From 5adb7bf1a4969be003fe1b9cd57d70e0e062609f Mon Sep 17 00:00:00 2001 From: satish kumar nuggu Date: Fri, 13 Aug 2021 23:15:34 +0530 Subject: [PATCH 001/243] Combined variants to reduce redundancy in dtrsm small 1. Left Lower non-trans,Left Upper trans 2. Left Upper non-trans,Left Lower trans 3. Right Lower non-trans.Right Upper trans 4. Right Upper non-trans,Right Lower trans Change-Id: I0b0155d7c3a55ec74d53c8f1f49f1bceb63b15f5 --- kernels/zen/3/bli_trsm_small.c | 26856 +++++++++---------------------- 1 file changed, 7228 insertions(+), 19628 deletions(-) diff --git a/kernels/zen/3/bli_trsm_small.c b/kernels/zen/3/bli_trsm_small.c index b127fa4e71..ea9de2a889 100644 --- a/kernels/zen/3/bli_trsm_small.c +++ b/kernels/zen/3/bli_trsm_small.c @@ -37,9 +37,6 @@ #include "immintrin.h" #define BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL -#define D_MR 8 -#define D_NR 6 - /* declaration of trsm small kernels function pointer @@ -55,7 +52,10 @@ typedef err_t (*trsmsmall_ker_ft) //AX = B; A is lower triangular; No transpose; //double precision; non-unit diagonal -BLIS_INLINE err_t bli_dtrsm_small_AlXB +//A.'X = B; A is upper triangular; +//A has to be transposed; double precision + +BLIS_INLINE err_t bli_dtrsm_small_AutXB_AlXB ( obj_t* AlphaObj, obj_t* a, @@ -68,7 +68,37 @@ BLIS_INLINE err_t bli_dtrsm_small_AlXB * A is upper-triangular, non-transpose, non-unit diagonal * dimensions A: mxm X: mxn B: mxn */ -BLIS_INLINE err_t bli_dtrsm_small_AuXB +//AX = B; A is lower triangular; transpose; double precision + +BLIS_INLINE err_t bli_dtrsm_small_AltXB_AuXB +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +); + +//XA = B; A is upper-triangular; A is transposed; +//double precision; non-unit diagonal +// XA = B; A is lower-traingular; No transpose; +//double precision; non-unit diagonal + +BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +); + +// XA = B; A is upper triangular; No transpose; +//double presicion; non-unit diagonal +//XA = B; A is lower-triangular; A is transposed; +// double precision; non-unit-diagonal + +BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB ( obj_t* AlphaObj, obj_t* a, @@ -90,28 +120,7 @@ BLIS_INLINE err_t dtrsm_AltXB_ref bool is_unitdiag ); -//A.'X = B; A is upper triangular; -//A has to be transposed; double precision -BLIS_INLINE err_t bli_dtrsm_small_AutXB -( - obj_t* alpha, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl -); - -//AX = B; A is lower triangular; transpose; double precision -BLIS_INLINE err_t bli_dtrsm_small_AltXB -( - obj_t* alpha, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl -); - -/* +/* * The preinversion of diagonal elements are enabled/disabled * based on configuration. */ @@ -146,16 +155,16 @@ BLIS_INLINE err_t dtrsm_AutXB_ref dim_t i, j, k; for (k = 0; k < M; k++) { - double lkk_inv = 1.0; - if(!unitDiagonal) lkk_inv = DIAG_ELE_INV_OPS(lkk_inv,A[k+k*lda]); - for (j = 0; j < N; j++) - { - B[k + j*ldb] = DIAG_ELE_EVAL_OPS(B[k + j*ldb] , lkk_inv); - for (i = k+1; i < M; i++) - { - B[i + j*ldb] -= A[i*lda + k] * B[k + j*ldb]; - } - } + double lkk_inv = 1.0; + if(!unitDiagonal) lkk_inv = DIAG_ELE_INV_OPS(lkk_inv,A[k+k*lda]); + for (j = 0; j < N; j++) + { + B[k + j*ldb] = DIAG_ELE_EVAL_OPS(B[k + j*ldb] , lkk_inv); + for (i = k+1; i < M; i++) + { + B[i + j*ldb] -= A[i*lda + k] * B[k + j*ldb]; + } + } }// k -loop return BLIS_SUCCESS; }// end of function @@ -178,16 +187,16 @@ BLIS_INLINE err_t dtrsm_AuXB_ref dim_t i, j, k; for (k = M-1; k >= 0; k--) { - double lkk_inv = 1.0; - if(!is_unitdiag) lkk_inv = DIAG_ELE_INV_OPS(lkk_inv,A[k+k*lda]); - for (j = N -1; j >= 0; j--) - { - B[k + j*ldb] = DIAG_ELE_EVAL_OPS(B[k + j*ldb],lkk_inv); - for (i = k-1; i >=0; i--) - { - B[i + j*ldb] -= A[i + k*lda] * B[k + j*ldb]; - } - } + double lkk_inv = 1.0; + if(!is_unitdiag) lkk_inv = DIAG_ELE_INV_OPS(lkk_inv,A[k+k*lda]); + for (j = N -1; j >= 0; j--) + { + B[k + j*ldb] = DIAG_ELE_EVAL_OPS(B[k + j*ldb],lkk_inv); + for (i = k-1; i >=0; i--) + { + B[i + j*ldb] -= A[i + k*lda] * B[k + j*ldb]; + } + } }// k -loop return BLIS_SUCCESS; }// end of function @@ -256,50 +265,6 @@ BLIS_INLINE err_t dtrsm_AltXB_ref return BLIS_SUCCESS; }// end of function -// XA = B; A is lower-traingular; No transpose; -//double precision; non-unit diagonal -BLIS_INLINE err_t bli_dtrsm_small_XAlB -( - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl -); - -//XA = B; A is lower-triangular; A is transposed; -// double precision; non-unit-diagonal -BLIS_INLINE err_t bli_dtrsm_small_XAltB -( - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl -); - -// XA = B; A is upper triangular; No transpose; -//double presicion; non-unit diagonal -BLIS_INLINE err_t bli_dtrsm_small_XAuB -( - obj_t* alpha, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl -); - -//XA = B; A is upper-triangular; A is transposed; -//double precision; non-unit diagonal -BLIS_INLINE err_t bli_dtrsm_small_XAutB -( - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl -); - /* TRSM scalar code for the case XA = alpha * B * A is upper-triangular, non-unit/unit diagonal no transpose * Dimensions: X:mxn A:nxn B:mxn @@ -341,7 +306,6 @@ BLIS_INLINE err_t dtrsm_XAlB_ref ( double *A, double *B, - double alpha, dim_t M, dim_t N, dim_t lda, @@ -350,13 +314,6 @@ BLIS_INLINE err_t dtrsm_XAlB_ref ) { dim_t i, j, k; - for(j = 0; j < N; j++) - { - for(i = 0; i < M; i++) - { - B[i+j*ldb] *= alpha; - } - } for(k = N;k--;) { @@ -383,7 +340,6 @@ BLIS_INLINE err_t dtrsm_XAutB_ref ( double *A, double *B, - double alpha, dim_t M, dim_t N, dim_t lda, @@ -392,13 +348,6 @@ BLIS_INLINE err_t dtrsm_XAutB_ref ) { dim_t i, j, k; - for(j = 0; j < N; j++) - { - for(i = 0; i < M; i++) - { - B[i+j*ldb] *=alpha; - } - } for(k = N; k--;) { @@ -474,7 +423,7 @@ BLIS_INLINE err_t dtrsm_XAltB_ref ymm15 = _mm256_setzero_pd(); /*GEMM block used in trsm small right cases*/ -#define BLIS_DTRSM_SMALL_GEMM_6x8(a01,b10,cs_b,p_lda,k_iter) \ +#define BLIS_DTRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) \ for(k = 0; k < k_iter; k++) \ {\ /*load 8x1 block of B10*/ \ @@ -512,8 +461,199 @@ BLIS_INLINE err_t dtrsm_XAltB_ref b10 += cs_b; \ } -/*GEMM block used in trsm small left cases*/ -#define BLIS_DTRSM_SMALL_GEMM_8x6(a10,b01,cs_b,p_lda,k_iter) \ +#define BLIS_DTRSM_SMALL_GEMM_6nx4m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 8x1 block of B10*/\ + ymm0 = _mm256_loadu_pd((double const *)b10); /*B10[0][0] B10[1][0] B10[2][0] B10[3][0]*/\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ + ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); /*A01[0][3]*/\ + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 4)); /*A01[0][4]*/\ + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 5)); /*A01[0][5]*/\ + ymm13 = _mm256_fmadd_pd(ymm2, ymm0, ymm13);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_DTRSM_SMALL_GEMM_4nx8m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 8x1 block of B10*/\ + ymm0 = _mm256_loadu_pd((double const *)b10);\ + ymm1 = _mm256_loadu_pd((double const *)(b10 + 4));\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3);\ + ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5);\ + ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ + ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7);\ + ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); /*A01[0][3]*/\ + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9);\ + ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_DTRSM_SMALL_GEMM_3nx8m(a01,b10,cs_b,p_lda,k_iter)\ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 8x1 block of B10*/\ + ymm0 = _mm256_loadu_pd((double const *)b10);\ + ymm1 = _mm256_loadu_pd((double const *)(b10 + 4));\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3);\ + ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5);\ + ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ + ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7);\ + ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_DTRSM_SMALL_GEMM_2nx8m(a01,b10,cs_b,p_lda,k_iter)\ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 8x1 block of B10*/\ + ymm0 = _mm256_loadu_pd((double const *)b10);/*B10[0][0] B10[1][0] B10[2][0] B10[3][0]*/\ + ymm1 = _mm256_loadu_pd((double const *)(b10 + 4));/*B10[4][0] B10[5][0] B10[6][0] B10[7][0]*/\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3);\ + ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5);\ + ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_DTRSM_SMALL_GEMM_1nx8m(a01,b10,cs_b,p_lda,k_iter)\ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 8x1 block of B10*/\ + ymm0 = _mm256_loadu_pd((double const *)b10);/*B10[0][0] B10[1][0] B10[2][0] B10[3][0]*/\ + ymm1 = _mm256_loadu_pd((double const *)(b10 + 4));/*B10[4][0] B10[5][0] B10[6][0] B10[7][0]*/\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3);\ + ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_DTRSM_SMALL_GEMM_4nx4m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 8x1 block of B10*/\ + ymm0 = _mm256_loadu_pd((double const *)b10);/*B10[0][0] B10[1][0] B10[2][0] B10[3][0]*/\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ + ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); /*A01[0][3]*/\ + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_DTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 8x1 block of B10*/\ + ymm0 = _mm256_loadu_pd((double const *)b10);/*B10[0][0] B10[1][0] B10[2][0] B10[3][0]*/\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ + ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_DTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 8x1 block of B10*/\ + ymm0 = _mm256_loadu_pd((double const *)b10);/*B10[0][0] B10[1][0] B10[2][0] B10[3][0]*/\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_DTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 8x1 block of B10*/\ + ymm0 = _mm256_loadu_pd((double const *)b10);/*B10[0][0] B10[1][0] B10[2][0] B10[3][0]*/\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +/*GEMM block used in dtrsm small left cases*/ +#define BLIS_DTRSM_SMALL_GEMM_8mx6n(a10,b01,cs_b,p_lda,k_iter) \ double *b01_prefetch = b01 + 8; \ for(k = 0; k< k_iter; k++) \ { \ @@ -554,6 +694,747 @@ BLIS_INLINE err_t dtrsm_XAltB_ref a10 += p_lda; \ } +#define BLIS_DTRSM_SMALL_GEMM_8mx4n(a10,b01,cs_b,p_lda,k_iter) \ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_loadu_pd((double const *)(a10));\ + ymm1 = _mm256_loadu_pd((double const *)(a10 + 4));\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(b01));\ + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8);\ + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1));\ + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9);\ + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2));\ + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10);\ + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3));\ + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11);\ + ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15);\ +\ + b01 += 1; /*move to next row of B*/\ + a10 += p_lda; /*pointer math to calculate next block of A for GEMM*/\ + } + +#define BLIS_DTRSM_SMALL_GEMM_8mx3n(a10,b01,cs_b,p_lda,k_iter) \ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_loadu_pd((double const *)(a10));\ + ymm1 = _mm256_loadu_pd((double const *)(a10 + 4));\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0));\ + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8);\ + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1));\ + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9);\ + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2));\ + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10);\ + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14);\ +\ + b01 += 1; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + } + +#define BLIS_DTRSM_SMALL_GEMM_8mx2n(a10,b01,cs_b,p_lda,k_iter) \ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_loadu_pd((double const *)(a10));\ + ymm1 = _mm256_loadu_pd((double const *)(a10 + 4));\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0));\ + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8);\ + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1));\ + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9);\ + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13);\ +\ + b01 += 1; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + } + +#define BLIS_DTRSM_SMALL_GEMM_8mx1n(a10,b01,cs_b,p_lda,k_iter) \ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_loadu_pd((double const *)(a10));\ + ymm1 = _mm256_loadu_pd((double const *)(a10 + 4));\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0));\ + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8);\ + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12);\ + b01 += 1; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + } + +#define BLIS_DTRSM_SMALL_GEMM_4mx6n(a10,b01,cs_b,p_lda,k_iter) \ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_loadu_pd((double const *)(a10));\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0));\ + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1));\ + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2));\ + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3));\ + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4));\ + ymm4 = _mm256_fmadd_pd(ymm2, ymm0, ymm4);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5));\ + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5);\ +\ + b01 += 1; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + } + +#define BLIS_DTRSM_SMALL_GEMM_4mx4n(a10,b01,cs_b,p_lda,k_iter) \ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_loadu_pd((double const *)(a10));\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0));\ + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1));\ + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2));\ + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3));\ + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11);\ +\ + b01 += 1; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + } + +#define BLIS_DTRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) \ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_loadu_pd((double const *)(a10));\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0));\ + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1));\ + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2));\ + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10);\ +\ + b01 += 1; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + } + +#define BLIS_DTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b,p_lda,k_iter) \ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_loadu_pd((double const *)(a10));\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0));\ + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1));\ + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9);\ +\ + b01 += 1; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + } + +#define BLIS_DTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b,p_lda,k_iter) \ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_loadu_pd((double const *)(a10));\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0));\ + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8);\ +\ + b01 += 1; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + } + +/* + Load b11 of size 6x8 and multiply with alpha + Add the GEMM output and perform inregister transose of b11 + to peform DTRSM operation for left cases. +*/ +#define BLIS_DTRSM_SMALL_NREG_TRANSPOSE_6x8(b11,cs_b,AlphaVal) \ + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal));\ +\ + ymm0 = _mm256_loadu_pd((double const *)(b11));\ + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1));\ + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2));\ + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3));\ + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8);\ + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9);\ + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10);\ + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11);\ +\ + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); \ + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); \ + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); \ + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31);\ + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); \ + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); \ + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); \ + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); \ +\ + ymm0 = _mm256_loadu_pd((double const *)(b11 + 4));\ + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4));\ + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2 + 4));\ + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3 + 4));\ + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm12);\ + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm13);\ + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm14);\ + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm15);\ +\ + ymm13 = _mm256_unpacklo_pd(ymm0, ymm1);\ + ymm15 = _mm256_unpacklo_pd(ymm2, ymm3);\ + ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20);\ + ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31);\ + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1);\ + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3);\ +\ + ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x20);\ + ymm15 = _mm256_permute2f128_pd(ymm0,ymm1,0x31);\ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4));\ + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5));\ + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm4);\ + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm5);\ + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *4 + 4));\ + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *5 + 4));\ + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm6);\ + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm7);\ +\ + ymm16 = _mm256_broadcast_sd((double const *)(&ones));\ + ymm7 = _mm256_unpacklo_pd(ymm0, ymm1);\ + ymm4 = _mm256_permute2f128_pd(ymm7,ymm16,0x20);\ + ymm6 = _mm256_permute2f128_pd(ymm7,ymm16,0x31);\ +\ + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1);\ + ymm5 = _mm256_permute2f128_pd(ymm0,ymm16,0x20);\ + ymm7 = _mm256_permute2f128_pd(ymm0,ymm16,0x31);\ + ymm18 = _mm256_unpacklo_pd(ymm2, ymm3);\ + ymm17 = _mm256_permute2f128_pd(ymm18,ymm16,0x20);\ + ymm19 = _mm256_permute2f128_pd(ymm18,ymm16,0x31);\ +\ + /*unpackhigh*/\ + ymm20 = _mm256_unpackhi_pd(ymm2, ymm3);\ +\ + /*rearrange high elements*/\ + ymm18 = _mm256_permute2f128_pd(ymm20,ymm16,0x20);\ + ymm20 = _mm256_permute2f128_pd(ymm20,ymm16,0x31); + +#define BLIS_DTRSM_SMALL_NREG_TRANSPOSE_8x6_AND_STORE(b11,cs_b)\ + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9);\ + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11);\ +\ + /*rearrange low elements*/\ + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20);\ + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31);\ +\ + /*unpack high*/\ + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9);\ + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11);\ +\ + /*rearrange high elements*/\ + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20);\ + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31);\ +\ + _mm256_storeu_pd((double *)(b11), ymm0);\ + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1);\ + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2);\ + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3);\ +\ + /*unpacklow*/\ + ymm1 = _mm256_unpacklo_pd(ymm12, ymm13);\ + ymm3 = _mm256_unpacklo_pd(ymm14, ymm15);\ +\ + /*rearrange low elements*/\ + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20);\ + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31);\ +\ + /*unpack high*/\ + ymm12 = _mm256_unpackhi_pd(ymm12, ymm13);\ + ymm13 = _mm256_unpackhi_pd(ymm14, ymm15);\ +\ + /*rearrange high elements*/\ + ymm1 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20);\ + ymm3 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31);\ +\ + _mm256_storeu_pd((double *)(b11 + 4), ymm0);\ + _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm1);\ + _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 4), ymm2);\ + _mm256_storeu_pd((double *)(b11 + cs_b * 3 + 4), ymm3);\ +\ + /*unpacklow*/\ + ymm1 = _mm256_unpacklo_pd(ymm4, ymm5);\ + ymm3 = _mm256_unpacklo_pd(ymm6, ymm7);\ +\ + /*rearrange low elements*/\ + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20);\ +\ + /*unpack high*/\ + ymm4 = _mm256_unpackhi_pd(ymm4, ymm5);\ + ymm5 = _mm256_unpackhi_pd(ymm6, ymm7);\ +\ + /*rearrange high elements*/\ + ymm1 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20);\ +\ + _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0);\ + _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1);\ +\ + /*unpacklow*/\ + ymm1 = _mm256_unpacklo_pd(ymm17, ymm18);\ + ymm3 = _mm256_unpacklo_pd(ymm19, ymm20);\ +\ + /*rearrange low elements*/\ + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20);\ +\ + /*unpack high*/\ + ymm17 = _mm256_unpackhi_pd(ymm17, ymm18);\ + ymm18 = _mm256_unpackhi_pd(ymm19, ymm20);\ +\ + /*rearrange high elements*/\ + ymm1 = _mm256_permute2f128_pd(ymm17, ymm18, 0x20);\ +\ + _mm256_storeu_pd((double *)(b11 + cs_b * 4 + 4), ymm0);\ + _mm256_storeu_pd((double *)(b11 + cs_b * 5 + 4), ymm1); + +#define BLIS_PRE_DTRSM_SMALL_3M_3N(AlphaVal,b11,cs_b)\ + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); /*register to hold alpha*/\ +\ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0));\ + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1));\ + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2));\ + ymm2 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 2 + 2));\ + ymm2 = _mm256_insertf128_pd(ymm2, xmm5, 0);\ +\ + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8);\ + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9);\ + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10);\ +\ + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08);\ + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08);\ + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08);\ +\ + _mm256_storeu_pd((double *)(b11), ymm0); /*store(B11[0-3][0])*/\ + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); /*store(B11[0-3][1])*/\ + xmm5 = _mm256_extractf128_pd(ymm2, 0);\ + _mm_storeu_pd((double *)(b11 + cs_b * 2), xmm5);\ + _mm_storel_pd((b11 + cs_b * 2 + 2), _mm256_extractf128_pd(ymm2, 1)); + +#define BLIS_PRE_DTRSM_SMALL_3M_2N(AlphaVal,b11,cs_b)\ + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); /*register to hold alpha*/\ +\ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0));\ + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1));\ + ymm1 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 1 + 2));\ + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0);\ +\ + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8);\ + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9);\ +\ + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08);\ + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08);\ +\ + _mm256_storeu_pd((double *)(b11), ymm0); /*store(B11[0-3][0])*/\ + xmm5 = _mm256_extractf128_pd(ymm1, 0);\ + _mm_storeu_pd((double *)(b11 + cs_b * 1), xmm5);\ + _mm_storel_pd((b11 + cs_b * 1 + 2), _mm256_extractf128_pd(ymm1, 1)); + +#define BLIS_PRE_DTRSM_SMALL_3M_1N(AlphaVal,b11,cs_b)\ + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); /*register to hold alpha*/\ +\ + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 0));\ + ymm0 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 0 + 2));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8);\ + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08);\ +\ + xmm5 = _mm256_extractf128_pd(ymm0, 0);\ + _mm_storeu_pd((double *)(b11), xmm5);\ + _mm_storel_pd((b11 + 2), _mm256_extractf128_pd(ymm0, 1)); + + +#define BLIS_PRE_DTRSM_SMALL_2M_3N(AlphaVal,b11,cs_b)\ + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); /*register to hold alpha*/\ +\ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0));\ + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1));\ + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2));\ + ymm2 = _mm256_insertf128_pd(ymm2, xmm5, 0);\ +\ + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8);\ + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9);\ + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10);\ +\ + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C);\ + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C);\ + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0C);\ +\ + _mm256_storeu_pd((double *)(b11), ymm0); /*store(B11[0-3][0])*/\ + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); /*store(B11[0-3][1])*/\ + xmm5 = _mm256_extractf128_pd(ymm2, 0);\ + _mm_storeu_pd((double *)(b11 + cs_b * 2), xmm5); + +#define BLIS_PRE_DTRSM_SMALL_2M_2N(AlphaVal,b11,cs_b)\ + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); /*register to hold alpha*/\ +\ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0));\ + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1));\ + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0);\ +\ + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8);\ + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9);\ +\ + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C);\ + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C);\ +\ + _mm256_storeu_pd((double *)(b11), ymm0); /*store(B11[0-3][0])*/\ + xmm5 = _mm256_extractf128_pd(ymm1, 0);\ + _mm_storeu_pd((double *)(b11 + cs_b * 1), xmm5); + +#define BLIS_PRE_DTRSM_SMALL_2M_1N(AlphaVal,b11,cs_b)\ + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); /*register to hold alpha*/\ +\ + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 0));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ +\ + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8);\ + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C);\ +\ + xmm5 = _mm256_extractf128_pd(ymm0, 0);\ + _mm_storeu_pd((double *)(b11 + cs_b * 0), xmm5); + +#define BLIS_PRE_DTRSM_SMALL_1M_3N(AlphaVal,b11,cs_b)\ + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); /*register to hold alpha*/\ +\ + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b *0));\ + ymm1 = _mm256_broadcast_sd((double const *)(b11 + cs_b *1));\ + ymm2 = _mm256_broadcast_sd((double const *)(b11 + cs_b *2));\ +\ + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8);\ + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9);\ + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10);\ +\ + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E);\ + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E);\ + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E);\ +\ + _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm0, 0));\ + _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm1, 0));\ + _mm_storel_pd((b11 + cs_b * 2), _mm256_extractf128_pd(ymm2, 0)); + +#define BLIS_PRE_DTRSM_SMALL_1M_2N(AlphaVal,b11,cs_b)\ + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); /*register to hold alpha*/\ +\ + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b *0));\ + ymm1 = _mm256_broadcast_sd((double const *)(b11 + cs_b *1));\ +\ + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8);\ + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9);\ +\ + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E);\ + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E);\ +\ + _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm0, 0));\ + _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm1, 0)); + +#define BLIS_PRE_DTRSM_SMALL_1M_1N(AlphaVal,b11,cs_b)\ + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); /*register to hold alpha*/\ +\ + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b *0));\ + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8);\ +\ + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E);\ +\ + _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm0, 0)); + +/* pre & post TRSM for Right remainder cases*/ +#define BLIS_PRE_DTRSM_SMALL_3N_3M(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); /*register to hold alpha*/\ +\ + ymm0 = _mm256_loadu_pd((double const *)b11);\ + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3);\ +\ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b));\ + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5);\ +\ + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2));\ + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*2 + 2));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ + ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); + +#define BLIS_POST_DTRSM_SMALL_3N_3M(b11,cs_b)\ + ymm0 = _mm256_loadu_pd((double const *)b11);\ + ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x07);\ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b));\ + ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x07);\ + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2));\ + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*2 + 2));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ + ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x07);\ +\ + _mm256_storeu_pd((double *)b11, ymm3);\ + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5);\ + xmm5 = _mm256_extractf128_pd(ymm7, 0);\ + _mm_storeu_pd((double *)(b11 + cs_b * 2),xmm5);\ + _mm_storel_pd((b11 + cs_b * 2 + 2), _mm256_extractf128_pd(ymm7, 1)); + +#define BLIS_PRE_DTRSM_SMALL_3N_2M(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); /*register to hold alpha*/\ +\ + ymm0 = _mm256_loadu_pd((double const *)b11);\ + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3);\ +\ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b));\ + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5);\ +\ + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ + ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); + +#define BLIS_POST_DTRSM_SMALL_3N_2M(b11,cs_b)\ + ymm0 = _mm256_loadu_pd((double const *)b11);\ + ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x03);\ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b));\ + ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x03);\ + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ + ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x03);\ +\ + _mm256_storeu_pd((double *)b11, ymm3);\ + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5);\ + xmm5 = _mm256_extractf128_pd(ymm7, 0);\ + _mm_storeu_pd((double *)(b11 + cs_b * 2),xmm5); + +#define BLIS_PRE_DTRSM_SMALL_3N_1M(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); /*register to hold alpha*/\ +\ + ymm0 = _mm256_broadcast_sd((double const *)b11);\ + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3);\ +\ + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b));\ + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5);\ +\ + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*2));\ + ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); + +#define BLIS_POST_DTRSM_SMALL_3N_1M(b11,cs_b)\ + ymm0 = _mm256_broadcast_sd((double const *)b11);\ + ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x01);\ + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b));\ + ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x01);\ + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*2));\ + ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x01);\ +\ + _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm3, 0));\ + _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm5, 0));\ + _mm_storel_pd((b11 + cs_b * 2), _mm256_extractf128_pd(ymm7, 0)); + +#define BLIS_PRE_DTRSM_SMALL_2N_3M(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); /*register to hold alpha*/\ +\ + ymm0 = _mm256_loadu_pd((double const *)b11);\ + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3);\ +\ + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1));\ + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*1 + 2));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); + +#define BLIS_POST_DTRSM_SMALL_2N_3M(b11,cs_b)\ + ymm0 = _mm256_loadu_pd((double const *)b11);\ + ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x07);\ + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1));\ + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*1 + 2));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ + ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x07);\ +\ + _mm256_storeu_pd((double *)b11, ymm3);\ + xmm5 = _mm256_extractf128_pd(ymm5, 0);\ + _mm_storeu_pd((double *)(b11 + cs_b*1), xmm5);\ + _mm_storel_pd((b11 + cs_b * 1 + 2), _mm256_extractf128_pd(ymm5, 1)); + +#define BLIS_PRE_DTRSM_SMALL_2N_2M(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); /*register to hold alpha*/\ +\ + ymm0 = _mm256_loadu_pd((double const *)b11);\ + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3);\ +\ + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); + +#define BLIS_POST_DTRSM_SMALL_2N_2M(b11,cs_b)\ + ymm0 = _mm256_loadu_pd((double const *)b11);\ + ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x03);\ + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ + ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x03);\ +\ + _mm256_storeu_pd((double *)b11, ymm3);\ + xmm5 = _mm256_extractf128_pd(ymm5, 0);\ + _mm_storeu_pd((double *)(b11 + cs_b*1), xmm5); + +#define BLIS_PRE_DTRSM_SMALL_2N_1M(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); /*register to hold alpha*/\ +\ + ymm0 = _mm256_broadcast_sd((double const *)b11);\ + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3);\ +\ + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b));\ + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); + +#define BLIS_POST_DTRSM_SMALL_2N_1M(b11,cs_b)\ + ymm0 = _mm256_broadcast_sd((double const *)b11);\ + ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x01);\ + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b));\ + ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x01);\ +\ + _mm_storel_pd(b11 , _mm256_extractf128_pd(ymm3, 0));\ + _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm5, 0)); + +#define BLIS_PRE_DTRSM_SMALL_1N_3M(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); /*register to hold alpha*/\ +\ + xmm5 = _mm_loadu_pd((double const*)(b11));\ + ymm0 = _mm256_broadcast_sd((double const *)(b11+ 2));\ + ymm6 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ + ymm3 = _mm256_fmsub_pd(ymm6, ymm15, ymm3); + +#define BLIS_POST_DTRSM_SMALL_1N_3M(b11,cs_b)\ + xmm5 = _mm256_extractf128_pd(ymm3, 0);\ + _mm_storeu_pd((double *)(b11), xmm5);\ + _mm_storel_pd((b11 + 2), _mm256_extractf128_pd(ymm3, 1)); + +#define BLIS_PRE_DTRSM_SMALL_1N_2M(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); /*register to hold alpha*/\ +\ + xmm5 = _mm_loadu_pd((double const*)(b11));\ + ymm6 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ + ymm3 = _mm256_fmsub_pd(ymm6, ymm15, ymm3); + +#define BLIS_POST_DTRSM_SMALL_1N_2M(b11,cs_b)\ + ymm0 = _mm256_loadu_pd((double const *)b11);\ + ymm3 = _mm256_blend_pd(ymm6, ymm3, 0x03);\ +\ + xmm5 = _mm256_extractf128_pd(ymm3, 0);\ + _mm_storeu_pd((double *)(b11), xmm5); + +#define BLIS_PRE_DTRSM_SMALL_1N_1M(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); /*register to hold alpha*/\ +\ + ymm6 = _mm256_broadcast_sd((double const *)b11);\ + ymm3 = _mm256_fmsub_pd(ymm6, ymm15, ymm3); + +#define BLIS_POST_DTRSM_SMALL_1N_1M(b11,cs_b)\ + ymm3 = _mm256_blend_pd(ymm6, ymm3, 0x01);\ +\ + _mm_storel_pd(b11, _mm256_extractf128_pd(ymm3, 0)); + +/* multiply with Alpha pre TRSM for 6*8 kernel*/ +#define BLIS_PRE_DTRSM_SMALL_6x8(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal);\ +\ + ymm0 = _mm256_loadu_pd((double const *)b11);\ + ymm1 = _mm256_loadu_pd((double const *)(b11 + 4));\ +\ + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3);\ + ymm4 = _mm256_fmsub_pd(ymm1, ymm15, ymm4);\ +\ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b));\ + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b + 4));\ +\ + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5);\ + ymm6 = _mm256_fmsub_pd(ymm1, ymm15, ymm6);\ +\ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2));\ + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b*2 + 4));\ +\ + ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7);\ + ymm8 = _mm256_fmsub_pd(ymm1, ymm15, ymm8);\ +\ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3));\ + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b*3 + 4));\ +\ + ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9);\ + ymm10 = _mm256_fmsub_pd(ymm1, ymm15, ymm10);\ +\ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4));\ + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b*4 + 4));\ +\ + ymm11 = _mm256_fmsub_pd(ymm0, ymm15, ymm11);\ + ymm12 = _mm256_fmsub_pd(ymm1, ymm15, ymm12);\ +\ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5));\ + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b*5 + 4));\ +\ + ymm13 = _mm256_fmsub_pd(ymm0, ymm15, ymm13);\ + ymm14 = _mm256_fmsub_pd(ymm1, ymm15, ymm14); + +#define BLIS_PRE_DTRSM_SMALL_4x8(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal));\ +\ + ymm0 = _mm256_loadu_pd((double const *)b11);\ + ymm1 = _mm256_loadu_pd((double const *)(b11 + 4));\ +\ + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3);\ + ymm4 = _mm256_fmsub_pd(ymm1, ymm15, ymm4);\ +\ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b));\ + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b + 4));\ +\ + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5);\ + ymm6 = _mm256_fmsub_pd(ymm1, ymm15, ymm6);\ +\ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2));\ + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b*2 + 4));\ +\ + ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7);\ + ymm8 = _mm256_fmsub_pd(ymm1, ymm15, ymm8);\ +\ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3));\ + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b*3 + 4));\ +\ + ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9);\ + ymm10 = _mm256_fmsub_pd(ymm1, ymm15, ymm10); + +#define BLIS_PRE_DTRSM_SMALL_6x4(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); /*register to hold alpha*/\ +\ + ymm0 = _mm256_loadu_pd((double const *)b11);\ + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3);\ +\ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b));\ + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5);\ +\ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2));\ + ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7);\ +\ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3));\ + ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9);\ +\ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4));\ + ymm11 = _mm256_fmsub_pd(ymm0, ymm15, ymm11);\ +\ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5));\ + ymm13 = _mm256_fmsub_pd(ymm0, ymm15, ymm13); + /* Pack a block of 8xk or 6xk from input buffer into packed buffer directly or after transpose based on input params @@ -566,7 +1447,8 @@ BLIS_INLINE void bli_dtrsm_small_pack double *inbuf, dim_t cs_a, double *pbuff, - dim_t p_lda + dim_t p_lda, + dim_t mr ) { //scratch registers @@ -579,9 +1461,9 @@ BLIS_INLINE void bli_dtrsm_small_pack if(side=='L'||side=='l') { - /*Left case is 8xk*/ - if(trans) - { + /*Left case is 8xk*/ + if(trans) + { /* ------------- ------------- | | | | | @@ -591,128 +1473,21 @@ BLIS_INLINE void bli_dtrsm_small_pack | | | | | ------------- ------------- */ - for(dim_t x = 0; x < size; x += D_MR) - { - ymm0 = _mm256_loadu_pd((double const *)(inbuf)); - ymm10 = _mm256_loadu_pd((double const *)(inbuf + 4)); - ymm1 = _mm256_loadu_pd((double const *)(inbuf + cs_a)); - ymm11 = _mm256_loadu_pd((double const *)(inbuf + 4 + cs_a)); - ymm2 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 2)); - ymm12 = _mm256_loadu_pd((double const *)(inbuf + 4 + cs_a * 2)); - ymm3 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 3)); - ymm13 = _mm256_loadu_pd((double const *)(inbuf + 4 + cs_a * 3)); - - ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - - _mm256_storeu_pd((double *)(pbuff), ymm6); - _mm256_storeu_pd((double *)(pbuff + p_lda), ymm7); - _mm256_storeu_pd((double *)(pbuff + p_lda*2), ymm8); - _mm256_storeu_pd((double *)(pbuff + p_lda*3), ymm9); - - ymm4 = _mm256_unpacklo_pd(ymm10, ymm11); - ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); - - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); - - ymm0 = _mm256_unpackhi_pd(ymm10, ymm11); - ymm1 = _mm256_unpackhi_pd(ymm12, ymm13); - - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - - _mm256_storeu_pd((double *)(pbuff + p_lda * 4), ymm6); - _mm256_storeu_pd((double *)(pbuff + p_lda * 5), ymm7); - _mm256_storeu_pd((double *)(pbuff + p_lda * 6), ymm8); - _mm256_storeu_pd((double *)(pbuff + p_lda * 7), ymm9); - - ymm0 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 4)); - ymm10 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 4 + 4)); - ymm1 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 5)); - ymm11 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 5 + 4)); - ymm2 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 6)); - ymm12 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 6 + 4)); - ymm3 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 7)); - ymm13 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 7 + 4)); - - ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - - _mm256_storeu_pd((double *)(pbuff + 4), ymm6); - _mm256_storeu_pd((double *)(pbuff + 4 + p_lda), ymm7); - _mm256_storeu_pd((double *)(pbuff + 4 + p_lda*2), ymm8); - _mm256_storeu_pd((double *)(pbuff + 4 + p_lda*3), ymm9); - - ymm4 = _mm256_unpacklo_pd(ymm10, ymm11); - ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); - ymm0 = _mm256_unpackhi_pd(ymm10, ymm11); - ymm1 = _mm256_unpackhi_pd(ymm12, ymm13); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - - _mm256_storeu_pd((double *)(pbuff + 4 + p_lda * 4), ymm6); - _mm256_storeu_pd((double *)(pbuff + 4 + p_lda * 5), ymm7); - _mm256_storeu_pd((double *)(pbuff + 4 + p_lda * 6), ymm8); - _mm256_storeu_pd((double *)(pbuff + 4 + p_lda * 7), ymm9); - - inbuf += D_MR; - pbuff += D_MR*D_MR; - } - }else - { - //Expected multiples of 4 - p_lda = 8; - for(dim_t x = 0; x < size; x++) - { - ymm0 = _mm256_loadu_pd((double const *)(inbuf)); - _mm256_storeu_pd((double *)(pbuff), ymm0); - ymm1 = _mm256_loadu_pd((double const *)(inbuf + 4)); - _mm256_storeu_pd((double *)(pbuff + 4), ymm1); - inbuf+=cs_a; - pbuff+=p_lda; - } - } - }else if(side=='R'||side=='r') - { - - if(trans) - { - /* - ------------------ ---------- - | | | | | | - | 4x4 | 4x4 | | 4x4 |4x2 | - ------------- ==> ------------- - | | | | | | - | 2x4 | 2x4 | | 2x4 |2x2 | - ------------------- ------------- - */ - for(dim_t x=0; x>2); i++) + }else + { + //Expected multiples of 4 + p_lda = 8; + for(dim_t x = 0; x < size; x++) { - ymm0 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 0 )); - _mm256_storeu_pd((double *)(pbuff + p_lda * 0), ymm0); - ymm1 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 1 )); - _mm256_storeu_pd((double *)(pbuff + p_lda * 1), ymm1); - ymm2 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 2)); - _mm256_storeu_pd((double *)(pbuff + p_lda * 2), ymm2); - ymm3 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 3 )); - _mm256_storeu_pd((double *)(pbuff + p_lda * 3), ymm3); - ymm0 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 4 )); - _mm256_storeu_pd((double *)(pbuff + p_lda * 4), ymm0); - ymm1 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 5)); - _mm256_storeu_pd((double *)(pbuff + p_lda * 5), ymm1); - inbuf += 4; - pbuff += 4; + ymm0 = _mm256_loadu_pd((double const *)(inbuf)); + _mm256_storeu_pd((double *)(pbuff), ymm0); + ymm1 = _mm256_loadu_pd((double const *)(inbuf + 4)); + _mm256_storeu_pd((double *)(pbuff + 4), ymm1); + inbuf+=cs_a; + pbuff+=p_lda; } + } + }else if(side=='R'||side=='r') + { - if(size & 0x3) + if(trans) + { + /* + ------------------ ---------- + | | | | | | + | 4x4 | 4x4 | | 4x4 |4x2 | + ------------- ==> ------------- + | | | | | | + | 2x4 | 2x4 | | 2x4 |2x2 | + ------------------- ------------- + */ + for(dim_t x=0; x>2); i++) + { + ymm0 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 0 )); + _mm256_storeu_pd((double *)(pbuff + p_lda * 0), ymm0); + ymm1 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 1 )); + _mm256_storeu_pd((double *)(pbuff + p_lda * 1), ymm1); + ymm2 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 2)); + _mm256_storeu_pd((double *)(pbuff + p_lda * 2), ymm2); + ymm3 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 3 )); + _mm256_storeu_pd((double *)(pbuff + p_lda * 3), ymm3); + ymm0 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 4 )); + _mm256_storeu_pd((double *)(pbuff + p_lda * 4), ymm0); + ymm1 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 5)); + _mm256_storeu_pd((double *)(pbuff + p_lda * 5), ymm1); + inbuf += 4; + pbuff += 4; + } + + if(size & 0x3) + { + xmm0 = _mm_loadu_pd((double const *)(inbuf + cs_a * 0)); + _mm_storeu_pd((double *)(pbuff + p_lda * 0 ), xmm0); xmm1 = _mm_loadu_pd((double const *)(inbuf + cs_a * 1)); _mm_storeu_pd((double *)(pbuff + p_lda * 1), xmm1); xmm2 = _mm_loadu_pd((double const *)(inbuf + cs_a * 2)); @@ -830,7 +1712,7 @@ BLIS_INLINE void dtrsm_small_pack_diag_element __m256d ymm0, ymm1, ymm2, ymm3; __m256d ymm4, ymm5; double ones = 1.0; - bool is_eight = (size==D_MR) ? 1 : 0; + bool is_eight = (size==8) ? 1 : 0; ymm4 = ymm5 = _mm256_broadcast_sd((double const *)&ones); if(!is_unitdiag) { @@ -879,20 +1761,20 @@ BLIS_INLINE void dtrsm_small_pack_diag_element _mm_storeu_pd((double *)(d11_pack + 4), _mm256_extractf128_pd(ymm5,0)); } } - + /* * Kernels Table */ trsmsmall_ker_ft ker_fps[8] = { - bli_dtrsm_small_AlXB, - bli_dtrsm_small_AltXB, - bli_dtrsm_small_AuXB, - bli_dtrsm_small_AutXB, - bli_dtrsm_small_XAlB, - bli_dtrsm_small_XAltB, - bli_dtrsm_small_XAuB, - bli_dtrsm_small_XAutB + bli_dtrsm_small_AutXB_AlXB, + bli_dtrsm_small_AltXB_AuXB, + bli_dtrsm_small_AltXB_AuXB, + bli_dtrsm_small_AutXB_AlXB, + bli_dtrsm_small_XAutB_XAlB, + bli_dtrsm_small_XAltB_XAuB, + bli_dtrsm_small_XAltB_XAuB, + bli_dtrsm_small_XAutB_XAlB }; /* @@ -930,7 +1812,7 @@ err_t bli_trsm_small /* ToDo: Temporary threshold condition for trsm single thread. * It will be updated with arch based threshold function which reads * tunned thresholds for all 64 (datatype,side,uplo,transa,unit,) trsm - combinations. We arrived to this condition based on performance + combinations. We arrived to this condition based on performance comparsion with only available native path */ if(m > 1000 || n > 1000) { @@ -985,31 +1867,42 @@ err_t bli_trsm_small return err; }; -/* TRSM for the case AX = alpha * B, Double precision - * A is lower-triangular, no-transpose, non-unit diagonal - * dimensions A: mxm X: mxn B: mxn +/*implements TRSM for the case XA = alpha * B + *A is lower triangular, non-unit diagonal/unit diagonal, transpose + *dimensions: X:mxn A:nxn B: mxn + * + * b11---> a01 ----> + ***************** *********** + *b01*b11* * * * * * * +b11 * * * * * **a01 * * a11 + | ***************** ********* | + | * * * * * *a11* * | + | * * * * * * * * | + v ***************** ****** v + * * * * * * * + * * * * * * * + ***************** * * + * + *implements TRSM for the case XA = alpha * B + *A is upper triangular, non-unit diagonal/unit diagonal, no transpose + *dimensions: X:mxn A:nxn B: mxn + * + * b11---> a01 ----> + ***************** *********** + *b01*b11* * * * * * * +b11 * * * * * **a01 * * a11 + | ***************** ********* | + | * * * * * *a11* * | + | * * * * * * * * | + v ***************** ****** v + * * * * * * * + * * * * * * * + ***************** * * + * - b01---> - * ***************** - ** * * * * * - * * * * * * * - * * *b01* * * * - * * * * * * * -a10 ****** b11 ***************** - | * * * | * * * * * - | * * * | * * * * * - | *a10*a11* | *b11* * * * - v * * * v * * * * * - *********** ***************** - * * * * * * * * * - * * * * * * * * * - * * * * * * * * * - * * * * * * * * * - **************** ***************** - a11---> */ -BLIS_INLINE err_t bli_dtrsm_small_AlXB +BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB ( obj_t* AlphaObj, obj_t* a, @@ -1018,5280 +1911,2458 @@ BLIS_INLINE err_t bli_dtrsm_small_AlXB cntl_t* cntl ) { - dim_t m = bli_obj_length(b); // number of rows of matrix B - dim_t n = bli_obj_width(b); // number of columns of matrix B - - dim_t cs_a = bli_obj_col_stride(a); // column stride of A - dim_t cs_b = bli_obj_col_stride(b); // column stride of B - - dim_t i, j, k; //loop variables - dim_t k_iter; //number of times GEMM to be performed - - double AlphaVal = *(double *)AlphaObj->buffer; //value of alpha - double *L = a->buffer; //pointer to matrix A - double *B = b->buffer; //pointer to matrix B - - //pointers that point to blocks for GEMM and TRSM - double *a10, *a11, *b01, *b11; - - double ones = 1.0; - bool is_unitdiag = bli_obj_has_unit_diag(a); - - //scratch registers - __m256d ymm0, ymm1, ymm2, ymm3; - __m256d ymm4, ymm5, ymm6, ymm7; - __m256d ymm8, ymm9, ymm10, ymm11; - __m256d ymm12, ymm13, ymm14, ymm15; - __m256d ymm16, ymm17, ymm18, ymm19; - __m256d ymm20; - - __m128d xmm5; - - gint_t required_packing_A = 1; - mem_t local_mem_buf_A_s = {0}; - double *D_A_pack = NULL; - double d11_pack[D_MR] __attribute__((aligned(64))); - rntm_t rntm; - - bli_rntm_init_from_global( &rntm ); - bli_rntm_set_num_threads_only( 1, &rntm ); - bli_membrk_rntm_set_membrk( &rntm ); - - siz_t buffer_size = bli_pool_block_size( - bli_membrk_pool( - bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), - bli_rntm_membrk(&rntm))); - - if( (D_MR * m * sizeof(double)) > buffer_size) - return BLIS_NOT_YET_IMPLEMENTED; - - if (required_packing_A == 1) - { - // Get the buffer from the pool. - bli_membrk_acquire_m(&rntm, - buffer_size, - BLIS_BITVAL_BUFFER_FOR_A_BLOCK, - &local_mem_buf_A_s); - if(FALSE==bli_mem_is_alloc(&local_mem_buf_A_s)) return BLIS_NULL_POINTER; - D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); - if(NULL==D_A_pack) return BLIS_NULL_POINTER; - } - - /* - Performs solving TRSM for 8 colmns at a time from 0 to m/8 in steps of D_MR - a. Load and pack A (a10 block), the size of packing 8x6 to 8x (m-8) - First there will be no GEMM and no packing of a10 because it is only TRSM - b. Using packed a10 block and b01 block perform GEMM operation - c. Use GEMM outputs, perform TRSM operaton using a11, b11 and update B - d. Repeat b,c for n rows of B in steps of D_NR - */ - for(i = 0;(i+D_MR-1) < m; i += D_MR) //loop along 'M' dimension - { - a10 = L + (i); //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - - dim_t p_lda = D_MR; // packed leading dimension - - /* - Pack current A block (a10) into packed buffer memory D_A_pack - a. This a10 block is used in GEMM portion only and this - a10 block size will be increasing by D_MR for every next itteration - untill it reaches 8x(m-8) which is the maximum GEMM alone block size in A - b. This packed buffer is reused to calculate all n rows of B matrix - */ - bli_dtrsm_small_pack('L', i, 0, a10, cs_a, D_A_pack, p_lda); + dim_t m = bli_obj_length(b); //number of rows + dim_t n = bli_obj_width(b); //number of columns + dim_t d_mr = 8,d_nr = 6; - /* - Pack 8 diagonal elements of A block into an array - a. This helps in utilze cache line efficiently in TRSM operation - b. store ones when input is unit diagonal - */ - dtrsm_small_pack_diag_element(is_unitdiag,a11,cs_a,d11_pack,D_MR); + bool transa = bli_obj_has_trans(a); + dim_t cs_a, rs_a; - /* - a. Perform GEMM using a10, b01. - b. Perform TRSM on a11, b11 - c. This loop GEMM+TRSM loops operates with 8x6 block size - along n dimension for every D_NR columns of B01 where - packed A buffer is reused in computing all n rows of B. - d. Same approch is used in remaining fringe cases. - */ - for(j = 0; (j+D_NR-1) < n; j += D_NR) //loop along 'N' dimension + // Swap rs_a & cs_a in case of non-tranpose. + if(transa) { - a10 = D_A_pack; - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + j*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM - - k_iter = i ; //number of times GEMM to be performed - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); - #endif - - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS - - /* - Peform GEMM between a10 and b01 blocks - For first itteration there will be no GEMM operation - where k_iter are zero - */ - BLIS_DTRSM_SMALL_GEMM_8x6(a10,b01,cs_b,p_lda,k_iter) - - /* - Load b11 of size 6x8 and multiply with alpha - Add the GEMM output and perform inregister transose of b11 - to peform TRSM operation. - */ - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); //B11[0-3][3] * alpha -= B01[0-3][3] - - ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - - ymm0 = _mm256_loadu_pd((double const *)(b11 + 4)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2 + 4)); - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3 + 4)); - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm12); - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm13); - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm14); - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm15); - - ymm13 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm15 = _mm256_unpacklo_pd(ymm2, ymm3); - ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); - ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); - ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm15 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *4 + 4)); - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *5 + 4)); - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm6); - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm7); - - ymm16 = _mm256_broadcast_sd((double const *)(&ones)); - ymm7 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm4 = _mm256_permute2f128_pd(ymm7,ymm16,0x20); - ymm6 = _mm256_permute2f128_pd(ymm7,ymm16,0x31); - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm5 = _mm256_permute2f128_pd(ymm0,ymm16,0x20); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm16,0x31); - ymm18 = _mm256_unpacklo_pd(ymm2, ymm3); - ymm17 = _mm256_permute2f128_pd(ymm18,ymm16,0x20); - ymm19 = _mm256_permute2f128_pd(ymm18,ymm16,0x31); - ymm20 = _mm256_unpackhi_pd(ymm2, ymm3); - ymm18 = _mm256_permute2f128_pd(ymm20,ymm16,0x20); - ymm20 = _mm256_permute2f128_pd(ymm20,ymm16,0x31); - //b11 transpose end - - /* - Compute 8x6 TRSM block by using GEMM block output in register - a. The 8x6 input (gemm outputs) are stored in combinations of ymm registers - 1. ymm8, ymm4 2. ymm9, ymm5 3. ymm10, ymm6, 4. ymm11, ymm7 - 5. ymm12, ymm17 6. ymm13,ymm18, 7. ymm14,ymm19 8. ymm15, ymm20 - where ymm8-ymm15 holds 8x4 data and reaming 8x2 will be hold by - other registers - b. Towards the end do in regiser transpose of TRSM output and store in b11 - */ - - //extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); - //perform mul operation - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); - ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm1); - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - //(ROw1): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); - ymm5 = _mm256_fnmadd_pd(ymm2, ymm4, ymm5); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm8, ymm10); - ymm6 = _mm256_fnmadd_pd(ymm2, ymm4, ymm6); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); - ymm7 = _mm256_fnmadd_pd(ymm2, ymm4, ymm7); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4)); - ymm12 = _mm256_fnmadd_pd(ymm2, ymm8, ymm12); - ymm17 = _mm256_fnmadd_pd(ymm2, ymm4, ymm17); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5)); - ymm13 = _mm256_fnmadd_pd(ymm2, ymm8, ymm13); - ymm18 = _mm256_fnmadd_pd(ymm2, ymm4, ymm18); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); - ymm14 = _mm256_fnmadd_pd(ymm2, ymm8, ymm14); - ymm19 = _mm256_fnmadd_pd(ymm2, ymm4, ymm19); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); - ymm15 = _mm256_fnmadd_pd(ymm2, ymm8, ymm15); - ymm20 = _mm256_fnmadd_pd(ymm2, ymm4, ymm20); - - //perform mul operation - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm1); - - a11 += cs_a; - - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(ROw2): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm9, ymm10); - ymm6 = _mm256_fnmadd_pd(ymm2, ymm5, ymm6); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm9, ymm11); - ymm7 = _mm256_fnmadd_pd(ymm2, ymm5, ymm7); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4)); - ymm12 = _mm256_fnmadd_pd(ymm2, ymm9, ymm12); - ymm17 = _mm256_fnmadd_pd(ymm2, ymm5, ymm17); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5)); - ymm13 = _mm256_fnmadd_pd(ymm2, ymm9, ymm13); - ymm18 = _mm256_fnmadd_pd(ymm2, ymm5, ymm18); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); - ymm14 = _mm256_fnmadd_pd(ymm2, ymm9, ymm14); - ymm19 = _mm256_fnmadd_pd(ymm2, ymm5, ymm19); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); - ymm15 = _mm256_fnmadd_pd(ymm2, ymm9, ymm15); - ymm20 = _mm256_fnmadd_pd(ymm2, ymm5, ymm20); - - //perform mul operation - ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); - ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm1); - - a11 += cs_a; - - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - //(ROw5): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm10, ymm11); - ymm7 = _mm256_fnmadd_pd(ymm2, ymm6, ymm7); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4)); - ymm12 = _mm256_fnmadd_pd(ymm2, ymm10, ymm12); - ymm17 = _mm256_fnmadd_pd(ymm2, ymm6, ymm17); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5)); - ymm13 = _mm256_fnmadd_pd(ymm2, ymm10, ymm13); - ymm18 = _mm256_fnmadd_pd(ymm2, ymm6, ymm18); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); - ymm14 = _mm256_fnmadd_pd(ymm2, ymm10, ymm14); - ymm19 = _mm256_fnmadd_pd(ymm2, ymm6, ymm19); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); - ymm15 = _mm256_fnmadd_pd(ymm2, ymm10, ymm15); - ymm20 = _mm256_fnmadd_pd(ymm2, ymm6, ymm20); - - //perform mul operation - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm1); - - a11 += cs_a; - - //extract a44 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); - - //(ROw4): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4)); - ymm12 = _mm256_fnmadd_pd(ymm2, ymm11, ymm12); - ymm17 = _mm256_fnmadd_pd(ymm2, ymm7, ymm17); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5)); - ymm13 = _mm256_fnmadd_pd(ymm2, ymm11, ymm13); - ymm18 = _mm256_fnmadd_pd(ymm2, ymm7, ymm18); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); - ymm14 = _mm256_fnmadd_pd(ymm2, ymm11, ymm14); - ymm19 = _mm256_fnmadd_pd(ymm2, ymm7, ymm19); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); - ymm15 = _mm256_fnmadd_pd(ymm2, ymm11, ymm15); - ymm20 = _mm256_fnmadd_pd(ymm2, ymm7, ymm20); - - //perform mul operation - ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); - ymm17 = DTRSM_SMALL_DIV_OR_SCALE(ymm17, ymm1); - - a11 += cs_a; - - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); - - //(ROw5): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5)); - ymm13 = _mm256_fnmadd_pd(ymm2, ymm12, ymm13); - ymm18 = _mm256_fnmadd_pd(ymm2, ymm17, ymm18); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); - ymm14 = _mm256_fnmadd_pd(ymm2, ymm12, ymm14); - ymm19 = _mm256_fnmadd_pd(ymm2, ymm17, ymm19); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); - ymm15 = _mm256_fnmadd_pd(ymm2, ymm12, ymm15); - ymm20 = _mm256_fnmadd_pd(ymm2, ymm17, ymm20); - - //perform mul operation - ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm1); - ymm18 = DTRSM_SMALL_DIV_OR_SCALE(ymm18, ymm1); - - a11 += cs_a; - - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); - - //(ROw6): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); - ymm14 = _mm256_fnmadd_pd(ymm2, ymm13, ymm14); - ymm19 = _mm256_fnmadd_pd(ymm2, ymm18, ymm19); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); - ymm15 = _mm256_fnmadd_pd(ymm2, ymm13, ymm15); - ymm20 = _mm256_fnmadd_pd(ymm2, ymm18, ymm20); - - //perform mul operation - ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); - ymm19 = DTRSM_SMALL_DIV_OR_SCALE(ymm19, ymm1); - - a11 += cs_a; - - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); - - //(ROw7): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); - ymm15 = _mm256_fnmadd_pd(ymm2, ymm14, ymm15); - ymm20 = _mm256_fnmadd_pd(ymm2, ymm19, ymm20); - - //perform mul operation - ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); - ymm20 = DTRSM_SMALL_DIV_OR_SCALE(ymm20, ymm1); - - a11 += cs_a; - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ///unpack high/// - ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] - ymm3 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2] - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] - - ///unpack high/// - ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[4][1] B11[5][1] B11[4][3] B11[5][3] - ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[6][1] B11[7][1] B11[6][3] B11[7][3] - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] - ymm3 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] - - _mm256_storeu_pd((double *)(b11 + 4), ymm0); //store B11[4][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm1); //store B11[5][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 4), ymm2); //store B11[6][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 3 + 4), ymm3); //store B11[7][0-3] - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm4, ymm5); - ymm3 = _mm256_unpacklo_pd(ymm6, ymm7); - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); - - ///unpack high/// - ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); - ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); - - _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); - _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm17, ymm18); - ymm3 = _mm256_unpacklo_pd(ymm19, ymm20); - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); - - ///unpack high/// - ymm17 = _mm256_unpackhi_pd(ymm17, ymm18); - ymm18 = _mm256_unpackhi_pd(ymm19, ymm20); - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm17, ymm18, 0x20); - - _mm256_storeu_pd((double *)(b11 + cs_b * 4 + 4), ymm0); - _mm256_storeu_pd((double *)(b11 + cs_b * 5 + 4), ymm1); + cs_a = bli_obj_col_stride(a); // column stride of A + rs_a = bli_obj_row_stride(a); // row stride of A } - - dim_t n_rem = n-j; - if(n_rem >= 4) + else { - a10 = D_A_pack; - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + j*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM - - k_iter = i; //number of times GEMM to be performed - - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - ymm12 = _mm256_setzero_pd(); - ymm13 = _mm256_setzero_pd(); - ymm14 = _mm256_setzero_pd(); - ymm15 = _mm256_setzero_pd(); - - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); - ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); - ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *2 + 4)); - ymm7 = _mm256_loadu_pd((double const *)(b11 + cs_b *3 + 4)); - - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); //B11[0-3][3] * alpha -= B01[0-3][3] - ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); - ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); - ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); - ymm7 = _mm256_fmsub_pd(ymm7, ymm16, ymm15); - - ///implement TRSM/// - - ///transpose of B11// - ///unpacklow/// - ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - - ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); - ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); - - //rearrange low elements - ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] - - ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); - ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); - - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - - ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); - ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); - - //rearrange high elements - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - - ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); - - //extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); - - //perform mul operation - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); - - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(ROw1): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm8, ymm10); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4)); - ymm12 = _mm256_fnmadd_pd(ymm2, ymm8, ymm12); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5)); - ymm13 = _mm256_fnmadd_pd(ymm2, ymm8, ymm13); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); - ymm14 = _mm256_fnmadd_pd(ymm2, ymm8, ymm14); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); - ymm15 = _mm256_fnmadd_pd(ymm2, ymm8, ymm15); - - //perform mul operation - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); - - a11 += cs_a; - - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(ROw2): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm9, ymm10); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm9, ymm11); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4)); - ymm12 = _mm256_fnmadd_pd(ymm2, ymm9, ymm12); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5)); - ymm13 = _mm256_fnmadd_pd(ymm2, ymm9, ymm13); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); - ymm14 = _mm256_fnmadd_pd(ymm2, ymm9, ymm14); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); - ymm15 = _mm256_fnmadd_pd(ymm2, ymm9, ymm15); - - //perform mul operation - ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); - - a11 += cs_a; - - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - //(ROw5): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm10, ymm11); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4)); - ymm12 = _mm256_fnmadd_pd(ymm2, ymm10, ymm12); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5)); - ymm13 = _mm256_fnmadd_pd(ymm2, ymm10, ymm13); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); - ymm14 = _mm256_fnmadd_pd(ymm2, ymm10, ymm14); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); - ymm15 = _mm256_fnmadd_pd(ymm2, ymm10, ymm15); - - //perform mul operation - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); - - a11 += cs_a; - - //extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); - - //(ROw4): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4)); - ymm12 = _mm256_fnmadd_pd(ymm2, ymm11, ymm12); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5)); - ymm13 = _mm256_fnmadd_pd(ymm2, ymm11, ymm13); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); - ymm14 = _mm256_fnmadd_pd(ymm2, ymm11, ymm14); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); - ymm15 = _mm256_fnmadd_pd(ymm2, ymm11, ymm15); - - //perform mul operation - ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); - - a11 += cs_a; - - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); - - //(ROw5): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5)); - ymm13 = _mm256_fnmadd_pd(ymm2, ymm12, ymm13); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); - ymm14 = _mm256_fnmadd_pd(ymm2, ymm12, ymm14); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); - ymm15 = _mm256_fnmadd_pd(ymm2, ymm12, ymm15); - - //perform mul operation - ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm1); - - a11 += cs_a; - - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); - - //(ROw6): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); - ymm14 = _mm256_fnmadd_pd(ymm2, ymm13, ymm14); - ymm2 = _mm256_broadcast_sd((double const *)(a11 +7)); - ymm15 = _mm256_fnmadd_pd(ymm2, ymm13, ymm15); - - //perform mul operation - ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); - - a11 += cs_a; - - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); + cs_a = bli_obj_row_stride(a); // row stride of A + rs_a = bli_obj_col_stride(a); // column stride of A + } + dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B - //(ROw7): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); - ymm15 = _mm256_fnmadd_pd(ymm2, ymm14, ymm15); + dim_t i, j, k; //loop variablse + dim_t k_iter; //determines the number of GEMM operations to be done - //perform mul operation - ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); + double ones = 1.0; + double zero = 0.0; + bool is_unitdiag = bli_obj_has_unit_diag(a); - a11 += cs_a; - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - - ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); - ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); + double AlphaVal = *(double *)AlphaObj->buffer; //value of Alpha + double* restrict L = a->buffer; //pointer to matrix A + double* restrict B = b->buffer; //pointer to matrix B - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + double *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks - ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); - ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); + gint_t required_packing_A = 1; + mem_t local_mem_buf_A_s = {0}; + double *D_A_pack = NULL; + double d11_pack[d_mr] __attribute__((aligned(64))); + rntm_t rntm; - ///unpack high/// - ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + bli_rntm_init_from_global( &rntm ); + bli_rntm_set_num_threads_only( 1, &rntm ); + bli_membrk_rntm_set_membrk( &rntm ); - ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); - ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); + siz_t buffer_size = bli_pool_block_size( + bli_membrk_pool( + bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), + bli_rntm_membrk(&rntm))); - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + if( (d_nr * n * sizeof(double)) > buffer_size) + return BLIS_NOT_YET_IMPLEMENTED; - ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); - ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); + if (required_packing_A == 1) + { + // Get the buffer from the pool. + bli_membrk_acquire_m(&rntm, + buffer_size, + BLIS_BITVAL_BUFFER_FOR_A_BLOCK, + &local_mem_buf_A_s); + if(FALSE==bli_mem_is_alloc(&local_mem_buf_A_s)) return BLIS_NULL_POINTER; + D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); + if(NULL==D_A_pack) return BLIS_NULL_POINTER; + } - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); - _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 4), ymm6); - _mm256_storeu_pd((double *)(b11 + cs_b * 3 + 4), ymm7); + //ymm scratch reginsters + __m256d ymm0, ymm1, ymm2, ymm3; + __m256d ymm4, ymm5, ymm6, ymm7; + __m256d ymm8, ymm9, ymm10, ymm11; + __m256d ymm12, ymm13, ymm14, ymm15; - n_rem -=4; - j +=4; + __m128d xmm5; - } + /* + Performs solving TRSM for 6 rows at a time from 0 to n/6 in steps of d_nr + a. Load and pack A (a01 block), the size of packing 6x6 to 6x (n-6) + First there will be no GEMM and no packing of a01 because it is only TRSM + b. Using packed a01 block and b10 block perform GEMM operation + c. Use GEMM outputs, perform TRSM operation using a11, b11 and update B + d. Repeat b for m cols of B in steps of d_mr + */ - if(n_rem) + for(j = 0; (j+d_nr-1) < n; j += d_nr) //loop along 'N' direction { - a10 = D_A_pack; - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + j*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM - - k_iter = i; //number of times GEMM to be performed - - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm12 = _mm256_setzero_pd(); - ymm13 = _mm256_setzero_pd(); - ymm14 = _mm256_setzero_pd(); - - if(3 == n_rem) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + a01 = L + j*rs_a; //pointer to block of A to be used in GEMM + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + //double *ptr_a10_dup = D_A_pack; - b01 += 1; //move to next row of B - a10 += p_lda; - } + dim_t p_lda = j; // packed leading dimension + // perform copy of A to packed buffer D_A_pack - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); - ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); - ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *2 + 4)); - - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] - ymm3 = _mm256_broadcast_sd((double const *)(&ones)); - - ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); - ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); - ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); - ymm7 = _mm256_broadcast_sd((double const *)(&ones)); - } - else if(2 == n_rem) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + if(transa) { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + /* + Pack current A block (a01) into packed buffer memory D_A_pack + a. This a10 block is used in GEMM portion only and this + a01 block size will be increasing by d_nr for every next iteration + until it reaches 6x(n-6) which is the maximum GEMM alone block size in A + b. This packed buffer is reused to calculate all m cols of B matrix + */ + bli_dtrsm_small_pack('R', j, 1, a01, cs_a, D_A_pack, p_lda,d_nr); - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + /* + Pack 6 diagonal elements of A block into an array + a. This helps in utilze cache line efficiently in TRSM operation + b. store ones when input is unit diagonal + */ - b01 += 1; //move to next row of B - a10 += p_lda; + dtrsm_small_pack_diag_element(is_unitdiag,a11,cs_a,d11_pack,d_nr); + } + else + { + bli_dtrsm_small_pack('R', j, 0, a01, rs_a, D_A_pack, p_lda,d_nr); + dtrsm_small_pack_diag_element(is_unitdiag,a11,rs_a,d11_pack,d_nr); } - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - - ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); - ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); - - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] - ymm2 = _mm256_broadcast_sd((double const *)(&ones)); - ymm3 = _mm256_broadcast_sd((double const *)(&ones)); - - ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); - ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); - ymm6 = _mm256_broadcast_sd((double const *)(&ones)); - ymm7 = _mm256_broadcast_sd((double const *)(&ones)); - } - else if(1 == n_rem) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + /* + a. Perform GEMM using a01, b10. + b. Perform TRSM on a11, b11 + c. This loop GEMM+TRSM loops operates with 8x6 block size + along m dimension for every d_mr columns of B10 where + packed A buffer is reused in computing all m cols of B. + d. Same approach is used in remaining fringe cases. + */ + for(i = 0; (i+d_mr-1) < m; i += d_mr) //loop along 'M' direction { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - b01 += 1; //move to next row of B - a10 += p_lda; - } + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + /* + Peform GEMM between a01 and b10 blocks + For first itteration there will be no GEMM operation + where k_iter are zero + */ + BLIS_DTRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - - ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); - - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_broadcast_sd((double const *)(&ones)); - ymm2 = _mm256_broadcast_sd((double const *)(&ones)); - ymm3 = _mm256_broadcast_sd((double const *)(&ones)); - - ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); - ymm5 = _mm256_broadcast_sd((double const *)(&ones)); - ymm6 = _mm256_broadcast_sd((double const *)(&ones)); - ymm7 = _mm256_broadcast_sd((double const *)(&ones)); - } - ///implement TRSM/// - - ///transpose of B11// - ///unpacklow/// - ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - - ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); - ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); - - //rearrange low elements - ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] - - ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); - ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); - - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - - ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); - ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); - - //rearrange high elements - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - - ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); - - //extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); - - //perform mul operation - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); - - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(ROw1): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm8, ymm10); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4)); - ymm12 = _mm256_fnmadd_pd(ymm2, ymm8, ymm12); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5)); - ymm13 = _mm256_fnmadd_pd(ymm2, ymm8, ymm13); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); - ymm14 = _mm256_fnmadd_pd(ymm2, ymm8, ymm14); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); - ymm15 = _mm256_fnmadd_pd(ymm2, ymm8, ymm15); - - //perform mul operation - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); - - a11 += cs_a; - - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(ROw2): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm9, ymm10); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm9, ymm11); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4)); - ymm12 = _mm256_fnmadd_pd(ymm2, ymm9, ymm12); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5)); - ymm13 = _mm256_fnmadd_pd(ymm2, ymm9, ymm13); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); - ymm14 = _mm256_fnmadd_pd(ymm2, ymm9, ymm14); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); - ymm15 = _mm256_fnmadd_pd(ymm2, ymm9, ymm15); + /* + Load b11 of size 8x6 and multiply with alpha + Add the GEMM output to b11 + and peform TRSM operation. + */ - //perform mul operation - ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + BLIS_PRE_DTRSM_SMALL_6x8(AlphaVal,b11,cs_b) - a11 += cs_a; - - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - //(ROw5): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm10, ymm11); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4)); - ymm12 = _mm256_fnmadd_pd(ymm2, ymm10, ymm12); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5)); - ymm13 = _mm256_fnmadd_pd(ymm2, ymm10, ymm13); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); - ymm14 = _mm256_fnmadd_pd(ymm2, ymm10, ymm14); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); - ymm15 = _mm256_fnmadd_pd(ymm2, ymm10, ymm15); - - //perform mul operation - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); - - a11 += cs_a; - - //extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); - - //(ROw4): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4)); - ymm12 = _mm256_fnmadd_pd(ymm2, ymm11, ymm12); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5)); - ymm13 = _mm256_fnmadd_pd(ymm2, ymm11, ymm13); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); - ymm14 = _mm256_fnmadd_pd(ymm2, ymm11, ymm14); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); - ymm15 = _mm256_fnmadd_pd(ymm2, ymm11, ymm15); - - //perform mul operation - ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); - - a11 += cs_a; - - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); - - //(ROw5): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5)); - ymm13 = _mm256_fnmadd_pd(ymm2, ymm12, ymm13); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); - ymm14 = _mm256_fnmadd_pd(ymm2, ymm12, ymm14); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); - ymm15 = _mm256_fnmadd_pd(ymm2, ymm12, ymm15); - - //perform mul operation - ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm1); - - a11 += cs_a; - - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); - - //(ROw6): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); - ymm14 = _mm256_fnmadd_pd(ymm2, ymm13, ymm14); - ymm2 = _mm256_broadcast_sd((double const *)(a11 +7)); - ymm15 = _mm256_fnmadd_pd(ymm2, ymm13, ymm15); - - //perform mul operation - ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); - - a11 += cs_a; - - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); - - //(ROw7): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); - ymm15 = _mm256_fnmadd_pd(ymm2, ymm14, ymm15); - - //perform mul operation - ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); - - a11 += cs_a; - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - - ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); - ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); - ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); - - ///unpack high/// - ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - - ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); - ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); - ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); - - if(3 == n_rem) - { - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); - _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 4), ymm6); - } - else if(2 == n_rem) - { - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); - _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); - } - else if(1 == n_rem) - { - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); - } - } - } + ///implement TRSM/// - /* - Reminder cases starts here: - a. Similar logic and code flow used in computing full block (8x6) - above holds for reminder cases too. - */ - - dim_t m_rem = m-i; - //implementation for reamainder rows(when 'M' is not a multiple of D_MR) - if(m_rem>=4) - { - a10 = L + (i); //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - double *ptr_a10_dup = D_A_pack; - double *ptr_a11_dup = a11; - - dim_t p_lda = 4; // packed leading dimension - for(dim_t x =0;x < i;x++) - { - ymm0 = _mm256_loadu_pd((double const *)(a10 + cs_a * x)); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * x), ymm0); - } + /* + Compute 6x8 TRSM block by using GEMM block output in register + a. The 6x8 input (gemm outputs) are stored in combinations of ymm registers + 1. ymm3, ymm4 2. ymm5, ymm6 3. ymm7, ymm8, 4. ymm9, ymm10 + 5. ymm11, ymm12 6. ymm13,ymm14 + b. Towards the end TRSM output will be stored back into b11 + */ - ymm4 = _mm256_broadcast_sd((double const *)&ones); - if(!is_unitdiag) - { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_sd((double const *)(a11)); - ymm1 = _mm256_broadcast_sd((double const *)(a11+cs_a*1 + 1)); - ymm2 = _mm256_broadcast_sd((double const *)(a11+cs_a*2 + 2)); - ymm3 = _mm256_broadcast_sd((double const *)(a11+cs_a*3 + 3)); + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - //Pick one element each column and create a 4 element vector and store - ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); - ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm0); - #ifdef BLIS_DISABLE_TRSM_PREINVERSION - ymm4 = ymm1; - #endif - #ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm4 = _mm256_div_pd(ymm4, ymm1); - #endif - } - _mm256_storeu_pd((double *)(d11_pack), ymm4); + //extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - for(j = 0; (j+D_NR-1) < n; j += D_NR) //loop along 'N' dimension - { - a10 = D_A_pack; //pointer to block of A to be used for GEMM - a11 = ptr_a11_dup; //pointer to block of A to be used for TRSM - b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM - b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM + //(row 1):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); - k_iter = i; //number of times GEMM operation to be done + ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); + ymm6 = _mm256_fnmadd_pd(ymm1, ymm4, ymm6); - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); - #endif + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); + ymm8 = _mm256_fnmadd_pd(ymm1, ymm4, ymm8); - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm3, ymm9); + ymm10 = _mm256_fnmadd_pd(ymm1, ymm4, ymm10); - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); - ymm4 = _mm256_fmadd_pd(ymm2, ymm0, ymm4); + ymm11 = _mm256_fnmadd_pd(ymm1, ymm3, ymm11); + ymm12 = _mm256_fnmadd_pd(ymm1, ymm4, ymm12); - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); - b01 += 1; //move to next row of B - a10 += p_lda; - } + ymm13 = _mm256_fnmadd_pd(ymm1, ymm3, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm1, ymm4, ymm14); - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm0); - ///implement TRSM/// - ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); //B11[0-3][3] * alpha -= B01[0-3][3] + a11 += cs_a; - ///transpose of B11// - ///unpacklow/// - ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + //extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - //rearrange low elements - ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); + ymm8 = _mm256_fnmadd_pd(ymm1, ymm6, ymm8); - //rearrange high elements - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm5, ymm9); + ymm10 = _mm256_fnmadd_pd(ymm1, ymm6, ymm10); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); + ymm11 = _mm256_fnmadd_pd(ymm1, ymm5, ymm11); + ymm12 = _mm256_fnmadd_pd(ymm1, ymm6, ymm12); - ymm16 = _mm256_broadcast_sd((double const *)(&ones)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); - ////unpacklow//// - ymm7 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm13 = _mm256_fnmadd_pd(ymm1, ymm5, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm1, ymm6, ymm14); - //rearrange low elements - ymm4 = _mm256_permute2f128_pd(ymm7,ymm16,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm6 = _mm256_permute2f128_pd(ymm7,ymm16,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm0); - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + a11 += cs_a; - //rearrange high elements - ymm5 = _mm256_permute2f128_pd(ymm0,ymm16,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm7 = _mm256_permute2f128_pd(ymm0,ymm16,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + //extract a33 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - //b11 transpose end + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); - //extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm7, ymm9); + ymm10 = _mm256_fnmadd_pd(ymm1, ymm8, ymm10); - //perform mul operation - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); - ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm1); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + ymm11 = _mm256_fnmadd_pd(ymm1, ymm7, ymm11); + ymm12 = _mm256_fnmadd_pd(ymm1, ymm8, ymm12); - //(ROw1): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); - ymm5 = _mm256_fnmadd_pd(ymm2, ymm4, ymm5); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm8, ymm10); - ymm6 = _mm256_fnmadd_pd(ymm2, ymm4, ymm6); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); - ymm7 = _mm256_fnmadd_pd(ymm2, ymm4, ymm7); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); - //perform mul operation - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm1); + ymm13 = _mm256_fnmadd_pd(ymm1, ymm7, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm1, ymm8, ymm14); - a11 += cs_a; + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm0); - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + a11 += cs_a; - //(ROw2): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm9, ymm10); - ymm6 = _mm256_fnmadd_pd(ymm2, ymm5, ymm6); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm9, ymm11); - ymm7 = _mm256_fnmadd_pd(ymm2, ymm5, ymm7); + //extract a44 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); - //perform mul operation - ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); - ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm1); + //(row 4):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); - a11 += cs_a; + ymm11 = _mm256_fnmadd_pd(ymm1, ymm9, ymm11); + ymm12 = _mm256_fnmadd_pd(ymm1, ymm10, ymm12); - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); - //(ROw5): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm10, ymm11); - ymm7 = _mm256_fnmadd_pd(ymm2, ymm6, ymm7); + ymm13 = _mm256_fnmadd_pd(ymm1, ymm9, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm1, ymm10, ymm14); - //perform mul operation - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm1); + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); + ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm0); - a11 += cs_a; + a11 += cs_a; - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + //extract a55 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + //(Row 5): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); - ///unpack high/// - ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + ymm13 = _mm256_fnmadd_pd(ymm1, ymm11, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm1, ymm12, ymm14); - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); + ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm0); - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + 4), ymm4); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + _mm256_storeu_pd((double *)(b11 + cs_b + 4), ymm6); + _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); + _mm256_storeu_pd((double *)(b11 + cs_b*2 + 4), ymm8); + _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); + _mm256_storeu_pd((double *)(b11 + cs_b*3 + 4), ymm10); + _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); + _mm256_storeu_pd((double *)(b11 + cs_b*4 + 4), ymm12); + _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); + _mm256_storeu_pd((double *)(b11 + cs_b*5 + 4), ymm14); + } - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + dim_t m_remainder = m - i; + if(m_remainder >= 4) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - ///unpack high/// - ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_6nx4m(a01,b10,cs_b,p_lda,k_iter) - _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store B11[1][0-3] - } - dim_t n_rem = n-j; - if(n_rem >= 4) - { - a10 = D_A_pack; - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + j*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM + // Load b11 of size 4x6 and multiply with alpha + BLIS_PRE_DTRSM_SMALL_6x4(AlphaVal,b11,cs_b) - k_iter = i; //number of times GEMM to be performed + ///implement TRSM/// - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - #endif + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); + //extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); + //(row 1):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm3, ymm9); - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); + ymm11 = _mm256_fnmadd_pd(ymm1, ymm3, ymm11); - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_pd(ymm1, ymm3, ymm13); - b01 += 1; //move to next row of B - a10 += p_lda; - } - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - ///implement TRSM/// + a11 += cs_a; - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + //extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - ///transpose of B11// - ///unpacklow/// - ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); - //rearrange low elements - ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm5, ymm9); - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); + ymm11 = _mm256_fnmadd_pd(ymm1, ymm5, ymm11); - //rearrange high elements - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_pd(ymm1, ymm5, ymm13); - //extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - //perform mul operation - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); + a11 += cs_a; - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + //extract a33 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - //(ROw1): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm8, ymm10); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm7, ymm9); - //perform mul operation - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); + ymm11 = _mm256_fnmadd_pd(ymm1, ymm7, ymm11); - a11 += cs_a; + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_pd(ymm1, ymm7, ymm13); - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - //(ROw2): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm9, ymm10); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm9, ymm11); + a11 += cs_a; - //perform mul operation - ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + //extract a44 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); - a11 += cs_a; + //(row 4):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); + ymm11 = _mm256_fnmadd_pd(ymm1, ymm9, ymm11); - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_pd(ymm1, ymm9, ymm13); - //(ROw5): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm10, ymm11); + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); - //perform mul operation - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); + a11 += cs_a; - a11 += cs_a; + //extract a55 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + //(Row 5): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_pd(ymm1, ymm11, ymm13); - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - ///unpack high/// - ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); + _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); + _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); + _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + m_remainder -= 4; + i += 4; + } - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] + if(m_remainder == 3) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - n_rem -= 4; - j += 4; - } - if(n_rem) - { - a10 = D_A_pack; - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + j*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - k_iter = i; //number of times GEMM to be performed + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_6nx4m(a01,b10,cs_b,p_lda,k_iter) - if(3 == n_rem) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); + // Load b11 of size 4x6 and multiply with alpha + BLIS_PRE_DTRSM_SMALL_6x4(AlphaVal,b11,cs_b) - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ///implement TRSM/// - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + //extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - b01 += 1; //move to next row of B - a10 += p_lda; - } + //(row 1):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] - ymm3 = _mm256_broadcast_sd((double const *)(&ones)); - } - else if(2 == n_rem) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm3, ymm9); - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); + ymm11 = _mm256_fnmadd_pd(ymm1, ymm3, ymm11); - b01 += 1; //move to next row of B - a10 += p_lda; - } + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_pd(ymm1, ymm3, ymm13); - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + a11 += cs_a; - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] - ymm2 = _mm256_broadcast_sd((double const *)(&ones)); - ymm3 = _mm256_broadcast_sd((double const *)(&ones)); - } - else if(1 == n_rem) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); + //extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); - b01 += 1; //move to next row of B - a10 += p_lda; - } + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm5, ymm9); - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); + ymm11 = _mm256_fnmadd_pd(ymm1, ymm5, ymm11); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_pd(ymm1, ymm5, ymm13); - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_broadcast_sd((double const *)(&ones)); - ymm2 = _mm256_broadcast_sd((double const *)(&ones)); - ymm3 = _mm256_broadcast_sd((double const *)(&ones)); - } + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - ///transpose of B11// - ///unpacklow/// - ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + a11 += cs_a; - //rearrange low elements - ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + //extract a33 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm7, ymm9); - //rearrange high elements - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); + ymm11 = _mm256_fnmadd_pd(ymm1, ymm7, ymm11); - //extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_pd(ymm1, ymm7, ymm13); - //perform mul operation - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + a11 += cs_a; - //(ROw1): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm8, ymm10); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); + //extract a44 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); - //perform mul operation - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); + //(row 4):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); + ymm11 = _mm256_fnmadd_pd(ymm1, ymm9, ymm11); - a11 += cs_a; + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_pd(ymm1, ymm9, ymm13); - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); - //(ROw2): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm9, ymm10); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm9, ymm11); + a11 += cs_a; - //perform mul operation - ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + //extract a55 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); - a11 += cs_a; + //(Row 5): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_pd(ymm1, ymm11, ymm13); - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - //(ROw5): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm10, ymm11); + ymm0 = _mm256_loadu_pd((double const *)b11); + ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x07); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x07); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x07); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x07); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm11 = _mm256_blend_pd(ymm0, ymm11, 0x07); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm13 = _mm256_blend_pd(ymm0, ymm13, 0x07); - //perform mul operation - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); + _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); + _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); + _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); - a11 += cs_a; + m_remainder -= 3; + i += 3; + } + else if(m_remainder == 2) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS - ///unpack high/// - ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_6nx4m(a01,b10,cs_b,p_lda,k_iter) - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + // Load b11 of size 4x6 and multiply with alpha + BLIS_PRE_DTRSM_SMALL_6x4(AlphaVal,b11,cs_b) - if(3 == n_rem) - { - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] - } - else if(2 == n_rem) - { - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - } - else if(1 == n_rem) - { - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - } - } - m_rem -=4; - i +=4; - } - - if(m_rem) - { - a10 = L + (i); //pointer to block of A to be used for GEMM - // Do transpose for a10 & store in D_A_pack - double *ptr_a10_dup = D_A_pack; - if(3 == m_rem) // Repetative A blocks will be 3*3 - { - dim_t p_lda = 4; // packed leading dimension - for(dim_t x=0;x= 4)) - { - a10 = D_A_pack; //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM - b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM - - k_iter = i; //number of times GEMM to be performed - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - #endif + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm3, ymm9); - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); + ymm11 = _mm256_fnmadd_pd(ymm1, ymm3, ymm11); - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_pd(ymm1, ymm3, ymm13); - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + a11 += cs_a; - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + //extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - b01 += 1; //move to next row of B - a10 += p_lda; - } + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); - ymm3 = _mm256_broadcast_sd((double const *)(b11 + cs_b*3 + 2)); - ymm3 = _mm256_insertf128_pd(ymm3, xmm5, 0); - - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x08); - - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - xmm5 = _mm256_extractf128_pd(ymm3, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 3),xmm5); - _mm_storel_pd((b11 + cs_b * 3 + 2), _mm256_extractf128_pd(ymm3, 1)); - - dtrsm_AlXB_ref(a11, b11, m_rem, 4, cs_a, cs_b, is_unitdiag); - n_rem -= 4; - j +=4; - } - - if(n_rem) - { - a10 = D_A_pack; //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM - b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM - - k_iter = i; //number of times GEMM to be performed - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - #endif + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm5, ymm9); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); + ymm11 = _mm256_fnmadd_pd(ymm1, ymm5, ymm11); - if(3 == n_rem) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_pd(ymm1, ymm5, ymm13); - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + a11 += cs_a; - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + //extract a33 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - b01 += 1; //move to next row of B - a10 += p_lda; - } + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm7, ymm9); - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); + ymm11 = _mm256_fnmadd_pd(ymm1, ymm7, ymm11); - ///implement TRSM/// - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2)); - ymm2 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 2 + 2)); - ymm2 = _mm256_insertf128_pd(ymm2, xmm5, 0); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_pd(ymm1, ymm7, ymm13); - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08); + a11 += cs_a; - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - xmm5 = _mm256_extractf128_pd(ymm2, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 2), xmm5); - _mm_storel_pd((b11 + cs_b * 2 + 2), _mm256_extractf128_pd(ymm2, 1)); + //extract a44 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); - dtrsm_AlXB_ref(a11, b11, m_rem, 3, cs_a, cs_b, is_unitdiag); - } - else if(2 == n_rem) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); + //(row 4):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); + ymm11 = _mm256_fnmadd_pd(ymm1, ymm9, ymm11); - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_pd(ymm1, ymm9, ymm13); - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); - b01 += 1; //move to next row of B - a10 += p_lda; - } + a11 += cs_a; - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + //extract a55 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); - ///implement TRSM/// - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1)); - ymm1 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 1 + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); + //(Row 5): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_pd(ymm1, ymm11, ymm13); - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); + ymm0 = _mm256_loadu_pd((double const *)b11); + ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x03); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x03); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x03); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x03); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm11 = _mm256_blend_pd(ymm0, ymm11, 0x03); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm13 = _mm256_blend_pd(ymm0, ymm13, 0x03); - _mm256_storeu_pd((double *)(b11), ymm0); - xmm5 = _mm256_extractf128_pd(ymm1, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 1), xmm5); - _mm_storel_pd((b11 + cs_b * 1 + 2), _mm256_extractf128_pd(ymm1, 1)); + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); + _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); + _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); + _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); - dtrsm_AlXB_ref(a11, b11, m_rem, 2, cs_a, cs_b, is_unitdiag); + m_remainder -= 2; + i += 2; } - else if(1 == n_rem) + else if(m_remainder == 1) { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - b01 += 1; //move to next row of B - a10 += p_lda; - } + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_6nx4m(a01,b10,cs_b,p_lda,k_iter) - ///implement TRSM/// - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 0)); - ymm0 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 0 + 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); + // Load b11 of size 4x6 and multiply with alpha + BLIS_PRE_DTRSM_SMALL_6x4(AlphaVal,b11,cs_b) - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ///implement TRSM/// - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - xmm5 = _mm256_extractf128_pd(ymm0, 0); - _mm_storeu_pd((double *)(b11), xmm5); - _mm_storel_pd((b11 + 2), _mm256_extractf128_pd(ymm0, 1)); + //extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - dtrsm_AlXB_ref(a11, b11, m_rem, 1, cs_a, cs_b, is_unitdiag); - } - } - } - else if(2 == m_rem) // Repetative A blocks will be 2*2 - { - dim_t p_lda = 4; // packed leading dimension - for(dim_t x=0;x= 4)) - { - a10 = D_A_pack; //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM - b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM - - k_iter = i; //number of times GEMM to be performed - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - #endif + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm7, ymm9); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); + ymm11 = _mm256_fnmadd_pd(ymm1, ymm7, ymm11); - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_pd(ymm1, ymm7, ymm13); - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + a11 += cs_a; - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + //extract a44 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + //(row 4):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); + ymm11 = _mm256_fnmadd_pd(ymm1, ymm9, ymm11); - b01 += 1; //move to next row of B - a10 += p_lda; - } + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_pd(ymm1, ymm9, ymm13); - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); - ymm3 = _mm256_insertf128_pd(ymm3, xmm5, 0); - - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0C); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0C); - - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - xmm5 = _mm256_extractf128_pd(ymm3, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 3), xmm5); - - dtrsm_AlXB_ref(a11, b11, m_rem, 4, cs_a, cs_b, is_unitdiag); - n_rem -= 4; - j +=4; - } - if(n_rem) - { - a10 = D_A_pack; //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM - b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM - - k_iter = i; //number of times GEMM to be performed - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - #endif + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); + a11 += cs_a; - if(3 == n_rem) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); + //extract a55 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + //(Row 5): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_pd(ymm1, ymm11, ymm13); - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm0 = _mm256_loadu_pd((double const *)b11); + ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x01); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x01); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x01); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x01); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm11 = _mm256_blend_pd(ymm0, ymm11, 0x01); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm13 = _mm256_blend_pd(ymm0, ymm13, 0x01); - b01 += 1; //move to next row of B - a10 += p_lda; - } + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); + _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); + _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); + _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + m_remainder -= 1; + i += 1; + } + } - ///implement TRSM/// + dim_t n_remainder = n - j; - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2)); - ymm2 = _mm256_insertf128_pd(ymm2, xmm5, 0); + /* + Reminder cases starts here: + a. Similar logic and code flow used in computing full block (6x8) + above holds for reminder cases too. + */ - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + if(n_remainder >= 4) + { + a01 = L + j*rs_a; //pointer to block of A to be used in GEMM + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0C); + double *ptr_a10_dup = D_A_pack; - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - xmm5 = _mm256_extractf128_pd(ymm2, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 2), xmm5); + dim_t p_lda = j; // packed leading dimension + // perform copy of A to packed buffer D_A_pack - dtrsm_AlXB_ref(a11, b11, m_rem, 3, cs_a, cs_b, is_unitdiag); - } - else if(2 == n_rem) + if(transa) { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + for(dim_t x =0;x < p_lda;x+=d_nr) + { + ymm0 = _mm256_loadu_pd((double const *)(a01)); + ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); + ymm2 = _mm256_loadu_pd((double const *)(a01 + cs_a * 2)); + ymm3 = _mm256_loadu_pd((double const *)(a01 + cs_a * 3)); - ///implement TRSM/// - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - xmm5 = _mm256_extractf128_pd(ymm1, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 1), xmm5); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - dtrsm_AlXB_ref(a11, b11, m_rem, 2, cs_a, cs_b, is_unitdiag); - } - else if(1 == n_rem) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm0 = _mm256_loadu_pd((double const *)(a01 + cs_a * 4)); + ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a * 5)); - b01 += 1; //move to next row of B - a10 += p_lda; - } + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_broadcast_sd((double const *)&zero); - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); - ///implement TRSM/// - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 0)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_broadcast_sd((double const *)&zero); - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - xmm5 = _mm256_extractf128_pd(ymm0, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 0), xmm5); + _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_extractf128_pd(ymm6,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_extractf128_pd(ymm7,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_extractf128_pd(ymm8,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_extractf128_pd(ymm9,0)); - dtrsm_AlXB_ref(a11, b11, m_rem, 1, cs_a, cs_b, is_unitdiag); + a01 += d_nr*cs_a; + ptr_a10_dup += d_nr; + } } - - } - m_rem -=2; - i+=2; - } - else if(1 == m_rem) // Repetative A blocks will be 1*1 - { - dim_t p_lda = 4; // packed leading dimension - for(dim_t x=0;x= 4)) - { - a10 = D_A_pack; //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM - b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM - - k_iter = i; //number of times GEMM to be performed - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - #endif - - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + for(i = 0; (i+d_mr-1) < m; i += d_mr) //loop along 'M' direction { - ymm0 = _mm256_loadu_pd((double const *)(a10)); + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_4nx8m(a01,b10,cs_b,p_lda,k_iter) - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + BLIS_PRE_DTRSM_SMALL_4x8(AlphaVal,b11,cs_b) - b01 += 1; //move to next row of B - a10 += p_lda; - } + ///implement TRSM/// - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_broadcast_sd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_broadcast_sd((double const *)(b11 + cs_b *2)); - ymm3 = _mm256_broadcast_sd((double const *)(b11 + cs_b *3)); - - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0E); - - _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm0, 0)); - _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm1, 0)); - _mm_storel_pd((b11 + cs_b * 2), _mm256_extractf128_pd(ymm2, 0)); - _mm_storel_pd((b11 + cs_b * 3), _mm256_extractf128_pd(ymm3, 0)); - - dtrsm_AlXB_ref(a11, b11, m_rem, 4, cs_a, cs_b, is_unitdiag); - n_rem -= 4; - j+=4; - } - - if(n_rem) - { - a10 = D_A_pack; //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM - b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM - - k_iter = i; //number of times GEMM to be performed - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - #endif + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm0); - if(3 == n_rem) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); + //extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + //(row 1):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); + ymm6 = _mm256_fnmadd_pd(ymm1, ymm4, ymm6); - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); - b01 += 1; //move to next row of B - a10 += p_lda; - } + ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); + ymm8 = _mm256_fnmadd_pd(ymm1, ymm4, ymm8); - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); - ///implement TRSM/// + ymm9 = _mm256_fnmadd_pd(ymm1, ymm3, ymm9); + ymm10 = _mm256_fnmadd_pd(ymm1, ymm4, ymm10); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_broadcast_sd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_broadcast_sd((double const *)(b11 + cs_b *2)); + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm0); - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + a11 += cs_a; - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); + //extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm0, 0)); - _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm1, 0)); - _mm_storel_pd((b11 + cs_b * 2), _mm256_extractf128_pd(ymm2, 0)); + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); - dtrsm_AlXB_ref(a11, b11, m_rem, 3, cs_a, cs_b, is_unitdiag); - } - else if(2 == n_rem) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); + ymm8 = _mm256_fnmadd_pd(ymm1, ymm6, ymm8); - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm5, ymm9); + ymm10 = _mm256_fnmadd_pd(ymm1, ymm6, ymm10); - b01 += 1; //move to next row of B - a10 += p_lda; - } + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm0); - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + a11 += cs_a; - ///implement TRSM/// - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_broadcast_sd((double const *)(b11 + cs_b *1)); + //extract a33 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm7, ymm9); + ymm10 = _mm256_fnmadd_pd(ymm1, ymm8, ymm10); - _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm0, 0)); - _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm1, 0)); + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm0); - dtrsm_AlXB_ref(a11, b11, m_rem, 2, cs_a, cs_b, is_unitdiag); + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + 4), ymm4); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + _mm256_storeu_pd((double *)(b11 + cs_b + 4), ymm6); + _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); + _mm256_storeu_pd((double *)(b11 + cs_b*2 + 4), ymm8); + _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); + _mm256_storeu_pd((double *)(b11 + cs_b*3 + 4), ymm10); } - else if(1 == n_rem) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - b01 += 1; //move to next row of B - a10 += p_lda; - } + dim_t m_remainder = m - i; + if(m_remainder >= 4) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - ///implement TRSM/// - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b *0)); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_4nx4m(a01,b10,cs_b,p_lda,k_iter) - _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm0, 0)); + ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - dtrsm_AlXB_ref(a11, b11, m_rem, 1, cs_a, cs_b, is_unitdiag); - } - } - m_rem -=1; - i+=1; - } - } + ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - if ((required_packing_A == 1) && - bli_mem_is_alloc( &local_mem_buf_A_s )) - { - bli_membrk_release(&rntm, &local_mem_buf_A_s); - } + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - return BLIS_SUCCESS; -} + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 -/* TRSM for the Left Upper case AX = alpha * B, Double precision - * A is Left side, upper-triangular, transpose, non-unit diagonal - * dimensions A: mxm X: mxn B: mxn - a10 ----> b11---> - *********** ***************** - * * * * *b01*b11* * * - **a10 * * a11 b11 * * * * * - ********* | | ***************** - *a11* * | | * * * * * - * * * | | * * * * * - ****** v v ***************** - * * * * * * * - * * * * * * * - * * ***************** - * - a11---> -*/ -BLIS_INLINE err_t bli_dtrsm_small_AutXB -( - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl -) -{ - dim_t m = bli_obj_length(b); // number of rows of matrix B - dim_t n = bli_obj_width(b); // number of columns of matrix B - - dim_t cs_a = bli_obj_col_stride(a); // column stride of A - dim_t cs_b = bli_obj_col_stride(b); // column stride of B - - dim_t i, j, k; //loop variables - dim_t k_iter; //number of times GEMM to be performed - - double AlphaVal = *(double *)AlphaObj->buffer; //value of alpha - double *L = a->buffer; //pointer to matrix A - double *B = b->buffer; //pointer to matrix B - - double *a10, *a11, *b01, *b11; //pointers that point to blocks for GEMM and TRSM - - double ones = 1.0; - bool is_unitdiag = bli_obj_has_unit_diag(a); - - //scratch registers - __m256d ymm0, ymm1, ymm2, ymm3; - __m256d ymm4, ymm5, ymm6, ymm7; - __m256d ymm8, ymm9, ymm10, ymm11; - __m256d ymm12, ymm13, ymm14, ymm15; - __m256d ymm16, ymm17, ymm18, ymm19; - __m256d ymm20; - - __m128d xmm5; - - gint_t required_packing_A = 1; - mem_t local_mem_buf_A_s = {0}; - double *D_A_pack = NULL; - double d11_pack[D_MR] __attribute__((aligned(64))); - rntm_t rntm; - - bli_rntm_init_from_global( &rntm ); - bli_rntm_set_num_threads_only( 1, &rntm ); - bli_membrk_rntm_set_membrk( &rntm ); - - siz_t buffer_size = bli_pool_block_size( - bli_membrk_pool( - bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), - bli_rntm_membrk(&rntm))); - - if( (D_MR * m * sizeof(double)) > buffer_size) - return BLIS_NOT_YET_IMPLEMENTED; - - if (required_packing_A == 1) - { - // Get the buffer from the pool. - bli_membrk_acquire_m(&rntm, - buffer_size, - BLIS_BITVAL_BUFFER_FOR_A_BLOCK, - &local_mem_buf_A_s); - if(FALSE==bli_mem_is_alloc(&local_mem_buf_A_s)) return BLIS_NULL_POINTER; - D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); - if(NULL==D_A_pack) return BLIS_NULL_POINTER; - } - /* - Performs solving TRSM for 8 colmns at a time from 0 to m/8 in steps of D_MR - a. Load, transpose, Pack A (a10 block), the size of packing 8x6 to 8x (m-8) - First there will be no GEMM and no packing of a10 because it is only TRSM - b. Using packed a10 block and b01 block perform GEMM operation - c. Use GEMM outputs, perform TRSM operaton using a11, b11 and update B - d. Repeat b,c for n rows of B in steps of D_NR - */ - for(i = 0;(i+D_MR-1) < m; i += D_MR) //loop along 'M' dimension - { - a10 = L + (i*cs_a); //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); - dim_t p_lda = D_MR; // packed leading dimension + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 - /* - Load, tranpose and pack current A block (a10) into packed buffer memory D_A_pack - a. This a10 block is used in GEMM portion only and this - a10 block size will be increasing by D_MR for every next itteration - untill it reaches 8x(m-8) which is the maximum GEMM alone block size in A - b. This packed buffer is reused to calculate all n rows of B matrix - */ - bli_dtrsm_small_pack('L', i, 1, a10, cs_a, D_A_pack, p_lda); - - /* - Pack 8 diagonal elements of A block into an array - a. This helps in utilze cache line efficiently in TRSM operation - b. store ones when input is unit diagonal - */ - dtrsm_small_pack_diag_element(is_unitdiag,a11,cs_a,d11_pack,D_MR); + ///implement TRSM/// - /* - a. Perform GEMM using a10, b01. - b. Perform TRSM on a11, b11 - c. This loop GEMM+TRSM loops operates with 8x6 block size - along n dimension for every D_NR rows of b01 where - packed A buffer is reused in computing all n rows of B. - d. Same approch is used in remaining fringe cases. - */ - dim_t temp = n - D_NR + 1; - for(j = 0; j < temp; j += D_NR) //loop along 'N' dimension - { - a10 = D_A_pack; - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + j*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM - - k_iter = i; - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); - #endif - - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS - - /* - Peform GEMM between a10 and b01 blocks - For first itteration there will be no GEMM operation - where k_iter are zero - */ - BLIS_DTRSM_SMALL_GEMM_8x6(a10,b01,cs_b,p_lda,k_iter) - - /* - Load b11 of size 6x8 and multiply with alpha - Add the GEMM output and perform inregister transose of b11 - to peform TRSM operation. - */ - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); //B11[0-3][3] * alpha -= B01[0-3][3] - - ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - - ymm0 = _mm256_loadu_pd((double const *)(b11 + 4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2 + 4)); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3 + 4)); //B11[0][7] B11[1][7] B11[2][7] B11[3][7] - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4] - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm13); //B11[0-3][5] * alpha -= B01[0-3][5] - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm14); //B11[0-3][6] * alpha -= B01[0-3][6] - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm15); //B11[0-3][7] * alpha -= B01[0-3][7] - - ymm13 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][4] B11[0][5] B11[2][4] B11[2][5] - ymm15 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][6] B11[0][7] B11[2][6] B11[2][7] - ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3] - ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3] - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][4] B11[1][5] B11[3][4] B11[3][5] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][6] B11[1][7] B11[3][6] B11[3][7] - ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3] - ymm15 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3] - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *4 + 4)); - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *5 + 4)); - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm6); - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm7); - - ymm16 = _mm256_broadcast_sd((double const *)(&ones)); - ymm7 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm4 = _mm256_permute2f128_pd(ymm7,ymm16,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm6 = _mm256_permute2f128_pd(ymm7,ymm16,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3] - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm5 = _mm256_permute2f128_pd(ymm0,ymm16,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm7 = _mm256_permute2f128_pd(ymm0,ymm16,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - ymm18 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - - ymm17 = _mm256_permute2f128_pd(ymm18,ymm16,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm19 = _mm256_permute2f128_pd(ymm18,ymm16,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3] - ymm20 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm18 = _mm256_permute2f128_pd(ymm20,ymm16,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm20 = _mm256_permute2f128_pd(ymm20,ymm16,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - - /* - Compute 8x6 TRSM block by using GEMM block output in register - a. The 8x6 input (gemm outputs) are stored in combinations of ymm registers - 1. ymm8, ymm4 2. ymm9, ymm5 3. ymm10, ymm6, 4. ymm11, ymm7 - 5. ymm12, ymm17 6. ymm13,ymm18, 7. ymm14,ymm19 8. ymm15, ymm20 - where ymm8-ymm15 holds 8x4 data and reaming 8x2 will be hold by - other registers - b. Towards the end do in regiser transpose of TRSM output and store in b11 - */ - ////extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); - - //perform mul operation - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); - ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm1); - - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(ROw1): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*1)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); - ymm5 = _mm256_fnmadd_pd(ymm2, ymm4, ymm5); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm8, ymm10); - ymm6 = _mm256_fnmadd_pd(ymm2, ymm4, ymm6); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); - ymm7 = _mm256_fnmadd_pd(ymm2, ymm4, ymm7); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); - ymm12 = _mm256_fnmadd_pd(ymm2, ymm8, ymm12); - ymm17 = _mm256_fnmadd_pd(ymm2, ymm4, ymm17); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); - ymm13 = _mm256_fnmadd_pd(ymm2, ymm8, ymm13); - ymm18 = _mm256_fnmadd_pd(ymm2, ymm4, ymm18); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); - ymm14 = _mm256_fnmadd_pd(ymm2, ymm8, ymm14); - ymm19 = _mm256_fnmadd_pd(ymm2, ymm4, ymm19); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); - ymm15 = _mm256_fnmadd_pd(ymm2, ymm8, ymm15); - ymm20 = _mm256_fnmadd_pd(ymm2, ymm4, ymm20); - - - //perform mul operation - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm1); - - a11 += 1; - - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(ROw2): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm9, ymm10); - ymm6 = _mm256_fnmadd_pd(ymm2, ymm5, ymm6); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm9, ymm11); - ymm7 = _mm256_fnmadd_pd(ymm2, ymm5, ymm7); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); - ymm12 = _mm256_fnmadd_pd(ymm2, ymm9, ymm12); - ymm17 = _mm256_fnmadd_pd(ymm2, ymm5, ymm17); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); - ymm13 = _mm256_fnmadd_pd(ymm2, ymm9, ymm13); - ymm18 = _mm256_fnmadd_pd(ymm2, ymm5, ymm18); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); - ymm14 = _mm256_fnmadd_pd(ymm2, ymm9, ymm14); - ymm19 = _mm256_fnmadd_pd(ymm2, ymm5, ymm19); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); - ymm15 = _mm256_fnmadd_pd(ymm2, ymm9, ymm15); - ymm20 = _mm256_fnmadd_pd(ymm2, ymm5, ymm20); - - //perform mul operation - ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); - ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm1); - - a11 += 1; - - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - //(ROw5): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm10, ymm11); - ymm7 = _mm256_fnmadd_pd(ymm2, ymm6, ymm7); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); - ymm12 = _mm256_fnmadd_pd(ymm2, ymm10, ymm12); - ymm17 = _mm256_fnmadd_pd(ymm2, ymm6, ymm17); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); - ymm13 = _mm256_fnmadd_pd(ymm2, ymm10, ymm13); - ymm18 = _mm256_fnmadd_pd(ymm2, ymm6, ymm18); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); - ymm14 = _mm256_fnmadd_pd(ymm2, ymm10, ymm14); - ymm19 = _mm256_fnmadd_pd(ymm2, ymm6, ymm19); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); - ymm15 = _mm256_fnmadd_pd(ymm2, ymm10, ymm15); - ymm20 = _mm256_fnmadd_pd(ymm2, ymm6, ymm20); - - //perform mul operation - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm1); - - a11 += 1; - - //extract a44 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); - //(ROw4): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); - ymm12 = _mm256_fnmadd_pd(ymm2, ymm11, ymm12); - ymm17 = _mm256_fnmadd_pd(ymm2, ymm7, ymm17); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); - ymm13 = _mm256_fnmadd_pd(ymm2, ymm11, ymm13); - ymm18 = _mm256_fnmadd_pd(ymm2, ymm7, ymm18); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); - ymm14 = _mm256_fnmadd_pd(ymm2, ymm11, ymm14); - ymm19 = _mm256_fnmadd_pd(ymm2, ymm7, ymm19); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); - ymm15 = _mm256_fnmadd_pd(ymm2, ymm11, ymm15); - ymm20 = _mm256_fnmadd_pd(ymm2, ymm7, ymm20); - - //perform mul operation - ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); - ymm17 = DTRSM_SMALL_DIV_OR_SCALE(ymm17, ymm1); - - a11 += 1; - - //extract a55 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); - - //(ROw5): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); - ymm13 = _mm256_fnmadd_pd(ymm2, ymm12, ymm13); - ymm18 = _mm256_fnmadd_pd(ymm2, ymm17, ymm18); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); - ymm14 = _mm256_fnmadd_pd(ymm2, ymm12, ymm14); - ymm19 = _mm256_fnmadd_pd(ymm2, ymm17, ymm19); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); - ymm15 = _mm256_fnmadd_pd(ymm2, ymm12, ymm15); - ymm20 = _mm256_fnmadd_pd(ymm2, ymm17, ymm20); - - //perform mul operation - ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm1); - ymm18 = DTRSM_SMALL_DIV_OR_SCALE(ymm18, ymm1); - - a11 += 1; - - //extract a66 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); - - //(ROw6): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); - ymm14 = _mm256_fnmadd_pd(ymm2, ymm13, ymm14); - ymm19 = _mm256_fnmadd_pd(ymm2, ymm18, ymm19); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); - ymm15 = _mm256_fnmadd_pd(ymm2, ymm13, ymm15); - ymm20 = _mm256_fnmadd_pd(ymm2, ymm18, ymm20); - - //perform mul operation - ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); - ymm19 = DTRSM_SMALL_DIV_OR_SCALE(ymm19, ymm1); - - a11 += 1; - - //extract a77 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); - - //(ROw7): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); - ymm15 = _mm256_fnmadd_pd(ymm2, ymm14, ymm15); - ymm20 = _mm256_fnmadd_pd(ymm2, ymm19, ymm20); - - //perform mul operation - ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); - ymm20 = DTRSM_SMALL_DIV_OR_SCALE(ymm20, ymm1); - - a11 += 1; - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ///unpack high/// - ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] - ymm3 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2] - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] - - ///unpack high/// - ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[4][1] B11[5][1] B11[4][3] B11[5][3] - ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[6][1] B11[7][1] B11[6][3] B11[7][3] - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] - ymm3 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - _mm256_storeu_pd((double *)(b11 + 4), ymm0); //store B11[4][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm1); //store B11[5][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 4), ymm2); //store B11[6][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 3 + 4), ymm3); //store B11[7][0-3] + //extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + //(row 1):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); - ///unpack high/// - ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm3, ymm9); - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store B11[1][0-3] + a11 += cs_a; - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm17, ymm18); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] - ymm3 = _mm256_unpacklo_pd(ymm19, ymm20); //B11[6][0] B11[7][0] B11[6][2] B11[7][2] + //extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); - ///unpack high/// - ymm17 = _mm256_unpackhi_pd(ymm17, ymm18); //B11[4][1] B11[5][1] B11[4][3] B11[5][3] - ymm18 = _mm256_unpackhi_pd(ymm19, ymm20); //B11[6][1] B11[7][1] B11[6][3] B11[7][3] + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm5, ymm9); - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm17, ymm18, 0x20); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - _mm256_storeu_pd((double *)(b11 + cs_b * 4 + 4), ymm0); //store B11[4][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 5 + 4), ymm1); //store B11[5][0-3] - } + a11 += cs_a; - dim_t n_rem = n-j; - if(n_rem >= 4) - { - a10 = D_A_pack; - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + j*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM - - k_iter = i ; //number of times GEMM to be performed(in blocks of 4x4) - - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - ymm12 = _mm256_setzero_pd(); - ymm13 = _mm256_setzero_pd(); - ymm14 = _mm256_setzero_pd(); - ymm15 = _mm256_setzero_pd(); + //extract a33 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm7, ymm9); - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15); + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - b01 += 1; //move to next row of B - a10 += D_MR; //pointer math to calculate next block of A for GEMM - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] - ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] - ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *2 + 4)); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] - ymm7 = _mm256_loadu_pd((double const *)(b11 + cs_b *3 + 4)); //B11[0][7] B11[1][7] B11[2][7] B11[3][7] - - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); //B11[0-3][3] * alpha -= B01[0-3][3] - ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4] - ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] * alpha -= B01[0-3][5] - ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); //B11[0-3][6] * alpha -= B01[0-3][6] - ymm7 = _mm256_fmsub_pd(ymm7, ymm16, ymm15); //B11[0-3][7] * alpha -= B01[0-3][7] - - ///implement TRSM/// - - ///transpose of B11// - ///unpacklow/// - ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - - ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[2][4] B11[2][5] - ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[2][6] B11[2][7] - - //rearrange low elements - ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] - - ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3] - ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3] - - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - - ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[1][4] B11[1][5] B11[3][4] B11[3][5] - ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[1][6] B11[1][7] B11[3][6] B11[3][7] - - //rearrange high elements - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - - ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3] - ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3] - - ymm0 = _mm256_broadcast_sd((double const *)&ones); - - //extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); + _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); - //perform mul operation - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); + m_remainder -= 4; + i += 4; + } - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + if(m_remainder == 3) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*1)); - ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); - ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); - ymm5 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); - ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - a11 += 1; + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS - //(ROw1): FMA operations - ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); - ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); - ymm11 = _mm256_fnmadd_pd(ymm4, ymm8, ymm11); - ymm12 = _mm256_fnmadd_pd(ymm5, ymm8, ymm12); - ymm13 = _mm256_fnmadd_pd(ymm6, ymm8, ymm13); - ymm14 = _mm256_fnmadd_pd(ymm7, ymm8, ymm14); - ymm15 = _mm256_fnmadd_pd(ymm16, ymm8, ymm15); + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_4nx4m(a01,b10,cs_b,p_lda,k_iter) - //perform mul operation - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); + ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); - ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); - ymm5 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); - ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - a11 += 1; + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - //(ROw2): FMA operations - ymm10 = _mm256_fnmadd_pd(ymm3, ymm9, ymm10); - ymm11 = _mm256_fnmadd_pd(ymm4, ymm9, ymm11); - ymm12 = _mm256_fnmadd_pd(ymm5, ymm9, ymm12); - ymm13 = _mm256_fnmadd_pd(ymm6, ymm9, ymm13); - ymm14 = _mm256_fnmadd_pd(ymm7, ymm9, ymm14); - ymm15 = _mm256_fnmadd_pd(ymm16, ymm9, ymm15); + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*3 + 2)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); + ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 - //perform mul operation - ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + ///implement TRSM/// - ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); - ymm5 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); - ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - a11 += 1; + //extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + //(row 1):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - //(ROw5): FMA operations - ymm11 = _mm256_fnmadd_pd(ymm4, ymm10, ymm11); - ymm12 = _mm256_fnmadd_pd(ymm5, ymm10, ymm12); - ymm13 = _mm256_fnmadd_pd(ymm6, ymm10, ymm13); - ymm14 = _mm256_fnmadd_pd(ymm7, ymm10, ymm14); - ymm15 = _mm256_fnmadd_pd(ymm16, ymm10, ymm15); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); - //perform mul operation - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm3, ymm9); - ymm0 = _mm256_broadcast_sd((double const *)&ones); + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - //extract a44 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + a11 += cs_a; - ymm5 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); - ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + //extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - a11 += 1; + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); - //(ROw4): FMA operations - ymm12 = _mm256_fnmadd_pd(ymm5, ymm11, ymm12); - ymm13 = _mm256_fnmadd_pd(ymm6, ymm11, ymm13); - ymm14 = _mm256_fnmadd_pd(ymm7, ymm11, ymm14); - ymm15 = _mm256_fnmadd_pd(ymm16, ymm11, ymm15); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm5, ymm9); - //perform mul operation - ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + a11 += cs_a; - a11 += 1; + //extract a33 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - //extract a55 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm7, ymm9); + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - //(ROw5): FMA operations - ymm13 = _mm256_fnmadd_pd(ymm6, ymm12, ymm13); - ymm14 = _mm256_fnmadd_pd(ymm7, ymm12, ymm14); - ymm15 = _mm256_fnmadd_pd(ymm16, ymm12, ymm15); + ymm0 = _mm256_loadu_pd((double const *)b11); + ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x07); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x07); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x07); + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*3 + 2)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x07); - //perform mul operation - ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm1); + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); + xmm5 = _mm256_extractf128_pd(ymm9, 0); + _mm_storeu_pd((double *)(b11 + cs_b * 3),xmm5); + _mm_storel_pd((b11 + cs_b * 3 + 2), _mm256_extractf128_pd(ymm9, 1)); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 +cs_a*7)); + m_remainder -= 3; + i += 3; + } + else if(m_remainder == 2) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - a11 += 1; + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - //extract a66 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS - //(ROw6): FMA operations - ymm14 = _mm256_fnmadd_pd(ymm7, ymm13, ymm14); - ymm15 = _mm256_fnmadd_pd(ymm16, ymm13, ymm15); + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_4nx4m(a01,b10,cs_b,p_lda,k_iter) - //perform mul operation - ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - //extract a77 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); + ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - a11 += 1; - //(ROw7): FMA operations - ymm15 = _mm256_fnmadd_pd(ymm16, ymm14, ymm15); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - //perform mul operation - ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); + ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + ///implement TRSM/// - ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] - ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2] + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + //extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] - ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] + //(row 1):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - ///unpack high/// - ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); - ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[4][1] B11[5][1] B11[4][3] B11[5][3] - ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[6][1] B11[7][1] B11[6][3] B11[7][3] + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm3, ymm9); - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] - ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] + a11 += cs_a; - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); //store B11[4][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); //store B11[5][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 4), ymm6); //store B11[6][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 3 + 4), ymm7); //store B11[7][0-3] + //extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - n_rem -=4; - j +=4; + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); - } - if(n_rem) - { - a10 = D_A_pack; - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + j*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM - - k_iter = i; //number of times GEMM to be performed(in blocks of 4x4) - - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm12 = _mm256_setzero_pd(); - ymm13 = _mm256_setzero_pd(); - ymm14 = _mm256_setzero_pd(); - - if(3 == n_rem) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm5, ymm9); - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + a11 += cs_a; - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + //extract a33 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - b01 += 1; //move to next row of B - a10 += D_MR; //pointer math to calculate next block of A for GEMM - } + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm7, ymm9); - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] - ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] - ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *2 + 4)); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] - - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] - ymm3 = _mm256_broadcast_sd((double const *)(&ones)); - - ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4] - ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] * alpha -= B01[0-3][5] - ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); //B11[0-3][6] * alpha -= B01[0-3][6] - ymm7 = _mm256_broadcast_sd((double const *)(&ones)); - } - else if(2 == n_rem) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + ymm0 = _mm256_loadu_pd((double const *)b11); + ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x03); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x03); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x03); + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); + ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x03); - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); + xmm5 = _mm256_extractf128_pd(ymm9, 0); + _mm_storeu_pd((double *)(b11 + cs_b * 3),xmm5); - b01 += 1; //move to next row of B - a10 += D_MR; //pointer math to calculate next block of A for GEMM + m_remainder -= 2; + i += 2; } + else if(m_remainder == 1) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS - ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] - ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_4nx4m(a01,b10,cs_b,p_lda,k_iter) - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] - ymm2 = _mm256_broadcast_sd((double const *)(&ones)); - ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4] - ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] * alpha -= B01[0-3][5] - ymm6 = _mm256_broadcast_sd((double const *)(&ones)); - ymm7 = _mm256_broadcast_sd((double const *)(&ones)); - } - else if(1 == n_rem) - { - ///GEMM code begins/// + ymm0 = _mm256_broadcast_sd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); - b01 += 1; //move to next row of B - a10 += D_MR; //pointer math to calculate next block of A for GEMM - } + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 - ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] + ///implement TRSM/// - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_broadcast_sd((double const *)(&ones)); - ymm2 = _mm256_broadcast_sd((double const *)(&ones)); - ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4] - ymm5 = _mm256_broadcast_sd((double const *)(&ones)); - ymm6 = _mm256_broadcast_sd((double const *)(&ones)); - ymm7 = _mm256_broadcast_sd((double const *)(&ones)); - } - ///implement TRSM/// + //extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - ///transpose of B11// - ///unpacklow/// - ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + //(row 1):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[2][4] B11[2][5] - ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[2][6] B11[2][7] + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); - //rearrange low elements - ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm3, ymm9); - ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3] - ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3] + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + a11 += cs_a; - ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[1][4] B11[1][5] B11[3][4] B11[3][5] - ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[1][6] B11[1][7] B11[3][6] B11[3][7] + //extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - //rearrange high elements - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); - ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3] - ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3] + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm5, ymm9); - ymm0 = _mm256_broadcast_sd((double const *)&ones); + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - //extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + a11 += cs_a; - //perform mul operation - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); + //extract a33 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm7, ymm9); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*1)); - ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); - ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); - ymm5 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); - ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - a11 += 1; + ymm0 = _mm256_loadu_pd((double const *)b11); + ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x01); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x01); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x01); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x01); - //(ROw1): FMA operations - ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); - ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); - ymm11 = _mm256_fnmadd_pd(ymm4, ymm8, ymm11); - ymm12 = _mm256_fnmadd_pd(ymm5, ymm8, ymm12); - ymm13 = _mm256_fnmadd_pd(ymm6, ymm8, ymm13); - ymm14 = _mm256_fnmadd_pd(ymm7, ymm8, ymm14); - ymm15 = _mm256_fnmadd_pd(ymm16, ymm8, ymm15); + _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm3, 0)); + _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm5, 0)); + _mm_storel_pd((b11 + cs_b * 2), _mm256_extractf128_pd(ymm7, 0)); + _mm_storel_pd((b11 + cs_b * 3), _mm256_extractf128_pd(ymm9, 0)); - //perform mul operation - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); + m_remainder -= 1; + i += 1; + } + j += 4; + n_remainder -= 4; + } - ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); - ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); - ymm5 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); - ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + if(n_remainder == 3) + { + a01 = L + j*rs_a; //pointer to block of A to be used in GEMM + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM - a11 += 1; + double *ptr_a10_dup = D_A_pack; - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + dim_t p_lda = j; // packed leading dimension + // perform copy of A to packed buffer D_A_pack - //(ROw2): FMA operations - ymm10 = _mm256_fnmadd_pd(ymm3, ymm9, ymm10); - ymm11 = _mm256_fnmadd_pd(ymm4, ymm9, ymm11); - ymm12 = _mm256_fnmadd_pd(ymm5, ymm9, ymm12); - ymm13 = _mm256_fnmadd_pd(ymm6, ymm9, ymm13); - ymm14 = _mm256_fnmadd_pd(ymm7, ymm9, ymm14); - ymm15 = _mm256_fnmadd_pd(ymm16, ymm9, ymm15); + if(transa) + { + for(dim_t x =0;x < p_lda;x+=d_nr) + { + ymm0 = _mm256_loadu_pd((double const *)(a01)); + ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); + ymm2 = _mm256_loadu_pd((double const *)(a01 + cs_a * 2)); + ymm3 = _mm256_loadu_pd((double const *)(a01 + cs_a * 3)); - //perform mul operation - ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); - ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); - ymm5 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); - ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); - a11 += 1; + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - //(ROw5): FMA operations - ymm11 = _mm256_fnmadd_pd(ymm4, ymm10, ymm11); - ymm12 = _mm256_fnmadd_pd(ymm5, ymm10, ymm12); - ymm13 = _mm256_fnmadd_pd(ymm6, ymm10, ymm13); - ymm14 = _mm256_fnmadd_pd(ymm7, ymm10, ymm14); - ymm15 = _mm256_fnmadd_pd(ymm16, ymm10, ymm15); + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); - //perform mul operation - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); + ymm0 = _mm256_loadu_pd((double const *)(a01 + cs_a * 4)); + ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a * 5)); - ymm0 = _mm256_broadcast_sd((double const *)&ones); + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_broadcast_sd((double const *)&zero); - //extract a44 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); - ymm5 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); - ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_broadcast_sd((double const *)&zero); - a11 += 1; + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - //(ROw4): FMA operations - ymm12 = _mm256_fnmadd_pd(ymm5, ymm11, ymm12); - ymm13 = _mm256_fnmadd_pd(ymm6, ymm11, ymm13); - ymm14 = _mm256_fnmadd_pd(ymm7, ymm11, ymm14); - ymm15 = _mm256_fnmadd_pd(ymm16, ymm11, ymm15); + _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_extractf128_pd(ymm6,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_extractf128_pd(ymm7,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_extractf128_pd(ymm8,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_extractf128_pd(ymm9,0)); - //perform mul operation - ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); + a01 += d_nr*cs_a; + ptr_a10_dup += d_nr; + } + } + else + { + dim_t loop_count = p_lda/4; - ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + for(dim_t x =0;x < loop_count;x++) + { + ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 0 + x*4)); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + x*4), ymm15); + ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 1 + x*4)); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 1 + x*4), ymm15); + ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 2 + x*4)); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 2 + x*4), ymm15); + } - a11 += 1; + dim_t remainder_loop_count = p_lda - loop_count*4; - //extract a55 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + __m128d xmm0; + if(remainder_loop_count != 0) + { + xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 0 + loop_count*4)); + _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + loop_count*4), xmm0); + xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 1 + loop_count*4)); + _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 1 + loop_count*4), xmm0); + xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 2 + loop_count*4)); + _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 2 + loop_count*4), xmm0); + } + } + ymm4 = _mm256_broadcast_sd((double const *)&ones); + if(!is_unitdiag) + { + if(transa) + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a*1 + 1)); + ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a*2 + 2)); + } + else + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11+ rs_a*1 + 1)); + ymm2 = _mm256_broadcast_sd((double const *)(a11+ rs_a*2 + 2)); + } + ymm3 = _mm256_broadcast_sd((double const *)&ones); - //(ROw5): FMA operations - ymm13 = _mm256_fnmadd_pd(ymm6, ymm12, ymm13); - ymm14 = _mm256_fnmadd_pd(ymm7, ymm12, ymm14); - ymm15 = _mm256_fnmadd_pd(ymm16, ymm12, ymm15); - - //perform mul operation - ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm1); + ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 +cs_a*7)); + ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); + #ifdef BLIS_DISABLE_TRSM_PREINVERSION + ymm4 = ymm1; + #endif + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + ymm4 = _mm256_div_pd(ymm4, ymm1); + #endif + } + _mm256_storeu_pd((double *)(d11_pack), ymm4); - a11 += 1; + for(i = 0; (i+d_mr-1) < m; i += d_mr) //loop along 'M' direction + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - //extract a66 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS - //(ROw6): FMA operations - ymm14 = _mm256_fnmadd_pd(ymm7, ymm13, ymm14); - ymm15 = _mm256_fnmadd_pd(ymm16, ymm13, ymm15); + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_3nx8m(a01,b10,cs_b,p_lda,k_iter) - //perform mul operation - ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); - //extract a77 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); + ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + 4)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] - ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + ymm4 = _mm256_fmsub_pd(ymm1, ymm15, ymm4); //B11[4-7][0] * alpha-= ymm1 - a11 += 1; - //(ROw7): FMA operations - ymm15 = _mm256_fnmadd_pd(ymm16, ymm14, ymm15); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b + 4)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] - //perform mul operation - ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + ymm6 = _mm256_fmsub_pd(ymm1, ymm15, ymm6); //B11[4-7][1] * alpha -= ymm3 - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b*2 + 4)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] - ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] - ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2] + ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 + ymm8 = _mm256_fmsub_pd(ymm1, ymm15, ymm8); //B11[4-7][2] * alpha -= ymm5 - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ///implement TRSM/// - ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] - ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ///unpack high/// - ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm0); - ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[4][1] B11[5][1] B11[4][3] B11[5][3] - ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[6][1] B11[7][1] B11[6][3] B11[7][3] + //extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + //(row 1):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); - ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] - ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] + ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); + ymm6 = _mm256_fnmadd_pd(ymm1, ymm4, ymm6); - if(3 == n_rem) - { - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); //store B11[4][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); //store B11[5][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 4), ymm6); //store B11[6][0-3] - } - else if(2 == n_rem) - { - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); //store B11[4][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); //store B11[5][0-3] - } - else if(1 == n_rem) - { - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); //store B11[4][0-3] - } - } + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); - } - //======================M remainder cases================================ - dim_t m_rem = m-i; - if(m_rem>=4) //implementation for reamainder rows(when 'M' is not a multiple of D_MR) - { - a10 = L + (i*cs_a); //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); - double *ptr_a10_dup = D_A_pack; - dim_t p_lda = 4; // packed leading dimension - for(dim_t x =0;x < i;x+=4) - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm1 = _mm256_loadu_pd((double const *)(a10 + cs_a)); - ymm2 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); - ymm3 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); + ymm8 = _mm256_fnmadd_pd(ymm1, ymm4, ymm8); - ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm0); - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + a11 += cs_a; - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + //extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); - _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); + ymm8 = _mm256_fnmadd_pd(ymm1, ymm6, ymm8); - a10 += 4; - ptr_a10_dup += 4*4; - } + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm0); - ymm4 = _mm256_broadcast_sd((double const *)&ones); - if(!is_unitdiag) - { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_sd((double const *)(a11)); - ymm1 = _mm256_broadcast_sd((double const *)(a11+cs_a*1 + 1)); - ymm2 = _mm256_broadcast_sd((double const *)(a11+cs_a*2 + 2)); - ymm3 = _mm256_broadcast_sd((double const *)(a11+cs_a*3 + 3)); + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + 4), ymm4); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + _mm256_storeu_pd((double *)(b11 + cs_b + 4), ymm6); + _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); + _mm256_storeu_pd((double *)(b11 + cs_b*2 + 4), ymm8); + } - ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); - ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); - #ifdef BLIS_DISABLE_TRSM_PREINVERSION - ymm4 = ymm1; - #endif - #ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm4 = _mm256_div_pd(ymm4, ymm1); - #endif - } - _mm256_storeu_pd((double *)(d11_pack), ymm4); + dim_t m_remainder = m - i; + if(m_remainder >= 4) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - for(j = 0; (j+D_NR-1) < n; j += D_NR) //loop along 'N' dimension - { - a10 = D_A_pack; //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM - b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - k_iter = i; //number of times GEMM operation to be done(in blocks of 4x4) + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); - #endif + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ///implement TRSM/// + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + //extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); - ymm4 = _mm256_fmadd_pd(ymm2, ymm0, ymm4); + //(row 1):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); - b01 += 1; //move to next row of B - a10 += p_lda; //pointer math to calculate next block of A for GEMM - } + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + a11 += cs_a; - ///implement TRSM/// - ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + //extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); //B11[0-3][3] * alpha -= B01[0-3][3] + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); - ///transpose of B11// - ///unpacklow/// - ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - //rearrange low elements - ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + m_remainder -= 4; + i += 4; + } - //rearrange high elements - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + if(m_remainder == 3) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) - ymm16 = _mm256_broadcast_sd((double const *)(&ones)); + BLIS_PRE_DTRSM_SMALL_3N_3M(AlphaVal,b11,cs_b) - ////unpacklow//// - ymm7 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - //ymm16; + ///implement TRSM/// - //rearrange low elements - ymm4 = _mm256_permute2f128_pd(ymm7,ymm16,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm6 = _mm256_permute2f128_pd(ymm7,ymm16,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3] + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - //ymm16; + //extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - //rearrange high elements - ymm5 = _mm256_permute2f128_pd(ymm0,ymm16,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm7 = _mm256_permute2f128_pd(ymm0,ymm16,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - //b11 transpose end + //(row 1):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - ////extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); - //perform mul operation - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); - ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm1); + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + a11 += cs_a; - //(ROw1): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*1)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); - ymm5 = _mm256_fnmadd_pd(ymm2, ymm4, ymm5); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm8, ymm10); - ymm6 = _mm256_fnmadd_pd(ymm2, ymm4, ymm6); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); - ymm7 = _mm256_fnmadd_pd(ymm2, ymm4, ymm7); + //extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); - //perform mul operation - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm1); + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - a11 += 1; + BLIS_POST_DTRSM_SMALL_3N_3M(b11,cs_b) - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + m_remainder -= 3; + i += 3; + } + else if(m_remainder == 2) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - //(ROw2): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm9, ymm10); - ymm6 = _mm256_fnmadd_pd(ymm2, ymm5, ymm6); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm9, ymm11); - ymm7 = _mm256_fnmadd_pd(ymm2, ymm5, ymm7); + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - //perform mul operation - ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); - ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm1); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS - a11 += 1; + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + BLIS_PRE_DTRSM_SMALL_3N_2M(AlphaVal,b11,cs_b) - //(ROw5): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm10, ymm11); - ymm7 = _mm256_fnmadd_pd(ymm2, ymm6, ymm7); + ///implement TRSM/// - //perform mul operation - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm1); + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - a11 += 1; + //extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ///unpack high/// - ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - - ///unpack high/// - ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - - _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store B11[1][0-3] - } - dim_t n_rem = n-j; - if(n_rem >= 4) - { - a10 = D_A_pack; - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + j*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM - - k_iter = i; //number of times GEMM to be performed(in blocks of 4x4) - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - #endif - - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); + //(row 1):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + a11 += cs_a; - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + //extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); - b01 += 1; //move to next row of B - a10 += p_lda; //pointer math to calculate next block of A for GEMM - } - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - ///implement TRSM/// + BLIS_POST_DTRSM_SMALL_3N_2M(b11,cs_b) - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + m_remainder -= 2; + i += 2; + } + else if(m_remainder == 1) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - ///transpose of B11// - ///unpacklow/// - ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - //rearrange low elements - ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) - //rearrange high elements - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + BLIS_PRE_DTRSM_SMALL_3N_1M(AlphaVal,b11,cs_b) - ymm0 = _mm256_broadcast_sd((double const *)&ones); + ///implement TRSM/// - //extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - //perform mul operation - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); + //extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + //(row 1):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*1)); - ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); - ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); - a11 += 1; + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - //(ROw1): FMA operations - ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); - ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); - ymm11 = _mm256_fnmadd_pd(ymm4, ymm8, ymm11); + a11 += cs_a; - //perform mul operation - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); + //extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); - ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); - a11 += 1; + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + BLIS_POST_DTRSM_SMALL_3N_1M(b11,cs_b) - //(ROw2): FMA operations - ymm10 = _mm256_fnmadd_pd(ymm3, ymm9, ymm10); - ymm11 = _mm256_fnmadd_pd(ymm4, ymm9, ymm11); + m_remainder -= 1; + i += 1; + } + j += 3; + n_remainder -= 3; + } + else if(n_remainder == 2) + { + a01 = L + j*rs_a; //pointer to block of A to be used in GEMM + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM - //perform mul operation - ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + double *ptr_a10_dup = D_A_pack; - ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + dim_t p_lda = j; // packed leading dimension + // perform copy of A to packed buffer D_A_pack - a11 += 1; + if(transa) + { + for(dim_t x =0;x < p_lda;x+=d_nr) + { + ymm0 = _mm256_loadu_pd((double const *)(a01)); + ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); + ymm2 = _mm256_loadu_pd((double const *)(a01 + cs_a * 2)); + ymm3 = _mm256_loadu_pd((double const *)(a01 + cs_a * 3)); - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); - //(ROw5): FMA operations - ymm11 = _mm256_fnmadd_pd(ymm4, ymm10, ymm11); + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); - //perform mul operation - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); - ///unpack high/// - ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + ymm0 = _mm256_loadu_pd((double const *)(a01 + cs_a * 4)); + ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a * 5)); - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_broadcast_sd((double const *)&zero); - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); - n_rem -= 4; - j += 4; - } - if(n_rem) - { - a10 = D_A_pack; - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + j*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_broadcast_sd((double const *)&zero); - k_iter = i; //number of times GEMM to be performed(in blocks of 4x4) + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); + _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_extractf128_pd(ymm6,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_extractf128_pd(ymm7,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_extractf128_pd(ymm8,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_extractf128_pd(ymm9,0)); - if(3 == n_rem) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + a01 += d_nr*cs_a; + ptr_a10_dup += d_nr; + } + } + else { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + dim_t loop_count = p_lda/4; - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + for(dim_t x =0;x < loop_count;x++) + { + ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 0 + x*4)); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + x*4), ymm15); + ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 1 + x*4)); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 1 + x*4), ymm15); + } - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + dim_t remainder_loop_count = p_lda - loop_count*4; - b01 += 1; //move to next row of B - a10 += p_lda; //pointer math to calculate next block of A for GEMM + __m128d xmm0; + if(remainder_loop_count != 0) + { + xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 0 + loop_count*4)); + _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + loop_count*4), xmm0); + xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 1 + loop_count*4)); + _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 1 + loop_count*4), xmm0); + } } - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] - ymm3 = _mm256_broadcast_sd((double const *)(&ones)); - } - else if(2 == n_rem) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + ymm4 = _mm256_broadcast_sd((double const *)&ones); + if(!is_unitdiag) { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + if(transa) + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11+cs_a*1 + 1)); + } + else + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11+rs_a*1 + 1)); + } + ymm2 = _mm256_broadcast_sd((double const *)&ones); + ymm3 = _mm256_broadcast_sd((double const *)&ones); - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); - b01 += 1; //move to next row of B - a10 += p_lda; //pointer math to calculate next block of A for GEMM + ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); + #ifdef BLIS_DISABLE_TRSM_PREINVERSION + ymm4 = ymm1; + #endif + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + ymm4 = _mm256_div_pd(ymm4, ymm1); + #endif } + _mm256_storeu_pd((double *)(d11_pack), ymm4); - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] - ymm2 = _mm256_broadcast_sd((double const *)(&ones)); - ymm3 = _mm256_broadcast_sd((double const *)(&ones)); - } - else if(1 == n_rem) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + for(i = 0; (i+d_mr-1) < m; i += d_mr) //loop along 'M' direction { - ymm0 = _mm256_loadu_pd((double const *)(a10)); + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - b01 += 1; //move to next row of B - a10 += p_lda; //pointer math to calculate next block of A for GEMM - } + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_2nx8m(a01,b10,cs_b,p_lda,k_iter) - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_broadcast_sd((double const *)(&ones)); - ymm2 = _mm256_broadcast_sd((double const *)(&ones)); - ymm3 = _mm256_broadcast_sd((double const *)(&ones)); - } + ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + 4)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] - ///transpose of B11// - ///unpacklow/// - ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + ymm4 = _mm256_fmsub_pd(ymm1, ymm15, ymm4); //B11[4-7][0] * alpha-= ymm1 - //rearrange low elements - ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b + 4)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + ymm6 = _mm256_fmsub_pd(ymm1, ymm15, ymm6); //B11[4-7][1] * alpha -= ymm3 - //rearrange high elements - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + ///implement TRSM/// - ymm0 = _mm256_broadcast_sd((double const *)&ones); + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ////extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm0); - //perform mul operation - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); + //extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + //(row 1):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*1)); - ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); - ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); + ymm6 = _mm256_fnmadd_pd(ymm1, ymm4, ymm6); - a11 += 1; + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm0); - //(ROw1): FMA operations - ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); - ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); - ymm11 = _mm256_fnmadd_pd(ymm4, ymm8, ymm11); + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + 4), ymm4); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + _mm256_storeu_pd((double *)(b11 + cs_b + 4), ymm6); + } - //perform mul operation - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); + dim_t m_remainder = m - i; + if(m_remainder >= 4) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); - ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - a11 += 1; + ymm3 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) - //(ROw2): FMA operations - ymm10 = _mm256_fnmadd_pd(ymm3, ymm9, ymm10); - ymm11 = _mm256_fnmadd_pd(ymm4, ymm9, ymm11); + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - //perform mul operation - ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - a11 += 1; + ///implement TRSM/// - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - //(ROw5): FMA operations - ymm11 = _mm256_fnmadd_pd(ymm4, ymm10, ymm11); + //extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - //perform mul operation - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); + //(row 1):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - ///unpack high/// - ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + m_remainder -= 4; + i += 4; + } - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + if(m_remainder == 3) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - if(3 == n_rem) - { - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] - } - else if(2 == n_rem) - { - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - } - else if(1 == n_rem) - { - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - } + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - } - m_rem -=4; - i +=4; - } - - if(m_rem) - { - a10 = L + (i*cs_a); //pointer to block of A to be used for GEMM - // Do transpose for a10 & store in D_A_pack - double *ptr_a10_dup = D_A_pack; - if(3 == m_rem) // Repetative A blocks will be 3*3 - { - dim_t p_lda = 4; // packed leading dimension - for(dim_t x=0;x= 4)) - { - a10 = D_A_pack; //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM - b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM - - k_iter = i; //number of times GEMM to be performed(in blocks of 4x4) - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - #endif - - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm3 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + BLIS_PRE_DTRSM_SMALL_2N_2M(AlphaVal,b11,cs_b) - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ///implement TRSM/// + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + //extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - b01 += 1; //move to next row of B - a10 += 4; //pointer math to calculate next block of A for GEMM - } + //(row 1):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); - ymm3 = _mm256_broadcast_sd((double const *)(b11 + cs_b*3 + 2)); - ymm3 = _mm256_insertf128_pd(ymm3, xmm5, 0); - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x08); - - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - xmm5 = _mm256_extractf128_pd(ymm3, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 3),xmm5); - _mm_storel_pd((b11 + cs_b * 3 + 2), _mm256_extractf128_pd(ymm3, 1)); - - dtrsm_AutXB_ref(a11, b11, m_rem, 4, cs_a, cs_b,is_unitdiag); - n_rem -= 4; - j +=4; - } - - if(n_rem) - { - a10 = D_A_pack; //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM - b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM - - k_iter = i; //number of times GEMM to be performed(in blocks of 4x4) - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - #endif + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); + BLIS_POST_DTRSM_SMALL_2N_2M(b11,cs_b) - if(3 == n_rem) + m_remainder -= 2; + i += 2; + } + else if(m_remainder == 1) { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ymm3 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) - b01 += 1; //move to next row of B - a10 += 4; //pointer math to calculate next block of A for GEMM - } + BLIS_PRE_DTRSM_SMALL_2N_1M(AlphaVal,b11,cs_b) - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + ///implement TRSM/// - ///implement TRSM/// + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2)); - ymm2 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 2 + 2)); - ymm2 = _mm256_insertf128_pd(ymm2, xmm5, 0); + //extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + //(row 1):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08); + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - xmm5 = _mm256_extractf128_pd(ymm2, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 2), xmm5); - _mm_storel_pd((b11 + cs_b * 2 + 2), _mm256_extractf128_pd(ymm2, 1)); + BLIS_POST_DTRSM_SMALL_2N_1M(b11,cs_b) - dtrsm_AutXB_ref(a11, b11, m_rem, 3, cs_a, cs_b,is_unitdiag); + m_remainder -= 1; + i += 1; } - else if(2 == n_rem) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); + j += 2; + n_remainder -= 2; + } + else if(n_remainder == 1) + { + a01 = L + j*rs_a; //pointer to block of A to be used in GEMM + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + double *ptr_a10_dup = D_A_pack; - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + dim_t p_lda = j; // packed leading dimension + // perform copy of A to packed buffer D_A_pack - b01 += 1; //move to next row of B - a10 += 4; //pointer math to calculate next block of A for GEMM - } + if(transa) + { + for(dim_t x =0;x < p_lda;x+=d_nr) + { + ymm0 = _mm256_loadu_pd((double const *)(a01)); + ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); + ymm2 = _mm256_loadu_pd((double const *)(a01 + cs_a * 2)); + ymm3 = _mm256_loadu_pd((double const *)(a01 + cs_a * 3)); - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); - ///implement TRSM/// - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1)); - ymm1 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 1 + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); - - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); - - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - xmm5 = _mm256_extractf128_pd(ymm1, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 1), xmm5); - _mm_storel_pd((b11 + cs_b * 1 + 2), _mm256_extractf128_pd(ymm1, 1)); + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); - dtrsm_AutXB_ref(a11, b11, m_rem, 2, cs_a, cs_b,is_unitdiag); - } - else if(1 == n_rem) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - b01 += 1; //move to next row of B - a10 += 4; //pointer math to calculate next block of A for GEMM - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 0)); - ymm0 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 0 + 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); - - xmm5 = _mm256_extractf128_pd(ymm0, 0); - _mm_storeu_pd((double *)(b11), xmm5); - _mm_storel_pd((b11 + 2), _mm256_extractf128_pd(ymm0, 1)); - - dtrsm_AutXB_ref(a11, b11, m_rem, 1, cs_a, cs_b, is_unitdiag); - } - } - } - else if(2 == m_rem) // Repetative A blocks will be 2*2 - { - dim_t p_lda = 4; // packed leading dimension - for(dim_t x=0;x= 4)) - { - a10 = D_A_pack; //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM - b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM - - k_iter = i; //number of times GEMM to be performed(in blocks of 4x4) - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); - #endif - - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + else { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + dim_t loop_count = p_lda/4; - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + for(dim_t x =0;x < loop_count;x++) + { + ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 0 + x*4)); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + x*4), ymm15); + } - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + dim_t remainder_loop_count = p_lda - loop_count*4; - b01 += 1; //move to next row of B - a10 += 4; //pointer math to calculate next block of A for GEMM + __m128d xmm0; + if(remainder_loop_count != 0) + { + xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 0 + loop_count*4)); + _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + loop_count*4), xmm0); + } } - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); - ymm3 = _mm256_insertf128_pd(ymm3, xmm5, 0); - - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0C); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0C); - - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - xmm5 = _mm256_extractf128_pd(ymm3, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 3), xmm5); - - dtrsm_AutXB_ref(a11, b11, m_rem, 4, cs_a, cs_b, is_unitdiag); - n_rem -= 4; - j +=4; - } - if(n_rem) - { - a10 = D_A_pack; //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM - b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM - - k_iter = i; //number of times GEMM to be performed(in blocks of 4x4) - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - #endif - - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - - if(3 == n_rem) + ymm4 = _mm256_broadcast_sd((double const *)&ones); + if(!is_unitdiag) { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - - b01 += 1; //move to next row of B - a10 += 4; //pointer math to calculate next block of A for GEMM - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2)); - ymm2 = _mm256_insertf128_pd(ymm2, xmm5, 0); - - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0C); + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)&ones); + ymm2 = _mm256_broadcast_sd((double const *)&ones); + ymm3 = _mm256_broadcast_sd((double const *)&ones); - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - xmm5 = _mm256_extractf128_pd(ymm2, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 2), xmm5); + ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); - dtrsm_AutXB_ref(a11, b11, m_rem, 3, cs_a, cs_b, is_unitdiag); + ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); + #ifdef BLIS_DISABLE_TRSM_PREINVERSION + ymm4 = ymm1; + #endif + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + ymm4 = _mm256_div_pd(ymm4, ymm1); + #endif } - else if(2 == n_rem) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - b01 += 1; //move to next row of B - a10 += 4; //pointer math to calculate next block of A for GEMM - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + _mm256_storeu_pd((double *)(d11_pack), ymm4); - ///implement TRSM/// - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); + for(i = 0; (i+d_mr-1) < m; i += d_mr) //loop along 'M' direction + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); + ymm3 = _mm256_setzero_pd(); + ymm4 = _mm256_setzero_pd(); + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_1nx8m(a01,b10,cs_b,p_lda,k_iter) - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - xmm5 = _mm256_extractf128_pd(ymm1, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 1), xmm5); + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); - dtrsm_AutXB_ref(a11, b11, m_rem, 2, cs_a, cs_b, is_unitdiag); - } - else if(1 == n_rem) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - b01 += 1; //move to next row of B - a10 += 4; //pointer math to calculate next block of A for GEMM - } + ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + 4)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + ymm4 = _mm256_fmsub_pd(ymm1, ymm15, ymm4); //B11[4-7][0] * alpha-= ymm1 - ///implement TRSM/// - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 0)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); + ///implement TRSM/// - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - xmm5 = _mm256_extractf128_pd(ymm0, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 0), xmm5); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm0); - dtrsm_AutXB_ref(a11, b11, m_rem, 1, cs_a, cs_b, is_unitdiag); + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + 4), ymm4); } - } - m_rem -=2; - i+=2; - } - else if(1 == m_rem) // Repetative A blocks will be 1*1 - { - dim_t p_lda = 4; // packed leading dimension - for(dim_t x=0;x= 4) { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); - ymm4 = _mm256_fmadd_pd(ymm2, ymm0, ymm4); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - b01 += 1; //move to next row of B - a10 += p_lda; //pointer math to calculate next block of A for GEMM - } + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - ///GEMM code ends/// - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to store alpha value - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0E); - - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); - - _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store(B11[0-3][3]) - - dtrsm_AutXB_ref(a11, b11, m_rem, 6, cs_a, cs_b, is_unitdiag); - } - dim_t n_rem = n-j; - if((n_rem >= 4)) - { - a10 = D_A_pack; //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM - b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM - - k_iter = i; //number of times GEMM to be performed(in blocks of 4x4) - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); - #endif + ymm3 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + ///implement TRSM/// - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + _mm256_storeu_pd((double *)b11, ymm3); - b01 += 1; //move to next row of B - a10 += 4; //pointer math to calculate next block of A for GEMM + m_remainder -= 4; + i += 4; } - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_broadcast_sd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_broadcast_sd((double const *)(b11 + cs_b *2)); - ymm3 = _mm256_broadcast_sd((double const *)(b11 + cs_b *3)); - - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0E); - - _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm0, 0)); - _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm1, 0)); - _mm_storel_pd((b11 + cs_b * 2), _mm256_extractf128_pd(ymm2, 0)); - _mm_storel_pd((b11 + cs_b * 3), _mm256_extractf128_pd(ymm3, 0)); - - dtrsm_AutXB_ref(a11, b11, m_rem, 4, cs_a, cs_b, is_unitdiag); - n_rem -= 4; - j+=4; - } - - if(n_rem) - { - a10 = D_A_pack; //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM - b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM - - k_iter = i; //number of times GEMM to be performed(in blocks of 4x4) - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - #endif - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - - if(3 == n_rem) + if(m_remainder == 3) { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - b01 += 1; //move to next row of B - a10 += 4; //pointer math to calculate next block of A for GEMM - } + ymm3 = _mm256_setzero_pd(); - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) - ///implement TRSM/// - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_broadcast_sd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_broadcast_sd((double const *)(b11 + cs_b *2)); + BLIS_PRE_DTRSM_SMALL_1N_3M(AlphaVal,b11,cs_b) - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ///implement TRSM/// - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm0, 0)); - _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm1, 0)); - _mm_storel_pd((b11 + cs_b * 2), _mm256_extractf128_pd(ymm2, 0)); + BLIS_POST_DTRSM_SMALL_1N_3M(b11,cs_b) - dtrsm_AutXB_ref(a11, b11, m_rem, 3, cs_a, cs_b, is_unitdiag); + m_remainder -= 3; + i += 3; } - else if(2 == n_rem) + else if(m_remainder == 2) { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - b01 += 1; //move to next row of B - a10 += 4; //pointer math to calculate next block of A for GEMM - } + ymm3 = _mm256_setzero_pd(); - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) - ///implement TRSM/// - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_broadcast_sd((double const *)(b11 + cs_b *1)); + BLIS_PRE_DTRSM_SMALL_1N_2M(AlphaVal,b11,cs_b) - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ///implement TRSM/// - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm0, 0)); - _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm1, 0)); + BLIS_POST_DTRSM_SMALL_1N_2M(b11,cs_b) - dtrsm_AutXB_ref(a11, b11, m_rem, 2, cs_a, cs_b, is_unitdiag); + m_remainder -= 2; + i += 2; } - else if(1 == n_rem) + else if(m_remainder == 1) { - ///GEMM code begins/// - - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - b01 += 1; //move to next row of B - a10 += 4; //pointer math to calculate next block of A for GEMM - } + ymm3 = _mm256_setzero_pd(); - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) - ///implement TRSM/// + BLIS_PRE_DTRSM_SMALL_1N_1M(AlphaVal,b11,cs_b) - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b *0)); - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ///implement TRSM/// - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm0, 0)); + BLIS_POST_DTRSM_SMALL_1N_1M(b11,cs_b) - dtrsm_AutXB_ref(a11, b11, m_rem, 1, cs_a, cs_b, is_unitdiag); + m_remainder -= 1; + i += 1; } - } - m_rem -=1; - i+=1; + j += 1; + n_remainder -= 1; } - } - if ((required_packing_A == 1) && - bli_mem_is_alloc( &local_mem_buf_A_s )) - { - bli_membrk_release(&rntm, &local_mem_buf_A_s); - } - return BLIS_SUCCESS; + if ((required_packing_A == 1) && bli_mem_is_alloc( &local_mem_buf_A_s )) + { + bli_membrk_release(&rntm, + &local_mem_buf_A_s); + } + + return BLIS_SUCCESS; } -/* TRSM for the case AX = alpha * B, Double precision - * A is lower-triangular, transpose, non-unit diagonal - * dimensions A: mxm X: mxn B: mxn +/*implements TRSM for the case XA = alpha * B + *A is upper triangular, non-unit diagonal/unit diagonal, transpose + *dimensions: X:mxn A:nxn B: mxn + * + * <---b11 <---a11 + ***************** * + *b01*b11* * * * * + ^ * * * * * ^ * * + | ***************** | ******* + | * * * * * | * * * + | * * * * * a01* * * +b10 ***************** ************* + * * * * * * * * * + * * * * * * * * * + ***************** ******************* + + *implements TRSM for the case XA = alpha * B + *A is lower triangular, non-unit diagonal/unit diagonal, no transpose + *dimensions: X:mxn A:nxn B: mxn + * + * <---b11 <---a11 + ***************** * + *b01*b11* * * * * + ^ * * * * * ^ * * + | ***************** | ******* + | * * * * * | * * * + | * * * * * a01* * * +b10 ***************** ************* + * * * * * * * * * + * * * * * * * * * + ***************** ******************* + */ -BLIS_INLINE err_t bli_dtrsm_small_AltXB +BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB ( obj_t* AlphaObj, obj_t* a, @@ -6300,12152 +4371,924 @@ BLIS_INLINE err_t bli_dtrsm_small_AltXB cntl_t* cntl ) { - dim_t m = bli_obj_length(b); // number of rows of matrix B - dim_t n = bli_obj_width(b); // number of columns of matrix B - - dim_t cs_a = bli_obj_col_stride(a); // column stride of A - dim_t cs_b = bli_obj_col_stride(b); // column stride of B - - dim_t i, j, k; //loop variables - dim_t k_iter; //number of times GEMM to be performed - - double AlphaVal = *(double *)AlphaObj->buffer; //value of alpha - double *L = a->buffer; //pointer to matrix A - double *B = b->buffer; //pointer to matrix B - - //pointers that point to blocks for GEMM and TRSM - double *a10, *a11, *b01, *b11; - - double ones = 1.0; - bool is_unitdiag = bli_obj_has_unit_diag(a); - - //scratch registers - __m256d ymm0, ymm1, ymm2, ymm3; - __m256d ymm4, ymm5, ymm6, ymm7; - __m256d ymm8, ymm9, ymm10, ymm11; - __m256d ymm12, ymm13, ymm14, ymm15; - __m256d ymm16, ymm17, ymm18, ymm19; - __m256d ymm20; - - __m128d xmm5; - - gint_t required_packing_A = 1; - mem_t local_mem_buf_A_s = {0}; - double *D_A_pack = NULL; - double d11_pack[D_MR] __attribute__((aligned(64))); - rntm_t rntm; - - bli_rntm_init_from_global( &rntm ); - bli_rntm_set_num_threads_only( 1, &rntm ); - bli_membrk_rntm_set_membrk( &rntm ); - - siz_t buffer_size = bli_pool_block_size( - bli_membrk_pool( - bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), - bli_rntm_membrk(&rntm))); - - if( ( D_MR * m * sizeof(double)) > buffer_size) - return BLIS_NOT_YET_IMPLEMENTED; - - if(required_packing_A == 1) - { - // Get the buffer from the pool. - bli_membrk_acquire_m(&rntm, - buffer_size, - BLIS_BITVAL_BUFFER_FOR_A_BLOCK, - &local_mem_buf_A_s); - - D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); - if(NULL==D_A_pack) return BLIS_NULL_POINTER; - } - - /* - Performs solving TRSM for 8 colmns at a time from 0 to m/D_MR in steps of D_MR - a. Load, transpose, Pack A (a10 block), the size of packing 8x6 to 8x (m-D_MR) - First there will be no GEMM and no packing of a10 because it is only TRSM - b. Using packed a10 block and b01 block perform GEMM operation - c. Use GEMM outputs, perform TRSM operaton using a11, b11 and update B - d. Repeat b,c for n rows of B in steps of D_NR - */ - for(i = (m - D_MR); (i + 1) > 0; i -= D_MR) - { - a10 = L + (i*cs_a) + i + D_MR; //pointer to block of A to be used for GEMM - a11 = L + (i*cs_a) + i; //pointer to block of A to be used for TRSM - - // Do transpose for a10 & store in D_A_pack - //ptr_a10_dup = D_A_pack; - - dim_t p_lda = D_MR; // packed leading dimension - /* - Load, transpose and pack current A block (a10) into packed buffer memory D_A_pack - a. This a10 block is used in GEMM portion only and this - a10 block size will be increasing by D_MR for every next itteration - untill it reaches 8x(m-8) which is the maximum GEMM alone block size in A - b. This packed buffer is reused to calculate all n rows of B matrix - */ - bli_dtrsm_small_pack('L', (m-i-D_MR), 1, a10, cs_a, D_A_pack,p_lda); + dim_t m = bli_obj_length(b); //number of rows + dim_t n = bli_obj_width(b); //number of columns - /* - Pack 8 diagonal elements of A block into an array - a. This helps in utilze cache line efficiently in TRSM operation - b. store ones when input is unit diagonal - */ - dtrsm_small_pack_diag_element(is_unitdiag,a11,cs_a,d11_pack,D_MR); + bool transa = bli_obj_has_trans(a); + dim_t cs_a, rs_a; + dim_t d_mr = 8,d_nr = 6; - /* - a. Perform GEMM using a10, b01. - b. Perform TRSM on a11, b11 - c. This loop GEMM+TRSM loops operates with 8x6 block size - along n dimension for every D_NR rows of b01 where - packed A buffer is reused in computing all n rows of B. - d. Same approch is used in remaining fringe cases. - */ - for(j = (n - D_NR); (j + 1) > 0; j -= D_NR) + // Swap rs_a & cs_a in case of non-tranpose. + if(transa) { - a10 = D_A_pack; - b01 = B + (j * cs_b) + i + D_MR; //pointer to block of B to be used for GEMM - b11 = B + (j * cs_b) + i; //pointer to block of B to be used for TRSM - - k_iter = (m - i - D_MR); - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); - #endif + cs_a = bli_obj_col_stride(a); // column stride of A + rs_a = bli_obj_row_stride(a); // row stride of A + } + else + { + cs_a = bli_obj_row_stride(a); // row stride of A + rs_a = bli_obj_col_stride(a); // column stride of A + } + dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS + dim_t i, j, k; //loop variablse + dim_t k_iter; //determines the number of GEMM operations to be done - /* - Peform GEMM between a10 and b01 blocks - For first itteration there will be no GEMM operation - where k_iter are zero - */ - BLIS_DTRSM_SMALL_GEMM_8x6(a10,b01,cs_b,p_lda,k_iter) + double ones = 1.0; + double zero = 0.0; + bool is_unitdiag = bli_obj_has_unit_diag(a); - /* - Load b11 of size 6x8 and multiply with alpha - Add the GEMM output and perform inregister transose of b11 - to peform TRSM operation. - */ + double AlphaVal = *(double *)AlphaObj->buffer; //value of Alpha + double* restrict L = a->buffer; //pointer to matrix A + double* restrict B = b->buffer; //pointer to matrix B - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - - ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2 + 4)); - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3 + 4)); - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm12); - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm13); - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm14); - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm15); - - ymm13 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm15 = _mm256_unpacklo_pd(ymm2, ymm3); - ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); - ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); - - ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm15 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *4 + 4)); - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *5 + 4)); - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm6); - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm7); - - ymm16 = _mm256_broadcast_sd((double const *)(&ones)); - ymm7 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm4 = _mm256_permute2f128_pd(ymm7,ymm16,0x20); - ymm6 = _mm256_permute2f128_pd(ymm7,ymm16,0x31); - - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm5 = _mm256_permute2f128_pd(ymm0,ymm16,0x20); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm16,0x31); - ymm18 = _mm256_unpacklo_pd(ymm2, ymm3); - ymm17 = _mm256_permute2f128_pd(ymm18,ymm16,0x20); - ymm19 = _mm256_permute2f128_pd(ymm18,ymm16,0x31); - - ////unpackhigh//// - ymm20 = _mm256_unpackhi_pd(ymm2, ymm3); - - //rearrange high elements - ymm18 = _mm256_permute2f128_pd(ymm20,ymm16,0x20); - ymm20 = _mm256_permute2f128_pd(ymm20,ymm16,0x31); + double *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks - /* - Compute 8x6 TRSM block by using GEMM block output in register - a. The 8x6 input (gemm outputs) are stored in combinations of ymm registers - 1. ymm15, ymm20 2. ymm14, ymm19 3. ymm13, ymm18 , 4. ymm12, ymm17 - 5. ymm11, ymm7 6. ymm10, ymm6, 7.ymm9, ymm5 8. ymm8, ymm4 - where ymm15-ymm8 holds 8x4 data and reaming 8x2 will be hold by - other registers - b. Towards the end do in regiser transpose of TRSM output and store in b11 - */ - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); - - //perform mul operation - ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); - ymm20 = DTRSM_SMALL_DIV_OR_SCALE(ymm20, ymm1); - - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); - - //(ROw7): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6*cs_a + 7)); - ymm14 = _mm256_fnmadd_pd(ymm2, ymm15, ymm14); - ymm19 = _mm256_fnmadd_pd(ymm2, ymm20, ymm19); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 7)); - ymm13 = _mm256_fnmadd_pd(ymm2, ymm15, ymm13); - ymm18 = _mm256_fnmadd_pd(ymm2, ymm20, ymm18); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 7)); - ymm12 = _mm256_fnmadd_pd(ymm2, ymm15, ymm12); - ymm17 = _mm256_fnmadd_pd(ymm2, ymm20, ymm17); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 7)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm15, ymm11); - ymm7 = _mm256_fnmadd_pd(ymm2, ymm20, ymm7); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 7)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm15, ymm10); - ymm6 = _mm256_fnmadd_pd(ymm2, ymm20, ymm6); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 7)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm15, ymm9); - ymm5 = _mm256_fnmadd_pd(ymm2, ymm20, ymm5); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm15, ymm8); - ymm4 = _mm256_fnmadd_pd(ymm2, ymm20, ymm4); - - //perform mul operation - ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); - ymm19 = DTRSM_SMALL_DIV_OR_SCALE(ymm19, ymm1); - - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); - - //(ROw6): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 6)); - ymm13 = _mm256_fnmadd_pd(ymm2, ymm14, ymm13); - ymm18 = _mm256_fnmadd_pd(ymm2, ymm19, ymm18); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 6)); - ymm12 = _mm256_fnmadd_pd(ymm2, ymm14, ymm12); - ymm17 = _mm256_fnmadd_pd(ymm2, ymm19, ymm17); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 6)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm14, ymm11); - ymm7 = _mm256_fnmadd_pd(ymm2, ymm19, ymm7); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 6)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm14, ymm10); - ymm6 = _mm256_fnmadd_pd(ymm2, ymm19, ymm6); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 6)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm14, ymm9); - ymm5 = _mm256_fnmadd_pd(ymm2, ymm19, ymm5); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm14, ymm8); - ymm4 = _mm256_fnmadd_pd(ymm2, ymm19, ymm4); - - //perform mul operation - ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm1); - ymm18 = DTRSM_SMALL_DIV_OR_SCALE(ymm18, ymm1); - - //extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); - - //(ROw5): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 5)); - ymm12 = _mm256_fnmadd_pd(ymm2, ymm13, ymm12); - ymm17 = _mm256_fnmadd_pd(ymm2, ymm18, ymm17); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 5)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm13, ymm11); - ymm7 = _mm256_fnmadd_pd(ymm2, ymm18, ymm7); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 5)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm13, ymm10); - ymm6 = _mm256_fnmadd_pd(ymm2, ymm18, ymm6); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 5)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm13, ymm9); - ymm5 = _mm256_fnmadd_pd(ymm2, ymm18, ymm5); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm13, ymm8); - ymm4 = _mm256_fnmadd_pd(ymm2, ymm18, ymm4); - - //perform mul operation - ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); - ymm17 = DTRSM_SMALL_DIV_OR_SCALE(ymm17, ymm1); - - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - //(ROw4): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 4)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm12, ymm11); - ymm7 = _mm256_fnmadd_pd(ymm2, ymm17, ymm7); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 4)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm12, ymm10); - ymm6 = _mm256_fnmadd_pd(ymm2, ymm17, ymm6); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 4)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm12, ymm9); - ymm5 = _mm256_fnmadd_pd(ymm2, ymm17, ymm5); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm12, ymm8); - ymm4 = _mm256_fnmadd_pd(ymm2, ymm17, ymm4); - - //perform mul operation - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm1); - - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(ROw3): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 3)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm11, ymm10); - ymm6 = _mm256_fnmadd_pd(ymm2, ymm7, ymm6); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm11, ymm9); - ymm5 = _mm256_fnmadd_pd(ymm2, ymm7, ymm5); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm11, ymm8); - ymm4 = _mm256_fnmadd_pd(ymm2, ymm7, ymm4); - - //perform mul operation - ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); - ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm1); - - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(ROw2): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm10, ymm9); - ymm5 = _mm256_fnmadd_pd(ymm2, ymm6, ymm5); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm10, ymm8); - ymm4 = _mm256_fnmadd_pd(ymm2, ymm6, ymm4); - - //perform mul operation - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm1); - - //extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); - - //(ROw2): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm9, ymm8); - ymm4 = _mm256_fnmadd_pd(ymm2, ymm5, ymm4); - - //perform mul operation - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); - ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm1); - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); - ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); - - ///unpack high/// - ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); - ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); - ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); - - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm12, ymm13); - ymm3 = _mm256_unpacklo_pd(ymm14, ymm15); - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); - - ///unpack high/// - ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); - ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); - ymm3 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); - - _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm0); - _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm1); - _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 4), ymm2); - _mm256_storeu_pd((double *)(b11 + cs_b * 3 + 4), ymm3); - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm4, ymm5); - ymm3 = _mm256_unpacklo_pd(ymm6, ymm7); - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); - - ///unpack high/// - ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); - ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); - - _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); - _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm17, ymm18); - ymm3 = _mm256_unpacklo_pd(ymm19, ymm20); - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); + gint_t required_packing_A = 1; + mem_t local_mem_buf_A_s = {0}; + double *D_A_pack = NULL; + double d11_pack[d_mr] __attribute__((aligned(64))); + rntm_t rntm; - ///unpack high/// - ymm17 = _mm256_unpackhi_pd(ymm17, ymm18); - ymm18 = _mm256_unpackhi_pd(ymm19, ymm20); + bli_rntm_init_from_global( &rntm ); + bli_rntm_set_num_threads_only( 1, &rntm ); + bli_membrk_rntm_set_membrk( &rntm ); - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm17, ymm18, 0x20); + siz_t buffer_size = bli_pool_block_size( + bli_membrk_pool( + bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), + bli_rntm_membrk(&rntm))); - _mm256_storeu_pd((double *)(b11 + cs_b * 4 + 4), ymm0); - _mm256_storeu_pd((double *)(b11 + cs_b * 5 + 4), ymm1); + if( (d_nr * n * sizeof(double)) > buffer_size) + return BLIS_NOT_YET_IMPLEMENTED; + if (required_packing_A == 1) + { + // Get the buffer from the pool. + bli_membrk_acquire_m(&rntm, + buffer_size, + BLIS_BITVAL_BUFFER_FOR_A_BLOCK, + &local_mem_buf_A_s); + if(FALSE==bli_mem_is_alloc(&local_mem_buf_A_s)) return BLIS_NULL_POINTER; + D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); + if(NULL==D_A_pack) return BLIS_NULL_POINTER; } - dim_t n_remainder = j + D_NR; - if(n_remainder >= 4) - { - a10 = D_A_pack; - a11 = L + (i*cs_a) + i; - b01 = B + ((n_remainder - 4)* cs_b) + i + D_MR; - b11 = B + ((n_remainder - 4)* cs_b) + i; - - k_iter = (m - i - D_MR); - - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - ymm12 = _mm256_setzero_pd(); - ymm13 = _mm256_setzero_pd(); - ymm14 = _mm256_setzero_pd(); - ymm15 = _mm256_setzero_pd(); - - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); + //ymm scratch reginsters + __m256d ymm0, ymm1, ymm2, ymm3; + __m256d ymm4, ymm5, ymm6, ymm7; + __m256d ymm8, ymm9, ymm10, ymm11; + __m256d ymm12, ymm13, ymm14, ymm15; - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + __m128d xmm5; - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + /* + Performs solving TRSM for 6 rows at a time from 0 to n/6 in steps of d_nr + a. Load and pack A (a01 block), the size of packing 6x6 to 6x (n-6) + First there will be no GEMM and no packing of a01 because it is only TRSM + b. Using packed a01 block and b10 block perform GEMM operation + c. Use GEMM outputs, perform TRSM operation using a11, b11 and update B + d. Repeat b for m cols of B in steps of d_mr + */ - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + for(j = (n-d_nr); (j+1) > 0; j -= d_nr) //loop along 'N' direction + { + a01 = L + (j*rs_a) + (j+d_nr)*cs_a; //pointer to block of A to be used in GEMM + a11 = L + (j*cs_a) + (j*rs_a); //pointer to block of A to be used for TRSM - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15); + //double *ptr_a10_dup = D_A_pack; - b01 += 1; //move to next row of B - a10 += p_lda; - } + dim_t p_lda = (n-j-d_nr); // packed leading dimension + // perform copy of A to packed buffer D_A_pack - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] - ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] - ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *2 + 4)); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] - ymm7 = _mm256_loadu_pd((double const *)(b11 + cs_b *3 + 4)); //B11[0][7] B11[1][7] B11[2][7] B11[3][7] - - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); //B11[0-3][3] * alpha -= B01[0-3][3] - ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4] - ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] * alpha -= B01[0-3][5] - ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); //B11[0-3][6] * alpha -= B01[0-3][6] - ymm7 = _mm256_fmsub_pd(ymm7, ymm16, ymm15); //B11[0-3][7] * alpha -= B01[0-3][7] - - ///implement TRSM/// - - ///transpose of B11// - ///unpacklow/// - ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - - ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[2][4] B11[2][5] - ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[2][6] B11[2][7] - - //rearrange low elements - ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] - - ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3] - ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3] - - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - - ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[1][4] B11[1][5] B11[3][4] B11[3][5] - ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[1][6] B11[1][7] B11[3][6] B11[3][7] - - //rearrange high elements - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - - ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3] - ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3] - - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); - - //perform mul operation - ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); - - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); - - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6*cs_a + 7)); - ymm3 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 7)); - ymm4 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 7)); - ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 7)); - ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 7)); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 7)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + 7)); - - //(ROw7): FMA operations - ymm14 = _mm256_fnmadd_pd(ymm2, ymm15, ymm14); - ymm13 = _mm256_fnmadd_pd(ymm3, ymm15, ymm13); - ymm12 = _mm256_fnmadd_pd(ymm4, ymm15, ymm12); - ymm11 = _mm256_fnmadd_pd(ymm5, ymm15, ymm11); - ymm10 = _mm256_fnmadd_pd(ymm6, ymm15, ymm10); - ymm9 = _mm256_fnmadd_pd(ymm7, ymm15, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm16, ymm15, ymm8); + if(transa) + { + /* + Pack current A block (a01) into packed buffer memory D_A_pack + a. This a10 block is used in GEMM portion only and this + a01 block size will be increasing by d_nr for every next iteration + until it reaches 6x(n-6) which is the maximum GEMM alone block size in A + b. This packed buffer is reused to calculate all m cols of B matrix + */ + bli_dtrsm_small_pack('R', p_lda, 1, a01, cs_a, D_A_pack, p_lda,d_nr); - //perform mul operation - ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); + /* + Pack 6 diagonal elements of A block into an array + a. This helps in utilze cache line efficiently in TRSM operation + b. store ones when input is unit diagonal + */ + dtrsm_small_pack_diag_element(is_unitdiag,a11,cs_a,d11_pack,d_nr); + } + else + { + bli_dtrsm_small_pack('R', p_lda, 0, a01, rs_a, D_A_pack, p_lda,d_nr); + dtrsm_small_pack_diag_element(is_unitdiag,a11,rs_a,d11_pack,d_nr); + } - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + /* + a. Perform GEMM using a01, b10. + b. Perform TRSM on a11, b11 + c. This loop GEMM+TRSM loops operates with 8x6 block size + along m dimension for every d_mr columns of B10 where + packed A buffer is reused in computing all m cols of B. + d. Same approach is used in remaining fringe cases. + */ + for(i = (m-d_mr); (i+1) > 0; i -= d_mr) //loop along 'M' direction + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i + (j+d_nr)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (i) + (j)*cs_b; //pointer to block of B to be used for TRSM - ymm3 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 6)); - ymm4 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 6)); - ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 6)); - ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 6)); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 6)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + 6)); + k_iter = (n-j-d_nr); //number of GEMM operations to be done(in blocks of 4x4) - //(ROw6): FMA operations - ymm13 = _mm256_fnmadd_pd(ymm3, ymm14, ymm13); - ymm12 = _mm256_fnmadd_pd(ymm4, ymm14, ymm12); - ymm11 = _mm256_fnmadd_pd(ymm5, ymm14, ymm11); - ymm10 = _mm256_fnmadd_pd(ymm6, ymm14, ymm10); - ymm9 = _mm256_fnmadd_pd(ymm7, ymm14, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm16, ymm14, ymm8); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS - //perform mul operation - ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm1); + /* + Peform GEMM between a01 and b10 blocks + For first itteration there will be no GEMM operation + where k_iter are zero + */ - //extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + BLIS_DTRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) - ymm4 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 5)); - ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 5)); - ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 5)); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 5)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + 5)); - - //(ROw5): FMA operations - ymm12 = _mm256_fnmadd_pd(ymm4, ymm13, ymm12); - ymm11 = _mm256_fnmadd_pd(ymm5, ymm13, ymm11); - ymm10 = _mm256_fnmadd_pd(ymm6, ymm13, ymm10); - ymm9 = _mm256_fnmadd_pd(ymm7, ymm13, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm16, ymm13, ymm8); - - //perform mul operation - ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); + /* + Load b11 of size 8x6 and multiply with alpha + Add the GEMM output to b11 + and peform TRSM operation. + */ - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 4)); - ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 4)); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 4)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + 4)); - - //(ROw4): FMA operations - ymm11 = _mm256_fnmadd_pd(ymm5, ymm12, ymm11); - ymm10 = _mm256_fnmadd_pd(ymm6, ymm12, ymm10); - ymm9 = _mm256_fnmadd_pd(ymm7, ymm12, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm16, ymm12, ymm8); - - //perform mul operation - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); - - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 3)); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + 3)); - - //(ROw3): FMA operations - ymm10 = _mm256_fnmadd_pd(ymm6, ymm11, ymm10); - ymm9 = _mm256_fnmadd_pd(ymm7, ymm11, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm16, ymm11, ymm8); - - //perform mul operation - ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); - - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + 2)); - - //(ROw2): FMA operations - ymm9 = _mm256_fnmadd_pd(ymm7, ymm10, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm16, ymm10, ymm8); - - //perform mul operation - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); - - //extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); - - ymm16 = _mm256_broadcast_sd((double const *)(a11 + 1)); - - //(ROw2): FMA operations - ymm8 = _mm256_fnmadd_pd(ymm16, ymm9, ymm8); - - //perform mul operation - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); + BLIS_PRE_DTRSM_SMALL_6x8(AlphaVal,b11,cs_b) - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + ///implement TRSM/// - ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] - ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2] + /* + Compute 6x8 TRSM block by using GEMM block output in register + a. The 6x8 input (gemm outputs) are stored in combinations of ymm registers + 1. ymm3, ymm4 2. ymm5, ymm6 3. ymm7, ymm8, 4. ymm9, ymm10 + 5. ymm11, ymm12 6. ymm13,ymm14 + b. Towards the end TRSM output will be stored back into b11 + */ - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + //extract a55 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); - ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] - ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] + ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); + ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm0); - ///unpack high/// - ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + //extract a44 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); - ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[4][1] B11[5][1] B11[4][3] B11[5][3] - ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[6][1] B11[7][1] B11[6][3] B11[7][3] + //(row 5):FMA operations + //ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 4*rs_a)); - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm11 = _mm256_fnmadd_pd(ymm1, ymm13, ymm11); + ymm12 = _mm256_fnmadd_pd(ymm1, ymm14, ymm12); - ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] - ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 3*rs_a)); - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); //store B11[4][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); //store B11[5][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 4), ymm6); //store B11[6][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 3 + 4), ymm7); //store B11[7][0-3] - n_remainder -=4; - } + ymm9 = _mm256_fnmadd_pd(ymm1, ymm13, ymm9); + ymm10 = _mm256_fnmadd_pd(ymm1, ymm14, ymm10); - if(n_remainder) //implementation fo remaining columns(when 'N' is not a multiple of D_NR)() n = 3 - { - a10 = D_A_pack; - a11 = L + (i*cs_a) + i; - b01 = B + i + D_MR; - b11 = B + i; + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 2*rs_a)); - k_iter = (m - i - D_MR) ; + ymm7 = _mm256_fnmadd_pd(ymm1, ymm13, ymm7); + ymm8 = _mm256_fnmadd_pd(ymm1, ymm14, ymm8); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm12 = _mm256_setzero_pd(); - ymm13 = _mm256_setzero_pd(); - ymm14 = _mm256_setzero_pd(); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 1*rs_a)); - if(3 == n_remainder) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm13, ymm5); + ymm6 = _mm256_fnmadd_pd(ymm1, ymm14, ymm6); - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a)); - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm13, ymm3); + ymm4 = _mm256_fnmadd_pd(ymm1, ymm14, ymm4); - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); + ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm0); - b01 += 1; //move to next row of B - a10 += p_lda; - } + //extract a33 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + //(row 4):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 3*rs_a)); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm9 = _mm256_fnmadd_pd(ymm1, ymm11, ymm9); + ymm10 = _mm256_fnmadd_pd(ymm1, ymm12, ymm10); - ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] - ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] - ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *2 + 4)); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 2*rs_a)); - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] - ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm11, ymm7); + ymm8 = _mm256_fnmadd_pd(ymm1, ymm12, ymm8); - ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4] - ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] * alpha -= B01[0-3][5] - ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); //B11[0-3][6] * alpha -= B01[0-3][6] - ymm7 = _mm256_broadcast_sd((double const *)(&ones)); - } - else if(2 == n_remainder) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 1*rs_a)); - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm11, ymm5); + ymm6 = _mm256_fnmadd_pd(ymm1, ymm12, ymm6); - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a)); - b01 += 1; //move to next row of B - a10 += p_lda; - } + ymm3 = _mm256_fnmadd_pd(ymm1, ymm11, ymm3); + ymm4 = _mm256_fnmadd_pd(ymm1, ymm12, ymm4); - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm0); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + //extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] - ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2*rs_a)); - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] - ymm2 = _mm256_broadcast_sd((double const *)(&ones)); - ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm9, ymm7); + ymm8 = _mm256_fnmadd_pd(ymm1, ymm10, ymm8); - ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4] - ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] * alpha -= B01[0-3][5] - ymm6 = _mm256_broadcast_sd((double const *)(&ones)); - ymm7 = _mm256_broadcast_sd((double const *)(&ones)); - } - else if(1 == n_remainder) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1*rs_a)); - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm9, ymm5); + ymm6 = _mm256_fnmadd_pd(ymm1, ymm10, ymm6); - b01 += 1; //move to next row of B - a10 += p_lda; - } + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + ymm3 = _mm256_fnmadd_pd(ymm1, ymm9, ymm3); + ymm4 = _mm256_fnmadd_pd(ymm1, ymm10, ymm4); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm0); - ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] + //extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_broadcast_sd((double const *)(&ones)); - ymm2 = _mm256_broadcast_sd((double const *)(&ones)); - ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1*rs_a)); - ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4] - ymm5 = _mm256_broadcast_sd((double const *)(&ones)); - ymm6 = _mm256_broadcast_sd((double const *)(&ones)); - ymm7 = _mm256_broadcast_sd((double const *)(&ones)); - } - ///implement TRSM/// - - ///transpose of B11// - ///unpacklow/// - ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - - ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[2][4] B11[2][5] - ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[2][6] B11[2][7] - - //rearrange low elements - ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] - - ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3] - ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3] - - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - - ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[1][4] B11[1][5] B11[3][4] B11[3][5] - ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[1][6] B11[1][7] B11[3][6] B11[3][7] - - //rearrange high elements - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - - ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3] - ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3] - - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); - - //perform mul operation - ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); - - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); - - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6*cs_a + 7)); - ymm3 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 7)); - ymm4 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 7)); - ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 7)); - ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 7)); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 7)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + 7)); - - //(ROw7): FMA operations - ymm14 = _mm256_fnmadd_pd(ymm2, ymm15, ymm14); - ymm13 = _mm256_fnmadd_pd(ymm3, ymm15, ymm13); - ymm12 = _mm256_fnmadd_pd(ymm4, ymm15, ymm12); - ymm11 = _mm256_fnmadd_pd(ymm5, ymm15, ymm11); - ymm10 = _mm256_fnmadd_pd(ymm6, ymm15, ymm10); - ymm9 = _mm256_fnmadd_pd(ymm7, ymm15, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm16, ymm15, ymm8); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); + ymm6 = _mm256_fnmadd_pd(ymm1, ymm8, ymm6); - //perform mul operation - ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); + ymm4 = _mm256_fnmadd_pd(ymm1, ymm8, ymm4); - ymm3 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 6)); - ymm4 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 6)); - ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 6)); - ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 6)); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 6)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + 6)); + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm0); - //(ROw6): FMA operations - ymm13 = _mm256_fnmadd_pd(ymm3, ymm14, ymm13); - ymm12 = _mm256_fnmadd_pd(ymm4, ymm14, ymm12); - ymm11 = _mm256_fnmadd_pd(ymm5, ymm14, ymm11); - ymm10 = _mm256_fnmadd_pd(ymm6, ymm14, ymm10); - ymm9 = _mm256_fnmadd_pd(ymm7, ymm14, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm16, ymm14, ymm8); + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); - //perform mul operation - ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm1); + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - //extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + ymm4 = _mm256_fnmadd_pd(ymm1, ymm6, ymm4); - ymm4 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 5)); - ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 5)); - ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 5)); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 5)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + 5)); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm0); - //(ROw5): FMA operations - ymm12 = _mm256_fnmadd_pd(ymm4, ymm13, ymm12); - ymm11 = _mm256_fnmadd_pd(ymm5, ymm13, ymm11); - ymm10 = _mm256_fnmadd_pd(ymm6, ymm13, ymm10); - ymm9 = _mm256_fnmadd_pd(ymm7, ymm13, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm16, ymm13, ymm8); + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + 4), ymm4); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + _mm256_storeu_pd((double *)(b11 + cs_b + 4), ymm6); + _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); + _mm256_storeu_pd((double *)(b11 + cs_b*2 + 4), ymm8); + _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); + _mm256_storeu_pd((double *)(b11 + cs_b*3 + 4), ymm10); + _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); + _mm256_storeu_pd((double *)(b11 + cs_b*4 + 4), ymm12); + _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); + _mm256_storeu_pd((double *)(b11 + cs_b*5 + 4), ymm14); + } - //perform mul operation - ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); + dim_t m_remainder = i + d_mr; + if(m_remainder >= 4) + { + a01 = D_A_pack; + a11 = L + (j*cs_a) + (j*rs_a); + b10 = B + (m_remainder - 4) + (j+d_nr)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 4) + (j*cs_b); - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + k_iter = (n-j-d_nr); //number of GEMM operations to be done(in blocks of 4x4) - ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 4)); - ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 4)); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 4)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + 4)); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS - //(ROw4): FMA operations - ymm11 = _mm256_fnmadd_pd(ymm5, ymm12, ymm11); - ymm10 = _mm256_fnmadd_pd(ymm6, ymm12, ymm10); - ymm9 = _mm256_fnmadd_pd(ymm7, ymm12, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm16, ymm12, ymm8); + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_6nx4m(a01,b10,cs_b,p_lda,k_iter) - //perform mul operation - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); + // Load b11 of size 4x6 and multiply with alpha + BLIS_PRE_DTRSM_SMALL_6x4(AlphaVal,b11,cs_b) - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + ///implement TRSM/// - ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 3)); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + 3)); + //extract a55 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - //(ROw3): FMA operations - ymm10 = _mm256_fnmadd_pd(ymm6, ymm11, ymm10); - ymm9 = _mm256_fnmadd_pd(ymm7, ymm11, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm16, ymm11, ymm8); + //extract a44 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); - //perform mul operation - ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + //(row 5):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 4*rs_a)); + ymm11 = _mm256_fnmadd_pd(ymm1, ymm13, ymm11); - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 3*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm13, ymm9); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + 2)); - - //(ROw2): FMA operations - ymm9 = _mm256_fnmadd_pd(ymm7, ymm10, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm16, ymm10, ymm8); - - //perform mul operation - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 2*rs_a)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm13, ymm7); - //extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm13, ymm5); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + 1)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm13, ymm3); - //(ROw2): FMA operations - ymm8 = _mm256_fnmadd_pd(ymm16, ymm9, ymm8); + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); - //perform mul operation - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); + //extract a33 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + //(row 4):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 3*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm11, ymm9); - ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] - ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2] + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 2*rs_a)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm11, ymm7); - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm11, ymm5); - ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] - ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm11, ymm3); - ///unpack high/// - ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[4][1] B11[5][1] B11[4][3] B11[5][3] - ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[6][1] B11[7][1] B11[6][3] B11[7][3] + //extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2*rs_a)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm9, ymm7); - ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] - ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm9, ymm5); - if(3 == n_remainder) - { - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); //store B11[4][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); //store B11[5][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 4), ymm6); //store B11[6][0-3] - } - else if(2 == n_remainder) - { - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); //store B11[4][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); //store B11[5][0-3] - } - else if(1 == n_remainder) - { - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); //store B11[4][0-3] - } - } - }// End of multiples of D_MR blocks in m-dimension - - // Repetative A blocks will be 4*4 - dim_t m_remainder = i + D_MR; - if(m_remainder >= 4) - { - i = m_remainder - 4; - a10 = L + (i*cs_a) + i + 4; //pointer to block of A to be used for GEMM - a11 = L + (i*cs_a) + i; //pointer to block of A to be used for TRSM - - // Do transpose for a10 & store in D_A_pack - double *ptr_a10_dup = D_A_pack; - dim_t p_lda = 4; // packed leading dimension - for(dim_t x =0;x < m-i+4;x+=4) - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm1 = _mm256_loadu_pd((double const *)(a10 + cs_a)); - ymm2 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); - ymm3 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm9, ymm3); - ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + //extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); - _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - a10 += 4; - ptr_a10_dup += 4*4; - } + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); - ymm4 = _mm256_broadcast_sd((double const *)&ones); - if(!is_unitdiag) - { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_sd((double const *)(a11)); - ymm1 = _mm256_broadcast_sd((double const *)(a11+cs_a*1 + 1)); - ymm2 = _mm256_broadcast_sd((double const *)(a11+cs_a*2 + 2)); - ymm3 = _mm256_broadcast_sd((double const *)(a11+cs_a*3 + 3)); + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); - ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); - ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); - #ifdef BLIS_DISABLE_TRSM_PREINVERSION - ymm4 = ymm1; - #endif - #ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm4 = _mm256_div_pd(ymm4, ymm1); - #endif - } - _mm256_storeu_pd((double *)(d11_pack), ymm4); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - //cols - for(j = (n - D_NR); (j + 1) > 0; j -= D_NR) //loop along 'N' dimension - { - a10 = D_A_pack; - a11 = L + (i*cs_a) + i; //pointer to block of A to be used for TRSM - b01 = B + (j*cs_b) + i + 4; //pointer to block of B to be used for GEMM - b11 = B + (j* cs_b) + i; //pointer to block of B to be used for TRSM - - k_iter = (m - i - 4); //number of times GEMM to be performed(in blocks of 4x4) - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); - #endif + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); + _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); + _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); + _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); + m_remainder -=4; + } - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + if(m_remainder) { - ymm0 = _mm256_loadu_pd((double const *)(a10)); + if(3 == m_remainder) + { + a01 = D_A_pack; + a11 = L + (j*cs_a) + (j*rs_a); + b10 = B + (j+d_nr)*cs_b + (m_remainder - 3); //pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 3) + (j*cs_b); - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + k_iter = (n-j-d_nr); //number of GEMM operations to be done(in blocks of 4x4) - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_6nx4m(a01,b10,cs_b,p_lda,k_iter) - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + // Load b11 of size 4x6 and multiply with alpha + BLIS_PRE_DTRSM_SMALL_6x4(AlphaVal,b11,cs_b) - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); - ymm4 = _mm256_fmadd_pd(ymm2, ymm0, ymm4); + ///implement TRSM/// - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); + //extract a55 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - b01 += 1; //move to next row of B - a10 += p_lda; - } + //extract a44 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + //(row 5):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 4*rs_a)); + ymm11 = _mm256_fnmadd_pd(ymm1, ymm13, ymm11); - ///implement TRSM/// + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 3*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm13, ymm9); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 2*rs_a)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm13, ymm7); - ///transpose of B11// - ///unpacklow/// - ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm13, ymm5); - //rearrange low elements - ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm13, ymm3); - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); - //rearrange high elements - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + //extract a33 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); - - ymm16 = _mm256_broadcast_sd((double const *)(&ones)); + //(row 4):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 3*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm11, ymm9); - ////unpacklow//// - ymm7 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 2*rs_a)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm11, ymm7); - //rearrange low elements - ymm4 = _mm256_permute2f128_pd(ymm7,ymm16,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm6 = _mm256_permute2f128_pd(ymm7,ymm16,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm11, ymm5); - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm11, ymm3); - //rearrange high elements - ymm5 = _mm256_permute2f128_pd(ymm0,ymm16,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm7 = _mm256_permute2f128_pd(ymm0,ymm16,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + //extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2*rs_a)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm9, ymm7); - //perform mul operation - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm1); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm9, ymm5); - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm9, ymm3); - //(ROw3): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 3)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm11, ymm10); - ymm6 = _mm256_fnmadd_pd(ymm2, ymm7, ymm6); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm11, ymm9); - ymm5 = _mm256_fnmadd_pd(ymm2, ymm7, ymm5); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm11, ymm8); - ymm4 = _mm256_fnmadd_pd(ymm2, ymm7, ymm4); + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - //perform mul operation - ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); - ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm1); + //extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); - //(ROw2): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm10, ymm9); - ymm5 = _mm256_fnmadd_pd(ymm2, ymm6, ymm5); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm10, ymm8); - ymm4 = _mm256_fnmadd_pd(ymm2, ymm6, ymm4); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); - //perform mul operation - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm1); + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - //extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); - //(ROw2): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm9, ymm8); - ymm4 = _mm256_fnmadd_pd(ymm2, ymm5, ymm4); + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); - //perform mul operation - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); - ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm1); - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ///unpack high/// - ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + ymm0 = _mm256_loadu_pd((double const *)b11); + ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x07); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x07); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x07); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x07); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm11 = _mm256_blend_pd(ymm0, ymm11, 0x07); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm13 = _mm256_blend_pd(ymm0, ymm13, 0x07); - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); + _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); + _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); + _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] + m_remainder -=3; + } + else if(2 == m_remainder) + { + a01 = D_A_pack; + a11 = L + (j*cs_a) + (j*rs_a); + b10 = B + (j+d_nr)*cs_b + (m_remainder - 2); //pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 2) + (j*cs_b); - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + k_iter = (n-j-d_nr); //number of GEMM operations to be done(in blocks of 4x4) - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS - ///unpack high/// - ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_6nx4m(a01,b10,cs_b,p_lda,k_iter) - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + // Load b11 of size 4x6 and multiply with alpha + BLIS_PRE_DTRSM_SMALL_6x4(AlphaVal,b11,cs_b) - _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store B11[1][0-3] - } - dim_t n_remainder = j + D_NR; - if((n_remainder >= 4)) - { - a10 = D_A_pack; - a11 = L + (i*cs_a) + i; //pointer to block of A to be used for TRSM - b01 = B + ((n_remainder - 4)* cs_b) + i + 4; //pointer to block of B to be used for GEMM - b11 = B + ((n_remainder - 4)* cs_b) + i; //pointer to block of B to be used for TRSM - - k_iter = (m - i - 4); //number of times GEMM to be performed(in blocks of 4x4) - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - #endif + ///implement TRSM/// - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); + //extract a55 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + //extract a44 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + //(row 5):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 4*rs_a)); + ymm11 = _mm256_fnmadd_pd(ymm1, ymm13, ymm11); - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 3*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm13, ymm9); - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 2*rs_a)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm13, ymm7); - b01 += 1; //move to next row of B - a10 += p_lda; - } + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm13, ymm5); - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm13, ymm3); - ///implement TRSM/// + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + //extract a33 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - ///transpose of B11// - ///unpacklow/// - ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + //(row 4):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 3*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm11, ymm9); - //rearrange low elements - ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 2*rs_a)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm11, ymm7); - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm11, ymm5); - //rearrange high elements - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm11, ymm3); - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - //perform mul operation - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); + //extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2*rs_a)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm9, ymm7); - //(ROw3): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 3)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm11, ymm10); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm11, ymm9); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm11, ymm8); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm9, ymm5); - //perform mul operation - ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm9, ymm3); - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - //(ROw2): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm10, ymm9); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm10, ymm8); + //extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - //perform mul operation - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); - //extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); - //(ROw2): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm9, ymm8); + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - //perform mul operation - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ///unpack high/// - ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + ymm0 = _mm256_loadu_pd((double const *)b11); + ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x03); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x03); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x03); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x03); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm11 = _mm256_blend_pd(ymm0, ymm11, 0x03); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm13 = _mm256_blend_pd(ymm0, ymm13, 0x03); - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); + _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); + _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); + _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] - n_remainder = n_remainder - 4; - } + m_remainder -=2; + } + else if (1 == m_remainder) + { + a01 = D_A_pack; + a11 = L + (j*cs_a) + (j*rs_a); + b10 = B + (j+d_nr)*cs_b + (m_remainder - 1); //pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 1) + (j*cs_b); - if(n_remainder) //implementation fo remaining columns(when 'N' is not a multiple of D_NR)() n = 3 - { - a10 = D_A_pack; - a11 = L + (i*cs_a) + i; - b01 = B + i + 4; - b11 = B + i; + k_iter = (n-j-d_nr); //number of GEMM operations to be done(in blocks of 4x4) - k_iter = (m - i - 4); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_6nx4m(a01,b10,cs_b,p_lda,k_iter) - if(3 == n_remainder) - { - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); + // Load b11 of size 4x6 and multiply with alpha + BLIS_PRE_DTRSM_SMALL_6x4(AlphaVal,b11,cs_b) - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ///implement TRSM/// - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + //extract a55 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + //extract a44 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); - b01 += 1; //move to next row of B - a10 += p_lda; - } + //(row 5):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 4*rs_a)); + ymm11 = _mm256_fnmadd_pd(ymm1, ymm13, ymm11); - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 3*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm13, ymm9); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 2*rs_a)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm13, ymm7); - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] - ymm3 = _mm256_broadcast_sd((double const *)(&ones)); - } - else if(2 == n_remainder) - { + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm13, ymm5); - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm13, ymm3); - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); + //extract a33 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - b01 += 1; //move to next row of B - a10 += p_lda; - } + //(row 4):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 3*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm11, ymm9); - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 2*rs_a)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm11, ymm7); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm11, ymm5); - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] - ymm2 = _mm256_broadcast_sd((double const *)(&ones)); - ymm3 = _mm256_broadcast_sd((double const *)(&ones)); - } - else if(1 == n_remainder) - { + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm11, ymm3); - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); + //extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - b01 += 1; //move to next row of B - a10 += p_lda; - } + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2*rs_a)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm9, ymm7); - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm9, ymm5); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm9, ymm3); - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_broadcast_sd((double const *)(&ones)); - ymm2 = _mm256_broadcast_sd((double const *)(&ones)); - ymm3 = _mm256_broadcast_sd((double const *)(&ones)); - } + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - ///implement TRSM/// + //extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - ///transpose of B11// - ///unpacklow/// - ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); - //rearrange low elements - ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - //rearrange high elements - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); - //perform mul operation - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + ymm0 = _mm256_loadu_pd((double const *)b11); + ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x01); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x01); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x01); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x01); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm11 = _mm256_blend_pd(ymm0, ymm11, 0x01); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm13 = _mm256_blend_pd(ymm0, ymm13, 0x01); - ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 3)); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + 3)); + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); + _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); + _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); + _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); - //(ROw3): FMA operations - ymm10 = _mm256_fnmadd_pd(ymm6, ymm11, ymm10); - ymm9 = _mm256_fnmadd_pd(ymm7, ymm11, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm16, ymm11, ymm8); + m_remainder -=1; + } + } + } - //perform mul operation - ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + dim_t n_remainder = j + d_nr; - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + /* + Reminder cases starts here: + a. Similar logic and code flow used in computing full block (6x8) + above holds for reminder cases too. + */ - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + 2)); + if(n_remainder >= 4) + { + a01 = L + (n_remainder - 4)*rs_a + n_remainder*cs_a; //pointer to block of A to be used in GEMM + a11 = L + (n_remainder - 4)*cs_a + (n_remainder - 4)*rs_a; //pointer to block of A to be used for TRSM - //(ROw2): FMA operations - ymm9 = _mm256_fnmadd_pd(ymm7, ymm10, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm16, ymm10, ymm8); + double *ptr_a10_dup = D_A_pack; - //perform mul operation - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); + dim_t p_lda = (n-n_remainder); // packed leading dimension + // perform copy of A to packed buffer D_A_pack - //extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + if(transa) + { + for(dim_t x =0;x < p_lda;x+=d_nr) + { + ymm0 = _mm256_loadu_pd((double const *)(a01)); + ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); + ymm2 = _mm256_loadu_pd((double const *)(a01 + cs_a * 2)); + ymm3 = _mm256_loadu_pd((double const *)(a01 + cs_a * 3)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + 1)); + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); - //(ROw2): FMA operations - ymm8 = _mm256_fnmadd_pd(ymm16, ymm9, ymm8); + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); - //perform mul operation - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); - ///unpack high/// - ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + ymm0 = _mm256_loadu_pd((double const *)(a01 + cs_a * 4)); + ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a * 5)); - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_broadcast_sd((double const *)&zero); - if(3 == n_remainder) - { - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] - } - else if(2 == n_remainder) - { - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - } - else if(1 == n_remainder) - { - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - } - } - m_remainder -= 4; - } + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); - if(m_remainder) - { - a10 = L + m_remainder; + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_broadcast_sd((double const *)&zero); - // Do transpose for a10 & store in D_A_pack - double *ptr_a10_dup = D_A_pack; - if(3 == m_remainder) // Repetative A blocks will be 3*3 - { - dim_t p_lda = 4; // packed leading dimension - for(dim_t x =0;x < m-m_remainder;x+=4) - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm1 = _mm256_loadu_pd((double const *)(a10 + cs_a)); - ymm2 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); - ymm3 = _mm256_broadcast_sd((double const *)&ones); - - ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); - - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); - - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); - - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - - _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); - - a10 += 4; - ptr_a10_dup += 4*4; - } - - //cols - for(j = (n - D_NR); (j + 1) > 0; j -= D_NR) //loop along 'N' dimension - { - a10 = D_A_pack; - a11 = L; //pointer to block of A to be used for TRSM - b01 = B + (j* cs_b) + m_remainder; //pointer to block of B to be used for GEMM - b11 = B + (j* cs_b); //pointer to block of B to be used for TRSM - - k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); - #endif - - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); - ymm4 = _mm256_fmadd_pd(ymm2, ymm0, ymm4); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ///GEMM code ends/// - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to store alpha value - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x08); - - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); - - _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store(B11[0-3][3]) - - dtrsm_AltXB_ref(a11, b11, m_remainder, 6, cs_a, cs_b, is_unitdiag); - } - - dim_t n_remainder = j + D_NR; - if((n_remainder >= 4)) - { - a10 = D_A_pack; - a11 = L; //pointer to block of A to be used for TRSM - b01 = B + ((n_remainder - 4)* cs_b) + m_remainder; //pointer to block of B to be used for GEMM - b11 = B + ((n_remainder - 4)* cs_b); //pointer to block of B to be used for TRSM - - k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - #endif - - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_extractf128_pd(ymm6,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_extractf128_pd(ymm7,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_extractf128_pd(ymm8,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_extractf128_pd(ymm9,0)); - b01 += 1; //move to next row of B - a10 += p_lda; + a01 += d_nr*cs_a; + ptr_a10_dup += d_nr; } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); - ymm3 = _mm256_broadcast_sd((double const *)(b11 + cs_b*3 + 2)); - ymm3 = _mm256_insertf128_pd(ymm3, xmm5, 0); - - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x08); - - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - xmm5 = _mm256_extractf128_pd(ymm3, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 3),xmm5); - _mm_storel_pd((b11 + cs_b * 3 + 2), _mm256_extractf128_pd(ymm3, 1)); - - dtrsm_AltXB_ref(a11, b11, m_remainder, 4, cs_a, cs_b, is_unitdiag); - n_remainder -= 4; } - - if(n_remainder) + else { - a10 = D_A_pack; - a11 = L; //pointer to block of A to be used for TRSM - b01 = B + m_remainder; //pointer to block of B to be used for GEMM - b11 = B; //pointer to block of B to be used for TRSM - - k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - #endif - - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - - if(3 == n_remainder) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2)); - ymm2 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 2 + 2)); - ymm2 = _mm256_insertf128_pd(ymm2, xmm5, 0); - - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08); - - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - xmm5 = _mm256_extractf128_pd(ymm2, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 2), xmm5); - _mm_storel_pd((b11 + cs_b * 2 + 2), _mm256_extractf128_pd(ymm2, 1)); - - dtrsm_AltXB_ref(a11, b11, m_remainder, 3, cs_a, cs_b, is_unitdiag); - } - else if(2 == n_remainder) - { - - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - b01 += 1; //move to next row of B - a10 += p_lda; - - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1)); - ymm1 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 1 + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); - - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); + dim_t loop_count = (n-n_remainder)/4; - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - xmm5 = _mm256_extractf128_pd(ymm1, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 1), xmm5); - _mm_storel_pd((b11 + cs_b * 1 + 2), _mm256_extractf128_pd(ymm1, 1)); - - dtrsm_AltXB_ref(a11, b11, m_remainder, 2, cs_a, cs_b, is_unitdiag); - } - else if(1 == n_remainder) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + for(dim_t x =0;x < loop_count;x++) { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 0)); - ymm0 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 0 + 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); - - xmm5 = _mm256_extractf128_pd(ymm0, 0); - _mm_storeu_pd((double *)(b11), xmm5); - _mm_storel_pd((b11 + 2), _mm256_extractf128_pd(ymm0, 1)); - - dtrsm_AltXB_ref(a11, b11, m_remainder, 1, cs_a, cs_b, is_unitdiag); + ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 0 + x*4)); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + x*4), ymm15); + ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 1 + x*4)); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 1 + x*4), ymm15); + ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 2 + x*4)); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 2 + x*4), ymm15); + ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 3 + x*4)); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 3 + x*4), ymm15); } - } - } - else if(2 == m_remainder) // Repetative A blocks will be 2*2 - { - dim_t p_lda = 4; // packed leading dimension - for(dim_t x =0;x < m-m_remainder;x+=4) - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm1 = _mm256_loadu_pd((double const *)(a10 + cs_a)); - ymm2 = _mm256_broadcast_sd((double const *)&ones); - ymm3 = _mm256_broadcast_sd((double const *)&ones); - ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); - - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); - - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); - - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - - _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); - - a10 += 4; - ptr_a10_dup += 4*4; - } - //cols - for(j = (n - D_NR); (j + 1) > 0; j -= D_NR) //loop along 'N' dimension - { - a10 = D_A_pack; - a11 = L; //pointer to block of A to be used for TRSM - b01 = B + (j* cs_b) + m_remainder; //pointer to block of B to be used for GEMM - b11 = B + (j* cs_b); //pointer to block of B to be used for TRSM - - k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); - #endif + dim_t remainder_loop_count = p_lda - loop_count*4; - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + __m128d xmm0; + if(remainder_loop_count != 0) { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); - ymm4 = _mm256_fmadd_pd(ymm2, ymm0, ymm4); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); - - b01 += 1; //move to next row of B - a10 += p_lda; + xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 0 + loop_count*4)); + _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + loop_count*4), xmm0); + xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 1 + loop_count*4)); + _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 1 + loop_count*4), xmm0); + xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 2 + loop_count*4)); + _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 2 + loop_count*4), xmm0); + xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 3 + loop_count*4)); + _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 3 + loop_count*4), xmm0); } - - ///GEMM code ends/// - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to store alpha value - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0C); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0C); - - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); - - _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store(B11[0-3][3]) - - dtrsm_AltXB_ref(a11, b11, m_remainder, 6, cs_a, cs_b, is_unitdiag); } - dim_t n_remainder = j + D_NR; - if((n_remainder >= 4)) - { - a10 = D_A_pack; - a11 = L; //pointer to block of A to be used for TRSM - b01 = B + ((n_remainder - 4)* cs_b) + m_remainder; //pointer to block of B to be used for GEMM - b11 = B + ((n_remainder - 4)* cs_b); //pointer to block of B to be used for TRSM - - k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - #endif - - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - ///implement TRSM/// - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); - ymm3 = _mm256_insertf128_pd(ymm3, xmm5, 0); - - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0C); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0C); - - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - xmm5 = _mm256_extractf128_pd(ymm3, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 3), xmm5); - - dtrsm_AltXB_ref(a11, b11, m_remainder, 4, cs_a, cs_b, is_unitdiag); - n_remainder -= 4; - } - if(n_remainder) + ymm4 = _mm256_broadcast_sd((double const *)&ones); + if(!is_unitdiag) { - a10 = D_A_pack; - a11 = L; //pointer to block of A to be used for TRSM - b01 = B + m_remainder; //pointer to block of B to be used for GEMM - b11 = B; //pointer to block of B to be used for TRSM - - k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - #endif - - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - if(3 == n_remainder) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + if(transa) { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2)); - ymm2 = _mm256_insertf128_pd(ymm2, xmm5, 0); - - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0C); - - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - xmm5 = _mm256_extractf128_pd(ymm2, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 2), xmm5); - - dtrsm_AltXB_ref(a11, b11, m_remainder, 3, cs_a, cs_b, is_unitdiag); + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a*1 + 1)); + ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a*2 + 2)); + ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a*3 + 3)); } - else if(2 == n_remainder) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations + else { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - b01 += 1; //move to next row of B - a10 += p_lda; + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11+ rs_a*1 + 1)); + ymm2 = _mm256_broadcast_sd((double const *)(a11+ rs_a*2 + 2)); + ymm3 = _mm256_broadcast_sd((double const *)(a11+ rs_a*3 + 3)); } - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); - - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); - - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - xmm5 = _mm256_extractf128_pd(ymm1, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 1), xmm5); - - dtrsm_AltXB_ref(a11, b11, m_remainder, 2, cs_a, cs_b, is_unitdiag); - } - else if(1 == n_remainder) - { - ///GEMM code begins/// - - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 0)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); - - xmm5 = _mm256_extractf128_pd(ymm0, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 0), xmm5); - - dtrsm_AltXB_ref(a11, b11, m_remainder, 1, cs_a, cs_b, is_unitdiag); - } - } - - } - else if(1 == m_remainder) // Repetative A blocks will be 1*1 - { - dim_t p_lda = 4; // packed leading dimension - for(dim_t x =0;x < m-m_remainder;x+=4) - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm1 = _mm256_broadcast_sd((double const *)&ones); - ymm2 = _mm256_broadcast_sd((double const *)&ones); - ymm3 = _mm256_broadcast_sd((double const *)&ones); - - ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); - - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); - - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); - - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - - _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); - - a10 += 4; - ptr_a10_dup += 4*4; - } - //cols - for(j = (n - D_NR); (j + 1) > 0; j -= D_NR) //loop along 'N' dimension - { - a10 = D_A_pack; - a11 = L; //pointer to block of A to be used for TRSM - b01 = B + (j* cs_b) + m_remainder; //pointer to block of B to be used for GEMM - b11 = B + (j* cs_b); //pointer to block of B to be used for TRSM - - k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); - #endif - - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); - ymm4 = _mm256_fmadd_pd(ymm2, ymm0, ymm4); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ///GEMM code ends/// - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to store alpha value - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0E); - - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); - - _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store(B11[0-3][3]) - - dtrsm_AltXB_ref(a11, b11, m_remainder, 6, cs_a, cs_b, is_unitdiag); - } - dim_t n_remainder = j + D_NR; - if((n_remainder >= 4)) - { - a10 = D_A_pack; - a11 = L; //pointer to block of A to be used for TRSM - b01 = B + ((n_remainder - 4)* cs_b) + m_remainder; //pointer to block of B to be used for GEMM - b11 = B + ((n_remainder - 4)* cs_b); //pointer to block of B to be used for TRSM - - k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - #endif - - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_broadcast_sd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_broadcast_sd((double const *)(b11 + cs_b *2)); - ymm3 = _mm256_broadcast_sd((double const *)(b11 + cs_b *3)); - - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0E); - - _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm0, 0)); - _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm1, 0)); - _mm_storel_pd((b11 + cs_b * 2), _mm256_extractf128_pd(ymm2, 0)); - _mm_storel_pd((b11 + cs_b * 3), _mm256_extractf128_pd(ymm3, 0)); - - dtrsm_AltXB_ref(a11, b11, m_remainder, 4, cs_a, cs_b, is_unitdiag); - n_remainder -= 4; - } - if(n_remainder) - { - a10 = D_A_pack; - a11 = L; //pointer to block of A to be used for TRSM - b01 = B + m_remainder; //pointer to block of B to be used for GEMM - b11 = B; //pointer to block of B to be used for TRSM - - k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - #endif - - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - - if(3 == n_remainder) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_broadcast_sd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_broadcast_sd((double const *)(b11 + cs_b *2)); - - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); - - _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm0, 0)); - _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm1, 0)); - _mm_storel_pd((b11 + cs_b * 2), _mm256_extractf128_pd(ymm2, 0)); - - dtrsm_AltXB_ref(a11, b11, m_remainder, 3, cs_a, cs_b, is_unitdiag); - } - else if(2 == n_remainder) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_broadcast_sd((double const *)(b11 + cs_b *1)); - - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); - - _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm0, 0)); - _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm1, 0)); - - dtrsm_AltXB_ref(a11, b11, m_remainder, 2, cs_a, cs_b, is_unitdiag); - } - else if(1 == n_remainder) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - b01 += 1; //move to next row of B - a10 += p_lda; - } - - //register to hold alpha - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); - - ///implement TRSM/// - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b *0)); - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); - - _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm0, 0)); - dtrsm_AltXB_ref(a11, b11, m_remainder, 1, cs_a, cs_b, is_unitdiag); - } - } - } - } - - if ((required_packing_A == 1) && - bli_mem_is_alloc( &local_mem_buf_A_s )) - { - bli_membrk_release(&rntm,&local_mem_buf_A_s); - } - return BLIS_SUCCESS; -} - -/* - * TRSM for the case AX = alpha * B, Double precision - * A is upper-triangular, non-transpose, non-unit diagonal - * dimensions A: mxm X: mxn B: mxn -*/ -BLIS_INLINE err_t bli_dtrsm_small_AuXB -( - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl -) -{ - dim_t m = bli_obj_length(b); // number of rows of matrix B - dim_t n = bli_obj_width(b); // number of columns of matrix B - - dim_t cs_a = bli_obj_col_stride(a); // column stride of A - dim_t cs_b = bli_obj_col_stride(b); // column stride of B - - dim_t i, j, k; //loop variables - dim_t k_iter; //number of times GEMM to be performed - - double AlphaVal = *(double *)AlphaObj->buffer; //value of alpha - double *L = a->buffer; //pointer to matrix A - double *B = b->buffer; //pointer to matrix B - - //pointers that point to blocks for GEMM and TRSM - double *a10, *a11, *b01, *b11; - //double *ptr_a10_dup; - - double ones = 1.0; - bool is_unitdiag = bli_obj_has_unit_diag(a); - //scratch registers - __m256d ymm0, ymm1, ymm2, ymm3; - __m256d ymm4, ymm5, ymm6, ymm7; - __m256d ymm8, ymm9, ymm10, ymm11; - __m256d ymm12, ymm13, ymm14, ymm15; - __m256d ymm16, ymm17, ymm18, ymm19; - __m256d ymm20; - - __m128d xmm5; - - gint_t required_packing_A = 1; - mem_t local_mem_buf_A_s = {0}; - double *D_A_pack = NULL; - double d11_pack[D_MR] __attribute__((aligned(64))); - rntm_t rntm; - - bli_rntm_init_from_global( &rntm ); - bli_rntm_set_num_threads_only( 1, &rntm ); - bli_membrk_rntm_set_membrk( &rntm ); - - siz_t buffer_size = bli_pool_block_size( - bli_membrk_pool( - bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), - bli_rntm_membrk(&rntm))); - - if( (D_MR * m * sizeof(double)) > buffer_size) - return BLIS_NOT_YET_IMPLEMENTED; - - if (required_packing_A == 1) - { - // Get the buffer from the pool. - bli_membrk_acquire_m(&rntm, - buffer_size, - BLIS_BITVAL_BUFFER_FOR_A_BLOCK, - &local_mem_buf_A_s); - if(FALSE==bli_mem_is_alloc(&local_mem_buf_A_s)) return BLIS_NULL_POINTER; - D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); - if(NULL==D_A_pack) return BLIS_NULL_POINTER; - } - - /* - Performs solving TRSM for 8 colmns at a time from 0 to m/8 in steps of D_MR - a. Load, transpose, Pack A (a10 block), the size of packing 8x6 to 8x (m-8) - First there will be no GEMM and no packing of a10 because it is only TRSM - b. Using packed a10 block and b01 block perform GEMM operation - c. Use GEMM outputs, perform TRSM operaton using a11, b11 and update B - d. Repeat b,c for n row of B in steps of D_NR - */ - for(i = (m - D_MR); (i + 1) > 0; i -= D_MR) - { - a10 = L + (i + D_MR)*cs_a + i; //pointer to block of A to be used for GEMM - a11 = L + (i*cs_a) + i; //pointer to block of A to be used for TRSM - - // Do transpose for a10 & store in D_A_pack - //ptr_a10_dup = D_A_pack; //ptr_a11_dup = a11; - dim_t p_lda = D_MR; // packed leading dimension - - /* - Pack current A block (a10) into packed buffer memory D_A_pack - a. This a10 block is used in GEMM portion only and this - a10 block size will be increasing by D_MR for every next itteration - untill it reaches 8x(m-8) which is the maximum GEMM alone block size in A - b. This packed buffer is reused to calculate all n rows of B matrix - */ - bli_dtrsm_small_pack('L', (m-i-D_MR), 0, a10, cs_a, D_A_pack, p_lda); - - /* - Pack 8 diagonal elements of A block into an array - a. This helps in utilze cache line efficiently in TRSM operation - b. store ones when input is unit diagonal - */ - dtrsm_small_pack_diag_element(is_unitdiag,a11,cs_a,d11_pack,D_MR); - - /* - a. Perform GEMM using a10, b01. - b. Perform TRSM on a11, b11 - c. This loop GEMM+TRSM loops operates with 8x6 block size - along n dimension for every D_NR rows of b01 where - packed A buffer is reused in computing all n rows of B. - d. Same approch is used in remaining fringe cases. - */ - for(j = (n - D_NR); (j + 1) > 0; j -= D_NR) //loop along 'N' dimension - { - a10 = D_A_pack; - b01 = B + (j*cs_b) + i + D_MR; //pointer to block of B to be used for GEMM - b11 = B + (j* cs_b) + i; //pointer to block of B to be used for TRSM - - k_iter = (m - i - D_MR); //number of times GEMM to be performed(in blocks of 4x4) - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); - #endif - - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS - - /* - Peform GEMM between a10 and b01 blocks - For first itteration there will be no GEMM operation - where k_iter are zero - */ - BLIS_DTRSM_SMALL_GEMM_8x6(a10,b01,cs_b,p_lda,k_iter) - - /* - Load b11 of size 6x8 and multiply with alpha - Add the GEMM output and perform inregister transose of b11 - to peform TRSM operation. - */ - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - - ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2 + 4)); - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3 + 4)); - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm12); - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm13); - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm14); - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm15); - - ymm13 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm15 = _mm256_unpacklo_pd(ymm2, ymm3); - ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); - ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); - ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm15 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *4 + 4)); - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *5 + 4)); - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm6); - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm7); - - ymm16 = _mm256_broadcast_sd((double const *)(&ones)); - ymm7 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm4 = _mm256_permute2f128_pd(ymm7,ymm16,0x20); - ymm6 = _mm256_permute2f128_pd(ymm7,ymm16,0x31); - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - - ymm5 = _mm256_permute2f128_pd(ymm0,ymm16,0x20); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm16,0x31); - ymm18 = _mm256_unpacklo_pd(ymm2, ymm3); - - ymm17 = _mm256_permute2f128_pd(ymm18,ymm16,0x20); - ymm19 = _mm256_permute2f128_pd(ymm18,ymm16,0x31); - ymm20 = _mm256_unpackhi_pd(ymm2, ymm3); - ymm18 = _mm256_permute2f128_pd(ymm20,ymm16,0x20); - ymm20 = _mm256_permute2f128_pd(ymm20,ymm16,0x31); - - /* - Compute 8x6 TRSM block by using GEMM block output in register - a. The 8x6 input (gemm outputs) are stored in combinations of ymm registers - 1. ymm15, ymm20 2. ymm14, ymm19 3. ymm13, ymm18 , 4. ymm12, ymm17 - 5. ymm11, ymm7 6. ymm10, ymm6, 7.ymm9, ymm5 8. ymm8, ymm4 - where ymm15-ymm8 holds 8x4 data and reaming 8x2 will be hold by - other registers - b. Towards the end do in regiser transpose of TRSM output and store in b11 - */ - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); - - //perform mul operation - ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); - ymm20 = DTRSM_SMALL_DIV_OR_SCALE(ymm20, ymm1); - - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); - - //(ROw7): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6 + 7*cs_a)); - ymm14 = _mm256_fnmadd_pd(ymm2, ymm15, ymm14); - ymm19 = _mm256_fnmadd_pd(ymm2, ymm20, ymm19); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5 + 7*cs_a)); - ymm13 = _mm256_fnmadd_pd(ymm2, ymm15, ymm13); - ymm18 = _mm256_fnmadd_pd(ymm2, ymm20, ymm18); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4 + 7*cs_a)); - ymm12 = _mm256_fnmadd_pd(ymm2, ymm15, ymm12); - ymm17 = _mm256_fnmadd_pd(ymm2, ymm20, ymm17); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3 + 7*cs_a)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm15, ymm11); - ymm7 = _mm256_fnmadd_pd(ymm2, ymm20, ymm7); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2 + 7*cs_a)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm15, ymm10); - ymm6 = _mm256_fnmadd_pd(ymm2, ymm20, ymm6); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 + 7*cs_a)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm15, ymm9); - ymm5 = _mm256_fnmadd_pd(ymm2, ymm20, ymm5); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7*cs_a)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm15, ymm8); - ymm4 = _mm256_fnmadd_pd(ymm2, ymm20, ymm4); - - //perform mul operation - ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); - ymm19 = DTRSM_SMALL_DIV_OR_SCALE(ymm19, ymm1); - - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); - - //(ROw6): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5 + 6*cs_a)); - ymm13 = _mm256_fnmadd_pd(ymm2, ymm14, ymm13); - ymm18 = _mm256_fnmadd_pd(ymm2, ymm19, ymm18); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4 + 6*cs_a)); - ymm12 = _mm256_fnmadd_pd(ymm2, ymm14, ymm12); - ymm17 = _mm256_fnmadd_pd(ymm2, ymm19, ymm17); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3 + 6*cs_a)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm14, ymm11); - ymm7 = _mm256_fnmadd_pd(ymm2, ymm19, ymm7); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2 + 6*cs_a)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm14, ymm10); - ymm6 = _mm256_fnmadd_pd(ymm2, ymm19, ymm6); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 + 6*cs_a)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm14, ymm9); - ymm5 = _mm256_fnmadd_pd(ymm2, ymm19, ymm5); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6*cs_a)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm14, ymm8); - ymm4 = _mm256_fnmadd_pd(ymm2, ymm19, ymm4); - - //perform mul operation - ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm1); - ymm18 = DTRSM_SMALL_DIV_OR_SCALE(ymm18, ymm1); - - //extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); - - //(ROw5): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4 + 5*cs_a)); - ymm12 = _mm256_fnmadd_pd(ymm2, ymm13, ymm12); - ymm17 = _mm256_fnmadd_pd(ymm2, ymm18, ymm17); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3 + 5*cs_a)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm13, ymm11); - ymm7 = _mm256_fnmadd_pd(ymm2, ymm18, ymm7); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2 + 5*cs_a)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm13, ymm10); - ymm6 = _mm256_fnmadd_pd(ymm2, ymm18, ymm6); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 + 5*cs_a)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm13, ymm9); - ymm5 = _mm256_fnmadd_pd(ymm2, ymm18, ymm5); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm13, ymm8); - ymm4 = _mm256_fnmadd_pd(ymm2, ymm18, ymm4); - - //perform mul operation - ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); - ymm17 = DTRSM_SMALL_DIV_OR_SCALE(ymm17, ymm1); - - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - //(ROw4): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3 + 4*cs_a)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm12, ymm11); - ymm7 = _mm256_fnmadd_pd(ymm2, ymm17, ymm7); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2 + 4*cs_a)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm12, ymm10); - ymm6 = _mm256_fnmadd_pd(ymm2, ymm17, ymm6); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 + 4*cs_a)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm12, ymm9); - ymm5 = _mm256_fnmadd_pd(ymm2, ymm17, ymm5); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm12, ymm8); - ymm4 = _mm256_fnmadd_pd(ymm2, ymm17, ymm4); - - //perform mul operation - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm1); - - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(ROw3): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2 + 3*cs_a)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm11, ymm10); - ymm6 = _mm256_fnmadd_pd(ymm2, ymm7, ymm6); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 + 3*cs_a)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm11, ymm9); - ymm5 = _mm256_fnmadd_pd(ymm2, ymm7, ymm5); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm11, ymm8); - ymm4 = _mm256_fnmadd_pd(ymm2, ymm7, ymm4); - - //perform mul operation - ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); - ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm1); - - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(ROw2): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 + 2*cs_a)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm10, ymm9); - ymm5 = _mm256_fnmadd_pd(ymm2, ymm6, ymm5); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm10, ymm8); - ymm4 = _mm256_fnmadd_pd(ymm2, ymm6, ymm4); - - //perform mul operation - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm1); - - //extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); - - //(ROw2): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm9, ymm8); - ymm4 = _mm256_fnmadd_pd(ymm2, ymm5, ymm4); - - //perform mul operation - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); - ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm1); - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); - ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); - - ///unpack high/// - ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); - ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); - ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); - - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm12, ymm13); - ymm3 = _mm256_unpacklo_pd(ymm14, ymm15); - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); - - ///unpack high/// - ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); - ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); - ymm3 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); - - _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm0); - _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm1); - _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 4), ymm2); - _mm256_storeu_pd((double *)(b11 + cs_b * 3 + 4), ymm3); - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm4, ymm5); - ymm3 = _mm256_unpacklo_pd(ymm6, ymm7); - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); - - ///unpack high/// - ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); - ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); - - _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); - _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm17, ymm18); - ymm3 = _mm256_unpacklo_pd(ymm19, ymm20); - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); - - ///unpack high/// - ymm17 = _mm256_unpackhi_pd(ymm17, ymm18); - ymm18 = _mm256_unpackhi_pd(ymm19, ymm20); - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm17, ymm18, 0x20); - - _mm256_storeu_pd((double *)(b11 + cs_b * 4 + 4), ymm0); - _mm256_storeu_pd((double *)(b11 + cs_b * 5 + 4), ymm1); - - } - - dim_t n_remainder = j + D_NR; - if(n_remainder >= 4) - { - a10 = D_A_pack; - a11 = L + (i*cs_a) + i; - b01 = B + ((n_remainder - 4)* cs_b) + i + D_MR; - b11 = B + ((n_remainder - 4)* cs_b) + i; - - k_iter = (m - i - D_MR); - - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - ymm12 = _mm256_setzero_pd(); - ymm13 = _mm256_setzero_pd(); - ymm14 = _mm256_setzero_pd(); - ymm15 = _mm256_setzero_pd(); - - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15); - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); - ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); - ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *2 + 4)); - ymm7 = _mm256_loadu_pd((double const *)(b11 + cs_b *3 + 4)); - - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); //B11[0-3][3] * alpha -= B01[0-3][3] - ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); - ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); - ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); - ymm7 = _mm256_fmsub_pd(ymm7, ymm16, ymm15); - - ///implement TRSM/// - - ///transpose of B11// - ///unpacklow/// - ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - - ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); - ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); - - //rearrange low elements - ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] - - ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); - ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); - - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - - ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); - ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); - - //rearrange high elements - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - - ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); - - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); - - //perform mul operation - ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); - - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); - - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6 + 7*cs_a)); - ymm3 = _mm256_broadcast_sd((double const *)(a11 + 5 + 7*cs_a)); - ymm4 = _mm256_broadcast_sd((double const *)(a11 + 4 + 7*cs_a)); - ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3 + 7*cs_a)); - ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2 + 7*cs_a)); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + 1 + 7*cs_a)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + 7*cs_a)); - - //(ROw7): FMA operations - ymm14 = _mm256_fnmadd_pd(ymm2, ymm15, ymm14); - ymm13 = _mm256_fnmadd_pd(ymm3, ymm15, ymm13); - ymm12 = _mm256_fnmadd_pd(ymm4, ymm15, ymm12); - ymm11 = _mm256_fnmadd_pd(ymm5, ymm15, ymm11); - ymm10 = _mm256_fnmadd_pd(ymm6, ymm15, ymm10); - ymm9 = _mm256_fnmadd_pd(ymm7, ymm15, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm16, ymm15, ymm8); - - //perform mul operation - ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); - - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); - - ymm3 = _mm256_broadcast_sd((double const *)(a11 + 5 + 6*cs_a)); - ymm4 = _mm256_broadcast_sd((double const *)(a11 + 4 + 6*cs_a)); - ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3 + 6*cs_a)); - ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2 + 6*cs_a)); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + 1 + 6*cs_a)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + 6*cs_a)); - - //(ROw6): FMA operations - ymm13 = _mm256_fnmadd_pd(ymm3, ymm14, ymm13); - ymm12 = _mm256_fnmadd_pd(ymm4, ymm14, ymm12); - ymm11 = _mm256_fnmadd_pd(ymm5, ymm14, ymm11); - ymm10 = _mm256_fnmadd_pd(ymm6, ymm14, ymm10); - ymm9 = _mm256_fnmadd_pd(ymm7, ymm14, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm16, ymm14, ymm8); - - //perform mul operation - ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm1); - - //extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); - - ymm4 = _mm256_broadcast_sd((double const *)(a11 + 4 + 5*cs_a)); - ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3 + 5*cs_a)); - ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2 + 5*cs_a)); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + 1 + 5*cs_a)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a)); - - //(ROw5): FMA operations - ymm12 = _mm256_fnmadd_pd(ymm4, ymm13, ymm12); - ymm11 = _mm256_fnmadd_pd(ymm5, ymm13, ymm11); - ymm10 = _mm256_fnmadd_pd(ymm6, ymm13, ymm10); - ymm9 = _mm256_fnmadd_pd(ymm7, ymm13, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm16, ymm13, ymm8); - - //perform mul operation - ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); - - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3 + 4*cs_a)); - ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2 + 4*cs_a)); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + 1 + 4*cs_a)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a)); - - //(ROw4): FMA operations - ymm11 = _mm256_fnmadd_pd(ymm5, ymm12, ymm11); - ymm10 = _mm256_fnmadd_pd(ymm6, ymm12, ymm10); - ymm9 = _mm256_fnmadd_pd(ymm7, ymm12, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm16, ymm12, ymm8); - - //perform mul operation - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); - - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2 + 3*cs_a)); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + 1 + 3*cs_a)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); - - //(ROw3): FMA operations - ymm10 = _mm256_fnmadd_pd(ymm6, ymm11, ymm10); - ymm9 = _mm256_fnmadd_pd(ymm7, ymm11, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm16, ymm11, ymm8); - - //perform mul operation - ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); - - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - ymm7 = _mm256_broadcast_sd((double const *)(a11 + 1 + 2*cs_a)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); - - //(ROw2): FMA operations - ymm9 = _mm256_fnmadd_pd(ymm7, ymm10, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm16, ymm10, ymm8); - - //perform mul operation - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); - - //extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); - - ymm16 = _mm256_broadcast_sd((double const *)(a11 + 1*cs_a)); - - //(ROw2): FMA operations - ymm8 = _mm256_fnmadd_pd(ymm16, ymm9, ymm8); - - //perform mul operation - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - - ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); - ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); - ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); - - ///unpack high/// - ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - - ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); - ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); - ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); - - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); - _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 4), ymm6); - _mm256_storeu_pd((double *)(b11 + cs_b * 3 + 4), ymm7); - n_remainder -=4; - } - - if(n_remainder) //implementation fo remaining columns(when 'N' is not a multiple of D_NR)() n = 3 - { - a10 = D_A_pack; - a11 = L + (i*cs_a) + i; - b01 = B + i + D_MR; - b11 = B + i; - - k_iter = (m - i - D_MR); - - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm12 = _mm256_setzero_pd(); - ymm13 = _mm256_setzero_pd(); - ymm14 = _mm256_setzero_pd(); - - if(3 == n_remainder) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); - ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); - ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *2 + 4)); - - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] - ymm3 = _mm256_broadcast_sd((double const *)(&ones)); - - ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); - ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); - ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); - ymm7 = _mm256_broadcast_sd((double const *)(&ones)); - } - else if(2 == n_remainder) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); - - b01 += 1; //move to next row of B - a10 += p_lda; - - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - - ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); - ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); - - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] - ymm2 = _mm256_broadcast_sd((double const *)(&ones)); - ymm3 = _mm256_broadcast_sd((double const *)(&ones)); - - ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); - ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); - ymm6 = _mm256_broadcast_sd((double const *)(&ones)); - ymm7 = _mm256_broadcast_sd((double const *)(&ones)); - } - else if(1 == n_remainder) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - - ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); - - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_broadcast_sd((double const *)(&ones)); - ymm2 = _mm256_broadcast_sd((double const *)(&ones)); - ymm3 = _mm256_broadcast_sd((double const *)(&ones)); - - ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); - ymm5 = _mm256_broadcast_sd((double const *)(&ones)); - ymm6 = _mm256_broadcast_sd((double const *)(&ones)); - ymm7 = _mm256_broadcast_sd((double const *)(&ones)); - } - ///implement TRSM/// - - ///transpose of B11// - ///unpacklow/// - ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - - ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); - ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); - - //rearrange low elements - ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] - - ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); - ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); - - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - - ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); - ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); - - //rearrange high elements - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - - ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); - - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); - - //perform mul operation - ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); - - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); - - //(ROw7): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6 + 7*cs_a)); - ymm14 = _mm256_fnmadd_pd(ymm2, ymm15, ymm14); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5 + 7*cs_a)); - ymm13 = _mm256_fnmadd_pd(ymm2, ymm15, ymm13); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4 + 7*cs_a)); - ymm12 = _mm256_fnmadd_pd(ymm2, ymm15, ymm12); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3 + 7*cs_a)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm15, ymm11); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2 + 7*cs_a)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm15, ymm10); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 + 7*cs_a)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm15, ymm9); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7*cs_a)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm15, ymm8); - - //perform mul operation - ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); - - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); - - //(ROw6): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5 + 6*cs_a)); - ymm13 = _mm256_fnmadd_pd(ymm2, ymm14, ymm13); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4 + 6*cs_a)); - ymm12 = _mm256_fnmadd_pd(ymm2, ymm14, ymm12); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3 + 6*cs_a)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm14, ymm11); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2 + 6*cs_a)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm14, ymm10); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 + 6*cs_a)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm14, ymm9); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6*cs_a)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm14, ymm8); - - //perform mul operation - ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm1); - - //extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); - - //(ROw5): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4 + 5*cs_a)); - ymm12 = _mm256_fnmadd_pd(ymm2, ymm13, ymm12); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3 + 5*cs_a)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm13, ymm11); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2 + 5*cs_a)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm13, ymm10); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 + 5*cs_a)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm13, ymm9); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm13, ymm8); - - //perform mul operation - ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); - - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - //(ROw4): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3 + 4*cs_a)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm12, ymm11); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2 + 4*cs_a)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm12, ymm10); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 + 4*cs_a)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm12, ymm9); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm12, ymm8); - - //perform mul operation - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); - - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(ROw3): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2 + 3*cs_a)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm11, ymm10); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 + 3*cs_a)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm11, ymm9); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm11, ymm8); - - //perform mul operation - ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); - - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(ROw2): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 + 2*cs_a)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm10, ymm9); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm10, ymm8); - - //perform mul operation - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); - - //extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); - - //(ROw2): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1*cs_a)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm9, ymm8); - - //perform mul operation - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - - ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); - ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); - ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); - - ///unpack high/// - ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - - ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); - ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); - ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); - - if(3 == n_remainder) - { - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); - _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 4), ymm6); - } - else if(2 == n_remainder) - { - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); - _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); - } - else if(1 == n_remainder) - { - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); - } - } - }// End of multiples of D_MR blocks in m-dimension - - // Repetative A blocks will be 4*4 - dim_t m_remainder = i + D_MR; - if(m_remainder >= 4) - { - i = m_remainder - 4; - a10 = L + (i + 4)*cs_a + i; //pointer to block of A to be used for GEMM - a11 = L + (i*cs_a) + i; //pointer to block of A to be used for TRSM - - // Do transpose for a10 & store in D_A_pack - double *ptr_a10_dup = D_A_pack; - double *ptr_a11_dup = a11; - dim_t p_lda = 4; // packed leading dimension - for(dim_t x =0;x < m-i-4;x++) - { - ymm0 = _mm256_loadu_pd((double const *)(a10 + x*cs_a)); - _mm256_storeu_pd((double *)(ptr_a10_dup + x*p_lda), ymm0); - } - - ymm4 = _mm256_broadcast_sd((double const *)&ones); - if(!is_unitdiag) - { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_sd((double const *)(a11)); - ymm1 = _mm256_broadcast_sd((double const *)(a11+cs_a*1 + 1)); - ymm2 = _mm256_broadcast_sd((double const *)(a11+cs_a*2 + 2)); - ymm3 = _mm256_broadcast_sd((double const *)(a11+cs_a*3 + 3)); - - //Pick one element each column and create a 4 element vector and store - ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); - ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); - #ifdef BLIS_DISABLE_TRSM_PREINVERSION - ymm4 = ymm1; - #endif - #ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm4 = _mm256_div_pd(ymm4, ymm1); - #endif - } - _mm256_storeu_pd((double *)(d11_pack), ymm4); - - //cols - for(j = (n - D_NR); (j + 1) > 0; j -= D_NR) //loop along 'N' dimension - { - a10 = D_A_pack; - a11 = ptr_a11_dup; //pointer to block of A to be used for TRSM - b01 = B + (j*cs_b) + i + 4; //pointer to block of B to be used for GEMM - b11 = B + (j* cs_b) + i; //pointer to block of B to be used for TRSM - - k_iter = (m - i - 4); //number of times GEMM to be performed(in blocks of 4x4) - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); - #endif - - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); - ymm4 = _mm256_fmadd_pd(ymm2, ymm0, ymm4); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - - ///transpose of B11// - ///unpacklow/// - ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - - //rearrange low elements - ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] - - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - - //rearrange high elements - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); - - ymm16 = _mm256_broadcast_sd((double const *)(&ones)); - - ////unpacklow//// - ymm7 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - - //rearrange low elements - ymm4 = _mm256_permute2f128_pd(ymm7,ymm16,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm6 = _mm256_permute2f128_pd(ymm7,ymm16,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] - - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - - //rearrange high elements - ymm5 = _mm256_permute2f128_pd(ymm0,ymm16,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm7 = _mm256_permute2f128_pd(ymm0,ymm16,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - - - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - //perform mul operation - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm1); - - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(ROw3): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2 + 3*cs_a)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm11, ymm10); - ymm6 = _mm256_fnmadd_pd(ymm2, ymm7, ymm6); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 + 3*cs_a)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm11, ymm9); - ymm5 = _mm256_fnmadd_pd(ymm2, ymm7, ymm5); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm11, ymm8); - ymm4 = _mm256_fnmadd_pd(ymm2, ymm7, ymm4); - - //perform mul operation - ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); - ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm1); - - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(ROw2): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 + 2*cs_a)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm10, ymm9); - ymm5 = _mm256_fnmadd_pd(ymm2, ymm6, ymm5); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm10, ymm8); - ymm4 = _mm256_fnmadd_pd(ymm2, ymm6, ymm4); - - //perform mul operation - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm1); - - //extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); - - //(ROw2): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm9, ymm8); - ymm4 = _mm256_fnmadd_pd(ymm2, ymm5, ymm4); - - //perform mul operation - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); - ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm1); - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ///unpack high/// - ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - - ///unpack high/// - ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - - _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store B11[1][0-3] - } - - dim_t n_remainder = j + D_NR; - if((n_remainder >= 4)) - { - a10 = D_A_pack; - a11 = L + (i*cs_a) + i; //pointer to block of A to be used for TRSM - b01 = B + ((n_remainder - 4)* cs_b) + i + 4; //pointer to block of B to be used for GEMM - b11 = B + ((n_remainder - 4)* cs_b) + i; //pointer to block of B to be used for TRSM - - k_iter = (m - i - 4); //number of times GEMM to be performed - - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - #endif - - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - - ///transpose of B11// - ///unpacklow/// - ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - - //rearrange low elements - ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] - - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - - //rearrange high elements - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - //perform mul operation - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); - - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(ROw3): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2 + 3*cs_a)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm11, ymm10); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 + 3*cs_a)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm11, ymm9); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm11, ymm8); - - //perform mul operation - ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); - - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(ROw2): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 + 2*cs_a)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm10, ymm9); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm10, ymm8); - - //perform mul operation - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); - - //extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); - - //(ROw2): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm9, ymm8); - - //perform mul operation - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ///unpack high/// - ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] - n_remainder = n_remainder - 4; - } - - if(n_remainder) //implementation fo remaining columns(when 'N' is not a multiple of D_NR)() n = 3 - { - a10 = D_A_pack; - a11 = L + (i*cs_a) + i; - b01 = B + i + 4; - b11 = B + i; - - k_iter = (m - i - 4); - - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - - if(3 == n_remainder) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] - ymm3 = _mm256_broadcast_sd((double const *)(&ones)); - } - else if(2 == n_remainder) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] - ymm2 = _mm256_broadcast_sd((double const *)(&ones)); - ymm3 = _mm256_broadcast_sd((double const *)(&ones)); - } - else if(1 == n_remainder) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_broadcast_sd((double const *)(&ones)); - ymm2 = _mm256_broadcast_sd((double const *)(&ones)); - ymm3 = _mm256_broadcast_sd((double const *)(&ones)); - } - - ///implement TRSM/// - - ///transpose of B11// - ///unpacklow/// - ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - - //rearrange low elements - ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] - - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - - //rearrange high elements - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - //perform mul operation - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); - - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(ROw3): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2 + 3*cs_a)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm11, ymm10); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 + 3*cs_a)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm11, ymm9); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm11, ymm8); - - //perform mul operation - ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); - - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(ROw2): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 + 2*cs_a)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm10, ymm9); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm10, ymm8); - - //perform mul operation - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); - - //extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); - - //(ROw2): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm9, ymm8); - - //perform mul operation - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ///unpack high/// - ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - if(3 == n_remainder) - { - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] - } - else if(2 == n_remainder) - { - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - } - else if(1 == n_remainder) - { - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - } - } - m_remainder -= 4; - } - - if(m_remainder) - { - a10 = L + m_remainder*cs_a; - - // Do transpose for a10 & store in D_A_pack - double *ptr_a10_dup = D_A_pack; - if(3 == m_remainder) // Repetative A blocks will be 3*3 - { - dim_t p_lda = 4; // packed leading dimension - for(dim_t x =0;x < m-m_remainder;x++) - { - ymm0 = _mm256_loadu_pd((double const *)(a10 + x*cs_a)); - _mm256_storeu_pd((double *)(ptr_a10_dup + x*p_lda), ymm0); - } - //cols - for(j = (n - D_NR); (j + 1) > 0; j -= D_NR) //loop along 'N' dimension - { - a10 = D_A_pack; - a11 = L; //pointer to block of A to be used for TRSM - b01 = B + (j* cs_b) + m_remainder; //pointer to block of B to be used for GEMM - b11 = B + (j* cs_b); //pointer to block of B to be used for TRSM - - k_iter = (m - m_remainder); //number of times GEMM to be performed - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); - #endif - - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); - ymm4 = _mm256_fmadd_pd(ymm2, ymm0, ymm4); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ///GEMM code ends/// - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to store alpha value - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x08); - - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); - - _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store(B11[0-3][3]) - - dtrsm_AuXB_ref(a11, b11, m_remainder, 6, cs_a, cs_b, is_unitdiag); - } - - dim_t n_remainder = j + D_NR; - if((n_remainder >= 4)) - { - a10 = D_A_pack; - a11 = L; //pointer to block of A to be used for TRSM - b01 = B + ((n_remainder - 4)* cs_b) + m_remainder; //pointer to block of B to be used for GEMM - b11 = B + ((n_remainder - 4)* cs_b); //pointer to block of B to be used for TRSM - - k_iter = (m - m_remainder); //number of times GEMM to be performed - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - #endif - - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); - ymm3 = _mm256_broadcast_sd((double const *)(b11 + cs_b*3 + 2)); - ymm3 = _mm256_insertf128_pd(ymm3, xmm5, 0); - - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x08); - - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - xmm5 = _mm256_extractf128_pd(ymm3, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 3),xmm5); - _mm_storel_pd((b11 + cs_b * 3 + 2), _mm256_extractf128_pd(ymm3, 1)); - - dtrsm_AuXB_ref(a11, b11, m_remainder, 4, cs_a, cs_b, is_unitdiag); - n_remainder -= 4; - } - if(n_remainder) - { - a10 = D_A_pack; - a11 = L; //pointer to block of A to be used for TRSM - b01 = B + m_remainder; //pointer to block of B to be used for GEMM - b11 = B; //pointer to block of B to be used for TRSM - - k_iter = (m - m_remainder); //number of times GEMM to be performed - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - #endif - - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - - if(3 == n_remainder) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2)); - ymm2 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 2 + 2)); - ymm2 = _mm256_insertf128_pd(ymm2, xmm5, 0); - - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08); - - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - xmm5 = _mm256_extractf128_pd(ymm2, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 2), xmm5); - _mm_storel_pd((b11 + cs_b * 2 + 2), _mm256_extractf128_pd(ymm2, 1)); - - dtrsm_AuXB_ref(a11, b11, m_remainder, 3, cs_a, cs_b, is_unitdiag); - } - else if(2 == n_remainder) - { - - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1)); - ymm1 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 1 + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); - - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); - - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - xmm5 = _mm256_extractf128_pd(ymm1, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 1), xmm5); - _mm_storel_pd((b11 + cs_b * 1 + 2), _mm256_extractf128_pd(ymm1, 1)); - - dtrsm_AuXB_ref(a11, b11, m_remainder, 2, cs_a, cs_b, is_unitdiag); - } - else if(1 == n_remainder) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 0)); - ymm0 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 0 + 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); - - xmm5 = _mm256_extractf128_pd(ymm0, 0); - _mm_storeu_pd((double *)(b11), xmm5); - _mm_storel_pd((b11 + 2), _mm256_extractf128_pd(ymm0, 1)); - - dtrsm_AuXB_ref(a11, b11, m_remainder, 1, cs_a, cs_b, is_unitdiag); - } - } - } - else if(2 == m_remainder) // Repetative A blocks will be 2*2 - { - dim_t p_lda = 4; // packed leading dimension - for(dim_t x =0;x < m-m_remainder;x++) - { - ymm0 = _mm256_loadu_pd((double const *)(a10 + x*cs_a)); - _mm256_storeu_pd((double *)(ptr_a10_dup + x*p_lda), ymm0); - } - //cols - for(j = (n - D_NR); (j + 1) > 0; j -= D_NR) //loop along 'N' dimension - { - a10 = D_A_pack; - a11 = L; //pointer to block of A to be used for TRSM - b01 = B + (j* cs_b) + m_remainder; //pointer to block of B to be used for GEMM - b11 = B + (j* cs_b); //pointer to block of B to be used for TRSM - - k_iter = (m - m_remainder); //number of times GEMM to be performed - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); - #endif - - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); - ymm4 = _mm256_fmadd_pd(ymm2, ymm0, ymm4); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ///GEMM code ends/// - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to store alpha value - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0C); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0C); - - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); - - _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store(B11[0-3][3]) - - dtrsm_AuXB_ref(a11, b11, m_remainder, 6, cs_a, cs_b, is_unitdiag); - } - dim_t n_remainder = j + D_NR; - if((n_remainder >= 4)) - { - a10 = D_A_pack; - a11 = L; //pointer to block of A to be used for TRSM - b01 = B + ((n_remainder - 4)* cs_b) + m_remainder; //pointer to block of B to be used for GEMM - b11 = B + ((n_remainder - 4)* cs_b); //pointer to block of B to be used for TRSM - - k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - #endif - - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); - ymm3 = _mm256_insertf128_pd(ymm3, xmm5, 0); - - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0C); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0C); - - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - xmm5 = _mm256_extractf128_pd(ymm3, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 3), xmm5); - - dtrsm_AuXB_ref(a11, b11, m_remainder, 4, cs_a, cs_b, is_unitdiag); - n_remainder -= 4; - } - if(n_remainder) - { - a10 = D_A_pack; - a11 = L; //pointer to block of A to be used for TRSM - b01 = B + m_remainder; //pointer to block of B to be used for GEMM - b11 = B; //pointer to block of B to be used for TRSM - - k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - #endif - - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - - if(3 == n_remainder) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2)); - ymm2 = _mm256_insertf128_pd(ymm2, xmm5, 0); - - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0C); - - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - xmm5 = _mm256_extractf128_pd(ymm2, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 2), xmm5); - - dtrsm_AuXB_ref(a11, b11, m_remainder, 3, cs_a, cs_b, is_unitdiag); - } - else if(2 == n_remainder) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); - - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); - - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - xmm5 = _mm256_extractf128_pd(ymm1, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 1), xmm5); - - dtrsm_AuXB_ref(a11, b11, m_remainder, 2, cs_a, cs_b, is_unitdiag); - } - else if(1 == n_remainder) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 0)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); - - xmm5 = _mm256_extractf128_pd(ymm0, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 0), xmm5); - - dtrsm_AuXB_ref(a11, b11, m_remainder, 1, cs_a, cs_b, is_unitdiag); - } - } - - } - else if(1 == m_remainder) // Repetative A blocks will be 1*1 - { - dim_t p_lda = 4; // packed leading dimension - for(dim_t x =0;x < m-m_remainder;x++) - { - ymm0 = _mm256_loadu_pd((double const *)(a10 + x*cs_a)); - _mm256_storeu_pd((double *)(ptr_a10_dup + x*p_lda), ymm0); - } - //cols - for(j = (n - D_NR); (j + 1) > 0; j -= D_NR) //loop along 'N' dimension - { - a10 = D_A_pack; - a11 = L; //pointer to block of A to be used for TRSM - b01 = B + (j* cs_b) + m_remainder; //pointer to block of B to be used for GEMM - b11 = B + (j* cs_b); //pointer to block of B to be used for TRSM - - k_iter = (m - m_remainder); //number of times GEMM to be performed - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - #endif - - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); - ymm4 = _mm256_fmadd_pd(ymm2, ymm0, ymm4); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ///GEMM code ends/// - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to store alpha value - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0E); - - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); - - _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store(B11[0-3][3]) - - dtrsm_AuXB_ref(a11, b11, m_remainder, 6, cs_a, cs_b, is_unitdiag); - } - dim_t n_remainder = j + D_NR; - if((n_remainder >= 4)) - { - a10 = D_A_pack; - a11 = L; //pointer to block of A to be used for TRSM - b01 = B + ((n_remainder - 4)* cs_b) + m_remainder; //pointer to block of B to be used for GEMM - b11 = B + ((n_remainder - 4)* cs_b); //pointer to block of B to be used for TRSM - - k_iter = (m - m_remainder); //number of times GEMM to be performed - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - #endif - - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_broadcast_sd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_broadcast_sd((double const *)(b11 + cs_b *2)); - ymm3 = _mm256_broadcast_sd((double const *)(b11 + cs_b *3)); - - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0E); - - _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm0, 0)); - _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm1, 0)); - _mm_storel_pd((b11 + cs_b * 2), _mm256_extractf128_pd(ymm2, 0)); - _mm_storel_pd((b11 + cs_b * 3), _mm256_extractf128_pd(ymm3, 0)); - - dtrsm_AuXB_ref(a11, b11, m_remainder, 4, cs_a, cs_b, is_unitdiag); - n_remainder -= 4; - } - if(n_remainder) - { - a10 = D_A_pack; - a11 = L; //pointer to block of A to be used for TRSM - b01 = B + m_remainder; //pointer to block of B to be used for GEMM - b11 = B; //pointer to block of B to be used for TRSM - - k_iter = (m - m_remainder); //number of times GEMM to be performed - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - #endif - - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - - if(3 == n_remainder) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_broadcast_sd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_broadcast_sd((double const *)(b11 + cs_b *2)); - - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); - - _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm0, 0)); - _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm1, 0)); - _mm_storel_pd((b11 + cs_b * 2), _mm256_extractf128_pd(ymm2, 0)); - - dtrsm_AuXB_ref(a11, b11, m_remainder, 3, cs_a, cs_b, is_unitdiag); - } - else if(2 == n_remainder) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_broadcast_sd((double const *)(b11 + cs_b *1)); - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); - - _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm0, 0)); - _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm1, 0)); - - dtrsm_AuXB_ref(a11, b11, m_remainder, 2, cs_a, cs_b, is_unitdiag); - } - else if(1 == n_remainder) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - //register to hold alpha - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); - - ///implement TRSM/// - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b *0)); - - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); - - _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm0, 0)); - dtrsm_AuXB_ref(a11, b11, m_remainder, 1, cs_a, cs_b, is_unitdiag); - } - } - } - } - - if ((required_packing_A == 1) && - bli_mem_is_alloc( &local_mem_buf_A_s )) - { - bli_membrk_release(&rntm, &local_mem_buf_A_s); - } - return BLIS_SUCCESS; -} - -/*implements TRSM for the case XA = alpha * B - *A is upper triangular, non-unit diagonal/unit diagonal, no transpose - *dimensions: X:mxn A:nxn B: mxn - * - * b11---> a01 ----> - ***************** *********** - *b01*b11* * * * * * * -b11 * * * * * **a01 * * a11 - | ***************** ********* | - | * * * * * *a11* * | - | * * * * * * * * | - v ***************** ****** v - * * * * * * * - * * * * * * * - ***************** * * - * - -*/ - -BLIS_INLINE err_t bli_dtrsm_small_XAuB -( - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl -) -{ - dim_t m = bli_obj_length(b); //number of rows - dim_t n = bli_obj_width(b); //number of columns - - dim_t cs_a = bli_obj_col_stride(a); //column stride of matrix A - dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B - - dim_t i, j, k; //loop variablse - dim_t k_iter; //determines the number of GEMM operations to be done - - double ones = 1.0; - bool is_unitdiag = bli_obj_has_unit_diag(a); - - double AlphaVal = *(double *)AlphaObj->buffer; //value of Alpha - double* restrict L = a->buffer; //pointer to matrix A - double* restrict B = b->buffer; //pointer to matrix B - - double *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks - - gint_t required_packing_A = 1; - mem_t local_mem_buf_A_s = {0}; - double *D_A_pack = NULL; - double d11_pack[D_MR] __attribute__((aligned(64))); - rntm_t rntm; - - bli_rntm_init_from_global( &rntm ); - bli_rntm_set_num_threads_only( 1, &rntm ); - bli_membrk_rntm_set_membrk( &rntm ); - - siz_t buffer_size = bli_pool_block_size( - bli_membrk_pool( - bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), - bli_rntm_membrk(&rntm))); - - if( (D_NR * n * sizeof(double)) > buffer_size) - return BLIS_NOT_YET_IMPLEMENTED; - - if (required_packing_A == 1) - { - // Get the buffer from the pool. - bli_membrk_acquire_m(&rntm, - buffer_size, - BLIS_BITVAL_BUFFER_FOR_A_BLOCK, - &local_mem_buf_A_s); - if(FALSE==bli_mem_is_alloc(&local_mem_buf_A_s)) return BLIS_NULL_POINTER; - D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); - if(NULL==D_A_pack) return BLIS_NULL_POINTER; - } - - //ymm scratch reginsters - __m256d ymm0, ymm1, ymm2, ymm3; - __m256d ymm4, ymm5, ymm6, ymm7; - __m256d ymm8, ymm9, ymm10, ymm11; - __m256d ymm12, ymm13, ymm14, ymm15; - - __m128d xmm5; - - /* - Performs solving TRSM for 6 rows at a time from 0 to n/6 in steps of D_NR - a. Load and pack A (a01 block), the size of packing 6x6 to 6x (n-6) - First there will be no GEMM and no packing of a01 because it is only TRSM - b. Using packed a01 block and b10 block perform GEMM operation - c. Use GEMM outputs, perform TRSM operation using a11, b11 and update B - d. Repeat b for m cols of B in steps of D_MR - */ - - for(j = 0; (j+D_NR-1) < n; j += D_NR) //loop along 'N' direction - { - a01 = L + j*cs_a; //pointer to block of A to be used in GEMM - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - dim_t p_lda = j; // packed leading dimension - - /* - Pack current A block (a01) into packed buffer memory D_A_pack - a. This a10 block is used in GEMM portion only and this - a01 block size will be increasing by D_NR for every next iteration - until it reaches 6x(n-6) which is the maximum GEMM alone block size in A - b. This packed buffer is reused to calculate all m cols of B matrix - */ - bli_dtrsm_small_pack('R', j, 0, a01, cs_a, D_A_pack, p_lda); - - /* - Pack 6 diagonal elements of A block into an array - a. This helps in utilze cache line efficiently in TRSM operation - b. store ones when input is unit diagonal - */ - dtrsm_small_pack_diag_element(is_unitdiag,a11,cs_a,d11_pack,D_NR); - - /* - a. Perform GEMM using a01, b10. - b. Perform TRSM on a11, b11 - c. This loop GEMM+TRSM loops operates with 8x6 block size - along m dimension for every D_MR columns of B10 where - packed A buffer is reused in computing all m cols of B. - d. Same approach is used in remaining fringe cases. - */ - for(i = 0; (i+D_MR-1) < m; i += D_MR) //loop along 'M' direction - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); - #endif - - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS - - /* - Peform GEMM between a01 and b10 blocks - For first itteration there will be no GEMM operation - where k_iter are zero - */ - BLIS_DTRSM_SMALL_GEMM_6x8(a01,b10,cs_b,p_lda,k_iter) - - /* - Load b11 of size 8x6 and multiply with alpha - Add the GEMM output to b11 - and peform TRSM operation. - */ - - ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + 4)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] - - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - ymm4 = _mm256_fmsub_pd(ymm1, ymm15, ymm4); //B11[4-7][0] * alpha-= ymm1 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b + 4)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] - - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - ymm6 = _mm256_fmsub_pd(ymm1, ymm15, ymm6); //B11[4-7][1] * alpha -= ymm3 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b*2 + 4)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] - - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - ymm8 = _mm256_fmsub_pd(ymm1, ymm15, ymm8); //B11[4-7][2] * alpha -= ymm5 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b*3 + 4)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] - - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 - ymm10 = _mm256_fmsub_pd(ymm1, ymm15, ymm10); //B11[4-7][3] * alpha -= ymm7 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b*4 + 4)); - - ymm11 = _mm256_fmsub_pd(ymm0, ymm15, ymm11); - ymm12 = _mm256_fmsub_pd(ymm1, ymm15, ymm12); - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b*5 + 4)); - - ymm13 = _mm256_fmsub_pd(ymm0, ymm15, ymm13); - ymm14 = _mm256_fmsub_pd(ymm1, ymm15, ymm14); - - ///implement TRSM/// - - /* - Compute 6x8 TRSM block by using GEMM block output in register - a. The 6x8 input (gemm outputs) are stored in combinations of ymm registers - 1. ymm3, ymm4 2. ymm5, ymm6 3. ymm7, ymm8, 4. ymm9, ymm10 - 5. ymm11, ymm12 6. ymm13,ymm14 - b. Towards the end TRSM output will be stored back into b11 - */ - - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - ymm6 = _mm256_fnmadd_pd(ymm1, ymm4, ymm6); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); - - ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); - ymm8 = _mm256_fnmadd_pd(ymm1, ymm4, ymm8); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); - - ymm9 = _mm256_fnmadd_pd(ymm1, ymm3, ymm9); - ymm10 = _mm256_fnmadd_pd(ymm1, ymm4, ymm10); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a)); - - ymm11 = _mm256_fnmadd_pd(ymm1, ymm3, ymm11); - ymm12 = _mm256_fnmadd_pd(ymm1, ymm4, ymm12); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a)); - - ymm13 = _mm256_fnmadd_pd(ymm1, ymm3, ymm13); - ymm14 = _mm256_fnmadd_pd(ymm1, ymm4, ymm14); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm0); - - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1)); - - ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); - ymm8 = _mm256_fnmadd_pd(ymm1, ymm6, ymm8); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1)); - - ymm9 = _mm256_fnmadd_pd(ymm1, ymm5, ymm9); - ymm10 = _mm256_fnmadd_pd(ymm1, ymm6, ymm10); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 1)); - - ymm11 = _mm256_fnmadd_pd(ymm1, ymm5, ymm11); - ymm12 = _mm256_fnmadd_pd(ymm1, ymm6, ymm12); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 1)); - - ymm13 = _mm256_fnmadd_pd(ymm1, ymm5, ymm13); - ymm14 = _mm256_fnmadd_pd(ymm1, ymm6, ymm14); - - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm0); - - //extract a33 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2)); - - ymm9 = _mm256_fnmadd_pd(ymm1, ymm7, ymm9); - ymm10 = _mm256_fnmadd_pd(ymm1, ymm8, ymm10); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 2)); - - ymm11 = _mm256_fnmadd_pd(ymm1, ymm7, ymm11); - ymm12 = _mm256_fnmadd_pd(ymm1, ymm8, ymm12); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 2)); - - ymm13 = _mm256_fnmadd_pd(ymm1, ymm7, ymm13); - ymm14 = _mm256_fnmadd_pd(ymm1, ymm8, ymm14); - - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm0); - - //extract a44 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); - - //(row 4):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 3)); - - ymm11 = _mm256_fnmadd_pd(ymm1, ymm9, ymm11); - ymm12 = _mm256_fnmadd_pd(ymm1, ymm10, ymm12); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 3)); - - ymm13 = _mm256_fnmadd_pd(ymm1, ymm9, ymm13); - ymm14 = _mm256_fnmadd_pd(ymm1, ymm10, ymm14); - - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); - ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm0); - - //extract a55 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); - - //(Row 5): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 4)); - - ymm13 = _mm256_fnmadd_pd(ymm1, ymm11, ymm13); - ymm14 = _mm256_fnmadd_pd(ymm1, ymm12, ymm14); - - ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm0); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + 4), ymm4); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b + 4), ymm6); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*2 + 4), ymm8); - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); - _mm256_storeu_pd((double *)(b11 + cs_b*3 + 4), ymm10); - _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); - _mm256_storeu_pd((double *)(b11 + cs_b*4 + 4), ymm12); - _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); - _mm256_storeu_pd((double *)(b11 + cs_b*5 + 4), ymm14); - } - - dim_t m_remainder = m - i; - if(m_remainder >= 4) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - ymm13 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); //A01[0][3] - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 4)); //A01[0][4] - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 5)); //A01[0][5] - ymm13 = _mm256_fmadd_pd(ymm2, ymm0, ymm13); - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); - ymm11 = _mm256_fmsub_pd(ymm0, ymm15, ymm11); - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); - ymm13 = _mm256_fmsub_pd(ymm0, ymm15, ymm13); - - ///implement TRSM/// - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a )); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a )); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a )); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm3, ymm9); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a )); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm3, ymm11); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm3, ymm13); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm5, ymm9); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 1)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm5, ymm11); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 1)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm5, ymm13); - - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - - //extract a33 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm7, ymm9); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 2)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm7, ymm11); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 2)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm7, ymm13); - - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - - //extract a44 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); - - //(row 4):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 3)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm9, ymm11); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 3)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm9, ymm13); - - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); - - //extract a55 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); - - //(Row 5): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 4)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm11, ymm13); - - ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); - _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); - _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); - - m_remainder -= 4; - i += 4; - } - - if(m_remainder == 3) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - ymm13 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); //A01[0][3] - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 4)); //A01[0][4] - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 5)); //A01[0][5] - ymm13 = _mm256_fmadd_pd(ymm2, ymm0, ymm13); - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); - ymm11 = _mm256_fmsub_pd(ymm0, ymm15, ymm11); - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); - ymm13 = _mm256_fmsub_pd(ymm0, ymm15, ymm13); - - ///implement TRSM/// - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm3, ymm9); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm3, ymm11); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm3, ymm13); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm5, ymm9); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 1)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm5, ymm11); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 1)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm5, ymm13); - - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - - //extract a33 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm7, ymm9); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 2)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm7, ymm11); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 2)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm7, ymm13); - - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - - //extract a44 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); - - //(row 4):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 3)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm9, ymm11); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 3)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm9, ymm13); - - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); - - //extract a55 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); - - //(Row 5): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 4)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm11, ymm13); - - ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_pd(ymm0, ymm11, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_pd(ymm0, ymm13, 0x07); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); - _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); - _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); - - m_remainder -= 3; - i += 3; - } - else if(m_remainder == 2) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - ymm13 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); //A01[0][3] - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 4)); //A01[0][4] - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 5)); //A01[0][5] - ymm13 = _mm256_fmadd_pd(ymm2, ymm0, ymm13); - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); - ymm11 = _mm256_fmsub_pd(ymm0, ymm15, ymm11); - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); - ymm13 = _mm256_fmsub_pd(ymm0, ymm15, ymm13); - - ///implement TRSM/// - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm3, ymm9); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm3, ymm11); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm3, ymm13); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm5, ymm9); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 1)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm5, ymm11); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 1)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm5, ymm13); - - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - - //extract a33 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm7, ymm9); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 2)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm7, ymm11); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 2)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm7, ymm13); - - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - - //extract a44 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); - - //(row 4):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 3)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm9, ymm11); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 3)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm9, ymm13); - - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); - - //extract a55 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); - - //(Row 5): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 4)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm11, ymm13); - - ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_pd(ymm0, ymm11, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_pd(ymm0, ymm13, 0x03); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); - _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); - _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); - - m_remainder -= 2; - i += 2; - } - else if(m_remainder == 1) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - ymm13 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); //A01[0][3] - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 4)); //A01[0][4] - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 5)); //A01[0][5] - ymm13 = _mm256_fmadd_pd(ymm2, ymm0, ymm13); - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); - ymm11 = _mm256_fmsub_pd(ymm0, ymm15, ymm11); - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); - ymm13 = _mm256_fmsub_pd(ymm0, ymm15, ymm13); - - ///implement TRSM/// - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm3, ymm9); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm3, ymm11); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm3, ymm13); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm5, ymm9); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 1)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm5, ymm11); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 1)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm5, ymm13); - - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - - //extract a33 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm7, ymm9); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 2)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm7, ymm11); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 2)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm7, ymm13); - - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - - //extract a44 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); - - //(row 4):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 3)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm9, ymm11); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 3)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm9, ymm13); - - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); - - //extract a55 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); - - //(Row 5): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 4)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm11, ymm13); - - ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_pd(ymm0, ymm11, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_pd(ymm0, ymm13, 0x01); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); - _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); - _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); - - m_remainder -= 1; - i += 1; - } - } - - dim_t n_remainder = n - j; - - /* - Reminder cases starts here: - a. Similar logic and code flow used in computing full block (6x8) - above holds for reminder cases too. - */ - - if(n_remainder >= 4) - { - a01 = L + j*cs_a; //pointer to block of A to be used in GEMM - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - - double *ptr_a10_dup = D_A_pack; - - dim_t p_lda = j; // packed leading dimension - // perform copy of A to packed buffer D_A_pack - - dim_t loop_count = j/4; - - for(dim_t x =0;x < loop_count;x++) - { - ymm15 = _mm256_loadu_pd((double const *)(a01 + cs_a * 0 + x*4)); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + x*4), ymm15); - ymm15 = _mm256_loadu_pd((double const *)(a01 + cs_a * 1 + x*4)); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 1 + x*4), ymm15); - ymm15 = _mm256_loadu_pd((double const *)(a01 + cs_a * 2 + x*4)); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 2 + x*4), ymm15); - ymm15 = _mm256_loadu_pd((double const *)(a01 + cs_a * 3 + x*4)); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 3 + x*4), ymm15); - } - - dim_t remainder_loop_count = p_lda - loop_count*4; - - __m128d xmm0; - if(remainder_loop_count != 0) - { - xmm0 = _mm_loadu_pd((double const *)(a01 + cs_a * 0 + loop_count*4)); - _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + loop_count*4), xmm0); - xmm0 = _mm_loadu_pd((double const *)(a01 + cs_a * 1 + loop_count*4)); - _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 1 + loop_count*4), xmm0); - xmm0 = _mm_loadu_pd((double const *)(a01 + cs_a * 2 + loop_count*4)); - _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 2 + loop_count*4), xmm0); - xmm0 = _mm_loadu_pd((double const *)(a01 + cs_a * 3 + loop_count*4)); - _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 3 + loop_count*4), xmm0); - } - - ymm4 = _mm256_broadcast_sd((double const *)&ones); - if(!is_unitdiag) - { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_sd((double const *)(a11)); - ymm1 = _mm256_broadcast_sd((double const *)(a11+cs_a*1 + 1)); - ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a*2 + 2)); - ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a*3 + 3)); - - ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); - - ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); - #ifdef BLIS_DISABLE_TRSM_PREINVERSION - ymm4 = ymm1; - #endif - #ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm4 = _mm256_div_pd(ymm4, ymm1); - #endif - } - _mm256_storeu_pd((double *)(d11_pack), ymm4); - - for(i = 0; (i+D_MR-1) < m; i += D_MR) //loop along 'M' direction - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); - #endif - - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b10 + 4)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); //A01[0][3] - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + 4)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] - - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - ymm4 = _mm256_fmsub_pd(ymm1, ymm15, ymm4); //B11[4-7][0] * alpha-= ymm1 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b + 4)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] - - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - ymm6 = _mm256_fmsub_pd(ymm1, ymm15, ymm6); //B11[4-7][1] * alpha -= ymm3 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b*2 + 4)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] - - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - ymm8 = _mm256_fmsub_pd(ymm1, ymm15, ymm8); //B11[4-7][2] * alpha -= ymm5 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b*3 + 4)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] - - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 - ymm10 = _mm256_fmsub_pd(ymm1, ymm15, ymm10); //B11[4-7][3] * alpha -= ymm7 - - ///implement TRSM/// - - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - ymm6 = _mm256_fnmadd_pd(ymm1, ymm4, ymm6); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); - - ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); - ymm8 = _mm256_fnmadd_pd(ymm1, ymm4, ymm8); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a )); - - ymm9 = _mm256_fnmadd_pd(ymm1, ymm3, ymm9); - ymm10 = _mm256_fnmadd_pd(ymm1, ymm4, ymm10); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm0); - - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1)); - - ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); - ymm8 = _mm256_fnmadd_pd(ymm1, ymm6, ymm8); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1)); - - ymm9 = _mm256_fnmadd_pd(ymm1, ymm5, ymm9); - ymm10 = _mm256_fnmadd_pd(ymm1, ymm6, ymm10); - - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm0); - - //extract a33 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2)); - - ymm9 = _mm256_fnmadd_pd(ymm1, ymm7, ymm9); - ymm10 = _mm256_fnmadd_pd(ymm1, ymm8, ymm10); - - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm0); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + 4), ymm4); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b + 4), ymm6); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*2 + 4), ymm8); - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); - _mm256_storeu_pd((double *)(b11 + cs_b*3 + 4), ymm10); - } - - dim_t m_remainder = m - i; - if(m_remainder >= 4) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); //A01[0][3] - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 - - ///implement TRSM/// - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm3, ymm9); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm5, ymm9); - - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - - //extract a33 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm7, ymm9); - - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); - - m_remainder -= 4; - i += 4; - } - - if(m_remainder == 3) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); //A01[0][3] - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*3 + 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 - - ///implement TRSM/// - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm3, ymm9); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm5, ymm9); - - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - - //extract a33 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm7, ymm9); - - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x07); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*3 + 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x07); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - xmm5 = _mm256_extractf128_pd(ymm9, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 3),xmm5); - _mm_storel_pd((b11 + cs_b * 3 + 2), _mm256_extractf128_pd(ymm9, 1)); - - m_remainder -= 3; - i += 3; - } - else if(m_remainder == 2) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); //A01[0][3] - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 - - ///implement TRSM/// - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a )); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm3, ymm9); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm5, ymm9); - - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - - //extract a33 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm7, ymm9); - - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x03); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x03); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - xmm5 = _mm256_extractf128_pd(ymm9, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 3),xmm5); - - m_remainder -= 2; - i += 2; - } - else if(m_remainder == 1) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); //A01[0][3] - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_broadcast_sd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 - - ///implement TRSM/// - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm3, ymm9); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm5, ymm9); - - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - - //extract a33 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm7, ymm9); - - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - - ymm0 = _mm256_broadcast_sd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x01); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x01); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x01); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x01); - - _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm3, 0)); - _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm5, 0)); - _mm_storel_pd((b11 + cs_b * 2), _mm256_extractf128_pd(ymm7, 0)); - _mm_storel_pd((b11 + cs_b * 3), _mm256_extractf128_pd(ymm9, 0)); - - m_remainder -= 1; - i += 1; - } - j += 4; - n_remainder -= 4; - } - - if(n_remainder == 3) - { - a01 = L + j*cs_a; //pointer to block of A to be used in GEMM - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - - double *ptr_a10_dup = D_A_pack; - - dim_t p_lda = j; // packed leading dimension - // perform copy of A to packed buffer D_A_pack - - dim_t loop_count = j/4; - - for(dim_t x =0;x < loop_count;x++) - { - ymm15 = _mm256_loadu_pd((double const *)(a01 + cs_a * 0 + x*4)); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + x*4), ymm15); - ymm15 = _mm256_loadu_pd((double const *)(a01 + cs_a * 1 + x*4)); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 1 + x*4), ymm15); - ymm15 = _mm256_loadu_pd((double const *)(a01 + cs_a * 2 + x*4)); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 2 + x*4), ymm15); - } - - dim_t remainder_loop_count = p_lda - loop_count*4; - - __m128d xmm0; - if(remainder_loop_count != 0) - { - xmm0 = _mm_loadu_pd((double const *)(a01 + cs_a * 0 + loop_count*4)); - _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + loop_count*4), xmm0); - xmm0 = _mm_loadu_pd((double const *)(a01 + cs_a * 1 + loop_count*4)); - _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 1 + loop_count*4), xmm0); - xmm0 = _mm_loadu_pd((double const *)(a01 + cs_a * 2 + loop_count*4)); - _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 2 + loop_count*4), xmm0); - } - - ymm4 = _mm256_broadcast_sd((double const *)&ones); - if(!is_unitdiag) - { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_sd((double const *)(a11)); - ymm1 = _mm256_broadcast_sd((double const *)(a11+cs_a*1 + 1)); - ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a*2 + 2)); - ymm3 = _mm256_broadcast_sd((double const *)&ones); - - ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); - - ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); - #ifdef BLIS_DISABLE_TRSM_PREINVERSION - ymm4 = ymm1; - #endif - #ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm4 = _mm256_div_pd(ymm4, ymm1); - #endif - } - _mm256_storeu_pd((double *)(d11_pack), ymm4); - - for(i = 0; (i+D_MR-1) < m; i += D_MR) //loop along 'M' direction - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + 4), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + 4 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + 4 + cs_b*2), _MM_HINT_T0); - #endif - - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b10 + 4)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + 4)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] - - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - ymm4 = _mm256_fmsub_pd(ymm1, ymm15, ymm4); //B11[4-7][0] * alpha-= ymm1 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b + 4)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] - - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - ymm6 = _mm256_fmsub_pd(ymm1, ymm15, ymm6); //B11[4-7][1] * alpha -= ymm3 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b*2 + 4)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] - - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - ymm8 = _mm256_fmsub_pd(ymm1, ymm15, ymm8); //B11[4-7][2] * alpha -= ymm5 - - ///implement TRSM/// - - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - ymm6 = _mm256_fnmadd_pd(ymm1, ymm4, ymm6); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); - - ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); - ymm8 = _mm256_fnmadd_pd(ymm1, ymm4, ymm8); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm0); - - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1)); - - ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); - ymm8 = _mm256_fnmadd_pd(ymm1, ymm6, ymm8); - - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm0); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + 4), ymm4); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b + 4), ymm6); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*2 + 4), ymm8); - } - - dim_t m_remainder = m - i; - if(m_remainder >= 4) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - - ///implement TRSM/// - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); - - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - - m_remainder -= 4; - i += 4; - } - - if(m_remainder == 3) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2)); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*2 + 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - - ///implement TRSM/// - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); - - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x07); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2)); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*2 + 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x07); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - xmm5 = _mm256_extractf128_pd(ymm7, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 2),xmm5); - _mm_storel_pd((b11 + cs_b * 2 + 2), _mm256_extractf128_pd(ymm7, 1)); - - m_remainder -= 3; - i += 3; - } - else if(m_remainder == 2) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - - ///implement TRSM/// - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); - - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x03); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x03); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - xmm5 = _mm256_extractf128_pd(ymm7, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 2),xmm5); - - m_remainder -= 2; - i += 2; - } - else if(m_remainder == 1) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_broadcast_sd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - - ///implement TRSM/// - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); - - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - - ymm0 = _mm256_broadcast_sd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x01); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x01); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x01); - - _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm3, 0)); - _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm5, 0)); - _mm_storel_pd((b11 + cs_b * 2), _mm256_extractf128_pd(ymm7, 0)); - - m_remainder -= 1; - i += 1; - } - j += 3; - n_remainder -= 3; - } - else if(n_remainder == 2) - { - a01 = L + j*cs_a; //pointer to block of A to be used in GEMM - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - - double *ptr_a10_dup = D_A_pack; - - dim_t p_lda = j; // packed leading dimension - // perform copy of A to packed buffer D_A_pack - - dim_t loop_count = j/4; - - for(dim_t x =0;x < loop_count;x++) - { - ymm15 = _mm256_loadu_pd((double const *)(a01 + cs_a * 0 + x*4)); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + x*4), ymm15); - ymm15 = _mm256_loadu_pd((double const *)(a01 + cs_a * 1 + x*4)); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 1 + x*4), ymm15); - } - - dim_t remainder_loop_count = p_lda - loop_count*4; - - __m128d xmm0; - if(remainder_loop_count != 0) - { - xmm0 = _mm_loadu_pd((double const *)(a01 + cs_a * 0 + loop_count*4)); - _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + loop_count*4), xmm0); - xmm0 = _mm_loadu_pd((double const *)(a01 + cs_a * 1 + loop_count*4)); - _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 1 + loop_count*4), xmm0); - } - - ymm4 = _mm256_broadcast_sd((double const *)&ones); - if(!is_unitdiag) - { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_sd((double const *)(a11)); - ymm1 = _mm256_broadcast_sd((double const *)(a11+cs_a*1 + 1)); - ymm2 = _mm256_broadcast_sd((double const *)&ones); - ymm3 = _mm256_broadcast_sd((double const *)&ones); - - ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); - - ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); - #ifdef BLIS_DISABLE_TRSM_PREINVERSION - ymm4 = ymm1; - #endif - #ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm4 = _mm256_div_pd(ymm4, ymm1); - #endif - } - _mm256_storeu_pd((double *)(d11_pack), ymm4); - - for(i = 0; (i+D_MR-1) < m; i += D_MR) //loop along 'M' direction - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); - #endif - - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b10 + 4)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + 4)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] - - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - ymm4 = _mm256_fmsub_pd(ymm1, ymm15, ymm4); //B11[4-7][0] * alpha-= ymm1 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b + 4)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] - - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - ymm6 = _mm256_fmsub_pd(ymm1, ymm15, ymm6); //B11[4-7][1] * alpha -= ymm3 - - ///implement TRSM/// - - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - ymm6 = _mm256_fnmadd_pd(ymm1, ymm4, ymm6); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm0); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + 4), ymm4); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b + 4), ymm6); - } - - dim_t m_remainder = m - i; - if(m_remainder >= 4) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - ///implement TRSM/// - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - - m_remainder -= 4; - i += 4; - } - - if(m_remainder == 3) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1)); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*1 + 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - ///implement TRSM/// - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x07); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1)); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*1 + 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x07); - - _mm256_storeu_pd((double *)b11, ymm3); - xmm5 = _mm256_extractf128_pd(ymm5, 0); - _mm_storeu_pd((double *)(b11 + cs_b*1), xmm5); - _mm_storel_pd((b11 + cs_b * 1 + 2), _mm256_extractf128_pd(ymm5, 1)); - - m_remainder -= 3; - i += 3; - } - else if(m_remainder == 2) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - ///implement TRSM/// - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x03); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x03); - - _mm256_storeu_pd((double *)b11, ymm3); - xmm5 = _mm256_extractf128_pd(ymm5, 0); - _mm_storeu_pd((double *)(b11 + cs_b*1), xmm5); - - m_remainder -= 2; - i += 2; - } - else if(m_remainder == 1) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_broadcast_sd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - ///implement TRSM/// - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - - ymm0 = _mm256_broadcast_sd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x01); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x01); - - _mm_storel_pd(b11 , _mm256_extractf128_pd(ymm3, 0)); - _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm5, 0)); - - m_remainder -= 1; - i += 1; - } - j += 2; - n_remainder -= 2; - } - else if(n_remainder == 1) - { - a01 = L + j*cs_a; //pointer to block of A to be used in GEMM - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - - double *ptr_a10_dup = D_A_pack; - - dim_t p_lda = j; // packed leading dimension - // perform copy of A to packed buffer D_A_pack - - dim_t loop_count = j/4; - - for(dim_t x =0;x < loop_count;x++) - { - ymm15 = _mm256_loadu_pd((double const *)(a01 + cs_a * 0 + x*4)); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + x*4), ymm15); - } - - dim_t remainder_loop_count = p_lda - loop_count*4; - - __m128d xmm0; - if(remainder_loop_count != 0) - { - xmm0 = _mm_loadu_pd((double const *)(a01 + cs_a * 0 + loop_count*4)); - _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + loop_count*4), xmm0); - } - - ymm4 = _mm256_broadcast_sd((double const *)&ones); - if(!is_unitdiag) - { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_sd((double const *)(a11)); - ymm1 = _mm256_broadcast_sd((double const *)&ones); - ymm2 = _mm256_broadcast_sd((double const *)&ones); - ymm3 = _mm256_broadcast_sd((double const *)&ones); - - ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); - - ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); - #ifdef BLIS_DISABLE_TRSM_PREINVERSION - ymm4 = ymm1; - #endif - #ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm4 = _mm256_div_pd(ymm4, ymm1); - #endif - } - _mm256_storeu_pd((double *)(d11_pack), ymm4); - - for(i = 0; (i+D_MR-1) < m; i += D_MR) //loop along 'M' direction - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + 4), _MM_HINT_T0); - #endif - - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b10 + 4)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + 4)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] - - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - ymm4 = _mm256_fmsub_pd(ymm1, ymm15, ymm4); //B11[4-7][0] * alpha-= ymm1 - - ///implement TRSM/// - - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm0); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + 4), ymm4); - } - - dim_t m_remainder = m - i; - if(m_remainder >= 4) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ///implement TRSM/// - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - _mm256_storeu_pd((double *)b11, ymm3); - - m_remainder -= 4; - i += 4; - } - - if(m_remainder == 3) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - xmm5 = _mm_loadu_pd((double const*)(b11)); - ymm0 = _mm256_broadcast_sd((double const *)(b11+ 2)); - ymm6 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm3 = _mm256_fmsub_pd(ymm6, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ///implement TRSM/// - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - ymm3 = _mm256_blend_pd(ymm6, ymm3, 0x07); - - xmm5 = _mm256_extractf128_pd(ymm3, 0); - _mm_storeu_pd((double *)(b11), xmm5); - _mm_storel_pd((b11 + 2), _mm256_extractf128_pd(ymm3, 1)); - - m_remainder -= 3; - i += 3; - } - else if(m_remainder == 2) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - xmm5 = _mm_loadu_pd((double const*)(b11)); - ymm6 = _mm256_insertf128_pd(ymm0, xmm5, 0); - - ymm3 = _mm256_fmsub_pd(ymm6, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ///implement TRSM/// - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - ymm3 = _mm256_blend_pd(ymm6, ymm3, 0x03); - - xmm5 = _mm256_extractf128_pd(ymm3, 0); - _mm_storeu_pd((double *)(b11), xmm5); - - m_remainder -= 2; - i += 2; - } - else if(m_remainder == 1) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm6 = _mm256_broadcast_sd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm6, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ///implement TRSM/// - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - ymm3 = _mm256_blend_pd(ymm6, ymm3, 0x01); - - _mm_storel_pd(b11, _mm256_extractf128_pd(ymm3, 0)); - - m_remainder -= 1; - i += 1; - } - j += 1; - n_remainder -= 1; - } - - if ((required_packing_A == 1) && bli_mem_is_alloc( &local_mem_buf_A_s )) - { - bli_membrk_release(&rntm, - &local_mem_buf_A_s); - } - return BLIS_SUCCESS; -} - -/*implements TRSM for the case XA = alpha * B - *A is lower triangular, non-unit diagonal/unit diagonal, transpose - *dimensions: X:mxn A:nxn B: mxn - * - * b11---> a01 ----> - ***************** *********** - *b01*b11* * * * * * * -b11 * * * * * **a01 * * a11 - | ***************** ********* | - | * * * * * *a11* * | - | * * * * * * * * | - v ***************** ****** v - * * * * * * * - * * * * * * * - ***************** * * - * - -*/ - -BLIS_INLINE err_t bli_dtrsm_small_XAltB -( - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl -) -{ - dim_t m = bli_obj_length(b); //number of rows - dim_t n = bli_obj_width(b); //number of columns - - dim_t cs_a = bli_obj_col_stride(a); //column stride of matrix A - dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B - - dim_t i, j, k; //loop variablse - dim_t k_iter; //determines the number of GEMM operations to be done - - double ones = 1.0; - double zero = 0.0; - bool is_unitdiag = bli_obj_has_unit_diag(a); - - double AlphaVal = *(double *)AlphaObj->buffer; //value of Alpha - double* restrict L = a->buffer; //pointer to matrix A - double* restrict B = b->buffer; //pointer to matrix B - - double *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks - - gint_t required_packing_A = 1; - mem_t local_mem_buf_A_s = {0}; - double *D_A_pack = NULL; - double d11_pack[D_MR] __attribute__((aligned(64))); - rntm_t rntm; - - bli_rntm_init_from_global( &rntm ); - bli_rntm_set_num_threads_only( 1, &rntm ); - bli_membrk_rntm_set_membrk( &rntm ); - - siz_t buffer_size = bli_pool_block_size( - bli_membrk_pool( - bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), - bli_rntm_membrk(&rntm))); - - if( (D_NR * n * sizeof(double)) > buffer_size) - return BLIS_NOT_YET_IMPLEMENTED; - - if (required_packing_A == 1) - { - // Get the buffer from the pool. - bli_membrk_acquire_m(&rntm, - buffer_size, - BLIS_BITVAL_BUFFER_FOR_A_BLOCK, - &local_mem_buf_A_s); - if(FALSE==bli_mem_is_alloc(&local_mem_buf_A_s)) return BLIS_NULL_POINTER; - D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); - if(NULL==D_A_pack) return BLIS_NULL_POINTER; - } - - //ymm scratch reginsters - __m256d ymm0, ymm1, ymm2, ymm3; - __m256d ymm4, ymm5, ymm6, ymm7; - __m256d ymm8, ymm9, ymm10, ymm11; - __m256d ymm12, ymm13, ymm14, ymm15; - - __m128d xmm5; - - /* - Performs solving TRSM for 6 rows at a time from 0 to n/6 in steps of D_NR - a. Load and pack A (a01 block), the size of packing 6x6 to 6x (n-6) - First there will be no GEMM and no packing of a01 because it is only TRSM - b. Using packed a01 block and b10 block perform GEMM operation - c. Use GEMM outputs, perform TRSM operation using a11, b11 and update B - d. Repeat b for m cols of B in steps of D_MR - */ - - for(j = 0; (j+D_NR-1) < n; j += D_NR) //loop along 'N' direction - { - a01 = L + j; //pointer to block of A to be used in GEMM - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - - dim_t p_lda = j; // packed leading dimension - // perform copy of A to packed buffer D_A_pack - - /* - Pack current A block (a01) into packed buffer memory D_A_pack - a. This a10 block is used in GEMM portion only and this - a01 block size will be increasing by D_NR for every next iteration - until it reaches 6x(n-6) which is the maximum GEMM alone block size in A - b. This packed buffer is reused to calculate all m cols of B matrix - */ - bli_dtrsm_small_pack('R', j, 1, a01, cs_a, D_A_pack, p_lda); - - /* - Pack 6 diagonal elements of A block into an array - a. This helps in utilze cache line efficiently in TRSM operation - b. store ones when input is unit diagonal - */ - - dtrsm_small_pack_diag_element(is_unitdiag,a11,cs_a,d11_pack,D_NR); - - /* - a. Perform GEMM using a01, b10. - b. Perform TRSM on a11, b11 - c. This loop GEMM+TRSM loops operates with 8x6 block size - along m dimension for every D_MR columns of B10 where - packed A buffer is reused in computing all m cols of B. - d. Same approach is used in remaining fringe cases. - */ - for(i = 0; (i+D_MR-1) < m; i += D_MR) //loop along 'M' direction - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); - #endif - - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS - - /* - Peform GEMM between a01 and b10 blocks - For first itteration there will be no GEMM operation - where k_iter are zero - */ - BLIS_DTRSM_SMALL_GEMM_6x8(a01,b10,cs_b,p_lda,k_iter) - - /* - Load b11 of size 8x6 and multiply with alpha - Add the GEMM output to b11 - and peform TRSM operation. - */ - - ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + 4)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] - - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - ymm4 = _mm256_fmsub_pd(ymm1, ymm15, ymm4); //B11[4-7][0] * alpha -= ymm1 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b + 4)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] - - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha -= ymm2 - ymm6 = _mm256_fmsub_pd(ymm1, ymm15, ymm6); //B11[4-7][1] * alpha -= ymm3 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b*2 + 4)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] - - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - ymm8 = _mm256_fmsub_pd(ymm1, ymm15, ymm8); //B11[4-7][2] * alpha -= ymm5 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b*3 + 4)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] - - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 - ymm10 = _mm256_fmsub_pd(ymm1, ymm15, ymm10); //B11[4-7][3] * alpha -= ymm7 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b*4 + 4)); - - ymm11 = _mm256_fmsub_pd(ymm0, ymm15, ymm11); - ymm12 = _mm256_fmsub_pd(ymm1, ymm15, ymm12); - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b*5 + 4)); - - ymm13 = _mm256_fmsub_pd(ymm0, ymm15, ymm13); - ymm14 = _mm256_fmsub_pd(ymm1, ymm15, ymm14); - - ///implement TRSM/// - - /* - Compute 6x8 TRSM block by using GEMM block output in register - a. The 6x8 input (gemm outputs) are stored in combinations of ymm registers - 1. ymm3, ymm4 2. ymm5, ymm6 3. ymm7, ymm8, 4. ymm9, ymm10 - 5. ymm11, ymm12 6. ymm13,ymm14 - b. Towards the end TRSM output will be stored back into b11 - */ - - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); - - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - ymm6 = _mm256_fnmadd_pd(ymm1, ymm4, ymm6); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); - - ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); - ymm8 = _mm256_fnmadd_pd(ymm1, ymm4, ymm8); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); - - ymm9 = _mm256_fnmadd_pd(ymm1, ymm3, ymm9); - ymm10 = _mm256_fnmadd_pd(ymm1, ymm4, ymm10); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4)); - - ymm11 = _mm256_fnmadd_pd(ymm1, ymm3, ymm11); - ymm12 = _mm256_fnmadd_pd(ymm1, ymm4, ymm12); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5)); - - ymm13 = _mm256_fnmadd_pd(ymm1, ymm3, ymm13); - ymm14 = _mm256_fnmadd_pd(ymm1, ymm4, ymm14); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm0); - - a11 += cs_a; - - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); - - ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); - ymm8 = _mm256_fnmadd_pd(ymm1, ymm6, ymm8); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); - - ymm9 = _mm256_fnmadd_pd(ymm1, ymm5, ymm9); - ymm10 = _mm256_fnmadd_pd(ymm1, ymm6, ymm10); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4)); - - ymm11 = _mm256_fnmadd_pd(ymm1, ymm5, ymm11); - ymm12 = _mm256_fnmadd_pd(ymm1, ymm6, ymm12); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5)); - - ymm13 = _mm256_fnmadd_pd(ymm1, ymm5, ymm13); - ymm14 = _mm256_fnmadd_pd(ymm1, ymm6, ymm14); - - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm0); - - a11 += cs_a; - - //extract a33 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); - - ymm9 = _mm256_fnmadd_pd(ymm1, ymm7, ymm9); - ymm10 = _mm256_fnmadd_pd(ymm1, ymm8, ymm10); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4)); - - ymm11 = _mm256_fnmadd_pd(ymm1, ymm7, ymm11); - ymm12 = _mm256_fnmadd_pd(ymm1, ymm8, ymm12); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5)); - - ymm13 = _mm256_fnmadd_pd(ymm1, ymm7, ymm13); - ymm14 = _mm256_fnmadd_pd(ymm1, ymm8, ymm14); - - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm0); - - a11 += cs_a; - - //extract a44 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); - - //(row 4):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4)); - - ymm11 = _mm256_fnmadd_pd(ymm1, ymm9, ymm11); - ymm12 = _mm256_fnmadd_pd(ymm1, ymm10, ymm12); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5)); - - ymm13 = _mm256_fnmadd_pd(ymm1, ymm9, ymm13); - ymm14 = _mm256_fnmadd_pd(ymm1, ymm10, ymm14); - - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); - ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm0); - - a11 += cs_a; - - //extract a55 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); - - //(Row 5): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5)); - - ymm13 = _mm256_fnmadd_pd(ymm1, ymm11, ymm13); - ymm14 = _mm256_fnmadd_pd(ymm1, ymm12, ymm14); - - ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm0); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + 4), ymm4); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b + 4), ymm6); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*2 + 4), ymm8); - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); - _mm256_storeu_pd((double *)(b11 + cs_b*3 + 4), ymm10); - _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); - _mm256_storeu_pd((double *)(b11 + cs_b*4 + 4), ymm12); - _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); - _mm256_storeu_pd((double *)(b11 + cs_b*5 + 4), ymm14); - } - - dim_t m_remainder = m - i; - if(m_remainder >= 4) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - ymm13 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); //A01[0][3] - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 4)); //A01[0][4] - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 5)); //A01[0][5] - ymm13 = _mm256_fmadd_pd(ymm2, ymm0, ymm13); - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); - ymm11 = _mm256_fmsub_pd(ymm0, ymm15, ymm11); - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); - ymm13 = _mm256_fmsub_pd(ymm0, ymm15, ymm13); - - ///implement TRSM/// - - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm3, ymm9); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm3, ymm11); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm3, ymm13); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - - a11 += cs_a; - - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm5, ymm9); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm5, ymm11); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm5, ymm13); - - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - - a11 += cs_a; - - //extract a33 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm7, ymm9); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm7, ymm11); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm7, ymm13); - - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - - a11 += cs_a; - - //extract a44 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); - - //(row 4):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm9, ymm11); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm9, ymm13); - - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); - - a11 += cs_a; - - //extract a55 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); - - //(Row 5): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm11, ymm13); - - ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); - _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); - _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); - - m_remainder -= 4; - i += 4; - } - - if(m_remainder == 3) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - ymm13 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); //A01[0][3] - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 4)); //A01[0][4] - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 5)); //A01[0][5] - ymm13 = _mm256_fmadd_pd(ymm2, ymm0, ymm13); - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); - ymm11 = _mm256_fmsub_pd(ymm0, ymm15, ymm11); - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); - ymm13 = _mm256_fmsub_pd(ymm0, ymm15, ymm13); - - ///implement TRSM/// - - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm3, ymm9); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm3, ymm11); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm3, ymm13); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - - a11 += cs_a; - - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm5, ymm9); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm5, ymm11); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm5, ymm13); - - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - - a11 += cs_a; - - //extract a33 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm7, ymm9); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm7, ymm11); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm7, ymm13); - - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - - a11 += cs_a; - - //extract a44 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); - - //(row 4):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm9, ymm11); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm9, ymm13); - - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); - - a11 += cs_a; - - //extract a55 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); - - //(Row 5): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm11, ymm13); - - ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_pd(ymm0, ymm11, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_pd(ymm0, ymm13, 0x07); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); - _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); - _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); - - m_remainder -= 3; - i += 3; - } - else if(m_remainder == 2) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - ymm13 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); //A01[0][3] - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 4)); //A01[0][4] - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 5)); //A01[0][5] - ymm13 = _mm256_fmadd_pd(ymm2, ymm0, ymm13); - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); - ymm11 = _mm256_fmsub_pd(ymm0, ymm15, ymm11); - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); - ymm13 = _mm256_fmsub_pd(ymm0, ymm15, ymm13); - - ///implement TRSM/// - - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm3, ymm9); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm3, ymm11); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm3, ymm13); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - - a11 += cs_a; - - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm5, ymm9); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm5, ymm11); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm5, ymm13); - - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - - a11 += cs_a; - - //extract a33 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm7, ymm9); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm7, ymm11); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm7, ymm13); - - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - - a11 += cs_a; - - //extract a44 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); - - //(row 4):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm9, ymm11); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm9, ymm13); - - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); - - a11 += cs_a; - - //extract a55 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); - - //(Row 5): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm11, ymm13); - - ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_pd(ymm0, ymm11, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_pd(ymm0, ymm13, 0x03); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); - _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); - _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); - - m_remainder -= 2; - i += 2; - } - else if(m_remainder == 1) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - ymm13 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); //A01[0][3] - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 4)); //A01[0][4] - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 5)); //A01[0][5] - ymm13 = _mm256_fmadd_pd(ymm2, ymm0, ymm13); - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); - ymm11 = _mm256_fmsub_pd(ymm0, ymm15, ymm11); - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); - ymm13 = _mm256_fmsub_pd(ymm0, ymm15, ymm13); - - ///implement TRSM/// - - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm3, ymm9); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm3, ymm11); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm3, ymm13); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - - a11 += cs_a; - - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm5, ymm9); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm5, ymm11); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm5, ymm13); - - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - - a11 += cs_a; - - //extract a33 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm7, ymm9); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm7, ymm11); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm7, ymm13); - - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - - a11 += cs_a; - - //extract a44 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); - - //(row 4):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm9, ymm11); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm9, ymm13); - - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); - - a11 += cs_a; - - //extract a55 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); - - //(Row 5): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm11, ymm13); - - ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_pd(ymm0, ymm11, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_pd(ymm0, ymm13, 0x01); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); - _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); - _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); - - m_remainder -= 1; - i += 1; - } - } - - dim_t n_remainder = n - j; - - /* - Reminder cases starts here: - a. Similar logic and code flow used in computing full block (6x8) - above holds for reminder cases too. - */ - - if(n_remainder >= 4) - { - a01 = L + j; //pointer to block of A to be used in GEMM - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - - double *ptr_a10_dup = D_A_pack; - - dim_t p_lda = j; // packed leading dimension - // perform copy of A to packed buffer D_A_pack - - for(dim_t x =0;x < p_lda;x+=D_NR) - { - ymm0 = _mm256_loadu_pd((double const *)(a01)); - ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); - ymm2 = _mm256_loadu_pd((double const *)(a01 + cs_a * 2)); - ymm3 = _mm256_loadu_pd((double const *)(a01 + cs_a * 3)); - - ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); - - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); - - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); - - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - - _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); - - ymm0 = _mm256_loadu_pd((double const *)(a01 + cs_a * 4)); - ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a * 5)); - - ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_broadcast_sd((double const *)&zero); - - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); - - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_broadcast_sd((double const *)&zero); - - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - - _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_extractf128_pd(ymm6,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_extractf128_pd(ymm7,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_extractf128_pd(ymm8,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_extractf128_pd(ymm9,0)); - - a01 += D_NR*cs_a; - ptr_a10_dup += D_NR; - } - - ymm4 = _mm256_broadcast_sd((double const *)&ones); - if(!is_unitdiag) - { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_sd((double const *)(a11)); - ymm1 = _mm256_broadcast_sd((double const *)(a11+cs_a*1 + 1)); - ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a*2 + 2)); - ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a*3 + 3)); - - ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); - - ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); - #ifdef BLIS_DISABLE_TRSM_PREINVERSION - ymm4 = ymm1; - #endif - #ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm4 = _mm256_div_pd(ymm4, ymm1); - #endif - } - _mm256_storeu_pd((double *)(d11_pack), ymm4); - - for(i = 0; (i+D_MR-1) < m; i += D_MR) //loop along 'M' direction - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); - #endif - - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b10 + 4)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); //A01[0][3] - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + 4)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] - - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - ymm4 = _mm256_fmsub_pd(ymm1, ymm15, ymm4); //B11[4-7][0] * alpha-= ymm1 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b + 4)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] - - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - ymm6 = _mm256_fmsub_pd(ymm1, ymm15, ymm6); //B11[4-7][1] * alpha -= ymm3 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b*2 + 4)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] - - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - ymm8 = _mm256_fmsub_pd(ymm1, ymm15, ymm8); //B11[4-7][2] * alpha -= ymm5 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b*3 + 4)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] - - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 - ymm10 = _mm256_fmsub_pd(ymm1, ymm15, ymm10); //B11[4-7][3] * alpha -= ymm7 - - ///implement TRSM/// - - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); - - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - ymm6 = _mm256_fnmadd_pd(ymm1, ymm4, ymm6); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); - - ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); - ymm8 = _mm256_fnmadd_pd(ymm1, ymm4, ymm8); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); - - ymm9 = _mm256_fnmadd_pd(ymm1, ymm3, ymm9); - ymm10 = _mm256_fnmadd_pd(ymm1, ymm4, ymm10); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm0); - - a11 += cs_a; - - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); - - ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); - ymm8 = _mm256_fnmadd_pd(ymm1, ymm6, ymm8); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); - - ymm9 = _mm256_fnmadd_pd(ymm1, ymm5, ymm9); - ymm10 = _mm256_fnmadd_pd(ymm1, ymm6, ymm10); - - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm0); - - a11 += cs_a; - - //extract a33 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); - - ymm9 = _mm256_fnmadd_pd(ymm1, ymm7, ymm9); - ymm10 = _mm256_fnmadd_pd(ymm1, ymm8, ymm10); - - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm0); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + 4), ymm4); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b + 4), ymm6); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*2 + 4), ymm8); - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); - _mm256_storeu_pd((double *)(b11 + cs_b*3 + 4), ymm10); - } - - dim_t m_remainder = m - i; - if(m_remainder >= 4) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); //A01[0][3] - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 - - ///implement TRSM/// - - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm3, ymm9); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - - a11 += cs_a; - - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm5, ymm9); - - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - - a11 += cs_a; - - //extract a33 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm7, ymm9); - - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); - - m_remainder -= 4; - i += 4; - } - - if(m_remainder == 3) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); //A01[0][3] - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*3 + 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 - - ///implement TRSM/// - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm3, ymm9); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - - a11 += cs_a; - - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm5, ymm9); - - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - - a11 += cs_a; - - //extract a33 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm7, ymm9); - - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x07); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*3 + 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x07); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - xmm5 = _mm256_extractf128_pd(ymm9, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 3),xmm5); - _mm_storel_pd((b11 + cs_b * 3 + 2), _mm256_extractf128_pd(ymm9, 1)); - - m_remainder -= 3; - i += 3; - } - else if(m_remainder == 2) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); //A01[0][3] - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 - - ///implement TRSM/// - - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm3, ymm9); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - - a11 += cs_a; - - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm5, ymm9); - - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - - a11 += cs_a; - - //extract a33 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm7, ymm9); - - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x03); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x03); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - xmm5 = _mm256_extractf128_pd(ymm9, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 3),xmm5); - - m_remainder -= 2; - i += 2; - } - else if(m_remainder == 1) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); //A01[0][3] - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - - ymm0 = _mm256_broadcast_sd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 - - ///implement TRSM/// - - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm3, ymm9); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - - a11 += cs_a; - - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm5, ymm9); - - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - - a11 += cs_a; - - //extract a33 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm7, ymm9); - - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - - ymm0 = _mm256_broadcast_sd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x01); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x01); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x01); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x01); - - _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm3, 0)); - _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm5, 0)); - _mm_storel_pd((b11 + cs_b * 2), _mm256_extractf128_pd(ymm7, 0)); - _mm_storel_pd((b11 + cs_b * 3), _mm256_extractf128_pd(ymm9, 0)); - - m_remainder -= 1; - i += 1; - } - j += 4; - n_remainder -= 4; - } - - if(n_remainder == 3) - { - a01 = L + j; //pointer to block of A to be used in GEMM - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - - double *ptr_a10_dup = D_A_pack; - - dim_t p_lda = j; // packed leading dimension - // perform copy of A to packed buffer D_A_pack - - for(dim_t x =0;x < p_lda;x+=D_NR) - { - ymm0 = _mm256_loadu_pd((double const *)(a01)); - ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); - ymm2 = _mm256_loadu_pd((double const *)(a01 + cs_a * 2)); - ymm3 = _mm256_loadu_pd((double const *)(a01 + cs_a * 3)); - - ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); - - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); - - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); - - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - - _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); - - ymm0 = _mm256_loadu_pd((double const *)(a01 + cs_a * 4)); - ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a * 5)); - - ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_broadcast_sd((double const *)&zero); - - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); - - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_broadcast_sd((double const *)&zero); - - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - - _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_extractf128_pd(ymm6,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_extractf128_pd(ymm7,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_extractf128_pd(ymm8,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_extractf128_pd(ymm9,0)); - - a01 += D_NR*cs_a; - ptr_a10_dup += D_NR; - } - - ymm4 = _mm256_broadcast_sd((double const *)&ones); - if(!is_unitdiag) - { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_sd((double const *)(a11)); - ymm1 = _mm256_broadcast_sd((double const *)(a11+cs_a*1 + 1)); - ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a*2 + 2)); - ymm3 = _mm256_broadcast_sd((double const *)&ones); - - ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); - - ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); - #ifdef BLIS_DISABLE_TRSM_PREINVERSION - ymm4 = ymm1; - #endif - #ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm4 = _mm256_div_pd(ymm4, ymm1); - #endif - } - _mm256_storeu_pd((double *)(d11_pack), ymm4); - - for(i = 0; (i+D_MR-1) < m; i += D_MR) //loop along 'M' direction - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); - #endif - - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b10 + 4)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + 4)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] - - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - ymm4 = _mm256_fmsub_pd(ymm1, ymm15, ymm4); //B11[4-7][0] * alpha-= ymm1 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b + 4)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] - - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - ymm6 = _mm256_fmsub_pd(ymm1, ymm15, ymm6); //B11[4-7][1] * alpha -= ymm3 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b*2 + 4)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] - - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - ymm8 = _mm256_fmsub_pd(ymm1, ymm15, ymm8); //B11[4-7][2] * alpha -= ymm5 - - ///implement TRSM/// - - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); - - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - ymm6 = _mm256_fnmadd_pd(ymm1, ymm4, ymm6); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); - - ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); - ymm8 = _mm256_fnmadd_pd(ymm1, ymm4, ymm8); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm0); - - a11 += cs_a; - - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); - - ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); - ymm8 = _mm256_fnmadd_pd(ymm1, ymm6, ymm8); - - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm0); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + 4), ymm4); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b + 4), ymm6); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*2 + 4), ymm8); - } - - dim_t m_remainder = m - i; - if(m_remainder >= 4) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - - ///implement TRSM/// - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - - a11 += cs_a; - - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); - - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - - m_remainder -= 4; - i += 4; - } - - if(m_remainder == 3) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2)); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*2 + 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - - ///implement TRSM/// - - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - - a11 += cs_a; - - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); - - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x07); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2)); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*2 + 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x07); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - xmm5 = _mm256_extractf128_pd(ymm7, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 2),xmm5); - _mm_storel_pd((b11 + cs_b * 2 + 2), _mm256_extractf128_pd(ymm7, 1)); - - m_remainder -= 3; - i += 3; - } - else if(m_remainder == 2) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - - ///implement TRSM/// - - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - - a11 += cs_a; - - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); - - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x03); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x03); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - xmm5 = _mm256_extractf128_pd(ymm7, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 2),xmm5); - - m_remainder -= 2; - i += 2; - } - else if(m_remainder == 1) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - - ymm0 = _mm256_broadcast_sd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - - ///implement TRSM/// - - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - - a11 += cs_a; - - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); - - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - - ymm0 = _mm256_broadcast_sd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x01); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x01); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x01); - - _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm3, 0)); - _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm5, 0)); - _mm_storel_pd((b11 + cs_b * 2), _mm256_extractf128_pd(ymm7, 0)); - - m_remainder -= 1; - i += 1; - } - j += 3; - n_remainder -= 3; - } - else if(n_remainder == 2) - { - a01 = L + j; //pointer to block of A to be used in GEMM - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - - double *ptr_a10_dup = D_A_pack; - - dim_t p_lda = j; // packed leading dimension - // perform copy of A to packed buffer D_A_pack - - for(dim_t x =0;x < p_lda;x+=D_NR) - { - ymm0 = _mm256_loadu_pd((double const *)(a01)); - ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); - ymm2 = _mm256_loadu_pd((double const *)(a01 + cs_a * 2)); - ymm3 = _mm256_loadu_pd((double const *)(a01 + cs_a * 3)); - - ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); - - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); - - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); - - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - - _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); - - ymm0 = _mm256_loadu_pd((double const *)(a01 + cs_a * 4)); - ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a * 5)); - - ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_broadcast_sd((double const *)&zero); - - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); - - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_broadcast_sd((double const *)&zero); - - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - - _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_extractf128_pd(ymm6,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_extractf128_pd(ymm7,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_extractf128_pd(ymm8,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_extractf128_pd(ymm9,0)); - - a01 += D_NR*cs_a; - ptr_a10_dup += D_NR; - } - - ymm4 = _mm256_broadcast_sd((double const *)&ones); - if(!is_unitdiag) - { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_sd((double const *)(a11)); - ymm1 = _mm256_broadcast_sd((double const *)(a11+cs_a*1 + 1)); - ymm2 = _mm256_broadcast_sd((double const *)&ones); - ymm3 = _mm256_broadcast_sd((double const *)&ones); - - ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); - - ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); - #ifdef BLIS_DISABLE_TRSM_PREINVERSION - ymm4 = ymm1; - #endif - #ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm4 = _mm256_div_pd(ymm4, ymm1); - #endif - } - _mm256_storeu_pd((double *)(d11_pack), ymm4); - - for(i = 0; (i+D_MR-1) < m; i += D_MR) //loop along 'M' direction - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); - #endif - - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b10 + 4)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + 4)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] - - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - ymm4 = _mm256_fmsub_pd(ymm1, ymm15, ymm4); //B11[4-7][0] * alpha-= ymm1 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b + 4)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] - - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - ymm6 = _mm256_fmsub_pd(ymm1, ymm15, ymm6); //B11[4-7][1] * alpha -= ymm3 - - ///implement TRSM/// - - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); - - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - ymm6 = _mm256_fnmadd_pd(ymm1, ymm4, ymm6); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm0); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + 4), ymm4); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b + 4), ymm6); - } - - dim_t m_remainder = m - i; - if(m_remainder >= 4) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - ///implement TRSM/// - - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - - m_remainder -= 4; - i += 4; - } - - if(m_remainder == 3) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1)); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*1 + 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - ///implement TRSM/// - - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x07); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1)); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*1 + 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x07); - - _mm256_storeu_pd((double *)b11, ymm3); - xmm5 = _mm256_extractf128_pd(ymm5, 0); - _mm_storeu_pd((double *)(b11 + cs_b*1), xmm5); - _mm_storel_pd((b11 + cs_b * 1 + 2), _mm256_extractf128_pd(ymm5, 1)); - - m_remainder -= 3; - i += 3; - } - else if(m_remainder == 2) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - ///implement TRSM/// - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x03); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x03); - - _mm256_storeu_pd((double *)b11, ymm3); - xmm5 = _mm256_extractf128_pd(ymm5, 0); - _mm_storeu_pd((double *)(b11 + cs_b*1), xmm5); - - m_remainder -= 2; - i += 2; - } - else if(m_remainder == 1) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - - ymm0 = _mm256_broadcast_sd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - ///implement TRSM/// - - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - - ymm0 = _mm256_broadcast_sd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x01); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x01); - - _mm_storel_pd(b11 , _mm256_extractf128_pd(ymm3, 0)); - _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm5, 0)); - - m_remainder -= 1; - i += 1; - } - j += 2; - n_remainder -= 2; - } - else if(n_remainder == 1) - { - a01 = L + j; //pointer to block of A to be used in GEMM - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - - double *ptr_a10_dup = D_A_pack; - - dim_t p_lda = j; // packed leading dimension - // perform copy of A to packed buffer D_A_pack - - for(dim_t x =0;x < p_lda;x+=D_NR) - { - ymm0 = _mm256_loadu_pd((double const *)(a01)); - ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); - ymm2 = _mm256_loadu_pd((double const *)(a01 + cs_a * 2)); - ymm3 = _mm256_loadu_pd((double const *)(a01 + cs_a * 3)); - - ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); - - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); - - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); - - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - - _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); - - ymm0 = _mm256_loadu_pd((double const *)(a01 + cs_a * 4)); - ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a * 5)); - - ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_broadcast_sd((double const *)&zero); - - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); - - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_broadcast_sd((double const *)&zero); - - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - - _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_extractf128_pd(ymm6,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_extractf128_pd(ymm7,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_extractf128_pd(ymm8,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_extractf128_pd(ymm9,0)); - - a01 += D_NR*cs_a; - ptr_a10_dup += D_NR; - } - - ymm4 = _mm256_broadcast_sd((double const *)&ones); - if(!is_unitdiag) - { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_sd((double const *)(a11)); - ymm1 = _mm256_broadcast_sd((double const *)&ones); - ymm2 = _mm256_broadcast_sd((double const *)&ones); - ymm3 = _mm256_broadcast_sd((double const *)&ones); - - ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); - - ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); - #ifdef BLIS_DISABLE_TRSM_PREINVERSION - ymm4 = ymm1; - #endif - #ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm4 = _mm256_div_pd(ymm4, ymm1); - #endif - } - _mm256_storeu_pd((double *)(d11_pack), ymm4); - - for(i = 0; (i+D_MR-1) < m; i += D_MR) //loop along 'M' direction - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + 4), _MM_HINT_T0); - #endif - - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b10 + 4)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + 4)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] - - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - ymm4 = _mm256_fmsub_pd(ymm1, ymm15, ymm4); //B11[4-7][0] * alpha-= ymm1 - - ///implement TRSM/// - - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm0); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + 4), ymm4); - } - - dim_t m_remainder = m - i; - if(m_remainder >= 4) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ///implement TRSM/// - - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - _mm256_storeu_pd((double *)b11, ymm3); - - m_remainder -= 4; - i += 4; - } - - if(m_remainder == 3) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - - xmm5 = _mm_loadu_pd((double const*)(b11)); - ymm0 = _mm256_broadcast_sd((double const *)(b11+ 2)); - ymm6 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm3 = _mm256_fmsub_pd(ymm6, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ///implement TRSM/// - - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm6, ymm3, 0x07); - - xmm5 = _mm256_extractf128_pd(ymm3, 0); - _mm_storeu_pd((double *)(b11), xmm5); - _mm_storel_pd((b11 + 2), _mm256_extractf128_pd(ymm3, 1)); - - m_remainder -= 3; - i += 3; - } - else if(m_remainder == 2) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - - xmm5 = _mm_loadu_pd((double const*)(b11)); - ymm6 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm3 = _mm256_fmsub_pd(ymm6, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ///implement TRSM/// - - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm6, ymm3, 0x03); - - xmm5 = _mm256_extractf128_pd(ymm3, 0); - _mm_storeu_pd((double *)(b11), xmm5); - - m_remainder -= 2; - i += 2; - } - else if(m_remainder == 1) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - - ymm6 = _mm256_broadcast_sd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm6, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ///implement TRSM/// - - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - ymm3 = _mm256_blend_pd(ymm6, ymm3, 0x01); - - _mm_storel_pd(b11, _mm256_extractf128_pd(ymm3, 0)); - - m_remainder -= 1; - i += 1; - } - j += 1; - n_remainder -= 1; - } - - if ((required_packing_A == 1) && bli_mem_is_alloc( &local_mem_buf_A_s )) - { - bli_membrk_release(&rntm, - &local_mem_buf_A_s); - } - - return BLIS_SUCCESS; -} - -/*implements TRSM for the case XA = alpha * B - *A is lower triangular, non-unit diagonal/unit diagonal, no transpose - *dimensions: X:mxn A:nxn B: mxn - * - * <---b11 <---a11 - ***************** * - *b01*b11* * * * * - ^ * * * * * ^ * * - | ***************** | ******* - | * * * * * | * * * - | * * * * * a01* * * -b10 ***************** ************* - * * * * * * * * * - * * * * * * * * * - ***************** ******************* - -*/ -BLIS_INLINE err_t bli_dtrsm_small_XAlB -( - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl -) -{ - dim_t m = bli_obj_length(b); //number of rows - dim_t n = bli_obj_width(b); //number of columns - - dim_t cs_a = bli_obj_col_stride(a); //column stride of matrix A - dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B - - dim_t i, j, k; //loop variablse - dim_t k_iter; //determines the number of GEMM operations to be done - - double ones = 1.0; - bool is_unitdiag = bli_obj_has_unit_diag(a); - - double AlphaVal = *(double *)AlphaObj->buffer; //value of Alpha - double* restrict L = a->buffer; //pointer to matrix A - double* restrict B = b->buffer; //pointer to matrix B - - double *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks - - gint_t required_packing_A = 1; - mem_t local_mem_buf_A_s = {0}; - double *D_A_pack = NULL; - double d11_pack[D_MR] __attribute__((aligned(64))); - rntm_t rntm; - - bli_rntm_init_from_global( &rntm ); - bli_rntm_set_num_threads_only( 1, &rntm ); - bli_membrk_rntm_set_membrk( &rntm ); - - siz_t buffer_size = bli_pool_block_size( - bli_membrk_pool( - bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), - bli_rntm_membrk(&rntm))); - - if( (D_NR * n * sizeof(double)) > buffer_size) - return BLIS_NOT_YET_IMPLEMENTED; - - if (required_packing_A == 1) - { - // Get the buffer from the pool. - bli_membrk_acquire_m(&rntm, - buffer_size, - BLIS_BITVAL_BUFFER_FOR_A_BLOCK, - &local_mem_buf_A_s); - if(FALSE==bli_mem_is_alloc(&local_mem_buf_A_s)) return BLIS_NULL_POINTER; - D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); - if(NULL==D_A_pack) return BLIS_NULL_POINTER; - } - - //ymm scratch reginsters - __m256d ymm0, ymm1, ymm2, ymm3; - __m256d ymm4, ymm5, ymm6, ymm7; - __m256d ymm8, ymm9, ymm10, ymm11; - __m256d ymm12, ymm13, ymm14, ymm15; - - __m128d xmm5; - - /* - Performs solving TRSM for 6 rows at a time from 0 to n/6 in steps of D_NR - a. Load and pack A (a01 block), the size of packing 6x6 to 6x (n-6) - First there will be no GEMM and no packing of a01 because it is only TRSM - b. Using packed a01 block and b10 block perform GEMM operation - c. Use GEMM outputs, perform TRSM operation using a11, b11 and update B - d. Repeat b for m cols of B in steps of D_MR - */ - - for(j = (n-D_NR); (j+1) > 0; j -= D_NR) //loop along 'N' direction - { - a01 = L + j*cs_a +(j+D_NR); //pointer to block of A to be used in GEMM - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - - dim_t p_lda = (n-j-D_NR); // packed leading dimension - // perform copy of A to packed buffer D_A_pack - - /* - Pack current A block (a01) into packed buffer memory D_A_pack - a. This a10 block is used in GEMM portion only and this - a01 block size will be increasing by D_NR for every next iteration - until it reaches 6x(n-6) which is the maximum GEMM alone block size in A - b. This packed buffer is reused to calculate all m cols of B matrix - */ - bli_dtrsm_small_pack('R', p_lda, 0, a01, cs_a, D_A_pack, p_lda); - - /* - Pack 6 diagonal elements of A block into an array - a. This helps in utilze cache line efficiently in TRSM operation - b. store ones when input is unit diagonal - */ - dtrsm_small_pack_diag_element(is_unitdiag,a11,cs_a,d11_pack,D_NR); - - /* - a. Perform GEMM using a01, b10. - b. Perform TRSM on a11, b11 - c. This loop GEMM+TRSM loops operates with 8x6 block size - along m dimension for every D_MR columns of B10 where - packed A buffer is reused in computing all m cols of B. - d. Same approach is used in remaining fringe cases. - */ - - for(i = (m-D_MR); (i+1) > 0; i -= D_MR)//loop along 'M' direction - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (i) + (j)*cs_b; //pointer to block of B to be used for TRSM - - k_iter = (n-j-D_NR);//no. of GEMM operations to be done(in blocks of 4x4) - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); - #endif - - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS - - /* - Peform GEMM between a01 and b10 blocks - For first itteration there will be no GEMM operation - where k_iter are zero - */ - - BLIS_DTRSM_SMALL_GEMM_6x8(a01,b10,cs_b,p_lda,k_iter) - - /* - Load b11 of size 8x6 and multiply with alpha - Add the GEMM output to b11 - and peform TRSM operation. - */ - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + 4)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] - - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - ymm4 = _mm256_fmsub_pd(ymm1, ymm15, ymm4); //B11[4-7][0] * alpha-= ymm1 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b + 4)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] - - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - ymm6 = _mm256_fmsub_pd(ymm1, ymm15, ymm6); //B11[4-7][1] * alpha -= ymm3 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b*2 + 4)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] - - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - ymm8 = _mm256_fmsub_pd(ymm1, ymm15, ymm8); //B11[4-7][2] * alpha -= ymm5 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b*3 + 4)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] - - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 - ymm10 = _mm256_fmsub_pd(ymm1, ymm15, ymm10); //B11[4-7][3] * alpha -= ymm7 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b*4 + 4)); - - ymm11 = _mm256_fmsub_pd(ymm0, ymm15, ymm11); - ymm12 = _mm256_fmsub_pd(ymm1, ymm15, ymm12); - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b*5 + 4)); - - ymm13 = _mm256_fmsub_pd(ymm0, ymm15, ymm13); - ymm14 = _mm256_fmsub_pd(ymm1, ymm15, ymm14); - - ///implement TRSM/// - - /* - Compute 6x8 TRSM block by using GEMM block output in register - a. The 6x8 input (gemm outputs) are stored in combinations of ymm registers - 1. ymm3, ymm4 2. ymm5, ymm6 3. ymm7, ymm8, 4. ymm9, ymm10 - 5. ymm11, ymm12 6. ymm13,ymm14 - b. Towards the end TRSM output will be stored back into b11 - */ - - //extract a55 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); - - ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm0); - - //extract a44 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); - - //(row 5):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 5)); - - ymm11 = _mm256_fnmadd_pd(ymm1, ymm13, ymm11); - ymm12 = _mm256_fnmadd_pd(ymm1, ymm14, ymm12); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 5)); - - ymm9 = _mm256_fnmadd_pd(ymm1, ymm13, ymm9); - ymm10 = _mm256_fnmadd_pd(ymm1, ymm14, ymm10); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 5)); - - ymm7 = _mm256_fnmadd_pd(ymm1, ymm13, ymm7); - ymm8 = _mm256_fnmadd_pd(ymm1, ymm14, ymm8); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 5)); - - ymm5 = _mm256_fnmadd_pd(ymm1, ymm13, ymm5); - ymm6 = _mm256_fnmadd_pd(ymm1, ymm14, ymm6); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5)); - - ymm3 = _mm256_fnmadd_pd(ymm1, ymm13, ymm3); - ymm4 = _mm256_fnmadd_pd(ymm1, ymm14, ymm4); - - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); - ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm0); - - //extract a33 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - //(row 4):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 4)); - - ymm9 = _mm256_fnmadd_pd(ymm1, ymm11, ymm9); - ymm10 = _mm256_fnmadd_pd(ymm1, ymm12, ymm10); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 4)); - - ymm7 = _mm256_fnmadd_pd(ymm1, ymm11, ymm7); - ymm8 = _mm256_fnmadd_pd(ymm1, ymm12, ymm8); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 4)); - - ymm5 = _mm256_fnmadd_pd(ymm1, ymm11, ymm5); - ymm6 = _mm256_fnmadd_pd(ymm1, ymm12, ymm6); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4)); - - ymm3 = _mm256_fnmadd_pd(ymm1, ymm11, ymm3); - ymm4 = _mm256_fnmadd_pd(ymm1, ymm12, ymm4); - - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm0); - - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 3)); - - ymm7 = _mm256_fnmadd_pd(ymm1, ymm9, ymm7); - ymm8 = _mm256_fnmadd_pd(ymm1, ymm10, ymm8); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3)); - - ymm5 = _mm256_fnmadd_pd(ymm1, ymm9, ymm5); - ymm6 = _mm256_fnmadd_pd(ymm1, ymm10, ymm6); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); - - ymm3 = _mm256_fnmadd_pd(ymm1, ymm9, ymm3); - ymm4 = _mm256_fnmadd_pd(ymm1, ymm10, ymm4); - - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2)); - - ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); - ymm6 = _mm256_fnmadd_pd(ymm1, ymm8, ymm6); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); - - ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); - ymm4 = _mm256_fnmadd_pd(ymm1, ymm8, ymm4); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm0); - - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); - - //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); - - ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); - ymm4 = _mm256_fnmadd_pd(ymm1, ymm6, ymm4); - - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm0); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + 4), ymm4); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b + 4), ymm6); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*2 + 4), ymm8); - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); - _mm256_storeu_pd((double *)(b11 + cs_b*3 + 4), ymm10); - _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); - _mm256_storeu_pd((double *)(b11 + cs_b*4 + 4), ymm12); - _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); - _mm256_storeu_pd((double *)(b11 + cs_b*5 + 4), ymm14); - } - - dim_t m_remainder = i + D_MR; - if(m_remainder >= 4) - { - a01 = D_A_pack; - a11 = L + (j*cs_a) + j; - b10 = B + (m_remainder - 4) + (j+D_NR)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (m_remainder - 4) + (j*cs_b); - - k_iter = (n-j-D_NR); //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - ymm13 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); //A01[0][3] - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 4)); //A01[0][4] - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 5)); //A01[0][5] - ymm13 = _mm256_fmadd_pd(ymm2, ymm0, ymm13); - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); - ymm11 = _mm256_fmsub_pd(ymm0, ymm15, ymm11); - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); - ymm13 = _mm256_fmsub_pd(ymm0, ymm15, ymm13); - - ///implement TRSM/// - - //extract a55 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); - - ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - - //extract a44 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); - - //(row 5):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 5)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm13, ymm11); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 5)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm13, ymm9); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 5)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm13, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 5)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm13, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm13, ymm3); - - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); - - //extract a33 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - //(row 4):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 4)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm11, ymm9); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 4)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm11, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 4)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm11, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm11, ymm3); - - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 3)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm9, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm9, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm9, ymm3); - - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); - - //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); - - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); - _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); - _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); - - m_remainder -=4; - } - - if(m_remainder) - { - if(3 == m_remainder) - { - a01 = D_A_pack; - a11 = L + (j*cs_a) + j; - b10 = B + (j+D_NR)*cs_b + (m_remainder - 3); //pointer to block of B to be used in GEMM - b11 = B + (m_remainder - 3) + (j*cs_b); - - k_iter = (n-j-D_NR); //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - ymm13 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); //A01[0][3] - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 4)); //A01[0][4] - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 5)); //A01[0][5] - ymm13 = _mm256_fmadd_pd(ymm2, ymm0, ymm13); - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); - ymm11 = _mm256_fmsub_pd(ymm0, ymm15, ymm11); - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); - ymm13 = _mm256_fmsub_pd(ymm0, ymm15, ymm13); - - ///implement TRSM/// - - //extract a55 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); - ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - - //extract a44 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); - - //(row 5):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 5)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm13, ymm11); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 5)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm13, ymm9); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 5)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm13, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 5)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm13, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm13, ymm3); - - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); - - //extract a33 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - //(row 4):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 4)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm11, ymm9); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 4)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm11, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 4)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm11, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm11, ymm3); - - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 3)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm9, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm9, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm9, ymm3); - - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); - - //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); - - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_pd(ymm0, ymm11, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_pd(ymm0, ymm13, 0x07); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); - _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); - _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); - - m_remainder -=3; - } - else if(2 == m_remainder) - { - a01 = D_A_pack; - a11 = L + (j*cs_a) + j; - b10 = B + (j+D_NR)*cs_b + (m_remainder - 2); //pointer to block of B to be used in GEMM - b11 = B + (m_remainder - 2) + (j*cs_b); - - k_iter = (n-j-D_NR); //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - ymm13 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); //A01[0][3] - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 4)); //A01[0][4] - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 5)); //A01[0][5] - ymm13 = _mm256_fmadd_pd(ymm2, ymm0, ymm13); - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); - ymm11 = _mm256_fmsub_pd(ymm0, ymm15, ymm11); - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); - ymm13 = _mm256_fmsub_pd(ymm0, ymm15, ymm13); - - ///implement TRSM/// - - //extract a55 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); - ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - - //extract a44 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); - - //(row 5):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 5)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm13, ymm11); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 5)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm13, ymm9); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 5)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm13, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 5)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm13, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm13, ymm3); - - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); - - //extract a33 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - //(row 4):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 4)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm11, ymm9); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 4)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm11, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 4)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm11, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm11, ymm3); - - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 3)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm9, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm9, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm9, ymm3); - - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); - - //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); - - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_pd(ymm0, ymm11, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_pd(ymm0, ymm13, 0x03); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); - _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); - _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); - - m_remainder -=2; - } - else if (1 == m_remainder) - { - a01 = D_A_pack; - a11 = L + (j*cs_a) + j; - b10 = B + (j+D_NR)*cs_b + (m_remainder - 1); //pointer to block of B to be used in GEMM - b11 = B + (m_remainder - 1) + (j*cs_b); - - k_iter = (n-j-D_NR); //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - ymm13 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); //A01[0][3] - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 4)); //A01[0][4] - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 5)); //A01[0][5] - ymm13 = _mm256_fmadd_pd(ymm2, ymm0, ymm13); - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); - ymm11 = _mm256_fmsub_pd(ymm0, ymm15, ymm11); - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); - ymm13 = _mm256_fmsub_pd(ymm0, ymm15, ymm13); - - ///implement TRSM/// - - //extract a55 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); - ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - - //extract a44 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); - - //(row 5):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 5)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm13, ymm11); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 5)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm13, ymm9); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 5)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm13, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 5)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm13, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm13, ymm3); - - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); - - //extract a33 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - //(row 4):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 4)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm11, ymm9); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 4)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm11, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 4)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm11, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm11, ymm3); - - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 3)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm9, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm9, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm9, ymm3); - - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); - - //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); - - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_pd(ymm0, ymm11, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_pd(ymm0, ymm13, 0x01); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); - _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); - _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); - - m_remainder -=1; - } - } - } - - dim_t n_remainder = j + D_NR; - - /* - Reminder cases starts here: - a. Similar logic and code flow used in computing full block (6x8) - above holds for reminder cases too. - */ - - if(n_remainder >= 4) - { - a01 = L + (n_remainder - 4)*cs_a + n_remainder; //pointer to block of A to be used in GEMM - a11 = L + (n_remainder - 4)*cs_a + (n_remainder - 4); //pointer to block of A to be used for TRSM - - double *ptr_a10_dup = D_A_pack; - - dim_t p_lda = (n-n_remainder); // packed leading dimension - // perform copy of A to packed buffer D_A_pack - - dim_t loop_count = (n-n_remainder)/4; - - for(dim_t x =0;x < loop_count;x++) - { - ymm15 = _mm256_loadu_pd((double const *)(a01 + cs_a * 0 + x*4)); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + x*4), ymm15); - ymm15 = _mm256_loadu_pd((double const *)(a01 + cs_a * 1 + x*4)); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 1 + x*4), ymm15); - ymm15 = _mm256_loadu_pd((double const *)(a01 + cs_a * 2 + x*4)); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 2 + x*4), ymm15); - ymm15 = _mm256_loadu_pd((double const *)(a01 + cs_a * 3 + x*4)); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 3 + x*4), ymm15); - } - - dim_t remainder_loop_count = p_lda - loop_count*4; - - __m128d xmm0; - if(remainder_loop_count != 0) - { - xmm0 = _mm_loadu_pd((double const *)(a01 + cs_a * 0 + loop_count*4)); - _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + loop_count*4), xmm0); - xmm0 = _mm_loadu_pd((double const *)(a01 + cs_a * 1 + loop_count*4)); - _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 1 + loop_count*4), xmm0); - xmm0 = _mm_loadu_pd((double const *)(a01 + cs_a * 2 + loop_count*4)); - _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 2 + loop_count*4), xmm0); - xmm0 = _mm_loadu_pd((double const *)(a01 + cs_a * 3 + loop_count*4)); - _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 3 + loop_count*4), xmm0); - } - - ymm4 = _mm256_broadcast_sd((double const *)&ones); - if(!is_unitdiag) - { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_sd((double const *)(a11)); - ymm1 = _mm256_broadcast_sd((double const *)(a11+cs_a*1 + 1)); - ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a*2 + 2)); - ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a*3 + 3)); - - ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); - - ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); - #ifdef BLIS_DISABLE_TRSM_PREINVERSION - ymm4 = ymm1; - #endif - #ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm4 = _mm256_div_pd(ymm4, ymm1); - #endif - } - _mm256_storeu_pd((double *)(d11_pack), ymm4); - - for(i = (m-D_MR); (i+1) > 0; i -= D_MR) //loop along 'M' direction - { - a01 = D_A_pack; - a11 = L + (n_remainder - 4)*cs_a + (n_remainder - 4); //pointer to block of A to be used for TRSM - b10 = B + i + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (i) + (n_remainder - 4)*cs_b; //pointer to block of B to be used for TRSM - - k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); - #endif - - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b10 + 4)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); //A01[0][3] - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + 4)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] - - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - ymm4 = _mm256_fmsub_pd(ymm1, ymm15, ymm4); //B11[4-7][0] * alpha-= ymm1 + ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b + 4)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] + ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); + #ifdef BLIS_DISABLE_TRSM_PREINVERSION + ymm4 = ymm1; + #endif + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + ymm4 = _mm256_div_pd(ymm4, ymm1); + #endif + } + _mm256_storeu_pd((double *)(d11_pack), ymm4); - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - ymm6 = _mm256_fmsub_pd(ymm1, ymm15, ymm6); //B11[4-7][1] * alpha -= ymm3 + for(i = (m-d_mr); (i+1) > 0; i -= d_mr) //loop along 'M' direction + { + a01 = D_A_pack; + a11 = L + (n_remainder - 4)*cs_a + (n_remainder - 4)*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (i) + (n_remainder - 4)*cs_b; //pointer to block of B to be used for TRSM - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b*2 + 4)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] + k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - ymm8 = _mm256_fmsub_pd(ymm1, ymm15, ymm8); //B11[4-7][2] * alpha -= ymm5 + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b*3 + 4)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_4nx8m(a01,b10,cs_b,p_lda,k_iter) - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 - ymm10 = _mm256_fmsub_pd(ymm1, ymm15, ymm10); //B11[4-7][3] * alpha -= ymm7 + BLIS_PRE_DTRSM_SMALL_4x8(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -18459,17 +5302,17 @@ BLIS_INLINE err_t bli_dtrsm_small_XAlB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 3)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2*rs_a)); ymm7 = _mm256_fnmadd_pd(ymm1, ymm9, ymm7); ymm8 = _mm256_fnmadd_pd(ymm1, ymm10, ymm8); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1*rs_a)); ymm5 = _mm256_fnmadd_pd(ymm1, ymm9, ymm5); ymm6 = _mm256_fnmadd_pd(ymm1, ymm10, ymm6); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); ymm3 = _mm256_fnmadd_pd(ymm1, ymm9, ymm3); ymm4 = _mm256_fnmadd_pd(ymm1, ymm10, ymm4); @@ -18481,12 +5324,12 @@ BLIS_INLINE err_t bli_dtrsm_small_XAlB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1*rs_a)); ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); ymm6 = _mm256_fnmadd_pd(ymm1, ymm8, ymm6); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); ymm4 = _mm256_fnmadd_pd(ymm1, ymm8, ymm4); @@ -18498,7 +5341,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAlB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); ymm4 = _mm256_fnmadd_pd(ymm1, ymm6, ymm4); @@ -18516,43 +5359,21 @@ BLIS_INLINE err_t bli_dtrsm_small_XAlB _mm256_storeu_pd((double *)(b11 + cs_b*3 + 4), ymm10); } - dim_t m_remainder = i + D_MR; + dim_t m_remainder = i + d_mr; if(m_remainder >= 4) { a01 = D_A_pack; - a11 = L + (n_remainder - 4)*cs_a + (n_remainder - 4); //pointer to block of A to be used for TRSM + a11 = L + (n_remainder - 4)*cs_a + (n_remainder - 4)*rs_a; //pointer to block of A to be used for TRSM b10 = B + (m_remainder - 4) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM b11 = B + (m_remainder - 4) + (n_remainder - 4)*cs_b; //pointer to block of B to be used for TRSM k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); //A01[0][3] - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - a01 += 1; //move to next row - b10 += cs_b; - } + BLIS_DTRSM_SMALL_GEMM_4nx4m(a01,b10,cs_b,p_lda,k_iter) ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha @@ -18572,20 +5393,19 @@ BLIS_INLINE err_t bli_dtrsm_small_XAlB //extract a33 ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); //extract a22 ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 3)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2*rs_a)); ymm7 = _mm256_fnmadd_pd(ymm1, ymm9, ymm7); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1*rs_a)); ymm5 = _mm256_fnmadd_pd(ymm1, ymm9, ymm5); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); ymm3 = _mm256_fnmadd_pd(ymm1, ymm9, ymm3); ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); @@ -18594,10 +5414,10 @@ BLIS_INLINE err_t bli_dtrsm_small_XAlB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1*rs_a)); ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); @@ -18606,7 +5426,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAlB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); @@ -18624,39 +5444,17 @@ BLIS_INLINE err_t bli_dtrsm_small_XAlB if(3 == m_remainder) { a01 = D_A_pack; - a11 = L + (n_remainder - 4)*cs_a + (n_remainder - 4); //pointer to block of A to be used for TRSM + a11 = L + (n_remainder - 4)*cs_a + (n_remainder - 4)*rs_a; //pointer to block of A to be used for TRSM b10 = B + (m_remainder - 3) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM b11 = B + (m_remainder - 3) + (n_remainder - 4)*cs_b; //pointer to block of B to be used for TRSM k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); //A01[0][3] - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - a01 += 1; //move to next row - b10 += cs_b; - } + BLIS_DTRSM_SMALL_GEMM_4nx4m(a01,b10,cs_b,p_lda,k_iter) ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha @@ -18678,20 +5476,19 @@ BLIS_INLINE err_t bli_dtrsm_small_XAlB //extract a33 ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); //extract a22 ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 3)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2*rs_a)); ymm7 = _mm256_fnmadd_pd(ymm1, ymm9, ymm7); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1*rs_a)); ymm5 = _mm256_fnmadd_pd(ymm1, ymm9, ymm5); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); ymm3 = _mm256_fnmadd_pd(ymm1, ymm9, ymm3); ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); @@ -18700,10 +5497,10 @@ BLIS_INLINE err_t bli_dtrsm_small_XAlB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1*rs_a)); ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); @@ -18712,7 +5509,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAlB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); @@ -18740,39 +5537,17 @@ BLIS_INLINE err_t bli_dtrsm_small_XAlB else if(2 == m_remainder) { a01 = D_A_pack; - a11 = L + (n_remainder - 4)*cs_a + (n_remainder - 4); //pointer to block of A to be used for TRSM + a11 = L + (n_remainder - 4)*cs_a + (n_remainder - 4)*rs_a; //pointer to block of A to be used for TRSM b10 = B + (m_remainder - 2) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM b11 = B + (m_remainder - 2) + (n_remainder - 4)*cs_b; //pointer to block of B to be used for TRSM k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); //A01[0][3] - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - a01 += 1; //move to next row - b10 += cs_b; - } + BLIS_DTRSM_SMALL_GEMM_4nx4m(a01,b10,cs_b,p_lda,k_iter) ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha @@ -18793,20 +5568,19 @@ BLIS_INLINE err_t bli_dtrsm_small_XAlB //extract a33 ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); //extract a22 ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 3)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2*rs_a)); ymm7 = _mm256_fnmadd_pd(ymm1, ymm9, ymm7); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1*rs_a)); ymm5 = _mm256_fnmadd_pd(ymm1, ymm9, ymm5); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); ymm3 = _mm256_fnmadd_pd(ymm1, ymm9, ymm3); ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); @@ -18815,10 +5589,10 @@ BLIS_INLINE err_t bli_dtrsm_small_XAlB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1*rs_a)); ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); @@ -18827,7 +5601,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAlB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); @@ -18853,39 +5627,17 @@ BLIS_INLINE err_t bli_dtrsm_small_XAlB else if (1 == m_remainder) { a01 = D_A_pack; - a11 = L + (n_remainder - 4)*cs_a + (n_remainder - 4); //pointer to block of A to be used for TRSM + a11 = L + (n_remainder - 4)*cs_a + (n_remainder - 4)*rs_a; //pointer to block of A to be used for TRSM b10 = B + (m_remainder - 1) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM b11 = B + (m_remainder - 1) + (n_remainder - 4)*cs_b; //pointer to block of B to be used for TRSM k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); //A01[0][3] - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - a01 += 1; //move to next row - b10 += cs_b; - } + BLIS_DTRSM_SMALL_GEMM_4nx4m(a01,b10,cs_b,p_lda,k_iter) ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha @@ -18905,20 +5657,19 @@ BLIS_INLINE err_t bli_dtrsm_small_XAlB //extract a33 ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); //extract a22 ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 3)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2*rs_a)); ymm7 = _mm256_fnmadd_pd(ymm1, ymm9, ymm7); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1*rs_a)); ymm5 = _mm256_fnmadd_pd(ymm1, ymm9, ymm5); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); ymm3 = _mm256_fnmadd_pd(ymm1, ymm9, ymm3); ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); @@ -18927,10 +5678,10 @@ BLIS_INLINE err_t bli_dtrsm_small_XAlB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1*rs_a)); ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); @@ -18939,25 +5690,25 @@ BLIS_INLINE err_t bli_dtrsm_small_XAlB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm0 = _mm256_broadcast_sd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x01); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x01); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x01); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x01); - - _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm3, 0)); - _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm5, 0)); - _mm_storel_pd((b11 + cs_b * 2), _mm256_extractf128_pd(ymm7, 0)); - _mm_storel_pd((b11 + cs_b * 3), _mm256_extractf128_pd(ymm9, 0)); - - m_remainder -=1; + ymm0 = _mm256_broadcast_sd((double const *)b11); + ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x01); + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x01); + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x01); + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x01); + + _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm3, 0)); + _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm5, 0)); + _mm_storel_pd((b11 + cs_b * 2), _mm256_extractf128_pd(ymm7, 0)); + _mm_storel_pd((b11 + cs_b * 3), _mm256_extractf128_pd(ymm9, 0)); + + m_remainder -=1; } } n_remainder -= 4; @@ -18965,47 +5716,110 @@ BLIS_INLINE err_t bli_dtrsm_small_XAlB if(n_remainder == 3) { - a01 = L + (n_remainder - 3)*cs_a + n_remainder; //pointer to block of A to be used in GEMM - a11 = L + (n_remainder - 3)*cs_a + (n_remainder - 3); //pointer to block of A to be used for TRSM + a01 = L + (n_remainder - 3)*rs_a + n_remainder*cs_a; //pointer to block of A to be used in GEMM + a11 = L + (n_remainder - 3)*cs_a + (n_remainder - 3)*rs_a; //pointer to block of A to be used for TRSM double *ptr_a10_dup = D_A_pack; dim_t p_lda = (n-n_remainder); // packed leading dimension // perform copy of A to packed buffer D_A_pack - dim_t loop_count = (n-n_remainder)/4; - - for(dim_t x =0;x < loop_count;x++) + if(transa) { - ymm15 = _mm256_loadu_pd((double const *)(a01 + cs_a * 0 + x*4)); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + x*4), ymm15); - ymm15 = _mm256_loadu_pd((double const *)(a01 + cs_a * 1 + x*4)); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 1 + x*4), ymm15); - ymm15 = _mm256_loadu_pd((double const *)(a01 + cs_a * 2 + x*4)); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 2 + x*4), ymm15); - } + for(dim_t x =0;x < p_lda;x+=d_nr) + { + ymm0 = _mm256_loadu_pd((double const *)(a01)); + ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); + ymm2 = _mm256_loadu_pd((double const *)(a01 + cs_a * 2)); + ymm3 = _mm256_loadu_pd((double const *)(a01 + cs_a * 3)); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); + + ymm0 = _mm256_loadu_pd((double const *)(a01 + cs_a * 4)); + ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a * 5)); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_broadcast_sd((double const *)&zero); + + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_broadcast_sd((double const *)&zero); + + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - dim_t remainder_loop_count = p_lda - loop_count*4; + _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_extractf128_pd(ymm6,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_extractf128_pd(ymm7,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_extractf128_pd(ymm8,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_extractf128_pd(ymm9,0)); - __m128d xmm0; - if(remainder_loop_count != 0) + a01 += d_nr*cs_a; + ptr_a10_dup += d_nr; + } + } + else { - xmm0 = _mm_loadu_pd((double const *)(a01 + cs_a * 0 + loop_count*4)); - _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + loop_count*4), xmm0); - xmm0 = _mm_loadu_pd((double const *)(a01 + cs_a * 1 + loop_count*4)); - _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 1 + loop_count*4), xmm0); - xmm0 = _mm_loadu_pd((double const *)(a01 + cs_a * 2 + loop_count*4)); - _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 2 + loop_count*4), xmm0); + dim_t loop_count = (n-n_remainder)/4; + + for(dim_t x =0;x < loop_count;x++) + { + ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 0 + x*4)); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + x*4), ymm15); + ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 1 + x*4)); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 1 + x*4), ymm15); + ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 2 + x*4)); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 2 + x*4), ymm15); + } + + dim_t remainder_loop_count = p_lda - loop_count*4; + + __m128d xmm0; + if(remainder_loop_count != 0) + { + xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 0 + loop_count*4)); + _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + loop_count*4), xmm0); + xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 1 + loop_count*4)); + _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 1 + loop_count*4), xmm0); + xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 2 + loop_count*4)); + _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 2 + loop_count*4), xmm0); + } } ymm4 = _mm256_broadcast_sd((double const *)&ones); if(!is_unitdiag) { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_sd((double const *)(a11)); - ymm1 = _mm256_broadcast_sd((double const *)(a11+cs_a*1 + 1)); - ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a*2 + 2)); - ymm3 = _mm256_broadcast_sd((double const *)&ones); + if(transa) + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a*1 + 1)); + ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a*2 + 2)); + ymm3 = _mm256_broadcast_sd((double const *)&ones); + } + else + { + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11+ rs_a*1 + 1)); + ymm2 = _mm256_broadcast_sd((double const *)(a11+ rs_a*2 + 2)); + ymm3 = _mm256_broadcast_sd((double const *)&ones); + } ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); @@ -19020,62 +5834,28 @@ BLIS_INLINE err_t bli_dtrsm_small_XAlB } _mm256_storeu_pd((double *)(d11_pack), ymm4); - for(i = (m-D_MR); (i+1) > 0; i -= D_MR) //loop along 'M' direction + for(i = (m-d_mr); (i+1) > 0; i -= d_mr) //loop along 'M' direction { a01 = D_A_pack; - a11 = L + (n_remainder - 3)*cs_a + (n_remainder - 3); //pointer to block of A to be used for TRSM + a11 = L + (n_remainder - 3)*cs_a + (n_remainder - 3)*rs_a; //pointer to block of A to be used for TRSM b10 = B + i + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM b11 = B + (i) + (n_remainder - 3)*cs_b; //pointer to block of B to be used for TRSM k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); - #endif - - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b10 + 4)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) - - a01 += 1; //move to next row - b10 += cs_b; - } + BLIS_DTRSM_SMALL_GEMM_3nx8m(a01,b10,cs_b,p_lda,k_iter) ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + 4)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] + ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + 4)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - ymm4 = _mm256_fmsub_pd(ymm1, ymm15, ymm4); //B11[4-7][0] * alpha-= ymm1 + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + ymm4 = _mm256_fmsub_pd(ymm1, ymm15, ymm4); //B11[4-7][0] * alpha-= ymm1 ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b + 4)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] @@ -19101,12 +5881,12 @@ BLIS_INLINE err_t bli_dtrsm_small_XAlB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1*rs_a)); ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); ymm6 = _mm256_fnmadd_pd(ymm1, ymm8, ymm6); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); ymm4 = _mm256_fnmadd_pd(ymm1, ymm8, ymm4); @@ -19118,7 +5898,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAlB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); ymm4 = _mm256_fnmadd_pd(ymm1, ymm6, ymm4); @@ -19134,39 +5914,21 @@ BLIS_INLINE err_t bli_dtrsm_small_XAlB _mm256_storeu_pd((double *)(b11 + cs_b*2 + 4), ymm8); } - dim_t m_remainder = i + D_MR; + dim_t m_remainder = i + d_mr; if(m_remainder >= 4) { a01 = D_A_pack; - a11 = L + (n_remainder - 3)*cs_a + (n_remainder - 3); //pointer to block of A to be used for TRSM + a11 = L + (n_remainder - 3)*cs_a + (n_remainder - 3)*rs_a; //pointer to block of A to be used for TRSM b10 = B + (m_remainder - 4) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM b11 = B + (m_remainder - 4) + (n_remainder - 3)*cs_b; //pointer to block of B to be used for TRSM k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - a01 += 1; //move to next row - b10 += cs_b; - } + BLIS_DTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha @@ -19183,17 +5945,16 @@ BLIS_INLINE err_t bli_dtrsm_small_XAlB //extract a22 ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); //extract a11 ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1*rs_a)); ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); @@ -19202,7 +5963,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAlB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); @@ -19219,64 +5980,33 @@ BLIS_INLINE err_t bli_dtrsm_small_XAlB if(3 == m_remainder) { a01 = D_A_pack; - a11 = L + (n_remainder - 3)*cs_a + (n_remainder - 3); //pointer to block of A to be used for TRSM + a11 = L + (n_remainder - 3)*cs_a + (n_remainder - 3)*rs_a; //pointer to block of A to be used for TRSM b10 = B + (m_remainder - 3) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM b11 = B + (m_remainder - 3) + (n_remainder - 3)*cs_b; //pointer to block of B to be used for TRSM k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + BLIS_DTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2)); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*2 + 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 + BLIS_PRE_DTRSM_SMALL_3N_3M(AlphaVal,b11,cs_b) ///implement TRSM/// - //extract a22 ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); //extract a11 ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1*rs_a)); ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); @@ -19285,88 +6015,46 @@ BLIS_INLINE err_t bli_dtrsm_small_XAlB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x07); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2)); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*2 + 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x07); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - xmm5 = _mm256_extractf128_pd(ymm7, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 2),xmm5); - _mm_storel_pd((b11 + cs_b * 2 + 2), _mm256_extractf128_pd(ymm7, 1)); + BLIS_POST_DTRSM_SMALL_3N_3M(b11,cs_b) m_remainder -=3; } else if(2 == m_remainder) { a01 = D_A_pack; - a11 = L + (n_remainder - 3)*cs_a + (n_remainder - 3); //pointer to block of A to be used for TRSM + a11 = L + (n_remainder - 3)*cs_a + (n_remainder - 3)*rs_a; //pointer to block of A to be used for TRSM b10 = B + (m_remainder - 2) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM b11 = B + (m_remainder - 2) + (n_remainder - 3)*cs_b; //pointer to block of B to be used for TRSM k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + BLIS_DTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 + BLIS_PRE_DTRSM_SMALL_3N_2M(AlphaVal,b11,cs_b) ///implement TRSM/// //extract a22 ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); //extract a11 ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1*rs_a)); ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); @@ -19375,85 +6063,46 @@ BLIS_INLINE err_t bli_dtrsm_small_XAlB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x03); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x03); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - xmm5 = _mm256_extractf128_pd(ymm7, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 2),xmm5); + BLIS_POST_DTRSM_SMALL_3N_2M(b11,cs_b) m_remainder -=2; } else if (1 == m_remainder) { a01 = D_A_pack; - a11 = L + (n_remainder - 3)*cs_a + (n_remainder - 3); //pointer to block of A to be used for TRSM + a11 = L + (n_remainder - 3)*cs_a + (n_remainder - 3)*rs_a; //pointer to block of A to be used for TRSM b10 = B + (m_remainder - 1) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM b11 = B + (m_remainder - 1) + (n_remainder - 3)*cs_b; //pointer to block of B to be used for TRSM k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + BLIS_DTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) - ymm0 = _mm256_broadcast_sd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 + BLIS_PRE_DTRSM_SMALL_3N_1M(AlphaVal,b11,cs_b) ///implement TRSM/// //extract a22 ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); //extract a11 ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1*rs_a)); ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); @@ -19462,21 +6111,12 @@ BLIS_INLINE err_t bli_dtrsm_small_XAlB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm0 = _mm256_broadcast_sd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x01); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x01); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x01); - - _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm3, 0)); - _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm5, 0)); - _mm_storel_pd((b11 + cs_b * 2), _mm256_extractf128_pd(ymm7, 0)); + BLIS_POST_DTRSM_SMALL_3N_1M(b11,cs_b) m_remainder -=1; } @@ -19485,41 +6125,103 @@ BLIS_INLINE err_t bli_dtrsm_small_XAlB } else if(n_remainder == 2) { - a01 = L + (n_remainder - 2)*cs_a + n_remainder; //pointer to block of A to be used in GEMM - a11 = L + (n_remainder - 2)*cs_a + (n_remainder - 2); //pointer to block of A to be used for TRSM + a01 = L + (n_remainder - 2)*rs_a + n_remainder*cs_a; //pointer to block of A to be used in GEMM + a11 = L + (n_remainder - 2)*cs_a + (n_remainder - 2)*rs_a; //pointer to block of A to be used for TRSM double *ptr_a10_dup = D_A_pack; dim_t p_lda = (n-n_remainder); // packed leading dimension // perform copy of A to packed buffer D_A_pack - dim_t loop_count = (n-n_remainder)/4; - - for(dim_t x =0;x < loop_count;x++) + if(transa) { - ymm15 = _mm256_loadu_pd((double const *)(a01 + cs_a * 0 + x*4)); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + x*4), ymm15); - ymm15 = _mm256_loadu_pd((double const *)(a01 + cs_a * 1 + x*4)); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 1 + x*4), ymm15); - } + for(dim_t x =0;x < p_lda;x+=d_nr) + { + ymm0 = _mm256_loadu_pd((double const *)(a01)); + ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); + ymm2 = _mm256_loadu_pd((double const *)(a01 + cs_a * 2)); + ymm3 = _mm256_loadu_pd((double const *)(a01 + cs_a * 3)); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); + + ymm0 = _mm256_loadu_pd((double const *)(a01 + cs_a * 4)); + ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a * 5)); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_broadcast_sd((double const *)&zero); + + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_broadcast_sd((double const *)&zero); + + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - dim_t remainder_loop_count = p_lda - loop_count*4; + _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_extractf128_pd(ymm6,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_extractf128_pd(ymm7,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_extractf128_pd(ymm8,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_extractf128_pd(ymm9,0)); - __m128d xmm0; - if(remainder_loop_count != 0) + a01 += d_nr*cs_a; + ptr_a10_dup += d_nr; + } + } + else { - xmm0 = _mm_loadu_pd((double const *)(a01 + cs_a * 0 + loop_count*4)); - _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + loop_count*4), xmm0); - xmm0 = _mm_loadu_pd((double const *)(a01 + cs_a * 1 + loop_count*4)); - _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 1 + loop_count*4), xmm0); + dim_t loop_count = (n-n_remainder)/4; + + for(dim_t x =0;x < loop_count;x++) + { + ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 0 + x*4)); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + x*4), ymm15); + ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 1 + x*4)); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 1 + x*4), ymm15); + } + + dim_t remainder_loop_count = p_lda - loop_count*4; + + __m128d xmm0; + if(remainder_loop_count != 0) + { + xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 0 + loop_count*4)); + _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + loop_count*4), xmm0); + xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 1 + loop_count*4)); + _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 1 + loop_count*4), xmm0); + } } ymm4 = _mm256_broadcast_sd((double const *)&ones); if(!is_unitdiag) { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_sd((double const *)(a11)); - ymm1 = _mm256_broadcast_sd((double const *)(a11+cs_a*1 + 1)); + if(transa) + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11+cs_a*1 + 1)); + } + else + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11+rs_a*1 + 1)); + } ymm2 = _mm256_broadcast_sd((double const *)&ones); ymm3 = _mm256_broadcast_sd((double const *)&ones); @@ -19536,45 +6238,20 @@ BLIS_INLINE err_t bli_dtrsm_small_XAlB } _mm256_storeu_pd((double *)(d11_pack), ymm4); - for(i = (m-D_MR); (i+1) > 0; i -= D_MR) //loop along 'M' direction + for(i = (m-d_mr); (i+1) > 0; i -= d_mr) //loop along 'M' direction { a01 = D_A_pack; - a11 = L + (n_remainder - 2)*cs_a + (n_remainder - 2); //pointer to block of A to be used for TRSM + a11 = L + (n_remainder - 2)*cs_a + (n_remainder - 2)*rs_a; //pointer to block of A to be used for TRSM b10 = B + i + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM b11 = B + (i) + (n_remainder - 2)*cs_b; //pointer to block of B to be used for TRSM k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + 4), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*1), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + 4 + cs_b*1), _MM_HINT_T0); - #endif + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b10 + 4)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) - - a01 += 1; //move to next row - b10 += cs_b; - } + BLIS_DTRSM_SMALL_GEMM_2nx8m(a01,b10,cs_b,p_lda,k_iter) ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); @@ -19591,8 +6268,10 @@ BLIS_INLINE err_t bli_dtrsm_small_XAlB ymm6 = _mm256_fmsub_pd(ymm1, ymm15, ymm6); //B11[4-7][1] * alpha -= ymm3 ///implement TRSM/// + //extract a11 ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm0); @@ -19600,7 +6279,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAlB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); ymm4 = _mm256_fnmadd_pd(ymm1, ymm6, ymm4); @@ -19614,35 +6293,21 @@ BLIS_INLINE err_t bli_dtrsm_small_XAlB _mm256_storeu_pd((double *)(b11 + cs_b + 4), ymm6); } - dim_t m_remainder = i + D_MR; + dim_t m_remainder = i + d_mr; if(m_remainder >= 4) { a01 = D_A_pack; - a11 = L + (n_remainder - 2)*cs_a + (n_remainder - 2); //pointer to block of A to be used for TRSM + a11 = L + (n_remainder - 2)*cs_a + (n_remainder - 2)*rs_a; //pointer to block of A to be used for TRSM b10 = B + (m_remainder - 4) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM b11 = B + (m_remainder - 4) + (n_remainder - 2)*cs_b; //pointer to block of B to be used for TRSM k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - a01 += 1; //move to next row - b10 += cs_b; - } + BLIS_DTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha @@ -19653,6 +6318,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAlB ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 ///implement TRSM/// + //extract a11 ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); @@ -19661,7 +6327,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAlB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); @@ -19677,43 +6343,22 @@ BLIS_INLINE err_t bli_dtrsm_small_XAlB if(3 == m_remainder) { a01 = D_A_pack; - a11 = L + (n_remainder - 2)*cs_a + (n_remainder - 2); //pointer to block of A to be used for TRSM + a11 = L + (n_remainder - 2)*cs_a + (n_remainder - 2)*rs_a; //pointer to block of A to be used for TRSM b10 = B + (m_remainder - 3) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM b11 = B + (m_remainder - 3) + (n_remainder - 2)*cs_b; //pointer to block of B to be used for TRSM k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - a01 += 1; //move to next row - b10 += cs_b; - } + BLIS_DTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1)); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*1 + 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + BLIS_PRE_DTRSM_SMALL_2N_3M(AlphaVal,b11,cs_b) ///implement TRSM/// + //extract a11 ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); @@ -19722,64 +6367,33 @@ BLIS_INLINE err_t bli_dtrsm_small_XAlB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x07); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1)); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*1 + 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x07); - - _mm256_storeu_pd((double *)b11, ymm3); - xmm5 = _mm256_extractf128_pd(ymm5, 0); - _mm_storeu_pd((double *)(b11 + cs_b*1), xmm5); - _mm_storel_pd((b11 + cs_b * 1 + 2), _mm256_extractf128_pd(ymm5, 1)); + BLIS_POST_DTRSM_SMALL_2N_3M(b11,cs_b) m_remainder -=3; } else if(2 == m_remainder) { a01 = D_A_pack; - a11 = L + (n_remainder - 2)*cs_a + (n_remainder - 2); //pointer to block of A to be used for TRSM + a11 = L + (n_remainder - 2)*cs_a + (n_remainder - 2)*rs_a; //pointer to block of A to be used for TRSM b10 = B + (m_remainder - 2) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM b11 = B + (m_remainder - 2) + (n_remainder - 2)*cs_b; //pointer to block of B to be used for TRSM k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + BLIS_DTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_PRE_DTRSM_SMALL_2N_2M(AlphaVal,b11,cs_b) ///implement TRSM/// + //extract a11 ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); @@ -19788,61 +6402,33 @@ BLIS_INLINE err_t bli_dtrsm_small_XAlB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x03); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x03); - - _mm256_storeu_pd((double *)b11, ymm3); - xmm5 = _mm256_extractf128_pd(ymm5, 0); - _mm_storeu_pd((double *)(b11 + cs_b*1), xmm5); + BLIS_POST_DTRSM_SMALL_2N_2M(b11,cs_b) m_remainder -=2; } else if (1 == m_remainder) { a01 = D_A_pack; - a11 = L + (n_remainder - 2)*cs_a + (n_remainder - 2); //pointer to block of A to be used for TRSM + a11 = L + (n_remainder - 2)*cs_a + (n_remainder - 2)*rs_a; //pointer to block of A to be used for TRSM b10 = B + (m_remainder - 1) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM b11 = B + (m_remainder - 1) + (n_remainder - 2)*cs_b; //pointer to block of B to be used for TRSM k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - - ymm0 = _mm256_broadcast_sd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + BLIS_DTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_PRE_DTRSM_SMALL_2N_1M(AlphaVal,b11,cs_b) ///implement TRSM/// + //extract a11 ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); @@ -19851,18 +6437,12 @@ BLIS_INLINE err_t bli_dtrsm_small_XAlB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm0 = _mm256_broadcast_sd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x01); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x01); - - _mm_storel_pd(b11 , _mm256_extractf128_pd(ymm3, 0)); - _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm5, 0)); + BLIS_POST_DTRSM_SMALL_2N_1M(b11,cs_b) m_remainder -=1; } @@ -19871,29 +6451,82 @@ BLIS_INLINE err_t bli_dtrsm_small_XAlB } else if(n_remainder == 1) { - a01 = L + (n_remainder - 1)*cs_a + n_remainder; //pointer to block of A to be used in GEMM - a11 = L + (n_remainder - 1)*cs_a + (n_remainder - 1); //pointer to block of A to be used for TRSM + a01 = L + (n_remainder - 1)*rs_a + n_remainder*cs_a; //pointer to block of A to be used in GEMM + a11 = L + (n_remainder - 1)*cs_a + (n_remainder - 1)*rs_a; //pointer to block of A to be used for TRSM double *ptr_a10_dup = D_A_pack; dim_t p_lda = (n-n_remainder); // packed leading dimension // perform copy of A to packed buffer D_A_pack - dim_t loop_count = (n-n_remainder)/4; - - for(dim_t x =0;x < loop_count;x++) + if(transa) { - ymm15 = _mm256_loadu_pd((double const *)(a01 + cs_a * 0 + x*4)); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + x*4), ymm15); - } + for(dim_t x =0;x < p_lda;x+=d_nr) + { + ymm0 = _mm256_loadu_pd((double const *)(a01)); + ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); + ymm2 = _mm256_loadu_pd((double const *)(a01 + cs_a * 2)); + ymm3 = _mm256_loadu_pd((double const *)(a01 + cs_a * 3)); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); + + ymm0 = _mm256_loadu_pd((double const *)(a01 + cs_a * 4)); + ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a * 5)); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_broadcast_sd((double const *)&zero); + + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_broadcast_sd((double const *)&zero); + + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - dim_t remainder_loop_count = p_lda - loop_count*4; + _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_extractf128_pd(ymm6,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_extractf128_pd(ymm7,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_extractf128_pd(ymm8,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_extractf128_pd(ymm9,0)); - __m128d xmm0; - if(remainder_loop_count != 0) + a01 += d_nr*cs_a; + ptr_a10_dup += d_nr; + } + } + else { - xmm0 = _mm_loadu_pd((double const *)(a01 + cs_a * 0 + loop_count*4)); - _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + loop_count*4), xmm0); + dim_t loop_count = (n-n_remainder)/4; + + for(dim_t x =0;x < loop_count;x++) + { + ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 0 + x*4)); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + x*4), ymm15); + } + + dim_t remainder_loop_count = p_lda - loop_count*4; + + __m128d xmm0; + if(remainder_loop_count != 0) + { + xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 0 + loop_count*4)); + _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + loop_count*4), xmm0); + } } ymm4 = _mm256_broadcast_sd((double const *)&ones); @@ -19918,41 +6551,20 @@ BLIS_INLINE err_t bli_dtrsm_small_XAlB } _mm256_storeu_pd((double *)(d11_pack), ymm4); - for(i = (m-D_MR); (i+1) > 0; i -= D_MR) //loop along 'M' direction + for(i = (m-d_mr); (i+1) > 0; i -= d_mr) //loop along 'M' direction { a01 = D_A_pack; - a11 = L + (n_remainder - 1)*cs_a + (n_remainder - 1); //pointer to block of A to be used for TRSM + a11 = L + (n_remainder - 1)*cs_a + (n_remainder - 1)*rs_a; //pointer to block of A to be used for TRSM b10 = B + i + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM b11 = B + (i) + (n_remainder - 1)*cs_b; //pointer to block of B to be used for TRSM k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); - #endif + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b10 + 4)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) - - a01 += 1; //move to next row - b10 += cs_b; - } + BLIS_DTRSM_SMALL_GEMM_1nx8m(a01,b10,cs_b,p_lda,k_iter) ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); @@ -19972,30 +6584,20 @@ BLIS_INLINE err_t bli_dtrsm_small_XAlB _mm256_storeu_pd((double *)(b11 + 4), ymm4); } - dim_t m_remainder = i + D_MR; + dim_t m_remainder = i + d_mr; if(m_remainder >= 4) { a01 = D_A_pack; - a11 = L + (n_remainder - 1)*cs_a + (n_remainder - 1); //pointer to block of A to be used for TRSM + a11 = L + (n_remainder - 1)*cs_a + (n_remainder - 1)*rs_a; //pointer to block of A to be used for TRSM b10 = B + (m_remainder - 4) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM b11 = B + (m_remainder - 4) + (n_remainder - 1)*cs_b; //pointer to block of B to be used for TRSM k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) ymm3 = _mm256_setzero_pd(); - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - a01 += 1; //move to next row - b10 += cs_b; - } + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha @@ -20017,7 +6619,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAlB if(3 == m_remainder) { a01 = D_A_pack; - a11 = L + (n_remainder - 1)*cs_a + (n_remainder - 1); //pointer to block of A to be used for TRSM + a11 = L + (n_remainder - 1)*cs_a + (n_remainder - 1)*rs_a; //pointer to block of A to be used for TRSM b10 = B + (m_remainder - 3) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM b11 = B + (m_remainder - 3) + (n_remainder - 1)*cs_b; //pointer to block of B to be used for TRSM @@ -20026,25 +6628,9 @@ BLIS_INLINE err_t bli_dtrsm_small_XAlB ymm3 = _mm256_setzero_pd(); ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + BLIS_DTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - - xmm5 = _mm_loadu_pd((double const*)(b11)); - ymm0 = _mm256_broadcast_sd((double const *)(b11+ 2)); - ymm6 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm3 = _mm256_fmsub_pd(ymm6, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + BLIS_PRE_DTRSM_SMALL_1N_3M(AlphaVal,b11,cs_b) ///implement TRSM/// //extract a00 @@ -20054,16 +6640,14 @@ BLIS_INLINE err_t bli_dtrsm_small_XAlB ymm0 = _mm256_loadu_pd((double const *)b11); ymm3 = _mm256_blend_pd(ymm6, ymm3, 0x07); - xmm5 = _mm256_extractf128_pd(ymm3, 0); - _mm_storeu_pd((double *)(b11), xmm5); - _mm_storel_pd((b11 + 2), _mm256_extractf128_pd(ymm3, 1)); + BLIS_POST_DTRSM_SMALL_1N_3M(b11,cs_b) m_remainder -=3; } else if(2 == m_remainder) { a01 = D_A_pack; - a11 = L + (n_remainder - 1)*cs_a + (n_remainder - 1); //pointer to block of A to be used for TRSM + a11 = L + (n_remainder - 1)*cs_a + (n_remainder - 1)*rs_a; //pointer to block of A to be used for TRSM b10 = B + (m_remainder - 2) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM b11 = B + (m_remainder - 2) + (n_remainder - 1)*cs_b; //pointer to block of B to be used for TRSM @@ -20072,76 +6656,41 @@ BLIS_INLINE err_t bli_dtrsm_small_XAlB ymm3 = _mm256_setzero_pd(); ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + BLIS_DTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) - xmm5 = _mm_loadu_pd((double const*)(b11)); - ymm6 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm3 = _mm256_fmsub_pd(ymm6, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + BLIS_PRE_DTRSM_SMALL_1N_2M(AlphaVal,b11,cs_b) ///implement TRSM/// //extract a00 ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm6, ymm3, 0x03); - - xmm5 = _mm256_extractf128_pd(ymm3, 0); - _mm_storeu_pd((double *)(b11), xmm5); + BLIS_POST_DTRSM_SMALL_1N_2M(b11,cs_b) m_remainder -=2; } else if (1 == m_remainder) { a01 = D_A_pack; - a11 = L + (n_remainder - 1)*cs_a + (n_remainder - 1); //pointer to block of A to be used for TRSM - b10 = B + (m_remainder - 1) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (m_remainder - 1) + (n_remainder - 1)*cs_b; //pointer to block of B to be used for TRSM + a11 = L + (n_remainder - 1)*cs_a + (n_remainder - 1)*rs_a; //pointer to block of A to be used for TRSM + b10 = B + (m_remainder - 1) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 1) + (n_remainder - 1)*cs_b; //pointer to block of B to be used for TRSM k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) ymm3 = _mm256_setzero_pd(); ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) + BLIS_DTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - - ymm6 = _mm256_broadcast_sd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm6, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + BLIS_PRE_DTRSM_SMALL_1N_1M(AlphaVal,b11,cs_b) ///implement TRSM/// //extract a00 ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm3 = _mm256_blend_pd(ymm6, ymm3, 0x01); - - _mm_storel_pd(b11, _mm256_extractf128_pd(ymm3, 0)); + BLIS_POST_DTRSM_SMALL_1N_1M(b11,cs_b) m_remainder -=1; } @@ -20154,29 +6703,14 @@ BLIS_INLINE err_t bli_dtrsm_small_XAlB bli_membrk_release(&rntm, &local_mem_buf_A_s); } - return BLIS_SUCCESS; } -/*implements TRSM for the case XA = alpha * B - *A is upper triangular, non-unit diagonal/unit diagonal, transpose - *dimensions: X:mxn A:nxn B: mxn - * - * <---b11 <---a11 - ***************** * - *b01*b11* * * * * - ^ * * * * * ^ * * - | ***************** | ******* - | * * * * * | * * * - | * * * * * a01* * * -b10 ***************** ************* - * * * * * * * * * - * * * * * * * * * - ***************** ******************* - +/* TRSM for the case AX = alpha * B, Double precision + * A is lower-triangular, transpose, non-unit diagonal + * dimensions A: mxm X: mxn B: mxn */ - -BLIS_INLINE err_t bli_dtrsm_small_XAutB +BLIS_INLINE err_t bli_dtrsm_small_AltXB_AuXB ( obj_t* AlphaObj, obj_t* a, @@ -20185,29 +6719,53 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB cntl_t* cntl ) { - dim_t m = bli_obj_length(b); //number of rows - dim_t n = bli_obj_width(b); //number of columns + dim_t m = bli_obj_length(b); // number of rows of matrix B + dim_t n = bli_obj_width(b); // number of columns of matrix B - dim_t cs_a = bli_obj_col_stride(a); //column stride of matrix A - dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B + bool transa = bli_obj_has_trans(a); + dim_t cs_a, rs_a; + dim_t d_mr = 8,d_nr = 6; - dim_t i, j, k; //loop variablse - dim_t k_iter; //determines the number of GEMM operations to be done + // Swap rs_a & cs_a in case of non-tranpose. + if(transa) + { + cs_a = bli_obj_col_stride(a); // column stride of A + rs_a = bli_obj_row_stride(a); // row stride of A + } + else + { + cs_a = bli_obj_row_stride(a); // row stride of A + rs_a = bli_obj_col_stride(a); // column stride of A + } + dim_t cs_b = bli_obj_col_stride(b); // column stride of B + + dim_t i, j, k; //loop variables + dim_t k_iter; //number of times GEMM to be performed + + double AlphaVal = *(double *)AlphaObj->buffer; //value of alpha + double *L = a->buffer; //pointer to matrix A + double *B = b->buffer; //pointer to matrix B + + //pointers that point to blocks for GEMM and TRSM + double *a10, *a11, *b01, *b11; double ones = 1.0; - double zero = 0.0; bool is_unitdiag = bli_obj_has_unit_diag(a); - double AlphaVal = *(double *)AlphaObj->buffer; //value of Alpha - double* restrict L = a->buffer; //pointer to matrix A - double* restrict B = b->buffer; //pointer to matrix B + //scratch registers + __m256d ymm0, ymm1, ymm2, ymm3; + __m256d ymm4, ymm5, ymm6, ymm7; + __m256d ymm8, ymm9, ymm10, ymm11; + __m256d ymm12, ymm13, ymm14, ymm15; + __m256d ymm16, ymm17, ymm18, ymm19; + __m256d ymm20; - double *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks + __m128d xmm5; gint_t required_packing_A = 1; mem_t local_mem_buf_A_s = {0}; double *D_A_pack = NULL; - double d11_pack[D_MR] __attribute__((aligned(64))); + double d11_pack[d_mr] __attribute__((aligned(64))); rntm_t rntm; bli_rntm_init_from_global( &rntm ); @@ -20215,2909 +6773,3951 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB bli_membrk_rntm_set_membrk( &rntm ); siz_t buffer_size = bli_pool_block_size( - bli_membrk_pool( + bli_membrk_pool( bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), bli_rntm_membrk(&rntm))); - if( (D_NR * n * sizeof(double)) > buffer_size) + if((d_mr * m * sizeof(double)) > buffer_size) return BLIS_NOT_YET_IMPLEMENTED; - - if (required_packing_A == 1) + + if(required_packing_A == 1) { - // Get the buffer from the pool. - bli_membrk_acquire_m(&rntm, - buffer_size, - BLIS_BITVAL_BUFFER_FOR_A_BLOCK, - &local_mem_buf_A_s); - if(FALSE==bli_mem_is_alloc(&local_mem_buf_A_s)) return BLIS_NULL_POINTER; - D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); - if(NULL==D_A_pack) return BLIS_NULL_POINTER; + // Get the buffer from the pool. + bli_membrk_acquire_m(&rntm, + buffer_size, + BLIS_BITVAL_BUFFER_FOR_A_BLOCK, + &local_mem_buf_A_s); + if(FALSE==bli_mem_is_alloc(&local_mem_buf_A_s)) return BLIS_NULL_POINTER; + D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); + if(NULL==D_A_pack) return BLIS_NULL_POINTER; } - //ymm scratch reginsters - __m256d ymm0, ymm1, ymm2, ymm3; - __m256d ymm4, ymm5, ymm6, ymm7; - __m256d ymm8, ymm9, ymm10, ymm11; - __m256d ymm12, ymm13, ymm14, ymm15; - - __m128d xmm5; - /* - Performs solving TRSM for 6 rows at a time from 0 to n/6 in steps of D_NR - a. Load and pack A (a01 block), the size of packing 6x6 to 6x (n-6) - First there will be no GEMM and no packing of a01 because it is only TRSM - b. Using packed a01 block and b10 block perform GEMM operation - c. Use GEMM outputs, perform TRSM operation using a11, b11 and update B - d. Repeat b for m cols of B in steps of D_MR + Performs solving TRSM for 8 colmns at a time from 0 to m/d_mr in steps of d_mr + a. Load, transpose, Pack A (a10 block), the size of packing 8x6 to 8x (m-d_mr) + First there will be no GEMM and no packing of a10 because it is only TRSM + b. Using packed a10 block and b01 block perform GEMM operation + c. Use GEMM outputs, perform TRSM operaton using a11, b11 and update B + d. Repeat b,c for n rows of B in steps of d_nr */ - - for(j = (n-D_NR); (j+1) > 0; j -= D_NR) //loop along 'N' direction + for(i = (m - d_mr); (i + 1) > 0; i -= d_mr) { - a01 = L + j +(j+D_NR)*cs_a; //pointer to block of A to be used in GEMM - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM + a10 = L + (i*cs_a) + (i + d_mr)*rs_a; //pointer to block of A to be used for GEMM + a11 = L + (i*cs_a) + (i*rs_a); //pointer to block of A to be used for TRSM - //double *ptr_a10_dup = D_A_pack; + // Do transpose for a10 & store in D_A_pack + //ptr_a10_dup = D_A_pack; - dim_t p_lda = (n-j-D_NR); // packed leading dimension - // perform copy of A to packed buffer D_A_pack + dim_t p_lda = d_mr; // packed leading dimension - /* - Pack current A block (a01) into packed buffer memory D_A_pack - a. This a10 block is used in GEMM portion only and this - a01 block size will be increasing by D_NR for every next iteration - until it reaches 6x(n-6) which is the maximum GEMM alone block size in A - b. This packed buffer is reused to calculate all m cols of B matrix - */ - bli_dtrsm_small_pack('R', p_lda, 1, a01, cs_a, D_A_pack, p_lda); + if(transa) + { + /* + Load, transpose and pack current A block (a10) into packed buffer memory D_A_pack + a. This a10 block is used in GEMM portion only and this + a10 block size will be increasing by d_mr for every next itteration + untill it reaches 8x(m-8) which is the maximum GEMM alone block size in A + b. This packed buffer is reused to calculate all n rows of B matrix + */ + bli_dtrsm_small_pack('L', (m-i-d_mr), 1, a10, cs_a, D_A_pack,p_lda,d_mr); - /* - Pack 6 diagonal elements of A block into an array - a. This helps in utilze cache line efficiently in TRSM operation - b. store ones when input is unit diagonal - */ - dtrsm_small_pack_diag_element(is_unitdiag,a11,cs_a,d11_pack,D_NR); + /* + Pack 8 diagonal elements of A block into an array + a. This helps in utilze cache line efficiently in TRSM operation + b. store ones when input is unit diagonal + */ + dtrsm_small_pack_diag_element(is_unitdiag,a11,cs_a,d11_pack,d_mr); + } + else + { + bli_dtrsm_small_pack('L', (m-i-d_mr), 0, a10, rs_a, D_A_pack,p_lda,d_mr); + dtrsm_small_pack_diag_element(is_unitdiag,a11,rs_a,d11_pack,d_mr); + } /* - a. Perform GEMM using a01, b10. - b. Perform TRSM on a11, b11 - c. This loop GEMM+TRSM loops operates with 8x6 block size - along m dimension for every D_MR columns of B10 where - packed A buffer is reused in computing all m cols of B. - d. Same approach is used in remaining fringe cases. + a. Perform GEMM using a10, b01. + b. Perform TRSM on a11, b11 + c. This loop GEMM+TRSM loops operates with 8x6 block size + along n dimension for every d_nr rows of b01 where + packed A buffer is reused in computing all n rows of B. + d. Same approch is used in remaining fringe cases. */ - for(i = (m-D_MR); (i+1) > 0; i -= D_MR) //loop along 'M' direction + for(j = (n - d_nr); (j + 1) > 0; j -= d_nr) { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (i) + (j)*cs_b; //pointer to block of B to be used for TRSM - - k_iter = (n-j-D_NR); //number of GEMM operations to be done(in blocks of 4x4) + a10 = D_A_pack; + b01 = B + (j * cs_b) + i + d_mr; //pointer to block of B to be used for GEMM + b11 = B + (j * cs_b) + i; //pointer to block of B to be used for TRSM - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); - #endif + k_iter = (m - i - d_mr); /*Fill zeros into ymm registers used in gemm accumulations */ BLIS_SET_YMM_REG_ZEROS /* - Peform GEMM between a01 and b10 blocks - For first itteration there will be no GEMM operation - where k_iter are zero + Peform GEMM between a10 and b01 blocks + For first itteration there will be no GEMM operation + where k_iter are zero */ - - BLIS_DTRSM_SMALL_GEMM_6x8(a01,b10,cs_b,p_lda,k_iter) + BLIS_DTRSM_SMALL_GEMM_8mx6n(a10,b01,cs_b,p_lda,k_iter) /* - Load b11 of size 8x6 and multiply with alpha - Add the GEMM output to b11 - and peform TRSM operation. + Load b11 of size 6x8 and multiply with alpha + Add the GEMM output and perform inregister transose of b11 + to peform TRSM operation. */ - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + 4)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] - - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - ymm4 = _mm256_fmsub_pd(ymm1, ymm15, ymm4); //B11[4-7][0] * alpha-= ymm1 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b + 4)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] - - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - ymm6 = _mm256_fmsub_pd(ymm1, ymm15, ymm6); //B11[4-7][1] * alpha -= ymm3 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b*2 + 4)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] - - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - ymm8 = _mm256_fmsub_pd(ymm1, ymm15, ymm8); //B11[4-7][2] * alpha -= ymm5 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b*3 + 4)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] - - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 - ymm10 = _mm256_fmsub_pd(ymm1, ymm15, ymm10); //B11[4-7][3] * alpha -= ymm7 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b*4 + 4)); - - ymm11 = _mm256_fmsub_pd(ymm0, ymm15, ymm11); - ymm12 = _mm256_fmsub_pd(ymm1, ymm15, ymm12); - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b*5 + 4)); - - ymm13 = _mm256_fmsub_pd(ymm0, ymm15, ymm13); - ymm14 = _mm256_fmsub_pd(ymm1, ymm15, ymm14); - - ///implement TRSM/// + BLIS_DTRSM_SMALL_NREG_TRANSPOSE_6x8(b11,cs_b,AlphaVal) /* - Compute 6x8 TRSM block by using GEMM block output in register - a. The 6x8 input (gemm outputs) are stored in combinations of ymm registers - 1. ymm3, ymm4 2. ymm5, ymm6 3. ymm7, ymm8, 4. ymm9, ymm10 - 5. ymm11, ymm12 6. ymm13,ymm14 - b. Towards the end TRSM output will be stored back into b11 + Compute 8x6 TRSM block by using GEMM block output in register + a. The 8x6 input (gemm outputs) are stored in combinations of ymm registers + 1. ymm15, ymm20 2. ymm14, ymm19 3. ymm13, ymm18 , 4. ymm12, ymm17 + 5. ymm11, ymm7 6. ymm10, ymm6, 7.ymm9, ymm5 8. ymm8, ymm4 + where ymm15-ymm8 holds 8x4 data and reaming 8x2 will be hold by + other registers + b. Towards the end do in regiser transpose of TRSM output and store in b11 */ + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); - //extract a55 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); - - ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm0); - - //extract a44 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); - - //(row 5):FMA operations - //ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 4)); - - ymm11 = _mm256_fnmadd_pd(ymm1, ymm13, ymm11); - ymm12 = _mm256_fnmadd_pd(ymm1, ymm14, ymm12); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 3)); - - ymm9 = _mm256_fnmadd_pd(ymm1, ymm13, ymm9); - ymm10 = _mm256_fnmadd_pd(ymm1, ymm14, ymm10); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 2)); - - ymm7 = _mm256_fnmadd_pd(ymm1, ymm13, ymm7); - ymm8 = _mm256_fnmadd_pd(ymm1, ymm14, ymm8); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 1)); - - ymm5 = _mm256_fnmadd_pd(ymm1, ymm13, ymm5); - ymm6 = _mm256_fnmadd_pd(ymm1, ymm14, ymm6); + //perform mul operation + ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); + ymm20 = DTRSM_SMALL_DIV_OR_SCALE(ymm20, ymm1); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a)); + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); + + //(ROw7): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6*cs_a + 7*rs_a)); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm15, ymm14); + ymm19 = _mm256_fnmadd_pd(ymm2, ymm20, ymm19); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 7*rs_a)); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm15, ymm13); + ymm18 = _mm256_fnmadd_pd(ymm2, ymm20, ymm18); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 7*rs_a)); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm15, ymm12); + ymm17 = _mm256_fnmadd_pd(ymm2, ymm20, ymm17); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 7*rs_a)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm15, ymm11); + ymm7 = _mm256_fnmadd_pd(ymm2, ymm20, ymm7); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 7*rs_a)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm15, ymm10); + ymm6 = _mm256_fnmadd_pd(ymm2, ymm20, ymm6); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 7*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm15, ymm9); + ymm5 = _mm256_fnmadd_pd(ymm2, ymm20, ymm5); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7*rs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm15, ymm8); + ymm4 = _mm256_fnmadd_pd(ymm2, ymm20, ymm4); + + //perform mul operation + ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); + ymm19 = DTRSM_SMALL_DIV_OR_SCALE(ymm19, ymm1); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm13, ymm3); - ymm4 = _mm256_fnmadd_pd(ymm1, ymm14, ymm4); + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + + //(ROw6): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 6*rs_a)); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm14, ymm13); + ymm18 = _mm256_fnmadd_pd(ymm2, ymm19, ymm18); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 6*rs_a)); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm14, ymm12); + ymm17 = _mm256_fnmadd_pd(ymm2, ymm19, ymm17); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 6*rs_a)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm14, ymm11); + ymm7 = _mm256_fnmadd_pd(ymm2, ymm19, ymm7); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 6*rs_a)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm14, ymm10); + ymm6 = _mm256_fnmadd_pd(ymm2, ymm19, ymm6); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 6*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm14, ymm9); + ymm5 = _mm256_fnmadd_pd(ymm2, ymm19, ymm5); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6*rs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm14, ymm8); + ymm4 = _mm256_fnmadd_pd(ymm2, ymm19, ymm4); + + //perform mul operation + ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm1); + ymm18 = DTRSM_SMALL_DIV_OR_SCALE(ymm18, ymm1); - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); - ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm0); + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + + //(ROw5): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 5*rs_a)); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm13, ymm12); + ymm17 = _mm256_fnmadd_pd(ymm2, ymm18, ymm17); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 5*rs_a)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm13, ymm11); + ymm7 = _mm256_fnmadd_pd(ymm2, ymm18, ymm7); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 5*rs_a)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm13, ymm10); + ymm6 = _mm256_fnmadd_pd(ymm2, ymm18, ymm6); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 5*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm13, ymm9); + ymm5 = _mm256_fnmadd_pd(ymm2, ymm18, ymm5); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm13, ymm8); + ymm4 = _mm256_fnmadd_pd(ymm2, ymm18, ymm4); + + //perform mul operation + ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); + ymm17 = DTRSM_SMALL_DIV_OR_SCALE(ymm17, ymm1); //extract a33 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - //(row 4):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 3)); - - ymm9 = _mm256_fnmadd_pd(ymm1, ymm11, ymm9); - ymm10 = _mm256_fnmadd_pd(ymm1, ymm12, ymm10); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 2)); - - ymm7 = _mm256_fnmadd_pd(ymm1, ymm11, ymm7); - ymm8 = _mm256_fnmadd_pd(ymm1, ymm12, ymm8); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 1)); - - ymm5 = _mm256_fnmadd_pd(ymm1, ymm11, ymm5); - ymm6 = _mm256_fnmadd_pd(ymm1, ymm12, ymm6); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a)); - - ymm3 = _mm256_fnmadd_pd(ymm1, ymm11, ymm3); - ymm4 = _mm256_fnmadd_pd(ymm1, ymm12, ymm4); - - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm0); + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //(ROw4): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 4*rs_a)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm12, ymm11); + ymm7 = _mm256_fnmadd_pd(ymm2, ymm17, ymm7); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 4*rs_a)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm12, ymm10); + ymm6 = _mm256_fnmadd_pd(ymm2, ymm17, ymm6); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 4*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm12, ymm9); + ymm5 = _mm256_fnmadd_pd(ymm2, ymm17, ymm5); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm12, ymm8); + ymm4 = _mm256_fnmadd_pd(ymm2, ymm17, ymm4); + + //perform mul operation + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm1); //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2)); - - ymm7 = _mm256_fnmadd_pd(ymm1, ymm9, ymm7); - ymm8 = _mm256_fnmadd_pd(ymm1, ymm10, ymm8); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1)); - - ymm5 = _mm256_fnmadd_pd(ymm1, ymm9, ymm5); - ymm6 = _mm256_fnmadd_pd(ymm1, ymm10, ymm6); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); - - ymm3 = _mm256_fnmadd_pd(ymm1, ymm9, ymm3); - ymm4 = _mm256_fnmadd_pd(ymm1, ymm10, ymm4); - - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm0); + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(ROw3): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 3*rs_a)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm11, ymm10); + ymm6 = _mm256_fnmadd_pd(ymm2, ymm7, ymm6); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm11, ymm9); + ymm5 = _mm256_fnmadd_pd(ymm2, ymm7, ymm5); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm11, ymm8); + ymm4 = _mm256_fnmadd_pd(ymm2, ymm7, ymm4); + + //perform mul operation + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm1); //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1)); - - ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); - ymm6 = _mm256_fnmadd_pd(ymm1, ymm8, ymm6); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); - ymm4 = _mm256_fnmadd_pd(ymm1, ymm8, ymm4); + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm10, ymm9); + ymm5 = _mm256_fnmadd_pd(ymm2, ymm6, ymm5); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm10, ymm8); + ymm4 = _mm256_fnmadd_pd(ymm2, ymm6, ymm4); - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm0); + //perform mul operation + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm1); //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); - - //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); - ymm4 = _mm256_fnmadd_pd(ymm1, ymm6, ymm4); + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm9, ymm8); + ymm4 = _mm256_fnmadd_pd(ymm2, ymm5, ymm4); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm0); + //perform mul operation + ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); + ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm1); - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + 4), ymm4); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b + 4), ymm6); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*2 + 4), ymm8); - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); - _mm256_storeu_pd((double *)(b11 + cs_b*3 + 4), ymm10); - _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); - _mm256_storeu_pd((double *)(b11 + cs_b*4 + 4), ymm12); - _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); - _mm256_storeu_pd((double *)(b11 + cs_b*5 + 4), ymm14); + BLIS_DTRSM_SMALL_NREG_TRANSPOSE_8x6_AND_STORE(b11,cs_b) } - dim_t m_remainder = i + D_MR; - if(m_remainder >= 4) + dim_t n_remainder = j + d_nr; + if(n_remainder >= 4) { - a01 = D_A_pack; - a11 = L + (j*cs_a) + j; - b10 = B + (m_remainder - 4) + (j+D_NR)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (m_remainder - 4) + (j*cs_b); - - k_iter = (n-j-D_NR); //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - ymm13 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) + a10 = D_A_pack; + a11 = L + (i*cs_a) + (i*rs_a); + b01 = B + ((n_remainder - 4)* cs_b) + i + d_mr; + b11 = B + ((n_remainder - 4)* cs_b) + i; - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + k_iter = (m - i - d_mr); - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); //A01[0][3] - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_8mx4n(a10,b01,cs_b,p_lda,k_iter) - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 4)); //A01[0][4] - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 5)); //A01[0][5] - ymm13 = _mm256_fmadd_pd(ymm2, ymm0, ymm13); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] + ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] + ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *2 + 4)); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] + ymm7 = _mm256_loadu_pd((double const *)(b11 + cs_b *3 + 4)); //B11[0][7] B11[1][7] B11[2][7] B11[3][7] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); //B11[0-3][3] * alpha -= B01[0-3][3] + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4] + ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] * alpha -= B01[0-3][5] + ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); //B11[0-3][6] * alpha -= B01[0-3][6] + ymm7 = _mm256_fmsub_pd(ymm7, ymm16, ymm15); //B11[0-3][7] * alpha -= B01[0-3][7] - a01 += 1; //move to next row - b10 += cs_b; - } + ///implement TRSM/// - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + ///transpose of B11// + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[2][4] B11[2][5] + ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[2][6] B11[2][7] - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 + ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3] + ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3] - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); - ymm11 = _mm256_fmsub_pd(ymm0, ymm15, ymm11); + ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[1][4] B11[1][5] B11[3][4] B11[3][5] + ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[1][6] B11[1][7] B11[3][6] B11[3][7] - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); - ymm13 = _mm256_fmsub_pd(ymm0, ymm15, ymm13); + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - ///implement TRSM/// + ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3] + ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3] - //extract a55 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); - ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); - //extract a44 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + //perform mul operation + ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); - //(row 5):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 4)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm13, ymm11); + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); + + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6*cs_a + 7*rs_a)); + ymm3 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 7*rs_a)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 7*rs_a)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 7*rs_a)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 7*rs_a)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 7*rs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 7*rs_a)); + + //(ROw7): FMA operations + ymm14 = _mm256_fnmadd_pd(ymm2, ymm15, ymm14); + ymm13 = _mm256_fnmadd_pd(ymm3, ymm15, ymm13); + ymm12 = _mm256_fnmadd_pd(ymm4, ymm15, ymm12); + ymm11 = _mm256_fnmadd_pd(ymm5, ymm15, ymm11); + ymm10 = _mm256_fnmadd_pd(ymm6, ymm15, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm15, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm15, ymm8); + + //perform mul operation + ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 3)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm13, ymm9); + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + + ymm3 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 6*rs_a)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 6*rs_a)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 6*rs_a)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 6*rs_a)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 6*rs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 6*rs_a)); + + //(ROw6): FMA operations + ymm13 = _mm256_fnmadd_pd(ymm3, ymm14, ymm13); + ymm12 = _mm256_fnmadd_pd(ymm4, ymm14, ymm12); + ymm11 = _mm256_fnmadd_pd(ymm5, ymm14, ymm11); + ymm10 = _mm256_fnmadd_pd(ymm6, ymm14, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm14, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm14, ymm8); + + //perform mul operation + ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm1); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 2)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm13, ymm7); + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 1)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm13, ymm5); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 5*rs_a)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 5*rs_a)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 5*rs_a)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 5*rs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm13, ymm3); + //(ROw5): FMA operations + ymm12 = _mm256_fnmadd_pd(ymm4, ymm13, ymm12); + ymm11 = _mm256_fnmadd_pd(ymm5, ymm13, ymm11); + ymm10 = _mm256_fnmadd_pd(ymm6, ymm13, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm13, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm13, ymm8); - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); + //perform mul operation + ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); //extract a33 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - //(row 4):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 3)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm11, ymm9); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 2)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm11, ymm7); + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 1)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm11, ymm5); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 4*rs_a)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 4*rs_a)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 4*rs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm11, ymm3); + //(ROw4): FMA operations + ymm11 = _mm256_fnmadd_pd(ymm5, ymm12, ymm11); + ymm10 = _mm256_fnmadd_pd(ymm6, ymm12, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm12, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm12, ymm8); - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + //perform mul operation + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm9, ymm7); + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm9, ymm5); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 3*rs_a)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3*rs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm9, ymm3); + //(ROw3): FMA operations + ymm10 = _mm256_fnmadd_pd(ymm6, ymm11, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm11, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm11, ymm8); - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + //perform mul operation + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2*rs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); + //(ROw2): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm7, ymm10, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm10, ymm8); - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + //perform mul operation + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); - //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + //(ROw2): FMA operations + ymm8 = _mm256_fnmadd_pd(ymm16, ymm9, ymm8); - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); - _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); - _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); + //perform mul operation + ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); - m_remainder -=4; - } + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - if(m_remainder) - { - if(3 == m_remainder) - { - a01 = D_A_pack; - a11 = L + (j*cs_a) + j; - b10 = B + (j+D_NR)*cs_b + (m_remainder - 3); //pointer to block of B to be used in GEMM - b11 = B + (m_remainder - 3) + (j*cs_b); + ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] + ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2] - k_iter = (n-j-D_NR); //number of GEMM operations to be done(in blocks of 4x4) + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - ymm13 = _mm256_setzero_pd(); + ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] + ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) + ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[4][1] B11[5][1] B11[4][3] B11[5][3] + ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[6][1] B11[7][1] B11[6][3] B11[7][3] - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] + ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); //A01[0][3] - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); //store B11[4][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); //store B11[5][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 4), ymm6); //store B11[6][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3 + 4), ymm7); //store B11[7][0-3] + n_remainder -=4; + } - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 4)); //A01[0][4] - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + if(n_remainder) //implementation fo remaining columns(when 'N' is not a multiple of d_nr)() n = 3 + { + a10 = D_A_pack; + a11 = L + (i*cs_a) + (i*rs_a); + b01 = B + i + d_mr; + b11 = B + i; - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 5)); //A01[0][5] - ymm13 = _mm256_fmadd_pd(ymm2, ymm0, ymm13); + k_iter = (m - i - d_mr) ; - a01 += 1; //move to next row - b10 += cs_b; - } + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + if(3 == n_remainder) + { + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_8mx3n(a10,b01,cs_b,p_lda,k_iter) - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] + ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] + ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *2 + 4)); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); - ymm11 = _mm256_fmsub_pd(ymm0, ymm15, ymm11); + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4] + ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] * alpha -= B01[0-3][5] + ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); //B11[0-3][6] * alpha -= B01[0-3][6] + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + } + else if(2 == n_remainder) + { + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_8mx2n(a10,b01,cs_b,p_lda,k_iter) - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); - ymm13 = _mm256_fmsub_pd(ymm0, ymm15, ymm13); + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - ///implement TRSM/// + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - //extract a55 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); - ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] + ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] - //extract a44 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); - //(row 5):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 4)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm13, ymm11); + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4] + ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] * alpha -= B01[0-3][5] + ymm6 = _mm256_broadcast_sd((double const *)(&ones)); + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + } + else if(1 == n_remainder) + { + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_8mx1n(a10,b01,cs_b,p_lda,k_iter) - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 3)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm13, ymm9); + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 2)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm13, ymm7); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 1)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm13, ymm5); + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm13, ymm3); + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_broadcast_sd((double const *)(&ones)); + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4] + ymm5 = _mm256_broadcast_sd((double const *)(&ones)); + ymm6 = _mm256_broadcast_sd((double const *)(&ones)); + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + } + ///implement TRSM/// - //extract a33 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + ///transpose of B11// + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - //(row 4):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 3)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm11, ymm9); + ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[2][4] B11[2][5] + ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[2][6] B11[2][7] - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 2)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm11, ymm7); + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 1)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm11, ymm5); + ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3] + ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3] - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm11, ymm3); + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[1][4] B11[1][5] B11[3][4] B11[3][5] + ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[1][6] B11[1][7] B11[3][6] B11[3][7] - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm9, ymm7); + ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3] + ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3] - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm9, ymm5); + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm9, ymm3); + //perform mul operation + ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); + + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6*cs_a + 7*rs_a)); + ymm3 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 7*rs_a)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 7*rs_a)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 7*rs_a)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 7*rs_a)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 7*rs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 7*rs_a)); + + //(ROw7): FMA operations + ymm14 = _mm256_fnmadd_pd(ymm2, ymm15, ymm14); + ymm13 = _mm256_fnmadd_pd(ymm3, ymm15, ymm13); + ymm12 = _mm256_fnmadd_pd(ymm4, ymm15, ymm12); + ymm11 = _mm256_fnmadd_pd(ymm5, ymm15, ymm11); + ymm10 = _mm256_fnmadd_pd(ymm6, ymm15, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm15, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm15, ymm8); + + //perform mul operation + ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + + ymm3 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 6*rs_a)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 6*rs_a)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 6*rs_a)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 6*rs_a)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 6*rs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 6*rs_a)); + + //(ROw6): FMA operations + ymm13 = _mm256_fnmadd_pd(ymm3, ymm14, ymm13); + ymm12 = _mm256_fnmadd_pd(ymm4, ymm14, ymm12); + ymm11 = _mm256_fnmadd_pd(ymm5, ymm14, ymm11); + ymm10 = _mm256_fnmadd_pd(ymm6, ymm14, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm14, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm14, ymm8); + + //perform mul operation + ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm1); - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 5*rs_a)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 5*rs_a)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 5*rs_a)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 5*rs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + //(ROw5): FMA operations + ymm12 = _mm256_fnmadd_pd(ymm4, ymm13, ymm12); + ymm11 = _mm256_fnmadd_pd(ymm5, ymm13, ymm11); + ymm10 = _mm256_fnmadd_pd(ymm6, ymm13, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm13, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm13, ymm8); - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); + //perform mul operation + ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); - //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 4*rs_a)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 4*rs_a)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 4*rs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_pd(ymm0, ymm11, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_pd(ymm0, ymm13, 0x07); + //(ROw4): FMA operations + ymm11 = _mm256_fnmadd_pd(ymm5, ymm12, ymm11); + ymm10 = _mm256_fnmadd_pd(ymm6, ymm12, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm12, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm12, ymm8); - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); - _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); - _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); + //perform mul operation + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); - m_remainder -=3; - } - else if(2 == m_remainder) - { - a01 = D_A_pack; - a11 = L + (j*cs_a) + j; - b10 = B + (j+D_NR)*cs_b + (m_remainder - 2); //pointer to block of B to be used in GEMM - b11 = B + (m_remainder - 2) + (j*cs_b); + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - k_iter = (n-j-D_NR); //number of GEMM operations to be done(in blocks of 4x4) + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 3*rs_a)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3*rs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - ymm13 = _mm256_setzero_pd(); + //(ROw3): FMA operations + ymm10 = _mm256_fnmadd_pd(ymm6, ymm11, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm11, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm11, ymm8); - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + //perform mul operation + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2*rs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + //(ROw2): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm7, ymm10, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm10, ymm8); - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); //A01[0][3] - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + //perform mul operation + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 4)); //A01[0][4] - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 5)); //A01[0][5] - ymm13 = _mm256_fmadd_pd(ymm2, ymm0, ymm13); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); - a01 += 1; //move to next row - b10 += cs_b; - } + //(ROw2): FMA operations + ymm8 = _mm256_fnmadd_pd(ymm16, ymm9, ymm8); - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + //perform mul operation + ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] + ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2] - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 + ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] + ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); - ymm11 = _mm256_fmsub_pd(ymm0, ymm15, ymm11); + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); - ymm13 = _mm256_fmsub_pd(ymm0, ymm15, ymm13); + ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[4][1] B11[5][1] B11[4][3] B11[5][3] + ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[6][1] B11[7][1] B11[6][3] B11[7][3] - ///implement TRSM/// + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - //extract a55 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); - ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); + ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] + ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] - //extract a44 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + if(3 == n_remainder) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); //store B11[4][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); //store B11[5][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 4), ymm6); //store B11[6][0-3] + } + else if(2 == n_remainder) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); //store B11[4][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); //store B11[5][0-3] + } + else if(1 == n_remainder) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); //store B11[4][0-3] + } + } + }// End of multiples of d_mr blocks in m-dimension - //(row 5):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 4)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm13, ymm11); + // Repetative A blocks will be 4*4 + dim_t m_remainder = i + d_mr; + if(m_remainder >= 4) + { + i = m_remainder - 4; + a10 = L + (i*cs_a) + (i + 4)*rs_a; //pointer to block of A to be used for GEMM + a11 = L + (i*cs_a) + (i*rs_a); //pointer to block of A to be used for TRSM - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 3)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm13, ymm9); + // Do transpose for a10 & store in D_A_pack + double *ptr_a10_dup = D_A_pack; + dim_t p_lda = 4; // packed leading dimension + if(transa) + { + for(dim_t x =0;x < m-i+4;x+=p_lda) + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + cs_a)); + ymm2 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); + ymm3 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 2)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm13, ymm7); + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 1)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm13, ymm5); + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm13, ymm3); + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - //extract a33 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); - //(row 4):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 3)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm11, ymm9); + a10 += p_lda; + ptr_a10_dup += p_lda*p_lda; + } + } + else + { + for(dim_t x =0;x < m-i-4;x++) + { + ymm0 = _mm256_loadu_pd((double const *)(a10 + x*rs_a)); + _mm256_storeu_pd((double *)(ptr_a10_dup + x*p_lda), ymm0); + } + } - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 2)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm11, ymm7); + ymm4 = _mm256_broadcast_sd((double const *)&ones); + if(!is_unitdiag) + { + if(transa) + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11+cs_a*1 + 1)); + ymm2 = _mm256_broadcast_sd((double const *)(a11+cs_a*2 + 2)); + ymm3 = _mm256_broadcast_sd((double const *)(a11+cs_a*3 + 3)); + } + else + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11+rs_a*1 + 1)); + ymm2 = _mm256_broadcast_sd((double const *)(a11+rs_a*2 + 2)); + ymm3 = _mm256_broadcast_sd((double const *)(a11+rs_a*3 + 3)); + } - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 1)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm11, ymm5); + ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); + #ifdef BLIS_DISABLE_TRSM_PREINVERSION + ymm4 = ymm1; + #endif + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + ymm4 = _mm256_div_pd(ymm4, ymm1); + #endif + } + _mm256_storeu_pd((double *)(d11_pack), ymm4); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm11, ymm3); + //cols + for(j = (n - d_nr); (j + 1) > 0; j -= d_nr) //loop along 'N' dimension + { + a10 = D_A_pack; + a11 = L + (i*cs_a) + (i*rs_a); //pointer to block of A to be used for TRSM + b01 = B + (j*cs_b) + i + 4; //pointer to block of B to be used for GEMM + b11 = B + (j* cs_b) + i; //pointer to block of B to be used for TRSM - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + k_iter = (m - i - 4); //number of times GEMM to be performed(in blocks of 4x4) - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS - //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm9, ymm7); + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx6n(a10,b01,cs_b,p_lda,k_iter) - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm9, ymm5); + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm9, ymm3); + ///implement TRSM/// - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + ///transpose of B11// + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); + + ymm16 = _mm256_broadcast_sd((double const *)(&ones)); - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + ////unpacklow//// + ymm7 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); + //rearrange low elements + ymm4 = _mm256_permute2f128_pd(ymm7,ymm16,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm6 = _mm256_permute2f128_pd(ymm7,ymm16,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + //rearrange high elements + ymm5 = _mm256_permute2f128_pd(ymm0,ymm16,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm7 = _mm256_permute2f128_pd(ymm0,ymm16,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); - //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + //perform mul operation + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm1); - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_pd(ymm0, ymm11, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_pd(ymm0, ymm13, 0x03); + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(ROw3): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 3*rs_a)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm11, ymm10); + ymm6 = _mm256_fnmadd_pd(ymm2, ymm7, ymm6); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm11, ymm9); + ymm5 = _mm256_fnmadd_pd(ymm2, ymm7, ymm5); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm11, ymm8); + ymm4 = _mm256_fnmadd_pd(ymm2, ymm7, ymm4); + + //perform mul operation + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm1); - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); - _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); - _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - m_remainder -=2; - } - else if (1 == m_remainder) - { - a01 = D_A_pack; - a11 = L + (j*cs_a) + j; - b10 = B + (j+D_NR)*cs_b + (m_remainder - 1); //pointer to block of B to be used in GEMM - b11 = B + (m_remainder - 1) + (j*cs_b); + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm10, ymm9); + ymm5 = _mm256_fnmadd_pd(ymm2, ymm6, ymm5); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm10, ymm8); + ymm4 = _mm256_fnmadd_pd(ymm2, ymm6, ymm4); - k_iter = (n-j-D_NR); //number of GEMM operations to be done(in blocks of 4x4) + //perform mul operation + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm1); - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - ymm13 = _mm256_setzero_pd(); + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm9, ymm8); + ymm4 = _mm256_fnmadd_pd(ymm2, ymm5, ymm4); - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) + //perform mul operation + ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); + ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm1); - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); //A01[0][3] - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 4)); //A01[0][4] - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 5)); //A01[0][5] - ymm13 = _mm256_fmadd_pd(ymm2, ymm0, ymm13); + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] - a01 += 1; //move to next row - b10 += cs_b; - } + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + ///unpack high/// + ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 + _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store B11[1][0-3] + } + dim_t n_remainder = j + d_nr; + if((n_remainder >= 4)) + { + a10 = D_A_pack; + a11 = L + (i*cs_a) + (i*rs_a); //pointer to block of A to be used for TRSM + b01 = B + ((n_remainder - 4)* cs_b) + i + 4; //pointer to block of B to be used for GEMM + b11 = B + ((n_remainder - 4)* cs_b) + i; //pointer to block of B to be used for TRSM - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 + k_iter = (m - i - 4); //number of times GEMM to be performed(in blocks of 4x4) - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); - ymm11 = _mm256_fmsub_pd(ymm0, ymm15, ymm11); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); - ymm13 = _mm256_fmsub_pd(ymm0, ymm15, ymm13); + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx4n(a10,b01,cs_b,p_lda,k_iter) - ///implement TRSM/// + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - //extract a55 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); - ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); + ///implement TRSM/// - //extract a44 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - //(row 5):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 4)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm13, ymm11); + ///transpose of B11// + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 3)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm13, ymm9); + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 2)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm13, ymm7); + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 1)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm13, ymm5); + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm13, ymm3); + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); + //perform mul operation + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); - //extract a33 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - //(row 4):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 3)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm11, ymm9); + //(ROw3): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 3*rs_a)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm11, ymm10); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm11, ymm9); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm11, ymm8); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 2)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm11, ymm7); + //perform mul operation + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 1)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm11, ymm5); + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm11, ymm3); + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm10, ymm9); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm10, ymm8); - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + //perform mul operation + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); - //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm9, ymm7); + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm9, ymm8); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm9, ymm5); + //perform mul operation + ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm9, ymm3); + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] + n_remainder = n_remainder - 4; + } - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + if(n_remainder) //implementation fo remaining columns(when 'N' is not a multiple of d_nr)() n = 3 + { + a10 = D_A_pack; + a11 = L + (i*cs_a) + (i*rs_a); + b01 = B + i + 4; + b11 = B + i; - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); + k_iter = (m - i - 4); - //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + if(3 == n_remainder) + { + BLIS_DTRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_pd(ymm0, ymm11, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_pd(ymm0, ymm13, 0x01); + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); - _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); - _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - m_remainder -=1; + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); } - } - } - - dim_t n_remainder = j + D_NR; + else if(2 == n_remainder) + { + BLIS_DTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b,p_lda,k_iter) - /* - Reminder cases starts here: - a. Similar logic and code flow used in computing full block (6x8) - above holds for reminder cases too. - */ + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - if(n_remainder >= 4) - { - a01 = L + (n_remainder - 4) + n_remainder*cs_a; //pointer to block of A to be used in GEMM - a11 = L + (n_remainder - 4)*cs_a + (n_remainder - 4); //pointer to block of A to be used for TRSM + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - double *ptr_a10_dup = D_A_pack; + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + } + else if(1 == n_remainder) + { + BLIS_DTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b,p_lda,k_iter) - dim_t p_lda = (n-n_remainder); // packed leading dimension - // perform copy of A to packed buffer D_A_pack + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - for(dim_t x =0;x < p_lda;x+=D_NR) - { - ymm0 = _mm256_loadu_pd((double const *)(a01)); - ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); - ymm2 = _mm256_loadu_pd((double const *)(a01 + cs_a * 2)); - ymm3 = _mm256_loadu_pd((double const *)(a01 + cs_a * 3)); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_broadcast_sd((double const *)(&ones)); + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + } - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + ///implement TRSM/// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + ///transpose of B11// + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] - _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - ymm0 = _mm256_loadu_pd((double const *)(a01 + cs_a * 4)); - ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a * 5)); + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_broadcast_sd((double const *)&zero); + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + //perform mul operation + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_broadcast_sd((double const *)&zero); + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 3*rs_a)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3*rs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_extractf128_pd(ymm6,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_extractf128_pd(ymm7,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_extractf128_pd(ymm8,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_extractf128_pd(ymm9,0)); + //(ROw3): FMA operations + ymm10 = _mm256_fnmadd_pd(ymm6, ymm11, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm11, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm11, ymm8); - a01 += D_NR*cs_a; - ptr_a10_dup += D_NR; - } + //perform mul operation + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); - ymm4 = _mm256_broadcast_sd((double const *)&ones); - if(!is_unitdiag) - { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_sd((double const *)(a11)); - ymm1 = _mm256_broadcast_sd((double const *)(a11+cs_a*1 + 1)); - ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a*2 + 2)); - ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a*3 + 3)); + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2*rs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); - ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); - #ifdef BLIS_DISABLE_TRSM_PREINVERSION - ymm4 = ymm1; - #endif - #ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm4 = _mm256_div_pd(ymm4, ymm1); - #endif - } - _mm256_storeu_pd((double *)(d11_pack), ymm4); + //(ROw2): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm7, ymm10, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm10, ymm8); - for(i = (m-D_MR); (i+1) > 0; i -= D_MR) //loop along 'M' direction - { - a01 = D_A_pack; - a11 = L + (n_remainder - 4)*cs_a + (n_remainder - 4); //pointer to block of A to be used for TRSM - b10 = B + i + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (i) + (n_remainder - 4)*cs_b; //pointer to block of B to be used for TRSM + //perform mul operation + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); - k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); - #endif + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); + //(ROw2): FMA operations + ymm8 = _mm256_fnmadd_pd(ymm16, ymm9, ymm8); - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b10 + 4)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] + //perform mul operation + ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); //A01[0][3] - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - a01 += 1; //move to next row - b10 += cs_b; + if(3 == n_remainder) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + } + else if(2 == n_remainder) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + } + else if(1 == n_remainder) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] } + } + m_remainder -= 4; + } - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); + a10 = L + m_remainder*rs_a; - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + 4)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] + // Do transpose for a10 & store in D_A_pack + double *ptr_a10_dup = D_A_pack; + if(3 == m_remainder) // Repetative A blocks will be 3*3 + { + dim_t p_lda = 4; // packed leading dimension + if(transa) + { + for(dim_t x =0;x < m-m_remainder;x+=p_lda) + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + cs_a)); + ymm2 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); + ymm3 = _mm256_broadcast_sd((double const *)&ones); - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - ymm4 = _mm256_fmsub_pd(ymm1, ymm15, ymm4); //B11[4-7][0] * alpha-= ymm1 + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b + 4)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - ymm6 = _mm256_fmsub_pd(ymm1, ymm15, ymm6); //B11[4-7][1] * alpha -= ymm3 + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b*2 + 4)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - ymm8 = _mm256_fmsub_pd(ymm1, ymm15, ymm8); //B11[4-7][2] * alpha -= ymm5 + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b*3 + 4)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] + a10 += p_lda; + ptr_a10_dup += p_lda*p_lda; + } + } + else + { + for(dim_t x =0;x < m-m_remainder;x++) + { + ymm0 = _mm256_loadu_pd((double const *)(a10 + x*rs_a)); + _mm256_storeu_pd((double *)(ptr_a10_dup + x*p_lda), ymm0); + } + } - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 - ymm10 = _mm256_fmsub_pd(ymm1, ymm15, ymm10); //B11[4-7][3] * alpha -= ymm7 + //cols + for(j = (n - d_nr); (j + 1) > 0; j -= d_nr) //loop along 'N' dimension + { + a10 = D_A_pack; + a11 = L; //pointer to block of A to be used for TRSM + b01 = B + (j* cs_b) + m_remainder; //pointer to block of B to be used for GEMM + b11 = B + (j* cs_b); //pointer to block of B to be used for TRSM - ///implement TRSM/// + k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) - //extract a33 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm0); + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx6n(a10,b01,cs_b,p_lda,k_iter) + + ///GEMM code ends/// + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to store alpha value + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08); + ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x08); - //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2)); + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) - ymm7 = _mm256_fnmadd_pd(ymm1, ymm9, ymm7); - ymm8 = _mm256_fnmadd_pd(ymm1, ymm10, ymm8); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1)); + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm9, ymm5); - ymm6 = _mm256_fnmadd_pd(ymm1, ymm10, ymm6); + _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store(B11[0-3][3]) + + if(transa) + dtrsm_AltXB_ref(a11, b11, m_remainder, 6, cs_a, cs_b, is_unitdiag); + else + dtrsm_AuXB_ref(a11, b11, m_remainder, 6, rs_a, cs_b, is_unitdiag); + } + + dim_t n_remainder = j + d_nr; + if((n_remainder >= 4)) + { + a10 = D_A_pack; + a11 = L; //pointer to block of A to be used for TRSM + b01 = B + ((n_remainder - 4)* cs_b) + m_remainder; //pointer to block of B to be used for GEMM + b11 = B + ((n_remainder - 4)* cs_b); //pointer to block of B to be used for TRSM - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); + k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) - ymm3 = _mm256_fnmadd_pd(ymm1, ymm9, ymm3); - ymm4 = _mm256_fnmadd_pd(ymm1, ymm10, ymm4); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm0); + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx4n(a10,b01,cs_b,p_lda,k_iter) - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1)); + ///implement TRSM/// - ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); - ymm6 = _mm256_fnmadd_pd(ymm1, ymm8, ymm6); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); + ymm3 = _mm256_broadcast_sd((double const *)(b11 + cs_b*3 + 2)); + ymm3 = _mm256_insertf128_pd(ymm3, xmm5, 0); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); - ymm4 = _mm256_fnmadd_pd(ymm1, ymm8, ymm4); + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08); + ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x08); - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm0); + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + xmm5 = _mm256_extractf128_pd(ymm3, 0); + _mm_storeu_pd((double *)(b11 + cs_b * 3),xmm5); + _mm_storel_pd((b11 + cs_b * 3 + 2), _mm256_extractf128_pd(ymm3, 1)); - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); + if(transa) + dtrsm_AltXB_ref(a11, b11, m_remainder, 4, cs_a, cs_b, is_unitdiag); + else + dtrsm_AuXB_ref(a11, b11, m_remainder, 4, rs_a, cs_b, is_unitdiag); + n_remainder -= 4; + } - //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); + if(n_remainder) + { + a10 = D_A_pack; + a11 = L; //pointer to block of A to be used for TRSM + b01 = B + m_remainder; //pointer to block of B to be used for GEMM + b11 = B; //pointer to block of B to be used for TRSM - ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); - ymm4 = _mm256_fnmadd_pd(ymm1, ymm6, ymm4); + k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm0); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + 4), ymm4); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b + 4), ymm6); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*2 + 4), ymm8); - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); - _mm256_storeu_pd((double *)(b11 + cs_b*3 + 4), ymm10); - } + if(3 == n_remainder) + { + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) - dim_t m_remainder = i + D_MR; - if(m_remainder >= 4) - { - a01 = D_A_pack; - a11 = L + (n_remainder - 4)*cs_a + (n_remainder - 4); //pointer to block of A to be used for TRSM - b10 = B + (m_remainder - 4) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (m_remainder - 4) + (n_remainder - 4)*cs_b; //pointer to block of B to be used for TRSM + BLIS_PRE_DTRSM_SMALL_3M_3N(AlphaVal,b11,cs_b) - k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + if(transa) + dtrsm_AltXB_ref(a11, b11, m_remainder, 3, cs_a, cs_b, is_unitdiag); + else + dtrsm_AuXB_ref(a11, b11, m_remainder, 3, rs_a, cs_b, is_unitdiag); + } + else if(2 == n_remainder) + { + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b,p_lda,k_iter) - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); + BLIS_PRE_DTRSM_SMALL_3M_2N(AlphaVal,b11,cs_b) - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + if(transa) + dtrsm_AltXB_ref(a11, b11, m_remainder, 2, cs_a, cs_b, is_unitdiag); + else + dtrsm_AuXB_ref(a11, b11, m_remainder, 2, rs_a, cs_b, is_unitdiag); + } + else if(1 == n_remainder) + { + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b,p_lda,k_iter) - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) + BLIS_PRE_DTRSM_SMALL_3M_1N(AlphaVal,b11,cs_b) - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + if(transa) + dtrsm_AltXB_ref(a11, b11, m_remainder, 1, cs_a, cs_b, is_unitdiag); + else + dtrsm_AuXB_ref(a11, b11, m_remainder, 1, rs_a, cs_b, is_unitdiag); + } + } + } + else if(2 == m_remainder) // Repetative A blocks will be 2*2 + { + dim_t p_lda = 4; // packed leading dimension + if(transa) + { + for(dim_t x =0;x < m-m_remainder;x+=p_lda) + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + cs_a)); + ymm2 = _mm256_broadcast_sd((double const *)&ones); + ymm3 = _mm256_broadcast_sd((double const *)&ones); - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); //A01[0][3] - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); - a01 += 1; //move to next row - b10 += cs_b; - } + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + a10 += p_lda; + ptr_a10_dup += p_lda*p_lda; + } + } + else + { + for(dim_t x =0;x < m-m_remainder;x++) + { + ymm0 = _mm256_loadu_pd((double const *)(a10 + x*rs_a)); + _mm256_storeu_pd((double *)(ptr_a10_dup + x*p_lda), ymm0); + } + } + //cols + for(j = (n - d_nr); (j + 1) > 0; j -= d_nr) //loop along 'N' dimension + { + a10 = D_A_pack; + a11 = L; //pointer to block of A to be used for TRSM + b01 = B + (j* cs_b) + m_remainder; //pointer to block of B to be used for GEMM + b11 = B + (j* cs_b); //pointer to block of B to be used for TRSM - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 + k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS - ///implement TRSM/// + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx6n(a10,b01,cs_b,p_lda,k_iter) + + ///GEMM code ends/// + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to store alpha value + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - //extract a33 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0C); + ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0C); - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) - //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm9, ymm7); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm9, ymm5); + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm9, ymm3); + _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store(B11[0-3][3]) - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + if(transa) + dtrsm_AltXB_ref(a11, b11, m_remainder, 6, cs_a, cs_b, is_unitdiag); + else + dtrsm_AuXB_ref(a11, b11, m_remainder, 6, rs_a, cs_b, is_unitdiag); + } + dim_t n_remainder = j + d_nr; + if((n_remainder >= 4)) + { + a10 = D_A_pack; + a11 = L; //pointer to block of A to be used for TRSM + b01 = B + ((n_remainder - 4)* cs_b) + m_remainder; //pointer to block of B to be used for GEMM + b11 = B + ((n_remainder - 4)* cs_b); //pointer to block of B to be used for TRSM - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx4n(a10,b01,cs_b,p_lda,k_iter) - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); + ///implement TRSM/// - //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); + ymm3 = _mm256_insertf128_pd(ymm3, xmm5, 0); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0C); + ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0C); - m_remainder -=4; - } + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + xmm5 = _mm256_extractf128_pd(ymm3, 0); + _mm_storeu_pd((double *)(b11 + cs_b * 3), xmm5); - if(m_remainder) - { - if(3 == m_remainder) + if(transa) + dtrsm_AltXB_ref(a11, b11, m_remainder, 4, cs_a, cs_b, is_unitdiag); + else + dtrsm_AuXB_ref(a11, b11, m_remainder, 4, rs_a, cs_b, is_unitdiag); + n_remainder -= 4; + } + if(n_remainder) { - a01 = D_A_pack; - a11 = L + (n_remainder - 4)*cs_a + (n_remainder - 4); //pointer to block of A to be used for TRSM - b10 = B + (m_remainder - 3) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (m_remainder - 3) + (n_remainder - 4)*cs_b; //pointer to block of B to be used for TRSM + a10 = D_A_pack; + a11 = L; //pointer to block of A to be used for TRSM + b01 = B + m_remainder; //pointer to block of B to be used for GEMM + b11 = B; //pointer to block of B to be used for TRSM - k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + if(3 == n_remainder) + { + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + BLIS_PRE_DTRSM_SMALL_2M_3N(AlphaVal,b11,cs_b) + + if(transa) + dtrsm_AltXB_ref(a11, b11, m_remainder, 3, cs_a, cs_b, is_unitdiag); + else + dtrsm_AuXB_ref(a11, b11, m_remainder, 3, rs_a, cs_b, is_unitdiag); + } + else if(2 == n_remainder) { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b,p_lda,k_iter) - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) + BLIS_PRE_DTRSM_SMALL_2M_2N(AlphaVal,b11,cs_b) - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + if(transa) + dtrsm_AltXB_ref(a11, b11, m_remainder, 2, cs_a, cs_b, is_unitdiag); + else + dtrsm_AuXB_ref(a11, b11, m_remainder, 2, rs_a, cs_b, is_unitdiag); + } + else if(1 == n_remainder) + { + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b,p_lda,k_iter) + + BLIS_PRE_DTRSM_SMALL_2M_1N(AlphaVal,b11,cs_b) + if(transa) + dtrsm_AltXB_ref(a11, b11, m_remainder, 1, cs_a, cs_b, is_unitdiag); + else + dtrsm_AuXB_ref(a11, b11, m_remainder, 1, rs_a, cs_b, is_unitdiag); + } + } - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + } + else if(1 == m_remainder) // Repetative A blocks will be 1*1 + { + dim_t p_lda = 4; // packed leading dimension + if(transa) + { + for(dim_t x =0;x < m-m_remainder;x+=p_lda) + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_broadcast_sd((double const *)&ones); + ymm2 = _mm256_broadcast_sd((double const *)&ones); + ymm3 = _mm256_broadcast_sd((double const *)&ones); - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); //A01[0][3] - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); - a01 += 1; //move to next row - b10 += cs_b; - } + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 + a10 += p_lda; + ptr_a10_dup += p_lda*p_lda; + } + } + else + { + for(dim_t x =0;x < m-m_remainder;x++) + { + ymm0 = _mm256_loadu_pd((double const *)(a10 + x*rs_a)); + _mm256_storeu_pd((double *)(ptr_a10_dup + x*p_lda), ymm0); + } + } + //cols + for(j = (n - d_nr); (j + 1) > 0; j -= d_nr) //loop along 'N' dimension + { + a10 = D_A_pack; + a11 = L; //pointer to block of A to be used for TRSM + b01 = B + (j* cs_b) + m_remainder; //pointer to block of B to be used for GEMM + b11 = B + (j* cs_b); //pointer to block of B to be used for TRSM - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*3 + 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 + k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) - ///implement TRSM/// + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS - //extract a33 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx6n(a10,b01,cs_b,p_lda,k_iter) + + ///GEMM code ends/// + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to store alpha value + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); + ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0E); - //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm9, ymm7); + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm9, ymm5); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm9, ymm3); + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store(B11[0-3][3]) - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + if(transa) + dtrsm_AltXB_ref(a11, b11, m_remainder, 6, cs_a, cs_b, is_unitdiag); + else + dtrsm_AuXB_ref(a11, b11, m_remainder, 6, rs_a, cs_b, is_unitdiag); + } + dim_t n_remainder = j + d_nr; + if((n_remainder >= 4)) + { + a10 = D_A_pack; + a11 = L; //pointer to block of A to be used for TRSM + b01 = B + ((n_remainder - 4)* cs_b) + m_remainder; //pointer to block of B to be used for GEMM + b11 = B + ((n_remainder - 4)* cs_b); //pointer to block of B to be used for TRSM - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); + k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx4n(a10,b01,cs_b,p_lda,k_iter) - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + ///implement TRSM/// - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x07); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*3 + 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x07); + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); + ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0E); - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - xmm5 = _mm256_extractf128_pd(ymm9, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 3),xmm5); - _mm_storel_pd((b11 + cs_b * 3 + 2), _mm256_extractf128_pd(ymm9, 1)); + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) - m_remainder -=3; + if(transa) + dtrsm_AltXB_ref(a11, b11, m_remainder, 4, cs_a, cs_b, is_unitdiag); + else + dtrsm_AuXB_ref(a11, b11, m_remainder, 4, rs_a, cs_b, is_unitdiag); + n_remainder -= 4; } - else if(2 == m_remainder) + if(n_remainder) { - a01 = D_A_pack; - a11 = L + (n_remainder - 4)*cs_a + (n_remainder - 4); //pointer to block of A to be used for TRSM - b10 = B + (m_remainder - 2) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (m_remainder - 2) + (n_remainder - 4)*cs_b; //pointer to block of B to be used for TRSM + a10 = D_A_pack; + a11 = L; //pointer to block of A to be used for TRSM + b01 = B + m_remainder; //pointer to block of B to be used for GEMM + b11 = B; //pointer to block of B to be used for TRSM - k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + if(3 == n_remainder) { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) + BLIS_PRE_DTRSM_SMALL_1M_3N(AlphaVal,b11,cs_b) + + if(transa) + dtrsm_AltXB_ref(a11, b11, m_remainder, 3, cs_a, cs_b, is_unitdiag); + else + dtrsm_AuXB_ref(a11, b11, m_remainder, 3, rs_a, cs_b, is_unitdiag); + } + else if(2 == n_remainder) + { + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b,p_lda,k_iter) - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + BLIS_PRE_DTRSM_SMALL_1M_2N(AlphaVal,b11,cs_b) - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + if(transa) + dtrsm_AltXB_ref(a11, b11, m_remainder, 2, cs_a, cs_b, is_unitdiag); + else + dtrsm_AuXB_ref(a11, b11, m_remainder, 2, rs_a, cs_b, is_unitdiag); + } + else if(1 == n_remainder) + { + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b,p_lda,k_iter) - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); //A01[0][3] - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + BLIS_PRE_DTRSM_SMALL_1M_1N(AlphaVal,b11,cs_b) - a01 += 1; //move to next row - b10 += cs_b; + if(transa) + dtrsm_AltXB_ref(a11, b11, m_remainder, 1, cs_a, cs_b, is_unitdiag); + else + dtrsm_AuXB_ref(a11, b11, m_remainder, 1, rs_a, cs_b, is_unitdiag); } + } + } - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + if ((required_packing_A == 1) && + bli_mem_is_alloc( &local_mem_buf_A_s )) + { + bli_membrk_release(&rntm,&local_mem_buf_A_s); + } + return BLIS_SUCCESS; +} + +/* TRSM for the Left Upper case AX = alpha * B, Double precision + * A is Left side, upper-triangular, transpose, non-unit/unit diagonal + * dimensions A: mxm X: mxn B: mxn + a10 ----> b11---> + *********** ***************** + * * * * *b01*b11* * * + **a10 * * a11 b11 * * * * * + ********* | | ***************** + *a11* * | | * * * * * + * * * | | * * * * * + ****** v v ***************** + * * * * * * * + * * * * * * * + * * ***************** + * + a11---> - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + * TRSM for the case AX = alpha * B, Double precision + * A is Left side, lower-triangular, no-transpose, non-unit/unit diagonal + * dimensions A: mxm X: mxn B: mxn - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + b01---> + * ***************** + ** * * * * * + * * * * * * * + * * *b01* * * * + * * * * * * * +a10 ****** b11 ***************** + | * * * | * * * * * + | * * * | * * * * * + | *a10*a11* | *b11* * * * + v * * * v * * * * * + *********** ***************** + * * * * * * * * * + * * * * * * * * * + * * * * * * * * * + * * * * * * * * * + **************** ***************** + a11---> +*/ +BLIS_INLINE err_t bli_dtrsm_small_AutXB_AlXB +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +) +{ + dim_t m = bli_obj_length(b); // number of rows of matrix B + dim_t n = bli_obj_width(b); // number of columns of matrix B - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 + bool transa = bli_obj_has_trans(a); + dim_t cs_a, rs_a; + dim_t d_mr = 8,d_nr = 6; - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 + // Swap rs_a & cs_a in case of non-tranpose. + if(transa) + { + cs_a = bli_obj_col_stride(a); // column stride of A + rs_a = bli_obj_row_stride(a); // row stride of A + } + else + { + cs_a = bli_obj_row_stride(a); // row stride of A + rs_a = bli_obj_col_stride(a); // column stride of A + } + dim_t cs_b = bli_obj_col_stride(b); // column stride of B - ///implement TRSM/// + dim_t i, j, k; //loop variables + dim_t k_iter; //number of times GEMM to be performed - //extract a33 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + double AlphaVal = *(double *)AlphaObj->buffer; //value of alpha + double *L = a->buffer; //pointer to matrix A + double *B = b->buffer; //pointer to matrix B - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + double *a10, *a11, *b01, *b11; //pointers that point to blocks for GEMM and TRSM - //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm9, ymm7); + double ones = 1.0; + bool is_unitdiag = bli_obj_has_unit_diag(a); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm9, ymm5); + //scratch registers + __m256d ymm0, ymm1, ymm2, ymm3; + __m256d ymm4, ymm5, ymm6, ymm7; + __m256d ymm8, ymm9, ymm10, ymm11; + __m256d ymm12, ymm13, ymm14, ymm15; + __m256d ymm16, ymm17, ymm18, ymm19; + __m256d ymm20; - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm9, ymm3); + __m128d xmm5; - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + gint_t required_packing_A = 1; + mem_t local_mem_buf_A_s = {0}; + double *D_A_pack = NULL; + double d11_pack[d_mr] __attribute__((aligned(64))); + rntm_t rntm; - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + bli_rntm_init_from_global( &rntm ); + bli_rntm_set_num_threads_only( 1, &rntm ); + bli_membrk_rntm_set_membrk( &rntm ); - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); + siz_t buffer_size = bli_pool_block_size( + bli_membrk_pool( + bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), + bli_rntm_membrk(&rntm))); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); + if ( (d_mr * m * sizeof(double)) > buffer_size) + return BLIS_NOT_YET_IMPLEMENTED; - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + if (required_packing_A == 1) + { + // Get the buffer from the pool. + bli_membrk_acquire_m(&rntm, + buffer_size, + BLIS_BITVAL_BUFFER_FOR_A_BLOCK, + &local_mem_buf_A_s); + if(FALSE==bli_mem_is_alloc(&local_mem_buf_A_s)) return BLIS_NULL_POINTER; + D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); + if(NULL==D_A_pack) return BLIS_NULL_POINTER; + } - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); + /* + Performs solving TRSM for 8 colmns at a time from 0 to m/8 in steps of d_mr + a. Load, transpose, Pack A (a10 block), the size of packing 8x6 to 8x (m-8) + First there will be no GEMM and no packing of a10 because it is only TRSM + b. Using packed a10 block and b01 block perform GEMM operation + c. Use GEMM outputs, perform TRSM operaton using a11, b11 and update B + d. Repeat b,c for n rows of B in steps of d_nr + */ + for(i = 0;(i+d_mr-1) < m; i += d_mr) //loop along 'M' dimension + { + a10 = L + (i*cs_a); //pointer to block of A to be used for GEMM + a11 = L + (i*rs_a) + (i*cs_a); + dim_t p_lda = d_mr; // packed leading dimension - //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + if(transa) + { + /* + Load, tranpose and pack current A block (a10) into packed buffer memory D_A_pack + a. This a10 block is used in GEMM portion only and this + a10 block size will be increasing by d_mr for every next itteration + untill it reaches 8x(m-8) which is the maximum GEMM alone block size in A + b. This packed buffer is reused to calculate all n rows of B matrix + */ + bli_dtrsm_small_pack('L', i, 1, a10, cs_a, D_A_pack, p_lda,d_mr); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + /* + Pack 8 diagonal elements of A block into an array + a. This helps in utilze cache line efficiently in TRSM operation + b. store ones when input is unit diagonal + */ + dtrsm_small_pack_diag_element(is_unitdiag,a11,cs_a,d11_pack,d_mr); + } + else + { + bli_dtrsm_small_pack('L', i, 0, a10, rs_a, D_A_pack, p_lda,d_mr); + dtrsm_small_pack_diag_element(is_unitdiag,a11,rs_a,d11_pack,d_mr); + } - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x03); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x03); + /* + a. Perform GEMM using a10, b01. + b. Perform TRSM on a11, b11 + c. This loop GEMM+TRSM loops operates with 8x6 block size + along n dimension for every d_nr rows of b01 where + packed A buffer is reused in computing all n rows of B. + d. Same approch is used in remaining fringe cases. + */ + dim_t temp = n - d_nr + 1; + for(j = 0; j < temp; j += d_nr) //loop along 'N' dimension + { + a10 = D_A_pack; + a11 = L + (i*rs_a) + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + j*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - xmm5 = _mm256_extractf128_pd(ymm9, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 3),xmm5); + k_iter = i; - m_remainder -=2; - } - else if (1 == m_remainder) - { - a01 = D_A_pack; - a11 = L + (n_remainder - 4)*cs_a + (n_remainder - 4); //pointer to block of A to be used for TRSM - b10 = B + (m_remainder - 1) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (m_remainder - 1) + (n_remainder - 4)*cs_b; //pointer to block of B to be used for TRSM + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS - k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + /* + Peform GEMM between a10 and b01 blocks + For first itteration there will be no GEMM operation + where k_iter are zero + */ + BLIS_DTRSM_SMALL_GEMM_8mx6n(a10,b01,cs_b,p_lda,k_iter) - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); + /* + Load b11 of size 6x8 and multiply with alpha + Add the GEMM output and perform inregister transose of b11 + to peform TRSM operation. + */ + BLIS_DTRSM_SMALL_NREG_TRANSPOSE_6x8(b11,cs_b,AlphaVal) - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + /* + Compute 8x6 TRSM block by using GEMM block output in register + a. The 8x6 input (gemm outputs) are stored in combinations of ymm registers + 1. ymm8, ymm4 2. ymm9, ymm5 3. ymm10, ymm6, 4. ymm11, ymm7 + 5. ymm12, ymm17 6. ymm13,ymm18, 7. ymm14,ymm19 8. ymm15, ymm20 + where ymm8-ymm15 holds 8x4 data and reaming 8x2 will be hold by + other registers + b. Towards the end do in regiser transpose of TRSM output and store in b11 + */ + ////extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) + //perform mul operation + ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); + ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm1); - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(ROw1): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*1)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); + ymm5 = _mm256_fnmadd_pd(ymm2, ymm4, ymm5); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm8, ymm10); + ymm6 = _mm256_fnmadd_pd(ymm2, ymm4, ymm6); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); + ymm7 = _mm256_fnmadd_pd(ymm2, ymm4, ymm7); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm8, ymm12); + ymm17 = _mm256_fnmadd_pd(ymm2, ymm4, ymm17); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm8, ymm13); + ymm18 = _mm256_fnmadd_pd(ymm2, ymm4, ymm18); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm8, ymm14); + ymm19 = _mm256_fnmadd_pd(ymm2, ymm4, ymm19); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm8, ymm15); + ymm20 = _mm256_fnmadd_pd(ymm2, ymm4, ymm20); + + + //perform mul operation + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm1); + + a11 += rs_a; - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm9, ymm10); + ymm6 = _mm256_fnmadd_pd(ymm2, ymm5, ymm6); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm9, ymm11); + ymm7 = _mm256_fnmadd_pd(ymm2, ymm5, ymm7); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm9, ymm12); + ymm17 = _mm256_fnmadd_pd(ymm2, ymm5, ymm17); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm9, ymm13); + ymm18 = _mm256_fnmadd_pd(ymm2, ymm5, ymm18); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm9, ymm14); + ymm19 = _mm256_fnmadd_pd(ymm2, ymm5, ymm19); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm9, ymm15); + ymm20 = _mm256_fnmadd_pd(ymm2, ymm5, ymm20); + + //perform mul operation + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm1); + + a11 += rs_a; - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); //A01[0][3] - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //(ROw5): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm10, ymm11); + ymm7 = _mm256_fnmadd_pd(ymm2, ymm6, ymm7); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm10, ymm12); + ymm17 = _mm256_fnmadd_pd(ymm2, ymm6, ymm17); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm10, ymm13); + ymm18 = _mm256_fnmadd_pd(ymm2, ymm6, ymm18); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm10, ymm14); + ymm19 = _mm256_fnmadd_pd(ymm2, ymm6, ymm19); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm10, ymm15); + ymm20 = _mm256_fnmadd_pd(ymm2, ymm6, ymm20); + + //perform mul operation + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm1); + + a11 += rs_a; - a01 += 1; //move to next row - b10 += cs_b; - } + //extract a44 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + //(ROw4): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm11, ymm12); + ymm17 = _mm256_fnmadd_pd(ymm2, ymm7, ymm17); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm11, ymm13); + ymm18 = _mm256_fnmadd_pd(ymm2, ymm7, ymm18); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm11, ymm14); + ymm19 = _mm256_fnmadd_pd(ymm2, ymm7, ymm19); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm11, ymm15); + ymm20 = _mm256_fnmadd_pd(ymm2, ymm7, ymm20); + + //perform mul operation + ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); + ymm17 = DTRSM_SMALL_DIV_OR_SCALE(ymm17, ymm1); + + a11 += rs_a; - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + //extract a55 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); - ymm0 = _mm256_broadcast_sd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + //(ROw5): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm12, ymm13); + ymm18 = _mm256_fnmadd_pd(ymm2, ymm17, ymm18); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm12, ymm14); + ymm19 = _mm256_fnmadd_pd(ymm2, ymm17, ymm19); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm12, ymm15); + ymm20 = _mm256_fnmadd_pd(ymm2, ymm17, ymm20); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + //perform mul operation + ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm1); + ymm18 = DTRSM_SMALL_DIV_OR_SCALE(ymm18, ymm1); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 + a11 += rs_a; - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 + //extract a66 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); - ///implement TRSM/// + //(ROw6): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm13, ymm14); + ymm19 = _mm256_fnmadd_pd(ymm2, ymm18, ymm19); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm13, ymm15); + ymm20 = _mm256_fnmadd_pd(ymm2, ymm18, ymm20); - //extract a33 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + //perform mul operation + ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); + ymm19 = DTRSM_SMALL_DIV_OR_SCALE(ymm19, ymm1); - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + a11 += rs_a; - //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm9, ymm7); + //extract a77 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm9, ymm5); + //(ROw7): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm14, ymm15); + ymm20 = _mm256_fnmadd_pd(ymm2, ymm19, ymm20); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm9, ymm3); + //perform mul operation + ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); + ymm20 = DTRSM_SMALL_DIV_OR_SCALE(ymm20, ymm1); - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + a11 += rs_a; - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + BLIS_DTRSM_SMALL_NREG_TRANSPOSE_8x6_AND_STORE(b11,cs_b) + } - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); + dim_t n_rem = n-j; + if(n_rem >= 4) + { + a10 = D_A_pack; + a11 = L + (i*rs_a) + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + j*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); + k_iter = i ; //number of times GEMM to be performed(in blocks of 4x4) - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_8mx4n(a10,b01,cs_b,p_lda,k_iter) - //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm0 = _mm256_broadcast_sd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x01); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x01); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x01); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x01); - - _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm3, 0)); - _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm5, 0)); - _mm_storel_pd((b11 + cs_b * 2), _mm256_extractf128_pd(ymm7, 0)); - _mm_storel_pd((b11 + cs_b * 3), _mm256_extractf128_pd(ymm9, 0)); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] + ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] + ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *2 + 4)); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] + ymm7 = _mm256_loadu_pd((double const *)(b11 + cs_b *3 + 4)); //B11[0][7] B11[1][7] B11[2][7] B11[3][7] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); //B11[0-3][3] * alpha -= B01[0-3][3] + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4] + ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] * alpha -= B01[0-3][5] + ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); //B11[0-3][6] * alpha -= B01[0-3][6] + ymm7 = _mm256_fmsub_pd(ymm7, ymm16, ymm15); //B11[0-3][7] * alpha -= B01[0-3][7] - m_remainder -=1; - } - } - n_remainder -= 4; - } + ///implement TRSM/// - if(n_remainder == 3) - { - a01 = L + (n_remainder - 3) + n_remainder*cs_a; //pointer to block of A to be used in GEMM - a11 = L + (n_remainder - 3)*cs_a + (n_remainder - 3); //pointer to block of A to be used for TRSM + ///transpose of B11// + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - double *ptr_a10_dup = D_A_pack; + ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[2][4] B11[2][5] + ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[2][6] B11[2][7] - dim_t p_lda = (n-n_remainder); // packed leading dimension - // perform copy of A to packed buffer D_A_pack + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] - for(dim_t x =0;x < p_lda;x+=D_NR) - { - ymm0 = _mm256_loadu_pd((double const *)(a01)); - ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); - ymm2 = _mm256_loadu_pd((double const *)(a01 + cs_a * 2)); - ymm3 = _mm256_loadu_pd((double const *)(a01 + cs_a * 3)); + ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3] + ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3] - ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[1][4] B11[1][5] B11[3][4] B11[3][5] + ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[1][6] B11[1][7] B11[3][6] B11[3][7] - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3] + ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3] - _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); + ymm0 = _mm256_broadcast_sd((double const *)&ones); - ymm0 = _mm256_loadu_pd((double const *)(a01 + cs_a * 4)); - ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a * 5)); + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_broadcast_sd((double const *)&zero); + //perform mul operation + ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*1)); + ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + + a11 += rs_a; + + //(ROw1): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); + ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); + ymm11 = _mm256_fnmadd_pd(ymm4, ymm8, ymm11); + ymm12 = _mm256_fnmadd_pd(ymm5, ymm8, ymm12); + ymm13 = _mm256_fnmadd_pd(ymm6, ymm8, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm7, ymm8, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm8, ymm15); + + //perform mul operation + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); + + ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + + a11 += rs_a; - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_broadcast_sd((double const *)&zero); + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + //(ROw2): FMA operations + ymm10 = _mm256_fnmadd_pd(ymm3, ymm9, ymm10); + ymm11 = _mm256_fnmadd_pd(ymm4, ymm9, ymm11); + ymm12 = _mm256_fnmadd_pd(ymm5, ymm9, ymm12); + ymm13 = _mm256_fnmadd_pd(ymm6, ymm9, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm7, ymm9, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm9, ymm15); - _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_extractf128_pd(ymm6,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_extractf128_pd(ymm7,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_extractf128_pd(ymm8,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_extractf128_pd(ymm9,0)); + //perform mul operation + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); - a01 += D_NR*cs_a; - ptr_a10_dup += D_NR; - } + ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); - ymm4 = _mm256_broadcast_sd((double const *)&ones); - if(!is_unitdiag) - { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_sd((double const *)(a11)); - ymm1 = _mm256_broadcast_sd((double const *)(a11+cs_a*1 + 1)); - ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a*2 + 2)); - ymm3 = _mm256_broadcast_sd((double const *)&ones); + a11 += rs_a; - ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); - #ifdef BLIS_DISABLE_TRSM_PREINVERSION - ymm4 = ymm1; - #endif - #ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm4 = _mm256_div_pd(ymm4, ymm1); - #endif - } - _mm256_storeu_pd((double *)(d11_pack), ymm4); + //(ROw5): FMA operations + ymm11 = _mm256_fnmadd_pd(ymm4, ymm10, ymm11); + ymm12 = _mm256_fnmadd_pd(ymm5, ymm10, ymm12); + ymm13 = _mm256_fnmadd_pd(ymm6, ymm10, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm7, ymm10, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm10, ymm15); - for(i = (m-D_MR); (i+1) > 0; i -= D_MR) //loop along 'M' direction - { - a01 = D_A_pack; - a11 = L + (n_remainder - 3)*cs_a + (n_remainder - 3); //pointer to block of A to be used for TRSM - b10 = B + i + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (i) + (n_remainder - 3)*cs_b; //pointer to block of B to be used for TRSM + //perform mul operation + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); - k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + ymm0 = _mm256_broadcast_sd((double const *)&ones); - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + 4), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + 4 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + 4 + cs_b*2), _MM_HINT_T0); - #endif + //extract a44 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b10 + 4)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] + a11 += rs_a; - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) + //(ROw4): FMA operations + ymm12 = _mm256_fnmadd_pd(ymm5, ymm11, ymm12); + ymm13 = _mm256_fnmadd_pd(ymm6, ymm11, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm7, ymm11, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm11, ymm15); - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) + //perform mul operation + ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) + ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); - a01 += 1; //move to next row - b10 += cs_b; - } + a11 += rs_a; - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); + //extract a55 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + 4)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] + //(ROw5): FMA operations + ymm13 = _mm256_fnmadd_pd(ymm6, ymm12, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm7, ymm12, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm12, ymm15); - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - ymm4 = _mm256_fmsub_pd(ymm1, ymm15, ymm4); //B11[4-7][0] * alpha-= ymm1 + //perform mul operation + ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm1); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b + 4)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 +cs_a*7)); - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - ymm6 = _mm256_fmsub_pd(ymm1, ymm15, ymm6); //B11[4-7][1] * alpha -= ymm3 + a11 += rs_a; - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b*2 + 4)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] + //extract a66 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - ymm8 = _mm256_fmsub_pd(ymm1, ymm15, ymm8); //B11[4-7][2] * alpha -= ymm5 + //(ROw6): FMA operations + ymm14 = _mm256_fnmadd_pd(ymm7, ymm13, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm13, ymm15); - ///implement TRSM/// + //perform mul operation + ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + //extract a77 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm0); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + a11 += rs_a; + //(ROw7): FMA operations + ymm15 = _mm256_fnmadd_pd(ymm16, ymm14, ymm15); - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1)); + //perform mul operation + ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); - ymm6 = _mm256_fnmadd_pd(ymm1, ymm8, ymm6); + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); + ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] + ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2] - ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); - ymm4 = _mm256_fnmadd_pd(ymm1, ymm8, ymm4); + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm0); + ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] + ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); + ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[4][1] B11[5][1] B11[4][3] B11[5][3] + ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[6][1] B11[7][1] B11[6][3] B11[7][3] - ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); - ymm4 = _mm256_fnmadd_pd(ymm1, ymm6, ymm4); + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm0); + ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] + ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + 4), ymm4); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b + 4), ymm6); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*2 + 4), ymm8); + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); //store B11[4][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); //store B11[5][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 4), ymm6); //store B11[6][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3 + 4), ymm7); //store B11[7][0-3] + + n_rem -=4; + j +=4; } - dim_t m_remainder = i + D_MR; - if(m_remainder >= 4) + if(n_rem) { - a01 = D_A_pack; - a11 = L + (n_remainder - 3)*cs_a + (n_remainder - 3); //pointer to block of A to be used for TRSM - b10 = B + (m_remainder - 4) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (m_remainder - 4) + (n_remainder - 3)*cs_b; //pointer to block of B to be used for TRSM + a10 = D_A_pack; + a11 = L + (i*rs_a) + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + j*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM - k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + k_iter = i; //number of times GEMM to be performed(in blocks of 4x4) - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + if(3 == n_rem) { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_8mx3n(a10,b01,cs_b,p_lda,k_iter) - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] + ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] + ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *2 + 4)); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); - a01 += 1; //move to next row - b10 += cs_b; + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4] + ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] * alpha -= B01[0-3][5] + ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); //B11[0-3][6] * alpha -= B01[0-3][6] + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); } + else if(2 == n_rem) + { + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_8mx2n(a10,b01,cs_b,p_lda,k_iter) - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] + ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4] + ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] * alpha -= B01[0-3][5] + ymm6 = _mm256_broadcast_sd((double const *)(&ones)); + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + } + else if(1 == n_rem) + { + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_8mx1n(a10,b01,cs_b,p_lda,k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_broadcast_sd((double const *)(&ones)); + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4] + ymm5 = _mm256_broadcast_sd((double const *)(&ones)); + ymm6 = _mm256_broadcast_sd((double const *)(&ones)); + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + } ///implement TRSM/// - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + ///transpose of B11// + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[2][4] B11[2][5] + ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[2][6] B11[2][7] - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); + ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3] + ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3] - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); + ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[1][4] B11[1][5] B11[3][4] B11[3][5] + ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[1][6] B11[1][7] B11[3][6] B11[3][7] - //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3] + ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3] - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); + ymm0 = _mm256_broadcast_sd((double const *)&ones); - m_remainder -=4; - } + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); - if(m_remainder) - { - if(3 == m_remainder) - { - a01 = D_A_pack; - a11 = L + (n_remainder - 3)*cs_a + (n_remainder - 3); //pointer to block of A to be used for TRSM - b10 = B + (m_remainder - 3) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (m_remainder - 3) + (n_remainder - 3)*cs_b; //pointer to block of B to be used for TRSM + //perform mul operation + ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); - k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*1)); + ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + + a11 += rs_a; + + //(ROw1): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); + ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); + ymm11 = _mm256_fnmadd_pd(ymm4, ymm8, ymm11); + ymm12 = _mm256_fnmadd_pd(ymm5, ymm8, ymm12); + ymm13 = _mm256_fnmadd_pd(ymm6, ymm8, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm7, ymm8, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm8, ymm15); + + //perform mul operation + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); + + ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + + a11 += rs_a; - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + //(ROw2): FMA operations + ymm10 = _mm256_fnmadd_pd(ymm3, ymm9, ymm10); + ymm11 = _mm256_fnmadd_pd(ymm4, ymm9, ymm11); + ymm12 = _mm256_fnmadd_pd(ymm5, ymm9, ymm12); + ymm13 = _mm256_fnmadd_pd(ymm6, ymm9, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm7, ymm9, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm9, ymm15); - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) + //perform mul operation + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + a11 += rs_a; - a01 += 1; //move to next row - b10 += cs_b; - } + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + //(ROw5): FMA operations + ymm11 = _mm256_fnmadd_pd(ymm4, ymm10, ymm11); + ymm12 = _mm256_fnmadd_pd(ymm5, ymm10, ymm12); + ymm13 = _mm256_fnmadd_pd(ymm6, ymm10, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm7, ymm10, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm10, ymm15); - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + //perform mul operation + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + ymm0 = _mm256_broadcast_sd((double const *)&ones); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2)); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*2 + 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 + //extract a44 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); - ///implement TRSM/// + ymm5 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + a11 += rs_a; - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + //(ROw4): FMA operations + ymm12 = _mm256_fnmadd_pd(ymm5, ymm11, ymm12); + ymm13 = _mm256_fnmadd_pd(ymm6, ymm11, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm7, ymm11, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm11, ymm15); - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); + //perform mul operation + ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + a11 += rs_a; - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); + //extract a55 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); - //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + //(ROw5): FMA operations + ymm13 = _mm256_fnmadd_pd(ymm6, ymm12, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm7, ymm12, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm12, ymm15); - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x07); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2)); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*2 + 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x07); + //perform mul operation + ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm1); - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - xmm5 = _mm256_extractf128_pd(ymm7, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 2),xmm5); - _mm_storel_pd((b11 + cs_b * 2 + 2), _mm256_extractf128_pd(ymm7, 1)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 +cs_a*7)); - m_remainder -=3; - } - else if(2 == m_remainder) - { - a01 = D_A_pack; - a11 = L + (n_remainder - 3)*cs_a + (n_remainder - 3); //pointer to block of A to be used for TRSM - b10 = B + (m_remainder - 2) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (m_remainder - 2) + (n_remainder - 3)*cs_b; //pointer to block of B to be used for TRSM + a11 += rs_a; - k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + //extract a66 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + //(ROw6): FMA operations + ymm14 = _mm256_fnmadd_pd(ymm7, ymm13, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm13, ymm15); - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) + //perform mul operation + ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + //extract a77 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); - a01 += 1; //move to next row - b10 += cs_b; - } + a11 += rs_a; + //(ROw7): FMA operations + ymm15 = _mm256_fnmadd_pd(ymm16, ymm14, ymm15); - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + //perform mul operation + ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] + ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2] - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ///implement TRSM/// + ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] + ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[4][1] B11[5][1] B11[4][3] B11[5][3] + ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[6][1] B11[7][1] B11[6][3] B11[7][3] - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); + ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] + ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + if(3 == n_rem) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); //store B11[4][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); //store B11[5][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 4), ymm6); //store B11[6][0-3] + } + else if(2 == n_rem) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); //store B11[4][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); //store B11[5][0-3] + } + else if(1 == n_rem) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); //store B11[4][0-3] + } + } + } - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); + //======================M remainder cases================================ + dim_t m_rem = m-i; + if(m_rem>=4) //implementation for reamainder rows(when 'M' is not a multiple of d_mr) + { + a10 = L + (i*cs_a); //pointer to block of A to be used for GEMM + a11 = L + (i*rs_a) + (i*cs_a); + double *ptr_a10_dup = D_A_pack; + dim_t p_lda = 4; // packed leading dimension - //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + if(transa) + { + for(dim_t x =0;x < i;x+=p_lda) + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + cs_a)); + ymm2 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); + ymm3 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x03); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x03); + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - xmm5 = _mm256_extractf128_pd(ymm7, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 2),xmm5); + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); - m_remainder -=2; + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); + + a10 += p_lda; + ptr_a10_dup += p_lda*p_lda; } - else if (1 == m_remainder) + } + else + { + for(dim_t x =0;x < i;x++) { - a01 = D_A_pack; - a11 = L + (n_remainder - 3)*cs_a + (n_remainder - 3); //pointer to block of A to be used for TRSM - b10 = B + (m_remainder - 1) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (m_remainder - 1) + (n_remainder - 3)*cs_b; //pointer to block of B to be used for TRSM + ymm0 = _mm256_loadu_pd((double const *)(a10 + rs_a * x)); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * x), ymm0); + } + } - k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + ymm4 = _mm256_broadcast_sd((double const *)&ones); + if(!is_unitdiag) + { + if(transa) + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11+cs_a*1 + 1)); + ymm2 = _mm256_broadcast_sd((double const *)(a11+cs_a*2 + 2)); + ymm3 = _mm256_broadcast_sd((double const *)(a11+cs_a*3 + 3)); + } + else + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11+rs_a*1 + 1)); + ymm2 = _mm256_broadcast_sd((double const *)(a11+rs_a*2 + 2)); + ymm3 = _mm256_broadcast_sd((double const *)(a11+rs_a*3 + 3)); + } - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); + ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); + #ifdef BLIS_DISABLE_TRSM_PREINVERSION + ymm4 = ymm1; + #endif + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + ymm4 = _mm256_div_pd(ymm4, ymm1); + #endif + } + _mm256_storeu_pd((double *)(d11_pack), ymm4); - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + for(j = 0; (j+d_nr-1) < n; j += d_nr) //loop along 'N' dimension + { + a10 = D_A_pack; //pointer to block of A to be used for GEMM + a11 = L + (i*rs_a) + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM + b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM + + k_iter = i; //number of times GEMM operation to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx6n(a10,b01,cs_b,p_lda,k_iter) - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + ///implement TRSM/// + ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); //B11[0-3][3] * alpha -= B01[0-3][3] - a01 += 1; //move to next row - b10 += cs_b; - } + ///transpose of B11// + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] - ymm0 = _mm256_broadcast_sd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); + + + ymm16 = _mm256_broadcast_sd((double const *)(&ones)); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 + ////unpacklow//// + ymm7 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + //ymm16; - ///implement TRSM/// + //rearrange low elements + ymm4 = _mm256_permute2f128_pd(ymm7,ymm16,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm6 = _mm256_permute2f128_pd(ymm7,ymm16,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3] - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + //ymm16; - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + //rearrange high elements + ymm5 = _mm256_permute2f128_pd(ymm0,ymm16,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm7 = _mm256_permute2f128_pd(ymm0,ymm16,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + //b11 transpose end - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); + ////extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); + //perform mul operation + ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); + ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm1); - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); + //(ROw1): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*1)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); + ymm5 = _mm256_fnmadd_pd(ymm2, ymm4, ymm5); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm8, ymm10); + ymm6 = _mm256_fnmadd_pd(ymm2, ymm4, ymm6); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); + ymm7 = _mm256_fnmadd_pd(ymm2, ymm4, ymm7); - //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + //perform mul operation + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm1); - ymm0 = _mm256_broadcast_sd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x01); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x01); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x01); + a11 += rs_a; - _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm3, 0)); - _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm5, 0)); - _mm_storel_pd((b11 + cs_b * 2), _mm256_extractf128_pd(ymm7, 0)); + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - m_remainder -=1; - } - } - n_remainder -= 3; - } - else if(n_remainder == 2) - { - a01 = L + (n_remainder - 2) + n_remainder*cs_a; //pointer to block of A to be used in GEMM - a11 = L + (n_remainder - 2)*cs_a + (n_remainder - 2); //pointer to block of A to be used for TRSM + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm9, ymm10); + ymm6 = _mm256_fnmadd_pd(ymm2, ymm5, ymm6); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm9, ymm11); + ymm7 = _mm256_fnmadd_pd(ymm2, ymm5, ymm7); - double *ptr_a10_dup = D_A_pack; + //perform mul operation + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm1); - dim_t p_lda = (n-n_remainder); // packed leading dimension - // perform copy of A to packed buffer D_A_pack + a11 += rs_a; - for(dim_t x =0;x < p_lda;x+=D_NR) - { - ymm0 = _mm256_loadu_pd((double const *)(a01)); - ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); - ymm2 = _mm256_loadu_pd((double const *)(a01 + cs_a * 2)); - ymm3 = _mm256_loadu_pd((double const *)(a01 + cs_a * 3)); + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + //(ROw5): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm10, ymm11); + ymm7 = _mm256_fnmadd_pd(ymm2, ymm6, ymm7); - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + //perform mul operation + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm1); - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + a11 += rs_a; - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm0 = _mm256_loadu_pd((double const *)(a01 + cs_a * 4)); - ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a * 5)); + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_broadcast_sd((double const *)&zero); + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_broadcast_sd((double const *)&zero); + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_extractf128_pd(ymm6,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_extractf128_pd(ymm7,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_extractf128_pd(ymm8,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_extractf128_pd(ymm9,0)); + ///unpack high/// + ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - a01 += D_NR*cs_a; - ptr_a10_dup += D_NR; + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + + _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store B11[1][0-3] } - ymm4 = _mm256_broadcast_sd((double const *)&ones); - if(!is_unitdiag) + dim_t n_rem = n-j; + if(n_rem >= 4) { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_sd((double const *)(a11)); - ymm1 = _mm256_broadcast_sd((double const *)(a11+cs_a*1 + 1)); - ymm2 = _mm256_broadcast_sd((double const *)&ones); - ymm3 = _mm256_broadcast_sd((double const *)&ones); + a10 = D_A_pack; + a11 = L + (i*rs_a) + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + j*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM - ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); + k_iter = i; //number of times GEMM to be performed(in blocks of 4x4) - ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); - #ifdef BLIS_DISABLE_TRSM_PREINVERSION - ymm4 = ymm1; - #endif - #ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm4 = _mm256_div_pd(ymm4, ymm1); - #endif - } - _mm256_storeu_pd((double *)(d11_pack), ymm4); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + BLIS_DTRSM_SMALL_GEMM_4mx4n(a10,b01,cs_b,p_lda,k_iter) - for(i = (m-D_MR); (i+1) > 0; i -= D_MR) //loop along 'M' direction - { - a01 = D_A_pack; - a11 = L + (n_remainder - 2)*cs_a + (n_remainder - 2); //pointer to block of A to be used for TRSM - b10 = B + i + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (i) + (n_remainder - 2)*cs_b; //pointer to block of B to be used for TRSM + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + ///implement TRSM/// - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + 4), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b+4), _MM_HINT_T0); - #endif + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b10 + 4)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] + ///transpose of B11// + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - a01 += 1; //move to next row - b10 += cs_b; - } + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); + ymm0 = _mm256_broadcast_sd((double const *)&ones); - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + 4)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - ymm4 = _mm256_fmsub_pd(ymm1, ymm15, ymm4); //B11[4-7][0] * alpha-= ymm1 + //perform mul operation + ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b + 4)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - ymm6 = _mm256_fmsub_pd(ymm1, ymm15, ymm6); //B11[4-7][1] * alpha -= ymm3 + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*1)); + ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); - ///implement TRSM/// + a11 += rs_a; - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + //(ROw1): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); + ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); + ymm11 = _mm256_fnmadd_pd(ymm4, ymm8, ymm11); - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm0); + //perform mul operation + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); + ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); - //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); + a11 += rs_a; - ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); - ymm4 = _mm256_fnmadd_pd(ymm1, ymm6, ymm4); + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm0); + //(ROw2): FMA operations + ymm10 = _mm256_fnmadd_pd(ymm3, ymm9, ymm10); + ymm11 = _mm256_fnmadd_pd(ymm4, ymm9, ymm11); - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + 4), ymm4); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b + 4), ymm6); - } + //perform mul operation + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); - dim_t m_remainder = i + D_MR; - if(m_remainder >= 4) - { - a01 = D_A_pack; - a11 = L + (n_remainder - 2)*cs_a + (n_remainder - 2); //pointer to block of A to be used for TRSM - b10 = B + (m_remainder - 4) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (m_remainder - 4) + (n_remainder - 2)*cs_b; //pointer to block of B to be used for TRSM + ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); - k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + a11 += rs_a; - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + //(ROw5): FMA operations + ymm11 = _mm256_fnmadd_pd(ymm4, ymm10, ymm11); - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) + //perform mul operation + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - a01 += 1; //move to next row - b10 += cs_b; - } + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] - ///implement TRSM/// + n_rem -= 4; + j += 4; + } + if(n_rem) + { + a10 = D_A_pack; + a11 = L + (i*rs_a) + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + j*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + k_iter = i; //number of times GEMM to be performed(in blocks of 4x4) - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); - //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + if(3 == n_rem) + { + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - m_remainder -=4; - } + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + } + else if(2 == n_rem) + { + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b,p_lda,k_iter) - if(m_remainder) - { - if(3 == m_remainder) + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + } + else if(1 == n_rem) { - a01 = D_A_pack; - a11 = L + (n_remainder - 2)*cs_a + (n_remainder - 2); //pointer to block of A to be used for TRSM - b10 = B + (m_remainder - 3) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (m_remainder - 3) + (n_remainder - 2)*cs_b; //pointer to block of B to be used for TRSM + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b,p_lda,k_iter) - k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_broadcast_sd((double const *)(&ones)); + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + } - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) + ///transpose of B11// + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] - a01 += 1; //move to next row - b10 += cs_b; - } + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + ymm0 = _mm256_broadcast_sd((double const *)&ones); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1)); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*1 + 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + ////extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); - ///implement TRSM/// + //perform mul operation + ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*1)); + ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); - //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + a11 += rs_a; - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + //(ROw1): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); + ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); + ymm11 = _mm256_fnmadd_pd(ymm4, ymm8, ymm11); - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x07); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1)); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*1 + 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x07); + //perform mul operation + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); - _mm256_storeu_pd((double *)b11, ymm3); - xmm5 = _mm256_extractf128_pd(ymm5, 0); - _mm_storeu_pd((double *)(b11 + cs_b*1), xmm5); - _mm_storel_pd((b11 + cs_b * 1 + 2), _mm256_extractf128_pd(ymm5, 1)); + ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); - m_remainder -=3; - } - else if(2 == m_remainder) - { - a01 = D_A_pack; - a11 = L + (n_remainder - 2)*cs_a + (n_remainder - 2); //pointer to block of A to be used for TRSM - b10 = B + (m_remainder - 2) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (m_remainder - 2) + (n_remainder - 2)*cs_b; //pointer to block of B to be used for TRSM + a11 += rs_a; - k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); + //(ROw2): FMA operations + ymm10 = _mm256_fnmadd_pd(ymm3, ymm9, ymm10); + ymm11 = _mm256_fnmadd_pd(ymm4, ymm9, ymm11); - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + //perform mul operation + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) + ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + a11 += rs_a; - a01 += 1; //move to next row - b10 += cs_b; - } + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + //(ROw5): FMA operations + ymm11 = _mm256_fnmadd_pd(ymm4, ymm10, ymm11); - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + //perform mul operation + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - ///implement TRSM/// + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + if(3 == n_rem) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + } + else if(2 == n_rem) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + } + else if(1 == n_rem) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + } + } + m_rem -=4; + i +=4; + } - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + if(m_rem) + { + a10 = L + (i*cs_a); //pointer to block of A to be used for GEMM + // Do transpose for a10 & store in D_A_pack + double *ptr_a10_dup = D_A_pack; + if(3 == m_rem) // Repetative A blocks will be 3*3 + { + dim_t p_lda = 4; // packed leading dimension + if(transa) + { + for(dim_t x=0;x= 4)) + { + a10 = D_A_pack; //pointer to block of A to be used for GEMM + a11 = L + (i*rs_a) + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM + b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + k_iter = i; //number of times GEMM to be performed(in blocks of 4x4) - ymm0 = _mm256_broadcast_sd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x01); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x01); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS - _mm_storel_pd(b11 , _mm256_extractf128_pd(ymm3, 0)); - _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm5, 0)); + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx4n(a10,b01,cs_b,p_lda,k_iter) - m_remainder -=1; - } - } - n_remainder -= 2; - } - else if(n_remainder == 1) - { - a01 = L + (n_remainder - 1) + n_remainder*cs_a; //pointer to block of A to be used in GEMM - a11 = L + (n_remainder - 1)*cs_a + (n_remainder - 1); //pointer to block of A to be used for TRSM + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - double *ptr_a10_dup = D_A_pack; + ///implement TRSM/// - dim_t p_lda = (n-n_remainder); // packed leading dimension - // perform copy of A to packed buffer D_A_pack + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); + ymm3 = _mm256_broadcast_sd((double const *)(b11 + cs_b*3 + 2)); + ymm3 = _mm256_insertf128_pd(ymm3, xmm5, 0); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - for(dim_t x =0;x < p_lda;x+=D_NR) - { - ymm0 = _mm256_loadu_pd((double const *)(a01)); - ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); - ymm2 = _mm256_loadu_pd((double const *)(a01 + cs_a * 2)); - ymm3 = _mm256_loadu_pd((double const *)(a01 + cs_a * 3)); + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08); + ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x08); - ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + xmm5 = _mm256_extractf128_pd(ymm3, 0); + _mm_storeu_pd((double *)(b11 + cs_b * 3),xmm5); + _mm_storel_pd((b11 + cs_b * 3 + 2), _mm256_extractf128_pd(ymm3, 1)); - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + if(transa) + dtrsm_AutXB_ref(a11, b11, m_rem, 4, cs_a, cs_b,is_unitdiag); + else + dtrsm_AlXB_ref(a11, b11, m_rem, 4, rs_a, cs_b, is_unitdiag); + n_rem -= 4; + j +=4; + } - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + if(n_rem) + { + a10 = D_A_pack; //pointer to block of A to be used for GEMM + a11 = L + (i*rs_a) + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM + b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + k_iter = i; //number of times GEMM to be performed(in blocks of 4x4) - _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS - ymm0 = _mm256_loadu_pd((double const *)(a01 + cs_a * 4)); - ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a * 5)); + if(3 == n_rem) + { + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) - ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_broadcast_sd((double const *)&zero); + BLIS_PRE_DTRSM_SMALL_3M_3N(AlphaVal,b11,cs_b) - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + if(transa) + dtrsm_AutXB_ref(a11, b11, m_rem, 3, cs_a, cs_b,is_unitdiag); + else + dtrsm_AlXB_ref(a11, b11, m_rem, 3, rs_a, cs_b, is_unitdiag); + } + else if(2 == n_rem) + { + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b,p_lda,k_iter) - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_broadcast_sd((double const *)&zero); + BLIS_PRE_DTRSM_SMALL_3M_2N(AlphaVal,b11,cs_b) - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + if(transa) + dtrsm_AutXB_ref(a11, b11, m_rem, 2, cs_a, cs_b,is_unitdiag); + else + dtrsm_AlXB_ref(a11, b11, m_rem, 2, rs_a, cs_b, is_unitdiag); + } + else if(1 == n_rem) + { + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b,p_lda,k_iter) - _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_extractf128_pd(ymm6,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_extractf128_pd(ymm7,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_extractf128_pd(ymm8,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_extractf128_pd(ymm9,0)); + BLIS_PRE_DTRSM_SMALL_3M_1N(AlphaVal,b11,cs_b) - a01 += D_NR*cs_a; - ptr_a10_dup += D_NR; + if(transa) + dtrsm_AutXB_ref(a11, b11, m_rem, 1, cs_a, cs_b, is_unitdiag); + else + dtrsm_AlXB_ref(a11, b11, m_rem, 1, rs_a, cs_b, is_unitdiag); + } + } } - - ymm4 = _mm256_broadcast_sd((double const *)&ones); - if(!is_unitdiag) + else if(2 == m_rem) // Repetative A blocks will be 2*2 { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_sd((double const *)(a11)); - ymm1 = _mm256_broadcast_sd((double const *)&ones); - ymm2 = _mm256_broadcast_sd((double const *)&ones); - ymm3 = _mm256_broadcast_sd((double const *)&ones); + dim_t p_lda = 4; // packed leading dimension + if(transa) + { + for(dim_t x=0;x 0; i -= D_MR) //loop along 'M' direction - { - a01 = D_A_pack; - a11 = L + (n_remainder - 1)*cs_a + (n_remainder - 1); //pointer to block of A to be used for TRSM - b10 = B + i + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (i) + (n_remainder - 1)*cs_b; //pointer to block of B to be used for TRSM + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); - k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + 4), _MM_HINT_T0); - #endif + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + a10 += p_lda; + ptr_a10_dup += p_lda*p_lda; + } + } + else + { + for(dim_t x=0;x= 4) - { - a01 = D_A_pack; - a11 = L + (n_remainder - 1)*cs_a + (n_remainder - 1); //pointer to block of A to be used for TRSM - b10 = B + (m_remainder - 4) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (m_remainder - 4) + (n_remainder - 1)*cs_b; //pointer to block of B to be used for TRSM + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); - k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); - ymm3 = _mm256_setzero_pd(); + _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store(B11[0-3][3]) - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + if(transa) + dtrsm_AutXB_ref(a11, b11, m_rem, 6, cs_a, cs_b, is_unitdiag); + else + dtrsm_AlXB_ref(a11, b11, m_rem, 6, rs_a, cs_b, is_unitdiag); + } + + dim_t n_rem = n-j; + if((n_rem >= 4)) { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + a10 = D_A_pack; //pointer to block of A to be used for GEMM + a11 = L + (i*rs_a) + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM + b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) + k_iter = i; //number of times GEMM to be performed(in blocks of 4x4) - a01 += 1; //move to next row - b10 += cs_b; - } + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx4n(a10,b01,cs_b,p_lda,k_iter) - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - ///implement TRSM/// - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + ///implement TRSM/// - _mm256_storeu_pd((double *)b11, ymm3); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); + ymm3 = _mm256_insertf128_pd(ymm3, xmm5, 0); - m_remainder -=4; - } + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - if(m_remainder) - { - if(3 == m_remainder) + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0C); + ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0C); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + xmm5 = _mm256_extractf128_pd(ymm3, 0); + _mm_storeu_pd((double *)(b11 + cs_b * 3), xmm5); + + if(transa) + dtrsm_AutXB_ref(a11, b11, m_rem, 4, cs_a, cs_b, is_unitdiag); + else + dtrsm_AlXB_ref(a11, b11, m_rem, 4, rs_a, cs_b, is_unitdiag); + n_rem -= 4; + j +=4; + } + if(n_rem) { - a01 = D_A_pack; - a11 = L + (n_remainder - 1)*cs_a + (n_remainder - 1); //pointer to block of A to be used for TRSM - b10 = B + (m_remainder - 3) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (m_remainder - 3) + (n_remainder - 1)*cs_b; //pointer to block of B to be used for TRSM + a10 = D_A_pack; //pointer to block of A to be used for GEMM + a11 = L + (i*rs_a) + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM + b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM - k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + k_iter = i; //number of times GEMM to be performed(in blocks of 4x4) - ymm3 = _mm256_setzero_pd(); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + if(3 == n_rem) { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) + BLIS_PRE_DTRSM_SMALL_2M_3N(AlphaVal,b11,cs_b) - a01 += 1; //move to next row - b10 += cs_b; + if(transa) + dtrsm_AutXB_ref(a11, b11, m_rem, 3, cs_a, cs_b, is_unitdiag); + else + dtrsm_AlXB_ref(a11, b11, m_rem, 3, rs_a, cs_b, is_unitdiag); } + else if(2 == n_rem) + { + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b,p_lda,k_iter) - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + BLIS_PRE_DTRSM_SMALL_2M_2N(AlphaVal,b11,cs_b) - xmm5 = _mm_loadu_pd((double const*)(b11)); - ymm0 = _mm256_broadcast_sd((double const *)(b11+ 2)); - ymm6 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm3 = _mm256_fmsub_pd(ymm6, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + if(transa) + dtrsm_AutXB_ref(a11, b11, m_rem, 2, cs_a, cs_b, is_unitdiag); + else + dtrsm_AlXB_ref(a11, b11, m_rem, 2, rs_a, cs_b, is_unitdiag); + } + else if(1 == n_rem) + { + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b,p_lda,k_iter) - ///implement TRSM/// - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + BLIS_PRE_DTRSM_SMALL_2M_1N(AlphaVal,b11,cs_b) - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm6, ymm3, 0x07); + if(transa) + dtrsm_AutXB_ref(a11, b11, m_rem, 1, cs_a, cs_b, is_unitdiag); + else + dtrsm_AlXB_ref(a11, b11, m_rem, 1, rs_a, cs_b, is_unitdiag); + } + } + m_rem -=2; + i+=2; + } + else if(1 == m_rem) // Repetative A blocks will be 1*1 + { + dim_t p_lda = 4; // packed leading dimension + if(transa) + { + for(dim_t x=0;x= 4)) + { + a10 = D_A_pack; //pointer to block of A to be used for GEMM + a11 = L + (i*rs_a) + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM + b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM - xmm5 = _mm_loadu_pd((double const*)(b11)); - ymm6 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm3 = _mm256_fmsub_pd(ymm6, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + k_iter = i; //number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx4n(a10,b01,cs_b,p_lda,k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha ///implement TRSM/// - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_broadcast_sd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_broadcast_sd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_broadcast_sd((double const *)(b11 + cs_b *3)); - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm6, ymm3, 0x03); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); + ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0E); - xmm5 = _mm256_extractf128_pd(ymm3, 0); - _mm_storeu_pd((double *)(b11), xmm5); + _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm0, 0)); + _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm1, 0)); + _mm_storel_pd((b11 + cs_b * 2), _mm256_extractf128_pd(ymm2, 0)); + _mm_storel_pd((b11 + cs_b * 3), _mm256_extractf128_pd(ymm3, 0)); - m_remainder -=2; + if(transa) + dtrsm_AutXB_ref(a11, b11, m_rem, 4, cs_a, cs_b, is_unitdiag); + else + dtrsm_AlXB_ref(a11, b11, m_rem, 4, rs_a, cs_b, is_unitdiag); + n_rem -= 4; + j+=4; } - else if (1 == m_remainder) + + if(n_rem) { - a01 = D_A_pack; - a11 = L + (n_remainder - 1)*cs_a + (n_remainder - 1); //pointer to block of A to be used for TRSM - b10 = B + (m_remainder - 1) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (m_remainder - 1) + (n_remainder - 1)*cs_b; //pointer to block of B to be used for TRSM + a10 = D_A_pack; //pointer to block of A to be used for GEMM + a11 = L + (i*rs_a) + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM + b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM - k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + k_iter = i; //number of times GEMM to be performed(in blocks of 4x4) - ymm3 = _mm256_setzero_pd(); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations + if(3 == n_rem) { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) + BLIS_PRE_DTRSM_SMALL_1M_3N(AlphaVal,b11,cs_b) - a01 += 1; //move to next row - b10 += cs_b; + if(transa) + dtrsm_AutXB_ref(a11, b11, m_rem, 3, cs_a, cs_b, is_unitdiag); + else + dtrsm_AlXB_ref(a11, b11, m_rem, 3, rs_a, cs_b, is_unitdiag); } + else if(2 == n_rem) + { + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b,p_lda,k_iter) - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - - ymm6 = _mm256_broadcast_sd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm6, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ///implement TRSM/// - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + BLIS_PRE_DTRSM_SMALL_1M_2N(AlphaVal,b11,cs_b) - ymm3 = _mm256_blend_pd(ymm6, ymm3, 0x01); + if(transa) + dtrsm_AutXB_ref(a11, b11, m_rem, 2, cs_a, cs_b, is_unitdiag); + else + dtrsm_AlXB_ref(a11, b11, m_rem, 2, rs_a, cs_b, is_unitdiag); + } + else if(1 == n_rem) + { + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b,p_lda,k_iter) - _mm_storel_pd(b11, _mm256_extractf128_pd(ymm3, 0)); + BLIS_PRE_DTRSM_SMALL_1M_1N(AlphaVal,b11,cs_b) - m_remainder -=1; + if(transa) + dtrsm_AutXB_ref(a11, b11, m_rem, 1, cs_a, cs_b, is_unitdiag); + else + dtrsm_AutXB_ref(a11, b11, m_rem, 1, rs_a, cs_b, is_unitdiag); + } } + m_rem -=1; + i+=1; } - n_remainder -= 1; } - if ((required_packing_A == 1) && bli_mem_is_alloc( &local_mem_buf_A_s )) + if ((required_packing_A == 1) && + bli_mem_is_alloc( &local_mem_buf_A_s )) { - bli_membrk_release(&rntm, - &local_mem_buf_A_s); + bli_membrk_release(&rntm, &local_mem_buf_A_s); } - return BLIS_SUCCESS; + return BLIS_SUCCESS; } -#endif //BLIS_ENABLE_SMALL_MATRIX_TRSM +#endif //BLIS_ENABLE_SMALL_MATRIX_TRSM \ No newline at end of file From 5cf260f1ecab6ba9c6d31a943fdd41409c6804ae Mon Sep 17 00:00:00 2001 From: Meghana Vankadari Date: Tue, 17 Aug 2021 14:03:45 +0530 Subject: [PATCH 002/243] Fixed a bug in trsm blocksize determining function. - The function "bli_determine_blocksize_b_sub" uses a modulo operation using b_alg. In cases where trsm blocksizes are not set, using this function with trsm blocksizes being zero will lead to FPE. - Added a check to avoid calling the above function with zeroes as parameters. Change-Id: I770aaa13125b55320a68ff9fc3da782111e0978a --- frame/base/bli_blksz.c | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/frame/base/bli_blksz.c b/frame/base/bli_blksz.c index 83a724e98e..f3891dbbba 100644 --- a/frame/base/bli_blksz.c +++ b/frame/base/bli_blksz.c @@ -272,12 +272,11 @@ dim_t bli_determine_blocksize_f b_alg = bli_blksz_get_def( dt, bsize ); b_max = bli_blksz_get_max( dt, bsize ); - b_use = bli_determine_blocksize_f_sub( i, dim, b_alg, b_max ); - // If b_use != 0, this means that trsm blocksizes are set // and we continue with trsm-specific blocksizes. // Else, we query L3 blocksizes and use them for TRSM execution. - if( b_use > 0 ) return b_use; + if( b_alg > 0 ) return bli_determine_blocksize_f_sub( i, dim, b_alg, b_max); + } bsize = bli_cntx_get_blksz( bszid, cntx ); @@ -314,12 +313,11 @@ dim_t bli_determine_blocksize_b b_alg = bli_blksz_get_def( dt, bsize ); b_max = bli_blksz_get_max( dt, bsize ); - b_use = bli_determine_blocksize_b_sub( i, dim, b_alg, b_max ); - // If b_use != 0, this means that trsm blocksizes are set // and we continue with trsm-specific blocksizes. // Else, we query L3 blocksizes and use them for TRSM execution. - if( b_use > 0 ) return b_use; + if( b_alg > 0 ) bli_determine_blocksize_b_sub( i, dim, b_alg, b_max ); + } bsize = bli_cntx_get_blksz( bszid, cntx ); From bcd9591b3f89b86c30430ba2f0dcf6104342ab1b Mon Sep 17 00:00:00 2001 From: Dipal M Zambare Date: Tue, 17 Aug 2021 12:46:38 +0530 Subject: [PATCH 003/243] Added support for amdepyc fat binary -- Created new configuration amdepyc to include fat binary which includes zen, zen2, zen3 and generic architecture for fallback. -- Updated amdepyc family makefiles to include macros needed in amdepyc family binary. This file must include all macros, compiler options to be used for non architecture specific code. -- Added 'workaround' to exclude ZEN family specific code in some of the framework files. There are still lot of places were ZEN family specific code is added in framework files. They will be addressed with proper design later. - Moved definition of BLIS_CONFIG_EPYC from header files to makefile so that it is enabled only for framework and kernels -- Removed redundant flag AOCL_BLIS_ZEN, used BLIS_CONFIG_EPYC wherever it was needed. -- Removed un-used, obsolete macros, some of them may be needed for debugging which can be added in the individual workspaces. - BLIS_DEFAULT_MR_THREAD_MAX - BLIS_DEFAULT_NR_THREAD_MAX - BLIS_ENABLE_ZEN_BLOCK_SIZES - BLIS_SMALL_MATRIX_THRES_TRSM - BLIS_ENABLE_SINGLE_INSTANCE_BLOCK_SIZES - BLIS_ENABLE_SUP_MR_EXT - BLIS_ENABLE_SUP_NR_EXT -- Corrected implementation of exiting amd64_legacy configuration. AMD-Internal: [CPUPL-1626, CPUPL-1628] Change-Id: I46b0ab3ea3ac7d9ff737fef66c462e85601ee29c --- CMakeLists.txt | 16 ++-- config/amd64/bli_family_amd64.h | 92 ------------------- ...mily_amd64.h => bli_family_amd64_legacy.h} | 6 +- config/amd64_legacy/make_defs.mk | 4 +- config/amdepyc/bli_family_amdepyc.h | 64 +++++++++++++ config/{amd64 => amdepyc}/make_defs.mk | 29 ++++-- config/zen/bli_family_zen.h | 31 +------ config/zen/make_defs.mk | 20 +++- config/zen2/bli_family_zen2.h | 9 -- config/zen2/make_defs.mk | 19 +++- config/zen3/bli_family_zen3.h | 58 ------------ config/zen3/make_defs.mk | 21 ++++- config_registry | 2 +- frame/base/bli_arch.c | 11 ++- frame/base/bli_cntx.c | 5 +- frame/base/bli_cpuid.c | 26 +++--- frame/compat/bla_gemm.c | 4 +- frame/include/bli_arch_config.h | 9 +- 18 files changed, 183 insertions(+), 243 deletions(-) delete mode 100644 config/amd64/bli_family_amd64.h rename config/amd64_legacy/{bli_family_amd64.h => bli_family_amd64_legacy.h} (93%) create mode 100644 config/amdepyc/bli_family_amdepyc.h rename config/{amd64 => amdepyc}/make_defs.mk (71%) diff --git a/CMakeLists.txt b/CMakeLists.txt index 1657948215..77a7e702ba 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -52,7 +52,7 @@ elseif (${AOCL_BLIS_FAMILY} STREQUAL "zen3") add_definitions(-DBLIS_KERNELS_HASWELL) elseif (${AOCL_BLIS_FAMILY} STREQUAL "amd64") set(AOCL_BLIS_ZEN FALSE) - add_definitions(-DBLIS_FAMILY_AMD64) + add_definitions(-DBLIS_FAMILY_AMDEPYC) add_definitions(-DBLIS_CONFIG_ZEN3) add_definitions(-DBLIS_CONFIG_ZEN2) add_definitions(-DBLIS_CONFIG_ZEN) @@ -294,7 +294,7 @@ endif() set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /W0 ") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /Oi") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /MP") -set(INTR_GENERAL_LINK_FLAGS "${INTR_GENERAL_LINK_FLAGS} /RELEGE") +set(INTR_GENERAL_LINK_FLAGS "${INTR_GENERAL_LINK_FLAGS} /RELEGE") add_definitions(-D_CRT_SECURE_NO_DEPRECATE) @@ -397,13 +397,13 @@ find_package(PythonLibs 3 REQUIRED) string(APPEND HEADER_PATH if(${AOCL_BLIS_FAMILY} STREQUAL "zen") - " ${CMAKE_CURRENT_SOURCE_DIR}/config/zen/" - " ${CMAKE_CURRENT_SOURCE_DIR}/kernels/zen/" - " ${CMAKE_CURRENT_SOURCE_DIR}/kernels/haswell/" + " ${CMAKE_CURRENT_SOURCE_DIR}/config/zen/" + " ${CMAKE_CURRENT_SOURCE_DIR}/kernels/zen/" + " ${CMAKE_CURRENT_SOURCE_DIR}/kernels/haswell/" elseif (${AOCL_BLIS_FAMILY} STREQUAL "zen2") - " ${CMAKE_CURRENT_SOURCE_DIR}/config/zen2/" - " ${CMAKE_CURRENT_SOURCE_DIR}/kernels/zen/" - " ${CMAKE_CURRENT_SOURCE_DIR}/kernels/haswell/" + " ${CMAKE_CURRENT_SOURCE_DIR}/config/zen2/" + " ${CMAKE_CURRENT_SOURCE_DIR}/kernels/zen/" + " ${CMAKE_CURRENT_SOURCE_DIR}/kernels/haswell/" elseif (${AOCL_BLIS_FAMILY} STREQUAL "amd64") " ${CMAKE_CURRENT_SOURCE_DIR}/config/amd64/" " ${CMAKE_CURRENT_SOURCE_DIR}/config/bulldozer/" diff --git a/config/amd64/bli_family_amd64.h b/config/amd64/bli_family_amd64.h deleted file mode 100644 index 31ae3ecb82..0000000000 --- a/config/amd64/bli_family_amd64.h +++ /dev/null @@ -1,92 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#ifndef BLIS_FAMILY_AMD64_H -#define BLIS_FAMILY_AMD64_H - -//To enable framework optimizations for EPYC family processors. -//With this macro defined, we can call kernels directly from -//BLAS interfaces for levels 1 & 2. -//This macro needs to be defined for all EPYC configurations. -#define BLIS_CONFIG_EPYC - - -// For zen3 architecture we dynamically change block sizes -// based on number of threads. These values were determined -// by running benchmarks on zen3 platform. - -#ifdef BLIS_ENABLE_MULTITHREADING - -#define BLIS_GEMM_DYNAMIC_BLOCK_SIZE_UPDATE(cntx, rntm, c) { \ - \ - if (bli_is_double(bli_obj_dt(&c))) { \ - const dim_t nt = rntm->num_threads; \ - const dim_t m = bli_obj_length(&c); \ - const dim_t n = bli_obj_width(&c); \ - \ - blksz_t blkszs[BLIS_NUM_BLKSZS]; \ - if (nt >= 32 && (m > 7800 || n > 7800)) { \ - bli_blksz_init_easy(&blkszs[BLIS_MC], 144, 72, 144, 72 ); \ - bli_blksz_init_easy(&blkszs[BLIS_KC], 256, 512, 256, 256 ); \ - bli_blksz_init_easy(&blkszs[BLIS_NC], 4080, 4080, 4080, 4080 ); \ - \ - bli_cntx_set_blkszs( \ - BLIS_NAT, 3, \ - BLIS_NC, &blkszs[BLIS_NC], BLIS_NR, \ - BLIS_KC, &blkszs[BLIS_KC], BLIS_KR, \ - BLIS_MC, &blkszs[BLIS_MC], BLIS_MR, \ - cntx); \ - } else { \ - bli_blksz_init_easy(&blkszs[BLIS_MC], 144, 72, 144, 72 ); \ - bli_blksz_init_easy(&blkszs[BLIS_KC], 256, 256, 256, 256 ); \ - bli_blksz_init_easy(&blkszs[BLIS_NC], 4080, 4080, 4080, 4080 ); \ - \ - bli_cntx_set_blkszs( \ - BLIS_NAT, 3, \ - BLIS_NC, &blkszs[BLIS_NC], BLIS_NR, \ - BLIS_KC, &blkszs[BLIS_KC], BLIS_KR, \ - BLIS_MC, &blkszs[BLIS_MC], BLIS_MR, \ - cntx); \ - } \ - } \ -} -#else -#define BLIS_GEMM_DYNAMIC_BLOCK_SIZE_UPDATE(cntx, rntm, c) {} -#endif - -// Place holder for bundle configuration. - -#endif - diff --git a/config/amd64_legacy/bli_family_amd64.h b/config/amd64_legacy/bli_family_amd64_legacy.h similarity index 93% rename from config/amd64_legacy/bli_family_amd64.h rename to config/amd64_legacy/bli_family_amd64_legacy.h index 1b4109dfeb..5629b9a2d3 100644 --- a/config/amd64_legacy/bli_family_amd64.h +++ b/config/amd64_legacy/bli_family_amd64_legacy.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020, Advanced Micro Devices, Inc + Copyright (C) 2021, Advanced Micro Devices, Inc Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -33,8 +33,8 @@ */ -#ifndef BLIS_FAMILY_AMD64_LEG_H -#define BLIS_FAMILY_AMD64_LEG_H +#ifndef BLIS_FAMILY_AMD64_LEGACY_H +#define BLIS_FAMILY_AMD64_LEGACY_H // Place holder for bundle configuration. diff --git a/config/amd64_legacy/make_defs.mk b/config/amd64_legacy/make_defs.mk index a2cc80fc5f..5f0d613cbb 100644 --- a/config/amd64_legacy/make_defs.mk +++ b/config/amd64_legacy/make_defs.mk @@ -1,11 +1,11 @@ # # -# BLIS +# BLIS # An object-based framework for developing high-performance BLAS-like # libraries. # # Copyright (C) 2014, The University of Texas at Austin -# Copyright (C) 2020, Advanced Micro Devices, Inc +# Copyright (C) 2021, Advanced Micro Devices, Inc # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are diff --git a/config/amdepyc/bli_family_amdepyc.h b/config/amdepyc/bli_family_amdepyc.h new file mode 100644 index 0000000000..c3f4370692 --- /dev/null +++ b/config/amdepyc/bli_family_amdepyc.h @@ -0,0 +1,64 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_FAMILY_AMD64_H +#define BLIS_FAMILY_AMD64_H + +// By default, it is effective to parallelize the outer loops. +// Setting these macros to 1 will force JR and IR inner loops +// to be not paralleized. +// +#define BLIS_THREAD_MAX_IR 1 +#define BLIS_THREAD_MAX_JR 1 + + +#define BLIS_ENABLE_SMALL_MATRIX +#define BLIS_ENABLE_SMALL_MATRIX_TRSM + + +// This will select the threshold below which small matrix code will be called. +#define BLIS_SMALL_MATRIX_THRES 700 +#define BLIS_SMALL_M_RECT_MATRIX_THRES 160 +#define BLIS_SMALL_K_RECT_MATRIX_THRES 128 + +#define BLIS_SMALL_MATRIX_A_THRES_M_SYRK 96 +#define BLIS_SMALL_MATRIX_A_THRES_N_SYRK 128 + +// When running HPL with pure MPI without DGEMM threading (Single-threaded +// BLIS), defining this macro as 1 yields better performance. +#define AOCL_BLIS_MULTIINSTANCE 0 + +#endif + diff --git a/config/amd64/make_defs.mk b/config/amdepyc/make_defs.mk similarity index 71% rename from config/amd64/make_defs.mk rename to config/amdepyc/make_defs.mk index 491289ff8d..d7e1b73226 100644 --- a/config/amd64/make_defs.mk +++ b/config/amdepyc/make_defs.mk @@ -5,6 +5,7 @@ # libraries. # # Copyright (C) 2014, The University of Texas at Austin +# Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -35,15 +36,29 @@ # Declare the name of the current configuration and add it to the # running list of configurations included by common.mk. -THIS_CONFIG := amd64 -#CONFIGS_INCL += $(THIS_CONFIG) +THIS_CONFIG := amdepyc -# -# --- Determine the C compiler and related flags --- -# +# For architecture independent files we still need to define +# the required flags +ifneq ($(DEBUG_TYPE),off) +CDBGFLAGS := -g +endif + +ifeq ($(DEBUG_TYPE),noopt) +COPTFLAGS := -O0 +else +COPTFLAGS := -O3 +endif -# These setting should come from makefiles for individial configuration -# included in this bundle. +# This will add BLIS_CONFIG_EPYC for all framework files +# FIXME: framework files should not have architecture specific +# checks at least at compile time. Once the macro +# is defined it is applicable to every build in the +# Family including any non AMD configuration. +# However, it is still better to define it in makefiles +# instead of headers so we can have slighly more +# control on this. +COPTFLAGS += -DBLIS_CONFIG_EPYC # Store all of the variables here to new variables containing the # configuration name. diff --git a/config/zen/bli_family_zen.h b/config/zen/bli_family_zen.h index 9e9d09afd9..737ca5c597 100644 --- a/config/zen/bli_family_zen.h +++ b/config/zen/bli_family_zen.h @@ -33,22 +33,15 @@ */ -//#ifndef BLIS_FAMILY_H -//#define BLIS_FAMILY_H - -//To enable framework optimizations for EPYC family processors. -//With this macro defined, we can call kernels directly from -//BLAS interfaces for levels 1 & 2. -//This macro needs to be defined for all EPYC configurations. -#define BLIS_CONFIG_EPYC +#ifndef BLIS_FAMILY_ZEN_H +#define BLIS_FAMILY_ZEN_H // By default, it is effective to parallelize the outer loops. // Setting these macros to 1 will force JR and IR inner loops // to be not paralleized. -#define BLIS_DEFAULT_MR_THREAD_MAX 1 -#define BLIS_DEFAULT_NR_THREAD_MAX 1 +#define BLIS_THREAD_MAX_IR 1 +#define BLIS_THREAD_MAX_JR 1 -#define BLIS_ENABLE_ZEN_BLOCK_SIZES #define BLIS_ENABLE_SMALL_MATRIX #define BLIS_ENABLE_SMALL_MATRIX_TRSM @@ -57,21 +50,7 @@ #define BLIS_SMALL_M_RECT_MATRIX_THRES 160 #define BLIS_SMALL_K_RECT_MATRIX_THRES 128 -#define BLIS_SMALL_MATRIX_THRES_TRSM 32768 //128(128+128) => m*(m+n) #define BLIS_SMALL_MATRIX_A_THRES_M_SYRK 96 #define BLIS_SMALL_MATRIX_A_THRES_N_SYRK 128 -//This macro will enable BLIS DGEMM to choose block sizes for a single instance mode -#define BLIS_ENABLE_SINGLE_INSTANCE_BLOCK_SIZES 0 - -// Allow the sup implementation to combine some small edge case iterations in -// the 2nd loop of the panel-block algorithm (MR) and/or the 2nd loop of the -// block-panel algorithm (NR) with the last full iteration that precedes it. -// NOTE: These cpp macros need to be explicitly set to an integer since they -// are used at compile-time to create unconditional branches or dead code -// regions. -#define BLIS_ENABLE_SUP_MR_EXT 1 -#define BLIS_ENABLE_SUP_NR_EXT 0 - -//#endif - +#endif \ No newline at end of file diff --git a/config/zen/make_defs.mk b/config/zen/make_defs.mk index f600ef8663..be1086a1de 100644 --- a/config/zen/make_defs.mk +++ b/config/zen/make_defs.mk @@ -1,11 +1,11 @@ # # -# BLIS +# BLIS # An object-based framework for developing high-performance BLAS-like # libraries. # # Copyright (C) 2014, The University of Texas at Austin -# Copyright (C) 2019, Advanced Micro Devices, Inc. +# Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -46,6 +46,18 @@ AMD_CONFIG_FILE := amd_config.mk AMD_CONFIG_PATH := $(BASE_SHARE_PATH)/config/zen -include $(AMD_CONFIG_PATH)/$(AMD_CONFIG_FILE) + +# Since we removed BLIS_CONFIG_EPYC from header file, we need to +# add it here at two places, +# CPPROCFLAGS = This will enable it for framework code +# This flag is used when configure is invoked with specific architecture +# CKOPTFLAGS = This will enable it for architecture specific kernels +# This flag is used for kernels assocaited with this architecture +# irrespective of the configuration it is built for. + +CPPROCFLAGS := -DBLIS_CONFIG_EPYC + + ifeq ($(DEBUG_TYPE),noopt) COPTFLAGS := -O0 else @@ -74,6 +86,10 @@ else CRVECFLAGS := $(CKVECFLAGS) endif +# Add this after updating variables for reference kernels +# we don't want this defined for them +CKOPTFLAGS += -DBLIS_CONFIG_EPYC + # Store all of the variables here to new variables containing the # configuration name. $(eval $(call store-make-defs,$(THIS_CONFIG))) diff --git a/config/zen2/bli_family_zen2.h b/config/zen2/bli_family_zen2.h index 5b1d68896b..fedc422ad1 100644 --- a/config/zen2/bli_family_zen2.h +++ b/config/zen2/bli_family_zen2.h @@ -36,12 +36,6 @@ #ifndef BLI_FAMILY_ZEN2_ #define BLI_FAMILY_ZEN2_ -//To enable framework optimizations for EPYC family processors. -//With this macro defined, we can call kernels directly from BLAS interfaces -//for levels 1 & 2. -//This macro needs to be defined for a;; EPYC configurations. -#define BLIS_CONFIG_EPYC - // By default, it is effective to parallelize the outer loops. // Setting these macros to 1 will force JR and IR inner loops // to be not paralleized. @@ -59,9 +53,6 @@ #define BLIS_SMALL_MATRIX_A_THRES_M_SYRK 96 #define BLIS_SMALL_MATRIX_A_THRES_N_SYRK 128 -#define BLIS_ENABLE_SMALL_MATRIX_ROME -#define BLIS_SMALL_MATRIX_THRES_ROME 400 - // When running HPL with pure MPI without DGEMM threading (Single-threaded // BLIS), defining this macro as 1 yields better performance. #define AOCL_BLIS_MULTIINSTANCE 0 diff --git a/config/zen2/make_defs.mk b/config/zen2/make_defs.mk index 5205c92911..c936487fc3 100644 --- a/config/zen2/make_defs.mk +++ b/config/zen2/make_defs.mk @@ -1,11 +1,11 @@ # # -# BLIS +# BLIS # An object-based framework for developing high-performance BLAS-like # libraries. # # Copyright (C) 2014, The University of Texas at Austin -# Copyright (C) 2019, Advanced Micro Devices, Inc. +# Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -49,7 +49,16 @@ THIS_CONFIG := zen2 # NOTE: The build system will append these variables with various # general-purpose/configuration-agnostic flags in common.mk. You # may specify additional flags here as needed. -CPPROCFLAGS := + +# Since we removed BLIS_CONFIG_EPYC from header file, we need to +# add it here at two places, +# CPPROCFLAGS = This will enable it for framework code +# This flag is used when configure is invoked with specific architecture +# CKOPTFLAGS = This will enable it for architecture specific kernels +# This flag is used for kernels assocaited with this architecture +# irrespective of the configuration it is built for. + +CPPROCFLAGS := -DBLIS_CONFIG_EPYC CMISCFLAGS := CPICFLAGS := CWARNFLAGS := @@ -102,6 +111,10 @@ endif CROPTFLAGS := $(CKOPTFLAGS) CRVECFLAGS := $(CKVECFLAGS) +# Add this after updating variables for reference kernels +# we don't want this defined for them +CKOPTFLAGS += -DBLIS_CONFIG_EPYC + # Store all of the variables here to new variables containing the # configuration name. $(eval $(call store-make-defs,$(THIS_CONFIG))) diff --git a/config/zen3/bli_family_zen3.h b/config/zen3/bli_family_zen3.h index 3b543b6107..0a6b210d62 100644 --- a/config/zen3/bli_family_zen3.h +++ b/config/zen3/bli_family_zen3.h @@ -44,20 +44,9 @@ #define BLIS_THREAD_MAX_IR 1 #define BLIS_THREAD_MAX_JR 1 -//To enable framework optimizations for EPYC family processors. -//With this macro defined, we can call kernels directly from BLAS interfaces -//for levels 1 & 2. -//This macro needs to be defined for all EPYC configurations. -#define BLIS_CONFIG_EPYC - -// To enable framework optimizations for zen3 platform -// All zen3 specific code should be included in this macro -#define BLIS_CONFIG_ZEN3 - #define BLIS_ENABLE_SMALL_MATRIX #define BLIS_ENABLE_SMALL_MATRIX_TRSM - // This will select the threshold below which small matrix code will be called. #define BLIS_SMALL_MATRIX_THRES 700 #define BLIS_SMALL_M_RECT_MATRIX_THRES 160 @@ -66,51 +55,4 @@ #define BLIS_SMALL_MATRIX_A_THRES_M_SYRK 96 #define BLIS_SMALL_MATRIX_A_THRES_N_SYRK 128 -#define BLIS_ENABLE_SMALL_MATRIX_ROME -#define BLIS_SMALL_MATRIX_THRES_ROME 400 - - -// For zen3 architecture we dynamically change block sizes -// based on number of threads. These values were determined -// by running benchmarks on zen3 platform. - -#ifdef BLIS_ENABLE_MULTITHREADING - -#define BLIS_GEMM_DYNAMIC_BLOCK_SIZE_UPDATE(cntx, rntm, c) { \ - \ - if (bli_is_double(bli_obj_dt(&c))) { \ - const dim_t nt = rntm->num_threads; \ - const dim_t m = bli_obj_length(&c); \ - const dim_t n = bli_obj_width(&c); \ - \ - blksz_t blkszs[BLIS_NUM_BLKSZS]; \ - if (nt >= 32 && (m > 7800 || n > 7800)) { \ - bli_blksz_init_easy(&blkszs[BLIS_MC], 144, 72, 144, 72 ); \ - bli_blksz_init_easy(&blkszs[BLIS_KC], 256, 512, 256, 256 ); \ - bli_blksz_init_easy(&blkszs[BLIS_NC], 4080, 4080, 4080, 4080 ); \ - \ - bli_cntx_set_blkszs( \ - BLIS_NAT, 3, \ - BLIS_NC, &blkszs[BLIS_NC], BLIS_NR, \ - BLIS_KC, &blkszs[BLIS_KC], BLIS_KR, \ - BLIS_MC, &blkszs[BLIS_MC], BLIS_MR, \ - cntx); \ - } else { \ - bli_blksz_init_easy(&blkszs[BLIS_MC], 144, 72, 144, 72 ); \ - bli_blksz_init_easy(&blkszs[BLIS_KC], 256, 256, 256, 256 ); \ - bli_blksz_init_easy(&blkszs[BLIS_NC], 4080, 4080, 4080, 4080 ); \ - \ - bli_cntx_set_blkszs( \ - BLIS_NAT, 3, \ - BLIS_NC, &blkszs[BLIS_NC], BLIS_NR, \ - BLIS_KC, &blkszs[BLIS_KC], BLIS_KR, \ - BLIS_MC, &blkszs[BLIS_MC], BLIS_MR, \ - cntx); \ - } \ - } \ -} -#else -#define BLIS_GEMM_DYNAMIC_BLOCK_SIZE_UPDATE(cntx, rntm, c) {} -#endif - #endif diff --git a/config/zen3/make_defs.mk b/config/zen3/make_defs.mk index 149100b802..bc36e6ae94 100644 --- a/config/zen3/make_defs.mk +++ b/config/zen3/make_defs.mk @@ -1,11 +1,11 @@ # # -# BLIS +# BLIS # An object-based framework for developing high-performance BLAS-like # libraries. # # Copyright (C) 2014, The University of Texas at Austin -# Copyright (C) 2020, Advanced Micro Devices, Inc. +# Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -39,7 +39,7 @@ # Declare the name of the current configuration and add it to the # running list of configurations included by common.mk. -THIS_CONFIG := zen3 +THIS_CONFIG := zen3 #CONFIGS_INCL += $(THIS_CONFIG) # @@ -49,7 +49,16 @@ THIS_CONFIG := zen3 # NOTE: The build system will append these variables with various # general-purpose/configuration-agnostic flags in common.mk. You # may specify additional flags here as needed. -CPPROCFLAGS := + +# Since we removed BLIS_CONFIG_EPYC from header file, we need to +# add it here at two places, +# CPPROCFLAGS = This will enable it for framework code +# This flag is used when configure is invoked with specific architecture +# CKOPTFLAGS = This will enable it for architecture specific kernels +# This flag is used for kernels assocaited with this architecture +# irrespective of the configuration it is built for. + +CPPROCFLAGS := -DBLIS_CONFIG_EPYC CMISCFLAGS := CPICFLAGS := CWARNFLAGS := @@ -119,6 +128,10 @@ endif # gcc CROPTFLAGS := $(CKOPTFLAGS) CRVECFLAGS := $(CKVECFLAGS) +# Add this after updating variables for reference kernels +# we don't want this defined for them +CKOPTFLAGS += -DBLIS_CONFIG_EPYC + # Store all of the variables here to new variables containing the # configuration name. $(eval $(call store-make-defs,$(THIS_CONFIG))) diff --git a/config_registry b/config_registry index f606e09e62..97dbcf5ae5 100644 --- a/config_registry +++ b/config_registry @@ -11,7 +11,7 @@ x86_64: intel64 amd64 amd64_legacy intel64: skx knl haswell sandybridge penryn generic amd64_legacy: excavator steamroller piledriver bulldozer generic -amd64: zen3 zen2 zen generic +amdepyc: zen3 zen2 zen generic # NOTE: ARM families will remain disabled until runtime hardware detection # logic is added to BLIS. diff --git a/frame/base/bli_arch.c b/frame/base/bli_arch.c index b52781f056..3df2c3688b 100644 --- a/frame/base/bli_arch.c +++ b/frame/base/bli_arch.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018-2020, Advanced Micro Devices, Inc. + Copyright (C) 2018-2021, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -124,10 +124,11 @@ void bli_arch_set_id( void ) // selection behavior. // Architecture families. - #if defined BLIS_FAMILY_INTEL64 || \ - defined BLIS_FAMILY_AMD64 || \ - defined BLIS_FAMILY_X86_64 || \ - defined BLIS_FAMILY_ARM64 || \ + #if defined BLIS_FAMILY_INTEL64 || \ + defined BLIS_FAMILY_AMDEPYC || \ + defined BLIS_FAMILY_AMD64_LEGACY || \ + defined BLIS_FAMILY_X86_64 || \ + defined BLIS_FAMILY_ARM64 || \ defined BLIS_FAMILY_ARM32 id = bli_cpuid_query_id(); #endif diff --git a/frame/base/bli_cntx.c b/frame/base/bli_cntx.c index 82087b696a..2ff56c0ba6 100644 --- a/frame/base/bli_cntx.c +++ b/frame/base/bli_cntx.c @@ -1263,9 +1263,6 @@ void bli_cntx_set_l3_sup_kers( dim_t n_ukrs, ... ) } // ----------------------------------------------------------------------------- - -#ifdef AOCL_BLIS_ZEN - void bli_cntx_set_trsm_blkszs( dim_t n_bs, ... ) { // This function should be called from the bli_cntx_init_*() function for @@ -1366,7 +1363,7 @@ void bli_cntx_set_trsm_blkszs( dim_t n_bs, ... ) #endif bli_free_intl( bszids ); } -#endif + // ----------------------------------------------------------------------------- void bli_cntx_set_l1f_kers( dim_t n_kers, ... ) diff --git a/frame/base/bli_cpuid.c b/frame/base/bli_cpuid.c index 11e27123ae..4b3837544f 100644 --- a/frame/base/bli_cpuid.c +++ b/frame/base/bli_cpuid.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018-2020, Advanced Micro Devices, Inc. + Copyright (C) 2018-2021, Advanced Micro Devices, Inc. All rights reserved. Copyright (C) 2019, Dave Love, University of Manchester Redistribution and use in source and binary forms, with or without @@ -77,7 +77,7 @@ arch_t bli_cpuid_query_id( void ) printf( "vendor = %s\n", vendor==1 ? "AMD": "INTEL" ); printf("family = %x\n", family ); printf( "model = %x\n", model ); - + printf( "features = %x\n", features ); #endif @@ -117,7 +117,7 @@ arch_t bli_cpuid_query_id( void ) #ifdef BLIS_CONFIG_ZEN3 if ( bli_cpuid_is_zen3( family, model, features ) ) return BLIS_ARCH_ZEN3; -#endif +#endif #ifdef BLIS_CONFIG_ZEN2 if ( bli_cpuid_is_zen2( family, model, features ) ) return BLIS_ARCH_ZEN2; @@ -282,10 +282,12 @@ bool bli_cpuid_is_zen3 if ( family != 0x19 ) return FALSE; // Finally, check for specific models: - // - 0x00-0xff (THIS NEEDS UPDATING) + // Zen 3 maps to couple of different model number ranges + // we check for all of them. const bool is_arch = - ( 0x00 <= model && model <= 0xff ); + (( model <= 0x0f ) || + (0x30 <= model && model <= 0x3f )); if ( !is_arch ) return FALSE; @@ -310,7 +312,6 @@ bool bli_cpuid_is_zen2 if ( family != 0x17 ) return FALSE; // Finally, check for specific models: - // - 0x30-0xff (THIS NEEDS UPDATING) const bool is_arch = ( 0x30 <= model && model <= 0xff ); @@ -338,10 +339,7 @@ bool bli_cpuid_is_zen if ( family != 0x17 ) return FALSE; // Finally, check for specific models: - // - 0x00-0xff (THIS NEEDS UPDATING) - const bool is_arch - = - ( 0x00 <= model && model <= 0xff ); + const bool is_arch = (model <= 0x30 ); if ( !is_arch ) return FALSE; @@ -811,7 +809,7 @@ uint32_t bli_cpuid_query if ( bli_cpuid_has_features( ecx, FEATURE_MASK_AVX ) ) *features |= FEATURE_AVX; if ( bli_cpuid_has_features( ecx, FEATURE_MASK_FMA3 ) ) *features |= FEATURE_FMA3; - // Check whether the hardware supports xsave/xrestor/xsetbv/xgetbv AND + // Check whether the hardware supports xsave/xrestor/xsetbv/xgetbv AND // support for these is enabled by the OS. If so, then we proceed with // checking that various register-state saving features are available. if ( bli_cpuid_has_features( ecx, FEATURE_MASK_XGETBV ) ) @@ -843,7 +841,7 @@ uint32_t bli_cpuid_query // The OS can manage the state of 512-bit zmm (AVX-512) registers // only if the xcr[7:5] bits are set. If they are not set, then - // clear all feature bits related to AVX-512. + // clear all feature bits related to AVX-512. if ( !bli_cpuid_has_features( eax, XGETBV_MASK_XMM | XGETBV_MASK_YMM | XGETBV_MASK_ZMM ) ) @@ -859,7 +857,7 @@ uint32_t bli_cpuid_query // The OS can manage the state of 256-bit ymm (AVX) registers // only if the xcr[2] bit is set. If it is not set, then - // clear all feature bits related to AVX. + // clear all feature bits related to AVX. if ( !bli_cpuid_has_features( eax, XGETBV_MASK_XMM | XGETBV_MASK_YMM ) ) { @@ -872,7 +870,7 @@ uint32_t bli_cpuid_query // The OS can manage the state of 128-bit xmm (SSE) registers // only if the xcr[1] bit is set. If it is not set, then // clear all feature bits related to SSE (which means the - // entire bitfield is clear). + // entire bitfield is clear). if ( !bli_cpuid_has_features( eax, XGETBV_MASK_XMM ) ) { *features = 0; diff --git a/frame/compat/bla_gemm.c b/frame/compat/bla_gemm.c index d071dc75dc..0e77c3bb1f 100644 --- a/frame/compat/bla_gemm.c +++ b/frame/compat/bla_gemm.c @@ -457,7 +457,7 @@ void dgemm_ //cntx_t* cntx = bli_gks_query_cntx(); //dim_t nt = bli_thread_get_num_threads(); // get number of threads bool nt = bli_thread_get_is_parallel(); // Check if parallel dgemm is invoked. - + // if m0 is large and (n0 & k0) < 10 - SMALL GEMM - ST is better // @@ -488,7 +488,7 @@ void dgemm_ // The code below will be called when number of threads = 1. #ifdef BLIS_ENABLE_SMALL_MATRIX - + //if( ((m0 + n0 -k0) < 2000) && ((m0 + k0-n0) < 2000) && ((n0 + k0-m0) < 2000) && (n0 > 2)) if( ( ( (m0 + n0 -k0) < 2000) && ((m0 + k0-n0) < 2000) && ((n0 + k0-m0) < 2000) ) || ((n0 <= 10) && (k0 <=10)) ) diff --git a/frame/include/bli_arch_config.h b/frame/include/bli_arch_config.h index 93904253ce..b341eaee3c 100644 --- a/frame/include/bli_arch_config.h +++ b/frame/include/bli_arch_config.h @@ -6,7 +6,7 @@ Copyright (C) 2014, The University of Texas at Austin Copyright (C) 2016, Hewlett Packard Enterprise Development LP - Copyright (C) 2019 - 2020, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 2021, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -136,8 +136,11 @@ CNTX_INIT_PROTS( generic ) #ifdef BLIS_FAMILY_INTEL64 #include "bli_family_intel64.h" #endif -#ifdef BLIS_FAMILY_AMD64 -#include "bli_family_amd64.h" +#ifdef BLIS_FAMILY_AMDEPYC +#include "bli_family_amdepyc.h" +#endif +#ifdef BLIS_FAMILY_AMD64_LEGACY +#include "bli_family_amd64_legacy.h" #endif #ifdef BLIS_FAMILY_X86_64 #include "bli_family_x86_64.h" From bdb5e32176c91a5a59fc48b1a87548d3ed98c008 Mon Sep 17 00:00:00 2001 From: Harihara Sudhan S Date: Thu, 26 Aug 2021 20:24:13 +0530 Subject: [PATCH 004/243] Level 1 Kernel: damaxv AVX512 Details: - Developed damaxv for AVX512 extension - Implemented removeNAN function that converts NAN values to negative values based on the location - Usage COMPARE256/COMPARE128 avoided in AVX512 implementation for better performance - Unrolled the loop by order of 4. Change-Id: Icf2a3606cf311ecc646aeb3db0628b293b9a3326 --- config/skx/bli_cntx_init_skx.c | 2 +- kernels/zen/1/bli_amaxv_zen_int.c | 347 +++++++++++++++++++++++++++++- kernels/zen/bli_kernels_zen.h | 1 + 3 files changed, 348 insertions(+), 2 deletions(-) diff --git a/config/skx/bli_cntx_init_skx.c b/config/skx/bli_cntx_init_skx.c index f18503a7a7..302ea63562 100644 --- a/config/skx/bli_cntx_init_skx.c +++ b/config/skx/bli_cntx_init_skx.c @@ -74,7 +74,7 @@ void bli_cntx_init_skx( cntx_t* cntx ) #if 1 // amaxv BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int, - BLIS_AMAXV_KER, BLIS_DOUBLE, bli_damaxv_zen_int, + BLIS_AMAXV_KER, BLIS_DOUBLE, bli_damaxv_zen_int_avx512, #endif // axpyv #if 0 diff --git a/kernels/zen/1/bli_amaxv_zen_int.c b/kernels/zen/1/bli_amaxv_zen_int.c index e72705340e..d6bf24cf62 100644 --- a/kernels/zen/1/bli_amaxv_zen_int.c +++ b/kernels/zen/1/bli_amaxv_zen_int.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2016 - 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2016 - 2021, Advanced Micro Devices, Inc. Copyright (C) 2018, The University of Texas at Austin Redistribution and use in source and binary forms, with or without @@ -37,6 +37,15 @@ #include "blis.h" +/* Union data structure to access AVX registers + One 512-bit AVX register holds 8 DP elements. */ +typedef union +{ + __m512d v; + double d[8] __attribute__((aligned(64))); +} v8df_t; + + /* Union data structure to access AVX registers One 256-bit AVX register holds 8 SP elements. */ typedef union @@ -521,3 +530,339 @@ void PASTEMAC(ch,varname) \ GENTFUNCR( scomplex, float, c, s, amaxv_zen_int ) GENTFUNCR( dcomplex, double, z, d, amaxv_zen_int ) #endif + + +/* Converts all the NAN to a negative number less than previously encountered NANs*/ +__m512d remove_NAN_512d(__m512d vec) +{ + + static int iter; + static __m512d sign_mask; + + __m512d vec_mask; + __m512i int_mask_vec; + __mmask8 vec_mask8; + + iter = iter - 1; + + sign_mask = _mm512_set1_pd(-0.f); + + //numbers other than NAN will become 0 + vec_mask = _mm512_mul_pd(vec, sign_mask); + + //producing an 8-bit mask + int_mask_vec = _mm512_castpd_si512(vec_mask); + vec_mask8 = _mm512_movepi64_mask(int_mask_vec); + + //replacing all the NAN with negative numbers + vec = _mm512_mask_blend_pd(vec_mask8, _mm512_set1_pd(-1 + iter), vec); + + return vec; +} + +//---------------------------------------------------------------------------------------------------- + +void bli_damaxv_zen_int_avx512( + dim_t n, + double *restrict x, inc_t incx, + dim_t *restrict i_max, + cntx_t *restrict cntx) +{ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3) + double *minus_one = PASTEMAC(d, m1); + dim_t *zero_i = PASTEMAC(i, 0); + + double chi1_r; + //double chi1_i; + double abs_chi1; + double abs_chi1_max; + dim_t i_max_l; + dim_t i; + + /* If the vector length is zero, return early. This directly emulates + the behavior of netlib BLAS's i?amax() routines. */ + if (bli_zero_dim1(n)) + { + PASTEMAC(i, copys) + (*zero_i, *i_max); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3) + return; + } + + /* Initialize the index of the maximum absolute value to zero. */ + PASTEMAC(i, copys) + (*zero_i, i_max_l); + + /* Initialize the maximum absolute value search candidate with + -1, which is guaranteed to be less than all values we will + compute. */ + PASTEMAC(d, copys) + (*minus_one, abs_chi1_max); + + // For non-unit strides, or very small vector lengths, compute with + // scalar code. + if (incx != 1 || n < 8) + { + for (i = 0; i < n; ++i) + { + double *chi1 = x + (i)*incx; + + /* Get the real and imaginary components of chi1. */ + chi1_r = *chi1; + + /* Replace chi1_r and chi1_i with their absolute values. */ + chi1_r = fabs(chi1_r); + + /* Add the real and imaginary absolute values together. */ + abs_chi1 = chi1_r; + + /* If the absolute value of the current element exceeds that of + the previous largest, save it and its index. If NaN is + encountered, then treat it the same as if it were a valid + value that was smaller than any previously seen. This + behavior mimics that of LAPACK's i?amax(). */ + if (abs_chi1_max < abs_chi1 || (isnan(abs_chi1) && !isnan(abs_chi1_max))) + { + abs_chi1_max = abs_chi1; + i_max_l = i; + } + } + } + else + { + + dim_t iterations, n_left, vector_length = 8, unrollCount = 0; + + //mask bits + __mmask8 mask_got_01, mask_got_23; + + //YMM0 - YMM6 registers + v4df_t max_hi, max_lo, max_ind_hi, max_ind_lo, + mask_final, inter_result, inter_ind; + + //XMM0 to XMM4 registers + v2dd_t max_vec_hi, max_vec_lo, max_ind_hi_128, + max_ind_lo_128, mask_vec_lo; + + //ZMM0 to ZMM13 registers + v8df_t zmm0, zmm1, zmm2, zmm3, zmm4_Ind, + zmm5_Ind, zmm6_Ind, zmm7_Ind, max_01, + max_23, final_max, max_array, max_ind, inc_vec; + + //ZMM14 to ZMM16 registers + __m512d mask_01, mask_23, sign_mask; + + //Intermediate int mask values + __m512i int_mask_01, int_mask_23; + + // Initialize sign mask + sign_mask = _mm512_set1_pd(-0.f); + + //Initializing the indexes of the base case of max vector + zmm4_Ind.v = _mm512_set_pd(7, 6, 5, 4, 3, 2, 1, 0); + inc_vec.v = _mm512_set1_pd(8); //Vector for incrementing + + // Initializing the max array as vec [ 0 : 512 ] + max_array.v = _mm512_loadu_pd(x); + + // Taking the absolute value and removing the NAN + max_array.v = _mm512_andnot_pd(sign_mask, max_array.v); + max_array.v = remove_NAN_512d(max_array.v); + + // Initializing the maximumum index + max_ind.v = _mm512_set_pd(7, 6, 5, 4, 3, 2, 1, 0); + x += vector_length; + + //Incrementing to make the vector + //to point to the next 8 elements + zmm4_Ind.v = _mm512_add_pd(zmm4_Ind.v, inc_vec.v); + + /* Loop unrolled by a factor of 4 + At the end of the loop max_array holds the largest element + in each corresponding vector index */ + for (unrollCount = 8; (unrollCount + 31) < n; unrollCount += 32) + { + // Taking 32 elements + // Taking only the absolute values of the registers + // Removing the NAN values and replacing it + // with negative numbers + zmm0.v = _mm512_loadu_pd(x); + zmm0.v = _mm512_andnot_pd(sign_mask, zmm0.v); + zmm0.v = remove_NAN_512d(zmm0.v); + x += vector_length; + + zmm1.v = _mm512_loadu_pd(x); + zmm5_Ind.v = _mm512_add_pd(zmm4_Ind.v, inc_vec.v); + zmm1.v = _mm512_andnot_pd(sign_mask, zmm1.v); + zmm1.v = remove_NAN_512d(zmm1.v); + x += vector_length; + + zmm2.v = _mm512_loadu_pd(x); + zmm6_Ind.v = _mm512_add_pd(zmm5_Ind.v, inc_vec.v); + zmm2.v = _mm512_andnot_pd(sign_mask, zmm2.v); + zmm2.v = remove_NAN_512d(zmm2.v); + x += vector_length; + + zmm3.v = _mm512_loadu_pd(x); + zmm7_Ind.v = _mm512_add_pd(zmm6_Ind.v, inc_vec.v); + zmm3.v = _mm512_andnot_pd(sign_mask, zmm3.v); + zmm3.v = remove_NAN_512d(zmm3.v); + x += vector_length; + + /*Using sub function to generating the mask + as a 512d type*/ + mask_01 = _mm512_sub_pd(zmm0.v, zmm1.v); + mask_23 = _mm512_sub_pd(zmm2.v, zmm3.v); + + //Converting the 512d mask to a 512i mask + int_mask_01 = _mm512_castpd_si512(mask_01); + int_mask_23 = _mm512_castpd_si512(mask_23); + + /*Converting the 512i mask + to mmask type to use the mask bits*/ + mask_got_01 = _mm512_movepi64_mask(int_mask_01); + mask_got_23 = _mm512_movepi64_mask(int_mask_23); + + //Storing the largest elements in index % 8 position for + //vector 1 and 2, and the index of the corresponding element + max_01.v = _mm512_mask_blend_pd(mask_got_01, zmm0.v, zmm1.v); + zmm5_Ind.v = _mm512_mask_blend_pd(mask_got_01, zmm4_Ind.v, zmm5_Ind.v); + + //Storing the largest elements in index % 8 position for + //vector 3 and 4, and the index of the corresponding element + max_23.v = _mm512_mask_blend_pd(mask_got_23, zmm2.v, zmm3.v); + zmm6_Ind.v = _mm512_mask_blend_pd(mask_got_23, zmm6_Ind.v, zmm7_Ind.v); + + //Generating mask for the intermediate max vector + mask_01 = _mm512_sub_pd(max_01.v, max_23.v); + int_mask_01 = _mm512_castpd_si512(mask_01); + mask_got_01 = _mm512_movepi64_mask(int_mask_01); + + /*Storing the largest elements in index % 8 position for + the intermediate max vectors, + and the index of the corresponding element*/ + final_max.v = _mm512_mask_blend_pd(mask_got_01, max_01.v, max_23.v); + zmm5_Ind.v = _mm512_mask_blend_pd(mask_got_01, zmm5_Ind.v, zmm6_Ind.v); + + //Generating the mask for final max vector and base max vector + mask_01 = _mm512_sub_pd(max_array.v, final_max.v); + int_mask_01 = _mm512_castpd_si512(mask_01); + mask_got_01 = _mm512_movepi64_mask(int_mask_01); + + // Result is the maximum of all index % 8 locations + max_array.v = _mm512_mask_blend_pd(mask_got_01, max_array.v, final_max.v); + max_ind.v = _mm512_mask_blend_pd(mask_got_01, max_ind.v, zmm5_Ind.v); + + // Incrementing the index to point to the next 8 locations + zmm4_Ind.v = _mm512_add_pd(zmm7_Ind.v, inc_vec.v); + } + + // Calculating the number of iterations left + iterations = (n - unrollCount) / vector_length; + n_left = (n - unrollCount) % vector_length; + + /* At the end of the loop max_array holds the largest element + in each corresponding vector index */ + for (int i = 1; i < iterations; ++i) + { + // Taking 32 elements + // Taking only the absolute values of the registers + // Removing the NAN values and replacing it + // with negative numbers + zmm0.v = _mm512_loadu_pd(x); + zmm0.v = _mm512_abs_pd(zmm0.v); + zmm0.v = remove_NAN_512d(zmm0.v); + + //Generating mask for the intermediate max vector + mask_01 = _mm512_sub_pd(max_array.v, zmm0.v); + int_mask_01 = _mm512_castpd_si512(mask_01); + mask_got_01 = _mm512_movepi64_mask(int_mask_01); + + // Result is the maximum of all index % 8 locations + max_array.v = _mm512_mask_blend_pd(mask_got_01, max_array.v, zmm0.v); + + //Storing the index of the corresponding max array elemets + max_ind.v = _mm512_mask_blend_pd(mask_got_01, max_ind.v, zmm4_Ind.v); + + //Incrementing the vector the point to the next location + //Incrementing the vector indexes + x += vector_length; + zmm4_Ind.v = _mm512_add_pd(zmm4_Ind.v, inc_vec.v); + } + + //Breaking max array into vectors of length 4 + //Taking upper and lower halves + max_hi.v = _mm512_extractf64x4_pd(max_array.v, 1); + max_ind_hi.v = _mm512_extractf64x4_pd(max_ind.v, 1); + max_lo.v = _mm512_extractf64x4_pd(max_array.v, 0); + max_ind_lo.v = _mm512_extractf64x4_pd(max_ind.v, 0); + + //Generating the mask for blending + mask_final.v = _mm256_sub_pd(max_hi.v, max_lo.v); + + // Storing the max of max array index % 4 + inter_result.v = _mm256_blendv_pd(max_hi.v, max_lo.v, mask_final.v); + inter_ind.v = _mm256_blendv_pd(max_ind_hi.v, max_ind_lo.v, mask_final.v); + + //Breaking max array into vectors of length 2 + max_vec_lo.v = _mm256_extractf128_pd(inter_result.v, 0); + max_vec_hi.v = _mm256_extractf128_pd(inter_result.v, 1); + max_ind_hi_128.v = _mm256_extractf128_pd(inter_ind.v, 1); + max_ind_lo_128.v = _mm256_extractf128_pd(inter_ind.v, 0); + + //Generating the mask for blending + mask_vec_lo.v = _mm_sub_pd(max_vec_lo.v, max_vec_hi.v); + + // Storing the max of max array index % 2 + max_vec_lo.v = _mm_blendv_pd(max_vec_lo.v, max_vec_hi.v, mask_vec_lo.v); + max_ind_lo_128.v = _mm_blendv_pd(max_ind_lo_128.v, max_ind_hi_128.v, mask_vec_lo.v); + + max_vec_hi.v = _mm_permute_pd(max_vec_lo.v, 1); + max_ind_hi_128.v = _mm_permute_pd(max_ind_lo_128.v, 1); + + //Performing work of CMP128 i.e generating mask + mask_vec_lo.v = _mm_sub_pd(max_vec_lo.v, max_vec_hi.v); + + //Finding the maximum element + max_vec_lo.v = _mm_blendv_pd(max_vec_lo.v, max_vec_hi.v, mask_vec_lo.v); + max_ind_lo_128.v = _mm_blendv_pd(max_ind_lo_128.v, max_ind_hi_128.v, mask_vec_lo.v); + + abs_chi1_max = max_vec_lo.d[0]; + + //If the largest number is negative it is NAN + if (abs_chi1_max < 0) + abs_chi1_max = NAN; + + i_max_l = max_ind_lo_128.d[0]; + + for (i = n - n_left; i < n; i++) + { + double *chi1 = x; + + /* Get the real and imaginary components of chi1. */ + chi1_r = *chi1; + + /* Replace chi1_r and chi1_i with their absolute values. */ + abs_chi1 = fabs(chi1_r); + + /* If the absolute value of the current element exceeds that of + the previous largest, save it and its index. If NaN is + encountered, return the index of the first NaN. This + behavior mimics that of LAPACK's i?amax(). */ + if (abs_chi1_max < abs_chi1 || (isnan(abs_chi1) && !isnan(abs_chi1_max))) + { + abs_chi1_max = abs_chi1; + i_max_l = i; + } + + x += 1; + } + } + + // Return value + *i_max = i_max_l; + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3) +} + +// --------------------------------------------------------------------------------- diff --git a/kernels/zen/bli_kernels_zen.h b/kernels/zen/bli_kernels_zen.h index 87ce504d86..0daa54936e 100644 --- a/kernels/zen/bli_kernels_zen.h +++ b/kernels/zen/bli_kernels_zen.h @@ -45,6 +45,7 @@ PACKM_KER_PROT(double, d, packm_6xk_nn_zen) // amaxv (intrinsics) AMAXV_KER_PROT( float, s, amaxv_zen_int ) AMAXV_KER_PROT( double, d, amaxv_zen_int ) +AMAXV_KER_PROT( double, d, amaxv_zen_int_avx512 ) // axpyv (intrinsics) AXPYV_KER_PROT( float, s, axpyv_zen_int ) From 7787bc79b11ad4221af6a48235c3f420ee504e2a Mon Sep 17 00:00:00 2001 From: Abhiram S Date: Thu, 26 Aug 2021 19:23:44 +0530 Subject: [PATCH 005/243] Level1 samaxv: AVX512 implementation Details: 1. Unrolled by a factor 5. This gave around 1GFLOPS gain 2. Changed CMP to subs and remove nan. CMP uses a lot of compare, which is higher in latency and more number of instructions. Replacing with subs and remove nan reduced it to 3 instructions and lighter ones. 3. Added remove nan function. 4. Added AVX512 definition in skx context. 5. Disabled code in AMAXV kernel depending on AVX512 flag exists or not Change-Id: I191725a55bc33edf8d537156292cf997d6a5fe35 --- config/skx/bli_cntx_init_skx.c | 2 +- config/skx/bli_family_skx.h | 2 + kernels/zen/1/bli_amaxv_zen_int.c | 560 +++++++++++++++++++++++++++++- kernels/zen/bli_kernels_zen.h | 1 + 4 files changed, 562 insertions(+), 3 deletions(-) diff --git a/config/skx/bli_cntx_init_skx.c b/config/skx/bli_cntx_init_skx.c index 302ea63562..c14311bf21 100644 --- a/config/skx/bli_cntx_init_skx.c +++ b/config/skx/bli_cntx_init_skx.c @@ -73,7 +73,7 @@ void bli_cntx_init_skx( cntx_t* cntx ) 10, #if 1 // amaxv - BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int, + BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int_avx512, BLIS_AMAXV_KER, BLIS_DOUBLE, bli_damaxv_zen_int_avx512, #endif // axpyv diff --git a/config/skx/bli_family_skx.h b/config/skx/bli_family_skx.h index ac9478f8ba..cbba06358e 100644 --- a/config/skx/bli_family_skx.h +++ b/config/skx/bli_family_skx.h @@ -50,6 +50,8 @@ #define BLIS_SIMD_SIZE 64 #define BLIS_SIMD_NUM_REGISTERS 32 +#define AVX512 + //#include //#define BLIS_MALLOC_POOL malloc diff --git a/kernels/zen/1/bli_amaxv_zen_int.c b/kernels/zen/1/bli_amaxv_zen_int.c index d6bf24cf62..8487bdce4b 100644 --- a/kernels/zen/1/bli_amaxv_zen_int.c +++ b/kernels/zen/1/bli_amaxv_zen_int.c @@ -36,7 +36,9 @@ #include "immintrin.h" #include "blis.h" - +// Disable for all context without AVX512 support +// Please define it in bli_family_xxx.h in config directory if there is AVX512 support +#ifdef AVX512 /* Union data structure to access AVX registers One 512-bit AVX register holds 8 DP elements. */ typedef union @@ -45,6 +47,14 @@ typedef union double d[8] __attribute__((aligned(64))); } v8df_t; +/* Union data structure to access AVX registers + One 512-bit AVX register holds 16 SP elements. */ +typedef union +{ + __m512 v; + float f[16] __attribute__((aligned(64))); +} v16sf_t; +#endif /* Union data structure to access AVX registers One 256-bit AVX register holds 8 SP elements. */ @@ -74,6 +84,42 @@ typedef union double d[2]; }v2dd_t; +// Disable for all context without AVX512 support +// Please define it in bli_family_xxx.h in config directory if there is AVX512 support +#ifdef AVX512 +/* Convert the nan to -ve numbers decrementing with + the times the function is called to ensure that + bigger numbers are assigned for nan which showed + up first.*/ +__m512 remove_NAN_512_s(__m512 vec) +{ + // Sign extraction mask + __m512 sign_mask; + // Temporary place to store vector's sign extracted 16xdouble word + __m512 vec_mask; + // k register to store the mask to do blend operation to remove NAN + __mmask16 vec_mask16; + // Static to preserve accross the function calls + static int iter = -1; + iter -= 1; + + // Extracting sign from the vec into int_mask_vec + // Sign is -0.f in IEEE754 is just signbit set, all others 0 + sign_mask = _mm512_set1_ps(-0.f); + // And with -0.f will keep just signbits, all others will be 0 + vec_mask = _mm512_mul_ps(vec, sign_mask); + // Typecast mask into int type no clock cycle is taken just to + // convince compiler. + __m512i int_mask_vec = _mm512_castps_si512(vec_mask); + // Extract the signbits and put it in a 16bit mask register + vec_mask16 = _mm512_movepi32_mask(int_mask_vec); + + // Swap NAN with -ve number + vec = _mm512_mask_blend_ps(vec_mask16, _mm512_set1_ps(iter), vec); + return vec; +} +#endif + // return a mask which indicates either: // - v1 > v2 // - v1 is NaN and v2 is not @@ -274,6 +320,509 @@ void bli_samaxv_zen_int AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3) } +// Disable for all context without AVX512 support +// Please define it in bli_family_xxx.h in config directory if there is AVX512 support +#ifdef AVX512 +void bli_samaxv_zen_int_avx512( + dim_t n, + float *restrict x, inc_t incx, + dim_t *restrict i_max, + cntx_t *restrict cntx) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_3) + // *minus_one = -1 + float *minus_one = PASTEMAC(s, m1); // bli_sm1() + // *zero_i = 0 + dim_t *zero_i = PASTEMAC(i, 0); // bli_i0() + + float fndMaxVal; // Max value will be stored in this + dim_t fndInd; // Max value's index will be stored in this + // Iterator for loops to keep continuity throughout the loops + dim_t i; + + /* If the vector length is zero, return early. This directly emulates + the behavior of netlib BLAS's i?amax() routines. */ + if (bli_zero_dim1(n)) + { + /* Set i_max to zero if dimension is 0, no need to compute */ + // Copy zero_i, that is 0 to i_max (i_max = 0) + PASTEMAC(i, copys) // bli_icopys + (*zero_i, *i_max); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3) + return; + } + + /* Initialize the index of the maximum absolute value to zero. */ + // Copy zero_i, that is 0 to fndInd (fndInd = 0) + PASTEMAC(i, copys) // bli_icopys + (*zero_i, fndInd); + + /* Initialize the maximum absolute value search candidate with + -1, which is guaranteed to be less than all values we will + compute. */ + // Copy minus_one to fndMaxVal real and imaginary. + PASTEMAC(s, copys) // bli_scopys + (*minus_one, fndMaxVal); + + // For non-unit strides, or very small vector lengths, compute with + // scalar code. + // n is less than the single vector length or non unit stride. + if (incx != 1 || n < 16) + { + for (i = 0; i < n; ++i) + { + // Call math.h fabsf to take absolute value of *(x +(i)*incx) + float absval = fabsf(*(x + (i)*incx)); + if (fndMaxVal < absval || (isnan(absval) && !isnan(fndMaxVal))) + { + // If max value is found, set the value and index + fndMaxVal = absval; + fndInd = i; + } + } + } + else + { + dim_t num_iter, num_remain; + dim_t num_vector_elements = 16; + /* Total Registers used is + * xmm0-xmm4 + * ymm5-ymm9 + * zmm10-zmm26 + * There are 6 free registers to use + */ + // zmm register 15x + v16sf_t x_vec_1, x_vec_2, x_vec_3, max_vec_1, max_vec_2, + max_vec_3, maxInd_vec_1, maxInd_vec_2, + maxInd_vec_3, index_vec_1, ind_vec_2, + ind_vec_3, inc_vec, mask, + abs_mask; + // ymm register 5x + v8sf_t max_vec_lo, max_vec_hi, + maxInd_vec_lo, maxInd_vec_hi, + mask_vec_lo; + // xmm register 5x + v4sf_t max_vec_lo_lo, max_vec_lo_hi, + maxInd_vec_lo_lo, maxInd_vec_lo_hi, + mask_vec_lo_lo; + // zmm register 1x + __m512i intMask; + // k register 3x + __mmask16 mask_vec_1, mask_vec_2, + mask_vec_3; + + // Number of iterations for main loop. + num_iter = n / num_vector_elements; + // Number of iterations remaining for residual non vector loop + num_remain = n % num_vector_elements; + // A number with signbit one and others 0 IEEE-754 + abs_mask.v = _mm512_set1_ps(-0.f); + // index_vector after loading max_vector with initial values. + index_vec_1.v = _mm512_setr_ps(16, 17, 18, 19, 20, 21, + 22, 23, 24, 25, 26, 27, + 28, 29, 30, 31); + // Broadcast 16. This is to increment the vector easily + inc_vec.v = _mm512_set1_ps(16); + // Load 16 float values from memory + max_vec_1.v = _mm512_loadu_ps(x); + // max_vector = abs(max_vector) + max_vec_1.v = _mm512_andnot_ps(abs_mask.v, max_vec_1.v); + // Remove nan and replace with -ve values + max_vec_1.v = remove_NAN_512_s(max_vec_1.v); + + // Increment x vector as we have loaded 16 values + x += num_vector_elements; + // indexes for values present in max vector. + maxInd_vec_1.v = _mm512_setr_ps(0, 1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15); + + int i = 1; + for (; (i + 4) < num_iter; i += 5) + { + /* + Unrolled to process 5 at a time. It basically works + by taking a master max_vec_1 and a maxInd_vec_1 + holding indexes. Elements are taken from the RAM on a batch + of 5 (1 master max_vec_1 already exists to compare so + 6 elements). Now each 2 of them is compared with each other + and an intermediate result is obtained. This intermediate + result is again with each other and combined until we reach + one vector in max_vector and maxIndex_vector. + */ + + // Load the vector and subs NAN + // Load Value x values + x_vec_1.v = _mm512_loadu_ps(x); + // x_vec_1 = abs(x_vec_1) + x_vec_1.v = _mm512_andnot_ps(abs_mask.v, x_vec_1.v); + // Increment x vector as we have loaded 16 values + x += num_vector_elements; + // Remove nan and replace with -ve values + x_vec_1.v = remove_NAN_512_s(x_vec_1.v); + + // Mask Generation of 1st(can be previous max) and 2nd element + // mask = max_vector - x_vec_1 + mask.v = _mm512_sub_ps(max_vec_1.v, x_vec_1.v); + // Type cast mask from IEEE754 (float) to integer type + // This operation will not need a new register, its just to convince + // the compiler. But its accounted as seperate register in the + // above calculations + intMask = _mm512_castps_si512(mask.v); + // Extract the signbit and build the mask. + mask_vec_1 = _mm512_movepi32_mask(intMask); + + // Load 2 elements to 2nd max and x vector, set indexes + // Load Value x values + max_vec_2.v = _mm512_loadu_ps(x); + // max_vec_2 = abs(max_vec_2) + max_vec_2.v = _mm512_andnot_ps(abs_mask.v, max_vec_2.v); + // Remove nan and replace with -ve values + max_vec_2.v = remove_NAN_512_s(max_vec_2.v); + // Increment x vector as we have loaded 16 values + x += num_vector_elements; + // Increment the index vector to point to next indexes. + maxInd_vec_2.v = _mm512_add_ps(index_vec_1.v, inc_vec.v); + + // Load Value x values + x_vec_2.v = _mm512_loadu_ps(x); + // x_vec_2 = abs(x_vec_2) + x_vec_2.v = _mm512_andnot_ps(abs_mask.v, x_vec_2.v); + // Remove nan and replace with -ve values + x_vec_2.v = remove_NAN_512_s(x_vec_2.v); + // Increment x vector as we have loaded 16 values + x += num_vector_elements; + // Increment the index vector to point to next indexes. + ind_vec_2.v = _mm512_add_ps(maxInd_vec_2.v, inc_vec.v); + + // Mask generation for last loaded 2 elements into x and max vectors. + // mask = max_vec_2 - x_vec_2 + mask.v = _mm512_sub_ps(max_vec_2.v, x_vec_2.v); + // Type cast mask from IEEE754 (float) to integer type + // This operation will not need a new register, its just to convince + // the compiler. But its accounted as seperate register in the + // above calculations + intMask = _mm512_castps_si512(mask.v); + // Extract the signbit and build the mask. + mask_vec_2 = _mm512_movepi32_mask(intMask); + + // Load 2 more elements to 3rd max and x vector, set indexes + // Load Value x values + max_vec_3.v = _mm512_loadu_ps(x); + // max_vec_3 = abs(max_vec_3) + max_vec_3.v = _mm512_andnot_ps(abs_mask.v, max_vec_3.v); + // Remove nan and replace with -ve values + max_vec_3.v = remove_NAN_512_s(max_vec_3.v); + // Increment x vector as we have loaded 16 values + x += num_vector_elements; + // Increment the index vector to point to next indexes. + maxInd_vec_3.v = _mm512_add_ps(ind_vec_2.v, inc_vec.v); + // Load Value x values + x_vec_3.v = _mm512_loadu_ps(x); + // x_vec_3 = abs(x_vec_3) + x_vec_3.v = _mm512_andnot_ps(abs_mask.v, x_vec_3.v); + // Remove nan and replace with -ve values + x_vec_3.v = remove_NAN_512_s(x_vec_3.v); + // Increment x vector as we have loaded 16 values + x += num_vector_elements; + // Increment the index vector to point to next indexes. + ind_vec_3.v = _mm512_add_ps(maxInd_vec_3.v, inc_vec.v); + + // Mask generation for last 2 elements loaded into x and max vectors. + // mask = max_vec_3 - x_vec_3 + mask.v = _mm512_sub_ps(max_vec_3.v, x_vec_3.v); + // Type cast mask from IEEE754 (float) to integer type + // This operation will not need a new register, its just to convince + // the compiler. But its accounted as seperate register in the + // above calculations + intMask = _mm512_castps_si512(mask.v); + // Extract the signbit and build the mask. + mask_vec_3 = _mm512_movepi32_mask(intMask); + + // Blend max vector and index vector (3 pairs of elements needs to be blended). + /* Take values from max_vector if corresponding bit in mask_vector is 0 + * otherwise take value from x_vector, this is accumulated maximum value + * from max_vector and x_vector to mask_vector */ + max_vec_1.v = _mm512_mask_blend_ps(mask_vec_1, + max_vec_1.v, + x_vec_1.v); + /* Take values from max_vector if corresponding bit in mask_vector is 0 + * otherwise take value from x_vector, this is accumulated maximum value + * from max_vector and x_vector to mask_vector */ + max_vec_2.v = _mm512_mask_blend_ps(mask_vec_2, + max_vec_2.v, + x_vec_2.v); + /* Take values from max_vector if corresponding bit in mask_vector is 0 + * otherwise take value from x_vector, this is accumulated maximum value + * from max_vector and x_vector to mask_vector */ + max_vec_3.v = _mm512_mask_blend_ps(mask_vec_3, + max_vec_3.v, + x_vec_3.v); + /* Take values from maxIndex_vector if corresponding bit in mask_vector + * is 0 otherwise take value from index_vec_1, this is accumulated + * maximum value index from maxIndex_vector and index_vec_1 + * to maxIndex_vector */ + maxInd_vec_1.v = _mm512_mask_blend_ps(mask_vec_1, + maxInd_vec_1.v, + index_vec_1.v); + /* Take values from maxIndex_vector if corresponding bit in mask_vector + * is 0 otherwise take value from index_vec_1, this is accumulated + * maximum value index from maxIndex_vector and index_vec_1 + * to maxIndex_vector */ + maxInd_vec_2.v = _mm512_mask_blend_ps(mask_vec_2, + maxInd_vec_2.v, + ind_vec_2.v); + /* Take values from maxIndex_vector if corresponding bit in mask_vector + * is 0 otherwise take value from index_vec_1, this is accumulated + * maximum value index from maxIndex_vector and index_vec_1 + * to maxIndex_vector */ + maxInd_vec_3.v = _mm512_mask_blend_ps(mask_vec_3, + maxInd_vec_3.v, + ind_vec_3.v); + + // Mask generation for blending max_vec_2 and max_vec_3 to max_vec_2. + // mask = max_vec_2 - max_vec_3 + mask.v = _mm512_sub_ps(max_vec_2.v, max_vec_3.v); + // Type cast mask from IEEE754 (float) to integer type + // This operation will not need a new register, its just to convince + // the compiler. But its accounted as seperate register in the + // above calculations + intMask = _mm512_castps_si512(mask.v); + // Extract the signbit and build the mask. + mask_vec_2 = _mm512_movepi32_mask(intMask); + + // Blend to obtain 1 vector each of max values and index. + /* Take values from max_vec_2 if corresponding bit in mask_vec_2 + * is 0 otherwise take value from max_vec_3, this is accumulated + * maximum value from max_vec_2 and max_vec_3 to mask_vec_2 */ + max_vec_2.v = _mm512_mask_blend_ps(mask_vec_2, + max_vec_2.v, + max_vec_3.v); + /* Take values from maxInd_vec_2 if corresponding bit in mask_vector + * is 0 otherwise take value from maxInd_vec_3, this is accumulated + * maximum value index from maxInd_vec_2 and maxInd_vec_3 + * to maxInd_vec_2 */ + maxInd_vec_2.v = _mm512_mask_blend_ps(mask_vec_2, + maxInd_vec_2.v, + maxInd_vec_3.v); + + // Mask generation for blending max_vec_1 and max_vec_2 into max_vec_1. + // mask = max_vec_1 - max_vec_2 + mask.v = _mm512_sub_ps(max_vec_1.v, max_vec_2.v); + // Type cast mask from IEEE754 (float) to integer type + // This operation will not need a new register, its just to convince + // the compiler. But its accounted as seperate register in the + // above calculations + intMask = _mm512_castps_si512(mask.v); + // Extract the signbit and build the mask. + mask_vec_1 = _mm512_movepi32_mask(intMask); + + // Final blend to the master max_vec_1 and maxInd_vec_1 + /* Take values from max_vec_1 if corresponding bit in mask_vec_1 + * is 0 otherwise take value from max_vec_2, this is accumulated + * maximum value from max_vec_1 and max_vec_2 to mask_vec_1 */ + max_vec_1.v = _mm512_mask_blend_ps(mask_vec_1, max_vec_1.v, max_vec_2.v); + /* Take values from maxInd_vec_1 if corresponding bit in mask_vector + * is 0 otherwise take value from maxInd_vec_2, this is accumulated + * maximum value index from maxInd_vec_1 and maxInd_vec_2 + * to maxInd_vec_1 */ + maxInd_vec_1.v = _mm512_mask_blend_ps(mask_vec_1, + maxInd_vec_1.v, + maxInd_vec_2.v); + + // Increment the index vector to point to next indexes. + index_vec_1.v = _mm512_add_ps(ind_vec_3.v, inc_vec.v); + } + + for (; i < num_iter; i++) + { + /* + Take vector one by one, above code makes max_vec_1 + contain the first 16 elements, now with the max vector + as first 16 elements (abs), we need to load next 16 elements + into x_vec_1 (abs). Now with those we can safely removeNan + which will put -ve values as NAN. + + These -ve values of NAN decreases by 1 in each iteration, + this helps us find the first NAN value. + */ + // Load Value x values + x_vec_1.v = _mm512_loadu_ps(x); + // x_vec_1 = abs(x_vec_1) + x_vec_1.v = _mm512_andnot_ps(abs_mask.v, x_vec_1.v); + // Remove nan and replace with -ve values + x_vec_1.v = remove_NAN_512_s(x_vec_1.v); + + // Mask Generation + // mask = max_vec_1 - x_vec_1 + mask.v = _mm512_sub_ps(max_vec_1.v, x_vec_1.v); + // Extract the signbit and build the mask. + mask_vec_1 = _mm512_movepi32_mask(_mm512_castps_si512(mask.v)); + /* Take values from max_vec_1 if corresponding bit in + * mask_vec_1 is 0 otherwise take value from x_vec_1, + * this is accumulated maximum value from max_vec_1 and + * x_vec_1 to mask_vec_1 */ + max_vec_1.v = _mm512_mask_blend_ps(mask_vec_1, + max_vec_1.v, + x_vec_1.v); + /* Take values from maxInd_vec_1 if corresponding bit in + * mask_vector is 0 otherwise take value from index_vec_1, + * this is accumulated maximum value index from maxInd_vec_1 + * and index_vec_1 to maxInd_vec_1 */ + maxInd_vec_1.v = _mm512_mask_blend_ps(mask_vec_1, + maxInd_vec_1.v, + index_vec_1.v); + + // Increment the index vector to point to next indexes. + index_vec_1.v = _mm512_add_ps(index_vec_1.v, inc_vec.v); + + // Increment x vector as we have loaded 16 values + x += num_vector_elements; + } + + num_remain = (n - ((i)*16)); + + /* + Now take the max vector and produce the max value from + the max vector by slicing and comparing with itself, + until we are left with just one index position and max value. + */ + // Split max to hi and lo + max_vec_hi.v = _mm512_extractf32x8_ps(max_vec_1.v, 1); + max_vec_lo.v = _mm512_extractf32x8_ps(max_vec_1.v, 0); + + // Split maxIndex to hi and lo + maxInd_vec_hi.v = _mm512_extractf32x8_ps(maxInd_vec_1.v, 1); + maxInd_vec_lo.v = _mm512_extractf32x8_ps(maxInd_vec_1.v, 0); + + // Compare max_vec_hi > max_vec_1 + // mask_vec_lo = max_vec_lo - max_vec_hi + mask_vec_lo.v = _mm256_sub_ps(max_vec_lo.v, max_vec_hi.v); + + /* Take values from max_vec_lo if corresponding bit in mask_vec_lo + * is 0 otherwise take value from max_vec_hi, this is accumulated + * maximum value from max_vec_lo and max_vec_hi to max_vec_lo */ + max_vec_lo.v = _mm256_blendv_ps(max_vec_lo.v, + max_vec_hi.v, + mask_vec_lo.v); + /* Take values from maxInd_vec_lo if corresponding bit + * in mask_vec_lo is 0 otherwise take value from maxInd_vec_hi, + * this is accumulated maximum value from maxInd_vec_lo and + * maxInd_vec_hi to maxInd_vec_lo */ + maxInd_vec_lo.v = _mm256_blendv_ps(maxInd_vec_lo.v, + maxInd_vec_hi.v, + mask_vec_lo.v); + + // Split max_lo to hi and lo + max_vec_lo_hi.v = _mm256_extractf128_ps(max_vec_lo.v, 1); + max_vec_lo_lo.v = _mm256_extractf128_ps(max_vec_lo.v, 0); + + // Split maxIndex_lo to hi and lo + maxInd_vec_lo_hi.v = _mm256_extractf128_ps(maxInd_vec_lo.v, 1); + maxInd_vec_lo_lo.v = _mm256_extractf128_ps(maxInd_vec_lo.v, 0); + + // mask_vec_lo_lo = max_vec_lo_lo - max_vec_lo_hi + mask_vec_lo_lo.v = _mm_sub_ps(max_vec_lo_lo.v, max_vec_lo_hi.v); + /* Take values from max_vec_lo_lo if corresponding bit in + * mask_vec_lo_lo is 0 otherwise take value from max_vec_lo_hi, + * this is accumulated maximum value from max_vec_lo_lo and + * max_vec_lo_hi to max_vec_lo_lo */ + max_vec_lo_lo.v = _mm_blendv_ps(max_vec_lo_lo.v, + max_vec_lo_hi.v, + mask_vec_lo_lo.v); + /* Take values from maxInd_vec_lo if corresponding bit + * in mask_vec_lo_lo is 0 otherwise take value from maxInd_vec_hi, + * this is accumulated maximum value from maxInd_vec_lo and + * maxInd_vec_hi to maxInd_vec_lo */ + maxInd_vec_lo_lo.v = _mm_blendv_ps(maxInd_vec_lo_lo.v, + maxInd_vec_lo_hi.v, + mask_vec_lo_lo.v); + + // Take 64 high bits of max_lo_lo and put it to 64 low bits, rest 1st value + /* Example max_vec_lo_lo is {a, b, x, y} + * After max_vec_lo_hi.v = _mm_permute_ps(max_vec_lo_lo.v, 14); + * max_vec_lo_hi is {x, y, a, a} (essentially folding the vector) + */ + max_vec_lo_hi.v = _mm_permute_ps(max_vec_lo_lo.v, 14); + // Fold the vector same as max_vector + maxInd_vec_lo_hi.v = _mm_permute_ps(maxInd_vec_lo_lo.v, 14); + + // mask_vec_lo_lo = max_vec_lo_lo - max_vec_lo_hi + mask_vec_lo_lo.v = _mm_sub_ps(max_vec_lo_lo.v, max_vec_lo_hi.v); + /* Take values from max_vec_lo_lo if corresponding bit in + * mask_vec_lo_lo is 0 otherwise take value from max_vec_lo_hi, + * this is accumulated maximum value from max_vec_lo_lo and + * max_vec_lo_hi to max_vec_lo_lo */ + max_vec_lo_lo.v = _mm_blendv_ps(max_vec_lo_lo.v, + max_vec_lo_hi.v, + mask_vec_lo_lo.v); + /* Take values from maxInd_vec_lo if corresponding bit + * in mask_vec_lo_lo is 0 otherwise take value from maxInd_vec_hi, + * this is accumulated maximum value from maxInd_vec_lo and + * maxInd_vec_hi to maxInd_vec_lo */ + maxInd_vec_lo_lo.v = _mm_blendv_ps(maxInd_vec_lo_lo.v, + maxInd_vec_lo_hi.v, + mask_vec_lo_lo.v); + + // Take max_vec_lo_lo.f[1] and put it to max_vec_lo_hi.f[0] + /* Example max_vec_lo_lo is {a, b, x, y} + * After max_vec_lo_hi.v = _mm_permute_ps(max_vec_lo_lo.v, 1); + * max_vec_lo_hi is {b, a, a, a} (essentially folding the vector) + */ + max_vec_lo_hi.v = _mm_permute_ps(max_vec_lo_lo.v, 1); + // Do the same operation. + maxInd_vec_lo_hi.v = _mm_permute_ps(maxInd_vec_lo_lo.v, 1); + + // mask_vec_lo_lo = max_vec_lo_lo - max_vec_lo_hi + mask_vec_lo_lo.v = _mm_sub_ps(max_vec_lo_lo.v, max_vec_lo_hi.v); + /* Take values from max_vec_lo_lo if corresponding bit in + * mask_vec_lo_lo is 0 otherwise take value from max_vec_lo_hi, + * this is accumulated maximum value from max_vec_lo_lo and + * max_vec_lo_hi to max_vec_lo_lo */ + max_vec_lo_lo.v = _mm_blendv_ps(max_vec_lo_lo.v, + max_vec_lo_hi.v, + mask_vec_lo_lo.v); + /* Take values from maxInd_vec_lo if corresponding bit + * in mask_vec_lo_lo is 0 otherwise take value from maxInd_vec_hi, + * this is accumulated maximum value from maxInd_vec_lo and + * maxInd_vec_hi to maxInd_vec_lo */ + maxInd_vec_lo_lo.v = _mm_blendv_ps(maxInd_vec_lo_lo.v, + maxInd_vec_lo_hi.v, + mask_vec_lo_lo.v); + /* We have kept on folding and comparing until we got one single index + * and max value so that is the final answer so set it as the final + * answer.*/ + fndInd = maxInd_vec_lo_lo.f[0]; + fndMaxVal = max_vec_lo_lo.f[0]; + // Found value is < 0 means it was the max NAN which was accumulated. + if (fndMaxVal < 0) + { + // So just set it as NAN + fndMaxVal = NAN; + } + // Finish off the remaining values using normal instructions + for (int i = n - num_remain; i < n; i++) + { + float absval = fabsf(*(x)); + if (fndMaxVal < absval || (isnan(absval) && !isnan(fndMaxVal))) + { + fndMaxVal = absval; + fndInd = i; + } + x += 1; + } + } + + // Issue vzeroupper instruction to clear upper lanes of ymm registers. + // This avoids a performance penalty caused by false dependencies when + // transitioning from from AVX to SSE instructions (which may occur + // later, especially if BLIS is compiled with -mfpmath=sse). + _mm256_zeroupper(); + + /* Store final index to output variable. */ + *i_max = fndInd; + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3) +} +#endif // ----------------------------------------------------------------------------- void bli_damaxv_zen_int @@ -531,7 +1080,9 @@ GENTFUNCR( scomplex, float, c, s, amaxv_zen_int ) GENTFUNCR( dcomplex, double, z, d, amaxv_zen_int ) #endif - +// Disable for all context without AVX512 support +// Please define it in bli_family_xxx.h in config directory if there is AVX512 support +#ifdef AVX512 /* Converts all the NAN to a negative number less than previously encountered NANs*/ __m512d remove_NAN_512d(__m512d vec) { @@ -560,8 +1111,12 @@ __m512d remove_NAN_512d(__m512d vec) return vec; } +#endif //---------------------------------------------------------------------------------------------------- +// Disable for all context without AVX512 support +// Please define it in bli_family_xxx.h in config directory if there is AVX512 support +#ifdef AVX512 void bli_damaxv_zen_int_avx512( dim_t n, double *restrict x, inc_t incx, @@ -864,5 +1419,6 @@ void bli_damaxv_zen_int_avx512( AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3) } +#endif // --------------------------------------------------------------------------------- diff --git a/kernels/zen/bli_kernels_zen.h b/kernels/zen/bli_kernels_zen.h index 0daa54936e..e43742e6e1 100644 --- a/kernels/zen/bli_kernels_zen.h +++ b/kernels/zen/bli_kernels_zen.h @@ -44,6 +44,7 @@ PACKM_KER_PROT(double, d, packm_6xk_nn_zen) // amaxv (intrinsics) AMAXV_KER_PROT( float, s, amaxv_zen_int ) +AMAXV_KER_PROT( float, s, amaxv_zen_int_avx512 ) AMAXV_KER_PROT( double, d, amaxv_zen_int ) AMAXV_KER_PROT( double, d, amaxv_zen_int_avx512 ) From 5d6012c50dc97ce0ca5590ec31dd4ad3cc29edcf Mon Sep 17 00:00:00 2001 From: Saitharun Date: Wed, 28 Jul 2021 20:29:12 +0530 Subject: [PATCH 006/243] Added additional symbols for BLIS APIs using wrapper functions Details: BLIS currently supports BLAS and CBLAS interfaces with lowercase. With this commit - we also supports uppercase with and without trailing underscore, lowercase without trailing underscore symbol names. Change-Id: Ibb06121821ab937b25d492409625916f542b2135 --- CMakeLists.txt | 9 - frame/compat/cblas/src/cblas_f77.h | 190 +- frame/include/bli_macro_defs.h | 351 +-- frame/util/CMakeLists.txt | 1 + frame/util/bli_util.h | 6 +- frame/util/bli_util_api_wrap.c | 3218 ++++++++++++++++++++++++++++ frame/util/bli_util_api_wrap.h | 1727 +++++++++++++++ 7 files changed, 4955 insertions(+), 547 deletions(-) create mode 100644 frame/util/bli_util_api_wrap.c create mode 100644 frame/util/bli_util_api_wrap.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 77a7e702ba..78a380fece 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -90,15 +90,10 @@ option(BLIS_ENABLE_ILP64 "ENABLE BLIS ILP64" OFF) option(ENABLE_INT_TYPE_SIZE " Internal BLIS integers ,used in native BLIS interfaces based on architecture dependent " ON) option(ENABLE_BLASTEST "Enable the blastest" OFF) option(ENABLE_TESTCPP_TESTING "Enabling testcpp" OFF) -option (ENABLE_NO_UNDERSCORE_API "export APIs without underscore" ON) -option (ENABLE_UPPERCASE "export APIs with uppercase" OFF) option (ENABLE_COMPLEX_RETURN_INTEL "Enable complex_return_intel" OFF) option (ENABLE_TRSM_PREINVERSION "Enable TRSM preinversion" ON) option (ENABLE_AOCL_DYNAMIC "Enable Dynamic Multi-threading" OFF) -if(ENABLE_NO_UNDERSCORE_API) - add_definitions(-DBLIS_ENABLE_NO_UNDERSCORE_API) -endif() if(ENABLE_COMPLEX_RETURN_INTEL) set(BLIS_ENABLE_COMPLEX_RETURN_INTEL TRUE) @@ -106,10 +101,6 @@ else() set(BLIS_DISABLE_COMPLEX_RETURN_INTEL TRUE) endif() -if(ENABLE_UPPERCASE) - add_definitions(-DBLIS_ENABLE_UPPERCASE) -endif() - if(ENABLE_AOCL_DYNAMIC) set(AOCL_DYNAMIC TRUE) endif() diff --git a/frame/compat/cblas/src/cblas_f77.h b/frame/compat/cblas/src/cblas_f77.h index 302cfd8151..b09963eec7 100644 --- a/frame/compat/cblas/src/cblas_f77.h +++ b/frame/compat/cblas/src/cblas_f77.h @@ -7,199 +7,13 @@ * * (Heavily hacked down from the original) * - * Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2020 - 2021, Advanced Micro Devices, Inc. All rights reserved. * */ #ifndef CBLAS_F77_H #define CBLAS_F77_H -#if defined(BLIS_ENABLE_NO_UNDERSCORE_API) - /* - * Level 1 BLAS - */ -#define F77_xerbla xerbla -#define F77_srotg srotg -#define F77_srotmg srotmg -#define F77_srot srot -#define F77_srotm srotm -#define F77_drotg drotg -#define F77_drotmg drotmg -#define F77_drot drot -#define F77_drotm drotm -#define F77_sswap sswap -#define F77_scopy scopy -#define F77_saxpy saxpy -#define F77_isamax_sub isamaxsub -#define F77_dswap dswap -#define F77_dcopy dcopy -#define F77_daxpy daxpy -#define F77_idamax_sub idamaxsub -#define F77_cswap cswap -#define F77_ccopy ccopy -#define F77_caxpy caxpy -#define F77_icamax_sub icamaxsub -#define F77_zswap zswap -#define F77_zcopy zcopy -#define F77_zaxpy zaxpy -#define F77_zaxpby zaxpby -#define F77_izamax_sub izamaxsub -#define F77_sdot_sub sdotsub -#define F77_ddot_sub ddotsub -#define F77_dsdot_sub dsdotsub -#define F77_sscal sscal -#define F77_dscal dscal -#define F77_cscal cscal -#define F77_zscal zscal -#define F77_csscal csscal -#define F77_zdscal zdscal -#define F77_cdotu_sub cdotusub -#define F77_cdotc_sub cdotcsub -#define F77_zdotu_sub zdotusub -#define F77_zdotc_sub zdotcsub -#define F77_snrm2_sub snrm2sub -#define F77_sasum_sub sasumsub -#define F77_dnrm2_sub dnrm2sub -#define F77_dasum_sub dasumsub -#define F77_scnrm2_sub scnrm2sub -#define F77_scasum_sub scasumsub -#define F77_dznrm2_sub dznrm2sub -#define F77_dzasum_sub dzasumsub -#define F77_sdsdot_sub sdsdotsub -/* -* Level 2 BLAS -*/ -#define F77_ssymv ssymv -#define F77_ssbmv ssbmv -#define F77_sspmv sspmv -#define F77_sger sger -#define F77_ssyr ssyr -#define F77_sspr sspr -#define F77_ssyr2 ssyr2 -#define F77_sspr2 sspr2 -#define F77_dsymv dsymv -#define F77_dsbmv dsbmv -#define F77_dspmv dspmv -#define F77_dger dger -#define F77_dsyr dsyr -#define F77_dspr dspr -#define F77_dsyr2 dsyr2 -#define F77_dspr2 dspr2 -#define F77_chemv chemv -#define F77_chbmv chbmv -#define F77_chpmv chpmv -#define F77_cgeru cgeru -#define F77_cgerc cgerc -#define F77_cher cher -#define F77_chpr chpr -#define F77_cher2 cher2 -#define F77_chpr2 chpr2 -#define F77_zhemv zhemv -#define F77_zhbmv zhbmv -#define F77_zhpmv zhpmv -#define F77_zgeru zgeru -#define F77_zgerc zgerc -#define F77_zher zher -#define F77_zhpr zhpr -#define F77_zher2 zher2 -#define F77_zhpr2 zhpr2 -#define F77_sgemv sgemv -#define F77_sgbmv sgbmv -#define F77_strmv strmv -#define F77_stbmv stbmv -#define F77_stpmv stpmv -#define F77_strsv strsv -#define F77_stbsv stbsv -#define F77_stpsv stpsv -#define F77_dgemv dgemv -#define F77_dgbmv dgbmv -#define F77_dtrmv dtrmv -#define F77_dtbmv dtbmv -#define F77_dtpmv dtpmv -#define F77_dtrsv dtrsv -#define F77_dtbsv dtbsv -#define F77_dtpsv dtpsv -#define F77_cgemv cgemv -#define F77_cgbmv cgbmv -#define F77_ctrmv ctrmv -#define F77_ctbmv ctbmv -#define F77_ctpmv ctpmv -#define F77_ctrsv ctrsv -#define F77_ctbsv ctbsv -#define F77_ctpsv ctpsv -#define F77_zgemv zgemv -#define F77_zgbmv zgbmv -#define F77_ztrmv ztrmv -#define F77_ztbmv ztbmv -#define F77_ztpmv ztpmv -#define F77_ztrsv ztrsv -#define F77_ztbsv ztbsv -#define F77_ztpsv ztpsv -/* -* Level 3 BLAS -*/ -#define F77_chemm chemm -#define F77_cherk cherk -#define F77_cher2k cher2k -#define F77_zhemm zhemm -#define F77_zherk zherk -#define F77_zher2k zher2k -#define F77_sgemm sgemm -#define F77_ssymm ssymm -#define F77_ssyrk ssyrk -#define F77_ssyr2k ssyr2k -#define F77_strmm strmm -#define F77_strsm strsm -#define F77_dgemm dgemm -#define F77_dsymm dsymm -#define F77_dsyrk dsyrk -#define F77_dsyr2k dsyr2k -#define F77_dtrmm dtrmm -#define F77_dtrsm dtrsm -#define F77_cgemm cgemm -#define F77_csymm csymm -#define F77_csyrk csyrk -#define F77_csyr2k csyr2k -#define F77_ctrmm ctrmm -#define F77_ctrsm ctrsm -#define F77_zgemm zgemm -#define F77_zsymm zsymm -#define F77_zsyrk zsyrk -#define F77_zsyr2k zsyr2k -#define F77_ztrmm ztrmm -#define F77_ztrsm ztrsm -#define F77_dgemmt dgemmt -#define F77_sgemmt sgemmt -#define F77_cgemmt cgemmt -#define F77_zgemmt zgemmt -/* -* Aux Function -*/ -#define F77_scabs1 scabs1 -#define F77_dcabs1 dcabs1 - -/* - * -- BLAS Extension APIs -- - */ - -#define F77_saxpby saxpby -#define F77_daxpby daxpby -#define F77_caxpby caxpby -#define F77_zaxpby zaxpby -#define F77_cgemm3m cgemm3m -#define F77_zgemm3m zgemm3m - -#define F77_isamin_sub isaminsub -#define F77_idamin_sub idaminsub -#define F77_icamin_sub icaminsub -#define F77_izamin_sub izaminsub - -// -- Batch APIs -- -#define F77_sgemm_batch sgemm_batch -#define F77_dgemm_batch dgemm_batch -#define F77_cgemm_batch cgemm_batch -#define F77_zgemm_batch zgemm_batch -#else /* * Level 1 BLAS */ @@ -387,4 +201,4 @@ #define F77_zgemm_batch zgemm_batch_ #endif -#endif /* CBLAS_F77_H */ +/* CBLAS_F77_H */ diff --git a/frame/include/bli_macro_defs.h b/frame/include/bli_macro_defs.h index 61fe4bc557..9808590393 100644 --- a/frame/include/bli_macro_defs.h +++ b/frame/include/bli_macro_defs.h @@ -156,18 +156,12 @@ #define STRINGIFY_INT( s ) MKSTR( s ) #define PASTEMACT(ch1, ch2, ch3, ch4) bli_ ## ch1 ## ch2 ## _ ## ch3 ## _ ## ch4 -// Fortran-77 name-mangling macros. -#ifdef BLIS_ENABLE_NO_UNDERSCORE_API -#define PASTEF770(name) name -#define PASTEF77(ch1,name) ch1 ## name -#define PASTEF772(ch1,ch2,name) ch1 ## ch2 ## name -#define PASTEF773(ch1,ch2,ch3,name) ch1 ## ch2 ## ch3 ## name -#else + #define PASTEF770(name) name ## _ #define PASTEF77(ch1,name) ch1 ## name ## _ #define PASTEF772(ch1,ch2,name) ch1 ## ch2 ## name ## _ #define PASTEF773(ch1,ch2,ch3,name) ch1 ## ch2 ## ch3 ## name ## _ -#endif + // -- Include other groups of macros @@ -187,345 +181,4 @@ #include "bli_oapi_macro_defs.h" #include "bli_tapi_macro_defs.h" - -#ifdef BLIS_ENABLE_NO_UNDERSCORE_API -#define isamax_ isamax -#define idamax_ idamax -#define icamax_ icamax -#define izamax_ izamax -#define sasum_ sasum -#define dasum_ dasum -#define scasum_ scasum -#define dzasum_ dzasum -#define saxpy_ saxpy -#define daxpy_ daxpy -#define caxpy_ caxpy -#define zaxpy_ zaxpy -#define scopy_ scopy -#define dcopy_ dcopy -#define ccopy_ ccopy -#define zcopy_ zcopy -#define sdot_ sdot -#define ddot_ ddot -#define cdotc_ cdotc -#define zdotc_ zdotc -#define cdotu_ cdotu -#define zdotu_ zdotu -#define snrm2_ snrm2 -#define dnrm2_ dnrm2 -#define scnrm2_ scnrm2 -#define dznrm2_ dznrm2 -#define sscal_ sscal -#define dscal_ dscal -#define cscal_ cscal -#define csscal_ csscal -#define zscal_ zscal -#define zdscal_ zdscal -#define sswap_ sswap -#define dswap_ dswap -#define cswap_ cswap -#define zswap_ zswap -#define sgemv_ sgemv -#define dgemv_ dgemv -#define cgemv_ cgemv -#define zgemv_ zgemv -#define sger_ sger -#define dger_ dger -#define cgerc_ cgerc -#define cgeru_ cgeru -#define zgerc_ zgerc -#define zgeru_ zgeru -#define chemv_ chemv -#define zhemv_ zhemv -#define cher_ cher -#define zher_ zher -#define cher2_ cher2 -#define zher2_ zher2 -#define ssymv_ ssymv -#define dsymv_ dsymv -#define csymm_ csymm -#define zsymm_ zsymm -#define ssyr_ ssyr -#define dsyr_ dsyr -#define csyrk_ csyrk -#define csyrk_ csyrk -#define zsyrk_ zsyrk -#define ssyr2_ ssyr2 -#define dsyr2_ dsyr2 -#define csyr2k_ csyr2k -#define zsyr2k_ zsyr2k -#define strmv_ strmv -#define dtrmv_ dtrmv -#define ctrmv_ ctrmv -#define ztrmv_ ztrmv -#define strsv_ strsv -#define dtrsv_ dtrsv -#define ctrsv_ ctrsv -#define ztrsv_ ztrsv -#define sgemm_ sgemm -#define dgemm_ dgemm -#define cgemm_ cgemm -#define zgemm_ zgemm -#define chemm_ chemm -#define zhemm_ zhemm -#define dgemmt_ dgemmt -#define sgemmt_ sgemmt -#define zgemmt_ zgemmt -#define cgemmt_ cgemmt -#define sgemm_batch_ sgemm_batch -#define dgemm_batch_ dgemm_batch -#define cgemm_batch_ cgemm_batch -#define zgemm_batch_ zgemm_batch -#define saxpby_ saxpby -#define daxpby_ daxpby -#define caxpby_ caxpby -#define zaxpby_ zaxpby -#define cher2k_ cher2k -#define zher2k_ zher2k -#define cherk_ cherk -#define zherk_ zherk -#define ssymm_ ssymm -#define dsymm_ dsymm -#define ssyr2k_ ssyr2k -#define dsyr2k_ dsyr2k -#define ssyrk_ ssyrk -#define dsyrk_ dsyrk -#define strmm_ strmm -#define dtrmm_ dtrmm -#define ctrmm_ ctrmm -#define ztrmm_ ztrmm -#define strsm_ strsm -#define dtrsm_ dtrsm -#define ctrsm_ ctrsm -#define ztrsm_ ztrsm -#define lsame_ lsame -#define cimatcopy_ cimatcopy -#define comatadd_ comatadd -#define comatcopy2_ comatcopy2 -#define comatcopy_ comatcopy -#define dimatcopy_ dimatcopy -#define domatadd_ domatadd -#define domatcopy2_ domatcopy2 -#define domatcopy_ domatcopy -#define simatcopy_ simatcopy -#define somatadd_ somatadd -#define somatcopy2_ somatcopy2 -#define somatcopy_ somatcopy -#define zimatcopy_ zimatcopy -#define zomatadd_ zomatadd -#define zomatcopy2_ zomatcopy2 -#define zomatcopy_ zomatcopy -#endif - -#ifdef BLIS_ENABLE_UPPERCASE -#define caxpby CAXPBY -#define caxpy CAXPY -#define ccopy CCOPY -#define cdotc CDOTC -#define cdotcsub CDOTCSUB -#define cdotu CDOTU -#define cdotusub CDOTUSUB -#define cgbmv CGBMV -#define cgemm CGEMM -#define cgemm3m CGEMM3M -#define cgemm_batch CGEMM_BATCH -#define cgemmt CGEMMT -#define cgemv CGEMV -#define cgerc CGERC -#define cgeru CGERU -#define chbmv CHBMV -#define chemm CHEMM -#define chemv CHEMV -#define cher CHER -#define cher2 CHER2 -#define cher2k CHER2K -#define cherk CHERK -#define chpmv CHPMV -#define chpr CHPR -#define chpr2 CHPR2 -#define cimatcopy CIMATCOPY -#define comatadd COMATADD -#define comatcopy2 COMATCOPY2 -#define comatcopy COMATCOPY -#define crotg CROTG -#define cscal CSCAL -#define csrot CSROT -#define csscal CSSCAL -#define cswap CSWAP -#define csymm CSYMM -#define csyr2k CSYR2K -#define csyrk CSYRK -#define ctbmv CTBMV -#define ctbsv CTBSV -#define ctpmv CTPMV -#define ctpsv CTPSV -#define ctrmm CTRMM -#define ctrmv CTRMV -#define ctrsm CTRSM -#define ctrsv CTRSV -#define dasum DASUM -#define dasumsub DASUMSUB -#define daxpby DAXPBY -#define daxpy DAXPY -#define dcabs1 DCABS1 -#define dcopy DCOPY -#define ddot DDOT -#define ddotsub DDOTSUB -#define dgbmv DGBMV -#define dgemm DGEMM -#define dgemm_batch DGEMM_BATCH -#define dgemmt DGEMMT -#define dgemv DGEMV -#define dger DGER -#define dnrm2 DNRM2 -#define dnrm2sub DNRM2SUB -#define dimatcopy DIMATCOPY -#define domatadd DOMATADD -#define domatcopy2 DOMATCOPY2 -#define domatcopy DOMATCOPY -#define drot DROT -#define drotg DROTG -#define drotm DROTM -#define drotmg DROTMG -#define dsbmv DSBMV -#define dscal DSCAL -#define dsdot DSDOT -#define dsdotsub DSDOTSUB -#define dspmv DSPMV -#define dspr DSPR -#define dspr2 DSPR2 -#define dswap DSWAP -#define dsymm DSYMM -#define dsymv DSYMV -#define dsyr DSYR -#define dsyr2 DSYR2 -#define dsyr2k DSYR2K -#define dsyrk DSYRK -#define dtbmv DTBMV -#define dtbsv DTBSV -#define dtpmv DTPMV -#define dtpsv DTPSV -#define dtrmm DTRMM -#define dtrmv DTRMV -#define dtrsm DTRSM -#define dtrsv DTRSV -#define dzasum DZASUM -#define dzasumsub DZASUMSUB -#define dznrm2 DZNRM2 -#define dznrm2sub DZNRM2SUB -#define icamax ICAMAX -#define icamaxsub ICAMAXSUB -#define icamin ICAMIN -#define icaminsub ICAMINSUB -#define idamax IDAMAX -#define idamaxsub IDAMAXSUB -#define idamin IDAMIN -#define idaminsub IDAMINSUB -#define isamax ISAMAX -#define isamaxsub ISAMAXSUB -#define isamin ISAMIN -#define isaminsub ISAMINSUB -#define izamax IZAMAX -#define izamaxsub IZAMAXSUB -#define izamin IZAMIN -#define izaminsub IZAMINSUB -#define lsame LSAME -#define sasum SASUM -#define sasumsub SASUMSUB -#define saxpby SAXPBY -#define saxpy SAXPY -#define scabs1 SCABS1 -#define scasum SCASUM -#define scasumsub SCASUMSUB -#define scnrm2 SCNRM2 -#define scnrm2sub SCNRM2SUB -#define scopy SCOPY -#define sdot SDOT -#define sdotsub SDOTSUB -#define sdsdot SDSDOT -#define sdsdotsub SDSDOTSUB -#define sgbmv SGBMV -#define sgemm SGEMM -#define sgemm_batch SGEMM_BATCH -#define sgemmt SGEMMT -#define sgemv SGEMV -#define sger SGER -#define snrm2 SNRM2 -#define snrm2sub SNRM2SUB -#define simatcopy SIMATCOPY -#define somatadd SOMATADD -#define somatcopy2 SOMATCOPY2 -#define somatcopy SOMATCOPY -#define srot SROT -#define srotg SROTG -#define srotm SROTM -#define srotmg SROTMG -#define ssbmv SSBMV -#define sscal SSCAL -#define sspmv SSPMV -#define sspr SSPR -#define sspr2 SSPR2 -#define sswap SSWAP -#define ssymm SSYMM -#define ssymv SSYMV -#define ssyr SSYR -#define ssyr2 SSYR2 -#define ssyr2k SSYR2K -#define ssyrk SSYRK -#define stbmv STBMV -#define stbsv STBSV -#define stpmv STPMV -#define stpsv STPSV -#define strmm STRMM -#define strmv STRMV -#define strsm STRSM -#define strsv STRSV -#define xerbla XERBLA -#define zaxpby ZAXPBY -#define zaxpy ZAXPY -#define zcopy ZCOPY -#define zdotc ZDOTC -#define zdotcsub ZDOTCSUB -#define zdotu ZDOTU -#define zdotusub ZDOTUSUB -#define zdrot ZDROT -#define zdscal ZDSCAL -#define zgbmv ZGBMV -#define zgemm ZGEMM -#define zgemm3m ZGEMM3M -#define zgemm_batch ZGEMM_BATCH -#define zgemmt ZGEMMT -#define zgemv ZGEMV -#define zgerc ZGERC -#define zgeru ZGERU -#define zhbmv ZHBMV -#define zhemm ZHEMM -#define zhemv ZHEMV -#define zher ZHER -#define zher2 ZHER2 -#define zher2k ZHER2K -#define zherk ZHERK -#define zhpmv ZHPMV -#define zhpr ZHPR -#define zhpr2 ZHPR2 -#define zimatcopy ZIMATCOPY -#define zomatadd ZOMATADD -#define zomatcopy2 ZOMATCOPY2 -#define zomatcopy ZOMATCOPY -#define zrotg ZROTG -#define zscal ZSCAL -#define zswap ZSWAP -#define zsymm ZSYMM -#define zsyr2k ZSYR2K -#define zsyrk ZSYRK -#define ztbmv ZTBMV -#define ztbsv ZTBSV -#define ztpmv ZTPMV -#define ztpsv ZTPSV -#define ztrmm ZTRMM -#define ztrmv ZTRMV -#define ztrsm ZTRSM -#define ztrsv ZTRSV -#endif - #endif diff --git a/frame/util/CMakeLists.txt b/frame/util/CMakeLists.txt index 16ba2fd5a4..c20d7c525d 100644 --- a/frame/util/CMakeLists.txt +++ b/frame/util/CMakeLists.txt @@ -12,4 +12,5 @@ target_sources("${PROJECT_NAME}" ${CMAKE_CURRENT_SOURCE_DIR}/bli_util_tapi_ex.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_util_unb_var1.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_util_update.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_util_api_wrap.c ) diff --git a/frame/util/bli_util.h b/frame/util/bli_util.h index b13c34cdc5..3c4e5722af 100644 --- a/frame/util/bli_util.h +++ b/frame/util/bli_util.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020, Advanced Micro Devices, Inc. + Copyright (C) 2020 - 2021, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -59,3 +59,7 @@ //Routines to copy certain portion of a matrix to another #include "bli_util_update.h" + +// Header file define different formats of BLAS APIs- uppercase with +// and without underscore, lowercase without underscore. +#include "bli_util_api_wrap.h" diff --git a/frame/util/bli_util_api_wrap.c b/frame/util/bli_util_api_wrap.c new file mode 100644 index 0000000000..393a56e143 --- /dev/null +++ b/frame/util/bli_util_api_wrap.c @@ -0,0 +1,3218 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "bli_util_api_wrap.h" + +// wrapper functions to support additional symbols + +void CAXPY(const f77_int *n,const scomplex *ca,const scomplex *cx,const f77_int *incx,scomplex *cy,const f77_int *incy) +{ + caxpy_( n, ca, cx, incx, cy, incy); +} + +void caxpy(const f77_int *n,const scomplex *ca,const scomplex *cx,const f77_int *incx,scomplex *cy,const f77_int *incy) +{ + caxpy_( n, ca, cx, incx, cy, incy); +} + +void CAXPY_(const f77_int *n,const scomplex *ca,const scomplex *cx,const f77_int *incx,scomplex *cy,const f77_int *incy) +{ + caxpy_( n, ca, cx, incx, cy, incy); +} + +void CCOPY(const f77_int *n,const scomplex *cx,const f77_int *incx,scomplex *cy,const f77_int *incy) +{ + ccopy_( n, cx, incx, cy, incy); +} + +void ccopy(const f77_int *n,const scomplex *cx,const f77_int *incx,scomplex *cy,const f77_int *incy) +{ + ccopy_( n, cx, incx, cy, incy); +} + +void CCOPY_(const f77_int *n,const scomplex *cx,const f77_int *incx,scomplex *cy,const f77_int *incy) +{ + ccopy_( n, cx, incx, cy, incy); +} + +#ifdef BLIS_DISABLE_COMPLEX_RETURN_INTEL +scomplex CDOTC(const f77_int* n,const scomplex* x, const f77_int* incx,const scomplex* y, const f77_int* incy) +{ + return cdotc_ ( n, x, incx, y, incy); +} + +scomplex cdotc(const f77_int* n,const scomplex* x, const f77_int* incx,const scomplex* y, const f77_int* incy) +{ + return cdotc_ ( n, x, incx, y, incy); +} + +scomplex CDOTC_(const f77_int* n,const scomplex* x, const f77_int* incx,const scomplex* y, const f77_int* incy) +{ + return cdotc_ ( n, x, incx, y, incy); +} + +scomplex CDOTU(const f77_int* n,const scomplex* x, const f77_int* incx,const scomplex* y, const f77_int* incy) +{ + return cdotu_ ( n, x, incx, y, incy); +} + +scomplex cdotu(const f77_int* n,const scomplex* x, const f77_int* incx,const scomplex* y, const f77_int* incy) +{ + return cdotu_ ( n, x, incx, y, incy); +} + +scomplex CDOTU_(const f77_int* n,const scomplex* x, const f77_int* incx,const scomplex* y, const f77_int* incy) +{ + return cdotu_ ( n, x, incx, y, incy); +} + +dcomplex ZDOTC(const f77_int* n, const dcomplex* x, const f77_int* incx, const dcomplex* y, const f77_int* incy) +{ + return zdotc_ ( n, x, incx, y, incy); +} + +dcomplex zdotc(const f77_int* n, const dcomplex* x, const f77_int* incx, const dcomplex* y, const f77_int* incy) +{ + return zdotc_ ( n, x, incx, y, incy); +} + +dcomplex ZDOTC_(const f77_int* n, const dcomplex* x, const f77_int* incx, const dcomplex* y, const f77_int* incy) +{ + return zdotc_ ( n, x, incx, y, incy); +} + +dcomplex ZDOTU (const f77_int* n, const dcomplex* x, const f77_int* incx, const dcomplex* y, const f77_int* incy) +{ + return zdotu_ ( n, x, incx, y, incy); +} + +dcomplex zdotu (const f77_int* n, const dcomplex* x, const f77_int* incx, const dcomplex* y, const f77_int* incy) +{ + return zdotu_ ( n, x, incx, y, incy); +} + +dcomplex ZDOTU_(const f77_int* n, const dcomplex* x, const f77_int* incx, const dcomplex* y, const f77_int* incy) +{ + return zdotu_ ( n, x, incx, y, incy); +} +#else +void CDOTC(scomplex* retval,const f77_int *n, const scomplex *cx, const f77_int *incx, const scomplex *cy, const f77_int *incy) +{ + cdotc_( retval, n, cx, incx, cy, incy); +} + +void cdotc(scomplex* retval,const f77_int *n, const scomplex *cx, const f77_int *incx, const scomplex *cy, const f77_int *incy) +{ + cdotc_( retval, n, cx, incx, cy, incy); +} + +void CDOTC_(scomplex* retval,const f77_int *n, const scomplex *cx, const f77_int *incx, const scomplex *cy, const f77_int *incy) +{ + cdotc_( retval, n, cx, incx, cy, incy); +} + +void CDOTU(scomplex* retval,const f77_int *n, const scomplex *cx, const f77_int *incx, const scomplex *cy, const f77_int *incy) +{ + cdotu_( retval, n, cx, incx, cy, incy); +} + +void cdotu(scomplex* retval,const f77_int *n, const scomplex *cx, const f77_int *incx, const scomplex *cy, const f77_int *incy) +{ + cdotu_( retval, n, cx, incx, cy, incy); +} + +void CDOTU_(scomplex* retval,const f77_int *n, const scomplex *cx, const f77_int *incx, const scomplex *cy, const f77_int *incy) +{ + cdotu_( retval, n, cx, incx, cy, incy); +} + +void ZDOTC(dcomplex* retval,const f77_int *n, const dcomplex *zx, const f77_int *incx, const dcomplex *zy, const f77_int *incy) +{ + zdotc_( retval, n, zx, incx, zy, incy); +} + +void zdotc(dcomplex* retval,const f77_int *n, const dcomplex *zx, const f77_int *incx, const dcomplex *zy, const f77_int *incy) +{ + zdotc_( retval, n, zx, incx, zy, incy); +} + +void ZDOTC_(dcomplex* retval,const f77_int *n, const dcomplex *zx, const f77_int *incx, const dcomplex *zy, const f77_int *incy) +{ + zdotc_( retval, n, zx, incx, zy, incy); +} + +void ZDOTU(dcomplex* retval,const f77_int *n, const dcomplex *zx, const f77_int *incx, const dcomplex *zy, const f77_int *incy) +{ + zdotu_( retval, n, zx, incx, zy, incy); +} + +void zdotu(dcomplex* retval,const f77_int *n, const dcomplex *zx, const f77_int *incx, const dcomplex *zy, const f77_int *incy) +{ + zdotu_( retval, n, zx, incx, zy, incy); +} + +void ZDOTU_(dcomplex* retval,const f77_int *n, const dcomplex *zx, const f77_int *incx, const dcomplex *zy, const f77_int *incy) +{ + zdotu_( retval, n, zx, incx, zy, incy); +} +#endif + +void CGBMV(const char *trans,const f77_int *m,const f77_int *n,const f77_int *kl,const f77_int *ku,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *x,const f77_int *incx,const scomplex *beta,scomplex *y,const f77_int *incy) +{ + cgbmv_( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); +} + +void cgbmv(const char *trans,const f77_int *m,const f77_int *n,const f77_int *kl,const f77_int *ku,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *x,const f77_int *incx,const scomplex *beta,scomplex *y,const f77_int *incy) +{ + cgbmv_( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); +} + +void CGBMV_(const char *trans,const f77_int *m,const f77_int *n,const f77_int *kl,const f77_int *ku,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *x,const f77_int *incx,const scomplex *beta,scomplex *y,const f77_int *incy) +{ + cgbmv_( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); +} + +void CGEMM(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) +{ + cgemm_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void cgemm(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) +{ + cgemm_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void CGEMM_(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) +{ + cgemm_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void CGEMV(const char *trans,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *x,const f77_int *incx,const scomplex *beta,scomplex *y,const f77_int *incy) +{ + cgemv_( trans, m, n, alpha, a, lda, x, incx, beta, y, incy); +} + +void cgemv(const char *trans,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *x,const f77_int *incx,const scomplex *beta,scomplex *y,const f77_int *incy) +{ + cgemv_( trans, m, n, alpha, a, lda, x, incx, beta, y, incy); +} + +void CGEMV_(const char *trans,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *x,const f77_int *incx,const scomplex *beta,scomplex *y,const f77_int *incy) +{ + cgemv_( trans, m, n, alpha, a, lda, x, incx, beta, y, incy); +} + +void CGERC(const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *x,const f77_int *incx,const scomplex *y,const f77_int *incy,scomplex *a,const f77_int *lda) +{ + cgerc_( m, n, alpha, x, incx, y, incy, a, lda); +} + +void cgerc(const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *x,const f77_int *incx,const scomplex *y,const f77_int *incy,scomplex *a,const f77_int *lda) +{ + cgerc_( m, n, alpha, x, incx, y, incy, a, lda); +} + +void CGERC_(const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *x,const f77_int *incx,const scomplex *y,const f77_int *incy,scomplex *a,const f77_int *lda) +{ + cgerc_( m, n, alpha, x, incx, y, incy, a, lda); +} + +void CGERU(const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *x,const f77_int *incx,const scomplex *y,const f77_int *incy,scomplex *a,const f77_int *lda) +{ + cgeru_( m, n, alpha, x, incx, y, incy, a, lda); +} + +void cgeru(const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *x,const f77_int *incx,const scomplex *y,const f77_int *incy,scomplex *a,const f77_int *lda) +{ + cgeru_( m, n, alpha, x, incx, y, incy, a, lda); +} + +void CGERU_(const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *x,const f77_int *incx,const scomplex *y,const f77_int *incy,scomplex *a,const f77_int *lda) +{ + cgeru_( m, n, alpha, x, incx, y, incy, a, lda); +} + +void CHBMV(const char *uplo,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *x,const f77_int *incx,const scomplex *beta,scomplex *y,const f77_int *incy) +{ + chbmv_( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); +} + +void chbmv(const char *uplo,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *x,const f77_int *incx,const scomplex *beta,scomplex *y,const f77_int *incy) +{ + chbmv_( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); +} + +void CHBMV_(const char *uplo,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *x,const f77_int *incx,const scomplex *beta,scomplex *y,const f77_int *incy) +{ + chbmv_( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); +} + +void CHEMM(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) +{ + chemm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void chemm(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) +{ + chemm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void CHEMM_(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) +{ + chemm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void CHEMV(const char *uplo,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *x,const f77_int *incx,const scomplex *beta,scomplex *y,const f77_int *incy) +{ + chemv_( uplo, n, alpha, a, lda, x, incx, beta, y, incy); +} + +void chemv(const char *uplo,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *x,const f77_int *incx,const scomplex *beta,scomplex *y,const f77_int *incy) +{ + chemv_( uplo, n, alpha, a, lda, x, incx, beta, y, incy); +} + +void CHEMV_(const char *uplo,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *x,const f77_int *incx,const scomplex *beta,scomplex *y,const f77_int *incy) +{ + chemv_( uplo, n, alpha, a, lda, x, incx, beta, y, incy); +} + +void CHER(const char *uplo,const f77_int *n,const float *alpha,const scomplex *x,const f77_int *incx,scomplex *a,const f77_int *lda) +{ + cher_( uplo, n, alpha, x, incx, a, lda); +} + +void cher(const char *uplo,const f77_int *n,const float *alpha,const scomplex *x,const f77_int *incx,scomplex *a,const f77_int *lda) +{ + cher_( uplo, n, alpha, x, incx, a, lda); +} + +void CHER_(const char *uplo,const f77_int *n,const float *alpha,const scomplex *x,const f77_int *incx,scomplex *a,const f77_int *lda) +{ + cher_( uplo, n, alpha, x, incx, a, lda); +} + +void CHER2(const char *uplo,const f77_int *n,const scomplex *alpha,const scomplex *x,const f77_int *incx,const scomplex *y,const f77_int *incy,scomplex *a,const f77_int *lda) +{ + cher2_( uplo, n, alpha, x, incx, y, incy, a, lda); +} + +void cher2(const char *uplo,const f77_int *n,const scomplex *alpha,const scomplex *x,const f77_int *incx,const scomplex *y,const f77_int *incy,scomplex *a,const f77_int *lda) +{ + cher2_( uplo, n, alpha, x, incx, y, incy, a, lda); +} + +void CHER2_(const char *uplo,const f77_int *n,const scomplex *alpha,const scomplex *x,const f77_int *incx,const scomplex *y,const f77_int *incy,scomplex *a,const f77_int *lda) +{ + cher2_( uplo, n, alpha, x, incx, y, incy, a, lda); +} + +void CHER2K(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const float *beta,scomplex *c,const f77_int *ldc) +{ + cher2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void cher2k(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const float *beta,scomplex *c,const f77_int *ldc) +{ + cher2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void CHER2K_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const float *beta,scomplex *c,const f77_int *ldc) +{ + cher2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void CHERK(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const float *alpha,const scomplex *a,const f77_int *lda,const float *beta,scomplex *c,const f77_int *ldc) +{ + cherk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); +} + +void cherk(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const float *alpha,const scomplex *a,const f77_int *lda,const float *beta,scomplex *c,const f77_int *ldc) +{ + cherk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); +} + +void CHERK_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const float *alpha,const scomplex *a,const f77_int *lda,const float *beta,scomplex *c,const f77_int *ldc) +{ + cherk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); +} + +void CHPMV(const char *uplo,const f77_int *n,const scomplex *alpha,const scomplex *ap,const scomplex *x,const f77_int *incx,const scomplex *beta,scomplex *y,const f77_int *incy) +{ + chpmv_( uplo, n, alpha, ap, x, incx, beta, y, incy); +} + +void chpmv(const char *uplo,const f77_int *n,const scomplex *alpha,const scomplex *ap,const scomplex *x,const f77_int *incx,const scomplex *beta,scomplex *y,const f77_int *incy) +{ + chpmv_( uplo, n, alpha, ap, x, incx, beta, y, incy); +} + +void CHPMV_(const char *uplo,const f77_int *n,const scomplex *alpha,const scomplex *ap,const scomplex *x,const f77_int *incx,const scomplex *beta,scomplex *y,const f77_int *incy) +{ + chpmv_( uplo, n, alpha, ap, x, incx, beta, y, incy); +} + +void CHPR(const char *uplo,const f77_int *n,const float *alpha,const scomplex *x,const f77_int *incx,scomplex *ap) +{ + chpr_( uplo, n, alpha, x, incx, ap); +} + +void chpr(const char *uplo,const f77_int *n,const float *alpha,const scomplex *x,const f77_int *incx,scomplex *ap) +{ + chpr_( uplo, n, alpha, x, incx, ap); +} + +void CHPR_(const char *uplo,const f77_int *n,const float *alpha,const scomplex *x,const f77_int *incx,scomplex *ap) +{ + chpr_( uplo, n, alpha, x, incx, ap); +} + +void CHPR2(const char *uplo,const f77_int *n,const scomplex *alpha,const scomplex *x,const f77_int *incx,const scomplex *y,const f77_int *incy,scomplex *ap) +{ + chpr2_( uplo, n, alpha, x, incx, y, incy, ap); +} + +void chpr2(const char *uplo,const f77_int *n,const scomplex *alpha,const scomplex *x,const f77_int *incx,const scomplex *y,const f77_int *incy,scomplex *ap) +{ + chpr2_( uplo, n, alpha, x, incx, y, incy, ap); +} + +void CHPR2_(const char *uplo,const f77_int *n,const scomplex *alpha,const scomplex *x,const f77_int *incx,const scomplex *y,const f77_int *incy,scomplex *ap) +{ + chpr2_( uplo, n, alpha, x, incx, y, incy, ap); +} + +void CROTG(scomplex *ca, bla_scomplex *cb, bla_real *c,scomplex *s) +{ + crotg_( ca, cb, c, s); +} + +void crotg(scomplex *ca, bla_scomplex *cb, bla_real *c,scomplex *s) +{ + crotg_( ca, cb, c, s); +} + +void CROTG_(scomplex *ca, bla_scomplex *cb, bla_real *c,scomplex *s) +{ + crotg_( ca, cb, c, s); +} + +void CSCAL(const f77_int *n,const scomplex *ca,scomplex *cx,const f77_int *incx) +{ + cscal_( n, ca, cx, incx); +} + +void cscal(const f77_int *n,const scomplex *ca,scomplex *cx,const f77_int *incx) +{ + cscal_( n, ca, cx, incx); +} + +void CSCAL_(const f77_int *n,const scomplex *ca,scomplex *cx,const f77_int *incx) +{ + cscal_( n, ca, cx, incx); +} + +void CSROT(const f77_int *n,scomplex *cx,const f77_int *incx,scomplex *cy,const f77_int *incy,const float *c,const float *s) +{ + csrot_( n, cx, incx, cy, incy, c, s); +} + +void csrot(const f77_int *n,scomplex *cx,const f77_int *incx,scomplex *cy,const f77_int *incy,const float *c,const float *s) +{ + csrot_( n, cx, incx, cy, incy, c, s); +} + +void CSROT_(const f77_int *n,scomplex *cx,const f77_int *incx,scomplex *cy,const f77_int *incy,const float *c,const float *s) +{ + csrot_( n, cx, incx, cy, incy, c, s); +} + +void CSSCAL(const f77_int *n,const float *sa,scomplex *cx,const f77_int *incx) +{ + csscal_( n, sa, cx, incx); +} + +void csscal(const f77_int *n,const float *sa,scomplex *cx,const f77_int *incx) +{ + csscal_( n, sa, cx, incx); +} + +void CSSCAL_(const f77_int *n,const float *sa,scomplex *cx,const f77_int *incx) +{ + csscal_( n, sa, cx, incx); +} + +void CSWAP(const f77_int *n,scomplex *cx,const f77_int *incx,scomplex *cy,const f77_int *incy) +{ + cswap_( n, cx, incx, cy, incy); +} + +void cswap(const f77_int *n,scomplex *cx,const f77_int *incx,scomplex *cy,const f77_int *incy) +{ + cswap_( n, cx, incx, cy, incy); +} + +void CSWAP_(const f77_int *n,scomplex *cx,const f77_int *incx,scomplex *cy,const f77_int *incy) +{ + cswap_( n, cx, incx, cy, incy); +} + +void CSYMM(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) +{ + csymm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void csymm(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) +{ + csymm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void CSYMM_(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) +{ + csymm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void CSYR2K(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) +{ + csyr2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void csyr2k(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) +{ + csyr2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void CSYR2K_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) +{ + csyr2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void CSYRK(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *beta,scomplex *c,const f77_int *ldc) +{ + csyrk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); +} + +void csyrk(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *beta,scomplex *c,const f77_int *ldc) +{ + csyrk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); +} + +void CSYRK_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *beta,scomplex *c,const f77_int *ldc) +{ + csyrk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); +} + +void CTBMV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const scomplex *a,const f77_int *lda,scomplex *x,const f77_int *incx) +{ + ctbmv_( uplo, trans, diag, n, k, a, lda, x, incx); +} + +void ctbmv(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const scomplex *a,const f77_int *lda,scomplex *x,const f77_int *incx) +{ + ctbmv_( uplo, trans, diag, n, k, a, lda, x, incx); +} + +void CTBMV_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const scomplex *a,const f77_int *lda,scomplex *x,const f77_int *incx) +{ + ctbmv_( uplo, trans, diag, n, k, a, lda, x, incx); +} + +void CTBSV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const scomplex *a,const f77_int *lda,scomplex *x,const f77_int *incx) +{ + ctbsv_( uplo, trans, diag, n, k, a, lda, x, incx); +} + +void ctbsv(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const scomplex *a,const f77_int *lda,scomplex *x,const f77_int *incx) +{ + ctbsv_( uplo, trans, diag, n, k, a, lda, x, incx); +} + +void CTBSV_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const scomplex *a,const f77_int *lda,scomplex *x,const f77_int *incx) +{ + ctbsv_( uplo, trans, diag, n, k, a, lda, x, incx); +} + +void CTPMV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const scomplex *ap,scomplex *x,const f77_int *incx) +{ + ctpmv_( uplo, trans, diag, n, ap, x, incx); +} + +void ctpmv(const char *uplo,const char *trans,const char *diag,const f77_int *n,const scomplex *ap,scomplex *x,const f77_int *incx) +{ + ctpmv_( uplo, trans, diag, n, ap, x, incx); +} + +void CTPMV_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const scomplex *ap,scomplex *x,const f77_int *incx) +{ + ctpmv_( uplo, trans, diag, n, ap, x, incx); +} + +void CTPSV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const scomplex *ap,scomplex *x,const f77_int *incx) +{ + ctpsv_( uplo, trans, diag, n, ap, x, incx); +} + +void ctpsv(const char *uplo,const char *trans,const char *diag,const f77_int *n,const scomplex *ap,scomplex *x,const f77_int *incx) +{ + ctpsv_( uplo, trans, diag, n, ap, x, incx); +} + +void CTPSV_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const scomplex *ap,scomplex *x,const f77_int *incx) +{ + ctpsv_( uplo, trans, diag, n, ap, x, incx); +} + +void CTRMM(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,scomplex *b,const f77_int *ldb) +{ + ctrmm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); +} + +void ctrmm(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,scomplex *b,const f77_int *ldb) +{ + ctrmm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); +} + +void CTRMM_(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,scomplex *b,const f77_int *ldb) +{ + ctrmm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); +} + +void CTRMV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const scomplex *a,const f77_int *lda,scomplex *x,const f77_int *incx) +{ + ctrmv_( uplo, trans, diag, n, a, lda, x, incx); +} + +void ctrmv(const char *uplo,const char *trans,const char *diag,const f77_int *n,const scomplex *a,const f77_int *lda,scomplex *x,const f77_int *incx) +{ + ctrmv_( uplo, trans, diag, n, a, lda, x, incx); +} + +void CTRMV_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const scomplex *a,const f77_int *lda,scomplex *x,const f77_int *incx) +{ + ctrmv_( uplo, trans, diag, n, a, lda, x, incx); +} + +void CTRSM(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,scomplex *b,const f77_int *ldb) +{ + ctrsm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); +} + +void ctrsm(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,scomplex *b,const f77_int *ldb) +{ + ctrsm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); +} + +void CTRSM_(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,scomplex *b,const f77_int *ldb) +{ + ctrsm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); +} + +void CTRSV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const scomplex *a,const f77_int *lda,scomplex *x,const f77_int *incx) +{ + ctrsv_( uplo, trans, diag, n, a, lda, x, incx); +} + +void ctrsv(const char *uplo,const char *trans,const char *diag,const f77_int *n,const scomplex *a,const f77_int *lda,scomplex *x,const f77_int *incx) +{ + ctrsv_( uplo, trans, diag, n, a, lda, x, incx); +} + +void CTRSV_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const scomplex *a,const f77_int *lda,scomplex *x,const f77_int *incx) +{ + ctrsv_( uplo, trans, diag, n, a, lda, x, incx); +} + +double DASUM(const f77_int *n,const double *dx,const f77_int *incx) +{ + return dasum_( n, dx, incx); +} + +double dasum(const f77_int *n,const double *dx,const f77_int *incx) +{ + return dasum_( n, dx, incx); +} + +double DASUM_(const f77_int *n,const double *dx,const f77_int *incx) +{ + return dasum_( n, dx, incx); +} + +void DAXPY(const f77_int *n,const double *da,const double *dx,const f77_int *incx,double *dy,const f77_int *incy) +{ + daxpy_( n, da, dx, incx, dy, incy); +} + +void daxpy(const f77_int *n,const double *da,const double *dx,const f77_int *incx,double *dy,const f77_int *incy) +{ + daxpy_( n, da, dx, incx, dy, incy); +} + +void DAXPY_(const f77_int *n,const double *da,const double *dx,const f77_int *incx,double *dy,const f77_int *incy) +{ + daxpy_( n, da, dx, incx, dy, incy); +} + +double DCABS1(bla_dcomplex *z) +{ + return dcabs1_( z); +} + +double dcabs1(bla_dcomplex *z) +{ + return dcabs1_( z); +} + +double DCABS1_(bla_dcomplex *z) +{ + return dcabs1_( z); +} + +void DCOPY(const f77_int *n,const double *dx,const f77_int *incx,double *dy,const f77_int *incy) +{ + dcopy_( n, dx, incx, dy, incy); +} + +void dcopy(const f77_int *n,const double *dx,const f77_int *incx,double *dy,const f77_int *incy) +{ + dcopy_( n, dx, incx, dy, incy); +} + +void DCOPY_(const f77_int *n,const double *dx,const f77_int *incx,double *dy,const f77_int *incy) +{ + dcopy_( n, dx, incx, dy, incy); +} + +double DDOT(const f77_int *n,const double *dx,const f77_int *incx,const double *dy,const f77_int *incy) +{ + return ddot_( n, dx, incx, dy, incy); +} + +double ddot(const f77_int *n,const double *dx,const f77_int *incx,const double *dy,const f77_int *incy) +{ + return ddot_( n, dx, incx, dy, incy); +} + +double DDOT_(const f77_int *n,const double *dx,const f77_int *incx,const double *dy,const f77_int *incy) +{ + return ddot_( n, dx, incx, dy, incy); +} + +void DGBMV(const char *trans,const f77_int *m,const f77_int *n,const f77_int *kl,const f77_int *ku,const double *alpha,const double *a,const f77_int *lda,const double *x,const f77_int *incx,const double *beta,double *y,const f77_int *incy) +{ + dgbmv_( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); +} + +void dgbmv(const char *trans,const f77_int *m,const f77_int *n,const f77_int *kl,const f77_int *ku,const double *alpha,const double *a,const f77_int *lda,const double *x,const f77_int *incx,const double *beta,double *y,const f77_int *incy) +{ + dgbmv_( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); +} + +void DGBMV_(const char *trans,const f77_int *m,const f77_int *n,const f77_int *kl,const f77_int *ku,const double *alpha,const double *a,const f77_int *lda,const double *x,const f77_int *incx,const double *beta,double *y,const f77_int *incy) +{ + dgbmv_( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); +} + +void DGEMM(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const double *alpha,const double *a,const f77_int *lda,const double *b,const f77_int *ldb,const double *beta,double *c,const f77_int *ldc) +{ + dgemm_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void dgemm(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const double *alpha,const double *a,const f77_int *lda,const double *b,const f77_int *ldb,const double *beta,double *c,const f77_int *ldc) +{ + dgemm_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void DGEMM_(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const double *alpha,const double *a,const f77_int *lda,const double *b,const f77_int *ldb,const double *beta,double *c,const f77_int *ldc) +{ + dgemm_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void DGEMV(const char *trans,const f77_int *m,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,const double *x,const f77_int *incx,const double *beta,double *y,const f77_int *incy) +{ + dgemv_( trans, m, n, alpha, a, lda, x, incx, beta, y, incy); +} + +void dgemv(const char *trans,const f77_int *m,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,const double *x,const f77_int *incx,const double *beta,double *y,const f77_int *incy) +{ + dgemv_( trans, m, n, alpha, a, lda, x, incx, beta, y, incy); +} + +void DGEMV_(const char *trans,const f77_int *m,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,const double *x,const f77_int *incx,const double *beta,double *y,const f77_int *incy) +{ + dgemv_( trans, m, n, alpha, a, lda, x, incx, beta, y, incy); +} + +void DGER(const f77_int *m,const f77_int *n,const double *alpha,const double *x,const f77_int *incx,const double *y,const f77_int *incy,double *a,const f77_int *lda) +{ + dger_( m, n, alpha, x, incx, y, incy, a, lda); +} + +void dger(const f77_int *m,const f77_int *n,const double *alpha,const double *x,const f77_int *incx,const double *y,const f77_int *incy,double *a,const f77_int *lda) +{ + dger_( m, n, alpha, x, incx, y, incy, a, lda); +} + +void DGER_(const f77_int *m,const f77_int *n,const double *alpha,const double *x,const f77_int *incx,const double *y,const f77_int *incy,double *a,const f77_int *lda) +{ + dger_( m, n, alpha, x, incx, y, incy, a, lda); +} + +double DNRM2(const f77_int *n,const double *x,const f77_int *incx) +{ + return dnrm2_( n, x, incx); +} + +double dnrm2(const f77_int *n,const double *x,const f77_int *incx) +{ + return dnrm2_( n, x, incx); +} + +double DNRM2_(const f77_int *n,const double *x,const f77_int *incx) +{ + return dnrm2_( n, x, incx); +} + +void DROT(const f77_int *n,double *dx,const f77_int *incx,double *dy,const f77_int *incy,const double *c,const double *s) +{ + drot_( n, dx, incx, dy, incy, c, s); +} + +void drot(const f77_int *n,double *dx,const f77_int *incx,double *dy,const f77_int *incy,const double *c,const double *s) +{ + drot_( n, dx, incx, dy, incy, c, s); +} + +void DROT_(const f77_int *n,double *dx,const f77_int *incx,double *dy,const f77_int *incy,const double *c,const double *s) +{ + drot_( n, dx, incx, dy, incy, c, s); +} + +void DROTG(double *da,double *db,double *c,double *s) +{ + drotg_( da, db, c, s); +} + +void drotg(double *da,double *db,double *c,double *s) +{ + drotg_( da, db, c, s); +} + +void DROTG_(double *da,double *db,double *c,double *s) +{ + drotg_( da, db, c, s); +} + +void DROTM(const f77_int *n,double *dx,const f77_int *incx,double *dy,const f77_int *incy,const double *dparam) +{ + drotm_( n, dx, incx, dy, incy, dparam); +} + +void drotm(const f77_int *n,double *dx,const f77_int *incx,double *dy,const f77_int *incy,const double *dparam) +{ + drotm_( n, dx, incx, dy, incy, dparam); +} + +void DROTM_(const f77_int *n,double *dx,const f77_int *incx,double *dy,const f77_int *incy,const double *dparam) +{ + drotm_( n, dx, incx, dy, incy, dparam); +} + +void DROTMG(double *dd1,double *dd2,double *dx1,const double *dy1,double *dparam) +{ + drotmg_( dd1, dd2, dx1, dy1, dparam); +} + +void drotmg(double *dd1,double *dd2,double *dx1,const double *dy1,double *dparam) +{ + drotmg_( dd1, dd2, dx1, dy1, dparam); +} + +void DROTMG_(double *dd1,double *dd2,double *dx1,const double *dy1,double *dparam) +{ + drotmg_( dd1, dd2, dx1, dy1, dparam); +} + +void DSBMV(const char *uplo,const f77_int *n,const f77_int *k,const double *alpha,const double *a,const f77_int *lda,const double *x,const f77_int *incx,const double *beta,double *y,const f77_int *incy) +{ + dsbmv_( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); +} + +void dsbmv(const char *uplo,const f77_int *n,const f77_int *k,const double *alpha,const double *a,const f77_int *lda,const double *x,const f77_int *incx,const double *beta,double *y,const f77_int *incy) +{ + dsbmv_( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); +} + +void DSBMV_(const char *uplo,const f77_int *n,const f77_int *k,const double *alpha,const double *a,const f77_int *lda,const double *x,const f77_int *incx,const double *beta,double *y,const f77_int *incy) +{ + dsbmv_( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); +} + +void DSCAL(const f77_int *n,const double *da,double *dx,const f77_int *incx) +{ + dscal_( n, da, dx, incx); +} + +void dscal(const f77_int *n,const double *da,double *dx,const f77_int *incx) +{ + dscal_( n, da, dx, incx); +} + +void DSCAL_(const f77_int *n,const double *da,double *dx,const f77_int *incx) +{ + dscal_( n, da, dx, incx); +} + +double DSDOT(const f77_int *n,const float *sx,const f77_int *incx,const float *sy,const f77_int *incy) +{ + return dsdot_( n, sx, incx, sy, incy); +} + +double dsdot(const f77_int *n,const float *sx,const f77_int *incx,const float *sy,const f77_int *incy) +{ + return dsdot_( n, sx, incx, sy, incy); +} + +double DSDOT_(const f77_int *n,const float *sx,const f77_int *incx,const float *sy,const f77_int *incy) +{ + return dsdot_( n, sx, incx, sy, incy); +} + +void DSPMV(const char *uplo,const f77_int *n,const double *alpha,const double *ap,const double *x,const f77_int *incx,const double *beta,double *y,const f77_int *incy) +{ + dspmv_( uplo, n, alpha, ap, x, incx, beta, y, incy); +} + +void dspmv(const char *uplo,const f77_int *n,const double *alpha,const double *ap,const double *x,const f77_int *incx,const double *beta,double *y,const f77_int *incy) +{ + dspmv_( uplo, n, alpha, ap, x, incx, beta, y, incy); +} + +void DSPMV_(const char *uplo,const f77_int *n,const double *alpha,const double *ap,const double *x,const f77_int *incx,const double *beta,double *y,const f77_int *incy) +{ + dspmv_( uplo, n, alpha, ap, x, incx, beta, y, incy); +} + +void DSPR(const char *uplo,const f77_int *n,const double *alpha,const double *x,const f77_int *incx,double *ap) +{ + dspr_( uplo, n, alpha, x, incx, ap); +} + +void dspr(const char *uplo,const f77_int *n,const double *alpha,const double *x,const f77_int *incx,double *ap) +{ + dspr_( uplo, n, alpha, x, incx, ap); +} + +void DSPR_(const char *uplo,const f77_int *n,const double *alpha,const double *x,const f77_int *incx,double *ap) +{ + dspr_( uplo, n, alpha, x, incx, ap); +} + +void DSPR2(const char *uplo,const f77_int *n,const double *alpha,const double *x,const f77_int *incx,const double *y,const f77_int *incy,double *ap) +{ + dspr2_( uplo, n, alpha, x, incx, y, incy, ap); +} + +void dspr2(const char *uplo,const f77_int *n,const double *alpha,const double *x,const f77_int *incx,const double *y,const f77_int *incy,double *ap) +{ + dspr2_( uplo, n, alpha, x, incx, y, incy, ap); +} + +void DSPR2_(const char *uplo,const f77_int *n,const double *alpha,const double *x,const f77_int *incx,const double *y,const f77_int *incy,double *ap) +{ + dspr2_( uplo, n, alpha, x, incx, y, incy, ap); +} + +void DSWAP(const f77_int *n,double *dx,const f77_int *incx,double *dy,const f77_int *incy) +{ + dswap_( n, dx, incx, dy, incy); +} + +void dswap(const f77_int *n,double *dx,const f77_int *incx,double *dy,const f77_int *incy) +{ + dswap_( n, dx, incx, dy, incy); +} + +void DSWAP_(const f77_int *n,double *dx,const f77_int *incx,double *dy,const f77_int *incy) +{ + dswap_( n, dx, incx, dy, incy); +} + +void DSYMM(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,const double *b,const f77_int *ldb,const double *beta,double *c,const f77_int *ldc) +{ + dsymm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void dsymm(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,const double *b,const f77_int *ldb,const double *beta,double *c,const f77_int *ldc) +{ + dsymm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void DSYMM_(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,const double *b,const f77_int *ldb,const double *beta,double *c,const f77_int *ldc) +{ + dsymm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void DSYMV(const char *uplo,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,const double *x,const f77_int *incx,const double *beta,double *y,const f77_int *incy) +{ + dsymv_( uplo, n, alpha, a, lda, x, incx, beta, y, incy); +} + +void dsymv(const char *uplo,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,const double *x,const f77_int *incx,const double *beta,double *y,const f77_int *incy) +{ + dsymv_( uplo, n, alpha, a, lda, x, incx, beta, y, incy); +} + +void DSYMV_(const char *uplo,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,const double *x,const f77_int *incx,const double *beta,double *y,const f77_int *incy) +{ + dsymv_( uplo, n, alpha, a, lda, x, incx, beta, y, incy); +} + +void DSYR(const char *uplo,const f77_int *n,const double *alpha,const double *x,const f77_int *incx,double *a,const f77_int *lda) +{ + dsyr_( uplo, n, alpha, x, incx, a, lda); +} + +void dsyr(const char *uplo,const f77_int *n,const double *alpha,const double *x,const f77_int *incx,double *a,const f77_int *lda) +{ + dsyr_( uplo, n, alpha, x, incx, a, lda); +} + +void DSYR_(const char *uplo,const f77_int *n,const double *alpha,const double *x,const f77_int *incx,double *a,const f77_int *lda) +{ + dsyr_( uplo, n, alpha, x, incx, a, lda); +} + +void DSYR2(const char *uplo,const f77_int *n,const double *alpha,const double *x,const f77_int *incx,const double *y,const f77_int *incy,double *a,const f77_int *lda) +{ + dsyr2_( uplo, n, alpha, x, incx, y, incy, a, lda); +} + +void dsyr2(const char *uplo,const f77_int *n,const double *alpha,const double *x,const f77_int *incx,const double *y,const f77_int *incy,double *a,const f77_int *lda) +{ + dsyr2_( uplo, n, alpha, x, incx, y, incy, a, lda); +} + +void DSYR2_(const char *uplo,const f77_int *n,const double *alpha,const double *x,const f77_int *incx,const double *y,const f77_int *incy,double *a,const f77_int *lda) +{ + dsyr2_( uplo, n, alpha, x, incx, y, incy, a, lda); +} + +void DSYR2K(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const double *alpha,const double *a,const f77_int *lda,const double *b,const f77_int *ldb,const double *beta,double *c,const f77_int *ldc) +{ + dsyr2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void dsyr2k(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const double *alpha,const double *a,const f77_int *lda,const double *b,const f77_int *ldb,const double *beta,double *c,const f77_int *ldc) +{ + dsyr2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void DSYR2K_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const double *alpha,const double *a,const f77_int *lda,const double *b,const f77_int *ldb,const double *beta,double *c,const f77_int *ldc) +{ + dsyr2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void DSYRK(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const double *alpha,const double *a,const f77_int *lda,const double *beta,double *c,const f77_int *ldc) +{ + dsyrk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); +} + +void dsyrk(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const double *alpha,const double *a,const f77_int *lda,const double *beta,double *c,const f77_int *ldc) +{ + dsyrk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); +} + +void DSYRK_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const double *alpha,const double *a,const f77_int *lda,const double *beta,double *c,const f77_int *ldc) +{ + dsyrk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); +} + +void DTBMV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const double *a,const f77_int *lda,double *x,const f77_int *incx) +{ + dtbmv_( uplo, trans, diag, n, k, a, lda, x, incx); +} + +void dtbmv(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const double *a,const f77_int *lda,double *x,const f77_int *incx) +{ + dtbmv_( uplo, trans, diag, n, k, a, lda, x, incx); +} + +void DTBMV_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const double *a,const f77_int *lda,double *x,const f77_int *incx) +{ + dtbmv_( uplo, trans, diag, n, k, a, lda, x, incx); +} + +void DTBSV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const double *a,const f77_int *lda,double *x,const f77_int *incx) +{ + dtbsv_( uplo, trans, diag, n, k, a, lda, x, incx); +} + +void dtbsv(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const double *a,const f77_int *lda,double *x,const f77_int *incx) +{ + dtbsv_( uplo, trans, diag, n, k, a, lda, x, incx); +} + +void DTBSV_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const double *a,const f77_int *lda,double *x,const f77_int *incx) +{ + dtbsv_( uplo, trans, diag, n, k, a, lda, x, incx); +} + +void DTPMV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const double *ap,double *x,const f77_int *incx) +{ + dtpmv_( uplo, trans, diag, n, ap, x, incx); +} + +void dtpmv(const char *uplo,const char *trans,const char *diag,const f77_int *n,const double *ap,double *x,const f77_int *incx) +{ + dtpmv_( uplo, trans, diag, n, ap, x, incx); +} + +void DTPMV_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const double *ap,double *x,const f77_int *incx) +{ + dtpmv_( uplo, trans, diag, n, ap, x, incx); +} + +void DTPSV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const double *ap,double *x,const f77_int *incx) +{ + dtpsv_( uplo, trans, diag, n, ap, x, incx); +} + +void dtpsv(const char *uplo,const char *trans,const char *diag,const f77_int *n,const double *ap,double *x,const f77_int *incx) +{ + dtpsv_( uplo, trans, diag, n, ap, x, incx); +} + +void DTPSV_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const double *ap,double *x,const f77_int *incx) +{ + dtpsv_( uplo, trans, diag, n, ap, x, incx); +} + +void DTRMM(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,double *b,const f77_int *ldb) +{ + dtrmm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); +} + +void dtrmm(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,double *b,const f77_int *ldb) +{ + dtrmm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); +} + +void DTRMM_(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,double *b,const f77_int *ldb) +{ + dtrmm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); +} + +void DTRMV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const double *a,const f77_int *lda,double *x,const f77_int *incx) +{ + dtrmv_( uplo, trans, diag, n, a, lda, x, incx); +} + +void dtrmv(const char *uplo,const char *trans,const char *diag,const f77_int *n,const double *a,const f77_int *lda,double *x,const f77_int *incx) +{ + dtrmv_( uplo, trans, diag, n, a, lda, x, incx); +} + +void DTRMV_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const double *a,const f77_int *lda,double *x,const f77_int *incx) +{ + dtrmv_( uplo, trans, diag, n, a, lda, x, incx); +} + +void DTRSM(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,double *b,const f77_int *ldb) +{ + dtrsm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); +} + +void dtrsm(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,double *b,const f77_int *ldb) +{ + dtrsm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); +} + +void DTRSM_(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,double *b,const f77_int *ldb) +{ + dtrsm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); +} + +void DTRSV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const double *a,const f77_int *lda,double *x,const f77_int *incx) +{ + dtrsv_( uplo, trans, diag, n, a, lda, x, incx); +} + +void dtrsv(const char *uplo,const char *trans,const char *diag,const f77_int *n,const double *a,const f77_int *lda,double *x,const f77_int *incx) +{ + dtrsv_( uplo, trans, diag, n, a, lda, x, incx); +} + +void DTRSV_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const double *a,const f77_int *lda,double *x,const f77_int *incx) +{ + dtrsv_( uplo, trans, diag, n, a, lda, x, incx); +} + +double DZASUM(const f77_int *n,const dcomplex *zx,const f77_int *incx) +{ + return dzasum_( n, zx, incx); +} + +double dzasum(const f77_int *n,const dcomplex *zx,const f77_int *incx) +{ + return dzasum_( n, zx, incx); +} + +double DZASUM_(const f77_int *n,const dcomplex *zx,const f77_int *incx) +{ + return dzasum_( n, zx, incx); +} + +double DZNRM2(const f77_int *n,const dcomplex *x,const f77_int *incx) +{ + return dznrm2_( n, x, incx); +} + +double dznrm2(const f77_int *n,const dcomplex *x,const f77_int *incx) +{ + return dznrm2_( n, x, incx); +} + +double DZNRM2_(const f77_int *n,const dcomplex *x,const f77_int *incx) +{ + return dznrm2_( n, x, incx); +} + +f77_int ICAMAX(const f77_int *n,const scomplex *cx,const f77_int *incx) +{ + return icamax_( n, cx, incx); +} + +f77_int icamax(const f77_int *n,const scomplex *cx,const f77_int *incx) +{ + return icamax_( n, cx, incx); +} + +f77_int ICAMAX_(const f77_int *n,const scomplex *cx,const f77_int *incx) +{ + return icamax_( n, cx, incx); +} + +f77_int IDAMAX(const f77_int *n,const double *dx,const f77_int *incx) +{ + return idamax_( n, dx, incx); +} + +f77_int idamax(const f77_int *n,const double *dx,const f77_int *incx) +{ + return idamax_( n, dx, incx); +} + +f77_int IDAMAX_(const f77_int *n,const double *dx,const f77_int *incx) +{ + return idamax_( n, dx, incx); +} + +f77_int ISAMAX(const f77_int *n,const float *sx,const f77_int *incx) +{ + return isamax_( n, sx, incx); +} + +f77_int isamax(const f77_int *n,const float *sx,const f77_int *incx) +{ + return isamax_( n, sx, incx); +} + +f77_int ISAMAX_(const f77_int *n,const float *sx,const f77_int *incx) +{ + return isamax_( n, sx, incx); +} + +f77_int IZAMAX(const f77_int *n,const dcomplex *zx,const f77_int *incx) +{ + return izamax_( n, zx, incx); +} + +f77_int izamax(const f77_int *n,const dcomplex *zx,const f77_int *incx) +{ + return izamax_( n, zx, incx); +} + +f77_int IZAMAX_(const f77_int *n,const dcomplex *zx,const f77_int *incx) +{ + return izamax_( n, zx, incx); +} + +f77_int LSAME(const char *ca,const char *cb,const f77_int a,const f77_int b) +{ + return lsame_( ca, cb, a, b); +} + +f77_int LSAME_(const char *ca,const char *cb,const f77_int a,const f77_int b) +{ + return lsame_( ca, cb, a, b); +} + +f77_int lsame(const char *ca,const char *cb,const f77_int a,const f77_int b) +{ + return lsame_( ca, cb, a, b); +} + +float SASUM(const f77_int *n,const float *sx, const f77_int *incx) +{ + return sasum_( n, sx, incx); +} + +float sasum(const f77_int *n,const float *sx, const f77_int *incx) +{ + return sasum_( n, sx, incx); +} + +float SASUM_(const f77_int *n,const float *sx, const f77_int *incx) +{ + return sasum_( n, sx, incx); +} + +void SAXPY(const f77_int *n,const float *sa,const float *sx,const f77_int *incx,float *sy,const f77_int *incy) +{ + saxpy_( n, sa, sx, incx, sy, incy); +} + +void saxpy(const f77_int *n,const float *sa,const float *sx,const f77_int *incx,float *sy,const f77_int *incy) +{ + saxpy_( n, sa, sx, incx, sy, incy); +} + +void SAXPY_(const f77_int *n,const float *sa,const float *sx,const f77_int *incx,float *sy,const f77_int *incy) +{ + saxpy_( n, sa, sx, incx, sy, incy); +} + + +float SCASUM(const f77_int *n,const scomplex *cx, const f77_int *incx) +{ + return scasum_( n, cx, incx); +} + +float scasum(const f77_int *n,const scomplex *cx, const f77_int *incx) +{ + return scasum_( n, cx, incx); +} + +float SCASUM_(const f77_int *n,const scomplex *cx, const f77_int *incx) +{ + return scasum_( n, cx, incx); +} + + + +float SCNRM2(const f77_int *n,const scomplex *x, const f77_int *incx) +{ + return scnrm2_( n, x, incx); +} + +float scnrm2(const f77_int *n,const scomplex *x, const f77_int *incx) +{ + return scnrm2_( n, x, incx); +} + +float SCNRM2_(const f77_int *n,const scomplex *x, const f77_int *incx) +{ + return scnrm2_( n, x, incx); +} + + +void SCOPY(const f77_int *n,const float *sx,const f77_int *incx,float *sy,const f77_int *incy) +{ + scopy_( n, sx, incx, sy, incy); +} + +void scopy(const f77_int *n,const float *sx,const f77_int *incx,float *sy,const f77_int *incy) +{ + scopy_( n, sx, incx, sy, incy); +} + +void SCOPY_(const f77_int *n,const float *sx,const f77_int *incx,float *sy,const f77_int *incy) +{ + scopy_( n, sx, incx, sy, incy); +} + + +float SDOT(const f77_int *n,const float *sx, const f77_int *incx, const float *sy, const f77_int *incy) +{ + return sdot_( n, sx, incx, sy, incy); +} + +float sdot(const f77_int *n,const float *sx, const f77_int *incx, const float *sy, const f77_int *incy) +{ + return sdot_( n, sx, incx, sy, incy); +} + +float SDOT_(const f77_int *n,const float *sx, const f77_int *incx, const float *sy, const f77_int *incy) +{ + return sdot_( n, sx, incx, sy, incy); +} + + +float SDSDOT(const f77_int *n,const float *sb, const float *sx, const f77_int *incx, const float *sy, const f77_int *incy) +{ + return sdsdot_( n, sb, sx, incx, sy, incy); +} + +float sdsdot(const f77_int *n,const float *sb, const float *sx, const f77_int *incx, const float *sy, const f77_int *incy) +{ + return sdsdot_( n, sb, sx, incx, sy, incy); +} + +float SDSDOT_(const f77_int *n,const float *sb, const float *sx, const f77_int *incx, const float *sy, const f77_int *incy) +{ + return sdsdot_( n, sb, sx, incx, sy, incy); +} + + +void SGBMV(const char *trans,const f77_int *m,const f77_int *n,const f77_int *kl,const f77_int *ku,const float *alpha,const float *a,const f77_int *lda,const float *x,const f77_int *incx,const float *beta,float *y,const f77_int *incy) +{ + sgbmv_( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); +} + +void sgbmv(const char *trans,const f77_int *m,const f77_int *n,const f77_int *kl,const f77_int *ku,const float *alpha,const float *a,const f77_int *lda,const float *x,const f77_int *incx,const float *beta,float *y,const f77_int *incy) +{ + sgbmv_( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); +} + +void SGBMV_(const char *trans,const f77_int *m,const f77_int *n,const f77_int *kl,const f77_int *ku,const float *alpha,const float *a,const f77_int *lda,const float *x,const f77_int *incx,const float *beta,float *y,const f77_int *incy) +{ + sgbmv_( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); +} + +void SGEMM(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const float *alpha,const float *a,const f77_int *lda,const float *b,const f77_int *ldb,const float *beta,float *c,const f77_int *ldc) +{ + sgemm_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void sgemm(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const float *alpha,const float *a,const f77_int *lda,const float *b,const f77_int *ldb,const float *beta,float *c,const f77_int *ldc) +{ + sgemm_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void SGEMM_(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const float *alpha,const float *a,const f77_int *lda,const float *b,const f77_int *ldb,const float *beta,float *c,const f77_int *ldc) +{ + sgemm_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void SGEMV(const char *trans,const f77_int *m,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,const float *x,const f77_int *incx,const float *beta,float *y,const f77_int *incy) +{ + sgemv_( trans, m, n, alpha, a, lda, x, incx, beta, y, incy); +} + +void sgemv(const char *trans,const f77_int *m,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,const float *x,const f77_int *incx,const float *beta,float *y,const f77_int *incy) +{ + sgemv_( trans, m, n, alpha, a, lda, x, incx, beta, y, incy); +} + +void SGEMV_(const char *trans,const f77_int *m,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,const float *x,const f77_int *incx,const float *beta,float *y,const f77_int *incy) +{ + sgemv_( trans, m, n, alpha, a, lda, x, incx, beta, y, incy); +} + +void SGER(const f77_int *m,const f77_int *n,const float *alpha,const float *x,const f77_int *incx,const float *y,const f77_int *incy,float *a,const f77_int *lda) +{ + sger_( m, n, alpha, x, incx, y, incy, a, lda); +} + +void sger(const f77_int *m,const f77_int *n,const float *alpha,const float *x,const f77_int *incx,const float *y,const f77_int *incy,float *a,const f77_int *lda) +{ + sger_( m, n, alpha, x, incx, y, incy, a, lda); +} + +void SGER_(const f77_int *m,const f77_int *n,const float *alpha,const float *x,const f77_int *incx,const float *y,const f77_int *incy,float *a,const f77_int *lda) +{ + sger_( m, n, alpha, x, incx, y, incy, a, lda); +} + + +float SNRM2(const f77_int *n,const float *x, const f77_int *incx) +{ + return snrm2_( n, x, incx); +} + +float snrm2(const f77_int *n,const float *x, const f77_int *incx) +{ + return snrm2_( n, x, incx); +} + +float SNRM2_(const f77_int *n,const float *x, const f77_int *incx) +{ + return snrm2_( n, x, incx); +} + + +void SROT(const f77_int *n,float *sx,const f77_int *incx,float *sy,const f77_int *incy,const float *c,const float *s) +{ + srot_( n, sx, incx, sy, incy, c, s); +} + +void srot(const f77_int *n,float *sx,const f77_int *incx,float *sy,const f77_int *incy,const float *c,const float *s) +{ + srot_( n, sx, incx, sy, incy, c, s); +} + +void SROT_(const f77_int *n,float *sx,const f77_int *incx,float *sy,const f77_int *incy,const float *c,const float *s) +{ + srot_( n, sx, incx, sy, incy, c, s); +} + +void SROTG(float *sa,float *sb,float *c,float *s) +{ + srotg_( sa, sb, c, s); +} + +void srotg(float *sa,float *sb,float *c,float *s) +{ + srotg_( sa, sb, c, s); +} + +void SROTG_(float *sa,float *sb,float *c,float *s) +{ + srotg_( sa, sb, c, s); +} + +void SROTM(const f77_int *n,float *sx,const f77_int *incx,float *sy,const f77_int *incy,const float *sparam) +{ + srotm_( n, sx, incx, sy, incy, sparam); +} + +void srotm(const f77_int *n,float *sx,const f77_int *incx,float *sy,const f77_int *incy,const float *sparam) +{ + srotm_( n, sx, incx, sy, incy, sparam); +} + +void SROTM_(const f77_int *n,float *sx,const f77_int *incx,float *sy,const f77_int *incy,const float *sparam) +{ + srotm_( n, sx, incx, sy, incy, sparam); +} + +void SROTMG(float *sd1,float *sd2,float *sx1,const float *sy1,float *sparam) +{ + srotmg_( sd1, sd2, sx1, sy1, sparam); +} + +void srotmg(float *sd1,float *sd2,float *sx1,const float *sy1,float *sparam) +{ + srotmg_( sd1, sd2, sx1, sy1, sparam); +} + +void SROTMG_(float *sd1,float *sd2,float *sx1,const float *sy1,float *sparam) +{ + srotmg_( sd1, sd2, sx1, sy1, sparam); +} + +void SSBMV(const char *uplo,const f77_int *n,const f77_int *k,const float *alpha,const float *a,const f77_int *lda,const float *x,const f77_int *incx,const float *beta,float *y,const f77_int *incy) +{ + ssbmv_( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); +} + +void ssbmv(const char *uplo,const f77_int *n,const f77_int *k,const float *alpha,const float *a,const f77_int *lda,const float *x,const f77_int *incx,const float *beta,float *y,const f77_int *incy) +{ + ssbmv_( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); +} + +void SSBMV_(const char *uplo,const f77_int *n,const f77_int *k,const float *alpha,const float *a,const f77_int *lda,const float *x,const f77_int *incx,const float *beta,float *y,const f77_int *incy) +{ + ssbmv_( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); +} + +void SSCAL(const f77_int *n,const float *sa,float *sx,const f77_int *incx) +{ + sscal_( n, sa, sx, incx); +} + +void sscal(const f77_int *n,const float *sa,float *sx,const f77_int *incx) +{ + sscal_( n, sa, sx, incx); +} + +void SSCAL_(const f77_int *n,const float *sa,float *sx,const f77_int *incx) +{ + sscal_( n, sa, sx, incx); +} + +void SSPMV(const char *uplo,const f77_int *n,const float *alpha,const float *ap,const float *x,const f77_int *incx,const float *beta,float *y,const f77_int *incy) +{ + sspmv_( uplo, n, alpha, ap, x, incx, beta, y, incy); +} + +void sspmv(const char *uplo,const f77_int *n,const float *alpha,const float *ap,const float *x,const f77_int *incx,const float *beta,float *y,const f77_int *incy) +{ + sspmv_( uplo, n, alpha, ap, x, incx, beta, y, incy); +} + +void SSPMV_(const char *uplo,const f77_int *n,const float *alpha,const float *ap,const float *x,const f77_int *incx,const float *beta,float *y,const f77_int *incy) +{ + sspmv_( uplo, n, alpha, ap, x, incx, beta, y, incy); +} + +void SSPR(const char *uplo,const f77_int *n,const float *alpha,const float *x,const f77_int *incx,float *ap) +{ + sspr_( uplo, n, alpha, x, incx, ap); +} + +void sspr(const char *uplo,const f77_int *n,const float *alpha,const float *x,const f77_int *incx,float *ap) +{ + sspr_( uplo, n, alpha, x, incx, ap); +} + +void SSPR_(const char *uplo,const f77_int *n,const float *alpha,const float *x,const f77_int *incx,float *ap) +{ + sspr_( uplo, n, alpha, x, incx, ap); +} + +void SSPR2(const char *uplo,const f77_int *n,const float *alpha,const float *x,const f77_int *incx,const float *y,const f77_int *incy,float *ap) +{ + sspr2_( uplo, n, alpha, x, incx, y, incy, ap); +} + +void sspr2(const char *uplo,const f77_int *n,const float *alpha,const float *x,const f77_int *incx,const float *y,const f77_int *incy,float *ap) +{ + sspr2_( uplo, n, alpha, x, incx, y, incy, ap); +} + +void SSPR2_(const char *uplo,const f77_int *n,const float *alpha,const float *x,const f77_int *incx,const float *y,const f77_int *incy,float *ap) +{ + sspr2_( uplo, n, alpha, x, incx, y, incy, ap); +} + +void SSWAP(const f77_int *n,float *sx,const f77_int *incx,float *sy,const f77_int *incy) +{ + sswap_( n, sx, incx, sy, incy); +} + +void sswap(const f77_int *n,float *sx,const f77_int *incx,float *sy,const f77_int *incy) +{ + sswap_( n, sx, incx, sy, incy); +} + +void SSWAP_(const f77_int *n,float *sx,const f77_int *incx,float *sy,const f77_int *incy) +{ + sswap_( n, sx, incx, sy, incy); +} + +void SSYMM(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,const float *b,const f77_int *ldb,const float *beta,float *c,const f77_int *ldc) +{ + ssymm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void ssymm(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,const float *b,const f77_int *ldb,const float *beta,float *c,const f77_int *ldc) +{ + ssymm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void SSYMM_(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,const float *b,const f77_int *ldb,const float *beta,float *c,const f77_int *ldc) +{ + ssymm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void SSYMV(const char *uplo,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,const float *x,const f77_int *incx,const float *beta,float *y,const f77_int *incy) +{ + ssymv_( uplo, n, alpha, a, lda, x, incx, beta, y, incy); +} + +void ssymv(const char *uplo,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,const float *x,const f77_int *incx,const float *beta,float *y,const f77_int *incy) +{ + ssymv_( uplo, n, alpha, a, lda, x, incx, beta, y, incy); +} + +void SSYMV_(const char *uplo,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,const float *x,const f77_int *incx,const float *beta,float *y,const f77_int *incy) +{ + ssymv_( uplo, n, alpha, a, lda, x, incx, beta, y, incy); +} + +void SSYR(const char *uplo,const f77_int *n,const float *alpha,const float *x,const f77_int *incx,float *a,const f77_int *lda) +{ + ssyr_( uplo, n, alpha, x, incx, a, lda); +} + +void ssyr(const char *uplo,const f77_int *n,const float *alpha,const float *x,const f77_int *incx,float *a,const f77_int *lda) +{ + ssyr_( uplo, n, alpha, x, incx, a, lda); +} + +void SSYR_(const char *uplo,const f77_int *n,const float *alpha,const float *x,const f77_int *incx,float *a,const f77_int *lda) +{ + ssyr_( uplo, n, alpha, x, incx, a, lda); +} + +void SSYR2(const char *uplo,const f77_int *n,const float *alpha,const float *x,const f77_int *incx,const float *y,const f77_int *incy,float *a,const f77_int *lda) +{ + ssyr2_( uplo, n, alpha, x, incx, y, incy, a, lda); +} + +void ssyr2(const char *uplo,const f77_int *n,const float *alpha,const float *x,const f77_int *incx,const float *y,const f77_int *incy,float *a,const f77_int *lda) +{ + ssyr2_( uplo, n, alpha, x, incx, y, incy, a, lda); +} + +void SSYR2_(const char *uplo,const f77_int *n,const float *alpha,const float *x,const f77_int *incx,const float *y,const f77_int *incy,float *a,const f77_int *lda) +{ + ssyr2_( uplo, n, alpha, x, incx, y, incy, a, lda); +} + +void SSYR2K(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const float *alpha,const float *a,const f77_int *lda,const float *b,const f77_int *ldb,const float *beta,float *c,const f77_int *ldc) +{ + ssyr2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void ssyr2k(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const float *alpha,const float *a,const f77_int *lda,const float *b,const f77_int *ldb,const float *beta,float *c,const f77_int *ldc) +{ + ssyr2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void SSYR2K_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const float *alpha,const float *a,const f77_int *lda,const float *b,const f77_int *ldb,const float *beta,float *c,const f77_int *ldc) +{ + ssyr2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void SSYRK(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const float *alpha,const float *a,const f77_int *lda,const float *beta,float *c,const f77_int *ldc) +{ + ssyrk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); +} + +void ssyrk(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const float *alpha,const float *a,const f77_int *lda,const float *beta,float *c,const f77_int *ldc) +{ + ssyrk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); +} + +void SSYRK_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const float *alpha,const float *a,const f77_int *lda,const float *beta,float *c,const f77_int *ldc) +{ + ssyrk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); +} + +void STBMV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const float *a,const f77_int *lda,float *x,const f77_int *incx) +{ + stbmv_( uplo, trans, diag, n, k, a, lda, x, incx); +} + +void stbmv(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const float *a,const f77_int *lda,float *x,const f77_int *incx) +{ + stbmv_( uplo, trans, diag, n, k, a, lda, x, incx); +} + +void STBMV_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const float *a,const f77_int *lda,float *x,const f77_int *incx) +{ + stbmv_( uplo, trans, diag, n, k, a, lda, x, incx); +} + +void STBSV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const float *a,const f77_int *lda,float *x,const f77_int *incx) +{ + stbsv_( uplo, trans, diag, n, k, a, lda, x, incx); +} + +void stbsv(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const float *a,const f77_int *lda,float *x,const f77_int *incx) +{ + stbsv_( uplo, trans, diag, n, k, a, lda, x, incx); +} + +void STBSV_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const float *a,const f77_int *lda,float *x,const f77_int *incx) +{ + stbsv_( uplo, trans, diag, n, k, a, lda, x, incx); +} + +void STPMV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const float *ap,float *x,const f77_int *incx) +{ + stpmv_( uplo, trans, diag, n, ap, x, incx); +} + +void stpmv(const char *uplo,const char *trans,const char *diag,const f77_int *n,const float *ap,float *x,const f77_int *incx) +{ + stpmv_( uplo, trans, diag, n, ap, x, incx); +} + +void STPMV_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const float *ap,float *x,const f77_int *incx) +{ + stpmv_( uplo, trans, diag, n, ap, x, incx); +} + +void STPSV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const float *ap,float *x,const f77_int *incx) +{ + stpsv_( uplo, trans, diag, n, ap, x, incx); +} + +void stpsv(const char *uplo,const char *trans,const char *diag,const f77_int *n,const float *ap,float *x,const f77_int *incx) +{ + stpsv_( uplo, trans, diag, n, ap, x, incx); +} + +void STPSV_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const float *ap,float *x,const f77_int *incx) +{ + stpsv_( uplo, trans, diag, n, ap, x, incx); +} + +void STRMM(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,float *b,const f77_int *ldb) +{ + strmm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); +} + +void strmm(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,float *b,const f77_int *ldb) +{ + strmm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); +} + +void STRMM_(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,float *b,const f77_int *ldb) +{ + strmm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); +} + +void STRMV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const float *a,const f77_int *lda,float *x,const f77_int *incx) +{ + strmv_( uplo, trans, diag, n, a, lda, x, incx); +} + +void strmv(const char *uplo,const char *trans,const char *diag,const f77_int *n,const float *a,const f77_int *lda,float *x,const f77_int *incx) +{ + strmv_( uplo, trans, diag, n, a, lda, x, incx); +} + +void STRMV_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const float *a,const f77_int *lda,float *x,const f77_int *incx) +{ + strmv_( uplo, trans, diag, n, a, lda, x, incx); +} + +void STRSM(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,float *b,const f77_int *ldb) +{ + strsm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); +} + +void strsm(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,float *b,const f77_int *ldb) +{ + strsm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); +} + +void STRSM_(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,float *b,const f77_int *ldb) +{ + strsm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); +} + +void STRSV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const float *a,const f77_int *lda,float *x,const f77_int *incx) +{ + strsv_( uplo, trans, diag, n, a, lda, x, incx); +} + +void strsv(const char *uplo,const char *trans,const char *diag,const f77_int *n,const float *a,const f77_int *lda,float *x,const f77_int *incx) +{ + strsv_( uplo, trans, diag, n, a, lda, x, incx); +} + +void STRSV_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const float *a,const f77_int *lda,float *x,const f77_int *incx) +{ + strsv_( uplo, trans, diag, n, a, lda, x, incx); +} + +int XERBLA(const char *srname,const f77_int *info, ftnlen n) +{ + return xerbla_( srname, info, n); +} + +int XERBLA_(const char *srname,const f77_int *info, ftnlen n) +{ + return xerbla_( srname, info, n); +} + +int xerbla(const char *srname,const f77_int *info, ftnlen n) +{ + return xerbla_( srname, info, n); +} + +void ZAXPY(const f77_int *n,const dcomplex *za,const dcomplex *zx,const f77_int *incx,dcomplex *zy,const f77_int *incy) +{ + zaxpy_( n, za, zx, incx, zy, incy); +} + +void zaxpy(const f77_int *n,const dcomplex *za,const dcomplex *zx,const f77_int *incx,dcomplex *zy,const f77_int *incy) +{ + zaxpy_( n, za, zx, incx, zy, incy); +} + +void ZAXPY_(const f77_int *n,const dcomplex *za,const dcomplex *zx,const f77_int *incx,dcomplex *zy,const f77_int *incy) +{ + zaxpy_( n, za, zx, incx, zy, incy); +} + +void ZCOPY(const f77_int *n,const dcomplex *zx,const f77_int *incx,dcomplex *zy,const f77_int *incy) +{ + zcopy_( n, zx, incx, zy, incy); +} + +void zcopy(const f77_int *n,const dcomplex *zx,const f77_int *incx,dcomplex *zy,const f77_int *incy) +{ + zcopy_( n, zx, incx, zy, incy); +} + +void ZCOPY_(const f77_int *n,const dcomplex *zx,const f77_int *incx,dcomplex *zy,const f77_int *incy) +{ + zcopy_( n, zx, incx, zy, incy); +} + +void ZDROT(const f77_int *n,dcomplex *cx,const f77_int *incx,dcomplex *cy,const f77_int *incy,const double *c,const double *s) +{ + zdrot_( n, cx, incx, cy, incy, c, s); +} + +void zdrot(const f77_int *n,dcomplex *cx,const f77_int *incx,dcomplex *cy,const f77_int *incy,const double *c,const double *s) +{ + zdrot_( n, cx, incx, cy, incy, c, s); +} + +void ZDROT_(const f77_int *n,dcomplex *cx,const f77_int *incx,dcomplex *cy,const f77_int *incy,const double *c,const double *s) +{ + zdrot_( n, cx, incx, cy, incy, c, s); +} + +void ZDSCAL(const f77_int *n,const double *da,dcomplex *zx,const f77_int *incx) +{ + zdscal_( n, da, zx, incx); +} + +void zdscal(const f77_int *n,const double *da,dcomplex *zx,const f77_int *incx) +{ + zdscal_( n, da, zx, incx); +} + +void ZDSCAL_(const f77_int *n,const double *da,dcomplex *zx,const f77_int *incx) +{ + zdscal_( n, da, zx, incx); +} + +void ZGBMV(const char *trans,const f77_int *m,const f77_int *n,const f77_int *kl,const f77_int *ku,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *x,const f77_int *incx,const dcomplex *beta,dcomplex *y,const f77_int *incy) +{ + zgbmv_( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); +} + +void zgbmv(const char *trans,const f77_int *m,const f77_int *n,const f77_int *kl,const f77_int *ku,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *x,const f77_int *incx,const dcomplex *beta,dcomplex *y,const f77_int *incy) +{ + zgbmv_( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); +} + +void ZGBMV_(const char *trans,const f77_int *m,const f77_int *n,const f77_int *kl,const f77_int *ku,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *x,const f77_int *incx,const dcomplex *beta,dcomplex *y,const f77_int *incy) +{ + zgbmv_( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); +} + +void ZGEMM(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) +{ + zgemm_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void zgemm(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) +{ + zgemm_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void ZGEMM_(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) +{ + zgemm_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void ZGEMV(const char *trans,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *x,const f77_int *incx,const dcomplex *beta,dcomplex *y,const f77_int *incy) +{ + zgemv_( trans, m, n, alpha, a, lda, x, incx, beta, y, incy); +} + +void zgemv(const char *trans,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *x,const f77_int *incx,const dcomplex *beta,dcomplex *y,const f77_int *incy) +{ + zgemv_( trans, m, n, alpha, a, lda, x, incx, beta, y, incy); +} + +void ZGEMV_(const char *trans,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *x,const f77_int *incx,const dcomplex *beta,dcomplex *y,const f77_int *incy) +{ + zgemv_( trans, m, n, alpha, a, lda, x, incx, beta, y, incy); +} + +void ZGERC(const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *x,const f77_int *incx,const dcomplex *y,const f77_int *incy,dcomplex *a,const f77_int *lda) +{ + zgerc_( m, n, alpha, x, incx, y, incy, a, lda); +} + +void zgerc(const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *x,const f77_int *incx,const dcomplex *y,const f77_int *incy,dcomplex *a,const f77_int *lda) +{ + zgerc_( m, n, alpha, x, incx, y, incy, a, lda); +} + +void ZGERC_(const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *x,const f77_int *incx,const dcomplex *y,const f77_int *incy,dcomplex *a,const f77_int *lda) +{ + zgerc_( m, n, alpha, x, incx, y, incy, a, lda); +} + +void ZGERU(const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *x,const f77_int *incx,const dcomplex *y,const f77_int *incy,dcomplex *a,const f77_int *lda) +{ + zgeru_( m, n, alpha, x, incx, y, incy, a, lda); +} + +void zgeru(const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *x,const f77_int *incx,const dcomplex *y,const f77_int *incy,dcomplex *a,const f77_int *lda) +{ + zgeru_( m, n, alpha, x, incx, y, incy, a, lda); +} + +void ZGERU_(const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *x,const f77_int *incx,const dcomplex *y,const f77_int *incy,dcomplex *a,const f77_int *lda) +{ + zgeru_( m, n, alpha, x, incx, y, incy, a, lda); +} + +void ZHBMV(const char *uplo,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *x,const f77_int *incx,const dcomplex *beta,dcomplex *y,const f77_int *incy) +{ + zhbmv_( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); +} + +void zhbmv(const char *uplo,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *x,const f77_int *incx,const dcomplex *beta,dcomplex *y,const f77_int *incy) +{ + zhbmv_( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); +} + +void ZHBMV_(const char *uplo,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *x,const f77_int *incx,const dcomplex *beta,dcomplex *y,const f77_int *incy) +{ + zhbmv_( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); +} + +void ZHEMM(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) +{ + zhemm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void zhemm(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) +{ + zhemm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void ZHEMM_(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) +{ + zhemm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void ZHEMV(const char *uplo,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *x,const f77_int *incx,const dcomplex *beta,dcomplex *y,const f77_int *incy) +{ + zhemv_( uplo, n, alpha, a, lda, x, incx, beta, y, incy); +} + +void zhemv(const char *uplo,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *x,const f77_int *incx,const dcomplex *beta,dcomplex *y,const f77_int *incy) +{ + zhemv_( uplo, n, alpha, a, lda, x, incx, beta, y, incy); +} + +void ZHEMV_(const char *uplo,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *x,const f77_int *incx,const dcomplex *beta,dcomplex *y,const f77_int *incy) +{ + zhemv_( uplo, n, alpha, a, lda, x, incx, beta, y, incy); +} + +void ZHER(const char *uplo,const f77_int *n,const double *alpha,const dcomplex *x,const f77_int *incx,dcomplex *a,const f77_int *lda) +{ + zher_( uplo, n, alpha, x, incx, a, lda); +} + +void zher(const char *uplo,const f77_int *n,const double *alpha,const dcomplex *x,const f77_int *incx,dcomplex *a,const f77_int *lda) +{ + zher_( uplo, n, alpha, x, incx, a, lda); +} + +void ZHER_(const char *uplo,const f77_int *n,const double *alpha,const dcomplex *x,const f77_int *incx,dcomplex *a,const f77_int *lda) +{ + zher_( uplo, n, alpha, x, incx, a, lda); +} + +void ZHER2(const char *uplo,const f77_int *n,const dcomplex *alpha,const dcomplex *x,const f77_int *incx,const dcomplex *y,const f77_int *incy,dcomplex *a,const f77_int *lda) +{ + zher2_( uplo, n, alpha, x, incx, y, incy, a, lda); +} + +void zher2(const char *uplo,const f77_int *n,const dcomplex *alpha,const dcomplex *x,const f77_int *incx,const dcomplex *y,const f77_int *incy,dcomplex *a,const f77_int *lda) +{ + zher2_( uplo, n, alpha, x, incx, y, incy, a, lda); +} + +void ZHER2_(const char *uplo,const f77_int *n,const dcomplex *alpha,const dcomplex *x,const f77_int *incx,const dcomplex *y,const f77_int *incy,dcomplex *a,const f77_int *lda) +{ + zher2_( uplo, n, alpha, x, incx, y, incy, a, lda); +} + +void ZHER2K(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const double *beta,dcomplex *c,const f77_int *ldc) +{ + zher2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void zher2k(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const double *beta,dcomplex *c,const f77_int *ldc) +{ + zher2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void ZHER2K_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const double *beta,dcomplex *c,const f77_int *ldc) +{ + zher2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void ZHERK(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const double *alpha,const dcomplex *a,const f77_int *lda,const double *beta,dcomplex *c,const f77_int *ldc) +{ + zherk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); +} + +void zherk(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const double *alpha,const dcomplex *a,const f77_int *lda,const double *beta,dcomplex *c,const f77_int *ldc) +{ + zherk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); +} + +void ZHERK_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const double *alpha,const dcomplex *a,const f77_int *lda,const double *beta,dcomplex *c,const f77_int *ldc) +{ + zherk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); +} + +void ZHPMV(const char *uplo,const f77_int *n,const dcomplex *alpha,const dcomplex *ap,const dcomplex *x,const f77_int *incx,const dcomplex *beta,dcomplex *y,const f77_int *incy) +{ + zhpmv_( uplo, n, alpha, ap, x, incx, beta, y, incy); +} + +void zhpmv(const char *uplo,const f77_int *n,const dcomplex *alpha,const dcomplex *ap,const dcomplex *x,const f77_int *incx,const dcomplex *beta,dcomplex *y,const f77_int *incy) +{ + zhpmv_( uplo, n, alpha, ap, x, incx, beta, y, incy); +} + +void ZHPMV_(const char *uplo,const f77_int *n,const dcomplex *alpha,const dcomplex *ap,const dcomplex *x,const f77_int *incx,const dcomplex *beta,dcomplex *y,const f77_int *incy) +{ + zhpmv_( uplo, n, alpha, ap, x, incx, beta, y, incy); +} + +void ZHPR(const char *uplo,const f77_int *n,const bla_double *alpha,const dcomplex *x,const f77_int *incx,dcomplex *ap) +{ + zhpr_( uplo, n, alpha, x, incx, ap); +} + +void zhpr(const char *uplo,const f77_int *n,const bla_double *alpha,const dcomplex *x,const f77_int *incx,dcomplex *ap) +{ + zhpr_( uplo, n, alpha, x, incx, ap); +} + +void ZHPR_(const char *uplo,const f77_int *n,const bla_double *alpha,const dcomplex *x,const f77_int *incx,dcomplex *ap) +{ + zhpr_( uplo, n, alpha, x, incx, ap); +} + +void ZHPR2(const char *uplo,const f77_int *n,const dcomplex *alpha,const dcomplex *x,const f77_int *incx,const dcomplex *y,const f77_int *incy,dcomplex *ap) +{ + zhpr2_( uplo, n, alpha, x, incx, y, incy, ap); +} + +void zhpr2(const char *uplo,const f77_int *n,const dcomplex *alpha,const dcomplex *x,const f77_int *incx,const dcomplex *y,const f77_int *incy,dcomplex *ap) +{ + zhpr2_( uplo, n, alpha, x, incx, y, incy, ap); +} + +void ZHPR2_(const char *uplo,const f77_int *n,const dcomplex *alpha,const dcomplex *x,const f77_int *incx,const dcomplex *y,const f77_int *incy,dcomplex *ap) +{ + zhpr2_( uplo, n, alpha, x, incx, y, incy, ap); +} + +void ZROTG(dcomplex *ca,bla_dcomplex *cb,bla_double *c,dcomplex *s) +{ + zrotg_( ca, cb, c, s); +} + +void zrotg(dcomplex *ca,bla_dcomplex *cb,bla_double *c,dcomplex *s) +{ + zrotg_( ca, cb, c, s); +} + +void ZROTG_(dcomplex *ca,bla_dcomplex *cb,bla_double *c,dcomplex *s) +{ + zrotg_( ca, cb, c, s); +} + +void ZSCAL(const f77_int *n,const dcomplex *za,dcomplex *zx,const f77_int *incx) +{ + zscal_( n, za, zx, incx); +} + +void zscal(const f77_int *n,const dcomplex *za,dcomplex *zx,const f77_int *incx) +{ + zscal_( n, za, zx, incx); +} + +void ZSCAL_(const f77_int *n,const dcomplex *za,dcomplex *zx,const f77_int *incx) +{ + zscal_( n, za, zx, incx); +} + +void ZSWAP(const f77_int *n,dcomplex *zx,const f77_int *incx,dcomplex *zy,const f77_int *incy) +{ + zswap_( n, zx, incx, zy, incy); +} + +void zswap(const f77_int *n,dcomplex *zx,const f77_int *incx,dcomplex *zy,const f77_int *incy) +{ + zswap_( n, zx, incx, zy, incy); +} + +void ZSWAP_(const f77_int *n,dcomplex *zx,const f77_int *incx,dcomplex *zy,const f77_int *incy) +{ + zswap_( n, zx, incx, zy, incy); +} + +void ZSYMM(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) +{ + zsymm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void zsymm(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) +{ + zsymm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void ZSYMM_(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) +{ + zsymm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void ZSYR2K(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) +{ + zsyr2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void zsyr2k(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) +{ + zsyr2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void ZSYR2K_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) +{ + zsyr2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void ZSYRK(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *beta,dcomplex *c,const f77_int *ldc) +{ + zsyrk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); +} + +void zsyrk(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *beta,dcomplex *c,const f77_int *ldc) +{ + zsyrk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); +} + +void ZSYRK_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *beta,dcomplex *c,const f77_int *ldc) +{ + zsyrk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); +} + +void ZTBMV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const dcomplex *a,const f77_int *lda,dcomplex *x,const f77_int *incx) +{ + ztbmv_( uplo, trans, diag, n, k, a, lda, x, incx); +} + +void ztbmv(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const dcomplex *a,const f77_int *lda,dcomplex *x,const f77_int *incx) +{ + ztbmv_( uplo, trans, diag, n, k, a, lda, x, incx); +} + +void ZTBMV_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const dcomplex *a,const f77_int *lda,dcomplex *x,const f77_int *incx) +{ + ztbmv_( uplo, trans, diag, n, k, a, lda, x, incx); +} + +void ZTBSV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const dcomplex *a,const f77_int *lda,dcomplex *x,const f77_int *incx) +{ + ztbsv_( uplo, trans, diag, n, k, a, lda, x, incx); +} + +void ztbsv(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const dcomplex *a,const f77_int *lda,dcomplex *x,const f77_int *incx) +{ + ztbsv_( uplo, trans, diag, n, k, a, lda, x, incx); +} + +void ZTBSV_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const dcomplex *a,const f77_int *lda,dcomplex *x,const f77_int *incx) +{ + ztbsv_( uplo, trans, diag, n, k, a, lda, x, incx); +} + +void ZTPMV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const dcomplex *ap,dcomplex *x,const f77_int *incx) +{ + ztpmv_( uplo, trans, diag, n, ap, x, incx); +} + +void ztpmv(const char *uplo,const char *trans,const char *diag,const f77_int *n,const dcomplex *ap,dcomplex *x,const f77_int *incx) +{ + ztpmv_( uplo, trans, diag, n, ap, x, incx); +} + +void ZTPMV_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const dcomplex *ap,dcomplex *x,const f77_int *incx) +{ + ztpmv_( uplo, trans, diag, n, ap, x, incx); +} + +void ZTPSV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const dcomplex *ap,dcomplex *x,const f77_int *incx) +{ + ztpsv_( uplo, trans, diag, n, ap, x, incx); +} + +void ztpsv(const char *uplo,const char *trans,const char *diag,const f77_int *n,const dcomplex *ap,dcomplex *x,const f77_int *incx) +{ + ztpsv_( uplo, trans, diag, n, ap, x, incx); +} + +void ZTPSV_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const dcomplex *ap,dcomplex *x,const f77_int *incx) +{ + ztpsv_( uplo, trans, diag, n, ap, x, incx); +} + +void ZTRMM(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,dcomplex *b,const f77_int *ldb) +{ + ztrmm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); +} + +void ztrmm(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,dcomplex *b,const f77_int *ldb) +{ + ztrmm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); +} + +void ZTRMM_(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,dcomplex *b,const f77_int *ldb) +{ + ztrmm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); +} + +void ZTRMV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const dcomplex *a,const f77_int *lda,dcomplex *x,const f77_int *incx) +{ + ztrmv_( uplo, trans, diag, n, a, lda, x, incx); +} + +void ztrmv(const char *uplo,const char *trans,const char *diag,const f77_int *n,const dcomplex *a,const f77_int *lda,dcomplex *x,const f77_int *incx) +{ + ztrmv_( uplo, trans, diag, n, a, lda, x, incx); +} + +void ZTRMV_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const dcomplex *a,const f77_int *lda,dcomplex *x,const f77_int *incx) +{ + ztrmv_( uplo, trans, diag, n, a, lda, x, incx); +} + +void ZTRSM(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,dcomplex *b,const f77_int *ldb) +{ + ztrsm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); +} + +void ztrsm(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,dcomplex *b,const f77_int *ldb) +{ + ztrsm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); +} + +void ZTRSM_(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,dcomplex *b,const f77_int *ldb) +{ + ztrsm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); +} + +void ZTRSV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const dcomplex *a,const f77_int *lda,dcomplex *x,const f77_int *incx) +{ + ztrsv_( uplo, trans, diag, n, a, lda, x, incx); +} + +void ztrsv(const char *uplo,const char *trans,const char *diag,const f77_int *n,const dcomplex *a,const f77_int *lda,dcomplex *x,const f77_int *incx) +{ + ztrsv_( uplo, trans, diag, n, a, lda, x, incx); +} + +void ZTRSV_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const dcomplex *a,const f77_int *lda,dcomplex *x,const f77_int *incx) +{ + ztrsv_( uplo, trans, diag, n, a, lda, x, incx); +} + + +void CDOTCSUB( const f77_int* n, const scomplex* x,const f77_int* incx, const scomplex* y, const f77_int* incy, scomplex* rval) +{ + cdotcsub_( n, x, incx, y, incy, rval); +} + +void cdotcsub( const f77_int* n, const scomplex* x,const f77_int* incx, const scomplex* y, const f77_int* incy, scomplex* rval) +{ + cdotcsub_( n, x, incx, y, incy, rval); +} + +void CDOTCSUB_( const f77_int* n, const scomplex* x,const f77_int* incx, const scomplex* y, const f77_int* incy, scomplex* rval) +{ + cdotcsub_( n, x, incx, y, incy, rval); +} + +void CDOTUSUB( const f77_int* n, const scomplex* x,const f77_int* incxy, const scomplex* y, const f77_int* incy, scomplex* rval) +{ + cdotusub_( n, x, incxy, y, incy, rval); +} + +void cdotusub( const f77_int* n, const scomplex* x,const f77_int* incxy, const scomplex* y, const f77_int* incy, scomplex* rval) +{ + cdotusub_( n, x, incxy, y, incy, rval); +} + +void CDOTUSUB_( const f77_int* n, const scomplex* x,const f77_int* incxy, const scomplex* y, const f77_int* incy, scomplex* rval) +{ + cdotusub_( n, x, incxy, y, incy, rval); +} + +void CGEMM3M( const f77_char* transa, const f77_char* transb, const f77_int* m, const f77_int* n, const f77_int* k, const scomplex* alpha, const scomplex* a, const f77_int* lda, const scomplex* b, const f77_int* ldb, const scomplex* beta, scomplex* c, const f77_int* ldc) +{ + cgemm3m_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void cgemm3m( const f77_char* transa, const f77_char* transb, const f77_int* m, const f77_int* n, const f77_int* k, const scomplex* alpha, const scomplex* a, const f77_int* lda, const scomplex* b, const f77_int* ldb, const scomplex* beta, scomplex* c, const f77_int* ldc) +{ + cgemm3m_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void CGEMM3M_( const f77_char* transa, const f77_char* transb, const f77_int* m, const f77_int* n, const f77_int* k, const scomplex* alpha, const scomplex* a, const f77_int* lda, const scomplex* b, const f77_int* ldb, const scomplex* beta, scomplex* c, const f77_int* ldc) +{ + cgemm3m_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void CGEMM_BATCH( const f77_char* transa_array, const f77_char* transb_array,const f77_int *m_array, const f77_int *n_array, const f77_int *k_array,const scomplex* alpha_array, const scomplex** a_array, const f77_int *lda_array, const scomplex** b_array, const f77_int *ldb_array, const scomplex* beta_array, scomplex** c_array, const f77_int *ldc_array, const f77_int* group_count, const f77_int *group_size) +{ + cgemm_batch_( transa_array, transb_array, m_array, n_array, k_array, alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array, group_count, group_size); +} + +void cgemm_batch( const f77_char* transa_array, const f77_char* transb_array,const f77_int *m_array, const f77_int *n_array, const f77_int *k_array,const scomplex* alpha_array, const scomplex** a_array, const f77_int *lda_array, const scomplex** b_array, const f77_int *ldb_array, const scomplex* beta_array, scomplex** c_array, const f77_int *ldc_array, const f77_int* group_count, const f77_int *group_size) +{ + cgemm_batch_( transa_array, transb_array, m_array, n_array, k_array, alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array, group_count, group_size); +} + +void CGEMM_BATCH_( const f77_char* transa_array, const f77_char* transb_array,const f77_int *m_array, const f77_int *n_array, const f77_int *k_array,const scomplex* alpha_array, const scomplex** a_array, const f77_int *lda_array, const scomplex** b_array, const f77_int *ldb_array, const scomplex* beta_array, scomplex** c_array, const f77_int *ldc_array, const f77_int* group_count, const f77_int *group_size) +{ + cgemm_batch_( transa_array, transb_array, m_array, n_array, k_array, alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array, group_count, group_size); +} + +void CGEMMT( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const scomplex* alpha, const scomplex* a, const f77_int* lda, const scomplex* b, const f77_int* ldb, const scomplex* beta, scomplex* c, const f77_int* ldc) +{ + cgemmt_( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void cgemmt( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const scomplex* alpha, const scomplex* a, const f77_int* lda, const scomplex* b, const f77_int* ldb, const scomplex* beta, scomplex* c, const f77_int* ldc) +{ + cgemmt_( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void CGEMMT_( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const scomplex* alpha, const scomplex* a, const f77_int* lda, const scomplex* b, const f77_int* ldb, const scomplex* beta, scomplex* c, const f77_int* ldc) +{ + cgemmt_( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void CIMATCOPY(f77_char* trans, f77_int* rows, f77_int* cols, const scomplex* alpha,scomplex* aptr, f77_int* lda, f77_int* ldb) +{ + cimatcopy_( trans, rows, cols, alpha, aptr, lda, ldb); +} + +void cimatcopy(f77_char* trans, f77_int* rows, f77_int* cols, const scomplex* alpha,scomplex* aptr, f77_int* lda, f77_int* ldb) +{ + cimatcopy_( trans, rows, cols, alpha, aptr, lda, ldb); +} + +void CIMATCOPY_(f77_char* trans, f77_int* rows, f77_int* cols, const scomplex* alpha,scomplex* aptr, f77_int* lda, f77_int* ldb) +{ + cimatcopy_( trans, rows, cols, alpha, aptr, lda, ldb); +} + +void COMATADD(f77_char* transa,f77_char* transb, f77_int* m, f77_int* n, const scomplex* alpha, const scomplex* A, f77_int* lda,const scomplex* beta, scomplex* B, f77_int* ldb, scomplex* C, f77_int* ldc) +{ + comatadd_( transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc); +} + +void comatadd(f77_char* transa,f77_char* transb, f77_int* m, f77_int* n, const scomplex* alpha, const scomplex* A, f77_int* lda,const scomplex* beta, scomplex* B, f77_int* ldb, scomplex* C, f77_int* ldc) +{ + comatadd_( transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc); +} + +void COMATADD_(f77_char* transa,f77_char* transb, f77_int* m, f77_int* n, const scomplex* alpha, const scomplex* A, f77_int* lda,const scomplex* beta, scomplex* B, f77_int* ldb, scomplex* C, f77_int* ldc) +{ + comatadd_( transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc); +} + +void COMATCOPY2(f77_char* trans, f77_int* rows, f77_int* cols, const scomplex* alpha, const scomplex* aptr, f77_int* lda,f77_int* stridea, scomplex* bptr, f77_int* ldb,f77_int* strideb) +{ + comatcopy2_( trans, rows, cols, alpha, aptr, lda, stridea, bptr, ldb, strideb); +} + +void comatcopy2(f77_char* trans, f77_int* rows, f77_int* cols, const scomplex* alpha, const scomplex* aptr, f77_int* lda,f77_int* stridea, scomplex* bptr, f77_int* ldb,f77_int* strideb) +{ + comatcopy2_( trans, rows, cols, alpha, aptr, lda, stridea, bptr, ldb, strideb); +} + +void COMATCOPY2_(f77_char* trans, f77_int* rows, f77_int* cols, const scomplex* alpha, const scomplex* aptr, f77_int* lda,f77_int* stridea, scomplex* bptr, f77_int* ldb,f77_int* strideb) +{ + comatcopy2_( trans, rows, cols, alpha, aptr, lda, stridea, bptr, ldb, strideb); +} + +void COMATCOPY(f77_char* trans, f77_int* rows, f77_int* cols, const scomplex* alpha, const scomplex* aptr, f77_int* lda, scomplex* bptr, f77_int* ldb) +{ + comatcopy_( trans, rows, cols, alpha, aptr, lda, bptr, ldb); +} + +void comatcopy(f77_char* trans, f77_int* rows, f77_int* cols, const scomplex* alpha, const scomplex* aptr, f77_int* lda, scomplex* bptr, f77_int* ldb) +{ + comatcopy_( trans, rows, cols, alpha, aptr, lda, bptr, ldb); +} + +void COMATCOPY_(f77_char* trans, f77_int* rows, f77_int* cols, const scomplex* alpha, const scomplex* aptr, f77_int* lda, scomplex* bptr, f77_int* ldb) +{ + comatcopy_( trans, rows, cols, alpha, aptr, lda, bptr, ldb); +} + +void DASUMSUB(const f77_int* n, const double* x, const f77_int* incx, double* rval) +{ + dasumsub_( n, x, incx, rval); +} + +void dasumsub(const f77_int* n, const double* x, const f77_int* incx, double* rval) +{ + dasumsub_( n, x, incx, rval); +} + +void DASUMSUB_(const f77_int* n, const double* x, const f77_int* incx, double* rval) +{ + dasumsub_( n, x, incx, rval); +} + +void DAXPBY(const f77_int* n, const double* alpha, const double *x, const f77_int* incx, const double* beta, double *y, const f77_int* incy) +{ + daxpby_( n, alpha, x, incx, beta, y, incy); +} + +void daxpby(const f77_int* n, const double* alpha, const double *x, const f77_int* incx, const double* beta, double *y, const f77_int* incy) +{ + daxpby_( n, alpha, x, incx, beta, y, incy); +} + +void DAXPBY_(const f77_int* n, const double* alpha, const double *x, const f77_int* incx, const double* beta, double *y, const f77_int* incy) +{ + daxpby_( n, alpha, x, incx, beta, y, incy); +} + +void DDOTSUB(const f77_int* n, const double* x, const f77_int* incx, const double* y, const f77_int* incy, double* rval) +{ + ddotsub_( n, x, incx, y, incy, rval); +} + +void ddotsub(const f77_int* n, const double* x, const f77_int* incx, const double* y, const f77_int* incy, double* rval) +{ + ddotsub_( n, x, incx, y, incy, rval); +} + +void DDOTSUB_(const f77_int* n, const double* x, const f77_int* incx, const double* y, const f77_int* incy, double* rval) +{ + ddotsub_( n, x, incx, y, incy, rval); +} + +void DGEMM_BATCH( const f77_char* transa_array, const f77_char* transb_array,const f77_int *m_array, const f77_int *n_array, const f77_int *k_array,const double* alpha_array, const double** a_array, const f77_int *lda_array, const double** b_array, const f77_int *ldb_array, const double* beta_array, double** c_array, const f77_int *ldc_array, const f77_int* group_count, const f77_int *group_size) +{ + dgemm_batch_( transa_array, transb_array, m_array, n_array, k_array, alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array, group_count, group_size); +} + +void dgemm_batch( const f77_char* transa_array, const f77_char* transb_array,const f77_int *m_array, const f77_int *n_array, const f77_int *k_array,const double* alpha_array, const double** a_array, const f77_int *lda_array, const double** b_array, const f77_int *ldb_array, const double* beta_array, double** c_array, const f77_int *ldc_array, const f77_int* group_count, const f77_int *group_size) +{ + dgemm_batch_( transa_array, transb_array, m_array, n_array, k_array, alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array, group_count, group_size); +} + +void DGEMM_BATCH_( const f77_char* transa_array, const f77_char* transb_array,const f77_int *m_array, const f77_int *n_array, const f77_int *k_array,const double* alpha_array, const double** a_array, const f77_int *lda_array, const double** b_array, const f77_int *ldb_array, const double* beta_array, double** c_array, const f77_int *ldc_array, const f77_int* group_count, const f77_int *group_size) +{ + dgemm_batch_( transa_array, transb_array, m_array, n_array, k_array, alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array, group_count, group_size); +} + +void DGEMMT( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const double* alpha, const double* a, const f77_int* lda, const double* b, const f77_int* ldb, const double* beta, double* c, const f77_int* ldc) +{ + dgemmt_( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void dgemmt( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const double* alpha, const double* a, const f77_int* lda, const double* b, const f77_int* ldb, const double* beta, double* c, const f77_int* ldc) +{ + dgemmt_( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void DGEMMT_( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const double* alpha, const double* a, const f77_int* lda, const double* b, const f77_int* ldb, const double* beta, double* c, const f77_int* ldc) +{ + dgemmt_( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void DNRM2SUB(const f77_int* n, const double* x, const f77_int* incx, double *rval) +{ + dnrm2sub_( n, x, incx, rval); +} + +void dnrm2sub(const f77_int* n, const double* x, const f77_int* incx, double *rval) +{ + dnrm2sub_( n, x, incx, rval); +} + +void DNRM2SUB_(const f77_int* n, const double* x, const f77_int* incx, double *rval) +{ + dnrm2sub_( n, x, incx, rval); +} + +void DOMATADD(f77_char* transa,f77_char* transb, f77_int* m, f77_int* n, const double* alpha, const double* A, f77_int* lda, const double* beta, const double* B, f77_int* ldb, double* C, f77_int* ldc) +{ + domatadd_( transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc); +} + +void domatadd(f77_char* transa,f77_char* transb, f77_int* m, f77_int* n, const double* alpha, const double* A, f77_int* lda, const double* beta, const double* B, f77_int* ldb, double* C, f77_int* ldc) +{ + domatadd_( transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc); +} + +void DOMATADD_(f77_char* transa,f77_char* transb, f77_int* m, f77_int* n, const double* alpha, const double* A, f77_int* lda, const double* beta, const double* B, f77_int* ldb, double* C, f77_int* ldc) +{ + domatadd_( transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc); +} + +void DOMATCOPY2(f77_char* trans, f77_int* rows, f77_int* cols, const double* alpha, const double* aptr, f77_int* lda,f77_int* stridea, double* bptr, f77_int* ldb,f77_int* strideb) +{ + domatcopy2_( trans, rows, cols, alpha, aptr, lda, stridea, bptr, ldb, strideb); +} + +void domatcopy2(f77_char* trans, f77_int* rows, f77_int* cols, const double* alpha, const double* aptr, f77_int* lda,f77_int* stridea, double* bptr, f77_int* ldb,f77_int* strideb) +{ + domatcopy2_( trans, rows, cols, alpha, aptr, lda, stridea, bptr, ldb, strideb); +} + +void DOMATCOPY2_(f77_char* trans, f77_int* rows, f77_int* cols, const double* alpha, const double* aptr, f77_int* lda,f77_int* stridea, double* bptr, f77_int* ldb,f77_int* strideb) +{ + domatcopy2_( trans, rows, cols, alpha, aptr, lda, stridea, bptr, ldb, strideb); +} + +void DOMATCOPY(f77_char* trans, f77_int* rows, f77_int* cols, const double* alpha, const double* aptr, f77_int* lda, double* bptr, f77_int* ldb) +{ + domatcopy_( trans, rows, cols, alpha, aptr, lda, bptr, ldb); +} + +void domatcopy(f77_char* trans, f77_int* rows, f77_int* cols, const double* alpha, const double* aptr, f77_int* lda, double* bptr, f77_int* ldb) +{ + domatcopy_( trans, rows, cols, alpha, aptr, lda, bptr, ldb); +} + +void DOMATCOPY_(f77_char* trans, f77_int* rows, f77_int* cols, const double* alpha, const double* aptr, f77_int* lda, double* bptr, f77_int* ldb) +{ + domatcopy_( trans, rows, cols, alpha, aptr, lda, bptr, ldb); +} + +void DZASUMSUB(const f77_int* n, const dcomplex* x, const f77_int* incx, double* rval) +{ + dzasumsub_( n, x, incx, rval); +} + +void dzasumsub(const f77_int* n, const dcomplex* x, const f77_int* incx, double* rval) +{ + dzasumsub_( n, x, incx, rval); +} + +void DZASUMSUB_(const f77_int* n, const dcomplex* x, const f77_int* incx, double* rval) +{ + dzasumsub_( n, x, incx, rval); +} + +void DZNRM2SUB(const f77_int* n, const dcomplex* x, const f77_int* incx, double* rval) +{ + dznrm2sub_( n, x, incx, rval); +} + +void dznrm2sub(const f77_int* n, const dcomplex* x, const f77_int* incx, double* rval) +{ + dznrm2sub_( n, x, incx, rval); +} + +void DZNRM2SUB_(const f77_int* n, const dcomplex* x, const f77_int* incx, double* rval) +{ + dznrm2sub_( n, x, incx, rval); +} + +void ICAMAXSUB(const f77_int* n, const scomplex* x, const f77_int* incx, f77_int* rval) +{ + icamaxsub_( n, x, incx, rval); +} + +void icamaxsub(const f77_int* n, const scomplex* x, const f77_int* incx, f77_int* rval) +{ + icamaxsub_( n, x, incx, rval); +} + +void ICAMAXSUB_(const f77_int* n, const scomplex* x, const f77_int* incx, f77_int* rval) +{ + icamaxsub_( n, x, incx, rval); +} + +f77_int ICAMIN( const f77_int* n, const scomplex* x, const f77_int* incx) +{ + return icamin_( n, x, incx); +} + +f77_int icamin( const f77_int* n, const scomplex* x, const f77_int* incx) +{ + return icamin_( n, x, incx); +} + +f77_int ICAMIN_( const f77_int* n, const scomplex* x, const f77_int* incx) +{ + return icamin_( n, x, incx); +} + +void ICAMINSUB( const f77_int* n, const scomplex* x, const f77_int* incx, f77_int* rval) +{ + icaminsub_( n, x, incx, rval); +} + +void icaminsub( const f77_int* n, const scomplex* x, const f77_int* incx, f77_int* rval) +{ + icaminsub_( n, x, incx, rval); +} + +void ICAMINSUB_( const f77_int* n, const scomplex* x, const f77_int* incx, f77_int* rval) +{ + icaminsub_( n, x, incx, rval); +} + +void IDAMAXSUB( const f77_int* n, const double* x, const f77_int* incx, f77_int* rval) +{ + idamaxsub_( n, x, incx, rval); +} + +void idamaxsub( const f77_int* n, const double* x, const f77_int* incx, f77_int* rval) +{ + idamaxsub_( n, x, incx, rval); +} + +void IDAMAXSUB_( const f77_int* n, const double* x, const f77_int* incx, f77_int* rval) +{ + idamaxsub_( n, x, incx, rval); +} + +f77_int IDAMIN( const f77_int* n, const double* x, const f77_int* incx) +{ + return idamin_( n, x, incx); +} + +f77_int idamin( const f77_int* n, const double* x, const f77_int* incx) +{ + return idamin_( n, x, incx); +} + +f77_int IDAMIN_( const f77_int* n, const double* x, const f77_int* incx) +{ + return idamin_( n, x, incx); +} + +void IDAMINSUB(const f77_int* n, const double* x, const f77_int* incx, f77_int* rval) +{ + idaminsub_( n, x, incx, rval); +} + +void idaminsub(const f77_int* n, const double* x, const f77_int* incx, f77_int* rval) +{ + idaminsub_( n, x, incx, rval); +} + +void IDAMINSUB_(const f77_int* n, const double* x, const f77_int* incx, f77_int* rval) +{ + idaminsub_( n, x, incx, rval); +} + +void ISAMAXSUB( const f77_int* n, const float* x, const f77_int* incx, f77_int* rval) +{ + isamaxsub_( n, x, incx, rval); +} + +void isamaxsub( const f77_int* n, const float* x, const f77_int* incx, f77_int* rval) +{ + isamaxsub_( n, x, incx, rval); +} + +void ISAMAXSUB_( const f77_int* n, const float* x, const f77_int* incx, f77_int* rval) +{ + isamaxsub_( n, x, incx, rval); +} + +f77_int ISAMIN( const f77_int* n, const float* x, const f77_int* incx) +{ + return isamin_( n, x, incx); +} + +f77_int isamin( const f77_int* n, const float* x, const f77_int* incx) +{ + return isamin_( n, x, incx); +} + +f77_int ISAMIN_( const f77_int* n, const float* x, const f77_int* incx) +{ + return isamin_( n, x, incx); +} + +void ISAMINSUB( const f77_int* n, const float* x, const f77_int* incx, f77_int* rval) +{ + isaminsub_( n, x, incx, rval); +} + +void isaminsub( const f77_int* n, const float* x, const f77_int* incx, f77_int* rval) +{ + isaminsub_( n, x, incx, rval); +} + +void ISAMINSUB_( const f77_int* n, const float* x, const f77_int* incx, f77_int* rval) +{ + isaminsub_( n, x, incx, rval); +} + +void IZAMAXSUB( const f77_int* n, const dcomplex* x, const f77_int* incx, f77_int* rval) +{ + izamaxsub_( n, x, incx, rval); +} + +void izamaxsub( const f77_int* n, const dcomplex* x, const f77_int* incx, f77_int* rval) +{ + izamaxsub_( n, x, incx, rval); +} + +void IZAMAXSUB_( const f77_int* n, const dcomplex* x, const f77_int* incx, f77_int* rval) +{ + izamaxsub_( n, x, incx, rval); +} + +f77_int IZAMIN( const f77_int* n, const dcomplex* x, const f77_int* incx) +{ + return izamin_( n, x, incx); +} + +f77_int izamin( const f77_int* n, const dcomplex* x, const f77_int* incx) +{ + return izamin_( n, x, incx); +} + +f77_int IZAMIN_( const f77_int* n, const dcomplex* x, const f77_int* incx) +{ + return izamin_( n, x, incx); +} + +void IZAMINSUB( const f77_int* n, const dcomplex* x, const f77_int* incx, f77_int* rval) +{ + izaminsub_( n, x, incx, rval); +} + +void izaminsub( const f77_int* n, const dcomplex* x, const f77_int* incx, f77_int* rval) +{ + izaminsub_( n, x, incx, rval); +} + +void IZAMINSUB_( const f77_int* n, const dcomplex* x, const f77_int* incx, f77_int* rval) +{ + izaminsub_( n, x, incx, rval); +} + +void SASUMSUB( const f77_int* n, const float* x, const f77_int* incx, float* rval) +{ + sasumsub_( n, x, incx, rval); +} + +void sasumsub( const f77_int* n, const float* x, const f77_int* incx, float* rval) +{ + sasumsub_( n, x, incx, rval); +} + +void SASUMSUB_( const f77_int* n, const float* x, const f77_int* incx, float* rval) +{ + sasumsub_( n, x, incx, rval); +} + +void SAXPBY( const f77_int* n, const float* alpha, const float *x, const f77_int* incx, const float* beta, float *y, const f77_int* incy) +{ + saxpby_( n, alpha, x, incx, beta, y, incy); +} + +void saxpby( const f77_int* n, const float* alpha, const float *x, const f77_int* incx, const float* beta, float *y, const f77_int* incy) +{ + saxpby_( n, alpha, x, incx, beta, y, incy); +} + +void SAXPBY_( const f77_int* n, const float* alpha, const float *x, const f77_int* incx, const float* beta, float *y, const f77_int* incy) +{ + saxpby_( n, alpha, x, incx, beta, y, incy); +} + +void SCASUMSUB( const f77_int* n, const scomplex* x, const f77_int* incx, float* rval) +{ + scasumsub_( n, x, incx, rval); +} + +void scasumsub( const f77_int* n, const scomplex* x, const f77_int* incx, float* rval) +{ + scasumsub_( n, x, incx, rval); +} + +void SCASUMSUB_( const f77_int* n, const scomplex* x, const f77_int* incx, float* rval) +{ + scasumsub_( n, x, incx, rval); +} + +void SCNRM2SUB( const f77_int* n, const scomplex* x, const f77_int* incx, float* rval) +{ + scnrm2sub_( n, x, incx, rval); +} + +void scnrm2sub( const f77_int* n, const scomplex* x, const f77_int* incx, float* rval) +{ + scnrm2sub_( n, x, incx, rval); +} + +void SCNRM2SUB_( const f77_int* n, const scomplex* x, const f77_int* incx, float* rval) +{ + scnrm2sub_( n, x, incx, rval); +} + +void SDOTSUB( const f77_int* n, const float* x, const f77_int* incx, const float* y, const f77_int* incy, float* rval) +{ + sdotsub_( n, x, incx, y, incy, rval); +} + +void sdotsub( const f77_int* n, const float* x, const f77_int* incx, const float* y, const f77_int* incy, float* rval) +{ + sdotsub_( n, x, incx, y, incy, rval); +} + +void SDOTSUB_( const f77_int* n, const float* x, const f77_int* incx, const float* y, const f77_int* incy, float* rval) +{ + sdotsub_( n, x, incx, y, incy, rval); +} + +void SGEMM_BATCH(const f77_char* transa_array, const f77_char* transb_array,const f77_int *m_array, const f77_int *n_array, const f77_int *k_array,const float* alpha_array, const float** a_array, const f77_int *lda_array, const float** b_array, const f77_int *ldb_array, const float* beta_array, float** c_array, const f77_int *ldc_array, const f77_int* group_count, const f77_int *group_size) +{ + sgemm_batch_( transa_array, transb_array, m_array, n_array, k_array, alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array, group_count, group_size); +} + +void sgemm_batch(const f77_char* transa_array, const f77_char* transb_array,const f77_int *m_array, const f77_int *n_array, const f77_int *k_array,const float* alpha_array, const float** a_array, const f77_int *lda_array, const float** b_array, const f77_int *ldb_array, const float* beta_array, float** c_array, const f77_int *ldc_array, const f77_int* group_count, const f77_int *group_size) +{ + sgemm_batch_( transa_array, transb_array, m_array, n_array, k_array, alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array, group_count, group_size); +} + +void SGEMM_BATCH_(const f77_char* transa_array, const f77_char* transb_array,const f77_int *m_array, const f77_int *n_array, const f77_int *k_array,const float* alpha_array, const float** a_array, const f77_int *lda_array, const float** b_array, const f77_int *ldb_array, const float* beta_array, float** c_array, const f77_int *ldc_array, const f77_int* group_count, const f77_int *group_size) +{ + sgemm_batch_( transa_array, transb_array, m_array, n_array, k_array, alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array, group_count, group_size); +} + +void SGEMMT( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const float* alpha, const float* a, const f77_int* lda, const float* b, const f77_int* ldb, const float* beta, float* c, const f77_int* ldc) +{ + sgemmt_( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void sgemmt( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const float* alpha, const float* a, const f77_int* lda, const float* b, const f77_int* ldb, const float* beta, float* c, const f77_int* ldc) +{ + sgemmt_( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void SGEMMT_( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const float* alpha, const float* a, const f77_int* lda, const float* b, const f77_int* ldb, const float* beta, float* c, const f77_int* ldc) +{ + sgemmt_( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void SIMATCOPY( f77_char* trans, f77_int* rows, f77_int* cols, const float* alpha,float* aptr, f77_int* lda, f77_int* ldb) +{ + simatcopy_( trans, rows, cols, alpha, aptr, lda, ldb); +} + +void simatcopy( f77_char* trans, f77_int* rows, f77_int* cols, const float* alpha,float* aptr, f77_int* lda, f77_int* ldb) +{ + simatcopy_( trans, rows, cols, alpha, aptr, lda, ldb); +} + +void SIMATCOPY_( f77_char* trans, f77_int* rows, f77_int* cols, const float* alpha,float* aptr, f77_int* lda, f77_int* ldb) +{ + simatcopy_( trans, rows, cols, alpha, aptr, lda, ldb); +} + +void SNRM2SUB( const f77_int* n, const float* x, const f77_int* incx, float *rval) +{ + snrm2sub_( n, x, incx, rval); +} + +void snrm2sub( const f77_int* n, const float* x, const f77_int* incx, float *rval) +{ + snrm2sub_( n, x, incx, rval); +} + +void SNRM2SUB_( const f77_int* n, const float* x, const f77_int* incx, float *rval) +{ + snrm2sub_( n, x, incx, rval); +} + +void SOMATADD( f77_char* transa,f77_char* transb, f77_int* m, f77_int* n, const float* alpha, const float* A, f77_int* lda, const float* beta, const float* B, f77_int* ldb, float* C, f77_int* ldc) +{ + somatadd_( transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc); +} + +void somatadd( f77_char* transa,f77_char* transb, f77_int* m, f77_int* n, const float* alpha, const float* A, f77_int* lda, const float* beta, const float* B, f77_int* ldb, float* C, f77_int* ldc) +{ + somatadd_( transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc); +} + +void SOMATADD_( f77_char* transa,f77_char* transb, f77_int* m, f77_int* n, const float* alpha, const float* A, f77_int* lda, const float* beta, const float* B, f77_int* ldb, float* C, f77_int* ldc) +{ + somatadd_( transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc); +} + +void SOMATCOPY2( f77_char* trans, f77_int* rows, f77_int* cols, const float* alpha, const float* aptr, f77_int* lda,f77_int* stridea, float* bptr, f77_int* ldb,f77_int* strideb) +{ + somatcopy2_( trans, rows, cols, alpha, aptr, lda, stridea, bptr, ldb, strideb); +} + +void somatcopy2( f77_char* trans, f77_int* rows, f77_int* cols, const float* alpha, const float* aptr, f77_int* lda,f77_int* stridea, float* bptr, f77_int* ldb,f77_int* strideb) +{ + somatcopy2_( trans, rows, cols, alpha, aptr, lda, stridea, bptr, ldb, strideb); +} + +void SOMATCOPY2_( f77_char* trans, f77_int* rows, f77_int* cols, const float* alpha, const float* aptr, f77_int* lda,f77_int* stridea, float* bptr, f77_int* ldb,f77_int* strideb) +{ + somatcopy2_( trans, rows, cols, alpha, aptr, lda, stridea, bptr, ldb, strideb); +} + +void SOMATCOPY( f77_char* trans, f77_int* rows, f77_int* cols, const float* alpha, const float* aptr, f77_int* lda, float* bptr, f77_int* ldb) +{ + somatcopy_( trans, rows, cols, alpha, aptr, lda, bptr, ldb); +} + +void somatcopy( f77_char* trans, f77_int* rows, f77_int* cols, const float* alpha, const float* aptr, f77_int* lda, float* bptr, f77_int* ldb) +{ + somatcopy_( trans, rows, cols, alpha, aptr, lda, bptr, ldb); +} + +void SOMATCOPY_( f77_char* trans, f77_int* rows, f77_int* cols, const float* alpha, const float* aptr, f77_int* lda, float* bptr, f77_int* ldb) +{ + somatcopy_( trans, rows, cols, alpha, aptr, lda, bptr, ldb); +} + +void ZAXPBY( const f77_int* n, const dcomplex* alpha, const dcomplex *x, const f77_int* incx, const dcomplex* beta, dcomplex *y, const f77_int* incy) +{ + zaxpby_( n, alpha, x, incx, beta, y, incy); +} + +void zaxpby( const f77_int* n, const dcomplex* alpha, const dcomplex *x, const f77_int* incx, const dcomplex* beta, dcomplex *y, const f77_int* incy) +{ + zaxpby_( n, alpha, x, incx, beta, y, incy); +} + +void ZAXPBY_( const f77_int* n, const dcomplex* alpha, const dcomplex *x, const f77_int* incx, const dcomplex* beta, dcomplex *y, const f77_int* incy) +{ + zaxpby_( n, alpha, x, incx, beta, y, incy); +} + +void ZDOTCSUB( const f77_int* n, const dcomplex* x, const f77_int* incx, const dcomplex* y, const f77_int* incy, dcomplex* rval) +{ + zdotcsub_( n, x, incx, y, incy, rval); +} + +void zdotcsub( const f77_int* n, const dcomplex* x, const f77_int* incx, const dcomplex* y, const f77_int* incy, dcomplex* rval) +{ + zdotcsub_( n, x, incx, y, incy, rval); +} + +void ZDOTCSUB_( const f77_int* n, const dcomplex* x, const f77_int* incx, const dcomplex* y, const f77_int* incy, dcomplex* rval) +{ + zdotcsub_( n, x, incx, y, incy, rval); +} + +void ZDOTUSUB( const f77_int* n, const dcomplex* x, const f77_int* incx,const dcomplex* y, const f77_int* incy, dcomplex* rval) +{ + zdotusub_( n, x, incx, y, incy, rval); +} + +void zdotusub( const f77_int* n, const dcomplex* x, const f77_int* incx,const dcomplex* y, const f77_int* incy, dcomplex* rval) +{ + zdotusub_( n, x, incx, y, incy, rval); +} + +void ZDOTUSUB_( const f77_int* n, const dcomplex* x, const f77_int* incx,const dcomplex* y, const f77_int* incy, dcomplex* rval) +{ + zdotusub_( n, x, incx, y, incy, rval); +} + +void ZGEMM3M( const f77_char* transa, const f77_char* transb, const f77_int* m, const f77_int* n, const f77_int* k, const dcomplex* alpha, const dcomplex* a, const f77_int* lda, const dcomplex* b, const f77_int* ldb, const dcomplex* beta, dcomplex* c, const f77_int* ldc) +{ + zgemm3m_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void zgemm3m( const f77_char* transa, const f77_char* transb, const f77_int* m, const f77_int* n, const f77_int* k, const dcomplex* alpha, const dcomplex* a, const f77_int* lda, const dcomplex* b, const f77_int* ldb, const dcomplex* beta, dcomplex* c, const f77_int* ldc) +{ + zgemm3m_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void ZGEMM3M_( const f77_char* transa, const f77_char* transb, const f77_int* m, const f77_int* n, const f77_int* k, const dcomplex* alpha, const dcomplex* a, const f77_int* lda, const dcomplex* b, const f77_int* ldb, const dcomplex* beta, dcomplex* c, const f77_int* ldc) +{ + zgemm3m_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void ZGEMM_BATCH( const f77_char* transa_array, const f77_char* transb_array,const f77_int *m_array, const f77_int *n_array, const f77_int *k_array,const dcomplex* alpha_array, const dcomplex** a_array, const f77_int *lda_array, const dcomplex** b_array, const f77_int *ldb_array, const dcomplex* beta_array, dcomplex** c_array, const f77_int *ldc_array, const f77_int* group_count, const f77_int *group_size) +{ + zgemm_batch_( transa_array, transb_array, m_array, n_array, k_array, alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array, group_count, group_size); +} + +void zgemm_batch( const f77_char* transa_array, const f77_char* transb_array,const f77_int *m_array, const f77_int *n_array, const f77_int *k_array,const dcomplex* alpha_array, const dcomplex** a_array, const f77_int *lda_array, const dcomplex** b_array, const f77_int *ldb_array, const dcomplex* beta_array, dcomplex** c_array, const f77_int *ldc_array, const f77_int* group_count, const f77_int *group_size) +{ + zgemm_batch_( transa_array, transb_array, m_array, n_array, k_array, alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array, group_count, group_size); +} + +void ZGEMM_BATCH_( const f77_char* transa_array, const f77_char* transb_array,const f77_int *m_array, const f77_int *n_array, const f77_int *k_array,const dcomplex* alpha_array, const dcomplex** a_array, const f77_int *lda_array, const dcomplex** b_array, const f77_int *ldb_array, const dcomplex* beta_array, dcomplex** c_array, const f77_int *ldc_array, const f77_int* group_count, const f77_int *group_size) +{ + zgemm_batch_( transa_array, transb_array, m_array, n_array, k_array, alpha_array, a_array, lda_array, b_array, ldb_array, beta_array, c_array, ldc_array, group_count, group_size); +} + +void ZGEMMT( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const dcomplex* alpha, const dcomplex* a, const f77_int* lda, const dcomplex* b, const f77_int* ldb, const dcomplex* beta, dcomplex* c, const f77_int* ldc) +{ + zgemmt_( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void zgemmt( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const dcomplex* alpha, const dcomplex* a, const f77_int* lda, const dcomplex* b, const f77_int* ldb, const dcomplex* beta, dcomplex* c, const f77_int* ldc) +{ + zgemmt_( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void ZGEMMT_( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const dcomplex* alpha, const dcomplex* a, const f77_int* lda, const dcomplex* b, const f77_int* ldb, const dcomplex* beta, dcomplex* c, const f77_int* ldc) +{ + zgemmt_( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void ZIMATCOPY(f77_char* trans, f77_int* rows, f77_int* cols, const dcomplex* alpha,dcomplex* aptr, f77_int* lda, f77_int* ldb) +{ + zimatcopy_( trans, rows, cols, alpha, aptr, lda, ldb); +} + +void zimatcopy(f77_char* trans, f77_int* rows, f77_int* cols, const dcomplex* alpha,dcomplex* aptr, f77_int* lda, f77_int* ldb) +{ + zimatcopy_( trans, rows, cols, alpha, aptr, lda, ldb); +} + +void ZIMATCOPY_(f77_char* trans, f77_int* rows, f77_int* cols, const dcomplex* alpha,dcomplex* aptr, f77_int* lda, f77_int* ldb) +{ + zimatcopy_( trans, rows, cols, alpha, aptr, lda, ldb); +} + +void ZOMATADD(f77_char* transa,f77_char* transb, f77_int* m, f77_int* n, const dcomplex* alpha, const dcomplex* A, f77_int* lda,const dcomplex* beta, dcomplex* B, f77_int* ldb, dcomplex* C, f77_int* ldc) +{ + zomatadd_( transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc); +} + +void zomatadd(f77_char* transa,f77_char* transb, f77_int* m, f77_int* n, const dcomplex* alpha, const dcomplex* A, f77_int* lda,const dcomplex* beta, dcomplex* B, f77_int* ldb, dcomplex* C, f77_int* ldc) +{ + zomatadd_( transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc); +} + +void ZOMATADD_(f77_char* transa,f77_char* transb, f77_int* m, f77_int* n, const dcomplex* alpha, const dcomplex* A, f77_int* lda,const dcomplex* beta, dcomplex* B, f77_int* ldb, dcomplex* C, f77_int* ldc) +{ + zomatadd_( transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc); +} + +void ZOMATCOPY2(f77_char* trans, f77_int* rows, f77_int* cols, const dcomplex* alpha, const dcomplex* aptr, f77_int* lda,f77_int* stridea, dcomplex* bptr, f77_int* ldb,f77_int* strideb) +{ + zomatcopy2_( trans, rows, cols, alpha, aptr, lda, stridea, bptr, ldb, strideb); +} + +void zomatcopy2(f77_char* trans, f77_int* rows, f77_int* cols, const dcomplex* alpha, const dcomplex* aptr, f77_int* lda,f77_int* stridea, dcomplex* bptr, f77_int* ldb,f77_int* strideb) +{ + zomatcopy2_( trans, rows, cols, alpha, aptr, lda, stridea, bptr, ldb, strideb); +} + +void ZOMATCOPY2_(f77_char* trans, f77_int* rows, f77_int* cols, const dcomplex* alpha, const dcomplex* aptr, f77_int* lda,f77_int* stridea, dcomplex* bptr, f77_int* ldb,f77_int* strideb) +{ + zomatcopy2_( trans, rows, cols, alpha, aptr, lda, stridea, bptr, ldb, strideb); +} + +void ZOMATCOPY(f77_char* trans, f77_int* rows, f77_int* cols, const dcomplex* alpha, const dcomplex* aptr, f77_int* lda, dcomplex* bptr, f77_int* ldb) +{ + zomatcopy_( trans, rows, cols, alpha, aptr, lda, bptr, ldb); +} + +void zomatcopy(f77_char* trans, f77_int* rows, f77_int* cols, const dcomplex* alpha, const dcomplex* aptr, f77_int* lda, dcomplex* bptr, f77_int* ldb) +{ + zomatcopy_( trans, rows, cols, alpha, aptr, lda, bptr, ldb); +} + +void ZOMATCOPY_(f77_char* trans, f77_int* rows, f77_int* cols, const dcomplex* alpha, const dcomplex* aptr, f77_int* lda, dcomplex* bptr, f77_int* ldb) +{ + zomatcopy_( trans, rows, cols, alpha, aptr, lda, bptr, ldb); +} + + + +float SCABS1(bla_scomplex* z) +{ + return scabs1_( z); +} + +float scabs1(bla_scomplex* z) +{ + return scabs1_( z); +} + +float SCABS1_(bla_scomplex* z) +{ + return scabs1_( z); + +} + +void SDSDOTSUB( const f77_int* n, float* sb, const float* x, const f77_int* incx, const float* y, const f77_int* incy, float* dot) +{ + sdsdotsub_( n, sb, x, incx, y, incy, dot); +} + +void sdsdotsub( const f77_int* n, float* sb, const float* x, const f77_int* incx, const float* y, const f77_int* incy, float* dot) +{ + sdsdotsub_( n, sb, x, incx, y, incy, dot); +} + +void SDSDOTSUB_( const f77_int* n, float* sb, const float* x, const f77_int* incx, const float* y, const f77_int* incy, float* dot) +{ + sdsdotsub_( n, sb, x, incx, y, incy, dot); +} + +void DSDOTSUB( const f77_int* n, const float* x, const f77_int* incx, const float* y, const f77_int* incy, double* dot) +{ + dsdotsub_( n, x, incx, y, incy, dot); +} + +void dsdotsub( const f77_int* n, const float* x, const f77_int* incx, const float* y, const f77_int* incy, double* dot) +{ + dsdotsub_( n, x, incx, y, incy, dot); +} + +void DSDOTSUB_( const f77_int* n, const float* x, const f77_int* incx, const float* y, const f77_int* incy, double* dot) +{ + dsdotsub_( n, x, incx, y, incy, dot); +} + +void CAXPBY( const f77_int* n, const scomplex* alpha, const scomplex *x, const f77_int* incx, const scomplex* beta, scomplex *y, const f77_int* incy) +{ + caxpby_(n, alpha, x, incx, beta, y, incy); +} + +void caxpby( const f77_int* n, const scomplex* alpha, const scomplex *x, const f77_int* incx, const scomplex* beta, scomplex *y, const f77_int* incy) +{ + caxpby_(n, alpha, x, incx, beta, y, incy); +} + +void CAXPBY_( const f77_int* n, const scomplex* alpha, const scomplex *x, const f77_int* incx, const scomplex* beta, scomplex *y, const f77_int* incy) +{ + caxpby_(n, alpha, x, incx, beta, y, incy); +} \ No newline at end of file diff --git a/frame/util/bli_util_api_wrap.h b/frame/util/bli_util_api_wrap.h new file mode 100644 index 0000000000..46d5a636a2 --- /dev/null +++ b/frame/util/bli_util_api_wrap.h @@ -0,0 +1,1727 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +//Level 1 APIs +BLIS_EXPORT_BLIS void SROTG(float *sa, float *sb, float *c, float *s); + +BLIS_EXPORT_BLIS void srotg(float *sa, float *sb, float *c, float *s); + +BLIS_EXPORT_BLIS void SROTG_(float *sa, float *sb, float *c, float *s); + + + +BLIS_EXPORT_BLIS void SROTMG(float *sd1, float *sd2, float *sx1, const float *sy1, float *sparam); + +BLIS_EXPORT_BLIS void srotmg(float *sd1, float *sd2, float *sx1, const float *sy1, float *sparam); + +BLIS_EXPORT_BLIS void SROTMG_(float *sd1, float *sd2, float *sx1, const float *sy1, float *sparam); + + + +BLIS_EXPORT_BLIS void SROT(const f77_int *n, float *sx, const f77_int *incx, float *sy, const f77_int *incy, const float *c, const float *s); + +BLIS_EXPORT_BLIS void srot(const f77_int *n, float *sx, const f77_int *incx, float *sy, const f77_int *incy, const float *c, const float *s); + +BLIS_EXPORT_BLIS void SROT_(const f77_int *n, float *sx, const f77_int *incx, float *sy, const f77_int *incy, const float *c, const float *s); + + + +BLIS_EXPORT_BLIS void SROTM(const f77_int *n, float *sx, const f77_int *incx, float *sy, const f77_int *incy, const float *sparam); + +BLIS_EXPORT_BLIS void srotm(const f77_int *n, float *sx, const f77_int *incx, float *sy, const f77_int *incy, const float *sparam); + +BLIS_EXPORT_BLIS void SROTM_(const f77_int *n, float *sx, const f77_int *incx, float *sy, const f77_int *incy, const float *sparam); + + + +BLIS_EXPORT_BLIS void SSWAP(const f77_int *n, float *sx, const f77_int *incx, float *sy, const f77_int *incy); + +BLIS_EXPORT_BLIS void sswap(const f77_int *n, float *sx, const f77_int *incx, float *sy, const f77_int *incy); + +BLIS_EXPORT_BLIS void SSWAP_(const f77_int *n, float *sx, const f77_int *incx, float *sy, const f77_int *incy); + + + +BLIS_EXPORT_BLIS void SSCAL(const f77_int *n, const float *sa, float *sx, const f77_int *incx); + +BLIS_EXPORT_BLIS void sscal(const f77_int *n, const float *sa, float *sx, const f77_int *incx); + +BLIS_EXPORT_BLIS void SSCAL_(const f77_int *n, const float *sa, float *sx, const f77_int *incx); + + + +BLIS_EXPORT_BLIS void SCOPY(const f77_int *n, const float *sx, const f77_int *incx, float *sy, const f77_int *incy); + +BLIS_EXPORT_BLIS void scopy(const f77_int *n, const float *sx, const f77_int *incx, float *sy, const f77_int *incy); + +BLIS_EXPORT_BLIS void SCOPY_(const f77_int *n, const float *sx, const f77_int *incx, float *sy, const f77_int *incy); + + + +BLIS_EXPORT_BLIS void SAXPY(const f77_int *n, const float *sa, const float *sx, const f77_int *incx, float *sy, const f77_int *incy); + +BLIS_EXPORT_BLIS void saxpy(const f77_int *n, const float *sa, const float *sx, const f77_int *incx, float *sy, const f77_int *incy); + +BLIS_EXPORT_BLIS void SAXPY_(const f77_int *n, const float *sa, const float *sx, const f77_int *incx, float *sy, const f77_int *incy); + + + +BLIS_EXPORT_BLIS float SDOT(const f77_int *n, const float *sx, const f77_int *incx, const float *sy, const f77_int *incy); + +BLIS_EXPORT_BLIS float sdot(const f77_int *n, const float *sx, const f77_int *incx, const float *sy, const f77_int *incy); + +BLIS_EXPORT_BLIS float SDOT_(const f77_int *n, const float *sx, const f77_int *incx, const float *sy, const f77_int *incy); + + + +BLIS_EXPORT_BLIS float SDSDOT(const f77_int *n, const float *sb, const float *sx, const f77_int *incx, const float *sy, const f77_int *incy); + +BLIS_EXPORT_BLIS float sdsdot(const f77_int *n, const float *sb, const float *sx, const f77_int *incx, const float *sy, const f77_int *incy); + +BLIS_EXPORT_BLIS float SDSDOT_(const f77_int *n, const float *sb, const float *sx, const f77_int *incx, const float *sy, const f77_int *incy); + + + +BLIS_EXPORT_BLIS float SNRM2(const f77_int *n, const float *x, const f77_int *incx); + +BLIS_EXPORT_BLIS float snrm2(const f77_int *n, const float *x, const f77_int *incx); + +BLIS_EXPORT_BLIS float SNRM2_(const f77_int *n, const float *x, const f77_int *incx); + + + +BLIS_EXPORT_BLIS float SCNRM2(const f77_int *n, const scomplex *x, const f77_int *incx); + +BLIS_EXPORT_BLIS float scnrm2(const f77_int *n, const scomplex *x, const f77_int *incx); + +BLIS_EXPORT_BLIS float SCNRM2_(const f77_int *n, const scomplex *x, const f77_int *incx); + + + +BLIS_EXPORT_BLIS float SASUM(const f77_int *n, const float *sx, const f77_int *incx); + +BLIS_EXPORT_BLIS float sasum(const f77_int *n, const float *sx, const f77_int *incx); + +BLIS_EXPORT_BLIS float SASUM_(const f77_int *n, const float *sx, const f77_int *incx); + + + +BLIS_EXPORT_BLIS f77_int ISAMAX(const f77_int *n, const float *sx, const f77_int *incx); + +BLIS_EXPORT_BLIS f77_int isamax(const f77_int *n, const float *sx, const f77_int *incx); + +BLIS_EXPORT_BLIS f77_int ISAMAX_(const f77_int *n, const float *sx, const f77_int *incx); + + + +BLIS_EXPORT_BLIS void DROTG(double *da, double *db, double *c, double *s); + +BLIS_EXPORT_BLIS void drotg(double *da, double *db, double *c, double *s); + +BLIS_EXPORT_BLIS void DROTG_(double *da, double *db, double *c, double *s); + + + +BLIS_EXPORT_BLIS void DROTMG(double *dd1, double *dd2, double *dx1, const double *dy1, double *dparam); + +BLIS_EXPORT_BLIS void drotmg(double *dd1, double *dd2, double *dx1, const double *dy1, double *dparam); + +BLIS_EXPORT_BLIS void DROTMG_(double *dd1, double *dd2, double *dx1, const double *dy1, double *dparam); + + + +BLIS_EXPORT_BLIS void DROT(const f77_int *n, double *dx, const f77_int *incx, double *dy, const f77_int *incy, const double *c, const double *s); + +BLIS_EXPORT_BLIS void drot(const f77_int *n, double *dx, const f77_int *incx, double *dy, const f77_int *incy, const double *c, const double *s); + +BLIS_EXPORT_BLIS void DROT_(const f77_int *n, double *dx, const f77_int *incx, double *dy, const f77_int *incy, const double *c, const double *s); + + + +BLIS_EXPORT_BLIS void DROTM(const f77_int *n, double *dx, const f77_int *incx, double *dy, const f77_int *incy, const double *dparam); + +BLIS_EXPORT_BLIS void drotm(const f77_int *n, double *dx, const f77_int *incx, double *dy, const f77_int *incy, const double *dparam); + +BLIS_EXPORT_BLIS void DROTM_(const f77_int *n, double *dx, const f77_int *incx, double *dy, const f77_int *incy, const double *dparam); + + + +BLIS_EXPORT_BLIS void DSWAP(const f77_int *n, double *dx, const f77_int *incx, double *dy, const f77_int *incy); + +BLIS_EXPORT_BLIS void dswap(const f77_int *n, double *dx, const f77_int *incx, double *dy, const f77_int *incy); + +BLIS_EXPORT_BLIS void DSWAP_(const f77_int *n, double *dx, const f77_int *incx, double *dy, const f77_int *incy); + + + +BLIS_EXPORT_BLIS void DSCAL(const f77_int *n, const double *da, double *dx, const f77_int *incx); + +BLIS_EXPORT_BLIS void dscal(const f77_int *n, const double *da, double *dx, const f77_int *incx); + +BLIS_EXPORT_BLIS void DSCAL_(const f77_int *n, const double *da, double *dx, const f77_int *incx); + + + +BLIS_EXPORT_BLIS void DCOPY(const f77_int *n, const double *dx, const f77_int *incx, double *dy, const f77_int *incy); + +BLIS_EXPORT_BLIS void dcopy(const f77_int *n, const double *dx, const f77_int *incx, double *dy, const f77_int *incy); + +BLIS_EXPORT_BLIS void DCOPY_(const f77_int *n, const double *dx, const f77_int *incx, double *dy, const f77_int *incy); + + + +BLIS_EXPORT_BLIS void DAXPY(const f77_int *n, const double *da, const double *dx, const f77_int *incx, double *dy, const f77_int *incy); + +BLIS_EXPORT_BLIS void daxpy(const f77_int *n, const double *da, const double *dx, const f77_int *incx, double *dy, const f77_int *incy); + +BLIS_EXPORT_BLIS void DAXPY_(const f77_int *n, const double *da, const double *dx, const f77_int *incx, double *dy, const f77_int *incy); + + + +BLIS_EXPORT_BLIS double DDOT(const f77_int *n, const double *dx, const f77_int *incx, const double *dy, const f77_int *incy); + +BLIS_EXPORT_BLIS double ddot(const f77_int *n, const double *dx, const f77_int *incx, const double *dy, const f77_int *incy); + +BLIS_EXPORT_BLIS double DDOT_(const f77_int *n, const double *dx, const f77_int *incx, const double *dy, const f77_int *incy); + + + +BLIS_EXPORT_BLIS double DSDOT(const f77_int *n, const float *sx, const f77_int *incx, const float *sy, const f77_int *incy); + +BLIS_EXPORT_BLIS double dsdot(const f77_int *n, const float *sx, const f77_int *incx, const float *sy, const f77_int *incy); + +BLIS_EXPORT_BLIS double DSDOT_(const f77_int *n, const float *sx, const f77_int *incx, const float *sy, const f77_int *incy); + + + +BLIS_EXPORT_BLIS double DNRM2(const f77_int *n, const double *x, const f77_int *incx); + +BLIS_EXPORT_BLIS double dnrm2(const f77_int *n, const double *x, const f77_int *incx); + +BLIS_EXPORT_BLIS double DNRM2_(const f77_int *n, const double *x, const f77_int *incx); + + + +BLIS_EXPORT_BLIS double DZNRM2(const f77_int *n, const dcomplex *x, const f77_int *incx); + +BLIS_EXPORT_BLIS double dznrm2(const f77_int *n, const dcomplex *x, const f77_int *incx); + +BLIS_EXPORT_BLIS double DZNRM2_(const f77_int *n, const dcomplex *x, const f77_int *incx); + + + +BLIS_EXPORT_BLIS double DASUM(const f77_int *n, const double *dx, const f77_int *incx); + +BLIS_EXPORT_BLIS double dasum(const f77_int *n, const double *dx, const f77_int *incx); + +BLIS_EXPORT_BLIS double DASUM_(const f77_int *n, const double *dx, const f77_int *incx); + + + +BLIS_EXPORT_BLIS f77_int IDAMAX(const f77_int *n, const double *dx, const f77_int *incx); + +BLIS_EXPORT_BLIS f77_int idamax(const f77_int *n, const double *dx, const f77_int *incx); + +BLIS_EXPORT_BLIS f77_int IDAMAX_(const f77_int *n, const double *dx, const f77_int *incx); + + + +BLIS_EXPORT_BLIS void CROTG(scomplex *ca, bla_scomplex *cb, bla_real *c, scomplex *s); + +BLIS_EXPORT_BLIS void crotg(scomplex *ca, bla_scomplex *cb, bla_real *c, scomplex *s); + +BLIS_EXPORT_BLIS void CROTG_(scomplex *ca, bla_scomplex *cb, bla_real *c, scomplex *s); + + + +BLIS_EXPORT_BLIS void CSROT(const f77_int *n, scomplex *cx, const f77_int *incx, scomplex *cy, const f77_int *incy, const float *c, const float *s); + +BLIS_EXPORT_BLIS void csrot(const f77_int *n, scomplex *cx, const f77_int *incx, scomplex *cy, const f77_int *incy, const float *c, const float *s); + +BLIS_EXPORT_BLIS void CSROT_(const f77_int *n, scomplex *cx, const f77_int *incx, scomplex *cy, const f77_int *incy, const float *c, const float *s); + + + +BLIS_EXPORT_BLIS void CSWAP(const f77_int *n, scomplex *cx, const f77_int *incx, scomplex *cy, const f77_int *incy); + +BLIS_EXPORT_BLIS void cswap(const f77_int *n, scomplex *cx, const f77_int *incx, scomplex *cy, const f77_int *incy); + +BLIS_EXPORT_BLIS void CSWAP_(const f77_int *n, scomplex *cx, const f77_int *incx, scomplex *cy, const f77_int *incy); + + + +BLIS_EXPORT_BLIS void CSCAL(const f77_int *n, const scomplex *ca, scomplex *cx, const f77_int *incx); + +BLIS_EXPORT_BLIS void cscal(const f77_int *n, const scomplex *ca, scomplex *cx, const f77_int *incx); + +BLIS_EXPORT_BLIS void CSCAL_(const f77_int *n, const scomplex *ca, scomplex *cx, const f77_int *incx); + + +BLIS_EXPORT_BLIS void CSSCAL(const f77_int *n, const float *sa, scomplex *cx, const f77_int *incx); + +BLIS_EXPORT_BLIS void csscal(const f77_int *n, const float *sa, scomplex *cx, const f77_int *incx); + +BLIS_EXPORT_BLIS void CSSCAL_(const f77_int *n, const float *sa, scomplex *cx, const f77_int *incx); + + +BLIS_EXPORT_BLIS void CCOPY(const f77_int *n, const scomplex *cx, const f77_int *incx, scomplex *cy, const f77_int *incy); + +BLIS_EXPORT_BLIS void ccopy(const f77_int *n, const scomplex *cx, const f77_int *incx, scomplex *cy, const f77_int *incy); + +BLIS_EXPORT_BLIS void CCOPY_(const f77_int *n, const scomplex *cx, const f77_int *incx, scomplex *cy, const f77_int *incy); + + +BLIS_EXPORT_BLIS void CAXPY(const f77_int *n, const scomplex *ca, const scomplex *cx, const f77_int *incx, scomplex *cy, const f77_int *incy); + +BLIS_EXPORT_BLIS void caxpy(const f77_int *n, const scomplex *ca, const scomplex *cx, const f77_int *incx, scomplex *cy, const f77_int *incy); + +BLIS_EXPORT_BLIS void CAXPY_(const f77_int *n, const scomplex *ca, const scomplex *cx, const f77_int *incx,scomplex *cy, const f77_int *incy); + + +#ifdef BLIS_DISABLE_COMPLEX_RETURN_INTEL + +BLIS_EXPORT_BLIS scomplex CDOTC(const f77_int* n, const scomplex* x, const f77_int* incx, const scomplex* y, const f77_int* incy); + +BLIS_EXPORT_BLIS scomplex cdotc(const f77_int* n, const scomplex* x, const f77_int* incx, const scomplex* y, const f77_int* incy); + +BLIS_EXPORT_BLIS scomplex CDOTC_ (const f77_int* n, const scomplex* x, const f77_int* incx, const scomplex* y, const f77_int* incy); + + + +BLIS_EXPORT_BLIS scomplex CDOTU(const f77_int* n, const scomplex* x, const f77_int* incx,const scomplex* y, const f77_int* incy); + +BLIS_EXPORT_BLIS scomplex cdotu(const f77_int* n, const scomplex* x, const f77_int* incx,const scomplex* y, const f77_int* incy); + +BLIS_EXPORT_BLIS scomplex CDOTU_(const f77_int* n, const scomplex* x, const f77_int* incx,const scomplex* y, const f77_int* incy); + + + +BLIS_EXPORT_BLIS dcomplex ZDOTC(const f77_int* n, const dcomplex* x, const f77_int* incx, const dcomplex* y, const f77_int* incy); + +BLIS_EXPORT_BLIS dcomplex zdotc (const f77_int* n, const dcomplex* x, const f77_int* incx, const dcomplex* y, const f77_int* incy); + +BLIS_EXPORT_BLIS dcomplex ZDOTC_ (const f77_int* n, const dcomplex* x, const f77_int* incx, const dcomplex* y, const f77_int* incy); + + + +BLIS_EXPORT_BLIS dcomplex ZDOTU(const f77_int* n, const dcomplex* x, const f77_int* incx, const dcomplex* y, const f77_int* incy); + +BLIS_EXPORT_BLIS dcomplex zdotu (const f77_int* n, const dcomplex* x, const f77_int* incx, const dcomplex* y, const f77_int* incy); + +BLIS_EXPORT_BLIS dcomplex ZDOTU_(const f77_int* n, const dcomplex* x, const f77_int* incx, const dcomplex* y, const f77_int* incy); + +#else + +BLIS_EXPORT_BLIS void CDOTC(scomplex* retval, const f77_int *n, const scomplex *cx, const f77_int *incx, const scomplex *cy, const f77_int *incy); + +BLIS_EXPORT_BLIS void cdotc(scomplex* retval, const f77_int *n, const scomplex *cx, const f77_int *incx, const scomplex *cy, const f77_int *incy); + +BLIS_EXPORT_BLIS void CDOTC_(scomplex* retval, const f77_int *n, const scomplex *cx, const f77_int *incx, const scomplex *cy, const f77_int *incy); + + + +BLIS_EXPORT_BLIS void CDOTU(scomplex* retval, const f77_int *n, const scomplex *cx, const f77_int *incx, const scomplex *cy, const f77_int *incy); + +BLIS_EXPORT_BLIS void cdotu(scomplex* retval, const f77_int *n, const scomplex *cx, const f77_int *incx, const scomplex *cy, const f77_int *incy); + +BLIS_EXPORT_BLIS void CDOTU_(scomplex* retval, const f77_int *n, const scomplex *cx, const f77_int *incx, const scomplex *cy, const f77_int *incy); + + + +BLIS_EXPORT_BLIS void ZDOTC(dcomplex* retval, const f77_int *n, const dcomplex *zx, const f77_int *incx, const dcomplex *zy, const f77_int *incy); + +BLIS_EXPORT_BLIS void zdotc(dcomplex* retval, const f77_int *n, const dcomplex *zx, const f77_int *incx, const dcomplex *zy, const f77_int *incy); + +BLIS_EXPORT_BLIS void ZDOTC_(dcomplex* retval, const f77_int *n, const dcomplex *zx, const f77_int *incx, const dcomplex *zy, const f77_int *incy); + + + +BLIS_EXPORT_BLIS void ZDOTU(dcomplex* retval, const f77_int *n, const dcomplex *zx, const f77_int *incx, const dcomplex *zy, const f77_int *incy); + +BLIS_EXPORT_BLIS void zdotu(dcomplex* retval, const f77_int *n, const dcomplex *zx, const f77_int *incx, const dcomplex *zy, const f77_int *incy); + +BLIS_EXPORT_BLIS void ZDOTU_(dcomplex* retval, const f77_int *n, const dcomplex *zx, const f77_int *incx, const dcomplex *zy, const f77_int *incy); + +#endif + + +BLIS_EXPORT_BLIS float SCASUM(const f77_int *n, const scomplex *cx, const f77_int *incx); + +BLIS_EXPORT_BLIS float scasum(const f77_int *n, const scomplex *cx, const f77_int *incx); + +BLIS_EXPORT_BLIS float SCASUM_(const f77_int *n, const scomplex *cx, const f77_int *incx); + + + +BLIS_EXPORT_BLIS f77_int ICAMAX(const f77_int *n, const scomplex *cx, const f77_int *incx); + +BLIS_EXPORT_BLIS f77_int icamax(const f77_int *n, const scomplex *cx, const f77_int *incx); + +BLIS_EXPORT_BLIS f77_int ICAMAX_(const f77_int *n, const scomplex *cx, const f77_int *incx); + + + +BLIS_EXPORT_BLIS void ZROTG(dcomplex *ca, bla_dcomplex *cb, bla_double *c, dcomplex *s); + +BLIS_EXPORT_BLIS void zrotg(dcomplex *ca, bla_dcomplex *cb, bla_double *c, dcomplex *s); + +BLIS_EXPORT_BLIS void ZROTG_(dcomplex *ca, bla_dcomplex *cb, bla_double *c, dcomplex *s); + + + +BLIS_EXPORT_BLIS void ZDROT(const f77_int *n, dcomplex *cx, const f77_int *incx, dcomplex *cy, const f77_int *incy, const double *c, const double *s); + +BLIS_EXPORT_BLIS void zdrot(const f77_int *n, dcomplex *cx, const f77_int *incx, dcomplex *cy, const f77_int *incy, const double *c, const double *s); + +BLIS_EXPORT_BLIS void ZDROT_(const f77_int *n, dcomplex *cx, const f77_int *incx, dcomplex *cy, const f77_int *incy, const double *c, const double *s); + + + +BLIS_EXPORT_BLIS void ZSWAP(const f77_int *n, dcomplex *zx, const f77_int *incx, dcomplex *zy, const f77_int *incy); + +BLIS_EXPORT_BLIS void zswap(const f77_int *n, dcomplex *zx, const f77_int *incx, dcomplex *zy, const f77_int *incy); + +BLIS_EXPORT_BLIS void ZSWAP_(const f77_int *n, dcomplex *zx, const f77_int *incx, dcomplex *zy, const f77_int *incy); + + + +BLIS_EXPORT_BLIS void ZSCAL(const f77_int *n, const dcomplex *za, dcomplex *zx, const f77_int *incx); + +BLIS_EXPORT_BLIS void zscal(const f77_int *n, const dcomplex *za, dcomplex *zx, const f77_int *incx); + +BLIS_EXPORT_BLIS void ZSCAL_(const f77_int *n, const dcomplex *za, dcomplex *zx, const f77_int *incx); + + + +BLIS_EXPORT_BLIS void ZDSCAL(const f77_int *n, const double *da, dcomplex *zx, const f77_int *incx); + +BLIS_EXPORT_BLIS void zdscal(const f77_int *n, const double *da, dcomplex *zx, const f77_int *incx); + +BLIS_EXPORT_BLIS void ZDSCAL_(const f77_int *n, const double *da, dcomplex *zx, const f77_int *incx); + + + +BLIS_EXPORT_BLIS void ZCOPY(const f77_int *n, const dcomplex *zx, const f77_int *incx, dcomplex *zy, const f77_int *incy); + +BLIS_EXPORT_BLIS void zcopy(const f77_int *n, const dcomplex *zx, const f77_int *incx, dcomplex *zy, const f77_int *incy); + +BLIS_EXPORT_BLIS void ZCOPY_(const f77_int *n, const dcomplex *zx, const f77_int *incx, dcomplex *zy, const f77_int *incy); + + + +BLIS_EXPORT_BLIS void ZAXPY(const f77_int *n, const dcomplex *za, const dcomplex *zx, const f77_int *incx, dcomplex *zy, const f77_int *incy); + +BLIS_EXPORT_BLIS void zaxpy(const f77_int *n, const dcomplex *za, const dcomplex *zx, const f77_int *incx, dcomplex *zy, const f77_int *incy); + +BLIS_EXPORT_BLIS void ZAXPY_(const f77_int *n, const dcomplex *za, const dcomplex *zx, const f77_int *incx, dcomplex *zy, const f77_int *incy); + + + +BLIS_EXPORT_BLIS double DZASUM(const f77_int *n, const dcomplex *zx, const f77_int *incx); + +BLIS_EXPORT_BLIS double dzasum(const f77_int *n, const dcomplex *zx, const f77_int *incx); + +BLIS_EXPORT_BLIS double DZASUM_(const f77_int *n, const dcomplex *zx, const f77_int *incx); + + + +BLIS_EXPORT_BLIS f77_int IZAMAX(const f77_int *n, const dcomplex *zx, const f77_int *incx); + +BLIS_EXPORT_BLIS f77_int izamax(const f77_int *n, const dcomplex *zx, const f77_int *incx); + +BLIS_EXPORT_BLIS f77_int IZAMAX_(const f77_int *n, const dcomplex *zx, const f77_int *incx); + + + +BLIS_EXPORT_BLIS f77_int ICAMIN( const f77_int* n, const scomplex* x, const f77_int* incx); + +BLIS_EXPORT_BLIS f77_int icamin( const f77_int* n, const scomplex* x, const f77_int* incx); + +BLIS_EXPORT_BLIS f77_int ICAMIN_( const f77_int* n, const scomplex* x, const f77_int* incx); + + + +BLIS_EXPORT_BLIS f77_int IDAMIN( const f77_int* n, const double* x, const f77_int* incx); + +BLIS_EXPORT_BLIS f77_int idamin( const f77_int* n, const double* x, const f77_int* incx); + +BLIS_EXPORT_BLIS f77_int IDAMIN_( const f77_int* n, const double* x, const f77_int* incx); + + + +BLIS_EXPORT_BLIS f77_int ISAMIN( const f77_int* n, const float* x, const f77_int* incx); + +BLIS_EXPORT_BLIS f77_int isamin( const f77_int* n, const float* x, const f77_int* incx); + +BLIS_EXPORT_BLIS f77_int ISAMIN_( const f77_int* n, const float* x, const f77_int* incx); + + + +BLIS_EXPORT_BLIS f77_int IZAMIN( const f77_int* n, const dcomplex* x, const f77_int* incx); + +BLIS_EXPORT_BLIS f77_int izamin( const f77_int* n, const dcomplex* x, const f77_int* incx); + +BLIS_EXPORT_BLIS f77_int IZAMIN_( const f77_int* n, const dcomplex* x, const f77_int* incx); + + + +//Level 2 APIs +BLIS_EXPORT_BLIS void SGEMV(const char *trans, const f77_int *m, const f77_int *n, const float *alpha, const float *a, const f77_int *lda, const float *x, const f77_int *incx, const float *beta, float *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void sgemv(const char *trans, const f77_int *m, const f77_int *n, const float *alpha, const float *a, const f77_int *lda, const float *x, const f77_int *incx, const float *beta, float *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void SGEMV_(const char *trans, const f77_int *m, const f77_int *n, const float *alpha, const float *a, const f77_int *lda, const float *x, const f77_int *incx, const float *beta, float *y, const f77_int *incy); + + + +BLIS_EXPORT_BLIS void SGBMV(const char *trans, const f77_int *m, const f77_int *n, const f77_int *kl, const f77_int *ku, const float *alpha, const float *a, const f77_int *lda, const float *x, const f77_int *incx, const float *beta, float *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void sgbmv(const char *trans, const f77_int *m, const f77_int *n, const f77_int *kl, const f77_int *ku, const float *alpha, const float *a, const f77_int *lda, const float *x, const f77_int *incx, const float *beta, float *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void SGBMV_(const char *trans, const f77_int *m, const f77_int *n, const f77_int *kl, const f77_int *ku, const float *alpha, const float *a, const f77_int *lda, const float *x, const f77_int *incx, const float *beta, float *y, const f77_int *incy); + + + +BLIS_EXPORT_BLIS void SSYMV(const char *uplo, const f77_int *n, const float *alpha, const float *a, const f77_int *lda, const float *x, const f77_int *incx, const float *beta, float *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void ssymv(const char *uplo, const f77_int *n, const float *alpha, const float *a, const f77_int *lda, const float *x, const f77_int *incx, const float *beta, float *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void SSYMV_(const char *uplo, const f77_int *n, const float *alpha, const float *a, const f77_int *lda, const float *x, const f77_int *incx, const float *beta, float *y, const f77_int *incy); + + + +BLIS_EXPORT_BLIS void SSBMV(const char *uplo, const f77_int *n, const f77_int *k, const float *alpha, const float *a, const f77_int *lda, const float *x, const f77_int *incx, const float *beta, float *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void ssbmv(const char *uplo, const f77_int *n, const f77_int *k, const float *alpha, const float *a, const f77_int *lda, const float *x, const f77_int *incx, const float *beta, float *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void SSBMV_(const char *uplo, const f77_int *n, const f77_int *k, const float *alpha, const float *a, const f77_int *lda, const float *x, const f77_int *incx, const float *beta, float *y, const f77_int *incy); + + + +BLIS_EXPORT_BLIS void SSPMV(const char *uplo, const f77_int *n, const float *alpha, const float *ap, const float *x, const f77_int *incx, const float *beta, float *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void sspmv(const char *uplo, const f77_int *n, const float *alpha, const float *ap, const float *x, const f77_int *incx, const float *beta, float *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void SSPMV_(const char *uplo, const f77_int *n, const float *alpha, const float *ap, const float *x, const f77_int *incx, const float *beta, float *y, const f77_int *incy); + + + +BLIS_EXPORT_BLIS void STRMV(const char *uplo, const char *trans, const char *diag, const f77_int *n, const float *a, const f77_int *lda, float *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void strmv(const char *uplo, const char *trans, const char *diag, const f77_int *n, const float *a, const f77_int *lda, float *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void STRMV_(const char *uplo, const char *trans, const char *diag, const f77_int *n, const float *a, const f77_int *lda, float *x, const f77_int *incx); + + + +BLIS_EXPORT_BLIS void STBMV(const char *uplo, const char *trans, const char *diag, const f77_int *n, const f77_int *k, const float *a, const f77_int *lda, float *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void stbmv(const char *uplo, const char *trans, const char *diag, const f77_int *n, const f77_int *k, const float *a, const f77_int *lda, float *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void STBMV_(const char *uplo, const char *trans, const char *diag, const f77_int *n, const f77_int *k, const float *a, const f77_int *lda, float *x, const f77_int *incx); + + + +BLIS_EXPORT_BLIS void STPMV(const char *uplo, const char *trans, const char *diag, const f77_int *n, const float *ap, float *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void stpmv(const char *uplo, const char *trans, const char *diag, const f77_int *n, const float *ap, float *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void STPMV_(const char *uplo, const char *trans, const char *diag, const f77_int *n, const float *ap, float *x, const f77_int *incx); + + + +BLIS_EXPORT_BLIS void STRSV(const char *uplo, const char *trans, const char *diag, const f77_int *n, const float *a, const f77_int *lda, float *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void strsv(const char *uplo, const char *trans, const char *diag, const f77_int *n, const float *a, const f77_int *lda, float *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void STRSV_(const char *uplo, const char *trans, const char *diag, const f77_int *n, const float *a, const f77_int *lda, float *x, const f77_int *incx); + + + +BLIS_EXPORT_BLIS void STBSV(const char *uplo, const char *trans, const char *diag, const f77_int *n, const f77_int *k, const float *a, const f77_int *lda, float *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void stbsv(const char *uplo, const char *trans, const char *diag, const f77_int *n, const f77_int *k, const float *a, const f77_int *lda, float *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void STBSV_(const char *uplo, const char *trans, const char *diag, const f77_int *n, const f77_int *k, const float *a, const f77_int *lda, float *x, const f77_int *incx); + + + +BLIS_EXPORT_BLIS void STPSV(const char *uplo, const char *trans, const char *diag, const f77_int *n, const float *ap, float *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void stpsv(const char *uplo, const char *trans, const char *diag, const f77_int *n, const float *ap, float *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void STPSV_(const char *uplo, const char *trans, const char *diag, const f77_int *n, const float *ap, float *x, const f77_int *incx); + + + +BLIS_EXPORT_BLIS void SGER(const f77_int *m, const f77_int *n, const float *alpha, const float *x, const f77_int *incx, const float *y, const f77_int *incy, float *a, const f77_int *lda); + +BLIS_EXPORT_BLIS void sger(const f77_int *m, const f77_int *n, const float *alpha, const float *x, const f77_int *incx, const float *y, const f77_int *incy, float *a, const f77_int *lda); + +BLIS_EXPORT_BLIS void SGER_(const f77_int *m, const f77_int *n, const float *alpha, const float *x, const f77_int *incx, const float *y, const f77_int *incy, float *a, const f77_int *lda); + + + +BLIS_EXPORT_BLIS void SSYR(const char *uplo, const f77_int *n, const float *alpha, const float *x, const f77_int *incx, float *a, const f77_int *lda); + +BLIS_EXPORT_BLIS void ssyr(const char *uplo, const f77_int *n, const float *alpha, const float *x, const f77_int *incx, float *a, const f77_int *lda); + +BLIS_EXPORT_BLIS void SSYR_(const char *uplo, const f77_int *n, const float *alpha, const float *x, const f77_int *incx, float *a, const f77_int *lda); + + + +BLIS_EXPORT_BLIS void SSPR(const char *uplo, const f77_int *n, const float *alpha, const float *x, const f77_int *incx, float *ap); + +BLIS_EXPORT_BLIS void sspr(const char *uplo, const f77_int *n, const float *alpha, const float *x, const f77_int *incx, float *ap); + +BLIS_EXPORT_BLIS void SSPR_(const char *uplo, const f77_int *n, const float *alpha, const float *x, const f77_int *incx, float *ap); + + + +BLIS_EXPORT_BLIS void SSYR2(const char *uplo, const f77_int *n, const float *alpha, const float *x, const f77_int *incx, const float *y, const f77_int *incy, float *a, const f77_int *lda); + +BLIS_EXPORT_BLIS void ssyr2(const char *uplo, const f77_int *n, const float *alpha, const float *x, const f77_int *incx, const float *y, const f77_int *incy, float *a, const f77_int *lda); + +BLIS_EXPORT_BLIS void SSYR2_(const char *uplo, const f77_int *n, const float *alpha, const float *x, const f77_int *incx, const float *y, const f77_int *incy, float *a, const f77_int *lda); + + + +BLIS_EXPORT_BLIS void SSPR2(const char *uplo, const f77_int *n, const float *alpha, const float *x, const f77_int *incx, const float *y, const f77_int *incy, float *ap); + +BLIS_EXPORT_BLIS void sspr2(const char *uplo, const f77_int *n, const float *alpha, const float *x, const f77_int *incx, const float *y, const f77_int *incy, float *ap); + +BLIS_EXPORT_BLIS void SSPR2_(const char *uplo, const f77_int *n, const float *alpha, const float *x, const f77_int *incx, const float *y, const f77_int *incy, float *ap); + + + +BLIS_EXPORT_BLIS void DGEMV(const char *trans, const f77_int *m, const f77_int *n, const double *alpha, const double *a, const f77_int *lda, const double *x, const f77_int *incx, const double *beta, double *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void dgemv(const char *trans, const f77_int *m, const f77_int *n, const double *alpha, const double *a, const f77_int *lda, const double *x, const f77_int *incx, const double *beta, double *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void DGEMV_(const char *trans, const f77_int *m, const f77_int *n, const double *alpha, const double *a, const f77_int *lda, const double *x, const f77_int *incx, const double *beta, double *y, const f77_int *incy); + + + +BLIS_EXPORT_BLIS void DGBMV(const char *trans, const f77_int *m, const f77_int *n, const f77_int *kl, const f77_int *ku, const double *alpha, const double *a, const f77_int *lda, const double *x, const f77_int *incx, const double *beta, double *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void dgbmv(const char *trans, const f77_int *m, const f77_int *n, const f77_int *kl, const f77_int *ku, const double *alpha, const double *a, const f77_int *lda, const double *x, const f77_int *incx, const double *beta, double *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void DGBMV_(const char *trans, const f77_int *m, const f77_int *n, const f77_int *kl, const f77_int *ku, const double *alpha, const double *a, const f77_int *lda, const double *x, const f77_int *incx, const double *beta, double *y, const f77_int *incy); + + + +BLIS_EXPORT_BLIS void DSYMV(const char *uplo, const f77_int *n, const double *alpha, const double *a, const f77_int *lda, const double *x, const f77_int *incx, const double *beta, double *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void dsymv(const char *uplo, const f77_int *n, const double *alpha, const double *a, const f77_int *lda, const double *x, const f77_int *incx, const double *beta, double *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void DSYMV_(const char *uplo, const f77_int *n, const double *alpha, const double *a, const f77_int *lda, const double *x, const f77_int *incx, const double *beta, double *y, const f77_int *incy); + + + +BLIS_EXPORT_BLIS void DSBMV(const char *uplo, const f77_int *n, const f77_int *k, const double *alpha, const double *a, const f77_int *lda, const double *x, const f77_int *incx, const double *beta, double *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void dsbmv(const char *uplo, const f77_int *n, const f77_int *k, const double *alpha, const double *a, const f77_int *lda, const double *x, const f77_int *incx, const double *beta, double *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void DSBMV_(const char *uplo, const f77_int *n, const f77_int *k, const double *alpha, const double *a, const f77_int *lda, const double *x, const f77_int *incx, const double *beta, double *y, const f77_int *incy); + + + +BLIS_EXPORT_BLIS void DSPMV(const char *uplo, const f77_int *n, const double *alpha, const double *ap, const double *x, const f77_int *incx, const double *beta, double *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void dspmv(const char *uplo, const f77_int *n, const double *alpha, const double *ap, const double *x, const f77_int *incx, const double *beta, double *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void DSPMV_(const char *uplo, const f77_int *n, const double *alpha, const double *ap, const double *x, const f77_int *incx, const double *beta, double *y, const f77_int *incy); + + + +BLIS_EXPORT_BLIS void DTRMV(const char *uplo, const char *trans, const char *diag, const f77_int *n, const double *a, const f77_int *lda, double *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void dtrmv(const char *uplo, const char *trans, const char *diag, const f77_int *n, const double *a, const f77_int *lda, double *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void DTRMV_(const char *uplo, const char *trans, const char *diag, const f77_int *n, const double *a, const f77_int *lda, double *x, const f77_int *incx); + + + +BLIS_EXPORT_BLIS void DTBMV(const char *uplo, const char *trans, const char *diag, const f77_int *n, const f77_int *k, const double *a, const f77_int *lda, double *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void dtbmv(const char *uplo, const char *trans, const char *diag, const f77_int *n, const f77_int *k, const double *a, const f77_int *lda, double *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void DTBMV_(const char *uplo, const char *trans, const char *diag, const f77_int *n, const f77_int *k, const double *a, const f77_int *lda, double *x, const f77_int *incx); + + + +BLIS_EXPORT_BLIS void DTPMV(const char *uplo, const char *trans, const char *diag, const f77_int *n, const double *ap, double *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void dtpmv(const char *uplo, const char *trans, const char *diag, const f77_int *n, const double *ap, double *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void DTPMV_(const char *uplo, const char *trans, const char *diag, const f77_int *n, const double *ap, double *x, const f77_int *incx); + + + +BLIS_EXPORT_BLIS void DTRSV(const char *uplo, const char *trans, const char *diag, const f77_int *n, const double *a, const f77_int *lda, double *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void dtrsv(const char *uplo, const char *trans, const char *diag, const f77_int *n, const double *a, const f77_int *lda, double *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void DTRSV_(const char *uplo, const char *trans, const char *diag, const f77_int *n, const double *a, const f77_int *lda, double *x, const f77_int *incx); + + + +BLIS_EXPORT_BLIS void DTBSV(const char *uplo, const char *trans, const char *diag, const f77_int *n, const f77_int *k, const double *a, const f77_int *lda, double *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void dtbsv(const char *uplo, const char *trans, const char *diag, const f77_int *n, const f77_int *k, const double *a, const f77_int *lda, double *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void DTBSV_(const char *uplo, const char *trans, const char *diag, const f77_int *n, const f77_int *k, const double *a, const f77_int *lda, double *x, const f77_int *incx); + + + +BLIS_EXPORT_BLIS void DTPSV(const char *uplo, const char *trans, const char *diag, const f77_int *n, const double *ap, double *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void dtpsv(const char *uplo, const char *trans, const char *diag, const f77_int *n, const double *ap, double *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void DTPSV_(const char *uplo, const char *trans, const char *diag, const f77_int *n, const double *ap, double *x, const f77_int *incx); + + + +BLIS_EXPORT_BLIS void DGER(const f77_int *m, const f77_int *n, const double *alpha, const double *x, const f77_int *incx, const double *y, const f77_int *incy, double *a, const f77_int *lda); + +BLIS_EXPORT_BLIS void dger(const f77_int *m, const f77_int *n, const double *alpha, const double *x, const f77_int *incx, const double *y, const f77_int *incy, double *a, const f77_int *lda); + +BLIS_EXPORT_BLIS void DGER_(const f77_int *m, const f77_int *n, const double *alpha, const double *x, const f77_int *incx, const double *y, const f77_int *incy, double *a, const f77_int *lda); + + + +BLIS_EXPORT_BLIS void DSYR(const char *uplo, const f77_int *n, const double *alpha, const double *x, const f77_int *incx, double *a, const f77_int *lda); + +BLIS_EXPORT_BLIS void dsyr(const char *uplo, const f77_int *n, const double *alpha, const double *x, const f77_int *incx, double *a, const f77_int *lda); + +BLIS_EXPORT_BLIS void DSYR_(const char *uplo, const f77_int *n, const double *alpha, const double *x, const f77_int *incx, double *a, const f77_int *lda); + + + +BLIS_EXPORT_BLIS void DSPR(const char *uplo, const f77_int *n, const double *alpha, const double *x, const f77_int *incx, double *ap); + +BLIS_EXPORT_BLIS void dspr(const char *uplo, const f77_int *n, const double *alpha, const double *x, const f77_int *incx, double *ap); + +BLIS_EXPORT_BLIS void DSPR_(const char *uplo, const f77_int *n, const double *alpha, const double *x, const f77_int *incx, double *ap); + + + +BLIS_EXPORT_BLIS void DSYR2(const char *uplo, const f77_int *n, const double *alpha, const double *x, const f77_int *incx, const double *y, const f77_int *incy, double *a, const f77_int *lda); + +BLIS_EXPORT_BLIS void dsyr2(const char *uplo, const f77_int *n, const double *alpha, const double *x, const f77_int *incx, const double *y, const f77_int *incy, double *a, const f77_int *lda); + +BLIS_EXPORT_BLIS void DSYR2_(const char *uplo, const f77_int *n, const double *alpha, const double *x, const f77_int *incx, const double *y, const f77_int *incy, double *a, const f77_int *lda); + + + +BLIS_EXPORT_BLIS void DSPR2(const char *uplo, const f77_int *n, const double *alpha, const double *x, const f77_int *incx, const double *y, const f77_int *incy, double *ap); + +BLIS_EXPORT_BLIS void dspr2(const char *uplo, const f77_int *n, const double *alpha, const double *x, const f77_int *incx, const double *y, const f77_int *incy, double *ap); + +BLIS_EXPORT_BLIS void DSPR2_(const char *uplo, const f77_int *n, const double *alpha, const double *x, const f77_int *incx, const double *y, const f77_int *incy, double *ap); + + + +BLIS_EXPORT_BLIS void CGEMV(const char *trans, const f77_int *m, const f77_int *n, const scomplex *alpha, const scomplex *a, const f77_int *lda, const scomplex *x, const f77_int *incx, const scomplex *beta, scomplex *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void cgemv(const char *trans, const f77_int *m, const f77_int *n, const scomplex *alpha, const scomplex *a, const f77_int *lda, const scomplex *x, const f77_int *incx, const scomplex *beta, scomplex *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void CGEMV_(const char *trans, const f77_int *m, const f77_int *n, const scomplex *alpha, const scomplex *a, const f77_int *lda, const scomplex *x, const f77_int *incx, const scomplex *beta, scomplex *y, const f77_int *incy); + + + +BLIS_EXPORT_BLIS void CGBMV(const char *trans, const f77_int *m, const f77_int *n, const f77_int *kl, const f77_int *ku, const scomplex *alpha, const scomplex *a, const f77_int *lda, const scomplex *x, const f77_int *incx, const scomplex *beta, scomplex *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void cgbmv(const char *trans, const f77_int *m, const f77_int *n, const f77_int *kl, const f77_int *ku, const scomplex *alpha, const scomplex *a, const f77_int *lda, const scomplex *x, const f77_int *incx, const scomplex *beta, scomplex *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void CGBMV_(const char *trans, const f77_int *m, const f77_int *n, const f77_int *kl, const f77_int *ku, const scomplex *alpha, const scomplex *a, const f77_int *lda, const scomplex *x, const f77_int *incx, const scomplex *beta, scomplex *y, const f77_int *incy); + + + +BLIS_EXPORT_BLIS void CHEMV(const char *uplo, const f77_int *n, const scomplex *alpha, const scomplex *a, const f77_int *lda, const scomplex *x, const f77_int *incx, const scomplex *beta, scomplex *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void chemv(const char *uplo, const f77_int *n, const scomplex *alpha, const scomplex *a, const f77_int *lda, const scomplex *x, const f77_int *incx, const scomplex *beta, scomplex *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void CHEMV_(const char *uplo, const f77_int *n, const scomplex *alpha, const scomplex *a, const f77_int *lda, const scomplex *x, const f77_int *incx, const scomplex *beta, scomplex *y, const f77_int *incy); + + + +BLIS_EXPORT_BLIS void CHBMV(const char *uplo, const f77_int *n, const f77_int *k, const scomplex *alpha, const scomplex *a, const f77_int *lda, const scomplex *x, const f77_int *incx, const scomplex *beta, scomplex *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void chbmv(const char *uplo, const f77_int *n, const f77_int *k, const scomplex *alpha, const scomplex *a, const f77_int *lda, const scomplex *x, const f77_int *incx, const scomplex *beta, scomplex *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void CHBMV_(const char *uplo, const f77_int *n, const f77_int *k, const scomplex *alpha, const scomplex *a,const f77_int *lda, const scomplex *x, const f77_int *incx, const scomplex *beta, scomplex *y, const f77_int *incy); + + + +BLIS_EXPORT_BLIS void CHPMV(const char *uplo, const f77_int *n, const scomplex *alpha, const scomplex *ap, const scomplex *x, const f77_int *incx, const scomplex *beta, scomplex *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void chpmv(const char *uplo, const f77_int *n, const scomplex *alpha, const scomplex *ap, const scomplex *x, const f77_int *incx, const scomplex *beta, scomplex *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void CHPMV_(const char *uplo, const f77_int *n, const scomplex *alpha, const scomplex *ap, const scomplex *x, const f77_int *incx, const scomplex *beta, scomplex *y, const f77_int *incy); + + + +BLIS_EXPORT_BLIS void CTRMV(const char *uplo, const char *trans, const char *diag, const f77_int *n, const scomplex *a, const f77_int *lda, scomplex *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void ctrmv(const char *uplo, const char *trans, const char *diag, const f77_int *n, const scomplex *a, const f77_int *lda, scomplex *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void CTRMV_(const char *uplo, const char *trans, const char *diag, const f77_int *n, const scomplex *a, const f77_int *lda, scomplex *x, const f77_int *incx); + + + +BLIS_EXPORT_BLIS void CTBMV(const char *uplo, const char *trans, const char *diag, const f77_int *n, const f77_int *k, const scomplex *a, const f77_int *lda, scomplex *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void ctbmv(const char *uplo, const char *trans, const char *diag, const f77_int *n, const f77_int *k, const scomplex *a, const f77_int *lda, scomplex *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void CTBMV_(const char *uplo, const char *trans, const char *diag, const f77_int *n, const f77_int *k, const scomplex *a, const f77_int *lda, scomplex *x, const f77_int *incx); + + + +BLIS_EXPORT_BLIS void CTPMV(const char *uplo, const char *trans, const char *diag, const f77_int *n, const scomplex *ap, scomplex *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void ctpmv(const char *uplo, const char *trans, const char *diag, const f77_int *n, const scomplex *ap, scomplex *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void CTPMV_(const char *uplo, const char *trans, const char *diag, const f77_int *n, const scomplex *ap, scomplex *x, const f77_int *incx); + + + +BLIS_EXPORT_BLIS void CTRSV(const char *uplo, const char *trans, const char *diag, const f77_int *n, const scomplex *a, const f77_int *lda, scomplex *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void ctrsv(const char *uplo, const char *trans, const char *diag, const f77_int *n, const scomplex *a, const f77_int *lda, scomplex *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void CTRSV_(const char *uplo, const char *trans, const char *diag, const f77_int *n, const scomplex *a, const f77_int *lda, scomplex *x, const f77_int *incx); + + + +BLIS_EXPORT_BLIS void CTBSV(const char *uplo, const char *trans, const char *diag, const f77_int *n, const f77_int *k, const scomplex *a, const f77_int *lda, scomplex *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void ctbsv(const char *uplo, const char *trans, const char *diag, const f77_int *n, const f77_int *k, const scomplex *a, const f77_int *lda, scomplex *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void CTBSV_(const char *uplo, const char *trans, const char *diag, const f77_int *n, const f77_int *k, const scomplex *a, const f77_int *lda, scomplex *x, const f77_int *incx); + + + +BLIS_EXPORT_BLIS void CTPSV(const char *uplo, const char *trans, const char *diag, const f77_int *n, const scomplex *ap, scomplex *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void ctpsv(const char *uplo, const char *trans, const char *diag, const f77_int *n, const scomplex *ap, scomplex *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void CTPSV_(const char *uplo, const char *trans, const char *diag, const f77_int *n, const scomplex *ap, scomplex *x, const f77_int *incx); + + + +BLIS_EXPORT_BLIS void CGERC(const f77_int *m, const f77_int *n, const scomplex *alpha, const scomplex *x, const f77_int *incx, const scomplex *y, const f77_int *incy, scomplex *a, const f77_int *lda); + +BLIS_EXPORT_BLIS void cgerc(const f77_int *m, const f77_int *n, const scomplex *alpha, const scomplex *x, const f77_int *incx, const scomplex *y, const f77_int *incy, scomplex *a, const f77_int *lda); + +BLIS_EXPORT_BLIS void CGERC_(const f77_int *m, const f77_int *n, const scomplex *alpha, const scomplex *x, const f77_int *incx, const scomplex *y, const f77_int *incy, scomplex *a, const f77_int *lda); + + + +BLIS_EXPORT_BLIS void CGERU(const f77_int *m, const f77_int *n, const scomplex *alpha, const scomplex *x, const f77_int *incx, const scomplex *y, const f77_int *incy, scomplex *a, const f77_int *lda); + +BLIS_EXPORT_BLIS void cgeru(const f77_int *m, const f77_int *n, const scomplex *alpha, const scomplex *x, const f77_int *incx, const scomplex *y, const f77_int *incy, scomplex *a, const f77_int *lda); + +BLIS_EXPORT_BLIS void CGERU_(const f77_int *m, const f77_int *n, const scomplex *alpha, const scomplex *x, const f77_int *incx, const scomplex *y, const f77_int *incy, scomplex *a, const f77_int *lda); + + + +BLIS_EXPORT_BLIS void CHER(const char *uplo, const f77_int *n, const float *alpha, const scomplex *x, const f77_int *incx, scomplex *a, const f77_int *lda); + +BLIS_EXPORT_BLIS void cher(const char *uplo, const f77_int *n, const float *alpha, const scomplex *x, const f77_int *incx, scomplex *a, const f77_int *lda); + +BLIS_EXPORT_BLIS void CHER_(const char *uplo, const f77_int *n, const float *alpha, const scomplex *x, const f77_int *incx, scomplex *a, const f77_int *lda); + + + +BLIS_EXPORT_BLIS void CHPR(const char *uplo, const f77_int *n, const float *alpha, const scomplex *x, const f77_int *incx, scomplex *ap); + +BLIS_EXPORT_BLIS void chpr(const char *uplo, const f77_int *n, const float *alpha, const scomplex *x, const f77_int *incx, scomplex *ap); + +BLIS_EXPORT_BLIS void CHPR_(const char *uplo, const f77_int *n, const float *alpha, const scomplex *x, const f77_int *incx, scomplex *ap); + + + +BLIS_EXPORT_BLIS void CHER2(const char *uplo, const f77_int *n, const scomplex *alpha, const scomplex *x, const f77_int *incx, const scomplex *y, const f77_int *incy, scomplex *a, const f77_int *lda); + +BLIS_EXPORT_BLIS void cher2(const char *uplo, const f77_int *n, const scomplex *alpha, const scomplex *x, const f77_int *incx, const scomplex *y, const f77_int *incy, scomplex *a, const f77_int *lda); + +BLIS_EXPORT_BLIS void CHER2_(const char *uplo, const f77_int *n, const scomplex *alpha, const scomplex *x, const f77_int *incx, const scomplex *y, const f77_int *incy, scomplex *a, const f77_int *lda); + + + +BLIS_EXPORT_BLIS void CHPR2(const char *uplo, const f77_int *n, const scomplex *alpha, const scomplex *x, const f77_int *incx, const scomplex *y, const f77_int *incy, scomplex *ap); + +BLIS_EXPORT_BLIS void chpr2(const char *uplo, const f77_int *n, const scomplex *alpha, const scomplex *x, const f77_int *incx, const scomplex *y, const f77_int *incy, scomplex *ap); + +BLIS_EXPORT_BLIS void CHPR2_(const char *uplo, const f77_int *n, const scomplex *alpha, const scomplex *x, const f77_int *incx, const scomplex *y, const f77_int *incy, scomplex *ap); + + + +BLIS_EXPORT_BLIS void ZGEMV(const char *trans, const f77_int *m, const f77_int *n, const dcomplex *alpha, const dcomplex *a, const f77_int *lda, const dcomplex *x, const f77_int *incx, const dcomplex *beta, dcomplex *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void zgemv(const char *trans, const f77_int *m, const f77_int *n, const dcomplex *alpha, const dcomplex *a, const f77_int *lda, const dcomplex *x, const f77_int *incx, const dcomplex *beta, dcomplex *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void ZGEMV_(const char *trans, const f77_int *m, const f77_int *n, const dcomplex *alpha, const dcomplex *a, const f77_int *lda, const dcomplex *x, const f77_int *incx, const dcomplex *beta, dcomplex *y, const f77_int *incy); + + + +BLIS_EXPORT_BLIS void ZGBMV(const char *trans, const f77_int *m, const f77_int *n, const f77_int *kl, const f77_int *ku, const dcomplex *alpha, const dcomplex *a, const f77_int *lda, const dcomplex *x, const f77_int *incx, const dcomplex *beta, dcomplex *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void zgbmv(const char *trans, const f77_int *m, const f77_int *n, const f77_int *kl, const f77_int *ku, const dcomplex *alpha, const dcomplex *a, const f77_int *lda, const dcomplex *x, const f77_int *incx, const dcomplex *beta, dcomplex *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void ZGBMV_(const char *trans, const f77_int *m, const f77_int *n, const f77_int *kl, const f77_int *ku, const dcomplex *alpha, const dcomplex *a, const f77_int *lda, const dcomplex *x, const f77_int *incx, const dcomplex *beta, dcomplex *y, const f77_int *incy); + + + +BLIS_EXPORT_BLIS void ZHEMV(const char *uplo, const f77_int *n, const dcomplex *alpha, const dcomplex *a, const f77_int *lda, const dcomplex *x, const f77_int *incx, const dcomplex *beta, dcomplex *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void zhemv(const char *uplo, const f77_int *n, const dcomplex *alpha, const dcomplex *a, const f77_int *lda, const dcomplex *x, const f77_int *incx, const dcomplex *beta, dcomplex *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void ZHEMV_(const char *uplo, const f77_int *n, const dcomplex *alpha, const dcomplex *a, const f77_int *lda, const dcomplex *x, const f77_int *incx, const dcomplex *beta, dcomplex *y, const f77_int *incy); + + + +BLIS_EXPORT_BLIS void ZHBMV(const char *uplo, const f77_int *n, const f77_int *k, const dcomplex *alpha, const dcomplex *a, const f77_int *lda, const dcomplex *x, const f77_int *incx, const dcomplex *beta, dcomplex *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void zhbmv(const char *uplo, const f77_int *n, const f77_int *k, const dcomplex *alpha, const dcomplex *a, const f77_int *lda, const dcomplex *x, const f77_int *incx, const dcomplex *beta, dcomplex *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void ZHBMV_(const char *uplo, const f77_int *n, const f77_int *k, const dcomplex *alpha, const dcomplex *a, const f77_int *lda, const dcomplex *x, const f77_int *incx, const dcomplex *beta, dcomplex *y, const f77_int *incy); + + + +BLIS_EXPORT_BLIS void ZHPMV(const char *uplo, const f77_int *n, const dcomplex *alpha, const dcomplex *ap, const dcomplex *x, const f77_int *incx, const dcomplex *beta, dcomplex *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void zhpmv(const char *uplo, const f77_int *n, const dcomplex *alpha, const dcomplex *ap, const dcomplex *x, const f77_int *incx, const dcomplex *beta, dcomplex *y, const f77_int *incy); + +BLIS_EXPORT_BLIS void ZHPMV_(const char *uplo, const f77_int *n, const dcomplex *alpha, const dcomplex *ap, const dcomplex *x, const f77_int *incx, const dcomplex *beta, dcomplex *y, const f77_int *incy); + + + +BLIS_EXPORT_BLIS void ZTRMV(const char *uplo, const char *trans, const char *diag, const f77_int *n, const dcomplex *a, const f77_int *lda, dcomplex *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void ztrmv(const char *uplo, const char *trans, const char *diag, const f77_int *n, const dcomplex *a, const f77_int *lda, dcomplex *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void ZTRMV_(const char *uplo, const char *trans, const char *diag, const f77_int *n, const dcomplex *a, const f77_int *lda, dcomplex *x, const f77_int *incx); + + + +BLIS_EXPORT_BLIS void ZTBMV(const char *uplo, const char *trans, const char *diag, const f77_int *n, const f77_int *k, const dcomplex *a, const f77_int *lda, dcomplex *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void ztbmv(const char *uplo, const char *trans, const char *diag, const f77_int *n, const f77_int *k, const dcomplex *a, const f77_int *lda, dcomplex *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void ZTBMV_(const char *uplo, const char *trans, const char *diag, const f77_int *n, const f77_int *k, const dcomplex *a, const f77_int *lda, dcomplex *x, const f77_int *incx); + + + +BLIS_EXPORT_BLIS void ZTPMV(const char *uplo, const char *trans, const char *diag, const f77_int *n, const dcomplex *ap, dcomplex *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void ztpmv(const char *uplo, const char *trans, const char *diag, const f77_int *n, const dcomplex *ap, dcomplex *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void ZTPMV_(const char *uplo, const char *trans, const char *diag, const f77_int *n, const dcomplex *ap, dcomplex *x, const f77_int *incx); + + + +BLIS_EXPORT_BLIS void ZTRSV(const char *uplo, const char *trans, const char *diag, const f77_int *n, const dcomplex *a, const f77_int *lda, dcomplex *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void ztrsv(const char *uplo, const char *trans, const char *diag, const f77_int *n, const dcomplex *a, const f77_int *lda, dcomplex *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void ZTRSV_(const char *uplo, const char *trans, const char *diag, const f77_int *n, const dcomplex *a, const f77_int *lda, dcomplex *x, const f77_int *incx); + + + +BLIS_EXPORT_BLIS void ZTBSV(const char *uplo, const char *trans, const char *diag, const f77_int *n, const f77_int *k, const dcomplex *a, const f77_int *lda, dcomplex *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void ztbsv(const char *uplo, const char *trans, const char *diag, const f77_int *n, const f77_int *k, const dcomplex *a, const f77_int *lda, dcomplex *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void ZTBSV_(const char *uplo, const char *trans, const char *diag, const f77_int *n, const f77_int *k, const dcomplex *a, const f77_int *lda, dcomplex *x, const f77_int *incx); + + + +BLIS_EXPORT_BLIS void ZTPSV(const char *uplo, const char *trans, const char *diag, const f77_int *n, const dcomplex *ap, dcomplex *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void ztpsv(const char *uplo, const char *trans, const char *diag, const f77_int *n, const dcomplex *ap, dcomplex *x, const f77_int *incx); + +BLIS_EXPORT_BLIS void ZTPSV_(const char *uplo, const char *trans, const char *diag, const f77_int *n, const dcomplex *ap, dcomplex *x, const f77_int *incx); + + + +BLIS_EXPORT_BLIS void ZGERU(const f77_int *m, const f77_int *n, const dcomplex *alpha, const dcomplex *x, const f77_int *incx, const dcomplex *y, const f77_int *incy, dcomplex *a, const f77_int *lda); + +BLIS_EXPORT_BLIS void zgeru(const f77_int *m, const f77_int *n, const dcomplex *alpha, const dcomplex *x, const f77_int *incx, const dcomplex *y, const f77_int *incy, dcomplex *a, const f77_int *lda); + +BLIS_EXPORT_BLIS void ZGERU_(const f77_int *m, const f77_int *n, const dcomplex *alpha, const dcomplex *x, const f77_int *incx, const dcomplex *y, const f77_int *incy, dcomplex *a, const f77_int *lda); + + + +BLIS_EXPORT_BLIS void ZGERC(const f77_int *m, const f77_int *n, const dcomplex *alpha, const dcomplex *x, const f77_int *incx, const dcomplex *y, const f77_int *incy, dcomplex *a, const f77_int *lda); + +BLIS_EXPORT_BLIS void zgerc(const f77_int *m, const f77_int *n, const dcomplex *alpha, const dcomplex *x, const f77_int *incx, const dcomplex *y, const f77_int *incy, dcomplex *a, const f77_int *lda); + +BLIS_EXPORT_BLIS void ZGERC_(const f77_int *m, const f77_int *n, const dcomplex *alpha, const dcomplex *x, const f77_int *incx, const dcomplex *y, const f77_int *incy, dcomplex *a, const f77_int *lda); + + + +BLIS_EXPORT_BLIS void ZHER(const char *uplo, const f77_int *n, const double *alpha, const dcomplex *x, const f77_int *incx, dcomplex *a, const f77_int *lda); + +BLIS_EXPORT_BLIS void zher(const char *uplo, const f77_int *n, const double *alpha, const dcomplex *x, const f77_int *incx, dcomplex *a, const f77_int *lda); + +BLIS_EXPORT_BLIS void ZHER_(const char *uplo, const f77_int *n, const double *alpha, const dcomplex *x, const f77_int *incx, dcomplex *a, const f77_int *lda); + + + +BLIS_EXPORT_BLIS void ZHPR(const char *uplo, const f77_int *n, const bla_double *alpha, const dcomplex *x, const f77_int *incx, dcomplex *ap); + +BLIS_EXPORT_BLIS void zhpr(const char *uplo, const f77_int *n, const bla_double *alpha, const dcomplex *x, const f77_int *incx, dcomplex *ap); + +BLIS_EXPORT_BLIS void ZHPR_(const char *uplo, const f77_int *n, const bla_double *alpha, const dcomplex *x, const f77_int *incx, dcomplex *ap); + + + +BLIS_EXPORT_BLIS void ZHER2(const char *uplo, const f77_int *n, const dcomplex *alpha, const dcomplex *x, const f77_int *incx, const dcomplex *y, const f77_int *incy, dcomplex *a, const f77_int *lda); + +BLIS_EXPORT_BLIS void zher2(const char *uplo, const f77_int *n, const dcomplex *alpha, const dcomplex *x, const f77_int *incx, const dcomplex *y, const f77_int *incy, dcomplex *a, const f77_int *lda); + +BLIS_EXPORT_BLIS void ZHER2_(const char *uplo, const f77_int *n, const dcomplex *alpha, const dcomplex *x, const f77_int *incx, const dcomplex *y, const f77_int *incy, dcomplex *a, const f77_int *lda); + + + +BLIS_EXPORT_BLIS void ZHPR2(const char *uplo, const f77_int *n, const dcomplex *alpha, const dcomplex *x, const f77_int *incx, const dcomplex *y, const f77_int *incy, dcomplex *ap); + +BLIS_EXPORT_BLIS void zhpr2(const char *uplo, const f77_int *n, const dcomplex *alpha, const dcomplex *x, const f77_int *incx, const dcomplex *y, const f77_int *incy, dcomplex *ap); + +BLIS_EXPORT_BLIS void ZHPR2_(const char *uplo, const f77_int *n, const dcomplex *alpha, const dcomplex *x, const f77_int *incx, const dcomplex *y, const f77_int *incy, dcomplex *ap); + + + +//Level 3 APIs +BLIS_EXPORT_BLIS void SGEMM(const char *transa, const char *transb, const f77_int *m, const f77_int *n, const f77_int *k, const float *alpha, const float *a, const f77_int *lda, const float *b, const f77_int *ldb, const float *beta, float *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void sgemm(const char *transa, const char *transb, const f77_int *m, const f77_int *n, const f77_int *k, const float *alpha, const float *a, const f77_int *lda, const float *b, const f77_int *ldb, const float *beta, float *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void SGEMM_(const char *transa, const char *transb, const f77_int *m, const f77_int *n, const f77_int *k, const float *alpha, const float *a, const f77_int *lda, const float *b, const f77_int *ldb, const float *beta, float *c, const f77_int *ldc); + + + +BLIS_EXPORT_BLIS void SSYMM(const char *side, const char *uplo, const f77_int *m, const f77_int *n, const float *alpha, const float *a, const f77_int *lda, const float *b, const f77_int *ldb, const float *beta, float *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void ssymm(const char *side, const char *uplo, const f77_int *m, const f77_int *n, const float *alpha, const float *a, const f77_int *lda, const float *b, const f77_int *ldb, const float *beta, float *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void SSYMM_(const char *side, const char *uplo, const f77_int *m, const f77_int *n, const float *alpha, const float *a, const f77_int *lda, const float *b, const f77_int *ldb, const float *beta, float *c, const f77_int *ldc); + + + +BLIS_EXPORT_BLIS void SSYRK(const char *uplo, const char *trans, const f77_int *n, const f77_int *k, const float *alpha, const float *a, const f77_int *lda, const float *beta, float *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void ssyrk(const char *uplo, const char *trans, const f77_int *n, const f77_int *k, const float *alpha, const float *a, const f77_int *lda, const float *beta, float *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void SSYRK_(const char *uplo, const char *trans, const f77_int *n, const f77_int *k, const float *alpha, const float *a, const f77_int *lda, const float *beta, float *c, const f77_int *ldc); + + + +BLIS_EXPORT_BLIS void SSYR2K(const char *uplo, const char *trans, const f77_int *n, const f77_int *k, const float *alpha, const float *a, const f77_int *lda, const float *b, const f77_int *ldb, const float *beta, float *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void ssyr2k(const char *uplo, const char *trans, const f77_int *n, const f77_int *k, const float *alpha, const float *a, const f77_int *lda, const float *b, const f77_int *ldb, const float *beta, float *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void SSYR2K_(const char *uplo, const char *trans, const f77_int *n, const f77_int *k, const float *alpha, const float *a, const f77_int *lda, const float *b, const f77_int *ldb, const float *beta, float *c, const f77_int *ldc); + + + +BLIS_EXPORT_BLIS void STRMM(const char *side, const char *uplo, const char *transa, const char *diag, const f77_int *m, const f77_int *n, const float *alpha, const float *a, const f77_int *lda, float *b, const f77_int *ldb); + +BLIS_EXPORT_BLIS void strmm(const char *side, const char *uplo, const char *transa, const char *diag, const f77_int *m, const f77_int *n, const float *alpha, const float *a, const f77_int *lda, float *b, const f77_int *ldb); + +BLIS_EXPORT_BLIS void STRMM_(const char *side, const char *uplo, const char *transa, const char *diag, const f77_int *m, const f77_int *n, const float *alpha, const float *a, const f77_int *lda, float *b, const f77_int *ldb); + + + +BLIS_EXPORT_BLIS void STRSM(const char *side, const char *uplo, const char *transa, const char *diag, const f77_int *m, const f77_int *n, const float *alpha, const float *a, const f77_int *lda, float *b, const f77_int *ldb); + +BLIS_EXPORT_BLIS void strsm(const char *side, const char *uplo, const char *transa, const char *diag, const f77_int *m, const f77_int *n, const float *alpha, const float *a, const f77_int *lda, float *b, const f77_int *ldb); + +BLIS_EXPORT_BLIS void STRSM_(const char *side, const char *uplo, const char *transa, const char *diag, const f77_int *m, const f77_int *n, const float *alpha, const float *a, const f77_int *lda, float *b, const f77_int *ldb); + + + +BLIS_EXPORT_BLIS void DGEMM(const char *transa, const char *transb, const f77_int *m, const f77_int *n, const f77_int *k, const double *alpha, const double *a, const f77_int *lda, const double *b, const f77_int *ldb, const double *beta, double *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void dgemm(const char *transa, const char *transb, const f77_int *m, const f77_int *n, const f77_int *k, const double *alpha, const double *a, const f77_int *lda, const double *b, const f77_int *ldb, const double *beta, double *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void DGEMM_(const char *transa, const char *transb, const f77_int *m, const f77_int *n, const f77_int *k, const double *alpha, const double *a, const f77_int *lda, const double *b, const f77_int *ldb, const double *beta, double *c, const f77_int *ldc); + + + +BLIS_EXPORT_BLIS void DSYMM(const char *side, const char *uplo, const f77_int *m, const f77_int *n, const double *alpha, const double *a, const f77_int *lda, const double *b, const f77_int *ldb, const double *beta, double *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void dsymm(const char *side, const char *uplo, const f77_int *m, const f77_int *n, const double *alpha, const double *a, const f77_int *lda, const double *b, const f77_int *ldb, const double *beta, double *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void DSYMM_(const char *side, const char *uplo, const f77_int *m, const f77_int *n, const double *alpha, const double *a, const f77_int *lda, const double *b, const f77_int *ldb, const double *beta, double *c, const f77_int *ldc); + + + +BLIS_EXPORT_BLIS void DSYRK(const char *uplo, const char *trans, const f77_int *n, const f77_int *k, const double *alpha, const double *a, const f77_int *lda, const double *beta, double *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void dsyrk(const char *uplo, const char *trans, const f77_int *n, const f77_int *k, const double *alpha, const double *a, const f77_int *lda, const double *beta, double *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void DSYRK_(const char *uplo, const char *trans, const f77_int *n, const f77_int *k, const double *alpha, const double *a, const f77_int *lda, const double *beta, double *c, const f77_int *ldc); + + + +BLIS_EXPORT_BLIS void DSYR2K(const char *uplo, const char *trans, const f77_int *n, const f77_int *k, const double *alpha, const double *a, const f77_int *lda, const double *b, const f77_int *ldb, const double *beta, double *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void dsyr2k(const char *uplo, const char *trans, const f77_int *n, const f77_int *k, const double *alpha, const double *a, const f77_int *lda, const double *b, const f77_int *ldb, const double *beta, double *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void DSYR2K_(const char *uplo, const char *trans, const f77_int *n, const f77_int *k, const double *alpha, const double *a, const f77_int *lda, const double *b, const f77_int *ldb, const double *beta, double *c, const f77_int *ldc); + + + +BLIS_EXPORT_BLIS void DTRMM(const char *side, const char *uplo, const char *transa, const char *diag, const f77_int *m, const f77_int *n, const double *alpha, const double *a, const f77_int *lda, double *b, const f77_int *ldb); + +BLIS_EXPORT_BLIS void dtrmm(const char *side, const char *uplo, const char *transa, const char *diag, const f77_int *m, const f77_int *n, const double *alpha, const double *a, const f77_int *lda, double *b, const f77_int *ldb); + +BLIS_EXPORT_BLIS void DTRMM_(const char *side, const char *uplo, const char *transa, const char *diag, const f77_int *m, const f77_int *n, const double *alpha, const double *a, const f77_int *lda, double *b, const f77_int *ldb); + + + +BLIS_EXPORT_BLIS void DTRSM(const char *side, const char *uplo, const char *transa, const char *diag, const f77_int *m, const f77_int *n, const double *alpha, const double *a, const f77_int *lda, double *b, const f77_int *ldb); + +BLIS_EXPORT_BLIS void dtrsm(const char *side, const char *uplo, const char *transa, const char *diag, const f77_int *m, const f77_int *n, const double *alpha, const double *a, const f77_int *lda, double *b, const f77_int *ldb); + +BLIS_EXPORT_BLIS void DTRSM_(const char *side, const char *uplo, const char *transa, const char *diag, const f77_int *m, const f77_int *n, const double *alpha, const double *a, const f77_int *lda, double *b, const f77_int *ldb); + + + +BLIS_EXPORT_BLIS void CGEMM(const char *transa, const char *transb, const f77_int *m, const f77_int *n, const f77_int *k, const scomplex *alpha, const scomplex *a, const f77_int *lda, const scomplex *b, const f77_int *ldb, const scomplex *beta, scomplex *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void cgemm(const char *transa, const char *transb, const f77_int *m, const f77_int *n, const f77_int *k, const scomplex *alpha, const scomplex *a, const f77_int *lda, const scomplex *b, const f77_int *ldb, const scomplex *beta, scomplex *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void CGEMM_(const char *transa, const char *transb, const f77_int *m, const f77_int *n, const f77_int *k, const scomplex *alpha, const scomplex *a, const f77_int *lda, const scomplex *b, const f77_int *ldb, const scomplex *beta, scomplex *c, const f77_int *ldc); + + + +BLIS_EXPORT_BLIS void CSYMM(const char *side, const char *uplo, const f77_int *m, const f77_int *n, const scomplex *alpha, const scomplex *a, const f77_int *lda, const scomplex *b, const f77_int *ldb, const scomplex *beta, scomplex *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void csymm(const char *side, const char *uplo, const f77_int *m, const f77_int *n, const scomplex *alpha, const scomplex *a, const f77_int *lda, const scomplex *b, const f77_int *ldb, const scomplex *beta, scomplex *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void CSYMM_(const char *side, const char *uplo, const f77_int *m, const f77_int *n, const scomplex *alpha, const scomplex *a, const f77_int *lda, const scomplex *b, const f77_int *ldb, const scomplex *beta, scomplex *c, const f77_int *ldc); + + + +BLIS_EXPORT_BLIS void CHEMM(const char *side, const char *uplo, const f77_int *m, const f77_int *n, const scomplex *alpha, const scomplex *a, const f77_int *lda, const scomplex *b, const f77_int *ldb, const scomplex *beta, scomplex *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void chemm(const char *side, const char *uplo, const f77_int *m, const f77_int *n, const scomplex *alpha, const scomplex *a, const f77_int *lda, const scomplex *b, const f77_int *ldb, const scomplex *beta, scomplex *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void CHEMM_(const char *side, const char *uplo, const f77_int *m, const f77_int *n, const scomplex *alpha, const scomplex *a, const f77_int *lda, const scomplex *b, const f77_int *ldb, const scomplex *beta, scomplex *c, const f77_int *ldc); + + + +BLIS_EXPORT_BLIS void CSYRK(const char *uplo, const char *trans, const f77_int *n, const f77_int *k, const scomplex *alpha, const scomplex *a, const f77_int *lda, const scomplex *beta, scomplex *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void csyrk(const char *uplo, const char *trans, const f77_int *n, const f77_int *k, const scomplex *alpha, const scomplex *a, const f77_int *lda, const scomplex *beta, scomplex *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void CSYRK_(const char *uplo, const char *trans, const f77_int *n, const f77_int *k, const scomplex *alpha, const scomplex *a, const f77_int *lda, const scomplex *beta, scomplex *c, const f77_int *ldc); + + + +BLIS_EXPORT_BLIS void CHERK(const char *uplo, const char *trans, const f77_int *n, const f77_int *k, const float *alpha, const scomplex *a, const f77_int *lda, const float *beta, scomplex *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void cherk(const char *uplo, const char *trans, const f77_int *n, const f77_int *k, const float *alpha, const scomplex *a, const f77_int *lda, const float *beta, scomplex *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void CHERK_(const char *uplo, const char *trans, const f77_int *n, const f77_int *k, const float *alpha, const scomplex *a, const f77_int *lda, const float *beta, scomplex *c, const f77_int *ldc); + + + +BLIS_EXPORT_BLIS void CSYR2K(const char *uplo, const char *trans, const f77_int *n, const f77_int *k, const scomplex *alpha, const scomplex *a, const f77_int *lda, const scomplex *b, const f77_int *ldb, const scomplex *beta, scomplex *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void csyr2k(const char *uplo, const char *trans, const f77_int *n, const f77_int *k, const scomplex *alpha, const scomplex *a, const f77_int *lda, const scomplex *b, const f77_int *ldb, const scomplex *beta, scomplex *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void CSYR2K_(const char *uplo, const char *trans, const f77_int *n, const f77_int *k, const scomplex *alpha, const scomplex *a, const f77_int *lda, const scomplex *b, const f77_int *ldb, const scomplex *beta, scomplex *c, const f77_int *ldc); + + + +BLIS_EXPORT_BLIS void CHER2K(const char *uplo, const char *trans, const f77_int *n, const f77_int *k, const scomplex *alpha, const scomplex *a, const f77_int *lda, const scomplex *b, const f77_int *ldb, const float *beta, scomplex *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void cher2k(const char *uplo, const char *trans, const f77_int *n, const f77_int *k, const scomplex *alpha, const scomplex *a, const f77_int *lda, const scomplex *b, const f77_int *ldb, const float *beta, scomplex *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void CHER2K_(const char *uplo, const char *trans, const f77_int *n, const f77_int *k, const scomplex *alpha, const scomplex *a, const f77_int *lda, const scomplex *b, const f77_int *ldb, const float *beta, scomplex *c, const f77_int *ldc); + + + +BLIS_EXPORT_BLIS void CTRMM(const char *side, const char *uplo, const char *transa, const char *diag, const f77_int *m, const f77_int *n, const scomplex *alpha, const scomplex *a, const f77_int *lda, scomplex *b, const f77_int *ldb); + +BLIS_EXPORT_BLIS void ctrmm(const char *side, const char *uplo, const char *transa, const char *diag, const f77_int *m, const f77_int *n, const scomplex *alpha, const scomplex *a, const f77_int *lda, scomplex *b, const f77_int *ldb); + +BLIS_EXPORT_BLIS void CTRMM_(const char *side, const char *uplo, const char *transa, const char *diag, const f77_int *m, const f77_int *n, const scomplex *alpha, const scomplex *a, const f77_int *lda, scomplex *b, const f77_int *ldb); + + + +BLIS_EXPORT_BLIS void CTRSM(const char *side, const char *uplo, const char *transa, const char *diag, const f77_int *m, const f77_int *n, const scomplex *alpha, const scomplex *a, const f77_int *lda, scomplex *b, const f77_int *ldb); + +BLIS_EXPORT_BLIS void ctrsm(const char *side, const char *uplo, const char *transa, const char *diag, const f77_int *m, const f77_int *n, const scomplex *alpha, const scomplex *a, const f77_int *lda, scomplex *b, const f77_int *ldb); + +BLIS_EXPORT_BLIS void CTRSM_(const char *side, const char *uplo, const char *transa, const char *diag, const f77_int *m, const f77_int *n, const scomplex *alpha, const scomplex *a, const f77_int *lda, scomplex *b, const f77_int *ldb); + + + +BLIS_EXPORT_BLIS void ZGEMM(const char *transa, const char *transb, const f77_int *m, const f77_int *n, const f77_int *k, const dcomplex *alpha, const dcomplex *a, const f77_int *lda, const dcomplex *b, const f77_int *ldb, const dcomplex *beta, dcomplex *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void zgemm(const char *transa, const char *transb, const f77_int *m, const f77_int *n, const f77_int *k, const dcomplex *alpha, const dcomplex *a, const f77_int *lda, const dcomplex *b, const f77_int *ldb, const dcomplex *beta, dcomplex *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void ZGEMM_(const char *transa, const char *transb, const f77_int *m, const f77_int *n, const f77_int *k, const dcomplex *alpha, const dcomplex *a, const f77_int *lda, const dcomplex *b, const f77_int *ldb, const dcomplex *beta, dcomplex *c, const f77_int *ldc); + + + +BLIS_EXPORT_BLIS void ZSYMM(const char *side, const char *uplo, const f77_int *m, const f77_int *n, const dcomplex *alpha, const dcomplex *a, const f77_int *lda, const dcomplex *b, const f77_int *ldb, const dcomplex *beta, dcomplex *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void zsymm(const char *side, const char *uplo, const f77_int *m, const f77_int *n, const dcomplex *alpha, const dcomplex *a, const f77_int *lda, const dcomplex *b, const f77_int *ldb, const dcomplex *beta, dcomplex *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void ZSYMM_(const char *side, const char *uplo, const f77_int *m, const f77_int *n, const dcomplex *alpha, const dcomplex *a, const f77_int *lda, const dcomplex *b, const f77_int *ldb, const dcomplex *beta, dcomplex *c, const f77_int *ldc); + + + +BLIS_EXPORT_BLIS void ZHEMM(const char *side, const char *uplo, const f77_int *m, const f77_int *n, const dcomplex *alpha, const dcomplex *a, const f77_int *lda, const dcomplex *b, const f77_int *ldb, const dcomplex *beta, dcomplex *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void zhemm(const char *side, const char *uplo, const f77_int *m, const f77_int *n, const dcomplex *alpha, const dcomplex *a, const f77_int *lda, const dcomplex *b, const f77_int *ldb, const dcomplex *beta, dcomplex *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void ZHEMM_(const char *side, const char *uplo, const f77_int *m, const f77_int *n, const dcomplex *alpha, const dcomplex *a, const f77_int *lda, const dcomplex *b, const f77_int *ldb, const dcomplex *beta, dcomplex *c, const f77_int *ldc); + + + +BLIS_EXPORT_BLIS void ZSYRK(const char *uplo, const char *trans, const f77_int *n, const f77_int *k, const dcomplex *alpha, const dcomplex *a, const f77_int *lda, const dcomplex *beta, dcomplex *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void zsyrk(const char *uplo, const char *trans, const f77_int *n, const f77_int *k, const dcomplex *alpha, const dcomplex *a, const f77_int *lda, const dcomplex *beta, dcomplex *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void ZSYRK_(const char *uplo, const char *trans, const f77_int *n, const f77_int *k, const dcomplex *alpha, const dcomplex *a, const f77_int *lda, const dcomplex *beta, dcomplex *c, const f77_int *ldc); + + + +BLIS_EXPORT_BLIS void ZHERK(const char *uplo, const char *trans, const f77_int *n, const f77_int *k, const double *alpha, const dcomplex *a, const f77_int *lda, const double *beta, dcomplex *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void zherk(const char *uplo, const char *trans, const f77_int *n, const f77_int *k, const double *alpha, const dcomplex *a, const f77_int *lda, const double *beta, dcomplex *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void ZHERK_(const char *uplo, const char *trans, const f77_int *n, const f77_int *k, const double *alpha, const dcomplex *a, const f77_int *lda, const double *beta, dcomplex *c, const f77_int *ldc); + + + +BLIS_EXPORT_BLIS void ZSYR2K(const char *uplo, const char *trans, const f77_int *n, const f77_int *k, const dcomplex *alpha, const dcomplex *a, const f77_int *lda, const dcomplex *b, const f77_int *ldb, const dcomplex *beta, dcomplex *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void zsyr2k(const char *uplo, const char *trans, const f77_int *n, const f77_int *k, const dcomplex *alpha, const dcomplex *a, const f77_int *lda, const dcomplex *b, const f77_int *ldb, const dcomplex *beta, dcomplex *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void ZSYR2K_(const char *uplo, const char *trans, const f77_int *n, const f77_int *k, const dcomplex *alpha, const dcomplex *a, const f77_int *lda, const dcomplex *b, const f77_int *ldb, const dcomplex *beta, dcomplex *c, const f77_int *ldc); + + + +BLIS_EXPORT_BLIS void ZHER2K(const char *uplo, const char *trans, const f77_int *n, const f77_int *k, const dcomplex *alpha, const dcomplex *a, const f77_int *lda, const dcomplex *b, const f77_int *ldb, const double *beta, dcomplex *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void zher2k(const char *uplo, const char *trans, const f77_int *n, const f77_int *k, const dcomplex *alpha, const dcomplex *a, const f77_int *lda, const dcomplex *b, const f77_int *ldb, const double *beta, dcomplex *c, const f77_int *ldc); + +BLIS_EXPORT_BLIS void ZHER2K_(const char *uplo, const char *trans, const f77_int *n, const f77_int *k, const dcomplex *alpha, const dcomplex *a, const f77_int *lda, const dcomplex *b, const f77_int *ldb, const double *beta, dcomplex *c, const f77_int *ldc); + + + +BLIS_EXPORT_BLIS void ZTRMM(const char *side, const char *uplo, const char *transa, const char *diag, const f77_int *m, const f77_int *n, const dcomplex *alpha, const dcomplex *a, const f77_int *lda, dcomplex *b, const f77_int *ldb); + +BLIS_EXPORT_BLIS void ztrmm(const char *side, const char *uplo, const char *transa, const char *diag, const f77_int *m, const f77_int *n, const dcomplex *alpha, const dcomplex *a, const f77_int *lda, dcomplex *b, const f77_int *ldb); + +BLIS_EXPORT_BLIS void ZTRMM_(const char *side, const char *uplo, const char *transa, const char *diag, const f77_int *m, const f77_int *n, const dcomplex *alpha, const dcomplex *a, const f77_int *lda, dcomplex *b, const f77_int *ldb); + + + +BLIS_EXPORT_BLIS void ZTRSM(const char *side, const char *uplo, const char *transa, const char *diag, const f77_int *m, const f77_int *n, const dcomplex *alpha, const dcomplex *a, const f77_int *lda, dcomplex *b, const f77_int *ldb); + +BLIS_EXPORT_BLIS void ztrsm(const char *side, const char *uplo, const char *transa, const char *diag, const f77_int *m, const f77_int *n, const dcomplex *alpha, const dcomplex *a, const f77_int *lda, dcomplex *b, const f77_int *ldb); + +BLIS_EXPORT_BLIS void ZTRSM_(const char *side, const char *uplo, const char *transa, const char *diag, const f77_int *m, const f77_int *n, const dcomplex *alpha, const dcomplex *a, const f77_int *lda, dcomplex *b, const f77_int *ldb); + + + +// Miscellaneous APIs +BLIS_EXPORT_BLIS void CDOTCSUB( const f77_int* n, const scomplex* x, const f77_int* incx, const scomplex* y, const f77_int* incy, scomplex* rval); + +BLIS_EXPORT_BLIS void cdotcsub( const f77_int* n, const scomplex* x, const f77_int* incx, const scomplex* y, const f77_int* incy, scomplex* rval); + +BLIS_EXPORT_BLIS void CDOTCSUB_( const f77_int* n, const scomplex* x, const f77_int* incx, const scomplex* y, const f77_int* incy, scomplex* rval); + + + +BLIS_EXPORT_BLIS void CDOTUSUB( const f77_int* n, const scomplex* x, const f77_int* incxy, const scomplex* y, const f77_int* incy, scomplex* rval); + +BLIS_EXPORT_BLIS void cdotusub( const f77_int* n, const scomplex* x, const f77_int* incxy, const scomplex* y, const f77_int* incy, scomplex* rval); + +BLIS_EXPORT_BLIS void CDOTUSUB_( const f77_int* n, const scomplex* x, const f77_int* incxy, const scomplex* y, const f77_int* incy, scomplex* rval); + + + +BLIS_EXPORT_BLIS void DASUMSUB(const f77_int* n, const double* x, const f77_int* incx, double* rval); + +BLIS_EXPORT_BLIS void dasumsub(const f77_int* n, const double* x, const f77_int* incx, double* rval); + +BLIS_EXPORT_BLIS void DASUMSUB_(const f77_int* n, const double* x, const f77_int* incx, double* rval); + + + +BLIS_EXPORT_BLIS void DDOTSUB(const f77_int* n, const double* x, const f77_int* incx, const double* y, const f77_int* incy, double* rval); + +BLIS_EXPORT_BLIS void ddotsub(const f77_int* n, const double* x, const f77_int* incx, const double* y, const f77_int* incy, double* rval); + +BLIS_EXPORT_BLIS void DDOTSUB_(const f77_int* n, const double* x, const f77_int* incx, const double* y, const f77_int* incy, double* rval); + + + +BLIS_EXPORT_BLIS void DNRM2SUB(const f77_int* n, const double* x, const f77_int* incx, double *rval); + +BLIS_EXPORT_BLIS void dnrm2sub(const f77_int* n, const double* x, const f77_int* incx, double *rval); + +BLIS_EXPORT_BLIS void DNRM2SUB_(const f77_int* n, const double* x, const f77_int* incx, double *rval); + + + +BLIS_EXPORT_BLIS void DZASUMSUB(const f77_int* n, const dcomplex* x, const f77_int* incx, double* rval); + +BLIS_EXPORT_BLIS void dzasumsub(const f77_int* n, const dcomplex* x, const f77_int* incx, double* rval); + +BLIS_EXPORT_BLIS void DZASUMSUB_(const f77_int* n, const dcomplex* x, const f77_int* incx, double* rval); + + + +BLIS_EXPORT_BLIS void DZNRM2SUB(const f77_int* n, const dcomplex* x, const f77_int* incx, double* rval); + +BLIS_EXPORT_BLIS void dznrm2sub(const f77_int* n, const dcomplex* x, const f77_int* incx, double* rval); + +BLIS_EXPORT_BLIS void DZNRM2SUB_(const f77_int* n, const dcomplex* x, const f77_int* incx, double* rval); + + + +BLIS_EXPORT_BLIS void ICAMAXSUB(const f77_int* n, const scomplex* x, const f77_int* incx, f77_int* rval); + +BLIS_EXPORT_BLIS void icamaxsub(const f77_int* n, const scomplex* x, const f77_int* incx, f77_int* rval); + +BLIS_EXPORT_BLIS void ICAMAXSUB_(const f77_int* n, const scomplex* x, const f77_int* incx, f77_int* rval); + + + +BLIS_EXPORT_BLIS void ICAMINSUB( const f77_int* n, const scomplex* x, const f77_int* incx, f77_int* rval); + +BLIS_EXPORT_BLIS void icaminsub( const f77_int* n, const scomplex* x, const f77_int* incx, f77_int* rval); + +BLIS_EXPORT_BLIS void ICAMINSUB_( const f77_int* n, const scomplex* x, const f77_int* incx, f77_int* rval); + + + +BLIS_EXPORT_BLIS void IDAMAXSUB( const f77_int* n, const double* x, const f77_int* incx, f77_int* rval); + +BLIS_EXPORT_BLIS void idamaxsub( const f77_int* n, const double* x, const f77_int* incx, f77_int* rval); + +BLIS_EXPORT_BLIS void IDAMAXSUB_( const f77_int* n, const double* x, const f77_int* incx, f77_int* rval); + + + +BLIS_EXPORT_BLIS void IDAMINSUB(const f77_int* n, const double* x, const f77_int* incx, f77_int* rval); + +BLIS_EXPORT_BLIS void idaminsub(const f77_int* n, const double* x, const f77_int* incx, f77_int* rval); + +BLIS_EXPORT_BLIS void IDAMINSUB_(const f77_int* n, const double* x, const f77_int* incx, f77_int* rval); + + + +BLIS_EXPORT_BLIS void ISAMAXSUB( const f77_int* n, const float* x, const f77_int* incx, f77_int* rval); + +BLIS_EXPORT_BLIS void isamaxsub( const f77_int* n, const float* x, const f77_int* incx, f77_int* rval); + +BLIS_EXPORT_BLIS void ISAMAXSUB_( const f77_int* n, const float* x, const f77_int* incx, f77_int* rval); + + + +BLIS_EXPORT_BLIS void ISAMINSUB( const f77_int* n, const float* x, const f77_int* incx, f77_int* rval); + +BLIS_EXPORT_BLIS void isaminsub( const f77_int* n, const float* x, const f77_int* incx, f77_int* rval); + +BLIS_EXPORT_BLIS void ISAMINSUB_( const f77_int* n, const float* x, const f77_int* incx, f77_int* rval); + + + +BLIS_EXPORT_BLIS void IZAMINSUB( const f77_int* n, const dcomplex* x, const f77_int* incx, f77_int* rval); + +BLIS_EXPORT_BLIS void izaminsub( const f77_int* n, const dcomplex* x, const f77_int* incx, f77_int* rval); + +BLIS_EXPORT_BLIS void IZAMINSUB_( const f77_int* n, const dcomplex* x, const f77_int* incx, f77_int* rval); + + + +BLIS_EXPORT_BLIS void IZAMAXSUB( const f77_int* n, const dcomplex* x, const f77_int* incx, f77_int* rval); + +BLIS_EXPORT_BLIS void izamaxsub( const f77_int* n, const dcomplex* x, const f77_int* incx, f77_int* rval); + +BLIS_EXPORT_BLIS void IZAMAXSUB_( const f77_int* n, const dcomplex* x, const f77_int* incx, f77_int* rval); + + + +BLIS_EXPORT_BLIS void SASUMSUB( const f77_int* n, const float* x, const f77_int* incx, float* rval); + +BLIS_EXPORT_BLIS void sasumsub( const f77_int* n, const float* x, const f77_int* incx, float* rval); + +BLIS_EXPORT_BLIS void SASUMSUB_( const f77_int* n, const float* x, const f77_int* incx, float* rval); + + + +BLIS_EXPORT_BLIS void SCASUMSUB( const f77_int* n, const scomplex* x, const f77_int* incx, float* rval); + +BLIS_EXPORT_BLIS void scasumsub( const f77_int* n, const scomplex* x, const f77_int* incx, float* rval); + +BLIS_EXPORT_BLIS void SCASUMSUB_( const f77_int* n, const scomplex* x, const f77_int* incx, float* rval); + + + +BLIS_EXPORT_BLIS void SCNRM2SUB( const f77_int* n, const scomplex* x, const f77_int* incx, float* rval); + +BLIS_EXPORT_BLIS void scnrm2sub( const f77_int* n, const scomplex* x, const f77_int* incx, float* rval); + +BLIS_EXPORT_BLIS void SCNRM2SUB_( const f77_int* n, const scomplex* x, const f77_int* incx, float* rval); + + + +BLIS_EXPORT_BLIS void SDOTSUB( const f77_int* n, const float* x, const f77_int* incx, const float* y, const f77_int* incy, float* rval); + +BLIS_EXPORT_BLIS void sdotsub( const f77_int* n, const float* x, const f77_int* incx, const float* y, const f77_int* incy, float* rval); + +BLIS_EXPORT_BLIS void SDOTSUB_( const f77_int* n, const float* x, const f77_int* incx, const float* y, const f77_int* incy, float* rval); + + + +BLIS_EXPORT_BLIS void SNRM2SUB( const f77_int* n, const float* x, const f77_int* incx, float *rval); + +BLIS_EXPORT_BLIS void snrm2sub( const f77_int* n, const float* x, const f77_int* incx, float *rval); + +BLIS_EXPORT_BLIS void SNRM2SUB_( const f77_int* n, const float* x, const f77_int* incx, float *rval); + + + +BLIS_EXPORT_BLIS void ZDOTCSUB( const f77_int* n, const dcomplex* x, const f77_int* incx, const dcomplex* y, const f77_int* incy, dcomplex* rval); + +BLIS_EXPORT_BLIS void zdotcsub( const f77_int* n, const dcomplex* x, const f77_int* incx, const dcomplex* y, const f77_int* incy, dcomplex* rval); + +BLIS_EXPORT_BLIS void ZDOTCSUB_( const f77_int* n, const dcomplex* x, const f77_int* incx, const dcomplex* y, const f77_int* incy, dcomplex* rval); + + + +BLIS_EXPORT_BLIS void ZDOTUSUB( const f77_int* n, const dcomplex* x, const f77_int* incx, const dcomplex* y, const f77_int* incy, dcomplex* rval); + +BLIS_EXPORT_BLIS void zdotusub( const f77_int* n, const dcomplex* x, const f77_int* incx, const dcomplex* y, const f77_int* incy, dcomplex* rval); + +BLIS_EXPORT_BLIS void ZDOTUSUB_( const f77_int* n, const dcomplex* x, const f77_int* incx, const dcomplex* y, const f77_int* incy, dcomplex* rval); + + + +BLIS_EXPORT_BLIS void SDSDOTSUB( const f77_int* n, float* sb, const float* x, const f77_int* incx, const float* y, const f77_int* incy, float* dot); + +BLIS_EXPORT_BLIS void sdsdotsub( const f77_int* n, float* sb, const float* x, const f77_int* incx, const float* y, const f77_int* incy, float* dot); + +BLIS_EXPORT_BLIS void SDSDOTSUB_( const f77_int* n, float* sb, const float* x, const f77_int* incx, const float* y, const f77_int* incy, float* dot); + + + +BLIS_EXPORT_BLIS void DSDOTSUB( const f77_int* n, const float* x, const f77_int* incx, const float* y, const f77_int* incy, double* dot); + +BLIS_EXPORT_BLIS void dsdotsub( const f77_int* n, const float* x, const f77_int* incx, const float* y, const f77_int* incy, double* dot); + +BLIS_EXPORT_BLIS void DSDOTSUB_( const f77_int* n, const float* x, const f77_int* incx, const float* y, const f77_int* incy, double* dot); + + + +BLIS_EXPORT_BLIS f77_int LSAME(const char *ca, const char *cb, const f77_int a, const f77_int b); + +BLIS_EXPORT_BLIS f77_int lsame(const char *ca, const char *cb, const f77_int a, const f77_int b); + +BLIS_EXPORT_BLIS f77_int LSAME_(const char *ca, const char *cb, const f77_int a, const f77_int b); + + + +BLIS_EXPORT_BLIS int XERBLA(const char *srname, const f77_int *info, ftnlen n); + +BLIS_EXPORT_BLIS int xerbla(const char *srname, const f77_int *info, ftnlen n); + +BLIS_EXPORT_BLIS int XERBLA_(const char *srname, const f77_int *info, ftnlen n); + + + +//Auxiliary APIs +BLIS_EXPORT_BLIS double DCABS1(bla_dcomplex *z); + +BLIS_EXPORT_BLIS double dcabs1(bla_dcomplex *z); + +BLIS_EXPORT_BLIS double DCABS1_(bla_dcomplex *z); + + + +BLIS_EXPORT_BLIS float SCABS1(bla_scomplex* z); + +BLIS_EXPORT_BLIS float scabs1(bla_scomplex* z); + +BLIS_EXPORT_BLIS float SCABS1_(bla_scomplex* z); + + + +//BLAS Extension APIs +BLIS_EXPORT_BLIS void CAXPBY( const f77_int* n, const scomplex* alpha, const scomplex *x, const f77_int* incx, const scomplex* beta, scomplex *y, const f77_int* incy); + +BLIS_EXPORT_BLIS void caxpby( const f77_int* n, const scomplex* alpha, const scomplex *x, const f77_int* incx, const scomplex* beta, scomplex *y, const f77_int* incy); + +BLIS_EXPORT_BLIS void CAXPBY_( const f77_int* n, const scomplex* alpha, const scomplex *x, const f77_int* incx, const scomplex* beta, scomplex *y, const f77_int* incy); + + + +BLIS_EXPORT_BLIS void CGEMM3M( const f77_char* transa, const f77_char* transb, const f77_int* m, const f77_int* n, const f77_int* k, const scomplex* alpha, const scomplex* a, const f77_int* lda, const scomplex* b, const f77_int* ldb, const scomplex* beta, scomplex* c, const f77_int* ldc); + +BLIS_EXPORT_BLIS void cgemm3m( const f77_char* transa, const f77_char* transb, const f77_int* m, const f77_int* n, const f77_int* k, const scomplex* alpha, const scomplex* a, const f77_int* lda, const scomplex* b, const f77_int* ldb, const scomplex* beta, scomplex* c, const f77_int* ldc); + +BLIS_EXPORT_BLIS void CGEMM3M_( const f77_char* transa, const f77_char* transb, const f77_int* m, const f77_int* n, const f77_int* k, const scomplex* alpha, const scomplex* a, const f77_int* lda, const scomplex* b, const f77_int* ldb, const scomplex* beta, scomplex* c, const f77_int* ldc); + + + +BLIS_EXPORT_BLIS void CGEMM_BATCH( const f77_char* transa_array, const f77_char* transb_array, const f77_int *m_array, const f77_int *n_array, const f77_int *k_array, const scomplex* alpha_array, const scomplex** a_array, const f77_int *lda_array, const scomplex** b_array, const f77_int *ldb_array, const scomplex* beta_array, scomplex** c_array, const f77_int *ldc_array, const f77_int* group_count, const f77_int *group_size); + +BLIS_EXPORT_BLIS void cgemm_batch( const f77_char* transa_array, const f77_char* transb_array, const f77_int *m_array, const f77_int *n_array, const f77_int *k_array, const scomplex* alpha_array, const scomplex** a_array, const f77_int *lda_array, const scomplex** b_array, const f77_int *ldb_array, const scomplex* beta_array, scomplex** c_array, const f77_int *ldc_array, const f77_int* group_count, const f77_int *group_size); + +BLIS_EXPORT_BLIS void CGEMM_BATCH_( const f77_char* transa_array, const f77_char* transb_array, const f77_int *m_array, const f77_int *n_array, const f77_int *k_array, const scomplex* alpha_array, const scomplex** a_array, const f77_int *lda_array, const scomplex** b_array, const f77_int *ldb_array, const scomplex* beta_array, scomplex** c_array, const f77_int *ldc_array, const f77_int* group_count, const f77_int *group_size); + + + +BLIS_EXPORT_BLIS void CGEMMT( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const scomplex* alpha, const scomplex* a, const f77_int* lda, const scomplex* b, const f77_int* ldb, const scomplex* beta, scomplex* c, const f77_int* ldc); + +BLIS_EXPORT_BLIS void cgemmt( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const scomplex* alpha, const scomplex* a, const f77_int* lda, const scomplex* b, const f77_int* ldb, const scomplex* beta, scomplex* c, const f77_int* ldc); + +BLIS_EXPORT_BLIS void CGEMMT_( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const scomplex* alpha, const scomplex* a, const f77_int* lda, const scomplex* b, const f77_int* ldb, const scomplex* beta, scomplex* c, const f77_int* ldc); + + + +BLIS_EXPORT_BLIS void DAXPBY(const f77_int* n, const double* alpha, const double *x, const f77_int* incx, const double* beta, double *y, const f77_int* incy); + +BLIS_EXPORT_BLIS void daxpby(const f77_int* n, const double* alpha, const double *x, const f77_int* incx, const double* beta, double *y, const f77_int* incy); + +BLIS_EXPORT_BLIS void DAXPBY_(const f77_int* n, const double* alpha, const double *x, const f77_int* incx, const double* beta, double *y, const f77_int* incy); + + + +BLIS_EXPORT_BLIS void DGEMM_BATCH( const f77_char* transa_array, const f77_char* transb_array, const f77_int *m_array, const f77_int *n_array, const f77_int *k_array, const double* alpha_array, const double** a_array, const f77_int *lda_array, const double** b_array, const f77_int *ldb_array, const double* beta_array, double** c_array, const f77_int *ldc_array, const f77_int* group_count, const f77_int *group_size); + +BLIS_EXPORT_BLIS void dgemm_batch( const f77_char* transa_array, const f77_char* transb_array, const f77_int *m_array, const f77_int *n_array, const f77_int *k_array, const double* alpha_array, const double** a_array, const f77_int *lda_array, const double** b_array, const f77_int *ldb_array, const double* beta_array, double** c_array, const f77_int *ldc_array, const f77_int* group_count, const f77_int *group_size); + +BLIS_EXPORT_BLIS void DGEMM_BATCH_( const f77_char* transa_array, const f77_char* transb_array, const f77_int *m_array, const f77_int *n_array, const f77_int *k_array, const double* alpha_array, const double** a_array, const f77_int *lda_array, const double** b_array, const f77_int *ldb_array, const double* beta_array, double** c_array, const f77_int *ldc_array, const f77_int* group_count, const f77_int *group_size); + + + +BLIS_EXPORT_BLIS void DGEMMT( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const double* alpha, const double* a, const f77_int* lda, const double* b, const f77_int* ldb, const double* beta, double* c, const f77_int* ldc); + +BLIS_EXPORT_BLIS void dgemmt( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const double* alpha, const double* a, const f77_int* lda, const double* b, const f77_int* ldb, const double* beta, double* c, const f77_int* ldc); + +BLIS_EXPORT_BLIS void DGEMMT_( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const double* alpha, const double* a, const f77_int* lda, const double* b, const f77_int* ldb, const double* beta, double* c, const f77_int* ldc); + + + +BLIS_EXPORT_BLIS void SAXPBY( const f77_int* n, const float* alpha, const float *x, const f77_int* incx, const float* beta, float *y, const f77_int* incy); + +BLIS_EXPORT_BLIS void saxpby( const f77_int* n, const float* alpha, const float *x, const f77_int* incx, const float* beta, float *y, const f77_int* incy); + +BLIS_EXPORT_BLIS void SAXPBY_( const f77_int* n, const float* alpha, const float *x, const f77_int* incx, const float* beta, float *y, const f77_int* incy); + + + +BLIS_EXPORT_BLIS void SGEMM_BATCH(const f77_char* transa_array, const f77_char* transb_array, const f77_int *m_array, const f77_int *n_array, const f77_int *k_array, const float* alpha_array, const float** a_array, const f77_int *lda_array, const float** b_array, const f77_int *ldb_array, const float* beta_array, float** c_array, const f77_int *ldc_array, const f77_int* group_count, const f77_int *group_size); + +BLIS_EXPORT_BLIS void sgemm_batch(const f77_char* transa_array, const f77_char* transb_array, const f77_int *m_array, const f77_int *n_array, const f77_int *k_array, const float* alpha_array, const float** a_array, const f77_int *lda_array, const float** b_array, const f77_int *ldb_array, const float* beta_array, float** c_array, const f77_int *ldc_array, const f77_int* group_count, const f77_int *group_size); + +BLIS_EXPORT_BLIS void SGEMM_BATCH_(const f77_char* transa_array, const f77_char* transb_array, const f77_int *m_array, const f77_int *n_array, const f77_int *k_array, const float* alpha_array, const float** a_array, const f77_int *lda_array, const float** b_array, const f77_int *ldb_array, const float* beta_array, float** c_array, const f77_int *ldc_array, const f77_int* group_count, const f77_int *group_size); + + + +BLIS_EXPORT_BLIS void SGEMMT( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const float* alpha, const float* a, const f77_int* lda, const float* b, const f77_int* ldb, const float* beta, float* c, const f77_int* ldc); + +BLIS_EXPORT_BLIS void sgemmt( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const float* alpha, const float* a, const f77_int* lda, const float* b, const f77_int* ldb, const float* beta, float* c, const f77_int* ldc); + +BLIS_EXPORT_BLIS void SGEMMT_( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const float* alpha, const float* a, const f77_int* lda, const float* b, const f77_int* ldb, const float* beta, float* c, const f77_int* ldc); + + + +BLIS_EXPORT_BLIS void ZAXPBY( const f77_int* n, const dcomplex* alpha, const dcomplex *x, const f77_int* incx, const dcomplex* beta, dcomplex *y, const f77_int* incy); + +BLIS_EXPORT_BLIS void zaxpby( const f77_int* n, const dcomplex* alpha, const dcomplex *x, const f77_int* incx, const dcomplex* beta, dcomplex *y, const f77_int* incy); + +BLIS_EXPORT_BLIS void ZAXPBY_( const f77_int* n, const dcomplex* alpha, const dcomplex *x, const f77_int* incx, const dcomplex* beta, dcomplex *y, const f77_int* incy); + + + +BLIS_EXPORT_BLIS void ZGEMM3M( const f77_char* transa, const f77_char* transb, const f77_int* m, const f77_int* n, const f77_int* k, const dcomplex* alpha, const dcomplex* a, const f77_int* lda, const dcomplex* b, const f77_int* ldb, const dcomplex* beta, dcomplex* c, const f77_int* ldc); + +BLIS_EXPORT_BLIS void zgemm3m( const f77_char* transa, const f77_char* transb, const f77_int* m, const f77_int* n, const f77_int* k, const dcomplex* alpha, const dcomplex* a, const f77_int* lda, const dcomplex* b, const f77_int* ldb, const dcomplex* beta, dcomplex* c, const f77_int* ldc); + +BLIS_EXPORT_BLIS void ZGEMM3M_( const f77_char* transa, const f77_char* transb, const f77_int* m, const f77_int* n, const f77_int* k, const dcomplex* alpha, const dcomplex* a, const f77_int* lda, const dcomplex* b, const f77_int* ldb, const dcomplex* beta, dcomplex* c, const f77_int* ldc); + + + +BLIS_EXPORT_BLIS void ZGEMM_BATCH( const f77_char* transa_array, const f77_char* transb_array, const f77_int *m_array, const f77_int *n_array, const f77_int *k_array, const dcomplex* alpha_array, const dcomplex** a_array, const f77_int *lda_array, const dcomplex** b_array, const f77_int *ldb_array, const dcomplex* beta_array, dcomplex** c_array, const f77_int *ldc_array, const f77_int* group_count, const f77_int *group_size); + +BLIS_EXPORT_BLIS void zgemm_batch( const f77_char* transa_array, const f77_char* transb_array, const f77_int *m_array, const f77_int *n_array, const f77_int *k_array, const dcomplex* alpha_array, const dcomplex** a_array, const f77_int *lda_array, const dcomplex** b_array, const f77_int *ldb_array, const dcomplex* beta_array, dcomplex** c_array, const f77_int *ldc_array, const f77_int* group_count, const f77_int *group_size); + +BLIS_EXPORT_BLIS void ZGEMM_BATCH_( const f77_char* transa_array, const f77_char* transb_array, const f77_int *m_array, const f77_int *n_array, const f77_int *k_array, const dcomplex* alpha_array, const dcomplex** a_array, const f77_int *lda_array, const dcomplex** b_array, const f77_int *ldb_array, const dcomplex* beta_array, dcomplex** c_array, const f77_int *ldc_array, const f77_int* group_count, const f77_int *group_size); + + + +BLIS_EXPORT_BLIS void ZGEMMT( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const dcomplex* alpha, const dcomplex* a, const f77_int* lda, const dcomplex* b, const f77_int* ldb, const dcomplex* beta, dcomplex* c, const f77_int* ldc); + +BLIS_EXPORT_BLIS void zgemmt( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const dcomplex* alpha, const dcomplex* a, const f77_int* lda, const dcomplex* b, const f77_int* ldb, const dcomplex* beta, dcomplex* c, const f77_int* ldc); + +BLIS_EXPORT_BLIS void ZGEMMT_( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const dcomplex* alpha, const dcomplex* a, const f77_int* lda, const dcomplex* b, const f77_int* ldb, const dcomplex* beta, dcomplex* c, const f77_int* ldc); + + + +BLIS_EXPORT_BLIS void CIMATCOPY(f77_char* trans, f77_int* rows, f77_int* cols, const scomplex* alpha, scomplex* aptr, f77_int* lda, f77_int* ldb); + +BLIS_EXPORT_BLIS void cimatcopy(f77_char* trans, f77_int* rows, f77_int* cols, const scomplex* alpha, scomplex* aptr, f77_int* lda, f77_int* ldb); + +BLIS_EXPORT_BLIS void CIMATCOPY_(f77_char* trans, f77_int* rows, f77_int* cols, const scomplex* alpha, scomplex* aptr, f77_int* lda, f77_int* ldb); + + + +BLIS_EXPORT_BLIS void COMATADD(f77_char* transa, f77_char* transb, f77_int* m, f77_int* n, const scomplex* alpha, const scomplex* A, f77_int* lda, const scomplex* beta, scomplex* B, f77_int* ldb, scomplex* C, f77_int* ldc); + +BLIS_EXPORT_BLIS void comatadd(f77_char* transa, f77_char* transb, f77_int* m, f77_int* n, const scomplex* alpha, const scomplex* A, f77_int* lda, const scomplex* beta, scomplex* B, f77_int* ldb, scomplex* C, f77_int* ldc); + +BLIS_EXPORT_BLIS void COMATADD_(f77_char* transa, f77_char* transb, f77_int* m, f77_int* n, const scomplex* alpha, const scomplex* A, f77_int* lda, const scomplex* beta, scomplex* B, f77_int* ldb, scomplex* C, f77_int* ldc); + + + +BLIS_EXPORT_BLIS void COMATCOPY2(f77_char* trans, f77_int* rows, f77_int* cols, const scomplex* alpha, const scomplex* aptr, f77_int* lda, f77_int* stridea, scomplex* bptr, f77_int* ldb, f77_int* strideb); + +BLIS_EXPORT_BLIS void comatcopy2(f77_char* trans, f77_int* rows, f77_int* cols, const scomplex* alpha, const scomplex* aptr, f77_int* lda, f77_int* stridea, scomplex* bptr, f77_int* ldb, f77_int* strideb); + +BLIS_EXPORT_BLIS void COMATCOPY2_(f77_char* trans, f77_int* rows, f77_int* cols, const scomplex* alpha, const scomplex* aptr, f77_int* lda, f77_int* stridea, scomplex* bptr, f77_int* ldb, f77_int* strideb); + + + +BLIS_EXPORT_BLIS void COMATCOPY(f77_char* trans, f77_int* rows, f77_int* cols, const scomplex* alpha, const scomplex* aptr, f77_int* lda, scomplex* bptr, f77_int* ldb); + +BLIS_EXPORT_BLIS void comatcopy(f77_char* trans, f77_int* rows, f77_int* cols, const scomplex* alpha, const scomplex* aptr, f77_int* lda, scomplex* bptr, f77_int* ldb); + +BLIS_EXPORT_BLIS void COMATCOPY_(f77_char* trans, f77_int* rows, f77_int* cols, const scomplex* alpha, const scomplex* aptr, f77_int* lda, scomplex* bptr, f77_int* ldb); + + + +BLIS_EXPORT_BLIS void DOMATADD(f77_char* transa, f77_char* transb, f77_int* m, f77_int* n, const double* alpha, const double* A, f77_int* lda, const double* beta, const double* B, f77_int* ldb, double* C, f77_int* ldc); + +BLIS_EXPORT_BLIS void domatadd(f77_char* transa, f77_char* transb, f77_int* m, f77_int* n, const double* alpha, const double* A, f77_int* lda, const double* beta, const double* B, f77_int* ldb, double* C, f77_int* ldc); + +BLIS_EXPORT_BLIS void DOMATADD_(f77_char* transa, f77_char* transb, f77_int* m, f77_int* n, const double* alpha, const double* A, f77_int* lda, const double* beta, const double* B, f77_int* ldb, double* C, f77_int* ldc); + + + +BLIS_EXPORT_BLIS void DOMATCOPY2(f77_char* trans, f77_int* rows, f77_int* cols, const double* alpha, const double* aptr, f77_int* lda, f77_int* stridea, double* bptr, f77_int* ldb, f77_int* strideb); + +BLIS_EXPORT_BLIS void domatcopy2(f77_char* trans, f77_int* rows, f77_int* cols, const double* alpha, const double* aptr, f77_int* lda, f77_int* stridea, double* bptr, f77_int* ldb, f77_int* strideb); + +BLIS_EXPORT_BLIS void DOMATCOPY2_(f77_char* trans, f77_int* rows, f77_int* cols, const double* alpha, const double* aptr, f77_int* lda, f77_int* stridea, double* bptr, f77_int* ldb, f77_int* strideb); + + + +BLIS_EXPORT_BLIS void DOMATCOPY(f77_char* trans, f77_int* rows, f77_int* cols, const double* alpha, const double* aptr, f77_int* lda, double* bptr, f77_int* ldb); + +BLIS_EXPORT_BLIS void domatcopy(f77_char* trans, f77_int* rows, f77_int* cols, const double* alpha, const double* aptr, f77_int* lda, double* bptr, f77_int* ldb); + +BLIS_EXPORT_BLIS void DOMATCOPY_(f77_char* trans, f77_int* rows, f77_int* cols, const double* alpha, const double* aptr, f77_int* lda, double* bptr, f77_int* ldb); + + + +BLIS_EXPORT_BLIS void SIMATCOPY( f77_char* trans, f77_int* rows, f77_int* cols, const float* alpha, float* aptr, f77_int* lda, f77_int* ldb); + +BLIS_EXPORT_BLIS void simatcopy( f77_char* trans, f77_int* rows, f77_int* cols, const float* alpha, float* aptr, f77_int* lda, f77_int* ldb); + +BLIS_EXPORT_BLIS void SIMATCOPY_( f77_char* trans, f77_int* rows, f77_int* cols, const float* alpha, float* aptr, f77_int* lda, f77_int* ldb); + + + +BLIS_EXPORT_BLIS void SOMATADD( f77_char* transa, f77_char* transb, f77_int* m, f77_int* n, const float* alpha, const float* A, f77_int* lda, const float* beta, const float* B, f77_int* ldb, float* C, f77_int* ldc); + +BLIS_EXPORT_BLIS void somatadd( f77_char* transa, f77_char* transb, f77_int* m, f77_int* n, const float* alpha, const float* A, f77_int* lda, const float* beta, const float* B, f77_int* ldb, float* C, f77_int* ldc); + +BLIS_EXPORT_BLIS void SOMATADD_( f77_char* transa, f77_char* transb, f77_int* m, f77_int* n, const float* alpha, const float* A, f77_int* lda, const float* beta, const float* B, f77_int* ldb, float* C, f77_int* ldc); + + + +BLIS_EXPORT_BLIS void SOMATCOPY2( f77_char* trans, f77_int* rows, f77_int* cols, const float* alpha, const float* aptr, f77_int* lda, f77_int* stridea, float* bptr, f77_int* ldb, f77_int* strideb); + +BLIS_EXPORT_BLIS void somatcopy2( f77_char* trans, f77_int* rows, f77_int* cols, const float* alpha, const float* aptr, f77_int* lda, f77_int* stridea, float* bptr, f77_int* ldb, f77_int* strideb); + +BLIS_EXPORT_BLIS void SOMATCOPY2_( f77_char* trans, f77_int* rows, f77_int* cols, const float* alpha, const float* aptr, f77_int* lda, f77_int* stridea, float* bptr, f77_int* ldb, f77_int* strideb); + + + +BLIS_EXPORT_BLIS void SOMATCOPY( f77_char* trans, f77_int* rows, f77_int* cols, const float* alpha, const float* aptr, f77_int* lda, float* bptr, f77_int* ldb); + +BLIS_EXPORT_BLIS void somatcopy( f77_char* trans, f77_int* rows, f77_int* cols, const float* alpha, const float* aptr, f77_int* lda, float* bptr, f77_int* ldb); + +BLIS_EXPORT_BLIS void SOMATCOPY_( f77_char* trans, f77_int* rows, f77_int* cols, const float* alpha, const float* aptr, f77_int* lda, float* bptr, f77_int* ldb); + + + +BLIS_EXPORT_BLIS void ZIMATCOPY(f77_char* trans, f77_int* rows, f77_int* cols, const dcomplex* alpha, dcomplex* aptr, f77_int* lda, f77_int* ldb); + +BLIS_EXPORT_BLIS void zimatcopy(f77_char* trans, f77_int* rows, f77_int* cols, const dcomplex* alpha, dcomplex* aptr, f77_int* lda, f77_int* ldb); + +BLIS_EXPORT_BLIS void ZIMATCOPY_(f77_char* trans, f77_int* rows, f77_int* cols, const dcomplex* alpha, dcomplex* aptr, f77_int* lda, f77_int* ldb); + + + +BLIS_EXPORT_BLIS void ZOMATADD(f77_char* transa, f77_char* transb, f77_int* m, f77_int* n, const dcomplex* alpha, const dcomplex* A, f77_int* lda, const dcomplex* beta, dcomplex* B, f77_int* ldb, dcomplex* C, f77_int* ldc); + +BLIS_EXPORT_BLIS void zomatadd(f77_char* transa, f77_char* transb, f77_int* m, f77_int* n, const dcomplex* alpha, const dcomplex* A, f77_int* lda, const dcomplex* beta, dcomplex* B, f77_int* ldb, dcomplex* C, f77_int* ldc); + +BLIS_EXPORT_BLIS void ZOMATADD_(f77_char* transa, f77_char* transb, f77_int* m, f77_int* n, const dcomplex* alpha, const dcomplex* A, f77_int* lda, const dcomplex* beta, dcomplex* B, f77_int* ldb, dcomplex* C, f77_int* ldc); + + + +BLIS_EXPORT_BLIS void ZOMATCOPY2(f77_char* trans, f77_int* rows, f77_int* cols, const dcomplex* alpha, const dcomplex* aptr, f77_int* lda, f77_int* stridea, dcomplex* bptr, f77_int* ldb, f77_int* strideb); + +BLIS_EXPORT_BLIS void zomatcopy2(f77_char* trans, f77_int* rows, f77_int* cols, const dcomplex* alpha, const dcomplex* aptr, f77_int* lda, f77_int* stridea, dcomplex* bptr, f77_int* ldb, f77_int* strideb); + +BLIS_EXPORT_BLIS void ZOMATCOPY2_(f77_char* trans, f77_int* rows, f77_int* cols, const dcomplex* alpha, const dcomplex* aptr, f77_int* lda, f77_int* stridea, dcomplex* bptr, f77_int* ldb, f77_int* strideb); + + + +BLIS_EXPORT_BLIS void ZOMATCOPY(f77_char* trans, f77_int* rows, f77_int* cols, const dcomplex* alpha, const dcomplex* aptr, f77_int* lda, dcomplex* bptr, f77_int* ldb); + +BLIS_EXPORT_BLIS void zomatcopy(f77_char* trans, f77_int* rows, f77_int* cols, const dcomplex* alpha, const dcomplex* aptr, f77_int* lda, dcomplex* bptr, f77_int* ldb); + +BLIS_EXPORT_BLIS void ZOMATCOPY_(f77_char* trans, f77_int* rows, f77_int* cols, const dcomplex* alpha, const dcomplex* aptr, f77_int* lda, dcomplex* bptr, f77_int* ldb); + + + From f698afc5676cad30ab4fead67c96dadca022cf50 Mon Sep 17 00:00:00 2001 From: Dipal M Zambare Date: Thu, 26 Aug 2021 09:07:01 +0530 Subject: [PATCH 007/243] Updated version number to 3.1.0 AMD-Internal: [CPUPL-1811] Change-Id: I6b485e7622e526791094ae621d9f84d2526e6569 --- so_version | 2 +- version | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/so_version b/so_version index 5d7b33d376..a831c0e579 100644 --- a/so_version +++ b/so_version @@ -1,2 +1,2 @@ 3 -0.1 +1.0 diff --git a/version b/version index d4c4b54b7c..0c6173b5f1 100644 --- a/version +++ b/version @@ -1,2 +1,2 @@ -3.0.1 +3.1.0 From 41a86c24632d68c2eb0764a77c1eee3a54db5114 Mon Sep 17 00:00:00 2001 From: Dipal M Zambare Date: Wed, 25 Aug 2021 14:43:24 +0530 Subject: [PATCH 008/243] Enabled znver3 flag for gcc version above 11 -- Added -march=znver3 flag if the library is built for zen3 configuration with gcc compiler version 11 or above. -- Replaced hardcoded compiler names 'gcc' and 'clang' with variable $CC so that options are chosen as per the compiler specified at configure time (instead of compiler in path). AMD-Internal: [CPUPL-1823] Change-Id: I2659349c998201ebd4480735c544e48a5ed76bb4 --- config/zen/amd_config.mk | 4 ++-- config/zen2/make_defs.mk | 4 ++-- config/zen3/make_defs.mk | 14 +++++++++----- 3 files changed, 13 insertions(+), 9 deletions(-) diff --git a/config/zen/amd_config.mk b/config/zen/amd_config.mk index f479386d4f..5ca32b268a 100644 --- a/config/zen/amd_config.mk +++ b/config/zen/amd_config.mk @@ -4,7 +4,7 @@ # An object-based framework for developing high-performance BLAS-like # libraries. # -# Copyright (C) 2019, Advanced Micro Devices, Inc. +# Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -61,7 +61,7 @@ CKVECFLAGS := -mavx2 -mfpmath=sse -mfma else ifeq ($(CC_VENDOR),clang) CKVECFLAGS := -mavx2 -mfpmath=sse -mfma -mno-fma4 -mno-tbm -mno-xop -mno-lwp -ifeq ($(strip $(shell clang -v |&head -1 |grep -c 'AOCC.LLVM')),1) +ifeq ($(strip $(shell $(CC) -v |&head -1 |grep -c 'AOCC.LLVM')),1) CKVECFLAGS += -mllvm -disable-licm-vrp endif else diff --git a/config/zen2/make_defs.mk b/config/zen2/make_defs.mk index c936487fc3..ba91f722ab 100644 --- a/config/zen2/make_defs.mk +++ b/config/zen2/make_defs.mk @@ -78,7 +78,7 @@ endif # they make explicit use of the rbp register. CKOPTFLAGS := $(COPTFLAGS) -fomit-frame-pointer ifeq ($(CC_VENDOR),gcc) -GCC_VERSION := $(strip $(shell gcc -dumpversion | cut -d. -f1)) +GCC_VERSION := $(strip $(shell $(CC) -dumpversion | cut -d. -f1)) #gcc or clang version must be atleast 4.0 # gcc 9.0 or later: ifeq ($(shell test $(GCC_VERSION) -ge 9; echo $$?),0) @@ -91,7 +91,7 @@ CKVECFLAGS += -march=znver1 -mno-avx256-split-unaligned-store endif else ifeq ($(CC_VENDOR),clang) -ifeq ($(strip $(shell clang -v |&head -1 |grep -c 'AOCC.LLVM.2\|AOCC_2')),1) +ifeq ($(strip $(shell $(CC) -v |&head -1 |grep -c 'AOCC.LLVM.2\|AOCC_2')),1) CKVECFLAGS += -march=znver2 else #if compiling with clang diff --git a/config/zen3/make_defs.mk b/config/zen3/make_defs.mk index bc36e6ae94..a479acf8a5 100644 --- a/config/zen3/make_defs.mk +++ b/config/zen3/make_defs.mk @@ -78,9 +78,12 @@ endif # they make explicit use of the rbp register. CKOPTFLAGS := $(COPTFLAGS) -fomit-frame-pointer ifeq ($(CC_VENDOR),gcc) -GCC_VERSION := $(strip $(shell gcc -dumpversion | cut -d. -f1)) -#gcc or clang version must be atleast 4.0 +GCC_VERSION := $(strip $(shell $(CC) -dumpversion | cut -d. -f1)) +# gcc or clang version must be atleast 4.0 # gcc 9.0 or later: +ifeq ($(shell test $(GCC_VERSION) -ge 11; echo $$?),0) +CKVECFLAGS += -march=znver3 +else ifeq ($(shell test $(GCC_VERSION) -ge 9; echo $$?),0) CKVECFLAGS += -march=znver2 else @@ -88,7 +91,8 @@ else # as the fallback option. CRVECFLAGS += -march=znver1 -mno-avx256-split-unaligned-store CKVECFLAGS += -march=znver1 -mno-avx256-split-unaligned-store -endif +endif # GCC 9 +endif # GCC 11 else ifeq ($(CC_VENDOR),clang) @@ -103,11 +107,11 @@ ifeq ($(CC_VENDOR),clang) # For our prupose we just want to know if it version 2x or 3x # for version 3x we will enable znver3 -ifeq ($(strip $(shell clang -v |&head -1 |grep -c 'AOCC_3')),1) +ifeq ($(strip $(shell $(CC) -v |&head -1 |grep -c 'AOCC_3')),1) CKVECFLAGS += -march=znver3 else # for version 2x we will enable znver2 -ifeq ($(strip $(shell clang -v |&head -1 |grep -c 'AOCC.LLVM.2\|AOCC_2')),1) +ifeq ($(strip $(shell $(CC) -v |&head -1 |grep -c 'AOCC.LLVM.2\|AOCC_2')),1) CKVECFLAGS += -march=znver2 else #if compiling with clang From d1fb770f1cde034682c76225495f148ee785ab0b Mon Sep 17 00:00:00 2001 From: Dipal M Zambare Date: Wed, 1 Sep 2021 10:01:39 +0530 Subject: [PATCH 009/243] Removed duplicate cpp and testcpp folders. The cpp and testcpp folder exists in root directory as well as vendor directory. Only folder in vendor directory are needed. Removed duplicate directories and updated makefiles to pick the sources from vendor folder. AMD-Internal: [CPUPL-1834] Change-Id: I178043a09fd746660938b89ecce73c53d6c53409 --- CMakeLists.txt | 2 +- Makefile | 15 +- common.mk | 2 - cpp/blis.hh | 3820 -------------------- cpp/cblas.hh | 1705 --------- testcpp/Makefile | 208 -- testcpp/test.hh | 219 -- testcpp/test.sh | 46 - testcpp/test_asum.cc | 127 - testcpp/test_axpy.cc | 138 - testcpp/test_copy.cc | 132 - testcpp/test_dot.cc | 131 - testcpp/test_dotc.cc | 127 - testcpp/test_gbmv.cc | 109 - testcpp/test_gemm.cc | 163 - testcpp/test_gemm.hh | 110 - testcpp/test_gemv.cc | 162 - testcpp/test_ger.cc | 150 - testcpp/test_gerc.cc | 174 - testcpp/test_geru.cc | 169 - testcpp/test_hemm.cc | 164 - testcpp/test_hemv.cc | 157 - testcpp/test_her.cc | 141 - testcpp/test_her2.cc | 147 - testcpp/test_herk.cc | 155 - testcpp/test_hpr.cc | 112 - testcpp/test_hpr2.cc | 93 - testcpp/test_nrm2.cc | 100 - testcpp/test_rot.cc | 102 - testcpp/test_rotg.cc | 108 - testcpp/test_rotm.cc | 106 - testcpp/test_rotmg.cc | 137 - testcpp/test_scal.cc | 138 - testcpp/test_sdsdot.cc | 134 - testcpp/test_spr.cc | 97 - testcpp/test_spr2.cc | 107 - testcpp/test_swap.cc | 136 - testcpp/test_symm.cc | 164 - testcpp/test_syr.cc | 140 - testcpp/test_syr2.cc | 149 - testcpp/test_syr2k.cc | 163 - testcpp/test_syrk.cc | 152 - testcpp/test_tbmv.cc | 103 - testcpp/test_tbsv.cc | 104 - testcpp/test_tpmv.cc | 84 - testcpp/test_tpsv.cc | 87 - testcpp/test_trmm.cc | 153 - testcpp/test_trsm.cc | 154 - testcpp/test_trsv.cc | 142 - {testcpp => vendor/testcpp}/CMakeLists.txt | 0 50 files changed, 5 insertions(+), 11433 deletions(-) delete mode 100644 cpp/blis.hh delete mode 100644 cpp/cblas.hh delete mode 100644 testcpp/Makefile delete mode 100644 testcpp/test.hh delete mode 100644 testcpp/test.sh delete mode 100644 testcpp/test_asum.cc delete mode 100644 testcpp/test_axpy.cc delete mode 100644 testcpp/test_copy.cc delete mode 100644 testcpp/test_dot.cc delete mode 100644 testcpp/test_dotc.cc delete mode 100644 testcpp/test_gbmv.cc delete mode 100644 testcpp/test_gemm.cc delete mode 100644 testcpp/test_gemm.hh delete mode 100644 testcpp/test_gemv.cc delete mode 100644 testcpp/test_ger.cc delete mode 100644 testcpp/test_gerc.cc delete mode 100644 testcpp/test_geru.cc delete mode 100644 testcpp/test_hemm.cc delete mode 100644 testcpp/test_hemv.cc delete mode 100644 testcpp/test_her.cc delete mode 100644 testcpp/test_her2.cc delete mode 100644 testcpp/test_herk.cc delete mode 100644 testcpp/test_hpr.cc delete mode 100644 testcpp/test_hpr2.cc delete mode 100644 testcpp/test_nrm2.cc delete mode 100644 testcpp/test_rot.cc delete mode 100644 testcpp/test_rotg.cc delete mode 100644 testcpp/test_rotm.cc delete mode 100644 testcpp/test_rotmg.cc delete mode 100644 testcpp/test_scal.cc delete mode 100644 testcpp/test_sdsdot.cc delete mode 100644 testcpp/test_spr.cc delete mode 100644 testcpp/test_spr2.cc delete mode 100644 testcpp/test_swap.cc delete mode 100644 testcpp/test_symm.cc delete mode 100644 testcpp/test_syr.cc delete mode 100644 testcpp/test_syr2.cc delete mode 100644 testcpp/test_syr2k.cc delete mode 100644 testcpp/test_syrk.cc delete mode 100644 testcpp/test_tbmv.cc delete mode 100644 testcpp/test_tbsv.cc delete mode 100644 testcpp/test_tpmv.cc delete mode 100644 testcpp/test_tpsv.cc delete mode 100644 testcpp/test_trmm.cc delete mode 100644 testcpp/test_trsm.cc delete mode 100644 testcpp/test_trsv.cc rename {testcpp => vendor/testcpp}/CMakeLists.txt (100%) diff --git a/CMakeLists.txt b/CMakeLists.txt index 78a380fece..8572aad640 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -532,7 +532,7 @@ add_subdirectory(aocl_dtl) add_subdirectory(test) add_subdirectory(testsuite) if(ENABLE_TESTCPP_TESTING) - add_subdirectory(testcpp) + add_subdirectory(vendor/testcpp) endif() if (ENABLE_BLASTEST) add_subdirectory(blastest) diff --git a/Makefile b/Makefile index 99b43b2bb3..38cc8144ab 100644 --- a/Makefile +++ b/Makefile @@ -257,16 +257,9 @@ ifeq ($(MK_ENABLE_CBLAS),yes) HEADERS_TO_INSTALL += $(CBLAS_H_FLAT) endif -# Install BLIS CPP Template header files -HEADERS_TO_INSTALL += $(CPP_HEADER_DIR)/*.hh - -# If requested, include AMD's C++ template header files in the list of headers +# Include AMD's C++ template header files in the list of headers # to install. -ifeq ($(INSTALL_HH),yes) HEADERS_TO_INSTALL += $(wildcard $(VEND_CPP_PATH)/*.hh) -endif - - # # --- public makefile fragment definitions ------------------------------------- @@ -903,7 +896,7 @@ endif # Check results of BLIS CPP Template tests checkbliscpp: - $(MAKE) -C $(CPP_TEST_DIR) + $(MAKE) -C $(VEND_TESTCPP_DIR) # Check the results of the BLIS testsuite. checkblis: testsuite-run @@ -1246,13 +1239,13 @@ ifeq ($(IS_CONFIGURED),yes) ifeq ($(ENABLE_VERBOSE),yes) - $(FIND) $(TESTSUITE_DIR)/$(OBJ_DIR) -name "*.o" | $(XARGS) $(RM_F) - $(RM_F) $(TESTSUITE_DIR)/$(TESTSUITE_BIN) - - $(MAKE) -C $(CPP_TEST_DIR) clean + - $(MAKE) -C $(VEND_TESTCPP_DIR) clean else @echo "Removing object files from $(TESTSUITE_DIR)/$(OBJ_DIR)" @- $(FIND) $(TESTSUITE_DIR)/$(OBJ_DIR) -name "*.o" | $(XARGS) $(RM_F) @echo "Removing binary $(TESTSUITE_DIR)/$(TESTSUITE_BIN)" @- $(RM_F) $(TESTSUITE_DIR)/$(TESTSUITE_BIN) - @$(MAKE) -C $(CPP_TEST_DIR) clean + @$(MAKE) -C $(VEND_TESTCPP_DIR) clean endif # ENABLE_VERBOSE endif # IS_CONFIGURED diff --git a/common.mk b/common.mk index 07548f6e80..a05e2160f0 100644 --- a/common.mk +++ b/common.mk @@ -303,8 +303,6 @@ LIB_DIR := lib INCLUDE_DIR := include BLASTEST_DIR := blastest TESTSUITE_DIR := testsuite -CPP_HEADER_DIR := cpp -CPP_TEST_DIR := testcpp VEND_DIR := vendor VEND_CPP_DIR := $(VEND_DIR)/cpp diff --git a/cpp/blis.hh b/cpp/blis.hh deleted file mode 100644 index 39dc258647..0000000000 --- a/cpp/blis.hh +++ /dev/null @@ -1,3820 +0,0 @@ -/****************************************************************************** -* Copyright (c) 2019 - present Advanced Micro Devices, Inc. All rights reserved. -* -* Permission is hereby granted, free of charge, to any person obtaining a copy -* of this software and associated documentation files (the "Software"), to deal -* in the Software without restriction, including without limitation the rights -* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -* copies of the Software, and to permit persons to whom the Software is -* furnished to do so, subject to the following conditions: -* -* The above copyright notice and this permission notice shall be included in -* all copies or substantial portions of the Software. -* -* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -* THE SOFTWARE. -*******************************************************************************/ - -/*! @file blis.hh - * blis.hh defines all the BLAS CPP templated public interfaces - * */ -#ifndef BLIS_HH -#define BLIS_HH - -#include "cblas.hh" - -namespace blis { - -/*! \brief Construct plane rotation for arbitrary data types - - \b Purpose: - - ROTG construct plane rotation that eliminates b for arbitrary data types, such that \n - - [ z ] = [ c s ] [ a ] \n - [ 0 ] [ -s c ] [ b ] \n - Data precisions supported include SINGLE/DOUBLE PRECISION REAL - - \param[in, out] a - SINGLE/DOUBLE PRECISION REAL - On entry, scalar a. On exit, set to z. - - \param[in, out] b - SINGLE/DOUBLE PRECISION REAL - On entry, scalar b. On exit, set to s, 1/c, or 0. - - \param[out] c - Cosine of rotation; SINGLE/DOUBLE PRECISION REAL. - - \param[out] s - Sine of rotation; SINGLE/DOUBLE PRECISION REAL. - */ -template< typename T > -void rotg( - T *a, - T *b, - T *c, - T *s ) -{ - cblas_rotg(a, b, c, s); -} - -/*! \brief Construct the modified givens transformation matrix for arbitrary data types - - \b Purpose: - - ROTMG construct modified (fast) plane rotation, H, that eliminates b, such that \n - [ z ] = H [ sqrt(d1) 0 ] [ a ] \n - [ 0 ] [ 0 sqrt(d2) ] [ b ] \n - Data precisions supported include SINGLE/DOUBLE PRECISION REAL - - \param[in, out] d1 - SINGLE/DOUBLE PRECISION REAL - sqrt(d1) is scaling factor for vector x. - - \param[in, out] d2 - SINGLE/DOUBLE PRECISION REAL - sqrt(d2) is scaling factor for vector y. - - \param[in, out] a - On entry, scalar a. On exit, set to z. SINGLE/DOUBLE PRECISION REAL. - - \param[in, out] b - On entry, scalar b. SINGLE/DOUBLE PRECISION REAL. - - \param[out] param - SINGLE/DOUBLE PRECISION REAL array, dimension (5),giving parameters - of modified plane rotation - param(1)=DFLAG - param(2)=DH11 - param(3)=DH21 - param(4)=DH12 - param(5)=DH22 - */ -template< typename T > -void rotmg( - T *d1, - T *d2, - T *a, - T b, - T param[5] ) -{ - cblas_rotmg(d1, d2, a, b, param ); -} - -/*! \brief Apply plane rotation for arbitrary data types - - \b Purpose: - - ROT applies a plane rotation: \n - [ x^T ] [ c s ] [ x^T ] \n - [ y^T ] = [ -s c ] [ y^T ] \n - Data precisions supported include SINGLE/DOUBLE PRECISION REAL - - \param[in] n - Number of elements in x and y. n >= 0. - - \param[in, out] x - SINGLE/DOUBLE PRECISION REAL array - The n-element vector x, in an array of length (n-1)*abs(incx) + 1. - - \param[in] incx - incx is INTEGER - Stride between elements of x. incx must not be zero. - If incx < 0, uses elements of x in reverse order: x(n-1), ..., x(0). - - \param[in, out] y - SINGLE/DOUBLE PRECISION REAL array - The n-element vector y, in an array of length (n-1)*abs(incy) + 1. - - \param[in] incy - incy is INTEGER - Stride between elements of y. incy must not be zero. - If incy < 0, uses elements of y in reverse order: y(n-1), ..., y(0). - - \param[in] c - Cosine of rotation; SINGLE/DOUBLE PRECISION REAL. - - \param[in] s - Sine of rotation; SINGLE/DOUBLE PRECISION REAL. - */ -template< typename T > -void rot( - int64_t n, - T *x, int64_t incx, - T *y, int64_t incy, - T c, - T s ) -{ - cblas_rot( n, x, incx, y, incy, c, s ); -} - -/*! \brief Apply the modified givens transformation for arbitrary data types - - \b Purpose: - - ROTM applies modified (fast) plane rotation, H: \n - [ x^T ] = H [ x^T ] \n - [ y^T ] [ y^T ] \n - - Data precisions supported include SINGLE/DOUBLE PRECISION REAL - - \param[in] n - Number of elements in x and y. n >= 0. - - \param[in, out] x - SINGLE/DOUBLE PRECISION REAL array - The n-element vector x, in an array of length (n-1)*abs(incx) + 1. - - \param[in] incx - incx is INTEGER - Stride between elements of x. incx must not be zero. - If incx < 0, uses elements of x in reverse order: x(n-1), ..., x(0). - - \param[in, out] y - SINGLE/DOUBLE PRECISION REAL array - The n-element vector y, in an array of length (n-1)*abs(incy) + 1. - - \param[in] incy - incy is INTEGER - Stride between elements of y. incy must not be zero. - If incy < 0, uses elements of y in reverse order: y(n-1), ..., y(0). - - \param[in] P - SINGLE/DOUBLE PRECISION REAL array, dimension (5),giving parameters - of modified plane rotation - param(1)=DFLAG - param(2)=DH11 - param(3)=DH21 - param(4)=DH12 - param(5)=DH22 - */ -template< typename T > -void rotm( - int64_t n, - T *x, int64_t incx, - T *y, int64_t incy, - const T *P) -{ - cblas_rotm( n, x, incx, y, incy, P ); -} - -/*! \brief Interchanges two vectors of arbitrary data types - - \b Purpose: - - SWAP interchanges two vectors uses unrolled loops for increments equal to 1.\n - x <=> y \n - Data precisions supported include SINGLE/DOUBLE PRECISION REAL, - SINGLE PRECISION COMPLEX, DOUBLE PRECISION COMPLEX(COMPLEX*16) - - \param[in] n - n is INTEGER - Number of elements in x and y. n >= 0. - - \param[in] x - REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array. - The n-element vector x, in an array of length (n-1)*abs(incx) + 1. - - \param[in] incx - incx is INTEGER. - Stride between elements of x. incx must not be zero. - If incx < 0, uses elements of x in reverse order: x(n-1), ..., x(0). - - \param[in, out] y - REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array. - The n-element vector y, in an array of length (n-1)*abs(incy) + 1. - - \param[in] incy - incy is INTEGER. - Stride between elements of y. incy must not be zero. - If incy < 0, uses elements of y in reverse order: y(n-1), ..., y(0). - */ -template< typename T > -void swap( - int64_t n, - T *x, int64_t incx, - T *y, int64_t incy ) -{ - cblas_swap( n, x, incx, y, incy ); -} - -/*! \brief Scales a vector of arbitrary data types by a constant. - - \b Purpose: - - SCAL scales a vector by a constant, uses unrolled loops for increment equal to 1.\n - x = alpha * x \n - Data precisions of vector & constant include SINGLE/DOUBLE PRECISION REAL, - SINGLE PRECISION COMPLEX, DOUBLE PRECISION COMPLEX(COMPLEX*16) - - \param[in] n - n is INTEGER - Number of elements in x. n >= 0. - - \param[in] alpha - alpha is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 - On entry, alpha specifies the scalar alpha. - - \param[in ,out] x - REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array - The n-element vector x, in an array of length (n-1)*abs(incx) + 1. - - \param[in] incx - incx is INTEGER - Stride between elements of x. incx must not be zero. - If incx < 0, uses elements of x in reverse order: x(n-1), ..., x(0). - */ -template< typename TA, typename TB > -void scal( - int64_t n, - TA alpha, - TB* x, int64_t incx ) -{ - cblas_scal( n, alpha, x, incx ); -} - -/*! \brief Copies a vector x to a vector y for arbitrary data types - - \b Purpose: - - COPY copies a vector x to a vector y.\n - y = x \n - Data precisions supported include SINGLE/DOUBLE PRECISION REAL, - SINGLE PRECISION COMPLEX, DOUBLE PRECISION COMPLEX(COMPLEX*16) - - \param[in] n - n is INTEGER - Number of elements in x and y. n >= 0. - - \param[in] x - REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array. - The n-element vector x, in an array of length (n-1)*abs(incx) + 1. - - \param[in] incx - incx is INTEGER. - Stride between elements of x. incx must not be zero. - If incx < 0, uses elements of x in reverse order: x(n-1), ..., x(0). - - \param[out] y - REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array. - The n-element vector y, in an array of length (n-1)*abs(incy) + 1. - - \param[in] incy - incy is INTEGER. - Stride between elements of y. incy must not be zero. - If incy < 0, uses elements of y in reverse order: y(n-1), ..., y(0). - */ -template< typename T > -void copy( - int64_t n, - T const *x, int64_t incx, - T *y, int64_t incy ) -{ - cblas_copy( n, x, incx, y, incy ); -} - -/*! \brief Performs addition of scaled vector for arbitrary data types - - \b Purpose: - - AXPY constant times a vector plus a vector.\n - y = alpha*x + y \n - Data precisions supported include SINGLE/DOUBLE PRECISION REAL, - SINGLE PRECISION COMPLEX, DOUBLE PRECISION COMPLEX(COMPLEX*16) - - \param[in] n - n is INTEGER - Number of elements in x and y. n >= 0. - - \param[in] alpha - alpha is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 - On entry, alpha specifies the scalar alpha.\n - If alpha is zero, y is not updated. - - \param[in] x - REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array. - The n-element vector x, in an array of length (n-1)*abs(incx) + 1. - - \param[in] incx - incx is INTEGER. - Stride between elements of x. incx must not be zero. - If incx < 0, uses elements of x in reverse order: x(n-1), ..., x(0). - - \param[out] y - REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array. - The n-element vector y, in an array of length (n-1)*abs(incy) + 1. - - \param[in] incy - incy is INTEGER. - Stride between elements of y. incy must not be zero. - If incy < 0, uses elements of y in reverse order: y(n-1), ..., y(0). - */ -template< typename T > -void axpy( - int64_t n, - T alpha, - T const *x, int64_t incx, - T *y, int64_t incy ) -{ - cblas_axpy( n, alpha, x, incx, y, incy ); -} - -/*! \brief Performs the dot product of two vectors for arbitrary data types - - \b Purpose: - - DOT forms the dot product of two vectors - uses unrolled loops for increments equal to one.\n - dot = x^T * y \n - Data precisions supported include SINGLE/DOUBLE PRECISION REAL - - \param[in] n - n is INTEGER - Number of elements in x and y. n >= 0. - - \param[in] x - REAL/DOUBLE PRECISION array. - The n-element vector x, in an array of length (n-1)*abs(incx) + 1. - - \param[in] incx - incx is INTEGER. - Stride between elements of x. incx must not be zero. - If incx < 0, uses elements of x in reverse order: x(n-1), ..., x(0). - - \param[in] y - REAL/DOUBLE PRECISION array. - The n-element vector y, in an array of length (n-1)*abs(incy) + 1. - - \param[in] incy - incy is INTEGER. - Stride between elements of y. incy must not be zero. - If incy < 0, uses elements of y in reverse order: y(n-1), ..., y(0). - - \return Unconjugated dot product, x^T * y. - REAL/DOUBLE PRECISION - */ -template< typename T, typename TR > -TR dot( - int64_t n, - T const *x, int64_t incx, - T const *y, int64_t incy ) -{ - return cblas_dot( n, x, incx, y, incy ); -} - -/*! \brief Performs the dot product of two complex vectors - - \b Purpose: - - DOTU forms the dot product of two complex vectors. \n - CDOTU = X^T * Y \n - Data precisions supported include SINGLE/DOUBLE PRECISION COMPLEX - - \param[in] n - n is INTEGER - Number of elements in x and y. n >= 0. - - \param[in] x - REAL/DOUBLE PRECISION COMPLEX array. - The n-element vector x, in an array of length (n-1)*abs(incx) + 1. - - \param[in] incx - incx is INTEGER. - Stride between elements of x. incx must not be zero. - If incx < 0, uses elements of x in reverse order: x(n-1), ..., x(0). - - \param[in] y - REAL/DOUBLE PRECISION COMPLEX array. - The n-element vector y, in an array of length (n-1)*abs(incy) + 1. - - \param[in] incy - incy is INTEGER. - Stride between elements of y. incy must not be zero. - If incy < 0, uses elements of y in reverse order: y(n-1), ..., y(0). - - \return Unconjugated dot product, x^T * y. - REAL/DOUBLE PRECISION COMPLEX - */ -template< typename T > -T dotu( - int64_t n, - T const *x, int64_t incx, - T const *y, int64_t incy ) -{ - return cblas_dotu( n, x, incx, y, incy ); -} - -/*! \brief Performs the dot product of two complex vectors - - \b Purpose: - - DOTC forms the dot product of two complex vectors. \n - CDOTU = X^H * Y \n - Data precisions supported include SINGLE/DOUBLE PRECISION COMPLEX - - \param[in] n - n is INTEGER - Number of elements in x and y. n >= 0. - - \param[in] x - REAL/DOUBLE PRECISION COMPLEX array. - The n-element vector x, in an array of length (n-1)*abs(incx) + 1. - - \param[in] incx - incx is INTEGER. - Stride between elements of x. incx must not be zero. - If incx < 0, uses elements of x in reverse order: x(n-1), ..., x(0). - - \param[in] y - REAL/DOUBLE PRECISION COMPLEX array. - The n-element vector y, in an array of length (n-1)*abs(incy) + 1. - - \param[in] incy - incy is INTEGER. - Stride between elements of y. incy must not be zero. - If incy < 0, uses elements of y in reverse order: y(n-1), ..., y(0). - - \return Conjugated dot product, x^H * y. - REAL/DOUBLE PRECISION COMPLEX - */ -template< typename T > -T dotc( - int64_t n, - T const *x, int64_t incx, - T const *y, int64_t incy ) -{ - return cblas_dotc( n, x, incx, y, incy ); -} - -/*! \brief Performs inner product of two vectors with extended precision accumulation - - \b Purpose: - - DOTC forms the inner product of two vectors with extended precision accumulation. \n - Data precisions supported include SINGLE PRECISION REAL - - \param[in] n - n is INTEGER\n - number of elements in input vector(s) - - \param[in] alpha - alpha is REAL\n - single precision scalar to be added to inner product - - \param[in] x - x is REAL array, dimension ( 1 + ( n - 1 )*abs( incx ) )\n - single precision vector with n elements - - \param[in] incx - incx is INTEGER\n - storage spacing between elements of x - - \param[in] y - y is REAL array, dimension ( 1 + ( n - 1 )*abs( incx ) )\n - single precision vector with n elements - - \param[in] incy - incy is INTEGER\n - storage spacing between elements of y - - \return S.P. result with dot product accumulated in D.P. - */ -template< typename T > -T sdsdot( - int64_t n, - T alpha, - T const *x, int64_t incx, - T const *y, int64_t incy ) -{ - return cblas_sdsdot( n, alpha, x, incx, y, incy ); -} - -/*! \brief return 2-norm of vectors of arbitrary data types - - \b Purpose: - - NRM2 returns the euclidean norm of a vector via the function name, so that - SNRM2 := sqrt( x'*x ). \n - Data precisions supported include SINGLE PRECISION REAL, DOUBLE PRECISION REAL, - SINGLE PRECISION COMPLEX, DOUBLE PRECISION COMPLEX(COMPLEX*16) - - \param[in] n - n is INTEGER\n - number of elements in input vector(s) - - \param[in] x - x is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array, - dimension ( 1 + ( n - 1 )*abs( incx ) )\n - single precision vector with n elements - - \param[in] incx - incx is INTEGER\n - storage spacing between elements of x - - \return 2-norm of vector - REAL SINGLE/DOUBLE PRECISION - */ -template< typename T > -real_type -nrm2( - int64_t n, - T const * x, int64_t incx ) -{ - return cblas_nrm2( n, x, incx ); -} - -/*! \brief return 1-norm of vector of arbitrary data types - - \b Purpose: - - ASUM takes the sum of the absolute values, uses unrolled loops for - increment equal to one. \n - ASUM := || Re(x) ||_1 + || Im(x) ||_1. \n - Data precisions supported include SINGLE PRECISION REAL, DOUBLE PRECISION REAL, - SINGLE PRECISION COMPLEX, DOUBLE PRECISION COMPLEX(COMPLEX*16) - - \param[in] n - n is INTEGER\n - number of elements in input vector(s) - - \param[in] x - x is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array, - dimension ( 1 + ( n - 1 )*abs( incx ) )\n - single precision vector with n elements - - \param[in] incx - incx is INTEGER\n - storage spacing between elements of x - - \return 1-norm of vector - REAL SINGLE/DOUBLE PRECISION - */ -template< typename T > -real_type -asum( - int64_t n, - T const *x, int64_t incx ) -{ - return cblas_asum( n, x, incx ); -} - -/*! \brief Return Index of infinity-norm of vectors of arbitrary types. - - \b Purpose: - - IAMAX finds the index of the first element having maximum |Re(.)| + |Im(.)|. \n - Data precisions supported include SINGLE PRECISION REAL, DOUBLE PRECISION REAL, - SINGLE PRECISION COMPLEX, DOUBLE PRECISION COMPLEX(COMPLEX*16) - - \param[in] n - n is INTEGER\n - number of elements in input vector(s) - - \param[in] x - x is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array, - dimension ( 1 + ( n - 1 )*abs( incx ) ) \n - single precision vector with n elements - - \param[in] incx - incx is INTEGER\n - storage spacing between elements of x - - \return Index of infinity-norm of vector - INTEGER - */ -template< typename T > -int64_t iamax( - int64_t n, - T const *x, int64_t incx ) -{ - return cblas_iamax( n, x, incx ); -} - -/*! \brief Solve General matrix-vector multiply for arbitrary data types - - \b Purpose: - - GEMV performs one of the matrix-vector operations for arbitrary data types - Data precisions supported include SINGLE PRECISION REAL, DOUBLE PRECISION REAL, - SINGLE PRECISION COMPLEX, DOUBLE PRECISION COMPLEX(COMPLEX*16) - - y := alpha*A*x + beta*y, or y := alpha*A**T*x + beta*y, - - where alpha and beta are scalars, x and y are vectors and A is an - m by n matrix. - - \param[in] layout - layout is enum CBLAS_LAYOUT - layout specifies Matrix storage as follows: - layout = CBLAS_LAYOUT::CblasRowMajor or Layout::CblasColMajor. - - \param[in] trans - trans is CBLAS_TRANSPOSE - On entry, trans specifies the operation to be used as follows: \n - trans = CBLAS_TRANSPOSE::CblasNoTrans,y := alpha*A*x + beta*y. \n - trans = CBLAS_TRANSPOSE::CblasTrans, y := alpha*A**T*x + beta*y. \n - trans = CBLAS_TRANSPOSE::CblasConjTrans, y := alpha*A**T*x + beta*y. - - \param[in] m - m is INTEGER - On entry, m specifies the number of rows of the matrix A. - m must be at least zero. - - \param[in] n - n is INTEGER - On entry, n specifies the number of columns of the matrix A. - n must be at least zero. - - \param[in] alpha - alpha is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 - On entry, alpha specifies the scalar alpha. - - \param[in] A - A is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array,dimension : - m-by-n , stored in an lda-by-n array [RowMajor: m-by-lda]. - - \param[in] lda - lda is INTEGER - On entry, lda specifies the Leading dimension of A - lda >= max(1, m) [RowMajor: lda >= max(1, n)]. - - \param[in] x - x is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array,dimension : \n - If trans = CblasNoTrans: - at least ( 1 + ( n - 1 )*abs( incx ) ). \n - Otherwise: - at least ( 1 + ( m - 1 )*abs( incx ) ). - - \param[in] incx - incx is INTEGER - On entry, incx specifies the increment for the elements of x. - incx must not be zero. - - \param[in] beta - beta is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 - On entry, beta specifies the scalar alpha.When beta is - supplied as zero then y need not be set on input. - - \param[in,out] y - y is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array, dimension : \n - If trans = CblasNoTrans: - at least ( 1 + ( m - 1 )*abs( incy ) ). \n - Otherwise: - at least ( 1 + ( n - 1 )*abs( incy ) ). \n - Before entry with beta non-zero, the incremented array y - must contain the vector y. On exit, y is overwritten by the - updated vector y. - - \param[in] incy - incy is INTEGER - On entry, incy specifies the increment for the elements of y. - incy must not be zero. - */ -template< typename T > -void gemv( - CBLAS_ORDER layout, - CBLAS_TRANSPOSE trans, - int64_t m, int64_t n, - T alpha, - T const *A, int64_t lda, - T const *x, int64_t incx, - T beta, - T *y, int64_t incy ) -{ - cblas_gemv(layout, trans, m, n, alpha, A, lda, x, incx, beta, y, incy); -} - -/*! \brief Solve General matrix-vector multiply for arbitrary data types - - \b Purpose: - - GBMV performs one of the matrix-vector operations for arbitrary data types - Data precisions supported include SINGLE PRECISION REAL, DOUBLE PRECISION REAL, - SINGLE PRECISION COMPLEX, DOUBLE PRECISION COMPLEX(COMPLEX*16) - - y := alpha*A*x + beta*y, or y := alpha*A**T*x + beta*y, or - - y := alpha*A**H*x + beta*y, - - where alpha and beta are scalars, x and y are vectors and A is an - m by n matrix with kl sub-diagonals and ku super-diagonals. - - \param[in] layout - layout is enum CBLAS_LAYOUT - layout specifies Matrix storage as follows: - layout = CBLAS_LAYOUT::CblasRowMajor or Layout::CblasColMajor. - - \param[in] trans - trans is CBLAS_TRANSPOSE - On entry, trans specifies the operation to be used as follows: \n - trans = CBLAS_TRANSPOSE::CblasNoTrans,y := alpha*A*x + beta*y. \n - trans = CBLAS_TRANSPOSE::CblasTrans, y := alpha*A**T*x + beta*y. \n - trans = CBLAS_TRANSPOSE::CblasConjTrans, y := alpha*A**H*x + beta*y. - - \param[in] m - m is INTEGER - On entry, m specifies the number of rows of the matrix A. - m must be at least zero. - - \param[in] n - n is INTEGER - On entry, n specifies the number of columns of the matrix A. - n must be at least zero. - - \param[in] kl - kl is INTEGER - On entry, kl specifies the number of sub-diagonals of the matrix A. - kl must be at least zero. - - \param[in] ku - ku is INTEGER - On entry, ku specifies the number of super-diagonals of the matrix A. - ku must be at least zero. - - \param[in] alpha - alpha is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 - On entry, alpha specifies the scalar alpha. - - \param[in] A - A is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array,dimension lda-by-n. - Before entry, the leading ( kl + ku + 1 ) by n part of the - array A must contain the matrix of coefficients, supplied - column by column, with the leading diagonal of the matrix in - row ( ku + 1 ) of the array, the first super-diagonal - starting at position 2 in row ku, the first sub-diagonal - starting at position 1 in row ( ku + 2 ), and so on. - Elements in the array A that do not correspond to elements - in the band matrix (such as the top left ku by ku triangle) - are not referenced. - - \param[in] lda - lda is INTEGER - On entry, lda specifies the Leading dimension of A - lda >= ( kl + ku + 1 ) - - \param[in] x - x is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array,dimension : \n - If trans = CblasNoTrans: - at least ( 1 + ( n - 1 )*abs( incx ) ). \n - Otherwise: - at least ( 1 + ( m - 1 )*abs( incx ) ). \n - Before entry, the incremented array x must contain the - vector x. - - \param[in] incx - incx is INTEGER - On entry, incx specifies the increment for the elements of x. - incx must not be zero. - - \param[in] beta - beta is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 - On entry, beta specifies the scalar alpha.When beta is - supplied as zero then y need not be set on input. - - \param[in,out] y - y is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array, dimension : \n - If trans = CblasNoTrans: - at least ( 1 + ( m - 1 )*abs( incy ) ). \n - Otherwise: - at least ( 1 + ( n - 1 )*abs( incy ) ). \n - Before entry with beta non-zero, the incremented array y - must contain the vector y. On exit, y is overwritten by the - updated vector y. - - \param[in] incy - incy is INTEGER - On entry, incy specifies the increment for the elements of y. - incy must not be zero. - */ -template< typename T > -void gbmv( - CBLAS_ORDER layout, - CBLAS_TRANSPOSE trans, - int64_t m, int64_t n, - int64_t kl, int64_t ku, - T alpha, - T const *A, int64_t lda, - T const *x, int64_t incx, - T beta, - T *y, int64_t incy ) -{ - cblas_gbmv(layout, trans, m, n, kl, ku, alpha, A, lda, x, incx, beta, y, incy); -} - -/*! \brief Solves Hermitian matrix-vector multiply for arbitrary data types - - \b Purpose: - - HEMV performs one of the matrix-vector operations for arbitrary data types - Data precisions supported include SINGLE PRECISION COMPLEX, - DOUBLE PRECISION COMPLEX(COMPLEX*16) - - y := alpha*A*x + beta*y, - - where alpha and beta are scalars, x and y are n element vectors and - A is an n by n hermitian matrix. - - \param[in] layout - layout is enum CBLAS_LAYOUT - layout specifies Matrix storage as follows: - layout = CBLAS_LAYOUT::CblasRowMajor or Layout::CblasColMajor. - - \param[in] uplo - uplo is enum CBLAS_UPLO - uplo specifies specifies whether the matrix A is an upper or - lower triangular matrix as follows: \n - uplo = CBLAS_UPLO::CblasUpper A is an upper triangular matrix. \n - uplo = CBLAS_UPLO::CblasLower A is a lower triangular matrix. - - \param[in] n - n is INTEGER - On entry, n specifies the order of the matrix A.n must be at least zero. - - \param[in] alpha - alpha is COMPLEX/COMPLEX*16 - On entry, alpha specifies the scalar alpha. - - \param[in] A - A is COMPLEX/COMPLEX*16 array,dimension lda-by-n. \n - Before entry with UPLO = CblasUpper, the leading n by n - upper triangular part of the array A must contain the upper - triangular part of the hermitian matrix and the strictly - lower triangular part of A is not referenced. - Before entry with UPLO = CblasLower, the leading n by n - lower triangular part of the array A must contain the lower - triangular part of the hermitian matrix and the strictly - upper triangular part of A is not referenced. \n - Note that the imaginary parts of the diagonal elements need - not be set and are assumed to be zero. - - \param[in] lda - lda is INTEGER - On entry, lda specifies the Leading dimension of A - lda must be at least max( 1, n ). - - \param[in] x - x is COMPLEX/COMPLEX*16 array,dimension : \n - at least ( 1 + ( n - 1 )*abs( incx ) ). \n - Before entry, the incremented array x must contain the - vector x. - - \param[in] incx - incx is INTEGER - On entry, incx specifies the increment for the elements of x. - incx must not be zero. - - \param[in] beta - beta is COMPLEX/COMPLEX*16 - On entry, beta specifies the scalar alpha.When beta is - supplied as zero then y need not be set on input. - - \param[in,out] y - y is COMPLEX/COMPLEX*16 array, dimension : \n - at least ( 1 + ( n - 1 )*abs( incy ) ). \n - Before entry with beta non-zero, the incremented array y - must contain the vector y. On exit, y is overwritten by the - updated vector y. - - \param[in] incy - incy is INTEGER - On entry, incy specifies the increment for the elements of y. - incy must not be zero. - */ -template< typename T > -void hemv( - CBLAS_ORDER layout, - CBLAS_UPLO uplo, - int64_t n, - T alpha, - T const *A, int64_t lda, - T const *x, int64_t incx, - T beta, - T *y, int64_t incy ) -{ - cblas_hemv(layout, uplo, n, alpha, A, lda, x, incx, beta, y, incy); -} - -/*! \brief Solves Hermitian matrix-vector multiply for arbitrary data types - - \b Purpose: - - HBMV performs one of the matrix-vector operations for arbitrary data types - Data precisions supported include SINGLE PRECISION COMPLEX, - DOUBLE PRECISION COMPLEX(COMPLEX*16) - - y := alpha*A*x + beta*y, - - where alpha and beta are scalars, x and y are n element vectors and - A is an n by n hermitian matrix with k super-diagonals. - - \param[in] layout - layout is enum CBLAS_LAYOUT - layout specifies Matrix storage as follows: - layout = CBLAS_LAYOUT::CblasRowMajor or Layout::CblasColMajor. - - \param[in] uplo - uplo is enum CBLAS_UPLO - uplo specifies specifies whether the the upper or lower triangular - part of the band matrix A is being supplied as follows: \n - uplo = CBLAS_UPLO::CblasUpper A is an upper triangular matrix. \n - uplo = CBLAS_UPLO::CblasLower A is a lower triangular matrix. - - \param[in] n - n is INTEGER - On entry, n specifies the order of the matrix A.n must be at least zero. - - \param[in] k - k is INTEGER - On entry, k specifies the number of super-diagonals of the matrix A. - k must be at least zero. - - \param[in] alpha - alpha is COMPLEX/COMPLEX*16 - On entry, alpha specifies the scalar alpha. - - \param[in] A - A is COMPLEX/COMPLEX*16 array,dimension lda-by-n. \n - Before entry with UPLO = CblasUpper, the leading ( k + 1 ) - by n part of the array A must contain the upper triangular - band part of the hermitian matrix, supplied column by - column, with the leading diagonal of the matrix in row - ( k + 1 ) of the array, the first super-diagonal starting at - position 2 in row k, and so on. The top left k by k triangle - of the array A is not referenced. \n - Before entry with UPLO = CblasLower, the leading ( k + 1 ) - by n part of the array A must contain the lower triangular - band part of the hermitian matrix, supplied column by - column, with the leading diagonal of the matrix in row 1 of - the array, the first sub-diagonal starting at position 1 in - row 2, and so on. The bottom right k by k triangle of the - array A is not referenced. \n - Note that the imaginary parts of the diagonal elements need - not be set and are assumed to be zero. - - \param[in] lda - lda is INTEGER - On entry, lda specifies the Leading dimension of A - lda must be at least ( k + 1 ). - - \param[in] x - x is COMPLEX/COMPLEX*16 array,dimension : \n - at least ( 1 + ( n - 1 )*abs( incx ) ). \n - Before entry, the incremented array x must contain the - vector x. - - \param[in] incx - incx is INTEGER - On entry, incx specifies the increment for the elements of x. - incx must not be zero. - - \param[in] beta - beta is COMPLEX/COMPLEX*16 - On entry, beta specifies the scalar alpha. - - \param[in,out] y - y is COMPLEX/COMPLEX*16 array, dimension : \n - at least ( 1 + ( n - 1 )*abs( incy ) ). \n - Before entry with beta non-zero, the incremented array y - must contain the vector y. On exit, y is overwritten by the - updated vector y. - - \param[in] incy - incy is INTEGER - On entry, incy specifies the increment for the elements of y. - incy must not be zero. - */ -template< typename T > -void hbmv( - CBLAS_ORDER layout, - CBLAS_UPLO uplo, - int64_t n, int64_t k, - T alpha, - T const *A, int64_t lda, - T const *x, int64_t incx, - T beta, - T *y, int64_t incy ) -{ - cblas_hbmv(layout, uplo, n, k, alpha, A, lda, x, incx, beta, y, incy); -} - -/*! \brief Solves Hermitian matrix-vector multiply for arbitrary data types - - \b Purpose: - - HPMV performs one of the matrix-vector operations for arbitrary data types - Data precisions supported include SINGLE PRECISION COMPLEX, - DOUBLE PRECISION COMPLEX(COMPLEX*16) - - y := alpha*A*x + beta*y, - - where alpha and beta are scalars, x and y are n element vectors and - A is an n by n hermitian matrix, supplied in packed form. - - \param[in] layout - layout is enum CBLAS_LAYOUT - layout specifies Matrix storage as follows: - layout = CBLAS_LAYOUT::CblasRowMajor or Layout::CblasColMajor. - - \param[in] uplo - uplo is enum CBLAS_UPLO - uplo specifies specifies whether the the upper or lower triangular - part of the band matrix A is supplied in the packed array Ap as follows: \n - uplo = CBLAS_UPLO::CblasUpper A is an upper triangular matrix. \n - uplo = CBLAS_UPLO::CblasLower A is a lower triangular matrix. - - \param[in] n - n is INTEGER - On entry, n specifies the order of the matrix A.n must be at least zero. - - \param[in] alpha - alpha is COMPLEX/COMPLEX*16 - On entry, alpha specifies the scalar alpha. - - \param[in] Ap - Ap is COMPLEX/COMPLEX*16 array,dimension atleast ( ( n*( n + 1 ) )/2 ). \n - Before entry with UPLO = CblasUpper, the array Ap must - contain the upper triangular part of the hermitian matrix - packed sequentially, column by column, so that Ap( 1 ) - contains a( 1, 1 ), Ap( 2 ) and Ap( 3 ) contain a( 1, 2 ) - and a( 2, 2 ) respectively, and so on. \n - Before entry with UPLO = CblasLower, the array Ap must - contain the lower triangular part of the hermitian matrix - packed sequentially, column by column, so that Ap( 1 ) - contains a( 1, 1 ), Ap( 2 ) and Ap( 3 ) contain a( 2, 1 ) - and a( 3, 1 ) respectively, and so on. \n - Note that the imaginary parts of the diagonal elements need - not be set and are assumed to be zero. - - \param[in] x - x is COMPLEX/COMPLEX*16 array,dimension : \n - at least ( 1 + ( n - 1 )*abs( incx ) ). \n - Before entry, the incremented array x must contain the - vector x. - - \param[in] incx - incx is INTEGER - On entry, incx specifies the increment for the elements of x. - incx must not be zero. - - \param[in] beta - beta is COMPLEX/COMPLEX*16 - On entry, beta specifies the scalar alpha.When beta is - supplied as zero then y need not be set on input. - - \param[in,out] y - y is COMPLEX/COMPLEX*16 array, dimension : \n - at least ( 1 + ( n - 1 )*abs( incy ) ). \n - Before entry with beta non-zero, the incremented array y - must contain the vector y. On exit, y is overwritten by the - updated vector y. - - \param[in] incy - incy is INTEGER - On entry, incy specifies the increment for the elements of y. - incy must not be zero. - */ -template< typename T > -void hpmv( - CBLAS_ORDER layout, - CBLAS_UPLO uplo, - int64_t n, - T alpha, - T const *Ap, - T const *x, int64_t incx, - T beta, - T *y, int64_t incy ) -{ - cblas_hpmv(layout, uplo, n, alpha, Ap, x, incx, beta, y, incy); -} - -/*! \brief Solves Symmetric matrix-vector multiply for arbitrary data types - - \b Purpose: - - SYMV performs one of the matrix-vector operations for arbitrary data types - Data precisions supported include SINGLE PRECISION REAL, DOUBLE PRECISION REAL - - y := alpha*A*x + beta*y, - - where alpha and beta are scalars, x and y are n element vectors and - A is an n by n symmetric matrix. - - \param[in] layout - layout is enum CBLAS_LAYOUT - layout specifies Matrix storage as follows: - layout = CBLAS_LAYOUT::CblasRowMajor or Layout::CblasColMajor. - - \param[in] uplo - uplo is enum CBLAS_UPLO - uplo specifies specifies whether the matrix A is an upper or - lower triangular matrix as follows: \n - uplo = CBLAS_UPLO::CblasUpper A is an upper triangular matrix. \n - uplo = CBLAS_UPLO::CblasLower A is a lower triangular matrix. - - \param[in] n - n is INTEGER - On entry, n specifies the order of the matrix A.n must be at least zero. - - \param[in] alpha - alpha is SINGLE/DOUBLE PRECISION REAL - On entry, alpha specifies the scalar alpha. - - \param[in] A - A is SINGLE/DOUBLE PRECISION REAL array,dimension lda-by-n. \n - Before entry with UPLO = CblasUpper, the leading n by n - upper triangular part of the array A must contain the upper - triangular part of the symmetric matrix and the strictly - lower triangular part of A is not referenced. - Before entry with UPLO = CblasLower, the leading n by n - lower triangular part of the array A must contain the lower - triangular part of the symmetric matrix and the strictly - upper triangular part of A is not referenced. \n - - \param[in] lda - lda is INTEGER - On entry, lda specifies the Leading dimension of A - lda must be at least max( 1, n ). - - \param[in] x - x is SINGLE/DOUBLE PRECISION REAL array,dimension : \n - at least ( 1 + ( n - 1 )*abs( incx ) ). \n - Before entry, the incremented array x must contain the - vector x. - - \param[in] incx - incx is INTEGER - On entry, incx specifies the increment for the elements of x. - incx must not be zero. - - \param[in] beta - beta is SINGLE/DOUBLE PRECISION REAL - On entry, beta specifies the scalar alpha.When beta is - supplied as zero then y need not be set on input. - - \param[in,out] y - y is SINGLE/DOUBLE PRECISION REAL array, dimension : \n - at least ( 1 + ( n - 1 )*abs( incy ) ). \n - Before entry with beta non-zero, the incremented array y - must contain the vector y. On exit, y is overwritten by the - updated vector y. - - \param[in] incy - incy is INTEGER - On entry, incy specifies the increment for the elements of y. - incy must not be zero. - */ -template< typename T > -void symv( - CBLAS_ORDER layout, - CBLAS_UPLO uplo, - int64_t n, - T alpha, - T const *A, int64_t lda, - T const *x, int64_t incx, - T beta, - T *y, int64_t incy ) -{ - cblas_symv(layout, uplo, n, alpha, A, lda, x, incx, beta, y, incy); -} - -/*! \brief Solves symmetric matrix-vector multiply for arbitrary data types - - \b Purpose: - - SBMV performs one of the matrix-vector operations for arbitrary data types - Data precisions supported include SINGLE PRECISION REAL, DOUBLE PRECISION REAL - - y := alpha*A*x + beta*y, - - where alpha and beta are scalars, x and y are n element vectors and - A is an n by n symmetric matrix with k super-diagonals. - - \param[in] layout - layout is enum CBLAS_LAYOUT - layout specifies Matrix storage as follows: - layout = CBLAS_LAYOUT::CblasRowMajor or Layout::CblasColMajor. - - \param[in] uplo - uplo is enum CBLAS_UPLO - uplo specifies specifies whether the the upper or lower triangular - part of the band matrix A is being supplied as follows: \n - uplo = CBLAS_UPLO::CblasUpper A is an upper triangular matrix. \n - uplo = CBLAS_UPLO::CblasLower A is a lower triangular matrix. - - \param[in] n - n is INTEGER - On entry, n specifies the order of the matrix A.n must be at least zero. - - \param[in] k - k is INTEGER - On entry, k specifies the number of super-diagonals of the matrix A. - k must be at least zero. - - \param[in] alpha - alpha is SINGLE/DOUBLE PRECISION REAL - On entry, alpha specifies the scalar alpha. - - \param[in] A - A is SINGLE/DOUBLE PRECISION REAL array,dimension lda-by-n. \n - Before entry with UPLO = CblasUpper, the leading ( k + 1 ) - by n part of the array A must contain the upper triangular - band part of the symmetric matrix, supplied column by - column, with the leading diagonal of the matrix in row - ( k + 1 ) of the array, the first super-diagonal starting at - position 2 in row k, and so on. The top left k by k triangle - of the array A is not referenced. \n - Before entry with UPLO = CblasLower, the leading ( k + 1 ) - by n part of the array A must contain the lower triangular - band part of the symmetric matrix, supplied column by - column, with the leading diagonal of the matrix in row 1 of - the array, the first sub-diagonal starting at position 1 in - row 2, and so on. The bottom right k by k triangle of the - array A is not referenced. \n - Note that the imaginary parts of the diagonal elements need - not be set and are assumed to be zero. - - \param[in] lda - lda is INTEGER - On entry, lda specifies the Leading dimension of A - lda must be at least ( k + 1 ). - - \param[in] x - x is SINGLE/DOUBLE PRECISION REAL array,dimension : \n - at least ( 1 + ( n - 1 )*abs( incx ) ). \n - Before entry, the incremented array x must contain the - vector x. - - \param[in] incx - incx is INTEGER - On entry, incx specifies the increment for the elements of x. - incx must not be zero. - - \param[in] beta - beta is SINGLE/DOUBLE PRECISION REAL - On entry, beta specifies the scalar alpha. - - \param[in,out] y - y is SINGLE/DOUBLE PRECISION REAL array, dimension : \n - at least ( 1 + ( n - 1 )*abs( incy ) ). \n - Before entry with beta non-zero, the incremented array y - must contain the vector y. On exit, y is overwritten by the - updated vector y. - - \param[in] incy - incy is INTEGER - On entry, incy specifies the increment for the elements of y. - incy must not be zero. - */ -template< typename T > -void sbmv( - CBLAS_ORDER layout, - CBLAS_UPLO uplo, - int64_t n, int64_t k, - T alpha, - T const *A, int64_t lda, - T const *x, int64_t incx, - T beta, - T *y, int64_t incy ) -{ - cblas_sbmv(layout, uplo, n, k, alpha, A, lda, x, incx, beta, y, incy); -} - -/*! \brief Solves symmetric matrix-vector multiply for arbitrary data types - - \b Purpose: - - SPMV performs one of the matrix-vector operations for arbitrary data types - Data precisions supported include SINGLE PRECISION REAL, DOUBLE PRECISION REAL - - y := alpha*A*x + beta*y, - - where alpha and beta are scalars, x and y are n element vectors and - A is an n by n symmetric matrix, supplied in packed form. - - \param[in] layout - layout is enum CBLAS_LAYOUT - layout specifies Matrix storage as follows: - layout = CBLAS_LAYOUT::CblasRowMajor or Layout::CblasColMajor. - - \param[in] uplo - uplo is enum CBLAS_UPLO - uplo specifies specifies whether the the upper or lower triangular - part of the band matrix A is supplied in the packed array Ap as follows: \n - uplo = CBLAS_UPLO::CblasUpper A is an upper triangular matrix. \n - uplo = CBLAS_UPLO::CblasLower A is a lower triangular matrix. - - \param[in] n - n is INTEGER - On entry, n specifies the order of the matrix A.n must be at least zero. - - \param[in] alpha - alpha is SINGLE/DOUBLE PRECISION REAL - On entry, alpha specifies the scalar alpha. - - \param[in] Ap - Ap is SINGLE/DOUBLE PRECISION REAL array,dimension atleast ( ( n*( n + 1 ) )/2 ). \n - Before entry with UPLO = CblasUpper, the array Ap must - contain the upper triangular part of the symmetric matrix - packed sequentially, column by column, so that Ap( 1 ) - contains a( 1, 1 ), Ap( 2 ) and Ap( 3 ) contain a( 1, 2 ) - and a( 2, 2 ) respectively, and so on. \n - Before entry with UPLO = CblasLower, the array Ap must - contain the lower triangular part of the symmetric matrix - packed sequentially, column by column, so that Ap( 1 ) - contains a( 1, 1 ), Ap( 2 ) and Ap( 3 ) contain a( 2, 1 ) - and a( 3, 1 ) respectively, and so on. \n - Note that the imaginary parts of the diagonal elements need - not be set and are assumed to be zero. - - \param[in] x - x is SINGLE/DOUBLE PRECISION REAL array,dimension : \n - at least ( 1 + ( n - 1 )*abs( incx ) ). \n - Before entry, the incremented array x must contain the - vector x. - - \param[in] incx - incx is INTEGER - On entry, incx specifies the increment for the elements of x. - incx must not be zero. - - \param[in] beta - beta is SINGLE/DOUBLE PRECISION REAL - On entry, beta specifies the scalar alpha.When beta is - supplied as zero then y need not be set on input. - - \param[in,out] y - y is SINGLE/DOUBLE PRECISION REAL array, dimension : \n - at least ( 1 + ( n - 1 )*abs( incy ) ). \n - Before entry with beta non-zero, the incremented array y - must contain the vector y. On exit, y is overwritten by the - updated vector y. - - \param[in] incy - incy is INTEGER - On entry, incy specifies the increment for the elements of y. - incy must not be zero. - */ -template< typename T > -void spmv( - CBLAS_ORDER layout, - CBLAS_UPLO uplo, - int64_t n, - T alpha, - T const *Ap, - T const *x, int64_t incx, - T beta, - T *y, int64_t incy ) -{ - cblas_spmv(layout, uplo, n, alpha, Ap, x, incx, beta, y, incy); -} - -/*! \brief Solve the one of the matrix-vector operations for arbitrary data types - - \b Purpose: - - TRMV performs one of the matrix-vector operations for arbitrary data types - Data precisions supported include SINGLE PRECISION REAL, DOUBLE PRECISION REAL, - SINGLE PRECISION COMPLEX, DOUBLE PRECISION COMPLEX(COMPLEX*16) - - x := A*x, or x := A**T*x, - - where x is an n element vector and A is an n by n unit, or non-unit, - upper or lower triangular matrix. - - \param[in] layout - layout is enum CBLAS_ORDER - layout specifies Matrix storage as follows: - layout = CBLAS_ORDER::CblasRowMajor or Layout::CblasColMajor. - - \param[in] uplo - uplo is enum CBLAS_UPLO. - uplo specifies specifies whether the matrix A is an upper or - lower triangular matrix as follows: \n - uplo = CBLAS_UPLO::CblasUpper A is an upper triangular matrix. \n - uplo = CBLAS_UPLO::CblasLower A is a lower triangular matrix. - - \param[in] trans - trans is CBLAS_TRANSPOSE - On entry, trans specifies the operation to be performed as follows: - trans = CBLAS_TRANSPOSE::CblasNoTrans, x := A*x. \n - trans = CBLAS_TRANSPOSE::CblasTrans, x := A**T*x. \n - trans = CBLAS_TRANSPOSE::CblasConjTrans, x := A**T*x. - - \param[in] diag - diag is enum CBLAS_DIAG - diag specifies specifies whether or not A is unit triangular - as follows: \n - diag = CBLAS_DIAG::CblasUnit A is assumed to be unit triangular.\n - diag = CBLAS_DIAG::CblasNonUnit A is not assumed to be unit - triangular. - - \param[in] n - n is INTEGER - On entry, n specifies the order of the matrix A.n must be at least zero. - - \param[in] A - A is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array,dimension ( lda, n )\n - Before entry with UPLO = CblasUpper, the leading n by n - upper triangular part of the array A must contain the upper - triangular matrix and the strictly lower triangular part of - A is not referenced. \n - Before entry with UPLO = CblasLower, the leading n by n - lower triangular part of the array A must contain the lower - triangular matrix and the strictly upper triangular part of - A is not referenced. \n - Note that when DIAG = CblasUnit, the diagonal elements of - A are not referenced either, but are assumed to be unity. - - \param[in] lda - lda is INTEGER - On entry, lda specifies the Leading dimension of A - lda must be at least max( 1, n ). - - \param[in, out] x - x is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array,dimension : \n - at least ( 1 + ( n - 1 )*abs( incx ) ). \n - Before entry, the incremented array x must contain the - vector x.On exit, x is overwritten with the transformed vector x. - - \param[in] incx - incx is INTEGER - On entry, incx specifies the increment for the elements of x. - incx must not be zero. - */ -template< typename T > -void trmv( - CBLAS_ORDER layout, - CBLAS_UPLO uplo, - CBLAS_TRANSPOSE trans, - CBLAS_DIAG diag, - int64_t n, - T const *A, int64_t lda, - T *x, int64_t incx ) -{ - cblas_trmv(layout, uplo, trans, diag, n, A, lda, x, incx); -} - -/*! \brief Solve the one of the matrix-vector operations for arbitrary data types - - \b Purpose: - - TBMV performs one of the matrix-vector operations for arbitrary data types - Data precisions supported include SINGLE PRECISION REAL, DOUBLE PRECISION REAL, - SINGLE PRECISION COMPLEX, DOUBLE PRECISION COMPLEX(COMPLEX*16) - - x := A*x, or x := A**T*x, - - where x is an n element vector and A is an n by n unit, or non-unit, - upper or lower triangular band matrix, with ( k + 1 ) diagonals. - - \param[in] layout - layout is enum CBLAS_ORDER - layout specifies Matrix storage as follows: - layout = CBLAS_ORDER::CblasRowMajor or Layout::CblasColMajor. - - \param[in] uplo - uplo is enum CBLAS_UPLO. - uplo specifies specifies whether the matrix A is an upper or - lower triangular matrix as follows: \n - uplo = CBLAS_UPLO::CblasUpper A is an upper triangular matrix. \n - uplo = CBLAS_UPLO::CblasLower A is a lower triangular matrix. - - \param[in] trans - trans is CBLAS_TRANSPOSE - On entry, trans specifies the operation to be performed as follows: - trans = CBLAS_TRANSPOSE::CblasNoTrans, x := A*x. \n - trans = CBLAS_TRANSPOSE::CblasTrans, x := A**T*x. \n - trans = CBLAS_TRANSPOSE::CblasConjTrans, x := A**T*x. - - \param[in] diag - diag is enum CBLAS_DIAG - diag specifies specifies whether or not A is unit triangular - as follows: \n - diag = CBLAS_DIAG::CblasUnit A is assumed to be unit triangular.\n - diag = CBLAS_DIAG::CblasNonUnit A is not assumed to be unit - triangular. - - \param[in] n - n is INTEGER - On entry, n specifies the order of the matrix A.n must be at least zero. - - \param[in] k - k is INTEGER - On entry with UPLO = CblasUpper, k specifies the number of - super-diagonals of the matrix A. - On entry with UPLO = CblasLower, k specifies the number of - sub-diagonals of the matrix A. - k must at least zero. - - \param[in] A - A is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array,dimension ( lda, n )\n - Before entry with UPLO = CblasUpper, the leading ( k + 1 ) - by n part of the array A must contain the upper triangular - band part of the matrix of coefficients, supplied column by - column, with the leading diagonal of the matrix in row - ( k + 1 ) of the array, the first super-diagonal starting at - position 2 in row k, and so on. The top left k by k triangle - of the array A is not referenced. \n - Before entry with UPLO = CblasLower, the leading ( k + 1 ) - by n part of the array A must contain the lower triangular - band part of the matrix of coefficients, supplied column by - column, with the leading diagonal of the matrix in row 1 of - the array, the first sub-diagonal starting at position 1 in - row 2, and so on. The bottom right k by k triangle of the - array A is not referenced. \n - Note that when DIAG = CblasUnit the elements of the array A - corresponding to the diagonal elements of the matrix are not - referenced, but are assumed to be unity. - - \param[in] lda - lda is INTEGER - On entry, lda specifies the Leading dimension of A - lda must be at least max( 1, ( k + 1 ) ). - - \param[in, out] x - x is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array,dimension : \n - at least ( 1 + ( n - 1 )*abs( incx ) ). \n - Before entry, the incremented array x must contain the - vector x.On exit, x is overwritten with the transformed vector x. - - \param[in] incx - incx is INTEGER - On entry, incx specifies the increment for the elements of x. - incx must not be zero. - */ -template< typename T > -void tbmv( - CBLAS_ORDER layout, - CBLAS_UPLO uplo, - CBLAS_TRANSPOSE trans, - CBLAS_DIAG diag, - int64_t n, int64_t k, - T const *A, int64_t lda, - T *x, int64_t incx ) -{ - cblas_tbmv(layout, uplo, trans, diag, n, k, A, lda, x, incx); -} - - -/*! \brief Solve the one of the matrix-vector operations for arbitrary data types - - \b Purpose: - - TPMV performs one of the matrix-vector operations for arbitrary data types - Data precisions supported include SINGLE PRECISION REAL, DOUBLE PRECISION REAL, - SINGLE PRECISION COMPLEX, DOUBLE PRECISION COMPLEX(COMPLEX*16) - - x := A*x, or x := A**T*x, - - where x is an n element vector and A is an n by n unit, or non-unit, - upper or lower triangular matrix, supplied in packed form. - - \param[in] layout - layout is enum CBLAS_ORDER - layout specifies Matrix storage as follows: - layout = CBLAS_ORDER::CblasRowMajor or Layout::CblasColMajor. - - \param[in] uplo - uplo is enum CBLAS_UPLO. - uplo specifies specifies whether the matrix A is an upper or - lower triangular matrix as follows: \n - uplo = CBLAS_UPLO::CblasUpper A is an upper triangular matrix. \n - uplo = CBLAS_UPLO::CblasLower A is a lower triangular matrix. - - \param[in] trans - trans is CBLAS_TRANSPOSE - On entry, trans specifies the operation to be performed as follows: - trans = CBLAS_TRANSPOSE::CblasNoTrans, x := A*x. \n - trans = CBLAS_TRANSPOSE::CblasTrans, x := A**T*x. \n - trans = CBLAS_TRANSPOSE::CblasConjTrans, x := A**T*x. - - \param[in] diag - diag is enum CBLAS_DIAG - diag specifies specifies whether or not A is unit triangular - as follows: \n - diag = CBLAS_DIAG::CblasUnit A is assumed to be unit triangular.\n - diag = CBLAS_DIAG::CblasNonUnit A is not assumed to be unit - triangular. - - \param[in] n - n is INTEGER - On entry, n specifies the order of the matrix A.n must be at least zero. - - \param[in] Ap - Ap is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array,dimension - ( ( n*( n + 1 ) )/2 ). \n - Before entry with UPLO = CblasUpper, the array Ap must - contain the upper triangular matrix packed sequentially, - column by column, so that Ap( 1 ) contains a( 1, 1 ), - Ap( 2 ) and Ap( 3 ) contain a( 1, 2 ) and a( 2, 2 ) - respectively, and so on. \n - Before entry with UPLO = CblasLower, the array Ap must - contain the lower triangular matrix packed sequentially, - column by column, so that Ap( 1 ) contains a( 1, 1 ), - Ap( 2 ) and Ap( 3 ) contain a( 2, 1 ) and a( 3, 1 ) - respectively, and so on. \n - Note that when DIAG = CblasUnit, the diagonal elements of - A are not referenced, but are assumed to be unity. - - \param[in, out] x - x is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array,dimension : \n - at least ( 1 + ( n - 1 )*abs( incx ) ). \n - Before entry, the incremented array x must contain the - vector x.On exit, x is overwritten with the transformed vector x. - - \param[in] incx - incx is INTEGER - On entry, incx specifies the increment for the elements of x. - incx must not be zero. - */ -template< typename T > -void tpmv( - CBLAS_ORDER layout, - CBLAS_UPLO uplo, - CBLAS_TRANSPOSE trans, - CBLAS_DIAG diag, - int64_t n, - T const *Ap, - T *x, int64_t incx ) -{ - cblas_tpmv(layout, uplo, trans, diag, n, Ap, x, incx); -} - -/*! \brief Solve the one of the triangular matrix-vector equation for arbitrary data types - - \b Purpose: - - TRSV solves one of the systems of equations for arbitrary data types - Data precisions supported include SINGLE PRECISION REAL, DOUBLE PRECISION REAL, - SINGLE PRECISION COMPLEX, DOUBLE PRECISION COMPLEX(COMPLEX*16) - - A*x = b, or A**T*x = b, - - where b and x are n element vectors and A is an n by n unit, or - non-unit, upper or lower triangular matrix - - \param[in] layout - layout is enum CBLAS_ORDER - layout specifies Matrix storage as follows: - layout = CBLAS_ORDER::CblasRowMajor or Layout::CblasColMajor. - - \param[in] uplo - uplo is enum CBLAS_UPLO. - uplo specifies specifies whether the matrix A is an upper or - lower triangular matrix as follows: \n - uplo = CBLAS_UPLO::CblasUpper A is an upper triangular matrix. \n - uplo = CBLAS_UPLO::CblasLower A is a lower triangular matrix. - - \param[in] trans - trans is CBLAS_TRANSPOSE - On entry, trans specifies the operation to be performed as follows: - trans = CBLAS_TRANSPOSE::CblasNoTrans, A*x = b. \n - trans = CBLAS_TRANSPOSE::CblasTrans, A**T*x = b. \n - trans = CBLAS_TRANSPOSE::CblasConjTrans, A**T*x = b. - - \param[in] diag - diag is enum CBLAS_DIAG - diag specifies specifies whether or not A is unit triangular - as follows: \n - diag = CBLAS_DIAG::CblasUnit A is assumed to be unit triangular.\n - diag = CBLAS_DIAG::CblasNonUnit A is not assumed to be unit - triangular. - - \param[in] n - n is INTEGER - On entry, n specifies the order of the matrix A.n must be at least zero. - - \param[in] A - A is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array,dimension ( lda, n )\n - Before entry with UPLO = CblasUpper, the leading n by n - upper triangular part of the array A must contain the upper - triangular matrix and the strictly lower triangular part of - A is not referenced. \n - Before entry with UPLO = CblasLower, the leading n by n - lower triangular part of the array A must contain the lower - triangular matrix and the strictly upper triangular part of - A is not referenced. \n - Note that when DIAG = CblasUnit, the diagonal elements of - A are not referenced either, but are assumed to be unity. - - \param[in] lda - lda is INTEGER - On entry, lda specifies the Leading dimension of A - lda must be at least max( 1, n ). - - \param[in, out] x - x is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array,dimension : - at least ( 1 + ( n - 1 )*abs( incx ) ). \n - Before entry, the incremented array x must contain the - element right-hand side vector b.On exit, x is overwritten - with the transformed vector x. - - \param[in] incx - incx is INTEGER - On entry, incx specifies the increment for the elements of x. - incx must not be zero. - */ -template< typename T > -void trsv( - CBLAS_ORDER layout, - CBLAS_UPLO uplo, - CBLAS_TRANSPOSE trans, - CBLAS_DIAG diag, - int64_t n, - T const *A, int64_t lda, - T *x, int64_t incx ) -{ - cblas_trsv(layout, uplo, trans, diag, n, A, lda, x, incx); -} - -/*! \brief Solve the one of the triangular matrix-vector equation for arbitrary data types - - \b Purpose: - - TBSV solves one of the systems of equations for arbitrary data types - Data precisions supported include SINGLE PRECISION REAL, DOUBLE PRECISION REAL, - SINGLE PRECISION COMPLEX, DOUBLE PRECISION COMPLEX(COMPLEX*16) - - A*x = b, or A**T*x = b, - - where b and x are n element vectors and A is an n by n unit, or - non-unit, upper or lower triangular band matrix, with ( k + 1 ) - diagonals. - - \param[in] layout - layout is enum CBLAS_ORDER - layout specifies Matrix storage as follows: - layout = CBLAS_ORDER::CblasRowMajor or Layout::CblasColMajor. - - \param[in] uplo - uplo is enum CBLAS_UPLO. - uplo specifies specifies whether the matrix A is an upper or - lower triangular matrix as follows: \n - uplo = CBLAS_UPLO::CblasUpper A is an upper triangular matrix. \n - uplo = CBLAS_UPLO::CblasLower A is a lower triangular matrix. - - \param[in] trans - trans is CBLAS_TRANSPOSE - On entry, trans specifies the operation to be performed as follows: - trans = CBLAS_TRANSPOSE::CblasNoTrans, A*x = b. \n - trans = CBLAS_TRANSPOSE::CblasTrans, A**T*x = b. \n - trans = CBLAS_TRANSPOSE::CblasConjTrans, A**T*x = b. - - \param[in] diag - diag is enum CBLAS_DIAG - diag specifies specifies whether or not A is unit triangular - as follows: \n - diag = CBLAS_DIAG::CblasUnit A is assumed to be unit triangular.\n - diag = CBLAS_DIAG::CblasNonUnit A is not assumed to be unit - triangular. - - \param[in] n - n is INTEGER - On entry, n specifies the order of the matrix A.n must be at least zero. - - \param[in] k - k is INTEGER - On entry with UPLO = CblasUpper, k specifies the number of - super-diagonals of the matrix A. - On entry with UPLO = CblasLower, k specifies the number of - sub-diagonals of the matrix A. - k must at least zero. - - \param[in] A - A is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array,dimension ( lda, n )\n - Before entry with UPLO = CblasUpper, the leading ( k + 1 ) - by n part of the array A must contain the upper triangular - band part of the matrix of coefficients, supplied column by - column, with the leading diagonal of the matrix in row - ( k + 1 ) of the array, the first super-diagonal starting at - position 2 in row k, and so on. The top left k by k triangle - of the array A is not referenced. \n - Before entry with UPLO = CblasLower, the leading ( k + 1 ) - by n part of the array A must contain the lower triangular - band part of the matrix of coefficients, supplied column by - column, with the leading diagonal of the matrix in row 1 of - the array, the first sub-diagonal starting at position 1 in - row 2, and so on. The bottom right k by k triangle of the - array A is not referenced. \n - Note that when DIAG = CblasUnit, the elements of the array A - corresponding to the diagonal elements of the matrix are not - referenced, but are assumed to be unity. - - \param[in] lda - lda is INTEGER - On entry, lda specifies the Leading dimension of A - lda must be at least max( 1, k+1 ). - - \param[in, out] x - x is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array,dimension : - at least ( 1 + ( n - 1 )*abs( incx ) ). \n - Before entry, the incremented array x must contain the - element right-hand side vector b.On exit, x is overwritten - with the solution vector x. - - \param[in] incx - incx is INTEGER - On entry, incx specifies the increment for the elements of x. - incx must not be zero. - */ -template< typename T > -void tbsv( - CBLAS_ORDER layout, - CBLAS_UPLO uplo, - CBLAS_TRANSPOSE trans, - CBLAS_DIAG diag, - int64_t n, int64_t k, - T const *A, int64_t lda, - T *x, int64_t incx ) -{ - cblas_tbsv(layout, uplo, trans, diag, n, k, A, lda, x, incx); -} - - -/*! \brief Solve the one of the triangular matrix-vector equation for arbitrary data types - - \b Purpose: - - TPSV solves one of the systems of equations for arbitrary data types - Data precisions supported include SINGLE PRECISION REAL, DOUBLE PRECISION REAL, - SINGLE PRECISION COMPLEX, DOUBLE PRECISION COMPLEX(COMPLEX*16) - - A*x = b, or A**T*x = b, - - where b and x are n element vectors and A is an n by n unit, or - non-unit, upper or lower triangular band matrix, supplied in packed form. - - \param[in] layout - layout is enum CBLAS_ORDER - layout specifies Matrix storage as follows: - layout = CBLAS_ORDER::CblasRowMajor or Layout::CblasColMajor. - - \param[in] uplo - uplo is enum CBLAS_UPLO. - uplo specifies specifies whether the matrix A is an upper or - lower triangular matrix as follows: \n - uplo = CBLAS_UPLO::CblasUpper A is an upper triangular matrix. \n - uplo = CBLAS_UPLO::CblasLower A is a lower triangular matrix. - - \param[in] trans - trans is CBLAS_TRANSPOSE - On entry, trans specifies the operation to be performed as follows: - trans = CBLAS_TRANSPOSE::CblasNoTrans, A*x = b. \n - trans = CBLAS_TRANSPOSE::CblasTrans, A**T*x = b. \n - trans = CBLAS_TRANSPOSE::CblasConjTrans, A**T*x = b. - - \param[in] diag - diag is enum CBLAS_DIAG - diag specifies specifies whether or not A is unit triangular - as follows: \n - diag = CBLAS_DIAG::CblasUnit A is assumed to be unit triangular.\n - diag = CBLAS_DIAG::CblasNonUnit A is not assumed to be unit - triangular. - - \param[in] n - n is INTEGER - On entry, n specifies the order of the matrix A.n must be at least zero. - - \param[in] Ap - Ap is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array,dimension - ( ( n*( n + 1 ) )/2 ). \n - Before entry with UPLO = CblasUpper, the array Ap must - contain the upper triangular matrix packed sequentially, - column by column, so that Ap( 1 ) contains a( 1, 1 ), - Ap( 2 ) and Ap( 3 ) contain a( 1, 2 ) and a( 2, 2 ) - respectively, and so on. \n - Before entry with UPLO = CblasLower, the array Ap must - contain the lower triangular matrix packed sequentially, - column by column, so that Ap( 1 ) contains a( 1, 1 ), - Ap( 2 ) and Ap( 3 ) contain a( 2, 1 ) and a( 3, 1 ) - respectively, and so on. \n - Note that when DIAG = CblasUnit, the diagonal elements of - A are not referenced, but are assumed to be unity. - - \param[in, out] x - x is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array,dimension : - at least ( 1 + ( n - 1 )*abs( incx ) ). \n - Before entry, the incremented array x must contain the - element right-hand side vector b.On exit, x is overwritten - with the solution vector x. - - \param[in] incx - incx is INTEGER - On entry, incx specifies the increment for the elements of x. - incx must not be zero. - */ -template< typename T > -void tpsv( - CBLAS_ORDER layout, - CBLAS_UPLO uplo, - CBLAS_TRANSPOSE trans, - CBLAS_DIAG diag, - int64_t n, - T const *Ap, - T *x, int64_t incx ) -{ - cblas_tpsv(layout, uplo, trans, diag, n, Ap, x, incx); -} - -/*! \brief Perform the General matrix rank-1 update for arbitrary data types - - \b Purpose: - - GER performs the rank 1 operation for arbitrary data types - Data precisions supported include SINGLE PRECISION REAL, DOUBLE PRECISION REAL, - - A := alpha*x*y**T + A, - - where alpha is a scalar, x is an m element vector, y is an n element - vector and A is an m by n matrix. - - \param[in] layout - layout is enum CBLAS_ORDER - layout specifies Matrix storage as follows: - layout = CBLAS_ORDER::CblasRowMajor or Layout::CblasColMajor. - - \param[in] m - m is INTEGER - On entry, m specifies the number of rows of the matrix A. - m must be at least zero. - - \param[in] n - n is INTEGER - On entry, n specifies the number of columns of the matrix A. - n must be at least zero. - - \param[in] alpha - alpha is REAL/DOUBLE PRECISION - On entry, alpha specifies the scalar alpha. - - \param[in] x - x is REAL/DOUBLE PRECISION array,dimension : - at least ( 1 + ( m - 1 )*abs( incx ) ). \n - Before entry, the incremented array x must contain the m - element vector x. - - \param[in] incx - incx is INTEGER - On entry, incx specifies the increment for the elements of x. - incx must not be zero. - - \param[in] y - y is REAL/DOUBLE PRECISION array,dimension : - at least ( 1 + ( n - 1 )*abs( incy ) ). \n - Before entry, the incremented array y must contain the n - element vector y. - - \param[in] incy - incy is INTEGER - On entry, incy specifies the increment for the elements of y. - incy must not be zero. - - \param[in,out] A - A is REAL/DOUBLE PRECISION array,dimension ( lda, n )\n - Before entry, the leading m by n part of the array A must - contain the matrix of coefficients. On exit, A is - overwritten by the updated matrix. - - \param[in] lda - lda is INTEGER - On entry, lda specifies the Leading dimension of A - lda must be at least max( 1, m ). - */ -template< typename T > -void ger( - CBLAS_ORDER layout, - int64_t m, int64_t n, - T alpha, - T const *x, int64_t incx, - T const *y, int64_t incy, - T *A, int64_t lda ) -{ - cblas_ger(layout, m, n, alpha, x, incx, y, incy, A, lda); -} - -/*! \brief Perform the General matrix rank-1 update for arbitrary data types - - \b Purpose: - - GERU performs the rank 1 operation for arbitrary data types - Data precisions supported include SINGLE/DOUBLE PRECISION COMPLEX(COMPLEX*16) - - A := alpha*x*y**T + A, - - where alpha is a scalar, x is an m element vector, y is an n element - vector and A is an m by n matrix. - - \param[in] layout - layout is enum CBLAS_ORDER - layout specifies Matrix storage as follows: - layout = CBLAS_ORDER::CblasRowMajor or Layout::CblasColMajor. - - \param[in] m - m is INTEGER - On entry, m specifies the number of rows of the matrix A. - m must be at least zero. - - \param[in] n - n is INTEGER - On entry, n specifies the number of columns of the matrix A. - n must be at least zero. - - \param[in] alpha - alpha is SINGLE/DOUBLE PRECISION COMPLEX - On entry, alpha specifies the scalar alpha. - - \param[in] x - x is SINGLE/DOUBLE PRECISION COMPLEX array,dimension : - at least ( 1 + ( m - 1 )*abs( incx ) ). \n - Before entry, the incremented array x must contain the m - element vector x. - - \param[in] incx - incx is INTEGER - On entry, incx specifies the increment for the elements of x. - incx must not be zero. - - \param[in] y - y is SINGLE/DOUBLE PRECISION COMPLEX array,dimension : - at least ( 1 + ( n - 1 )*abs( incy ) ). \n - Before entry, the incremented array y must contain the n - element vector y. - - \param[in] incy - incy is INTEGER - On entry, incy specifies the increment for the elements of y. - incy must not be zero. - - \param[in,out] A - A is SINGLE/DOUBLE PRECISION COMPLEX array,dimension ( lda, n )\n - Before entry, the leading m by n part of the array A must - contain the matrix of coefficients. On exit, A is - overwritten by the updated matrix. - - \param[in] lda - lda is INTEGER - On entry, lda specifies the Leading dimension of A - lda must be at least max( 1, m ). - */ -template< typename T > -void geru( - CBLAS_ORDER layout, - int64_t m, int64_t n, - T alpha, - T const *x, int64_t incx, - T const *y, int64_t incy, - T *A, int64_t lda ) -{ - cblas_geru(layout, m, n, alpha, x, incx, y, incy, A, lda); -} - -/*! \brief Perform the General matrix rank-1 update for arbitrary data types - - \b Purpose: - - GERC performs the rank 1 operation for arbitrary data types - Data precisions supported include SINGLE/DOUBLE PRECISION COMPLEX(COMPLEX*16) - - A := alpha*x*y**T + A, - - where alpha is a scalar, x is an m element vector, y is an n element - vector and A is an m by n matrix. - - \param[in] layout - layout is enum CBLAS_ORDER - layout specifies Matrix storage as follows: - layout = CBLAS_ORDER::CblasRowMajor or Layout::CblasColMajor. - - \param[in] m - m is INTEGER - On entry, m specifies the number of rows of the matrix A. - m must be at least zero. - - \param[in] n - n is INTEGER - On entry, n specifies the number of columns of the matrix A. - n must be at least zero. - - \param[in] alpha - alpha is SINGLE/DOUBLE PRECISION COMPLEX - On entry, alpha specifies the scalar alpha. - - \param[in] x - x is SINGLE/DOUBLE PRECISION COMPLEX array,dimension : - at least ( 1 + ( m - 1 )*abs( incx ) ). \n - Before entry, the incremented array x must contain the m - element vector x. - - \param[in] incx - incx is INTEGER - On entry, incx specifies the increment for the elements of x. - incx must not be zero. - - \param[in] y - y is SINGLE/DOUBLE PRECISION COMPLEX array,dimension : - at least ( 1 + ( n - 1 )*abs( incy ) ). \n - Before entry, the incremented array y must contain the n - element vector y. - - \param[in] incy - incy is INTEGER - On entry, incy specifies the increment for the elements of y. - incy must not be zero. - - \param[in,out] A - A is SINGLE/DOUBLE PRECISION COMPLEX array,dimension ( lda, n )\n - Before entry, the leading m by n part of the array A must - contain the matrix of coefficients. On exit, A is - overwritten by the updated matrix. - - \param[in] lda - lda is INTEGER - On entry, lda specifies the Leading dimension of A - lda must be at least max( 1, m ). - */ -template< typename T > -void gerc( - CBLAS_ORDER layout, - int64_t m, int64_t n, - T alpha, - T const *x, int64_t incx, - T const *y, int64_t incy, - T *A, int64_t lda ) -{ - cblas_gerc(layout, m, n, alpha, x, incx, y, incy, A, lda); -} - -/*! \brief Perform the hermitian rank 1 operation for arbitrary data types - - \b Purpose: - - HER performs the hermitian rank 1 operation for arbitrary data types - Data precisions supported include SINGLE/DOUBLE PRECISION COMPLEX(COMPLEX*16) - - A := alpha*x*x**H + A, - - where alpha is a real scalar, x is an n element vector, A is an n by n - hermitian matrix. - - \param[in] layout - layout is enum CBLAS_ORDER - layout specifies Matrix storage as follows: - layout = CBLAS_ORDER::CblasRowMajor or Layout::CblasColMajor. - - \param[in] uplo - uplo is enum CBLAS_UPLO. - uplo specifies specifies whether the upper or lower triangular - part of the array A is to be referenced as follows: \n - uplo = CBLAS_UPLO::CblasUpper A is an upper triangular matrix. \n - uplo = CBLAS_UPLO::CblasLower A is a lower triangular matrix. - - \param[in] n - n is INTEGER - On entry, n specifies the order of the matrix A. - n must be at least zero. - - \param[in] alpha - alpha is SINGLE/DOUBLE PRECISION REAL - On entry, alpha specifies the scalar alpha. - - \param[in] x - x is SINGLE/DOUBLE PRECISION COMPLEX array,dimension : - at least ( 1 + ( n - 1 )*abs( incx ) ). \n - Before entry, the incremented array x must contain the n - element vector x. - - \param[in] incx - incx is INTEGER - On entry, incx specifies the increment for the elements of x. - incx must not be zero. - - \param[in,out] A - A is SINGLE/DOUBLE PRECISION COMPLEX array,dimension ( lda, n )\n - Before entry with UPLO = CblasUpper, the leading n by n - upper triangular part of the array A must contain the upper - triangular part of the hermitian matrix and the strictly - lower triangular part of A is not referenced. On exit, the - upper triangular part of the array A is overwritten by the - upper triangular part of the updated matrix. \n - Before entry with UPLO = CblasLower, the leading n by n - lower triangular part of the array A must contain the lower - triangular part of the hermitian matrix and the strictly - upper triangular part of A is not referenced. On exit, the - lower triangular part of the array A is overwritten by the - lower triangular part of the updated matrix. \n - Note that the imaginary parts of the diagonal elements need - not be set, they are assumed to be zero, and on exit they - are set to zero. - - \param[in] lda - lda is INTEGER - On entry, lda specifies the Leading dimension of A - lda must be at least max( 1, n ). - */ -template< typename T > -void her( - CBLAS_ORDER layout, - CBLAS_UPLO uplo, - int64_t n, - real_type alpha, // zher takes double alpha; use real - T const *x, int64_t incx, - T *A, int64_t lda ) -{ - cblas_her(layout, uplo, n, alpha, x, incx, A, lda); -} - -/*! \brief Perform the hermitian rank 1 operation for arbitrary data types - - \b Purpose: - - HPR performs the hermitian rank 1 operation for arbitrary data types - Data precisions supported include SINGLE/DOUBLE PRECISION COMPLEX(COMPLEX*16) - - A := alpha*x*x**H + A, - - where alpha is a real scalar, x is an n element vector, A is an n by n - hermitian matrix, supplied in packed form. - - \param[in] layout - layout is enum CBLAS_ORDER - layout specifies Matrix storage as follows: - layout = CBLAS_ORDER::CblasRowMajor or Layout::CblasColMajor. - - \param[in] uplo - uplo is enum CBLAS_UPLO. - uplo specifies specifies whether the upper or lower triangular - part of the array A is to be referenced as follows: \n - uplo = CBLAS_UPLO::CblasUpper The upper triangular part of A is - supplied in Ap. \n - uplo = CBLAS_UPLO::CblasLower The lower triangular part of A is - supplied in Ap. - - \param[in] n - n is INTEGER - On entry, n specifies the order of the matrix A. - n must be at least zero. - - \param[in] alpha - alpha is SINGLE/DOUBLE PRECISION REAL - On entry, alpha specifies the scalar alpha. - - \param[in] x - x is SINGLE/DOUBLE PRECISION COMPLEX array,dimension : - at least ( 1 + ( n - 1 )*abs( incx ) ). \n - Before entry, the incremented array x must contain the n - element vector x. - - \param[in] incx - incx is INTEGER - On entry, incx specifies the increment for the elements of x. - incx must not be zero. - - \param[in,out] Ap - Ap is SINGLE/DOUBLE PRECISION COMPLEX array,dimension - atleast ( ( n*( n + 1 ) )/2 ).\n - Before entry with UPLO = CblasUpper, the array Ap must - contain the upper triangular part of the hermitian matrix - packed sequentially, column by column, so that Ap( 1 ) - contains a( 1, 1 ), Ap( 2 ) and Ap( 3 ) contain a( 1, 2 ) - and a( 2, 2 ) respectively, and so on. On exit, the array - Ap is overwritten by the upper triangular part of the - updated matrix. \n - Before entry with UPLO = CblasLower, the array Ap must - contain the lower triangular part of the hermitian matrix - packed sequentially, column by column, so that Ap( 1 ) - contains a( 1, 1 ), Ap( 2 ) and Ap( 3 ) contain a( 2, 1 ) - and a( 3, 1 ) respectively, and so on. On exit, the array - Ap is overwritten by the lower triangular part of the - updated matrix. \n - Note that the imaginary parts of the diagonal elements need - not be set, they are assumed to be zero, and on exit they - are set to zero. - */ -template< typename T > -void hpr( - CBLAS_ORDER layout, - CBLAS_UPLO uplo, - int64_t n, - real_type alpha, // zher takes double alpha; use real - T const *x, int64_t incx, - T *Ap ) -{ - cblas_hpr(layout, uplo, n, alpha, x, incx, Ap); -} - -/*! \brief Perform the hermitian rank 2 operation for arbitrary data types - - \b Purpose: - - HER2 performs the hermitian rank 2 operation for arbitrary data types - Data precisions supported include SINGLE/DOUBLE PRECISION COMPLEX(COMPLEX*16) - - A := alpha*x*y**H + conjg( alpha )*y*x**H + A, - - where alpha is a scalar, x and y are n element vector, A is an n by n - hermitian matrix. - - \param[in] layout - layout is enum CBLAS_ORDER - layout specifies Matrix storage as follows: - layout = CBLAS_ORDER::CblasRowMajor or Layout::CblasColMajor. - - \param[in] uplo - uplo is enum CBLAS_UPLO. - uplo specifies whether the upper or lower triangular part of the - array A is to be referenced as follows: \n - UPLO = CblasUpper Only the upper triangular part of A - is to be referenced. \n - UPLO = CblasLower Only the lower triangular part of A - is to be referenced. - - \param[in] n - n is INTEGER - On entry, n specifies the order of the matrix A. - n must be at least zero. - - \param[in] alpha - alpha is SINGLE/DOUBLE PRECISION COMPLEX - On entry, alpha specifies the scalar alpha. - - \param[in] x - x is SINGLE/DOUBLE PRECISION COMPLEX array,dimension : - at least ( 1 + ( n - 1 )*abs( incx ) ). \n - Before entry, the incremented array x must contain the n - element vector x. - - \param[in] incx - incx is INTEGER - On entry, incx specifies the increment for the elements of x. - incx must not be zero. - - \param[in] y - y is SINGLE/DOUBLE PRECISION COMPLEX array,dimension : - at least ( 1 + ( n - 1 )*abs( incy ) ). \n - Before entry, the incremented array y must contain the n - element vector y. - - \param[in] incy - incy is INTEGER - On entry, incy specifies the increment for the elements of y. - incy must not be zero. - - \param[in,out] A - A is SINGLE/DOUBLE PRECISION COMPLEX array,dimension ( lda, n )\n - Before entry with UPLO = CblasUpper, the leading n by n - upper triangular part of the array A must contain the upper - triangular part of the hermitian matrix and the strictly - lower triangular part of A is not referenced. On exit, the - upper triangular part of the array A is overwritten by the - upper triangular part of the updated matrix. \n - Before entry with UPLO = CblasLower, the leading n by n - lower triangular part of the array A must contain the lower - triangular part of the hermitian matrix and the strictly - upper triangular part of A is not referenced. On exit, the - lower triangular part of the array A is overwritten by the - lower triangular part of the updated matrix. \n - Note that the imaginary parts of the diagonal elements need - not be set, they are assumed to be zero, and on exit they - are set to zero. - - \param[in] lda - lda is INTEGER - On entry, lda specifies the Leading dimension of A - lda must be at least max( 1, n ). - */ -template< typename T > -void her2( - CBLAS_ORDER layout, - CBLAS_UPLO uplo, - int64_t n, - T alpha, - T const *x, int64_t incx, - T const *y, int64_t incy, - T *A, int64_t lda ) -{ - cblas_her2(layout, uplo, n, alpha, x, incx, y, incy, A, lda); -} - -/*! \brief Perform the hermitian rank 2 operation for arbitrary data types - - \b Purpose: - - HPR2 performs the hermitian rank 2 operation for arbitrary data types - Data precisions supported include SINGLE/DOUBLE PRECISION COMPLEX(COMPLEX*16) - - A := alpha*x*y**H + conjg( alpha )*y*x**H + A, - - where alpha is a scalar, x and y are n element vector, A is an n by n - hermitian matrix, supplied in packed form. - - \param[in] layout - layout is enum CBLAS_ORDER - layout specifies Matrix storage as follows: - layout = CBLAS_ORDER::CblasRowMajor or Layout::CblasColMajor. - - \param[in] uplo - uplo is enum CBLAS_UPLO. - uplo specifies specifies whether the upper or lower triangular - part of the array A is to be referenced as follows: \n - uplo = CBLAS_UPLO::CblasUpper The upper triangular part of A is - supplied in Ap. \n - uplo = CBLAS_UPLO::CblasLower The lower triangular part of A is - supplied in Ap. - - \param[in] n - n is INTEGER - On entry, n specifies the order of the matrix A. - n must be at least zero. - - \param[in] alpha - alpha is SINGLE/DOUBLE PRECISION COMPLEX - On entry, alpha specifies the scalar alpha. - - \param[in] x - x is SINGLE/DOUBLE PRECISION COMPLEX array,dimension : - at least ( 1 + ( n - 1 )*abs( incx ) ). \n - Before entry, the incremented array x must contain the n - element vector x. - - \param[in] incx - incx is INTEGER - On entry, incx specifies the increment for the elements of x. - incx must not be zero. - - \param[in] y - y is SINGLE/DOUBLE PRECISION REAL array,dimension : - at least ( 1 + ( n - 1 )*abs( incy ) ). \n - Before entry, the incremented array y must contain the n - element vector y. - - \param[in] incy - incy is INTEGER - On entry, incy specifies the increment for the elements of y. - incy must not be zero. - - \param[in,out] Ap - Ap is SINGLE/DOUBLE PRECISION COMPLEX array,dimension - atleast ( ( n*( n + 1 ) )/2 ).\n - Before entry with UPLO = CblasUpper, the array Ap must - contain the upper triangular part of the hermitian matrix - packed sequentially, column by column, so that Ap( 1 ) - contains a( 1, 1 ), Ap( 2 ) and Ap( 3 ) contain a( 1, 2 ) - and a( 2, 2 ) respectively, and so on. On exit, the array - Ap is overwritten by the upper triangular part of the - updated matrix. \n - Before entry with UPLO = CblasLower, the array Ap must - contain the lower triangular part of the hermitian matrix - packed sequentially, column by column, so that Ap( 1 ) - contains a( 1, 1 ), Ap( 2 ) and Ap( 3 ) contain a( 2, 1 ) - and a( 3, 1 ) respectively, and so on. On exit, the array - Ap is overwritten by the lower triangular part of the - updated matrix. \n - Note that the imaginary parts of the diagonal elements need - not be set, they are assumed to be zero, and on exit they - are set to zero. - */ -template< typename T > -void hpr2( - CBLAS_ORDER layout, - CBLAS_UPLO uplo, - int64_t n, - T alpha, - T const *x, int64_t incx, - T const *y, int64_t incy, - T *Ap ) -{ - cblas_hpr2(layout, uplo, n, alpha, x, incx, y, incy, Ap); -} - -/*! \brief Perform the symmetric rank 1 operation for arbitrary data types - - \b Purpose: - - SYR performs the symmetric rank 1 operation for arbitrary data types - Data precisions supported include SINGLE/DOUBLE PRECISION REAL - - A := alpha*x*x**T + A, - - where alpha is a real scalar, x is an n element vector, A is an n by n - symmetric matrix. - - \param[in] layout - layout is enum CBLAS_ORDER - layout specifies Matrix storage as follows: - layout = CBLAS_ORDER::CblasRowMajor or Layout::CblasColMajor. - - \param[in] uplo - uplo is enum CBLAS_UPLO. - uplo specifies specifies whether the upper or lower triangular - part of the array A is to be referenced as follows: \n - uplo = CBLAS_UPLO::CblasUpper A is an upper triangular matrix. \n - uplo = CBLAS_UPLO::CblasLower A is a lower triangular matrix. - - \param[in] n - n is INTEGER - On entry, n specifies the order of the matrix A. - n must be at least zero. - - \param[in] alpha - alpha is SINGLE/DOUBLE PRECISION REAL - On entry, alpha specifies the scalar alpha. - - \param[in] x - x is SINGLE/DOUBLE PRECISION REAL array,dimension : - at least ( 1 + ( n - 1 )*abs( incx ) ). \n - Before entry, the incremented array x must contain the n - element vector x. - - \param[in] incx - incx is INTEGER - On entry, incx specifies the increment for the elements of x. - incx must not be zero. - - \param[in,out] A - A is SINGLE/DOUBLE PRECISION REAL array,dimension ( lda, n )\n - Before entry with UPLO = CblasUpper, the leading n by n - upper triangular part of the array A must contain the upper - triangular part of the symmetric matrix and the strictly - lower triangular part of A is not referenced. On exit, the - upper triangular part of the array A is overwritten by the - upper triangular part of the updated matrix. \n - Before entry with UPLO = CblasLower, the leading n by n - lower triangular part of the array A must contain the lower - triangular part of the symmetric matrix and the strictly - upper triangular part of A is not referenced. On exit, the - lower triangular part of the array A is overwritten by the - lower triangular part of the updated matrix. \n - - \param[in] lda - lda is INTEGER - On entry, lda specifies the Leading dimension of A - lda must be at least max( 1, n ). - */ -template< typename T > -void syr( - CBLAS_ORDER layout, - CBLAS_UPLO uplo, - int64_t n, - T alpha, - T const *x, int64_t incx, - T *A, int64_t lda ) -{ - cblas_syr(layout, uplo, n, alpha, x, incx, A, lda); -} - -/*! \brief Perform the symmetric rank 1 operation for arbitrary data types - - \b Purpose: - - SPR performs the symmetric rank 1 operation for arbitrary data types - Data precisions supported include SINGLE PRECISION REAL, DOUBLE PRECISION REAL - - A := alpha*x*x**T + A, - - where alpha is a real scalar, x is an n element vector, A is an n by n - symmetric matrix, supplied in packed form. - - \param[in] layout - layout is enum CBLAS_ORDER - layout specifies Matrix storage as follows: - layout = CBLAS_ORDER::CblasRowMajor or Layout::CblasColMajor. - - \param[in] uplo - uplo is enum CBLAS_UPLO. - uplo specifies specifies whether the upper or lower triangular - part of the array A is to be referenced as follows: \n - uplo = CBLAS_UPLO::CblasUpper The upper triangular part of A is - supplied in Ap. \n - uplo = CBLAS_UPLO::CblasLower The lower triangular part of A is - supplied in Ap. - - \param[in] n - n is INTEGER - On entry, n specifies the order of the matrix A. - n must be at least zero. - - \param[in] alpha - alpha is SINGLE/DOUBLE PRECISION REAL - On entry, alpha specifies the scalar alpha. - - \param[in] x - x is SINGLE/DOUBLE PRECISION REAL array,dimension : - at least ( 1 + ( n - 1 )*abs( incx ) ). \n - Before entry, the incremented array x must contain the n - element vector x. - - \param[in] incx - incx is INTEGER - On entry, incx specifies the increment for the elements of x. - incx must not be zero. - - \param[in,out] Ap - Ap is SINGLE/DOUBLE PRECISION REAL array,dimension - atleast ( ( n*( n + 1 ) )/2 ).\n - Before entry with UPLO = CblasUpper, the array Ap must - contain the upper triangular part of the symmetric matrix - packed sequentially, column by column, so that Ap( 1 ) - contains a( 1, 1 ), Ap( 2 ) and Ap( 3 ) contain a( 1, 2 ) - and a( 2, 2 ) respectively, and so on. On exit, the array - Ap is overwritten by the upper triangular part of the - updated matrix. \n - Before entry with UPLO = CblasLower, the array Ap must - contain the lower triangular part of the symmetric matrix - packed sequentially, column by column, so that Ap( 1 ) - contains a( 1, 1 ), Ap( 2 ) and Ap( 3 ) contain a( 2, 1 ) - and a( 3, 1 ) respectively, and so on. On exit, the array - Ap is overwritten by the lower triangular part of the - updated matrix. \n - */ -template< typename T > -void spr( - CBLAS_ORDER layout, - CBLAS_UPLO uplo, - int64_t n, - T alpha, - T const *x, int64_t incx, - T *Ap ) -{ - cblas_spr(layout, uplo, n, alpha, x, incx, Ap); -} - -/*! \brief Perform the symmetric rank 2 operation for arbitrary data types - - \b Purpose: - - SYR2 performs the symmetric rank 2 operation for arbitrary data types - Data precisions supported include SINGLE/DOUBLE PRECISION REAL - - A := alpha*x*y**T + alpha*y*x**T + A, - - where alpha is a scalar, x and y are n element vector, A is an n by n - symmetric matrix. - - \param[in] layout - layout is enum CBLAS_ORDER - layout specifies Matrix storage as follows: - layout = CBLAS_ORDER::CblasRowMajor or Layout::CblasColMajor. - - \param[in] uplo - uplo is enum CBLAS_UPLO. - uplo specifies whether the upper or lower triangular part of the - array A is to be referenced as follows: \n - UPLO = CblasUpper Only the upper triangular part of A - is to be referenced. \n - UPLO = CblasLower Only the lower triangular part of A - is to be referenced. - - \param[in] n - n is INTEGER - On entry, n specifies the order of the matrix A. - n must be at least zero. - - \param[in] alpha - alpha is SINGLE/DOUBLE PRECISION REAL - On entry, alpha specifies the scalar alpha. - - \param[in] x - x is SINGLE/DOUBLE PRECISION REAL array,dimension : - at least ( 1 + ( n - 1 )*abs( incx ) ). \n - Before entry, the incremented array x must contain the n - element vector x. - - \param[in] incx - incx is INTEGER - On entry, incx specifies the increment for the elements of x. - incx must not be zero. - - \param[in] y - y is SINGLE/DOUBLE PRECISION REAL array,dimension : - at least ( 1 + ( n - 1 )*abs( incy ) ). \n - Before entry, the incremented array y must contain the n - element vector y. - - \param[in] incy - incy is INTEGER - On entry, incy specifies the increment for the elements of y. - incy must not be zero. - - \param[in,out] A - A is SINGLE/DOUBLE PRECISION REAL array,dimension ( lda, n )\n - Before entry with UPLO = CblasUpper, the leading n by n - upper triangular part of the array A must contain the upper - triangular part of the symmetric matrix and the strictly - lower triangular part of A is not referenced. On exit, the - upper triangular part of the array A is overwritten by the - upper triangular part of the updated matrix. \n - Before entry with UPLO = CblasLower, the leading n by n - lower triangular part of the array A must contain the lower - triangular part of the symmetric matrix and the strictly - upper triangular part of A is not referenced. On exit, the - lower triangular part of the array A is overwritten by the - lower triangular part of the updated matrix. \n - - \param[in] lda - lda is INTEGER - On entry, lda specifies the Leading dimension of A - lda must be at least max( 1, n ). - */ -template< typename T > -void syr2( - CBLAS_ORDER layout, - CBLAS_UPLO uplo, - int64_t n, - T alpha, - T const *x, int64_t incx, - T const *y, int64_t incy, - T *A, int64_t lda ) -{ - cblas_syr2(layout, uplo, n, alpha, x, incx, y, incy, A, lda); -} - -/*! \brief Perform the symmetric rank 2 operation for arbitrary data types - - \b Purpose: - - SPR2 performs the symmetric rank 2 operation for arbitrary data types - Data precisions supported include SINGLE/DOUBLE PRECISION REAL - - A := alpha*x*y**T + alpha*y*x**T + A, - - where alpha is a scalar, x and y are n element vector, A is an n by n - symmetric matrix, supplied in packed form. - - \param[in] layout - layout is enum CBLAS_ORDER - layout specifies Matrix storage as follows: - layout = CBLAS_ORDER::CblasRowMajor or Layout::CblasColMajor. - - \param[in] uplo - uplo is enum CBLAS_UPLO. - uplo specifies specifies whether the upper or lower triangular - part of the array A is to be referenced as follows: \n - uplo = CBLAS_UPLO::CblasUpper The upper triangular part of A is - supplied in Ap. \n - uplo = CBLAS_UPLO::CblasLower The lower triangular part of A is - supplied in Ap. - - \param[in] n - n is INTEGER - On entry, n specifies the order of the matrix A. - n must be at least zero. - - \param[in] alpha - alpha is SINGLE/DOUBLE PRECISION REAL - On entry, alpha specifies the scalar alpha. - - \param[in] x - x is SINGLE/DOUBLE PRECISION REAL array,dimension : - at least ( 1 + ( n - 1 )*abs( incx ) ). \n - Before entry, the incremented array x must contain the n - element vector x. - - \param[in] incx - incx is INTEGER - On entry, incx specifies the increment for the elements of x. - incx must not be zero. - - \param[in] y - y is SINGLE/DOUBLE PRECISION REAL array,dimension : - at least ( 1 + ( n - 1 )*abs( incy ) ). \n - Before entry, the incremented array y must contain the n - element vector y. - - \param[in] incy - incy is INTEGER - On entry, incy specifies the increment for the elements of y. - incy must not be zero. - - \param[in,out] Ap - Ap is SINGLE/DOUBLE PRECISION REAL array,dimension - atleast ( ( n*( n + 1 ) )/2 ).\n - Before entry with UPLO = CblasUpper, the array Ap must - contain the upper triangular part of the symmetric matrix - packed sequentially, column by column, so that Ap( 1 ) - contains a( 1, 1 ), Ap( 2 ) and Ap( 3 ) contain a( 1, 2 ) - and a( 2, 2 ) respectively, and so on. On exit, the array - Ap is overwritten by the upper triangular part of the - updated matrix. \n - Before entry with UPLO = CblasLower, the array Ap must - contain the lower triangular part of the symmetric matrix - packed sequentially, column by column, so that Ap( 1 ) - contains a( 1, 1 ), Ap( 2 ) and Ap( 3 ) contain a( 2, 1 ) - and a( 3, 1 ) respectively, and so on. On exit, the array - Ap is overwritten by the lower triangular part of the - updated matrix. \n - */ -template< typename T > -void spr2( - CBLAS_ORDER layout, - CBLAS_UPLO uplo, - int64_t n, - T alpha, - T const *x, int64_t incx, - T const *y, int64_t incy, - T *Ap ) -{ - cblas_spr2(layout, uplo, n, alpha, x, incx, y, incy, Ap); -} - -/*! \brief General matrix-matrix multiply for arbitrary data types - - \b Purpose: - - GEMM performs general matrix-matrix multiply for arbitrary data types - Data precisions supported include SINGLE PRECISION REAL, DOUBLE PRECISION REAL, - SINGLE PRECISION COMPLEX, DOUBLE PRECISION COMPLEX(COMPLEX*16) - - C := alpha*op( A )*op( B ) + beta*C, - - where op( X ) is one of - - op( X ) = X or op( X ) = X**T or op( X ) = X**H, - - alpha and beta are scalars, and A, B and C are matrices, with op( A ) - an m by k matrix, op( B ) a k by n matrix and C an m by n matrix. - - \param[in] layout - layout is enum CBLAS_ORDER - layout specifies Matrix storage as follows: - layout = CBLAS_ORDER::CblasRowMajor or Layout::CblasColMajor. - - \param[in] transA - transA is CBLAS_TRANSPOSE - On entry, transA specifies the form of op( A ) to be used in - the matrix multiplication as follows: - transA = CBLAS_TRANSPOSE::CblasNoTrans, op( A ) = A. - transA = CBLAS_TRANSPOSE::CblasTrans, op( A ) = A**T. - transA = CBLAS_TRANSPOSE::CblasConjTrans, op( A ) = A**H. - - \param[in] transB - transB is CBLAS_TRANSPOSE - On entry, transB specifies the form of op( B ) to be used in - the matrix multiplication as follows: - transB = CBLAS_TRANSPOSE::CblasNoTrans, op( B ) = B. - transB = CBLAS_TRANSPOSE::CblasTrans, op( B ) = B**T. - transB = CBLAS_TRANSPOSE::CblasConjTrans, op( B ) = B**H. - - \param[in] m - m is INTEGER - On entry, m specifies the number of rows of the matrix - op( A ) and of the matrix C. m must be at least zero. - - \param[in] n - n is INTEGER - On entry, n specifies the number of columns of the matrix - op( B ) and the number of columns of the matrix C. n must be - at least zero. - - \param[in] k - k is INTEGER - On entry, k specifies the number of columns of the matrix - op( A ) and the number of rows of the matrix op( B ). k must - be at least zero. - - \param[in] alpha - alpha is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 - On entry, alpha specifies the scalar alpha. - - \param[in] A - A is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array,dimension : - If transA = CblasNoTrans: - m-by-k , stored in an lda-by-k array [RowMajor: m-by-lda]. - Otherwise: - k-by-m , stored in an lda-by-m array [RowMajor: k-by-lda]. - - \param[in] lda - lda is INTEGER - On entry, lda specifies the Leading dimension of A - If transA = CblasNoTrans: lda >= max(1, m) [RowMajor: lda >= max(1, k)]. - Otherwise: lda >= max(1, k) [RowMajor: lda >= max(1, m)]. - - \param[in] B - B is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array,dimension : - If transA = CblasNoTrans: - k-by-n , stored in an ldb-by-n array [RowMajor: k-by-ldb]. - Otherwise: - n-by-k , stored in an ldb-by-k array [RowMajor: n-by-ldb]. - - \param[in] ldb - ldb is INTEGER - On entry, ldb specifies the Leading dimension of B - If transA = CblasNoTrans: ldb >= max(1, k) [RowMajor: ldb >= max(1, n)]. - Otherwise: ldb >= max(1, n) [RowMajor: ldb >= max(1, k)]. - - \param[in] beta - beta is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 - On entry, beta specifies the scalar alpha.When beta is - supplied as zero then C need not be set on input. - - \param[in,out] C - C is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array, dimension : - m-by-n stored in an ldc-by-n array [RowMajor: m-by-ldc]. - Before entry, the leading m by n part of the array C must - contain the matrix C, except when beta is zero, in which - case C need not be set on entry. - On exit, the array C is overwritten by the m by n matrix - ( alpha*op( A )*op( B ) + beta*C ). - - \param[in] ldc - ldc is INTEGER - On entry, ldc specifies the first dimension of C - ldc >= max(1, m) [RowMajor: ldc >= max(1, n)]. - */ -template< typename T > -void gemm( - CBLAS_ORDER layout, - CBLAS_TRANSPOSE transA, - CBLAS_TRANSPOSE transB, - int64_t m, int64_t n, int64_t k, - T alpha, - T const *A, int64_t lda, - T const *B, int64_t ldb, - T beta, - T *C, int64_t ldc ) -{ - cblas_gemm(layout, transA, transB, m, n, k, alpha, A,lda, B, ldb, beta, C, ldc); -} - -/*! \brief Solve the triangular matrix-matrix equation for arbitrary data types - - \b Purpose: - - TRSM performs one of the matrix equations for arbitrary data types - Data precisions supported include SINGLE PRECISION REAL, DOUBLE PRECISION REAL, - SINGLE PRECISION COMPLEX, DOUBLE PRECISION COMPLEX(COMPLEX*16) - - op( A )*X = alpha*B, or X*op( A ) = alpha*B, - - where alpha is a scalar, X and B are m by n matrices, A is a unit, or - non-unit, upper or lower triangular matrix and op( A ) is one of - where op( X ) is one of - - op( A ) = A or op( A ) = A**T or op( A ) = A**H. - - The matrix X is overwritten on B. - - \param[in] layout - layout is enum CBLAS_ORDER - layout specifies Matrix storage as follows: - layout = CBLAS_ORDER::CblasRowMajor or Layout::CblasColMajor. - - \param[in] side - side is enum CBLAS_SIDE - side specifies specifies whether op( A ) appears on the left - or right of X as follows: - side = CBLAS_SIDE::CblasLeft op( A )*X = alpha*B. - side = CBLAS_SIDE::CblasRight op( A )*X = alpha*B. - - \param[in] uplo - uplo is enum CBLAS_UPLO - uplo specifies specifies whether the matrix A is an upper or - lower triangular matrix as follows: - uplo = CBLAS_UPLO::CblasUpper A is an upper triangular matrix. - uplo = CBLAS_UPLO::CblasLower A is a lower triangular matrix. - - \param[in] trans - trans is CBLAS_TRANSPOSE - On entry, trans specifies the form of op( A ) to be used in - the matrix multiplication as follows: - trans = CBLAS_TRANSPOSE::CblasNoTrans, op( A ) = A. - trans = CBLAS_TRANSPOSE::CblasTrans, op( A ) = A**T. - trans = CBLAS_TRANSPOSE::CblasConjTrans, op( A ) = A**H. - - \param[in] diag - diag is enum CBLAS_DIAG - diag specifies specifies whether or not A is unit triangular - as follows: - diag = CBLAS_DIAG::CblasUnit A is assumed to be unit triangular. - diag = CBLAS_DIAG::CblasNonUnit A is not assumed to be unit - triangular. - - \param[in] m - m is INTEGER - On entry, m specifies the number of rows of the matrix - B. m must be at least zero. - - \param[in] n - n is INTEGER - On entry, n specifies the number of columns of the matrix - B. n must be at least zero. - - \param[in] alpha - alpha is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 - On entry, alpha specifies the scalar alpha. - - \param[in] A - A is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array,dimension : - If side = CblasLeft: - the m-by-m matrix A, stored in an lda-by-m array [RowMajor: m-by-lda]. - If side = CblasRight: - the n-by-n matrix A, stored in an lda-by-n array [RowMajor: n-by-lda]. - - \param[in] lda - lda is INTEGER - On entry, lda specifies the Leading dimension of A - If side = CblasLeft: lda >= max(1, m) . - If side = CblasRight:lda >= max(1, k) . - - \param[in,out] B - B is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array,dimension : - m-by-n , stored in an ldb-by-n array [RowMajor: m-by-ldb]. - on exit is overwritten by the solution matrix X. - - \param[in] ldb - ldb is INTEGER - On entry, ldb specifies the Leading dimension of B - ldb >= max(1, m) [RowMajor: ldb >= max(1, n)]. - */ -template< typename T > -void trsm( - CBLAS_ORDER layout, - CBLAS_SIDE side, - CBLAS_UPLO uplo, - CBLAS_TRANSPOSE trans, - CBLAS_DIAG diag, - int64_t m, - int64_t n, - T alpha, - T const *A, int64_t lda, - T *B, int64_t ldb ) -{ - cblas_trsm( layout, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb); -} -/*! \brief Solve the Triangular matrix-matrix multiply for arbitrary data types - - \b Purpose: - - TRMM performs solves one of the matrix equations for arbitrary data types - Data precisions supported include SINGLE PRECISION REAL, DOUBLE PRECISION REAL, - SINGLE PRECISION COMPLEX, DOUBLE PRECISION COMPLEX(COMPLEX*16) - - B := alpha*op( A )*B, or B := alpha*B*op( A ), - - where alpha is a scalar, B is an m by n matrices, A is a unit, or - non-unit, upper or lower triangular matrix and op( A ) is one of - op( A ) = A or op( A ) = A**T. - - \param[in] layout - layout is enum CBLAS_ORDER - layout specifies Matrix storage as follows: - layout = CBLAS_ORDER::CblasRowMajor or Layout::CblasColMajor. - - \param[in] side - side is enum CBLAS_SIDE - side specifies whether op( A ) multiplies B from left or right of X - as follows: - side = CBLAS_SIDE::CblasLeft B := alpha*op( A )*B. - side = CBLAS_SIDE::CblasRight B := alpha*B*op( A ). - - \param[in] uplo - uplo is enum CBLAS_UPLO - uplo specifies whether the matrix A is an upper or lower triangular - matrix as follows: - uplo = CBLAS_UPLO::CblasUpper A is an upper triangular matrix. - uplo = CBLAS_UPLO::CblasLower A is a lower triangular matrix. - - \param[in] trans - trans is CBLAS_TRANSPOSE - On entry, trans specifies the form of op( A ) to be used in - the matrix multiplication as follows: - trans = CBLAS_TRANSPOSE::CblasNoTrans, op( A ) = A. - trans = CBLAS_TRANSPOSE::CblasTrans, op( A ) = A**T. - trans = CBLAS_TRANSPOSE::CblasConjTrans, op( A ) = A**T. - - \param[in] diag - diag is enum CBLAS_DIAG - diag specifies specifies whether or not A is unit triangular - as follows: - diag = CBLAS_DIAG::CblasUnit A is assumed to be unit triangular. - diag = CBLAS_DIAG::CblasNonUnit A is not assumed to be unit - triangular. - - \param[in] m - m is INTEGER - On entry, m specifies the number of rows of the matrix - B. m must be at least zero. - - \param[in] n - n is INTEGER - On entry, n specifies the number of columns of the matrix - B. n must be at least zero. - - \param[in] alpha - alpha is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 - On entry, alpha specifies the scalar alpha.When alpha is - zero then A is not referenced and B need not be set before - entry. - - \param[in] A - A is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array,dimension : - If side = CblasLeft: - the m-by-m matrix A, stored in an lda-by-m array [RowMajor: m-by-lda]. - If side = CblasRight: - the n-by-n matrix A, stored in an lda-by-n array [RowMajor: n-by-lda]. - - \param[in] lda - lda is INTEGER - On entry, lda specifies the Leading dimension of A - If side = CblasLeft: lda >= max(1, m) . - If side = CblasRight:lda >= max(1, n) . - - \param[in,out] B - B is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array,dimension : - m-by-n , stored in an ldb-by-n array [RowMajor: m-by-ldb]. - - \param[in] ldb - ldb is INTEGER - On entry, ldb specifies the Leading dimension of B - ldb >= max(1, m) [RowMajor: ldb >= max(1, n)]. - */ -template< typename T > -void trmm( - CBLAS_ORDER layout, - CBLAS_SIDE side, - CBLAS_UPLO uplo, - CBLAS_TRANSPOSE trans, - CBLAS_DIAG diag, - int64_t m, - int64_t n, - T alpha, - T const *A, int64_t lda, - T *B, int64_t ldb ) -{ - cblas_trmm( layout, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb); -} - -/*! \brief Solve the Hermitian matrix-matrix multiply for arbitrary data types - - \b Purpose: - - HEMM performs solves one of the matrix-matrix operations for arbitrary data types - Data precisions supported include SINGLE PRECISION COMPLEX, DOUBLE PRECISION COMPLEX(COMPLEX*16) - - C := alpha*A*B + beta*C - - or - - C := alpha*B*A + beta*C, - - where alpha is a scalar, A is an hermitian matrix - C and B are m by n matrices - - \param[in] layout - layout is enum CBLAS_ORDER - layout specifies Matrix storage as follows: - layout = CBLAS_ORDER::CblasRowMajor or Layout::CblasColMajor. - - \param[in] side - side is enum CBLAS_SIDE - side specifies specifies whether the hermitian matrix A - appears on the left or right in the operation as follows: - side = CBLAS_SIDE::CblasLeft C := alpha*A*B + beta*C, - side = CBLAS_SIDE::CblasRight C := alpha*B*A + beta*C - - \param[in] uplo - uplo is enum CBLAS_UPLO - uplo specifies specifies whether the upper or lower - triangular part of the hermitian matrix A is to be - referenced as follows: - uplo = CBLAS_UPLO::CblasUpper Only the upper triangular part of the - hermitian matrix is to be referenced. - uplo = CBLAS_UPLO::CblasLower Only the lower triangular part of the - hermitian matrix is to be referenced. - - \param[in] m - m is INTEGER - On entry, m specifies the number of rows of the matrix - C. m must be at least zero. - - \param[in] n - n is INTEGER - On entry, n specifies the number of columns of the matrix - C. n must be at least zero. - - \param[in] alpha - alpha is COMPLEX/COMPLEX*16 - On entry, alpha specifies the scalar alpha. - - \param[in] A - A is COMPLEX/COMPLEX*16 array,dimension : - If side = CblasLeft: - the m-by-m matrix A, stored in an lda-by-m array [RowMajor: m-by-lda]. - If side = CblasRight: - the n-by-n matrix A, stored in an lda-by-n array [RowMajor: n-by-lda]. - - \param[in] lda - lda is INTEGER - On entry, lda specifies the Leading dimension of A - If side = CblasLeft: lda >= max(1, m) . - If side = CblasRight:lda >= max(1, k) . - - \param[in] B - B is COMPLEX/COMPLEX*16 array,dimension : - m-by-n , stored in an ldb-by-n array [RowMajor: m-by-ldb]. - - \param[in] ldb - ldb is INTEGER - On entry, ldb specifies the Leading dimension of B - ldb >= max(1, m) [RowMajor: ldb >= max(1, n)]. - - \param[in] beta - beta is COMPLEX/COMPLEX*16 - On entry, beta specifies the scalar beta. - If beta is zero, C need not be set on input - - \param[in,out] C - C is COMPLEX/COMPLEX*16 array,dimension : - m-by-n , stored in an ldc-by-n array [RowMajor: m-by-ldc]. - - \param[in] ldc - ldc is INTEGER - On entry, ldc specifies the Leading dimension of C - ldc >= max(1, m) [RowMajor: ldc >= max(1, n)]. - */ -template< typename T > -void hemm( - CBLAS_ORDER layout, - CBLAS_SIDE side, - CBLAS_UPLO uplo, - int64_t m, int64_t n, - T alpha, - T const *A, int64_t lda, - T const *B, int64_t ldb, - T beta, - T *C, int64_t ldc ) -{ - cblas_hemm( layout, side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, ldc); -} - -/*! \brief Solve the Symmetric matrix-matrix multiply for arbitrary data types - - \b Purpose: - - SYMM performs solves one of the matrix-matrix operations for arbitrary data types - Data precisions supported include SINGLE PRECISION REAL, DOUBLE PRECISION REAL, - SINGLE PRECISION COMPLEX, DOUBLE PRECISION COMPLEX(COMPLEX*16) - - C := alpha*A*B + beta*C - - or - - C := alpha*B*A + beta*C, - - where alpha is a scalar, A is an symmetric matrix - C and B are m by n matrices - - \param[in] layout - layout is enum CBLAS_ORDER - layout specifies Matrix storage as follows: - layout = CBLAS_ORDER::CblasRowMajor or Layout::CblasColMajor. - - \param[in] side - side is enum CBLAS_SIDE - side specifies specifies whether the symmetric matrix A - appears on the left or right in the operation as follows: - side = CBLAS_SIDE::CblasLeft C := alpha*A*B + beta*C, - side = CBLAS_SIDE::CblasRight C := alpha*B*A + beta*C - - \param[in] uplo - uplo is enum CBLAS_UPLO - uplo specifies specifies whether the upper or lower - triangular part of the symmetric matrix A is to be - referenced as follows: - uplo = CBLAS_UPLO::CblasUpper Only the upper triangular part of the - symmetric matrix is to be referenced. - uplo = CBLAS_UPLO::CblasLower Only the lower triangular part of the - symmetric matrix is to be referenced. - - \param[in] m - m is INTEGER - On entry, m specifies the number of rows of the matrix - C. m must be at least zero. - - \param[in] n - n is INTEGER - On entry, n specifies the number of columns of the matrix - C. n must be at least zero. - - \param[in] alpha - alpha is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 - On entry, alpha specifies the scalar alpha. - - \param[in] A - A is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array,dimension : - If side = CblasLeft: - the m-by-m matrix A, stored in an lda-by-m array [RowMajor: m-by-lda]. - If side = CblasRight: - the n-by-n matrix A, stored in an lda-by-n array [RowMajor: n-by-lda]. - - \param[in] lda - lda is INTEGER - On entry, lda specifies the Leading dimension of A - If side = CblasLeft: lda >= max(1, m) . - If side = CblasRight:lda >= max(1, k) . - - \param[in] B - B is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array,dimension : - m-by-n , stored in an ldb-by-n array [RowMajor: m-by-ldb]. - - \param[in] ldb - ldb is INTEGER - On entry, ldb specifies the Leading dimension of B - ldb >= max(1, m) [RowMajor: ldb >= max(1, n)]. - - \param[in] beta - beta is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 - On entry, beta specifies the scalar beta. - If beta is zero, C need not be set on input - - \param[in, out] C - C is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array,dimension : - m-by-n , stored in an ldc-by-n array [RowMajor: m-by-ldc]. - - \param[in] ldc - ldc is INTEGER - On entry, ldc specifies the Leading dimension of C - ldc >= max(1, m) [RowMajor: ldc >= max(1, n)]. - */ -template< typename T > -void symm( - CBLAS_ORDER layout, - CBLAS_SIDE side, - CBLAS_UPLO uplo, - int64_t m, int64_t n, - T alpha, - T const *A, int64_t lda, - T const *B, int64_t ldb, - T beta, - T *C, int64_t ldc ) -{ - cblas_symm( layout, side, uplo, m, n, alpha, A, lda, B, ldb, beta, C, ldc); -} - -/*! \brief Solve the Symmetric rank-k operations for arbitrary data types - - \b Purpose: - - SYRK performs one of the symmetric rank k operations for arbitrary data types - Data precisions supported include SINGLE PRECISION REAL, DOUBLE PRECISION REAL, - SINGLE PRECISION COMPLEX, DOUBLE PRECISION COMPLEX(COMPLEX*16) - - C := alpha*A*A**T + beta*C, - - or - - C := alpha*A**T*A + beta*C, - - where alpha and beta are scalars, C is an n by n symmetric matrix - and A is an n by k matrix in the first case and a k by n matrix - in the second case. - - \param[in] layout - layout is enum CBLAS_LAYOUT - layout specifies Matrix storage as follows: - layout = CBLAS_LAYOUT::CblasRowMajor or Layout::CblasColMajor. - - \param[in] uplo - uplo is enum CBLAS_UPLO - uplo specifies specifies whether the upper or lower - triangular part of the array C is to be referenced - as follows: - uplo = CBLAS_UPLO::CblasUpper Only the upper triangular part of C - is to be referenced. - uplo = CBLAS_UPLO::CblasLower Only the lower triangular part of C - is to be referenced. - - \param[in] trans - trans is CBLAS_TRANSPOSE - On entry, trans specifies the operation to be used as follows: - trans = CBLAS_TRANSPOSE::CblasNoTrans,C := alpha*A*A**T + beta*C. - trans = CBLAS_TRANSPOSE::CblasTrans,C := alpha*A**T*A + beta*C. - - \param[in] n - n is INTEGER - On entry, n specifies the order of the matrix C. n must be - at least zero. - - \param[in] k - k is INTEGER - If trans = CblasNoTrans: k is number of columns of the matrix A. - Otherwise: k is number of rows of the matrix A. - k must be at least zero. - - \param[in] alpha - alpha is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 - On entry, alpha specifies the scalar alpha. - - \param[in] A - A is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array,dimension : - If transA = CblasNoTrans: - n-by-k , stored in an lda-by-k array [RowMajor: n-by-lda]. - Otherwise: - k-by-n , stored in an lda-by-n array [RowMajor: k-by-lda]. - - \param[in] lda - lda is INTEGER - On entry, lda specifies the Leading dimension of A - If transA = CblasNoTrans: lda >= max(1, n) [RowMajor: lda >= max(1, k)]. - Otherwise: lda >= max(1, k) [RowMajor: lda >= max(1, n)]. - - \param[in] beta - beta is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 - On entry, beta specifies the scalar alpha.When beta is - supplied as zero then C need not be set on input. - - \param[in,out] C - C is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array, dimension : - The n-by-n symmetric matrix C, - stored in an ldc-by-n array [RowMajor: n-by-ldc]. - On exit, the array C is overwritten by the lower/upper - triangular part of the updated matrix. - - \param[in] ldc - ldc is INTEGER - On entry, ldc specifies the first dimension of C - ldc >= max(1, n) - */ -template< typename T > -void syrk( - CBLAS_ORDER layout, - CBLAS_UPLO uplo, - CBLAS_TRANSPOSE trans, - int64_t n, int64_t k, - T alpha, - T const *A, int64_t lda, - T beta, - T *C, int64_t ldc ) -{ - cblas_syrk( layout, uplo, trans, n, k, alpha, A, lda, beta, C, ldc); -} - -/*! \brief Solve the Symmetric rank 2k operations for arbitrary data types - - \b Purpose: - - SYR2K performs one of the symmetric rank 2k operations for arbitrary data types - Data precisions supported include SINGLE PRECISION REAL, DOUBLE PRECISION REAL, - SINGLE PRECISION COMPLEX, DOUBLE PRECISION COMPLEX(COMPLEX*16) - - C := alpha*A*B**T + alpha*B*A**T + beta*C, - - or - - C := alpha*A**T*B + alpha*B**T*A + beta*C, - - where alpha and beta are scalars, C is an n by n symmetric matrix - and A and B are n by k matrices in the first case and k by n matrices - in the second case. - - \param[in] layout - layout is enum CBLAS_LAYOUT - layout specifies Matrix storage as follows: - layout = CBLAS_LAYOUT::CblasRowMajor or Layout::CblasColMajor. - - \param[in] uplo - uplo is enum CBLAS_UPLO - uplo specifies specifies whether the upper or lower - triangular part of the array C is to be referenced - as follows: - uplo = CBLAS_UPLO::CblasUpper Only the upper triangular part of C - is to be referenced. - uplo = CBLAS_UPLO::CblasLower Only the lower triangular part of C - is to be referenced. - - \param[in] trans - trans is CBLAS_TRANSPOSE - On entry, trans specifies the operation to be used as follows: - trans = CBLAS_TRANSPOSE::CblasNoTrans,C := alpha*A*B**T + alpha*B*A**T + beta*C. - trans = CBLAS_TRANSPOSE::CblasTrans, C := alpha*A**T*B + alpha*B**T*A + beta*C. - - \param[in] n - n is INTEGER - On entry, n specifies the order of the matrix C. n must be - at least zero. - - \param[in] k - k is INTEGER - If trans = CblasNoTrans: k is number of columns of the matrices A & B. - Otherwise: k is number of rows of the matrices A & B. - k must be at least zero. - - \param[in] alpha - alpha is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 - On entry, alpha specifies the scalar alpha. - - \param[in] A - A is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array,dimension : - If trans = CblasNoTrans: - n-by-k , stored in an lda-by-k array [RowMajor: n-by-lda]. - Otherwise: - k-by-n , stored in an lda-by-n array [RowMajor: k-by-lda]. - - \param[in] lda - lda is INTEGER - On entry, lda specifies the Leading dimension of A - If trans = CblasNoTrans: lda >= max(1, n) [RowMajor: lda >= max(1, k)]. - Otherwise: lda >= max(1, k) [RowMajor: lda >= max(1, n)]. - - \param[in] B - B is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array,dimension : - If trans = CblasNoTrans: - n-by-k , stored in an ldb-by-k array [RowMajor: n-by-ldb]. - Otherwise: - k-by-n , stored in an ldb-by-n array [RowMajor: k-by-ldb] - - \param[in] ldb - ldb is INTEGER - On entry, ldb specifies the Leading dimension of B - If trans = CblasNoTrans: ldb >= max(1, n) [RowMajor: ldb >= max(1, k)]. - Otherwise: ldb >= max(1, k) [RowMajor: ldb >= max(1, n)]. - - \param[in] beta - beta is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 - On entry, beta specifies the scalar alpha.When beta is - supplied as zero then C need not be set on input. - - \param[in,out] C - C is REAL/DOUBLE PRECISION/COMPLEX/COMPLEX*16 array, dimension : - The n-by-n symmetric matrix C, - stored in an ldc-by-n array [RowMajor: n-by-ldc]. - On exit, the array C is overwritten by the lower/upper - triangular part of the updated matrix. - - \param[in] ldc - ldc is INTEGER - On entry, ldc specifies the first dimension of C - ldc >= max(1, n) - */ -template< typename T > -void syr2k( - CBLAS_ORDER layout, - CBLAS_UPLO uplo, - CBLAS_TRANSPOSE trans, - int64_t n, int64_t k, - T alpha, - T const *A, int64_t lda, - T const *B, int64_t ldb, - T beta, - T *C, int64_t ldc ) -{ - cblas_syr2k( layout, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc ); -} - -/*! \brief Solve the Hermitian rank k operations for arbitrary data types - - \b Purpose: - - HERK performs one of the hermitian rank k operations for arbitrary data types - Data precisions supported include SINGLE PRECISION COMPLEX, - DOUBLE PRECISION COMPLEX(COMPLEX*16) - - C := alpha*A*B**H + conjg( alpha )*B*A**H + beta*C, - - or - - C := alpha*A**H*B + conjg( alpha )*B**H*A + beta*C, - - where alpha and beta are real scalars, C is an n by n hermitian - matrix and A is an n by k matrix in the first case and - k by n matrix in the second case. - - \param[in] layout - layout is enum CBLAS_LAYOUT - layout specifies Matrix storage as follows: - layout = CBLAS_LAYOUT::CblasRowMajor or Layout::CblasColMajor. - - \param[in] uplo - uplo is enum CBLAS_UPLO - uplo specifies specifies whether the upper or lower - triangular part of the array C is to be referenced - as follows: - uplo = CBLAS_UPLO::CblasUpper Only the upper triangular part of C - is to be referenced. - uplo = CBLAS_UPLO::CblasLower Only the lower triangular part of C - is to be referenced. - - \param[in] trans - trans is CBLAS_TRANSPOSE - On entry, trans specifies the operation to be used as follows: - trans = CBLAS_TRANSPOSE::CblasNoTrans, C := alpha*A*A**H + beta*C. - trans = CBLAS_TRANSPOSE::CblasConjTrans,C := alpha*A**H*A + beta*C. - - \param[in] n - n is INTEGER - On entry, n specifies the order of the matrix C. n must be - at least zero. - - \param[in] k - k is INTEGER - If trans = CblasNoTrans: k is number of columns of the matrix A. - Otherwise: k is number of rows of the matrix A. - k must be at least zero. - - \param[in] alpha - alpha is REAL/DOUBLE PRECISION - On entry, alpha specifies the scalar alpha. - - \param[in] A - A is COMPLEX/COMPLEX*16 array,dimension : - If trans = CblasNoTrans: - n-by-k , stored in an lda-by-k array [RowMajor: n-by-lda]. - Otherwise: - k-by-n , stored in an lda-by-n array [RowMajor: k-by-lda]. - - \param[in] lda - lda is INTEGER - On entry, lda specifies the Leading dimension of A - If trans = CblasNoTrans: lda >= max(1, n) [RowMajor: lda >= max(1, k)]. - Otherwise: lda >= max(1, k) [RowMajor: lda >= max(1, n)]. - - \param[in] beta - beta is REAL/DOUBLE PRECISION - On entry, beta specifies the scalar alpha.When beta is - supplied as zero then C need not be set on input. - - \param[in,out] C - C is COMPLEX/COMPLEX*16 array, dimension : - The n-by-n Hermitian matrix C, - stored in an ldc-by-n array [RowMajor: n-by-ldc]. - On exit, the array C is overwritten by the lower/upper - triangular part of the updated matrix. - - \param[in] ldc - ldc is INTEGER - On entry, ldc specifies the first dimension of C - ldc >= max(1, n) - */ -template< typename T > -void herk( - CBLAS_ORDER layout, - CBLAS_UPLO uplo, - CBLAS_TRANSPOSE trans, - int64_t n, int64_t k, - real_type alpha, - T const *A, int64_t lda, - real_type beta, - T *C, int64_t ldc ) -{ - cblas_herk( layout, uplo, trans, n, k, alpha, A, lda, beta, C, ldc ); -} - -/*! \brief Solve the Hermitian rank 2k operations for arbitrary data types - - \b Purpose: - - HER2K performs one of the hermitian rank 2k operations for arbitrary data types - Data precisions supported include SINGLE PRECISION COMPLEX, - DOUBLE PRECISION COMPLEX(COMPLEX*16) - - C := alpha*A*B**H + conjg( alpha )*B*A**H + beta*C, - - or - - C := alpha*A**H*B + conjg( alpha )*B**H*A + beta*C, - - where alpha and beta are scalars with beta real, C is an n by n - hermitian matrix and A and B are n by k matrices in the first case - and k by n matrices in the second case. - - \param[in] layout - layout is enum CBLAS_LAYOUT - layout specifies Matrix storage as follows: - layout = CBLAS_LAYOUT::CblasRowMajor or Layout::CblasColMajor. - - \param[in] uplo - uplo is enum CBLAS_UPLO - uplo specifies specifies whether the upper or lower - triangular part of the array C is to be referenced - as follows: - uplo = CBLAS_UPLO::CblasUpper Only the upper triangular part of C - is to be referenced. - uplo = CBLAS_UPLO::CblasLower Only the lower triangular part of C - is to be referenced. - - \param[in] trans - trans is CBLAS_TRANSPOSE - On entry, trans specifies the operation to be used as follows: - trans = CBLAS_TRANSPOSE::CblasNoTrans, C := alpha*A*B**H + conjg( alpha )*B*A**H + beta*C. - trans = CBLAS_TRANSPOSE::CblasConjTrans,C := alpha*A**H*B + conjg( alpha )*B**H*A + beta*C. - - \param[in] n - n is INTEGER - On entry, n specifies the order of the matrix C. n must be - at least zero. - - \param[in] k - k is INTEGER - If trans = CblasNoTrans: k is number of columns of the matrices A & B. - Otherwise: k is number of rows of the matrices A & B. - k must be at least zero. - - \param[in] alpha - alpha is COMPLEX/COMPLEX*16 - On entry, alpha specifies the scalar alpha. - - \param[in] A - A is COMPLEX/COMPLEX*16 array,dimension : - If trans = CblasNoTrans: - n-by-k , stored in an lda-by-k array [RowMajor: n-by-lda]. - Otherwise: - k-by-n , stored in an lda-by-n array [RowMajor: k-by-lda]. - - \param[in] lda - lda is INTEGER - On entry, lda specifies the Leading dimension of A - If trans = CblasNoTrans: lda >= max(1, n) [RowMajor: lda >= max(1, k)]. - Otherwise: lda >= max(1, k) [RowMajor: lda >= max(1, n)]. - - \param[in] B - B is COMPLEX/COMPLEX*16 array,dimension : - If trans = CblasNoTrans: - n-by-k , stored in an ldb-by-k array [RowMajor: n-by-ldb]. - Otherwise: - k-by-n , stored in an ldb-by-n array [RowMajor: k-by-ldb] - - \param[in] ldb - ldb is INTEGER - On entry, ldb specifies the Leading dimension of B - If trans = CblasNoTrans: ldb >= max(1, n) [RowMajor: ldb >= max(1, k)]. - Otherwise: ldb >= max(1, k) [RowMajor: ldb >= max(1, n)]. - - \param[in] beta - beta is REAL/DOUBLE PRECISION - On entry, beta specifies the scalar alpha.When beta is - supplied as zero then C need not be set on input. - - \param[in,out] C - C is COMPLEX/COMPLEX*16 array, dimension : - The n-by-n Hermitian matrix C, - stored in an ldc-by-n array [RowMajor: n-by-ldc]. - On exit, the array C is overwritten by the lower/upper - triangular part of the updated matrix. - - \param[in] ldc - ldc is INTEGER - On entry, ldc specifies the first dimension of C - ldc >= max(1, n) - */ -template< typename T > -void her2k( - CBLAS_ORDER layout, - CBLAS_UPLO uplo, - CBLAS_TRANSPOSE trans, - int64_t n, int64_t k, - T alpha, - T const *A, int64_t lda, - T const *B, int64_t ldb, - real_type beta, - T *C, int64_t ldc ) -{ - cblas_her2k( layout, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc ); -} - -} // namespace blis -#endif // #ifndef BLIS_HH diff --git a/cpp/cblas.hh b/cpp/cblas.hh deleted file mode 100644 index b656ed28e1..0000000000 --- a/cpp/cblas.hh +++ /dev/null @@ -1,1705 +0,0 @@ -/****************************************************************************** -* Copyright (c) 2019 - present Advanced Micro Devices, Inc. All rights reserved. -* -* Permission is hereby granted, free of charge, to any person obtaining a copy -* of this software and associated documentation files (the "Software"), to deal -* in the Software without restriction, including without limitation the rights -* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -* copies of the Software, and to permit persons to whom the Software is -* furnished to do so, subject to the following conditions: -* -* The above copyright notice and this permission notice shall be included in -* all copies or substantial portions of the Software. -* -* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -* THE SOFTWARE. -*******************************************************************************/ - -/*! @file cblas.hh - * cblas.hh defines all the overloaded CPP functions to be invoked from - * template interfaces - * */ -#ifndef CBLAS_HH -#define CBLAS_HH - -extern "C" { -#include -} - -#include - -namespace blis{ - -template< typename... Types > struct real_type_traits; - -//define real_type<> type alias -template< typename... Types > -using real_type = typename real_type_traits< Types... >::real_t; - -// for one type -template< typename T > -struct real_type_traits -{ - using real_t = T; -}; - -// for one complex type, strip complex -template< typename T > -struct real_type_traits< std::complex > -{ - using real_t = T; -}; - -// ============================================================================= -// Level 1 BLAS -// ----------------------------------------------------------------------------- -inline void -cblas_rotg( - float *a, float *b, - float *c, float *s ) -{ - cblas_srotg( a, b, c, s ); -} - -inline void -cblas_rotg( - double *a, double *b, - double *c, double *s ) -{ - cblas_drotg( a, b, c, s ); -} - -// ----------------------------------------------------------------------------- -inline void -cblas_rotmg( - float *d1, float *d2, float *x1, float y1, float param[5] ) -{ - cblas_srotmg( d1, d2, x1, y1, param ); -} - -inline void -cblas_rotmg( - double *d1, double *d2, double *x1, double y1, double param[5] ) -{ - cblas_drotmg( d1, d2, x1, y1, param ); -} - -// ----------------------------------------------------------------------------- -inline void -cblas_rot( - int n, - float *x, int incx, - float *y, int incy, - float c, float s ) -{ - cblas_srot( n, x, incx, y, incy, c, s ); -} - -inline void -cblas_rot( - int n, - double *x, int incx, - double *y, int incy, - double c, double s ) -{ - cblas_drot( n, x, incx, y, incy, c, s ); -} - -// ----------------------------------------------------------------------------- -inline void -cblas_rotm( - int n, - float *x, int incx, - float *y, int incy, - const float p[5] ) -{ - cblas_srotm( n, x, incx, y, incy, p ); -} - -inline void -cblas_rotm( - int n, - double *x, int incx, - double *y, int incy, - const double p[5] ) -{ - cblas_drotm( n, x, incx, y, incy, p ); -} - -// ----------------------------------------------------------------------------- -inline void -cblas_swap( - int n, - float* x, int incx, - float* y, int incy ) -{ - cblas_sswap( n, x, incx, y, incy ); -} - -inline void -cblas_swap( - int n, - double* x, int incx, - double* y, int incy ) -{ - cblas_dswap( n, x, incx, y, incy ); -} - -inline void -cblas_swap( - int n, - std::complex* x, int incx, - std::complex* y, int incy ) -{ - cblas_cswap( n, x, incx, y, incy ); -} - -inline void -cblas_swap( - int n, - std::complex* x, int incx, - std::complex* y, int incy ) -{ - cblas_zswap( n, x, incx, y, incy ); -} - -// ----------------------------------------------------------------------------- -inline void -cblas_scal( - int n, float alpha, - float* x, int incx ) -{ - cblas_sscal( n, alpha, x, incx ); -} - -inline void -cblas_scal( - int n, double alpha, - double* x, int incx ) -{ - cblas_dscal( n, alpha, x, incx ); -} - -inline void -cblas_scal( - int n, std::complex alpha, - std::complex* x, int incx ) -{ - cblas_cscal( n, &alpha, x, incx ); -} - -inline void -cblas_scal( - int n, std::complex alpha, - std::complex* x, int incx ) -{ - cblas_zscal( n, &alpha, x, incx ); -} - -inline void -cblas_scal( - int n, float alpha, - std::complex* x, int incx ) -{ - cblas_csscal( n, alpha, x, incx ); -} - -inline void -cblas_scal( - int n, double alpha, - std::complex* x, int incx ) -{ - cblas_zdscal( n, alpha, x, incx ); -} - -// ----------------------------------------------------------------------------- -inline void -cblas_copy( - int n, - float const *x, int incx, - float* y, int incy ) -{ - cblas_scopy( n, x, incx, y, incy ); -} - -inline void -cblas_copy( - int n, - double const *x, int incx, - double* y, int incy ) -{ - cblas_dcopy( n, x, incx, y, incy ); -} - -inline void -cblas_copy( - int n, - std::complex const *x, int incx, - std::complex* y, int incy ) -{ - cblas_ccopy( n, x, incx, y, incy ); -} - -inline void -cblas_copy( - int n, - std::complex const *x, int incx, - std::complex* y, int incy ) -{ - cblas_zcopy( n, x, incx, y, incy ); -} - -// ----------------------------------------------------------------------------- -inline void -cblas_axpy( - int n, float alpha, - float const *x, int incx, - float* y, int incy ) -{ - cblas_saxpy( n, alpha, x, incx, y, incy ); -} - -inline void -cblas_axpy( - int n, double alpha, - double const *x, int incx, - double* y, int incy ) -{ - cblas_daxpy( n, alpha, x, incx, y, incy ); -} - -inline void -cblas_axpy( - int n, std::complex alpha, - std::complex const *x, int incx, - std::complex* y, int incy ) -{ - cblas_caxpy( n, &alpha, x, incx, y, incy ); -} - -inline void -cblas_axpy( - int n, std::complex alpha, - std::complex const *x, int incx, - std::complex* y, int incy ) -{ - cblas_zaxpy( n, &alpha, x, incx, y, incy ); -} - -// ----------------------------------------------------------------------------- -inline float -cblas_dot( - int n, - float const *x, int incx, - float const *y, int incy ) -{ - return cblas_sdot( n, x, incx, y, incy ); -} - -inline double -cblas_dot( - int n, - double const *x, int incx, - double const *y, int incy ) -{ - return cblas_ddot( n, x, incx, y, incy ); -} -// ----------------------------------------------------------------------------- -inline std::complex -cblas_dotu( - int n, - std::complex const *x, int incx, - std::complex const *y, int incy ) -{ - std::complex result; - cblas_cdotu_sub( n, x, incx, y, incy, &result ); - return result; -} - -inline std::complex -cblas_dotu( - int n, - std::complex const *x, int incx, - std::complex const *y, int incy ) -{ - std::complex result; - cblas_zdotu_sub( n, x, incx, y, incy, &result ); - return result; -} - -// ----------------------------------------------------------------------------- -inline std::complex -cblas_dotc( - int n, - std::complex const *x, int incx, - std::complex const *y, int incy ) -{ - std::complex result; - cblas_cdotc_sub( n, x, incx, y, incy, &result ); - return result; -} - -inline std::complex -cblas_dotc( - int n, - std::complex const *x, int incx, - std::complex const *y, int incy ) -{ - std::complex result; - cblas_zdotc_sub( n, x, incx, y, incy, &result ); - return result; -} - -// ----------------------------------------------------------------------------- -inline int -cblas_iamax( - int n, float const *x, int incx ) -{ - return cblas_isamax( n, x, incx ); -} - -inline int -cblas_iamax( - int n, double const *x, int incx ) -{ - return cblas_idamax( n, x, incx ); -} - -inline int -cblas_iamax( - int n, std::complex const *x, int incx ) -{ - return cblas_icamax( n, x, incx ); -} - -inline int -cblas_iamax( - int n, std::complex const *x, int incx ) -{ - return cblas_izamax( n, x, incx ); -} - - -// ----------------------------------------------------------------------------- -inline float -cblas_nrm2( - int n, float const *x, int incx ) -{ - return cblas_snrm2( n, x, incx ); -} - -inline double -cblas_nrm2( - int n, double const *x, int incx ) -{ - return cblas_dnrm2( n, x, incx ); -} - -inline float -cblas_nrm2( - int n, std::complex const *x, int incx ) -{ - return cblas_scnrm2( n, x, incx ); -} - -inline double -cblas_nrm2( - int n, std::complex const *x, int incx ) -{ - return cblas_dznrm2( n, x, incx ); -} - -// ----------------------------------------------------------------------------- -inline float -cblas_asum( - int n, float const *x, int incx ) -{ - return cblas_sasum( n, x, incx ); -} - -inline double -cblas_asum( - int n, double const *x, int incx ) -{ - return cblas_dasum( n, x, incx ); -} - -inline float -cblas_asum( - int n, std::complex const *x, int incx ) -{ - return cblas_scasum( n, x, incx ); -} - -inline double -cblas_asum( - int n, std::complex const *x, int incx ) -{ - return cblas_dzasum( n, x, incx ); -} -// ============================================================================= -// Level 2 BLAS - -// ----------------------------------------------------------------------------- -inline void -cblas_gemv( - CBLAS_ORDER layout, CBLAS_TRANSPOSE trans, int m, int n, - float alpha, - float const *A, int lda, - float const *x, int incx, - float beta, - float* y, int incy ) -{ - cblas_sgemv( layout, trans, m, n, - alpha, A, lda, x, incx, beta, y, incy ); -} - -inline void -cblas_gemv( - CBLAS_ORDER layout, CBLAS_TRANSPOSE trans, int m, int n, - double alpha, - double const *A, int lda, - double const *x, int incx, - double beta, - double* y, int incy ) -{ - cblas_dgemv( layout, trans, m, n, - alpha, A, lda, x, incx, beta, y, incy ); -} - -inline void -cblas_gemv( - CBLAS_ORDER layout, CBLAS_TRANSPOSE trans, int m, int n, - std::complex alpha, - std::complex const *A, int lda, - std::complex const *x, int incx, - std::complex beta, - std::complex* y, int incy ) -{ - cblas_cgemv( layout, trans, m, n, - &alpha, A, lda, x, incx, - &beta, y, incy ); -} - -inline void -cblas_gemv( - CBLAS_ORDER layout, CBLAS_TRANSPOSE trans, int m, int n, - std::complex alpha, - std::complex const *A, int lda, - std::complex const *x, int incx, - std::complex beta, - std::complex* y, int incy ) -{ - cblas_zgemv( layout, trans, m, n, - &alpha, A, lda, x, incx, - &beta, y, incy ); -} -inline void -cblas_gbmv( - CBLAS_ORDER layout, CBLAS_TRANSPOSE trans, - int m, int n, int kl, int ku, - float alpha, - float const *A, int lda, - float const *x, int incx, - float beta, - float* y, int incy ) -{ - cblas_sgbmv( layout, trans, m, n, kl, ku, - alpha, A, lda, x, incx, beta, y, incy ); -} - -inline void -cblas_gbmv( - CBLAS_ORDER layout, CBLAS_TRANSPOSE trans, - int m, int n, int kl, int ku, - double alpha, - double const *A, int lda, - double const *x, int incx, - double beta, - double* y, int incy ) -{ - cblas_dgbmv( layout, trans, m, n, kl, ku, - alpha, A, lda, x, incx, beta, y, incy ); -} - -inline void -cblas_gbmv( - CBLAS_ORDER layout, CBLAS_TRANSPOSE trans, - int m, int n, int kl, int ku, - std::complex alpha, - std::complex const *A, int lda, - std::complex const *x, int incx, - std::complex beta, - std::complex* y, int incy ) -{ - cblas_cgbmv( layout, trans, m, n, kl, ku, - &alpha, A, lda, x, incx, - &beta, y, incy ); -} - -inline void -cblas_gbmv( - CBLAS_ORDER layout, CBLAS_TRANSPOSE trans, - int m, int n, int kl, int ku, - std::complex alpha, - std::complex const *A, int lda, - std::complex const *x, int incx, - std::complex beta, - std::complex* y, int incy ) -{ - cblas_zgbmv( layout, trans, m, n, kl, ku, - &alpha, A, lda, x, incx, - &beta, y, incy ); -} - -// ----------------------------------------------------------------------------- -inline void -cblas_hemv( - CBLAS_ORDER layout, CBLAS_UPLO uplo, int n, - std::complex alpha, - std::complex const *A, int lda, - std::complex const *x, int incx, - std::complex beta, - std::complex* y, int incy ) -{ - cblas_chemv( layout, uplo, n, - &alpha, A, lda, x, incx, - &beta, y, incy ); -} - -inline void -cblas_hemv( - CBLAS_ORDER layout, CBLAS_UPLO uplo, int n, - std::complex alpha, - std::complex const *A, int lda, - std::complex const *x, int incx, - std::complex beta, - std::complex* y, int incy ) -{ - cblas_zhemv( layout, uplo, n, - &alpha, A, lda, x, incx, - &beta, y, incy ); -} - -// ----------------------------------------------------------------------------- -inline void -cblas_hbmv( - CBLAS_ORDER layout, CBLAS_UPLO uplo, int n, int k, - std::complex alpha, - std::complex const *A, int lda, - std::complex const *x, int incx, - std::complex beta, - std::complex* y, int incy ) -{ - cblas_chbmv( layout, uplo, n, k, - &alpha, A, lda, x, incx, - &beta, y, incy ); -} - -inline void -cblas_hbmv( - CBLAS_ORDER layout, CBLAS_UPLO uplo, int n, int k, - std::complex alpha, - std::complex const *A, int lda, - std::complex const *x, int incx, - std::complex beta, - std::complex* y, int incy ) -{ - cblas_zhbmv( layout, uplo, n, k, - &alpha, A, lda, x, incx, - &beta, y, incy ); -} - -// ----------------------------------------------------------------------------- -inline void -cblas_hpmv( - CBLAS_ORDER layout, CBLAS_UPLO uplo, int n, - std::complex alpha, - std::complex const *Ap, - std::complex const *x, int incx, - std::complex beta, - std::complex* y, int incy ) -{ - cblas_chpmv( layout, uplo, n, - &alpha, Ap, x, incx, - &beta, y, incy ); -} - -inline void -cblas_hpmv( - CBLAS_ORDER layout, CBLAS_UPLO uplo, int n, - std::complex alpha, - std::complex const *Ap, - std::complex const *x, int incx, - std::complex beta, - std::complex* y, int incy ) -{ - cblas_zhpmv( layout, uplo, n, - &alpha, Ap, x, incx, - &beta, y, incy ); -} - -// ----------------------------------------------------------------------------- -inline void -cblas_symv( - CBLAS_ORDER layout, CBLAS_UPLO uplo, int n, - float alpha, - float const *A, int lda, - float const *x, int incx, - float beta, - float* y, int incy ) -{ - cblas_ssymv( layout, uplo, n, - alpha, A, lda, x, incx, beta, y, incy ); -} - -inline void -cblas_symv( - CBLAS_ORDER layout, CBLAS_UPLO uplo, int n, - double alpha, - double const *A, int lda, - double const *x, int incx, - double beta, - double* y, int incy ) -{ - cblas_dsymv( layout, uplo, n, - alpha, A, lda, x, incx, beta, y, incy ); -} - -// ----------------------------------------------------------------------------- -inline void -cblas_sbmv( - CBLAS_ORDER layout, CBLAS_UPLO uplo, int n, int k, - float alpha, - float const *A, int lda, - float const *x, int incx, - float beta, - float* y, int incy ) -{ - cblas_ssbmv( layout, uplo, n, k, - alpha, A, lda, x, incx, beta, y, incy ); -} - -inline void -cblas_sbmv( - CBLAS_ORDER layout, CBLAS_UPLO uplo, int n, int k, - double alpha, - double const *A, int lda, - double const *x, int incx, - double beta, - double* y, int incy ) -{ - cblas_dsbmv( layout, uplo, n, k, - alpha, A, lda, x, incx, beta, y, incy ); -} - -// ----------------------------------------------------------------------------- -inline void -cblas_spmv( - CBLAS_ORDER layout, CBLAS_UPLO uplo, int n, - float alpha, - float const *Ap, - float const *x, int incx, - float beta, - float* y, int incy ) -{ - cblas_sspmv( layout, uplo, n, - alpha, Ap, x, incx, beta, y, incy ); -} - -inline void -cblas_spmv( - CBLAS_ORDER layout, CBLAS_UPLO uplo, int n, - double alpha, - double const *Ap, - double const *x, int incx, - double beta, - double* y, int incy ) -{ - cblas_dspmv( layout, uplo, n, - alpha, Ap, x, incx, beta, y, incy ); -} - -// ----------------------------------------------------------------------------- -inline void -cblas_trmv( - CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, CBLAS_DIAG diag, int n, - float const *A, int lda, - float* x, int incx ) -{ - cblas_strmv( layout, uplo, trans, diag, n, - A, lda, x, incx ); -} - -inline void -cblas_trmv( - CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, CBLAS_DIAG diag, int n, - double const *A, int lda, - double* x, int incx ) -{ - cblas_dtrmv( layout, uplo, trans, diag, n, - A, lda, x, incx ); -} - -inline void -cblas_trmv( - CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, CBLAS_DIAG diag, int n, - std::complex const *A, int lda, - std::complex* x, int incx ) -{ - cblas_ctrmv( layout, uplo, trans, diag, n, - A, lda, x, incx ); -} - -inline void -cblas_trmv( - CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, CBLAS_DIAG diag, int n, - std::complex const *A, int lda, - std::complex* x, int incx ) -{ - cblas_ztrmv( layout, uplo, trans, diag, n, - A, lda, x, incx ); -} - -// ----------------------------------------------------------------------------- -inline void -cblas_tbmv( - CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, CBLAS_DIAG diag, - int n, int k, - float const *A, int lda, - float* x, int incx ) -{ - cblas_stbmv( layout, uplo, trans, diag, n, k, - A, lda, x, incx ); -} - -inline void -cblas_tbmv( - CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, CBLAS_DIAG diag, - int n, int k, - double const *A, int lda, - double* x, int incx ) -{ - cblas_dtbmv( layout, uplo, trans, diag, n, k, - A, lda, x, incx ); -} - -inline void -cblas_tbmv( - CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, CBLAS_DIAG diag, - int n, int k, - std::complex const *A, int lda, - std::complex* x, int incx ) -{ - cblas_ctbmv( layout, uplo, trans, diag, n, k, - A, lda, x, incx ); -} - -inline void -cblas_tbmv( - CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, CBLAS_DIAG diag, - int n, int k, - std::complex const *A, int lda, - std::complex* x, int incx ) -{ - cblas_ztbmv( layout, uplo, trans, diag, n, k, - A, lda, x, incx ); -} - -// ----------------------------------------------------------------------------- -inline void -cblas_tpmv( - CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, CBLAS_DIAG diag, int n, - float const *Ap, - float* x, int incx ) -{ - cblas_stpmv( layout, uplo, trans, diag, n, - Ap, x, incx ); -} - -inline void -cblas_tpmv( - CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, CBLAS_DIAG diag, int n, - double const *Ap, - double* x, int incx ) -{ - cblas_dtpmv( layout, uplo, trans, diag, n, - Ap, x, incx ); -} - -inline void -cblas_tpmv( - CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, CBLAS_DIAG diag, int n, - std::complex const *Ap, - std::complex* x, int incx ) -{ - cblas_ctpmv( layout, uplo, trans, diag, n, - Ap, x, incx ); -} - -inline void -cblas_tpmv( - CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, CBLAS_DIAG diag, int n, - std::complex const *Ap, - std::complex* x, int incx ) -{ - cblas_ztpmv( layout, uplo, trans, diag, n, - Ap, x, incx ); -} - -// ----------------------------------------------------------------------------- -inline void -cblas_trsv( - CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, CBLAS_DIAG diag, int n, - float const *A, int lda, - float* x, int incx ) -{ - cblas_strsv( layout, uplo, trans, diag, n, - A, lda, x, incx ); -} - -inline void -cblas_trsv( - CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, CBLAS_DIAG diag, int n, - double const *A, int lda, - double* x, int incx ) -{ - cblas_dtrsv( layout, uplo, trans, diag, n, - A, lda, x, incx ); -} - -inline void -cblas_trsv( - CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, CBLAS_DIAG diag, int n, - std::complex const *A, int lda, - std::complex* x, int incx ) -{ - cblas_ctrsv( layout, uplo, trans, diag, n, - A, lda, x, incx ); -} - -inline void -cblas_trsv( - CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, CBLAS_DIAG diag, int n, - std::complex const *A, int lda, - std::complex* x, int incx ) -{ - cblas_ztrsv( layout, uplo, trans, diag, n, - A, lda, x, incx ); -} - -// ----------------------------------------------------------------------------- -inline void -cblas_tbsv( - CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, CBLAS_DIAG diag, - int n, int k, - float const *A, int lda, - float* x, int incx ) -{ - cblas_stbsv( layout, uplo, trans, diag, n, k, - A, lda, x, incx ); -} - -inline void -cblas_tbsv( - CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, CBLAS_DIAG diag, - int n, int k, - double const *A, int lda, - double* x, int incx ) -{ - cblas_dtbsv( layout, uplo, trans, diag, n, k, - A, lda, x, incx ); -} - -inline void -cblas_tbsv( - CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, CBLAS_DIAG diag, - int n, int k, - std::complex const *A, int lda, - std::complex* x, int incx ) -{ - cblas_ctbsv( layout, uplo, trans, diag, n, k, - A, lda, x, incx ); -} - -inline void -cblas_tbsv( - CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, CBLAS_DIAG diag, - int n, int k, - std::complex const *A, int lda, - std::complex* x, int incx ) -{ - cblas_ztbsv( layout, uplo, trans, diag, n, k, - A, lda, x, incx ); -} - -// ----------------------------------------------------------------------------- -inline void -cblas_tpsv( - CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, CBLAS_DIAG diag, int n, - float const *Ap, - float* x, int incx ) -{ - cblas_stpsv( layout, uplo, trans, diag, n, - Ap, x, incx ); -} - -inline void -cblas_tpsv( - CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, CBLAS_DIAG diag, int n, - double const *Ap, - double* x, int incx ) -{ - cblas_dtpsv( layout, uplo, trans, diag, n, - Ap, x, incx ); -} - -inline void -cblas_tpsv( - CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, CBLAS_DIAG diag, int n, - std::complex const *Ap, - std::complex* x, int incx ) -{ - cblas_ctpsv( layout, uplo, trans, diag, n, - Ap, x, incx ); -} - -inline void -cblas_tpsv( - CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, CBLAS_DIAG diag, int n, - std::complex const *Ap, - std::complex* x, int incx ) -{ - cblas_ztpsv( layout, uplo, trans, diag, n, - Ap, x, incx ); -} - -// ----------------------------------------------------------------------------- -inline void -cblas_ger( - CBLAS_ORDER layout, int m, int n, - float alpha, - float const *x, int incx, - float const *y, int incy, - float* A, int lda ) -{ - cblas_sger( layout, m, n, alpha, x, incx, y, incy, A, lda ); -} - -inline void -cblas_ger( - CBLAS_ORDER layout, int m, int n, - double alpha, - double const *x, int incx, - double const *y, int incy, - double* A, int lda ) -{ - cblas_dger( layout, m, n, alpha, x, incx, y, incy, A, lda ); -} - -// ----------------------------------------------------------------------------- -inline void -cblas_geru( - CBLAS_ORDER layout, int m, int n, - std::complex alpha, - std::complex const *x, int incx, - std::complex const *y, int incy, - std::complex* A, int lda ) -{ - cblas_cgeru( layout, m, n, &alpha, - x, incx, y, incy, A, lda ); -} - -inline void -cblas_geru( - CBLAS_ORDER layout, int m, int n, - std::complex alpha, - std::complex const *x, int incx, - std::complex const *y, int incy, - std::complex* A, int lda ) -{ - cblas_zgeru( layout, m, n, &alpha, - x, incx, y, incy, A, lda ); -} - -// ----------------------------------------------------------------------------- -inline void -cblas_gerc( - CBLAS_ORDER layout, int m, int n, - std::complex alpha, - std::complex const *x, int incx, - std::complex const *y, int incy, - std::complex* A, int lda ) -{ - cblas_cgerc( layout, m, n, &alpha, - x, incx, y, incy, A, lda ); -} - -inline void -cblas_gerc( - CBLAS_ORDER layout, int m, int n, - std::complex alpha, - std::complex const *x, int incx, - std::complex const *y, int incy, - std::complex* A, int lda ) -{ - cblas_zgerc( layout, m, n, &alpha, - x, incx, y, incy, A, lda ); -} - -// ----------------------------------------------------------------------------- -inline void -cblas_her( - CBLAS_ORDER layout, CBLAS_UPLO uplo, int n, - float alpha, - std::complex const *x, int incx, - std::complex* A, int lda ) -{ - cblas_cher( layout, uplo, n, alpha, x, incx, A, lda ); -} - -inline void -cblas_her( - CBLAS_ORDER layout, CBLAS_UPLO uplo, int n, - double alpha, - std::complex const *x, int incx, - std::complex* A, int lda ) -{ - cblas_zher( layout, uplo, n, alpha, x, incx, A, lda ); -} - -// ----------------------------------------------------------------------------- -inline void -cblas_hpr( - CBLAS_ORDER layout, CBLAS_UPLO uplo, int n, - float alpha, - std::complex const *x, int incx, - std::complex* Ap ) -{ - cblas_chpr( layout, uplo, n, alpha, x, incx, Ap ); -} - -inline void -cblas_hpr( - CBLAS_ORDER layout, CBLAS_UPLO uplo, int n, - double alpha, - std::complex const *x, int incx, - std::complex* Ap ) -{ - cblas_zhpr( layout, uplo, n, alpha, x, incx, Ap ); -} - -// ----------------------------------------------------------------------------- -inline void -cblas_her2( - CBLAS_ORDER layout, CBLAS_UPLO uplo, int n, - std::complex alpha, - std::complex const *x, int incx, - std::complex const *y, int incy, - std::complex* A, int lda ) -{ - cblas_cher2( layout, uplo, n, &alpha, x, incx, y, incy, A, lda ); -} - -inline void -cblas_her2( - CBLAS_ORDER layout, CBLAS_UPLO uplo, int n, - std::complex alpha, - std::complex const *x, int incx, - std::complex const *y, int incy, - std::complex* A, int lda ) -{ - cblas_zher2( layout, uplo, n, &alpha, x, incx, y, incy, A, lda ); -} - -// ----------------------------------------------------------------------------- -inline void -cblas_hpr2( - CBLAS_ORDER layout, CBLAS_UPLO uplo, int n, - std::complex alpha, - std::complex const *x, int incx, - std::complex const *y, int incy, - std::complex* Ap ) -{ - cblas_chpr2( layout, uplo, n, &alpha, x, incx, y, incy, Ap ); -} - -inline void -cblas_hpr2( - CBLAS_ORDER layout, CBLAS_UPLO uplo, int n, - std::complex alpha, - std::complex const *x, int incx, - std::complex const *y, int incy, - std::complex* Ap ) -{ - cblas_zhpr2( layout, uplo, n, &alpha, x, incx, y, incy, Ap ); -} -// ----------------------------------------------------------------------------- -inline void -cblas_syr( - CBLAS_ORDER layout, CBLAS_UPLO uplo, int n, - float alpha, - float const *x, int incx, - float* A, int lda ) -{ - cblas_ssyr( layout, uplo, n, alpha, x, incx, A, lda ); -} - -inline void -cblas_syr( - CBLAS_ORDER layout, CBLAS_UPLO uplo, int n, - double alpha, - double const *x, int incx, - double* A, int lda ) -{ - cblas_dsyr( layout, uplo, n, alpha, x, incx, A, lda ); -} - -// ----------------------------------------------------------------------------- -inline void -cblas_spr( - CBLAS_ORDER layout, CBLAS_UPLO uplo, int n, - float alpha, - float const *x, int incx, - float* Ap ) -{ - cblas_sspr( layout, uplo, n, alpha, x, incx, Ap ); -} - -inline void -cblas_spr( - CBLAS_ORDER layout, CBLAS_UPLO uplo, int n, - double alpha, - double const *x, int incx, - double* Ap ) -{ - cblas_dspr( layout, uplo, n, alpha, x, incx, Ap ); -} - -// ----------------------------------------------------------------------------- -inline void -cblas_syr2( - CBLAS_ORDER layout, CBLAS_UPLO uplo, int n, - float alpha, - float const *x, int incx, - float const *y, int incy, - float* A, int lda ) -{ - cblas_ssyr2( layout, uplo, n, alpha, x, incx, y, incy, A, lda ); -} - -inline void -cblas_syr2( - CBLAS_ORDER layout, CBLAS_UPLO uplo, int n, - double alpha, - double const *x, int incx, - double const *y, int incy, - double* A, int lda ) -{ - cblas_dsyr2( layout, uplo, n, alpha, x, incx, y, incy, A, lda ); -} - -// ----------------------------------------------------------------------------- -inline void -cblas_spr2( - CBLAS_ORDER layout, CBLAS_UPLO uplo, int n, - float alpha, - float const *x, int incx, - float const *y, int incy, - float* Ap ) -{ - cblas_sspr2( layout, uplo, n, alpha, x, incx, y, incy, Ap ); -} - -inline void -cblas_spr2( - CBLAS_ORDER layout, CBLAS_UPLO uplo, int n, - double alpha, - double const *x, int incx, - double const *y, int incy, - double* Ap ) -{ - cblas_dspr2( layout, uplo, n, alpha, x, incx, y, incy, Ap ); -} - -// ============================================================================= -// Level 3 BLAS - -// ----------------------------------------------------------------------------- -inline void -cblas_gemm( - CBLAS_ORDER layout, CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, - int m, int n, int k, - float alpha, - float const *A, int lda, - float const *B, int ldb, - float beta, - float* C, int ldc ) -{ - cblas_sgemm( layout, transA, transB, m, n, k, - alpha, A, lda, B, ldb, - beta, C, ldc ); -} - -inline void -cblas_gemm( - CBLAS_ORDER layout, CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, - int m, int n, int k, - double alpha, - double const *A, int lda, - double const *B, int ldb, - double beta, - double* C, int ldc ) -{ - cblas_dgemm( layout, transA, transB, m, n, k, - alpha, A, lda, B, ldb, - beta, C, ldc ); -} - -inline void -cblas_gemm( - CBLAS_ORDER layout, CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, - int m, int n, int k, - std::complex alpha, - std::complex const *A, int lda, - std::complex const *B, int ldb, - std::complex beta, - std::complex* C, int ldc ) -{ - cblas_cgemm( layout, transA, transB, m, n, k, - &alpha, A, lda, B, ldb, - &beta, C, ldc ); -} - -inline void -cblas_gemm( - CBLAS_ORDER layout, CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, - int m, int n, int k, - std::complex alpha, - std::complex const *A, int lda, - std::complex const *B, int ldb, - std::complex beta, - std::complex* C, int ldc ) -{ - cblas_zgemm( layout, transA, transB, m, n, k, - &alpha, A, lda, B, ldb, - &beta, C, ldc ); -} - -// ----------------------------------------------------------------------------- -inline void -cblas_trmm( - CBLAS_ORDER layout, CBLAS_SIDE side, CBLAS_UPLO uplo, - CBLAS_TRANSPOSE trans, CBLAS_DIAG diag, - int m, int n, - float alpha, - float const *A, int lda, - float *B, int ldb ) -{ - cblas_strmm( layout, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb); -} - -inline void -cblas_trmm( - CBLAS_ORDER layout, CBLAS_SIDE side, CBLAS_UPLO uplo, - CBLAS_TRANSPOSE trans, CBLAS_DIAG diag, - int m, int n, - double alpha, - double const *A, int lda, - double *B, int ldb ) -{ - cblas_dtrmm( layout, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb); -} - -inline void -cblas_trmm( - CBLAS_ORDER layout, CBLAS_SIDE side, CBLAS_UPLO uplo, - CBLAS_TRANSPOSE trans, CBLAS_DIAG diag, - int m, int n, - std::complex alpha, - std::complex const *A, int lda, - std::complex *B, int ldb ) -{ - cblas_ctrmm( layout, side, uplo, trans, diag, m, n, &alpha, A, lda, B, ldb ); -} - -inline void -cblas_trmm( - CBLAS_ORDER layout, CBLAS_SIDE side, CBLAS_UPLO uplo, - CBLAS_TRANSPOSE trans, CBLAS_DIAG diag, - int m, int n, - std::complex alpha, - std::complex const *A, int lda, - std::complex *B, int ldb ) -{ - cblas_ztrmm( layout, side, uplo, trans, diag, m, n, &alpha, A, lda, B, ldb ); -} - - -// ----------------------------------------------------------------------------- -inline void -cblas_trsm( - CBLAS_ORDER layout, CBLAS_SIDE side, CBLAS_UPLO uplo, - CBLAS_TRANSPOSE trans, CBLAS_DIAG diag, - int m, int n, - float alpha, - float const *A, int lda, - float *B, int ldb ) -{ - cblas_strsm( layout, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb); -} - -inline void -cblas_trsm( - CBLAS_ORDER layout, CBLAS_SIDE side, CBLAS_UPLO uplo, - CBLAS_TRANSPOSE trans, CBLAS_DIAG diag, - int m, int n, - double alpha, - double const *A, int lda, - double *B, int ldb ) -{ - cblas_dtrsm( layout, side, uplo, trans, diag, m, n, alpha, A, lda, B, ldb); -} - -inline void -cblas_trsm( - CBLAS_ORDER layout, CBLAS_SIDE side, CBLAS_UPLO uplo, - CBLAS_TRANSPOSE trans, CBLAS_DIAG diag, - int m, int n, - std::complex alpha, - std::complex const *A, int lda, - std::complex *B, int ldb ) -{ - cblas_ctrsm( layout, side, uplo, trans, diag, m, n, &alpha, A, lda, B, ldb ); -} - -inline void -cblas_trsm( - CBLAS_ORDER layout, CBLAS_SIDE side, CBLAS_UPLO uplo, - CBLAS_TRANSPOSE trans, CBLAS_DIAG diag, - int m, int n, - std::complex alpha, - std::complex const *A, int lda, - std::complex *B, int ldb ) -{ - cblas_ztrsm( layout, side, uplo, trans, diag, m, n, &alpha, A, lda, B, ldb ); -} - -// ----------------------------------------------------------------------------- -inline void -cblas_hemm( - CBLAS_ORDER layout, CBLAS_SIDE side, CBLAS_UPLO uplo, - int m, int n, - float alpha, - float const *A, int lda, - float const *B, int ldb, - float beta, - float* C, int ldc ) -{ - cblas_ssymm( layout, side, uplo, m, n, - alpha, A, lda, B, ldb, - beta, C, ldc ); -} - -inline void -cblas_hemm( - CBLAS_ORDER layout, CBLAS_SIDE side, CBLAS_UPLO uplo, - int m, int n, - double alpha, - double const *A, int lda, - double const *B, int ldb, - double beta, - double* C, int ldc ) -{ - cblas_dsymm( layout, side, uplo, m, n, - alpha, A, lda, B, ldb, - beta, C, ldc ); -} - -inline void -cblas_hemm( - CBLAS_ORDER layout, CBLAS_SIDE side, CBLAS_UPLO uplo, - int m, int n, - std::complex alpha, - std::complex const *A, int lda, - std::complex const *B, int ldb, - std::complex beta, - std::complex* C, int ldc ) -{ - cblas_chemm( layout, side, uplo, m, n, - &alpha, A, lda, B, ldb, - &beta, C, ldc ); -} - -inline void -cblas_hemm( - CBLAS_ORDER layout, CBLAS_SIDE side, CBLAS_UPLO uplo, - int m, int n, - std::complex alpha, - std::complex const *A, int lda, - std::complex const *B, int ldb, - std::complex beta, - std::complex* C, int ldc ) -{ - cblas_zhemm( layout, side, uplo, m, n, - &alpha, A, lda, B, ldb, - &beta, C, ldc ); -} - -// ----------------------------------------------------------------------------- -inline void -cblas_symm( - CBLAS_ORDER layout, CBLAS_SIDE side, CBLAS_UPLO uplo, - int m, int n, - float alpha, - float const *A, int lda, - float const *B, int ldb, - float beta, - float* C, int ldc ) -{ - cblas_ssymm( layout, side, uplo, m, n, - alpha, A, lda, B, ldb, - beta, C, ldc ); -} - -inline void -cblas_symm( - CBLAS_ORDER layout, CBLAS_SIDE side, CBLAS_UPLO uplo, - int m, int n, - double alpha, - double const *A, int lda, - double const *B, int ldb, - double beta, - double* C, int ldc ) -{ - cblas_dsymm( layout, side, uplo, m, n, - alpha, A, lda, B, ldb, - beta, C, ldc ); -} - -inline void -cblas_symm( - CBLAS_ORDER layout, CBLAS_SIDE side, CBLAS_UPLO uplo, - int m, int n, - std::complex alpha, - std::complex const *A, int lda, - std::complex const *B, int ldb, - std::complex beta, - std::complex* C, int ldc ) -{ - cblas_csymm( layout, side, uplo, m, n, - &alpha, A, lda, B, ldb, - &beta, C, ldc ); -} - -inline void -cblas_symm( - CBLAS_ORDER layout, CBLAS_SIDE side, CBLAS_UPLO uplo, - int m, int n, - std::complex alpha, - std::complex const *A, int lda, - std::complex const *B, int ldb, - std::complex beta, - std::complex* C, int ldc ) -{ - cblas_zsymm( layout, side, uplo, m, n, - &alpha, A, lda, B, ldb, - &beta, C, ldc ); -} - - -// ----------------------------------------------------------------------------- -inline void -cblas_syrk( - CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, int n, int k, - float alpha, - float const *A, int lda, - float beta, - float* C, int ldc ) -{ - cblas_ssyrk( layout, uplo, trans, n, k, alpha, A, lda, beta, C, ldc ); -} - -inline void -cblas_syrk( - CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, int n, int k, - double alpha, - double const *A, int lda, - double beta, - double* C, int ldc ) -{ - cblas_dsyrk( layout, uplo, trans, n, k, alpha, A, lda, beta, C, ldc ); -} - -inline void -cblas_syrk( - CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, int n, int k, - std::complex alpha, - std::complex const *A, int lda, - std::complex beta, - std::complex* C, int ldc ) -{ - cblas_csyrk( layout, uplo, trans, n, k, &alpha, A, lda, &beta, C, ldc ); -} - -inline void -cblas_syrk( - CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, int n, int k, - std::complex alpha, - std::complex const *A, int lda, - std::complex beta, - std::complex* C, int ldc ) -{ - cblas_zsyrk( layout, uplo, trans, n, k, &alpha, A, lda, &beta, C, ldc ); -} - -// ----------------------------------------------------------------------------- -inline void -cblas_herk( - CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, int n, int k, - float alpha, - float const *A, int lda, - float beta, - float* C, int ldc ) -{ - cblas_ssyrk( layout, uplo, trans, n, k, alpha, A, lda, beta, C, ldc ); -} - -inline void -cblas_herk( - CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, int n, int k, - double alpha, - double const *A, int lda, - double beta, - double* C, int ldc ) -{ - cblas_dsyrk( layout, uplo, trans, n, k, alpha, A, lda, beta, C, ldc ); -} - -inline void -cblas_herk( - CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, int n, int k, - float alpha, // note: real - std::complex const *A, int lda, - float beta, // note: real - std::complex* C, int ldc ) -{ - cblas_cherk( layout, uplo, trans, n, k, alpha, A, lda, beta, C, ldc ); -} - -inline void -cblas_herk( - CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, int n, int k, - double alpha, // note: real - std::complex const *A, int lda, - double beta, // note: real - std::complex* C, int ldc ) -{ - cblas_zherk( layout, uplo, trans, n, k, alpha, A, lda, beta, C, ldc ); -} - -// ----------------------------------------------------------------------------- -inline void -cblas_syr2k( - CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, int n, int k, - float alpha, - float const *A, int lda, - float const *B, int ldb, - float beta, - float* C, int ldc ) -{ - cblas_ssyr2k( layout, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc ); -} - -inline void -cblas_syr2k( - CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, int n, int k, - double alpha, - double const *A, int lda, - double const *B, int ldb, - double beta, - double* C, int ldc ) -{ - cblas_dsyr2k( layout, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc ); -} - -inline void -cblas_syr2k( - CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, int n, int k, - std::complex alpha, - std::complex const *A, int lda, - std::complex const *B, int ldb, - std::complex beta, - std::complex* C, int ldc ) -{ - cblas_csyr2k( layout, uplo, trans, n, k, &alpha, A, lda, B, ldb, &beta, C, ldc ); -} - -inline void -cblas_syr2k( - CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, int n, int k, - std::complex alpha, - std::complex const *A, int lda, - std::complex const *B, int ldb, - std::complex beta, - std::complex* C, int ldc ) -{ - cblas_zsyr2k( layout, uplo, trans, n, k, &alpha, A, lda, B, ldb, &beta, C, ldc ); -} - -// ----------------------------------------------------------------------------- -inline void -cblas_her2k( - CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, int n, int k, - float alpha, - float const *A, int lda, - float const *B, int ldb, - float beta, - float* C, int ldc ) -{ - cblas_ssyr2k( layout, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc ); -} - -inline void -cblas_her2k( - CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, int n, int k, - double alpha, - double const *A, int lda, - double const *B, int ldb, - double beta, - double* C, int ldc ) -{ - cblas_dsyr2k( layout, uplo, trans, n, k, alpha, A, lda, B, ldb, beta, C, ldc ); -} - -inline void -cblas_her2k( - CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, int n, int k, - std::complex alpha, - std::complex const *A, int lda, - std::complex const *B, int ldb, - float beta, // note: real - std::complex* C, int ldc ) -{ - cblas_cher2k( layout, uplo, trans, n, k, &alpha, A, lda, B, ldb, beta, C, ldc ); -} - -inline void -cblas_her2k( - CBLAS_ORDER layout, CBLAS_UPLO uplo, CBLAS_TRANSPOSE trans, int n, int k, - std::complex alpha, - std::complex const *A, int lda, - std::complex const *B, int ldb, - double beta, // note: real - std::complex* C, int ldc ) -{ - cblas_zher2k( layout, uplo, trans, n, k, &alpha, A, lda, B, ldb, beta, C, ldc ); -} -}//namespace blis - -#endif // #ifndef CBLAS_HH diff --git a/testcpp/Makefile b/testcpp/Makefile deleted file mode 100644 index ccd44172b8..0000000000 --- a/testcpp/Makefile +++ /dev/null @@ -1,208 +0,0 @@ -# BLIS -# An object-based framework for developing high-performance BLAS-like -# libraries. -# -# Copyright (C) 2014, The University of Texas at Austin -# Copyright (C) 2017 - 2019, Advanced Micro Devices, Inc. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are -# met: -# - Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# - Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the distribution. -# - Neither the name(s) of the copyright holder(s) nor the names of its -# contributors may be used to endorse or promote products derived -# from this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -# - -# -# Makefile -# -# Field G. Van Zee -# -# Makefile for standalone BLIS test drivers. -# - -# -# --- Makefile PHONY target definitions ---------------------------------------- -# - -.PHONY: all \ - blis \ - clean cleanx - - - -# -# --- Determine makefile fragment location ------------------------------------- -# - -# Comments: -# - DIST_PATH is assumed to not exist if BLIS_INSTALL_PATH is given. -# - We must use recursively expanded assignment for LIB_PATH and INC_PATH in -# the second case because CONFIG_NAME is not yet set. -ifneq ($(strip $(BLIS_INSTALL_PATH)),) -LIB_PATH := $(BLIS_INSTALL_PATH)/lib -INC_PATH := $(BLIS_INSTALL_PATH)/include/blis -SHARE_PATH := $(BLIS_INSTALL_PATH)/share/blis -else -DIST_PATH := .. -LIB_PATH = ../lib/$(CONFIG_NAME) -INC_PATH = ../include/$(CONFIG_NAME) -SHARE_PATH := .. -endif - - - -# -# --- Include common makefile definitions -------------------------------------- -# - -# Include the common makefile fragment. --include $(SHARE_PATH)/common.mk - - - -# -# --- BLAS and LAPACK implementations ------------------------------------------ -# - -# BLIS library and header path. This is simply wherever it was installed. -#BLIS_LIB_PATH := $(INSTALL_PREFIX)/lib -#BLIS_INC_PATH := $(INSTALL_PREFIX)/include/blis - -# BLIS library. -#BLIS_LIB := $(BLIS_LIB_PATH)/libblis.a - -# BLAS library path(s). This is where the BLAS libraries reside. -BLAS_LIB_PATH := $(HOME)/flame/lib - - -# -# --- General build definitions ------------------------------------------------ -# - -TEST_SRC_PATH := . -CPP_SRC_PATH := ../cpp/ -TEST_OBJ_PATH := . - -# Gather all local object files. -TEST_OBJS := $(patsubst $(TEST_SRC_PATH)/%.c, \ - $(TEST_OBJ_PATH)/%.o, \ - $(wildcard $(TEST_SRC_PATH)/*.c)) - -# Override the value of CINCFLAGS so that the value of CFLAGS returned by -# get-user-cflags-for() is not cluttered up with include paths needed only -# while building BLIS. -CINCFLAGS := -I$(INC_PATH) - -CXX = g++ - -# Use the CFLAGS for the configuration family. -override CFLAGS += $(call get-sandbox-cxxflags-for,$(CONFIG_NAME)) - -# Add local header paths to CFLAGS -#CFLAGS = -O0 -g -Wall -#CFLAGS += -I$(INC_PATH) -override CFLAGS += -I$(TEST_SRC_PATH) -override CFLAGS += -I$(CPP_SRC_PATH) - -LINKER = $(CXX) - -# Locate the libblis library to which we will link. -LIBBLIS_LINK := $(LIB_PATH)/$(LIBBLIS_L) - - - -# -# --- Targets/rules ------------------------------------------------------------ -# - -# Complete list of possible targets when defining 'all': -# -# blis -# -all: blis - - -blis: test_asum_blis.x \ - test_axpy_blis.x \ - test_copy_blis.x \ - test_dot_blis.x \ - test_dotc_blis.x \ - test_gbmv_blis.x \ - test_gemm_blis.x \ - test_gemv_blis.x \ - test_ger_blis.x \ - test_gerc_blis.x \ - test_geru_blis.x \ - test_hemm_blis.x \ - test_hemv_blis.x \ - test_her2_blis.x \ - test_her_blis.x \ - test_herk_blis.x \ - test_hpr2_blis.x \ - test_hpr_blis.x \ - test_nrm2_blis.x \ - test_rot_blis.x \ - test_rotg_blis.x \ - test_rotm_blis.x \ - test_rotmg_blis.x \ - test_scal_blis.x \ - test_sdsdot_blis.x \ - test_spr2_blis.x \ - test_spr_blis.x \ - test_swap_blis.x \ - test_symm_blis.x \ - test_syr2_blis.x \ - test_syr2k_blis.x \ - test_syr_blis.x \ - test_syrk_blis.x \ - test_tbmv_blis.x \ - test_tbsv_blis.x \ - test_tpmv_blis.x \ - test_tpsv_blis.x \ - test_trmm_blis.x \ - test_trsm_blis.x \ - test_trsv_blis.x - - - -# --Object file rules -- - -$(TEST_OBJ_PATH)/%.o: $(TEST_SRC_PATH)/%.cc - $(CXX) $(CFLAGS) -c $< -o $@ - -test_%_blis.o: test_%.cc - @$(CXX) $(CFLAGS) -c $< -o $@ - - -# -- Executable file rules -- - -test_%_blis.x: test_%_blis.o $(LIBBLIS_LINK) - @$(LINKER) $^ $(LIBBLIS_LINK) $(LDFLAGS) -o $@ - ./$@ - -# -- Clean rules -- - -clean: cleanx - -cleanx: - - $(RM_F) *.o *.x - diff --git a/testcpp/test.hh b/testcpp/test.hh deleted file mode 100644 index b1be412d64..0000000000 --- a/testcpp/test.hh +++ /dev/null @@ -1,219 +0,0 @@ -/* - * -------------------------------------------------------------------------- - * BLISLAB - * -------------------------------------------------------------------------- - * Copyright (C) 2016, The University of Texas at Austin - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are - * met: - * - Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * - Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * - Neither the name of The University of Texas nor the names of its - * contributors may be used to endorse or promote products derived - * from this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - * - * test.hh - * - * - * Purpose: - * this header file contains all function prototypes. - * - * Todo: - * - * - * Modification: - * - * - * */ - - -#ifndef TEST_HH -#define TEST_HH - -#include - -#include -#include - -using namespace std; -#define min( i, j ) ( (i)<(j) ? (i): (j) ) - -#define A( i, j ) A[ (j)*lda + (i) ] -#define A_ref( i, j ) A_ref[ (j)*lda_ref + (i) ] - -#define B( i, j ) B[ (j)*ldb + (i) ] -#define B_ref( i, j ) B_ref[ (j)*ldb_ref + (i) ] - -#define C( i, j ) C[ (j)*ldc + (i) ] -#define C_ref( i, j ) C_ref[ (j)*ldc_ref + (i) ] - -#define X( i ) X[ incx + (i) ] -#define X_ref( i, j ) X_ref[ (j)*incx_ref + (i) - -#define Y( i ) Y[ incy + (i) ] -#define Y_ref( i ) Y_ref[ incy_ref + (i) ]\ - -// Allocate memory and initialise memory with random values -void allocate_init_buffer(int *aIn, int m, int n) -{ - aIn = new int [m*n]; - for ( int i = 0; i < m*n; i ++ ) { - aIn[ i ] = ((int) rand() / ((int) RAND_MAX / 2.0)) - 1.0; - } -} - -void allocate_init_buffer(float *&aIn, int m, int n) -{ - aIn = new float [m*n]; - for ( int i = 0; i < m*n; i ++ ) { - aIn[ i ] = ((float) rand() / ((float) RAND_MAX / 2.0)) - 1.0; - } -} - -void allocate_init_buffer(double *&aIn, int m, int n) -{ - aIn = new double [m*n]; - for ( int i = 0; i < m*n; i ++ ) { - aIn[ i ] = ((double) rand() / ((double) RAND_MAX / 2.0)) - 1.0; - } -} -void allocate_init_buffer(complex *&aIn, int m, int n) -{ - aIn = new complex [m*n]; - for ( int i = 0; i < m*n; i ++ ) { - float real = ((float) rand() / ((float) RAND_MAX / 2.0)) - 1.0; - float imag = ((float) rand() / ((float) RAND_MAX / 2.0)) - 1.0; - aIn[i] = {real,imag}; - } -} -void allocate_init_buffer(complex *&aIn, int m, int n) -{ - aIn = new complex [m*n]; - for ( int i = 0; i < m*n; i ++ ) { - double real = ((double) rand() / ((double) RAND_MAX / 2.0)) - 1.0; - double imag = ((double) rand() / ((double) RAND_MAX / 2.0)) - 1.0; - aIn[i] = {real,imag}; - } -} - -template< typename T > -void copy_buffer(T *aSrc, T *&aDest, int m, int n) -{ - aDest = new T [m*n]; - for ( int i = 0; i < m*n; i ++ ) { - aDest[i] = aSrc[i]; - } -} - -template< typename T > -int computeErrorM( - int lda, - int lda_ref, - int m, - int n, - T *A, - T *A_ref - ) -{ - - int i, j; - int ret = 0; - for ( i = 0; i < m; i ++ ) { - for ( j = 0; j < n; j ++ ) { - if ( (fabs (A( i, j )) - fabs( A_ref( i, j ))) > 0.0000001 ) { - cout << A(i,j) << A_ref(i,j); - ret = 1; - break; - } - } - } - return ret; - -} - - - - template< typename T > - int computeErrorV( - int incy, - int incy_ref, - int n, - T *Y, - T *Y_ref - ) - { - int i; - int ret = 0; - for ( i = 0; i < n; i ++ ) { - if ( (fabs( Y_ref[ i ]) - fabs(Y[ i ] ) ) > 0.00001) { - cout << Y[i] << Y_ref[i]; - ret = 1; - break; - } - } - - return ret; - - } - -/* - *printing matix and vector - * - */ - -template -void printmatrix( - T *A, - int lda, - int m, - int n, - char *func_str - ) -{ - int i, j; - cout << func_str <<"\n"; - for ( i = 0; i < m; i ++ ) { - for ( j = 0; j < n; j ++ ) { - cout<< A[j * lda + i]<<" "; - } - printf("\n"); - } - printf("\n"); -} - -template -void printvector( - T *X, - int m, - char *func_str - ) - { - int i; - cout << func_str <<"\n"; - for ( i = 0; i < m; i ++ ) { - cout<< X[i]<<" "; - cout<<"\n"; - } - printf("\n"); - - } - - -#endif diff --git a/testcpp/test.sh b/testcpp/test.sh deleted file mode 100644 index 6d06c867bb..0000000000 --- a/testcpp/test.sh +++ /dev/null @@ -1,46 +0,0 @@ - -echo Build BLIS CPP Template tests -make clean -make - -echo Run tests -./test_asum_blis.x -./test_axpy_blis.x -./test_copy_blis.x -./test_dot_blis.x -./test_dotc_blis.x -./test_gbmv_blis.x -./test_gemm_blis.x -./test_gemv_blis.x -./test_ger_blis.x -./test_gerc_blis.x -./test_geru_blis.x -./test_hemm_blis.x -./test_hemv_blis.x -./test_her2_blis.x -./test_her_blis.x -./test_herk_blis.x -./test_hpr2_blis.x -./test_hpr_blis.x -./test_nrm2_blis.x -./test_rot_blis.x -./test_rotg_blis.x -./test_rotm_blis.x -./test_rotmg_blis.x -./test_scal_blis.x -./test_sdsdot_blis.x -./test_spr2_blis.x -./test_spr_blis.x -./test_swap_blis.x -./test_symm_blis.x -./test_syr2_blis.x -./test_syr2k_blis.x -./test_syr_blis.x -./test_syrk_blis.x -./test_tbmv_blis.x -./test_tbsv_blis.x -./test_tpmv_blis.x -./test_tpsv_blis.x -./test_trmm_blis.x -./test_trsm_blis.x -./test_trsv_blis.x diff --git a/testcpp/test_asum.cc b/testcpp/test_asum.cc deleted file mode 100644 index 948f4250fd..0000000000 --- a/testcpp/test_asum.cc +++ /dev/null @@ -1,127 +0,0 @@ -/* - - BLISPP - C++ test driver for BLIS CPP asum routine and reference blis asum routine. - - Copyright (C) 2019, Advanced Micro Devices, Inc. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include -#include "blis.hh" -#include "test.hh" - -using namespace blis; -using namespace std; -//#define PRINT -#define N 6 -#define ALPHA 0.5 - -template< typename T, typename TR> -void ref_asum(int64_t n, - T *X, - TR *asum - ) -{ - obj_t obj_x; - obj_t obj_asum; - num_t dt, dtR; - - if(is_same::value) - dt = BLIS_FLOAT; - else if(is_same::value) - dt = BLIS_DOUBLE; - else if(is_same>::value) - dt = BLIS_SCOMPLEX; - else if(is_same>::value) - dt = BLIS_DCOMPLEX; - - if(is_same::value) - dtR = BLIS_FLOAT; - else if(is_same::value) - dtR = BLIS_DOUBLE; - - bli_obj_create_with_attached_buffer( dt, n, 1, X, 1, n,&obj_x ); - bli_obj_create_with_attached_buffer( dtR, 1, 1, asum, 1, 1,&obj_asum ); - - bli_asumv(&obj_x, &obj_asum); - -} -template< typename T, typename TR> -void test_asum() -{ - - T *X, *X_ref; - TR asum, asum_ref; - int n; - int incx; - - n = N; - incx = 1; - srand (time(NULL)); - allocate_init_buffer(X , n , 1); - copy_buffer(X, X_ref , n ,1); - -#ifdef PRINT - printvector(X, n,(char *) "X"); -#endif - - asum = blis::asum( - n, - X, - incx - ); - -#ifdef PRINT - cout<< "Sum of all values in Vector X: " << asum << "\n"; -#endif - - ref_asum(n, X_ref, &asum_ref ); - -#ifdef PRINT - cout<< "Ref Sum of all values in Vector X: " << asum_ref << "\n"; -#endif - if(computeErrorV(incx, incx, 1, &asum, &asum_ref )==1) - printf("%s TEST FAIL\n" , __PRETTY_FUNCTION__); - else - printf("%s TEST PASS\n" , __PRETTY_FUNCTION__); - - delete[]( X ); - delete[]( X_ref ); -} - -// ----------------------------------------------------------------------------- -int main( int argc, char** argv ) -{ - test_asum( ); - test_asum( ); - test_asum, float>( ); - test_asum, double>( ); - return 0; - -} diff --git a/testcpp/test_axpy.cc b/testcpp/test_axpy.cc deleted file mode 100644 index 45035198c3..0000000000 --- a/testcpp/test_axpy.cc +++ /dev/null @@ -1,138 +0,0 @@ -/* - - BLISPP - C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. - - Copyright (C) 2019, Advanced Micro Devices, Inc. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include -#include "blis.hh" -#include "test.hh" - -using namespace blis; -using namespace std; -//#define PRINT -#define N 6 -#define ALPHA 1.0 - -/* - * Test application assumes matrices to be column major, non-transposed - */ -template< typename T> -void ref_axpy(int64_t n, - T * alpha, - T *X, - T *Y - ) - -{ - obj_t obj_x, obj_y, obj_alpha; - num_t dt; - - if(is_same::value) - dt = BLIS_FLOAT; - else if(is_same::value) - dt = BLIS_DOUBLE; - else if(is_same>::value) - dt = BLIS_SCOMPLEX; - else if(is_same>::value) - dt = BLIS_DCOMPLEX; - - - bli_obj_create_with_attached_buffer( dt, 1, 1, alpha, 1,1,&obj_alpha ); - bli_obj_create_with_attached_buffer( dt, n, 1, X, 1, n,&obj_x ); - bli_obj_create_with_attached_buffer( dt, n, 1, Y, 1, n,&obj_y ); - - bli_axpyv( &obj_alpha, - &obj_x, - &obj_y - ); - -} -template< typename T > -void test_axpy( ) -{ - T *X, *Y,*Y_ref; - T alpha = ALPHA; - int n; - int incx, incy; - - n = N; - - incx = 1; - incy = 1; - - srand (time(NULL)); - allocate_init_buffer(X , n , 1); - allocate_init_buffer(Y , n , 1); - copy_buffer(Y, Y_ref , n ,1); - -#ifdef PRINT - printvector(X, n,(char *) "X"); - printvector(Y, n, (char *) "Y"); -#endif - blis::axpy( - n, - alpha, - X, - incx, - Y, - incy - ); - -#ifdef PRINT - printvector(Y, n,(char *) "Y output"); -#endif - ref_axpy(n , &alpha , X, Y_ref ); - -#ifdef PRINT - printvector(Y_ref, n, (char *) "Y ref output"); -#endif - if(computeErrorV(incy, incy , n, Y, Y_ref )==1) - printf("%s TEST FAIL\n" , __PRETTY_FUNCTION__); - else - printf("%s TEST PASS\n" , __PRETTY_FUNCTION__); - - - delete[]( X ); - delete[]( Y ); - delete[]( Y_ref ); -} - -// ----------------------------------------------------------------------------- -int main( int argc, char** argv ) -{ - test_axpy( ); - test_axpy( ); - test_axpy>( ); - test_axpy>( ); - return 0; - -} diff --git a/testcpp/test_copy.cc b/testcpp/test_copy.cc deleted file mode 100644 index a1042d1c9b..0000000000 --- a/testcpp/test_copy.cc +++ /dev/null @@ -1,132 +0,0 @@ -/* - - BLISPP - C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. - - Copyright (C) 2019, Advanced Micro Devices, Inc. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include -#include "blis.hh" -#include "test.hh" - -using namespace blis; -using namespace std; -//#define PRINT -#define N 6 - -/* - * Test application assumes matrices to be column major, non-transposed - */ -template< typename T> -void ref_copy(int64_t n, - T *X, - T *Y - ) - -{ - obj_t obj_x, obj_y; - num_t dt; - - if(is_same::value) - dt = BLIS_FLOAT; - else if(is_same::value) - dt = BLIS_DOUBLE; - else if(is_same>::value) - dt = BLIS_SCOMPLEX; - else if(is_same>::value) - dt = BLIS_DCOMPLEX; - - - bli_obj_create_with_attached_buffer( dt, n, 1, X, 1, n,&obj_x ); - bli_obj_create_with_attached_buffer( dt, n, 1, Y, 1, n,&obj_y ); - - bli_copyv( &obj_x, - &obj_y - ); - -} -template< typename T > -void test_copy( ) -{ - T *X, *X_ref, *Y,*Y_ref; - int n; - int incx, incy; - - n = N; - - incx = 1; - incy = 1; - - Y = new T[n]; - Y_ref = new T[n]; - srand (time(NULL)); - allocate_init_buffer(X , n , 1); - copy_buffer(X, X_ref , n ,1); - -#ifdef PRINT - printvector(X, n,(char *) "X"); -#endif - blis::copy( - n, - X, - incx, - Y, - incy - ); - -#ifdef PRINT - printvector(Y, n,(char *) "Y output"); -#endif - ref_copy(n , X_ref, Y_ref ); - -#ifdef PRINT - printvector(Y_ref, n,(char *) "Y ref output"); -#endif - if(computeErrorV(incy , incy , n, Y, Y_ref )==1) - printf("%s TEST FAIL\n" , __PRETTY_FUNCTION__); - else - printf("%s TEST PASS\n" , __PRETTY_FUNCTION__); - - delete[]( X ); - delete[]( X_ref ); - delete[]( Y ); - delete[]( Y_ref ); -} - -// ----------------------------------------------------------------------------- -int main( int argc, char** argv ) -{ - test_copy( ); - test_copy( ); - test_copy>(); - test_copy>(); - return 0; - -} diff --git a/testcpp/test_dot.cc b/testcpp/test_dot.cc deleted file mode 100644 index 553287784a..0000000000 --- a/testcpp/test_dot.cc +++ /dev/null @@ -1,131 +0,0 @@ -/* - - BLISPP - C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. - - Copyright (C) 2019, Advanced Micro Devices, Inc. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include -#include "blis.hh" -#include "test.hh" - -using namespace blis; -using namespace std; -//#define PRINT -#define N 6 - -/* - * Test application assumes matrices to be column major, non-transposed - */ -template< typename T, typename TR> -void ref_dot(int64_t n, - T *X, - T *Y, - TR *res_ref - ) - -{ - obj_t obj_x; - obj_t obj_y; - obj_t obj_res; - num_t dt; - - if(is_same::value) - dt = BLIS_FLOAT; - else if(is_same::value) - dt = BLIS_DOUBLE; - - bli_obj_create_with_attached_buffer( dt, n, 1, X, 1, n,&obj_x ); - bli_obj_create_with_attached_buffer( dt, n, 1, Y, 1, n,&obj_y ); - bli_obj_create_with_attached_buffer( dt, 1, 1, res_ref, 1, 1,&obj_res ); - - bli_dotv(&obj_x, - &obj_y, - &obj_res ); - -} -template< typename T, typename TR> -void test_dot() -{ - T *X, *Y; - int n; - int incx, incy; - TR res = 0, res_ref = 0; - - n = N; - - incx = 1; - incy = 1; - - srand (time(NULL)); - allocate_init_buffer(X , n , 1); - allocate_init_buffer(Y , n , 1); - -#ifdef PRINT - printvector(X, n, (char *)"X"); - printvector(Y, n, (char *)"Y"); -#endif - res = blis::dot( - n, - X, - incx, - Y, - incy - ); - -#ifdef PRINT - printf("Dot product = %E \n", res); - -#endif - ref_dot(n, X, Y , &res_ref ); - -#ifdef PRINT - printf("Dot product ref_dot %E \n", res_ref); - -#endif - if(res != res_ref ) - printf("%s TEST FAIL\n" ,__PRETTY_FUNCTION__); - else - printf("%s TEST PASS\n" , __PRETTY_FUNCTION__); - - - delete[]( X ); - delete[]( Y ); -} - -// ----------------------------------------------------------------------------- -int main( int argc, char** argv ) -{ - test_dot( ); - //test_dot( ); - test_dot( ); - return 0; - -} diff --git a/testcpp/test_dotc.cc b/testcpp/test_dotc.cc deleted file mode 100644 index 88ffe19c4d..0000000000 --- a/testcpp/test_dotc.cc +++ /dev/null @@ -1,127 +0,0 @@ -/* - - BLISPP - C++ test driver for BLIS CPP dotc routine and reference blis dotc routine. - - Copyright (C) 2019, Advanced Micro Devices, Inc. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include -#include "blis.hh" -#include "test.hh" - -using namespace blis; -using namespace std; -//#define PRINT -#define N 16 - -template< typename T > -void ref_dotc(int64_t n, - T *X, - T *Y, - T *res_ref - ) -{ - obj_t obj_x; - obj_t obj_y; - obj_t obj_res; - num_t dt; - - if(is_same>::value) - dt = BLIS_SCOMPLEX; - else if(is_same>::value) - dt = BLIS_DCOMPLEX; - - bli_obj_create_with_attached_buffer( dt, n, 1, X, 1, n, &obj_x ); - bli_obj_create_with_attached_buffer( dt, n, 1, Y, 1, n, &obj_y ); - bli_obj_set_conj(BLIS_CONJUGATE,&obj_x); - bli_obj_create_with_attached_buffer( dt, 1, 1, res_ref, 1, 1,&obj_res ); - - bli_dotv(&obj_x, - &obj_y, - &obj_res ); - -} - -template< typename T > -void test_dotc() -{ - T *X, *Y; - int n; - int incx, incy; - T res = 0, res_ref = 0; - - n = N; - - incx = 1; - incy = 1; - - srand (time(NULL)); - allocate_init_buffer(X , n , 1); - allocate_init_buffer(Y , n , 1); - -#ifdef PRINT - printvector(X, n,(char *) "X"); - printvector(Y, n,(char *) "Y"); -#endif - - res = blis::dotc( - n, - X, - incx, - Y, - incy - ); - -#ifdef PRINT - cout<< "Dot product \n" << res << "\n"; -#endif - ref_dotc(n, X, Y , &res_ref ); - -#ifdef PRINT - cout<< "Dot product ref\n" << res_ref << "\n";; -#endif - - if(res != res_ref ) - printf("%s TEST FAIL\n" ,__PRETTY_FUNCTION__); - else - printf("%s TEST PASS\n" , __PRETTY_FUNCTION__); - - delete[]( X ); - delete[]( Y ); -} - -// ----------------------------------------------------------------------------- -int main( int argc, char** argv ) -{ - test_dotc>( ); - test_dotc>( ); - return 0; - -} diff --git a/testcpp/test_gbmv.cc b/testcpp/test_gbmv.cc deleted file mode 100644 index 6d64f42ee3..0000000000 --- a/testcpp/test_gbmv.cc +++ /dev/null @@ -1,109 +0,0 @@ -/* - - BLISPP - C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. - - Copyright (C) 2019, Advanced Micro Devices, Inc. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include -#include "blis.hh" -#include "test.hh" - -using namespace blis; -using namespace std; -//#define PRINT -#define ALPHA -1.0 -#define BETA -1.0 -#define M 3 -#define N 4 - -template< typename T > -void test_gbmv( ) -{ -// int i, j, p; - T alpha, beta; - int m,n; - int KL = 1; - int KU = 1; - int lda = 4; - T A[] = { 0.423f, -0.143f, -0.182f, -0.076f, -0.855f, 0.599f, 0.389f, -0.473f, 0.493f, -0.902f, -0.889f, -0.256f, 0.112f, 0.128f, -0.277f, -0.777f }; - T X[] = { 0.488f, 0.029f, -0.633f, 0.84f }; - int incX = -1; - T Y[] = { 0.874f, 0.322f, -0.477f }; - int incY = -1; - T Y_ref[] = { -0.656261f, 0.19575f, 0.055905f }; - alpha = ALPHA; - beta = BETA; - m = M; - n = N; - - -#ifdef PRINT - printmatrix(A, lda ,m,n,(char *) "A"); - printvector(Y, m, (char *)"m"); -#endif - blis::gbmv( - CblasColMajor, - CblasNoTrans, - m, - n,KL,KU, - alpha, - A, - lda, - X, - incX, - beta, - Y, - incY - ); - -#ifdef PRINT - printvector(Y, m,(char *)"Y blis:gbmv"); - printvector(Y_ref, m, (char *) "Y_ref blis:gbmv" ); - -#endif - - if(computeErrorV(incY,incY, m, Y, Y_ref )==1) - printf("%s TEST FAIL\n" , __PRETTY_FUNCTION__ ); - else - printf("%s TEST PASS\n" , __PRETTY_FUNCTION__ ); - -} - -// ----------------------------------------------------------------------------- -int main( int argc, char** argv ) -{ - test_gbmv( ); - test_gbmv( ); - test_gbmv>( ); - test_gbmv>( ); - return 0; - -} diff --git a/testcpp/test_gemm.cc b/testcpp/test_gemm.cc deleted file mode 100644 index 2fe6e55a7c..0000000000 --- a/testcpp/test_gemm.cc +++ /dev/null @@ -1,163 +0,0 @@ -/* - - BLISPP - C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. - - Copyright (C) 2019, Advanced Micro Devices, Inc. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include -#include "blis.hh" -#include "test.hh" - -using namespace blis; -using namespace std; -//#define PRINT -#define ALPHA 1.0 -#define BETA 0.0 -#define M 5 -#define N 6 -#define K 4 - -/* - * Test application assumes matrices to be column major, non-transposed - */ -template< typename T > -void ref_gemm(int64_t m, int64_t n, int64_t k, - T * alpha, - T *A, - T *B, - T * beta, - T *C ) - -{ - obj_t obj_a, obj_b, obj_c; - obj_t obj_alpha, obj_beta; - num_t dt; - if(is_same::value) - dt = BLIS_FLOAT; - else if(is_same::value) - dt = BLIS_DOUBLE; - else if(is_same>::value) - dt = BLIS_SCOMPLEX; - else if(is_same>::value) - dt = BLIS_DCOMPLEX; - - bli_obj_create_with_attached_buffer( dt, 1, 1, alpha, 1,1,&obj_alpha ); - bli_obj_create_with_attached_buffer( dt, 1, 1, beta, 1,1,&obj_beta ); - bli_obj_create_with_attached_buffer( dt, m, k, A, 1,m,&obj_a ); - bli_obj_create_with_attached_buffer( dt, k, n, B,1,k,&obj_b ); - bli_obj_create_with_attached_buffer( dt, m, n, C, 1,m,&obj_c ); - - bli_obj_set_conjtrans( BLIS_NO_TRANSPOSE, &obj_a ); - bli_obj_set_conjtrans( BLIS_NO_TRANSPOSE, &obj_b ); - bli_gemm( &obj_alpha, - &obj_a, - &obj_b, - &obj_beta, - &obj_c ); - -} -template< typename T > -void test_gemm( ) -{ - T *A, *B, *C, *C_ref; - T alpha, beta; - int m,n,k; - int lda, ldb, ldc, ldc_ref; - - alpha = ALPHA; - beta = BETA; - m = M; - k = K; - n = N; - - lda = m; - ldb = k; - ldc = m; - ldc_ref = m; - srand (time(NULL)); - allocate_init_buffer(A , m , k); - allocate_init_buffer(B , k , n); - allocate_init_buffer(C , m , n); - copy_buffer(C, C_ref , m ,n); - -#ifdef PRINT - printmatrix(A, lda ,m,k , (char *)"A"); - printmatrix(B, ldb ,k,n, (char *)"B"); - printmatrix(C, ldc ,m,n, (char *)"C"); -#endif - blis::gemm( - CblasColMajor, - CblasNoTrans, - CblasNoTrans, - m, - n, - k, - alpha, - A, - lda, - B, - ldb, - beta, - C, - ldc - ); - -#ifdef PRINT - printmatrix(C,ldc ,m,n , (char *)"C output"); -#endif - ref_gemm(m, n, k, &alpha, A, B, &beta, C_ref); - -#ifdef PRINT - printmatrix(C_ref, ldc_ref ,m,n, (char *)"C ref output"); -#endif - if(computeErrorM(ldc, ldc_ref, m, n, C, C_ref )==1) - printf("%s TEST FAIL\n" , __PRETTY_FUNCTION__ ); - else - printf("%s TEST PASS\n" , __PRETTY_FUNCTION__ ); - - - - delete[]( A ); - delete[]( B ); - delete[]( C ); - delete[]( C_ref ); -} - -// ----------------------------------------------------------------------------- -int main( int argc, char** argv ) -{ - test_gemm( ); - test_gemm( ); - test_gemm>( ); - test_gemm>( ); - return 0; - -} diff --git a/testcpp/test_gemm.hh b/testcpp/test_gemm.hh deleted file mode 100644 index 876ac16658..0000000000 --- a/testcpp/test_gemm.hh +++ /dev/null @@ -1,110 +0,0 @@ -/* - * -------------------------------------------------------------------------- - * BLISLAB - * -------------------------------------------------------------------------- - * Copyright (C) 2016, The University of Texas at Austin - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are - * met: - * - Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * - Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * - Neither the name of The University of Texas nor the names of its - * contributors may be used to endorse or promote products derived - * from this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - * - * test_gemm.hh - * - * - * Purpose: - * this header file contains all function prototypes. - * - * Todo: - * - * - * Modification: - * - * - * */ - - -#ifndef TEST_GEMM_HH -#define TEST_GEMM_HH - -#include - -#include -#include - -using namespace std; -#define min( i, j ) ( (i)<(j) ? (i): (j) ) - -#define A( i, j ) A[ (j)*lda + (i) ] -#define B( i, j ) B[ (j)*ldb + (i) ] -#define C( i, j ) C[ (j)*ldc + (i) ] -#define C_ref( i, j ) C_ref[ (j)*ldc_ref + (i) ] - -template< typename T > -int computeError( - int ldc, - int ldc_ref, - int m, - int n, - T *C, - T *C_ref - ) -{ - int i, j; - int ret = 0; - for ( i = 0; i < m; i ++ ) { - for ( j = 0; j < n; j ++ ) { - if ( C( i, j ) != C_ref( i, j ) ) { - printf( "C[ %d ][ %d ] != C_ref, %E, %E\n", i, j, C( i, j ), C_ref( i, j ) ); - ret = 1; - break; - } - } - } - return ret; - -} - -/* - * - * - */ -template -void bl_dgemm_printmatrix( - T *A, - int lda, - int m, - int n - ) -{ - int i, j; - for ( i = 0; i < m; i ++ ) { - for ( j = 0; j < n; j ++ ) { - cout<< A[j * lda + i]<<" "; - } - printf("\n"); - } - printf("\n"); -} - -#endif diff --git a/testcpp/test_gemv.cc b/testcpp/test_gemv.cc deleted file mode 100644 index ca36a61d29..0000000000 --- a/testcpp/test_gemv.cc +++ /dev/null @@ -1,162 +0,0 @@ -/* - - BLISPP - C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. - - Copyright (C) 2019, Advanced Micro Devices, Inc. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include -#include "blis.hh" -#include "test.hh" - -using namespace blis; -using namespace std; -//#define PRINT -#define ALPHA 1.0 -#define BETA 0.0 -#define M 5 -#define N 6 - -/* - * Test application assumes matrices to be column major, non-transposed - */ -template< typename T > -void ref_gemv(int64_t m, int64_t n, - T * alpha, - T *A, - T *X, - T * beta, - T *Y ) - -{ - obj_t obj_a, obj_x, obj_y; - obj_t obj_alpha, obj_beta; - num_t dt; - - if(is_same::value) - dt = BLIS_FLOAT; - else if(is_same::value) - dt = BLIS_DOUBLE; - else if(is_same>::value) - dt = BLIS_SCOMPLEX; - else if(is_same>::value) - dt = BLIS_DCOMPLEX; - - bli_obj_create_with_attached_buffer( dt, 1, 1, alpha, 1,1,&obj_alpha ); - bli_obj_create_with_attached_buffer( dt, 1, 1, beta, 1,1,&obj_beta ); - bli_obj_create_with_attached_buffer( dt, m, n, A, 1,m,&obj_a ); - bli_obj_create_with_attached_buffer( dt, n, 1, X, 1,n,&obj_x ); - bli_obj_create_with_attached_buffer( dt, m, 1, Y, 1,m,&obj_y ); - - bli_obj_set_conjtrans( BLIS_NO_TRANSPOSE, &obj_a ); - bli_obj_set_conjtrans( BLIS_NO_TRANSPOSE, &obj_x); - bli_obj_set_conjtrans( BLIS_NO_TRANSPOSE, &obj_y); - bli_gemv( &obj_alpha, - &obj_a, - &obj_x, - &obj_beta, - &obj_y ); - -} -template< typename T > -void test_gemv( ) -{ - T *A, *Y, *Y_ref, *X; - T alpha, beta; - int m,n; - int lda, incx, incy, incy_ref; - - alpha = ALPHA; - beta = BETA; - m = M; - n = N; - - lda = m; - incx = 1; - incy = 1; - incy_ref = 1; - - srand (time(NULL)); - allocate_init_buffer(A , m , n); - allocate_init_buffer(X , m , 1); - allocate_init_buffer(Y , m , 1); - copy_buffer(Y, Y_ref , m ,1); - -#ifdef PRINT - printmatrix(A, lda ,m,n,(char *) "A"); - printvector(X, m,(char *) "X"); - printvector(Y, m, (char *)"Y"); -#endif - blis::gemv( - CblasColMajor, - CblasNoTrans, - m, - n, - alpha, - A, - lda, - X, - incx, - beta, - Y, - incy - ); - -#ifdef PRINT - printvector(Y, m, (char *)"Y output"); -#endif - ref_gemv(m, n, &alpha, A, X, &beta, Y_ref); - -#ifdef PRINT - printvector(Y_ref, m, (char *) "Y_Ref output"); -#endif - if(computeErrorV(incy,incy_ref, m , Y, Y_ref )==1) - printf("%s TEST FAIL\n" , __PRETTY_FUNCTION__ ); - else - printf("%s TEST PASS\n" , __PRETTY_FUNCTION__ ); - - - - delete[]( A ); - delete[]( X ); - delete[]( Y ); - delete[]( Y_ref ); -} - -// ----------------------------------------------------------------------------- -int main( int argc, char** argv ) -{ - test_gemv( ); - test_gemv( ); - test_gemv>( ); - test_gemv>( ); - return 0; - -} diff --git a/testcpp/test_ger.cc b/testcpp/test_ger.cc deleted file mode 100644 index 15b018ce60..0000000000 --- a/testcpp/test_ger.cc +++ /dev/null @@ -1,150 +0,0 @@ -/* - - BLISPP - C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. - - Copyright (C) 2019, Advanced Micro Devices, Inc. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include -#include "blis.hh" -#include "test.hh" - -using namespace blis; -using namespace std; -//#define PRINT -#define ALPHA 1.0 -#define M 5 -#define N 6 - -/* - * Test application assumes matrices to be column major, non-transposed - */ -template< typename T > -void ref_ger(int64_t m, int64_t n, - T * alpha, - T *X, - T *Y, - T *A ) - -{ - obj_t obj_a; - obj_t obj_x; - obj_t obj_y; - obj_t obj_alpha; - num_t dt; - - if(is_same::value) - dt = BLIS_FLOAT; - else if(is_same::value) - dt = BLIS_DOUBLE; - - bli_obj_create_with_attached_buffer( dt, 1, 1, alpha, 1,1,&obj_alpha ); - - bli_obj_create_with_attached_buffer( dt, m, n, A, 1, m, &obj_a ); - bli_obj_create_with_attached_buffer( dt, m, 1, X, 1, m,&obj_x ); - bli_obj_create_with_attached_buffer( dt, n, 1, Y, 1, n,&obj_y ); - - //bli_obj_set_struc( BLIS_HERMITIAN, &obj_a ); - //bli_obj_set_uplo( BLIS_LOWER, &obj_a); - bli_ger( &obj_alpha, - &obj_x, - &obj_y, - &obj_a ); - -} -template< typename T > -void test_ger( ) -{ - T *A, *X, *Y, *A_ref; - T alpha; - int m,n; - int lda, incx, incy, lda_ref; - - alpha = ALPHA; - m = M; - n = N; - - lda = m; - lda_ref = m; - incx = 1; - incy = 1; - - srand (time(NULL)); - allocate_init_buffer(A , m , n); - allocate_init_buffer(X , m , 1); - allocate_init_buffer(Y , n , 1); - copy_buffer(A, A_ref , m ,n); - -#ifdef PRINT - printmatrix(A, lda ,m,n,(char *) "A"); - printvector(X, m,(char *) "X"); - printvector(Y, n,(char *) "Y"); -#endif - blis::ger( - CblasColMajor, - m, - n, - alpha, - X, - incx, - Y, - incy, - A, - lda - ); - -#ifdef PRINT - printmatrix(A, lda , m ,n ,(char *) "A output"); -#endif - ref_ger(m, n, &alpha, X, Y, A_ref); - -#ifdef PRINT - printmatrix(A_ref, lda ,m,n, (char *)"A_ref output"); -#endif - if(computeErrorM(lda, lda_ref, m, n, A, A_ref )==1) - printf("%s TEST FAIL\n" ,__PRETTY_FUNCTION__); - else - printf("%s TEST PASS\n" , __PRETTY_FUNCTION__); - - - delete[]( A ); - delete[]( X ); - delete[]( Y ); - delete[]( A_ref ); -} - -// ----------------------------------------------------------------------------- -int main( int argc, char** argv ) -{ - test_ger( ); - test_ger( ); - return 0; - -} diff --git a/testcpp/test_gerc.cc b/testcpp/test_gerc.cc deleted file mode 100644 index 332405b7c1..0000000000 --- a/testcpp/test_gerc.cc +++ /dev/null @@ -1,174 +0,0 @@ -/* - - BLISPP - C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. - - Copyright (C) 2019, Advanced Micro Devices, Inc. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include -#include "blis.hh" -#include "test.hh" - -using namespace blis; -using namespace std; -//#define PRINT -#define ALPHA 1.0 -#define M 5 -#define N 6 - -/* - * Test application assumes matrices to be column major, non-transposed - */ -template< typename T > -void ref_gerc(int64_t m, int64_t n, - T * alpha, - T *X, - T *Y, - T *A ) - -{ -obj_t obj_a; -obj_t obj_x; -obj_t obj_y; -obj_t obj_alpha; -num_t dt; - -if(is_same::value) - dt = BLIS_FLOAT; -else if(is_same::value) - dt = BLIS_DOUBLE; -else if(is_same>::value) - dt = BLIS_SCOMPLEX; -else if(is_same>::value) - dt = BLIS_DCOMPLEX; - - -if(dt == BLIS_FLOAT){ - bli_obj_create_with_attached_buffer( BLIS_FLOAT, 1, 1, alpha, 1,1,&obj_alpha ); - } -else if(dt == BLIS_DOUBLE){ - bli_obj_create_with_attached_buffer( BLIS_DOUBLE, 1, 1, alpha, 1,1,&obj_alpha ); - } - -if(dt == BLIS_SCOMPLEX){ - bli_obj_create_with_attached_buffer( BLIS_SCOMPLEX, 1, 1, alpha, 1,1,&obj_alpha ); - } -else if(dt == BLIS_DCOMPLEX){ - bli_obj_create_with_attached_buffer( BLIS_DCOMPLEX, 1, 1, alpha, 1,1,&obj_alpha ); - } - -bli_obj_create_with_attached_buffer( dt, m, n, A, 1, m, &obj_a ); -bli_obj_create_with_attached_buffer( dt, m, 1, X, 1, m,&obj_x ); -bli_obj_create_with_attached_buffer( dt, n, 1, Y, 1, n,&obj_y ); - - bli_obj_set_conj(BLIS_CONJUGATE,&obj_y); -bli_ger( &obj_alpha, - &obj_x, - &obj_y, - &obj_a ); -} - - -template< typename T > -void test_gerc( ) -{ - T *A, *X, *Y, *A_ref; - T alpha; - int m,n; - int lda, incx, incy, lda_ref; - - alpha = ALPHA; - m = M; - n = N; - - lda = m; - lda_ref = m; - incx = 1; - incy = 1; - - srand (time(NULL)); - allocate_init_buffer(A , m , n); - allocate_init_buffer(X , m , 1); - allocate_init_buffer(Y , n , 1); - copy_buffer(A, A_ref , m ,n); - - - -#ifdef PRINT - printmatrix(A, lda ,m,n,(char *)"A"); - printvector(X, m, (char *)"X"); - -#endif - blis::gerc( - CblasColMajor, - m, - n, - alpha, - X, - incx, - Y, - incy, - A, - lda - ); - -#ifdef PRINT - printmatrix (A, lda ,m , n,(char *)"A blis::gerc\n"); - -#endif - ref_gerc(m, n, &alpha, X, Y, A_ref); - -#ifdef PRINT - printmatrix(A_ref, lda_ref, m, n, (char *)"A_ref output\n"); - - - - -#endif - if(computeErrorM(lda, lda_ref, m, n, A, A_ref )==1) - printf("%s TEST FAIL\n" ,__PRETTY_FUNCTION__); - else - printf("%s TEST PASS\n" , __PRETTY_FUNCTION__); - - - delete[]( A ); - delete[]( X ); - delete[]( Y ); - delete[]( A_ref ); -} - -// ----------------------------------------------------------------------------- -int main( int argc, char** argv ) -{ - test_gerc>( ); - test_gerc>( ); - return 0; - -} diff --git a/testcpp/test_geru.cc b/testcpp/test_geru.cc deleted file mode 100644 index 03e3e6a271..0000000000 --- a/testcpp/test_geru.cc +++ /dev/null @@ -1,169 +0,0 @@ -/* - - BLISPP - C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. - - Copyright (C) 2019, Advanced Micro Devices, Inc. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include -#include "blis.hh" -#include "test.hh" - -using namespace blis; -using namespace std; -//#define PRINT -#define ALPHA 1.0 -#define M 5 -#define N 6 - -/* - * Test application assumes matrices to be column major, non-transposed - */ -template< typename T > -void ref_geru(int64_t m, int64_t n, - T * alpha, - T *X, - T *Y, - T *A ) - -{ -obj_t obj_a; -obj_t obj_x; -obj_t obj_y; -obj_t obj_alpha; -num_t dt; - -if(is_same::value) - dt = BLIS_FLOAT; -else if(is_same::value) - dt = BLIS_DOUBLE; -else if(is_same>::value) - dt = BLIS_SCOMPLEX; -else if(is_same>::value) - dt = BLIS_DCOMPLEX; - - -if(dt == BLIS_FLOAT){ - bli_obj_create_with_attached_buffer( BLIS_FLOAT, 1, 1, alpha, 1,1,&obj_alpha ); - } -else if(dt == BLIS_DOUBLE){ - bli_obj_create_with_attached_buffer( BLIS_DOUBLE, 1, 1, alpha, 1,1,&obj_alpha ); - } - -if(dt == BLIS_SCOMPLEX){ - bli_obj_create_with_attached_buffer( BLIS_SCOMPLEX, 1, 1, alpha, 1,1,&obj_alpha ); - } -else if(dt == BLIS_DCOMPLEX){ - bli_obj_create_with_attached_buffer( BLIS_DCOMPLEX, 1, 1, alpha, 1,1,&obj_alpha ); - } - -bli_obj_create_with_attached_buffer( dt, m, n, A, 1, m, &obj_a ); -bli_obj_create_with_attached_buffer( dt, m, 1, X, 1, m,&obj_x ); -bli_obj_create_with_attached_buffer( dt, n, 1, Y, 1, n,&obj_y ); - -bli_ger( &obj_alpha, - &obj_x, - &obj_y, - &obj_a ); -} - - -template< typename T > -void test_geru( ) -{ - T *A, *X, *Y, *A_ref; - T alpha; - int m,n; - int lda, incx, incy, lda_ref; - - alpha = ALPHA; - m = M; - n = N; - - lda = m; - lda_ref = m; - incx = 1; - incy = 1; - - srand (time(NULL)); - allocate_init_buffer(A , m , n); - allocate_init_buffer(X , m , 1); - allocate_init_buffer(Y , n , 1); -copy_buffer(A, A_ref , m ,n); - - -#ifdef PRINT - printmatrix(A, lda ,m,n,(char *)"A"); - printvector(X, m,(char *) "X"); -#endif - blis::geru( - CblasColMajor, - m, - n, - alpha, - X, - incx, - Y, - incy, - A, - lda - ); - -#ifdef PRINT - printmatrix (A, lda ,m,n,(char *)"A output"); - printvector (X, m,(char *) "X"); - -#endif - ref_geru(m, n, &alpha, X, Y, A_ref); - -#ifdef PRINT - printmatrix(A_ref, lda_ref, m,n,(char *)"A_ref output" ); - -#endif - if(computeErrorM(lda, lda_ref, m, n, A, A_ref )==1) - printf("%s TEST FAIL\n" ,__PRETTY_FUNCTION__); - else - printf("%s TEST PASS\n" , __PRETTY_FUNCTION__); - - - delete[]( A ); - delete[]( X ); - delete[]( Y ); - delete[]( A_ref ); -} - -// ----------------------------------------------------------------------------- -int main( int argc, char** argv ) -{ - test_geru>( ); - test_geru>( ); - return 0; - -} diff --git a/testcpp/test_hemm.cc b/testcpp/test_hemm.cc deleted file mode 100644 index 8b88bcad35..0000000000 --- a/testcpp/test_hemm.cc +++ /dev/null @@ -1,164 +0,0 @@ -/* - - BLISPP - C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. - - Copyright (C) 2019, Advanced Micro Devices, Inc. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include -#include "blis.hh" -#include "test.hh" - -using namespace blis; -using namespace std; -//#define PRINT -#define ALPHA 1.0 -#define BETA 0.0 -#define M 5 -#define N 5 - -/* - * Test application assumes matrices to be column major, non-transposed - */ -template< typename T > -void ref_hemm(int64_t m, int64_t n, - T * alpha, - T *A, - T *B, - T * beta, - T *C ) - -{ - obj_t obj_a, obj_b, obj_c; - obj_t obj_alpha, obj_beta; - num_t dt; - - if(is_same::value) - dt = BLIS_FLOAT; - else if(is_same::value) - dt = BLIS_DOUBLE; - else if(is_same>::value) - dt = BLIS_SCOMPLEX; - else if(is_same>::value) - dt = BLIS_DCOMPLEX; - - bli_obj_create_with_attached_buffer( dt, 1, 1, alpha, 1,1,&obj_alpha ); - bli_obj_create_with_attached_buffer( dt, 1, 1, beta, 1,1,&obj_beta ); - bli_obj_create_with_attached_buffer( dt, m, m, A, 1,m,&obj_a ); - bli_obj_create_with_attached_buffer( dt, m, n, B, 1,n,&obj_b ); - bli_obj_create_with_attached_buffer( dt, m, n, C, 1,m,&obj_c ); - - bli_obj_set_struc( BLIS_HERMITIAN, &obj_a ); - bli_obj_set_uplo( BLIS_LOWER, &obj_a ); - bli_mkherm(&obj_a); - bli_mktrim(&obj_a); - bli_hemm( BLIS_LEFT, - &obj_alpha, - &obj_a, - &obj_b, - &obj_beta, - &obj_c ); - -} -template< typename T > -void test_hemm( ) -{ - T *A, *B, *C, *C_ref; - T alpha, beta; - int m,n; - int lda, ldb, ldc, ldc_ref; - - alpha = ALPHA; - beta = BETA; - m = M; - n = N; - - lda = m; - ldb = n; - ldc = m; - ldc_ref = m; - - srand48 (time(NULL)); - srand (time(NULL)); - allocate_init_buffer(A , m , m); - allocate_init_buffer(B , m , n); - allocate_init_buffer(C , m , n); - copy_buffer(C, C_ref , m ,n); - -#ifdef PRINT - printmatrix(A, lda ,m,m,(char *) "A"); - printmatrix(B, ldb ,m,n,(char *) "B"); - printmatrix(C, ldc ,m,n,(char *) "C"); -#endif - blis::hemm( - CblasColMajor, - CblasLeft, - CblasLower, - m, - n, - alpha, - A, - lda, - B, - ldb, - beta, - C, - ldc - ); - -#ifdef PRINT - printmatrix(C, ldc ,m,n,(char *) "C output"); -#endif - ref_hemm(m, n, &alpha, A, B, &beta, C_ref); - -#ifdef PRINT - printmatrix(C_ref, ldc_ref ,m,n,(char *) "C ref output"); -#endif - if(computeErrorM(ldc, ldc_ref, m, n, C, C_ref )==1) - printf("%s TEST FAIL\n" , __PRETTY_FUNCTION__ ); - else - printf("%s TEST PASS\n" , __PRETTY_FUNCTION__ ); - - - - delete[]( A ); - delete[]( B ); - delete[]( C ); - delete[]( C_ref ); -} - -// ----------------------------------------------------------------------------- -int main( int argc, char** argv ) -{ - test_hemm>( ); - test_hemm>( ); - return 0; - -} diff --git a/testcpp/test_hemv.cc b/testcpp/test_hemv.cc deleted file mode 100644 index 463fdf557f..0000000000 --- a/testcpp/test_hemv.cc +++ /dev/null @@ -1,157 +0,0 @@ -/* - - BLISPP - C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. - - Copyright (C) 2019, Advanced Micro Devices, Inc. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include -#include "blis.hh" -#include "test.hh" - -using namespace blis; -using namespace std; -//#define PRINT -#define ALPHA 1.0 -#define BETA 0.0 -#define N 6 - -/* - * Test application assumes matrices to be column major, non-transposed - */ -template< typename T > -void ref_hemv(int64_t n, - T * alpha, - T *A, - T *X, - T * beta, - T *Y ) - -{ - obj_t obj_a, obj_x, obj_y; - obj_t obj_alpha, obj_beta; - num_t dt; - - if(is_same::value) - dt = BLIS_FLOAT; - else if(is_same::value) - dt = BLIS_DOUBLE; - else if(is_same>::value) - dt = BLIS_SCOMPLEX; - else if(is_same>::value) - dt = BLIS_DCOMPLEX; - - bli_obj_create_with_attached_buffer( dt, 1, 1, alpha, 1,1,&obj_alpha ); - bli_obj_create_with_attached_buffer( dt, 1, 1, beta, 1,1,&obj_beta ); - bli_obj_create_with_attached_buffer( dt, n, n, A, 1,n,&obj_a ); - bli_obj_create_with_attached_buffer( dt, n, 1, X, 1,n,&obj_x ); - bli_obj_create_with_attached_buffer( dt, n, 1, Y, 1,n,&obj_y ); - - bli_obj_set_struc( BLIS_HERMITIAN, &obj_a ); - bli_obj_set_uplo( BLIS_LOWER, &obj_a ); - - bli_hemv( &obj_alpha, - &obj_a, - &obj_x, - &obj_beta, - &obj_y ); - -} -template< typename T > -void test_hemv( ) -{ - T *A, *Y, *Y_ref, *X; - T alpha, beta; - int n; - int lda, incx, incy, incy_ref; - - alpha = ALPHA; - beta = BETA; - n = N; - - lda = n; - incx = 1; - incy = 1; - incy_ref = 1; - - srand (time(NULL)); - allocate_init_buffer(A , n , n); - allocate_init_buffer(X , n , 1); - allocate_init_buffer(Y , n , 1); - copy_buffer(Y, Y_ref , n ,1); - -#ifdef PRINT - printmatrix(A, lda ,n,n, (char *)"A"); - printvector(X, n, (char *)"X"); - printvector(Y, n, (char *)"Y"); -#endif - blis::hemv( - CblasColMajor, - CblasLower, - n, - alpha, - A, - lda, - X, - incx, - beta, - Y, - incy - ); - -#ifdef PRINT - printvector(Y, n, (char *)"Y output"); -#endif - ref_hemv(n, &alpha, A, X, &beta, Y_ref); - -#ifdef PRINT - printvector(Y_ref, n,(char *) "Y_ref output"); -#endif - if(computeErrorV(incy,incy_ref, n, Y, Y_ref )==1) - printf("%s TEST FAIL\n" , __PRETTY_FUNCTION__ ); - else - printf("%s TEST PASS\n" , __PRETTY_FUNCTION__ ); - - - - delete[]( A ); - delete[]( X ); - delete[]( Y ); - delete[]( Y_ref ); -} - -// ----------------------------------------------------------------------------- -int main( int argc, char** argv ) -{ - test_hemv>( ); - test_hemv>( ); - return 0; - -} diff --git a/testcpp/test_her.cc b/testcpp/test_her.cc deleted file mode 100644 index 687d1e90d8..0000000000 --- a/testcpp/test_her.cc +++ /dev/null @@ -1,141 +0,0 @@ -/* - - BLISPP - C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. - - Copyright (C) 2019, Advanced Micro Devices, Inc. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include -#include "blis.hh" -#include "test.hh" - -using namespace blis; -using namespace std; -//#define PRINT -#define ALPHA 1.0 -#define N 6 - -/* - * Test application assumes matrices to be column major, non-transposed - */ -template< typename T > -void ref_her(int64_t n, - real_type * alpha, - T *X, - T *A ) - -{ - obj_t obj_a; - obj_t obj_x; - obj_t obj_alpha; - num_t dt; - - if(is_same>::value) - dt = BLIS_SCOMPLEX; - else if(is_same>::value) - dt = BLIS_DCOMPLEX; - - if(dt == BLIS_SCOMPLEX){ - bli_obj_create_with_attached_buffer( BLIS_FLOAT, 1, 1, alpha, 1,1,&obj_alpha ); - } - else if(dt == BLIS_DCOMPLEX){ - bli_obj_create_with_attached_buffer( BLIS_DOUBLE, 1, 1, alpha, 1,1,&obj_alpha ); - } - - bli_obj_create_with_attached_buffer( dt, n, n, A, 1, n, &obj_a ); - bli_obj_create_with_attached_buffer( dt, n, 1, X, 1, n,&obj_x ); - - bli_obj_set_struc( BLIS_HERMITIAN, &obj_a ); - bli_obj_set_uplo( BLIS_LOWER, &obj_a); - bli_her( &obj_alpha, - &obj_x, - &obj_a ); - -} -template< typename T > -void test_her( ) -{ - T *A, *X, *A_ref; - real_type alpha; - int n; - int lda, incx, lda_ref; - - alpha = ALPHA; - n = N; - - lda = n; - lda_ref = n; - incx = 1; - srand (time(NULL)); - allocate_init_buffer(A , n , n); - allocate_init_buffer(X , n , 1); - copy_buffer(A, A_ref , n ,n); - -#ifdef PRINT - printmatrix(A, lda ,n,n,(char *) "A"); - printvector(X, n,(char *) "X"); -#endif - blis::her( - CblasColMajor, - CblasLower, - n, - alpha, - X, - incx, - A, - lda - ); - -#ifdef PRINT - printmatrix(A, lda ,n,n, (char *)"A output"); -#endif - ref_her(n, &alpha, X, A_ref); -#ifdef PRINT - printmatrix(A_ref, lda_ref, n,n ,(char *) "A refoutput"); -#endif - if(computeErrorM(lda, lda_ref, n, n, A, A_ref )==1) - printf("%s TEST FAIL\n" ,__PRETTY_FUNCTION__); - else - printf("%s TEST PASS\n" , __PRETTY_FUNCTION__); - - - delete[]( A ); - delete[]( X ); - delete[]( A_ref ); -} - -// ----------------------------------------------------------------------------- -int main( int argc, char** argv ) -{ - test_her>( ); - test_her>( ); - return 0; - -} diff --git a/testcpp/test_her2.cc b/testcpp/test_her2.cc deleted file mode 100644 index 2f3ca253ac..0000000000 --- a/testcpp/test_her2.cc +++ /dev/null @@ -1,147 +0,0 @@ -/* - - BLISPP - C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. - - Copyright (C) 2019, Advanced Micro Devices, Inc. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include -#include "blis.hh" -#include "test.hh" - -using namespace blis; -using namespace std; -//#define PRINT -#define ALPHA 1.0 -#define N 6 - -/* - * Test application assumes matrices to be column major, non-transposed - */ -template< typename T > -void ref_her2(int64_t n, - T * alpha, - T *X, - T *Y, - T *A ) - -{ - obj_t obj_a; - obj_t obj_x, obj_y; - obj_t obj_alpha; - num_t dt; - - if(is_same>::value) - dt = BLIS_SCOMPLEX; - else if(is_same>::value) - dt = BLIS_DCOMPLEX; - - bli_obj_create_with_attached_buffer(dt, 1, 1, alpha, 1,1,&obj_alpha ); - - bli_obj_create_with_attached_buffer( dt, n, n, A, 1, n, &obj_a ); - bli_obj_create_with_attached_buffer( dt, n, 1, X, 1, n,&obj_x ); - bli_obj_create_with_attached_buffer( dt, n, 1, Y, 1, n,&obj_y ); - - bli_obj_set_struc( BLIS_HERMITIAN, &obj_a ); - bli_obj_set_uplo( BLIS_LOWER, &obj_a); - bli_her2( &obj_alpha, - &obj_x, - &obj_y, - &obj_a ); - -} -template< typename T > -void test_her2( ) -{ - T *A, *X, *Y, *A_ref; - T alpha; - int n; - int lda, incx, incy, lda_ref; - - alpha = ALPHA; - n = N; - - lda = n; - lda_ref = n; - incx = 1; - incy = 1; - - - srand (time(NULL)); - allocate_init_buffer(A , n , n); - allocate_init_buffer(X , n , 1); - allocate_init_buffer(Y , n , 1); - copy_buffer(A, A_ref , n ,n); - -#ifdef PRINT - printmatrix(A, lda ,n,n,(char *) "A"); - printvector(X, n,(char *) "X"); - printvector(Y, n, (char *)"Y"); -#endif - blis::her2( - CblasColMajor, - CblasLower, - n, - alpha, - X, - incx, - Y, - incy, - A, - lda - ); - -#ifdef PRINT - printmatrix(A, lda , n , n,(char *) "A output"); -#endif - ref_her2(n, &alpha, X, Y, A_ref); -#ifdef PRINT - printmatrix(A_ref, lda , n, n, (char *)"A_ref output"); -#endif - if(computeErrorM(lda, lda_ref, n, n, A, A_ref )==1) - printf("%s TEST FAIL\n" ,__PRETTY_FUNCTION__); - else - printf("%s TEST PASS\n" , __PRETTY_FUNCTION__); - - - delete[]( A ); - delete[]( X ); - delete[]( Y ); - delete[]( A_ref ); -} - -// ----------------------------------------------------------------------------- -int main( int argc, char** argv ) -{ - test_her2>( ); - test_her2>( ); - return 0; - -} diff --git a/testcpp/test_herk.cc b/testcpp/test_herk.cc deleted file mode 100644 index 3febf3e6f1..0000000000 --- a/testcpp/test_herk.cc +++ /dev/null @@ -1,155 +0,0 @@ -/* - - BLISPP - C++ test driver for BLIS CPP herk routine and reference blis herk routine. - - Copyright (C) 2019, Advanced Micro Devices, Inc. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ -#include -#include -#include "blis.hh" -#include "test.hh" - -using namespace blis; -using namespace std; -//#define PRINT -#define ALPHA 1.0 -#define BETA 0.0 -#define N 6 -#define K 6 - -/* - * Test application assumes matrices to be column major, non-transposed - */ -template< typename T > -void ref_herk(int64_t n, int64_t k, - real_type * alpha, - T *A, - real_type * beta, - T *C ) - -{ - obj_t obj_a,obj_c; - obj_t obj_alpha, obj_beta; - num_t dt; - - if(is_same>::value) - dt = BLIS_SCOMPLEX; - else if(is_same>::value) - dt = BLIS_DCOMPLEX; - - if(dt == BLIS_SCOMPLEX){ - bli_obj_create_with_attached_buffer( BLIS_FLOAT, 1, 1, alpha, 1,1,&obj_alpha ); - bli_obj_create_with_attached_buffer( BLIS_FLOAT, 1, 1, beta, 1,1,&obj_beta ); - } - else if(dt == BLIS_DCOMPLEX){ - bli_obj_create_with_attached_buffer( BLIS_DOUBLE, 1, 1, alpha, 1,1,&obj_alpha ); - bli_obj_create_with_attached_buffer( BLIS_DOUBLE, 1, 1, beta, 1,1,&obj_beta ); - } - - bli_obj_create_with_attached_buffer( dt, n, k, A, 1,n,&obj_a ); - bli_obj_create_with_attached_buffer( dt, n, n, C, 1,n,&obj_c ); - - bli_obj_set_struc( BLIS_HERMITIAN, &obj_c ); - bli_obj_set_uplo( BLIS_LOWER, &obj_c ); - bli_obj_set_conjtrans( BLIS_NO_TRANSPOSE, &obj_c ); - bli_herk( &obj_alpha, - &obj_a, - &obj_beta, - &obj_c ); - -} -template< typename T > -void test_herk( ) -{ - T *A, *C, *C_ref; - real_type alpha; - real_type beta; - int n,k; - int lda, ldc, ldc_ref; - - alpha = ALPHA; - beta = BETA; - k = K; - n = N; - - - lda = k; - ldc = n; - ldc_ref = n; - srand (time(NULL)); - allocate_init_buffer(A , n , k); - allocate_init_buffer(C , n , n); - copy_buffer(C, C_ref , n ,n); - -#ifdef PRINT - printmatrix(A, lda ,n,k, (char *)"A"); - printmatrix(C, ldc ,n,n, (char *)"C"); -#endif - blis::herk( - CblasColMajor, - CblasLower, - CblasNoTrans, - n, - k, - alpha, - A, - lda, - beta, - C, - ldc - ); - -#ifdef PRINT - printmatrix(C, ldc ,n,n, (char *)"C output"); -#endif - ref_herk(n, k, &alpha, A, &beta, C_ref); - -#ifdef PRINT - printmatrix(C_ref, ldc_ref ,n,n, (char *)"C ref output"); -#endif - if(computeErrorM(ldc, ldc_ref, n, n, C, C_ref )==1) - printf("%s TEST FAIL\n" ,__PRETTY_FUNCTION__); - else - printf("%s TEST PASS\n" , __PRETTY_FUNCTION__); - - - - delete[]( A ); - delete[]( C ); - delete[]( C_ref ); -} - -// ----------------------------------------------------------------------------- -int main( int argc, char** argv ) -{ - test_herk>( ); - test_herk>( ); - return 0; - -} diff --git a/testcpp/test_hpr.cc b/testcpp/test_hpr.cc deleted file mode 100644 index dfc7bdd4a9..0000000000 --- a/testcpp/test_hpr.cc +++ /dev/null @@ -1,112 +0,0 @@ -/* - - BLISPP - C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. - - Copyright (C) 2019, Advanced Micro Devices, Inc. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include -#include "blis.hh" -#include "test.hh" - -using namespace blis; -using namespace std; -//#define PRINT -#define N 2 - -/* - * Test application assumes matrices to be column major, non-transposed - */ - -template< typename T > -void test_hpr( ) -{ -int n; -real_type alpha; -int incX = -1; - -alpha = 1.0; -n = N; - - -T A[4]; - A[0] = { 0.265, 0.362}; - A[1] = {-0.855, 0.035}; - A[2] = {0.136, 0.133 }; - A[3] = { 0.00, 0.00}; - -T X[2]; - X[0] = { -0.278, -0.686}; - X[1] = {-0.736, -0.918 }; - -T A_ref[4]; - A_ref[0] = { 1.64942, 0.0}; - A_ref[1] = {-0.020644, 0.284692}; - A_ref[2] = {0.68388, 0.0 }; - A_ref[3] = {0.00, 0.00 }; - - - -#ifdef PRINT - printmatrix(A, n,n, n,(char *) "A"); - printvector(X, n, (char *)"X"); -#endif - blis::hpr( - CblasColMajor, - CblasLower, - n, - alpha, - X, - incX, - A - ); - -#ifdef PRINT - printmatrix(A, n , n, n,(char *)"A blis:hpr\n"); - - printmatrix(A_ref, n, n, n,(char *)"A_ref output\n"); -#endif - - if(computeErrorM(n, n, n, n, A, A_ref )==1) - printf("%s TEST FAIL\n" ,__PRETTY_FUNCTION__); - else - printf("%s TEST PASS\n" , __PRETTY_FUNCTION__); - - -} - -// ----------------------------------------------------------------------------- -int main( int argc, char** argv ) -{ - test_hpr>( ); - test_hpr>( ); - return 0; - -} diff --git a/testcpp/test_hpr2.cc b/testcpp/test_hpr2.cc deleted file mode 100644 index 1b8b9b2b4f..0000000000 --- a/testcpp/test_hpr2.cc +++ /dev/null @@ -1,93 +0,0 @@ -/* - - BLISPP - C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. - - Copyright (C) 2019, Advanced Micro Devices, Inc. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include -#include "blis.hh" -#include "test.hh" - -using namespace blis; -using namespace std; -#define N 1 - -/* - * Test application assumes matrices to be column major, non-transposed - */ - -template< typename T > -void test_hpr2( ) -{ -int n; -int incX = -1; -int incY = -1; - n = N; - -T alpha = {-0.3, 0.1}; - -T A[1]; - A[0] = { 0.772, 0.997 }; -T X[1]; - X[0] = { -0.173, -0.839 }; -T Y[1]; - Y[0] = { 0.941, -0.422 }; -T A_ref[1]; - A_ref[0] = { 0.829742, 0.0 }; - - blis::hpr2( - CblasColMajor, - CblasLower, - n, - alpha, - X, - incX, - Y, - incY, - A - ); - - - if(computeErrorM(1, 1, n, n, A, A_ref )==1) - printf("%s TEST FAIL\n" ,__PRETTY_FUNCTION__); - else - printf("%s TEST PASS\n" , __PRETTY_FUNCTION__); -} - -// ----------------------------------------------------------------------------- -int main( int argc, char** argv ) -{ - test_hpr2>( ); - printf("**************\n"); - test_hpr2>( ); - return 0; - -} diff --git a/testcpp/test_nrm2.cc b/testcpp/test_nrm2.cc deleted file mode 100644 index d29ec77788..0000000000 --- a/testcpp/test_nrm2.cc +++ /dev/null @@ -1,100 +0,0 @@ -/* - - BLISPP - C++ test driver for BLIS CPP nrm2 routine and reference blis nrm2 routine. - - Copyright (C) 2019, Advanced Micro Devices, Inc. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include -#include "blis.hh" -#include "test.hh" - -using namespace blis; -using namespace std; -//#define PRINT -#define N 2 -#define ALPHA 0.5 - -#define TOLERANCE 0.0000001 -/* - * Test application assumes matrices to be column major, non-transposed - */ -template< typename T> -void test_nrm2() -{ - - T X[N]; - T nrm2, nrm2_ref; - int n; - int incx; - - n = N; - incx = 1; - - if(is_same::value) - { - X[0] = 0.14f; - X[1] = -0.632f; - nrm2_ref = 0.647320631527f; - } - else if(is_same::value) - { - X[0] = 0.696; - X[1] = -0.804; - nrm2_ref = 1.06340584915; - } - -#ifdef PRINT - printvector(X, n,(char *) "Vector X after blis::nrm2"); -#endif - nrm2 = blis::nrm2( - n, - X, - incx - ); -#ifdef PRINT - printf("Norm of a Vector %E \n", nrm2); - printf("Ref Norm of a Vector %E \n", nrm2_ref); -#endif - - if (fabs(nrm2 - nrm2_ref) > TOLERANCE) - printf("%s TEST FAIL\n" , __PRETTY_FUNCTION__); - else - printf("%s TEST PASS\n" , __PRETTY_FUNCTION__); -} - -// ----------------------------------------------------------------------------- -int main( int argc, char** argv ) -{ - test_nrm2( ); - test_nrm2( ); - return 0; - -} diff --git a/testcpp/test_rot.cc b/testcpp/test_rot.cc deleted file mode 100644 index a2e3fb7086..0000000000 --- a/testcpp/test_rot.cc +++ /dev/null @@ -1,102 +0,0 @@ -/* - - BLISPP - C++ test driver for BLIS CPP rot routine and reference blis rot routine. - - Copyright (C) 2019, Advanced Micro Devices, Inc. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include -#include "blis.hh" -#include "test.hh" - -using namespace blis; -using namespace std; -//#define PRINT -#define N 1 - -/* - * Test application assumes matrices to be column major, non-transposed - */ -template< typename T> -void test_rot() -{ - - T c, s; - T X[N], X_ref[N]; - T Y[N], Y_ref[N]; - int n; - int incx, incy; - - n = N; - incx = 1; - incy = 1; - if(is_same::value){ - c = -1.0f; - s = 0.0f; - X[0] = { -0.314f }; - Y[0] = { -0.406f }; - X_ref[0] = { 0.314f }; - Y_ref[0] = { 0.406f }; - }else{ - c = -1; - s = 0; - X[0] = { -0.176 }; - Y[0] = { -0.165 }; - X_ref[0] = { 0.176 }; - Y_ref[0] = { 0.165 }; - } - -#ifdef PRINT - printvector(X, n, (char *)"Before blis::rot\nVector X"); - printvector(Y, n, (char *)"Vector Y"); -#endif - blis::rot( N, X, incx, Y, incy, c, s); -#ifdef PRINT - printvector(X, n, (char *)"After blis::rot\nVector X"); - printvector(Y, n, (char *) "Vector Y"); - printvector(X, n, (char *) "Expected Output from blis::rot\nVector X"); - printvector(Y, n, (char *)"Vector Y"); -#endif - - if((computeErrorV(incx, incx , n, X, X_ref )==1) || (computeErrorV(incy, incy , n, Y, Y_ref )==1)) - printf("%s TEST FAIL\n" , __PRETTY_FUNCTION__); - else - printf("%s TEST PASS\n" , __PRETTY_FUNCTION__); - -} - -// ----------------------------------------------------------------------------- -int main( int argc, char** argv ) -{ - test_rot( ); - test_rot( ); - return 0; - -} diff --git a/testcpp/test_rotg.cc b/testcpp/test_rotg.cc deleted file mode 100644 index e11571ae3c..0000000000 --- a/testcpp/test_rotg.cc +++ /dev/null @@ -1,108 +0,0 @@ -/* - - BLISPP - C++ test driver for BLIS CPP rotg routine and reference blis rotg routine. - - Copyright (C) 2019, Advanced Micro Devices, Inc. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include -#include "blis.hh" -#include "test.hh" - -using namespace blis; -using namespace std; -//#define PRINT - -/* - * Test application assumes matrices to be column major, non-transposed - */ -template< typename T> -void test_rotg() -{ - - T a, b, c, s; - T a_ref, b_ref, c_ref, s_ref; - - if(is_same::value) - { - a = 1.0f; - b = 1.0f; - a_ref = 1.41421356237f; - b_ref = 1.41421356237f; - c_ref = 0.707106781187f; - s_ref = 0.707106781187f; - }else{ - a = 1; - b = 0; - a_ref = 1; - b_ref = 0; - c_ref = 1; - s_ref = 0; - } - -#ifdef PRINT - cout<< "Before blis::rotg \na Value : " << a << "\n" ; - cout<< "b Value : " << b << "\n" ; -#endif - blis::rotg( - &a, - &b, - &c, - &s - ); - -#ifdef PRINT - cout<< "After blis::rotg \na Value : " << a << "\n" ; - cout<< "b Value : " << b << "\n" ; - cout<< "c Value : " << c << "\n" ; - cout<< "s Value : " << s << "\n" ; -#endif - -#ifdef PRINT - cout<< "Expected Output\na Value : " << a_ref << "\n" ; - cout<< "b Value : " << b_ref << "\n" ; - cout<< "c Value : " << c_ref << "\n" ; - cout<< "s Value : " << s_ref << "\n" ; -#endif - if( (a != a_ref ) || (b != b_ref ) || (c != c_ref ) || (s != s_ref )) - printf("%s TEST FAIL\n" , __PRETTY_FUNCTION__); - else - printf("%s TEST PASS\n" , __PRETTY_FUNCTION__); - -} - -// ----------------------------------------------------------------------------- -int main( int argc, char** argv ) -{ - test_rotg( ); - test_rotg( ); - return 0; - -} diff --git a/testcpp/test_rotm.cc b/testcpp/test_rotm.cc deleted file mode 100644 index aad4504b83..0000000000 --- a/testcpp/test_rotm.cc +++ /dev/null @@ -1,106 +0,0 @@ -/* - - BLISPP - C++ test driver for BLIS CPP rotm routine and reference blis rotm routine. - - Copyright (C) 2019, Advanced Micro Devices, Inc. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include -#include "blis.hh" -#include "test.hh" - -using namespace blis; -using namespace std; -//#define PRINT -#define N 1 - -/* - * Test application assumes matrices to be column major, non-transposed - */ -template< typename T> -void test_rotm() -{ - - T X[N], X_ref[N]; - T Y[N], Y_ref[N]; - int n; - int incx, incy; - const T P[5] = { -1.0f, -4.44982e+03f, -15.5826f, 7.091334e+04f, 2.95912e+04f }; - const T P_double[5] = { 1.0, -1.244580625511e+03, 1.11154682624, - 2.269384716089e-05, -0.0143785338883 }; - n = N; - incx = 1; - incy = 1; - if(is_same::value) - { - X[0] = { -0.034f }; - Y[0] = { -0.56f }; - X_ref[0] = { -3.956017e+04f }; - Y_ref[0] = { -1.657054e+04f }; - }else{ - X[0] = { 0.84 }; - Y[0] = { -0.711 }; - X_ref[0] = { -1.046158725429e+03 }; - Y_ref[0] = { -0.829776862405 }; - } - -#ifdef PRINT - printvector(X, n, (char *)"Before blis::rot\nVector X"); - printvector(Y, n, (char *)"Vector Y"); -#endif - if(is_same::value) - { - blis::rotm( N, X, incx, Y, incy, P); - }else{ - blis::rotm( N, X, incx, Y, incy, P_double); - } -#ifdef PRINT - printvector(X, n, (char *)"After blis::rot\nVector X"); - printvector(Y, n, (char *)"Vector Y"); - printvector(X, n, (char *)"Expected Output from blis::rot\nVector X"); - printvector(Y, n, (char *)"Vector Y"); -#endif - - if((computeErrorV(incx, incx , n, X, X_ref )==1) - || (computeErrorV(incy, incy , n, Y, Y_ref )==1)) - printf("%s TEST FAIL\n" , __PRETTY_FUNCTION__); - else - printf("%s TEST PASS\n" , __PRETTY_FUNCTION__); - -} - -// ----------------------------------------------------------------------------- -int main( int argc, char** argv ) -{ - test_rotm( ); - test_rotm( ); - return 0; - -} diff --git a/testcpp/test_rotmg.cc b/testcpp/test_rotmg.cc deleted file mode 100644 index b2325bb241..0000000000 --- a/testcpp/test_rotmg.cc +++ /dev/null @@ -1,137 +0,0 @@ -/* - - BLISPP - C++ test driver for BLIS CPP rotmg routine and reference blis rotmg routine. - - Copyright (C) 2019, Advanced Micro Devices, Inc. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include -#include "blis.hh" -#include "test.hh" - -using namespace blis; -using namespace std; -//#define PRINT - -/* - * Test application assumes matrices to be column major, non-transposed - */ -template< typename T> -void test_rotmg() -{ - T d1, d2, b1, b2; - T d1_ref, d2_ref, b1_ref; - T h[5] = { -999.0f, -999.1f, -999.2f, -999.3f, -999.4f }; - T h_ref[5] = {-1.0f, 0.0f, 0.0f, 0.0f,0.0f}; - T h_double[5] = { -999.0, -999.1, -999.2, -999.3, -999.4 }; - T h_ref_double[5] = { 1, 0, 0, 0}; - - if(is_same::value) - { - d1 = -1630.28519312f; - d2 = 44320.1964703f; - b1 = 1274.7681352f; - b2 = 0.983006912864f; - d1_ref= 0.0f; - d2_ref= 0.0f; - b1_ref= 0.0f; - }else{ - d1 = -49.1978123005; - d2 = 0.228703451277; - b1 = 1.8901039144; - b2 = 7081.47754386; - d1_ref= 0; - d2_ref= 0; - b1_ref= 0; - } - -#ifdef PRINT - cout<< "Before blis::rotmg \nd1 Value : " << d1 << "\n" ; - cout<< "d2 Value : " << d2 << "\n" ; - cout<< "b1 Value : " << b1 << "\n" ; - printvector(h, 5,(char *) "param"); -#endif - if(is_same::value) - { - blis::rotmg( - &d1, - &d2, - &b1, - b2, - h - ); - }else{ - blis::rotmg( - &d1, - &d2, - &b1, - b2, - h_double - ); - } - -#ifdef PRINT - cout<< "After blis::rotmg \nd1 Value : " << d1 << "\n" ; - cout<< "d2 Value : " << d2 << "\n" ; - cout<< "b1 Value : " << b1 << "\n" ; - printvector(h, 5,(char *) "param"); -#endif - -#ifdef PRINT - cout<< "Expected Output from blis::rotmg \nd1 Value : " << d1_ref << "\n" ; - cout<< "d2 Value : " << d2_ref << "\n" ; - cout<< "b1 Value : " << b1_ref << "\n" ; - printvector(h_ref, 5,(char *) "param"); -#endif - if( (d1 != d1_ref ) || (d2 != d2_ref ) || (b1 != b1_ref ) ) - printf("%s TEST FAIL\n" , __PRETTY_FUNCTION__); - else if(is_same::value){ - if(computeErrorV(1, 1 , 5, h, h_ref )==1) - printf("%s TEST FAIL\n" , __PRETTY_FUNCTION__); - else - printf("%s TEST PASS\n" , __PRETTY_FUNCTION__); - }else if(is_same::value){ - if(computeErrorV(1, 1 , 5, h_double, h_ref_double )==1) - printf("%s TEST FAIL\n" , __PRETTY_FUNCTION__); - else - printf("%s TEST PASS\n" , __PRETTY_FUNCTION__); - }else - printf("%s TEST PASS\n" , __PRETTY_FUNCTION__); - -} - -// ----------------------------------------------------------------------------- -int main( int argc, char** argv ) -{ - test_rotmg( ); - test_rotmg( ); - return 0; - -} diff --git a/testcpp/test_scal.cc b/testcpp/test_scal.cc deleted file mode 100644 index 82b2821a66..0000000000 --- a/testcpp/test_scal.cc +++ /dev/null @@ -1,138 +0,0 @@ -/* - - BLISPP - C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. - - Copyright (C) 2019, Advanced Micro Devices, Inc. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include -#include "blis.hh" -#include "test.hh" - -using namespace blis; -using namespace std; -//#define PRINT -#define N 6 -#define ALPHA 0.5 - -/* - * Test application assumes matrices to be column major, non-transposed - */ -template< typename TA,typename TB> -void ref_scal(int64_t n, - TA * alpha, - TB *X - ) - -{ - obj_t obj_x; - obj_t obj_alpha; - num_t dt_x , dt_alpha; - if(is_same::value) - dt_x = BLIS_FLOAT; - else if(is_same::value) - dt_x = BLIS_DOUBLE; - else if(is_same>::value) - dt_x = BLIS_SCOMPLEX; - else if(is_same>::value) - dt_x = BLIS_DCOMPLEX; - - if(is_same::value) - dt_alpha = BLIS_FLOAT; - else if(is_same::value) - dt_alpha = BLIS_DOUBLE; - else if(is_same>::value) - dt_alpha = BLIS_SCOMPLEX; - else if(is_same>::value) - dt_alpha = BLIS_DCOMPLEX; - - bli_obj_create_with_attached_buffer( dt_alpha, 1, 1, alpha, 1,1,&obj_alpha ); - bli_obj_create_with_attached_buffer( dt_x, n, 1, X, 1, n,&obj_x ); - - bli_scalv(&obj_alpha, - &obj_x - ); - -} -template< typename TA, typename TB> -void test_scal() -{ - TB *X, *X_ref; - TA alpha = ALPHA; - int n; - int incx; - - n = N; - - incx = 1; - srand (time(NULL)); - allocate_init_buffer(X , n , 1); - copy_buffer(X, X_ref , n ,1); - -#ifdef PRINT - printvector(X, n, (char *)"X"); -#endif - blis::scal( - n, - alpha, - X, - incx - ); - -#ifdef PRINT - printvector(X, n, (char *)"X output"); -#endif - ref_scal(n , &alpha , X_ref ); - -#ifdef PRINT - printvector(X_ref, n, (char *)"X ref output"); -#endif - if(computeErrorV(incx, incx , n, X, X_ref )==1) - printf("%s TEST FAIL\n" , __PRETTY_FUNCTION__); - else - printf("%s TEST PASS\n" , __PRETTY_FUNCTION__); - - - delete[]( X ); - delete[]( X_ref ); -} - -// ----------------------------------------------------------------------------- -int main( int argc, char** argv ) -{ - test_scal( ); - test_scal( ); - test_scal , std::complex>( ); - test_scal , std::complex>( ); - test_scal>( ); - test_scal>( ); - return 0; - -} diff --git a/testcpp/test_sdsdot.cc b/testcpp/test_sdsdot.cc deleted file mode 100644 index c903c97d33..0000000000 --- a/testcpp/test_sdsdot.cc +++ /dev/null @@ -1,134 +0,0 @@ -/* - - BLISPP - C++ test driver for BLIS CPP sdsdot routine and reference blis sdsdot routine. - - Copyright (C) 2019, Advanced Micro Devices, Inc. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include -#include "blis.hh" -#include "test.hh" - -using namespace blis; -using namespace std; -//#define PRINT -#define N 1 -#define ALPHA 0 - -/* - * Test application assumes matrices to be column major, non-transposed - */ - - #if 0 -template< typename T > -void ref_sdsot(int64_t n, - T alpha, - T *X, - T *Y, - T *res_ref - ) - -{ - obj_t obj_x; - obj_t obj_y; - obj_t obj_res; - obj_t obj_alpha; - num_t dt; - - if(is_same>::value) - dt = BLIS_SCOMPLEX; - else if(is_same>::value) - dt = BLIS_DCOMPLEX; - - bli_obj_create_with_attached_buffer( dt, n, 1, X, 1, n,&obj_x ); - bli_obj_create_with_attached_buffer( dt, n, 1, Y, 1, n,&obj_y ); - bli_obj_create_with_attached_buffer( dt, 1, 1, &alpha, 1,1,&obj_alpha ); - bli_obj_create_with_attached_buffer( dt, 1, 1, res_ref, 1, 1,&obj_res ); - - bli_ddots( &obj_x, - &obj_y, - &obj_res ); - -} -#endif - -template< typename T > -void test_sdsdot() -{ - - T X[N], Y[N]; - int n; - int incx, incy; - T res = 0, res_ref = 0; - - n = N; - - incx = 1; - incy = 1; - - //srand (time(NULL)); - //allocate_init_buffer(X , n , 1); - //allocate_init_buffer(Y , n , 1); - - X[0] = { 0.733f }; - Y[0] = { 0.825f }; - res_ref = 0.604725f; - res = blis::sdsdot( - n, - ALPHA, - X, - incx, - Y, - incy - ); - -#ifdef PRINT - printf("Dot product = %E \n", res); - -#endif - //ref_sdsot(n, aplha, X, Y , &res_ref ); - -#ifdef PRINT - printf("Ref Dot product %E \n", res_ref); -#endif - if(res != res_ref ) - printf("%s TEST FAIL\n" ,__PRETTY_FUNCTION__); - else - printf("%s TEST PASS\n" , __PRETTY_FUNCTION__); - -} - -// ----------------------------------------------------------------------------- -int main( int argc, char** argv ) -{ - test_sdsdot( ); - return 0; - -} diff --git a/testcpp/test_spr.cc b/testcpp/test_spr.cc deleted file mode 100644 index edb7aa81a9..0000000000 --- a/testcpp/test_spr.cc +++ /dev/null @@ -1,97 +0,0 @@ -/* - - BLISPP - C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. - - Copyright (C) 2019, Advanced Micro Devices, Inc. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include -#include "blis.hh" -#include "test.hh" - -using namespace blis; -using namespace std; -//#define PRINT -#define N 2 - -/* - * Test application assumes matrices to be column major, non-transposed - */ - -template< typename T > -void test_spr( ) -{ - int n; - int incX = -1; - T alpha = -1; - - n = N; - - - T A[] = { 0.819, 0.175, -0.809 }; - T X[] = { -0.645, -0.222 }; - T A_ref[] = { 0.769716, 0.03181, -1.225025 }; - - -#ifdef PRINT - printmatrix(A, n, n, n,(char *) "A"); - printvector(X, n,(char *) "X"); -#endif - blis::spr( - CblasColMajor, - CblasLower, - n, - alpha, - X, - incX, - A - ); - -#ifdef PRINT - printmatrix (A, n ,n, n, (char *)"A blis:spr\n"); - printmatrix(A_ref, n, n, n,(char *)"A_ref blis:spr \n"); -#endif - - if(computeErrorM(1, 1, n, n, A, A_ref )==1) - printf("%s TEST FAIL\n" ,__PRETTY_FUNCTION__); - else - printf("%s TEST PASS\n" , __PRETTY_FUNCTION__); - - -} - -// ----------------------------------------------------------------------------- -int main( int argc, char** argv ) -{ - test_spr( ); - test_spr( ); - return 0; - -} diff --git a/testcpp/test_spr2.cc b/testcpp/test_spr2.cc deleted file mode 100644 index 24f364b8e1..0000000000 --- a/testcpp/test_spr2.cc +++ /dev/null @@ -1,107 +0,0 @@ -/* - - BLISPP - C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. - - Copyright (C) 2019, Advanced Micro Devices, Inc. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include -#include "blis.hh" -#include "test.hh" - -using namespace blis; -using namespace std; -//#define PRINT -#define ALPHA -1.0f -#define N 2 - -/* - * Test application assumes matrices to be column major, non-transposed - */ - -template< typename T > -void test_spr2( ) -{ - int n; - int incX = -1; - int incY = -1; - T alpha; - - alpha = ALPHA; - n = N; - - T A[] = { 0.493f, -0.175f, -0.831f }; - T X[] = { -0.163f, 0.489f }; - T Y[] = { 0.154f, 0.769f }; - T A_ref[]= { -0.259082f, -0.124959f, -0.780796f }; - - - -#ifdef PRINT - printf("Matrix A\n"); - printmatrix(A, incX, n,n,(char *)"A"); - printf("Vector X \n"); - printvector(X, n, (char *)"X"); -#endif - blis::spr2( - CblasColMajor, - CblasLower, - n, - alpha, - X, - incX, - Y, - incY, - A - ); - -#ifdef PRINT - printf("Matrix A after blis:spr2\n"); - printmatrix (A,1 ,n, n,(char *)"A"); - printf("A_ref \n"); - printmatrix(A_ref, 1, n,n,(char *)"A_ref output"); -#endif - - if(computeErrorM(1, 1, n, n, A, A_ref )==1) - printf("%s TEST FAIL\n" ,__PRETTY_FUNCTION__); - else - printf("%s TEST PASS\n" , __PRETTY_FUNCTION__); - - -} - -// ----------------------------------------------------------------------------- -int main( int argc, char** argv ) -{ - test_spr2( ); - test_spr2( ); - return 0; - -} diff --git a/testcpp/test_swap.cc b/testcpp/test_swap.cc deleted file mode 100644 index 8979d90bdf..0000000000 --- a/testcpp/test_swap.cc +++ /dev/null @@ -1,136 +0,0 @@ -/* - - BLISPP - C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. - - Copyright (C) 2019, Advanced Micro Devices, Inc. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include -#include "blis.hh" -#include "test.hh" - -using namespace blis; -using namespace std; -//#define PRINT -#define N 6 - -/* - * Test application assumes matrices to be column major, non-transposed - */ -template< typename T> -void ref_swap(int64_t n, - T *X, - T *Y - ) - -{ - obj_t obj_x, obj_y; - num_t dt; - - if(is_same::value) - dt = BLIS_FLOAT; - else if(is_same::value) - dt = BLIS_DOUBLE; - else if(is_same>::value) - dt = BLIS_SCOMPLEX; - else if(is_same>::value) - dt = BLIS_DCOMPLEX; - - - bli_obj_create_with_attached_buffer( dt, n, 1, X, 1, n,&obj_x ); - bli_obj_create_with_attached_buffer( dt, n, 1, Y, 1, n,&obj_y ); - - bli_swapv( &obj_x, - &obj_y - ); - -} -template< typename T > -void test_swap( ) -{ - T *X, *X_ref, *Y,*Y_ref; - int n; - int incx, incy; - - n = N; - - incx = 1; - incy = 1; - - srand (time(NULL)); - allocate_init_buffer(X , n , 1); - allocate_init_buffer(Y , n , 1); - copy_buffer(X, X_ref , n ,1); - copy_buffer(Y, Y_ref , n ,1); - -#ifdef PRINT - printvector(X, n, (char *)"X"); - printvector(Y, n, (char *)"Y"); -#endif - blis::swap( - n, - X, - incx, - Y, - incy - ); - -#ifdef PRINT - printvector(X, n, (char *)"X output"); - printvector(Y, n, (char *)"Y output"); -#endif - ref_swap(n , X_ref, Y_ref ); - -#ifdef PRINT - printvector(X_ref, n, (char *)"X ref output"); - printvector(Y_ref, n, (char *)"Y ref output"); -#endif - if((computeErrorV(incy, incy,n, Y, Y_ref )==1)||(computeErrorV(incx, incx, n, X, X_ref )==1)) - printf("%s TEST FAIL\n" , __PRETTY_FUNCTION__); - else - printf("%s TEST PASS\n" , __PRETTY_FUNCTION__); - - - delete[]( X ); - delete[]( Y ); - delete[]( Y_ref ); - delete[]( X_ref ); -} - -// ----------------------------------------------------------------------------- -int main( int argc, char** argv ) -{ - test_swap( ); - test_swap( ); - test_swap>( ); - test_swap>( ); - return 0; - -} diff --git a/testcpp/test_symm.cc b/testcpp/test_symm.cc deleted file mode 100644 index b4e10398ff..0000000000 --- a/testcpp/test_symm.cc +++ /dev/null @@ -1,164 +0,0 @@ -/* - - BLISPP - C++ test driver for BLIS CPP symm routine and reference blis symm routine. - - Copyright (C) 2019, Advanced Micro Devices, Inc. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include -#include "blis.hh" -#include "test.hh" - -using namespace blis; -using namespace std; -//#define PRINT -#define ALPHA 1.0 -#define BETA 0.0 -#define M 5 -#define N 5 -/* - * Test application assumes matrices to be column major, non-transposed - */ -template< typename T > -void ref_symm(int64_t m, int64_t n, - // side_t side, - T * alpha, - T *A, - T *B, - T * beta, - T *C ) - -{ - obj_t obj_a, obj_b, obj_c; - obj_t obj_alpha, obj_beta; - num_t dt; - - if(is_same::value) - dt = BLIS_FLOAT; - else if(is_same::value) - dt = BLIS_DOUBLE; - else if(is_same>::value) - dt = BLIS_SCOMPLEX; - else if(is_same>::value) - dt = BLIS_DCOMPLEX; - - - bli_obj_create_with_attached_buffer( dt, 1, 1, alpha, 1,1,&obj_alpha ); - bli_obj_create_with_attached_buffer( dt, 1, 1, beta, 1,1,&obj_beta ); - bli_obj_create_with_attached_buffer( dt, m, m, A, 1,m,&obj_a ); - bli_obj_create_with_attached_buffer( dt, m, n, B, 1,n,&obj_b ); - bli_obj_create_with_attached_buffer( dt, m, n, C, 1,m,&obj_c ); - - bli_obj_set_struc( BLIS_SYMMETRIC, &obj_a ); - bli_obj_set_uplo( BLIS_LOWER, &obj_a ); - bli_symm( BLIS_LEFT, - &obj_alpha, - &obj_a, - &obj_b, - &obj_beta, - &obj_c ); - -} -template< typename T > -void test_symm( ) -{ - T *A, *B, *C, *C_ref; - T alpha, beta; - int m,n; - int lda, ldb, ldc, ldc_ref; - - alpha = ALPHA; - beta = BETA; - m = M; - n = N; - - lda = m; - ldb = n; - ldc = m; - ldc_ref = m; - - srand (time(NULL)); - allocate_init_buffer(A , m , m); - allocate_init_buffer(B , m , n); - allocate_init_buffer(C , m , n); - copy_buffer(C, C_ref , m ,n); - -#ifdef PRINT - printmatrix(A, lda ,m,m, (char *)"A"); - printmatrix(B, ldb ,m,n, (char *)"B"); - printmatrix(C, ldc ,m,n, (char *)"C"); -#endif - blis::symm( - CblasColMajor, - CblasLeft, - CblasLower, - m, - n, - alpha, - A, - lda, - B, - ldb, - beta, - C, - ldc - ); - -#ifdef PRINT - printmatrix(C, ldc ,m,n, (char *)"C output"); -#endif - // ref_symm(m, n, side, &alpha, A, B, &beta, C_ref); - ref_symm(m, n, &alpha, A, B, &beta, C_ref); - -#ifdef PRINT - printmatrix(C_ref, ldc_ref ,m,n, (char *)"C ref output"); -#endif - if(computeErrorM(ldc, ldc_ref, m, n, C, C_ref )==1) - printf("%s TEST FAIL\n" , __PRETTY_FUNCTION__ ); - else - printf("%s TEST PASS\n" , __PRETTY_FUNCTION__ ); - - - delete[]( A ); - delete[]( B ); - delete[]( C ); - delete[]( C_ref ); -} - -// ----------------------------------------------------------------------------- -int main( int argc, char** argv ) -{ - test_symm( ); - test_symm( ); - test_symm>( ); - test_symm>( ); - return 0; - -} diff --git a/testcpp/test_syr.cc b/testcpp/test_syr.cc deleted file mode 100644 index 327cd93947..0000000000 --- a/testcpp/test_syr.cc +++ /dev/null @@ -1,140 +0,0 @@ -/* - - BLISPP - C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. - - Copyright (C) 2019, Advanced Micro Devices, Inc. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include -#include "blis.hh" -#include "test.hh" - -using namespace blis; -using namespace std; -//#define PRINT -#define ALPHA 1.0 -#define N 6 - -/* - * Test application assumes matrices to be column major, non-transposed - */ -template< typename T > -void ref_syr(int64_t n, - T * alpha, - T *X, - T *A ) - -{ - obj_t obj_a; - obj_t obj_x; - obj_t obj_alpha; - num_t dt; - - if(is_same::value) - dt = BLIS_FLOAT; - else if(is_same::value) - dt = BLIS_DOUBLE; - else if(is_same>::value) - dt = BLIS_SCOMPLEX; - else if(is_same>::value) - dt = BLIS_DCOMPLEX; - - bli_obj_create_with_attached_buffer( dt, 1, 1, alpha, 1,1,&obj_alpha ); - bli_obj_create_with_attached_buffer( dt, n, n, A, 1, n, &obj_a ); - bli_obj_create_with_attached_buffer( dt, n, 1, X, 1, n,&obj_x ); - - bli_obj_set_struc( BLIS_SYMMETRIC, &obj_a ); - bli_obj_set_uplo( BLIS_LOWER, &obj_a); - bli_syr( &obj_alpha, - &obj_x, - &obj_a ); - -} -template< typename T > -void test_syr( ) -{ - T *A, *X, *A_ref; - T alpha; - int n; - int lda, incx, lda_ref; - - alpha = ALPHA; - n = N; - - lda = n; - lda_ref = n; - incx = 1; - - srand (time(NULL)); - allocate_init_buffer(A , n , n); - allocate_init_buffer(X , n , 1); - copy_buffer(A, A_ref , n ,n); - -#ifdef PRINT - printmatrix(A, lda ,n,n, (char *)"A"); - printvector(X, n,(char *) "X"); -#endif - blis::syr( - CblasColMajor, - CblasLower, - n, - alpha, - X, - incx, - A, - lda - ); - -#ifdef PRINT - printmatrix(A, lda , n , n,(char *) "A output"); -#endif - ref_syr(n, &alpha, X, A_ref); -#ifdef PRINT - printmatrix(A_ref, lda , n, n, (char *)"A ref output"); -#endif - if(computeErrorM(lda, lda_ref, n, n, A, A_ref )==1) - printf("%s TEST FAIL\n" ,__PRETTY_FUNCTION__); - else - printf("%s TEST PASS\n" , __PRETTY_FUNCTION__); - - - delete[]( A ); - delete[]( X ); - delete[]( A_ref ); -} - -// ----------------------------------------------------------------------------- -int main( int argc, char** argv ) -{ - test_syr( ); - test_syr( ); - return 0; - -} diff --git a/testcpp/test_syr2.cc b/testcpp/test_syr2.cc deleted file mode 100644 index 165ca146f6..0000000000 --- a/testcpp/test_syr2.cc +++ /dev/null @@ -1,149 +0,0 @@ -/* - - BLISPP - C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. - - Copyright (C) 2019, Advanced Micro Devices, Inc. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include -#include "blis.hh" -#include "test.hh" - -using namespace blis; -using namespace std; -//#define PRINT -#define ALPHA 1.0 -#define N 6 - -/* - * Test application assumes matrices to be column major, non-transposed - */ -template< typename T > -void ref_syr2(int64_t n, - T * alpha, - T *X, - T *Y, - T *A ) - -{ - obj_t obj_a; - obj_t obj_x, obj_y; - obj_t obj_alpha; - num_t dt; - - if(is_same::value) - dt = BLIS_FLOAT; - else if(is_same::value) - dt = BLIS_DOUBLE; - else if(is_same>::value) - dt = BLIS_SCOMPLEX; - else if(is_same>::value) - dt = BLIS_DCOMPLEX; - - bli_obj_create_with_attached_buffer( dt, 1, 1, alpha, 1,1,&obj_alpha ); - bli_obj_create_with_attached_buffer( dt, n, n, A, 1, n, &obj_a ); - bli_obj_create_with_attached_buffer( dt, n, 1, X, 1, n,&obj_x ); - bli_obj_create_with_attached_buffer( dt, n, 1, Y, 1, n,&obj_y ); - - bli_obj_set_struc( BLIS_SYMMETRIC, &obj_a ); - bli_obj_set_uplo( BLIS_LOWER, &obj_a); - bli_syr2( &obj_alpha, - &obj_x, - &obj_y, - &obj_a ); - -} -template< typename T > -void test_syr2( ) -{ - T *A, *X, *Y, *A_ref; - T alpha; - int n; - int lda, incx, incy, lda_ref; - - alpha = ALPHA; - n = N; - - lda = n; - lda_ref = n; - incx = 1; - incy = 1; - srand (time(NULL)); - allocate_init_buffer(A , n , n); - allocate_init_buffer(X , n , 1); - allocate_init_buffer(Y , n , 1); - copy_buffer(A, A_ref , n ,n); - -#ifdef PRINT - printmatrix(A, lda ,n,n,(char *) "A"); - printvector(X, n, (char *)"X"); - printvector(Y, n, (char *)"Y"); -#endif - blis::syr2( - CblasColMajor, - CblasLower, - n, - alpha, - X, - incx, - Y, - incy, - A, - lda - ); - -#ifdef PRINT - printmatrix(A, lda , n , n,(char *) "A output"); -#endif - ref_syr2(n, &alpha, X, Y, A_ref); - -#ifdef PRINT - printmatrix(A_ref, lda , n, n, (char *)"A_ref output"); -#endif - if(computeErrorM(lda, lda_ref, n, n, A, A_ref )==1) - printf("%s TEST FAIL\n" ,__PRETTY_FUNCTION__); - else - printf("%s TEST PASS\n" , __PRETTY_FUNCTION__); - - - delete[]( A ); - delete[]( X ); - delete[]( Y ); - delete[]( A_ref ); -} - -// ----------------------------------------------------------------------------- -int main( int argc, char** argv ) -{ - test_syr2( ); - test_syr2( ); - return 0; - -} diff --git a/testcpp/test_syr2k.cc b/testcpp/test_syr2k.cc deleted file mode 100644 index d56ff97a31..0000000000 --- a/testcpp/test_syr2k.cc +++ /dev/null @@ -1,163 +0,0 @@ -/* - - BLISPP - C++ test driver for BLIS CPP syr2k routine and reference blis syr2k routine. - - Copyright (C) 2019, Advanced Micro Devices, Inc. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include -#include "blis.hh" -#include "test.hh" - -using namespace blis; -using namespace std; -//#define PRINT -#define ALPHA 1.0 -#define BETA 0.0 -#define N 6 -#define K 6 - -/* - * Test application assumes matrices to be column major, non-transposed - */ -template< typename T > -void ref_syr2k(int64_t n, int64_t k, - T * alpha, - T *A, - T *B, - T * beta, - T *C ) - -{ - obj_t obj_a, obj_b, obj_c; - obj_t obj_alpha, obj_beta; - num_t dt; - - if(is_same::value) - dt = BLIS_FLOAT; - else if(is_same::value) - dt = BLIS_DOUBLE; - else if(is_same>::value) - dt = BLIS_SCOMPLEX; - else if(is_same>::value) - dt = BLIS_DCOMPLEX; - - bli_obj_create_with_attached_buffer( dt, 1, 1, alpha, 1,1,&obj_alpha ); - bli_obj_create_with_attached_buffer( dt, 1, 1, beta, 1,1,&obj_beta ); - bli_obj_create_with_attached_buffer( dt, n, k, A, 1,n,&obj_a ); - bli_obj_create_with_attached_buffer( dt, k, n, B,1,k,&obj_b ); - bli_obj_create_with_attached_buffer( dt, n, n, C, 1,n,&obj_c ); - - bli_obj_set_struc( BLIS_SYMMETRIC, &obj_c ); - bli_obj_set_uplo( BLIS_LOWER, &obj_c ); - bli_obj_set_conjtrans( BLIS_NO_TRANSPOSE, &obj_c ); - bli_syr2k( &obj_alpha, - &obj_a, - &obj_b, - &obj_beta, - &obj_c ); - -} -template< typename T > -void test_syr2k( ) -{ - T *A, *B, *C, *C_ref; - T alpha; - T beta; - int n,k; - int ldb, lda, ldc, ldc_ref; - - alpha = ALPHA; - beta = BETA; - k = K; - n = N; - - lda = n; - ldb = k; - ldc = n; - ldc_ref = n; - srand (time(NULL)); - allocate_init_buffer(A , n , k); - allocate_init_buffer(B , k , n); - allocate_init_buffer(C , n , n); - copy_buffer(C, C_ref , n ,n); - -#ifdef PRINT - printmatrix(A, lda ,n,k,(char *) "A"); - printmatrix(B, ldb ,k,n,(char *) "B"); - printmatrix(C, ldc ,n,n,(char *) "C"); -#endif - blis::syr2k( - CblasColMajor, - CblasLower, - CblasNoTrans, - n, - k, - alpha, - A, - lda, - B, - ldb, - beta, - C, - ldc - ); - -#ifdef PRINT - printmatrix(C, ldc ,n,n,(char *) "C output"); -#endif - ref_syr2k(n, k, &alpha, A, B, &beta, C_ref); - -#ifdef PRINT - printmatrix(C_ref, ldc_ref ,n,n,(char *) "C ref output"); -#endif - - if(computeErrorM(ldc, ldc_ref, n, n, C, C_ref )==1) - printf("%s TEST FAIL\n" ,__PRETTY_FUNCTION__); - else - printf("%s TEST PASS\n" , __PRETTY_FUNCTION__); - - - delete[]( A ); - delete[]( B ); - delete[]( C ); - delete[]( C_ref ); -} - -// ----------------------------------------------------------------------------- -int main( int argc, char** argv ) -{ - test_syr2k( ); - test_syr2k( ); - test_syr2k>( ); - test_syr2k>( ); - return 0; - -} diff --git a/testcpp/test_syrk.cc b/testcpp/test_syrk.cc deleted file mode 100644 index 3defc22519..0000000000 --- a/testcpp/test_syrk.cc +++ /dev/null @@ -1,152 +0,0 @@ -/* - - BLISPP - C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. - - Copyright (C) 2019, Advanced Micro Devices, Inc. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include -#include "blis.hh" -#include "test.hh" - -using namespace blis; -using namespace std; -//#define PRINT -#define ALPHA 1.0 -#define BETA 0.0 -#define N 6 -#define K 4 -/* - * Test application assumes matrices to be column major, non-transposed - */ -template< typename T > -void ref_syrk(int64_t n, int64_t k, - T * alpha, - T *A, - T * beta, - T *C ) - -{ - obj_t obj_a,obj_c; - obj_t obj_alpha, obj_beta; - num_t dt; - - if(is_same::value) - dt = BLIS_FLOAT; - else if(is_same::value) - dt = BLIS_DOUBLE; - else if(is_same>::value) - dt = BLIS_SCOMPLEX; - else if(is_same>::value) - dt = BLIS_DCOMPLEX; - - bli_obj_create_with_attached_buffer( dt, 1, 1, alpha, 1,1,&obj_alpha ); - bli_obj_create_with_attached_buffer( dt, 1, 1, beta, 1,1,&obj_beta ); - bli_obj_create_with_attached_buffer( dt, n, k, A, 1,n,&obj_a ); - bli_obj_create_with_attached_buffer( dt, n, n, C, 1,n,&obj_c ); - - bli_obj_set_struc( BLIS_SYMMETRIC, &obj_c ); - bli_obj_set_conjtrans( BLIS_NO_TRANSPOSE, &obj_c ); - bli_obj_set_uplo( BLIS_LOWER, &obj_c ); - bli_syrk( &obj_alpha, - &obj_a, - &obj_beta, - &obj_c ); - -} -template< typename T > -void test_syrk( ) -{ - T *A, *C, *C_ref; - T alpha, beta; - int n,k; - int lda, ldc, ldc_ref; - - alpha = ALPHA; - beta = BETA; - k = K; - n = N; - - lda = n; - ldc = n; - ldc_ref = n; - - srand (time(NULL)); - allocate_init_buffer(A , n , k); - allocate_init_buffer(C , n , n); - copy_buffer(C, C_ref , n ,n); - -#ifdef PRINT - printmatrix(A, lda ,n,k, (char *)"A"); - printmatrix(C, ldc ,n,n, (char *)"C"); -#endif - blis::syrk( - CblasColMajor, - CblasLower, - CblasNoTrans, - n, - k, - alpha, - A, - lda, - beta, - C, - ldc - ); - -#ifdef PRINT - printmatrix(C, ldc ,n,n, (char *)"C output"); -#endif - ref_syrk(n, k, &alpha, A, &beta, C_ref); - -#ifdef PRINT - printmatrix(C_ref, ldc_ref ,n,n, (char *)"C ref output"); -#endif - if(computeErrorM(ldc, ldc_ref, n, n, C, C_ref )==1) - printf("%s TEST FAIL\n" ,__PRETTY_FUNCTION__); - else - printf("%s TEST PASS\n" , __PRETTY_FUNCTION__); - - - delete[]( A ); - delete[]( C ); - delete[]( C_ref ); -} - -// ----------------------------------------------------------------------------- -int main( int argc, char** argv ) -{ - test_syrk( ); - test_syrk( ); - test_syrk>( ); - test_syrk>( ); - return 0; - -} diff --git a/testcpp/test_tbmv.cc b/testcpp/test_tbmv.cc deleted file mode 100644 index ba9d565232..0000000000 --- a/testcpp/test_tbmv.cc +++ /dev/null @@ -1,103 +0,0 @@ -/* - - BLISPP - C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. - - Copyright (C) 2019, Advanced Micro Devices, Inc. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include -#include "blis.hh" -#include "test.hh" - -using namespace blis; -using namespace std; -//#define PRINT -//#define PRINT -#define N 3 -#define K 1 -/* - * Test application assumes matrices to be column major, non-transposed - */ - -template< typename T > -void test_tbmv( ) -{ - int n,k,lda; - - k = K; - n = N; - - - lda = n; - T A[] = { 0.439f, -0.484f, -0.952f, -0.508f, 0.381f, -0.889f, -0.192f, -0.279f, -0.155f }; - T X[] = { -0.089f, -0.688f, -0.203f }; - int incX = -1; - T X_ref[] = { -0.24504f, 0.447756f, -0.089117f }; - - -#ifdef PRINT - printmatrix(A, lda ,n,n,(char *)"A"); - printvector(X, n,(char *)"X"); -#endif - blis::tbmv( - CblasColMajor, - CblasLower, - CblasNoTrans, - CblasNonUnit, - n, - k, - A, - lda, - X, - incX - ); - -#ifdef PRINT - printvector(X, n,(char *)"X"); - printvector(X_ref ,n,(char *) "X output"); -#endif - if(computeErrorV(incX, incX, n, X, X_ref )==1) - printf("%s TEST FAIL\n" , __PRETTY_FUNCTION__); - else - printf("%s TEST PASS\n" , __PRETTY_FUNCTION__); - - -} - -// ----------------------------------------------------------------------------- -int main( int argc, char** argv ) -{ - test_tbmv( ); - test_tbmv( ); - test_tbmv>( ); - test_tbmv>( ); - return 0; - -} diff --git a/testcpp/test_tbsv.cc b/testcpp/test_tbsv.cc deleted file mode 100644 index 85bcdb4ffd..0000000000 --- a/testcpp/test_tbsv.cc +++ /dev/null @@ -1,104 +0,0 @@ -/* - - BLISPP - C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. - - Copyright (C) 2019, Advanced Micro Devices, Inc. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include -#include "blis.hh" -#include "test.hh" - -using namespace blis; -using namespace std; -//#define PRINT -//#define PRINT -#define K 1 -#define N 3 -/* - * Test application assumes matrices to be column major, non-transposed - */ - -template< typename T > -void test_tbsv( ) -{ - int n,k,lda; - - k = K; - n = N; - lda = n; - - T A[] = { -0.681f, 0.209f, 0.436f, -0.369f, 0.786f, -0.84f, 0.86f, -0.233f, 0.734f }; - T X[] = { -0.305f, 0.61f, -0.831f }; - int incX = -1; - T X_ref[] = { 0.524539f, -0.961964f, 1.22026f }; - - -#ifdef PRINT - printmatrix(A, lda ,n,n,(char *)"A"); - printvector(X, n,(char *) "X"); -#endif - blis::tbsv( - CblasColMajor, - CblasLower, - CblasNoTrans, - CblasNonUnit, - n, - k, - A, - lda, - X, - incX - ); - -#ifdef PRINT - printvector(X, n, (char *)"X blis::tbsv\n"); - printvector(X_ref, n,(char *) "X_ref blis::tbsv output"); - -#endif - - if(computeErrorV(1,1, n, X, X_ref )==1) - printf("%s TEST FAIL\n" , __PRETTY_FUNCTION__); - else - printf("%s TEST PASS\n" , __PRETTY_FUNCTION__); - - -} - -// ----------------------------------------------------------------------------- -int main( int argc, char** argv ) -{ - test_tbsv( ); - test_tbsv( ); - test_tbsv>( ); - test_tbsv>( ); - return 0; - -} diff --git a/testcpp/test_tpmv.cc b/testcpp/test_tpmv.cc deleted file mode 100644 index e2a41d34aa..0000000000 --- a/testcpp/test_tpmv.cc +++ /dev/null @@ -1,84 +0,0 @@ -/* - - BLISPP - C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. - - Copyright (C) 2019, Advanced Micro Devices, Inc. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include -#include "blis.hh" -#include "test.hh" - -using namespace blis; -using namespace std; -#define N 2 -/* - * Test application assumes matrices to be column major, non-transposed - */ -template< typename T > -void test_tpmv( ) -{ - int n; - - n = N; - - T A[] = { -0.587f, 0.14f, 0.841f }; - T X[] = { -0.213f, 0.885f }; - int incX = -1; - T X_ref[] = { -0.055233f, -0.519495f }; - - blis::tpmv( - CblasColMajor, - CblasLower, - CblasNoTrans, - CblasNonUnit, - n, - A, - X, - incX - ); - - if(computeErrorV(incX, incX, n, X, X_ref )==1) - printf("%s TEST FAIL\n" , __PRETTY_FUNCTION__); - else - printf("%s TEST PASS\n" , __PRETTY_FUNCTION__); - -} - -// ----------------------------------------------------------------------------- -int main( int argc, char** argv ) -{ - test_tpmv( ); - test_tpmv( ); - test_tpmv>( ); - test_tpmv>( ); - return 0; - -} diff --git a/testcpp/test_tpsv.cc b/testcpp/test_tpsv.cc deleted file mode 100644 index a9c3c2109f..0000000000 --- a/testcpp/test_tpsv.cc +++ /dev/null @@ -1,87 +0,0 @@ -/* - - BLISPP - C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. - - Copyright (C) 2019, Advanced Micro Devices, Inc. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include -#include "blis.hh" -#include "test.hh" - -using namespace blis; -using namespace std; -#define N 2 -/* - * Test application assumes matrices to be column major, non-transposed - */ -template< typename T > -void test_tpsv( ) -{ - int n; - n = N; - - T A[] = { -0.381f, 0.53f, 0.451f }; - T X[] = { 0.144f, 0.032f }; - int incX = -1; - T X_ref[] = { 0.417992f, -0.0839895f }; - - - - blis::tpsv( - CblasColMajor, - CblasLower, - CblasNoTrans, - CblasNonUnit, - n, - A, - X, - incX - ); - - - if(computeErrorV(1,1, n, X, X_ref )==1) - printf("%s TEST FAIL\n" , __PRETTY_FUNCTION__); - else - printf("%s TEST PASS\n" , __PRETTY_FUNCTION__); - - -} - -// ----------------------------------------------------------------------------- -int main( int argc, char** argv ) -{ - test_tpsv( ); - test_tpsv( ); - test_tpsv>( ); - test_tpsv>( ); - return 0; - -} diff --git a/testcpp/test_trmm.cc b/testcpp/test_trmm.cc deleted file mode 100644 index c6301f0134..0000000000 --- a/testcpp/test_trmm.cc +++ /dev/null @@ -1,153 +0,0 @@ -/* - - BLISPP - C++ test driver for BLIS CPP trmm routine and reference blis trmm routine. - - Copyright (C) 2019, Advanced Micro Devices, Inc. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include -#include "blis.hh" -#include "test.hh" - -using namespace blis; -using namespace std; -//#define PRINT -#define ALPHA 1.0 -#define M 6 -#define N 4 -/* - * Test application assumes matrices to be column major, non-transposed - */ -template< typename T > -void ref_trmm(int64_t m, int64_t n, - T * alpha, - T *A, - T *B - ) - -{ - obj_t obj_a, obj_b; - obj_t obj_alpha; - num_t dt; - - if(is_same::value) - dt = BLIS_FLOAT; - else if(is_same::value) - dt = BLIS_DOUBLE; - else if(is_same>::value) - dt = BLIS_SCOMPLEX; - else if(is_same>::value) - dt = BLIS_DCOMPLEX; - - bli_obj_create_with_attached_buffer( dt, 1, 1, alpha, 1,1,&obj_alpha ); - bli_obj_create_with_attached_buffer( dt, m, m, A, 1,m,&obj_a ); - bli_obj_create_with_attached_buffer( dt, m, n, B, 1,m,&obj_b ); - - bli_obj_set_struc( BLIS_TRIANGULAR, &obj_a ); - bli_obj_set_uplo( BLIS_LOWER, &obj_a ); - bli_obj_set_conjtrans( BLIS_NO_TRANSPOSE, &obj_a ); - bli_obj_set_diag( BLIS_NONUNIT_DIAG, &obj_a ); - bli_trmm( BLIS_LEFT, - &obj_alpha, - &obj_a, - &obj_b - ); - -} -template< typename T > -void test_trmm( ) -{ - T *A, *B, *B_ref; - T alpha; - int m,n; - int lda, ldb, ldb_ref; - - alpha = ALPHA; - m = M; - n = N; - - lda = m; - ldb = m; - ldb_ref = m; - - srand (time(NULL)); - allocate_init_buffer(A , m , m); - allocate_init_buffer(B , m , n); - copy_buffer(B, B_ref , m ,n); - -#ifdef PRINT - printmatrix(A, lda ,m,m, (char *)"A"); - printmatrix(B, ldb ,m,n, (char *)"B"); -#endif - blis::trmm( - CblasColMajor, - CblasLeft, - CblasLower, - CblasNoTrans, - CblasNonUnit, - m, - n, - alpha, - A, - lda, - B, - ldb - ); - -#ifdef PRINT - printmatrix(B, ldb ,m,n, (char *)"B output"); -#endif - ref_trmm(m, n, &alpha, A, B_ref); - -#ifdef PRINT - printmatrix(B_ref, ldb_ref ,m,n, (char *)"B ref output"); -#endif - if(computeErrorM(ldb, ldb_ref, m, n, B, B_ref )==1) - printf("%s TEST FAIL\n" , __PRETTY_FUNCTION__); - else - printf("%s TEST PASS\n" , __PRETTY_FUNCTION__); - - - - delete[]( A ); - delete[]( B ); - delete[]( B_ref ); -} - -// ----------------------------------------------------------------------------- -int main( int argc, char** argv ) -{ - test_trmm( ); - test_trmm( ); - test_trmm>( ); - test_trmm>( ); - return 0; - -} diff --git a/testcpp/test_trsm.cc b/testcpp/test_trsm.cc deleted file mode 100644 index 4c5ead3bcf..0000000000 --- a/testcpp/test_trsm.cc +++ /dev/null @@ -1,154 +0,0 @@ -/* - - BLISPP - C++ test driver for BLIS CPP trsm routine and reference blis trsm routine. - - Copyright (C) 2019, Advanced Micro Devices, Inc. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include -#include "blis.hh" -#include "test.hh" - -using namespace blis; -using namespace std; -//#define PRINT -#define ALPHA 1.0 -#define M 5 -#define N 4 -/* - * Test application assumes matrices to be column major, non-transposed - */ -template< typename T > -void ref_trsm(int64_t m, int64_t n, - T * alpha, - T *A, - T *B - ) - -{ - obj_t obj_a, obj_b; - obj_t obj_alpha; - num_t dt; - - if(is_same::value) - dt = BLIS_FLOAT; - else if(is_same::value) - dt = BLIS_DOUBLE; - else if(is_same>::value) - dt = BLIS_SCOMPLEX; - else if(is_same>::value) - dt = BLIS_DCOMPLEX; - - bli_obj_create_with_attached_buffer( dt, 1, 1, alpha, 1,1,&obj_alpha ); - bli_obj_create_with_attached_buffer( dt, m, m, A, 1,m,&obj_a ); - bli_obj_create_with_attached_buffer( dt, m, n, B, 1,m,&obj_b ); - - bli_obj_set_struc( BLIS_TRIANGULAR, &obj_a ); - bli_obj_set_uplo( BLIS_LOWER, &obj_a ); - bli_obj_set_conjtrans( BLIS_NO_TRANSPOSE, &obj_a ); - bli_obj_set_diag( BLIS_NONUNIT_DIAG, &obj_a ); - bli_trsm( BLIS_LEFT, - &obj_alpha, - &obj_a, - &obj_b - ); - -} -template< typename T > -void test_trsm( ) -{ - T *A, *B, *B_ref; - T alpha; - int m,n; - int lda, ldb, ldb_ref; - - alpha = ALPHA; - m = M; - n = N; - - lda = m; - ldb = m; - ldb_ref = m; - - srand (time(NULL)); - allocate_init_buffer(A , m , m); - allocate_init_buffer(B , m , n); - copy_buffer(B, B_ref , m ,n); - -#ifdef PRINT - printmatrix(A, lda ,m,m, (char *)"A"); - printmatrix(B, ldb ,m,n, (char *)"B"); -#endif - - blis::trsm( - CblasColMajor, - CblasLeft, - CblasLower, - CblasNoTrans, - CblasNonUnit, - m, - n, - alpha, - A, - lda, - B, - ldb - ); - -#ifdef PRINT - printmatrix(B, ldb ,m,n, (char *)"B output"); -#endif - ref_trsm(m, n, &alpha, A, B_ref); - -#ifdef PRINT - printmatrix(B_ref, ldb_ref ,m,n, (char *)"B ref output"); -#endif - if(computeErrorM(ldb, ldb_ref, m, n, B, B_ref )==1) - printf("%s TEST FAIL\n" , __PRETTY_FUNCTION__); - else - printf("%s TEST PASS\n" , __PRETTY_FUNCTION__); - - - - delete[]( A ); - delete[]( B ); - delete[]( B_ref ); -} - -// ----------------------------------------------------------------------------- -int main( int argc, char** argv ) -{ - test_trsm( ); - test_trsm( ); - test_trsm>( ); - test_trsm>( ); - return 0; - -} diff --git a/testcpp/test_trsv.cc b/testcpp/test_trsv.cc deleted file mode 100644 index d194f097b7..0000000000 --- a/testcpp/test_trsv.cc +++ /dev/null @@ -1,142 +0,0 @@ -/* - - BLISPP - C++ test driver for BLIS CPP gemm routine and reference blis gemm routine. - - Copyright (C) 2019, Advanced Micro Devices, Inc. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include -#include -#include "blis.hh" -#include "test.hh" - -using namespace blis; -using namespace std; -//#define PRINT -//#define PRINT -#define M 5 -#define N 6 -/* - * Test application assumes matrices to be column major, non-transposed - */ -template< typename T > -void ref_trsv(int64_t n, - T *A, - T *X - ) - -{ - obj_t obj_a, obj_x; - num_t dt; - - if(is_same::value) - dt = BLIS_FLOAT; - else if(is_same::value) - dt = BLIS_DOUBLE; - else if(is_same>::value) - dt = BLIS_SCOMPLEX; - else if(is_same>::value) - dt = BLIS_DCOMPLEX; - - bli_obj_create_with_attached_buffer( dt, n, n, A, 1,n,&obj_a ); - bli_obj_create_with_attached_buffer( dt, n, 1, X, 1,n,&obj_x ); - - bli_obj_set_struc( BLIS_TRIANGULAR, &obj_a ); - bli_obj_set_uplo( BLIS_LOWER, &obj_a ); - bli_obj_set_onlytrans( BLIS_NO_TRANSPOSE, &obj_a ); - bli_obj_set_diag( BLIS_NONUNIT_DIAG, &obj_a ); - bli_trsv( &BLIS_ONE, - &obj_a, - &obj_x - ); - -} -template< typename T > -void test_trsv( ) -{ - T *A, *X, *X_ref; - int n; - int lda, incx, incx_ref; - - n = N; - - lda = n; - incx = 1; - incx_ref = 1; - - srand (time(NULL)); - allocate_init_buffer(A , n , n); - allocate_init_buffer(X , n , 1); - copy_buffer(X, X_ref , n ,1); - -#ifdef PRINT - printmatrix(A, lda ,n,n,(char *) "A"); - printvector(X, n,(char *) "X"); -#endif - blis::trsv( - CblasColMajor, - CblasLower, - CblasNoTrans, - CblasNonUnit, - n, - A, - lda, - X, - incx - ); - -#ifdef PRINT - printvector(X, n,(char *) "X output"); -#endif - ref_trsv(n, A, X_ref); - -#ifdef PRINT - printvector(X_ref, n,(char *) "X ref output"); -#endif - if(computeErrorV(incx, incx_ref, n, X, X_ref )==1) - printf("%s TEST FAIL\n" , __PRETTY_FUNCTION__); - else - printf("%s TEST PASS\n" , __PRETTY_FUNCTION__); - - - delete[]( A ); - delete[]( X ); - delete[]( X_ref ); -} - -// ----------------------------------------------------------------------------- -int main( int argc, char** argv ) -{ - test_trsv( ); - test_trsv( ); - test_trsv>( ); - test_trsv>( ); - return 0; - -} diff --git a/testcpp/CMakeLists.txt b/vendor/testcpp/CMakeLists.txt similarity index 100% rename from testcpp/CMakeLists.txt rename to vendor/testcpp/CMakeLists.txt From 7bbb7ee7f229692814e16c1d0ff3546a5eb961f0 Mon Sep 17 00:00:00 2001 From: Meghana Vankadari Date: Mon, 6 Sep 2021 13:12:44 +0530 Subject: [PATCH 010/243] Added weighted thread distibution for SUP GEMMT/SYRK Change-Id: Ia080b8a76e788d923bb3545b3f8f97e39f85cebf --- frame/3/gemmt/bli_gemmt_sup_var1n2m.c | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/frame/3/gemmt/bli_gemmt_sup_var1n2m.c b/frame/3/gemmt/bli_gemmt_sup_var1n2m.c index c45ac56722..ff46d1f52c 100644 --- a/frame/3/gemmt/bli_gemmt_sup_var1n2m.c +++ b/frame/3/gemmt/bli_gemmt_sup_var1n2m.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2020, Advanced Micro Devices, Inc. + Copyright (C) 2020 - 21, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -1569,7 +1569,7 @@ void PASTEMACT(ch,opname,uplo,varname) \ \ /* Compute the JC loop thread range for the current thread. */ \ dim_t jc_start, jc_end; \ - bli_thread_range_sub( thread_jc, n, NR, FALSE, &jc_start, &jc_end ); \ + bli_thread_range_weighted_sub( thread_jc, 0, BLIS_LOWER, m, n, NR, FALSE, &jc_start, &jc_end ); \ const dim_t n_local = jc_end - jc_start; \ \ /* Compute number of primary and leftover components of the JC loop. */ \ @@ -1579,6 +1579,7 @@ void PASTEMACT(ch,opname,uplo,varname) \ dim_t m_off_cblock, n_off_cblock; \ dim_t m_off = 0; \ dim_t n_off = 0; \ + doff_t diagoffc; \ \ /* Loop over the n dimension (NC rows/columns at a time). */ \ /*for ( dim_t jj = 0; jj < jc_iter; jj += 1 )*/ \ @@ -1589,8 +1590,6 @@ void PASTEMACT(ch,opname,uplo,varname) \ \ ctype* restrict b_jc = b_00 + jj * jcstep_b; \ ctype* restrict c_jc = c_00 + jj * jcstep_c; \ -\ - n_off = jj; \ \ /* Grow the thrinfo_t tree. */ \ bszid_t* restrict bszids_pc = &bszids_jc[1]; \ @@ -1617,6 +1616,10 @@ void PASTEMACT(ch,opname,uplo,varname) \ \ /* Only apply beta to the first iteration of the pc loop. */ \ ctype* restrict beta_use = ( pp == 0 ? &beta_local : &one_local ); \ +\ + m_off = 0; \ + n_off = jj; \ + diagoffc = m_off - n_off; \ \ ctype* b_use; \ inc_t rs_b_use, cs_b_use, ps_b_use; \ @@ -1675,7 +1678,7 @@ void PASTEMACT(ch,opname,uplo,varname) \ \ /* Compute the IC loop thread range for the current thread. */ \ dim_t ic_start, ic_end; \ - bli_thread_range_sub( thread_ic, m, MR, FALSE, &ic_start, &ic_end ); \ + bli_thread_range_weighted_sub( thread_ic, -diagoffc, BLIS_UPPER, nc_cur, m, MR, FALSE, &ic_start, &ic_end ); \ const dim_t m_local = ic_end - ic_start; \ \ /* Compute number of primary and leftover components of the IC loop. */ \ @@ -2087,11 +2090,12 @@ void PASTEMACT(ch,opname,uplo,varname) \ \ /* Compute the JC loop thread range for the current thread. */ \ dim_t jc_start, jc_end; \ - bli_thread_range_sub( thread_jc, n, NR, FALSE, &jc_start, &jc_end ); \ + bli_thread_range_weighted_sub( thread_jc, 0, BLIS_UPPER, m, n, NR, FALSE, &jc_start, &jc_end ); \ const dim_t n_local = jc_end - jc_start; \ \ dim_t m_off = 0; \ dim_t n_off = 0; \ + doff_t diagoffc; \ dim_t m_off_cblock, n_off_cblock; \ \ /* Compute number of primary and leftover components of the JC loop. */ \ @@ -2107,8 +2111,6 @@ void PASTEMACT(ch,opname,uplo,varname) \ \ ctype* restrict b_jc = b_00 + jj * jcstep_b; \ ctype* restrict c_jc = c_00 + jj * jcstep_c; \ -\ - n_off = jj; \ \ /* Grow the thrinfo_t tree. */ \ bszid_t* restrict bszids_pc = &bszids_jc[1]; \ @@ -2135,6 +2137,10 @@ void PASTEMACT(ch,opname,uplo,varname) \ \ /* Only apply beta to the first iteration of the pc loop. */ \ ctype* restrict beta_use = ( pp == 0 ? &beta_local : &one_local ); \ +\ + m_off = 0; \ + n_off = jj; \ + diagoffc = m_off - n_off; \ \ ctype* b_use; \ inc_t rs_b_use, cs_b_use, ps_b_use; \ @@ -2193,7 +2199,7 @@ void PASTEMACT(ch,opname,uplo,varname) \ \ /* Compute the IC loop thread range for the current thread. */ \ dim_t ic_start, ic_end; \ - bli_thread_range_sub( thread_ic, m, MR, FALSE, &ic_start, &ic_end ); \ + bli_thread_range_weighted_sub( thread_ic, -diagoffc, BLIS_LOWER, nc_cur, m, MR, FALSE, &ic_start, &ic_end ); \ const dim_t m_local = ic_end - ic_start; \ \ /* Compute number of primary and leftover components of the IC loop. */ \ From 5c770cafeeaca0a89ba4e5f707d1aefc660811a9 Mon Sep 17 00:00:00 2001 From: Meghana Vankadari Date: Wed, 1 Sep 2021 14:20:17 +0530 Subject: [PATCH 011/243] Removed syrk_small code - The current implementation of syrk_small computes the entire C matrix rather than computing triangular part. This implementation is not efficient. AMD-Internal: [CPUPL-1571] Change-Id: I9a153207471a55e52634429062d18ba1a225fed9 --- frame/3/syrk/bli_syrk_front.c | 6 - frame/3/syrk/bli_syrk_front.h | 12 - kernels/zen/3/CMakeLists.txt | 1 - kernels/zen/3/bli_syrk_small.c | 4210 -------------------------------- 4 files changed, 4229 deletions(-) delete mode 100644 kernels/zen/3/bli_syrk_small.c diff --git a/frame/3/syrk/bli_syrk_front.c b/frame/3/syrk/bli_syrk_front.c index 7b8231b352..4b7c8cd75a 100644 --- a/frame/3/syrk/bli_syrk_front.c +++ b/frame/3/syrk/bli_syrk_front.c @@ -61,12 +61,6 @@ void bli_syrk_front bli_obj_alias_to( a, &at_local ); bli_obj_induce_trans( &at_local ); -#ifdef BLIS_ENABLE_SMALL_MATRIX - gint_t status = bli_syrk_small( alpha, &a_local, &at_local, beta, &c_local, - cntx, cntl ); - if ( status == BLIS_SUCCESS ) return; -#endif - // Check parameters. if ( bli_error_checking_is_enabled() ) bli_syrk_check( alpha, a, beta, c, cntx ); diff --git a/frame/3/syrk/bli_syrk_front.h b/frame/3/syrk/bli_syrk_front.h index bf8d26a52c..9c19b0798f 100644 --- a/frame/3/syrk/bli_syrk_front.h +++ b/frame/3/syrk/bli_syrk_front.h @@ -43,16 +43,4 @@ void bli_syrk_front cntl_t* cntl ); -#ifdef BLIS_ENABLE_SMALL_MATRIX -err_t bli_syrk_small - ( - obj_t* alpha, - obj_t* a, - obj_t* b, - obj_t* beta, - obj_t* c, - cntx_t* cntx, - cntl_t* cntl - ); -#endif diff --git a/kernels/zen/3/CMakeLists.txt b/kernels/zen/3/CMakeLists.txt index d27d6cc78d..80f78b471b 100644 --- a/kernels/zen/3/CMakeLists.txt +++ b/kernels/zen/3/CMakeLists.txt @@ -3,7 +3,6 @@ target_sources("${PROJECT_NAME}" PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemm_small.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_syrk_small.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_trsm_small.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_dgemm_ref_k1.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemm_sqp_kernels.c diff --git a/kernels/zen/3/bli_syrk_small.c b/kernels/zen/3/bli_syrk_small.c deleted file mode 100644 index 23d47298c6..0000000000 --- a/kernels/zen/3/bli_syrk_small.c +++ /dev/null @@ -1,4210 +0,0 @@ -/* - -BLIS -An object-based framework for developing high-performance BLAS-like -libraries. - -Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are -met: -- Redistributions of source code must retain the above copyright -notice, this list of conditions and the following disclaimer. -- Redistributions in binary form must reproduce the above copyright -notice, this list of conditions and the following disclaimer in the -documentation and/or other materials provided with the distribution. -- Neither the name of The University of Texas at Austin nor the names -of its contributors may be used to endorse or promote products -derived from this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -THEORY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include "immintrin.h" -#include "xmmintrin.h" -#include "blis.h" - -#ifdef BLIS_ENABLE_SMALL_MATRIX - -#define MR 32 -#define D_MR (MR >> 1) -#define NR 3 - -#define BLIS_ENABLE_PREFETCH -#define F_SCRATCH_DIM (BLIS_SMALL_MATRIX_THRES * BLIS_SMALL_MATRIX_THRES) -static float A_pack[F_SCRATCH_DIM] __attribute__((aligned(64))); -static float C_pack[F_SCRATCH_DIM] __attribute__((aligned(64))); -#define D_BLIS_SMALL_MATRIX_THRES (BLIS_SMALL_MATRIX_THRES / 2 ) -#define D_BLIS_SMALL_M_RECT_MATRIX_THRES (BLIS_SMALL_M_RECT_MATRIX_THRES / 2) -#define D_BLIS_SMALL_K_RECT_MATRIX_THRES (BLIS_SMALL_K_RECT_MATRIX_THRES / 2) -#define D_SCRATCH_DIM (D_BLIS_SMALL_MATRIX_THRES * D_BLIS_SMALL_MATRIX_THRES) -static double D_A_pack[D_SCRATCH_DIM] __attribute__((aligned(64))); -static double D_C_pack[D_SCRATCH_DIM] __attribute__((aligned(64))); -#define BLIS_ATBN_M_THRES 40 // Threshold value of M for/below which small matrix code is called. -#define AT_MR 4 // The kernel dimension of the A transpose SYRK kernel.(AT_MR * NR). -static err_t bli_ssyrk_small - ( - obj_t* alpha, - obj_t* a, - obj_t* b, - obj_t* beta, - obj_t* c, - cntx_t* cntx, - cntl_t* cntl - ); - -static err_t bli_dsyrk_small - ( - obj_t* alpha, - obj_t* a, - obj_t* b, - obj_t* beta, - obj_t* c, - cntx_t* cntx, - cntl_t* cntl - ); - -static err_t bli_ssyrk_small_atbn - ( - obj_t* alpha, - obj_t* a, - obj_t* b, - obj_t* beta, - obj_t* c, - cntx_t* cntx, - cntl_t* cntl - ); - -static err_t bli_dsyrk_small_atbn - ( - obj_t* alpha, - obj_t* a, - obj_t* b, - obj_t* beta, - obj_t* c, - cntx_t* cntx, - cntl_t* cntl - ); -/* -* The bli_syrk_small function will use the -* custom MRxNR kernels, to perform the computation. -* The custom kernels are used if the [M * N] < 240 * 240 -*/ -err_t bli_syrk_small - ( - obj_t* alpha, - obj_t* a, - obj_t* b, - obj_t* beta, - obj_t* c, - cntx_t* cntx, - cntl_t* cntl - ) -{ - // FGVZ: This code was originally in bli_syrk_front(). However, it really - // fits more naturally here within the bli_syrk_small() function. This - // becomes a bit more obvious now that the code is here, as it contains - // cpp macros such as BLIS_SMALL_MATRIX_A_THRES_M_SYRK, which are specific - // to this implementation. - if ( bli_obj_has_trans( a ) ) - { - // Continue with small implementation. - ; - } - else if ( ( bli_obj_length( a ) <= BLIS_SMALL_MATRIX_A_THRES_M_SYRK && - bli_obj_width( a ) < BLIS_SMALL_MATRIX_A_THRES_N_SYRK ) || - ( bli_obj_length( a ) < BLIS_SMALL_MATRIX_A_THRES_M_SYRK && - bli_obj_width( a ) <= BLIS_SMALL_MATRIX_A_THRES_N_SYRK ) ) - { - // Continue with small implementation. - ; - } - else - { - // Reject the problem and return to large code path. - return BLIS_FAILURE; - } - -#ifdef BLIS_ENABLE_MULTITHREADING - return BLIS_NOT_YET_IMPLEMENTED; -#endif - // If alpha is zero, scale by beta and return. - if (bli_obj_equals(alpha, &BLIS_ZERO)) - { - return BLIS_NOT_YET_IMPLEMENTED; - } - - // if row major format return. - if ((bli_obj_row_stride( a ) != 1) || - (bli_obj_row_stride( b ) != 1) || - (bli_obj_row_stride( c ) != 1)) - { - return BLIS_INVALID_ROW_STRIDE; - } - - num_t dt = ((*c).info & (0x7 << 0)); - - if (bli_obj_has_trans( a )) - { - if (bli_obj_has_notrans( b )) - { - if (dt == BLIS_FLOAT) - { - return bli_ssyrk_small_atbn(alpha, a, b, beta, c, cntx, cntl); - } - else if (dt == BLIS_DOUBLE) - { - return bli_dsyrk_small_atbn(alpha, a, b, beta, c, cntx, cntl); - } - } - - return BLIS_NOT_YET_IMPLEMENTED; - } - - if (dt == BLIS_DOUBLE) - { - return bli_dsyrk_small(alpha, a, b, beta, c, cntx, cntl); - } - - if (dt == BLIS_FLOAT) - { - return bli_ssyrk_small(alpha, a, b, beta, c, cntx, cntl); - } - - return BLIS_NOT_YET_IMPLEMENTED; -}; - - -static err_t bli_ssyrk_small - ( - obj_t* alpha, - obj_t* a, - obj_t* b, - obj_t* beta, - obj_t* c, - cntx_t* cntx, - cntl_t* cntl - ) -{ - - int M = bli_obj_length( c ); // number of rows of Matrix C - int N = bli_obj_width( c ); // number of columns of Matrix C - int K = bli_obj_width( a ); // number of columns of OP(A), will be updated if OP(A) is Transpose(A) . - int L = M * N; - - if ((((L) < (BLIS_SMALL_MATRIX_THRES * BLIS_SMALL_MATRIX_THRES)) - || ((M < BLIS_SMALL_M_RECT_MATRIX_THRES) && (K < BLIS_SMALL_K_RECT_MATRIX_THRES))) && ((L!=0) && (K!=0))) - { - - int lda = bli_obj_col_stride(a); // column stride of matrix OP(A), where OP(A) is Transpose(A) if transA enabled. - int ldb = bli_obj_col_stride(b); // column stride of matrix OP(B), where OP(B) is Transpose(B) if transB enabled. - int ldc_matC = bli_obj_col_stride( c ); // column stride of matrix C - int ldc = M;//bli_obj_col_stride( c ); // column stride of static buffer for matrix C - int row_idx, col_idx, k; - int rs_matC = bli_obj_row_stride( c ); - int rsc = 1; - float *A = a->buffer; // pointer to elements of Matrix A - float *B = b->buffer; // pointer to elements of Matrix B - float *C = C_pack; // pointer to elements of Matrix C - float *matCbuf = c->buffer; - - float *tA = A, *tB = B, *tC = C;//, *tA_pack; - float *tA_packed; // temprorary pointer to hold packed A memory pointer - int row_idx_packed; //packed A memory row index - int lda_packed; //lda of packed A - int col_idx_start; //starting index after A matrix is packed. - dim_t tb_inc_row = 1; // row stride of matrix B - dim_t tb_inc_col = ldb; // column stride of matrix B - __m256 ymm4, ymm5, ymm6, ymm7; - __m256 ymm8, ymm9, ymm10, ymm11; - __m256 ymm12, ymm13, ymm14, ymm15; - __m256 ymm0, ymm1, ymm2, ymm3; - - int n_remainder; // If the N is non multiple of 3.(N%3) - int m_remainder; // If the M is non multiple of 32.(M%32) - - float *alpha_cast, *beta_cast; // alpha, beta multiples - alpha_cast = (alpha->buffer); - beta_cast = (beta->buffer); - int required_packing_A = 1; - - // when N is equal to 1 call GEMV instead of SYRK - if (N == 1) - { - bli_gemv - ( - alpha, - a, - b, - beta, - c - ); - return BLIS_SUCCESS; - } - - //update the pointer math if matrix B needs to be transposed. - if (bli_obj_has_trans( b )) - { - tb_inc_col = 1; //switch row and column strides - tb_inc_row = ldb; - } - - if ((N <= 3) || ((MR * K) > F_SCRATCH_DIM)) - { - required_packing_A = 0; - } - /* - * The computation loop runs for MRxN columns of C matrix, thus - * accessing the MRxK A matrix data and KxNR B matrix data. - * The computation is organized as inner loops of dimension MRxNR. - */ - // Process MR rows of C matrix at a time. - for (row_idx = 0; (row_idx + (MR - 1)) < M; row_idx += MR) - { - - col_idx_start = 0; - tA_packed = A; - row_idx_packed = row_idx; - lda_packed = lda; - - // This is the part of the pack and compute optimization. - // During the first column iteration, we store the accessed A matrix into - // contiguous static memory. This helps to keep te A matrix in Cache and - // aviods the TLB misses. - if (required_packing_A) - { - col_idx = 0; - - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; - tA_packed = A_pack; - -#if 0//def BLIS_ENABLE_PREFETCH - _mm_prefetch((char*)(tC + 0), _MM_HINT_T0); - _mm_prefetch((char*)(tC + 16), _MM_HINT_T0); - _mm_prefetch((char*)(tC + ldc), _MM_HINT_T0); - _mm_prefetch((char*)(tC + ldc + 16), _MM_HINT_T0); - _mm_prefetch((char*)(tC + 2 * ldc), _MM_HINT_T0); - _mm_prefetch((char*)(tC + 2 * ldc + 16), _MM_HINT_T0); -#endif - // clear scratch registers. - ymm4 = _mm256_setzero_ps(); - ymm5 = _mm256_setzero_ps(); - ymm6 = _mm256_setzero_ps(); - ymm7 = _mm256_setzero_ps(); - ymm8 = _mm256_setzero_ps(); - ymm9 = _mm256_setzero_ps(); - ymm10 = _mm256_setzero_ps(); - ymm11 = _mm256_setzero_ps(); - ymm12 = _mm256_setzero_ps(); - ymm13 = _mm256_setzero_ps(); - ymm14 = _mm256_setzero_ps(); - ymm15 = _mm256_setzero_ps(); - - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix data and - // multiplies it with the A matrix. - // This loop is processing MR x K - ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0); - ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1); - ymm2 = _mm256_broadcast_ss(tB + tb_inc_col * 2); - tB += tb_inc_row; - - //broadcasted matrix B elements are multiplied - //with matrix A columns. - ymm3 = _mm256_loadu_ps(tA); - _mm256_storeu_ps(tA_packed, ymm3); // the packing of matrix A - // ymm4 += ymm0 * ymm3; - ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4); - // ymm8 += ymm1 * ymm3; - ymm8 = _mm256_fmadd_ps(ymm1, ymm3, ymm8); - // ymm12 += ymm2 * ymm3; - ymm12 = _mm256_fmadd_ps(ymm2, ymm3, ymm12); - - ymm3 = _mm256_loadu_ps(tA + 8); - _mm256_storeu_ps(tA_packed + 8, ymm3); // the packing of matrix A - // ymm5 += ymm0 * ymm3; - ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5); - // ymm9 += ymm1 * ymm3; - ymm9 = _mm256_fmadd_ps(ymm1, ymm3, ymm9); - // ymm13 += ymm2 * ymm3; - ymm13 = _mm256_fmadd_ps(ymm2, ymm3, ymm13); - - ymm3 = _mm256_loadu_ps(tA + 16); - _mm256_storeu_ps(tA_packed + 16, ymm3); // the packing of matrix A - // ymm6 += ymm0 * ymm3; - ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6); - // ymm10 += ymm1 * ymm3; - ymm10 = _mm256_fmadd_ps(ymm1, ymm3, ymm10); - // ymm14 += ymm2 * ymm3; - ymm14 = _mm256_fmadd_ps(ymm2, ymm3, ymm14); - - ymm3 = _mm256_loadu_ps(tA + 24); - _mm256_storeu_ps(tA_packed + 24, ymm3); // the packing of matrix A - // ymm7 += ymm0 * ymm3; - ymm7 = _mm256_fmadd_ps(ymm0, ymm3, ymm7); - // ymm11 += ymm1 * ymm3; - ymm11 = _mm256_fmadd_ps(ymm1, ymm3, ymm11); - // ymm15 += ymm2 * ymm3; - ymm15 = _mm256_fmadd_ps(ymm2, ymm3, ymm15); - - tA += lda; - tA_packed += MR; - } - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_ss(alpha_cast); - //ymm1 = _mm256_broadcast_ss(beta_cast); - - //multiply A*B by alpha. - ymm4 = _mm256_mul_ps(ymm4, ymm0); - ymm5 = _mm256_mul_ps(ymm5, ymm0); - ymm6 = _mm256_mul_ps(ymm6, ymm0); - ymm7 = _mm256_mul_ps(ymm7, ymm0); - ymm8 = _mm256_mul_ps(ymm8, ymm0); - ymm9 = _mm256_mul_ps(ymm9, ymm0); - ymm10 = _mm256_mul_ps(ymm10, ymm0); - ymm11 = _mm256_mul_ps(ymm11, ymm0); - ymm12 = _mm256_mul_ps(ymm12, ymm0); - ymm13 = _mm256_mul_ps(ymm13, ymm0); - ymm14 = _mm256_mul_ps(ymm14, ymm0); - ymm15 = _mm256_mul_ps(ymm15, ymm0); - - // multiply C by beta and accumulate col 1. - /*ymm2 = _mm256_loadu_ps(tC); - ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4); - ymm2 = _mm256_loadu_ps(tC + 8); - ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5); - ymm2 = _mm256_loadu_ps(tC + 16); - ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6); - ymm2 = _mm256_loadu_ps(tC + 24); - ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7);*/ - _mm256_storeu_ps(tC, ymm4); - _mm256_storeu_ps(tC + 8, ymm5); - _mm256_storeu_ps(tC + 16, ymm6); - _mm256_storeu_ps(tC + 24, ymm7); - - // multiply C by beta and accumulate, col 2. - tC += ldc; - /*ymm2 = _mm256_loadu_ps(tC); - ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8); - ymm2 = _mm256_loadu_ps(tC + 8); - ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9); - ymm2 = _mm256_loadu_ps(tC + 16); - ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10); - ymm2 = _mm256_loadu_ps(tC + 24); - ymm11 = _mm256_fmadd_ps(ymm2, ymm1, ymm11);*/ - _mm256_storeu_ps(tC, ymm8); - _mm256_storeu_ps(tC + 8, ymm9); - _mm256_storeu_ps(tC + 16, ymm10); - _mm256_storeu_ps(tC + 24, ymm11); - - // multiply C by beta and accumulate, col 3. - tC += ldc; - /*ymm2 = _mm256_loadu_ps(tC); - ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12); - ymm2 = _mm256_loadu_ps(tC + 8); - ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13); - ymm2 = _mm256_loadu_ps(tC + 16); - ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14); - ymm2 = _mm256_loadu_ps(tC + 24); - ymm15 = _mm256_fmadd_ps(ymm2, ymm1, ymm15);*/ - _mm256_storeu_ps(tC, ymm12); - _mm256_storeu_ps(tC + 8, ymm13); - _mm256_storeu_ps(tC + 16, ymm14); - _mm256_storeu_ps(tC + 24, ymm15); - - // modify the pointer arithematic to use packed A matrix. - col_idx_start = NR; - tA_packed = A_pack; - row_idx_packed = 0; - lda_packed = MR; - } - // Process NR columns of C matrix at a time. - for (col_idx = col_idx_start; (col_idx + (NR - 1)) < N; col_idx += NR) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = tA_packed + row_idx_packed; - -#if 0//def BLIS_ENABLE_PREFETCH - _mm_prefetch((char*)(tC + 0), _MM_HINT_T0); - _mm_prefetch((char*)(tC + 16), _MM_HINT_T0); - _mm_prefetch((char*)(tC + ldc), _MM_HINT_T0); - _mm_prefetch((char*)(tC + ldc + 16), _MM_HINT_T0); - _mm_prefetch((char*)(tC + 2 * ldc), _MM_HINT_T0); - _mm_prefetch((char*)(tC + 2 * ldc + 16), _MM_HINT_T0); -#endif - // clear scratch registers. - ymm4 = _mm256_setzero_ps(); - ymm5 = _mm256_setzero_ps(); - ymm6 = _mm256_setzero_ps(); - ymm7 = _mm256_setzero_ps(); - ymm8 = _mm256_setzero_ps(); - ymm9 = _mm256_setzero_ps(); - ymm10 = _mm256_setzero_ps(); - ymm11 = _mm256_setzero_ps(); - ymm12 = _mm256_setzero_ps(); - ymm13 = _mm256_setzero_ps(); - ymm14 = _mm256_setzero_ps(); - ymm15 = _mm256_setzero_ps(); - - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix data and - // multiplies it with the A matrix. - // This loop is processing MR x K - ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0); - ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1); - ymm2 = _mm256_broadcast_ss(tB + tb_inc_col * 2); - tB += tb_inc_row; - - //broadcasted matrix B elements are multiplied - //with matrix A columns. - ymm3 = _mm256_loadu_ps(tA); - // ymm4 += ymm0 * ymm3; - ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4); - // ymm8 += ymm1 * ymm3; - ymm8 = _mm256_fmadd_ps(ymm1, ymm3, ymm8); - // ymm12 += ymm2 * ymm3; - ymm12 = _mm256_fmadd_ps(ymm2, ymm3, ymm12); - - ymm3 = _mm256_loadu_ps(tA + 8); - // ymm5 += ymm0 * ymm3; - ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5); - // ymm9 += ymm1 * ymm3; - ymm9 = _mm256_fmadd_ps(ymm1, ymm3, ymm9); - // ymm13 += ymm2 * ymm3; - ymm13 = _mm256_fmadd_ps(ymm2, ymm3, ymm13); - - ymm3 = _mm256_loadu_ps(tA + 16); - // ymm6 += ymm0 * ymm3; - ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6); - // ymm10 += ymm1 * ymm3; - ymm10 = _mm256_fmadd_ps(ymm1, ymm3, ymm10); - // ymm14 += ymm2 * ymm3; - ymm14 = _mm256_fmadd_ps(ymm2, ymm3, ymm14); - - ymm3 = _mm256_loadu_ps(tA + 24); - // ymm7 += ymm0 * ymm3; - ymm7 = _mm256_fmadd_ps(ymm0, ymm3, ymm7); - // ymm11 += ymm1 * ymm3; - ymm11 = _mm256_fmadd_ps(ymm1, ymm3, ymm11); - // ymm15 += ymm2 * ymm3; - ymm15 = _mm256_fmadd_ps(ymm2, ymm3, ymm15); - - tA += lda_packed; - } - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_ss(alpha_cast); - //ymm1 = _mm256_broadcast_ss(beta_cast); - - //multiply A*B by alpha. - ymm4 = _mm256_mul_ps(ymm4, ymm0); - ymm5 = _mm256_mul_ps(ymm5, ymm0); - ymm6 = _mm256_mul_ps(ymm6, ymm0); - ymm7 = _mm256_mul_ps(ymm7, ymm0); - ymm8 = _mm256_mul_ps(ymm8, ymm0); - ymm9 = _mm256_mul_ps(ymm9, ymm0); - ymm10 = _mm256_mul_ps(ymm10, ymm0); - ymm11 = _mm256_mul_ps(ymm11, ymm0); - ymm12 = _mm256_mul_ps(ymm12, ymm0); - ymm13 = _mm256_mul_ps(ymm13, ymm0); - ymm14 = _mm256_mul_ps(ymm14, ymm0); - ymm15 = _mm256_mul_ps(ymm15, ymm0); - - // multiply C by beta and accumulate col 1. - /*ymm2 = _mm256_loadu_ps(tC); - ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4); - ymm2 = _mm256_loadu_ps(tC + 8); - ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5); - ymm2 = _mm256_loadu_ps(tC + 16); - ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6); - ymm2 = _mm256_loadu_ps(tC + 24); - ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7);*/ - _mm256_storeu_ps(tC, ymm4); - _mm256_storeu_ps(tC + 8, ymm5); - _mm256_storeu_ps(tC + 16, ymm6); - _mm256_storeu_ps(tC + 24, ymm7); - - // multiply C by beta and accumulate, col 2. - tC += ldc; - /*ymm2 = _mm256_loadu_ps(tC); - ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8); - ymm2 = _mm256_loadu_ps(tC + 8); - ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9); - ymm2 = _mm256_loadu_ps(tC + 16); - ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10); - ymm2 = _mm256_loadu_ps(tC + 24); - ymm11 = _mm256_fmadd_ps(ymm2, ymm1, ymm11);*/ - _mm256_storeu_ps(tC, ymm8); - _mm256_storeu_ps(tC + 8, ymm9); - _mm256_storeu_ps(tC + 16, ymm10); - _mm256_storeu_ps(tC + 24, ymm11); - - // multiply C by beta and accumulate, col 3. - tC += ldc; - /*ymm2 = _mm256_loadu_ps(tC); - ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12); - ymm2 = _mm256_loadu_ps(tC + 8); - ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13); - ymm2 = _mm256_loadu_ps(tC + 16); - ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14); - ymm2 = _mm256_loadu_ps(tC + 24); - ymm15 = _mm256_fmadd_ps(ymm2, ymm1, ymm15);*/ - _mm256_storeu_ps(tC, ymm12); - _mm256_storeu_ps(tC + 8, ymm13); - _mm256_storeu_ps(tC + 16, ymm14); - _mm256_storeu_ps(tC + 24, ymm15); - - } - n_remainder = N - col_idx; - - // if the N is not multiple of 3. - // handling edge case. - if (n_remainder == 2) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; - - // clear scratch registers. - ymm8 = _mm256_setzero_ps(); - ymm9 = _mm256_setzero_ps(); - ymm10 = _mm256_setzero_ps(); - ymm11 = _mm256_setzero_ps(); - ymm12 = _mm256_setzero_ps(); - ymm13 = _mm256_setzero_ps(); - ymm14 = _mm256_setzero_ps(); - ymm15 = _mm256_setzero_ps(); - - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix data and - // multiplies it with the A matrix. - ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0); - ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1); - tB += tb_inc_row; - - //broadcasted matrix B elements are multiplied - //with matrix A columns. - ymm3 = _mm256_loadu_ps(tA); - ymm8 = _mm256_fmadd_ps(ymm0, ymm3, ymm8); - ymm12 = _mm256_fmadd_ps(ymm1, ymm3, ymm12); - - ymm3 = _mm256_loadu_ps(tA + 8); - ymm9 = _mm256_fmadd_ps(ymm0, ymm3, ymm9); - ymm13 = _mm256_fmadd_ps(ymm1, ymm3, ymm13); - - ymm3 = _mm256_loadu_ps(tA + 16); - ymm10 = _mm256_fmadd_ps(ymm0, ymm3, ymm10); - ymm14 = _mm256_fmadd_ps(ymm1, ymm3, ymm14); - - ymm3 = _mm256_loadu_ps(tA + 24); - ymm11 = _mm256_fmadd_ps(ymm0, ymm3, ymm11); - ymm15 = _mm256_fmadd_ps(ymm1, ymm3, ymm15); - - tA += lda; - - } - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_ss(alpha_cast); - //ymm1 = _mm256_broadcast_ss(beta_cast); - - //multiply A*B by alpha. - ymm8 = _mm256_mul_ps(ymm8, ymm0); - ymm9 = _mm256_mul_ps(ymm9, ymm0); - ymm10 = _mm256_mul_ps(ymm10, ymm0); - ymm11 = _mm256_mul_ps(ymm11, ymm0); - ymm12 = _mm256_mul_ps(ymm12, ymm0); - ymm13 = _mm256_mul_ps(ymm13, ymm0); - ymm14 = _mm256_mul_ps(ymm14, ymm0); - ymm15 = _mm256_mul_ps(ymm15, ymm0); - - // multiply C by beta and accumulate, col 1. - /*ymm2 = _mm256_loadu_ps(tC + 0); - ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8); - ymm2 = _mm256_loadu_ps(tC + 8); - ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9); - ymm2 = _mm256_loadu_ps(tC + 16); - ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10); - ymm2 = _mm256_loadu_ps(tC + 24); - ymm11 = _mm256_fmadd_ps(ymm2, ymm1, ymm11);*/ - _mm256_storeu_ps(tC + 0, ymm8); - _mm256_storeu_ps(tC + 8, ymm9); - _mm256_storeu_ps(tC + 16, ymm10); - _mm256_storeu_ps(tC + 24, ymm11); - - // multiply C by beta and accumulate, col 2. - tC += ldc; - /*ymm2 = _mm256_loadu_ps(tC); - ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12); - ymm2 = _mm256_loadu_ps(tC + 8); - ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13); - ymm2 = _mm256_loadu_ps(tC + 16); - ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14); - ymm2 = _mm256_loadu_ps(tC + 24); - ymm15 = _mm256_fmadd_ps(ymm2, ymm1, ymm15);*/ - _mm256_storeu_ps(tC, ymm12); - _mm256_storeu_ps(tC + 8, ymm13); - _mm256_storeu_ps(tC + 16, ymm14); - _mm256_storeu_ps(tC + 24, ymm15); - - col_idx += 2; - } - // if the N is not multiple of 3. - // handling edge case. - if (n_remainder == 1) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; - - // clear scratch registers. - ymm12 = _mm256_setzero_ps(); - ymm13 = _mm256_setzero_ps(); - ymm14 = _mm256_setzero_ps(); - ymm15 = _mm256_setzero_ps(); - - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix data and - // multiplies it with the A matrix. - ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0); - tB += tb_inc_row; - - //broadcasted matrix B elements are multiplied - //with matrix A columns. - ymm3 = _mm256_loadu_ps(tA); - ymm12 = _mm256_fmadd_ps(ymm0, ymm3, ymm12); - - ymm3 = _mm256_loadu_ps(tA + 8); - ymm13 = _mm256_fmadd_ps(ymm0, ymm3, ymm13); - - ymm3 = _mm256_loadu_ps(tA + 16); - ymm14 = _mm256_fmadd_ps(ymm0, ymm3, ymm14); - - ymm3 = _mm256_loadu_ps(tA + 24); - ymm15 = _mm256_fmadd_ps(ymm0, ymm3, ymm15); - - tA += lda; - - } - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_ss(alpha_cast); - //ymm1 = _mm256_broadcast_ss(beta_cast); - - //multiply A*B by alpha. - ymm12 = _mm256_mul_ps(ymm12, ymm0); - ymm13 = _mm256_mul_ps(ymm13, ymm0); - ymm14 = _mm256_mul_ps(ymm14, ymm0); - ymm15 = _mm256_mul_ps(ymm15, ymm0); - - // multiply C by beta and accumulate. - /*ymm2 = _mm256_loadu_ps(tC + 0); - ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12); - ymm2 = _mm256_loadu_ps(tC + 8); - ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13); - ymm2 = _mm256_loadu_ps(tC + 16); - ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14); - ymm2 = _mm256_loadu_ps(tC + 24); - ymm15 = _mm256_fmadd_ps(ymm2, ymm1, ymm15);*/ - - _mm256_storeu_ps(tC + 0, ymm12); - _mm256_storeu_ps(tC + 8, ymm13); - _mm256_storeu_ps(tC + 16, ymm14); - _mm256_storeu_ps(tC + 24, ymm15); - } - } - - m_remainder = M - row_idx; - - if (m_remainder >= 24) - { - m_remainder -= 24; - - for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; - - // clear scratch registers. - ymm4 = _mm256_setzero_ps(); - ymm5 = _mm256_setzero_ps(); - ymm6 = _mm256_setzero_ps(); - ymm8 = _mm256_setzero_ps(); - ymm9 = _mm256_setzero_ps(); - ymm10 = _mm256_setzero_ps(); - ymm12 = _mm256_setzero_ps(); - ymm13 = _mm256_setzero_ps(); - ymm14 = _mm256_setzero_ps(); - - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix data and - // multiplies it with the A matrix. - ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0); - ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1); - ymm2 = _mm256_broadcast_ss(tB + tb_inc_col * 2); - tB += tb_inc_row; - - //broadcasted matrix B elements are multiplied - //with matrix A columns. - ymm3 = _mm256_loadu_ps(tA); - // ymm4 += ymm0 * ymm3; - ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4); - // ymm8 += ymm1 * ymm3; - ymm8 = _mm256_fmadd_ps(ymm1, ymm3, ymm8); - // ymm12 += ymm2 * ymm3; - ymm12 = _mm256_fmadd_ps(ymm2, ymm3, ymm12); - - ymm3 = _mm256_loadu_ps(tA + 8); - // ymm5 += ymm0 * ymm3; - ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5); - // ymm9 += ymm1 * ymm3; - ymm9 = _mm256_fmadd_ps(ymm1, ymm3, ymm9); - // ymm13 += ymm2 * ymm3; - ymm13 = _mm256_fmadd_ps(ymm2, ymm3, ymm13); - - ymm3 = _mm256_loadu_ps(tA + 16); - // ymm6 += ymm0 * ymm3; - ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6); - // ymm10 += ymm1 * ymm3; - ymm10 = _mm256_fmadd_ps(ymm1, ymm3, ymm10); - // ymm14 += ymm2 * ymm3; - ymm14 = _mm256_fmadd_ps(ymm2, ymm3, ymm14); - - tA += lda; - } - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_ss(alpha_cast); - //ymm1 = _mm256_broadcast_ss(beta_cast); - - //multiply A*B by alpha. - ymm4 = _mm256_mul_ps(ymm4, ymm0); - ymm5 = _mm256_mul_ps(ymm5, ymm0); - ymm6 = _mm256_mul_ps(ymm6, ymm0); - ymm8 = _mm256_mul_ps(ymm8, ymm0); - ymm9 = _mm256_mul_ps(ymm9, ymm0); - ymm10 = _mm256_mul_ps(ymm10, ymm0); - ymm12 = _mm256_mul_ps(ymm12, ymm0); - ymm13 = _mm256_mul_ps(ymm13, ymm0); - ymm14 = _mm256_mul_ps(ymm14, ymm0); - - // multiply C by beta and accumulate. - /*ymm2 = _mm256_loadu_ps(tC); - ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4); - ymm2 = _mm256_loadu_ps(tC + 8); - ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5); - ymm2 = _mm256_loadu_ps(tC + 16); - ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6);*/ - _mm256_storeu_ps(tC, ymm4); - _mm256_storeu_ps(tC + 8, ymm5); - _mm256_storeu_ps(tC + 16, ymm6); - - // multiply C by beta and accumulate. - tC += ldc; - /*ymm2 = _mm256_loadu_ps(tC); - ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8); - ymm2 = _mm256_loadu_ps(tC + 8); - ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9); - ymm2 = _mm256_loadu_ps(tC + 16); - ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10);*/ - _mm256_storeu_ps(tC, ymm8); - _mm256_storeu_ps(tC + 8, ymm9); - _mm256_storeu_ps(tC + 16, ymm10); - - // multiply C by beta and accumulate. - tC += ldc; - /*ymm2 = _mm256_loadu_ps(tC); - ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12); - ymm2 = _mm256_loadu_ps(tC + 8); - ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13); - ymm2 = _mm256_loadu_ps(tC + 16); - ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);*/ - _mm256_storeu_ps(tC, ymm12); - _mm256_storeu_ps(tC + 8, ymm13); - _mm256_storeu_ps(tC + 16, ymm14); - - } - n_remainder = N - col_idx; - // if the N is not multiple of 3. - // handling edge case. - if (n_remainder == 2) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; - - // clear scratch registers. - ymm8 = _mm256_setzero_ps(); - ymm9 = _mm256_setzero_ps(); - ymm10 = _mm256_setzero_ps(); - ymm12 = _mm256_setzero_ps(); - ymm13 = _mm256_setzero_ps(); - ymm14 = _mm256_setzero_ps(); - - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix data and - // multiplies it with the A matrix. - ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0); - ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1); - tB += tb_inc_row; - - //broadcasted matrix B elements are multiplied - //with matrix A columns. - ymm3 = _mm256_loadu_ps(tA); - ymm8 = _mm256_fmadd_ps(ymm0, ymm3, ymm8); - ymm12 = _mm256_fmadd_ps(ymm1, ymm3, ymm12); - - ymm3 = _mm256_loadu_ps(tA + 8); - ymm9 = _mm256_fmadd_ps(ymm0, ymm3, ymm9); - ymm13 = _mm256_fmadd_ps(ymm1, ymm3, ymm13); - - ymm3 = _mm256_loadu_ps(tA + 16); - ymm10 = _mm256_fmadd_ps(ymm0, ymm3, ymm10); - ymm14 = _mm256_fmadd_ps(ymm1, ymm3, ymm14); - - tA += lda; - - } - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_ss(alpha_cast); - //ymm1 = _mm256_broadcast_ss(beta_cast); - - //multiply A*B by alpha. - ymm8 = _mm256_mul_ps(ymm8, ymm0); - ymm9 = _mm256_mul_ps(ymm9, ymm0); - ymm10 = _mm256_mul_ps(ymm10, ymm0); - ymm12 = _mm256_mul_ps(ymm12, ymm0); - ymm13 = _mm256_mul_ps(ymm13, ymm0); - ymm14 = _mm256_mul_ps(ymm14, ymm0); - - // multiply C by beta and accumulate. - /*ymm2 = _mm256_loadu_ps(tC + 0); - ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8); - ymm2 = _mm256_loadu_ps(tC + 8); - ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9); - ymm2 = _mm256_loadu_ps(tC + 16); - ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10);*/ - _mm256_storeu_ps(tC + 0, ymm8); - _mm256_storeu_ps(tC + 8, ymm9); - _mm256_storeu_ps(tC + 16, ymm10); - - // multiply C by beta and accumulate. - tC += ldc; - /*ymm2 = _mm256_loadu_ps(tC); - ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12); - ymm2 = _mm256_loadu_ps(tC + 8); - ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13); - ymm2 = _mm256_loadu_ps(tC + 16); - ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);*/ - _mm256_storeu_ps(tC, ymm12); - _mm256_storeu_ps(tC + 8, ymm13); - _mm256_storeu_ps(tC + 16, ymm14); - - col_idx += 2; - } - // if the N is not multiple of 3. - // handling edge case. - if (n_remainder == 1) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; - - // clear scratch registers. - ymm12 = _mm256_setzero_ps(); - ymm13 = _mm256_setzero_ps(); - ymm14 = _mm256_setzero_ps(); - - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix data and - // multiplies it with the A matrix. - ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0); - tB += tb_inc_row; - - //broadcasted matrix B elements are multiplied - //with matrix A columns. - ymm3 = _mm256_loadu_ps(tA); - ymm12 = _mm256_fmadd_ps(ymm0, ymm3, ymm12); - - ymm3 = _mm256_loadu_ps(tA + 8); - ymm13 = _mm256_fmadd_ps(ymm0, ymm3, ymm13); - - ymm3 = _mm256_loadu_ps(tA + 16); - ymm14 = _mm256_fmadd_ps(ymm0, ymm3, ymm14); - - tA += lda; - - } - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_ss(alpha_cast); - //ymm1 = _mm256_broadcast_ss(beta_cast); - - //multiply A*B by alpha. - ymm12 = _mm256_mul_ps(ymm12, ymm0); - ymm13 = _mm256_mul_ps(ymm13, ymm0); - ymm14 = _mm256_mul_ps(ymm14, ymm0); - - // multiply C by beta and accumulate. - /*ymm2 = _mm256_loadu_ps(tC + 0); - ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12); - ymm2 = _mm256_loadu_ps(tC + 8); - ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13); - ymm2 = _mm256_loadu_ps(tC + 16); - ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);*/ - - _mm256_storeu_ps(tC + 0, ymm12); - _mm256_storeu_ps(tC + 8, ymm13); - _mm256_storeu_ps(tC + 16, ymm14); - } - - row_idx += 24; - } - - if (m_remainder >= 16) - { - m_remainder -= 16; - - for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; - - // clear scratch registers. - ymm4 = _mm256_setzero_ps(); - ymm5 = _mm256_setzero_ps(); - ymm6 = _mm256_setzero_ps(); - ymm7 = _mm256_setzero_ps(); - ymm8 = _mm256_setzero_ps(); - ymm9 = _mm256_setzero_ps(); - - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix data and - // multiplies it with the A matrix. - ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0); - ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1); - ymm2 = _mm256_broadcast_ss(tB + tb_inc_col * 2); - tB += tb_inc_row; - - //broadcasted matrix B elements are multiplied - //with matrix A columns. - ymm3 = _mm256_loadu_ps(tA); - ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4); - ymm6 = _mm256_fmadd_ps(ymm1, ymm3, ymm6); - ymm8 = _mm256_fmadd_ps(ymm2, ymm3, ymm8); - - ymm3 = _mm256_loadu_ps(tA + 8); - ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5); - ymm7 = _mm256_fmadd_ps(ymm1, ymm3, ymm7); - ymm9 = _mm256_fmadd_ps(ymm2, ymm3, ymm9); - - tA += lda; - } - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_ss(alpha_cast); - //ymm1 = _mm256_broadcast_ss(beta_cast); - - //multiply A*B by alpha. - ymm4 = _mm256_mul_ps(ymm4, ymm0); - ymm5 = _mm256_mul_ps(ymm5, ymm0); - ymm6 = _mm256_mul_ps(ymm6, ymm0); - ymm7 = _mm256_mul_ps(ymm7, ymm0); - ymm8 = _mm256_mul_ps(ymm8, ymm0); - ymm9 = _mm256_mul_ps(ymm9, ymm0); - - // multiply C by beta and accumulate. - /*ymm2 = _mm256_loadu_ps(tC); - ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4); - ymm2 = _mm256_loadu_ps(tC + 8); - ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);*/ - _mm256_storeu_ps(tC, ymm4); - _mm256_storeu_ps(tC + 8, ymm5); - - // multiply C by beta and accumulate. - tC += ldc; - /*ymm2 = _mm256_loadu_ps(tC); - ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6); - ymm2 = _mm256_loadu_ps(tC + 8); - ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7);*/ - _mm256_storeu_ps(tC, ymm6); - _mm256_storeu_ps(tC + 8, ymm7); - - // multiply C by beta and accumulate. - tC += ldc; - /*ymm2 = _mm256_loadu_ps(tC); - ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8); - ymm2 = _mm256_loadu_ps(tC + 8); - ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9);*/ - _mm256_storeu_ps(tC, ymm8); - _mm256_storeu_ps(tC + 8, ymm9); - - } - n_remainder = N - col_idx; - // if the N is not multiple of 3. - // handling edge case. - if (n_remainder == 2) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; - - // clear scratch registers. - ymm4 = _mm256_setzero_ps(); - ymm5 = _mm256_setzero_ps(); - ymm6 = _mm256_setzero_ps(); - ymm7 = _mm256_setzero_ps(); - - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix data and - // multiplies it with the A matrix. - ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0); - ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1); - tB += tb_inc_row; - - //broadcasted matrix B elements are multiplied - //with matrix A columns. - ymm3 = _mm256_loadu_ps(tA); - ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4); - ymm6 = _mm256_fmadd_ps(ymm1, ymm3, ymm6); - - ymm3 = _mm256_loadu_ps(tA + 8); - ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5); - ymm7 = _mm256_fmadd_ps(ymm1, ymm3, ymm7); - - tA += lda; - } - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_ss(alpha_cast); - //ymm1 = _mm256_broadcast_ss(beta_cast); - - //multiply A*B by alpha. - ymm4 = _mm256_mul_ps(ymm4, ymm0); - ymm5 = _mm256_mul_ps(ymm5, ymm0); - ymm6 = _mm256_mul_ps(ymm6, ymm0); - ymm7 = _mm256_mul_ps(ymm7, ymm0); - - // multiply C by beta and accumulate. - /*ymm2 = _mm256_loadu_ps(tC); - ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4); - ymm2 = _mm256_loadu_ps(tC + 8); - ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);*/ - _mm256_storeu_ps(tC, ymm4); - _mm256_storeu_ps(tC + 8, ymm5); - - // multiply C by beta and accumulate. - tC += ldc; - /*ymm2 = _mm256_loadu_ps(tC); - ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6); - ymm2 = _mm256_loadu_ps(tC + 8); - ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7);*/ - _mm256_storeu_ps(tC, ymm6); - _mm256_storeu_ps(tC + 8, ymm7); - - col_idx += 2; - - } - // if the N is not multiple of 3. - // handling edge case. - if (n_remainder == 1) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; - - ymm4 = _mm256_setzero_ps(); - ymm5 = _mm256_setzero_ps(); - - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix data and - // multiplies it with the A matrix. - ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0); - tB += tb_inc_row; - - //broadcasted matrix B elements are multiplied - //with matrix A columns. - ymm3 = _mm256_loadu_ps(tA); - ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4); - - ymm3 = _mm256_loadu_ps(tA + 8); - ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5); - - tA += lda; - } - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_ss(alpha_cast); - //ymm1 = _mm256_broadcast_ss(beta_cast); - - ymm4 = _mm256_mul_ps(ymm4, ymm0); - ymm5 = _mm256_mul_ps(ymm5, ymm0); - - // multiply C by beta and accumulate. - /*ymm2 = _mm256_loadu_ps(tC); - ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4); - ymm2 = _mm256_loadu_ps(tC + 8); - ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);*/ - _mm256_storeu_ps(tC, ymm4); - _mm256_storeu_ps(tC + 8, ymm5); - - } - - row_idx += 16; - } - - if (m_remainder >= 8) - { - m_remainder -= 8; - - for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; - - // clear scratch registers. - ymm4 = _mm256_setzero_ps(); - ymm5 = _mm256_setzero_ps(); - ymm6 = _mm256_setzero_ps(); - - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix data and - // multiplies it with the A matrix. - ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0); - ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1); - ymm2 = _mm256_broadcast_ss(tB + tb_inc_col * 2); - tB += tb_inc_row; - - //broadcasted matrix B elements are multiplied - //with matrix A columns. - ymm3 = _mm256_loadu_ps(tA); - ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_ps(ymm1, ymm3, ymm5); - ymm6 = _mm256_fmadd_ps(ymm2, ymm3, ymm6); - - tA += lda; - } - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_ss(alpha_cast); - //ymm1 = _mm256_broadcast_ss(beta_cast); - - //multiply A*B by alpha. - ymm4 = _mm256_mul_ps(ymm4, ymm0); - ymm5 = _mm256_mul_ps(ymm5, ymm0); - ymm6 = _mm256_mul_ps(ymm6, ymm0); - - // multiply C by beta and accumulate. - /*ymm2 = _mm256_loadu_ps(tC); - ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);*/ - _mm256_storeu_ps(tC, ymm4); - - // multiply C by beta and accumulate. - tC += ldc; - /*ymm2 = _mm256_loadu_ps(tC); - ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);*/ - _mm256_storeu_ps(tC, ymm5); - - // multiply C by beta and accumulate. - tC += ldc; - /*ymm2 = _mm256_loadu_ps(tC); - ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6);*/ - _mm256_storeu_ps(tC, ymm6); - } - n_remainder = N - col_idx; - // if the N is not multiple of 3. - // handling edge case. - if (n_remainder == 2) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; - - ymm4 = _mm256_setzero_ps(); - ymm5 = _mm256_setzero_ps(); - - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix data and - // multiplies it with the A matrix. - ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0); - ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1); - tB += tb_inc_row; - - //broadcasted matrix B elements are multiplied - //with matrix A columns. - ymm3 = _mm256_loadu_ps(tA); - ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_ps(ymm1, ymm3, ymm5); - - tA += lda; - } - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_ss(alpha_cast); - //ymm1 = _mm256_broadcast_ss(beta_cast); - - //multiply A*B by alpha. - ymm4 = _mm256_mul_ps(ymm4, ymm0); - ymm5 = _mm256_mul_ps(ymm5, ymm0); - - // multiply C by beta and accumulate. - /*ymm2 = _mm256_loadu_ps(tC); - ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);*/ - _mm256_storeu_ps(tC, ymm4); - - // multiply C by beta and accumulate. - tC += ldc; - /*ymm2 = _mm256_loadu_ps(tC); - ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);*/ - _mm256_storeu_ps(tC, ymm5); - - col_idx += 2; - - } - // if the N is not multiple of 3. - // handling edge case. - if (n_remainder == 1) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; - - ymm4 = _mm256_setzero_ps(); - - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix data and - // multiplies it with the A matrix. - ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0); - tB += tb_inc_row; - - //broadcasted matrix B elements are multiplied - //with matrix A columns. - ymm3 = _mm256_loadu_ps(tA); - ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4); - - tA += lda; - } - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_ss(alpha_cast); - //ymm1 = _mm256_broadcast_ss(beta_cast); - - ymm4 = _mm256_mul_ps(ymm4, ymm0); - - // multiply C by beta and accumulate. - /*ymm2 = _mm256_loadu_ps(tC); - ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);*/ - _mm256_storeu_ps(tC, ymm4); - - } - - row_idx += 8; - } - // M is not a multiple of 32. - // The handling of edge case where the remainder - // dimension is less than 8. The padding takes place - // to handle this case. - if ((m_remainder) && (lda > 7)) - { - float f_temp[8]; - - for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; - - // clear scratch registers. - ymm5 = _mm256_setzero_ps(); - ymm7 = _mm256_setzero_ps(); - ymm9 = _mm256_setzero_ps(); - - for (k = 0; k < (K - 1); ++k) - { - // The inner loop broadcasts the B matrix data and - // multiplies it with the A matrix. - ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0); - ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1); - ymm2 = _mm256_broadcast_ss(tB + tb_inc_col * 2); - tB += tb_inc_row; - - //broadcasted matrix B elements are multiplied - //with matrix A columns. - ymm3 = _mm256_loadu_ps(tA); - ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5); - ymm7 = _mm256_fmadd_ps(ymm1, ymm3, ymm7); - ymm9 = _mm256_fmadd_ps(ymm2, ymm3, ymm9); - - tA += lda; - } - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0); - ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1); - ymm2 = _mm256_broadcast_ss(tB + tb_inc_col * 2); - tB += tb_inc_row; - - for (int i = 0; i < m_remainder; i++) - { - f_temp[i] = tA[i]; - } - ymm3 = _mm256_loadu_ps(f_temp); - ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5); - ymm7 = _mm256_fmadd_ps(ymm1, ymm3, ymm7); - ymm9 = _mm256_fmadd_ps(ymm2, ymm3, ymm9); - - ymm0 = _mm256_broadcast_ss(alpha_cast); - //ymm1 = _mm256_broadcast_ss(beta_cast); - - //multiply A*B by alpha. - ymm5 = _mm256_mul_ps(ymm5, ymm0); - ymm7 = _mm256_mul_ps(ymm7, ymm0); - ymm9 = _mm256_mul_ps(ymm9, ymm0); - - - /*for (int i = 0; i < m_remainder; i++) - { - f_temp[i] = tC[i]; - } - ymm2 = _mm256_loadu_ps(f_temp); - ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);*/ - _mm256_storeu_ps(f_temp, ymm5); - for (int i = 0; i < m_remainder; i++) - { - tC[i] = f_temp[i]; - } - - tC += ldc; - /*for (int i = 0; i < m_remainder; i++) - { - f_temp[i] = tC[i]; - } - ymm2 = _mm256_loadu_ps(f_temp); - ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7);*/ - _mm256_storeu_ps(f_temp, ymm7); - for (int i = 0; i < m_remainder; i++) - { - tC[i] = f_temp[i]; - } - - tC += ldc; - /*for (int i = 0; i < m_remainder; i++) - { - f_temp[i] = tC[i]; - } - ymm2 = _mm256_loadu_ps(f_temp); - ymm9 = _mm256_fmadd_ps(ymm2, ymm1, ymm9);*/ - _mm256_storeu_ps(f_temp, ymm9); - for (int i = 0; i < m_remainder; i++) - { - tC[i] = f_temp[i]; - } - } - n_remainder = N - col_idx; - // if the N is not multiple of 3. - // handling edge case. - if (n_remainder == 2) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; - - ymm5 = _mm256_setzero_ps(); - ymm7 = _mm256_setzero_ps(); - - for (k = 0; k < (K - 1); ++k) - { - // The inner loop broadcasts the B matrix data and - // multiplies it with the A matrix. - ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0); - ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1); - tB += tb_inc_row; - - ymm3 = _mm256_loadu_ps(tA); - ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5); - ymm7 = _mm256_fmadd_ps(ymm1, ymm3, ymm7); - - tA += lda; - } - - ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0); - ymm1 = _mm256_broadcast_ss(tB + tb_inc_col * 1); - tB += tb_inc_row; - - for (int i = 0; i < m_remainder; i++) - { - f_temp[i] = tA[i]; - } - ymm3 = _mm256_loadu_ps(f_temp); - ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5); - ymm7 = _mm256_fmadd_ps(ymm1, ymm3, ymm7); - - ymm0 = _mm256_broadcast_ss(alpha_cast); - //ymm1 = _mm256_broadcast_ss(beta_cast); - - ymm5 = _mm256_mul_ps(ymm5, ymm0); - ymm7 = _mm256_mul_ps(ymm7, ymm0); - - /*for (int i = 0; i < m_remainder; i++) - { - f_temp[i] = tC[i]; - } - ymm2 = _mm256_loadu_ps(f_temp); - ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);*/ - _mm256_storeu_ps(f_temp, ymm5); - for (int i = 0; i < m_remainder; i++) - { - tC[i] = f_temp[i]; - } - - tC += ldc; - /*for (int i = 0; i < m_remainder; i++) - { - f_temp[i] = tC[i]; - } - ymm2 = _mm256_loadu_ps(f_temp); - ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7);*/ - _mm256_storeu_ps(f_temp, ymm7); - for (int i = 0; i < m_remainder; i++) - { - tC[i] = f_temp[i]; - } - } - // if the N is not multiple of 3. - // handling edge case. - if (n_remainder == 1) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; - - ymm5 = _mm256_setzero_ps(); - - for (k = 0; k < (K - 1); ++k) - { - // The inner loop broadcasts the B matrix data and - // multiplies it with the A matrix. - ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0); - tB += tb_inc_row; - - ymm3 = _mm256_loadu_ps(tA); - ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5); - - tA += lda; - } - - ymm0 = _mm256_broadcast_ss(tB + tb_inc_col * 0); - tB += tb_inc_row; - - for (int i = 0; i < m_remainder; i++) - { - f_temp[i] = tA[i]; - } - ymm3 = _mm256_loadu_ps(f_temp); - ymm5 = _mm256_fmadd_ps(ymm0, ymm3, ymm5); - - ymm0 = _mm256_broadcast_ss(alpha_cast); - //ymm1 = _mm256_broadcast_ss(beta_cast); - - // multiply C by beta and accumulate. - ymm5 = _mm256_mul_ps(ymm5, ymm0); - - /*for (int i = 0; i < m_remainder; i++) - { - f_temp[i] = tC[i]; - } - ymm2 = _mm256_loadu_ps(f_temp); - ymm5 = _mm256_fmadd_ps(ymm2, ymm1, ymm5);*/ - _mm256_storeu_ps(f_temp, ymm5); - for (int i = 0; i < m_remainder; i++) - { - tC[i] = f_temp[i]; - } - } - m_remainder = 0; - } - - if (m_remainder) - { - float result; - for (; row_idx < M; row_idx += 1) - { - for (col_idx = 0; col_idx < N; col_idx += 1) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; - - result = 0; - for (k = 0; k < K; ++k) - { - result += (*tA) * (*tB); - tA += lda; - tB += tb_inc_row; - } - - result *= (*alpha_cast); - (*tC) = /*(*tC) * (*beta_cast) + */result; - } - } - } - - //copy/compute sryk values back to C using SIMD - if ( bli_seq0( *beta_cast ) ) - {//just copy in case of beta = 0 - dim_t _i, _j, k, _l; - if(bli_obj_is_lower(c)) // c is lower - { - //first column - _j = 0; - k = M >> 3; - _i = 0; - for ( _l = 0; _l < k; _l++ ) - { - ymm0 = _mm256_loadu_ps((C + _i*rsc)); - _mm256_storeu_ps((matCbuf + _i*rs_matC), ymm0); - _i += 8; - } - while (_i < M ) - { - bli_sscopys( *(C + _i*rsc + _j*ldc), - *(matCbuf + _i*rs_matC + _j*ldc_matC) ); - _i++; - } - _j++; - while ( _j < N ) //next column - { - //k = (_j + (8 - (_j & 7))); - _l = _j & 7; - k = (_l != 0) ? (_j + (8 - _l)) : _j; - k = (k <= M) ? k : M; - for ( _i = _j; _i < k; ++_i ) - { - bli_sscopys( *(C + _i*rsc + _j*ldc), - *(matCbuf + _i*rs_matC + _j*ldc_matC) ); - } - k = (M - _i) >> 3; - _l = 0; - while ( _l < k ) - { - ymm0 = _mm256_loadu_ps((C + _i*rsc + _j*ldc)); - _mm256_storeu_ps((matCbuf + _i*rs_matC + _j*ldc_matC), ymm0); - - _i += 8; - _l++; - } - while (_i < M ) - { - bli_sscopys( *(C + _i*rsc + _j*ldc), - *(matCbuf + _i*rs_matC + _j*ldc_matC) ); - _i++; - } - _j++; - } - } - else //c is upper - { - for ( _j = 0; _j < N; ++_j ) - { - k = (_j + 1) >> 3; - _i = 0; - _l = 0; - while ( _l < k ) - { - ymm0 = _mm256_loadu_ps((C + _i*rsc + _j*ldc)); - _mm256_storeu_ps((matCbuf + _i*rs_matC + _j*ldc_matC), ymm0); - _i += 8; - _l++; - } - while (_i <= _j ) - { - bli_sscopys( *(C + _i*rsc + _j*ldc), - *(matCbuf + _i*rs_matC + _j*ldc_matC) ); - ++_i; - } - } - } - } - else - {//when beta is non-zero, fmadd and store the results - dim_t _i, _j, k, _l; - ymm1 = _mm256_broadcast_ss(beta_cast); - if(bli_obj_is_lower(c)) //c is lower - { - //first column - _j = 0; - k = M >> 3; - _i = 0; - for ( _l = 0; _l < k; _l++ ) - { - ymm2 = _mm256_loadu_ps((matCbuf + _i*rs_matC)); - ymm0 = _mm256_loadu_ps((C + _i*rsc)); - ymm0 = _mm256_fmadd_ps(ymm2, ymm1, ymm0); - _mm256_storeu_ps((matCbuf + _i*rs_matC), ymm0); - _i += 8; - } - while (_i < M ) - { - bli_sssxpbys( *(C + _i*rsc + _j*ldc), - *(beta_cast), - *(matCbuf + _i*rs_matC + _j*ldc_matC) ); - _i++; - } - _j++; - while ( _j < N ) //next column - { - //k = (_j + (8 - (_j & 7))); - _l = _j & 7; - k = (_l != 0) ? (_j + (8 - _l)) : _j; - k = (k <= M) ? k : M; - for ( _i = _j; _i < k; ++_i ) - { - bli_sssxpbys( *(C + _i*rsc + _j*ldc), - *(beta_cast), - *(matCbuf + _i*rs_matC + _j*ldc_matC) ); - } - k = (M - _i) >> 3; - _l = 0; - while ( _l < k ) - { - ymm2 = _mm256_loadu_ps((matCbuf + _i*rs_matC + _j*ldc_matC)); - ymm0 = _mm256_loadu_ps((C + _i*rsc + _j*ldc)); - ymm0 = _mm256_fmadd_ps(ymm2, ymm1, ymm0); - _mm256_storeu_ps((matCbuf + _i*rs_matC + _j*ldc_matC), ymm0); - - _i += 8; - _l++; - } - while (_i < M ) - { - bli_sssxpbys( *(C + _i*rsc + _j*ldc), - *(beta_cast), - *(matCbuf + _i*rs_matC + _j*ldc_matC) ); - _i++; - } - _j++; - } - } - else //c is upper - { - for ( _j = 0; _j < N; ++_j ) - { - k = (_j + 1) >> 3; - _i = 0; - _l = 0; - while ( _l < k ) - { - ymm2 = _mm256_loadu_ps((matCbuf + _i*rs_matC + _j*ldc_matC)); - ymm0 = _mm256_loadu_ps((C + _i*rsc + _j*ldc)); - ymm0 = _mm256_fmadd_ps(ymm2, ymm1, ymm0); - _mm256_storeu_ps((matCbuf + _i*rs_matC + _j*ldc_matC), ymm0); - _i += 8; - _l++; - } - while (_i <= _j ) - { - bli_sssxpbys( *(C + _i*rsc + _j*ldc), - *(beta_cast), - *(matCbuf + _i*rs_matC + _j*ldc_matC) ); - ++_i; - } - } - } - } - - return BLIS_SUCCESS; - } - else - return BLIS_NONCONFORMAL_DIMENSIONS; - - -}; - -static err_t bli_dsyrk_small - ( - obj_t* alpha, - obj_t* a, - obj_t* b, - obj_t* beta, - obj_t* c, - cntx_t* cntx, - cntl_t* cntl - ) -{ - - int M = bli_obj_length( c ); // number of rows of Matrix C - int N = bli_obj_width( c ); // number of columns of Matrix C - int K = bli_obj_width( a ); // number of columns of OP(A), will be updated if OP(A) is Transpose(A) . - int L = M * N; - - // If alpha is zero, scale by beta and return. - if ((((L) < (D_BLIS_SMALL_MATRIX_THRES * D_BLIS_SMALL_MATRIX_THRES)) - || ((M < D_BLIS_SMALL_M_RECT_MATRIX_THRES) && (K < D_BLIS_SMALL_K_RECT_MATRIX_THRES))) && ((L!=0) && (K!=0))) - { - - int lda = bli_obj_col_stride( a ); // column stride of matrix OP(A), where OP(A) is Transpose(A) if transA enabled. - int ldb = bli_obj_col_stride( b ); // column stride of matrix OP(B), where OP(B) is Transpose(B) if transB enabled. - int ldc_matC = bli_obj_col_stride( c ); // column stride of matrix C - int ldc = M;//bli_obj_col_stride( c ); // column stride of static buffer for matrix C - int row_idx, col_idx, k; - int rs_matC = bli_obj_row_stride( c ); - int rsc = 1; - double *A = a->buffer; // pointer to elements of Matrix A - double *B = b->buffer; // pointer to elements of Matrix B - double *C = D_C_pack; // pointer to elements of Matrix C - double *matCbuf = c->buffer; - - double *tA = A, *tB = B, *tC = C;//, *tA_pack; - double *tA_packed; // temprorary pointer to hold packed A memory pointer - int row_idx_packed; //packed A memory row index - int lda_packed; //lda of packed A - int col_idx_start; //starting index after A matrix is packed. - dim_t tb_inc_row = 1; // row stride of matrix B - dim_t tb_inc_col = ldb; // column stride of matrix B - __m256d ymm4, ymm5, ymm6, ymm7; - __m256d ymm8, ymm9, ymm10, ymm11; - __m256d ymm12, ymm13, ymm14, ymm15; - __m256d ymm0, ymm1, ymm2, ymm3; - - int n_remainder; // If the N is non multiple of 3.(N%3) - int m_remainder; // If the M is non multiple of 16.(M%16) - - double *alpha_cast, *beta_cast; // alpha, beta multiples - alpha_cast = (alpha->buffer); - beta_cast = (beta->buffer); - int required_packing_A = 1; - - // when N is equal to 1 call GEMV instead of SYRK - if (N == 1) - { - bli_gemv - ( - alpha, - a, - b, - beta, - c - ); - return BLIS_SUCCESS; - } - - //update the pointer math if matrix B needs to be transposed. - if (bli_obj_has_trans( b )) - { - tb_inc_col = 1; //switch row and column strides - tb_inc_row = ldb; - } - - if ((N <= 3) || ((D_MR * K) > D_SCRATCH_DIM)) - { - required_packing_A = 0; - } - /* - * The computation loop runs for D_MRxN columns of C matrix, thus - * accessing the D_MRxK A matrix data and KxNR B matrix data. - * The computation is organized as inner loops of dimension D_MRxNR. - */ - // Process D_MR rows of C matrix at a time. - for (row_idx = 0; (row_idx + (D_MR - 1)) < M; row_idx += D_MR) - { - - col_idx_start = 0; - tA_packed = A; - row_idx_packed = row_idx; - lda_packed = lda; - - // This is the part of the pack and compute optimization. - // During the first column iteration, we store the accessed A matrix into - // contiguous static memory. This helps to keep te A matrix in Cache and - // aviods the TLB misses. - if (required_packing_A) - { - col_idx = 0; - - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; - tA_packed = D_A_pack; - -#if 0//def BLIS_ENABLE_PREFETCH - _mm_prefetch((char*)(tC + 0), _MM_HINT_T0); - _mm_prefetch((char*)(tC + 8), _MM_HINT_T0); - _mm_prefetch((char*)(tC + ldc), _MM_HINT_T0); - _mm_prefetch((char*)(tC + ldc + 8), _MM_HINT_T0); - _mm_prefetch((char*)(tC + 2 * ldc), _MM_HINT_T0); - _mm_prefetch((char*)(tC + 2 * ldc + 8), _MM_HINT_T0); -#endif - // clear scratch registers. - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - ymm12 = _mm256_setzero_pd(); - ymm13 = _mm256_setzero_pd(); - ymm14 = _mm256_setzero_pd(); - ymm15 = _mm256_setzero_pd(); - - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix data and - // multiplies it with the A matrix. - // This loop is processing D_MR x K - ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); - ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1); - ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2); - tB += tb_inc_row; - - //broadcasted matrix B elements are multiplied - //with matrix A columns. - ymm3 = _mm256_loadu_pd(tA); - _mm256_storeu_pd(tA_packed, ymm3); // the packing of matrix A - // ymm4 += ymm0 * ymm3; - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - // ymm8 += ymm1 * ymm3; - ymm8 = _mm256_fmadd_pd(ymm1, ymm3, ymm8); - // ymm12 += ymm2 * ymm3; - ymm12 = _mm256_fmadd_pd(ymm2, ymm3, ymm12); - - ymm3 = _mm256_loadu_pd(tA + 4); - _mm256_storeu_pd(tA_packed + 4, ymm3); // the packing of matrix A - // ymm5 += ymm0 * ymm3; - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - // ymm9 += ymm1 * ymm3; - ymm9 = _mm256_fmadd_pd(ymm1, ymm3, ymm9); - // ymm13 += ymm2 * ymm3; - ymm13 = _mm256_fmadd_pd(ymm2, ymm3, ymm13); - - ymm3 = _mm256_loadu_pd(tA + 8); - _mm256_storeu_pd(tA_packed + 8, ymm3); // the packing of matrix A - // ymm6 += ymm0 * ymm3; - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - // ymm10 += ymm1 * ymm3; - ymm10 = _mm256_fmadd_pd(ymm1, ymm3, ymm10); - // ymm14 += ymm2 * ymm3; - ymm14 = _mm256_fmadd_pd(ymm2, ymm3, ymm14); - - ymm3 = _mm256_loadu_pd(tA + 12); - _mm256_storeu_pd(tA_packed + 12, ymm3); // the packing of matrix A - // ymm7 += ymm0 * ymm3; - ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7); - // ymm11 += ymm1 * ymm3; - ymm11 = _mm256_fmadd_pd(ymm1, ymm3, ymm11); - // ymm15 += ymm2 * ymm3; - ymm15 = _mm256_fmadd_pd(ymm2, ymm3, ymm15); - - tA += lda; - tA_packed += D_MR; - } - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_sd(alpha_cast); - //ymm1 = _mm256_broadcast_sd(beta_cast); - - //multiply A*B by alpha. - ymm4 = _mm256_mul_pd(ymm4, ymm0); - ymm5 = _mm256_mul_pd(ymm5, ymm0); - ymm6 = _mm256_mul_pd(ymm6, ymm0); - ymm7 = _mm256_mul_pd(ymm7, ymm0); - ymm8 = _mm256_mul_pd(ymm8, ymm0); - ymm9 = _mm256_mul_pd(ymm9, ymm0); - ymm10 = _mm256_mul_pd(ymm10, ymm0); - ymm11 = _mm256_mul_pd(ymm11, ymm0); - ymm12 = _mm256_mul_pd(ymm12, ymm0); - ymm13 = _mm256_mul_pd(ymm13, ymm0); - ymm14 = _mm256_mul_pd(ymm14, ymm0); - ymm15 = _mm256_mul_pd(ymm15, ymm0); - - // multiply C by beta and accumulate col 1. - /*ymm2 = _mm256_loadu_pd(tC); - ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); - ymm2 = _mm256_loadu_pd(tC + 4); - ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5); - ymm2 = _mm256_loadu_pd(tC + 8); - ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); - ymm2 = _mm256_loadu_pd(tC + 12); - ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7);*/ - _mm256_storeu_pd(tC, ymm4); - _mm256_storeu_pd(tC + 4, ymm5); - _mm256_storeu_pd(tC + 8, ymm6); - _mm256_storeu_pd(tC + 12, ymm7); - - // multiply C by beta and accumulate, col 2. - tC += ldc; - /*ymm2 = _mm256_loadu_pd(tC); - ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8); - ymm2 = _mm256_loadu_pd(tC + 4); - ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9); - ymm2 = _mm256_loadu_pd(tC + 8); - ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10); - ymm2 = _mm256_loadu_pd(tC + 12); - ymm11 = _mm256_fmadd_pd(ymm2, ymm1, ymm11);*/ - _mm256_storeu_pd(tC, ymm8); - _mm256_storeu_pd(tC + 4, ymm9); - _mm256_storeu_pd(tC + 8, ymm10); - _mm256_storeu_pd(tC + 12, ymm11); - - // multiply C by beta and accumulate, col 3. - tC += ldc; - /*ymm2 = _mm256_loadu_pd(tC); - ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); - ymm2 = _mm256_loadu_pd(tC + 4); - ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); - ymm2 = _mm256_loadu_pd(tC + 8); - ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); - ymm2 = _mm256_loadu_pd(tC + 12); - ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15);*/ - _mm256_storeu_pd(tC, ymm12); - _mm256_storeu_pd(tC + 4, ymm13); - _mm256_storeu_pd(tC + 8, ymm14); - _mm256_storeu_pd(tC + 12, ymm15); - - // modify the pointer arithematic to use packed A matrix. - col_idx_start = NR; - tA_packed = D_A_pack; - row_idx_packed = 0; - lda_packed = D_MR; - } - // Process NR columns of C matrix at a time. - for (col_idx = col_idx_start; (col_idx + (NR - 1)) < N; col_idx += NR) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = tA_packed + row_idx_packed; - -#if 0//def BLIS_ENABLE_PREFETCH - _mm_prefetch((char*)(tC + 0), _MM_HINT_T0); - _mm_prefetch((char*)(tC + 8), _MM_HINT_T0); - _mm_prefetch((char*)(tC + ldc), _MM_HINT_T0); - _mm_prefetch((char*)(tC + ldc + 8), _MM_HINT_T0); - _mm_prefetch((char*)(tC + 2 * ldc), _MM_HINT_T0); - _mm_prefetch((char*)(tC + 2 * ldc + 8), _MM_HINT_T0); -#endif - // clear scratch registers. - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - ymm12 = _mm256_setzero_pd(); - ymm13 = _mm256_setzero_pd(); - ymm14 = _mm256_setzero_pd(); - ymm15 = _mm256_setzero_pd(); - - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix data and - // multiplies it with the A matrix. - // This loop is processing D_MR x K - ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); - ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1); - ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2); - tB += tb_inc_row; - - //broadcasted matrix B elements are multiplied - //with matrix A columns. - ymm3 = _mm256_loadu_pd(tA); - // ymm4 += ymm0 * ymm3; - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - // ymm8 += ymm1 * ymm3; - ymm8 = _mm256_fmadd_pd(ymm1, ymm3, ymm8); - // ymm12 += ymm2 * ymm3; - ymm12 = _mm256_fmadd_pd(ymm2, ymm3, ymm12); - - ymm3 = _mm256_loadu_pd(tA + 4); - // ymm5 += ymm0 * ymm3; - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - // ymm9 += ymm1 * ymm3; - ymm9 = _mm256_fmadd_pd(ymm1, ymm3, ymm9); - // ymm13 += ymm2 * ymm3; - ymm13 = _mm256_fmadd_pd(ymm2, ymm3, ymm13); - - ymm3 = _mm256_loadu_pd(tA + 8); - // ymm6 += ymm0 * ymm3; - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - // ymm10 += ymm1 * ymm3; - ymm10 = _mm256_fmadd_pd(ymm1, ymm3, ymm10); - // ymm14 += ymm2 * ymm3; - ymm14 = _mm256_fmadd_pd(ymm2, ymm3, ymm14); - - ymm3 = _mm256_loadu_pd(tA + 12); - // ymm7 += ymm0 * ymm3; - ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7); - // ymm11 += ymm1 * ymm3; - ymm11 = _mm256_fmadd_pd(ymm1, ymm3, ymm11); - // ymm15 += ymm2 * ymm3; - ymm15 = _mm256_fmadd_pd(ymm2, ymm3, ymm15); - - tA += lda_packed; - } - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_sd(alpha_cast); - //ymm1 = _mm256_broadcast_sd(beta_cast); - - //multiply A*B by alpha. - ymm4 = _mm256_mul_pd(ymm4, ymm0); - ymm5 = _mm256_mul_pd(ymm5, ymm0); - ymm6 = _mm256_mul_pd(ymm6, ymm0); - ymm7 = _mm256_mul_pd(ymm7, ymm0); - ymm8 = _mm256_mul_pd(ymm8, ymm0); - ymm9 = _mm256_mul_pd(ymm9, ymm0); - ymm10 = _mm256_mul_pd(ymm10, ymm0); - ymm11 = _mm256_mul_pd(ymm11, ymm0); - ymm12 = _mm256_mul_pd(ymm12, ymm0); - ymm13 = _mm256_mul_pd(ymm13, ymm0); - ymm14 = _mm256_mul_pd(ymm14, ymm0); - ymm15 = _mm256_mul_pd(ymm15, ymm0); - - // multiply C by beta and accumulate col 1. - /*ymm2 = _mm256_loadu_pd(tC); - ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); - ymm2 = _mm256_loadu_pd(tC + 4); - ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5); - ymm2 = _mm256_loadu_pd(tC + 8); - ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); - ymm2 = _mm256_loadu_pd(tC + 12); - ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7);*/ - _mm256_storeu_pd(tC, ymm4); - _mm256_storeu_pd(tC + 4, ymm5); - _mm256_storeu_pd(tC + 8, ymm6); - _mm256_storeu_pd(tC + 12, ymm7); - - // multiply C by beta and accumulate, col 2. - tC += ldc; - /*ymm2 = _mm256_loadu_pd(tC); - ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8); - ymm2 = _mm256_loadu_pd(tC + 4); - ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9); - ymm2 = _mm256_loadu_pd(tC + 8); - ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10); - ymm2 = _mm256_loadu_pd(tC + 12); - ymm11 = _mm256_fmadd_pd(ymm2, ymm1, ymm11);*/ - _mm256_storeu_pd(tC, ymm8); - _mm256_storeu_pd(tC + 4, ymm9); - _mm256_storeu_pd(tC + 8, ymm10); - _mm256_storeu_pd(tC + 12, ymm11); - - // multiply C by beta and accumulate, col 3. - tC += ldc; - /*ymm2 = _mm256_loadu_pd(tC); - ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); - ymm2 = _mm256_loadu_pd(tC + 4); - ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); - ymm2 = _mm256_loadu_pd(tC + 8); - ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); - ymm2 = _mm256_loadu_pd(tC + 12); - ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15);*/ - _mm256_storeu_pd(tC, ymm12); - _mm256_storeu_pd(tC + 4, ymm13); - _mm256_storeu_pd(tC + 8, ymm14); - _mm256_storeu_pd(tC + 12, ymm15); - - } - n_remainder = N - col_idx; - - // if the N is not multiple of 3. - // handling edge case. - if (n_remainder == 2) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; - - // clear scratch registers. - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - ymm12 = _mm256_setzero_pd(); - ymm13 = _mm256_setzero_pd(); - ymm14 = _mm256_setzero_pd(); - ymm15 = _mm256_setzero_pd(); - - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix data and - // multiplies it with the A matrix. - ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); - ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1); - tB += tb_inc_row; - - //broadcasted matrix B elements are multiplied - //with matrix A columns. - ymm3 = _mm256_loadu_pd(tA); - ymm8 = _mm256_fmadd_pd(ymm0, ymm3, ymm8); - ymm12 = _mm256_fmadd_pd(ymm1, ymm3, ymm12); - - ymm3 = _mm256_loadu_pd(tA + 4); - ymm9 = _mm256_fmadd_pd(ymm0, ymm3, ymm9); - ymm13 = _mm256_fmadd_pd(ymm1, ymm3, ymm13); - - ymm3 = _mm256_loadu_pd(tA + 8); - ymm10 = _mm256_fmadd_pd(ymm0, ymm3, ymm10); - ymm14 = _mm256_fmadd_pd(ymm1, ymm3, ymm14); - - ymm3 = _mm256_loadu_pd(tA + 12); - ymm11 = _mm256_fmadd_pd(ymm0, ymm3, ymm11); - ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); - - tA += lda; - - } - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_sd(alpha_cast); - //ymm1 = _mm256_broadcast_sd(beta_cast); - - //multiply A*B by alpha. - ymm8 = _mm256_mul_pd(ymm8, ymm0); - ymm9 = _mm256_mul_pd(ymm9, ymm0); - ymm10 = _mm256_mul_pd(ymm10, ymm0); - ymm11 = _mm256_mul_pd(ymm11, ymm0); - ymm12 = _mm256_mul_pd(ymm12, ymm0); - ymm13 = _mm256_mul_pd(ymm13, ymm0); - ymm14 = _mm256_mul_pd(ymm14, ymm0); - ymm15 = _mm256_mul_pd(ymm15, ymm0); - - // multiply C by beta and accumulate, col 1. - /*ymm2 = _mm256_loadu_pd(tC + 0); - ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8); - ymm2 = _mm256_loadu_pd(tC + 4); - ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9); - ymm2 = _mm256_loadu_pd(tC + 8); - ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10); - ymm2 = _mm256_loadu_pd(tC + 12); - ymm11 = _mm256_fmadd_pd(ymm2, ymm1, ymm11);*/ - _mm256_storeu_pd(tC + 0, ymm8); - _mm256_storeu_pd(tC + 4, ymm9); - _mm256_storeu_pd(tC + 8, ymm10); - _mm256_storeu_pd(tC + 12, ymm11); - - // multiply C by beta and accumulate, col 2. - tC += ldc; - /*ymm2 = _mm256_loadu_pd(tC); - ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); - ymm2 = _mm256_loadu_pd(tC + 4); - ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); - ymm2 = _mm256_loadu_pd(tC + 8); - ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); - ymm2 = _mm256_loadu_pd(tC + 12); - ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15);*/ - _mm256_storeu_pd(tC, ymm12); - _mm256_storeu_pd(tC + 4, ymm13); - _mm256_storeu_pd(tC + 8, ymm14); - _mm256_storeu_pd(tC + 12, ymm15); - - col_idx += 2; - } - // if the N is not multiple of 3. - // handling edge case. - if (n_remainder == 1) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; - - // clear scratch registers. - ymm12 = _mm256_setzero_pd(); - ymm13 = _mm256_setzero_pd(); - ymm14 = _mm256_setzero_pd(); - ymm15 = _mm256_setzero_pd(); - - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix data and - // multiplies it with the A matrix. - ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); - tB += tb_inc_row; - - //broadcasted matrix B elements are multiplied - //with matrix A columns. - ymm3 = _mm256_loadu_pd(tA); - ymm12 = _mm256_fmadd_pd(ymm0, ymm3, ymm12); - - ymm3 = _mm256_loadu_pd(tA + 4); - ymm13 = _mm256_fmadd_pd(ymm0, ymm3, ymm13); - - ymm3 = _mm256_loadu_pd(tA + 8); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - - ymm3 = _mm256_loadu_pd(tA + 12); - ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); - - tA += lda; - - } - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_sd(alpha_cast); - //ymm1 = _mm256_broadcast_sd(beta_cast); - - //multiply A*B by alpha. - ymm12 = _mm256_mul_pd(ymm12, ymm0); - ymm13 = _mm256_mul_pd(ymm13, ymm0); - ymm14 = _mm256_mul_pd(ymm14, ymm0); - ymm15 = _mm256_mul_pd(ymm15, ymm0); - - // multiply C by beta and accumulate. - /*ymm2 = _mm256_loadu_pd(tC + 0); - ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); - ymm2 = _mm256_loadu_pd(tC + 4); - ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); - ymm2 = _mm256_loadu_pd(tC + 8); - ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); - ymm2 = _mm256_loadu_pd(tC + 12); - ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15);*/ - - _mm256_storeu_pd(tC + 0, ymm12); - _mm256_storeu_pd(tC + 4, ymm13); - _mm256_storeu_pd(tC + 8, ymm14); - _mm256_storeu_pd(tC + 12, ymm15); - } - } - - m_remainder = M - row_idx; - - if (m_remainder >= 12) - { - m_remainder -= 12; - - for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; - - // clear scratch registers. - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm12 = _mm256_setzero_pd(); - ymm13 = _mm256_setzero_pd(); - ymm14 = _mm256_setzero_pd(); - - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix data and - // multiplies it with the A matrix. - ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); - ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1); - ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2); - tB += tb_inc_row; - - //broadcasted matrix B elements are multiplied - //with matrix A columns. - ymm3 = _mm256_loadu_pd(tA); - // ymm4 += ymm0 * ymm3; - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - // ymm8 += ymm1 * ymm3; - ymm8 = _mm256_fmadd_pd(ymm1, ymm3, ymm8); - // ymm12 += ymm2 * ymm3; - ymm12 = _mm256_fmadd_pd(ymm2, ymm3, ymm12); - - ymm3 = _mm256_loadu_pd(tA + 4); - // ymm5 += ymm0 * ymm3; - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - // ymm9 += ymm1 * ymm3; - ymm9 = _mm256_fmadd_pd(ymm1, ymm3, ymm9); - // ymm13 += ymm2 * ymm3; - ymm13 = _mm256_fmadd_pd(ymm2, ymm3, ymm13); - - ymm3 = _mm256_loadu_pd(tA + 8); - // ymm6 += ymm0 * ymm3; - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - // ymm10 += ymm1 * ymm3; - ymm10 = _mm256_fmadd_pd(ymm1, ymm3, ymm10); - // ymm14 += ymm2 * ymm3; - ymm14 = _mm256_fmadd_pd(ymm2, ymm3, ymm14); - - tA += lda; - } - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_sd(alpha_cast); - //ymm1 = _mm256_broadcast_sd(beta_cast); - - //multiply A*B by alpha. - ymm4 = _mm256_mul_pd(ymm4, ymm0); - ymm5 = _mm256_mul_pd(ymm5, ymm0); - ymm6 = _mm256_mul_pd(ymm6, ymm0); - ymm8 = _mm256_mul_pd(ymm8, ymm0); - ymm9 = _mm256_mul_pd(ymm9, ymm0); - ymm10 = _mm256_mul_pd(ymm10, ymm0); - ymm12 = _mm256_mul_pd(ymm12, ymm0); - ymm13 = _mm256_mul_pd(ymm13, ymm0); - ymm14 = _mm256_mul_pd(ymm14, ymm0); - - // multiply C by beta and accumulate. - /*ymm2 = _mm256_loadu_pd(tC); - ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); - ymm2 = _mm256_loadu_pd(tC + 4); - ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5); - ymm2 = _mm256_loadu_pd(tC + 8); - ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6);*/ - _mm256_storeu_pd(tC, ymm4); - _mm256_storeu_pd(tC + 4, ymm5); - _mm256_storeu_pd(tC + 8, ymm6); - - // multiply C by beta and accumulate. - tC += ldc; - /*ymm2 = _mm256_loadu_pd(tC); - ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8); - ymm2 = _mm256_loadu_pd(tC + 4); - ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9); - ymm2 = _mm256_loadu_pd(tC + 8); - ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10);*/ - _mm256_storeu_pd(tC, ymm8); - _mm256_storeu_pd(tC + 4, ymm9); - _mm256_storeu_pd(tC + 8, ymm10); - - // multiply C by beta and accumulate. - tC += ldc; - /*ymm2 = _mm256_loadu_pd(tC); - ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); - ymm2 = _mm256_loadu_pd(tC + 4); - ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); - ymm2 = _mm256_loadu_pd(tC + 8); - ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14);*/ - _mm256_storeu_pd(tC, ymm12); - _mm256_storeu_pd(tC + 4, ymm13); - _mm256_storeu_pd(tC + 8, ymm14); - - } - n_remainder = N - col_idx; - // if the N is not multiple of 3. - // handling edge case. - if (n_remainder == 2) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; - - // clear scratch registers. - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm12 = _mm256_setzero_pd(); - ymm13 = _mm256_setzero_pd(); - ymm14 = _mm256_setzero_pd(); - - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix data and - // multiplies it with the A matrix. - ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); - ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1); - tB += tb_inc_row; - - //broadcasted matrix B elements are multiplied - //with matrix A columns. - ymm3 = _mm256_loadu_pd(tA); - ymm8 = _mm256_fmadd_pd(ymm0, ymm3, ymm8); - ymm12 = _mm256_fmadd_pd(ymm1, ymm3, ymm12); - - ymm3 = _mm256_loadu_pd(tA + 4); - ymm9 = _mm256_fmadd_pd(ymm0, ymm3, ymm9); - ymm13 = _mm256_fmadd_pd(ymm1, ymm3, ymm13); - - ymm3 = _mm256_loadu_pd(tA + 8); - ymm10 = _mm256_fmadd_pd(ymm0, ymm3, ymm10); - ymm14 = _mm256_fmadd_pd(ymm1, ymm3, ymm14); - - tA += lda; - - } - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_sd(alpha_cast); - //ymm1 = _mm256_broadcast_sd(beta_cast); - - //multiply A*B by alpha. - ymm8 = _mm256_mul_pd(ymm8, ymm0); - ymm9 = _mm256_mul_pd(ymm9, ymm0); - ymm10 = _mm256_mul_pd(ymm10, ymm0); - ymm12 = _mm256_mul_pd(ymm12, ymm0); - ymm13 = _mm256_mul_pd(ymm13, ymm0); - ymm14 = _mm256_mul_pd(ymm14, ymm0); - - // multiply C by beta and accumulate. - /*ymm2 = _mm256_loadu_pd(tC + 0); - ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8); - ymm2 = _mm256_loadu_pd(tC + 4); - ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9); - ymm2 = _mm256_loadu_pd(tC + 8); - ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10);*/ - _mm256_storeu_pd(tC + 0, ymm8); - _mm256_storeu_pd(tC + 4, ymm9); - _mm256_storeu_pd(tC + 8, ymm10); - - // multiply C by beta and accumulate. - tC += ldc; - /*ymm2 = _mm256_loadu_pd(tC); - ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); - ymm2 = _mm256_loadu_pd(tC + 4); - ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); - ymm2 = _mm256_loadu_pd(tC + 8); - ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14);*/ - _mm256_storeu_pd(tC, ymm12); - _mm256_storeu_pd(tC + 4, ymm13); - _mm256_storeu_pd(tC + 8, ymm14); - - col_idx += 2; - } - // if the N is not multiple of 3. - // handling edge case. - if (n_remainder == 1) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; - - // clear scratch registers. - ymm12 = _mm256_setzero_pd(); - ymm13 = _mm256_setzero_pd(); - ymm14 = _mm256_setzero_pd(); - - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix data and - // multiplies it with the A matrix. - ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); - tB += tb_inc_row; - - //broadcasted matrix B elements are multiplied - //with matrix A columns. - ymm3 = _mm256_loadu_pd(tA); - ymm12 = _mm256_fmadd_pd(ymm0, ymm3, ymm12); - - ymm3 = _mm256_loadu_pd(tA + 4); - ymm13 = _mm256_fmadd_pd(ymm0, ymm3, ymm13); - - ymm3 = _mm256_loadu_pd(tA + 8); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - - tA += lda; - - } - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_sd(alpha_cast); - //ymm1 = _mm256_broadcast_sd(beta_cast); - - //multiply A*B by alpha. - ymm12 = _mm256_mul_pd(ymm12, ymm0); - ymm13 = _mm256_mul_pd(ymm13, ymm0); - ymm14 = _mm256_mul_pd(ymm14, ymm0); - - // multiply C by beta and accumulate. - /*ymm2 = _mm256_loadu_pd(tC + 0); - ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); - ymm2 = _mm256_loadu_pd(tC + 4); - ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); - ymm2 = _mm256_loadu_pd(tC + 8); - ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14);*/ - - _mm256_storeu_pd(tC + 0, ymm12); - _mm256_storeu_pd(tC + 4, ymm13); - _mm256_storeu_pd(tC + 8, ymm14); - } - - row_idx += 12; - } - - if (m_remainder >= 8) - { - m_remainder -= 8; - - for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; - - // clear scratch registers. - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix data and - // multiplies it with the A matrix. - ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); - ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1); - ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2); - tB += tb_inc_row; - - //broadcasted matrix B elements are multiplied - //with matrix A columns. - ymm3 = _mm256_loadu_pd(tA); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm6 = _mm256_fmadd_pd(ymm1, ymm3, ymm6); - ymm8 = _mm256_fmadd_pd(ymm2, ymm3, ymm8); - - ymm3 = _mm256_loadu_pd(tA + 4); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - ymm9 = _mm256_fmadd_pd(ymm2, ymm3, ymm9); - - tA += lda; - } - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_sd(alpha_cast); - //ymm1 = _mm256_broadcast_sd(beta_cast); - - //multiply A*B by alpha. - ymm4 = _mm256_mul_pd(ymm4, ymm0); - ymm5 = _mm256_mul_pd(ymm5, ymm0); - ymm6 = _mm256_mul_pd(ymm6, ymm0); - ymm7 = _mm256_mul_pd(ymm7, ymm0); - ymm8 = _mm256_mul_pd(ymm8, ymm0); - ymm9 = _mm256_mul_pd(ymm9, ymm0); - - // multiply C by beta and accumulate. - /*ymm2 = _mm256_loadu_pd(tC); - ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); - ymm2 = _mm256_loadu_pd(tC + 4); - ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);*/ - _mm256_storeu_pd(tC, ymm4); - _mm256_storeu_pd(tC + 4, ymm5); - - // multiply C by beta and accumulate. - tC += ldc; - /*ymm2 = _mm256_loadu_pd(tC); - ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); - ymm2 = _mm256_loadu_pd(tC + 4); - ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7);*/ - _mm256_storeu_pd(tC, ymm6); - _mm256_storeu_pd(tC + 4, ymm7); - - // multiply C by beta and accumulate. - tC += ldc; - /*ymm2 = _mm256_loadu_pd(tC); - ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8); - ymm2 = _mm256_loadu_pd(tC + 4); - ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9);*/ - _mm256_storeu_pd(tC, ymm8); - _mm256_storeu_pd(tC + 4, ymm9); - - } - n_remainder = N - col_idx; - // if the N is not multiple of 3. - // handling edge case. - if (n_remainder == 2) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; - - // clear scratch registers. - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix data and - // multiplies it with the A matrix. - ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); - ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1); - tB += tb_inc_row; - - //broadcasted matrix B elements are multiplied - //with matrix A columns. - ymm3 = _mm256_loadu_pd(tA); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm6 = _mm256_fmadd_pd(ymm1, ymm3, ymm6); - - ymm3 = _mm256_loadu_pd(tA + 4); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - tA += lda; - } - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_sd(alpha_cast); - //ymm1 = _mm256_broadcast_sd(beta_cast); - - //multiply A*B by alpha. - ymm4 = _mm256_mul_pd(ymm4, ymm0); - ymm5 = _mm256_mul_pd(ymm5, ymm0); - ymm6 = _mm256_mul_pd(ymm6, ymm0); - ymm7 = _mm256_mul_pd(ymm7, ymm0); - - // multiply C by beta and accumulate. - /*ymm2 = _mm256_loadu_pd(tC); - ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); - ymm2 = _mm256_loadu_pd(tC + 4); - ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);*/ - _mm256_storeu_pd(tC, ymm4); - _mm256_storeu_pd(tC + 4, ymm5); - - // multiply C by beta and accumulate. - tC += ldc; - /*ymm2 = _mm256_loadu_pd(tC); - ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); - ymm2 = _mm256_loadu_pd(tC + 4); - ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7);*/ - _mm256_storeu_pd(tC, ymm6); - _mm256_storeu_pd(tC + 4, ymm7); - - col_idx += 2; - - } - // if the N is not multiple of 3. - // handling edge case. - if (n_remainder == 1) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; - - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix data and - // multiplies it with the A matrix. - ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); - tB += tb_inc_row; - - //broadcasted matrix B elements are multiplied - //with matrix A columns. - ymm3 = _mm256_loadu_pd(tA); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - ymm3 = _mm256_loadu_pd(tA + 4); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - - tA += lda; - } - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_sd(alpha_cast); - //ymm1 = _mm256_broadcast_sd(beta_cast); - - ymm4 = _mm256_mul_pd(ymm4, ymm0); - ymm5 = _mm256_mul_pd(ymm5, ymm0); - - // multiply C by beta and accumulate. - /*ymm2 = _mm256_loadu_pd(tC); - ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); - ymm2 = _mm256_loadu_pd(tC + 4); - ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);*/ - _mm256_storeu_pd(tC, ymm4); - _mm256_storeu_pd(tC + 4, ymm5); - - } - - row_idx += 8; - } - - if (m_remainder >= 4) - { - m_remainder -= 4; - - for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; - - // clear scratch registers. - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix data and - // multiplies it with the A matrix. - ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); - ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1); - ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2); - tB += tb_inc_row; - - //broadcasted matrix B elements are multiplied - //with matrix A columns. - ymm3 = _mm256_loadu_pd(tA); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - ymm6 = _mm256_fmadd_pd(ymm2, ymm3, ymm6); - - tA += lda; - } - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_sd(alpha_cast); - //ymm1 = _mm256_broadcast_sd(beta_cast); - - //multiply A*B by alpha. - ymm4 = _mm256_mul_pd(ymm4, ymm0); - ymm5 = _mm256_mul_pd(ymm5, ymm0); - ymm6 = _mm256_mul_pd(ymm6, ymm0); - - // multiply C by beta and accumulate. - /*ymm2 = _mm256_loadu_pd(tC); - ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);*/ - _mm256_storeu_pd(tC, ymm4); - - // multiply C by beta and accumulate. - tC += ldc; - /*ymm2 = _mm256_loadu_pd(tC); - ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);*/ - _mm256_storeu_pd(tC, ymm5); - - // multiply C by beta and accumulate. - tC += ldc; - /*ymm2 = _mm256_loadu_pd(tC); - ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6);*/ - _mm256_storeu_pd(tC, ymm6); - } - n_remainder = N - col_idx; - // if the N is not multiple of 3. - // handling edge case. - if (n_remainder == 2) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; - - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix data and - // multiplies it with the A matrix. - ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); - ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1); - tB += tb_inc_row; - - //broadcasted matrix B elements are multiplied - //with matrix A columns. - ymm3 = _mm256_loadu_pd(tA); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - tA += lda; - } - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_sd(alpha_cast); - //ymm1 = _mm256_broadcast_sd(beta_cast); - - //multiply A*B by alpha. - ymm4 = _mm256_mul_pd(ymm4, ymm0); - ymm5 = _mm256_mul_pd(ymm5, ymm0); - - // multiply C by beta and accumulate. - /*ymm2 = _mm256_loadu_pd(tC); - ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);*/ - _mm256_storeu_pd(tC, ymm4); - - // multiply C by beta and accumulate. - tC += ldc; - /*ymm2 = _mm256_loadu_pd(tC); - ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);*/ - _mm256_storeu_pd(tC, ymm5); - - col_idx += 2; - - } - // if the N is not multiple of 3. - // handling edge case. - if (n_remainder == 1) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; - - ymm4 = _mm256_setzero_pd(); - - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix data and - // multiplies it with the A matrix. - ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); - tB += tb_inc_row; - - //broadcasted matrix B elements are multiplied - //with matrix A columns. - ymm3 = _mm256_loadu_pd(tA); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - tA += lda; - } - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_sd(alpha_cast); - //ymm1 = _mm256_broadcast_sd(beta_cast); - - ymm4 = _mm256_mul_pd(ymm4, ymm0); - - // multiply C by beta and accumulate. - /*ymm2 = _mm256_loadu_pd(tC); - ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);*/ - _mm256_storeu_pd(tC, ymm4); - - } - - row_idx += 4; - } - // M is not a multiple of 32. - // The handling of edge case where the remainder - // dimension is less than 8. The padding takes place - // to handle this case. - if ((m_remainder) && (lda > 3)) - { - double f_temp[8]; - - for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; - - // clear scratch registers. - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - - for (k = 0; k < (K - 1); ++k) - { - // The inner loop broadcasts the B matrix data and - // multiplies it with the A matrix. - ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); - ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1); - ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2); - tB += tb_inc_row; - - //broadcasted matrix B elements are multiplied - //with matrix A columns. - ymm3 = _mm256_loadu_pd(tA); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - ymm9 = _mm256_fmadd_pd(ymm2, ymm3, ymm9); - - tA += lda; - } - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); - ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1); - ymm2 = _mm256_broadcast_sd(tB + tb_inc_col * 2); - tB += tb_inc_row; - - for (int i = 0; i < m_remainder; i++) - { - f_temp[i] = tA[i]; - } - ymm3 = _mm256_loadu_pd(f_temp); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - ymm9 = _mm256_fmadd_pd(ymm2, ymm3, ymm9); - - ymm0 = _mm256_broadcast_sd(alpha_cast); - //ymm1 = _mm256_broadcast_sd(beta_cast); - - //multiply A*B by alpha. - ymm5 = _mm256_mul_pd(ymm5, ymm0); - ymm7 = _mm256_mul_pd(ymm7, ymm0); - ymm9 = _mm256_mul_pd(ymm9, ymm0); - - - /*for (int i = 0; i < m_remainder; i++) - { - f_temp[i] = tC[i]; - } - ymm2 = _mm256_loadu_pd(f_temp); - ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);*/ - _mm256_storeu_pd(f_temp, ymm5); - for (int i = 0; i < m_remainder; i++) - { - tC[i] = f_temp[i]; - } - - tC += ldc; - /*for (int i = 0; i < m_remainder; i++) - { - f_temp[i] = tC[i]; - } - ymm2 = _mm256_loadu_pd(f_temp); - ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7);*/ - _mm256_storeu_pd(f_temp, ymm7); - for (int i = 0; i < m_remainder; i++) - { - tC[i] = f_temp[i]; - } - - tC += ldc; - /*for (int i = 0; i < m_remainder; i++) - { - f_temp[i] = tC[i]; - } - ymm2 = _mm256_loadu_pd(f_temp); - ymm9 = _mm256_fmadd_pd(ymm2, ymm1, ymm9);*/ - _mm256_storeu_pd(f_temp, ymm9); - for (int i = 0; i < m_remainder; i++) - { - tC[i] = f_temp[i]; - } - } - n_remainder = N - col_idx; - // if the N is not multiple of 3. - // handling edge case. - if (n_remainder == 2) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; - - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - for (k = 0; k < (K - 1); ++k) - { - // The inner loop broadcasts the B matrix data and - // multiplies it with the A matrix. - ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); - ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1); - tB += tb_inc_row; - - ymm3 = _mm256_loadu_pd(tA); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - tA += lda; - } - - ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); - ymm1 = _mm256_broadcast_sd(tB + tb_inc_col * 1); - tB += tb_inc_row; - - for (int i = 0; i < m_remainder; i++) - { - f_temp[i] = tA[i]; - } - ymm3 = _mm256_loadu_pd(f_temp); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - ymm0 = _mm256_broadcast_sd(alpha_cast); - //ymm1 = _mm256_broadcast_sd(beta_cast); - - ymm5 = _mm256_mul_pd(ymm5, ymm0); - ymm7 = _mm256_mul_pd(ymm7, ymm0); - - /*for (int i = 0; i < m_remainder; i++) - { - f_temp[i] = tC[i]; - } - ymm2 = _mm256_loadu_pd(f_temp); - ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);*/ - _mm256_storeu_pd(f_temp, ymm5); - for (int i = 0; i < m_remainder; i++) - { - tC[i] = f_temp[i]; - } - - tC += ldc; - /*for (int i = 0; i < m_remainder; i++) - { - f_temp[i] = tC[i]; - } - ymm2 = _mm256_loadu_pd(f_temp); - ymm7 = _mm256_fmadd_pd(ymm2, ymm1, ymm7);*/ - _mm256_storeu_pd(f_temp, ymm7); - for (int i = 0; i < m_remainder; i++) - { - tC[i] = f_temp[i]; - } - } - // if the N is not multiple of 3. - // handling edge case. - if (n_remainder == 1) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; - - ymm5 = _mm256_setzero_pd(); - - for (k = 0; k < (K - 1); ++k) - { - // The inner loop broadcasts the B matrix data and - // multiplies it with the A matrix. - ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); - tB += tb_inc_row; - - ymm3 = _mm256_loadu_pd(tA); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - - tA += lda; - } - - ymm0 = _mm256_broadcast_sd(tB + tb_inc_col * 0); - tB += tb_inc_row; - - for (int i = 0; i < m_remainder; i++) - { - f_temp[i] = tA[i]; - } - ymm3 = _mm256_loadu_pd(f_temp); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - - ymm0 = _mm256_broadcast_sd(alpha_cast); - //ymm1 = _mm256_broadcast_sd(beta_cast); - - // multiply C by beta and accumulate. - ymm5 = _mm256_mul_pd(ymm5, ymm0); - - /*for (int i = 0; i < m_remainder; i++) - { - f_temp[i] = tC[i]; - } - ymm2 = _mm256_loadu_pd(f_temp); - ymm5 = _mm256_fmadd_pd(ymm2, ymm1, ymm5);*/ - _mm256_storeu_pd(f_temp, ymm5); - for (int i = 0; i < m_remainder; i++) - { - tC[i] = f_temp[i]; - } - } - m_remainder = 0; - } - - if (m_remainder) - { - double result; - for (; row_idx < M; row_idx += 1) - { - for (col_idx = 0; col_idx < N; col_idx += 1) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; - - result = 0; - for (k = 0; k < K; ++k) - { - result += (*tA) * (*tB); - tA += lda; - tB += tb_inc_row; - } - - result *= (*alpha_cast); - (*tC) = /*(*tC) * (*beta_cast) + */result; - } - } - } - - //copy/compute sryk values back to C using SIMD - if ( bli_seq0( *beta_cast ) ) - {//just copy for beta = 0 - dim_t _i, _j, k, _l; - if(bli_obj_is_lower(c)) //c is lower - { - //first column - _j = 0; - k = M >> 2; - _i = 0; - for ( _l = 0; _l < k; _l++ ) - { - ymm0 = _mm256_loadu_pd((C + _i*rsc)); - _mm256_storeu_pd((matCbuf + _i*rs_matC), ymm0); - _i += 4; - } - while (_i < M ) - { - bli_ddcopys( *(C + _i*rsc + _j*ldc), - *(matCbuf + _i*rs_matC + _j*ldc_matC) ); - _i++; - } - _j++; - while ( _j < N ) //next column - { - //k = (_j + (4 - (_j & 3))); - _l = _j & 3; - k = (_l != 0) ? (_j + (4 - _l)) : _j; - k = (k <= M) ? k : M; - for ( _i = _j; _i < k; ++_i ) - { - bli_ddcopys( *(C + _i*rsc + _j*ldc), - *(matCbuf + _i*rs_matC + _j*ldc_matC) ); - } - k = (M - _i) >> 2; - _l = 0; - while ( _l < k ) - { - ymm0 = _mm256_loadu_pd((C + _i*rsc + _j*ldc)); - _mm256_storeu_pd((matCbuf + _i*rs_matC + _j*ldc_matC), ymm0); - - _i += 4; - _l++; - } - while (_i < M ) - { - bli_ddcopys( *(C + _i*rsc + _j*ldc), - *(matCbuf + _i*rs_matC + _j*ldc_matC) ); - _i++; - } - _j++; - } - } - else //c is upper - { - for ( _j = 0; _j < N; ++_j ) - { - k = (_j + 1) >> 2; - _i = 0; - _l = 0; - while ( _l < k ) - { - ymm0 = _mm256_loadu_pd((C + _i*rsc + _j*ldc)); - _mm256_storeu_pd((matCbuf + _i*rs_matC + _j*ldc_matC), ymm0); - _i += 4; - _l++; - } - while (_i <= _j ) - { - bli_ddcopys( *(C + _i*rsc + _j*ldc), - *(matCbuf + _i*rs_matC + _j*ldc_matC) ); - ++_i; - } - } - } - } - else - {//when beta is non-zero, fmadd and store the results - dim_t _i, _j, k, _l; - ymm1 = _mm256_broadcast_sd(beta_cast); - if(bli_obj_is_lower(c)) //c is lower - { - //first column - _j = 0; - k = M >> 2; - _i = 0; - for ( _l = 0; _l < k; _l++ ) - { - ymm2 = _mm256_loadu_pd((matCbuf + _i*rs_matC)); - ymm0 = _mm256_loadu_pd((C + _i*rsc)); - ymm0 = _mm256_fmadd_pd(ymm2, ymm1, ymm0); - _mm256_storeu_pd((matCbuf + _i*rs_matC), ymm0); - _i += 4; - } - while (_i < M ) - { - bli_dddxpbys( *(C + _i*rsc + _j*ldc), - *(beta_cast), - *(matCbuf + _i*rs_matC + _j*ldc_matC) ); - _i++; - } - _j++; - while ( _j < N ) //next column - { - //k = (_j + (4 - (_j & 3))); - _l = _j & 3; - k = (_l != 0) ? (_j + (4 - _l)) : _j; - k = (k <= M) ? k : M; - for ( _i = _j; _i < k; ++_i ) - { - bli_dddxpbys( *(C + _i*rsc + _j*ldc), - *(beta_cast), - *(matCbuf + _i*rs_matC + _j*ldc_matC) ); - } - k = (M - _i) >> 2; - _l = 0; - while ( _l < k ) - { - ymm2 = _mm256_loadu_pd((matCbuf + _i*rs_matC + _j*ldc_matC)); - ymm0 = _mm256_loadu_pd((C + _i*rsc + _j*ldc)); - ymm0 = _mm256_fmadd_pd(ymm2, ymm1, ymm0); - _mm256_storeu_pd((matCbuf + _i*rs_matC + _j*ldc_matC), ymm0); - - _i += 4; - _l++; - } - while (_i < M ) - { - bli_dddxpbys( *(C + _i*rsc + _j*ldc), - *(beta_cast), - *(matCbuf + _i*rs_matC + _j*ldc_matC) ); - _i++; - } - _j++; - } - } - else //c is upper - { - for ( _j = 0; _j < N; ++_j ) - { - k = (_j + 1) >> 2; - _i = 0; - _l = 0; - while ( _l < k ) - { - ymm2 = _mm256_loadu_pd((matCbuf + _i*rs_matC + _j*ldc_matC)); - ymm0 = _mm256_loadu_pd((C + _i*rsc + _j*ldc)); - ymm0 = _mm256_fmadd_pd(ymm2, ymm1, ymm0); - _mm256_storeu_pd((matCbuf + _i*rs_matC + _j*ldc_matC), ymm0); - _i += 4; - _l++; - } - while (_i <= _j ) - { - bli_dddxpbys( *(C + _i*rsc + _j*ldc), - *(beta_cast), - *(matCbuf + _i*rs_matC + _j*ldc_matC) ); - ++_i; - } - } - } - } - - return BLIS_SUCCESS; - } - else - return BLIS_NONCONFORMAL_DIMENSIONS; - - -}; - -static err_t bli_ssyrk_small_atbn - ( - obj_t* alpha, - obj_t* a, - obj_t* b, - obj_t* beta, - obj_t* c, - cntx_t* cntx, - cntl_t* cntl - ) -{ - int M = bli_obj_length(c); // number of rows of Matrix C - int N = bli_obj_width(c); // number of columns of Matrix C - int K = bli_obj_length(b); // number of rows of Matrix B - int lda = bli_obj_col_stride(a); // column stride of matrix OP(A), where OP(A) is Transpose(A) if transA enabled. - int ldb = bli_obj_col_stride(b); // column stride of matrix OP(B), where OP(B) is Transpose(B) if transB enabled. - int ldc_matC = bli_obj_col_stride( c ); // column stride of matrix C - int ldc = M;//bli_obj_col_stride( c ); // column stride of static buffer for matrix C - int row_idx = 0, col_idx = 0, k; - int rs_matC = bli_obj_row_stride( c ); - int rsc = 1; - float *A = a->buffer; // pointer to matrix A elements, stored in row major format - float *B = b->buffer; // pointer to matrix B elements, stored in column major format - float *C = C_pack; // pointer to matrix C elements, stored in column major format - float *matCbuf = c->buffer; - - float *tA = A, *tB = B, *tC = C; - - __m256 ymm4, ymm5, ymm6, ymm7; - __m256 ymm8, ymm9, ymm10, ymm11; - __m256 ymm12, ymm13, ymm14, ymm15; - __m256 ymm0, ymm1, ymm2, ymm3; - - float result, scratch[8]; - float *alpha_cast, *beta_cast; // alpha, beta multiples - alpha_cast = (alpha->buffer); - beta_cast = (beta->buffer); - - // The non-copy version of the A^T SYRK gives better performance for the small M cases. - // The threshold is controlled by BLIS_ATBN_M_THRES - if (M <= BLIS_ATBN_M_THRES) - { - for (col_idx = 0; (col_idx + (NR - 1)) < N; col_idx += NR) - { - for (row_idx = 0; (row_idx + (AT_MR - 1)) < M; row_idx += AT_MR) - { - tA = A + row_idx * lda; - tB = B + col_idx * ldb; - tC = C + col_idx * ldc + row_idx; - // clear scratch registers. - ymm4 = _mm256_setzero_ps(); - ymm5 = _mm256_setzero_ps(); - ymm6 = _mm256_setzero_ps(); - ymm7 = _mm256_setzero_ps(); - ymm8 = _mm256_setzero_ps(); - ymm9 = _mm256_setzero_ps(); - ymm10 = _mm256_setzero_ps(); - ymm11 = _mm256_setzero_ps(); - ymm12 = _mm256_setzero_ps(); - ymm13 = _mm256_setzero_ps(); - ymm14 = _mm256_setzero_ps(); - ymm15 = _mm256_setzero_ps(); - - //The inner loop computes the 4x3 values of the matrix. - //The computation pattern is: - // ymm4 ymm5 ymm6 - // ymm7 ymm8 ymm9 - // ymm10 ymm11 ymm12 - // ymm13 ymm14 ymm15 - - //The Dot operation is performed in the inner loop, 8 float elements fit - //in the YMM register hence loop count incremented by 8 - for (k = 0; (k + 7) < K; k += 8) - { - ymm0 = _mm256_loadu_ps(tB + 0); - ymm1 = _mm256_loadu_ps(tB + ldb); - ymm2 = _mm256_loadu_ps(tB + 2 * ldb); - - ymm3 = _mm256_loadu_ps(tA); - ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_ps(ymm1, ymm3, ymm5); - ymm6 = _mm256_fmadd_ps(ymm2, ymm3, ymm6); - - ymm3 = _mm256_loadu_ps(tA + lda); - ymm7 = _mm256_fmadd_ps(ymm0, ymm3, ymm7); - ymm8 = _mm256_fmadd_ps(ymm1, ymm3, ymm8); - ymm9 = _mm256_fmadd_ps(ymm2, ymm3, ymm9); - - ymm3 = _mm256_loadu_ps(tA + 2 * lda); - ymm10 = _mm256_fmadd_ps(ymm0, ymm3, ymm10); - ymm11 = _mm256_fmadd_ps(ymm1, ymm3, ymm11); - ymm12 = _mm256_fmadd_ps(ymm2, ymm3, ymm12); - - ymm3 = _mm256_loadu_ps(tA + 3 * lda); - ymm13 = _mm256_fmadd_ps(ymm0, ymm3, ymm13); - ymm14 = _mm256_fmadd_ps(ymm1, ymm3, ymm14); - ymm15 = _mm256_fmadd_ps(ymm2, ymm3, ymm15); - - tA += 8; - tB += 8; - - } - - // if K is not a multiple of 8, padding is done before load using temproary array. - if (k < K) - { - int iter; - float data_feeder[8] = { 0.0 }; - - for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter]; - ymm0 = _mm256_loadu_ps(data_feeder); - for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter + ldb]; - ymm1 = _mm256_loadu_ps(data_feeder); - for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter + 2 * ldb]; - ymm2 = _mm256_loadu_ps(data_feeder); - - for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[iter]; - ymm3 = _mm256_loadu_ps(data_feeder); - ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_ps(ymm1, ymm3, ymm5); - ymm6 = _mm256_fmadd_ps(ymm2, ymm3, ymm6); - - for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[lda + iter]; - ymm3 = _mm256_loadu_ps(data_feeder); - ymm7 = _mm256_fmadd_ps(ymm0, ymm3, ymm7); - ymm8 = _mm256_fmadd_ps(ymm1, ymm3, ymm8); - ymm9 = _mm256_fmadd_ps(ymm2, ymm3, ymm9); - - for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[2 * lda + iter]; - ymm3 = _mm256_loadu_ps(data_feeder); - ymm10 = _mm256_fmadd_ps(ymm0, ymm3, ymm10); - ymm11 = _mm256_fmadd_ps(ymm1, ymm3, ymm11); - ymm12 = _mm256_fmadd_ps(ymm2, ymm3, ymm12); - - for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[3 * lda + iter]; - ymm3 = _mm256_loadu_ps(data_feeder); - ymm13 = _mm256_fmadd_ps(ymm0, ymm3, ymm13); - ymm14 = _mm256_fmadd_ps(ymm1, ymm3, ymm14); - ymm15 = _mm256_fmadd_ps(ymm2, ymm3, ymm15); - - } - - //horizontal addition and storage of the data. - //Results for 4x3 blocks of C is stored here - ymm4 = _mm256_hadd_ps(ymm4, ymm4); - ymm4 = _mm256_hadd_ps(ymm4, ymm4); - _mm256_storeu_ps(scratch, ymm4); - result = scratch[0] + scratch[4]; - result *= (*alpha_cast); - tC[0] = result/* + tC[0] * (*beta_cast)*/; - - ymm7 = _mm256_hadd_ps(ymm7, ymm7); - ymm7 = _mm256_hadd_ps(ymm7, ymm7); - _mm256_storeu_ps(scratch, ymm7); - result = scratch[0] + scratch[4]; - result *= (*alpha_cast); - tC[1] = result/* + tC[1] * (*beta_cast)*/; - - ymm10 = _mm256_hadd_ps(ymm10, ymm10); - ymm10 = _mm256_hadd_ps(ymm10, ymm10); - _mm256_storeu_ps(scratch, ymm10); - result = scratch[0] + scratch[4]; - result *= (*alpha_cast); - tC[2] = result/* + tC[2] * (*beta_cast)*/; - - ymm13 = _mm256_hadd_ps(ymm13, ymm13); - ymm13 = _mm256_hadd_ps(ymm13, ymm13); - _mm256_storeu_ps(scratch, ymm13); - result = scratch[0] + scratch[4]; - result *= (*alpha_cast); - tC[3] = result/* + tC[3] * (*beta_cast)*/; - - tC += ldc; - ymm5 = _mm256_hadd_ps(ymm5, ymm5); - ymm5 = _mm256_hadd_ps(ymm5, ymm5); - _mm256_storeu_ps(scratch, ymm5); - result = scratch[0] + scratch[4]; - result *= (*alpha_cast); - tC[0] = result/* + tC[0] * (*beta_cast)*/; - - ymm8 = _mm256_hadd_ps(ymm8, ymm8); - ymm8 = _mm256_hadd_ps(ymm8, ymm8); - _mm256_storeu_ps(scratch, ymm8); - result = scratch[0] + scratch[4]; - result *= (*alpha_cast); - tC[1] = result/* + tC[1] * (*beta_cast)*/; - - ymm11 = _mm256_hadd_ps(ymm11, ymm11); - ymm11 = _mm256_hadd_ps(ymm11, ymm11); - _mm256_storeu_ps(scratch, ymm11); - result = scratch[0] + scratch[4]; - result *= (*alpha_cast); - tC[2] = result/* + tC[2] * (*beta_cast)*/; - - ymm14 = _mm256_hadd_ps(ymm14, ymm14); - ymm14 = _mm256_hadd_ps(ymm14, ymm14); - _mm256_storeu_ps(scratch, ymm14); - result = scratch[0] + scratch[4]; - result *= (*alpha_cast); - tC[3] = result/* + tC[3] * (*beta_cast)*/; - - tC += ldc; - ymm6 = _mm256_hadd_ps(ymm6, ymm6); - ymm6 = _mm256_hadd_ps(ymm6, ymm6); - _mm256_storeu_ps(scratch, ymm6); - result = scratch[0] + scratch[4]; - result *= (*alpha_cast); - tC[0] = result/* + tC[0] * (*beta_cast)*/; - - ymm9 = _mm256_hadd_ps(ymm9, ymm9); - ymm9 = _mm256_hadd_ps(ymm9, ymm9); - _mm256_storeu_ps(scratch, ymm9); - result = scratch[0] + scratch[4]; - result *= (*alpha_cast); - tC[1] = result/* + tC[1] * (*beta_cast)*/; - - ymm12 = _mm256_hadd_ps(ymm12, ymm12); - ymm12 = _mm256_hadd_ps(ymm12, ymm12); - _mm256_storeu_ps(scratch, ymm12); - result = scratch[0] + scratch[4]; - result *= (*alpha_cast); - tC[2] = result/* + tC[2] * (*beta_cast)*/; - - ymm15 = _mm256_hadd_ps(ymm15, ymm15); - ymm15 = _mm256_hadd_ps(ymm15, ymm15); - _mm256_storeu_ps(scratch, ymm15); - result = scratch[0] + scratch[4]; - result *= (*alpha_cast); - tC[3] = result/* + tC[3] * (*beta_cast)*/; - } - } - - int processed_col = col_idx; - int processed_row = row_idx; - - //The edge case handling where N is not a multiple of 3 - if (processed_col < N) - { - for (col_idx = processed_col; col_idx < N; col_idx += 1) - { - for (row_idx = 0; (row_idx + (AT_MR - 1)) < M; row_idx += AT_MR) - { - tA = A + row_idx * lda; - tB = B + col_idx * ldb; - tC = C + col_idx * ldc + row_idx; - // clear scratch registers. - ymm4 = _mm256_setzero_ps(); - ymm7 = _mm256_setzero_ps(); - ymm10 = _mm256_setzero_ps(); - ymm13 = _mm256_setzero_ps(); - - //The inner loop computes the 4x1 values of the matrix. - //The computation pattern is: - // ymm4 - // ymm7 - // ymm10 - // ymm13 - - for (k = 0; (k + 7) < K; k += 8) - { - ymm0 = _mm256_loadu_ps(tB + 0); - - ymm3 = _mm256_loadu_ps(tA); - ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4); - - ymm3 = _mm256_loadu_ps(tA + lda); - ymm7 = _mm256_fmadd_ps(ymm0, ymm3, ymm7); - - ymm3 = _mm256_loadu_ps(tA + 2 * lda); - ymm10 = _mm256_fmadd_ps(ymm0, ymm3, ymm10); - - ymm3 = _mm256_loadu_ps(tA + 3 * lda); - ymm13 = _mm256_fmadd_ps(ymm0, ymm3, ymm13); - - tA += 8; - tB += 8; - } - - // if K is not a multiple of 8, padding is done before load using temproary array. - if (k < K) - { - int iter; - float data_feeder[8] = { 0.0 }; - - for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter]; - ymm0 = _mm256_loadu_ps(data_feeder); - - for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[iter]; - ymm3 = _mm256_loadu_ps(data_feeder); - ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4); - - for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[lda + iter]; - ymm3 = _mm256_loadu_ps(data_feeder); - ymm7 = _mm256_fmadd_ps(ymm0, ymm3, ymm7); - - for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[2 * lda + iter]; - ymm3 = _mm256_loadu_ps(data_feeder); - ymm10 = _mm256_fmadd_ps(ymm0, ymm3, ymm10); - - for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[3 * lda + iter]; - ymm3 = _mm256_loadu_ps(data_feeder); - ymm13 = _mm256_fmadd_ps(ymm0, ymm3, ymm13); - - } - - //horizontal addition and storage of the data. - //Results for 4x1 blocks of C is stored here - ymm4 = _mm256_hadd_ps(ymm4, ymm4); - ymm4 = _mm256_hadd_ps(ymm4, ymm4); - _mm256_storeu_ps(scratch, ymm4); - result = scratch[0] + scratch[4]; - result *= (*alpha_cast); - tC[0] = result/* + tC[0] * (*beta_cast)*/; - - ymm7 = _mm256_hadd_ps(ymm7, ymm7); - ymm7 = _mm256_hadd_ps(ymm7, ymm7); - _mm256_storeu_ps(scratch, ymm7); - result = scratch[0] + scratch[4]; - result *= (*alpha_cast); - tC[1] = result/* + tC[1] * (*beta_cast)*/; - - ymm10 = _mm256_hadd_ps(ymm10, ymm10); - ymm10 = _mm256_hadd_ps(ymm10, ymm10); - _mm256_storeu_ps(scratch, ymm10); - result = scratch[0] + scratch[4]; - result *= (*alpha_cast); - tC[2] = result/* + tC[2] * (*beta_cast)*/; - - ymm13 = _mm256_hadd_ps(ymm13, ymm13); - ymm13 = _mm256_hadd_ps(ymm13, ymm13); - _mm256_storeu_ps(scratch, ymm13); - result = scratch[0] + scratch[4]; - result *= (*alpha_cast); - tC[3] = result/* + tC[3] * (*beta_cast)*/; - - } - } - processed_row = row_idx; - } - - //The edge case handling where M is not a multiple of 4 - if (processed_row < M) - { - for (row_idx = processed_row; row_idx < M; row_idx += 1) - { - for (col_idx = 0; col_idx < N; col_idx += 1) - { - tA = A + row_idx * lda; - tB = B + col_idx * ldb; - tC = C + col_idx * ldc + row_idx; - // clear scratch registers. - ymm4 = _mm256_setzero_ps(); - - for (k = 0; (k + 7) < K; k += 8) - { - ymm0 = _mm256_loadu_ps(tB + 0); - ymm3 = _mm256_loadu_ps(tA); - ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4); - - tA += 8; - tB += 8; - } - - // if K is not a multiple of 8, padding is done before load using temproary array. - if (k < K) - { - int iter; - float data_feeder[8] = { 0.0 }; - - for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter]; - ymm0 = _mm256_loadu_ps(data_feeder); - - for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[iter]; - ymm3 = _mm256_loadu_ps(data_feeder); - ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4); - - } - - //horizontal addition and storage of the data. - ymm4 = _mm256_hadd_ps(ymm4, ymm4); - ymm4 = _mm256_hadd_ps(ymm4, ymm4); - _mm256_storeu_ps(scratch, ymm4); - result = scratch[0] + scratch[4]; - result *= (*alpha_cast); - tC[0] = result/* + tC[0] * (*beta_cast)*/; - - } - } - } - - //copy/compute sryk values back to C - if ( bli_seq0( *beta_cast ) ) //when beta is 0, just copy result to C - { - dim_t _i, _j; - if(bli_obj_is_lower(c)) //c is lower - { - for ( _j = 0; _j < N; ++_j ) - for ( _i = 0; _i < M; ++_i ) - if ( (doff_t)_j - (doff_t)_i <= 0 ) - { - bli_sscopys( *(C + _i*rsc + _j*ldc), - *(matCbuf + _i*rs_matC + _j*ldc_matC) ); - } - } - else //c is upper - { - for ( _j = 0; _j < N; ++_j ) - for ( _i = 0; _i < M; ++_i ) - if ( (doff_t)_j - (doff_t)_i >= 0 ) - { - bli_sscopys( *(C + _i*rsc + _j*ldc), - *(matCbuf + _i*rs_matC + _j*ldc_matC) ); - } - } - } - else //when beta is non-zero, multiply and store result to C - { - dim_t _i, _j; - if(bli_obj_is_lower(c)) //c is lower - { - for ( _j = 0; _j < N; ++_j ) - for ( _i = 0; _i < M; ++_i ) - if ( (doff_t)_j - (doff_t)_i <= 0 ) - { - bli_sssxpbys( *(C + _i*rsc + _j*ldc), - *(beta_cast), - *(matCbuf + _i*rs_matC + _j*ldc_matC) ); - } - } - else //c is upper - { - for ( _j = 0; _j < N; ++_j ) - for ( _i = 0; _i < M; ++_i ) - if ( (doff_t)_j - (doff_t)_i >= 0 ) - { - bli_sssxpbys( *(C + _i*rsc + _j*ldc), - *(beta_cast), - *(matCbuf + _i*rs_matC + _j*ldc_matC) ); - } - } - } - - return BLIS_SUCCESS; - } - else - return BLIS_NONCONFORMAL_DIMENSIONS; -} - -static err_t bli_dsyrk_small_atbn - ( - obj_t* alpha, - obj_t* a, - obj_t* b, - obj_t* beta, - obj_t* c, - cntx_t* cntx, - cntl_t* cntl - ) -{ - int M = bli_obj_length( c ); // number of rows of Matrix C - int N = bli_obj_width( c ); // number of columns of Matrix C - int K = bli_obj_length( b ); // number of rows of Matrix B - int lda = bli_obj_col_stride( a ); // column stride of matrix OP(A), where OP(A) is Transpose(A) if transA enabled. - int ldb = bli_obj_col_stride( b ); // column stride of matrix OP(B), where OP(B) is Transpose(B) if transB enabled. - int ldc_matC = bli_obj_col_stride( c ); // column stride of matrix C - int ldc = M;//bli_obj_col_stride( c ); // column stride of static buffer for matrix C - int row_idx = 0, col_idx = 0, k; - int rs_matC = bli_obj_row_stride( c ); - int rsc = 1; - double *A = a->buffer; // pointer to matrix A elements, stored in row major format - double *B = b->buffer; // pointer to matrix B elements, stored in column major format - double *C = D_C_pack; // pointer to matrix C elements, stored in column major format - double *matCbuf = c->buffer; - - double *tA = A, *tB = B, *tC = C; - - __m256d ymm4, ymm5, ymm6, ymm7; - __m256d ymm8, ymm9, ymm10, ymm11; - __m256d ymm12, ymm13, ymm14, ymm15; - __m256d ymm0, ymm1, ymm2, ymm3; - - double result, scratch[8]; - double *alpha_cast, *beta_cast; // alpha, beta multiples - alpha_cast = (alpha->buffer); - beta_cast = (beta->buffer); - - // The non-copy version of the A^T SYRK gives better performance for the small M cases. - // The threshold is controlled by BLIS_ATBN_M_THRES - if (M <= BLIS_ATBN_M_THRES) - { - for (col_idx = 0; (col_idx + (NR - 1)) < N; col_idx += NR) - { - for (row_idx = 0; (row_idx + (AT_MR - 1)) < M; row_idx += AT_MR) - { - tA = A + row_idx * lda; - tB = B + col_idx * ldb; - tC = C + col_idx * ldc + row_idx; - // clear scratch registers. - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - ymm12 = _mm256_setzero_pd(); - ymm13 = _mm256_setzero_pd(); - ymm14 = _mm256_setzero_pd(); - ymm15 = _mm256_setzero_pd(); - - //The inner loop computes the 4x3 values of the matrix. - //The computation pattern is: - // ymm4 ymm5 ymm6 - // ymm7 ymm8 ymm9 - // ymm10 ymm11 ymm12 - // ymm13 ymm14 ymm15 - - //The Dot operation is performed in the inner loop, 4 double elements fit - //in the YMM register hence loop count incremented by 4 - for (k = 0; (k + 3) < K; k += 4) - { - ymm0 = _mm256_loadu_pd(tB + 0); - ymm1 = _mm256_loadu_pd(tB + ldb); - ymm2 = _mm256_loadu_pd(tB + 2 * ldb); - - ymm3 = _mm256_loadu_pd(tA); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - ymm6 = _mm256_fmadd_pd(ymm2, ymm3, ymm6); - - ymm3 = _mm256_loadu_pd(tA + lda); - ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7); - ymm8 = _mm256_fmadd_pd(ymm1, ymm3, ymm8); - ymm9 = _mm256_fmadd_pd(ymm2, ymm3, ymm9); - - ymm3 = _mm256_loadu_pd(tA + 2 * lda); - ymm10 = _mm256_fmadd_pd(ymm0, ymm3, ymm10); - ymm11 = _mm256_fmadd_pd(ymm1, ymm3, ymm11); - ymm12 = _mm256_fmadd_pd(ymm2, ymm3, ymm12); - - ymm3 = _mm256_loadu_pd(tA + 3 * lda); - ymm13 = _mm256_fmadd_pd(ymm0, ymm3, ymm13); - ymm14 = _mm256_fmadd_pd(ymm1, ymm3, ymm14); - ymm15 = _mm256_fmadd_pd(ymm2, ymm3, ymm15); - - tA += 4; - tB += 4; - - } - - // if K is not a multiple of 4, padding is done before load using temproary array. - if (k < K) - { - int iter; - double data_feeder[4] = { 0.0 }; - - for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter]; - ymm0 = _mm256_loadu_pd(data_feeder); - for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter + ldb]; - ymm1 = _mm256_loadu_pd(data_feeder); - for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter + 2 * ldb]; - ymm2 = _mm256_loadu_pd(data_feeder); - - for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[iter]; - ymm3 = _mm256_loadu_pd(data_feeder); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - ymm6 = _mm256_fmadd_pd(ymm2, ymm3, ymm6); - - for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[lda + iter]; - ymm3 = _mm256_loadu_pd(data_feeder); - ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7); - ymm8 = _mm256_fmadd_pd(ymm1, ymm3, ymm8); - ymm9 = _mm256_fmadd_pd(ymm2, ymm3, ymm9); - - for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[2 * lda + iter]; - ymm3 = _mm256_loadu_pd(data_feeder); - ymm10 = _mm256_fmadd_pd(ymm0, ymm3, ymm10); - ymm11 = _mm256_fmadd_pd(ymm1, ymm3, ymm11); - ymm12 = _mm256_fmadd_pd(ymm2, ymm3, ymm12); - - for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[3 * lda + iter]; - ymm3 = _mm256_loadu_pd(data_feeder); - ymm13 = _mm256_fmadd_pd(ymm0, ymm3, ymm13); - ymm14 = _mm256_fmadd_pd(ymm1, ymm3, ymm14); - ymm15 = _mm256_fmadd_pd(ymm2, ymm3, ymm15); - - } - - //horizontal addition and storage of the data. - //Results for 4x3 blocks of C is stored here - ymm4 = _mm256_hadd_pd(ymm4, ymm4); - _mm256_storeu_pd(scratch, ymm4); - result = scratch[0] + scratch[2]; - result *= (*alpha_cast); - tC[0] = result/* + tC[0] * (*beta_cast)*/; - - ymm7 = _mm256_hadd_pd(ymm7, ymm7); - _mm256_storeu_pd(scratch, ymm7); - result = scratch[0] + scratch[2]; - result *= (*alpha_cast); - tC[1] = result/* + tC[1] * (*beta_cast)*/; - - ymm10 = _mm256_hadd_pd(ymm10, ymm10); - _mm256_storeu_pd(scratch, ymm10); - result = scratch[0] + scratch[2]; - result *= (*alpha_cast); - tC[2] = result/* + tC[2] * (*beta_cast)*/; - - ymm13 = _mm256_hadd_pd(ymm13, ymm13); - _mm256_storeu_pd(scratch, ymm13); - result = scratch[0] + scratch[2]; - result *= (*alpha_cast); - tC[3] = result/* + tC[3] * (*beta_cast)*/; - - - tC += ldc; - ymm5 = _mm256_hadd_pd(ymm5, ymm5); - _mm256_storeu_pd(scratch, ymm5); - result = scratch[0] + scratch[2]; - result *= (*alpha_cast); - tC[0] = result/* + tC[0] * (*beta_cast)*/; - - ymm8 = _mm256_hadd_pd(ymm8, ymm8); - _mm256_storeu_pd(scratch, ymm8); - result = scratch[0] + scratch[2]; - result *= (*alpha_cast); - tC[1] = result/* + tC[1] * (*beta_cast)*/; - - ymm11 = _mm256_hadd_pd(ymm11, ymm11); - _mm256_storeu_pd(scratch, ymm11); - result = scratch[0] + scratch[2]; - result *= (*alpha_cast); - tC[2] = result/* + tC[2] * (*beta_cast)*/; - - ymm14 = _mm256_hadd_pd(ymm14, ymm14); - _mm256_storeu_pd(scratch, ymm14); - result = scratch[0] + scratch[2]; - result *= (*alpha_cast); - tC[3] = result/* + tC[3] * (*beta_cast)*/; - - - tC += ldc; - ymm6 = _mm256_hadd_pd(ymm6, ymm6); - _mm256_storeu_pd(scratch, ymm6); - result = scratch[0] + scratch[2]; - result *= (*alpha_cast); - tC[0] = result/* + tC[0] * (*beta_cast)*/; - - ymm9 = _mm256_hadd_pd(ymm9, ymm9); - _mm256_storeu_pd(scratch, ymm9); - result = scratch[0] + scratch[2]; - result *= (*alpha_cast); - tC[1] = result/* + tC[1] * (*beta_cast)*/; - - ymm12 = _mm256_hadd_pd(ymm12, ymm12); - _mm256_storeu_pd(scratch, ymm12); - result = scratch[0] + scratch[2]; - result *= (*alpha_cast); - tC[2] = result/* + tC[2] * (*beta_cast)*/; - - ymm15 = _mm256_hadd_pd(ymm15, ymm15); - _mm256_storeu_pd(scratch, ymm15); - result = scratch[0] + scratch[2]; - result *= (*alpha_cast); - tC[3] = result/* + tC[3] * (*beta_cast)*/; - } - } - - int processed_col = col_idx; - int processed_row = row_idx; - - //The edge case handling where N is not a multiple of 3 - if (processed_col < N) - { - for (col_idx = processed_col; col_idx < N; col_idx += 1) - { - for (row_idx = 0; (row_idx + (AT_MR - 1)) < M; row_idx += AT_MR) - { - tA = A + row_idx * lda; - tB = B + col_idx * ldb; - tC = C + col_idx * ldc + row_idx; - // clear scratch registers. - ymm4 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm13 = _mm256_setzero_pd(); - - //The inner loop computes the 4x1 values of the matrix. - //The computation pattern is: - // ymm4 - // ymm7 - // ymm10 - // ymm13 - - for (k = 0; (k + 3) < K; k += 4) - { - ymm0 = _mm256_loadu_pd(tB + 0); - - ymm3 = _mm256_loadu_pd(tA); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - ymm3 = _mm256_loadu_pd(tA + lda); - ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7); - - ymm3 = _mm256_loadu_pd(tA + 2 * lda); - ymm10 = _mm256_fmadd_pd(ymm0, ymm3, ymm10); - - ymm3 = _mm256_loadu_pd(tA + 3 * lda); - ymm13 = _mm256_fmadd_pd(ymm0, ymm3, ymm13); - - tA += 4; - tB += 4; - } - // if K is not a multiple of 4, padding is done before load using temproary array. - if (k < K) - { - int iter; - double data_feeder[4] = { 0.0 }; - - for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter]; - ymm0 = _mm256_loadu_pd(data_feeder); - - for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[iter]; - ymm3 = _mm256_loadu_pd(data_feeder); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[lda + iter]; - ymm3 = _mm256_loadu_pd(data_feeder); - ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7); - - for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[2 * lda + iter]; - ymm3 = _mm256_loadu_pd(data_feeder); - ymm10 = _mm256_fmadd_pd(ymm0, ymm3, ymm10); - - for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[3 * lda + iter]; - ymm3 = _mm256_loadu_pd(data_feeder); - ymm13 = _mm256_fmadd_pd(ymm0, ymm3, ymm13); - - } - - //horizontal addition and storage of the data. - //Results for 4x1 blocks of C is stored here - ymm4 = _mm256_hadd_pd(ymm4, ymm4); - _mm256_storeu_pd(scratch, ymm4); - result = scratch[0] + scratch[2]; - result *= (*alpha_cast); - tC[0] = result/* + tC[0] * (*beta_cast)*/; - - ymm7 = _mm256_hadd_pd(ymm7, ymm7); - _mm256_storeu_pd(scratch, ymm7); - result = scratch[0] + scratch[2]; - result *= (*alpha_cast); - tC[1] = result/* + tC[1] * (*beta_cast)*/; - - ymm10 = _mm256_hadd_pd(ymm10, ymm10); - _mm256_storeu_pd(scratch, ymm10); - result = scratch[0] + scratch[2]; - result *= (*alpha_cast); - tC[2] = result/* + tC[2] * (*beta_cast)*/; - - ymm13 = _mm256_hadd_pd(ymm13, ymm13); - _mm256_storeu_pd(scratch, ymm13); - result = scratch[0] + scratch[2]; - result *= (*alpha_cast); - tC[3] = result/* + tC[3] * (*beta_cast)*/; - - } - } - processed_row = row_idx; - } - - // The edge case handling where M is not a multiple of 4 - if (processed_row < M) - { - for (row_idx = processed_row; row_idx < M; row_idx += 1) - { - for (col_idx = 0; col_idx < N; col_idx += 1) - { - tA = A + row_idx * lda; - tB = B + col_idx * ldb; - tC = C + col_idx * ldc + row_idx; - // clear scratch registers. - ymm4 = _mm256_setzero_pd(); - - for (k = 0; (k + 3) < K; k += 4) - { - ymm0 = _mm256_loadu_pd(tB + 0); - ymm3 = _mm256_loadu_pd(tA); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - tA += 4; - tB += 4; - } - - // if K is not a multiple of 4, padding is done before load using temproary array. - if (k < K) - { - int iter; - double data_feeder[4] = { 0.0 }; - - for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tB[iter]; - ymm0 = _mm256_loadu_pd(data_feeder); - - for (iter = 0; iter < (K - k); iter++) data_feeder[iter] = tA[iter]; - ymm3 = _mm256_loadu_pd(data_feeder); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - } - - //horizontal addition and storage of the data. - ymm4 = _mm256_hadd_pd(ymm4, ymm4); - _mm256_storeu_pd(scratch, ymm4); - result = scratch[0] + scratch[2]; - result *= (*alpha_cast); - tC[0] = result/* + tC[0] * (*beta_cast)*/; - - } - } - } - - //copy/compute sryk values back to C - if ( bli_seq0( *beta_cast ) ) //when beta is 0, just copy result to C - { - dim_t _i, _j; - if(bli_obj_is_lower(c)) //c is lower - { - for ( _j = 0; _j < N; ++_j ) - for ( _i = 0; _i < M; ++_i ) - if ( (doff_t)_j - (doff_t)_i <= 0 ) - { - bli_ddcopys( *(C + _i*rsc + _j*ldc), - *(matCbuf + _i*rs_matC + _j*ldc_matC) ); - } - } - else //c is upper - { - for ( _j = 0; _j < N; ++_j ) - for ( _i = 0; _i < M; ++_i ) - if ( (doff_t)_j - (doff_t)_i >= 0 ) - { - bli_ddcopys( *(C + _i*rsc + _j*ldc), - *(matCbuf + _i*rs_matC + _j*ldc_matC) ); - } - } - } - else //when beta is non-zero, multiply and store result to C - { - dim_t _i, _j; - if(bli_obj_is_lower(c)) //c is lower - { - for ( _j = 0; _j < N; ++_j ) - for ( _i = 0; _i < M; ++_i ) - if ( (doff_t)_j - (doff_t)_i <= 0 ) - { - bli_dddxpbys( *(C + _i*rsc + _j*ldc), - *(beta_cast), - *(matCbuf + _i*rs_matC + _j*ldc_matC) ); - } - } - else //c is upper - { - for ( _j = 0; _j < N; ++_j ) - for ( _i = 0; _i < M; ++_i ) - if ( (doff_t)_j - (doff_t)_i >= 0 ) - { - bli_dddxpbys( *(C + _i*rsc + _j*ldc), - *(beta_cast), - *(matCbuf + _i*rs_matC + _j*ldc_matC) ); - } - } - } - - return BLIS_SUCCESS; - } - else - return BLIS_NONCONFORMAL_DIMENSIONS; -} - -#endif - From ffcb33853197753b05f5424ecb6668167e68ca2f Mon Sep 17 00:00:00 2001 From: Saitharun Date: Mon, 6 Sep 2021 17:13:21 +0530 Subject: [PATCH 012/243] Adding ENABLE_WRAPPER CMAKE OPTION details: Wrapper code will be enabled when selecting the cmake option ENABLE_WRAPPER and also this commit will fixing the ScaLAPACK build error on windows. AMD-Internal: [CPUPL-1848] Change-Id: I3d687cbc00e7603fdfb45937a00daf86bd07878e --- CMakeLists.txt | 14 ++ frame/compat/cblas/src/cblas_f77.h | 191 +++++++++++++++- frame/include/bli_macro_defs.h | 354 ++++++++++++++++++++++++++++- frame/util/bli_util_api_wrap.c | 9 +- frame/util/bli_util_api_wrap.h | 6 +- 5 files changed, 567 insertions(+), 7 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 8572aad640..3fa559abc6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -90,10 +90,16 @@ option(BLIS_ENABLE_ILP64 "ENABLE BLIS ILP64" OFF) option(ENABLE_INT_TYPE_SIZE " Internal BLIS integers ,used in native BLIS interfaces based on architecture dependent " ON) option(ENABLE_BLASTEST "Enable the blastest" OFF) option(ENABLE_TESTCPP_TESTING "Enabling testcpp" OFF) +option (ENABLE_NO_UNDERSCORE_API "export APIs without underscore" ON) +option (ENABLE_UPPERCASE_API "export APIs with uppercase" OFF) +option (ENABLE_API_WRAPPER "Enable wrapper code" OFF) option (ENABLE_COMPLEX_RETURN_INTEL "Enable complex_return_intel" OFF) option (ENABLE_TRSM_PREINVERSION "Enable TRSM preinversion" ON) option (ENABLE_AOCL_DYNAMIC "Enable Dynamic Multi-threading" OFF) +if(ENABLE_NO_UNDERSCORE_API) + add_definitions(-DBLIS_ENABLE_NO_UNDERSCORE_API) +endif() if(ENABLE_COMPLEX_RETURN_INTEL) set(BLIS_ENABLE_COMPLEX_RETURN_INTEL TRUE) @@ -101,6 +107,14 @@ else() set(BLIS_DISABLE_COMPLEX_RETURN_INTEL TRUE) endif() +if(ENABLE_UPPERCASE_API) + add_definitions(-DBLIS_ENABLE_UPPERCASE_API) +endif() + +if(ENABLE_API_WRAPPER) + add_definitions(-DBLIS_ENABLE_API_WRAPPER) +endif() + if(ENABLE_AOCL_DYNAMIC) set(AOCL_DYNAMIC TRUE) endif() diff --git a/frame/compat/cblas/src/cblas_f77.h b/frame/compat/cblas/src/cblas_f77.h index b09963eec7..fabf3efb1c 100644 --- a/frame/compat/cblas/src/cblas_f77.h +++ b/frame/compat/cblas/src/cblas_f77.h @@ -14,6 +14,195 @@ #ifndef CBLAS_F77_H #define CBLAS_F77_H +#if defined(BLIS_ENABLE_NO_UNDERSCORE_API) + /* + * Level 1 BLAS + */ +#define F77_xerbla xerbla +#define F77_srotg srotg +#define F77_srotmg srotmg +#define F77_srot srot +#define F77_srotm srotm +#define F77_drotg drotg +#define F77_drotmg drotmg +#define F77_drot drot +#define F77_drotm drotm +#define F77_sswap sswap +#define F77_scopy scopy +#define F77_saxpy saxpy +#define F77_isamax_sub isamaxsub +#define F77_dswap dswap +#define F77_dcopy dcopy +#define F77_daxpy daxpy +#define F77_idamax_sub idamaxsub +#define F77_cswap cswap +#define F77_ccopy ccopy +#define F77_caxpy caxpy +#define F77_icamax_sub icamaxsub +#define F77_zswap zswap +#define F77_zcopy zcopy +#define F77_zaxpy zaxpy +#define F77_zaxpby zaxpby +#define F77_izamax_sub izamaxsub +#define F77_sdot_sub sdotsub +#define F77_ddot_sub ddotsub +#define F77_dsdot_sub dsdotsub +#define F77_sscal sscal +#define F77_dscal dscal +#define F77_cscal cscal +#define F77_zscal zscal +#define F77_csscal csscal +#define F77_zdscal zdscal +#define F77_cdotu_sub cdotusub +#define F77_cdotc_sub cdotcsub +#define F77_zdotu_sub zdotusub +#define F77_zdotc_sub zdotcsub +#define F77_snrm2_sub snrm2sub +#define F77_sasum_sub sasumsub +#define F77_dnrm2_sub dnrm2sub +#define F77_dasum_sub dasumsub +#define F77_scnrm2_sub scnrm2sub +#define F77_scasum_sub scasumsub +#define F77_dznrm2_sub dznrm2sub +#define F77_dzasum_sub dzasumsub +#define F77_sdsdot_sub sdsdotsub +/* +* Level 2 BLAS +*/ +#define F77_ssymv ssymv +#define F77_ssbmv ssbmv +#define F77_sspmv sspmv +#define F77_sger sger +#define F77_ssyr ssyr +#define F77_sspr sspr +#define F77_ssyr2 ssyr2 +#define F77_sspr2 sspr2 +#define F77_dsymv dsymv +#define F77_dsbmv dsbmv +#define F77_dspmv dspmv +#define F77_dger dger +#define F77_dsyr dsyr +#define F77_dspr dspr +#define F77_dsyr2 dsyr2 +#define F77_dspr2 dspr2 +#define F77_chemv chemv +#define F77_chbmv chbmv +#define F77_chpmv chpmv +#define F77_cgeru cgeru +#define F77_cgerc cgerc +#define F77_cher cher +#define F77_chpr chpr +#define F77_cher2 cher2 +#define F77_chpr2 chpr2 +#define F77_zhemv zhemv +#define F77_zhbmv zhbmv +#define F77_zhpmv zhpmv +#define F77_zgeru zgeru +#define F77_zgerc zgerc +#define F77_zher zher +#define F77_zhpr zhpr +#define F77_zher2 zher2 +#define F77_zhpr2 zhpr2 +#define F77_sgemv sgemv +#define F77_sgbmv sgbmv +#define F77_strmv strmv +#define F77_stbmv stbmv +#define F77_stpmv stpmv +#define F77_strsv strsv +#define F77_stbsv stbsv +#define F77_stpsv stpsv +#define F77_dgemv dgemv +#define F77_dgbmv dgbmv +#define F77_dtrmv dtrmv +#define F77_dtbmv dtbmv +#define F77_dtpmv dtpmv +#define F77_dtrsv dtrsv +#define F77_dtbsv dtbsv +#define F77_dtpsv dtpsv +#define F77_cgemv cgemv +#define F77_cgbmv cgbmv +#define F77_ctrmv ctrmv +#define F77_ctbmv ctbmv +#define F77_ctpmv ctpmv +#define F77_ctrsv ctrsv +#define F77_ctbsv ctbsv +#define F77_ctpsv ctpsv +#define F77_zgemv zgemv +#define F77_zgbmv zgbmv +#define F77_ztrmv ztrmv +#define F77_ztbmv ztbmv +#define F77_ztpmv ztpmv +#define F77_ztrsv ztrsv +#define F77_ztbsv ztbsv +#define F77_ztpsv ztpsv +/* +* Level 3 BLAS +*/ +#define F77_chemm chemm +#define F77_cherk cherk +#define F77_cher2k cher2k +#define F77_zhemm zhemm +#define F77_zherk zherk +#define F77_zher2k zher2k +#define F77_sgemm sgemm +#define F77_ssymm ssymm +#define F77_ssyrk ssyrk +#define F77_ssyr2k ssyr2k +#define F77_strmm strmm +#define F77_strsm strsm +#define F77_dgemm dgemm +#define F77_dsymm dsymm +#define F77_dsyrk dsyrk +#define F77_dsyr2k dsyr2k +#define F77_dtrmm dtrmm +#define F77_dtrsm dtrsm +#define F77_cgemm cgemm +#define F77_csymm csymm +#define F77_csyrk csyrk +#define F77_csyr2k csyr2k +#define F77_ctrmm ctrmm +#define F77_ctrsm ctrsm +#define F77_zgemm zgemm +#define F77_zsymm zsymm +#define F77_zsyrk zsyrk +#define F77_zsyr2k zsyr2k +#define F77_ztrmm ztrmm +#define F77_ztrsm ztrsm +#define F77_dgemmt dgemmt +#define F77_sgemmt sgemmt +#define F77_cgemmt cgemmt +#define F77_zgemmt zgemmt + +/* +* Aux Function +*/ +#define F77_scabs1 scabs1 +#define F77_dcabs1 dcabs1 + +/* + * -- BLAS Extension APIs -- + */ + +#define F77_saxpby saxpby +#define F77_daxpby daxpby +#define F77_caxpby caxpby +#define F77_zaxpby zaxpby +#define F77_cgemm3m cgemm3m +#define F77_zgemm3m zgemm3m + +#define F77_isamin_sub isaminsub +#define F77_idamin_sub idaminsub +#define F77_icamin_sub icaminsub +#define F77_izamin_sub izaminsub + +// -- Batch APIs -- +#define F77_sgemm_batch sgemm_batch +#define F77_dgemm_batch dgemm_batch +#define F77_cgemm_batch cgemm_batch +#define F77_zgemm_batch zgemm_batch + +// (BLIS_ENABLE_NO_UNDERSCORE_API) ends +#else /* * Level 1 BLAS */ @@ -201,4 +390,4 @@ #define F77_zgemm_batch zgemm_batch_ #endif -/* CBLAS_F77_H */ +#endif /* CBLAS_F77_H */ \ No newline at end of file diff --git a/frame/include/bli_macro_defs.h b/frame/include/bli_macro_defs.h index 9808590393..f29fdc1fe4 100644 --- a/frame/include/bli_macro_defs.h +++ b/frame/include/bli_macro_defs.h @@ -6,7 +6,7 @@ Copyright (C) 2014, The University of Texas at Austin Copyright (C) 2018-2021, Advanced Micro Devices, Inc. All rights reserved. - + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -156,12 +156,18 @@ #define STRINGIFY_INT( s ) MKSTR( s ) #define PASTEMACT(ch1, ch2, ch3, ch4) bli_ ## ch1 ## ch2 ## _ ## ch3 ## _ ## ch4 - +// name-mangling macros. +#ifdef BLIS_ENABLE_NO_UNDERSCORE_API +#define PASTEF770(name) name +#define PASTEF77(ch1,name) ch1 ## name +#define PASTEF772(ch1,ch2,name) ch1 ## ch2 ## name +#define PASTEF773(ch1,ch2,ch3,name) ch1 ## ch2 ## ch3 ## name +#else #define PASTEF770(name) name ## _ #define PASTEF77(ch1,name) ch1 ## name ## _ #define PASTEF772(ch1,ch2,name) ch1 ## ch2 ## name ## _ #define PASTEF773(ch1,ch2,ch3,name) ch1 ## ch2 ## ch3 ## name ## _ - +#endif // -- Include other groups of macros @@ -181,4 +187,346 @@ #include "bli_oapi_macro_defs.h" #include "bli_tapi_macro_defs.h" + +#ifdef BLIS_ENABLE_NO_UNDERSCORE_API +#define isamax_ isamax +#define idamax_ idamax +#define icamax_ icamax +#define izamax_ izamax +#define sasum_ sasum +#define dasum_ dasum +#define scasum_ scasum +#define dzasum_ dzasum +#define saxpy_ saxpy +#define daxpy_ daxpy +#define caxpy_ caxpy +#define zaxpy_ zaxpy +#define scopy_ scopy +#define dcopy_ dcopy +#define ccopy_ ccopy +#define zcopy_ zcopy +#define sdot_ sdot +#define ddot_ ddot +#define cdotc_ cdotc +#define zdotc_ zdotc +#define cdotu_ cdotu +#define zdotu_ zdotu +#define snrm2_ snrm2 +#define dnrm2_ dnrm2 +#define scnrm2_ scnrm2 +#define dznrm2_ dznrm2 +#define sscal_ sscal +#define dscal_ dscal +#define cscal_ cscal +#define csscal_ csscal +#define zscal_ zscal +#define zdscal_ zdscal +#define sswap_ sswap +#define dswap_ dswap +#define cswap_ cswap +#define zswap_ zswap +#define sgemv_ sgemv +#define dgemv_ dgemv +#define cgemv_ cgemv +#define zgemv_ zgemv +#define sger_ sger +#define dger_ dger +#define cgerc_ cgerc +#define cgeru_ cgeru +#define zgerc_ zgerc +#define zgeru_ zgeru +#define chemv_ chemv +#define zhemv_ zhemv +#define cher_ cher +#define zher_ zher +#define cher2_ cher2 +#define zher2_ zher2 +#define ssymv_ ssymv +#define dsymv_ dsymv +#define csymm_ csymm +#define zsymm_ zsymm +#define ssyr_ ssyr +#define dsyr_ dsyr +#define csyrk_ csyrk +#define csyrk_ csyrk +#define zsyrk_ zsyrk +#define ssyr2_ ssyr2 +#define dsyr2_ dsyr2 +#define csyr2k_ csyr2k +#define zsyr2k_ zsyr2k +#define strmv_ strmv +#define dtrmv_ dtrmv +#define ctrmv_ ctrmv +#define ztrmv_ ztrmv +#define strsv_ strsv +#define dtrsv_ dtrsv +#define ctrsv_ ctrsv +#define ztrsv_ ztrsv +#define sgemm_ sgemm +#define dgemm_ dgemm +#define cgemm_ cgemm +#define zgemm_ zgemm +#define chemm_ chemm +#define zhemm_ zhemm +#define dgemmt_ dgemmt +#define sgemmt_ sgemmt +#define zgemmt_ zgemmt +#define cgemmt_ cgemmt +#define sgemm_batch_ sgemm_batch +#define dgemm_batch_ dgemm_batch +#define cgemm_batch_ cgemm_batch +#define zgemm_batch_ zgemm_batch +#define saxpby_ saxpby +#define daxpby_ daxpby +#define caxpby_ caxpby +#define zaxpby_ zaxpby +#define cher2k_ cher2k +#define zher2k_ zher2k +#define cherk_ cherk +#define zherk_ zherk +#define ssymm_ ssymm +#define dsymm_ dsymm +#define ssyr2k_ ssyr2k +#define dsyr2k_ dsyr2k +#define ssyrk_ ssyrk +#define dsyrk_ dsyrk +#define strmm_ strmm +#define dtrmm_ dtrmm +#define ctrmm_ ctrmm +#define ztrmm_ ztrmm +#define strsm_ strsm +#define dtrsm_ dtrsm +#define ctrsm_ ctrsm +#define ztrsm_ ztrsm +#define lsame_ lsame +#define cimatcopy_ cimatcopy +#define comatadd_ comatadd +#define comatcopy2_ comatcopy2 +#define comatcopy_ comatcopy +#define dimatcopy_ dimatcopy +#define domatadd_ domatadd +#define domatcopy2_ domatcopy2 +#define domatcopy_ domatcopy +#define simatcopy_ simatcopy +#define somatadd_ somatadd +#define somatcopy2_ somatcopy2 +#define somatcopy_ somatcopy +#define zimatcopy_ zimatcopy +#define zomatadd_ zomatadd +#define zomatcopy2_ zomatcopy2 +#define zomatcopy_ zomatcopy #endif + +#ifdef BLIS_ENABLE_UPPERCASE_API +#define caxpby CAXPBY +#define caxpy CAXPY +#define ccopy CCOPY +#define cdotc CDOTC +#define cdotcsub CDOTCSUB +#define cdotu CDOTU +#define cdotusub CDOTUSUB +#define cgbmv CGBMV +#define cgemm CGEMM +#define cgemm3m CGEMM3M +#define cgemm_batch CGEMM_BATCH +#define cgemmt CGEMMT +#define cgemv CGEMV +#define cgerc CGERC +#define cgeru CGERU +#define chbmv CHBMV +#define chemm CHEMM +#define chemv CHEMV +#define cher CHER +#define cher2 CHER2 +#define cher2k CHER2K +#define cherk CHERK +#define chpmv CHPMV +#define chpr CHPR +#define chpr2 CHPR2 +#define cimatcopy CIMATCOPY +#define comatadd COMATADD +#define comatcopy2 COMATCOPY2 +#define comatcopy COMATCOPY +#define crotg CROTG +#define cscal CSCAL +#define csrot CSROT +#define csscal CSSCAL +#define cswap CSWAP +#define csymm CSYMM +#define csyr2k CSYR2K +#define csyrk CSYRK +#define ctbmv CTBMV +#define ctbsv CTBSV +#define ctpmv CTPMV +#define ctpsv CTPSV +#define ctrmm CTRMM +#define ctrmv CTRMV +#define ctrsm CTRSM +#define ctrsv CTRSV +#define dasum DASUM +#define dasumsub DASUMSUB +#define daxpby DAXPBY +#define daxpy DAXPY +#define dcabs1 DCABS1 +#define dcopy DCOPY +#define ddot DDOT +#define ddotsub DDOTSUB +#define dgbmv DGBMV +#define dgemm DGEMM +#define dgemm_batch DGEMM_BATCH +#define dgemmt DGEMMT +#define dgemv DGEMV +#define dger DGER +#define dnrm2 DNRM2 +#define dnrm2sub DNRM2SUB +#define dimatcopy DIMATCOPY +#define domatadd DOMATADD +#define domatcopy2 DOMATCOPY2 +#define domatcopy DOMATCOPY +#define drot DROT +#define drotg DROTG +#define drotm DROTM +#define drotmg DROTMG +#define dsbmv DSBMV +#define dscal DSCAL +#define dsdot DSDOT +#define dsdotsub DSDOTSUB +#define dspmv DSPMV +#define dspr DSPR +#define dspr2 DSPR2 +#define dswap DSWAP +#define dsymm DSYMM +#define dsymv DSYMV +#define dsyr DSYR +#define dsyr2 DSYR2 +#define dsyr2k DSYR2K +#define dsyrk DSYRK +#define dtbmv DTBMV +#define dtbsv DTBSV +#define dtpmv DTPMV +#define dtpsv DTPSV +#define dtrmm DTRMM +#define dtrmv DTRMV +#define dtrsm DTRSM +#define dtrsv DTRSV +#define dzasum DZASUM +#define dzasumsub DZASUMSUB +#define dznrm2 DZNRM2 +#define dznrm2sub DZNRM2SUB +#define icamax ICAMAX +#define icamaxsub ICAMAXSUB +#define icamin ICAMIN +#define icaminsub ICAMINSUB +#define idamax IDAMAX +#define idamaxsub IDAMAXSUB +#define idamin IDAMIN +#define idaminsub IDAMINSUB +#define isamax ISAMAX +#define isamaxsub ISAMAXSUB +#define isamin ISAMIN +#define isaminsub ISAMINSUB +#define izamax IZAMAX +#define izamaxsub IZAMAXSUB +#define izamin IZAMIN +#define izaminsub IZAMINSUB +#define lsame LSAME +#define sasum SASUM +#define sasumsub SASUMSUB +#define saxpby SAXPBY +#define saxpy SAXPY +#define scabs1 SCABS1 +#define scasum SCASUM +#define scasumsub SCASUMSUB +#define scnrm2 SCNRM2 +#define scnrm2sub SCNRM2SUB +#define scopy SCOPY +#define sdot SDOT +#define sdotsub SDOTSUB +#define sdsdot SDSDOT +#define sdsdotsub SDSDOTSUB +#define sgbmv SGBMV +#define sgemm SGEMM +#define sgemm_batch SGEMM_BATCH +#define sgemmt SGEMMT +#define sgemv SGEMV +#define sger SGER +#define snrm2 SNRM2 +#define snrm2sub SNRM2SUB +#define simatcopy SIMATCOPY +#define somatadd SOMATADD +#define somatcopy2 SOMATCOPY2 +#define somatcopy SOMATCOPY +#define srot SROT +#define srotg SROTG +#define srotm SROTM +#define srotmg SROTMG +#define ssbmv SSBMV +#define sscal SSCAL +#define sspmv SSPMV +#define sspr SSPR +#define sspr2 SSPR2 +#define sswap SSWAP +#define ssymm SSYMM +#define ssymv SSYMV +#define ssyr SSYR +#define ssyr2 SSYR2 +#define ssyr2k SSYR2K +#define ssyrk SSYRK +#define stbmv STBMV +#define stbsv STBSV +#define stpmv STPMV +#define stpsv STPSV +#define strmm STRMM +#define strmv STRMV +#define strsm STRSM +#define strsv STRSV +#define xerbla XERBLA +#define zaxpby ZAXPBY +#define zaxpy ZAXPY +#define zcopy ZCOPY +#define zdotc ZDOTC +#define zdotcsub ZDOTCSUB +#define zdotu ZDOTU +#define zdotusub ZDOTUSUB +#define zdrot ZDROT +#define zdscal ZDSCAL +#define zgbmv ZGBMV +#define zgemm ZGEMM +#define zgemm3m ZGEMM3M +#define zgemm_batch ZGEMM_BATCH +#define zgemmt ZGEMMT +#define zgemv ZGEMV +#define zgerc ZGERC +#define zgeru ZGERU +#define zhbmv ZHBMV +#define zhemm ZHEMM +#define zhemv ZHEMV +#define zher ZHER +#define zher2 ZHER2 +#define zher2k ZHER2K +#define zherk ZHERK +#define zhpmv ZHPMV +#define zhpr ZHPR +#define zhpr2 ZHPR2 +#define zimatcopy ZIMATCOPY +#define zomatadd ZOMATADD +#define zomatcopy2 ZOMATCOPY2 +#define zomatcopy ZOMATCOPY +#define zrotg ZROTG +#define zscal ZSCAL +#define zswap ZSWAP +#define zsymm ZSYMM +#define zsyr2k ZSYR2K +#define zsyrk ZSYRK +#define ztbmv ZTBMV +#define ztbsv ZTBSV +#define ztpmv ZTPMV +#define ztpsv ZTPSV +#define ztrmm ZTRMM +#define ztrmv ZTRMV +#define ztrsm ZTRSM +#define ztrsv ZTRSV +#endif + +#endif + diff --git a/frame/util/bli_util_api_wrap.c b/frame/util/bli_util_api_wrap.c index 393a56e143..128fba8b87 100644 --- a/frame/util/bli_util_api_wrap.c +++ b/frame/util/bli_util_api_wrap.c @@ -32,11 +32,14 @@ */ +// file define different formats of BLAS APIs- uppercase with +// and without underscore, lowercase without underscore. + #include "blis.h" #include "bli_util_api_wrap.h" // wrapper functions to support additional symbols - +#ifdef BLIS_ENABLE_API_WRAPPER void CAXPY(const f77_int *n,const scomplex *ca,const scomplex *cx,const f77_int *incx,scomplex *cy,const f77_int *incy) { caxpy_( n, ca, cx, incx, cy, incy); @@ -3215,4 +3218,6 @@ void caxpby( const f77_int* n, const scomplex* alpha, const scomplex *x, cons void CAXPBY_( const f77_int* n, const scomplex* alpha, const scomplex *x, const f77_int* incx, const scomplex* beta, scomplex *y, const f77_int* incy) { caxpby_(n, alpha, x, incx, beta, y, incy); -} \ No newline at end of file +} + +#endif diff --git a/frame/util/bli_util_api_wrap.h b/frame/util/bli_util_api_wrap.h index 46d5a636a2..f0aff49ff2 100644 --- a/frame/util/bli_util_api_wrap.h +++ b/frame/util/bli_util_api_wrap.h @@ -32,6 +32,10 @@ */ +// file define different formats of BLAS APIs- uppercase with +// and without underscore, lowercase without underscore. + +#ifdef BLIS_ENABLE_API_WRAPPER //Level 1 APIs BLIS_EXPORT_BLIS void SROTG(float *sa, float *sb, float *c, float *s); @@ -1724,4 +1728,4 @@ BLIS_EXPORT_BLIS void zomatcopy(f77_char* trans, f77_int* rows, f77_int* cols, BLIS_EXPORT_BLIS void ZOMATCOPY_(f77_char* trans, f77_int* rows, f77_int* cols, const dcomplex* alpha, const dcomplex* aptr, f77_int* lda, dcomplex* bptr, f77_int* ldb); - +#endif From 15b7fff1598a3d2119d25e3f788fe9eaee8e58ab Mon Sep 17 00:00:00 2001 From: Nallani Bhaskar Date: Mon, 13 Sep 2021 22:03:28 +0530 Subject: [PATCH 013/243] Fixed reading C when beta=0 in few sgemm asm kernels Description: 1. When beta is zero we should not be doing any arthemetic operation on C data and should not assume anything on values of C matrix 2. It is taken care in all sgemmsup kernels already except in bli_sgemmsup_rv_zen_asm_2x8 and bli_sgemmsup_rv_zen_asm_3x8 kernels, when beta zero and C is column storage case. Fixed this issue by removing reading C matrix in these kernels. 3. When C has NaN or Inf and when we multiply NaN or Inf with zero (beta) the result becomes NaN only. Change-Id: I3fb8c0cd37cf1d52a7909f6b402aa9c40c7c3846 --- .../zen/3/sup/bli_gemmsup_rv_zen_asm_s6x16.c | 34 ------------------- 1 file changed, 34 deletions(-) diff --git a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_s6x16.c b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_s6x16.c index 32d5b65841..347384aa65 100644 --- a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_s6x16.c +++ b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_s6x16.c @@ -4125,33 +4125,16 @@ void bli_sgemmsup_rv_zen_asm_3x8 vshufpd(imm(0x01), xmm0, xmm0, xmm1)//a1b1 vshufpd(imm(0x01), xmm2, xmm2, xmm10)//a3b3 - vmovsd(mem(rcx),xmm4) - vmovsd(mem(rcx, rsi, 1),xmm6) - vfmadd231ps(xmm4, xmm3, xmm0) - vfmadd231ps(xmm6, xmm3, xmm1) vmovsd(xmm0, mem(rcx)) // store ( gamma00..gamma10 ) vmovsd(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) - vmovsd(mem(rcx, rsi, 2),xmm4) - vmovsd(mem(rcx, rax, 1),xmm6) - vfmadd231ps(xmm4, xmm3, xmm2) - vfmadd231ps(xmm6, xmm3, xmm10) vmovsd(xmm2, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) vmovsd(xmm10, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) lea(mem(rcx, rsi, 4), rcx) // rcx += cs_c vshufpd(imm(0x01), xmm11, xmm11, xmm1)//a1b1 vshufpd(imm(0x01), xmm12, xmm12, xmm10)//a3b3 - vmovsd(mem(rcx),xmm4) - vmovsd(mem(rcx, rsi, 1),xmm6) - vfmadd231ps(xmm4, xmm3, xmm11) - vfmadd231ps(xmm6, xmm3, xmm1) vmovsd(xmm11, mem(rcx)) // store ( gamma00..gamma10 ) vmovsd(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) - - vmovsd(mem(rcx, rsi, 2),xmm4) - vmovsd(mem(rcx, rax, 1),xmm6) - vfmadd231ps(xmm4, xmm3, xmm12) - vfmadd231ps(xmm6, xmm3, xmm10) vmovsd(xmm12, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) vmovsd(xmm10, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) @@ -4474,33 +4457,16 @@ void bli_sgemmsup_rv_zen_asm_2x8 vshufpd(imm(0x01), xmm0, xmm0, xmm1)//a1b1 vshufpd(imm(0x01), xmm2, xmm2, xmm10)//a3b3 - vmovsd(mem(rcx),xmm4) - vmovsd(mem(rcx, rsi, 1),xmm6) - vfmadd231ps(xmm4, xmm3, xmm0) - vfmadd231ps(xmm6, xmm3, xmm1) vmovsd(xmm0, mem(rcx)) // store ( gamma00..gamma10 ) vmovsd(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) - vmovsd(mem(rcx, rsi, 2),xmm4) - vmovsd(mem(rcx, rax, 1),xmm6) - vfmadd231ps(xmm4, xmm3, xmm2) - vfmadd231ps(xmm6, xmm3, xmm10) vmovsd(xmm2, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) vmovsd(xmm10, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) lea(mem(rcx, rsi, 4), rcx) // rcx += cs_c vshufpd(imm(0x01), xmm11, xmm11, xmm1)//a1b1 vshufpd(imm(0x01), xmm12, xmm12, xmm10)//a3b3 - vmovsd(mem(rcx),xmm4) - vmovsd(mem(rcx, rsi, 1),xmm6) - vfmadd231ps(xmm4, xmm3, xmm11) - vfmadd231ps(xmm6, xmm3, xmm1) vmovsd(xmm11, mem(rcx)) // store ( gamma00..gamma10 ) vmovsd(xmm1, mem(rcx, rsi, 1)) // store ( gamma01..gamma11 ) - - vmovsd(mem(rcx, rsi, 2),xmm4) - vmovsd(mem(rcx, rax, 1),xmm6) - vfmadd231ps(xmm4, xmm3, xmm12) - vfmadd231ps(xmm6, xmm3, xmm10) vmovsd(xmm12, mem(rcx, rsi, 2)) // store ( gamma02..gamma12 ) vmovsd(xmm10, mem(rcx, rax, 1)) // store ( gamma03..gamma13 ) From f1291fc957128478b805ff4d555eee7206d4e5e1 Mon Sep 17 00:00:00 2001 From: Meghana Vankadari Date: Wed, 8 Sep 2021 12:18:10 +0530 Subject: [PATCH 014/243] Disabled dzgemm blas interfaces AMD-Internal: [CPUPL-1825] Change-Id: I786a2ad83be418bd4c55d24e30ae6c3a86fd8844 --- frame/compat/bla_gemm.c | 6 +++++- frame/compat/bla_gemm.h | 4 +++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/frame/compat/bla_gemm.c b/frame/compat/bla_gemm.c index 0e77c3bb1f..30e3b58e8e 100644 --- a/frame/compat/bla_gemm.c +++ b/frame/compat/bla_gemm.c @@ -719,6 +719,10 @@ INSERT_GENTFUNC_BLAS_SC( gemm, gemm ) #else INSERT_GENTFUNC_BLAS( gemm,gemm ) #endif + +// Observed a regression in dgemm with this function addition. +// Disabling temporarily. +#if 0 void dzgemm_ ( const f77_char* transa, @@ -808,5 +812,5 @@ void dzgemm_ /* Finalize BLIS. */ bli_finalize_auto(); }// end of dzgemm_ - +#endif #endif diff --git a/frame/compat/bla_gemm.h b/frame/compat/bla_gemm.h index 425b01e7b0..25aef8d11f 100644 --- a/frame/compat/bla_gemm.h +++ b/frame/compat/bla_gemm.h @@ -54,6 +54,8 @@ BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ ); #ifdef BLIS_ENABLE_BLAS +// Disabling temporarily +#if 0 BLIS_EXPORT_BLAS void dzgemm_ ( const f77_char* transa, \ @@ -67,7 +69,7 @@ BLIS_EXPORT_BLAS void dzgemm_ const dcomplex* beta, \ dcomplex* c, const f77_int* ldc \ ); - +#endif INSERT_GENTPROT_BLAS( gemm ) #endif From 3dda4ebf229e4f2ab2c13cc1bb82b42f95ed76d9 Mon Sep 17 00:00:00 2001 From: Madan mohan Manokar Date: Fri, 10 Sep 2021 09:12:59 +0530 Subject: [PATCH 015/243] Induced method turned off, fix for beta=0 & C = NAN 1. Induced Method turned off, till the path fully tested for different alpha,beta conditions. 2. Fix for Beta =0, and C = NAN done. Change-Id: I5a7bd1393ac245c2ebb72f9a634728af4c0d4000 --- frame/compat/bla_gemm.c | 2 +- kernels/zen/3/bli_gemm_sqp_kernels.c | 17 ++++++++++++++++- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/frame/compat/bla_gemm.c b/frame/compat/bla_gemm.c index 30e3b58e8e..faa7459a49 100644 --- a/frame/compat/bla_gemm.c +++ b/frame/compat/bla_gemm.c @@ -38,7 +38,7 @@ // // Define BLAS-to-BLIS interfaces. // -#define ENABLE_INDUCED_METHOD 1 +#define ENABLE_INDUCED_METHOD 0 #ifdef BLIS_BLAS3_CALLS_TAPI #undef GENTFUNC diff --git a/kernels/zen/3/bli_gemm_sqp_kernels.c b/kernels/zen/3/bli_gemm_sqp_kernels.c index 9cac5e83eb..0f20c0a956 100644 --- a/kernels/zen/3/bli_gemm_sqp_kernels.c +++ b/kernels/zen/3/bli_gemm_sqp_kernels.c @@ -1278,7 +1278,22 @@ void bli_3m_sqp_packC_real_imag(double* pc, } } } - else /* handles alpha or beta is not equal +/- 1.0 */ + else if(mul==0) /* handles alpha or beta is equal to zero */ + { + double br_ = 0; + double bi_ = 0; + for (j = 0; j < n; j++) + { + for (p = 0; p < (m*2); p += 2)// (real + imag)*m + { + *pcr = br_; + *pci = bi_; + pcr++; pci++; + } + pc = pc + ldc; + } + } + else /* handles alpha or beta is not equal +/- 1.0 and zero */ { for (j = 0; j < n; j++) { From 628f4a41b8c307b6cbf0c9d467b705c3548c7400 Mon Sep 17 00:00:00 2001 From: Kiran Varaganti Date: Wed, 15 Sep 2021 12:02:28 +0530 Subject: [PATCH 016/243] Fixed bug in packing API bli_pack_set_pack_b() API wrongly inits pack_a element of rntm object instead of pack_b. Fixed this bug. Change-Id: I267493ab3ff0bade478d1157799a6fa5b51c7970 --- frame/base/bli_pack.c | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/frame/base/bli_pack.c b/frame/base/bli_pack.c index 5f4cca575e..382b960d4b 100644 --- a/frame/base/bli_pack.c +++ b/frame/base/bli_pack.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018-2021, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -101,7 +101,7 @@ void bli_pack_set_pack_b( bool pack_b ) // Acquire the mutex protecting global_rntm. bli_pthread_mutex_lock( &global_rntm_mutex ); - bli_rntm_set_pack_a( pack_b, &global_rntm ); + bli_rntm_set_pack_b( pack_b, &global_rntm ); // Release the mutex protecting global_rntm. bli_pthread_mutex_unlock( &global_rntm_mutex ); From 09bfe3e37236c5ecdd31f5991801790eee3ade12 Mon Sep 17 00:00:00 2001 From: Kiran Varaganti Date: Mon, 6 Sep 2021 17:07:19 +0530 Subject: [PATCH 017/243] Improve DGEMM sup for RCR case When m and n are very large values, larger KC is preferred to reduce L3 cache misses, therefore increased KC value to KC0. This improved performance of DGEMM considerably on EPYC processors Kindly note (PACKA and PACKB disabled). Change-Id: I7b3f5b53a01fca0e55bc4479eeb644c7001d3463 --- frame/3/bli_l3_sup_var1n2m.c | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/frame/3/bli_l3_sup_var1n2m.c b/frame/3/bli_l3_sup_var1n2m.c index 467ca31e6a..e8c4d7845b 100644 --- a/frame/3/bli_l3_sup_var1n2m.c +++ b/frame/3/bli_l3_sup_var1n2m.c @@ -971,7 +971,8 @@ void PASTEMAC(ch,varname) \ else if ( m <= 2*MR && n <= 2*NR ) KC = KC0 / 2; \ else if ( m <= 3*MR && n <= 3*NR ) KC = (( KC0 / 3 ) / 4 ) * 4; \ else if ( m <= 4*MR && n <= 4*NR ) KC = KC0 / 4; \ - else KC = (( KC0 / 5 ) / 4 ) * 4; \ + /* Will revisit setting this KC value - larger m and n demands larger KC */ \ + else KC = KC0; /* (( KC0 / 5 ) / 4 ) * 4; VK */ \ } \ \ /* Query the maximum blocksize for NR, which implies a maximum blocksize From 5ecb4fd0dbccef842d7efe0a0c92933b383dbb05 Mon Sep 17 00:00:00 2001 From: Kiran Varaganti Date: Sun, 19 Sep 2021 20:41:31 +0530 Subject: [PATCH 018/243] Removed dead code This sanity check (checking top_index != 0) which has been disabled earlier in bli_pool_reinit() by commenting out bli_abort() was unnecessarily computing top_index - this whole statement is commented out. Change-Id: If296754ca8cba3a69d023d4a7ec891f1cbce1d6a --- frame/base/bli_pool.c | 3 +++ 1 file changed, 3 insertions(+) diff --git a/frame/base/bli_pool.c b/frame/base/bli_pool.c index d994fe9782..7e561983c6 100644 --- a/frame/base/bli_pool.c +++ b/frame/base/bli_pool.c @@ -122,6 +122,7 @@ void bli_pool_finalize // Query the total number of blocks currently allocated. const siz_t num_blocks = bli_pool_num_blocks( pool ); +#if 0 // Removing dead code // Query the top_index of the pool. const siz_t top_index = bli_pool_top_index( pool ); @@ -142,6 +143,8 @@ void bli_pool_finalize //bli_abort(); } +#endif + // Query the free() function pointer for the pool. free_ft free_fp = bli_pool_free_fp( pool ); From 7196b86f0514be71a91ff950094b4e9e16f60f52 Mon Sep 17 00:00:00 2001 From: Kiran Varaganti Date: Mon, 20 Sep 2021 12:51:32 +0530 Subject: [PATCH 019/243] Removed packm kernels of zen intrinsic optimized packm kernels written for zen are no longer used. Therefore removing it. Currently packm kernels from haswell configuration are being used for zen2 and zen3 configs. --- kernels/zen/1m/CMakeLists.txt | 6 - kernels/zen/1m/bli_packm_zen_int.c | 544 ----------------------------- kernels/zen/CMakeLists.txt | 2 +- 3 files changed, 1 insertion(+), 551 deletions(-) delete mode 100644 kernels/zen/1m/CMakeLists.txt delete mode 100644 kernels/zen/1m/bli_packm_zen_int.c diff --git a/kernels/zen/1m/CMakeLists.txt b/kernels/zen/1m/CMakeLists.txt deleted file mode 100644 index 6af90a6a46..0000000000 --- a/kernels/zen/1m/CMakeLists.txt +++ /dev/null @@ -1,6 +0,0 @@ -##Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.## - -target_sources("${PROJECT_NAME}" - PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/bli_packm_zen_int.c - ) diff --git a/kernels/zen/1m/bli_packm_zen_int.c b/kernels/zen/1m/bli_packm_zen_int.c deleted file mode 100644 index 174557efdc..0000000000 --- a/kernels/zen/1m/bli_packm_zen_int.c +++ /dev/null @@ -1,544 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2020, Advanced Micro Devices, Inc. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include "immintrin.h" -#include "blis.h" - - -// Union data structure to access AVX registers -// One 256-bit AVX register holds 4 DP elements. -typedef union -{ - __m256d v; - double d[4] __attribute__((aligned(64))); -} v4df_t; - - - - -// packing routine for dgemm/trsm -// when op(A) = n & op(B) = n -void bli_dpackm_8xk_nn_zen -( - conj_t conja, - pack_t schema, - dim_t cdim, - dim_t n, - dim_t n_max, - double* restrict kappa, - double* restrict a, inc_t inca, inc_t lda, // inca = 1 - double* restrict p, inc_t ldp, - cntx_t* restrict cntx -) -{ - double* restrict alpha1 = a; - double* restrict pi1 = p; - - dim_t n_iter = n / 2; - dim_t n_left = n % 2; - - if (cdim == 8) - { - // (*kappa_cast) = 1.0 for GEMM - __m256d ymmSrc_0_0123; // source registers - __m256d ymmSrc_0_4567; - __m256d ymmSrc_1_0123; - __m256d ymmSrc_1_4567; - - for (; n_iter != 0; --n_iter) - { - // Works when inca = 1, which is the case for op(A) = n and op(B) = n - ymmSrc_0_0123 = _mm256_loadu_pd(alpha1 + 0 * inca + 0 * lda); - ymmSrc_0_4567 = _mm256_loadu_pd(alpha1 + 4 * inca + 0 * lda); - ymmSrc_1_0123 = _mm256_loadu_pd(alpha1 + 0 * inca + 1 * lda); - ymmSrc_1_4567 = _mm256_loadu_pd(alpha1 + 4 * inca + 1 * lda); - - // Store -#if 1 - _mm256_storeu_pd((pi1 + 0 + 0 * ldp), ymmSrc_0_0123); - _mm256_storeu_pd((pi1 + 4 + 0 * ldp), ymmSrc_0_4567); - - _mm256_storeu_pd((pi1 + 0 + 1 * ldp), ymmSrc_1_0123); - _mm256_storeu_pd((pi1 + 4 + 1 * ldp), ymmSrc_1_4567); -#else - _mm256_stream_pd((pi1 + 0), ymmSrc_0_0123); - _mm256_stream_pd((pi1 + 4), ymmSrc_0_4567); - - _mm256_stream_pd((pi1 + 0 + 1 * ldp), ymmSrc_1_0123); - _mm256_stream_pd((pi1 + 4 + 1 * ldp), ymmSrc_1_4567); -#endif - alpha1 += 2 * lda; - pi1 += 2 * ldp; - } - - if (n_left & 1) //for (; n_left != 0; --n_left) - { - ymmSrc_0_0123 = _mm256_loadu_pd(alpha1 + 0 * inca); - ymmSrc_0_4567 = _mm256_loadu_pd(alpha1 + 4 * inca); - - _mm256_storeu_pd((pi1 + 0), ymmSrc_0_0123); - _mm256_storeu_pd((pi1 + 4), ymmSrc_0_4567); - - alpha1 += lda; - pi1 += ldp; - } - } - else /* if ( cdim < mnr ) */ - { - double* restrict a_cast = a; - double* restrict p_cast = p; - // (*kappa_cast == 1.0) for GEMM - - PRAGMA_SIMD - for (dim_t j = 0; j < n; ++j) - for (dim_t i = 0; i < cdim; ++i) - p_cast[i + j*ldp] = a_cast[i + j*lda]; - - - const dim_t i = cdim; - const dim_t m_edge = 8 - cdim; - const dim_t n_edge = n_max; - // double* restrict p_cast = p; - double* restrict p_edge = p_cast + (i) * 1; - - PRAGMA_SIMD - for (dim_t j = 0; j < n_edge; ++j) - for (dim_t i = 0; i < m_edge; ++i) - *(p_edge + i + j*ldp) = 0.0; - } - - if (n < n_max) - { - const dim_t j = n; - const dim_t m_edge = 8; - const dim_t n_edge = n_max - n; - double* restrict p_cast = p; - double* restrict p_edge = p_cast + (j)*ldp; - - PRAGMA_SIMD - for (dim_t j = 0; j < n_edge; ++j) - for (dim_t i = 0; i < m_edge; ++i) - *(p_edge + i + j*ldp) = 0.0; - } -}// End of function - - -void bli_dpackm_6xk_nn_zen -( - conj_t conja, - pack_t schema, - dim_t cdim, - dim_t n, - dim_t n_max, - double* restrict kappa, - double* restrict a, inc_t inca, inc_t lda, - double* restrict p, inc_t ldp, - cntx_t* restrict cntx -) -{ - double* restrict alpha1 = a; - double* restrict pi1 = p; - - if (cdim == 6) - { - //if ( (*kappa_cast) == 1.0 ) // Kappa_cast = 1.0 for dgemm - for (dim_t k = n; k != 0; --k) - { - (*(pi1 + 0)) = (*(alpha1 + 0 * inca)); - (*(pi1 + 1)) = (*(alpha1 + 1 * inca)); - (*(pi1 + 2)) = (*(alpha1 + 2 * inca)); - (*(pi1 + 3)) = (*(alpha1 + 3 * inca)); - - (*(pi1 + 4)) = (*(alpha1 + 4 * inca)); - (*(pi1 + 5)) = (*(alpha1 + 5 * inca)); - - alpha1 += lda; - pi1 += ldp; - } - } - else /* if ( cdim < mnr ) */ - { - double* restrict a_cast = a; - double* restrict p_cast = p; - - // (*kappa_cast == 1.0) for GEMM - // a will be in row-major, inca != 1 and lda = 1 - PRAGMA_SIMD - for (dim_t i = 0; i < cdim; ++i) - for(dim_t j = 0; j < n; ++j) - p_cast[i + j*ldp] = a_cast[i * inca + j]; // i * inca + j * lda, lda = 1 - - - const dim_t m_edge = 6 - cdim; - const dim_t n_edge = n_max; - // double* restrict p_cast = p; - double* restrict p_edge = p_cast + (cdim) * 1; - - PRAGMA_SIMD - for (dim_t j = 0; j < n_edge; ++j) - for (dim_t i = 0; i < m_edge; ++i) - *(p_edge + i + j*ldp) = 0.0; - } - - if (n < n_max) - { - const dim_t j = n; - const dim_t m_edge = 6; - const dim_t n_edge = n_max - n; - double* restrict p_cast = p; - double* restrict p_edge = p_cast + (j)*ldp; - - PRAGMA_SIMD - for (dim_t j = 0; j < n_edge; ++j) - for (dim_t i = 0; i < m_edge; ++i) - *(p_edge + i + j*ldp) = 0.0; - } -}// end of function - - - -// Packing routine for general operations op(A) = ? op(B) = ? -void bli_dpackm_8xk_gen_zen -( - conj_t conja, - pack_t schema, - dim_t cdim, - dim_t n, - dim_t n_max, - double* restrict kappa, - double* restrict a, inc_t inca, inc_t lda, - double* restrict p, inc_t ldp, - cntx_t* restrict cntx -) -{ - double* restrict kappa_cast = kappa; - double* restrict alpha1 = a; - double* restrict pi1 = p; - - dim_t n_iter = n / 2; - dim_t n_left = n % 2; - - if (cdim == 8) - { - if ((*kappa_cast) == (1.0)) - { - if (bli_is_conj(conja)) - { - //for (dim_t k = n; k != 0; --k) - for (dim_t k = n; k--;) - { - (((*(pi1 + 0)))) = (((*(alpha1 + 0 * inca)))); - (((*(pi1 + 1)))) = (((*(alpha1 + 1 * inca)))); - (((*(pi1 + 2)))) = (((*(alpha1 + 2 * inca)))); - (((*(pi1 + 3)))) = (((*(alpha1 + 3 * inca)))); - (((*(pi1 + 4)))) = (((*(alpha1 + 4 * inca)))); - (((*(pi1 + 5)))) = (((*(alpha1 + 5 * inca)))); - (((*(pi1 + 6)))) = (((*(alpha1 + 6 * inca)))); - (((*(pi1 + 7)))) = (((*(alpha1 + 7 * inca)))); - - alpha1 += lda; - pi1 += ldp; - } - } - else - { - for (; n_iter != 0; --n_iter) - { - - ((*(pi1 + 0 + 0 * ldp))) = ((*(alpha1 + 0 * inca + 0 * lda))); - ((*(pi1 + 1 + 0 * ldp))) = ((*(alpha1 + 1 * inca + 0 * lda))); - ((*(pi1 + 2 + 0 * ldp))) = ((*(alpha1 + 2 * inca + 0 * lda))); - ((*(pi1 + 3 + 0 * ldp))) = ((*(alpha1 + 3 * inca + 0 * lda))); - ((*(pi1 + 4 + 0 * ldp))) = ((*(alpha1 + 4 * inca + 0 * lda))); - ((*(pi1 + 5 + 0 * ldp))) = ((*(alpha1 + 5 * inca + 0 * lda))); - ((*(pi1 + 6 + 0 * ldp))) = ((*(alpha1 + 6 * inca + 0 * lda))); - ((*(pi1 + 7 + 0 * ldp))) = ((*(alpha1 + 7 * inca + 0 * lda))); - - ((*(pi1 + 0 + 1 * ldp))) = ((*(alpha1 + 0 * inca + 1 * lda))); - ((*(pi1 + 1 + 1 * ldp))) = ((*(alpha1 + 1 * inca + 1 * lda))); - ((*(pi1 + 2 + 1 * ldp))) = ((*(alpha1 + 2 * inca + 1 * lda))); - ((*(pi1 + 3 + 1 * ldp))) = ((*(alpha1 + 3 * inca + 1 * lda))); - ((*(pi1 + 4 + 1 * ldp))) = ((*(alpha1 + 4 * inca + 1 * lda))); - ((*(pi1 + 5 + 1 * ldp))) = ((*(alpha1 + 5 * inca + 1 * lda))); - ((*(pi1 + 6 + 1 * ldp))) = ((*(alpha1 + 6 * inca + 1 * lda))); - ((*(pi1 + 7 + 1 * ldp))) = ((*(alpha1 + 7 * inca + 1 * lda))); - - alpha1 += 2 * lda; - pi1 += 2 * ldp; - } - - //for (; n_left != 0; --n_left) - if (n_left == 1) - { - ((*(pi1 + 0))) = ((*(alpha1 + 0 * inca))); - ((*(pi1 + 1))) = ((*(alpha1 + 1 * inca))); - ((*(pi1 + 2))) = ((*(alpha1 + 2 * inca))); - ((*(pi1 + 3))) = ((*(alpha1 + 3 * inca))); - ((*(pi1 + 4))) = ((*(alpha1 + 4 * inca))); - ((*(pi1 + 5))) = ((*(alpha1 + 5 * inca))); - ((*(pi1 + 6))) = ((*(alpha1 + 6 * inca))); - ((*(pi1 + 7))) = ((*(alpha1 + 7 * inca))); - - alpha1 += lda; - pi1 += ldp; - } - } - } - else - { - if (bli_is_conj(conja)) - { - for (dim_t k = n; k != 0; --k) - { - ((*(pi1 + 0))) = ((*kappa_cast)) * ((*(alpha1 + 0 * inca))); - ((*(pi1 + 1))) = ((*kappa_cast)) * ((*(alpha1 + 1 * inca))); - ((*(pi1 + 2))) = ((*kappa_cast)) * ((*(alpha1 + 2 * inca))); - ((*(pi1 + 3))) = ((*kappa_cast)) * ((*(alpha1 + 3 * inca))); - ((*(pi1 + 4))) = ((*kappa_cast)) * ((*(alpha1 + 4 * inca))); - ((*(pi1 + 5))) = ((*kappa_cast)) * ((*(alpha1 + 5 * inca))); - ((*(pi1 + 6))) = ((*kappa_cast)) * ((*(alpha1 + 6 * inca))); - ((*(pi1 + 7))) = ((*kappa_cast)) * ((*(alpha1 + 7 * inca))); - - alpha1 += lda; - pi1 += ldp; - } - } - else - { - for (dim_t k = n; k != 0; --k) - { - ((*(pi1 + 0))) = ((*kappa_cast)) * ((*(alpha1 + 0 * inca))); - ((*(pi1 + 1))) = ((*kappa_cast)) * ((*(alpha1 + 1 * inca))); - ((*(pi1 + 2))) = ((*kappa_cast)) * ((*(alpha1 + 2 * inca))); - ((*(pi1 + 3))) = ((*kappa_cast)) * ((*(alpha1 + 3 * inca))); - ((*(pi1 + 4))) = ((*kappa_cast)) * ((*(alpha1 + 4 * inca))); - ((*(pi1 + 5))) = ((*kappa_cast)) * ((*(alpha1 + 5 * inca))); - ((*(pi1 + 6))) = ((*kappa_cast)) * ((*(alpha1 + 6 * inca))); - ((*(pi1 + 7))) = ((*kappa_cast)) * ((*(alpha1 + 7 * inca))); - - alpha1 += lda; - pi1 += ldp; - } - } - } - } - else /* if ( cdim < mnr ) */ - { - bli_dscal2m_ex - ( - 0, - BLIS_NONUNIT_DIAG, - BLIS_DENSE, - (trans_t)conja, - cdim, - n, - kappa, - a, inca, lda, - p, 1, ldp, - cntx, - NULL - ); - - /* if ( cdim < mnr ) */ - { - const dim_t i = cdim; - const dim_t m_edge = 8 - cdim; - const dim_t n_edge = n_max; - double* restrict p_cast = p; - double* restrict p_edge = p_cast + (i) * 1; - - bli_dset0s_mxn - ( - m_edge, - n_edge, - p_edge, 1, ldp - ); - } - } - - if (n < n_max) - { - const dim_t j = n; - const dim_t m_edge = 8; - const dim_t n_edge = n_max - n; - double* restrict p_cast = p; - double* restrict p_edge = p_cast + (j)*ldp; - - bli_dset0s_mxn - ( - m_edge, - n_edge, - p_edge, 1, ldp - ); - } -}// End of function - -// Packing routine for general operations -void bli_dpackm_6xk_gen_zen -( - conj_t conja, - pack_t schema, - dim_t cdim, - dim_t n, - dim_t n_max, - double* restrict kappa, - double* restrict a, inc_t inca, inc_t lda, - double* restrict p, inc_t ldp, - cntx_t* restrict cntx -) -{ - double* restrict kappa_cast = kappa; - double* restrict alpha1 = a; - double* restrict pi1 = p; - - if (cdim == 6) - { - if ((((*kappa_cast)) == (1.0))) - { - if (bli_is_conj(conja)) - { - for (dim_t k = n; k != 0; --k) - { - ((*(pi1 + 0))) = (((*(alpha1 + 0 * inca)))); - ((*(pi1 + 1))) = (((*(alpha1 + 1 * inca)))); - ((*(pi1 + 2))) = (((*(alpha1 + 2 * inca)))); - ((*(pi1 + 3))) = (((*(alpha1 + 3 * inca)))); - ((*(pi1 + 4))) = (((*(alpha1 + 4 * inca)))); - ((*(pi1 + 5))) = (((*(alpha1 + 5 * inca)))); - - alpha1 += lda; - pi1 += ldp; - } - } - else - { - for (dim_t k = n; k != 0; --k) - { - ((*(pi1 + 0))) = ((*(alpha1 + 0 * inca))); - ((*(pi1 + 1))) = ((*(alpha1 + 1 * inca))); - ((*(pi1 + 2))) = ((*(alpha1 + 2 * inca))); - ((*(pi1 + 3))) = ((*(alpha1 + 3 * inca))); - - ((*(pi1 + 4))) = ((*(alpha1 + 4 * inca))); - ((*(pi1 + 5))) = ((*(alpha1 + 5 * inca))); - - alpha1 += lda; - pi1 += ldp; - } - } - } - else - { - if (bli_is_conj(conja)) - { - for (dim_t k = n; k != 0; --k) - { - ((*(pi1 + 0))) = ((*kappa_cast)) * ((*(alpha1 + 0 * inca))); - ((*(pi1 + 1))) = ((*kappa_cast)) * ((*(alpha1 + 1 * inca))); - ((*(pi1 + 2))) = ((*kappa_cast)) * ((*(alpha1 + 2 * inca))); - ((*(pi1 + 3))) = ((*kappa_cast)) * ((*(alpha1 + 3 * inca))); - ((*(pi1 + 4))) = ((*kappa_cast)) * ((*(alpha1 + 4 * inca))); - ((*(pi1 + 5))) = ((*kappa_cast)) * ((*(alpha1 + 5 * inca))); - - alpha1 += lda; - pi1 += ldp; - } - } - else - { - for (dim_t k = n; k != 0; --k) - { - ((*(pi1 + 0))) = ((*kappa_cast)) * ((*(alpha1 + 0 * inca))); - ((*(pi1 + 1))) = ((*kappa_cast)) * ((*(alpha1 + 1 * inca))); - ((*(pi1 + 2))) = ((*kappa_cast)) * ((*(alpha1 + 2 * inca))); - ((*(pi1 + 3))) = ((*kappa_cast)) * ((*(alpha1 + 3 * inca))); - ((*(pi1 + 4))) = ((*kappa_cast)) * ((*(alpha1 + 4 * inca))); - ((*(pi1 + 5))) = ((*kappa_cast)) * ((*(alpha1 + 5 * inca))); - - alpha1 += lda; - pi1 += ldp; - } - } - } - } - else /* if ( cdim < mnr ) */ - { - bli_dscal2m_ex - ( - 0, - BLIS_NONUNIT_DIAG, - BLIS_DENSE, - (trans_t)conja, - cdim, - n, - kappa, - a, inca, lda, - p, 1, ldp, - cntx, - NULL - ); - - /* if ( cdim < mnr ) */ - { - const dim_t i = cdim; - const dim_t m_edge = 6 - cdim; - const dim_t n_edge = n_max; - double* restrict p_cast = p; - double* restrict p_edge = p_cast + (i) * 1; - - bli_dset0s_mxn - ( - m_edge, - n_edge, - p_edge, 1, ldp - ); - } - } - - if (n < n_max) - { - const dim_t j = n; - const dim_t m_edge = 6; - const dim_t n_edge = n_max - n; - double* restrict p_cast = p; - double* restrict p_edge = p_cast + (j)*ldp; - - bli_dset0s_mxn - ( - m_edge, - n_edge, - p_edge, 1, ldp - ); - } -}// end of function diff --git a/kernels/zen/CMakeLists.txt b/kernels/zen/CMakeLists.txt index bc17272d54..0ac346fb3e 100644 --- a/kernels/zen/CMakeLists.txt +++ b/kernels/zen/CMakeLists.txt @@ -1,7 +1,7 @@ ##Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.## -set(SUBDIRECTORIES "1" "1f" "1m" "2" "3" "util") +set(SUBDIRECTORIES "1" "1f" "2" "3" "util") #Add all subdirectories foreach(VAR ${SUBDIRECTORIES}) From a8e47a82c08f2301175a7a53ddde1f7f6a05652a Mon Sep 17 00:00:00 2001 From: Dipal M Zambare Date: Sun, 19 Sep 2021 23:02:57 +0530 Subject: [PATCH 020/243] Fixed build issue with clang in 'make checkcpp' The makefile in vendor/testcpp folder has hardcoded g++ as cpp compiler and linker. Updated the makefile to take the compiler which is used to build the library. AMD-Internal: [CPUPL-1873] Change-Id: Ib0bcbb8fccd0ff6f90b49b3b1e4a272cf3bad361 --- vendor/testcpp/Makefile | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vendor/testcpp/Makefile b/vendor/testcpp/Makefile index 01506c9966..9a5a466f59 100644 --- a/vendor/testcpp/Makefile +++ b/vendor/testcpp/Makefile @@ -3,7 +3,7 @@ # libraries. # # Copyright (C) 2014, The University of Texas at Austin -# Copyright (C) 2017 - 2019, Advanced Micro Devices, Inc. +# Copyright (C) 2017 - 2021, Advanced Micro Devices, Inc. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -112,7 +112,9 @@ TEST_OBJS := $(patsubst $(TEST_SRC_PATH)/%.c, \ # while building BLIS. CINCFLAGS := -I$(INC_PATH) -CXX = g++ +# Use CXX from the blis configuration, this will insure that +# correct compiler and compiler version is used to build testcpp folder +#CXX = g++ # Use the CFLAGS for the configuration family. override CFLAGS += $(call get-sandbox-cxxflags-for,$(CONFIG_NAME)) From 5d287fdba01a8886c930a169900798413d62001c Mon Sep 17 00:00:00 2001 From: Dipal M Zambare Date: Wed, 22 Sep 2021 20:21:09 +0530 Subject: [PATCH 021/243] Include LP64/ILP64 in BLIS binary name Binary name will be chosen based on multi-threading and BLAS integer size configuration as given below. libblis-[mt]-lp64 - when configured to use 32 bit integers libblis-[mt]-ilp64 - when configured to use 64 bit integers AMD-Internal: [CPUPL-1879] Change-Id: I865023c63235a0a72bdfce7057b2cfb8158b1d87 --- build/config.mk.in | 4 ++++ common.mk | 16 ++++++++++++---- configure | 3 ++- 3 files changed, 18 insertions(+), 5 deletions(-) diff --git a/build/config.mk.in b/build/config.mk.in index da8b3ad285..709e0f543c 100644 --- a/build/config.mk.in +++ b/build/config.mk.in @@ -5,6 +5,7 @@ # libraries. # # Copyright (C) 2014, The University of Texas at Austin +# Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -200,5 +201,8 @@ MK_COMPLEX_RETURN_SCHEME := @complex_return@ # Status of aocl dynamic configuration MK_ENABLE_AOCL_DYNAMIC := @enable_aocl_dynamic@ +# BLAS int size +MK_BLAS_INT_TYPE_SIZE := @blas_int_type_size@ + # end of ifndef CONFIG_MK_INCLUDED conditional block endif diff --git a/common.mk b/common.mk index a05e2160f0..26e8627adb 100644 --- a/common.mk +++ b/common.mk @@ -1,11 +1,11 @@ # # -# BLIS +# BLIS # An object-based framework for developing high-performance BLAS-like # libraries. # # Copyright (C) 2014, The University of Texas at Austin -# Copyright (C) 2020, Advanced Micro Devices, Inc. +# Copyright (C) 2020-2021, Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -411,9 +411,17 @@ BASE_LIB_PATH := $(LIB_PATH) # The base name of the BLIS library that we will build. ifeq ($(THREADING_MODEL),off) -LIBBLIS := libblis +ifeq ($(MK_BLAS_INT_TYPE_SIZE), 64) +LIBBLIS := libblis-ilp64 else -LIBBLIS := libblis-mt +LIBBLIS := libblis-lp64 +endif +else +ifeq ($(MK_BLAS_INT_TYPE_SIZE), 64) +LIBBLIS := libblis-mt-ilp64 +else +LIBBLIS := libblis-mt-lp64 +endif endif # The shared (dynamic) library file suffix is different for Linux and OS X. diff --git a/configure b/configure index 4858f5aa65..bec498d3cf 100755 --- a/configure +++ b/configure @@ -5,7 +5,7 @@ # libraries. # # Copyright (C) 2014, The University of Texas at Austin -# Copyright (C) 2020-2021, Advanced Micro Devices, Inc. +# Copyright (C) 2020-2021, Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -3369,6 +3369,7 @@ main() | sed -e "s/@enable_trsm_preinversion@/${enable_trsm_preinversion}/g" \ | sed -e "s/@enable_aocl_dynamic@/${enable_aocl_dynamic}/g" \ | sed -e "s/@complex_return@/${complex_return}/g" \ + | sed -e "s/@blas_int_type_size@/${blas_int_type_size}/g" \ > "${config_mk_out_path}" From faeb79f2b980182b1f883cf36ddb7408a8967b6b Mon Sep 17 00:00:00 2001 From: Nageshwar Singh Date: Fri, 17 Sep 2021 16:01:23 +0530 Subject: [PATCH 022/243] Trsm bench utility missmatch DTL logs and bench AOCL-Internal: [CPUPL-1585] Change-Id: I2896d695e6bb40ec39a4f840240499927de16962 --- bench/bench_trsm.c | 228 +++++++++++++++++++------------------------- bench/inputtrsm.txt | 28 ++++-- 2 files changed, 116 insertions(+), 140 deletions(-) diff --git a/bench/bench_trsm.c b/bench/bench_trsm.c index c95f51cc86..a7d62ebecc 100644 --- a/bench/bench_trsm.c +++ b/bench/bench_trsm.c @@ -1,12 +1,10 @@ -/* +/* BLIS An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2014, The University of Texas at Austin Copyright (C) 2020-2021, Advanced Micro Devices, Inc. All rights reserved. - Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -18,7 +16,6 @@ - Neither the name of The University of Texas nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR @@ -30,29 +27,22 @@ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - */ - #ifdef WIN32 #include #else #include #endif #include "blis.h" - /* -Benchmark application to process aocl logs generated +Benchmark application to process aocl logs generated by BLIS library for trsm. */ - #ifndef N_REPEAT #define N_REPEAT 30 #endif - #define AOCL_MATRIX_INITIALISATION -//#define BLIS_ENABLE_CBLAS - /* For BLIS since logs are collected at BLAS interfaces * we disable cblas interfaces for this benchmark application */ @@ -65,33 +55,29 @@ by BLIS library for trsm. int main( int argc, char** argv ) { - obj_t a, c; - obj_t c_save; + obj_t a, b; + obj_t b_save; obj_t alpha; dim_t m, n; + dim_t p_inc = 0; // to keep track of number of inputs num_t dt = BLIS_DOUBLE; - int r, n_repeats; - char side; + dim_t r, n_repeats; + f77_char side; uplo_t uploa; trans_t transa; diag_t diaga; - f77_char f77_side; f77_char f77_uploa; f77_char f77_transa; f77_char f77_diaga; - double dtime; double dtime_save; double gflops; double alphaR; double alphaI; - FILE* fin = NULL; FILE* fout = NULL; - n_repeats = N_REPEAT; - if(argc < 3) { printf("Usage: ./test_trsm_XX.x input.csv output.csv\n"); @@ -103,69 +89,65 @@ int main( int argc, char** argv ) printf("Error opening the file %s\n", argv[1]); exit(1); } - fout = fopen(argv[2], "w"); if(fout == NULL) { printf("Error opening the file %s\n", argv[2]); exit(1); } - - fprintf(fout,"dt\t side\t uploa\t transa\t diaga\t m\t n\t lda\t ldb\t\ - lphaR\t alphaI\t gflops\n"); - - inc_t cs_a,rs_a,cs_c,rs_c; - char dt_type, side_c, uploa_c, transa_c, diaga_c; - char logline[255]; - - while(fscanf(fin, "%s %c %c %c %ld %ld %lu %lu %lu %lu %c %c %lf %lf\n", - logline,&dt_type, &side_c,&uploa_c, &m, &n, &cs_a, &cs_c,&rs_a,&rs_c, - &transa_c,&diaga_c,&alphaR,&alphaI) == 14) + fprintf(fout,"dt\t side\t uploa\t transa\t diaga\t m\t n\t lda\t ldb\t alphaR\t alphaI\t gflops\n"); + + dim_t lda,ldb; + f77_char dt_type_arg, side_arg, uploa_arg, transa_arg, diaga_arg; + f77_char logline[255]; + // input order: {S,D,C,Z} {side, uplo, transa, diag, m, n, lda, ldb, alphaR, alphaI} + while(fscanf(fin, "%s %c %c %c %c %c %ld %ld %ld %ld %lf %lf\n", + logline, &dt_type_arg, &side_arg, &uploa_arg, &transa_arg, &diaga_arg, &m, &n, &lda, &ldb, + &alphaR, &alphaI) == 12) { - if((dt_type=='S')||(dt_type=='s')) dt = BLIS_FLOAT; - if((dt_type=='D')||(dt_type=='d')) dt = BLIS_DOUBLE; - if((dt_type=='C')||(dt_type=='c')) dt = BLIS_SCOMPLEX; - if((dt_type=='Z')||(dt_type=='z')) dt = BLIS_DCOMPLEX; - - if( 'l' == side_c|| 'L' == side_c) + if( (dt_type_arg=='S') || (dt_type_arg=='s') ) dt = BLIS_FLOAT; + if( (dt_type_arg=='D') || (dt_type_arg=='d') ) dt = BLIS_DOUBLE; + if( (dt_type_arg=='C') || (dt_type_arg=='c') ) dt = BLIS_SCOMPLEX; + if( (dt_type_arg=='Z') || (dt_type_arg=='z') ) dt = BLIS_DCOMPLEX; + if( 'l' == side_arg|| 'L' == side_arg ) side = BLIS_LEFT; - else if('r' == side_c || 'R' == side_c) + else if( 'r' == side_arg || 'R' == side_arg ) side = BLIS_RIGHT; else { - printf("Invalid entry for the argument 'side':%c\n",side_c); + printf("Invalid entry for the argument 'side':%c\n", side_arg); continue; } - if('l' == uploa_c || 'L' == uploa_c) + if('l' == uploa_arg || 'L' == uploa_arg) uploa = BLIS_LOWER; - else if('u' == uploa_c || 'U' == uploa_c) + else if('u' == uploa_arg || 'U' == uploa_arg) uploa = BLIS_UPPER; else { - printf("Invalid entry for the argument 'uplo':%c\n",uploa_c); + printf("Invalid entry for the argument 'uplo':%c\n",uploa_arg); continue; } - if('t' == transa_c || 'T' == transa_c) + if('t' == transa_arg || 'T' == transa_arg) transa = BLIS_TRANSPOSE; - else if('n' == transa_c || 'N' == transa_c) + else if('n' == transa_arg || 'N' == transa_arg) transa = BLIS_NO_TRANSPOSE; - else if('c' == transa_c || 'C' == transa_c) - transa = BLIS_CONJ_TRANSPOSE; - else + else if('c' == transa_arg || 'C' == transa_arg) + transa = BLIS_CONJ_TRANSPOSE; + else { - printf("Invalid entry for the argument 'transa':%c\n",transa_c); + printf("Invalid entry for the argument 'transa':%c\n",transa_arg); continue; } - if('u' == diaga_c || 'U' == diaga_c) + if('u' == diaga_arg || 'U' == diaga_arg) diaga = BLIS_UNIT_DIAG; - else if('n' == diaga_c || 'N' == diaga_c) + else if('n' == diaga_arg || 'N' == diaga_arg) diaga = BLIS_NONUNIT_DIAG; else { - printf("Invalid entry for the argument 'diaga':%c\n", diaga_c); + printf("Invalid entry for the argument 'diaga':%c\n", diaga_arg); continue; } @@ -175,56 +157,45 @@ int main( int argc, char** argv ) bli_param_map_blis_to_netlib_diag( diaga, &f77_diaga ); if ( bli_is_left( side ) ) - bli_obj_create( dt, m, m, rs_a, cs_a, &a ); + bli_obj_create( dt, m, m, 1, lda, &a ); else - bli_obj_create( dt, n, n, rs_a, cs_a, &a ); + bli_obj_create( dt, n, n, 1, lda, &a ); - bli_obj_create( dt, m, n, rs_c, cs_c, &c ); - bli_obj_create( dt, m, n, rs_c, cs_c, &c_save ); + bli_obj_create( dt, m, n, 1, ldb, &b ); + bli_obj_create( dt, m, n, 1, ldb, &b_save ); #ifdef AOCL_MATRIX_INITIALISATION bli_randm( &a ); - bli_randm( &c ); + bli_randm( &b ); #endif - bli_obj_set_struc( BLIS_TRIANGULAR, &a ); bli_obj_set_uplo( uploa, &a ); bli_obj_set_conjtrans( transa, &a ); bli_obj_set_diag( diaga, &a ); - // Randomize A and zero the unstored triangle to ensure the // implementation reads only from the stored region. bli_randm( &a ); bli_mktrim( &a ); - // Load the diagonal of A to make it more likely to be invertible. bli_shiftd( &BLIS_TWO, &a ); - bli_obj_create( dt, 1, 1, 0, 0, &alpha ); bli_setsc( alphaR, alphaI, &alpha ); - - bli_copym( &c, &c_save ); - + bli_copym( &b, &b_save ); dtime_save = DBL_MAX; - for ( r = 0; r < n_repeats; ++r ) { - bli_copym( &c_save, &c ); + bli_copym( &b_save, &b ); #ifdef PRINT bli_printm( "a", &a, "%4.1f", "" ); - bli_printm( "c", &c, "%4.1f", "" ); + bli_printm( "b", &b, "%4.1f", "" ); #endif dtime = bli_clock(); - #ifdef BLIS - bli_trsm( &side, &alpha, &a, - &c ); - + &b ); #else - #ifdef CBLAS enum CBLAS_ORDER cblas_order; enum CBLAS_TRANSPOSE cblas_transa; @@ -232,7 +203,7 @@ int main( int argc, char** argv ) enum CBLAS_SIDE cblas_side; enum CBLAS_DIAG cblas_diag; - if ( ( stor_scheme == 'C' ) || ( stor_scheme == 'c' ) ) + if ( bli_obj_row_stride( &b ) == 1 ) cblas_order = CblasColMajor; else cblas_order = CblasRowMajor; @@ -244,20 +215,30 @@ int main( int argc, char** argv ) else cblas_transa = CblasNoTrans; - if(bli_is_upper(uploa)) - cblas_uplo = CblasUpper; + if ('u' == diaga_arg || 'U' == diaga_arg) + cblas_diag = CblasUnit; else - cblas_uplo = CblasLower; + cblas_diag = CblasNonUnit; - if(bli_is_left(side)) + if( 'l' == side_arg || 'L' == side_arg ) cblas_side = CblasLeft; - else + else if( 'r' == side_arg || 'R' == side_arg ) cblas_side = CblasRight; + else + { + printf("Invalid entry for the argument 'side':%c\n", side_arg); + continue; + } - if(bli_is_unit_diag(diaga)) - cblas_diag = CblasUnit; + if('l' == uploa_arg || 'L' == uploa_arg) + cblas_uplo = CblasLower; + else if('u' == uploa_arg || 'U' == uploa_arg) + cblas_uplo = CblasUpper; else - cblas_diag = CblasNonUnit; + { + printf("Invalid entry for the argument 'uplo':%c\n",uploa_arg); + continue; + } #else f77_char f77_transa; @@ -265,15 +246,13 @@ int main( int argc, char** argv ) #endif if ( bli_is_float( dt ) ) { - f77_int mm = bli_obj_length( &c ); - f77_int nn = bli_obj_width( &c ); + f77_int mm = bli_obj_length( &b ); + f77_int nn = bli_obj_width( &b ); f77_int lda = bli_obj_col_stride( &a ); - f77_int ldc = bli_obj_col_stride( &c ); - + f77_int ldb = bli_obj_col_stride( &b ); float* alphap = bli_obj_buffer( &alpha ); float* ap = bli_obj_buffer( &a ); - float* cp = bli_obj_buffer( &c ); - + float* bp = bli_obj_buffer( &b ); #ifdef CBLAS cblas_strsm( cblas_order, cblas_side, @@ -284,7 +263,7 @@ int main( int argc, char** argv ) nn, *alphap, ap, lda, - cp, ldc + bp, ldb ); #else strsm_( &f77_side, @@ -295,19 +274,18 @@ int main( int argc, char** argv ) &nn, alphap, ap, &lda, - cp, &ldc ); + bp, &ldb ); #endif } else if ( bli_is_double( dt ) ) { - f77_int mm = bli_obj_length( &c ); - f77_int nn = bli_obj_width( &c ); + f77_int mm = bli_obj_length( &b ); + f77_int nn = bli_obj_width( &b ); f77_int lda = bli_obj_col_stride( &a ); - f77_int ldc = bli_obj_col_stride( &c ); + f77_int ldb = bli_obj_col_stride( &b ); double* alphap = bli_obj_buffer( &alpha ); double* ap = bli_obj_buffer( &a ); - double* cp = bli_obj_buffer( &c ); - + double* bp = bli_obj_buffer( &b ); #ifdef CBLAS cblas_dtrsm( cblas_order, cblas_side, @@ -318,9 +296,9 @@ int main( int argc, char** argv ) nn, *alphap, ap, lda, - cp, ldc + bp, ldb ); -#else +#else dtrsm_( &f77_side, &f77_uploa, &f77_transa, @@ -329,20 +307,18 @@ int main( int argc, char** argv ) &nn, alphap, ap, &lda, - cp, &ldc ); + bp, &ldb ); #endif - } else if ( bli_is_scomplex( dt ) ) { - f77_int mm = bli_obj_length( &c ); - f77_int nn = bli_obj_width( &c ); + f77_int mm = bli_obj_length( &b ); + f77_int nn = bli_obj_width( &b ); f77_int lda = bli_obj_col_stride( &a ); - f77_int ldc = bli_obj_col_stride( &c ); + f77_int ldb = bli_obj_col_stride( &b ); scomplex* alphap = bli_obj_buffer( &alpha ); scomplex* ap = bli_obj_buffer( &a ); - scomplex* cp = bli_obj_buffer( &c ); - + scomplex* bp = bli_obj_buffer( &b ); #ifdef CBLAS cblas_ctrsm( cblas_order, cblas_side, @@ -353,7 +329,7 @@ int main( int argc, char** argv ) nn, alphap, ap, lda, - cp, ldc + bp, ldb ); #else ctrsm_( &f77_side, @@ -364,18 +340,18 @@ int main( int argc, char** argv ) &nn, alphap, ap, &lda, - cp, &ldc ); + bp, &ldb ); #endif } else if ( bli_is_dcomplex( dt ) ) { - f77_int mm = bli_obj_length( &c ); - f77_int nn = bli_obj_width( &c ); + f77_int mm = bli_obj_length( &b ); + f77_int nn = bli_obj_width( &b ); f77_int lda = bli_obj_col_stride( &a ); - f77_int ldc = bli_obj_col_stride( &c ); + f77_int ldb = bli_obj_col_stride( &b ); dcomplex* alphap = bli_obj_buffer( &alpha ); dcomplex* ap = bli_obj_buffer( &a ); - dcomplex* cp = bli_obj_buffer( &c ); + dcomplex* bp = bli_obj_buffer( &b ); #ifdef CBLAS cblas_ztrsm( cblas_order, cblas_side, @@ -386,7 +362,7 @@ int main( int argc, char** argv ) nn, alphap, ap, lda, - cp, ldc + bp, ldb ); #else ztrsm_( &f77_side, @@ -397,7 +373,7 @@ int main( int argc, char** argv ) &nn, alphap, ap, &lda, - cp, &ldc ); + bp, &ldb ); #endif }else{ printf("Invalid data type! Exiting!\n"); @@ -406,42 +382,32 @@ int main( int argc, char** argv ) #endif dtime_save = bli_clock_min_diff( dtime_save, dtime ); } - if ( bli_is_left( side ) ) gflops = ( 1.0 * m * m * n ) / ( dtime_save * 1.0e9 ); else gflops = ( 1.0 * m * n * n ) / ( dtime_save * 1.0e9 ); - if ( bli_is_complex( dt ) ) gflops *= 4.0; - #ifdef BLIS printf( "data_trsm_blis\t\t"); #else printf( "data_trsm_%s\t\t",BLAS ); #endif - - printf("%c\t %c\t %c\t %c\t %c\t %4lu\t %4lu\t %4lu\t %4lu\t %6.3f\t %6.3f\t %6.3f\n", - dt_type,side_c, uploa_c, transa_c, - diaga_c, (unsigned long )m, (unsigned long ) n, (unsigned long )cs_a, - (unsigned long )cs_c, alphaR, alphaI, gflops); - + p_inc++; + printf( "( %2lu, 1:2 ) = [ %4lu %7.2f ];\n", + ( unsigned long )p_inc, + ( unsigned long )m, gflops ); fprintf(fout,"%c\t %c\t %c\t %c\t %c\t %4lu\t %4lu\t %4lu\t %4lu\t %6.3f\t %6.3f\t %6.3f\n", - dt_type,side_c, uploa_c, transa_c, - diaga_c, (unsigned long )m, (unsigned long ) n, (unsigned long )cs_a, - (unsigned long )cs_c, alphaR, alphaI, gflops); - + dt_type_arg, side_arg, uploa_arg, transa_arg, + diaga_arg, (unsigned long )m, (unsigned long ) n, (unsigned long )lda, + (unsigned long )ldb, alphaR, alphaI, gflops); fflush(fout); - bli_obj_free( &alpha ); bli_obj_free( &a ); - bli_obj_free( &c ); - bli_obj_free( &c_save ); + bli_obj_free( &b ); + bli_obj_free( &b_save ); } - fclose(fin); fclose(fout); - //bli_finalize(); - return 0; } diff --git a/bench/inputtrsm.txt b/bench/inputtrsm.txt index 54254e919e..ec42f13c97 100644 --- a/bench/inputtrsm.txt +++ b/bench/inputtrsm.txt @@ -1,9 +1,19 @@ - bli_trsm_ex:375: D L L 10 12 228 228 1 1 n n 1.000000 0.000000 - bli_trsm_ex:375: D L L 1017 1095 2112 2112 1 1 n n 1.000000 0.000000 - bli_trsm_ex:375: D L U 99 1 753 99 1 1 t u 1.000000 0.000000 - bli_trsm_ex:375: D L L 10 17 417 417 1 1 n n 1.000000 0.000000 - bli_trsm_ex:375: D L L 1020 5958 6978 6978 1 1 n n 1.000000 0.000000 - bli_trsm_ex:375: D L L 102 1032 1134 1134 1 1 n n 1.000000 0.000000 - bli_trsm_ex:375: D L L 102 1 1005 4602001 1 1 t n 1.000000 0.000000 - bli_trsm_ex:375: D L L 16 25 609 609 1 1 n n 1.000000 0.000000 - \ No newline at end of file +dtrsm_:400: d L L N N 1000 1000 1000 1000 2.000000 0.000000 +dtrsm_:377: d L L N U 16 96 1000 1000 1.000000 0.000000 +dtrsm_:377: d L L N U 16 96 10000 10000 1.000000 0.000000 +dtrsm_:377: d L L N U 16 96 30000 30000 1.000000 0.000000 +dtrsm_:377: d L L N U 5 5 5 5 1.000000 0.000000 +dtrsm_:377: d L U N N 10 10 10 10 1.000000 0.000000 +dtrsm_:377: d L U N N 100 100 100 100 1.000000 0.000000 +dtrsm_:377: d L U N N 1000 1000 1000 1000 1.000000 0.000000 +dtrsm_:400: d R U N N 1000 1000 1000 1000 2.000000 0.000000 +dtrsm_:400: d R U N N 1200 1200 1200 1200 2.000000 0.000000 +dtrsm_:400: d R U N N 1400 1400 1400 1400 2.000000 0.000000 +dtrsm_:400: d R U N N 1600 1600 1600 1600 2.000000 0.000000 +dtrsm_:400: d R U N N 1800 1800 1800 1800 2.000000 0.000000 +dtrsm_:400: d R U N N 2000 2000 2000 2000 2.000000 0.000000 +dtrsm_:400: d R U N N 200 200 200 200 2.000000 0.000000 +dtrsm_:400: d R U N N 400 400 400 400 2.000000 0.000000 +dtrsm_:400: d R U N N 600 600 600 600 2.000000 0.000000 +dtrsm_:400: d R U N N 800 800 800 800 2.000000 0.000000 + From 10ca8710f0264f92c6d547d6c0dfdbc1f7f97bb7 Mon Sep 17 00:00:00 2001 From: Meghana Vankadari Date: Fri, 18 Jun 2021 17:32:25 +0530 Subject: [PATCH 023/243] Optimized SUP code for GEMMT Details: - Eliminated the IR loop in ref_var2m functions. - Handled the rectangular and triangular portions of C matrix separately. - Added a condition to check and eliminate zero regions inside IC loop. - modified kc selection logic to choose optimal KC in SUP - Updated thresholds to choose between SUP and native. Change-Id: I21908eaa6bc3a8f37bdea29f7bfca7e6fcfee724 --- frame/3/gemmt/bli_gemmt_sup_var1n2m.c | 346 ++++++++++++++---------- frame/util/bli_util_update.c | 42 ++- kernels/zen/util/bli_thresh_funcs_zen.c | 2 +- 3 files changed, 242 insertions(+), 148 deletions(-) diff --git a/frame/3/gemmt/bli_gemmt_sup_var1n2m.c b/frame/3/gemmt/bli_gemmt_sup_var1n2m.c index ff46d1f52c..382ca6f67d 100644 --- a/frame/3/gemmt/bli_gemmt_sup_var1n2m.c +++ b/frame/3/gemmt/bli_gemmt_sup_var1n2m.c @@ -1497,7 +1497,7 @@ void PASTEMACT(ch,opname,uplo,varname) \ function pointer type. */ \ PASTECH(ch,gemmsup_ker_ft) \ gemmsup_ker = bli_cntx_get_l3_sup_ker_dt( dt, stor_id, cntx ); \ - ctype ct[ BLIS_STACK_BUF_MAX_SIZE / sizeof( ctype ) ] __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ + ctype ct[ BLIS_STACK_BUF_MAX_SIZE / sizeof( ctype ) ] __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ \ /* storage-scheme of ct should be same as that of C. Since update routines only support row-major order, @@ -1580,6 +1580,7 @@ void PASTEMACT(ch,opname,uplo,varname) \ dim_t m_off = 0; \ dim_t n_off = 0; \ doff_t diagoffc; \ + dim_t i, ip; \ \ /* Loop over the n dimension (NC rows/columns at a time). */ \ /*for ( dim_t jj = 0; jj < jc_iter; jj += 1 )*/ \ @@ -1690,7 +1691,8 @@ void PASTEMACT(ch,opname,uplo,varname) \ for ( dim_t ii = ic_start; ii < ic_end; ii += MC ) \ { \ /* Calculate the thread's current IC block dimension. */ \ - const dim_t mc_cur = ( MC <= ic_end - ii ? MC : ic_left ); \ + dim_t mc_cur = ( MC <= ic_end - ii ? MC : ic_left ); \ + dim_t nc_pruned = nc_cur; \ \ ctype* restrict a_ic = a_pc + ii * icstep_a; \ ctype* restrict c_ic = c_jc + ii * icstep_c; \ @@ -1699,7 +1701,24 @@ void PASTEMACT(ch,opname,uplo,varname) \ \ if(bli_gemmt_is_strictly_above_diag( m_off, n_off, mc_cur, nc_cur ) ) continue; \ \ - PASTEMAC(ch,set0s_mxn) ( MR, NR, ct, rs_ct, cs_ct ); \ + diagoffc = m_off - n_off; \ +\ + if( diagoffc < 0 ) \ + { \ + ip = -diagoffc / MR; \ + i = ip * MR; \ + mc_cur = mc_cur - i; \ + diagoffc = -diagoffc % MR; \ + m_off += i; \ + c_ic = c_ic + ( i ) * rs_c; \ + a_ic = a_ic + ( i ) * rs_a; \ + } \ +\ + if( ( diagoffc + mc_cur ) < nc_cur ) \ + { \ + nc_pruned = diagoffc + mc_cur; \ + } \ +\ ctype* a_use; \ inc_t rs_a_use, cs_a_use, ps_a_use; \ \ @@ -1755,8 +1774,8 @@ void PASTEMACT(ch,opname,uplo,varname) \ bli_thrinfo_sup_grow( rntm, bszids_jr, thread_jr ); \ \ /* Compute number of primary and leftover components of the JR loop. */ \ - dim_t jr_iter = ( nc_cur + NR - 1 ) / NR; \ - dim_t jr_left = nc_cur % NR; \ + dim_t jr_iter = ( nc_pruned + NR - 1 ) / NR; \ + dim_t jr_left = nc_pruned % NR; \ \ /* Compute the JR loop thread range for the current thread. */ \ dim_t jr_start, jr_end; \ @@ -1785,76 +1804,92 @@ void PASTEMACT(ch,opname,uplo,varname) \ ctype* restrict b_jr = b_pc_use + j * ps_b_use; \ ctype* restrict c_jr = c_ic + j * jrstep_c; \ \ - const dim_t ir_iter = ( mc_cur + MR - 1 ) / MR; \ - const dim_t ir_left = mc_cur % MR; \ + dim_t i; \ + dim_t m_zero = 0; \ + dim_t n_iter_zero = 0; \ +\ + m_off_cblock = m_off; \ + n_off_cblock = n_off + j * NR; \ +\ + if(bli_gemmt_is_strictly_below_diag(m_off_cblock, n_off_cblock, mc_cur, nc_cur)) \ + { \ + m_zero = 0; \ + } \ + else \ + { \ + /* compute number of rows that are filled with zeroes and can be ignored */ \ + n_iter_zero = (n_off_cblock < m_off_cblock)? 0 : (n_off_cblock - m_off)/MR; \ + m_zero = n_iter_zero * MR; \ + } \ +\ + ctype* restrict a_ir = a_ic_use + n_iter_zero * ps_a_use; \ + ctype* restrict c_ir = c_jr + n_iter_zero * irstep_c; \ \ - /* Loop over the m dimension (MR rows at a time). */ \ - for(dim_t i = 0; i < ir_iter; i += 1 ) \ + /* Ignore the zero region */ \ + m_off_cblock += m_zero; \ +\ + /* Compute the triangular part */ \ + for( i = m_zero; (i < mc_cur) && ( m_off_cblock < n_off_cblock + nr_cur); i += MR ) \ { \ - const dim_t mr_cur = ( bli_is_not_edge_f( i, ir_iter, ir_left ) ? MR : ir_left ); \ -\ - m_off_cblock = m_off + i * MR; \ - n_off_cblock = n_off + j * NR; \ - if(bli_gemmt_is_strictly_above_diag( m_off_cblock, n_off_cblock, mr_cur, nr_cur )) continue; \ - ctype* restrict a_ir = a_ic_use + i * ps_a_use; \ - ctype* restrict c_ir = c_jr + i * irstep_c; \ - if( bli_gemmt_is_strictly_below_diag(m_off_cblock, n_off_cblock, mr_cur, nr_cur ) ) \ + const dim_t mr_cur = (i+MR-1) < mc_cur ? MR : mc_cur - i; \ +\ + /* Invoke the gemmsup millikernel. */ \ + gemmsup_ker \ + ( \ + conja, \ + conjb, \ + mr_cur, \ + nr_cur, \ + kc_cur, \ + alpha_cast, \ + a_ir, rs_a_use, cs_a_use, \ + b_jr, rs_b_use, cs_b_use, \ + zero, \ + ct, rs_ct, cs_ct, \ + &aux, \ + cntx \ + ); \ + /* Scale the bottom edge of C and add the result from above. */ \ + /* If c and ct are col-major, induce transpose and call update for upper-triangle of C */ \ + if( col_pref ) \ { \ - /* Invoke the gemmsup millikernel. */ \ - gemmsup_ker \ - ( \ - conja, \ - conjb, \ - mr_cur, \ - nr_cur, \ - kc_cur, \ - alpha_cast, \ - a_ir, rs_a_use, cs_a_use, \ - b_jr, rs_b_use, cs_b_use, \ - beta_use, \ - c_ir, rs_c, cs_c, \ - &aux, \ - cntx \ - ); \ + PASTEMAC(ch,update_upper_triang)( n_off_cblock, m_off_cblock, \ + nr_cur, mr_cur, \ + ct, cs_ct, rs_ct, \ + beta_use, \ + c_ir, cs_c, rs_c ); \ } \ else \ { \ - /* Invoke the gemmsup millikernel. */ \ - gemmsup_ker \ - ( \ - conja, \ - conjb, \ - mr_cur, \ - nr_cur, \ - kc_cur, \ - alpha_cast, \ - a_ir, rs_a_use, cs_a_use, \ - b_jr, rs_b_use, cs_b_use, \ - zero, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ - /* Scale the bottom edge of C and add the result from above. */ \ - /* If c and ct are col-major, induce transpose and call update for upper-triangle of C */ \ - if( col_pref ) \ - { \ - PASTEMAC(ch,update_upper_triang)( n_off_cblock, m_off_cblock, \ - nr_cur, mr_cur, \ - ct, cs_ct, rs_ct, \ - beta_use, \ - c_ir, cs_c, rs_c ); \ - } \ - else \ - { \ - PASTEMAC(ch,update_lower_triang)( m_off_cblock, n_off_cblock, \ - mr_cur, nr_cur, \ - ct, rs_ct, cs_ct, \ - beta_use, \ - c_ir, rs_c, cs_c ); \ - } \ + PASTEMAC(ch,update_lower_triang)( m_off_cblock, n_off_cblock, \ + mr_cur, nr_cur, \ + ct, rs_ct, cs_ct, \ + beta_use, \ + c_ir, rs_c, cs_c ); \ } \ +\ + a_ir += ps_a_use; \ + c_ir += irstep_c; \ + m_off_cblock += mr_cur; \ } \ +\ + /* Invoke the gemmsup millikerneli for remaining rectangular part. */ \ + gemmsup_ker \ + ( \ + conja, \ + conjb, \ + (i > mc_cur)? 0: mc_cur - i, \ + nr_cur, \ + kc_cur, \ + alpha_cast, \ + a_ir, rs_a_use, cs_a_use, \ + b_jr, rs_b_use, cs_b_use, \ + beta_use, \ + c_ir, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ +\ } \ } \ \ @@ -1889,8 +1924,6 @@ PASTEMAC(ch,fprintm)( stdout, "gemmsup_ref_var2: c ", mr_cur, nr_cur, c_ir, rs_c INSERT_GENTFUNC_L( gemmtsup, ref_var2m ) - - #undef GENTFUNC #define GENTFUNC( ctype, ch, opname, uplo, varname ) \ \ @@ -1978,6 +2011,13 @@ void PASTEMACT(ch,opname,uplo,varname) \ stor_id == BLIS_CCC ) KC = KC0; \ else if ( stor_id == BLIS_RRC || \ stor_id == BLIS_CRC ) KC = KC0; \ + else if ( stor_id == BLIS_RCR ) \ + { \ + if ( m <= 4*MR ) KC = KC0; \ + else if ( m <= 36*MR ) KC = KC0 / 2; \ + else if ( m <= 56*MR ) KC = (( KC0 / 3 ) / 4 ) * 4; \ + else KC = KC0 / 4; \ + } \ else if ( m <= MR && n <= NR ) KC = KC0; \ else if ( m <= 2*MR && n <= 2*NR ) KC = KC0 / 2; \ else if ( m <= 3*MR && n <= 3*NR ) KC = (( KC0 / 3 ) / 4 ) * 4; \ @@ -2026,8 +2066,6 @@ void PASTEMACT(ch,opname,uplo,varname) \ \ const inc_t rs_ct = ( col_pref ? 1 : NR ); \ const inc_t cs_ct = ( col_pref ? MR : 1 ); \ -\ - PASTEMAC(ch,set0s_mxn) ( MR, NR, ct, rs_ct, cs_ct ); \ \ ctype* restrict a_00 = a; \ ctype* restrict b_00 = b; \ @@ -2097,6 +2135,7 @@ void PASTEMACT(ch,opname,uplo,varname) \ dim_t n_off = 0; \ doff_t diagoffc; \ dim_t m_off_cblock, n_off_cblock; \ + dim_t jp, j; \ \ /* Compute number of primary and leftover components of the JC loop. */ \ /*const dim_t jc_iter = ( n_local + NC - 1 ) / NC;*/ \ @@ -2211,14 +2250,37 @@ void PASTEMACT(ch,opname,uplo,varname) \ for ( dim_t ii = ic_start; ii < ic_end; ii += MC ) \ { \ /* Calculate the thread's current IC block dimension. */ \ - const dim_t mc_cur = ( MC <= ic_end - ii ? MC : ic_left ); \ + dim_t mc_cur = ( MC <= ic_end - ii ? MC : ic_left ); \ +\ + dim_t nc_pruned = nc_cur; \ \ m_off = ii; \ + n_off = jj; \ \ if(bli_gemmt_is_strictly_below_diag(m_off, n_off, mc_cur, nc_cur)) continue; \ \ ctype* restrict a_ic = a_pc + ii * icstep_a; \ ctype* restrict c_ic = c_jc + ii * icstep_c; \ +\ + doff_t diagoffc = m_off - n_off; \ +\ + ctype* restrict b_pc_pruned = b_pc_use; \ +\ + if(diagoffc > 0 ) \ + { \ + jp = diagoffc / NR; \ + j = jp * NR; \ + nc_pruned = nc_cur - j; \ + n_off += j; \ + diagoffc = diagoffc % NR; \ + c_ic = c_ic + ( j ) * cs_c; \ + b_pc_pruned = b_pc_use + ( jp ) * ps_b_use; \ + } \ +\ + if( ( ( -diagoffc ) + nc_pruned ) < mc_cur ) \ + { \ + mc_cur = -diagoffc + nc_pruned; \ + } \ \ ctype* a_use; \ inc_t rs_a_use, cs_a_use, ps_a_use; \ @@ -2275,8 +2337,8 @@ void PASTEMACT(ch,opname,uplo,varname) \ bli_thrinfo_sup_grow( rntm, bszids_jr, thread_jr ); \ \ /* Compute number of primary and leftover components of the JR loop. */ \ - dim_t jr_iter = ( nc_cur + NR - 1 ) / NR; \ - dim_t jr_left = nc_cur % NR; \ + dim_t jr_iter = ( nc_pruned + NR - 1 ) / NR; \ + dim_t jr_left = nc_pruned % NR; \ \ /* Compute the JR loop thread range for the current thread. */ \ dim_t jr_start, jr_end; \ @@ -2302,77 +2364,89 @@ void PASTEMACT(ch,opname,uplo,varname) \ /* ctype* restrict b_jr = b_pc_use + j * jrstep_b; \ */ \ - ctype* restrict b_jr = b_pc_use + j * ps_b_use; \ + ctype* restrict b_jr = b_pc_pruned + j * ps_b_use; \ ctype* restrict c_jr = c_ic + j * jrstep_c; \ + dim_t m_rect = 0; \ + dim_t n_iter_rect = 0; \ +\ + m_off_cblock = m_off; \ + n_off_cblock = n_off + j * NR; \ \ - const dim_t ir_iter = ( mc_cur + MR - 1 ) / MR; \ - const dim_t ir_left = mc_cur % MR; \ + if(bli_gemmt_is_strictly_above_diag(m_off_cblock, n_off_cblock, mc_cur, nr_cur)) \ + { \ + m_rect = mc_cur; \ + } \ + else \ + { \ + /* calculate the number of rows in rectangular region of the block */ \ + n_iter_rect = n_off_cblock < m_off_cblock ? 0: (n_off_cblock - m_off_cblock) / MR; \ + m_rect = n_iter_rect * MR; \ + } \ \ - /* Loop over the m dimension (MR rows at a time). */ \ - for(dim_t i = 0; i < ir_iter; i += 1 ) \ + /* Compute the rectangular part */ \ + gemmsup_ker \ + ( \ + conja, \ + conjb, \ + m_rect, \ + nr_cur, \ + kc_cur, \ + alpha_cast, \ + a_ic_use, rs_a_use, cs_a_use, \ + b_jr, rs_b_use, cs_b_use, \ + beta_use, \ + c_jr, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ +\ + m_off_cblock = m_off + m_rect; \ +\ + ctype* restrict a_ir = a_ic_use + n_iter_rect * ps_a_use; \ + ctype* restrict c_ir = c_jr + n_iter_rect * irstep_c; \ +\ + /* compute the remaining triangular part */ \ + for( dim_t i = m_rect;( i < mc_cur) && (m_off_cblock < n_off_cblock + nr_cur); i += MR ) \ { \ - const dim_t mr_cur = ( bli_is_not_edge_f( i, ir_iter, ir_left ) ? MR : ir_left ); \ - m_off_cblock = m_off + i * MR; \ - n_off_cblock = n_off + j * NR; \ - if( bli_gemmt_is_strictly_below_diag( m_off_cblock, n_off_cblock, mr_cur, nr_cur )) continue; \ - ctype* restrict a_ir = a_ic_use + i * ps_a_use; \ - ctype* restrict c_ir = c_jr + i * irstep_c; \ - if(bli_gemmt_is_strictly_above_diag( m_off_cblock, n_off_cblock, mr_cur, nr_cur )) \ + const dim_t mr_cur = (i+MR-1) < mc_cur ? MR : mc_cur - i; \ +\ + /* Invoke the gemmsup millikernel. */ \ + gemmsup_ker \ + ( \ + conja, \ + conjb, \ + mr_cur, \ + nr_cur, \ + kc_cur, \ + alpha_cast, \ + a_ir, rs_a_use, cs_a_use, \ + b_jr, rs_b_use, cs_b_use, \ + zero, \ + ct, rs_ct, cs_ct, \ + &aux, \ + cntx \ + ); \ +\ + if( col_pref ) \ { \ - /* Invoke the gemmsup millikernel. */ \ - gemmsup_ker \ - ( \ - conja, \ - conjb, \ - mr_cur, \ - nr_cur, \ - kc_cur, \ - alpha_cast, \ - a_ir, rs_a_use, cs_a_use, \ - b_jr, rs_b_use, cs_b_use, \ - beta_use, \ - c_ir, rs_c, cs_c, \ - &aux, \ - cntx \ - ); \ + PASTEMAC(ch,update_lower_triang)( n_off_cblock, m_off_cblock, \ + nr_cur, mr_cur, \ + ct, cs_ct, rs_ct, \ + beta_use, \ + c_ir, cs_c, rs_c ); \ } \ else \ { \ - /* Invoke the gemmsup millikernel. */ \ - gemmsup_ker \ - ( \ - conja, \ - conjb, \ - mr_cur, \ - nr_cur, \ - kc_cur, \ - alpha_cast, \ - a_ir, rs_a_use, cs_a_use, \ - b_jr, rs_b_use, cs_b_use, \ - zero, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - /* If c and ct are col-major, induce transpose and call update for lower-triangle of C */ \ - if( col_pref ) \ - { \ - PASTEMAC(ch,update_lower_triang)( n_off_cblock, m_off_cblock, \ - nr_cur, mr_cur, \ - ct, cs_ct, rs_ct, \ - beta_use, \ - c_ir, cs_c, rs_c ); \ - } \ - else \ - { \ - PASTEMAC(ch,update_upper_triang)( m_off_cblock, n_off_cblock, \ - mr_cur, nr_cur, \ - ct, rs_ct, cs_ct, \ - beta_use, \ - c_ir, rs_c, cs_c ); \ - } \ + PASTEMAC(ch,update_upper_triang)( m_off_cblock, n_off_cblock, \ + mr_cur, nr_cur, \ + ct, rs_ct, cs_ct, \ + beta_use, \ + c_ir, rs_c, cs_c ); \ } \ + a_ir += ps_a_use; \ + c_ir += irstep_c; \ + m_off_cblock += mr_cur; \ +\ } \ } \ } \ diff --git a/frame/util/bli_util_update.c b/frame/util/bli_util_update.c index 0f23424c88..b57c065721 100644 --- a/frame/util/bli_util_update.c +++ b/frame/util/bli_util_update.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2020, Advanced Micro Devices, Inc. + Copyright (C) 2020 - 21, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -61,17 +61,17 @@ void PASTEMAC(ch, varname) \ start = ((n_off < m_off) && (m_off < n_off + n_cur)) ? m_off: n_off; \ end = ((n_off < m_off + m_cur) && (m_off + m_cur < n_off + n_cur))? (m_off + m_cur):(n_off + n_cur); \ \ - if( beta_val != 0.0 ) \ + if ( beta_val == 1.0 ) \ { \ for(diag = start, m= start-m_off; diag < end; diag++, m++) \ for(n = 0; n <= diag-n_off; n++) \ - c[m*rs_c + n] = c[m * rs_c + n] * beta_val + ct[m*rs_ct + n]; \ + c[m*rs_c + n] += ct[m*rs_ct + n]; \ \ for(; m < m_cur; m++) \ for(n = 0; n < n_cur; n++) \ - c[m*rs_c + n] = c[m * rs_c + n] * beta_val + ct[m*rs_ct + n]; \ + c[m*rs_c + n] += ct[m*rs_ct + n]; \ } \ - else \ + else if( beta_val == 0.0 )\ { \ for(diag = start, m= start-m_off; diag < end; diag++, m++) \ for(n = 0; n <= diag-n_off; n++) \ @@ -81,6 +81,16 @@ void PASTEMAC(ch, varname) \ for(n = 0; n < n_cur; n++) \ c[m*rs_c + n] = ct[m*rs_ct + n]; \ } \ + else \ + { \ + for(diag = start, m= start-m_off; diag < end; diag++, m++) \ + for(n = 0; n <= diag-n_off; n++) \ + c[m*rs_c + n] = c[m * rs_c + n] * beta_val + ct[m*rs_ct + n]; \ +\ + for(; m < m_cur; m++) \ + for(n = 0; n < n_cur; n++) \ + c[m*rs_c + n] = c[m * rs_c + n] * beta_val + ct[m*rs_ct + n]; \ + } \ \ return; \ } @@ -109,17 +119,17 @@ void PASTEMAC(ch, varname) \ start = ((n_off < m_off) && (m_off < n_off + n_cur)) ? m_off: n_off; \ end = ((n_off < m_off + m_cur) && (m_off + m_cur < n_off + n_cur))? (m_off + m_cur):(n_off + n_cur); \ \ - if( beta_val != 0.0 ) \ + if( beta_val == 1.0 ) \ { \ for(m = 0; m < start-m_off; m++) \ for(n = 0; n < n_cur; n++) \ - c[m*rs_c + n] = c[m * rs_c + n] * beta_val + ct[m*rs_ct + n]; \ + c[m*rs_c + n] += ct[m*rs_ct + n]; \ \ - for(diag = start, m= start-m_off; diag < end; diag++, m++) \ - for(n = diag-n_off; n < n_cur; n++) \ - c[m*rs_c + n] = c[m * rs_c + n] * beta_val + ct[m*rs_ct + n]; \ + for(diag = start, m= start-m_off; diag < end; diag++, m++) \ + for(n = diag-n_off; n < n_cur; n++) \ + c[m*rs_c + n] += ct[m*rs_ct + n]; \ } \ - else \ + else if ( beta_val == 0.0 )\ { \ for(m = 0; m < start-m_off; m++) \ for(n = 0; n < n_cur; n++) \ @@ -129,6 +139,16 @@ void PASTEMAC(ch, varname) \ for(n = diag-n_off; n < n_cur; n++) \ c[m*rs_c + n] = ct[m*rs_ct + n]; \ } \ + else \ + { \ + for(m = 0; m < start-m_off; m++) \ + for(n = 0; n < n_cur; n++) \ + c[m*rs_c + n] = c[m * rs_c + n] * beta_val + ct[m*rs_ct + n]; \ +\ + for(diag = start, m= start-m_off; diag < end; diag++, m++) \ + for(n = diag-n_off; n < n_cur; n++) \ + c[m*rs_c + n] = c[m * rs_c + n] * beta_val + ct[m*rs_ct + n]; \ + } \ \ return; \ } diff --git a/kernels/zen/util/bli_thresh_funcs_zen.c b/kernels/zen/util/bli_thresh_funcs_zen.c index 3aed8bf5bf..1b5fc86998 100644 --- a/kernels/zen/util/bli_thresh_funcs_zen.c +++ b/kernels/zen/util/bli_thresh_funcs_zen.c @@ -79,7 +79,7 @@ bool bli_cntx_syrksup_thresh_is_met_zen( obj_t* a, obj_t* b, obj_t* c, cntx_t* c } else { - if( n < 150 ) return TRUE; + if( n <= 432 ) return TRUE; else return FALSE; } } From 9a5b15da68a768173c2c963db6e188f2dcb2ad96 Mon Sep 17 00:00:00 2001 From: Meghana Vankadari Date: Tue, 21 Sep 2021 12:01:25 +0530 Subject: [PATCH 024/243] Disabled dzgemm_ function call in test_gemm.c Change-Id: I31da12cbaf50cf7fe44baf97de3c39896c4ccfb1 --- test/test_gemm.c | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/test/test_gemm.c b/test/test_gemm.c index b5c2163739..772d73c7b1 100644 --- a/test/test_gemm.c +++ b/test/test_gemm.c @@ -382,6 +382,8 @@ int main( int argc, char** argv ) cp, ldc ); #else +//Disabled dzgemm function temporarily. +#if 0 if( bli_is_double( dt_a ) ) { dzgemm_( @@ -399,6 +401,7 @@ int main( int argc, char** argv ) } else { +#else zgemm_( &f77_transa, &f77_transb, &mm, @@ -409,7 +412,8 @@ int main( int argc, char** argv ) bp, (f77_int*)&ldb, betap, cp, (f77_int*)&ldc ); - } +// } +#endif #endif } #endif From 9f1ce594a5826bb188ec6a7b78b3ebce0e868da9 Mon Sep 17 00:00:00 2001 From: mkurumel Date: Wed, 16 Jun 2021 06:55:40 +0530 Subject: [PATCH 025/243] BLIS : Compiler warning fixes Details : - Fixed warnings with AOCC and GCC compilers. AMD-Internal: [CPUPL-1662] Change-Id: Ia0e298a169d4dd4664b11e03a4e3cd340e9fdfce --- frame/base/bli_cntx.c | 2 +- .../haswell/1m/bli_packm_haswell_asm_c3xk.c | 4 +- .../haswell/1m/bli_packm_haswell_asm_c8xk.c | 4 +- .../haswell/1m/bli_packm_haswell_asm_z3xk.c | 4 +- .../haswell/1m/bli_packm_haswell_asm_z4xk.c | 4 +- kernels/haswell/3/bli_gemm_haswell_asm_d6x8.c | 48 +- kernels/zen/1f/bli_axpyf_zen_int_6.c | 5 +- .../zen/3/sup/bli_gemmsup_rv_zen_asm_z3x4m.c | 1886 +++++++------- .../zen/3/sup/bli_gemmsup_rv_zen_asm_z3x4n.c | 2260 ++++++++--------- 9 files changed, 2084 insertions(+), 2133 deletions(-) diff --git a/frame/base/bli_cntx.c b/frame/base/bli_cntx.c index 2ff56c0ba6..3a8a2f0d70 100644 --- a/frame/base/bli_cntx.c +++ b/frame/base/bli_cntx.c @@ -1631,7 +1631,7 @@ void bli_cntx_set_l3_thresh_funcs( dim_t n_funcs, ... ) #ifdef BLIS_ENABLE_MEM_TRACING printf( "bli_cntx_set_l3_thresh_funcs(): " ); #endif - l1vkr_t* func_ids = bli_malloc_intl( n_funcs * sizeof( opid_t ) ); + opid_t* func_ids = bli_malloc_intl( n_funcs * sizeof( opid_t ) ); #ifdef BLIS_ENABLE_MEM_TRACING printf( "bli_cntx_set_l3_thresh_funcs(): " ); diff --git a/kernels/haswell/1m/bli_packm_haswell_asm_c3xk.c b/kernels/haswell/1m/bli_packm_haswell_asm_c3xk.c index c31384cc45..b99b6eef26 100644 --- a/kernels/haswell/1m/bli_packm_haswell_asm_c3xk.c +++ b/kernels/haswell/1m/bli_packm_haswell_asm_c3xk.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019 - 2020, Advanced Micro Devices, Inc. + Copyright (C) 2019 - 2021, Advanced Micro Devices, Inc.All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -104,7 +104,7 @@ void bli_cpackm_haswell_asm_3xk // ------------------------------------------------------------------------- - if ( cdim0 == mnr && !gs && !bli_does_conj( conja ) && unitk ) + if ( cdim0 == mnr && !gs && !conja && unitk ) { begin_asm() diff --git a/kernels/haswell/1m/bli_packm_haswell_asm_c8xk.c b/kernels/haswell/1m/bli_packm_haswell_asm_c8xk.c index 02c894a393..4cad0c90c3 100644 --- a/kernels/haswell/1m/bli_packm_haswell_asm_c8xk.c +++ b/kernels/haswell/1m/bli_packm_haswell_asm_c8xk.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019 - 2020, Advanced Micro Devices, Inc. + Copyright (C) 2019 - 2021, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -104,7 +104,7 @@ void bli_cpackm_haswell_asm_8xk // ------------------------------------------------------------------------- - if ( cdim0 == mnr && !gs && !bli_does_conj( conja ) && unitk ) + if ( cdim0 == mnr && !gs && !conja && unitk ) { begin_asm() diff --git a/kernels/haswell/1m/bli_packm_haswell_asm_z3xk.c b/kernels/haswell/1m/bli_packm_haswell_asm_z3xk.c index 26b98f4daf..06fcf1438a 100644 --- a/kernels/haswell/1m/bli_packm_haswell_asm_z3xk.c +++ b/kernels/haswell/1m/bli_packm_haswell_asm_z3xk.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019 - 2020, Advanced Micro Devices, Inc. + Copyright (C) 2019 - 2021, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -104,7 +104,7 @@ void bli_zpackm_haswell_asm_3xk // ------------------------------------------------------------------------- - if ( cdim0 == mnr && !gs && !bli_does_conj( conja ) && unitk ) + if ( cdim0 == mnr && !gs && !conja && unitk ) { begin_asm() diff --git a/kernels/haswell/1m/bli_packm_haswell_asm_z4xk.c b/kernels/haswell/1m/bli_packm_haswell_asm_z4xk.c index 6552317541..25a8b6181e 100644 --- a/kernels/haswell/1m/bli_packm_haswell_asm_z4xk.c +++ b/kernels/haswell/1m/bli_packm_haswell_asm_z4xk.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019 - 2020, Advanced Micro Devices, Inc. + Copyright (C) 2019 - 2021, Advanced Micro Devices, Inc.All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -104,7 +104,7 @@ void bli_zpackm_haswell_asm_4xk // ------------------------------------------------------------------------- - if ( cdim0 == mnr && !gs && !bli_does_conj( conja ) && unitk ) + if ( cdim0 == mnr && !gs && !conja && unitk ) { begin_asm() diff --git a/kernels/haswell/3/bli_gemm_haswell_asm_d6x8.c b/kernels/haswell/3/bli_gemm_haswell_asm_d6x8.c index 315894b171..b4ac979e1a 100644 --- a/kernels/haswell/3/bli_gemm_haswell_asm_d6x8.c +++ b/kernels/haswell/3/bli_gemm_haswell_asm_d6x8.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2021, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2021, Advanced Micro Devices, Inc.All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -2224,40 +2224,24 @@ void bli_zgemm_haswell_asm_3x4 uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; - //handling case when alpha and beta are real and +/-1. - uint64_t alpha_real_one = *((uint64_t*)(&alpha->real)); - uint64_t beta_real_one = *((uint64_t*)(&beta->real)); - - uint64_t alpha_real_one_abs = ((alpha_real_one << 1) >> 1); - uint64_t beta_real_one_abs = ((beta_real_one << 1) >> 1); - char alpha_mul_type = BLIS_MUL_DEFAULT; char beta_mul_type = BLIS_MUL_DEFAULT; - if((alpha_real_one_abs == BLIS_DOUBLE_TO_UINT64_ONE_ABS) && (alpha->imag==0))// (alpha is real and +/-1) - { - alpha_mul_type = BLIS_MUL_ONE; //alpha real and 1 - if(alpha_real_one == BLIS_DOUBLE_TO_UINT64_MINUS_ONE) - { - alpha_mul_type = BLIS_MUL_MINUS_ONE; //alpha real and -1 - } - } - - if(beta->imag == 0)// beta is real - { - if(beta_real_one_abs == BLIS_DOUBLE_TO_UINT64_ONE_ABS)// (beta +/-1) - { - beta_mul_type = BLIS_MUL_ONE; - if(beta_real_one == BLIS_DOUBLE_TO_UINT64_MINUS_ONE) - { - beta_mul_type = BLIS_MUL_MINUS_ONE; - } - } - else if(beta_real_one == 0) - { - beta_mul_type = BLIS_MUL_ZERO; - } - } + //handling case when alpha and beta are real and +/-1. + + if(alpha->imag == 0.0)// (alpha is real) + { + if(alpha->real == 1.0) alpha_mul_type = BLIS_MUL_ONE; + else if(alpha->real == -1.0) alpha_mul_type = BLIS_MUL_MINUS_ONE; + else if(alpha->real == 0.0) alpha_mul_type = BLIS_MUL_ZERO; + } + + if(beta->imag == 0.0)// (beta is real) + { + if(beta->real == 1.0) beta_mul_type = BLIS_MUL_ONE; + else if(beta->real == -1.0) beta_mul_type = BLIS_MUL_MINUS_ONE; + else if(beta->real == 0.0) beta_mul_type = BLIS_MUL_ZERO; + } begin_asm() diff --git a/kernels/zen/1f/bli_axpyf_zen_int_6.c b/kernels/zen/1f/bli_axpyf_zen_int_6.c index d27dce6cf2..99b544db15 100644 --- a/kernels/zen/1f/bli_axpyf_zen_int_6.c +++ b/kernels/zen/1f/bli_axpyf_zen_int_6.c @@ -83,10 +83,9 @@ void bli_saxpyf_zen_int_6 v8sf_t chi0v, chi1v, chi2v, chi3v; v8sf_t chi4v,chi5v; - v8sf_t a00v, a01v, a02v, a03v; - v8sf_t a04v,a05v; + v8sf_t a00v, a01v; - v8sf_t y0v, y1v; + v8sf_t y0v; float chi0, chi1, chi2, chi3; float chi4,chi5; diff --git a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_z3x4m.c b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_z3x4m.c index 1e9bacd9ae..64aedb8791 100644 --- a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_z3x4m.c +++ b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_z3x4m.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2021, Advanced Micro Devices, Inc. + Copyright (C) 2020 - 2021, Advanced Micro Devices, Inc.All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -42,57 +42,57 @@ and store outputs to ymm0 (creal,cimag)*(betar,beati) where c is stored in col major order*/ #define ZGEMM_INPUT_SCALE_CS_BETA_NZ \ - vmovupd(mem(rcx), xmm0) \ - vmovupd(mem(rcx, rsi, 1), xmm3) \ - vinsertf128(imm(1), xmm3, ymm0, ymm0) \ - vpermilpd(imm(0x5), ymm0, ymm3) \ - vmulpd(ymm1, ymm0, ymm0) \ - vmulpd(ymm2, ymm3, ymm3) \ - vaddsubpd(ymm3, ymm0, ymm0) + vmovupd(mem(rcx), xmm0) \ + vmovupd(mem(rcx, rsi, 1), xmm3) \ + vinsertf128(imm(1), xmm3, ymm0, ymm0) \ + vpermilpd(imm(0x5), ymm0, ymm3) \ + vmulpd(ymm1, ymm0, ymm0) \ + vmulpd(ymm2, ymm3, ymm3) \ + vaddsubpd(ymm3, ymm0, ymm0) //(creal,cimag)*(betar,beati) where c is stored in row major order #define ZGEMM_INPUT_SCALE_RS_BETA_NZ \ - vmovupd(mem(rcx), ymm0) \ - vpermilpd(imm(0x5), ymm0, ymm3) \ - vmulpd(ymm1, ymm0, ymm0) \ - vmulpd(ymm2, ymm3, ymm3) \ - vaddsubpd(ymm3, ymm0, ymm0) + vmovupd(mem(rcx), ymm0) \ + vpermilpd(imm(0x5), ymm0, ymm3) \ + vmulpd(ymm1, ymm0, ymm0) \ + vmulpd(ymm2, ymm3, ymm3) \ + vaddsubpd(ymm3, ymm0, ymm0) #define ZGEMM_INPUT_RS_BETA_ONE \ - vmovupd(mem(rcx), ymm0) + vmovupd(mem(rcx), ymm0) #define ZGEMM_OUTPUT_RS \ - vmovupd(ymm0, mem(rcx)) \ + vmovupd(ymm0, mem(rcx)) \ -/*(cNextRowreal,cNextRowimag)*(betar,beati) +/*(cNextRowreal,cNextRowimag)*(betar,beati) where c is stored in row major order rsi = cs_c * sizeof((real +imag)dt)*numofElements numofElements = 2, 2 elements are processed at a time*/ #define ZGEMM_INPUT_SCALE_RS_BETA_NZ_NEXT \ - vmovupd(mem(rcx, rsi, 1), ymm0) \ - vpermilpd(imm(0x5), ymm0, ymm3) \ - vmulpd(ymm1, ymm0, ymm0) \ - vmulpd(ymm2, ymm3, ymm3) \ - vaddsubpd(ymm3, ymm0, ymm0) + vmovupd(mem(rcx, rsi, 1), ymm0) \ + vpermilpd(imm(0x5), ymm0, ymm3) \ + vmulpd(ymm1, ymm0, ymm0) \ + vmulpd(ymm2, ymm3, ymm3) \ + vaddsubpd(ymm3, ymm0, ymm0) #define ZGEMM_INPUT_RS_BETA_ONE_NEXT \ - vmovupd(mem(rcx, rsi, 1), ymm0) + vmovupd(mem(rcx, rsi, 1), ymm0) #define ZGEMM_OUTPUT_RS_NEXT \ - vmovupd(ymm0, mem(rcx, rsi, 1)) + vmovupd(ymm0, mem(rcx, rsi, 1)) /* rrr: - -------- ------ -------- - -------- += ------ ... -------- - -------- ------ -------- - -------- ------ : + -------- ------ -------- + -------- += ------ ... -------- + -------- ------ -------- + -------- ------ : rcr: - -------- | | | | -------- - -------- += | | | | ... -------- - -------- | | | | -------- - -------- | | | | : + -------- | | | | -------- + -------- += | | | | ... -------- + -------- | | | | -------- + -------- | | | | : Assumptions: - B is row-stored; @@ -108,11 +108,11 @@ cost of the in-register transpose). crr: - | | | | | | | | ------ -------- - | | | | | | | | += ------ - -------- - | | | | | | | | ------ -------- - | | | | | | | | ------ : + | | | | | | | | ------ -------- + | | | | | | | | += ------ + -------- + | | | | | | | | ------ -------- + | | | | | | | | ------ : */ void bli_zgemmsup_rv_zen_asm_3x4m ( @@ -130,666 +130,650 @@ void bli_zgemmsup_rv_zen_asm_3x4m cntx_t* restrict cntx ) { - uint64_t n_left = n0 % 4; + uint64_t n_left = n0 % 4; - // First check whether this is a edge case in the n dimension. If so, - // dispatch other 3x?m kernels, as needed. - if (n_left ) - { + // First check whether this is a edge case in the n dimension. If so, + // dispatch other 3x?m kernels, as needed. + if (n_left ) + { dcomplex* cij = c; dcomplex* bj = b; dcomplex* ai = a; - if ( 2 <= n_left ) - { - const dim_t nr_cur = 2; - - bli_zgemmsup_rv_zen_asm_3x2m - ( - conja, conjb, m0, nr_cur, k0, - alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, - beta, cij, rs_c0, cs_c0, data, cntx - ); - cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; - } - if ( 1 == n_left ) - { - bli_zgemv_ex - ( - BLIS_NO_TRANSPOSE, conjb, m0, k0, - alpha, ai, rs_a0, cs_a0, bj, rs_b0, - beta, cij, rs_c0, cntx, NULL - ); - } - - return; - } - - //void* a_next = bli_auxinfo_next_a( data ); - //void* b_next = bli_auxinfo_next_b( data ); - - // Typecast local copies of integers in case dim_t and inc_t are a - // different size than is expected by load instructions. - - uint64_t k_iter = k0 / 4; - uint64_t k_left = k0 % 4; - - uint64_t m_iter = m0 / 3; - uint64_t m_left = m0 % 3; - - uint64_t rs_a = rs_a0; - uint64_t cs_a = cs_a0; - uint64_t rs_b = rs_b0; - uint64_t rs_c = rs_c0; - uint64_t cs_c = cs_c0; - - if ( m_iter == 0 ) goto consider_edge_cases; - - //handling case when alpha and beta are real and +/-1. - uint64_t alpha_real_one = *((uint64_t*)(&alpha->real)); - uint64_t beta_real_one = *((uint64_t*)(&beta->real)); - - uint64_t alpha_real_one_abs = ((alpha_real_one << 1) >> 1); - uint64_t beta_real_one_abs = ((beta_real_one << 1) >> 1); - - char alpha_mul_type = BLIS_MUL_DEFAULT; - char beta_mul_type = BLIS_MUL_DEFAULT; - - if((alpha_real_one_abs == BLIS_DOUBLE_TO_UINT64_ONE_ABS) && (alpha->imag==0))// (alpha is real and +/-1) - { - alpha_mul_type = BLIS_MUL_ONE; //alpha real and 1 - if(alpha_real_one == BLIS_DOUBLE_TO_UINT64_MINUS_ONE) - { - alpha_mul_type = BLIS_MUL_MINUS_ONE; //alpha real and -1 - } - } - - if(beta->imag == 0)// beta is real - { - if(beta_real_one_abs == BLIS_DOUBLE_TO_UINT64_ONE_ABS)// (beta +/-1) - { - beta_mul_type = BLIS_MUL_ONE; - if(beta_real_one == BLIS_DOUBLE_TO_UINT64_MINUS_ONE) - { - beta_mul_type = BLIS_MUL_MINUS_ONE; - } - } - else if(beta_real_one == 0) - { - beta_mul_type = BLIS_MUL_ZERO; - } - } - - // ------------------------------------------------------------------------- - - begin_asm() - - mov(var(a), r14) // load address of a. - mov(var(rs_a), r8) // load rs_a - mov(var(cs_a), r9) // load cs_a - lea(mem(, r8, 8), r8) // rs_a *= sizeof(real dt) - lea(mem(, r8, 2), r8) // rs_a *= sizeof((real + imag) dt) - lea(mem(, r9, 8), r9) // cs_a *= sizeof( real dt) - lea(mem(, r9, 2), r9) // cs_a *= sizeof((real + imag) dt) - - mov(var(rs_b), r10) // load rs_b - lea(mem(, r10, 8), r10) // rs_b *= sizeof(real dt) - lea(mem(, r10, 2), r10) // rs_b *= sizeof((real +imag) dt) - - // NOTE: We cannot pre-load elements of a or b - // because it could eventually, in the last - // unrolled iter or the cleanup loop, result - // in reading beyond the bounds allocated mem - // (the likely result: a segmentation fault). - - mov(var(c), r12) // load address of c - mov(var(rs_c), rdi) // load rs_c - lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(dt) - lea(mem(, rdi, 2), rdi) // rs_c *= sizeof(dt) - - // During preamble and loops: - // r12 = rcx = c - // r14 = rax = a - // read rbx from var(b) near beginning of loop - // r11 = m dim index ii - - mov(var(m_iter), r11) // ii = m_iter; + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_zgemmsup_rv_zen_asm_3x2m + ( + conja, conjb, m0, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { + bli_zgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, m0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, + beta, cij, rs_c0, cntx, NULL + ); + } + + return; + } + + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; - label(.ZLOOP3X4I) // LOOP OVER ii = [ m_iter ... 1 0 ] + uint64_t m_iter = m0 / 3; + uint64_t m_left = m0 % 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + if ( m_iter == 0 ) goto consider_edge_cases; + + char alpha_mul_type = BLIS_MUL_DEFAULT; + char beta_mul_type = BLIS_MUL_DEFAULT; + + //handling case when alpha and beta are real and +/-1. - vzeroall() // zero all xmm/ymm registers. + if(alpha->imag == 0.0)// (alpha is real) + { + if(alpha->real == 1.0) alpha_mul_type = BLIS_MUL_ONE; + else if(alpha->real == -1.0) alpha_mul_type = BLIS_MUL_MINUS_ONE; + else if(alpha->real == 0.0) alpha_mul_type = BLIS_MUL_ZERO; + } - mov(var(b), rbx) // load address of b. - mov(r14, rax) // reset rax to current upanel of a. + if(beta->imag == 0.0)// (beta is real) + { + if(beta->real == 1.0) beta_mul_type = BLIS_MUL_ONE; + else if(beta->real == -1.0) beta_mul_type = BLIS_MUL_MINUS_ONE; + else if(beta->real == 0.0) beta_mul_type = BLIS_MUL_ZERO; + } - cmp(imm(16), rdi) // set ZF if (16*rs_c) == 16. - jz(.ZCOLPFETCH) // jump to column storage case - label(.ZROWPFETCH) // row-stored pre-fetching on c // not used + // ------------------------------------------------------------------------- - jmp(.ZPOSTPFETCH) // jump to end of pre-fetching c - label(.ZCOLPFETCH) // column-stored pre-fetching c + begin_asm() - mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) - lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(dt) - label(.ZPOSTPFETCH) // done prefetching c + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(real dt) + lea(mem(, r8, 2), r8) // rs_a *= sizeof((real + imag) dt) + lea(mem(, r9, 8), r9) // cs_a *= sizeof( real dt) + lea(mem(, r9, 2), r9) // cs_a *= sizeof((real + imag) dt) - mov(var(k_iter), rsi) // i = k_iter; - test(rsi, rsi) // check i via logical AND. - je(.ZCONSIDKLEFT) // if i == 0, jump to code that - // contains the k_left loop. + mov(var(rs_b), r10) // load rs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(real dt) + lea(mem(, r10, 2), r10) // rs_b *= sizeof((real +imag) dt) - label(.ZLOOPKITER) // MAIN LOOP + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). - // ---------------------------------- iteration 0 + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(dt) + lea(mem(, rdi, 2), rdi) // rs_c *= sizeof(dt) - vmovupd(mem(rbx, 0*32), ymm0) - vmovupd(mem(rbx, 1*32), ymm1) - add(r10, rbx) // b += rs_b; + // During preamble and loops: + // r12 = rcx = c + // r14 = rax = a + // read rbx from var(b) near beginning of loop + // r11 = m dim index ii - vbroadcastsd(mem(rax ), ymm2) - vfmadd231pd(ymm0, ymm2, ymm4) - vfmadd231pd(ymm1, ymm2, ymm5) + mov(var(m_iter), r11) // ii = m_iter; - vbroadcastsd(mem(rax, r8, 1), ymm2) - vfmadd231pd(ymm0, ymm2, ymm8) - vfmadd231pd(ymm1, ymm2, ymm9) + label(.ZLOOP3X4I) // LOOP OVER ii = [ m_iter ... 1 0 ] - vbroadcastsd(mem(rax, r8, 2), ymm2) - vfmadd231pd(ymm0, ymm2, ymm12) - vfmadd231pd(ymm1, ymm2, ymm13) + vzeroall() // zero all xmm/ymm registers. - vbroadcastsd(mem(rax, 8), ymm3) - vfmadd231pd(ymm0, ymm3, ymm6) - vfmadd231pd(ymm1, ymm3, ymm7) + mov(var(b), rbx) // load address of b. + mov(r14, rax) // reset rax to current upanel of a. - vbroadcastsd(mem(rax, r8, 1, 8), ymm3) - vfmadd231pd(ymm0, ymm3, ymm10) - vfmadd231pd(ymm1, ymm3, ymm11) + cmp(imm(16), rdi) // set ZF if (16*rs_c) == 16. + jz(.ZCOLPFETCH) // jump to column storage case + label(.ZROWPFETCH) // row-stored pre-fetching on c // not used - vbroadcastsd(mem(rax, r8, 2, 8), ymm3) - vfmadd231pd(ymm0, ymm3, ymm14) - vfmadd231pd(ymm1, ymm3, ymm15) + jmp(.ZPOSTPFETCH) // jump to end of pre-fetching c + label(.ZCOLPFETCH) // column-stored pre-fetching c - add(r9, rax) // a += cs_a; + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(dt) + label(.ZPOSTPFETCH) // done prefetching c - // ---------------------------------- iteration 1 + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.ZCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. - vmovupd(mem(rbx, 0*32), ymm0) - vmovupd(mem(rbx, 1*32), ymm1) - add(r10, rbx) // b += rs_b; + label(.ZLOOPKITER) // MAIN LOOP - vbroadcastsd(mem(rax ), ymm2) - vfmadd231pd(ymm0, ymm2, ymm4) - vfmadd231pd(ymm1, ymm2, ymm5) + // ---------------------------------- iteration 0 - vbroadcastsd(mem(rax, r8, 1), ymm2) - vfmadd231pd(ymm0, ymm2, ymm8) - vfmadd231pd(ymm1, ymm2, ymm9) + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; - vbroadcastsd(mem(rax, r8, 2), ymm2) - vfmadd231pd(ymm0, ymm2, ymm12) - vfmadd231pd(ymm1, ymm2, ymm13) + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) - vbroadcastsd(mem(rax, 8 ), ymm3) - vfmadd231pd(ymm0, ymm3, ymm6) - vfmadd231pd(ymm1, ymm3, ymm7) + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) - vbroadcastsd(mem(rax, r8, 1, 8), ymm3) - vfmadd231pd(ymm0, ymm3, ymm10) - vfmadd231pd(ymm1, ymm3, ymm11) + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) - vbroadcastsd(mem(rax, r8, 2, 8), ymm3) - vfmadd231pd(ymm0, ymm3, ymm14) - vfmadd231pd(ymm1, ymm3, ymm15) + vbroadcastsd(mem(rax, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) - add(r9, rax) // a += cs_a; + vbroadcastsd(mem(rax, r8, 1, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) - // ---------------------------------- iteration 2 + vbroadcastsd(mem(rax, r8, 2, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) - vmovupd(mem(rbx, 0*32), ymm0) - vmovupd(mem(rbx, 1*32), ymm1) - add(r10, rbx) // b += rs_b; + add(r9, rax) // a += cs_a; - vbroadcastsd(mem(rax ), ymm2) - vfmadd231pd(ymm0, ymm2, ymm4) - vfmadd231pd(ymm1, ymm2, ymm5) + // ---------------------------------- iteration 1 - vbroadcastsd(mem(rax, r8, 1), ymm2) - vfmadd231pd(ymm0, ymm2, ymm8) - vfmadd231pd(ymm1, ymm2, ymm9) + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; - vbroadcastsd(mem(rax, r8, 2), ymm2) - vfmadd231pd(ymm0, ymm2, ymm12) - vfmadd231pd(ymm1, ymm2, ymm13) + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) - vbroadcastsd(mem(rax, 8 ), ymm3) - vfmadd231pd(ymm0, ymm3, ymm6) - vfmadd231pd(ymm1, ymm3, ymm7) + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) - vbroadcastsd(mem(rax, r8, 1, 8), ymm3) - vfmadd231pd(ymm0, ymm3, ymm10) - vfmadd231pd(ymm1, ymm3, ymm11) + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) - vbroadcastsd(mem(rax, r8, 2, 8), ymm3) - vfmadd231pd(ymm0, ymm3, ymm14) - vfmadd231pd(ymm1, ymm3, ymm15) + vbroadcastsd(mem(rax, 8 ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) - add(r9, rax) // a += cs_a; + vbroadcastsd(mem(rax, r8, 1, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) - // ---------------------------------- iteration 3 - vmovupd(mem(rbx, 0*32), ymm0) - vmovupd(mem(rbx, 1*32), ymm1) - add(r10, rbx) // b += rs_b; + vbroadcastsd(mem(rax, r8, 2, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) - vbroadcastsd(mem(rax ), ymm2) - vfmadd231pd(ymm0, ymm2, ymm4) - vfmadd231pd(ymm1, ymm2, ymm5) + add(r9, rax) // a += cs_a; - vbroadcastsd(mem(rax, r8, 1), ymm2) - vfmadd231pd(ymm0, ymm2, ymm8) - vfmadd231pd(ymm1, ymm2, ymm9) + // ---------------------------------- iteration 2 - vbroadcastsd(mem(rax, r8, 2), ymm2) - vfmadd231pd(ymm0, ymm2, ymm12) - vfmadd231pd(ymm1, ymm2, ymm13) + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; - vbroadcastsd(mem(rax, 8), ymm3) - vfmadd231pd(ymm0, ymm3, ymm6) - vfmadd231pd(ymm1, ymm3, ymm7) + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) - vbroadcastsd(mem(rax, r8, 1, 8), ymm3) - vfmadd231pd(ymm0, ymm3, ymm10) - vfmadd231pd(ymm1, ymm3, ymm11) + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) - vbroadcastsd(mem(rax, r8, 2, 8), ymm3) - vfmadd231pd(ymm0, ymm3, ymm14) - vfmadd231pd(ymm1, ymm3, ymm15) + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) - add(r9, rax) // a += cs_a; + vbroadcastsd(mem(rax, 8 ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) - dec(rsi) // i -= 1; - jne(.ZLOOPKITER) // iterate again if i != 0. + vbroadcastsd(mem(rax, r8, 1, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) - label(.ZCONSIDKLEFT) + vbroadcastsd(mem(rax, r8, 2, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) - mov(var(k_left), rsi) // i = k_left; - test(rsi, rsi) // check i via logical AND. - je(.ZPOSTACCUM) // if i == 0, we're done; jump to end. - // else, we prepare to enter k_left loop. + add(r9, rax) // a += cs_a; - label(.ZLOOPKLEFT) // EDGE LOOP + // ---------------------------------- iteration 3 + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; - vmovupd(mem(rbx, 0*32), ymm0) - vmovupd(mem(rbx, 1*32), ymm1) - add(r10, rbx) // b += rs_b; + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) - vbroadcastsd(mem(rax ), ymm2) - vfmadd231pd(ymm0, ymm2, ymm4) - vfmadd231pd(ymm1, ymm2, ymm5) + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) - vbroadcastsd(mem(rax, r8, 1), ymm2) - vfmadd231pd(ymm0, ymm2, ymm8) - vfmadd231pd(ymm1, ymm2, ymm9) + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) - vbroadcastsd(mem(rax, r8, 2), ymm2) - vfmadd231pd(ymm0, ymm2, ymm12) - vfmadd231pd(ymm1, ymm2, ymm13) + vbroadcastsd(mem(rax, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) - vbroadcastsd(mem(rax, 8), ymm3) - vfmadd231pd(ymm0, ymm3, ymm6) - vfmadd231pd(ymm1, ymm3, ymm7) + vbroadcastsd(mem(rax, r8, 1, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) - vbroadcastsd(mem(rax, r8, 1, 8), ymm3) - vfmadd231pd(ymm0, ymm3, ymm10) - vfmadd231pd(ymm1, ymm3, ymm11) + vbroadcastsd(mem(rax, r8, 2, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) - vbroadcastsd(mem(rax, r8, 2, 8), ymm3) - vfmadd231pd(ymm0, ymm3, ymm14) - vfmadd231pd(ymm1, ymm3, ymm15) + add(r9, rax) // a += cs_a; - add(r9, rax) // a += cs_a; + dec(rsi) // i -= 1; + jne(.ZLOOPKITER) // iterate again if i != 0. - dec(rsi) // i -= 1; - jne(.ZLOOPKLEFT) // iterate again if i != 0. + label(.ZCONSIDKLEFT) - label(.ZPOSTACCUM) + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.ZPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. - mov(r12, rcx) // reset rcx to current utile of c. + label(.ZLOOPKLEFT) // EDGE LOOP - // permute even and odd elements - // of ymm6/7, ymm10/11, ymm/14/15 - vpermilpd(imm(0x5), ymm6, ymm6) - vpermilpd(imm(0x5), ymm7, ymm7) - vpermilpd(imm(0x5), ymm10, ymm10) - vpermilpd(imm(0x5), ymm11, ymm11) - vpermilpd(imm(0x5), ymm14, ymm14) - vpermilpd(imm(0x5), ymm15, ymm15) + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; - // subtract/add even/odd elements - vaddsubpd(ymm6, ymm4, ymm4) - vaddsubpd(ymm7, ymm5, ymm5) + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) - vaddsubpd(ymm10, ymm8, ymm8) - vaddsubpd(ymm11, ymm9, ymm9) + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) - vaddsubpd(ymm14, ymm12, ymm12) - vaddsubpd(ymm15, ymm13, ymm13) + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) - mov(var(cs_c), rsi) // load cs_c - lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(real dt) - lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof((real +imag)dt) + vbroadcastsd(mem(rax, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) - //if(alpha_mul_type == BLIS_MUL_MINUS_ONE) - mov(var(alpha_mul_type), al) - cmp(imm(0xFF), al) - jne(.ALPHA_NOT_MINUS1) + vbroadcastsd(mem(rax, r8, 1, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) - // when alpha = -1 and real. - vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. - vsubpd(ymm4, ymm0, ymm4) - vsubpd(ymm5, ymm0, ymm5) - vsubpd(ymm8, ymm0, ymm8) - vsubpd(ymm9, ymm0, ymm9) - vsubpd(ymm12, ymm0, ymm12) - vsubpd(ymm13, ymm0, ymm13) - jmp(.ALPHA_REAL_ONE) - - label(.ALPHA_NOT_MINUS1) - //when alpha is real and +/-1, multiplication is skipped. - cmp(imm(2), al)//if(alpha_mul_type != BLIS_MUL_DEFAULT) skip below multiplication. - jne(.ALPHA_REAL_ONE) + vbroadcastsd(mem(rax, r8, 2, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) - /* (ar + ai) x AB */ - mov(var(alpha), rax) // load address of alpha - vbroadcastsd(mem(rax), ymm0) // load alpha_r and duplicate - vbroadcastsd(mem(rax, 8), ymm1) // load alpha_i and duplicate + add(r9, rax) // a += cs_a; - vpermilpd(imm(0x5), ymm4, ymm3) - vmulpd(ymm0, ymm4, ymm4) - vmulpd(ymm1, ymm3, ymm3) - vaddsubpd(ymm3, ymm4, ymm4) - - vpermilpd(imm(0x5), ymm5, ymm3) - vmulpd(ymm0, ymm5, ymm5) - vmulpd(ymm1, ymm3, ymm3) - vaddsubpd(ymm3, ymm5, ymm5) - - vpermilpd(imm(0x5), ymm8, ymm3) - vmulpd(ymm0, ymm8, ymm8) - vmulpd(ymm1, ymm3, ymm3) - vaddsubpd(ymm3, ymm8, ymm8) - - vpermilpd(imm(0x5), ymm9, ymm3) - vmulpd(ymm0, ymm9, ymm9) - vmulpd(ymm1, ymm3, ymm3) - vaddsubpd(ymm3, ymm9, ymm9) - - vpermilpd(imm(0x5), ymm12, ymm3) - vmulpd(ymm0, ymm12, ymm12) - vmulpd(ymm1, ymm3, ymm3) - vaddsubpd(ymm3, ymm12, ymm12) - - vpermilpd(imm(0x5), ymm13, ymm3) - vmulpd(ymm0, ymm13, ymm13) - vmulpd(ymm1, ymm3, ymm3) - vaddsubpd(ymm3, ymm13, ymm13) - - label(.ALPHA_REAL_ONE) - // Beta multiplication - /* (br + bi)x C + ((ar + ai) x AB) */ - - mov(var(beta_mul_type), al) - cmp(imm(0), al) //if(beta_mul_type == BLIS_MUL_ZERO) - je(.ZBETAZERO) //jump to beta == 0 case - - cmp(imm(16), rdi) // set ZF if (16*rs_c) ==16. - jz(.ZCOLSTORED) // jump to column storage case - - label(.ZROWSTORED) - - lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof((real +imag)dt) * numofElements - - cmp(imm(2), al) // if(beta_mul_type == BLIS_MUL_DEFAULT) - je(.ROW_BETA_NOT_REAL_ONE) // jump to beta handling with multiplication. - - cmp(imm(0xFF), al) // if(beta_mul_type == BLIS_MUL_MINUS_ONE) - je(.ROW_BETA_REAL_MINUS1) // jump to beta real = -1 section. - - //CASE 1: beta is real = 1 - ZGEMM_INPUT_RS_BETA_ONE - vaddpd(ymm4, ymm0, ymm0) - ZGEMM_OUTPUT_RS - - ZGEMM_INPUT_RS_BETA_ONE_NEXT - vaddpd(ymm5, ymm0, ymm0) - ZGEMM_OUTPUT_RS_NEXT - add(rdi, rcx) // rcx = c + 1*rs_c - - ZGEMM_INPUT_RS_BETA_ONE - vaddpd(ymm8, ymm0, ymm0) - ZGEMM_OUTPUT_RS - - ZGEMM_INPUT_RS_BETA_ONE_NEXT - vaddpd(ymm9, ymm0, ymm0) - ZGEMM_OUTPUT_RS_NEXT - add(rdi, rcx) // rcx = c + 2*rs_c - - ZGEMM_INPUT_RS_BETA_ONE - vaddpd(ymm12, ymm0, ymm0) - ZGEMM_OUTPUT_RS - - ZGEMM_INPUT_RS_BETA_ONE_NEXT - vaddpd(ymm13, ymm0, ymm0) - ZGEMM_OUTPUT_RS_NEXT - jmp(.ZDONE) - - - //CASE 2: beta is real = -1 - label(.ROW_BETA_REAL_MINUS1) - ZGEMM_INPUT_RS_BETA_ONE - vsubpd(ymm0, ymm4, ymm0) - ZGEMM_OUTPUT_RS - - ZGEMM_INPUT_RS_BETA_ONE_NEXT - vsubpd(ymm0, ymm5, ymm0) - ZGEMM_OUTPUT_RS_NEXT - add(rdi, rcx) // rcx = c + 1*rs_c - - ZGEMM_INPUT_RS_BETA_ONE - vsubpd(ymm0, ymm8, ymm0) - ZGEMM_OUTPUT_RS - - ZGEMM_INPUT_RS_BETA_ONE_NEXT - vsubpd(ymm0, ymm9, ymm0) - ZGEMM_OUTPUT_RS_NEXT - add(rdi, rcx) // rcx = c + 2*rs_c - - ZGEMM_INPUT_RS_BETA_ONE - vsubpd(ymm0, ymm12, ymm0) - ZGEMM_OUTPUT_RS - - ZGEMM_INPUT_RS_BETA_ONE_NEXT - vsubpd(ymm0, ymm13, ymm0) - ZGEMM_OUTPUT_RS_NEXT - jmp(.ZDONE) - - - //CASE 3: Default case with multiplication - // beta not equal to (+/-1) or zero, do normal multiplication. - label(.ROW_BETA_NOT_REAL_ONE) - mov(var(beta), rbx) // load address of beta - vbroadcastsd(mem(rbx), ymm1) // load beta_r and duplicate - vbroadcastsd(mem(rbx, 8), ymm2) // load beta_i and duplicate - - ZGEMM_INPUT_SCALE_RS_BETA_NZ - vaddpd(ymm4, ymm0, ymm0) - ZGEMM_OUTPUT_RS - - ZGEMM_INPUT_SCALE_RS_BETA_NZ_NEXT - vaddpd(ymm5, ymm0, ymm0) - ZGEMM_OUTPUT_RS_NEXT - add(rdi, rcx) // rcx = c + 1*rs_c - - ZGEMM_INPUT_SCALE_RS_BETA_NZ - vaddpd(ymm8, ymm0, ymm0) - ZGEMM_OUTPUT_RS - - ZGEMM_INPUT_SCALE_RS_BETA_NZ_NEXT - vaddpd(ymm9, ymm0, ymm0) - ZGEMM_OUTPUT_RS_NEXT - add(rdi, rcx) // rcx = c + 2*rs_c - - ZGEMM_INPUT_SCALE_RS_BETA_NZ - vaddpd(ymm12, ymm0, ymm0) - ZGEMM_OUTPUT_RS - - ZGEMM_INPUT_SCALE_RS_BETA_NZ_NEXT - vaddpd(ymm13, ymm0, ymm0) - ZGEMM_OUTPUT_RS_NEXT - jmp(.ZDONE) // jump to end. - - label(.ZCOLSTORED) - mov(var(beta), rbx) // load address of beta - vbroadcastsd(mem(rbx), ymm1) // load beta_r and duplicate - vbroadcastsd(mem(rbx, 8), ymm2) // load beta_i and duplicate - /*|--------| |-------| - | | | | - | 3x4 | | 4x3 | - |--------| |-------| - */ - - ZGEMM_INPUT_SCALE_CS_BETA_NZ - vaddpd(ymm4, ymm0, ymm4) - - add(rdi, rcx) - ZGEMM_INPUT_SCALE_CS_BETA_NZ - vaddpd(ymm8, ymm0, ymm8) - add(rdi, rcx) - - ZGEMM_INPUT_SCALE_CS_BETA_NZ - vaddpd(ymm12, ymm0, ymm12) - - lea(mem(r12, rsi, 2), rcx) - - ZGEMM_INPUT_SCALE_CS_BETA_NZ - vaddpd(ymm5, ymm0, ymm5) - add(rdi, rcx) - - ZGEMM_INPUT_SCALE_CS_BETA_NZ - vaddpd(ymm9, ymm0, ymm9) - add(rdi, rcx) - - ZGEMM_INPUT_SCALE_CS_BETA_NZ - vaddpd(ymm13, ymm0, ymm13) - - mov(r12, rcx) // reset rcx to current utile of c. - - - /****3x4 tile going to save into 4x3 tile in C*****/ - - /******************Transpose top tile 4x3***************************/ - vmovups(xmm4, mem(rcx)) - vmovups(xmm8, mem(rcx, 16)) - vmovups(xmm12, mem(rcx,32)) - - add(rsi, rcx) - - vextractf128(imm(0x1), ymm4, xmm4) - vextractf128(imm(0x1), ymm8, xmm8) - vextractf128(imm(0x1), ymm12, xmm12) - vmovups(xmm4, mem(rcx)) - vmovups(xmm8, mem(rcx, 16)) - vmovups(xmm12, mem(rcx,32)) - - add(rsi, rcx) - - vmovups(xmm5, mem(rcx)) - vmovups(xmm9, mem(rcx, 16)) - vmovups(xmm13,mem(rcx,32)) - - add(rsi, rcx) - - vextractf128(imm(0x1), ymm5, xmm5) - vextractf128(imm(0x1), ymm9, xmm9) - vextractf128(imm(0x1), ymm13, xmm13) - vmovups(xmm5, mem(rcx)) - vmovups(xmm9, mem(rcx, 16)) - vmovups(xmm13,mem(rcx,32)) - - jmp(.ZDONE) // jump to end. - - label(.ZBETAZERO) - cmp(imm(16), rdi) // set ZF if (16*rs_c) == 16. - jz(.ZCOLSTORBZ) // jump to column storage case - - label(.ZROWSTORBZ) - /* Store 3x4 elements to C matrix where is C row major order*/ + dec(rsi) // i -= 1; + jne(.ZLOOPKLEFT) // iterate again if i != 0. + + label(.ZPOSTACCUM) + + mov(r12, rcx) // reset rcx to current utile of c. + + // permute even and odd elements + // of ymm6/7, ymm10/11, ymm/14/15 + vpermilpd(imm(0x5), ymm6, ymm6) + vpermilpd(imm(0x5), ymm7, ymm7) + vpermilpd(imm(0x5), ymm10, ymm10) + vpermilpd(imm(0x5), ymm11, ymm11) + vpermilpd(imm(0x5), ymm14, ymm14) + vpermilpd(imm(0x5), ymm15, ymm15) + + // subtract/add even/odd elements + vaddsubpd(ymm6, ymm4, ymm4) + vaddsubpd(ymm7, ymm5, ymm5) + + vaddsubpd(ymm10, ymm8, ymm8) + vaddsubpd(ymm11, ymm9, ymm9) + + vaddsubpd(ymm14, ymm12, ymm12) + vaddsubpd(ymm15, ymm13, ymm13) + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(real dt) + lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof((real +imag)dt) + + //if(alpha_mul_type == BLIS_MUL_MINUS_ONE) + mov(var(alpha_mul_type), al) + cmp(imm(0xFF), al) + jne(.ALPHA_NOT_MINUS1) + + // when alpha = -1 and real. + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vsubpd(ymm4, ymm0, ymm4) + vsubpd(ymm5, ymm0, ymm5) + vsubpd(ymm8, ymm0, ymm8) + vsubpd(ymm9, ymm0, ymm9) + vsubpd(ymm12, ymm0, ymm12) + vsubpd(ymm13, ymm0, ymm13) + jmp(.ALPHA_REAL_ONE) + + label(.ALPHA_NOT_MINUS1) + //when alpha is real and +/-1, multiplication is skipped. + cmp(imm(2), al)//if(alpha_mul_type != BLIS_MUL_DEFAULT) skip below multiplication. + jne(.ALPHA_REAL_ONE) + + /* (ar + ai) x AB */ + mov(var(alpha), rax) // load address of alpha + vbroadcastsd(mem(rax), ymm0) // load alpha_r and duplicate + vbroadcastsd(mem(rax, 8), ymm1) // load alpha_i and duplicate + + vpermilpd(imm(0x5), ymm4, ymm3) + vmulpd(ymm0, ymm4, ymm4) + vmulpd(ymm1, ymm3, ymm3) + vaddsubpd(ymm3, ymm4, ymm4) + + vpermilpd(imm(0x5), ymm5, ymm3) + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm1, ymm3, ymm3) + vaddsubpd(ymm3, ymm5, ymm5) + + vpermilpd(imm(0x5), ymm8, ymm3) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm1, ymm3, ymm3) + vaddsubpd(ymm3, ymm8, ymm8) + + vpermilpd(imm(0x5), ymm9, ymm3) + vmulpd(ymm0, ymm9, ymm9) + vmulpd(ymm1, ymm3, ymm3) + vaddsubpd(ymm3, ymm9, ymm9) + + vpermilpd(imm(0x5), ymm12, ymm3) + vmulpd(ymm0, ymm12, ymm12) + vmulpd(ymm1, ymm3, ymm3) + vaddsubpd(ymm3, ymm12, ymm12) + + vpermilpd(imm(0x5), ymm13, ymm3) + vmulpd(ymm0, ymm13, ymm13) + vmulpd(ymm1, ymm3, ymm3) + vaddsubpd(ymm3, ymm13, ymm13) + + label(.ALPHA_REAL_ONE) + // Beta multiplication + /* (br + bi)x C + ((ar + ai) x AB) */ + + mov(var(beta_mul_type), al) + cmp(imm(0), al) //if(beta_mul_type == BLIS_MUL_ZERO) + je(.ZBETAZERO) //jump to beta == 0 case + + cmp(imm(16), rdi) // set ZF if (16*rs_c) ==16. + jz(.ZCOLSTORED) // jump to column storage case + + label(.ZROWSTORED) + + lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof((real +imag)dt) * numofElements + + cmp(imm(2), al) // if(beta_mul_type == BLIS_MUL_DEFAULT) + je(.ROW_BETA_NOT_REAL_ONE) // jump to beta handling with multiplication. + + cmp(imm(0xFF), al) // if(beta_mul_type == BLIS_MUL_MINUS_ONE) + je(.ROW_BETA_REAL_MINUS1) // jump to beta real = -1 section. + + //CASE 1: beta is real = 1 + ZGEMM_INPUT_RS_BETA_ONE + vaddpd(ymm4, ymm0, ymm0) + ZGEMM_OUTPUT_RS + + ZGEMM_INPUT_RS_BETA_ONE_NEXT + vaddpd(ymm5, ymm0, ymm0) + ZGEMM_OUTPUT_RS_NEXT + add(rdi, rcx) // rcx = c + 1*rs_c + + ZGEMM_INPUT_RS_BETA_ONE + vaddpd(ymm8, ymm0, ymm0) + ZGEMM_OUTPUT_RS + + ZGEMM_INPUT_RS_BETA_ONE_NEXT + vaddpd(ymm9, ymm0, ymm0) + ZGEMM_OUTPUT_RS_NEXT + add(rdi, rcx) // rcx = c + 2*rs_c + + ZGEMM_INPUT_RS_BETA_ONE + vaddpd(ymm12, ymm0, ymm0) + ZGEMM_OUTPUT_RS + + ZGEMM_INPUT_RS_BETA_ONE_NEXT + vaddpd(ymm13, ymm0, ymm0) + ZGEMM_OUTPUT_RS_NEXT + jmp(.ZDONE) + + + //CASE 2: beta is real = -1 + label(.ROW_BETA_REAL_MINUS1) + ZGEMM_INPUT_RS_BETA_ONE + vsubpd(ymm0, ymm4, ymm0) + ZGEMM_OUTPUT_RS + + ZGEMM_INPUT_RS_BETA_ONE_NEXT + vsubpd(ymm0, ymm5, ymm0) + ZGEMM_OUTPUT_RS_NEXT + add(rdi, rcx) // rcx = c + 1*rs_c + + ZGEMM_INPUT_RS_BETA_ONE + vsubpd(ymm0, ymm8, ymm0) + ZGEMM_OUTPUT_RS + + ZGEMM_INPUT_RS_BETA_ONE_NEXT + vsubpd(ymm0, ymm9, ymm0) + ZGEMM_OUTPUT_RS_NEXT + add(rdi, rcx) // rcx = c + 2*rs_c + + ZGEMM_INPUT_RS_BETA_ONE + vsubpd(ymm0, ymm12, ymm0) + ZGEMM_OUTPUT_RS + + ZGEMM_INPUT_RS_BETA_ONE_NEXT + vsubpd(ymm0, ymm13, ymm0) + ZGEMM_OUTPUT_RS_NEXT + jmp(.ZDONE) + + + //CASE 3: Default case with multiplication + // beta not equal to (+/-1) or zero, do normal multiplication. + label(.ROW_BETA_NOT_REAL_ONE) + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rbx), ymm1) // load beta_r and duplicate + vbroadcastsd(mem(rbx, 8), ymm2) // load beta_i and duplicate + + ZGEMM_INPUT_SCALE_RS_BETA_NZ + vaddpd(ymm4, ymm0, ymm0) + ZGEMM_OUTPUT_RS + + ZGEMM_INPUT_SCALE_RS_BETA_NZ_NEXT + vaddpd(ymm5, ymm0, ymm0) + ZGEMM_OUTPUT_RS_NEXT + add(rdi, rcx) // rcx = c + 1*rs_c + + ZGEMM_INPUT_SCALE_RS_BETA_NZ + vaddpd(ymm8, ymm0, ymm0) + ZGEMM_OUTPUT_RS + + ZGEMM_INPUT_SCALE_RS_BETA_NZ_NEXT + vaddpd(ymm9, ymm0, ymm0) + ZGEMM_OUTPUT_RS_NEXT + add(rdi, rcx) // rcx = c + 2*rs_c + + ZGEMM_INPUT_SCALE_RS_BETA_NZ + vaddpd(ymm12, ymm0, ymm0) + ZGEMM_OUTPUT_RS + + ZGEMM_INPUT_SCALE_RS_BETA_NZ_NEXT + vaddpd(ymm13, ymm0, ymm0) + ZGEMM_OUTPUT_RS_NEXT + jmp(.ZDONE) // jump to end. + + label(.ZCOLSTORED) + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rbx), ymm1) // load beta_r and duplicate + vbroadcastsd(mem(rbx, 8), ymm2) // load beta_i and duplicate + /*|--------| |-------| + | | | | + | 3x4 | | 4x3 | + |--------| |-------| + */ + + ZGEMM_INPUT_SCALE_CS_BETA_NZ + vaddpd(ymm4, ymm0, ymm4) + + add(rdi, rcx) + ZGEMM_INPUT_SCALE_CS_BETA_NZ + vaddpd(ymm8, ymm0, ymm8) + add(rdi, rcx) + + ZGEMM_INPUT_SCALE_CS_BETA_NZ + vaddpd(ymm12, ymm0, ymm12) + + lea(mem(r12, rsi, 2), rcx) + + ZGEMM_INPUT_SCALE_CS_BETA_NZ + vaddpd(ymm5, ymm0, ymm5) + add(rdi, rcx) + + ZGEMM_INPUT_SCALE_CS_BETA_NZ + vaddpd(ymm9, ymm0, ymm9) + add(rdi, rcx) + + ZGEMM_INPUT_SCALE_CS_BETA_NZ + vaddpd(ymm13, ymm0, ymm13) + + mov(r12, rcx) // reset rcx to current utile of c. + + + /****3x4 tile going to save into 4x3 tile in C*****/ + + /******************Transpose top tile 4x3***************************/ + vmovups(xmm4, mem(rcx)) + vmovups(xmm8, mem(rcx, 16)) + vmovups(xmm12, mem(rcx,32)) + + add(rsi, rcx) + + vextractf128(imm(0x1), ymm4, xmm4) + vextractf128(imm(0x1), ymm8, xmm8) + vextractf128(imm(0x1), ymm12, xmm12) + vmovups(xmm4, mem(rcx)) + vmovups(xmm8, mem(rcx, 16)) + vmovups(xmm12, mem(rcx,32)) + + add(rsi, rcx) + + vmovups(xmm5, mem(rcx)) + vmovups(xmm9, mem(rcx, 16)) + vmovups(xmm13,mem(rcx,32)) + + add(rsi, rcx) + + vextractf128(imm(0x1), ymm5, xmm5) + vextractf128(imm(0x1), ymm9, xmm9) + vextractf128(imm(0x1), ymm13, xmm13) + vmovups(xmm5, mem(rcx)) + vmovups(xmm9, mem(rcx, 16)) + vmovups(xmm13,mem(rcx,32)) + + jmp(.ZDONE) // jump to end. + + label(.ZBETAZERO) + cmp(imm(16), rdi) // set ZF if (16*rs_c) == 16. + jz(.ZCOLSTORBZ) // jump to column storage case + + label(.ZROWSTORBZ) + /* Store 3x4 elements to C matrix where is C row major order*/ // rsi = cs_c * sizeof((real +imag)dt) *numofElements - lea(mem(, rsi, 2), rsi) + lea(mem(, rsi, 2), rsi) - vmovupd(ymm4, mem(rcx)) - vmovupd(ymm5, mem(rcx, rsi, 1)) - add(rdi, rcx) + vmovupd(ymm4, mem(rcx)) + vmovupd(ymm5, mem(rcx, rsi, 1)) + add(rdi, rcx) - vmovupd(ymm8, mem(rcx)) - vmovupd(ymm9, mem(rcx, rsi, 1)) - add(rdi, rcx) + vmovupd(ymm8, mem(rcx)) + vmovupd(ymm9, mem(rcx, rsi, 1)) + add(rdi, rcx) - vmovupd(ymm12, mem(rcx)) - vmovupd(ymm13, mem(rcx, rsi, 1)) + vmovupd(ymm12, mem(rcx)) + vmovupd(ymm13, mem(rcx, rsi, 1)) - jmp(.ZDONE) // jump to end. + jmp(.ZDONE) // jump to end. - label(.ZCOLSTORBZ) + label(.ZCOLSTORBZ) - /****3x4 tile going to save into 4x3 tile in C*****/ + /****3x4 tile going to save into 4x3 tile in C*****/ - /******************Transpose top tile 4x3***************************/ - vmovups(xmm4, mem(rcx)) - vmovups(xmm8, mem(rcx, 16)) - vmovups(xmm12, mem(rcx,32)) + /******************Transpose top tile 4x3***************************/ + vmovups(xmm4, mem(rcx)) + vmovups(xmm8, mem(rcx, 16)) + vmovups(xmm12, mem(rcx,32)) - add(rsi, rcx) + add(rsi, rcx) - vextractf128(imm(0x1), ymm4, xmm4) - vextractf128(imm(0x1), ymm8, xmm8) - vextractf128(imm(0x1), ymm12, xmm12) - vmovups(xmm4, mem(rcx)) - vmovups(xmm8, mem(rcx, 16)) - vmovups(xmm12, mem(rcx,32)) + vextractf128(imm(0x1), ymm4, xmm4) + vextractf128(imm(0x1), ymm8, xmm8) + vextractf128(imm(0x1), ymm12, xmm12) + vmovups(xmm4, mem(rcx)) + vmovups(xmm8, mem(rcx, 16)) + vmovups(xmm12, mem(rcx,32)) - add(rsi, rcx) - - vmovups(xmm5, mem(rcx)) - vmovups(xmm9, mem(rcx, 16)) - vmovups(xmm13,mem(rcx,32)) - - add(rsi, rcx) - - vextractf128(imm(0x1), ymm5, xmm5) - vextractf128(imm(0x1), ymm9, xmm9) - vextractf128(imm(0x1), ymm13, xmm13) - vmovups(xmm5, mem(rcx)) - vmovups(xmm9, mem(rcx, 16)) - vmovups(xmm13,mem(rcx,32)) + add(rsi, rcx) - label(.ZDONE) + vmovups(xmm5, mem(rcx)) + vmovups(xmm9, mem(rcx, 16)) + vmovups(xmm13,mem(rcx,32)) - lea(mem(r12, rdi, 2), r12) - lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + add(rsi, rcx) - lea(mem(r14, r8, 2), r14) - lea(mem(r14, r8, 1), r14) //a_ii = r14 += 3*rs_a + vextractf128(imm(0x1), ymm5, xmm5) + vextractf128(imm(0x1), ymm9, xmm9) + vextractf128(imm(0x1), ymm13, xmm13) + vmovups(xmm5, mem(rcx)) + vmovups(xmm9, mem(rcx, 16)) + vmovups(xmm13,mem(rcx,32)) - dec(r11) // ii -= 1; - jne(.ZLOOP3X4I) // iterate again if ii != 0. + label(.ZDONE) - label(.ZRETURN) + lea(mem(r12, rdi, 2), r12) + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c - end_asm( - : // output operands (none) - : // input operands + lea(mem(r14, r8, 2), r14) + lea(mem(r14, r8, 1), r14) //a_ii = r14 += 3*rs_a + + dec(r11) // ii -= 1; + jne(.ZLOOP3X4I) // iterate again if ii != 0. + + label(.ZRETURN) + + end_asm( + : // output operands (none) + : // input operands [alpha_mul_type] "m" (alpha_mul_type), [beta_mul_type] "m" (beta_mul_type), [m_iter] "m" (m_iter), @@ -807,46 +791,46 @@ void bli_zgemmsup_rv_zen_asm_3x4m [cs_c] "m" (cs_c)/*, [a_next] "m" (a_next), [b_next] "m" (b_next)*/ - : // register clobber list - "rax", "rbx", "rcx", "rsi", "rdi", - "r8", "r9", "r10", "r11", "r12", "r14", "r15", - "xmm0", "xmm1", "xmm2", "xmm3", - "xmm4", "xmm5", "xmm6", "xmm7", - "xmm8", "xmm9", "xmm10", "xmm11", - "xmm12", "xmm13", "xmm14", "xmm15", - "memory" - ) - - consider_edge_cases: - - // Handle edge cases in the m dimension, if they exist. - if ( m_left ) - { - const dim_t nr_cur = 4; - const dim_t i_edge = m0 - ( dim_t )m_left; + : // register clobber list + "rax", "rbx", "rcx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 4; + const dim_t i_edge = m0 - ( dim_t )m_left; dcomplex* cij = c + i_edge*rs_c; dcomplex* ai = a + i_edge*rs_a; dcomplex* bj = b; - zgemmsup_ker_ft ker_fps[3] = - { - NULL, - bli_zgemmsup_rv_zen_asm_1x4, - bli_zgemmsup_rv_zen_asm_2x4, - }; + zgemmsup_ker_ft ker_fps[3] = + { + NULL, + bli_zgemmsup_rv_zen_asm_1x4, + bli_zgemmsup_rv_zen_asm_2x4, + }; - zgemmsup_ker_ft ker_fp = ker_fps[ m_left ]; + zgemmsup_ker_ft ker_fp = ker_fps[ m_left ]; - ker_fp - ( - conja, conjb, m_left, nr_cur, k0, - alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, - beta, cij, rs_c0, cs_c0, data, cntx - ); - return; + ker_fp + ( + conja, conjb, m_left, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + return; - } + } } @@ -867,393 +851,393 @@ void bli_zgemmsup_rv_zen_asm_3x2m ) { - //void* a_next = bli_auxinfo_next_a( data ); - //void* b_next = bli_auxinfo_next_b( data ); + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t m_iter = m0 / 3; + uint64_t m_left = m0 % 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; - // Typecast local copies of integers in case dim_t and inc_t are a - // different size than is expected by load instructions. + if ( m_iter == 0 ) goto consider_edge_cases; - uint64_t k_iter = k0 / 4; - uint64_t k_left = k0 % 4; + // ------------------------------------------------------------------------- - uint64_t m_iter = m0 / 3; - uint64_t m_left = m0 % 3; + begin_asm() - uint64_t rs_a = rs_a0; - uint64_t cs_a = cs_a0; - uint64_t rs_b = rs_b0; - uint64_t rs_c = rs_c0; - uint64_t cs_c = cs_c0; + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(dt) + lea(mem(, r8, 2), r8) // rs_a *= sizeof(dt) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(dt) + lea(mem(, r9, 2), r9) // cs_a *= sizeof(dt) - if ( m_iter == 0 ) goto consider_edge_cases; + mov(var(rs_b), r10) // load rs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(dt) + lea(mem(, r10, 2), r10) // rs_b *= sizeof(dt) - // ------------------------------------------------------------------------- + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). - begin_asm() + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(dt) + lea(mem(, rdi, 2), rdi) // rs_c *= sizeof(dt) - mov(var(a), r14) // load address of a. - mov(var(rs_a), r8) // load rs_a - mov(var(cs_a), r9) // load cs_a - lea(mem(, r8, 8), r8) // rs_a *= sizeof(dt) - lea(mem(, r8, 2), r8) // rs_a *= sizeof(dt) - lea(mem(, r9, 8), r9) // cs_a *= sizeof(dt) - lea(mem(, r9, 2), r9) // cs_a *= sizeof(dt) + // During preamble and loops: + // r12 = rcx = c + // r14 = rax = a + // read rbx from var(b) near beginning of loop + // r11 = m dim index ii - mov(var(rs_b), r10) // load rs_b - lea(mem(, r10, 8), r10) // rs_b *= sizeof(dt) - lea(mem(, r10, 2), r10) // rs_b *= sizeof(dt) + mov(var(m_iter), r11) // ii = m_iter; - // NOTE: We cannot pre-load elements of a or b - // because it could eventually, in the last - // unrolled iter or the cleanup loop, result - // in reading beyond the bounds allocated mem - // (the likely result: a segmentation fault). + label(.ZLOOP3X2I) // LOOP OVER ii = [ m_iter ... 1 0 ] - mov(var(c), r12) // load address of c - mov(var(rs_c), rdi) // load rs_c - lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(dt) - lea(mem(, rdi, 2), rdi) // rs_c *= sizeof(dt) + vzeroall() // zero all xmm/ymm registers. - // During preamble and loops: - // r12 = rcx = c - // r14 = rax = a - // read rbx from var(b) near beginning of loop - // r11 = m dim index ii + mov(var(b), rbx) // load address of b. + mov(r14, rax) // reset rax to current upanel of a. - mov(var(m_iter), r11) // ii = m_iter; + cmp(imm(16), rdi) // set ZF if (16*rs_c) == 16. + jz(.ZCOLPFETCH) // jump to column storage case + label(.ZROWPFETCH) // row-stored pre-fetching on c // not used - label(.ZLOOP3X2I) // LOOP OVER ii = [ m_iter ... 1 0 ] + jmp(.ZPOSTPFETCH) // jump to end of pre-fetching c + label(.ZCOLPFETCH) // column-stored pre-fetching c - vzeroall() // zero all xmm/ymm registers. + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(dt) - mov(var(b), rbx) // load address of b. - mov(r14, rax) // reset rax to current upanel of a. + label(.ZPOSTPFETCH) // done prefetching c - cmp(imm(16), rdi) // set ZF if (16*rs_c) == 16. - jz(.ZCOLPFETCH) // jump to column storage case - label(.ZROWPFETCH) // row-stored pre-fetching on c // not used + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.ZCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. - jmp(.ZPOSTPFETCH) // jump to end of pre-fetching c - label(.ZCOLPFETCH) // column-stored pre-fetching c + label(.ZLOOPKITER) // MAIN LOOP - mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) - lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(dt) + // ---------------------------------- iteration 0 - label(.ZPOSTPFETCH) // done prefetching c + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; - mov(var(k_iter), rsi) // i = k_iter; - test(rsi, rsi) // check i via logical AND. - je(.ZCONSIDKLEFT) // if i == 0, jump to code that - // contains the k_left loop. + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm4) - label(.ZLOOPKITER) // MAIN LOOP + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm8) - // ---------------------------------- iteration 0 + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm12) - vmovupd(mem(rbx, 0*32), ymm0) - add(r10, rbx) // b += rs_b; + vbroadcastsd(mem(rax, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) - vbroadcastsd(mem(rax ), ymm2) - vfmadd231pd(ymm0, ymm2, ymm4) + vbroadcastsd(mem(rax, r8, 1, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) - vbroadcastsd(mem(rax, r8, 1), ymm2) - vfmadd231pd(ymm0, ymm2, ymm8) + vbroadcastsd(mem(rax, r8, 2, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm14) - vbroadcastsd(mem(rax, r8, 2), ymm2) - vfmadd231pd(ymm0, ymm2, ymm12) + add(r9, rax) // a += cs_a; - vbroadcastsd(mem(rax, 8), ymm3) - vfmadd231pd(ymm0, ymm3, ymm6) + // ---------------------------------- iteration 1 - vbroadcastsd(mem(rax, r8, 1, 8), ymm3) - vfmadd231pd(ymm0, ymm3, ymm10) + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; - vbroadcastsd(mem(rax, r8, 2, 8), ymm3) - vfmadd231pd(ymm0, ymm3, ymm14) + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm4) - add(r9, rax) // a += cs_a; + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm8) - // ---------------------------------- iteration 1 + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm12) - vmovupd(mem(rbx, 0*32), ymm0) - add(r10, rbx) // b += rs_b; + vbroadcastsd(mem(rax, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) - vbroadcastsd(mem(rax ), ymm2) - vfmadd231pd(ymm0, ymm2, ymm4) + vbroadcastsd(mem(rax, r8, 1, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) - vbroadcastsd(mem(rax, r8, 1), ymm2) - vfmadd231pd(ymm0, ymm2, ymm8) + vbroadcastsd(mem(rax, r8, 2, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm14) - vbroadcastsd(mem(rax, r8, 2), ymm2) - vfmadd231pd(ymm0, ymm2, ymm12) + add(r9, rax) // a += cs_a; - vbroadcastsd(mem(rax, 8), ymm3) - vfmadd231pd(ymm0, ymm3, ymm6) + // ---------------------------------- iteration 2 - vbroadcastsd(mem(rax, r8, 1, 8), ymm3) - vfmadd231pd(ymm0, ymm3, ymm10) + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; - vbroadcastsd(mem(rax, r8, 2, 8), ymm3) - vfmadd231pd(ymm0, ymm3, ymm14) + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm4) - add(r9, rax) // a += cs_a; + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm8) - // ---------------------------------- iteration 2 + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm12) - vmovupd(mem(rbx, 0*32), ymm0) - add(r10, rbx) // b += rs_b; + vbroadcastsd(mem(rax, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) - vbroadcastsd(mem(rax ), ymm2) - vfmadd231pd(ymm0, ymm2, ymm4) + vbroadcastsd(mem(rax, r8, 1, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) - vbroadcastsd(mem(rax, r8, 1), ymm2) - vfmadd231pd(ymm0, ymm2, ymm8) + vbroadcastsd(mem(rax, r8, 2, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm14) - vbroadcastsd(mem(rax, r8, 2), ymm2) - vfmadd231pd(ymm0, ymm2, ymm12) + add(r9, rax) // a += cs_a; - vbroadcastsd(mem(rax, 8), ymm3) - vfmadd231pd(ymm0, ymm3, ymm6) + // ---------------------------------- iteration 3 + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; - vbroadcastsd(mem(rax, r8, 1, 8), ymm3) - vfmadd231pd(ymm0, ymm3, ymm10) + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm4) - vbroadcastsd(mem(rax, r8, 2, 8), ymm3) - vfmadd231pd(ymm0, ymm3, ymm14) + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm8) - add(r9, rax) // a += cs_a; + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm12) - // ---------------------------------- iteration 3 - vmovupd(mem(rbx, 0*32), ymm0) - add(r10, rbx) // b += rs_b; + vbroadcastsd(mem(rax, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) - vbroadcastsd(mem(rax ), ymm2) - vfmadd231pd(ymm0, ymm2, ymm4) + vbroadcastsd(mem(rax, r8, 1, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) - vbroadcastsd(mem(rax, r8, 1), ymm2) - vfmadd231pd(ymm0, ymm2, ymm8) + vbroadcastsd(mem(rax, r8, 2, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm14) - vbroadcastsd(mem(rax, r8, 2), ymm2) - vfmadd231pd(ymm0, ymm2, ymm12) + add(r9, rax) // a += cs_a; - vbroadcastsd(mem(rax, 8), ymm3) - vfmadd231pd(ymm0, ymm3, ymm6) + dec(rsi) // i -= 1; + jne(.ZLOOPKITER) // iterate again if i != 0. - vbroadcastsd(mem(rax, r8, 1, 8), ymm3) - vfmadd231pd(ymm0, ymm3, ymm10) + label(.ZCONSIDKLEFT) - vbroadcastsd(mem(rax, r8, 2, 8), ymm3) - vfmadd231pd(ymm0, ymm3, ymm14) + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.ZPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. - add(r9, rax) // a += cs_a; + label(.ZLOOPKLEFT) // EDGE LOOP - dec(rsi) // i -= 1; - jne(.ZLOOPKITER) // iterate again if i != 0. + vmovupd(mem(rbx, 0*32), ymm0) + add(r10, rbx) // b += rs_b; - label(.ZCONSIDKLEFT) + vbroadcastsd(mem(rax ), ymm2) + vfmadd231pd(ymm0, ymm2, ymm4) - mov(var(k_left), rsi) // i = k_left; - test(rsi, rsi) // check i via logical AND. - je(.ZPOSTACCUM) // if i == 0, we're done; jump to end. - // else, we prepare to enter k_left loop. + vbroadcastsd(mem(rax, r8, 1), ymm2) + vfmadd231pd(ymm0, ymm2, ymm8) - label(.ZLOOPKLEFT) // EDGE LOOP + vbroadcastsd(mem(rax, r8, 2), ymm2) + vfmadd231pd(ymm0, ymm2, ymm12) - vmovupd(mem(rbx, 0*32), ymm0) - add(r10, rbx) // b += rs_b; + vbroadcastsd(mem(rax, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm6) - vbroadcastsd(mem(rax ), ymm2) - vfmadd231pd(ymm0, ymm2, ymm4) + vbroadcastsd(mem(rax, r8, 1, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) - vbroadcastsd(mem(rax, r8, 1), ymm2) - vfmadd231pd(ymm0, ymm2, ymm8) + vbroadcastsd(mem(rax, r8, 2, 8), ymm3) + vfmadd231pd(ymm0, ymm3, ymm14) - vbroadcastsd(mem(rax, r8, 2), ymm2) - vfmadd231pd(ymm0, ymm2, ymm12) + add(r9, rax) // a += cs_a; - vbroadcastsd(mem(rax, 8), ymm3) - vfmadd231pd(ymm0, ymm3, ymm6) + dec(rsi) // i -= 1; + jne(.ZLOOPKLEFT) // iterate again if i != 0. - vbroadcastsd(mem(rax, r8, 1, 8), ymm3) - vfmadd231pd(ymm0, ymm3, ymm10) + label(.ZPOSTACCUM) - vbroadcastsd(mem(rax, r8, 2, 8), ymm3) - vfmadd231pd(ymm0, ymm3, ymm14) + mov(r12, rcx) // reset rcx to current utile of c. - add(r9, rax) // a += cs_a; + // permute even and odd elements + // of ymm6/7, ymm10/11, ymm/14/15 + vpermilpd(imm(0x5), ymm6, ymm6) + vpermilpd(imm(0x5), ymm10, ymm10) + vpermilpd(imm(0x5), ymm14, ymm14) - dec(rsi) // i -= 1; - jne(.ZLOOPKLEFT) // iterate again if i != 0. + // subtract/add even/odd elements + vaddsubpd(ymm6, ymm4, ymm4) + vaddsubpd(ymm10, ymm8, ymm8) + vaddsubpd(ymm14, ymm12, ymm12) - label(.ZPOSTACCUM) + /* (ar + ai) x AB */ + mov(var(alpha), rax) // load address of alpha + vbroadcastsd(mem(rax), ymm0) // load alpha_r and duplicate + vbroadcastsd(mem(rax, 8), ymm1) // load alpha_i and duplicate - mov(r12, rcx) // reset rcx to current utile of c. + vpermilpd(imm(0x5), ymm4, ymm3) + vmulpd(ymm0, ymm4, ymm4) + vmulpd(ymm1, ymm3, ymm3) + vaddsubpd(ymm3, ymm4, ymm4) - // permute even and odd elements - // of ymm6/7, ymm10/11, ymm/14/15 - vpermilpd(imm(0x5), ymm6, ymm6) - vpermilpd(imm(0x5), ymm10, ymm10) - vpermilpd(imm(0x5), ymm14, ymm14) + vpermilpd(imm(0x5), ymm8, ymm3) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm1, ymm3, ymm3) + vaddsubpd(ymm3, ymm8, ymm8) - // subtract/add even/odd elements - vaddsubpd(ymm6, ymm4, ymm4) - vaddsubpd(ymm10, ymm8, ymm8) - vaddsubpd(ymm14, ymm12, ymm12) + vpermilpd(imm(0x5), ymm12, ymm3) + vmulpd(ymm0, ymm12, ymm12) + vmulpd(ymm1, ymm3, ymm3) + vaddsubpd(ymm3, ymm12, ymm12) - /* (ar + ai) x AB */ - mov(var(alpha), rax) // load address of alpha - vbroadcastsd(mem(rax), ymm0) // load alpha_r and duplicate - vbroadcastsd(mem(rax, 8), ymm1) // load alpha_i and duplicate + /* (br + bi)x C + ((ar + ai) x AB) */ + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rbx), ymm1) // load beta_r and duplicate + vbroadcastsd(mem(rbx, 8), ymm2) // load beta_i and duplicate - vpermilpd(imm(0x5), ymm4, ymm3) - vmulpd(ymm0, ymm4, ymm4) - vmulpd(ymm1, ymm3, ymm3) - vaddsubpd(ymm3, ymm4, ymm4) + // now avoid loading C if beta == 0 + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm1) // set ZF if beta_r == 0. + sete(r13b) // r13b = ( ZF == 1 ? 1 : 0 ); + vucomisd(xmm0, xmm2) // set ZF if beta_i == 0. + sete(r15b) // r15b = ( ZF == 1 ? 1 : 0 ); + and(r13b, r15b) // set ZF if r13b & r15b == 1. + jne(.ZBETAZERO) // if ZF = 1, jump to beta == 0 case - vpermilpd(imm(0x5), ymm8, ymm3) - vmulpd(ymm0, ymm8, ymm8) - vmulpd(ymm1, ymm3, ymm3) - vaddsubpd(ymm3, ymm8, ymm8) + cmp(imm(16), rdi) // set ZF if (16*rs_c) == 16. + jz(.ZCOLSTORED) // jump to column storage case - vpermilpd(imm(0x5), ymm12, ymm3) - vmulpd(ymm0, ymm12, ymm12) - vmulpd(ymm1, ymm3, ymm3) - vaddsubpd(ymm3, ymm12, ymm12) + label(.ZROWSTORED) - /* (br + bi)x C + ((ar + ai) x AB) */ - mov(var(beta), rbx) // load address of beta - vbroadcastsd(mem(rbx), ymm1) // load beta_r and duplicate - vbroadcastsd(mem(rbx, 8), ymm2) // load beta_i and duplicate + ZGEMM_INPUT_SCALE_RS_BETA_NZ + vaddpd(ymm4, ymm0, ymm0) + ZGEMM_OUTPUT_RS - // now avoid loading C if beta == 0 - vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. - vucomisd(xmm0, xmm1) // set ZF if beta_r == 0. - sete(r13b) // r13b = ( ZF == 1 ? 1 : 0 ); - vucomisd(xmm0, xmm2) // set ZF if beta_i == 0. - sete(r15b) // r15b = ( ZF == 1 ? 1 : 0 ); - and(r13b, r15b) // set ZF if r13b & r15b == 1. - jne(.ZBETAZERO) // if ZF = 1, jump to beta == 0 case + add(rdi, rcx) // rcx = c + 1*rs_c - cmp(imm(16), rdi) // set ZF if (16*rs_c) == 16. - jz(.ZCOLSTORED) // jump to column storage case + ZGEMM_INPUT_SCALE_RS_BETA_NZ + vaddpd(ymm8, ymm0, ymm0) + ZGEMM_OUTPUT_RS - label(.ZROWSTORED) + add(rdi, rcx) // rcx = c + 2*rs_c - ZGEMM_INPUT_SCALE_RS_BETA_NZ - vaddpd(ymm4, ymm0, ymm0) - ZGEMM_OUTPUT_RS + ZGEMM_INPUT_SCALE_RS_BETA_NZ + vaddpd(ymm12, ymm0, ymm0) + ZGEMM_OUTPUT_RS - add(rdi, rcx) // rcx = c + 1*rs_c + jmp(.ZDONE) // jump to end. - ZGEMM_INPUT_SCALE_RS_BETA_NZ - vaddpd(ymm8, ymm0, ymm0) - ZGEMM_OUTPUT_RS + label(.ZCOLSTORED) + /*|--------| |-------| + | | | | + | 3x2 | | 2x3 | + |--------| |-------| + */ - add(rdi, rcx) // rcx = c + 2*rs_c + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(real dt) + lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof((real+imag) dt) - ZGEMM_INPUT_SCALE_RS_BETA_NZ - vaddpd(ymm12, ymm0, ymm0) - ZGEMM_OUTPUT_RS + ZGEMM_INPUT_SCALE_CS_BETA_NZ + vaddpd(ymm4, ymm0, ymm4) - jmp(.ZDONE) // jump to end. + add(rdi, rcx) + ZGEMM_INPUT_SCALE_CS_BETA_NZ + vaddpd(ymm8, ymm0, ymm8) + add(rdi, rcx) - label(.ZCOLSTORED) - /*|--------| |-------| - | | | | - | 3x2 | | 2x3 | - |--------| |-------| - */ + ZGEMM_INPUT_SCALE_CS_BETA_NZ + vaddpd(ymm12, ymm0, ymm12) - mov(var(cs_c), rsi) // load cs_c - lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(real dt) - lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof((real+imag) dt) - - ZGEMM_INPUT_SCALE_CS_BETA_NZ - vaddpd(ymm4, ymm0, ymm4) + mov(r12, rcx) // reset rcx to current utile of c. - add(rdi, rcx) - ZGEMM_INPUT_SCALE_CS_BETA_NZ - vaddpd(ymm8, ymm0, ymm8) - add(rdi, rcx) - - ZGEMM_INPUT_SCALE_CS_BETA_NZ - vaddpd(ymm12, ymm0, ymm12) + /****3x2 tile going to save into 2x3 tile in C*****/ - mov(r12, rcx) // reset rcx to current utile of c. + /******************Transpose top tile 2x3***************************/ + vmovups(xmm4, mem(rcx)) + vmovups(xmm8, mem(rcx, 16)) + vmovups(xmm12, mem(rcx,32)) - /****3x2 tile going to save into 2x3 tile in C*****/ + add(rsi, rcx) - /******************Transpose top tile 2x3***************************/ - vmovups(xmm4, mem(rcx)) - vmovups(xmm8, mem(rcx, 16)) - vmovups(xmm12, mem(rcx,32)) + vextractf128(imm(0x1), ymm4, xmm4) + vextractf128(imm(0x1), ymm8, xmm8) + vextractf128(imm(0x1), ymm12, xmm12) + vmovups(xmm4, mem(rcx)) + vmovups(xmm8, mem(rcx, 16)) + vmovups(xmm12, mem(rcx,32)) - add(rsi, rcx) - vextractf128(imm(0x1), ymm4, xmm4) - vextractf128(imm(0x1), ymm8, xmm8) - vextractf128(imm(0x1), ymm12, xmm12) - vmovups(xmm4, mem(rcx)) - vmovups(xmm8, mem(rcx, 16)) - vmovups(xmm12, mem(rcx,32)) + jmp(.ZDONE) // jump to end. + label(.ZBETAZERO) - jmp(.ZDONE) // jump to end. + cmp(imm(16), rdi) // set ZF if (8*rs_c) == 8. + jz(.ZCOLSTORBZ) // jump to column storage case - label(.ZBETAZERO) + label(.ZROWSTORBZ) - cmp(imm(16), rdi) // set ZF if (8*rs_c) == 8. - jz(.ZCOLSTORBZ) // jump to column storage case + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) - label(.ZROWSTORBZ) + vmovupd(ymm8, mem(rcx)) + add(rdi, rcx) - vmovupd(ymm4, mem(rcx)) - add(rdi, rcx) - - vmovupd(ymm8, mem(rcx)) - add(rdi, rcx) - - vmovupd(ymm12, mem(rcx)) + vmovupd(ymm12, mem(rcx)) - jmp(.ZDONE) // jump to end. + jmp(.ZDONE) // jump to end. - label(.ZCOLSTORBZ) + label(.ZCOLSTORBZ) - /****3x2 tile going to save into 2x3 tile in C*****/ - mov(var(cs_c), rsi) // load cs_c - lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(real dt) - lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof((real+imag) dt) + /****3x2 tile going to save into 2x3 tile in C*****/ + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(real dt) + lea(mem(, rsi, 2), rsi) // rsi = cs_c * sizeof((real+imag) dt) - /******************Transpose tile 3x2***************************/ - vmovups(xmm4, mem(rcx)) - vmovups(xmm8, mem(rcx, 16)) - vmovups(xmm12, mem(rcx,32)) + /******************Transpose tile 3x2***************************/ + vmovups(xmm4, mem(rcx)) + vmovups(xmm8, mem(rcx, 16)) + vmovups(xmm12, mem(rcx,32)) - add(rsi, rcx) + add(rsi, rcx) - vextractf128(imm(0x1), ymm4, xmm4) - vextractf128(imm(0x1), ymm8, xmm8) - vextractf128(imm(0x1), ymm12, xmm12) - vmovups(xmm4, mem(rcx)) - vmovups(xmm8, mem(rcx, 16)) - vmovups(xmm12, mem(rcx,32)) + vextractf128(imm(0x1), ymm4, xmm4) + vextractf128(imm(0x1), ymm8, xmm8) + vextractf128(imm(0x1), ymm12, xmm12) + vmovups(xmm4, mem(rcx)) + vmovups(xmm8, mem(rcx, 16)) + vmovups(xmm12, mem(rcx,32)) - label(.ZDONE) + label(.ZDONE) - lea(mem(r12, rdi, 2), r12) - lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + lea(mem(r12, rdi, 2), r12) + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c - lea(mem(r14, r8, 2), r14) - lea(mem(r14, r8, 1), r14) //a_ii = r14 += 3*rs_a + lea(mem(r14, r8, 2), r14) + lea(mem(r14, r8, 1), r14) //a_ii = r14 += 3*rs_a - dec(r11) // ii -= 1; - jne(.ZLOOP3X2I) // iterate again if ii != 0. + dec(r11) // ii -= 1; + jne(.ZLOOP3X2I) // iterate again if ii != 0. - label(.ZRETURN) + label(.ZRETURN) - end_asm( - : // output operands (none) - : // input operands + end_asm( + : // output operands (none) + : // input operands [m_iter] "m" (m_iter), [k_iter] "m" (k_iter), [k_left] "m" (k_left), @@ -1269,43 +1253,43 @@ void bli_zgemmsup_rv_zen_asm_3x2m [cs_c] "m" (cs_c)/*, [a_next] "m" (a_next), [b_next] "m" (b_next)*/ - : // register clobber list - "rax", "rbx", "rcx", "rsi", "rdi", - "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", - "xmm0", "xmm1", "xmm2", "xmm3", - "xmm4", "xmm5", "xmm6", "xmm7", - "xmm8", "xmm9", "xmm10", "xmm11", - "xmm12", "xmm13", "xmm14", "xmm15", - "memory" - ) - - consider_edge_cases: - - // Handle edge cases in the m dimension, if they exist. - if ( m_left ) - { - const dim_t nr_cur = 4; - const dim_t i_edge = m0 - ( dim_t )m_left; - - dcomplex* cij = c + i_edge*rs_c; - dcomplex* ai = a + i_edge*rs_a; - dcomplex* bj = b; - - zgemmsup_ker_ft ker_fps[3] = - { - NULL, - bli_zgemmsup_rv_zen_asm_1x2, - bli_zgemmsup_rv_zen_asm_2x2, - }; - - zgemmsup_ker_ft ker_fp = ker_fps[ m_left ]; - - ker_fp - ( - conja, conjb, m_left, nr_cur, k0, - alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, - beta, cij, rs_c0, cs_c0, data, cntx - ); - return; - } -} \ No newline at end of file + : // register clobber list + "rax", "rbx", "rcx", "rsi", "rdi", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 4; + const dim_t i_edge = m0 - ( dim_t )m_left; + + dcomplex* cij = c + i_edge*rs_c; + dcomplex* ai = a + i_edge*rs_a; + dcomplex* bj = b; + + zgemmsup_ker_ft ker_fps[3] = + { + NULL, + bli_zgemmsup_rv_zen_asm_1x2, + bli_zgemmsup_rv_zen_asm_2x2, + }; + + zgemmsup_ker_ft ker_fp = ker_fps[ m_left ]; + + ker_fp + ( + conja, conjb, m_left, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + return; + } +} diff --git a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_z3x4n.c b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_z3x4n.c index 44b43e7418..b12f67ca9d 100644 --- a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_z3x4n.c +++ b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_z3x4n.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2021, Advanced Micro Devices, Inc. + Copyright (C) 2020 - 2021, Advanced Micro Devices, Inc.All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -37,16 +37,16 @@ /* rrr: - -------- ------ -------- - -------- += ------ ... -------- - -------- ------ -------- - -------- ------ : + -------- ------ -------- + -------- += ------ ... -------- + -------- ------ -------- + -------- ------ : rcr: - -------- | | | | -------- - -------- += | | | | ... -------- - -------- | | | | -------- - -------- | | | | : + -------- | | | | -------- + -------- += | | | | ... -------- + -------- | | | | -------- + -------- | | | | : Assumptions: - B is row-stored; @@ -62,10 +62,10 @@ cost of the in-register transpose). crr: - | | | | | | | | ------ -------- - | | | | | | | | += ------ ... -------- - | | | | | | | | ------ -------- - | | | | | | | | ------ : + | | | | | | | | ------ -------- + | | | | | | | | += ------ ... -------- + | | | | | | | | ------ -------- + | | | | | | | | ------ : */ void bli_zgemmsup_rv_zen_asm_3x4n ( @@ -83,452 +83,436 @@ void bli_zgemmsup_rv_zen_asm_3x4n cntx_t* restrict cntx ) { - uint64_t m_left = m0 % 3; - if ( m_left ) - { - zgemmsup_ker_ft ker_fps[3] = - { - NULL, - bli_zgemmsup_rv_zen_asm_1x4n, - bli_zgemmsup_rv_zen_asm_2x4n, - }; - zgemmsup_ker_ft ker_fp = ker_fps[ m_left ]; - ker_fp - ( - conja, conjb, m_left, n0, k0, - alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, - beta, c, rs_c0, cs_c0, data, cntx - ); - return; - } - //void* a_next = bli_auxinfo_next_a( data ); - //void* b_next = bli_auxinfo_next_b( data ); - - // Typecast local copies of integers in case dim_t and inc_t are a - // different size than is expected by load instructions. - - uint64_t k_iter = 0; - - - uint64_t n_iter = n0 / 4; - uint64_t n_left = n0 % 4; - - uint64_t rs_a = rs_a0; - uint64_t cs_a = cs_a0; - uint64_t rs_b = rs_b0; - uint64_t cs_b = cs_b0; - uint64_t rs_c = rs_c0; - uint64_t cs_c = cs_c0; - - - if ( n_iter == 0 ) goto consider_edge_cases; - - //handling case when alpha and beta are real and +/-1. - uint64_t alpha_real_one = *((uint64_t*)(&alpha->real)); - uint64_t beta_real_one = *((uint64_t*)(&beta->real)); - - uint64_t alpha_real_one_abs = ((alpha_real_one << 1) >> 1); - uint64_t beta_real_one_abs = ((beta_real_one << 1) >> 1); - - char alpha_mul_type = BLIS_MUL_DEFAULT; - char beta_mul_type = BLIS_MUL_DEFAULT; - - if((alpha_real_one_abs == BLIS_DOUBLE_TO_UINT64_ONE_ABS) && (alpha->imag==0))// (alpha is real and +/-1) - { - alpha_mul_type = BLIS_MUL_ONE; //alpha real and 1 - if(alpha_real_one == BLIS_DOUBLE_TO_UINT64_MINUS_ONE) - { - alpha_mul_type = BLIS_MUL_MINUS_ONE; //alpha real and -1 - } - } - - if(beta->imag == 0)// beta is real - { - if(beta_real_one_abs == BLIS_DOUBLE_TO_UINT64_ONE_ABS)// (beta +/-1) - { - beta_mul_type = BLIS_MUL_ONE; - if(beta_real_one == BLIS_DOUBLE_TO_UINT64_MINUS_ONE) - { - beta_mul_type = BLIS_MUL_MINUS_ONE; - } - } - else if(beta_real_one == 0) - { - beta_mul_type = BLIS_MUL_ZERO; - } - } - - // ------------------------------------------------------------------------- - //scratch registers - __m256d ymm0, ymm1, ymm2, ymm3; - __m256d ymm4, ymm5, ymm6, ymm7; - __m256d ymm8, ymm9, ymm10, ymm11; - __m256d ymm12, ymm13, ymm14, ymm15; - __m128d xmm0, xmm3; - - dcomplex *tA = a; - double *tAimag = &a->imag; - dcomplex *tB = b; - dcomplex *tC = c; - for (n_iter = 0; n_iter < n0 / 4; n_iter++) - { - // clear scratch registers. - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - ymm12 = _mm256_setzero_pd(); - ymm13 = _mm256_setzero_pd(); - ymm14 = _mm256_setzero_pd(); - ymm15 = _mm256_setzero_pd(); - - dim_t ta_inc_row = rs_a; - dim_t tb_inc_row = rs_b; - dim_t tc_inc_row = rs_c; - - dim_t ta_inc_col = cs_a; - dim_t tb_inc_col = cs_b; - dim_t tc_inc_col = cs_c; - - tA = a; - tAimag = &a->imag; - tB = b + n_iter*tb_inc_col*4; - tC = c + n_iter*tc_inc_col*4; - for (k_iter = 0; k_iter real == -1.0) - { - ymm0 = _mm256_setzero_pd(); - ymm4 = _mm256_sub_pd(ymm0,ymm4); - ymm5 = _mm256_sub_pd(ymm0, ymm5); - ymm8 = _mm256_sub_pd(ymm0, ymm8); - ymm9 = _mm256_sub_pd(ymm0, ymm9); - ymm12 = _mm256_sub_pd(ymm0, ymm12); - ymm13 = _mm256_sub_pd(ymm0, ymm13); - } - - //when alpha is real and +/-1, multiplication is skipped. - if(alpha_mul_type == BLIS_MUL_DEFAULT) - { - // alpha, beta multiplication. - /* (ar + ai) x AB */ - ymm0 = _mm256_broadcast_sd((double const *)(alpha)); // load alpha_r and duplicate - ymm1 = _mm256_broadcast_sd((double const *)(&alpha->imag)); // load alpha_i and duplicate - - ymm3 = _mm256_permute_pd(ymm4, 5); - ymm4 = _mm256_mul_pd(ymm0, ymm4); - ymm3 =_mm256_mul_pd(ymm1, ymm3); - ymm4 = _mm256_addsub_pd(ymm4, ymm3); - - ymm3 = _mm256_permute_pd(ymm5, 5); - ymm5 = _mm256_mul_pd(ymm0, ymm5); - ymm3 = _mm256_mul_pd(ymm1, ymm3); - ymm5 = _mm256_addsub_pd(ymm5, ymm3); - - ymm3 = _mm256_permute_pd(ymm8, 5); - ymm8 = _mm256_mul_pd(ymm0, ymm8); - ymm3 = _mm256_mul_pd(ymm1, ymm3); - ymm8 = _mm256_addsub_pd(ymm8, ymm3); - - ymm3 = _mm256_permute_pd(ymm9, 5); - ymm9 = _mm256_mul_pd(ymm0, ymm9); - ymm3 = _mm256_mul_pd(ymm1, ymm3); - ymm9 = _mm256_addsub_pd(ymm9, ymm3); - - ymm3 = _mm256_permute_pd(ymm12, 5); - ymm12 = _mm256_mul_pd(ymm0, ymm12); - ymm3 = _mm256_mul_pd(ymm1, ymm3); - ymm12 = _mm256_addsub_pd(ymm12, ymm3); - - ymm3 = _mm256_permute_pd(ymm13, 5); - ymm13 = _mm256_mul_pd(ymm0, ymm13); - ymm3 = _mm256_mul_pd(ymm1, ymm3); - ymm13 = _mm256_addsub_pd(ymm13, ymm3); - } - - if(tc_inc_row == 1) //col stored - { - if(beta_mul_type == BLIS_MUL_ZERO) - { - //transpose left 3x2 - _mm_storeu_pd((double *)(tC ), _mm256_castpd256_pd128(ymm4)); - _mm_storeu_pd((double *)(tC+1), _mm256_castpd256_pd128(ymm8)); - _mm_storeu_pd((double *)(tC+2), _mm256_castpd256_pd128(ymm12)); - tC += tc_inc_col; - - _mm_storeu_pd((double *)(tC ),_mm256_extractf128_pd (ymm4,1)); - _mm_storeu_pd((double *)(tC+1) ,_mm256_extractf128_pd (ymm8,1)); - _mm_storeu_pd((double *)(tC+2), _mm256_extractf128_pd(ymm12, 1)); - tC += tc_inc_col; - - //transpose right 3x2 - _mm_storeu_pd((double *)(tC ), _mm256_castpd256_pd128(ymm5)); - _mm_storeu_pd((double *)(tC+1), _mm256_castpd256_pd128(ymm9)); - _mm_storeu_pd((double *)(tC+2), _mm256_castpd256_pd128(ymm13)); - tC += tc_inc_col; - - _mm_storeu_pd((double *)(tC ),_mm256_extractf128_pd (ymm5,1)); - _mm_storeu_pd((double *)(tC+1) ,_mm256_extractf128_pd (ymm9,1)); - _mm_storeu_pd((double *)(tC+2), _mm256_extractf128_pd(ymm13, 1)); - } - else{ - ymm1 = _mm256_broadcast_sd((double const *)(beta)); // load alpha_r and duplicate - ymm2 = _mm256_broadcast_sd((double const *)(&beta->imag)); // load alpha_i and duplicate - //Multiply ymm4 with beta - xmm0 = _mm_loadu_pd((double *)(tC)) ; - xmm3 = _mm_loadu_pd((double *)(tC + tc_inc_col)) ; - ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; - ymm3 = _mm256_permute_pd(ymm0, 5); - ymm0 = _mm256_mul_pd(ymm1, ymm0); - ymm3 = _mm256_mul_pd(ymm2, ymm3); - ymm0 = _mm256_addsub_pd(ymm0, ymm3); - ymm4 = _mm256_add_pd(ymm4, ymm0); - //Multiply ymm8 with beta - xmm0 = _mm_loadu_pd((double *)(tC + 1)) ; - xmm3 = _mm_loadu_pd((double *)(tC + 1 + tc_inc_col)) ; - ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; - ymm3 = _mm256_permute_pd(ymm0, 5); - ymm0 = _mm256_mul_pd(ymm1, ymm0); - ymm3 = _mm256_mul_pd(ymm2, ymm3); - ymm0 = _mm256_addsub_pd(ymm0, ymm3); - ymm8 = _mm256_add_pd(ymm8, ymm0); - - //Multiply ymm12 with beta - xmm0 = _mm_loadu_pd((double *)(tC + 2)) ; - xmm3 = _mm_loadu_pd((double *)(tC + 2 + tc_inc_col)) ; - ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; - ymm3 = _mm256_permute_pd(ymm0, 5); - ymm0 = _mm256_mul_pd(ymm1, ymm0); - ymm3 = _mm256_mul_pd(ymm2, ymm3); - ymm0 = _mm256_addsub_pd(ymm0, ymm3); - ymm12 = _mm256_add_pd(ymm12, ymm0); - - //transpose left 3x2 - _mm_storeu_pd((double *)(tC ), _mm256_castpd256_pd128(ymm4)); - _mm_storeu_pd((double *)(tC+1), _mm256_castpd256_pd128(ymm8)); - _mm_storeu_pd((double *)(tC+2), _mm256_castpd256_pd128(ymm12)); - tC += tc_inc_col; - - _mm_storeu_pd((double *)(tC ),_mm256_extractf128_pd (ymm4,1)); - _mm_storeu_pd((double *)(tC+1) ,_mm256_extractf128_pd (ymm8,1)); - _mm_storeu_pd((double *)(tC+2), _mm256_extractf128_pd(ymm12, 1)); - tC += tc_inc_col; - - //Multiply ymm5 with beta - xmm0 = _mm_loadu_pd((double *)(tC)) ; - xmm3 = _mm_loadu_pd((double *)(tC + tc_inc_col)) ; - ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; - ymm3 = _mm256_permute_pd(ymm0, 5); - ymm0 = _mm256_mul_pd(ymm1, ymm0); - ymm3 = _mm256_mul_pd(ymm2, ymm3); - ymm0 = _mm256_addsub_pd(ymm0, ymm3); - ymm5 = _mm256_add_pd(ymm5, ymm0); - //Multiply ymm9 with beta - xmm0 = _mm_loadu_pd((double *)(tC + 1)) ; - xmm3 = _mm_loadu_pd((double *)(tC + 1 + tc_inc_col)) ; - ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; - ymm3 = _mm256_permute_pd(ymm0, 5); - ymm0 = _mm256_mul_pd(ymm1, ymm0); - ymm3 = _mm256_mul_pd(ymm2, ymm3); - ymm0 = _mm256_addsub_pd(ymm0, ymm3); - ymm9 = _mm256_add_pd(ymm9, ymm0); - - //Multiply ymm13 with beta - xmm0 = _mm_loadu_pd((double *)(tC + 2)) ; - xmm3 = _mm_loadu_pd((double *)(tC + 2 + tc_inc_col)) ; - ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; - ymm3 = _mm256_permute_pd(ymm0, 5); - ymm0 = _mm256_mul_pd(ymm1, ymm0); - ymm3 = _mm256_mul_pd(ymm2, ymm3); - ymm0 = _mm256_addsub_pd(ymm0, ymm3); - ymm13 = _mm256_add_pd(ymm13, ymm0); - - //transpose right 3x2 - _mm_storeu_pd((double *)(tC ), _mm256_castpd256_pd128(ymm5)); - _mm_storeu_pd((double *)(tC+1), _mm256_castpd256_pd128(ymm9)); - _mm_storeu_pd((double *)(tC+2), _mm256_castpd256_pd128(ymm13)); - tC += tc_inc_col; - - _mm_storeu_pd((double *)(tC ),_mm256_extractf128_pd (ymm5,1)); - _mm_storeu_pd((double *)(tC+1) ,_mm256_extractf128_pd (ymm9,1)); - _mm_storeu_pd((double *)(tC+2), _mm256_extractf128_pd(ymm13, 1)); - } - - } - else - { - if(beta_mul_type == BLIS_MUL_ZERO) - { - _mm256_storeu_pd((double *)(tC), ymm4); - _mm256_storeu_pd((double *)(tC + 2), ymm5); - _mm256_storeu_pd((double *)(tC + tc_inc_row) , ymm8); - _mm256_storeu_pd((double *)(tC + tc_inc_row + 2), ymm9); - _mm256_storeu_pd((double *)(tC + tc_inc_row *2), ymm12); - _mm256_storeu_pd((double *)(tC + tc_inc_row *2+ 2), ymm13); - } - else if(beta_mul_type == BLIS_MUL_ONE)// equivalent to if(beta->real == 1.0) - { - ymm2 = _mm256_loadu_pd((double const *)(tC)); - ymm4 = _mm256_add_pd(ymm4,ymm2); - ymm2 = _mm256_loadu_pd((double const *)(tC+2)); - ymm5 = _mm256_add_pd(ymm5,ymm2); - ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row)); - ymm8 = _mm256_add_pd(ymm8,ymm2); - ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row + 2)); - ymm9 = _mm256_add_pd(ymm9,ymm2); - ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row*2)); - ymm12 = _mm256_add_pd(ymm12,ymm2); - ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row*2 +2)); - ymm13 = _mm256_add_pd(ymm13,ymm2); - - _mm256_storeu_pd((double *)(tC), ymm4); - _mm256_storeu_pd((double *)(tC + 2), ymm5); - _mm256_storeu_pd((double *)(tC + tc_inc_row) , ymm8); - _mm256_storeu_pd((double *)(tC + tc_inc_row + 2), ymm9); - _mm256_storeu_pd((double *)(tC + tc_inc_row *2), ymm12); - _mm256_storeu_pd((double *)(tC + tc_inc_row *2+ 2), ymm13); - } - else{ - /* (br + bi) C + (ar + ai) AB */ - ymm0 = _mm256_broadcast_sd((double const *)(beta)); // load beta_r and duplicate - ymm1 = _mm256_broadcast_sd((double const *)(&beta->imag)); // load beta_i and duplicate - - ymm2 = _mm256_loadu_pd((double const *)(tC)); - ymm3 = _mm256_permute_pd(ymm2, 5); - ymm2 = _mm256_mul_pd(ymm0, ymm2); - ymm3 =_mm256_mul_pd(ymm1, ymm3); - ymm4 = _mm256_add_pd(ymm4, _mm256_addsub_pd(ymm2, ymm3)); - - ymm2 = _mm256_loadu_pd((double const *)(tC+2)); - ymm3 = _mm256_permute_pd(ymm2, 5); - ymm2 = _mm256_mul_pd(ymm0, ymm2); - ymm3 = _mm256_mul_pd(ymm1, ymm3); - ymm5 = _mm256_add_pd(ymm5, _mm256_addsub_pd(ymm2, ymm3)); - - ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row)); - ymm3 = _mm256_permute_pd(ymm2, 5); - ymm2 = _mm256_mul_pd(ymm0, ymm2); - ymm3 = _mm256_mul_pd(ymm1, ymm3); - ymm8 = _mm256_add_pd(ymm8, _mm256_addsub_pd(ymm2, ymm3)); - - ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row + 2)); - ymm3 = _mm256_permute_pd(ymm2, 5); - ymm2 = _mm256_mul_pd(ymm0, ymm2); - ymm3 = _mm256_mul_pd(ymm1, ymm3); - ymm9 = _mm256_add_pd(ymm9, _mm256_addsub_pd(ymm2, ymm3)); - - ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row*2)); - ymm3 = _mm256_permute_pd(ymm2, 5); - ymm2 = _mm256_mul_pd(ymm0, ymm2); - ymm3 = _mm256_mul_pd(ymm1, ymm3); - ymm12 = _mm256_add_pd(ymm12, _mm256_addsub_pd(ymm2, ymm3)); - - ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row*2 +2)); - ymm3 = _mm256_permute_pd(ymm2, 5); - ymm2 = _mm256_mul_pd(ymm0, ymm2); - ymm3 = _mm256_mul_pd(ymm1, ymm3); - ymm13 = _mm256_add_pd(ymm13, _mm256_addsub_pd(ymm2, ymm3)); - - _mm256_storeu_pd((double *)(tC), ymm4); - _mm256_storeu_pd((double *)(tC + 2), ymm5); - _mm256_storeu_pd((double *)(tC + tc_inc_row) , ymm8); - _mm256_storeu_pd((double *)(tC + tc_inc_row + 2), ymm9); - _mm256_storeu_pd((double *)(tC + tc_inc_row *2), ymm12); - _mm256_storeu_pd((double *)(tC + tc_inc_row *2+ 2), ymm13); - } - } - } - - consider_edge_cases: - // Handle edge cases in the m dimension, if they exist. - if ( n_left ) - { - const dim_t mr_cur = 3; - const dim_t j_edge = n0 - ( dim_t )n_left; - - dcomplex* restrict cij = c + j_edge*cs_c; - dcomplex* restrict ai = a; - dcomplex* restrict bj = b + n_iter * 4; - - if ( 2 <= n_left ) - { - const dim_t nr_cur = 2; - - bli_zgemmsup_rv_zen_asm_3x2 - ( - conja, conjb, mr_cur, nr_cur, k0, - alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, - beta, cij, rs_c0, cs_c0, data, cntx - ); - cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; - } - if ( 1 == n_left ) - { - bli_zgemv_ex - ( - BLIS_NO_TRANSPOSE, conjb, m0, k0, - alpha, ai, rs_a0, cs_a0, bj, rs_b0, - beta, cij, rs_c0, cntx, NULL - ); - } - } + uint64_t m_left = m0 % 3; + if ( m_left ) + { + zgemmsup_ker_ft ker_fps[3] = + { + NULL, + bli_zgemmsup_rv_zen_asm_1x4n, + bli_zgemmsup_rv_zen_asm_2x4n, + }; + zgemmsup_ker_ft ker_fp = ker_fps[ m_left ]; + ker_fp + ( + conja, conjb, m_left, n0, k0, + alpha, a, rs_a0, cs_a0, b, rs_b0, cs_b0, + beta, c, rs_c0, cs_c0, data, cntx + ); + return; + } + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + + uint64_t k_iter = 0; + + + uint64_t n_iter = n0 / 4; + uint64_t n_left = n0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + + if ( n_iter == 0 ) goto consider_edge_cases; + + char alpha_mul_type = BLIS_MUL_DEFAULT; + char beta_mul_type = BLIS_MUL_DEFAULT; + + //handling case when alpha and beta are real and +/-1. + + if(alpha->imag == 0.0)// (alpha is real) + { + if(alpha->real == 1.0) alpha_mul_type = BLIS_MUL_ONE; + else if(alpha->real == -1.0) alpha_mul_type = BLIS_MUL_MINUS_ONE; + else if(alpha->real == 0.0) alpha_mul_type = BLIS_MUL_ZERO; + } + + if(beta->imag == 0.0)// (beta is real) + { + if(beta->real == 1.0) beta_mul_type = BLIS_MUL_ONE; + else if(beta->real == -1.0) beta_mul_type = BLIS_MUL_MINUS_ONE; + else if(beta->real == 0.0) beta_mul_type = BLIS_MUL_ZERO; + } + + // ------------------------------------------------------------------------- + //scratch registers + __m256d ymm0, ymm1, ymm2, ymm3; + __m256d ymm4, ymm5, ymm6, ymm7; + __m256d ymm8, ymm9, ymm10, ymm11; + __m256d ymm12, ymm13, ymm14, ymm15; + __m128d xmm0, xmm3; + + dcomplex *tA = a; + double *tAimag = &a->imag; + dcomplex *tB = b; + dcomplex *tC = c; + for (n_iter = 0; n_iter < n0 / 4; n_iter++) + { + // clear scratch registers. + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + dim_t ta_inc_row = rs_a; + dim_t tb_inc_row = rs_b; + dim_t tc_inc_row = rs_c; + + dim_t ta_inc_col = cs_a; + dim_t tb_inc_col = cs_b; + dim_t tc_inc_col = cs_c; + + tA = a; + tAimag = &a->imag; + tB = b + n_iter*tb_inc_col*4; + tC = c + n_iter*tc_inc_col*4; + for (k_iter = 0; k_iter real == -1.0) + { + ymm0 = _mm256_setzero_pd(); + ymm4 = _mm256_sub_pd(ymm0,ymm4); + ymm5 = _mm256_sub_pd(ymm0, ymm5); + ymm8 = _mm256_sub_pd(ymm0, ymm8); + ymm9 = _mm256_sub_pd(ymm0, ymm9); + ymm12 = _mm256_sub_pd(ymm0, ymm12); + ymm13 = _mm256_sub_pd(ymm0, ymm13); + } + + //when alpha is real and +/-1, multiplication is skipped. + if(alpha_mul_type == BLIS_MUL_DEFAULT) + { + // alpha, beta multiplication. + /* (ar + ai) x AB */ + ymm0 = _mm256_broadcast_sd((double const *)(alpha)); // load alpha_r and duplicate + ymm1 = _mm256_broadcast_sd((double const *)(&alpha->imag)); // load alpha_i and duplicate + + ymm3 = _mm256_permute_pd(ymm4, 5); + ymm4 = _mm256_mul_pd(ymm0, ymm4); + ymm3 =_mm256_mul_pd(ymm1, ymm3); + ymm4 = _mm256_addsub_pd(ymm4, ymm3); + + ymm3 = _mm256_permute_pd(ymm5, 5); + ymm5 = _mm256_mul_pd(ymm0, ymm5); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm5 = _mm256_addsub_pd(ymm5, ymm3); + + ymm3 = _mm256_permute_pd(ymm8, 5); + ymm8 = _mm256_mul_pd(ymm0, ymm8); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm8 = _mm256_addsub_pd(ymm8, ymm3); + + ymm3 = _mm256_permute_pd(ymm9, 5); + ymm9 = _mm256_mul_pd(ymm0, ymm9); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm9 = _mm256_addsub_pd(ymm9, ymm3); + + ymm3 = _mm256_permute_pd(ymm12, 5); + ymm12 = _mm256_mul_pd(ymm0, ymm12); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm12 = _mm256_addsub_pd(ymm12, ymm3); + + ymm3 = _mm256_permute_pd(ymm13, 5); + ymm13 = _mm256_mul_pd(ymm0, ymm13); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm13 = _mm256_addsub_pd(ymm13, ymm3); + } + + if(tc_inc_row == 1) //col stored + { + if(beta_mul_type == BLIS_MUL_ZERO) + { + //transpose left 3x2 + _mm_storeu_pd((double *)(tC ), _mm256_castpd256_pd128(ymm4)); + _mm_storeu_pd((double *)(tC+1), _mm256_castpd256_pd128(ymm8)); + _mm_storeu_pd((double *)(tC+2), _mm256_castpd256_pd128(ymm12)); + tC += tc_inc_col; + + _mm_storeu_pd((double *)(tC ),_mm256_extractf128_pd (ymm4,1)); + _mm_storeu_pd((double *)(tC+1) ,_mm256_extractf128_pd (ymm8,1)); + _mm_storeu_pd((double *)(tC+2), _mm256_extractf128_pd(ymm12, 1)); + tC += tc_inc_col; + + //transpose right 3x2 + _mm_storeu_pd((double *)(tC ), _mm256_castpd256_pd128(ymm5)); + _mm_storeu_pd((double *)(tC+1), _mm256_castpd256_pd128(ymm9)); + _mm_storeu_pd((double *)(tC+2), _mm256_castpd256_pd128(ymm13)); + tC += tc_inc_col; + + _mm_storeu_pd((double *)(tC ),_mm256_extractf128_pd (ymm5,1)); + _mm_storeu_pd((double *)(tC+1) ,_mm256_extractf128_pd (ymm9,1)); + _mm_storeu_pd((double *)(tC+2), _mm256_extractf128_pd(ymm13, 1)); + } + else{ + ymm1 = _mm256_broadcast_sd((double const *)(beta)); // load alpha_r and duplicate + ymm2 = _mm256_broadcast_sd((double const *)(&beta->imag)); // load alpha_i and duplicate + //Multiply ymm4 with beta + xmm0 = _mm_loadu_pd((double *)(tC)) ; + xmm3 = _mm_loadu_pd((double *)(tC + tc_inc_col)) ; + ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_pd(ymm0, 5); + ymm0 = _mm256_mul_pd(ymm1, ymm0); + ymm3 = _mm256_mul_pd(ymm2, ymm3); + ymm0 = _mm256_addsub_pd(ymm0, ymm3); + ymm4 = _mm256_add_pd(ymm4, ymm0); + //Multiply ymm8 with beta + xmm0 = _mm_loadu_pd((double *)(tC + 1)) ; + xmm3 = _mm_loadu_pd((double *)(tC + 1 + tc_inc_col)) ; + ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_pd(ymm0, 5); + ymm0 = _mm256_mul_pd(ymm1, ymm0); + ymm3 = _mm256_mul_pd(ymm2, ymm3); + ymm0 = _mm256_addsub_pd(ymm0, ymm3); + ymm8 = _mm256_add_pd(ymm8, ymm0); + + //Multiply ymm12 with beta + xmm0 = _mm_loadu_pd((double *)(tC + 2)) ; + xmm3 = _mm_loadu_pd((double *)(tC + 2 + tc_inc_col)) ; + ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_pd(ymm0, 5); + ymm0 = _mm256_mul_pd(ymm1, ymm0); + ymm3 = _mm256_mul_pd(ymm2, ymm3); + ymm0 = _mm256_addsub_pd(ymm0, ymm3); + ymm12 = _mm256_add_pd(ymm12, ymm0); + + //transpose left 3x2 + _mm_storeu_pd((double *)(tC ), _mm256_castpd256_pd128(ymm4)); + _mm_storeu_pd((double *)(tC+1), _mm256_castpd256_pd128(ymm8)); + _mm_storeu_pd((double *)(tC+2), _mm256_castpd256_pd128(ymm12)); + tC += tc_inc_col; + + _mm_storeu_pd((double *)(tC ),_mm256_extractf128_pd (ymm4,1)); + _mm_storeu_pd((double *)(tC+1) ,_mm256_extractf128_pd (ymm8,1)); + _mm_storeu_pd((double *)(tC+2), _mm256_extractf128_pd(ymm12, 1)); + tC += tc_inc_col; + + //Multiply ymm5 with beta + xmm0 = _mm_loadu_pd((double *)(tC)) ; + xmm3 = _mm_loadu_pd((double *)(tC + tc_inc_col)) ; + ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_pd(ymm0, 5); + ymm0 = _mm256_mul_pd(ymm1, ymm0); + ymm3 = _mm256_mul_pd(ymm2, ymm3); + ymm0 = _mm256_addsub_pd(ymm0, ymm3); + ymm5 = _mm256_add_pd(ymm5, ymm0); + //Multiply ymm9 with beta + xmm0 = _mm_loadu_pd((double *)(tC + 1)) ; + xmm3 = _mm_loadu_pd((double *)(tC + 1 + tc_inc_col)) ; + ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_pd(ymm0, 5); + ymm0 = _mm256_mul_pd(ymm1, ymm0); + ymm3 = _mm256_mul_pd(ymm2, ymm3); + ymm0 = _mm256_addsub_pd(ymm0, ymm3); + ymm9 = _mm256_add_pd(ymm9, ymm0); + + //Multiply ymm13 with beta + xmm0 = _mm_loadu_pd((double *)(tC + 2)) ; + xmm3 = _mm_loadu_pd((double *)(tC + 2 + tc_inc_col)) ; + ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_pd(ymm0, 5); + ymm0 = _mm256_mul_pd(ymm1, ymm0); + ymm3 = _mm256_mul_pd(ymm2, ymm3); + ymm0 = _mm256_addsub_pd(ymm0, ymm3); + ymm13 = _mm256_add_pd(ymm13, ymm0); + + //transpose right 3x2 + _mm_storeu_pd((double *)(tC ), _mm256_castpd256_pd128(ymm5)); + _mm_storeu_pd((double *)(tC+1), _mm256_castpd256_pd128(ymm9)); + _mm_storeu_pd((double *)(tC+2), _mm256_castpd256_pd128(ymm13)); + tC += tc_inc_col; + + _mm_storeu_pd((double *)(tC ),_mm256_extractf128_pd (ymm5,1)); + _mm_storeu_pd((double *)(tC+1) ,_mm256_extractf128_pd (ymm9,1)); + _mm_storeu_pd((double *)(tC+2), _mm256_extractf128_pd(ymm13, 1)); + } + + } + else + { + if(beta_mul_type == BLIS_MUL_ZERO) + { + _mm256_storeu_pd((double *)(tC), ymm4); + _mm256_storeu_pd((double *)(tC + 2), ymm5); + _mm256_storeu_pd((double *)(tC + tc_inc_row) , ymm8); + _mm256_storeu_pd((double *)(tC + tc_inc_row + 2), ymm9); + _mm256_storeu_pd((double *)(tC + tc_inc_row *2), ymm12); + _mm256_storeu_pd((double *)(tC + tc_inc_row *2+ 2), ymm13); + } + else if(beta_mul_type == BLIS_MUL_ONE)// equivalent to if(beta->real == 1.0) + { + ymm2 = _mm256_loadu_pd((double const *)(tC)); + ymm4 = _mm256_add_pd(ymm4,ymm2); + ymm2 = _mm256_loadu_pd((double const *)(tC+2)); + ymm5 = _mm256_add_pd(ymm5,ymm2); + ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row)); + ymm8 = _mm256_add_pd(ymm8,ymm2); + ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row + 2)); + ymm9 = _mm256_add_pd(ymm9,ymm2); + ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row*2)); + ymm12 = _mm256_add_pd(ymm12,ymm2); + ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row*2 +2)); + ymm13 = _mm256_add_pd(ymm13,ymm2); + + _mm256_storeu_pd((double *)(tC), ymm4); + _mm256_storeu_pd((double *)(tC + 2), ymm5); + _mm256_storeu_pd((double *)(tC + tc_inc_row) , ymm8); + _mm256_storeu_pd((double *)(tC + tc_inc_row + 2), ymm9); + _mm256_storeu_pd((double *)(tC + tc_inc_row *2), ymm12); + _mm256_storeu_pd((double *)(tC + tc_inc_row *2+ 2), ymm13); + } + else{ + /* (br + bi) C + (ar + ai) AB */ + ymm0 = _mm256_broadcast_sd((double const *)(beta)); // load beta_r and duplicate + ymm1 = _mm256_broadcast_sd((double const *)(&beta->imag)); // load beta_i and duplicate + + ymm2 = _mm256_loadu_pd((double const *)(tC)); + ymm3 = _mm256_permute_pd(ymm2, 5); + ymm2 = _mm256_mul_pd(ymm0, ymm2); + ymm3 =_mm256_mul_pd(ymm1, ymm3); + ymm4 = _mm256_add_pd(ymm4, _mm256_addsub_pd(ymm2, ymm3)); + + ymm2 = _mm256_loadu_pd((double const *)(tC+2)); + ymm3 = _mm256_permute_pd(ymm2, 5); + ymm2 = _mm256_mul_pd(ymm0, ymm2); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm5 = _mm256_add_pd(ymm5, _mm256_addsub_pd(ymm2, ymm3)); + + ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row)); + ymm3 = _mm256_permute_pd(ymm2, 5); + ymm2 = _mm256_mul_pd(ymm0, ymm2); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm8 = _mm256_add_pd(ymm8, _mm256_addsub_pd(ymm2, ymm3)); + + ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row + 2)); + ymm3 = _mm256_permute_pd(ymm2, 5); + ymm2 = _mm256_mul_pd(ymm0, ymm2); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm9 = _mm256_add_pd(ymm9, _mm256_addsub_pd(ymm2, ymm3)); + + ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row*2)); + ymm3 = _mm256_permute_pd(ymm2, 5); + ymm2 = _mm256_mul_pd(ymm0, ymm2); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm12 = _mm256_add_pd(ymm12, _mm256_addsub_pd(ymm2, ymm3)); + + ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row*2 +2)); + ymm3 = _mm256_permute_pd(ymm2, 5); + ymm2 = _mm256_mul_pd(ymm0, ymm2); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm13 = _mm256_add_pd(ymm13, _mm256_addsub_pd(ymm2, ymm3)); + + _mm256_storeu_pd((double *)(tC), ymm4); + _mm256_storeu_pd((double *)(tC + 2), ymm5); + _mm256_storeu_pd((double *)(tC + tc_inc_row) , ymm8); + _mm256_storeu_pd((double *)(tC + tc_inc_row + 2), ymm9); + _mm256_storeu_pd((double *)(tC + tc_inc_row *2), ymm12); + _mm256_storeu_pd((double *)(tC + tc_inc_row *2+ 2), ymm13); + } + } + } + + consider_edge_cases: + // Handle edge cases in the m dimension, if they exist. + if ( n_left ) + { + const dim_t mr_cur = 3; + const dim_t j_edge = n0 - ( dim_t )n_left; + + dcomplex* restrict cij = c + j_edge*cs_c; + dcomplex* restrict ai = a; + dcomplex* restrict bj = b + n_iter * 4; + + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_zgemmsup_rv_zen_asm_3x2 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { + bli_zgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, m0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, + beta, cij, rs_c0, cntx, NULL + ); + } + } } @@ -549,286 +533,286 @@ void bli_zgemmsup_rv_zen_asm_2x4n ) { - uint64_t k_iter = 0; - - - uint64_t n_iter = n0 / 4; - uint64_t n_left = n0 % 4; - - uint64_t rs_a = rs_a0; - uint64_t cs_a = cs_a0; - uint64_t rs_b = rs_b0; - uint64_t cs_b = cs_b0; - uint64_t rs_c = rs_c0; - uint64_t cs_c = cs_c0; - - - if ( n_iter == 0 ) goto consider_edge_cases; - - // ------------------------------------------------------------------------- - //scratch registers - __m256d ymm0, ymm1, ymm2, ymm3; - __m256d ymm4, ymm5, ymm6, ymm7; - __m256d ymm8, ymm9, ymm10, ymm11; - __m128d xmm0, xmm3; - - dcomplex *tA = a; - double *tAimag = &a->imag; - dcomplex *tB = b; - dcomplex *tC = c; - for (n_iter = 0; n_iter < n0 / 4; n_iter++) - { - // clear scratch registers. - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - - dim_t ta_inc_row = rs_a; - dim_t tb_inc_row = rs_b; - dim_t tc_inc_row = rs_c; - - dim_t ta_inc_col = cs_a; - dim_t tb_inc_col = cs_b; - dim_t tc_inc_col = cs_c; - - tA = a; - tAimag = &a->imag; - tB = b + n_iter*tb_inc_col*4; - tC = c + n_iter*tc_inc_col*4; - for (k_iter = 0; k_iter imag)); // load alpha_i and duplicate - - ymm3 = _mm256_permute_pd(ymm4, 5); - ymm4 = _mm256_mul_pd(ymm0, ymm4); - ymm3 =_mm256_mul_pd(ymm1, ymm3); - ymm4 = _mm256_addsub_pd(ymm4, ymm3); - - ymm3 = _mm256_permute_pd(ymm5, 5); - ymm5 = _mm256_mul_pd(ymm0, ymm5); - ymm3 = _mm256_mul_pd(ymm1, ymm3); - ymm5 = _mm256_addsub_pd(ymm5, ymm3); - - ymm3 = _mm256_permute_pd(ymm8, 5); - ymm8 = _mm256_mul_pd(ymm0, ymm8); - ymm3 = _mm256_mul_pd(ymm1, ymm3); - ymm8 = _mm256_addsub_pd(ymm8, ymm3); - - ymm3 = _mm256_permute_pd(ymm9, 5); - ymm9 = _mm256_mul_pd(ymm0, ymm9); - ymm3 = _mm256_mul_pd(ymm1, ymm3); - ymm9 = _mm256_addsub_pd(ymm9, ymm3); - - if(tc_inc_row == 1) //col stored - { - if(beta->real == 0.0 && beta->imag == 0.0) - { - //transpose left 2x2 - _mm_storeu_pd((double *)(tC ), _mm256_castpd256_pd128(ymm4)); - _mm_storeu_pd((double *)(tC+1), _mm256_castpd256_pd128(ymm8)); - tC += tc_inc_col; - - _mm_storeu_pd((double *)(tC ),_mm256_extractf128_pd (ymm4,1)); - _mm_storeu_pd((double *)(tC+1) ,_mm256_extractf128_pd (ymm8,1)); - tC += tc_inc_col; - - //transpose right 2x2 - _mm_storeu_pd((double *)(tC ), _mm256_castpd256_pd128(ymm5)); - _mm_storeu_pd((double *)(tC+1), _mm256_castpd256_pd128(ymm9)); - tC += tc_inc_col; - - _mm_storeu_pd((double *)(tC ),_mm256_extractf128_pd (ymm5,1)); - _mm_storeu_pd((double *)(tC+1) ,_mm256_extractf128_pd (ymm9,1)); - } - else{ - ymm1 = _mm256_broadcast_sd((double const *)(beta)); // load alpha_r and duplicate - ymm2 = _mm256_broadcast_sd((double const *)(&beta->imag)); // load alpha_i and duplicate - //Multiply ymm4 with beta - xmm0 = _mm_loadu_pd((double *)(tC)) ; - xmm3 = _mm_loadu_pd((double *)(tC + tc_inc_col)) ; - ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; - ymm3 = _mm256_permute_pd(ymm0, 5); - ymm0 = _mm256_mul_pd(ymm1, ymm0); - ymm3 = _mm256_mul_pd(ymm2, ymm3); - ymm0 = _mm256_addsub_pd(ymm0, ymm3); - ymm4 = _mm256_add_pd(ymm4, ymm0); - //Multiply ymm8 with beta - xmm0 = _mm_loadu_pd((double *)(tC + 1)) ; - xmm3 = _mm_loadu_pd((double *)(tC + 1 + tc_inc_col)) ; - ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; - ymm3 = _mm256_permute_pd(ymm0, 5); - ymm0 = _mm256_mul_pd(ymm1, ymm0); - ymm3 = _mm256_mul_pd(ymm2, ymm3); - ymm0 = _mm256_addsub_pd(ymm0, ymm3); - ymm8 = _mm256_add_pd(ymm8, ymm0); - - //transpose left 2x2 - _mm_storeu_pd((double *)(tC), _mm256_castpd256_pd128(ymm4)); - _mm_storeu_pd((double *)(tC+1), _mm256_castpd256_pd128(ymm8)); - tC += tc_inc_col; - - _mm_storeu_pd((double *)(tC ) ,_mm256_extractf128_pd (ymm4,1)); - _mm_storeu_pd((double *)(tC+1) ,_mm256_extractf128_pd (ymm8,1)); - tC += tc_inc_col; - - - //Multiply ymm5 with beta - xmm0 = _mm_loadu_pd((double *)(tC)) ; - xmm3 = _mm_loadu_pd((double *)(tC + tc_inc_col)) ; - ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; - ymm3 = _mm256_permute_pd(ymm0, 5); - ymm0 = _mm256_mul_pd(ymm1, ymm0); - ymm3 = _mm256_mul_pd(ymm2, ymm3); - ymm0 = _mm256_addsub_pd(ymm0, ymm3); - ymm5 = _mm256_add_pd(ymm5, ymm0); - //Multiply ymm9 with beta - xmm0 = _mm_loadu_pd((double *)(tC + 1)) ; - xmm3 = _mm_loadu_pd((double *)(tC + 1 + tc_inc_col)) ; - ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; - ymm3 = _mm256_permute_pd(ymm0, 5); - ymm0 = _mm256_mul_pd(ymm1, ymm0); - ymm3 = _mm256_mul_pd(ymm2, ymm3); - ymm0 = _mm256_addsub_pd(ymm0, ymm3); - ymm9 = _mm256_add_pd(ymm9, ymm0); - - //transpose right 2x2 - _mm_storeu_pd((double *)(tC), _mm256_castpd256_pd128(ymm5)); - _mm_storeu_pd((double *)(tC+1), _mm256_castpd256_pd128(ymm9)); - tC += tc_inc_col; - - _mm_storeu_pd((double *)(tC ) ,_mm256_extractf128_pd (ymm5,1)); - _mm_storeu_pd((double *)(tC+1) ,_mm256_extractf128_pd (ymm9,1)); - } - - } - else - { - if(beta->real == 0.0 && beta->imag == 0.0) - { - _mm256_storeu_pd((double *)(tC), ymm4); - _mm256_storeu_pd((double *)(tC + 2), ymm5); - _mm256_storeu_pd((double *)(tC + tc_inc_row) , ymm8); - _mm256_storeu_pd((double *)(tC + tc_inc_row + 2), ymm9); - } - else{ - /* (br + bi) C + (ar + ai) AB */ - ymm0 = _mm256_broadcast_sd((double const *)(beta)); // load beta_r and duplicate - ymm1 = _mm256_broadcast_sd((double const *)(&beta->imag)); // load beta_i and duplicate - - ymm2 = _mm256_loadu_pd((double const *)(tC)); - ymm3 = _mm256_permute_pd(ymm2, 5); - ymm2 = _mm256_mul_pd(ymm0, ymm2); - ymm3 = _mm256_mul_pd(ymm1, ymm3); - ymm4 = _mm256_add_pd(ymm4, _mm256_addsub_pd(ymm2, ymm3)); - - ymm2 = _mm256_loadu_pd((double const *)(tC+2)); - ymm3 = _mm256_permute_pd(ymm2, 5); - ymm2 = _mm256_mul_pd(ymm0, ymm2); - ymm3 = _mm256_mul_pd(ymm1, ymm3); - ymm5 = _mm256_add_pd(ymm5, _mm256_addsub_pd(ymm2, ymm3)); - - ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row)); - ymm3 = _mm256_permute_pd(ymm2, 5); - ymm2 = _mm256_mul_pd(ymm0, ymm2); - ymm3 = _mm256_mul_pd(ymm1, ymm3); - ymm8 = _mm256_add_pd(ymm8, _mm256_addsub_pd(ymm2, ymm3)); - - ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row + 2)); - ymm3 = _mm256_permute_pd(ymm2, 5); - ymm2 = _mm256_mul_pd(ymm0, ymm2); - ymm3 = _mm256_mul_pd(ymm1, ymm3); - ymm9 = _mm256_add_pd(ymm9, _mm256_addsub_pd(ymm2, ymm3)); - - _mm256_storeu_pd((double *)(tC), ymm4); - _mm256_storeu_pd((double *)(tC + 2), ymm5); - _mm256_storeu_pd((double *)(tC + tc_inc_row) , ymm8); - _mm256_storeu_pd((double *)(tC + tc_inc_row + 2), ymm9); - } - } - } - - consider_edge_cases: - // Handle edge cases in the m dimension, if they exist. - if ( n_left ) - { - const dim_t mr_cur = 3; - const dim_t j_edge = n0 - ( dim_t )n_left; - - dcomplex* restrict cij = c + j_edge*cs_c; - dcomplex* restrict ai = a; - dcomplex* restrict bj = b + n_iter * 4; - - if ( 2 <= n_left ) - { - const dim_t nr_cur = 2; - - bli_zgemmsup_rv_zen_asm_2x2 - ( - conja, conjb, mr_cur, nr_cur, k0, - alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, - beta, cij, rs_c0, cs_c0, data, cntx - ); - cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; - } - if ( 1 == n_left ) - { - bli_zgemv_ex - ( - BLIS_NO_TRANSPOSE, conjb, m0, k0, - alpha, ai, rs_a0, cs_a0, bj, rs_b0, - beta, cij, rs_c0, cntx, NULL - ); - } - } + uint64_t k_iter = 0; + + + uint64_t n_iter = n0 / 4; + uint64_t n_left = n0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + + if ( n_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + //scratch registers + __m256d ymm0, ymm1, ymm2, ymm3; + __m256d ymm4, ymm5, ymm6, ymm7; + __m256d ymm8, ymm9, ymm10, ymm11; + __m128d xmm0, xmm3; + + dcomplex *tA = a; + double *tAimag = &a->imag; + dcomplex *tB = b; + dcomplex *tC = c; + for (n_iter = 0; n_iter < n0 / 4; n_iter++) + { + // clear scratch registers. + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + + dim_t ta_inc_row = rs_a; + dim_t tb_inc_row = rs_b; + dim_t tc_inc_row = rs_c; + + dim_t ta_inc_col = cs_a; + dim_t tb_inc_col = cs_b; + dim_t tc_inc_col = cs_c; + + tA = a; + tAimag = &a->imag; + tB = b + n_iter*tb_inc_col*4; + tC = c + n_iter*tc_inc_col*4; + for (k_iter = 0; k_iter imag)); // load alpha_i and duplicate + + ymm3 = _mm256_permute_pd(ymm4, 5); + ymm4 = _mm256_mul_pd(ymm0, ymm4); + ymm3 =_mm256_mul_pd(ymm1, ymm3); + ymm4 = _mm256_addsub_pd(ymm4, ymm3); + + ymm3 = _mm256_permute_pd(ymm5, 5); + ymm5 = _mm256_mul_pd(ymm0, ymm5); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm5 = _mm256_addsub_pd(ymm5, ymm3); + + ymm3 = _mm256_permute_pd(ymm8, 5); + ymm8 = _mm256_mul_pd(ymm0, ymm8); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm8 = _mm256_addsub_pd(ymm8, ymm3); + + ymm3 = _mm256_permute_pd(ymm9, 5); + ymm9 = _mm256_mul_pd(ymm0, ymm9); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm9 = _mm256_addsub_pd(ymm9, ymm3); + + if(tc_inc_row == 1) //col stored + { + if(beta->real == 0.0 && beta->imag == 0.0) + { + //transpose left 2x2 + _mm_storeu_pd((double *)(tC ), _mm256_castpd256_pd128(ymm4)); + _mm_storeu_pd((double *)(tC+1), _mm256_castpd256_pd128(ymm8)); + tC += tc_inc_col; + + _mm_storeu_pd((double *)(tC ),_mm256_extractf128_pd (ymm4,1)); + _mm_storeu_pd((double *)(tC+1) ,_mm256_extractf128_pd (ymm8,1)); + tC += tc_inc_col; + + //transpose right 2x2 + _mm_storeu_pd((double *)(tC ), _mm256_castpd256_pd128(ymm5)); + _mm_storeu_pd((double *)(tC+1), _mm256_castpd256_pd128(ymm9)); + tC += tc_inc_col; + + _mm_storeu_pd((double *)(tC ),_mm256_extractf128_pd (ymm5,1)); + _mm_storeu_pd((double *)(tC+1) ,_mm256_extractf128_pd (ymm9,1)); + } + else{ + ymm1 = _mm256_broadcast_sd((double const *)(beta)); // load alpha_r and duplicate + ymm2 = _mm256_broadcast_sd((double const *)(&beta->imag)); // load alpha_i and duplicate + //Multiply ymm4 with beta + xmm0 = _mm_loadu_pd((double *)(tC)) ; + xmm3 = _mm_loadu_pd((double *)(tC + tc_inc_col)) ; + ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_pd(ymm0, 5); + ymm0 = _mm256_mul_pd(ymm1, ymm0); + ymm3 = _mm256_mul_pd(ymm2, ymm3); + ymm0 = _mm256_addsub_pd(ymm0, ymm3); + ymm4 = _mm256_add_pd(ymm4, ymm0); + //Multiply ymm8 with beta + xmm0 = _mm_loadu_pd((double *)(tC + 1)) ; + xmm3 = _mm_loadu_pd((double *)(tC + 1 + tc_inc_col)) ; + ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_pd(ymm0, 5); + ymm0 = _mm256_mul_pd(ymm1, ymm0); + ymm3 = _mm256_mul_pd(ymm2, ymm3); + ymm0 = _mm256_addsub_pd(ymm0, ymm3); + ymm8 = _mm256_add_pd(ymm8, ymm0); + + //transpose left 2x2 + _mm_storeu_pd((double *)(tC), _mm256_castpd256_pd128(ymm4)); + _mm_storeu_pd((double *)(tC+1), _mm256_castpd256_pd128(ymm8)); + tC += tc_inc_col; + + _mm_storeu_pd((double *)(tC ) ,_mm256_extractf128_pd (ymm4,1)); + _mm_storeu_pd((double *)(tC+1) ,_mm256_extractf128_pd (ymm8,1)); + tC += tc_inc_col; + + + //Multiply ymm5 with beta + xmm0 = _mm_loadu_pd((double *)(tC)) ; + xmm3 = _mm_loadu_pd((double *)(tC + tc_inc_col)) ; + ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_pd(ymm0, 5); + ymm0 = _mm256_mul_pd(ymm1, ymm0); + ymm3 = _mm256_mul_pd(ymm2, ymm3); + ymm0 = _mm256_addsub_pd(ymm0, ymm3); + ymm5 = _mm256_add_pd(ymm5, ymm0); + //Multiply ymm9 with beta + xmm0 = _mm_loadu_pd((double *)(tC + 1)) ; + xmm3 = _mm_loadu_pd((double *)(tC + 1 + tc_inc_col)) ; + ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_pd(ymm0, 5); + ymm0 = _mm256_mul_pd(ymm1, ymm0); + ymm3 = _mm256_mul_pd(ymm2, ymm3); + ymm0 = _mm256_addsub_pd(ymm0, ymm3); + ymm9 = _mm256_add_pd(ymm9, ymm0); + + //transpose right 2x2 + _mm_storeu_pd((double *)(tC), _mm256_castpd256_pd128(ymm5)); + _mm_storeu_pd((double *)(tC+1), _mm256_castpd256_pd128(ymm9)); + tC += tc_inc_col; + + _mm_storeu_pd((double *)(tC ) ,_mm256_extractf128_pd (ymm5,1)); + _mm_storeu_pd((double *)(tC+1) ,_mm256_extractf128_pd (ymm9,1)); + } + + } + else + { + if(beta->real == 0.0 && beta->imag == 0.0) + { + _mm256_storeu_pd((double *)(tC), ymm4); + _mm256_storeu_pd((double *)(tC + 2), ymm5); + _mm256_storeu_pd((double *)(tC + tc_inc_row) , ymm8); + _mm256_storeu_pd((double *)(tC + tc_inc_row + 2), ymm9); + } + else{ + /* (br + bi) C + (ar + ai) AB */ + ymm0 = _mm256_broadcast_sd((double const *)(beta)); // load beta_r and duplicate + ymm1 = _mm256_broadcast_sd((double const *)(&beta->imag)); // load beta_i and duplicate + + ymm2 = _mm256_loadu_pd((double const *)(tC)); + ymm3 = _mm256_permute_pd(ymm2, 5); + ymm2 = _mm256_mul_pd(ymm0, ymm2); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm4 = _mm256_add_pd(ymm4, _mm256_addsub_pd(ymm2, ymm3)); + + ymm2 = _mm256_loadu_pd((double const *)(tC+2)); + ymm3 = _mm256_permute_pd(ymm2, 5); + ymm2 = _mm256_mul_pd(ymm0, ymm2); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm5 = _mm256_add_pd(ymm5, _mm256_addsub_pd(ymm2, ymm3)); + + ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row)); + ymm3 = _mm256_permute_pd(ymm2, 5); + ymm2 = _mm256_mul_pd(ymm0, ymm2); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm8 = _mm256_add_pd(ymm8, _mm256_addsub_pd(ymm2, ymm3)); + + ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row + 2)); + ymm3 = _mm256_permute_pd(ymm2, 5); + ymm2 = _mm256_mul_pd(ymm0, ymm2); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm9 = _mm256_add_pd(ymm9, _mm256_addsub_pd(ymm2, ymm3)); + + _mm256_storeu_pd((double *)(tC), ymm4); + _mm256_storeu_pd((double *)(tC + 2), ymm5); + _mm256_storeu_pd((double *)(tC + tc_inc_row) , ymm8); + _mm256_storeu_pd((double *)(tC + tc_inc_row + 2), ymm9); + } + } + } + + consider_edge_cases: + // Handle edge cases in the m dimension, if they exist. + if ( n_left ) + { + const dim_t mr_cur = 3; + const dim_t j_edge = n0 - ( dim_t )n_left; + + dcomplex* restrict cij = c + j_edge*cs_c; + dcomplex* restrict ai = a; + dcomplex* restrict bj = b + n_iter * 4; + + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + + bli_zgemmsup_rv_zen_asm_2x2 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { + bli_zgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, m0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, + beta, cij, rs_c0, cntx, NULL + ); + } + } } @@ -848,215 +832,215 @@ void bli_zgemmsup_rv_zen_asm_1x4n cntx_t* restrict cntx ) { - //void* a_next = bli_auxinfo_next_a( data ); - //void* b_next = bli_auxinfo_next_b( data ); - - // Typecast local copies of integers in case dim_t and inc_t are a - // different size than is expected by load instructions. - - uint64_t k_iter = 0; - - uint64_t n_iter = n0 / 4; - uint64_t n_left = n0 % 4; - - uint64_t cs_a = cs_a0; - uint64_t rs_b = rs_b0; - uint64_t cs_b = cs_b0; - uint64_t rs_c = rs_c0; - uint64_t cs_c = cs_c0; - - - if ( n_iter == 0 ) goto consider_edge_cases; - - // ------------------------------------------------------------------------- - //scratch registers - __m256d ymm0, ymm1, ymm2, ymm3; - __m256d ymm4, ymm5, ymm6, ymm7; - __m128d xmm0, xmm3; - - dcomplex *tA = a; - double *tAimag = &a->imag; - dcomplex *tB = b; - dcomplex *tC = c; - for (n_iter = 0; n_iter < n0 / 4; n_iter++) - { - // clear scratch registers. - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - dim_t tb_inc_row = rs_b; - dim_t tc_inc_row = rs_c; - - dim_t ta_inc_col = cs_a; - dim_t tb_inc_col = cs_b; - dim_t tc_inc_col = cs_c; - - tA = a; - tAimag = &a->imag; - tB = b + n_iter*tb_inc_col*4; - tC = c + n_iter*tc_inc_col*4; - for (k_iter = 0; k_iter imag)); // load alpha_i and duplicate - - ymm3 = _mm256_permute_pd(ymm4, 5); - ymm4 = _mm256_mul_pd(ymm0, ymm4); - ymm3 =_mm256_mul_pd(ymm1, ymm3); - ymm4 = _mm256_addsub_pd(ymm4, ymm3); - - ymm3 = _mm256_permute_pd(ymm5, 5); - ymm5 = _mm256_mul_pd(ymm0, ymm5); - ymm3 = _mm256_mul_pd(ymm1, ymm3); - ymm5 = _mm256_addsub_pd(ymm5, ymm3); - - if(tc_inc_row == 1) //col stored - { - if(beta->real == 0.0 && beta->imag == 0.0) - { - //transpose left 1x2 - _mm_storeu_pd((double *)(tC), _mm256_castpd256_pd128(ymm4)); - tC += tc_inc_col; - - _mm_storeu_pd((double *)(tC) ,_mm256_extractf128_pd (ymm4,1)); - tC += tc_inc_col; - - //transpose right 1x2 - _mm_storeu_pd((double *)(tC), _mm256_castpd256_pd128(ymm5)); - tC += tc_inc_col; - - _mm_storeu_pd((double *)(tC) ,_mm256_extractf128_pd (ymm5,1)); - } - else{ - ymm1 = _mm256_broadcast_sd((double const *)(beta)); // load alpha_r and duplicate - ymm2 = _mm256_broadcast_sd((double const *)(&beta->imag)); // load alpha_i and duplicate - //Multiply ymm4 with beta - xmm0 = _mm_loadu_pd((double *)(tC)) ; - xmm3 = _mm_loadu_pd((double *)(tC + tc_inc_col)) ; - ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; - ymm3 = _mm256_permute_pd(ymm0, 5); - ymm0 = _mm256_mul_pd(ymm1, ymm0); - ymm3 = _mm256_mul_pd(ymm2, ymm3); - ymm0 = _mm256_addsub_pd(ymm0, ymm3); - ymm4 = _mm256_add_pd(ymm4, ymm0); - - _mm_storeu_pd((double *)(tC), _mm256_castpd256_pd128(ymm4)); - tC += tc_inc_col; - - _mm_storeu_pd((double *)(tC ) ,_mm256_extractf128_pd (ymm4,1)); - tC += tc_inc_col; - - //Multiply ymm5 with beta - xmm0 = _mm_loadu_pd((double *)(tC)) ; - xmm3 = _mm_loadu_pd((double *)(tC + tc_inc_col)) ; - ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; - ymm3 = _mm256_permute_pd(ymm0, 5); - ymm0 = _mm256_mul_pd(ymm1, ymm0); - ymm3 = _mm256_mul_pd(ymm2, ymm3); - ymm0 = _mm256_addsub_pd(ymm0, ymm3); - ymm5 = _mm256_add_pd(ymm5, ymm0); - - _mm_storeu_pd((double *)(tC), _mm256_castpd256_pd128(ymm5)); - tC += tc_inc_col; - - _mm_storeu_pd((double *)(tC) ,_mm256_extractf128_pd (ymm5,1)); - } - - } - else - { - if(beta->real == 0.0 && beta->imag == 0.0) - { - _mm256_storeu_pd((double *)(tC), ymm4); - _mm256_storeu_pd((double *)(tC + 2), ymm5); - } - else{ - /* (br + bi) C + (ar + ai) AB */ - ymm0 = _mm256_broadcast_sd((double const *)(beta)); // load beta_r and duplicate - ymm1 = _mm256_broadcast_sd((double const *)(&beta->imag)); // load beta_i and duplicate - - ymm2 = _mm256_loadu_pd((double const *)(tC)); - ymm3 = _mm256_permute_pd(ymm2, 5); - ymm2 = _mm256_mul_pd(ymm0, ymm2); - ymm3 =_mm256_mul_pd(ymm1, ymm3); - ymm4 = _mm256_add_pd(ymm4, _mm256_addsub_pd(ymm2, ymm3)); - - ymm2 = _mm256_loadu_pd((double const *)(tC+2)); - ymm3 = _mm256_permute_pd(ymm2, 5); - ymm2 = _mm256_mul_pd(ymm0, ymm2); - ymm3 = _mm256_mul_pd(ymm1, ymm3); - ymm5 = _mm256_add_pd(ymm5, _mm256_addsub_pd(ymm2, ymm3)); - - _mm256_storeu_pd((double *)(tC), ymm4); - _mm256_storeu_pd((double *)(tC + 2), ymm5); - } - } - } - - consider_edge_cases: - // Handle edge cases in the m dimension, if they exist. - if ( n_left ) - { - const dim_t mr_cur = 3; - const dim_t j_edge = n0 - ( dim_t )n_left; - - dcomplex* restrict cij = c + j_edge*cs_c; - dcomplex* restrict ai = a; - dcomplex* restrict bj = b + n_iter * 4; - - if ( 2 <= n_left ) - { - const dim_t nr_cur = 2; - bli_zgemmsup_rv_zen_asm_1x2 - ( - conja, conjb, mr_cur, nr_cur, k0, - alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, - beta, cij, rs_c0, cs_c0, data, cntx - ); - cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; - } - if ( 1 == n_left ) - { - bli_zgemv_ex - ( - BLIS_NO_TRANSPOSE, conjb, m0, k0, - alpha, ai, rs_a0, cs_a0, bj, rs_b0, - beta, cij, rs_c0, cntx, NULL - ); - } - } + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + + uint64_t k_iter = 0; + + uint64_t n_iter = n0 / 4; + uint64_t n_left = n0 % 4; + + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + + if ( n_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + //scratch registers + __m256d ymm0, ymm1, ymm2, ymm3; + __m256d ymm4, ymm5, ymm6, ymm7; + __m128d xmm0, xmm3; + + dcomplex *tA = a; + double *tAimag = &a->imag; + dcomplex *tB = b; + dcomplex *tC = c; + for (n_iter = 0; n_iter < n0 / 4; n_iter++) + { + // clear scratch registers. + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + + dim_t tb_inc_row = rs_b; + dim_t tc_inc_row = rs_c; + + dim_t ta_inc_col = cs_a; + dim_t tb_inc_col = cs_b; + dim_t tc_inc_col = cs_c; + + tA = a; + tAimag = &a->imag; + tB = b + n_iter*tb_inc_col*4; + tC = c + n_iter*tc_inc_col*4; + for (k_iter = 0; k_iter imag)); // load alpha_i and duplicate + + ymm3 = _mm256_permute_pd(ymm4, 5); + ymm4 = _mm256_mul_pd(ymm0, ymm4); + ymm3 =_mm256_mul_pd(ymm1, ymm3); + ymm4 = _mm256_addsub_pd(ymm4, ymm3); + + ymm3 = _mm256_permute_pd(ymm5, 5); + ymm5 = _mm256_mul_pd(ymm0, ymm5); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm5 = _mm256_addsub_pd(ymm5, ymm3); + + if(tc_inc_row == 1) //col stored + { + if(beta->real == 0.0 && beta->imag == 0.0) + { + //transpose left 1x2 + _mm_storeu_pd((double *)(tC), _mm256_castpd256_pd128(ymm4)); + tC += tc_inc_col; + + _mm_storeu_pd((double *)(tC) ,_mm256_extractf128_pd (ymm4,1)); + tC += tc_inc_col; + + //transpose right 1x2 + _mm_storeu_pd((double *)(tC), _mm256_castpd256_pd128(ymm5)); + tC += tc_inc_col; + + _mm_storeu_pd((double *)(tC) ,_mm256_extractf128_pd (ymm5,1)); + } + else{ + ymm1 = _mm256_broadcast_sd((double const *)(beta)); // load alpha_r and duplicate + ymm2 = _mm256_broadcast_sd((double const *)(&beta->imag)); // load alpha_i and duplicate + //Multiply ymm4 with beta + xmm0 = _mm_loadu_pd((double *)(tC)) ; + xmm3 = _mm_loadu_pd((double *)(tC + tc_inc_col)) ; + ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_pd(ymm0, 5); + ymm0 = _mm256_mul_pd(ymm1, ymm0); + ymm3 = _mm256_mul_pd(ymm2, ymm3); + ymm0 = _mm256_addsub_pd(ymm0, ymm3); + ymm4 = _mm256_add_pd(ymm4, ymm0); + + _mm_storeu_pd((double *)(tC), _mm256_castpd256_pd128(ymm4)); + tC += tc_inc_col; + + _mm_storeu_pd((double *)(tC ) ,_mm256_extractf128_pd (ymm4,1)); + tC += tc_inc_col; + + //Multiply ymm5 with beta + xmm0 = _mm_loadu_pd((double *)(tC)) ; + xmm3 = _mm_loadu_pd((double *)(tC + tc_inc_col)) ; + ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_pd(ymm0, 5); + ymm0 = _mm256_mul_pd(ymm1, ymm0); + ymm3 = _mm256_mul_pd(ymm2, ymm3); + ymm0 = _mm256_addsub_pd(ymm0, ymm3); + ymm5 = _mm256_add_pd(ymm5, ymm0); + + _mm_storeu_pd((double *)(tC), _mm256_castpd256_pd128(ymm5)); + tC += tc_inc_col; + + _mm_storeu_pd((double *)(tC) ,_mm256_extractf128_pd (ymm5,1)); + } + + } + else + { + if(beta->real == 0.0 && beta->imag == 0.0) + { + _mm256_storeu_pd((double *)(tC), ymm4); + _mm256_storeu_pd((double *)(tC + 2), ymm5); + } + else{ + /* (br + bi) C + (ar + ai) AB */ + ymm0 = _mm256_broadcast_sd((double const *)(beta)); // load beta_r and duplicate + ymm1 = _mm256_broadcast_sd((double const *)(&beta->imag)); // load beta_i and duplicate + + ymm2 = _mm256_loadu_pd((double const *)(tC)); + ymm3 = _mm256_permute_pd(ymm2, 5); + ymm2 = _mm256_mul_pd(ymm0, ymm2); + ymm3 =_mm256_mul_pd(ymm1, ymm3); + ymm4 = _mm256_add_pd(ymm4, _mm256_addsub_pd(ymm2, ymm3)); + + ymm2 = _mm256_loadu_pd((double const *)(tC+2)); + ymm3 = _mm256_permute_pd(ymm2, 5); + ymm2 = _mm256_mul_pd(ymm0, ymm2); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm5 = _mm256_add_pd(ymm5, _mm256_addsub_pd(ymm2, ymm3)); + + _mm256_storeu_pd((double *)(tC), ymm4); + _mm256_storeu_pd((double *)(tC + 2), ymm5); + } + } + } + + consider_edge_cases: + // Handle edge cases in the m dimension, if they exist. + if ( n_left ) + { + const dim_t mr_cur = 3; + const dim_t j_edge = n0 - ( dim_t )n_left; + + dcomplex* restrict cij = c + j_edge*cs_c; + dcomplex* restrict ai = a; + dcomplex* restrict bj = b + n_iter * 4; + + if ( 2 <= n_left ) + { + const dim_t nr_cur = 2; + bli_zgemmsup_rv_zen_asm_1x2 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + cij += nr_cur*cs_c0; bj += nr_cur*cs_b0; n_left -= nr_cur; + } + if ( 1 == n_left ) + { + bli_zgemv_ex + ( + BLIS_NO_TRANSPOSE, conjb, m0, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, + beta, cij, rs_c0, cntx, NULL + ); + } + } } void bli_zgemmsup_rv_zen_asm_3x2 @@ -1075,194 +1059,194 @@ void bli_zgemmsup_rv_zen_asm_3x2 cntx_t* restrict cntx ) { - uint64_t k_iter = 0; - - uint64_t rs_a = rs_a0; - uint64_t cs_a = cs_a0; - uint64_t rs_b = rs_b0; - uint64_t rs_c = rs_c0; - uint64_t cs_c = cs_c0; - - - - // ------------------------------------------------------------------------- - //scratch registers - __m256d ymm0, ymm1, ymm2, ymm3; - __m256d ymm4, ymm6; - __m256d ymm8, ymm10; - __m256d ymm12, ymm14; - __m128d xmm0, xmm3; - - dcomplex *tA = a; - double *tAimag = &a->imag; - dcomplex *tB = b; - dcomplex *tC = c; - // clear scratch registers. - ymm4 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm12 = _mm256_setzero_pd(); - ymm14 = _mm256_setzero_pd(); - - dim_t ta_inc_row = rs_a; - dim_t tb_inc_row = rs_b; - dim_t tc_inc_row = rs_c; - - dim_t ta_inc_col = cs_a; - dim_t tc_inc_col = cs_c; - - for (k_iter = 0; k_iter imag)); // load alpha_i and duplicate - - ymm3 = _mm256_permute_pd(ymm4, 5); - ymm4 = _mm256_mul_pd(ymm0, ymm4); - ymm3 =_mm256_mul_pd(ymm1, ymm3); - ymm4 = _mm256_addsub_pd(ymm4, ymm3); - - ymm3 = _mm256_permute_pd(ymm8, 5); - ymm8 = _mm256_mul_pd(ymm0, ymm8); - ymm3 = _mm256_mul_pd(ymm1, ymm3); - ymm8 = _mm256_addsub_pd(ymm8, ymm3); - - ymm3 = _mm256_permute_pd(ymm12, 5); - ymm12 = _mm256_mul_pd(ymm0, ymm12); - ymm3 = _mm256_mul_pd(ymm1, ymm3); - ymm12 = _mm256_addsub_pd(ymm12, ymm3); - - if(tc_inc_row == 1) //col stored - { - if(beta->real == 0.0 && beta->imag == 0.0) - { - //transpose left 3x2 - _mm_storeu_pd((double *)(tC), _mm256_castpd256_pd128(ymm4)); - _mm_storeu_pd((double *)(tC+1), _mm256_castpd256_pd128(ymm8)); - _mm_storeu_pd((double *)(tC+2), _mm256_castpd256_pd128(ymm12)); - tC += tc_inc_col; - - _mm_storeu_pd((double *)(tC ),_mm256_extractf128_pd (ymm4,1)); - _mm_storeu_pd((double *)(tC+1) ,_mm256_extractf128_pd (ymm8,1)); - _mm_storeu_pd((double *)(tC+2), _mm256_extractf128_pd(ymm12, 1)); - } - else{ - ymm1 = _mm256_broadcast_sd((double const *)(beta)); // load alpha_r and duplicate - ymm2 = _mm256_broadcast_sd((double const *)(&beta->imag)); // load alpha_i and duplicate - //Multiply ymm4 with beta - xmm0 = _mm_loadu_pd((double *)(tC)) ; - xmm3 = _mm_loadu_pd((double *)(tC + tc_inc_col)) ; - ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; - ymm3 = _mm256_permute_pd(ymm0, 5); - ymm0 = _mm256_mul_pd(ymm1, ymm0); - ymm3 = _mm256_mul_pd(ymm2, ymm3); - ymm0 = _mm256_addsub_pd(ymm0, ymm3); - ymm4 = _mm256_add_pd(ymm4, ymm0); - //Multiply ymm8 with beta - xmm0 = _mm_loadu_pd((double *)(tC + 1)) ; - xmm3 = _mm_loadu_pd((double *)(tC + 1 + tc_inc_col)) ; - ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; - ymm3 = _mm256_permute_pd(ymm0, 5); - ymm0 = _mm256_mul_pd(ymm1, ymm0); - ymm3 = _mm256_mul_pd(ymm2, ymm3); - ymm0 = _mm256_addsub_pd(ymm0, ymm3); - ymm8 = _mm256_add_pd(ymm8, ymm0); - - //Multiply ymm12 with beta - xmm0 = _mm_loadu_pd((double *)(tC + 2)) ; - xmm3 = _mm_loadu_pd((double *)(tC + 2 + tc_inc_col)) ; - ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; - ymm3 = _mm256_permute_pd(ymm0, 5); - ymm0 = _mm256_mul_pd(ymm1, ymm0); - ymm3 = _mm256_mul_pd(ymm2, ymm3); - ymm0 = _mm256_addsub_pd(ymm0, ymm3); - ymm12 = _mm256_add_pd(ymm12, ymm0); - - _mm_storeu_pd((double *)(tC), _mm256_castpd256_pd128(ymm4)); - _mm_storeu_pd((double *)(tC+1), _mm256_castpd256_pd128(ymm8)); - _mm_storeu_pd((double *)(tC+2), _mm256_castpd256_pd128(ymm12)); - tC += tc_inc_col; - _mm_storeu_pd((double *)(tC ),_mm256_extractf128_pd (ymm4,1)); - _mm_storeu_pd((double *)(tC+1) ,_mm256_extractf128_pd (ymm8,1)); - _mm_storeu_pd((double *)(tC+2), _mm256_extractf128_pd(ymm12, 1)); - } - } - else - { - if(beta->real == 0.0 && beta->imag == 0.0) - { - _mm256_storeu_pd((double *)(tC), ymm4); - _mm256_storeu_pd((double *)(tC + tc_inc_row ), ymm8); - _mm256_storeu_pd((double *)(tC + tc_inc_row *2), ymm12); - } - else{ - /* (br + bi) C + (ar + ai) AB */ - ymm0 = _mm256_broadcast_sd((double const *)(beta)); // load beta_r and duplicate - ymm1 = _mm256_broadcast_sd((double const *)(&beta->imag)); // load beta_i and duplicate - - ymm2 = _mm256_loadu_pd((double const *)(tC)); - ymm3 = _mm256_permute_pd(ymm2, 5); - ymm2 = _mm256_mul_pd(ymm0, ymm2); - ymm3 =_mm256_mul_pd(ymm1, ymm3); - ymm4 = _mm256_add_pd(ymm4, _mm256_addsub_pd(ymm2, ymm3)); - - ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row)); - ymm3 = _mm256_permute_pd(ymm2, 5); - ymm2 = _mm256_mul_pd(ymm0, ymm2); - ymm3 = _mm256_mul_pd(ymm1, ymm3); - ymm8 = _mm256_add_pd(ymm8, _mm256_addsub_pd(ymm2, ymm3)); - - ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row*2)); - ymm3 = _mm256_permute_pd(ymm2, 5); - ymm2 = _mm256_mul_pd(ymm0, ymm2); - ymm3 = _mm256_mul_pd(ymm1, ymm3); - ymm12 = _mm256_add_pd(ymm12, _mm256_addsub_pd(ymm2, ymm3)); - - _mm256_storeu_pd((double *)(tC), ymm4); - _mm256_storeu_pd((double *)(tC + tc_inc_row) , ymm8); - _mm256_storeu_pd((double *)(tC + tc_inc_row *2), ymm12); - } - } + uint64_t k_iter = 0; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + + + // ------------------------------------------------------------------------- + //scratch registers + __m256d ymm0, ymm1, ymm2, ymm3; + __m256d ymm4, ymm6; + __m256d ymm8, ymm10; + __m256d ymm12, ymm14; + __m128d xmm0, xmm3; + + dcomplex *tA = a; + double *tAimag = &a->imag; + dcomplex *tB = b; + dcomplex *tC = c; + // clear scratch registers. + ymm4 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + + dim_t ta_inc_row = rs_a; + dim_t tb_inc_row = rs_b; + dim_t tc_inc_row = rs_c; + + dim_t ta_inc_col = cs_a; + dim_t tc_inc_col = cs_c; + + for (k_iter = 0; k_iter imag)); // load alpha_i and duplicate + + ymm3 = _mm256_permute_pd(ymm4, 5); + ymm4 = _mm256_mul_pd(ymm0, ymm4); + ymm3 =_mm256_mul_pd(ymm1, ymm3); + ymm4 = _mm256_addsub_pd(ymm4, ymm3); + + ymm3 = _mm256_permute_pd(ymm8, 5); + ymm8 = _mm256_mul_pd(ymm0, ymm8); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm8 = _mm256_addsub_pd(ymm8, ymm3); + + ymm3 = _mm256_permute_pd(ymm12, 5); + ymm12 = _mm256_mul_pd(ymm0, ymm12); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm12 = _mm256_addsub_pd(ymm12, ymm3); + + if(tc_inc_row == 1) //col stored + { + if(beta->real == 0.0 && beta->imag == 0.0) + { + //transpose left 3x2 + _mm_storeu_pd((double *)(tC), _mm256_castpd256_pd128(ymm4)); + _mm_storeu_pd((double *)(tC+1), _mm256_castpd256_pd128(ymm8)); + _mm_storeu_pd((double *)(tC+2), _mm256_castpd256_pd128(ymm12)); + tC += tc_inc_col; + + _mm_storeu_pd((double *)(tC ),_mm256_extractf128_pd (ymm4,1)); + _mm_storeu_pd((double *)(tC+1) ,_mm256_extractf128_pd (ymm8,1)); + _mm_storeu_pd((double *)(tC+2), _mm256_extractf128_pd(ymm12, 1)); + } + else{ + ymm1 = _mm256_broadcast_sd((double const *)(beta)); // load alpha_r and duplicate + ymm2 = _mm256_broadcast_sd((double const *)(&beta->imag)); // load alpha_i and duplicate + //Multiply ymm4 with beta + xmm0 = _mm_loadu_pd((double *)(tC)) ; + xmm3 = _mm_loadu_pd((double *)(tC + tc_inc_col)) ; + ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_pd(ymm0, 5); + ymm0 = _mm256_mul_pd(ymm1, ymm0); + ymm3 = _mm256_mul_pd(ymm2, ymm3); + ymm0 = _mm256_addsub_pd(ymm0, ymm3); + ymm4 = _mm256_add_pd(ymm4, ymm0); + //Multiply ymm8 with beta + xmm0 = _mm_loadu_pd((double *)(tC + 1)) ; + xmm3 = _mm_loadu_pd((double *)(tC + 1 + tc_inc_col)) ; + ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_pd(ymm0, 5); + ymm0 = _mm256_mul_pd(ymm1, ymm0); + ymm3 = _mm256_mul_pd(ymm2, ymm3); + ymm0 = _mm256_addsub_pd(ymm0, ymm3); + ymm8 = _mm256_add_pd(ymm8, ymm0); + + //Multiply ymm12 with beta + xmm0 = _mm_loadu_pd((double *)(tC + 2)) ; + xmm3 = _mm_loadu_pd((double *)(tC + 2 + tc_inc_col)) ; + ymm0 = _mm256_insertf128_pd(_mm256_castpd128_pd256(xmm0), xmm3, 1) ; + ymm3 = _mm256_permute_pd(ymm0, 5); + ymm0 = _mm256_mul_pd(ymm1, ymm0); + ymm3 = _mm256_mul_pd(ymm2, ymm3); + ymm0 = _mm256_addsub_pd(ymm0, ymm3); + ymm12 = _mm256_add_pd(ymm12, ymm0); + + _mm_storeu_pd((double *)(tC), _mm256_castpd256_pd128(ymm4)); + _mm_storeu_pd((double *)(tC+1), _mm256_castpd256_pd128(ymm8)); + _mm_storeu_pd((double *)(tC+2), _mm256_castpd256_pd128(ymm12)); + tC += tc_inc_col; + _mm_storeu_pd((double *)(tC ),_mm256_extractf128_pd (ymm4,1)); + _mm_storeu_pd((double *)(tC+1) ,_mm256_extractf128_pd (ymm8,1)); + _mm_storeu_pd((double *)(tC+2), _mm256_extractf128_pd(ymm12, 1)); + } + } + else + { + if(beta->real == 0.0 && beta->imag == 0.0) + { + _mm256_storeu_pd((double *)(tC), ymm4); + _mm256_storeu_pd((double *)(tC + tc_inc_row ), ymm8); + _mm256_storeu_pd((double *)(tC + tc_inc_row *2), ymm12); + } + else{ + /* (br + bi) C + (ar + ai) AB */ + ymm0 = _mm256_broadcast_sd((double const *)(beta)); // load beta_r and duplicate + ymm1 = _mm256_broadcast_sd((double const *)(&beta->imag)); // load beta_i and duplicate + + ymm2 = _mm256_loadu_pd((double const *)(tC)); + ymm3 = _mm256_permute_pd(ymm2, 5); + ymm2 = _mm256_mul_pd(ymm0, ymm2); + ymm3 =_mm256_mul_pd(ymm1, ymm3); + ymm4 = _mm256_add_pd(ymm4, _mm256_addsub_pd(ymm2, ymm3)); + + ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row)); + ymm3 = _mm256_permute_pd(ymm2, 5); + ymm2 = _mm256_mul_pd(ymm0, ymm2); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm8 = _mm256_add_pd(ymm8, _mm256_addsub_pd(ymm2, ymm3)); + + ymm2 = _mm256_loadu_pd((double const *)(tC+tc_inc_row*2)); + ymm3 = _mm256_permute_pd(ymm2, 5); + ymm2 = _mm256_mul_pd(ymm0, ymm2); + ymm3 = _mm256_mul_pd(ymm1, ymm3); + ymm12 = _mm256_add_pd(ymm12, _mm256_addsub_pd(ymm2, ymm3)); + + _mm256_storeu_pd((double *)(tC), ymm4); + _mm256_storeu_pd((double *)(tC + tc_inc_row) , ymm8); + _mm256_storeu_pd((double *)(tC + tc_inc_row *2), ymm12); + } + } } From 595f7b7edf0981e4b03bb01a864d59d264896f90 Mon Sep 17 00:00:00 2001 From: mkurumel Date: Thu, 5 Aug 2021 17:58:20 +0530 Subject: [PATCH 026/243] dnrm2 optimization with dot method 1. Added new kernel bli_dnorm2fv_unb_var1 kernel to compute norm with dot operation. 2. Added vectorization to compute square of 32 double element block size from vector X. 3. Defined a new Macro BLIS_ENABLE_DNRM2_FAST under config header to compute nrm2 using new kernel. 4. Dot kernel definitions and implementation have a possibility for accuracy issues .we can switch to traditional implementation by disabling the MACRO BLIS_ENABLE_DNRM2_FAST to compute L2-norm for Vector X . AMD-Internal: [CPUPL-1757] Change-Id: I1adcaf1b3b4e33837758593c998c25705ff0fe11 --- config/amdepyc/bli_family_amdepyc.h | 2 + config/zen/bli_family_zen.h | 4 +- config/zen2/bli_family_zen2.h | 2 +- config/zen3/bli_family_zen3.h | 2 + frame/util/bli_util_unb_var1.c | 1682 ++++++++++++++------------- kernels/zen/1/CMakeLists.txt | 1 + kernels/zen/1/bli_norm2_zen_int.c | 236 ++++ kernels/zen/bli_kernels_zen.h | 14 +- 8 files changed, 1111 insertions(+), 832 deletions(-) create mode 100644 kernels/zen/1/bli_norm2_zen_int.c diff --git a/config/amdepyc/bli_family_amdepyc.h b/config/amdepyc/bli_family_amdepyc.h index c3f4370692..5ae4460442 100644 --- a/config/amdepyc/bli_family_amdepyc.h +++ b/config/amdepyc/bli_family_amdepyc.h @@ -60,5 +60,7 @@ // BLIS), defining this macro as 1 yields better performance. #define AOCL_BLIS_MULTIINSTANCE 0 +#define BLIS_ENABLE_FAST_MATH + #endif diff --git a/config/zen/bli_family_zen.h b/config/zen/bli_family_zen.h index 737ca5c597..a166125889 100644 --- a/config/zen/bli_family_zen.h +++ b/config/zen/bli_family_zen.h @@ -53,4 +53,6 @@ #define BLIS_SMALL_MATRIX_A_THRES_M_SYRK 96 #define BLIS_SMALL_MATRIX_A_THRES_N_SYRK 128 -#endif \ No newline at end of file +#define BLIS_ENABLE_FAST_MATH + +#endif diff --git a/config/zen2/bli_family_zen2.h b/config/zen2/bli_family_zen2.h index fedc422ad1..dbaa8f4f73 100644 --- a/config/zen2/bli_family_zen2.h +++ b/config/zen2/bli_family_zen2.h @@ -56,6 +56,6 @@ // When running HPL with pure MPI without DGEMM threading (Single-threaded // BLIS), defining this macro as 1 yields better performance. #define AOCL_BLIS_MULTIINSTANCE 0 +#define BLIS_ENABLE_FAST_MATH #endif - diff --git a/config/zen3/bli_family_zen3.h b/config/zen3/bli_family_zen3.h index 0a6b210d62..78e2c9de97 100644 --- a/config/zen3/bli_family_zen3.h +++ b/config/zen3/bli_family_zen3.h @@ -55,4 +55,6 @@ #define BLIS_SMALL_MATRIX_A_THRES_M_SYRK 96 #define BLIS_SMALL_MATRIX_A_THRES_N_SYRK 128 +#define BLIS_ENABLE_FAST_MATH + #endif diff --git a/frame/util/bli_util_unb_var1.c b/frame/util/bli_util_unb_var1.c index e4042dd3b1..a2166b7b1f 100644 --- a/frame/util/bli_util_unb_var1.c +++ b/frame/util/bli_util_unb_var1.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2021, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -52,33 +52,33 @@ void PASTEMAC(ch,varname) \ rntm_t* rntm \ ) \ { \ - ctype* chi1; \ - ctype_r chi1_r; \ - ctype_r chi1_i; \ - ctype_r absum; \ - dim_t i; \ + ctype* chi1; \ + ctype_r chi1_r; \ + ctype_r chi1_i; \ + ctype_r absum; \ + dim_t i; \ \ - /* Initialize the absolute sum accumulator to zero. */ \ - PASTEMAC(chr,set0s)( absum ); \ + /* Initialize the absolute sum accumulator to zero. */ \ + PASTEMAC(chr,set0s)( absum ); \ \ - for ( i = 0; i < n; ++i ) \ - { \ - chi1 = x + (i )*incx; \ + for ( i = 0; i < n; ++i ) \ + { \ + chi1 = x + (i )*incx; \ \ - /* Get the real and imaginary components of chi1. */ \ - PASTEMAC2(ch,chr,gets)( *chi1, chi1_r, chi1_i ); \ + /* Get the real and imaginary components of chi1. */ \ + PASTEMAC2(ch,chr,gets)( *chi1, chi1_r, chi1_i ); \ \ - /* Replace chi1_r and chi1_i with their absolute values. */ \ - chi1_r = bli_fabs( chi1_r ); \ - chi1_i = bli_fabs( chi1_i ); \ + /* Replace chi1_r and chi1_i with their absolute values. */ \ + chi1_r = bli_fabs( chi1_r ); \ + chi1_i = bli_fabs( chi1_i ); \ \ - /* Accumulate the real and imaginary components into absum. */ \ - PASTEMAC(chr,adds)( chi1_r, absum ); \ - PASTEMAC(chr,adds)( chi1_i, absum ); \ - } \ + /* Accumulate the real and imaginary components into absum. */ \ + PASTEMAC(chr,adds)( chi1_r, absum ); \ + PASTEMAC(chr,adds)( chi1_i, absum ); \ + } \ \ - /* Store the final value of absum to the output variable. */ \ - PASTEMAC(chr,copys)( absum, *asum ); \ + /* Store the final value of absum to the output variable. */ \ + PASTEMAC(chr,copys)( absum, *asum ); \ } INSERT_GENTFUNCR_BASIC0( asumv_unb_var1 ) @@ -96,45 +96,45 @@ void PASTEMAC(ch,varname) \ rntm_t* rntm \ ) \ { \ - ctype_r* zeror = PASTEMAC(chr,0); \ - doff_t diagoffa; \ -\ - /* If the dimension is zero, return early. */ \ - if ( bli_zero_dim1( m ) ) return; \ -\ - /* In order to avoid the main diagonal, we must nudge the diagonal either - up or down by one, depending on which triangle is currently stored. */ \ - if ( bli_is_upper( uploa ) ) diagoffa = 1; \ - else /*if ( bli_is_lower( uploa ) )*/ diagoffa = -1; \ -\ - /* We will be reflecting the stored region over the diagonal into the - unstored region, so a transposition is necessary. Furthermore, since - we are creating a Hermitian matrix, we must also conjugate. */ \ - PASTEMAC2(ch,copym,BLIS_TAPI_EX_SUF) \ - ( \ - diagoffa, \ - BLIS_NONUNIT_DIAG, \ - uploa, \ - BLIS_CONJ_TRANSPOSE, \ - m, \ - m, \ - a, rs_a, cs_a, \ - a, rs_a, cs_a, \ - cntx, \ - rntm \ - ); \ -\ - /* Set the imaginary parts of the diagonal elements to zero. */ \ - PASTEMAC2(ch,setid,BLIS_TAPI_EX_SUF) \ - ( \ - 0, \ - m, \ - m, \ - zeror, \ - a, rs_a, cs_a, \ - cntx, \ - rntm \ - ); \ + ctype_r* zeror = PASTEMAC(chr,0); \ + doff_t diagoffa; \ +\ + /* If the dimension is zero, return early. */ \ + if ( bli_zero_dim1( m ) ) return; \ +\ + /* In order to avoid the main diagonal, we must nudge the diagonal either + up or down by one, depending on which triangle is currently stored. */ \ + if ( bli_is_upper( uploa ) ) diagoffa = 1; \ + else /*if ( bli_is_lower( uploa ) )*/ diagoffa = -1; \ +\ + /* We will be reflecting the stored region over the diagonal into the + unstored region, so a transposition is necessary. Furthermore, since + we are creating a Hermitian matrix, we must also conjugate. */ \ + PASTEMAC2(ch,copym,BLIS_TAPI_EX_SUF) \ + ( \ + diagoffa, \ + BLIS_NONUNIT_DIAG, \ + uploa, \ + BLIS_CONJ_TRANSPOSE, \ + m, \ + m, \ + a, rs_a, cs_a, \ + a, rs_a, cs_a, \ + cntx, \ + rntm \ + ); \ +\ + /* Set the imaginary parts of the diagonal elements to zero. */ \ + PASTEMAC2(ch,setid,BLIS_TAPI_EX_SUF) \ + ( \ + 0, \ + m, \ + m, \ + zeror, \ + a, rs_a, cs_a, \ + cntx, \ + rntm \ + ); \ } INSERT_GENTFUNCR_BASIC0( mkherm_unb_var1 ) @@ -152,31 +152,31 @@ void PASTEMAC(ch,varname) \ rntm_t* rntm \ ) \ { \ - doff_t diagoffa; \ -\ - /* If the dimension is zero, return early. */ \ - if ( bli_zero_dim1( m ) ) return; \ -\ - /* In order to avoid the main diagonal, we must nudge the diagonal either - up or down by one, depending on which triangle is currently stored. */ \ - if ( bli_is_upper( uploa ) ) diagoffa = 1; \ - else /*if ( bli_is_lower( uploa ) )*/ diagoffa = -1; \ -\ - /* We will be reflecting the stored region over the diagonal into the - unstored region, so a transposition is necessary. */ \ - PASTEMAC2(ch,copym,BLIS_TAPI_EX_SUF) \ - ( \ - diagoffa, \ - BLIS_NONUNIT_DIAG, \ - uploa, \ - BLIS_TRANSPOSE, \ - m, \ - m, \ - a, rs_a, cs_a, \ - a, rs_a, cs_a, \ - cntx, \ - rntm \ - ); \ + doff_t diagoffa; \ +\ + /* If the dimension is zero, return early. */ \ + if ( bli_zero_dim1( m ) ) return; \ +\ + /* In order to avoid the main diagonal, we must nudge the diagonal either + up or down by one, depending on which triangle is currently stored. */ \ + if ( bli_is_upper( uploa ) ) diagoffa = 1; \ + else /*if ( bli_is_lower( uploa ) )*/ diagoffa = -1; \ +\ + /* We will be reflecting the stored region over the diagonal into the + unstored region, so a transposition is necessary. */ \ + PASTEMAC2(ch,copym,BLIS_TAPI_EX_SUF) \ + ( \ + diagoffa, \ + BLIS_NONUNIT_DIAG, \ + uploa, \ + BLIS_TRANSPOSE, \ + m, \ + m, \ + a, rs_a, cs_a, \ + a, rs_a, cs_a, \ + cntx, \ + rntm \ + ); \ } INSERT_GENTFUNC_BASIC0( mksymm_unb_var1 ) @@ -194,34 +194,34 @@ void PASTEMAC(ch,varname) \ rntm_t* rntm \ ) \ { \ - ctype* zero = PASTEMAC(ch,0); \ - doff_t diagoffa; \ -\ - /* If the dimension is zero, return early. */ \ - if ( bli_zero_dim1( m ) ) return; \ -\ - /* Toggle uplo so that it refers to the unstored triangle. */ \ - bli_toggle_uplo( &uploa ); \ -\ - /* In order to avoid the main diagonal, we must nudge the diagonal either - up or down by one, depending on which triangle is to be zeroed. */ \ - if ( bli_is_upper( uploa ) ) diagoffa = 1; \ - else /*if ( bli_is_lower( uploa ) )*/ diagoffa = -1; \ -\ - /* Set the unstored triangle to zero. */ \ - PASTEMAC2(ch,setm,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - diagoffa, \ - BLIS_NONUNIT_DIAG, \ - uploa, \ - m, \ - m, \ - zero, \ - a, rs_a, cs_a, \ - cntx, \ - rntm \ - ); \ + ctype* zero = PASTEMAC(ch,0); \ + doff_t diagoffa; \ +\ + /* If the dimension is zero, return early. */ \ + if ( bli_zero_dim1( m ) ) return; \ +\ + /* Toggle uplo so that it refers to the unstored triangle. */ \ + bli_toggle_uplo( &uploa ); \ +\ + /* In order to avoid the main diagonal, we must nudge the diagonal either + up or down by one, depending on which triangle is to be zeroed. */ \ + if ( bli_is_upper( uploa ) ) diagoffa = 1; \ + else /*if ( bli_is_lower( uploa ) )*/ diagoffa = -1; \ +\ + /* Set the unstored triangle to zero. */ \ + PASTEMAC2(ch,setm,BLIS_TAPI_EX_SUF) \ + ( \ + BLIS_NO_CONJUGATE, \ + diagoffa, \ + BLIS_NONUNIT_DIAG, \ + uploa, \ + m, \ + m, \ + zero, \ + a, rs_a, cs_a, \ + cntx, \ + rntm \ + ); \ } INSERT_GENTFUNC_BASIC0( mktrim_unb_var1 ) @@ -239,27 +239,27 @@ void PASTEMAC(ch,varname) \ rntm_t* rntm \ ) \ { \ - ctype* chi1; \ - ctype_r abs_chi1; \ - ctype_r absum; \ - dim_t i; \ + ctype* chi1; \ + ctype_r abs_chi1; \ + ctype_r absum; \ + dim_t i; \ \ - /* Initialize the absolute sum accumulator to zero. */ \ - PASTEMAC(chr,set0s)( absum ); \ + /* Initialize the absolute sum accumulator to zero. */ \ + PASTEMAC(chr,set0s)( absum ); \ \ - for ( i = 0; i < n; ++i ) \ - { \ - chi1 = x + (i )*incx; \ + for ( i = 0; i < n; ++i ) \ + { \ + chi1 = x + (i )*incx; \ \ - /* Compute the absolute value (or complex magnitude) of chi1. */ \ - PASTEMAC2(ch,chr,abval2s)( *chi1, abs_chi1 ); \ + /* Compute the absolute value (or complex magnitude) of chi1. */ \ + PASTEMAC2(ch,chr,abval2s)( *chi1, abs_chi1 ); \ \ - /* Accumulate the absolute value of chi1 into absum. */ \ - PASTEMAC(chr,adds)( abs_chi1, absum ); \ - } \ + /* Accumulate the absolute value of chi1 into absum. */ \ + PASTEMAC(chr,adds)( abs_chi1, absum ); \ + } \ \ - /* Store final value of absum to the output variable. */ \ - PASTEMAC(chr,copys)( absum, *norm ); \ + /* Store final value of absum to the output variable. */ \ + PASTEMAC(chr,copys)( absum, *norm ); \ } INSERT_GENTFUNCR_BASIC0( norm1v_unb_var1 ) @@ -277,33 +277,33 @@ void PASTEMAC(ch,varname) \ rntm_t* rntm \ ) \ { \ - ctype_r* zero = PASTEMAC(chr,0); \ - ctype_r* one = PASTEMAC(chr,1); \ - ctype_r scale; \ - ctype_r sumsq; \ - ctype_r sqrt_sumsq; \ -\ - /* Initialize scale and sumsq to begin the summation. */ \ - PASTEMAC(chr,copys)( *zero, scale ); \ - PASTEMAC(chr,copys)( *one, sumsq ); \ -\ - /* Compute the sum of the squares of the vector. */ \ - PASTEMAC(ch,kername) \ - ( \ - n, \ - x, incx, \ - &scale, \ - &sumsq, \ - cntx, \ - rntm \ - ); \ -\ - /* Compute: norm = scale * sqrt( sumsq ) */ \ - PASTEMAC(chr,sqrt2s)( sumsq, sqrt_sumsq ); \ - PASTEMAC(chr,scals)( scale, sqrt_sumsq ); \ -\ - /* Store the final value to the output variable. */ \ - PASTEMAC(chr,copys)( sqrt_sumsq, *norm ); \ + ctype_r* zero = PASTEMAC(chr,0); \ + ctype_r* one = PASTEMAC(chr,1); \ + ctype_r scale; \ + ctype_r sumsq; \ + ctype_r sqrt_sumsq; \ +\ + /* Initialize scale and sumsq to begin the summation. */ \ + PASTEMAC(chr,copys)( *zero, scale ); \ + PASTEMAC(chr,copys)( *one, sumsq ); \ +\ + /* Compute the sum of the squares of the vector. */ \ + PASTEMAC(ch,kername) \ + ( \ + n, \ + x, incx, \ + &scale, \ + &sumsq, \ + cntx, \ + rntm \ + ); \ +\ + /* Compute: norm = scale * sqrt( sumsq ) */ \ + PASTEMAC(chr,sqrt2s)( sumsq, sqrt_sumsq ); \ + PASTEMAC(chr,scals)( scale, sqrt_sumsq ); \ +\ + /* Store the final value to the output variable. */ \ + PASTEMAC(chr,copys)( sqrt_sumsq, *norm ); \ } //INSERT_GENTFUNCR_BASIC( normfv_unb_var1, sumsqv_unb_var1 ) @@ -330,72 +330,72 @@ void PASTEMAC(ch,varname) \ rntm_t* rntm \ ) \ { \ - ctype_r* zero = PASTEMAC(chr,0); \ - ctype_r* one = PASTEMAC(chr,1); \ - ctype_r scale; \ - ctype_r sumsq; \ - ctype_r sqrt_sumsq; \ -\ - /* Initialize scale and sumsq to begin the summation. */ \ - PASTEMAC(chr,copys)( *zero, scale ); \ - PASTEMAC(chr,copys)( *one, sumsq ); \ -\ - /* An optimization: first try to use dotv to compute the sum of - the squares of the vector. If no floating-point exceptions - (specifically, overflow and invalid exceptions) were produced, - then we accept the computed value and returne early. The cost - of this optimization is the "sunk" cost of the initial dotv - when sumsqv must be used instead. However, we expect that the - vast majority of use cases will not produce exceptions, and - therefore only one pass through the data, via dotv, will be - required. */ \ - if ( TRUE ) \ - { \ - int f_exp_raised;\ - ctype sumsqc; \ -\ - feclearexcept( FE_ALL_EXCEPT );\ -\ - PASTEMAC2(ch,dotv,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE, \ - BLIS_NO_CONJUGATE, \ - n,\ - x, incx, \ - x, incx, \ - &sumsqc, \ - cntx, \ - rntm \ - ); \ -\ - PASTEMAC2(ch,chr,copys)( sumsqc, sumsq ); \ -\ - f_exp_raised = fetestexcept( FE_OVERFLOW | FE_INVALID );\ -\ - if ( !f_exp_raised ) \ - { \ - PASTEMAC(chr,sqrt2s)( sumsq, *norm ); \ - return; \ - } \ - } \ -\ - /* Compute the sum of the squares of the vector. */ \ - PASTEMAC(ch,kername) \ - ( \ - n, \ - x, incx, \ - &scale, \ - &sumsq, \ - cntx, \ - rntm \ - ); \ -\ - /* Compute: norm = scale * sqrt( sumsq ) */ \ - PASTEMAC(chr,sqrt2s)( sumsq, sqrt_sumsq ); \ - PASTEMAC(chr,scals)( scale, sqrt_sumsq ); \ -\ - /* Store the final value to the output variable. */ \ - PASTEMAC(chr,copys)( sqrt_sumsq, *norm ); \ + ctype_r* zero = PASTEMAC(chr,0); \ + ctype_r* one = PASTEMAC(chr,1); \ + ctype_r scale; \ + ctype_r sumsq; \ + ctype_r sqrt_sumsq; \ +\ + /* Initialize scale and sumsq to begin the summation. */ \ + PASTEMAC(chr,copys)( *zero, scale ); \ + PASTEMAC(chr,copys)( *one, sumsq ); \ +\ + /* An optimization: first try to use dotv to compute the sum of + the squares of the vector. If no floating-point exceptions + (specifically, overflow and invalid exceptions) were produced, + then we accept the computed value and returne early. The cost + of this optimization is the "sunk" cost of the initial dotv + when sumsqv must be used instead. However, we expect that the + vast majority of use cases will not produce exceptions, and + therefore only one pass through the data, via dotv, will be + required. */ \ + if ( TRUE ) \ + { \ + int f_exp_raised;\ + ctype sumsqc; \ +\ + feclearexcept( FE_ALL_EXCEPT );\ +\ + PASTEMAC2(ch,dotv,BLIS_TAPI_EX_SUF) \ + ( \ + BLIS_NO_CONJUGATE, \ + BLIS_NO_CONJUGATE, \ + n,\ + x, incx, \ + x, incx, \ + &sumsqc, \ + cntx, \ + rntm \ + ); \ +\ + PASTEMAC2(ch,chr,copys)( sumsqc, sumsq ); \ +\ + f_exp_raised = fetestexcept( FE_OVERFLOW | FE_INVALID );\ +\ + if ( !f_exp_raised ) \ + { \ + PASTEMAC(chr,sqrt2s)( sumsq, *norm ); \ + return; \ + } \ + } \ +\ + /* Compute the sum of the squares of the vector. */ \ + PASTEMAC(ch,kername) \ + ( \ + n, \ + x, incx, \ + &scale, \ + &sumsq, \ + cntx, \ + rntm \ + ); \ +\ + /* Compute: norm = scale * sqrt( sumsq ) */ \ + PASTEMAC(chr,sqrt2s)( sumsq, sqrt_sumsq ); \ + PASTEMAC(chr,scals)( scale, sqrt_sumsq ); \ +\ + /* Store the final value to the output variable. */ \ + PASTEMAC(chr,copys)( sqrt_sumsq, *norm ); \ } #else #define GENTFUNCR( ctype, ctype_r, ch, chr, varname, kername ) \ @@ -409,39 +409,65 @@ void PASTEMAC(ch,varname) \ rntm_t* rntm \ ) \ { \ - ctype_r* zero = PASTEMAC(chr,0); \ - ctype_r* one = PASTEMAC(chr,1); \ - ctype_r scale; \ - ctype_r sumsq; \ - ctype_r sqrt_sumsq; \ -\ - /* Initialize scale and sumsq to begin the summation. */ \ - PASTEMAC(chr,copys)( *zero, scale ); \ - PASTEMAC(chr,copys)( *one, sumsq ); \ -\ - /* Compute the sum of the squares of the vector. */ \ -\ - PASTEMAC(ch,kername) \ - ( \ - n, \ - x, incx, \ - &scale, \ - &sumsq, \ - cntx, \ - rntm \ - ); \ -\ - /* Compute: norm = scale * sqrt( sumsq ) */ \ - PASTEMAC(chr,sqrt2s)( sumsq, sqrt_sumsq ); \ - PASTEMAC(chr,scals)( scale, sqrt_sumsq ); \ -\ - /* Store the final value to the output variable. */ \ - PASTEMAC(chr,copys)( sqrt_sumsq, *norm ); \ + ctype_r* zero = PASTEMAC(chr,0); \ + ctype_r* one = PASTEMAC(chr,1); \ + ctype_r scale; \ + ctype_r sumsq; \ + ctype_r sqrt_sumsq; \ +\ + /* Initialize scale and sumsq to begin the summation. */ \ + PASTEMAC(chr,copys)( *zero, scale ); \ + PASTEMAC(chr,copys)( *one, sumsq ); \ +\ + /* Compute the sum of the squares of the vector. */ \ +\ + PASTEMAC(ch,kername) \ + ( \ + n, \ + x, incx, \ + &scale, \ + &sumsq, \ + cntx, \ + rntm \ + ); \ +\ + /* Compute: norm = scale * sqrt( sumsq ) */ \ + PASTEMAC(chr,sqrt2s)( sumsq, sqrt_sumsq ); \ + PASTEMAC(chr,scals)( scale, sqrt_sumsq ); \ +\ + /* Store the final value to the output variable. */ \ + PASTEMAC(chr,copys)( sqrt_sumsq, *norm ); \ } #endif GENTFUNCR( float, float, s, s, normfv_unb_var1, sumsqv_unb_var1 ) -GENTFUNCR( double, double, d, d, normfv_unb_var1, sumsqv_unb_var1 ) - +/*call sumsqv_unb_var1 if FAST_MATH is not defined else call dot-norm method*/\ +#ifndef BLIS_ENABLE_FAST_MATH +GENTFUNCR( double, double, d, d, normfv_unb_var1, sumsqv_unb_var1 ) +#else +#undef GENTFUNCR +#define GENTFUNCR( ctype, ctype_r, ch, chr, varname, kername ) \ +\ +void PASTEMAC(ch,varname) \ + ( \ + dim_t n, \ + ctype* x, inc_t incx, \ + ctype_r* norm, \ + cntx_t* cntx, \ + rntm_t* rntm \ + ) \ +{ \ +\ + /* Compute the sum of the squares of the vector. */ \ + PASTEMAC(ch,kername) \ + ( \ + n, \ + x, incx, \ + norm, \ + cntx \ + ); \ +} +GENTFUNCR( double, double, d, d, normfv_unb_var1, norm2fv_unb_var1 ) +#endif #undef GENTFUNCR #define GENTFUNCR( ctype, ctype_r, ch, chr, varname ) \ @@ -455,34 +481,34 @@ void PASTEMAC(ch,varname) \ rntm_t* rntm \ ) \ { \ - ctype* chi1; \ - ctype_r abs_chi1; \ - ctype_r abs_chi1_max; \ - dim_t i; \ -\ - /* Initialize the maximum absolute value to zero. */ \ - PASTEMAC(chr,set0s)( abs_chi1_max ); \ -\ - for ( i = 0; i < n; ++i ) \ - { \ - chi1 = x + (i )*incx; \ -\ - /* Compute the absolute value (or complex magnitude) of chi1. */ \ - PASTEMAC2(ch,chr,abval2s)( *chi1, abs_chi1 ); \ -\ - /* If the absolute value of the current element exceeds that of - the previous largest, save it and its index. If NaN is - encountered, then treat it the same as if it were a valid - value that was larger than any previously seen. This - behavior mimics that of LAPACK's ?lange(). */ \ - if ( abs_chi1_max < abs_chi1 || bli_isnan( abs_chi1 ) ) \ - { \ - PASTEMAC(chr,copys)( abs_chi1, abs_chi1_max ); \ - } \ - } \ -\ - /* Store the final value to the output variable. */ \ - PASTEMAC(chr,copys)( abs_chi1_max, *norm ); \ + ctype* chi1; \ + ctype_r abs_chi1; \ + ctype_r abs_chi1_max; \ + dim_t i; \ +\ + /* Initialize the maximum absolute value to zero. */ \ + PASTEMAC(chr,set0s)( abs_chi1_max ); \ +\ + for ( i = 0; i < n; ++i ) \ + { \ + chi1 = x + (i )*incx; \ +\ + /* Compute the absolute value (or complex magnitude) of chi1. */ \ + PASTEMAC2(ch,chr,abval2s)( *chi1, abs_chi1 ); \ +\ + /* If the absolute value of the current element exceeds that of + the previous largest, save it and its index. If NaN is + encountered, then treat it the same as if it were a valid + value that was larger than any previously seen. This + behavior mimics that of LAPACK's ?lange(). */ \ + if ( abs_chi1_max < abs_chi1 || bli_isnan( abs_chi1 ) ) \ + { \ + PASTEMAC(chr,copys)( abs_chi1, abs_chi1_max ); \ + } \ + } \ +\ + /* Store the final value to the output variable. */ \ + PASTEMAC(chr,copys)( abs_chi1_max, *norm ); \ } INSERT_GENTFUNCR_BASIC0( normiv_unb_var1 ) @@ -505,149 +531,149 @@ void PASTEMAC(ch,varname) \ rntm_t* rntm \ ) \ { \ - ctype* one = PASTEMAC(ch,1); \ - ctype* x0; \ - ctype* chi1; \ - ctype* x2; \ - ctype_r absum_max; \ - ctype_r absum_j; \ - ctype_r abval_chi1; \ - uplo_t uplox_eff; \ - dim_t n_iter; \ - dim_t n_elem, n_elem_max; \ - inc_t ldx, incx; \ - dim_t j, i; \ - dim_t ij0, n_shift; \ -\ - /* Initialize the maximum absolute column sum to zero. */ \ - PASTEMAC(chr,set0s)( absum_max ); \ -\ - /* If either dimension is zero, return with absum_max equal to zero. */ \ - if ( bli_zero_dim2( m, n ) ) \ - { \ - PASTEMAC(chr,copys)( absum_max, *norm ); \ - return; \ - } \ -\ - /* Set various loop parameters. */ \ - bli_set_dims_incs_uplo_1m_noswap \ - ( \ - diagoffx, BLIS_NONUNIT_DIAG, \ - uplox, m, n, rs_x, cs_x, \ - &uplox_eff, &n_elem_max, &n_iter, &incx, &ldx, \ - &ij0, &n_shift \ - ); \ -\ - /* If the matrix is zeros, return with absum_max equal to zero. */ \ - if ( bli_is_zeros( uplox_eff ) ) \ - { \ - PASTEMAC(chr,copys)( absum_max, *norm ); \ - return; \ - } \ -\ -\ - /* Handle dense and upper/lower storage cases separately. */ \ - if ( bli_is_dense( uplox_eff ) ) \ - { \ - for ( j = 0; j < n_iter; ++j ) \ - { \ - n_elem = n_elem_max; \ -\ - x0 = x + (j )*ldx + (0 )*incx; \ -\ - /* Compute the norm of the current column. */ \ - PASTEMAC(ch,kername) \ - ( \ - n_elem, \ - x0, incx, \ - &absum_j, \ - cntx, \ - rntm \ - ); \ -\ - /* If absum_j is greater than the previous maximum value, - then save it. */ \ - if ( absum_max < absum_j || bli_isnan( absum_j ) ) \ - { \ - PASTEMAC(chr,copys)( absum_j, absum_max ); \ - } \ - } \ - } \ - else \ - { \ - if ( bli_is_upper( uplox_eff ) ) \ - { \ - for ( j = 0; j < n_iter; ++j ) \ - { \ - n_elem = bli_min( n_shift + j + 1, n_elem_max ); \ -\ - x0 = x + (ij0+j )*ldx + (0 )*incx; \ - chi1 = x + (ij0+j )*ldx + (n_elem-1)*incx; \ -\ - /* Compute the norm of the super-diagonal elements. */ \ - PASTEMAC(ch,kername) \ - ( \ - n_elem - 1, \ - x0, incx, \ - &absum_j, \ - cntx, \ - rntm \ - ); \ -\ - if ( bli_is_unit_diag( diagx ) ) chi1 = one; \ -\ - /* Handle the diagonal element separately in case it's - unit. */ \ - PASTEMAC2(ch,chr,abval2s)( *chi1, abval_chi1 ); \ - PASTEMAC(chr,adds)( abval_chi1, absum_j ); \ -\ - /* If absum_j is greater than the previous maximum value, - then save it. */ \ - if ( absum_max < absum_j || bli_isnan( absum_j ) ) \ - { \ - PASTEMAC(chr,copys)( absum_j, absum_max ); \ - } \ - } \ - } \ - else if ( bli_is_lower( uplox_eff ) ) \ - { \ - for ( j = 0; j < n_iter; ++j ) \ - { \ - i = bli_max( 0, ( doff_t )j - ( doff_t )n_shift ); \ - n_elem = n_elem_max - i; \ -\ - chi1 = x + (j )*ldx + (ij0+i )*incx; \ - x2 = x + (j )*ldx + (ij0+i+1)*incx; \ -\ - /* Compute the norm of the sub-diagonal elements. */ \ - PASTEMAC(ch,kername) \ - ( \ - n_elem - 1, \ - x2, incx, \ - &absum_j, \ - cntx, \ - rntm \ - ); \ -\ - if ( bli_is_unit_diag( diagx ) ) chi1 = one; \ -\ - /* Handle the diagonal element separately in case it's - unit. */ \ - PASTEMAC2(ch,chr,abval2s)( *chi1, abval_chi1 ); \ - PASTEMAC(chr,adds)( abval_chi1, absum_j ); \ -\ - /* If absum_j is greater than the previous maximum value, - then save it. */ \ - if ( absum_max < absum_j || bli_isnan( absum_j ) ) \ - { \ - PASTEMAC(chr,copys)( absum_j, absum_max ); \ - } \ - } \ - } \ - } \ -\ - /* Store final value of absum_max to the output variable. */ \ - PASTEMAC(chr,copys)( absum_max, *norm ); \ + ctype* one = PASTEMAC(ch,1); \ + ctype* x0; \ + ctype* chi1; \ + ctype* x2; \ + ctype_r absum_max; \ + ctype_r absum_j; \ + ctype_r abval_chi1; \ + uplo_t uplox_eff; \ + dim_t n_iter; \ + dim_t n_elem, n_elem_max; \ + inc_t ldx, incx; \ + dim_t j, i; \ + dim_t ij0, n_shift; \ +\ + /* Initialize the maximum absolute column sum to zero. */ \ + PASTEMAC(chr,set0s)( absum_max ); \ +\ + /* If either dimension is zero, return with absum_max equal to zero. */ \ + if ( bli_zero_dim2( m, n ) ) \ + { \ + PASTEMAC(chr,copys)( absum_max, *norm ); \ + return; \ + } \ +\ + /* Set various loop parameters. */ \ + bli_set_dims_incs_uplo_1m_noswap \ + ( \ + diagoffx, BLIS_NONUNIT_DIAG, \ + uplox, m, n, rs_x, cs_x, \ + &uplox_eff, &n_elem_max, &n_iter, &incx, &ldx, \ + &ij0, &n_shift \ + ); \ +\ + /* If the matrix is zeros, return with absum_max equal to zero. */ \ + if ( bli_is_zeros( uplox_eff ) ) \ + { \ + PASTEMAC(chr,copys)( absum_max, *norm ); \ + return; \ + } \ +\ +\ + /* Handle dense and upper/lower storage cases separately. */ \ + if ( bli_is_dense( uplox_eff ) ) \ + { \ + for ( j = 0; j < n_iter; ++j ) \ + { \ + n_elem = n_elem_max; \ +\ + x0 = x + (j )*ldx + (0 )*incx; \ +\ + /* Compute the norm of the current column. */ \ + PASTEMAC(ch,kername) \ + ( \ + n_elem, \ + x0, incx, \ + &absum_j, \ + cntx, \ + rntm \ + ); \ +\ + /* If absum_j is greater than the previous maximum value, + then save it. */ \ + if ( absum_max < absum_j || bli_isnan( absum_j ) ) \ + { \ + PASTEMAC(chr,copys)( absum_j, absum_max ); \ + } \ + } \ + } \ + else \ + { \ + if ( bli_is_upper( uplox_eff ) ) \ + { \ + for ( j = 0; j < n_iter; ++j ) \ + { \ + n_elem = bli_min( n_shift + j + 1, n_elem_max ); \ +\ + x0 = x + (ij0+j )*ldx + (0 )*incx; \ + chi1 = x + (ij0+j )*ldx + (n_elem-1)*incx; \ +\ + /* Compute the norm of the super-diagonal elements. */ \ + PASTEMAC(ch,kername) \ + ( \ + n_elem - 1, \ + x0, incx, \ + &absum_j, \ + cntx, \ + rntm \ + ); \ +\ + if ( bli_is_unit_diag( diagx ) ) chi1 = one; \ +\ + /* Handle the diagonal element separately in case it's + unit. */ \ + PASTEMAC2(ch,chr,abval2s)( *chi1, abval_chi1 ); \ + PASTEMAC(chr,adds)( abval_chi1, absum_j ); \ +\ + /* If absum_j is greater than the previous maximum value, + then save it. */ \ + if ( absum_max < absum_j || bli_isnan( absum_j ) ) \ + { \ + PASTEMAC(chr,copys)( absum_j, absum_max ); \ + } \ + } \ + } \ + else if ( bli_is_lower( uplox_eff ) ) \ + { \ + for ( j = 0; j < n_iter; ++j ) \ + { \ + i = bli_max( 0, ( doff_t )j - ( doff_t )n_shift ); \ + n_elem = n_elem_max - i; \ +\ + chi1 = x + (j )*ldx + (ij0+i )*incx; \ + x2 = x + (j )*ldx + (ij0+i+1)*incx; \ +\ + /* Compute the norm of the sub-diagonal elements. */ \ + PASTEMAC(ch,kername) \ + ( \ + n_elem - 1, \ + x2, incx, \ + &absum_j, \ + cntx, \ + rntm \ + ); \ +\ + if ( bli_is_unit_diag( diagx ) ) chi1 = one; \ +\ + /* Handle the diagonal element separately in case it's + unit. */ \ + PASTEMAC2(ch,chr,abval2s)( *chi1, abval_chi1 ); \ + PASTEMAC(chr,adds)( abval_chi1, absum_j ); \ +\ + /* If absum_j is greater than the previous maximum value, + then save it. */ \ + if ( absum_max < absum_j || bli_isnan( absum_j ) ) \ + { \ + PASTEMAC(chr,copys)( absum_j, absum_max ); \ + } \ + } \ + } \ + } \ +\ + /* Store final value of absum_max to the output variable. */ \ + PASTEMAC(chr,copys)( absum_max, *norm ); \ } INSERT_GENTFUNCR_BASIC( norm1m_unb_var1, norm1v_unb_var1 ) @@ -669,152 +695,152 @@ void PASTEMAC(ch,varname) \ rntm_t* rntm \ ) \ { \ - ctype* one = PASTEMAC(ch,1); \ - ctype_r* one_r = PASTEMAC(chr,1); \ - ctype_r* zero_r = PASTEMAC(chr,0); \ - ctype* x0; \ - ctype* chi1; \ - ctype* x2; \ - ctype_r scale; \ - ctype_r sumsq; \ - ctype_r sqrt_sumsq; \ - uplo_t uplox_eff; \ - dim_t n_iter; \ - dim_t n_elem, n_elem_max; \ - inc_t ldx, incx; \ - dim_t j, i; \ - dim_t ij0, n_shift; \ -\ - /* Return a norm of zero if either dimension is zero. */ \ - if ( bli_zero_dim2( m, n ) ) \ - { \ - PASTEMAC(chr,set0s)( *norm ); \ - return; \ - } \ -\ - /* Set various loop parameters. Here, we pretend that diagx is equal to - BLIS_NONUNIT_DIAG because we handle the unit diagonal case manually. */ \ - bli_set_dims_incs_uplo_1m \ - ( \ - diagoffx, BLIS_NONUNIT_DIAG, \ - uplox, m, n, rs_x, cs_x, \ - &uplox_eff, &n_elem_max, &n_iter, &incx, &ldx, \ - &ij0, &n_shift \ - ); \ -\ - /* Check the effective uplo; if it's zeros, then our norm is zero. */ \ - if ( bli_is_zeros( uplox_eff ) ) \ - { \ - PASTEMAC(chr,set0s)( *norm ); \ - return; \ - } \ -\ - /* Initialize scale and sumsq to begin the summation. */ \ - PASTEMAC(chr,copys)( *zero_r, scale ); \ - PASTEMAC(chr,copys)( *one_r, sumsq ); \ -\ - /* Handle dense and upper/lower storage cases separately. */ \ - if ( bli_is_dense( uplox_eff ) ) \ - { \ - for ( j = 0; j < n_iter; ++j ) \ - { \ - n_elem = n_elem_max; \ -\ - x0 = x + (j )*ldx + (0 )*incx; \ -\ - /* Compute the norm of the current column. */ \ - PASTEMAC(ch,kername) \ - ( \ - n_elem, \ - x0, incx, \ - &scale, \ - &sumsq, \ - cntx, \ - rntm \ - ); \ - } \ - } \ - else \ - { \ - if ( bli_is_upper( uplox_eff ) ) \ - { \ - for ( j = 0; j < n_iter; ++j ) \ - { \ - n_elem = bli_min( n_shift + j + 1, n_elem_max ); \ -\ - x0 = x + (ij0+j )*ldx + (0 )*incx; \ - chi1 = x + (ij0+j )*ldx + (n_elem-1)*incx; \ -\ - /* Sum the squares of the super-diagonal elements. */ \ - PASTEMAC(ch,kername) \ - ( \ - n_elem - 1, \ - x0, incx, \ - &scale, \ - &sumsq, \ - cntx, \ - rntm \ - ); \ -\ - if ( bli_is_unit_diag( diagx ) ) chi1 = one; \ -\ - /* Handle the diagonal element separately in case it's - unit. */ \ - PASTEMAC(ch,kername) \ - ( \ - 1, \ - chi1, incx, \ - &scale, \ - &sumsq, \ - cntx, \ - rntm \ - ); \ - } \ - } \ - else if ( bli_is_lower( uplox_eff ) ) \ - { \ - for ( j = 0; j < n_iter; ++j ) \ - { \ - i = bli_max( 0, ( doff_t )j - ( doff_t )n_shift ); \ - n_elem = n_elem_max - i; \ -\ - chi1 = x + (j )*ldx + (ij0+i )*incx; \ - x2 = x + (j )*ldx + (ij0+i+1)*incx; \ -\ - /* Sum the squares of the sub-diagonal elements. */ \ - PASTEMAC(ch,kername) \ - ( \ - n_elem - 1, \ - x2, incx, \ - &scale, \ - &sumsq, \ - cntx, \ - rntm \ - ); \ -\ - if ( bli_is_unit_diag( diagx ) ) chi1 = one; \ -\ - /* Handle the diagonal element separately in case it's - unit. */ \ - PASTEMAC(ch,kername) \ - ( \ - 1, \ - chi1, incx, \ - &scale, \ - &sumsq, \ - cntx, \ - rntm \ - ); \ - } \ - } \ - } \ -\ - /* Compute: norm = scale * sqrt( sumsq ) */ \ - PASTEMAC(chr,sqrt2s)( sumsq, sqrt_sumsq ); \ - PASTEMAC(chr,scals)( scale, sqrt_sumsq ); \ -\ - /* Store the final value to the output variable. */ \ - PASTEMAC(chr,copys)( sqrt_sumsq, *norm ); \ + ctype* one = PASTEMAC(ch,1); \ + ctype_r* one_r = PASTEMAC(chr,1); \ + ctype_r* zero_r = PASTEMAC(chr,0); \ + ctype* x0; \ + ctype* chi1; \ + ctype* x2; \ + ctype_r scale; \ + ctype_r sumsq; \ + ctype_r sqrt_sumsq; \ + uplo_t uplox_eff; \ + dim_t n_iter; \ + dim_t n_elem, n_elem_max; \ + inc_t ldx, incx; \ + dim_t j, i; \ + dim_t ij0, n_shift; \ +\ + /* Return a norm of zero if either dimension is zero. */ \ + if ( bli_zero_dim2( m, n ) ) \ + { \ + PASTEMAC(chr,set0s)( *norm ); \ + return; \ + } \ +\ + /* Set various loop parameters. Here, we pretend that diagx is equal to + BLIS_NONUNIT_DIAG because we handle the unit diagonal case manually. */ \ + bli_set_dims_incs_uplo_1m \ + ( \ + diagoffx, BLIS_NONUNIT_DIAG, \ + uplox, m, n, rs_x, cs_x, \ + &uplox_eff, &n_elem_max, &n_iter, &incx, &ldx, \ + &ij0, &n_shift \ + ); \ +\ + /* Check the effective uplo; if it's zeros, then our norm is zero. */ \ + if ( bli_is_zeros( uplox_eff ) ) \ + { \ + PASTEMAC(chr,set0s)( *norm ); \ + return; \ + } \ +\ + /* Initialize scale and sumsq to begin the summation. */ \ + PASTEMAC(chr,copys)( *zero_r, scale ); \ + PASTEMAC(chr,copys)( *one_r, sumsq ); \ +\ + /* Handle dense and upper/lower storage cases separately. */ \ + if ( bli_is_dense( uplox_eff ) ) \ + { \ + for ( j = 0; j < n_iter; ++j ) \ + { \ + n_elem = n_elem_max; \ +\ + x0 = x + (j )*ldx + (0 )*incx; \ +\ + /* Compute the norm of the current column. */ \ + PASTEMAC(ch,kername) \ + ( \ + n_elem, \ + x0, incx, \ + &scale, \ + &sumsq, \ + cntx, \ + rntm \ + ); \ + } \ + } \ + else \ + { \ + if ( bli_is_upper( uplox_eff ) ) \ + { \ + for ( j = 0; j < n_iter; ++j ) \ + { \ + n_elem = bli_min( n_shift + j + 1, n_elem_max ); \ +\ + x0 = x + (ij0+j )*ldx + (0 )*incx; \ + chi1 = x + (ij0+j )*ldx + (n_elem-1)*incx; \ +\ + /* Sum the squares of the super-diagonal elements. */ \ + PASTEMAC(ch,kername) \ + ( \ + n_elem - 1, \ + x0, incx, \ + &scale, \ + &sumsq, \ + cntx, \ + rntm \ + ); \ +\ + if ( bli_is_unit_diag( diagx ) ) chi1 = one; \ +\ + /* Handle the diagonal element separately in case it's + unit. */ \ + PASTEMAC(ch,kername) \ + ( \ + 1, \ + chi1, incx, \ + &scale, \ + &sumsq, \ + cntx, \ + rntm \ + ); \ + } \ + } \ + else if ( bli_is_lower( uplox_eff ) ) \ + { \ + for ( j = 0; j < n_iter; ++j ) \ + { \ + i = bli_max( 0, ( doff_t )j - ( doff_t )n_shift ); \ + n_elem = n_elem_max - i; \ +\ + chi1 = x + (j )*ldx + (ij0+i )*incx; \ + x2 = x + (j )*ldx + (ij0+i+1)*incx; \ +\ + /* Sum the squares of the sub-diagonal elements. */ \ + PASTEMAC(ch,kername) \ + ( \ + n_elem - 1, \ + x2, incx, \ + &scale, \ + &sumsq, \ + cntx, \ + rntm \ + ); \ +\ + if ( bli_is_unit_diag( diagx ) ) chi1 = one; \ +\ + /* Handle the diagonal element separately in case it's + unit. */ \ + PASTEMAC(ch,kername) \ + ( \ + 1, \ + chi1, incx, \ + &scale, \ + &sumsq, \ + cntx, \ + rntm \ + ); \ + } \ + } \ + } \ +\ + /* Compute: norm = scale * sqrt( sumsq ) */ \ + PASTEMAC(chr,sqrt2s)( sumsq, sqrt_sumsq ); \ + PASTEMAC(chr,scals)( scale, sqrt_sumsq ); \ +\ + /* Store the final value to the output variable. */ \ + PASTEMAC(chr,copys)( sqrt_sumsq, *norm ); \ } INSERT_GENTFUNCR_BASIC( normfm_unb_var1, sumsqv_unb_var1 ) @@ -836,27 +862,27 @@ void PASTEMAC(ch,varname) \ rntm_t* rntm \ ) \ { \ - /* Induce a transposition so that rows become columns. */ \ - bli_swap_dims( &m, &n ); \ - bli_swap_incs( &rs_x, &cs_x ); \ - bli_toggle_uplo( &uplox ); \ - bli_negate_diag_offset( &diagoffx ); \ -\ - /* Now we can simply compute the 1-norm of this transposed matrix, - which will be equivalent to the infinity-norm of the original - matrix. */ \ - PASTEMAC(ch,kername) \ - ( \ - diagoffx, \ - diagx, \ - uplox, \ - m, \ - n, \ - x, rs_x, cs_x, \ - norm, \ - cntx, \ - rntm \ - ); \ + /* Induce a transposition so that rows become columns. */ \ + bli_swap_dims( &m, &n ); \ + bli_swap_incs( &rs_x, &cs_x ); \ + bli_toggle_uplo( &uplox ); \ + bli_negate_diag_offset( &diagoffx ); \ +\ + /* Now we can simply compute the 1-norm of this transposed matrix, + which will be equivalent to the infinity-norm of the original + matrix. */ \ + PASTEMAC(ch,kername) \ + ( \ + diagoffx, \ + diagx, \ + uplox, \ + m, \ + n, \ + x, rs_x, cs_x, \ + norm, \ + cntx, \ + rntm \ + ); \ } INSERT_GENTFUNCR_BASIC( normim_unb_var1, norm1m_unb_var1 ) @@ -875,25 +901,25 @@ void PASTEMAC(ch,opname) \ char* s2 \ ) \ { \ - dim_t i; \ - ctype* chi1; \ - char default_spec[32] = PASTEMAC(ch,formatspec)(); \ + dim_t i; \ + ctype* chi1; \ + char default_spec[32] = PASTEMAC(ch,formatspec)(); \ \ - if ( format == NULL ) format = default_spec; \ + if ( format == NULL ) format = default_spec; \ \ - chi1 = x; \ + chi1 = x; \ \ - fprintf( file, "%s\n", s1 ); \ + fprintf( file, "%s\n", s1 ); \ \ - for ( i = 0; i < n; ++i ) \ - { \ - PASTEMAC(ch,fprints)( file, format, *chi1 ); \ - fprintf( file, "\n" ); \ + for ( i = 0; i < n; ++i ) \ + { \ + PASTEMAC(ch,fprints)( file, format, *chi1 ); \ + fprintf( file, "\n" ); \ \ - chi1 += incx; \ - } \ + chi1 += incx; \ + } \ \ - fprintf( file, "%s\n", s2 ); \ + fprintf( file, "%s\n", s2 ); \ } INSERT_GENTFUNC_BASIC0_I( fprintv ) @@ -913,29 +939,29 @@ void PASTEMAC(ch,opname) \ char* s2 \ ) \ { \ - dim_t i, j; \ - ctype* chi1; \ - char default_spec[32] = PASTEMAC(ch,formatspec)(); \ + dim_t i, j; \ + ctype* chi1; \ + char default_spec[32] = PASTEMAC(ch,formatspec)(); \ \ - if ( format == NULL ) format = default_spec; \ + if ( format == NULL ) format = default_spec; \ \ - fprintf( file, "%s\n", s1 ); \ + fprintf( file, "%s\n", s1 ); \ \ - for ( i = 0; i < m; ++i ) \ - { \ - for ( j = 0; j < n; ++j ) \ - { \ - chi1 = (( ctype* ) x) + i*rs_x + j*cs_x; \ + for ( i = 0; i < m; ++i ) \ + { \ + for ( j = 0; j < n; ++j ) \ + { \ + chi1 = (( ctype* ) x) + i*rs_x + j*cs_x; \ \ - PASTEMAC(ch,fprints)( file, format, *chi1 ); \ - fprintf( file, " " ); \ - } \ + PASTEMAC(ch,fprints)( file, format, *chi1 ); \ + fprintf( file, " " ); \ + } \ \ - fprintf( file, "\n" ); \ - } \ + fprintf( file, "\n" ); \ + } \ \ - fprintf( file, "%s\n", s2 ); \ - fflush( file ); \ + fprintf( file, "%s\n", s2 ); \ + fflush( file ); \ } INSERT_GENTFUNC_BASIC0_I( fprintm ) @@ -952,17 +978,17 @@ void PASTEMAC(ch,varname) \ rntm_t* rntm \ ) \ { \ - ctype* chi1; \ - dim_t i; \ + ctype* chi1; \ + dim_t i; \ \ - chi1 = x; \ + chi1 = x; \ \ - for ( i = 0; i < n; ++i ) \ - { \ - PASTEMAC(ch,randmac)( *chi1 ); \ + for ( i = 0; i < n; ++i ) \ + { \ + PASTEMAC(ch,randmac)( *chi1 ); \ \ - chi1 += incx; \ - } \ + chi1 += incx; \ + } \ } INSERT_GENTFUNC_BASIC( randv_unb_var1, rands ) @@ -983,142 +1009,142 @@ void PASTEMAC(ch,varname) \ rntm_t* rntm \ ) \ { \ - ctype* one = PASTEMAC(ch,1); \ - ctype* x0; \ - ctype* x1; \ - ctype* x2; \ - ctype* chi1; \ - ctype beta; \ - ctype omega; \ - double max_m_n; \ - uplo_t uplox_eff; \ - dim_t n_iter; \ - dim_t n_elem, n_elem_max; \ - inc_t ldx, incx; \ - dim_t j, i; \ - dim_t ij0, n_shift; \ -\ - /* Set various loop parameters. Here, we pretend that diagx is equal to - BLIS_NONUNIT_DIAG because we handle the unit diagonal case manually. */ \ - bli_set_dims_incs_uplo_1m \ - ( \ - diagoffx, BLIS_NONUNIT_DIAG, \ - uplox, m, n, rs_x, cs_x, \ - &uplox_eff, &n_elem_max, &n_iter, &incx, &ldx, \ - &ij0, &n_shift \ - ); \ -\ - if ( bli_is_zeros( uplox_eff ) ) return; \ -\ - /* Handle dense and upper/lower storage cases separately. */ \ - if ( bli_is_dense( uplox_eff ) ) \ - { \ - for ( j = 0; j < n_iter; ++j ) \ - { \ - n_elem = n_elem_max; \ -\ - x1 = x + (j )*ldx + (0 )*incx; \ -\ - /*PASTEMAC2(ch,kername,BLIS_TAPI_EX_SUF)*/ \ - PASTEMAC(ch,kername) \ - ( \ - n_elem, \ - x1, incx, \ - cntx, \ - rntm \ - ); \ - } \ - } \ - else \ - { \ - max_m_n = bli_max( m, n ); \ -\ - PASTEMAC2(d,ch,sets)( max_m_n, 0.0, omega ); \ - PASTEMAC(ch,copys)( *one, beta ); \ - PASTEMAC(ch,invscals)( omega, beta ); \ -\ - if ( bli_is_upper( uplox_eff ) ) \ - { \ - for ( j = 0; j < n_iter; ++j ) \ - { \ - n_elem = bli_min( n_shift + j + 1, n_elem_max ); \ -\ - x1 = x + (ij0+j )*ldx + (0 )*incx; \ - x0 = x1; \ - chi1 = x1 + (n_elem-1)*incx; \ -\ - /*PASTEMAC2(ch,kername,BLIS_TAPI_EX_SUF)*/ \ - PASTEMAC(ch,kername) \ - ( \ - n_elem, \ - x1, incx, \ - cntx, \ - rntm \ - ); \ -\ - ( void )x0; \ - ( void )chi1; \ - /* We want positive diagonal elements between 1 and 2. */ \ + ctype* one = PASTEMAC(ch,1); \ + ctype* x0; \ + ctype* x1; \ + ctype* x2; \ + ctype* chi1; \ + ctype beta; \ + ctype omega; \ + double max_m_n; \ + uplo_t uplox_eff; \ + dim_t n_iter; \ + dim_t n_elem, n_elem_max; \ + inc_t ldx, incx; \ + dim_t j, i; \ + dim_t ij0, n_shift; \ +\ + /* Set various loop parameters. Here, we pretend that diagx is equal to + BLIS_NONUNIT_DIAG because we handle the unit diagonal case manually. */ \ + bli_set_dims_incs_uplo_1m \ + ( \ + diagoffx, BLIS_NONUNIT_DIAG, \ + uplox, m, n, rs_x, cs_x, \ + &uplox_eff, &n_elem_max, &n_iter, &incx, &ldx, \ + &ij0, &n_shift \ + ); \ +\ + if ( bli_is_zeros( uplox_eff ) ) return; \ +\ + /* Handle dense and upper/lower storage cases separately. */ \ + if ( bli_is_dense( uplox_eff ) ) \ + { \ + for ( j = 0; j < n_iter; ++j ) \ + { \ + n_elem = n_elem_max; \ +\ + x1 = x + (j )*ldx + (0 )*incx; \ +\ + /*PASTEMAC2(ch,kername,BLIS_TAPI_EX_SUF)*/ \ + PASTEMAC(ch,kername) \ + ( \ + n_elem, \ + x1, incx, \ + cntx, \ + rntm \ + ); \ + } \ + } \ + else \ + { \ + max_m_n = bli_max( m, n ); \ +\ + PASTEMAC2(d,ch,sets)( max_m_n, 0.0, omega ); \ + PASTEMAC(ch,copys)( *one, beta ); \ + PASTEMAC(ch,invscals)( omega, beta ); \ +\ + if ( bli_is_upper( uplox_eff ) ) \ + { \ + for ( j = 0; j < n_iter; ++j ) \ + { \ + n_elem = bli_min( n_shift + j + 1, n_elem_max ); \ +\ + x1 = x + (ij0+j )*ldx + (0 )*incx; \ + x0 = x1; \ + chi1 = x1 + (n_elem-1)*incx; \ +\ + /*PASTEMAC2(ch,kername,BLIS_TAPI_EX_SUF)*/ \ + PASTEMAC(ch,kername) \ + ( \ + n_elem, \ + x1, incx, \ + cntx, \ + rntm \ + ); \ +\ + ( void )x0; \ + ( void )chi1; \ + /* We want positive diagonal elements between 1 and 2. */ \ /* - PASTEMAC(ch,abval2s)( *chi1, *chi1 ); \ - PASTEMAC(ch,adds)( *one, *chi1 ); \ + PASTEMAC(ch,abval2s)( *chi1, *chi1 ); \ + PASTEMAC(ch,adds)( *one, *chi1 ); \ */ \ \ - /* Scale the super-diagonal elements by 1/max(m,n). */ \ + /* Scale the super-diagonal elements by 1/max(m,n). */ \ /* - PASTEMAC(ch,scalv) \ - ( \ - BLIS_NO_CONJUGATE, \ - n_elem - 1, \ - &beta, \ - x0, incx, \ - cntx \ - ); \ + PASTEMAC(ch,scalv) \ + ( \ + BLIS_NO_CONJUGATE, \ + n_elem - 1, \ + &beta, \ + x0, incx, \ + cntx \ + ); \ */ \ - } \ - } \ - else if ( bli_is_lower( uplox_eff ) ) \ - { \ - for ( j = 0; j < n_iter; ++j ) \ - { \ - i = bli_max( 0, ( doff_t )j - ( doff_t )n_shift ); \ - n_elem = n_elem_max - i; \ -\ - x1 = x + (j )*ldx + (ij0+i )*incx; \ - x2 = x1 + incx; \ - chi1 = x1; \ -\ - /*PASTEMAC2(ch,kername,BLIS_TAPI_EX_SUF)*/ \ - PASTEMAC(ch,kername) \ - ( \ - n_elem, \ - x1, incx, \ - cntx, \ - rntm \ - ); \ -\ - ( void )x2; \ - ( void )chi1; \ - /* We want positive diagonal elements between 1 and 2. */ \ + } \ + } \ + else if ( bli_is_lower( uplox_eff ) ) \ + { \ + for ( j = 0; j < n_iter; ++j ) \ + { \ + i = bli_max( 0, ( doff_t )j - ( doff_t )n_shift ); \ + n_elem = n_elem_max - i; \ +\ + x1 = x + (j )*ldx + (ij0+i )*incx; \ + x2 = x1 + incx; \ + chi1 = x1; \ +\ + /*PASTEMAC2(ch,kername,BLIS_TAPI_EX_SUF)*/ \ + PASTEMAC(ch,kername) \ + ( \ + n_elem, \ + x1, incx, \ + cntx, \ + rntm \ + ); \ +\ + ( void )x2; \ + ( void )chi1; \ + /* We want positive diagonal elements between 1 and 2. */ \ /* - PASTEMAC(ch,abval2s)( *chi1, *chi1 ); \ - PASTEMAC(ch,adds)( *one, *chi1 ); \ + PASTEMAC(ch,abval2s)( *chi1, *chi1 ); \ + PASTEMAC(ch,adds)( *one, *chi1 ); \ */ \ \ - /* Scale the sub-diagonal elements by 1/max(m,n). */ \ + /* Scale the sub-diagonal elements by 1/max(m,n). */ \ /* - PASTEMAC(ch,scalv) \ - ( \ - BLIS_NO_CONJUGATE, \ - n_elem - 1, \ - &beta, \ - x2, incx, \ - cntx \ - ); \ + PASTEMAC(ch,scalv) \ + ( \ + BLIS_NO_CONJUGATE, \ + n_elem - 1, \ + &beta, \ + x2, incx, \ + cntx \ + ); \ */ \ - } \ - } \ - } \ + } \ + } \ + } \ } INSERT_GENTFUNC_BASIC( randm_unb_var1, randv_unb_var1 ) @@ -1138,79 +1164,79 @@ void PASTEMAC(ch,varname) \ rntm_t* rntm \ ) \ { \ - const ctype_r zero_r = *PASTEMAC(chr,0); \ - const ctype_r one_r = *PASTEMAC(chr,1); \ -\ - ctype* chi1; \ - ctype_r chi1_r; \ - ctype_r chi1_i; \ - ctype_r scale_r; \ - ctype_r sumsq_r; \ - ctype_r abs_chi1_r; \ - dim_t i; \ -\ - /* NOTE: This function attempts to mimic the algorithm for computing - the Frobenius norm in netlib LAPACK's ?lassq(). */ \ -\ - /* Copy scale and sumsq to local variables. */ \ - PASTEMAC(chr,copys)( *scale, scale_r ); \ - PASTEMAC(chr,copys)( *sumsq, sumsq_r ); \ -\ - chi1 = x; \ -\ - for ( i = 0; i < n; ++i ) \ - { \ - /* Get the real and imaginary components of chi1. */ \ - PASTEMAC2(ch,chr,gets)( *chi1, chi1_r, chi1_i ); \ -\ - abs_chi1_r = bli_fabs( chi1_r ); \ -\ - /* Accumulate real component into sumsq, adjusting scale if - needed. */ \ - if ( abs_chi1_r > zero_r || bli_isnan( abs_chi1_r) ) \ - { \ - if ( scale_r < abs_chi1_r ) \ - { \ - sumsq_r = one_r + \ - sumsq_r * ( scale_r / abs_chi1_r ) * \ - ( scale_r / abs_chi1_r ); \ -\ - PASTEMAC(chr,copys)( abs_chi1_r, scale_r ); \ - } \ - else \ - { \ - sumsq_r = sumsq_r + ( abs_chi1_r / scale_r ) * \ - ( abs_chi1_r / scale_r ); \ - } \ - } \ -\ - abs_chi1_r = bli_fabs( chi1_i ); \ -\ - /* Accumulate imaginary component into sumsq, adjusting scale if - needed. */ \ - if ( abs_chi1_r > zero_r || bli_isnan( abs_chi1_r) ) \ - { \ - if ( scale_r < abs_chi1_r ) \ - { \ - sumsq_r = one_r + \ - sumsq_r * ( scale_r / abs_chi1_r ) * \ - ( scale_r / abs_chi1_r ); \ -\ - PASTEMAC(chr,copys)( abs_chi1_r, scale_r ); \ - } \ - else \ - { \ - sumsq_r = sumsq_r + ( abs_chi1_r / scale_r ) * \ - ( abs_chi1_r / scale_r ); \ - } \ - } \ -\ - chi1 += incx; \ - } \ -\ - /* Store final values of scale and sumsq to output variables. */ \ - PASTEMAC(chr,copys)( scale_r, *scale ); \ - PASTEMAC(chr,copys)( sumsq_r, *sumsq ); \ + const ctype_r zero_r = *PASTEMAC(chr,0); \ + const ctype_r one_r = *PASTEMAC(chr,1); \ +\ + ctype* chi1; \ + ctype_r chi1_r; \ + ctype_r chi1_i; \ + ctype_r scale_r; \ + ctype_r sumsq_r; \ + ctype_r abs_chi1_r; \ + dim_t i; \ +\ + /* NOTE: This function attempts to mimic the algorithm for computing + the Frobenius norm in netlib LAPACK's ?lassq(). */ \ +\ + /* Copy scale and sumsq to local variables. */ \ + PASTEMAC(chr,copys)( *scale, scale_r ); \ + PASTEMAC(chr,copys)( *sumsq, sumsq_r ); \ +\ + chi1 = x; \ +\ + for ( i = 0; i < n; ++i ) \ + { \ + /* Get the real and imaginary components of chi1. */ \ + PASTEMAC2(ch,chr,gets)( *chi1, chi1_r, chi1_i ); \ +\ + abs_chi1_r = bli_fabs( chi1_r ); \ +\ + /* Accumulate real component into sumsq, adjusting scale if + needed. */ \ + if ( abs_chi1_r > zero_r || bli_isnan( abs_chi1_r) ) \ + { \ + if ( scale_r < abs_chi1_r ) \ + { \ + sumsq_r = one_r + \ + sumsq_r * ( scale_r / abs_chi1_r ) * \ + ( scale_r / abs_chi1_r ); \ +\ + PASTEMAC(chr,copys)( abs_chi1_r, scale_r ); \ + } \ + else \ + { \ + sumsq_r = sumsq_r + ( abs_chi1_r / scale_r ) * \ + ( abs_chi1_r / scale_r ); \ + } \ + } \ +\ + abs_chi1_r = bli_fabs( chi1_i ); \ +\ + /* Accumulate imaginary component into sumsq, adjusting scale if + needed. */ \ + if ( abs_chi1_r > zero_r || bli_isnan( abs_chi1_r) ) \ + { \ + if ( scale_r < abs_chi1_r ) \ + { \ + sumsq_r = one_r + \ + sumsq_r * ( scale_r / abs_chi1_r ) * \ + ( scale_r / abs_chi1_r ); \ +\ + PASTEMAC(chr,copys)( abs_chi1_r, scale_r ); \ + } \ + else \ + { \ + sumsq_r = sumsq_r + ( abs_chi1_r / scale_r ) * \ + ( abs_chi1_r / scale_r ); \ + } \ + } \ +\ + chi1 += incx; \ + } \ +\ + /* Store final values of scale and sumsq to output variables. */ \ + PASTEMAC(chr,copys)( scale_r, *scale ); \ + PASTEMAC(chr,copys)( sumsq_r, *sumsq ); \ } INSERT_GENTFUNCR_BASIC0( sumsqv_unb_var1 ) diff --git a/kernels/zen/1/CMakeLists.txt b/kernels/zen/1/CMakeLists.txt index 00ebbdeca7..669a3ba89a 100644 --- a/kernels/zen/1/CMakeLists.txt +++ b/kernels/zen/1/CMakeLists.txt @@ -13,4 +13,5 @@ target_sources("${PROJECT_NAME}" ${CMAKE_CURRENT_SOURCE_DIR}/bli_scalv_zen_int10.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_setv_zen_int.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_swapv_zen_int8.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_norm2_zen_int.c ) diff --git a/kernels/zen/1/bli_norm2_zen_int.c b/kernels/zen/1/bli_norm2_zen_int.c new file mode 100644 index 0000000000..0a0f92e36c --- /dev/null +++ b/kernels/zen/1/bli_norm2_zen_int.c @@ -0,0 +1,236 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ +#include "immintrin.h" +#include "blis.h" + +#ifdef BLIS_ENABLE_FAST_MATH +/* Union data structure to access AVX registers + One 256-bit AVX register holds 8 SP elements. */ +typedef union +{ + __m256 v; + float f[8] __attribute__((aligned(64))); +} v8sf_t; + +/* Union data structure to access AVX registers +* One 256-bit AVX register holds 4 DP elements. */ +typedef union +{ + __m256d v; + double d[4] __attribute__((aligned(64))); +} v4df_t; + +// ----------------------------------------------------------------------------- + +void bli_dnorm2fv_unb_var1 + ( + dim_t n, + double* x, inc_t incx, + double* norm, + cntx_t* cntx + ) +{ + double sumsq = 0; + double rem_sumsq = 0; /*sum of squares accumulated for n_remainder<8 cases.*/ + dim_t n_remainder = 0; + dim_t i; + /*memory pool declarations for packing vector X. + Initialize mem pool buffer to NULL and size to 0 + "buf" and "size" fields are assigned once memory + is allocated from the pool in bli_membrk_acquire_m(). + This will ensure bli_mem_is_alloc() will be passed on + an allocated memory if created or a NULL .*/ + mem_t mem_bufX = {0}; + rntm_t rntm; + double *x_buf = x; + + /*early return if n<=0 or incx =0 */ + if((n <= 0) || (incx == 0)) + return; + + /*packing for non-unit strided Vector X*/ + if(incx != 1) + { + /* In order to get the buffer from pool via rntm access to memory broker + is needed.Following are initializations for rntm */ + + bli_rntm_init_from_global( &rntm ); + bli_rntm_set_num_threads_only( 1, &rntm ); + bli_membrk_rntm_set_membrk( &rntm ); + + //calculate the size required for "n" double elements in vector X. + size_t buffer_size = n * sizeof(double); + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_dnorm2fv_unb_var1(): get mem pool block\n" ); + #endif + + /*acquire a Buffer(n*size(double)) from the memory broker + and save the associated mem_t entry to mem_bufX.*/ + bli_membrk_acquire_m(&rntm, + buffer_size, + BLIS_BUFFER_FOR_B_PANEL, + &mem_bufX); + + /*Continue packing X if buffer memory is allocated*/ + if ((bli_mem_is_alloc( &mem_bufX ))) + { + x_buf = bli_mem_buffer(&mem_bufX); + + /*pack X vector with non-unit stride to a temp buffer x_buf with unit stride*/ + for(dim_t x_index = 0 ; x_index < n ; x_index++) + { + *(x_buf + x_index) = *(x + (x_index * incx)) ; + } + } + } + + v4df_t x0v, x1v, x2v, x3v, x4v, x5v, x6v, x7v; + /* Initialize rho vector accumulators to zero.*/ + v4df_t rho0v; rho0v.v = _mm256_setzero_pd(); + v4df_t rho1v; rho1v.v = _mm256_setzero_pd(); + v4df_t rho2v; rho2v.v = _mm256_setzero_pd(); + v4df_t rho3v; rho3v.v = _mm256_setzero_pd(); + v4df_t rho4v; rho4v.v = _mm256_setzero_pd(); + v4df_t rho5v; rho5v.v = _mm256_setzero_pd(); + v4df_t rho6v; rho6v.v = _mm256_setzero_pd(); + v4df_t rho7v; rho7v.v = _mm256_setzero_pd(); + + double *x0 = x_buf; + + for(i = 0 ; i+31 < n ; i = i + 32) + { + + x0v.v = _mm256_loadu_pd( x0 ); + x1v.v = _mm256_loadu_pd( x0 + 4 ); + x2v.v = _mm256_loadu_pd( x0 + 8 ); + x3v.v = _mm256_loadu_pd( x0 + 12 ); + x4v.v = _mm256_loadu_pd( x0 + 16 ); + x5v.v = _mm256_loadu_pd( x0 + 20 ); + x6v.v = _mm256_loadu_pd( x0 + 24 ); + x7v.v = _mm256_loadu_pd( x0 + 28 ); + + rho0v.v = _mm256_fmadd_pd(x0v.v, x0v.v, rho0v.v); + rho1v.v = _mm256_fmadd_pd(x1v.v, x1v.v, rho1v.v); + rho2v.v = _mm256_fmadd_pd(x2v.v, x2v.v, rho2v.v); + rho3v.v = _mm256_fmadd_pd(x3v.v, x3v.v, rho3v.v); + rho4v.v = _mm256_fmadd_pd(x4v.v, x4v.v, rho4v.v); + rho5v.v = _mm256_fmadd_pd(x5v.v, x5v.v, rho5v.v); + rho6v.v = _mm256_fmadd_pd(x6v.v, x6v.v, rho6v.v); + rho7v.v = _mm256_fmadd_pd(x7v.v, x7v.v, rho7v.v); + + x0 += 32; + } + + n_remainder = n - i; + + if(n_remainder) + { + if(n_remainder >= 16) + { + x0v.v = _mm256_loadu_pd( x0 ); + x1v.v = _mm256_loadu_pd( x0 + 4 ); + x2v.v = _mm256_loadu_pd( x0 + 8 ); + x3v.v = _mm256_loadu_pd( x0 + 12 ); + + rho0v.v = _mm256_fmadd_pd(x0v.v, x0v.v, rho0v.v); + rho1v.v = _mm256_fmadd_pd(x1v.v, x1v.v, rho1v.v); + rho2v.v = _mm256_fmadd_pd(x2v.v, x2v.v, rho2v.v); + rho3v.v = _mm256_fmadd_pd(x3v.v, x3v.v, rho3v.v); + + x0 += 16; + n_remainder -= 16; + } + if(n_remainder >= 8) + { + x0v.v = _mm256_loadu_pd( x0 ); + x1v.v = _mm256_loadu_pd( x0 + 4 ); + + rho0v.v = _mm256_fmadd_pd(x0v.v, x0v.v, rho0v.v); + rho1v.v = _mm256_fmadd_pd(x1v.v, x1v.v, rho1v.v); + + x0 += 8; + n_remainder -= 8; + } + if(n_remainder >= 4) + { + x0v.v = _mm256_loadu_pd( x0 ); + + rho0v.v = _mm256_fmadd_pd(x0v.v, x0v.v, rho0v.v); + + x0 += 4; + n_remainder -= 4; + } + if(n_remainder) + { + for(i=0; i< n_remainder ;i++) + { + double x_temp = *x0; + rem_sumsq += x_temp * x_temp ; + x0 += 1; + } + } + } + + /*add all the dot product of x*x into one vector .*/ + rho0v.v = _mm256_add_pd ( rho0v.v, rho1v.v ); + rho1v.v = _mm256_add_pd ( rho2v.v, rho3v.v ); + rho2v.v = _mm256_add_pd ( rho4v.v, rho5v.v ); + rho3v.v = _mm256_add_pd ( rho6v.v, rho7v.v ); + + rho4v.v = _mm256_add_pd ( rho0v.v, rho1v.v ); + rho5v.v = _mm256_add_pd ( rho2v.v, rho3v.v ); + + rho6v.v = _mm256_add_pd ( rho4v.v, rho5v.v ); + + rho7v.v = _mm256_hadd_pd( rho6v.v, rho6v.v ); + + /*rem_sumsq will have sum of squares of n_remainder < 4 cases . + Accumulate all the sum of squares to sumsq*/ + sumsq = rem_sumsq + rho7v.d[0] + rho7v.d[2]; + + PASTEMAC(d,sqrt2s)( sumsq, *norm ); + + if ((incx != 1) && bli_mem_is_alloc( &mem_bufX )) + { + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_dnorm2fv_unb_var1(): releasing mem pool block\n" ); + #endif + /* Return the buffer to pool*/ + bli_membrk_release(&rntm , &mem_bufX); + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); + return ; +} +#endif diff --git a/kernels/zen/bli_kernels_zen.h b/kernels/zen/bli_kernels_zen.h index e43742e6e1..d21eb6fe28 100644 --- a/kernels/zen/bli_kernels_zen.h +++ b/kernels/zen/bli_kernels_zen.h @@ -81,8 +81,8 @@ SCALV_KER_PROT( float, s, scalv_zen_int10 ) SCALV_KER_PROT( double, d, scalv_zen_int10 ) // swapv (intrinsics) -SWAPV_KER_PROT(float, s, swapv_zen_int8 ) -SWAPV_KER_PROT(double, d, swapv_zen_int8 ) +SWAPV_KER_PROT(float, s, swapv_zen_int8 ) +SWAPV_KER_PROT(double, d, swapv_zen_int8 ) // copyv (intrinsics) COPYV_KER_PROT( float, s, copyv_zen_int ) @@ -291,3 +291,13 @@ bool bli_cntx_syrksup_thresh_is_met_zen obj_t* c, cntx_t* cntx ); + +#ifdef BLIS_ENABLE_FAST_MATH +void bli_dnorm2fv_unb_var1 + ( + dim_t n, + double* x, inc_t incx, + double* norm, + cntx_t* cntx + ); +#endif From 3bafdf392302fb80f8ff6390c2b629ff3498ecd6 Mon Sep 17 00:00:00 2001 From: satish kumar nuggu Date: Thu, 19 Aug 2021 09:09:35 +0530 Subject: [PATCH 027/243] DGEMM kernel implementation for case k = 1. Details : - DGEMM kernel implementation for case k = 1, vectorized with 8x6 block implementation (Rank-1 update in DGEMM Optimization). Change-Id: I7d06378adeb8bcc5b965e2a94314d731629d0b4c --- kernels/zen/3/bli_dgemm_ref_k1.c | 1113 ++++++++++++++++++++++++++++-- 1 file changed, 1039 insertions(+), 74 deletions(-) diff --git a/kernels/zen/3/bli_dgemm_ref_k1.c b/kernels/zen/3/bli_dgemm_ref_k1.c index 8170a35ca5..659975cdb7 100644 --- a/kernels/zen/3/bli_dgemm_ref_k1.c +++ b/kernels/zen/3/bli_dgemm_ref_k1.c @@ -35,85 +35,1050 @@ #include #include "blis.h" -#define C(i,j) *(c+j*ldc+i) -#define A(i,j) *(a+j*lda+i) -#define B(i,j) *(b+j*ldb+i) +#include "immintrin.h" +#define D_MR 8 +#define D_NR 6 void bli_dgemm_ref_k1_nn ( - dim_t m, - dim_t n, - dim_t k, - double* alpha, - double* a, const inc_t lda, - double* b, const inc_t ldb, - double* beta, - double* c, const inc_t ldc + dim_t m, + dim_t n, + dim_t k, + double* alpha, + double* a, const inc_t lda, + double* b, const inc_t ldb, + double* beta, + double* c, const inc_t ldc ) { + double alpha_val, beta_val; - double alpha_val, beta_val, temp; - dim_t i, j, K; - - beta_val = *beta; - alpha_val = *alpha; - - if((m == 0) || (n == 0) || (((alpha_val == 0.0) || (k == 0)) && (beta_val == 1.0))){ - return; - } - - /* If alpha = 0 */ - - if(alpha_val == 0.0) - { - if(beta_val == 0.0) - { - for(j = 0; j < n; j++){ - for(i = 0; i < m; i++){ - C(i,j) = 0.0; - } - } - } - else - { - for(j = 0; j < n; j++){ - for(i = 0; i < m; i++){ - C(i,j) = beta_val * C(i,j); - } - } - } - return; - } - - - /* Start the operation */ - - - /* Form C = alpha*A*B + beta*c */ - if(beta_val == 0.0){ - for(j =0; j < n; j++){ - for(i = 0; i < m; i++){ - C(i,j) = 0.0; - } - } - } - else if(beta_val != 1.0){ - for(j = 0 ; j < n; j++){ - for(i = 0; i < m; i++){ - C(i,j) = beta_val * C(i,j); - } - } - } - - for(j = 0; j < n; j++){ - for(K = 0; K < k; K++) - { - temp = alpha_val * B(K,j); - for(i = 0; i < m; i++){ - C(i,j) = C(i,j) + temp*A(i,K); - } - } - } - return; + beta_val = *beta; + alpha_val = *alpha; + + if((m == 0) || (n == 0) || (((alpha_val == 0.0) || (k == 0)) && (beta_val == 1.0))){ + return; + } + + dim_t m_remainder = (m % D_MR); + dim_t n_remainder = (n % D_NR); + + //scratch registers + __m256d ymm0, ymm1, ymm2, ymm3; + __m256d ymm4, ymm5, ymm6, ymm7; + __m256d ymm8, ymm9, ymm10, ymm11; + __m256d ymm12, ymm13, ymm14, ymm15; + __m128d xmm5; + + /* Form C = alpha*A*B + beta*c */ + for(dim_t j = 0;j < (n-D_NR+1);j=j+D_NR) + { + double* temp_b = b + j*ldb; + double* temp_a = a; + double* temp_c = c + j*ldc; + + for(dim_t i = 0;i < (m-D_MR+1);i=i+D_MR) + { + ymm3 = _mm256_setzero_pd(); + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + if(alpha_val != 0.0) + { + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_val, + where alpha_val is not zero. + b. This loop operates with 8x6 block size + along n dimension for every D_NR columns of temp_b where + computing all D_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + ymm0 = _mm256_loadu_pd((double const *)(temp_a)); //a[0][0] a[1][0] a[2][0] a[3][0] + ymm1 = _mm256_loadu_pd((double const *)(temp_a + 4)); //a[4][0] a[5][0] a[6][0] a[7][0] + _mm_prefetch((char*)( temp_a + 64), _MM_HINT_T0); + + ymm15 = _mm256_broadcast_sd((double const *)(&alpha_val)); + + ymm0 = _mm256_mul_pd(ymm0,ymm15); //ymm0 = (alpha_val*a[0][0] alpha_val*a[1][0] alpha_val*a[2][0] alpha_val*a[3][0]) + ymm1 = _mm256_mul_pd(ymm1,ymm15); //ymm1 = (alpha_val*a[4][0] alpha_val*a[5][0] alpha_val*a[6][0] alpha_val*a[7][0]) + + ymm2 = _mm256_broadcast_sd((double const *)(temp_b)); //b[0][0] + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm3 += (b[0][0]*a[0][0] b[0][0]*a[1][0] b[0][0]*a[2][0] b[0][0]*a[3][0]) + ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); //ymm4 += (b[0][0]*a[4][0] b[0][0]*a[5][0] b[0][0]*a[6][0] b[0][0]*a[7][0]) + + ymm2 = _mm256_broadcast_sd((double const *)(temp_b + ldb * 1)); //b[0][1] + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm5 += (b[0][1]*a[0][0] b[0][1]*a[1][0] b[0][1]*a[2][0] b[0][1]*a[3][0]) + ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); //ymm6 += (b[0][1]*a[4][0] b[0][1]*a[5][0] b[0][1]*a[6][0] b[0][1]*a[7][0]) + + ymm2 = _mm256_broadcast_sd((double const *)(temp_b + ldb * 2)); //b[0][2] + ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm7 += (b[0][2]*a[0][0] b[0][2]*a[1][0] b[0][2]*a[2][0] b[0][2]*a[3][0]) + ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8); //ymm8 += (b[0][2]*a[4][0] b[0][2]*a[5][0] b[0][2]*a[6][0] b[0][2]*a[7][0]) + + ymm2 = _mm256_broadcast_sd((double const *)(temp_b + ldb * 3)); //b[0][3] + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm9 += (b[0][3]*a[0][0] b[0][3]*a[1][0] b[0][3]*a[2][0] b[0][3]*a[3][0]) + ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10); //ymm10 += (b[0][3]*a[4][0] b[0][3]*a[5][0] b[0][3]*a[6][0] b[0][3]*a[7][0]) + + ymm2 = _mm256_broadcast_sd((double const *)(temp_b + ldb * 4)); //b[0][4] + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); //ymm11 += (b[0][4]*a[0][0] b[0][4]*a[1][0] b[0][4]*a[2][0] b[0][4]*a[3][0]) + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); //ymm12 += (b[0][4]*a[4][0] b[0][4]*a[5][0] b[0][4]*a[6][0] b[0][4]*a[7][0]) + + ymm2 = _mm256_broadcast_sd((double const *)(temp_b + ldb * 5)); //b[0][5] + ymm13 = _mm256_fmadd_pd(ymm2, ymm0, ymm13); //ymm13 += (b[0][5]*a[0][0] b[0][5]*a[1][0] b[0][5]*a[2][0] b[0][5]*a[3][0]) + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); //ymm14 += (b[0][5]*a[4][0] b[0][5]*a[5][0] b[0][5]*a[6][0] b[0][5]*a[7][0]) + } + + if(beta_val != 0.0) + { + /* + a. Perform beta*C using temp_c, beta, + where beta_val is not zero. + b. This loop operates with 8x6 block size + along n dimension for every D_NR columns of temp_c where + computing all D_MR rows of temp_c. + c. Accumulated alpha*A*B into registers will be added to beta*C + d. Same approach is used in remaining fringe cases. + */ + ymm15 = _mm256_broadcast_sd((double const *)(&beta_val)); + + ymm0 = _mm256_loadu_pd((double const *)(temp_c)); //c[0][0] c[1][0] c[2][0] c[3][0] + ymm1 = _mm256_loadu_pd((double const *)(temp_c + 4)); //c[4][0] c[5][0] c[6][0] c[7][0] + + ymm3 = _mm256_fmadd_pd(ymm15, ymm0, ymm3); //ymm3 += (beta_val*c[0][0] beta_val*c[1][0] beta_val*c[2][0] beta_val*c[3][0]) + ymm4 = _mm256_fmadd_pd(ymm15, ymm1, ymm4); //ymm4 += (beta_val*c[4][0] beta_val*c[5][0] beta_val*c[6][0] beta_val*c[7][0]) + + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc)); //c[0][1] c[1][1] c[2][1] c[3][1] + ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc + 4)); //c[4][1] c[5][1] c[6][1] c[7][1] + + ymm5 = _mm256_fmadd_pd(ymm15, ymm0, ymm5); //ymm5 += (beta_val*c[0][1] beta_val*c[1][1] beta_val*c[2][1] beta_val*c[3][1]) + ymm6 = _mm256_fmadd_pd(ymm15, ymm1, ymm6); //ymm6 += (beta_val*c[4][1] beta_val*c[5][1] beta_val*c[6][1] beta_val*c[7][1]) + + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*2)); //c[0][2] c[1][2] c[2][2] c[3][2] + ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc*2 + 4)); //c[4][2] c[5][2] c[6][2] c[7][2] + + ymm7 = _mm256_fmadd_pd(ymm15, ymm0, ymm7); //ymm7 += (beta_val*c[0][2] beta_val*c[1][2] beta_val*c[2][2] beta_val*c[3][2]) + ymm8 = _mm256_fmadd_pd(ymm15, ymm1, ymm8); //ymm8 += (beta_val*c[4][2] beta_val*c[5][2] beta_val*c[6][2] beta_val*c[7][2]) + + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*3)); //c[0][3] c[1][3] c[2][3] c[3][3] + ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc*3 + 4)); //c[4][3] c[5][3] c[6][3] c[7][3] + + ymm9 = _mm256_fmadd_pd(ymm15, ymm0, ymm9); //ymm9 += (beta_val*c[0][3] beta_val*c[1][3] beta_val*c[2][3] beta_val*c[3][3]) + ymm10 = _mm256_fmadd_pd(ymm15, ymm1, ymm10); //ymm10 += (beta_val*c[4][3] beta_val*c[5][3] beta_val*c[6][3] beta_val*c[7][3]) + + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*4)); //c[0][4] c[1][4] c[2][4] c[3][4] + ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc*4 + 4)); //c[4][4] c[5][4] c[6][4] c[7][4] + + ymm11 = _mm256_fmadd_pd(ymm15, ymm0, ymm11); //ymm11 += (beta_val*c[0][4] beta_val*c[1][4] beta_val*c[2][4] beta_val*c[3][4]) + ymm12 = _mm256_fmadd_pd(ymm15, ymm1, ymm12); //ymm12 += (beta_val*c[4][4] beta_val*c[5][4] beta_val*c[6][4] beta_val*c[7][4]) + + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*5)); //c[0][5] c[1][5] c[2][5] c[3][5] + ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc*5 + 4)); //c[4][5] c[5][5] c[6][5] c[7][5] + + ymm13 = _mm256_fmadd_pd(ymm15, ymm0, ymm13); //ymm13 += (beta_val*c[0][5] beta_val*c[1][5] beta_val*c[2][5] beta_val*c[3][5]) + ymm14 = _mm256_fmadd_pd(ymm15, ymm1, ymm14); //ymm14 += (beta_val*c[4][5] beta_val*c[5][5] beta_val*c[6][5] beta_val*c[7][5]) + } + + /* + a. If both alpha_val & beta_val are zeros, + C matix will be filled with all zeros. + b. If only alpha_val is zero, + accumulated alpha*A*B will be stored into C. + c. If only beta_val is zero, + accumulated beta*C will be stored into C. + d. If both alpha_val & beta_val are not zeros, + accumulated alpha*A*B + beta*C will be stored into C. + e. Same approach is used in remaining fringe cases. + */ + + _mm256_storeu_pd((double *)(temp_c), ymm3); //c[0][0] c[1][0] c[2][0] c[3][0] + _mm256_storeu_pd((double *)(temp_c + 4), ymm4); //c[4][0] c[5][0] c[6][0] c[7][0] + + _mm256_storeu_pd((double *)(temp_c + ldc), ymm5); //c[0][1] c[1][1] c[2][1] c[3][1] + _mm256_storeu_pd((double *)(temp_c + ldc + 4), ymm6); //c[4][1] c[5][1] c[6][1] c[7][1] + + _mm256_storeu_pd((double *)(temp_c + ldc*2), ymm7); //c[0][2] c[1][2] c[2][2] c[3][2] + _mm256_storeu_pd((double *)(temp_c + ldc*2 + 4), ymm8); //c[4][2] c[5][2] c[6][2] c[7][2] + + _mm256_storeu_pd((double *)(temp_c + ldc*3), ymm9); //c[0][3] c[1][3] c[2][3] c[3][3] + _mm256_storeu_pd((double *)(temp_c + ldc*3 +4), ymm10); //c[4][3] c[5][3] c[6][3] c[7][3] + + _mm256_storeu_pd((double *)(temp_c + ldc*4), ymm11); //c[0][4] c[1][4] c[2][4] c[3][4] + _mm256_storeu_pd((double *)(temp_c + ldc*4 + 4), ymm12); //c[4][4] c[5][4] c[6][4] c[7][4] + + _mm256_storeu_pd((double *)(temp_c + ldc*5), ymm13); //c[0][5] c[1][5] c[2][5] c[3][5] + _mm256_storeu_pd((double *)(temp_c + ldc*5 + 4), ymm14); //c[4][5] c[5][5] c[6][5] c[7][5] + + temp_c += D_MR; + temp_a += D_MR; + } + + dim_t m_rem = m_remainder; + if(m_remainder >= 4) + { + ymm3 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + if(alpha_val != 0.0) + { + ymm0 = _mm256_loadu_pd((double const *)(temp_a)); //a[0][0] a[1][0] a[2][0] a[3][0] + + ymm15 = _mm256_broadcast_sd((double const *)(&alpha_val)); + ymm0 = _mm256_mul_pd(ymm0,ymm15); //ymm0 = (alpha_val*a[0][0] alpha_val*a[1][0] alpha_val*a[2][0] alpha_val*a[3][0]) + + ymm2 = _mm256_broadcast_sd((double const *)(temp_b)); //b[0][0] + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm3 += (b[0][0]*a[0][0] b[0][0]*a[1][0] b[0][0]*a[2][0] b[0][0]*a[3][0] + + ymm2 = _mm256_broadcast_sd((double const *)(temp_b + ldb * 1)); //b[0][1] + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm5 += (b[0][1]*a[0][0] b[0][1]*a[1][0] b[0][1]*a[2][0] b[0][1]*a[3][0]) + + ymm2 = _mm256_broadcast_sd((double const *)(temp_b + ldb * 2)); //b[0][2] + ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm7 += (b[0][2]*a[0][0] b[0][2]*a[1][0] b[0][2]*a[2][0] b[0][2]*a[3][0]) + + ymm2 = _mm256_broadcast_sd((double const *)(temp_b + ldb * 3)); //b[0][3] + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm9 += (b[0][3]*a[0][0] b[0][3]*a[1][0] b[0][3]*a[2][0] b[0][3]*a[3][0]) + + ymm2 = _mm256_broadcast_sd((double const *)(temp_b + ldb * 4)); //b[0][4] + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); //ymm11 += (b[0][4]*a[0][0] b[0][4]*a[1][0] b[0][4]*a[2][0] b[0][4]*a[3][0]) + + ymm2 = _mm256_broadcast_sd((double const *)(temp_b + ldb * 5)); //b[0][5] + ymm13 = _mm256_fmadd_pd(ymm2, ymm0, ymm13); //ymm13 += (b[0][5]*a[0][0] b[0][5]*a[1][0] b[0][5]*a[2][0] b[0][5]*a[3][0]) + } + + if(beta_val != 0.0) + { + ymm15 = _mm256_broadcast_sd((double const *)(&beta_val)); + + ymm0 = _mm256_loadu_pd((double const *)(temp_c)); //c[0][0] c[1][0] c[2][0] c[3][0] + ymm3 = _mm256_fmadd_pd(ymm15, ymm0, ymm3); //ymm3 += (beta_val*c[0][0] beta_val*c[1][0] beta_val*c[2][0] beta_val*c[3][0]) + + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc)); //c[0][1] c[1][1] c[2][1] c[3][1] + ymm5 = _mm256_fmadd_pd(ymm15, ymm0, ymm5); //ymm5 += (beta_val*c[0][1] beta_val*c[1][1] beta_val*c[2][1] beta_val*c[3][1]) + + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*2)); //c[0][2] c[1][2] c[2][2] c[3][2] + ymm7 = _mm256_fmadd_pd(ymm15, ymm0, ymm7); //ymm7 += (beta_val*c[0][2] beta_val*c[1][2] beta_val*c[2][2] beta_val*c[3][2]) + + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*3)); //c[0][3] c[1][3] c[2][3] c[3][3] + ymm9 = _mm256_fmadd_pd(ymm15, ymm0, ymm9); //ymm9 += (beta_val*c[0][3] beta_val*c[1][3] beta_val*c[2][3] beta_val*c[3][3]) + + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*4)); //c[0][4] c[1][4] c[2][4] c[3][4] + ymm11 = _mm256_fmadd_pd(ymm15, ymm0, ymm11); //ymm11 += (beta_val*c[0][4] beta_val*c[1][4] beta_val*c[2][4] beta_val*c[3][4]) + + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*5)); //c[0][5] c[1][5] c[2][5] c[3][5] + ymm13 = _mm256_fmadd_pd(ymm15, ymm0, ymm13); //ymm13 += (beta_val*c[0][5] beta_val*c[1][5] beta_val*c[2][5] beta_val*c[3][5]) + + } + _mm256_storeu_pd((double *)(temp_c), ymm3); //c[0][0] c[1][0] c[2][0] c[3][0] + _mm256_storeu_pd((double *)(temp_c + ldc), ymm5); //c[0][1] c[1][1] c[2][1] c[3][1] + _mm256_storeu_pd((double *)(temp_c + ldc*2), ymm7); //c[0][2] c[1][2] c[2][2] c[3][2] + _mm256_storeu_pd((double *)(temp_c + ldc*3), ymm9); //c[0][3] c[1][3] c[2][3] c[3][3] + _mm256_storeu_pd((double *)(temp_c + ldc*4), ymm11); //c[0][4] c[1][4] c[2][4] c[3][4] + _mm256_storeu_pd((double *)(temp_c + ldc*5), ymm13); //c[0][5] c[1][5] c[2][5] c[3][5] + + temp_c += 4; + temp_a += 4; + m_rem = m_remainder - 4; + } + + if(m_rem >= 2) + { + ymm3 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + if(alpha_val != 0.0) + { + __m128d xmm5; + xmm5 = _mm_loadu_pd((double const*)(temp_a)); //a[0][0] a[1][0] + ymm0 = _mm256_broadcast_sd((double const*)(temp_a)); //a[0][0] a[0][0] a[0][0] a[0][0] + ymm0 = _mm256_insertf128_pd(ymm1, xmm5, 0); //a[0][0] a[1][0] a[0][0] a[1][0] + + ymm15 = _mm256_broadcast_sd((double const *)(&alpha_val)); + + ymm0 = _mm256_mul_pd(ymm0,ymm15); //ymm0 = (alpha_val*a[0][0] alpha_val*a[1][0] alpha_val*a[0][0] alpha_val*a[1][0]) + + ymm2 = _mm256_broadcast_sd((double const *)(temp_b)); //b[0][0] + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm3 += (b[0][0]*a[0][0] b[0][0]*a[1][0] b[0][0]*a[0][0] b[0][0]*a[1][0] + + ymm2 = _mm256_broadcast_sd((double const *)(temp_b + ldb * 1)); //b[0][1] + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm5 += (b[0][1]*a[0][0] b[0][1]*a[1][0] b[0][1]*a[0][0] b[0][1]*a[1][0] + + ymm2 = _mm256_broadcast_sd((double const *)(temp_b + ldb * 2)); //b[0][2] + ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm7 += (b[0][2]*a[0][0] b[0][2]*a[1][0] b[0][2]*a[0][0] b[0][2]*a[1][0] + + ymm2 = _mm256_broadcast_sd((double const *)(temp_b + ldb * 3)); //b[0][3] + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm9 += (b[0][3]*a[0][0] b[0][3]*a[1][0] b[0][3]*a[0][0] b[0][3]*a[1][0] + + ymm2 = _mm256_broadcast_sd((double const *)(temp_b + ldb * 4)); //b[0][4] + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11);//ymm11 += (b[0][4]*a[0][0] b[0][4]*a[1][0] b[0][4]*a[0][0] b[0][4]*a[1][0] + + ymm2 = _mm256_broadcast_sd((double const *)(temp_b + ldb * 5)); //b[0][5] + ymm13 = _mm256_fmadd_pd(ymm2, ymm0, ymm13);//ymm13 += (b[0][5]*a[0][0] b[0][5]*a[1][0] b[0][5]*a[0][0] b[0][5]*a[1][0] + } + + if(beta_val != 0.0) + { + ymm15 = _mm256_broadcast_sd((double const *)(&beta_val)); + + xmm5 = _mm_loadu_pd((double const*)(temp_c)); //c[0][0] c[1][0] + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); //c[0][0] c[1][0] c[0][0] c[1][0] + + ymm3 = _mm256_fmadd_pd(ymm15, ymm0, ymm3); //ymm3 += (beta_val*c[0][0] beta_val*c[1][0] beta_val*c[0][0] beta_val*c[1][0]) + + xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc)); //c[0][1] c[1][1] + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); //c[0][1] c[1][1] c[0][1] c[1][1] + + ymm5 = _mm256_fmadd_pd(ymm15, ymm0, ymm5); //ymm5 += (beta_val*c[0][1] beta_val*c[1][1] beta_val*c[0][1] beta_val*c[1][1]) + + xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc*2)); //c[0][2] c[1][2] + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); //c[0][2] c[1][2] c[0][2] c[1][2] + + ymm7 = _mm256_fmadd_pd(ymm15, ymm0, ymm7); //ymm7 += (beta_val*c[0][2] beta_val*c[1][2] beta_val*c[0][2] beta_val*c[1][2]) + + xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc*3)); //c[0][3] c[1][3] + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); //c[0][3] c[1][3] c[0][3] c[1][3] + + ymm9 = _mm256_fmadd_pd(ymm15, ymm0, ymm9); //ymm7 += (beta_val*c[0][3] beta_val*c[1][3] beta_val*c[0][3] beta_val*c[1][3]) + + xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc*4)); //c[0][4] c[1][4] + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); //c[0][4] c[1][4] c[0][4] c[1][4] + + ymm11 = _mm256_fmadd_pd(ymm15, ymm0, ymm11); //ymm11 += (beta_val*c[0][4] beta_val*c[1][4] beta_val*c[0][4] beta_val*c[1][4]) + + xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc*5)); //c[0][5] c[1][5] + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); //c[0][5] c[1][5] c[0][5] c[1][5] + + ymm13 = _mm256_fmadd_pd(ymm15, ymm0, ymm13); //ymm13 += (beta_val*c[0][5] beta_val*c[1][5] beta_val*c[0][5] beta_val*c[1][5]) + } + + xmm5 = _mm256_extractf128_pd(ymm3, 0); // xmm5 = ymm3[0] ymm3[1] + _mm_storeu_pd((double *)(temp_c), xmm5); //c[0][0] c[1][0] + + xmm5 = _mm256_extractf128_pd(ymm5, 0); // xmm5 = ymm5[0] ymm5[1] + _mm_storeu_pd((double *)(temp_c + ldc), xmm5); //c[0][1] c[1][1] + + xmm5 = _mm256_extractf128_pd(ymm7, 0); // xmm5 = ymm7[0] ymm7[1] + _mm_storeu_pd((double *)(temp_c + ldc*2), xmm5); //c[0][2] c[1][2] + + xmm5 = _mm256_extractf128_pd(ymm9, 0); // xmm5 = ymm9[0] ymm9[1] + _mm_storeu_pd((double *)(temp_c + ldc*3), xmm5); //c[0][3] c[1][3] + + xmm5 = _mm256_extractf128_pd(ymm11, 0); // xmm5 = ymm11[0] ymm11[1] + _mm_storeu_pd((double *)(temp_c + ldc*4), xmm5); //c[0][4] c[1][4] + + xmm5 = _mm256_extractf128_pd(ymm13, 0); // xmm5 = ymm13[0] ymm13[1] + _mm_storeu_pd((double *)(temp_c + ldc*5), xmm5); //c[0][5] c[1][5] + + temp_c += 2; + temp_a += 2; + m_rem = m_rem - 2; + } + + if(m_rem == 1) + { + ymm3 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm13 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + if(alpha_val != 0.0) + { + ymm0 = _mm256_broadcast_sd((double const *)(temp_a)); //a[0][0] a[0][0] a[0][0] a[0][0] + + ymm15 = _mm256_broadcast_sd((double const *)(&alpha_val)); + ymm0 = _mm256_mul_pd(ymm0,ymm15); //ymm0 = (alpha_val*a[0][0] alpha_val*a[0][0] alpha_val*a[0][0] alpha_val*a[0][0]) + + ymm2 = _mm256_broadcast_sd((double const *)(temp_b)); //b[0][0] + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm3 += (b[0][0]*a[0][0] b[0][0]*a[0][0] b[0][0]*a[0][0] b[0][0]*a[0][0] + + ymm2 = _mm256_broadcast_sd((double const *)(temp_b + ldb * 1)); // + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm5 += (b[0][1]*a[0][0] b[0][1]*a[0][0] b[0][1]*a[0][0] b[0][1]*a[0][0] + + ymm2 = _mm256_broadcast_sd((double const *)(temp_b + ldb * 2)); //b[0][2] + ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm7 += (b[0][2]*a[0][0] b[0][2]*a[0][0] b[0][2]*a[0][0] b[0][2]*a[0][0] + + ymm2 = _mm256_broadcast_sd((double const *)(temp_b + ldb * 3)); //b[0][3] + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm9 += (b[0][3]*a[0][0] b[0][3]*a[0][0] b[0][3]*a[0][0] b[0][3]*a[0][0] + + ymm2 = _mm256_broadcast_sd((double const *)(temp_b + ldb * 4)); //b[0][4] + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); //ymm11 += (b[0][4]*a[0][0] b[0][4]*a[0][0] b[0][4]*a[0][0] b[0][4]*a[0][0] + + ymm2 = _mm256_broadcast_sd((double const *)(temp_b + ldb * 5)); //b[0][5] + ymm13 = _mm256_fmadd_pd(ymm2, ymm0, ymm13); //ymm13 += (b[0][5]*a[0][0] b[0][5]*a[0][0] b[0][5]*a[0][0] b[0][5]*a[0][0] + } + + if(beta_val != 0.0) + { + ymm15 = _mm256_broadcast_sd((double const *)(&beta_val)); + + ymm0 = _mm256_broadcast_sd((double const *)(temp_c)); //c[0][0] c[0][0] c[0][0] c[0][0] + ymm3 = _mm256_fmadd_pd(ymm15, ymm0, ymm3); //ymm3 += (beta_val*c[0][0] beta_val*c[0][0] beta_val*c[0][0] beta_val*c[0][0]) + + ymm0 = _mm256_broadcast_sd((double const *)(temp_c + ldc)); //c[0][1] c[0][1] c[0][1] c[0][1] + ymm5 = _mm256_fmadd_pd(ymm15, ymm0, ymm5); //ymm5 += (beta_val*c[0][1] beta_val*c[0][1] beta_val*c[0][1] beta_val*c[0][1]) + + ymm0 = _mm256_broadcast_sd((double const *)(temp_c + ldc*2)); //c[0][2] c[0][2] c[0][2] c[0][2] + ymm7 = _mm256_fmadd_pd(ymm15, ymm0, ymm7); //ymm7 += (beta_val*c[0][2] beta_val*c[0][2] beta_val*c[0][2] beta_val*c[0][2]) + + ymm0 = _mm256_broadcast_sd((double const *)(temp_c + ldc*3)); //c[0][3] c[0][3] c[0][3] c[0][3] + ymm9 = _mm256_fmadd_pd(ymm15, ymm0, ymm9); //ymm9 += (beta_val*c[0][3] beta_val*c[0][3] beta_val*c[0][3] beta_val*c[0][3]) + + ymm0 = _mm256_broadcast_sd((double const *)(temp_c + ldc*4)); //c[0][4] c[0][4] c[0][4] c[0][4] + ymm11 = _mm256_fmadd_pd(ymm15, ymm0, ymm11); //ymm11 += (beta_val*c[0][4] beta_val*c[0][4] beta_val*c[0][4] beta_val*c[0][4]) + + ymm0 = _mm256_broadcast_sd((double const *)(temp_c + ldc*5)); //c[0][5] c[0][5] c[0][5] c[0][5] + ymm13 = _mm256_fmadd_pd(ymm15, ymm0, ymm13); //ymm13 += (beta_val*c[0][5] beta_val*c[0][5] beta_val*c[0][5] beta_val*c[0][5]) + } + ymm0 = _mm256_blend_pd(ymm3, ymm0, 0x0E); // ymm0 = ymm3[0] ymm0[1] ymm0[2] ymm0[3] + _mm_storel_pd((temp_c), _mm256_extractf128_pd(ymm0, 0)); //c[0][0] + + ymm0 = _mm256_blend_pd(ymm5, ymm0, 0x0E); // ymm0 = ymm5[0] ymm0[1] ymm0[2] ymm0[3] + _mm_storel_pd((temp_c + ldc), _mm256_extractf128_pd(ymm0, 0)); //c[0][1] + + ymm0 = _mm256_blend_pd(ymm7, ymm0, 0x0E); // ymm0 = ymm7[0] ymm0[1] ymm0[2] ymm0[3] + _mm_storel_pd((temp_c + ldc*2), _mm256_extractf128_pd(ymm0, 0)); //c[0][2] + + ymm0 = _mm256_blend_pd(ymm9, ymm0, 0x0E); // ymm0 = ymm9[0] ymm0[1] ymm0[2] ymm0[3] + _mm_storel_pd((temp_c + ldc*3), _mm256_extractf128_pd(ymm0, 0)); //c[0][3] + + ymm0 = _mm256_blend_pd(ymm11, ymm0, 0x0E); // ymm0 = ymm11[0] ymm0[1] ymm0[2] ymm0[3] + _mm_storel_pd((temp_c + ldc*4), _mm256_extractf128_pd(ymm0, 0)); //c[0][4] + + ymm0 = _mm256_blend_pd(ymm13, ymm0, 0x0E); // ymm0 = ymm13[0] ymm0[1] ymm0[2] ymm0[3] + _mm_storel_pd((temp_c + ldc*5), _mm256_extractf128_pd(ymm0, 0)); //c[0][5] + } + } + + if(n_remainder >=4) + { + double* temp_b = b + (n - n_remainder)*ldb; + double* temp_a = a; + double* temp_c = c + (n - n_remainder)*ldc; + + for(dim_t i = 0;i < (m-D_MR+1);i=i+D_MR) + { + ymm3 = _mm256_setzero_pd(); + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + if(alpha_val != 0.0) + { + ymm0 = _mm256_loadu_pd((double const *)(temp_a)); //a[0][0] a[1][0] a[2][0] a[3][0] + ymm1 = _mm256_loadu_pd((double const *)(temp_a + 4)); //a[4][0] a[5][0] a[6][0] a[7][0] + + ymm15 = _mm256_broadcast_sd((double const *)(&alpha_val)); + + ymm0 = _mm256_mul_pd(ymm0,ymm15); //ymm0 = (alpha_val*a[0][0] alpha_val*a[1][0] alpha_val*a[2][0] alpha_val*a[3][0]) + ymm1 = _mm256_mul_pd(ymm1,ymm15); //ymm1 = (alpha_val*a[4][0] alpha_val*a[5][0] alpha_val*a[6][0] alpha_val*a[7][0]) + + ymm2 = _mm256_broadcast_sd((double const *)(temp_b)); //b[0][0] + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm3 += (b[0][0]*a[0][0] b[0][0]*a[1][0] b[0][0]*a[2][0] b[0][0]*a[3][0]) + ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); //ymm4 += (b[0][0]*a[4][0] b[0][0]*a[5][0] b[0][0]*a[6][0] b[0][0]*a[7][0]) + + ymm2 = _mm256_broadcast_sd((double const *)(temp_b + ldb * 1)); //b[0][1] + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm5 += (b[0][1]*a[0][0] b[0][1]*a[1][0] b[0][1]*a[2][0] b[0][1]*a[3][0]) + ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); //ymm6 += (b[0][1]*a[4][0] b[0][1]*a[5][0] b[0][1]*a[6][0] b[0][1]*a[7][0]) + + ymm2 = _mm256_broadcast_sd((double const *)(temp_b + ldb * 2)); //b[0][2] + ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm7 += (b[0][2]*a[0][0] b[0][2]*a[1][0] b[0][2]*a[2][0] b[0][2]*a[3][0]) + ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8); //ymm8 += (b[0][2]*a[4][0] b[0][2]*a[5][0] b[0][2]*a[6][0] b[0][2]*a[7][0]) + + ymm2 = _mm256_broadcast_sd((double const *)(temp_b + ldb * 3)); //b[0][3] + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm9 += (b[0][3]*a[0][0] b[0][3]*a[1][0] b[0][3]*a[2][0] b[0][3]*a[3][0]) + ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10); //ymm10 += (b[0][3]*a[4][0] b[0][3]*a[5][0] b[0][3]*a[6][0] b[0][3]*a[7][0]) + } + + if(beta_val != 0.0) + { + ymm15 = _mm256_broadcast_sd((double const *)(&beta_val)); + + ymm0 = _mm256_loadu_pd((double const *)(temp_c)); //c[0][0] c[1][0] c[2][0] c[3][0] + ymm1 = _mm256_loadu_pd((double const *)(temp_c + 4)); //c[4][0] c[5][0] c[6][0] c[7][0] + + ymm3 = _mm256_fmadd_pd(ymm15, ymm0, ymm3); //ymm3 += (beta_val*c[0][0] beta_val*c[1][0] beta_val*c[2][0] beta_val*c[3][0]) + ymm4 = _mm256_fmadd_pd(ymm15, ymm1, ymm4); //ymm4 += (beta_val*c[4][0] beta_val*c[5][0] beta_val*c[6][0] beta_val*c[7][0]) + + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc)); //c[0][1] c[1][1] c[2][1] c[3][1] + ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc + 4));//c[4][1] c[5][1] c[6][1] c[7][1] + + ymm5 = _mm256_fmadd_pd(ymm15, ymm0, ymm5); //ymm5 += (beta_val*c[0][1] beta_val*c[1][1] beta_val*c[2][1] beta_val*c[3][1]) + ymm6 = _mm256_fmadd_pd(ymm15, ymm1, ymm6); //ymm6 += (beta_val*c[4][1] beta_val*c[5][1] beta_val*c[6][1] beta_val*c[7][1]) + + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*2)); //c[0][2] c[1][2] c[2][2] c[3][2] + ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc*2 + 4));//c[4][2] c[5][2] c[6][2] c[7][2] + + ymm7 = _mm256_fmadd_pd(ymm15, ymm0, ymm7); //ymm7 += (beta_val*c[0][2] beta_val*c[1][2] beta_val*c[2][2] beta_val*c[3][2]) + ymm8 = _mm256_fmadd_pd(ymm15, ymm1, ymm8); //ymm8 += (beta_val*c[4][2] beta_val*c[5][2] beta_val*c[6][2] beta_val*c[7][2]) + + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*3)); //c[0][3] c[1][3] c[2][3] c[3][3] + ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc*3 + 4));//c[4][3] c[5][3] c[6][3] c[7][3] + + ymm9 = _mm256_fmadd_pd(ymm15, ymm0, ymm9); //ymm9 += (beta_val*c[0][3] beta_val*c[1][3] beta_val*c[2][3] beta_val*c[3][3]) + ymm10 = _mm256_fmadd_pd(ymm15, ymm1, ymm10); //ymm10 += (beta_val*c[4][3] beta_val*c[5][3] beta_val*c[6][3] beta_val*c[7][3]) + } + + _mm256_storeu_pd((double *)(temp_c), ymm3); //c[0][0] c[1][0] c[2][0] c[3][0] + _mm256_storeu_pd((double *)(temp_c + 4), ymm4); //c[4][0] c[5][0] c[6][0] c[7][0] + + _mm256_storeu_pd((double *)(temp_c + ldc), ymm5); //c[0][1] c[1][1] c[2][1] c[3][1] + _mm256_storeu_pd((double *)(temp_c + ldc + 4), ymm6); //c[4][1] c[5][1] c[6][1] c[7][1] + + _mm256_storeu_pd((double *)(temp_c + ldc*2), ymm7); //c[0][2] c[1][2] c[2][2] c[3][2] + _mm256_storeu_pd((double *)(temp_c + ldc*2 + 4), ymm8); //c[4][2] c[5][2] c[6][2] c[7][2] + + _mm256_storeu_pd((double *)(temp_c + ldc*3), ymm9); //c[0][3] c[1][3] c[2][3] c[3][3] + _mm256_storeu_pd((double *)(temp_c + ldc*3 +4), ymm10); //c[4][3] c[5][3] c[6][3] c[7][3] + + temp_c += D_MR; + temp_a += D_MR; + } + + dim_t m_rem = m_remainder; + if(m_remainder >= 4) + { + ymm3 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + if(alpha_val != 0.0) + { + ymm0 = _mm256_loadu_pd((double const *)(temp_a)); //a[0][0] a[1][0] a[2][0] a[3][0] + + ymm15 = _mm256_broadcast_sd((double const *)(&alpha_val)); + ymm0 = _mm256_mul_pd(ymm0,ymm15); //ymm0 = (alpha_val*a[0][0] alpha_val*a[1][0] alpha_val*a[2][0] alpha_val*a[3][0]) + + ymm2 = _mm256_broadcast_sd((double const *)(temp_b)); //b[0][0] + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm3 += (b[0][0]*a[0][0] b[0][0]*a[1][0] b[0][0]*a[2][0] b[0][0]*a[3][0]) + + ymm2 = _mm256_broadcast_sd((double const *)(temp_b + ldb * 1)); //b[0][1] + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm5 += (b[0][1]*a[0][0] b[0][1]*a[1][0] b[0][1]*a[2][0] b[0][1]*a[3][0]) + + ymm2 = _mm256_broadcast_sd((double const *)(temp_b + ldb * 2)); //b[0][2] + ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm7 += (b[0][2]*a[0][0] b[0][2]*a[1][0] b[0][2]*a[2][0] b[0][2]*a[3][0]) + + ymm2 = _mm256_broadcast_sd((double const *)(temp_b + ldb * 3)); //b[0][3] + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm9 += (b[0][3]*a[0][0] b[0][3]*a[1][0] b[0][3]*a[2][0] b[0][3]*a[3][0]) + } + + if(beta_val != 0.0) + { + ymm15 = _mm256_broadcast_sd((double const *)(&beta_val)); + + ymm0 = _mm256_loadu_pd((double const *)(temp_c)); //c[0][0] c[1][0] c[2][0] c[3][0] + ymm3 = _mm256_fmadd_pd(ymm15, ymm0, ymm3); //ymm3 += (beta_val*c[0][0] beta_val*c[1][0] beta_val*c[2][0] beta_val*c[3][0]) + + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc)); //c[0][1] c[1][1] c[2][1] c[3][1] + ymm5 = _mm256_fmadd_pd(ymm15, ymm0, ymm5); //ymm5 += (beta_val*c[0][1] beta_val*c[1][1] beta_val*c[2][1] beta_val*c[3][1]) + + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*2)); //c[0][2] c[1][2] c[2][2] c[3][2] + ymm7 = _mm256_fmadd_pd(ymm15, ymm0, ymm7); //ymm7 += (beta_val*c[0][2] beta_val*c[1][2] beta_val*c[2][2] beta_val*c[3][2]) + + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*3)); //c[0][3] c[1][3] c[2][3] c[3][3] + ymm9 = _mm256_fmadd_pd(ymm15, ymm0, ymm9); //ymm9 += (beta_val*c[0][3] beta_val*c[1][3] beta_val*c[2][3] beta_val*c[3][3]) + } + _mm256_storeu_pd((double *)(temp_c), ymm3); //c[0][0] c[1][0] c[2][0] c[3][0] + _mm256_storeu_pd((double *)(temp_c + ldc), ymm5); //c[0][1] c[1][1] c[2][1] c[3][1] + _mm256_storeu_pd((double *)(temp_c + ldc*2), ymm7); //c[0][2] c[1][2] c[2][2] c[3][2] + _mm256_storeu_pd((double *)(temp_c + ldc*3), ymm9); //c[0][3] c[1][3] c[2][3] c[3][3] + + temp_c += 4; + temp_a += 4; + m_rem = m_remainder - 4; + } + + if(m_rem >= 2) + { + ymm3 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + if(alpha_val != 0.0) + { + __m128d xmm5; + xmm5 = _mm_loadu_pd((double const*)(temp_a)); //a[0][0] a[1][0] + ymm0 = _mm256_broadcast_sd((double const*)(temp_a)); + ymm0 = _mm256_insertf128_pd(ymm1, xmm5, 0); //a[0][0] a[1][0] a[0][0] a[1][0] + + ymm15 = _mm256_broadcast_sd((double const *)(&alpha_val)); + ymm0 = _mm256_mul_pd(ymm0,ymm15); //ymm0 = (alpha_val*a[0][0] alpha_val*a[1][0] alpha_val*a[0][0] alpha_val*a[1][0]) + + ymm2 = _mm256_broadcast_sd((double const *)(temp_b)); //b[0][0] + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm3 += (b[0][0]*a[0][0] b[0][0]*a[1][0] b[0][0]*a[0][0] b[0][0]*a[1][0]) + + ymm2 = _mm256_broadcast_sd((double const *)(temp_b + ldb * 1)); //b[0][1] + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm5 += (b[0][1]*a[0][0] b[0][1]*a[1][0] b[0][1]*a[0][0] b[0][1]*a[1][0]) + + ymm2 = _mm256_broadcast_sd((double const *)(temp_b + ldb * 2)); //b[0][2] + ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm7 += (b[0][2]*a[0][0] b[0][2]*a[1][0] b[0][2]*a[0][0] b[0][2]*a[1][0]) + + ymm2 = _mm256_broadcast_sd((double const *)(temp_b + ldb * 3)); //b[0][3] + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm9 += (b[0][3]*a[0][0] b[0][3]*a[1][0] b[0][3]*a[0][0] b[0][3]*a[1][0]) + } + + if(beta_val != 0.0) + { + ymm15 = _mm256_broadcast_sd((double const *)(&beta_val)); + + xmm5 = _mm_loadu_pd((double const*)(temp_c)); //c[0][0] c[1][0] + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); //c[0][0] c[1][0] c[0][0] c[1][0] + + ymm3 = _mm256_fmadd_pd(ymm15, ymm0, ymm3); //ymm3 += (beta_val*c[0][0] beta_val*c[1][0] beta_val*c[0][0] beta_val*c[1][0]) + + xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc)); //c[0][1] c[1][1] + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); //c[0][1] c[1][1] c[0][1] c[1][1] + + ymm5 = _mm256_fmadd_pd(ymm15, ymm0, ymm5); //ymm5 += (beta_val*c[0][1] beta_val*c[1][1] beta_val*c[0][1] beta_val*c[1][1]) + + xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc*2)); //c[0][2] c[1][2] + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); //c[0][2] c[1][2] c[0][2] c[1][2] + + ymm7 = _mm256_fmadd_pd(ymm15, ymm0, ymm7); //ymm7 += (beta_val*c[0][2] beta_val*c[1][2] beta_val*c[0][2] beta_val*c[1][2]) + + xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc*3)); //c[0][3] c[1][3] + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); //c[0][3] c[1][3] c[0][3] c[1][3] + + ymm9 = _mm256_fmadd_pd(ymm15, ymm0, ymm9); //ymm9 += (beta_val*c[0][3] beta_val*c[1][3] beta_val*c[0][3] beta_val*c[1][3]) + } + + xmm5 = _mm256_extractf128_pd(ymm3, 0); // xmm5 = ymm3[0] ymm3[1] + _mm_storeu_pd((double *)(temp_c), xmm5); //c[0][0] c[1][0] + + xmm5 = _mm256_extractf128_pd(ymm5, 0); // xmm5 = ymm5[0] ymm5[1] + _mm_storeu_pd((double *)(temp_c + ldc), xmm5); //c[0][1] c[1][1] + + xmm5 = _mm256_extractf128_pd(ymm7, 0); // xmm5 = ymm7[0] ymm7[1] + _mm_storeu_pd((double *)(temp_c + ldc*2), xmm5); //c[0][2] c[1][2] + + xmm5 = _mm256_extractf128_pd(ymm9, 0); // xmm5 = ymm9[0] ymm9[1] + _mm_storeu_pd((double *)(temp_c + ldc*3), xmm5); //c[0][3] c[1][3] + + temp_c += 2; + temp_a += 2; + m_rem = m_rem - 2; + } + + if(m_rem == 1) + { + ymm3 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + if(alpha_val != 0.0) + { + ymm0 = _mm256_broadcast_sd((double const *)(temp_a)); //a[0][0] + + ymm15 = _mm256_broadcast_sd((double const *)(&alpha_val)); + ymm0 = _mm256_mul_pd(ymm0,ymm15); //ymm0 = (alpha_val*a[0][0] alpha_val*a[0][0] alpha_val*a[0][0] alpha_val*a[0][0]) + + ymm2 = _mm256_broadcast_sd((double const *)(temp_b)); //b[0][0] + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm3 += (b[0][0]*a[0][0] b[0][0]*a[0][0] b[0][0]*a[0][0] b[0][0]*a[0][0]) + + ymm2 = _mm256_broadcast_sd((double const *)(temp_b + ldb * 1)); //b[0][1] + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm5 += (b[0][1]*a[0][0] b[0][1]*a[0][0] b[0][1]*a[0][0] b[0][1]*a[0][0]) + + ymm2 = _mm256_broadcast_sd((double const *)(temp_b + ldb * 2)); //b[0][2] + ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm7 += (b[0][2]*a[0][0] b[0][2]*a[0][0] b[0][2]*a[0][0] b[0][2]*a[0][0]) + + ymm2 = _mm256_broadcast_sd((double const *)(temp_b + ldb * 3)); //b[0][3] + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm9 += (b[0][3]*a[0][0] b[0][3]*a[0][0] b[0][3]*a[0][0] b[0][3]*a[0][0]) + } + + if(beta_val != 0.0) + { + ymm15 = _mm256_broadcast_sd((double const *)(&beta_val)); + + ymm0 = _mm256_broadcast_sd((double const *)(temp_c)); //c[0][0] + ymm3 = _mm256_fmadd_pd(ymm15, ymm0, ymm3); //ymm3 += (beta_val*c[0][0] beta_val*c[0][0] beta_val*c[0][0] beta_val*c[0][0]) + + ymm0 = _mm256_broadcast_sd((double const *)(temp_c + ldc)); //c[0][1] + ymm5 = _mm256_fmadd_pd(ymm15, ymm0, ymm5); //ymm5 += (beta_val*c[0][1] beta_val*c[0][1] beta_val*c[0][1] beta_val*c[0][1]) + + ymm0 = _mm256_broadcast_sd((double const *)(temp_c + ldc*2)); //c[0][2] + ymm7 = _mm256_fmadd_pd(ymm15, ymm0, ymm7); //ymm7 += (beta_val*c[0][2] beta_val*c[0][2] beta_val*c[0][2] beta_val*c[0][2]) + + ymm0 = _mm256_broadcast_sd((double const *)(temp_c + ldc*3)); //c[0][3] + ymm9 = _mm256_fmadd_pd(ymm15, ymm0, ymm9); //ymm9 += (beta_val*c[0][3] beta_val*c[0][3] beta_val*c[0][3] beta_val*c[0][3]) + } + ymm0 = _mm256_blend_pd(ymm3, ymm0, 0x0E); // ymm0 = ymm3[0] ymm0[1] ymm0[2] ymm0[3] + _mm_storel_pd((temp_c), _mm256_extractf128_pd(ymm0, 0)); //c[0][0] + + ymm0 = _mm256_blend_pd(ymm5, ymm0, 0x0E); // ymm0 = ymm5[0] ymm0[1] ymm0[2] ymm0[3] + _mm_storel_pd((temp_c + ldc), _mm256_extractf128_pd(ymm0, 0)); //c[0][1] + + ymm0 = _mm256_blend_pd(ymm7, ymm0, 0x0E); // ymm0 = ymm7[0] ymm0[1] ymm0[2] ymm0[3] + _mm_storel_pd((temp_c + ldc*2), _mm256_extractf128_pd(ymm0, 0)); //c[0][2] + + ymm0 = _mm256_blend_pd(ymm9, ymm0, 0x0E); // ymm0 = ymm9[0] ymm0[1] ymm0[2] ymm0[3] + _mm_storel_pd((temp_c + ldc*3), _mm256_extractf128_pd(ymm0, 0)); //c[0][3] + } + n_remainder = n_remainder - 4; + } + + if(n_remainder >=2) + { + double* temp_b = b + (n - n_remainder)*ldb; + double* temp_a = a; + double* temp_c = c + (n - n_remainder)*ldc; + + for(dim_t i = 0;i < (m-D_MR+1);i=i+D_MR) + { + ymm3 = _mm256_setzero_pd(); + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + if(alpha_val != 0.0) + { + ymm0 = _mm256_loadu_pd((double const *)(temp_a)); //a[0][0] a[1][0] a[2][0] a[3][0] + ymm1 = _mm256_loadu_pd((double const *)(temp_a + 4)); //a[4][0] a[5][0] a[6][0] a[7][0] + + ymm15 = _mm256_broadcast_sd((double const *)(&alpha_val)); + + ymm0 = _mm256_mul_pd(ymm0,ymm15); //ymm0 = (alpha_val*a[0][0] alpha_val*a[1][0] alpha_val*a[2][0] alpha_val*a[3][0]) + ymm1 = _mm256_mul_pd(ymm1,ymm15); //ymm1 = (alpha_val*a[4][0] alpha_val*a[5][0] alpha_val*a[6][0] alpha_val*a[7][0]) + + ymm2 = _mm256_broadcast_sd((double const *)(temp_b)); //b[0][0] + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm3 += (b[0][0]*a[0][0] b[0][0]*a[1][0] b[0][0]*a[2][0] b[0][0]*a[3][0]) + ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); //ymm4 += (b[0][0]*a[4][0] b[0][0]*a[5][0] b[0][0]*a[6][0] b[0][0]*a[7][0]) + + ymm2 = _mm256_broadcast_sd((double const *)(temp_b + ldb)); //b[0][1] + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm5 += (b[0][1]*a[0][0] b[0][1]*a[1][0] b[0][1]*a[2][0] b[0][1]*a[3][0]) + ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); //ymm6 += (b[0][1]*a[4][0] b[0][1]*a[5][0] b[0][1]*a[6][0] b[0][1]*a[7][0]) + } + + if(beta_val != 0.0) + { + ymm15 = _mm256_broadcast_sd((double const *)(&beta_val)); + + ymm0 = _mm256_loadu_pd((double const *)(temp_c)); //c[0][0] c[1][0] c[2][0] c[3][0] + ymm1 = _mm256_loadu_pd((double const *)(temp_c + 4)); //c[4][0] c[5][0] c[6][0] c[7][0] + + ymm3 = _mm256_fmadd_pd(ymm15, ymm0, ymm3); //ymm3 += (beta_val*c[0][0] beta_val*c[1][0] beta_val*c[2][0] beta_val*c[3][0]) + ymm4 = _mm256_fmadd_pd(ymm15, ymm1, ymm4); //ymm4 += (beta_val*c[4][0] beta_val*c[5][0] beta_val*c[6][0] beta_val*c[7][0]) + + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc)); //c[0][1] c[1][1] c[2][1] c[3][1] + ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc + 4));//c[4][1] c[5][1] c[6][1] c[7][1] + + ymm5 = _mm256_fmadd_pd(ymm15, ymm0, ymm5); //ymm5 += (beta_val*c[0][1] beta_val*c[1][1] beta_val*c[2][1] beta_val*c[3][1]) + ymm6 = _mm256_fmadd_pd(ymm15, ymm1, ymm6); //ymm6 += (beta_val*c[4][1] beta_val*c[5][1] beta_val*c[6][1] beta_val*c[7][1]) + } + + _mm256_storeu_pd((double *)(temp_c), ymm3); //c[0][0] c[1][0] c[2][0] c[3][0] + _mm256_storeu_pd((double *)(temp_c + 4), ymm4); //c[4][0] c[5][0] c[6][0] c[7][0] + + _mm256_storeu_pd((double *)(temp_c + ldc), ymm5); //c[0][1] c[1][1] c[2][1] c[3][1] + _mm256_storeu_pd((double *)(temp_c + ldc + 4), ymm6); //c[4][1] c[5][1] c[6][1] c[7][1] + + temp_c += D_MR; + temp_a += D_MR; + } + + dim_t m_rem = m_remainder; + if(m_remainder >= 4) + { + ymm3 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + if(alpha_val != 0.0) + { + ymm0 = _mm256_loadu_pd((double const *)(temp_a)); //a[0][0] a[1][0] a[2][0] a[3][0] + + ymm15 = _mm256_broadcast_sd((double const *)(&alpha_val)); + ymm0 = _mm256_mul_pd(ymm0,ymm15); //ymm0 = (alpha_val*a[0][0] alpha_val*a[1][0] alpha_val*a[2][0] alpha_val*a[3][0]) + + ymm2 = _mm256_broadcast_sd((double const *)(temp_b)); //b[0][0] + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm3 += (b[0][0]*a[0][0] b[0][0]*a[1][0] b[0][0]*a[2][0] b[0][0]*a[3][0]) + + ymm2 = _mm256_broadcast_sd((double const *)(temp_b + ldb)); //b[0][1] + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm5 += (b[0][1]*a[0][0] b[0][1]*a[1][0] b[0][1]*a[2][0] b[0][1]*a[3][0]) + } + + if(beta_val != 0.0) + { + ymm15 = _mm256_broadcast_sd((double const *)(&beta_val)); + + ymm0 = _mm256_loadu_pd((double const *)(temp_c)); //c[0][0] c[1][0] c[2][0] c[3][0] + ymm3 = _mm256_fmadd_pd(ymm15, ymm0, ymm3); //ymm3 += (beta_val*c[0][0] beta_val*c[1][0] beta_val*c[2][0] beta_val*c[3][0]) + + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc)); //c[0][1] c[1][1] c[2][1] c[3][1] + ymm5 = _mm256_fmadd_pd(ymm15, ymm0, ymm5); //ymm5 += (beta_val*c[0][1] beta_val*c[1][1] beta_val*c[2][1] beta_val*c[3][1]) + } + + _mm256_storeu_pd((double *)(temp_c), ymm3); //c[0][0] c[1][0] c[2][0] c[3][0] + _mm256_storeu_pd((double *)(temp_c + ldc), ymm5); //c[0][1] c[1][1] c[2][1] c[3][1] + + temp_c += 4; + temp_a += 4; + m_rem = m_remainder - 4; + } + + if(m_rem >= 2) + { + ymm3 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + if(alpha_val != 0.0) + { + __m128d xmm5; + xmm5 = _mm_loadu_pd((double const*)(temp_a)); //a[0][0] a[1][0] + ymm0 = _mm256_broadcast_sd((double const*)(temp_a)); + ymm0 = _mm256_insertf128_pd(ymm1, xmm5, 0); //a[0][0] a[1][0] a[0][0] a[1][0] + + ymm15 = _mm256_broadcast_sd((double const *)(&alpha_val)); + ymm0 = _mm256_mul_pd(ymm0,ymm15); //ymm0 = (alpha_val*a[0][0] alpha_val*a[1][0] alpha_val*a[0][0] alpha_val*a[1][0]) + + ymm2 = _mm256_broadcast_sd((double const *)(temp_b)); //b[0][0] + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm3 += (b[0][0]*a[0][0] b[0][0]*a[1][0] b[0][0]*a[0][0] b[0][0]*a[1][0]) + + ymm2 = _mm256_broadcast_sd((double const *)(temp_b + ldb * 1)); //b[0][1] + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm5 += (b[0][1]*a[0][0] b[0][1]*a[1][0] b[0][1]*a[0][0] b[0][1]*a[1][0]) + } + + if(beta_val != 0.0) + { + ymm15 = _mm256_broadcast_sd((double const *)(&beta_val)); + + xmm5 = _mm_loadu_pd((double const*)(temp_c)); //c[0][0] c[1][0] + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); //c[0][0] c[1][0] c[0][0] c[1][0] + + ymm3 = _mm256_fmadd_pd(ymm15, ymm0, ymm3); //ymm3 += (beta_val*c[0][0] beta_val*c[1][0] beta_val*c[0][0] beta_val*c[1][0]) + + xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc)); //c[0][1] c[1][1] + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); //c[0][1] c[1][1] c[0][1] c[1][1] + + ymm5 = _mm256_fmadd_pd(ymm15, ymm0, ymm5); //ymm5 += (beta_val*c[0][1] beta_val*c[1][1] beta_val*c[0][1] beta_val*c[1][1]) + } + xmm5 = _mm256_extractf128_pd(ymm3, 0); // xmm5 = ymm3[0] ymm3[1] + _mm_storeu_pd((double *)(temp_c), xmm5); //c[0][0] c[1][0] + + xmm5 = _mm256_extractf128_pd(ymm5, 0); // xmm5 = ymm5[0] ymm5[1] + _mm_storeu_pd((double *)(temp_c + ldc), xmm5); //c[0][1] c[1][1] + + temp_c += 2; + temp_a += 2; + m_rem = m_rem - 2; + } + + if(m_rem == 1) + { + ymm3 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + if(alpha_val != 0.0) + { + ymm0 = _mm256_broadcast_sd((double const *)(temp_a)); //a[0][0] a[0][0] a[0][0] a[0][0] + + ymm15 = _mm256_broadcast_sd((double const *)(&alpha_val)); + ymm0 = _mm256_mul_pd(ymm0,ymm15); //ymm0 = (alpha_val*a[0][0] alpha_val*a[0][0] alpha_val*a[0][0] alpha_val*a[0][0]) + + ymm2 = _mm256_broadcast_sd((double const *)(temp_b)); //b[0][0] + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm3 += (b[0][0]*a[0][0] b[0][0]*a[0][0] b[0][0]*a[0][0] b[0][0]*a[0][0]) + + ymm2 = _mm256_broadcast_sd((double const *)(temp_b + ldb)); //b[0][1] + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm5 += (b[0][1]*a[0][0] b[0][1]*a[0][0] b[0][1]*a[0][0] b[0][1]*a[0][0]) + } + + if(beta_val != 0.0) + { + ymm15 = _mm256_broadcast_sd((double const *)(&beta_val)); + + ymm0 = _mm256_broadcast_sd((double const *)(temp_c)); //c[0][0] c[0][0] c[0][0] c[0][0] + ymm3 = _mm256_fmadd_pd(ymm15, ymm0, ymm3); //ymm3 += (beta_val*c[0][0] beta_val*c[0][0] beta_val*c[0][0] beta_val*c[0][0]) + + ymm0 = _mm256_broadcast_sd((double const *)(temp_c + ldc)); //c[0][1] c[0][1] c[0][1] c[0][1] + ymm5 = _mm256_fmadd_pd(ymm15, ymm0, ymm5); //ymm5 += (beta_val*c[0][1] beta_val*c[0][1] beta_val*c[0][1] beta_val*c[0][1]) + } + + ymm0 = _mm256_blend_pd(ymm3, ymm0, 0x0E); // ymm0 = ymm3[0] ymm0[1] ymm0[2] ymm0[3] + _mm_storel_pd((temp_c), _mm256_extractf128_pd(ymm0, 0)); // c[0][0] + + ymm0 = _mm256_blend_pd(ymm5, ymm0, 0x0E); //ymm0 = ymm5[0] ymm0[1] ymm0[2] ymm0[3] + _mm_storel_pd((temp_c + ldc), _mm256_extractf128_pd(ymm0, 0)); // c[0][1] + } + n_remainder = n_remainder - 2; + } + + if(n_remainder == 1) + { + double* temp_b = b + (n - n_remainder)*ldb; + double* temp_a = a; + double* temp_c = c + (n - n_remainder)*ldc; + + for(dim_t i = 0;i < (m-D_MR+1);i=i+D_MR) + { + ymm3 = _mm256_setzero_pd(); + ymm4 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + if(alpha_val != 0.0) + { + ymm0 = _mm256_loadu_pd((double const *)(temp_a)); //a[0][0] a[1][0] a[2][0] a[3][0] + ymm1 = _mm256_loadu_pd((double const *)(temp_a + 4)); //a[4][0] a[5][0] a[6][0] a[7][0] + + ymm15 = _mm256_broadcast_sd((double const *)(&alpha_val)); + + ymm0 = _mm256_mul_pd(ymm0,ymm15); //ymm0 = (alpha_val*a[0][0] alpha_val*a[1][0] alpha_val*a[2][0] alpha_val*a[3][0]) + ymm1 = _mm256_mul_pd(ymm1,ymm15); //ymm1 = (alpha_val*a[4][0] alpha_val*a[5][0] alpha_val*a[6][0] alpha_val*a[7][0]) + + ymm2 = _mm256_broadcast_sd((double const *)(temp_b)); //b[0][0] + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3);//ymm3 += (b[0][0]*a[0][0] b[0][0]*a[1][0] b[0][0]*a[2][0] b[0][0]*a[3][0]) + ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);//ymm4 += (b[0][0]*a[4][0] b[0][0]*a[5][0] b[0][0]*a[6][0] b[0][0]*a[7][0]) + } + + if(beta_val != 0.0) + { + ymm15 = _mm256_broadcast_sd((double const *)(&beta_val)); + + ymm0 = _mm256_loadu_pd((double const *)(temp_c)); //c[0][0] c[1][0] c[2][0] c[3][0] + ymm1 = _mm256_loadu_pd((double const *)(temp_c + 4)); //c[4][0] c[5][0] c[6][0] c[7][0] + + ymm3 = _mm256_fmadd_pd(ymm15, ymm0, ymm3);//ymm3 += (beta_val*c[0][0] beta_val*c[1][0] beta_val*c[2][0] beta_val*c[3][0]) + ymm4 = _mm256_fmadd_pd(ymm15, ymm1, ymm4);//ymm4 += (beta_val*c[4][0] beta_val*c[5][0] beta_val*c[6][0] beta_val*c[7][0]) + } + + _mm256_storeu_pd((double *)(temp_c), ymm3); //c[0][0] c[1][0] c[2][0] c[3][0] + _mm256_storeu_pd((double *)(temp_c + 4), ymm4); //c[4][0] c[5][0] c[6][0] c[7][0] + + temp_c += D_MR; + temp_a += D_MR; + } + + dim_t m_rem = m_remainder; + if(m_remainder >= 4) + { + ymm3 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + if(alpha_val != 0.0) + { + ymm0 = _mm256_loadu_pd((double const *)(temp_a)); //a[0][0] a[1][0] a[2][0] a[3][0] + + ymm15 = _mm256_broadcast_sd((double const *)(&alpha_val)); + ymm0 = _mm256_mul_pd(ymm0,ymm15); //ymm0 = (alpha_val*a[0][0] alpha_val*a[1][0] alpha_val*a[2][0] alpha_val*a[3][0]) + + ymm2 = _mm256_broadcast_sd((double const *)(temp_b)); //b[0][0] + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3);//ymm3 += (b[0][0]*a[0][0] b[0][0]*a[1][0] b[0][0]*a[2][0] b[0][0]*a[3][0]) + } + + if(beta_val != 0.0) + { + ymm15 = _mm256_broadcast_sd((double const *)(&beta_val)); + + ymm0 = _mm256_loadu_pd((double const *)(temp_c)); //c[0][0] c[1][0] c[2][0] c[3][0] + ymm3 = _mm256_fmadd_pd(ymm15, ymm0, ymm3);//ymm3 += (beta_val*c[0][0] beta_val*c[1][0] beta_val*c[2][0] beta_val*c[3][0]) + } + + _mm256_storeu_pd((double *)(temp_c), ymm3); //c[0][0] c[1][0] c[2][0] c[3][0] + + temp_c += 4; + temp_a += 4; + m_rem = m_remainder - 4; + } + + if(m_rem >= 2) + { + ymm3 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + if(alpha_val != 0.0) + { + __m128d xmm5; + xmm5 = _mm_loadu_pd((double const*)(temp_a)); //a[0][0] a[1][0] + ymm0 = _mm256_broadcast_sd((double const*)(temp_a)); + ymm0 = _mm256_insertf128_pd(ymm1, xmm5, 0); //a[0][0] a[1][0] a[0][0] a[1][0] + + ymm15 = _mm256_broadcast_sd((double const *)(&alpha_val)); + + ymm0 = _mm256_mul_pd(ymm0,ymm15); //ymm0 = (alpha_val*a[0][0] alpha_val*a[1][0] alpha_val*a[0][0] alpha_val*a[1][0]) + + ymm2 = _mm256_broadcast_sd((double const *)(temp_b)); //b[0][0] + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3);//ymm3 += (b[0][0]*a[0][0] b[0][0]*a[1][0] b[0][0]*a[0][0] b[0][0]*a[1][0]) + } + + if(beta_val != 0.0) + { + ymm15 = _mm256_broadcast_sd((double const *)(&beta_val)); + + xmm5 = _mm_loadu_pd((double const*)(temp_c)); //c[0][0] c[1][0] + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); //c[0][0] c[1][0] c[0][0] c[1][0] + + ymm3 = _mm256_fmadd_pd(ymm15, ymm0, ymm3);//ymm3 += (beta_val*c[0][0] beta_val*c[1][0] beta_val*c[0][0] beta_val*c[1][0]) + } + + xmm5 = _mm256_extractf128_pd(ymm3, 0); // xmm5 = ymm3[0] ymm3[1] + _mm_storeu_pd((double *)(temp_c), xmm5); //c[0][0] c[1][0] + + temp_c += 2; + temp_a += 2; + m_rem = m_rem - 2; + } + + if(m_rem == 1) + { + ymm3 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + if(alpha_val != 0.0) + { + ymm0 = _mm256_broadcast_sd((double const *)(temp_a)); //a[0][0] a[0][0] a[0][0] a[0][0] + + ymm15 = _mm256_broadcast_sd((double const *)(&alpha_val)); + ymm0 = _mm256_mul_pd(ymm0,ymm15); //ymm0 = (alpha_val*a[0][0] alpha_val*a[0][0] alpha_val*a[0][0] alpha_val*a[0][0]) + + ymm2 = _mm256_broadcast_sd((double const *)(temp_b)); //b[0][0] + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3);//ymm3 += (b[0][0]*a[0][0] b[0][0]*a[0][0] b[0][0]*a[0][0] b[0][0]*a[0][0]) + } + + if(beta_val != 0.0) + { + ymm15 = _mm256_broadcast_sd((double const *)(&beta_val)); + + ymm0 = _mm256_broadcast_sd((double const *)(temp_c)); //c[0][0] c[0][0] c[0][0] c[0][0] + ymm3 = _mm256_fmadd_pd(ymm15, ymm0, ymm3);//ymm3 += (beta_val*c[0][0] beta_val*c[0][0] beta_val*c[0][0] beta_val*c[0][0]) + } + + ymm0 = _mm256_blend_pd(ymm3, ymm0, 0x0E); // ymm0 = ymm3[0] ymm0[1] ymm0[2] ymm0[3] + + _mm_storel_pd((temp_c), _mm256_extractf128_pd(ymm0, 0)); //c[0][0] + } + n_remainder = n_remainder - 2; + } + return; } From a263146a4c805f0bbdca37fe702f171d317c928e Mon Sep 17 00:00:00 2001 From: Nageshwar Singh Date: Fri, 13 Aug 2021 20:21:08 +0530 Subject: [PATCH 028/243] Optimized scalv for complex data-types c and z (cscalv and zscalv) AMD-Internal: [CPUPL-1551] Change-Id: Ie6855409d89f1edfd2a27f9e5f9efa6cd94bc0c9 --- config/zen/bli_cntx_init_zen.c | 4 +- config/zen2/bli_cntx_init_zen2.c | 4 +- config/zen3/bli_cntx_init_zen3.c | 4 +- frame/compat/bla_scal.c | 142 ++++- frame/include/bli_gentfunc_macro_defs.h | 4 +- kernels/zen/1/bli_scalv_zen_int10.c | 728 +++++++++++++++++++++++- kernels/zen/bli_kernels_zen.h | 2 + test/Makefile | 2 +- 8 files changed, 879 insertions(+), 11 deletions(-) diff --git a/config/zen/bli_cntx_init_zen.c b/config/zen/bli_cntx_init_zen.c index 7595849866..de4cbfb130 100644 --- a/config/zen/bli_cntx_init_zen.c +++ b/config/zen/bli_cntx_init_zen.c @@ -95,7 +95,7 @@ void bli_cntx_init_zen( cntx_t* cntx ) // Update the context with optimized level-1v kernels. bli_cntx_set_l1v_kers ( - 20, + 22, #if 1 // amaxv BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int, @@ -128,6 +128,8 @@ void bli_cntx_init_zen( cntx_t* cntx ) #else BLIS_SCALV_KER, BLIS_FLOAT, bli_sscalv_zen_int10, BLIS_SCALV_KER, BLIS_DOUBLE, bli_dscalv_zen_int10, + BLIS_SCALV_KER, BLIS_SCOMPLEX, bli_cscalv_zen_int10, + BLIS_SCALV_KER, BLIS_DCOMPLEX, bli_zscalv_zen_int10, #endif BLIS_SWAPV_KER, BLIS_FLOAT, bli_sswapv_zen_int8, BLIS_SWAPV_KER, BLIS_DOUBLE, bli_dswapv_zen_int8, diff --git a/config/zen2/bli_cntx_init_zen2.c b/config/zen2/bli_cntx_init_zen2.c index 4f56316a7a..6f3bbf3da9 100644 --- a/config/zen2/bli_cntx_init_zen2.c +++ b/config/zen2/bli_cntx_init_zen2.c @@ -107,7 +107,7 @@ void bli_cntx_init_zen2( cntx_t* cntx ) // Update the context with optimized level-1v kernels. bli_cntx_set_l1v_kers ( - 20, + 22, #if 1 // amaxv BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int, @@ -134,6 +134,8 @@ void bli_cntx_init_zen2( cntx_t* cntx ) // scalv BLIS_SCALV_KER, BLIS_FLOAT, bli_sscalv_zen_int10, BLIS_SCALV_KER, BLIS_DOUBLE, bli_dscalv_zen_int10, + BLIS_SCALV_KER, BLIS_SCOMPLEX, bli_cscalv_zen_int10, + BLIS_SCALV_KER, BLIS_DCOMPLEX, bli_zscalv_zen_int10, //swap BLIS_SWAPV_KER, BLIS_FLOAT, bli_sswapv_zen_int8, diff --git a/config/zen3/bli_cntx_init_zen3.c b/config/zen3/bli_cntx_init_zen3.c index fc7dbcb808..6b97f6bbf2 100644 --- a/config/zen3/bli_cntx_init_zen3.c +++ b/config/zen3/bli_cntx_init_zen3.c @@ -107,7 +107,7 @@ void bli_cntx_init_zen3( cntx_t* cntx ) // Update the context with optimized level-1v kernels. bli_cntx_set_l1v_kers ( - 20, + 22, #if 1 // amaxv BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int, @@ -134,6 +134,8 @@ void bli_cntx_init_zen3( cntx_t* cntx ) // scalv BLIS_SCALV_KER, BLIS_FLOAT, bli_sscalv_zen_int10, BLIS_SCALV_KER, BLIS_DOUBLE, bli_dscalv_zen_int10, + BLIS_SCALV_KER, BLIS_SCOMPLEX, bli_cscalv_zen_int10, + BLIS_SCALV_KER, BLIS_DCOMPLEX, bli_zscalv_zen_int10, //swap BLIS_SWAPV_KER, BLIS_FLOAT, bli_sswapv_zen_int8, diff --git a/frame/compat/bla_scal.c b/frame/compat/bla_scal.c index 184b14eda0..821aca6160 100644 --- a/frame/compat/bla_scal.c +++ b/frame/compat/bla_scal.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020-21, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -226,8 +226,146 @@ void dscal_ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) } -INSERT_GENTFUNCSCAL_BLAS_CZ( scal, scalv ) +void cscal_ + ( + const f77_int* n, + const scomplex* alpha, + scomplex* x, + const f77_int* incx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); + AOCL_DTL_LOG_SCAL_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'C', (void *)alpha, *n, *incx ); + dim_t n0; + scomplex* x0; + inc_t incx0; + + /* Initialize BLIS */ + //bli_init_auto(); + + /* Convert typecast negative values of n to zero. */ + if ( *n < 0 ) n0 = ( dim_t )0; + else n0 = ( dim_t )(*n); + + if (*n == 0 || alpha == NULL) { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return; + } + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + if ( *incx < 0 ) + { + /* The semantics of negative stride in BLAS are that the vector + operand be traversed in reverse order. (Another way to think + of this is that negative strides effectively reverse the order + of the vector, but without any explicit data movements.) This + is also how BLIS interprets negative strides. The differences + is that with BLAS, the caller *always* passes in the 0th (i.e., + top-most or left-most) element of the vector, even when the + stride is negative. By contrast, in BLIS, negative strides are + used *relative* to the vector address as it is given. Thus, in + BLIS, if this backwards traversal is desired, the caller *must* + pass in the address to the (n-1)th (i.e., the bottom-most or + right-most) element along with a negative stride. */ + + x0 = (x) + (n0-1)*(-*incx); + incx0 = ( inc_t )(*incx); + + } + else + { + x0 = (x); + incx0 = ( inc_t )(*incx); + } + + /* Call BLIS kernel */ + bli_cscalv_zen_int10 + ( + BLIS_NO_CONJUGATE, + n0, + (scomplex*) alpha, + x0, incx0, + NULL + ); + + /* Finalize BLIS. */ + // bli_finalize_auto(); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) +} + +void zscal_ + ( + const f77_int* n, + const dcomplex* alpha, + dcomplex* x, + const f77_int* incx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) + AOCL_DTL_LOG_SCAL_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'Z', (void *)alpha, *n, *incx ); + dim_t n0; + dcomplex* x0; + inc_t incx0; + + /* Initialize BLIS */ + //bli_init_auto(); + + /* Convert typecast negative values of n to zero. */ + if ( *n < 0 ) n0 = ( dim_t )0; + else n0 = ( dim_t )(*n); + + if (*n == 0 || alpha == NULL) { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return; + } + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + if ( *incx < 0 ) + { + /* The semantics of negative stride in BLAS are that the vector + operand be traversed in reverse order. (Another way to think + of this is that negative strides effectively reverse the order + of the vector, but without any explicit data movements.) This + is also how BLIS interprets negative strides. The differences + is that with BLAS, the caller *always* passes in the 0th (i.e., + top-most or left-most) element of the vector, even when the + stride is negative. By contrast, in BLIS, negative strides are + used *relative* to the vector address as it is given. Thus, in + BLIS, if this backwards traversal is desired, the caller *must* + pass in the address to the (n-1)th (i.e., the bottom-most or + right-most) element along with a negative stride. */ + + x0 = (x) + (n0-1)*(-*incx); + incx0 = ( inc_t )(*incx); + + } + else + { + x0 = (x); + incx0 = ( inc_t )(*incx); + } + + /* Call BLIS kernel */ + bli_zscalv_zen_int10 + ( + BLIS_NO_CONJUGATE, + n0, + (dcomplex*) alpha, + x0, incx0, + NULL + ); + + /* Finalize BLIS. */ + // bli_finalize_auto(); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) +} + +INSERT_GENTFUNCSCAL_BLAS_CsZd(scal, scalv) + #else INSERT_GENTFUNCSCAL_BLAS( scal, scalv ) #endif #endif + diff --git a/frame/include/bli_gentfunc_macro_defs.h b/frame/include/bli_gentfunc_macro_defs.h index 1bac7aa7c4..ae0b1f3857 100644 --- a/frame/include/bli_gentfunc_macro_defs.h +++ b/frame/include/bli_gentfunc_macro_defs.h @@ -151,10 +151,8 @@ GENTFUNCR2( dcomplex, double, z, d, blasname, blisname ) // -- Extended two-operand macro (used only for scal) -- -#define INSERT_GENTFUNCSCAL_BLAS_CZ( blasname, blisname ) \ +#define INSERT_GENTFUNCSCAL_BLAS_CsZd( blasname, blisname ) \ \ -GENTFUNCSCAL( scomplex, scomplex, c, , blasname, blisname ) \ -GENTFUNCSCAL( dcomplex, dcomplex, z, , blasname, blisname ) \ GENTFUNCSCAL( scomplex, float, c, s, blasname, blisname ) \ GENTFUNCSCAL( dcomplex, double, z, d, blasname, blisname ) diff --git a/kernels/zen/1/bli_scalv_zen_int10.c b/kernels/zen/1/bli_scalv_zen_int10.c index 6c7f52e161..fef490196d 100644 --- a/kernels/zen/1/bli_scalv_zen_int10.c +++ b/kernels/zen/1/bli_scalv_zen_int10.c @@ -4,8 +4,8 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2017 - 2020, Advanced Micro Devices, Inc. All rights reserved. - Copyright (C) 2018, The University of Texas at Austin + Copyright (C) 2017 - 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018, The University of Texas at Austin. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -454,3 +454,727 @@ void bli_dscalv_zen_int10 } } +// ----------------------------------------------------------------------------- + +void bli_cscalv_zen_int10 + ( + conj_t conjalpha, + dim_t n, + scomplex* restrict alpha, + scomplex* restrict x, + inc_t incx, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_4) + + const dim_t n_elem_per_reg = 8; + + dim_t i; + + float* restrict x0; + float* restrict alpha0; + float alphaR, alphaI; + + __m256 alphaRv; + __m256 alphaIv; + __m256 xv[10]; + __m256 x_sufv[10]; + + conj_t conjx_use = conjalpha; + + // If the vector dimension is zero, or if alpha is unit, return early. + if ( bli_zero_dim1( n ) || PASTEMAC(c,eq1)( *alpha ) ) return; + + // If alpha is zero, use setv. + if ( PASTEMAC(c,eq0)( *alpha ) ) + { + scomplex* zero = bli_c0; + if (cntx == NULL) + cntx = bli_gks_query_cntx(); + csetv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_SCOMPLEX, BLIS_SETV_KER, cntx ); + f + ( + BLIS_NO_CONJUGATE, + n, + zero, + x, incx, + cntx + ); + return; + } + + // Initialize local pointers. + x0 = (float*)x; + alpha0 = (float*)alpha; + + alphaR = alpha->real; + alphaI = alpha->imag; + + if ( incx == 1 ) + { + // Broadcast the alpha scalar to all elements of a vector register. + if ( !bli_is_conj (conjx_use) ) // If BLIS_NO_CONJUGATE + { + alphaRv = _mm256_broadcast_ss( &alphaR ); + alphaIv = _mm256_set_ps(alphaI, -alphaI, alphaI, -alphaI, alphaI, -alphaI, alphaI, -alphaI); + } + else + { + alphaIv = _mm256_broadcast_ss( &alphaI ); + alphaRv = _mm256_set_ps(-alphaR, alphaR, -alphaR, alphaR, -alphaR, alphaR, -alphaR, alphaR); + } + + /* + = (alpha_r + alpha_i) * (x_r + x_i) + = alpha_r*x_r + alpha_r*x_i + alpha_i*x_r + (-alpha_i*x_i) + = (alpha_r*x_r - alpha_i*x_i) + (alpha_r*x_i + alpha_i*x_r)I + + x = x_r , x_i , x_r , x_i , x_r , x_i , x_r , x_i + x_suf = x_i , x_r , x_i , x_r , x_i , x_r , x_i , x_r + alphaR = ar , ar , ar , ar , ar , ar , ar , ar + alphaI = -ai , ai ,-ai , ai ,-ai , ai ,-ai, ai + + step 1) Load x. + step 2) Shuffle x. + step 3) mul x <= x*alphaR => ar*x_r , ar*x_i + step 4) fma x <= x_suf*alphaI + x => (-ai*x_i , ai*x_r) + (ar*x_r , ar*x_i) + => (ar*x_r - ai*x_i), (ar*x_i + ai*x_r ) + */ + + for ( i = 0; (i + 39) < n; i += 40 ) + { + // Load the input values. + xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); + xv[2] = _mm256_loadu_ps( x0 + 2*n_elem_per_reg ); + xv[3] = _mm256_loadu_ps( x0 + 3*n_elem_per_reg ); + xv[4] = _mm256_loadu_ps( x0 + 4*n_elem_per_reg ); + xv[5] = _mm256_loadu_ps( x0 + 5*n_elem_per_reg ); + xv[6] = _mm256_loadu_ps( x0 + 6*n_elem_per_reg ); + xv[7] = _mm256_loadu_ps( x0 + 7*n_elem_per_reg ); + xv[8] = _mm256_loadu_ps( x0 + 8*n_elem_per_reg ); + xv[9] = _mm256_loadu_ps( x0 + 9*n_elem_per_reg ); + + // x = xr0 , xi0, xr1, xi1 .... + // x_suf = xi0 , xr0, xi1, xr1 .... + x_sufv[0] = _mm256_permute_ps( xv[0], 0xB1); + x_sufv[1] = _mm256_permute_ps( xv[1], 0xB1); + x_sufv[2] = _mm256_permute_ps( xv[2], 0xB1); + x_sufv[3] = _mm256_permute_ps( xv[3], 0xB1); + x_sufv[4] = _mm256_permute_ps( xv[4], 0xB1); + x_sufv[5] = _mm256_permute_ps( xv[5], 0xB1); + x_sufv[6] = _mm256_permute_ps( xv[6], 0xB1); + x_sufv[7] = _mm256_permute_ps( xv[7], 0xB1); + x_sufv[8] = _mm256_permute_ps( xv[8], 0xB1); + x_sufv[9] = _mm256_permute_ps( xv[9], 0xB1); + + // mul x <= x*alphaR + // aphhaR = ar , ar , ar , ar , .... + // x = xr , xi , xr , xi , .... + // mul = ar*xr, ar*xi , ar*xr , ar*xi, .... + xv[0] = _mm256_mul_ps( alphaRv, xv[0] ); + xv[1] = _mm256_mul_ps( alphaRv, xv[1] ); + xv[2] = _mm256_mul_ps( alphaRv, xv[2] ); + xv[3] = _mm256_mul_ps( alphaRv, xv[3] ); + xv[4] = _mm256_mul_ps( alphaRv, xv[4] ); + xv[5] = _mm256_mul_ps( alphaRv, xv[5] ); + xv[6] = _mm256_mul_ps( alphaRv, xv[6] ); + xv[7] = _mm256_mul_ps( alphaRv, xv[7] ); + xv[8] = _mm256_mul_ps( alphaRv, xv[8] ); + xv[9] = _mm256_mul_ps( alphaRv, xv[9] ); + + // fma x <= x_suf*alphaI + x + // alphaI = -ai , ai , -ai , ai .... + // X suf = xi , xr , xi , xr .... + // mul = -ai*xi, ai*xr , -ai*xi, ai*xi .... + // add x = ar*xr - ai*xi, ar*xi + ai*xr, .... + xv[0] = _mm256_fmadd_ps( alphaIv, x_sufv[0], xv[0] ); + xv[1] = _mm256_fmadd_ps( alphaIv, x_sufv[1], xv[1] ); + xv[2] = _mm256_fmadd_ps( alphaIv, x_sufv[2], xv[2] ); + xv[3] = _mm256_fmadd_ps( alphaIv, x_sufv[3], xv[3] ); + xv[4] = _mm256_fmadd_ps( alphaIv, x_sufv[4], xv[4] ); + xv[5] = _mm256_fmadd_ps( alphaIv, x_sufv[5], xv[5] ); + xv[6] = _mm256_fmadd_ps( alphaIv, x_sufv[6], xv[6] ); + xv[7] = _mm256_fmadd_ps( alphaIv, x_sufv[7], xv[7] ); + xv[8] = _mm256_fmadd_ps( alphaIv, x_sufv[8], xv[8] ); + xv[9] = _mm256_fmadd_ps( alphaIv, x_sufv[9], xv[9] ); + + // Store the output. + _mm256_storeu_ps( (x0 + 0*n_elem_per_reg), xv[0] ); + _mm256_storeu_ps( (x0 + 1*n_elem_per_reg), xv[1] ); + _mm256_storeu_ps( (x0 + 2*n_elem_per_reg), xv[2] ); + _mm256_storeu_ps( (x0 + 3*n_elem_per_reg), xv[3] ); + _mm256_storeu_ps( (x0 + 4*n_elem_per_reg), xv[4] ); + _mm256_storeu_ps( (x0 + 5*n_elem_per_reg), xv[5] ); + _mm256_storeu_ps( (x0 + 6*n_elem_per_reg), xv[6] ); + _mm256_storeu_ps( (x0 + 7*n_elem_per_reg), xv[7] ); + _mm256_storeu_ps( (x0 + 8*n_elem_per_reg), xv[8] ); + _mm256_storeu_ps( (x0 + 9*n_elem_per_reg), xv[9] ); + + x0 += 10*n_elem_per_reg; + } + + for ( ; (i + 19) < n; i += 20 ) + { + // Load the input values. + xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); + xv[2] = _mm256_loadu_ps( x0 + 2*n_elem_per_reg ); + xv[3] = _mm256_loadu_ps( x0 + 3*n_elem_per_reg ); + xv[4] = _mm256_loadu_ps( x0 + 4*n_elem_per_reg ); + + // x = xr0 , xi0, xr1, xi1 .... + // x_suf = xi0 , xr0, xi1, xr1 .... + x_sufv[0] = _mm256_permute_ps( xv[0], 0xB1); + x_sufv[1] = _mm256_permute_ps( xv[1], 0xB1); + x_sufv[2] = _mm256_permute_ps( xv[2], 0xB1); + x_sufv[3] = _mm256_permute_ps( xv[3], 0xB1); + x_sufv[4] = _mm256_permute_ps( xv[4], 0xB1); + + // mul x <= x*alphaR + // aphhaR = ar , ar , ar , ar , .... + // x = xr , xi , xr , xi , .... + // mul = ar*xr, ar*xi , ar*xr , ar*xi, .... + xv[0] = _mm256_mul_ps( alphaRv, xv[0] ); + xv[1] = _mm256_mul_ps( alphaRv, xv[1] ); + xv[2] = _mm256_mul_ps( alphaRv, xv[2] ); + xv[3] = _mm256_mul_ps( alphaRv, xv[3] ); + xv[4] = _mm256_mul_ps( alphaRv, xv[4] ); + + // fma x <= x_suf*alphaI + x + // alphaI = -ai , ai , -ai , ai .... + // X = xi , xr , xi , xr .... + // mul = -ai*xi, ai*xr , -ai*xi, ai*xi .... + // add x = ar*xr - ai*xi, ar*xi + ai*xr, + xv[0] = _mm256_fmadd_ps( alphaIv, x_sufv[0], xv[0] ); + xv[1] = _mm256_fmadd_ps( alphaIv, x_sufv[1], xv[1] ); + xv[2] = _mm256_fmadd_ps( alphaIv, x_sufv[2], xv[2] ); + xv[3] = _mm256_fmadd_ps( alphaIv, x_sufv[3], xv[3] ); + xv[4] = _mm256_fmadd_ps( alphaIv, x_sufv[4], xv[4] ); + + // Store the output. + _mm256_storeu_ps( (x0 + 0*n_elem_per_reg), xv[0] ); + _mm256_storeu_ps( (x0 + 1*n_elem_per_reg), xv[1] ); + _mm256_storeu_ps( (x0 + 2*n_elem_per_reg), xv[2] ); + _mm256_storeu_ps( (x0 + 3*n_elem_per_reg), xv[3] ); + _mm256_storeu_ps( (x0 + 4*n_elem_per_reg), xv[4] ); + + x0 += 5*n_elem_per_reg; + } + + for ( ; (i + 15) < n; i += 16 ) + { + // Load the input values. + xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); + xv[2] = _mm256_loadu_ps( x0 + 2*n_elem_per_reg ); + xv[3] = _mm256_loadu_ps( x0 + 3*n_elem_per_reg ); + + // x = xr0 , xi0, xr1, xi1 .... + // x_suf = xi0 , xr0, xi1, xr1 .... + x_sufv[0] = _mm256_permute_ps( xv[0], 0xB1); + x_sufv[1] = _mm256_permute_ps( xv[1], 0xB1); + x_sufv[2] = _mm256_permute_ps( xv[2], 0xB1); + x_sufv[3] = _mm256_permute_ps( xv[3], 0xB1); + + // mul x <= x*alphaR + // aphhaR = ar , ar , ar , ar , .... + // x = xr , xi , xr , xi , .... + // mul = ar*xr, ar*xi , ar*xr , ar*xi, .... + xv[0] = _mm256_mul_ps( alphaRv, xv[0] ); + xv[1] = _mm256_mul_ps( alphaRv, xv[1] ); + xv[2] = _mm256_mul_ps( alphaRv, xv[2] ); + xv[3] = _mm256_mul_ps( alphaRv, xv[3] ); + + // fma x <= x_suf*alphaI + x + // alphaI = -ai , ai , -ai , ai .... + // X = xi , xr , xi , xr .... + // mul = -ai*xi, ai*xr , -ai*xi, ai*xi .... + // add x = ar*xr - ai*xi, ar*xi + ai*xr, + xv[0] = _mm256_fmadd_ps( alphaIv, x_sufv[0], xv[0] ); + xv[1] = _mm256_fmadd_ps( alphaIv, x_sufv[1], xv[1] ); + xv[2] = _mm256_fmadd_ps( alphaIv, x_sufv[2], xv[2] ); + xv[3] = _mm256_fmadd_ps( alphaIv, x_sufv[3], xv[3] ); + + // Store the output. + _mm256_storeu_ps( (x0 + 0*n_elem_per_reg), xv[0] ); + _mm256_storeu_ps( (x0 + 1*n_elem_per_reg), xv[1] ); + _mm256_storeu_ps( (x0 + 2*n_elem_per_reg), xv[2] ); + _mm256_storeu_ps( (x0 + 3*n_elem_per_reg), xv[3] ); + + x0 += 4*n_elem_per_reg; + } + + for ( ; (i + 7) < n; i += 8 ) + { + // Load the input values. + xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); + + // x = xr0 , xi0, xr1, xi1 .... + // x_suf = xi0 , xr0, xi1, xr1 .... + x_sufv[0] = _mm256_permute_ps( xv[0], 0xB1); + x_sufv[1] = _mm256_permute_ps( xv[1], 0xB1); + + // mul x <= x*alphaR + // aphhaR = ar , ar , ar , ar , .... + // x = xr , xi , xr , xi , .... + // mul = ar*xr, ar*xi , ar*xr , ar*xi, .... + xv[0] = _mm256_mul_ps( alphaRv, xv[0] ); + xv[1] = _mm256_mul_ps( alphaRv, xv[1] ); + + // fma x <= x_suf*alphaI + x + // alphaI = -ai , ai , -ai , ai .... + // X = xi , xr , xi , xr .... + // mul = -ai*xi, ai*xr , -ai*xi, ai*xi .... + // add x = ar*xr - ai*xi, ar*xi + ai*xr, + xv[0] = _mm256_fmadd_ps( alphaIv, x_sufv[0], xv[0] ); + xv[1] = _mm256_fmadd_ps( alphaIv, x_sufv[1], xv[1] ); + + // Store the output. + _mm256_storeu_ps( (x0 + 0*n_elem_per_reg), xv[0] ); + _mm256_storeu_ps( (x0 + 1*n_elem_per_reg), xv[1] ); + + x0 += 2*n_elem_per_reg; + } + + for ( ; (i + 3) < n; i += 4 ) + { + // Load the input values. + xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); + + // x = xr0 , xi0, xr1, xi1 .... + // x_suf = xi0 , xr0, xi1, xr1 .... + x_sufv[0] = _mm256_permute_ps( xv[0], 0xB1); + + // mul x <= x*alphaR + // aphhaR = ar , ar , ar , ar , .... + // x = xr , xi , xr , xi , .... + // mul = ar*xr, ar*xi , ar*xr , ar*xi, .... + xv[0] = _mm256_mul_ps( alphaRv, xv[0] ); + + // fma x <= x_suf*alphaI + x + // alphaI = -ai , ai , -ai , ai .... + // X = xi , xr , xi , xr .... + // mul = -ai*xi, ai*xr , -ai*xi, ai*xi .... + // add x = ar*xr - ai*xi, ar*xi + ai*xr, + xv[0] = _mm256_fmadd_ps( alphaIv, x_sufv[0], xv[0] ); + + // Store the output. + _mm256_storeu_ps( (x0 + 0*n_elem_per_reg), xv[0] ); + + x0 += 1*n_elem_per_reg; + } + + for ( ; (i + 0) < n; i += 1 ) + { + float real; + + // real part: ( aR.xR - aIxI ) + real = *alpha0 * (*x0) - (*(alpha0 + 1)) * (*(x0+1)); + // img part: ( aR.xI + aI.xR ) + *(x0 + 1) = *alpha0 * (*(x0+1)) + (*(alpha0 + 1)) * (*x0); + + *x0 = real; + + x0 += 2; + } + } + else + { + const float alphar = *alpha0; + const float alphai = *(alpha0 + 1); + + if ( !bli_is_conj(conjx_use) ) // BLIS_NO_CONJUGATE + { + for ( i = 0; i < n; ++i ) + { + const float x0c = *x0; + const float x1c = *( x0+1 ); + + *x0 = alphar * x0c - alphai * x1c; + *(x0 + 1) = alphar * x1c + alphai * x0c; + + x0 += incx*2; + } + } + else // BLIS_CONJUGATE + { + for ( i = 0; i < n; ++i ) + { + const float x0c = *x0; + const float x1c = *( x0+1 ); + + *x0 = alphar * x0c + alphai * x1c; + *(x0 + 1) = alphai * x0c - alphar * x1c; + + x0 += incx*2; + } + } + } +} + +// ----------------------------------------------------------------------------- + +void bli_zscalv_zen_int10 + ( + conj_t conjalpha, + dim_t n, + dcomplex* restrict alpha, + dcomplex* restrict x, + inc_t incx, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_4) + + const dim_t n_elem_per_reg = 4; + + dim_t i; + + double* restrict x0; + double* restrict alpha0; + double alphaR, alphaI; + + __m256d alphaRv; + __m256d alphaIv; + __m256d xv[10]; + __m256d x_sufv[10]; + + conj_t conjx_use = conjalpha; + + // If the vector dimension is zero, or if alpha is unit, return early. + if ( bli_zero_dim1( n ) || PASTEMAC(z,eq1)( *alpha ) ) return; + + // If alpha is zero, use setv. + if ( PASTEMAC(z,eq0)( *alpha ) ) + { + dcomplex* zero = bli_z0; + + if (cntx == NULL) + cntx = bli_gks_query_cntx(); + zsetv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_DCOMPLEX, BLIS_SETV_KER, cntx ); + f + ( + BLIS_NO_CONJUGATE, + n, + zero, + x, incx, + cntx + ); + + return; + } + + // Initialize local pointers. + x0 = (double*)x; + alpha0 = (double*)alpha; + + alphaR = alpha->real; + alphaI = alpha->imag; + + if ( incx == 1 ) + { + // Broadcast the alpha scalar to all elements of a vector register. + if ( !bli_is_conj (conjx_use) ) // If BLIS_NO_CONJUGATE + { + alphaRv = _mm256_broadcast_sd( &alphaR ); + alphaIv = _mm256_set_pd(alphaI, -alphaI, alphaI, -alphaI); + } + else + { + alphaIv = _mm256_broadcast_sd( &alphaI ); + alphaRv = _mm256_set_pd(alphaR, -alphaR, alphaR, -alphaR); + } + + /* + = (alpha_r + alpha_i) * (x_r + x_i) + = alpha_r*x_r + alpha_r*x_i + alpha_i*x_r + (-alpha_i*x_i) + = (alpha_r*x_r - alpha_i*x_i) + (alpha_r*x_i + alpha_i*x_r)I + + x = x_r , x_i , x_r , x_i , x_r , x_i , x_r , x_i + x_suf = x_i , x_r , x_i , x_r , x_i , x_r , x_i , x_r + alphaR = ar , ar , ar , ar , ar , ar , ar , ar + alphaI = -ai , ai ,-ai , ai ,-ai , ai ,-ai, ai + + step 1) Load x. + step 2) Shuffle x. + step 3) mul x <= x*alphaR => ar*x_r , ar*x_i + step 4) fma x <= x_suf*alphaI + x => (-ai*x_i , ai*x_r) + (ar*x_r , ar*x_i) + => (ar*x_r - ai*x_i), (ar*x_i + ai*x_r ) + */ + + for ( i = 0; (i + 19) < n; i += 20 ) + { + // Load the input values. + xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); + xv[2] = _mm256_loadu_pd( x0 + 2*n_elem_per_reg ); + xv[3] = _mm256_loadu_pd( x0 + 3*n_elem_per_reg ); + xv[4] = _mm256_loadu_pd( x0 + 4*n_elem_per_reg ); + xv[5] = _mm256_loadu_pd( x0 + 5*n_elem_per_reg ); + xv[6] = _mm256_loadu_pd( x0 + 6*n_elem_per_reg ); + xv[7] = _mm256_loadu_pd( x0 + 7*n_elem_per_reg ); + xv[8] = _mm256_loadu_pd( x0 + 8*n_elem_per_reg ); + xv[9] = _mm256_loadu_pd( x0 + 9*n_elem_per_reg ); + + // x = xr0 , xi0, xr1, xi1 .... + // x_suf = xi0 , xr0, xi1, xr1 .... + x_sufv[0] = _mm256_permute_pd( xv[0], 5); + x_sufv[1] = _mm256_permute_pd( xv[1], 5); + x_sufv[2] = _mm256_permute_pd( xv[2], 5); + x_sufv[3] = _mm256_permute_pd( xv[3], 5); + x_sufv[4] = _mm256_permute_pd( xv[4], 5); + x_sufv[5] = _mm256_permute_pd( xv[5], 5); + x_sufv[6] = _mm256_permute_pd( xv[6], 5); + x_sufv[7] = _mm256_permute_pd( xv[7], 5); + x_sufv[8] = _mm256_permute_pd( xv[8], 5); + x_sufv[9] = _mm256_permute_pd( xv[9], 5); + + // mul x <= x*alphaR + // aphhaR = ar , ar , ar , ar , .... + // x = xr , xi , xr , xi , .... + // mul = ar*xr, ar*xi , ar*xr , ar*xi, .... + xv[0] = _mm256_mul_pd( alphaRv, xv[0] ); + xv[1] = _mm256_mul_pd( alphaRv, xv[1] ); + xv[2] = _mm256_mul_pd( alphaRv, xv[2] ); + xv[3] = _mm256_mul_pd( alphaRv, xv[3] ); + xv[4] = _mm256_mul_pd( alphaRv, xv[4] ); + xv[5] = _mm256_mul_pd( alphaRv, xv[5] ); + xv[6] = _mm256_mul_pd( alphaRv, xv[6] ); + xv[7] = _mm256_mul_pd( alphaRv, xv[7] ); + xv[8] = _mm256_mul_pd( alphaRv, xv[8] ); + xv[9] = _mm256_mul_pd( alphaRv, xv[9] ); + + // fma x <= x_suf*alphaI + x + // alphaI = -ai , ai , -ai , ai .... + // X suf = xi , xr , xi , xr .... + // mul = -ai*xi, ai*xr , -ai*xi, ai*xi .... + // add x = ar*xr - ai*xi, ar*xi + ai*xr, .... + xv[0] = _mm256_fmadd_pd( alphaIv, x_sufv[0], xv[0] ); + xv[1] = _mm256_fmadd_pd( alphaIv, x_sufv[1], xv[1] ); + xv[2] = _mm256_fmadd_pd( alphaIv, x_sufv[2], xv[2] ); + xv[3] = _mm256_fmadd_pd( alphaIv, x_sufv[3], xv[3] ); + xv[4] = _mm256_fmadd_pd( alphaIv, x_sufv[4], xv[4] ); + xv[5] = _mm256_fmadd_pd( alphaIv, x_sufv[5], xv[5] ); + xv[6] = _mm256_fmadd_pd( alphaIv, x_sufv[6], xv[6] ); + xv[7] = _mm256_fmadd_pd( alphaIv, x_sufv[7], xv[7] ); + xv[8] = _mm256_fmadd_pd( alphaIv, x_sufv[8], xv[8] ); + xv[9] = _mm256_fmadd_pd( alphaIv, x_sufv[9], xv[9] ); + + // Store the output. + _mm256_storeu_pd( (x0 + 0*n_elem_per_reg), xv[0] ); + _mm256_storeu_pd( (x0 + 1*n_elem_per_reg), xv[1] ); + _mm256_storeu_pd( (x0 + 2*n_elem_per_reg), xv[2] ); + _mm256_storeu_pd( (x0 + 3*n_elem_per_reg), xv[3] ); + _mm256_storeu_pd( (x0 + 4*n_elem_per_reg), xv[4] ); + _mm256_storeu_pd( (x0 + 5*n_elem_per_reg), xv[5] ); + _mm256_storeu_pd( (x0 + 6*n_elem_per_reg), xv[6] ); + _mm256_storeu_pd( (x0 + 7*n_elem_per_reg), xv[7] ); + _mm256_storeu_pd( (x0 + 8*n_elem_per_reg), xv[8] ); + _mm256_storeu_pd( (x0 + 9*n_elem_per_reg), xv[9] ); + + x0 += 10*n_elem_per_reg; + } + + for ( ; (i + 9) < n; i += 10 ) + { + // Load the input values. + xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); + xv[2] = _mm256_loadu_pd( x0 + 2*n_elem_per_reg ); + xv[3] = _mm256_loadu_pd( x0 + 3*n_elem_per_reg ); + xv[4] = _mm256_loadu_pd( x0 + 4*n_elem_per_reg ); + + // x = xr0 , xi0, xr1, xi1 + // x_suf = xi0 , xr0, xi1, xr1 + x_sufv[0] = _mm256_permute_pd( xv[0], 5); + x_sufv[1] = _mm256_permute_pd( xv[1], 5); + x_sufv[2] = _mm256_permute_pd( xv[2], 5); + x_sufv[3] = _mm256_permute_pd( xv[3], 5); + x_sufv[4] = _mm256_permute_pd( xv[4], 5); + + // mul x <= x*alphaR + // aphhaR = ar , ar , ar , ar + // x = xr , xi , xr , xi + // mul = ar*xr, ar*xi , ar*xr , ar*xi + xv[0] = _mm256_mul_pd( alphaRv, xv[0] ); + xv[1] = _mm256_mul_pd( alphaRv, xv[1] ); + xv[2] = _mm256_mul_pd( alphaRv, xv[2] ); + xv[3] = _mm256_mul_pd( alphaRv, xv[3] ); + xv[4] = _mm256_mul_pd( alphaRv, xv[4] ); + + // fma x <= x_suf*alphaI + x + // alphaI = -ai , ai , -ai , ai + // X = xi , xr , xi , xr + // mul = -ai*xi, ai*xr , -ai*xi, ai*xi + // add x = ar*xr - ai*xi, ar*xi + ai*xr, + xv[0] = _mm256_fmadd_pd( alphaIv, x_sufv[0], xv[0] ); + xv[1] = _mm256_fmadd_pd( alphaIv, x_sufv[1], xv[1] ); + xv[2] = _mm256_fmadd_pd( alphaIv, x_sufv[2], xv[2] ); + xv[3] = _mm256_fmadd_pd( alphaIv, x_sufv[3], xv[3] ); + xv[4] = _mm256_fmadd_pd( alphaIv, x_sufv[4], xv[4] ); + + // Store the output. + _mm256_storeu_pd( (x0 + 0*n_elem_per_reg), xv[0] ); + _mm256_storeu_pd( (x0 + 1*n_elem_per_reg), xv[1] ); + _mm256_storeu_pd( (x0 + 2*n_elem_per_reg), xv[2] ); + _mm256_storeu_pd( (x0 + 3*n_elem_per_reg), xv[3] ); + _mm256_storeu_pd( (x0 + 4*n_elem_per_reg), xv[4] ); + + x0 += 5*n_elem_per_reg; + } + + for ( ; (i + 7) < n; i += 8 ) + { + // Load the input values. + xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); + xv[2] = _mm256_loadu_pd( x0 + 2*n_elem_per_reg ); + xv[3] = _mm256_loadu_pd( x0 + 3*n_elem_per_reg ); + + // x = xr0 , xi0, xr1, xi1 .... + // x_suf = xi0 , xr0, xi1, xr1 .... + x_sufv[0] = _mm256_permute_pd( xv[0], 5); + x_sufv[1] = _mm256_permute_pd( xv[1], 5); + x_sufv[2] = _mm256_permute_pd( xv[2], 5); + x_sufv[3] = _mm256_permute_pd( xv[3], 5); + + // mul x <= x*alphaR + // aphhaR = ar , ar , ar , ar , .... + // x = xr , xi , xr , xi , .... + // mul = ar*xr, ar*xi , ar*xr , ar*xi, .... + xv[0] = _mm256_mul_pd( alphaRv, xv[0] ); + xv[1] = _mm256_mul_pd( alphaRv, xv[1] ); + xv[2] = _mm256_mul_pd( alphaRv, xv[2] ); + xv[3] = _mm256_mul_pd( alphaRv, xv[3] ); + + // fma x <= x_suf*alphaI + x + // alphaI = -ai , ai , -ai , ai .... + // X = xi , xr , xi , xr .... + // mul = -ai*xi, ai*xr , -ai*xi, ai*xi .... + // add x = ar*xr - ai*xi, ar*xi + ai*xr, + xv[0] = _mm256_fmadd_pd( alphaIv, x_sufv[0], xv[0] ); + xv[1] = _mm256_fmadd_pd( alphaIv, x_sufv[1], xv[1] ); + xv[2] = _mm256_fmadd_pd( alphaIv, x_sufv[2], xv[2] ); + xv[3] = _mm256_fmadd_pd( alphaIv, x_sufv[3], xv[3] ); + + // Store the output. + _mm256_storeu_pd( (x0 + 0*n_elem_per_reg), xv[0] ); + _mm256_storeu_pd( (x0 + 1*n_elem_per_reg), xv[1] ); + _mm256_storeu_pd( (x0 + 2*n_elem_per_reg), xv[2] ); + _mm256_storeu_pd( (x0 + 3*n_elem_per_reg), xv[3] ); + + x0 += 4*n_elem_per_reg; + } + + + for ( ; (i + 3) < n; i += 4 ) + { + // Load the input values. + xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); + + // x = xr0 , xi0, xr1, xi1 .... + // x_suf = xi0 , xr0, xi1, xr1 .... + x_sufv[0] = _mm256_permute_pd( xv[0], 5); + x_sufv[1] = _mm256_permute_pd( xv[1], 5); + + // mul x <= x*alphaR + // aphhaR = ar , ar , ar , ar , .... + // x = xr , xi , xr , xi , .... + // mul = ar*xr, ar*xi , ar*xr , ar*xi, .... + xv[0] = _mm256_mul_pd( alphaRv, xv[0] ); + xv[1] = _mm256_mul_pd( alphaRv, xv[1] ); + + // fma x <= x_suf*alphaI + x + // alphaI = -ai , ai , -ai , ai .... + // X = xi , xr , xi , xr .... + // mul = -ai*xi, ai*xr , -ai*xi, ai*xi .... + // add x = ar*xr - ai*xi, ar*xi + ai*xr, + xv[0] = _mm256_fmadd_pd( alphaIv, x_sufv[0], xv[0] ); + xv[1] = _mm256_fmadd_pd( alphaIv, x_sufv[1], xv[1] ); + + // Store the output. + _mm256_storeu_pd( (x0 + 0*n_elem_per_reg), xv[0] ); + _mm256_storeu_pd( (x0 + 1*n_elem_per_reg), xv[1] ); + + x0 += 2*n_elem_per_reg; + } + + for ( ; (i + 1) < n; i += 2 ) + { + // Load the input values. + xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); + + // x = xr0 , xi0, xr1, xi1 .... + // x_suf = xi0 , xr0, xi1, xr1 .... + x_sufv[0] = _mm256_permute_pd( xv[0], 5); + + // mul x <= x*alphaR + // aphhaR = ar , ar , ar , ar , .... + // x = xr , xi , xr , xi , .... + // mul = ar*xr, ar*xi , ar*xr , ar*xi, .... + xv[0] = _mm256_mul_pd( alphaRv, xv[0] ); + + // fma x <= x_suf*alphaI + x + // alphaI = -ai , ai , -ai , ai .... + // X = xi , xr , xi , xr .... + // mul = -ai*xi, ai*xr , -ai*xi, ai*xi .... + // add x = ar*xr - ai*xi, ar*xi + ai*xr, + xv[0] = _mm256_fmadd_pd( alphaIv, x_sufv[0], xv[0] ); + + // Store the output. + _mm256_storeu_pd( (x0 + 0*n_elem_per_reg), xv[0] ); + + x0 += 1*n_elem_per_reg; + } + + for ( ; (i + 0) < n; i += 1 ) + { + double real; + + // real part: ( aR.xR - aIxI ) + real = *alpha0 * (*x0) - (*(alpha0 + 1)) * (*(x0+1)); + // img part: ( aR.xI + aI.xR ) + *(x0 + 1) = *alpha0 * (*(x0+1)) + (*(alpha0 + 1)) * (*x0); + + *x0 = real; + + x0 += 2; + } + } + else + { + const double alphar = *alpha0; + const double alphai = *(alpha0 + 1); + + if ( !bli_is_conj(conjx_use) ) // BLIS_NO_CONJUGATE + { + for ( i = 0; i < n; ++i ) + { + const double x0c = *x0; + const double x1c = *( x0 + 1 ); + + *x0 = alphar * x0c - alphai * x1c; + *(x0 + 1) = alphar * x1c + alphai * x0c; + + x0 += incx*2; + } + } + else // BLIS_CONJUGATE + { + for ( i = 0; i < n; ++i ) + { + const double x0c = *x0; + const double x1c = *( x0 + 1 ); + + *x0 = alphar * x0c + alphai * x1c; + *(x0 + 1) = alphai * x0c - alphar * x1c; + + x0 += incx*2; + } + } + } +} diff --git a/kernels/zen/bli_kernels_zen.h b/kernels/zen/bli_kernels_zen.h index d21eb6fe28..02b73ba16c 100644 --- a/kernels/zen/bli_kernels_zen.h +++ b/kernels/zen/bli_kernels_zen.h @@ -79,6 +79,8 @@ SCALV_KER_PROT( double, d, scalv_zen_int ) // scalv (intrinsics unrolled x10) SCALV_KER_PROT( float, s, scalv_zen_int10 ) SCALV_KER_PROT( double, d, scalv_zen_int10 ) +SCALV_KER_PROT( scomplex, c, scalv_zen_int10 ) +SCALV_KER_PROT( dcomplex, z, scalv_zen_int10 ) // swapv (intrinsics) SWAPV_KER_PROT(float, s, swapv_zen_int8 ) diff --git a/test/Makefile b/test/Makefile index 3370ce7157..7521fb7f13 100644 --- a/test/Makefile +++ b/test/Makefile @@ -155,7 +155,7 @@ CFLAGS += -I$(TEST_SRC_PATH) # # Define the operations we will test. -TEST_OPS := dotv axpyv \ +TEST_OPS := dotv axpyv scalv \ gemv ger hemv her her2 trmv trsv \ gemm hemm herk her2k trmm trsm \ From 40e1fd186048ee29147f75a82a2adbe25e953625 Mon Sep 17 00:00:00 2001 From: Nallani Bhaskar Date: Mon, 27 Sep 2021 19:10:43 +0530 Subject: [PATCH 029/243] Enabled processing of zero inputs of x in tbsv to provide NaN Details: The basic idea is inverse of singular matrix doesn't exist, therefore we should be returning NAN. BLAS standard and BLIS is optimizing by not doing any compute when x[j]== 0. As a result BLIS is generating finite values for inverse calculation of singular matrices which in reality is not the right answer. Fix is provided in this commit to generate NAN/INF values incase this API is called to compute inverses of singular matrices. But according to the standard, this API shouldn't be called in the first place, the check for singularity or near singularity should be done by the calling application Change-Id: Iccdbc07744de3892626f4066ee4a63eb30bc06cd --- frame/compat/f2c/bla_tbsv.c | 99 +++++++++++++++++++++++++++++++------ 1 file changed, 83 insertions(+), 16 deletions(-) diff --git a/frame/compat/f2c/bla_tbsv.c b/frame/compat/f2c/bla_tbsv.c index 6914882d2b..819456f029 100644 --- a/frame/compat/f2c/bla_tbsv.c +++ b/frame/compat/f2c/bla_tbsv.c @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -267,7 +268,13 @@ if (*incx == 1) { for (j = *n; j >= 1; --j) { i__1 = j; - if (bli_creal(x[i__1]) != 0.f || bli_cimag(x[i__1]) != 0.f) { + //When matrix A is singular or near singular, the solution to Ax = b is non-trivial + //Therefore inverse of A doesn't exist. Here by commenting out the below lines, + //we end up generating NAN or inf + //Therefore inverse of A doesn't exist. Here by commenting out the below lines, + //we end up generating NAN or inf + //if (bli_creal(x[i__1]) != 0.f || bli_cimag(x[i__1]) != 0.f) + { l = kplus1 - j; if (nounit) { i__1 = j; @@ -297,7 +304,11 @@ for (j = *n; j >= 1; --j) { kx -= *incx; i__1 = jx; - if (bli_creal(x[i__1]) != 0.f || bli_cimag(x[i__1]) != 0.f) { + //When matrix A is singular or near singular, the solution to Ax = b is non-trivial + //Therefore inverse of A doesn't exist. Here by commenting out the below lines, + //we end up generating NAN or inf + //if (bli_creal(x[i__1]) != 0.f || bli_cimag(x[i__1]) != 0.f) + { ix = kx; l = kplus1 - j; if (nounit) { @@ -330,7 +341,11 @@ i__1 = *n; for (j = 1; j <= i__1; ++j) { i__2 = j; - if (bli_creal(x[i__2]) != 0.f || bli_cimag(x[i__2]) != 0.f) { + //When matrix A is singular or near singular, the solution to Ax = b is non-trivial + //Therefore inverse of A doesn't exist. Here by commenting out the below lines, + //we end up generating NAN or inf + //if (bli_creal(x[i__2]) != 0.f || bli_cimag(x[i__2]) != 0.f) + { l = 1 - j; if (nounit) { i__2 = j; @@ -360,7 +375,11 @@ for (j = 1; j <= i__1; ++j) { kx += *incx; i__2 = jx; - if (bli_creal(x[i__2]) != 0.f || bli_cimag(x[i__2]) != 0.f) { + //When matrix A is singular or near singular, the solution to Ax = b is non-trivial + //Therefore inverse of A doesn't exist. Here by commenting out the below lines, + //we end up generating NAN or inf + //if (bli_creal(x[i__2]) != 0.f || bli_cimag(x[i__2]) != 0.f) + { ix = kx; l = 1 - j; if (nounit) { @@ -823,7 +842,11 @@ kplus1 = *k + 1; if (*incx == 1) { for (j = *n; j >= 1; --j) { - if (x[j] != 0.) { + //When matrix A is singular or near singular, the solution to Ax = b is non-trivial + //Therefore inverse of A doesn't exist. Here by commenting out the below lines, + //we end up generating NAN or inf + //if (x[j] != 0.) + { l = kplus1 - j; if (nounit) { x[j] /= a[kplus1 + j * a_dim1]; @@ -844,7 +867,11 @@ jx = kx; for (j = *n; j >= 1; --j) { kx -= *incx; - if (x[jx] != 0.) { + //When matrix A is singular or near singular, the solution to Ax = b is non-trivial + //Therefore inverse of A doesn't exist. Here by commenting out the below lines, + //we end up generating NAN or inf + //if (x[jx] != 0.) + { ix = kx; l = kplus1 - j; if (nounit) { @@ -868,7 +895,11 @@ if (*incx == 1) { i__1 = *n; for (j = 1; j <= i__1; ++j) { - if (x[j] != 0.) { + //When matrix A is singular or near singular, the solution to Ax = b is non-trivial + //Therefore inverse of A doesn't exist. Here by commenting out the below lines, + //we end up generating NAN or inf + //if (x[j] != 0.) + { l = 1 - j; if (nounit) { x[j] /= a[j * a_dim1 + 1]; @@ -889,7 +920,11 @@ i__1 = *n; for (j = 1; j <= i__1; ++j) { kx += *incx; - if (x[jx] != 0.) { + //When matrix A is singular or near singular, the solution to Ax = b is non-trivial + //Therefore inverse of A doesn't exist. Here by commenting out the below lines, + //we end up generating NAN or inf + //if (x[jx] != 0.) + { ix = kx; l = 1 - j; if (nounit) { @@ -1238,7 +1273,11 @@ kplus1 = *k + 1; if (*incx == 1) { for (j = *n; j >= 1; --j) { - if (x[j] != 0.f) { + //When matrix A is singular or near singular, the solution to Ax = b is non-trivial + //Therefore inverse of A doesn't exist. Here by commenting out the below lines, + //we end up generating NAN or inf + //if (x[j] != 0.f) + { l = kplus1 - j; if (nounit) { x[j] /= a[kplus1 + j * a_dim1]; @@ -1259,7 +1298,11 @@ jx = kx; for (j = *n; j >= 1; --j) { kx -= *incx; - if (x[jx] != 0.f) { + //When matrix A is singular or near singular, the solution to Ax = b is non-trivial + //Therefore inverse of A doesn't exist. Here by commenting out the below lines, + //we end up generating NAN or inf + //if (x[jx] != 0.f) + { ix = kx; l = kplus1 - j; if (nounit) { @@ -1283,7 +1326,11 @@ if (*incx == 1) { i__1 = *n; for (j = 1; j <= i__1; ++j) { - if (x[j] != 0.f) { + //When matrix A is singular or near singular, the solution to Ax = b is non-trivial + //Therefore inverse of A doesn't exist. Here by commenting out the below lines, + //we end up generating NAN or inf + //if (x[j] != 0.f) + { l = 1 - j; if (nounit) { x[j] /= a[j * a_dim1 + 1]; @@ -1304,7 +1351,11 @@ i__1 = *n; for (j = 1; j <= i__1; ++j) { kx += *incx; - if (x[jx] != 0.f) { + //When matrix A is singular or near singular, the solution to Ax = b is non-trivial + //Therefore inverse of A doesn't exist. Here by commenting out the below lines, + //we end up generating NAN or inf + //if (x[jx] != 0.f) + { ix = kx; l = 1 - j; if (nounit) { @@ -1660,7 +1711,11 @@ if (*incx == 1) { for (j = *n; j >= 1; --j) { i__1 = j; - if (bli_zreal(x[i__1]) != 0. || bli_zimag(x[i__1]) != 0.) { + //When matrix A is singular or near singular, the solution to Ax = b is non-trivial + //Therefore inverse of A doesn't exist. Here by commenting out the below lines, + //we end up generating NAN or inf + //if (bli_zreal(x[i__1]) != 0. || bli_zimag(x[i__1]) != 0.) + { l = kplus1 - j; if (nounit) { i__1 = j; @@ -1690,7 +1745,11 @@ for (j = *n; j >= 1; --j) { kx -= *incx; i__1 = jx; - if (bli_zreal(x[i__1]) != 0. || bli_zimag(x[i__1]) != 0.) { + //When matrix A is singular or near singular, the solution to Ax = b is non-trivial + //Therefore inverse of A doesn't exist. Here by commenting out the below lines, + //we end up generating NAN or inf + //if (bli_zreal(x[i__1]) != 0. || bli_zimag(x[i__1]) != 0.) + { ix = kx; l = kplus1 - j; if (nounit) { @@ -1723,7 +1782,11 @@ i__1 = *n; for (j = 1; j <= i__1; ++j) { i__2 = j; - if (bli_zreal(x[i__2]) != 0. || bli_zimag(x[i__2]) != 0.) { + //When matrix A is singular or near singular, the solution to Ax = b is non-trivial + //Therefore inverse of A doesn't exist. Here by commenting out the below lines, + //we end up generating NAN or inf + //if (bli_zreal(x[i__2]) != 0. || bli_zimag(x[i__2]) != 0.) + { l = 1 - j; if (nounit) { i__2 = j; @@ -1753,7 +1816,11 @@ for (j = 1; j <= i__1; ++j) { kx += *incx; i__2 = jx; - if (bli_zreal(x[i__2]) != 0. || bli_zimag(x[i__2]) != 0.) { + //When matrix A is singular or near singular, the solution to Ax = b is non-trivial + //Therefore inverse of A doesn't exist. Here by commenting out the below lines, + //we end up generating NAN or inf + //if (bli_zreal(x[i__2]) != 0. || bli_zimag(x[i__2]) != 0.) + { ix = kx; l = 1 - j; if (nounit) { From a7f600b3a44a9cd48569251114c2649b799fca88 Mon Sep 17 00:00:00 2001 From: Harsh Dave Date: Wed, 29 Sep 2021 05:36:54 -0500 Subject: [PATCH 030/243] Implemented ztrsm small kernels Details: -- AMD Internal Id: CPUPL-1702 -- Used 4x3 ZGEMM kernel with vector fma by utilizing ymm registers efficiently to produce 12 dcomplex outputs at a time -- Used packing of matrix A to effectively cache and reuse -- Implemented kernels using macro based modular approach -- Added ztrsm_small for in ztrsm_ BLAS path for single thread when (m,n)<500 and multithread (m+n)<128 -- Taken care of --disable_pre_inversion configuration -- Achieved 10% average performance improvement for sizes less than 500 -- modularized all 16 combinations of trsm into 4 kernels Change-Id: I3cb42a1385f6b3b82d6c470912242675789cce75 --- frame/compat/bla_trsm.c | 301 +- kernels/zen/3/bli_trsm_small.c | 7479 +++++++++++++++++++++++++++++++- 2 files changed, 7764 insertions(+), 16 deletions(-) diff --git a/frame/compat/bla_trsm.c b/frame/compat/bla_trsm.c index 95ae079cc1..943164d367 100644 --- a/frame/compat/bla_trsm.c +++ b/frame/compat/bla_trsm.c @@ -639,8 +639,307 @@ void dtrsm_ bli_finalize_auto(); } +void ztrsm_ +( + const f77_char* side, + const f77_char* uploa, + const f77_char* transa, + const f77_char* diaga, + const f77_int* m, + const f77_int* n, + const dcomplex* alpha, + const dcomplex* a, const f77_int* lda, + dcomplex* b, const f77_int* ldb +) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO) + AOCL_DTL_LOG_TRSM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'z', + *side, *uploa,*transa, *diaga, *m, *n, + (void*)alpha,*lda, *ldb); + + side_t blis_side; + uplo_t blis_uploa; + trans_t blis_transa; + diag_t blis_diaga; + dim_t m0, n0; + conj_t conja = BLIS_NO_CONJUGATE ; + + /* Initialize BLIS. */ + bli_init_auto(); + + /* Perform BLAS parameter checking. */ + PASTEBLACHK(trsm) + ( + MKSTR(z), + MKSTR(trsm), + side, + uploa, + transa, + diaga, + m, + n, + lda, + ldb + ); + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + bli_param_map_netlib_to_blis_side( *side, &blis_side ); + bli_param_map_netlib_to_blis_uplo( *uploa, &blis_uploa ); + bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); + bli_param_map_netlib_to_blis_diag( *diaga, &blis_diaga ); + + /* Typecast BLAS integers to BLIS integers. */ + bli_convert_blas_dim1( *m, m0 ); + bli_convert_blas_dim1( *n, n0 ); + + /* Set the row and column strides of the matrix operands. */ + const inc_t rs_a = 1; + const inc_t cs_a = *lda; + const inc_t rs_b = 1; + const inc_t cs_b = *ldb; + const num_t dt = BLIS_DCOMPLEX; + + + if( n0 == 1 ) + { + if( blis_side == BLIS_LEFT ) + { + if(bli_is_notrans(blis_transa)) + { + bli_ztrsv_unf_var2 + ( + blis_uploa, + blis_transa, + blis_diaga, + m0, + (dcomplex*)alpha, + (dcomplex*)a, rs_a, cs_a, + (dcomplex*)b, rs_b, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + else if(bli_is_trans(blis_transa)) + { + bli_ztrsv_unf_var1 + ( + blis_uploa, + blis_transa, + blis_diaga, + m0, + (dcomplex*)alpha, + (dcomplex*)a, rs_a, cs_a, + (dcomplex*)b, rs_b, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + } + else if( ( blis_side == BLIS_RIGHT ) && ( m0 != 1 ) ) + { + /** NOTE: Since for RUCN kernel, function seem to + * be having issue with the computation, which is + * causing make check to fail, For time being, letting + * this particular case through small ztrsm for sake + * of make check. + * TODO: code snippet needs to be enabled, once + * fix is done. + */ + + /* b = alpha * b; */ +/* bli_zscalv_ex + ( + conja, + m0, + (dcomplex*)alpha, + (dcomplex*)b, rs_b, + NULL, + NULL + ); + if(blis_diaga == BLIS_NONUNIT_DIAG) + { + dcomplex inva = {0, 0}; + inva.real = a->real; + inva.imag = (a->imag * -1.0); + double dnm = (a->real * a->real); + dnm += ( (-1.0 * (a->imag * a->imag )) * -1.0 ); + inva.real /= dnm; + inva.imag /= dnm; + for(int indx = 0; indx < m0; indx ++) + { + double real = (inva.real * b[indx].real); + real += ((inva.imag * b[indx].imag) * -1.0); + double imag = (inva.real * b[indx].imag); + imag += (inva.imag * b[indx].real); + b[indx].real = real; + b[indx].imag = imag; + } + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return;*/ + } + } + else if( m0 == 1 ) + { + if(blis_side == BLIS_RIGHT) + { + if(bli_is_notrans(blis_transa)) + { + if(blis_uploa == BLIS_UPPER) + blis_uploa = BLIS_LOWER; + else + blis_uploa = BLIS_UPPER; + + bli_ztrsv_unf_var1 + ( + blis_uploa, + blis_transa, + blis_diaga, + n0, + (dcomplex*)alpha, + (dcomplex*)a, cs_a, rs_a, + (dcomplex*)b, cs_b, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + else if(bli_is_trans(blis_transa)) + { + if(blis_uploa == BLIS_UPPER) + blis_uploa = BLIS_LOWER; + else + blis_uploa = BLIS_UPPER; + + bli_ztrsv_unf_var2 + ( + blis_uploa, + blis_transa, + blis_diaga, + n0, + (dcomplex*)alpha, + (dcomplex*)a, cs_a, rs_a, + (dcomplex*)b, cs_b, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + } + else if(( blis_side == BLIS_LEFT ) && ( n0 != 1 )) + { + /** NOTE: Since for LUCN kernel, function seem to + * be having issue with the computation, which is + * causing make check to fail, For time being, letting + * this particular case through small ztrsm for sake + * of make check. + * TODO: code snippet needs to be enabled, once + * fix is done. + */ + /* b = alpha * b; */ +/* bli_zscalv_ex + ( + conja, + n0, + (dcomplex*)alpha, + (dcomplex*)b, cs_b, + NULL, + NULL + ); + if(blis_diaga == BLIS_NONUNIT_DIAG) + { + dcomplex inva = {0, 0}; + inva.real = a->real; + inva.imag = (a->imag * -1.0); + double dnm = (a->real * a->real); + dnm += ( (-1.0 * (a->imag * a->imag )) * -1.0 ); + inva.real /= dnm; + inva.imag /= dnm; + for(int indx = 0; indx < n0; indx ++) + { + double real = (inva.real * b[indx*cs_b].real); + real += ((inva.imag * b[indx*cs_b].imag) * -1.0); + double imag = (inva.real * b[indx*cs_b].imag); + imag += (inva.imag * b[indx*cs_b].real); + b[indx*cs_b].real = real; + b[indx*cs_b].imag = imag; + } + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return;*/ + } + } + + const struc_t struca = BLIS_TRIANGULAR; + + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; + obj_t ao = BLIS_OBJECT_INITIALIZER; + obj_t bo = BLIS_OBJECT_INITIALIZER; + + dim_t mn0_a; + + bli_set_dim_with_side( blis_side, m0, n0, &mn0_a ); + + bli_obj_init_finish_1x1( dt, (dcomplex*)alpha, &alphao ); + + bli_obj_init_finish( dt, mn0_a, mn0_a, (dcomplex*)a, rs_a, cs_a, &ao ); + bli_obj_init_finish( dt, m0, n0, (dcomplex*)b, rs_b, cs_b, &bo ); + + bli_obj_set_uplo( blis_uploa, &ao ); + bli_obj_set_diag( blis_diaga, &ao ); + bli_obj_set_conjtrans( blis_transa, &ao ); + + bli_obj_set_struc( struca, &ao ); + +#ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM + /* bli_ztrsm_small is performing better existing native + * implementations for [m,n]<=1000 for single thread. + * In case of multithread when [m,n]<=128 sinlge thread implemenation + * is doing better than native multithread */ + bool nt = bli_thread_get_is_parallel(); + if((nt==0 && m0<500 && n0<500) || + (nt && (m0+n0)<128) ) + { + err_t status; + status = bli_trsm_small + ( + blis_side, + &alphao, + &ao, + &bo, + NULL, + NULL + ); + if (status == BLIS_SUCCESS) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + /* Finalize BLIS. */ + bli_finalize_auto(); + return; + } + } +#endif + + bli_trsmnat + ( + blis_side, + &alphao, + &ao, + &bo, + NULL, + NULL + ); + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) + /* Finalize BLIS. */ + bli_finalize_auto(); +} + + GENTFUNC( float, s, trsm, trsm ) -INSERT_GENTFUNC_BLAS_CZ( trsm, trsm ) +GENTFUNC( scomplex, c, trsm, trsm ) #else INSERT_GENTFUNC_BLAS( trsm, trsm ) #endif diff --git a/kernels/zen/3/bli_trsm_small.c b/kernels/zen/3/bli_trsm_small.c index ea9de2a889..6e8455d024 100644 --- a/kernels/zen/3/bli_trsm_small.c +++ b/kernels/zen/3/bli_trsm_small.c @@ -119,6 +119,121 @@ BLIS_INLINE err_t dtrsm_AltXB_ref dim_t ldb, bool is_unitdiag ); +/* + * ZTRSM kernel declaration + */ +BLIS_INLINE err_t bli_ztrsm_small_AutXB_AlXB +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +); + +BLIS_INLINE err_t bli_ztrsm_small_AltXB_AuXB +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +); + +BLIS_INLINE err_t bli_ztrsm_small_XAutB_XAlB +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +); + +BLIS_INLINE err_t bli_ztrsm_small_XAltB_XAuB +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +); +/* + * CTRSM kernel declaration + */ +BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +); + +BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +); + +BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +); + +BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +); +/* + * STRSM kernel declaration + */ +BLIS_INLINE err_t bli_strsm_small_AutXB_AlXB +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +); + +BLIS_INLINE err_t bli_strsm_small_AltXB_AuXB +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +); + +BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +); + +BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +); + /* * The preinversion of diagonal elements are enabled/disabled @@ -1765,16 +1880,43 @@ BLIS_INLINE void dtrsm_small_pack_diag_element /* * Kernels Table */ -trsmsmall_ker_ft ker_fps[8] = +trsmsmall_ker_ft ker_fps[4][8] = { - bli_dtrsm_small_AutXB_AlXB, - bli_dtrsm_small_AltXB_AuXB, - bli_dtrsm_small_AltXB_AuXB, - bli_dtrsm_small_AutXB_AlXB, - bli_dtrsm_small_XAutB_XAlB, - bli_dtrsm_small_XAltB_XAuB, - bli_dtrsm_small_XAltB_XAuB, - bli_dtrsm_small_XAutB_XAlB + {bli_strsm_small_AutXB_AlXB, + bli_strsm_small_AltXB_AuXB, + bli_strsm_small_AltXB_AuXB, + bli_strsm_small_AutXB_AlXB, + bli_strsm_small_XAutB_XAlB, + bli_strsm_small_XAltB_XAuB, + bli_strsm_small_XAltB_XAuB, + bli_strsm_small_XAutB_XAlB }, + + {bli_ctrsm_small_AutXB_AlXB, + bli_ctrsm_small_AltXB_AuXB, + bli_ctrsm_small_AltXB_AuXB, + bli_ctrsm_small_AutXB_AlXB, + bli_ctrsm_small_XAutB_XAlB, + bli_ctrsm_small_XAltB_XAuB, + bli_ctrsm_small_XAltB_XAuB, + bli_ctrsm_small_XAutB_XAlB }, + + {bli_dtrsm_small_AutXB_AlXB, + bli_dtrsm_small_AltXB_AuXB, + bli_dtrsm_small_AltXB_AuXB, + bli_dtrsm_small_AutXB_AlXB, + bli_dtrsm_small_XAutB_XAlB, + bli_dtrsm_small_XAltB_XAuB, + bli_dtrsm_small_XAltB_XAuB, + bli_dtrsm_small_XAutB_XAlB }, + + {bli_ztrsm_small_AutXB_AlXB, + bli_ztrsm_small_AltXB_AuXB, + bli_ztrsm_small_AltXB_AuXB, + bli_ztrsm_small_AutXB_AlXB, + bli_ztrsm_small_XAutB_XAlB, + bli_ztrsm_small_XAltB_XAuB, + bli_ztrsm_small_XAltB_XAuB, + bli_ztrsm_small_XAutB_XAlB }, }; /* @@ -1834,7 +1976,7 @@ err_t bli_trsm_small //Curretnly optimized for double data type only num_t dt = bli_obj_dt(a); - if (dt != BLIS_DOUBLE) { + if (dt != BLIS_DOUBLE && dt != BLIS_DCOMPLEX) { return BLIS_NOT_YET_IMPLEMENTED; } @@ -1847,12 +1989,12 @@ err_t bli_trsm_small * Compose kernel index based on inputs */ + dim_t keridx = ( (( side & 0x1) << 2) | (( uplo & 0x1) << 1) | ( transa & 0x1) ); - - trsmsmall_ker_ft ker_fp = ker_fps[ keridx ]; + trsmsmall_ker_ft ker_fp = ker_fps[dt][ keridx ]; /*Call the kernel*/ err = ker_fp @@ -1863,7 +2005,6 @@ err_t bli_trsm_small cntx, cntl ); - return err; }; @@ -6154,7 +6295,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); @@ -10720,4 +10861,7312 @@ BLIS_INLINE err_t bli_dtrsm_small_AutXB_AlXB } return BLIS_SUCCESS; } -#endif //BLIS_ENABLE_SMALL_MATRIX_TRSM \ No newline at end of file + +/* + * ZTRSM utilities and kernel functions + */ + +#define DCOMPLEX_INV(a, b) {\ + a.real = b.real;\ + a.imag = (b.imag * -1.0);\ + /*Compute denominator eliminating imaginary component*/\ + double dnm = (b.real * b.real);\ + /*multiply two times with -1 for correct result as + * dcomplex number with positive imaginary part will + * invert the sign if not multiplied twice with -1*/\ + dnm += ((-1.0 * (b.imag * b.imag)) * -1.0);\ + /*Compute the final result by dividing real and imag part by dnm*/\ + a.real /= dnm;\ + a.imag /= dnm;\ +} + +#define DCOMPLEX_MUL(a, b, c) {\ + double real = a.real * b.real;\ + real += ((a.imag * b.imag) * -1.0);\ + double imag = (a.real * b.imag);\ + imag += (a.imag * b.real);\ + c.real = real;\ + c.imag = imag;\ +} + +#define DCOMPLEX_DIV(a, b){\ + double dnm = b.real * b.real;\ + dnm += (-1.0 * (b.imag * (b.imag * -1.0) ));\ + a.real /= dnm;\ + a.imag /= dnm;\ +} + + +#ifdef BLIS_ENABLE_TRSM_PREINVERSION +#define ZTRSM_DIAG_ELE_INV_OPS(a,b){\ + DCOMPLEX_INV(a, b)\ +} +#endif + +#ifdef BLIS_DISABLE_TRSM_PREINVERSION +#define ZTRSM_DIAG_ELE_INV_OPS(a,b) {\ + a.real = b.real;\ + a.imag = b.imag;\ +} +#endif + + +#ifdef BLIS_ENABLE_TRSM_PREINVERSION +#define ZTRSM_DIAG_ELE_EVAL_OPS(a,b,c){\ + if(!is_unitdiag)\ + DCOMPLEX_MUL(b, c, c)\ +} +#endif + +#ifdef BLIS_DISABLE_TRSM_PREINVERSION +#define ZTRSM_DIAG_ELE_EVAL_OPS(a,b,c){\ + if(!is_unitdiag)\ + {\ + a.real = b.real;\ + a.imag = (b.imag * -1.0);\ + DCOMPLEX_MUL(c, a, c)\ + DCOMPLEX_DIV(c, b)\ + }\ +} +#endif + +BLIS_INLINE err_t ztrsm_AltXB_ref +( + dcomplex *A, + dcomplex *B, + dim_t M, + dim_t N, + dim_t lda, + dim_t ldb, + bool is_unitdiag, + bool conjtransa +) +{ + dim_t i, j, k; + for (k = M-1; k >= 0; k--) + { + dcomplex lkk_inv = {1.0, 1.0}, cur_compute = {0.0, 0.0}, A_trans = {0.0, 0.0}; + if(!is_unitdiag) + { + ZTRSM_DIAG_ELE_INV_OPS(lkk_inv, A[k+k*lda]) + if(conjtransa) + { + lkk_inv.imag *= -1.0; + } + } + for (j = N -1; j >= 0; j--) + { + ZTRSM_DIAG_ELE_EVAL_OPS(cur_compute, lkk_inv, B[k + j*ldb]) + for (i = k-1; i >=0; i--) + { + if(conjtransa) + { + A_trans.real = A[i*lda + k].real; + A_trans.imag = A[i*lda + k].imag * -1.0; + } + else + { + A_trans.real = A[i*lda + k].real; + A_trans.imag = A[i*lda + k].imag; + } + + + DCOMPLEX_MUL(A_trans, B[k+j*ldb], cur_compute) + B[i + j*ldb].real -= cur_compute.real; + B[i + j*ldb].imag -= cur_compute.imag; + } + } + } + return BLIS_SUCCESS; +} + +BLIS_INLINE err_t ztrsm_AutXB_ref +( + dcomplex *A, + dcomplex *B, + dim_t M, + dim_t N, + dim_t lda, + dim_t ldb, + bool is_unitdiag, + bool conjtransa +) +{ + dim_t i, j, k; + for (k = 0; k < M; k++) + { + dcomplex lkk_inv = {1.0, 1.0}, cur_compute = {0.0, 0.0}, A_trans = {0.0, 0.0}; + if(!is_unitdiag) + { + ZTRSM_DIAG_ELE_INV_OPS(lkk_inv, A[k+k*lda]) + if(conjtransa) + { + lkk_inv.imag *= -1.0; + } + } + + for (j = 0; j < N; j++) + { + ZTRSM_DIAG_ELE_EVAL_OPS(cur_compute, lkk_inv, B[k + j*ldb]) + for (i = k+1; i < M; i++) + { + if(conjtransa) + { + A_trans.real = A[k+i*lda].real; + A_trans.imag = A[k+i*lda].imag * -1.0; + } + else + { + A_trans.real = A[k+i*lda].real; + A_trans.imag = A[k+i*lda].imag; + } + + DCOMPLEX_MUL(A_trans, B[k+j*ldb], cur_compute) + B[i + j*ldb].real -= cur_compute.real; + B[i + j*ldb].imag -= cur_compute.imag; + } + + } + + } + return BLIS_SUCCESS; +} + +BLIS_INLINE err_t ztrsm_AlXB_ref +( + dcomplex *A, + dcomplex *B, + dim_t M, + dim_t N, + dim_t lda, + dim_t ldb, + bool is_unitdiag, + bool conjtransa +) +{ + dim_t i, j, k; + for (k = 0; k < M; k++) + { + dcomplex lkk_inv = {1.0, 1.0}, cur_compute = {0.0, 0.0}, A_trans = {0.0, 0.0}; + if(!is_unitdiag) + { + ZTRSM_DIAG_ELE_INV_OPS(lkk_inv, A[k+k*lda]) + if(conjtransa) + { + lkk_inv.imag *= -1.0; + } + } + for (j = 0; j < N; j++) + { + ZTRSM_DIAG_ELE_EVAL_OPS(cur_compute, lkk_inv, B[k + j*ldb]) + for (i = k+1; i < M; i++) + { + if(conjtransa) + { + A_trans.real = A[i+k*lda].real; + A_trans.imag = A[i+k*lda].imag * -1.0; + } + else + { + A_trans.real = A[i+k*lda].real; + A_trans.imag = A[i+k*lda].imag; + } + DCOMPLEX_MUL(A_trans, B[k+j*ldb], cur_compute) + B[i + j*ldb].real -= cur_compute.real; + B[i + j*ldb].imag -= cur_compute.imag; + } + } + } + return BLIS_SUCCESS; +} + +BLIS_INLINE err_t ztrsm_AuXB_ref +( + dcomplex *A, + dcomplex *B, + dim_t M, + dim_t N, + dim_t lda, + dim_t ldb, + bool is_unitdiag, + bool conjtransa +) +{ + dim_t i, j, k; + for (k = M-1; k >= 0; k--) + { + dcomplex lkk_inv = {1.0, 1.0}, cur_compute = {0.0, 0.0}, A_trans = {0.0, 0.0}; + if(!is_unitdiag) + { + ZTRSM_DIAG_ELE_INV_OPS(lkk_inv, A[k+k*lda]) + if(conjtransa) + { + lkk_inv.imag *= -1.0; + } + + } + for (j = N -1; j >= 0; j--) + { + ZTRSM_DIAG_ELE_EVAL_OPS(cur_compute, lkk_inv, B[k + j*ldb]) + for (i = k-1; i >=0; i--) + { + if(conjtransa) + { + A_trans.real = A[i+k*lda].real; + A_trans.imag = A[i+k*lda].imag * -1.0; + } + else + { + A_trans.real = A[i+k*lda].real; + A_trans.imag = A[i+k*lda].imag; + } + + DCOMPLEX_MUL(A_trans, B[k+j*ldb], cur_compute) + B[i + j*ldb].real -= cur_compute.real; + B[i + j*ldb].imag -= cur_compute.imag; + } + } + } + return BLIS_SUCCESS; +} + +/** + * Multiplies Alpha with one dcomplex + * element of one column. + * One xmm register holds one dcomplex + * element only(real(64 bit) + imaginary(64 bit)) + */ +#define BLIS_PRE_ZTRSM_SMALL_1M_1N(AlphaVal,b11,cs_b) {\ + /*register to hold alpha*/\ + ymm16 = _mm256_broadcast_pd(( __m128d const *)(&AlphaVal));\ + \ + /*load dcomplex elements*/\ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 0));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ + /*to negate the real part of complex number*/\ + ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ + /*dcomplex multiplication and substraction*/\ + /*swaps position of real and imag components of complex number*/\ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + /*multiply with modified vec2 */\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm0, ymm16);\ + /*multiply with vec2 */\ + ymm14 = _mm256_mul_pd(ymm0, ymm14);\ + /*get the dcomplex mul answer into register*/\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm8 = _mm256_sub_pd(ymm15,ymm8);\ + xmm5 = _mm256_extractf128_pd(ymm8, 0);\ + /*store dcomplex elements*/\ + _mm_storeu_pd((double *)(b11 + cs_b * 0), xmm5);\ +} + +/** + * Multiplies Alpha with one dcomplex + * element of two columns. + */ +#define BLIS_PRE_ZTRSM_SMALL_1M_2N(AlphaVal,b11,cs_b) {\ + /*register to hold alpha*/\ + ymm16 = _mm256_broadcast_pd(( __m128d const*)(&AlphaVal));\ + \ + /*ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0));*/\ + xmm4 = _mm_loadu_pd((double const *)(b11 + cs_b * 0));\ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 1));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm4, 0);\ + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0);\ + /*to negate the real part of complex number*/\ + ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ + /*swaps position of real and imag components of complex number*/\ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + /*dcomplex multiplication and substraction*/\ + /*multiply with modified vec2 */\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm0, ymm16);\ + /*multiply with vec2 */\ + ymm14 = _mm256_mul_pd(ymm0, ymm14);\ + /*get the dcomplex mul answer into register*/\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm8 = _mm256_sub_pd(ymm15,ymm8);\ + \ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm1, ymm16);\ + ymm14 = _mm256_mul_pd(ymm1, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm9 = _mm256_sub_pd(ymm15,ymm9);\ + xmm4 = _mm256_extractf128_pd(ymm8, 0);\ + _mm_storeu_pd((double *)(b11 + cs_b * 0), xmm4);\ + xmm5 = _mm256_extractf128_pd(ymm9, 0);\ + _mm_storeu_pd((double *)(b11 + cs_b * 1), xmm5);\ +} + +#define BLIS_ZTRSM_SMALL_NREG_TRANSPOSE_1x4(b11,cs_b,AlphaVal) {\ + ymm16 = _mm256_broadcast_pd(( __m128d const *)&AlphaVal);\ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0));\ + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 2));\ + ymm1 = _mm256_broadcast_pd((__m128d const *)(&ones));\ + ymm5 = _mm256_broadcast_pd((__m128d const *)(&ones));\ + \ + ymm14 = _mm256_shuffle_pd(ymm16, ymm16, 5);\ + \ + /*dcomplex multiplication and substraction*/\ + ymm17 = _mm256_shuffle_pd(ymm0, ymm0, 15);\ + ymm18 = _mm256_shuffle_pd(ymm0, ymm0,0);\ + ymm19 = _mm256_mul_pd(ymm17, ymm14);\ + ymm15 = _mm256_fmaddsub_pd(ymm18, ymm16, ymm19);\ + ymm0 = _mm256_sub_pd(ymm15, ymm8);\ + \ + /*dcomplex multiplication and substraction*/\ + ymm17 = _mm256_shuffle_pd(ymm4, ymm4, 15);\ + ymm18 = _mm256_shuffle_pd(ymm4, ymm4,0);\ + ymm19 = _mm256_mul_pd(ymm17, ymm14);\ + ymm15 = _mm256_fmaddsub_pd(ymm18, ymm16, ymm19);\ + ymm4 = _mm256_sub_pd(ymm15, ymm12);\ +} + +/** + * Multiplies Alpha with two dcomplex + * elements of one column and store it into + * buffer b11. + */ +#define BLIS_PRE_ZTRSM_SMALL_2M_1N(AlphaVal,b11,cs_b) {\ + ymm16 = _mm256_broadcast_pd(( __m128d const*)(&AlphaVal));\ + \ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0));\ + ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ + /*dcomplex multiplication and substraction*/\ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm0, ymm16);\ + ymm14 = _mm256_mul_pd(ymm0, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm8 = _mm256_sub_pd(ymm15,ymm8);\ + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm8);\ +} + +/** + * Multiplies Alpha with two elements of + * two columns and store the result in buffer b11 + * + */ +#define BLIS_PRE_ZTRSM_SMALL_2M_2N(AlphaVal,b11,cs_b){\ + ymm16 = _mm256_broadcast_pd(( __m128d const*)(&AlphaVal));\ + \ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0));\ + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1));\ + ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ + /*dcomplex multiplication and substraction*/\ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm0, ymm16);\ + ymm14 = _mm256_mul_pd(ymm0, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm8 = _mm256_sub_pd(ymm15,ymm8);\ + \ + /*dcomplex multiplication and substraction*/\ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm1, ymm16);\ + ymm14 = _mm256_mul_pd(ymm1, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm9 = _mm256_sub_pd(ymm15,ymm9);\ + \ + _mm256_storeu_pd((double *)(b11), ymm8);\ + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm9);\ +} + +/** + * Performs GEMM operation. + * Two elements of column in ymm0 + * ymm1, ymm2 holds respective broadcasted element. + */ +#define BLIS_ZTRSM_SMALL_GEMM_2mx3n(a10,b01,cs_b,p_lda,k_iter){\ + double *tptr = (double *)b01;\ + if(conjtransa) {\ + ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_loadu_pd((double const *)(a10));\ + ymm0 = _mm256_mul_pd(ymm0, ymm18);\ + \ + ymm1 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0));\ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0 + 1));\ + \ + ymm8 = _mm256_fmadd_pd(ymm0, ymm1, ymm8);\ + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4);\ + \ + ymm1 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1));\ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1 + 1));\ + \ + ymm9 = _mm256_fmadd_pd(ymm0, ymm1, ymm9);\ + ymm5 = _mm256_fmadd_pd(ymm0, ymm2, ymm5);\ + \ + ymm1 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 2));\ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 2 + 1));\ + \ + ymm10 = _mm256_fmadd_pd(ymm0, ymm1, ymm10);\ + ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6);\ + \ + tptr += 2; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_loadu_pd((double const *)(a10));\ + \ + ymm1 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0));\ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0 + 1));\ + \ + ymm8 = _mm256_fmadd_pd(ymm0, ymm1, ymm8);\ + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4);\ + \ + ymm1 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1));\ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1 + 1));\ + \ + ymm9 = _mm256_fmadd_pd(ymm0, ymm1, ymm9);\ + ymm5 = _mm256_fmadd_pd(ymm0, ymm2, ymm5);\ + \ + ymm1 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 2));\ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 2 + 1));\ + \ + ymm10 = _mm256_fmadd_pd(ymm0, ymm1, ymm10);\ + ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6);\ + \ + tptr += 2; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + }\ + }\ + ymm4 = _mm256_permute_pd(ymm4, 0x5);\ + ymm5 = _mm256_permute_pd(ymm5, 0x5);\ + ymm6 = _mm256_permute_pd(ymm6, 0x5);\ + ymm8 = _mm256_addsub_pd(ymm8, ymm4);\ + ymm9 = _mm256_addsub_pd(ymm9, ymm5);\ + ymm10 = _mm256_addsub_pd(ymm10, ymm6);\ +} + +/** + * Performs GEMM operation. + * Four elements of column in ymm0, ymm1. + * ymm2, ymm7 holds respective broadcasted element. + */ +#define BLIS_ZTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) {\ + double *tptr = (double *)a01;\ + if(conjtransa) {\ + ymm18 = _mm256_set_pd(-1.0, -1.0, -1.0, -1.0);\ + for(k = 0; k < k_iter; k++)\ + {\ + ymm0 = _mm256_loadu_pd((double const *)b10);\ + ymm1 = _mm256_loadu_pd((double const *)(b10 + 2));\ + \ + _mm_prefetch((char*)( b10 + 4*cs_b), _MM_HINT_T0); \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0));\ + ymm7 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0 + 1));\ + ymm7 = _mm256_mul_pd(ymm7, ymm18);\ + /*dcomplex multiplication and substraction*/\ + \ + ymm3 = _mm256_fmadd_pd(ymm0, ymm2, ymm3);\ + ymm4 = _mm256_fmadd_pd(ymm1, ymm2, ymm4);\ + ymm5 = _mm256_fmadd_pd(ymm0, ymm7, ymm5);\ + ymm6 = _mm256_fmadd_pd(ymm1, ymm7, ymm6);\ + /*dcomplex multiplication and substraction*/\ + \ + tptr += 2;\ + b10 += cs_b;\ + }\ + }\ + else {\ + for(k = 0; k < k_iter; k++)\ + {\ + ymm0 = _mm256_loadu_pd((double const *)b10);\ + ymm1 = _mm256_loadu_pd((double const *)(b10 + 2));\ + \ + _mm_prefetch((char*)( b10 + 4*cs_b), _MM_HINT_T0); \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0));\ + ymm7 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0 + 1));\ + /*dcomplex multiplication and substraction*/\ + \ + ymm3 = _mm256_fmadd_pd(ymm0, ymm2, ymm3);\ + ymm4 = _mm256_fmadd_pd(ymm1, ymm2, ymm4);\ + ymm5 = _mm256_fmadd_pd(ymm0, ymm7, ymm5);\ + ymm6 = _mm256_fmadd_pd(ymm1, ymm7, ymm6);\ + /*ymm3 = _mm256_add_pd(ymm15, ymm3);*/\ + /*dcomplex multiplication and substraction*/\ + \ + tptr += 2;\ + b10 += cs_b;\ + }\ + }\ + ymm5 = _mm256_permute_pd(ymm5, 0x5);\ + ymm6 = _mm256_permute_pd(ymm6, 0x5);\ +\ + ymm3 = _mm256_addsub_pd(ymm3, ymm5);\ + ymm4 = _mm256_addsub_pd(ymm4, ymm6);\ +} + +/** + * Multiplies Alpha with 4 elements of column + */ +#define BLIS_PRE_ZTRSM_SMALL_1x4(b11,cs_b,AlphaVal) {\ + ymm16 = _mm256_broadcast_pd((__m128d const *)&AlphaVal);\ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0));\ + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 2));\ +\ + ymm14 = _mm256_shuffle_pd(ymm16, ymm16, 5);\ +\ + ymm17 = _mm256_shuffle_pd(ymm0, ymm0, 15);\ + ymm18 = _mm256_shuffle_pd(ymm0, ymm0,0);\ + ymm19 = _mm256_mul_pd(ymm17, ymm14);\ + ymm15 = _mm256_fmaddsub_pd(ymm18, ymm16, ymm19);\ + ymm3 = _mm256_sub_pd(ymm15, ymm3);\ +\ + ymm17 = _mm256_shuffle_pd(ymm1, ymm1, 15);\ + ymm18 = _mm256_shuffle_pd(ymm1, ymm1,0);\ + ymm19 = _mm256_mul_pd(ymm17, ymm14);\ + ymm15 = _mm256_fmaddsub_pd(ymm18, ymm16, ymm19);\ + ymm4 = _mm256_sub_pd(ymm15, ymm4);\ +} + +/** + * Multiplies Alpha with 3 elements of column. + * ymm0 holds first 2 element and xmm5 holds the + * 3rd one. + */ +#define BLIS_PRE_ZTRSM_SMALL_1x3(b11,cs_b,AlphaVal) {\ + ymm16 = _mm256_broadcast_pd((__m128d const *)&AlphaVal);\ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0));\ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 0 + 2));\ + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0);\ +\ + ymm14 = _mm256_shuffle_pd(ymm16, ymm16, 5);\ +\ + ymm17 = _mm256_shuffle_pd(ymm0, ymm0, 15);\ + ymm18 = _mm256_shuffle_pd(ymm0, ymm0,0);\ + ymm19 = _mm256_mul_pd(ymm17, ymm14);\ + ymm15 = _mm256_fmaddsub_pd(ymm18, ymm16, ymm19);\ + ymm3 = _mm256_sub_pd(ymm15, ymm3);\ +\ + ymm17 = _mm256_shuffle_pd(ymm1, ymm1, 15);\ + ymm18 = _mm256_shuffle_pd(ymm1, ymm1,0);\ + ymm19 = _mm256_mul_pd(ymm17, ymm14);\ + ymm15 = _mm256_fmaddsub_pd(ymm18, ymm16, ymm19);\ + ymm4 = _mm256_sub_pd(ymm15, ymm4);\ +} + +#define BLIS_ZTRSM_SMALL_NREG_TRANSPOSE_2x4(b11,cs_b,AlphaVal) {\ + ymm16 = _mm256_broadcast_pd((__m128d const *)&AlphaVal);\ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0));\ + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1));\ + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 2));\ + ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 2));\ + ymm14 = _mm256_shuffle_pd(ymm16, ymm16, 5);\ +\ + ymm17 = _mm256_shuffle_pd(ymm0, ymm0, 15);\ + ymm18 = _mm256_shuffle_pd(ymm0, ymm0,0);\ + ymm19 = _mm256_mul_pd(ymm17, ymm14);\ + ymm15 = _mm256_fmaddsub_pd(ymm18, ymm16, ymm19);\ + ymm0 = _mm256_sub_pd(ymm15, ymm8);\ +\ + ymm17 = _mm256_shuffle_pd(ymm1, ymm1, 15);\ + ymm18 = _mm256_shuffle_pd(ymm1, ymm1,0);\ + ymm19 = _mm256_mul_pd(ymm17, ymm14);\ + ymm15 = _mm256_fmaddsub_pd(ymm18, ymm16, ymm19);\ + ymm1 = _mm256_sub_pd(ymm15, ymm9);\ +\ + ymm17 = _mm256_shuffle_pd(ymm4, ymm4, 15);\ + ymm18 = _mm256_shuffle_pd(ymm4, ymm4,0);\ + ymm19 = _mm256_mul_pd(ymm17, ymm14);\ + ymm15 = _mm256_fmaddsub_pd(ymm18, ymm16, ymm19);\ + ymm4 = _mm256_sub_pd(ymm15, ymm12);\ +\ + ymm17 = _mm256_shuffle_pd(ymm5, ymm5, 15);\ + ymm18 = _mm256_shuffle_pd(ymm5, ymm5,0);\ + ymm19 = _mm256_mul_pd(ymm17, ymm14);\ + ymm15 = _mm256_fmaddsub_pd(ymm18, ymm16, ymm19);\ + ymm5 = _mm256_sub_pd(ymm15, ymm13);\ +} + +#define BLIS_PRE_ZTRSM_SMALL_3M_1N(AlphaVal,b11,cs_b){\ + ymm16 = _mm256_broadcast_pd(( __m128d const *)(&AlphaVal));\ + \ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0));\ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 0 + 2));\ + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0);\ + \ + ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ + /*dcomplex multiplication and substraction*/\ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm0, ymm16);\ + ymm14 = _mm256_mul_pd(ymm0, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm8 = _mm256_sub_pd(ymm15,ymm8);\ + \ + /*dcomplex multiplication and substraction*/\ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm1, ymm16);\ + ymm14 = _mm256_mul_pd(ymm1, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm12 = _mm256_sub_pd(ymm15,ymm12);\ + \ + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm8);\ + xmm5 = _mm256_extractf128_pd(ymm12, 0);\ + _mm_storeu_pd((double *)(b11 + cs_b * 0 + 2), xmm5);\ +} + +/** + * Multiplies Alpha with 3 elements of 2 columns + * and store into buffer b11. + * ymm0 ymm1 holds first 2 elements of 2 columns. + * xmm4 xmm5 holds the 3rd elements of 2 columns. + */ +#define BLIS_PRE_ZTRSM_SMALL_3M_2N(AlphaVal,b11,cs_b){\ + ymm16 = _mm256_broadcast_pd(( __m128d const*)(&AlphaVal));\ + \ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0));\ + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1));\ + xmm4 = _mm_loadu_pd((double const *)(b11 + cs_b * 0 + 2));\ + ymm3 = _mm256_insertf128_pd(ymm3, xmm4, 0);\ +\ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 1 + 2));\ + ymm4 = _mm256_insertf128_pd(ymm4, xmm5, 0);\ +\ + ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ + /*dcomplex multiplication and substraction*/\ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm0, ymm16);\ + ymm14 = _mm256_mul_pd(ymm0, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm8 = _mm256_sub_pd(ymm15,ymm8);\ + \ + /*dcomplex multiplication and substraction*/\ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm1, ymm16);\ + ymm14 = _mm256_mul_pd(ymm1, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm9 = _mm256_sub_pd(ymm15,ymm9);\ + \ + /*dcomplex multiplication and substraction*/\ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm3, ymm16);\ + ymm14 = _mm256_mul_pd(ymm3, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm12 = _mm256_sub_pd(ymm15,ymm12);\ + \ + /*dcomplex multiplication and substraction*/\ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm4, ymm16);\ + ymm14 = _mm256_mul_pd(ymm4, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm13 = _mm256_sub_pd(ymm15,ymm13);\ + \ + _mm256_storeu_pd((double *)(b11), ymm8);\ + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm9);\ + xmm4 = _mm256_extractf128_pd(ymm12, 0);\ + _mm_storeu_pd((double *)(b11 + cs_b * 0 + 2), xmm4);\ + xmm5 = _mm256_extractf128_pd(ymm13, 0);\ + _mm_storeu_pd((double *)(b11 + cs_b * 1 + 2), xmm5);\ +} + +/** + * Performs GEMM operation + * ymm0 holds 2 elements of column. + * ymm4 ymm6 holds broadcasted elements respectively + */ +#define BLIS_ZTRSM_SMALL_GEMM_3nx2m(a01,b10,cs_b,p_lda,k_iter) {\ + double *tptr = (double *)a01;\ + if(conjtransa) {\ + ymm18 = _mm256_set_pd(-1.0, -1.0, -1.0, -1.0);\ + for(k = 0; k< k_iter; k++) \ + {\ + ymm0 = _mm256_loadu_pd((double const *)(b10)); \ + \ + _mm_prefetch((char*)( b10 + 2*cs_b), _MM_HINT_T0); \ + ymm4 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0)); \ + ymm6 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0 + 1)); \ + ymm6 = _mm256_mul_pd(ymm6, ymm18);\ + /*dcomplex multiplication and substraction*/\ + \ + ymm3 = _mm256_fmadd_pd(ymm0, ymm4, ymm3);\ + ymm8 = _mm256_fmadd_pd(ymm0, ymm6, ymm8);\ + \ + ymm4 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1)); \ + ymm6 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1 + 1)); \ + ymm6 = _mm256_mul_pd(ymm6, ymm18);\ + \ + /*dcomplex multiplication and substraction*/\ + \ + ymm5 = _mm256_fmadd_pd(ymm0, ymm4, ymm5);\ + ymm9 = _mm256_fmadd_pd(ymm0, ymm6, ymm9);\ + \ + ymm4 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 2)); \ + ymm6 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 2 + 1)); \ + ymm6 = _mm256_mul_pd(ymm6, ymm18);\ + \ + /*dcomplex multiplication and substraction*/\ + \ + ymm7 = _mm256_fmadd_pd(ymm0, ymm4, ymm7);\ + ymm10 = _mm256_fmadd_pd(ymm0, ymm6, ymm10);\ + \ + tptr += 2; \ + b10 += cs_b; \ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) \ + {\ + ymm0 = _mm256_loadu_pd((double const *)(b10)); \ + \ + _mm_prefetch((char*)( b10 + 2*cs_b), _MM_HINT_T0); \ + ymm4 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0)); \ + ymm6 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0 + 1)); \ + /*dcomplex multiplication and substraction*/\ + \ + ymm3 = _mm256_fmadd_pd(ymm0, ymm4, ymm3);\ + ymm8 = _mm256_fmadd_pd(ymm0, ymm6, ymm8);\ + /*ymm3 = _mm256_add_pd(ymm15, ymm3);*/\ + \ + ymm4 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1)); \ + ymm6 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1 + 1)); \ + \ + /*dcomplex multiplication and substraction*/\ + \ + ymm5 = _mm256_fmadd_pd(ymm0, ymm4, ymm5);\ + ymm9 = _mm256_fmadd_pd(ymm0, ymm6, ymm9);\ + /*ymm5 = _mm256_add_pd(ymm15, ymm5);*/\ + \ + ymm4 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 2)); \ + ymm6 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 2 + 1)); \ + \ + /*dcomplex multiplication and substraction*/\ + \ + ymm7 = _mm256_fmadd_pd(ymm0, ymm4, ymm7);\ + ymm10 = _mm256_fmadd_pd(ymm0, ymm6, ymm10);\ + /*ymm7 = _mm256_add_pd(ymm15, ymm7);*/\ + \ + tptr += 2; \ + b10 += cs_b; \ + }\ + }\ + ymm8 = _mm256_permute_pd(ymm8, 0x5);\ + ymm9 = _mm256_permute_pd(ymm9, 0x5);\ + ymm10 = _mm256_permute_pd(ymm10, 0x5);\ + ymm3 = _mm256_addsub_pd(ymm3, ymm8);\ + ymm5 = _mm256_addsub_pd(ymm5, ymm9);\ + ymm7 = _mm256_addsub_pd(ymm7, ymm10);\ +} + +/** + * Multiplies Alpha with 2 elements of 3 columns + * ymm0 holds 2 elements of columns, once computation + * is done, it holds 2 elements of next columns after + * saving computed result into some other register. + * ymm3 ymm5 ymm7. + */ +#define BLIS_PRE_ZTRSM_SMALL_3x2(AlphaVal,b11,cs_b) {\ + ymm16 = _mm256_broadcast_pd(( __m128d const*)(&AlphaVal));\ + \ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0));\ + ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ + \ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm0, ymm16);\ + ymm14 = _mm256_mul_pd(ymm0, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm3 = _mm256_sub_pd(ymm15,ymm3);\ + \ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *1));\ +\ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm0, ymm16);\ + ymm14 = _mm256_mul_pd(ymm0, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm5 = _mm256_sub_pd(ymm15,ymm5);\ + \ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *2));\ + \ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm0, ymm16);\ + ymm14 = _mm256_mul_pd(ymm0, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm7 = _mm256_sub_pd(ymm15,ymm7);\ + \ +} + +/** + * Performs GEMM + * ymm0 and ymm1 together holds 4 elements of column. + */ +#define BLIS_ZTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) {\ + double *tptr = (double *)a01;\ + if(conjtransa) {\ + ymm18 = _mm256_set_pd(-1.0, -1.0, -1.0, -1.0);\ + for(k = 0; k< k_iter; k++) \ + { \ + ymm0 = _mm256_loadu_pd((double const *)(b10)); \ + ymm1 = _mm256_loadu_pd((double const *)(b10 + 2)); \ + \ + _mm_prefetch((char*)( b10 + 4*cs_b), _MM_HINT_T0); \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0)); \ + ymm12 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0 + 1)); \ + ymm12 = _mm256_mul_pd(ymm12, ymm18);\ + \ + ymm3 = _mm256_fmadd_pd(ymm0, ymm2, ymm3);\ + ymm4 = _mm256_fmadd_pd(ymm1, ymm2, ymm4);\ + ymm8 = _mm256_fmadd_pd(ymm0, ymm12, ymm8);\ + ymm9 = _mm256_fmadd_pd(ymm1, ymm12, ymm9);\ + \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1)); \ + ymm12 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1 + 1)); \ + ymm12 = _mm256_mul_pd(ymm12, ymm18);\ + \ + ymm5 = _mm256_fmadd_pd(ymm0, ymm2, ymm5);\ + ymm6 = _mm256_fmadd_pd(ymm1, ymm2, ymm6);\ + ymm10 = _mm256_fmadd_pd(ymm0, ymm12, ymm10);\ + ymm11 = _mm256_fmadd_pd(ymm1, ymm12, ymm11);\ + \ + tptr += 2; \ + b10 += cs_b; \ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) \ + { \ + ymm0 = _mm256_loadu_pd((double const *)(b10)); \ + ymm1 = _mm256_loadu_pd((double const *)(b10 + 2)); \ + \ + _mm_prefetch((char*)( b10 + 4*cs_b), _MM_HINT_T0); \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0)); \ + ymm12 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0 + 1)); \ + \ + ymm3 = _mm256_fmadd_pd(ymm0, ymm2, ymm3);\ + ymm4 = _mm256_fmadd_pd(ymm1, ymm2, ymm4);\ + ymm8 = _mm256_fmadd_pd(ymm0, ymm12, ymm8);\ + ymm9 = _mm256_fmadd_pd(ymm1, ymm12, ymm9);\ + \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1)); \ + ymm12 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1 + 1)); \ + \ + ymm5 = _mm256_fmadd_pd(ymm0, ymm2, ymm5);\ + ymm6 = _mm256_fmadd_pd(ymm1, ymm2, ymm6);\ + ymm10 = _mm256_fmadd_pd(ymm0, ymm12, ymm10);\ + ymm11 = _mm256_fmadd_pd(ymm1, ymm12, ymm11);\ + \ + tptr += 2; \ + b10 += cs_b; \ + }\ + }\ + ymm8 = _mm256_permute_pd(ymm8, 0x5);\ + ymm9 = _mm256_permute_pd(ymm9, 0x5);\ + ymm10 = _mm256_permute_pd(ymm10, 0x5);\ + ymm11 = _mm256_permute_pd(ymm11, 0x5);\ + ymm3 = _mm256_addsub_pd(ymm3, ymm8);\ + ymm4 = _mm256_addsub_pd(ymm4, ymm9);\ + ymm5 = _mm256_addsub_pd(ymm5, ymm10);\ + ymm6 = _mm256_addsub_pd(ymm6, ymm11);\ +} + +/** + * Performs GEMM operation + * ymm0 holds 2 elements of a column. + */ +#define BLIS_ZTRSM_SMALL_GEMM_2nx2m(a01,b10,cs_b,p_lda,k_iter){\ + double *tptr = (double *)a01;\ + if(conjtransa) {\ + ymm18 = _mm256_set_pd(-1.0, -1.0, -1.0, -1.0);\ + for(k = 0; k< k_iter; k++) \ + { \ + ymm0 = _mm256_loadu_pd((double const *)(b10)); \ + \ + _mm_prefetch((char*)( b10 + 2*cs_b), _MM_HINT_T0); \ + ymm1 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0)); \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0 + 1)); \ + ymm2 = _mm256_mul_pd(ymm2, ymm18);\ + \ + ymm3 = _mm256_fmadd_pd(ymm0, ymm1, ymm3);\ + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4);\ + \ + \ + ymm1 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1)); \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1 + 1)); \ + ymm2 = _mm256_mul_pd(ymm2, ymm18);\ + \ + ymm5 = _mm256_fmadd_pd(ymm0, ymm1, ymm5);\ + ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6);\ + \ + tptr += 2; \ + b10 += cs_b; \ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) \ + { \ + ymm0 = _mm256_loadu_pd((double const *)(b10)); \ + \ + _mm_prefetch((char*)( b10 + 2*cs_b), _MM_HINT_T0); \ + ymm1 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0)); \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0 + 1)); \ + \ + ymm3 = _mm256_fmadd_pd(ymm0, ymm1, ymm3);\ + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4);\ + \ + \ + ymm1 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1)); \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1 + 1)); \ + \ + ymm5 = _mm256_fmadd_pd(ymm0, ymm1, ymm5);\ + ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6);\ + \ + tptr += 2; \ + b10 += cs_b; \ + }\ + }\ + ymm4 = _mm256_permute_pd(ymm4, 0x5);\ + ymm6 = _mm256_permute_pd(ymm6, 0x5);\ + ymm3 = _mm256_addsub_pd(ymm3, ymm4);\ + ymm5 = _mm256_addsub_pd(ymm5, ymm6);\ +} + +/** + * Multiplies Alpha with 2 elements of a column. + * ymm0 holds the 2 element of a column. + */ +#define BLIS_PRE_ZTRSM_SMALL_1x1(AlphaVal,b11,cs_b){\ + ymm16 = _mm256_broadcast_pd(( __m128d const*)(&AlphaVal));\ + \ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 0));\ + ymm0 = _mm256_insertf128_pd(ymm1, xmm5, 0);\ + ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ + \ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm0, ymm16);\ + ymm14 = _mm256_mul_pd(ymm0, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm3 = _mm256_sub_pd(ymm15,ymm3);\ +} + +/** + * Multiplies Alpha with 2 elements of a column. + * ymm0 holds the 2 element of a column. + */ +#define BLIS_PRE_ZTRSM_SMALL_1x2(AlphaVal,b11,cs_b){\ + ymm16 = _mm256_broadcast_pd(( __m128d const*)(&AlphaVal));\ + \ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0));\ + ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ + \ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm0, ymm16);\ + ymm14 = _mm256_mul_pd(ymm0, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm3 = _mm256_sub_pd(ymm15,ymm3);\ +} + +/** + * Multiplies Alpha with 2 elements of 2 columns. + * ymm0 holds 2 elements of a columns respectively, + * once computation is done, gets stored in registers + * ymm3, ymm5 + */ +#define BLIS_PRE_ZTRSM_SMALL_2x2(AlphaVal,b11,cs_b){\ + ymm16 = _mm256_broadcast_pd(( __m128d const*)(&AlphaVal));\ + \ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0));\ + ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ + \ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm0, ymm16);\ + ymm14 = _mm256_mul_pd(ymm0, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm3 = _mm256_sub_pd(ymm15,ymm3);\ + \ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *1));\ +\ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm0, ymm16);\ + ymm14 = _mm256_mul_pd(ymm0, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm5 = _mm256_sub_pd(ymm15,ymm5);\ +} + +/** + * Performs GEMM operation + * 3 elements of a columns get held by ymm0(2 element) + * and xmm5 (1 element). + */ +#define BLIS_ZTRSM_SMALL_GEMM_1nx3m(a01,b10,cs_b,p_lda,k_iter) {\ + double *tptr = (double *)a01;\ + if(conjtransa) {\ + ymm18 = _mm256_set_pd(-1.0, -1.0, -1.0, -1.0);\ + for(k = 0; k< k_iter; k++) \ + {\ + ymm0 = _mm256_loadu_pd((double const *)(b10)); \ + /*ymm1 = _mm256_loadu_pd((double const *)(b10 + 2));*/\ + xmm5 = _mm_loadu_pd((double const *)(b10 + 2));\ + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0);\ + \ + _mm_prefetch((char*)( b10 + 4*cs_b), _MM_HINT_T0); \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0)); \ + ymm5 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0 + 1)); \ + ymm5 = _mm256_mul_pd(ymm5, ymm18);\ + \ + ymm3 = _mm256_fmadd_pd(ymm0, ymm2, ymm3);\ + ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6);\ + ymm4 = _mm256_fmadd_pd(ymm1, ymm5, ymm4);\ + ymm7 = _mm256_fmadd_pd(ymm1, ymm5, ymm7);\ + \ + tptr += 2;\ + b10 += cs_b;\ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) \ + {\ + ymm0 = _mm256_loadu_pd((double const *)(b10)); \ + /*ymm1 = _mm256_loadu_pd((double const *)(b10 + 2));*/\ + xmm5 = _mm_loadu_pd((double const *)(b10 + 2));\ + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0);\ + \ + _mm_prefetch((char*)( b10 + 4*cs_b), _MM_HINT_T0); \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0)); \ + ymm5 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0 + 1)); \ + \ + ymm3 = _mm256_fmadd_pd(ymm0, ymm2, ymm3);\ + ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6);\ + ymm4 = _mm256_fmadd_pd(ymm1, ymm5, ymm4);\ + ymm7 = _mm256_fmadd_pd(ymm1, ymm5, ymm7);\ + \ + tptr += 2;\ + b10 += cs_b;\ + }\ + }\ + ymm6 = _mm256_permute_pd(ymm6, 0x5);\ + ymm7 = _mm256_permute_pd(ymm7, 0x5);\ + ymm3 = _mm256_addsub_pd(ymm3, ymm6);\ + ymm4 = _mm256_addsub_pd(ymm5, ymm7);\ +} + + +/** + * Performs GEMM operation. + * 1 elements of a column are kept in ymm0. + */ +#define BLIS_ZTRSM_SMALL_GEMM_1nx1m(a01,b10,cs_b,p_lda,k_iter) {\ + double *tptr = (double *)a01;\ + if(conjtransa) {\ + ymm18 = _mm256_set_pd(-1.0, -1.0, -1.0, -1.0);\ + for(k = 0; k< k_iter; k++) \ + { \ + xmm5 = _mm_loadu_pd((double const *)(b10));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ + \ + _mm_prefetch((char*)( b10 + 2*cs_b), _MM_HINT_T0); \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0)); \ + ymm5 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0 + 1)); \ + ymm5 = _mm256_mul_pd(ymm5, ymm18);\ + \ + ymm3 = _mm256_fmadd_pd(ymm0, ymm2, ymm3);\ + ymm4 = _mm256_fmadd_pd(ymm0, ymm5, ymm4);\ + \ + tptr += 2; \ + b10 += cs_b; \ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) \ + { \ + xmm5 = _mm_loadu_pd((double const *)(b10));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ + \ + _mm_prefetch((char*)( b10 + 2*cs_b), _MM_HINT_T0); \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0)); \ + ymm5 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0 + 1)); \ + \ + ymm3 = _mm256_fmadd_pd(ymm0, ymm2, ymm3);\ + ymm4 = _mm256_fmadd_pd(ymm0, ymm5, ymm4);\ + \ + tptr += 2; \ + b10 += cs_b; \ + }\ + }\ + ymm4 = _mm256_permute_pd(ymm4, 0x5);\ + ymm3 = _mm256_addsub_pd(ymm3, ymm4);\ +} + + +/** + * Performs GEMM operation. + * 2 elements of a column are kept in ymm0. + */ +#define BLIS_ZTRSM_SMALL_GEMM_1nx2m(a01,b10,cs_b,p_lda,k_iter) {\ + double *tptr = (double *)a01;\ + if(conjtransa) {\ + ymm18 = _mm256_set_pd(-1.0, -1.0, -1.0, -1.0);\ + for(k = 0; k< k_iter; k++) \ + { \ + ymm0 = _mm256_loadu_pd((double const *)(b10)); \ + \ + _mm_prefetch((char*)( b10 + 2*cs_b), _MM_HINT_T0); \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0)); \ + ymm5 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0 + 1)); \ + ymm5 = _mm256_mul_pd(ymm5, ymm18);\ + \ + ymm3 = _mm256_fmadd_pd(ymm0, ymm2, ymm3);\ + ymm4 = _mm256_fmadd_pd(ymm0, ymm5, ymm4);\ + \ + tptr += 2; \ + b10 += cs_b; \ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) \ + { \ + ymm0 = _mm256_loadu_pd((double const *)(b10)); \ + \ + _mm_prefetch((char*)( b10 + 2*cs_b), _MM_HINT_T0); \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0)); \ + ymm5 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0 + 1)); \ + \ + ymm3 = _mm256_fmadd_pd(ymm0, ymm2, ymm3);\ + ymm4 = _mm256_fmadd_pd(ymm0, ymm5, ymm4);\ + \ + tptr += 2; \ + b10 += cs_b; \ + }\ + }\ + ymm4 = _mm256_permute_pd(ymm4, 0x5);\ + ymm3 = _mm256_addsub_pd(ymm3, ymm4);\ +} + +/** + * Performs GEMM operation + * 4 elements of columns are kept in ymm0 and ymm1. + */ +#define BLIS_ZTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) {\ + double *tptr = (double *)a01;\ + if(conjtransa) {\ + ymm18 = _mm256_set_pd(-1.0, -1.0, -1.0, -1.0);\ + for(k = 0; k< k_iter; k++) \ + { \ + ymm0 = _mm256_loadu_pd((double const *)(b10)); \ + ymm1 = _mm256_loadu_pd((double const *)(b10 + 2)); \ + \ + _mm_prefetch((char*)( b10 + 4*cs_b), _MM_HINT_T0); \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0));\ + ymm9 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0 + 1));\ + ymm9 = _mm256_mul_pd(ymm9, ymm18);\ + \ + ymm3 = _mm256_fmadd_pd(ymm0, ymm2, ymm3);\ + ymm4 = _mm256_fmadd_pd(ymm1, ymm2, ymm4);\ + ymm10 = _mm256_fmadd_pd(ymm0, ymm9, ymm10);\ + ymm11 = _mm256_fmadd_pd(ymm1, ymm9, ymm11);\ + \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1)); \ + ymm9 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1 + 1)); \ + ymm9 = _mm256_mul_pd(ymm9, ymm18);\ + \ + ymm5 = _mm256_fmadd_pd(ymm0, ymm2, ymm5);\ + ymm6 = _mm256_fmadd_pd(ymm1, ymm2, ymm6);\ + ymm12 = _mm256_fmadd_pd(ymm0, ymm9, ymm12);\ + ymm13 = _mm256_fmadd_pd(ymm1, ymm9, ymm13);\ + \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 2)); \ + ymm9 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 2 + 1)); \ + ymm9 = _mm256_mul_pd(ymm9, ymm18);\ + \ + ymm7 = _mm256_fmadd_pd(ymm0, ymm2, ymm7);\ + ymm8 = _mm256_fmadd_pd(ymm1, ymm2, ymm8);\ + ymm14 = _mm256_fmadd_pd(ymm0, ymm9, ymm14);\ + ymm15 = _mm256_fmadd_pd(ymm1, ymm9, ymm15);\ + \ + tptr += 2; \ + b10 += cs_b; \ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) \ + { \ + ymm0 = _mm256_loadu_pd((double const *)(b10)); \ + ymm1 = _mm256_loadu_pd((double const *)(b10 + 2)); \ + \ + _mm_prefetch((char*)( b10 + 4*cs_b), _MM_HINT_T0); \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0));\ + ymm9 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0 + 1));\ + \ + ymm3 = _mm256_fmadd_pd(ymm0, ymm2, ymm3);\ + ymm4 = _mm256_fmadd_pd(ymm1, ymm2, ymm4);\ + ymm10 = _mm256_fmadd_pd(ymm0, ymm9, ymm10);\ + ymm11 = _mm256_fmadd_pd(ymm1, ymm9, ymm11);\ + \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1)); \ + ymm9 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1 + 1)); \ + \ + ymm5 = _mm256_fmadd_pd(ymm0, ymm2, ymm5);\ + ymm6 = _mm256_fmadd_pd(ymm1, ymm2, ymm6);\ + ymm12 = _mm256_fmadd_pd(ymm0, ymm9, ymm12);\ + ymm13 = _mm256_fmadd_pd(ymm1, ymm9, ymm13);\ + \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 2)); \ + ymm9 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 2 + 1)); \ + \ + ymm7 = _mm256_fmadd_pd(ymm0, ymm2, ymm7);\ + ymm8 = _mm256_fmadd_pd(ymm1, ymm2, ymm8);\ + ymm14 = _mm256_fmadd_pd(ymm0, ymm9, ymm14);\ + ymm15 = _mm256_fmadd_pd(ymm1, ymm9, ymm15);\ + \ + tptr += 2; \ + b10 += cs_b; \ + }\ + }\ + ymm10 = _mm256_permute_pd(ymm10, 0x5);\ + ymm11 = _mm256_permute_pd(ymm11, 0x5);\ + ymm12 = _mm256_permute_pd(ymm12, 0x5);\ + ymm13 = _mm256_permute_pd(ymm13, 0x5);\ + ymm14 = _mm256_permute_pd(ymm14, 0x5);\ + ymm15 = _mm256_permute_pd(ymm15, 0x5);\ +\ + ymm3 = _mm256_addsub_pd(ymm3, ymm10);\ + ymm4 = _mm256_addsub_pd(ymm4, ymm11);\ + ymm5 = _mm256_addsub_pd(ymm5, ymm12);\ + ymm6 = _mm256_addsub_pd(ymm6, ymm13);\ + ymm7 = _mm256_addsub_pd(ymm7, ymm14);\ + ymm8 = _mm256_addsub_pd(ymm8, ymm15);\ +} + +/** + * Multiplies Alpha with 4 element of 2 columns. + * ymm0 and ymm1 holds 4 elements of a column. + */ +#define BLIS_PRE_ZTRSM_SMALL_2x4(AlphaVal,b11,cs_b) {\ + ymm16 = _mm256_broadcast_pd(( __m128d const*)(&AlphaVal));\ + \ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0));\ + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 2));\ + ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ + \ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm0, ymm16);\ + ymm14 = _mm256_mul_pd(ymm0, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm3 = _mm256_sub_pd(ymm15,ymm3);\ + \ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm1, ymm16);\ + ymm14 = _mm256_mul_pd(ymm1, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm4 = _mm256_sub_pd(ymm15,ymm4);\ + \ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *1));\ + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 2));\ +\ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm0, ymm16);\ + ymm14 = _mm256_mul_pd(ymm0, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm5 = _mm256_sub_pd(ymm15,ymm5);\ + \ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm1, ymm16);\ + ymm14 = _mm256_mul_pd(ymm1, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm6 = _mm256_sub_pd(ymm15,ymm6);\ +} + +/** + * Multiplies Alpha with 4 element of 3 columns. + * ymm0 and ymm1 holds 4 elements of a column. + */ +#define BLIS_PRE_ZTRSM_SMALL_3x4(AlphaVal,b11,cs_b) {\ + ymm16 = _mm256_broadcast_pd(( __m128d const*)(&AlphaVal));\ + \ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0));\ + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 2));\ + ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ + \ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm0, ymm16);\ + ymm14 = _mm256_mul_pd(ymm0, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm3 = _mm256_sub_pd(ymm15,ymm3);\ + \ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm1, ymm16);\ + ymm14 = _mm256_mul_pd(ymm1, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm4 = _mm256_sub_pd(ymm15,ymm4);\ + \ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *1));\ + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 2));\ +\ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm0, ymm16);\ + ymm14 = _mm256_mul_pd(ymm0, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm5 = _mm256_sub_pd(ymm15,ymm5);\ + \ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm1, ymm16);\ + ymm14 = _mm256_mul_pd(ymm1, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm6 = _mm256_sub_pd(ymm15,ymm6);\ + \ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *2));\ + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *2 + 2));\ + \ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm0, ymm16);\ + ymm14 = _mm256_mul_pd(ymm0, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm7 = _mm256_sub_pd(ymm15,ymm7);\ + \ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm1, ymm16);\ + ymm14 = _mm256_mul_pd(ymm1, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm8 = _mm256_sub_pd(ymm15,ymm8);\ + \ +} + +/* + * Pack a block of 4xk or 3xk from input buffer into packed buffer + * directly or after transpose based on input params + */ + +/* + * Load b11 of size 3x4 and multiply with alpha + * Add the GEMM output and perform inregister transose of b11 + * to peform ZTRSM operation for left cases. + */ +#define BLIS_ZTRSM_SMALL_NREG_TRANSPOSE_3x4(b11,cs_b,AlphaVal) {\ + ymm16 = _mm256_broadcast_pd(( __m128d const *)(&AlphaVal));\ +\ + ymm0 = _mm256_loadu_pd((double const *)(b11));\ + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1));\ + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2));\ + ymm3 = _mm256_broadcast_pd((__m128d const *)&ones);\ + /*in register transpose + * ymm0,ymm1,ymm2 holds + * two dcomplex elements of b11 cols*/\ + ymm14 = _mm256_shuffle_pd(ymm16, ymm16, 5);\ + ymm5 = _mm256_shuffle_pd(ymm0, ymm0, 15);\ + ymm6 = _mm256_shuffle_pd(ymm0, ymm0,0);\ + ymm7 = _mm256_mul_pd(ymm5, ymm14);\ + ymm15 = _mm256_fmaddsub_pd(ymm6, ymm16, ymm7);\ + ymm0 = _mm256_sub_pd(ymm15, ymm8);\ +\ + ymm5 = _mm256_shuffle_pd(ymm1, ymm1, 15);\ + ymm6 = _mm256_shuffle_pd(ymm1, ymm1,0);\ + ymm7 = _mm256_mul_pd(ymm5, ymm14);\ + ymm15 = _mm256_fmaddsub_pd(ymm6, ymm16, ymm7);\ + ymm1 = _mm256_sub_pd(ymm15, ymm9);\ +\ + ymm5 = _mm256_shuffle_pd(ymm2, ymm2, 15);\ + ymm6 = _mm256_shuffle_pd(ymm2, ymm2,0);\ + ymm7 = _mm256_mul_pd(ymm5, ymm14);\ + ymm15 = _mm256_fmaddsub_pd(ymm6, ymm16, ymm7);\ + ymm2 = _mm256_sub_pd(ymm15, ymm10);\ +\ + /*in register transpose of computed b11 col*/\ + ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); \ + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31);\ + ymm4 = _mm256_permute2f128_pd(ymm2,ymm3,0x20); \ + ymm5 = _mm256_permute2f128_pd(ymm2,ymm3,0x31); \ +\ + /*in register transpose + * ymm0,ymm1,ymm2 holds + * next two dcomplex elements of b11 cols*/\ + ymm0 = _mm256_loadu_pd((double const *)(b11 + 2));\ + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 2));\ + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2 + 2));\ +\ + ymm17 = _mm256_shuffle_pd(ymm0, ymm0, 15);\ + ymm18 = _mm256_shuffle_pd(ymm0, ymm0, 0);\ + ymm19 = _mm256_mul_pd(ymm17, ymm14);\ + ymm15 = _mm256_fmaddsub_pd(ymm18, ymm16, ymm19);\ + ymm0 = _mm256_sub_pd(ymm15, ymm11);\ +\ + ymm17 = _mm256_shuffle_pd(ymm1, ymm1, 15);\ + ymm18 = _mm256_shuffle_pd(ymm1, ymm1, 0);\ + ymm19 = _mm256_mul_pd(ymm17, ymm14);\ + ymm15 = _mm256_fmaddsub_pd(ymm18, ymm16, ymm19);\ + ymm1 = _mm256_sub_pd(ymm15, ymm12);\ +\ + ymm17 = _mm256_shuffle_pd(ymm2, ymm2, 15);\ + ymm18 = _mm256_shuffle_pd(ymm2, ymm2, 0);\ + ymm19 = _mm256_mul_pd(ymm17, ymm14);\ + ymm15 = _mm256_fmaddsub_pd(ymm18, ymm16, ymm19);\ + ymm2 = _mm256_sub_pd(ymm15, ymm13);\ +\ + /*in register transpose of computed b11 col*/\ + ymm10 = _mm256_permute2f128_pd(ymm0,ymm1,0x20);\ + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31);\ + ymm6 = _mm256_permute2f128_pd(ymm2,ymm3,0x20);\ + ymm7 = _mm256_permute2f128_pd(ymm2,ymm3,0x31);\ +} + +/** + * Performs GEMM operation. + * 4 elements of a column are kept inymm0 and ymm1 + */ +#define BLIS_ZTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b,p_lda,k_iter) {\ + double *tptr = (double *)b01;\ + if(conjtransa) {\ + ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_loadu_pd((double const *)(a10));\ + ymm1 = _mm256_loadu_pd((double const *)(a10 + 2));\ + ymm0 = _mm256_mul_pd(ymm0, ymm18);\ + ymm1 = _mm256_mul_pd(ymm1, ymm18);\ + \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0));\ + ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0 + 1)); \ + \ + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8);\ + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12);\ + \ + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);\ + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5);\ + tptr += 2; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_loadu_pd((double const *)(a10));\ + ymm1 = _mm256_loadu_pd((double const *)(a10 + 2));\ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0));\ + ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0 + 1)); \ + \ + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8);\ + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12);\ + \ + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);\ + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5);\ + tptr += 2; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + }\ + }\ + ymm4 = _mm256_permute_pd(ymm4, 0x5);\ + ymm5 = _mm256_permute_pd(ymm5, 0x5);\ + ymm8 = _mm256_addsub_pd(ymm8, ymm4);\ + ymm12 = _mm256_addsub_pd(ymm12, ymm5);\ +} + +/** + * Performs the GEMM operation. + * 2 elements of a column are kept in ymm0. + */ +#define BLIS_ZTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b,p_lda,k_iter) {\ + double *tptr = (double * )b01;\ + if(conjtransa) {\ + ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_loadu_pd((double const *)(a10));\ + ymm1 = _mm256_loadu_pd((double const *)(a10 + 2));\ + ymm0 = _mm256_mul_pd(ymm0, ymm18);\ + ymm1 = _mm256_mul_pd(ymm1, ymm18);\ + \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0));\ + ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0 + 1)); \ + \ + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8);\ + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12);\ + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);\ + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5);\ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1)); \ + ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1 + 1)); \ + \ + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9);\ + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13);\ + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6);\ + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7);\ + tptr += 2; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_loadu_pd((double const *)(a10));\ + ymm1 = _mm256_loadu_pd((double const *)(a10 + 2));\ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0));\ + ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0 + 1)); \ + \ + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8);\ + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12);\ + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);\ + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5);\ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1)); \ + ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1 + 1)); \ + \ + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9);\ + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13);\ + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6);\ + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7);\ + tptr += 2; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + }\ + }\ + ymm4 = _mm256_permute_pd(ymm4, 0x5);\ + ymm5 = _mm256_permute_pd(ymm5, 0x5);\ + ymm6 = _mm256_permute_pd(ymm6, 0x5);\ + ymm7 = _mm256_permute_pd(ymm7, 0x5);\ +\ + ymm8 = _mm256_addsub_pd(ymm8, ymm4);\ + ymm12 = _mm256_addsub_pd(ymm12, ymm5);\ + ymm9 = _mm256_addsub_pd(ymm9, ymm6);\ + ymm13 = _mm256_addsub_pd(ymm13, ymm7);\ +} + +/*GEMM block used in ztrsm small left cases*/ +#define BLIS_ZTRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) {\ + double *tptr = (double *)b01;\ + if(conjtransa) {\ + ymm16 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ + for(k = 0; k< k_iter; k++) \ + { \ + ymm0 = _mm256_loadu_pd((double const *)(a10)); \ + ymm1 = _mm256_loadu_pd((double const *)(a10 + 2)); \ + ymm0 = _mm256_mul_pd(ymm0, ymm16);\ + ymm1 = _mm256_mul_pd(ymm1, ymm16);\ + \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr)); \ + ymm3 = _mm256_broadcast_sd((double const *)(tptr + 1)); \ + \ + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8);\ + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11);\ + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);\ + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5);\ + \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 1 * 2)); \ + ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 1 * 2 + 1)); \ + \ + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9);\ + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12);\ + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6);\ + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7);\ + \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b *2 * 2)); \ + ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 2 + 1)); \ + \ + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10);\ + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13);\ + \ + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14);\ + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15);\ + \ + tptr += 2; \ + a10 += p_lda; \ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) \ + { \ + ymm0 = _mm256_loadu_pd((double const *)(a10)); \ + ymm1 = _mm256_loadu_pd((double const *)(a10 + 2)); \ + \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr)); \ + ymm3 = _mm256_broadcast_sd((double const *)(tptr + 1)); \ + \ + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8);\ + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11);\ + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);\ + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5);\ + \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 1 * 2)); \ + ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 1 * 2 + 1)); \ + \ + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9);\ + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12);\ + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6);\ + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7);\ + \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b *2 * 2)); \ + ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 2 + 1)); \ + \ + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10);\ + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13);\ + \ + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14);\ + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15);\ + \ + tptr += 2; \ + a10 += p_lda; \ + }\ + }\ + ymm4 = _mm256_permute_pd(ymm4, 0x5);\ + ymm5 = _mm256_permute_pd(ymm5, 0x5);\ + ymm6 = _mm256_permute_pd(ymm6, 0x5);\ + ymm7 = _mm256_permute_pd(ymm7, 0x5);\ + ymm14 = _mm256_permute_pd(ymm14, 0x5);\ + ymm15 = _mm256_permute_pd(ymm15, 0x5);\ + \ + ymm8 = _mm256_addsub_pd(ymm8, ymm4);\ + ymm11 = _mm256_addsub_pd(ymm11, ymm5);\ + ymm9 = _mm256_addsub_pd(ymm9, ymm6);\ + ymm12 = _mm256_addsub_pd(ymm12, ymm7);\ + ymm10 = _mm256_addsub_pd(ymm10, ymm14);\ + ymm13 = _mm256_addsub_pd(ymm13, ymm15);\ +} + + +#define BLIS_ZTRSM_SMALL_NREG_TRANSPOSE_4x3_AND_STORE(b11,cs_b){\ + ymm0 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20);\ + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31);\ + ymm2 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20);\ + _mm256_storeu_pd((double *)(b11), ymm0);\ + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1);\ + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2);\ +\ + ymm0 = _mm256_permute2f128_pd(ymm10, ymm11, 0x20);\ + ymm1 = _mm256_permute2f128_pd(ymm10, ymm11, 0x31);\ + ymm2 = _mm256_permute2f128_pd(ymm6, ymm7, 0x20);\ + _mm256_storeu_pd((double *)(b11 + 2), ymm0);\ + _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 2), ymm1);\ + _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 2), ymm2);\ +} + +/** + * Performs dcomplex division of vec1 and vec2 with ymm1. + * vec1 and vec2 gets divided by ymm1 which holds + * diagonal element from buffer. + * Function gets called while performing TRSM. + */ +#define BLIS_ZTRSM_TWO_DIV(vec1, vec2) {\ + if(!is_unitdiag) {\ + if(conjtransa){\ + ymm1 = _mm256_mul_pd(ymm1, ymm0);\ + }\ + ymm12 = _mm256_mul_pd(ymm1, ymm0);\ + /*perform decomplex multiplication*/\ + /* Switch the real and imaginary elements of vec2 */\ + ymm14 = _mm256_permute_pd(ymm12, 0x5);\ + /* Negate the imaginary elements of vec2 */\ + ymm14 = _mm256_mul_pd(ymm14, ymm0);\ + /* Multiply vec1 and vec2 */ \ + ymm13 = _mm256_mul_pd(vec1, ymm12); /*vec3*/\ + /* Multiply vec1 and the modified vec2 */\ + ymm14 = _mm256_mul_pd(vec1, ymm14); /*vec4*/\ + /* Horizontally subtract the elements in vec3 and vec4 */\ + vec1 = _mm256_hsub_pd(ymm13, ymm14);\ + \ + ymm14 = _mm256_permute_pd(ymm12, 0x5);\ + /* Negate the imaginary elements of vec2 */\ + ymm14 = _mm256_mul_pd(ymm14, ymm0);\ + ymm13 = _mm256_mul_pd(vec2, ymm12);\ + ymm14 = _mm256_mul_pd(vec2, ymm14);\ + vec2 = _mm256_hsub_pd(ymm13, ymm14);\ + /*dcomplex multiplication is done*/\ + /*Swapping real & imaginary component position for addition with respective + * components*/\ + ymm12 = _mm256_mul_pd(ymm1, ymm1);\ + ymm13 = _mm256_permute4x64_pd(ymm12, 0xb1);\ + ymm14 = _mm256_add_pd(ymm12, ymm13);\ + \ + /*Finally dividing numerator by denominator*/\ + vec1 = _mm256_div_pd(vec1, ymm14);\ + vec2 = _mm256_div_pd(vec2, ymm14);\ + }\ +} + +/** + * Performs dcomplex division of vec1 with ymm1. + * ymm1 holds diagonal element from buffer. + * Function gets called while performing TRSM. + */ +#define BLIS_ZTRSM_DIV(vec1) {\ + if(!is_unitdiag){\ + if(conjtransa){\ + ymm1 = _mm256_mul_pd(ymm1, ymm0);\ + }\ + ymm12 = _mm256_mul_pd(ymm1, ymm0); /*vec2 and ymm8 is vec1*/\ + ymm14 = _mm256_permute_pd(ymm12, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm0);\ + ymm13 = _mm256_mul_pd(vec1, ymm12); /*vec3*/\ + ymm14 = _mm256_mul_pd(vec1, ymm14); /*vec4*/\ + vec1 = _mm256_hsub_pd(ymm13, ymm14);\ + \ + ymm12 = _mm256_mul_pd(ymm1, ymm1);\ + ymm13 = _mm256_permute4x64_pd(ymm12, 0xb1);\ + ymm14 = _mm256_add_pd(ymm12, ymm13);\ + \ + /*Finally dividing numerator by denominator*/\ + vec1 = _mm256_div_pd(vec1, ymm14);\ + }\ +} + +/** + * Performs dcomplex multiplication of vec1 with ymm1. + * ymm1 holds diagonal element from buffer. + * Function gets called while performing TRSM. + */ +#define BLIS_ZTRSM_MUL(vec1) {\ + if(!is_unitdiag){\ + if(conjtransa){\ + ymm19 = _mm256_mul_pd(ymm1, ymm0);\ + }\ + else{\ + ymm19 = ymm1;\ + }\ + ymm14 = _mm256_permute_pd(ymm19, 0x5);\ + /* Negate the imaginary elements of vec2 */\ + ymm14 = _mm256_mul_pd(ymm14, ymm0);\ + /* Multiply vec1 and vec2 */\ + ymm13 = _mm256_mul_pd(vec1, ymm19); /*vec3*/\ + /* Multiply vec1 and the modified vec2 */\ + ymm14 = _mm256_mul_pd(vec1, ymm14); /*vec4*/\ + /* Horizontally subtract the elements in vec3 and vec4 */\ + vec1 = _mm256_hsub_pd(ymm13, ymm14);\ + }\ +} + +BLIS_INLINE void bli_ztrsm_small_pack +( + char side, + dim_t size, + bool trans, + dcomplex *inbuf, + dim_t cs_a, + dcomplex *pbuff, + dim_t p_lda, + dim_t mr +) +{ + //scratch registers + __m256d ymm0, ymm1, ymm2; + __m256d ymm5, ymm6, ymm7; + __m256d ymm8, ymm9, ymm10, ymm11; + __m128d xmm0,xmm1,xmm2; + double zero = 0.0; + + if(side=='L'||side=='l') + { + /*Left case is 4xk*/ + if(trans) + { + /* + ------------- ------------- + | | | | | + | 2x4 | | | | + ------------- ==> | 4x2 | 4x2 | + | 2x4 | | | | + | | | | | + ------------- ------------- + */ + for(dim_t x = 0; x < size; x += mr) + { + ymm0 = _mm256_loadu_pd((double const *)(inbuf)); + ymm10 = _mm256_loadu_pd((double const *)(inbuf + 2)); + ymm1 = _mm256_loadu_pd((double const *)(inbuf + cs_a)); + ymm11 = _mm256_loadu_pd((double const *)(inbuf + 2 + cs_a)); + + ymm6 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + ymm8 = _mm256_permute2f128_pd(ymm10,ymm11,0x20); + ymm9 = _mm256_permute2f128_pd(ymm10,ymm11,0x31); + + _mm256_storeu_pd((double *)(pbuff), ymm6); + _mm256_storeu_pd((double *)(pbuff + p_lda), ymm7); + _mm256_storeu_pd((double *)(pbuff + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(pbuff + p_lda*3), ymm9); + + ymm0 = _mm256_loadu_pd((double const *)(inbuf + 2 * cs_a)); + ymm10 = _mm256_loadu_pd((double const *)(inbuf + 2 * cs_a + 2)); + ymm1 = _mm256_loadu_pd((double const *)(inbuf + 3 * cs_a)); + ymm11 = _mm256_loadu_pd((double const *)(inbuf + 3 * cs_a + 2)); + + ymm6 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + ymm8 = _mm256_permute2f128_pd(ymm10,ymm11,0x20); + ymm9 = _mm256_permute2f128_pd(ymm10,ymm11,0x31); + + _mm256_storeu_pd((double *)(pbuff + 2), ymm6); + _mm256_storeu_pd((double *)(pbuff + p_lda + 2), ymm7); + _mm256_storeu_pd((double *)(pbuff + p_lda*2 + 2), ymm8); + _mm256_storeu_pd((double *)(pbuff + p_lda*3 + 2), ymm9); + + inbuf += mr; + pbuff += mr*mr; + } + }else + { + //Expected multiples of 4 + p_lda = 4; + for(dim_t x = 0; x < size; x++) + { + ymm0 = _mm256_loadu_pd((double const *)(inbuf)); + _mm256_storeu_pd((double *)(pbuff), ymm0); + ymm1 = _mm256_loadu_pd((double const *)(inbuf + 2)); + _mm256_storeu_pd((double *)(pbuff + 2), ymm1); + inbuf+=cs_a; + pbuff+=p_lda; + } + } + }else if(side=='R'||side=='r') + { + + if(trans) + { + for(dim_t x=0; x>1); i++) + { + ymm0 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 0 )); + _mm256_storeu_pd((double *)(pbuff + p_lda * 0), ymm0); + ymm1 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 1 )); + _mm256_storeu_pd((double *)(pbuff + p_lda * 1), ymm1); + ymm2 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 2)); + _mm256_storeu_pd((double *)(pbuff + p_lda * 2), ymm2); + inbuf += 2; + pbuff += 2; + } + if(size & 0x1) + { + xmm0 = _mm_loadu_pd((double const *)(inbuf + cs_a * 0)); + _mm_storeu_pd((double *)(pbuff + p_lda * 0 ), xmm0); + xmm1 = _mm_loadu_pd((double const *)(inbuf + cs_a * 1)); + _mm_storeu_pd((double *)(pbuff + p_lda * 1), xmm1); + xmm2 = _mm_loadu_pd((double const *)(inbuf + cs_a * 2)); + _mm_storeu_pd((double *)(pbuff + p_lda * 2), xmm2); + } + } + } + +} + + +BLIS_INLINE void ztrsm_small_pack_diag_element +( + bool is_unitdiag, + dcomplex *a11, + dim_t cs_a, + dcomplex *d11_pack, + dim_t size +) +{ + __m256d ymm1, ymm2, ymm3, ymm4, ymm5, ymm6, ymm7, ymm8; + bool is_four = (size == 4) ? 1 : 0; + dcomplex ones = {1.0, 1.0}; + ymm2 = ymm1 = _mm256_broadcast_pd((__m128d const *)&ones); + ymm7 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + if(!is_unitdiag) + { + //broadcast diagonal elements of A11 + ymm1 = _mm256_broadcast_pd((__m128d const *)a11); + ymm2 = _mm256_broadcast_pd((__m128d const *)a11+ cs_a +1); + /*Pick one element frome each column and create 3 element vector + and store it*/ + ymm1 = _mm256_permute2f128_pd(ymm1, ymm2, 0x20); + ymm2 = _mm256_broadcast_pd((__m128d const *)a11+ cs_a*2 + 2); + + if(is_four) + { + ymm3 = _mm256_broadcast_pd((__m128d const *)a11+ cs_a*2 + 2); + ymm2 = _mm256_broadcast_pd((__m128d const *)a11+ cs_a*3 + 3); + ymm2 = _mm256_permute2f128_pd(ymm3, ymm2, 0x20); + } + +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + /*Taking denomerator multiplication of real & imaginary components*/ + ymm4 = _mm256_mul_pd(ymm1, ymm1); + ymm5 = _mm256_mul_pd(ymm2,ymm2); + /*Swapping real & imaginary component position for addition with + * respective components*/ + ymm6 = _mm256_permute4x64_pd(ymm4, 0xb1); + ymm4 = _mm256_add_pd(ymm4, ymm6); + ymm8 = _mm256_permute4x64_pd(ymm5, 0xb1); + + ymm5 = _mm256_add_pd(ymm5, ymm8); + /*Negating imaginary component of numerator*/ + ymm1 = _mm256_mul_pd(ymm1, ymm7); + ymm2 = _mm256_mul_pd(ymm2, ymm7); + /*Dividing numerator by denominator*/ + ymm1 = _mm256_div_pd(ymm1, ymm4); + ymm2 = _mm256_div_pd(ymm2, ymm5); +#endif + + } + _mm256_store_pd((double *)d11_pack, ymm1); + if(is_four) + { + _mm256_store_pd((double *)(d11_pack + 2), ymm2); + } + else + { + _mm_store_pd((double *)(d11_pack + 2), + _mm256_extractf128_pd(ymm2,0)); + + } +} + +BLIS_INLINE err_t bli_ztrsm_small_AutXB_AlXB +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +) +{ + dim_t m = bli_obj_length(b); // number of rows of matrix B + dim_t n = bli_obj_width(b); // number of columns of matrix B + + bool transa = bli_obj_has_trans(a); + bool conjtransa = bli_obj_has_conj(a); + + dim_t cs_a, rs_a; + dim_t d_mr = 4,d_nr = 3; + + // Swap rs_a & cs_a in case of non-tranpose. + if(transa) + { + cs_a = bli_obj_col_stride(a); // column stride of A + rs_a = bli_obj_row_stride(a); // row stride of A + } + else + { + cs_a = bli_obj_row_stride(a); // row stride of A + rs_a = bli_obj_col_stride(a); // column stride of A + } + dim_t cs_b = bli_obj_col_stride(b); // column stride of B + + dim_t i, j, k; //loop variables + dim_t k_iter; //number of times GEMM to be performed + + dcomplex AlphaVal = *(dcomplex *)AlphaObj->buffer; //value of alpha + dcomplex *L = a->buffer; //pointer to matrix A + dcomplex *B = b->buffer; //pointer to matrix B + + dcomplex *a10, *a11, *b01, *b11; //pointers that point to blocks for GEMM and TRSM + + dcomplex ones = {1.0, 1.0}; + bool is_unitdiag = bli_obj_has_unit_diag(a); + + //scratch registers + __m256d ymm0, ymm1, ymm2, ymm3; + __m256d ymm4, ymm5, ymm6, ymm7; + __m256d ymm8, ymm9, ymm10, ymm11; + __m256d ymm12, ymm13, ymm14, ymm15; + __m256d ymm16, ymm17, ymm18, ymm19; + + __m128d xmm5, xmm4, xmm3; + + gint_t required_packing_A = 1; + mem_t local_mem_buf_A_s = {0}; + dcomplex *D_A_pack = NULL; + dcomplex d11_pack[d_mr] __attribute__((aligned(64))); + rntm_t rntm; + + bli_rntm_init_from_global( &rntm ); + bli_rntm_set_num_threads_only( 1, &rntm ); + bli_membrk_rntm_set_membrk( &rntm ); + + siz_t buffer_size = bli_pool_block_size( + bli_membrk_pool( + bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), + bli_rntm_membrk(&rntm))); + + if ( (d_mr * m * sizeof(dcomplex)) > buffer_size) + return BLIS_NOT_YET_IMPLEMENTED; + + if (required_packing_A == 1) + { + // Get the buffer from the pool. + bli_membrk_acquire_m(&rntm, + buffer_size, + BLIS_BITVAL_BUFFER_FOR_A_BLOCK, + &local_mem_buf_A_s); + if(FALSE==bli_mem_is_alloc(&local_mem_buf_A_s)) return BLIS_NULL_POINTER; + D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); + if(NULL==D_A_pack) return BLIS_NULL_POINTER; + } + + /* + Performs solving TRSM for 4 colmns at a time from 0 to m/4 in steps of d_mr + a. Load, transpose, Pack A (a10 block), the size of packing 4x3 to 4x (m-4) + First there will be no GEMM and no packing of a10 because it is only TRSM + b. Using packed a10 block and b01 block perform GEMM operation + c. Use GEMM outputs, perform TRSM operaton using a11, b11 and update B + d. Repeat b,c for n rows of B in steps of d_nr + */ + for(i = 0;(i+d_mr-1) < m; i += d_mr) //loop along 'M' dimension + { + a10 = L + (i*cs_a); //pointer to block of A to be used for GEMM + a11 = L + (i*rs_a) + (i*cs_a); + dim_t p_lda = d_mr; // packed leading dimension + + if(transa) + { + /* + Load, tranpose and pack current A block (a10) into packed buffer memory + D_A_pack + a. This a10 block is used in GEMM portion only and this + a10 block size will be increasing by d_mr for every next itteration + untill it reaches 4x(m-4) which is the maximum GEMM alone block size + in A + b. This packed buffer is reused to calculate all n rows of B matrix + */ + bli_ztrsm_small_pack('L', i, 1, a10, cs_a, D_A_pack, p_lda,d_mr); + + /* + Pack 4 diagonal elements of A block into an array + a. This helps in utilze cache line efficiently in TRSM operation + b. store ones when input is unit diagonal + */ + ztrsm_small_pack_diag_element(is_unitdiag,a11,cs_a,d11_pack,d_mr); + } + else + { + bli_ztrsm_small_pack('L', i, 0, a10, rs_a, D_A_pack, p_lda,d_mr); + ztrsm_small_pack_diag_element(is_unitdiag,a11,rs_a,d11_pack,d_mr); + } + /* + a. Perform GEMM using a10, b01. + b. Perform TRSM on a11, b11 + c. This loop GEMM+TRSM loops operates with 4x3 block size + along n dimension for every d_nr rows of b01 where + packed A buffer is reused in computing all n rows of B. + d. Same approch is used in remaining fringe cases. + */ + dim_t temp = n - d_nr + 1; + for(j = 0; j < temp; j += d_nr) //loop along 'N' dimension + { + a10 = D_A_pack; + a11 = L + (i*rs_a) + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + j*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM + + k_iter = i; + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + /* + Peform GEMM between a10 and b01 blocks + For first itteration there will be no GEMM operation + where k_iter are zero + */ + BLIS_ZTRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) + + /* + Load b11 of size 3x4 and multiply with alpha + Add the GEMM output and perform inregister transose of b11 + to peform TRSM operation. + */ + BLIS_ZTRSM_SMALL_NREG_TRANSPOSE_3x4(b11,cs_b,AlphaVal) + /* + Compute 4x3 TRSM block by using GEMM block output in register + a. The 4x3 input (gemm outputs) are stored in combinations of ymm + registers + 1. ymm8, ymm4 2. ymm9, ymm5 3. ymm10, ymm6, 4. ymm11, ymm7 + where ymm8-ymm11 holds 4x2 data and reaming 4x1 will be hold by + other registers + b. Towards the end do in regiser transpose of TRSM output and store in + b11 + */ + ////extract a00 + ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); + +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm8 and ymm4 with ymm1*/ + BLIS_ZTRSM_TWO_DIV(ymm8,ymm4) +#else + /*performs dcomplex multiplication of ymm8 and ymm4 with ymm1*/ + BLIS_ZTRSM_MUL(ymm8) + BLIS_ZTRSM_MUL(ymm4) +#endif + //extract a11 + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack + 1)); + //(ROW1): FMA operations + ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a*1)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + /* Step1 dcomplex multiply ymm2, ymm8 + * Step2 negate the result + * Step3 add ymm9*/ + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + //For ymm8 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm8, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm8, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + + //For ymm4 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + + ymm13 = _mm256_mul_pd(ymm4, ymm2); + ymm14 = _mm256_mul_pd(ymm4, ymm14); + ymm17 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + ymm17 = _mm256_mul_pd(ymm17, ymm15); + + //Step 3 + ymm9 = _mm256_add_pd(ymm16, ymm9); + ymm5 = _mm256_add_pd(ymm17, ymm5); + + ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a*2)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + //For ymm8 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm8, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm8, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + //For ymm4 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + + ymm13 = _mm256_mul_pd(ymm4, ymm2); + ymm14 = _mm256_mul_pd(ymm4, ymm14); + ymm17 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + ymm17 = _mm256_mul_pd(ymm17, ymm15); + + //Step 3 + ymm10 = _mm256_add_pd(ymm16, ymm10); + ymm6 = _mm256_add_pd(ymm17, ymm6); + + ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a*3)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + //For ymm8 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm8, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm8, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + //For ymm4 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + + ymm13 = _mm256_mul_pd(ymm4, ymm2); + ymm14 = _mm256_mul_pd(ymm4, ymm14); + ymm17 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + ymm17 = _mm256_mul_pd(ymm17, ymm15); + //Step 3 + ymm11 = _mm256_add_pd(ymm16, ymm11); + ymm7 = _mm256_add_pd(ymm17, ymm7); + +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm9 and ymm5 with ymm1*/ + BLIS_ZTRSM_TWO_DIV(ymm9,ymm5) +#else + /*performs dcomplex multiplication of ymm9 and ymm5 with ymm1*/ + BLIS_ZTRSM_MUL(ymm9) + BLIS_ZTRSM_MUL(ymm5) +#endif + a11 += rs_a; + //extract a22 + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack + 2)); + + //(ROW2): FMA operations + ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a*2)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + + //For ymm9 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm9, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm9, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + //For ymm5 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + + ymm13 = _mm256_mul_pd(ymm5, ymm2); + ymm14 = _mm256_mul_pd(ymm5, ymm14); + ymm17 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + ymm17 = _mm256_mul_pd(ymm17, ymm15); + //Step 3 + ymm10 = _mm256_add_pd(ymm16, ymm10); + ymm6 = _mm256_add_pd(ymm17, ymm6); + + ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a*3)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + + //For ymm9 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm9, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm9, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + //For ymm5 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + + ymm13 = _mm256_mul_pd(ymm5, ymm2); + ymm14 = _mm256_mul_pd(ymm5, ymm14); + ymm17 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + ymm17 = _mm256_mul_pd(ymm17, ymm15); + //Step 3 + ymm11 = _mm256_add_pd(ymm16, ymm11); + ymm7 = _mm256_add_pd(ymm17, ymm7); + +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm10 and ymm6 with ymm1*/ + BLIS_ZTRSM_TWO_DIV(ymm10,ymm6) +#else + /*performs dcomplex multiplication of ymm10 and ymm6 with ymm1*/ + BLIS_ZTRSM_MUL(ymm10) + BLIS_ZTRSM_MUL(ymm6) +#endif + a11 += rs_a; + //extract a44 + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack + 3)); + //(ROW3): FMA operations + ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a*3)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + + //For ymm10 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm10, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm10, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + //For ymm6 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + + ymm13 = _mm256_mul_pd(ymm6, ymm2); + ymm14 = _mm256_mul_pd(ymm6, ymm14); + ymm17 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + ymm17 = _mm256_mul_pd(ymm17, ymm15); + //Step 3 + ymm11 = _mm256_add_pd(ymm16, ymm11); + ymm7 = _mm256_add_pd(ymm17, ymm7); + +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm11 and ymm7 with ymm1*/ + BLIS_ZTRSM_TWO_DIV(ymm11,ymm7) +#else + /*performs dcomplex nultiplication of ymm11 and ymm7 with ymm1*/ + BLIS_ZTRSM_MUL(ymm11) + BLIS_ZTRSM_MUL(ymm7) +#endif + a11 += rs_a; + BLIS_ZTRSM_SMALL_NREG_TRANSPOSE_4x3_AND_STORE(b11,cs_b) + } + + dim_t n_rem = n-j; + if(n_rem) + { + a10 = D_A_pack; + a11 = L + (i*rs_a) + (i*cs_a);//pointer to block of A to be used for TRSM + b01 = B + j*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM + + k_iter = i; //number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + if(2 == n_rem) + { + ///GEMM code begins/// + BLIS_ZTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b,p_lda,k_iter) + BLIS_ZTRSM_SMALL_NREG_TRANSPOSE_2x4(b11,cs_b,AlphaVal) + } + else if(1 == n_rem) + { + ///GEMM code begins/// + BLIS_ZTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b,p_lda,k_iter) + BLIS_ZTRSM_SMALL_NREG_TRANSPOSE_1x4(b11,cs_b,AlphaVal) + } + ///implement TRSM/// + + ///transpose of B11// + ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + ymm10 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm11 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + ////extract a00 + ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); + +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_ZTRSM_DIV(ymm8) +#else + BLIS_ZTRSM_MUL(ymm8) +#endif + + //extract a11 + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack + 1)); + //(ROW1): FMA operations + ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a*1)); + ymm3 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a*2)); + ymm4 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a*3)); + + if(conjtransa){ + ymm2 = _mm256_mul_pd(ymm2, ymm0); + ymm3 = _mm256_mul_pd(ymm3, ymm0); + ymm4 = _mm256_mul_pd(ymm4, ymm0); + } + + a11 += rs_a; + /*Step1 dcomplex multiply ymmx, ymmx + * Step2 negate the result + * Step3 add ymmx*/ + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + //For ymm8 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm8, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm8, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + + //Step 3 + ymm9 = _mm256_add_pd(ymm16, ymm9); + + //Step 1 + ymm14 = _mm256_permute_pd(ymm3, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + //For ymm8 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm8, ymm3); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm8, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + + //Step 3 + ymm10 = _mm256_add_pd(ymm16, ymm10); + + //Step 1 + ymm14 = _mm256_permute_pd(ymm4, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + //For ymm8 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm8, ymm4); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm8, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + //Step 3 + ymm11 = _mm256_add_pd(ymm16, ymm11); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_ZTRSM_DIV(ymm9) +#else + BLIS_ZTRSM_MUL(ymm9) +#endif + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack + 2)); + ymm3 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a*2)); + ymm4 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a*3)); + + if(conjtransa){ + ymm3 = _mm256_mul_pd(ymm3, ymm0); + ymm4 = _mm256_mul_pd(ymm4, ymm0); + } + + a11 += rs_a; + //Step 1 + ymm14 = _mm256_permute_pd(ymm3, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + //For ymm9 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm9, ymm3); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm9, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + + //Step 3 + ymm10 = _mm256_add_pd(ymm16, ymm10); + + //Step 1 + ymm14 = _mm256_permute_pd(ymm4, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + //For ymm8 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm9, ymm4); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm9, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + //Step 3 + ymm11 = _mm256_add_pd(ymm16, ymm11); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_ZTRSM_DIV(ymm10) +#else + BLIS_ZTRSM_MUL(ymm10) +#endif + + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack + 3)); + ymm4 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a*3)); + + if(conjtransa){ + ymm4 = _mm256_mul_pd(ymm4, ymm0); + } + + //Step 1 + ymm14 = _mm256_permute_pd(ymm4, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + //For ymm10 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm10, ymm4); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm10, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + //Step 3 + ymm11 = _mm256_add_pd(ymm16, ymm11); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_ZTRSM_DIV(ymm11) +#else + BLIS_ZTRSM_MUL(ymm11) +#endif + if(n_rem == 1) + { + ymm0 = _mm256_permute2f128_pd(ymm8,ymm9,0x20); + ymm4 = _mm256_permute2f128_pd(ymm10,ymm11,0x20); + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 2), ymm4); + } + else if(n_rem == 2) + { + ymm0 = _mm256_permute2f128_pd(ymm8,ymm9,0x20); + ymm4 = _mm256_permute2f128_pd(ymm10,ymm11,0x20); + ymm1 = _mm256_permute2f128_pd(ymm8,ymm9,0x31); + ymm3 = _mm256_permute2f128_pd(ymm10,ymm11,0x31); + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 2), ymm4); + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); + _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 2), ymm3); + } + } + } + dim_t m_rem = m-i; + if(m_rem) + { + a10 = L + (i*cs_a); + dcomplex *ptr_a10_dup = D_A_pack; + if(m_rem == 3) + { + dim_t p_lda = 4; + if(transa) + { + for(dim_t x = 0; x < i; x += p_lda) + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm10 = _mm256_loadu_pd((double const *) + (a10 + 2)); + ymm1 = _mm256_loadu_pd((double const *) + (a10 + cs_a)); + ymm11 = _mm256_loadu_pd((double const *) + (a10 + 2 + cs_a)); + + ymm6 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + ymm8 = _mm256_permute2f128_pd(ymm10,ymm11,0x20); + ymm9 = _mm256_permute2f128_pd(ymm10,ymm11,0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + + p_lda*3), ymm9); + + ymm0 = _mm256_loadu_pd((double const *)(a10 + + 2 * cs_a)); + ymm10 = _mm256_loadu_pd((double const *)(a10 + + 2 * cs_a + 2)); + + ymm1 = _mm256_loadu_pd((double const *)(a10 + + 3 * cs_a)); + ymm11 = _mm256_loadu_pd((double const *)(a10 + + 3 * cs_a + 2)); + + ymm6 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + ymm8 = _mm256_permute2f128_pd(ymm10,ymm11,0x20); + ymm9 = _mm256_permute2f128_pd(ymm10,ymm11,0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup + 2), + ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + + p_lda + 2), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + + p_lda*2 + 2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + + p_lda*3 + 2), ymm9); + + a10 += p_lda; + ptr_a10_dup += p_lda * p_lda; + } + + } + else + { + for(dim_t x=0;xbuffer; //value of alpha + dcomplex *L = a->buffer; //pointer to matrix A + dcomplex *B = b->buffer; //pointer to matrix B + + //pointers that point to blocks for GEMM and TRSM + dcomplex *a10, *a11, *b01, *b11; + + dcomplex ones = {1.0, 1.0}; + bool is_unitdiag = bli_obj_has_unit_diag(a); + + //scratch registers + __m256d ymm0, ymm1, ymm2, ymm3; + __m256d ymm4, ymm5, ymm6, ymm7; + __m256d ymm8, ymm9, ymm10, ymm11; + __m256d ymm12, ymm13, ymm14, ymm15; + __m256d ymm16, ymm17, ymm18, ymm19; + + __m128d xmm5, xmm4, xmm3; + + gint_t required_packing_A = 1; + mem_t local_mem_buf_A_s = {0}; + dcomplex *D_A_pack = NULL; + dcomplex d11_pack[d_mr] __attribute__((aligned(64))); + rntm_t rntm; + + bli_rntm_init_from_global( &rntm ); + bli_rntm_set_num_threads_only( 1, &rntm ); + bli_membrk_rntm_set_membrk( &rntm ); + + siz_t buffer_size = bli_pool_block_size( + bli_membrk_pool( + bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), + bli_rntm_membrk(&rntm))); + + if((d_mr * m * sizeof(dcomplex)) > buffer_size) + return BLIS_NOT_YET_IMPLEMENTED; + + if(required_packing_A == 1) + { + // Get the buffer from the pool. + bli_membrk_acquire_m(&rntm, + buffer_size, + BLIS_BITVAL_BUFFER_FOR_A_BLOCK, + &local_mem_buf_A_s); + if(FALSE==bli_mem_is_alloc(&local_mem_buf_A_s)) return BLIS_NULL_POINTER; + D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); + if(NULL==D_A_pack) return BLIS_NULL_POINTER; + } + + /* + Performs solving TRSM for 4 colmns at a time from 0 to m/d_mr in steps of d_mr + a. Load, transpose, Pack A (a10 block), the size of packing 8x6 to 8x (m-d_mr) + First there will be no GEMM and no packing of a10 because it is only TRSM + b. Using packed a10 block and b01 block perform GEMM operation + c. Use GEMM outputs, perform TRSM operaton using a11, b11 and update B + d. Repeat b,c for n rows of B in steps of d_nr + */ + for(i = (m - d_mr); (i + 1) > 0; i -= d_mr) + { + a10 = L + (i*cs_a) + (i + d_mr)*rs_a;//pointer to block of A to be used for GEMM + a11 = L + (i*cs_a) + (i*rs_a);//pointer to block of A to be used for TRSM + + // Do transpose for a10 & store in D_A_pack + //ptr_a10_dup = D_A_pack; + + dim_t p_lda = d_mr; // packed leading dimension + + if(transa) + { + /* + Load, transpose and pack current A block (a10) into packed buffer memory + D_A_pack + a. This a10 block is used in GEMM portion only and this + a10 block size will be increasing by d_mr for every next itteration + untill it reaches 4x(m-4) which is the maximum GEMM alone block size + in A + b. This packed buffer is reused to calculate all n rows of B matrix + */ + bli_ztrsm_small_pack('L', (m-i-d_mr), 1, a10, cs_a, D_A_pack,p_lda,d_mr); + + /* + Pack 8 diagonal elements of A block into an array + a. This helps in utilze cache line efficiently in TRSM operation + b. store ones when input is unit diagonal + */ + ztrsm_small_pack_diag_element(is_unitdiag,a11,cs_a,d11_pack,d_mr); + } + else + { + bli_ztrsm_small_pack('L', (m-i-d_mr), 0, a10, rs_a, D_A_pack,p_lda,d_mr); + ztrsm_small_pack_diag_element(is_unitdiag,a11,rs_a,d11_pack,d_mr); + } + + /* + a. Perform GEMM using a10, b01. + b. Perform TRSM on a11, b11 + c. This loop GEMM+TRSM loops operates with 8x6 block size + along n dimension for every d_nr rows of b01 where + packed A buffer is reused in computing all n rows of B. + d. Same approch is used in remaining fringe cases. + */ + for(j = (n - d_nr); (j + 1) > 0; j -= d_nr) + { + a10 = D_A_pack; + b01 = B + (j * cs_b) + i + d_mr;//pointer to block of B to be used for GEMM + b11 = B + (j * cs_b) + i;//pointer to block of B to be used for TRSM + + k_iter = (m - i - d_mr); + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + /* + Peform GEMM between a10 and b01 blocks + For first itteration there will be no GEMM operation + where k_iter are zero + */ + BLIS_ZTRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) + + /* + Load b11 of size 6x8 and multiply with alpha + Add the GEMM output and perform inregister transose of b11 + to peform TRSM operation. + */ + BLIS_ZTRSM_SMALL_NREG_TRANSPOSE_3x4(b11,cs_b,AlphaVal) + + /* + Compute 4x3 TRSM block by using GEMM block output in register + a. The 4x3 input (gemm outputs) are stored in combinations of ymm + registers + 1. ymm8, ymm4 2. ymm9, ymm5 3. ymm10, ymm6, 4. ymm11, ymm7 + where ymm8-ymm11 holds 4x2 data and reaming 4x1 will be hold by + other registers + b. Towards the end do in regiser transpose of TRSM output and store in + b11 + */ + ////extract a00 + ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack + 3)); + +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm11 and ymm7 with ymm1*/ + BLIS_ZTRSM_TWO_DIV(ymm11,ymm7) +#else + /*performs dcomplex multiplication of ymm11 and ymm7 with ymm1*/ + BLIS_ZTRSM_MUL(ymm11) + BLIS_ZTRSM_MUL(ymm7) +#endif + //extract a11 + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack + 2)); + //(ROW1): FMA operations + ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a*2 + rs_a*3)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + /* Step1 dcomplex multiply ymm2, ymm8 + * Step2 negate the result + * Step3 add ymm9*/ + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + //For ymm11 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm11, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm11, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + + //For ymm7 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + + ymm13 = _mm256_mul_pd(ymm7, ymm2); + ymm14 = _mm256_mul_pd(ymm7, ymm14); + ymm17 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + ymm17 = _mm256_mul_pd(ymm17, ymm15); + + //Step 3 + ymm10 = _mm256_add_pd(ymm16, ymm10); + ymm6 = _mm256_add_pd(ymm17, ymm6); + + ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a*1 + rs_a*3)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + //For ymm11 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm11, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm11, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + //For ymm7 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + + ymm13 = _mm256_mul_pd(ymm7, ymm2); + ymm14 = _mm256_mul_pd(ymm7, ymm14); + ymm17 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + ymm17 = _mm256_mul_pd(ymm17, ymm15); + + //Step 3 + ymm9 = _mm256_add_pd(ymm16, ymm9); + ymm5 = _mm256_add_pd(ymm17, ymm5); + + ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + rs_a*3)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + //For ymm11 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm11, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm11, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + //For ymm7 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + + ymm13 = _mm256_mul_pd(ymm7, ymm2); + ymm14 = _mm256_mul_pd(ymm7, ymm14); + ymm17 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + ymm17 = _mm256_mul_pd(ymm17, ymm15); + //Step 3 + ymm8 = _mm256_add_pd(ymm16, ymm8); + ymm4 = _mm256_add_pd(ymm17, ymm4); + +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm10 and ymm6 with ymm1*/ + BLIS_ZTRSM_TWO_DIV(ymm10,ymm6) +#else + /*performs dcomplex multiplication of ymm10 and ymm6 with ymm1*/ + BLIS_ZTRSM_MUL(ymm10) + BLIS_ZTRSM_MUL(ymm6) +#endif + //extract a22 + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack + 1)); + + //(ROW2): FMA operations + ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a*1 + rs_a*2)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + + //For ymm10 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm10, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm10, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + //For ymm6 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + + ymm13 = _mm256_mul_pd(ymm6, ymm2); + ymm14 = _mm256_mul_pd(ymm6, ymm14); + ymm17 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + ymm17 = _mm256_mul_pd(ymm17, ymm15); + //Step 3 + ymm9 = _mm256_add_pd(ymm16, ymm9); + ymm5 = _mm256_add_pd(ymm17, ymm5); + + ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + rs_a*2)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + + //For ymm10 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm10, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm10, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + //For ymm6 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + + ymm13 = _mm256_mul_pd(ymm6, ymm2); + ymm14 = _mm256_mul_pd(ymm6, ymm14); + ymm17 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + ymm17 = _mm256_mul_pd(ymm17, ymm15); + //Step 3 + ymm8 = _mm256_add_pd(ymm16, ymm8); + ymm4 = _mm256_add_pd(ymm17, ymm4); + +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm9 and ymm5 with ymm1*/ + BLIS_ZTRSM_TWO_DIV(ymm9,ymm5) +#else + /*performs dcomplex multiplication of ymm9 and ymm5 with ymm1*/ + BLIS_ZTRSM_MUL(ymm9) + BLIS_ZTRSM_MUL(ymm5) +#endif + //extract a44 + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); + //(ROW3): FMA operations + ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + rs_a)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + + //For ymm9 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm9, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm9, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + //For ymm5 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + + ymm13 = _mm256_mul_pd(ymm5, ymm2); + ymm14 = _mm256_mul_pd(ymm5, ymm14); + ymm17 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + ymm17 = _mm256_mul_pd(ymm17, ymm15); + //Step 3 + ymm8 = _mm256_add_pd(ymm16, ymm8); + ymm4 = _mm256_add_pd(ymm17, ymm4); + +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm8 and ymm4 with ymm1*/ + BLIS_ZTRSM_TWO_DIV(ymm8,ymm4) +#else + /*performs dcomplex nultiplication of ymm8 and ymm4 with ymm1*/ + BLIS_ZTRSM_MUL(ymm8) + BLIS_ZTRSM_MUL(ymm4) + +#endif + BLIS_ZTRSM_SMALL_NREG_TRANSPOSE_4x3_AND_STORE(b11,cs_b) + + + } + dim_t n_remainder = j + d_nr; + if(n_remainder) + { + a10 = D_A_pack; + a11 = L + (i*cs_a) + (i*rs_a); + b01 = B + i + d_mr; + b11 = B + i; + + k_iter = (m - i - d_mr) ; + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + if(2 == n_remainder) + { + ///GEMM code begins/// + BLIS_ZTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b,p_lda,k_iter) + + ymm16 = _mm256_broadcast_pd((__m128d const *)(&AlphaVal)); + //register to hold alpha + BLIS_ZTRSM_SMALL_NREG_TRANSPOSE_2x4(b11,cs_b,AlphaVal) + } + else if(1 == n_remainder) + { + ///GEMM code begins/// + BLIS_ZTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b,p_lda,k_iter) + BLIS_ZTRSM_SMALL_NREG_TRANSPOSE_1x4(b11,cs_b,AlphaVal) + } + ///implement TRSM/// + ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + ymm10 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm11 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + ////extract a00 + ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack + 3)); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_ZTRSM_DIV(ymm11) +#else + BLIS_ZTRSM_MUL(ymm11) +#endif + + //extract a11 + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack + 2)); + //(ROW1): FMA operations + ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a*2 + rs_a*3)); + ymm3 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a*1 + rs_a*3)); + ymm4 = _mm256_broadcast_pd((__m128d const *)(a11 + rs_a*3)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + ymm3 = _mm256_mul_pd(ymm3, ymm0); + ymm4 = _mm256_mul_pd(ymm4, ymm0); + } + /*Step1 dcomplex multiply ymmx, ymmx + * Step2 negate the result + * Step3 add ymmx*/ + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + //For ymm8 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm11, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm11, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + + //Step 3 + ymm10 = _mm256_add_pd(ymm16, ymm10); + + //Step 1 + ymm14 = _mm256_permute_pd(ymm3, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + //For ymm8 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm11, ymm3); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm11, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + + //Step 3 + ymm9 = _mm256_add_pd(ymm16, ymm9); + + //Step 1 + ymm14 = _mm256_permute_pd(ymm4, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + //For ymm8 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm11, ymm4); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm11, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + //Step 3 + ymm8 = _mm256_add_pd(ymm16, ymm8); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_ZTRSM_DIV(ymm10) +#else + BLIS_ZTRSM_MUL(ymm10) +#endif + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack + 1)); + ymm3 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a*1 + rs_a*2)); + ymm4 = _mm256_broadcast_pd((__m128d const *)(a11 + rs_a*2)); + if(conjtransa) + { + ymm3 = _mm256_mul_pd(ymm3, ymm0); + ymm4 = _mm256_mul_pd(ymm4, ymm0); + } + //Step 1 + ymm14 = _mm256_permute_pd(ymm3, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + //For ymm9 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm10, ymm3); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm10, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + + //Step 3 + ymm9 = _mm256_add_pd(ymm16, ymm9); + + //Step 1 + ymm14 = _mm256_permute_pd(ymm4, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + //For ymm8 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm10, ymm4); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm10, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + //Step 3 + ymm8 = _mm256_add_pd(ymm16, ymm8); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_ZTRSM_DIV(ymm9) +#else + BLIS_ZTRSM_MUL(ymm9) +#endif + + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); + ymm4 = _mm256_broadcast_pd((__m128d const *)(a11 + rs_a)); + if(conjtransa) + { + ymm4 = _mm256_mul_pd(ymm4, ymm0); + } + //Step 1 + ymm14 = _mm256_permute_pd(ymm4, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + //For ymm10 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm9, ymm4); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm9, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + //Step 3 + ymm8 = _mm256_add_pd(ymm16, ymm8); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_ZTRSM_DIV(ymm8) +#else + BLIS_ZTRSM_MUL(ymm8) +#endif + + if(2 == n_remainder) + { + ymm0 = _mm256_permute2f128_pd(ymm8,ymm9,0x20); + ymm4 = _mm256_permute2f128_pd(ymm10,ymm11,0x20); + ymm1 = _mm256_permute2f128_pd(ymm8,ymm9,0x31); + ymm3 = _mm256_permute2f128_pd(ymm10,ymm11,0x31); + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 2), ymm4); + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); + _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 2), ymm3); + + } + else if(1 == n_remainder) + { + ymm0 = _mm256_permute2f128_pd(ymm8,ymm9,0x20); + ymm4 = _mm256_permute2f128_pd(ymm10,ymm11,0x20); + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 2), ymm4); + } + } + } + + dim_t m_remainder = i + d_mr; + a10 = L + m_remainder*rs_a; + dcomplex *ptr_a10_dup = D_A_pack; + if(m_remainder == 3) + { + dim_t p_lda = 4; + if(transa) + { + for(dim_t x = 0; x < m-m_remainder; x += p_lda) + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm10 = _mm256_loadu_pd((double const *) + (a10 + 2)); + ymm1 = _mm256_loadu_pd((double const *) + (a10 + cs_a)); + ymm11 = _mm256_loadu_pd((double const *) + (a10 + 2 + cs_a)); + + ymm6 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + ymm8 = _mm256_permute2f128_pd(ymm10,ymm11,0x20); + ymm9 = _mm256_permute2f128_pd(ymm10,ymm11,0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + + p_lda*3), ymm9); + + ymm0 = _mm256_loadu_pd((double const *)(a10 + + 2 * cs_a)); + ymm10 = _mm256_loadu_pd((double const *)(a10 + + 2 * cs_a + 2)); + + ymm1 = _mm256_loadu_pd((double const *)(a10 + + 3 * cs_a)); + ymm11 = _mm256_loadu_pd((double const *)(a10 + + 3 * cs_a + 2)); + + ymm6 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + ymm8 = _mm256_permute2f128_pd(ymm10,ymm11,0x20); + ymm9 = _mm256_permute2f128_pd(ymm10,ymm11,0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup + 2), + ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + + p_lda + 2), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + + p_lda*2 + 2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + + p_lda*3 + 2), ymm9); + + a10 += p_lda; + ptr_a10_dup += p_lda * p_lda; + } + + } + else + { + for(dim_t x=0;x < m-m_remainder;x++) + { + ymm0 = _mm256_loadu_pd((double const *) + (a10 + rs_a * x)); + _mm256_storeu_pd((double *) + (ptr_a10_dup + p_lda * x), ymm0); + ymm0 = _mm256_loadu_pd((double const *) + (a10 + rs_a * x + 2)); + _mm256_storeu_pd((double *) + (ptr_a10_dup + p_lda * x + 2), + ymm0); + } + } + //cols + for(j = (n - d_nr); (j + 1) > 0; j -= d_nr) + { + a10 = D_A_pack; + a11 = L; + b01 = B + (j*cs_b) + m_remainder; + b11 = B + (j* cs_b); + k_iter = (m - m_remainder); + + BLIS_SET_YMM_REG_ZEROS + ///GEMM code begins/// + BLIS_ZTRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) + ///GEMM code ends/// + ymm16 = _mm256_broadcast_pd((__m128d const *) + (&AlphaVal)); + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + + ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm16, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm18); + ymm17 = _mm256_mul_pd(ymm0, ymm16); + ymm14 = _mm256_mul_pd(ymm0, ymm14); + ymm15 = _mm256_hsub_pd(ymm17, ymm14); + + ymm8 = _mm256_sub_pd(ymm15,ymm8); + + ymm14 = _mm256_permute_pd(ymm16, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm18); + ymm17 = _mm256_mul_pd(ymm1, ymm16); + ymm14 = _mm256_mul_pd(ymm1, ymm14); + ymm15 = _mm256_hsub_pd(ymm17, ymm14); + + ymm9 = _mm256_sub_pd(ymm15,ymm9); + + ymm14 = _mm256_permute_pd(ymm16, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm18); + ymm17 = _mm256_mul_pd(ymm2, ymm16); + ymm14 = _mm256_mul_pd(ymm2, ymm14); + ymm15 = _mm256_hsub_pd(ymm17, ymm14); + + ymm10 = _mm256_sub_pd(ymm15,ymm10); + + _mm256_storeu_pd((double *)(b11), ymm8); + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm9); + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm10); + + ymm0 = _mm256_loadu_pd((double const *) + (b11 + cs_b *0 + 2)); + ymm1 = _mm256_loadu_pd((double const *) + (b11 + cs_b *1 + 2)); + ymm2 = _mm256_loadu_pd((double const *) + (b11 + cs_b *2 + 2)); + + ymm14 = _mm256_permute_pd(ymm16, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm18); + ymm17 = _mm256_mul_pd(ymm0, ymm16); + ymm14 = _mm256_mul_pd(ymm0, ymm14); + ymm15 = _mm256_hsub_pd(ymm17, ymm14); + + ymm11 = _mm256_sub_pd(ymm15,ymm11); + + ymm14 = _mm256_permute_pd(ymm16, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm18); + ymm17 = _mm256_mul_pd(ymm1, ymm16); + ymm14 = _mm256_mul_pd(ymm1, ymm14); + ymm15 = _mm256_hsub_pd(ymm17, ymm14); + + ymm12 = _mm256_sub_pd(ymm15,ymm12); + ymm14 = _mm256_permute_pd(ymm16, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm18); + ymm17 = _mm256_mul_pd(ymm2, ymm16); + ymm14 = _mm256_mul_pd(ymm2, ymm14); + ymm15 = _mm256_hsub_pd(ymm17, ymm14); + + ymm13 = _mm256_sub_pd(ymm15,ymm13); + _mm_storeu_pd((double *)(b11 + 2), + _mm256_extractf128_pd(ymm11,0)); + _mm_storeu_pd((double *)(b11 + cs_b * 1 + 2), + _mm256_extractf128_pd(ymm12,0)); + _mm_storeu_pd((double *)(b11 + cs_b * 2 + 2), + _mm256_extractf128_pd(ymm13,0)); + + if(transa) + ztrsm_AltXB_ref(a11, b11, m_remainder, 3, + cs_a, cs_b, is_unitdiag, + conjtransa); + else + ztrsm_AuXB_ref(a11, b11, m_remainder, 3, + rs_a, cs_b, is_unitdiag, + conjtransa); + } + dim_t n_remainder = j + d_nr; + if(n_remainder) + { + a10 = D_A_pack; + a11 = L; + b01 = B + m_remainder; + b11 = B; + k_iter = (m - m_remainder); + BLIS_SET_YMM_REG_ZEROS + if(2 == n_remainder) + { + ///GEMM code begins/// + BLIS_ZTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b, + p_lda,k_iter) + BLIS_PRE_ZTRSM_SMALL_3M_2N(AlphaVal,b11,cs_b) + + if(transa) + ztrsm_AltXB_ref(a11, b11, m_remainder, 2, + cs_a, cs_b, is_unitdiag, + conjtransa); + + else + ztrsm_AuXB_ref(a11, b11, m_remainder, 2, + rs_a, cs_b, is_unitdiag, + conjtransa); + } + else if(1 == n_remainder) + { + ///GEMM code begins/// + BLIS_ZTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b, + p_lda,k_iter) + BLIS_PRE_ZTRSM_SMALL_3M_1N(AlphaVal,b11,cs_b) + + if(transa) + ztrsm_AltXB_ref(a11, b11, m_remainder, 1, + cs_a, cs_b, is_unitdiag, + conjtransa); + else + ztrsm_AuXB_ref(a11, b11, m_remainder, 1, + rs_a, cs_b, is_unitdiag, + conjtransa); + + } + } + } + else if(m_remainder == 2) + { + dim_t p_lda = 2; + if(transa) + { + for(dim_t x = 0; x < m-m_remainder; x += p_lda) + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *) + (a10 + cs_a)); + + ymm6 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + + p_lda), ymm7); + + a10 += p_lda; + ptr_a10_dup += p_lda * p_lda; + } + + } + else + { + for(dim_t x=0;x < m-m_remainder;x++) + { + ymm0 = _mm256_loadu_pd((double const *) + (a10 + rs_a * x)); + _mm256_storeu_pd((double *) + (ptr_a10_dup + p_lda * x), ymm0); + } + } + //cols + for(j = (n - d_nr); (j + 1) > 0; j -= d_nr) + { + a10 = D_A_pack; + a11 = L; + b01 = B + (j*cs_b) + m_remainder; + b11 = B + (j* cs_b); + k_iter = (m - m_remainder); + + BLIS_SET_YMM_REG_ZEROS + ///GEMM code begins/// + BLIS_ZTRSM_SMALL_GEMM_2mx3n(a10,b01,cs_b,p_lda,k_iter) + ///GEMM code ends/// + ymm16 = _mm256_broadcast_pd((__m128d const *) + (&AlphaVal)); + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + + ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm16, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm18); + ymm17 = _mm256_mul_pd(ymm0, ymm16); + ymm14 = _mm256_mul_pd(ymm0, ymm14); + ymm15 = _mm256_hsub_pd(ymm17, ymm14); + + ymm8 = _mm256_sub_pd(ymm15,ymm8); + + ymm14 = _mm256_permute_pd(ymm16, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm18); + ymm17 = _mm256_mul_pd(ymm1, ymm16); + ymm14 = _mm256_mul_pd(ymm1, ymm14); + ymm15 = _mm256_hsub_pd(ymm17, ymm14); + + ymm9 = _mm256_sub_pd(ymm15,ymm9); + + ymm14 = _mm256_permute_pd(ymm16, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm18); + ymm17 = _mm256_mul_pd(ymm2, ymm16); + ymm14 = _mm256_mul_pd(ymm2, ymm14); + ymm15 = _mm256_hsub_pd(ymm17, ymm14); + + ymm10 = _mm256_sub_pd(ymm15,ymm10); + + _mm256_storeu_pd((double *)(b11), ymm8); + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm9); + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm10); + + if(transa) + ztrsm_AltXB_ref(a11, b11, m_remainder, 3, + cs_a, cs_b, is_unitdiag, + conjtransa); + else + ztrsm_AuXB_ref(a11, b11, m_remainder, 3, + rs_a, cs_b, is_unitdiag, + conjtransa); + } + dim_t n_remainder = j + d_nr; + if(n_remainder) + { + a10 = D_A_pack; + a11 = L; + b01 = B + m_remainder; + b11 = B; + k_iter = (m - m_remainder); + BLIS_SET_YMM_REG_ZEROS + if(2 == n_remainder) + { + ///GEMM code begins/// + BLIS_ZTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b, + p_lda,k_iter) + BLIS_PRE_ZTRSM_SMALL_2M_2N(AlphaVal,b11,cs_b) + + if(transa) + ztrsm_AltXB_ref(a11, b11, m_remainder, 2, + cs_a, cs_b, is_unitdiag, + conjtransa); + + else + ztrsm_AuXB_ref(a11, b11, m_remainder, 2, + rs_a, cs_b, is_unitdiag, + conjtransa); + } + else if(1 == n_remainder) + { + ///GEMM code begins/// + BLIS_ZTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b, + p_lda,k_iter) + BLIS_PRE_ZTRSM_SMALL_2M_1N(AlphaVal,b11,cs_b) + + if(transa) + ztrsm_AltXB_ref(a11, b11, m_remainder, 1, + cs_a, cs_b, is_unitdiag, + conjtransa); + else + ztrsm_AuXB_ref(a11, b11, m_remainder, 1, + rs_a, cs_b, is_unitdiag, + conjtransa); + + } + } + } + else if(m_remainder == 1) + { + dim_t p_lda = 2; // packed leading dimension + if(transa) + { + for(dim_t x = 0; x < m-m_remainder; x += p_lda) + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *) + (a10 + cs_a)); + + ymm6 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + + p_lda), ymm7); + + a10 += p_lda; + ptr_a10_dup += p_lda * p_lda; + } + + } + else + { + for(dim_t x=0;x 0; j -= d_nr) + { + a10 = D_A_pack; + a11 = L; + b01 = B + (j*cs_b) + m_remainder; + b11 = B + (j* cs_b); + k_iter = (m - m_remainder); + + BLIS_SET_YMM_REG_ZEROS + ///GEMM code begins/// + BLIS_ZTRSM_SMALL_GEMM_2mx3n(a10,b01,cs_b,p_lda,k_iter) + ///GEMM code ends/// + ymm16 = _mm256_broadcast_pd((__m128d const *) + (&AlphaVal)); + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm16, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm18); + ymm17 = _mm256_mul_pd(ymm0, ymm16); + ymm14 = _mm256_mul_pd(ymm0, ymm14); + ymm15 = _mm256_hsub_pd(ymm17, ymm14); + + ymm8 = _mm256_sub_pd(ymm15,ymm8); + + ymm14 = _mm256_permute_pd(ymm16, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm18); + ymm17 = _mm256_mul_pd(ymm1, ymm16); + ymm14 = _mm256_mul_pd(ymm1, ymm14); + ymm15 = _mm256_hsub_pd(ymm17, ymm14); + + ymm9 = _mm256_sub_pd(ymm15,ymm9); + ymm14 = _mm256_permute_pd(ymm16, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm18); + ymm17 = _mm256_mul_pd(ymm2, ymm16); + ymm14 = _mm256_mul_pd(ymm2, ymm14); + ymm15 = _mm256_hsub_pd(ymm17, ymm14); + + ymm10 = _mm256_sub_pd(ymm15,ymm10); + + _mm_storeu_pd((double *)(b11), + _mm256_extractf128_pd(ymm8,0)); + _mm_storeu_pd((double *)(b11 + cs_b * 1), + _mm256_extractf128_pd(ymm9,0) ); + _mm_storeu_pd((double *)(b11 + cs_b * 2), + _mm256_extractf128_pd(ymm10,0)); + + if(transa) + ztrsm_AltXB_ref(a11, b11, m_remainder, 3, + cs_a, cs_b, is_unitdiag, + conjtransa); + + else + ztrsm_AuXB_ref(a11, b11, m_remainder, 3, rs_a, + cs_b, is_unitdiag, + conjtransa); + } + dim_t n_remainder = j + d_nr; + if(n_remainder) + { + a10 = D_A_pack; + a11 = L ; + b01 = B + m_remainder; + b11 = B; + k_iter = (m - m_remainder); + BLIS_SET_YMM_REG_ZEROS + if(2 == n_remainder) + { + + ///GEMM code begins/// + BLIS_ZTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b, + p_lda,k_iter) + BLIS_PRE_ZTRSM_SMALL_1M_2N(AlphaVal,b11,cs_b) + + if(transa) + ztrsm_AltXB_ref(a11, b11, m_remainder, 2, + cs_a, cs_b, is_unitdiag, + conjtransa); + + else + ztrsm_AuXB_ref(a11, b11, m_remainder, 2, + rs_a, cs_b, is_unitdiag, + conjtransa); + } + else if(1 == n_remainder) + { + ///GEMM code begins/// + BLIS_ZTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b, + p_lda,k_iter) + + BLIS_PRE_ZTRSM_SMALL_1M_1N(AlphaVal,b11,cs_b) + + if(transa) + ztrsm_AltXB_ref(a11, b11, m_remainder, 1, + cs_a, cs_b, is_unitdiag, + conjtransa); + + else + ztrsm_AuXB_ref(a11, b11, m_remainder, 1, + rs_a, cs_b, is_unitdiag, + conjtransa); + } + } + } + + if ((required_packing_A == 1) && + bli_mem_is_alloc( &local_mem_buf_A_s )) + { + bli_membrk_release(&rntm, &local_mem_buf_A_s); + } + + return BLIS_SUCCESS; +} + +BLIS_INLINE err_t bli_ztrsm_small_XAutB_XAlB +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +) +{ + dim_t m = bli_obj_length(b); //number of rows + dim_t n = bli_obj_width(b); //number of columns + + bool transa = bli_obj_has_trans(a); + bool conjtransa = bli_obj_has_conj(a); + + dim_t cs_a, rs_a; + dim_t d_mr = 4,d_nr = 3; + + // Swap rs_a & cs_a in case of non-tranpose. + if(transa) + { + cs_a = bli_obj_col_stride(a); // column stride of A + rs_a = bli_obj_row_stride(a); // row stride of A + } + else + { + cs_a = bli_obj_row_stride(a); // row stride of A + rs_a = bli_obj_col_stride(a); // column stride of A + } + dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B + + dim_t i, j, k; //loop variablse + dim_t k_iter; //determines the number of GEMM operations to be done + + dcomplex ones = {1.0, 1.0}; + dcomplex zero = {0.0, 0.0}; + bool is_unitdiag = bli_obj_has_unit_diag(a); + + dcomplex AlphaVal = *(dcomplex *)AlphaObj->buffer; //value of Alpha + dcomplex* restrict L = a->buffer; //pointer to matrix A + dcomplex* restrict B = b->buffer; //pointer to matrix B + + dcomplex *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks + + gint_t required_packing_A = 1; + mem_t local_mem_buf_A_s = {0}; + dcomplex *D_A_pack = NULL; + dcomplex d11_pack[d_mr] __attribute__((aligned(64))); + rntm_t rntm; + + bli_rntm_init_from_global( &rntm ); + bli_rntm_set_num_threads_only( 1, &rntm ); + bli_membrk_rntm_set_membrk( &rntm ); + + siz_t buffer_size = bli_pool_block_size( + bli_membrk_pool( + bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), + bli_rntm_membrk(&rntm))); + + if( (d_nr * n * sizeof(dcomplex)) > buffer_size) + return BLIS_NOT_YET_IMPLEMENTED; + + if (required_packing_A == 1) + { + // Get the buffer from the pool. + bli_membrk_acquire_m(&rntm, + buffer_size, + BLIS_BITVAL_BUFFER_FOR_A_BLOCK, + &local_mem_buf_A_s); + if(FALSE==bli_mem_is_alloc(&local_mem_buf_A_s)) return BLIS_NULL_POINTER; + D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); + if(NULL==D_A_pack) return BLIS_NULL_POINTER; + } + + //ymm scratch reginsters + __m256d ymm0, ymm1, ymm2, ymm3; + __m256d ymm4, ymm5, ymm6, ymm7; + __m256d ymm8, ymm9, ymm10, ymm11; + __m256d ymm12, ymm13, ymm14, ymm15; + __m256d ymm16, ymm17, ymm18, ymm19; + + __m128d xmm5, xmm4, xmm3; + + for(j = (n-d_nr); (j+1) > 0; j -= d_nr) //loop along 'N' direction + { + a01 = L + (j*rs_a) + (j+d_nr)*cs_a; + a11 = L + (j*cs_a) + (j*rs_a); + + dim_t p_lda = (n-j-d_nr); // packed leading dimension + // perform copy of A to packed buffer D_A_pack + + if(transa) + { + /* + Pack current A block (a01) into packed buffer memory D_A_pack + a. This a10 block is used in GEMM portion only and this + a01 block size will be increasing by d_nr for every next + iteration until it reaches 3x(n-3) which is the maximum GEMM + alone block size in A + b. This packed buffer is reused to calculate all m cols of B + matrix + */ + bli_ztrsm_small_pack('R', p_lda, 1, a01, cs_a, D_A_pack, + p_lda,d_nr); + + /* + Pack 3 diagonal elements of A block into an array + a. This helps in utilze cache line efficiently in TRSM + operation + b. store ones when input is unit diagonal + */ + ztrsm_small_pack_diag_element(is_unitdiag,a11,cs_a, + d11_pack,d_nr); + } + else + { + bli_ztrsm_small_pack('R', p_lda, 0, a01, rs_a, D_A_pack, + p_lda,d_nr); + ztrsm_small_pack_diag_element(is_unitdiag,a11,rs_a, + d11_pack,d_nr); + } + + /* + a. Perform GEMM using a01, b10. + b. Perform TRSM on a11, b11 + c. This loop GEMM+TRSM loops operates with 8x6 block size + along m dimension for every d_mr columns of B10 where + packed A buffer is reused in computing all m cols of B. + d. Same approach is used in remaining fringe cases. + */ + for(i = (m-d_mr); (i+1) > 0; i -= d_mr) //loop along 'M' direction + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; + b10 = B + i + (j+d_nr)*cs_b; + b11 = B + (i) + (j)*cs_b; + + k_iter = (n-j-d_nr); + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + /* + Peform GEMM between a01 and b10 blocks + For first itteration there will be no GEMM operation + where k_iter are zero + */ + + BLIS_ZTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) + + /* + Load b11 multiply with alpha + Add the GEMM output to b11 + and peform TRSM operation. + */ + + BLIS_PRE_ZTRSM_SMALL_3x4(AlphaVal,b11,cs_b) + ///implement TRSM/// + /* + Compute 3x3 TRSM block by using GEMM block output in register + a. The 4x3 input (gemm outputs) are stored in combinations of + ymm registers + 1. ymm7, ymm8 2. ymm5, ymm6 3. ymm3, ymm4 + b. Towards the end do in regiser transpose of TRSM output and + store in b11 + */ + ////extract a00 + ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack + 2)); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm7 and ymm8 with ymm1*/ + BLIS_ZTRSM_TWO_DIV(ymm7,ymm8) +#else + /*performs dcomplex multiplication of ymm7 and ymm8 with ymm1*/ + BLIS_ZTRSM_MUL(ymm7) + BLIS_ZTRSM_MUL(ymm8) +#endif + //extract a11 + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack + 1)); + //(ROW1): FMA operations + ymm2 = _mm256_broadcast_pd((__m128d const *) + (a11 + cs_a*2 + rs_a*1)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + /* Step1 dcomplex multiply ymm2, ymm7 + * Step2 negate the result + * Step3 add ymmx*/ + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + //For ymm7 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm7, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm7, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + + //For ymm8 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + + ymm13 = _mm256_mul_pd(ymm8, ymm2); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm17 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + ymm17 = _mm256_mul_pd(ymm17, ymm15); + + //Step 3 + ymm5 = _mm256_add_pd(ymm16, ymm5); + ymm6 = _mm256_add_pd(ymm17, ymm6); + + ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a*2)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + //For ymm7 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm7, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm7, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + //For ymm8 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + + ymm13 = _mm256_mul_pd(ymm8, ymm2); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm17 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + ymm17 = _mm256_mul_pd(ymm17, ymm15); + + //Step 3 + ymm3 = _mm256_add_pd(ymm16, ymm3); + ymm4 = _mm256_add_pd(ymm17, ymm4); + + +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm5 and ymm6 with ymm1*/ + BLIS_ZTRSM_TWO_DIV(ymm5,ymm6) +#else + /*performs dcomplex multiplication of ymm5 and ymm6 with ymm1*/ + BLIS_ZTRSM_MUL(ymm5) + BLIS_ZTRSM_MUL(ymm6) +#endif + //extract a22 + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); + + //(ROW2): FMA operations + ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + + //For ymm5 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm5, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm5, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + //For ymm6 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + + ymm13 = _mm256_mul_pd(ymm6, ymm2); + ymm14 = _mm256_mul_pd(ymm6, ymm14); + ymm17 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + ymm17 = _mm256_mul_pd(ymm17, ymm15); + //Step 3 + ymm3 = _mm256_add_pd(ymm16, ymm3); + ymm4 = _mm256_add_pd(ymm17, ymm4); + + +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm3 and ymm4 with ymm1*/ + BLIS_ZTRSM_TWO_DIV(ymm3,ymm4) +#else + /*performs dcomplex multiplication of ymm3 and ymm4 with ymm1*/ + BLIS_ZTRSM_MUL(ymm3) + BLIS_ZTRSM_MUL(ymm4) +#endif + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + 2), ymm4); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + _mm256_storeu_pd((double *)(b11 + cs_b + 2), ymm6); + _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); + _mm256_storeu_pd((double *)(b11 + cs_b*2 + 2), ymm8); + + } + dim_t m_remainder = i + d_mr; + if(m_remainder) + { + if(3 == m_remainder) + { + a01 = D_A_pack; + a11 = L + (j*cs_a) + (j*rs_a); + b10 = B + (j+d_nr)*cs_b + (m_remainder - 3); + b11 = B + (m_remainder - 3) + (j*cs_b); + k_iter = (n-j-d_nr); + /*Fill zeros into ymm registers used in gemm + * accumulations */ + BLIS_SET_YMM_REG_ZEROS + /* + Peform GEMM between a01 and b10 blocks + For first itteration there will be no GEMM operation + where k_iter are zero + */ + + BLIS_ZTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) + + /* + Load b11 multiply with alpha + Add the GEMM output to b11 + and peform TRSM operation. + */ + + BLIS_PRE_ZTRSM_SMALL_3x4(AlphaVal,b11,cs_b) + ///implement TRSM/// + /* + Compute 3x3 TRSM block by using GEMM block output in + register + a. The 4x3 input (gemm outputs) are stored in + combinations of ymm registers + 1. ymm7, ymm8 2. ymm5, ymm6 3. ymm3, ymm4 + b. Towards the end do in regiser transpose of TRSM + output and store in b11 + */ + ////extract a00 + ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + ymm1 = _mm256_broadcast_pd((__m128d const *) + (d11_pack + 2)); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm7 and ymm8 with ymm1*/ + BLIS_ZTRSM_TWO_DIV(ymm7,ymm8) +#else + /*performs dcomplex multiplication of ymm7 and + * ymm8 with ymm1*/ + BLIS_ZTRSM_MUL(ymm7) + BLIS_ZTRSM_MUL(ymm8) +#endif + //extract a11 + ymm1 = _mm256_broadcast_pd((__m128d const *) + (d11_pack + 1)); + //(ROW1): FMA operations + ymm2 = _mm256_broadcast_pd((__m128d const *) + (a11 + cs_a*2 + rs_a*1)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + /* Step1 dcomplex multiply ymm2, ymm7 + * Step2 negate the result + * Step3 add ymmx*/ + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + //For ymm7 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm7, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm7, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + + //For ymm8 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + + ymm13 = _mm256_mul_pd(ymm8, ymm2); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm17 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + ymm17 = _mm256_mul_pd(ymm17, ymm15); + + //Step 3 + ymm5 = _mm256_add_pd(ymm16, ymm5); + ymm6 = _mm256_add_pd(ymm17, ymm6); + + ymm2 = _mm256_broadcast_pd((__m128d const *) + (a11 + cs_a*2)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + //For ymm7 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm7, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm7, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + //For ymm8 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + + ymm13 = _mm256_mul_pd(ymm8, ymm2); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm17 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + ymm17 = _mm256_mul_pd(ymm17, ymm15); + + //Step 3 + ymm3 = _mm256_add_pd(ymm16, ymm3); + ymm4 = _mm256_add_pd(ymm17, ymm4); + + +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm5 and ymm6 with ymm1*/ + BLIS_ZTRSM_TWO_DIV(ymm5,ymm6) +#else + /*performs dcomplex multiplication of ymm5 and + * ymm6 with ymm1*/ + BLIS_ZTRSM_MUL(ymm5) + BLIS_ZTRSM_MUL(ymm6) +#endif + //extract a22 + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); + + //(ROW2): FMA operations + ymm2 = _mm256_broadcast_pd((__m128d const *) + (a11 + cs_a)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + + //For ymm5 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm5, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm5, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + //For ymm6 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + + ymm13 = _mm256_mul_pd(ymm6, ymm2); + ymm14 = _mm256_mul_pd(ymm6, ymm14); + ymm17 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + ymm17 = _mm256_mul_pd(ymm17, ymm15); + //Step 3 + ymm3 = _mm256_add_pd(ymm16, ymm3); + ymm4 = _mm256_add_pd(ymm17, ymm4); + + +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm3 and ymm4 with ymm1*/ + BLIS_ZTRSM_TWO_DIV(ymm3,ymm4) +#else + /*performs dcomplex multiplication of ymm3 and + * ymm4 with ymm1*/ + BLIS_ZTRSM_MUL(ymm3) + BLIS_ZTRSM_MUL(ymm4) +#endif + _mm256_storeu_pd((double *)b11, ymm3); + _mm_storeu_pd((double *)(b11 + 2), + _mm256_extractf128_pd(ymm4,0)); + + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + _mm_storeu_pd((double *)(b11 + cs_b + 2), + _mm256_extractf128_pd(ymm6,0)); + + _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); + _mm_storeu_pd((double *)(b11 + cs_b*2 + 2), + _mm256_extractf128_pd(ymm8,0)); + m_remainder -=3; + } + else if(2 == m_remainder) + { + a01 = D_A_pack; + a11 = L + (j*cs_a) + (j*rs_a); + b10 = B + (j+d_nr)*cs_b + (m_remainder - 2); + b11 = B + (m_remainder - 2) + (j*cs_b); + k_iter = (n-j-d_nr); + /*Fill zeros into ymm registers used in gemm + * accumulations */ + BLIS_SET_YMM_REG_ZEROS + /* + Peform GEMM between a01 and b10 blocks + For first itteration there will be no GEMM operation + where k_iter are zero + */ + + BLIS_ZTRSM_SMALL_GEMM_3nx2m(a01,b10,cs_b,p_lda,k_iter) + + /* + Load b11 of size 8x6 and multiply with alpha + Add the GEMM output to b11 + and peform TRSM operation. + */ + + BLIS_PRE_ZTRSM_SMALL_3x2(AlphaVal,b11,cs_b) + ///implement TRSM/// + /* + Compute 3x3 TRSM block by using GEMM block output + in register + a. The 4x3 input (gemm outputs) are stored in + combinations of ymm registers + 1. ymm8, ymm11 2. ymm9, ymm12 3. ymm10, ymm13 + b. Towards the end do in regiser transpose of TRSM + output and store in b11 + */ + ////extract a00 + ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + ymm1 = _mm256_broadcast_pd((__m128d const *) + (d11_pack + 2)); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm7 with ymm1*/ + BLIS_ZTRSM_DIV(ymm7) +#else + /*performs dcomplex multiplication of ymm7 with ymm1*/ + BLIS_ZTRSM_MUL(ymm7) +#endif + //extract a11 + ymm1 = _mm256_broadcast_pd((__m128d const *) + (d11_pack + 1)); + //(ROW1): FMA operations + ymm2 = _mm256_broadcast_pd((__m128d const *) + (a11 + cs_a*2 + rs_a*1)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + /* Step1 dcomplex multiply ymm2, ymm7 + * Step2 negate the result + * Step3 add ymmx*/ + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + //For ymm7 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm7, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm7, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + + //Step 3 + ymm5 = _mm256_add_pd(ymm16, ymm5); + + ymm2 = _mm256_broadcast_pd((__m128d const *) + (a11 + cs_a*2)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + //For ymm7 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm7, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm7, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + + //Step 3 + ymm3 = _mm256_add_pd(ymm16, ymm3); + + +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm5 with ymm1*/ + BLIS_ZTRSM_DIV(ymm5) +#else + /*performs dcomplex multiplication of ymm5 with ymm1*/ + BLIS_ZTRSM_MUL(ymm5) +#endif + //extract a22 + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); + + //(ROW2): FMA operations + ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + + //For ymm5 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm5, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm5, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + //Step 3 + ymm3 = _mm256_add_pd(ymm16, ymm3); + + +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm3 with ymm1*/ + BLIS_ZTRSM_DIV(ymm3) +#else + /*performs dcomplex multiplication of ymm3 with ymm1*/ + BLIS_ZTRSM_MUL(ymm3) +#endif + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); + m_remainder -=2; + } + else if(1 == m_remainder) + { + a01 = D_A_pack; + a11 = L + (j*cs_a) + (j*rs_a); + b10 = B + (j+d_nr)*cs_b + (m_remainder - 1); + b11 = B + (m_remainder - 1) + (j*cs_b); + k_iter = (n-j-d_nr); + /*Fill zeros into ymm registers used in gemm + * accumulations */ + BLIS_SET_YMM_REG_ZEROS + /* + Peform GEMM between a01 and b10 blocks + For first itteration there will be no GEMM operation + where k_iter are zero + */ + + BLIS_ZTRSM_SMALL_GEMM_3nx2m(a01,b10,cs_b,p_lda,k_iter) + + /* + Load b11 and multiply with alpha + Add the GEMM output to b11 + and peform TRSM operation. + */ + + BLIS_PRE_ZTRSM_SMALL_3x2(AlphaVal,b11,cs_b) + ///implement TRSM/// + /* + Compute 3x3 TRSM block by using GEMM block output + in register + a. The 4x3 input (gemm outputs) are stored in + combinations of ymm registers + 1. ymm7, ymm8 2. ymm5, ymm6 3. ymm3, ymm4 + b. Towards the end do in regiser transpose of TRSM + output and store in + b11 + */ + ////extract a00 + ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + ymm1 = _mm256_broadcast_pd((__m128d const *) + (d11_pack + 2)); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm7 with ymm1*/ + BLIS_ZTRSM_DIV(ymm7) +#else + /*performs dcomplex multiplication of ymm7 with ymm1*/ + BLIS_ZTRSM_MUL(ymm7) +#endif + //extract a11 + ymm1 = _mm256_broadcast_pd((__m128d const *) + (d11_pack + 1)); + //(ROW1): FMA operations + ymm2 = _mm256_broadcast_pd((__m128d const *) + (a11 + cs_a*2 + rs_a*1)); + + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + /* Step1 dcomplex multiply ymm2, ymm7 + * Step2 negate the result + * Step3 add ymmx*/ + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + //For ymm7 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm7, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm7, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + + //Step 3 + ymm5 = _mm256_add_pd(ymm16, ymm5); + + ymm2 = _mm256_broadcast_pd((__m128d const *) + (a11 + cs_a*2)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + //For ymm7 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm7, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm7, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + + //Step 3 + ymm3 = _mm256_add_pd(ymm16, ymm3); + + +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm5 with ymm1*/ + BLIS_ZTRSM_DIV(ymm5) +#else + /*performs dcomplex multiplication of ymm5 with ymm1*/ + BLIS_ZTRSM_MUL(ymm5) +#endif + //extract a22 + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); + + //(ROW2): FMA operations + ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + + //For ymm5 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm5, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm5, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + //Step 3 + ymm3 = _mm256_add_pd(ymm16, ymm3); + + +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm3 with ymm1*/ + BLIS_ZTRSM_DIV(ymm3) +#else + /*performs dcomplex multiplication of ymm3 and with ymm1*/ + BLIS_ZTRSM_MUL(ymm3) +#endif + _mm_storeu_pd((double *)b11, + _mm256_extractf128_pd(ymm3,0)); + _mm_storeu_pd((double *)(b11 + cs_b), + _mm256_extractf128_pd(ymm5,0)); + _mm_storeu_pd((double *)(b11 + cs_b*2), + _mm256_extractf128_pd(ymm7,0)); + m_remainder -=1; + } + } + + } + dim_t n_remainder = j + d_nr; + if(n_remainder == 2) + { + a01 = L + (n_remainder - 2)*rs_a + n_remainder*cs_a; + a11 = L + (n_remainder - 2)*cs_a + (n_remainder - 2)*rs_a; + + dcomplex *ptr_a10_dup = D_A_pack; + + dim_t p_lda = (n-n_remainder); + + if(transa) + { + for(dim_t x =0;x < p_lda;x+=d_nr) + { + ymm0 = _mm256_loadu_pd((double const *)(a01)); + ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); + ymm3 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm4 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm3); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm4); + ymm0 = _mm256_loadu_pd((double const *)(a01 + 2)); + ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a + 2)); + ymm3 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 2), + ymm3); + + ymm0 = _mm256_loadu_pd((double const *)(a01 + cs_a * 2)); + ymm1 = _mm256_loadu_pd((double const *) + (a01 + cs_a * 2 + 2)); + ymm5 = _mm256_broadcast_pd((__m128d const *)&zero); + + ymm3 = _mm256_permute2f128_pd(ymm0,ymm5,0x20); + ymm4 = _mm256_permute2f128_pd(ymm0,ymm5,0x31); + ymm5 = _mm256_permute2f128_pd(ymm1,ymm5,0x20); + + _mm_storeu_pd((double *)(ptr_a10_dup + 2), + _mm256_extractf128_pd(ymm3,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + p_lda + 2), + _mm256_extractf128_pd(ymm4,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 2 + 2), + _mm256_extractf128_pd(ymm5, 0)); + a01 += d_nr*cs_a; + ptr_a10_dup += d_nr; + } + } + else + { + dim_t loop_count = (n-n_remainder)/2; + + for(dim_t x =0;x < loop_count;x++) + { + ymm15 = _mm256_loadu_pd((double const *) + (a01 + rs_a * 0 + x*2)); + _mm256_storeu_pd((double *) + (ptr_a10_dup + p_lda * 0 + x*2), ymm15); + ymm15 = _mm256_loadu_pd((double const *) + (a01 + rs_a * 1 + x*2)); + _mm256_storeu_pd((double *) + (ptr_a10_dup + p_lda * 1 + x*2), ymm15); + } + + dim_t remainder_loop_count = p_lda - loop_count*2; + + __m128d xmm0; + if(remainder_loop_count != 0) + { + xmm0 = _mm_loadu_pd((double const *) + (a01 + rs_a * 0 + loop_count*2)); + _mm_storeu_pd((double *) + (ptr_a10_dup + p_lda * 0 + loop_count*2), + xmm0); + xmm0 = _mm_loadu_pd((double const *) + (a01 + rs_a * 1 + loop_count*2)); + _mm_storeu_pd((double *) + (ptr_a10_dup + p_lda * 1 + loop_count*2), + xmm0); + } + } + if(!is_unitdiag) + { + if(transa) + { + ymm0 = _mm256_broadcast_pd((__m128d const *)(a11)); + ymm1 = _mm256_broadcast_pd((__m128d const *) + (a11+cs_a*1 + 1)); + } + else + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_pd((__m128d const *)(a11)); + ymm1 = _mm256_broadcast_pd((__m128d const *) + (a11+rs_a*1 + 1)); + } + ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + ymm7 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + /*Taking denomerator multiplication of real & + * imaginary components*/ + ymm4 = _mm256_mul_pd(ymm1, ymm1); + /*Swapping real & imaginary component position for addition with + * respective components*/ + ymm6 = _mm256_permute4x64_pd(ymm4, 0xb1); + ymm4 = _mm256_add_pd(ymm4, ymm6); + /*Negating imaginary component of numerator*/ + ymm1 = _mm256_mul_pd(ymm1, ymm7); + /*Dividing numerator by denominator*/ + ymm1 = _mm256_div_pd(ymm1, ymm4); +#endif + } + else + { + ymm1 = _mm256_broadcast_pd((__m128d const*)&ones); + } + _mm256_storeu_pd((double *)(d11_pack), ymm1); + for(i = (m-d_mr); (i+1) > 0; i -= d_mr) //loop along 'M' direction + { + a01 = D_A_pack; + a11 = L + (n_remainder - 2)*cs_a + (n_remainder - 2)*rs_a; + b10 = B + i + (n_remainder)*cs_b; + b11 = B + (i) + (n_remainder - 2)*cs_b; + + k_iter = (n-n_remainder); + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + BLIS_ZTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_PRE_ZTRSM_SMALL_2x4(AlphaVal,b11,cs_b) + ///implement TRSM/// + ////extract a00 + ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack + 1)); + +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm5 and ymm6 with ymm1*/ + BLIS_ZTRSM_TWO_DIV(ymm5,ymm6) +#else + /*performs dcomplex multiplication of ymm5 and ymm6 with ymm1*/ + BLIS_ZTRSM_MUL(ymm5) + BLIS_ZTRSM_MUL(ymm6) +#endif + //extract a22 + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); + + //(ROW2): FMA operations + ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + + //For ymm5 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm5, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm5, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + //For ymm6 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + + ymm13 = _mm256_mul_pd(ymm6, ymm2); + ymm14 = _mm256_mul_pd(ymm6, ymm14); + ymm17 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + ymm17 = _mm256_mul_pd(ymm17, ymm15); + //Step 3 + ymm3 = _mm256_add_pd(ymm16, ymm3); + ymm4 = _mm256_add_pd(ymm17, ymm4); + + +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm3 and ymm4 with ymm1*/ + BLIS_ZTRSM_TWO_DIV(ymm3,ymm4) +#else + /*performs dcomplex multiplication of ymm3 and ymm4 with ymm1*/ + BLIS_ZTRSM_MUL(ymm3) + BLIS_ZTRSM_MUL(ymm4) +#endif + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + 2), ymm4); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + _mm256_storeu_pd((double *)(b11 + cs_b + 2), ymm6); + + } + dim_t m_remainder = i + d_mr; + if(3 == m_remainder) + { + a01 = D_A_pack; + a11 = L + (n_remainder - 2)*cs_a + (n_remainder - 2)*rs_a; + b10 = B + (m_remainder - 3) + (n_remainder)*cs_b; + b11 = B + (m_remainder - 3) + (n_remainder - 2)*cs_b; + + k_iter = (n-n_remainder); + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + /* + Peform GEMM between a01 and b10 blocks + For first itteration there will be no GEMM operation + where k_iter are zero + */ + BLIS_ZTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) + + // Load b11 and multiply with alpha + BLIS_PRE_ZTRSM_SMALL_3x4(AlphaVal,b11,cs_b) + ////extract a00 + ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack + 1)); + +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm5 and ymm6 with ymm1*/ + BLIS_ZTRSM_TWO_DIV(ymm5,ymm6) +#else + /*performs dcomplex multiplication of ymm5 and ymm6 with ymm1*/ + BLIS_ZTRSM_MUL(ymm5) + BLIS_ZTRSM_MUL(ymm6) +#endif + //extract a22 + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); + + //(ROW2): FMA operations + ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + + //For ymm5 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm5, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm5, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + //For ymm6 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + + ymm13 = _mm256_mul_pd(ymm6, ymm2); + ymm14 = _mm256_mul_pd(ymm6, ymm14); + ymm17 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + ymm17 = _mm256_mul_pd(ymm17, ymm15); + //Step 3 + ymm3 = _mm256_add_pd(ymm16, ymm3); + ymm4 = _mm256_add_pd(ymm17, ymm4); + + +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm3 and ymm4 with ymm1*/ + BLIS_ZTRSM_TWO_DIV(ymm3,ymm4) +#else + /*performs dcomplex multiplication of ymm3 and ymm4 with ymm1*/ + BLIS_ZTRSM_MUL(ymm3) + BLIS_ZTRSM_MUL(ymm4) +#endif + _mm256_storeu_pd((double *)b11, ymm3); + _mm_storeu_pd((double *)(b11 + 2), + _mm256_extractf128_pd(ymm4,0)); + + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + _mm_storeu_pd((double *)(b11 + cs_b + 2), + _mm256_extractf128_pd(ymm6,0)); + m_remainder -=3; + } + if(2 == m_remainder) + { + a01 = D_A_pack; + a11 = L + (n_remainder - 2)*cs_a + (n_remainder - 2)*rs_a; + b10 = B + (m_remainder - 2) + (n_remainder)*cs_b; + b11 = B + (m_remainder - 2) + (n_remainder - 2)*cs_b; + + k_iter = (n-n_remainder); + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + /* + Peform GEMM between a01 and b10 blocks + For first itteration there will be no GEMM operation + where k_iter are zero + */ + BLIS_ZTRSM_SMALL_GEMM_3nx2m(a01,b10,cs_b,p_lda,k_iter) + + // Load b11 and multiply with alpha + BLIS_PRE_ZTRSM_SMALL_3x2(AlphaVal,b11,cs_b) + ////extract a00 + ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack + 1)); + +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm5 with ymm1*/ + BLIS_ZTRSM_DIV(ymm5) +#else + /*performs dcomplex multiplication of ymm5 with ymm1*/ + BLIS_ZTRSM_MUL(ymm5) +#endif + //extract a22 + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); + + //(ROW2): FMA operations + ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + + //For ymm5 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm5, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm5, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + //Step 3 + ymm3 = _mm256_add_pd(ymm16, ymm3); + + +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm3 with ymm1*/ + BLIS_ZTRSM_DIV(ymm3) +#else + /*performs dcomplex multiplication of ymm3 with ymm1*/ + BLIS_ZTRSM_MUL(ymm3) +#endif + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + m_remainder -=2; + } + if(1 == m_remainder) + { + a01 = D_A_pack; + a11 = L + (n_remainder - 2)*cs_a + (n_remainder - 2)*rs_a; + b10 = B + (m_remainder - 1) + (n_remainder)*cs_b; + b11 = B + (m_remainder - 1) + (n_remainder - 2)*cs_b; + + k_iter = (n-n_remainder); + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + /* + Peform GEMM between a01 and b10 blocks + For first itteration there will be no GEMM operation + where k_iter are zero + */ + BLIS_ZTRSM_SMALL_GEMM_3nx2m(a01,b10,cs_b,p_lda,k_iter) + + // Load b11 and multiply with alpha + BLIS_PRE_ZTRSM_SMALL_3x2(AlphaVal,b11,cs_b) + ////extract a00 + ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack + 1)); + +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm5 with ymm1*/ + BLIS_ZTRSM_DIV(ymm5) +#else + /*performs dcomplex multiplication of ymm5 with ymm1*/ + BLIS_ZTRSM_MUL(ymm5) +#endif + //extract a22 + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); + + //(ROW2): FMA operations + ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + + //For ymm5 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm5, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm5, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + //Step 3 + ymm3 = _mm256_add_pd(ymm16, ymm3); + + +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm3 with ymm1*/ + BLIS_ZTRSM_DIV(ymm3) +#else + /*performs dcomplex multiplication of ymm3 with ymm1*/ + BLIS_ZTRSM_MUL(ymm3) +#endif + _mm_storeu_pd((double *)b11, + _mm256_extractf128_pd(ymm3,0)); + _mm_storeu_pd((double *)(b11 + cs_b), + _mm256_extractf128_pd(ymm5,0)); + m_remainder -=1; + } + n_remainder -= 2; + } + else if(n_remainder == 1) + { + a01 = L + (n_remainder - 1)*rs_a + n_remainder*cs_a; + a11 = L + (n_remainder - 1)*cs_a + (n_remainder - 1)*rs_a; + + dcomplex *ptr_a10_dup = D_A_pack; + + dim_t p_lda = (n-n_remainder); // packed leading dimension + // perform copy of A to packed buffer D_A_pack + if(transa) + { + for(dim_t x =0;x < p_lda;x+=d_nr) + { + ymm0 = _mm256_loadu_pd((double const *)(a01)); + ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); + ymm3 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm4 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm3); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm4); + + ymm0 = _mm256_loadu_pd((double const *)(a01 + 2)); + ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a + 2)); + ymm3 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + _mm256_storeu_pd((double *) + (ptr_a10_dup + p_lda * 2), ymm3); + + ymm0 = _mm256_loadu_pd((double const *)(a01 + cs_a * 2)); + ymm1 = _mm256_loadu_pd((double const *) + (a01 + cs_a * 2 + 2)); + ymm5 = _mm256_broadcast_pd((__m128d const *)&zero); + + ymm3 = _mm256_permute2f128_pd(ymm0,ymm5,0x20); + ymm4 = _mm256_permute2f128_pd(ymm0,ymm5,0x31); + ymm5 = _mm256_permute2f128_pd(ymm1,ymm5,0x20); + + _mm_storeu_pd((double *)(ptr_a10_dup + 2), + _mm256_extractf128_pd(ymm3,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + p_lda + 2), + _mm256_extractf128_pd(ymm4,0)); + _mm_storeu_pd((double *) + (ptr_a10_dup + p_lda * 2 + 2), + _mm256_extractf128_pd(ymm5, 0)); + a01 += d_nr*cs_a; + ptr_a10_dup += d_nr; + } + + } + else + { + dim_t loop_count = (n-n_remainder)/2; + + for(dim_t x =0;x < loop_count;x++) + { + ymm15 = _mm256_loadu_pd((double const *) + (a01 + rs_a * 0 + x*2)); + _mm256_storeu_pd((double *) + (ptr_a10_dup + p_lda * 0 + x*2), ymm15); + } + + dim_t remainder_loop_count = p_lda - loop_count*2; + + __m128d xmm0; + if(remainder_loop_count != 0) + { + xmm0 = _mm_loadu_pd((double const *) + (a01 + rs_a * 0 + loop_count*2)); + _mm_storeu_pd((double *) + (ptr_a10_dup + p_lda * 0 + loop_count*2), + xmm0); + } + } + if(!is_unitdiag) + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_pd((__m128d const *)(a11)); + ymm1 = _mm256_broadcast_pd((__m128d const *)&ones); + ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + ymm7 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + /*Taking denomerator multiplication of real & + * imaginary components*/ + ymm4 = _mm256_mul_pd(ymm1, ymm1); + /*Swapping real & imaginary component position for addition with + * respective components*/ + ymm6 = _mm256_permute4x64_pd(ymm4, 0xb1); + ymm4 = _mm256_add_pd(ymm4, ymm6); + /*Negating imaginary component of numerator*/ + ymm1 = _mm256_mul_pd(ymm1, ymm7); + /*Dividing numerator by denominator*/ + ymm1 = _mm256_div_pd(ymm1, ymm4); +#endif + } + else + { + ymm1 = _mm256_broadcast_pd((__m128d const*)&ones); + } + _mm256_storeu_pd((double *)(d11_pack), ymm1); + for(i = (m-d_mr); (i+1) > 0; i -= d_mr) //loop along 'M' direction + { + a01 = D_A_pack; + a11 = L + (n_remainder - 1)*cs_a + (n_remainder - 1)*rs_a; + b10 = B + i + (n_remainder)*cs_b; + b11 = B + (i) + (n_remainder - 1)*cs_b; + + k_iter = (n-n_remainder); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + ///GEMM implementation starts/// + BLIS_ZTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_PRE_ZTRSM_SMALL_1x4(b11,cs_b,AlphaVal) + ///implement TRSM/// + ////extract a00 + ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm3 and ymm4 with ymm1*/ + BLIS_ZTRSM_TWO_DIV(ymm3,ymm4) +#else + /*performs dcomplex multiplication of ymm3 and ymm4 with ymm1*/ + BLIS_ZTRSM_MUL(ymm3) + BLIS_ZTRSM_MUL(ymm4) +#endif + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + 2),ymm4); + + } + dim_t m_remainder = i + d_mr; + if(3 == m_remainder) + { + a01 = D_A_pack; + a11 = L + (n_remainder - 1)*cs_a + (n_remainder - 1)*rs_a; + b10 = B + (m_remainder - 3) + (n_remainder)*cs_b; + b11 = B + (m_remainder - 3) + (n_remainder - 1)*cs_b; + + k_iter = (n-n_remainder); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_ZTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_PRE_ZTRSM_SMALL_1x3(b11,cs_b,AlphaVal) + + ///implement TRSM/// + ////extract a00 + ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm3 and ymm4 with ymm1*/ + BLIS_ZTRSM_TWO_DIV(ymm3,ymm4) +#else + /*performs dcomplex multiplication of ymm3 and ymm4 with ymm1*/ + BLIS_ZTRSM_MUL(ymm3) + BLIS_ZTRSM_MUL(ymm4) +#endif + + _mm256_storeu_pd((double *)b11, ymm3); + _mm_storeu_pd((double *)(b11 + 2), + _mm256_extractf128_pd(ymm4,0)); + m_remainder -=3; + + } + else if(2 == m_remainder) + { + a01 = D_A_pack; + a11 = L + (n_remainder - 1)*cs_a + (n_remainder - 1)*rs_a; + b10 = B + (m_remainder - 2) + (n_remainder)*cs_b; + b11 = B + (m_remainder - 2) + (n_remainder - 1)*cs_b; + + k_iter = (n-n_remainder); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_ZTRSM_SMALL_GEMM_1nx2m(a01,b10,cs_b,p_lda,k_iter) + + // Load b11 of size 2x1 and multiply with alpha + BLIS_PRE_ZTRSM_SMALL_1x2(AlphaVal,b11,cs_b) + + ///implement TRSM/// + ////extract a00 + ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm3 with ymm1*/ + BLIS_ZTRSM_DIV(ymm3) +#else + /*performs dcomplex multiplication of ymm3 with ymm1*/ + BLIS_ZTRSM_MUL(ymm3) +#endif + + _mm256_storeu_pd((double *)b11, ymm3); + m_remainder -=2; + + } + else if (1 == m_remainder) + { + a01 = D_A_pack; + a11 = L + (n_remainder - 1)*cs_a + (n_remainder - 1)*rs_a; + b10 = B + (m_remainder - 1) + (n_remainder)*cs_b; + b11 = B + (m_remainder - 1) + (n_remainder - 1)*cs_b; + + k_iter = (n-n_remainder); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_ZTRSM_SMALL_GEMM_1nx1m(a01,b10,cs_b,p_lda,k_iter) + + // Load b11 of size 4x6 and multiply with alpha + BLIS_PRE_ZTRSM_SMALL_1x1(AlphaVal,b11,cs_b) + + ///implement TRSM/// + ////extract a00 + ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm3 with ymm1*/ + BLIS_ZTRSM_DIV(ymm3) +#else + /*performs dcomplex multiplication of ymm3 with ymm1*/ + BLIS_ZTRSM_MUL(ymm3) +#endif + _mm_storeu_pd((double *)b11, + _mm256_extractf128_pd(ymm3,0)); + m_remainder -=1; + } + n_remainder -= 1; + } + + if ((required_packing_A == 1) && + bli_mem_is_alloc( &local_mem_buf_A_s )) + { + bli_membrk_release(&rntm, &local_mem_buf_A_s); + } + + + return BLIS_SUCCESS; +} + +BLIS_INLINE err_t bli_ztrsm_small_XAltB_XAuB +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +) +{ + dim_t m = bli_obj_length(b); //number of rows + dim_t n = bli_obj_width(b); //number of columns + + bool transa = bli_obj_has_trans(a); + bool conjtransa = bli_obj_has_conj(a); + + dim_t cs_a, rs_a; + dim_t d_mr = 4,d_nr = 3; + + // Swap rs_a & cs_a in case of non-tranpose. + if(transa) + { + cs_a = bli_obj_col_stride(a); // column stride of A + rs_a = bli_obj_row_stride(a); // row stride of A + } + else + { + cs_a = bli_obj_row_stride(a); // row stride of A + rs_a = bli_obj_col_stride(a); // column stride of A + } + dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B + + dim_t i, j, k; //loop variablse + dim_t k_iter; //determines the number of GEMM operations to be done + + dcomplex ones = {1.0, 1.0}; + dcomplex zero = {0.0, 0.0}; + bool is_unitdiag = bli_obj_has_unit_diag(a); + + dcomplex AlphaVal = *(dcomplex *)AlphaObj->buffer; //value of Alpha + dcomplex* restrict L = a->buffer; //pointer to matrix A + dcomplex* restrict B = b->buffer; //pointer to matrix B + + dcomplex *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks + + gint_t required_packing_A = 1; + mem_t local_mem_buf_A_s = {0}; + dcomplex *D_A_pack = NULL; + dcomplex d11_pack[d_mr] __attribute__((aligned(64))); + rntm_t rntm; + + bli_rntm_init_from_global( &rntm ); + bli_rntm_set_num_threads_only( 1, &rntm ); + bli_membrk_rntm_set_membrk( &rntm ); + + siz_t buffer_size = bli_pool_block_size( + bli_membrk_pool( + bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), + bli_rntm_membrk(&rntm))); + + if( (d_nr * n * sizeof(dcomplex)) > buffer_size) + return BLIS_NOT_YET_IMPLEMENTED; + + if (required_packing_A == 1) + { + // Get the buffer from the pool. + bli_membrk_acquire_m(&rntm, + buffer_size, + BLIS_BITVAL_BUFFER_FOR_A_BLOCK, + &local_mem_buf_A_s); + if(FALSE==bli_mem_is_alloc(&local_mem_buf_A_s)) return BLIS_NULL_POINTER; + D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); + if(NULL==D_A_pack) return BLIS_NULL_POINTER; + } + + //ymm scratch reginsters + __m256d ymm0, ymm1, ymm2, ymm3; + __m256d ymm4, ymm5, ymm6, ymm7; + __m256d ymm8, ymm9, ymm10, ymm11; + __m256d ymm12, ymm13, ymm14, ymm15; + __m256d ymm16, ymm17, ymm18, ymm19; + + __m128d xmm5, xmm4, xmm3; + + for(j = 0; (j+d_nr-1) < n; j += d_nr) //loop along 'N' direction + { + a01 = L + j*rs_a;//pointer to block of A to be used in GEMM + a11 = L + j*cs_a + j*rs_a;//pointer to block of A to be used for TRSM + + dim_t p_lda = j; // packed leading dimension + // perform copy of A to packed buffer D_A_pack + + if(transa) + { + /* + Pack current A block (a01) into packed buffer memory D_A_pack + a. This a10 block is used in GEMM portion only and this + a01 block size will be increasing by d_nr for every next + iteration until it reaches 3x(n-3) which is the maximum GEMM + alone block size in A + b. This packed buffer is reused to calculate all m cols of + B matrix + */ + bli_ztrsm_small_pack('R', j, 1, a01, cs_a, D_A_pack, p_lda,d_nr); + + /* + Pack 3 diagonal elements of A block into an array + a. This helps in utilze cache line efficiently in TRSM + operation + b. store ones when input is unit diagonal + */ + ztrsm_small_pack_diag_element(is_unitdiag,a11,cs_a, + d11_pack,d_nr); + } + else + { + bli_ztrsm_small_pack('R', j, 0, a01, rs_a, D_A_pack, + p_lda,d_nr); + ztrsm_small_pack_diag_element(is_unitdiag,a11,rs_a, + d11_pack,d_nr); + } + + /* + a. Perform GEMM using a01, b10. + b. Perform TRSM on a11, b11 + c. This loop GEMM+TRSM loops operates with 8x6 block size + along m dimension for every d_mr columns of B10 where + packed A buffer is reused in computing all m cols of B. + d. Same approach is used in remaining fringe cases. + */ + for(i = 0; (i+d_mr-1) < m; i += d_mr) //loop along 'M' direction + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; + b10 = B + i; + b11 = B + i + j*cs_b; + + k_iter = j; + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + /* + Peform GEMM between a01 and b10 blocks + For first itteration there will be no GEMM operation + where k_iter are zero + */ + + BLIS_ZTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) + + /* + Load b11 of size 4x3 and multiply with alpha + Add the GEMM output to b11 + and peform TRSM operation. + */ + + BLIS_PRE_ZTRSM_SMALL_3x4(AlphaVal,b11,cs_b) + ///implement TRSM/// + /* + Compute 3x3 TRSM block by using GEMM block output in register + a. The 3x4 input (gemm outputs) are stored in combinations of + ymm registers + 1. ymm3, ymm4 2. ymm5, ymm6 3. ymm7, ymm8 + b. Towards the end do in regiser transpose of TRSM output + and store in b11 + */ + ////extract a00 + ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm3 and ymm4 with ymm1*/ + BLIS_ZTRSM_TWO_DIV(ymm3,ymm4) +#else + /*performs dcomplex multiplication of ymm3 and ymm4 with ymm1*/ + BLIS_ZTRSM_MUL(ymm3) + BLIS_ZTRSM_MUL(ymm4) +#endif + //extract a11 + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack + 1)); + //(ROW1): FMA operations + ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + rs_a*1)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + /* Step1 dcomplex multiply ymm2, ymm3 + * Step2 negate the result + * Step3 add ymmx*/ + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + //For ymm3 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm3, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm3, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + + //For ymm4 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + + ymm13 = _mm256_mul_pd(ymm4, ymm2); + ymm14 = _mm256_mul_pd(ymm4, ymm14); + ymm17 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + ymm17 = _mm256_mul_pd(ymm17, ymm15); + + //Step 3 + ymm5 = _mm256_add_pd(ymm16, ymm5); + ymm6 = _mm256_add_pd(ymm17, ymm6); + + ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + rs_a*2)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + //For ymm3 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm3, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm3, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + //For ymm4 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + + ymm13 = _mm256_mul_pd(ymm4, ymm2); + ymm14 = _mm256_mul_pd(ymm4, ymm14); + ymm17 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + ymm17 = _mm256_mul_pd(ymm17, ymm15); + + //Step 3 + ymm7 = _mm256_add_pd(ymm16, ymm7); + ymm8 = _mm256_add_pd(ymm17, ymm8); + + +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm5 and ymm6 with ymm1*/ + BLIS_ZTRSM_TWO_DIV(ymm5,ymm6) +#else + /*performs dcomplex multiplication of ymm5 and ymm6 with ymm1*/ + BLIS_ZTRSM_MUL(ymm5) + BLIS_ZTRSM_MUL(ymm6) +#endif + a11 += cs_a; + + //extract a22 + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack + 2)); + //(ROW2): FMA operations + ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + rs_a * 2)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + + //For ymm5 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm5, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm5, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + //For ymm6 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + + ymm13 = _mm256_mul_pd(ymm6, ymm2); + ymm14 = _mm256_mul_pd(ymm6, ymm14); + ymm17 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + ymm17 = _mm256_mul_pd(ymm17, ymm15); + //Step 3 + ymm7 = _mm256_add_pd(ymm16, ymm7); + ymm8 = _mm256_add_pd(ymm17, ymm8); + + +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm7 and ymm8 with ymm1*/ + BLIS_ZTRSM_TWO_DIV(ymm7,ymm8) +#else + /*performs dcomplex multiplication of ymm7 and ymm8 with ymm1*/ + BLIS_ZTRSM_MUL(ymm7) + BLIS_ZTRSM_MUL(ymm8) +#endif + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + 2), ymm4); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + _mm256_storeu_pd((double *)(b11 + cs_b + 2), ymm6); + _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); + _mm256_storeu_pd((double *)(b11 + cs_b*2 + 2), ymm8); + + } + + dim_t m_remainder = m - i; + if(m_remainder) + { + if(m_remainder == 3) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; + b10 = B + i; + b11 = B + i + j*cs_b; + + k_iter = j; + + /*Fill zeros into ymm registers used in gemm + * accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_ZTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) + + // Load b11 of size 4x6 and multiply with alpha + BLIS_PRE_ZTRSM_SMALL_3x4(AlphaVal,b11,cs_b) + + ///implement TRSM/// + ////extract a00 + ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm3 and ymm4 with ymm1*/ + BLIS_ZTRSM_TWO_DIV(ymm3,ymm4) +#else + /*performs dcomplex multiplication of ymm3 and ymm4 + * with ymm1*/ + BLIS_ZTRSM_MUL(ymm3) + BLIS_ZTRSM_MUL(ymm4) +#endif + //extract a11 + ymm1 = _mm256_broadcast_pd((__m128d const *) + (d11_pack + 1)); + //(ROW1): FMA operations + ymm2 = _mm256_broadcast_pd((__m128d const *) + (a11 + rs_a*1)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + /* Step1 dcomplex multiply ymm2, ymm3 + * Step2 negate the result + * Step3 add ymmx*/ + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + //For ymm3 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm3, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm3, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + + //For ymm4 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + + ymm13 = _mm256_mul_pd(ymm4, ymm2); + ymm14 = _mm256_mul_pd(ymm4, ymm14); + ymm17 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + ymm17 = _mm256_mul_pd(ymm17, ymm15); + + //Step 3 + ymm5 = _mm256_add_pd(ymm16, ymm5); + ymm6 = _mm256_add_pd(ymm17, ymm6); + + ymm2 = _mm256_broadcast_pd((__m128d const *) + (a11 + rs_a*2)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + //For ymm3 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm3, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm3, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + //For ymm4 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + + ymm13 = _mm256_mul_pd(ymm4, ymm2); + ymm14 = _mm256_mul_pd(ymm4, ymm14); + ymm17 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + ymm17 = _mm256_mul_pd(ymm17, ymm15); + + //Step 3 + ymm7 = _mm256_add_pd(ymm16, ymm7); + ymm8 = _mm256_add_pd(ymm17, ymm8); + + +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm5 and ymm6 with ymm1*/ + BLIS_ZTRSM_TWO_DIV(ymm5,ymm6) +#else + /*performs dcomplex multiplication of ymm5 and ymm6 with + * ymm1*/ + BLIS_ZTRSM_MUL(ymm5) + BLIS_ZTRSM_MUL(ymm6) +#endif + a11 += cs_a; + + //extract a22 + ymm1 = _mm256_broadcast_pd((__m128d const *) + (d11_pack + 2)); + //(ROW2): FMA operations + ymm2 = _mm256_broadcast_pd((__m128d const *) + (a11 + rs_a * 2)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + + //For ymm5 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm5, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm5, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + //For ymm6 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + + ymm13 = _mm256_mul_pd(ymm6, ymm2); + ymm14 = _mm256_mul_pd(ymm6, ymm14); + ymm17 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + ymm17 = _mm256_mul_pd(ymm17, ymm15); + //Step 3 + ymm7 = _mm256_add_pd(ymm16, ymm7); + ymm8 = _mm256_add_pd(ymm17, ymm8); + + +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm7 and ymm8 with ymm1*/ + BLIS_ZTRSM_TWO_DIV(ymm7,ymm8) +#else + /*performs dcomplex multiplication of ymm7 and ymm8 + * with ymm1*/ + BLIS_ZTRSM_MUL(ymm7) + BLIS_ZTRSM_MUL(ymm8) +#endif + + _mm256_storeu_pd((double *)b11, ymm3); + _mm_storeu_pd((double *)(b11 + 2), + _mm256_extractf128_pd(ymm4,0)); + + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + _mm_storeu_pd((double *)(b11 + cs_b + 2), + _mm256_extractf128_pd(ymm6,0)); + + _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); + _mm_storeu_pd((double *)(b11 + cs_b*2 + 2), + _mm256_extractf128_pd(ymm8,0)); + + m_remainder -= 3; + i += 3; + } + else if(m_remainder == 2) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; + b10 = B + i; + b11 = B + i + j*cs_b; + + k_iter = j; + + /*Fill zeros into ymm registers used in gemm + * accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_ZTRSM_SMALL_GEMM_3nx2m(a01,b10,cs_b,p_lda,k_iter) + + // Load b11 of size 4x6 and multiply with alpha + BLIS_PRE_ZTRSM_SMALL_3x2(AlphaVal,b11,cs_b) + ////extract a00 + ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm3 with ymm1*/ + BLIS_ZTRSM_DIV(ymm3) +#else + /*performs dcomplex multiplication of ymm3 + * with ymm1*/ + BLIS_ZTRSM_MUL(ymm3) +#endif + //extract a11 + ymm1 = _mm256_broadcast_pd((__m128d const *) + (d11_pack + 1)); + //(ROW1): FMA operations + ymm2 = _mm256_broadcast_pd((__m128d const *) + (a11 + rs_a*1)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + /* Step1 dcomplex multiply ymm2, ymm3 + * Step2 negate the result + * Step3 add ymmx*/ + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + //For ymm3 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm3, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm3, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + + //Step 3 + ymm5 = _mm256_add_pd(ymm16, ymm5); + + ymm2 = _mm256_broadcast_pd((__m128d const *) + (a11 + rs_a*2)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + //For ymm3 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm3, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm3, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + + //Step 3 + ymm7 = _mm256_add_pd(ymm16, ymm7); + +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm5 with ymm1*/ + BLIS_ZTRSM_DIV(ymm5) +#else + /*performs dcomplex multiplication of ymm5 + * with ymm1*/ + BLIS_ZTRSM_MUL(ymm5) +#endif + a11 += cs_a; + + //extract a22 + ymm1 = _mm256_broadcast_pd((__m128d const *) + (d11_pack + 2)); + //(ROW2): FMA operations + ymm2 = _mm256_broadcast_pd((__m128d const *) + (a11 + rs_a * 2)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + + //For ymm5 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm5, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm5, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + //Step 3 + ymm7 = _mm256_add_pd(ymm16, ymm7); + +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm7 with ymm1*/ + BLIS_ZTRSM_DIV(ymm7) +#else + /*performs dcomplex multiplication of ymm7 + * with ymm1*/ + BLIS_ZTRSM_MUL(ymm7) +#endif + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); + m_remainder -= 2; + i += 2; + } + else if(m_remainder == 1) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; + b10 = B + i; + b11 = B + i + j*cs_b; + + k_iter = j; + + /*Fill zeros into ymm registers used in gemm + * accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_ZTRSM_SMALL_GEMM_3nx2m(a01,b10,cs_b,p_lda,k_iter) + + // Load b11 of size 2x3 and multiply with alpha + BLIS_PRE_ZTRSM_SMALL_3x2(AlphaVal,b11,cs_b) + + ///implement TRSM/// + ////extract a00 + ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm3 with ymm1*/ + BLIS_ZTRSM_DIV(ymm3) +#else + /*performs dcomplex multiplication of ymm3 + * with ymm1*/ + BLIS_ZTRSM_MUL(ymm3) +#endif + //extract a11 + ymm1 = _mm256_broadcast_pd((__m128d const *) + (d11_pack + 1)); + //(ROW1): FMA operations + ymm2 = _mm256_broadcast_pd((__m128d const *) + (a11 + rs_a*1)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + /* Step1 dcomplex multiply ymm2, ymm3 + * Step2 negate the result + * Step3 add ymmx*/ + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + //For ymm3 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm3, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm3, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + + //Step 3 + ymm5 = _mm256_add_pd(ymm16, ymm5); + + ymm2 = _mm256_broadcast_pd((__m128d const *) + (a11 + rs_a*2)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + //For ymm3 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm3, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm3, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + + //Step 3 + ymm7 = _mm256_add_pd(ymm16, ymm7); + +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm5 with ymm1*/ + BLIS_ZTRSM_DIV(ymm5) +#else + /*performs dcomplex multiplication of ymm5 + * with ymm1*/ + BLIS_ZTRSM_MUL(ymm5) +#endif + a11 += cs_a; + + //extract a22 + ymm1 = _mm256_broadcast_pd((__m128d const *) + (d11_pack + 2)); + //(ROW2): FMA operations + ymm2 = _mm256_broadcast_pd((__m128d const *) + (a11 + rs_a * 2)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + + //For ymm5 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm5, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm5, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + //Step 3 + ymm7 = _mm256_add_pd(ymm16, ymm7); + +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm7 with ymm1*/ + BLIS_ZTRSM_DIV(ymm7) +#else + /*performs dcomplex multiplication of ymm7 + * with ymm1*/ + BLIS_ZTRSM_MUL(ymm7) +#endif + + + _mm_storeu_pd((double *)b11, + _mm256_extractf128_pd(ymm3,0)); + _mm_storeu_pd((double *)(b11 + cs_b), + _mm256_extractf128_pd(ymm5,0)); + _mm_storeu_pd((double *)(b11 + cs_b*2), + _mm256_extractf128_pd(ymm7,0)); + + m_remainder -= 1; + i += 1; + } + } + + } + dim_t n_remainder = n - j; + if(n_remainder == 2) + { + a01 = L + j*rs_a; + a11 = L + j*cs_a + j*rs_a; + dcomplex *ptr_a10_dup = D_A_pack; + + dim_t p_lda = j; + // perform copy of A to packed buffer D_A_pack + + if(transa) + { + for(dim_t x =0;x < p_lda;x+=d_nr) + { + ymm0 = _mm256_loadu_pd((double const *)(a01)); + ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); + ymm3 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm4 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm3); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm4); + + ymm0 = _mm256_loadu_pd((double const *)(a01 + 2)); + ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a + 2)); + ymm3 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 2), + ymm3); + + ymm0 = _mm256_loadu_pd((double const *) + (a01 + cs_a * 2)); + ymm1 = _mm256_loadu_pd((double const *) + (a01 + cs_a * 2 + 2)); + ymm5 = _mm256_broadcast_pd((__m128d const *)&zero); + + ymm3 = _mm256_permute2f128_pd(ymm0,ymm5,0x20); + ymm4 = _mm256_permute2f128_pd(ymm0,ymm5,0x31); + ymm5 = _mm256_permute2f128_pd(ymm1,ymm5,0x20); + + _mm_storeu_pd((double *)(ptr_a10_dup + 2), + _mm256_extractf128_pd(ymm3,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + p_lda + 2), + _mm256_extractf128_pd(ymm4,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 2 + 2), + _mm256_extractf128_pd(ymm5, 0)); + a01 += d_nr*cs_a; + ptr_a10_dup += d_nr; + } + } + else + { + dim_t loop_count = p_lda/2; + + for(dim_t x =0;x < loop_count;x++) + { + ymm15 = _mm256_loadu_pd((double const *) + (a01 + rs_a * 0 + x*2)); + _mm256_storeu_pd((double *) + (ptr_a10_dup + p_lda * 0 + x*2), ymm15); + ymm15 = _mm256_loadu_pd((double const *) + (a01 + rs_a * 1 + x*2)); + _mm256_storeu_pd((double *) + (ptr_a10_dup + p_lda * 1 + x*2), + ymm15); + } + + dim_t remainder_loop_count = p_lda - loop_count*2; + + __m128d xmm0; + if(remainder_loop_count != 0) + { + xmm0 = _mm_loadu_pd((double const *) + (a01 + rs_a * 0 + loop_count*2)); + _mm_storeu_pd((double *) + (ptr_a10_dup + p_lda * 0 + loop_count*2), + xmm0); + xmm0 = _mm_loadu_pd((double const *) + (a01 + rs_a * 1 + loop_count*2)); + _mm_storeu_pd((double *) + (ptr_a10_dup + p_lda * 1 + loop_count*2), + xmm0); + } + } + if(!is_unitdiag) + { + if(transa) + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_pd((__m128d const *)(a11)); + ymm1 = _mm256_broadcast_pd((__m128d const *) + (a11+cs_a*1 + 1)); + } + else + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_pd((__m128d const *)(a11)); + ymm1 = _mm256_broadcast_pd((__m128d const *) + (a11+rs_a*1 + 1)); + } + ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + ymm7 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + /*Taking denomerator multiplication of real & + * imaginary components*/ + ymm4 = _mm256_mul_pd(ymm1, ymm1); + /*Swapping real & imaginary component position for addition with + * respective components*/ + ymm6 = _mm256_permute4x64_pd(ymm4, 0xb1); + ymm4 = _mm256_add_pd(ymm4, ymm6); + /*Negating imaginary component of numerator*/ + ymm1 = _mm256_mul_pd(ymm1, ymm7); + /*Dividing numerator by denominator*/ + ymm1 = _mm256_div_pd(ymm1, ymm4); +#endif + } + else + { + ymm1 = _mm256_broadcast_pd((__m128d const *)&ones); + } + _mm256_storeu_pd((double *)(d11_pack), ymm1); + + for(i = 0; (i+d_mr-1) < m; i += d_mr) //loop along 'M' direction + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; + b10 = B + i; + b11 = B + i + j*cs_b; + + k_iter = j; + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + ///GEMM implementation starts/// + BLIS_ZTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_PRE_ZTRSM_SMALL_2x4(AlphaVal,b11,cs_b) + ///implement TRSM/// + ////extract a00 + ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm3 and ymm4 with ymm1*/ + BLIS_ZTRSM_TWO_DIV(ymm3,ymm4) +#else + /*performs dcomplex multiplication of ymm3 and ymm4 with ymm1*/ + BLIS_ZTRSM_MUL(ymm3) + BLIS_ZTRSM_MUL(ymm4) +#endif + //extract a11 + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack + 1)); + //(ROW1): FMA operations + ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + rs_a*1)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + /* Step1 dcomplex multiply ymm2, ymm3 + * Step2 negate the result + * Step3 add ymmx*/ + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + //For ymm3 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm3, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm3, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + + //For ymm4 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + + ymm13 = _mm256_mul_pd(ymm4, ymm2); + ymm14 = _mm256_mul_pd(ymm4, ymm14); + ymm17 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + ymm17 = _mm256_mul_pd(ymm17, ymm15); + + //Step 3 + ymm5 = _mm256_add_pd(ymm16, ymm5); + ymm6 = _mm256_add_pd(ymm17, ymm6); + +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm5 and ymm6 with ymm1*/ + BLIS_ZTRSM_TWO_DIV(ymm5,ymm6) +#else + /*performs dcomplex multiplication of ymm5 and ymm6 with ymm1*/ + BLIS_ZTRSM_MUL(ymm5) + BLIS_ZTRSM_MUL(ymm6) +#endif + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + 2), ymm4); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + _mm256_storeu_pd((double *)(b11 + cs_b + 2), ymm6); + } + dim_t m_remainder = m - i; + if(m_remainder == 3) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; + b10 = B + i; + b11 = B + i + j*cs_b; + + k_iter = j; + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_ZTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) + + // Load b11 of size 4x6 and multiply with alpha + BLIS_PRE_ZTRSM_SMALL_3x4(AlphaVal,b11,cs_b) + + ///implement TRSM/// + ////extract a00 + ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm3 and ymm4 with ymm1*/ + BLIS_ZTRSM_TWO_DIV(ymm3,ymm4) +#else + /*performs dcomplex multiplication of ymm3 and ymm4 with ymm1*/ + BLIS_ZTRSM_MUL(ymm3) + BLIS_ZTRSM_MUL(ymm4) +#endif + //extract a11 + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack + 1)); + //(ROW1): FMA operations + ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + rs_a*1)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + /* Step1 dcomplex multiply ymm2, ymm3 + * Step2 negate the result + * Step3 add ymmx*/ + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + //For ymm3 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm3, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm3, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + + //For ymm4 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + + ymm13 = _mm256_mul_pd(ymm4, ymm2); + ymm14 = _mm256_mul_pd(ymm4, ymm14); + ymm17 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + ymm17 = _mm256_mul_pd(ymm17, ymm15); + + //Step 3 + ymm5 = _mm256_add_pd(ymm16, ymm5); + ymm6 = _mm256_add_pd(ymm17, ymm6); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm5 and ymm6 with ymm1*/ + BLIS_ZTRSM_TWO_DIV(ymm5,ymm6) +#else + /*performs dcomplex multiplication of ymm5 and ymm6 with ymm1*/ + BLIS_ZTRSM_MUL(ymm5) + BLIS_ZTRSM_MUL(ymm6) +#endif + + _mm256_storeu_pd((double *)b11, ymm3); + _mm_storeu_pd((double *)(b11 + 2), + _mm256_extractf128_pd(ymm4,0)); + + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + _mm_storeu_pd((double *)(b11 + cs_b + 2), + _mm256_extractf128_pd(ymm6,0)); + m_remainder -= 3; + i += 3; + } + if(m_remainder == 2) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; + b10 = B + i; + b11 = B + i + j*cs_b; + + k_iter = j; + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_ZTRSM_SMALL_GEMM_2nx2m(a01,b10,cs_b,p_lda,k_iter) + + // Load b11 of size 4x6 and multiply with alpha + BLIS_PRE_ZTRSM_SMALL_2x2(AlphaVal,b11,cs_b) + + ///implement TRSM/// + ////extract a00 + ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm3 with ymm1*/ + BLIS_ZTRSM_DIV(ymm3) +#else + /*performs dcomplex multiplication of ymm3 with ymm1*/ + BLIS_ZTRSM_MUL(ymm3) +#endif + //extract a11 + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack + 1)); + //(ROW1): FMA operations + ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + rs_a*1)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + /* Step1 dcomplex multiply ymm2, ymm3 + * Step2 negate the result + * Step3 add ymmx*/ + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + //For ymm3 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm3, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm3, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + + //Step 3 + ymm5 = _mm256_add_pd(ymm16, ymm5); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm5 with ymm1*/ + BLIS_ZTRSM_DIV(ymm5) +#else + /*performs dcomplex multiplication of ymm5 with ymm1*/ + BLIS_ZTRSM_MUL(ymm5) +#endif + + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + m_remainder -= 2; + i += 2; + } + if(m_remainder == 1) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; + b10 = B + i; + b11 = B + i + j*cs_b; + + k_iter = j; + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_ZTRSM_SMALL_GEMM_2nx2m(a01,b10,cs_b,p_lda,k_iter) + + // Load b11 of size 4x6 and multiply with alpha + BLIS_PRE_ZTRSM_SMALL_2x2(AlphaVal,b11,cs_b) + + ///implement TRSM/// + ////extract a00 + ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm3 with ymm1*/ + BLIS_ZTRSM_DIV(ymm3) +#else + /*performs dcomplex multiplication of ymm3 with ymm1*/ + BLIS_ZTRSM_MUL(ymm3) +#endif + //extract a11 + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack + 1)); + //(ROW1): FMA operations + ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + rs_a*1)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + /* Step1 dcomplex multiply ymm2, ymm3 + * Step2 negate the result + * Step3 add ymmx*/ + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + //For ymm3 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm3, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm3, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + + //Step 3 + ymm5 = _mm256_add_pd(ymm16, ymm5); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm5 with ymm1*/ + BLIS_ZTRSM_DIV(ymm5) +#else + /*performs dcomplex multiplication of ymm5 with ymm1*/ + BLIS_ZTRSM_MUL(ymm5) +#endif + _mm_storeu_pd((double *)b11, + _mm256_extractf128_pd(ymm3,0)); + _mm_storeu_pd((double *)(b11 + cs_b), + _mm256_extractf128_pd(ymm5,0)); + m_remainder -= 1; + i += 1; + } + j += 2; + n_remainder -= 2; + } + else if(n_remainder == 1) + { + a01 = L + j*rs_a; + a11 = L + j*cs_a + j*rs_a; + dcomplex *ptr_a10_dup = D_A_pack; + dim_t p_lda = j; + // perform copy of A to packed buffer D_A_pack + + if(transa) + { + for(dim_t x =0;x < p_lda;x+=d_nr) + { + ymm0 = _mm256_loadu_pd((double const *)(a01)); + ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); + ymm3 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm4 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm3); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm4); + + ymm0 = _mm256_loadu_pd((double const *)(a01 + 2)); + ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a + 2)); + ymm3 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 2), + ymm3); + + ymm0 = _mm256_loadu_pd((double const *) + (a01 + cs_a * 2)); + ymm1 = _mm256_loadu_pd((double const *) + (a01 + cs_a * 2 + 2)); + ymm5 = _mm256_broadcast_pd((__m128d const *)&zero); + + ymm3 = _mm256_permute2f128_pd(ymm0,ymm5,0x20); + ymm4 = _mm256_permute2f128_pd(ymm0,ymm5,0x31); + ymm5 = _mm256_permute2f128_pd(ymm1,ymm5,0x20); + + _mm_storeu_pd((double *)(ptr_a10_dup + 2), + _mm256_extractf128_pd(ymm3,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + p_lda + 2), + _mm256_extractf128_pd(ymm4,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 2 + 2), + _mm256_extractf128_pd(ymm5, 0)); + a01 += d_nr*cs_a; + ptr_a10_dup += d_nr; + } + + } + else + { + dim_t loop_count = p_lda/2; + + for(dim_t x =0;x < loop_count;x++) + { + ymm15 = _mm256_loadu_pd((double const *) + (a01 + rs_a * 0 + x*2)); + _mm256_storeu_pd((double *) + (ptr_a10_dup + p_lda * 0 + x*2), ymm15); + } + + dim_t remainder_loop_count = p_lda - loop_count*2; + + __m128d xmm0; + if(remainder_loop_count != 0) + { + xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 0 + + loop_count*2)); + + _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + + loop_count*2), xmm0); + } + } + if(!is_unitdiag) + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_pd((__m128d const *)(a11)); + ymm1 = _mm256_broadcast_pd((__m128d const *)&ones); + ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + ymm7 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + /*Taking denomerator multiplication of real & + * imaginary components*/ + ymm4 = _mm256_mul_pd(ymm1, ymm1); + /*Swapping real & imaginary component position for addition with + * respective components*/ + ymm6 = _mm256_permute4x64_pd(ymm4, 0xb1); + ymm4 = _mm256_add_pd(ymm4, ymm6); + /*Negating imaginary component of numerator*/ + ymm1 = _mm256_mul_pd(ymm1, ymm7); + /*Dividing numerator by denominator*/ + ymm1 = _mm256_div_pd(ymm1, ymm4); +#endif + } + else + { + ymm1 = _mm256_broadcast_pd((__m128d const *)&ones); + } + _mm256_storeu_pd((double *)(d11_pack), ymm1); + + for(i = 0; (i+d_mr-1) < m; i += d_mr) //loop along 'M' direction + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; + b10 = B + i; + b11 = B + i + j*cs_b; + + k_iter = j; + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + ///GEMM implementation starts/// + BLIS_ZTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_PRE_ZTRSM_SMALL_1x4(b11,cs_b,AlphaVal) + ///implement TRSM/// + ////extract a00 + ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm3 and ymm4 with ymm1*/ + BLIS_ZTRSM_TWO_DIV(ymm3,ymm4) +#else + /*performs dcomplex multiplication of ymm3 and ymm4 with ymm1*/ + BLIS_ZTRSM_MUL(ymm3) + BLIS_ZTRSM_MUL(ymm4) +#endif + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + 2),ymm4); + } + dim_t m_remainder = m - i; + if(m_remainder == 3) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; + b10 = B + i; + b11 = B + i + j*cs_b; + + k_iter = j; + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_ZTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_PRE_ZTRSM_SMALL_1x3(b11,cs_b,AlphaVal) + + ///implement TRSM/// + ////extract a00 + ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm3 and ymm4 with ymm1*/ + BLIS_ZTRSM_TWO_DIV(ymm3,ymm4) +#else + /*performs dcomplex multiplication of ymm3 and ymm4 with ymm1*/ + BLIS_ZTRSM_MUL(ymm3) + BLIS_ZTRSM_MUL(ymm4) +#endif + + _mm256_storeu_pd((double *)b11, ymm3); + _mm_storeu_pd((double *)(b11 + 2), + _mm256_extractf128_pd(ymm4,0)); + m_remainder -= 3; + i += 3; + } + if(m_remainder == 2) + { + a01 = D_A_pack; + //pointer to block of A to be used for TRSM + a11 = L + j*cs_a + j*rs_a; + //pointer to block of B to be used in GEMM + b10 = B + i; + //pointer to block of B to be used for TRSM + b11 = B + i + j*cs_b; + //number of GEMM operations to be done + k_iter = j; + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_ZTRSM_SMALL_GEMM_1nx2m(a01,b10,cs_b,p_lda,k_iter) + + // Load b11 of size 4x6 and multiply with alpha + BLIS_PRE_ZTRSM_SMALL_1x2(AlphaVal,b11,cs_b) + + ///implement TRSM/// + ////extract a00 + ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm3 with ymm1*/ + BLIS_ZTRSM_DIV(ymm3) +#else + /*performs dcomplex multiplication of ymm3 with ymm1*/ + BLIS_ZTRSM_MUL(ymm3) +#endif + + _mm256_storeu_pd((double *)b11, ymm3); + m_remainder -= 2; + i += 2; + } + if(m_remainder == 1) + { + a01 = D_A_pack; + //pointer to block of A to be used for TRSM + a11 = L + j*cs_a + j*rs_a; + //pointer to block of B to be used in GEMM + b10 = B + i; + //pointer to block of B to be used for TRSM + b11 = B + i + j*cs_b; + + //number of GEMM operations to be done(in blocks of 4x4) + k_iter = j; + + /*Fill zeros into ymm registers used in gemm accumulations*/ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_ZTRSM_SMALL_GEMM_1nx1m(a01,b10,cs_b,p_lda,k_iter) + + // Load b11 of size 4x6 and multiply with alpha + BLIS_PRE_ZTRSM_SMALL_1x1(AlphaVal,b11,cs_b) + + ///implement TRSM/// + ////extract a00 + ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm3 with ymm1*/ + BLIS_ZTRSM_DIV(ymm3) +#else + /*performs dcomplex multiplication of ymm3 with ymm1*/ + BLIS_ZTRSM_MUL(ymm3) +#endif + _mm_storeu_pd((double *)b11, + _mm256_extractf128_pd(ymm3,0)); + m_remainder -= 1; + i += 1; + } + j += 1; + n_remainder -= 1; + } + + if ((required_packing_A == 1) && + bli_mem_is_alloc( &local_mem_buf_A_s )) + { + bli_membrk_release(&rntm, &local_mem_buf_A_s); + } + + + return BLIS_SUCCESS; +} + +BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +) +{ + return BLIS_SUCCESS; +} + +BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +) +{ + return BLIS_SUCCESS; +} + +BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +) +{ + return BLIS_SUCCESS; +} + +BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +) +{ + return BLIS_SUCCESS; +} + +BLIS_INLINE err_t bli_strsm_small_AutXB_AlXB +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +) +{ + return BLIS_SUCCESS; +} + +BLIS_INLINE err_t bli_strsm_small_AltXB_AuXB +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +) +{ + return BLIS_SUCCESS; +} + +BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +) +{ + return BLIS_SUCCESS; +} + +BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +) +{ + return BLIS_SUCCESS; +} + +#endif //BLIS_ENABLE_SMALL_MATRIX_TRSM From 8f310c3384742dabd8934b154e40e5bdcee40721 Mon Sep 17 00:00:00 2001 From: Dipal M Zambare Date: Wed, 7 Jul 2021 21:48:05 +0530 Subject: [PATCH 031/243] AOCL DTL - Added thread and execution time details in logs -- Added number of threads used in DTL logs -- Added support for timestamps in DTL traces -- Added time taken by API at BLAS layer in the DTL logs -- Added GFLOPS achieved in DTL logs -- Added support to enable/disable execution time and gflops printing for individual API's. We may not want it for all API's. Also it will help us migrate API's to execution time and gflops logs in stages. -- Updated GEMM bench to match new logs -- Refactored aocldtl_blis.c to remove code duplication. -- Clean up logs generation and reading to use spaces consistently to separate various fields. -- Updated AOCL_gettid() to return correct thread id when using pthreads. AMD-Internal: [CPUPL-1691] Change-Id: Iddb8a3be2a5cd624a07ccdbf5ae0695799d8ae8e --- aocl_dtl/aocldtl.c | 111 ++- aocl_dtl/aocldtl.h | 27 +- aocl_dtl/aocldtl_blis.c | 1244 ++++++++++++--------------------- aocl_dtl/aocldtl_blis.h | 23 +- aocl_dtl/aoclflist.c | 40 +- aocl_dtl/aoclflist.h | 11 +- aocl_dtl/aoclos.c | 20 +- aocl_dtl/aocltpdef.h | 6 +- bench/bench_gemm.c | 56 +- bench/inputgemm.txt | 50 +- bench/outTemp.txt | 21 - frame/3/bli_l3_sup.c | 1 - frame/3/bli_l3_sup_ref.c | 1 - frame/3/gemm/bli_gemm_front.c | 3 +- frame/3/gemm/bli_gemm_int.c | 5 +- frame/compat/bla_gemm.c | 96 ++- 16 files changed, 772 insertions(+), 943 deletions(-) delete mode 100644 bench/outTemp.txt diff --git a/aocl_dtl/aocldtl.c b/aocl_dtl/aocldtl.c index 1c377e0928..148e99d888 100644 --- a/aocl_dtl/aocldtl.c +++ b/aocl_dtl/aocldtl.c @@ -5,7 +5,7 @@ * These functions are invoked though macros by * end user. * - * Copyright (C) 2020, Advanced Micro Devices, Inc. All rights Reserved. + * Copyright (C) 2020-2021, Advanced Micro Devices, Inc. All rights reserved. * *=======================================================================*/ #include "blis.h" @@ -23,10 +23,23 @@ #endif #endif +/* + * Client should provide this function, it should return + * number of threads used by the API + */ +extern dim_t AOCL_get_requested_threads_count(void); + /* By default the trace level will be set to ALL User can configure this parameter at run time using command line argument */ uint32 gui32TraceLogLevel = AOCL_DTL_TRACE_LEVEL; +/* + * Time elapsed in the function will be logged from main thread only, + * we will save the main thread id. This will be compared with the id + * of the logging thread. + */ +AOCL_TID gtidMainThreadID = -1; + /* The user can configure the file name in which he wants to dump the data */ #if AOCL_DTL_TRACE_ENABLE /* The file name for storing traced log added manually in the code */ @@ -117,6 +130,9 @@ void DTL_Initialize( } #endif + /* Save Id for main thread */ + gtidMainThreadID = AOCL_gettid(); + } /* DTL_Initialize */ #endif @@ -162,6 +178,7 @@ void DTL_Uninitialize(void) * pi8FunctionName - Function Name * ui32LineNumber - Line number * pi8Message - Message to be printed +* * Output Parameter(s) : None * Return parameter(s) : None *==================================================================*/ @@ -176,6 +193,8 @@ void DTL_Trace( { uint8 i = 0; AOCL_FAL_FILE *pOutFile = NULL; + uint64 u64EventTime = AOCL_getTimestamp(); + dim_t u64RequestedThreadsCount = AOCL_get_requested_threads_count(); bli_init_auto(); @@ -226,7 +245,6 @@ void DTL_Trace( level set while initialization */ if (ui8LogLevel <= gui32TraceLogLevel) { - /* Indent as per level if is function call trace */ if ((ui8LogLevel >= AOCL_DTL_LEVEL_TRACE_1) && (ui8LogLevel <= AOCL_DTL_LEVEL_TRACE_8)) @@ -242,26 +260,39 @@ void DTL_Trace( switch (ui8LogType) { case TRACE_TYPE_FENTRY: - fprintf(pOutFile, "In %s()...\n", pi8FunctionName); + fprintf(pOutFile, "nt=%ld,ts=%ld: In %s()...\n", + u64RequestedThreadsCount, + u64EventTime, + pi8FunctionName); break; case TRACE_TYPE_FEXIT: if (pi8Message == NULL) { /* Function returned successfully */ - fprintf(pOutFile, "Out of %s()\n", pi8FunctionName); + fprintf(pOutFile, "ts=%ld: Out of %s()\n", + u64EventTime, + pi8FunctionName); } else { /* Function failed to complete, use message to get error */ - fprintf(pOutFile, "Out of %s() with error %s\n", pi8FunctionName, pi8Message); + fprintf(pOutFile, "ts=%ld: Out of %s() with error %s\n", + u64EventTime, + pi8FunctionName, + pi8Message); } break; case TRACE_TYPE_LOG: - fprintf(pOutFile, "%s:%d:%s\n", pi8FileName, ui32LineNumber, pi8Message); + fprintf(pOutFile, "%s %s", + pi8FileName, + pi8Message + ); + break; case TRACE_TYPE_RAW: - fprintf(pOutFile, "%s\n", pi8Message); + fprintf(pOutFile, "%s\n", + pi8Message); break; } fflush(pOutFile); @@ -407,6 +438,72 @@ void DTL_DumpData( } /* DTL_DumpData */ #endif +#if (AOCL_DTL_TRACE_ENABLE || AOCL_DTL_LOG_ENABLE) +void AOCL_DTL_start_perf_timer(void) +{ + AOCL_TID current_thread = AOCL_gettid(); + + // Automatic duration calulation is currently + // supported from main thread only, in other words + // at BLAS interface. + if (current_thread != gtidMainThreadID) { + return; + } + + AOCL_FLIST_Node *pFileNode = AOCL_FLIST_GetNode(gpLogFileList, current_thread); + + if (NULL == pFileNode) { + /* It might be the first call from the current thread, try to create + new trace for this thread. */ + AOCL_FAL_FILE *pOutFile = AOCL_FLIST_AddFile(pchDTL_LOG_FILE, &gpLogFileList, current_thread); + + if (NULL == pOutFile) + { + AOCL_DEBUGPRINT("File does not exists to dump the trace data \n"); + return; + } else { + pFileNode = AOCL_FLIST_GetNode(gpLogFileList, current_thread); + } + } + + pFileNode->u64SavedTimeStamp = AOCL_getTimestamp(); + fflush(stdout); +} + + +uint64 AOCL_DTL_get_time_spent(void) +{ + AOCL_TID current_thread = AOCL_gettid(); + + // Automatic duration calulation is currently + // supported from main thread only, in other words + // at BLAS interface. + if (current_thread != gtidMainThreadID) { + return 0; + } + + uint64 u64CurrentTimeStamp = AOCL_getTimestamp(); + AOCL_FLIST_Node *pFileNode = AOCL_FLIST_GetNode(gpLogFileList, AOCL_gettid()); + + if (NULL == pFileNode) { + /* It might be the first call from the current thread, try to create + new trace for this thread. */ + AOCL_FAL_FILE *pOutFile = AOCL_FLIST_AddFile(pchDTL_LOG_FILE, &gpLogFileList, AOCL_gettid()); + + if (NULL == pOutFile) + { + AOCL_DEBUGPRINT("File does not exists to dump the trace data \n"); + return 0; + } else { + pFileNode = AOCL_FLIST_GetNode(gpLogFileList, AOCL_gettid()); + } + } + + return (u64CurrentTimeStamp - pFileNode->u64SavedTimeStamp); +} + +#endif + /* This is enabled by passing ETRACE_ENABLE=1 to make */ #ifdef AOCL_DTL_AUTO_TRACE_ENABLE diff --git a/aocl_dtl/aocldtl.h b/aocl_dtl/aocldtl.h index 90d79ca17f..58c1a56079 100644 --- a/aocl_dtl/aocldtl.h +++ b/aocl_dtl/aocldtl.h @@ -1,12 +1,12 @@ /*=================================================================== * File Name : aocldtl.h - * + * * Description : This is main interface file for the end user - * It provides defination for all macros to be + * It provides defination for all macros to be * used by user to add debug/trace information. * - * Copyright (C) 2020, Advanced Micro Devices, Inc - * + * Copyright (C) 2020-2021, Advanced Micro Devices, Inc. All rights reserved. + * *==================================================================*/ #ifndef _AOCLDTL_H_ @@ -47,7 +47,7 @@ #endif #if AOCL_DTL_TRACE_ENABLE -/* Exit macro to trace the flow of control The parameter LogLevel specifies +/* Exit macro to trace the flow of control The parameter LogLevel specifies log level String will preferably contains the function name in which this macro is invoked */ #define AOCL_DTL_TRACE_EXIT(LogLevel) \ @@ -72,8 +72,8 @@ #endif #if AOCL_DTL_DUMP_ENABLE -/* Macro to Dump the DATA The parameters Buffer contains the data to be - dumped BufferSize specifies the no. of bytes to be dumped DataType +/* Macro to Dump the DATA The parameters Buffer contains the data to be + dumped BufferSize specifies the no. of bytes to be dumped DataType specifies the data type of Buffer */ #define AOCL_DTL_DUMP(LogLevel, Buffer, BufferSize, DataType, String, OutputType) \ /* Call the Dump function to Dump the DATA */ \ @@ -103,6 +103,19 @@ #define AOCL_DTL_LOG(LogLevel, Message) #endif +#if AOCL_DTL_LOG_ENABLE + +void AOCL_DTL_start_perf_timer(void); +uint64 AOCL_DTL_get_time_spent(void); + +/* Macro to log the Data */ +#define AOCL_DTL_START_PERF_TIMER() \ + AOCL_DTL_start_perf_timer() +#else +/* Dummy macro definition if the AOCL_DTL_LOG_ENABLE macro is not enabled */ +#define AOCL_DTL_START_PERF_TIMER() +#endif + /* Macro to initialize the prerequisite for debuging */ #ifdef AOCL_DTL_INITIALIZE_ENABLE #define AOCL_DTL_INITIALIZE(CURRENT_LOG_LEVEL) \ diff --git a/aocl_dtl/aocldtl_blis.c b/aocl_dtl/aocldtl_blis.c index dc0e2c7b91..c4de2bfcda 100755 --- a/aocl_dtl/aocldtl_blis.c +++ b/aocl_dtl/aocldtl_blis.c @@ -7,13 +7,47 @@ * *==================================================================*/ - #include "blis.h" +dim_t AOCL_get_requested_threads_count(void) +{ + return bli_thread_get_num_threads(); +} + #if AOCL_DTL_LOG_ENABLE +// Helper functions + +void DTL_get_complex_parts(char dt_type, + const void *complex_input, + double *real, + double *imag) +{ + if (dt_type == 'S' || dt_type == 's') + { + *real = *((float *)complex_input); + *imag = 0.0; + } + else if (dt_type == 'D' || dt_type == 'd') + { + *real = *((double *)complex_input); + *imag = 0.0; + } + else if (dt_type == 'c' || dt_type == 'C') + { + *real = (float)(((scomplex *)complex_input)->real); + *imag = (float)(((scomplex *)complex_input)->imag); + } + else if (dt_type == 'z' || dt_type == 'Z') + { + *real = ((dcomplex *)complex_input)->real; + *imag = ((dcomplex *)complex_input)->imag; + } +} + // Level-3 + void AOCL_DTL_log_gemm_sizes(int8 loglevel, char dt_type, const f77_char transa, @@ -21,13 +55,13 @@ void AOCL_DTL_log_gemm_sizes(int8 loglevel, const f77_int m, const f77_int n, const f77_int k, - const void* alpha, + const void *alpha, const f77_int lda, const f77_int ldb, - const void* beta, + const void *beta, const f77_int ldc, - const char* filename, - const char* function_name, + const char *filename, + const char *function_name, int line) { char buffer[256]; @@ -37,113 +71,92 @@ void AOCL_DTL_log_gemm_sizes(int8 loglevel, double beta_real = 0.0; double beta_imag = 0.0; - if( dt_type == 'S' || dt_type == 's' ) - { - alpha_real = *((float*)alpha); - alpha_imag = 0.0; - beta_real = *((float*)beta); - beta_imag = 0.0; - } - else if( dt_type == 'D' || dt_type == 'd' ) - { - alpha_real = *((double*)alpha); - alpha_imag = 0.0; - beta_real = *((double*)beta); - beta_imag = 0.0; - } - else if( dt_type == 'c' || dt_type == 'C' ) - { - alpha_real = (float)(((scomplex*)alpha)->real); - alpha_imag = (float)(((scomplex*)alpha)->imag); - beta_real = (float)(((scomplex*)beta)->real); - beta_imag = (float)(((scomplex*)beta)->imag); - } - else if( dt_type == 'z' || dt_type == 'Z' ) - { - alpha_real = ((dcomplex*)alpha)->real; - alpha_imag = ((dcomplex*)alpha)->imag; - beta_real = ((dcomplex*)beta)->real; - beta_imag = ((dcomplex*)beta)->imag; - } + AOCL_DTL_START_PERF_TIMER(); + + DTL_get_complex_parts(dt_type, alpha, &alpha_real, &alpha_imag); + DTL_get_complex_parts(dt_type, beta, &beta_real, &beta_imag); - //{S, D, C, Z} m, n, k, lda, ldb, ldc, transa, transb, alpha_real, alpha_imag, beta_real, beta_imag - sprintf(buffer, " %c %ld %ld %ld %ld %ld %ld %c %c %lf %lf %lf %lf", - dt_type, - (dim_t)m, (dim_t)n, (dim_t)k, - (dim_t)lda, (dim_t)ldb, (dim_t)ldc, - transa, transb, - alpha_real, alpha_imag, beta_real, beta_imag - ); + // Ordering as per cblas/blas interfaces + // {S, D, C, Z} transa, transb, m, n, k, alpha_real, alpha_imag, + // lda, ldb, beta_real, beta_imag, ldc + sprintf(buffer, "%c %c %c %ld %ld %ld %lf %lf %ld %ld %lf %lf %ld", + toupper(dt_type), + toupper(transa), toupper(transb), + (dim_t)m, (dim_t)n, (dim_t)k, + alpha_real, alpha_imag, + (inc_t)lda, (inc_t)ldb, + beta_real, beta_imag, + (inc_t)ldc); DTL_Trace(loglevel, TRACE_TYPE_LOG, function_name, function_name, line, buffer); } +void AOCL_DTL_log_gemm_stats(int8 loglevel, + const f77_int m, + const f77_int n, + const f77_int k) +{ + char buffer[256]; + + double flops = 2.0 * m * n * k; + + // Execution time is in micro seconds. + Double execution_time = AOCL_DTL_get_time_spent(); + + sprintf(buffer, " nt=%ld %.3f ms %0.3f GFLOPS", + AOCL_get_requested_threads_count(), + execution_time/1000.0, + flops/(execution_time * 1e3)); + + DTL_Trace(loglevel, TRACE_TYPE_RAW, NULL, NULL, 0, buffer); +} + void AOCL_DTL_log_trsm_sizes(int8 loglevel, - char dt_type, - f77_char side, - f77_char uploa, - f77_char transa, - f77_char diaga, - const f77_int m, - const f77_int n, - const void* alpha, - f77_int lda, - f77_int ldb, - const char* filename, - const char* function_name, - int line) + char dt_type, + f77_char side, + f77_char uploa, + f77_char transa, + f77_char diaga, + const f77_int m, + const f77_int n, + const void *alpha, + f77_int lda, + f77_int ldb, + const char *filename, + const char *function_name, + int line) { char buffer[256]; double alpha_real = 0.0; double alpha_imag = 0.0; - if( dt_type == 'S' || dt_type == 's' ) - { - alpha_real = *((float*)alpha); - alpha_imag = 0.0; - } - else if( dt_type == 'D' || dt_type == 'd' ) - { - alpha_real = *((double*)alpha); - alpha_imag = 0.0; - } - else if( dt_type == 'C' || dt_type == 'c' ) - { - alpha_real = (float)(((scomplex*)alpha)->real); - alpha_imag = (float)(((scomplex*)alpha)->imag); - } - else if( dt_type == 'z' || dt_type == 'Z' ) - { - alpha_real = ((dcomplex*)alpha)->real; - alpha_imag = ((dcomplex*)alpha)->imag; - } + DTL_get_complex_parts(dt_type, alpha, &alpha_real, &alpha_imag); //{S, D, C, Z} side, uplo, transa, diaga, m, n, lda, ldb, alpha_real, alpha_imag - sprintf(buffer, " %c %c %c %c %c %ld %ld %ld %ld %lf %lf",dt_type, - side, uploa, transa, diaga, - (dim_t)m, (dim_t)n, (dim_t)lda, (dim_t)ldb, - alpha_real, alpha_imag - ); + sprintf(buffer, "%c %c %c %c %c %ld %ld %ld %ld %lf %lf\n", dt_type, + side, uploa, transa, diaga, + (dim_t)m, (dim_t)n, (dim_t)lda, (dim_t)ldb, + alpha_real, alpha_imag); DTL_Trace(loglevel, TRACE_TYPE_LOG, function_name, function_name, line, buffer); } void AOCL_DTL_log_gemmt_sizes(int8 loglevel, - char dt_type, - char uplo, - char transa, - char transb, - const f77_int n, - const f77_int k, - const void* alpha, - const f77_int lda, - const f77_int ldb, - const void* beta, - const f77_int ldc, - const char* filename, - const char* function_name, - int line) + char dt_type, + char uplo, + char transa, + char transb, + const f77_int n, + const f77_int k, + const void *alpha, + const f77_int lda, + const f77_int ldb, + const void *beta, + const f77_int ldc, + const char *filename, + const char *function_name, + int line) { char buffer[256]; @@ -152,43 +165,17 @@ void AOCL_DTL_log_gemmt_sizes(int8 loglevel, double beta_real = 0.0; double beta_imag = 0.0; - if( dt_type == 's' || dt_type == 'S' ) - { - alpha_real = *((float*)alpha); - alpha_imag = 0.0; - beta_real = *((float*)beta); - beta_imag = 0.0; - } - else if( dt_type == 'd' || dt_type == 'D' ) - { - alpha_real = *((double*)alpha); - alpha_imag = 0.0; - beta_real = *((double*)beta); - beta_imag = 0.0; - } - else if( dt_type == 'c' || dt_type == 'C' ) - { - alpha_real = (float)(((scomplex*)alpha)->real); - alpha_imag = (float)(((scomplex*)alpha)->imag); - beta_real = (float)(((scomplex*)beta)->real); - beta_imag = (float)(((scomplex*)beta)->imag); - } - else if( dt_type == 'z' || dt_type == 'Z' ) - { - alpha_real = ((dcomplex*)alpha)->real; - alpha_imag = ((dcomplex*)alpha)->imag; - beta_real = ((dcomplex*)beta)->real; - beta_imag = ((dcomplex*)beta)->imag; - } + DTL_get_complex_parts(dt_type, alpha, &alpha_real, &alpha_imag); + DTL_get_complex_parts(dt_type, beta, &beta_real, &beta_imag); // {S,D,C,Z} {triangC : l or u} {n k lda ldb ldc transa transb alpha_real alpha_imaginary // beta_real, beta_imaginary} - sprintf(buffer, " %c %c %ld %ld %lu %lu %lu %c %c %lf %lf %lf %lf", + sprintf(buffer, "%c %c %ld %ld %lu %lu %lu %c %c %lf %lf %lf %lf\n", dt_type, uplo, (dim_t)n, (dim_t)k, - (dim_t)lda, (dim_t)ldb, (dim_t)ldc, - transa, transb, - alpha_real, alpha_imag, - beta_real, beta_imag); + (dim_t)lda, (dim_t)ldb, (dim_t)ldc, + transa, transb, + alpha_real, alpha_imag, + beta_real, beta_imag); DTL_Trace(loglevel, TRACE_TYPE_LOG, function_name, function_name, line, buffer); } @@ -197,15 +184,15 @@ void AOCL_DTL_log_hemm_sizes(int8 loglevel, char dt_type, const f77_char side, const f77_char uploa, - const f77_int m, - const f77_int n, - const void* alpha, + const f77_int m, + const f77_int n, + const void *alpha, const f77_int lda, const f77_int ldb, - const void* beta, + const void *beta, const f77_int ldc, - const char* filename, - const char* function_name, + const char *filename, + const char *function_name, int line) { char buffer[256]; @@ -214,24 +201,12 @@ void AOCL_DTL_log_hemm_sizes(int8 loglevel, double beta_real = 0.0; double beta_imag = 0.0; - if(dt_type == 'c' || dt_type == 'C' ) - { - alpha_real = (float)(((scomplex*)alpha)->real); - alpha_imag = (float)(((scomplex*)alpha)->imag); - beta_real = (float)(((scomplex*)beta)->real); - beta_imag = (float)(((scomplex*)beta)->imag); - } - else if(dt_type == 'z' || dt_type == 'Z' ) - { - alpha_real = ((dcomplex*)alpha)->real; - alpha_imag = ((dcomplex*)alpha)->imag; - beta_real = ((dcomplex*)beta)->real; - beta_imag = ((dcomplex*)beta)->imag; - } + DTL_get_complex_parts(dt_type, alpha, &alpha_real, &alpha_imag); + DTL_get_complex_parts(dt_type, beta, &beta_real, &beta_imag); // {C, Z} { side, uploa, m, n, alpha_real, alpha_imag, lda, incx, beta_real, beta_imag, incy} - sprintf(buffer, " %c %c %c %ld %ld %lf %lf %ld %ld %lf %lf %ld", + sprintf(buffer, "%c %c %c %ld %ld %lf %lf %ld %ld %lf %lf %ld\n", dt_type, side, uploa, (dim_t)m, (dim_t)n, alpha_real, alpha_imag, (dim_t)lda, (dim_t)ldb, beta_real, beta_imag, (dim_t)ldc); @@ -239,60 +214,49 @@ void AOCL_DTL_log_hemm_sizes(int8 loglevel, } // Level-3 -void AOCL_DTL_log_herk_sizes( int8 loglevel, - char dt_type, - const f77_char uploc, - const f77_char transa, - const f77_int m, - const f77_int k, - const void* alpha, - const f77_int lda, - const void* beta, - const f77_int ldc, - const char* filename, - const char* function_name, - int line) +void AOCL_DTL_log_herk_sizes(int8 loglevel, + char dt_type, + const f77_char uploc, + const f77_char transa, + const f77_int m, + const f77_int k, + const void *alpha, + const f77_int lda, + const void *beta, + const f77_int ldc, + const char *filename, + const char *function_name, + int line) { char buffer[256]; double alpha_real = 0.0; double alpha_imag = 0.0; double beta_real = 0.0; double beta_imag = 0.0; - if(dt_type == 'c' || dt_type == 'C' ) - { - alpha_real = (double)(((scomplex*)alpha)->real); - alpha_imag = (double)(((scomplex*)alpha)->imag); - beta_real = (double)(((scomplex*)beta)->real); - beta_imag = (double)(((scomplex*)beta)->imag); - } - else if(dt_type == 'z' || dt_type == 'Z' ) - { - alpha_real = (double)((dcomplex*)alpha)->real; - alpha_imag = (double)((dcomplex*)alpha)->imag; - beta_real = (double)((dcomplex*)beta)->real; - beta_imag = (double)((dcomplex*)beta)->imag; - } + + DTL_get_complex_parts(dt_type, alpha, &alpha_real, &alpha_imag); + DTL_get_complex_parts(dt_type, beta, &beta_real, &beta_imag); + // {C, Z} {uploc, transa, m, k, alpha_real, alpha_imag, lda, beta_real, beta_imag, ldc} - sprintf(buffer, " %c %c %c %ld %ld %lf %lf %ld %lf %lf %ld", - dt_type, uploc, transa, (dim_t)m, (dim_t)k, alpha_real, alpha_imag, (dim_t)lda, beta_real, beta_imag, (dim_t)ldc); + sprintf(buffer, "%c %c %c %ld %ld %lf %lf %ld %lf %lf %ld\n", + dt_type, uploc, transa, (dim_t)m, (dim_t)k, alpha_real, alpha_imag, (dim_t)lda, beta_real, beta_imag, (dim_t)ldc); DTL_Trace(loglevel, TRACE_TYPE_LOG, function_name, function_name, line, buffer); - } void AOCL_DTL_log_her2k_sizes(int8 loglevel, char dt_type, const f77_char uploc, const f77_char transa, - const f77_int m, - const f77_int k, - const void* alpha, + const f77_int m, + const f77_int k, + const void *alpha, const f77_int lda, const f77_int ldb, - const void* beta, + const void *beta, const f77_int ldc, - const char* filename, - const char* function_name, + const char *filename, + const char *function_name, int line) { char buffer[256]; @@ -300,42 +264,31 @@ void AOCL_DTL_log_her2k_sizes(int8 loglevel, double alpha_imag = 0.0; double beta_real = 0.0; double beta_imag = 0.0; - if(dt_type == 'c' || dt_type == 'C' ) - { - alpha_real = (double)(((scomplex*)alpha)->real); - alpha_imag = (double)(((scomplex*)alpha)->imag); - beta_real = (double)(((scomplex*)beta)->real); - beta_imag = (double)(((scomplex*)beta)->imag); - } - else if(dt_type == 'z' || dt_type == 'Z' ) - { - alpha_real = (double)((dcomplex*)alpha)->real; - alpha_imag = (double)((dcomplex*)alpha)->imag; - beta_real = (double)((dcomplex*)beta)->real; - beta_imag = (double)((dcomplex*)beta)->imag; - } + + DTL_get_complex_parts(dt_type, alpha, &alpha_real, &alpha_imag); + DTL_get_complex_parts(dt_type, beta, &beta_real, &beta_imag); + // {C, Z} { uploc, transa, m, k, alpha_real, alpha_imag, lda, ldb, beta_real, beta_imag, ldc} - sprintf(buffer, " %c %c %c %ld %ld %lf %lf %ld %ld %lf %lf %ld", - dt_type, uploc, transa, (dim_t)m, (dim_t)k, alpha_real, alpha_imag, (dim_t)lda, (dim_t)ldb, beta_real, beta_imag, (dim_t)ldc); + sprintf(buffer, "%c %c %c %ld %ld %lf %lf %ld %ld %lf %lf %ld\n", + dt_type, uploc, transa, (dim_t)m, (dim_t)k, alpha_real, alpha_imag, (dim_t)lda, (dim_t)ldb, beta_real, beta_imag, (dim_t)ldc); DTL_Trace(loglevel, TRACE_TYPE_LOG, function_name, function_name, line, buffer); - } -void AOCL_DTL_log_symm_sizes( int8 loglevel, - char dt_type, - const f77_char side, - const f77_char uploa, - const f77_int m, - const f77_int n, - const void* alpha, - const f77_int lda, - const f77_int ldb, - const void* beta, - const f77_int ldc, - const char* filename, - const char* function_name, - int line) +void AOCL_DTL_log_symm_sizes(int8 loglevel, + char dt_type, + const f77_char side, + const f77_char uploa, + const f77_int m, + const f77_int n, + const void *alpha, + const f77_int lda, + const f77_int ldb, + const void *beta, + const f77_int ldc, + const char *filename, + const char *function_name, + int line) { char buffer[256]; double alpha_real = 0.0; @@ -343,92 +296,65 @@ void AOCL_DTL_log_symm_sizes( int8 loglevel, double beta_real = 0.0; double beta_imag = 0.0; - if(dt_type == 's' || dt_type == 'S' ) - { - alpha_real = *((float*)alpha); - alpha_imag = 0.0; - beta_real = *((float*)beta); - beta_imag = 0.0; - } - else if(dt_type == 'd' || dt_type == 'D' ) - { - alpha_real = *((double*)alpha); - alpha_imag = 0.0; - beta_real = *((double*)beta); - beta_imag = 0.0; - } - else if(dt_type == 'c' || dt_type == 'C' ) - { - alpha_real = (float)(((scomplex*)alpha)->real); - alpha_imag = (float)(((scomplex*)alpha)->imag); - beta_real = (float)(((scomplex*)beta)->real); - beta_imag = (float)(((scomplex*)beta)->imag); - } - else if(dt_type == 'z' || dt_type == 'Z' ) - { - alpha_real = ((dcomplex*)alpha)->real; - alpha_imag = ((dcomplex*)alpha)->imag; - beta_real = ((dcomplex*)beta)->real; - beta_imag = ((dcomplex*)beta)->imag; - } + DTL_get_complex_parts(dt_type, alpha, &alpha_real, &alpha_imag); + DTL_get_complex_parts(dt_type, beta, &beta_real, &beta_imag); // {S, D, C, Z} { side, uploa, m, n, alpha_real, alpha_imag, lda, ldb, beta_real, beta_imag, ldc} - sprintf(buffer, " %c %c %c %ld %ld %lf %lf %ld %ld %lf %lf %ld", - dt_type, side, uploa, (dim_t)m, (dim_t)n, alpha_real, alpha_imag, (dim_t)lda, (dim_t)ldb, beta_real, beta_imag, (dim_t)ldc); + sprintf(buffer, "%c %c %c %ld %ld %lf %lf %ld %ld %lf %lf %ld\n", + dt_type, side, uploa, (dim_t)m, (dim_t)n, alpha_real, alpha_imag, (dim_t)lda, (dim_t)ldb, beta_real, beta_imag, (dim_t)ldc); DTL_Trace(loglevel, TRACE_TYPE_LOG, function_name, function_name, line, buffer); - } // Level-2 -void AOCL_DTL_log_symv_sizes( int8 loglevel, - char dt_type, - const f77_char uploa, - const f77_int m, - const void* alpha, - const f77_int lda, - const f77_int incx, - const void* beta, - const f77_int incy, - const char* filename, - const char* function_name, - int line) +void AOCL_DTL_log_symv_sizes(int8 loglevel, + char dt_type, + const f77_char uploa, + const f77_int m, + const void *alpha, + const f77_int lda, + const f77_int incx, + const void *beta, + const f77_int incy, + const char *filename, + const char *function_name, + int line) { char buffer[256]; double alpha_d = 0.0; double beta_d = 0.0; - if(dt_type == 's' || dt_type == 'S' ) + + if (dt_type == 's' || dt_type == 'S') { - alpha_d = *((float*)alpha); - beta_d = *((float*)beta); + alpha_d = *((float *)alpha); + beta_d = *((float *)beta); } - else if(dt_type == 'd' || dt_type == 'D' ) + else if (dt_type == 'd' || dt_type == 'D') { - alpha_d = *((double*)alpha); - beta_d = *((double*)beta); + alpha_d = *((double *)alpha); + beta_d = *((double *)beta); } // {S, D} { uploa, m, alpha_d, lda, incx, beta_d, incy} - sprintf(buffer, " %c %c %ld %lf %ld %ld %lf %ld", - dt_type, uploa, (dim_t)m, alpha_d, (dim_t)lda, (dim_t)incx, beta_d, (dim_t)incy); + sprintf(buffer, "%c %c %ld %lf %ld %ld %lf %ld\n", + dt_type, uploa, (dim_t)m, alpha_d, (dim_t)lda, (dim_t)incx, beta_d, (dim_t)incy); DTL_Trace(loglevel, TRACE_TYPE_LOG, function_name, function_name, line, buffer); - } -void AOCL_DTL_log_gemv_sizes( int8 loglevel, - char dt_type, - const f77_char transa, - const f77_int m, - const f77_int n, - const void* alpha, - const f77_int lda, - const f77_int incx, - const void* beta, - const f77_int incy, - const char* filename, - const char* function_name, - int line) +void AOCL_DTL_log_gemv_sizes(int8 loglevel, + char dt_type, + const f77_char transa, + const f77_int m, + const f77_int n, + const void *alpha, + const f77_int lda, + const f77_int incx, + const void *beta, + const f77_int incy, + const char *filename, + const char *function_name, + int line) { char buffer[256]; double alpha_real = 0.0; @@ -436,152 +362,93 @@ void AOCL_DTL_log_gemv_sizes( int8 loglevel, double beta_real = 0.0; double beta_imag = 0.0; - if(dt_type == 's' || dt_type == 'S' ) - { - alpha_real = *((float*)alpha); - alpha_imag = 0.0; - beta_real = *((float*)beta); - beta_imag = 0.0; - } - else if(dt_type == 'd' || dt_type == 'D' ) - { - alpha_real = *((double*)alpha); - alpha_imag = 0.0; - beta_real = *((double*)beta); - beta_imag = 0.0; - } - else if(dt_type == 'c' || dt_type == 'C' ) - { - alpha_real = (float)(((scomplex*)alpha)->real); - alpha_imag = (float)(((scomplex*)alpha)->imag); - beta_real = (float)(((scomplex*)beta)->real); - beta_imag = (float)(((scomplex*)beta)->imag); - } - else if(dt_type == 'z' || dt_type == 'Z' ) - { - alpha_real = ((dcomplex*)alpha)->real; - alpha_imag = ((dcomplex*)alpha)->imag; - beta_real = ((dcomplex*)beta)->real; - beta_imag = ((dcomplex*)beta)->imag; - } + DTL_get_complex_parts(dt_type, alpha, &alpha_real, &alpha_imag); + DTL_get_complex_parts(dt_type, beta, &beta_real, &beta_imag); // {S, D,C, Z} { transa, m, n, alpha, lda, incx, beta, incy} - sprintf(buffer, " %c %c %ld %ld %lf %lf %ld %ld %lf %lf %ld", - dt_type, transa, (dim_t)m, (dim_t)n, alpha_real, alpha_imag, - (dim_t)lda, (dim_t)incx, beta_real, beta_imag, (dim_t)incy); - + sprintf(buffer, "%c %c %ld %ld %lf %lf %ld %ld %lf %lf %ld\n", + dt_type, transa, (dim_t)m, (dim_t)n, alpha_real, alpha_imag, + (dim_t)lda, (dim_t)incx, beta_real, beta_imag, (dim_t)incy); DTL_Trace(loglevel, TRACE_TYPE_LOG, function_name, function_name, line, buffer); - } -void AOCL_DTL_log_ger_sizes( int8 loglevel, - char dt_type, - const f77_int m, - const f77_int n, - const void* alpha, - const f77_int incx, - const f77_int incy, - const f77_int lda, - const char* filename, - const char* function_name, - int line - ) +void AOCL_DTL_log_ger_sizes(int8 loglevel, + char dt_type, + const f77_int m, + const f77_int n, + const void *alpha, + const f77_int incx, + const f77_int incy, + const f77_int lda, + const char *filename, + const char *function_name, + int line) { char buffer[256]; double alpha_real = 0.0; double alpha_imag = 0.0; - if(dt_type == 's' || dt_type == 'S' ) - { - alpha_real = *((float*)alpha); - alpha_imag = 0.0; - } - else if(dt_type == 'd' || dt_type == 'D' ) - { - alpha_real = *((double*)alpha); - alpha_imag = 0.0; - } - else if(dt_type == 'c' || dt_type == 'C' ) - { - alpha_real = (float)(((scomplex*)alpha)->real); - alpha_imag = (float)(((scomplex*)alpha)->imag); - } - else if(dt_type == 'z' || dt_type == 'Z' ) - { - alpha_real = ((dcomplex*)alpha)->real; - alpha_imag = ((dcomplex*)alpha)->imag; - } + DTL_get_complex_parts(dt_type, alpha, &alpha_real, &alpha_imag); - sprintf(buffer, " %c %ld %ld %lf %lf %ld %ld %ld", dt_type, (dim_t)m, (dim_t)n, alpha_real, alpha_imag, (dim_t)incx, (dim_t)incy, (dim_t)lda ); + sprintf(buffer, "%c %ld %ld %lf %lf %ld %ld %ld\n", dt_type, (dim_t)m, (dim_t)n, alpha_real, alpha_imag, (dim_t)incx, (dim_t)incy, (dim_t)lda); DTL_Trace(loglevel, TRACE_TYPE_LOG, function_name, function_name, line, buffer); - } -void AOCL_DTL_log_her_sizes( int8 loglevel, - char dt_type, - const f77_char uploa, - const f77_int m, - const void* alpha, - const f77_int incx, - const f77_int lda, - const char* filename, - const char* function_name, - int line) +void AOCL_DTL_log_her_sizes(int8 loglevel, + char dt_type, + const f77_char uploa, + const f77_int m, + const void *alpha, + const f77_int incx, + const f77_int lda, + const char *filename, + const char *function_name, + int line) { char buffer[256]; double alpha_real = 0.0; double alpha_imag = 0.0; - if(dt_type == 'c' || dt_type == 'C' ) - { - alpha_real = (double)(((scomplex*)alpha)->real); - alpha_imag = (double)(((scomplex*)alpha)->imag); - } - else if(dt_type == 'z' || dt_type == 'Z' ) - { - alpha_real = (double)((dcomplex*)alpha)->real; - alpha_imag = (double)((dcomplex*)alpha)->imag; - } + + DTL_get_complex_parts(dt_type, alpha, &alpha_real, &alpha_imag); + // {C, Z} {uploa, m alpha_real, alpha_imag incx lda} - sprintf(buffer, " %c %c %ld %lf %lf %ld %ld", + sprintf(buffer, "%c %c %ld %lf %lf %ld %ld\n", dt_type, uploa, (dim_t)m, alpha_real, alpha_imag, (dim_t)incx, (dim_t)lda); DTL_Trace(loglevel, TRACE_TYPE_LOG, function_name, function_name, line, buffer); - } -void AOCL_DTL_log_dotv_sizes( int8 loglevel, - char dt_type, - const f77_int n, - const f77_int incx, - const f77_int incy, - const char* filename, - const char* function_name, - int line) +void AOCL_DTL_log_dotv_sizes(int8 loglevel, + char dt_type, + const f77_int n, + const f77_int incx, + const f77_int incy, + const char *filename, + const char *function_name, + int line) { char buffer[256]; // { n, incx, incy} - sprintf(buffer, " %c %ld %ld %ld", dt_type, (dim_t)n, (dim_t)incx, (dim_t)incy); - + sprintf(buffer, "%c %ld %ld %ld\n", dt_type, (dim_t)n, (dim_t)incx, (dim_t)incy); DTL_Trace(loglevel, TRACE_TYPE_LOG, function_name, function_name, line, buffer); - } -void AOCL_DTL_log_hemv_sizes ( int8 loglevel, - char dt_type, - const f77_char uploa, - const f77_int m, - const void* alpha, - const f77_int lda, - const f77_int incx, - const void* beta, - const f77_int incy, - const char* filename, - const char* function_name, - int line) +void AOCL_DTL_log_hemv_sizes(int8 loglevel, + char dt_type, + const f77_char uploa, + const f77_int m, + const void *alpha, + const f77_int lda, + const f77_int incx, + const void *beta, + const f77_int incy, + const char *filename, + const char *function_name, + int line) { char buffer[256]; double alpha_real = 0.0; @@ -589,131 +456,83 @@ void AOCL_DTL_log_hemv_sizes ( int8 loglevel, double beta_real = 0.0; double beta_imag = 0.0; - if(dt_type == 's' || dt_type == 'S' ) - { - alpha_real = *((float*)alpha); - alpha_imag = 0.0; - beta_real = *((float*)beta); - beta_imag = 0.0; - } - else if(dt_type == 'd' || dt_type == 'D' ) - { - alpha_real = *((double*)alpha); - alpha_imag = 0.0; - beta_real = *((double*)beta); - beta_imag = 0.0; - } - else if(dt_type == 'c' || dt_type == 'C' ) - { - alpha_real = (float)(((scomplex*)alpha)->real); - alpha_imag = (float)(((scomplex*)alpha)->imag); - beta_real = (float)(((scomplex*)beta)->real); - beta_imag = (float)(((scomplex*)beta)->imag); - } - else if(dt_type == 'z' || dt_type == 'Z' ) - { - alpha_real = ((dcomplex*)alpha)->real; - alpha_imag = ((dcomplex*)alpha)->imag; - beta_real = ((dcomplex*)beta)->real; - beta_imag = ((dcomplex*)beta)->imag; - } + DTL_get_complex_parts(dt_type, alpha, &alpha_real, &alpha_imag); + DTL_get_complex_parts(dt_type, beta, &beta_real, &beta_imag); + // {S, D,C, Z} { uploa, m, alpha_real, alpha_imag, lda, incx, beta_real, beta_imag, incy} - sprintf(buffer, " %c %c %ld %lf %lf %ld %ld %lf %lf %ld", + sprintf(buffer, "%c %c %ld %lf %lf %ld %ld %lf %lf %ld\n", dt_type, uploa, (dim_t)m, alpha_real, alpha_imag, (dim_t)lda, (dim_t)incx, beta_real, beta_imag, (dim_t)incy); - DTL_Trace(loglevel, TRACE_TYPE_LOG, function_name, function_name, line, buffer); } - -void AOCL_DTL_log_her2_sizes ( int8 loglevel, - char dt_type, - const f77_char uploa, - const f77_int m, - const void* alpha, - const f77_int incx, - const f77_int incy, - const f77_int lda, - const char* filename, - const char* function_name, - int line) +void AOCL_DTL_log_her2_sizes(int8 loglevel, + char dt_type, + const f77_char uploa, + const f77_int m, + const void *alpha, + const f77_int incx, + const f77_int incy, + const f77_int lda, + const char *filename, + const char *function_name, + int line) { char buffer[256]; double alpha_real = 0.0; double alpha_imag = 0.0; - if(dt_type == 's' || dt_type == 'S' ) - { - alpha_real = *((float*)alpha); - alpha_imag = 0.0; - } - else if(dt_type == 'd' || dt_type == 'D' ) - { - alpha_real = *((double*)alpha); - alpha_imag = 0.0; - } - else if(dt_type == 'c' || dt_type == 'C' ) - { - alpha_real = (float)(((scomplex*)alpha)->real); - alpha_imag = (float)(((scomplex*)alpha)->imag); - } - else if(dt_type == 'z' || dt_type == 'Z' ) - { - alpha_real = ((dcomplex*)alpha)->real; - alpha_imag = ((dcomplex*)alpha)->imag; - } + DTL_get_complex_parts(dt_type, alpha, &alpha_real, &alpha_imag); // {S, D, C, Z} {uploa, m, alpha_real, alpha_imag, incx, incy} - sprintf(buffer, " %c %c %ld %lf %lf %ld %ld", - dt_type, uploa, (dim_t)m, alpha_real, alpha_imag, (dim_t)incx, (dim_t)incy); + sprintf(buffer, "%c %c %ld %lf %lf %ld %ld\n", + dt_type, uploa, (dim_t)m, alpha_real, alpha_imag, (dim_t)incx, (dim_t)incy); DTL_Trace(loglevel, TRACE_TYPE_LOG, function_name, function_name, line, buffer); } // Level-1 -void AOCL_DTL_log_amax_sizes ( int8 loglevel, - char dt_type, - const f77_int n, - const f77_int incx, - const char* filename, - const char* function_name, - int line) +void AOCL_DTL_log_amax_sizes(int8 loglevel, + char dt_type, + const f77_int n, + const f77_int incx, + const char *filename, + const char *function_name, + int line) { char buffer[256]; // {S, D, C, Z} {n, incx} - sprintf(buffer, " %c %ld %ld", dt_type, (dim_t)n, (dim_t)incx); + sprintf(buffer, "%c %ld %ld\n", dt_type, (dim_t)n, (dim_t)incx); DTL_Trace(loglevel, TRACE_TYPE_LOG, function_name, function_name, line, buffer); - } -void AOCL_DTL_log_asum_sizes ( int8 loglevel, - char dt_type, - const f77_int n, - const f77_int incx, - const char* filename, - const char* function_name, - int line) +void AOCL_DTL_log_asum_sizes(int8 loglevel, + char dt_type, + const f77_int n, + const f77_int incx, + const char *filename, + const char *function_name, + int line) { char buffer[256]; // {S, D, C, Z} {n, incx} - sprintf(buffer, " %c %ld %ld", dt_type, (dim_t)n, (dim_t)incx); + sprintf(buffer, "%c %ld %ld\n", dt_type, (dim_t)n, (dim_t)incx); DTL_Trace(loglevel, TRACE_TYPE_LOG, function_name, function_name, line, buffer); - } -void AOCL_DTL_log_axpby_sizes ( int8 loglevel, - char dt_type, - const f77_int n, - const void* alpha, - const f77_int incx, - const void* beta, - const f77_int incy, - const char* filename, - const char* function_name, - int line) +void AOCL_DTL_log_axpby_sizes(int8 loglevel, + char dt_type, + const f77_int n, + const void *alpha, + const f77_int incx, + const void *beta, + const f77_int incy, + const char *filename, + const char *function_name, + int line) { char buffer[256]; double alpha_real = 0.0; @@ -721,240 +540,151 @@ void AOCL_DTL_log_axpby_sizes ( int8 loglevel, double beta_real = 0.0; double beta_imag = 0.0; - if(dt_type == 's' || dt_type == 'S' ) - { - alpha_real = *((float*)alpha); - alpha_imag = 0.0; - beta_real = *((float*)beta); - beta_imag = 0.0; - } - else if(dt_type == 'd' || dt_type == 'D' ) - { - alpha_real = *((double*)alpha); - alpha_imag = 0.0; - beta_real = *((double*)beta); - beta_imag = 0.0; - } - else if(dt_type == 'c' || dt_type == 'C' ) - { - alpha_real = (float)(((scomplex*)alpha)->real); - alpha_imag = (float)(((scomplex*)alpha)->imag); - beta_real = (float)(((scomplex*)beta)->real); - beta_imag = (float)(((scomplex*)beta)->imag); - } - else if(dt_type == 'z' || dt_type == 'Z' ) - { - alpha_real = ((dcomplex*)alpha)->real; - alpha_imag = ((dcomplex*)alpha)->imag; - beta_real = ((dcomplex*)beta)->real; - beta_imag = ((dcomplex*)beta)->imag; - } + DTL_get_complex_parts(dt_type, alpha, &alpha_real, &alpha_imag); + DTL_get_complex_parts(dt_type, beta, &beta_real, &beta_imag); // {S, D, C, Z} {n, alpha_real, alpha_imag, incx, beta_real, beta_imag, incy} - sprintf(buffer, " %c %ld %lf %lf %ld %lf %lf %ld", - dt_type, (dim_t)n, alpha_real, alpha_imag, (dim_t)incx, - beta_real, beta_imag, (dim_t)incy); + sprintf(buffer, "%c %ld %lf %lf %ld %lf %lf %ld\n", + dt_type, (dim_t)n, alpha_real, alpha_imag, (dim_t)incx, + beta_real, beta_imag, (dim_t)incy); DTL_Trace(loglevel, TRACE_TYPE_LOG, function_name, function_name, line, buffer); - } -void AOCL_DTL_log_axpy_sizes ( int8 loglevel, - char dt_type, - const f77_int n, - const void* alpha, - const f77_int incx, - const f77_int incy, - const char* filename, - const char* function_name, - int line) +void AOCL_DTL_log_axpy_sizes(int8 loglevel, + char dt_type, + const f77_int n, + const void *alpha, + const f77_int incx, + const f77_int incy, + const char *filename, + const char *function_name, + int line) { char buffer[256]; double alpha_real = 0.0; double alpha_imag = 0.0; - if(dt_type == 's' || dt_type == 'S' ) - { - alpha_real = *((float*)alpha); - alpha_imag = 0.0; - } - else if(dt_type == 'd' || dt_type == 'D' ) - { - alpha_real = *((double*)alpha); - alpha_imag = 0.0; - } - else if(dt_type == 'c' || dt_type == 'C' ) - { - alpha_real = (float)(((scomplex*)alpha)->real); - alpha_imag = (float)(((scomplex*)alpha)->imag); - } - else if(dt_type == 'z' || dt_type == 'Z' ) - { - alpha_real = ((dcomplex*)alpha)->real; - alpha_imag = ((dcomplex*)alpha)->imag; - } + DTL_get_complex_parts(dt_type, alpha, &alpha_real, &alpha_imag); // {S, D, C, Z} {n, alpha_real, alpha_imag, incx, incy} - sprintf(buffer, " %c %ld %lf %lf %ld %ld", - dt_type, (dim_t)n, alpha_real, alpha_imag, (dim_t)incx, (dim_t)incy); + sprintf(buffer, "%c %ld %lf %lf %ld %ld\n", + dt_type, (dim_t)n, alpha_real, alpha_imag, (dim_t)incx, (dim_t)incy); DTL_Trace(loglevel, TRACE_TYPE_LOG, function_name, function_name, line, buffer); } -void AOCL_DTL_log_copy_sizes( int8 loglevel, - char dt_type, - const f77_int n, - const f77_int incx, - const f77_int incy, - const char* filename, - const char* function_name, - int line - ) +void AOCL_DTL_log_copy_sizes(int8 loglevel, + char dt_type, + const f77_int n, + const f77_int incx, + const f77_int incy, + const char *filename, + const char *function_name, + int line) { char buffer[256]; // {S, D, C, Z} {n, incx, incy} - sprintf(buffer, " %c %ld %ld %ld", dt_type, (dim_t)n, (dim_t)incx, (dim_t)incy); + sprintf(buffer, "%c %ld %ld %ld\n", dt_type, (dim_t)n, (dim_t)incx, (dim_t)incy); DTL_Trace(loglevel, TRACE_TYPE_LOG, function_name, function_name, line, buffer); - } - -void AOCL_DTL_log_scal_sizes( int8 loglevel, - char dt_type, - const void* alpha, - const f77_int n, - const f77_int incx, - const char* filename, - const char* function_name, - int line) +void AOCL_DTL_log_scal_sizes(int8 loglevel, + char dt_type, + const void *alpha, + const f77_int n, + const f77_int incx, + const char *filename, + const char *function_name, + int line) { char buffer[256]; double alpha_real = 0.0; double alpha_imag = 0.0; - if(dt_type == 's' || dt_type == 'S' ) - { - alpha_real = *((float*)alpha); - alpha_imag = 0.0; - } - else if(dt_type == 'd' || dt_type == 'D' ) - { - alpha_real = *((double*)alpha); - alpha_imag = 0.0; - } - else if(dt_type == 'c' || dt_type == 'C' ) - { - alpha_real = (float)(((scomplex*)alpha)->real); - alpha_imag = (float)(((scomplex*)alpha)->imag); - } - else if(dt_type == 'z' || dt_type == 'Z' ) - { - alpha_real = ((dcomplex*)alpha)->real; - alpha_imag = ((dcomplex*)alpha)->imag; - } + DTL_get_complex_parts(dt_type, alpha, &alpha_real, &alpha_imag); // {S, D, C, Z} { alpha, n, incx} - sprintf(buffer, " %c %lf %lf %ld %ld", - dt_type, alpha_real, alpha_imag, (dim_t)n, (dim_t)incx); + sprintf(buffer, "%c %lf %lf %ld %ld\n", + dt_type, alpha_real, alpha_imag, (dim_t)n, (dim_t)incx); DTL_Trace(loglevel, TRACE_TYPE_LOG, function_name, function_name, line, buffer); - } -void AOCL_DTL_log_swap_sizes( int8 loglevel, - char dt_type, - const f77_int n, - const f77_int incx, - const f77_int incy, - const char* filename, - const char* function_name, - int line) +void AOCL_DTL_log_swap_sizes(int8 loglevel, + char dt_type, + const f77_int n, + const f77_int incx, + const f77_int incy, + const char *filename, + const char *function_name, + int line) { - char buffer[256]; + char buffer[256]; // {S, D, C, Z} {n, incx, incy} - sprintf(buffer, " %c %ld %ld %ld", - dt_type, (dim_t)n, (dim_t)incx, (dim_t)incy); + sprintf(buffer, "%c %ld %ld %ld\n", + dt_type, (dim_t)n, (dim_t)incx, (dim_t)incy); DTL_Trace(loglevel, TRACE_TYPE_LOG, function_name, function_name, line, buffer); - } -void AOCL_DTL_log_nrm2_sizes( int8 loglevel, - char dt_type, - const f77_int n, - const f77_int incx, - const char* filename, - const char* function_name, - int line) +void AOCL_DTL_log_nrm2_sizes(int8 loglevel, + char dt_type, + const f77_int n, + const f77_int incx, + const char *filename, + const char *function_name, + int line) { - char buffer[256]; + char buffer[256]; // {S, D, C, Z} {n, incx} - sprintf(buffer, " %c %ld %ld", - dt_type, (dim_t)n, (dim_t)incx); + sprintf(buffer, "%c %ld %ld\n", + dt_type, (dim_t)n, (dim_t)incx); DTL_Trace(loglevel, TRACE_TYPE_LOG, function_name, function_name, line, buffer); - } //Level-2 void AOCL_DTL_log_syr2_sizes(int8 loglevel, char dt_type, const f77_char uploa, - const f77_int m, - const void* alpha, - const f77_int incx, - const f77_int incy, - const f77_int lda, - const char* filename, - const char* function_name, - int line) + const f77_int m, + const void *alpha, + const f77_int incx, + const f77_int incy, + const f77_int lda, + const char *filename, + const char *function_name, + int line) { char buffer[256]; double alpha_real = 0.0; double alpha_imag = 0.0; - if(dt_type == 's' || dt_type == 'S' ) - { - alpha_real = *((float*)alpha); - alpha_imag = 0.0; - } - else if(dt_type == 'd' || dt_type == 'D' ) - { - alpha_real = *((double*)alpha); - alpha_imag = 0.0; - } - else if(dt_type == 'c' || dt_type == 'C' ) - { - alpha_real = (float)(((scomplex*)alpha)->real); - alpha_imag = (float)(((scomplex*)alpha)->imag); - } - else if(dt_type == 'z' || dt_type == 'Z' ) - { - alpha_real = ((dcomplex*)alpha)->real; - alpha_imag = ((dcomplex*)alpha)->imag; - } + DTL_get_complex_parts(dt_type, alpha, &alpha_real, &alpha_imag); + // { uploa, m, alpha_real, alpha_imag, incx, incy, lda} - sprintf(buffer, " %c %c %ld %lf %lf %ld %ld %ld", - dt_type, uploa, (dim_t)m, alpha_real, alpha_imag, (dim_t)incx, (dim_t)incy, (dim_t)lda); + sprintf(buffer, "%c %c %ld %lf %lf %ld %ld %ld\n", + dt_type, uploa, (dim_t)m, alpha_real, alpha_imag, (dim_t)incx, (dim_t)incy, (dim_t)lda); DTL_Trace(loglevel, TRACE_TYPE_LOG, function_name, function_name, line, buffer); } -void AOCL_DTL_log_syr2k_sizes(int8 loglevel, - char dt_type, - const f77_char uploc, - const f77_char transa, - const f77_int m, - const f77_int k, - const void* alpha, - const f77_int lda, - const f77_int ldb, - const void* beta, - const f77_int ldc, - const char* filename, - const char* function_name, - int line) +void AOCL_DTL_log_syr2k_sizes(int8 loglevel, + char dt_type, + const f77_char uploc, + const f77_char transa, + const f77_int m, + const f77_int k, + const void *alpha, + const f77_int lda, + const f77_int ldb, + const void *beta, + const f77_int ldc, + const char *filename, + const char *function_name, + int line) { char buffer[256]; double alpha_real = 0.0; @@ -962,78 +692,35 @@ void AOCL_DTL_log_syr2k_sizes(int8 loglevel, double beta_real = 0.0; double beta_imag = 0.0; - if(dt_type == 's' || dt_type == 'S' ) - { - alpha_real = *((float*)alpha); - alpha_imag = 0.0; - beta_real = *((float*)beta); - beta_imag = 0.0; - } - else if(dt_type == 'd' || dt_type == 'D' ) - { - alpha_real = *((double*)alpha); - alpha_imag = 0.0; - beta_real = *((double*)beta); - beta_imag = 0.0; - } - else if(dt_type == 'c' || dt_type == 'C' ) - { - alpha_real = (float)(((scomplex*)alpha)->real); - alpha_imag = (float)(((scomplex*)alpha)->imag); - beta_real = (float)(((scomplex*)beta)->real); - beta_imag = (float)(((scomplex*)beta)->imag); - } - else if(dt_type == 'z' || dt_type == 'Z' ) - { - alpha_real = ((dcomplex*)alpha)->real; - alpha_imag = ((dcomplex*)alpha)->imag; - beta_real = ((dcomplex*)beta)->real; - beta_imag = ((dcomplex*)beta)->imag; - } + DTL_get_complex_parts(dt_type, alpha, &alpha_real, &alpha_imag); + DTL_get_complex_parts(dt_type, beta, &beta_real, &beta_imag); + // { uploc, transa, m, k, alpha_real, alpha_imag, lda, ldb, beta_real, beta_imag, ldc} - sprintf(buffer, " %c %c %c %ld %ld %lf %lf %ld %ld %lf %lf %ld", - dt_type, uploc, transa, (dim_t)m, (dim_t)k, alpha_real, alpha_imag, (dim_t)lda, (dim_t)ldb, beta_real, beta_imag ,(dim_t)ldc); + sprintf(buffer, "%c %c %c %ld %ld %lf %lf %ld %ld %lf %lf %ld\n", + dt_type, uploc, transa, (dim_t)m, (dim_t)k, alpha_real, alpha_imag, (dim_t)lda, (dim_t)ldb, beta_real, beta_imag, (dim_t)ldc); DTL_Trace(loglevel, TRACE_TYPE_LOG, function_name, function_name, line, buffer); } -void AOCL_DTL_log_syr_sizes(int8 loglevel, - char dt_type, - const f77_char uploa, - const f77_int m, - const void* alpha, - const f77_int incx, - const f77_int lda, - const char* filename, - const char* function_name, - int line) +void AOCL_DTL_log_syr_sizes(int8 loglevel, + char dt_type, + const f77_char uploa, + const f77_int m, + const void *alpha, + const f77_int incx, + const f77_int lda, + const char *filename, + const char *function_name, + int line) { char buffer[256]; double alpha_real = 0.0; double alpha_imag = 0.0; - if(dt_type == 's' || dt_type == 'S' ) - { - alpha_real = *((float*)alpha); - alpha_imag = 0.0; - } - else if(dt_type == 'd' || dt_type == 'D' ) - { - alpha_real = *((double*)alpha); - alpha_imag = 0.0; - } - else if(dt_type == 'c' || dt_type == 'C' ) - { - alpha_real = (float)(((scomplex*)alpha)->real); - alpha_imag = (float)(((scomplex*)alpha)->imag); - } - else if(dt_type == 'z' || dt_type == 'Z' ) - { - alpha_real = ((dcomplex*)alpha)->real; - alpha_imag = ((dcomplex*)alpha)->imag; - } + DTL_get_complex_parts(dt_type, alpha, &alpha_real, &alpha_imag); + // {S, D,C, Z} { uploa, m, alpha_real, alpha_imag, incx, lda} - sprintf(buffer, " %c %c %ld %lf %lf %ld %ld", + sprintf(buffer, "%c %c %ld %lf %lf %ld %ld\n", dt_type, uploa, (dim_t)m, alpha_real, alpha_imag, (dim_t)incx, (dim_t)lda); DTL_Trace(loglevel, TRACE_TYPE_LOG, function_name, function_name, line, buffer); @@ -1043,15 +730,15 @@ void AOCL_DTL_log_syrk_sizes(int8 loglevel, char dt_type, const f77_char uploc, const f77_char transa, - const f77_int m, - const f77_int k, - const void* alpha, - const f77_int lda, - const void* beta, - const f77_int ldc, - const char* filename, - const char* function_name, - int line) + const f77_int m, + const f77_int k, + const void *alpha, + const f77_int lda, + const void *beta, + const f77_int ldc, + const char *filename, + const char *function_name, + int line) { char buffer[256]; double alpha_real = 0.0; @@ -1059,37 +746,12 @@ void AOCL_DTL_log_syrk_sizes(int8 loglevel, double beta_real = 0.0; double beta_imag = 0.0; - if(dt_type == 's' || dt_type == 'S' ) - { - alpha_real = *((float*)alpha); - alpha_imag = 0.0; - beta_real = *((float*)beta); - beta_imag = 0.0; - } - else if(dt_type == 'd' || dt_type == 'D' ) - { - alpha_real = *((double*)alpha); - alpha_imag = 0.0; - beta_real = *((double*)beta); - beta_imag = 0.0; - } - else if(dt_type == 'c' || dt_type == 'C' ) - { - alpha_real = (float)(((scomplex*)alpha)->real); - alpha_imag = (float)(((scomplex*)alpha)->imag); - beta_real = (float)(((scomplex*)beta)->real); - beta_imag = (float)(((scomplex*)beta)->imag); - } - else if(dt_type == 'z' || dt_type == 'Z' ) - { - alpha_real = ((dcomplex*)alpha)->real; - alpha_imag = ((dcomplex*)alpha)->imag; - beta_real = ((dcomplex*)beta)->real; - beta_imag = ((dcomplex*)beta)->imag; - } + DTL_get_complex_parts(dt_type, alpha, &alpha_real, &alpha_imag); + DTL_get_complex_parts(dt_type, beta, &beta_real, &beta_imag); + // {S, D,C, Z} { uploc, transa, m, k, alpha_real, alpha_imag, lda, beta_real, beta_imag, ldc} - sprintf(buffer, " %c %c %c %ld %ld %lf %lf %ld %lf %lf %ld", - dt_type, uploc, transa, (dim_t)m, (dim_t)k, alpha_real, alpha_imag, (dim_t)lda, beta_real, beta_imag, (dim_t)ldc); + sprintf(buffer, "%c %c %c %ld %ld %lf %lf %ld %lf %lf %ld\n", + dt_type, uploc, transa, (dim_t)m, (dim_t)k, alpha_real, alpha_imag, (dim_t)lda, beta_real, beta_imag, (dim_t)ldc); DTL_Trace(loglevel, TRACE_TYPE_LOG, function_name, function_name, line, buffer); } @@ -1100,42 +762,24 @@ void AOCL_DTL_log_trmm_sizes(int8 loglevel, const f77_char uploa, const f77_char transa, const f77_char diaga, - const f77_int m, - const f77_int n, - const void* alpha, - const f77_int lda, - const f77_int ldb, - const char* filename, - const char* function_name, - int line) + const f77_int m, + const f77_int n, + const void *alpha, + const f77_int lda, + const f77_int ldb, + const char *filename, + const char *function_name, + int line) { char buffer[256]; double alpha_real = 0.0; double alpha_imag = 0.0; - if(dt_type == 's' || dt_type == 'S' ) - { - alpha_real = *((float*)alpha); - alpha_imag = 0.0; - } - else if(dt_type == 'd' || dt_type == 'D' ) - { - alpha_real = *((double*)alpha); - alpha_imag = 0.0; - } - else if(dt_type == 'c' || dt_type == 'C' ) - { - alpha_real = (float)(((scomplex*)alpha)->real); - alpha_imag = (float)(((scomplex*)alpha)->imag); - } - else if(dt_type == 'z' || dt_type == 'Z' ) - { - alpha_real = ((dcomplex*)alpha)->real; - alpha_imag = ((dcomplex*)alpha)->imag; - } + DTL_get_complex_parts(dt_type, alpha, &alpha_real, &alpha_imag); + // {S, D,C, Z} { side, uploa, transa, diaga, m, n, alpha_real, alpha_imag, lda, ldb} - sprintf(buffer, " %c %c %c %c %c %ld %ld %lf %lf %ld %ld", - dt_type, side, uploa, transa, diaga, (dim_t)m, (dim_t)n, alpha_real, alpha_imag, (dim_t)lda, (dim_t)ldb); + sprintf(buffer, "%c %c %c %c %c %ld %ld %lf %lf %ld %ld\n", + dt_type, side, uploa, transa, diaga, (dim_t)m, (dim_t)n, alpha_real, alpha_imag, (dim_t)lda, (dim_t)ldb); DTL_Trace(loglevel, TRACE_TYPE_LOG, function_name, function_name, line, buffer); } @@ -1145,16 +789,16 @@ void AOCL_DTL_log_trmv_sizes(int8 loglevel, const f77_char uploa, const f77_char transa, const f77_char diaga, - const f77_int m, - const f77_int lda, - const f77_int incx, - const char* filename, - const char* function_name, - int line) + const f77_int m, + const f77_int lda, + const f77_int incx, + const char *filename, + const char *function_name, + int line) { char buffer[256]; // {S, D,C, Z} { side, uploa, transa, diaga, m, lda, incx} - sprintf(buffer, " %c %c %c %c %ld %ld %ld", + sprintf(buffer, "%c %c %c %c %ld %ld %ld\n", dt_type, uploa, transa, diaga, (dim_t)m, (dim_t)lda, (dim_t)incx); DTL_Trace(loglevel, TRACE_TYPE_LOG, function_name, function_name, line, buffer); @@ -1165,16 +809,16 @@ void AOCL_DTL_log_trsv_sizes(int8 loglevel, const f77_char uploa, const f77_char transa, const f77_char diaga, - const f77_int m, - const f77_int lda, - const f77_int incx, - const char* filename, - const char* function_name, - int line) + const f77_int m, + const f77_int lda, + const f77_int incx, + const char *filename, + const char *function_name, + int line) { char buffer[256]; // {S, D,C, Z} { side, uploa, transa, diaga, m, lda, incx} - sprintf(buffer, " %c %c %c %c %ld %ld %ld", + sprintf(buffer, "%c %c %c %c %ld %ld %ld\n", dt_type, uploa, transa, diaga, (dim_t)m, (dim_t)lda, (dim_t)incx); DTL_Trace(loglevel, TRACE_TYPE_LOG, function_name, function_name, line, buffer); diff --git a/aocl_dtl/aocldtl_blis.h b/aocl_dtl/aocldtl_blis.h index 8c995bc36a..a9ea3368f9 100755 --- a/aocl_dtl/aocldtl_blis.h +++ b/aocl_dtl/aocldtl_blis.h @@ -14,22 +14,29 @@ #include "blis.h" #if AOCL_DTL_LOG_ENABLE +dim_t AOCL_get_requested_threads_count(void); + void AOCL_DTL_log_gemm_sizes(int8 loglevel, - char dt, + char dt_type, const f77_char transa, const f77_char transb, const f77_int m, const f77_int n, const f77_int k, - const void* alpha, + const void *alpha, const f77_int lda, const f77_int ldb, - const void* beta, + const void *beta, const f77_int ldc, - const char* filename, - const char* functionn_name, + const char *filename, + const char *function_name, int line); +void AOCL_DTL_log_gemm_stats(int8 loglevel, + const f77_int m, + const f77_int n, + const f77_int k); + void AOCL_DTL_log_trsm_sizes(int8 loglevel, char dt, f77_char side, @@ -376,9 +383,13 @@ void AOCL_DTL_log_trmm_sizes(int8 loglevel, const char* function_name, int line); + #define AOCL_DTL_LOG_GEMM_INPUTS(loglevel, dt, transa, transb, m, n, k, alpha, lda, ldb, beta, ldc) \ AOCL_DTL_log_gemm_sizes(loglevel, dt, transa, transb, m, n, k, alpha, lda, ldb, beta, ldc, __FILE__, __FUNCTION__, __LINE__); +#define AOCL_DTL_LOG_GEMM_STATS(loglevel, m, n, k) \ + AOCL_DTL_log_gemm_stats(loglevel, m, n, k); + #define AOCL_DTL_LOG_TRSM_INPUTS(loglevel, dt, side, uploa, transa, diaga, m, n, alpha, lda, ldb) \ AOCL_DTL_log_trsm_sizes(loglevel, dt, side, uploa, transa, diaga, m, n, alpha, lda, ldb, __FILE__, __FUNCTION__, __LINE__); @@ -487,6 +498,8 @@ void AOCL_DTL_log_trmm_sizes(int8 loglevel, #define AOCL_DTL_LOG_GEMM_INPUTS(loglevel, dt, transa, transb, m, n, k, alpha, lda, ldb, beta, ldc) +#define AOCL_DTL_LOG_GEMM_STATS(loglevel, m, n, k) + #define AOCL_DTL_LOG_TRSM_INPUTS(loglevel, dt, side, uploa, transa, diaga, m, n, alpha, lda, ldb) #define AOCL_DTL_LOG_GEMMT_INPUTS(loglevel, dt, uplo, transa, transb, n, k, alpha, lda, ldb, beta, ldc) diff --git a/aocl_dtl/aoclflist.c b/aocl_dtl/aoclflist.c index 5bba38fb50..5d44fdba87 100644 --- a/aocl_dtl/aoclflist.c +++ b/aocl_dtl/aoclflist.c @@ -1,12 +1,12 @@ /*=================================================================== * File Name : aoclflist.c - * - * Description : Linked list of open files assocaited with + * + * Description : Linked list of open files assocaited with * each thread. This is used to log the data * to correct file as per the current thread id. * * Copyright (C) 2020, Advanced Micro Devices, Inc - * + * *==================================================================*/ #include "aocltpdef.h" @@ -16,7 +16,7 @@ #include "aoclos.h" -/* Disable instrumentation for following function, since they are called from +/* Disable instrumentation for following function, since they are called from * Auto Generated execution trace handlers. */ Bool AOCL_FLIST_IsEmpty( AOCL_FLIST_Node *plist) __attribute__((no_instrument_function)); @@ -45,6 +45,35 @@ Bool AOCL_FLIST_IsEmpty(AOCL_FLIST_Node *plist) } /* AOCL_FLIST_IsEmpty */ +AOCL_FLIST_Node * AOCL_FLIST_GetNode(AOCL_FLIST_Node *plist, AOCL_TID tid) +{ + AOCL_FLIST_Node *temp; + + if (AOCL_FLIST_IsEmpty(plist) == 1) + { + return NULL; + } + + temp = plist; + + /* if list is not empty search for the file handle in all nodes */ + while (temp != NULL) + { + if (temp->tid == tid) + { + if (temp->fp == NULL) + { + AOCL_DEBUGPRINT("Could not get saved time stamp for thread = %d", tid); + } + return temp; + } + temp = temp->pNext; + } + + return NULL; + +} /* AOCL_FLIST_GetNode */ + AOCL_FAL_FILE *AOCL_FLIST_GetFile(AOCL_FLIST_Node *plist, AOCL_TID tid) { AOCL_FLIST_Node *temp; @@ -89,7 +118,7 @@ AOCL_FAL_FILE *AOCL_FLIST_AddFile(const int8 *pchFilePrefix, AOCL_FLIST_Node **p } /* We don't have exiting file, lets try to open new one */ - sprintf(pchFileName, "P%d_T%d_%s", AOCL_getpid(), tid, pchFilePrefix); + sprintf(pchFileName, "P%d_T%u_%s", AOCL_getpid(), tid, pchFilePrefix); file = AOCL_FAL_Open(pchFileName, "wb"); if (file == NULL) @@ -108,6 +137,7 @@ AOCL_FAL_FILE *AOCL_FLIST_AddFile(const int8 *pchFilePrefix, AOCL_FLIST_Node **p newNode->pNext = NULL; newNode->tid = tid; + newNode->u64SavedTimeStamp = AOCL_getTimestamp(); newNode->fp = file; if (AOCL_FLIST_IsEmpty(*plist) == 1) diff --git a/aocl_dtl/aoclflist.h b/aocl_dtl/aoclflist.h index 849713bb0d..a4e45ca328 100644 --- a/aocl_dtl/aoclflist.h +++ b/aocl_dtl/aoclflist.h @@ -1,12 +1,12 @@ /*=================================================================== * File Name : aoclflist.h - * - * Description : Linked list of open files assocaited with + * + * Description : Linked list of open files assocaited with * each thread. This is used to log the deta * to correct file as per the current thread id. * * Copyright (C) 2020, Advanced Micro Devices, Inc - * + * *==================================================================*/ #ifndef _AOCL_FLIST_H_ @@ -19,12 +19,17 @@ typedef struct AOCL_FLIST_Node_t { AOCL_TID tid; AOCL_FAL_FILE *fp; + uint64 u64SavedTimeStamp; struct AOCL_FLIST_Node_t *pNext; } AOCL_FLIST_Node; Bool AOCL_FLIST_IsEmpty( AOCL_FLIST_Node *plist); +AOCL_FLIST_Node * AOCL_FLIST_GetNode( + AOCL_FLIST_Node *plist, + AOCL_TID tid); + AOCL_FAL_FILE *AOCL_FLIST_GetFile( AOCL_FLIST_Node *plist, AOCL_TID tid); diff --git a/aocl_dtl/aoclos.c b/aocl_dtl/aoclos.c index 0e554eb952..92a489564e 100644 --- a/aocl_dtl/aoclos.c +++ b/aocl_dtl/aoclos.c @@ -19,7 +19,7 @@ #include #endif -// BLIS TODO: This is workaround to check if BLIS is built with +// BLIS TODO: This is workaround to check if BLIS is built with // openmp support. Ideally we dont' want any library // specific code in dtl. #include @@ -36,19 +36,23 @@ */ -uint32 AOCL_gettid(void) __attribute__((no_instrument_function)); +AOCL_TID AOCL_gettid(void) __attribute__((no_instrument_function)); pid_t AOCL_getpid(void) __attribute__((no_instrument_function)); uint64 AOCL_getTimestamp(void) __attribute__((no_instrument_function)); -uint32 AOCL_gettid(void) +AOCL_TID AOCL_gettid(void) { #ifdef BLIS_ENABLE_OPENMP return omp_get_thread_num(); #else - return 0; // will not work for pthread-based parallelization - +#ifdef BLIS_ENABLE_PTHREADS + return pthread_self(); +#else + return 0; #endif +#endif + } pid_t AOCL_getpid(void) @@ -63,7 +67,7 @@ uint64 AOCL_getTimestamp(void) /* The C11 way */ if (clock_gettime(CLOCK_REALTIME, &tms)) { - return -1; + return -1; } /* seconds, multiplied with 1 million */ @@ -73,13 +77,13 @@ uint64 AOCL_getTimestamp(void) /* round up if necessary */ if (tms.tv_nsec % 1000 >= 500) { - ++micros; + ++micros; } return micros; } #else /* Non linux support */ -uint32 AOCL_gettid(void) +AOCL_TID AOCL_gettid(void) { /* stub for other os's */ return 0; diff --git a/aocl_dtl/aocltpdef.h b/aocl_dtl/aocltpdef.h index 896731c584..7c08455369 100644 --- a/aocl_dtl/aocltpdef.h +++ b/aocl_dtl/aocltpdef.h @@ -1,11 +1,11 @@ /*=================================================================== * File Name : aocltpdef.h - * + * * Description : Abstraction for various datatypes used by DTL. * - * Copyright (C) 2020, Advanced Micro Devices, Inc - * + * Copyright (C) 2020-2021, Advanced Micro Devices, Inc. All rights reserved. + * *==================================================================*/ #ifndef AOCL_TYPEDEF_H_ #define AOCL_TYPEDEF_H_ diff --git a/bench/bench_gemm.c b/bench/bench_gemm.c index ecb22c432f..8258b61d18 100755 --- a/bench/bench_gemm.c +++ b/bench/bench_gemm.c @@ -57,15 +57,15 @@ #define AOCL_MATRIX_INITIALISATION - +#define BUFFER_SIZE 256 /* For BLIS since logs are collected at BLAS interfaces * we disable cblas interfaces for this benchmark application */ -#ifdef BLIS_ENABLE_CBLAS -//#define CBLAS -#endif +#ifdef BLIS_ENABLE_CBLAS +//#define CBLAS +#endif int main( int argc, char** argv ) { @@ -110,26 +110,36 @@ int main( int argc, char** argv ) exit(1); } - fprintf(fout, "Dt m\t n\t k\t lda\t ldb\t ldc\t rs_a rs_b rs_c transa transb \ - alphaR\t alphaI\t betaR\t betaI\t gflops\n"); + fprintf(fout, "Dt transa transb m n k alphaR alphaI lda ldb betaR betaI ldc gflops\n"); + + // Following variables are needed for scanf to read inputs properly + // however they are not used in bench. + char api_name[BUFFER_SIZE]; // to store function name, line no present in logs + char dummy_buffer[BUFFER_SIZE]; + // Variables extracted from the logs which are used by bench char stor_scheme, transA_c, transB_c; double alpha_r, beta_r, alpha_i, beta_i; dim_t m_trans, n_trans; - char tmp[256]; // to store function name, line no present in logs. - dim_t rs_a, rs_b, rs_c; - dim_t cs_a, cs_b, cs_c; inc_t lda, ldb, ldc; - stor_scheme = 'C'; // since logs are collected at BLAS APIs + stor_scheme = 'C'; // By default set it to Column Major - while (fscanf(fin, "%s %c %ld %ld %ld %ld %ld %ld %ld %ld %ld %c %c %lf %lf %lf %lf\n", - tmp, &dt_ch, &m, &n, &k, &cs_a, &cs_b, &cs_c, &rs_a, &rs_b, &rs_c, - &transA_c, &transB_c, &alpha_r, &alpha_i, &beta_r, &beta_i) == 17) + //{S, D, C, Z} transa, transb, m, n, k, alpha_real, alpha_imag, lda ldb + // beta_real, beta_imag, ldc, + // + // number of threads, execution time, gflops ---> ignored by bench + + while (fscanf(fin, "%s %c %c %c %ld %ld %ld %lf %lf %ld %ld %lf %lf %ld[^\n]", + api_name, &dt_ch, &transA_c, &transB_c, &m, &n, &k, &alpha_r, &alpha_i, + &lda, &ldb, &beta_r, &beta_i, &ldc) == 14) { - if(cs_a==1 && cs_b==1 && cs_c==1) stor_scheme = 'R'; - if(rs_a==1 && rs_b==1 && rs_c==1) stor_scheme = 'C'; + // Discard any extra data on current line in the input file. + fgets(dummy_buffer, BUFFER_SIZE, fin ); + + // At BLAS level only column major order is supported. + stor_scheme = 'C'; if (dt_ch == 'D' || dt_ch == 'd') dt = BLIS_DOUBLE; else if (dt_ch == 'Z' || dt_ch == 'z') dt = BLIS_DCOMPLEX; @@ -164,10 +174,7 @@ int main( int argc, char** argv ) if( (stor_scheme == 'C') || (stor_scheme == 'c') ) { - // Column storage - lda = cs_a; ldb = cs_b; ldc = cs_c; - - // leading dimension should be greater than number of rows + // leading dimension should be greater than number of rows // if ((m > lda) || (k > ldb) || (m > ldc)) continue; // Since this bench app is run on logs generated by AOCL trace logs // - we have relaxed the checks on the input parameters. @@ -190,14 +197,12 @@ int main( int argc, char** argv ) } else if( (stor_scheme == 'r') || (stor_scheme == 'R') ) { - // Row-major order - lda = rs_a; ldb = rs_b; ldc = rs_c; //leading dimension should be greater than number of columns //if ((k > lda) || (n > ldb) || (n > ldc)) continue; // Since this bench app is run on logs generated by AOCL trace logs // - we have relaxed the checks on the input parameters. - // if A is transpose - A(k x lda), lda >= max(1,m) + // if A is transpose - A(k x lda), lda >= max(1,m) // if A is non-transpose - A (m x lda), lda >= max(1,k) // if B is transpose - B (n x ldb), ldb >= max(1,k) // if B is non-transpose - B (k x ldb ), ldb >= max(1,n) @@ -228,7 +233,7 @@ int main( int argc, char** argv ) } #endif #endif - + #ifdef AOCL_MATRIX_INITIALISATION bli_randm( &a ); bli_randm( &b ); @@ -474,9 +479,8 @@ int main( int argc, char** argv ) (unsigned long)n, (unsigned long)k, gflops); - fprintf (fout, "%c %ld\t %ld\t %ld\t %ld\t %ld\t %ld\t %ld %ld %ld %c %c %lf\t %lf\t %lf\t %lf\t %6.3f\n", \ - dt_ch, m, n, k, lda, ldb, ldc, rs_a, rs_b, rs_c, \ - transA_c, transB_c, alpha_r, alpha_i, beta_r, beta_i, gflops); + fprintf (fout, "%c %c %c %ld %ld %ld %lf %lf %ld %ld %lf %lf %ld %6.3f\n", \ + dt_ch, transA_c, transB_c, m, n, k, alpha_r, alpha_i, lda, ldb, beta_r, beta_i, ldc, gflops); fflush(fout); diff --git a/bench/inputgemm.txt b/bench/inputgemm.txt index 5274b7735f..3334636b0b 100644 --- a/bench/inputgemm.txt +++ b/bench/inputgemm.txt @@ -1,18 +1,32 @@ - bli_gemm_ex:125: D 173 23 1 173 174 174 1 1 1 t n -1.000000 0.000000 1.000000 0.000000 - bli_gemm_ex:125: D 173 23 1 1 1 1 1 23 23 t n -1.000000 0.000000 1.000000 0.000000 - bli_gemm_ex:125: D 173 23 1 1 1 1 1 23 23 n t -1.000000 0.000000 1.000000 0.000000 - bli_gemm_ex:125: D 83 23 1 83 84 84 1 1 1 n n -1.000000 0.000000 1.000000 0.000000 - bli_gemm_ex:125: D 41 2 1 41 42 42 1 1 1 n n -1.000000 0.000000 1.000000 0.000000 - bli_gemm_ex:125: D 77 8 1 77 78 78 1 1 1 n t -1.000000 0.000000 1.000000 0.000000 - bli_gemm_ex:125: D 77 8 1 77 78 78 1 1 1 n n -2.000000 0.000000 3.000000 0.000000 - bli_gemm_ex:125: D 41 5 1 41 42 42 1 1 1 n n -1.000000 0.000000 1.000000 0.000000 - bli_gemm_ex:125: D 41 5 1 41 42 42 1 1 1 t n -1.000000 0.000000 1.000000 0.000000 - bli_gemm_ex:125: D 65 8 1 65 66 66 1 1 1 n n -3.000000 0.000000 1.000000 0.000000 - bli_gemm_ex:125: D 53 8 1 53 54 54 1 1 1 n n -1.000000 0.000000 1.000000 0.000000 - bli_gemm_ex:125: D 68 8 1 68 69 69 1 1 1 n n -1.000000 0.000000 1.000000 0.000000 - bli_gemm_ex:125: D 41 5 1 41 42 42 1 1 1 n t -1.000000 0.000000 2.000000 0.000000 - bli_gemm_ex:125: D 41 5 1 41 42 42 1 1 1 n n -1.000000 0.000000 1.000000 0.000000 - bli_gemm_ex:125: D 53 5 1 53 54 54 1 1 1 n n -1.000000 0.000000 1.000000 0.000000 - bli_gemm_ex:125: D 95 14 1 95 96 96 1 1 1 t n -1.000000 0.000000 1.000000 0.000000 - bli_gemm_ex:125: D 110 17 1 1 1 1 1 17 17 n n -1.000000 0.000000 1.000000 0.000000 - bli_gemm_ex:125: D 95 14 1 95 96 96 1 1 1 n n -1.000000 0.000000 1.000000 0.000000 +dgemm_ D N N 1000 3000 2000 0.900000 0.000000 4000 5000 -1.100000 0.000000 6000 nt=4 1542.854 ms 7.778 GFLOPS +dgemm_ D N N 100 100 100 0.900000 0.000000 104 104 -1.100000 0.000000 104 nt=4 0.307 ms 6.515 GFLOPS +dgemm_ D N N 500 500 500 0.900000 0.000000 504 504 -1.100000 0.000000 504 nt=4 32.442 ms 7.706 GFLOPS +dgemm_ D N N 900 900 900 0.900000 0.000000 904 904 -1.100000 0.000000 904 nt=4 172.170 ms 8.468 GFLOPS +dgemm_ D N N 1300 1300 1300 0.900000 0.000000 1304 1304 -1.100000 0.000000 1304 nt=4 655.381 ms 6.704 GFLOPS +dgemm_ D N T 1700 1700 1700 0.900000 0.000000 1704 1704 -1.100000 0.000000 1704 nt=4 1302.928 ms 7.541 GFLOPS +dgemm_ D T N 2100 2100 2100 0.900000 0.000000 2104 2104 -1.100000 0.000000 2104 nt=4 3278.541 ms 5.649 GFLOPS +dgemm_ D T T 2500 2500 2500 0.900000 0.000000 2504 2504 -1.100000 0.000000 2504 nt=4 5292.842 ms 5.904 GFLOPS +zgemm_ Z N N 1000 3000 2000 0.900000 0.000000 4000 5000 -1.100000 0.000000 6000 nt=4 300.940 ms 159.500 GFLOPS +zgemm_ Z N N 100 100 100 0.900000 0.000000 104 104 -1.100000 0.000000 104 nt=4 0.748 ms 10.695 GFLOPS +zgemm_ Z N N 500 500 500 0.900000 0.000000 504 504 -1.100000 0.000000 504 nt=4 8.618 ms 116.036 GFLOPS +zgemm_ Z N N 900 900 900 0.900000 0.000000 904 904 -1.100000 0.000000 904 nt=4 42.717 ms 136.526 GFLOPS +zgemm_ Z N N 1300 1300 1300 0.900000 0.000000 1304 1304 -1.100000 0.000000 1304 nt=4 124.652 ms 141.001 GFLOPS +zgemm_ Z N T 1700 1700 1700 0.900000 0.000000 1704 1704 -1.100000 0.000000 1704 nt=4 277.029 ms 141.877 GFLOPS +zgemm_ Z T N 2100 2100 2100 0.900000 0.000000 2104 2104 -1.100000 0.000000 2104 nt=4 494.360 ms 149.866 GFLOPS +zgemm_ Z T T 2500 2500 2500 0.900000 0.000000 2504 2504 -1.100000 0.000000 2504 nt=4 803.699 ms 155.531 GFLOPS +cgemm_ C N N 1000 3000 2000 0.900000 0.000000 4000 5000 -1.100000 0.000000 6000 nt=4 135.321 ms 354.712 GFLOPS +cgemm_ C N N 100 100 100 0.900000 0.000000 104 104 -1.100000 0.000000 104 nt=4 0.429 ms 18.648 GFLOPS +cgemm_ C N N 500 500 500 0.900000 0.000000 504 504 -1.100000 0.000000 504 nt=4 5.045 ms 198.216 GFLOPS +cgemm_ C N N 900 900 900 0.900000 0.000000 904 904 -1.100000 0.000000 904 nt=4 20.003 ms 291.556 GFLOPS +cgemm_ C N N 1300 1300 1300 0.900000 0.000000 1304 1304 -1.100000 0.000000 1304 nt=4 56.253 ms 312.446 GFLOPS +cgemm_ C N T 1700 1700 1700 0.900000 0.000000 1704 1704 -1.100000 0.000000 1704 nt=4 116.948 ms 336.081 GFLOPS +cgemm_ C T N 2100 2100 2100 0.900000 0.000000 2104 2104 -1.100000 0.000000 2104 nt=4 207.581 ms 356.911 GFLOPS +cgemm_ C T T 2500 2500 2500 0.900000 0.000000 2504 2504 -1.100000 0.000000 2504 nt=4 346.031 ms 361.239 GFLOPS +sgemm_ S N N 1000 3000 2000 0.900000 0.000000 4000 5000 -1.100000 0.000000 6000 nt=4 1024.360 ms 11.715 GFLOPS +sgemm_ S N N 100 100 100 0.900000 0.000000 104 104 -1.100000 0.000000 104 nt=4 0.362 ms 5.525 GFLOPS +sgemm_ S N N 500 500 500 0.900000 0.000000 504 504 -1.100000 0.000000 504 nt=4 1.688 ms 148.104 GFLOPS +sgemm_ S N N 900 900 900 0.900000 0.000000 904 904 -1.100000 0.000000 904 nt=4 147.791 ms 9.865 GFLOPS +sgemm_ S N N 1300 1300 1300 0.900000 0.000000 1304 1304 -1.100000 0.000000 1304 nt=4 451.156 ms 9.739 GFLOPS +sgemm_ S N T 1700 1700 1700 0.900000 0.000000 1704 1704 -1.100000 0.000000 1704 nt=4 873.577 ms 11.248 GFLOPS +sgemm_ S T N 2100 2100 2100 0.900000 0.000000 2104 2104 -1.100000 0.000000 2104 nt=4 1699.278 ms 10.900 GFLOPS +sgemm_ S T T 2500 2500 2500 0.900000 0.000000 2504 2504 -1.100000 0.000000 2504 nt=4 2651.917 ms 11.784 GFLOPS diff --git a/bench/outTemp.txt b/bench/outTemp.txt deleted file mode 100644 index ba29f25000..0000000000 --- a/bench/outTemp.txt +++ /dev/null @@ -1,21 +0,0 @@ -Dt n incx incy gflops -isamax_:183: S 100 1 29 0.043 -isamax_:183: S 200 1 65 0.065 -isamax_:183: S 300 1 185 0.078 -isamax_:183: S 400 1 86 0.261 -isamax_:183: S 500 1 271 0.279 -idamax_:183: D 100 1 64 0.099 -idamax_:183: D 200 1 175 0.131 -idamax_:183: D 300 1 102 0.148 -idamax_:183: D 400 1 249 0.157 -idamax_:183: D 500 1 197 0.165 -icamax_:183: C 100 1 1 0.185 -icamax_:183: C 200 1 108 0.242 -icamax_:183: C 300 1 76 0.271 -icamax_:183: C 400 1 178 0.283 -icamax_:183: C 500 1 403 0.304 -izamax_:183: Z 100 1 51 0.178 -izamax_:183: Z 200 1 175 0.232 -izamax_:183: Z 300 1 240 0.260 -izamax_:183: Z 400 1 108 0.293 -izamax_:183: Z 500 1 411 0.294 diff --git a/frame/3/bli_l3_sup.c b/frame/3/bli_l3_sup.c index 6944c465cc..163a828f86 100644 --- a/frame/3/bli_l3_sup.c +++ b/frame/3/bli_l3_sup.c @@ -46,7 +46,6 @@ err_t bli_gemmsup ) { AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_2); -// AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_2, alpha, a, b, beta, c); // Return early if small matrix handling is disabled at configure-time. #ifdef BLIS_DISABLE_SUP_HANDLING diff --git a/frame/3/bli_l3_sup_ref.c b/frame/3/bli_l3_sup_ref.c index 885467979e..b140cc5f1a 100644 --- a/frame/3/bli_l3_sup_ref.c +++ b/frame/3/bli_l3_sup_ref.c @@ -46,7 +46,6 @@ err_t bli_gemmsup_ref ) { AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_3); -// AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_3, alpha, a, b, beta, c); // This function implements the default gemmsup handler. If you are a // BLIS developer and wish to use a different gemmsup handler, please // register a different function pointer in the context in your diff --git a/frame/3/gemm/bli_gemm_front.c b/frame/3/gemm/bli_gemm_front.c index 21c0353710..662a6da9bb 100644 --- a/frame/3/gemm/bli_gemm_front.c +++ b/frame/3/gemm/bli_gemm_front.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2020, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 2021, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -48,7 +48,6 @@ void bli_gemm_front ) { AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_3); -// AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_3, alpha, a, b, beta, c); bli_init_once(); obj_t a_local; diff --git a/frame/3/gemm/bli_gemm_int.c b/frame/3/gemm/bli_gemm_int.c index 76ee08bfb6..405c74d76b 100644 --- a/frame/3/gemm/bli_gemm_int.c +++ b/frame/3/gemm/bli_gemm_int.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 2021, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -54,8 +54,7 @@ void bli_gemm_int gemm_var_oft f; AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_4); -// AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_4, alpha, a, b, beta, c); - + // Check parameters. if ( bli_error_checking_is_enabled() ) bli_gemm_basic_check( alpha, a, b, beta, c, cntx ); diff --git a/frame/compat/bla_gemm.c b/frame/compat/bla_gemm.c index faa7459a49..1bdb2397b2 100644 --- a/frame/compat/bla_gemm.c +++ b/frame/compat/bla_gemm.c @@ -65,9 +65,12 @@ void PASTEF77(ch,blasname) \ inc_t rs_b, cs_b; \ inc_t rs_c, cs_c; \ \ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); \ /* Initialize BLIS. */ \ bli_init_auto(); \ +\ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); \ + AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(ch), *transa, *transb, *m, *n, *k, \ + (void*)alpha, *lda, *ldb, (void*)beta, *ldc); \ \ /* Perform BLAS parameter checking. */ \ PASTEBLACHK(blasname) \ @@ -118,6 +121,7 @@ void PASTEF77(ch,blasname) \ NULL \ ); \ \ + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ @@ -142,18 +146,20 @@ void PASTEF77(ch,blasname) \ ftype* c, const f77_int* ldc \ ) \ { \ - AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(ch), *transa, *transb, *m, *n, *k, (void*)alpha, *lda, *ldb, (void*)beta, *ldc); \ \ trans_t blis_transa; \ trans_t blis_transb; \ dim_t m0, n0, k0; \ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO) \ \ dim_t m0_a, n0_a; \ dim_t m0_b, n0_b; \ \ /* Initialize BLIS. */ \ bli_init_auto(); \ +\ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); \ + AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(ch), *transa, *transb, *m, *n, *k, \ + (void*)alpha, *lda, *ldb, (void*)beta, *ldc); \ \ /* Perform BLAS parameter checking. */ \ PASTEBLACHK(blasname) \ @@ -217,6 +223,7 @@ void PASTEF77(ch,blasname) \ NULL \ ); \ } \ + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); \ return; \ } \ else if( m0 == 1 ) \ @@ -249,6 +256,7 @@ void PASTEF77(ch,blasname) \ NULL \ ); \ } \ + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); \ return; \ } \ \ @@ -284,7 +292,8 @@ void PASTEF77(ch,blasname) \ NULL \ ); \ \ - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) \ + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ } @@ -306,15 +315,19 @@ void dgemm_ double* c, const f77_int* ldc ) { - AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'D', *transa, *transb, *m, *n, *k, (void*)alpha, *lda, *ldb, (void*)beta, *ldc); + + trans_t blis_transa; trans_t blis_transb; dim_t m0, n0, k0; - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO) - /* Initialize BLIS. */ - bli_init_auto(); + /* Initialize BLIS. */ + bli_init_auto(); + + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) + AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(d), *transa, *transb, *m, *n, *k, \ + (void*)alpha, *lda, *ldb, (void*)beta, *ldc); /* Perform BLAS parameter checking. */ PASTEBLACHK(gemm) @@ -358,7 +371,8 @@ void dgemm_ (double*)beta, c, *ldc ); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); /* Finalize BLIS */ bli_finalize_auto(); @@ -395,6 +409,9 @@ void dgemm_ ((void*)0) ); } + + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + return; } else if (m0 == 1) @@ -427,6 +444,7 @@ void dgemm_ ((void*)0) ); } + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); return; } @@ -478,8 +496,9 @@ void dgemm_ NULL, NULL ); + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); /* Finalize BLIS. */ bli_finalize_auto(); return; @@ -519,7 +538,8 @@ void dgemm_ if (status == BLIS_SUCCESS) { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); /* Finalize BLIS. */ bli_finalize_auto(); @@ -532,7 +552,8 @@ void dgemm_ err_t status = bli_gemmsup(&alphao, &ao, &bo, &betao, &co, NULL, NULL); if (status == BLIS_SUCCESS) { - return; + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + return; } // fall back on native path when dgemm is not handled in sup path. @@ -550,7 +571,8 @@ void dgemm_ /* NULL */ /* ); */ - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); /* Finalize BLIS. */ bli_finalize_auto(); } // end of dgemm_ @@ -569,15 +591,16 @@ void zgemm_ dcomplex* c, const f77_int* ldc ) { - AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'Z', *transa, *transb, *m, *n, *k, (void*)alpha, *lda, *ldb, (void*)beta, *ldc); - trans_t blis_transa; trans_t blis_transb; dim_t m0, n0, k0; - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO) - /* Initialize BLIS. */ - bli_init_auto(); + /* Initialize BLIS. */ + bli_init_auto(); + + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) + AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(z), *transa, *transb, *m, *n, *k, + (void*)alpha, *lda, *ldb, (void*)beta, *ldc); /* Perform BLAS parameter checking. */ PASTEBLACHK(gemm) @@ -655,11 +678,12 @@ void zgemm_ NULL ); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - /* Finalize BLIS. */ - bli_finalize_auto(); - return; - } + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + /* Finalize BLIS. */ + bli_finalize_auto(); + return; + } // The code below will be called when number of threads = 1. #if ENABLE_INDUCED_METHOD @@ -686,7 +710,8 @@ void zgemm_ //sqp algo is found better for n > 40 if(bli_gemm_sqp(&alphao, &ao, &bo, &betao, &co, NULL, NULL)==BLIS_SUCCESS) { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) return; } } @@ -699,17 +724,20 @@ void zgemm_ err_t status = bli_gemmsup(&alphao, &ao, &bo, &betao, &co, NULL, NULL); if(status==BLIS_SUCCESS) { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) return; } + } // fall back on native path when zgemm is not handled in sup path. bli_gemmnat(&alphao, &ao, &bo, &betao, &co, NULL, NULL); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) return; - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) /* Finalize BLIS. */ bli_finalize_auto(); }// end of zgemm_ @@ -738,15 +766,16 @@ void dzgemm_ ) { - AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'Z', *transa, *transb, *m, *n, *k, (void*)alpha, *lda, *ldb, (void*)beta, *ldc); - trans_t blis_transa; trans_t blis_transb; dim_t m0, n0, k0; - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO) - /* Initialize BLIS. */ - bli_init_auto(); + /* Initialize BLIS. */ + bli_init_auto(); + + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) + AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(z), *transa, *transb, *m, *n, *k, + (void*)alpha, *lda, *ldb, (void*)beta, *ldc); /* Perform BLAS parameter checking. */ PASTEBLACHK(gemm) @@ -808,7 +837,8 @@ void dzgemm_ // fall back on native path when zgemm is not handled in sup path. bli_gemmnat(&alphao, &ao, &bo, &betao, &co, NULL, NULL); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) /* Finalize BLIS. */ bli_finalize_auto(); }// end of dzgemm_ From a3d04a21a0a76d744ccc08b3cfedc88ffb30abdf Mon Sep 17 00:00:00 2001 From: Nageshwar Singh Date: Tue, 31 Aug 2021 11:07:13 +0530 Subject: [PATCH 032/243] Complex double standalone gemv implementation independent of axpyf. Details - For axpyf implementation there are function(axpyf) calling overhead. - New implementations reduces function calling overhead. - This implementation uses kernel of size 4x4. - This implementation gives better performance for smaller sizes when compared to axpyf based implementation AMD-Internal: [CPUPL-1402] Change-Id: I5fa421b8c1d2b44c991c2a05e8f5b01b83eb4b37 --- frame/2/bli_l2_ker_prot.h | 8 +- frame/2/gemv/bli_gemv_unf_var2.c | 76 ++++++--- frame/compat/bla_scal.c | 16 +- kernels/zen/2/bli_gemv_zen_int_4.c | 260 +++++++++++++++++++++++++++++ kernels/zen/2/bli_gemv_zen_ref.c | 12 +- kernels/zen/bli_kernels_zen.h | 1 + 6 files changed, 332 insertions(+), 41 deletions(-) create mode 100644 kernels/zen/2/bli_gemv_zen_int_4.c diff --git a/frame/2/bli_l2_ker_prot.h b/frame/2/bli_l2_ker_prot.h index 15888a55f2..82febd761f 100644 --- a/frame/2/bli_l2_ker_prot.h +++ b/frame/2/bli_l2_ker_prot.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020, Advanced Micro Devices, Inc. + Copyright (C) 2020-21, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -35,17 +35,19 @@ // -// Define template prototypes for level-1f kernels. +// Define template prototypes for level-2 kernels. // #define GEMV_KER_PROT( ctype, ch, opname ) \ \ void PASTEMAC(ch,opname) \ ( \ + conj_t conja,\ + conj_t conjx,\ dim_t m, \ dim_t n, \ ctype* restrict alpha, \ - ctype* restrict a, inc_t lda, \ + ctype* restrict a, inc_t rs, inc_t cs, \ ctype* restrict x, inc_t incx, \ ctype* restrict beta, \ ctype* restrict y, inc_t incy, \ diff --git a/frame/2/gemv/bli_gemv_unf_var2.c b/frame/2/gemv/bli_gemv_unf_var2.c index cb77e30735..34c11f758b 100644 --- a/frame/2/gemv/bli_gemv_unf_var2.c +++ b/frame/2/gemv/bli_gemv_unf_var2.c @@ -394,15 +394,14 @@ void bli_zgemv_unf_var2 /* If beta is zero, use setv. Otherwise, scale by beta. */ /* y = beta * y; */ /* beta=0 case is hadled by scalv internally */ - - bli_zscalv_ex + bli_zscalv_zen_int10 ( BLIS_NO_CONJUGATE, n_elem, beta, - y, incy, - cntx, - NULL + y, + incy, + cntx ); if( bli_zeq0( *alpha ) ) @@ -411,30 +410,57 @@ void bli_zgemv_unf_var2 return; } - /* fusing factor */ - b_fuse = 4; - - for ( i = 0; i < n_iter; i += f ) + // for non-unit incx, incy and rs_at and conjugate will be added in the next patch + if( (incx == 1 && incy == 1 && rs_at == 1 ) && + !bli_is_conj(conja) && !bli_is_conj(conjx) && !bli_is_trans(transa)) { - f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); - A1 = a + (0 )*rs_at + (i )*cs_at; - x1 = x + (i )*incx; - y1 = y + (0 )*incy; - - /* y = y + alpha * A1 * x1; */ - bli_zaxpyf_zen_int_4 + // This gemv code deals with the followint conditions only + // 1. incx, incy, and row stride equal to one + // 2. Non conjugate A matrix and X vector + // 3. No Transpose for A Martix + // Rest is taken care by the else part (axpyf implementation) + bli_zgemv_zen_int_4x4 ( - conja, - conjx, - n_elem, - f, - alpha, - A1, rs_at, cs_at, - x1, incx, - y1, incy, - NULL + conja, + conjx, + m, + n, + alpha, + a, rs_at, cs_at, + x, incx, + beta, + y, incy, + NULL ); } + else + { + /* fusing factor */ + b_fuse = 4; + + for ( i = 0; i < n_iter; i += f ) + { + f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); + A1 = a + (0 )*rs_at + (i )*cs_at; + x1 = x + (i )*incx; + y1 = y + (0 )*incy; + + /* y = y + alpha * A1 * x1; */ + bli_zaxpyf_zen_int_4 + ( + conja, + conjx, + n_elem, + f, + alpha, + A1, rs_at, cs_at, + x1, incx, + y1, incy, + NULL + ); + } + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); } diff --git a/frame/compat/bla_scal.c b/frame/compat/bla_scal.c index 821aca6160..b08fac87f5 100644 --- a/frame/compat/bla_scal.c +++ b/frame/compat/bla_scal.c @@ -243,15 +243,15 @@ void cscal_ /* Initialize BLIS */ //bli_init_auto(); - /* Convert typecast negative values of n to zero. */ - if ( *n < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(*n); - if (*n == 0 || alpha == NULL) { AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); return; } + /* Convert typecast negative values of n to zero. */ + if ( *n < 0 ) n0 = ( dim_t )0; + else n0 = ( dim_t )(*n); + /* If the input increments are negative, adjust the pointers so we can use positive increments instead. */ if ( *incx < 0 ) @@ -311,15 +311,15 @@ void zscal_ /* Initialize BLIS */ //bli_init_auto(); - /* Convert typecast negative values of n to zero. */ - if ( *n < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(*n); - if (*n == 0 || alpha == NULL) { AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); return; } + /* Convert typecast negative values of n to zero. */ + if ( *n < 0 ) n0 = ( dim_t )0; + else n0 = ( dim_t )(*n); + /* If the input increments are negative, adjust the pointers so we can use positive increments instead. */ if ( *incx < 0 ) diff --git a/kernels/zen/2/bli_gemv_zen_int_4.c b/kernels/zen/2/bli_gemv_zen_int_4.c new file mode 100644 index 0000000000..95060f57e2 --- /dev/null +++ b/kernels/zen/2/bli_gemv_zen_int_4.c @@ -0,0 +1,260 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "immintrin.h" +#include "blis.h" + +/* + This implementation uses 512 bits of cache line efficiently for + column stored matrix and vectors. + To achieve this, at each iteration we use 2 ymm registers + i.e. .512 bits for arithmetic operation. By this we use the + cache efficiently. +*/ +void bli_zgemv_zen_int_4x4 + ( + conj_t conja, + conj_t conjx, + dim_t m, + dim_t n, + dcomplex* restrict alpha, + dcomplex* restrict a, inc_t inca, inc_t lda, + dcomplex* restrict x, inc_t incx, + dcomplex* restrict beta, + dcomplex* restrict y, inc_t incy, + cntx_t* restrict cntx + ) +{ + + const dim_t S_MR = 4; // Kernel size , m = 4 + const dim_t S_NR = 4; // Kernel size , n = 4 + + dcomplex chi0; + dcomplex chi1; + dcomplex chi2; + dcomplex chi3; + + inc_t lda2 = 2*lda; + inc_t lda3 = 3*lda; + + inc_t incy2 = 2*incy; + inc_t incx2 = 2*incx; + inc_t incx3 = 3*incx; + inc_t inca2 = 2*inca; + + dcomplex* restrict x0 = x; + dcomplex* restrict y0 = y; + dcomplex* restrict a0 = a; + + dim_t i,j; + + __m256d ymm0, ymm1, ymm2, ymm3; + __m256d ymm4, ymm5, ymm6, ymm7; + __m256d ymm8, ymm9, ymm10, ymm11; + __m256d ymm12, ymm13, ymm14, ymm15; + + for( i = 0; i+S_NR-1 < n; i+=S_NR ) + { + a0 = a + (i )*lda; + x0 = x + (i )*incx; + y0 = y;// For each kernel, y should start form beginning + + chi0 = *( x0);// + 0*incx ); + chi1 = *( x0 + incx ); + chi2 = *( x0 + incx2 ); + chi3 = *( x0 + incx3 ); + + // Scale each chi scalar by alpha. + bli_zscals( *alpha, chi0 ); + bli_zscals( *alpha, chi1 ); + bli_zscals( *alpha, chi2 ); + bli_zscals( *alpha, chi3 ); + + // broadcast x0,x1,x2,x3 + // broadcast real & imag parts of 4 elements of x + ymm0 = _mm256_broadcast_sd(&chi0.real); // real part of x0 + ymm1 = _mm256_broadcast_sd(&chi0.imag); // imag part of x0 + ymm2 = _mm256_broadcast_sd(&chi1.real); // real part of x1 + ymm3 = _mm256_broadcast_sd(&chi1.imag); // imag part of x1 + ymm4 = _mm256_broadcast_sd(&chi2.real); // real part of x2 + ymm5 = _mm256_broadcast_sd(&chi2.imag); // imag part of x2 + ymm6 = _mm256_broadcast_sd(&chi3.real); // real part of x3 + ymm7 = _mm256_broadcast_sd(&chi3.imag); // imag part of x3 + + for( j = 0 ; j+S_MR-1 < m ; j+=S_MR ) + { + //load columns of A + ymm8 = _mm256_loadu_pd((double const *)(a0)); + ymm9 = _mm256_loadu_pd((double const *)(a0 + lda)); + ymm10 = _mm256_loadu_pd((double const *)(a0 + lda2)); + ymm11 = _mm256_loadu_pd((double const *)(a0 + lda3)); + +//-------------------- + //Ar*Xr Ai*Xr Ar*Xr Ai*Xr + ymm14 = _mm256_mul_pd(ymm8, ymm0); + //Ar*Xi Ai*Xi Ar*Xi Ai*Xi + ymm15 = _mm256_mul_pd(ymm8, ymm1); + + /* Next set of A mult by real and imag, + Add into the previous real and imag results */ + + // (Ar*Xr Ai*Xr Ar*Xr Ai*Xr) + (prev iteration real results) + ymm14 = _mm256_fmadd_pd(ymm9, ymm2, ymm14); + // (Ar*Xi Ai*Xi Ar*Xi Ai*Xi) + + (prev iteration imag results) + ymm15 = _mm256_fmadd_pd(ymm9, ymm3, ymm15); + + // (Ar*Xr Ai*Xr Ar*Xr Ai*Xr) + (prev iteration real results) + ymm14 = _mm256_fmadd_pd(ymm10, ymm4, ymm14); + // (Ar*Xi Ai*Xi Ar*Xi Ai*Xi) + + (prev iteration imag results) + ymm15 = _mm256_fmadd_pd(ymm10, ymm5, ymm15); + + // (Ar*Xr Ai*Xr Ar*Xr Ai*Xr) + (prev iteration real results) + ymm14 = _mm256_fmadd_pd(ymm11, ymm6, ymm14); + // (Ar*Xi Ai*Xi Ar*Xi Ai*Xi) + + (prev iteration imag results) + ymm15 = _mm256_fmadd_pd(ymm11, ymm7, ymm15); + + /*Permute the imag acc register to addsub to real accu results */ + // (Ar*Xi Ai*Xi Ar*Xi Ai*Xi) => (Ai*Xi Ar*Xi Ai*Xi Ar*Xi) + ymm15 = _mm256_permute_pd(ymm15, 5); + + /*AddSub to get the 2 proper complex multipled value*/ + /* Ar*Xi - Ai*Xi, Ai*Xi + Ar*Xi, Ar*Xi - Ai*Xi, Ai*Xi + Ar*Xi*/ + ymm12 = _mm256_addsub_pd(ymm14, ymm15); + + //load Y vector + ymm14 = _mm256_loadu_pd((double*)y0); + //Add the results into y + ymm12 = _mm256_add_pd(ymm14, ymm12); + // Store the results back + _mm256_storeu_pd((double*)(y0), ymm12); +//----------------------- + + // Load Next Set of A matrix elements for the same col + // Ar2 Ai2 Ar3 Ai3 + ymm8 = _mm256_loadu_pd((double const *)(a0 + (inca2))); + ymm9 = _mm256_loadu_pd((double const *)(a0 + (inca2) + lda)); + ymm10 = _mm256_loadu_pd((double const *)(a0 + (inca2) + lda2)); + ymm11 = _mm256_loadu_pd((double const *)(a0 + (inca2) + lda3)); + + //Ar0*Xr Ai0*Xr Ar1*Xr Ai1*Xr + ymm14 = _mm256_mul_pd(ymm8, ymm0); + //Ar0*Xi Ai0*Xi Ar1*Xi Ai1*Xi + ymm15 = _mm256_mul_pd(ymm8, ymm1); + + /* Next set of A mult by real and imag, + Add into the previous real and imag results */ + + // (Ar*Xr Ai*Xr Ar*Xr Ai*Xr) + (prev iteration real results) + ymm14 = _mm256_fmadd_pd(ymm9, ymm2, ymm14); + // (Ar*Xi Ai*Xi Ar*Xi Ai*Xi) + + (prev iteration imag results) + ymm15 = _mm256_fmadd_pd(ymm9, ymm3, ymm15); + + // (Ar*Xr Ai*Xr Ar*Xr Ai*Xr) + (prev iteration real results) + ymm14 = _mm256_fmadd_pd(ymm10, ymm4, ymm14); + // (Ar*Xi Ai*Xi Ar*Xi Ai*Xi) + + (prev iteration imag results) + ymm15 = _mm256_fmadd_pd(ymm10, ymm5, ymm15); + + // (Ar*Xr Ai*Xr Ar*Xr Ai*Xr) + (prev iteration real results) + ymm14 = _mm256_fmadd_pd(ymm11, ymm6, ymm14); + // (Ar*Xi Ai*Xi Ar*Xi Ai*Xi) + + (prev iteration imag results) + ymm15 = _mm256_fmadd_pd(ymm11, ymm7, ymm15); + + /*Permute the imag acc register to addsub to real accu results */ + // (Ar*Xi Ai*Xi Ar*Xi Ai*Xi) => (Ai*Xi Ar*Xi Ai*Xi Ar*Xi) + ymm15 = _mm256_permute_pd(ymm15, 5); + /*AddSub to get the 2 proper complex multipled value*/ + /* Ar*Xi - Ai*Xi, Ai*Xi + Ar*Xi, Ar*Xi - Ai*Xi, Ai*Xi + Ar*Xi*/ + ymm13 = _mm256_addsub_pd(ymm14, ymm15); + + // load Y vector + ymm14 = _mm256_loadu_pd((double *)(y0 + (incy2))); + // Add the results into y + ymm13 = _mm256_add_pd(ymm14, ymm13); + // Store the results back + _mm256_storeu_pd((double*)(y0 + (incy2)), ymm13); +//----------------------- + + y0 += S_MR*incy ; // Next Set of y0 vector + a0 += S_MR*inca ; // Next Set of a0 matrix elements in the same col + } + + // For resisual m + for( ; j < m ; ++j ) + { + dcomplex y0c = *(dcomplex*)y0; + + const dcomplex a0c = *a0; + const dcomplex a1c = *(a0 + lda); + const dcomplex a2c = *(a0 + lda2); + const dcomplex a3c = *(a0 + lda3); + + y0c.real += chi0.real * a0c.real - chi0.imag * a0c.imag; + y0c.real += chi1.real * a1c.real - chi1.imag * a1c.imag; + y0c.real += chi2.real * a2c.real - chi2.imag * a2c.imag; + y0c.real += chi3.real * a3c.real - chi3.imag * a3c.imag; + + y0c.imag += chi0.imag * a0c.real + chi0.real * a0c.imag; + y0c.imag += chi1.imag * a1c.real + chi1.real * a1c.imag; + y0c.imag += chi2.imag * a2c.real + chi2.real * a2c.imag; + y0c.imag += chi3.imag * a3c.real + chi3.real * a3c.imag; + + *(dcomplex*)y0 = y0c; + + a0 += 1; + y0 += 1; + } + } + + // For resisual n, axpyv is used + for ( ; i < n; ++i ) + { + dcomplex* a1 = a + (i )*lda; + dcomplex* chi1 = x + (i )*incx; + dcomplex* y1 = y; + dcomplex alpha_chi1; + + bli_zcopycjs( conjx, *chi1, alpha_chi1 ); + bli_zscals( *alpha, alpha_chi1 ); + + bli_zaxpyv_zen_int5 + ( + conja, + m, + &alpha_chi1, + a1, inca, + y1, incy, + cntx + ); + } +} \ No newline at end of file diff --git a/kernels/zen/2/bli_gemv_zen_ref.c b/kernels/zen/2/bli_gemv_zen_ref.c index fd36e73cd5..0d71522c3c 100644 --- a/kernels/zen/2/bli_gemv_zen_ref.c +++ b/kernels/zen/2/bli_gemv_zen_ref.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020-21, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -40,10 +40,12 @@ */ void bli_dgemv_zen_ref_c ( + conj_t conja, + conj_t conjx, dim_t m, dim_t n, double* restrict alpha, - double* restrict a, inc_t lda, + double* restrict a, inc_t inca, inc_t lda, double* restrict x, inc_t incx, double* restrict beta, double* restrict y, inc_t incy, @@ -75,7 +77,7 @@ void bli_dgemv_zen_ref_c { PRAGMA_SIMD for(i = 0; i < m; i++) - (y0[i]) = (a0[i]) * (x0_val) * (*alpha) + y0[i] * (*beta); + (y0[i]) = (a0[i]) * (x0_val) * (*alpha) + y0[i] * (*beta); } a0 += lda; @@ -86,7 +88,7 @@ void bli_dgemv_zen_ref_c PRAGMA_SIMD for(i = 0; i < m; i++) { - (y0[i]) += (a0[i]) * xp * (*alpha); + (y0[i]) += (a0[i]) * xp * (*alpha); } a0 += lda; } @@ -112,7 +114,7 @@ void bli_dgemv_zen_ref_c const double xp = *(x0+j*incx); for(i = 0; i < m; i++) { - *(y0 + i*incy) += (a0[j*lda+i]) * xp * (*alpha); + *(y0 + i*incy) += (a0[j*lda+i]) * xp * (*alpha); } } } diff --git a/kernels/zen/bli_kernels_zen.h b/kernels/zen/bli_kernels_zen.h index 02b73ba16c..b39ccec577 100644 --- a/kernels/zen/bli_kernels_zen.h +++ b/kernels/zen/bli_kernels_zen.h @@ -117,6 +117,7 @@ DOTXF_KER_PROT( double, d, dotxf_zen_int_8 ) //gemv(scalar code) GEMV_KER_PROT( double, d, gemv_zen_ref_c ) +GEMV_KER_PROT( dcomplex, z, gemv_zen_int_4x4 ) // -- level-3 sup -------------------------------------------------------------- // semmsup_rv From 23278627f40bc58e87b6918d0baa176122cd08fd Mon Sep 17 00:00:00 2001 From: satish kumar nuggu Date: Fri, 24 Sep 2021 20:51:22 +0530 Subject: [PATCH 033/243] STRSM small kernel implementation Details: -- AMD Internal Id: [CPUPL-1702] -- Used 16x6 SGEMM kernel with vector fma by utilizing ymm registers -- Used packing of matrix A to effectively cache and reuse -- Implemented kernels using macro based modular approach -- Taken care of --disable_pre_inversion configuration -- modularized strsm 16 combinations of trsm into 4 kernels Change-Id: I30a1551967c36f6bae33be3b7ae5b7fcc7c905ea --- frame/compat/bla_trsm.c | 262 +- kernels/zen/3/bli_trsm_small.c | 45800 ++++++++++++++++++++++--------- 2 files changed, 32352 insertions(+), 13710 deletions(-) diff --git a/frame/compat/bla_trsm.c b/frame/compat/bla_trsm.c index 943164d367..a2703d1cdd 100644 --- a/frame/compat/bla_trsm.c +++ b/frame/compat/bla_trsm.c @@ -381,6 +381,264 @@ void PASTEF77(ch,blasname) \ #ifdef BLIS_ENABLE_BLAS #ifdef BLIS_CONFIG_EPYC + +void strsm_ +( + const f77_char* side, + const f77_char* uploa, + const f77_char* transa, + const f77_char* diaga, + const f77_int* m, + const f77_int* n, + const float* alpha, + const float* a, const f77_int* lda, + float* b, const f77_int* ldb +) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO) + AOCL_DTL_LOG_TRSM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'd', + *side, *uploa,*transa, *diaga, *m, *n, + (void*)alpha,*lda, *ldb); + + side_t blis_side; + uplo_t blis_uploa; + trans_t blis_transa; + diag_t blis_diaga; + dim_t m0, n0; + conj_t conja = BLIS_NO_CONJUGATE ; + + /* Initialize BLIS. */ + bli_init_auto(); + + /* Perform BLAS parameter checking. */ + PASTEBLACHK(trsm) + ( + MKSTR(s), + MKSTR(trsm), + side, + uploa, + transa, + diaga, + m, + n, + lda, + ldb + ); + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + bli_param_map_netlib_to_blis_side( *side, &blis_side ); + bli_param_map_netlib_to_blis_uplo( *uploa, &blis_uploa ); + bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); + bli_param_map_netlib_to_blis_diag( *diaga, &blis_diaga ); + + /* Typecast BLAS integers to BLIS integers. */ + bli_convert_blas_dim1( *m, m0 ); + bli_convert_blas_dim1( *n, n0 ); + + /* Set the row and column strides of the matrix operands. */ + const inc_t rs_a = 1; + const inc_t cs_a = *lda; + const inc_t rs_b = 1; + const inc_t cs_b = *ldb; + const num_t dt = BLIS_FLOAT; + + if( n0 == 1 ) + { + if( blis_side == BLIS_LEFT ) + { + if(bli_is_notrans(blis_transa)) + { + bli_strsv_unf_var2 + ( + blis_uploa, + blis_transa, + blis_diaga, + m0, + (float*)alpha, + (float*)a, rs_a, cs_a, + (float*)b, rs_b, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + else if(bli_is_trans(blis_transa)) + { + bli_strsv_unf_var1 + ( + blis_uploa, + blis_transa, + blis_diaga, + m0, + (float*)alpha, + (float*)a, rs_a, cs_a, + (float*)b, rs_b, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + } + else if( ( blis_side == BLIS_RIGHT ) && ( m0 != 1 ) ) + { + /* b = alpha * b; */ + bli_sscalv_ex + ( + conja, + m0, + (float*)alpha, + b, rs_b, + NULL, + NULL + ); + if(blis_diaga == BLIS_NONUNIT_DIAG) + { + float inva = 1.0/ *a; + for(int indx = 0; indx < m0; indx ++) + { + b[indx] = ( inva * b[indx] ); + } + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + } + else if( m0 == 1 ) + { + if(blis_side == BLIS_RIGHT) + { + if(bli_is_notrans(blis_transa)) + { + if(blis_uploa == BLIS_UPPER) + blis_uploa = BLIS_LOWER; + else + blis_uploa = BLIS_UPPER; + + bli_strsv_unf_var1 + ( + blis_uploa, + blis_transa, + blis_diaga, + n0, + (float*)alpha, + (float*)a, cs_a, rs_a, + (float*)b, cs_b, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + else if(bli_is_trans(blis_transa)) + { + if(blis_uploa == BLIS_UPPER) + blis_uploa = BLIS_LOWER; + else + blis_uploa = BLIS_UPPER; + + bli_strsv_unf_var2 + ( + blis_uploa, + blis_transa, + blis_diaga, + n0, + (float*)alpha, + (float*)a, cs_a, rs_a, + (float*)b, cs_b, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + } + else if(( blis_side == BLIS_LEFT ) && ( n0 != 1 )) + { + /* b = alpha * b; */ + bli_sscalv_ex + ( + conja, + n0, + (float*)alpha, + b, cs_b, + NULL, + NULL + ); + if(blis_diaga == BLIS_NONUNIT_DIAG) + { + float inva = 1.0/ *a; + for(int indx = 0; indx < n0; indx ++) + { + b[indx*cs_b] = (inva * b[indx*cs_b] ); + } + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + } + const struc_t struca = BLIS_TRIANGULAR; + + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; + obj_t ao = BLIS_OBJECT_INITIALIZER; + obj_t bo = BLIS_OBJECT_INITIALIZER; + + dim_t mn0_a; + + bli_set_dim_with_side( blis_side, m0, n0, &mn0_a ); + + bli_obj_init_finish_1x1( dt, (float*)alpha, &alphao ); + + bli_obj_init_finish( dt, mn0_a, mn0_a, (float*)a, rs_a, cs_a, &ao ); + bli_obj_init_finish( dt, m0, n0, (float*)b, rs_b, cs_b, &bo ); + + bli_obj_set_uplo( blis_uploa, &ao ); + bli_obj_set_diag( blis_diaga, &ao ); + bli_obj_set_conjtrans( blis_transa, &ao ); + + bli_obj_set_struc( struca, &ao ); + +#ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM + /* bli_strsm_small is performing better existing native + * implementations for [m,n]<=1000 for single thread. + * In case of multithread when [m,n]<=128 sinlge thread implemenation + * is doing better than native multithread */ + bool nt = bli_thread_get_is_parallel(); + if((nt==0 && m0<=1000 && n0<=1000) || + (nt && (m0+n0)<320) ) + { + err_t status; + status = bli_trsm_small + ( + blis_side, + &alphao, + &ao, + &bo, + NULL, + NULL + ); + if (status == BLIS_SUCCESS) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + /* Finalize BLIS. */ + bli_finalize_auto(); + return; + } + } +#endif + + bli_trsmnat + ( + blis_side, + &alphao, + &ao, + &bo, + NULL, + NULL + ); + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) + /* Finalize BLIS. */ + bli_finalize_auto(); +} + void dtrsm_ ( const f77_char* side, @@ -662,7 +920,7 @@ void ztrsm_ trans_t blis_transa; diag_t blis_diaga; dim_t m0, n0; - conj_t conja = BLIS_NO_CONJUGATE ; + //conj_t conja = BLIS_NO_CONJUGATE ; /* Initialize BLIS. */ bli_init_auto(); @@ -937,8 +1195,6 @@ void ztrsm_ bli_finalize_auto(); } - -GENTFUNC( float, s, trsm, trsm ) GENTFUNC( scomplex, c, trsm, trsm ) #else INSERT_GENTFUNC_BLAS( trsm, trsm ) diff --git a/kernels/zen/3/bli_trsm_small.c b/kernels/zen/3/bli_trsm_small.c index 6e8455d024..f4428564c5 100644 --- a/kernels/zen/3/bli_trsm_small.c +++ b/kernels/zen/3/bli_trsm_small.c @@ -234,7 +234,6 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB cntl_t* cntl ); - /* * The preinversion of diagonal elements are enabled/disabled * based on configuration. @@ -284,6 +283,41 @@ BLIS_INLINE err_t dtrsm_AutXB_ref return BLIS_SUCCESS; }// end of function +/* + * Reference implementations + * ToDo: We can combine all these reference implementation + into a macro +*/ +//A'X = B; A is upper triangular; transpose; +//non-unitDiagonal double precision +BLIS_INLINE err_t strsm_AutXB_ref +( + float *A, + float *B, + dim_t M, + dim_t N, + dim_t lda, + dim_t ldb, + bool unitDiagonal +) +{ + dim_t i, j, k; + for (k = 0; k < M; k++) + { + float lkk_inv = 1.0; + if(!unitDiagonal) lkk_inv = DIAG_ELE_INV_OPS(lkk_inv,A[k+k*lda]); + for (j = 0; j < N; j++) + { + B[k + j*ldb] = DIAG_ELE_EVAL_OPS(B[k + j*ldb] , lkk_inv); + for (i = k+1; i < M; i++) + { + B[i + j*ldb] -= A[i*lda + k] * B[k + j*ldb]; + } + } + }// k -loop + return BLIS_SUCCESS; +}// end of function + /* TRSM scalar code for the case AX = alpha * B * A is upper-triangular, non-unit-diagonal * Dimensions: A: mxm X: mxn B:mxn @@ -316,6 +350,38 @@ BLIS_INLINE err_t dtrsm_AuXB_ref return BLIS_SUCCESS; }// end of function +/* TRSM scalar code for the case AX = alpha * B + * A is upper-triangular, non-unit-diagonal + * Dimensions: A: mxm X: mxn B:mxn + */ +BLIS_INLINE err_t strsm_AuXB_ref +( + float *A, + float *B, + dim_t M, + dim_t N, + dim_t lda, + dim_t ldb, + bool is_unitdiag +) +{ + dim_t i, j, k; + for (k = M-1; k >= 0; k--) + { + float lkk_inv = 1.0; + if(!is_unitdiag) lkk_inv = DIAG_ELE_INV_OPS(lkk_inv,A[k+k*lda]); + for (j = N -1; j >= 0; j--) + { + B[k + j*ldb] = DIAG_ELE_EVAL_OPS(B[k + j*ldb],lkk_inv); + for (i = k-1; i >=0; i--) + { + B[i + j*ldb] -= A[i + k*lda] * B[k + j*ldb]; + } + } + }// k -loop + return BLIS_SUCCESS; +}// end of function + /* TRSM scalar code for the case AX = alpha * B * A is lower-triangular, non-unit-diagonal, no transpose * Dimensions: A: mxm X: mxn B:mxn @@ -348,6 +414,38 @@ BLIS_INLINE err_t dtrsm_AlXB_ref return BLIS_SUCCESS; }// end of function +/* TRSM scalar code for the case AX = alpha * B + * A is lower-triangular, non-unit-diagonal, no transpose + * Dimensions: A: mxm X: mxn B:mxn + */ +BLIS_INLINE err_t strsm_AlXB_ref +( + float *A, + float *B, + dim_t M, + dim_t N, + dim_t lda, + dim_t ldb, + bool is_unitdiag +) +{ + dim_t i, j, k; + for (k = 0; k < M; k++) + { + float lkk_inv = 1.0; + if(!is_unitdiag) lkk_inv = DIAG_ELE_INV_OPS(lkk_inv,A[k+k*lda]); + for (j = 0; j < N; j++) + { + B[k + j*ldb] = DIAG_ELE_EVAL_OPS(B[k + j*ldb],lkk_inv); + for (i = k+1; i < M; i++) + { + B[i + j*ldb] -= A[i + k*lda] * B[k + j*ldb]; + } + } + }// k -loop + return BLIS_SUCCESS; +}// end of function + /* TRSM scalar code for the case AX = alpha * B * A is lower-triangular, non-unit-diagonal, transpose * Dimensions: A: mxm X: mxn B:mxn @@ -380,6 +478,38 @@ BLIS_INLINE err_t dtrsm_AltXB_ref return BLIS_SUCCESS; }// end of function +/* TRSM scalar code for the case AX = alpha * B + * A is lower-triangular, non-unit-diagonal, transpose + * Dimensions: A: mxm X: mxn B:mxn + */ +BLIS_INLINE err_t strsm_AltXB_ref +( + float *A, + float *B, + dim_t M, + dim_t N, + dim_t lda, + dim_t ldb, + bool is_unitdiag +) +{ + dim_t i, j, k; + for (k = M-1; k >= 0; k--) + { + float lkk_inv = 1.0; + if(!is_unitdiag) lkk_inv = DIAG_ELE_INV_OPS(lkk_inv,A[k+k*lda]); + for (j = N -1; j >= 0; j--) + { + B[k + j*ldb] = DIAG_ELE_EVAL_OPS(B[k + j*ldb],lkk_inv); + for (i = k-1; i >=0; i--) + { + B[i + j*ldb] -= A[i*lda + k] * B[k + j*ldb]; + } + } + }// k -loop + return BLIS_SUCCESS; +}// end of function + /* TRSM scalar code for the case XA = alpha * B * A is upper-triangular, non-unit/unit diagonal no transpose * Dimensions: X:mxn A:nxn B:mxn @@ -1550,87 +1680,1821 @@ BLIS_INLINE err_t dtrsm_XAltB_ref ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5));\ ymm13 = _mm256_fmsub_pd(ymm0, ymm15, ymm13); -/* - Pack a block of 8xk or 6xk from input buffer into packed buffer - directly or after transpose based on input params -*/ -BLIS_INLINE void bli_dtrsm_small_pack -( - char side, - dim_t size, - bool trans, - double *inbuf, - dim_t cs_a, - double *pbuff, - dim_t p_lda, - dim_t mr -) -{ - //scratch registers - __m256d ymm0, ymm1, ymm2, ymm3; - __m256d ymm4, ymm5, ymm6, ymm7; - __m256d ymm8, ymm9, ymm10, ymm11; - __m256d ymm12, ymm13; - __m128d xmm0,xmm1,xmm2,xmm3; - double zero = 0.0; - - if(side=='L'||side=='l') - { - /*Left case is 8xk*/ - if(trans) - { - /* - ------------- ------------- - | | | | | - | 4x8 | | | | - ------------- ==> | 8x4 | 8x4 | - | 4x8 | | | | - | | | | | - ------------- ------------- - */ - for(dim_t x = 0; x < size; x += mr) - { - ymm0 = _mm256_loadu_pd((double const *)(inbuf)); - ymm10 = _mm256_loadu_pd((double const *)(inbuf + 4)); - ymm1 = _mm256_loadu_pd((double const *)(inbuf + cs_a)); - ymm11 = _mm256_loadu_pd((double const *)(inbuf + 4 + cs_a)); - ymm2 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 2)); - ymm12 = _mm256_loadu_pd((double const *)(inbuf + 4 + cs_a * 2)); - ymm3 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 3)); - ymm13 = _mm256_loadu_pd((double const *)(inbuf + 4 + cs_a * 3)); +#ifdef BLIS_DISABLE_TRSM_PREINVERSION +#define STRSM_SMALL_DIV_OR_SCALE _mm256_div_ps +#endif - ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); +#ifdef BLIS_ENABLE_TRSM_PREINVERSION +#define STRSM_SMALL_DIV_OR_SCALE _mm256_mul_ps +#endif - _mm256_storeu_pd((double *)(pbuff), ymm6); - _mm256_storeu_pd((double *)(pbuff + p_lda), ymm7); - _mm256_storeu_pd((double *)(pbuff + p_lda*2), ymm8); - _mm256_storeu_pd((double *)(pbuff + p_lda*3), ymm9); +/*Initialize */ +#define BLIS_SET_S_YMM_REG_ZEROS \ + ymm3 = _mm256_setzero_ps(); \ + ymm4 = _mm256_setzero_ps(); \ + ymm5 = _mm256_setzero_ps(); \ + ymm6 = _mm256_setzero_ps(); \ + ymm7 = _mm256_setzero_ps(); \ + ymm8 = _mm256_setzero_ps(); \ + ymm9 = _mm256_setzero_ps(); \ + ymm10 = _mm256_setzero_ps(); \ + ymm11 = _mm256_setzero_ps(); \ + ymm12 = _mm256_setzero_ps(); \ + ymm13 = _mm256_setzero_ps(); \ + ymm14 = _mm256_setzero_ps(); \ + ymm15 = _mm256_setzero_ps(); - ymm4 = _mm256_unpacklo_pd(ymm10, ymm11); - ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); +/*GEMM block used in trsm small right cases*/ +#define BLIS_STRSM_SMALL_GEMM_6nx16m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) \ + {\ + /*load 8x1 block of B10*/ \ + ymm0 = _mm256_loadu_ps((float const *)b10); \ + ymm1 = _mm256_loadu_ps((float const *)(b10 + 8)); \ + \ + /*broadcast 1st row of A01*/ \ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); \ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3); \ + ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4); \ + \ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 1)); \ + ymm5 = _mm256_fmadd_ps(ymm2, ymm0, ymm5); \ + ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6); \ + \ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 2)); \ + ymm7 = _mm256_fmadd_ps(ymm2, ymm0, ymm7); \ + ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8); \ + \ + /*Prefetch the next micro panel*/\ + _mm_prefetch((char*)( b10 + 8*cs_b), _MM_HINT_T0);\ + \ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 3)); \ + ymm9 = _mm256_fmadd_ps(ymm2, ymm0, ymm9); \ + ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10); \ + \ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 4)); \ + ymm11 = _mm256_fmadd_ps(ymm2, ymm0, ymm11); \ + ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12); \ + \ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 5)); \ + ymm13 = _mm256_fmadd_ps(ymm2, ymm0, ymm13); \ + ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14); \ + a01 += 1;\ + b10 += cs_b; \ + } - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); +#define BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 8x1 block of B10*/\ + ymm0 = _mm256_loadu_ps((float const *)b10); /*B10[0][0] B10[1][0] B10[2][0] B10[3][0]*/\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_ps(ymm2, ymm0, ymm5);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ + ymm7 = _mm256_fmadd_ps(ymm2, ymm0, ymm7);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 3)); /*A01[0][3]*/\ + ymm9 = _mm256_fmadd_ps(ymm2, ymm0, ymm9);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 4)); /*A01[0][4]*/\ + ymm11 = _mm256_fmadd_ps(ymm2, ymm0, ymm11);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 5)); /*A01[0][5]*/\ + ymm13 = _mm256_fmadd_ps(ymm2, ymm0, ymm13);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } - ymm0 = _mm256_unpackhi_pd(ymm10, ymm11); - ymm1 = _mm256_unpackhi_pd(ymm12, ymm13); +#define BLIS_STRSM_SMALL_GEMM_4nx16m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 8x1 block of B10*/\ + ymm0 = _mm256_loadu_ps((float const *)b10);\ + ymm1 = _mm256_loadu_ps((float const *)(b10 + 8));\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ + ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_ps(ymm2, ymm0, ymm5);\ + ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ + ymm7 = _mm256_fmadd_ps(ymm2, ymm0, ymm7);\ + ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 3)); /*A01[0][3]*/\ + ymm9 = _mm256_fmadd_ps(ymm2, ymm0, ymm9);\ + ymm10 = _mm256_fmadd_ps(ymm2, ymm1, ymm10);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); +#define BLIS_STRSM_SMALL_GEMM_3nx16m(a01,b10,cs_b,p_lda,k_iter)\ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 8x1 block of B10*/\ + ymm0 = _mm256_loadu_ps((float const *)b10);\ + ymm1 = _mm256_loadu_ps((float const *)(b10 + 8));\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ + ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_ps(ymm2, ymm0, ymm5);\ + ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ + ymm7 = _mm256_fmadd_ps(ymm2, ymm0, ymm7);\ + ymm8 = _mm256_fmadd_ps(ymm2, ymm1, ymm8);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } - _mm256_storeu_pd((double *)(pbuff + p_lda * 4), ymm6); - _mm256_storeu_pd((double *)(pbuff + p_lda * 5), ymm7); - _mm256_storeu_pd((double *)(pbuff + p_lda * 6), ymm8); - _mm256_storeu_pd((double *)(pbuff + p_lda * 7), ymm9); +#define BLIS_STRSM_SMALL_GEMM_2nx16m(a01,b10,cs_b,p_lda,k_iter)\ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 8x1 block of B10*/\ + ymm0 = _mm256_loadu_ps((float const *)b10);/*B10[0][0] B10[1][0] B10[2][0] B10[3][0]*/\ + ymm1 = _mm256_loadu_ps((float const *)(b10 + 8));/*B10[4][0] B10[5][0] B10[6][0] B10[7][0]*/\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ + ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_ps(ymm2, ymm0, ymm5);\ + ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } - ymm0 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 4)); +#define BLIS_STRSM_SMALL_GEMM_1nx16m(a01,b10,cs_b,p_lda,k_iter)\ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 8x1 block of B10*/\ + ymm0 = _mm256_loadu_ps((float const *)b10);/*B10[0][0] B10[1][0] B10[2][0] B10[3][0]*/\ + ymm1 = _mm256_loadu_ps((float const *)(b10 + 8));/*B10[4][0] B10[5][0] B10[6][0] B10[7][0]*/\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ + ymm4 = _mm256_fmadd_ps(ymm2, ymm1, ymm4);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_STRSM_SMALL_GEMM_4nx8m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 8x1 block of B10*/\ + ymm0 = _mm256_loadu_ps((float const *)b10);/*B10[0][0] B10[1][0] B10[2][0] B10[3][0]*/\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_ps(ymm2, ymm0, ymm5);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ + ymm7 = _mm256_fmadd_ps(ymm2, ymm0, ymm7);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 3)); /*A01[0][3]*/\ + ymm9 = _mm256_fmadd_ps(ymm2, ymm0, ymm9);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_STRSM_SMALL_GEMM_3nx8m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 8x1 block of B10*/\ + ymm0 = _mm256_loadu_ps((float const *)b10);/*B10[0][0] B10[1][0] B10[2][0] B10[3][0]*/\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_ps(ymm2, ymm0, ymm5);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ + ymm7 = _mm256_fmadd_ps(ymm2, ymm0, ymm7);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_STRSM_SMALL_GEMM_2nx8m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 8x1 block of B10*/\ + ymm0 = _mm256_loadu_ps((float const *)b10);/*B10[0][0] B10[1][0] B10[2][0] B10[3][0]*/\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_ps(ymm2, ymm0, ymm5);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_STRSM_SMALL_GEMM_1nx8m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 8x1 block of B10*/\ + ymm0 = _mm256_loadu_ps((float const *)b10);/*B10[0][0] B10[1][0] B10[2][0] B10[3][0]*/\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +/*GEMM block used in strsm small left cases*/ +#define BLIS_STRSM_SMALL_GEMM_16mx6n(a10,b01,cs_b,p_lda,k_iter) \ + float *b01_prefetch = b01 + 8; \ + for(k = 0; k< k_iter; k++) \ + { \ + ymm0 = _mm256_loadu_ps((float const *)(a10)); \ + ymm1 = _mm256_loadu_ps((float const *)(a10 + 8)); \ + _mm_prefetch((char*)( a10 + 64), _MM_HINT_T0); \ + /*Calculate the next micro pannel address to prefetch*/ \ + if(k & 0x7) b01_prefetch += cs_b; \ + else b01_prefetch = b01+ 8; \ + ymm2 = _mm256_broadcast_ss((float const *)(b01)); \ + ymm8 = _mm256_fmadd_ps(ymm2, ymm0, ymm8); \ + ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12); \ + \ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 1)); \ + ymm9 = _mm256_fmadd_ps(ymm2, ymm0, ymm9); \ + ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13); \ + \ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 2)); \ + ymm10 = _mm256_fmadd_ps(ymm2, ymm0, ymm10); \ + ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14); \ + \ + /*Prefetch the next 6x8 micro panelof B */ \ + _mm_prefetch((char*)( b01_prefetch), _MM_HINT_T0); \ + \ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 3)); \ + ymm11 = _mm256_fmadd_ps(ymm2, ymm0, ymm11); \ + ymm15 = _mm256_fmadd_ps(ymm2, ymm1, ymm15); \ + \ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 4)); \ + ymm4 = _mm256_fmadd_ps(ymm2, ymm0, ymm4); \ + ymm6 = _mm256_fmadd_ps(ymm2, ymm1, ymm6); \ + \ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 5)); \ + ymm5 = _mm256_fmadd_ps(ymm2, ymm0, ymm5); \ + ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7); \ + \ + b01 += 1; \ + a10 += p_lda; \ + } + +#define BLIS_STRSM_SMALL_GEMM_16mx4n(a10,b01,cs_b,p_lda,k_iter) \ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_loadu_ps((float const *)(a10));\ + ymm1 = _mm256_loadu_ps((float const *)(a10 + 8));\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01));\ + ymm8 = _mm256_fmadd_ps(ymm2, ymm0, ymm8);\ + ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 1));\ + ymm9 = _mm256_fmadd_ps(ymm2, ymm0, ymm9);\ + ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 2));\ + ymm10 = _mm256_fmadd_ps(ymm2, ymm0, ymm10);\ + ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 3));\ + ymm11 = _mm256_fmadd_ps(ymm2, ymm0, ymm11);\ + ymm15 = _mm256_fmadd_ps(ymm2, ymm1, ymm15);\ +\ + b01 += 1; /*move to next row of B*/\ + a10 += p_lda; /*pointer math to calculate next block of A for GEMM*/\ + } + +#define BLIS_STRSM_SMALL_GEMM_16mx3n(a10,b01,cs_b,p_lda,k_iter) \ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_loadu_ps((float const *)(a10));\ + ymm1 = _mm256_loadu_ps((float const *)(a10 + 8));\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 0));\ + ymm8 = _mm256_fmadd_ps(ymm2, ymm0, ymm8);\ + ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 1));\ + ymm9 = _mm256_fmadd_ps(ymm2, ymm0, ymm9);\ + ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 2));\ + ymm10 = _mm256_fmadd_ps(ymm2, ymm0, ymm10);\ + ymm14 = _mm256_fmadd_ps(ymm2, ymm1, ymm14);\ +\ + b01 += 1; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + } + +#define BLIS_STRSM_SMALL_GEMM_16mx2n(a10,b01,cs_b,p_lda,k_iter) \ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_loadu_ps((float const *)(a10));\ + ymm1 = _mm256_loadu_ps((float const *)(a10 + 8));\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 0));\ + ymm8 = _mm256_fmadd_ps(ymm2, ymm0, ymm8);\ + ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 1));\ + ymm9 = _mm256_fmadd_ps(ymm2, ymm0, ymm9);\ + ymm13 = _mm256_fmadd_ps(ymm2, ymm1, ymm13);\ +\ + b01 += 1; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + } + +#define BLIS_STRSM_SMALL_GEMM_16mx1n(a10,b01,cs_b,p_lda,k_iter) \ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_loadu_ps((float const *)(a10));\ + ymm1 = _mm256_loadu_ps((float const *)(a10 + 8));\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 0));\ + ymm8 = _mm256_fmadd_ps(ymm2, ymm0, ymm8);\ + ymm12 = _mm256_fmadd_ps(ymm2, ymm1, ymm12);\ + b01 += 1; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + } + +#define BLIS_STRSM_SMALL_GEMM_8mx6n(a10,b01,cs_b,p_lda,k_iter) \ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_loadu_ps((float const *)(a10));\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 0));\ + ymm8 = _mm256_fmadd_ps(ymm2, ymm0, ymm8);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 1));\ + ymm9 = _mm256_fmadd_ps(ymm2, ymm0, ymm9);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 2));\ + ymm10 = _mm256_fmadd_ps(ymm2, ymm0, ymm10);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 3));\ + ymm11 = _mm256_fmadd_ps(ymm2, ymm0, ymm11);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 4));\ + ymm4 = _mm256_fmadd_ps(ymm2, ymm0, ymm4);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 5));\ + ymm5 = _mm256_fmadd_ps(ymm2, ymm0, ymm5);\ +\ + b01 += 1; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + } + +#define BLIS_STRSM_SMALL_GEMM_8mx4n(a10,b01,cs_b,p_lda,k_iter) \ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_loadu_ps((float const *)(a10));\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 0));\ + ymm8 = _mm256_fmadd_ps(ymm2, ymm0, ymm8);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 1));\ + ymm9 = _mm256_fmadd_ps(ymm2, ymm0, ymm9);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 2));\ + ymm10 = _mm256_fmadd_ps(ymm2, ymm0, ymm10);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 3));\ + ymm11 = _mm256_fmadd_ps(ymm2, ymm0, ymm11);\ +\ + b01 += 1; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + } + +#define BLIS_STRSM_SMALL_GEMM_8mx3n(a10,b01,cs_b,p_lda,k_iter) \ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_loadu_ps((float const *)(a10));\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 0));\ + ymm8 = _mm256_fmadd_ps(ymm2, ymm0, ymm8);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 1));\ + ymm9 = _mm256_fmadd_ps(ymm2, ymm0, ymm9);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 2));\ + ymm10 = _mm256_fmadd_ps(ymm2, ymm0, ymm10);\ +\ + b01 += 1; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + } + +#define BLIS_STRSM_SMALL_GEMM_8mx2n(a10,b01,cs_b,p_lda,k_iter) \ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_loadu_ps((float const *)(a10));\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 0));\ + ymm8 = _mm256_fmadd_ps(ymm2, ymm0, ymm8);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 1));\ + ymm9 = _mm256_fmadd_ps(ymm2, ymm0, ymm9);\ +\ + b01 += 1; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + } + +#define BLIS_STRSM_SMALL_GEMM_8mx1n(a10,b01,cs_b,p_lda,k_iter) \ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_loadu_ps((float const *)(a10));\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 0));\ + ymm8 = _mm256_fmadd_ps(ymm2, ymm0, ymm8);\ +\ + b01 += 1; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + } + +#define BLIS_STRSM_SMALL_GEMM_4mx6n(a10,b01,cs_b,p_lda,k_iter) \ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_insertf128_ps(ymm1, _mm_loadu_ps((float const*)(a10)), 0);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 0));\ + ymm8 = _mm256_fmadd_ps(ymm2, ymm0, ymm8);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 1));\ + ymm9 = _mm256_fmadd_ps(ymm2, ymm0, ymm9);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 2));\ + ymm10 = _mm256_fmadd_ps(ymm2, ymm0, ymm10);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 3));\ + ymm11 = _mm256_fmadd_ps(ymm2, ymm0, ymm11);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 4));\ + ymm4 = _mm256_fmadd_ps(ymm2, ymm0, ymm4);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 5));\ + ymm5 = _mm256_fmadd_ps(ymm2, ymm0, ymm5);\ +\ + b01 += 1; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + } + +#define BLIS_STRSM_SMALL_GEMM_4mx4n(a10,b01,cs_b,p_lda,k_iter) \ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_insertf128_ps(ymm1, _mm_loadu_ps((float const*)(a10)), 0);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 0));\ + ymm8 = _mm256_fmadd_ps(ymm2, ymm0, ymm8);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 1));\ + ymm9 = _mm256_fmadd_ps(ymm2, ymm0, ymm9);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 2));\ + ymm10 = _mm256_fmadd_ps(ymm2, ymm0, ymm10);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 3));\ + ymm11 = _mm256_fmadd_ps(ymm2, ymm0, ymm11);\ +\ + b01 += 1; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + } + +#define BLIS_STRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) \ +\ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_insertf128_ps(ymm1, _mm_loadu_ps((float const*)(a10)), 0);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 0));\ + ymm8 = _mm256_fmadd_ps(ymm2, ymm0, ymm8);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 1));\ + ymm9 = _mm256_fmadd_ps(ymm2, ymm0, ymm9);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 2));\ + ymm10 = _mm256_fmadd_ps(ymm2, ymm0, ymm10);\ +\ + b01 += 1; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + } + +#define BLIS_STRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b,p_lda,k_iter) \ +\ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_insertf128_ps(ymm1, _mm_loadu_ps((float const*)(a10)), 0);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 0));\ + ymm8 = _mm256_fmadd_ps(ymm2, ymm0, ymm8);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 1));\ + ymm9 = _mm256_fmadd_ps(ymm2, ymm0, ymm9);\ +\ + b01 += 1; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + } + +#define BLIS_STRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b,p_lda,k_iter) \ +\ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_insertf128_ps(ymm1, _mm_loadu_ps((float const*)(a10)), 0);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 0));\ + ymm8 = _mm256_fmadd_ps(ymm2, ymm0, ymm8);\ +\ + b01 += 1; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + } + +#define BLIS_STRSM_SMALL_GEMM_3mx6n(a10,b01,cs_b,p_lda,k_iter) \ +\ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + xmm4 = _mm_broadcast_ss((float const*)(a10 + 2));\ + ymm0 = _mm256_insertf128_ps(ymm1, _mm_loadl_pi(xmm4,(__m64 *)(a10)), 0);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 0));\ + ymm8 = _mm256_fmadd_ps(ymm2, ymm0, ymm8);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 1));\ + ymm9 = _mm256_fmadd_ps(ymm2, ymm0, ymm9);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 2));\ + ymm10 = _mm256_fmadd_ps(ymm2, ymm0, ymm10);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 3));\ + ymm11 = _mm256_fmadd_ps(ymm2, ymm0, ymm11);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 4));\ + ymm4 = _mm256_fmadd_ps(ymm2, ymm0, ymm4);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 5));\ + ymm5 = _mm256_fmadd_ps(ymm2, ymm0, ymm5);\ +\ + b01 += 1; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + } + +#define BLIS_STRSM_SMALL_GEMM_3mx4n(a10,b01,cs_b,p_lda,k_iter) \ +\ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + xmm4 = _mm_broadcast_ss((float const*)(a10 + 2));\ + ymm0 = _mm256_insertf128_ps(ymm1, _mm_loadl_pi(xmm4,(__m64 *)(a10)), 0);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 0));\ + ymm8 = _mm256_fmadd_ps(ymm2, ymm0, ymm8);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 1));\ + ymm9 = _mm256_fmadd_ps(ymm2, ymm0, ymm9);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 2));\ + ymm10 = _mm256_fmadd_ps(ymm2, ymm0, ymm10);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 3));\ + ymm11 = _mm256_fmadd_ps(ymm2, ymm0, ymm11);\ +\ + b01 += 1; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + } + +#define BLIS_STRSM_SMALL_GEMM_3mx3n(a10,b01,cs_b,p_lda,k_iter) \ +\ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + xmm4 = _mm_broadcast_ss((float const*)(a10 + 2));\ + ymm0 = _mm256_insertf128_ps(ymm1, _mm_loadl_pi(xmm4,(__m64 *)(a10)), 0);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 0));\ + ymm8 = _mm256_fmadd_ps(ymm2, ymm0, ymm8);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 1));\ + ymm9 = _mm256_fmadd_ps(ymm2, ymm0, ymm9);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 2));\ + ymm10 = _mm256_fmadd_ps(ymm2, ymm0, ymm10);\ +\ + b01 += 1; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + } + +#define BLIS_STRSM_SMALL_GEMM_3mx2n(a10,b01,cs_b,p_lda,k_iter) \ +\ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + xmm4 = _mm_broadcast_ss((float const*)(a10 + 2));\ + ymm0 = _mm256_insertf128_ps(ymm1, _mm_loadl_pi(xmm4,(__m64 *)(a10)), 0);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 0));\ + ymm8 = _mm256_fmadd_ps(ymm2, ymm0, ymm8);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 1));\ + ymm9 = _mm256_fmadd_ps(ymm2, ymm0, ymm9);\ +\ + b01 += 1; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + } + +#define BLIS_STRSM_SMALL_GEMM_3mx1n(a10,b01,cs_b,p_lda,k_iter) \ +\ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + xmm4 = _mm_broadcast_ss((float const*)(a10 + 2));\ + ymm0 = _mm256_insertf128_ps(ymm1, _mm_loadl_pi(xmm4,(__m64 *)(a10)), 0);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 0));\ + ymm8 = _mm256_fmadd_ps(ymm2, ymm0, ymm8);\ +\ + b01 += 1; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + } + +#define BLIS_STRSM_SMALL_GEMM_2mx6n(a10,b01,cs_b,p_lda,k_iter) \ +\ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_insertf128_ps(ymm1, _mm_loadl_pi(xmm4,(__m64 *)(a10)), 0);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 0));\ + ymm8 = _mm256_fmadd_ps(ymm2, ymm0, ymm8);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 1));\ + ymm9 = _mm256_fmadd_ps(ymm2, ymm0, ymm9);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 2));\ + ymm10 = _mm256_fmadd_ps(ymm2, ymm0, ymm10);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 3));\ + ymm11 = _mm256_fmadd_ps(ymm2, ymm0, ymm11);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 4));\ + ymm4 = _mm256_fmadd_ps(ymm2, ymm0, ymm4);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 5));\ + ymm5 = _mm256_fmadd_ps(ymm2, ymm0, ymm5);\ +\ + b01 += 1; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + } + +#define BLIS_STRSM_SMALL_GEMM_2mx4n(a10,b01,cs_b,p_lda,k_iter) \ +\ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_insertf128_ps(ymm1, _mm_loadl_pi(xmm4,(__m64 *)(a10)), 0);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 0));\ + ymm8 = _mm256_fmadd_ps(ymm2, ymm0, ymm8);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 1));\ + ymm9 = _mm256_fmadd_ps(ymm2, ymm0, ymm9);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 2));\ + ymm10 = _mm256_fmadd_ps(ymm2, ymm0, ymm10);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 3));\ + ymm11 = _mm256_fmadd_ps(ymm2, ymm0, ymm11);\ +\ + b01 += 1; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + } + +#define BLIS_STRSM_SMALL_GEMM_2mx3n(a10,b01,cs_b,p_lda,k_iter) \ +\ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_insertf128_ps(ymm1, _mm_loadl_pi(xmm4,(__m64 *)(a10)), 0);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 0));\ + ymm8 = _mm256_fmadd_ps(ymm2, ymm0, ymm8);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 1));\ + ymm9 = _mm256_fmadd_ps(ymm2, ymm0, ymm9);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 2));\ + ymm10 = _mm256_fmadd_ps(ymm2, ymm0, ymm10);\ +\ + b01 += 1; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + } + +#define BLIS_STRSM_SMALL_GEMM_2mx2n(a10,b01,cs_b,p_lda,k_iter) \ +\ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_insertf128_ps(ymm1, _mm_loadl_pi(xmm4,(__m64 *)(a10)), 0);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 0));\ + ymm8 = _mm256_fmadd_ps(ymm2, ymm0, ymm8);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 1));\ + ymm9 = _mm256_fmadd_ps(ymm2, ymm0, ymm9);\ +\ + b01 += 1; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + } + +#define BLIS_STRSM_SMALL_GEMM_2mx1n(a10,b01,cs_b,p_lda,k_iter) \ +\ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_insertf128_ps(ymm1, _mm_loadl_pi(xmm4,(__m64 *)(a10)), 0);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 0));\ + ymm8 = _mm256_fmadd_ps(ymm2, ymm0, ymm8);\ +\ + b01 += 1; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + } + +#define BLIS_STRSM_SMALL_GEMM_1mx6n(a10,b01,cs_b,p_lda,k_iter) \ +\ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_broadcast_ss((float const*)(a10));\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 0));\ + ymm8 = _mm256_fmadd_ps(ymm2, ymm0, ymm8);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 1));\ + ymm9 = _mm256_fmadd_ps(ymm2, ymm0, ymm9);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 2));\ + ymm10 = _mm256_fmadd_ps(ymm2, ymm0, ymm10);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 3));\ + ymm11 = _mm256_fmadd_ps(ymm2, ymm0, ymm11);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 4));\ + ymm4 = _mm256_fmadd_ps(ymm2, ymm0, ymm4);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 5));\ + ymm5 = _mm256_fmadd_ps(ymm2, ymm0, ymm5);\ +\ + b01 += 1; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + } + +#define BLIS_STRSM_SMALL_GEMM_1mx4n(a10,b01,cs_b,p_lda,k_iter) \ +\ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_broadcast_ss((float const*)(a10));\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 0));\ + ymm8 = _mm256_fmadd_ps(ymm2, ymm0, ymm8);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 1));\ + ymm9 = _mm256_fmadd_ps(ymm2, ymm0, ymm9);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 2));\ + ymm10 = _mm256_fmadd_ps(ymm2, ymm0, ymm10);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 3));\ + ymm11 = _mm256_fmadd_ps(ymm2, ymm0, ymm11);\ +\ + b01 += 1; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + } + +#define BLIS_STRSM_SMALL_GEMM_1mx3n(a10,b01,cs_b,p_lda,k_iter) \ +\ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_broadcast_ss((float const*)(a10));\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 0));\ + ymm8 = _mm256_fmadd_ps(ymm2, ymm0, ymm8);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 1));\ + ymm9 = _mm256_fmadd_ps(ymm2, ymm0, ymm9);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 2));\ + ymm10 = _mm256_fmadd_ps(ymm2, ymm0, ymm10);\ +\ + b01 += 1; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + } + +#define BLIS_STRSM_SMALL_GEMM_1mx2n(a10,b01,cs_b,p_lda,k_iter) \ +\ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_broadcast_ss((float const*)(a10));\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 0));\ + ymm8 = _mm256_fmadd_ps(ymm2, ymm0, ymm8);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 1));\ + ymm9 = _mm256_fmadd_ps(ymm2, ymm0, ymm9);\ +\ + b01 += 1; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + } + +#define BLIS_STRSM_SMALL_GEMM_1mx1n(a10,b01,cs_b,p_lda,k_iter) \ +\ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_broadcast_ss((float const*)(a10));\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(b01 + cs_b * 0));\ + ymm8 = _mm256_fmadd_ps(ymm2, ymm0, ymm8);\ +\ + b01 += 1; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + } + +#define BLIS_PRE_STRSM_SMALL_3M_3N(AlphaVal,b11,cs_b)\ + ymm16 = _mm256_broadcast_ss((float const *)(&AlphaVal)); /*register to hold alpha*/\ +\ + __m128 xmm4 = _mm_broadcast_ss((float const *)(b11 + 2));\ + xmm5 = _mm_loadl_pi(xmm4,(__m64 *)(b11));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + xmm4 = _mm_broadcast_ss((float const *)(b11 + cs_b + 2));\ + xmm5 = _mm_loadl_pi(xmm4,(__m64 *)(b11 + cs_b));\ + ymm1 = _mm256_insertf128_ps(ymm1, xmm5, 0);\ + xmm4 = _mm_broadcast_ss((float const *)(b11 + cs_b*2 + 2));\ + xmm5 = _mm_loadl_pi(xmm4,(__m64 *)(b11 + cs_b*2));\ + ymm2 = _mm256_insertf128_ps(ymm2, xmm5, 0);\ +\ + ymm8 = _mm256_fmsub_ps(ymm0, ymm16, ymm8);\ + ymm9 = _mm256_fmsub_ps(ymm1, ymm16, ymm9);\ + ymm10 = _mm256_fmsub_ps(ymm2, ymm16, ymm10);\ +\ + xmm5 = _mm256_extractf128_ps(ymm8, 0);\ + _mm_storel_pi((__m64 *)(b11), xmm5);\ + _mm_store_ss((float *)(b11 + 2), _mm_permute_ps(_mm256_extractf128_ps(ymm8, 0),0x02));\ + xmm5 = _mm256_extractf128_ps(ymm9, 0);\ + _mm_storel_pi((__m64 *)(b11 + cs_b), xmm5);\ + _mm_store_ss((float *)(b11 + cs_b + 2), _mm_permute_ps(_mm256_extractf128_ps(ymm9, 0),0x02));\ + xmm5 = _mm256_extractf128_ps(ymm10, 0);\ + _mm_storel_pi((__m64 *)(b11 + cs_b*2), xmm5);\ + _mm_store_ss((float *)(b11 + cs_b*2 + 2), _mm_permute_ps(_mm256_extractf128_ps(ymm10, 0),0x02)); + +#define BLIS_PRE_STRSM_SMALL_3M_2N(AlphaVal,b11,cs_b)\ + ymm16 = _mm256_broadcast_ss((float const *)(&AlphaVal)); /*register to hold alpha*/\ +\ + __m128 xmm4 = _mm_broadcast_ss((float const *)(b11 + 2));\ + xmm5 = _mm_loadl_pi(xmm4,(__m64 *)(b11));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + xmm4 = _mm_broadcast_ss((float const *)(b11 + cs_b + 2));\ + xmm5 = _mm_loadl_pi(xmm4,(__m64 *)(b11 + cs_b));\ + ymm1 = _mm256_insertf128_ps(ymm1, xmm5, 0);\ +\ + ymm8 = _mm256_fmsub_ps(ymm0, ymm16, ymm8);\ + ymm9 = _mm256_fmsub_ps(ymm1, ymm16, ymm9);\ +\ + xmm5 = _mm256_extractf128_ps(ymm8, 0);\ + _mm_storel_pi((__m64 *)(b11), xmm5);\ + _mm_store_ss((float *)(b11 + 2), _mm_permute_ps(_mm256_extractf128_ps(ymm8, 0),0x02));\ + xmm5 = _mm256_extractf128_ps(ymm9, 0);\ + _mm_storel_pi((__m64 *)(b11 + cs_b), xmm5);\ + _mm_store_ss((float *)(b11 + cs_b + 2), _mm_permute_ps(_mm256_extractf128_ps(ymm9, 0),0x02)); + +#define BLIS_PRE_STRSM_SMALL_3M_1N(AlphaVal,b11,cs_b)\ + ymm16 = _mm256_broadcast_ss((float const *)(&AlphaVal)); /*register to hold alpha*/\ +\ + __m128 xmm4 = _mm_broadcast_ss((float const *)(b11 + 2));\ + xmm5 = _mm_loadl_pi(xmm4,(__m64 *)(b11 + cs_b * 0));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm8 = _mm256_fmsub_ps(ymm0, ymm16, ymm8);\ +\ + xmm5 = _mm256_extractf128_ps(ymm8, 0);\ + _mm_storel_pi((__m64 *)(b11), xmm5);\ + _mm_store_ss((float *)(b11 + 2), _mm_permute_ps(_mm256_extractf128_ps(ymm8, 0),0x02)); + +#define BLIS_PRE_STRSM_SMALL_2M_3N(AlphaVal,b11,cs_b)\ + ymm16 = _mm256_broadcast_ss((float const *)(&AlphaVal)); /*register to hold alpha*/\ +\ + __m128 xmm4 = _mm_broadcast_ss((float const *)(&zero));\ + xmm5 = _mm_loadl_pi(xmm4,(__m64 *)(b11 + cs_b * 0));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + xmm5 = _mm_loadl_pi(xmm4,(__m64 *)(b11 + cs_b * 1));\ + ymm1 = _mm256_insertf128_ps(ymm1, xmm5, 0);\ + xmm5 = _mm_loadl_pi(xmm4,(__m64 *)(b11 + cs_b * 2));\ + ymm2 = _mm256_insertf128_ps(ymm2, xmm5, 0);\ +\ + ymm8 = _mm256_fmsub_ps(ymm0, ymm16, ymm8);\ + ymm9 = _mm256_fmsub_ps(ymm1, ymm16, ymm9);\ + ymm10 = _mm256_fmsub_ps(ymm2, ymm16, ymm10);\ +\ + xmm5 = _mm256_extractf128_ps(ymm8, 0);\ + _mm_storel_pi((__m64 *)(b11), xmm5); /*store(B11[0-3][0])*/\ + xmm5 = _mm256_extractf128_ps(ymm9, 0);\ + _mm_storel_pi((__m64 *)(b11 + cs_b * 1), xmm5); /*store(B11[0-3][1])*/\ + xmm5 = _mm256_extractf128_ps(ymm10, 0);\ + _mm_storel_pi((__m64 *)(b11 + cs_b * 2), xmm5); + +#define BLIS_PRE_STRSM_SMALL_2M_2N(AlphaVal,b11,cs_b)\ + ymm16 = _mm256_broadcast_ss((float const *)(&AlphaVal)); /*register to hold alpha*/\ +\ + __m128 xmm4 = _mm_broadcast_ss((float const *)(&zero));\ + xmm5 = _mm_loadl_pi(xmm4,(__m64 *)(b11 + cs_b * 0));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + xmm5 = _mm_loadl_pi(xmm4,(__m64 *)(b11 + cs_b * 1));\ + ymm1 = _mm256_insertf128_ps(ymm1, xmm5, 0);\ +\ + ymm8 = _mm256_fmsub_ps(ymm0, ymm16, ymm8);\ + ymm9 = _mm256_fmsub_ps(ymm1, ymm16, ymm9);\ +\ + xmm5 = _mm256_extractf128_ps(ymm8, 0);\ + _mm_storel_pi((__m64 *)(b11), xmm5); /*store(B11[0-3][0])*/\ + xmm5 = _mm256_extractf128_ps(ymm9, 0);\ + _mm_storel_pi((__m64 *)(b11 + cs_b * 1), xmm5); + +#define BLIS_PRE_STRSM_SMALL_2M_1N(AlphaVal,b11,cs_b)\ + ymm16 = _mm256_broadcast_ss((float const *)(&AlphaVal)); /*register to hold alpha*/\ +\ + __m128 xmm4 = _mm_broadcast_ss((float const *)(&zero));\ + xmm5 = _mm_loadl_pi(xmm4,(__m64 *)(b11 + cs_b * 0));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm8 = _mm256_fmsub_ps(ymm0, ymm16, ymm8);\ +\ + xmm5 = _mm256_extractf128_ps(ymm8, 0);\ + _mm_storel_pi((__m64 *)(b11 + cs_b * 0), xmm5); + +#define BLIS_PRE_STRSM_SMALL_1M_3N(AlphaVal,b11,cs_b)\ + ymm16 = _mm256_broadcast_ss((float const *)(&AlphaVal)); /*register to hold alpha*/\ +\ + ymm0 = _mm256_broadcast_ss((float const *)(b11 + cs_b *0));\ + ymm1 = _mm256_broadcast_ss((float const *)(b11 + cs_b *1));\ + ymm2 = _mm256_broadcast_ss((float const *)(b11 + cs_b *2));\ +\ + ymm8 = _mm256_fmsub_ps(ymm0, ymm16, ymm8);\ + ymm9 = _mm256_fmsub_ps(ymm1, ymm16, ymm9);\ + ymm10 = _mm256_fmsub_ps(ymm2, ymm16, ymm10);\ +\ + _mm_store_ss((b11 + cs_b * 0), _mm256_extractf128_ps(ymm8, 0));\ + _mm_store_ss((b11 + cs_b * 1), _mm256_extractf128_ps(ymm9, 0));\ + _mm_store_ss((b11 + cs_b * 2), _mm256_extractf128_ps(ymm10, 0)); + +#define BLIS_PRE_STRSM_SMALL_1M_2N(AlphaVal,b11,cs_b)\ + ymm16 = _mm256_broadcast_ss((float const *)(&AlphaVal)); /*register to hold alpha*/\ +\ + ymm0 = _mm256_broadcast_ss((float const *)(b11 + cs_b *0));\ + ymm1 = _mm256_broadcast_ss((float const *)(b11 + cs_b *1));\ +\ + ymm8 = _mm256_fmsub_ps(ymm0, ymm16, ymm8);\ + ymm9 = _mm256_fmsub_ps(ymm1, ymm16, ymm9);\ +\ + _mm_store_ss((b11 + cs_b * 0), _mm256_extractf128_ps(ymm8, 0));\ + _mm_store_ss((b11 + cs_b * 1), _mm256_extractf128_ps(ymm9, 0)); + +#define BLIS_PRE_STRSM_SMALL_1M_1N(AlphaVal,b11,cs_b)\ + ymm16 = _mm256_broadcast_ss((float const *)(&AlphaVal)); /*register to hold alpha*/\ +\ + ymm0 = _mm256_broadcast_ss((float const *)(b11 + cs_b *0));\ + ymm8 = _mm256_fmsub_ps(ymm0, ymm16, ymm8);\ +\ + _mm_store_ss((b11 + cs_b * 0), _mm256_extractf128_ps(ymm8, 0)); + + +/* pre & post TRSM for Right remainder cases*/ +#define BLIS_PRE_STRSM_SMALL_3N_7M(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); /*register to hold alpha*/\ +\ + __m128 xmm6;\ + xmm6 = _mm_broadcast_ss((float const*)(b11 + 6));\ + xmm5 = _mm_loadu_ps((float const*)(b11));\ + ymm6 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + xmm5 = _mm_loadl_pi(xmm6,(__m64*)(b11 + 4));\ + ymm6 = _mm256_insertf128_ps(ymm6, xmm5, 1);\ + ymm3 = _mm256_fmsub_ps(ymm6, ymm15, ymm3);\ +\ + xmm6 = _mm_broadcast_ss((float const*)(b11 + 6 + cs_b));\ + xmm5 = _mm_loadu_ps((float const*)(b11 + cs_b));\ + ymm6 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + xmm5 = _mm_loadl_pi(xmm6,(__m64*)(b11 + 4 + cs_b));\ + ymm6 = _mm256_insertf128_ps(ymm6, xmm5, 1);\ + ymm5 = _mm256_fmsub_ps(ymm6, ymm15, ymm5);\ +\ + xmm6 = _mm_broadcast_ss((float const*)(b11 + 6 + cs_b*2));\ + xmm5 = _mm_loadu_ps((float const*)(b11 + cs_b*2));\ + ymm6 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + xmm5 = _mm_loadl_pi(xmm6,(__m64*)(b11 + 4 + cs_b*2));\ + ymm6 = _mm256_insertf128_ps(ymm6, xmm5, 1);\ + ymm7 = _mm256_fmsub_ps(ymm6, ymm15, ymm7); + +#define BLIS_POST_STRSM_SMALL_3N_7M(b11,cs_b)\ + xmm5 = _mm256_extractf128_ps(ymm3, 0);\ + _mm_storeu_ps((float *)(b11), xmm5);\ + _mm_storel_pi((__m64 *)(b11 + 4), _mm256_extractf128_ps(ymm3, 1));\ + _mm_store_ss((float *)(b11 + 6),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm3,ymm3), 1));\ + xmm5 = _mm256_extractf128_ps(ymm5, 0);\ + _mm_storeu_ps((float *)(b11 + cs_b*1), xmm5);\ + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*1), _mm256_extractf128_ps(ymm5, 1));\ + _mm_store_ss((float *)(b11 + 6 + cs_b),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm5,ymm5), 1));\ + xmm5 = _mm256_extractf128_ps(ymm7, 0);\ + _mm_storeu_ps((float *)(b11 + cs_b*2), xmm5);\ + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*2), _mm256_extractf128_ps(ymm7, 1));\ + _mm_store_ss((float *)(b11 + 6 + cs_b*2),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm7,ymm7), 1)); + +#define BLIS_PRE_STRSM_SMALL_3N_6M(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); /*register to hold alpha*/\ +\ + __m128 xmm6;\ + xmm6 = _mm_broadcast_ss((float const*)&zero);\ + xmm5 = _mm_loadu_ps((float const*)(b11));\ + ymm6 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + xmm5 = _mm_loadl_pi(xmm6,(__m64*)(b11 + 4));\ + ymm6 = _mm256_insertf128_ps(ymm6, xmm5, 1);\ + ymm3 = _mm256_fmsub_ps(ymm6, ymm15, ymm3);\ +\ + xmm5 = _mm_loadu_ps((float const*)(b11 + cs_b));\ + ymm6 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + xmm5 = _mm_loadl_pi(xmm6,(__m64*)(b11 + 4 + cs_b));\ + ymm6 = _mm256_insertf128_ps(ymm6, xmm5, 1);\ + ymm5 = _mm256_fmsub_ps(ymm6, ymm15, ymm5);\ +\ + xmm5 = _mm_loadu_ps((float const*)(b11 + cs_b*2));\ + ymm6 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + xmm5 = _mm_loadl_pi(xmm6,(__m64*)(b11 + 4 + cs_b*2));\ + ymm6 = _mm256_insertf128_ps(ymm6, xmm5, 1);\ + ymm7 = _mm256_fmsub_ps(ymm6, ymm15, ymm7); + +#define BLIS_POST_STRSM_SMALL_3N_6M(b11,cs_b)\ + xmm5 = _mm256_extractf128_ps(ymm3, 0);\ + _mm_storeu_ps((float *)(b11), xmm5);\ + _mm_storel_pi((__m64 *)(b11 + 4), _mm256_extractf128_ps(ymm3, 1));\ + xmm5 = _mm256_extractf128_ps(ymm5, 0);\ + _mm_storeu_ps((float *)(b11 + cs_b*1), xmm5);\ + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*1), _mm256_extractf128_ps(ymm5, 1));\ + xmm5 = _mm256_extractf128_ps(ymm7, 0);\ + _mm_storeu_ps((float *)(b11 + cs_b*2), xmm5);\ + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*2), _mm256_extractf128_ps(ymm7, 1)); + +#define BLIS_PRE_STRSM_SMALL_3N_5M(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); /*register to hold alpha*/\ +\ + ymm0 = _mm256_broadcast_ss((float const *)(b11 + 4));\ + xmm5 = _mm_loadu_ps((float const*)(b11));\ + ymm6 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + ymm3 = _mm256_fmsub_ps(ymm6, ymm15, ymm3);\ +\ + ymm0 = _mm256_broadcast_ss((float const *)(b11 + 4 + cs_b));\ + xmm5 = _mm_loadu_ps((float const*)(b11 + cs_b));\ + ymm6 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + ymm5 = _mm256_fmsub_ps(ymm6, ymm15, ymm5);\ +\ + ymm0 = _mm256_broadcast_ss((float const *)(b11 + 4 + cs_b*2));\ + xmm5 = _mm_loadu_ps((float const*)(b11 + cs_b*2));\ + ymm6 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + ymm7 = _mm256_fmsub_ps(ymm6, ymm15, ymm7); + +#define BLIS_POST_STRSM_SMALL_3N_5M(b11,cs_b)\ + xmm5 = _mm256_extractf128_ps(ymm3, 0);\ + _mm_storeu_ps((float *)(b11), xmm5);\ + _mm_store_ss((float *)(b11 + 4), _mm256_extractf128_ps(ymm3, 1));\ + xmm5 = _mm256_extractf128_ps(ymm5, 0);\ + _mm_storeu_ps((float *)(b11 + cs_b*1), xmm5);\ + _mm_store_ss((float *)(b11 + 4 + cs_b*1), _mm256_extractf128_ps(ymm5, 1));\ + xmm5 = _mm256_extractf128_ps(ymm7, 0);\ + _mm_storeu_ps((float *)(b11 + cs_b*2), xmm5);\ + _mm_store_ss((float *)(b11 + 4 + cs_b*2), _mm256_extractf128_ps(ymm7, 1)); + +#define BLIS_PRE_STRSM_SMALL_3N_4M(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); /*register to hold alpha*/\ +\ + /*__m128 xmm6 = _mm_broadcast_ss((float const *)(b11+ 2));*/\ + xmm5 = _mm_loadu_ps((float const*)(b11));\ + ymm6 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + ymm3 = _mm256_fmsub_ps(ymm6, ymm15, ymm3);\ +\ + xmm5 = _mm_loadu_ps((float const*)(b11 + cs_b));\ + ymm6 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + ymm5 = _mm256_fmsub_ps(ymm6, ymm15, ymm5);\ +\ + xmm5 = _mm_loadu_ps((float const*)(b11 + cs_b*2));\ + ymm6 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + ymm7 = _mm256_fmsub_ps(ymm6, ymm15, ymm7); + +#define BLIS_POST_STRSM_SMALL_3N_4M(b11,cs_b)\ + xmm5 = _mm256_extractf128_ps(ymm3, 0);\ + _mm_storeu_ps((float *)(b11), xmm5);\ + xmm5 = _mm256_extractf128_ps(ymm5, 0);\ + _mm_storeu_ps((float *)(b11 + cs_b*1), xmm5);\ + xmm5 = _mm256_extractf128_ps(ymm7, 0);\ + _mm_storeu_ps((float *)(b11 + cs_b*2), xmm5); + +#define BLIS_PRE_STRSM_SMALL_3N_3M(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); /*register to hold alpha*/\ +\ + __m128 xmm6 = _mm_broadcast_ss((float const *)(b11+ 2));\ + xmm5 = _mm_loadl_pi(xmm6,(__m64*)(b11));\ + ymm6 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + ymm3 = _mm256_fmsub_ps(ymm6, ymm15, ymm3);\ + xmm6 = _mm_broadcast_ss((float const *)(b11 + cs_b + 2));\ + xmm5 = _mm_loadl_pi(xmm6,(__m64*)(b11 + cs_b));\ + ymm6 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + ymm5 = _mm256_fmsub_ps(ymm6, ymm15, ymm5);\ + xmm6 = _mm_broadcast_ss((float const *)(b11 + cs_b*2 + 2));\ + xmm5 = _mm_loadl_pi(xmm6,(__m64*)(b11 + cs_b*2));\ + ymm6 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + ymm7 = _mm256_fmsub_ps(ymm6, ymm15, ymm7); + +#define BLIS_POST_STRSM_SMALL_3N_3M(b11,cs_b)\ + xmm5 = _mm256_extractf128_ps(ymm3, 0);\ + _mm_storel_pi((__m64 *)(b11), xmm5);\ + _mm_store_ss((float *)(b11 + 2),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm3,ymm3), 0));\ + xmm5 = _mm256_extractf128_ps(ymm5, 0);\ + _mm_storel_pi((__m64 *)(b11 + cs_b*1), xmm5);\ + _mm_store_ss((float *)(b11 + 2 + cs_b*1),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm5,ymm5), 0));\ + xmm5 = _mm256_extractf128_ps(ymm7, 0);\ + _mm_storel_pi((__m64 *)(b11 + cs_b*2), xmm5);\ + _mm_store_ss((float *)(b11 + 2 + cs_b*2),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm7,ymm7), 0)); + +#define BLIS_PRE_STRSM_SMALL_3N_2M(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); /*register to hold alpha*/\ +\ + xmm5 = _mm_loadl_pi(xmm5,(__m64*)(b11));\ + ymm6 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + ymm3 = _mm256_fmsub_ps(ymm6, ymm15, ymm3);\ + xmm5 = _mm_loadl_pi(xmm5,(__m64*)(b11 + cs_b));\ + ymm6 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + ymm5 = _mm256_fmsub_ps(ymm6, ymm15, ymm5);\ + xmm5 = _mm_loadl_pi(xmm5,(__m64*)(b11 + cs_b * 2));\ + ymm6 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + ymm7 = _mm256_fmsub_ps(ymm6, ymm15, ymm7); + +#define BLIS_POST_STRSM_SMALL_3N_2M(b11,cs_b)\ + /*ymm0 = _mm256_loadu_ps((float const *)b11);*/\ + ymm3 = _mm256_blend_ps(ymm0, ymm3, 0x03);\ + /*ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b));*/\ + ymm5 = _mm256_blend_ps(ymm0, ymm5, 0x03);\ + /*xmm5 = _mm_loadu_pd((float const*)(b11 + cs_b * 2));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);*/\ + ymm7 = _mm256_blend_ps(ymm0, ymm7, 0x03);\ +\ + xmm5 = _mm256_extractf128_ps(ymm3, 0);\ + _mm_storel_pi((__m64 *)(b11), xmm5);\ + xmm5 = _mm256_extractf128_ps(ymm5, 0);\ + _mm_storel_pi((__m64 *)(b11 + cs_b), xmm5);\ + xmm5 = _mm256_extractf128_ps(ymm7, 0);\ + _mm_storel_pi((__m64 *)(b11 + cs_b * 2), xmm5); + +#define BLIS_PRE_STRSM_SMALL_3N_1M(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); /*register to hold alpha*/\ +\ + ymm0 = _mm256_broadcast_ss((float const *)b11);\ + ymm3 = _mm256_fmsub_ps(ymm0, ymm15, ymm3);\ +\ + ymm0 = _mm256_broadcast_ss((float const *)(b11 + cs_b));\ + ymm5 = _mm256_fmsub_ps(ymm0, ymm15, ymm5);\ +\ + ymm0 = _mm256_broadcast_ss((float const *)(b11 + cs_b*2));\ + ymm7 = _mm256_fmsub_ps(ymm0, ymm15, ymm7); + +#define BLIS_POST_STRSM_SMALL_3N_1M(b11,cs_b)\ + ymm0 = _mm256_broadcast_ss((float const *)b11);\ + ymm3 = _mm256_blend_ps(ymm0, ymm3, 0x01);\ + ymm0 = _mm256_broadcast_ss((float const *)(b11 + cs_b));\ + ymm5 = _mm256_blend_ps(ymm0, ymm5, 0x01);\ + ymm0 = _mm256_broadcast_ss((float const *)(b11 + cs_b*2));\ + ymm7 = _mm256_blend_ps(ymm0, ymm7, 0x01);\ +\ + _mm_store_ss((b11 + cs_b * 0), _mm256_extractf128_ps(ymm3, 0));\ + _mm_store_ss((b11 + cs_b * 1), _mm256_extractf128_ps(ymm5, 0));\ + _mm_store_ss((b11 + cs_b * 2), _mm256_extractf128_ps(ymm7, 0)); + +#define BLIS_PRE_STRSM_SMALL_2N_7M(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); /*register to hold alpha*/\ +\ + __m128 xmm6;\ + xmm6 = _mm_broadcast_ss((float const*)(b11 + 6));\ + xmm5 = _mm_loadu_ps((float const*)(b11));\ + ymm6 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + xmm5 = _mm_loadl_pi(xmm6,(__m64*)(b11 + 4));\ + ymm6 = _mm256_insertf128_ps(ymm6, xmm5, 1);\ + ymm3 = _mm256_fmsub_ps(ymm6, ymm15, ymm3);\ +\ + xmm6 = _mm_broadcast_ss((float const*)(b11 + 6 + cs_b));\ + xmm5 = _mm_loadu_ps((float const*)(b11 + cs_b));\ + ymm6 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + xmm5 = _mm_loadl_pi(xmm6,(__m64*)(b11 + 4 + cs_b));\ + ymm6 = _mm256_insertf128_ps(ymm6, xmm5, 1);\ + ymm5 = _mm256_fmsub_ps(ymm6, ymm15, ymm5); + +#define BLIS_POST_STRSM_SMALL_2N_7M(b11,cs_b)\ + xmm5 = _mm256_extractf128_ps(ymm3, 0);\ + _mm_storeu_ps((float *)(b11), xmm5);\ + _mm_storel_pi((__m64 *)(b11 + 4), _mm256_extractf128_ps(ymm3, 1));\ + _mm_store_ss((float *)(b11 + 6),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm3,ymm3), 1));\ + xmm5 = _mm256_extractf128_ps(ymm5, 0);\ + _mm_storeu_ps((float *)(b11 + cs_b), xmm5);\ + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b), _mm256_extractf128_ps(ymm5, 1));\ + _mm_store_ss((float *)(b11 + 6 + cs_b),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm5,ymm5), 1)); + +#define BLIS_PRE_STRSM_SMALL_2N_6M(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); /*register to hold alpha*/\ +\ + __m128 xmm6;\ + xmm6 = _mm_broadcast_ss((float const*)&zero);\ + xmm5 = _mm_loadu_ps((float const*)(b11));\ + ymm6 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + xmm5 = _mm_loadl_pi(xmm6,(__m64*)(b11 + 4));\ + ymm6 = _mm256_insertf128_ps(ymm6, xmm5, 1);\ + ymm3 = _mm256_fmsub_ps(ymm6, ymm15, ymm3);\ +\ + xmm5 = _mm_loadu_ps((float const*)(b11 + cs_b));\ + ymm6 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + xmm5 = _mm_loadl_pi(xmm6,(__m64*)(b11 + 4 + cs_b));\ + ymm6 = _mm256_insertf128_ps(ymm6, xmm5, 1);\ + ymm5 = _mm256_fmsub_ps(ymm6, ymm15, ymm5); + +#define BLIS_POST_STRSM_SMALL_2N_6M(b11,cs_b)\ + xmm5 = _mm256_extractf128_ps(ymm3, 0);\ + _mm_storeu_ps((float *)(b11), xmm5);\ + _mm_storel_pi((__m64 *)(b11 + 4), _mm256_extractf128_ps(ymm3, 1));\ + xmm5 = _mm256_extractf128_ps(ymm5, 0);\ + _mm_storeu_ps((float *)(b11 + cs_b*1), xmm5);\ + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*1), _mm256_extractf128_ps(ymm5, 1)); + +#define BLIS_PRE_STRSM_SMALL_2N_5M(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); /*register to hold alpha*/\ +\ + ymm0 = _mm256_broadcast_ss((float const *)(b11 + 4));\ + xmm5 = _mm_loadu_ps((float const*)(b11));\ + ymm6 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + ymm3 = _mm256_fmsub_ps(ymm6, ymm15, ymm3);\ + ymm0 = _mm256_broadcast_ss((float const *)(b11 + 4 + cs_b * 1));\ + xmm5 = _mm_loadu_ps((float const*)(b11 + cs_b * 1));\ + ymm6 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + ymm5 = _mm256_fmsub_ps(ymm6, ymm15, ymm5); + +#define BLIS_POST_STRSM_SMALL_2N_5M(b11,cs_b)\ + xmm5 = _mm256_extractf128_ps(ymm3, 0);\ + _mm_storeu_ps((float *)(b11), xmm5);\ + _mm_store_ss((float *)(b11 + 4), _mm256_extractf128_ps(ymm3, 1));\ + xmm5 = _mm256_extractf128_ps(ymm5, 0);\ + _mm_storeu_ps((float *)(b11 + cs_b*1), xmm5);\ + _mm_store_ss((float *)(b11 + cs_b*1 + 4), _mm256_extractf128_ps(ymm5, 1));\ + +#define BLIS_PRE_STRSM_SMALL_2N_4M(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); /*register to hold alpha*/\ +\ + xmm5 = _mm_loadu_ps((float const*)(b11));\ + ymm6 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + ymm3 = _mm256_fmsub_ps(ymm6, ymm15, ymm3);\ + xmm5 = _mm_loadu_ps((float const*)(b11 + cs_b * 1));\ + ymm6 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + ymm5 = _mm256_fmsub_ps(ymm6, ymm15, ymm5); + +#define BLIS_POST_STRSM_SMALL_2N_4M(b11,cs_b)\ + xmm5 = _mm256_extractf128_ps(ymm3, 0);\ + _mm_storeu_ps((float *)(b11), xmm5);\ + xmm5 = _mm256_extractf128_ps(ymm5, 0);\ + _mm_storeu_ps((float *)(b11 + cs_b*1), xmm5); + +#define BLIS_PRE_STRSM_SMALL_2N_3M(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); /*register to hold alpha*/\ +\ + __m128 xmm6 = _mm_broadcast_ss((float const *)(b11+ 2));\ + xmm5 = _mm_loadl_pi(xmm6,(__m64*)(b11));\ + ymm6 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + ymm3 = _mm256_fmsub_ps(ymm6, ymm15, ymm3);\ + xmm6 = _mm_broadcast_ss((float const *)(b11 + cs_b * 1+ 2));\ + xmm5 = _mm_loadl_pi(xmm6,(__m64*)(b11 + cs_b * 1));\ + ymm6 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + ymm5 = _mm256_fmsub_ps(ymm6, ymm15, ymm5); + +#define BLIS_POST_STRSM_SMALL_2N_3M(b11,cs_b)\ + xmm5 = _mm256_extractf128_ps(ymm3, 0);\ + _mm_storel_pi((__m64 *)(b11), xmm5);\ + _mm_store_ss((float *)(b11 + 2),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm3,ymm3), 0));\ + xmm5 = _mm256_extractf128_ps(ymm5, 0);\ + _mm_storel_pi((__m64 *)(b11 + cs_b*1), xmm5);\ + _mm_store_ss((float *)(b11 + 2 + cs_b*1),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm5,ymm5), 0)); + +#define BLIS_PRE_STRSM_SMALL_2N_2M(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); /*register to hold alpha*/\ +\ + xmm5 = _mm_loadl_pi(xmm5,(__m64*)(b11));\ + ymm6 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + ymm3 = _mm256_fmsub_ps(ymm6, ymm15, ymm3);\ + xmm5 = _mm_loadl_pi(xmm5,(__m64*)(b11 + cs_b * 1));\ + ymm6 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + ymm5 = _mm256_fmsub_ps(ymm6, ymm15, ymm5); + +#define BLIS_POST_STRSM_SMALL_2N_2M(b11,cs_b)\ + /*ymm0 = _mm256_loadu_ps((float const *)b11);*/\ + /*ymm3 = _mm256_blend_ps(ymm0, ymm3, 0x03);\ + xmm5 = _mm_loadu_pd((float const*)(b11 + cs_b * 1));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ + ymm5 = _mm256_blend_ps(ymm0, ymm5, 0x03);*/\ +\ + xmm5 = _mm256_extractf128_ps(ymm3, 0);\ + _mm_storel_pi((__m64 *)(b11), xmm5);\ + xmm5 = _mm256_extractf128_ps(ymm5, 0);\ + _mm_storel_pi((__m64 *)(b11 + cs_b * 1), xmm5); + +#define BLIS_PRE_STRSM_SMALL_2N_1M(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); /*register to hold alpha*/\ +\ + ymm0 = _mm256_broadcast_ss((float const *)b11);\ + ymm3 = _mm256_fmsub_ps(ymm0, ymm15, ymm3);\ +\ + ymm0 = _mm256_broadcast_ss((float const *)(b11 + cs_b));\ + ymm5 = _mm256_fmsub_ps(ymm0, ymm15, ymm5); + +#define BLIS_POST_STRSM_SMALL_2N_1M(b11,cs_b)\ + /*ymm0 = _mm256_broadcast_ss((float const *)b11);\ + ymm3 = _mm256_blend_ps(ymm0, ymm3, 0x01);\ + ymm0 = _mm256_broadcast_ss((float const *)(b11 + cs_b));\ + ymm5 = _mm256_blend_ps(ymm0, ymm5, 0x01);*/\ +\ + _mm_store_ss(b11, _mm256_extractf128_ps(ymm3, 0));\ + _mm_store_ss((b11 + cs_b * 1), _mm256_extractf128_ps(ymm5, 0)); + +#define BLIS_PRE_STRSM_SMALL_1N_7M(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); /*register to hold alpha*/\ +\ + xmm5 = _mm_loadu_ps((float const*)(b11));\ + ymm6 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + xmm5 = _mm_broadcast_ss((float const*)(b11 + 6));\ + xmm5 = _mm_loadl_pi(xmm5,(__m64*)(b11 + 4));\ + ymm6 = _mm256_insertf128_ps(ymm6, xmm5, 1);\ + ymm3 = _mm256_fmsub_ps(ymm6, ymm15, ymm3); + +#define BLIS_POST_STRSM_SMALL_1N_7M(b11,cs_b)\ + xmm5 = _mm256_extractf128_ps(ymm3, 0);\ + _mm_storeu_ps((float *)(b11), xmm5);\ + _mm_storel_pi((__m64 *)(b11 + 4), _mm256_extractf128_ps(ymm3, 1));\ + _mm_store_ss((float *)(b11 + 6),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm3,ymm3), 1)); + +#define BLIS_PRE_STRSM_SMALL_1N_6M(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); /*register to hold alpha*/\ +\ + xmm5 = _mm_loadu_ps((float const*)(b11));\ + ymm6 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + xmm5 = _mm_loadl_pi(xmm5,(__m64*)(b11 + 4));\ + ymm6 = _mm256_insertf128_ps(ymm6, xmm5, 1);\ + ymm3 = _mm256_fmsub_ps(ymm6, ymm15, ymm3); + +#define BLIS_POST_STRSM_SMALL_1N_6M(b11,cs_b)\ + xmm5 = _mm256_extractf128_ps(ymm3, 0);\ + _mm_storeu_ps((float *)(b11), xmm5);\ + _mm_storel_pi((__m64 *)(b11 + 4), _mm256_extractf128_ps(ymm3, 1)); + +#define BLIS_PRE_STRSM_SMALL_1N_5M(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); /*register to hold alpha*/\ +\ + ymm0 = _mm256_broadcast_ss((float const *)(b11 + 4));\ + xmm5 = _mm_loadu_ps((float const*)(b11));\ + ymm6 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + ymm3 = _mm256_fmsub_ps(ymm6, ymm15, ymm3); + +#define BLIS_POST_STRSM_SMALL_1N_5M(b11,cs_b)\ + xmm5 = _mm256_extractf128_ps(ymm3, 0);\ + _mm_storeu_ps((float *)(b11), xmm5);\ + _mm_store_ss((float *)(b11 + 4), _mm256_extractf128_ps(ymm3, 1)); + +#define BLIS_PRE_STRSM_SMALL_1N_4M(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); /*register to hold alpha*/\ +\ + xmm5 = _mm_loadu_ps((float const*)(b11));\ + ymm6 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + ymm3 = _mm256_fmsub_ps(ymm6, ymm15, ymm3); + +#define BLIS_POST_STRSM_SMALL_1N_4M(b11,cs_b)\ + xmm5 = _mm256_extractf128_ps(ymm3, 0);\ + _mm_storeu_ps((float *)(b11), xmm5); + +#define BLIS_PRE_STRSM_SMALL_1N_3M(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); /*register to hold alpha*/\ +\ + __m128 xmm6 = _mm_broadcast_ss((float const *)(b11+ 2));\ + xmm5 = _mm_loadl_pi(xmm6,(__m64*)(b11));\ + ymm6 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + ymm3 = _mm256_fmsub_ps(ymm6, ymm15, ymm3); + +#define BLIS_POST_STRSM_SMALL_1N_3M(b11,cs_b)\ + xmm5 = _mm256_extractf128_ps(ymm3, 0);\ + _mm_storel_pi((__m64 *)(b11), xmm5);\ + _mm_store_ss((float *)(b11 + 2),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm3,ymm3), 0)); + +#define BLIS_PRE_STRSM_SMALL_1N_2M(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); /*register to hold alpha*/\ +\ + xmm5 = _mm_loadl_pi(xmm5,(__m64*)(b11));\ + ymm6 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + ymm3 = _mm256_fmsub_ps(ymm6, ymm15, ymm3); + +#define BLIS_POST_STRSM_SMALL_1N_2M(b11,cs_b)\ +\ + xmm5 = _mm256_extractf128_ps(ymm3, 0);\ + _mm_storel_pi((__m64 *)(b11), xmm5); + +#define BLIS_PRE_STRSM_SMALL_1N_1M(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); /*register to hold alpha*/\ +\ + ymm6 = _mm256_broadcast_ss((float const *)b11);\ + ymm3 = _mm256_fmsub_ps(ymm6, ymm15, ymm3); + +#define BLIS_POST_STRSM_SMALL_1N_1M(b11,cs_b)\ +\ + _mm_store_ss(b11, _mm256_extractf128_ps(ymm3, 0)); + +/* multiply with Alpha pre TRSM for 6*16 kernel*/ +#define BLIS_PRE_STRSM_SMALL_6x16(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal);\ +\ + ymm0 = _mm256_loadu_ps((float const *)b11);\ + ymm1 = _mm256_loadu_ps((float const *)(b11 + 8));\ +\ + ymm3 = _mm256_fmsub_ps(ymm0, ymm15, ymm3);\ + ymm4 = _mm256_fmsub_ps(ymm1, ymm15, ymm4);\ +\ + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b));\ + ymm1 = _mm256_loadu_ps((float const *)(b11 + cs_b + 8));\ +\ + ymm5 = _mm256_fmsub_ps(ymm0, ymm15, ymm5);\ + ymm6 = _mm256_fmsub_ps(ymm1, ymm15, ymm6);\ +\ + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2));\ + ymm1 = _mm256_loadu_ps((float const *)(b11 + cs_b*2 + 8));\ +\ + ymm7 = _mm256_fmsub_ps(ymm0, ymm15, ymm7);\ + ymm8 = _mm256_fmsub_ps(ymm1, ymm15, ymm8);\ +\ + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*3));\ + ymm1 = _mm256_loadu_ps((float const *)(b11 + cs_b*3 + 8));\ +\ + ymm9 = _mm256_fmsub_ps(ymm0, ymm15, ymm9);\ + ymm10 = _mm256_fmsub_ps(ymm1, ymm15, ymm10);\ +\ + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4));\ + ymm1 = _mm256_loadu_ps((float const *)(b11 + cs_b*4 + 8));\ +\ + ymm11 = _mm256_fmsub_ps(ymm0, ymm15, ymm11);\ + ymm12 = _mm256_fmsub_ps(ymm1, ymm15, ymm12);\ +\ + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5));\ + ymm1 = _mm256_loadu_ps((float const *)(b11 + cs_b*5 + 8));\ +\ + ymm13 = _mm256_fmsub_ps(ymm0, ymm15, ymm13);\ + ymm14 = _mm256_fmsub_ps(ymm1, ymm15, ymm14); + +#define BLIS_PRE_STRSM_SMALL_4x16(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_ss((float const *)(&AlphaVal));\ +\ + ymm0 = _mm256_loadu_ps((float const *)b11);\ + ymm1 = _mm256_loadu_ps((float const *)(b11 + 8));\ +\ + ymm3 = _mm256_fmsub_ps(ymm0, ymm15, ymm3);\ + ymm4 = _mm256_fmsub_ps(ymm1, ymm15, ymm4);\ +\ + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b));\ + ymm1 = _mm256_loadu_ps((float const *)(b11 + cs_b + 8));\ +\ + ymm5 = _mm256_fmsub_ps(ymm0, ymm15, ymm5);\ + ymm6 = _mm256_fmsub_ps(ymm1, ymm15, ymm6);\ +\ + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2));\ + ymm1 = _mm256_loadu_ps((float const *)(b11 + cs_b*2 + 8));\ +\ + ymm7 = _mm256_fmsub_ps(ymm0, ymm15, ymm7);\ + ymm8 = _mm256_fmsub_ps(ymm1, ymm15, ymm8);\ +\ + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*3));\ + ymm1 = _mm256_loadu_ps((float const *)(b11 + cs_b*3 + 8));\ +\ + ymm9 = _mm256_fmsub_ps(ymm0, ymm15, ymm9);\ + ymm10 = _mm256_fmsub_ps(ymm1, ymm15, ymm10); + +#define BLIS_PRE_STRSM_SMALL_6x8(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); /*register to hold alpha*/\ +\ + ymm0 = _mm256_loadu_ps((float const *)b11);\ + ymm3 = _mm256_fmsub_ps(ymm0, ymm15, ymm3);\ +\ + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b));\ + ymm5 = _mm256_fmsub_ps(ymm0, ymm15, ymm5);\ +\ + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2));\ + ymm7 = _mm256_fmsub_ps(ymm0, ymm15, ymm7);\ +\ + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*3));\ + ymm9 = _mm256_fmsub_ps(ymm0, ymm15, ymm9);\ +\ + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4));\ + ymm11 = _mm256_fmsub_ps(ymm0, ymm15, ymm11);\ +\ + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5));\ + ymm13 = _mm256_fmsub_ps(ymm0, ymm15, ymm13); + +/* + Load b11 of size 6x8 and multiply with alpha + Add the GEMM output and perform inregister transose of b11 + to peform DTRSM operation for left cases. +*/ +#define BLIS_STRSM_SMALL_NREG_TRANSPOSE_6x16(b11,cs_b,AlphaVal) \ + ymm16 = _mm256_broadcast_ss((float const *)(&AlphaVal));\ +\ + ymm17 = _mm256_loadu_ps((float const *)(b11));\ + ymm18 = _mm256_loadu_ps((float const *)(b11 + cs_b));\ + ymm19 = _mm256_loadu_ps((float const *)(b11 + cs_b*2));\ + ymm20 = _mm256_loadu_ps((float const *)(b11 + cs_b*3));\ + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4));\ + ymm1 = _mm256_loadu_ps((float const *)(b11 + cs_b*5));\ +\ + ymm17 = _mm256_fmsub_ps(ymm17, ymm16, ymm8);\ + ymm18 = _mm256_fmsub_ps(ymm18, ymm16, ymm9);\ + ymm19 = _mm256_fmsub_ps(ymm19, ymm16, ymm10);\ + ymm20 = _mm256_fmsub_ps(ymm20, ymm16, ymm11);\ + ymm0 = _mm256_fmsub_ps(ymm0 , ymm16, ymm4);\ + ymm1 = _mm256_fmsub_ps(ymm1 , ymm16, ymm5);\ +\ + ymm8 = _mm256_unpacklo_ps(ymm17, ymm18);\ + ymm9 = _mm256_unpacklo_ps(ymm19, ymm20);\ +\ + ymm16 = _mm256_unpacklo_ps(ymm0, ymm1);\ +\ + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b01000100);\ + ymm5 = _mm256_shuffle_ps(ymm16,ymm16,0b01000100);\ +\ + ymm10 = _mm256_permute2f128_ps(ymm4,ymm5,0x20);/*1*/\ + ymm2 = _mm256_permute2f128_ps(ymm4,ymm5,0x31);/*5*/\ +\ + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b11101110);\ + ymm5 = _mm256_shuffle_ps(ymm16,ymm16,0b11101110);\ +\ + ymm11 = _mm256_permute2f128_ps(ymm4,ymm5,0x20);/*2*/\ + ymm3 = _mm256_permute2f128_ps(ymm4,ymm5,0x31);/*6*/\ +\ + ymm8 = _mm256_unpackhi_ps(ymm17, ymm18);\ + ymm9 = _mm256_unpackhi_ps(ymm19, ymm20);\ +\ + ymm16 = _mm256_unpackhi_ps(ymm0, ymm1);\ +\ + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b01000100);\ + ymm5 = _mm256_shuffle_ps(ymm16,ymm16,0b01000100);\ +\ + ymm17 = _mm256_permute2f128_ps(ymm4,ymm5,0x20);/*3*/\ + ymm19 = _mm256_permute2f128_ps(ymm4,ymm5,0x31);/*7*/\ +\ + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b11101110);\ + ymm5 = _mm256_shuffle_ps(ymm16,ymm16,0b11101110);\ +\ + ymm18 = _mm256_permute2f128_ps(ymm4,ymm5,0x20);/*4*/\ + ymm20 = _mm256_permute2f128_ps(ymm4,ymm5,0x31);/*8*/\ +\ + ymm16 = _mm256_broadcast_ss((float const *)(&AlphaVal));\ +\ + ymm8 = _mm256_loadu_ps((float const *)(b11 + 8));\ + ymm9 = _mm256_loadu_ps((float const *)(b11 + cs_b + 8));\ + ymm4 = _mm256_loadu_ps((float const *)(b11 + cs_b*2 + 8));\ + ymm5 = _mm256_loadu_ps((float const *)(b11 + cs_b*3 + 8));\ + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4 + 8));\ + ymm1 = _mm256_loadu_ps((float const *)(b11 + cs_b*5 + 8));\ +\ + ymm8 = _mm256_fmsub_ps(ymm8, ymm16, ymm12);\ + ymm9 = _mm256_fmsub_ps(ymm9, ymm16, ymm13);\ + ymm4 = _mm256_fmsub_ps(ymm4, ymm16, ymm14);\ + ymm5 = _mm256_fmsub_ps(ymm5, ymm16, ymm15);\ + ymm0 = _mm256_fmsub_ps(ymm0, ymm16, ymm6);\ + ymm1 = _mm256_fmsub_ps(ymm1, ymm16, ymm7);\ +\ + ymm12 = _mm256_unpacklo_ps(ymm8, ymm9);\ + ymm13 = _mm256_unpacklo_ps(ymm4, ymm5);\ +\ + ymm16 = _mm256_unpacklo_ps(ymm0, ymm1);\ +\ + ymm6 = _mm256_shuffle_ps(ymm12,ymm13,0b01000100);\ + ymm7 = _mm256_shuffle_ps(ymm16,ymm16,0b01000100);\ +\ + ymm14 = _mm256_permute2f128_ps(ymm6,ymm7,0x20);/*1*/\ + ymm21 = _mm256_permute2f128_ps(ymm6,ymm7,0x31);/*5*/\ +\ + ymm6 = _mm256_shuffle_ps(ymm12,ymm13,0b11101110);\ + ymm7 = _mm256_shuffle_ps(ymm16,ymm16,0b11101110);\ +\ + ymm15 = _mm256_permute2f128_ps(ymm6,ymm7,0x20);/*2*/\ + ymm22 = _mm256_permute2f128_ps(ymm6,ymm7,0x31);/*6*/\ +\ + ymm12 = _mm256_unpackhi_ps(ymm8, ymm9);\ + ymm13 = _mm256_unpackhi_ps(ymm4, ymm5);\ +\ + ymm16 = _mm256_unpackhi_ps(ymm0, ymm1);\ +\ + ymm6 = _mm256_shuffle_ps(ymm12,ymm13,0b01000100);\ + ymm7 = _mm256_shuffle_ps(ymm16,ymm16,0b01000100);\ +\ + ymm8 = _mm256_permute2f128_ps(ymm6,ymm7,0x20);/*3*/\ + ymm4 = _mm256_permute2f128_ps(ymm6,ymm7,0x31);/*7*/\ +\ + ymm6 = _mm256_shuffle_ps(ymm12,ymm13,0b11101110);\ + ymm7 = _mm256_shuffle_ps(ymm16,ymm16,0b11101110);\ +\ + ymm9 = _mm256_permute2f128_ps(ymm6,ymm7,0x20);/*4*/\ + ymm5 = _mm256_permute2f128_ps(ymm6,ymm7,0x31);/*8*/ + +#define BLIS_STRSM_SMALL_NREG_TRANSPOSE_16x6_AND_STORE(b11,cs_b)\ + ymm0 = _mm256_unpacklo_ps(ymm10, ymm11);\ + ymm1 = _mm256_unpacklo_ps(ymm17, ymm18);\ +\ + ymm6 = _mm256_unpacklo_ps(ymm2, ymm3);\ + ymm7 = _mm256_unpacklo_ps(ymm19, ymm20);\ +\ + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b01000100);\ + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b01000100);\ +\ + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);/*1*/\ + _mm256_storeu_ps((float *)(b11), ymm16);\ +\ + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x31);/*5*/\ + _mm256_storeu_ps((float *)(b11 + 4*cs_b), ymm16);\ +\ + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b11101110);\ + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b11101110);\ +\ + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);/*2*/\ + _mm256_storeu_ps((float *)(b11 + cs_b), ymm16);\ +\ + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x31);/*6*/\ + _mm256_storeu_ps((float *)(b11 + 5*cs_b), ymm16);\ +\ + ymm0 = _mm256_unpackhi_ps(ymm10, ymm11);\ + ymm1 = _mm256_unpackhi_ps(ymm17, ymm18);\ +\ + ymm6 = _mm256_unpackhi_ps(ymm2, ymm3);\ + ymm7 = _mm256_unpackhi_ps(ymm19, ymm20);\ +\ + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b01000100);\ + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b01000100);\ +\ + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);/*3*/\ + _mm256_storeu_ps((float *)(b11 + 2*cs_b), ymm16);\ +\ + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b11101110);\ + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b11101110);\ +\ + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);/*4*/\ + _mm256_storeu_ps((float *)(b11 + 3*cs_b), ymm16);\ +\ + ymm0 = _mm256_unpacklo_ps(ymm14, ymm15);\ + ymm1 = _mm256_unpacklo_ps(ymm8, ymm9);\ +\ + ymm6 = _mm256_unpacklo_ps(ymm21, ymm22);\ + ymm7 = _mm256_unpacklo_ps(ymm4, ymm5);\ +\ + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b01000100);\ + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b01000100);\ +\ + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);/*1*/\ + _mm256_storeu_ps((float *)(b11 + 8), ymm16);\ +\ + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x31);/*5*/\ + _mm256_storeu_ps((float *)(b11 + 4*cs_b + 8), ymm16);\ +\ + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b11101110);\ + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b11101110);\ +\ + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);/*2*/\ + _mm256_storeu_ps((float *)(b11 + cs_b + 8), ymm16);\ +\ + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x31);/*6*/\ + _mm256_storeu_ps((float *)(b11 + 5*cs_b + 8), ymm16);\ +\ + ymm0 = _mm256_unpackhi_ps(ymm14, ymm15);\ + ymm1 = _mm256_unpackhi_ps(ymm8, ymm9);\ +\ + ymm6 = _mm256_unpackhi_ps(ymm21, ymm22);\ + ymm7 = _mm256_unpackhi_ps(ymm4, ymm5);\ +\ + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b01000100);\ + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b01000100);\ +\ + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);/*3*/\ + _mm256_storeu_ps((float *)(b11 + 2*cs_b + 8), ymm16);\ +\ + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b11101110);\ + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b11101110);\ +\ + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);/*4*/\ + _mm256_storeu_ps((float *)(b11 + 3*cs_b + 8), ymm16); + +/* + Pack a block of 8xk or 6xk from input buffer into packed buffer + directly or after transpose based on input params +*/ +BLIS_INLINE void bli_dtrsm_small_pack +( + char side, + dim_t size, + bool trans, + double *inbuf, + dim_t cs_a, + double *pbuff, + dim_t p_lda, + dim_t mr +) +{ + //scratch registers + __m256d ymm0, ymm1, ymm2, ymm3; + __m256d ymm4, ymm5, ymm6, ymm7; + __m256d ymm8, ymm9, ymm10, ymm11; + __m256d ymm12, ymm13; + __m128d xmm0,xmm1,xmm2,xmm3; + double zero = 0.0; + + if(side=='L'||side=='l') + { + /*Left case is 8xk*/ + if(trans) + { + /* + ------------- ------------- + | | | | | + | 4x8 | | | | + ------------- ==> | 8x4 | 8x4 | + | 4x8 | | | | + | | | | | + ------------- ------------- + */ + for(dim_t x = 0; x < size; x += mr) + { + ymm0 = _mm256_loadu_pd((double const *)(inbuf)); + ymm10 = _mm256_loadu_pd((double const *)(inbuf + 4)); + ymm1 = _mm256_loadu_pd((double const *)(inbuf + cs_a)); + ymm11 = _mm256_loadu_pd((double const *)(inbuf + 4 + cs_a)); + ymm2 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 2)); + ymm12 = _mm256_loadu_pd((double const *)(inbuf + 4 + cs_a * 2)); + ymm3 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 3)); + ymm13 = _mm256_loadu_pd((double const *)(inbuf + 4 + cs_a * 3)); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm256_storeu_pd((double *)(pbuff), ymm6); + _mm256_storeu_pd((double *)(pbuff + p_lda), ymm7); + _mm256_storeu_pd((double *)(pbuff + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(pbuff + p_lda*3), ymm9); + + ymm4 = _mm256_unpacklo_pd(ymm10, ymm11); + ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); + + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + ymm0 = _mm256_unpackhi_pd(ymm10, ymm11); + ymm1 = _mm256_unpackhi_pd(ymm12, ymm13); + + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm256_storeu_pd((double *)(pbuff + p_lda * 4), ymm6); + _mm256_storeu_pd((double *)(pbuff + p_lda * 5), ymm7); + _mm256_storeu_pd((double *)(pbuff + p_lda * 6), ymm8); + _mm256_storeu_pd((double *)(pbuff + p_lda * 7), ymm9); + + ymm0 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 4)); ymm10 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 4 + 4)); ymm1 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 5)); ymm11 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 5 + 4)); @@ -1639,6097 +3503,25840 @@ BLIS_INLINE void bli_dtrsm_small_pack ymm3 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 7)); ymm13 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 7 + 4)); - ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm256_storeu_pd((double *)(pbuff + 4), ymm6); + _mm256_storeu_pd((double *)(pbuff + 4 + p_lda), ymm7); + _mm256_storeu_pd((double *)(pbuff + 4 + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(pbuff + 4 + p_lda*3), ymm9); + + ymm4 = _mm256_unpacklo_pd(ymm10, ymm11); + ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + ymm0 = _mm256_unpackhi_pd(ymm10, ymm11); + ymm1 = _mm256_unpackhi_pd(ymm12, ymm13); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm256_storeu_pd((double *)(pbuff + 4 + p_lda * 4), ymm6); + _mm256_storeu_pd((double *)(pbuff + 4 + p_lda * 5), ymm7); + _mm256_storeu_pd((double *)(pbuff + 4 + p_lda * 6), ymm8); + _mm256_storeu_pd((double *)(pbuff + 4 + p_lda * 7), ymm9); + + inbuf += mr; + pbuff += mr*mr; + } + }else + { + //Expected multiples of 4 + p_lda = 8; + for(dim_t x = 0; x < size; x++) + { + ymm0 = _mm256_loadu_pd((double const *)(inbuf)); + _mm256_storeu_pd((double *)(pbuff), ymm0); + ymm1 = _mm256_loadu_pd((double const *)(inbuf + 4)); + _mm256_storeu_pd((double *)(pbuff + 4), ymm1); + inbuf+=cs_a; + pbuff+=p_lda; + } + } + }else if(side=='R'||side=='r') + { + + if(trans) + { + /* + ------------------ ---------- + | | | | | | + | 4x4 | 4x4 | | 4x4 |4x2 | + ------------- ==> ------------- + | | | | | | + | 2x4 | 2x4 | | 2x4 |2x2 | + ------------------- ------------- + */ + for(dim_t x=0; x>2); i++) + { + ymm0 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 0 )); + _mm256_storeu_pd((double *)(pbuff + p_lda * 0), ymm0); + ymm1 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 1 )); + _mm256_storeu_pd((double *)(pbuff + p_lda * 1), ymm1); + ymm2 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 2)); + _mm256_storeu_pd((double *)(pbuff + p_lda * 2), ymm2); + ymm3 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 3 )); + _mm256_storeu_pd((double *)(pbuff + p_lda * 3), ymm3); + ymm0 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 4 )); + _mm256_storeu_pd((double *)(pbuff + p_lda * 4), ymm0); + ymm1 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 5)); + _mm256_storeu_pd((double *)(pbuff + p_lda * 5), ymm1); + inbuf += 4; + pbuff += 4; + } + + if(size & 0x3) + { + xmm0 = _mm_loadu_pd((double const *)(inbuf + cs_a * 0)); + _mm_storeu_pd((double *)(pbuff + p_lda * 0 ), xmm0); + xmm1 = _mm_loadu_pd((double const *)(inbuf + cs_a * 1)); + _mm_storeu_pd((double *)(pbuff + p_lda * 1), xmm1); + xmm2 = _mm_loadu_pd((double const *)(inbuf + cs_a * 2)); + _mm_storeu_pd((double *)(pbuff + p_lda * 2), xmm2); + xmm3 = _mm_loadu_pd((double const *)(inbuf + cs_a * 3)); + _mm_storeu_pd((double *)(pbuff + p_lda * 3), xmm3); + xmm0 = _mm_loadu_pd((double const *)(inbuf + cs_a * 4)); + _mm_storeu_pd((double *)(pbuff + p_lda * 4), xmm0); + xmm1 = _mm_loadu_pd((double const *)(inbuf + cs_a * 5)); + _mm_storeu_pd((double *)(pbuff + p_lda * 5), xmm1); + } + } + } +} +/* + Pack diagonal elements of A block (8 or 6) into an array + a. This helps in utilze cache line efficiently in TRSM operation + b. store ones when input is unit diagonal +*/ +BLIS_INLINE void dtrsm_small_pack_diag_element +( + bool is_unitdiag, + double *a11, + dim_t cs_a, + double *d11_pack, + dim_t size +) +{ + __m256d ymm0, ymm1, ymm2, ymm3; + __m256d ymm4, ymm5; + double ones = 1.0; + bool is_eight = (size==8) ? 1 : 0; + ymm4 = ymm5 = _mm256_broadcast_sd((double const *)&ones); + if(!is_unitdiag) + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a +1)); + ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a*2 + 2)); + ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a*3 + 3)); + + //Pick one element each column and create a 4 element vector and store + ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); + + #ifdef BLIS_DISABLE_TRSM_PREINVERSION + ymm4 = ymm1; + #endif + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + ymm4 = _mm256_div_pd(ymm4, ymm1); + #endif + + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11 + 4 + cs_a*4)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5 + cs_a*5)); + //Pick one element each column and create a 4 element vector and store + ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); + if(is_eight) { + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6 + cs_a*6)); + ymm3 = _mm256_broadcast_sd((double const *)(a11 + 7 + cs_a*7)); + ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); + } + ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); + + #ifdef BLIS_DISABLE_TRSM_PREINVERSION + ymm5 = ymm1; + #endif + + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + ymm5 = _mm256_div_pd(ymm5, ymm1); + #endif + } + _mm256_store_pd((double *)(d11_pack), ymm4); + if(is_eight){ + _mm256_store_pd((double *)(d11_pack + 4), ymm5); + }else{ + _mm_storeu_pd((double *)(d11_pack + 4), _mm256_extractf128_pd(ymm5,0)); + } +} + +/* + * Kernels Table +*/ +trsmsmall_ker_ft ker_fps[4][8] = +{ + {bli_strsm_small_AutXB_AlXB, + bli_strsm_small_AltXB_AuXB, + bli_strsm_small_AltXB_AuXB, + bli_strsm_small_AutXB_AlXB, + bli_strsm_small_XAutB_XAlB, + bli_strsm_small_XAltB_XAuB, + bli_strsm_small_XAltB_XAuB, + bli_strsm_small_XAutB_XAlB }, + + {bli_ctrsm_small_AutXB_AlXB, + bli_ctrsm_small_AltXB_AuXB, + bli_ctrsm_small_AltXB_AuXB, + bli_ctrsm_small_AutXB_AlXB, + bli_ctrsm_small_XAutB_XAlB, + bli_ctrsm_small_XAltB_XAuB, + bli_ctrsm_small_XAltB_XAuB, + bli_ctrsm_small_XAutB_XAlB }, + + {bli_dtrsm_small_AutXB_AlXB, + bli_dtrsm_small_AltXB_AuXB, + bli_dtrsm_small_AltXB_AuXB, + bli_dtrsm_small_AutXB_AlXB, + bli_dtrsm_small_XAutB_XAlB, + bli_dtrsm_small_XAltB_XAuB, + bli_dtrsm_small_XAltB_XAuB, + bli_dtrsm_small_XAutB_XAlB }, + + {bli_ztrsm_small_AutXB_AlXB, + bli_ztrsm_small_AltXB_AuXB, + bli_ztrsm_small_AltXB_AuXB, + bli_ztrsm_small_AutXB_AlXB, + bli_ztrsm_small_XAutB_XAlB, + bli_ztrsm_small_XAltB_XAuB, + bli_ztrsm_small_XAltB_XAuB, + bli_ztrsm_small_XAutB_XAlB }, +}; + +/* +* The bli_trsm_small implements a version of TRSM where A is packed and reused +* +* Input: A: MxM (triangular matrix) +* B: MxN matrix +* Output: X: MxN matrix such that + AX = alpha*B or XA = alpha*B or A'X = alpha*B or XA' = alpha*B +* Here the output X is stored in B +* +* Note: Currently only dtrsm is supported when A & B are column-major +*/ +err_t bli_trsm_small +( + side_t side, + obj_t* alpha, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +) +{ + err_t err; + dim_t m = bli_obj_length(b); + dim_t n = bli_obj_width(b); + + if(!(m && n)) { + return BLIS_SUCCESS; + } + + bool uplo = bli_obj_is_upper(a); + bool transa = bli_obj_has_trans(a); + + /* ToDo: Temporary threshold condition for trsm single thread. + * It will be updated with arch based threshold function which reads + * tunned thresholds for all 64 (datatype,side,uplo,transa,unit,) trsm + combinations. We arrived to this condition based on performance + comparsion with only available native path + */ + if(m > 1000 || n > 1000) { + return BLIS_NOT_YET_IMPLEMENTED; + } + + /* If alpha is zero, B matrix will become zero after scaling + hence solution is also zero matrix */ + if (bli_obj_equals(alpha, &BLIS_ZERO)) { + return BLIS_NOT_YET_IMPLEMENTED; // scale B by alpha + } + + // Return if inputs are row major as currently + // we are supporing col major only + if ((bli_obj_row_stride(a) != 1) || + (bli_obj_row_stride(b) != 1)) { + return BLIS_INVALID_ROW_STRIDE; + } + + //Curretnly optimized for double data type only + num_t dt = bli_obj_dt(a); + if (dt != BLIS_DOUBLE && dt != BLIS_FLOAT && dt != BLIS_DCOMPLEX) { + return BLIS_NOT_YET_IMPLEMENTED; + } + + // A is expected to be triangular in trsm + if (!bli_obj_is_upper_or_lower (a)) { + return BLIS_EXPECTED_TRIANGULAR_OBJECT; + } + + /* + * Compose kernel index based on inputs + */ + + dim_t keridx = ( (( side & 0x1) << 2) | + (( uplo & 0x1) << 1) | + ( transa & 0x1) ); + + + trsmsmall_ker_ft ker_fp = ker_fps[dt][ keridx ]; + + /*Call the kernel*/ + err = ker_fp + ( + alpha, + a, + b, + cntx, + cntl + ); + + return err; +}; + +/* + * ZTRSM utilities and kernel functions + */ + +#define DCOMPLEX_INV(a, b) {\ + a.real = b.real;\ + a.imag = (b.imag * -1.0);\ + /*Compute denominator eliminating imaginary component*/\ + double dnm = (b.real * b.real);\ + /*multiply two times with -1 for correct result as + * dcomplex number with positive imaginary part will + * invert the sign if not multiplied twice with -1*/\ + dnm += ((-1.0 * (b.imag * b.imag)) * -1.0);\ + /*Compute the final result by dividing real and imag part by dnm*/\ + a.real /= dnm;\ + a.imag /= dnm;\ +} + +#define DCOMPLEX_MUL(a, b, c) {\ + double real = a.real * b.real;\ + real += ((a.imag * b.imag) * -1.0);\ + double imag = (a.real * b.imag);\ + imag += (a.imag * b.real);\ + c.real = real;\ + c.imag = imag;\ +} + +#define DCOMPLEX_DIV(a, b){\ + double dnm = b.real * b.real;\ + dnm += (-1.0 * (b.imag * (b.imag * -1.0) ));\ + a.real /= dnm;\ + a.imag /= dnm;\ +} + + +#ifdef BLIS_ENABLE_TRSM_PREINVERSION +#define ZTRSM_DIAG_ELE_INV_OPS(a,b){\ + DCOMPLEX_INV(a, b)\ +} +#endif + +#ifdef BLIS_DISABLE_TRSM_PREINVERSION +#define ZTRSM_DIAG_ELE_INV_OPS(a,b) {\ + a.real = b.real;\ + a.imag = b.imag;\ +} +#endif + + +#ifdef BLIS_ENABLE_TRSM_PREINVERSION +#define ZTRSM_DIAG_ELE_EVAL_OPS(a,b,c){\ + if(!is_unitdiag)\ + DCOMPLEX_MUL(b, c, c)\ +} +#endif + +#ifdef BLIS_DISABLE_TRSM_PREINVERSION +#define ZTRSM_DIAG_ELE_EVAL_OPS(a,b,c){\ + if(!is_unitdiag)\ + {\ + a.real = b.real;\ + a.imag = (b.imag * -1.0);\ + DCOMPLEX_MUL(c, a, c)\ + DCOMPLEX_DIV(c, b)\ + }\ +} +#endif + +BLIS_INLINE err_t ztrsm_AltXB_ref +( + dcomplex *A, + dcomplex *B, + dim_t M, + dim_t N, + dim_t lda, + dim_t ldb, + bool is_unitdiag, + bool conjtransa +) +{ + dim_t i, j, k; + for (k = M-1; k >= 0; k--) + { + dcomplex lkk_inv = {1.0, 1.0}, cur_compute = {0.0, 0.0}, A_trans = {0.0, 0.0}; + if(!is_unitdiag) + { + ZTRSM_DIAG_ELE_INV_OPS(lkk_inv, A[k+k*lda]) + if(conjtransa) + { + lkk_inv.imag *= -1.0; + } + } + for (j = N -1; j >= 0; j--) + { + ZTRSM_DIAG_ELE_EVAL_OPS(cur_compute, lkk_inv, B[k + j*ldb]) + for (i = k-1; i >=0; i--) + { + if(conjtransa) + { + A_trans.real = A[i*lda + k].real; + A_trans.imag = A[i*lda + k].imag * -1.0; + } + else + { + A_trans.real = A[i*lda + k].real; + A_trans.imag = A[i*lda + k].imag; + } + + + DCOMPLEX_MUL(A_trans, B[k+j*ldb], cur_compute) + B[i + j*ldb].real -= cur_compute.real; + B[i + j*ldb].imag -= cur_compute.imag; + } + } + } + return BLIS_SUCCESS; +} + +BLIS_INLINE err_t ztrsm_AutXB_ref +( + dcomplex *A, + dcomplex *B, + dim_t M, + dim_t N, + dim_t lda, + dim_t ldb, + bool is_unitdiag, + bool conjtransa +) +{ + dim_t i, j, k; + for (k = 0; k < M; k++) + { + dcomplex lkk_inv = {1.0, 1.0}, cur_compute = {0.0, 0.0}, A_trans = {0.0, 0.0}; + if(!is_unitdiag) + { + ZTRSM_DIAG_ELE_INV_OPS(lkk_inv, A[k+k*lda]) + if(conjtransa) + { + lkk_inv.imag *= -1.0; + } + } + + for (j = 0; j < N; j++) + { + ZTRSM_DIAG_ELE_EVAL_OPS(cur_compute, lkk_inv, B[k + j*ldb]) + for (i = k+1; i < M; i++) + { + if(conjtransa) + { + A_trans.real = A[k+i*lda].real; + A_trans.imag = A[k+i*lda].imag * -1.0; + } + else + { + A_trans.real = A[k+i*lda].real; + A_trans.imag = A[k+i*lda].imag; + } + + DCOMPLEX_MUL(A_trans, B[k+j*ldb], cur_compute) + B[i + j*ldb].real -= cur_compute.real; + B[i + j*ldb].imag -= cur_compute.imag; + } + + } + + } + return BLIS_SUCCESS; +} + +BLIS_INLINE err_t ztrsm_AlXB_ref +( + dcomplex *A, + dcomplex *B, + dim_t M, + dim_t N, + dim_t lda, + dim_t ldb, + bool is_unitdiag, + bool conjtransa +) +{ + dim_t i, j, k; + for (k = 0; k < M; k++) + { + dcomplex lkk_inv = {1.0, 1.0}, cur_compute = {0.0, 0.0}, A_trans = {0.0, 0.0}; + if(!is_unitdiag) + { + ZTRSM_DIAG_ELE_INV_OPS(lkk_inv, A[k+k*lda]) + if(conjtransa) + { + lkk_inv.imag *= -1.0; + } + } + for (j = 0; j < N; j++) + { + ZTRSM_DIAG_ELE_EVAL_OPS(cur_compute, lkk_inv, B[k + j*ldb]) + for (i = k+1; i < M; i++) + { + if(conjtransa) + { + A_trans.real = A[i+k*lda].real; + A_trans.imag = A[i+k*lda].imag * -1.0; + } + else + { + A_trans.real = A[i+k*lda].real; + A_trans.imag = A[i+k*lda].imag; + } + DCOMPLEX_MUL(A_trans, B[k+j*ldb], cur_compute) + B[i + j*ldb].real -= cur_compute.real; + B[i + j*ldb].imag -= cur_compute.imag; + } + } + } + return BLIS_SUCCESS; +} + +BLIS_INLINE err_t ztrsm_AuXB_ref +( + dcomplex *A, + dcomplex *B, + dim_t M, + dim_t N, + dim_t lda, + dim_t ldb, + bool is_unitdiag, + bool conjtransa +) +{ + dim_t i, j, k; + for (k = M-1; k >= 0; k--) + { + dcomplex lkk_inv = {1.0, 1.0}, cur_compute = {0.0, 0.0}, A_trans = {0.0, 0.0}; + if(!is_unitdiag) + { + ZTRSM_DIAG_ELE_INV_OPS(lkk_inv, A[k+k*lda]) + if(conjtransa) + { + lkk_inv.imag *= -1.0; + } + + } + for (j = N -1; j >= 0; j--) + { + ZTRSM_DIAG_ELE_EVAL_OPS(cur_compute, lkk_inv, B[k + j*ldb]) + for (i = k-1; i >=0; i--) + { + if(conjtransa) + { + A_trans.real = A[i+k*lda].real; + A_trans.imag = A[i+k*lda].imag * -1.0; + } + else + { + A_trans.real = A[i+k*lda].real; + A_trans.imag = A[i+k*lda].imag; + } + + DCOMPLEX_MUL(A_trans, B[k+j*ldb], cur_compute) + B[i + j*ldb].real -= cur_compute.real; + B[i + j*ldb].imag -= cur_compute.imag; + } + } + } + return BLIS_SUCCESS; +} + +/** + * Multiplies Alpha with one dcomplex + * element of one column. + * One xmm register holds one dcomplex + * element only(real(64 bit) + imaginary(64 bit)) + */ +#define BLIS_PRE_ZTRSM_SMALL_1M_1N(AlphaVal,b11,cs_b) {\ + /*register to hold alpha*/\ + ymm16 = _mm256_broadcast_pd(( __m128d const *)(&AlphaVal));\ + \ + /*load dcomplex elements*/\ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 0));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ + /*to negate the real part of complex number*/\ + ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ + /*dcomplex multiplication and substraction*/\ + /*swaps position of real and imag components of complex number*/\ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + /*multiply with modified vec2 */\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm0, ymm16);\ + /*multiply with vec2 */\ + ymm14 = _mm256_mul_pd(ymm0, ymm14);\ + /*get the dcomplex mul answer into register*/\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm8 = _mm256_sub_pd(ymm15,ymm8);\ + xmm5 = _mm256_extractf128_pd(ymm8, 0);\ + /*store dcomplex elements*/\ + _mm_storeu_pd((double *)(b11 + cs_b * 0), xmm5);\ +} + +/** + * Multiplies Alpha with one dcomplex + * element of two columns. + */ +#define BLIS_PRE_ZTRSM_SMALL_1M_2N(AlphaVal,b11,cs_b) {\ + /*register to hold alpha*/\ + ymm16 = _mm256_broadcast_pd(( __m128d const*)(&AlphaVal));\ + \ + /*ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0));*/\ + xmm4 = _mm_loadu_pd((double const *)(b11 + cs_b * 0));\ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 1));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm4, 0);\ + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0);\ + /*to negate the real part of complex number*/\ + ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ + /*swaps position of real and imag components of complex number*/\ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + /*dcomplex multiplication and substraction*/\ + /*multiply with modified vec2 */\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm0, ymm16);\ + /*multiply with vec2 */\ + ymm14 = _mm256_mul_pd(ymm0, ymm14);\ + /*get the dcomplex mul answer into register*/\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm8 = _mm256_sub_pd(ymm15,ymm8);\ + \ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm1, ymm16);\ + ymm14 = _mm256_mul_pd(ymm1, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm9 = _mm256_sub_pd(ymm15,ymm9);\ + xmm4 = _mm256_extractf128_pd(ymm8, 0);\ + _mm_storeu_pd((double *)(b11 + cs_b * 0), xmm4);\ + xmm5 = _mm256_extractf128_pd(ymm9, 0);\ + _mm_storeu_pd((double *)(b11 + cs_b * 1), xmm5);\ +} + +#define BLIS_ZTRSM_SMALL_NREG_TRANSPOSE_1x4(b11,cs_b,AlphaVal) {\ + ymm16 = _mm256_broadcast_pd(( __m128d const *)&AlphaVal);\ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0));\ + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 2));\ + ymm1 = _mm256_broadcast_pd((__m128d const *)(&ones));\ + ymm5 = _mm256_broadcast_pd((__m128d const *)(&ones));\ + \ + ymm14 = _mm256_shuffle_pd(ymm16, ymm16, 5);\ + \ + /*dcomplex multiplication and substraction*/\ + ymm17 = _mm256_shuffle_pd(ymm0, ymm0, 15);\ + ymm18 = _mm256_shuffle_pd(ymm0, ymm0,0);\ + ymm19 = _mm256_mul_pd(ymm17, ymm14);\ + ymm15 = _mm256_fmaddsub_pd(ymm18, ymm16, ymm19);\ + ymm0 = _mm256_sub_pd(ymm15, ymm8);\ + \ + /*dcomplex multiplication and substraction*/\ + ymm17 = _mm256_shuffle_pd(ymm4, ymm4, 15);\ + ymm18 = _mm256_shuffle_pd(ymm4, ymm4,0);\ + ymm19 = _mm256_mul_pd(ymm17, ymm14);\ + ymm15 = _mm256_fmaddsub_pd(ymm18, ymm16, ymm19);\ + ymm4 = _mm256_sub_pd(ymm15, ymm12);\ +} + +/** + * Multiplies Alpha with two dcomplex + * elements of one column and store it into + * buffer b11. + */ +#define BLIS_PRE_ZTRSM_SMALL_2M_1N(AlphaVal,b11,cs_b) {\ + ymm16 = _mm256_broadcast_pd(( __m128d const*)(&AlphaVal));\ + \ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0));\ + ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ + /*dcomplex multiplication and substraction*/\ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm0, ymm16);\ + ymm14 = _mm256_mul_pd(ymm0, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm8 = _mm256_sub_pd(ymm15,ymm8);\ + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm8);\ +} + +/** + * Multiplies Alpha with two elements of + * two columns and store the result in buffer b11 + * + */ +#define BLIS_PRE_ZTRSM_SMALL_2M_2N(AlphaVal,b11,cs_b){\ + ymm16 = _mm256_broadcast_pd(( __m128d const*)(&AlphaVal));\ + \ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0));\ + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1));\ + ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ + /*dcomplex multiplication and substraction*/\ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm0, ymm16);\ + ymm14 = _mm256_mul_pd(ymm0, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm8 = _mm256_sub_pd(ymm15,ymm8);\ + \ + /*dcomplex multiplication and substraction*/\ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm1, ymm16);\ + ymm14 = _mm256_mul_pd(ymm1, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm9 = _mm256_sub_pd(ymm15,ymm9);\ + \ + _mm256_storeu_pd((double *)(b11), ymm8);\ + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm9);\ +} + +/** + * Performs GEMM operation. + * Two elements of column in ymm0 + * ymm1, ymm2 holds respective broadcasted element. + */ +#define BLIS_ZTRSM_SMALL_GEMM_2mx3n(a10,b01,cs_b,p_lda,k_iter){\ + double *tptr = (double *)b01;\ + if(conjtransa) {\ + ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_loadu_pd((double const *)(a10));\ + ymm0 = _mm256_mul_pd(ymm0, ymm18);\ + \ + ymm1 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0));\ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0 + 1));\ + \ + ymm8 = _mm256_fmadd_pd(ymm0, ymm1, ymm8);\ + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4);\ + \ + ymm1 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1));\ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1 + 1));\ + \ + ymm9 = _mm256_fmadd_pd(ymm0, ymm1, ymm9);\ + ymm5 = _mm256_fmadd_pd(ymm0, ymm2, ymm5);\ + \ + ymm1 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 2));\ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 2 + 1));\ + \ + ymm10 = _mm256_fmadd_pd(ymm0, ymm1, ymm10);\ + ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6);\ + \ + tptr += 2; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_loadu_pd((double const *)(a10));\ + \ + ymm1 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0));\ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0 + 1));\ + \ + ymm8 = _mm256_fmadd_pd(ymm0, ymm1, ymm8);\ + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4);\ + \ + ymm1 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1));\ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1 + 1));\ + \ + ymm9 = _mm256_fmadd_pd(ymm0, ymm1, ymm9);\ + ymm5 = _mm256_fmadd_pd(ymm0, ymm2, ymm5);\ + \ + ymm1 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 2));\ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 2 + 1));\ + \ + ymm10 = _mm256_fmadd_pd(ymm0, ymm1, ymm10);\ + ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6);\ + \ + tptr += 2; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + }\ + }\ + ymm4 = _mm256_permute_pd(ymm4, 0x5);\ + ymm5 = _mm256_permute_pd(ymm5, 0x5);\ + ymm6 = _mm256_permute_pd(ymm6, 0x5);\ + ymm8 = _mm256_addsub_pd(ymm8, ymm4);\ + ymm9 = _mm256_addsub_pd(ymm9, ymm5);\ + ymm10 = _mm256_addsub_pd(ymm10, ymm6);\ +} + +/** + * Performs GEMM operation. + * Four elements of column in ymm0, ymm1. + * ymm2, ymm7 holds respective broadcasted element. + */ +#define BLIS_ZTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) {\ + double *tptr = (double *)a01;\ + if(conjtransa) {\ + ymm18 = _mm256_set_pd(-1.0, -1.0, -1.0, -1.0);\ + for(k = 0; k < k_iter; k++)\ + {\ + ymm0 = _mm256_loadu_pd((double const *)b10);\ + ymm1 = _mm256_loadu_pd((double const *)(b10 + 2));\ + \ + _mm_prefetch((char*)( b10 + 4*cs_b), _MM_HINT_T0); \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0));\ + ymm7 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0 + 1));\ + ymm7 = _mm256_mul_pd(ymm7, ymm18);\ + /*dcomplex multiplication and substraction*/\ + \ + ymm3 = _mm256_fmadd_pd(ymm0, ymm2, ymm3);\ + ymm4 = _mm256_fmadd_pd(ymm1, ymm2, ymm4);\ + ymm5 = _mm256_fmadd_pd(ymm0, ymm7, ymm5);\ + ymm6 = _mm256_fmadd_pd(ymm1, ymm7, ymm6);\ + /*dcomplex multiplication and substraction*/\ + \ + tptr += 2;\ + b10 += cs_b;\ + }\ + }\ + else {\ + for(k = 0; k < k_iter; k++)\ + {\ + ymm0 = _mm256_loadu_pd((double const *)b10);\ + ymm1 = _mm256_loadu_pd((double const *)(b10 + 2));\ + \ + _mm_prefetch((char*)( b10 + 4*cs_b), _MM_HINT_T0); \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0));\ + ymm7 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0 + 1));\ + /*dcomplex multiplication and substraction*/\ + \ + ymm3 = _mm256_fmadd_pd(ymm0, ymm2, ymm3);\ + ymm4 = _mm256_fmadd_pd(ymm1, ymm2, ymm4);\ + ymm5 = _mm256_fmadd_pd(ymm0, ymm7, ymm5);\ + ymm6 = _mm256_fmadd_pd(ymm1, ymm7, ymm6);\ + /*ymm3 = _mm256_add_pd(ymm15, ymm3);*/\ + /*dcomplex multiplication and substraction*/\ + \ + tptr += 2;\ + b10 += cs_b;\ + }\ + }\ + ymm5 = _mm256_permute_pd(ymm5, 0x5);\ + ymm6 = _mm256_permute_pd(ymm6, 0x5);\ +\ + ymm3 = _mm256_addsub_pd(ymm3, ymm5);\ + ymm4 = _mm256_addsub_pd(ymm4, ymm6);\ +} + +/** + * Multiplies Alpha with 4 elements of column + */ +#define BLIS_PRE_ZTRSM_SMALL_1x4(b11,cs_b,AlphaVal) {\ + ymm16 = _mm256_broadcast_pd((__m128d const *)&AlphaVal);\ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0));\ + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 2));\ +\ + ymm14 = _mm256_shuffle_pd(ymm16, ymm16, 5);\ +\ + ymm17 = _mm256_shuffle_pd(ymm0, ymm0, 15);\ + ymm18 = _mm256_shuffle_pd(ymm0, ymm0,0);\ + ymm19 = _mm256_mul_pd(ymm17, ymm14);\ + ymm15 = _mm256_fmaddsub_pd(ymm18, ymm16, ymm19);\ + ymm3 = _mm256_sub_pd(ymm15, ymm3);\ +\ + ymm17 = _mm256_shuffle_pd(ymm1, ymm1, 15);\ + ymm18 = _mm256_shuffle_pd(ymm1, ymm1,0);\ + ymm19 = _mm256_mul_pd(ymm17, ymm14);\ + ymm15 = _mm256_fmaddsub_pd(ymm18, ymm16, ymm19);\ + ymm4 = _mm256_sub_pd(ymm15, ymm4);\ +} + +/** + * Multiplies Alpha with 3 elements of column. + * ymm0 holds first 2 element and xmm5 holds the + * 3rd one. + */ +#define BLIS_PRE_ZTRSM_SMALL_1x3(b11,cs_b,AlphaVal) {\ + ymm16 = _mm256_broadcast_pd((__m128d const *)&AlphaVal);\ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0));\ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 0 + 2));\ + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0);\ +\ + ymm14 = _mm256_shuffle_pd(ymm16, ymm16, 5);\ +\ + ymm17 = _mm256_shuffle_pd(ymm0, ymm0, 15);\ + ymm18 = _mm256_shuffle_pd(ymm0, ymm0,0);\ + ymm19 = _mm256_mul_pd(ymm17, ymm14);\ + ymm15 = _mm256_fmaddsub_pd(ymm18, ymm16, ymm19);\ + ymm3 = _mm256_sub_pd(ymm15, ymm3);\ +\ + ymm17 = _mm256_shuffle_pd(ymm1, ymm1, 15);\ + ymm18 = _mm256_shuffle_pd(ymm1, ymm1,0);\ + ymm19 = _mm256_mul_pd(ymm17, ymm14);\ + ymm15 = _mm256_fmaddsub_pd(ymm18, ymm16, ymm19);\ + ymm4 = _mm256_sub_pd(ymm15, ymm4);\ +} + +#define BLIS_ZTRSM_SMALL_NREG_TRANSPOSE_2x4(b11,cs_b,AlphaVal) {\ + ymm16 = _mm256_broadcast_pd((__m128d const *)&AlphaVal);\ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0));\ + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1));\ + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 2));\ + ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 2));\ + ymm14 = _mm256_shuffle_pd(ymm16, ymm16, 5);\ +\ + ymm17 = _mm256_shuffle_pd(ymm0, ymm0, 15);\ + ymm18 = _mm256_shuffle_pd(ymm0, ymm0,0);\ + ymm19 = _mm256_mul_pd(ymm17, ymm14);\ + ymm15 = _mm256_fmaddsub_pd(ymm18, ymm16, ymm19);\ + ymm0 = _mm256_sub_pd(ymm15, ymm8);\ +\ + ymm17 = _mm256_shuffle_pd(ymm1, ymm1, 15);\ + ymm18 = _mm256_shuffle_pd(ymm1, ymm1,0);\ + ymm19 = _mm256_mul_pd(ymm17, ymm14);\ + ymm15 = _mm256_fmaddsub_pd(ymm18, ymm16, ymm19);\ + ymm1 = _mm256_sub_pd(ymm15, ymm9);\ +\ + ymm17 = _mm256_shuffle_pd(ymm4, ymm4, 15);\ + ymm18 = _mm256_shuffle_pd(ymm4, ymm4,0);\ + ymm19 = _mm256_mul_pd(ymm17, ymm14);\ + ymm15 = _mm256_fmaddsub_pd(ymm18, ymm16, ymm19);\ + ymm4 = _mm256_sub_pd(ymm15, ymm12);\ +\ + ymm17 = _mm256_shuffle_pd(ymm5, ymm5, 15);\ + ymm18 = _mm256_shuffle_pd(ymm5, ymm5,0);\ + ymm19 = _mm256_mul_pd(ymm17, ymm14);\ + ymm15 = _mm256_fmaddsub_pd(ymm18, ymm16, ymm19);\ + ymm5 = _mm256_sub_pd(ymm15, ymm13);\ +} + +#define BLIS_PRE_ZTRSM_SMALL_3M_1N(AlphaVal,b11,cs_b){\ + ymm16 = _mm256_broadcast_pd(( __m128d const *)(&AlphaVal));\ + \ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0));\ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 0 + 2));\ + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0);\ + \ + ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ + /*dcomplex multiplication and substraction*/\ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm0, ymm16);\ + ymm14 = _mm256_mul_pd(ymm0, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm8 = _mm256_sub_pd(ymm15,ymm8);\ + \ + /*dcomplex multiplication and substraction*/\ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm1, ymm16);\ + ymm14 = _mm256_mul_pd(ymm1, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm12 = _mm256_sub_pd(ymm15,ymm12);\ + \ + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm8);\ + xmm5 = _mm256_extractf128_pd(ymm12, 0);\ + _mm_storeu_pd((double *)(b11 + cs_b * 0 + 2), xmm5);\ +} + +/** + * Multiplies Alpha with 3 elements of 2 columns + * and store into buffer b11. + * ymm0 ymm1 holds first 2 elements of 2 columns. + * xmm4 xmm5 holds the 3rd elements of 2 columns. + */ +#define BLIS_PRE_ZTRSM_SMALL_3M_2N(AlphaVal,b11,cs_b){\ + ymm16 = _mm256_broadcast_pd(( __m128d const*)(&AlphaVal));\ + \ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0));\ + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1));\ + xmm4 = _mm_loadu_pd((double const *)(b11 + cs_b * 0 + 2));\ + ymm3 = _mm256_insertf128_pd(ymm3, xmm4, 0);\ +\ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 1 + 2));\ + ymm4 = _mm256_insertf128_pd(ymm4, xmm5, 0);\ +\ + ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ + /*dcomplex multiplication and substraction*/\ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm0, ymm16);\ + ymm14 = _mm256_mul_pd(ymm0, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm8 = _mm256_sub_pd(ymm15,ymm8);\ + \ + /*dcomplex multiplication and substraction*/\ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm1, ymm16);\ + ymm14 = _mm256_mul_pd(ymm1, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm9 = _mm256_sub_pd(ymm15,ymm9);\ + \ + /*dcomplex multiplication and substraction*/\ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm3, ymm16);\ + ymm14 = _mm256_mul_pd(ymm3, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm12 = _mm256_sub_pd(ymm15,ymm12);\ + \ + /*dcomplex multiplication and substraction*/\ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm4, ymm16);\ + ymm14 = _mm256_mul_pd(ymm4, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm13 = _mm256_sub_pd(ymm15,ymm13);\ + \ + _mm256_storeu_pd((double *)(b11), ymm8);\ + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm9);\ + xmm4 = _mm256_extractf128_pd(ymm12, 0);\ + _mm_storeu_pd((double *)(b11 + cs_b * 0 + 2), xmm4);\ + xmm5 = _mm256_extractf128_pd(ymm13, 0);\ + _mm_storeu_pd((double *)(b11 + cs_b * 1 + 2), xmm5);\ +} + +/** + * Performs GEMM operation + * ymm0 holds 2 elements of column. + * ymm4 ymm6 holds broadcasted elements respectively + */ +#define BLIS_ZTRSM_SMALL_GEMM_3nx2m(a01,b10,cs_b,p_lda,k_iter) {\ + double *tptr = (double *)a01;\ + if(conjtransa) {\ + ymm18 = _mm256_set_pd(-1.0, -1.0, -1.0, -1.0);\ + for(k = 0; k< k_iter; k++) \ + {\ + ymm0 = _mm256_loadu_pd((double const *)(b10)); \ + \ + _mm_prefetch((char*)( b10 + 2*cs_b), _MM_HINT_T0); \ + ymm4 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0)); \ + ymm6 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0 + 1)); \ + ymm6 = _mm256_mul_pd(ymm6, ymm18);\ + /*dcomplex multiplication and substraction*/\ + \ + ymm3 = _mm256_fmadd_pd(ymm0, ymm4, ymm3);\ + ymm8 = _mm256_fmadd_pd(ymm0, ymm6, ymm8);\ + \ + ymm4 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1)); \ + ymm6 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1 + 1)); \ + ymm6 = _mm256_mul_pd(ymm6, ymm18);\ + \ + /*dcomplex multiplication and substraction*/\ + \ + ymm5 = _mm256_fmadd_pd(ymm0, ymm4, ymm5);\ + ymm9 = _mm256_fmadd_pd(ymm0, ymm6, ymm9);\ + \ + ymm4 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 2)); \ + ymm6 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 2 + 1)); \ + ymm6 = _mm256_mul_pd(ymm6, ymm18);\ + \ + /*dcomplex multiplication and substraction*/\ + \ + ymm7 = _mm256_fmadd_pd(ymm0, ymm4, ymm7);\ + ymm10 = _mm256_fmadd_pd(ymm0, ymm6, ymm10);\ + \ + tptr += 2; \ + b10 += cs_b; \ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) \ + {\ + ymm0 = _mm256_loadu_pd((double const *)(b10)); \ + \ + _mm_prefetch((char*)( b10 + 2*cs_b), _MM_HINT_T0); \ + ymm4 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0)); \ + ymm6 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0 + 1)); \ + /*dcomplex multiplication and substraction*/\ + \ + ymm3 = _mm256_fmadd_pd(ymm0, ymm4, ymm3);\ + ymm8 = _mm256_fmadd_pd(ymm0, ymm6, ymm8);\ + /*ymm3 = _mm256_add_pd(ymm15, ymm3);*/\ + \ + ymm4 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1)); \ + ymm6 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1 + 1)); \ + \ + /*dcomplex multiplication and substraction*/\ + \ + ymm5 = _mm256_fmadd_pd(ymm0, ymm4, ymm5);\ + ymm9 = _mm256_fmadd_pd(ymm0, ymm6, ymm9);\ + /*ymm5 = _mm256_add_pd(ymm15, ymm5);*/\ + \ + ymm4 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 2)); \ + ymm6 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 2 + 1)); \ + \ + /*dcomplex multiplication and substraction*/\ + \ + ymm7 = _mm256_fmadd_pd(ymm0, ymm4, ymm7);\ + ymm10 = _mm256_fmadd_pd(ymm0, ymm6, ymm10);\ + /*ymm7 = _mm256_add_pd(ymm15, ymm7);*/\ + \ + tptr += 2; \ + b10 += cs_b; \ + }\ + }\ + ymm8 = _mm256_permute_pd(ymm8, 0x5);\ + ymm9 = _mm256_permute_pd(ymm9, 0x5);\ + ymm10 = _mm256_permute_pd(ymm10, 0x5);\ + ymm3 = _mm256_addsub_pd(ymm3, ymm8);\ + ymm5 = _mm256_addsub_pd(ymm5, ymm9);\ + ymm7 = _mm256_addsub_pd(ymm7, ymm10);\ +} + +/** + * Multiplies Alpha with 2 elements of 3 columns + * ymm0 holds 2 elements of columns, once computation + * is done, it holds 2 elements of next columns after + * saving computed result into some other register. + * ymm3 ymm5 ymm7. + */ +#define BLIS_PRE_ZTRSM_SMALL_3x2(AlphaVal,b11,cs_b) {\ + ymm16 = _mm256_broadcast_pd(( __m128d const*)(&AlphaVal));\ + \ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0));\ + ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ + \ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm0, ymm16);\ + ymm14 = _mm256_mul_pd(ymm0, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm3 = _mm256_sub_pd(ymm15,ymm3);\ + \ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *1));\ +\ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm0, ymm16);\ + ymm14 = _mm256_mul_pd(ymm0, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm5 = _mm256_sub_pd(ymm15,ymm5);\ + \ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *2));\ + \ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm0, ymm16);\ + ymm14 = _mm256_mul_pd(ymm0, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm7 = _mm256_sub_pd(ymm15,ymm7);\ + \ +} + +/** + * Performs GEMM + * ymm0 and ymm1 together holds 4 elements of column. + */ +#define BLIS_ZTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) {\ + double *tptr = (double *)a01;\ + if(conjtransa) {\ + ymm18 = _mm256_set_pd(-1.0, -1.0, -1.0, -1.0);\ + for(k = 0; k< k_iter; k++) \ + { \ + ymm0 = _mm256_loadu_pd((double const *)(b10)); \ + ymm1 = _mm256_loadu_pd((double const *)(b10 + 2)); \ + \ + _mm_prefetch((char*)( b10 + 4*cs_b), _MM_HINT_T0); \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0)); \ + ymm12 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0 + 1)); \ + ymm12 = _mm256_mul_pd(ymm12, ymm18);\ + \ + ymm3 = _mm256_fmadd_pd(ymm0, ymm2, ymm3);\ + ymm4 = _mm256_fmadd_pd(ymm1, ymm2, ymm4);\ + ymm8 = _mm256_fmadd_pd(ymm0, ymm12, ymm8);\ + ymm9 = _mm256_fmadd_pd(ymm1, ymm12, ymm9);\ + \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1)); \ + ymm12 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1 + 1)); \ + ymm12 = _mm256_mul_pd(ymm12, ymm18);\ + \ + ymm5 = _mm256_fmadd_pd(ymm0, ymm2, ymm5);\ + ymm6 = _mm256_fmadd_pd(ymm1, ymm2, ymm6);\ + ymm10 = _mm256_fmadd_pd(ymm0, ymm12, ymm10);\ + ymm11 = _mm256_fmadd_pd(ymm1, ymm12, ymm11);\ + \ + tptr += 2; \ + b10 += cs_b; \ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) \ + { \ + ymm0 = _mm256_loadu_pd((double const *)(b10)); \ + ymm1 = _mm256_loadu_pd((double const *)(b10 + 2)); \ + \ + _mm_prefetch((char*)( b10 + 4*cs_b), _MM_HINT_T0); \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0)); \ + ymm12 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0 + 1)); \ + \ + ymm3 = _mm256_fmadd_pd(ymm0, ymm2, ymm3);\ + ymm4 = _mm256_fmadd_pd(ymm1, ymm2, ymm4);\ + ymm8 = _mm256_fmadd_pd(ymm0, ymm12, ymm8);\ + ymm9 = _mm256_fmadd_pd(ymm1, ymm12, ymm9);\ + \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1)); \ + ymm12 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1 + 1)); \ + \ + ymm5 = _mm256_fmadd_pd(ymm0, ymm2, ymm5);\ + ymm6 = _mm256_fmadd_pd(ymm1, ymm2, ymm6);\ + ymm10 = _mm256_fmadd_pd(ymm0, ymm12, ymm10);\ + ymm11 = _mm256_fmadd_pd(ymm1, ymm12, ymm11);\ + \ + tptr += 2; \ + b10 += cs_b; \ + }\ + }\ + ymm8 = _mm256_permute_pd(ymm8, 0x5);\ + ymm9 = _mm256_permute_pd(ymm9, 0x5);\ + ymm10 = _mm256_permute_pd(ymm10, 0x5);\ + ymm11 = _mm256_permute_pd(ymm11, 0x5);\ + ymm3 = _mm256_addsub_pd(ymm3, ymm8);\ + ymm4 = _mm256_addsub_pd(ymm4, ymm9);\ + ymm5 = _mm256_addsub_pd(ymm5, ymm10);\ + ymm6 = _mm256_addsub_pd(ymm6, ymm11);\ +} + +/** + * Performs GEMM operation + * ymm0 holds 2 elements of a column. + */ +#define BLIS_ZTRSM_SMALL_GEMM_2nx2m(a01,b10,cs_b,p_lda,k_iter){\ + double *tptr = (double *)a01;\ + if(conjtransa) {\ + ymm18 = _mm256_set_pd(-1.0, -1.0, -1.0, -1.0);\ + for(k = 0; k< k_iter; k++) \ + { \ + ymm0 = _mm256_loadu_pd((double const *)(b10)); \ + \ + _mm_prefetch((char*)( b10 + 2*cs_b), _MM_HINT_T0); \ + ymm1 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0)); \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0 + 1)); \ + ymm2 = _mm256_mul_pd(ymm2, ymm18);\ + \ + ymm3 = _mm256_fmadd_pd(ymm0, ymm1, ymm3);\ + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4);\ + \ + \ + ymm1 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1)); \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1 + 1)); \ + ymm2 = _mm256_mul_pd(ymm2, ymm18);\ + \ + ymm5 = _mm256_fmadd_pd(ymm0, ymm1, ymm5);\ + ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6);\ + \ + tptr += 2; \ + b10 += cs_b; \ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) \ + { \ + ymm0 = _mm256_loadu_pd((double const *)(b10)); \ + \ + _mm_prefetch((char*)( b10 + 2*cs_b), _MM_HINT_T0); \ + ymm1 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0)); \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0 + 1)); \ + \ + ymm3 = _mm256_fmadd_pd(ymm0, ymm1, ymm3);\ + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4);\ + \ + \ + ymm1 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1)); \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1 + 1)); \ + \ + ymm5 = _mm256_fmadd_pd(ymm0, ymm1, ymm5);\ + ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6);\ + \ + tptr += 2; \ + b10 += cs_b; \ + }\ + }\ + ymm4 = _mm256_permute_pd(ymm4, 0x5);\ + ymm6 = _mm256_permute_pd(ymm6, 0x5);\ + ymm3 = _mm256_addsub_pd(ymm3, ymm4);\ + ymm5 = _mm256_addsub_pd(ymm5, ymm6);\ +} + +/** + * Multiplies Alpha with 2 elements of a column. + * ymm0 holds the 2 element of a column. + */ +#define BLIS_PRE_ZTRSM_SMALL_1x1(AlphaVal,b11,cs_b){\ + ymm16 = _mm256_broadcast_pd(( __m128d const*)(&AlphaVal));\ + \ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 0));\ + ymm0 = _mm256_insertf128_pd(ymm1, xmm5, 0);\ + ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ + \ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm0, ymm16);\ + ymm14 = _mm256_mul_pd(ymm0, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm3 = _mm256_sub_pd(ymm15,ymm3);\ +} + +/** + * Multiplies Alpha with 2 elements of a column. + * ymm0 holds the 2 element of a column. + */ +#define BLIS_PRE_ZTRSM_SMALL_1x2(AlphaVal,b11,cs_b){\ + ymm16 = _mm256_broadcast_pd(( __m128d const*)(&AlphaVal));\ + \ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0));\ + ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ + \ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm0, ymm16);\ + ymm14 = _mm256_mul_pd(ymm0, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm3 = _mm256_sub_pd(ymm15,ymm3);\ +} + +/** + * Multiplies Alpha with 2 elements of 2 columns. + * ymm0 holds 2 elements of a columns respectively, + * once computation is done, gets stored in registers + * ymm3, ymm5 + */ +#define BLIS_PRE_ZTRSM_SMALL_2x2(AlphaVal,b11,cs_b){\ + ymm16 = _mm256_broadcast_pd(( __m128d const*)(&AlphaVal));\ + \ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0));\ + ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ + \ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm0, ymm16);\ + ymm14 = _mm256_mul_pd(ymm0, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm3 = _mm256_sub_pd(ymm15,ymm3);\ + \ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *1));\ +\ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm0, ymm16);\ + ymm14 = _mm256_mul_pd(ymm0, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm5 = _mm256_sub_pd(ymm15,ymm5);\ +} + +/** + * Performs GEMM operation + * 3 elements of a columns get held by ymm0(2 element) + * and xmm5 (1 element). + */ +#define BLIS_ZTRSM_SMALL_GEMM_1nx3m(a01,b10,cs_b,p_lda,k_iter) {\ + double *tptr = (double *)a01;\ + if(conjtransa) {\ + ymm18 = _mm256_set_pd(-1.0, -1.0, -1.0, -1.0);\ + for(k = 0; k< k_iter; k++) \ + {\ + ymm0 = _mm256_loadu_pd((double const *)(b10)); \ + /*ymm1 = _mm256_loadu_pd((double const *)(b10 + 2));*/\ + xmm5 = _mm_loadu_pd((double const *)(b10 + 2));\ + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0);\ + \ + _mm_prefetch((char*)( b10 + 4*cs_b), _MM_HINT_T0); \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0)); \ + ymm5 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0 + 1)); \ + ymm5 = _mm256_mul_pd(ymm5, ymm18);\ + \ + ymm3 = _mm256_fmadd_pd(ymm0, ymm2, ymm3);\ + ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6);\ + ymm4 = _mm256_fmadd_pd(ymm1, ymm5, ymm4);\ + ymm7 = _mm256_fmadd_pd(ymm1, ymm5, ymm7);\ + \ + tptr += 2;\ + b10 += cs_b;\ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) \ + {\ + ymm0 = _mm256_loadu_pd((double const *)(b10)); \ + /*ymm1 = _mm256_loadu_pd((double const *)(b10 + 2));*/\ + xmm5 = _mm_loadu_pd((double const *)(b10 + 2));\ + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0);\ + \ + _mm_prefetch((char*)( b10 + 4*cs_b), _MM_HINT_T0); \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0)); \ + ymm5 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0 + 1)); \ + \ + ymm3 = _mm256_fmadd_pd(ymm0, ymm2, ymm3);\ + ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6);\ + ymm4 = _mm256_fmadd_pd(ymm1, ymm5, ymm4);\ + ymm7 = _mm256_fmadd_pd(ymm1, ymm5, ymm7);\ + \ + tptr += 2;\ + b10 += cs_b;\ + }\ + }\ + ymm6 = _mm256_permute_pd(ymm6, 0x5);\ + ymm7 = _mm256_permute_pd(ymm7, 0x5);\ + ymm3 = _mm256_addsub_pd(ymm3, ymm6);\ + ymm4 = _mm256_addsub_pd(ymm5, ymm7);\ +} + + +/** + * Performs GEMM operation. + * 1 elements of a column are kept in ymm0. + */ +#define BLIS_ZTRSM_SMALL_GEMM_1nx1m(a01,b10,cs_b,p_lda,k_iter) {\ + double *tptr = (double *)a01;\ + if(conjtransa) {\ + ymm18 = _mm256_set_pd(-1.0, -1.0, -1.0, -1.0);\ + for(k = 0; k< k_iter; k++) \ + { \ + xmm5 = _mm_loadu_pd((double const *)(b10));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ + \ + _mm_prefetch((char*)( b10 + 2*cs_b), _MM_HINT_T0); \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0)); \ + ymm5 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0 + 1)); \ + ymm5 = _mm256_mul_pd(ymm5, ymm18);\ + \ + ymm3 = _mm256_fmadd_pd(ymm0, ymm2, ymm3);\ + ymm4 = _mm256_fmadd_pd(ymm0, ymm5, ymm4);\ + \ + tptr += 2; \ + b10 += cs_b; \ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) \ + { \ + xmm5 = _mm_loadu_pd((double const *)(b10));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ + \ + _mm_prefetch((char*)( b10 + 2*cs_b), _MM_HINT_T0); \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0)); \ + ymm5 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0 + 1)); \ + \ + ymm3 = _mm256_fmadd_pd(ymm0, ymm2, ymm3);\ + ymm4 = _mm256_fmadd_pd(ymm0, ymm5, ymm4);\ + \ + tptr += 2; \ + b10 += cs_b; \ + }\ + }\ + ymm4 = _mm256_permute_pd(ymm4, 0x5);\ + ymm3 = _mm256_addsub_pd(ymm3, ymm4);\ +} + + +/** + * Performs GEMM operation. + * 2 elements of a column are kept in ymm0. + */ +#define BLIS_ZTRSM_SMALL_GEMM_1nx2m(a01,b10,cs_b,p_lda,k_iter) {\ + double *tptr = (double *)a01;\ + if(conjtransa) {\ + ymm18 = _mm256_set_pd(-1.0, -1.0, -1.0, -1.0);\ + for(k = 0; k< k_iter; k++) \ + { \ + ymm0 = _mm256_loadu_pd((double const *)(b10)); \ + \ + _mm_prefetch((char*)( b10 + 2*cs_b), _MM_HINT_T0); \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0)); \ + ymm5 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0 + 1)); \ + ymm5 = _mm256_mul_pd(ymm5, ymm18);\ + \ + ymm3 = _mm256_fmadd_pd(ymm0, ymm2, ymm3);\ + ymm4 = _mm256_fmadd_pd(ymm0, ymm5, ymm4);\ + \ + tptr += 2; \ + b10 += cs_b; \ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) \ + { \ + ymm0 = _mm256_loadu_pd((double const *)(b10)); \ + \ + _mm_prefetch((char*)( b10 + 2*cs_b), _MM_HINT_T0); \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0)); \ + ymm5 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0 + 1)); \ + \ + ymm3 = _mm256_fmadd_pd(ymm0, ymm2, ymm3);\ + ymm4 = _mm256_fmadd_pd(ymm0, ymm5, ymm4);\ + \ + tptr += 2; \ + b10 += cs_b; \ + }\ + }\ + ymm4 = _mm256_permute_pd(ymm4, 0x5);\ + ymm3 = _mm256_addsub_pd(ymm3, ymm4);\ +} + +/** + * Performs GEMM operation + * 4 elements of columns are kept in ymm0 and ymm1. + */ +#define BLIS_ZTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) {\ + double *tptr = (double *)a01;\ + if(conjtransa) {\ + ymm18 = _mm256_set_pd(-1.0, -1.0, -1.0, -1.0);\ + for(k = 0; k< k_iter; k++) \ + { \ + ymm0 = _mm256_loadu_pd((double const *)(b10)); \ + ymm1 = _mm256_loadu_pd((double const *)(b10 + 2)); \ + \ + _mm_prefetch((char*)( b10 + 4*cs_b), _MM_HINT_T0); \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0));\ + ymm9 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0 + 1));\ + ymm9 = _mm256_mul_pd(ymm9, ymm18);\ + \ + ymm3 = _mm256_fmadd_pd(ymm0, ymm2, ymm3);\ + ymm4 = _mm256_fmadd_pd(ymm1, ymm2, ymm4);\ + ymm10 = _mm256_fmadd_pd(ymm0, ymm9, ymm10);\ + ymm11 = _mm256_fmadd_pd(ymm1, ymm9, ymm11);\ + \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1)); \ + ymm9 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1 + 1)); \ + ymm9 = _mm256_mul_pd(ymm9, ymm18);\ + \ + ymm5 = _mm256_fmadd_pd(ymm0, ymm2, ymm5);\ + ymm6 = _mm256_fmadd_pd(ymm1, ymm2, ymm6);\ + ymm12 = _mm256_fmadd_pd(ymm0, ymm9, ymm12);\ + ymm13 = _mm256_fmadd_pd(ymm1, ymm9, ymm13);\ + \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 2)); \ + ymm9 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 2 + 1)); \ + ymm9 = _mm256_mul_pd(ymm9, ymm18);\ + \ + ymm7 = _mm256_fmadd_pd(ymm0, ymm2, ymm7);\ + ymm8 = _mm256_fmadd_pd(ymm1, ymm2, ymm8);\ + ymm14 = _mm256_fmadd_pd(ymm0, ymm9, ymm14);\ + ymm15 = _mm256_fmadd_pd(ymm1, ymm9, ymm15);\ + \ + tptr += 2; \ + b10 += cs_b; \ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) \ + { \ + ymm0 = _mm256_loadu_pd((double const *)(b10)); \ + ymm1 = _mm256_loadu_pd((double const *)(b10 + 2)); \ + \ + _mm_prefetch((char*)( b10 + 4*cs_b), _MM_HINT_T0); \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0));\ + ymm9 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0 + 1));\ + \ + ymm3 = _mm256_fmadd_pd(ymm0, ymm2, ymm3);\ + ymm4 = _mm256_fmadd_pd(ymm1, ymm2, ymm4);\ + ymm10 = _mm256_fmadd_pd(ymm0, ymm9, ymm10);\ + ymm11 = _mm256_fmadd_pd(ymm1, ymm9, ymm11);\ + \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1)); \ + ymm9 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1 + 1)); \ + \ + ymm5 = _mm256_fmadd_pd(ymm0, ymm2, ymm5);\ + ymm6 = _mm256_fmadd_pd(ymm1, ymm2, ymm6);\ + ymm12 = _mm256_fmadd_pd(ymm0, ymm9, ymm12);\ + ymm13 = _mm256_fmadd_pd(ymm1, ymm9, ymm13);\ + \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 2)); \ + ymm9 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 2 + 1)); \ + \ + ymm7 = _mm256_fmadd_pd(ymm0, ymm2, ymm7);\ + ymm8 = _mm256_fmadd_pd(ymm1, ymm2, ymm8);\ + ymm14 = _mm256_fmadd_pd(ymm0, ymm9, ymm14);\ + ymm15 = _mm256_fmadd_pd(ymm1, ymm9, ymm15);\ + \ + tptr += 2; \ + b10 += cs_b; \ + }\ + }\ + ymm10 = _mm256_permute_pd(ymm10, 0x5);\ + ymm11 = _mm256_permute_pd(ymm11, 0x5);\ + ymm12 = _mm256_permute_pd(ymm12, 0x5);\ + ymm13 = _mm256_permute_pd(ymm13, 0x5);\ + ymm14 = _mm256_permute_pd(ymm14, 0x5);\ + ymm15 = _mm256_permute_pd(ymm15, 0x5);\ +\ + ymm3 = _mm256_addsub_pd(ymm3, ymm10);\ + ymm4 = _mm256_addsub_pd(ymm4, ymm11);\ + ymm5 = _mm256_addsub_pd(ymm5, ymm12);\ + ymm6 = _mm256_addsub_pd(ymm6, ymm13);\ + ymm7 = _mm256_addsub_pd(ymm7, ymm14);\ + ymm8 = _mm256_addsub_pd(ymm8, ymm15);\ +} + +/** + * Multiplies Alpha with 4 element of 2 columns. + * ymm0 and ymm1 holds 4 elements of a column. + */ +#define BLIS_PRE_ZTRSM_SMALL_2x4(AlphaVal,b11,cs_b) {\ + ymm16 = _mm256_broadcast_pd(( __m128d const*)(&AlphaVal));\ + \ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0));\ + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 2));\ + ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ + \ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm0, ymm16);\ + ymm14 = _mm256_mul_pd(ymm0, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm3 = _mm256_sub_pd(ymm15,ymm3);\ + \ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm1, ymm16);\ + ymm14 = _mm256_mul_pd(ymm1, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm4 = _mm256_sub_pd(ymm15,ymm4);\ + \ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *1));\ + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 2));\ +\ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm0, ymm16);\ + ymm14 = _mm256_mul_pd(ymm0, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm5 = _mm256_sub_pd(ymm15,ymm5);\ + \ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm1, ymm16);\ + ymm14 = _mm256_mul_pd(ymm1, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm6 = _mm256_sub_pd(ymm15,ymm6);\ +} + +/** + * Multiplies Alpha with 4 element of 3 columns. + * ymm0 and ymm1 holds 4 elements of a column. + */ +#define BLIS_PRE_ZTRSM_SMALL_3x4(AlphaVal,b11,cs_b) {\ + ymm16 = _mm256_broadcast_pd(( __m128d const*)(&AlphaVal));\ + \ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0));\ + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 2));\ + ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ + \ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm0, ymm16);\ + ymm14 = _mm256_mul_pd(ymm0, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm3 = _mm256_sub_pd(ymm15,ymm3);\ + \ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm1, ymm16);\ + ymm14 = _mm256_mul_pd(ymm1, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm4 = _mm256_sub_pd(ymm15,ymm4);\ + \ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *1));\ + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 2));\ +\ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm0, ymm16);\ + ymm14 = _mm256_mul_pd(ymm0, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm5 = _mm256_sub_pd(ymm15,ymm5);\ + \ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm1, ymm16);\ + ymm14 = _mm256_mul_pd(ymm1, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm6 = _mm256_sub_pd(ymm15,ymm6);\ + \ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *2));\ + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *2 + 2));\ + \ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm0, ymm16);\ + ymm14 = _mm256_mul_pd(ymm0, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm7 = _mm256_sub_pd(ymm15,ymm7);\ + \ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm1, ymm16);\ + ymm14 = _mm256_mul_pd(ymm1, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm8 = _mm256_sub_pd(ymm15,ymm8);\ + \ +} + +/* + * Pack a block of 4xk or 3xk from input buffer into packed buffer + * directly or after transpose based on input params + */ + +/* + * Load b11 of size 3x4 and multiply with alpha + * Add the GEMM output and perform inregister transose of b11 + * to peform ZTRSM operation for left cases. + */ +#define BLIS_ZTRSM_SMALL_NREG_TRANSPOSE_3x4(b11,cs_b,AlphaVal) {\ + ymm16 = _mm256_broadcast_pd(( __m128d const *)(&AlphaVal));\ +\ + ymm0 = _mm256_loadu_pd((double const *)(b11));\ + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1));\ + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2));\ + ymm3 = _mm256_broadcast_pd((__m128d const *)&ones);\ + /*in register transpose + * ymm0,ymm1,ymm2 holds + * two dcomplex elements of b11 cols*/\ + ymm14 = _mm256_shuffle_pd(ymm16, ymm16, 5);\ + ymm5 = _mm256_shuffle_pd(ymm0, ymm0, 15);\ + ymm6 = _mm256_shuffle_pd(ymm0, ymm0,0);\ + ymm7 = _mm256_mul_pd(ymm5, ymm14);\ + ymm15 = _mm256_fmaddsub_pd(ymm6, ymm16, ymm7);\ + ymm0 = _mm256_sub_pd(ymm15, ymm8);\ +\ + ymm5 = _mm256_shuffle_pd(ymm1, ymm1, 15);\ + ymm6 = _mm256_shuffle_pd(ymm1, ymm1,0);\ + ymm7 = _mm256_mul_pd(ymm5, ymm14);\ + ymm15 = _mm256_fmaddsub_pd(ymm6, ymm16, ymm7);\ + ymm1 = _mm256_sub_pd(ymm15, ymm9);\ +\ + ymm5 = _mm256_shuffle_pd(ymm2, ymm2, 15);\ + ymm6 = _mm256_shuffle_pd(ymm2, ymm2,0);\ + ymm7 = _mm256_mul_pd(ymm5, ymm14);\ + ymm15 = _mm256_fmaddsub_pd(ymm6, ymm16, ymm7);\ + ymm2 = _mm256_sub_pd(ymm15, ymm10);\ +\ + /*in register transpose of computed b11 col*/\ + ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); \ + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31);\ + ymm4 = _mm256_permute2f128_pd(ymm2,ymm3,0x20); \ + ymm5 = _mm256_permute2f128_pd(ymm2,ymm3,0x31); \ +\ + /*in register transpose + * ymm0,ymm1,ymm2 holds + * next two dcomplex elements of b11 cols*/\ + ymm0 = _mm256_loadu_pd((double const *)(b11 + 2));\ + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 2));\ + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2 + 2));\ +\ + ymm17 = _mm256_shuffle_pd(ymm0, ymm0, 15);\ + ymm18 = _mm256_shuffle_pd(ymm0, ymm0, 0);\ + ymm19 = _mm256_mul_pd(ymm17, ymm14);\ + ymm15 = _mm256_fmaddsub_pd(ymm18, ymm16, ymm19);\ + ymm0 = _mm256_sub_pd(ymm15, ymm11);\ +\ + ymm17 = _mm256_shuffle_pd(ymm1, ymm1, 15);\ + ymm18 = _mm256_shuffle_pd(ymm1, ymm1, 0);\ + ymm19 = _mm256_mul_pd(ymm17, ymm14);\ + ymm15 = _mm256_fmaddsub_pd(ymm18, ymm16, ymm19);\ + ymm1 = _mm256_sub_pd(ymm15, ymm12);\ +\ + ymm17 = _mm256_shuffle_pd(ymm2, ymm2, 15);\ + ymm18 = _mm256_shuffle_pd(ymm2, ymm2, 0);\ + ymm19 = _mm256_mul_pd(ymm17, ymm14);\ + ymm15 = _mm256_fmaddsub_pd(ymm18, ymm16, ymm19);\ + ymm2 = _mm256_sub_pd(ymm15, ymm13);\ +\ + /*in register transpose of computed b11 col*/\ + ymm10 = _mm256_permute2f128_pd(ymm0,ymm1,0x20);\ + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31);\ + ymm6 = _mm256_permute2f128_pd(ymm2,ymm3,0x20);\ + ymm7 = _mm256_permute2f128_pd(ymm2,ymm3,0x31);\ +} + +/** + * Performs GEMM operation. + * 4 elements of a column are kept inymm0 and ymm1 + */ +#define BLIS_ZTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b,p_lda,k_iter) {\ + double *tptr = (double *)b01;\ + if(conjtransa) {\ + ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_loadu_pd((double const *)(a10));\ + ymm1 = _mm256_loadu_pd((double const *)(a10 + 2));\ + ymm0 = _mm256_mul_pd(ymm0, ymm18);\ + ymm1 = _mm256_mul_pd(ymm1, ymm18);\ + \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0));\ + ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0 + 1)); \ + \ + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8);\ + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12);\ + \ + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);\ + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5);\ + tptr += 2; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_loadu_pd((double const *)(a10));\ + ymm1 = _mm256_loadu_pd((double const *)(a10 + 2));\ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0));\ + ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0 + 1)); \ + \ + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8);\ + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12);\ + \ + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);\ + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5);\ + tptr += 2; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + }\ + }\ + ymm4 = _mm256_permute_pd(ymm4, 0x5);\ + ymm5 = _mm256_permute_pd(ymm5, 0x5);\ + ymm8 = _mm256_addsub_pd(ymm8, ymm4);\ + ymm12 = _mm256_addsub_pd(ymm12, ymm5);\ +} + +/** + * Performs the GEMM operation. + * 2 elements of a column are kept in ymm0. + */ +#define BLIS_ZTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b,p_lda,k_iter) {\ + double *tptr = (double * )b01;\ + if(conjtransa) {\ + ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_loadu_pd((double const *)(a10));\ + ymm1 = _mm256_loadu_pd((double const *)(a10 + 2));\ + ymm0 = _mm256_mul_pd(ymm0, ymm18);\ + ymm1 = _mm256_mul_pd(ymm1, ymm18);\ + \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0));\ + ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0 + 1)); \ + \ + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8);\ + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12);\ + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);\ + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5);\ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1)); \ + ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1 + 1)); \ + \ + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9);\ + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13);\ + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6);\ + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7);\ + tptr += 2; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_loadu_pd((double const *)(a10));\ + ymm1 = _mm256_loadu_pd((double const *)(a10 + 2));\ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0));\ + ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0 + 1)); \ + \ + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8);\ + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12);\ + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);\ + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5);\ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1)); \ + ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1 + 1)); \ + \ + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9);\ + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13);\ + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6);\ + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7);\ + tptr += 2; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + }\ + }\ + ymm4 = _mm256_permute_pd(ymm4, 0x5);\ + ymm5 = _mm256_permute_pd(ymm5, 0x5);\ + ymm6 = _mm256_permute_pd(ymm6, 0x5);\ + ymm7 = _mm256_permute_pd(ymm7, 0x5);\ +\ + ymm8 = _mm256_addsub_pd(ymm8, ymm4);\ + ymm12 = _mm256_addsub_pd(ymm12, ymm5);\ + ymm9 = _mm256_addsub_pd(ymm9, ymm6);\ + ymm13 = _mm256_addsub_pd(ymm13, ymm7);\ +} + +/*GEMM block used in ztrsm small left cases*/ +#define BLIS_ZTRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) {\ + double *tptr = (double *)b01;\ + if(conjtransa) {\ + ymm16 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ + for(k = 0; k< k_iter; k++) \ + { \ + ymm0 = _mm256_loadu_pd((double const *)(a10)); \ + ymm1 = _mm256_loadu_pd((double const *)(a10 + 2)); \ + ymm0 = _mm256_mul_pd(ymm0, ymm16);\ + ymm1 = _mm256_mul_pd(ymm1, ymm16);\ + \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr)); \ + ymm3 = _mm256_broadcast_sd((double const *)(tptr + 1)); \ + \ + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8);\ + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11);\ + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);\ + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5);\ + \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 1 * 2)); \ + ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 1 * 2 + 1)); \ + \ + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9);\ + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12);\ + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6);\ + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7);\ + \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b *2 * 2)); \ + ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 2 + 1)); \ + \ + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10);\ + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13);\ + \ + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14);\ + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15);\ + \ + tptr += 2; \ + a10 += p_lda; \ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) \ + { \ + ymm0 = _mm256_loadu_pd((double const *)(a10)); \ + ymm1 = _mm256_loadu_pd((double const *)(a10 + 2)); \ + \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr)); \ + ymm3 = _mm256_broadcast_sd((double const *)(tptr + 1)); \ + \ + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8);\ + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11);\ + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);\ + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5);\ + \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 1 * 2)); \ + ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 1 * 2 + 1)); \ + \ + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9);\ + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12);\ + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6);\ + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7);\ + \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b *2 * 2)); \ + ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 2 + 1)); \ + \ + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10);\ + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13);\ + \ + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14);\ + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15);\ + \ + tptr += 2; \ + a10 += p_lda; \ + }\ + }\ + ymm4 = _mm256_permute_pd(ymm4, 0x5);\ + ymm5 = _mm256_permute_pd(ymm5, 0x5);\ + ymm6 = _mm256_permute_pd(ymm6, 0x5);\ + ymm7 = _mm256_permute_pd(ymm7, 0x5);\ + ymm14 = _mm256_permute_pd(ymm14, 0x5);\ + ymm15 = _mm256_permute_pd(ymm15, 0x5);\ + \ + ymm8 = _mm256_addsub_pd(ymm8, ymm4);\ + ymm11 = _mm256_addsub_pd(ymm11, ymm5);\ + ymm9 = _mm256_addsub_pd(ymm9, ymm6);\ + ymm12 = _mm256_addsub_pd(ymm12, ymm7);\ + ymm10 = _mm256_addsub_pd(ymm10, ymm14);\ + ymm13 = _mm256_addsub_pd(ymm13, ymm15);\ +} + + +#define BLIS_ZTRSM_SMALL_NREG_TRANSPOSE_4x3_AND_STORE(b11,cs_b){\ + ymm0 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20);\ + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31);\ + ymm2 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20);\ + _mm256_storeu_pd((double *)(b11), ymm0);\ + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1);\ + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2);\ +\ + ymm0 = _mm256_permute2f128_pd(ymm10, ymm11, 0x20);\ + ymm1 = _mm256_permute2f128_pd(ymm10, ymm11, 0x31);\ + ymm2 = _mm256_permute2f128_pd(ymm6, ymm7, 0x20);\ + _mm256_storeu_pd((double *)(b11 + 2), ymm0);\ + _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 2), ymm1);\ + _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 2), ymm2);\ +} + +/** + * Performs dcomplex division of vec1 and vec2 with ymm1. + * vec1 and vec2 gets divided by ymm1 which holds + * diagonal element from buffer. + * Function gets called while performing TRSM. + */ +#define BLIS_ZTRSM_TWO_DIV(vec1, vec2) {\ + if(!is_unitdiag) {\ + if(conjtransa){\ + ymm1 = _mm256_mul_pd(ymm1, ymm0);\ + }\ + ymm12 = _mm256_mul_pd(ymm1, ymm0);\ + /*perform decomplex multiplication*/\ + /* Switch the real and imaginary elements of vec2 */\ + ymm14 = _mm256_permute_pd(ymm12, 0x5);\ + /* Negate the imaginary elements of vec2 */\ + ymm14 = _mm256_mul_pd(ymm14, ymm0);\ + /* Multiply vec1 and vec2 */ \ + ymm13 = _mm256_mul_pd(vec1, ymm12); /*vec3*/\ + /* Multiply vec1 and the modified vec2 */\ + ymm14 = _mm256_mul_pd(vec1, ymm14); /*vec4*/\ + /* Horizontally subtract the elements in vec3 and vec4 */\ + vec1 = _mm256_hsub_pd(ymm13, ymm14);\ + \ + ymm14 = _mm256_permute_pd(ymm12, 0x5);\ + /* Negate the imaginary elements of vec2 */\ + ymm14 = _mm256_mul_pd(ymm14, ymm0);\ + ymm13 = _mm256_mul_pd(vec2, ymm12);\ + ymm14 = _mm256_mul_pd(vec2, ymm14);\ + vec2 = _mm256_hsub_pd(ymm13, ymm14);\ + /*dcomplex multiplication is done*/\ + /*Swapping real & imaginary component position for addition with respective + * components*/\ + ymm12 = _mm256_mul_pd(ymm1, ymm1);\ + ymm13 = _mm256_permute4x64_pd(ymm12, 0xb1);\ + ymm14 = _mm256_add_pd(ymm12, ymm13);\ + \ + /*Finally dividing numerator by denominator*/\ + vec1 = _mm256_div_pd(vec1, ymm14);\ + vec2 = _mm256_div_pd(vec2, ymm14);\ + }\ +} + +/** + * Performs dcomplex division of vec1 with ymm1. + * ymm1 holds diagonal element from buffer. + * Function gets called while performing TRSM. + */ +#define BLIS_ZTRSM_DIV(vec1) {\ + if(!is_unitdiag){\ + if(conjtransa){\ + ymm1 = _mm256_mul_pd(ymm1, ymm0);\ + }\ + ymm12 = _mm256_mul_pd(ymm1, ymm0); /*vec2 and ymm8 is vec1*/\ + ymm14 = _mm256_permute_pd(ymm12, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm0);\ + ymm13 = _mm256_mul_pd(vec1, ymm12); /*vec3*/\ + ymm14 = _mm256_mul_pd(vec1, ymm14); /*vec4*/\ + vec1 = _mm256_hsub_pd(ymm13, ymm14);\ + \ + ymm12 = _mm256_mul_pd(ymm1, ymm1);\ + ymm13 = _mm256_permute4x64_pd(ymm12, 0xb1);\ + ymm14 = _mm256_add_pd(ymm12, ymm13);\ + \ + /*Finally dividing numerator by denominator*/\ + vec1 = _mm256_div_pd(vec1, ymm14);\ + }\ +} + +/** + * Performs dcomplex multiplication of vec1 with ymm1. + * ymm1 holds diagonal element from buffer. + * Function gets called while performing TRSM. + */ +#define BLIS_ZTRSM_MUL(vec1) {\ + if(!is_unitdiag){\ + if(conjtransa){\ + ymm19 = _mm256_mul_pd(ymm1, ymm0);\ + }\ + else{\ + ymm19 = ymm1;\ + }\ + ymm14 = _mm256_permute_pd(ymm19, 0x5);\ + /* Negate the imaginary elements of vec2 */\ + ymm14 = _mm256_mul_pd(ymm14, ymm0);\ + /* Multiply vec1 and vec2 */\ + ymm13 = _mm256_mul_pd(vec1, ymm19); /*vec3*/\ + /* Multiply vec1 and the modified vec2 */\ + ymm14 = _mm256_mul_pd(vec1, ymm14); /*vec4*/\ + /* Horizontally subtract the elements in vec3 and vec4 */\ + vec1 = _mm256_hsub_pd(ymm13, ymm14);\ + }\ +} + +BLIS_INLINE void bli_ztrsm_small_pack +( + char side, + dim_t size, + bool trans, + dcomplex *inbuf, + dim_t cs_a, + dcomplex *pbuff, + dim_t p_lda, + dim_t mr +) +{ + //scratch registers + __m256d ymm0, ymm1, ymm2; + __m256d ymm5, ymm6, ymm7; + __m256d ymm8, ymm9, ymm10, ymm11; + __m128d xmm0,xmm1,xmm2; + double zero = 0.0; + + if(side=='L'||side=='l') + { + /*Left case is 4xk*/ + if(trans) + { + /* + ------------- ------------- + | | | | | + | 2x4 | | | | + ------------- ==> | 4x2 | 4x2 | + | 2x4 | | | | + | | | | | + ------------- ------------- + */ + for(dim_t x = 0; x < size; x += mr) + { + ymm0 = _mm256_loadu_pd((double const *)(inbuf)); + ymm10 = _mm256_loadu_pd((double const *)(inbuf + 2)); + ymm1 = _mm256_loadu_pd((double const *)(inbuf + cs_a)); + ymm11 = _mm256_loadu_pd((double const *)(inbuf + 2 + cs_a)); + + ymm6 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + ymm8 = _mm256_permute2f128_pd(ymm10,ymm11,0x20); + ymm9 = _mm256_permute2f128_pd(ymm10,ymm11,0x31); + + _mm256_storeu_pd((double *)(pbuff), ymm6); + _mm256_storeu_pd((double *)(pbuff + p_lda), ymm7); + _mm256_storeu_pd((double *)(pbuff + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(pbuff + p_lda*3), ymm9); + + ymm0 = _mm256_loadu_pd((double const *)(inbuf + 2 * cs_a)); + ymm10 = _mm256_loadu_pd((double const *)(inbuf + 2 * cs_a + 2)); + ymm1 = _mm256_loadu_pd((double const *)(inbuf + 3 * cs_a)); + ymm11 = _mm256_loadu_pd((double const *)(inbuf + 3 * cs_a + 2)); + + ymm6 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + ymm8 = _mm256_permute2f128_pd(ymm10,ymm11,0x20); + ymm9 = _mm256_permute2f128_pd(ymm10,ymm11,0x31); + + _mm256_storeu_pd((double *)(pbuff + 2), ymm6); + _mm256_storeu_pd((double *)(pbuff + p_lda + 2), ymm7); + _mm256_storeu_pd((double *)(pbuff + p_lda*2 + 2), ymm8); + _mm256_storeu_pd((double *)(pbuff + p_lda*3 + 2), ymm9); + + inbuf += mr; + pbuff += mr*mr; + } + }else + { + //Expected multiples of 4 + p_lda = 4; + for(dim_t x = 0; x < size; x++) + { + ymm0 = _mm256_loadu_pd((double const *)(inbuf)); + _mm256_storeu_pd((double *)(pbuff), ymm0); + ymm1 = _mm256_loadu_pd((double const *)(inbuf + 2)); + _mm256_storeu_pd((double *)(pbuff + 2), ymm1); + inbuf+=cs_a; + pbuff+=p_lda; + } + } + }else if(side=='R'||side=='r') + { + + if(trans) + { + for(dim_t x=0; x>1); i++) + { + ymm0 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 0 )); + _mm256_storeu_pd((double *)(pbuff + p_lda * 0), ymm0); + ymm1 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 1 )); + _mm256_storeu_pd((double *)(pbuff + p_lda * 1), ymm1); + ymm2 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 2)); + _mm256_storeu_pd((double *)(pbuff + p_lda * 2), ymm2); + inbuf += 2; + pbuff += 2; + } + if(size & 0x1) + { + xmm0 = _mm_loadu_pd((double const *)(inbuf + cs_a * 0)); + _mm_storeu_pd((double *)(pbuff + p_lda * 0 ), xmm0); + xmm1 = _mm_loadu_pd((double const *)(inbuf + cs_a * 1)); + _mm_storeu_pd((double *)(pbuff + p_lda * 1), xmm1); + xmm2 = _mm_loadu_pd((double const *)(inbuf + cs_a * 2)); + _mm_storeu_pd((double *)(pbuff + p_lda * 2), xmm2); + } + } + } + +} + + +BLIS_INLINE void ztrsm_small_pack_diag_element +( + bool is_unitdiag, + dcomplex *a11, + dim_t cs_a, + dcomplex *d11_pack, + dim_t size +) +{ + __m256d ymm1, ymm2, ymm3, ymm4, ymm5, ymm6, ymm7, ymm8; + bool is_four = (size == 4) ? 1 : 0; + dcomplex ones = {1.0, 1.0}; + ymm2 = ymm1 = _mm256_broadcast_pd((__m128d const *)&ones); + ymm7 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + if(!is_unitdiag) + { + //broadcast diagonal elements of A11 + ymm1 = _mm256_broadcast_pd((__m128d const *)a11); + ymm2 = _mm256_broadcast_pd((__m128d const *)a11+ cs_a +1); + /*Pick one element frome each column and create 3 element vector + and store it*/ + ymm1 = _mm256_permute2f128_pd(ymm1, ymm2, 0x20); + ymm2 = _mm256_broadcast_pd((__m128d const *)a11+ cs_a*2 + 2); + + if(is_four) + { + ymm3 = _mm256_broadcast_pd((__m128d const *)a11+ cs_a*2 + 2); + ymm2 = _mm256_broadcast_pd((__m128d const *)a11+ cs_a*3 + 3); + ymm2 = _mm256_permute2f128_pd(ymm3, ymm2, 0x20); + } + +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + /*Taking denomerator multiplication of real & imaginary components*/ + ymm4 = _mm256_mul_pd(ymm1, ymm1); + ymm5 = _mm256_mul_pd(ymm2,ymm2); + /*Swapping real & imaginary component position for addition with + * respective components*/ + ymm6 = _mm256_permute4x64_pd(ymm4, 0xb1); + ymm4 = _mm256_add_pd(ymm4, ymm6); + ymm8 = _mm256_permute4x64_pd(ymm5, 0xb1); + + ymm5 = _mm256_add_pd(ymm5, ymm8); + /*Negating imaginary component of numerator*/ + ymm1 = _mm256_mul_pd(ymm1, ymm7); + ymm2 = _mm256_mul_pd(ymm2, ymm7); + /*Dividing numerator by denominator*/ + ymm1 = _mm256_div_pd(ymm1, ymm4); + ymm2 = _mm256_div_pd(ymm2, ymm5); +#endif + + } + _mm256_store_pd((double *)d11_pack, ymm1); + if(is_four) + { + _mm256_store_pd((double *)(d11_pack + 2), ymm2); + } + else + { + _mm_store_pd((double *)(d11_pack + 2), + _mm256_extractf128_pd(ymm2,0)); + + } +} + +/*implements TRSM for the case XA = alpha * B + *A is lower triangular, non-unit diagonal/unit diagonal, transpose + *dimensions: X:mxn A:nxn B: mxn + * + * b11---> a01 ----> + ***************** *********** + *b01*b11* * * * * * * +b11 * * * * * **a01 * * a11 + | ***************** ********* | + | * * * * * *a11* * | + | * * * * * * * * | + v ***************** ****** v + * * * * * * * + * * * * * * * + ***************** * * + * + *implements TRSM for the case XA = alpha * B + *A is upper triangular, non-unit diagonal/unit diagonal, no transpose + *dimensions: X:mxn A:nxn B: mxn + * + * b11---> a01 ----> + ***************** *********** + *b01*b11* * * * * * * +b11 * * * * * **a01 * * a11 + | ***************** ********* | + | * * * * * *a11* * | + | * * * * * * * * | + v ***************** ****** v + * * * * * * * + * * * * * * * + ***************** * * + * + +*/ + +BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +) +{ + dim_t m = bli_obj_length(b); //number of rows + dim_t n = bli_obj_width(b); //number of columns + dim_t d_mr = 8,d_nr = 6; + + bool transa = bli_obj_has_trans(a); + dim_t cs_a, rs_a; + + // Swap rs_a & cs_a in case of non-tranpose. + if(transa) + { + cs_a = bli_obj_col_stride(a); // column stride of A + rs_a = bli_obj_row_stride(a); // row stride of A + } + else + { + cs_a = bli_obj_row_stride(a); // row stride of A + rs_a = bli_obj_col_stride(a); // column stride of A + } + dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B + + dim_t i, j, k; //loop variablse + dim_t k_iter; //determines the number of GEMM operations to be done + + double ones = 1.0; + double zero = 0.0; + bool is_unitdiag = bli_obj_has_unit_diag(a); + + double AlphaVal = *(double *)AlphaObj->buffer; //value of Alpha + double* restrict L = a->buffer; //pointer to matrix A + double* restrict B = b->buffer; //pointer to matrix B + + double *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks + + gint_t required_packing_A = 1; + mem_t local_mem_buf_A_s = {0}; + double *D_A_pack = NULL; + double d11_pack[d_mr] __attribute__((aligned(64))); + rntm_t rntm; + + bli_rntm_init_from_global( &rntm ); + bli_rntm_set_num_threads_only( 1, &rntm ); + bli_membrk_rntm_set_membrk( &rntm ); + + siz_t buffer_size = bli_pool_block_size( + bli_membrk_pool( + bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), + bli_rntm_membrk(&rntm))); + + if( (d_nr * n * sizeof(double)) > buffer_size) + return BLIS_NOT_YET_IMPLEMENTED; + + if (required_packing_A == 1) + { + // Get the buffer from the pool. + bli_membrk_acquire_m(&rntm, + buffer_size, + BLIS_BITVAL_BUFFER_FOR_A_BLOCK, + &local_mem_buf_A_s); + if(FALSE==bli_mem_is_alloc(&local_mem_buf_A_s)) return BLIS_NULL_POINTER; + D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); + if(NULL==D_A_pack) return BLIS_NULL_POINTER; + } + + //ymm scratch reginsters + __m256d ymm0, ymm1, ymm2, ymm3; + __m256d ymm4, ymm5, ymm6, ymm7; + __m256d ymm8, ymm9, ymm10, ymm11; + __m256d ymm12, ymm13, ymm14, ymm15; + + __m128d xmm5; + + /* + Performs solving TRSM for 6 rows at a time from 0 to n/6 in steps of d_nr + a. Load and pack A (a01 block), the size of packing 6x6 to 6x (n-6) + First there will be no GEMM and no packing of a01 because it is only TRSM + b. Using packed a01 block and b10 block perform GEMM operation + c. Use GEMM outputs, perform TRSM operation using a11, b11 and update B + d. Repeat b for m cols of B in steps of d_mr + */ + + for(j = 0; (j+d_nr-1) < n; j += d_nr) //loop along 'N' direction + { + a01 = L + j*rs_a; //pointer to block of A to be used in GEMM + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + + //double *ptr_a10_dup = D_A_pack; + + dim_t p_lda = j; // packed leading dimension + // perform copy of A to packed buffer D_A_pack + + if(transa) + { + /* + Pack current A block (a01) into packed buffer memory D_A_pack + a. This a10 block is used in GEMM portion only and this + a01 block size will be increasing by d_nr for every next iteration + until it reaches 6x(n-6) which is the maximum GEMM alone block size in A + b. This packed buffer is reused to calculate all m cols of B matrix + */ + bli_dtrsm_small_pack('R', j, 1, a01, cs_a, D_A_pack, p_lda,d_nr); + + /* + Pack 6 diagonal elements of A block into an array + a. This helps in utilze cache line efficiently in TRSM operation + b. store ones when input is unit diagonal + */ + + dtrsm_small_pack_diag_element(is_unitdiag,a11,cs_a,d11_pack,d_nr); + } + else + { + bli_dtrsm_small_pack('R', j, 0, a01, rs_a, D_A_pack, p_lda,d_nr); + dtrsm_small_pack_diag_element(is_unitdiag,a11,rs_a,d11_pack,d_nr); + } + + /* + a. Perform GEMM using a01, b10. + b. Perform TRSM on a11, b11 + c. This loop GEMM+TRSM loops operates with 8x6 block size + along m dimension for every d_mr columns of B10 where + packed A buffer is reused in computing all m cols of B. + d. Same approach is used in remaining fringe cases. + */ + for(i = 0; (i+d_mr-1) < m; i += d_mr) //loop along 'M' direction + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + /* + Peform GEMM between a01 and b10 blocks + For first itteration there will be no GEMM operation + where k_iter are zero + */ + BLIS_DTRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) + + /* + Load b11 of size 8x6 and multiply with alpha + Add the GEMM output to b11 + and peform TRSM operation. + */ + + BLIS_PRE_DTRSM_SMALL_6x8(AlphaVal,b11,cs_b) + + ///implement TRSM/// + + /* + Compute 6x8 TRSM block by using GEMM block output in register + a. The 6x8 input (gemm outputs) are stored in combinations of ymm registers + 1. ymm3, ymm4 2. ymm5, ymm6 3. ymm7, ymm8, 4. ymm9, ymm10 + 5. ymm11, ymm12 6. ymm13,ymm14 + b. Towards the end TRSM output will be stored back into b11 + */ + + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(row 1):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); + + ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); + ymm6 = _mm256_fnmadd_pd(ymm1, ymm4, ymm6); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + + ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); + ymm8 = _mm256_fnmadd_pd(ymm1, ymm4, ymm8); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + + ymm9 = _mm256_fnmadd_pd(ymm1, ymm3, ymm9); + ymm10 = _mm256_fnmadd_pd(ymm1, ymm4, ymm10); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); + + ymm11 = _mm256_fnmadd_pd(ymm1, ymm3, ymm11); + ymm12 = _mm256_fnmadd_pd(ymm1, ymm4, ymm12); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); + + ymm13 = _mm256_fnmadd_pd(ymm1, ymm3, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm1, ymm4, ymm14); + + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm0); + + a11 += cs_a; + + //extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + + ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); + ymm8 = _mm256_fnmadd_pd(ymm1, ymm6, ymm8); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + + ymm9 = _mm256_fnmadd_pd(ymm1, ymm5, ymm9); + ymm10 = _mm256_fnmadd_pd(ymm1, ymm6, ymm10); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); + + ymm11 = _mm256_fnmadd_pd(ymm1, ymm5, ymm11); + ymm12 = _mm256_fnmadd_pd(ymm1, ymm6, ymm12); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); + + ymm13 = _mm256_fnmadd_pd(ymm1, ymm5, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm1, ymm6, ymm14); + + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm0); + + a11 += cs_a; + + //extract a33 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + + ymm9 = _mm256_fnmadd_pd(ymm1, ymm7, ymm9); + ymm10 = _mm256_fnmadd_pd(ymm1, ymm8, ymm10); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); + + ymm11 = _mm256_fnmadd_pd(ymm1, ymm7, ymm11); + ymm12 = _mm256_fnmadd_pd(ymm1, ymm8, ymm12); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); + + ymm13 = _mm256_fnmadd_pd(ymm1, ymm7, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm1, ymm8, ymm14); + + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm0); + + a11 += cs_a; + + //extract a44 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + + //(row 4):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); + + ymm11 = _mm256_fnmadd_pd(ymm1, ymm9, ymm11); + ymm12 = _mm256_fnmadd_pd(ymm1, ymm10, ymm12); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); + + ymm13 = _mm256_fnmadd_pd(ymm1, ymm9, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm1, ymm10, ymm14); + + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); + ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm0); + + a11 += cs_a; + + //extract a55 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + + //(Row 5): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); + + ymm13 = _mm256_fnmadd_pd(ymm1, ymm11, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm1, ymm12, ymm14); + + ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); + ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm0); + + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + 4), ymm4); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + _mm256_storeu_pd((double *)(b11 + cs_b + 4), ymm6); + _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); + _mm256_storeu_pd((double *)(b11 + cs_b*2 + 4), ymm8); + _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); + _mm256_storeu_pd((double *)(b11 + cs_b*3 + 4), ymm10); + _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); + _mm256_storeu_pd((double *)(b11 + cs_b*4 + 4), ymm12); + _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); + _mm256_storeu_pd((double *)(b11 + cs_b*5 + 4), ymm14); + } + + dim_t m_remainder = m - i; + if(m_remainder >= 4) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_6nx4m(a01,b10,cs_b,p_lda,k_iter) + + // Load b11 of size 4x6 and multiply with alpha + BLIS_PRE_DTRSM_SMALL_6x4(AlphaVal,b11,cs_b) + + ///implement TRSM/// + + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(row 1):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm3, ymm9); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); + ymm11 = _mm256_fnmadd_pd(ymm1, ymm3, ymm11); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_pd(ymm1, ymm3, ymm13); + + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + a11 += cs_a; + + //extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm5, ymm9); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); + ymm11 = _mm256_fnmadd_pd(ymm1, ymm5, ymm11); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_pd(ymm1, ymm5, ymm13); + + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + a11 += cs_a; + + //extract a33 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm7, ymm9); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); + ymm11 = _mm256_fnmadd_pd(ymm1, ymm7, ymm11); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_pd(ymm1, ymm7, ymm13); + + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + + a11 += cs_a; + + //extract a44 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + + //(row 4):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); + ymm11 = _mm256_fnmadd_pd(ymm1, ymm9, ymm11); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_pd(ymm1, ymm9, ymm13); + + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); + + a11 += cs_a; + + //extract a55 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + + //(Row 5): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_pd(ymm1, ymm11, ymm13); + + ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); + + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); + _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); + _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); + _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); + + m_remainder -= 4; + i += 4; + } + + if(m_remainder == 3) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_6nx4m(a01,b10,cs_b,p_lda,k_iter) + + // Load b11 of size 4x6 and multiply with alpha + BLIS_PRE_DTRSM_SMALL_6x4(AlphaVal,b11,cs_b) + + ///implement TRSM/// + + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(row 1):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm3, ymm9); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); + ymm11 = _mm256_fnmadd_pd(ymm1, ymm3, ymm11); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_pd(ymm1, ymm3, ymm13); + + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + a11 += cs_a; + + //extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm5, ymm9); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); + ymm11 = _mm256_fnmadd_pd(ymm1, ymm5, ymm11); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_pd(ymm1, ymm5, ymm13); + + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + a11 += cs_a; + + //extract a33 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm7, ymm9); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); + ymm11 = _mm256_fnmadd_pd(ymm1, ymm7, ymm11); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_pd(ymm1, ymm7, ymm13); + + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + + a11 += cs_a; + + //extract a44 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + + //(row 4):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); + ymm11 = _mm256_fnmadd_pd(ymm1, ymm9, ymm11); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_pd(ymm1, ymm9, ymm13); + + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); + + a11 += cs_a; + + //extract a55 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + + //(Row 5): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_pd(ymm1, ymm11, ymm13); + + ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); + + ymm0 = _mm256_loadu_pd((double const *)b11); + ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x07); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x07); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x07); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x07); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm11 = _mm256_blend_pd(ymm0, ymm11, 0x07); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm13 = _mm256_blend_pd(ymm0, ymm13, 0x07); + + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); + _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); + _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); + _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); + + m_remainder -= 3; + i += 3; + } + else if(m_remainder == 2) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_6nx4m(a01,b10,cs_b,p_lda,k_iter) + + // Load b11 of size 4x6 and multiply with alpha + BLIS_PRE_DTRSM_SMALL_6x4(AlphaVal,b11,cs_b) + + ///implement TRSM/// + + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(row 1):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm3, ymm9); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); + ymm11 = _mm256_fnmadd_pd(ymm1, ymm3, ymm11); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_pd(ymm1, ymm3, ymm13); + + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + a11 += cs_a; + + //extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm5, ymm9); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); + ymm11 = _mm256_fnmadd_pd(ymm1, ymm5, ymm11); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_pd(ymm1, ymm5, ymm13); + + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + a11 += cs_a; + + //extract a33 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm7, ymm9); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); + ymm11 = _mm256_fnmadd_pd(ymm1, ymm7, ymm11); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_pd(ymm1, ymm7, ymm13); + + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + + a11 += cs_a; + + //extract a44 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + + //(row 4):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); + ymm11 = _mm256_fnmadd_pd(ymm1, ymm9, ymm11); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_pd(ymm1, ymm9, ymm13); + + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); + + a11 += cs_a; + + //extract a55 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + + //(Row 5): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_pd(ymm1, ymm11, ymm13); + + ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); + + ymm0 = _mm256_loadu_pd((double const *)b11); + ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x03); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x03); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x03); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x03); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm11 = _mm256_blend_pd(ymm0, ymm11, 0x03); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm13 = _mm256_blend_pd(ymm0, ymm13, 0x03); + + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); + _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); + _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); + _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); + + m_remainder -= 2; + i += 2; + } + else if(m_remainder == 1) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_6nx4m(a01,b10,cs_b,p_lda,k_iter) + + // Load b11 of size 4x6 and multiply with alpha + BLIS_PRE_DTRSM_SMALL_6x4(AlphaVal,b11,cs_b) + + ///implement TRSM/// + + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(row 1):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm3, ymm9); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); + ymm11 = _mm256_fnmadd_pd(ymm1, ymm3, ymm11); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_pd(ymm1, ymm3, ymm13); + + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + a11 += cs_a; + + //extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm5, ymm9); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); + ymm11 = _mm256_fnmadd_pd(ymm1, ymm5, ymm11); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_pd(ymm1, ymm5, ymm13); + + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + a11 += cs_a; + + //extract a33 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm7, ymm9); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); + ymm11 = _mm256_fnmadd_pd(ymm1, ymm7, ymm11); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_pd(ymm1, ymm7, ymm13); + + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + + a11 += cs_a; + + //extract a44 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + + //(row 4):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); + ymm11 = _mm256_fnmadd_pd(ymm1, ymm9, ymm11); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_pd(ymm1, ymm9, ymm13); + + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); + + a11 += cs_a; + + //extract a55 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + + //(Row 5): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_pd(ymm1, ymm11, ymm13); + + ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); + + ymm0 = _mm256_loadu_pd((double const *)b11); + ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x01); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x01); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x01); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x01); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm11 = _mm256_blend_pd(ymm0, ymm11, 0x01); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm13 = _mm256_blend_pd(ymm0, ymm13, 0x01); + + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); + _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); + _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); + _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); + + m_remainder -= 1; + i += 1; + } + } + + dim_t n_remainder = n - j; + + /* + Reminder cases starts here: + a. Similar logic and code flow used in computing full block (6x8) + above holds for reminder cases too. + */ + + if(n_remainder >= 4) + { + a01 = L + j*rs_a; //pointer to block of A to be used in GEMM + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + + double *ptr_a10_dup = D_A_pack; + + dim_t p_lda = j; // packed leading dimension + // perform copy of A to packed buffer D_A_pack + + if(transa) + { + for(dim_t x =0;x < p_lda;x+=d_nr) + { + ymm0 = _mm256_loadu_pd((double const *)(a01)); + ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); + ymm2 = _mm256_loadu_pd((double const *)(a01 + cs_a * 2)); + ymm3 = _mm256_loadu_pd((double const *)(a01 + cs_a * 3)); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); + + ymm0 = _mm256_loadu_pd((double const *)(a01 + cs_a * 4)); + ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a * 5)); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_broadcast_sd((double const *)&zero); + + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_broadcast_sd((double const *)&zero); + + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_extractf128_pd(ymm6,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_extractf128_pd(ymm7,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_extractf128_pd(ymm8,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_extractf128_pd(ymm9,0)); + + a01 += d_nr*cs_a; + ptr_a10_dup += d_nr; + } + } + else + { + dim_t loop_count = p_lda/4; + + for(dim_t x =0;x < loop_count;x++) + { + ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 0 + x*4)); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + x*4), ymm15); + ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 1 + x*4)); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 1 + x*4), ymm15); + ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 2 + x*4)); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 2 + x*4), ymm15); + ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 3 + x*4)); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 3 + x*4), ymm15); + } + + dim_t remainder_loop_count = p_lda - loop_count*4; + + __m128d xmm0; + if(remainder_loop_count != 0) + { + xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 0 + loop_count*4)); + _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + loop_count*4), xmm0); + xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 1 + loop_count*4)); + _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 1 + loop_count*4), xmm0); + xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 2 + loop_count*4)); + _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 2 + loop_count*4), xmm0); + xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 3 + loop_count*4)); + _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 3 + loop_count*4), xmm0); + } + } + + ymm4 = _mm256_broadcast_sd((double const *)&ones); + if(!is_unitdiag) + { + if(transa) + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a*1 + 1)); + ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a*2 + 2)); + ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a*3 + 3)); + } + else + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11+ rs_a*1 + 1)); + ymm2 = _mm256_broadcast_sd((double const *)(a11+ rs_a*2 + 2)); + ymm3 = _mm256_broadcast_sd((double const *)(a11+ rs_a*3 + 3)); + } + + ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); + #ifdef BLIS_DISABLE_TRSM_PREINVERSION + ymm4 = ymm1; + #endif + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + ymm4 = _mm256_div_pd(ymm4, ymm1); + #endif + } + _mm256_storeu_pd((double *)(d11_pack), ymm4); + + for(i = 0; (i+d_mr-1) < m; i += d_mr) //loop along 'M' direction + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_4nx8m(a01,b10,cs_b,p_lda,k_iter) + + BLIS_PRE_DTRSM_SMALL_4x8(AlphaVal,b11,cs_b) + + ///implement TRSM/// + + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(row 1):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); + + ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); + ymm6 = _mm256_fnmadd_pd(ymm1, ymm4, ymm6); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + + ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); + ymm8 = _mm256_fnmadd_pd(ymm1, ymm4, ymm8); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + + ymm9 = _mm256_fnmadd_pd(ymm1, ymm3, ymm9); + ymm10 = _mm256_fnmadd_pd(ymm1, ymm4, ymm10); + + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm0); + + a11 += cs_a; + + //extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + + ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); + ymm8 = _mm256_fnmadd_pd(ymm1, ymm6, ymm8); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + + ymm9 = _mm256_fnmadd_pd(ymm1, ymm5, ymm9); + ymm10 = _mm256_fnmadd_pd(ymm1, ymm6, ymm10); + + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm0); + + a11 += cs_a; + + //extract a33 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + + ymm9 = _mm256_fnmadd_pd(ymm1, ymm7, ymm9); + ymm10 = _mm256_fnmadd_pd(ymm1, ymm8, ymm10); + + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm0); + + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + 4), ymm4); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + _mm256_storeu_pd((double *)(b11 + cs_b + 4), ymm6); + _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); + _mm256_storeu_pd((double *)(b11 + cs_b*2 + 4), ymm8); + _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); + _mm256_storeu_pd((double *)(b11 + cs_b*3 + 4), ymm10); + } + + dim_t m_remainder = m - i; + if(m_remainder >= 4) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_4nx4m(a01,b10,cs_b,p_lda,k_iter) + + ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 + + ///implement TRSM/// + + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(row 1):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm3, ymm9); + + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + a11 += cs_a; + + //extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm5, ymm9); + + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + a11 += cs_a; + + //extract a33 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm7, ymm9); + + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); + _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); + + m_remainder -= 4; + i += 4; + } + + if(m_remainder == 3) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_4nx4m(a01,b10,cs_b,p_lda,k_iter) + + ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 + + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*3 + 2)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); + ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 + + ///implement TRSM/// + + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(row 1):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm3, ymm9); + + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + a11 += cs_a; + + //extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm5, ymm9); + + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + a11 += cs_a; + + //extract a33 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm7, ymm9); + + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + + ymm0 = _mm256_loadu_pd((double const *)b11); + ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x07); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x07); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x07); + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*3 + 2)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x07); + + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); + xmm5 = _mm256_extractf128_pd(ymm9, 0); + _mm_storeu_pd((double *)(b11 + cs_b * 3),xmm5); + _mm_storel_pd((b11 + cs_b * 3 + 2), _mm256_extractf128_pd(ymm9, 1)); + + m_remainder -= 3; + i += 3; + } + else if(m_remainder == 2) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_4nx4m(a01,b10,cs_b,p_lda,k_iter) + + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 + + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); + ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 + + ///implement TRSM/// + + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(row 1):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm3, ymm9); + + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + a11 += cs_a; + + //extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm5, ymm9); + + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + a11 += cs_a; + + //extract a33 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm7, ymm9); + + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + + ymm0 = _mm256_loadu_pd((double const *)b11); + ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x03); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x03); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x03); + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); + ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x03); + + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); + xmm5 = _mm256_extractf128_pd(ymm9, 0); + _mm_storeu_pd((double *)(b11 + cs_b * 3),xmm5); + + m_remainder -= 2; + i += 2; + } + else if(m_remainder == 1) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_4nx4m(a01,b10,cs_b,p_lda,k_iter) + + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ymm0 = _mm256_broadcast_sd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 + + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 + + ///implement TRSM/// + + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(row 1):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm3, ymm9); + + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + a11 += cs_a; + + //extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm5, ymm9); + + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + a11 += cs_a; + + //extract a33 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm7, ymm9); + + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + + ymm0 = _mm256_loadu_pd((double const *)b11); + ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x01); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x01); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x01); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x01); + + _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm3, 0)); + _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm5, 0)); + _mm_storel_pd((b11 + cs_b * 2), _mm256_extractf128_pd(ymm7, 0)); + _mm_storel_pd((b11 + cs_b * 3), _mm256_extractf128_pd(ymm9, 0)); + + m_remainder -= 1; + i += 1; + } + j += 4; + n_remainder -= 4; + } + + if(n_remainder == 3) + { + a01 = L + j*rs_a; //pointer to block of A to be used in GEMM + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + + double *ptr_a10_dup = D_A_pack; + + dim_t p_lda = j; // packed leading dimension + // perform copy of A to packed buffer D_A_pack + + if(transa) + { + for(dim_t x =0;x < p_lda;x+=d_nr) + { + ymm0 = _mm256_loadu_pd((double const *)(a01)); + ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); + ymm2 = _mm256_loadu_pd((double const *)(a01 + cs_a * 2)); + ymm3 = _mm256_loadu_pd((double const *)(a01 + cs_a * 3)); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); + + ymm0 = _mm256_loadu_pd((double const *)(a01 + cs_a * 4)); + ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a * 5)); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_broadcast_sd((double const *)&zero); + + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_broadcast_sd((double const *)&zero); + + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_extractf128_pd(ymm6,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_extractf128_pd(ymm7,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_extractf128_pd(ymm8,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_extractf128_pd(ymm9,0)); + + a01 += d_nr*cs_a; + ptr_a10_dup += d_nr; + } + } + else + { + dim_t loop_count = p_lda/4; + + for(dim_t x =0;x < loop_count;x++) + { + ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 0 + x*4)); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + x*4), ymm15); + ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 1 + x*4)); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 1 + x*4), ymm15); + ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 2 + x*4)); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 2 + x*4), ymm15); + } + + dim_t remainder_loop_count = p_lda - loop_count*4; + + __m128d xmm0; + if(remainder_loop_count != 0) + { + xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 0 + loop_count*4)); + _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + loop_count*4), xmm0); + xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 1 + loop_count*4)); + _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 1 + loop_count*4), xmm0); + xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 2 + loop_count*4)); + _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 2 + loop_count*4), xmm0); + } + } + + ymm4 = _mm256_broadcast_sd((double const *)&ones); + if(!is_unitdiag) + { + if(transa) + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a*1 + 1)); + ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a*2 + 2)); + } + else + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11+ rs_a*1 + 1)); + ymm2 = _mm256_broadcast_sd((double const *)(a11+ rs_a*2 + 2)); + } + ymm3 = _mm256_broadcast_sd((double const *)&ones); + + ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); + #ifdef BLIS_DISABLE_TRSM_PREINVERSION + ymm4 = ymm1; + #endif + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + ymm4 = _mm256_div_pd(ymm4, ymm1); + #endif + } + _mm256_storeu_pd((double *)(d11_pack), ymm4); + + for(i = 0; (i+d_mr-1) < m; i += d_mr) //loop along 'M' direction + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_3nx8m(a01,b10,cs_b,p_lda,k_iter) + + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); + + ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + 4)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] + + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + ymm4 = _mm256_fmsub_pd(ymm1, ymm15, ymm4); //B11[4-7][0] * alpha-= ymm1 + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b + 4)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] + + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + ymm6 = _mm256_fmsub_pd(ymm1, ymm15, ymm6); //B11[4-7][1] * alpha -= ymm3 + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b*2 + 4)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] + + ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 + ymm8 = _mm256_fmsub_pd(ymm1, ymm15, ymm8); //B11[4-7][2] * alpha -= ymm5 + + ///implement TRSM/// + + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(row 1):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); + + ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); + ymm6 = _mm256_fnmadd_pd(ymm1, ymm4, ymm6); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + + ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); + ymm8 = _mm256_fnmadd_pd(ymm1, ymm4, ymm8); + + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm0); + + a11 += cs_a; + + //extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + + ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); + ymm8 = _mm256_fnmadd_pd(ymm1, ymm6, ymm8); + + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm0); + + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + 4), ymm4); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + _mm256_storeu_pd((double *)(b11 + cs_b + 4), ymm6); + _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); + _mm256_storeu_pd((double *)(b11 + cs_b*2 + 4), ymm8); + } + + dim_t m_remainder = m - i; + if(m_remainder >= 4) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) + + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 + + ///implement TRSM/// + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(row 1):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); + + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + a11 += cs_a; + + //extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); + + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); + + m_remainder -= 4; + i += 4; + } + + if(m_remainder == 3) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) + + BLIS_PRE_DTRSM_SMALL_3N_3M(AlphaVal,b11,cs_b) + + ///implement TRSM/// + + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(row 1):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); + + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + a11 += cs_a; + + //extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); + + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + BLIS_POST_DTRSM_SMALL_3N_3M(b11,cs_b) + + m_remainder -= 3; + i += 3; + } + else if(m_remainder == 2) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) + + BLIS_PRE_DTRSM_SMALL_3N_2M(AlphaVal,b11,cs_b) + + ///implement TRSM/// + + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(row 1):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); + + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + a11 += cs_a; + + //extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); + + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + BLIS_POST_DTRSM_SMALL_3N_2M(b11,cs_b) + + m_remainder -= 2; + i += 2; + } + else if(m_remainder == 1) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) + + BLIS_PRE_DTRSM_SMALL_3N_1M(AlphaVal,b11,cs_b) + + ///implement TRSM/// + + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(row 1):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); + + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + a11 += cs_a; + + //extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); + + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + BLIS_POST_DTRSM_SMALL_3N_1M(b11,cs_b) + + m_remainder -= 1; + i += 1; + } + j += 3; + n_remainder -= 3; + } + else if(n_remainder == 2) + { + a01 = L + j*rs_a; //pointer to block of A to be used in GEMM + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + + double *ptr_a10_dup = D_A_pack; + + dim_t p_lda = j; // packed leading dimension + // perform copy of A to packed buffer D_A_pack + + if(transa) + { + for(dim_t x =0;x < p_lda;x+=d_nr) + { + ymm0 = _mm256_loadu_pd((double const *)(a01)); + ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); + ymm2 = _mm256_loadu_pd((double const *)(a01 + cs_a * 2)); + ymm3 = _mm256_loadu_pd((double const *)(a01 + cs_a * 3)); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); + + ymm0 = _mm256_loadu_pd((double const *)(a01 + cs_a * 4)); + ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a * 5)); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_broadcast_sd((double const *)&zero); + + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_broadcast_sd((double const *)&zero); + + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_extractf128_pd(ymm6,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_extractf128_pd(ymm7,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_extractf128_pd(ymm8,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_extractf128_pd(ymm9,0)); + + a01 += d_nr*cs_a; + ptr_a10_dup += d_nr; + } + } + else + { + dim_t loop_count = p_lda/4; + + for(dim_t x =0;x < loop_count;x++) + { + ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 0 + x*4)); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + x*4), ymm15); + ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 1 + x*4)); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 1 + x*4), ymm15); + } + + dim_t remainder_loop_count = p_lda - loop_count*4; + + __m128d xmm0; + if(remainder_loop_count != 0) + { + xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 0 + loop_count*4)); + _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + loop_count*4), xmm0); + xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 1 + loop_count*4)); + _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 1 + loop_count*4), xmm0); + } + } + + ymm4 = _mm256_broadcast_sd((double const *)&ones); + if(!is_unitdiag) + { + if(transa) + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11+cs_a*1 + 1)); + } + else + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11+rs_a*1 + 1)); + } + ymm2 = _mm256_broadcast_sd((double const *)&ones); + ymm3 = _mm256_broadcast_sd((double const *)&ones); + + ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); + #ifdef BLIS_DISABLE_TRSM_PREINVERSION + ymm4 = ymm1; + #endif + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + ymm4 = _mm256_div_pd(ymm4, ymm1); + #endif + } + _mm256_storeu_pd((double *)(d11_pack), ymm4); + + for(i = 0; (i+d_mr-1) < m; i += d_mr) //loop along 'M' direction + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_2nx8m(a01,b10,cs_b,p_lda,k_iter) + + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); + + ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + 4)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] + + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + ymm4 = _mm256_fmsub_pd(ymm1, ymm15, ymm4); //B11[4-7][0] * alpha-= ymm1 + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b + 4)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] + + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + ymm6 = _mm256_fmsub_pd(ymm1, ymm15, ymm6); //B11[4-7][1] * alpha -= ymm3 + + ///implement TRSM/// + + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(row 1):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); + + ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); + ymm6 = _mm256_fnmadd_pd(ymm1, ymm4, ymm6); + + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm0); + + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + 4), ymm4); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + _mm256_storeu_pd((double *)(b11 + cs_b + 4), ymm6); + } + + dim_t m_remainder = m - i; + if(m_remainder >= 4) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + + ymm3 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) + + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + + ///implement TRSM/// + + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(row 1):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); + + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + + m_remainder -= 4; + i += 4; + } + + if(m_remainder == 3) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + + ymm3 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) + + BLIS_PRE_DTRSM_SMALL_2N_3M(AlphaVal,b11,cs_b) + + ///implement TRSM/// + + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(row 1):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); + + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + BLIS_POST_DTRSM_SMALL_2N_3M(b11,cs_b) + + m_remainder -= 3; + i += 3; + } + else if(m_remainder == 2) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + + ymm3 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) + + BLIS_PRE_DTRSM_SMALL_2N_2M(AlphaVal,b11,cs_b) + + ///implement TRSM/// + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(row 1):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); + + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + BLIS_POST_DTRSM_SMALL_2N_2M(b11,cs_b) + + m_remainder -= 2; + i += 2; + } + else if(m_remainder == 1) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + + ymm3 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) + + BLIS_PRE_DTRSM_SMALL_2N_1M(AlphaVal,b11,cs_b) + + ///implement TRSM/// + + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(row 1):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); + + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + BLIS_POST_DTRSM_SMALL_2N_1M(b11,cs_b) + + m_remainder -= 1; + i += 1; + } + j += 2; + n_remainder -= 2; + } + else if(n_remainder == 1) + { + a01 = L + j*rs_a; //pointer to block of A to be used in GEMM + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + + double *ptr_a10_dup = D_A_pack; + + dim_t p_lda = j; // packed leading dimension + // perform copy of A to packed buffer D_A_pack + + if(transa) + { + for(dim_t x =0;x < p_lda;x+=d_nr) + { + ymm0 = _mm256_loadu_pd((double const *)(a01)); + ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); + ymm2 = _mm256_loadu_pd((double const *)(a01 + cs_a * 2)); + ymm3 = _mm256_loadu_pd((double const *)(a01 + cs_a * 3)); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); + + ymm0 = _mm256_loadu_pd((double const *)(a01 + cs_a * 4)); + ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a * 5)); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_broadcast_sd((double const *)&zero); + + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_broadcast_sd((double const *)&zero); + + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_extractf128_pd(ymm6,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_extractf128_pd(ymm7,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_extractf128_pd(ymm8,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_extractf128_pd(ymm9,0)); + + a01 += d_nr*cs_a; + ptr_a10_dup += d_nr; + } + } + else + { + dim_t loop_count = p_lda/4; + + for(dim_t x =0;x < loop_count;x++) + { + ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 0 + x*4)); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + x*4), ymm15); + } + + dim_t remainder_loop_count = p_lda - loop_count*4; + + __m128d xmm0; + if(remainder_loop_count != 0) + { + xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 0 + loop_count*4)); + _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + loop_count*4), xmm0); + } + } + + ymm4 = _mm256_broadcast_sd((double const *)&ones); + if(!is_unitdiag) + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)&ones); + ymm2 = _mm256_broadcast_sd((double const *)&ones); + ymm3 = _mm256_broadcast_sd((double const *)&ones); + + ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); + #ifdef BLIS_DISABLE_TRSM_PREINVERSION + ymm4 = ymm1; + #endif + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + ymm4 = _mm256_div_pd(ymm4, ymm1); + #endif + } + _mm256_storeu_pd((double *)(d11_pack), ymm4); + + for(i = 0; (i+d_mr-1) < m; i += d_mr) //loop along 'M' direction + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + + ymm3 = _mm256_setzero_pd(); + ymm4 = _mm256_setzero_pd(); + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_1nx8m(a01,b10,cs_b,p_lda,k_iter) + + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); + + ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + 4)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] + + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + ymm4 = _mm256_fmsub_pd(ymm1, ymm15, ymm4); //B11[4-7][0] * alpha-= ymm1 + + ///implement TRSM/// + + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm0); + + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + 4), ymm4); + } + + dim_t m_remainder = m - i; + if(m_remainder >= 4) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + + ymm3 = _mm256_setzero_pd(); + + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) + + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + + ///implement TRSM/// + + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + _mm256_storeu_pd((double *)b11, ymm3); + + m_remainder -= 4; + i += 4; + } + + if(m_remainder == 3) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + + ymm3 = _mm256_setzero_pd(); + + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) + + BLIS_PRE_DTRSM_SMALL_1N_3M(AlphaVal,b11,cs_b) + + ///implement TRSM/// + + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + BLIS_POST_DTRSM_SMALL_1N_3M(b11,cs_b) + + m_remainder -= 3; + i += 3; + } + else if(m_remainder == 2) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + + ymm3 = _mm256_setzero_pd(); + + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) + + BLIS_PRE_DTRSM_SMALL_1N_2M(AlphaVal,b11,cs_b) + + ///implement TRSM/// + + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + BLIS_POST_DTRSM_SMALL_1N_2M(b11,cs_b) + + m_remainder -= 2; + i += 2; + } + else if(m_remainder == 1) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + + ymm3 = _mm256_setzero_pd(); + + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) + + BLIS_PRE_DTRSM_SMALL_1N_1M(AlphaVal,b11,cs_b) + + ///implement TRSM/// + + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + BLIS_POST_DTRSM_SMALL_1N_1M(b11,cs_b) + + m_remainder -= 1; + i += 1; + } + j += 1; + n_remainder -= 1; + } + + if ((required_packing_A == 1) && bli_mem_is_alloc( &local_mem_buf_A_s )) + { + bli_membrk_release(&rntm, + &local_mem_buf_A_s); + } + + return BLIS_SUCCESS; +} + +/*implements TRSM for the case XA = alpha * B + *A is upper triangular, non-unit diagonal/unit diagonal, transpose + *dimensions: X:mxn A:nxn B: mxn + * + * <---b11 <---a11 + ***************** * + *b01*b11* * * * * + ^ * * * * * ^ * * + | ***************** | ******* + | * * * * * | * * * + | * * * * * a01* * * +b10 ***************** ************* + * * * * * * * * * + * * * * * * * * * + ***************** ******************* + + *implements TRSM for the case XA = alpha * B + *A is lower triangular, non-unit diagonal/unit diagonal, no transpose + *dimensions: X:mxn A:nxn B: mxn + * + * <---b11 <---a11 + ***************** * + *b01*b11* * * * * + ^ * * * * * ^ * * + | ***************** | ******* + | * * * * * | * * * + | * * * * * a01* * * +b10 ***************** ************* + * * * * * * * * * + * * * * * * * * * + ***************** ******************* + +*/ +BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +) +{ + dim_t m = bli_obj_length(b); //number of rows + dim_t n = bli_obj_width(b); //number of columns + + bool transa = bli_obj_has_trans(a); + dim_t cs_a, rs_a; + dim_t d_mr = 8,d_nr = 6; + + // Swap rs_a & cs_a in case of non-tranpose. + if(transa) + { + cs_a = bli_obj_col_stride(a); // column stride of A + rs_a = bli_obj_row_stride(a); // row stride of A + } + else + { + cs_a = bli_obj_row_stride(a); // row stride of A + rs_a = bli_obj_col_stride(a); // column stride of A + } + dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B + + dim_t i, j, k; //loop variablse + dim_t k_iter; //determines the number of GEMM operations to be done + + double ones = 1.0; + double zero = 0.0; + bool is_unitdiag = bli_obj_has_unit_diag(a); + + double AlphaVal = *(double *)AlphaObj->buffer; //value of Alpha + double* restrict L = a->buffer; //pointer to matrix A + double* restrict B = b->buffer; //pointer to matrix B + + double *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks + + gint_t required_packing_A = 1; + mem_t local_mem_buf_A_s = {0}; + double *D_A_pack = NULL; + double d11_pack[d_mr] __attribute__((aligned(64))); + rntm_t rntm; + + bli_rntm_init_from_global( &rntm ); + bli_rntm_set_num_threads_only( 1, &rntm ); + bli_membrk_rntm_set_membrk( &rntm ); + + siz_t buffer_size = bli_pool_block_size( + bli_membrk_pool( + bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), + bli_rntm_membrk(&rntm))); + + if( (d_nr * n * sizeof(double)) > buffer_size) + return BLIS_NOT_YET_IMPLEMENTED; + + if (required_packing_A == 1) + { + // Get the buffer from the pool. + bli_membrk_acquire_m(&rntm, + buffer_size, + BLIS_BITVAL_BUFFER_FOR_A_BLOCK, + &local_mem_buf_A_s); + if(FALSE==bli_mem_is_alloc(&local_mem_buf_A_s)) return BLIS_NULL_POINTER; + D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); + if(NULL==D_A_pack) return BLIS_NULL_POINTER; + } + + //ymm scratch reginsters + __m256d ymm0, ymm1, ymm2, ymm3; + __m256d ymm4, ymm5, ymm6, ymm7; + __m256d ymm8, ymm9, ymm10, ymm11; + __m256d ymm12, ymm13, ymm14, ymm15; + + __m128d xmm5; + + /* + Performs solving TRSM for 6 rows at a time from 0 to n/6 in steps of d_nr + a. Load and pack A (a01 block), the size of packing 6x6 to 6x (n-6) + First there will be no GEMM and no packing of a01 because it is only TRSM + b. Using packed a01 block and b10 block perform GEMM operation + c. Use GEMM outputs, perform TRSM operation using a11, b11 and update B + d. Repeat b for m cols of B in steps of d_mr + */ + + for(j = (n-d_nr); (j+1) > 0; j -= d_nr) //loop along 'N' direction + { + a01 = L + (j*rs_a) + (j+d_nr)*cs_a; //pointer to block of A to be used in GEMM + a11 = L + (j*cs_a) + (j*rs_a); //pointer to block of A to be used for TRSM + + //double *ptr_a10_dup = D_A_pack; + + dim_t p_lda = (n-j-d_nr); // packed leading dimension + // perform copy of A to packed buffer D_A_pack + + if(transa) + { + /* + Pack current A block (a01) into packed buffer memory D_A_pack + a. This a10 block is used in GEMM portion only and this + a01 block size will be increasing by d_nr for every next iteration + until it reaches 6x(n-6) which is the maximum GEMM alone block size in A + b. This packed buffer is reused to calculate all m cols of B matrix + */ + bli_dtrsm_small_pack('R', p_lda, 1, a01, cs_a, D_A_pack, p_lda,d_nr); + + /* + Pack 6 diagonal elements of A block into an array + a. This helps in utilze cache line efficiently in TRSM operation + b. store ones when input is unit diagonal + */ + dtrsm_small_pack_diag_element(is_unitdiag,a11,cs_a,d11_pack,d_nr); + } + else + { + bli_dtrsm_small_pack('R', p_lda, 0, a01, rs_a, D_A_pack, p_lda,d_nr); + dtrsm_small_pack_diag_element(is_unitdiag,a11,rs_a,d11_pack,d_nr); + } + + /* + a. Perform GEMM using a01, b10. + b. Perform TRSM on a11, b11 + c. This loop GEMM+TRSM loops operates with 8x6 block size + along m dimension for every d_mr columns of B10 where + packed A buffer is reused in computing all m cols of B. + d. Same approach is used in remaining fringe cases. + */ + for(i = (m-d_mr); (i+1) > 0; i -= d_mr) //loop along 'M' direction + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i + (j+d_nr)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (i) + (j)*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-j-d_nr); //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + /* + Peform GEMM between a01 and b10 blocks + For first itteration there will be no GEMM operation + where k_iter are zero + */ + + BLIS_DTRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) + + /* + Load b11 of size 8x6 and multiply with alpha + Add the GEMM output to b11 + and peform TRSM operation. + */ + + BLIS_PRE_DTRSM_SMALL_6x8(AlphaVal,b11,cs_b) + + ///implement TRSM/// + + /* + Compute 6x8 TRSM block by using GEMM block output in register + a. The 6x8 input (gemm outputs) are stored in combinations of ymm registers + 1. ymm3, ymm4 2. ymm5, ymm6 3. ymm7, ymm8, 4. ymm9, ymm10 + 5. ymm11, ymm12 6. ymm13,ymm14 + b. Towards the end TRSM output will be stored back into b11 + */ + + //extract a55 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + + ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); + ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm0); + + //extract a44 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + + //(row 5):FMA operations + //ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 4*rs_a)); + + ymm11 = _mm256_fnmadd_pd(ymm1, ymm13, ymm11); + ymm12 = _mm256_fnmadd_pd(ymm1, ymm14, ymm12); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 3*rs_a)); + + ymm9 = _mm256_fnmadd_pd(ymm1, ymm13, ymm9); + ymm10 = _mm256_fnmadd_pd(ymm1, ymm14, ymm10); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 2*rs_a)); + + ymm7 = _mm256_fnmadd_pd(ymm1, ymm13, ymm7); + ymm8 = _mm256_fnmadd_pd(ymm1, ymm14, ymm8); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 1*rs_a)); + + ymm5 = _mm256_fnmadd_pd(ymm1, ymm13, ymm5); + ymm6 = _mm256_fnmadd_pd(ymm1, ymm14, ymm6); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a)); + + ymm3 = _mm256_fnmadd_pd(ymm1, ymm13, ymm3); + ymm4 = _mm256_fnmadd_pd(ymm1, ymm14, ymm4); + + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); + ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm0); + + //extract a33 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //(row 4):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 3*rs_a)); + + ymm9 = _mm256_fnmadd_pd(ymm1, ymm11, ymm9); + ymm10 = _mm256_fnmadd_pd(ymm1, ymm12, ymm10); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 2*rs_a)); + + ymm7 = _mm256_fnmadd_pd(ymm1, ymm11, ymm7); + ymm8 = _mm256_fnmadd_pd(ymm1, ymm12, ymm8); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 1*rs_a)); + + ymm5 = _mm256_fnmadd_pd(ymm1, ymm11, ymm5); + ymm6 = _mm256_fnmadd_pd(ymm1, ymm12, ymm6); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a)); + + ymm3 = _mm256_fnmadd_pd(ymm1, ymm11, ymm3); + ymm4 = _mm256_fnmadd_pd(ymm1, ymm12, ymm4); + + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm0); + + //extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2*rs_a)); + + ymm7 = _mm256_fnmadd_pd(ymm1, ymm9, ymm7); + ymm8 = _mm256_fnmadd_pd(ymm1, ymm10, ymm8); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1*rs_a)); + + ymm5 = _mm256_fnmadd_pd(ymm1, ymm9, ymm5); + ymm6 = _mm256_fnmadd_pd(ymm1, ymm10, ymm6); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); + + ymm3 = _mm256_fnmadd_pd(ymm1, ymm9, ymm3); + ymm4 = _mm256_fnmadd_pd(ymm1, ymm10, ymm4); + + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1*rs_a)); + + ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); + ymm6 = _mm256_fnmadd_pd(ymm1, ymm8, ymm6); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); + + ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); + ymm4 = _mm256_fnmadd_pd(ymm1, ymm8, ymm4); + + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm0); + + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); + + ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + ymm4 = _mm256_fnmadd_pd(ymm1, ymm6, ymm4); + + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm0); + + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + 4), ymm4); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + _mm256_storeu_pd((double *)(b11 + cs_b + 4), ymm6); + _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); + _mm256_storeu_pd((double *)(b11 + cs_b*2 + 4), ymm8); + _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); + _mm256_storeu_pd((double *)(b11 + cs_b*3 + 4), ymm10); + _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); + _mm256_storeu_pd((double *)(b11 + cs_b*4 + 4), ymm12); + _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); + _mm256_storeu_pd((double *)(b11 + cs_b*5 + 4), ymm14); + } + + dim_t m_remainder = i + d_mr; + if(m_remainder >= 4) + { + a01 = D_A_pack; + a11 = L + (j*cs_a) + (j*rs_a); + b10 = B + (m_remainder - 4) + (j+d_nr)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 4) + (j*cs_b); + + k_iter = (n-j-d_nr); //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_6nx4m(a01,b10,cs_b,p_lda,k_iter) + + // Load b11 of size 4x6 and multiply with alpha + BLIS_PRE_DTRSM_SMALL_6x4(AlphaVal,b11,cs_b) + + ///implement TRSM/// + + //extract a55 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); + + //extract a44 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + + //(row 5):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 4*rs_a)); + ymm11 = _mm256_fnmadd_pd(ymm1, ymm13, ymm11); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 3*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm13, ymm9); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 2*rs_a)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm13, ymm7); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm13, ymm5); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm13, ymm3); + + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); + + //extract a33 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //(row 4):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 3*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm11, ymm9); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 2*rs_a)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm11, ymm7); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm11, ymm5); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm11, ymm3); + + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + + //extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2*rs_a)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm9, ymm7); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm9, ymm5); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm9, ymm3); + + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); + + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); + _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); + _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); + _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); + + m_remainder -=4; + } + + if(m_remainder) + { + if(3 == m_remainder) + { + a01 = D_A_pack; + a11 = L + (j*cs_a) + (j*rs_a); + b10 = B + (j+d_nr)*cs_b + (m_remainder - 3); //pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 3) + (j*cs_b); + + k_iter = (n-j-d_nr); //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_6nx4m(a01,b10,cs_b,p_lda,k_iter) + + // Load b11 of size 4x6 and multiply with alpha + BLIS_PRE_DTRSM_SMALL_6x4(AlphaVal,b11,cs_b) + + ///implement TRSM/// + + //extract a55 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); + + //extract a44 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + + //(row 5):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 4*rs_a)); + ymm11 = _mm256_fnmadd_pd(ymm1, ymm13, ymm11); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 3*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm13, ymm9); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 2*rs_a)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm13, ymm7); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm13, ymm5); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm13, ymm3); + + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); + + //extract a33 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //(row 4):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 3*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm11, ymm9); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 2*rs_a)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm11, ymm7); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm11, ymm5); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm11, ymm3); + + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + + //extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2*rs_a)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm9, ymm7); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm9, ymm5); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm9, ymm3); + + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); + + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + ymm0 = _mm256_loadu_pd((double const *)b11); + ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x07); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x07); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x07); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x07); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm11 = _mm256_blend_pd(ymm0, ymm11, 0x07); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm13 = _mm256_blend_pd(ymm0, ymm13, 0x07); + + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); + _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); + _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); + _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); + + m_remainder -=3; + } + else if(2 == m_remainder) + { + a01 = D_A_pack; + a11 = L + (j*cs_a) + (j*rs_a); + b10 = B + (j+d_nr)*cs_b + (m_remainder - 2); //pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 2) + (j*cs_b); + + k_iter = (n-j-d_nr); //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_6nx4m(a01,b10,cs_b,p_lda,k_iter) + + // Load b11 of size 4x6 and multiply with alpha + BLIS_PRE_DTRSM_SMALL_6x4(AlphaVal,b11,cs_b) + + ///implement TRSM/// + + //extract a55 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); + + //extract a44 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + + //(row 5):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 4*rs_a)); + ymm11 = _mm256_fnmadd_pd(ymm1, ymm13, ymm11); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 3*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm13, ymm9); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 2*rs_a)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm13, ymm7); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm13, ymm5); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm13, ymm3); + + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); + + //extract a33 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //(row 4):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 3*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm11, ymm9); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 2*rs_a)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm11, ymm7); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm11, ymm5); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm11, ymm3); + + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + + //extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2*rs_a)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm9, ymm7); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm9, ymm5); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm9, ymm3); + + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); + + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + ymm0 = _mm256_loadu_pd((double const *)b11); + ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x03); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x03); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x03); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x03); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm11 = _mm256_blend_pd(ymm0, ymm11, 0x03); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm13 = _mm256_blend_pd(ymm0, ymm13, 0x03); + + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); + _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); + _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); + _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); + + m_remainder -=2; + } + else if (1 == m_remainder) + { + a01 = D_A_pack; + a11 = L + (j*cs_a) + (j*rs_a); + b10 = B + (j+d_nr)*cs_b + (m_remainder - 1); //pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 1) + (j*cs_b); + + k_iter = (n-j-d_nr); //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_6nx4m(a01,b10,cs_b,p_lda,k_iter) + + // Load b11 of size 4x6 and multiply with alpha + BLIS_PRE_DTRSM_SMALL_6x4(AlphaVal,b11,cs_b) + + ///implement TRSM/// + + //extract a55 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); + + //extract a44 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + + //(row 5):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 4*rs_a)); + ymm11 = _mm256_fnmadd_pd(ymm1, ymm13, ymm11); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 3*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm13, ymm9); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 2*rs_a)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm13, ymm7); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm13, ymm5); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm13, ymm3); + + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); + + //extract a33 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //(row 4):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 3*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm11, ymm9); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 2*rs_a)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm11, ymm7); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm11, ymm5); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm11, ymm3); + + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + + //extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2*rs_a)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm9, ymm7); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm9, ymm5); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm9, ymm3); + + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); + + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + ymm0 = _mm256_loadu_pd((double const *)b11); + ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x01); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x01); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x01); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x01); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm11 = _mm256_blend_pd(ymm0, ymm11, 0x01); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm13 = _mm256_blend_pd(ymm0, ymm13, 0x01); + + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); + _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); + _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); + _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); + + m_remainder -=1; + } + } + } + + dim_t n_remainder = j + d_nr; + + /* + Reminder cases starts here: + a. Similar logic and code flow used in computing full block (6x8) + above holds for reminder cases too. + */ + + if(n_remainder >= 4) + { + a01 = L + (n_remainder - 4)*rs_a + n_remainder*cs_a; //pointer to block of A to be used in GEMM + a11 = L + (n_remainder - 4)*cs_a + (n_remainder - 4)*rs_a; //pointer to block of A to be used for TRSM + + double *ptr_a10_dup = D_A_pack; + + dim_t p_lda = (n-n_remainder); // packed leading dimension + // perform copy of A to packed buffer D_A_pack + + if(transa) + { + for(dim_t x =0;x < p_lda;x+=d_nr) + { + ymm0 = _mm256_loadu_pd((double const *)(a01)); + ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); + ymm2 = _mm256_loadu_pd((double const *)(a01 + cs_a * 2)); + ymm3 = _mm256_loadu_pd((double const *)(a01 + cs_a * 3)); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); + + ymm0 = _mm256_loadu_pd((double const *)(a01 + cs_a * 4)); + ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a * 5)); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_broadcast_sd((double const *)&zero); + + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_broadcast_sd((double const *)&zero); + + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_extractf128_pd(ymm6,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_extractf128_pd(ymm7,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_extractf128_pd(ymm8,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_extractf128_pd(ymm9,0)); + + a01 += d_nr*cs_a; + ptr_a10_dup += d_nr; + } + } + else + { + dim_t loop_count = (n-n_remainder)/4; + + for(dim_t x =0;x < loop_count;x++) + { + ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 0 + x*4)); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + x*4), ymm15); + ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 1 + x*4)); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 1 + x*4), ymm15); + ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 2 + x*4)); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 2 + x*4), ymm15); + ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 3 + x*4)); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 3 + x*4), ymm15); + } + + dim_t remainder_loop_count = p_lda - loop_count*4; + + __m128d xmm0; + if(remainder_loop_count != 0) + { + xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 0 + loop_count*4)); + _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + loop_count*4), xmm0); + xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 1 + loop_count*4)); + _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 1 + loop_count*4), xmm0); + xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 2 + loop_count*4)); + _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 2 + loop_count*4), xmm0); + xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 3 + loop_count*4)); + _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 3 + loop_count*4), xmm0); + } + } + + ymm4 = _mm256_broadcast_sd((double const *)&ones); + if(!is_unitdiag) + { + if(transa) + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a*1 + 1)); + ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a*2 + 2)); + ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a*3 + 3)); + } + else + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11+ rs_a*1 + 1)); + ymm2 = _mm256_broadcast_sd((double const *)(a11+ rs_a*2 + 2)); + ymm3 = _mm256_broadcast_sd((double const *)(a11+ rs_a*3 + 3)); + } + + ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); + #ifdef BLIS_DISABLE_TRSM_PREINVERSION + ymm4 = ymm1; + #endif + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + ymm4 = _mm256_div_pd(ymm4, ymm1); + #endif + } + _mm256_storeu_pd((double *)(d11_pack), ymm4); + + for(i = (m-d_mr); (i+1) > 0; i -= d_mr) //loop along 'M' direction + { + a01 = D_A_pack; + a11 = L + (n_remainder - 4)*cs_a + (n_remainder - 4)*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (i) + (n_remainder - 4)*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_4nx8m(a01,b10,cs_b,p_lda,k_iter) + + BLIS_PRE_DTRSM_SMALL_4x8(AlphaVal,b11,cs_b) + + ///implement TRSM/// + + //extract a33 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm0); + + //extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2*rs_a)); + + ymm7 = _mm256_fnmadd_pd(ymm1, ymm9, ymm7); + ymm8 = _mm256_fnmadd_pd(ymm1, ymm10, ymm8); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1*rs_a)); + + ymm5 = _mm256_fnmadd_pd(ymm1, ymm9, ymm5); + ymm6 = _mm256_fnmadd_pd(ymm1, ymm10, ymm6); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); + + ymm3 = _mm256_fnmadd_pd(ymm1, ymm9, ymm3); + ymm4 = _mm256_fnmadd_pd(ymm1, ymm10, ymm4); + + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1*rs_a)); + + ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); + ymm6 = _mm256_fnmadd_pd(ymm1, ymm8, ymm6); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); + + ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); + ymm4 = _mm256_fnmadd_pd(ymm1, ymm8, ymm4); + + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm0); + + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); + + ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + ymm4 = _mm256_fnmadd_pd(ymm1, ymm6, ymm4); + + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm0); + + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + 4), ymm4); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + _mm256_storeu_pd((double *)(b11 + cs_b + 4), ymm6); + _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); + _mm256_storeu_pd((double *)(b11 + cs_b*2 + 4), ymm8); + _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); + _mm256_storeu_pd((double *)(b11 + cs_b*3 + 4), ymm10); + } + + dim_t m_remainder = i + d_mr; + if(m_remainder >= 4) + { + a01 = D_A_pack; + a11 = L + (n_remainder - 4)*cs_a + (n_remainder - 4)*rs_a; //pointer to block of A to be used for TRSM + b10 = B + (m_remainder - 4) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 4) + (n_remainder - 4)*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_4nx4m(a01,b10,cs_b,p_lda,k_iter) + + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 + + ///implement TRSM/// + + //extract a33 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + + //extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2*rs_a)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm9, ymm7); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm9, ymm5); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm9, ymm3); + + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); + + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); + _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); + + m_remainder -=4; + } + + if(m_remainder) + { + if(3 == m_remainder) + { + a01 = D_A_pack; + a11 = L + (n_remainder - 4)*cs_a + (n_remainder - 4)*rs_a; //pointer to block of A to be used for TRSM + b10 = B + (m_remainder - 3) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 3) + (n_remainder - 4)*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_4nx4m(a01,b10,cs_b,p_lda,k_iter) + + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 + + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*3 + 2)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); + ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 + + ///implement TRSM/// + + //extract a33 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + + //extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2*rs_a)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm9, ymm7); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm9, ymm5); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm9, ymm3); + + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); + + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + ymm0 = _mm256_loadu_pd((double const *)b11); + ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x07); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x07); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x07); + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*3 + 2)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x07); + + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); + xmm5 = _mm256_extractf128_pd(ymm9, 0); + _mm_storeu_pd((double *)(b11 + cs_b * 3),xmm5); + _mm_storel_pd((b11 + cs_b * 3 + 2), _mm256_extractf128_pd(ymm9, 1)); + + m_remainder -=3; + } + else if(2 == m_remainder) + { + a01 = D_A_pack; + a11 = L + (n_remainder - 4)*cs_a + (n_remainder - 4)*rs_a; //pointer to block of A to be used for TRSM + b10 = B + (m_remainder - 2) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 2) + (n_remainder - 4)*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_4nx4m(a01,b10,cs_b,p_lda,k_iter) + + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 + + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); + ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 + + ///implement TRSM/// + + //extract a33 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + + //extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2*rs_a)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm9, ymm7); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm9, ymm5); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm9, ymm3); + + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); + + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + ymm0 = _mm256_loadu_pd((double const *)b11); + ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x03); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x03); + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x03); + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); + ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x03); + + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); + xmm5 = _mm256_extractf128_pd(ymm9, 0); + _mm_storeu_pd((double *)(b11 + cs_b * 3),xmm5); + + m_remainder -=2; + } + else if (1 == m_remainder) + { + a01 = D_A_pack; + a11 = L + (n_remainder - 4)*cs_a + (n_remainder - 4)*rs_a; //pointer to block of A to be used for TRSM + b10 = B + (m_remainder - 1) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 1) + (n_remainder - 4)*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_4nx4m(a01,b10,cs_b,p_lda,k_iter) + + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ymm0 = _mm256_broadcast_sd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 + + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 + + ///implement TRSM/// + + //extract a33 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + + //extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2*rs_a)); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm9, ymm7); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm9, ymm5); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm9, ymm3); + + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); + + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + ymm0 = _mm256_broadcast_sd((double const *)b11); + ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x01); + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x01); + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x01); + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x01); + + _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm3, 0)); + _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm5, 0)); + _mm_storel_pd((b11 + cs_b * 2), _mm256_extractf128_pd(ymm7, 0)); + _mm_storel_pd((b11 + cs_b * 3), _mm256_extractf128_pd(ymm9, 0)); + + m_remainder -=1; + } + } + n_remainder -= 4; + } + + if(n_remainder == 3) + { + a01 = L + (n_remainder - 3)*rs_a + n_remainder*cs_a; //pointer to block of A to be used in GEMM + a11 = L + (n_remainder - 3)*cs_a + (n_remainder - 3)*rs_a; //pointer to block of A to be used for TRSM + + double *ptr_a10_dup = D_A_pack; + + dim_t p_lda = (n-n_remainder); // packed leading dimension + // perform copy of A to packed buffer D_A_pack + + if(transa) + { + for(dim_t x =0;x < p_lda;x+=d_nr) + { + ymm0 = _mm256_loadu_pd((double const *)(a01)); + ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); + ymm2 = _mm256_loadu_pd((double const *)(a01 + cs_a * 2)); + ymm3 = _mm256_loadu_pd((double const *)(a01 + cs_a * 3)); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); + + ymm0 = _mm256_loadu_pd((double const *)(a01 + cs_a * 4)); + ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a * 5)); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_broadcast_sd((double const *)&zero); + + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_broadcast_sd((double const *)&zero); + + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_extractf128_pd(ymm6,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_extractf128_pd(ymm7,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_extractf128_pd(ymm8,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_extractf128_pd(ymm9,0)); + + a01 += d_nr*cs_a; + ptr_a10_dup += d_nr; + } + } + else + { + dim_t loop_count = (n-n_remainder)/4; + + for(dim_t x =0;x < loop_count;x++) + { + ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 0 + x*4)); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + x*4), ymm15); + ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 1 + x*4)); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 1 + x*4), ymm15); + ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 2 + x*4)); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 2 + x*4), ymm15); + } + + dim_t remainder_loop_count = p_lda - loop_count*4; + + __m128d xmm0; + if(remainder_loop_count != 0) + { + xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 0 + loop_count*4)); + _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + loop_count*4), xmm0); + xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 1 + loop_count*4)); + _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 1 + loop_count*4), xmm0); + xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 2 + loop_count*4)); + _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 2 + loop_count*4), xmm0); + } + } + + ymm4 = _mm256_broadcast_sd((double const *)&ones); + if(!is_unitdiag) + { + if(transa) + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a*1 + 1)); + ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a*2 + 2)); + ymm3 = _mm256_broadcast_sd((double const *)&ones); + } + else + { + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11+ rs_a*1 + 1)); + ymm2 = _mm256_broadcast_sd((double const *)(a11+ rs_a*2 + 2)); + ymm3 = _mm256_broadcast_sd((double const *)&ones); + } + + ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); + #ifdef BLIS_DISABLE_TRSM_PREINVERSION + ymm4 = ymm1; + #endif + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + ymm4 = _mm256_div_pd(ymm4, ymm1); + #endif + } + _mm256_storeu_pd((double *)(d11_pack), ymm4); + + for(i = (m-d_mr); (i+1) > 0; i -= d_mr) //loop along 'M' direction + { + a01 = D_A_pack; + a11 = L + (n_remainder - 3)*cs_a + (n_remainder - 3)*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (i) + (n_remainder - 3)*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_3nx8m(a01,b10,cs_b,p_lda,k_iter) + + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); + + ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + 4)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] + + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + ymm4 = _mm256_fmsub_pd(ymm1, ymm15, ymm4); //B11[4-7][0] * alpha-= ymm1 + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b + 4)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] + + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + ymm6 = _mm256_fmsub_pd(ymm1, ymm15, ymm6); //B11[4-7][1] * alpha -= ymm3 + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b*2 + 4)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] + + ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 + ymm8 = _mm256_fmsub_pd(ymm1, ymm15, ymm8); //B11[4-7][2] * alpha -= ymm5 + + ///implement TRSM/// + + //extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1*rs_a)); + + ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); + ymm6 = _mm256_fnmadd_pd(ymm1, ymm8, ymm6); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); + + ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); + ymm4 = _mm256_fnmadd_pd(ymm1, ymm8, ymm4); + + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm0); + + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); + + ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + ymm4 = _mm256_fnmadd_pd(ymm1, ymm6, ymm4); + + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm0); + + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + 4), ymm4); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + _mm256_storeu_pd((double *)(b11 + cs_b + 4), ymm6); + _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); + _mm256_storeu_pd((double *)(b11 + cs_b*2 + 4), ymm8); + } + + dim_t m_remainder = i + d_mr; + if(m_remainder >= 4) + { + a01 = D_A_pack; + a11 = L + (n_remainder - 3)*cs_a + (n_remainder - 3)*rs_a; //pointer to block of A to be used for TRSM + b10 = B + (m_remainder - 4) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 4) + (n_remainder - 3)*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) + + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 + + ///implement TRSM/// + + //extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); + + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); + + m_remainder -=4; + } + + if(m_remainder) + { + if(3 == m_remainder) + { + a01 = D_A_pack; + a11 = L + (n_remainder - 3)*cs_a + (n_remainder - 3)*rs_a; //pointer to block of A to be used for TRSM + b10 = B + (m_remainder - 3) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 3) + (n_remainder - 3)*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) + + BLIS_PRE_DTRSM_SMALL_3N_3M(AlphaVal,b11,cs_b) + + ///implement TRSM/// + //extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); + + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + BLIS_POST_DTRSM_SMALL_3N_3M(b11,cs_b) + + m_remainder -=3; + } + else if(2 == m_remainder) + { + a01 = D_A_pack; + a11 = L + (n_remainder - 3)*cs_a + (n_remainder - 3)*rs_a; //pointer to block of A to be used for TRSM + b10 = B + (m_remainder - 2) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 2) + (n_remainder - 3)*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) + + BLIS_PRE_DTRSM_SMALL_3N_2M(AlphaVal,b11,cs_b) + + ///implement TRSM/// + + //extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); + + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + BLIS_POST_DTRSM_SMALL_3N_2M(b11,cs_b) + + m_remainder -=2; + } + else if (1 == m_remainder) + { + a01 = D_A_pack; + a11 = L + (n_remainder - 3)*cs_a + (n_remainder - 3)*rs_a; //pointer to block of A to be used for TRSM + b10 = B + (m_remainder - 1) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 1) + (n_remainder - 3)*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) + + BLIS_PRE_DTRSM_SMALL_3N_1M(AlphaVal,b11,cs_b) + + ///implement TRSM/// + + //extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); + + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + BLIS_POST_DTRSM_SMALL_3N_1M(b11,cs_b) + + m_remainder -=1; + } + } + n_remainder -= 3; + } + else if(n_remainder == 2) + { + a01 = L + (n_remainder - 2)*rs_a + n_remainder*cs_a; //pointer to block of A to be used in GEMM + a11 = L + (n_remainder - 2)*cs_a + (n_remainder - 2)*rs_a; //pointer to block of A to be used for TRSM + + double *ptr_a10_dup = D_A_pack; + + dim_t p_lda = (n-n_remainder); // packed leading dimension + // perform copy of A to packed buffer D_A_pack + + if(transa) + { + for(dim_t x =0;x < p_lda;x+=d_nr) + { + ymm0 = _mm256_loadu_pd((double const *)(a01)); + ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); + ymm2 = _mm256_loadu_pd((double const *)(a01 + cs_a * 2)); + ymm3 = _mm256_loadu_pd((double const *)(a01 + cs_a * 3)); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); + + ymm0 = _mm256_loadu_pd((double const *)(a01 + cs_a * 4)); + ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a * 5)); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_broadcast_sd((double const *)&zero); + + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_broadcast_sd((double const *)&zero); + + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_extractf128_pd(ymm6,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_extractf128_pd(ymm7,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_extractf128_pd(ymm8,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_extractf128_pd(ymm9,0)); + + a01 += d_nr*cs_a; + ptr_a10_dup += d_nr; + } + } + else + { + dim_t loop_count = (n-n_remainder)/4; + + for(dim_t x =0;x < loop_count;x++) + { + ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 0 + x*4)); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + x*4), ymm15); + ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 1 + x*4)); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 1 + x*4), ymm15); + } + + dim_t remainder_loop_count = p_lda - loop_count*4; + + __m128d xmm0; + if(remainder_loop_count != 0) + { + xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 0 + loop_count*4)); + _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + loop_count*4), xmm0); + xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 1 + loop_count*4)); + _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 1 + loop_count*4), xmm0); + } + } + + ymm4 = _mm256_broadcast_sd((double const *)&ones); + if(!is_unitdiag) + { + if(transa) + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11+cs_a*1 + 1)); + } + else + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11+rs_a*1 + 1)); + } + ymm2 = _mm256_broadcast_sd((double const *)&ones); + ymm3 = _mm256_broadcast_sd((double const *)&ones); + + ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); + #ifdef BLIS_DISABLE_TRSM_PREINVERSION + ymm4 = ymm1; + #endif + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + ymm4 = _mm256_div_pd(ymm4, ymm1); + #endif + } + _mm256_storeu_pd((double *)(d11_pack), ymm4); + + for(i = (m-d_mr); (i+1) > 0; i -= d_mr) //loop along 'M' direction + { + a01 = D_A_pack; + a11 = L + (n_remainder - 2)*cs_a + (n_remainder - 2)*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (i) + (n_remainder - 2)*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_2nx8m(a01,b10,cs_b,p_lda,k_iter) + + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); + + ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + 4)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] + + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + ymm4 = _mm256_fmsub_pd(ymm1, ymm15, ymm4); //B11[4-7][0] * alpha-= ymm1 + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b + 4)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] + + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + ymm6 = _mm256_fmsub_pd(ymm1, ymm15, ymm6); //B11[4-7][1] * alpha -= ymm3 + + ///implement TRSM/// + + //extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm0); + + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); + + ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + ymm4 = _mm256_fnmadd_pd(ymm1, ymm6, ymm4); + + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm0); + + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + 4), ymm4); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + _mm256_storeu_pd((double *)(b11 + cs_b + 4), ymm6); + } + + dim_t m_remainder = i + d_mr; + if(m_remainder >= 4) + { + a01 = D_A_pack; + a11 = L + (n_remainder - 2)*cs_a + (n_remainder - 2)*rs_a; //pointer to block of A to be used for TRSM + b10 = B + (m_remainder - 4) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 4) + (n_remainder - 2)*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) + + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + + ///implement TRSM/// + + //extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + + m_remainder -=4; + } + + if(m_remainder) + { + if(3 == m_remainder) + { + a01 = D_A_pack; + a11 = L + (n_remainder - 2)*cs_a + (n_remainder - 2)*rs_a; //pointer to block of A to be used for TRSM + b10 = B + (m_remainder - 3) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 3) + (n_remainder - 2)*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) + + BLIS_PRE_DTRSM_SMALL_2N_3M(AlphaVal,b11,cs_b) + + ///implement TRSM/// + + //extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + BLIS_POST_DTRSM_SMALL_2N_3M(b11,cs_b) + + m_remainder -=3; + } + else if(2 == m_remainder) + { + a01 = D_A_pack; + a11 = L + (n_remainder - 2)*cs_a + (n_remainder - 2)*rs_a; //pointer to block of A to be used for TRSM + b10 = B + (m_remainder - 2) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 2) + (n_remainder - 2)*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) + + BLIS_PRE_DTRSM_SMALL_2N_2M(AlphaVal,b11,cs_b) + ///implement TRSM/// + + //extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + BLIS_POST_DTRSM_SMALL_2N_2M(b11,cs_b) + + m_remainder -=2; + } + else if (1 == m_remainder) + { + a01 = D_A_pack; + a11 = L + (n_remainder - 2)*cs_a + (n_remainder - 2)*rs_a; //pointer to block of A to be used for TRSM + b10 = B + (m_remainder - 1) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 1) + (n_remainder - 2)*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) + + BLIS_PRE_DTRSM_SMALL_2N_1M(AlphaVal,b11,cs_b) + ///implement TRSM/// + + //extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + BLIS_POST_DTRSM_SMALL_2N_1M(b11,cs_b) + + m_remainder -=1; + } + } + n_remainder -= 2; + } + else if(n_remainder == 1) + { + a01 = L + (n_remainder - 1)*rs_a + n_remainder*cs_a; //pointer to block of A to be used in GEMM + a11 = L + (n_remainder - 1)*cs_a + (n_remainder - 1)*rs_a; //pointer to block of A to be used for TRSM + + double *ptr_a10_dup = D_A_pack; + + dim_t p_lda = (n-n_remainder); // packed leading dimension + // perform copy of A to packed buffer D_A_pack + + if(transa) + { + for(dim_t x =0;x < p_lda;x+=d_nr) + { + ymm0 = _mm256_loadu_pd((double const *)(a01)); + ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); + ymm2 = _mm256_loadu_pd((double const *)(a01 + cs_a * 2)); + ymm3 = _mm256_loadu_pd((double const *)(a01 + cs_a * 3)); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); + + ymm0 = _mm256_loadu_pd((double const *)(a01 + cs_a * 4)); + ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a * 5)); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_broadcast_sd((double const *)&zero); + + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_broadcast_sd((double const *)&zero); + + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_extractf128_pd(ymm6,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_extractf128_pd(ymm7,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_extractf128_pd(ymm8,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_extractf128_pd(ymm9,0)); + + a01 += d_nr*cs_a; + ptr_a10_dup += d_nr; + } + } + else + { + dim_t loop_count = (n-n_remainder)/4; + + for(dim_t x =0;x < loop_count;x++) + { + ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 0 + x*4)); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + x*4), ymm15); + } + + dim_t remainder_loop_count = p_lda - loop_count*4; + + __m128d xmm0; + if(remainder_loop_count != 0) + { + xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 0 + loop_count*4)); + _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + loop_count*4), xmm0); + } + } + + ymm4 = _mm256_broadcast_sd((double const *)&ones); + if(!is_unitdiag) + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)&ones); + ymm2 = _mm256_broadcast_sd((double const *)&ones); + ymm3 = _mm256_broadcast_sd((double const *)&ones); + + ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); + #ifdef BLIS_DISABLE_TRSM_PREINVERSION + ymm4 = ymm1; + #endif + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + ymm4 = _mm256_div_pd(ymm4, ymm1); + #endif + } + _mm256_storeu_pd((double *)(d11_pack), ymm4); + + for(i = (m-d_mr); (i+1) > 0; i -= d_mr) //loop along 'M' direction + { + a01 = D_A_pack; + a11 = L + (n_remainder - 1)*cs_a + (n_remainder - 1)*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (i) + (n_remainder - 1)*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_1nx8m(a01,b10,cs_b,p_lda,k_iter) + + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); + + ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + 4)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] + + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + ymm4 = _mm256_fmsub_pd(ymm1, ymm15, ymm4); //B11[4-7][0] * alpha-= ymm1 + + ///implement TRSM/// + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm0); + + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + 4), ymm4); + } + + dim_t m_remainder = i + d_mr; + if(m_remainder >= 4) + { + a01 = D_A_pack; + a11 = L + (n_remainder - 1)*cs_a + (n_remainder - 1)*rs_a; //pointer to block of A to be used for TRSM + b10 = B + (m_remainder - 4) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 4) + (n_remainder - 1)*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + + ymm3 = _mm256_setzero_pd(); + + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) + + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + + ///implement TRSM/// + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + _mm256_storeu_pd((double *)b11, ymm3); + + m_remainder -=4; + } + + if(m_remainder) + { + if(3 == m_remainder) + { + a01 = D_A_pack; + a11 = L + (n_remainder - 1)*cs_a + (n_remainder - 1)*rs_a; //pointer to block of A to be used for TRSM + b10 = B + (m_remainder - 3) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 3) + (n_remainder - 1)*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + + ymm3 = _mm256_setzero_pd(); + + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) + + BLIS_PRE_DTRSM_SMALL_1N_3M(AlphaVal,b11,cs_b) + + ///implement TRSM/// + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + ymm0 = _mm256_loadu_pd((double const *)b11); + ymm3 = _mm256_blend_pd(ymm6, ymm3, 0x07); + + BLIS_POST_DTRSM_SMALL_1N_3M(b11,cs_b) + + m_remainder -=3; + } + else if(2 == m_remainder) + { + a01 = D_A_pack; + a11 = L + (n_remainder - 1)*cs_a + (n_remainder - 1)*rs_a; //pointer to block of A to be used for TRSM + b10 = B + (m_remainder - 2) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 2) + (n_remainder - 1)*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + + ymm3 = _mm256_setzero_pd(); + + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) + + BLIS_PRE_DTRSM_SMALL_1N_2M(AlphaVal,b11,cs_b) + + ///implement TRSM/// + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + BLIS_POST_DTRSM_SMALL_1N_2M(b11,cs_b) + + m_remainder -=2; + } + else if (1 == m_remainder) + { + a01 = D_A_pack; + a11 = L + (n_remainder - 1)*cs_a + (n_remainder - 1)*rs_a; //pointer to block of A to be used for TRSM + b10 = B + (m_remainder - 1) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 1) + (n_remainder - 1)*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + + ymm3 = _mm256_setzero_pd(); + + ///GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) + + BLIS_PRE_DTRSM_SMALL_1N_1M(AlphaVal,b11,cs_b) + + ///implement TRSM/// + //extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + BLIS_POST_DTRSM_SMALL_1N_1M(b11,cs_b) + + m_remainder -=1; + } + } + n_remainder -= 1; + } + + if ((required_packing_A == 1) && bli_mem_is_alloc( &local_mem_buf_A_s )) + { + bli_membrk_release(&rntm, + &local_mem_buf_A_s); + } + return BLIS_SUCCESS; +} + +/* TRSM for the case AX = alpha * B, Double precision + * A is lower-triangular, transpose, non-unit diagonal + * dimensions A: mxm X: mxn B: mxn +*/ +BLIS_INLINE err_t bli_dtrsm_small_AltXB_AuXB +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +) +{ + dim_t m = bli_obj_length(b); // number of rows of matrix B + dim_t n = bli_obj_width(b); // number of columns of matrix B + + bool transa = bli_obj_has_trans(a); + dim_t cs_a, rs_a; + dim_t d_mr = 8,d_nr = 6; + + // Swap rs_a & cs_a in case of non-tranpose. + if(transa) + { + cs_a = bli_obj_col_stride(a); // column stride of A + rs_a = bli_obj_row_stride(a); // row stride of A + } + else + { + cs_a = bli_obj_row_stride(a); // row stride of A + rs_a = bli_obj_col_stride(a); // column stride of A + } + dim_t cs_b = bli_obj_col_stride(b); // column stride of B + + dim_t i, j, k; //loop variables + dim_t k_iter; //number of times GEMM to be performed + + double AlphaVal = *(double *)AlphaObj->buffer; //value of alpha + double *L = a->buffer; //pointer to matrix A + double *B = b->buffer; //pointer to matrix B + + //pointers that point to blocks for GEMM and TRSM + double *a10, *a11, *b01, *b11; + + double ones = 1.0; + bool is_unitdiag = bli_obj_has_unit_diag(a); + + //scratch registers + __m256d ymm0, ymm1, ymm2, ymm3; + __m256d ymm4, ymm5, ymm6, ymm7; + __m256d ymm8, ymm9, ymm10, ymm11; + __m256d ymm12, ymm13, ymm14, ymm15; + __m256d ymm16, ymm17, ymm18, ymm19; + __m256d ymm20; + + __m128d xmm5; + + gint_t required_packing_A = 1; + mem_t local_mem_buf_A_s = {0}; + double *D_A_pack = NULL; + double d11_pack[d_mr] __attribute__((aligned(64))); + rntm_t rntm; + + bli_rntm_init_from_global( &rntm ); + bli_rntm_set_num_threads_only( 1, &rntm ); + bli_membrk_rntm_set_membrk( &rntm ); + + siz_t buffer_size = bli_pool_block_size( + bli_membrk_pool( + bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), + bli_rntm_membrk(&rntm))); + + if((d_mr * m * sizeof(double)) > buffer_size) + return BLIS_NOT_YET_IMPLEMENTED; + + if(required_packing_A == 1) + { + // Get the buffer from the pool. + bli_membrk_acquire_m(&rntm, + buffer_size, + BLIS_BITVAL_BUFFER_FOR_A_BLOCK, + &local_mem_buf_A_s); + if(FALSE==bli_mem_is_alloc(&local_mem_buf_A_s)) return BLIS_NULL_POINTER; + D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); + if(NULL==D_A_pack) return BLIS_NULL_POINTER; + } + + /* + Performs solving TRSM for 8 colmns at a time from 0 to m/d_mr in steps of d_mr + a. Load, transpose, Pack A (a10 block), the size of packing 8x6 to 8x (m-d_mr) + First there will be no GEMM and no packing of a10 because it is only TRSM + b. Using packed a10 block and b01 block perform GEMM operation + c. Use GEMM outputs, perform TRSM operaton using a11, b11 and update B + d. Repeat b,c for n rows of B in steps of d_nr + */ + for(i = (m - d_mr); (i + 1) > 0; i -= d_mr) + { + a10 = L + (i*cs_a) + (i + d_mr)*rs_a; //pointer to block of A to be used for GEMM + a11 = L + (i*cs_a) + (i*rs_a); //pointer to block of A to be used for TRSM + + // Do transpose for a10 & store in D_A_pack + //ptr_a10_dup = D_A_pack; + + dim_t p_lda = d_mr; // packed leading dimension + + if(transa) + { + /* + Load, transpose and pack current A block (a10) into packed buffer memory D_A_pack + a. This a10 block is used in GEMM portion only and this + a10 block size will be increasing by d_mr for every next itteration + untill it reaches 8x(m-8) which is the maximum GEMM alone block size in A + b. This packed buffer is reused to calculate all n rows of B matrix + */ + bli_dtrsm_small_pack('L', (m-i-d_mr), 1, a10, cs_a, D_A_pack,p_lda,d_mr); + + /* + Pack 8 diagonal elements of A block into an array + a. This helps in utilze cache line efficiently in TRSM operation + b. store ones when input is unit diagonal + */ + dtrsm_small_pack_diag_element(is_unitdiag,a11,cs_a,d11_pack,d_mr); + } + else + { + bli_dtrsm_small_pack('L', (m-i-d_mr), 0, a10, rs_a, D_A_pack,p_lda,d_mr); + dtrsm_small_pack_diag_element(is_unitdiag,a11,rs_a,d11_pack,d_mr); + } + + /* + a. Perform GEMM using a10, b01. + b. Perform TRSM on a11, b11 + c. This loop GEMM+TRSM loops operates with 8x6 block size + along n dimension for every d_nr rows of b01 where + packed A buffer is reused in computing all n rows of B. + d. Same approch is used in remaining fringe cases. + */ + for(j = (n - d_nr); (j + 1) > 0; j -= d_nr) + { + a10 = D_A_pack; + b01 = B + (j * cs_b) + i + d_mr; //pointer to block of B to be used for GEMM + b11 = B + (j * cs_b) + i; //pointer to block of B to be used for TRSM + + k_iter = (m - i - d_mr); + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + /* + Peform GEMM between a10 and b01 blocks + For first itteration there will be no GEMM operation + where k_iter are zero + */ + BLIS_DTRSM_SMALL_GEMM_8mx6n(a10,b01,cs_b,p_lda,k_iter) + + /* + Load b11 of size 6x8 and multiply with alpha + Add the GEMM output and perform inregister transose of b11 + to peform TRSM operation. + */ + BLIS_DTRSM_SMALL_NREG_TRANSPOSE_6x8(b11,cs_b,AlphaVal) + + /* + Compute 8x6 TRSM block by using GEMM block output in register + a. The 8x6 input (gemm outputs) are stored in combinations of ymm registers + 1. ymm15, ymm20 2. ymm14, ymm19 3. ymm13, ymm18 , 4. ymm12, ymm17 + 5. ymm11, ymm7 6. ymm10, ymm6, 7.ymm9, ymm5 8. ymm8, ymm4 + where ymm15-ymm8 holds 8x4 data and reaming 8x2 will be hold by + other registers + b. Towards the end do in regiser transpose of TRSM output and store in b11 + */ + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); + + //perform mul operation + ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); + ymm20 = DTRSM_SMALL_DIV_OR_SCALE(ymm20, ymm1); + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); + + //(ROw7): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6*cs_a + 7*rs_a)); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm15, ymm14); + ymm19 = _mm256_fnmadd_pd(ymm2, ymm20, ymm19); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 7*rs_a)); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm15, ymm13); + ymm18 = _mm256_fnmadd_pd(ymm2, ymm20, ymm18); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 7*rs_a)); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm15, ymm12); + ymm17 = _mm256_fnmadd_pd(ymm2, ymm20, ymm17); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 7*rs_a)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm15, ymm11); + ymm7 = _mm256_fnmadd_pd(ymm2, ymm20, ymm7); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 7*rs_a)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm15, ymm10); + ymm6 = _mm256_fnmadd_pd(ymm2, ymm20, ymm6); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 7*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm15, ymm9); + ymm5 = _mm256_fnmadd_pd(ymm2, ymm20, ymm5); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7*rs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm15, ymm8); + ymm4 = _mm256_fnmadd_pd(ymm2, ymm20, ymm4); + + //perform mul operation + ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); + ymm19 = DTRSM_SMALL_DIV_OR_SCALE(ymm19, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + + //(ROw6): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 6*rs_a)); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm14, ymm13); + ymm18 = _mm256_fnmadd_pd(ymm2, ymm19, ymm18); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 6*rs_a)); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm14, ymm12); + ymm17 = _mm256_fnmadd_pd(ymm2, ymm19, ymm17); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 6*rs_a)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm14, ymm11); + ymm7 = _mm256_fnmadd_pd(ymm2, ymm19, ymm7); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 6*rs_a)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm14, ymm10); + ymm6 = _mm256_fnmadd_pd(ymm2, ymm19, ymm6); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 6*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm14, ymm9); + ymm5 = _mm256_fnmadd_pd(ymm2, ymm19, ymm5); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6*rs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm14, ymm8); + ymm4 = _mm256_fnmadd_pd(ymm2, ymm19, ymm4); + + //perform mul operation + ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm1); + ymm18 = DTRSM_SMALL_DIV_OR_SCALE(ymm18, ymm1); + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + + //(ROw5): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 5*rs_a)); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm13, ymm12); + ymm17 = _mm256_fnmadd_pd(ymm2, ymm18, ymm17); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 5*rs_a)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm13, ymm11); + ymm7 = _mm256_fnmadd_pd(ymm2, ymm18, ymm7); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 5*rs_a)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm13, ymm10); + ymm6 = _mm256_fnmadd_pd(ymm2, ymm18, ymm6); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 5*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm13, ymm9); + ymm5 = _mm256_fnmadd_pd(ymm2, ymm18, ymm5); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm13, ymm8); + ymm4 = _mm256_fnmadd_pd(ymm2, ymm18, ymm4); + + //perform mul operation + ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); + ymm17 = DTRSM_SMALL_DIV_OR_SCALE(ymm17, ymm1); + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //(ROw4): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 4*rs_a)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm12, ymm11); + ymm7 = _mm256_fnmadd_pd(ymm2, ymm17, ymm7); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 4*rs_a)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm12, ymm10); + ymm6 = _mm256_fnmadd_pd(ymm2, ymm17, ymm6); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 4*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm12, ymm9); + ymm5 = _mm256_fnmadd_pd(ymm2, ymm17, ymm5); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm12, ymm8); + ymm4 = _mm256_fnmadd_pd(ymm2, ymm17, ymm4); + + //perform mul operation + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm1); + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(ROw3): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 3*rs_a)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm11, ymm10); + ymm6 = _mm256_fnmadd_pd(ymm2, ymm7, ymm6); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm11, ymm9); + ymm5 = _mm256_fnmadd_pd(ymm2, ymm7, ymm5); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm11, ymm8); + ymm4 = _mm256_fnmadd_pd(ymm2, ymm7, ymm4); + + //perform mul operation + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm10, ymm9); + ymm5 = _mm256_fnmadd_pd(ymm2, ymm6, ymm5); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm10, ymm8); + ymm4 = _mm256_fnmadd_pd(ymm2, ymm6, ymm4); + + //perform mul operation + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm1); + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm9, ymm8); + ymm4 = _mm256_fnmadd_pd(ymm2, ymm5, ymm4); + + //perform mul operation + ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); + ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm1); + + BLIS_DTRSM_SMALL_NREG_TRANSPOSE_8x6_AND_STORE(b11,cs_b) + } + + dim_t n_remainder = j + d_nr; + if(n_remainder >= 4) + { + a10 = D_A_pack; + a11 = L + (i*cs_a) + (i*rs_a); + b01 = B + ((n_remainder - 4)* cs_b) + i + d_mr; + b11 = B + ((n_remainder - 4)* cs_b) + i; + + k_iter = (m - i - d_mr); + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_8mx4n(a10,b01,cs_b,p_lda,k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] + ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] + ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *2 + 4)); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] + ymm7 = _mm256_loadu_pd((double const *)(b11 + cs_b *3 + 4)); //B11[0][7] B11[1][7] B11[2][7] B11[3][7] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); //B11[0-3][3] * alpha -= B01[0-3][3] + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4] + ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] * alpha -= B01[0-3][5] + ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); //B11[0-3][6] * alpha -= B01[0-3][6] + ymm7 = _mm256_fmsub_pd(ymm7, ymm16, ymm15); //B11[0-3][7] * alpha -= B01[0-3][7] + + ///implement TRSM/// + + ///transpose of B11// + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[2][4] B11[2][5] + ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[2][6] B11[2][7] + + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3] + ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[1][4] B11[1][5] B11[3][4] B11[3][5] + ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[1][6] B11[1][7] B11[3][6] B11[3][7] + + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3] + ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3] + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); + + //perform mul operation + ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); + + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6*cs_a + 7*rs_a)); + ymm3 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 7*rs_a)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 7*rs_a)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 7*rs_a)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 7*rs_a)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 7*rs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 7*rs_a)); + + //(ROw7): FMA operations + ymm14 = _mm256_fnmadd_pd(ymm2, ymm15, ymm14); + ymm13 = _mm256_fnmadd_pd(ymm3, ymm15, ymm13); + ymm12 = _mm256_fnmadd_pd(ymm4, ymm15, ymm12); + ymm11 = _mm256_fnmadd_pd(ymm5, ymm15, ymm11); + ymm10 = _mm256_fnmadd_pd(ymm6, ymm15, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm15, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm15, ymm8); + + //perform mul operation + ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + + ymm3 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 6*rs_a)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 6*rs_a)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 6*rs_a)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 6*rs_a)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 6*rs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 6*rs_a)); + + //(ROw6): FMA operations + ymm13 = _mm256_fnmadd_pd(ymm3, ymm14, ymm13); + ymm12 = _mm256_fnmadd_pd(ymm4, ymm14, ymm12); + ymm11 = _mm256_fnmadd_pd(ymm5, ymm14, ymm11); + ymm10 = _mm256_fnmadd_pd(ymm6, ymm14, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm14, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm14, ymm8); + + //perform mul operation + ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm1); + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + + ymm4 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 5*rs_a)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 5*rs_a)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 5*rs_a)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 5*rs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); + + //(ROw5): FMA operations + ymm12 = _mm256_fnmadd_pd(ymm4, ymm13, ymm12); + ymm11 = _mm256_fnmadd_pd(ymm5, ymm13, ymm11); + ymm10 = _mm256_fnmadd_pd(ymm6, ymm13, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm13, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm13, ymm8); + + //perform mul operation + ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 4*rs_a)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 4*rs_a)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 4*rs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); + + //(ROw4): FMA operations + ymm11 = _mm256_fnmadd_pd(ymm5, ymm12, ymm11); + ymm10 = _mm256_fnmadd_pd(ymm6, ymm12, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm12, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm12, ymm8); + + //perform mul operation + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 3*rs_a)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3*rs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + + //(ROw3): FMA operations + ymm10 = _mm256_fnmadd_pd(ymm6, ymm11, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm11, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm11, ymm8); + + //perform mul operation + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2*rs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + + //(ROw2): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm7, ymm10, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm10, ymm8); + + //perform mul operation + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); + + //(ROw2): FMA operations + ymm8 = _mm256_fnmadd_pd(ymm16, ymm9, ymm8); + + //perform mul operation + ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] + ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] + ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[4][1] B11[5][1] B11[4][3] B11[5][3] + ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[6][1] B11[7][1] B11[6][3] B11[7][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] + ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] + + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); //store B11[4][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); //store B11[5][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 4), ymm6); //store B11[6][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3 + 4), ymm7); //store B11[7][0-3] + n_remainder -=4; + } + + if(n_remainder) //implementation fo remaining columns(when 'N' is not a multiple of d_nr)() n = 3 + { + a10 = D_A_pack; + a11 = L + (i*cs_a) + (i*rs_a); + b01 = B + i + d_mr; + b11 = B + i; + + k_iter = (m - i - d_mr) ; + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + if(3 == n_remainder) + { + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_8mx3n(a10,b01,cs_b,p_lda,k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] + ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] + ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *2 + 4)); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4] + ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] * alpha -= B01[0-3][5] + ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); //B11[0-3][6] * alpha -= B01[0-3][6] + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + } + else if(2 == n_remainder) + { + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_8mx2n(a10,b01,cs_b,p_lda,k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] + ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4] + ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] * alpha -= B01[0-3][5] + ymm6 = _mm256_broadcast_sd((double const *)(&ones)); + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + } + else if(1 == n_remainder) + { + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_8mx1n(a10,b01,cs_b,p_lda,k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_broadcast_sd((double const *)(&ones)); + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4] + ymm5 = _mm256_broadcast_sd((double const *)(&ones)); + ymm6 = _mm256_broadcast_sd((double const *)(&ones)); + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + } + ///implement TRSM/// + + ///transpose of B11// + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[2][4] B11[2][5] + ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[2][6] B11[2][7] + + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3] + ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[1][4] B11[1][5] B11[3][4] B11[3][5] + ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[1][6] B11[1][7] B11[3][6] B11[3][7] + + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3] + ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3] + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); + + //perform mul operation + ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); + + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6*cs_a + 7*rs_a)); + ymm3 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 7*rs_a)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 7*rs_a)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 7*rs_a)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 7*rs_a)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 7*rs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 7*rs_a)); + + //(ROw7): FMA operations + ymm14 = _mm256_fnmadd_pd(ymm2, ymm15, ymm14); + ymm13 = _mm256_fnmadd_pd(ymm3, ymm15, ymm13); + ymm12 = _mm256_fnmadd_pd(ymm4, ymm15, ymm12); + ymm11 = _mm256_fnmadd_pd(ymm5, ymm15, ymm11); + ymm10 = _mm256_fnmadd_pd(ymm6, ymm15, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm15, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm15, ymm8); + + //perform mul operation + ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + + ymm3 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 6*rs_a)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 6*rs_a)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 6*rs_a)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 6*rs_a)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 6*rs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 6*rs_a)); + + //(ROw6): FMA operations + ymm13 = _mm256_fnmadd_pd(ymm3, ymm14, ymm13); + ymm12 = _mm256_fnmadd_pd(ymm4, ymm14, ymm12); + ymm11 = _mm256_fnmadd_pd(ymm5, ymm14, ymm11); + ymm10 = _mm256_fnmadd_pd(ymm6, ymm14, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm14, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm14, ymm8); + + //perform mul operation + ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm1); + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + + ymm4 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 5*rs_a)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 5*rs_a)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 5*rs_a)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 5*rs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); + + //(ROw5): FMA operations + ymm12 = _mm256_fnmadd_pd(ymm4, ymm13, ymm12); + ymm11 = _mm256_fnmadd_pd(ymm5, ymm13, ymm11); + ymm10 = _mm256_fnmadd_pd(ymm6, ymm13, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm13, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm13, ymm8); + + //perform mul operation + ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 4*rs_a)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 4*rs_a)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 4*rs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); + + //(ROw4): FMA operations + ymm11 = _mm256_fnmadd_pd(ymm5, ymm12, ymm11); + ymm10 = _mm256_fnmadd_pd(ymm6, ymm12, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm12, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm12, ymm8); + + //perform mul operation + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 3*rs_a)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3*rs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + + //(ROw3): FMA operations + ymm10 = _mm256_fnmadd_pd(ymm6, ymm11, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm11, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm11, ymm8); + + //perform mul operation + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2*rs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + + //(ROw2): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm7, ymm10, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm10, ymm8); + + //perform mul operation + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); + + //(ROw2): FMA operations + ymm8 = _mm256_fnmadd_pd(ymm16, ymm9, ymm8); + + //perform mul operation + ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] + ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] + ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[4][1] B11[5][1] B11[4][3] B11[5][3] + ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[6][1] B11[7][1] B11[6][3] B11[7][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] + ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] + + if(3 == n_remainder) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); //store B11[4][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); //store B11[5][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 4), ymm6); //store B11[6][0-3] + } + else if(2 == n_remainder) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); //store B11[4][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); //store B11[5][0-3] + } + else if(1 == n_remainder) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); //store B11[4][0-3] + } + } + }// End of multiples of d_mr blocks in m-dimension + + // Repetative A blocks will be 4*4 + dim_t m_remainder = i + d_mr; + if(m_remainder >= 4) + { + i = m_remainder - 4; + a10 = L + (i*cs_a) + (i + 4)*rs_a; //pointer to block of A to be used for GEMM + a11 = L + (i*cs_a) + (i*rs_a); //pointer to block of A to be used for TRSM + + // Do transpose for a10 & store in D_A_pack + double *ptr_a10_dup = D_A_pack; + dim_t p_lda = 4; // packed leading dimension + if(transa) + { + for(dim_t x =0;x < m-i+4;x+=p_lda) + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + cs_a)); + ymm2 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); + ymm3 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); + + a10 += p_lda; + ptr_a10_dup += p_lda*p_lda; + } + } + else + { + for(dim_t x =0;x < m-i-4;x++) + { + ymm0 = _mm256_loadu_pd((double const *)(a10 + x*rs_a)); + _mm256_storeu_pd((double *)(ptr_a10_dup + x*p_lda), ymm0); + } + } + + ymm4 = _mm256_broadcast_sd((double const *)&ones); + if(!is_unitdiag) + { + if(transa) + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11+cs_a*1 + 1)); + ymm2 = _mm256_broadcast_sd((double const *)(a11+cs_a*2 + 2)); + ymm3 = _mm256_broadcast_sd((double const *)(a11+cs_a*3 + 3)); + } + else + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11+rs_a*1 + 1)); + ymm2 = _mm256_broadcast_sd((double const *)(a11+rs_a*2 + 2)); + ymm3 = _mm256_broadcast_sd((double const *)(a11+rs_a*3 + 3)); + } + + ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); + #ifdef BLIS_DISABLE_TRSM_PREINVERSION + ymm4 = ymm1; + #endif + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + ymm4 = _mm256_div_pd(ymm4, ymm1); + #endif + } + _mm256_storeu_pd((double *)(d11_pack), ymm4); + + //cols + for(j = (n - d_nr); (j + 1) > 0; j -= d_nr) //loop along 'N' dimension + { + a10 = D_A_pack; + a11 = L + (i*cs_a) + (i*rs_a); //pointer to block of A to be used for TRSM + b01 = B + (j*cs_b) + i + 4; //pointer to block of B to be used for GEMM + b11 = B + (j* cs_b) + i; //pointer to block of B to be used for TRSM + + k_iter = (m - i - 4); //number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx6n(a10,b01,cs_b,p_lda,k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + ///transpose of B11// + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); + + ymm16 = _mm256_broadcast_sd((double const *)(&ones)); + + ////unpacklow//// + ymm7 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + + //rearrange low elements + ymm4 = _mm256_permute2f128_pd(ymm7,ymm16,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm6 = _mm256_permute2f128_pd(ymm7,ymm16,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + + //rearrange high elements + ymm5 = _mm256_permute2f128_pd(ymm0,ymm16,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm7 = _mm256_permute2f128_pd(ymm0,ymm16,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //perform mul operation + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm1); + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(ROw3): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 3*rs_a)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm11, ymm10); + ymm6 = _mm256_fnmadd_pd(ymm2, ymm7, ymm6); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm11, ymm9); + ymm5 = _mm256_fnmadd_pd(ymm2, ymm7, ymm5); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm11, ymm8); + ymm4 = _mm256_fnmadd_pd(ymm2, ymm7, ymm4); + + //perform mul operation + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm10, ymm9); + ymm5 = _mm256_fnmadd_pd(ymm2, ymm6, ymm5); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm10, ymm8); + ymm4 = _mm256_fnmadd_pd(ymm2, ymm6, ymm4); + + //perform mul operation + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm1); + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm9, ymm8); + ymm4 = _mm256_fnmadd_pd(ymm2, ymm5, ymm4); + + //perform mul operation + ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); + ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm1); + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + + ///unpack high/// + ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + + _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store B11[1][0-3] + } + dim_t n_remainder = j + d_nr; + if((n_remainder >= 4)) + { + a10 = D_A_pack; + a11 = L + (i*cs_a) + (i*rs_a); //pointer to block of A to be used for TRSM + b01 = B + ((n_remainder - 4)* cs_b) + i + 4; //pointer to block of B to be used for GEMM + b11 = B + ((n_remainder - 4)* cs_b) + i; //pointer to block of B to be used for TRSM + + k_iter = (m - i - 4); //number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx4n(a10,b01,cs_b,p_lda,k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + ///transpose of B11// + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //perform mul operation + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(ROw3): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 3*rs_a)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm11, ymm10); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm11, ymm9); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm11, ymm8); + + //perform mul operation + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm10, ymm9); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm10, ymm8); + + //perform mul operation + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm9, ymm8); + + //perform mul operation + ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] + n_remainder = n_remainder - 4; + } + + if(n_remainder) //implementation fo remaining columns(when 'N' is not a multiple of d_nr)() n = 3 + { + a10 = D_A_pack; + a11 = L + (i*cs_a) + (i*rs_a); + b01 = B + i + 4; + b11 = B + i; + + k_iter = (m - i - 4); + + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + + if(3 == n_remainder) + { + BLIS_DTRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + } + else if(2 == n_remainder) + { + BLIS_DTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b,p_lda,k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + } + else if(1 == n_remainder) + { + BLIS_DTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b,p_lda,k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_broadcast_sd((double const *)(&ones)); + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + } + + ///implement TRSM/// + + ///transpose of B11// + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //perform mul operation + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 3*rs_a)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3*rs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + + //(ROw3): FMA operations + ymm10 = _mm256_fnmadd_pd(ymm6, ymm11, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm11, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm11, ymm8); + + //perform mul operation + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2*rs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + + //(ROw2): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm7, ymm10, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm10, ymm8); + + //perform mul operation + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); + + //(ROw2): FMA operations + ymm8 = _mm256_fnmadd_pd(ymm16, ymm9, ymm8); + + //perform mul operation + ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + if(3 == n_remainder) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + } + else if(2 == n_remainder) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + } + else if(1 == n_remainder) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + } + } + m_remainder -= 4; + } + + a10 = L + m_remainder*rs_a; + + // Do transpose for a10 & store in D_A_pack + double *ptr_a10_dup = D_A_pack; + if(3 == m_remainder) // Repetative A blocks will be 3*3 + { + dim_t p_lda = 4; // packed leading dimension + if(transa) + { + for(dim_t x =0;x < m-m_remainder;x+=p_lda) + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + cs_a)); + ymm2 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); + ymm3 = _mm256_broadcast_sd((double const *)&ones); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); + + a10 += p_lda; + ptr_a10_dup += p_lda*p_lda; + } + } + else + { + for(dim_t x =0;x < m-m_remainder;x++) + { + ymm0 = _mm256_loadu_pd((double const *)(a10 + x*rs_a)); + _mm256_storeu_pd((double *)(ptr_a10_dup + x*p_lda), ymm0); + } + } + + //cols + for(j = (n - d_nr); (j + 1) > 0; j -= d_nr) //loop along 'N' dimension + { + a10 = D_A_pack; + a11 = L; //pointer to block of A to be used for TRSM + b01 = B + (j* cs_b) + m_remainder; //pointer to block of B to be used for GEMM + b11 = B + (j* cs_b); //pointer to block of B to be used for TRSM + + k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx6n(a10,b01,cs_b,p_lda,k_iter) + + ///GEMM code ends/// + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to store alpha value + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08); + ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x08); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); + + _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store(B11[0-3][3]) + + if(transa) + dtrsm_AltXB_ref(a11, b11, m_remainder, 6, cs_a, cs_b, is_unitdiag); + else + dtrsm_AuXB_ref(a11, b11, m_remainder, 6, rs_a, cs_b, is_unitdiag); + } + + dim_t n_remainder = j + d_nr; + if((n_remainder >= 4)) + { + a10 = D_A_pack; + a11 = L; //pointer to block of A to be used for TRSM + b01 = B + ((n_remainder - 4)* cs_b) + m_remainder; //pointer to block of B to be used for GEMM + b11 = B + ((n_remainder - 4)* cs_b); //pointer to block of B to be used for TRSM + + k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx4n(a10,b01,cs_b,p_lda,k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); + ymm3 = _mm256_broadcast_sd((double const *)(b11 + cs_b*3 + 2)); + ymm3 = _mm256_insertf128_pd(ymm3, xmm5, 0); + + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08); + ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x08); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + xmm5 = _mm256_extractf128_pd(ymm3, 0); + _mm_storeu_pd((double *)(b11 + cs_b * 3),xmm5); + _mm_storel_pd((b11 + cs_b * 3 + 2), _mm256_extractf128_pd(ymm3, 1)); + + if(transa) + dtrsm_AltXB_ref(a11, b11, m_remainder, 4, cs_a, cs_b, is_unitdiag); + else + dtrsm_AuXB_ref(a11, b11, m_remainder, 4, rs_a, cs_b, is_unitdiag); + n_remainder -= 4; + } + + if(n_remainder) + { + a10 = D_A_pack; + a11 = L; //pointer to block of A to be used for TRSM + b01 = B + m_remainder; //pointer to block of B to be used for GEMM + b11 = B; //pointer to block of B to be used for TRSM + + k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + if(3 == n_remainder) + { + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) + + BLIS_PRE_DTRSM_SMALL_3M_3N(AlphaVal,b11,cs_b) + + if(transa) + dtrsm_AltXB_ref(a11, b11, m_remainder, 3, cs_a, cs_b, is_unitdiag); + else + dtrsm_AuXB_ref(a11, b11, m_remainder, 3, rs_a, cs_b, is_unitdiag); + } + else if(2 == n_remainder) + { + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b,p_lda,k_iter) + + BLIS_PRE_DTRSM_SMALL_3M_2N(AlphaVal,b11,cs_b) + + if(transa) + dtrsm_AltXB_ref(a11, b11, m_remainder, 2, cs_a, cs_b, is_unitdiag); + else + dtrsm_AuXB_ref(a11, b11, m_remainder, 2, rs_a, cs_b, is_unitdiag); + } + else if(1 == n_remainder) + { + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b,p_lda,k_iter) + + BLIS_PRE_DTRSM_SMALL_3M_1N(AlphaVal,b11,cs_b) + + if(transa) + dtrsm_AltXB_ref(a11, b11, m_remainder, 1, cs_a, cs_b, is_unitdiag); + else + dtrsm_AuXB_ref(a11, b11, m_remainder, 1, rs_a, cs_b, is_unitdiag); + } + } + } + else if(2 == m_remainder) // Repetative A blocks will be 2*2 + { + dim_t p_lda = 4; // packed leading dimension + if(transa) + { + for(dim_t x =0;x < m-m_remainder;x+=p_lda) + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + cs_a)); + ymm2 = _mm256_broadcast_sd((double const *)&ones); + ymm3 = _mm256_broadcast_sd((double const *)&ones); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); + + a10 += p_lda; + ptr_a10_dup += p_lda*p_lda; + } + } + else + { + for(dim_t x =0;x < m-m_remainder;x++) + { + ymm0 = _mm256_loadu_pd((double const *)(a10 + x*rs_a)); + _mm256_storeu_pd((double *)(ptr_a10_dup + x*p_lda), ymm0); + } + } + //cols + for(j = (n - d_nr); (j + 1) > 0; j -= d_nr) //loop along 'N' dimension + { + a10 = D_A_pack; + a11 = L; //pointer to block of A to be used for TRSM + b01 = B + (j* cs_b) + m_remainder; //pointer to block of B to be used for GEMM + b11 = B + (j* cs_b); //pointer to block of B to be used for TRSM + + k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx6n(a10,b01,cs_b,p_lda,k_iter) + + ///GEMM code ends/// + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to store alpha value + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0C); + ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0C); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); + + _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store(B11[0-3][3]) + + if(transa) + dtrsm_AltXB_ref(a11, b11, m_remainder, 6, cs_a, cs_b, is_unitdiag); + else + dtrsm_AuXB_ref(a11, b11, m_remainder, 6, rs_a, cs_b, is_unitdiag); + } + dim_t n_remainder = j + d_nr; + if((n_remainder >= 4)) + { + a10 = D_A_pack; + a11 = L; //pointer to block of A to be used for TRSM + b01 = B + ((n_remainder - 4)* cs_b) + m_remainder; //pointer to block of B to be used for GEMM + b11 = B + ((n_remainder - 4)* cs_b); //pointer to block of B to be used for TRSM + + k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx4n(a10,b01,cs_b,p_lda,k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); + ymm3 = _mm256_insertf128_pd(ymm3, xmm5, 0); + + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0C); + ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0C); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + xmm5 = _mm256_extractf128_pd(ymm3, 0); + _mm_storeu_pd((double *)(b11 + cs_b * 3), xmm5); + + if(transa) + dtrsm_AltXB_ref(a11, b11, m_remainder, 4, cs_a, cs_b, is_unitdiag); + else + dtrsm_AuXB_ref(a11, b11, m_remainder, 4, rs_a, cs_b, is_unitdiag); + n_remainder -= 4; + } + if(n_remainder) + { + a10 = D_A_pack; + a11 = L; //pointer to block of A to be used for TRSM + b01 = B + m_remainder; //pointer to block of B to be used for GEMM + b11 = B; //pointer to block of B to be used for TRSM + + k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + if(3 == n_remainder) + { + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) + + BLIS_PRE_DTRSM_SMALL_2M_3N(AlphaVal,b11,cs_b) + + if(transa) + dtrsm_AltXB_ref(a11, b11, m_remainder, 3, cs_a, cs_b, is_unitdiag); + else + dtrsm_AuXB_ref(a11, b11, m_remainder, 3, rs_a, cs_b, is_unitdiag); + } + else if(2 == n_remainder) + { + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b,p_lda,k_iter) + + BLIS_PRE_DTRSM_SMALL_2M_2N(AlphaVal,b11,cs_b) + + if(transa) + dtrsm_AltXB_ref(a11, b11, m_remainder, 2, cs_a, cs_b, is_unitdiag); + else + dtrsm_AuXB_ref(a11, b11, m_remainder, 2, rs_a, cs_b, is_unitdiag); + } + else if(1 == n_remainder) + { + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b,p_lda,k_iter) + + BLIS_PRE_DTRSM_SMALL_2M_1N(AlphaVal,b11,cs_b) + if(transa) + dtrsm_AltXB_ref(a11, b11, m_remainder, 1, cs_a, cs_b, is_unitdiag); + else + dtrsm_AuXB_ref(a11, b11, m_remainder, 1, rs_a, cs_b, is_unitdiag); + } + } + + } + else if(1 == m_remainder) // Repetative A blocks will be 1*1 + { + dim_t p_lda = 4; // packed leading dimension + if(transa) + { + for(dim_t x =0;x < m-m_remainder;x+=p_lda) + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_broadcast_sd((double const *)&ones); + ymm2 = _mm256_broadcast_sd((double const *)&ones); + ymm3 = _mm256_broadcast_sd((double const *)&ones); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); + + a10 += p_lda; + ptr_a10_dup += p_lda*p_lda; + } + } + else + { + for(dim_t x =0;x < m-m_remainder;x++) + { + ymm0 = _mm256_loadu_pd((double const *)(a10 + x*rs_a)); + _mm256_storeu_pd((double *)(ptr_a10_dup + x*p_lda), ymm0); + } + } + //cols + for(j = (n - d_nr); (j + 1) > 0; j -= d_nr) //loop along 'N' dimension + { + a10 = D_A_pack; + a11 = L; //pointer to block of A to be used for TRSM + b01 = B + (j* cs_b) + m_remainder; //pointer to block of B to be used for GEMM + b11 = B + (j* cs_b); //pointer to block of B to be used for TRSM + + k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx6n(a10,b01,cs_b,p_lda,k_iter) + + ///GEMM code ends/// + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to store alpha value + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); + ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0E); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); + + _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store(B11[0-3][3]) + + if(transa) + dtrsm_AltXB_ref(a11, b11, m_remainder, 6, cs_a, cs_b, is_unitdiag); + else + dtrsm_AuXB_ref(a11, b11, m_remainder, 6, rs_a, cs_b, is_unitdiag); + } + dim_t n_remainder = j + d_nr; + if((n_remainder >= 4)) + { + a10 = D_A_pack; + a11 = L; //pointer to block of A to be used for TRSM + b01 = B + ((n_remainder - 4)* cs_b) + m_remainder; //pointer to block of B to be used for GEMM + b11 = B + ((n_remainder - 4)* cs_b); //pointer to block of B to be used for TRSM + + k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx4n(a10,b01,cs_b,p_lda,k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); + ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0E); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) + + if(transa) + dtrsm_AltXB_ref(a11, b11, m_remainder, 4, cs_a, cs_b, is_unitdiag); + else + dtrsm_AuXB_ref(a11, b11, m_remainder, 4, rs_a, cs_b, is_unitdiag); + n_remainder -= 4; + } + if(n_remainder) + { + a10 = D_A_pack; + a11 = L; //pointer to block of A to be used for TRSM + b01 = B + m_remainder; //pointer to block of B to be used for GEMM + b11 = B; //pointer to block of B to be used for TRSM + + k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + if(3 == n_remainder) + { + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) + + BLIS_PRE_DTRSM_SMALL_1M_3N(AlphaVal,b11,cs_b) + + if(transa) + dtrsm_AltXB_ref(a11, b11, m_remainder, 3, cs_a, cs_b, is_unitdiag); + else + dtrsm_AuXB_ref(a11, b11, m_remainder, 3, rs_a, cs_b, is_unitdiag); + } + else if(2 == n_remainder) + { + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b,p_lda,k_iter) + + BLIS_PRE_DTRSM_SMALL_1M_2N(AlphaVal,b11,cs_b) + + if(transa) + dtrsm_AltXB_ref(a11, b11, m_remainder, 2, cs_a, cs_b, is_unitdiag); + else + dtrsm_AuXB_ref(a11, b11, m_remainder, 2, rs_a, cs_b, is_unitdiag); + } + else if(1 == n_remainder) + { + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b,p_lda,k_iter) + + BLIS_PRE_DTRSM_SMALL_1M_1N(AlphaVal,b11,cs_b) + + if(transa) + dtrsm_AltXB_ref(a11, b11, m_remainder, 1, cs_a, cs_b, is_unitdiag); + else + dtrsm_AuXB_ref(a11, b11, m_remainder, 1, rs_a, cs_b, is_unitdiag); + } + } + } + + if ((required_packing_A == 1) && + bli_mem_is_alloc( &local_mem_buf_A_s )) + { + bli_membrk_release(&rntm,&local_mem_buf_A_s); + } + return BLIS_SUCCESS; +} + +/* TRSM for the Left Upper case AX = alpha * B, Double precision + * A is Left side, upper-triangular, transpose, non-unit/unit diagonal + * dimensions A: mxm X: mxn B: mxn + a10 ----> b11---> + *********** ***************** + * * * * *b01*b11* * * + **a10 * * a11 b11 * * * * * + ********* | | ***************** + *a11* * | | * * * * * + * * * | | * * * * * + ****** v v ***************** + * * * * * * * + * * * * * * * + * * ***************** + * + a11---> + + * TRSM for the case AX = alpha * B, Double precision + * A is Left side, lower-triangular, no-transpose, non-unit/unit diagonal + * dimensions A: mxm X: mxn B: mxn + + b01---> + * ***************** + ** * * * * * + * * * * * * * + * * *b01* * * * + * * * * * * * +a10 ****** b11 ***************** + | * * * | * * * * * + | * * * | * * * * * + | *a10*a11* | *b11* * * * + v * * * v * * * * * + *********** ***************** + * * * * * * * * * + * * * * * * * * * + * * * * * * * * * + * * * * * * * * * + **************** ***************** + a11---> +*/ +BLIS_INLINE err_t bli_dtrsm_small_AutXB_AlXB +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +) +{ + dim_t m = bli_obj_length(b); // number of rows of matrix B + dim_t n = bli_obj_width(b); // number of columns of matrix B + + bool transa = bli_obj_has_trans(a); + dim_t cs_a, rs_a; + dim_t d_mr = 8,d_nr = 6; + + // Swap rs_a & cs_a in case of non-tranpose. + if(transa) + { + cs_a = bli_obj_col_stride(a); // column stride of A + rs_a = bli_obj_row_stride(a); // row stride of A + } + else + { + cs_a = bli_obj_row_stride(a); // row stride of A + rs_a = bli_obj_col_stride(a); // column stride of A + } + dim_t cs_b = bli_obj_col_stride(b); // column stride of B + + dim_t i, j, k; //loop variables + dim_t k_iter; //number of times GEMM to be performed + + double AlphaVal = *(double *)AlphaObj->buffer; //value of alpha + double *L = a->buffer; //pointer to matrix A + double *B = b->buffer; //pointer to matrix B + + double *a10, *a11, *b01, *b11; //pointers that point to blocks for GEMM and TRSM + + double ones = 1.0; + bool is_unitdiag = bli_obj_has_unit_diag(a); + + //scratch registers + __m256d ymm0, ymm1, ymm2, ymm3; + __m256d ymm4, ymm5, ymm6, ymm7; + __m256d ymm8, ymm9, ymm10, ymm11; + __m256d ymm12, ymm13, ymm14, ymm15; + __m256d ymm16, ymm17, ymm18, ymm19; + __m256d ymm20; + + __m128d xmm5; + + gint_t required_packing_A = 1; + mem_t local_mem_buf_A_s = {0}; + double *D_A_pack = NULL; + double d11_pack[d_mr] __attribute__((aligned(64))); + rntm_t rntm; + + bli_rntm_init_from_global( &rntm ); + bli_rntm_set_num_threads_only( 1, &rntm ); + bli_membrk_rntm_set_membrk( &rntm ); + + siz_t buffer_size = bli_pool_block_size( + bli_membrk_pool( + bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), + bli_rntm_membrk(&rntm))); + + if ( (d_mr * m * sizeof(double)) > buffer_size) + return BLIS_NOT_YET_IMPLEMENTED; + + if (required_packing_A == 1) + { + // Get the buffer from the pool. + bli_membrk_acquire_m(&rntm, + buffer_size, + BLIS_BITVAL_BUFFER_FOR_A_BLOCK, + &local_mem_buf_A_s); + if(FALSE==bli_mem_is_alloc(&local_mem_buf_A_s)) return BLIS_NULL_POINTER; + D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); + if(NULL==D_A_pack) return BLIS_NULL_POINTER; + } + + /* + Performs solving TRSM for 8 colmns at a time from 0 to m/8 in steps of d_mr + a. Load, transpose, Pack A (a10 block), the size of packing 8x6 to 8x (m-8) + First there will be no GEMM and no packing of a10 because it is only TRSM + b. Using packed a10 block and b01 block perform GEMM operation + c. Use GEMM outputs, perform TRSM operaton using a11, b11 and update B + d. Repeat b,c for n rows of B in steps of d_nr + */ + for(i = 0;(i+d_mr-1) < m; i += d_mr) //loop along 'M' dimension + { + a10 = L + (i*cs_a); //pointer to block of A to be used for GEMM + a11 = L + (i*rs_a) + (i*cs_a); + dim_t p_lda = d_mr; // packed leading dimension + + if(transa) + { + /* + Load, tranpose and pack current A block (a10) into packed buffer memory D_A_pack + a. This a10 block is used in GEMM portion only and this + a10 block size will be increasing by d_mr for every next itteration + untill it reaches 8x(m-8) which is the maximum GEMM alone block size in A + b. This packed buffer is reused to calculate all n rows of B matrix + */ + bli_dtrsm_small_pack('L', i, 1, a10, cs_a, D_A_pack, p_lda,d_mr); + + /* + Pack 8 diagonal elements of A block into an array + a. This helps in utilze cache line efficiently in TRSM operation + b. store ones when input is unit diagonal + */ + dtrsm_small_pack_diag_element(is_unitdiag,a11,cs_a,d11_pack,d_mr); + } + else + { + bli_dtrsm_small_pack('L', i, 0, a10, rs_a, D_A_pack, p_lda,d_mr); + dtrsm_small_pack_diag_element(is_unitdiag,a11,rs_a,d11_pack,d_mr); + } + + /* + a. Perform GEMM using a10, b01. + b. Perform TRSM on a11, b11 + c. This loop GEMM+TRSM loops operates with 8x6 block size + along n dimension for every d_nr rows of b01 where + packed A buffer is reused in computing all n rows of B. + d. Same approch is used in remaining fringe cases. + */ + dim_t temp = n - d_nr + 1; + for(j = 0; j < temp; j += d_nr) //loop along 'N' dimension + { + a10 = D_A_pack; + a11 = L + (i*rs_a) + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + j*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM + + k_iter = i; + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + /* + Peform GEMM between a10 and b01 blocks + For first itteration there will be no GEMM operation + where k_iter are zero + */ + BLIS_DTRSM_SMALL_GEMM_8mx6n(a10,b01,cs_b,p_lda,k_iter) + + /* + Load b11 of size 6x8 and multiply with alpha + Add the GEMM output and perform inregister transose of b11 + to peform TRSM operation. + */ + BLIS_DTRSM_SMALL_NREG_TRANSPOSE_6x8(b11,cs_b,AlphaVal) + + /* + Compute 8x6 TRSM block by using GEMM block output in register + a. The 8x6 input (gemm outputs) are stored in combinations of ymm registers + 1. ymm8, ymm4 2. ymm9, ymm5 3. ymm10, ymm6, 4. ymm11, ymm7 + 5. ymm12, ymm17 6. ymm13,ymm18, 7. ymm14,ymm19 8. ymm15, ymm20 + where ymm8-ymm15 holds 8x4 data and reaming 8x2 will be hold by + other registers + b. Towards the end do in regiser transpose of TRSM output and store in b11 + */ + ////extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //perform mul operation + ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); + ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(ROw1): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*1)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); + ymm5 = _mm256_fnmadd_pd(ymm2, ymm4, ymm5); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm8, ymm10); + ymm6 = _mm256_fnmadd_pd(ymm2, ymm4, ymm6); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); + ymm7 = _mm256_fnmadd_pd(ymm2, ymm4, ymm7); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm8, ymm12); + ymm17 = _mm256_fnmadd_pd(ymm2, ymm4, ymm17); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm8, ymm13); + ymm18 = _mm256_fnmadd_pd(ymm2, ymm4, ymm18); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm8, ymm14); + ymm19 = _mm256_fnmadd_pd(ymm2, ymm4, ymm19); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm8, ymm15); + ymm20 = _mm256_fnmadd_pd(ymm2, ymm4, ymm20); + + + //perform mul operation + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm1); + + a11 += rs_a; + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm9, ymm10); + ymm6 = _mm256_fnmadd_pd(ymm2, ymm5, ymm6); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm9, ymm11); + ymm7 = _mm256_fnmadd_pd(ymm2, ymm5, ymm7); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm9, ymm12); + ymm17 = _mm256_fnmadd_pd(ymm2, ymm5, ymm17); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm9, ymm13); + ymm18 = _mm256_fnmadd_pd(ymm2, ymm5, ymm18); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm9, ymm14); + ymm19 = _mm256_fnmadd_pd(ymm2, ymm5, ymm19); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm9, ymm15); + ymm20 = _mm256_fnmadd_pd(ymm2, ymm5, ymm20); + + //perform mul operation + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm1); + + a11 += rs_a; + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //(ROw5): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm10, ymm11); + ymm7 = _mm256_fnmadd_pd(ymm2, ymm6, ymm7); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm10, ymm12); + ymm17 = _mm256_fnmadd_pd(ymm2, ymm6, ymm17); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm10, ymm13); + ymm18 = _mm256_fnmadd_pd(ymm2, ymm6, ymm18); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm10, ymm14); + ymm19 = _mm256_fnmadd_pd(ymm2, ymm6, ymm19); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm10, ymm15); + ymm20 = _mm256_fnmadd_pd(ymm2, ymm6, ymm20); + + //perform mul operation + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm1); + + a11 += rs_a; + + //extract a44 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + //(ROw4): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm11, ymm12); + ymm17 = _mm256_fnmadd_pd(ymm2, ymm7, ymm17); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm11, ymm13); + ymm18 = _mm256_fnmadd_pd(ymm2, ymm7, ymm18); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm11, ymm14); + ymm19 = _mm256_fnmadd_pd(ymm2, ymm7, ymm19); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm11, ymm15); + ymm20 = _mm256_fnmadd_pd(ymm2, ymm7, ymm20); + + //perform mul operation + ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); + ymm17 = DTRSM_SMALL_DIV_OR_SCALE(ymm17, ymm1); + + a11 += rs_a; + + //extract a55 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + + //(ROw5): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm12, ymm13); + ymm18 = _mm256_fnmadd_pd(ymm2, ymm17, ymm18); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm12, ymm14); + ymm19 = _mm256_fnmadd_pd(ymm2, ymm17, ymm19); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm12, ymm15); + ymm20 = _mm256_fnmadd_pd(ymm2, ymm17, ymm20); + + //perform mul operation + ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm1); + ymm18 = DTRSM_SMALL_DIV_OR_SCALE(ymm18, ymm1); + + a11 += rs_a; + + //extract a66 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); + + //(ROw6): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm13, ymm14); + ymm19 = _mm256_fnmadd_pd(ymm2, ymm18, ymm19); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm13, ymm15); + ymm20 = _mm256_fnmadd_pd(ymm2, ymm18, ymm20); + + //perform mul operation + ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); + ymm19 = DTRSM_SMALL_DIV_OR_SCALE(ymm19, ymm1); + + a11 += rs_a; + + //extract a77 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); + + //(ROw7): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm14, ymm15); + ymm20 = _mm256_fnmadd_pd(ymm2, ymm19, ymm20); + + //perform mul operation + ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); + ymm20 = DTRSM_SMALL_DIV_OR_SCALE(ymm20, ymm1); + + a11 += rs_a; + + BLIS_DTRSM_SMALL_NREG_TRANSPOSE_8x6_AND_STORE(b11,cs_b) + } + + dim_t n_rem = n-j; + if(n_rem >= 4) + { + a10 = D_A_pack; + a11 = L + (i*rs_a) + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + j*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM + + k_iter = i ; //number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_8mx4n(a10,b01,cs_b,p_lda,k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] + ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] + ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *2 + 4)); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] + ymm7 = _mm256_loadu_pd((double const *)(b11 + cs_b *3 + 4)); //B11[0][7] B11[1][7] B11[2][7] B11[3][7] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); //B11[0-3][3] * alpha -= B01[0-3][3] + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4] + ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] * alpha -= B01[0-3][5] + ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); //B11[0-3][6] * alpha -= B01[0-3][6] + ymm7 = _mm256_fmsub_pd(ymm7, ymm16, ymm15); //B11[0-3][7] * alpha -= B01[0-3][7] + + ///implement TRSM/// + + ///transpose of B11// + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[2][4] B11[2][5] + ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[2][6] B11[2][7] + + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3] + ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[1][4] B11[1][5] B11[3][4] B11[3][5] + ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[1][6] B11[1][7] B11[3][6] B11[3][7] + + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3] + ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3] + + ymm0 = _mm256_broadcast_sd((double const *)&ones); + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //perform mul operation + ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*1)); + ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + + a11 += rs_a; + + //(ROw1): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); + ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); + ymm11 = _mm256_fnmadd_pd(ymm4, ymm8, ymm11); + ymm12 = _mm256_fnmadd_pd(ymm5, ymm8, ymm12); + ymm13 = _mm256_fnmadd_pd(ymm6, ymm8, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm7, ymm8, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm8, ymm15); + + //perform mul operation + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); + + ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + + a11 += rs_a; + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(ROw2): FMA operations + ymm10 = _mm256_fnmadd_pd(ymm3, ymm9, ymm10); + ymm11 = _mm256_fnmadd_pd(ymm4, ymm9, ymm11); + ymm12 = _mm256_fnmadd_pd(ymm5, ymm9, ymm12); + ymm13 = _mm256_fnmadd_pd(ymm6, ymm9, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm7, ymm9, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm9, ymm15); + + //perform mul operation + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + + ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + + a11 += rs_a; + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //(ROw5): FMA operations + ymm11 = _mm256_fnmadd_pd(ymm4, ymm10, ymm11); + ymm12 = _mm256_fnmadd_pd(ymm5, ymm10, ymm12); + ymm13 = _mm256_fnmadd_pd(ymm6, ymm10, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm7, ymm10, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm10, ymm15); + + //perform mul operation + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); + + ymm0 = _mm256_broadcast_sd((double const *)&ones); + + //extract a44 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + + ymm5 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + + a11 += rs_a; + + //(ROw4): FMA operations + ymm12 = _mm256_fnmadd_pd(ymm5, ymm11, ymm12); + ymm13 = _mm256_fnmadd_pd(ymm6, ymm11, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm7, ymm11, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm11, ymm15); + + //perform mul operation + ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); + + ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + + a11 += rs_a; + + //extract a55 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + + //(ROw5): FMA operations + ymm13 = _mm256_fnmadd_pd(ymm6, ymm12, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm7, ymm12, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm12, ymm15); + + //perform mul operation + ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm1); + + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 +cs_a*7)); + + a11 += rs_a; + + //extract a66 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); + + //(ROw6): FMA operations + ymm14 = _mm256_fnmadd_pd(ymm7, ymm13, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm13, ymm15); + + //perform mul operation + ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); + + //extract a77 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); + + ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + + a11 += rs_a; + //(ROw7): FMA operations + ymm15 = _mm256_fnmadd_pd(ymm16, ymm14, ymm15); + + //perform mul operation + ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] + ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] + ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[4][1] B11[5][1] B11[4][3] B11[5][3] + ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[6][1] B11[7][1] B11[6][3] B11[7][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] + ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] + + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); //store B11[4][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); //store B11[5][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 4), ymm6); //store B11[6][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3 + 4), ymm7); //store B11[7][0-3] + + n_rem -=4; + j +=4; + } + + if(n_rem) + { + a10 = D_A_pack; + a11 = L + (i*rs_a) + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + j*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM + + k_iter = i; //number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + if(3 == n_rem) + { + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_8mx3n(a10,b01,cs_b,p_lda,k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] + ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] + ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *2 + 4)); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4] + ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] * alpha -= B01[0-3][5] + ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); //B11[0-3][6] * alpha -= B01[0-3][6] + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + } + else if(2 == n_rem) + { + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_8mx2n(a10,b01,cs_b,p_lda,k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] + ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4] + ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] * alpha -= B01[0-3][5] + ymm6 = _mm256_broadcast_sd((double const *)(&ones)); + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + } + else if(1 == n_rem) + { + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_8mx1n(a10,b01,cs_b,p_lda,k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_broadcast_sd((double const *)(&ones)); + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4] + ymm5 = _mm256_broadcast_sd((double const *)(&ones)); + ymm6 = _mm256_broadcast_sd((double const *)(&ones)); + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + } + ///implement TRSM/// + + ///transpose of B11// + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[2][4] B11[2][5] + ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[2][6] B11[2][7] + + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3] + ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[1][4] B11[1][5] B11[3][4] B11[3][5] + ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[1][6] B11[1][7] B11[3][6] B11[3][7] + + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3] + ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3] + + ymm0 = _mm256_broadcast_sd((double const *)&ones); + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //perform mul operation + ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*1)); + ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + + a11 += rs_a; + + //(ROw1): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); + ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); + ymm11 = _mm256_fnmadd_pd(ymm4, ymm8, ymm11); + ymm12 = _mm256_fnmadd_pd(ymm5, ymm8, ymm12); + ymm13 = _mm256_fnmadd_pd(ymm6, ymm8, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm7, ymm8, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm8, ymm15); + + //perform mul operation + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); + + ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + + a11 += rs_a; + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(ROw2): FMA operations + ymm10 = _mm256_fnmadd_pd(ymm3, ymm9, ymm10); + ymm11 = _mm256_fnmadd_pd(ymm4, ymm9, ymm11); + ymm12 = _mm256_fnmadd_pd(ymm5, ymm9, ymm12); + ymm13 = _mm256_fnmadd_pd(ymm6, ymm9, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm7, ymm9, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm9, ymm15); + + //perform mul operation + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + + ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + + a11 += rs_a; + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //(ROw5): FMA operations + ymm11 = _mm256_fnmadd_pd(ymm4, ymm10, ymm11); + ymm12 = _mm256_fnmadd_pd(ymm5, ymm10, ymm12); + ymm13 = _mm256_fnmadd_pd(ymm6, ymm10, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm7, ymm10, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm10, ymm15); + + //perform mul operation + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); + + ymm0 = _mm256_broadcast_sd((double const *)&ones); + + //extract a44 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + + ymm5 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + + a11 += rs_a; + + //(ROw4): FMA operations + ymm12 = _mm256_fnmadd_pd(ymm5, ymm11, ymm12); + ymm13 = _mm256_fnmadd_pd(ymm6, ymm11, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm7, ymm11, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm11, ymm15); + + //perform mul operation + ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); + + ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + + a11 += rs_a; + + //extract a55 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + + + //(ROw5): FMA operations + ymm13 = _mm256_fnmadd_pd(ymm6, ymm12, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm7, ymm12, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm12, ymm15); + + //perform mul operation + ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm1); + + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 +cs_a*7)); + + a11 += rs_a; + + //extract a66 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); + + + //(ROw6): FMA operations + ymm14 = _mm256_fnmadd_pd(ymm7, ymm13, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm13, ymm15); + + //perform mul operation + ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); + + //extract a77 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); + + ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + + a11 += rs_a; + //(ROw7): FMA operations + ymm15 = _mm256_fnmadd_pd(ymm16, ymm14, ymm15); + + //perform mul operation + ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] + ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] + ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[4][1] B11[5][1] B11[4][3] B11[5][3] + ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[6][1] B11[7][1] B11[6][3] B11[7][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] + ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] + + if(3 == n_rem) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); //store B11[4][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); //store B11[5][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 4), ymm6); //store B11[6][0-3] + } + else if(2 == n_rem) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); //store B11[4][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); //store B11[5][0-3] + } + else if(1 == n_rem) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); //store B11[4][0-3] + } + } + } + + //======================M remainder cases================================ + dim_t m_rem = m-i; + if(m_rem>=4) //implementation for reamainder rows(when 'M' is not a multiple of d_mr) + { + a10 = L + (i*cs_a); //pointer to block of A to be used for GEMM + a11 = L + (i*rs_a) + (i*cs_a); + double *ptr_a10_dup = D_A_pack; + dim_t p_lda = 4; // packed leading dimension + + if(transa) + { + for(dim_t x =0;x < i;x+=p_lda) + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + cs_a)); + ymm2 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); + ymm3 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); + + a10 += p_lda; + ptr_a10_dup += p_lda*p_lda; + } + } + else + { + for(dim_t x =0;x < i;x++) + { + ymm0 = _mm256_loadu_pd((double const *)(a10 + rs_a * x)); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * x), ymm0); + } + } + + ymm4 = _mm256_broadcast_sd((double const *)&ones); + if(!is_unitdiag) + { + if(transa) + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11+cs_a*1 + 1)); + ymm2 = _mm256_broadcast_sd((double const *)(a11+cs_a*2 + 2)); + ymm3 = _mm256_broadcast_sd((double const *)(a11+cs_a*3 + 3)); + } + else + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11+rs_a*1 + 1)); + ymm2 = _mm256_broadcast_sd((double const *)(a11+rs_a*2 + 2)); + ymm3 = _mm256_broadcast_sd((double const *)(a11+rs_a*3 + 3)); + } + + ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); + #ifdef BLIS_DISABLE_TRSM_PREINVERSION + ymm4 = ymm1; + #endif + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + ymm4 = _mm256_div_pd(ymm4, ymm1); + #endif + } + _mm256_storeu_pd((double *)(d11_pack), ymm4); + + for(j = 0; (j+d_nr-1) < n; j += d_nr) //loop along 'N' dimension + { + a10 = D_A_pack; //pointer to block of A to be used for GEMM + a11 = L + (i*rs_a) + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM + b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM + + k_iter = i; //number of times GEMM operation to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx6n(a10,b01,cs_b,p_lda,k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + + ///implement TRSM/// + ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); //B11[0-3][3] * alpha -= B01[0-3][3] + + ///transpose of B11// + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); + + + ymm16 = _mm256_broadcast_sd((double const *)(&ones)); + + ////unpacklow//// + ymm7 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + //ymm16; + + //rearrange low elements + ymm4 = _mm256_permute2f128_pd(ymm7,ymm16,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm6 = _mm256_permute2f128_pd(ymm7,ymm16,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + //ymm16; + + //rearrange high elements + ymm5 = _mm256_permute2f128_pd(ymm0,ymm16,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm7 = _mm256_permute2f128_pd(ymm0,ymm16,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + //b11 transpose end + + ////extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //perform mul operation + ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); + ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(ROw1): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*1)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); + ymm5 = _mm256_fnmadd_pd(ymm2, ymm4, ymm5); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm8, ymm10); + ymm6 = _mm256_fnmadd_pd(ymm2, ymm4, ymm6); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); + ymm7 = _mm256_fnmadd_pd(ymm2, ymm4, ymm7); + + + //perform mul operation + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm1); + + a11 += rs_a; + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm9, ymm10); + ymm6 = _mm256_fnmadd_pd(ymm2, ymm5, ymm6); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm9, ymm11); + ymm7 = _mm256_fnmadd_pd(ymm2, ymm5, ymm7); + + //perform mul operation + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm1); + + a11 += rs_a; + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //(ROw5): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm10, ymm11); + ymm7 = _mm256_fnmadd_pd(ymm2, ymm6, ymm7); + + //perform mul operation + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm1); + + a11 += rs_a; + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + + ///unpack high/// + ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + + _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store B11[1][0-3] + } + + dim_t n_rem = n-j; + if(n_rem >= 4) + { + a10 = D_A_pack; + a11 = L + (i*rs_a) + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + j*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM + + k_iter = i; //number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + BLIS_DTRSM_SMALL_GEMM_4mx4n(a10,b01,cs_b,p_lda,k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + ///transpose of B11// + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + ymm0 = _mm256_broadcast_sd((double const *)&ones); + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //perform mul operation + ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*1)); + ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + + a11 += rs_a; + + //(ROw1): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); + ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); + ymm11 = _mm256_fnmadd_pd(ymm4, ymm8, ymm11); + + //perform mul operation + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); + + ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + + a11 += rs_a; + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(ROw2): FMA operations + ymm10 = _mm256_fnmadd_pd(ymm3, ymm9, ymm10); + ymm11 = _mm256_fnmadd_pd(ymm4, ymm9, ymm11); + + //perform mul operation + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + + ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + + a11 += rs_a; + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //(ROw5): FMA operations + ymm11 = _mm256_fnmadd_pd(ymm4, ymm10, ymm11); + + //perform mul operation + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] + + n_rem -= 4; + j += 4; + } + if(n_rem) + { + a10 = D_A_pack; + a11 = L + (i*rs_a) + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + j*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM + + k_iter = i; //number of times GEMM to be performed(in blocks of 4x4) + + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + + if(3 == n_rem) + { + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + } + else if(2 == n_rem) + { + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b,p_lda,k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + } + else if(1 == n_rem) + { + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b,p_lda,k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_broadcast_sd((double const *)(&ones)); + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + } + + ///transpose of B11// + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + ymm0 = _mm256_broadcast_sd((double const *)&ones); + + ////extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //perform mul operation + ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*1)); + ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + + a11 += rs_a; + + //(ROw1): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); + ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); + ymm11 = _mm256_fnmadd_pd(ymm4, ymm8, ymm11); + + //perform mul operation + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); + + ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + + a11 += rs_a; + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(ROw2): FMA operations + ymm10 = _mm256_fnmadd_pd(ymm3, ymm9, ymm10); + ymm11 = _mm256_fnmadd_pd(ymm4, ymm9, ymm11); + + //perform mul operation + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + + ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + + a11 += rs_a; + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //(ROw5): FMA operations + ymm11 = _mm256_fnmadd_pd(ymm4, ymm10, ymm11); + + //perform mul operation + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + if(3 == n_rem) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + } + else if(2 == n_rem) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + } + else if(1 == n_rem) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + } + } + m_rem -=4; + i +=4; + } + + if(m_rem) + { + a10 = L + (i*cs_a); //pointer to block of A to be used for GEMM + // Do transpose for a10 & store in D_A_pack + double *ptr_a10_dup = D_A_pack; + if(3 == m_rem) // Repetative A blocks will be 3*3 + { + dim_t p_lda = 4; // packed leading dimension + if(transa) + { + for(dim_t x=0;x= 4)) + { + a10 = D_A_pack; //pointer to block of A to be used for GEMM + a11 = L + (i*rs_a) + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM + b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM + + k_iter = i; //number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx4n(a10,b01,cs_b,p_lda,k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); + ymm3 = _mm256_broadcast_sd((double const *)(b11 + cs_b*3 + 2)); + ymm3 = _mm256_insertf128_pd(ymm3, xmm5, 0); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08); + ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x08); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + xmm5 = _mm256_extractf128_pd(ymm3, 0); + _mm_storeu_pd((double *)(b11 + cs_b * 3),xmm5); + _mm_storel_pd((b11 + cs_b * 3 + 2), _mm256_extractf128_pd(ymm3, 1)); + + if(transa) + dtrsm_AutXB_ref(a11, b11, m_rem, 4, cs_a, cs_b,is_unitdiag); + else + dtrsm_AlXB_ref(a11, b11, m_rem, 4, rs_a, cs_b, is_unitdiag); + n_rem -= 4; + j +=4; + } + + if(n_rem) + { + a10 = D_A_pack; //pointer to block of A to be used for GEMM + a11 = L + (i*rs_a) + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM + b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM + + k_iter = i; //number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + if(3 == n_rem) + { + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) + + BLIS_PRE_DTRSM_SMALL_3M_3N(AlphaVal,b11,cs_b) + + if(transa) + dtrsm_AutXB_ref(a11, b11, m_rem, 3, cs_a, cs_b,is_unitdiag); + else + dtrsm_AlXB_ref(a11, b11, m_rem, 3, rs_a, cs_b, is_unitdiag); + } + else if(2 == n_rem) + { + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b,p_lda,k_iter) + + BLIS_PRE_DTRSM_SMALL_3M_2N(AlphaVal,b11,cs_b) + + if(transa) + dtrsm_AutXB_ref(a11, b11, m_rem, 2, cs_a, cs_b,is_unitdiag); + else + dtrsm_AlXB_ref(a11, b11, m_rem, 2, rs_a, cs_b, is_unitdiag); + } + else if(1 == n_rem) + { + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b,p_lda,k_iter) + + BLIS_PRE_DTRSM_SMALL_3M_1N(AlphaVal,b11,cs_b) + + if(transa) + dtrsm_AutXB_ref(a11, b11, m_rem, 1, cs_a, cs_b, is_unitdiag); + else + dtrsm_AlXB_ref(a11, b11, m_rem, 1, rs_a, cs_b, is_unitdiag); + } + } + } + else if(2 == m_rem) // Repetative A blocks will be 2*2 + { + dim_t p_lda = 4; // packed leading dimension + if(transa) + { + for(dim_t x=0;x= 4)) + { + a10 = D_A_pack; //pointer to block of A to be used for GEMM + a11 = L + (i*rs_a) + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM + b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM + + k_iter = i; //number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx4n(a10,b01,cs_b,p_lda,k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); + ymm3 = _mm256_insertf128_pd(ymm3, xmm5, 0); + + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0C); + ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0C); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + xmm5 = _mm256_extractf128_pd(ymm3, 0); + _mm_storeu_pd((double *)(b11 + cs_b * 3), xmm5); + + if(transa) + dtrsm_AutXB_ref(a11, b11, m_rem, 4, cs_a, cs_b, is_unitdiag); + else + dtrsm_AlXB_ref(a11, b11, m_rem, 4, rs_a, cs_b, is_unitdiag); + n_rem -= 4; + j +=4; + } + if(n_rem) + { + a10 = D_A_pack; //pointer to block of A to be used for GEMM + a11 = L + (i*rs_a) + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM + b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM + + k_iter = i; //number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + if(3 == n_rem) + { + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) + + BLIS_PRE_DTRSM_SMALL_2M_3N(AlphaVal,b11,cs_b) + + if(transa) + dtrsm_AutXB_ref(a11, b11, m_rem, 3, cs_a, cs_b, is_unitdiag); + else + dtrsm_AlXB_ref(a11, b11, m_rem, 3, rs_a, cs_b, is_unitdiag); + } + else if(2 == n_rem) + { + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b,p_lda,k_iter) + + BLIS_PRE_DTRSM_SMALL_2M_2N(AlphaVal,b11,cs_b) + + if(transa) + dtrsm_AutXB_ref(a11, b11, m_rem, 2, cs_a, cs_b, is_unitdiag); + else + dtrsm_AlXB_ref(a11, b11, m_rem, 2, rs_a, cs_b, is_unitdiag); + } + else if(1 == n_rem) + { + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b,p_lda,k_iter) + + BLIS_PRE_DTRSM_SMALL_2M_1N(AlphaVal,b11,cs_b) + + if(transa) + dtrsm_AutXB_ref(a11, b11, m_rem, 1, cs_a, cs_b, is_unitdiag); + else + dtrsm_AlXB_ref(a11, b11, m_rem, 1, rs_a, cs_b, is_unitdiag); + } + } + m_rem -=2; + i+=2; + } + else if(1 == m_rem) // Repetative A blocks will be 1*1 + { + dim_t p_lda = 4; // packed leading dimension + if(transa) + { + for(dim_t x=0;x= 4)) + { + a10 = D_A_pack; //pointer to block of A to be used for GEMM + a11 = L + (i*rs_a) + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM + b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM + + k_iter = i; //number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx4n(a10,b01,cs_b,p_lda,k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + + ///implement TRSM/// + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_broadcast_sd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_broadcast_sd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_broadcast_sd((double const *)(b11 + cs_b *3)); + + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); + ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0E); + + _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm0, 0)); + _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm1, 0)); + _mm_storel_pd((b11 + cs_b * 2), _mm256_extractf128_pd(ymm2, 0)); + _mm_storel_pd((b11 + cs_b * 3), _mm256_extractf128_pd(ymm3, 0)); + + if(transa) + dtrsm_AutXB_ref(a11, b11, m_rem, 4, cs_a, cs_b, is_unitdiag); + else + dtrsm_AlXB_ref(a11, b11, m_rem, 4, rs_a, cs_b, is_unitdiag); + n_rem -= 4; + j+=4; + } + + if(n_rem) + { + a10 = D_A_pack; //pointer to block of A to be used for GEMM + a11 = L + (i*rs_a) + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM + b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM + + k_iter = i; //number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + if(3 == n_rem) + { + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) + + BLIS_PRE_DTRSM_SMALL_1M_3N(AlphaVal,b11,cs_b) + + if(transa) + dtrsm_AutXB_ref(a11, b11, m_rem, 3, cs_a, cs_b, is_unitdiag); + else + dtrsm_AlXB_ref(a11, b11, m_rem, 3, rs_a, cs_b, is_unitdiag); + } + else if(2 == n_rem) + { + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b,p_lda,k_iter) + + BLIS_PRE_DTRSM_SMALL_1M_2N(AlphaVal,b11,cs_b) + + if(transa) + dtrsm_AutXB_ref(a11, b11, m_rem, 2, cs_a, cs_b, is_unitdiag); + else + dtrsm_AlXB_ref(a11, b11, m_rem, 2, rs_a, cs_b, is_unitdiag); + } + else if(1 == n_rem) + { + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b,p_lda,k_iter) + + BLIS_PRE_DTRSM_SMALL_1M_1N(AlphaVal,b11,cs_b) + + if(transa) + dtrsm_AutXB_ref(a11, b11, m_rem, 1, cs_a, cs_b, is_unitdiag); + else + dtrsm_AutXB_ref(a11, b11, m_rem, 1, rs_a, cs_b, is_unitdiag); + } + } + m_rem -=1; + i+=1; + } + } + + if ((required_packing_A == 1) && + bli_mem_is_alloc( &local_mem_buf_A_s )) + { + bli_membrk_release(&rntm, &local_mem_buf_A_s); + } + return BLIS_SUCCESS; +} + +/* + Pack diagonal elements of A block (16 or 6) into an array + a. This helps in utilze cache line efficiently in TRSM operation + b. store ones when input is unit diagonal +*/ +BLIS_INLINE void strsm_small_pack_diag_element +( + char side, + bool is_unitdiag, + float *a11, + dim_t cs_a, + float *d11_pack, + dim_t size +) +{ + __m256 ymm0, ymm1, ymm2, ymm3; + __m256 ymm4, ymm5, ymm6, ymm7; + __m256 ymm8, ymm9, ymm10,ymm11; + __m256 ymm14, ymm15, ymm12,ymm13; + float ones = 1.0; + ymm13 = ymm14 = ymm15 = _mm256_broadcast_ss((float const *)&ones); + if(side=='L'||side=='l') + { + if(!is_unitdiag) + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_ss((float const *)(a11)); + ymm1 = _mm256_broadcast_ss((float const *)(a11+ cs_a +1)); + ymm2 = _mm256_broadcast_ss((float const *)(a11+ cs_a*2 + 2)); + ymm3 = _mm256_broadcast_ss((float const *)(a11+ cs_a*3 + 3)); + ymm4 = _mm256_broadcast_ss((float const *)(a11+ cs_a*4 + 4)); + ymm5 = _mm256_broadcast_ss((float const *)(a11+ cs_a*5 + 5)); + ymm6 = _mm256_broadcast_ss((float const *)(a11+ cs_a*6 + 6)); + ymm7 = _mm256_broadcast_ss((float const *)(a11+ cs_a*7 + 7)); + + ymm8 = _mm256_unpacklo_ps(ymm0, ymm1); + ymm9 = _mm256_unpacklo_ps(ymm2, ymm3); + ymm10 = _mm256_blend_ps(ymm8, ymm9, 0xCC); + + ymm8 = _mm256_unpacklo_ps(ymm4, ymm5); + ymm9 = _mm256_unpacklo_ps(ymm6, ymm7); + ymm11 = _mm256_blend_ps(ymm8, ymm9, 0xCC); + + ymm12 = _mm256_blend_ps(ymm10, ymm11, 0xF0); + + #ifdef BLIS_DISABLE_TRSM_PREINVERSION + ymm14 = ymm12; + #endif + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + ymm14 = _mm256_div_ps(ymm13, ymm12); + #endif + + ymm0 = _mm256_broadcast_ss((float const *)(a11+ cs_a*8 + 8)); + ymm1 = _mm256_broadcast_ss((float const *)(a11+ cs_a*9 + 9)); + ymm2 = _mm256_broadcast_ss((float const *)(a11+ cs_a*10 + 10)); + ymm3 = _mm256_broadcast_ss((float const *)(a11+ cs_a*11 + 11)); + ymm4 = _mm256_broadcast_ss((float const *)(a11+ cs_a*12 + 12)); + ymm5 = _mm256_broadcast_ss((float const *)(a11+ cs_a*13 + 13)); + ymm6 = _mm256_broadcast_ss((float const *)(a11+ cs_a*14 + 14)); + ymm7 = _mm256_broadcast_ss((float const *)(a11+ cs_a*15 + 15)); + + ymm8 = _mm256_unpacklo_ps(ymm0, ymm1); + ymm9 = _mm256_unpacklo_ps(ymm2, ymm3); + ymm10 = _mm256_blend_ps(ymm8, ymm9, 0xCC); + + ymm8 = _mm256_unpacklo_ps(ymm4, ymm5); + ymm9 = _mm256_unpacklo_ps(ymm6, ymm7); + ymm11 = _mm256_blend_ps(ymm8, ymm9, 0xCC); + + ymm12 = _mm256_blend_ps(ymm10, ymm11, 0xF0); + + #ifdef BLIS_DISABLE_TRSM_PREINVERSION + ymm15 = ymm12; + #endif + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + ymm15 = _mm256_div_ps(ymm13, ymm12); + #endif + } + _mm256_store_ps((float *)(d11_pack), ymm14); + _mm256_store_ps((float *)(d11_pack + 8), ymm15); + } + else if(side=='R'||side=='r') + { + if(!is_unitdiag) + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_ss((float const *)(a11)); + ymm1 = _mm256_broadcast_ss((float const *)(a11+ cs_a +1)); + ymm2 = _mm256_broadcast_ss((float const *)(a11+ cs_a*2 + 2)); + ymm3 = _mm256_broadcast_ss((float const *)(a11+ cs_a*3 + 3)); + ymm4 = _mm256_broadcast_ss((float const *)(a11+ cs_a*4 + 4)); + ymm5 = _mm256_broadcast_ss((float const *)(a11+ cs_a*5 + 5)); + + ymm8 = _mm256_unpacklo_ps(ymm0, ymm1); + ymm9 = _mm256_unpacklo_ps(ymm2, ymm3); + ymm10 = _mm256_blend_ps(ymm8, ymm9, 0x0C); + + ymm8 = _mm256_unpacklo_ps(ymm4, ymm5); + ymm11 = _mm256_blend_ps(ymm8, ymm8, 0x0C); + ymm12 = _mm256_blend_ps(ymm10, ymm11, 0xF0); + + #ifdef BLIS_DISABLE_TRSM_PREINVERSION + ymm14 = ymm12; + #endif + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + ymm14 = _mm256_div_ps(ymm13, ymm12); + #endif + } + _mm_storeu_ps((float *)(d11_pack), _mm256_extractf128_ps(ymm14,0)); + __m128 xmm5 = _mm256_extractf128_ps(ymm14,1); + _mm_storel_pi((__m64 *)(d11_pack + 4),xmm5); + } +} + +/* + Pack a block of 16xk or 6xk from input buffer into packed buffer + directly or after transpose based on input params +*/ +BLIS_INLINE void bli_strsm_small_pack +( + char side, + dim_t size, + bool trans, + float *inbuf, + dim_t cs_a, + float *pbuff, + dim_t p_lda, + dim_t mr +) +{ + //scratch registers + __m256 ymm0, ymm1, ymm2, ymm3; + __m256 ymm4, ymm5, ymm6, ymm7; + __m256 ymm8, ymm9, ymm10, ymm11; + __m256 ymm12, ymm13,ymm14,ymm15; + float zero = 0.0; + + if(side=='L'||side=='l') + { + /*Left case is 16xk*/ + if(trans) + { + /* + ------------- ------------- + | | | | | + | 8x16 | | | | + ------------- ==> | 16x8 | 16x8 | + | 8x16 | | | | + | | | | | + ------------- ------------- + */ + for(dim_t x = 0; x < size; x += p_lda) + { + ymm0 = _mm256_loadu_ps((float const *)(inbuf)); + ymm1 = _mm256_loadu_ps((float const *)(inbuf + cs_a)); + ymm2 = _mm256_loadu_ps((float const *)(inbuf + cs_a*2)); + ymm3 = _mm256_loadu_ps((float const *)(inbuf + cs_a*3)); + ymm4 = _mm256_loadu_ps((float const *)(inbuf + cs_a*4)); + ymm5 = _mm256_loadu_ps((float const *)(inbuf + cs_a*5)); + ymm6 = _mm256_loadu_ps((float const *)(inbuf + cs_a*6)); + ymm7 = _mm256_loadu_ps((float const *)(inbuf + cs_a*7)); + + ymm8 = _mm256_unpacklo_ps(ymm0, ymm1); + ymm9 = _mm256_unpacklo_ps(ymm2, ymm3); + + ymm10 = _mm256_unpacklo_ps(ymm4, ymm5); + ymm11 = _mm256_unpacklo_ps(ymm6, ymm7); + + ymm12 = _mm256_shuffle_ps(ymm8,ymm9,0b01000100); + ymm13 = _mm256_shuffle_ps(ymm10,ymm11,0b01000100); + + ymm14 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//1 + ymm15 = _mm256_permute2f128_ps(ymm12,ymm13,0x31);//5 + + _mm256_storeu_ps((float *)(pbuff), ymm14); + _mm256_storeu_ps((float *)(pbuff + 4*p_lda), ymm15); + + ymm12 = _mm256_shuffle_ps(ymm8,ymm9,0b11101110); + ymm13 = _mm256_shuffle_ps(ymm10,ymm11,0b11101110); + + ymm14 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//2 + ymm15 = _mm256_permute2f128_ps(ymm12,ymm13,0x31);//6 + _mm256_storeu_ps((float *)(pbuff + p_lda), ymm14); + _mm256_storeu_ps((float *)(pbuff + 5*p_lda), ymm15); + + ymm8 = _mm256_unpackhi_ps(ymm0, ymm1); + ymm9 = _mm256_unpackhi_ps(ymm2, ymm3); + + ymm10 = _mm256_unpackhi_ps(ymm4, ymm5); + ymm11 = _mm256_unpackhi_ps(ymm6, ymm7); + + ymm12 = _mm256_shuffle_ps(ymm8,ymm9,0b01000100); + ymm13 = _mm256_shuffle_ps(ymm10,ymm11,0b01000100); + + ymm14 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//3 + ymm15 = _mm256_permute2f128_ps(ymm12,ymm13,0x31);//7 + _mm256_storeu_ps((float *)(pbuff + 2*p_lda), ymm14); + _mm256_storeu_ps((float *)(pbuff + 6*p_lda), ymm15); + + ymm12 = _mm256_shuffle_ps(ymm8,ymm9,0b11101110); + ymm13 = _mm256_shuffle_ps(ymm10,ymm11,0b11101110); + + ymm14 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//4 + ymm15 = _mm256_permute2f128_ps(ymm12,ymm13,0x31);//8 + _mm256_storeu_ps((float *)(pbuff + 3*p_lda), ymm14); + _mm256_storeu_ps((float *)(pbuff + 7*p_lda), ymm15); + + ymm0 = _mm256_loadu_ps((float const *)(inbuf + 8)); + ymm1 = _mm256_loadu_ps((float const *)(inbuf + cs_a + 8)); + ymm2 = _mm256_loadu_ps((float const *)(inbuf + cs_a*2 + 8)); + ymm3 = _mm256_loadu_ps((float const *)(inbuf + cs_a*3 + 8)); + ymm4 = _mm256_loadu_ps((float const *)(inbuf + cs_a*4 + 8)); + ymm5 = _mm256_loadu_ps((float const *)(inbuf + cs_a*5 + 8)); + ymm6 = _mm256_loadu_ps((float const *)(inbuf + cs_a*6 + 8)); + ymm7 = _mm256_loadu_ps((float const *)(inbuf + cs_a*7 + 8)); + + ymm8 = _mm256_unpacklo_ps(ymm0, ymm1); + ymm9 = _mm256_unpacklo_ps(ymm2, ymm3); + + ymm10 = _mm256_unpacklo_ps(ymm4, ymm5); + ymm11 = _mm256_unpacklo_ps(ymm6, ymm7); + + ymm12 = _mm256_shuffle_ps(ymm8,ymm9,0b01000100); + ymm13 = _mm256_shuffle_ps(ymm10,ymm11,0b01000100); + + ymm14 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//1 + ymm15 = _mm256_permute2f128_ps(ymm12,ymm13,0x31);//5 + + _mm256_storeu_ps((float *)(pbuff + 8*p_lda), ymm14); + _mm256_storeu_ps((float *)(pbuff + 12*p_lda), ymm15); + + ymm12 = _mm256_shuffle_ps(ymm8,ymm9,0b11101110); + ymm13 = _mm256_shuffle_ps(ymm10,ymm11,0b11101110); + + ymm14 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//2 + ymm15 = _mm256_permute2f128_ps(ymm12,ymm13,0x31);//6 + _mm256_storeu_ps((float *)(pbuff + 9*p_lda), ymm14); + _mm256_storeu_ps((float *)(pbuff + 13*p_lda), ymm15); + + ymm8 = _mm256_unpackhi_ps(ymm0, ymm1); + ymm9 = _mm256_unpackhi_ps(ymm2, ymm3); + + ymm10 = _mm256_unpackhi_ps(ymm4, ymm5); + ymm11 = _mm256_unpackhi_ps(ymm6, ymm7); + + ymm12 = _mm256_shuffle_ps(ymm8,ymm9,0b01000100); + ymm13 = _mm256_shuffle_ps(ymm10,ymm11,0b01000100); + + ymm14 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//3 + ymm15 = _mm256_permute2f128_ps(ymm12,ymm13,0x31);//7 + _mm256_storeu_ps((float *)(pbuff + 10*p_lda), ymm14); + _mm256_storeu_ps((float *)(pbuff + 14*p_lda), ymm15); + + ymm12 = _mm256_shuffle_ps(ymm8,ymm9,0b11101110); + ymm13 = _mm256_shuffle_ps(ymm10,ymm11,0b11101110); + + ymm14 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//4 + ymm15 = _mm256_permute2f128_ps(ymm12,ymm13,0x31);//8 + _mm256_storeu_ps((float *)(pbuff + 11*p_lda), ymm14); + _mm256_storeu_ps((float *)(pbuff + 15*p_lda), ymm15); + + ymm0 = _mm256_loadu_ps((float const *)(inbuf + cs_a*8)); + ymm1 = _mm256_loadu_ps((float const *)(inbuf + cs_a*9)); + ymm2 = _mm256_loadu_ps((float const *)(inbuf + cs_a*10)); + ymm3 = _mm256_loadu_ps((float const *)(inbuf + cs_a*11)); + ymm4 = _mm256_loadu_ps((float const *)(inbuf + cs_a*12)); + ymm5 = _mm256_loadu_ps((float const *)(inbuf + cs_a*13)); + ymm6 = _mm256_loadu_ps((float const *)(inbuf + cs_a*14)); + ymm7 = _mm256_loadu_ps((float const *)(inbuf + cs_a*15)); + + ymm8 = _mm256_unpacklo_ps(ymm0, ymm1); + ymm9 = _mm256_unpacklo_ps(ymm2, ymm3); + + ymm10 = _mm256_unpacklo_ps(ymm4, ymm5); + ymm11 = _mm256_unpacklo_ps(ymm6, ymm7); + + ymm12 = _mm256_shuffle_ps(ymm8,ymm9,0b01000100); + ymm13 = _mm256_shuffle_ps(ymm10,ymm11,0b01000100); + + ymm14 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//1 + ymm15 = _mm256_permute2f128_ps(ymm12,ymm13,0x31);//5 + + _mm256_storeu_ps((float *)(pbuff + 8), ymm14); + _mm256_storeu_ps((float *)(pbuff + 4*p_lda + 8), ymm15); + + ymm12 = _mm256_shuffle_ps(ymm8,ymm9,0b11101110); + ymm13 = _mm256_shuffle_ps(ymm10,ymm11,0b11101110); + + ymm14 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//2 + ymm15 = _mm256_permute2f128_ps(ymm12,ymm13,0x31);//6 + _mm256_storeu_ps((float *)(pbuff + p_lda + 8), ymm14); + _mm256_storeu_ps((float *)(pbuff + 5*p_lda + 8), ymm15); + + ymm8 = _mm256_unpackhi_ps(ymm0, ymm1); + ymm9 = _mm256_unpackhi_ps(ymm2, ymm3); + + ymm10 = _mm256_unpackhi_ps(ymm4, ymm5); + ymm11 = _mm256_unpackhi_ps(ymm6, ymm7); + + ymm12 = _mm256_shuffle_ps(ymm8,ymm9,0b01000100); + ymm13 = _mm256_shuffle_ps(ymm10,ymm11,0b01000100); + + ymm14 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//3 + ymm15 = _mm256_permute2f128_ps(ymm12,ymm13,0x31);//7 + _mm256_storeu_ps((float *)(pbuff + 2*p_lda + 8), ymm14); + _mm256_storeu_ps((float *)(pbuff + 6*p_lda + 8), ymm15); + + ymm12 = _mm256_shuffle_ps(ymm8,ymm9,0b11101110); + ymm13 = _mm256_shuffle_ps(ymm10,ymm11,0b11101110); + + ymm14 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//4 + ymm15 = _mm256_permute2f128_ps(ymm12,ymm13,0x31);//8 + _mm256_storeu_ps((float *)(pbuff + 3*p_lda + 8), ymm14); + _mm256_storeu_ps((float *)(pbuff + 7*p_lda + 8), ymm15); + + ymm0 = _mm256_loadu_ps((float const *)(inbuf + cs_a*8 + 8)); + ymm1 = _mm256_loadu_ps((float const *)(inbuf + cs_a*9 + 8)); + ymm2 = _mm256_loadu_ps((float const *)(inbuf + cs_a*10 + 8)); + ymm3 = _mm256_loadu_ps((float const *)(inbuf + cs_a*11 + 8)); + ymm4 = _mm256_loadu_ps((float const *)(inbuf + cs_a*12 + 8)); + ymm5 = _mm256_loadu_ps((float const *)(inbuf + cs_a*13 + 8)); + ymm6 = _mm256_loadu_ps((float const *)(inbuf + cs_a*14 + 8)); + ymm7 = _mm256_loadu_ps((float const *)(inbuf + cs_a*15 + 8)); + + ymm8 = _mm256_unpacklo_ps(ymm0, ymm1); + ymm9 = _mm256_unpacklo_ps(ymm2, ymm3); + + ymm10 = _mm256_unpacklo_ps(ymm4, ymm5); + ymm11 = _mm256_unpacklo_ps(ymm6, ymm7); + + ymm12 = _mm256_shuffle_ps(ymm8,ymm9,0b01000100); + ymm13 = _mm256_shuffle_ps(ymm10,ymm11,0b01000100); + + ymm14 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//1 + ymm15 = _mm256_permute2f128_ps(ymm12,ymm13,0x31);//5 + + _mm256_storeu_ps((float *)(pbuff + 8*p_lda + 8), ymm14); + _mm256_storeu_ps((float *)(pbuff + 12*p_lda + 8), ymm15); + + ymm12 = _mm256_shuffle_ps(ymm8,ymm9,0b11101110); + ymm13 = _mm256_shuffle_ps(ymm10,ymm11,0b11101110); + + ymm14 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//2 + ymm15 = _mm256_permute2f128_ps(ymm12,ymm13,0x31);//6 + _mm256_storeu_ps((float *)(pbuff + 9*p_lda + 8), ymm14); + _mm256_storeu_ps((float *)(pbuff + 13*p_lda + 8), ymm15); + + ymm8 = _mm256_unpackhi_ps(ymm0, ymm1); + ymm9 = _mm256_unpackhi_ps(ymm2, ymm3); + + ymm10 = _mm256_unpackhi_ps(ymm4, ymm5); + ymm11 = _mm256_unpackhi_ps(ymm6, ymm7); + + ymm12 = _mm256_shuffle_ps(ymm8,ymm9,0b01000100); + ymm13 = _mm256_shuffle_ps(ymm10,ymm11,0b01000100); + + ymm14 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//3 + ymm15 = _mm256_permute2f128_ps(ymm12,ymm13,0x31);//7 + _mm256_storeu_ps((float *)(pbuff + 10*p_lda + 8), ymm14); + _mm256_storeu_ps((float *)(pbuff + 14*p_lda + 8), ymm15); + + ymm12 = _mm256_shuffle_ps(ymm8,ymm9,0b11101110); + ymm13 = _mm256_shuffle_ps(ymm10,ymm11,0b11101110); + + ymm14 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//4 + ymm15 = _mm256_permute2f128_ps(ymm12,ymm13,0x31);//8 + _mm256_storeu_ps((float *)(pbuff + 11*p_lda + 8), ymm14); + _mm256_storeu_ps((float *)(pbuff + 15*p_lda + 8), ymm15); + + inbuf += p_lda; + pbuff += p_lda*p_lda; + } + } + else + { + //Expected multiples of 8 + //p_lda = 16; + for(dim_t x = 0; x < size; x++) + { + ymm0 = _mm256_loadu_ps((float const *)(inbuf)); + _mm256_storeu_ps((float *)(pbuff), ymm0); + ymm1 = _mm256_loadu_ps((float const *)(inbuf + 8)); + _mm256_storeu_ps((float *)(pbuff + 8), ymm1); + inbuf+=cs_a; + pbuff+=p_lda; + } + } + } + else if(side=='R'||side=='r') + { + if(trans) + { + /* + ------------------ ---------- + | | | | | | + | 4x4 | 4x2 | | 4x4 |4x2 | + ------------- ==> ------------- + | | | | | | + | 2x4 | 2x2 | | 2x4 |2x2 | + ------------------- ------------- + */ + __m128 xmm0, xmm1, xmm2, xmm3; + __m128 xmm4, xmm5, xmm6, xmm7; + __m128 xmm8, xmm9, xmm10, xmm11; + __m128 xmm12, xmm13; + + for(dim_t x=0; xbuffer; //value of Alpha + float* restrict L = a->buffer; //pointer to matrix A + float* restrict B = b->buffer; //pointer to matrix B + + float *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks + + gint_t required_packing_A = 1; + mem_t local_mem_buf_A_s = {0}; + float *D_A_pack = NULL; + float d11_pack[d_mr] __attribute__((aligned(64))); + rntm_t rntm; + + bli_rntm_init_from_global( &rntm ); + bli_rntm_set_num_threads_only( 1, &rntm ); + bli_membrk_rntm_set_membrk( &rntm ); + + siz_t buffer_size = bli_pool_block_size( + bli_membrk_pool( + bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), + bli_rntm_membrk(&rntm))); + + if( (d_nr * n * sizeof(float)) > buffer_size) + return BLIS_NOT_YET_IMPLEMENTED; + + if (required_packing_A == 1) + { + // Get the buffer from the pool. + bli_membrk_acquire_m(&rntm, + buffer_size, + BLIS_BITVAL_BUFFER_FOR_A_BLOCK, + &local_mem_buf_A_s); + if(FALSE==bli_mem_is_alloc(&local_mem_buf_A_s)) return BLIS_NULL_POINTER; + D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); + if(NULL==D_A_pack) return BLIS_NULL_POINTER; + } + + //ymm scratch reginsters + __m256 ymm0, ymm1, ymm2, ymm3; + __m256 ymm4, ymm5, ymm6, ymm7; + __m256 ymm8, ymm9, ymm10, ymm11; + __m256 ymm12, ymm13, ymm14, ymm15; + + __m128 xmm5; + + /* + Performs solving TRSM for 6 rows at a time from 0 to n/6 in steps of d_nr + a. Load and pack A (a01 block), the size of packing 6x6 to 6x (n-6) + First there will be no GEMM and no packing of a01 because it is only TRSM + b. Using packed a01 block and b10 block perform GEMM operation + c. Use GEMM outputs, perform TRSM operation using a11, b11 and update B + d. Repeat b for m cols of B in steps of d_mr + */ + + for(j = (n-d_nr); (j+1) > 0; j -= d_nr) //loop along 'N' direction + { + a01 = L + (j*rs_a) + (j+d_nr)*cs_a; //pointer to block of A to be used in GEMM + a11 = L + (j*cs_a) + (j*rs_a); //pointer to block of A to be used for TRSM + + dim_t p_lda = (n-j-d_nr); // packed leading dimension + // perform copy of A to packed buffer D_A_pack + + if(transa) + { + /* + Pack current A block (a01) into packed buffer memory D_A_pack + a. This a10 block is used in GEMM portion only and this + a01 block size will be increasing by d_nr for every next iteration + until it reaches 6x(n-6) which is the maximum GEMM alone block size in A + b. This packed buffer is reused to calculate all m cols of B matrix + */ + bli_strsm_small_pack('R', p_lda, 1, a01, cs_a, D_A_pack, p_lda,d_nr); + + /* + Pack 6 diagonal elements of A block into an array + a. This helps in utilze cache line efficiently in TRSM operation + b. store ones when input is unit diagonal + */ + strsm_small_pack_diag_element('R',is_unitdiag,a11,cs_a,d11_pack,d_nr); + } + else + { + bli_strsm_small_pack('R', p_lda, 0, a01, rs_a, D_A_pack, p_lda,d_nr); + strsm_small_pack_diag_element('R',is_unitdiag,a11,rs_a,d11_pack,d_nr); + } + + /* + a. Perform GEMM using a01, b10. + b. Perform TRSM on a11, b11 + c. This loop GEMM+TRSM loops operates with 16x6 block size + along m dimension for every d_mr columns of B10 where + packed A buffer is reused in computing all m cols of B. + d. Same approach is used in remaining fringe cases. + */ + for(i = (m-d_mr); (i+1) > 0; i -= d_mr) //loop along 'M' direction + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i + (j+d_nr)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (i) + (j)*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-j-d_nr); //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + /* + Peform GEMM between a01 and b10 blocks + For first itteration there will be no GEMM operation + where k_iter are zero + */ + BLIS_STRSM_SMALL_GEMM_6nx16m(a01,b10,cs_b,p_lda,k_iter) + + /* + Load b11 of size 16x6 and multiply with alpha + Add the GEMM output to b11 + and peform TRSM operation. + */ + + BLIS_PRE_STRSM_SMALL_6x16(AlphaVal,b11,cs_b) + + ///implement TRSM/// + + /* + Compute 6x16 TRSM block by using GEMM block output in register + a. The 6x16 input (gemm outputs) are stored in combinations of ymm registers + b. Towards the end TRSM output will be stored back into b11 + */ + + //extract a55 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 5)); + + ymm13 = STRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); + ymm14 = STRSM_SMALL_DIV_OR_SCALE(ymm14, ymm0); + + //extract a44 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 4)); + + //(row 5):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*cs_a + 4*rs_a)); + + ymm11 = _mm256_fnmadd_ps(ymm1, ymm13, ymm11); + ymm12 = _mm256_fnmadd_ps(ymm1, ymm14, ymm12); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*cs_a + 3*rs_a)); + + ymm9 = _mm256_fnmadd_ps(ymm1, ymm13, ymm9); + ymm10 = _mm256_fnmadd_ps(ymm1, ymm14, ymm10); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*cs_a + 2*rs_a)); + + ymm7 = _mm256_fnmadd_ps(ymm1, ymm13, ymm7); + ymm8 = _mm256_fnmadd_ps(ymm1, ymm14, ymm8); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*cs_a + 1*rs_a)); + + ymm5 = _mm256_fnmadd_ps(ymm1, ymm13, ymm5); + ymm6 = _mm256_fnmadd_ps(ymm1, ymm14, ymm6); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*cs_a)); + + ymm3 = _mm256_fnmadd_ps(ymm1, ymm13, ymm3); + ymm4 = _mm256_fnmadd_ps(ymm1, ymm14, ymm4); + + ymm11 = STRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); + ymm12 = STRSM_SMALL_DIV_OR_SCALE(ymm12, ymm0); + + //extract a33 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 3)); + + //(row 4):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 4*cs_a + 3*rs_a)); + + ymm9 = _mm256_fnmadd_ps(ymm1, ymm11, ymm9); + ymm10 = _mm256_fnmadd_ps(ymm1, ymm12, ymm10); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 4*cs_a + 2*rs_a)); + + ymm7 = _mm256_fnmadd_ps(ymm1, ymm11, ymm7); + ymm8 = _mm256_fnmadd_ps(ymm1, ymm12, ymm8); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 4*cs_a + 1*rs_a)); + + ymm5 = _mm256_fnmadd_ps(ymm1, ymm11, ymm5); + ymm6 = _mm256_fnmadd_ps(ymm1, ymm12, ymm6); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 4*cs_a)); + + ymm3 = _mm256_fnmadd_ps(ymm1, ymm11, ymm3); + ymm4 = _mm256_fnmadd_ps(ymm1, ymm12, ymm4); + + ymm9 = STRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + ymm10 = STRSM_SMALL_DIV_OR_SCALE(ymm10, ymm0); + + //extract a22 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 2)); + + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*cs_a + 2*rs_a)); + + ymm7 = _mm256_fnmadd_ps(ymm1, ymm9, ymm7); + ymm8 = _mm256_fnmadd_ps(ymm1, ymm10, ymm8); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*cs_a + 1*rs_a)); + + ymm5 = _mm256_fnmadd_ps(ymm1, ymm9, ymm5); + ymm6 = _mm256_fnmadd_ps(ymm1, ymm10, ymm6); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*cs_a)); + + ymm3 = _mm256_fnmadd_ps(ymm1, ymm9, ymm3); + ymm4 = _mm256_fnmadd_ps(ymm1, ymm10, ymm4); + + ymm7 = STRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + ymm8 = STRSM_SMALL_DIV_OR_SCALE(ymm8, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*cs_a + 1*rs_a)); + + ymm5 = _mm256_fnmadd_ps(ymm1, ymm7, ymm5); + ymm6 = _mm256_fnmadd_ps(ymm1, ymm8, ymm6); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*cs_a)); + + ymm3 = _mm256_fnmadd_ps(ymm1, ymm7, ymm3); + ymm4 = _mm256_fnmadd_ps(ymm1, ymm8, ymm4); + + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + ymm6 = STRSM_SMALL_DIV_OR_SCALE(ymm6, ymm0); + + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack )); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + cs_a)); + + ymm3 = _mm256_fnmadd_ps(ymm1, ymm5, ymm3); + ymm4 = _mm256_fnmadd_ps(ymm1, ymm6, ymm4); + + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + ymm4 = STRSM_SMALL_DIV_OR_SCALE(ymm4, ymm0); + + _mm256_storeu_ps((float *)b11, ymm3); + _mm256_storeu_ps((float *)(b11 + 8), ymm4); + _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); + _mm256_storeu_ps((float *)(b11 + cs_b + 8), ymm6); + _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); + _mm256_storeu_ps((float *)(b11 + cs_b*2 + 8), ymm8); + _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); + _mm256_storeu_ps((float *)(b11 + cs_b*3 + 8), ymm10); + _mm256_storeu_ps((float *)(b11 + cs_b*4), ymm11); + _mm256_storeu_ps((float *)(b11 + cs_b*4 + 8), ymm12); + _mm256_storeu_ps((float *)(b11 + cs_b*5), ymm13); + _mm256_storeu_ps((float *)(b11 + cs_b*5 + 8), ymm14); + } + + dim_t m_remainder = i + d_mr; + if(m_remainder >= 8) + { + a01 = D_A_pack; + a11 = L + (j*cs_a) + (j*rs_a); + b10 = B + (m_remainder - 8) + (j+d_nr)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 8) + (j*cs_b); + + k_iter = (n-j-d_nr); //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) + + // Load b11 of size 4x6 and multiply with alpha + BLIS_PRE_STRSM_SMALL_6x8(AlphaVal,b11,cs_b) + + ///implement TRSM/// + + //extract a55 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 5)); + ymm13 = STRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); + + //extract a44 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 4)); + + //(row 5):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*cs_a + 4*rs_a)); + ymm11 = _mm256_fnmadd_ps(ymm1, ymm13, ymm11); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*cs_a + 3*rs_a)); + ymm9 = _mm256_fnmadd_ps(ymm1, ymm13, ymm9); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*cs_a + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm13, ymm7); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm13, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm13, ymm3); + + ymm11 = STRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); + + //extract a33 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 3)); + + //(row 4):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 4*cs_a + 3*rs_a)); + ymm9 = _mm256_fnmadd_ps(ymm1, ymm11, ymm9); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 4*cs_a + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm11, ymm7); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 4*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm11, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 4*cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm11, ymm3); + + ymm9 = STRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + + //extract a22 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 2)); + + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*cs_a + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm9, ymm7); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm9, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm9, ymm3); + + ymm7 = STRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm7, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm7, ymm3); + + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack )); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm5, ymm3); + + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + _mm256_storeu_ps((float *)b11, ymm3); + _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); + _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); + _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); + _mm256_storeu_ps((float *)(b11 + cs_b*4), ymm11); + _mm256_storeu_ps((float *)(b11 + cs_b*5), ymm13); + + m_remainder -=8; + } + + if(m_remainder) + { + if(7 == m_remainder) + { + a01 = D_A_pack; + a11 = L + (j*cs_a) + (j*rs_a); + b10 = B + (j+d_nr)*cs_b + (m_remainder - 7); //pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 7) + (j*cs_b); + + k_iter = (n-j-d_nr); //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) + + // Load b11 of size 4x6 and multiply with alpha + BLIS_PRE_STRSM_SMALL_6x8(AlphaVal,b11,cs_b) + + ///implement TRSM/// + + //extract a55 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 5)); + ymm13 = STRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); + + //extract a44 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 4)); + + //(row 5):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*cs_a + 4*rs_a)); + ymm11 = _mm256_fnmadd_ps(ymm1, ymm13, ymm11); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*cs_a + 3*rs_a)); + ymm9 = _mm256_fnmadd_ps(ymm1, ymm13, ymm9); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*cs_a + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm13, ymm7); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm13, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm13, ymm3); + + ymm11 = STRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); + + //extract a33 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 3)); + + //(row 4):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 4*cs_a + 3*rs_a)); + ymm9 = _mm256_fnmadd_ps(ymm1, ymm11, ymm9); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 4*cs_a + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm11, ymm7); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 4*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm11, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 4*cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm11, ymm3); + + ymm9 = STRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + + //extract a22 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 2)); + + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*cs_a + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm9, ymm7); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm9, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm9, ymm3); + + ymm7 = STRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm7, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm7, ymm3); + + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack )); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm5, ymm3); + + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + ymm0 = _mm256_loadu_ps((float const *)b11); + ymm3 = _mm256_blend_ps(ymm0, ymm3, 0x7F); + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_blend_ps(ymm0, ymm5, 0x7F); + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm7 = _mm256_blend_ps(ymm0, ymm7, 0x7F); + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm9 = _mm256_blend_ps(ymm0, ymm9, 0x7F); + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm11 = _mm256_blend_ps(ymm0, ymm11, 0x7F); + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm13 = _mm256_blend_ps(ymm0, ymm13, 0x7F); + + _mm256_storeu_ps((float *)b11, ymm3); + _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); + _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); + _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); + _mm256_storeu_ps((float *)(b11 + cs_b*4), ymm11); + _mm256_storeu_ps((float *)(b11 + cs_b*5), ymm13); + + m_remainder -=7; + } + else if(6 == m_remainder) + { + a01 = D_A_pack; + a11 = L + (j*cs_a) + (j*rs_a); + b10 = B + (j+d_nr)*cs_b + (m_remainder - 6); //pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 6) + (j*cs_b); + + k_iter = (n-j-d_nr); //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) + + // Load b11 of size 4x6 and multiply with alpha + BLIS_PRE_STRSM_SMALL_6x8(AlphaVal,b11,cs_b) + + ///implement TRSM/// + + //extract a55 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 5)); + ymm13 = STRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); + + //extract a44 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 4)); + + //(row 5):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*cs_a + 4*rs_a)); + ymm11 = _mm256_fnmadd_ps(ymm1, ymm13, ymm11); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*cs_a + 3*rs_a)); + ymm9 = _mm256_fnmadd_ps(ymm1, ymm13, ymm9); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*cs_a + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm13, ymm7); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm13, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm13, ymm3); + + ymm11 = STRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); + + //extract a33 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 3)); + + //(row 4):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 4*cs_a + 3*rs_a)); + ymm9 = _mm256_fnmadd_ps(ymm1, ymm11, ymm9); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 4*cs_a + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm11, ymm7); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 4*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm11, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 4*cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm11, ymm3); + + ymm9 = STRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + + //extract a22 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 2)); + + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*cs_a + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm9, ymm7); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm9, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm9, ymm3); + + ymm7 = STRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm7, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm7, ymm3); + + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack )); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm5, ymm3); + + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + ymm0 = _mm256_loadu_ps((float const *)b11); + ymm3 = _mm256_blend_ps(ymm0, ymm3, 0x3F); + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_blend_ps(ymm0, ymm5, 0x3F); + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm7 = _mm256_blend_ps(ymm0, ymm7, 0x3F); + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm9 = _mm256_blend_ps(ymm0, ymm9, 0x3F); + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm11 = _mm256_blend_ps(ymm0, ymm11, 0x3F); + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm13 = _mm256_blend_ps(ymm0, ymm13, 0x3F); + + _mm256_storeu_ps((float *)b11, ymm3); + _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); + _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); + _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); + _mm256_storeu_ps((float *)(b11 + cs_b*4), ymm11); + _mm256_storeu_ps((float *)(b11 + cs_b*5), ymm13); + + m_remainder -=6; + } + else if(5 == m_remainder) + { + a01 = D_A_pack; + a11 = L + (j*cs_a) + (j*rs_a); + b10 = B + (j+d_nr)*cs_b + (m_remainder - 5); //pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 5) + (j*cs_b); + + k_iter = (n-j-d_nr); //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) + + // Load b11 of size 4x6 and multiply with alpha + BLIS_PRE_STRSM_SMALL_6x8(AlphaVal,b11,cs_b) + + ///implement TRSM/// + + //extract a55 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 5)); + ymm13 = STRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); + + //extract a44 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 4)); + + //(row 5):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*cs_a + 4*rs_a)); + ymm11 = _mm256_fnmadd_ps(ymm1, ymm13, ymm11); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*cs_a + 3*rs_a)); + ymm9 = _mm256_fnmadd_ps(ymm1, ymm13, ymm9); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*cs_a + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm13, ymm7); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm13, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm13, ymm3); + + ymm11 = STRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); + + //extract a33 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 3)); + + //(row 4):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 4*cs_a + 3*rs_a)); + ymm9 = _mm256_fnmadd_ps(ymm1, ymm11, ymm9); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 4*cs_a + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm11, ymm7); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 4*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm11, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 4*cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm11, ymm3); + + ymm9 = STRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + + //extract a22 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 2)); + + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*cs_a + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm9, ymm7); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm9, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm9, ymm3); + + ymm7 = STRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm7, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm7, ymm3); + + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack )); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm5, ymm3); + + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + ymm0 = _mm256_loadu_ps((float const *)b11); + ymm3 = _mm256_blend_ps(ymm0, ymm3, 0x1F); + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_blend_ps(ymm0, ymm5, 0x1F); + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm7 = _mm256_blend_ps(ymm0, ymm7, 0x1F); + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm9 = _mm256_blend_ps(ymm0, ymm9, 0x1F); + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm11 = _mm256_blend_ps(ymm0, ymm11, 0x1F); + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm13 = _mm256_blend_ps(ymm0, ymm13, 0x1F); + + _mm256_storeu_ps((float *)b11, ymm3); + _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); + _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); + _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); + _mm256_storeu_ps((float *)(b11 + cs_b*4), ymm11); + _mm256_storeu_ps((float *)(b11 + cs_b*5), ymm13); + + m_remainder -=5; + } + else if(4 == m_remainder) + { + a01 = D_A_pack; + a11 = L + (j*cs_a) + (j*rs_a); + b10 = B + (j+d_nr)*cs_b + (m_remainder - 4); //pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 4) + (j*cs_b); + + k_iter = (n-j-d_nr); //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) + + // Load b11 of size 4x6 and multiply with alpha + BLIS_PRE_STRSM_SMALL_6x8(AlphaVal,b11,cs_b) + + ///implement TRSM/// + + //extract a55 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 5)); + ymm13 = STRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); + + //extract a44 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 4)); + + //(row 5):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*cs_a + 4*rs_a)); + ymm11 = _mm256_fnmadd_ps(ymm1, ymm13, ymm11); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*cs_a + 3*rs_a)); + ymm9 = _mm256_fnmadd_ps(ymm1, ymm13, ymm9); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*cs_a + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm13, ymm7); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm13, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm13, ymm3); + + ymm11 = STRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); + + //extract a33 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 3)); + + //(row 4):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 4*cs_a + 3*rs_a)); + ymm9 = _mm256_fnmadd_ps(ymm1, ymm11, ymm9); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 4*cs_a + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm11, ymm7); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 4*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm11, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 4*cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm11, ymm3); + + ymm9 = STRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + + //extract a22 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 2)); + + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*cs_a + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm9, ymm7); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm9, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm9, ymm3); + + ymm7 = STRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm7, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm7, ymm3); + + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack )); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm5, ymm3); + + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + ymm0 = _mm256_loadu_ps((float const *)b11); + ymm3 = _mm256_blend_ps(ymm0, ymm3, 0x0F); + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_blend_ps(ymm0, ymm5, 0x0F); + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm7 = _mm256_blend_ps(ymm0, ymm7, 0x0F); + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm9 = _mm256_blend_ps(ymm0, ymm9, 0x0F); + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm11 = _mm256_blend_ps(ymm0, ymm11, 0x0F); + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm13 = _mm256_blend_ps(ymm0, ymm13, 0x0F); + + _mm256_storeu_ps((float *)b11, ymm3); + _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); + _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); + _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); + _mm256_storeu_ps((float *)(b11 + cs_b*4), ymm11); + _mm256_storeu_ps((float *)(b11 + cs_b*5), ymm13); + + m_remainder -=4; + } + else if(3 == m_remainder) + { + a01 = D_A_pack; + a11 = L + (j*cs_a) + (j*rs_a); + b10 = B + (j+d_nr)*cs_b + (m_remainder - 3); //pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 3) + (j*cs_b); + + k_iter = (n-j-d_nr); //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) + + // Load b11 of size 4x6 and multiply with alpha + BLIS_PRE_STRSM_SMALL_6x8(AlphaVal,b11,cs_b) + + ///implement TRSM/// + + //extract a55 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 5)); + ymm13 = STRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); + + //extract a44 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 4)); + + //(row 5):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*cs_a + 4*rs_a)); + ymm11 = _mm256_fnmadd_ps(ymm1, ymm13, ymm11); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*cs_a + 3*rs_a)); + ymm9 = _mm256_fnmadd_ps(ymm1, ymm13, ymm9); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*cs_a + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm13, ymm7); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm13, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm13, ymm3); + + ymm11 = STRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); + + //extract a33 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 3)); + + //(row 4):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 4*cs_a + 3*rs_a)); + ymm9 = _mm256_fnmadd_ps(ymm1, ymm11, ymm9); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 4*cs_a + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm11, ymm7); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 4*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm11, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 4*cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm11, ymm3); + + ymm9 = STRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + + //extract a22 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 2)); + + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*cs_a + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm9, ymm7); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm9, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm9, ymm3); + + ymm7 = STRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm7, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm7, ymm3); + + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack )); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm5, ymm3); + + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + ymm0 = _mm256_loadu_ps((float const *)b11); + ymm3 = _mm256_blend_ps(ymm0, ymm3, 0x07); + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_blend_ps(ymm0, ymm5, 0x07); + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm7 = _mm256_blend_ps(ymm0, ymm7, 0x07); + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm9 = _mm256_blend_ps(ymm0, ymm9, 0x07); + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm11 = _mm256_blend_ps(ymm0, ymm11, 0x07); + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm13 = _mm256_blend_ps(ymm0, ymm13, 0x07); + + _mm256_storeu_ps((float *)b11, ymm3); + _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); + _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); + _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); + _mm256_storeu_ps((float *)(b11 + cs_b*4), ymm11); + _mm256_storeu_ps((float *)(b11 + cs_b*5), ymm13); + + m_remainder -=3; + } + else if(2 == m_remainder) + { + a01 = D_A_pack; + a11 = L + (j*cs_a) + (j*rs_a); + b10 = B + (j+d_nr)*cs_b + (m_remainder - 2); //pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 2) + (j*cs_b); + + k_iter = (n-j-d_nr); //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) + + // Load b11 of size 4x6 and multiply with alpha + BLIS_PRE_STRSM_SMALL_6x8(AlphaVal,b11,cs_b) + + ///implement TRSM/// + + //extract a55 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 5)); + ymm13 = STRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); + + //extract a44 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 4)); + + //(row 5):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*cs_a + 4*rs_a)); + ymm11 = _mm256_fnmadd_ps(ymm1, ymm13, ymm11); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*cs_a + 3*rs_a)); + ymm9 = _mm256_fnmadd_ps(ymm1, ymm13, ymm9); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*cs_a + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm13, ymm7); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm13, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm13, ymm3); + + ymm11 = STRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); + + //extract a33 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 3)); + + //(row 4):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 4*cs_a + 3*rs_a)); + ymm9 = _mm256_fnmadd_ps(ymm1, ymm11, ymm9); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 4*cs_a + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm11, ymm7); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 4*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm11, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 4*cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm11, ymm3); + + ymm9 = STRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + + //extract a22 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 2)); + + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*cs_a + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm9, ymm7); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm9, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm9, ymm3); + + ymm7 = STRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm7, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm7, ymm3); + + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack )); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm5, ymm3); + + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + ymm0 = _mm256_loadu_ps((float const *)b11); + ymm3 = _mm256_blend_ps(ymm0, ymm3, 0x03); + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_blend_ps(ymm0, ymm5, 0x03); + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm7 = _mm256_blend_ps(ymm0, ymm7, 0x03); + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm9 = _mm256_blend_ps(ymm0, ymm9, 0x03); + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm11 = _mm256_blend_ps(ymm0, ymm11, 0x03); + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm13 = _mm256_blend_ps(ymm0, ymm13, 0x03); + + _mm256_storeu_ps((float *)b11, ymm3); + _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); + _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); + _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); + _mm256_storeu_ps((float *)(b11 + cs_b*4), ymm11); + _mm256_storeu_ps((float *)(b11 + cs_b*5), ymm13); + + m_remainder -=2; + } + else if (1 == m_remainder) + { + a01 = D_A_pack; + a11 = L + (j*cs_a) + (j*rs_a); + b10 = B + (j+d_nr)*cs_b + (m_remainder - 1); //pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 1) + (j*cs_b); + + k_iter = (n-j-d_nr); //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) + + // Load b11 of size 4x6 and multiply with alpha + BLIS_PRE_STRSM_SMALL_6x8(AlphaVal,b11,cs_b) + + ///implement TRSM/// + + //extract a55 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 5)); + ymm13 = STRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); + + //extract a44 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 4)); + + //(row 5):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*cs_a + 4*rs_a)); + ymm11 = _mm256_fnmadd_ps(ymm1, ymm13, ymm11); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*cs_a + 3*rs_a)); + ymm9 = _mm256_fnmadd_ps(ymm1, ymm13, ymm9); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*cs_a + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm13, ymm7); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm13, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm13, ymm3); + + ymm11 = STRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); + + //extract a33 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 3)); + + //(row 4):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 4*cs_a + 3*rs_a)); + ymm9 = _mm256_fnmadd_ps(ymm1, ymm11, ymm9); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 4*cs_a + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm11, ymm7); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 4*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm11, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 4*cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm11, ymm3); + + ymm9 = STRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + + //extract a22 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 2)); + + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*cs_a + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm9, ymm7); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm9, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm9, ymm3); + + ymm7 = STRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm7, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm7, ymm3); + + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack )); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm5, ymm3); + + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + ymm0 = _mm256_loadu_ps((float const *)b11); + ymm3 = _mm256_blend_ps(ymm0, ymm3, 0x01); + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_blend_ps(ymm0, ymm5, 0x01); + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm7 = _mm256_blend_ps(ymm0, ymm7, 0x01); + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm9 = _mm256_blend_ps(ymm0, ymm9, 0x01); + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm11 = _mm256_blend_ps(ymm0, ymm11, 0x01); + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm13 = _mm256_blend_ps(ymm0, ymm13, 0x01); + + _mm256_storeu_ps((float *)b11, ymm3); + _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); + _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); + _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); + _mm256_storeu_ps((float *)(b11 + cs_b*4), ymm11); + _mm256_storeu_ps((float *)(b11 + cs_b*5), ymm13); + + m_remainder -=1; + } + } + } + + dim_t n_remainder = j + d_nr; + + /* + Reminder cases starts here: + a. Similar logic and code flow used in computing full block (6x8) + above holds for reminder cases too. + */ + + if(n_remainder >= 4) + { + a01 = L + (n_remainder - 4)*rs_a + n_remainder*cs_a; //pointer to block of A to be used in GEMM + a11 = L + (n_remainder - 4)*cs_a + (n_remainder - 4)*rs_a; //pointer to block of A to be used for TRSM + + float *ptr_a10_dup = D_A_pack; + + dim_t p_lda = (n-n_remainder); // packed leading dimension + // perform copy of A to packed buffer D_A_pack + + if(transa) + { + for(dim_t x =0;x < p_lda;x+=d_nr) + { + __m128 xmm0, xmm1, xmm2, xmm3; + __m128 xmm4, xmm5, xmm6, xmm7; + __m128 xmm8, xmm9; + + xmm0 = _mm_loadu_ps((float const *)(a01)); + xmm1 = _mm_loadu_ps((float const *)(a01 + cs_a)); + xmm2 = _mm_loadu_ps((float const *)(a01 + cs_a * 2)); + xmm3 = _mm_loadu_ps((float const *)(a01 + cs_a * 3)); + + xmm4 = _mm_unpacklo_ps(xmm0, xmm1); + xmm5 = _mm_unpacklo_ps(xmm2, xmm3); + xmm6 = _mm_shuffle_ps(xmm4,xmm5,0x44); + xmm7 = _mm_shuffle_ps(xmm4,xmm5,0xEE); + + xmm0 = _mm_unpackhi_ps(xmm0, xmm1); + xmm1 = _mm_unpackhi_ps(xmm2, xmm3); + xmm8 = _mm_shuffle_ps(xmm0,xmm1,0x44); + xmm9 = _mm_shuffle_ps(xmm0,xmm1,0xEE); + + _mm_storeu_ps((float *)(ptr_a10_dup), xmm6); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda), xmm7); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda*2), xmm8); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda*3), xmm9); + + xmm0 = _mm_loadu_ps((float const *)(a01 + cs_a * 4)); + xmm1 = _mm_loadu_ps((float const *)(a01 + cs_a * 5)); + + xmm4 = _mm_unpacklo_ps(xmm0, xmm1); + xmm5 = _mm_broadcast_ss((float const *)&zero); + xmm6 = _mm_shuffle_ps(xmm4,xmm5,0x44); + xmm7 = _mm_shuffle_ps(xmm4,xmm5,0xEE); + + xmm0 = _mm_unpackhi_ps(xmm0, xmm1); + xmm1 = _mm_broadcast_ss((float const *)&zero); + xmm8 = _mm_shuffle_ps(xmm0,xmm1,0x44); + xmm9 = _mm_shuffle_ps(xmm0,xmm1,0xEE); + + _mm_storel_pi((__m64 *)(ptr_a10_dup + 4), xmm6); + _mm_storel_pi((__m64 *)(ptr_a10_dup + 4 + p_lda), xmm7); + _mm_storel_pi((__m64 *)(ptr_a10_dup + 4 + p_lda*2), xmm8); + _mm_storel_pi((__m64 *)(ptr_a10_dup + 4 + p_lda*3), xmm9); + + a01 += d_nr*cs_a; + ptr_a10_dup += d_nr; + } + } + else + { + __m128 xmm0,xmm1; + dim_t loop_count = (n-n_remainder)/6; + for(dim_t i=0; i < loop_count; i++) + { + xmm1 = _mm_broadcast_ss((float *)&zero); + + xmm0 = _mm_loadu_ps((float *)(a01 + rs_a * 0 + i*6)); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda * 0 + i*6), xmm0); + xmm0 = _mm_loadl_pi(xmm1,(__m64 *)(a01 + rs_a * 0 + 4 + i*6)); + _mm_storel_pi((__m64 *)(ptr_a10_dup + p_lda * 0 + 4 + i*6),xmm0); + + xmm0 = _mm_loadu_ps((float const *)(a01 + rs_a * 1 + i*6)); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda * 1 + i*6), xmm0); + xmm0 = _mm_loadl_pi(xmm1,(__m64 *)(a01 + rs_a * 1 + 4 + i*6)); + _mm_storel_pi((__m64 *)(ptr_a10_dup + p_lda * 1 + 4 + i*6),xmm0); + + xmm0 = _mm_loadu_ps((float const *)(a01 + rs_a * 2 + i*6)); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda * 2 + i*6), xmm0); + xmm0 = _mm_loadl_pi(xmm1,(__m64 *)(a01 + rs_a * 2 + 4 + i*6)); + _mm_storel_pi((__m64 *)(ptr_a10_dup + p_lda * 2 + 4 + i*6),xmm0); + + xmm0 = _mm_loadu_ps((float const *)(a01 + rs_a * 3 + i*6)); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda * 3 + i*6), xmm0); + xmm0 = _mm_loadl_pi(xmm1,(__m64 *)(a01 + rs_a * 3 + 4 + i*6)); + _mm_storel_pi((__m64 *)(ptr_a10_dup + p_lda * 3 + 4 + i*6),xmm0); + } + } + + ymm4 = _mm256_broadcast_ss((float const *)&ones); + if(!is_unitdiag) + { + if(transa) + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_ss((float const *)(a11)); + ymm1 = _mm256_broadcast_ss((float const *)(a11+ cs_a*1 + 1)); + ymm2 = _mm256_broadcast_ss((float const *)(a11+ cs_a*2 + 2)); + ymm3 = _mm256_broadcast_ss((float const *)(a11+ cs_a*3 + 3)); + } + else + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_ss((float const *)(a11)); + ymm1 = _mm256_broadcast_ss((float const *)(a11+ rs_a*1 + 1)); + ymm2 = _mm256_broadcast_ss((float const *)(a11+ rs_a*2 + 2)); + ymm3 = _mm256_broadcast_ss((float const *)(a11+ rs_a*3 + 3)); + } + + ymm0 = _mm256_unpacklo_ps(ymm0, ymm1); + ymm1 = _mm256_unpacklo_ps(ymm2, ymm3); + + ymm1 = _mm256_blend_ps(ymm0, ymm1, 0x0C); + #ifdef BLIS_DISABLE_TRSM_PREINVERSION + ymm4 = ymm1; + #endif + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + ymm4 = _mm256_div_ps(ymm4, ymm1); + #endif + } + _mm256_storeu_ps((float *)(d11_pack), ymm4); + + for(i = (m-d_mr); (i+1) > 0; i -= d_mr) //loop along 'M' direction + { + a01 = D_A_pack; + a11 = L + (n_remainder - 4)*cs_a + (n_remainder - 4)*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (i) + (n_remainder - 4)*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_4nx16m(a01,b10,cs_b,p_lda,k_iter) + + BLIS_PRE_STRSM_SMALL_4x16(AlphaVal,b11,cs_b) + + ///implement TRSM/// + + //extract a33 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 3)); + + ymm9 = STRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + ymm10 = STRSM_SMALL_DIV_OR_SCALE(ymm10, ymm0); + + //extract a22 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 2)); + + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*cs_a + 2*rs_a)); + + ymm7 = _mm256_fnmadd_ps(ymm1, ymm9, ymm7); + ymm8 = _mm256_fnmadd_ps(ymm1, ymm10, ymm8); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*cs_a + 1*rs_a)); + + ymm5 = _mm256_fnmadd_ps(ymm1, ymm9, ymm5); + ymm6 = _mm256_fnmadd_ps(ymm1, ymm10, ymm6); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*cs_a)); + + ymm3 = _mm256_fnmadd_ps(ymm1, ymm9, ymm3); + ymm4 = _mm256_fnmadd_ps(ymm1, ymm10, ymm4); + + ymm7 = STRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + ymm8 = STRSM_SMALL_DIV_OR_SCALE(ymm8, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*cs_a + 1*rs_a)); + + ymm5 = _mm256_fnmadd_ps(ymm1, ymm7, ymm5); + ymm6 = _mm256_fnmadd_ps(ymm1, ymm8, ymm6); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*cs_a)); + + ymm3 = _mm256_fnmadd_ps(ymm1, ymm7, ymm3); + ymm4 = _mm256_fnmadd_ps(ymm1, ymm8, ymm4); + + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + ymm6 = STRSM_SMALL_DIV_OR_SCALE(ymm6, ymm0); + + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack )); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + cs_a)); + + ymm3 = _mm256_fnmadd_ps(ymm1, ymm5, ymm3); + ymm4 = _mm256_fnmadd_ps(ymm1, ymm6, ymm4); + + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + ymm4 = STRSM_SMALL_DIV_OR_SCALE(ymm4, ymm0); + + _mm256_storeu_ps((float *)b11, ymm3); + _mm256_storeu_ps((float *)(b11 + 8), ymm4); + _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); + _mm256_storeu_ps((float *)(b11 + cs_b + 8), ymm6); + _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); + _mm256_storeu_ps((float *)(b11 + cs_b*2 + 8), ymm8); + _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); + _mm256_storeu_ps((float *)(b11 + cs_b*3 + 8), ymm10); + } + + dim_t m_remainder = i + d_mr; + if(m_remainder >= 8) + { + a01 = D_A_pack; + a11 = L + (n_remainder - 4)*cs_a + (n_remainder - 4)*rs_a; //pointer to block of A to be used for TRSM + b10 = B + (m_remainder - 8) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 8) + (n_remainder - 4)*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_4nx8m(a01,b10,cs_b,p_lda,k_iter) + + ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); //register to hold alpha + + ymm0 = _mm256_loadu_ps((float const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm3 = _mm256_fmsub_ps(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_fmsub_ps(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm7 = _mm256_fmsub_ps(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 + + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm9 = _mm256_fmsub_ps(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 + + ///implement TRSM/// + + //extract a33 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 3)); + ymm9 = STRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + + //extract a22 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 2)); + + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*cs_a + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm9, ymm7); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm9, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm9, ymm3); + + ymm7 = STRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm7, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm7, ymm3); + + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack )); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm5, ymm3); + + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + _mm256_storeu_ps((float *)b11, ymm3); + _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); + _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); + _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); + + m_remainder -=8; + } + + if(m_remainder) + { + if(7 == m_remainder) + { + a01 = D_A_pack; + a11 = L + (n_remainder - 4)*cs_a + (n_remainder - 4)*rs_a; //pointer to block of A to be used for TRSM + b10 = B + (m_remainder - 7) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 7) + (n_remainder - 4)*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_4nx8m(a01,b10,cs_b,p_lda,k_iter) + + ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); //register to hold alpha + + __m128 xmm0,xmm1; + xmm0 = _mm_loadu_ps((float const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + + xmm0 = _mm_broadcast_ss((float *)(b11 + 6)); + xmm1 = _mm_loadl_pi(xmm0,(__m64 *)(b11 + 4)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 1); + + ymm3 = _mm256_fmsub_ps(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + + xmm0 = _mm_broadcast_ss((float *)(b11 + 6 + cs_b)); + xmm1 = _mm_loadl_pi(xmm0,(__m64 *)(b11 + 4 + cs_b)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 1); + + ymm5 = _mm256_fmsub_ps(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + + xmm0 = _mm_broadcast_ss((float *)(b11 + 6 + cs_b*2)); + xmm1 = _mm_loadl_pi(xmm0,(__m64 *)(b11 + 4 + cs_b*2)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 1); + + ymm7 = _mm256_fmsub_ps(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 + + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b*3)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + + xmm0 = _mm_broadcast_ss((float *)(b11 + 6 + cs_b*3)); + xmm1 = _mm_loadl_pi(xmm0,(__m64 *)(b11 + 4 + cs_b*3)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 1); + + ymm9 = _mm256_fmsub_ps(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 + + ///implement TRSM/// + + //extract a33 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 3)); + ymm9 = STRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + + //extract a22 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 2)); + + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*cs_a + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm9, ymm7); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm9, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm9, ymm3); + + ymm7 = STRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm7, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm7, ymm3); + + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack )); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm5, ymm3); + + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + _mm_storeu_ps((float *)(b11),_mm256_extractf128_ps(ymm3, 0)); + _mm_storel_pi((__m64 *)(b11 + 4),_mm256_extractf128_ps(ymm3, 1)); + _mm_store_ss((float *)(b11 + 6),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm3,ymm3), 1)); + + _mm_storeu_ps((float *)(b11 + cs_b),_mm256_extractf128_ps(ymm5, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b),_mm256_extractf128_ps(ymm5, 1)); + _mm_store_ss((float *)(b11 + 6 + cs_b),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm5,ymm5), 1)); + + _mm_storeu_ps((float *)(b11 + cs_b*2),_mm256_extractf128_ps(ymm7, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*2),_mm256_extractf128_ps(ymm7, 1)); + _mm_store_ss((float *)(b11 + 6 + cs_b*2),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm7,ymm7), 1)); + + _mm_storeu_ps((float *)(b11 + cs_b*3),_mm256_extractf128_ps(ymm9, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*3),_mm256_extractf128_ps(ymm9, 1)); + _mm_store_ss((float *)(b11 + 6 + cs_b*3),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm9,ymm9), 1)); + + m_remainder -=7; + } + else if(6 == m_remainder) + { + a01 = D_A_pack; + a11 = L + (n_remainder - 4)*cs_a + (n_remainder - 4)*rs_a; //pointer to block of A to be used for TRSM + b10 = B + (m_remainder - 6) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 6) + (n_remainder - 4)*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_4nx8m(a01,b10,cs_b,p_lda,k_iter) + + ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); //register to hold alpha + + __m128 xmm0,xmm1; + xmm0 = _mm_loadu_ps((float const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + + xmm1 = _mm_loadl_pi(xmm0,(__m64 *)(b11 + 4)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 1); + + ymm3 = _mm256_fmsub_ps(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + + xmm1 = _mm_loadl_pi(xmm0,(__m64 *)(b11 + 4 + cs_b)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 1); + + ymm5 = _mm256_fmsub_ps(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + + xmm1 = _mm_loadl_pi(xmm0,(__m64 *)(b11 + 4 + cs_b*2)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 1); + + ymm7 = _mm256_fmsub_ps(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 + + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b*3)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + + xmm1 = _mm_loadl_pi(xmm0,(__m64 *)(b11 + 4 + cs_b*3)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 1); + + ymm9 = _mm256_fmsub_ps(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 + + ///implement TRSM/// + + //extract a33 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 3)); + ymm9 = STRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + + //extract a22 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 2)); + + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*cs_a + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm9, ymm7); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm9, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm9, ymm3); + + ymm7 = STRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm7, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm7, ymm3); + + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack )); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm5, ymm3); + + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + _mm_storeu_ps((float *)(b11),_mm256_extractf128_ps(ymm3, 0)); + _mm_storel_pi((__m64 *)(b11 + 4),_mm256_extractf128_ps(ymm3, 1)); + _mm_storeu_ps((float *)(b11 + cs_b),_mm256_extractf128_ps(ymm5, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b),_mm256_extractf128_ps(ymm5, 1)); + _mm_storeu_ps((float *)(b11 + cs_b*2),_mm256_extractf128_ps(ymm7, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*2),_mm256_extractf128_ps(ymm7, 1)); + _mm_storeu_ps((float *)(b11 + cs_b*3),_mm256_extractf128_ps(ymm9, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*3),_mm256_extractf128_ps(ymm9, 1)); + + m_remainder -=6; + } + else if(5 == m_remainder) + { + a01 = D_A_pack; + a11 = L + (n_remainder - 4)*cs_a + (n_remainder - 4)*rs_a; //pointer to block of A to be used for TRSM + b10 = B + (m_remainder - 5) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 5) + (n_remainder - 4)*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_4nx8m(a01,b10,cs_b,p_lda,k_iter) + + ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); //register to hold alpha + + __m128 xmm0; + ymm0 = _mm256_broadcast_ss((float const *)(b11 + 4)); + xmm0 = _mm_loadu_ps((float const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + ymm3 = _mm256_fmsub_ps(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + + ymm0 = _mm256_broadcast_ss((float const *)(b11 + 4 + cs_b)); + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + ymm5 = _mm256_fmsub_ps(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + + ymm0 = _mm256_broadcast_ss((float const *)(b11 + 4 + cs_b*2)); + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + ymm7 = _mm256_fmsub_ps(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 + + ymm0 = _mm256_broadcast_ss((float const *)(b11 + 4 + cs_b*3)); + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b*3)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + ymm9 = _mm256_fmsub_ps(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 + + ///implement TRSM/// + + //extract a33 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 3)); + ymm9 = STRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + + //extract a22 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 2)); + + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*cs_a + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm9, ymm7); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm9, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm9, ymm3); + + ymm7 = STRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm7, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm7, ymm3); + + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack )); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm5, ymm3); + + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + _mm_storeu_ps((float *)(b11),_mm256_extractf128_ps(ymm3, 0)); + _mm_store_ss((float *)(b11 + 4),_mm256_extractf128_ps(ymm3, 1)); + _mm_storeu_ps((float *)(b11 + cs_b),_mm256_extractf128_ps(ymm5, 0)); + _mm_store_ss((float *)(b11 + 4 + cs_b),_mm256_extractf128_ps(ymm5, 1)); + _mm_storeu_ps((float *)(b11 + cs_b*2),_mm256_extractf128_ps(ymm7, 0)); + _mm_store_ss((float *)(b11 + 4 + cs_b*2),_mm256_extractf128_ps(ymm7, 1)); + _mm_storeu_ps((float *)(b11 + cs_b*3),_mm256_extractf128_ps(ymm9, 0)); + _mm_store_ss((float *)(b11 + 4 + cs_b*3),_mm256_extractf128_ps(ymm9, 1)); + + m_remainder -=5; + } + else if(4 == m_remainder) + { + a01 = D_A_pack; + a11 = L + (n_remainder - 4)*cs_a + (n_remainder - 4)*rs_a; //pointer to block of A to be used for TRSM + b10 = B + (m_remainder - 4) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 4) + (n_remainder - 4)*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_4nx8m(a01,b10,cs_b,p_lda,k_iter) + + ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); //register to hold alpha + + __m128 xmm0; + xmm0 = _mm_loadu_ps((float const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + ymm3 = _mm256_fmsub_ps(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + ymm5 = _mm256_fmsub_ps(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + ymm7 = _mm256_fmsub_ps(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 + + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b*3)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + ymm9 = _mm256_fmsub_ps(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 + + ///implement TRSM/// + + //extract a33 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 3)); + ymm9 = STRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + + //extract a22 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 2)); + + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*cs_a + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm9, ymm7); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm9, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm9, ymm3); + + ymm7 = STRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm7, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm7, ymm3); + + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack )); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm5, ymm3); + + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + _mm_storeu_ps((float *)(b11),_mm256_extractf128_ps(ymm3, 0)); + _mm_storeu_ps((float *)(b11 + cs_b),_mm256_extractf128_ps(ymm5, 0)); + _mm_storeu_ps((float *)(b11 + cs_b*2),_mm256_extractf128_ps(ymm7, 0)); + _mm_storeu_ps((float *)(b11 + cs_b*3),_mm256_extractf128_ps(ymm9, 0)); + + m_remainder -=4; + } + else if(3 == m_remainder) + { + a01 = D_A_pack; + a11 = L + (n_remainder - 4)*cs_a + (n_remainder - 4)*rs_a; //pointer to block of A to be used for TRSM + b10 = B + (m_remainder - 3) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 3) + (n_remainder - 4)*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_4nx8m(a01,b10,cs_b,p_lda,k_iter) + + ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); //register to hold alpha + + __m128 xmm0,xmm1; + xmm1 = _mm_broadcast_ss((float *)(b11 + 2)); + xmm0 = _mm_loadl_pi(xmm1,(__m64 *)(b11)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + ymm3 = _mm256_fmsub_ps(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + + xmm1 = _mm_broadcast_ss((float *)(b11 + cs_b + 2)); + xmm0 = _mm_loadl_pi(xmm1,(__m64 *)(b11 + cs_b)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + ymm5 = _mm256_fmsub_ps(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + + xmm1 = _mm_broadcast_ss((float *)(b11 + cs_b*2 + 2)); + xmm0 = _mm_loadl_pi(xmm1,(__m64 *)(b11 + cs_b*2)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + ymm7 = _mm256_fmsub_ps(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 + + xmm1 = _mm_broadcast_ss((float *)(b11 + cs_b*3 + 2)); + xmm0 = _mm_loadl_pi(xmm1,(__m64 *)(b11 + cs_b*3)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + ymm9 = _mm256_fmsub_ps(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 + + ///implement TRSM/// + + //extract a33 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 3)); + ymm9 = STRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + + //extract a22 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 2)); + + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*cs_a + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm9, ymm7); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm9, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm9, ymm3); + + ymm7 = STRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm7, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm7, ymm3); + + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack )); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm5, ymm3); + + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + xmm0 = _mm256_extractf128_ps(ymm3, 0); + _mm_storel_pi((__m64 *)(b11),xmm0); + _mm_store_ss((float *)(b11+2),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm3,ymm3), 0)); + + xmm0 = _mm256_extractf128_ps(ymm5, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b),xmm0); + _mm_store_ss((float *)(b11+ 2 + cs_b),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm5,ymm5), 0)); + + xmm0 = _mm256_extractf128_ps(ymm7, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b*2),xmm0); + _mm_store_ss((float *)(b11 + 2 + cs_b*2),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm7,ymm7), 0)); + + xmm0 = _mm256_extractf128_ps(ymm9, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b*3),xmm0); + _mm_store_ss((float *)(b11 + 2 + cs_b*3),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm9,ymm9), 0)); + + m_remainder -=3; + } + else if(2 == m_remainder) + { + a01 = D_A_pack; + a11 = L + (n_remainder - 4)*cs_a + (n_remainder - 4)*rs_a; //pointer to block of A to be used for TRSM + b10 = B + (m_remainder - 2) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 2) + (n_remainder - 4)*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_4nx8m(a01,b10,cs_b,p_lda,k_iter) + + ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); //register to hold alpha + + __m128 xmm0,xmm1; + xmm1 = _mm_broadcast_ss((float *)&zero); + xmm0 = _mm_loadl_pi(xmm1,(__m64 *)(b11)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + ymm3 = _mm256_fmsub_ps(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + + xmm0 = _mm_loadl_pi(xmm1,(__m64 *)(b11 + cs_b)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + ymm5 = _mm256_fmsub_ps(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + + xmm0 = _mm_loadl_pi(xmm1,(__m64 *)(b11 + cs_b*2)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + ymm7 = _mm256_fmsub_ps(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 + + xmm0 = _mm_loadl_pi(xmm1,(__m64 *)(b11 + cs_b*3)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + ymm9 = _mm256_fmsub_ps(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 + + ///implement TRSM/// + + //extract a33 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 3)); + ymm9 = STRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + + //extract a22 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 2)); + + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*cs_a + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm9, ymm7); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm9, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm9, ymm3); + + ymm7 = STRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm7, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm7, ymm3); + + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack )); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm5, ymm3); + + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + xmm0 = _mm256_extractf128_ps(ymm3, 0); + _mm_storel_pi((__m64 *)(b11),xmm0); + + xmm0 = _mm256_extractf128_ps(ymm5, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b),xmm0); + + xmm0 = _mm256_extractf128_ps(ymm7, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b*2),xmm0); + + xmm0 = _mm256_extractf128_ps(ymm9, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b*3),xmm0); + + m_remainder -=2; + } + else if (1 == m_remainder) + { + a01 = D_A_pack; + a11 = L + (n_remainder - 4)*cs_a + (n_remainder - 4)*rs_a; //pointer to block of A to be used for TRSM + b10 = B + (m_remainder - 1) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 1) + (n_remainder - 4)*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_4nx8m(a01,b10,cs_b,p_lda,k_iter) + + ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); //register to hold alpha + + ymm0 = _mm256_broadcast_ss((float const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm3 = _mm256_fmsub_ps(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + + ymm0 = _mm256_broadcast_ss((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_fmsub_ps(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + + ymm0 = _mm256_broadcast_ss((float const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm7 = _mm256_fmsub_ps(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 + + ymm0 = _mm256_broadcast_ss((float const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm9 = _mm256_fmsub_ps(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 + + ///implement TRSM/// + + //extract a33 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 3)); + ymm9 = STRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + + //extract a22 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 2)); + + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*cs_a + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm9, ymm7); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm9, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm9, ymm3); + + ymm7 = STRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm7, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm7, ymm3); + + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack )); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm5, ymm3); + + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + _mm_store_ss((b11 + cs_b * 0), _mm256_extractf128_ps(ymm3, 0)); + _mm_store_ss((b11 + cs_b * 1), _mm256_extractf128_ps(ymm5, 0)); + _mm_store_ss((b11 + cs_b * 2), _mm256_extractf128_ps(ymm7, 0)); + _mm_store_ss((b11 + cs_b * 3), _mm256_extractf128_ps(ymm9, 0)); + + m_remainder -=1; + } + } + n_remainder -= 4; + } + + if(n_remainder == 3) + { + a01 = L + (n_remainder - 3)*rs_a + n_remainder*cs_a; //pointer to block of A to be used in GEMM + a11 = L + (n_remainder - 3)*cs_a + (n_remainder - 3)*rs_a; //pointer to block of A to be used for TRSM + + float *ptr_a10_dup = D_A_pack; + + dim_t p_lda = (n-n_remainder); // packed leading dimension + // perform copy of A to packed buffer D_A_pack + + if(transa) + { + __m128 xmm0, xmm1, xmm2, xmm3; + __m128 xmm4, xmm5, xmm6, xmm7; + __m128 xmm8, xmm9; + + for(dim_t x =0;x < p_lda;x+=d_nr) + { + xmm0 = _mm_loadu_ps((float const *)(a01)); + xmm1 = _mm_loadu_ps((float const *)(a01 + cs_a)); + xmm2 = _mm_loadu_ps((float const *)(a01 + cs_a * 2)); + xmm3 = _mm_loadu_ps((float const *)(a01 + cs_a * 3)); + + xmm4 = _mm_unpacklo_ps(xmm0, xmm1); + xmm5 = _mm_unpacklo_ps(xmm2, xmm3); + xmm6 = _mm_shuffle_ps(xmm4,xmm5,0x44); + xmm7 = _mm_shuffle_ps(xmm4,xmm5,0xEE); + + xmm0 = _mm_unpackhi_ps(xmm0, xmm1); + xmm1 = _mm_unpackhi_ps(xmm2, xmm3); + xmm8 = _mm_shuffle_ps(xmm0,xmm1,0x44); + xmm9 = _mm_shuffle_ps(xmm0,xmm1,0xEE); + + _mm_storeu_ps((float *)(ptr_a10_dup), xmm6); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda), xmm7); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda*2), xmm8); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda*3), xmm9); + + xmm0 = _mm_loadu_ps((float const *)(a01 + cs_a * 4)); + xmm1 = _mm_loadu_ps((float const *)(a01 + cs_a * 5)); + + xmm4 = _mm_unpacklo_ps(xmm0, xmm1); + xmm5 = _mm_broadcast_ss((float const *)&zero); + xmm6 = _mm_shuffle_ps(xmm4,xmm5,0x44); + xmm7 = _mm_shuffle_ps(xmm4,xmm5,0xEE); + + xmm0 = _mm_unpackhi_ps(xmm0, xmm1); + xmm1 = _mm_broadcast_ss((float const *)&zero); + xmm8 = _mm_shuffle_ps(xmm0,xmm1,0x44); + xmm9 = _mm_shuffle_ps(xmm0,xmm1,0xEE); + + _mm_storel_pi((__m64 *)(ptr_a10_dup + 4), xmm6); + _mm_storel_pi((__m64 *)(ptr_a10_dup + 4 + p_lda), xmm7); + _mm_storel_pi((__m64 *)(ptr_a10_dup + 4 + p_lda*2), xmm8); + _mm_storel_pi((__m64 *)(ptr_a10_dup + 4 + p_lda*3), xmm9); + + a01 += d_nr*cs_a; + ptr_a10_dup += d_nr; + } + + } + else + { + __m128 xmm0,xmm1; + dim_t loop_count = (n-n_remainder)/6; + for(dim_t i=0; i < loop_count; i++) + { + xmm1 = _mm_broadcast_ss((float *)&zero); + + xmm0 = _mm_loadu_ps((float *)(a01 + rs_a * 0 + i*6)); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda * 0 + i*6), xmm0); + xmm0 = _mm_loadl_pi(xmm1,(__m64 *)(a01 + rs_a * 0 + 4 + i*6)); + _mm_storel_pi((__m64 *)(ptr_a10_dup + p_lda * 0 + 4 + i*6),xmm0); + + xmm0 = _mm_loadu_ps((float const *)(a01 + rs_a * 1 + i*6)); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda * 1 + i*6), xmm0); + xmm0 = _mm_loadl_pi(xmm1,(__m64 *)(a01 + rs_a * 1 + 4 + i*6)); + _mm_storel_pi((__m64 *)(ptr_a10_dup + p_lda * 1 + 4 + i*6),xmm0); + + xmm0 = _mm_loadu_ps((float const *)(a01 + rs_a * 2 + i*6)); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda * 2 + i*6), xmm0); + xmm0 = _mm_loadl_pi(xmm1,(__m64 *)(a01 + rs_a * 2 + 4 + i*6)); + _mm_storel_pi((__m64 *)(ptr_a10_dup + p_lda * 2 + 4 + i*6),xmm0); + } + } + + ymm4 = _mm256_broadcast_ss((float const *)&ones); + if(!is_unitdiag) + { + if(transa) + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_ss((float const *)(a11)); + ymm1 = _mm256_broadcast_ss((float const *)(a11+ cs_a*1 + 1)); + ymm2 = _mm256_broadcast_ss((float const *)(a11+ cs_a*2 + 2)); + ymm3 = _mm256_broadcast_ss((float const *)&ones); + } + else + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_ss((float const *)(a11)); + ymm1 = _mm256_broadcast_ss((float const *)(a11+ rs_a*1 + 1)); + ymm2 = _mm256_broadcast_ss((float const *)(a11+ rs_a*2 + 2)); + ymm3 = _mm256_broadcast_ss((float const *)&ones); + } + + ymm0 = _mm256_unpacklo_ps(ymm0, ymm1); + ymm1 = _mm256_unpacklo_ps(ymm2, ymm3); + + ymm1 = _mm256_blend_ps(ymm0, ymm1, 0x0C); + #ifdef BLIS_DISABLE_TRSM_PREINVERSION + ymm4 = ymm1; + #endif + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + ymm4 = _mm256_div_ps(ymm4, ymm1); + #endif + } + _mm256_storeu_ps((float *)(d11_pack), ymm4); + + for(i = (m-d_mr); (i+1) > 0; i -= d_mr) //loop along 'M' direction + { + a01 = D_A_pack; + a11 = L + (n_remainder - 3)*cs_a + (n_remainder - 3)*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (i) + (n_remainder - 3)*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_3nx16m(a01,b10,cs_b,p_lda,k_iter) + + ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); + + ymm0 = _mm256_loadu_ps((float const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_ps((float const *)(b11 + 8)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] + + ymm3 = _mm256_fmsub_ps(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + ymm4 = _mm256_fmsub_ps(ymm1, ymm15, ymm4); //B11[4-7][0] * alpha-= ymm1 + + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm1 = _mm256_loadu_ps((float const *)(b11 + cs_b + 8)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] + + ymm5 = _mm256_fmsub_ps(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + ymm6 = _mm256_fmsub_ps(ymm1, ymm15, ymm6); //B11[4-7][1] * alpha -= ymm3 + + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm1 = _mm256_loadu_ps((float const *)(b11 + cs_b*2 + 8)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] + + ymm7 = _mm256_fmsub_ps(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 + ymm8 = _mm256_fmsub_ps(ymm1, ymm15, ymm8); //B11[4-7][2] * alpha -= ymm5 + + ///implement TRSM/// + + //extract a22 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 2)); + + ymm7 = STRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + ymm8 = STRSM_SMALL_DIV_OR_SCALE(ymm8, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*cs_a + 1*rs_a)); + + ymm5 = _mm256_fnmadd_ps(ymm1, ymm7, ymm5); + ymm6 = _mm256_fnmadd_ps(ymm1, ymm8, ymm6); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*cs_a)); + + ymm3 = _mm256_fnmadd_ps(ymm1, ymm7, ymm3); + ymm4 = _mm256_fnmadd_ps(ymm1, ymm8, ymm4); + + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + ymm6 = STRSM_SMALL_DIV_OR_SCALE(ymm6, ymm0); + + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack )); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + cs_a)); + + ymm3 = _mm256_fnmadd_ps(ymm1, ymm5, ymm3); + ymm4 = _mm256_fnmadd_ps(ymm1, ymm6, ymm4); + + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + ymm4 = STRSM_SMALL_DIV_OR_SCALE(ymm4, ymm0); + + _mm256_storeu_ps((float *)b11, ymm3); + _mm256_storeu_ps((float *)(b11 + 8), ymm4); + _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); + _mm256_storeu_ps((float *)(b11 + cs_b + 8), ymm6); + _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); + _mm256_storeu_ps((float *)(b11 + cs_b*2 + 8), ymm8); + } + + dim_t m_remainder = i + d_mr; + if(m_remainder >= 8) + { + a01 = D_A_pack; + a11 = L + (n_remainder - 3)*cs_a + (n_remainder - 3)*rs_a; //pointer to block of A to be used for TRSM + b10 = B + (m_remainder - 8) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 8) + (n_remainder - 3)*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_3nx8m(a01,b10,cs_b,p_lda,k_iter) + + ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); //register to hold alpha + + ymm0 = _mm256_loadu_ps((float const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm3 = _mm256_fmsub_ps(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_fmsub_ps(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm7 = _mm256_fmsub_ps(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 + + ///implement TRSM/// + + //extract a22 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 2)); + ymm7 = STRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm7, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm7, ymm3); + + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack )); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm5, ymm3); + + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + _mm256_storeu_ps((float *)b11, ymm3); + _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); + _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); + + m_remainder -=8; + } + if(m_remainder) + { + if(7 == m_remainder) + { + a01 = D_A_pack; + a11 = L + (n_remainder - 3)*cs_a + (n_remainder - 3)*rs_a; //pointer to block of A to be used for TRSM + b10 = B + (m_remainder - 7) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 7) + (n_remainder - 3)*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_3nx8m(a01,b10,cs_b,p_lda,k_iter) + + BLIS_PRE_STRSM_SMALL_3N_7M(AlphaVal,b11,cs_b) + + ///implement TRSM/// + //extract a22 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 2)); + ymm7 = STRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm7, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm7, ymm3); + + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack )); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm5, ymm3); + + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + BLIS_POST_STRSM_SMALL_3N_7M(b11,cs_b) + + m_remainder -=7; + } + else if(6 == m_remainder) + { + a01 = D_A_pack; + a11 = L + (n_remainder - 3)*cs_a + (n_remainder - 3)*rs_a; //pointer to block of A to be used for TRSM + b10 = B + (m_remainder - 6) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 6) + (n_remainder - 3)*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_3nx8m(a01,b10,cs_b,p_lda,k_iter) + + BLIS_PRE_STRSM_SMALL_3N_6M(AlphaVal,b11,cs_b) + + ///implement TRSM/// + //extract a22 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 2)); + ymm7 = STRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm7, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm7, ymm3); + + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack )); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm5, ymm3); + + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + BLIS_POST_STRSM_SMALL_3N_6M(b11,cs_b) + + m_remainder -=6; + } + else if(5 == m_remainder) + { + a01 = D_A_pack; + a11 = L + (n_remainder - 3)*cs_a + (n_remainder - 3)*rs_a; //pointer to block of A to be used for TRSM + b10 = B + (m_remainder - 5) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 5) + (n_remainder - 3)*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_3nx8m(a01,b10,cs_b,p_lda,k_iter) + + BLIS_PRE_STRSM_SMALL_3N_5M(AlphaVal,b11,cs_b) + + ///implement TRSM/// + //extract a22 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 2)); + ymm7 = STRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm7, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm7, ymm3); + + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack )); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm5, ymm3); + + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + BLIS_POST_STRSM_SMALL_3N_5M(b11,cs_b) + + m_remainder -=5; + + } + else if(4 == m_remainder) + { + a01 = D_A_pack; + a11 = L + (n_remainder - 3)*cs_a + (n_remainder - 3)*rs_a; //pointer to block of A to be used for TRSM + b10 = B + (m_remainder - 4) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 4) + (n_remainder - 3)*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_3nx8m(a01,b10,cs_b,p_lda,k_iter) + + BLIS_PRE_STRSM_SMALL_3N_4M(AlphaVal,b11,cs_b) + + ///implement TRSM/// + //extract a22 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 2)); + ymm7 = STRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm7, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm7, ymm3); + + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack )); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm5, ymm3); + + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + BLIS_POST_STRSM_SMALL_3N_4M(b11,cs_b) + + m_remainder -=4; + } + else if(3 == m_remainder) + { + a01 = D_A_pack; + a11 = L + (n_remainder - 3)*cs_a + (n_remainder - 3)*rs_a; //pointer to block of A to be used for TRSM + b10 = B + (m_remainder - 3) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 3) + (n_remainder - 3)*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_3nx8m(a01,b10,cs_b,p_lda,k_iter) + + BLIS_PRE_STRSM_SMALL_3N_3M(AlphaVal,b11,cs_b) + + ///implement TRSM/// + //extract a22 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 2)); + ymm7 = STRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm7, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm7, ymm3); + + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack )); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm5, ymm3); + + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + BLIS_POST_STRSM_SMALL_3N_3M(b11,cs_b) + + m_remainder -=3; + } + else if(2 == m_remainder) + { + a01 = D_A_pack; + a11 = L + (n_remainder - 3)*cs_a + (n_remainder - 3)*rs_a; //pointer to block of A to be used for TRSM + b10 = B + (m_remainder - 2) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 2) + (n_remainder - 3)*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_3nx8m(a01,b10,cs_b,p_lda,k_iter) + + BLIS_PRE_STRSM_SMALL_3N_2M(AlphaVal,b11,cs_b) + + ///implement TRSM/// + + //extract a22 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 2)); + ymm7 = STRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm7, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm7, ymm3); + + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack )); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm5, ymm3); + + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + BLIS_POST_STRSM_SMALL_3N_2M(b11,cs_b) + + m_remainder -=2; + } + else if (1 == m_remainder) + { + a01 = D_A_pack; + a11 = L + (n_remainder - 3)*cs_a + (n_remainder - 3)*rs_a; //pointer to block of A to be used for TRSM + b10 = B + (m_remainder - 1) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 1) + (n_remainder - 3)*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_3nx8m(a01,b10,cs_b,p_lda,k_iter) + + BLIS_PRE_STRSM_SMALL_3N_1M(AlphaVal,b11,cs_b) + + ///implement TRSM/// + + //extract a22 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 2)); + ymm7 = STRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*cs_a + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm7, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm7, ymm3); + + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack )); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm5, ymm3); + + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + BLIS_POST_STRSM_SMALL_3N_1M(b11,cs_b) + + m_remainder -=1; + } + } + n_remainder -= 3; + } + else if(n_remainder == 2) + { + a01 = L + (n_remainder - 2)*rs_a + n_remainder*cs_a; //pointer to block of A to be used in GEMM + a11 = L + (n_remainder - 2)*cs_a + (n_remainder - 2)*rs_a; //pointer to block of A to be used for TRSM + + float *ptr_a10_dup = D_A_pack; + + dim_t p_lda = (n-n_remainder); // packed leading dimension + // perform copy of A to packed buffer D_A_pack + + if(transa) + { + __m128 xmm0, xmm1, xmm2, xmm3; + __m128 xmm4, xmm5, xmm6, xmm7; + __m128 xmm8, xmm9; + + for(dim_t x =0;x < p_lda;x+=d_nr) + { + xmm0 = _mm_loadu_ps((float const *)(a01)); + xmm1 = _mm_loadu_ps((float const *)(a01 + cs_a)); + xmm2 = _mm_loadu_ps((float const *)(a01 + cs_a * 2)); + xmm3 = _mm_loadu_ps((float const *)(a01 + cs_a * 3)); + + xmm4 = _mm_unpacklo_ps(xmm0, xmm1); + xmm5 = _mm_unpacklo_ps(xmm2, xmm3); + xmm6 = _mm_shuffle_ps(xmm4,xmm5,0x44); + xmm7 = _mm_shuffle_ps(xmm4,xmm5,0xEE); + + xmm0 = _mm_unpackhi_ps(xmm0, xmm1); + xmm1 = _mm_unpackhi_ps(xmm2, xmm3); + xmm8 = _mm_shuffle_ps(xmm0,xmm1,0x44); + xmm9 = _mm_shuffle_ps(xmm0,xmm1,0xEE); + + _mm_storeu_ps((float *)(ptr_a10_dup), xmm6); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda), xmm7); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda*2), xmm8); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda*3), xmm9); + + xmm0 = _mm_loadu_ps((float const *)(a01 + cs_a * 4)); + xmm1 = _mm_loadu_ps((float const *)(a01 + cs_a * 5)); + + xmm4 = _mm_unpacklo_ps(xmm0, xmm1); + xmm5 = _mm_broadcast_ss((float const *)&zero); + xmm6 = _mm_shuffle_ps(xmm4,xmm5,0x44); + xmm7 = _mm_shuffle_ps(xmm4,xmm5,0xEE); + + xmm0 = _mm_unpackhi_ps(xmm0, xmm1); + xmm1 = _mm_broadcast_ss((float const *)&zero); + xmm8 = _mm_shuffle_ps(xmm0,xmm1,0x44); + xmm9 = _mm_shuffle_ps(xmm0,xmm1,0xEE); + + _mm_storel_pi((__m64 *)(ptr_a10_dup + 4), xmm6); + _mm_storel_pi((__m64 *)(ptr_a10_dup + 4 + p_lda), xmm7); + _mm_storel_pi((__m64 *)(ptr_a10_dup + 4 + p_lda*2), xmm8); + _mm_storel_pi((__m64 *)(ptr_a10_dup + 4 + p_lda*3), xmm9); + + a01 += d_nr*cs_a; + ptr_a10_dup += d_nr; + } + } + else + { + __m128 xmm0,xmm1; + dim_t loop_count = (n-n_remainder)/6; + for(dim_t i=0; i < loop_count; i++) + { + xmm1 = _mm_broadcast_ss((float *)&zero); + + xmm0 = _mm_loadu_ps((float *)(a01 + rs_a * 0 + i*6)); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda * 0 + i*6), xmm0); + xmm0 = _mm_loadl_pi(xmm1,(__m64 *)(a01 + rs_a * 0 + 4 + i*6)); + _mm_storel_pi((__m64 *)(ptr_a10_dup + p_lda * 0 + 4 + i*6),xmm0); + + xmm0 = _mm_loadu_ps((float const *)(a01 + rs_a * 1 + i*6)); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda * 1 + i*6), xmm0); + xmm0 = _mm_loadl_pi(xmm1,(__m64 *)(a01 + rs_a * 1 + 4 + i*6)); + _mm_storel_pi((__m64 *)(ptr_a10_dup + p_lda * 1 + 4 + i*6),xmm0); + } + } + + ymm4 = _mm256_broadcast_ss((float const *)&ones); + if(!is_unitdiag) + { + if(transa) + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_ss((float const *)(a11)); + ymm1 = _mm256_broadcast_ss((float const *)(a11+ cs_a*1 + 1)); + } + else + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_ss((float const *)(a11)); + ymm1 = _mm256_broadcast_ss((float const *)(a11+ rs_a*1 + 1)); + } + + ymm0 = _mm256_unpacklo_ps(ymm0, ymm1); + //ymm1 = _mm256_unpacklo_ps(ymm2, ymm3); + + ymm1 = _mm256_blend_ps(ymm0, ymm0, 0x0C); + #ifdef BLIS_DISABLE_TRSM_PREINVERSION + ymm4 = ymm1; + #endif + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + ymm4 = _mm256_div_ps(ymm4, ymm1); + #endif + } + _mm256_storeu_ps((float *)(d11_pack), ymm4); + + for(i = (m-d_mr); (i+1) > 0; i -= d_mr) //loop along 'M' direction + { + a01 = D_A_pack; + a11 = L + (n_remainder - 2)*cs_a + (n_remainder - 2)*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (i) + (n_remainder - 2)*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_2nx16m(a01,b10,cs_b,p_lda,k_iter) + + ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); + + ymm0 = _mm256_loadu_ps((float const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_ps((float const *)(b11 + 8)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] + + ymm3 = _mm256_fmsub_ps(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + ymm4 = _mm256_fmsub_ps(ymm1, ymm15, ymm4); //B11[4-7][0] * alpha-= ymm1 + + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm1 = _mm256_loadu_ps((float const *)(b11 + cs_b + 8)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] + + ymm5 = _mm256_fmsub_ps(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + ymm6 = _mm256_fmsub_ps(ymm1, ymm15, ymm6); //B11[4-7][1] * alpha -= ymm3 + + ///implement TRSM/// + + //extract a11 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); + + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + ymm6 = STRSM_SMALL_DIV_OR_SCALE(ymm6, ymm0); + + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack )); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + cs_a)); + + ymm3 = _mm256_fnmadd_ps(ymm1, ymm5, ymm3); + ymm4 = _mm256_fnmadd_ps(ymm1, ymm6, ymm4); + + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + ymm4 = STRSM_SMALL_DIV_OR_SCALE(ymm4, ymm0); + + _mm256_storeu_ps((float *)b11, ymm3); + _mm256_storeu_ps((float *)(b11 + 8), ymm4); + _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); + _mm256_storeu_ps((float *)(b11 + cs_b + 8), ymm6); + } + + dim_t m_remainder = i + d_mr; + if(m_remainder >= 8) + { + a01 = D_A_pack; + a11 = L + (n_remainder - 2)*cs_a + (n_remainder - 2)*rs_a; //pointer to block of A to be used for TRSM + b10 = B + (m_remainder - 8) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 8) + (n_remainder - 2)*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_2nx8m(a01,b10,cs_b,p_lda,k_iter) + + ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); //register to hold alpha + + ymm0 = _mm256_loadu_ps((float const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm3 = _mm256_fmsub_ps(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_fmsub_ps(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + + ///implement TRSM/// + + //extract a11 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack )); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm5, ymm3); + + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + _mm256_storeu_ps((float *)b11, ymm3); + _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); + + m_remainder -=8; + } + + if(m_remainder) + { + if(7 == m_remainder) + { + a01 = D_A_pack; + a11 = L + (n_remainder - 2)*cs_a + (n_remainder - 2)*rs_a; //pointer to block of A to be used for TRSM + b10 = B + (m_remainder - 7) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 7) + (n_remainder - 2)*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_2nx8m(a01,b10,cs_b,p_lda,k_iter) + + BLIS_PRE_STRSM_SMALL_2N_7M(AlphaVal,b11,cs_b) + + ///implement TRSM/// + + //extract a11 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack )); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm5, ymm3); + + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + BLIS_POST_STRSM_SMALL_2N_7M(b11,cs_b) + + m_remainder -=7; + } + else if(6 == m_remainder) + { + a01 = D_A_pack; + a11 = L + (n_remainder - 2)*cs_a + (n_remainder - 2)*rs_a; //pointer to block of A to be used for TRSM + b10 = B + (m_remainder - 6) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 6) + (n_remainder - 2)*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_2nx8m(a01,b10,cs_b,p_lda,k_iter) + + BLIS_PRE_STRSM_SMALL_2N_6M(AlphaVal,b11,cs_b) + + ///implement TRSM/// + + //extract a11 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack )); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm5, ymm3); + + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + BLIS_POST_STRSM_SMALL_2N_6M(b11,cs_b) + + m_remainder -=6; + } + else if(5 == m_remainder) + { + a01 = D_A_pack; + a11 = L + (n_remainder - 2)*cs_a + (n_remainder - 2)*rs_a; //pointer to block of A to be used for TRSM + b10 = B + (m_remainder - 5) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 5) + (n_remainder - 2)*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_2nx8m(a01,b10,cs_b,p_lda,k_iter) + + BLIS_PRE_STRSM_SMALL_2N_5M(AlphaVal,b11,cs_b) + + ///implement TRSM/// + + //extract a11 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack )); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm5, ymm3); + + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + BLIS_POST_STRSM_SMALL_2N_5M(b11,cs_b) + + m_remainder -=5; + } + else if(4 == m_remainder) + { + a01 = D_A_pack; + a11 = L + (n_remainder - 2)*cs_a + (n_remainder - 2)*rs_a; //pointer to block of A to be used for TRSM + b10 = B + (m_remainder - 4) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 4) + (n_remainder - 2)*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_2nx8m(a01,b10,cs_b,p_lda,k_iter) + + BLIS_PRE_STRSM_SMALL_2N_4M(AlphaVal,b11,cs_b) + + ///implement TRSM/// + + //extract a11 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack )); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm5, ymm3); + + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + BLIS_POST_STRSM_SMALL_2N_4M(b11,cs_b) + + m_remainder -=4; + } + else if(3 == m_remainder) + { + a01 = D_A_pack; + a11 = L + (n_remainder - 2)*cs_a + (n_remainder - 2)*rs_a; //pointer to block of A to be used for TRSM + b10 = B + (m_remainder - 3) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 3) + (n_remainder - 2)*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_2nx8m(a01,b10,cs_b,p_lda,k_iter) + + BLIS_PRE_STRSM_SMALL_2N_3M(AlphaVal,b11,cs_b) + + ///implement TRSM/// + + //extract a11 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack )); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm5, ymm3); + + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + BLIS_POST_STRSM_SMALL_2N_3M(b11,cs_b) + + m_remainder -=3; + } + else if(2 == m_remainder) + { + a01 = D_A_pack; + a11 = L + (n_remainder - 2)*cs_a + (n_remainder - 2)*rs_a; //pointer to block of A to be used for TRSM + b10 = B + (m_remainder - 2) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 2) + (n_remainder - 2)*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_2nx8m(a01,b10,cs_b,p_lda,k_iter) + + BLIS_PRE_STRSM_SMALL_2N_2M(AlphaVal,b11,cs_b) + ///implement TRSM/// + + //extract a11 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack )); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm5, ymm3); + + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + BLIS_POST_STRSM_SMALL_2N_2M(b11,cs_b) + + m_remainder -=2; + } + else if (1 == m_remainder) + { + a01 = D_A_pack; + a11 = L + (n_remainder - 2)*cs_a + (n_remainder - 2)*rs_a; //pointer to block of A to be used for TRSM + b10 = B + (m_remainder - 1) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 1) + (n_remainder - 2)*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_2nx8m(a01,b10,cs_b,p_lda,k_iter) + + BLIS_PRE_STRSM_SMALL_2N_1M(AlphaVal,b11,cs_b) + ///implement TRSM/// + + //extract a11 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack )); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_ps(ymm1, ymm5, ymm3); + + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + BLIS_POST_STRSM_SMALL_2N_1M(b11,cs_b) + + m_remainder -=1; + } + } + n_remainder -= 2; + } + else if(n_remainder == 1) + { + a01 = L + (n_remainder - 1)*rs_a + n_remainder*cs_a; //pointer to block of A to be used in GEMM + a11 = L + (n_remainder - 1)*cs_a + (n_remainder - 1)*rs_a; //pointer to block of A to be used for TRSM + + float *ptr_a10_dup = D_A_pack; + + dim_t p_lda = (n-n_remainder); // packed leading dimension + // perform copy of A to packed buffer D_A_pack + + if(transa) + { + __m128 xmm0, xmm1, xmm2, xmm3; + __m128 xmm4, xmm5, xmm6, xmm7; + __m128 xmm8, xmm9; + for(dim_t x =0;x < p_lda;x+=d_nr) + { + xmm0 = _mm_loadu_ps((float const *)(a01)); + xmm1 = _mm_loadu_ps((float const *)(a01 + cs_a)); + xmm2 = _mm_loadu_ps((float const *)(a01 + cs_a * 2)); + xmm3 = _mm_loadu_ps((float const *)(a01 + cs_a * 3)); + + xmm4 = _mm_unpacklo_ps(xmm0, xmm1); + xmm5 = _mm_unpacklo_ps(xmm2, xmm3); + xmm6 = _mm_shuffle_ps(xmm4,xmm5,0x44); + xmm7 = _mm_shuffle_ps(xmm4,xmm5,0xEE); + + xmm0 = _mm_unpackhi_ps(xmm0, xmm1); + xmm1 = _mm_unpackhi_ps(xmm2, xmm3); + xmm8 = _mm_shuffle_ps(xmm0,xmm1,0x44); + xmm9 = _mm_shuffle_ps(xmm0,xmm1,0xEE); + + _mm_storeu_ps((float *)(ptr_a10_dup), xmm6); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda), xmm7); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda*2), xmm8); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda*3), xmm9); + + xmm0 = _mm_loadu_ps((float const *)(a01 + cs_a * 4)); + xmm1 = _mm_loadu_ps((float const *)(a01 + cs_a * 5)); + + xmm4 = _mm_unpacklo_ps(xmm0, xmm1); + xmm5 = _mm_broadcast_ss((float const *)&zero); + xmm6 = _mm_shuffle_ps(xmm4,xmm5,0x44); + xmm7 = _mm_shuffle_ps(xmm4,xmm5,0xEE); + + xmm0 = _mm_unpackhi_ps(xmm0, xmm1); + xmm1 = _mm_broadcast_ss((float const *)&zero); + xmm8 = _mm_shuffle_ps(xmm0,xmm1,0x44); + xmm9 = _mm_shuffle_ps(xmm0,xmm1,0xEE); + + _mm_storel_pi((__m64 *)(ptr_a10_dup + 4), xmm6); + _mm_storel_pi((__m64 *)(ptr_a10_dup + 4 + p_lda), xmm7); + _mm_storel_pi((__m64 *)(ptr_a10_dup + 4 + p_lda*2), xmm8); + _mm_storel_pi((__m64 *)(ptr_a10_dup + 4 + p_lda*3), xmm9); + + a01 += d_nr*cs_a; + ptr_a10_dup += d_nr; + } + } + else + { + __m128 xmm0,xmm1; + dim_t loop_count = (n-n_remainder)/6; + for(dim_t i=0; i < loop_count; i++) + { + xmm1 = _mm_broadcast_ss((float *)&zero); + + xmm0 = _mm_loadu_ps((float *)(a01 + rs_a * 0 + i*6)); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda * 0 + i*6), xmm0); + xmm0 = _mm_loadl_pi(xmm1,(__m64 *)(a01 + rs_a * 0 + 4 + i*6)); + _mm_storel_pi((__m64 *)(ptr_a10_dup + p_lda * 0 + 4 + i*6),xmm0); + } + + if((n - n_remainder - (loop_count*6))/4 != 0) + { + xmm1 = _mm_broadcast_ss((float *)&zero); + + xmm0 = _mm_loadu_ps((float *)(a01 + rs_a * 0 + loop_count*6)); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda * 0 + loop_count*6), xmm0); + xmm0 = _mm_loadl_pi(xmm1,(__m64 *)(a01 + rs_a * 0 + 4 + loop_count*6)); + _mm_storel_pi((__m64 *)(ptr_a10_dup + p_lda * 0 + 4 + loop_count*6),xmm0); + } + } + + ymm4 = _mm256_broadcast_ss((float const *)&ones); + if(!is_unitdiag) + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_ss((float const *)(a11)); + ymm1 = _mm256_broadcast_ss((float const *)&ones); + ymm2 = _mm256_broadcast_ss((float const *)&ones); + ymm3 = _mm256_broadcast_ss((float const *)&ones); + + ymm0 = _mm256_unpacklo_ps(ymm0, ymm1); + + ymm1 = _mm256_blend_ps(ymm0, ymm0, 0x0C); + #ifdef BLIS_DISABLE_TRSM_PREINVERSION + ymm4 = ymm1; + #endif + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + ymm4 = _mm256_div_ps(ymm4, ymm1); + #endif + } + _mm256_storeu_ps((float *)(d11_pack), ymm4); + + for(i = (m-d_mr); (i+1) > 0; i -= d_mr) //loop along 'M' direction + { + a01 = D_A_pack; + a11 = L + (n_remainder - 1)*cs_a + (n_remainder - 1)*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (i) + (n_remainder - 1)*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_1nx16m(a01,b10,cs_b,p_lda,k_iter) + + ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); + + ymm0 = _mm256_loadu_ps((float const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_ps((float const *)(b11 + 8)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] + + ymm3 = _mm256_fmsub_ps(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + ymm4 = _mm256_fmsub_ps(ymm1, ymm15, ymm4); //B11[4-7][0] * alpha-= ymm1 + + ///implement TRSM/// + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack )); + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + ymm4 = STRSM_SMALL_DIV_OR_SCALE(ymm4, ymm0); + + _mm256_storeu_ps((float *)b11, ymm3); + _mm256_storeu_ps((float *)(b11 + 8), ymm4); + } + + dim_t m_remainder = i + d_mr; + if(m_remainder >= 8) + { + a01 = D_A_pack; + a11 = L + (n_remainder - 1)*cs_a + (n_remainder - 1)*rs_a; //pointer to block of A to be used for TRSM + b10 = B + (m_remainder - 8) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 8) + (n_remainder - 1)*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + + ymm3 = _mm256_setzero_ps(); + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_1nx8m(a01,b10,cs_b,p_lda,k_iter) + + ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); //register to hold alpha + + ymm0 = _mm256_loadu_ps((float const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm3 = _mm256_fmsub_ps(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + + ///implement TRSM/// + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack )); + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + _mm256_storeu_ps((float *)b11, ymm3); + + m_remainder -=8; + } + + if(m_remainder) + { + if(7 == m_remainder) + { + a01 = D_A_pack; + a11 = L + (n_remainder - 1)*cs_a + (n_remainder - 1)*rs_a; //pointer to block of A to be used for TRSM + b10 = B + (m_remainder - 7) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 7) + (n_remainder - 1)*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + + ymm3 = _mm256_setzero_ps(); + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_1nx8m(a01,b10,cs_b,p_lda,k_iter) + + BLIS_PRE_STRSM_SMALL_1N_7M(AlphaVal,b11,cs_b) + + ///implement TRSM/// + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack )); + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + BLIS_POST_STRSM_SMALL_1N_7M(b11,cs_b) + + m_remainder -=7; + + } + else if(6 == m_remainder) + { + a01 = D_A_pack; + a11 = L + (n_remainder - 1)*cs_a + (n_remainder - 1)*rs_a; //pointer to block of A to be used for TRSM + b10 = B + (m_remainder - 6) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 6) + (n_remainder - 1)*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + + ymm3 = _mm256_setzero_ps(); + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_1nx8m(a01,b10,cs_b,p_lda,k_iter) + + BLIS_PRE_STRSM_SMALL_1N_6M(AlphaVal,b11,cs_b) + + ///implement TRSM/// + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack )); + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + BLIS_POST_STRSM_SMALL_1N_6M(b11,cs_b) + + m_remainder -=6; + + } + else if(5 == m_remainder) + { + a01 = D_A_pack; + a11 = L + (n_remainder - 1)*cs_a + (n_remainder - 1)*rs_a; //pointer to block of A to be used for TRSM + b10 = B + (m_remainder - 5) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 5) + (n_remainder - 1)*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + + ymm3 = _mm256_setzero_ps(); + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_1nx8m(a01,b10,cs_b,p_lda,k_iter) + + BLIS_PRE_STRSM_SMALL_1N_5M(AlphaVal,b11,cs_b) + + ///implement TRSM/// + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack )); + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + BLIS_POST_STRSM_SMALL_1N_5M(b11,cs_b) + + m_remainder -=5; + } + else if(4 == m_remainder) + { + a01 = D_A_pack; + a11 = L + (n_remainder - 1)*cs_a + (n_remainder - 1)*rs_a; //pointer to block of A to be used for TRSM + b10 = B + (m_remainder - 4) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 4) + (n_remainder - 1)*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + + ymm3 = _mm256_setzero_ps(); + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_1nx8m(a01,b10,cs_b,p_lda,k_iter) + + BLIS_PRE_STRSM_SMALL_1N_4M(AlphaVal,b11,cs_b) + + ///implement TRSM/// + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack )); + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + BLIS_POST_STRSM_SMALL_1N_4M(b11,cs_b) + + m_remainder -=4; + } + else if(3 == m_remainder) + { + a01 = D_A_pack; + a11 = L + (n_remainder - 1)*cs_a + (n_remainder - 1)*rs_a; //pointer to block of A to be used for TRSM + b10 = B + (m_remainder - 3) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 3) + (n_remainder - 1)*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + + ymm3 = _mm256_setzero_ps(); + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_1nx8m(a01,b10,cs_b,p_lda,k_iter) + + BLIS_PRE_STRSM_SMALL_1N_3M(AlphaVal,b11,cs_b) + + ///implement TRSM/// + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack )); + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + BLIS_POST_STRSM_SMALL_1N_3M(b11,cs_b) + + m_remainder -=3; + } + else if(2 == m_remainder) + { + a01 = D_A_pack; + a11 = L + (n_remainder - 1)*cs_a + (n_remainder - 1)*rs_a; //pointer to block of A to be used for TRSM + b10 = B + (m_remainder - 2) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 2) + (n_remainder - 1)*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + + ymm3 = _mm256_setzero_ps(); + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_1nx8m(a01,b10,cs_b,p_lda,k_iter) + + BLIS_PRE_STRSM_SMALL_1N_2M(AlphaVal,b11,cs_b) + + ///implement TRSM/// + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack )); + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + BLIS_POST_STRSM_SMALL_1N_2M(b11,cs_b) + + m_remainder -=2; + } + else if (1 == m_remainder) + { + a01 = D_A_pack; + a11 = L + (n_remainder - 1)*cs_a + (n_remainder - 1)*rs_a; //pointer to block of A to be used for TRSM + b10 = B + (m_remainder - 1) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 1) + (n_remainder - 1)*cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + + ymm3 = _mm256_setzero_ps(); + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_1nx8m(a01,b10,cs_b,p_lda,k_iter) + + BLIS_PRE_STRSM_SMALL_1N_1M(AlphaVal,b11,cs_b) + + ///implement TRSM/// + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack )); + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + BLIS_POST_STRSM_SMALL_1N_1M(b11,cs_b) + + m_remainder -=1; + } + } + n_remainder -= 1; + } + + if ((required_packing_A == 1) && bli_mem_is_alloc( &local_mem_buf_A_s )) + { + bli_membrk_release(&rntm, + &local_mem_buf_A_s); + } + return BLIS_SUCCESS; +} + +/*implements TRSM for the case XA = alpha * B + *A is lower triangular, non-unit diagonal/unit diagonal, transpose + *dimensions: X:mxn A:nxn B: mxn + * + * b11---> a01 ----> + ***************** *********** + *b01*b11* * * * * * * +b11 * * * * * **a01 * * a11 + | ***************** ********* | + | * * * * * *a11* * | + | * * * * * * * * | + v ***************** ****** v + * * * * * * * + * * * * * * * + ***************** * * + * + *implements TRSM for the case XA = alpha * B + *A is upper triangular, non-unit diagonal/unit diagonal, no transpose + *dimensions: X:mxn A:nxn B: mxn + * + * b11---> a01 ----> + ***************** *********** + *b01*b11* * * * * * * +b11 * * * * * **a01 * * a11 + | ***************** ********* | + | * * * * * *a11* * | + | * * * * * * * * | + v ***************** ****** v + * * * * * * * + * * * * * * * + ***************** * * + * + +*/ + +BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +) +{ + dim_t m = bli_obj_length(b); //number of rows + dim_t n = bli_obj_width(b); //number of columns + dim_t d_mr = 16,d_nr = 6; + + bool transa = bli_obj_has_trans(a); + dim_t cs_a, rs_a; + + // Swap rs_a & cs_a in case of non-tranpose. + if(transa) + { + cs_a = bli_obj_col_stride(a); // column stride of A + rs_a = bli_obj_row_stride(a); // row stride of A + } + else + { + cs_a = bli_obj_row_stride(a); // row stride of A + rs_a = bli_obj_col_stride(a); // column stride of A + } + dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B + + dim_t i, j, k; //loop variablse + dim_t k_iter; //determines the number of GEMM operations to be done + + float ones = 1.0; + float zero = 0.0; + bool is_unitdiag = bli_obj_has_unit_diag(a); + + float AlphaVal = *(float *)AlphaObj->buffer; //value of Alpha + float* restrict L = a->buffer; //pointer to matrix A + float* restrict B = b->buffer; //pointer to matrix B + + float *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks + + gint_t required_packing_A = 1; + mem_t local_mem_buf_A_s = {0}; + float *D_A_pack = NULL; + float d11_pack[d_mr] __attribute__((aligned(64))); + rntm_t rntm; + + bli_rntm_init_from_global( &rntm ); + bli_rntm_set_num_threads_only( 1, &rntm ); + bli_membrk_rntm_set_membrk( &rntm ); + + siz_t buffer_size = bli_pool_block_size( + bli_membrk_pool( + bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), + bli_rntm_membrk(&rntm))); + + if( (d_nr * n * sizeof(float)) > buffer_size) + return BLIS_NOT_YET_IMPLEMENTED; + + if (required_packing_A == 1) + { + // Get the buffer from the pool. + bli_membrk_acquire_m(&rntm, + buffer_size, + BLIS_BITVAL_BUFFER_FOR_A_BLOCK, + &local_mem_buf_A_s); + if(FALSE==bli_mem_is_alloc(&local_mem_buf_A_s)) return BLIS_NULL_POINTER; + D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); + if(NULL==D_A_pack) return BLIS_NULL_POINTER; + } + + //ymm scratch reginsters + __m256 ymm0, ymm1, ymm2, ymm3; + __m256 ymm4, ymm5, ymm6, ymm7; + __m256 ymm8, ymm9, ymm10, ymm11; + __m256 ymm12, ymm13, ymm14, ymm15; + + __m128 xmm5; + + /* + Performs solving TRSM for 6 rows at a time from 0 to n/6 in steps of d_nr + a. Load and pack A (a01 block), the size of packing 6x6 to 6x (n-6) + First there will be no GEMM and no packing of a01 because it is only TRSM + b. Using packed a01 block and b10 block perform GEMM operation + c. Use GEMM outputs, perform TRSM operation using a11, b11 and update B + d. Repeat b for m cols of B in steps of d_mr + */ + + for(j = 0; (j+d_nr-1) < n; j += d_nr) //loop along 'N' direction + { + a01 = L + j*rs_a; //pointer to block of A to be used in GEMM + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + + //double *ptr_a10_dup = D_A_pack; + + dim_t p_lda = j; // packed leading dimension + // perform copy of A to packed buffer D_A_pack + + if(transa) + { + /* + Pack current A block (a01) into packed buffer memory D_A_pack + a. This a10 block is used in GEMM portion only and this + a01 block size will be increasing by d_nr for every next iteration + until it reaches 6x(n-6) which is the maximum GEMM alone block size in A + b. This packed buffer is reused to calculate all m cols of B matrix + */ + bli_strsm_small_pack('R', j, 1, a01, cs_a, D_A_pack, p_lda,d_nr); + + /* + Pack 6 diagonal elements of A block into an array + a. This helps in utilze cache line efficiently in TRSM operation + b. store ones when input is unit diagonal + */ + + strsm_small_pack_diag_element('R',is_unitdiag,a11,cs_a,d11_pack,d_nr); + } + else + { + bli_strsm_small_pack('R', j, 0, a01, rs_a, D_A_pack, p_lda,d_nr); + strsm_small_pack_diag_element('R',is_unitdiag,a11,rs_a,d11_pack,d_nr); + } + + /* + a. Perform GEMM using a01, b10. + b. Perform TRSM on a11, b11 + c. This loop GEMM+TRSM loops operates with 16x6 block size + along m dimension for every d_mr columns of B10 where + packed A buffer is reused in computing all m cols of B. + d. Same approach is used in remaining fringe cases. + */ + for(i = 0; (i+d_mr-1) < m; i += d_mr) //loop along 'M' direction + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + /* + Peform GEMM between a01 and b10 blocks + For first itteration there will be no GEMM operation + where k_iter are zero + */ + BLIS_STRSM_SMALL_GEMM_6nx16m(a01,b10,cs_b,p_lda,k_iter) + + /* + Load b11 of size 16x6 and multiply with alpha + Add the GEMM output to b11 + and peform TRSM operation. + */ + + BLIS_PRE_STRSM_SMALL_6x16(AlphaVal,b11,cs_b) + + ///implement TRSM/// + + /* + Compute 6x16 TRSM block by using GEMM block output in register + a. The 6x16 input (gemm outputs) are stored in combinations of ymm registers + b. Towards the end TRSM output will be stored back into b11 + */ + + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack)); + + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + ymm4 = STRSM_SMALL_DIV_OR_SCALE(ymm4, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); + + //(row 1):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 1*rs_a)); + + ymm5 = _mm256_fnmadd_ps(ymm1, ymm3, ymm5); + ymm6 = _mm256_fnmadd_ps(ymm1, ymm4, ymm6); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*rs_a)); + + ymm7 = _mm256_fnmadd_ps(ymm1, ymm3, ymm7); + ymm8 = _mm256_fnmadd_ps(ymm1, ymm4, ymm8); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*rs_a)); + + ymm9 = _mm256_fnmadd_ps(ymm1, ymm3, ymm9); + ymm10 = _mm256_fnmadd_ps(ymm1, ymm4, ymm10); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 4*rs_a)); + + ymm11 = _mm256_fnmadd_ps(ymm1, ymm3, ymm11); + ymm12 = _mm256_fnmadd_ps(ymm1, ymm4, ymm12); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*rs_a)); + + ymm13 = _mm256_fnmadd_ps(ymm1, ymm3, ymm13); + ymm14 = _mm256_fnmadd_ps(ymm1, ymm4, ymm14); + + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + ymm6 = STRSM_SMALL_DIV_OR_SCALE(ymm6, ymm0); + + a11 += cs_a; + + //extract a22 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 2)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*rs_a)); + + ymm7 = _mm256_fnmadd_ps(ymm1, ymm5, ymm7); + ymm8 = _mm256_fnmadd_ps(ymm1, ymm6, ymm8); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*rs_a)); + + ymm9 = _mm256_fnmadd_ps(ymm1, ymm5, ymm9); + ymm10 = _mm256_fnmadd_ps(ymm1, ymm6, ymm10); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 4*rs_a)); + + ymm11 = _mm256_fnmadd_ps(ymm1, ymm5, ymm11); + ymm12 = _mm256_fnmadd_ps(ymm1, ymm6, ymm12); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*rs_a)); + + ymm13 = _mm256_fnmadd_ps(ymm1, ymm5, ymm13); + ymm14 = _mm256_fnmadd_ps(ymm1, ymm6, ymm14); + + ymm7 = STRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + ymm8 = STRSM_SMALL_DIV_OR_SCALE(ymm8, ymm0); + + a11 += cs_a; + + //extract a33 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 3)); + + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*rs_a)); + + ymm9 = _mm256_fnmadd_ps(ymm1, ymm7, ymm9); + ymm10 = _mm256_fnmadd_ps(ymm1, ymm8, ymm10); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 4*rs_a)); + + ymm11 = _mm256_fnmadd_ps(ymm1, ymm7, ymm11); + ymm12 = _mm256_fnmadd_ps(ymm1, ymm8, ymm12); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*rs_a)); + + ymm13 = _mm256_fnmadd_ps(ymm1, ymm7, ymm13); + ymm14 = _mm256_fnmadd_ps(ymm1, ymm8, ymm14); + + ymm9 = STRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + ymm10 = STRSM_SMALL_DIV_OR_SCALE(ymm10, ymm0); + + a11 += cs_a; + + //extract a44 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 4)); + + //(row 4):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 4*rs_a)); + + ymm11 = _mm256_fnmadd_ps(ymm1, ymm9, ymm11); + ymm12 = _mm256_fnmadd_ps(ymm1, ymm10, ymm12); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*rs_a)); + + ymm13 = _mm256_fnmadd_ps(ymm1, ymm9, ymm13); + ymm14 = _mm256_fnmadd_ps(ymm1, ymm10, ymm14); + + ymm11 = STRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); + ymm12 = STRSM_SMALL_DIV_OR_SCALE(ymm12, ymm0); + + a11 += cs_a; + + //extract a55 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 5)); + + //(Row 5): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*rs_a)); + + ymm13 = _mm256_fnmadd_ps(ymm1, ymm11, ymm13); + ymm14 = _mm256_fnmadd_ps(ymm1, ymm12, ymm14); + + ymm13 = STRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); + ymm14 = STRSM_SMALL_DIV_OR_SCALE(ymm14, ymm0); + + _mm256_storeu_ps((float *)b11, ymm3); + _mm256_storeu_ps((float *)(b11 + 8), ymm4); + _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); + _mm256_storeu_ps((float *)(b11 + cs_b + 8), ymm6); + _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); + _mm256_storeu_ps((float *)(b11 + cs_b*2 + 8), ymm8); + _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); + _mm256_storeu_ps((float *)(b11 + cs_b*3 + 8), ymm10); + _mm256_storeu_ps((float *)(b11 + cs_b*4), ymm11); + _mm256_storeu_ps((float *)(b11 + cs_b*4 + 8), ymm12); + _mm256_storeu_ps((float *)(b11 + cs_b*5), ymm13); + _mm256_storeu_ps((float *)(b11 + cs_b*5 + 8), ymm14); + } + + + dim_t m_remainder = m - i; + if(m_remainder >= 8) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) + + // Load b11 of size 8x6 and multiply with alpha + BLIS_PRE_STRSM_SMALL_6x8(AlphaVal,b11,cs_b) + + ///implement TRSM/// + + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack)); + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); + + //(row 1):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm3, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm3, ymm7); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_ps(ymm1, ymm3, ymm9); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 4*rs_a)); + ymm11 = _mm256_fnmadd_ps(ymm1, ymm3, ymm11); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_ps(ymm1, ymm3, ymm13); + + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + a11 += cs_a; + + //extract a22 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 2)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm5, ymm7); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_ps(ymm1, ymm5, ymm9); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 4*rs_a)); + ymm11 = _mm256_fnmadd_ps(ymm1, ymm5, ymm11); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_ps(ymm1, ymm5, ymm13); + + ymm7 = STRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + a11 += cs_a; + + //extract a33 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 3)); + + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_ps(ymm1, ymm7, ymm9); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 4*rs_a)); + ymm11 = _mm256_fnmadd_ps(ymm1, ymm7, ymm11); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_ps(ymm1, ymm7, ymm13); + + ymm9 = STRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + + a11 += cs_a; + + //extract a44 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 4)); + + //(row 4):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 4*rs_a)); + ymm11 = _mm256_fnmadd_ps(ymm1, ymm9, ymm11); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_ps(ymm1, ymm9, ymm13); + + ymm11 = STRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); + + a11 += cs_a; + + //extract a55 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 5)); + + //(Row 5): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_ps(ymm1, ymm11, ymm13); + + ymm13 = STRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); + + _mm256_storeu_ps((float *)b11, ymm3); + _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); + _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); + _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); + _mm256_storeu_ps((float *)(b11 + cs_b*4), ymm11); + _mm256_storeu_ps((float *)(b11 + cs_b*5), ymm13); + + m_remainder -= 8; + i += 8; + } + + if(m_remainder == 7) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) + + // Load b11 of size 8x6 and multiply with alpha + BLIS_PRE_STRSM_SMALL_6x8(AlphaVal,b11,cs_b) + + ///implement TRSM/// + + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack)); + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); + + //(row 1):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm3, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm3, ymm7); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_ps(ymm1, ymm3, ymm9); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 4*rs_a)); + ymm11 = _mm256_fnmadd_ps(ymm1, ymm3, ymm11); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_ps(ymm1, ymm3, ymm13); + + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + a11 += cs_a; + + //extract a22 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 2)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm5, ymm7); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_ps(ymm1, ymm5, ymm9); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 4*rs_a)); + ymm11 = _mm256_fnmadd_ps(ymm1, ymm5, ymm11); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_ps(ymm1, ymm5, ymm13); + + ymm7 = STRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + a11 += cs_a; + + //extract a33 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 3)); + + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_ps(ymm1, ymm7, ymm9); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 4*rs_a)); + ymm11 = _mm256_fnmadd_ps(ymm1, ymm7, ymm11); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_ps(ymm1, ymm7, ymm13); + + ymm9 = STRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + + a11 += cs_a; + + //extract a44 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 4)); + + //(row 4):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 4*rs_a)); + ymm11 = _mm256_fnmadd_ps(ymm1, ymm9, ymm11); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_ps(ymm1, ymm9, ymm13); + + ymm11 = STRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); + + a11 += cs_a; + + //extract a55 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 5)); + + //(Row 5): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_ps(ymm1, ymm11, ymm13); + + ymm13 = STRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); + + ymm0 = _mm256_loadu_ps((float const *)b11); + ymm3 = _mm256_blend_ps(ymm0, ymm3, 0x7F); + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_blend_ps(ymm0, ymm5, 0x7F); + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm7 = _mm256_blend_ps(ymm0, ymm7, 0x7F); + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm9 = _mm256_blend_ps(ymm0, ymm9, 0x7F); + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm11 = _mm256_blend_ps(ymm0, ymm11, 0x7F); + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm13 = _mm256_blend_ps(ymm0, ymm13, 0x7F); + + _mm256_storeu_ps((float *)b11, ymm3); + _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); + _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); + _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); + _mm256_storeu_ps((float *)(b11 + cs_b*4), ymm11); + _mm256_storeu_ps((float *)(b11 + cs_b*5), ymm13); + + m_remainder -= 7; + i += 7; + } + else if(m_remainder == 6) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) + + // Load b11 of size 8x6 and multiply with alpha + BLIS_PRE_STRSM_SMALL_6x8(AlphaVal,b11,cs_b) + + ///implement TRSM/// + + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack)); + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); + + //(row 1):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm3, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm3, ymm7); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_ps(ymm1, ymm3, ymm9); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 4*rs_a)); + ymm11 = _mm256_fnmadd_ps(ymm1, ymm3, ymm11); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_ps(ymm1, ymm3, ymm13); + + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + a11 += cs_a; + + //extract a22 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 2)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm5, ymm7); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_ps(ymm1, ymm5, ymm9); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 4*rs_a)); + ymm11 = _mm256_fnmadd_ps(ymm1, ymm5, ymm11); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_ps(ymm1, ymm5, ymm13); + + ymm7 = STRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + a11 += cs_a; + + //extract a33 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 3)); + + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_ps(ymm1, ymm7, ymm9); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 4*rs_a)); + ymm11 = _mm256_fnmadd_ps(ymm1, ymm7, ymm11); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_ps(ymm1, ymm7, ymm13); + + ymm9 = STRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + + a11 += cs_a; + + //extract a44 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 4)); + + //(row 4):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 4*rs_a)); + ymm11 = _mm256_fnmadd_ps(ymm1, ymm9, ymm11); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_ps(ymm1, ymm9, ymm13); + + ymm11 = STRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); + + a11 += cs_a; + + //extract a55 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 5)); + + //(Row 5): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_ps(ymm1, ymm11, ymm13); + + ymm13 = STRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); + + ymm0 = _mm256_loadu_ps((float const *)b11); + ymm3 = _mm256_blend_ps(ymm0, ymm3, 0x3F); + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_blend_ps(ymm0, ymm5, 0x3F); + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm7 = _mm256_blend_ps(ymm0, ymm7, 0x3F); + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm9 = _mm256_blend_ps(ymm0, ymm9, 0x3F); + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm11 = _mm256_blend_ps(ymm0, ymm11, 0x3F); + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm13 = _mm256_blend_ps(ymm0, ymm13, 0x3F); + + _mm256_storeu_ps((float *)b11, ymm3); + _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); + _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); + _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); + _mm256_storeu_ps((float *)(b11 + cs_b*4), ymm11); + _mm256_storeu_ps((float *)(b11 + cs_b*5), ymm13); + + m_remainder -= 6; + i += 6; + } + else if(m_remainder == 5) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) + + // Load b11 of size 8x6 and multiply with alpha + BLIS_PRE_STRSM_SMALL_6x8(AlphaVal,b11,cs_b) + + ///implement TRSM/// + + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack)); + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); + + //(row 1):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm3, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm3, ymm7); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_ps(ymm1, ymm3, ymm9); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 4*rs_a)); + ymm11 = _mm256_fnmadd_ps(ymm1, ymm3, ymm11); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_ps(ymm1, ymm3, ymm13); + + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + a11 += cs_a; + + //extract a22 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 2)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm5, ymm7); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_ps(ymm1, ymm5, ymm9); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 4*rs_a)); + ymm11 = _mm256_fnmadd_ps(ymm1, ymm5, ymm11); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_ps(ymm1, ymm5, ymm13); + + ymm7 = STRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + a11 += cs_a; + + //extract a33 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 3)); + + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_ps(ymm1, ymm7, ymm9); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 4*rs_a)); + ymm11 = _mm256_fnmadd_ps(ymm1, ymm7, ymm11); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_ps(ymm1, ymm7, ymm13); + + ymm9 = STRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + + a11 += cs_a; + + //extract a44 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 4)); + + //(row 4):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 4*rs_a)); + ymm11 = _mm256_fnmadd_ps(ymm1, ymm9, ymm11); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_ps(ymm1, ymm9, ymm13); + + ymm11 = STRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); + + a11 += cs_a; + + //extract a55 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 5)); + + //(Row 5): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_ps(ymm1, ymm11, ymm13); + + ymm13 = STRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); + + ymm0 = _mm256_loadu_ps((float const *)b11); + ymm3 = _mm256_blend_ps(ymm0, ymm3, 0x1F); + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_blend_ps(ymm0, ymm5, 0x1F); + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm7 = _mm256_blend_ps(ymm0, ymm7, 0x1F); + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm9 = _mm256_blend_ps(ymm0, ymm9, 0x1F); + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm11 = _mm256_blend_ps(ymm0, ymm11, 0x1F); + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm13 = _mm256_blend_ps(ymm0, ymm13, 0x1F); + + _mm256_storeu_ps((float *)b11, ymm3); + _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); + _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); + _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); + _mm256_storeu_ps((float *)(b11 + cs_b*4), ymm11); + _mm256_storeu_ps((float *)(b11 + cs_b*5), ymm13); + + m_remainder -= 5; + i += 5; + } + else if(m_remainder == 4) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) + + // Load b11 of size 8x6 and multiply with alpha + BLIS_PRE_STRSM_SMALL_6x8(AlphaVal,b11,cs_b) + + ///implement TRSM/// + + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack)); + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); + + //(row 1):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm3, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm3, ymm7); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_ps(ymm1, ymm3, ymm9); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 4*rs_a)); + ymm11 = _mm256_fnmadd_ps(ymm1, ymm3, ymm11); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_ps(ymm1, ymm3, ymm13); + + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + a11 += cs_a; + + //extract a22 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 2)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm5, ymm7); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_ps(ymm1, ymm5, ymm9); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 4*rs_a)); + ymm11 = _mm256_fnmadd_ps(ymm1, ymm5, ymm11); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_ps(ymm1, ymm5, ymm13); + + ymm7 = STRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + a11 += cs_a; + + //extract a33 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 3)); + + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_ps(ymm1, ymm7, ymm9); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 4*rs_a)); + ymm11 = _mm256_fnmadd_ps(ymm1, ymm7, ymm11); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_ps(ymm1, ymm7, ymm13); + + ymm9 = STRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + + a11 += cs_a; + + //extract a44 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 4)); + + //(row 4):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 4*rs_a)); + ymm11 = _mm256_fnmadd_ps(ymm1, ymm9, ymm11); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_ps(ymm1, ymm9, ymm13); + + ymm11 = STRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); + + a11 += cs_a; + + //extract a55 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 5)); + + //(Row 5): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_ps(ymm1, ymm11, ymm13); + + ymm13 = STRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); + + ymm0 = _mm256_loadu_ps((float const *)b11); + ymm3 = _mm256_blend_ps(ymm0, ymm3, 0x0F); + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_blend_ps(ymm0, ymm5, 0x0F); + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm7 = _mm256_blend_ps(ymm0, ymm7, 0x0F); + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm9 = _mm256_blend_ps(ymm0, ymm9, 0x0F); + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm11 = _mm256_blend_ps(ymm0, ymm11, 0x0F); + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm13 = _mm256_blend_ps(ymm0, ymm13, 0x0F); + + _mm256_storeu_ps((float *)b11, ymm3); + _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); + _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); + _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); + _mm256_storeu_ps((float *)(b11 + cs_b*4), ymm11); + _mm256_storeu_ps((float *)(b11 + cs_b*5), ymm13); + + m_remainder -= 4; + i += 4; + } + else if(m_remainder == 3) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) + + // Load b11 of size 8x6 and multiply with alpha + BLIS_PRE_STRSM_SMALL_6x8(AlphaVal,b11,cs_b) + + ///implement TRSM/// + + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack)); + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); + + //(row 1):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm3, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm3, ymm7); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_ps(ymm1, ymm3, ymm9); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 4*rs_a)); + ymm11 = _mm256_fnmadd_ps(ymm1, ymm3, ymm11); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_ps(ymm1, ymm3, ymm13); + + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + a11 += cs_a; + + //extract a22 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 2)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm5, ymm7); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_ps(ymm1, ymm5, ymm9); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 4*rs_a)); + ymm11 = _mm256_fnmadd_ps(ymm1, ymm5, ymm11); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_ps(ymm1, ymm5, ymm13); + + ymm7 = STRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + a11 += cs_a; + + //extract a33 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 3)); + + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_ps(ymm1, ymm7, ymm9); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 4*rs_a)); + ymm11 = _mm256_fnmadd_ps(ymm1, ymm7, ymm11); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_ps(ymm1, ymm7, ymm13); + + ymm9 = STRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + + a11 += cs_a; + + //extract a44 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 4)); + + //(row 4):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 4*rs_a)); + ymm11 = _mm256_fnmadd_ps(ymm1, ymm9, ymm11); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_ps(ymm1, ymm9, ymm13); + + ymm11 = STRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); + + a11 += cs_a; + + //extract a55 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 5)); + + //(Row 5): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_ps(ymm1, ymm11, ymm13); + + ymm13 = STRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); + + ymm0 = _mm256_loadu_ps((float const *)b11); + ymm3 = _mm256_blend_ps(ymm0, ymm3, 0x07); + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_blend_ps(ymm0, ymm5, 0x07); + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm7 = _mm256_blend_ps(ymm0, ymm7, 0x07); + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm9 = _mm256_blend_ps(ymm0, ymm9, 0x07); + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm11 = _mm256_blend_ps(ymm0, ymm11, 0x07); + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm13 = _mm256_blend_ps(ymm0, ymm13, 0x07); + + _mm256_storeu_ps((float *)b11, ymm3); + _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); + _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); + _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); + _mm256_storeu_ps((float *)(b11 + cs_b*4), ymm11); + _mm256_storeu_ps((float *)(b11 + cs_b*5), ymm13); + + m_remainder -= 3; + i += 3; + } + else if(m_remainder == 2) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) + + // Load b11 of size 4x6 and multiply with alpha + BLIS_PRE_STRSM_SMALL_6x8(AlphaVal,b11,cs_b) + + ///implement TRSM/// + + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack)); + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); + + //(row 1):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm3, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm3, ymm7); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_ps(ymm1, ymm3, ymm9); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 4*rs_a)); + ymm11 = _mm256_fnmadd_ps(ymm1, ymm3, ymm11); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_ps(ymm1, ymm3, ymm13); + + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + a11 += cs_a; + + //extract a22 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 2)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm5, ymm7); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_ps(ymm1, ymm5, ymm9); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 4*rs_a)); + ymm11 = _mm256_fnmadd_ps(ymm1, ymm5, ymm11); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_ps(ymm1, ymm5, ymm13); + + ymm7 = STRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + a11 += cs_a; + + //extract a33 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 3)); + + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_ps(ymm1, ymm7, ymm9); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 4*rs_a)); + ymm11 = _mm256_fnmadd_ps(ymm1, ymm7, ymm11); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_ps(ymm1, ymm7, ymm13); + + ymm9 = STRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + + a11 += cs_a; + + //extract a44 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 4)); + + //(row 4):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 4*rs_a)); + ymm11 = _mm256_fnmadd_ps(ymm1, ymm9, ymm11); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_ps(ymm1, ymm9, ymm13); + + ymm11 = STRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); + + a11 += cs_a; + + //extract a55 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 5)); + + //(Row 5): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_ps(ymm1, ymm11, ymm13); + + ymm13 = STRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); + + ymm0 = _mm256_loadu_ps((float const *)b11); + ymm3 = _mm256_blend_ps(ymm0, ymm3, 0x03); + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_blend_ps(ymm0, ymm5, 0x03); + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm7 = _mm256_blend_ps(ymm0, ymm7, 0x03); + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm9 = _mm256_blend_ps(ymm0, ymm9, 0x03); + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm11 = _mm256_blend_ps(ymm0, ymm11, 0x03); + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm13 = _mm256_blend_ps(ymm0, ymm13, 0x03); + + _mm256_storeu_ps((float *)b11, ymm3); + _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); + _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); + _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); + _mm256_storeu_ps((float *)(b11 + cs_b*4), ymm11); + _mm256_storeu_ps((float *)(b11 + cs_b*5), ymm13); + + m_remainder -= 2; + i += 2; + } + else if(m_remainder == 1) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) + + // Load b11 of size 4x6 and multiply with alpha + BLIS_PRE_STRSM_SMALL_6x8(AlphaVal,b11,cs_b) + + ///implement TRSM/// + + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack)); + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); + + //(row 1):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm3, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm3, ymm7); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_ps(ymm1, ymm3, ymm9); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 4*rs_a)); + ymm11 = _mm256_fnmadd_ps(ymm1, ymm3, ymm11); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_ps(ymm1, ymm3, ymm13); + + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + a11 += cs_a; + + //extract a22 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 2)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm5, ymm7); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_ps(ymm1, ymm5, ymm9); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 4*rs_a)); + ymm11 = _mm256_fnmadd_ps(ymm1, ymm5, ymm11); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_ps(ymm1, ymm5, ymm13); + + ymm7 = STRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + a11 += cs_a; + + //extract a33 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 3)); + + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_ps(ymm1, ymm7, ymm9); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 4*rs_a)); + ymm11 = _mm256_fnmadd_ps(ymm1, ymm7, ymm11); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_ps(ymm1, ymm7, ymm13); + + ymm9 = STRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + + a11 += cs_a; + + //extract a44 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 4)); + + //(row 4):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 4*rs_a)); + ymm11 = _mm256_fnmadd_ps(ymm1, ymm9, ymm11); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_ps(ymm1, ymm9, ymm13); + + ymm11 = STRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); + + a11 += cs_a; + + //extract a55 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 5)); + + //(Row 5): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 5*rs_a)); + ymm13 = _mm256_fnmadd_ps(ymm1, ymm11, ymm13); + + ymm13 = STRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); + + ymm0 = _mm256_loadu_ps((float const *)b11); + ymm3 = _mm256_blend_ps(ymm0, ymm3, 0x01); + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_blend_ps(ymm0, ymm5, 0x01); + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm7 = _mm256_blend_ps(ymm0, ymm7, 0x01); + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm9 = _mm256_blend_ps(ymm0, ymm9, 0x01); + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm11 = _mm256_blend_ps(ymm0, ymm11, 0x01); + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm13 = _mm256_blend_ps(ymm0, ymm13, 0x01); + + _mm256_storeu_ps((float *)b11, ymm3); + _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); + _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); + _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); + _mm256_storeu_ps((float *)(b11 + cs_b*4), ymm11); + _mm256_storeu_ps((float *)(b11 + cs_b*5), ymm13); + + m_remainder -= 1; + i += 1; + } + + } + + dim_t n_remainder = n - j; + + /* + Reminder cases starts here: + a. Similar logic and code flow used in computing full block (6x8) + above holds for reminder cases too. + */ + + if(n_remainder >= 4) + { + a01 = L + j*rs_a; //pointer to block of A to be used in GEMM + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + + float *ptr_a10_dup = D_A_pack; + + dim_t p_lda = j; // packed leading dimension + // perform copy of A to packed buffer D_A_pack + + if(transa) + { + __m128 xmm0, xmm1, xmm2, xmm3; + __m128 xmm4, xmm5, xmm6, xmm7; + __m128 xmm8, xmm9; + + for(dim_t x =0;x < p_lda;x+=d_nr) + { + xmm0 = _mm_loadu_ps((float const *)(a01)); + xmm1 = _mm_loadu_ps((float const *)(a01 + cs_a)); + xmm2 = _mm_loadu_ps((float const *)(a01 + cs_a * 2)); + xmm3 = _mm_loadu_ps((float const *)(a01 + cs_a * 3)); + + xmm4 = _mm_unpacklo_ps(xmm0, xmm1); + xmm5 = _mm_unpacklo_ps(xmm2, xmm3); + xmm6 = _mm_shuffle_ps(xmm4,xmm5,0x44); + xmm7 = _mm_shuffle_ps(xmm4,xmm5,0xEE); + + xmm0 = _mm_unpackhi_ps(xmm0, xmm1); + xmm1 = _mm_unpackhi_ps(xmm2, xmm3); + xmm8 = _mm_shuffle_ps(xmm0,xmm1,0x44); + xmm9 = _mm_shuffle_ps(xmm0,xmm1,0xEE); + + _mm_storeu_ps((float *)(ptr_a10_dup), xmm6); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda), xmm7); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda*2), xmm8); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda*3), xmm9); + + xmm0 = _mm_loadu_ps((float const *)(a01 + cs_a * 4)); + xmm1 = _mm_loadu_ps((float const *)(a01 + cs_a * 5)); + + xmm4 = _mm_unpacklo_ps(xmm0, xmm1); + xmm5 = _mm_broadcast_ss((float const *)&zero); + xmm6 = _mm_shuffle_ps(xmm4,xmm5,0x44); + xmm7 = _mm_shuffle_ps(xmm4,xmm5,0xEE); + + xmm0 = _mm_unpackhi_ps(xmm0, xmm1); + xmm1 = _mm_broadcast_ss((float const *)&zero); + xmm8 = _mm_shuffle_ps(xmm0,xmm1,0x44); + xmm9 = _mm_shuffle_ps(xmm0,xmm1,0xEE); + + _mm_storel_pi((__m64 *)(ptr_a10_dup + 4), xmm6); + _mm_storel_pi((__m64 *)(ptr_a10_dup + 4 + p_lda), xmm7); + _mm_storel_pi((__m64 *)(ptr_a10_dup + 4 + p_lda*2), xmm8); + _mm_storel_pi((__m64 *)(ptr_a10_dup + 4 + p_lda*3), xmm9); + + a01 += d_nr*cs_a; + ptr_a10_dup += d_nr; + } + } + else + { + __m128 xmm0,xmm1; + dim_t loop_count = (n-n_remainder)/6; + for(dim_t i=0; i < loop_count; i++) + { + xmm1 = _mm_broadcast_ss((float *)&zero); + + xmm0 = _mm_loadu_ps((float *)(a01 + rs_a * 0 + i*6)); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda * 0 + i*6), xmm0); + xmm0 = _mm_loadl_pi(xmm1,(__m64 *)(a01 + rs_a * 0 + 4 + i*6)); + _mm_storel_pi((__m64 *)(ptr_a10_dup + p_lda * 0 + 4 + i*6),xmm0); + + xmm0 = _mm_loadu_ps((float const *)(a01 + rs_a * 1 + i*6)); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda * 1 + i*6), xmm0); + xmm0 = _mm_loadl_pi(xmm1,(__m64 *)(a01 + rs_a * 1 + 4 + i*6)); + _mm_storel_pi((__m64 *)(ptr_a10_dup + p_lda * 1 + 4 + i*6),xmm0); + + xmm0 = _mm_loadu_ps((float const *)(a01 + rs_a * 2 + i*6)); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda * 2 + i*6), xmm0); + xmm0 = _mm_loadl_pi(xmm1,(__m64 *)(a01 + rs_a * 2 + 4 + i*6)); + _mm_storel_pi((__m64 *)(ptr_a10_dup + p_lda * 2 + 4 + i*6),xmm0); + + xmm0 = _mm_loadu_ps((float const *)(a01 + rs_a * 3 + i*6)); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda * 3 + i*6), xmm0); + xmm0 = _mm_loadl_pi(xmm1,(__m64 *)(a01 + rs_a * 3 + 4 + i*6)); + _mm_storel_pi((__m64 *)(ptr_a10_dup + p_lda * 3 + 4 + i*6),xmm0); + } + } + + ymm4 = _mm256_broadcast_ss((float const *)&ones); + if(!is_unitdiag) + { + if(transa) + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_ss((float const *)(a11)); + ymm1 = _mm256_broadcast_ss((float const *)(a11+ cs_a*1 + 1)); + ymm2 = _mm256_broadcast_ss((float const *)(a11+ cs_a*2 + 2)); + ymm3 = _mm256_broadcast_ss((float const *)(a11+ cs_a*3 + 3)); + } + else + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_ss((float const *)(a11)); + ymm1 = _mm256_broadcast_ss((float const *)(a11+ rs_a*1 + 1)); + ymm2 = _mm256_broadcast_ss((float const *)(a11+ rs_a*2 + 2)); + ymm3 = _mm256_broadcast_ss((float const *)(a11+ rs_a*3 + 3)); + } + + ymm0 = _mm256_unpacklo_ps(ymm0, ymm1); + ymm1 = _mm256_unpacklo_ps(ymm2, ymm3); + + ymm1 = _mm256_blend_ps(ymm0, ymm1, 0x0C); + #ifdef BLIS_DISABLE_TRSM_PREINVERSION + ymm4 = ymm1; + #endif + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + ymm4 = _mm256_div_ps(ymm4, ymm1); + #endif + } + _mm256_storeu_ps((float *)(d11_pack), ymm4); + + for(i = 0; (i+d_mr-1) < m; i += d_mr) //loop along 'M' direction + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_4nx16m(a01,b10,cs_b,p_lda,k_iter) + + BLIS_PRE_STRSM_SMALL_4x16(AlphaVal,b11,cs_b) + + ///implement TRSM/// + + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack)); + + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + ymm4 = STRSM_SMALL_DIV_OR_SCALE(ymm4, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); + + //(row 1):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 1*rs_a)); + + ymm5 = _mm256_fnmadd_ps(ymm1, ymm3, ymm5); + ymm6 = _mm256_fnmadd_ps(ymm1, ymm4, ymm6); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*rs_a)); + + ymm7 = _mm256_fnmadd_ps(ymm1, ymm3, ymm7); + ymm8 = _mm256_fnmadd_ps(ymm1, ymm4, ymm8); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*rs_a)); + + ymm9 = _mm256_fnmadd_ps(ymm1, ymm3, ymm9); + ymm10 = _mm256_fnmadd_ps(ymm1, ymm4, ymm10); + + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + ymm6 = STRSM_SMALL_DIV_OR_SCALE(ymm6, ymm0); + + a11 += cs_a; + + //extract a22 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 2)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*rs_a)); + + ymm7 = _mm256_fnmadd_ps(ymm1, ymm5, ymm7); + ymm8 = _mm256_fnmadd_ps(ymm1, ymm6, ymm8); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*rs_a)); + + ymm9 = _mm256_fnmadd_ps(ymm1, ymm5, ymm9); + ymm10 = _mm256_fnmadd_ps(ymm1, ymm6, ymm10); + + ymm7 = STRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + ymm8 = STRSM_SMALL_DIV_OR_SCALE(ymm8, ymm0); + + a11 += cs_a; + + //extract a33 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 3)); + + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*rs_a)); + + ymm9 = _mm256_fnmadd_ps(ymm1, ymm7, ymm9); + ymm10 = _mm256_fnmadd_ps(ymm1, ymm8, ymm10); + + ymm9 = STRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + ymm10 = STRSM_SMALL_DIV_OR_SCALE(ymm10, ymm0); + + _mm256_storeu_ps((float *)b11, ymm3); + _mm256_storeu_ps((float *)(b11 + 8), ymm4); + _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); + _mm256_storeu_ps((float *)(b11 + cs_b + 8), ymm6); + _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); + _mm256_storeu_ps((float *)(b11 + cs_b*2 + 8), ymm8); + _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); + _mm256_storeu_ps((float *)(b11 + cs_b*3 + 8), ymm10); + } + + dim_t m_remainder = m - i; + if(m_remainder >= 8) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_4nx8m(a01,b10,cs_b,p_lda,k_iter) + + ymm15 = _mm256_broadcast_ss((float const *)(&AlphaVal)); //register to hold alpha + + ymm0 = _mm256_loadu_ps((float const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm3 = _mm256_fmsub_ps(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_fmsub_ps(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm7 = _mm256_fmsub_ps(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 + + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm9 = _mm256_fmsub_ps(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 + + ///implement TRSM/// + + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack)); + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); + + //(row 1):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm3, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm3, ymm7); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_ps(ymm1, ymm3, ymm9); + + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + a11 += cs_a; + + //extract a22 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 2)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm5, ymm7); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_ps(ymm1, ymm5, ymm9); + + ymm7 = STRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + a11 += cs_a; + + //extract a33 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 3)); + + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_ps(ymm1, ymm7, ymm9); + + ymm9 = STRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + + _mm256_storeu_ps((float *)b11, ymm3); + _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); + _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); + _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); + + m_remainder -= 8; + i += 8; + } + + if(m_remainder == 7) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_4nx8m(a01,b10,cs_b,p_lda,k_iter) + + ymm15 = _mm256_broadcast_ss((float const *)(&AlphaVal)); //register to hold alpha + + __m128 xmm0,xmm1; + xmm0 = _mm_loadu_ps((float const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + + xmm0 = _mm_broadcast_ss((float *)(b11 + 6)); + xmm1 = _mm_loadl_pi(xmm0,(__m64 *)(b11 + 4)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 1); + + ymm3 = _mm256_fmsub_ps(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + + xmm0 = _mm_broadcast_ss((float *)(b11 + 6 + cs_b)); + xmm1 = _mm_loadl_pi(xmm0,(__m64 *)(b11 + 4 + cs_b)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 1); + + ymm5 = _mm256_fmsub_ps(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + + xmm0 = _mm_broadcast_ss((float *)(b11 + 6 + cs_b*2)); + xmm1 = _mm_loadl_pi(xmm0,(__m64 *)(b11 + 4 + cs_b*2)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 1); + + ymm7 = _mm256_fmsub_ps(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 + + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b*3)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + + xmm0 = _mm_broadcast_ss((float *)(b11 + 6 + cs_b*3)); + xmm1 = _mm_loadl_pi(xmm0,(__m64 *)(b11 + 4 + cs_b*3)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 1); + + ymm9 = _mm256_fmsub_ps(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 + + ///implement TRSM/// + + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack)); + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); + + //(row 1):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm3, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm3, ymm7); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_ps(ymm1, ymm3, ymm9); + + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + a11 += cs_a; + + //extract a22 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 2)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm5, ymm7); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_ps(ymm1, ymm5, ymm9); + + ymm7 = STRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + a11 += cs_a; + + //extract a33 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 3)); + + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_ps(ymm1, ymm7, ymm9); + + ymm9 = STRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + + _mm_storeu_ps((float *)(b11),_mm256_extractf128_ps(ymm3, 0)); + _mm_storel_pi((__m64 *)(b11 + 4),_mm256_extractf128_ps(ymm3, 1)); + _mm_store_ss((float *)(b11 + 6),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm3,ymm3), 1)); + + _mm_storeu_ps((float *)(b11 + cs_b),_mm256_extractf128_ps(ymm5, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b),_mm256_extractf128_ps(ymm5, 1)); + _mm_store_ss((float *)(b11 + 6 + cs_b),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm5,ymm5), 1)); + + _mm_storeu_ps((float *)(b11 + cs_b*2),_mm256_extractf128_ps(ymm7, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*2),_mm256_extractf128_ps(ymm7, 1)); + _mm_store_ss((float *)(b11 + 6 + cs_b*2),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm7,ymm7), 1)); + + _mm_storeu_ps((float *)(b11 + cs_b*3),_mm256_extractf128_ps(ymm9, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*3),_mm256_extractf128_ps(ymm9, 1)); + _mm_store_ss((float *)(b11 + 6 + cs_b*3),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm9,ymm9), 1)); + + m_remainder -= 7; + i += 7; + } + else if(m_remainder == 6) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_4nx8m(a01,b10,cs_b,p_lda,k_iter) + + ymm15 = _mm256_broadcast_ss((float const *)(&AlphaVal)); //register to hold alpha + + __m128 xmm0,xmm1; + xmm0 = _mm_loadu_ps((float const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + + xmm1 = _mm_loadl_pi(xmm0,(__m64 *)(b11 + 4)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 1); + + ymm3 = _mm256_fmsub_ps(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + + xmm1 = _mm_loadl_pi(xmm0,(__m64 *)(b11 + 4 + cs_b)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 1); + + ymm5 = _mm256_fmsub_ps(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + + xmm1 = _mm_loadl_pi(xmm0,(__m64 *)(b11 + 4 + cs_b*2)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 1); + + ymm7 = _mm256_fmsub_ps(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 + + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b*3)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + + xmm1 = _mm_loadl_pi(xmm0,(__m64 *)(b11 + 4 + cs_b*3)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 1); + + ymm9 = _mm256_fmsub_ps(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 + + ///implement TRSM/// + + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack)); + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); + + //(row 1):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm3, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm3, ymm7); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_ps(ymm1, ymm3, ymm9); + + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + a11 += cs_a; + + //extract a22 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 2)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm5, ymm7); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_ps(ymm1, ymm5, ymm9); + + ymm7 = STRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + a11 += cs_a; + + //extract a33 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 3)); + + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_ps(ymm1, ymm7, ymm9); + + ymm9 = STRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + + _mm_storeu_ps((float *)(b11),_mm256_extractf128_ps(ymm3, 0)); + _mm_storel_pi((__m64 *)(b11 + 4),_mm256_extractf128_ps(ymm3, 1)); + _mm_storeu_ps((float *)(b11 + cs_b),_mm256_extractf128_ps(ymm5, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b),_mm256_extractf128_ps(ymm5, 1)); + _mm_storeu_ps((float *)(b11 + cs_b*2),_mm256_extractf128_ps(ymm7, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*2),_mm256_extractf128_ps(ymm7, 1)); + _mm_storeu_ps((float *)(b11 + cs_b*3),_mm256_extractf128_ps(ymm9, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*3),_mm256_extractf128_ps(ymm9, 1)); + + m_remainder -= 6; + i += 6; + } + else if(m_remainder == 5) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_4nx8m(a01,b10,cs_b,p_lda,k_iter) + + ymm15 = _mm256_broadcast_ss((float const *)(&AlphaVal)); //register to hold alpha + + __m128 xmm0; + ymm0 = _mm256_broadcast_ss((float const *)(b11 + 4)); + xmm0 = _mm_loadu_ps((float const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + ymm3 = _mm256_fmsub_ps(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + + ymm0 = _mm256_broadcast_ss((float const *)(b11 + 4 + cs_b)); + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + ymm5 = _mm256_fmsub_ps(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + + ymm0 = _mm256_broadcast_ss((float const *)(b11 + 4 + cs_b*2)); + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + ymm7 = _mm256_fmsub_ps(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 + + ymm0 = _mm256_broadcast_ss((float const *)(b11 + 4 + cs_b*3)); + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b*3)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + ymm9 = _mm256_fmsub_ps(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 + + ///implement TRSM/// + + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack)); + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); + + //(row 1):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm3, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm3, ymm7); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_ps(ymm1, ymm3, ymm9); + + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + a11 += cs_a; + + //extract a22 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 2)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm5, ymm7); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_ps(ymm1, ymm5, ymm9); + + ymm7 = STRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + a11 += cs_a; + + //extract a33 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 3)); + + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_ps(ymm1, ymm7, ymm9); + + ymm9 = STRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + + _mm_storeu_ps((float *)(b11),_mm256_extractf128_ps(ymm3, 0)); + _mm_store_ss((float *)(b11 + 4),_mm256_extractf128_ps(ymm3, 1)); + _mm_storeu_ps((float *)(b11 + cs_b),_mm256_extractf128_ps(ymm5, 0)); + _mm_store_ss((float *)(b11 + 4 + cs_b),_mm256_extractf128_ps(ymm5, 1)); + _mm_storeu_ps((float *)(b11 + cs_b*2),_mm256_extractf128_ps(ymm7, 0)); + _mm_store_ss((float *)(b11 + 4 + cs_b*2),_mm256_extractf128_ps(ymm7, 1)); + _mm_storeu_ps((float *)(b11 + cs_b*3),_mm256_extractf128_ps(ymm9, 0)); + _mm_store_ss((float *)(b11 + 4 + cs_b*3),_mm256_extractf128_ps(ymm9, 1)); + + m_remainder -= 5; + i += 5; + } + else if(m_remainder == 4) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_4nx8m(a01,b10,cs_b,p_lda,k_iter) + + ymm15 = _mm256_broadcast_ss((float const *)(&AlphaVal)); //register to hold alpha + + __m128 xmm0; + xmm0 = _mm_loadu_ps((float const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + ymm3 = _mm256_fmsub_ps(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + ymm5 = _mm256_fmsub_ps(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + ymm7 = _mm256_fmsub_ps(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 + + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b*3)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + ymm9 = _mm256_fmsub_ps(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 + + ///implement TRSM/// + + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack)); + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); + + //(row 1):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm3, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm3, ymm7); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_ps(ymm1, ymm3, ymm9); + + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + a11 += cs_a; + + //extract a22 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 2)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm5, ymm7); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_ps(ymm1, ymm5, ymm9); + + ymm7 = STRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + a11 += cs_a; + + //extract a33 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 3)); + + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_ps(ymm1, ymm7, ymm9); + + ymm9 = STRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + + _mm_storeu_ps((float *)(b11),_mm256_extractf128_ps(ymm3, 0)); + _mm_storeu_ps((float *)(b11 + cs_b),_mm256_extractf128_ps(ymm5, 0)); + _mm_storeu_ps((float *)(b11 + cs_b*2),_mm256_extractf128_ps(ymm7, 0)); + _mm_storeu_ps((float *)(b11 + cs_b*3),_mm256_extractf128_ps(ymm9, 0)); + + m_remainder -= 4; + i += 4; + } + else if(m_remainder == 3) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_4nx8m(a01,b10,cs_b,p_lda,k_iter) + + ymm15 = _mm256_broadcast_ss((float const *)(&AlphaVal)); //register to hold alpha + + __m128 xmm0,xmm1; + xmm1 = _mm_broadcast_ss((float *)(b11 + 2)); + xmm0 = _mm_loadl_pi(xmm1,(__m64 *)(b11)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + ymm3 = _mm256_fmsub_ps(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + + xmm1 = _mm_broadcast_ss((float *)(b11 + cs_b + 2)); + xmm0 = _mm_loadl_pi(xmm1,(__m64 *)(b11 + cs_b)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + ymm5 = _mm256_fmsub_ps(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + + xmm1 = _mm_broadcast_ss((float *)(b11 + cs_b*2 + 2)); + xmm0 = _mm_loadl_pi(xmm1,(__m64 *)(b11 + cs_b*2)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + ymm7 = _mm256_fmsub_ps(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 + + xmm1 = _mm_broadcast_ss((float *)(b11 + cs_b*3 + 2)); + xmm0 = _mm_loadl_pi(xmm1,(__m64 *)(b11 + cs_b*3)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + ymm9 = _mm256_fmsub_ps(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 + + ///implement TRSM/// + + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack)); + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); + + //(row 1):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm3, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm3, ymm7); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_ps(ymm1, ymm3, ymm9); + + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + a11 += cs_a; + + //extract a22 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 2)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm5, ymm7); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_ps(ymm1, ymm5, ymm9); + + ymm7 = STRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + a11 += cs_a; + + //extract a33 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 3)); + + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_ps(ymm1, ymm7, ymm9); + + ymm9 = STRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + + xmm0 = _mm256_extractf128_ps(ymm3, 0); + _mm_storel_pi((__m64 *)(b11),xmm0); + _mm_store_ss((float *)(b11+2),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm3,ymm3), 0)); + + xmm0 = _mm256_extractf128_ps(ymm5, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b),xmm0); + _mm_store_ss((float *)(b11+ 2 + cs_b),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm5,ymm5), 0)); + + xmm0 = _mm256_extractf128_ps(ymm7, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b*2),xmm0); + _mm_store_ss((float *)(b11 + 2 + cs_b*2),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm7,ymm7), 0)); + + xmm0 = _mm256_extractf128_ps(ymm9, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b*3),xmm0); + _mm_store_ss((float *)(b11 + 2 + cs_b*3),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm9,ymm9), 0)); + + m_remainder -= 3; + i += 3; + } + else if(m_remainder == 2) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_4nx8m(a01,b10,cs_b,p_lda,k_iter) + + ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); //register to hold alpha + + __m128 xmm0,xmm1; + xmm1 = _mm_broadcast_ss((float *)&zero); + xmm0 = _mm_loadl_pi(xmm1,(__m64 *)(b11)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + ymm3 = _mm256_fmsub_ps(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + + xmm0 = _mm_loadl_pi(xmm1,(__m64 *)(b11 + cs_b)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + ymm5 = _mm256_fmsub_ps(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + + xmm0 = _mm_loadl_pi(xmm1,(__m64 *)(b11 + cs_b*2)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + ymm7 = _mm256_fmsub_ps(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 + + xmm0 = _mm_loadl_pi(xmm1,(__m64 *)(b11 + cs_b*3)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + ymm9 = _mm256_fmsub_ps(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 + + ///implement TRSM/// + + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack)); + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); + + //(row 1):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm3, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm3, ymm7); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_ps(ymm1, ymm3, ymm9); + + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + a11 += cs_a; + + //extract a22 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 2)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm5, ymm7); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_ps(ymm1, ymm5, ymm9); + + ymm7 = STRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + a11 += cs_a; + + //extract a33 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 3)); + + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_ps(ymm1, ymm7, ymm9); + + ymm9 = STRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + + xmm0 = _mm256_extractf128_ps(ymm3, 0); + _mm_storel_pi((__m64 *)(b11),xmm0); + + xmm0 = _mm256_extractf128_ps(ymm5, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b),xmm0); + + xmm0 = _mm256_extractf128_ps(ymm7, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b*2),xmm0); + + xmm0 = _mm256_extractf128_ps(ymm9, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b*3),xmm0); + + m_remainder -= 2; + i += 2; + } + else if(m_remainder == 1) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_4nx8m(a01,b10,cs_b,p_lda,k_iter) + + ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); //register to hold alpha + + ymm0 = _mm256_broadcast_ss((float const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm3 = _mm256_fmsub_ps(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + + ymm0 = _mm256_broadcast_ss((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_fmsub_ps(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + + ymm0 = _mm256_broadcast_ss((float const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm7 = _mm256_fmsub_ps(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 + + ymm0 = _mm256_broadcast_ss((float const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm9 = _mm256_fmsub_ps(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 + + ///implement TRSM/// + + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack)); + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); + + //(row 1):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm3, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm3, ymm7); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_ps(ymm1, ymm3, ymm9); + + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + a11 += cs_a; + + //extract a22 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 2)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm5, ymm7); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_ps(ymm1, ymm5, ymm9); + + ymm7 = STRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + a11 += cs_a; + + //extract a33 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 3)); + + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 3*rs_a)); + ymm9 = _mm256_fnmadd_ps(ymm1, ymm7, ymm9); + + ymm9 = STRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + + _mm_store_ss((b11 + cs_b * 0), _mm256_extractf128_ps(ymm3, 0)); + _mm_store_ss((b11 + cs_b * 1), _mm256_extractf128_ps(ymm5, 0)); + _mm_store_ss((b11 + cs_b * 2), _mm256_extractf128_ps(ymm7, 0)); + _mm_store_ss((b11 + cs_b * 3), _mm256_extractf128_ps(ymm9, 0)); + + m_remainder -= 1; + i += 1; + } + + j += 4; + n_remainder -= 4; + } + + if(n_remainder == 3) + { + a01 = L + j*rs_a; //pointer to block of A to be used in GEMM + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + + float *ptr_a10_dup = D_A_pack; + + dim_t p_lda = j; // packed leading dimension + // perform copy of A to packed buffer D_A_pack + + if(transa) + { + __m128 xmm0, xmm1, xmm2, xmm3; + __m128 xmm4, xmm5, xmm6, xmm7; + __m128 xmm8, xmm9; + + for(dim_t x =0;x < p_lda;x+=d_nr) + { + xmm0 = _mm_loadu_ps((float const *)(a01)); + xmm1 = _mm_loadu_ps((float const *)(a01 + cs_a)); + xmm2 = _mm_loadu_ps((float const *)(a01 + cs_a * 2)); + xmm3 = _mm_loadu_ps((float const *)(a01 + cs_a * 3)); + + xmm4 = _mm_unpacklo_ps(xmm0, xmm1); + xmm5 = _mm_unpacklo_ps(xmm2, xmm3); + xmm6 = _mm_shuffle_ps(xmm4,xmm5,0x44); + xmm7 = _mm_shuffle_ps(xmm4,xmm5,0xEE); + + xmm0 = _mm_unpackhi_ps(xmm0, xmm1); + xmm1 = _mm_unpackhi_ps(xmm2, xmm3); + xmm8 = _mm_shuffle_ps(xmm0,xmm1,0x44); + xmm9 = _mm_shuffle_ps(xmm0,xmm1,0xEE); + + _mm_storeu_ps((float *)(ptr_a10_dup), xmm6); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda), xmm7); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda*2), xmm8); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda*3), xmm9); + + xmm0 = _mm_loadu_ps((float const *)(a01 + cs_a * 4)); + xmm1 = _mm_loadu_ps((float const *)(a01 + cs_a * 5)); + + xmm4 = _mm_unpacklo_ps(xmm0, xmm1); + xmm5 = _mm_broadcast_ss((float const *)&zero); + xmm6 = _mm_shuffle_ps(xmm4,xmm5,0x44); + xmm7 = _mm_shuffle_ps(xmm4,xmm5,0xEE); + + xmm0 = _mm_unpackhi_ps(xmm0, xmm1); + xmm1 = _mm_broadcast_ss((float const *)&zero); + xmm8 = _mm_shuffle_ps(xmm0,xmm1,0x44); + xmm9 = _mm_shuffle_ps(xmm0,xmm1,0xEE); + + _mm_storel_pi((__m64 *)(ptr_a10_dup + 4), xmm6); + _mm_storel_pi((__m64 *)(ptr_a10_dup + 4 + p_lda), xmm7); + _mm_storel_pi((__m64 *)(ptr_a10_dup + 4 + p_lda*2), xmm8); + _mm_storel_pi((__m64 *)(ptr_a10_dup + 4 + p_lda*3), xmm9); + + a01 += d_nr*cs_a; + ptr_a10_dup += d_nr; + } + } + else + { + __m128 xmm0,xmm1; + dim_t loop_count = (n-n_remainder)/6; + for(dim_t i=0; i < loop_count; i++) + { + xmm1 = _mm_broadcast_ss((float *)&zero); + + xmm0 = _mm_loadu_ps((float *)(a01 + rs_a * 0 + i*6)); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda * 0 + i*6), xmm0); + xmm0 = _mm_loadl_pi(xmm1,(__m64 *)(a01 + rs_a * 0 + 4 + i*6)); + _mm_storel_pi((__m64 *)(ptr_a10_dup + p_lda * 0 + 4 + i*6),xmm0); + + xmm0 = _mm_loadu_ps((float const *)(a01 + rs_a * 1 + i*6)); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda * 1 + i*6), xmm0); + xmm0 = _mm_loadl_pi(xmm1,(__m64 *)(a01 + rs_a * 1 + 4 + i*6)); + _mm_storel_pi((__m64 *)(ptr_a10_dup + p_lda * 1 + 4 + i*6),xmm0); + + xmm0 = _mm_loadu_ps((float const *)(a01 + rs_a * 2 + i*6)); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda * 2 + i*6), xmm0); + xmm0 = _mm_loadl_pi(xmm1,(__m64 *)(a01 + rs_a * 2 + 4 + i*6)); + _mm_storel_pi((__m64 *)(ptr_a10_dup + p_lda * 2 + 4 + i*6),xmm0); + } + } + + ymm4 = _mm256_broadcast_ss((float const *)&ones); + if(!is_unitdiag) + { + if(transa) + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_ss((float const *)(a11)); + ymm1 = _mm256_broadcast_ss((float const *)(a11+ cs_a*1 + 1)); + ymm2 = _mm256_broadcast_ss((float const *)(a11+ cs_a*2 + 2)); + ymm3 = _mm256_broadcast_ss((float const *)&ones); + } + else + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_ss((float const *)(a11)); + ymm1 = _mm256_broadcast_ss((float const *)(a11+ rs_a*1 + 1)); + ymm2 = _mm256_broadcast_ss((float const *)(a11+ rs_a*2 + 2)); + ymm3 = _mm256_broadcast_ss((float const *)&ones); + } + + ymm0 = _mm256_unpacklo_ps(ymm0, ymm1); + ymm1 = _mm256_unpacklo_ps(ymm2, ymm3); + + ymm1 = _mm256_blend_ps(ymm0, ymm1, 0x0C); + #ifdef BLIS_DISABLE_TRSM_PREINVERSION + ymm4 = ymm1; + #endif + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + ymm4 = _mm256_div_ps(ymm4, ymm1); + #endif + } + _mm256_storeu_ps((float *)(d11_pack), ymm4); + + for(i = 0; (i+d_mr-1) < m; i += d_mr) //loop along 'M' direction + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_3nx16m(a01,b10,cs_b,p_lda,k_iter) + + ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); + + ymm0 = _mm256_loadu_ps((float const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_ps((float const *)(b11 + 8)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] + + ymm3 = _mm256_fmsub_ps(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + ymm4 = _mm256_fmsub_ps(ymm1, ymm15, ymm4); //B11[4-7][0] * alpha-= ymm1 + + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm1 = _mm256_loadu_ps((float const *)(b11 + cs_b + 8)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] + + ymm5 = _mm256_fmsub_ps(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + ymm6 = _mm256_fmsub_ps(ymm1, ymm15, ymm6); //B11[4-7][1] * alpha -= ymm3 + + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm1 = _mm256_loadu_ps((float const *)(b11 + cs_b*2 + 8)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] + + ymm7 = _mm256_fmsub_ps(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 + ymm8 = _mm256_fmsub_ps(ymm1, ymm15, ymm8); //B11[4-7][2] * alpha -= ymm5 + + ///implement TRSM/// + + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack)); + + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + ymm4 = STRSM_SMALL_DIV_OR_SCALE(ymm4, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); + + //(row 1):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 1*rs_a)); + + ymm5 = _mm256_fnmadd_ps(ymm1, ymm3, ymm5); + ymm6 = _mm256_fnmadd_ps(ymm1, ymm4, ymm6); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*rs_a)); + + ymm7 = _mm256_fnmadd_ps(ymm1, ymm3, ymm7); + ymm8 = _mm256_fnmadd_ps(ymm1, ymm4, ymm8); + + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + ymm6 = STRSM_SMALL_DIV_OR_SCALE(ymm6, ymm0); + + a11 += cs_a; + + //extract a22 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 2)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*rs_a)); + + ymm7 = _mm256_fnmadd_ps(ymm1, ymm5, ymm7); + ymm8 = _mm256_fnmadd_ps(ymm1, ymm6, ymm8); + + ymm7 = STRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + ymm8 = STRSM_SMALL_DIV_OR_SCALE(ymm8, ymm0); + + _mm256_storeu_ps((float *)b11, ymm3); + _mm256_storeu_ps((float *)(b11 + 8), ymm4); + _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); + _mm256_storeu_ps((float *)(b11 + cs_b + 8), ymm6); + _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); + _mm256_storeu_ps((float *)(b11 + cs_b*2 + 8), ymm8); + } + + dim_t m_remainder = m - i; + if(m_remainder >= 8) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_3nx8m(a01,b10,cs_b,p_lda,k_iter) + + ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); //register to hold alpha + + ymm0 = _mm256_loadu_ps((float const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm3 = _mm256_fmsub_ps(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_fmsub_ps(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm7 = _mm256_fmsub_ps(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 + + ///implement TRSM/// + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack)); + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); + + //(row 1):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm3, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm3, ymm7); + + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + a11 += cs_a; + + //extract a22 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 2)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm5, ymm7); + + ymm7 = STRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + _mm256_storeu_ps((float *)b11, ymm3); + _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); + _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); + + m_remainder -= 8; + i += 8; + } + + if(m_remainder == 7) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_3nx8m(a01,b10,cs_b,p_lda,k_iter) + + BLIS_PRE_STRSM_SMALL_3N_7M(AlphaVal,b11,cs_b) + + ///implement TRSM/// + + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack)); + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); + + //(row 1):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm3, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm3, ymm7); + + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + a11 += cs_a; + + //extract a22 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 2)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm5, ymm7); + + ymm7 = STRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + BLIS_POST_STRSM_SMALL_3N_7M(b11,cs_b) + + m_remainder -= 7; + i += 7; + } + else if(m_remainder == 6) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_3nx8m(a01,b10,cs_b,p_lda,k_iter) + + BLIS_PRE_STRSM_SMALL_3N_6M(AlphaVal,b11,cs_b) + + ///implement TRSM/// + + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack)); + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); + + //(row 1):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm3, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm3, ymm7); + + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + a11 += cs_a; + + //extract a22 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 2)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm5, ymm7); + + ymm7 = STRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + BLIS_POST_STRSM_SMALL_3N_6M(b11,cs_b) + + m_remainder -= 6; + i += 6; + } + else if(m_remainder == 5) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_3nx8m(a01,b10,cs_b,p_lda,k_iter) + + BLIS_PRE_STRSM_SMALL_3N_5M(AlphaVal,b11,cs_b) + + ///implement TRSM/// + + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack)); + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); + + //(row 1):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm3, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm3, ymm7); + + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + a11 += cs_a; + + //extract a22 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 2)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm5, ymm7); + + ymm7 = STRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + BLIS_POST_STRSM_SMALL_3N_5M(b11,cs_b) + + m_remainder -= 5; + i += 5; + } + else if(m_remainder == 4) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_3nx8m(a01,b10,cs_b,p_lda,k_iter) + + BLIS_PRE_STRSM_SMALL_3N_4M(AlphaVal,b11,cs_b) + + ///implement TRSM/// + + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack)); + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); + + //(row 1):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm3, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm3, ymm7); + + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + a11 += cs_a; + + //extract a22 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 2)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm5, ymm7); + + ymm7 = STRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + BLIS_POST_STRSM_SMALL_3N_4M(b11,cs_b) + + m_remainder -= 4; + i += 4; + } + else if(m_remainder == 3) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_3nx8m(a01,b10,cs_b,p_lda,k_iter) + + BLIS_PRE_STRSM_SMALL_3N_3M(AlphaVal,b11,cs_b) + + ///implement TRSM/// + + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack)); + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); + + //(row 1):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm3, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm3, ymm7); + + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + a11 += cs_a; + + //extract a22 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 2)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm5, ymm7); + + ymm7 = STRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + BLIS_POST_STRSM_SMALL_3N_3M(b11,cs_b) + + m_remainder -= 3; + i += 3; + } + else if(m_remainder == 2) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_3nx8m(a01,b10,cs_b,p_lda,k_iter) + + BLIS_PRE_STRSM_SMALL_3N_2M(AlphaVal,b11,cs_b) + + ///implement TRSM/// + + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack)); + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); + + //(row 1):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm3, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm3, ymm7); + + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + a11 += cs_a; + + //extract a22 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 2)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm5, ymm7); + + ymm7 = STRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + BLIS_POST_STRSM_SMALL_3N_2M(b11,cs_b) + + m_remainder -= 2; + i += 2; + } + else if(m_remainder == 1) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_3nx8m(a01,b10,cs_b,p_lda,k_iter) + + BLIS_PRE_STRSM_SMALL_3N_1M(AlphaVal,b11,cs_b) + + ///implement TRSM/// + + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack)); + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); + + //(row 1):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm3, ymm5); + + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm3, ymm7); + + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + a11 += cs_a; + + //extract a22 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 2)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 2*rs_a)); + ymm7 = _mm256_fnmadd_ps(ymm1, ymm5, ymm7); + + ymm7 = STRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + BLIS_POST_STRSM_SMALL_3N_1M(b11,cs_b) + + m_remainder -= 1; + i += 1; + } + + j += 3; + n_remainder -= 3; + } + else if(n_remainder == 2) + { + a01 = L + j*rs_a; //pointer to block of A to be used in GEMM + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + + float *ptr_a10_dup = D_A_pack; + + dim_t p_lda = j; // packed leading dimension + // perform copy of A to packed buffer D_A_pack + + if(transa) + { + __m128 xmm0, xmm1, xmm2, xmm3; + __m128 xmm4, xmm5, xmm6, xmm7; + __m128 xmm8, xmm9; + + for(dim_t x =0;x < p_lda;x+=d_nr) + { + xmm0 = _mm_loadu_ps((float const *)(a01)); + xmm1 = _mm_loadu_ps((float const *)(a01 + cs_a)); + xmm2 = _mm_loadu_ps((float const *)(a01 + cs_a * 2)); + xmm3 = _mm_loadu_ps((float const *)(a01 + cs_a * 3)); + + xmm4 = _mm_unpacklo_ps(xmm0, xmm1); + xmm5 = _mm_unpacklo_ps(xmm2, xmm3); + xmm6 = _mm_shuffle_ps(xmm4,xmm5,0x44); + xmm7 = _mm_shuffle_ps(xmm4,xmm5,0xEE); + + xmm0 = _mm_unpackhi_ps(xmm0, xmm1); + xmm1 = _mm_unpackhi_ps(xmm2, xmm3); + xmm8 = _mm_shuffle_ps(xmm0,xmm1,0x44); + xmm9 = _mm_shuffle_ps(xmm0,xmm1,0xEE); + + _mm_storeu_ps((float *)(ptr_a10_dup), xmm6); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda), xmm7); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda*2), xmm8); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda*3), xmm9); + + xmm0 = _mm_loadu_ps((float const *)(a01 + cs_a * 4)); + xmm1 = _mm_loadu_ps((float const *)(a01 + cs_a * 5)); + + xmm4 = _mm_unpacklo_ps(xmm0, xmm1); + xmm5 = _mm_broadcast_ss((float const *)&zero); + xmm6 = _mm_shuffle_ps(xmm4,xmm5,0x44); + xmm7 = _mm_shuffle_ps(xmm4,xmm5,0xEE); + + xmm0 = _mm_unpackhi_ps(xmm0, xmm1); + xmm1 = _mm_broadcast_ss((float const *)&zero); + xmm8 = _mm_shuffle_ps(xmm0,xmm1,0x44); + xmm9 = _mm_shuffle_ps(xmm0,xmm1,0xEE); + + _mm_storel_pi((__m64 *)(ptr_a10_dup + 4), xmm6); + _mm_storel_pi((__m64 *)(ptr_a10_dup + 4 + p_lda), xmm7); + _mm_storel_pi((__m64 *)(ptr_a10_dup + 4 + p_lda*2), xmm8); + _mm_storel_pi((__m64 *)(ptr_a10_dup + 4 + p_lda*3), xmm9); + + a01 += d_nr*cs_a; + ptr_a10_dup += d_nr; + } + } + else + { + __m128 xmm0,xmm1; + dim_t loop_count = (n-n_remainder)/6; + for(dim_t i=0; i < loop_count; i++) + { + xmm1 = _mm_broadcast_ss((float *)&zero); + + xmm0 = _mm_loadu_ps((float *)(a01 + rs_a * 0 + i*6)); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda * 0 + i*6), xmm0); + xmm0 = _mm_loadl_pi(xmm1,(__m64 *)(a01 + rs_a * 0 + 4 + i*6)); + _mm_storel_pi((__m64 *)(ptr_a10_dup + p_lda * 0 + 4 + i*6),xmm0); + + xmm0 = _mm_loadu_ps((float const *)(a01 + rs_a * 1 + i*6)); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda * 1 + i*6), xmm0); + xmm0 = _mm_loadl_pi(xmm1,(__m64 *)(a01 + rs_a * 1 + 4 + i*6)); + _mm_storel_pi((__m64 *)(ptr_a10_dup + p_lda * 1 + 4 + i*6),xmm0); + } + } + + ymm4 = _mm256_broadcast_ss((float const *)&ones); + if(!is_unitdiag) + { + if(transa) + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_ss((float const *)(a11)); + ymm1 = _mm256_broadcast_ss((float const *)(a11+ cs_a*1 + 1)); + } + else + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_ss((float const *)(a11)); + ymm1 = _mm256_broadcast_ss((float const *)(a11+ rs_a*1 + 1)); + } + + ymm0 = _mm256_unpacklo_ps(ymm0, ymm1); + //ymm1 = _mm256_unpacklo_ps(ymm2, ymm3); + + ymm1 = _mm256_blend_ps(ymm0, ymm0, 0x0C); + #ifdef BLIS_DISABLE_TRSM_PREINVERSION + ymm4 = ymm1; + #endif + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + ymm4 = _mm256_div_ps(ymm4, ymm1); + #endif + } + _mm256_storeu_ps((float *)(d11_pack), ymm4); + + for(i = 0; (i+d_mr-1) < m; i += d_mr) //loop along 'M' direction + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_2nx16m(a01,b10,cs_b,p_lda,k_iter) + + ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); + + ymm0 = _mm256_loadu_ps((float const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_ps((float const *)(b11 + 8)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] + + ymm3 = _mm256_fmsub_ps(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + ymm4 = _mm256_fmsub_ps(ymm1, ymm15, ymm4); //B11[4-7][0] * alpha-= ymm1 + + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm1 = _mm256_loadu_ps((float const *)(b11 + cs_b + 8)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] + + ymm5 = _mm256_fmsub_ps(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + ymm6 = _mm256_fmsub_ps(ymm1, ymm15, ymm6); //B11[4-7][1] * alpha -= ymm3 + + ///implement TRSM/// + + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack)); + + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + ymm4 = STRSM_SMALL_DIV_OR_SCALE(ymm4, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); + + //(row 1):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 1*rs_a)); + + ymm5 = _mm256_fnmadd_ps(ymm1, ymm3, ymm5); + ymm6 = _mm256_fnmadd_ps(ymm1, ymm4, ymm6); + + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + ymm6 = STRSM_SMALL_DIV_OR_SCALE(ymm6, ymm0); + + _mm256_storeu_ps((float *)b11, ymm3); + _mm256_storeu_ps((float *)(b11 + 8), ymm4); + _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); + _mm256_storeu_ps((float *)(b11 + cs_b + 8), ymm6); + } + + dim_t m_remainder = m - i; + if(m_remainder >= 8) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + + ymm3 = _mm256_setzero_ps(); + ymm5 = _mm256_setzero_ps(); + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_2nx8m(a01,b10,cs_b,p_lda,k_iter) + + ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); //register to hold alpha + + ymm0 = _mm256_loadu_ps((float const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm3 = _mm256_fmsub_ps(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_fmsub_ps(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + + ///implement TRSM/// + + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack)); + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); + + //(row 1):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm3, ymm5); + + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + _mm256_storeu_ps((float *)b11, ymm3); + _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); + + m_remainder -= 8; + i += 8; + } + + if(m_remainder == 7) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + + ymm3 = _mm256_setzero_ps(); + ymm5 = _mm256_setzero_ps(); + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_2nx8m(a01,b10,cs_b,p_lda,k_iter) + + BLIS_PRE_STRSM_SMALL_2N_7M(AlphaVal,b11,cs_b) + + ///implement TRSM/// + + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack)); + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); + + //(row 1):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm3, ymm5); + + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + BLIS_POST_STRSM_SMALL_2N_7M(b11,cs_b) + + m_remainder -= 7; + i += 7; + } + else if(m_remainder == 6) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + + ymm3 = _mm256_setzero_ps(); + ymm5 = _mm256_setzero_ps(); + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_2nx8m(a01,b10,cs_b,p_lda,k_iter) + + BLIS_PRE_STRSM_SMALL_2N_6M(AlphaVal,b11,cs_b) + + ///implement TRSM/// + + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack)); + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); + + //(row 1):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm3, ymm5); + + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + BLIS_POST_STRSM_SMALL_2N_6M(b11,cs_b) + + m_remainder -= 6; + i += 6; + } + else if(m_remainder == 5) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + + ymm3 = _mm256_setzero_ps(); + ymm5 = _mm256_setzero_ps(); + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_2nx8m(a01,b10,cs_b,p_lda,k_iter) + + BLIS_PRE_STRSM_SMALL_2N_5M(AlphaVal,b11,cs_b) + + ///implement TRSM/// + + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack)); + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); + + //(row 1):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm3, ymm5); + + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + BLIS_POST_STRSM_SMALL_2N_5M(b11,cs_b) + + m_remainder -= 5; + i += 5; + } + else if(m_remainder == 4) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + + ymm3 = _mm256_setzero_ps(); + ymm5 = _mm256_setzero_ps(); + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_2nx8m(a01,b10,cs_b,p_lda,k_iter) + + BLIS_PRE_STRSM_SMALL_2N_4M(AlphaVal,b11,cs_b) + + ///implement TRSM/// + + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack)); + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); + + //(row 1):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm3, ymm5); + + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + BLIS_POST_STRSM_SMALL_2N_4M(b11,cs_b) + + m_remainder -= 4; + i += 4; + } + else if(m_remainder == 3) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + + ymm3 = _mm256_setzero_ps(); + ymm5 = _mm256_setzero_ps(); + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_2nx8m(a01,b10,cs_b,p_lda,k_iter) + + BLIS_PRE_STRSM_SMALL_2N_3M(AlphaVal,b11,cs_b) + + ///implement TRSM/// + + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack)); + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); + + //(row 1):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm3, ymm5); + + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + BLIS_POST_STRSM_SMALL_2N_3M(b11,cs_b) + + m_remainder -= 3; + i += 3; + } + else if(m_remainder == 2) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + + ymm3 = _mm256_setzero_ps(); + ymm5 = _mm256_setzero_ps(); + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_2nx8m(a01,b10,cs_b,p_lda,k_iter) + + BLIS_PRE_STRSM_SMALL_2N_2M(AlphaVal,b11,cs_b) + + ///implement TRSM/// + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack)); + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); + + //(row 1):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm3, ymm5); + + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + BLIS_POST_STRSM_SMALL_2N_2M(b11,cs_b) + + m_remainder -= 2; + i += 2; + } + else if(m_remainder == 1) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + + ymm3 = _mm256_setzero_ps(); + ymm5 = _mm256_setzero_ps(); + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_2nx8m(a01,b10,cs_b,p_lda,k_iter) + + BLIS_PRE_STRSM_SMALL_2N_1M(AlphaVal,b11,cs_b) + + ///implement TRSM/// + + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack)); + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + //extract a11 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); + + //(row 1):FMA operations + ymm1 = _mm256_broadcast_ss((float const *)(a11 + 1*rs_a)); + ymm5 = _mm256_fnmadd_ps(ymm1, ymm3, ymm5); + + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + BLIS_POST_STRSM_SMALL_2N_1M(b11,cs_b) + + m_remainder -= 1; + i += 1; + } + + j += 2; + n_remainder -= 2; + } + else if(n_remainder == 1) + { + a01 = L + j*rs_a; //pointer to block of A to be used in GEMM + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + + float *ptr_a10_dup = D_A_pack; + + dim_t p_lda = j; // packed leading dimension + // perform copy of A to packed buffer D_A_pack + + if(transa) + { + __m128 xmm0, xmm1, xmm2, xmm3; + __m128 xmm4, xmm5, xmm6, xmm7; + __m128 xmm8, xmm9; + + for(dim_t x =0;x < p_lda;x+=d_nr) + { + xmm0 = _mm_loadu_ps((float const *)(a01)); + xmm1 = _mm_loadu_ps((float const *)(a01 + cs_a)); + xmm2 = _mm_loadu_ps((float const *)(a01 + cs_a * 2)); + xmm3 = _mm_loadu_ps((float const *)(a01 + cs_a * 3)); + + xmm4 = _mm_unpacklo_ps(xmm0, xmm1); + xmm5 = _mm_unpacklo_ps(xmm2, xmm3); + xmm6 = _mm_shuffle_ps(xmm4,xmm5,0x44); + xmm7 = _mm_shuffle_ps(xmm4,xmm5,0xEE); + + xmm0 = _mm_unpackhi_ps(xmm0, xmm1); + xmm1 = _mm_unpackhi_ps(xmm2, xmm3); + xmm8 = _mm_shuffle_ps(xmm0,xmm1,0x44); + xmm9 = _mm_shuffle_ps(xmm0,xmm1,0xEE); + + _mm_storeu_ps((float *)(ptr_a10_dup), xmm6); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda), xmm7); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda*2), xmm8); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda*3), xmm9); + + xmm0 = _mm_loadu_ps((float const *)(a01 + cs_a * 4)); + xmm1 = _mm_loadu_ps((float const *)(a01 + cs_a * 5)); + + xmm4 = _mm_unpacklo_ps(xmm0, xmm1); + xmm5 = _mm_broadcast_ss((float const *)&zero); + xmm6 = _mm_shuffle_ps(xmm4,xmm5,0x44); + xmm7 = _mm_shuffle_ps(xmm4,xmm5,0xEE); + + xmm0 = _mm_unpackhi_ps(xmm0, xmm1); + xmm1 = _mm_broadcast_ss((float const *)&zero); + xmm8 = _mm_shuffle_ps(xmm0,xmm1,0x44); + xmm9 = _mm_shuffle_ps(xmm0,xmm1,0xEE); + + _mm_storel_pi((__m64 *)(ptr_a10_dup + 4), xmm6); + _mm_storel_pi((__m64 *)(ptr_a10_dup + 4 + p_lda), xmm7); + _mm_storel_pi((__m64 *)(ptr_a10_dup + 4 + p_lda*2), xmm8); + _mm_storel_pi((__m64 *)(ptr_a10_dup + 4 + p_lda*3), xmm9); + + a01 += d_nr*cs_a; + ptr_a10_dup += d_nr; + } + } + else + { + __m128 xmm0,xmm1; + dim_t loop_count = (n-n_remainder)/6; + for(dim_t i=0; i < loop_count; i++) + { + xmm1 = _mm_broadcast_ss((float *)&zero); + + xmm0 = _mm_loadu_ps((float *)(a01 + rs_a * 0 + i*6)); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda * 0 + i*6), xmm0); + xmm0 = _mm_loadl_pi(xmm1,(__m64 *)(a01 + rs_a * 0 + 4 + i*6)); + _mm_storel_pi((__m64 *)(ptr_a10_dup + p_lda * 0 + 4 + i*6),xmm0); + } + + if((n - n_remainder - (loop_count*6))/4 != 0) + { + xmm1 = _mm_broadcast_ss((float *)&zero); + + xmm0 = _mm_loadu_ps((float *)(a01 + rs_a * 0 + loop_count*6)); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda * 0 + loop_count*6), xmm0); + xmm0 = _mm_loadl_pi(xmm1,(__m64 *)(a01 + rs_a * 0 + 4 + loop_count*6)); + _mm_storel_pi((__m64 *)(ptr_a10_dup + p_lda * 0 + 4 + loop_count*6),xmm0); + } + } + + ymm4 = _mm256_broadcast_ss((float const *)&ones); + if(!is_unitdiag) + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_ss((float const *)(a11)); + ymm1 = _mm256_broadcast_ss((float const *)&ones); + ymm2 = _mm256_broadcast_ss((float const *)&ones); + ymm3 = _mm256_broadcast_ss((float const *)&ones); + + ymm0 = _mm256_unpacklo_ps(ymm0, ymm1); + //ymm1 = _mm256_unpacklo_ps(ymm2, ymm3); + + ymm1 = _mm256_blend_ps(ymm0, ymm0, 0x0C); + #ifdef BLIS_DISABLE_TRSM_PREINVERSION + ymm4 = ymm1; + #endif + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + ymm4 = _mm256_div_ps(ymm4, ymm1); + #endif + } + _mm256_storeu_ps((float *)(d11_pack), ymm4); + + for(i = 0; (i+d_mr-1) < m; i += d_mr) //loop along 'M' direction + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + + ymm3 = _mm256_setzero_ps(); + ymm4 = _mm256_setzero_ps(); + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_1nx16m(a01,b10,cs_b,p_lda,k_iter) + + ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); + + ymm0 = _mm256_loadu_ps((float const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_ps((float const *)(b11 + 8)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] + + ymm3 = _mm256_fmsub_ps(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + ymm4 = _mm256_fmsub_ps(ymm1, ymm15, ymm4); //B11[4-7][0] * alpha-= ymm1 + + ///implement TRSM/// + + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack)); + + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + ymm4 = STRSM_SMALL_DIV_OR_SCALE(ymm4, ymm0); + + _mm256_storeu_ps((float *)b11, ymm3); + _mm256_storeu_ps((float *)(b11 + 8), ymm4); + } + + dim_t m_remainder = m - i; + if(m_remainder >= 8) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + + ymm3 = _mm256_setzero_ps(); + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_1nx8m(a01,b10,cs_b,p_lda,k_iter) + + ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); //register to hold alpha + + ymm0 = _mm256_loadu_ps((float const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm3 = _mm256_fmsub_ps(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + + ///implement TRSM/// + + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack)); + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + _mm256_storeu_ps((float *)b11, ymm3); + + m_remainder -= 8; + i += 8; + } + + if(m_remainder == 7) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + + ymm3 = _mm256_setzero_ps(); + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_1nx8m(a01,b10,cs_b,p_lda,k_iter) + + BLIS_PRE_STRSM_SMALL_1N_7M(AlphaVal,b11,cs_b) + + ///implement TRSM/// + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack )); + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + BLIS_POST_STRSM_SMALL_1N_7M(b11,cs_b) + + m_remainder -= 7; + i += 7; + } + else if(m_remainder == 6) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + + ymm3 = _mm256_setzero_ps(); + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_1nx8m(a01,b10,cs_b,p_lda,k_iter) + + BLIS_PRE_STRSM_SMALL_1N_6M(AlphaVal,b11,cs_b) + + ///implement TRSM/// + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack )); + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + BLIS_POST_STRSM_SMALL_1N_6M(b11,cs_b) + + m_remainder -= 6; + i += 6; + } + else if(m_remainder == 5) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + + ymm3 = _mm256_setzero_ps(); + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_1nx8m(a01,b10,cs_b,p_lda,k_iter) + + BLIS_PRE_STRSM_SMALL_1N_5M(AlphaVal,b11,cs_b) + + ///implement TRSM/// + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack )); + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + BLIS_POST_STRSM_SMALL_1N_5M(b11,cs_b) + + m_remainder -= 5; + i += 5; + } + else if(m_remainder == 4) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + + ymm3 = _mm256_setzero_ps(); + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_1nx8m(a01,b10,cs_b,p_lda,k_iter) + + BLIS_PRE_STRSM_SMALL_1N_4M(AlphaVal,b11,cs_b) + + ///implement TRSM/// + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack )); + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + BLIS_POST_STRSM_SMALL_1N_4M(b11,cs_b) + + m_remainder -= 4; + i += 4; + } + else if(m_remainder == 3) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + + ymm3 = _mm256_setzero_ps(); + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_1nx8m(a01,b10,cs_b,p_lda,k_iter) + + BLIS_PRE_STRSM_SMALL_1N_3M(AlphaVal,b11,cs_b) + + ///implement TRSM/// + + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack)); + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + BLIS_POST_STRSM_SMALL_1N_3M(b11,cs_b) + + m_remainder -= 3; + i += 3; + } + else if(m_remainder == 2) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + + ymm3 = _mm256_setzero_ps(); + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_1nx8m(a01,b10,cs_b,p_lda,k_iter) + + BLIS_PRE_STRSM_SMALL_1N_2M(AlphaVal,b11,cs_b) + + ///implement TRSM/// + + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack)); + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + BLIS_POST_STRSM_SMALL_1N_2M(b11,cs_b) + + m_remainder -= 2; + i += 2; + } + else if(m_remainder == 1) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + + k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + + ymm3 = _mm256_setzero_ps(); + + ///GEMM implementation starts/// + BLIS_STRSM_SMALL_GEMM_1nx8m(a01,b10,cs_b,p_lda,k_iter) + + BLIS_PRE_STRSM_SMALL_1N_1M(AlphaVal,b11,cs_b) + + ///implement TRSM/// + + //extract a00 + ymm0 = _mm256_broadcast_ss((float const *)(d11_pack)); + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + BLIS_POST_STRSM_SMALL_1N_1M(b11,cs_b) + + m_remainder -= 1; + i += 1; + } + + j += 1; + n_remainder -= 1; + } + + if ((required_packing_A == 1) && bli_mem_is_alloc( &local_mem_buf_A_s )) + { + bli_membrk_release(&rntm, + &local_mem_buf_A_s); + } + + return BLIS_SUCCESS; +} + +/* TRSM for the Left Upper case AX = alpha * B, Single precision + * A is Left side, upper-triangular, transpose, non-unit/unit diagonal + * dimensions A: mxm X: mxn B: mxn + a10 ----> b11---> + *********** ***************** + * * * * *b01*b11* * * + **a10 * * a11 b11 * * * * * + ********* | | ***************** + *a11* * | | * * * * * + * * * | | * * * * * + ****** v v ***************** + * * * * * * * + * * * * * * * + * * ***************** + * + a11---> + + * TRSM for the case AX = alpha * B, Single precision + * A is Left side, lower-triangular, no-transpose, non-unit/unit diagonal + * dimensions A: mxm X: mxn B: mxn + + b01---> + * ***************** + ** * * * * * + * * * * * * * + * * *b01* * * * + * * * * * * * +a10 ****** b11 ***************** + | * * * | * * * * * + | * * * | * * * * * + | *a10*a11* | *b11* * * * + v * * * v * * * * * + *********** ***************** + * * * * * * * * * + * * * * * * * * * + * * * * * * * * * + * * * * * * * * * + **************** ***************** + a11---> +*/ +BLIS_INLINE err_t bli_strsm_small_AutXB_AlXB +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +) +{ + dim_t m = bli_obj_length(b); // number of rows of matrix B + dim_t n = bli_obj_width(b); // number of columns of matrix B + + bool transa = bli_obj_has_trans(a); + dim_t cs_a, rs_a; + dim_t d_mr = 16,d_nr = 6; + + // Swap rs_a & cs_a in case of non-tranpose. + if(transa) + { + cs_a = bli_obj_col_stride(a); // column stride of A + rs_a = bli_obj_row_stride(a); // row stride of A + } + else + { + cs_a = bli_obj_row_stride(a); // row stride of A + rs_a = bli_obj_col_stride(a); // column stride of A + } + dim_t cs_b = bli_obj_col_stride(b); // column stride of B + + dim_t i, j, k; //loop variables + dim_t k_iter; //number of times GEMM to be performed + + float AlphaVal = *(float *)AlphaObj->buffer; //value of alpha + float *L = a->buffer; //pointer to matrix A + float *B = b->buffer; //pointer to matrix B + + float *a10, *a11, *b01, *b11; //pointers that point to blocks for GEMM and TRSM + + float ones = 1.0; + float zero = 0.0; + + bool is_unitdiag = bli_obj_has_unit_diag(a); + + //scratch registers + __m256 ymm0, ymm1, ymm2, ymm3; + __m256 ymm4, ymm5, ymm6, ymm7; + __m256 ymm8, ymm9, ymm10, ymm11; + __m256 ymm12, ymm13, ymm14, ymm15; + __m256 ymm16, ymm17, ymm18, ymm19; + __m256 ymm20,ymm21,ymm22; + + gint_t required_packing_A = 1; + mem_t local_mem_buf_A_s = {0}; + float *D_A_pack = NULL; + float d11_pack[d_mr] __attribute__((aligned(64))); + rntm_t rntm; + + bli_rntm_init_from_global( &rntm ); + bli_rntm_set_num_threads_only( 1, &rntm ); + bli_membrk_rntm_set_membrk( &rntm ); + + siz_t buffer_size = bli_pool_block_size( + bli_membrk_pool( + bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), + bli_rntm_membrk(&rntm))); + + if ( (d_mr * m * sizeof(float)) > buffer_size) + return BLIS_NOT_YET_IMPLEMENTED; + + if (required_packing_A == 1) + { + // Get the buffer from the pool. + bli_membrk_acquire_m(&rntm, + buffer_size, + BLIS_BITVAL_BUFFER_FOR_A_BLOCK, + &local_mem_buf_A_s); + if(FALSE==bli_mem_is_alloc(&local_mem_buf_A_s)) return BLIS_NULL_POINTER; + D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); + if(NULL==D_A_pack) return BLIS_NULL_POINTER; + } + + /* + Performs solving TRSM for 16 colmns at a time from 0 to m/16 in steps of d_mr + a. Load, transpose, Pack A (a10 block), the size of packing 16x6 to 16x (m-16) + First there will be no GEMM and no packing of a10 because it is only TRSM + b. Using packed a10 block and b01 block perform GEMM operation + c. Use GEMM outputs, perform TRSM operaton using a11, b11 and update B + d. Repeat b,c for n rows of B in steps of d_nr + */ + for(i = 0;(i+d_mr-1) < m; i += d_mr) //loop along 'M' dimension + { + a10 = L + (i*cs_a); //pointer to block of A to be used for GEMM + a11 = L + (i*rs_a) + (i*cs_a); + dim_t p_lda = d_mr; // packed leading dimension + + if(transa) + { + /* + Load, tranpose and pack current A block (a10) into packed buffer memory D_A_pack + a. This a10 block is used in GEMM portion only and this + a10 block size will be increasing by d_mr for every next itteration + untill it reaches 16x(m-16) which is the maximum GEMM alone block size in A + b. This packed buffer is reused to calculate all n rows of B matrix + */ + bli_strsm_small_pack('L', i, 1, a10, cs_a, D_A_pack, p_lda,d_mr); + + /* + Pack 16 diagonal elements of A block into an array + a. This helps in utilze cache line efficiently in TRSM operation + b. store ones when input is unit diagonal + */ + strsm_small_pack_diag_element('L',is_unitdiag,a11,cs_a,d11_pack,d_mr); + } + else + { + bli_strsm_small_pack('L', i, 0, a10, rs_a, D_A_pack, p_lda,d_mr); + strsm_small_pack_diag_element('L',is_unitdiag,a11,rs_a,d11_pack,d_mr); + } + + /* + a. Perform GEMM using a10, b01. + b. Perform TRSM on a11, b11 + c. This loop GEMM+TRSM loops operates with 16x6 block size + along n dimension for every d_nr rows of b01 where + packed A buffer is reused in computing all n rows of B. + d. Same approch is used in remaining fringe cases. + */ + dim_t temp = n - d_nr + 1; + for(j = 0; j < temp; j += d_nr) //loop along 'N' dimension + { + a10 = D_A_pack; + a11 = L + (i*rs_a) + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + j*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM + + k_iter = i; + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + /* + Peform GEMM between a10 and b01 blocks + For first itteration there will be no GEMM operation + where k_iter are zero + */ + BLIS_STRSM_SMALL_GEMM_16mx6n(a10,b01,cs_b,p_lda,k_iter) + + /* + Load b11 of size 6x16 and multiply with alpha + Add the GEMM output and perform inregister transose of b11 + to peform TRSM operation. + */ + BLIS_STRSM_SMALL_NREG_TRANSPOSE_6x16(b11,cs_b,AlphaVal) + + + ////extract a00 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack)); + + //perform mul operation + ymm10 = STRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); + + //(ROw1): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*1)); + ymm11 = _mm256_fnmadd_ps(ymm0, ymm10, ymm11); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*2)); + ymm17 = _mm256_fnmadd_ps(ymm0, ymm10, ymm17); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*3)); + ymm18 = _mm256_fnmadd_ps(ymm0, ymm10, ymm18); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*4)); + ymm2 = _mm256_fnmadd_ps(ymm0, ymm10, ymm2); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*5)); + ymm3 = _mm256_fnmadd_ps(ymm0, ymm10, ymm3); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*6)); + ymm19 = _mm256_fnmadd_ps(ymm0, ymm10, ymm19); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*7)); + ymm20 = _mm256_fnmadd_ps(ymm0, ymm10, ymm20); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*8)); + ymm14 = _mm256_fnmadd_ps(ymm0, ymm10, ymm14); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*9)); + ymm15 = _mm256_fnmadd_ps(ymm0, ymm10, ymm15); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*10)); + ymm8 = _mm256_fnmadd_ps(ymm0, ymm10, ymm8 ); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*11)); + ymm9 = _mm256_fnmadd_ps(ymm0, ymm10, ymm9 ); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*12)); + ymm21 = _mm256_fnmadd_ps(ymm0, ymm10, ymm21); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*13)); + ymm22 = _mm256_fnmadd_ps(ymm0, ymm10, ymm22); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*14)); + ymm4 = _mm256_fnmadd_ps(ymm0, ymm10, ymm4 ); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*15)); + ymm5 = _mm256_fnmadd_ps(ymm0, ymm10, ymm5 ); + + //perform mul operation + ymm11 = STRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); + + a11 += rs_a; + + //extract a22 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 2)); + + //(ROw2): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*2)); + ymm17 = _mm256_fnmadd_ps(ymm0, ymm11, ymm17); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*3)); + ymm18 = _mm256_fnmadd_ps(ymm0, ymm11, ymm18); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*4)); + ymm2 = _mm256_fnmadd_ps(ymm0, ymm11, ymm2); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*5)); + ymm3 = _mm256_fnmadd_ps(ymm0, ymm11, ymm3); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*6)); + ymm19 = _mm256_fnmadd_ps(ymm0, ymm11, ymm19); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*7)); + ymm20 = _mm256_fnmadd_ps(ymm0, ymm11, ymm20); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*8)); + ymm14 = _mm256_fnmadd_ps(ymm0, ymm11, ymm14); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*9)); + ymm15 = _mm256_fnmadd_ps(ymm0, ymm11, ymm15); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*10)); + ymm8 = _mm256_fnmadd_ps(ymm0, ymm11, ymm8 ); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*11)); + ymm9 = _mm256_fnmadd_ps(ymm0, ymm11, ymm9 ); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*12)); + ymm21 = _mm256_fnmadd_ps(ymm0, ymm11, ymm21); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*13)); + ymm22 = _mm256_fnmadd_ps(ymm0, ymm11, ymm22); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*14)); + ymm4 = _mm256_fnmadd_ps(ymm0, ymm11, ymm4 ); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*15)); + ymm5 = _mm256_fnmadd_ps(ymm0, ymm11, ymm5 ); + + //perform mul operation + ymm17 = STRSM_SMALL_DIV_OR_SCALE(ymm17, ymm1); + + a11 += rs_a; + + //extract a33 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 3)); + + //(ROw5): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*3)); + ymm18 = _mm256_fnmadd_ps(ymm0, ymm17, ymm18); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*4)); + ymm2 = _mm256_fnmadd_ps(ymm0, ymm17, ymm2); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*5)); + ymm3 = _mm256_fnmadd_ps(ymm0, ymm17, ymm3); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*6)); + ymm19 = _mm256_fnmadd_ps(ymm0, ymm17, ymm19); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*7)); + ymm20 = _mm256_fnmadd_ps(ymm0, ymm17, ymm20); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*8)); + ymm14 = _mm256_fnmadd_ps(ymm0, ymm17, ymm14); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*9)); + ymm15 = _mm256_fnmadd_ps(ymm0, ymm17, ymm15); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*10)); + ymm8 = _mm256_fnmadd_ps(ymm0, ymm17, ymm8 ); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*11)); + ymm9 = _mm256_fnmadd_ps(ymm0, ymm17, ymm9 ); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*12)); + ymm21 = _mm256_fnmadd_ps(ymm0, ymm17, ymm21); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*13)); + ymm22 = _mm256_fnmadd_ps(ymm0, ymm17, ymm22); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*14)); + ymm4 = _mm256_fnmadd_ps(ymm0, ymm17, ymm4 ); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*15)); + ymm5 = _mm256_fnmadd_ps(ymm0, ymm17, ymm5 ); + + //perform mul operation + ymm18 = STRSM_SMALL_DIV_OR_SCALE(ymm18, ymm1); + + a11 += rs_a; + + //extract a44 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 4)); + //(ROw4): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*4)); + ymm2 = _mm256_fnmadd_ps(ymm0, ymm18, ymm2); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*5)); + ymm3 = _mm256_fnmadd_ps(ymm0, ymm18, ymm3); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*6)); + ymm19 = _mm256_fnmadd_ps(ymm0, ymm18, ymm19); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*7)); + ymm20 = _mm256_fnmadd_ps(ymm0, ymm18, ymm20); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*8)); + ymm14 = _mm256_fnmadd_ps(ymm0, ymm18, ymm14); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*9)); + ymm15 = _mm256_fnmadd_ps(ymm0, ymm18, ymm15); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*10)); + ymm8 = _mm256_fnmadd_ps(ymm0, ymm18, ymm8 ); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*11)); + ymm9 = _mm256_fnmadd_ps(ymm0, ymm18, ymm9 ); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*12)); + ymm21 = _mm256_fnmadd_ps(ymm0, ymm18, ymm21); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*13)); + ymm22 = _mm256_fnmadd_ps(ymm0, ymm18, ymm22); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*14)); + ymm4 = _mm256_fnmadd_ps(ymm0, ymm18, ymm4 ); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*15)); + ymm5 = _mm256_fnmadd_ps(ymm0, ymm18, ymm5 ); + + //perform mul operation + ymm2 = STRSM_SMALL_DIV_OR_SCALE(ymm2, ymm1); + + a11 += rs_a; + + //extract a55 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 5)); + + //(ROw5): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*5)); + ymm3 = _mm256_fnmadd_ps(ymm0, ymm2, ymm3); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*6)); + ymm19 = _mm256_fnmadd_ps(ymm0, ymm2, ymm19); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*7)); + ymm20 = _mm256_fnmadd_ps(ymm0, ymm2, ymm20); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*8)); + ymm14 = _mm256_fnmadd_ps(ymm0, ymm2, ymm14); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*9)); + ymm15 = _mm256_fnmadd_ps(ymm0, ymm2, ymm15); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*10)); + ymm8 = _mm256_fnmadd_ps(ymm0, ymm2, ymm8 ); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*11)); + ymm9 = _mm256_fnmadd_ps(ymm0, ymm2, ymm9 ); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*12)); + ymm21 = _mm256_fnmadd_ps(ymm0, ymm2, ymm21); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*13)); + ymm22 = _mm256_fnmadd_ps(ymm0, ymm2, ymm22); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*14)); + ymm4 = _mm256_fnmadd_ps(ymm0, ymm2, ymm4 ); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*15)); + ymm5 = _mm256_fnmadd_ps(ymm0, ymm2, ymm5 ); + + //perform mul operation + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm1); + + a11 += rs_a; + + //extract a66 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 6)); - _mm256_storeu_pd((double *)(pbuff + 4), ymm6); - _mm256_storeu_pd((double *)(pbuff + 4 + p_lda), ymm7); - _mm256_storeu_pd((double *)(pbuff + 4 + p_lda*2), ymm8); - _mm256_storeu_pd((double *)(pbuff + 4 + p_lda*3), ymm9); + //(ROw6): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*6)); + ymm19 = _mm256_fnmadd_ps(ymm0, ymm3, ymm19); - ymm4 = _mm256_unpacklo_pd(ymm10, ymm11); - ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); - ymm0 = _mm256_unpackhi_pd(ymm10, ymm11); - ymm1 = _mm256_unpackhi_pd(ymm12, ymm13); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*7)); + ymm20 = _mm256_fnmadd_ps(ymm0, ymm3, ymm20); - _mm256_storeu_pd((double *)(pbuff + 4 + p_lda * 4), ymm6); - _mm256_storeu_pd((double *)(pbuff + 4 + p_lda * 5), ymm7); - _mm256_storeu_pd((double *)(pbuff + 4 + p_lda * 6), ymm8); - _mm256_storeu_pd((double *)(pbuff + 4 + p_lda * 7), ymm9); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*8)); + ymm14 = _mm256_fnmadd_ps(ymm0, ymm3, ymm14); - inbuf += mr; - pbuff += mr*mr; - } - }else - { - //Expected multiples of 4 - p_lda = 8; - for(dim_t x = 0; x < size; x++) - { - ymm0 = _mm256_loadu_pd((double const *)(inbuf)); - _mm256_storeu_pd((double *)(pbuff), ymm0); - ymm1 = _mm256_loadu_pd((double const *)(inbuf + 4)); - _mm256_storeu_pd((double *)(pbuff + 4), ymm1); - inbuf+=cs_a; - pbuff+=p_lda; - } + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*9)); + ymm15 = _mm256_fnmadd_ps(ymm0, ymm3, ymm15); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*10)); + ymm8 = _mm256_fnmadd_ps(ymm0, ymm3, ymm8 ); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*11)); + ymm9 = _mm256_fnmadd_ps(ymm0, ymm3, ymm9 ); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*12)); + ymm21 = _mm256_fnmadd_ps(ymm0, ymm3, ymm21); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*13)); + ymm22 = _mm256_fnmadd_ps(ymm0, ymm3, ymm22); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*14)); + ymm4 = _mm256_fnmadd_ps(ymm0, ymm3, ymm4 ); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*15)); + ymm5 = _mm256_fnmadd_ps(ymm0, ymm3, ymm5 ); + + //perform mul operation + ymm19 = STRSM_SMALL_DIV_OR_SCALE(ymm19, ymm1); + + a11 += rs_a; + + //extract a77 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 7)); + + //(ROw7): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*7)); + ymm20 = _mm256_fnmadd_ps(ymm0, ymm19, ymm20); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*8)); + ymm14 = _mm256_fnmadd_ps(ymm0, ymm19, ymm14); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*9)); + ymm15 = _mm256_fnmadd_ps(ymm0, ymm19, ymm15); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*10)); + ymm8 = _mm256_fnmadd_ps(ymm0, ymm19, ymm8 ); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*11)); + ymm9 = _mm256_fnmadd_ps(ymm0, ymm19, ymm9 ); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*12)); + ymm21 = _mm256_fnmadd_ps(ymm0, ymm19, ymm21); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*13)); + ymm22 = _mm256_fnmadd_ps(ymm0, ymm19, ymm22); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*14)); + ymm4 = _mm256_fnmadd_ps(ymm0, ymm19, ymm4 ); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*15)); + ymm5 = _mm256_fnmadd_ps(ymm0, ymm19, ymm5 ); + + //perform mul operation + ymm20 = STRSM_SMALL_DIV_OR_SCALE(ymm20, ymm1); + + a11 += rs_a; + + //extract a88 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 8)); + + //(ROw8): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*8)); + ymm14 = _mm256_fnmadd_ps(ymm0, ymm20, ymm14); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*9)); + ymm15 = _mm256_fnmadd_ps(ymm0, ymm20, ymm15); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*10)); + ymm8 = _mm256_fnmadd_ps(ymm0, ymm20, ymm8 ); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*11)); + ymm9 = _mm256_fnmadd_ps(ymm0, ymm20, ymm9 ); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*12)); + ymm21 = _mm256_fnmadd_ps(ymm0, ymm20, ymm21); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*13)); + ymm22 = _mm256_fnmadd_ps(ymm0, ymm20, ymm22); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*14)); + ymm4 = _mm256_fnmadd_ps(ymm0, ymm20, ymm4 ); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*15)); + ymm5 = _mm256_fnmadd_ps(ymm0, ymm20, ymm5 ); + + //perform mul operation + ymm14 = STRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); + + a11 += rs_a; + + //extract a99 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 9)); + + //(ROw9): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*9)); + ymm15 = _mm256_fnmadd_ps(ymm0, ymm14, ymm15); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*10)); + ymm8 = _mm256_fnmadd_ps(ymm0, ymm14, ymm8 ); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*11)); + ymm9 = _mm256_fnmadd_ps(ymm0, ymm14, ymm9 ); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*12)); + ymm21 = _mm256_fnmadd_ps(ymm0, ymm14, ymm21); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*13)); + ymm22 = _mm256_fnmadd_ps(ymm0, ymm14, ymm22); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*14)); + ymm4 = _mm256_fnmadd_ps(ymm0, ymm14, ymm4 ); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*15)); + ymm5 = _mm256_fnmadd_ps(ymm0, ymm14, ymm5 ); + + //perform mul operation + ymm15 = STRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); + + a11 += rs_a; + + //extract a10 10 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 10)); + + //(ROw10): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*10)); + ymm8 = _mm256_fnmadd_ps(ymm0, ymm15, ymm8 ); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*11)); + ymm9 = _mm256_fnmadd_ps(ymm0, ymm15, ymm9 ); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*12)); + ymm21 = _mm256_fnmadd_ps(ymm0, ymm15, ymm21); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*13)); + ymm22 = _mm256_fnmadd_ps(ymm0, ymm15, ymm22); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*14)); + ymm4 = _mm256_fnmadd_ps(ymm0, ymm15, ymm4 ); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*15)); + ymm5 = _mm256_fnmadd_ps(ymm0, ymm15, ymm5 ); + + //perform mul operation + ymm8 = STRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); + + a11 += rs_a; + + //extract a11 11 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 11)); + + //(ROw11): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*11)); + ymm9 = _mm256_fnmadd_ps(ymm0, ymm8, ymm9 ); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*12)); + ymm21 = _mm256_fnmadd_ps(ymm0, ymm8, ymm21); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*13)); + ymm22 = _mm256_fnmadd_ps(ymm0, ymm8, ymm22); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*14)); + ymm4 = _mm256_fnmadd_ps(ymm0, ymm8, ymm4 ); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*15)); + ymm5 = _mm256_fnmadd_ps(ymm0, ymm8, ymm5 ); + + //perform mul operation + ymm9 = STRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); + + a11 += rs_a; + + //extract a12 12 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 12)); + + //(ROw12): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*12)); + ymm21 = _mm256_fnmadd_ps(ymm0, ymm9, ymm21); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*13)); + ymm22 = _mm256_fnmadd_ps(ymm0, ymm9, ymm22); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*14)); + ymm4 = _mm256_fnmadd_ps(ymm0, ymm9, ymm4 ); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*15)); + ymm5 = _mm256_fnmadd_ps(ymm0, ymm9, ymm5 ); + + //perform mul operation + ymm21 = STRSM_SMALL_DIV_OR_SCALE(ymm21, ymm1); + + a11 += rs_a; + + //extract a13 13 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 13)); + + //(ROw13): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*13)); + ymm22 = _mm256_fnmadd_ps(ymm0, ymm21, ymm22); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*14)); + ymm4 = _mm256_fnmadd_ps(ymm0, ymm21, ymm4 ); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*15)); + ymm5 = _mm256_fnmadd_ps(ymm0, ymm21, ymm5 ); + + //perform mul operation + ymm22 = STRSM_SMALL_DIV_OR_SCALE(ymm22, ymm1); + + a11 += rs_a; + + //extract a14 14 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 14)); + + //(ROw13): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*14)); + ymm4 = _mm256_fnmadd_ps(ymm0, ymm22, ymm4 ); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*15)); + ymm5 = _mm256_fnmadd_ps(ymm0, ymm22, ymm5 ); + + //perform mul operation + ymm4 = STRSM_SMALL_DIV_OR_SCALE(ymm4, ymm1); + + a11 += rs_a; + + //extract a15 15 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 15)); + + //(ROw15): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*15)); + ymm5 = _mm256_fnmadd_ps(ymm0, ymm4, ymm5 ); + + //perform mul operation + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm1); + + a11 += rs_a; + + BLIS_STRSM_SMALL_NREG_TRANSPOSE_16x6_AND_STORE(b11,cs_b) } - }else if(side=='R'||side=='r') - { - if(trans) + dim_t n_rem = n-j; + if(n_rem >= 4) { - /* - ------------------ ---------- - | | | | | | - | 4x4 | 4x4 | | 4x4 |4x2 | - ------------- ==> ------------- - | | | | | | - | 2x4 | 2x4 | | 2x4 |2x2 | - ------------------- ------------- - */ - for(dim_t x=0; x>2); i++) - { - ymm0 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 0 )); - _mm256_storeu_pd((double *)(pbuff + p_lda * 0), ymm0); - ymm1 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 1 )); - _mm256_storeu_pd((double *)(pbuff + p_lda * 1), ymm1); - ymm2 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 2)); - _mm256_storeu_pd((double *)(pbuff + p_lda * 2), ymm2); - ymm3 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 3 )); - _mm256_storeu_pd((double *)(pbuff + p_lda * 3), ymm3); - ymm0 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 4 )); - _mm256_storeu_pd((double *)(pbuff + p_lda * 4), ymm0); - ymm1 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 5)); - _mm256_storeu_pd((double *)(pbuff + p_lda * 5), ymm1); - inbuf += 4; - pbuff += 4; - } + //perform mul operation + ymm17 = STRSM_SMALL_DIV_OR_SCALE(ymm17, ymm1); - if(size & 0x3) - { - xmm0 = _mm_loadu_pd((double const *)(inbuf + cs_a * 0)); - _mm_storeu_pd((double *)(pbuff + p_lda * 0 ), xmm0); - xmm1 = _mm_loadu_pd((double const *)(inbuf + cs_a * 1)); - _mm_storeu_pd((double *)(pbuff + p_lda * 1), xmm1); - xmm2 = _mm_loadu_pd((double const *)(inbuf + cs_a * 2)); - _mm_storeu_pd((double *)(pbuff + p_lda * 2), xmm2); - xmm3 = _mm_loadu_pd((double const *)(inbuf + cs_a * 3)); - _mm_storeu_pd((double *)(pbuff + p_lda * 3), xmm3); - xmm0 = _mm_loadu_pd((double const *)(inbuf + cs_a * 4)); - _mm_storeu_pd((double *)(pbuff + p_lda * 4), xmm0); - xmm1 = _mm_loadu_pd((double const *)(inbuf + cs_a * 5)); - _mm_storeu_pd((double *)(pbuff + p_lda * 5), xmm1); - } - } - } -} -/* - Pack diagonal elements of A block (8 or 6) into an array - a. This helps in utilze cache line efficiently in TRSM operation - b. store ones when input is unit diagonal -*/ -BLIS_INLINE void dtrsm_small_pack_diag_element -( - bool is_unitdiag, - double *a11, - dim_t cs_a, - double *d11_pack, - dim_t size -) -{ - __m256d ymm0, ymm1, ymm2, ymm3; - __m256d ymm4, ymm5; - double ones = 1.0; - bool is_eight = (size==8) ? 1 : 0; - ymm4 = ymm5 = _mm256_broadcast_sd((double const *)&ones); - if(!is_unitdiag) - { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_sd((double const *)(a11)); - ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a +1)); - ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a*2 + 2)); - ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a*3 + 3)); + a11 += rs_a; - //Pick one element each column and create a 4 element vector and store - ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); - ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); + //extract a33 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 3)); - #ifdef BLIS_DISABLE_TRSM_PREINVERSION - ymm4 = ymm1; - #endif - #ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm4 = _mm256_div_pd(ymm4, ymm1); - #endif + //(ROw5): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*3)); + ymm18 = _mm256_fnmadd_ps(ymm0, ymm17, ymm18); - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_sd((double const *)(a11 + 4 + cs_a*4)); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5 + cs_a*5)); - //Pick one element each column and create a 4 element vector and store - ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); - if(is_eight) { - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6 + cs_a*6)); - ymm3 = _mm256_broadcast_sd((double const *)(a11 + 7 + cs_a*7)); - ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); - } - ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*4)); + ymm2 = _mm256_fnmadd_ps(ymm0, ymm17, ymm2); - #ifdef BLIS_DISABLE_TRSM_PREINVERSION - ymm5 = ymm1; - #endif + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*5)); + ymm3 = _mm256_fnmadd_ps(ymm0, ymm17, ymm3); - #ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm5 = _mm256_div_pd(ymm5, ymm1); - #endif - } - _mm256_store_pd((double *)(d11_pack), ymm4); - if(is_eight){ - _mm256_store_pd((double *)(d11_pack + 4), ymm5); - }else{ - _mm_storeu_pd((double *)(d11_pack + 4), _mm256_extractf128_pd(ymm5,0)); - } -} + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*6)); + ymm19 = _mm256_fnmadd_ps(ymm0, ymm17, ymm19); -/* - * Kernels Table -*/ -trsmsmall_ker_ft ker_fps[4][8] = -{ - {bli_strsm_small_AutXB_AlXB, - bli_strsm_small_AltXB_AuXB, - bli_strsm_small_AltXB_AuXB, - bli_strsm_small_AutXB_AlXB, - bli_strsm_small_XAutB_XAlB, - bli_strsm_small_XAltB_XAuB, - bli_strsm_small_XAltB_XAuB, - bli_strsm_small_XAutB_XAlB }, - - {bli_ctrsm_small_AutXB_AlXB, - bli_ctrsm_small_AltXB_AuXB, - bli_ctrsm_small_AltXB_AuXB, - bli_ctrsm_small_AutXB_AlXB, - bli_ctrsm_small_XAutB_XAlB, - bli_ctrsm_small_XAltB_XAuB, - bli_ctrsm_small_XAltB_XAuB, - bli_ctrsm_small_XAutB_XAlB }, - - {bli_dtrsm_small_AutXB_AlXB, - bli_dtrsm_small_AltXB_AuXB, - bli_dtrsm_small_AltXB_AuXB, - bli_dtrsm_small_AutXB_AlXB, - bli_dtrsm_small_XAutB_XAlB, - bli_dtrsm_small_XAltB_XAuB, - bli_dtrsm_small_XAltB_XAuB, - bli_dtrsm_small_XAutB_XAlB }, - - {bli_ztrsm_small_AutXB_AlXB, - bli_ztrsm_small_AltXB_AuXB, - bli_ztrsm_small_AltXB_AuXB, - bli_ztrsm_small_AutXB_AlXB, - bli_ztrsm_small_XAutB_XAlB, - bli_ztrsm_small_XAltB_XAuB, - bli_ztrsm_small_XAltB_XAuB, - bli_ztrsm_small_XAutB_XAlB }, -}; + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*7)); + ymm20 = _mm256_fnmadd_ps(ymm0, ymm17, ymm20); -/* -* The bli_trsm_small implements a version of TRSM where A is packed and reused -* -* Input: A: MxM (triangular matrix) -* B: MxN matrix -* Output: X: MxN matrix such that - AX = alpha*B or XA = alpha*B or A'X = alpha*B or XA' = alpha*B -* Here the output X is stored in B -* -* Note: Currently only dtrsm is supported when A & B are column-major -*/ -err_t bli_trsm_small -( - side_t side, - obj_t* alpha, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl -) -{ - err_t err; - dim_t m = bli_obj_length(b); - dim_t n = bli_obj_width(b); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*8)); + ymm14 = _mm256_fnmadd_ps(ymm0, ymm17, ymm14); - if(!(m && n)) { - return BLIS_SUCCESS; - } + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*9)); + ymm15 = _mm256_fnmadd_ps(ymm0, ymm17, ymm15); - bool uplo = bli_obj_is_upper(a); - bool transa = bli_obj_has_trans(a); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*10)); + ymm8 = _mm256_fnmadd_ps(ymm0, ymm17, ymm8 ); - /* ToDo: Temporary threshold condition for trsm single thread. - * It will be updated with arch based threshold function which reads - * tunned thresholds for all 64 (datatype,side,uplo,transa,unit,) trsm - combinations. We arrived to this condition based on performance - comparsion with only available native path - */ - if(m > 1000 || n > 1000) { - return BLIS_NOT_YET_IMPLEMENTED; - } + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*11)); + ymm9 = _mm256_fnmadd_ps(ymm0, ymm17, ymm9 ); - /* If alpha is zero, B matrix will become zero after scaling - hence solution is also zero matrix */ - if (bli_obj_equals(alpha, &BLIS_ZERO)) { - return BLIS_NOT_YET_IMPLEMENTED; // scale B by alpha - } + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*12)); + ymm21 = _mm256_fnmadd_ps(ymm0, ymm17, ymm21); - // Return if inputs are row major as currently - // we are supporing col major only - if ((bli_obj_row_stride(a) != 1) || - (bli_obj_row_stride(b) != 1)) { - return BLIS_INVALID_ROW_STRIDE; - } + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*13)); + ymm22 = _mm256_fnmadd_ps(ymm0, ymm17, ymm22); - //Curretnly optimized for double data type only - num_t dt = bli_obj_dt(a); - if (dt != BLIS_DOUBLE && dt != BLIS_DCOMPLEX) { - return BLIS_NOT_YET_IMPLEMENTED; - } + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*14)); + ymm4 = _mm256_fnmadd_ps(ymm0, ymm17, ymm4 ); - // A is expected to be triangular in trsm - if (!bli_obj_is_upper_or_lower (a)) { - return BLIS_EXPECTED_TRIANGULAR_OBJECT; - } + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*15)); + ymm5 = _mm256_fnmadd_ps(ymm0, ymm17, ymm5 ); - /* - * Compose kernel index based on inputs - */ + //perform mul operation + ymm18 = STRSM_SMALL_DIV_OR_SCALE(ymm18, ymm1); + a11 += rs_a; - dim_t keridx = ( (( side & 0x1) << 2) | - (( uplo & 0x1) << 1) | - ( transa & 0x1) ); + //extract a44 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 4)); + //(ROw4): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*4)); + ymm2 = _mm256_fnmadd_ps(ymm0, ymm18, ymm2); - trsmsmall_ker_ft ker_fp = ker_fps[dt][ keridx ]; + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*5)); + ymm3 = _mm256_fnmadd_ps(ymm0, ymm18, ymm3); - /*Call the kernel*/ - err = ker_fp - ( - alpha, - a, - b, - cntx, - cntl - ); - return err; -}; + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*6)); + ymm19 = _mm256_fnmadd_ps(ymm0, ymm18, ymm19); -/*implements TRSM for the case XA = alpha * B - *A is lower triangular, non-unit diagonal/unit diagonal, transpose - *dimensions: X:mxn A:nxn B: mxn - * - * b11---> a01 ----> - ***************** *********** - *b01*b11* * * * * * * -b11 * * * * * **a01 * * a11 - | ***************** ********* | - | * * * * * *a11* * | - | * * * * * * * * | - v ***************** ****** v - * * * * * * * - * * * * * * * - ***************** * * - * - *implements TRSM for the case XA = alpha * B - *A is upper triangular, non-unit diagonal/unit diagonal, no transpose - *dimensions: X:mxn A:nxn B: mxn - * - * b11---> a01 ----> - ***************** *********** - *b01*b11* * * * * * * -b11 * * * * * **a01 * * a11 - | ***************** ********* | - | * * * * * *a11* * | - | * * * * * * * * | - v ***************** ****** v - * * * * * * * - * * * * * * * - ***************** * * - * + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*7)); + ymm20 = _mm256_fnmadd_ps(ymm0, ymm18, ymm20); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*8)); + ymm14 = _mm256_fnmadd_ps(ymm0, ymm18, ymm14); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*9)); + ymm15 = _mm256_fnmadd_ps(ymm0, ymm18, ymm15); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*10)); + ymm8 = _mm256_fnmadd_ps(ymm0, ymm18, ymm8 ); -*/ + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*11)); + ymm9 = _mm256_fnmadd_ps(ymm0, ymm18, ymm9 ); -BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB -( - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl -) -{ - dim_t m = bli_obj_length(b); //number of rows - dim_t n = bli_obj_width(b); //number of columns - dim_t d_mr = 8,d_nr = 6; + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*12)); + ymm21 = _mm256_fnmadd_ps(ymm0, ymm18, ymm21); - bool transa = bli_obj_has_trans(a); - dim_t cs_a, rs_a; + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*13)); + ymm22 = _mm256_fnmadd_ps(ymm0, ymm18, ymm22); - // Swap rs_a & cs_a in case of non-tranpose. - if(transa) - { - cs_a = bli_obj_col_stride(a); // column stride of A - rs_a = bli_obj_row_stride(a); // row stride of A - } - else - { - cs_a = bli_obj_row_stride(a); // row stride of A - rs_a = bli_obj_col_stride(a); // column stride of A - } - dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*14)); + ymm4 = _mm256_fnmadd_ps(ymm0, ymm18, ymm4 ); - dim_t i, j, k; //loop variablse - dim_t k_iter; //determines the number of GEMM operations to be done + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*15)); + ymm5 = _mm256_fnmadd_ps(ymm0, ymm18, ymm5 ); - double ones = 1.0; - double zero = 0.0; - bool is_unitdiag = bli_obj_has_unit_diag(a); + //perform mul operation + ymm2 = STRSM_SMALL_DIV_OR_SCALE(ymm2, ymm1); - double AlphaVal = *(double *)AlphaObj->buffer; //value of Alpha - double* restrict L = a->buffer; //pointer to matrix A - double* restrict B = b->buffer; //pointer to matrix B + a11 += rs_a; - double *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks + //extract a55 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 5)); - gint_t required_packing_A = 1; - mem_t local_mem_buf_A_s = {0}; - double *D_A_pack = NULL; - double d11_pack[d_mr] __attribute__((aligned(64))); - rntm_t rntm; + //(ROw5): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*5)); + ymm3 = _mm256_fnmadd_ps(ymm0, ymm2, ymm3); - bli_rntm_init_from_global( &rntm ); - bli_rntm_set_num_threads_only( 1, &rntm ); - bli_membrk_rntm_set_membrk( &rntm ); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*6)); + ymm19 = _mm256_fnmadd_ps(ymm0, ymm2, ymm19); - siz_t buffer_size = bli_pool_block_size( - bli_membrk_pool( - bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), - bli_rntm_membrk(&rntm))); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*7)); + ymm20 = _mm256_fnmadd_ps(ymm0, ymm2, ymm20); - if( (d_nr * n * sizeof(double)) > buffer_size) - return BLIS_NOT_YET_IMPLEMENTED; + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*8)); + ymm14 = _mm256_fnmadd_ps(ymm0, ymm2, ymm14); - if (required_packing_A == 1) - { - // Get the buffer from the pool. - bli_membrk_acquire_m(&rntm, - buffer_size, - BLIS_BITVAL_BUFFER_FOR_A_BLOCK, - &local_mem_buf_A_s); - if(FALSE==bli_mem_is_alloc(&local_mem_buf_A_s)) return BLIS_NULL_POINTER; - D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); - if(NULL==D_A_pack) return BLIS_NULL_POINTER; - } + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*9)); + ymm15 = _mm256_fnmadd_ps(ymm0, ymm2, ymm15); - //ymm scratch reginsters - __m256d ymm0, ymm1, ymm2, ymm3; - __m256d ymm4, ymm5, ymm6, ymm7; - __m256d ymm8, ymm9, ymm10, ymm11; - __m256d ymm12, ymm13, ymm14, ymm15; + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*10)); + ymm8 = _mm256_fnmadd_ps(ymm0, ymm2, ymm8 ); - __m128d xmm5; + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*11)); + ymm9 = _mm256_fnmadd_ps(ymm0, ymm2, ymm9 ); - /* - Performs solving TRSM for 6 rows at a time from 0 to n/6 in steps of d_nr - a. Load and pack A (a01 block), the size of packing 6x6 to 6x (n-6) - First there will be no GEMM and no packing of a01 because it is only TRSM - b. Using packed a01 block and b10 block perform GEMM operation - c. Use GEMM outputs, perform TRSM operation using a11, b11 and update B - d. Repeat b for m cols of B in steps of d_mr - */ + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*12)); + ymm21 = _mm256_fnmadd_ps(ymm0, ymm2, ymm21); - for(j = 0; (j+d_nr-1) < n; j += d_nr) //loop along 'N' direction - { - a01 = L + j*rs_a; //pointer to block of A to be used in GEMM - a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*13)); + ymm22 = _mm256_fnmadd_ps(ymm0, ymm2, ymm22); - //double *ptr_a10_dup = D_A_pack; + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*14)); + ymm4 = _mm256_fnmadd_ps(ymm0, ymm2, ymm4 ); - dim_t p_lda = j; // packed leading dimension - // perform copy of A to packed buffer D_A_pack + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*15)); + ymm5 = _mm256_fnmadd_ps(ymm0, ymm2, ymm5 ); - if(transa) - { - /* - Pack current A block (a01) into packed buffer memory D_A_pack - a. This a10 block is used in GEMM portion only and this - a01 block size will be increasing by d_nr for every next iteration - until it reaches 6x(n-6) which is the maximum GEMM alone block size in A - b. This packed buffer is reused to calculate all m cols of B matrix - */ - bli_dtrsm_small_pack('R', j, 1, a01, cs_a, D_A_pack, p_lda,d_nr); + //perform mul operation + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm1); - /* - Pack 6 diagonal elements of A block into an array - a. This helps in utilze cache line efficiently in TRSM operation - b. store ones when input is unit diagonal - */ + a11 += rs_a; - dtrsm_small_pack_diag_element(is_unitdiag,a11,cs_a,d11_pack,d_nr); - } - else - { - bli_dtrsm_small_pack('R', j, 0, a01, rs_a, D_A_pack, p_lda,d_nr); - dtrsm_small_pack_diag_element(is_unitdiag,a11,rs_a,d11_pack,d_nr); - } + //extract a66 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 6)); - /* - a. Perform GEMM using a01, b10. - b. Perform TRSM on a11, b11 - c. This loop GEMM+TRSM loops operates with 8x6 block size - along m dimension for every d_mr columns of B10 where - packed A buffer is reused in computing all m cols of B. - d. Same approach is used in remaining fringe cases. - */ - for(i = 0; (i+d_mr-1) < m; i += d_mr) //loop along 'M' direction - { - a01 = D_A_pack; - a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + //(ROw6): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*6)); + ymm19 = _mm256_fnmadd_ps(ymm0, ymm3, ymm19); - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*7)); + ymm20 = _mm256_fnmadd_ps(ymm0, ymm3, ymm20); - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*8)); + ymm14 = _mm256_fnmadd_ps(ymm0, ymm3, ymm14); - /* - Peform GEMM between a01 and b10 blocks - For first itteration there will be no GEMM operation - where k_iter are zero - */ - BLIS_DTRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*9)); + ymm15 = _mm256_fnmadd_ps(ymm0, ymm3, ymm15); - /* - Load b11 of size 8x6 and multiply with alpha - Add the GEMM output to b11 - and peform TRSM operation. - */ + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*10)); + ymm8 = _mm256_fnmadd_ps(ymm0, ymm3, ymm8 ); - BLIS_PRE_DTRSM_SMALL_6x8(AlphaVal,b11,cs_b) + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*11)); + ymm9 = _mm256_fnmadd_ps(ymm0, ymm3, ymm9 ); - ///implement TRSM/// + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*12)); + ymm21 = _mm256_fnmadd_ps(ymm0, ymm3, ymm21); - /* - Compute 6x8 TRSM block by using GEMM block output in register - a. The 6x8 input (gemm outputs) are stored in combinations of ymm registers - 1. ymm3, ymm4 2. ymm5, ymm6 3. ymm7, ymm8, 4. ymm9, ymm10 - 5. ymm11, ymm12 6. ymm13,ymm14 - b. Towards the end TRSM output will be stored back into b11 - */ + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*13)); + ymm22 = _mm256_fnmadd_ps(ymm0, ymm3, ymm22); - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*14)); + ymm4 = _mm256_fnmadd_ps(ymm0, ymm3, ymm4 ); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm0); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*15)); + ymm5 = _mm256_fnmadd_ps(ymm0, ymm3, ymm5 ); - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + //perform mul operation + ymm19 = STRSM_SMALL_DIV_OR_SCALE(ymm19, ymm1); - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); + a11 += rs_a; - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - ymm6 = _mm256_fnmadd_pd(ymm1, ymm4, ymm6); + //extract a77 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 7)); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + //(ROw7): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*7)); + ymm20 = _mm256_fnmadd_ps(ymm0, ymm19, ymm20); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); - ymm8 = _mm256_fnmadd_pd(ymm1, ymm4, ymm8); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*8)); + ymm14 = _mm256_fnmadd_ps(ymm0, ymm19, ymm14); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*9)); + ymm15 = _mm256_fnmadd_ps(ymm0, ymm19, ymm15); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm3, ymm9); - ymm10 = _mm256_fnmadd_pd(ymm1, ymm4, ymm10); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*10)); + ymm8 = _mm256_fnmadd_ps(ymm0, ymm19, ymm8 ); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*11)); + ymm9 = _mm256_fnmadd_ps(ymm0, ymm19, ymm9 ); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm3, ymm11); - ymm12 = _mm256_fnmadd_pd(ymm1, ymm4, ymm12); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*12)); + ymm21 = _mm256_fnmadd_ps(ymm0, ymm19, ymm21); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*13)); + ymm22 = _mm256_fnmadd_ps(ymm0, ymm19, ymm22); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm3, ymm13); - ymm14 = _mm256_fnmadd_pd(ymm1, ymm4, ymm14); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*14)); + ymm4 = _mm256_fnmadd_ps(ymm0, ymm19, ymm4 ); - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm0); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*15)); + ymm5 = _mm256_fnmadd_ps(ymm0, ymm19, ymm5 ); - a11 += cs_a; + //perform mul operation + ymm20 = STRSM_SMALL_DIV_OR_SCALE(ymm20, ymm1); - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + a11 += rs_a; - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + //extract a88 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 8)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); - ymm8 = _mm256_fnmadd_pd(ymm1, ymm6, ymm8); + //(ROw8): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*8)); + ymm14 = _mm256_fnmadd_ps(ymm0, ymm20, ymm14); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*9)); + ymm15 = _mm256_fnmadd_ps(ymm0, ymm20, ymm15); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm5, ymm9); - ymm10 = _mm256_fnmadd_pd(ymm1, ymm6, ymm10); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*10)); + ymm8 = _mm256_fnmadd_ps(ymm0, ymm20, ymm8 ); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*11)); + ymm9 = _mm256_fnmadd_ps(ymm0, ymm20, ymm9 ); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm5, ymm11); - ymm12 = _mm256_fnmadd_pd(ymm1, ymm6, ymm12); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*12)); + ymm21 = _mm256_fnmadd_ps(ymm0, ymm20, ymm21); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*13)); + ymm22 = _mm256_fnmadd_ps(ymm0, ymm20, ymm22); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm5, ymm13); - ymm14 = _mm256_fnmadd_pd(ymm1, ymm6, ymm14); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*14)); + ymm4 = _mm256_fnmadd_ps(ymm0, ymm20, ymm4 ); - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm0); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*15)); + ymm5 = _mm256_fnmadd_ps(ymm0, ymm20, ymm5 ); - a11 += cs_a; + //perform mul operation + ymm14 = STRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); - //extract a33 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + a11 += rs_a; - //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + //extract a99 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 9)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm7, ymm9); - ymm10 = _mm256_fnmadd_pd(ymm1, ymm8, ymm10); + //(ROw9): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*9)); + ymm15 = _mm256_fnmadd_ps(ymm0, ymm14, ymm15); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*10)); + ymm8 = _mm256_fnmadd_ps(ymm0, ymm14, ymm8 ); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm7, ymm11); - ymm12 = _mm256_fnmadd_pd(ymm1, ymm8, ymm12); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*11)); + ymm9 = _mm256_fnmadd_ps(ymm0, ymm14, ymm9 ); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*12)); + ymm21 = _mm256_fnmadd_ps(ymm0, ymm14, ymm21); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm7, ymm13); - ymm14 = _mm256_fnmadd_pd(ymm1, ymm8, ymm14); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*13)); + ymm22 = _mm256_fnmadd_ps(ymm0, ymm14, ymm22); - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm0); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*14)); + ymm4 = _mm256_fnmadd_ps(ymm0, ymm14, ymm4 ); - a11 += cs_a; + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*15)); + ymm5 = _mm256_fnmadd_ps(ymm0, ymm14, ymm5 ); - //extract a44 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + //perform mul operation + ymm15 = STRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); - //(row 4):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); + a11 += rs_a; - ymm11 = _mm256_fnmadd_pd(ymm1, ymm9, ymm11); - ymm12 = _mm256_fnmadd_pd(ymm1, ymm10, ymm12); + //extract a10 10 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 10)); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); + //(ROw10): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*10)); + ymm8 = _mm256_fnmadd_ps(ymm0, ymm15, ymm8 ); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm9, ymm13); - ymm14 = _mm256_fnmadd_pd(ymm1, ymm10, ymm14); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*11)); + ymm9 = _mm256_fnmadd_ps(ymm0, ymm15, ymm9 ); - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); - ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm0); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*12)); + ymm21 = _mm256_fnmadd_ps(ymm0, ymm15, ymm21); - a11 += cs_a; + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*13)); + ymm22 = _mm256_fnmadd_ps(ymm0, ymm15, ymm22); - //extract a55 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*14)); + ymm4 = _mm256_fnmadd_ps(ymm0, ymm15, ymm4 ); - //(Row 5): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*15)); + ymm5 = _mm256_fnmadd_ps(ymm0, ymm15, ymm5 ); + + //perform mul operation + ymm8 = STRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); + + a11 += rs_a; + + //extract a11 11 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 11)); + + //(ROw11): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*11)); + ymm9 = _mm256_fnmadd_ps(ymm0, ymm8, ymm9 ); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*12)); + ymm21 = _mm256_fnmadd_ps(ymm0, ymm8, ymm21); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*13)); + ymm22 = _mm256_fnmadd_ps(ymm0, ymm8, ymm22); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*14)); + ymm4 = _mm256_fnmadd_ps(ymm0, ymm8, ymm4 ); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*15)); + ymm5 = _mm256_fnmadd_ps(ymm0, ymm8, ymm5 ); + + //perform mul operation + ymm9 = STRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); + + a11 += rs_a; + + //extract a12 12 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 12)); + + //(ROw12): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*12)); + ymm21 = _mm256_fnmadd_ps(ymm0, ymm9, ymm21); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*13)); + ymm22 = _mm256_fnmadd_ps(ymm0, ymm9, ymm22); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*14)); + ymm4 = _mm256_fnmadd_ps(ymm0, ymm9, ymm4 ); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*15)); + ymm5 = _mm256_fnmadd_ps(ymm0, ymm9, ymm5 ); + + //perform mul operation + ymm21 = STRSM_SMALL_DIV_OR_SCALE(ymm21, ymm1); + + a11 += rs_a; - ymm13 = _mm256_fnmadd_pd(ymm1, ymm11, ymm13); - ymm14 = _mm256_fnmadd_pd(ymm1, ymm12, ymm14); + //extract a13 13 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 13)); - ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm0); + //(ROw13): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*13)); + ymm22 = _mm256_fnmadd_ps(ymm0, ymm21, ymm22); - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + 4), ymm4); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b + 4), ymm6); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*2 + 4), ymm8); - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); - _mm256_storeu_pd((double *)(b11 + cs_b*3 + 4), ymm10); - _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); - _mm256_storeu_pd((double *)(b11 + cs_b*4 + 4), ymm12); - _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); - _mm256_storeu_pd((double *)(b11 + cs_b*5 + 4), ymm14); - } + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*14)); + ymm4 = _mm256_fnmadd_ps(ymm0, ymm21, ymm4 ); - dim_t m_remainder = m - i; - if(m_remainder >= 4) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*15)); + ymm5 = _mm256_fnmadd_ps(ymm0, ymm21, ymm5 ); - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + //perform mul operation + ymm22 = STRSM_SMALL_DIV_OR_SCALE(ymm22, ymm1); - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS + a11 += rs_a; - ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_6nx4m(a01,b10,cs_b,p_lda,k_iter) + //extract a14 14 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 14)); - // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_DTRSM_SMALL_6x4(AlphaVal,b11,cs_b) + //(ROw13): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*14)); + ymm4 = _mm256_fnmadd_ps(ymm0, ymm22, ymm4 ); - ///implement TRSM/// + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*15)); + ymm5 = _mm256_fnmadd_ps(ymm0, ymm22, ymm5 ); - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + //perform mul operation + ymm4 = STRSM_SMALL_DIV_OR_SCALE(ymm4, ymm1); - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + a11 += rs_a; - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); + //extract a15 15 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 15)); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); + //(ROw15): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*15)); + ymm5 = _mm256_fnmadd_ps(ymm0, ymm4, ymm5 ); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm3, ymm9); + //perform mul operation + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm1); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm3, ymm11); + a11 += rs_a; - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm3, ymm13); + // N-register tranpose and store - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + ymm0 = _mm256_unpacklo_ps(ymm10, ymm11); + ymm1 = _mm256_unpacklo_ps(ymm17, ymm18); - a11 += cs_a; + ymm6 = _mm256_unpacklo_ps(ymm2, ymm3); + ymm7 = _mm256_unpacklo_ps(ymm19, ymm20); - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b01000100); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b01000100); - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//1 + _mm256_storeu_ps((float *)(b11), ymm16); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm5, ymm9); + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b11101110); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b11101110); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm5, ymm11); + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//2 + _mm256_storeu_ps((float *)(b11 + cs_b), ymm16); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm5, ymm13); + ymm0 = _mm256_unpackhi_ps(ymm10, ymm11); + ymm1 = _mm256_unpackhi_ps(ymm17, ymm18); - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + ymm6 = _mm256_unpackhi_ps(ymm2, ymm3); + ymm7 = _mm256_unpackhi_ps(ymm19, ymm20); - a11 += cs_a; + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b01000100); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b01000100); - //extract a33 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//3 + _mm256_storeu_ps((float *)(b11 + 2*cs_b), ymm16); - //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm7, ymm9); + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b11101110); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b11101110); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm7, ymm11); + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//4 + _mm256_storeu_ps((float *)(b11 + 3*cs_b), ymm16); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm7, ymm13); + ymm0 = _mm256_unpacklo_ps(ymm14, ymm15); + ymm1 = _mm256_unpacklo_ps(ymm8, ymm9); - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + ymm6 = _mm256_unpacklo_ps(ymm21, ymm22); + ymm7 = _mm256_unpacklo_ps(ymm4, ymm5); - a11 += cs_a; + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b01000100); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b01000100); - //extract a44 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//1 + _mm256_storeu_ps((float *)(b11 + 8), ymm16); - //(row 4):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm9, ymm11); + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b11101110); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b11101110); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm9, ymm13); + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//2 + _mm256_storeu_ps((float *)(b11 + cs_b + 8), ymm16); - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); + ymm0 = _mm256_unpackhi_ps(ymm14, ymm15); + ymm1 = _mm256_unpackhi_ps(ymm8, ymm9); - a11 += cs_a; + ymm6 = _mm256_unpackhi_ps(ymm21, ymm22); + ymm7 = _mm256_unpackhi_ps(ymm4, ymm5); - //extract a55 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b01000100); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b01000100); - //(Row 5): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm11, ymm13); + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//3 + _mm256_storeu_ps((float *)(b11 + 2*cs_b + 8), ymm16); - ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b11101110); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b11101110); - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); - _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); - _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//4 + _mm256_storeu_ps((float *)(b11 + 3*cs_b + 8), ymm16); - m_remainder -= 4; - i += 4; + n_rem -=4; + j +=4; } - if(m_remainder == 3) + if(n_rem) { - a01 = D_A_pack; - a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + a10 = D_A_pack; + a11 = L + (i*rs_a) + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + j*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + k_iter = i; //number of times GEMM to be performed(in blocks of 4x4) /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS + BLIS_SET_S_YMM_REG_ZEROS - ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_6nx4m(a01,b10,cs_b,p_lda,k_iter) + if(3 == n_rem) + { + ///GEMM code begins/// + BLIS_STRSM_SMALL_GEMM_16mx3n(a10,b01,cs_b,p_lda,k_iter) - // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_DTRSM_SMALL_6x4(AlphaVal,b11,cs_b) + /* + Load b11 of size 6x16 and multiply with alpha + Add the GEMM output and perform inregister transose of b11 + to peform TRSM operation. + */ + ymm16 = _mm256_broadcast_ss((float const *)(&AlphaVal)); + ymm0 = _mm256_broadcast_ss((float const *)(&zero)); - ///implement TRSM/// + ymm17 = _mm256_loadu_ps((float const *)(b11)); + ymm18 = _mm256_loadu_ps((float const *)(b11 + cs_b)); + ymm19 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); + ymm20 = _mm256_broadcast_ss((float const *)(&ones)); - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + ymm17 = _mm256_fmsub_ps(ymm17, ymm16, ymm8); + ymm18 = _mm256_fmsub_ps(ymm18, ymm16, ymm9); + ymm19 = _mm256_fmsub_ps(ymm19, ymm16, ymm10); - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + ymm8 = _mm256_unpacklo_ps(ymm17, ymm18); + ymm9 = _mm256_unpacklo_ps(ymm19, ymm20); - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b01000100); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); + ymm10 = _mm256_permute2f128_ps(ymm4,ymm0,0x20);//1 + ymm2 = _mm256_permute2f128_ps(ymm4,ymm0,0x31);//5 - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm3, ymm9); + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b11101110); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm3, ymm11); + ymm11 = _mm256_permute2f128_ps(ymm4,ymm0,0x20);//2 + ymm3 = _mm256_permute2f128_ps(ymm4,ymm0,0x31);//6 - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm3, ymm13); + ymm8 = _mm256_unpackhi_ps(ymm17, ymm18); + ymm9 = _mm256_unpackhi_ps(ymm19, ymm20); - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b01000100); - a11 += cs_a; + ymm17 = _mm256_permute2f128_ps(ymm4,ymm0,0x20);//3 + ymm19 = _mm256_permute2f128_ps(ymm4,ymm0,0x31);//7 - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b11101110); - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); + ymm18 = _mm256_permute2f128_ps(ymm4,ymm0,0x20);//4 + ymm20 = _mm256_permute2f128_ps(ymm4,ymm0,0x31);//8 - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm5, ymm9); + ymm16 = _mm256_broadcast_ss((float const *)(&AlphaVal)); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm5, ymm11); + ymm8 = _mm256_loadu_ps((float const *)(b11 + 8)); + ymm9 = _mm256_loadu_ps((float const *)(b11 + cs_b + 8)); + ymm4 = _mm256_loadu_ps((float const *)(b11 + cs_b*2 + 8)); + ymm5 = _mm256_broadcast_ss((float const *)(&ones)); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm5, ymm13); + ymm8 = _mm256_fmsub_ps(ymm8, ymm16, ymm12); + ymm9 = _mm256_fmsub_ps(ymm9, ymm16, ymm13); + ymm4 = _mm256_fmsub_ps(ymm4, ymm16, ymm14); - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + ymm12 = _mm256_unpacklo_ps(ymm8, ymm9); + ymm13 = _mm256_unpacklo_ps(ymm4, ymm5); - a11 += cs_a; + ymm6 = _mm256_shuffle_ps(ymm12,ymm13,0b01000100); - //extract a33 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + ymm14 = _mm256_permute2f128_ps(ymm6,ymm0,0x20);//1 + ymm21 = _mm256_permute2f128_ps(ymm6,ymm0,0x31);//5 - //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm7, ymm9); + ymm6 = _mm256_shuffle_ps(ymm12,ymm13,0b11101110); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm7, ymm11); + ymm15 = _mm256_permute2f128_ps(ymm6,ymm0,0x20);//2 + ymm22 = _mm256_permute2f128_ps(ymm6,ymm0,0x31);//6 - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm7, ymm13); + ymm12 = _mm256_unpackhi_ps(ymm8, ymm9); + ymm13 = _mm256_unpackhi_ps(ymm4, ymm5); - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + ymm6 = _mm256_shuffle_ps(ymm12,ymm13,0b01000100); - a11 += cs_a; + ymm8 = _mm256_permute2f128_ps(ymm6,ymm0,0x20);//3 + ymm4 = _mm256_permute2f128_ps(ymm6,ymm0,0x31);//7 - //extract a44 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + ymm6 = _mm256_shuffle_ps(ymm12,ymm13,0b11101110); - //(row 4):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm9, ymm11); + ymm9 = _mm256_permute2f128_ps(ymm6,ymm0,0x20);//4 + ymm5 = _mm256_permute2f128_ps(ymm6,ymm0,0x31);//8 + } + else if(2 == n_rem) + { + ///GEMM code begins/// + BLIS_STRSM_SMALL_GEMM_16mx2n(a10,b01,cs_b,p_lda,k_iter) - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm9, ymm13); + /* + Load b11 of size 6x16 and multiply with alpha + Add the GEMM output and perform inregister transose of b11 + to peform TRSM operation. + */ + ymm16 = _mm256_broadcast_ss((float const *)(&AlphaVal)); + ymm0 = _mm256_broadcast_ss((float const *)(&zero)); - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); + ymm17 = _mm256_loadu_ps((float const *)(b11)); + ymm18 = _mm256_loadu_ps((float const *)(b11 + cs_b)); + ymm19 = _mm256_broadcast_ss((float const *)(&ones)); + ymm20 = _mm256_broadcast_ss((float const *)(&ones)); - a11 += cs_a; + ymm17 = _mm256_fmsub_ps(ymm17, ymm16, ymm8); + ymm18 = _mm256_fmsub_ps(ymm18, ymm16, ymm9); - //extract a55 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + ymm8 = _mm256_unpacklo_ps(ymm17, ymm18); + ymm9 = _mm256_unpacklo_ps(ymm19, ymm20); - //(Row 5): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm11, ymm13); + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b01000100); - ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); + ymm10 = _mm256_permute2f128_ps(ymm4,ymm0,0x20);//1 + ymm2 = _mm256_permute2f128_ps(ymm4,ymm0,0x31);//5 - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_pd(ymm0, ymm11, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_pd(ymm0, ymm13, 0x07); + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b11101110); - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); - _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); - _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); + ymm11 = _mm256_permute2f128_ps(ymm4,ymm0,0x20);//2 + ymm3 = _mm256_permute2f128_ps(ymm4,ymm0,0x31);//6 - m_remainder -= 3; - i += 3; - } - else if(m_remainder == 2) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + ymm8 = _mm256_unpackhi_ps(ymm17, ymm18); + ymm9 = _mm256_unpackhi_ps(ymm19, ymm20); - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b01000100); - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS + ymm17 = _mm256_permute2f128_ps(ymm4,ymm0,0x20);//3 + ymm19 = _mm256_permute2f128_ps(ymm4,ymm0,0x31);//7 - ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_6nx4m(a01,b10,cs_b,p_lda,k_iter) + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b11101110); - // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_DTRSM_SMALL_6x4(AlphaVal,b11,cs_b) + ymm18 = _mm256_permute2f128_ps(ymm4,ymm0,0x20);//4 + ymm20 = _mm256_permute2f128_ps(ymm4,ymm0,0x31);//8 - ///implement TRSM/// + ymm16 = _mm256_broadcast_ss((float const *)(&AlphaVal)); - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + ymm8 = _mm256_loadu_ps((float const *)(b11 + 8)); + ymm9 = _mm256_loadu_ps((float const *)(b11 + cs_b + 8)); + ymm4 = _mm256_broadcast_ss((float const *)(&ones)); + ymm5 = _mm256_broadcast_ss((float const *)(&ones)); - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + ymm8 = _mm256_fmsub_ps(ymm8, ymm16, ymm12); + ymm9 = _mm256_fmsub_ps(ymm9, ymm16, ymm13); - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); + ymm12 = _mm256_unpacklo_ps(ymm8, ymm9); + ymm13 = _mm256_unpacklo_ps(ymm4, ymm5); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); + ymm6 = _mm256_shuffle_ps(ymm12,ymm13,0b01000100); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm3, ymm9); + ymm14 = _mm256_permute2f128_ps(ymm6,ymm0,0x20);//1 + ymm21 = _mm256_permute2f128_ps(ymm6,ymm0,0x31);//5 - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm3, ymm11); + ymm6 = _mm256_shuffle_ps(ymm12,ymm13,0b11101110); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm3, ymm13); + ymm15 = _mm256_permute2f128_ps(ymm6,ymm0,0x20);//2 + ymm22 = _mm256_permute2f128_ps(ymm6,ymm0,0x31);//6 - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + ymm12 = _mm256_unpackhi_ps(ymm8, ymm9); + ymm13 = _mm256_unpackhi_ps(ymm4, ymm5); - a11 += cs_a; + ymm6 = _mm256_shuffle_ps(ymm12,ymm13,0b01000100); - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + ymm8 = _mm256_permute2f128_ps(ymm6,ymm0,0x20);//3 + ymm4 = _mm256_permute2f128_ps(ymm6,ymm0,0x31);//7 - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); + ymm6 = _mm256_shuffle_ps(ymm12,ymm13,0b11101110); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm5, ymm9); + ymm9 = _mm256_permute2f128_ps(ymm6,ymm0,0x20);//4 + ymm5 = _mm256_permute2f128_ps(ymm6,ymm0,0x31);//8 + } + else if(1 == n_rem) + { + ///GEMM code begins/// + BLIS_STRSM_SMALL_GEMM_16mx1n(a10,b01,cs_b,p_lda,k_iter) - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm5, ymm11); + /* + Load b11 of size 6x16 and multiply with alpha + Add the GEMM output and perform inregister transose of b11 + to peform TRSM operation. + */ + ymm16 = _mm256_broadcast_ss((float const *)(&AlphaVal)); + ymm0 = _mm256_broadcast_ss((float const *)(&zero)); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm5, ymm13); + ymm17 = _mm256_loadu_ps((float const *)(b11)); + ymm18 = _mm256_broadcast_ss((float const *)(&ones)); + ymm19 = _mm256_broadcast_ss((float const *)(&ones)); + ymm20 = _mm256_broadcast_ss((float const *)(&ones)); - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + ymm17 = _mm256_fmsub_ps(ymm17, ymm16, ymm8); - a11 += cs_a; + ymm8 = _mm256_unpacklo_ps(ymm17, ymm18); + ymm9 = _mm256_unpacklo_ps(ymm19, ymm20); - //extract a33 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b01000100); - //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm7, ymm9); + ymm10 = _mm256_permute2f128_ps(ymm4,ymm0,0x20);//1 + ymm2 = _mm256_permute2f128_ps(ymm4,ymm0,0x31);//5 - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm7, ymm11); + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b11101110); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm7, ymm13); + ymm11 = _mm256_permute2f128_ps(ymm4,ymm0,0x20);//2 + ymm3 = _mm256_permute2f128_ps(ymm4,ymm0,0x31);//6 - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + ymm8 = _mm256_unpackhi_ps(ymm17, ymm18); + ymm9 = _mm256_unpackhi_ps(ymm19, ymm20); - a11 += cs_a; + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b01000100); - //extract a44 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + ymm17 = _mm256_permute2f128_ps(ymm4,ymm0,0x20);//3 + ymm19 = _mm256_permute2f128_ps(ymm4,ymm0,0x31);//7 - //(row 4):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm9, ymm11); + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b11101110); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm9, ymm13); + ymm18 = _mm256_permute2f128_ps(ymm4,ymm0,0x20);//4 + ymm20 = _mm256_permute2f128_ps(ymm4,ymm0,0x31);//8 - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); + ymm16 = _mm256_broadcast_ss((float const *)(&AlphaVal)); - a11 += cs_a; + ymm8 = _mm256_loadu_ps((float const *)(b11 + 8)); + ymm9 = _mm256_broadcast_ss((float const *)(&ones)); + ymm4 = _mm256_broadcast_ss((float const *)(&ones)); + ymm5 = _mm256_broadcast_ss((float const *)(&ones)); - //extract a55 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + ymm8 = _mm256_fmsub_ps(ymm8, ymm16, ymm12); - //(Row 5): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm11, ymm13); + ymm12 = _mm256_unpacklo_ps(ymm8, ymm9); + ymm13 = _mm256_unpacklo_ps(ymm4, ymm5); - ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); + ymm6 = _mm256_shuffle_ps(ymm12,ymm13,0b01000100); - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_pd(ymm0, ymm11, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_pd(ymm0, ymm13, 0x03); + ymm14 = _mm256_permute2f128_ps(ymm6,ymm0,0x20);//1 + ymm21 = _mm256_permute2f128_ps(ymm6,ymm0,0x31);//5 - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); - _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); - _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); + ymm6 = _mm256_shuffle_ps(ymm12,ymm13,0b11101110); - m_remainder -= 2; - i += 2; - } - else if(m_remainder == 1) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + ymm15 = _mm256_permute2f128_ps(ymm6,ymm0,0x20);//2 + ymm22 = _mm256_permute2f128_ps(ymm6,ymm0,0x31);//6 - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + ymm12 = _mm256_unpackhi_ps(ymm8, ymm9); + ymm13 = _mm256_unpackhi_ps(ymm4, ymm5); - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS + ymm6 = _mm256_shuffle_ps(ymm12,ymm13,0b01000100); - ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_6nx4m(a01,b10,cs_b,p_lda,k_iter) + ymm8 = _mm256_permute2f128_ps(ymm6,ymm0,0x20);//3 + ymm4 = _mm256_permute2f128_ps(ymm6,ymm0,0x31);//7 - // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_DTRSM_SMALL_6x4(AlphaVal,b11,cs_b) + ymm6 = _mm256_shuffle_ps(ymm12,ymm13,0b11101110); - ///implement TRSM/// + ymm9 = _mm256_permute2f128_ps(ymm6,ymm0,0x20);//4 + ymm5 = _mm256_permute2f128_ps(ymm6,ymm0,0x31);//8 + } - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + // TRSM portion - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + ////extract a00 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack)); - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); + //perform mul operation + ymm10 = STRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); + //extract a11 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm3, ymm9); + //(ROw1): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*1)); + ymm11 = _mm256_fnmadd_ps(ymm0, ymm10, ymm11); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm3, ymm11); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*2)); + ymm17 = _mm256_fnmadd_ps(ymm0, ymm10, ymm17); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm3, ymm13); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*3)); + ymm18 = _mm256_fnmadd_ps(ymm0, ymm10, ymm18); - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*4)); + ymm2 = _mm256_fnmadd_ps(ymm0, ymm10, ymm2); - a11 += cs_a; + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*5)); + ymm3 = _mm256_fnmadd_ps(ymm0, ymm10, ymm3); - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*6)); + ymm19 = _mm256_fnmadd_ps(ymm0, ymm10, ymm19); - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*7)); + ymm20 = _mm256_fnmadd_ps(ymm0, ymm10, ymm20); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm5, ymm9); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*8)); + ymm14 = _mm256_fnmadd_ps(ymm0, ymm10, ymm14); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm5, ymm11); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*9)); + ymm15 = _mm256_fnmadd_ps(ymm0, ymm10, ymm15); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm5, ymm13); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*10)); + ymm8 = _mm256_fnmadd_ps(ymm0, ymm10, ymm8 ); - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*11)); + ymm9 = _mm256_fnmadd_ps(ymm0, ymm10, ymm9 ); - a11 += cs_a; + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*12)); + ymm21 = _mm256_fnmadd_ps(ymm0, ymm10, ymm21); - //extract a33 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*13)); + ymm22 = _mm256_fnmadd_ps(ymm0, ymm10, ymm22); - //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm7, ymm9); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*14)); + ymm4 = _mm256_fnmadd_ps(ymm0, ymm10, ymm4 ); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm7, ymm11); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*15)); + ymm5 = _mm256_fnmadd_ps(ymm0, ymm10, ymm5 ); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm7, ymm13); + //perform mul operation + ymm11 = STRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + a11 += rs_a; - a11 += cs_a; + //extract a22 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 2)); - //extract a44 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + //(ROw2): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*2)); + ymm17 = _mm256_fnmadd_ps(ymm0, ymm11, ymm17); - //(row 4):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm9, ymm11); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*3)); + ymm18 = _mm256_fnmadd_ps(ymm0, ymm11, ymm18); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm9, ymm13); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*4)); + ymm2 = _mm256_fnmadd_ps(ymm0, ymm11, ymm2); - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*5)); + ymm3 = _mm256_fnmadd_ps(ymm0, ymm11, ymm3); - a11 += cs_a; + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*6)); + ymm19 = _mm256_fnmadd_ps(ymm0, ymm11, ymm19); - //extract a55 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*7)); + ymm20 = _mm256_fnmadd_ps(ymm0, ymm11, ymm20); - //(Row 5): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm11, ymm13); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*8)); + ymm14 = _mm256_fnmadd_ps(ymm0, ymm11, ymm14); - ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*9)); + ymm15 = _mm256_fnmadd_ps(ymm0, ymm11, ymm15); - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_pd(ymm0, ymm11, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_pd(ymm0, ymm13, 0x01); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*10)); + ymm8 = _mm256_fnmadd_ps(ymm0, ymm11, ymm8 ); - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); - _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); - _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*11)); + ymm9 = _mm256_fnmadd_ps(ymm0, ymm11, ymm9 ); - m_remainder -= 1; - i += 1; - } - } + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*12)); + ymm21 = _mm256_fnmadd_ps(ymm0, ymm11, ymm21); - dim_t n_remainder = n - j; + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*13)); + ymm22 = _mm256_fnmadd_ps(ymm0, ymm11, ymm22); - /* - Reminder cases starts here: - a. Similar logic and code flow used in computing full block (6x8) - above holds for reminder cases too. - */ + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*14)); + ymm4 = _mm256_fnmadd_ps(ymm0, ymm11, ymm4 ); - if(n_remainder >= 4) - { - a01 = L + j*rs_a; //pointer to block of A to be used in GEMM - a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*15)); + ymm5 = _mm256_fnmadd_ps(ymm0, ymm11, ymm5 ); - double *ptr_a10_dup = D_A_pack; + //perform mul operation + ymm17 = STRSM_SMALL_DIV_OR_SCALE(ymm17, ymm1); - dim_t p_lda = j; // packed leading dimension - // perform copy of A to packed buffer D_A_pack + a11 += rs_a; - if(transa) - { - for(dim_t x =0;x < p_lda;x+=d_nr) - { - ymm0 = _mm256_loadu_pd((double const *)(a01)); - ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); - ymm2 = _mm256_loadu_pd((double const *)(a01 + cs_a * 2)); - ymm3 = _mm256_loadu_pd((double const *)(a01 + cs_a * 3)); + //extract a33 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 3)); - ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + //(ROw5): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*3)); + ymm18 = _mm256_fnmadd_ps(ymm0, ymm17, ymm18); - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*4)); + ymm2 = _mm256_fnmadd_ps(ymm0, ymm17, ymm2); - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*5)); + ymm3 = _mm256_fnmadd_ps(ymm0, ymm17, ymm3); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*6)); + ymm19 = _mm256_fnmadd_ps(ymm0, ymm17, ymm19); - _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*7)); + ymm20 = _mm256_fnmadd_ps(ymm0, ymm17, ymm20); - ymm0 = _mm256_loadu_pd((double const *)(a01 + cs_a * 4)); - ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a * 5)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*8)); + ymm14 = _mm256_fnmadd_ps(ymm0, ymm17, ymm14); - ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_broadcast_sd((double const *)&zero); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*9)); + ymm15 = _mm256_fnmadd_ps(ymm0, ymm17, ymm15); - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*10)); + ymm8 = _mm256_fnmadd_ps(ymm0, ymm17, ymm8 ); - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_broadcast_sd((double const *)&zero); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*11)); + ymm9 = _mm256_fnmadd_ps(ymm0, ymm17, ymm9 ); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*12)); + ymm21 = _mm256_fnmadd_ps(ymm0, ymm17, ymm21); - _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_extractf128_pd(ymm6,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_extractf128_pd(ymm7,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_extractf128_pd(ymm8,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_extractf128_pd(ymm9,0)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*13)); + ymm22 = _mm256_fnmadd_ps(ymm0, ymm17, ymm22); - a01 += d_nr*cs_a; - ptr_a10_dup += d_nr; - } - } - else - { - dim_t loop_count = p_lda/4; + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*14)); + ymm4 = _mm256_fnmadd_ps(ymm0, ymm17, ymm4 ); - for(dim_t x =0;x < loop_count;x++) - { - ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 0 + x*4)); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + x*4), ymm15); - ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 1 + x*4)); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 1 + x*4), ymm15); - ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 2 + x*4)); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 2 + x*4), ymm15); - ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 3 + x*4)); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 3 + x*4), ymm15); - } + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*15)); + ymm5 = _mm256_fnmadd_ps(ymm0, ymm17, ymm5 ); - dim_t remainder_loop_count = p_lda - loop_count*4; + //perform mul operation + ymm18 = STRSM_SMALL_DIV_OR_SCALE(ymm18, ymm1); - __m128d xmm0; - if(remainder_loop_count != 0) - { - xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 0 + loop_count*4)); - _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + loop_count*4), xmm0); - xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 1 + loop_count*4)); - _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 1 + loop_count*4), xmm0); - xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 2 + loop_count*4)); - _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 2 + loop_count*4), xmm0); - xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 3 + loop_count*4)); - _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 3 + loop_count*4), xmm0); - } - } + a11 += rs_a; - ymm4 = _mm256_broadcast_sd((double const *)&ones); - if(!is_unitdiag) - { - if(transa) - { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_sd((double const *)(a11)); - ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a*1 + 1)); - ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a*2 + 2)); - ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a*3 + 3)); - } - else - { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_sd((double const *)(a11)); - ymm1 = _mm256_broadcast_sd((double const *)(a11+ rs_a*1 + 1)); - ymm2 = _mm256_broadcast_sd((double const *)(a11+ rs_a*2 + 2)); - ymm3 = _mm256_broadcast_sd((double const *)(a11+ rs_a*3 + 3)); - } + //extract a44 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 4)); + //(ROw4): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*4)); + ymm2 = _mm256_fnmadd_ps(ymm0, ymm18, ymm2); - ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*5)); + ymm3 = _mm256_fnmadd_ps(ymm0, ymm18, ymm3); - ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); - #ifdef BLIS_DISABLE_TRSM_PREINVERSION - ymm4 = ymm1; - #endif - #ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm4 = _mm256_div_pd(ymm4, ymm1); - #endif - } - _mm256_storeu_pd((double *)(d11_pack), ymm4); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*6)); + ymm19 = _mm256_fnmadd_ps(ymm0, ymm18, ymm19); - for(i = 0; (i+d_mr-1) < m; i += d_mr) //loop along 'M' direction - { - a01 = D_A_pack; - a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*7)); + ymm20 = _mm256_fnmadd_ps(ymm0, ymm18, ymm20); - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*8)); + ymm14 = _mm256_fnmadd_ps(ymm0, ymm18, ymm14); - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*9)); + ymm15 = _mm256_fnmadd_ps(ymm0, ymm18, ymm15); - ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_4nx8m(a01,b10,cs_b,p_lda,k_iter) + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*10)); + ymm8 = _mm256_fnmadd_ps(ymm0, ymm18, ymm8 ); - BLIS_PRE_DTRSM_SMALL_4x8(AlphaVal,b11,cs_b) + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*11)); + ymm9 = _mm256_fnmadd_ps(ymm0, ymm18, ymm9 ); - ///implement TRSM/// + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*12)); + ymm21 = _mm256_fnmadd_ps(ymm0, ymm18, ymm21); - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*13)); + ymm22 = _mm256_fnmadd_ps(ymm0, ymm18, ymm22); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm0); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*14)); + ymm4 = _mm256_fnmadd_ps(ymm0, ymm18, ymm4 ); - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*15)); + ymm5 = _mm256_fnmadd_ps(ymm0, ymm18, ymm5 ); - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); + //perform mul operation + ymm2 = STRSM_SMALL_DIV_OR_SCALE(ymm2, ymm1); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - ymm6 = _mm256_fnmadd_pd(ymm1, ymm4, ymm6); + a11 += rs_a; - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + //extract a55 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 5)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); - ymm8 = _mm256_fnmadd_pd(ymm1, ymm4, ymm8); + //(ROw5): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*5)); + ymm3 = _mm256_fnmadd_ps(ymm0, ymm2, ymm3); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*6)); + ymm19 = _mm256_fnmadd_ps(ymm0, ymm2, ymm19); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm3, ymm9); - ymm10 = _mm256_fnmadd_pd(ymm1, ymm4, ymm10); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*7)); + ymm20 = _mm256_fnmadd_ps(ymm0, ymm2, ymm20); - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm0); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*8)); + ymm14 = _mm256_fnmadd_ps(ymm0, ymm2, ymm14); - a11 += cs_a; + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*9)); + ymm15 = _mm256_fnmadd_ps(ymm0, ymm2, ymm15); - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*10)); + ymm8 = _mm256_fnmadd_ps(ymm0, ymm2, ymm8 ); - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*11)); + ymm9 = _mm256_fnmadd_ps(ymm0, ymm2, ymm9 ); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); - ymm8 = _mm256_fnmadd_pd(ymm1, ymm6, ymm8); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*12)); + ymm21 = _mm256_fnmadd_ps(ymm0, ymm2, ymm21); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*13)); + ymm22 = _mm256_fnmadd_ps(ymm0, ymm2, ymm22); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm5, ymm9); - ymm10 = _mm256_fnmadd_pd(ymm1, ymm6, ymm10); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*14)); + ymm4 = _mm256_fnmadd_ps(ymm0, ymm2, ymm4 ); - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm0); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*15)); + ymm5 = _mm256_fnmadd_ps(ymm0, ymm2, ymm5 ); - a11 += cs_a; + //perform mul operation + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm1); - //extract a33 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + a11 += rs_a; - //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + //extract a66 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 6)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm7, ymm9); - ymm10 = _mm256_fnmadd_pd(ymm1, ymm8, ymm10); + //(ROw6): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*6)); + ymm19 = _mm256_fnmadd_ps(ymm0, ymm3, ymm19); - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm0); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*7)); + ymm20 = _mm256_fnmadd_ps(ymm0, ymm3, ymm20); - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + 4), ymm4); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b + 4), ymm6); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*2 + 4), ymm8); - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); - _mm256_storeu_pd((double *)(b11 + cs_b*3 + 4), ymm10); - } + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*8)); + ymm14 = _mm256_fnmadd_ps(ymm0, ymm3, ymm14); - dim_t m_remainder = m - i; - if(m_remainder >= 4) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*9)); + ymm15 = _mm256_fnmadd_ps(ymm0, ymm3, ymm15); - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*10)); + ymm8 = _mm256_fnmadd_ps(ymm0, ymm3, ymm8 ); - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*11)); + ymm9 = _mm256_fnmadd_ps(ymm0, ymm3, ymm9 ); - ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_4nx4m(a01,b10,cs_b,p_lda,k_iter) + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*12)); + ymm21 = _mm256_fnmadd_ps(ymm0, ymm3, ymm21); - ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*13)); + ymm22 = _mm256_fnmadd_ps(ymm0, ymm3, ymm22); - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*14)); + ymm4 = _mm256_fnmadd_ps(ymm0, ymm3, ymm4 ); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*15)); + ymm5 = _mm256_fnmadd_ps(ymm0, ymm3, ymm5 ); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 + //perform mul operation + ymm19 = STRSM_SMALL_DIV_OR_SCALE(ymm19, ymm1); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 + a11 += rs_a; - ///implement TRSM/// + //extract a77 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 7)); - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + //(ROw7): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*7)); + ymm20 = _mm256_fnmadd_ps(ymm0, ymm19, ymm20); - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*8)); + ymm14 = _mm256_fnmadd_ps(ymm0, ymm19, ymm14); - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*9)); + ymm15 = _mm256_fnmadd_ps(ymm0, ymm19, ymm15); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*10)); + ymm8 = _mm256_fnmadd_ps(ymm0, ymm19, ymm8 ); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm3, ymm9); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*11)); + ymm9 = _mm256_fnmadd_ps(ymm0, ymm19, ymm9 ); - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*12)); + ymm21 = _mm256_fnmadd_ps(ymm0, ymm19, ymm21); - a11 += cs_a; + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*13)); + ymm22 = _mm256_fnmadd_ps(ymm0, ymm19, ymm22); - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*14)); + ymm4 = _mm256_fnmadd_ps(ymm0, ymm19, ymm4 ); - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*15)); + ymm5 = _mm256_fnmadd_ps(ymm0, ymm19, ymm5 ); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm5, ymm9); + //perform mul operation + ymm20 = STRSM_SMALL_DIV_OR_SCALE(ymm20, ymm1); - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + a11 += rs_a; - a11 += cs_a; + //extract a88 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 8)); - //extract a33 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + //(ROw8): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*8)); + ymm14 = _mm256_fnmadd_ps(ymm0, ymm20, ymm14); - //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm7, ymm9); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*9)); + ymm15 = _mm256_fnmadd_ps(ymm0, ymm20, ymm15); - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*10)); + ymm8 = _mm256_fnmadd_ps(ymm0, ymm20, ymm8 ); - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*11)); + ymm9 = _mm256_fnmadd_ps(ymm0, ymm20, ymm9 ); - m_remainder -= 4; - i += 4; - } + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*12)); + ymm21 = _mm256_fnmadd_ps(ymm0, ymm20, ymm21); - if(m_remainder == 3) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*13)); + ymm22 = _mm256_fnmadd_ps(ymm0, ymm20, ymm22); - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*14)); + ymm4 = _mm256_fnmadd_ps(ymm0, ymm20, ymm4 ); - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*15)); + ymm5 = _mm256_fnmadd_ps(ymm0, ymm20, ymm5 ); - ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_4nx4m(a01,b10,cs_b,p_lda,k_iter) + //perform mul operation + ymm14 = STRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); - ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + a11 += rs_a; - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + //extract a99 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 9)); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + //(ROw9): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*9)); + ymm15 = _mm256_fnmadd_ps(ymm0, ymm14, ymm15); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*10)); + ymm8 = _mm256_fnmadd_ps(ymm0, ymm14, ymm8 ); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*3 + 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*11)); + ymm9 = _mm256_fnmadd_ps(ymm0, ymm14, ymm9 ); - ///implement TRSM/// + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*12)); + ymm21 = _mm256_fnmadd_ps(ymm0, ymm14, ymm21); - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*13)); + ymm22 = _mm256_fnmadd_ps(ymm0, ymm14, ymm22); - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*14)); + ymm4 = _mm256_fnmadd_ps(ymm0, ymm14, ymm4 ); - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*15)); + ymm5 = _mm256_fnmadd_ps(ymm0, ymm14, ymm5 ); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); + //perform mul operation + ymm15 = STRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm3, ymm9); + a11 += rs_a; - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + //extract a10 10 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 10)); - a11 += cs_a; + //(ROw10): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*10)); + ymm8 = _mm256_fnmadd_ps(ymm0, ymm15, ymm8 ); - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*11)); + ymm9 = _mm256_fnmadd_ps(ymm0, ymm15, ymm9 ); - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*12)); + ymm21 = _mm256_fnmadd_ps(ymm0, ymm15, ymm21); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm5, ymm9); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*13)); + ymm22 = _mm256_fnmadd_ps(ymm0, ymm15, ymm22); - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*14)); + ymm4 = _mm256_fnmadd_ps(ymm0, ymm15, ymm4 ); - a11 += cs_a; + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*15)); + ymm5 = _mm256_fnmadd_ps(ymm0, ymm15, ymm5 ); - //extract a33 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + //perform mul operation + ymm8 = STRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); - //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm7, ymm9); + a11 += rs_a; - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + //extract a11 11 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 11)); - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x07); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*3 + 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x07); + //(ROw11): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*11)); + ymm9 = _mm256_fnmadd_ps(ymm0, ymm8, ymm9 ); - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - xmm5 = _mm256_extractf128_pd(ymm9, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 3),xmm5); - _mm_storel_pd((b11 + cs_b * 3 + 2), _mm256_extractf128_pd(ymm9, 1)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*12)); + ymm21 = _mm256_fnmadd_ps(ymm0, ymm8, ymm21); - m_remainder -= 3; - i += 3; - } - else if(m_remainder == 2) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*13)); + ymm22 = _mm256_fnmadd_ps(ymm0, ymm8, ymm22); - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*14)); + ymm4 = _mm256_fnmadd_ps(ymm0, ymm8, ymm4 ); - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*15)); + ymm5 = _mm256_fnmadd_ps(ymm0, ymm8, ymm5 ); - ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_4nx4m(a01,b10,cs_b,p_lda,k_iter) + //perform mul operation + ymm9 = STRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + a11 += rs_a; - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + //extract a12 12 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 12)); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + //(ROw12): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*12)); + ymm21 = _mm256_fnmadd_ps(ymm0, ymm9, ymm21); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*13)); + ymm22 = _mm256_fnmadd_ps(ymm0, ymm9, ymm22); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*14)); + ymm4 = _mm256_fnmadd_ps(ymm0, ymm9, ymm4 ); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*15)); + ymm5 = _mm256_fnmadd_ps(ymm0, ymm9, ymm5 ); + + //perform mul operation + ymm21 = STRSM_SMALL_DIV_OR_SCALE(ymm21, ymm1); + + a11 += rs_a; + + //extract a13 13 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 13)); + + //(ROw13): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*13)); + ymm22 = _mm256_fnmadd_ps(ymm0, ymm21, ymm22); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*14)); + ymm4 = _mm256_fnmadd_ps(ymm0, ymm21, ymm4 ); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*15)); + ymm5 = _mm256_fnmadd_ps(ymm0, ymm21, ymm5 ); + + //perform mul operation + ymm22 = STRSM_SMALL_DIV_OR_SCALE(ymm22, ymm1); + + a11 += rs_a; - ///implement TRSM/// + //extract a14 14 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 14)); - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + //(ROw13): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*14)); + ymm4 = _mm256_fnmadd_ps(ymm0, ymm22, ymm4 ); - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*15)); + ymm5 = _mm256_fnmadd_ps(ymm0, ymm22, ymm5 ); - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); + //perform mul operation + ymm4 = STRSM_SMALL_DIV_OR_SCALE(ymm4, ymm1); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); + a11 += rs_a; - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm3, ymm9); + //extract a15 15 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 15)); - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + //(ROw15): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*15)); + ymm5 = _mm256_fnmadd_ps(ymm0, ymm4, ymm5 ); - a11 += cs_a; + //perform mul operation + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm1); - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + a11 += rs_a; - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); + if(n_rem == 3) + { + ymm0 = _mm256_unpacklo_ps(ymm10, ymm11); + ymm1 = _mm256_unpacklo_ps(ymm17, ymm18); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm5, ymm9); + ymm6 = _mm256_unpacklo_ps(ymm2, ymm3); + ymm7 = _mm256_unpacklo_ps(ymm19, ymm20); - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b01000100); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b01000100); - a11 += cs_a; + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//1 + _mm256_storeu_ps((float *)(b11), ymm16); - //extract a33 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b11101110); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b11101110); - //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm7, ymm9); + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//2 + _mm256_storeu_ps((float *)(b11 + cs_b), ymm16); - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + ymm0 = _mm256_unpackhi_ps(ymm10, ymm11); + ymm1 = _mm256_unpackhi_ps(ymm17, ymm18); - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x03); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x03); + ymm6 = _mm256_unpackhi_ps(ymm2, ymm3); + ymm7 = _mm256_unpackhi_ps(ymm19, ymm20); - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - xmm5 = _mm256_extractf128_pd(ymm9, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 3),xmm5); + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b01000100); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b01000100); - m_remainder -= 2; - i += 2; - } - else if(m_remainder == 1) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//3 + _mm256_storeu_ps((float *)(b11 + 2*cs_b), ymm16); - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + ymm0 = _mm256_unpacklo_ps(ymm14, ymm15); + ymm1 = _mm256_unpacklo_ps(ymm8, ymm9); - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS + ymm6 = _mm256_unpacklo_ps(ymm21, ymm22); + ymm7 = _mm256_unpacklo_ps(ymm4, ymm5); - ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_4nx4m(a01,b10,cs_b,p_lda,k_iter) + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b01000100); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b01000100); - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//1 + _mm256_storeu_ps((float *)(b11 + 8), ymm16); - ymm0 = _mm256_broadcast_sd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b11101110); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b11101110); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//2 + _mm256_storeu_ps((float *)(b11 + cs_b + 8), ymm16); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 + ymm0 = _mm256_unpackhi_ps(ymm14, ymm15); + ymm1 = _mm256_unpackhi_ps(ymm8, ymm9); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 + ymm6 = _mm256_unpackhi_ps(ymm21, ymm22); + ymm7 = _mm256_unpackhi_ps(ymm4, ymm5); - ///implement TRSM/// + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b01000100); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b01000100); - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//3 + _mm256_storeu_ps((float *)(b11 + 2*cs_b + 8), ymm16); + } + else if(n_rem == 2) + { + ymm0 = _mm256_unpacklo_ps(ymm10, ymm11); + ymm1 = _mm256_unpacklo_ps(ymm17, ymm18); - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + ymm6 = _mm256_unpacklo_ps(ymm2, ymm3); + ymm7 = _mm256_unpacklo_ps(ymm19, ymm20); - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b01000100); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b01000100); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//1 + _mm256_storeu_ps((float *)(b11), ymm16); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm3, ymm9); + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b11101110); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b11101110); - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//2 + _mm256_storeu_ps((float *)(b11 + cs_b), ymm16); - a11 += cs_a; + ymm0 = _mm256_unpacklo_ps(ymm14, ymm15); + ymm1 = _mm256_unpacklo_ps(ymm8, ymm9); - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + ymm6 = _mm256_unpacklo_ps(ymm21, ymm22); + ymm7 = _mm256_unpacklo_ps(ymm4, ymm5); - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b01000100); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b01000100); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm5, ymm9); + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//1 + _mm256_storeu_ps((float *)(b11 + 8), ymm16); - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b11101110); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b11101110); - a11 += cs_a; + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//2 + _mm256_storeu_ps((float *)(b11 + cs_b + 8), ymm16); + } + else if(n_rem == 1) + { + ymm0 = _mm256_unpacklo_ps(ymm10, ymm11); + ymm1 = _mm256_unpacklo_ps(ymm17, ymm18); - //extract a33 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + ymm6 = _mm256_unpacklo_ps(ymm2, ymm3); + ymm7 = _mm256_unpacklo_ps(ymm19, ymm20); - //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm7, ymm9); + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b01000100); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b01000100); - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//1 + _mm256_storeu_ps((float *)(b11), ymm16); - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x01); + ymm0 = _mm256_unpacklo_ps(ymm14, ymm15); + ymm1 = _mm256_unpacklo_ps(ymm8, ymm9); - _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm3, 0)); - _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm5, 0)); - _mm_storel_pd((b11 + cs_b * 2), _mm256_extractf128_pd(ymm7, 0)); - _mm_storel_pd((b11 + cs_b * 3), _mm256_extractf128_pd(ymm9, 0)); + ymm6 = _mm256_unpacklo_ps(ymm21, ymm22); + ymm7 = _mm256_unpacklo_ps(ymm4, ymm5); - m_remainder -= 1; - i += 1; + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b01000100); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b01000100); + + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//1 + _mm256_storeu_ps((float *)(b11 + 8), ymm16); + } } - j += 4; - n_remainder -= 4; } - if(n_remainder == 3) + //======================M remainder cases================================ + dim_t m_rem = m-i; + if(m_rem>=8) //implementation for reamainder rows(when 'M' is not a multiple of d_mr) { - a01 = L + j*rs_a; //pointer to block of A to be used in GEMM - a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM - - double *ptr_a10_dup = D_A_pack; - - dim_t p_lda = j; // packed leading dimension - // perform copy of A to packed buffer D_A_pack + a10 = L + (i*cs_a); //pointer to block of A to be used for GEMM + a11 = L + (i*rs_a) + (i*cs_a); + float *ptr_a10_dup = D_A_pack; + dim_t p_lda = 8; // packed leading dimension if(transa) { - for(dim_t x =0;x < p_lda;x+=d_nr) + for(dim_t x =0;x < i;x+=p_lda) { - ymm0 = _mm256_loadu_pd((double const *)(a01)); - ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); - ymm2 = _mm256_loadu_pd((double const *)(a01 + cs_a * 2)); - ymm3 = _mm256_loadu_pd((double const *)(a01 + cs_a * 3)); + ymm0 = _mm256_loadu_ps((float const *)(a10)); + ymm1 = _mm256_loadu_ps((float const *)(a10 + cs_a)); + ymm2 = _mm256_loadu_ps((float const *)(a10 + cs_a*2)); + ymm3 = _mm256_loadu_ps((float const *)(a10 + cs_a*3)); + ymm4 = _mm256_loadu_ps((float const *)(a10 + cs_a*4)); + ymm5 = _mm256_loadu_ps((float const *)(a10 + cs_a*5)); + ymm6 = _mm256_loadu_ps((float const *)(a10 + cs_a*6)); + ymm7 = _mm256_loadu_ps((float const *)(a10 + cs_a*7)); - ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm8 = _mm256_unpacklo_ps(ymm0, ymm1); + ymm9 = _mm256_unpacklo_ps(ymm2, ymm3); - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + ymm10 = _mm256_unpacklo_ps(ymm4, ymm5); + ymm11 = _mm256_unpacklo_ps(ymm6, ymm7); - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + ymm12 = _mm256_shuffle_ps(ymm8,ymm9,0b01000100); + ymm13 = _mm256_shuffle_ps(ymm10,ymm11,0b01000100); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + ymm14 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//1 + ymm15 = _mm256_permute2f128_ps(ymm12,ymm13,0x31);//5 - _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); + _mm256_storeu_ps((float *)(ptr_a10_dup), ymm14); + _mm256_storeu_ps((float *)(ptr_a10_dup + 4*p_lda), ymm15); - ymm0 = _mm256_loadu_pd((double const *)(a01 + cs_a * 4)); - ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a * 5)); + ymm12 = _mm256_shuffle_ps(ymm8,ymm9,0b11101110); + ymm13 = _mm256_shuffle_ps(ymm10,ymm11,0b11101110); - ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_broadcast_sd((double const *)&zero); + ymm14 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//2 + ymm15 = _mm256_permute2f128_ps(ymm12,ymm13,0x31);//6 + _mm256_storeu_ps((float *)(ptr_a10_dup + p_lda), ymm14); + _mm256_storeu_ps((float *)(ptr_a10_dup + 5*p_lda), ymm15); - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + ymm8 = _mm256_unpackhi_ps(ymm0, ymm1); + ymm9 = _mm256_unpackhi_ps(ymm2, ymm3); - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_broadcast_sd((double const *)&zero); + ymm10 = _mm256_unpackhi_ps(ymm4, ymm5); + ymm11 = _mm256_unpackhi_ps(ymm6, ymm7); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + ymm12 = _mm256_shuffle_ps(ymm8,ymm9,0b01000100); + ymm13 = _mm256_shuffle_ps(ymm10,ymm11,0b01000100); - _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_extractf128_pd(ymm6,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_extractf128_pd(ymm7,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_extractf128_pd(ymm8,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_extractf128_pd(ymm9,0)); + ymm14 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//3 + ymm15 = _mm256_permute2f128_ps(ymm12,ymm13,0x31);//7 + _mm256_storeu_ps((float *)(ptr_a10_dup + 2*p_lda), ymm14); + _mm256_storeu_ps((float *)(ptr_a10_dup + 6*p_lda), ymm15); - a01 += d_nr*cs_a; - ptr_a10_dup += d_nr; + ymm12 = _mm256_shuffle_ps(ymm8,ymm9,0b11101110); + ymm13 = _mm256_shuffle_ps(ymm10,ymm11,0b11101110); + + ymm14 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//4 + ymm15 = _mm256_permute2f128_ps(ymm12,ymm13,0x31);//8 + _mm256_storeu_ps((float *)(ptr_a10_dup + 3*p_lda), ymm14); + _mm256_storeu_ps((float *)(ptr_a10_dup + 7*p_lda), ymm15); + + a10 += p_lda; + ptr_a10_dup += p_lda*p_lda; } } else { - dim_t loop_count = p_lda/4; - - for(dim_t x =0;x < loop_count;x++) - { - ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 0 + x*4)); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + x*4), ymm15); - ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 1 + x*4)); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 1 + x*4), ymm15); - ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 2 + x*4)); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 2 + x*4), ymm15); - } - - dim_t remainder_loop_count = p_lda - loop_count*4; - - __m128d xmm0; - if(remainder_loop_count != 0) + for(dim_t x =0;x < i;x++) { - xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 0 + loop_count*4)); - _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + loop_count*4), xmm0); - xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 1 + loop_count*4)); - _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 1 + loop_count*4), xmm0); - xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 2 + loop_count*4)); - _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 2 + loop_count*4), xmm0); + ymm0 = _mm256_loadu_ps((float const *)(a10 + rs_a * x)); + _mm256_storeu_ps((float *)(ptr_a10_dup + p_lda * x), ymm0); } } - ymm4 = _mm256_broadcast_sd((double const *)&ones); + ymm13 = ymm14 = _mm256_broadcast_ss((float const *)&ones); if(!is_unitdiag) { if(transa) { //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_sd((double const *)(a11)); - ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a*1 + 1)); - ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a*2 + 2)); + ymm0 = _mm256_broadcast_ss((float const *)(a11)); + ymm1 = _mm256_broadcast_ss((float const *)(a11+cs_a*1 + 1)); + ymm2 = _mm256_broadcast_ss((float const *)(a11+cs_a*2 + 2)); + ymm3 = _mm256_broadcast_ss((float const *)(a11+cs_a*3 + 3)); + ymm4 = _mm256_broadcast_ss((float const *)(a11+cs_a*4 + 4)); + ymm5 = _mm256_broadcast_ss((float const *)(a11+cs_a*5 + 5)); + ymm6 = _mm256_broadcast_ss((float const *)(a11+cs_a*6 + 6)); + ymm7 = _mm256_broadcast_ss((float const *)(a11+cs_a*7 + 7)); } else { //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_sd((double const *)(a11)); - ymm1 = _mm256_broadcast_sd((double const *)(a11+ rs_a*1 + 1)); - ymm2 = _mm256_broadcast_sd((double const *)(a11+ rs_a*2 + 2)); + ymm0 = _mm256_broadcast_ss((float const *)(a11)); + ymm1 = _mm256_broadcast_ss((float const *)(a11+rs_a*1 + 1)); + ymm2 = _mm256_broadcast_ss((float const *)(a11+rs_a*2 + 2)); + ymm3 = _mm256_broadcast_ss((float const *)(a11+rs_a*3 + 3)); + ymm4 = _mm256_broadcast_ss((float const *)(a11+rs_a*4 + 4)); + ymm5 = _mm256_broadcast_ss((float const *)(a11+rs_a*5 + 5)); + ymm6 = _mm256_broadcast_ss((float const *)(a11+rs_a*6 + 6)); + ymm7 = _mm256_broadcast_ss((float const *)(a11+rs_a*7 + 7)); } - ymm3 = _mm256_broadcast_sd((double const *)&ones); - ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm8 = _mm256_unpacklo_ps(ymm0, ymm1); + ymm9 = _mm256_unpacklo_ps(ymm2, ymm3); + ymm10 = _mm256_blend_ps(ymm8, ymm9, 0xCC); + + ymm8 = _mm256_unpacklo_ps(ymm4, ymm5); + ymm9 = _mm256_unpacklo_ps(ymm6, ymm7); + ymm11 = _mm256_blend_ps(ymm8, ymm9, 0xCC); + + ymm12 = _mm256_blend_ps(ymm10, ymm11, 0xF0); + + #ifdef BLIS_DISABLE_TRSM_PREINVERSION + ymm14 = ymm12; + #endif + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + ymm14 = _mm256_div_ps(ymm13, ymm12); + #endif + } + _mm256_storeu_ps((float *)(d11_pack), ymm14); + + for(j = 0; (j+d_nr-1) < n; j += d_nr) //loop along 'N' dimension + { + a10 = D_A_pack; //pointer to block of A to be used for GEMM + a11 = L + (i*rs_a) + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM + b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM + + k_iter = i; //number of times GEMM operation to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM code begins/// + BLIS_STRSM_SMALL_GEMM_8mx6n(a10,b01,cs_b,p_lda,k_iter) + + ymm16 = _mm256_broadcast_ss((float const *)(&AlphaVal)); + + ymm17 = _mm256_loadu_ps((float const *)(b11)); + ymm18 = _mm256_loadu_ps((float const *)(b11 + cs_b)); + ymm19 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); + ymm20 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4)); + ymm1 = _mm256_loadu_ps((float const *)(b11 + cs_b*5)); + + ymm17 = _mm256_fmsub_ps(ymm17, ymm16, ymm8); + ymm18 = _mm256_fmsub_ps(ymm18, ymm16, ymm9); + ymm19 = _mm256_fmsub_ps(ymm19, ymm16, ymm10); + ymm20 = _mm256_fmsub_ps(ymm20, ymm16, ymm11); + ymm0 = _mm256_fmsub_ps(ymm0 , ymm16, ymm4); + ymm1 = _mm256_fmsub_ps(ymm1 , ymm16, ymm5); + + ymm8 = _mm256_unpacklo_ps(ymm17, ymm18); + ymm9 = _mm256_unpacklo_ps(ymm19, ymm20); + + ymm16 = _mm256_unpacklo_ps(ymm0, ymm1); + + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b01000100); + ymm5 = _mm256_shuffle_ps(ymm16,ymm16,0b01000100); + + ymm10 = _mm256_permute2f128_ps(ymm4,ymm5,0x20);//1 + ymm2 = _mm256_permute2f128_ps(ymm4,ymm5,0x31);//5 + + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b11101110); + ymm5 = _mm256_shuffle_ps(ymm16,ymm16,0b11101110); + + ymm11 = _mm256_permute2f128_ps(ymm4,ymm5,0x20);//2 + ymm3 = _mm256_permute2f128_ps(ymm4,ymm5,0x31);//6 + + ymm8 = _mm256_unpackhi_ps(ymm17, ymm18); + ymm9 = _mm256_unpackhi_ps(ymm19, ymm20); + + ymm16 = _mm256_unpackhi_ps(ymm0, ymm1); + + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b01000100); + ymm5 = _mm256_shuffle_ps(ymm16,ymm16,0b01000100); + + ymm17 = _mm256_permute2f128_ps(ymm4,ymm5,0x20);//3 + ymm19 = _mm256_permute2f128_ps(ymm4,ymm5,0x31);//7 + + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b11101110); + ymm5 = _mm256_shuffle_ps(ymm16,ymm16,0b11101110); + + ymm18 = _mm256_permute2f128_ps(ymm4,ymm5,0x20);//4 + ymm20 = _mm256_permute2f128_ps(ymm4,ymm5,0x31);//8 + + // TRSM portion + + ////extract a00 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack)); + + //perform mul operation + ymm10 = STRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); + + //(ROw1): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*1)); + ymm11 = _mm256_fnmadd_ps(ymm0, ymm10, ymm11); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*2)); + ymm17 = _mm256_fnmadd_ps(ymm0, ymm10, ymm17); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*3)); + ymm18 = _mm256_fnmadd_ps(ymm0, ymm10, ymm18); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*4)); + ymm2 = _mm256_fnmadd_ps(ymm0, ymm10, ymm2); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*5)); + ymm3 = _mm256_fnmadd_ps(ymm0, ymm10, ymm3); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*6)); + ymm19 = _mm256_fnmadd_ps(ymm0, ymm10, ymm19); - ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); - #ifdef BLIS_DISABLE_TRSM_PREINVERSION - ymm4 = ymm1; - #endif - #ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm4 = _mm256_div_pd(ymm4, ymm1); - #endif - } - _mm256_storeu_pd((double *)(d11_pack), ymm4); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*7)); + ymm20 = _mm256_fnmadd_ps(ymm0, ymm10, ymm20); - for(i = 0; (i+d_mr-1) < m; i += d_mr) //loop along 'M' direction - { - a01 = D_A_pack; - a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + //perform mul operation + ymm11 = STRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + a11 += rs_a; - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS + //extract a22 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 2)); - ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_3nx8m(a01,b10,cs_b,p_lda,k_iter) + //(ROw2): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*2)); + ymm17 = _mm256_fnmadd_ps(ymm0, ymm11, ymm17); - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*3)); + ymm18 = _mm256_fnmadd_ps(ymm0, ymm11, ymm18); - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + 4)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*4)); + ymm2 = _mm256_fnmadd_ps(ymm0, ymm11, ymm2); - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - ymm4 = _mm256_fmsub_pd(ymm1, ymm15, ymm4); //B11[4-7][0] * alpha-= ymm1 + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*5)); + ymm3 = _mm256_fnmadd_ps(ymm0, ymm11, ymm3); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b + 4)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*6)); + ymm19 = _mm256_fnmadd_ps(ymm0, ymm11, ymm19); - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - ymm6 = _mm256_fmsub_pd(ymm1, ymm15, ymm6); //B11[4-7][1] * alpha -= ymm3 + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*7)); + ymm20 = _mm256_fnmadd_ps(ymm0, ymm11, ymm20); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b*2 + 4)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] + //perform mul operation + ymm17 = STRSM_SMALL_DIV_OR_SCALE(ymm17, ymm1); - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - ymm8 = _mm256_fmsub_pd(ymm1, ymm15, ymm8); //B11[4-7][2] * alpha -= ymm5 + a11 += rs_a; - ///implement TRSM/// + //extract a33 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 3)); - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + //(ROw5): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*3)); + ymm18 = _mm256_fnmadd_ps(ymm0, ymm17, ymm18); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm0); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*4)); + ymm2 = _mm256_fnmadd_ps(ymm0, ymm17, ymm2); - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*5)); + ymm3 = _mm256_fnmadd_ps(ymm0, ymm17, ymm3); - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*6)); + ymm19 = _mm256_fnmadd_ps(ymm0, ymm17, ymm19); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - ymm6 = _mm256_fnmadd_pd(ymm1, ymm4, ymm6); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*7)); + ymm20 = _mm256_fnmadd_ps(ymm0, ymm17, ymm20); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + //perform mul operation + ymm18 = STRSM_SMALL_DIV_OR_SCALE(ymm18, ymm1); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); - ymm8 = _mm256_fnmadd_pd(ymm1, ymm4, ymm8); + a11 += rs_a; - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm0); + //extract a44 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 4)); - a11 += cs_a; + //(ROw4): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*4)); + ymm2 = _mm256_fnmadd_ps(ymm0, ymm18, ymm2); - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*5)); + ymm3 = _mm256_fnmadd_ps(ymm0, ymm18, ymm3); - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*6)); + ymm19 = _mm256_fnmadd_ps(ymm0, ymm18, ymm19); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); - ymm8 = _mm256_fnmadd_pd(ymm1, ymm6, ymm8); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*7)); + ymm20 = _mm256_fnmadd_ps(ymm0, ymm18, ymm20); - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm0); + //perform mul operation + ymm2 = STRSM_SMALL_DIV_OR_SCALE(ymm2, ymm1); - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + 4), ymm4); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b + 4), ymm6); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*2 + 4), ymm8); - } + a11 += rs_a; - dim_t m_remainder = m - i; - if(m_remainder >= 4) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + //extract a55 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 5)); - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + //(ROw5): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*5)); + ymm3 = _mm256_fnmadd_ps(ymm0, ymm2, ymm3); - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*6)); + ymm19 = _mm256_fnmadd_ps(ymm0, ymm2, ymm19); - ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*7)); + ymm20 = _mm256_fnmadd_ps(ymm0, ymm2, ymm20); - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + //perform mul operation + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm1); - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + a11 += rs_a; - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + //extract a66 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 6)); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 + //(ROw6): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*6)); + ymm19 = _mm256_fnmadd_ps(ymm0, ymm3, ymm19); - ///implement TRSM/// - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*7)); + ymm20 = _mm256_fnmadd_ps(ymm0, ymm3, ymm20); - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + //perform mul operation + ymm19 = STRSM_SMALL_DIV_OR_SCALE(ymm19, ymm1); - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); + a11 += rs_a; - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); + //extract a77 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 7)); - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + //(ROw7): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*7)); + ymm20 = _mm256_fnmadd_ps(ymm0, ymm19, ymm20); - a11 += cs_a; + //perform mul operation + ymm20 = STRSM_SMALL_DIV_OR_SCALE(ymm20, ymm1); - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + a11 += rs_a; - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); + ymm0 = _mm256_unpacklo_ps(ymm10, ymm11); + ymm1 = _mm256_unpacklo_ps(ymm17, ymm18); - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + ymm6 = _mm256_unpacklo_ps(ymm2, ymm3); + ymm7 = _mm256_unpacklo_ps(ymm19, ymm20); - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b01000100); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b01000100); - m_remainder -= 4; - i += 4; + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//1 + _mm256_storeu_ps((float *)(b11), ymm16); + + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x31);//5 + _mm256_storeu_ps((float *)(b11 + 4*cs_b), ymm16); + + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b11101110); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b11101110); + + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//2 + _mm256_storeu_ps((float *)(b11 + cs_b), ymm16); + + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x31);//6 + _mm256_storeu_ps((float *)(b11 + 5*cs_b), ymm16); + + ymm0 = _mm256_unpackhi_ps(ymm10, ymm11); + ymm1 = _mm256_unpackhi_ps(ymm17, ymm18); + + ymm6 = _mm256_unpackhi_ps(ymm2, ymm3); + ymm7 = _mm256_unpackhi_ps(ymm19, ymm20); + + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b01000100); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b01000100); + + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//3 + _mm256_storeu_ps((float *)(b11 + 2*cs_b), ymm16); + + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b11101110); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b11101110); + + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//4 + _mm256_storeu_ps((float *)(b11 + 3*cs_b), ymm16); } - if(m_remainder == 3) + dim_t n_rem = n-j; + if(n_rem >= 4) { - a01 = D_A_pack; - a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + a10 = D_A_pack; + a11 = L + (i*rs_a) + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + j*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + k_iter = i; //number of times GEMM to be performed(in blocks of 4x4) /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS + BLIS_SET_S_YMM_REG_ZEROS - ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_8mx4n(a10,b01,cs_b,p_lda,k_iter) - BLIS_PRE_DTRSM_SMALL_3N_3M(AlphaVal,b11,cs_b) + ymm16 = _mm256_broadcast_ss((float const *)(&AlphaVal)); //register to hold alpha - ///implement TRSM/// + ymm17 = _mm256_loadu_ps((float const *)(b11)); + ymm18 = _mm256_loadu_ps((float const *)(b11 + cs_b)); + ymm19 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); + ymm20 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + ymm17 = _mm256_fmsub_ps(ymm17, ymm16, ymm8); + ymm18 = _mm256_fmsub_ps(ymm18, ymm16, ymm9); + ymm19 = _mm256_fmsub_ps(ymm19, ymm16, ymm10); + ymm20 = _mm256_fmsub_ps(ymm20, ymm16, ymm11); - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + ymm8 = _mm256_unpacklo_ps(ymm17, ymm18); + ymm9 = _mm256_unpacklo_ps(ymm19, ymm20); - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b01000100); + ymm5 = _mm256_broadcast_ss((float const *)(&zero)); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); + ymm10 = _mm256_permute2f128_ps(ymm4,ymm5,0x20);//1 + ymm2 = _mm256_permute2f128_ps(ymm4,ymm5,0x31);//5 - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b11101110); - a11 += cs_a; + ymm11 = _mm256_permute2f128_ps(ymm4,ymm5,0x20);//2 + ymm3 = _mm256_permute2f128_ps(ymm4,ymm5,0x31);//6 - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + ymm8 = _mm256_unpackhi_ps(ymm17, ymm18); + ymm9 = _mm256_unpackhi_ps(ymm19, ymm20); - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b01000100); - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + ymm17 = _mm256_permute2f128_ps(ymm4,ymm5,0x20);//3 + ymm19 = _mm256_permute2f128_ps(ymm4,ymm5,0x31);//7 - BLIS_POST_DTRSM_SMALL_3N_3M(b11,cs_b) + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b11101110); - m_remainder -= 3; - i += 3; - } - else if(m_remainder == 2) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + ymm18 = _mm256_permute2f128_ps(ymm4,ymm5,0x20);//4 + ymm20 = _mm256_permute2f128_ps(ymm4,ymm5,0x31);//8 - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + // TRSM portion - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS + ////extract a00 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack)); - ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) + //perform mul operation + ymm10 = STRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); - BLIS_PRE_DTRSM_SMALL_3N_2M(AlphaVal,b11,cs_b) + //extract a11 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); - ///implement TRSM/// + //(ROw1): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*1)); + ymm11 = _mm256_fnmadd_ps(ymm0, ymm10, ymm11); - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*2)); + ymm17 = _mm256_fnmadd_ps(ymm0, ymm10, ymm17); - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*3)); + ymm18 = _mm256_fnmadd_ps(ymm0, ymm10, ymm18); - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*4)); + ymm2 = _mm256_fnmadd_ps(ymm0, ymm10, ymm2); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*5)); + ymm3 = _mm256_fnmadd_ps(ymm0, ymm10, ymm3); - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*6)); + ymm19 = _mm256_fnmadd_ps(ymm0, ymm10, ymm19); - a11 += cs_a; + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*7)); + ymm20 = _mm256_fnmadd_ps(ymm0, ymm10, ymm20); + + //perform mul operation + ymm11 = STRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); + + a11 += rs_a; //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 2)); - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); + //(ROw2): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*2)); + ymm17 = _mm256_fnmadd_ps(ymm0, ymm11, ymm17); - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*3)); + ymm18 = _mm256_fnmadd_ps(ymm0, ymm11, ymm18); - BLIS_POST_DTRSM_SMALL_3N_2M(b11,cs_b) + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*4)); + ymm2 = _mm256_fnmadd_ps(ymm0, ymm11, ymm2); - m_remainder -= 2; - i += 2; - } - else if(m_remainder == 1) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*5)); + ymm3 = _mm256_fnmadd_ps(ymm0, ymm11, ymm3); - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*6)); + ymm19 = _mm256_fnmadd_ps(ymm0, ymm11, ymm19); - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*7)); + ymm20 = _mm256_fnmadd_ps(ymm0, ymm11, ymm20); - ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) + //perform mul operation + ymm17 = STRSM_SMALL_DIV_OR_SCALE(ymm17, ymm1); + + a11 += rs_a; + + //extract a33 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 3)); + + //(ROw5): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*3)); + ymm18 = _mm256_fnmadd_ps(ymm0, ymm17, ymm18); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*4)); + ymm2 = _mm256_fnmadd_ps(ymm0, ymm17, ymm2); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*5)); + ymm3 = _mm256_fnmadd_ps(ymm0, ymm17, ymm3); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*6)); + ymm19 = _mm256_fnmadd_ps(ymm0, ymm17, ymm19); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*7)); + ymm20 = _mm256_fnmadd_ps(ymm0, ymm17, ymm20); + + //perform mul operation + ymm18 = STRSM_SMALL_DIV_OR_SCALE(ymm18, ymm1); + + a11 += rs_a; + + //extract a44 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 4)); - BLIS_PRE_DTRSM_SMALL_3N_1M(AlphaVal,b11,cs_b) + //(ROw4): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*4)); + ymm2 = _mm256_fnmadd_ps(ymm0, ymm18, ymm2); - ///implement TRSM/// + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*5)); + ymm3 = _mm256_fnmadd_ps(ymm0, ymm18, ymm3); - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*6)); + ymm19 = _mm256_fnmadd_ps(ymm0, ymm18, ymm19); - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*7)); + ymm20 = _mm256_fnmadd_ps(ymm0, ymm18, ymm20); - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); + //perform mul operation + ymm2 = STRSM_SMALL_DIV_OR_SCALE(ymm2, ymm1); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); + a11 += rs_a; - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + //extract a55 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 5)); - a11 += cs_a; + //(ROw5): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*5)); + ymm3 = _mm256_fnmadd_ps(ymm0, ymm2, ymm3); - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*6)); + ymm19 = _mm256_fnmadd_ps(ymm0, ymm2, ymm19); - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*7)); + ymm20 = _mm256_fnmadd_ps(ymm0, ymm2, ymm20); - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + //perform mul operation + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm1); - BLIS_POST_DTRSM_SMALL_3N_1M(b11,cs_b) + a11 += rs_a; - m_remainder -= 1; - i += 1; - } - j += 3; - n_remainder -= 3; - } - else if(n_remainder == 2) - { - a01 = L + j*rs_a; //pointer to block of A to be used in GEMM - a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + //extract a66 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 6)); - double *ptr_a10_dup = D_A_pack; + //(ROw6): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*6)); + ymm19 = _mm256_fnmadd_ps(ymm0, ymm3, ymm19); - dim_t p_lda = j; // packed leading dimension - // perform copy of A to packed buffer D_A_pack + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*7)); + ymm20 = _mm256_fnmadd_ps(ymm0, ymm3, ymm20); - if(transa) - { - for(dim_t x =0;x < p_lda;x+=d_nr) - { - ymm0 = _mm256_loadu_pd((double const *)(a01)); - ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); - ymm2 = _mm256_loadu_pd((double const *)(a01 + cs_a * 2)); - ymm3 = _mm256_loadu_pd((double const *)(a01 + cs_a * 3)); + //perform mul operation + ymm19 = STRSM_SMALL_DIV_OR_SCALE(ymm19, ymm1); - ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + a11 += rs_a; - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + //extract a77 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 7)); - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + //(ROw7): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*7)); + ymm20 = _mm256_fnmadd_ps(ymm0, ymm19, ymm20); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + //perform mul operation + ymm20 = STRSM_SMALL_DIV_OR_SCALE(ymm20, ymm1); - _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); + a11 += rs_a; - ymm0 = _mm256_loadu_pd((double const *)(a01 + cs_a * 4)); - ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a * 5)); + ymm0 = _mm256_unpacklo_ps(ymm10, ymm11); + ymm1 = _mm256_unpacklo_ps(ymm17, ymm18); - ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_broadcast_sd((double const *)&zero); + ymm6 = _mm256_unpacklo_ps(ymm2, ymm3); + ymm7 = _mm256_unpacklo_ps(ymm19, ymm20); - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b01000100); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b01000100); - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_broadcast_sd((double const *)&zero); + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//1 + _mm256_storeu_ps((float *)(b11), ymm16); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b11101110); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b11101110); - _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_extractf128_pd(ymm6,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_extractf128_pd(ymm7,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_extractf128_pd(ymm8,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_extractf128_pd(ymm9,0)); + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//2 + _mm256_storeu_ps((float *)(b11 + cs_b), ymm16); - a01 += d_nr*cs_a; - ptr_a10_dup += d_nr; - } - } - else - { - dim_t loop_count = p_lda/4; + ymm0 = _mm256_unpackhi_ps(ymm10, ymm11); + ymm1 = _mm256_unpackhi_ps(ymm17, ymm18); - for(dim_t x =0;x < loop_count;x++) - { - ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 0 + x*4)); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + x*4), ymm15); - ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 1 + x*4)); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 1 + x*4), ymm15); - } + ymm6 = _mm256_unpackhi_ps(ymm2, ymm3); + ymm7 = _mm256_unpackhi_ps(ymm19, ymm20); - dim_t remainder_loop_count = p_lda - loop_count*4; + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b01000100); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b01000100); - __m128d xmm0; - if(remainder_loop_count != 0) - { - xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 0 + loop_count*4)); - _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + loop_count*4), xmm0); - xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 1 + loop_count*4)); - _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 1 + loop_count*4), xmm0); - } - } + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//3 + _mm256_storeu_ps((float *)(b11 + 2*cs_b), ymm16); - ymm4 = _mm256_broadcast_sd((double const *)&ones); - if(!is_unitdiag) - { - if(transa) - { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_sd((double const *)(a11)); - ymm1 = _mm256_broadcast_sd((double const *)(a11+cs_a*1 + 1)); - } - else - { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_sd((double const *)(a11)); - ymm1 = _mm256_broadcast_sd((double const *)(a11+rs_a*1 + 1)); - } - ymm2 = _mm256_broadcast_sd((double const *)&ones); - ymm3 = _mm256_broadcast_sd((double const *)&ones); + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b11101110); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b11101110); - ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//4 + _mm256_storeu_ps((float *)(b11 + 3*cs_b), ymm16); - ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); - #ifdef BLIS_DISABLE_TRSM_PREINVERSION - ymm4 = ymm1; - #endif - #ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm4 = _mm256_div_pd(ymm4, ymm1); - #endif + n_rem -= 4; + j += 4; } - _mm256_storeu_pd((double *)(d11_pack), ymm4); - - for(i = 0; (i+d_mr-1) < m; i += d_mr) //loop along 'M' direction + if(n_rem) { - a01 = D_A_pack; - a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + a10 = D_A_pack; + a11 = L + (i*rs_a) + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + j*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + k_iter = i; //number of times GEMM to be performed(in blocks of 4x4) - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS + ymm8 = _mm256_setzero_ps(); + ymm9 = _mm256_setzero_ps(); + ymm10 = _mm256_setzero_ps(); - ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_2nx8m(a01,b10,cs_b,p_lda,k_iter) + if(3 == n_rem) + { + ///GEMM code begins/// + BLIS_STRSM_SMALL_GEMM_8mx3n(a10,b01,cs_b,p_lda,k_iter) - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); + ymm16 = _mm256_broadcast_ss((float const *)(&AlphaVal)); //register to hold alpha - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + 4)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] + ymm17 = _mm256_loadu_ps((float const *)(b11)); + ymm18 = _mm256_loadu_ps((float const *)(b11 + cs_b)); + ymm19 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); + ymm20 = _mm256_broadcast_ss((float const *)(&zero)); - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - ymm4 = _mm256_fmsub_pd(ymm1, ymm15, ymm4); //B11[4-7][0] * alpha-= ymm1 + ymm17 = _mm256_fmsub_ps(ymm17, ymm16, ymm8); + ymm18 = _mm256_fmsub_ps(ymm18, ymm16, ymm9); + ymm19 = _mm256_fmsub_ps(ymm19, ymm16, ymm10); + } + else if(2 == n_rem) + { + ///GEMM code begins/// + BLIS_STRSM_SMALL_GEMM_8mx2n(a10,b01,cs_b,p_lda,k_iter) - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b + 4)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] + ymm16 = _mm256_broadcast_ss((float const *)(&AlphaVal)); //register to hold alpha - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - ymm6 = _mm256_fmsub_pd(ymm1, ymm15, ymm6); //B11[4-7][1] * alpha -= ymm3 + ymm17 = _mm256_loadu_ps((float const *)(b11)); + ymm18 = _mm256_loadu_ps((float const *)(b11 + cs_b)); + ymm19 = _mm256_broadcast_ss((float const *)(&zero)); + ymm20 = _mm256_broadcast_ss((float const *)(&zero)); - ///implement TRSM/// + ymm17 = _mm256_fmsub_ps(ymm17, ymm16, ymm8); + ymm18 = _mm256_fmsub_ps(ymm18, ymm16, ymm9); + } + else if(1 == n_rem) + { + ///GEMM code begins/// + BLIS_STRSM_SMALL_GEMM_8mx1n(a10,b01,cs_b,p_lda,k_iter) - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm16 = _mm256_broadcast_ss((float const *)(&AlphaVal)); //register to hold alpha - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm0); + ymm17 = _mm256_loadu_ps((float const *)(b11)); + ymm18 = _mm256_broadcast_ss((float const *)(&zero)); + ymm19 = _mm256_broadcast_ss((float const *)(&zero)); + ymm20 = _mm256_broadcast_ss((float const *)(&zero)); - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + ymm17 = _mm256_fmsub_ps(ymm17, ymm16, ymm8); + } - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); + ymm8 = _mm256_unpacklo_ps(ymm17, ymm18); + ymm9 = _mm256_unpacklo_ps(ymm19, ymm20); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - ymm6 = _mm256_fnmadd_pd(ymm1, ymm4, ymm6); + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b01000100); + ymm5 = _mm256_broadcast_ss((float const *)(&zero)); - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm0); + ymm10 = _mm256_permute2f128_ps(ymm4,ymm5,0x20);//1 + ymm2 = _mm256_permute2f128_ps(ymm4,ymm5,0x31);//5 - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + 4), ymm4); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b + 4), ymm6); - } + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b11101110); - dim_t m_remainder = m - i; - if(m_remainder >= 4) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + ymm11 = _mm256_permute2f128_ps(ymm4,ymm5,0x20);//2 + ymm3 = _mm256_permute2f128_ps(ymm4,ymm5,0x31);//6 - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + ymm8 = _mm256_unpackhi_ps(ymm17, ymm18); + ymm9 = _mm256_unpackhi_ps(ymm19, ymm20); - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b01000100); - ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) + ymm17 = _mm256_permute2f128_ps(ymm4,ymm5,0x20);//3 + ymm19 = _mm256_permute2f128_ps(ymm4,ymm5,0x31);//7 - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b11101110); - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + ymm18 = _mm256_permute2f128_ps(ymm4,ymm5,0x20);//4 + ymm20 = _mm256_permute2f128_ps(ymm4,ymm5,0x31);//8 - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + // TRSM portion - ///implement TRSM/// + ////extract a00 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack)); - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + //perform mul operation + ymm10 = STRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - - m_remainder -= 4; - i += 4; - } + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); - if(m_remainder == 3) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + //(ROw1): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*1)); + ymm11 = _mm256_fnmadd_ps(ymm0, ymm10, ymm11); - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*2)); + ymm17 = _mm256_fnmadd_ps(ymm0, ymm10, ymm17); - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*3)); + ymm18 = _mm256_fnmadd_ps(ymm0, ymm10, ymm18); - ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*4)); + ymm2 = _mm256_fnmadd_ps(ymm0, ymm10, ymm2); - BLIS_PRE_DTRSM_SMALL_2N_3M(AlphaVal,b11,cs_b) + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*5)); + ymm3 = _mm256_fnmadd_ps(ymm0, ymm10, ymm3); - ///implement TRSM/// + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*6)); + ymm19 = _mm256_fnmadd_ps(ymm0, ymm10, ymm19); - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*7)); + ymm20 = _mm256_fnmadd_ps(ymm0, ymm10, ymm20); - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + //perform mul operation + ymm11 = STRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); + a11 += rs_a; - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + //extract a22 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 2)); - BLIS_POST_DTRSM_SMALL_2N_3M(b11,cs_b) + //(ROw2): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*2)); + ymm17 = _mm256_fnmadd_ps(ymm0, ymm11, ymm17); - m_remainder -= 3; - i += 3; - } - else if(m_remainder == 2) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*3)); + ymm18 = _mm256_fnmadd_ps(ymm0, ymm11, ymm18); - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*4)); + ymm2 = _mm256_fnmadd_ps(ymm0, ymm11, ymm2); - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*5)); + ymm3 = _mm256_fnmadd_ps(ymm0, ymm11, ymm3); - ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*6)); + ymm19 = _mm256_fnmadd_ps(ymm0, ymm11, ymm19); - BLIS_PRE_DTRSM_SMALL_2N_2M(AlphaVal,b11,cs_b) + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*7)); + ymm20 = _mm256_fnmadd_ps(ymm0, ymm11, ymm20); - ///implement TRSM/// - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + //perform mul operation + ymm17 = STRSM_SMALL_DIV_OR_SCALE(ymm17, ymm1); - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + a11 += rs_a; - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); + //extract a33 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 3)); - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + //(ROw5): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*3)); + ymm18 = _mm256_fnmadd_ps(ymm0, ymm17, ymm18); - BLIS_POST_DTRSM_SMALL_2N_2M(b11,cs_b) + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*4)); + ymm2 = _mm256_fnmadd_ps(ymm0, ymm17, ymm2); - m_remainder -= 2; - i += 2; - } - else if(m_remainder == 1) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*5)); + ymm3 = _mm256_fnmadd_ps(ymm0, ymm17, ymm3); - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*6)); + ymm19 = _mm256_fnmadd_ps(ymm0, ymm17, ymm19); - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*7)); + ymm20 = _mm256_fnmadd_ps(ymm0, ymm17, ymm20); - ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) + //perform mul operation + ymm18 = STRSM_SMALL_DIV_OR_SCALE(ymm18, ymm1); - BLIS_PRE_DTRSM_SMALL_2N_1M(AlphaVal,b11,cs_b) + a11 += rs_a; - ///implement TRSM/// + //extract a44 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 4)); - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + //(ROw4): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*4)); + ymm2 = _mm256_fnmadd_ps(ymm0, ymm18, ymm2); - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*5)); + ymm3 = _mm256_fnmadd_ps(ymm0, ymm18, ymm3); - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*6)); + ymm19 = _mm256_fnmadd_ps(ymm0, ymm18, ymm19); - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*7)); + ymm20 = _mm256_fnmadd_ps(ymm0, ymm18, ymm20); - BLIS_POST_DTRSM_SMALL_2N_1M(b11,cs_b) + //perform mul operation + ymm2 = STRSM_SMALL_DIV_OR_SCALE(ymm2, ymm1); - m_remainder -= 1; - i += 1; - } - j += 2; - n_remainder -= 2; - } - else if(n_remainder == 1) - { - a01 = L + j*rs_a; //pointer to block of A to be used in GEMM - a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + a11 += rs_a; - double *ptr_a10_dup = D_A_pack; + //extract a55 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 5)); - dim_t p_lda = j; // packed leading dimension - // perform copy of A to packed buffer D_A_pack + //(ROw5): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*5)); + ymm3 = _mm256_fnmadd_ps(ymm0, ymm2, ymm3); - if(transa) - { - for(dim_t x =0;x < p_lda;x+=d_nr) - { - ymm0 = _mm256_loadu_pd((double const *)(a01)); - ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); - ymm2 = _mm256_loadu_pd((double const *)(a01 + cs_a * 2)); - ymm3 = _mm256_loadu_pd((double const *)(a01 + cs_a * 3)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*6)); + ymm19 = _mm256_fnmadd_ps(ymm0, ymm2, ymm19); - ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*7)); + ymm20 = _mm256_fnmadd_ps(ymm0, ymm2, ymm20); - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + //perform mul operation + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm1); - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + a11 += rs_a; - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + //extract a66 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 6)); - _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); + //(ROw6): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*6)); + ymm19 = _mm256_fnmadd_ps(ymm0, ymm3, ymm19); - ymm0 = _mm256_loadu_pd((double const *)(a01 + cs_a * 4)); - ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a * 5)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*7)); + ymm20 = _mm256_fnmadd_ps(ymm0, ymm3, ymm20); - ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_broadcast_sd((double const *)&zero); + //perform mul operation + ymm19 = STRSM_SMALL_DIV_OR_SCALE(ymm19, ymm1); - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + a11 += rs_a; - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_broadcast_sd((double const *)&zero); + //extract a77 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 7)); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + //(ROw7): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*7)); + ymm20 = _mm256_fnmadd_ps(ymm0, ymm19, ymm20); - _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_extractf128_pd(ymm6,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_extractf128_pd(ymm7,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_extractf128_pd(ymm8,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_extractf128_pd(ymm9,0)); + //perform mul operation + ymm20 = STRSM_SMALL_DIV_OR_SCALE(ymm20, ymm1); - a01 += d_nr*cs_a; - ptr_a10_dup += d_nr; - } - } - else - { - dim_t loop_count = p_lda/4; + a11 += rs_a; - for(dim_t x =0;x < loop_count;x++) + if(3 == n_rem) { - ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 0 + x*4)); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + x*4), ymm15); - } + ymm0 = _mm256_unpacklo_ps(ymm10, ymm11); + ymm1 = _mm256_unpacklo_ps(ymm17, ymm18); - dim_t remainder_loop_count = p_lda - loop_count*4; + ymm6 = _mm256_unpacklo_ps(ymm2, ymm3); + ymm7 = _mm256_unpacklo_ps(ymm19, ymm20); - __m128d xmm0; - if(remainder_loop_count != 0) - { - xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 0 + loop_count*4)); - _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + loop_count*4), xmm0); - } - } + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b01000100); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b01000100); - ymm4 = _mm256_broadcast_sd((double const *)&ones); - if(!is_unitdiag) - { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_sd((double const *)(a11)); - ymm1 = _mm256_broadcast_sd((double const *)&ones); - ymm2 = _mm256_broadcast_sd((double const *)&ones); - ymm3 = _mm256_broadcast_sd((double const *)&ones); + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//1 + _mm256_storeu_ps((float *)(b11), ymm16); - ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b11101110); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b11101110); - ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); - #ifdef BLIS_DISABLE_TRSM_PREINVERSION - ymm4 = ymm1; - #endif - #ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm4 = _mm256_div_pd(ymm4, ymm1); - #endif - } - _mm256_storeu_pd((double *)(d11_pack), ymm4); + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//2 + _mm256_storeu_ps((float *)(b11 + cs_b), ymm16); - for(i = 0; (i+d_mr-1) < m; i += d_mr) //loop along 'M' direction - { - a01 = D_A_pack; - a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + ymm0 = _mm256_unpackhi_ps(ymm10, ymm11); + ymm1 = _mm256_unpackhi_ps(ymm17, ymm18); - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + ymm6 = _mm256_unpackhi_ps(ymm2, ymm3); + ymm7 = _mm256_unpackhi_ps(ymm19, ymm20); - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_1nx8m(a01,b10,cs_b,p_lda,k_iter) + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b01000100); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b01000100); - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//3 + _mm256_storeu_ps((float *)(b11 + 2*cs_b), ymm16); + } + else if(2 == n_rem) + { + ymm0 = _mm256_unpacklo_ps(ymm10, ymm11); + ymm1 = _mm256_unpacklo_ps(ymm17, ymm18); - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + 4)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] + ymm6 = _mm256_unpacklo_ps(ymm2, ymm3); + ymm7 = _mm256_unpacklo_ps(ymm19, ymm20); - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - ymm4 = _mm256_fmsub_pd(ymm1, ymm15, ymm4); //B11[4-7][0] * alpha-= ymm1 + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b01000100); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b01000100); - ///implement TRSM/// + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//1 + _mm256_storeu_ps((float *)(b11), ymm16); - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b11101110); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b11101110); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm0); + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//2 + _mm256_storeu_ps((float *)(b11 + cs_b), ymm16); + } + else if(1 == n_rem) + { + ymm0 = _mm256_unpacklo_ps(ymm10, ymm11); + ymm1 = _mm256_unpacklo_ps(ymm17, ymm18); - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + 4), ymm4); - } + ymm6 = _mm256_unpacklo_ps(ymm2, ymm3); + ymm7 = _mm256_unpacklo_ps(ymm19, ymm20); - dim_t m_remainder = m - i; - if(m_remainder >= 4) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b01000100); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b01000100); - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//1 + _mm256_storeu_ps((float *)(b11), ymm16); + } + } + m_rem -=8; + i +=8; + } - ymm3 = _mm256_setzero_pd(); - ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) + if(m_rem>=4) //implementation for reamainder rows(when 'M' is not a multiple of d_mr) + { + a10 = L + (i*cs_a); //pointer to block of A to be used for GEMM + a11 = L + (i*rs_a) + (i*cs_a); + float *ptr_a10_dup = D_A_pack; + dim_t p_lda = 4; // packed leading dimension - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + __m128 xmm0,xmm1,xmm2,xmm3; + __m128 xmm4,xmm5; + __m128 xmm6,xmm7,xmm8,xmm9; - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + if(transa) + { + for(dim_t x =0;x < i;x+=p_lda) + { + xmm0 = _mm_loadu_ps((float const *)(a10)); + xmm1 = _mm_loadu_ps((float const *)(a10 + cs_a)); + xmm2 = _mm_loadu_ps((float const *)(a10 + cs_a * 2)); + xmm3 = _mm_loadu_ps((float const *)(a10 + cs_a * 3)); + + xmm4 = _mm_unpacklo_ps(xmm0, xmm1); + xmm5 = _mm_unpacklo_ps(xmm2, xmm3); + xmm6 = _mm_shuffle_ps(xmm4,xmm5,0x44); + xmm7 = _mm_shuffle_ps(xmm4,xmm5,0xEE); + + xmm0 = _mm_unpackhi_ps(xmm0, xmm1); + xmm1 = _mm_unpackhi_ps(xmm2, xmm3); + xmm8 = _mm_shuffle_ps(xmm0,xmm1,0x44); + xmm9 = _mm_shuffle_ps(xmm0,xmm1,0xEE); + + _mm_storeu_ps((float *)(ptr_a10_dup), xmm6); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda), xmm7); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda*2), xmm8); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda*3), xmm9); - ///implement TRSM/// + a10 += p_lda; + ptr_a10_dup += p_lda*p_lda; + } + } + else + { + for(dim_t x =0;x < i;x++) + { + xmm4 = _mm_loadu_ps((float const *)(a10 + rs_a * x)); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda * x), xmm4); + } + } - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + xmm5 = xmm4 = _mm_broadcast_ss((float const *)&ones); + if(!is_unitdiag) + { + if(transa) + { + //broadcast diagonal elements of A11 + xmm0 = _mm_broadcast_ss((float const *)(a11)); + xmm1 = _mm_broadcast_ss((float const *)(a11+cs_a*1 + 1)); + xmm2 = _mm_broadcast_ss((float const *)(a11+cs_a*2 + 2)); + xmm3 = _mm_broadcast_ss((float const *)(a11+cs_a*3 + 3)); + } + else + { + //broadcast diagonal elements of A11 + xmm0 = _mm_broadcast_ss((float const *)(a11)); + xmm1 = _mm_broadcast_ss((float const *)(a11+rs_a*1 + 1)); + xmm2 = _mm_broadcast_ss((float const *)(a11+rs_a*2 + 2)); + xmm3 = _mm_broadcast_ss((float const *)(a11+rs_a*3 + 3)); + } - _mm256_storeu_pd((double *)b11, ymm3); + xmm0 = _mm_unpacklo_ps(xmm0, xmm1); + xmm1 = _mm_unpacklo_ps(xmm2, xmm3); + xmm2 = _mm_blend_ps(xmm0, xmm1, 0x0C); - m_remainder -= 4; - i += 4; + #ifdef BLIS_DISABLE_TRSM_PREINVERSION + xmm4 = xmm2; + #endif + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + xmm4 = _mm_div_ps(xmm5, xmm2); + #endif } + _mm_storeu_ps((float *)(d11_pack), xmm4); + + for(j = 0; (j+d_nr-1) < n; j += d_nr) //loop along 'N' dimension + { + a10 = D_A_pack; //pointer to block of A to be used for GEMM + a11 = L + (i*rs_a) + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM + b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM - if(m_remainder == 3) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + k_iter = i; //number of times GEMM operation to be done(in blocks of 4x4) - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS - ymm3 = _mm256_setzero_pd(); + ///GEMM code begins/// + BLIS_STRSM_SMALL_GEMM_4mx6n(a10,b01,cs_b,p_lda,k_iter) - ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) + ymm16 = _mm256_broadcast_ss((float const *)(&AlphaVal)); - BLIS_PRE_DTRSM_SMALL_1N_3M(AlphaVal,b11,cs_b) + ymm17 = _mm256_insertf128_ps(ymm1, _mm_loadu_ps((float const*)(b11)), 0); + ymm18 = _mm256_insertf128_ps(ymm1, _mm_loadu_ps((float const*)(b11 + cs_b)), 0); + ymm19 = _mm256_insertf128_ps(ymm1, _mm_loadu_ps((float const*)(b11 + cs_b*2)), 0); + ymm20 = _mm256_insertf128_ps(ymm1, _mm_loadu_ps((float const*)(b11 + cs_b*3)), 0); + ymm0 = _mm256_insertf128_ps(ymm1, _mm_loadu_ps((float const*)(b11 + cs_b*4)), 0); + ymm1 = _mm256_insertf128_ps(ymm1, _mm_loadu_ps((float const*)(b11 + cs_b*5)), 0); - ///implement TRSM/// + ymm17 = _mm256_fmsub_ps(ymm17, ymm16, ymm8); + ymm18 = _mm256_fmsub_ps(ymm18, ymm16, ymm9); + ymm19 = _mm256_fmsub_ps(ymm19, ymm16, ymm10); + ymm20 = _mm256_fmsub_ps(ymm20, ymm16, ymm11); + ymm0 = _mm256_fmsub_ps(ymm0 , ymm16, ymm4); + ymm1 = _mm256_fmsub_ps(ymm1 , ymm16, ymm5); - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + ymm8 = _mm256_unpacklo_ps(ymm17, ymm18); + ymm9 = _mm256_unpacklo_ps(ymm19, ymm20); - BLIS_POST_DTRSM_SMALL_1N_3M(b11,cs_b) + ymm16 = _mm256_unpacklo_ps(ymm0, ymm1); - m_remainder -= 3; - i += 3; - } - else if(m_remainder == 2) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b01000100); + ymm5 = _mm256_shuffle_ps(ymm16,ymm16,0b01000100); - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + ymm10 = _mm256_permute2f128_ps(ymm4,ymm5,0x20);//1 - ymm3 = _mm256_setzero_pd(); + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b11101110); + ymm5 = _mm256_shuffle_ps(ymm16,ymm16,0b11101110); - ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) + ymm11 = _mm256_permute2f128_ps(ymm4,ymm5,0x20);//2 - BLIS_PRE_DTRSM_SMALL_1N_2M(AlphaVal,b11,cs_b) + ymm8 = _mm256_unpackhi_ps(ymm17, ymm18); + ymm9 = _mm256_unpackhi_ps(ymm19, ymm20); - ///implement TRSM/// + ymm16 = _mm256_unpackhi_ps(ymm0, ymm1); - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b01000100); + ymm5 = _mm256_shuffle_ps(ymm16,ymm16,0b01000100); - BLIS_POST_DTRSM_SMALL_1N_2M(b11,cs_b) + ymm17 = _mm256_permute2f128_ps(ymm4,ymm5,0x20);//3 - m_remainder -= 2; - i += 2; - } - else if(m_remainder == 1) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b11101110); + ymm5 = _mm256_shuffle_ps(ymm16,ymm16,0b11101110); - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) + ymm18 = _mm256_permute2f128_ps(ymm4,ymm5,0x20);//4 - ymm3 = _mm256_setzero_pd(); + // TRSM portion - ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) + ////extract a00 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack)); - BLIS_PRE_DTRSM_SMALL_1N_1M(AlphaVal,b11,cs_b) + //perform mul operation + ymm10 = STRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); - ///implement TRSM/// + //extract a11 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + //(ROw1): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*1)); + ymm11 = _mm256_fnmadd_ps(ymm0, ymm10, ymm11); - BLIS_POST_DTRSM_SMALL_1N_1M(b11,cs_b) + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*2)); + ymm17 = _mm256_fnmadd_ps(ymm0, ymm10, ymm17); - m_remainder -= 1; - i += 1; - } - j += 1; - n_remainder -= 1; - } + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*3)); + ymm18 = _mm256_fnmadd_ps(ymm0, ymm10, ymm18); - if ((required_packing_A == 1) && bli_mem_is_alloc( &local_mem_buf_A_s )) - { - bli_membrk_release(&rntm, - &local_mem_buf_A_s); - } + //perform mul operation + ymm11 = STRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); - return BLIS_SUCCESS; -} + a11 += rs_a; -/*implements TRSM for the case XA = alpha * B - *A is upper triangular, non-unit diagonal/unit diagonal, transpose - *dimensions: X:mxn A:nxn B: mxn - * - * <---b11 <---a11 - ***************** * - *b01*b11* * * * * - ^ * * * * * ^ * * - | ***************** | ******* - | * * * * * | * * * - | * * * * * a01* * * -b10 ***************** ************* - * * * * * * * * * - * * * * * * * * * - ***************** ******************* + //extract a22 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 2)); - *implements TRSM for the case XA = alpha * B - *A is lower triangular, non-unit diagonal/unit diagonal, no transpose - *dimensions: X:mxn A:nxn B: mxn - * - * <---b11 <---a11 - ***************** * - *b01*b11* * * * * - ^ * * * * * ^ * * - | ***************** | ******* - | * * * * * | * * * - | * * * * * a01* * * -b10 ***************** ************* - * * * * * * * * * - * * * * * * * * * - ***************** ******************* + //(ROw2): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*2)); + ymm17 = _mm256_fnmadd_ps(ymm0, ymm11, ymm17); -*/ -BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB -( - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl -) -{ - dim_t m = bli_obj_length(b); //number of rows - dim_t n = bli_obj_width(b); //number of columns + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*3)); + ymm18 = _mm256_fnmadd_ps(ymm0, ymm11, ymm18); - bool transa = bli_obj_has_trans(a); - dim_t cs_a, rs_a; - dim_t d_mr = 8,d_nr = 6; + //perform mul operation + ymm17 = STRSM_SMALL_DIV_OR_SCALE(ymm17, ymm1); - // Swap rs_a & cs_a in case of non-tranpose. - if(transa) - { - cs_a = bli_obj_col_stride(a); // column stride of A - rs_a = bli_obj_row_stride(a); // row stride of A - } - else - { - cs_a = bli_obj_row_stride(a); // row stride of A - rs_a = bli_obj_col_stride(a); // column stride of A - } - dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B + a11 += rs_a; - dim_t i, j, k; //loop variablse - dim_t k_iter; //determines the number of GEMM operations to be done + //extract a33 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 3)); - double ones = 1.0; - double zero = 0.0; - bool is_unitdiag = bli_obj_has_unit_diag(a); + //(ROw5): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*3)); + ymm18 = _mm256_fnmadd_ps(ymm0, ymm17, ymm18); - double AlphaVal = *(double *)AlphaObj->buffer; //value of Alpha - double* restrict L = a->buffer; //pointer to matrix A - double* restrict B = b->buffer; //pointer to matrix B + //perform mul operation + ymm18 = STRSM_SMALL_DIV_OR_SCALE(ymm18, ymm1); - double *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks + a11 += rs_a; - gint_t required_packing_A = 1; - mem_t local_mem_buf_A_s = {0}; - double *D_A_pack = NULL; - double d11_pack[d_mr] __attribute__((aligned(64))); - rntm_t rntm; + ymm0 = _mm256_unpacklo_ps(ymm10, ymm11); + ymm1 = _mm256_unpacklo_ps(ymm17, ymm18); - bli_rntm_init_from_global( &rntm ); - bli_rntm_set_num_threads_only( 1, &rntm ); - bli_membrk_rntm_set_membrk( &rntm ); + ymm6 = _mm256_unpacklo_ps(ymm2, ymm3); + ymm7 = _mm256_unpacklo_ps(ymm19, ymm20); - siz_t buffer_size = bli_pool_block_size( - bli_membrk_pool( - bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), - bli_rntm_membrk(&rntm))); + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b01000100); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b01000100); - if( (d_nr * n * sizeof(double)) > buffer_size) - return BLIS_NOT_YET_IMPLEMENTED; + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//1 + _mm_storeu_ps((float *)(b11), _mm256_extractf128_ps(ymm16, 0)); - if (required_packing_A == 1) - { - // Get the buffer from the pool. - bli_membrk_acquire_m(&rntm, - buffer_size, - BLIS_BITVAL_BUFFER_FOR_A_BLOCK, - &local_mem_buf_A_s); - if(FALSE==bli_mem_is_alloc(&local_mem_buf_A_s)) return BLIS_NULL_POINTER; - D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); - if(NULL==D_A_pack) return BLIS_NULL_POINTER; - } + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x31);//5 + _mm_storeu_ps((float *)(b11 + 4*cs_b), _mm256_extractf128_ps(ymm16, 0)); - //ymm scratch reginsters - __m256d ymm0, ymm1, ymm2, ymm3; - __m256d ymm4, ymm5, ymm6, ymm7; - __m256d ymm8, ymm9, ymm10, ymm11; - __m256d ymm12, ymm13, ymm14, ymm15; + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b11101110); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b11101110); - __m128d xmm5; + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//2 + _mm_storeu_ps((float *)(b11 + cs_b), _mm256_extractf128_ps(ymm16, 0)); - /* - Performs solving TRSM for 6 rows at a time from 0 to n/6 in steps of d_nr - a. Load and pack A (a01 block), the size of packing 6x6 to 6x (n-6) - First there will be no GEMM and no packing of a01 because it is only TRSM - b. Using packed a01 block and b10 block perform GEMM operation - c. Use GEMM outputs, perform TRSM operation using a11, b11 and update B - d. Repeat b for m cols of B in steps of d_mr - */ + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x31);//6 + _mm_storeu_ps((float *)(b11 + 5*cs_b), _mm256_extractf128_ps(ymm16, 0)); - for(j = (n-d_nr); (j+1) > 0; j -= d_nr) //loop along 'N' direction - { - a01 = L + (j*rs_a) + (j+d_nr)*cs_a; //pointer to block of A to be used in GEMM - a11 = L + (j*cs_a) + (j*rs_a); //pointer to block of A to be used for TRSM + ymm0 = _mm256_unpackhi_ps(ymm10, ymm11); + ymm1 = _mm256_unpackhi_ps(ymm17, ymm18); - //double *ptr_a10_dup = D_A_pack; + ymm6 = _mm256_unpackhi_ps(ymm2, ymm3); + ymm7 = _mm256_unpackhi_ps(ymm19, ymm20); - dim_t p_lda = (n-j-d_nr); // packed leading dimension - // perform copy of A to packed buffer D_A_pack + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b01000100); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b01000100); - if(transa) - { - /* - Pack current A block (a01) into packed buffer memory D_A_pack - a. This a10 block is used in GEMM portion only and this - a01 block size will be increasing by d_nr for every next iteration - until it reaches 6x(n-6) which is the maximum GEMM alone block size in A - b. This packed buffer is reused to calculate all m cols of B matrix - */ - bli_dtrsm_small_pack('R', p_lda, 1, a01, cs_a, D_A_pack, p_lda,d_nr); + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//3 + _mm_storeu_ps((float *)(b11 + 2*cs_b), _mm256_extractf128_ps(ymm16, 0)); - /* - Pack 6 diagonal elements of A block into an array - a. This helps in utilze cache line efficiently in TRSM operation - b. store ones when input is unit diagonal - */ - dtrsm_small_pack_diag_element(is_unitdiag,a11,cs_a,d11_pack,d_nr); - } - else - { - bli_dtrsm_small_pack('R', p_lda, 0, a01, rs_a, D_A_pack, p_lda,d_nr); - dtrsm_small_pack_diag_element(is_unitdiag,a11,rs_a,d11_pack,d_nr); + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b11101110); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b11101110); + + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//4 + _mm_storeu_ps((float *)(b11 + 3*cs_b), _mm256_extractf128_ps(ymm16, 0)); } - /* - a. Perform GEMM using a01, b10. - b. Perform TRSM on a11, b11 - c. This loop GEMM+TRSM loops operates with 8x6 block size - along m dimension for every d_mr columns of B10 where - packed A buffer is reused in computing all m cols of B. - d. Same approach is used in remaining fringe cases. - */ - for(i = (m-d_mr); (i+1) > 0; i -= d_mr) //loop along 'M' direction + dim_t n_rem = n-j; + if(n_rem >= 4) { - a01 = D_A_pack; - a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM - b10 = B + i + (j+d_nr)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (i) + (j)*cs_b; //pointer to block of B to be used for TRSM + a10 = D_A_pack; + a11 = L + (i*rs_a) + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + j*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM - k_iter = (n-j-d_nr); //number of GEMM operations to be done(in blocks of 4x4) + k_iter = i; //number of times GEMM to be performed(in blocks of 4x4) /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS + BLIS_SET_S_YMM_REG_ZEROS - /* - Peform GEMM between a01 and b10 blocks - For first itteration there will be no GEMM operation - where k_iter are zero - */ - - BLIS_DTRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) - - /* - Load b11 of size 8x6 and multiply with alpha - Add the GEMM output to b11 - and peform TRSM operation. - */ - - BLIS_PRE_DTRSM_SMALL_6x8(AlphaVal,b11,cs_b) + ///GEMM code begins/// + BLIS_STRSM_SMALL_GEMM_4mx4n(a10,b01,cs_b,p_lda,k_iter) - ///implement TRSM/// + ymm16 = _mm256_broadcast_ss((float const *)(&AlphaVal)); - /* - Compute 6x8 TRSM block by using GEMM block output in register - a. The 6x8 input (gemm outputs) are stored in combinations of ymm registers - 1. ymm3, ymm4 2. ymm5, ymm6 3. ymm7, ymm8, 4. ymm9, ymm10 - 5. ymm11, ymm12 6. ymm13,ymm14 - b. Towards the end TRSM output will be stored back into b11 - */ + ymm17 = _mm256_insertf128_ps(ymm1, _mm_loadu_ps((float const*)(b11)), 0); + ymm18 = _mm256_insertf128_ps(ymm1, _mm_loadu_ps((float const*)(b11 + cs_b)), 0); + ymm19 = _mm256_insertf128_ps(ymm1, _mm_loadu_ps((float const*)(b11 + cs_b*2)), 0); + ymm20 = _mm256_insertf128_ps(ymm1, _mm_loadu_ps((float const*)(b11 + cs_b*3)), 0); - //extract a55 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + ymm17 = _mm256_fmsub_ps(ymm17, ymm16, ymm8); + ymm18 = _mm256_fmsub_ps(ymm18, ymm16, ymm9); + ymm19 = _mm256_fmsub_ps(ymm19, ymm16, ymm10); + ymm20 = _mm256_fmsub_ps(ymm20, ymm16, ymm11); - ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm0); + ymm8 = _mm256_unpacklo_ps(ymm17, ymm18); + ymm9 = _mm256_unpacklo_ps(ymm19, ymm20); - //extract a44 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + ymm5 = _mm256_broadcast_ss((float const *)(&zero)); - //(row 5):FMA operations - //ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 4*rs_a)); + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b01000100); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm13, ymm11); - ymm12 = _mm256_fnmadd_pd(ymm1, ymm14, ymm12); + ymm10 = _mm256_permute2f128_ps(ymm4,ymm5,0x20);//1 - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 3*rs_a)); + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b11101110); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm13, ymm9); - ymm10 = _mm256_fnmadd_pd(ymm1, ymm14, ymm10); + ymm11 = _mm256_permute2f128_ps(ymm4,ymm5,0x20);//2 - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 2*rs_a)); + ymm8 = _mm256_unpackhi_ps(ymm17, ymm18); + ymm9 = _mm256_unpackhi_ps(ymm19, ymm20); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm13, ymm7); - ymm8 = _mm256_fnmadd_pd(ymm1, ymm14, ymm8); + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b01000100); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 1*rs_a)); + ymm17 = _mm256_permute2f128_ps(ymm4,ymm5,0x20);//3 - ymm5 = _mm256_fnmadd_pd(ymm1, ymm13, ymm5); - ymm6 = _mm256_fnmadd_pd(ymm1, ymm14, ymm6); + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b11101110); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a)); + ymm18 = _mm256_permute2f128_ps(ymm4,ymm5,0x20);//4 - ymm3 = _mm256_fnmadd_pd(ymm1, ymm13, ymm3); - ymm4 = _mm256_fnmadd_pd(ymm1, ymm14, ymm4); + // TRSM portion - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); - ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm0); + ////extract a00 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack)); - //extract a33 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + //perform mul operation + ymm10 = STRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); - //(row 4):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 3*rs_a)); + //extract a11 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm11, ymm9); - ymm10 = _mm256_fnmadd_pd(ymm1, ymm12, ymm10); + //(ROw1): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*1)); + ymm11 = _mm256_fnmadd_ps(ymm0, ymm10, ymm11); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 2*rs_a)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*2)); + ymm17 = _mm256_fnmadd_ps(ymm0, ymm10, ymm17); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm11, ymm7); - ymm8 = _mm256_fnmadd_pd(ymm1, ymm12, ymm8); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*3)); + ymm18 = _mm256_fnmadd_ps(ymm0, ymm10, ymm18); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 1*rs_a)); + //perform mul operation + ymm11 = STRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm11, ymm5); - ymm6 = _mm256_fnmadd_pd(ymm1, ymm12, ymm6); + a11 += rs_a; - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a)); + //extract a22 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 2)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm11, ymm3); - ymm4 = _mm256_fnmadd_pd(ymm1, ymm12, ymm4); + //(ROw2): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*2)); + ymm17 = _mm256_fnmadd_ps(ymm0, ymm11, ymm17); - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm0); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*3)); + ymm18 = _mm256_fnmadd_ps(ymm0, ymm11, ymm18); - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + //perform mul operation + ymm17 = STRSM_SMALL_DIV_OR_SCALE(ymm17, ymm1); - //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2*rs_a)); + a11 += rs_a; - ymm7 = _mm256_fnmadd_pd(ymm1, ymm9, ymm7); - ymm8 = _mm256_fnmadd_pd(ymm1, ymm10, ymm8); + //extract a33 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 3)); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1*rs_a)); + //(ROw3): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*3)); + ymm18 = _mm256_fnmadd_ps(ymm0, ymm17, ymm18); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm9, ymm5); - ymm6 = _mm256_fnmadd_pd(ymm1, ymm10, ymm6); + //perform mul operation + ymm18 = STRSM_SMALL_DIV_OR_SCALE(ymm18, ymm1); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); + a11 += rs_a; - ymm3 = _mm256_fnmadd_pd(ymm1, ymm9, ymm3); - ymm4 = _mm256_fnmadd_pd(ymm1, ymm10, ymm4); + ymm0 = _mm256_unpacklo_ps(ymm10, ymm11); + ymm1 = _mm256_unpacklo_ps(ymm17, ymm18); - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm0); + ymm6 = _mm256_unpacklo_ps(ymm2, ymm3); + ymm7 = _mm256_unpacklo_ps(ymm19, ymm20); - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b01000100); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b01000100); - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1*rs_a)); + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//1 + _mm_storeu_ps((float *)(b11), _mm256_extractf128_ps(ymm16, 0)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); - ymm6 = _mm256_fnmadd_pd(ymm1, ymm8, ymm6); + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b11101110); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b11101110); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//2 + _mm_storeu_ps((float *)(b11 + cs_b), _mm256_extractf128_ps(ymm16, 0)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); - ymm4 = _mm256_fnmadd_pd(ymm1, ymm8, ymm4); + ymm0 = _mm256_unpackhi_ps(ymm10, ymm11); + ymm1 = _mm256_unpackhi_ps(ymm17, ymm18); - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm0); + ymm6 = _mm256_unpackhi_ps(ymm2, ymm3); + ymm7 = _mm256_unpackhi_ps(ymm19, ymm20); - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b01000100); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b01000100); - //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//3 + _mm_storeu_ps((float *)(b11 + 2*cs_b), _mm256_extractf128_ps(ymm16, 0)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); - ymm4 = _mm256_fnmadd_pd(ymm1, ymm6, ymm4); + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b11101110); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b11101110); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm0); + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//4 + _mm_storeu_ps((float *)(b11 + 3*cs_b), _mm256_extractf128_ps(ymm16, 0)); - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + 4), ymm4); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b + 4), ymm6); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*2 + 4), ymm8); - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); - _mm256_storeu_pd((double *)(b11 + cs_b*3 + 4), ymm10); - _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); - _mm256_storeu_pd((double *)(b11 + cs_b*4 + 4), ymm12); - _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); - _mm256_storeu_pd((double *)(b11 + cs_b*5 + 4), ymm14); + n_rem -= 4; + j += 4; } - - dim_t m_remainder = i + d_mr; - if(m_remainder >= 4) + if(n_rem) { - a01 = D_A_pack; - a11 = L + (j*cs_a) + (j*rs_a); - b10 = B + (m_remainder - 4) + (j+d_nr)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (m_remainder - 4) + (j*cs_b); + a10 = D_A_pack; + a11 = L + (i*rs_a) + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + j*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM - k_iter = (n-j-d_nr); //number of GEMM operations to be done(in blocks of 4x4) + k_iter = i; //number of times GEMM to be performed(in blocks of 4x4) - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS + ymm8 = _mm256_setzero_ps(); + ymm9 = _mm256_setzero_ps(); + ymm10 = _mm256_setzero_ps(); - ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_6nx4m(a01,b10,cs_b,p_lda,k_iter) + if(3 == n_rem) + { + ///GEMM code begins/// + BLIS_STRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) - // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_DTRSM_SMALL_6x4(AlphaVal,b11,cs_b) + ymm16 = _mm256_broadcast_ss((float const *)(&AlphaVal)); //register to hold alpha - ///implement TRSM/// + ymm17 = _mm256_insertf128_ps(ymm1, _mm_loadu_ps((float const*)(b11)), 0); + ymm18 = _mm256_insertf128_ps(ymm1, _mm_loadu_ps((float const*)(b11 + cs_b)), 0); + ymm19 = _mm256_insertf128_ps(ymm1, _mm_loadu_ps((float const*)(b11 + cs_b*2)), 0); + ymm20 = _mm256_broadcast_ss((float const *)(&zero)); - //extract a55 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); - ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); + ymm17 = _mm256_fmsub_ps(ymm17, ymm16, ymm8); + ymm18 = _mm256_fmsub_ps(ymm18, ymm16, ymm9); + ymm19 = _mm256_fmsub_ps(ymm19, ymm16, ymm10); + } + else if(2 == n_rem) + { + ///GEMM code begins/// + BLIS_STRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b,p_lda,k_iter) - //extract a44 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + ymm16 = _mm256_broadcast_ss((float const *)(&AlphaVal)); //register to hold alpha - //(row 5):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 4*rs_a)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm13, ymm11); + ymm17 = _mm256_insertf128_ps(ymm1, _mm_loadu_ps((float const*)(b11)), 0); + ymm18 = _mm256_insertf128_ps(ymm1, _mm_loadu_ps((float const*)(b11 + cs_b)), 0); + ymm19 = _mm256_broadcast_ss((float const *)(&zero)); + ymm20 = _mm256_broadcast_ss((float const *)(&zero)); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 3*rs_a)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm13, ymm9); + ymm17 = _mm256_fmsub_ps(ymm17, ymm16, ymm8); + ymm18 = _mm256_fmsub_ps(ymm18, ymm16, ymm9); + } + else if(1 == n_rem) + { + ///GEMM code begins/// + BLIS_STRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b,p_lda,k_iter) - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 2*rs_a)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm13, ymm7); + ymm16 = _mm256_broadcast_ss((float const *)(&AlphaVal)); //register to hold alpha - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 1*rs_a)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm13, ymm5); + ymm17 = _mm256_insertf128_ps(ymm1, _mm_loadu_ps((float const*)(b11)), 0); + ymm18 = _mm256_broadcast_ss((float const *)(&zero)); + ymm19 = _mm256_broadcast_ss((float const *)(&zero)); + ymm20 = _mm256_broadcast_ss((float const *)(&zero)); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm13, ymm3); + ymm17 = _mm256_fmsub_ps(ymm17, ymm16, ymm8); + } - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); + ymm8 = _mm256_unpacklo_ps(ymm17, ymm18); + ymm9 = _mm256_unpacklo_ps(ymm19, ymm20); - //extract a33 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + ymm5 = _mm256_broadcast_ss((float const *)(&zero)); - //(row 4):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 3*rs_a)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm11, ymm9); + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b01000100); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 2*rs_a)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm11, ymm7); + ymm10 = _mm256_permute2f128_ps(ymm4,ymm5,0x20);//1 - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 1*rs_a)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm11, ymm5); + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b11101110); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm11, ymm3); + ymm11 = _mm256_permute2f128_ps(ymm4,ymm5,0x20);//2 - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + ymm8 = _mm256_unpackhi_ps(ymm17, ymm18); + ymm9 = _mm256_unpackhi_ps(ymm19, ymm20); - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b01000100); - //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2*rs_a)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm9, ymm7); + ymm17 = _mm256_permute2f128_ps(ymm4,ymm5,0x20);//3 - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1*rs_a)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm9, ymm5); + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b11101110); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm9, ymm3); + ymm18 = _mm256_permute2f128_ps(ymm4,ymm5,0x20);//4 - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + // TRSM portion - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + ////extract a00 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack)); - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1*rs_a)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); + //perform mul operation + ymm10 = STRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); + //extract a11 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + //(ROw1): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*1)); + ymm11 = _mm256_fnmadd_ps(ymm0, ymm10, ymm11); - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*2)); + ymm17 = _mm256_fnmadd_ps(ymm0, ymm10, ymm17); - //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*3)); + ymm18 = _mm256_fnmadd_ps(ymm0, ymm10, ymm18); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + //perform mul operation + ymm11 = STRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); - _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); - _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); + a11 += rs_a; - m_remainder -=4; - } + //extract a22 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 2)); - if(m_remainder) - { - if(3 == m_remainder) - { - a01 = D_A_pack; - a11 = L + (j*cs_a) + (j*rs_a); - b10 = B + (j+d_nr)*cs_b + (m_remainder - 3); //pointer to block of B to be used in GEMM - b11 = B + (m_remainder - 3) + (j*cs_b); + //(ROw2): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*2)); + ymm17 = _mm256_fnmadd_ps(ymm0, ymm11, ymm17); - k_iter = (n-j-d_nr); //number of GEMM operations to be done(in blocks of 4x4) + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*3)); + ymm18 = _mm256_fnmadd_ps(ymm0, ymm11, ymm18); - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS + //perform mul operation + ymm17 = STRSM_SMALL_DIV_OR_SCALE(ymm17, ymm1); - ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_6nx4m(a01,b10,cs_b,p_lda,k_iter) + a11 += rs_a; - // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_DTRSM_SMALL_6x4(AlphaVal,b11,cs_b) + //extract a33 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 3)); - ///implement TRSM/// + //(ROw3): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + cs_a*3)); + ymm18 = _mm256_fnmadd_ps(ymm0, ymm17, ymm18); - //extract a55 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); - ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); + //perform mul operation + ymm18 = STRSM_SMALL_DIV_OR_SCALE(ymm18, ymm1); - //extract a44 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + a11 += rs_a; - //(row 5):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 4*rs_a)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm13, ymm11); + if(3 == n_rem) + { + ymm0 = _mm256_unpacklo_ps(ymm10, ymm11); + ymm1 = _mm256_unpacklo_ps(ymm17, ymm18); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 3*rs_a)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm13, ymm9); + ymm6 = _mm256_unpacklo_ps(ymm2, ymm3); + ymm7 = _mm256_unpacklo_ps(ymm19, ymm20); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 2*rs_a)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm13, ymm7); + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b01000100); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b01000100); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 1*rs_a)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm13, ymm5); + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//1 + _mm_storeu_ps((float *)(b11), _mm256_extractf128_ps(ymm16, 0)); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm13, ymm3); + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b11101110); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b11101110); - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//2 + _mm_storeu_ps((float *)(b11 + cs_b), _mm256_extractf128_ps(ymm16, 0)); - //extract a33 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + ymm0 = _mm256_unpackhi_ps(ymm10, ymm11); + ymm1 = _mm256_unpackhi_ps(ymm17, ymm18); - //(row 4):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 3*rs_a)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm11, ymm9); + ymm6 = _mm256_unpackhi_ps(ymm2, ymm3); + ymm7 = _mm256_unpackhi_ps(ymm19, ymm20); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 2*rs_a)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm11, ymm7); + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b01000100); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b01000100); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 1*rs_a)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm11, ymm5); + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//3 + _mm_storeu_ps((float *)(b11 + 2*cs_b), _mm256_extractf128_ps(ymm16, 0)); + } + else if(2 == n_rem) + { + ymm0 = _mm256_unpacklo_ps(ymm10, ymm11); + ymm1 = _mm256_unpacklo_ps(ymm17, ymm18); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm11, ymm3); + ymm6 = _mm256_unpacklo_ps(ymm2, ymm3); + ymm7 = _mm256_unpacklo_ps(ymm19, ymm20); - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b01000100); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b01000100); - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//1 + _mm_storeu_ps((float *)(b11), _mm256_extractf128_ps(ymm16, 0)); - //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2*rs_a)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm9, ymm7); + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b11101110); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b11101110); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1*rs_a)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm9, ymm5); + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//2 + _mm_storeu_ps((float *)(b11 + cs_b), _mm256_extractf128_ps(ymm16, 0)); + } + else if(1 == n_rem) + { + ymm0 = _mm256_unpacklo_ps(ymm10, ymm11); + ymm1 = _mm256_unpacklo_ps(ymm17, ymm18); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm9, ymm3); + ymm6 = _mm256_unpacklo_ps(ymm2, ymm3); + ymm7 = _mm256_unpacklo_ps(ymm19, ymm20); - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b01000100); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b01000100); - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//1 + _mm_storeu_ps((float *)(b11), _mm256_extractf128_ps(ymm16, 0)); + } + } + m_rem -=4; + i +=4; + } - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1*rs_a)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); + if(m_rem) + { + a10 = L + (i*cs_a); //pointer to block of A to be used for GEMM + // Do transpose for a10 & store in D_A_pack + float *ptr_a10_dup = D_A_pack; - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); + if(3 == m_rem) // Repetative A blocks will be 3*3 + { + __m128 xmm0,xmm1,xmm2,xmm3; + __m128 xmm4,xmm5,xmm6,xmm7; + __m128 xmm8,xmm9; + dim_t p_lda = 4; // packed leading dimension + if(transa) + { + for(dim_t x=0;x= 4)) { - a01 = D_A_pack; - a11 = L + (j*cs_a) + (j*rs_a); - b10 = B + (j+d_nr)*cs_b + (m_remainder - 2); //pointer to block of B to be used in GEMM - b11 = B + (m_remainder - 2) + (j*cs_b); + a10 = D_A_pack; //pointer to block of A to be used for GEMM + a11 = L + (i*rs_a) + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM + b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM - k_iter = (n-j-d_nr); //number of GEMM operations to be done(in blocks of 4x4) + k_iter = i; //number of times GEMM to be performed(in blocks of 4x4) /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS + BLIS_SET_S_YMM_REG_ZEROS - ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_6nx4m(a01,b10,cs_b,p_lda,k_iter) + ///GEMM code begins/// + BLIS_STRSM_SMALL_GEMM_3mx4n(a10,b01,cs_b,p_lda,k_iter) - // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_DTRSM_SMALL_6x4(AlphaVal,b11,cs_b) + ymm16 = _mm256_broadcast_ss((float const *)(&AlphaVal)); //register to hold alpha ///implement TRSM/// + xmm4 = _mm_broadcast_ss((float const *)(b11 + 2)); + ymm17 = _mm256_insertf128_ps(ymm1, _mm_loadl_pi(xmm4,(__m64 *)(b11)), 0); + xmm4 = _mm_broadcast_ss((float const *)(b11 + cs_b + 2)); + ymm18 = _mm256_insertf128_ps(ymm1, _mm_loadl_pi(xmm4,(__m64 *)(b11 + cs_b)), 0); + xmm4 = _mm_broadcast_ss((float const *)(b11 + cs_b*2 + 2)); + ymm19 = _mm256_insertf128_ps(ymm1, _mm_loadl_pi(xmm4,(__m64 *)(b11 + cs_b*2)), 0); + xmm4 = _mm_broadcast_ss((float const *)(b11 + cs_b*3 + 2)); + ymm20 = _mm256_insertf128_ps(ymm1, _mm_loadl_pi(xmm4,(__m64 *)(b11 + cs_b*3)), 0); - //extract a55 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); - ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); + ymm17 = _mm256_fmsub_ps(ymm17, ymm16, ymm8); + ymm18 = _mm256_fmsub_ps(ymm18, ymm16, ymm9); + ymm19 = _mm256_fmsub_ps(ymm19, ymm16, ymm10); + ymm20 = _mm256_fmsub_ps(ymm20, ymm16, ymm11); - //extract a44 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + _mm_storel_pi((__m64 *)(b11), _mm256_extractf128_ps(ymm17, 0)); + _mm_store_ss((float *)(b11 + 2), _mm_permute_ps(_mm256_extractf128_ps(ymm17, 0),0x02)); - //(row 5):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 4*rs_a)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm13, ymm11); + _mm_storel_pi((__m64 *)(b11 + cs_b), _mm256_extractf128_ps(ymm18, 0)); + _mm_store_ss((float *)(b11 + cs_b + 2), _mm_permute_ps(_mm256_extractf128_ps(ymm18, 0),0x02)); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 3*rs_a)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm13, ymm9); + _mm_storel_pi((__m64 *)(b11 + cs_b*2), _mm256_extractf128_ps(ymm19, 0)); + _mm_store_ss((float *)(b11 + cs_b*2 + 2), _mm_permute_ps(_mm256_extractf128_ps(ymm19, 0),0x02)); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 2*rs_a)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm13, ymm7); + _mm_storel_pi((__m64 *)(b11 + cs_b*3), _mm256_extractf128_ps(ymm20, 0)); + _mm_store_ss((float *)(b11 + cs_b*3 + 2), _mm_permute_ps(_mm256_extractf128_ps(ymm20, 0),0x02)); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 1*rs_a)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm13, ymm5); + if(transa) + strsm_AutXB_ref(a11, b11, m_rem, 4, cs_a, cs_b,is_unitdiag); + else + strsm_AlXB_ref(a11, b11, m_rem, 4, rs_a, cs_b, is_unitdiag); + n_rem -= 4; + j +=4; + } - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm13, ymm3); + if(n_rem) + { + a10 = D_A_pack; //pointer to block of A to be used for GEMM + a11 = L + (i*rs_a) + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM + b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); + k_iter = i; //number of times GEMM to be performed(in blocks of 4x4) - //extract a33 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS - //(row 4):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 3*rs_a)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm11, ymm9); + if(3 == n_rem) + { + ///GEMM code begins/// + BLIS_STRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 2*rs_a)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm11, ymm7); + BLIS_PRE_STRSM_SMALL_3M_3N(AlphaVal,b11,cs_b) - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 1*rs_a)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm11, ymm5); + if(transa) + strsm_AutXB_ref(a11, b11, m_rem, 3, cs_a, cs_b,is_unitdiag); + else + strsm_AlXB_ref(a11, b11, m_rem, 3, rs_a, cs_b, is_unitdiag); + } + else if(2 == n_rem) + { + ///GEMM code begins/// + BLIS_STRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b,p_lda,k_iter) - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm11, ymm3); + BLIS_PRE_STRSM_SMALL_3M_2N(AlphaVal,b11,cs_b) - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + if(transa) + strsm_AutXB_ref(a11, b11, m_rem, 2, cs_a, cs_b,is_unitdiag); + else + strsm_AlXB_ref(a11, b11, m_rem, 2, rs_a, cs_b, is_unitdiag); + } + else if(1 == n_rem) + { + ///GEMM code begins/// + BLIS_STRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b,p_lda,k_iter) - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + BLIS_PRE_STRSM_SMALL_3M_1N(AlphaVal,b11,cs_b) - //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2*rs_a)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm9, ymm7); + if(transa) + strsm_AutXB_ref(a11, b11, m_rem, 1, cs_a, cs_b, is_unitdiag); + else + strsm_AlXB_ref(a11, b11, m_rem, 1, rs_a, cs_b, is_unitdiag); + } + } + } + else if(2 == m_rem) // Repetative A blocks will be 2*2 + { + __m128 xmm0,xmm1,xmm2,xmm3; + __m128 xmm4,xmm5,xmm6,xmm7; + __m128 xmm8,xmm9; + dim_t p_lda = 4; // packed leading dimension + if(transa) + { + for(dim_t x=0;x= 4)) + { + a10 = D_A_pack; //pointer to block of A to be used for GEMM + a11 = L + (i*rs_a) + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM + b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); + k_iter = i; //number of times GEMM to be performed(in blocks of 4x4) - //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + ///GEMM code begins/// + BLIS_STRSM_SMALL_GEMM_4mx4n(a10,b01,cs_b,p_lda,k_iter) - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_pd(ymm0, ymm11, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_pd(ymm0, ymm13, 0x03); + ymm16 = _mm256_broadcast_ss((float const *)(&AlphaVal)); //register to hold alpha + + ///implement TRSM/// - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); - _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); - _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); + xmm4 = _mm_broadcast_ss((float const *)(&zero)); + ymm17 = _mm256_insertf128_ps(ymm1, _mm_loadl_pi(xmm4,(__m64 *)(b11)), 0); + ymm18 = _mm256_insertf128_ps(ymm1, _mm_loadl_pi(xmm4,(__m64 *)(b11 + cs_b)), 0); + ymm19 = _mm256_insertf128_ps(ymm1, _mm_loadl_pi(xmm4,(__m64 *)(b11 + cs_b*2)), 0); + ymm20 = _mm256_insertf128_ps(ymm1, _mm_loadl_pi(xmm4,(__m64 *)(b11 + cs_b*3)), 0); - m_remainder -=2; + ymm17 = _mm256_fmsub_ps(ymm17, ymm16, ymm8); + ymm18 = _mm256_fmsub_ps(ymm18, ymm16, ymm9); + ymm19 = _mm256_fmsub_ps(ymm19, ymm16, ymm10); + ymm20 = _mm256_fmsub_ps(ymm20, ymm16, ymm11); + + _mm_storel_pi((__m64 *)(b11), _mm256_extractf128_ps(ymm17, 0)); + _mm_storel_pi((__m64 *)(b11 + cs_b), _mm256_extractf128_ps(ymm18, 0)); + _mm_storel_pi((__m64 *)(b11 + cs_b*2), _mm256_extractf128_ps(ymm19, 0)); + _mm_storel_pi((__m64 *)(b11 + cs_b*3), _mm256_extractf128_ps(ymm20, 0)); + + if(transa) + strsm_AutXB_ref(a11, b11, m_rem, 4, cs_a, cs_b, is_unitdiag); + else + strsm_AlXB_ref(a11, b11, m_rem, 4, rs_a, cs_b, is_unitdiag); + n_rem -= 4; + j +=4; } - else if (1 == m_remainder) + if(n_rem) { - a01 = D_A_pack; - a11 = L + (j*cs_a) + (j*rs_a); - b10 = B + (j+d_nr)*cs_b + (m_remainder - 1); //pointer to block of B to be used in GEMM - b11 = B + (m_remainder - 1) + (j*cs_b); + a10 = D_A_pack; //pointer to block of A to be used for GEMM + a11 = L + (i*rs_a) + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM + b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM - k_iter = (n-j-d_nr); //number of GEMM operations to be done(in blocks of 4x4) + k_iter = i; //number of times GEMM to be performed(in blocks of 4x4) /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS - - ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_6nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_SET_S_YMM_REG_ZEROS - // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_DTRSM_SMALL_6x4(AlphaVal,b11,cs_b) + if(3 == n_rem) + { + ///GEMM code begins/// + BLIS_STRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) - ///implement TRSM/// + BLIS_PRE_STRSM_SMALL_2M_3N(AlphaVal,b11,cs_b) - //extract a55 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); - ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); + if(transa) + strsm_AutXB_ref(a11, b11, m_rem, 3, cs_a, cs_b, is_unitdiag); + else + strsm_AlXB_ref(a11, b11, m_rem, 3, rs_a, cs_b, is_unitdiag); + } + else if(2 == n_rem) + { + ///GEMM code begins/// + BLIS_STRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b,p_lda,k_iter) - //extract a44 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + BLIS_PRE_STRSM_SMALL_2M_2N(AlphaVal,b11,cs_b) - //(row 5):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 4*rs_a)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm13, ymm11); + if(transa) + strsm_AutXB_ref(a11, b11, m_rem, 2, cs_a, cs_b, is_unitdiag); + else + strsm_AlXB_ref(a11, b11, m_rem, 2, rs_a, cs_b, is_unitdiag); + } + else if(1 == n_rem) + { + ///GEMM code begins/// + BLIS_STRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b,p_lda,k_iter) - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 3*rs_a)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm13, ymm9); + BLIS_PRE_STRSM_SMALL_2M_1N(AlphaVal,b11,cs_b) - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 2*rs_a)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm13, ymm7); + if(transa) + strsm_AutXB_ref(a11, b11, m_rem, 1, cs_a, cs_b, is_unitdiag); + else + strsm_AlXB_ref(a11, b11, m_rem, 1, rs_a, cs_b, is_unitdiag); + } + } + m_rem -=2; + i+=2; + } + else if(1 == m_rem) // Repetative A blocks will be 1*1 + { + __m128 xmm0,xmm1,xmm2,xmm3; + __m128 xmm4,xmm5,xmm6,xmm7; + __m128 xmm8,xmm9; + dim_t p_lda = 4; // packed leading dimension + if(transa) + { + for(dim_t x=0;x= 4)) + { + a10 = D_A_pack; //pointer to block of A to be used for GEMM + a11 = L + (i*rs_a) + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM + b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 1*rs_a)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm11, ymm5); + k_iter = i; //number of times GEMM to be performed(in blocks of 4x4) - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm11, ymm3); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + ///GEMM code begins/// + BLIS_STRSM_SMALL_GEMM_4mx4n(a10,b01,cs_b,p_lda,k_iter) - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + ymm16 = _mm256_broadcast_ss((float const *)(&AlphaVal)); //register to hold alpha - //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2*rs_a)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm9, ymm7); + ///implement TRSM/// + ymm17 = _mm256_broadcast_ss((float const*)(b11)); + ymm18 = _mm256_broadcast_ss((float const*)(b11 + cs_b)); + ymm19 = _mm256_broadcast_ss((float const*)(b11 + cs_b*2)); + ymm20 = _mm256_broadcast_ss((float const*)(b11 + cs_b*3)); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1*rs_a)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm9, ymm5); + ymm17 = _mm256_fmsub_ps(ymm17, ymm16, ymm8); + ymm18 = _mm256_fmsub_ps(ymm18, ymm16, ymm9); + ymm19 = _mm256_fmsub_ps(ymm19, ymm16, ymm10); + ymm20 = _mm256_fmsub_ps(ymm20, ymm16, ymm11); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm9, ymm3); + _mm_store_ss((float *)(b11), _mm256_extractf128_ps(ymm17,0)); + _mm_store_ss((float *)(b11 + cs_b), _mm256_extractf128_ps(ymm18,0)); + _mm_store_ss((float *)(b11 + cs_b*2), _mm256_extractf128_ps(ymm19,0)); + _mm_store_ss((float *)(b11 + cs_b*3), _mm256_extractf128_ps(ymm20,0)); - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + if(transa) + strsm_AutXB_ref(a11, b11, m_rem, 4, cs_a, cs_b, is_unitdiag); + else + strsm_AlXB_ref(a11, b11, m_rem, 4, rs_a, cs_b, is_unitdiag); + n_rem -= 4; + j+=4; + } - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + if(n_rem) + { + a10 = D_A_pack; //pointer to block of A to be used for GEMM + a11 = L + (i*rs_a) + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM + b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1*rs_a)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); + k_iter = i; //number of times GEMM to be performed(in blocks of 4x4) - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + if(3 == n_rem) + { + ///GEMM code begins/// + BLIS_STRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); + BLIS_PRE_STRSM_SMALL_1M_3N(AlphaVal,b11,cs_b) - //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + if(transa) + strsm_AutXB_ref(a11, b11, m_rem, 3, cs_a, cs_b, is_unitdiag); + else + strsm_AlXB_ref(a11, b11, m_rem, 3, rs_a, cs_b, is_unitdiag); + } + else if(2 == n_rem) + { + ///GEMM code begins/// + BLIS_STRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b,p_lda,k_iter) - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + BLIS_PRE_STRSM_SMALL_1M_2N(AlphaVal,b11,cs_b) - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_pd(ymm0, ymm11, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_pd(ymm0, ymm13, 0x01); + if(transa) + strsm_AutXB_ref(a11, b11, m_rem, 2, cs_a, cs_b, is_unitdiag); + else + strsm_AlXB_ref(a11, b11, m_rem, 2, rs_a, cs_b, is_unitdiag); + } + else if(1 == n_rem) + { + ///GEMM code begins/// + BLIS_STRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b,p_lda,k_iter) - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); - _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); - _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); + BLIS_PRE_STRSM_SMALL_1M_1N(AlphaVal,b11,cs_b) - m_remainder -=1; + if(transa) + strsm_AutXB_ref(a11, b11, m_rem, 1, cs_a, cs_b, is_unitdiag); + else + strsm_AlXB_ref(a11, b11, m_rem, 1, rs_a, cs_b, is_unitdiag); + } } + m_rem -=1; + i+=1; } } - dim_t n_remainder = j + d_nr; - - /* - Reminder cases starts here: - a. Similar logic and code flow used in computing full block (6x8) - above holds for reminder cases too. - */ - - if(n_remainder >= 4) + if ((required_packing_A == 1) && + bli_mem_is_alloc( &local_mem_buf_A_s )) { - a01 = L + (n_remainder - 4)*rs_a + n_remainder*cs_a; //pointer to block of A to be used in GEMM - a11 = L + (n_remainder - 4)*cs_a + (n_remainder - 4)*rs_a; //pointer to block of A to be used for TRSM + bli_membrk_release(&rntm, &local_mem_buf_A_s); + } + return BLIS_SUCCESS; +} - double *ptr_a10_dup = D_A_pack; +/* TRSM for the case AX = alpha * B, Single precision + * A is lower-triangular, transpose, non-unit diagonal + * dimensions A: mxm X: mxn B: mxn +*/ +BLIS_INLINE err_t bli_strsm_small_AltXB_AuXB +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +) +{ + dim_t m = bli_obj_length(b); // number of rows of matrix B + dim_t n = bli_obj_width(b); // number of columns of matrix B - dim_t p_lda = (n-n_remainder); // packed leading dimension - // perform copy of A to packed buffer D_A_pack + bool transa = bli_obj_has_trans(a); + dim_t cs_a, rs_a; + dim_t d_mr = 16,d_nr = 6; - if(transa) - { - for(dim_t x =0;x < p_lda;x+=d_nr) - { - ymm0 = _mm256_loadu_pd((double const *)(a01)); - ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); - ymm2 = _mm256_loadu_pd((double const *)(a01 + cs_a * 2)); - ymm3 = _mm256_loadu_pd((double const *)(a01 + cs_a * 3)); + // Swap rs_a & cs_a in case of non-tranpose. + if(transa) + { + cs_a = bli_obj_col_stride(a); // column stride of A + rs_a = bli_obj_row_stride(a); // row stride of A + } + else + { + cs_a = bli_obj_row_stride(a); // row stride of A + rs_a = bli_obj_col_stride(a); // column stride of A + } + dim_t cs_b = bli_obj_col_stride(b); // column stride of B - ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + dim_t i, j, k; //loop variables + dim_t k_iter; //number of times GEMM to be performed - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + float AlphaVal = *(float *)AlphaObj->buffer; //value of alpha + float *L = a->buffer; //pointer to matrix A + float *B = b->buffer; //pointer to matrix B - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + //pointers that point to blocks for GEMM and TRSM + float *a10, *a11, *b01, *b11; - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + float ones = 1.0; + float zero = 0.0; + bool is_unitdiag = bli_obj_has_unit_diag(a); - _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); + //scratch registers + __m256 ymm0, ymm1, ymm2, ymm3; + __m256 ymm4, ymm5, ymm6, ymm7; + __m256 ymm8, ymm9, ymm10, ymm11; + __m256 ymm12, ymm13, ymm14, ymm15; + __m256 ymm16, ymm17, ymm18, ymm19; + __m256 ymm20, ymm21, ymm22; - ymm0 = _mm256_loadu_pd((double const *)(a01 + cs_a * 4)); - ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a * 5)); + gint_t required_packing_A = 1; + mem_t local_mem_buf_A_s = {0}; + float *D_A_pack = NULL; + float d11_pack[d_mr] __attribute__((aligned(64))); + rntm_t rntm; - ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_broadcast_sd((double const *)&zero); + bli_rntm_init_from_global( &rntm ); + bli_rntm_set_num_threads_only( 1, &rntm ); + bli_membrk_rntm_set_membrk( &rntm ); - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + siz_t buffer_size = bli_pool_block_size( + bli_membrk_pool( + bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), + bli_rntm_membrk(&rntm))); - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_broadcast_sd((double const *)&zero); + if((d_mr * m * sizeof(float)) > buffer_size) + return BLIS_NOT_YET_IMPLEMENTED; - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + if(required_packing_A == 1) + { + // Get the buffer from the pool. + bli_membrk_acquire_m(&rntm, + buffer_size, + BLIS_BITVAL_BUFFER_FOR_A_BLOCK, + &local_mem_buf_A_s); + if(FALSE==bli_mem_is_alloc(&local_mem_buf_A_s)) return BLIS_NULL_POINTER; + D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); + if(NULL==D_A_pack) return BLIS_NULL_POINTER; + } - _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_extractf128_pd(ymm6,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_extractf128_pd(ymm7,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_extractf128_pd(ymm8,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_extractf128_pd(ymm9,0)); + /* + Performs solving TRSM for 16 columns at a time from 0 to m/d_mr in steps of d_mr + a. Load, transpose, Pack A (a10 block), the size of packing 16x6 to 16 x (m-d_mr) + First there will be no GEMM and no packing of a10 because it is only TRSM + b. Using packed a10 block and b01 block perform GEMM operation + c. Use GEMM outputs, perform TRSM operaton using a11, b11 and update B + d. Repeat b,c for n rows of B in steps of d_nr + */ + for(i = (m - d_mr); (i + 1) > 0; i -= d_mr) + { + a10 = L + (i*cs_a) + (i + d_mr)*rs_a; //pointer to block of A to be used for GEMM + a11 = L + (i*cs_a) + (i*rs_a); //pointer to block of A to be used for TRSM - a01 += d_nr*cs_a; - ptr_a10_dup += d_nr; - } + // Do transpose for a10 & store in D_A_pack + //ptr_a10_dup = D_A_pack; + + dim_t p_lda = d_mr; // packed leading dimension + + if(transa) + { + /* + Load, transpose and pack current A block (a10) into packed buffer memory D_A_pack + a. This a10 block is used in GEMM portion only and this + a10 block size will be increasing by d_mr for every next itteration + untill it reaches 16x(m-16) which is the maximum GEMM alone block size in A + b. This packed buffer is reused to calculate all n rows of B matrix + */ + bli_strsm_small_pack('L', (m-i-d_mr), 1, a10, cs_a, D_A_pack,p_lda,d_mr); + + /* + Pack 8 diagonal elements of A block into an array + a. This helps in utilze cache line efficiently in TRSM operation + b. store ones when input is unit diagonal + */ + strsm_small_pack_diag_element('L',is_unitdiag,a11,cs_a,d11_pack,d_mr); } else { - dim_t loop_count = (n-n_remainder)/4; + bli_strsm_small_pack('L', (m-i-d_mr), 0, a10, rs_a, D_A_pack,p_lda,d_mr); + strsm_small_pack_diag_element('L',is_unitdiag,a11,rs_a,d11_pack,d_mr); + } - for(dim_t x =0;x < loop_count;x++) - { - ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 0 + x*4)); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + x*4), ymm15); - ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 1 + x*4)); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 1 + x*4), ymm15); - ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 2 + x*4)); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 2 + x*4), ymm15); - ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 3 + x*4)); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 3 + x*4), ymm15); - } + /* + a. Perform GEMM using a10, b01. + b. Perform TRSM on a11, b11 + c. This loop GEMM+TRSM loops operates with 16x6 block size + along n dimension for every d_nr rows of b01 where + packed A buffer is reused in computing all n rows of B. + d. Same approach is used in remaining fringe cases. + */ + for(j = (n - d_nr); (j + 1) > 0; j -= d_nr) + { + a10 = D_A_pack; + b01 = B + (j * cs_b) + i + d_mr; //pointer to block of B to be used for GEMM + b11 = B + (j * cs_b) + i; //pointer to block of B to be used for TRSM - dim_t remainder_loop_count = p_lda - loop_count*4; + k_iter = (m - i - d_mr); - __m128d xmm0; - if(remainder_loop_count != 0) - { - xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 0 + loop_count*4)); - _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + loop_count*4), xmm0); - xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 1 + loop_count*4)); - _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 1 + loop_count*4), xmm0); - xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 2 + loop_count*4)); - _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 2 + loop_count*4), xmm0); - xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 3 + loop_count*4)); - _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 3 + loop_count*4), xmm0); - } - } + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS - ymm4 = _mm256_broadcast_sd((double const *)&ones); - if(!is_unitdiag) - { - if(transa) - { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_sd((double const *)(a11)); - ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a*1 + 1)); - ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a*2 + 2)); - ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a*3 + 3)); - } - else - { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_sd((double const *)(a11)); - ymm1 = _mm256_broadcast_sd((double const *)(a11+ rs_a*1 + 1)); - ymm2 = _mm256_broadcast_sd((double const *)(a11+ rs_a*2 + 2)); - ymm3 = _mm256_broadcast_sd((double const *)(a11+ rs_a*3 + 3)); - } + /* + Peform GEMM between a10 and b01 blocks + For first itteration there will be no GEMM operation + where k_iter are zero + */ + BLIS_STRSM_SMALL_GEMM_16mx6n(a10,b01,cs_b,p_lda,k_iter) - ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); + /* + Load b11 of size 6x16 and multiply with alpha + Add the GEMM output and perform inregister transose of b11 + to peform TRSM operation. + */ + BLIS_STRSM_SMALL_NREG_TRANSPOSE_6x16(b11,cs_b,AlphaVal) - ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); - #ifdef BLIS_DISABLE_TRSM_PREINVERSION - ymm4 = ymm1; - #endif - #ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm4 = _mm256_div_pd(ymm4, ymm1); - #endif - } - _mm256_storeu_pd((double *)(d11_pack), ymm4); + // TRSM Operation - for(i = (m-d_mr); (i+1) > 0; i -= d_mr) //loop along 'M' direction - { - a01 = D_A_pack; - a11 = L + (n_remainder - 4)*cs_a + (n_remainder - 4)*rs_a; //pointer to block of A to be used for TRSM - b10 = B + i + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (i) + (n_remainder - 4)*cs_b; //pointer to block of B to be used for TRSM + ////extract a00 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 15)); - k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + //perform mul operation + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm1); - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS + //extract a11 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 14)); - ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_4nx8m(a01,b10,cs_b,p_lda,k_iter) + //(ROw1): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*15 + cs_a*14)); + ymm4 = _mm256_fnmadd_ps(ymm0, ymm5, ymm4); - BLIS_PRE_DTRSM_SMALL_4x8(AlphaVal,b11,cs_b) + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*15 + cs_a*13)); + ymm22 = _mm256_fnmadd_ps(ymm0, ymm5, ymm22); - ///implement TRSM/// + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*15 + cs_a*12)); + ymm21 = _mm256_fnmadd_ps(ymm0, ymm5, ymm21); - //extract a33 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*15 + cs_a*11)); + ymm9 = _mm256_fnmadd_ps(ymm0, ymm5, ymm9); - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm0); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*15 + cs_a*10)); + ymm8 = _mm256_fnmadd_ps(ymm0, ymm5, ymm8); - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*15 + cs_a*9)); + ymm15 = _mm256_fnmadd_ps(ymm0, ymm5, ymm15); - //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2*rs_a)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*15 + cs_a*8)); + ymm14 = _mm256_fnmadd_ps(ymm0, ymm5, ymm14); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm9, ymm7); - ymm8 = _mm256_fnmadd_pd(ymm1, ymm10, ymm8); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*15 + cs_a*7)); + ymm20 = _mm256_fnmadd_ps(ymm0, ymm5, ymm20); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1*rs_a)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*15 + cs_a*6)); + ymm19 = _mm256_fnmadd_ps(ymm0, ymm5, ymm19); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm9, ymm5); - ymm6 = _mm256_fnmadd_pd(ymm1, ymm10, ymm6); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*15 + cs_a*5)); + ymm3 = _mm256_fnmadd_ps(ymm0, ymm5, ymm3 ); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*15 + cs_a*4)); + ymm2 = _mm256_fnmadd_ps(ymm0, ymm5, ymm2 ); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm9, ymm3); - ymm4 = _mm256_fnmadd_pd(ymm1, ymm10, ymm4); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*15 + cs_a*3)); + ymm18 = _mm256_fnmadd_ps(ymm0, ymm5, ymm18); - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm0); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*15 + cs_a*2)); + ymm17 = _mm256_fnmadd_ps(ymm0, ymm5, ymm17); - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*15 + cs_a*1)); + ymm11 = _mm256_fnmadd_ps(ymm0, ymm5, ymm11 ); - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1*rs_a)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*15)); + ymm10 = _mm256_fnmadd_ps(ymm0, ymm5, ymm10 ); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); - ymm6 = _mm256_fnmadd_pd(ymm1, ymm8, ymm6); + //perform mul operation + ymm4 = STRSM_SMALL_DIV_OR_SCALE(ymm4, ymm1); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); + //extract a22 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 13)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); - ymm4 = _mm256_fnmadd_pd(ymm1, ymm8, ymm4); + //(ROw2): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*14 + cs_a*13)); + ymm22 = _mm256_fnmadd_ps(ymm0, ymm4, ymm22); - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm0); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*14 + cs_a*12)); + ymm21 = _mm256_fnmadd_ps(ymm0, ymm4, ymm21); - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*14 + cs_a*11)); + ymm9 = _mm256_fnmadd_ps(ymm0, ymm4, ymm9); - //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*14 + cs_a*10)); + ymm8 = _mm256_fnmadd_ps(ymm0, ymm4, ymm8); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); - ymm4 = _mm256_fnmadd_pd(ymm1, ymm6, ymm4); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*14 + cs_a*9)); + ymm15 = _mm256_fnmadd_ps(ymm0, ymm4, ymm15); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm0); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*14 + cs_a*8)); + ymm14 = _mm256_fnmadd_ps(ymm0, ymm4, ymm14); - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + 4), ymm4); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b + 4), ymm6); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*2 + 4), ymm8); - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); - _mm256_storeu_pd((double *)(b11 + cs_b*3 + 4), ymm10); - } + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*14 + cs_a*7)); + ymm20 = _mm256_fnmadd_ps(ymm0, ymm4, ymm20); - dim_t m_remainder = i + d_mr; - if(m_remainder >= 4) - { - a01 = D_A_pack; - a11 = L + (n_remainder - 4)*cs_a + (n_remainder - 4)*rs_a; //pointer to block of A to be used for TRSM - b10 = B + (m_remainder - 4) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (m_remainder - 4) + (n_remainder - 4)*cs_b; //pointer to block of B to be used for TRSM + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*14 + cs_a*6)); + ymm19 = _mm256_fnmadd_ps(ymm0, ymm4, ymm19); - k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*14 + cs_a*5)); + ymm3 = _mm256_fnmadd_ps(ymm0, ymm4, ymm3 ); - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*14 + cs_a*4)); + ymm2 = _mm256_fnmadd_ps(ymm0, ymm4, ymm2 ); - ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_4nx4m(a01,b10,cs_b,p_lda,k_iter) + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*14 + cs_a*3)); + ymm18 = _mm256_fnmadd_ps(ymm0, ymm4, ymm18); - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*14 + cs_a*2)); + ymm17 = _mm256_fnmadd_ps(ymm0, ymm4, ymm17); - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*14 + cs_a*1)); + ymm11 = _mm256_fnmadd_ps(ymm0, ymm4, ymm11 ); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*14)); + ymm10 = _mm256_fnmadd_ps(ymm0, ymm4, ymm10 ); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 + //perform mul operation + ymm22 = STRSM_SMALL_DIV_OR_SCALE(ymm22, ymm1); + + //extract a33 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 12)); + + //(ROw3): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*13 + cs_a*12)); + ymm21 = _mm256_fnmadd_ps(ymm0, ymm22, ymm21); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*13 + cs_a*11)); + ymm9 = _mm256_fnmadd_ps(ymm0, ymm22, ymm9); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*13 + cs_a*10)); + ymm8 = _mm256_fnmadd_ps(ymm0, ymm22, ymm8); - ///implement TRSM/// + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*13 + cs_a*9)); + ymm15 = _mm256_fnmadd_ps(ymm0, ymm22, ymm15); - //extract a33 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*13 + cs_a*8)); + ymm14 = _mm256_fnmadd_ps(ymm0, ymm22, ymm14); - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*13 + cs_a*7)); + ymm20 = _mm256_fnmadd_ps(ymm0, ymm22, ymm20); - //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2*rs_a)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm9, ymm7); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*13 + cs_a*6)); + ymm19 = _mm256_fnmadd_ps(ymm0, ymm22, ymm19); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1*rs_a)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm9, ymm5); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*13 + cs_a*5)); + ymm3 = _mm256_fnmadd_ps(ymm0, ymm22, ymm3 ); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm9, ymm3); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*13 + cs_a*4)); + ymm2 = _mm256_fnmadd_ps(ymm0, ymm22, ymm2 ); - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*13 + cs_a*3)); + ymm18 = _mm256_fnmadd_ps(ymm0, ymm22, ymm18); - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*13 + cs_a*2)); + ymm17 = _mm256_fnmadd_ps(ymm0, ymm22, ymm17); - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1*rs_a)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*13 + cs_a*1)); + ymm11 = _mm256_fnmadd_ps(ymm0, ymm22, ymm11 ); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*13)); + ymm10 = _mm256_fnmadd_ps(ymm0, ymm22, ymm10 ); - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + //perform mul operation + ymm21 = STRSM_SMALL_DIV_OR_SCALE(ymm21, ymm1); - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); + //extract a44 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 11)); + //(ROw4): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*12 + cs_a*11)); + ymm9 = _mm256_fnmadd_ps(ymm0, ymm21, ymm9); - //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*12 + cs_a*10)); + ymm8 = _mm256_fnmadd_ps(ymm0, ymm21, ymm8); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*12 + cs_a*9)); + ymm15 = _mm256_fnmadd_ps(ymm0, ymm21, ymm15); - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*12 + cs_a*8)); + ymm14 = _mm256_fnmadd_ps(ymm0, ymm21, ymm14); - m_remainder -=4; - } + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*12 + cs_a*7)); + ymm20 = _mm256_fnmadd_ps(ymm0, ymm21, ymm20); - if(m_remainder) - { - if(3 == m_remainder) - { - a01 = D_A_pack; - a11 = L + (n_remainder - 4)*cs_a + (n_remainder - 4)*rs_a; //pointer to block of A to be used for TRSM - b10 = B + (m_remainder - 3) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (m_remainder - 3) + (n_remainder - 4)*cs_b; //pointer to block of B to be used for TRSM + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*12 + cs_a*6)); + ymm19 = _mm256_fnmadd_ps(ymm0, ymm21, ymm19); - k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*12 + cs_a*5)); + ymm3 = _mm256_fnmadd_ps(ymm0, ymm21, ymm3 ); - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*12 + cs_a*4)); + ymm2 = _mm256_fnmadd_ps(ymm0, ymm21, ymm2 ); - ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_4nx4m(a01,b10,cs_b,p_lda,k_iter) + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*12 + cs_a*3)); + ymm18 = _mm256_fnmadd_ps(ymm0, ymm21, ymm18); - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*12 + cs_a*2)); + ymm17 = _mm256_fnmadd_ps(ymm0, ymm21, ymm17); - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*12 + cs_a*1)); + ymm11 = _mm256_fnmadd_ps(ymm0, ymm21, ymm11 ); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*12)); + ymm10 = _mm256_fnmadd_ps(ymm0, ymm21, ymm10 ); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 + //perform mul operation + ymm9 = STRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*3 + 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 + //extract a55 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 10)); - ///implement TRSM/// + //(ROw5): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*11 + cs_a*10)); + ymm8 = _mm256_fnmadd_ps(ymm0, ymm9, ymm8); - //extract a33 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*11 + cs_a*9)); + ymm15 = _mm256_fnmadd_ps(ymm0, ymm9, ymm15); - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*11 + cs_a*8)); + ymm14 = _mm256_fnmadd_ps(ymm0, ymm9, ymm14); - //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2*rs_a)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm9, ymm7); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*11 + cs_a*7)); + ymm20 = _mm256_fnmadd_ps(ymm0, ymm9, ymm20); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1*rs_a)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm9, ymm5); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*11 + cs_a*6)); + ymm19 = _mm256_fnmadd_ps(ymm0, ymm9, ymm19); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm9, ymm3); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*11 + cs_a*5)); + ymm3 = _mm256_fnmadd_ps(ymm0, ymm9, ymm3 ); - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*11 + cs_a*4)); + ymm2 = _mm256_fnmadd_ps(ymm0, ymm9, ymm2 ); - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*11 + cs_a*3)); + ymm18 = _mm256_fnmadd_ps(ymm0, ymm9, ymm18); - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1*rs_a)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*11 + cs_a*2)); + ymm17 = _mm256_fnmadd_ps(ymm0, ymm9, ymm17); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*11 + cs_a*1)); + ymm11 = _mm256_fnmadd_ps(ymm0, ymm9, ymm11 ); - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*11)); + ymm10 = _mm256_fnmadd_ps(ymm0, ymm9, ymm10 ); - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); + //perform mul operation + ymm8 = STRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); - //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + //extract a66 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 9)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + //(ROw6): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*10 + cs_a*9)); + ymm15 = _mm256_fnmadd_ps(ymm0, ymm8, ymm15); - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x07); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*3 + 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x07); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*10 + cs_a*8)); + ymm14 = _mm256_fnmadd_ps(ymm0, ymm8, ymm14); - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - xmm5 = _mm256_extractf128_pd(ymm9, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 3),xmm5); - _mm_storel_pd((b11 + cs_b * 3 + 2), _mm256_extractf128_pd(ymm9, 1)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*10 + cs_a*7)); + ymm20 = _mm256_fnmadd_ps(ymm0, ymm8, ymm20); - m_remainder -=3; - } - else if(2 == m_remainder) - { - a01 = D_A_pack; - a11 = L + (n_remainder - 4)*cs_a + (n_remainder - 4)*rs_a; //pointer to block of A to be used for TRSM - b10 = B + (m_remainder - 2) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (m_remainder - 2) + (n_remainder - 4)*cs_b; //pointer to block of B to be used for TRSM + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*10 + cs_a*6)); + ymm19 = _mm256_fnmadd_ps(ymm0, ymm8, ymm19); - k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*10 + cs_a*5)); + ymm3 = _mm256_fnmadd_ps(ymm0, ymm8, ymm3 ); - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*10 + cs_a*4)); + ymm2 = _mm256_fnmadd_ps(ymm0, ymm8, ymm2 ); - ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_4nx4m(a01,b10,cs_b,p_lda,k_iter) + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*10 + cs_a*3)); + ymm18 = _mm256_fnmadd_ps(ymm0, ymm8, ymm18); - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*10 + cs_a*2)); + ymm17 = _mm256_fnmadd_ps(ymm0, ymm8, ymm17); - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*10 + cs_a*1)); + ymm11 = _mm256_fnmadd_ps(ymm0, ymm8, ymm11 ); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*10)); + ymm10 = _mm256_fnmadd_ps(ymm0, ymm8, ymm10 ); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 + //perform mul operation + ymm15 = STRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 + //extract a77 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 8)); - ///implement TRSM/// + //(ROw7): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*9 + cs_a*8)); + ymm14 = _mm256_fnmadd_ps(ymm0, ymm15, ymm14); - //extract a33 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*9 + cs_a*7)); + ymm20 = _mm256_fnmadd_ps(ymm0, ymm15, ymm20); - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*9 + cs_a*6)); + ymm19 = _mm256_fnmadd_ps(ymm0, ymm15, ymm19); - //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2*rs_a)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm9, ymm7); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*9 + cs_a*5)); + ymm3 = _mm256_fnmadd_ps(ymm0, ymm15, ymm3 ); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1*rs_a)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm9, ymm5); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*9 + cs_a*4)); + ymm2 = _mm256_fnmadd_ps(ymm0, ymm15, ymm2 ); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm9, ymm3); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*9 + cs_a*3)); + ymm18 = _mm256_fnmadd_ps(ymm0, ymm15, ymm18); - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*9 + cs_a*2)); + ymm17 = _mm256_fnmadd_ps(ymm0, ymm15, ymm17); - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*9 + cs_a*1)); + ymm11 = _mm256_fnmadd_ps(ymm0, ymm15, ymm11 ); - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1*rs_a)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*9)); + ymm10 = _mm256_fnmadd_ps(ymm0, ymm15, ymm10 ); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); + //perform mul operation + ymm14 = STRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + //extract a88 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 7)); - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); + //(ROw8): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*8 + cs_a*7)); + ymm20 = _mm256_fnmadd_ps(ymm0, ymm14, ymm20); - //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*8 + cs_a*6)); + ymm19 = _mm256_fnmadd_ps(ymm0, ymm14, ymm19); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*8 + cs_a*5)); + ymm3 = _mm256_fnmadd_ps(ymm0, ymm14, ymm3 ); - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x03); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x03); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*8 + cs_a*4)); + ymm2 = _mm256_fnmadd_ps(ymm0, ymm14, ymm2 ); - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - xmm5 = _mm256_extractf128_pd(ymm9, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 3),xmm5); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*8 + cs_a*3)); + ymm18 = _mm256_fnmadd_ps(ymm0, ymm14, ymm18); - m_remainder -=2; - } - else if (1 == m_remainder) - { - a01 = D_A_pack; - a11 = L + (n_remainder - 4)*cs_a + (n_remainder - 4)*rs_a; //pointer to block of A to be used for TRSM - b10 = B + (m_remainder - 1) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (m_remainder - 1) + (n_remainder - 4)*cs_b; //pointer to block of B to be used for TRSM + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*8 + cs_a*2)); + ymm17 = _mm256_fnmadd_ps(ymm0, ymm14, ymm17); - k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*8 + cs_a*1)); + ymm11 = _mm256_fnmadd_ps(ymm0, ymm14, ymm11 ); - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*8)); + ymm10 = _mm256_fnmadd_ps(ymm0, ymm14, ymm10 ); - ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_4nx4m(a01,b10,cs_b,p_lda,k_iter) + //perform mul operation + ymm20 = STRSM_SMALL_DIV_OR_SCALE(ymm20, ymm1); - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + //extract a99 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 6)); - ymm0 = _mm256_broadcast_sd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + //(ROw9): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*7 + cs_a*6)); + ymm19 = _mm256_fnmadd_ps(ymm0, ymm20, ymm19); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*7 + cs_a*5)); + ymm3 = _mm256_fnmadd_ps(ymm0, ymm20, ymm3 ); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*7 + cs_a*4)); + ymm2 = _mm256_fnmadd_ps(ymm0, ymm20, ymm2 ); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*7 + cs_a*3)); + ymm18 = _mm256_fnmadd_ps(ymm0, ymm20, ymm18); - ///implement TRSM/// + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*7 + cs_a*2)); + ymm17 = _mm256_fnmadd_ps(ymm0, ymm20, ymm17); - //extract a33 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*7 + cs_a*1)); + ymm11 = _mm256_fnmadd_ps(ymm0, ymm20, ymm11 ); - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*7)); + ymm10 = _mm256_fnmadd_ps(ymm0, ymm20, ymm10 ); - //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2*rs_a)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm9, ymm7); + //perform mul operation + ymm19 = STRSM_SMALL_DIV_OR_SCALE(ymm19, ymm1); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1*rs_a)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm9, ymm5); + //extract a10 10 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 5)); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm9, ymm3); + //(ROw10): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*6 + cs_a*5)); + ymm3 = _mm256_fnmadd_ps(ymm0, ymm19, ymm3 ); - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*6 + cs_a*4)); + ymm2 = _mm256_fnmadd_ps(ymm0, ymm19, ymm2 ); - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*6 + cs_a*3)); + ymm18 = _mm256_fnmadd_ps(ymm0, ymm19, ymm18); - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1*rs_a)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*6 + cs_a*2)); + ymm17 = _mm256_fnmadd_ps(ymm0, ymm19, ymm17); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*6 + cs_a*1)); + ymm11 = _mm256_fnmadd_ps(ymm0, ymm19, ymm11 ); - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*6)); + ymm10 = _mm256_fnmadd_ps(ymm0, ymm19, ymm10 ); - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); + //perform mul operation + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm1); - //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + //extract a11 11 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 4)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm0 = _mm256_broadcast_sd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x01); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x01); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x01); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x01); + //(ROw11): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*5 + cs_a*4)); + ymm2 = _mm256_fnmadd_ps(ymm0, ymm3, ymm2 ); - _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm3, 0)); - _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm5, 0)); - _mm_storel_pd((b11 + cs_b * 2), _mm256_extractf128_pd(ymm7, 0)); - _mm_storel_pd((b11 + cs_b * 3), _mm256_extractf128_pd(ymm9, 0)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*5 + cs_a*3)); + ymm18 = _mm256_fnmadd_ps(ymm0, ymm3, ymm18); - m_remainder -=1; - } - } - n_remainder -= 4; - } + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*5 + cs_a*2)); + ymm17 = _mm256_fnmadd_ps(ymm0, ymm3, ymm17); - if(n_remainder == 3) - { - a01 = L + (n_remainder - 3)*rs_a + n_remainder*cs_a; //pointer to block of A to be used in GEMM - a11 = L + (n_remainder - 3)*cs_a + (n_remainder - 3)*rs_a; //pointer to block of A to be used for TRSM + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*5 + cs_a*1)); + ymm11 = _mm256_fnmadd_ps(ymm0, ymm3, ymm11 ); - double *ptr_a10_dup = D_A_pack; + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*5)); + ymm10 = _mm256_fnmadd_ps(ymm0, ymm3, ymm10 ); - dim_t p_lda = (n-n_remainder); // packed leading dimension - // perform copy of A to packed buffer D_A_pack + //perform mul operation + ymm2 = STRSM_SMALL_DIV_OR_SCALE(ymm2, ymm1); - if(transa) - { - for(dim_t x =0;x < p_lda;x+=d_nr) - { - ymm0 = _mm256_loadu_pd((double const *)(a01)); - ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); - ymm2 = _mm256_loadu_pd((double const *)(a01 + cs_a * 2)); - ymm3 = _mm256_loadu_pd((double const *)(a01 + cs_a * 3)); + //extract a12 12 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 3)); - ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + //(ROw12): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*4 + cs_a*3)); + ymm18 = _mm256_fnmadd_ps(ymm0, ymm2, ymm18); - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*4 + cs_a*2)); + ymm17 = _mm256_fnmadd_ps(ymm0, ymm2, ymm17); - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*4 + cs_a*1)); + ymm11 = _mm256_fnmadd_ps(ymm0, ymm2, ymm11 ); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*4)); + ymm10 = _mm256_fnmadd_ps(ymm0, ymm2, ymm10 ); - _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); + //perform mul operation + ymm18 = STRSM_SMALL_DIV_OR_SCALE(ymm18, ymm1); - ymm0 = _mm256_loadu_pd((double const *)(a01 + cs_a * 4)); - ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a * 5)); + //extract a13 13 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 2)); - ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_broadcast_sd((double const *)&zero); + //(ROw13): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*3 + cs_a*2)); + ymm17 = _mm256_fnmadd_ps(ymm0, ymm18, ymm17); - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*3 + cs_a*1)); + ymm11 = _mm256_fnmadd_ps(ymm0, ymm18, ymm11 ); - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_broadcast_sd((double const *)&zero); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*3)); + ymm10 = _mm256_fnmadd_ps(ymm0, ymm18, ymm10 ); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + //perform mul operation + ymm17 = STRSM_SMALL_DIV_OR_SCALE(ymm17, ymm1); - _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_extractf128_pd(ymm6,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_extractf128_pd(ymm7,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_extractf128_pd(ymm8,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_extractf128_pd(ymm9,0)); + //extract a14 14 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); - a01 += d_nr*cs_a; - ptr_a10_dup += d_nr; - } - } - else - { - dim_t loop_count = (n-n_remainder)/4; + //(ROw13): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*2 + cs_a*1)); + ymm11 = _mm256_fnmadd_ps(ymm0, ymm17, ymm11 ); - for(dim_t x =0;x < loop_count;x++) - { - ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 0 + x*4)); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + x*4), ymm15); - ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 1 + x*4)); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 1 + x*4), ymm15); - ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 2 + x*4)); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 2 + x*4), ymm15); - } + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*2)); + ymm10 = _mm256_fnmadd_ps(ymm0, ymm17, ymm10 ); - dim_t remainder_loop_count = p_lda - loop_count*4; + //perform mul operation + ymm11 = STRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); - __m128d xmm0; - if(remainder_loop_count != 0) - { - xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 0 + loop_count*4)); - _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + loop_count*4), xmm0); - xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 1 + loop_count*4)); - _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 1 + loop_count*4), xmm0); - xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 2 + loop_count*4)); - _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 2 + loop_count*4), xmm0); - } - } + //extract a15 15 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 0)); - ymm4 = _mm256_broadcast_sd((double const *)&ones); - if(!is_unitdiag) - { - if(transa) - { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_sd((double const *)(a11)); - ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a*1 + 1)); - ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a*2 + 2)); - ymm3 = _mm256_broadcast_sd((double const *)&ones); - } - else - { - ymm0 = _mm256_broadcast_sd((double const *)(a11)); - ymm1 = _mm256_broadcast_sd((double const *)(a11+ rs_a*1 + 1)); - ymm2 = _mm256_broadcast_sd((double const *)(a11+ rs_a*2 + 2)); - ymm3 = _mm256_broadcast_sd((double const *)&ones); - } + //(ROw15): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*1)); + ymm10 = _mm256_fnmadd_ps(ymm0, ymm11, ymm10 ); - ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); + //perform mul operation + ymm10 = STRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); - ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); - #ifdef BLIS_DISABLE_TRSM_PREINVERSION - ymm4 = ymm1; - #endif - #ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm4 = _mm256_div_pd(ymm4, ymm1); - #endif + BLIS_STRSM_SMALL_NREG_TRANSPOSE_16x6_AND_STORE(b11,cs_b) } - _mm256_storeu_pd((double *)(d11_pack), ymm4); - for(i = (m-d_mr); (i+1) > 0; i -= d_mr) //loop along 'M' direction + dim_t n_remainder = j + d_nr; + if(n_remainder >= 4) { - a01 = D_A_pack; - a11 = L + (n_remainder - 3)*cs_a + (n_remainder - 3)*rs_a; //pointer to block of A to be used for TRSM - b10 = B + i + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (i) + (n_remainder - 3)*cs_b; //pointer to block of B to be used for TRSM + a10 = D_A_pack; + a11 = L + (i*cs_a) + (i*rs_a); + b01 = B + ((n_remainder - 4)* cs_b) + i + d_mr; + b11 = B + ((n_remainder - 4)* cs_b) + i; - k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + k_iter = (m - i - d_mr); /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS + BLIS_SET_S_YMM_REG_ZEROS - ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_3nx8m(a01,b10,cs_b,p_lda,k_iter) + ///GEMM code begins/// + BLIS_STRSM_SMALL_GEMM_16mx4n(a10,b01,cs_b,p_lda,k_iter) - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); + ymm16 = _mm256_broadcast_ss((float const *)(&AlphaVal)); - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + 4)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] + ymm17 = _mm256_loadu_ps((float const *)(b11)); + ymm18 = _mm256_loadu_ps((float const *)(b11 + cs_b)); + ymm19 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); + ymm20 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - ymm4 = _mm256_fmsub_pd(ymm1, ymm15, ymm4); //B11[4-7][0] * alpha-= ymm1 + ymm17 = _mm256_fmsub_ps(ymm17, ymm16, ymm8); + ymm18 = _mm256_fmsub_ps(ymm18, ymm16, ymm9); + ymm19 = _mm256_fmsub_ps(ymm19, ymm16, ymm10); + ymm20 = _mm256_fmsub_ps(ymm20, ymm16, ymm11); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b + 4)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] + ymm8 = _mm256_unpacklo_ps(ymm17, ymm18); + ymm9 = _mm256_unpacklo_ps(ymm19, ymm20); - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - ymm6 = _mm256_fmsub_pd(ymm1, ymm15, ymm6); //B11[4-7][1] * alpha -= ymm3 + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b01000100); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b*2 + 4)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] + ymm10 = _mm256_permute2f128_ps(ymm4,ymm4,0x20);//1 + ymm2 = _mm256_permute2f128_ps(ymm4,ymm4,0x31);//5 - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - ymm8 = _mm256_fmsub_pd(ymm1, ymm15, ymm8); //B11[4-7][2] * alpha -= ymm5 + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b11101110); - ///implement TRSM/// + ymm11 = _mm256_permute2f128_ps(ymm4,ymm4,0x20);//2 + ymm3 = _mm256_permute2f128_ps(ymm4,ymm4,0x31);//6 - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + ymm8 = _mm256_unpackhi_ps(ymm17, ymm18); + ymm9 = _mm256_unpackhi_ps(ymm19, ymm20); - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm0); + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b01000100); - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + ymm17 = _mm256_permute2f128_ps(ymm4,ymm4,0x20);//3 + ymm19 = _mm256_permute2f128_ps(ymm4,ymm4,0x31);//7 - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1*rs_a)); + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b11101110); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); - ymm6 = _mm256_fnmadd_pd(ymm1, ymm8, ymm6); + ymm18 = _mm256_permute2f128_ps(ymm4,ymm4,0x20);//4 + ymm20 = _mm256_permute2f128_ps(ymm4,ymm4,0x31);//8 - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); + ymm16 = _mm256_broadcast_ss((float const *)(&AlphaVal)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); - ymm4 = _mm256_fnmadd_pd(ymm1, ymm8, ymm4); + ymm8 = _mm256_loadu_ps((float const *)(b11 + 8)); + ymm9 = _mm256_loadu_ps((float const *)(b11 + cs_b + 8)); + ymm4 = _mm256_loadu_ps((float const *)(b11 + cs_b*2 + 8)); + ymm5 = _mm256_loadu_ps((float const *)(b11 + cs_b*3 + 8)); - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm0); + ymm8 = _mm256_fmsub_ps(ymm8, ymm16, ymm12); + ymm9 = _mm256_fmsub_ps(ymm9, ymm16, ymm13); + ymm4 = _mm256_fmsub_ps(ymm4, ymm16, ymm14); + ymm5 = _mm256_fmsub_ps(ymm5, ymm16, ymm15); - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); + ymm12 = _mm256_unpacklo_ps(ymm8, ymm9); + ymm13 = _mm256_unpacklo_ps(ymm4, ymm5); - //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); + ymm6 = _mm256_shuffle_ps(ymm12,ymm13,0b01000100); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); - ymm4 = _mm256_fnmadd_pd(ymm1, ymm6, ymm4); + ymm14 = _mm256_permute2f128_ps(ymm6,ymm6,0x20);//1 + ymm21 = _mm256_permute2f128_ps(ymm6,ymm6,0x31);//5 - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm0); + ymm6 = _mm256_shuffle_ps(ymm12,ymm13,0b11101110); - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + 4), ymm4); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b + 4), ymm6); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*2 + 4), ymm8); - } + ymm15 = _mm256_permute2f128_ps(ymm6,ymm6,0x20);//2 + ymm22 = _mm256_permute2f128_ps(ymm6,ymm6,0x31);//6 - dim_t m_remainder = i + d_mr; - if(m_remainder >= 4) - { - a01 = D_A_pack; - a11 = L + (n_remainder - 3)*cs_a + (n_remainder - 3)*rs_a; //pointer to block of A to be used for TRSM - b10 = B + (m_remainder - 4) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (m_remainder - 4) + (n_remainder - 3)*cs_b; //pointer to block of B to be used for TRSM + ymm12 = _mm256_unpackhi_ps(ymm8, ymm9); + ymm13 = _mm256_unpackhi_ps(ymm4, ymm5); - k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + ymm6 = _mm256_shuffle_ps(ymm12,ymm13,0b01000100); - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS + ymm8 = _mm256_permute2f128_ps(ymm6,ymm6,0x20);//3 + ymm4 = _mm256_permute2f128_ps(ymm6,ymm6,0x31);//7 - ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) + ymm6 = _mm256_shuffle_ps(ymm12,ymm13,0b11101110); - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + ymm9 = _mm256_permute2f128_ps(ymm6,ymm6,0x20);//4 + ymm5 = _mm256_permute2f128_ps(ymm6,ymm6,0x31);//8 - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + // TRSM Operation - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + ////extract a00 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 15)); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 + //perform mul operation + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm1); - ///implement TRSM/// + //extract a11 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 14)); - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + //(ROw1): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*15 + cs_a*14)); + ymm4 = _mm256_fnmadd_ps(ymm0, ymm5, ymm4); - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*15 + cs_a*13)); + ymm22 = _mm256_fnmadd_ps(ymm0, ymm5, ymm22); - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1*rs_a)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*15 + cs_a*12)); + ymm21 = _mm256_fnmadd_ps(ymm0, ymm5, ymm21); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*15 + cs_a*11)); + ymm9 = _mm256_fnmadd_ps(ymm0, ymm5, ymm9); - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*15 + cs_a*10)); + ymm8 = _mm256_fnmadd_ps(ymm0, ymm5, ymm8); - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*15 + cs_a*9)); + ymm15 = _mm256_fnmadd_ps(ymm0, ymm5, ymm15); - //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*15 + cs_a*8)); + ymm14 = _mm256_fnmadd_ps(ymm0, ymm5, ymm14); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*15 + cs_a*7)); + ymm20 = _mm256_fnmadd_ps(ymm0, ymm5, ymm20); - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*15 + cs_a*6)); + ymm19 = _mm256_fnmadd_ps(ymm0, ymm5, ymm19); - m_remainder -=4; - } + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*15 + cs_a*5)); + ymm3 = _mm256_fnmadd_ps(ymm0, ymm5, ymm3 ); - if(m_remainder) - { - if(3 == m_remainder) - { - a01 = D_A_pack; - a11 = L + (n_remainder - 3)*cs_a + (n_remainder - 3)*rs_a; //pointer to block of A to be used for TRSM - b10 = B + (m_remainder - 3) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (m_remainder - 3) + (n_remainder - 3)*cs_b; //pointer to block of B to be used for TRSM + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*15 + cs_a*4)); + ymm2 = _mm256_fnmadd_ps(ymm0, ymm5, ymm2 ); - k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*15 + cs_a*3)); + ymm18 = _mm256_fnmadd_ps(ymm0, ymm5, ymm18); - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*15 + cs_a*2)); + ymm17 = _mm256_fnmadd_ps(ymm0, ymm5, ymm17); - ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*15 + cs_a*1)); + ymm11 = _mm256_fnmadd_ps(ymm0, ymm5, ymm11 ); - BLIS_PRE_DTRSM_SMALL_3N_3M(AlphaVal,b11,cs_b) + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*15)); + ymm10 = _mm256_fnmadd_ps(ymm0, ymm5, ymm10 ); - ///implement TRSM/// - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + //perform mul operation + ymm4 = STRSM_SMALL_DIV_OR_SCALE(ymm4, ymm1); - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + //extract a22 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 13)); - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1*rs_a)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); + //(ROw2): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*14 + cs_a*13)); + ymm22 = _mm256_fnmadd_ps(ymm0, ymm4, ymm22); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*14 + cs_a*12)); + ymm21 = _mm256_fnmadd_ps(ymm0, ymm4, ymm21); - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*14 + cs_a*11)); + ymm9 = _mm256_fnmadd_ps(ymm0, ymm4, ymm9); - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*14 + cs_a*10)); + ymm8 = _mm256_fnmadd_ps(ymm0, ymm4, ymm8); - //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*14 + cs_a*9)); + ymm15 = _mm256_fnmadd_ps(ymm0, ymm4, ymm15); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*14 + cs_a*8)); + ymm14 = _mm256_fnmadd_ps(ymm0, ymm4, ymm14); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*14 + cs_a*7)); + ymm20 = _mm256_fnmadd_ps(ymm0, ymm4, ymm20); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*14 + cs_a*6)); + ymm19 = _mm256_fnmadd_ps(ymm0, ymm4, ymm19); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*14 + cs_a*5)); + ymm3 = _mm256_fnmadd_ps(ymm0, ymm4, ymm3 ); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*14 + cs_a*4)); + ymm2 = _mm256_fnmadd_ps(ymm0, ymm4, ymm2 ); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*14 + cs_a*3)); + ymm18 = _mm256_fnmadd_ps(ymm0, ymm4, ymm18); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*14 + cs_a*2)); + ymm17 = _mm256_fnmadd_ps(ymm0, ymm4, ymm17); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*14 + cs_a*1)); + ymm11 = _mm256_fnmadd_ps(ymm0, ymm4, ymm11 ); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*14)); + ymm10 = _mm256_fnmadd_ps(ymm0, ymm4, ymm10 ); + + //perform mul operation + ymm22 = STRSM_SMALL_DIV_OR_SCALE(ymm22, ymm1); + + //extract a33 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 12)); + + //(ROw3): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*13 + cs_a*12)); + ymm21 = _mm256_fnmadd_ps(ymm0, ymm22, ymm21); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*13 + cs_a*11)); + ymm9 = _mm256_fnmadd_ps(ymm0, ymm22, ymm9); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*13 + cs_a*10)); + ymm8 = _mm256_fnmadd_ps(ymm0, ymm22, ymm8); - BLIS_POST_DTRSM_SMALL_3N_3M(b11,cs_b) + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*13 + cs_a*9)); + ymm15 = _mm256_fnmadd_ps(ymm0, ymm22, ymm15); - m_remainder -=3; - } - else if(2 == m_remainder) - { - a01 = D_A_pack; - a11 = L + (n_remainder - 3)*cs_a + (n_remainder - 3)*rs_a; //pointer to block of A to be used for TRSM - b10 = B + (m_remainder - 2) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (m_remainder - 2) + (n_remainder - 3)*cs_b; //pointer to block of B to be used for TRSM + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*13 + cs_a*8)); + ymm14 = _mm256_fnmadd_ps(ymm0, ymm22, ymm14); - k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*13 + cs_a*7)); + ymm20 = _mm256_fnmadd_ps(ymm0, ymm22, ymm20); - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*13 + cs_a*6)); + ymm19 = _mm256_fnmadd_ps(ymm0, ymm22, ymm19); - ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*13 + cs_a*5)); + ymm3 = _mm256_fnmadd_ps(ymm0, ymm22, ymm3 ); - BLIS_PRE_DTRSM_SMALL_3N_2M(AlphaVal,b11,cs_b) + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*13 + cs_a*4)); + ymm2 = _mm256_fnmadd_ps(ymm0, ymm22, ymm2 ); - ///implement TRSM/// + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*13 + cs_a*3)); + ymm18 = _mm256_fnmadd_ps(ymm0, ymm22, ymm18); - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*13 + cs_a*2)); + ymm17 = _mm256_fnmadd_ps(ymm0, ymm22, ymm17); - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*13 + cs_a*1)); + ymm11 = _mm256_fnmadd_ps(ymm0, ymm22, ymm11 ); - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1*rs_a)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*13)); + ymm10 = _mm256_fnmadd_ps(ymm0, ymm22, ymm10 ); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); + //perform mul operation + ymm21 = STRSM_SMALL_DIV_OR_SCALE(ymm21, ymm1); - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + //extract a44 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 11)); + //(ROw4): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*12 + cs_a*11)); + ymm9 = _mm256_fnmadd_ps(ymm0, ymm21, ymm9); - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*12 + cs_a*10)); + ymm8 = _mm256_fnmadd_ps(ymm0, ymm21, ymm8); - //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*12 + cs_a*9)); + ymm15 = _mm256_fnmadd_ps(ymm0, ymm21, ymm15); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*12 + cs_a*8)); + ymm14 = _mm256_fnmadd_ps(ymm0, ymm21, ymm14); - BLIS_POST_DTRSM_SMALL_3N_2M(b11,cs_b) + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*12 + cs_a*7)); + ymm20 = _mm256_fnmadd_ps(ymm0, ymm21, ymm20); - m_remainder -=2; - } - else if (1 == m_remainder) - { - a01 = D_A_pack; - a11 = L + (n_remainder - 3)*cs_a + (n_remainder - 3)*rs_a; //pointer to block of A to be used for TRSM - b10 = B + (m_remainder - 1) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (m_remainder - 1) + (n_remainder - 3)*cs_b; //pointer to block of B to be used for TRSM + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*12 + cs_a*6)); + ymm19 = _mm256_fnmadd_ps(ymm0, ymm21, ymm19); - k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*12 + cs_a*5)); + ymm3 = _mm256_fnmadd_ps(ymm0, ymm21, ymm3 ); - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*12 + cs_a*4)); + ymm2 = _mm256_fnmadd_ps(ymm0, ymm21, ymm2 ); - ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*12 + cs_a*3)); + ymm18 = _mm256_fnmadd_ps(ymm0, ymm21, ymm18); - BLIS_PRE_DTRSM_SMALL_3N_1M(AlphaVal,b11,cs_b) + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*12 + cs_a*2)); + ymm17 = _mm256_fnmadd_ps(ymm0, ymm21, ymm17); - ///implement TRSM/// + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*12 + cs_a*1)); + ymm11 = _mm256_fnmadd_ps(ymm0, ymm21, ymm11 ); - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*12)); + ymm10 = _mm256_fnmadd_ps(ymm0, ymm21, ymm10 ); - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + //perform mul operation + ymm9 = STRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1*rs_a)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); + //extract a55 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 10)); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); + //(ROw5): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*11 + cs_a*10)); + ymm8 = _mm256_fnmadd_ps(ymm0, ymm9, ymm8); - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*11 + cs_a*9)); + ymm15 = _mm256_fnmadd_ps(ymm0, ymm9, ymm15); - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*11 + cs_a*8)); + ymm14 = _mm256_fnmadd_ps(ymm0, ymm9, ymm14); - //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*11 + cs_a*7)); + ymm20 = _mm256_fnmadd_ps(ymm0, ymm9, ymm20); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*11 + cs_a*6)); + ymm19 = _mm256_fnmadd_ps(ymm0, ymm9, ymm19); - BLIS_POST_DTRSM_SMALL_3N_1M(b11,cs_b) + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*11 + cs_a*5)); + ymm3 = _mm256_fnmadd_ps(ymm0, ymm9, ymm3 ); - m_remainder -=1; - } - } - n_remainder -= 3; - } - else if(n_remainder == 2) - { - a01 = L + (n_remainder - 2)*rs_a + n_remainder*cs_a; //pointer to block of A to be used in GEMM - a11 = L + (n_remainder - 2)*cs_a + (n_remainder - 2)*rs_a; //pointer to block of A to be used for TRSM + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*11 + cs_a*4)); + ymm2 = _mm256_fnmadd_ps(ymm0, ymm9, ymm2 ); - double *ptr_a10_dup = D_A_pack; + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*11 + cs_a*3)); + ymm18 = _mm256_fnmadd_ps(ymm0, ymm9, ymm18); - dim_t p_lda = (n-n_remainder); // packed leading dimension - // perform copy of A to packed buffer D_A_pack + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*11 + cs_a*2)); + ymm17 = _mm256_fnmadd_ps(ymm0, ymm9, ymm17); - if(transa) - { - for(dim_t x =0;x < p_lda;x+=d_nr) - { - ymm0 = _mm256_loadu_pd((double const *)(a01)); - ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); - ymm2 = _mm256_loadu_pd((double const *)(a01 + cs_a * 2)); - ymm3 = _mm256_loadu_pd((double const *)(a01 + cs_a * 3)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*11 + cs_a*1)); + ymm11 = _mm256_fnmadd_ps(ymm0, ymm9, ymm11 ); - ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*11)); + ymm10 = _mm256_fnmadd_ps(ymm0, ymm9, ymm10 ); - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + //perform mul operation + ymm8 = STRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + //extract a66 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 9)); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + //(ROw6): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*10 + cs_a*9)); + ymm15 = _mm256_fnmadd_ps(ymm0, ymm8, ymm15); - _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*10 + cs_a*8)); + ymm14 = _mm256_fnmadd_ps(ymm0, ymm8, ymm14); - ymm0 = _mm256_loadu_pd((double const *)(a01 + cs_a * 4)); - ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a * 5)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*10 + cs_a*7)); + ymm20 = _mm256_fnmadd_ps(ymm0, ymm8, ymm20); - ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_broadcast_sd((double const *)&zero); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*10 + cs_a*6)); + ymm19 = _mm256_fnmadd_ps(ymm0, ymm8, ymm19); - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*10 + cs_a*5)); + ymm3 = _mm256_fnmadd_ps(ymm0, ymm8, ymm3 ); - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_broadcast_sd((double const *)&zero); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*10 + cs_a*4)); + ymm2 = _mm256_fnmadd_ps(ymm0, ymm8, ymm2 ); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*10 + cs_a*3)); + ymm18 = _mm256_fnmadd_ps(ymm0, ymm8, ymm18); - _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_extractf128_pd(ymm6,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_extractf128_pd(ymm7,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_extractf128_pd(ymm8,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_extractf128_pd(ymm9,0)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*10 + cs_a*2)); + ymm17 = _mm256_fnmadd_ps(ymm0, ymm8, ymm17); - a01 += d_nr*cs_a; - ptr_a10_dup += d_nr; - } - } - else - { - dim_t loop_count = (n-n_remainder)/4; + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*10 + cs_a*1)); + ymm11 = _mm256_fnmadd_ps(ymm0, ymm8, ymm11 ); - for(dim_t x =0;x < loop_count;x++) - { - ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 0 + x*4)); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + x*4), ymm15); - ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 1 + x*4)); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 1 + x*4), ymm15); - } + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*10)); + ymm10 = _mm256_fnmadd_ps(ymm0, ymm8, ymm10 ); - dim_t remainder_loop_count = p_lda - loop_count*4; + //perform mul operation + ymm15 = STRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); - __m128d xmm0; - if(remainder_loop_count != 0) - { - xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 0 + loop_count*4)); - _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + loop_count*4), xmm0); - xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 1 + loop_count*4)); - _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 1 + loop_count*4), xmm0); - } - } + //extract a77 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 8)); - ymm4 = _mm256_broadcast_sd((double const *)&ones); - if(!is_unitdiag) - { - if(transa) - { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_sd((double const *)(a11)); - ymm1 = _mm256_broadcast_sd((double const *)(a11+cs_a*1 + 1)); - } - else - { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_sd((double const *)(a11)); - ymm1 = _mm256_broadcast_sd((double const *)(a11+rs_a*1 + 1)); - } - ymm2 = _mm256_broadcast_sd((double const *)&ones); - ymm3 = _mm256_broadcast_sd((double const *)&ones); + //(ROw7): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*9 + cs_a*8)); + ymm14 = _mm256_fnmadd_ps(ymm0, ymm15, ymm14); - ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*9 + cs_a*7)); + ymm20 = _mm256_fnmadd_ps(ymm0, ymm15, ymm20); - ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); - #ifdef BLIS_DISABLE_TRSM_PREINVERSION - ymm4 = ymm1; - #endif - #ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm4 = _mm256_div_pd(ymm4, ymm1); - #endif - } - _mm256_storeu_pd((double *)(d11_pack), ymm4); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*9 + cs_a*6)); + ymm19 = _mm256_fnmadd_ps(ymm0, ymm15, ymm19); - for(i = (m-d_mr); (i+1) > 0; i -= d_mr) //loop along 'M' direction - { - a01 = D_A_pack; - a11 = L + (n_remainder - 2)*cs_a + (n_remainder - 2)*rs_a; //pointer to block of A to be used for TRSM - b10 = B + i + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (i) + (n_remainder - 2)*cs_b; //pointer to block of B to be used for TRSM + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*9 + cs_a*5)); + ymm3 = _mm256_fnmadd_ps(ymm0, ymm15, ymm3 ); - k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*9 + cs_a*4)); + ymm2 = _mm256_fnmadd_ps(ymm0, ymm15, ymm2 ); - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*9 + cs_a*3)); + ymm18 = _mm256_fnmadd_ps(ymm0, ymm15, ymm18); - ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_2nx8m(a01,b10,cs_b,p_lda,k_iter) + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*9 + cs_a*2)); + ymm17 = _mm256_fnmadd_ps(ymm0, ymm15, ymm17); - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*9 + cs_a*1)); + ymm11 = _mm256_fnmadd_ps(ymm0, ymm15, ymm11 ); - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + 4)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*9)); + ymm10 = _mm256_fnmadd_ps(ymm0, ymm15, ymm10 ); - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - ymm4 = _mm256_fmsub_pd(ymm1, ymm15, ymm4); //B11[4-7][0] * alpha-= ymm1 + //perform mul operation + ymm14 = STRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b + 4)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] + //extract a88 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 7)); - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - ymm6 = _mm256_fmsub_pd(ymm1, ymm15, ymm6); //B11[4-7][1] * alpha -= ymm3 + //(ROw8): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*8 + cs_a*7)); + ymm20 = _mm256_fnmadd_ps(ymm0, ymm14, ymm20); - ///implement TRSM/// + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*8 + cs_a*6)); + ymm19 = _mm256_fnmadd_ps(ymm0, ymm14, ymm19); - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*8 + cs_a*5)); + ymm3 = _mm256_fnmadd_ps(ymm0, ymm14, ymm3 ); - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm0); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*8 + cs_a*4)); + ymm2 = _mm256_fnmadd_ps(ymm0, ymm14, ymm2 ); - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*8 + cs_a*3)); + ymm18 = _mm256_fnmadd_ps(ymm0, ymm14, ymm18); - //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*8 + cs_a*2)); + ymm17 = _mm256_fnmadd_ps(ymm0, ymm14, ymm17); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); - ymm4 = _mm256_fnmadd_pd(ymm1, ymm6, ymm4); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*8 + cs_a*1)); + ymm11 = _mm256_fnmadd_ps(ymm0, ymm14, ymm11 ); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm0); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*8)); + ymm10 = _mm256_fnmadd_ps(ymm0, ymm14, ymm10 ); - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + 4), ymm4); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b + 4), ymm6); - } + //perform mul operation + ymm20 = STRSM_SMALL_DIV_OR_SCALE(ymm20, ymm1); - dim_t m_remainder = i + d_mr; - if(m_remainder >= 4) - { - a01 = D_A_pack; - a11 = L + (n_remainder - 2)*cs_a + (n_remainder - 2)*rs_a; //pointer to block of A to be used for TRSM - b10 = B + (m_remainder - 4) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (m_remainder - 4) + (n_remainder - 2)*cs_b; //pointer to block of B to be used for TRSM + //extract a99 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 6)); - k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + //(ROw9): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*7 + cs_a*6)); + ymm19 = _mm256_fnmadd_ps(ymm0, ymm20, ymm19); - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*7 + cs_a*5)); + ymm3 = _mm256_fnmadd_ps(ymm0, ymm20, ymm3 ); - ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*7 + cs_a*4)); + ymm2 = _mm256_fnmadd_ps(ymm0, ymm20, ymm2 ); - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*7 + cs_a*3)); + ymm18 = _mm256_fnmadd_ps(ymm0, ymm20, ymm18); - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*7 + cs_a*2)); + ymm17 = _mm256_fnmadd_ps(ymm0, ymm20, ymm17); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*7 + cs_a*1)); + ymm11 = _mm256_fnmadd_ps(ymm0, ymm20, ymm11 ); - ///implement TRSM/// + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*7)); + ymm10 = _mm256_fnmadd_ps(ymm0, ymm20, ymm10 ); - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + //perform mul operation + ymm19 = STRSM_SMALL_DIV_OR_SCALE(ymm19, ymm1); - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); + //extract a10 10 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 5)); - //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + //(ROw10): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*6 + cs_a*5)); + ymm3 = _mm256_fnmadd_ps(ymm0, ymm19, ymm3 ); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*6 + cs_a*4)); + ymm2 = _mm256_fnmadd_ps(ymm0, ymm19, ymm2 ); - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*6 + cs_a*3)); + ymm18 = _mm256_fnmadd_ps(ymm0, ymm19, ymm18); - m_remainder -=4; - } + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*6 + cs_a*2)); + ymm17 = _mm256_fnmadd_ps(ymm0, ymm19, ymm17); - if(m_remainder) - { - if(3 == m_remainder) - { - a01 = D_A_pack; - a11 = L + (n_remainder - 2)*cs_a + (n_remainder - 2)*rs_a; //pointer to block of A to be used for TRSM - b10 = B + (m_remainder - 3) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (m_remainder - 3) + (n_remainder - 2)*cs_b; //pointer to block of B to be used for TRSM + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*6 + cs_a*1)); + ymm11 = _mm256_fnmadd_ps(ymm0, ymm19, ymm11 ); - k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*6)); + ymm10 = _mm256_fnmadd_ps(ymm0, ymm19, ymm10 ); - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS + //perform mul operation + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm1); - ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) + //extract a11 11 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 4)); - BLIS_PRE_DTRSM_SMALL_2N_3M(AlphaVal,b11,cs_b) + //(ROw11): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*5 + cs_a*4)); + ymm2 = _mm256_fnmadd_ps(ymm0, ymm3, ymm2 ); - ///implement TRSM/// + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*5 + cs_a*3)); + ymm18 = _mm256_fnmadd_ps(ymm0, ymm3, ymm18); - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*5 + cs_a*2)); + ymm17 = _mm256_fnmadd_ps(ymm0, ymm3, ymm17); - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*5 + cs_a*1)); + ymm11 = _mm256_fnmadd_ps(ymm0, ymm3, ymm11 ); - //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*5)); + ymm10 = _mm256_fnmadd_ps(ymm0, ymm3, ymm10 ); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + //perform mul operation + ymm2 = STRSM_SMALL_DIV_OR_SCALE(ymm2, ymm1); - BLIS_POST_DTRSM_SMALL_2N_3M(b11,cs_b) + //extract a12 12 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 3)); - m_remainder -=3; - } - else if(2 == m_remainder) - { - a01 = D_A_pack; - a11 = L + (n_remainder - 2)*cs_a + (n_remainder - 2)*rs_a; //pointer to block of A to be used for TRSM - b10 = B + (m_remainder - 2) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (m_remainder - 2) + (n_remainder - 2)*cs_b; //pointer to block of B to be used for TRSM + //(ROw12): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*4 + cs_a*3)); + ymm18 = _mm256_fnmadd_ps(ymm0, ymm2, ymm18); - k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*4 + cs_a*2)); + ymm17 = _mm256_fnmadd_ps(ymm0, ymm2, ymm17); - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*4 + cs_a*1)); + ymm11 = _mm256_fnmadd_ps(ymm0, ymm2, ymm11 ); - ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*4)); + ymm10 = _mm256_fnmadd_ps(ymm0, ymm2, ymm10 ); - BLIS_PRE_DTRSM_SMALL_2N_2M(AlphaVal,b11,cs_b) - ///implement TRSM/// + //perform mul operation + ymm18 = STRSM_SMALL_DIV_OR_SCALE(ymm18, ymm1); - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + //extract a13 13 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 2)); - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); + //(ROw13): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*3 + cs_a*2)); + ymm17 = _mm256_fnmadd_ps(ymm0, ymm18, ymm17); - //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*3 + cs_a*1)); + ymm11 = _mm256_fnmadd_ps(ymm0, ymm18, ymm11 ); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*3)); + ymm10 = _mm256_fnmadd_ps(ymm0, ymm18, ymm10 ); - BLIS_POST_DTRSM_SMALL_2N_2M(b11,cs_b) + //perform mul operation + ymm17 = STRSM_SMALL_DIV_OR_SCALE(ymm17, ymm1); - m_remainder -=2; - } - else if (1 == m_remainder) - { - a01 = D_A_pack; - a11 = L + (n_remainder - 2)*cs_a + (n_remainder - 2)*rs_a; //pointer to block of A to be used for TRSM - b10 = B + (m_remainder - 1) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (m_remainder - 1) + (n_remainder - 2)*cs_b; //pointer to block of B to be used for TRSM + //extract a14 14 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); - k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + //(ROw13): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*2 + cs_a*1)); + ymm11 = _mm256_fnmadd_ps(ymm0, ymm17, ymm11 ); - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*2)); + ymm10 = _mm256_fnmadd_ps(ymm0, ymm17, ymm10 ); - ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) + //perform mul operation + ymm11 = STRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); - BLIS_PRE_DTRSM_SMALL_2N_1M(AlphaVal,b11,cs_b) - ///implement TRSM/// + //extract a15 15 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 0)); - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + //(ROw15): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*1)); + ymm10 = _mm256_fnmadd_ps(ymm0, ymm11, ymm10 ); - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); + //perform mul operation + ymm10 = STRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); - //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + ymm0 = _mm256_unpacklo_ps(ymm10, ymm11); + ymm1 = _mm256_unpacklo_ps(ymm17, ymm18); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + ymm6 = _mm256_unpacklo_ps(ymm2, ymm3); + ymm7 = _mm256_unpacklo_ps(ymm19, ymm20); - BLIS_POST_DTRSM_SMALL_2N_1M(b11,cs_b) + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b01000100); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b01000100); - m_remainder -=1; - } - } - n_remainder -= 2; - } - else if(n_remainder == 1) - { - a01 = L + (n_remainder - 1)*rs_a + n_remainder*cs_a; //pointer to block of A to be used in GEMM - a11 = L + (n_remainder - 1)*cs_a + (n_remainder - 1)*rs_a; //pointer to block of A to be used for TRSM + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//1 + _mm256_storeu_ps((float *)(b11), ymm16); - double *ptr_a10_dup = D_A_pack; + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b11101110); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b11101110); - dim_t p_lda = (n-n_remainder); // packed leading dimension - // perform copy of A to packed buffer D_A_pack + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//2 + _mm256_storeu_ps((float *)(b11 + cs_b), ymm16); - if(transa) - { - for(dim_t x =0;x < p_lda;x+=d_nr) - { - ymm0 = _mm256_loadu_pd((double const *)(a01)); - ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); - ymm2 = _mm256_loadu_pd((double const *)(a01 + cs_a * 2)); - ymm3 = _mm256_loadu_pd((double const *)(a01 + cs_a * 3)); + ymm0 = _mm256_unpackhi_ps(ymm10, ymm11); + ymm1 = _mm256_unpackhi_ps(ymm17, ymm18); - ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm6 = _mm256_unpackhi_ps(ymm2, ymm3); + ymm7 = _mm256_unpackhi_ps(ymm19, ymm20); - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b01000100); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b01000100); - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//3 + _mm256_storeu_ps((float *)(b11 + 2*cs_b), ymm16); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b11101110); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b11101110); - _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//4 + _mm256_storeu_ps((float *)(b11 + 3*cs_b), ymm16); - ymm0 = _mm256_loadu_pd((double const *)(a01 + cs_a * 4)); - ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a * 5)); + ymm0 = _mm256_unpacklo_ps(ymm14, ymm15); + ymm1 = _mm256_unpacklo_ps(ymm8, ymm9); - ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_broadcast_sd((double const *)&zero); + ymm6 = _mm256_unpacklo_ps(ymm21, ymm22); + ymm7 = _mm256_unpacklo_ps(ymm4, ymm5); - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b01000100); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b01000100); - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_broadcast_sd((double const *)&zero); + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//1 + _mm256_storeu_ps((float *)(b11 + 8), ymm16); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b11101110); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b11101110); - _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_extractf128_pd(ymm6,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_extractf128_pd(ymm7,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_extractf128_pd(ymm8,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_extractf128_pd(ymm9,0)); + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//2 + _mm256_storeu_ps((float *)(b11 + cs_b + 8), ymm16); - a01 += d_nr*cs_a; - ptr_a10_dup += d_nr; - } - } - else - { - dim_t loop_count = (n-n_remainder)/4; + ymm0 = _mm256_unpackhi_ps(ymm14, ymm15); + ymm1 = _mm256_unpackhi_ps(ymm8, ymm9); - for(dim_t x =0;x < loop_count;x++) - { - ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 0 + x*4)); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + x*4), ymm15); - } + ymm6 = _mm256_unpackhi_ps(ymm21, ymm22); + ymm7 = _mm256_unpackhi_ps(ymm4, ymm5); - dim_t remainder_loop_count = p_lda - loop_count*4; + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b01000100); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b01000100); - __m128d xmm0; - if(remainder_loop_count != 0) - { - xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 0 + loop_count*4)); - _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + loop_count*4), xmm0); - } - } + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//3 + _mm256_storeu_ps((float *)(b11 + 2*cs_b + 8), ymm16); - ymm4 = _mm256_broadcast_sd((double const *)&ones); - if(!is_unitdiag) - { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_sd((double const *)(a11)); - ymm1 = _mm256_broadcast_sd((double const *)&ones); - ymm2 = _mm256_broadcast_sd((double const *)&ones); - ymm3 = _mm256_broadcast_sd((double const *)&ones); + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b11101110); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b11101110); - ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//4 + _mm256_storeu_ps((float *)(b11 + 3*cs_b + 8), ymm16); - ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); - #ifdef BLIS_DISABLE_TRSM_PREINVERSION - ymm4 = ymm1; - #endif - #ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm4 = _mm256_div_pd(ymm4, ymm1); - #endif + n_remainder -=4; } - _mm256_storeu_pd((double *)(d11_pack), ymm4); - for(i = (m-d_mr); (i+1) > 0; i -= d_mr) //loop along 'M' direction + if(n_remainder) //implementation fo remaining columns(when 'N' is not a multiple of d_nr)() n = 3 { - a01 = D_A_pack; - a11 = L + (n_remainder - 1)*cs_a + (n_remainder - 1)*rs_a; //pointer to block of A to be used for TRSM - b10 = B + i + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (i) + (n_remainder - 1)*cs_b; //pointer to block of B to be used for TRSM + a10 = D_A_pack; + a11 = L + (i*cs_a) + (i*rs_a); + b01 = B + i + d_mr; + b11 = B + i; - k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + k_iter = (m - i - d_mr) ; /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS + BLIS_SET_S_YMM_REG_ZEROS - ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_1nx8m(a01,b10,cs_b,p_lda,k_iter) + if(3 == n_remainder) + { + ///GEMM code begins/// + BLIS_STRSM_SMALL_GEMM_16mx3n(a10,b01,cs_b,p_lda,k_iter) - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); + ymm16 = _mm256_broadcast_ss((float const *)(&AlphaVal)); //register to hold alpha - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + 4)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] + ymm17 = _mm256_loadu_ps((float const *)(b11)); + ymm18 = _mm256_loadu_ps((float const *)(b11 + cs_b)); + ymm19 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - ymm4 = _mm256_fmsub_pd(ymm1, ymm15, ymm4); //B11[4-7][0] * alpha-= ymm1 + ymm17 = _mm256_fmsub_ps(ymm17, ymm16, ymm8); + ymm18 = _mm256_fmsub_ps(ymm18, ymm16, ymm9); + ymm19 = _mm256_fmsub_ps(ymm19, ymm16, ymm10); + ymm20 = _mm256_broadcast_ss((float const *)(&zero)); - ///implement TRSM/// - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm0); + ymm8 = _mm256_unpacklo_ps(ymm17, ymm18); + ymm9 = _mm256_unpacklo_ps(ymm19, ymm20); - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + 4), ymm4); - } + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b01000100); - dim_t m_remainder = i + d_mr; - if(m_remainder >= 4) - { - a01 = D_A_pack; - a11 = L + (n_remainder - 1)*cs_a + (n_remainder - 1)*rs_a; //pointer to block of A to be used for TRSM - b10 = B + (m_remainder - 4) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (m_remainder - 4) + (n_remainder - 1)*cs_b; //pointer to block of B to be used for TRSM + ymm10 = _mm256_permute2f128_ps(ymm4,ymm4,0x20);//1 + ymm2 = _mm256_permute2f128_ps(ymm4,ymm4,0x31);//5 - k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b11101110); - ymm3 = _mm256_setzero_pd(); + ymm11 = _mm256_permute2f128_ps(ymm4,ymm4,0x20);//2 + ymm3 = _mm256_permute2f128_ps(ymm4,ymm4,0x31);//6 - ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) + ymm8 = _mm256_unpackhi_ps(ymm17, ymm18); + ymm9 = _mm256_unpackhi_ps(ymm19, ymm20); - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b01000100); - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + ymm17 = _mm256_permute2f128_ps(ymm4,ymm4,0x20);//3 + ymm19 = _mm256_permute2f128_ps(ymm4,ymm4,0x31);//7 - ///implement TRSM/// - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b11101110); - _mm256_storeu_pd((double *)b11, ymm3); + ymm18 = _mm256_permute2f128_ps(ymm4,ymm4,0x20);//4 + ymm20 = _mm256_permute2f128_ps(ymm4,ymm4,0x31);//8 - m_remainder -=4; - } + ymm16 = _mm256_broadcast_ss((float const *)(&AlphaVal)); - if(m_remainder) - { - if(3 == m_remainder) - { - a01 = D_A_pack; - a11 = L + (n_remainder - 1)*cs_a + (n_remainder - 1)*rs_a; //pointer to block of A to be used for TRSM - b10 = B + (m_remainder - 3) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (m_remainder - 3) + (n_remainder - 1)*cs_b; //pointer to block of B to be used for TRSM + ymm8 = _mm256_loadu_ps((float const *)(b11 + 8)); + ymm9 = _mm256_loadu_ps((float const *)(b11 + cs_b + 8)); + ymm4 = _mm256_loadu_ps((float const *)(b11 + cs_b*2 + 8)); - k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + ymm8 = _mm256_fmsub_ps(ymm8, ymm16, ymm12); + ymm9 = _mm256_fmsub_ps(ymm9, ymm16, ymm13); + ymm4 = _mm256_fmsub_ps(ymm4, ymm16, ymm14); + ymm5 = _mm256_broadcast_ss((float const *)(&zero)); - ymm3 = _mm256_setzero_pd(); + ymm12 = _mm256_unpacklo_ps(ymm8, ymm9); + ymm13 = _mm256_unpacklo_ps(ymm4, ymm5); - ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) + ymm6 = _mm256_shuffle_ps(ymm12,ymm13,0b01000100); - BLIS_PRE_DTRSM_SMALL_1N_3M(AlphaVal,b11,cs_b) + ymm14 = _mm256_permute2f128_ps(ymm6,ymm6,0x20);//1 + ymm21 = _mm256_permute2f128_ps(ymm6,ymm6,0x31);//5 - ///implement TRSM/// - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + ymm6 = _mm256_shuffle_ps(ymm12,ymm13,0b11101110); - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm6, ymm3, 0x07); + ymm15 = _mm256_permute2f128_ps(ymm6,ymm6,0x20);//2 + ymm22 = _mm256_permute2f128_ps(ymm6,ymm6,0x31);//6 - BLIS_POST_DTRSM_SMALL_1N_3M(b11,cs_b) + ymm12 = _mm256_unpackhi_ps(ymm8, ymm9); + ymm13 = _mm256_unpackhi_ps(ymm4, ymm5); - m_remainder -=3; + ymm6 = _mm256_shuffle_ps(ymm12,ymm13,0b01000100); + + ymm8 = _mm256_permute2f128_ps(ymm6,ymm6,0x20);//3 + ymm4 = _mm256_permute2f128_ps(ymm6,ymm6,0x31);//7 + + ymm6 = _mm256_shuffle_ps(ymm12,ymm13,0b11101110); + + ymm9 = _mm256_permute2f128_ps(ymm6,ymm6,0x20);//4 + ymm5 = _mm256_permute2f128_ps(ymm6,ymm6,0x31);//8 } - else if(2 == m_remainder) + else if(2 == n_remainder) { - a01 = D_A_pack; - a11 = L + (n_remainder - 1)*cs_a + (n_remainder - 1)*rs_a; //pointer to block of A to be used for TRSM - b10 = B + (m_remainder - 2) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (m_remainder - 2) + (n_remainder - 1)*cs_b; //pointer to block of B to be used for TRSM + ///GEMM code begins/// + BLIS_STRSM_SMALL_GEMM_16mx2n(a10,b01,cs_b,p_lda,k_iter) - k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + ymm16 = _mm256_broadcast_ss((float const *)(&AlphaVal)); //register to hold alpha - ymm3 = _mm256_setzero_pd(); + ymm17 = _mm256_loadu_ps((float const *)(b11)); + ymm18 = _mm256_loadu_ps((float const *)(b11 + cs_b)); - ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) + ymm17 = _mm256_fmsub_ps(ymm17, ymm16, ymm8); + ymm18 = _mm256_fmsub_ps(ymm18, ymm16, ymm9); + ymm19 = _mm256_broadcast_ss((float const *)(&zero)); + ymm20 = _mm256_broadcast_ss((float const *)(&zero)); - BLIS_PRE_DTRSM_SMALL_1N_2M(AlphaVal,b11,cs_b) + ymm8 = _mm256_unpacklo_ps(ymm17, ymm18); + ymm9 = _mm256_unpacklo_ps(ymm19, ymm20); - ///implement TRSM/// - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b01000100); - BLIS_POST_DTRSM_SMALL_1N_2M(b11,cs_b) + ymm10 = _mm256_permute2f128_ps(ymm4,ymm4,0x20);//1 + ymm2 = _mm256_permute2f128_ps(ymm4,ymm4,0x31);//5 + + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b11101110); + + ymm11 = _mm256_permute2f128_ps(ymm4,ymm4,0x20);//2 + ymm3 = _mm256_permute2f128_ps(ymm4,ymm4,0x31);//6 + + ymm8 = _mm256_unpackhi_ps(ymm17, ymm18); + ymm9 = _mm256_unpackhi_ps(ymm19, ymm20); + + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b01000100); + + ymm17 = _mm256_permute2f128_ps(ymm4,ymm4,0x20);//3 + ymm19 = _mm256_permute2f128_ps(ymm4,ymm4,0x31);//7 - m_remainder -=2; - } - else if (1 == m_remainder) - { - a01 = D_A_pack; - a11 = L + (n_remainder - 1)*cs_a + (n_remainder - 1)*rs_a; //pointer to block of A to be used for TRSM - b10 = B + (m_remainder - 1) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (m_remainder - 1) + (n_remainder - 1)*cs_b; //pointer to block of B to be used for TRSM + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b11101110); - k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) + ymm18 = _mm256_permute2f128_ps(ymm4,ymm4,0x20);//4 + ymm20 = _mm256_permute2f128_ps(ymm4,ymm4,0x31);//8 - ymm3 = _mm256_setzero_pd(); + ymm16 = _mm256_broadcast_ss((float const *)(&AlphaVal)); - ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) + ymm8 = _mm256_loadu_ps((float const *)(b11 + 8)); + ymm9 = _mm256_loadu_ps((float const *)(b11 + cs_b + 8)); - BLIS_PRE_DTRSM_SMALL_1N_1M(AlphaVal,b11,cs_b) + ymm8 = _mm256_fmsub_ps(ymm8, ymm16, ymm12); + ymm9 = _mm256_fmsub_ps(ymm9, ymm16, ymm13); + ymm4 = _mm256_broadcast_ss((float const *)(&zero)); + ymm5 = _mm256_broadcast_ss((float const *)(&zero)); - ///implement TRSM/// - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + ymm12 = _mm256_unpacklo_ps(ymm8, ymm9); + ymm13 = _mm256_unpacklo_ps(ymm4, ymm5); - BLIS_POST_DTRSM_SMALL_1N_1M(b11,cs_b) + ymm6 = _mm256_shuffle_ps(ymm12,ymm13,0b01000100); - m_remainder -=1; - } - } - n_remainder -= 1; - } + ymm14 = _mm256_permute2f128_ps(ymm6,ymm6,0x20);//1 + ymm21 = _mm256_permute2f128_ps(ymm6,ymm6,0x31);//5 - if ((required_packing_A == 1) && bli_mem_is_alloc( &local_mem_buf_A_s )) - { - bli_membrk_release(&rntm, - &local_mem_buf_A_s); - } - return BLIS_SUCCESS; -} + ymm6 = _mm256_shuffle_ps(ymm12,ymm13,0b11101110); -/* TRSM for the case AX = alpha * B, Double precision - * A is lower-triangular, transpose, non-unit diagonal - * dimensions A: mxm X: mxn B: mxn -*/ -BLIS_INLINE err_t bli_dtrsm_small_AltXB_AuXB -( - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl -) -{ - dim_t m = bli_obj_length(b); // number of rows of matrix B - dim_t n = bli_obj_width(b); // number of columns of matrix B + ymm15 = _mm256_permute2f128_ps(ymm6,ymm6,0x20);//2 + ymm22 = _mm256_permute2f128_ps(ymm6,ymm6,0x31);//6 - bool transa = bli_obj_has_trans(a); - dim_t cs_a, rs_a; - dim_t d_mr = 8,d_nr = 6; + ymm12 = _mm256_unpackhi_ps(ymm8, ymm9); + ymm13 = _mm256_unpackhi_ps(ymm4, ymm5); - // Swap rs_a & cs_a in case of non-tranpose. - if(transa) - { - cs_a = bli_obj_col_stride(a); // column stride of A - rs_a = bli_obj_row_stride(a); // row stride of A - } - else - { - cs_a = bli_obj_row_stride(a); // row stride of A - rs_a = bli_obj_col_stride(a); // column stride of A - } - dim_t cs_b = bli_obj_col_stride(b); // column stride of B + ymm6 = _mm256_shuffle_ps(ymm12,ymm13,0b01000100); - dim_t i, j, k; //loop variables - dim_t k_iter; //number of times GEMM to be performed + ymm8 = _mm256_permute2f128_ps(ymm6,ymm6,0x20);//3 + ymm4 = _mm256_permute2f128_ps(ymm6,ymm6,0x31);//7 - double AlphaVal = *(double *)AlphaObj->buffer; //value of alpha - double *L = a->buffer; //pointer to matrix A - double *B = b->buffer; //pointer to matrix B + ymm6 = _mm256_shuffle_ps(ymm12,ymm13,0b11101110); - //pointers that point to blocks for GEMM and TRSM - double *a10, *a11, *b01, *b11; + ymm9 = _mm256_permute2f128_ps(ymm6,ymm6,0x20);//4 + ymm5 = _mm256_permute2f128_ps(ymm6,ymm6,0x31);//8 + } + else if(1 == n_remainder) + { + ///GEMM code begins/// + BLIS_STRSM_SMALL_GEMM_16mx1n(a10,b01,cs_b,p_lda,k_iter) - double ones = 1.0; - bool is_unitdiag = bli_obj_has_unit_diag(a); + ymm16 = _mm256_broadcast_ss((float const *)(&AlphaVal)); //register to hold alpha - //scratch registers - __m256d ymm0, ymm1, ymm2, ymm3; - __m256d ymm4, ymm5, ymm6, ymm7; - __m256d ymm8, ymm9, ymm10, ymm11; - __m256d ymm12, ymm13, ymm14, ymm15; - __m256d ymm16, ymm17, ymm18, ymm19; - __m256d ymm20; + ymm17 = _mm256_loadu_ps((float const *)(b11)); - __m128d xmm5; + ymm17 = _mm256_fmsub_ps(ymm17, ymm16, ymm8); + ymm18 = _mm256_broadcast_ss((float const *)(&zero)); + ymm19 = _mm256_broadcast_ss((float const *)(&zero)); + ymm20 = _mm256_broadcast_ss((float const *)(&zero)); - gint_t required_packing_A = 1; - mem_t local_mem_buf_A_s = {0}; - double *D_A_pack = NULL; - double d11_pack[d_mr] __attribute__((aligned(64))); - rntm_t rntm; + ymm8 = _mm256_unpacklo_ps(ymm17, ymm18); + ymm9 = _mm256_unpacklo_ps(ymm19, ymm20); - bli_rntm_init_from_global( &rntm ); - bli_rntm_set_num_threads_only( 1, &rntm ); - bli_membrk_rntm_set_membrk( &rntm ); + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b01000100); - siz_t buffer_size = bli_pool_block_size( - bli_membrk_pool( - bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), - bli_rntm_membrk(&rntm))); + ymm10 = _mm256_permute2f128_ps(ymm4,ymm4,0x20);//1 + ymm2 = _mm256_permute2f128_ps(ymm4,ymm4,0x31);//5 - if((d_mr * m * sizeof(double)) > buffer_size) - return BLIS_NOT_YET_IMPLEMENTED; + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b11101110); - if(required_packing_A == 1) - { - // Get the buffer from the pool. - bli_membrk_acquire_m(&rntm, - buffer_size, - BLIS_BITVAL_BUFFER_FOR_A_BLOCK, - &local_mem_buf_A_s); - if(FALSE==bli_mem_is_alloc(&local_mem_buf_A_s)) return BLIS_NULL_POINTER; - D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); - if(NULL==D_A_pack) return BLIS_NULL_POINTER; - } + ymm11 = _mm256_permute2f128_ps(ymm4,ymm4,0x20);//2 + ymm3 = _mm256_permute2f128_ps(ymm4,ymm4,0x31);//6 - /* - Performs solving TRSM for 8 colmns at a time from 0 to m/d_mr in steps of d_mr - a. Load, transpose, Pack A (a10 block), the size of packing 8x6 to 8x (m-d_mr) - First there will be no GEMM and no packing of a10 because it is only TRSM - b. Using packed a10 block and b01 block perform GEMM operation - c. Use GEMM outputs, perform TRSM operaton using a11, b11 and update B - d. Repeat b,c for n rows of B in steps of d_nr - */ - for(i = (m - d_mr); (i + 1) > 0; i -= d_mr) - { - a10 = L + (i*cs_a) + (i + d_mr)*rs_a; //pointer to block of A to be used for GEMM - a11 = L + (i*cs_a) + (i*rs_a); //pointer to block of A to be used for TRSM + ymm8 = _mm256_unpackhi_ps(ymm17, ymm18); + ymm9 = _mm256_unpackhi_ps(ymm19, ymm20); - // Do transpose for a10 & store in D_A_pack - //ptr_a10_dup = D_A_pack; + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b01000100); - dim_t p_lda = d_mr; // packed leading dimension + ymm17 = _mm256_permute2f128_ps(ymm4,ymm4,0x20);//3 + ymm19 = _mm256_permute2f128_ps(ymm4,ymm4,0x31);//7 - if(transa) - { - /* - Load, transpose and pack current A block (a10) into packed buffer memory D_A_pack - a. This a10 block is used in GEMM portion only and this - a10 block size will be increasing by d_mr for every next itteration - untill it reaches 8x(m-8) which is the maximum GEMM alone block size in A - b. This packed buffer is reused to calculate all n rows of B matrix - */ - bli_dtrsm_small_pack('L', (m-i-d_mr), 1, a10, cs_a, D_A_pack,p_lda,d_mr); + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b11101110); - /* - Pack 8 diagonal elements of A block into an array - a. This helps in utilze cache line efficiently in TRSM operation - b. store ones when input is unit diagonal - */ - dtrsm_small_pack_diag_element(is_unitdiag,a11,cs_a,d11_pack,d_mr); - } - else - { - bli_dtrsm_small_pack('L', (m-i-d_mr), 0, a10, rs_a, D_A_pack,p_lda,d_mr); - dtrsm_small_pack_diag_element(is_unitdiag,a11,rs_a,d11_pack,d_mr); - } + ymm18 = _mm256_permute2f128_ps(ymm4,ymm4,0x20);//4 + ymm20 = _mm256_permute2f128_ps(ymm4,ymm4,0x31);//8 - /* - a. Perform GEMM using a10, b01. - b. Perform TRSM on a11, b11 - c. This loop GEMM+TRSM loops operates with 8x6 block size - along n dimension for every d_nr rows of b01 where - packed A buffer is reused in computing all n rows of B. - d. Same approch is used in remaining fringe cases. - */ - for(j = (n - d_nr); (j + 1) > 0; j -= d_nr) - { - a10 = D_A_pack; - b01 = B + (j * cs_b) + i + d_mr; //pointer to block of B to be used for GEMM - b11 = B + (j * cs_b) + i; //pointer to block of B to be used for TRSM + ymm16 = _mm256_broadcast_ss((float const *)(&AlphaVal)); - k_iter = (m - i - d_mr); + ymm8 = _mm256_loadu_ps((float const *)(b11 + 8)); - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS + ymm8 = _mm256_fmsub_ps(ymm8, ymm16, ymm12); + ymm9 = _mm256_broadcast_ss((float const *)(&zero)); + ymm4 = _mm256_broadcast_ss((float const *)(&zero)); + ymm5 = _mm256_broadcast_ss((float const *)(&zero)); - /* - Peform GEMM between a10 and b01 blocks - For first itteration there will be no GEMM operation - where k_iter are zero - */ - BLIS_DTRSM_SMALL_GEMM_8mx6n(a10,b01,cs_b,p_lda,k_iter) + ymm12 = _mm256_unpacklo_ps(ymm8, ymm9); + ymm13 = _mm256_unpacklo_ps(ymm4, ymm5); - /* - Load b11 of size 6x8 and multiply with alpha - Add the GEMM output and perform inregister transose of b11 - to peform TRSM operation. - */ - BLIS_DTRSM_SMALL_NREG_TRANSPOSE_6x8(b11,cs_b,AlphaVal) + ymm6 = _mm256_shuffle_ps(ymm12,ymm13,0b01000100); - /* - Compute 8x6 TRSM block by using GEMM block output in register - a. The 8x6 input (gemm outputs) are stored in combinations of ymm registers - 1. ymm15, ymm20 2. ymm14, ymm19 3. ymm13, ymm18 , 4. ymm12, ymm17 - 5. ymm11, ymm7 6. ymm10, ymm6, 7.ymm9, ymm5 8. ymm8, ymm4 - where ymm15-ymm8 holds 8x4 data and reaming 8x2 will be hold by - other registers - b. Towards the end do in regiser transpose of TRSM output and store in b11 - */ - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); + ymm14 = _mm256_permute2f128_ps(ymm6,ymm6,0x20);//1 + ymm21 = _mm256_permute2f128_ps(ymm6,ymm6,0x31);//5 - //perform mul operation - ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); - ymm20 = DTRSM_SMALL_DIV_OR_SCALE(ymm20, ymm1); + ymm6 = _mm256_shuffle_ps(ymm12,ymm13,0b11101110); - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); + ymm15 = _mm256_permute2f128_ps(ymm6,ymm6,0x20);//2 + ymm22 = _mm256_permute2f128_ps(ymm6,ymm6,0x31);//6 - //(ROw7): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6*cs_a + 7*rs_a)); - ymm14 = _mm256_fnmadd_pd(ymm2, ymm15, ymm14); - ymm19 = _mm256_fnmadd_pd(ymm2, ymm20, ymm19); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 7*rs_a)); - ymm13 = _mm256_fnmadd_pd(ymm2, ymm15, ymm13); - ymm18 = _mm256_fnmadd_pd(ymm2, ymm20, ymm18); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 7*rs_a)); - ymm12 = _mm256_fnmadd_pd(ymm2, ymm15, ymm12); - ymm17 = _mm256_fnmadd_pd(ymm2, ymm20, ymm17); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 7*rs_a)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm15, ymm11); - ymm7 = _mm256_fnmadd_pd(ymm2, ymm20, ymm7); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 7*rs_a)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm15, ymm10); - ymm6 = _mm256_fnmadd_pd(ymm2, ymm20, ymm6); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 7*rs_a)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm15, ymm9); - ymm5 = _mm256_fnmadd_pd(ymm2, ymm20, ymm5); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7*rs_a)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm15, ymm8); - ymm4 = _mm256_fnmadd_pd(ymm2, ymm20, ymm4); + ymm12 = _mm256_unpackhi_ps(ymm8, ymm9); + ymm13 = _mm256_unpackhi_ps(ymm4, ymm5); - //perform mul operation - ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); - ymm19 = DTRSM_SMALL_DIV_OR_SCALE(ymm19, ymm1); + ymm6 = _mm256_shuffle_ps(ymm12,ymm13,0b01000100); - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + ymm8 = _mm256_permute2f128_ps(ymm6,ymm6,0x20);//3 + ymm4 = _mm256_permute2f128_ps(ymm6,ymm6,0x31);//7 - //(ROw6): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 6*rs_a)); - ymm13 = _mm256_fnmadd_pd(ymm2, ymm14, ymm13); - ymm18 = _mm256_fnmadd_pd(ymm2, ymm19, ymm18); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 6*rs_a)); - ymm12 = _mm256_fnmadd_pd(ymm2, ymm14, ymm12); - ymm17 = _mm256_fnmadd_pd(ymm2, ymm19, ymm17); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 6*rs_a)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm14, ymm11); - ymm7 = _mm256_fnmadd_pd(ymm2, ymm19, ymm7); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 6*rs_a)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm14, ymm10); - ymm6 = _mm256_fnmadd_pd(ymm2, ymm19, ymm6); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 6*rs_a)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm14, ymm9); - ymm5 = _mm256_fnmadd_pd(ymm2, ymm19, ymm5); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6*rs_a)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm14, ymm8); - ymm4 = _mm256_fnmadd_pd(ymm2, ymm19, ymm4); + ymm6 = _mm256_shuffle_ps(ymm12,ymm13,0b11101110); - //perform mul operation - ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm1); - ymm18 = DTRSM_SMALL_DIV_OR_SCALE(ymm18, ymm1); + ymm9 = _mm256_permute2f128_ps(ymm6,ymm6,0x20);//4 + ymm5 = _mm256_permute2f128_ps(ymm6,ymm6,0x31);//8 + } - //extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + ///implement TRSM/// - //(ROw5): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 5*rs_a)); - ymm12 = _mm256_fnmadd_pd(ymm2, ymm13, ymm12); - ymm17 = _mm256_fnmadd_pd(ymm2, ymm18, ymm17); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 5*rs_a)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm13, ymm11); - ymm7 = _mm256_fnmadd_pd(ymm2, ymm18, ymm7); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 5*rs_a)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm13, ymm10); - ymm6 = _mm256_fnmadd_pd(ymm2, ymm18, ymm6); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 5*rs_a)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm13, ymm9); - ymm5 = _mm256_fnmadd_pd(ymm2, ymm18, ymm5); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm13, ymm8); - ymm4 = _mm256_fnmadd_pd(ymm2, ymm18, ymm4); + ////extract a00 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 15)); //perform mul operation - ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); - ymm17 = DTRSM_SMALL_DIV_OR_SCALE(ymm17, ymm1); + ymm5 = STRSM_SMALL_DIV_OR_SCALE(ymm5, ymm1); - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + //extract a11 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 14)); - //(ROw4): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 4*rs_a)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm12, ymm11); - ymm7 = _mm256_fnmadd_pd(ymm2, ymm17, ymm7); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 4*rs_a)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm12, ymm10); - ymm6 = _mm256_fnmadd_pd(ymm2, ymm17, ymm6); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 4*rs_a)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm12, ymm9); - ymm5 = _mm256_fnmadd_pd(ymm2, ymm17, ymm5); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm12, ymm8); - ymm4 = _mm256_fnmadd_pd(ymm2, ymm17, ymm4); + //(ROw1): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*15 + cs_a*14)); + ymm4 = _mm256_fnmadd_ps(ymm0, ymm5, ymm4); - //perform mul operation - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm1); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*15 + cs_a*13)); + ymm22 = _mm256_fnmadd_ps(ymm0, ymm5, ymm22); - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*15 + cs_a*12)); + ymm21 = _mm256_fnmadd_ps(ymm0, ymm5, ymm21); - //(ROw3): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 3*rs_a)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm11, ymm10); - ymm6 = _mm256_fnmadd_pd(ymm2, ymm7, ymm6); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3*rs_a)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm11, ymm9); - ymm5 = _mm256_fnmadd_pd(ymm2, ymm7, ymm5); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm11, ymm8); - ymm4 = _mm256_fnmadd_pd(ymm2, ymm7, ymm4); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*15 + cs_a*11)); + ymm9 = _mm256_fnmadd_ps(ymm0, ymm5, ymm9); - //perform mul operation - ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); - ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm1); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*15 + cs_a*10)); + ymm8 = _mm256_fnmadd_ps(ymm0, ymm5, ymm8); - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*15 + cs_a*9)); + ymm15 = _mm256_fnmadd_ps(ymm0, ymm5, ymm15); - //(ROw2): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2*rs_a)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm10, ymm9); - ymm5 = _mm256_fnmadd_pd(ymm2, ymm6, ymm5); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm10, ymm8); - ymm4 = _mm256_fnmadd_pd(ymm2, ymm6, ymm4); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*15 + cs_a*8)); + ymm14 = _mm256_fnmadd_ps(ymm0, ymm5, ymm14); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*15 + cs_a*7)); + ymm20 = _mm256_fnmadd_ps(ymm0, ymm5, ymm20); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*15 + cs_a*6)); + ymm19 = _mm256_fnmadd_ps(ymm0, ymm5, ymm19); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*15 + cs_a*5)); + ymm3 = _mm256_fnmadd_ps(ymm0, ymm5, ymm3 ); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*15 + cs_a*4)); + ymm2 = _mm256_fnmadd_ps(ymm0, ymm5, ymm2 ); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*15 + cs_a*3)); + ymm18 = _mm256_fnmadd_ps(ymm0, ymm5, ymm18); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*15 + cs_a*2)); + ymm17 = _mm256_fnmadd_ps(ymm0, ymm5, ymm17); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*15 + cs_a*1)); + ymm11 = _mm256_fnmadd_ps(ymm0, ymm5, ymm11 ); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*15)); + ymm10 = _mm256_fnmadd_ps(ymm0, ymm5, ymm10 ); //perform mul operation - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm1); + ymm4 = STRSM_SMALL_DIV_OR_SCALE(ymm4, ymm1); - //extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + //extract a22 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 13)); //(ROw2): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm9, ymm8); - ymm4 = _mm256_fnmadd_pd(ymm2, ymm5, ymm4); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*14 + cs_a*13)); + ymm22 = _mm256_fnmadd_ps(ymm0, ymm4, ymm22); - //perform mul operation - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); - ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm1); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*14 + cs_a*12)); + ymm21 = _mm256_fnmadd_ps(ymm0, ymm4, ymm21); - BLIS_DTRSM_SMALL_NREG_TRANSPOSE_8x6_AND_STORE(b11,cs_b) - } + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*14 + cs_a*11)); + ymm9 = _mm256_fnmadd_ps(ymm0, ymm4, ymm9); - dim_t n_remainder = j + d_nr; - if(n_remainder >= 4) - { - a10 = D_A_pack; - a11 = L + (i*cs_a) + (i*rs_a); - b01 = B + ((n_remainder - 4)* cs_b) + i + d_mr; - b11 = B + ((n_remainder - 4)* cs_b) + i; + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*14 + cs_a*10)); + ymm8 = _mm256_fnmadd_ps(ymm0, ymm4, ymm8); - k_iter = (m - i - d_mr); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*14 + cs_a*9)); + ymm15 = _mm256_fnmadd_ps(ymm0, ymm4, ymm15); - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*14 + cs_a*8)); + ymm14 = _mm256_fnmadd_ps(ymm0, ymm4, ymm14); - ///GEMM code begins/// - BLIS_DTRSM_SMALL_GEMM_8mx4n(a10,b01,cs_b,p_lda,k_iter) + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*14 + cs_a*7)); + ymm20 = _mm256_fnmadd_ps(ymm0, ymm4, ymm20); - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*14 + cs_a*6)); + ymm19 = _mm256_fnmadd_ps(ymm0, ymm4, ymm19); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*14 + cs_a*5)); + ymm3 = _mm256_fnmadd_ps(ymm0, ymm4, ymm3 ); - ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] - ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] - ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *2 + 4)); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] - ymm7 = _mm256_loadu_pd((double const *)(b11 + cs_b *3 + 4)); //B11[0][7] B11[1][7] B11[2][7] B11[3][7] + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*14 + cs_a*4)); + ymm2 = _mm256_fnmadd_ps(ymm0, ymm4, ymm2 ); - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); //B11[0-3][3] * alpha -= B01[0-3][3] - ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4] - ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] * alpha -= B01[0-3][5] - ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); //B11[0-3][6] * alpha -= B01[0-3][6] - ymm7 = _mm256_fmsub_pd(ymm7, ymm16, ymm15); //B11[0-3][7] * alpha -= B01[0-3][7] + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*14 + cs_a*3)); + ymm18 = _mm256_fnmadd_ps(ymm0, ymm4, ymm18); - ///implement TRSM/// + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*14 + cs_a*2)); + ymm17 = _mm256_fnmadd_ps(ymm0, ymm4, ymm17); - ///transpose of B11// - ///unpacklow/// - ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*14 + cs_a*1)); + ymm11 = _mm256_fnmadd_ps(ymm0, ymm4, ymm11 ); - ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[2][4] B11[2][5] - ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[2][6] B11[2][7] + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*14)); + ymm10 = _mm256_fnmadd_ps(ymm0, ymm4, ymm10 ); - //rearrange low elements - ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + //perform mul operation + ymm22 = STRSM_SMALL_DIV_OR_SCALE(ymm22, ymm1); - ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3] - ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3] + //extract a33 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 12)); - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + //(ROw3): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*13 + cs_a*12)); + ymm21 = _mm256_fnmadd_ps(ymm0, ymm22, ymm21); - ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[1][4] B11[1][5] B11[3][4] B11[3][5] - ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[1][6] B11[1][7] B11[3][6] B11[3][7] + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*13 + cs_a*11)); + ymm9 = _mm256_fnmadd_ps(ymm0, ymm22, ymm9); - //rearrange high elements - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*13 + cs_a*10)); + ymm8 = _mm256_fnmadd_ps(ymm0, ymm22, ymm8); - ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3] - ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3] + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*13 + cs_a*9)); + ymm15 = _mm256_fnmadd_ps(ymm0, ymm22, ymm15); - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*13 + cs_a*8)); + ymm14 = _mm256_fnmadd_ps(ymm0, ymm22, ymm14); - //perform mul operation - ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*13 + cs_a*7)); + ymm20 = _mm256_fnmadd_ps(ymm0, ymm22, ymm20); - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*13 + cs_a*6)); + ymm19 = _mm256_fnmadd_ps(ymm0, ymm22, ymm19); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6*cs_a + 7*rs_a)); - ymm3 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 7*rs_a)); - ymm4 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 7*rs_a)); - ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 7*rs_a)); - ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 7*rs_a)); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 7*rs_a)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + 7*rs_a)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*13 + cs_a*5)); + ymm3 = _mm256_fnmadd_ps(ymm0, ymm22, ymm3 ); - //(ROw7): FMA operations - ymm14 = _mm256_fnmadd_pd(ymm2, ymm15, ymm14); - ymm13 = _mm256_fnmadd_pd(ymm3, ymm15, ymm13); - ymm12 = _mm256_fnmadd_pd(ymm4, ymm15, ymm12); - ymm11 = _mm256_fnmadd_pd(ymm5, ymm15, ymm11); - ymm10 = _mm256_fnmadd_pd(ymm6, ymm15, ymm10); - ymm9 = _mm256_fnmadd_pd(ymm7, ymm15, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm16, ymm15, ymm8); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*13 + cs_a*4)); + ymm2 = _mm256_fnmadd_ps(ymm0, ymm22, ymm2 ); - //perform mul operation - ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*13 + cs_a*3)); + ymm18 = _mm256_fnmadd_ps(ymm0, ymm22, ymm18); - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*13 + cs_a*2)); + ymm17 = _mm256_fnmadd_ps(ymm0, ymm22, ymm17); - ymm3 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 6*rs_a)); - ymm4 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 6*rs_a)); - ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 6*rs_a)); - ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 6*rs_a)); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 6*rs_a)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + 6*rs_a)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*13 + cs_a*1)); + ymm11 = _mm256_fnmadd_ps(ymm0, ymm22, ymm11 ); - //(ROw6): FMA operations - ymm13 = _mm256_fnmadd_pd(ymm3, ymm14, ymm13); - ymm12 = _mm256_fnmadd_pd(ymm4, ymm14, ymm12); - ymm11 = _mm256_fnmadd_pd(ymm5, ymm14, ymm11); - ymm10 = _mm256_fnmadd_pd(ymm6, ymm14, ymm10); - ymm9 = _mm256_fnmadd_pd(ymm7, ymm14, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm16, ymm14, ymm8); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*13)); + ymm10 = _mm256_fnmadd_ps(ymm0, ymm22, ymm10 ); //perform mul operation - ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm1); + ymm21 = STRSM_SMALL_DIV_OR_SCALE(ymm21, ymm1); - //extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + //extract a44 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 11)); + //(ROw4): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*12 + cs_a*11)); + ymm9 = _mm256_fnmadd_ps(ymm0, ymm21, ymm9); - ymm4 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 5*rs_a)); - ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 5*rs_a)); - ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 5*rs_a)); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 5*rs_a)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*12 + cs_a*10)); + ymm8 = _mm256_fnmadd_ps(ymm0, ymm21, ymm8); - //(ROw5): FMA operations - ymm12 = _mm256_fnmadd_pd(ymm4, ymm13, ymm12); - ymm11 = _mm256_fnmadd_pd(ymm5, ymm13, ymm11); - ymm10 = _mm256_fnmadd_pd(ymm6, ymm13, ymm10); - ymm9 = _mm256_fnmadd_pd(ymm7, ymm13, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm16, ymm13, ymm8); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*12 + cs_a*9)); + ymm15 = _mm256_fnmadd_ps(ymm0, ymm21, ymm15); - //perform mul operation - ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*12 + cs_a*8)); + ymm14 = _mm256_fnmadd_ps(ymm0, ymm21, ymm14); - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*12 + cs_a*7)); + ymm20 = _mm256_fnmadd_ps(ymm0, ymm21, ymm20); - ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 4*rs_a)); - ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 4*rs_a)); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 4*rs_a)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*12 + cs_a*6)); + ymm19 = _mm256_fnmadd_ps(ymm0, ymm21, ymm19); - //(ROw4): FMA operations - ymm11 = _mm256_fnmadd_pd(ymm5, ymm12, ymm11); - ymm10 = _mm256_fnmadd_pd(ymm6, ymm12, ymm10); - ymm9 = _mm256_fnmadd_pd(ymm7, ymm12, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm16, ymm12, ymm8); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*12 + cs_a*5)); + ymm3 = _mm256_fnmadd_ps(ymm0, ymm21, ymm3 ); - //perform mul operation - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*12 + cs_a*4)); + ymm2 = _mm256_fnmadd_ps(ymm0, ymm21, ymm2 ); - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*12 + cs_a*3)); + ymm18 = _mm256_fnmadd_ps(ymm0, ymm21, ymm18); - ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 3*rs_a)); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3*rs_a)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*12 + cs_a*2)); + ymm17 = _mm256_fnmadd_ps(ymm0, ymm21, ymm17); - //(ROw3): FMA operations - ymm10 = _mm256_fnmadd_pd(ymm6, ymm11, ymm10); - ymm9 = _mm256_fnmadd_pd(ymm7, ymm11, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm16, ymm11, ymm8); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*12 + cs_a*1)); + ymm11 = _mm256_fnmadd_ps(ymm0, ymm21, ymm11 ); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*12)); + ymm10 = _mm256_fnmadd_ps(ymm0, ymm21, ymm10 ); //perform mul operation - ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + ymm9 = STRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + //extract a55 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 10)); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2*rs_a)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + //(ROw5): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*11 + cs_a*10)); + ymm8 = _mm256_fnmadd_ps(ymm0, ymm9, ymm8); - //(ROw2): FMA operations - ymm9 = _mm256_fnmadd_pd(ymm7, ymm10, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm16, ymm10, ymm8); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*11 + cs_a*9)); + ymm15 = _mm256_fnmadd_ps(ymm0, ymm9, ymm15); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*11 + cs_a*8)); + ymm14 = _mm256_fnmadd_ps(ymm0, ymm9, ymm14); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*11 + cs_a*7)); + ymm20 = _mm256_fnmadd_ps(ymm0, ymm9, ymm20); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*11 + cs_a*6)); + ymm19 = _mm256_fnmadd_ps(ymm0, ymm9, ymm19); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*11 + cs_a*5)); + ymm3 = _mm256_fnmadd_ps(ymm0, ymm9, ymm3 ); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*11 + cs_a*4)); + ymm2 = _mm256_fnmadd_ps(ymm0, ymm9, ymm2 ); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*11 + cs_a*3)); + ymm18 = _mm256_fnmadd_ps(ymm0, ymm9, ymm18); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*11 + cs_a*2)); + ymm17 = _mm256_fnmadd_ps(ymm0, ymm9, ymm17); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*11 + cs_a*1)); + ymm11 = _mm256_fnmadd_ps(ymm0, ymm9, ymm11 ); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*11)); + ymm10 = _mm256_fnmadd_ps(ymm0, ymm9, ymm10 ); //perform mul operation - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); + ymm8 = STRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); - //extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + //extract a66 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 9)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); + //(ROw6): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*10 + cs_a*9)); + ymm15 = _mm256_fnmadd_ps(ymm0, ymm8, ymm15); - //(ROw2): FMA operations - ymm8 = _mm256_fnmadd_pd(ymm16, ymm9, ymm8); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*10 + cs_a*8)); + ymm14 = _mm256_fnmadd_ps(ymm0, ymm8, ymm14); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*10 + cs_a*7)); + ymm20 = _mm256_fnmadd_ps(ymm0, ymm8, ymm20); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*10 + cs_a*6)); + ymm19 = _mm256_fnmadd_ps(ymm0, ymm8, ymm19); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*10 + cs_a*5)); + ymm3 = _mm256_fnmadd_ps(ymm0, ymm8, ymm3 ); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*10 + cs_a*4)); + ymm2 = _mm256_fnmadd_ps(ymm0, ymm8, ymm2 ); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*10 + cs_a*3)); + ymm18 = _mm256_fnmadd_ps(ymm0, ymm8, ymm18); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*10 + cs_a*2)); + ymm17 = _mm256_fnmadd_ps(ymm0, ymm8, ymm17); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*10 + cs_a*1)); + ymm11 = _mm256_fnmadd_ps(ymm0, ymm8, ymm11 ); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*10)); + ymm10 = _mm256_fnmadd_ps(ymm0, ymm8, ymm10 ); //perform mul operation - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); + ymm15 = STRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + //extract a77 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 8)); - ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] - ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2] + //(ROw7): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*9 + cs_a*8)); + ymm14 = _mm256_fnmadd_ps(ymm0, ymm15, ymm14); - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*9 + cs_a*7)); + ymm20 = _mm256_fnmadd_ps(ymm0, ymm15, ymm20); - ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] - ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*9 + cs_a*6)); + ymm19 = _mm256_fnmadd_ps(ymm0, ymm15, ymm19); - ///unpack high/// - ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*9 + cs_a*5)); + ymm3 = _mm256_fnmadd_ps(ymm0, ymm15, ymm3 ); - ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[4][1] B11[5][1] B11[4][3] B11[5][3] - ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[6][1] B11[7][1] B11[6][3] B11[7][3] + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*9 + cs_a*4)); + ymm2 = _mm256_fnmadd_ps(ymm0, ymm15, ymm2 ); - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*9 + cs_a*3)); + ymm18 = _mm256_fnmadd_ps(ymm0, ymm15, ymm18); - ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] - ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*9 + cs_a*2)); + ymm17 = _mm256_fnmadd_ps(ymm0, ymm15, ymm17); - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); //store B11[4][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); //store B11[5][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 4), ymm6); //store B11[6][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 3 + 4), ymm7); //store B11[7][0-3] - n_remainder -=4; - } + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*9 + cs_a*1)); + ymm11 = _mm256_fnmadd_ps(ymm0, ymm15, ymm11 ); - if(n_remainder) //implementation fo remaining columns(when 'N' is not a multiple of d_nr)() n = 3 - { - a10 = D_A_pack; - a11 = L + (i*cs_a) + (i*rs_a); - b01 = B + i + d_mr; - b11 = B + i; + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*9)); + ymm10 = _mm256_fnmadd_ps(ymm0, ymm15, ymm10 ); - k_iter = (m - i - d_mr) ; + //perform mul operation + ymm14 = STRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS + //extract a88 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 7)); - if(3 == n_remainder) - { - ///GEMM code begins/// - BLIS_DTRSM_SMALL_GEMM_8mx3n(a10,b01,cs_b,p_lda,k_iter) + //(ROw8): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*8 + cs_a*7)); + ymm20 = _mm256_fnmadd_ps(ymm0, ymm14, ymm20); - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*8 + cs_a*6)); + ymm19 = _mm256_fnmadd_ps(ymm0, ymm14, ymm19); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*8 + cs_a*5)); + ymm3 = _mm256_fnmadd_ps(ymm0, ymm14, ymm3 ); - ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] - ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] - ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *2 + 4)); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*8 + cs_a*4)); + ymm2 = _mm256_fnmadd_ps(ymm0, ymm14, ymm2 ); - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] - ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*8 + cs_a*3)); + ymm18 = _mm256_fnmadd_ps(ymm0, ymm14, ymm18); - ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4] - ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] * alpha -= B01[0-3][5] - ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); //B11[0-3][6] * alpha -= B01[0-3][6] - ymm7 = _mm256_broadcast_sd((double const *)(&ones)); - } - else if(2 == n_remainder) - { - ///GEMM code begins/// - BLIS_DTRSM_SMALL_GEMM_8mx2n(a10,b01,cs_b,p_lda,k_iter) + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*8 + cs_a*2)); + ymm17 = _mm256_fnmadd_ps(ymm0, ymm14, ymm17); - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*8 + cs_a*1)); + ymm11 = _mm256_fnmadd_ps(ymm0, ymm14, ymm11 ); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*8)); + ymm10 = _mm256_fnmadd_ps(ymm0, ymm14, ymm10 ); - ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] - ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] + //perform mul operation + ymm20 = STRSM_SMALL_DIV_OR_SCALE(ymm20, ymm1); - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] - ymm2 = _mm256_broadcast_sd((double const *)(&ones)); - ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + //extract a99 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 6)); - ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4] - ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] * alpha -= B01[0-3][5] - ymm6 = _mm256_broadcast_sd((double const *)(&ones)); - ymm7 = _mm256_broadcast_sd((double const *)(&ones)); - } - else if(1 == n_remainder) - { - ///GEMM code begins/// - BLIS_DTRSM_SMALL_GEMM_8mx1n(a10,b01,cs_b,p_lda,k_iter) + //(ROw9): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*7 + cs_a*6)); + ymm19 = _mm256_fnmadd_ps(ymm0, ymm20, ymm19); - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*7 + cs_a*5)); + ymm3 = _mm256_fnmadd_ps(ymm0, ymm20, ymm3 ); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*7 + cs_a*4)); + ymm2 = _mm256_fnmadd_ps(ymm0, ymm20, ymm2 ); - ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*7 + cs_a*3)); + ymm18 = _mm256_fnmadd_ps(ymm0, ymm20, ymm18); - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_broadcast_sd((double const *)(&ones)); - ymm2 = _mm256_broadcast_sd((double const *)(&ones)); - ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*7 + cs_a*2)); + ymm17 = _mm256_fnmadd_ps(ymm0, ymm20, ymm17); - ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4] - ymm5 = _mm256_broadcast_sd((double const *)(&ones)); - ymm6 = _mm256_broadcast_sd((double const *)(&ones)); - ymm7 = _mm256_broadcast_sd((double const *)(&ones)); - } - ///implement TRSM/// + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*7 + cs_a*1)); + ymm11 = _mm256_fnmadd_ps(ymm0, ymm20, ymm11 ); - ///transpose of B11// - ///unpacklow/// - ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*7)); + ymm10 = _mm256_fnmadd_ps(ymm0, ymm20, ymm10 ); - ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[2][4] B11[2][5] - ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[2][6] B11[2][7] + //perform mul operation + ymm19 = STRSM_SMALL_DIV_OR_SCALE(ymm19, ymm1); - //rearrange low elements - ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + //extract a10 10 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 5)); - ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3] - ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3] + //(ROw10): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*6 + cs_a*5)); + ymm3 = _mm256_fnmadd_ps(ymm0, ymm19, ymm3 ); - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*6 + cs_a*4)); + ymm2 = _mm256_fnmadd_ps(ymm0, ymm19, ymm2 ); - ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[1][4] B11[1][5] B11[3][4] B11[3][5] - ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[1][6] B11[1][7] B11[3][6] B11[3][7] + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*6 + cs_a*3)); + ymm18 = _mm256_fnmadd_ps(ymm0, ymm19, ymm18); - //rearrange high elements - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*6 + cs_a*2)); + ymm17 = _mm256_fnmadd_ps(ymm0, ymm19, ymm17); - ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3] - ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3] + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*6 + cs_a*1)); + ymm11 = _mm256_fnmadd_ps(ymm0, ymm19, ymm11 ); - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*6)); + ymm10 = _mm256_fnmadd_ps(ymm0, ymm19, ymm10 ); //perform mul operation - ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm1); - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); + //extract a11 11 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 4)); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6*cs_a + 7*rs_a)); - ymm3 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 7*rs_a)); - ymm4 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 7*rs_a)); - ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 7*rs_a)); - ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 7*rs_a)); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 7*rs_a)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + 7*rs_a)); + //(ROw11): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*5 + cs_a*4)); + ymm2 = _mm256_fnmadd_ps(ymm0, ymm3, ymm2 ); - //(ROw7): FMA operations - ymm14 = _mm256_fnmadd_pd(ymm2, ymm15, ymm14); - ymm13 = _mm256_fnmadd_pd(ymm3, ymm15, ymm13); - ymm12 = _mm256_fnmadd_pd(ymm4, ymm15, ymm12); - ymm11 = _mm256_fnmadd_pd(ymm5, ymm15, ymm11); - ymm10 = _mm256_fnmadd_pd(ymm6, ymm15, ymm10); - ymm9 = _mm256_fnmadd_pd(ymm7, ymm15, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm16, ymm15, ymm8); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*5 + cs_a*3)); + ymm18 = _mm256_fnmadd_ps(ymm0, ymm3, ymm18); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*5 + cs_a*2)); + ymm17 = _mm256_fnmadd_ps(ymm0, ymm3, ymm17); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*5 + cs_a*1)); + ymm11 = _mm256_fnmadd_ps(ymm0, ymm3, ymm11 ); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*5)); + ymm10 = _mm256_fnmadd_ps(ymm0, ymm3, ymm10 ); //perform mul operation - ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); + ymm2 = STRSM_SMALL_DIV_OR_SCALE(ymm2, ymm1); - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + //extract a12 12 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 3)); - ymm3 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 6*rs_a)); - ymm4 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 6*rs_a)); - ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 6*rs_a)); - ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 6*rs_a)); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 6*rs_a)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + 6*rs_a)); + //(ROw12): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*4 + cs_a*3)); + ymm18 = _mm256_fnmadd_ps(ymm0, ymm2, ymm18); - //(ROw6): FMA operations - ymm13 = _mm256_fnmadd_pd(ymm3, ymm14, ymm13); - ymm12 = _mm256_fnmadd_pd(ymm4, ymm14, ymm12); - ymm11 = _mm256_fnmadd_pd(ymm5, ymm14, ymm11); - ymm10 = _mm256_fnmadd_pd(ymm6, ymm14, ymm10); - ymm9 = _mm256_fnmadd_pd(ymm7, ymm14, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm16, ymm14, ymm8); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*4 + cs_a*2)); + ymm17 = _mm256_fnmadd_ps(ymm0, ymm2, ymm17); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*4 + cs_a*1)); + ymm11 = _mm256_fnmadd_ps(ymm0, ymm2, ymm11 ); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*4)); + ymm10 = _mm256_fnmadd_ps(ymm0, ymm2, ymm10 ); //perform mul operation - ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm1); + ymm18 = STRSM_SMALL_DIV_OR_SCALE(ymm18, ymm1); - //extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + //extract a13 13 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 2)); - ymm4 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 5*rs_a)); - ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 5*rs_a)); - ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 5*rs_a)); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 5*rs_a)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); + //(ROw13): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*3 + cs_a*2)); + ymm17 = _mm256_fnmadd_ps(ymm0, ymm18, ymm17); - //(ROw5): FMA operations - ymm12 = _mm256_fnmadd_pd(ymm4, ymm13, ymm12); - ymm11 = _mm256_fnmadd_pd(ymm5, ymm13, ymm11); - ymm10 = _mm256_fnmadd_pd(ymm6, ymm13, ymm10); - ymm9 = _mm256_fnmadd_pd(ymm7, ymm13, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm16, ymm13, ymm8); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*3 + cs_a*1)); + ymm11 = _mm256_fnmadd_ps(ymm0, ymm18, ymm11 ); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*3)); + ymm10 = _mm256_fnmadd_ps(ymm0, ymm18, ymm10 ); //perform mul operation - ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); + ymm17 = STRSM_SMALL_DIV_OR_SCALE(ymm17, ymm1); - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + //extract a14 14 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); - ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 4*rs_a)); - ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 4*rs_a)); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 4*rs_a)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); + //(ROw13): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*2 + cs_a*1)); + ymm11 = _mm256_fnmadd_ps(ymm0, ymm17, ymm11 ); - //(ROw4): FMA operations - ymm11 = _mm256_fnmadd_pd(ymm5, ymm12, ymm11); - ymm10 = _mm256_fnmadd_pd(ymm6, ymm12, ymm10); - ymm9 = _mm256_fnmadd_pd(ymm7, ymm12, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm16, ymm12, ymm8); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*2)); + ymm10 = _mm256_fnmadd_ps(ymm0, ymm17, ymm10 ); //perform mul operation - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); - - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + ymm11 = STRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); - ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 3*rs_a)); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3*rs_a)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + //extract a15 15 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 0)); - //(ROw3): FMA operations - ymm10 = _mm256_fnmadd_pd(ymm6, ymm11, ymm10); - ymm9 = _mm256_fnmadd_pd(ymm7, ymm11, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm16, ymm11, ymm8); + //(ROw15): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*1)); + ymm10 = _mm256_fnmadd_ps(ymm0, ymm11, ymm10 ); //perform mul operation - ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + ymm10 = STRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + if(3 == n_remainder) + { + ymm0 = _mm256_unpacklo_ps(ymm10, ymm11); + ymm1 = _mm256_unpacklo_ps(ymm17, ymm18); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2*rs_a)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + ymm6 = _mm256_unpacklo_ps(ymm2, ymm3); + ymm7 = _mm256_unpacklo_ps(ymm19, ymm20); - //(ROw2): FMA operations - ymm9 = _mm256_fnmadd_pd(ymm7, ymm10, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm16, ymm10, ymm8); + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b01000100); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b01000100); - //perform mul operation - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//1 + _mm256_storeu_ps((float *)(b11), ymm16); - //extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b11101110); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b11101110); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//2 + _mm256_storeu_ps((float *)(b11 + cs_b), ymm16); - //(ROw2): FMA operations - ymm8 = _mm256_fnmadd_pd(ymm16, ymm9, ymm8); + ymm0 = _mm256_unpackhi_ps(ymm10, ymm11); + ymm1 = _mm256_unpackhi_ps(ymm17, ymm18); - //perform mul operation - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); + ymm6 = _mm256_unpackhi_ps(ymm2, ymm3); + ymm7 = _mm256_unpackhi_ps(ymm19, ymm20); - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b01000100); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b01000100); - ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] - ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2] + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//3 + _mm256_storeu_ps((float *)(b11 + 2*cs_b), ymm16); - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm0 = _mm256_unpacklo_ps(ymm14, ymm15); + ymm1 = _mm256_unpacklo_ps(ymm8, ymm9); - ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] - ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] + ymm6 = _mm256_unpacklo_ps(ymm21, ymm22); + ymm7 = _mm256_unpacklo_ps(ymm4, ymm5); - ///unpack high/// - ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b01000100); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b01000100); - ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[4][1] B11[5][1] B11[4][3] B11[5][3] - ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[6][1] B11[7][1] B11[6][3] B11[7][3] + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//1 + _mm256_storeu_ps((float *)(b11 + 8), ymm16); - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b11101110); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b11101110); - ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] - ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//2 + _mm256_storeu_ps((float *)(b11 + cs_b + 8), ymm16); - if(3 == n_remainder) - { - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); //store B11[4][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); //store B11[5][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 4), ymm6); //store B11[6][0-3] + ymm0 = _mm256_unpackhi_ps(ymm14, ymm15); + ymm1 = _mm256_unpackhi_ps(ymm8, ymm9); + + ymm6 = _mm256_unpackhi_ps(ymm21, ymm22); + ymm7 = _mm256_unpackhi_ps(ymm4, ymm5); + + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b01000100); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b01000100); + + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//3 + _mm256_storeu_ps((float *)(b11 + 2*cs_b + 8), ymm16); } else if(2 == n_remainder) { - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); //store B11[4][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); //store B11[5][0-3] + ymm0 = _mm256_unpacklo_ps(ymm10, ymm11); + ymm1 = _mm256_unpacklo_ps(ymm17, ymm18); + + ymm6 = _mm256_unpacklo_ps(ymm2, ymm3); + ymm7 = _mm256_unpacklo_ps(ymm19, ymm20); + + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b01000100); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b01000100); + + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//1 + _mm256_storeu_ps((float *)(b11), ymm16); + + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b11101110); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b11101110); + + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//2 + _mm256_storeu_ps((float *)(b11 + cs_b), ymm16); + + ymm0 = _mm256_unpacklo_ps(ymm14, ymm15); + ymm1 = _mm256_unpacklo_ps(ymm8, ymm9); + + ymm6 = _mm256_unpacklo_ps(ymm21, ymm22); + ymm7 = _mm256_unpacklo_ps(ymm4, ymm5); + + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b01000100); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b01000100); + + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//1 + _mm256_storeu_ps((float *)(b11 + 8), ymm16); + + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b11101110); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b11101110); + + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//2 + _mm256_storeu_ps((float *)(b11 + cs_b + 8), ymm16); } else if(1 == n_remainder) { - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); //store B11[4][0-3] + ymm0 = _mm256_unpacklo_ps(ymm10, ymm11); + ymm1 = _mm256_unpacklo_ps(ymm17, ymm18); + + ymm6 = _mm256_unpacklo_ps(ymm2, ymm3); + ymm7 = _mm256_unpacklo_ps(ymm19, ymm20); + + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b01000100); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b01000100); + + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//1 + _mm256_storeu_ps((float *)(b11), ymm16); + + ymm0 = _mm256_unpacklo_ps(ymm14, ymm15); + ymm1 = _mm256_unpacklo_ps(ymm8, ymm9); + + ymm6 = _mm256_unpacklo_ps(ymm21, ymm22); + ymm7 = _mm256_unpacklo_ps(ymm4, ymm5); + + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b01000100); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b01000100); + + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//1 + _mm256_storeu_ps((float *)(b11 + 8), ymm16); } } + }// End of multiples of d_mr blocks in m-dimension - // Repetative A blocks will be 4*4 + // Repetative A blocks will be 8*8 dim_t m_remainder = i + d_mr; - if(m_remainder >= 4) + if(m_remainder >= 8) { - i = m_remainder - 4; - a10 = L + (i*cs_a) + (i + 4)*rs_a; //pointer to block of A to be used for GEMM + i = m_remainder - 8; + a10 = L + (i*cs_a) + (i + 8)*rs_a; //pointer to block of A to be used for GEMM a11 = L + (i*cs_a) + (i*rs_a); //pointer to block of A to be used for TRSM // Do transpose for a10 & store in D_A_pack - double *ptr_a10_dup = D_A_pack; - dim_t p_lda = 4; // packed leading dimension + float *ptr_a10_dup = D_A_pack; + dim_t p_lda = 8; // packed leading dimension if(transa) { - for(dim_t x =0;x < m-i+4;x+=p_lda) + for(dim_t x =0;x < m-i+8;x+=p_lda) { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm1 = _mm256_loadu_pd((double const *)(a10 + cs_a)); - ymm2 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); - ymm3 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); + ymm0 = _mm256_loadu_ps((float const *)(a10)); + ymm1 = _mm256_loadu_ps((float const *)(a10 + cs_a)); + ymm2 = _mm256_loadu_ps((float const *)(a10 + cs_a*2)); + ymm3 = _mm256_loadu_ps((float const *)(a10 + cs_a*3)); + ymm4 = _mm256_loadu_ps((float const *)(a10 + cs_a*4)); + ymm5 = _mm256_loadu_ps((float const *)(a10 + cs_a*5)); + ymm6 = _mm256_loadu_ps((float const *)(a10 + cs_a*6)); + ymm7 = _mm256_loadu_ps((float const *)(a10 + cs_a*7)); - ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm8 = _mm256_unpacklo_ps(ymm0, ymm1); + ymm9 = _mm256_unpacklo_ps(ymm2, ymm3); - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + ymm10 = _mm256_unpacklo_ps(ymm4, ymm5); + ymm11 = _mm256_unpacklo_ps(ymm6, ymm7); - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + ymm12 = _mm256_shuffle_ps(ymm8,ymm9,0b01000100); + ymm13 = _mm256_shuffle_ps(ymm10,ymm11,0b01000100); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + ymm14 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//1 + ymm15 = _mm256_permute2f128_ps(ymm12,ymm13,0x31);//5 - _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); + _mm256_storeu_ps((float *)(ptr_a10_dup), ymm14); + _mm256_storeu_ps((float *)(ptr_a10_dup + 4*p_lda), ymm15); + + ymm12 = _mm256_shuffle_ps(ymm8,ymm9,0b11101110); + ymm13 = _mm256_shuffle_ps(ymm10,ymm11,0b11101110); + + ymm14 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//2 + ymm15 = _mm256_permute2f128_ps(ymm12,ymm13,0x31);//6 + _mm256_storeu_ps((float *)(ptr_a10_dup + p_lda), ymm14); + _mm256_storeu_ps((float *)(ptr_a10_dup + 5*p_lda), ymm15); + + ymm8 = _mm256_unpackhi_ps(ymm0, ymm1); + ymm9 = _mm256_unpackhi_ps(ymm2, ymm3); + + ymm10 = _mm256_unpackhi_ps(ymm4, ymm5); + ymm11 = _mm256_unpackhi_ps(ymm6, ymm7); + + ymm12 = _mm256_shuffle_ps(ymm8,ymm9,0b01000100); + ymm13 = _mm256_shuffle_ps(ymm10,ymm11,0b01000100); + + ymm14 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//3 + ymm15 = _mm256_permute2f128_ps(ymm12,ymm13,0x31);//7 + _mm256_storeu_ps((float *)(ptr_a10_dup + 2*p_lda), ymm14); + _mm256_storeu_ps((float *)(ptr_a10_dup + 6*p_lda), ymm15); + + ymm12 = _mm256_shuffle_ps(ymm8,ymm9,0b11101110); + ymm13 = _mm256_shuffle_ps(ymm10,ymm11,0b11101110); + + ymm14 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//4 + ymm15 = _mm256_permute2f128_ps(ymm12,ymm13,0x31);//8 + _mm256_storeu_ps((float *)(ptr_a10_dup + 3*p_lda), ymm14); + _mm256_storeu_ps((float *)(ptr_a10_dup + 7*p_lda), ymm15); a10 += p_lda; ptr_a10_dup += p_lda*p_lda; @@ -7737,3116 +29344,3079 @@ BLIS_INLINE err_t bli_dtrsm_small_AltXB_AuXB } else { - for(dim_t x =0;x < m-i-4;x++) + for(dim_t x =0;x < m-i-8;x++) { - ymm0 = _mm256_loadu_pd((double const *)(a10 + x*rs_a)); - _mm256_storeu_pd((double *)(ptr_a10_dup + x*p_lda), ymm0); + ymm0 = _mm256_loadu_ps((float const *)(a10 + x*rs_a)); + _mm256_storeu_ps((float *)(ptr_a10_dup + x*p_lda), ymm0); } } - ymm4 = _mm256_broadcast_sd((double const *)&ones); + ymm13 = ymm14 = _mm256_broadcast_ss((float const *)&ones); if(!is_unitdiag) { if(transa) { //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_sd((double const *)(a11)); - ymm1 = _mm256_broadcast_sd((double const *)(a11+cs_a*1 + 1)); - ymm2 = _mm256_broadcast_sd((double const *)(a11+cs_a*2 + 2)); - ymm3 = _mm256_broadcast_sd((double const *)(a11+cs_a*3 + 3)); + ymm0 = _mm256_broadcast_ss((float const *)(a11)); + ymm1 = _mm256_broadcast_ss((float const *)(a11+cs_a*1 + 1)); + ymm2 = _mm256_broadcast_ss((float const *)(a11+cs_a*2 + 2)); + ymm3 = _mm256_broadcast_ss((float const *)(a11+cs_a*3 + 3)); + ymm4 = _mm256_broadcast_ss((float const *)(a11+cs_a*4 + 4)); + ymm5 = _mm256_broadcast_ss((float const *)(a11+cs_a*5 + 5)); + ymm6 = _mm256_broadcast_ss((float const *)(a11+cs_a*6 + 6)); + ymm7 = _mm256_broadcast_ss((float const *)(a11+cs_a*7 + 7)); } else { //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_sd((double const *)(a11)); - ymm1 = _mm256_broadcast_sd((double const *)(a11+rs_a*1 + 1)); - ymm2 = _mm256_broadcast_sd((double const *)(a11+rs_a*2 + 2)); - ymm3 = _mm256_broadcast_sd((double const *)(a11+rs_a*3 + 3)); + ymm0 = _mm256_broadcast_ss((float const *)(a11)); + ymm1 = _mm256_broadcast_ss((float const *)(a11+rs_a*1 + 1)); + ymm2 = _mm256_broadcast_ss((float const *)(a11+rs_a*2 + 2)); + ymm3 = _mm256_broadcast_ss((float const *)(a11+rs_a*3 + 3)); + ymm4 = _mm256_broadcast_ss((float const *)(a11+rs_a*4 + 4)); + ymm5 = _mm256_broadcast_ss((float const *)(a11+rs_a*5 + 5)); + ymm6 = _mm256_broadcast_ss((float const *)(a11+rs_a*6 + 6)); + ymm7 = _mm256_broadcast_ss((float const *)(a11+rs_a*7 + 7)); } - ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); - ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); + ymm8 = _mm256_unpacklo_ps(ymm0, ymm1); + ymm9 = _mm256_unpacklo_ps(ymm2, ymm3); + ymm10 = _mm256_blend_ps(ymm8, ymm9, 0xCC); + + ymm8 = _mm256_unpacklo_ps(ymm4, ymm5); + ymm9 = _mm256_unpacklo_ps(ymm6, ymm7); + ymm11 = _mm256_blend_ps(ymm8, ymm9, 0xCC); + + ymm12 = _mm256_blend_ps(ymm10, ymm11, 0xF0); + #ifdef BLIS_DISABLE_TRSM_PREINVERSION - ymm4 = ymm1; + ymm14 = ymm12; #endif #ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm4 = _mm256_div_pd(ymm4, ymm1); + ymm14 = _mm256_div_ps(ymm13, ymm12); #endif } - _mm256_storeu_pd((double *)(d11_pack), ymm4); + _mm256_storeu_ps((float *)(d11_pack), ymm14); //cols for(j = (n - d_nr); (j + 1) > 0; j -= d_nr) //loop along 'N' dimension { a10 = D_A_pack; a11 = L + (i*cs_a) + (i*rs_a); //pointer to block of A to be used for TRSM - b01 = B + (j*cs_b) + i + 4; //pointer to block of B to be used for GEMM + b01 = B + (j*cs_b) + i + 8; //pointer to block of B to be used for GEMM b11 = B + (j* cs_b) + i; //pointer to block of B to be used for TRSM - k_iter = (m - i - 4); //number of times GEMM to be performed(in blocks of 4x4) + k_iter = (m - i - 8); //number of times GEMM to be performed(in blocks of 4x4) /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS + BLIS_SET_S_YMM_REG_ZEROS ///GEMM code begins/// - BLIS_DTRSM_SMALL_GEMM_4mx6n(a10,b01,cs_b,p_lda,k_iter) - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - - ///transpose of B11// - ///unpacklow/// - ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - - //rearrange low elements - ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] - - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - - //rearrange high elements - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); - - ymm16 = _mm256_broadcast_sd((double const *)(&ones)); - - ////unpacklow//// - ymm7 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - - //rearrange low elements - ymm4 = _mm256_permute2f128_pd(ymm7,ymm16,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm6 = _mm256_permute2f128_pd(ymm7,ymm16,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] - - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - - //rearrange high elements - ymm5 = _mm256_permute2f128_pd(ymm0,ymm16,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm7 = _mm256_permute2f128_pd(ymm0,ymm16,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - - - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - //perform mul operation - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm1); - - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(ROw3): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 3*rs_a)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm11, ymm10); - ymm6 = _mm256_fnmadd_pd(ymm2, ymm7, ymm6); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3*rs_a)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm11, ymm9); - ymm5 = _mm256_fnmadd_pd(ymm2, ymm7, ymm5); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm11, ymm8); - ymm4 = _mm256_fnmadd_pd(ymm2, ymm7, ymm4); - - //perform mul operation - ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); - ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm1); - - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(ROw2): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2*rs_a)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm10, ymm9); - ymm5 = _mm256_fnmadd_pd(ymm2, ymm6, ymm5); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm10, ymm8); - ymm4 = _mm256_fnmadd_pd(ymm2, ymm6, ymm4); - - //perform mul operation - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm1); - - //extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); - - //(ROw2): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm9, ymm8); - ymm4 = _mm256_fnmadd_pd(ymm2, ymm5, ymm4); - - //perform mul operation - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); - ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm1); - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ///unpack high/// - ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] + BLIS_STRSM_SMALL_GEMM_8mx6n(a10,b01,cs_b,p_lda,k_iter) - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + /* + Load b11 of size 6x8 and multiply with alpha + Add the GEMM output and perform inregister transose of b11 + to peform TRSM operation. + */ + ymm16 = _mm256_broadcast_ss((float const *)(&AlphaVal)); - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm17 = _mm256_loadu_ps((float const *)(b11)); + ymm18 = _mm256_loadu_ps((float const *)(b11 + cs_b)); + ymm19 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); + ymm20 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4)); + ymm1 = _mm256_loadu_ps((float const *)(b11 + cs_b*5)); - ///unpack high/// - ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + ymm17 = _mm256_fmsub_ps(ymm17, ymm16, ymm8); + ymm18 = _mm256_fmsub_ps(ymm18, ymm16, ymm9); + ymm19 = _mm256_fmsub_ps(ymm19, ymm16, ymm10); + ymm20 = _mm256_fmsub_ps(ymm20, ymm16, ymm11); + ymm0 = _mm256_fmsub_ps(ymm0 , ymm16, ymm4); + ymm1 = _mm256_fmsub_ps(ymm1 , ymm16, ymm5); - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm8 = _mm256_unpacklo_ps(ymm17, ymm18); + ymm9 = _mm256_unpacklo_ps(ymm19, ymm20); - _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store B11[1][0-3] - } - dim_t n_remainder = j + d_nr; - if((n_remainder >= 4)) - { - a10 = D_A_pack; - a11 = L + (i*cs_a) + (i*rs_a); //pointer to block of A to be used for TRSM - b01 = B + ((n_remainder - 4)* cs_b) + i + 4; //pointer to block of B to be used for GEMM - b11 = B + ((n_remainder - 4)* cs_b) + i; //pointer to block of B to be used for TRSM + ymm16 = _mm256_unpacklo_ps(ymm0, ymm1); - k_iter = (m - i - 4); //number of times GEMM to be performed(in blocks of 4x4) + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b01000100); + ymm5 = _mm256_shuffle_ps(ymm16,ymm16,0b01000100); - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS + ymm10 = _mm256_permute2f128_ps(ymm4,ymm5,0x20);//1 + ymm2 = _mm256_permute2f128_ps(ymm4,ymm5,0x31);//5 - ///GEMM code begins/// - BLIS_DTRSM_SMALL_GEMM_4mx4n(a10,b01,cs_b,p_lda,k_iter) + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b11101110); + ymm5 = _mm256_shuffle_ps(ymm16,ymm16,0b11101110); - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + ymm11 = _mm256_permute2f128_ps(ymm4,ymm5,0x20);//2 + ymm3 = _mm256_permute2f128_ps(ymm4,ymm5,0x31);//6 - ///implement TRSM/// + ymm8 = _mm256_unpackhi_ps(ymm17, ymm18); + ymm9 = _mm256_unpackhi_ps(ymm19, ymm20); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + ymm16 = _mm256_unpackhi_ps(ymm0, ymm1); - ///transpose of B11// - ///unpacklow/// - ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b01000100); + ymm5 = _mm256_shuffle_ps(ymm16,ymm16,0b01000100); - //rearrange low elements - ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + ymm17 = _mm256_permute2f128_ps(ymm4,ymm5,0x20);//3 + ymm19 = _mm256_permute2f128_ps(ymm4,ymm5,0x31);//7 - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b11101110); + ymm5 = _mm256_shuffle_ps(ymm16,ymm16,0b11101110); - //rearrange high elements - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + ymm18 = _mm256_permute2f128_ps(ymm4,ymm5,0x20);//4 + ymm20 = _mm256_permute2f128_ps(ymm4,ymm5,0x31);//8 - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + //extract a88 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 7)); //perform mul operation - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); - - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + ymm20 = STRSM_SMALL_DIV_OR_SCALE(ymm20, ymm1); - //(ROw3): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 3*rs_a)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm11, ymm10); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3*rs_a)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm11, ymm9); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm11, ymm8); + //extract a99 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 6)); - //perform mul operation - ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + //(ROw9): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*7 + cs_a*6)); + ymm19 = _mm256_fnmadd_ps(ymm0, ymm20, ymm19); - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*7 + cs_a*5)); + ymm3 = _mm256_fnmadd_ps(ymm0, ymm20, ymm3 ); - //(ROw2): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2*rs_a)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm10, ymm9); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm10, ymm8); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*7 + cs_a*4)); + ymm2 = _mm256_fnmadd_ps(ymm0, ymm20, ymm2 ); - //perform mul operation - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*7 + cs_a*3)); + ymm18 = _mm256_fnmadd_ps(ymm0, ymm20, ymm18); - //extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*7 + cs_a*2)); + ymm17 = _mm256_fnmadd_ps(ymm0, ymm20, ymm17); - //(ROw2): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm9, ymm8); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*7 + cs_a*1)); + ymm11 = _mm256_fnmadd_ps(ymm0, ymm20, ymm11 ); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*7)); + ymm10 = _mm256_fnmadd_ps(ymm0, ymm20, ymm10 ); //perform mul operation - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); + ymm19 = STRSM_SMALL_DIV_OR_SCALE(ymm19, ymm1); - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + //extract a10 10 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 5)); - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + //(ROw10): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*6 + cs_a*5)); + ymm3 = _mm256_fnmadd_ps(ymm0, ymm19, ymm3 ); - ///unpack high/// - ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*6 + cs_a*4)); + ymm2 = _mm256_fnmadd_ps(ymm0, ymm19, ymm2 ); - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*6 + cs_a*3)); + ymm18 = _mm256_fnmadd_ps(ymm0, ymm19, ymm18); - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] - n_remainder = n_remainder - 4; - } + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*6 + cs_a*2)); + ymm17 = _mm256_fnmadd_ps(ymm0, ymm19, ymm17); - if(n_remainder) //implementation fo remaining columns(when 'N' is not a multiple of d_nr)() n = 3 - { - a10 = D_A_pack; - a11 = L + (i*cs_a) + (i*rs_a); - b01 = B + i + 4; - b11 = B + i; + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*6 + cs_a*1)); + ymm11 = _mm256_fnmadd_ps(ymm0, ymm19, ymm11 ); - k_iter = (m - i - 4); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*6)); + ymm10 = _mm256_fnmadd_ps(ymm0, ymm19, ymm10 ); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); + //perform mul operation + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm1); - if(3 == n_remainder) - { - BLIS_DTRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) + //extract a11 11 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 4)); - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + //(ROw11): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*5 + cs_a*4)); + ymm2 = _mm256_fnmadd_ps(ymm0, ymm3, ymm2 ); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*5 + cs_a*3)); + ymm18 = _mm256_fnmadd_ps(ymm0, ymm3, ymm18); - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] - ymm3 = _mm256_broadcast_sd((double const *)(&ones)); - } - else if(2 == n_remainder) - { - BLIS_DTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b,p_lda,k_iter) + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*5 + cs_a*2)); + ymm17 = _mm256_fnmadd_ps(ymm0, ymm3, ymm17); - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*5 + cs_a*1)); + ymm11 = _mm256_fnmadd_ps(ymm0, ymm3, ymm11 ); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*5)); + ymm10 = _mm256_fnmadd_ps(ymm0, ymm3, ymm10 ); - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] - ymm2 = _mm256_broadcast_sd((double const *)(&ones)); - ymm3 = _mm256_broadcast_sd((double const *)(&ones)); - } - else if(1 == n_remainder) - { - BLIS_DTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b,p_lda,k_iter) + //perform mul operation + ymm2 = STRSM_SMALL_DIV_OR_SCALE(ymm2, ymm1); - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + //extract a12 12 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 3)); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + //(ROw12): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*4 + cs_a*3)); + ymm18 = _mm256_fnmadd_ps(ymm0, ymm2, ymm18); - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_broadcast_sd((double const *)(&ones)); - ymm2 = _mm256_broadcast_sd((double const *)(&ones)); - ymm3 = _mm256_broadcast_sd((double const *)(&ones)); - } + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*4 + cs_a*2)); + ymm17 = _mm256_fnmadd_ps(ymm0, ymm2, ymm17); - ///implement TRSM/// + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*4 + cs_a*1)); + ymm11 = _mm256_fnmadd_ps(ymm0, ymm2, ymm11 ); - ///transpose of B11// - ///unpacklow/// - ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*4)); + ymm10 = _mm256_fnmadd_ps(ymm0, ymm2, ymm10 ); - //rearrange low elements - ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + //perform mul operation + ymm18 = STRSM_SMALL_DIV_OR_SCALE(ymm18, ymm1); - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + //extract a13 13 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 2)); - //rearrange high elements - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + //(ROw13): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*3 + cs_a*2)); + ymm17 = _mm256_fnmadd_ps(ymm0, ymm18, ymm17); - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*3 + cs_a*1)); + ymm11 = _mm256_fnmadd_ps(ymm0, ymm18, ymm11 ); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*3)); + ymm10 = _mm256_fnmadd_ps(ymm0, ymm18, ymm10 ); //perform mul operation - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); + ymm17 = STRSM_SMALL_DIV_OR_SCALE(ymm17, ymm1); - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + //extract a14 14 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); - ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 3*rs_a)); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3*rs_a)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + //(ROw13): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*2 + cs_a*1)); + ymm11 = _mm256_fnmadd_ps(ymm0, ymm17, ymm11 ); - //(ROw3): FMA operations - ymm10 = _mm256_fnmadd_pd(ymm6, ymm11, ymm10); - ymm9 = _mm256_fnmadd_pd(ymm7, ymm11, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm16, ymm11, ymm8); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*2)); + ymm10 = _mm256_fnmadd_ps(ymm0, ymm17, ymm10 ); //perform mul operation - ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); - - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + ymm11 = STRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2*rs_a)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + //extract a15 15 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 0)); - //(ROw2): FMA operations - ymm9 = _mm256_fnmadd_pd(ymm7, ymm10, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm16, ymm10, ymm8); + //(ROw15): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*1)); + ymm10 = _mm256_fnmadd_ps(ymm0, ymm11, ymm10 ); //perform mul operation - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); + ymm10 = STRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); - //extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); + ymm0 = _mm256_unpacklo_ps(ymm10, ymm11); + ymm1 = _mm256_unpacklo_ps(ymm17, ymm18); - //(ROw2): FMA operations - ymm8 = _mm256_fnmadd_pd(ymm16, ymm9, ymm8); + ymm6 = _mm256_unpacklo_ps(ymm2, ymm3); + ymm7 = _mm256_unpacklo_ps(ymm19, ymm20); - //perform mul operation - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b01000100); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b01000100); - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//1 + _mm256_storeu_ps((float *)(b11), ymm16); - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x31);//5 + _mm256_storeu_ps((float *)(b11 + 4*cs_b), ymm16); - ///unpack high/// - ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b11101110); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b11101110); - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//2 + _mm256_storeu_ps((float *)(b11 + cs_b), ymm16); - if(3 == n_remainder) - { - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] - } - else if(2 == n_remainder) - { - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - } - else if(1 == n_remainder) - { - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - } - } - m_remainder -= 4; - } + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x31);//6 + _mm256_storeu_ps((float *)(b11 + 5*cs_b), ymm16); - a10 = L + m_remainder*rs_a; + ymm0 = _mm256_unpackhi_ps(ymm10, ymm11); + ymm1 = _mm256_unpackhi_ps(ymm17, ymm18); - // Do transpose for a10 & store in D_A_pack - double *ptr_a10_dup = D_A_pack; - if(3 == m_remainder) // Repetative A blocks will be 3*3 + ymm6 = _mm256_unpackhi_ps(ymm2, ymm3); + ymm7 = _mm256_unpackhi_ps(ymm19, ymm20); + + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b01000100); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b01000100); + + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//3 + _mm256_storeu_ps((float *)(b11 + 2*cs_b), ymm16); + + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b11101110); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b11101110); + + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//4 + _mm256_storeu_ps((float *)(b11 + 3*cs_b), ymm16); + } + + dim_t n_remainder = j + d_nr; + if((n_remainder >= 4)) { - dim_t p_lda = 4; // packed leading dimension - if(transa) - { - for(dim_t x =0;x < m-m_remainder;x+=p_lda) - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm1 = _mm256_loadu_pd((double const *)(a10 + cs_a)); - ymm2 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); - ymm3 = _mm256_broadcast_sd((double const *)&ones); + a10 = D_A_pack; + a11 = L + (i*cs_a) + (i*rs_a); //pointer to block of A to be used for TRSM + b01 = B + ((n_remainder - 4)* cs_b) + i + 8; //pointer to block of B to be used for GEMM + b11 = B + ((n_remainder - 4)* cs_b) + i; //pointer to block of B to be used for TRSM - ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + k_iter = (m - i - 8); //number of times GEMM to be performed(in blocks of 4x4) - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + ///GEMM code begins/// + BLIS_STRSM_SMALL_GEMM_8mx4n(a10,b01,cs_b,p_lda,k_iter) - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + ymm16 = _mm256_broadcast_ss((float const *)(&AlphaVal)); //register to hold alpha - _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); + ///implement TRSM/// + ymm17 = _mm256_loadu_ps((float const *)(b11)); + ymm18 = _mm256_loadu_ps((float const *)(b11 + cs_b)); + ymm19 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); + ymm20 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); - a10 += p_lda; - ptr_a10_dup += p_lda*p_lda; - } - } - else - { - for(dim_t x =0;x < m-m_remainder;x++) - { - ymm0 = _mm256_loadu_pd((double const *)(a10 + x*rs_a)); - _mm256_storeu_pd((double *)(ptr_a10_dup + x*p_lda), ymm0); - } - } + ymm17 = _mm256_fmsub_ps(ymm17, ymm16, ymm8); + ymm18 = _mm256_fmsub_ps(ymm18, ymm16, ymm9); + ymm19 = _mm256_fmsub_ps(ymm19, ymm16, ymm10); + ymm20 = _mm256_fmsub_ps(ymm20, ymm16, ymm11); - //cols - for(j = (n - d_nr); (j + 1) > 0; j -= d_nr) //loop along 'N' dimension - { - a10 = D_A_pack; - a11 = L; //pointer to block of A to be used for TRSM - b01 = B + (j* cs_b) + m_remainder; //pointer to block of B to be used for GEMM - b11 = B + (j* cs_b); //pointer to block of B to be used for TRSM + ymm8 = _mm256_unpacklo_ps(ymm17, ymm18); + ymm9 = _mm256_unpacklo_ps(ymm19, ymm20); - k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b01000100); - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS + ymm10 = _mm256_permute2f128_ps(ymm4,ymm4,0x20);//1 + ymm2 = _mm256_permute2f128_ps(ymm4,ymm4,0x31);//5 - ///GEMM code begins/// - BLIS_DTRSM_SMALL_GEMM_4mx6n(a10,b01,cs_b,p_lda,k_iter) + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b11101110); - ///GEMM code ends/// - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to store alpha value - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + ymm11 = _mm256_permute2f128_ps(ymm4,ymm4,0x20);//2 + ymm3 = _mm256_permute2f128_ps(ymm4,ymm4,0x31);//6 - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x08); + ymm8 = _mm256_unpackhi_ps(ymm17, ymm18); + ymm9 = _mm256_unpackhi_ps(ymm19, ymm20); - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b01000100); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); + ymm17 = _mm256_permute2f128_ps(ymm4,ymm4,0x20);//3 + ymm19 = _mm256_permute2f128_ps(ymm4,ymm4,0x31);//7 - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b11101110); - _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store(B11[0-3][3]) + ymm18 = _mm256_permute2f128_ps(ymm4,ymm4,0x20);//4 + ymm20 = _mm256_permute2f128_ps(ymm4,ymm4,0x31);//8 - if(transa) - dtrsm_AltXB_ref(a11, b11, m_remainder, 6, cs_a, cs_b, is_unitdiag); - else - dtrsm_AuXB_ref(a11, b11, m_remainder, 6, rs_a, cs_b, is_unitdiag); - } + //Implement TRSM - dim_t n_remainder = j + d_nr; - if((n_remainder >= 4)) - { - a10 = D_A_pack; - a11 = L; //pointer to block of A to be used for TRSM - b01 = B + ((n_remainder - 4)* cs_b) + m_remainder; //pointer to block of B to be used for GEMM - b11 = B + ((n_remainder - 4)* cs_b); //pointer to block of B to be used for TRSM + //extract a77 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 7)); - k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) + //perform mul operation + ymm20 = STRSM_SMALL_DIV_OR_SCALE(ymm20, ymm1); - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS + //extract a66 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 6)); - ///GEMM code begins/// - BLIS_DTRSM_SMALL_GEMM_4mx4n(a10,b01,cs_b,p_lda,k_iter) + //(ROw6): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*7 + cs_a*6)); + ymm19 = _mm256_fnmadd_ps(ymm0, ymm20, ymm19); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*7 + cs_a*5)); + ymm3 = _mm256_fnmadd_ps(ymm0, ymm20, ymm3 ); - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*7 + cs_a*4)); + ymm2 = _mm256_fnmadd_ps(ymm0, ymm20, ymm2 ); - ///implement TRSM/// + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*7 + cs_a*3)); + ymm18 = _mm256_fnmadd_ps(ymm0, ymm20, ymm18); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); - ymm3 = _mm256_broadcast_sd((double const *)(b11 + cs_b*3 + 2)); - ymm3 = _mm256_insertf128_pd(ymm3, xmm5, 0); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*7 + cs_a*2)); + ymm17 = _mm256_fnmadd_ps(ymm0, ymm20, ymm17); - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*7 + cs_a*1)); + ymm11 = _mm256_fnmadd_ps(ymm0, ymm20, ymm11 ); - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x08); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*7)); + ymm10 = _mm256_fnmadd_ps(ymm0, ymm20, ymm10 ); - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - xmm5 = _mm256_extractf128_pd(ymm3, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 3),xmm5); - _mm_storel_pd((b11 + cs_b * 3 + 2), _mm256_extractf128_pd(ymm3, 1)); + //perform mul operation + ymm19 = STRSM_SMALL_DIV_OR_SCALE(ymm19, ymm1); - if(transa) - dtrsm_AltXB_ref(a11, b11, m_remainder, 4, cs_a, cs_b, is_unitdiag); - else - dtrsm_AuXB_ref(a11, b11, m_remainder, 4, rs_a, cs_b, is_unitdiag); - n_remainder -= 4; - } + //extract a55 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 5)); - if(n_remainder) - { - a10 = D_A_pack; - a11 = L; //pointer to block of A to be used for TRSM - b01 = B + m_remainder; //pointer to block of B to be used for GEMM - b11 = B; //pointer to block of B to be used for TRSM + //(ROw5): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*6 + cs_a*5)); + ymm3 = _mm256_fnmadd_ps(ymm0, ymm19, ymm3 ); - k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*6 + cs_a*4)); + ymm2 = _mm256_fnmadd_ps(ymm0, ymm19, ymm2 ); - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*6 + cs_a*3)); + ymm18 = _mm256_fnmadd_ps(ymm0, ymm19, ymm18); - if(3 == n_remainder) - { - ///GEMM code begins/// - BLIS_DTRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*6 + cs_a*2)); + ymm17 = _mm256_fnmadd_ps(ymm0, ymm19, ymm17); - BLIS_PRE_DTRSM_SMALL_3M_3N(AlphaVal,b11,cs_b) + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*6 + cs_a*1)); + ymm11 = _mm256_fnmadd_ps(ymm0, ymm19, ymm11 ); - if(transa) - dtrsm_AltXB_ref(a11, b11, m_remainder, 3, cs_a, cs_b, is_unitdiag); - else - dtrsm_AuXB_ref(a11, b11, m_remainder, 3, rs_a, cs_b, is_unitdiag); - } - else if(2 == n_remainder) - { - ///GEMM code begins/// - BLIS_DTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b,p_lda,k_iter) + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*6)); + ymm10 = _mm256_fnmadd_ps(ymm0, ymm19, ymm10 ); - BLIS_PRE_DTRSM_SMALL_3M_2N(AlphaVal,b11,cs_b) + //perform mul operation + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm1); - if(transa) - dtrsm_AltXB_ref(a11, b11, m_remainder, 2, cs_a, cs_b, is_unitdiag); - else - dtrsm_AuXB_ref(a11, b11, m_remainder, 2, rs_a, cs_b, is_unitdiag); - } - else if(1 == n_remainder) - { - ///GEMM code begins/// - BLIS_DTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b,p_lda,k_iter) + //extract a44 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 4)); - BLIS_PRE_DTRSM_SMALL_3M_1N(AlphaVal,b11,cs_b) + //(ROw4): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*5 + cs_a*4)); + ymm2 = _mm256_fnmadd_ps(ymm0, ymm3, ymm2 ); - if(transa) - dtrsm_AltXB_ref(a11, b11, m_remainder, 1, cs_a, cs_b, is_unitdiag); - else - dtrsm_AuXB_ref(a11, b11, m_remainder, 1, rs_a, cs_b, is_unitdiag); - } - } - } - else if(2 == m_remainder) // Repetative A blocks will be 2*2 - { - dim_t p_lda = 4; // packed leading dimension - if(transa) - { - for(dim_t x =0;x < m-m_remainder;x+=p_lda) - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm1 = _mm256_loadu_pd((double const *)(a10 + cs_a)); - ymm2 = _mm256_broadcast_sd((double const *)&ones); - ymm3 = _mm256_broadcast_sd((double const *)&ones); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*5 + cs_a*3)); + ymm18 = _mm256_fnmadd_ps(ymm0, ymm3, ymm18); - ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*5 + cs_a*2)); + ymm17 = _mm256_fnmadd_ps(ymm0, ymm3, ymm17); - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*5 + cs_a*1)); + ymm11 = _mm256_fnmadd_ps(ymm0, ymm3, ymm11 ); - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*5)); + ymm10 = _mm256_fnmadd_ps(ymm0, ymm3, ymm10 ); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + //perform mul operation + ymm2 = STRSM_SMALL_DIV_OR_SCALE(ymm2, ymm1); - _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); + //extract a33 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 3)); - a10 += p_lda; - ptr_a10_dup += p_lda*p_lda; - } - } - else - { - for(dim_t x =0;x < m-m_remainder;x++) - { - ymm0 = _mm256_loadu_pd((double const *)(a10 + x*rs_a)); - _mm256_storeu_pd((double *)(ptr_a10_dup + x*p_lda), ymm0); - } - } - //cols - for(j = (n - d_nr); (j + 1) > 0; j -= d_nr) //loop along 'N' dimension - { - a10 = D_A_pack; - a11 = L; //pointer to block of A to be used for TRSM - b01 = B + (j* cs_b) + m_remainder; //pointer to block of B to be used for GEMM - b11 = B + (j* cs_b); //pointer to block of B to be used for TRSM + //(ROw3): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*4 + cs_a*3)); + ymm18 = _mm256_fnmadd_ps(ymm0, ymm2, ymm18); - k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*4 + cs_a*2)); + ymm17 = _mm256_fnmadd_ps(ymm0, ymm2, ymm17); - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*4 + cs_a*1)); + ymm11 = _mm256_fnmadd_ps(ymm0, ymm2, ymm11 ); - ///GEMM code begins/// - BLIS_DTRSM_SMALL_GEMM_4mx6n(a10,b01,cs_b,p_lda,k_iter) + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*4)); + ymm10 = _mm256_fnmadd_ps(ymm0, ymm2, ymm10 ); - ///GEMM code ends/// - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to store alpha value - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + //perform mul operation + ymm18 = STRSM_SMALL_DIV_OR_SCALE(ymm18, ymm1); - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0C); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0C); + //extract a22 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 2)); - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) + //(ROw2): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*3 + cs_a*2)); + ymm17 = _mm256_fnmadd_ps(ymm0, ymm18, ymm17); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*3 + cs_a*1)); + ymm11 = _mm256_fnmadd_ps(ymm0, ymm18, ymm11 ); - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*3)); + ymm10 = _mm256_fnmadd_ps(ymm0, ymm18, ymm10 ); - _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store(B11[0-3][3]) + //perform mul operation + ymm17 = STRSM_SMALL_DIV_OR_SCALE(ymm17, ymm1); - if(transa) - dtrsm_AltXB_ref(a11, b11, m_remainder, 6, cs_a, cs_b, is_unitdiag); - else - dtrsm_AuXB_ref(a11, b11, m_remainder, 6, rs_a, cs_b, is_unitdiag); - } - dim_t n_remainder = j + d_nr; - if((n_remainder >= 4)) - { - a10 = D_A_pack; - a11 = L; //pointer to block of A to be used for TRSM - b01 = B + ((n_remainder - 4)* cs_b) + m_remainder; //pointer to block of B to be used for GEMM - b11 = B + ((n_remainder - 4)* cs_b); //pointer to block of B to be used for TRSM + //extract a11 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); - k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) + //(ROw1): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*2 + cs_a*1)); + ymm11 = _mm256_fnmadd_ps(ymm0, ymm17, ymm11 ); - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*2)); + ymm10 = _mm256_fnmadd_ps(ymm0, ymm17, ymm10 ); - ///GEMM code begins/// - BLIS_DTRSM_SMALL_GEMM_4mx4n(a10,b01,cs_b,p_lda,k_iter) + //perform mul operation + ymm11 = STRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + //extract a00 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 0)); - ///implement TRSM/// + //(ROw 0): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*1)); + ymm10 = _mm256_fnmadd_ps(ymm0, ymm11, ymm10 ); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); - ymm3 = _mm256_insertf128_pd(ymm3, xmm5, 0); + //perform mul operation + ymm10 = STRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + ymm0 = _mm256_unpacklo_ps(ymm10, ymm11); + ymm1 = _mm256_unpacklo_ps(ymm17, ymm18); - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0C); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0C); + ymm6 = _mm256_unpacklo_ps(ymm2, ymm3); + ymm7 = _mm256_unpacklo_ps(ymm19, ymm20); - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - xmm5 = _mm256_extractf128_pd(ymm3, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 3), xmm5); + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b01000100); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b01000100); - if(transa) - dtrsm_AltXB_ref(a11, b11, m_remainder, 4, cs_a, cs_b, is_unitdiag); - else - dtrsm_AuXB_ref(a11, b11, m_remainder, 4, rs_a, cs_b, is_unitdiag); - n_remainder -= 4; - } - if(n_remainder) - { - a10 = D_A_pack; - a11 = L; //pointer to block of A to be used for TRSM - b01 = B + m_remainder; //pointer to block of B to be used for GEMM - b11 = B; //pointer to block of B to be used for TRSM + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//1 + _mm256_storeu_ps((float *)(b11), ymm16); - k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b11101110); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b11101110); - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//2 + _mm256_storeu_ps((float *)(b11 + cs_b), ymm16); - if(3 == n_remainder) - { - ///GEMM code begins/// - BLIS_DTRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) + ymm0 = _mm256_unpackhi_ps(ymm10, ymm11); + ymm1 = _mm256_unpackhi_ps(ymm17, ymm18); - BLIS_PRE_DTRSM_SMALL_2M_3N(AlphaVal,b11,cs_b) + ymm6 = _mm256_unpackhi_ps(ymm2, ymm3); + ymm7 = _mm256_unpackhi_ps(ymm19, ymm20); - if(transa) - dtrsm_AltXB_ref(a11, b11, m_remainder, 3, cs_a, cs_b, is_unitdiag); - else - dtrsm_AuXB_ref(a11, b11, m_remainder, 3, rs_a, cs_b, is_unitdiag); - } - else if(2 == n_remainder) - { - ///GEMM code begins/// - BLIS_DTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b,p_lda,k_iter) + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b01000100); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b01000100); - BLIS_PRE_DTRSM_SMALL_2M_2N(AlphaVal,b11,cs_b) + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//3 + _mm256_storeu_ps((float *)(b11 + 2*cs_b), ymm16); - if(transa) - dtrsm_AltXB_ref(a11, b11, m_remainder, 2, cs_a, cs_b, is_unitdiag); - else - dtrsm_AuXB_ref(a11, b11, m_remainder, 2, rs_a, cs_b, is_unitdiag); - } - else if(1 == n_remainder) - { - ///GEMM code begins/// - BLIS_DTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b,p_lda,k_iter) + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b11101110); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b11101110); - BLIS_PRE_DTRSM_SMALL_2M_1N(AlphaVal,b11,cs_b) - if(transa) - dtrsm_AltXB_ref(a11, b11, m_remainder, 1, cs_a, cs_b, is_unitdiag); - else - dtrsm_AuXB_ref(a11, b11, m_remainder, 1, rs_a, cs_b, is_unitdiag); - } - } + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//4 + _mm256_storeu_ps((float *)(b11 + 3*cs_b), ymm16); + n_remainder = n_remainder - 4; } - else if(1 == m_remainder) // Repetative A blocks will be 1*1 + + if(n_remainder) //implementation fo remaining columns(when 'N' is not a multiple of d_nr)() n = 3 { - dim_t p_lda = 4; // packed leading dimension - if(transa) - { - for(dim_t x =0;x < m-m_remainder;x+=p_lda) - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm1 = _mm256_broadcast_sd((double const *)&ones); - ymm2 = _mm256_broadcast_sd((double const *)&ones); - ymm3 = _mm256_broadcast_sd((double const *)&ones); + a10 = D_A_pack; + a11 = L + (i*cs_a) + (i*rs_a); + b01 = B + i + 8; + b11 = B + i; - ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + k_iter = (m - i - 8); - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + BLIS_SET_S_YMM_REG_ZEROS - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + if(3 == n_remainder) + { + BLIS_STRSM_SMALL_GEMM_8mx3n(a10,b01,cs_b,p_lda,k_iter) - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + ymm16 = _mm256_broadcast_ss((float const *)(&AlphaVal)); //register to hold alpha - _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); + ymm17 = _mm256_loadu_ps((float const *)(b11)); + ymm18 = _mm256_loadu_ps((float const *)(b11 + cs_b)); + ymm19 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); - a10 += p_lda; - ptr_a10_dup += p_lda*p_lda; - } + ymm17 = _mm256_fmsub_ps(ymm17, ymm16, ymm8); + ymm18 = _mm256_fmsub_ps(ymm18, ymm16, ymm9); + ymm19 = _mm256_fmsub_ps(ymm19, ymm16, ymm10); + ymm20 = _mm256_broadcast_ss((float const *)(&ones)); } - else + else if(2 == n_remainder) { - for(dim_t x =0;x < m-m_remainder;x++) - { - ymm0 = _mm256_loadu_pd((double const *)(a10 + x*rs_a)); - _mm256_storeu_pd((double *)(ptr_a10_dup + x*p_lda), ymm0); - } + BLIS_STRSM_SMALL_GEMM_8mx2n(a10,b01,cs_b,p_lda,k_iter) + + ymm16 = _mm256_broadcast_ss((float const *)(&AlphaVal)); //register to hold alpha + + ymm17 = _mm256_loadu_ps((float const *)(b11)); + ymm18 = _mm256_loadu_ps((float const *)(b11 + cs_b)); + + ymm17 = _mm256_fmsub_ps(ymm17, ymm16, ymm8); + ymm18 = _mm256_fmsub_ps(ymm18, ymm16, ymm9); + ymm19 = _mm256_broadcast_ss((float const *)(&ones)); + ymm20 = _mm256_broadcast_ss((float const *)(&ones)); } - //cols - for(j = (n - d_nr); (j + 1) > 0; j -= d_nr) //loop along 'N' dimension + else if(1 == n_remainder) { - a10 = D_A_pack; - a11 = L; //pointer to block of A to be used for TRSM - b01 = B + (j* cs_b) + m_remainder; //pointer to block of B to be used for GEMM - b11 = B + (j* cs_b); //pointer to block of B to be used for TRSM + BLIS_STRSM_SMALL_GEMM_8mx1n(a10,b01,cs_b,p_lda,k_iter) - k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) + ymm16 = _mm256_broadcast_ss((float const *)(&AlphaVal)); //register to hold alpha - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS + ymm17 = _mm256_loadu_ps((float const *)(b11)); - ///GEMM code begins/// - BLIS_DTRSM_SMALL_GEMM_4mx6n(a10,b01,cs_b,p_lda,k_iter) + ymm17 = _mm256_fmsub_ps(ymm17, ymm16, ymm8); + ymm18 = _mm256_broadcast_ss((float const *)(&ones)); + ymm19 = _mm256_broadcast_ss((float const *)(&ones)); + ymm20 = _mm256_broadcast_ss((float const *)(&ones)); + } - ///GEMM code ends/// - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to store alpha value - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + ymm8 = _mm256_unpacklo_ps(ymm17, ymm18); + ymm9 = _mm256_unpacklo_ps(ymm19, ymm20); - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0E); + ymm16 = _mm256_unpacklo_ps(ymm0, ymm1); - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b01000100); + ymm5 = _mm256_shuffle_ps(ymm16,ymm16,0b01000100); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); + ymm10 = _mm256_permute2f128_ps(ymm4,ymm5,0x20);//1 + ymm2 = _mm256_permute2f128_ps(ymm4,ymm5,0x31);//5 - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b11101110); + ymm5 = _mm256_shuffle_ps(ymm16,ymm16,0b11101110); - _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store(B11[0-3][3]) + ymm11 = _mm256_permute2f128_ps(ymm4,ymm5,0x20);//2 + ymm3 = _mm256_permute2f128_ps(ymm4,ymm5,0x31);//6 - if(transa) - dtrsm_AltXB_ref(a11, b11, m_remainder, 6, cs_a, cs_b, is_unitdiag); - else - dtrsm_AuXB_ref(a11, b11, m_remainder, 6, rs_a, cs_b, is_unitdiag); - } - dim_t n_remainder = j + d_nr; - if((n_remainder >= 4)) - { - a10 = D_A_pack; - a11 = L; //pointer to block of A to be used for TRSM - b01 = B + ((n_remainder - 4)* cs_b) + m_remainder; //pointer to block of B to be used for GEMM - b11 = B + ((n_remainder - 4)* cs_b); //pointer to block of B to be used for TRSM + ymm8 = _mm256_unpackhi_ps(ymm17, ymm18); + ymm9 = _mm256_unpackhi_ps(ymm19, ymm20); - k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) + ymm16 = _mm256_unpackhi_ps(ymm0, ymm1); - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b01000100); + ymm5 = _mm256_shuffle_ps(ymm16,ymm16,0b01000100); - ///GEMM code begins/// - BLIS_DTRSM_SMALL_GEMM_4mx4n(a10,b01,cs_b,p_lda,k_iter) + ymm17 = _mm256_permute2f128_ps(ymm4,ymm5,0x20);//3 + ymm19 = _mm256_permute2f128_ps(ymm4,ymm5,0x31);//7 - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b11101110); + ymm5 = _mm256_shuffle_ps(ymm16,ymm16,0b11101110); - ///implement TRSM/// + ymm18 = _mm256_permute2f128_ps(ymm4,ymm5,0x20);//4 + ymm20 = _mm256_permute2f128_ps(ymm4,ymm5,0x31);//8 - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + ///implement TRSM/// - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0E); + //extract a77 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 7)); - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) + //perform mul operation + ymm20 = STRSM_SMALL_DIV_OR_SCALE(ymm20, ymm1); - if(transa) - dtrsm_AltXB_ref(a11, b11, m_remainder, 4, cs_a, cs_b, is_unitdiag); - else - dtrsm_AuXB_ref(a11, b11, m_remainder, 4, rs_a, cs_b, is_unitdiag); - n_remainder -= 4; - } - if(n_remainder) - { - a10 = D_A_pack; - a11 = L; //pointer to block of A to be used for TRSM - b01 = B + m_remainder; //pointer to block of B to be used for GEMM - b11 = B; //pointer to block of B to be used for TRSM + //extract a66 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 6)); - k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) + //(ROw6): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*7 + cs_a*6)); + ymm19 = _mm256_fnmadd_ps(ymm0, ymm20, ymm19); - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*7 + cs_a*5)); + ymm3 = _mm256_fnmadd_ps(ymm0, ymm20, ymm3 ); - if(3 == n_remainder) - { - ///GEMM code begins/// - BLIS_DTRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*7 + cs_a*4)); + ymm2 = _mm256_fnmadd_ps(ymm0, ymm20, ymm2 ); - BLIS_PRE_DTRSM_SMALL_1M_3N(AlphaVal,b11,cs_b) + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*7 + cs_a*3)); + ymm18 = _mm256_fnmadd_ps(ymm0, ymm20, ymm18); - if(transa) - dtrsm_AltXB_ref(a11, b11, m_remainder, 3, cs_a, cs_b, is_unitdiag); - else - dtrsm_AuXB_ref(a11, b11, m_remainder, 3, rs_a, cs_b, is_unitdiag); - } - else if(2 == n_remainder) - { - ///GEMM code begins/// - BLIS_DTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b,p_lda,k_iter) + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*7 + cs_a*2)); + ymm17 = _mm256_fnmadd_ps(ymm0, ymm20, ymm17); - BLIS_PRE_DTRSM_SMALL_1M_2N(AlphaVal,b11,cs_b) + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*7 + cs_a*1)); + ymm11 = _mm256_fnmadd_ps(ymm0, ymm20, ymm11 ); - if(transa) - dtrsm_AltXB_ref(a11, b11, m_remainder, 2, cs_a, cs_b, is_unitdiag); - else - dtrsm_AuXB_ref(a11, b11, m_remainder, 2, rs_a, cs_b, is_unitdiag); - } - else if(1 == n_remainder) - { - ///GEMM code begins/// - BLIS_DTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b,p_lda,k_iter) + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*7)); + ymm10 = _mm256_fnmadd_ps(ymm0, ymm20, ymm10 ); - BLIS_PRE_DTRSM_SMALL_1M_1N(AlphaVal,b11,cs_b) + //perform mul operation + ymm19 = STRSM_SMALL_DIV_OR_SCALE(ymm19, ymm1); - if(transa) - dtrsm_AltXB_ref(a11, b11, m_remainder, 1, cs_a, cs_b, is_unitdiag); - else - dtrsm_AuXB_ref(a11, b11, m_remainder, 1, rs_a, cs_b, is_unitdiag); - } - } - } + //extract a55 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 5)); - if ((required_packing_A == 1) && - bli_mem_is_alloc( &local_mem_buf_A_s )) - { - bli_membrk_release(&rntm,&local_mem_buf_A_s); - } - return BLIS_SUCCESS; -} + //(ROw5): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*6 + cs_a*5)); + ymm3 = _mm256_fnmadd_ps(ymm0, ymm19, ymm3 ); -/* TRSM for the Left Upper case AX = alpha * B, Double precision - * A is Left side, upper-triangular, transpose, non-unit/unit diagonal - * dimensions A: mxm X: mxn B: mxn - a10 ----> b11---> - *********** ***************** - * * * * *b01*b11* * * - **a10 * * a11 b11 * * * * * - ********* | | ***************** - *a11* * | | * * * * * - * * * | | * * * * * - ****** v v ***************** - * * * * * * * - * * * * * * * - * * ***************** - * - a11---> + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*6 + cs_a*4)); + ymm2 = _mm256_fnmadd_ps(ymm0, ymm19, ymm2 ); - * TRSM for the case AX = alpha * B, Double precision - * A is Left side, lower-triangular, no-transpose, non-unit/unit diagonal - * dimensions A: mxm X: mxn B: mxn + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*6 + cs_a*3)); + ymm18 = _mm256_fnmadd_ps(ymm0, ymm19, ymm18); - b01---> - * ***************** - ** * * * * * - * * * * * * * - * * *b01* * * * - * * * * * * * -a10 ****** b11 ***************** - | * * * | * * * * * - | * * * | * * * * * - | *a10*a11* | *b11* * * * - v * * * v * * * * * - *********** ***************** - * * * * * * * * * - * * * * * * * * * - * * * * * * * * * - * * * * * * * * * - **************** ***************** - a11---> -*/ -BLIS_INLINE err_t bli_dtrsm_small_AutXB_AlXB -( - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl -) -{ - dim_t m = bli_obj_length(b); // number of rows of matrix B - dim_t n = bli_obj_width(b); // number of columns of matrix B + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*6 + cs_a*2)); + ymm17 = _mm256_fnmadd_ps(ymm0, ymm19, ymm17); - bool transa = bli_obj_has_trans(a); - dim_t cs_a, rs_a; - dim_t d_mr = 8,d_nr = 6; + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*6 + cs_a*1)); + ymm11 = _mm256_fnmadd_ps(ymm0, ymm19, ymm11 ); - // Swap rs_a & cs_a in case of non-tranpose. - if(transa) - { - cs_a = bli_obj_col_stride(a); // column stride of A - rs_a = bli_obj_row_stride(a); // row stride of A - } - else - { - cs_a = bli_obj_row_stride(a); // row stride of A - rs_a = bli_obj_col_stride(a); // column stride of A - } - dim_t cs_b = bli_obj_col_stride(b); // column stride of B + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*6)); + ymm10 = _mm256_fnmadd_ps(ymm0, ymm19, ymm10 ); + + //perform mul operation + ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm1); + + //extract a44 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 4)); + + //(ROw4): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*5 + cs_a*4)); + ymm2 = _mm256_fnmadd_ps(ymm0, ymm3, ymm2 ); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*5 + cs_a*3)); + ymm18 = _mm256_fnmadd_ps(ymm0, ymm3, ymm18); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*5 + cs_a*2)); + ymm17 = _mm256_fnmadd_ps(ymm0, ymm3, ymm17); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*5 + cs_a*1)); + ymm11 = _mm256_fnmadd_ps(ymm0, ymm3, ymm11 ); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*5)); + ymm10 = _mm256_fnmadd_ps(ymm0, ymm3, ymm10 ); + + //perform mul operation + ymm2 = STRSM_SMALL_DIV_OR_SCALE(ymm2, ymm1); + + //extract a33 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 3)); + + //(ROw3): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*4 + cs_a*3)); + ymm18 = _mm256_fnmadd_ps(ymm0, ymm2, ymm18); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*4 + cs_a*2)); + ymm17 = _mm256_fnmadd_ps(ymm0, ymm2, ymm17); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*4 + cs_a*1)); + ymm11 = _mm256_fnmadd_ps(ymm0, ymm2, ymm11 ); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*4)); + ymm10 = _mm256_fnmadd_ps(ymm0, ymm2, ymm10 ); + + //perform mul operation + ymm18 = STRSM_SMALL_DIV_OR_SCALE(ymm18, ymm1); + + //extract a22 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 2)); + + //(ROw2): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*3 + cs_a*2)); + ymm17 = _mm256_fnmadd_ps(ymm0, ymm18, ymm17); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*3 + cs_a*1)); + ymm11 = _mm256_fnmadd_ps(ymm0, ymm18, ymm11 ); - dim_t i, j, k; //loop variables - dim_t k_iter; //number of times GEMM to be performed + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*3)); + ymm10 = _mm256_fnmadd_ps(ymm0, ymm18, ymm10 ); - double AlphaVal = *(double *)AlphaObj->buffer; //value of alpha - double *L = a->buffer; //pointer to matrix A - double *B = b->buffer; //pointer to matrix B + //perform mul operation + ymm17 = STRSM_SMALL_DIV_OR_SCALE(ymm17, ymm1); - double *a10, *a11, *b01, *b11; //pointers that point to blocks for GEMM and TRSM + //extract a11 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); - double ones = 1.0; - bool is_unitdiag = bli_obj_has_unit_diag(a); + //(ROw1): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*2 + cs_a*1)); + ymm11 = _mm256_fnmadd_ps(ymm0, ymm17, ymm11 ); - //scratch registers - __m256d ymm0, ymm1, ymm2, ymm3; - __m256d ymm4, ymm5, ymm6, ymm7; - __m256d ymm8, ymm9, ymm10, ymm11; - __m256d ymm12, ymm13, ymm14, ymm15; - __m256d ymm16, ymm17, ymm18, ymm19; - __m256d ymm20; + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*2)); + ymm10 = _mm256_fnmadd_ps(ymm0, ymm17, ymm10 ); - __m128d xmm5; + //perform mul operation + ymm11 = STRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); - gint_t required_packing_A = 1; - mem_t local_mem_buf_A_s = {0}; - double *D_A_pack = NULL; - double d11_pack[d_mr] __attribute__((aligned(64))); - rntm_t rntm; + //extract a00 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 0)); - bli_rntm_init_from_global( &rntm ); - bli_rntm_set_num_threads_only( 1, &rntm ); - bli_membrk_rntm_set_membrk( &rntm ); + //(ROw 0): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*1)); + ymm10 = _mm256_fnmadd_ps(ymm0, ymm11, ymm10 ); - siz_t buffer_size = bli_pool_block_size( - bli_membrk_pool( - bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), - bli_rntm_membrk(&rntm))); + //perform mul operation + ymm10 = STRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); - if ( (d_mr * m * sizeof(double)) > buffer_size) - return BLIS_NOT_YET_IMPLEMENTED; + if(3 == n_remainder) + { + ymm0 = _mm256_unpacklo_ps(ymm10, ymm11); + ymm1 = _mm256_unpacklo_ps(ymm17, ymm18); - if (required_packing_A == 1) - { - // Get the buffer from the pool. - bli_membrk_acquire_m(&rntm, - buffer_size, - BLIS_BITVAL_BUFFER_FOR_A_BLOCK, - &local_mem_buf_A_s); - if(FALSE==bli_mem_is_alloc(&local_mem_buf_A_s)) return BLIS_NULL_POINTER; - D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); - if(NULL==D_A_pack) return BLIS_NULL_POINTER; - } + ymm6 = _mm256_unpacklo_ps(ymm2, ymm3); + ymm7 = _mm256_unpacklo_ps(ymm19, ymm20); - /* - Performs solving TRSM for 8 colmns at a time from 0 to m/8 in steps of d_mr - a. Load, transpose, Pack A (a10 block), the size of packing 8x6 to 8x (m-8) - First there will be no GEMM and no packing of a10 because it is only TRSM - b. Using packed a10 block and b01 block perform GEMM operation - c. Use GEMM outputs, perform TRSM operaton using a11, b11 and update B - d. Repeat b,c for n rows of B in steps of d_nr - */ - for(i = 0;(i+d_mr-1) < m; i += d_mr) //loop along 'M' dimension - { - a10 = L + (i*cs_a); //pointer to block of A to be used for GEMM - a11 = L + (i*rs_a) + (i*cs_a); - dim_t p_lda = d_mr; // packed leading dimension + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b01000100); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b01000100); - if(transa) - { - /* - Load, tranpose and pack current A block (a10) into packed buffer memory D_A_pack - a. This a10 block is used in GEMM portion only and this - a10 block size will be increasing by d_mr for every next itteration - untill it reaches 8x(m-8) which is the maximum GEMM alone block size in A - b. This packed buffer is reused to calculate all n rows of B matrix - */ - bli_dtrsm_small_pack('L', i, 1, a10, cs_a, D_A_pack, p_lda,d_mr); + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//1 + _mm256_storeu_ps((float *)(b11), ymm16); - /* - Pack 8 diagonal elements of A block into an array - a. This helps in utilze cache line efficiently in TRSM operation - b. store ones when input is unit diagonal - */ - dtrsm_small_pack_diag_element(is_unitdiag,a11,cs_a,d11_pack,d_mr); - } - else - { - bli_dtrsm_small_pack('L', i, 0, a10, rs_a, D_A_pack, p_lda,d_mr); - dtrsm_small_pack_diag_element(is_unitdiag,a11,rs_a,d11_pack,d_mr); - } + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b11101110); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b11101110); - /* - a. Perform GEMM using a10, b01. - b. Perform TRSM on a11, b11 - c. This loop GEMM+TRSM loops operates with 8x6 block size - along n dimension for every d_nr rows of b01 where - packed A buffer is reused in computing all n rows of B. - d. Same approch is used in remaining fringe cases. - */ - dim_t temp = n - d_nr + 1; - for(j = 0; j < temp; j += d_nr) //loop along 'N' dimension - { - a10 = D_A_pack; - a11 = L + (i*rs_a) + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + j*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//2 + _mm256_storeu_ps((float *)(b11 + cs_b), ymm16); - k_iter = i; + ymm0 = _mm256_unpackhi_ps(ymm10, ymm11); + ymm1 = _mm256_unpackhi_ps(ymm17, ymm18); - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS + ymm6 = _mm256_unpackhi_ps(ymm2, ymm3); + ymm7 = _mm256_unpackhi_ps(ymm19, ymm20); - /* - Peform GEMM between a10 and b01 blocks - For first itteration there will be no GEMM operation - where k_iter are zero - */ - BLIS_DTRSM_SMALL_GEMM_8mx6n(a10,b01,cs_b,p_lda,k_iter) + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b01000100); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b01000100); - /* - Load b11 of size 6x8 and multiply with alpha - Add the GEMM output and perform inregister transose of b11 - to peform TRSM operation. - */ - BLIS_DTRSM_SMALL_NREG_TRANSPOSE_6x8(b11,cs_b,AlphaVal) + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//3 + _mm256_storeu_ps((float *)(b11 + 2*cs_b), ymm16); + } + else if(2 == n_remainder) + { + ymm0 = _mm256_unpacklo_ps(ymm10, ymm11); + ymm1 = _mm256_unpacklo_ps(ymm17, ymm18); - /* - Compute 8x6 TRSM block by using GEMM block output in register - a. The 8x6 input (gemm outputs) are stored in combinations of ymm registers - 1. ymm8, ymm4 2. ymm9, ymm5 3. ymm10, ymm6, 4. ymm11, ymm7 - 5. ymm12, ymm17 6. ymm13,ymm18, 7. ymm14,ymm19 8. ymm15, ymm20 - where ymm8-ymm15 holds 8x4 data and reaming 8x2 will be hold by - other registers - b. Towards the end do in regiser transpose of TRSM output and store in b11 - */ - ////extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm6 = _mm256_unpacklo_ps(ymm2, ymm3); + ymm7 = _mm256_unpacklo_ps(ymm19, ymm20); - //perform mul operation - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); - ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm1); + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b01000100); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b01000100); - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//1 + _mm256_storeu_ps((float *)(b11), ymm16); - //(ROw1): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*1)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); - ymm5 = _mm256_fnmadd_pd(ymm2, ymm4, ymm5); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm8, ymm10); - ymm6 = _mm256_fnmadd_pd(ymm2, ymm4, ymm6); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); - ymm7 = _mm256_fnmadd_pd(ymm2, ymm4, ymm7); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); - ymm12 = _mm256_fnmadd_pd(ymm2, ymm8, ymm12); - ymm17 = _mm256_fnmadd_pd(ymm2, ymm4, ymm17); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); - ymm13 = _mm256_fnmadd_pd(ymm2, ymm8, ymm13); - ymm18 = _mm256_fnmadd_pd(ymm2, ymm4, ymm18); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); - ymm14 = _mm256_fnmadd_pd(ymm2, ymm8, ymm14); - ymm19 = _mm256_fnmadd_pd(ymm2, ymm4, ymm19); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); - ymm15 = _mm256_fnmadd_pd(ymm2, ymm8, ymm15); - ymm20 = _mm256_fnmadd_pd(ymm2, ymm4, ymm20); + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b11101110); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b11101110); + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//2 + _mm256_storeu_ps((float *)(b11 + cs_b), ymm16); + } + else if(1 == n_remainder) + { + ymm0 = _mm256_unpacklo_ps(ymm10, ymm11); + ymm1 = _mm256_unpacklo_ps(ymm17, ymm18); - //perform mul operation - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm1); + ymm6 = _mm256_unpacklo_ps(ymm2, ymm3); + ymm7 = _mm256_unpacklo_ps(ymm19, ymm20); - a11 += rs_a; + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b01000100); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b01000100); - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//1 + _mm256_storeu_ps((float *)(b11), ymm16); + } + } + m_remainder -= 8; + } - //(ROw2): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm9, ymm10); - ymm6 = _mm256_fnmadd_pd(ymm2, ymm5, ymm6); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm9, ymm11); - ymm7 = _mm256_fnmadd_pd(ymm2, ymm5, ymm7); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); - ymm12 = _mm256_fnmadd_pd(ymm2, ymm9, ymm12); - ymm17 = _mm256_fnmadd_pd(ymm2, ymm5, ymm17); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); - ymm13 = _mm256_fnmadd_pd(ymm2, ymm9, ymm13); - ymm18 = _mm256_fnmadd_pd(ymm2, ymm5, ymm18); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); - ymm14 = _mm256_fnmadd_pd(ymm2, ymm9, ymm14); - ymm19 = _mm256_fnmadd_pd(ymm2, ymm5, ymm19); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); - ymm15 = _mm256_fnmadd_pd(ymm2, ymm9, ymm15); - ymm20 = _mm256_fnmadd_pd(ymm2, ymm5, ymm20); + if(m_remainder >= 4) + { + i = m_remainder - 4; + a10 = L + (i*cs_a) + (i + 4)*rs_a; //pointer to block of A to be used for GEMM + a11 = L + (i*cs_a) + (i*rs_a); //pointer to block of A to be used for TRSM - //perform mul operation - ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); - ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm1); + // Do transpose for a10 & store in D_A_pack + float *ptr_a10_dup = D_A_pack; + dim_t p_lda = 4; // packed leading dimension + __m128 xmm0,xmm1,xmm2,xmm3; + __m128 xmm4,xmm5; + __m128 xmm6,xmm7,xmm8,xmm9; + if(transa) + { + for(dim_t x =0;x < m-i+4;x+=p_lda) + { + xmm0 = _mm_loadu_ps((float const *)(a10)); + xmm1 = _mm_loadu_ps((float const *)(a10 + cs_a)); + xmm2 = _mm_loadu_ps((float const *)(a10 + cs_a * 2)); + xmm3 = _mm_loadu_ps((float const *)(a10 + cs_a * 3)); + + xmm4 = _mm_unpacklo_ps(xmm0, xmm1); + xmm5 = _mm_unpacklo_ps(xmm2, xmm3); + xmm6 = _mm_shuffle_ps(xmm4,xmm5,0x44); + xmm7 = _mm_shuffle_ps(xmm4,xmm5,0xEE); + + xmm0 = _mm_unpackhi_ps(xmm0, xmm1); + xmm1 = _mm_unpackhi_ps(xmm2, xmm3); + xmm8 = _mm_shuffle_ps(xmm0,xmm1,0x44); + xmm9 = _mm_shuffle_ps(xmm0,xmm1,0xEE); + + _mm_storeu_ps((float *)(ptr_a10_dup), xmm6); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda), xmm7); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda*2), xmm8); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda*3), xmm9); - a11 += rs_a; + a10 += p_lda; + ptr_a10_dup += p_lda*p_lda; + } + } + else + { + for(dim_t x =0;x < m-i-4;x++) + { + xmm4 = _mm_loadu_ps((float const *)(a10 + x*rs_a)); + _mm_storeu_ps((float *)(ptr_a10_dup + x*p_lda), xmm4); + } + } - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + xmm5 = xmm4 = _mm_broadcast_ss((float const *)&ones); + if(!is_unitdiag) + { + if(transa) + { + //broadcast diagonal elements of A11 + xmm0 = _mm_broadcast_ss((float const *)(a11)); + xmm1 = _mm_broadcast_ss((float const *)(a11+cs_a*1 + 1)); + xmm2 = _mm_broadcast_ss((float const *)(a11+cs_a*2 + 2)); + xmm3 = _mm_broadcast_ss((float const *)(a11+cs_a*3 + 3)); + } + else + { + //broadcast diagonal elements of A11 + xmm0 = _mm_broadcast_ss((float const *)(a11)); + xmm1 = _mm_broadcast_ss((float const *)(a11+rs_a*1 + 1)); + xmm2 = _mm_broadcast_ss((float const *)(a11+rs_a*2 + 2)); + xmm3 = _mm_broadcast_ss((float const *)(a11+rs_a*3 + 3)); + } - //(ROw5): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm10, ymm11); - ymm7 = _mm256_fnmadd_pd(ymm2, ymm6, ymm7); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); - ymm12 = _mm256_fnmadd_pd(ymm2, ymm10, ymm12); - ymm17 = _mm256_fnmadd_pd(ymm2, ymm6, ymm17); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); - ymm13 = _mm256_fnmadd_pd(ymm2, ymm10, ymm13); - ymm18 = _mm256_fnmadd_pd(ymm2, ymm6, ymm18); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); - ymm14 = _mm256_fnmadd_pd(ymm2, ymm10, ymm14); - ymm19 = _mm256_fnmadd_pd(ymm2, ymm6, ymm19); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); - ymm15 = _mm256_fnmadd_pd(ymm2, ymm10, ymm15); - ymm20 = _mm256_fnmadd_pd(ymm2, ymm6, ymm20); + xmm0 = _mm_unpacklo_ps(xmm0, xmm1); + xmm1 = _mm_unpacklo_ps(xmm2, xmm3); + xmm2 = _mm_blend_ps(xmm0, xmm1, 0x0C); + + #ifdef BLIS_DISABLE_TRSM_PREINVERSION + xmm4 = xmm2; + #endif + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + xmm4 = _mm_div_ps(xmm5, xmm2); + #endif + } + _mm_storeu_ps((float *)(d11_pack), xmm4); + + //cols + for(j = (n - d_nr); (j + 1) > 0; j -= d_nr) //loop along 'N' dimension + { + a10 = D_A_pack; + a11 = L + (i*cs_a) + (i*rs_a); //pointer to block of A to be used for TRSM + b01 = B + (j*cs_b) + i + 4; //pointer to block of B to be used for GEMM + b11 = B + (j* cs_b) + i; //pointer to block of B to be used for TRSM + + k_iter = (m - i - 4); //number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS - //perform mul operation - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm1); + ///GEMM code begins/// + BLIS_STRSM_SMALL_GEMM_4mx6n(a10,b01,cs_b,p_lda,k_iter) - a11 += rs_a; + ymm16 = _mm256_broadcast_ss((float const *)(&AlphaVal)); - //extract a44 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); - //(ROw4): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); - ymm12 = _mm256_fnmadd_pd(ymm2, ymm11, ymm12); - ymm17 = _mm256_fnmadd_pd(ymm2, ymm7, ymm17); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); - ymm13 = _mm256_fnmadd_pd(ymm2, ymm11, ymm13); - ymm18 = _mm256_fnmadd_pd(ymm2, ymm7, ymm18); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); - ymm14 = _mm256_fnmadd_pd(ymm2, ymm11, ymm14); - ymm19 = _mm256_fnmadd_pd(ymm2, ymm7, ymm19); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); - ymm15 = _mm256_fnmadd_pd(ymm2, ymm11, ymm15); - ymm20 = _mm256_fnmadd_pd(ymm2, ymm7, ymm20); + ymm17 = _mm256_insertf128_ps(ymm16, _mm_loadu_ps((float const*)(b11)), 0); + ymm18 = _mm256_insertf128_ps(ymm16, _mm_loadu_ps((float const*)(b11 + cs_b)), 0); + ymm19 = _mm256_insertf128_ps(ymm16, _mm_loadu_ps((float const*)(b11 + cs_b*2)), 0); + ymm20 = _mm256_insertf128_ps(ymm16, _mm_loadu_ps((float const*)(b11 + cs_b*3)), 0); + ymm0 = _mm256_insertf128_ps(ymm16, _mm_loadu_ps((float const*)(b11 + cs_b*4)), 0); + ymm1 = _mm256_insertf128_ps(ymm16, _mm_loadu_ps((float const*)(b11 + cs_b*5)), 0); - //perform mul operation - ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); - ymm17 = DTRSM_SMALL_DIV_OR_SCALE(ymm17, ymm1); + ymm17 = _mm256_fmsub_ps(ymm17, ymm16, ymm8); + ymm18 = _mm256_fmsub_ps(ymm18, ymm16, ymm9); + ymm19 = _mm256_fmsub_ps(ymm19, ymm16, ymm10); + ymm20 = _mm256_fmsub_ps(ymm20, ymm16, ymm11); + ymm0 = _mm256_fmsub_ps(ymm0 , ymm16, ymm4); + ymm1 = _mm256_fmsub_ps(ymm1 , ymm16, ymm5); - a11 += rs_a; + ymm8 = _mm256_unpacklo_ps(ymm17, ymm18); + ymm9 = _mm256_unpacklo_ps(ymm19, ymm20); - //extract a55 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + ymm16 = _mm256_unpacklo_ps(ymm0, ymm1); - //(ROw5): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); - ymm13 = _mm256_fnmadd_pd(ymm2, ymm12, ymm13); - ymm18 = _mm256_fnmadd_pd(ymm2, ymm17, ymm18); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); - ymm14 = _mm256_fnmadd_pd(ymm2, ymm12, ymm14); - ymm19 = _mm256_fnmadd_pd(ymm2, ymm17, ymm19); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); - ymm15 = _mm256_fnmadd_pd(ymm2, ymm12, ymm15); - ymm20 = _mm256_fnmadd_pd(ymm2, ymm17, ymm20); + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b01000100); + ymm5 = _mm256_shuffle_ps(ymm16,ymm16,0b01000100); - //perform mul operation - ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm1); - ymm18 = DTRSM_SMALL_DIV_OR_SCALE(ymm18, ymm1); + ymm10 = _mm256_permute2f128_ps(ymm4,ymm5,0x20);//1 - a11 += rs_a; + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b11101110); + ymm5 = _mm256_shuffle_ps(ymm16,ymm16,0b11101110); - //extract a66 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); + ymm11 = _mm256_permute2f128_ps(ymm4,ymm5,0x20);//2 - //(ROw6): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); - ymm14 = _mm256_fnmadd_pd(ymm2, ymm13, ymm14); - ymm19 = _mm256_fnmadd_pd(ymm2, ymm18, ymm19); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); - ymm15 = _mm256_fnmadd_pd(ymm2, ymm13, ymm15); - ymm20 = _mm256_fnmadd_pd(ymm2, ymm18, ymm20); + ymm8 = _mm256_unpackhi_ps(ymm17, ymm18); + ymm9 = _mm256_unpackhi_ps(ymm19, ymm20); - //perform mul operation - ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); - ymm19 = DTRSM_SMALL_DIV_OR_SCALE(ymm19, ymm1); + ymm16 = _mm256_unpackhi_ps(ymm0, ymm1); - a11 += rs_a; + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b01000100); + ymm5 = _mm256_shuffle_ps(ymm16,ymm16,0b01000100); - //extract a77 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); + ymm17 = _mm256_permute2f128_ps(ymm4,ymm5,0x20);//3 - //(ROw7): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); - ymm15 = _mm256_fnmadd_pd(ymm2, ymm14, ymm15); - ymm20 = _mm256_fnmadd_pd(ymm2, ymm19, ymm20); + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b11101110); + ymm5 = _mm256_shuffle_ps(ymm16,ymm16,0b11101110); - //perform mul operation - ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); - ymm20 = DTRSM_SMALL_DIV_OR_SCALE(ymm20, ymm1); + ymm18 = _mm256_permute2f128_ps(ymm4,ymm5,0x20);//4 - a11 += rs_a; + // Implement TRSM - BLIS_DTRSM_SMALL_NREG_TRANSPOSE_8x6_AND_STORE(b11,cs_b) - } + //extract a33 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 3)); - dim_t n_rem = n-j; - if(n_rem >= 4) - { - a10 = D_A_pack; - a11 = L + (i*rs_a) + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + j*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM + //perform mul operation + ymm18 = STRSM_SMALL_DIV_OR_SCALE(ymm18, ymm1); - k_iter = i ; //number of times GEMM to be performed(in blocks of 4x4) + //extract a22 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 2)); - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS + //(ROw2): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*3 + cs_a*2)); + ymm17 = _mm256_fnmadd_ps(ymm0, ymm18, ymm17); - ///GEMM code begins/// - BLIS_DTRSM_SMALL_GEMM_8mx4n(a10,b01,cs_b,p_lda,k_iter) + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*3 + cs_a*1)); + ymm11 = _mm256_fnmadd_ps(ymm0, ymm18, ymm11 ); - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*3)); + ymm10 = _mm256_fnmadd_ps(ymm0, ymm18, ymm10 ); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] - ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] - ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *2 + 4)); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] - ymm7 = _mm256_loadu_pd((double const *)(b11 + cs_b *3 + 4)); //B11[0][7] B11[1][7] B11[2][7] B11[3][7] + //perform mul operation + ymm17 = STRSM_SMALL_DIV_OR_SCALE(ymm17, ymm1); - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); //B11[0-3][3] * alpha -= B01[0-3][3] - ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4] - ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] * alpha -= B01[0-3][5] - ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); //B11[0-3][6] * alpha -= B01[0-3][6] - ymm7 = _mm256_fmsub_pd(ymm7, ymm16, ymm15); //B11[0-3][7] * alpha -= B01[0-3][7] + //extract a11 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); - ///implement TRSM/// + //(ROw1): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*2 + cs_a*1)); + ymm11 = _mm256_fnmadd_ps(ymm0, ymm17, ymm11 ); - ///transpose of B11// - ///unpacklow/// - ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*2)); + ymm10 = _mm256_fnmadd_ps(ymm0, ymm17, ymm10 ); - ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[2][4] B11[2][5] - ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[2][6] B11[2][7] + //perform mul operation + ymm11 = STRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); - //rearrange low elements - ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + //extract a00 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 0)); - ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3] - ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3] + //(ROw0): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*1)); + ymm10 = _mm256_fnmadd_ps(ymm0, ymm11, ymm10 ); - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + //perform mul operation + ymm10 = STRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); - ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[1][4] B11[1][5] B11[3][4] B11[3][5] - ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[1][6] B11[1][7] B11[3][6] B11[3][7] + ymm0 = _mm256_unpacklo_ps(ymm10, ymm11); + ymm1 = _mm256_unpacklo_ps(ymm17, ymm18); - //rearrange high elements - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b01000100); - ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3] - ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3] + ymm16 = _mm256_permute2f128_ps(ymm12,ymm12,0x20);//1 + _mm_storeu_ps((float *)(b11), _mm256_extractf128_ps(ymm16, 0)); - ymm0 = _mm256_broadcast_sd((double const *)&ones); + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x31);//5 + _mm_storeu_ps((float *)(b11 + 4*cs_b), _mm256_extractf128_ps(ymm16, 0)); - //extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b11101110); + ymm13 = _mm256_shuffle_ps(ymm6,ymm7,0b11101110); - //perform mul operation - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x20);//2 + _mm_storeu_ps((float *)(b11 + cs_b), _mm256_extractf128_ps(ymm16, 0)); - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + ymm16 = _mm256_permute2f128_ps(ymm12,ymm13,0x31);//6 + _mm_storeu_ps((float *)(b11 + 5*cs_b), _mm256_extractf128_ps(ymm16, 0)); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*1)); - ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); - ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); - ymm5 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); - ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + ymm0 = _mm256_unpackhi_ps(ymm10, ymm11); + ymm1 = _mm256_unpackhi_ps(ymm17, ymm18); - a11 += rs_a; + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b01000100); - //(ROw1): FMA operations - ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); - ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); - ymm11 = _mm256_fnmadd_pd(ymm4, ymm8, ymm11); - ymm12 = _mm256_fnmadd_pd(ymm5, ymm8, ymm12); - ymm13 = _mm256_fnmadd_pd(ymm6, ymm8, ymm13); - ymm14 = _mm256_fnmadd_pd(ymm7, ymm8, ymm14); - ymm15 = _mm256_fnmadd_pd(ymm16, ymm8, ymm15); + ymm16 = _mm256_permute2f128_ps(ymm12,ymm12,0x20);//3 + _mm_storeu_ps((float *)(b11 + 2*cs_b), _mm256_extractf128_ps(ymm16, 0)); - //perform mul operation - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b11101110); - ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); - ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); - ymm5 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); - ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + ymm16 = _mm256_permute2f128_ps(ymm12,ymm12,0x20);//4 + _mm_storeu_ps((float *)(b11 + 3*cs_b), _mm256_extractf128_ps(ymm16, 0)); + } - a11 += rs_a; + dim_t n_remainder = j + d_nr; + if((n_remainder >= 4)) + { + a10 = D_A_pack; + a11 = L + (i*cs_a) + (i*rs_a); //pointer to block of A to be used for TRSM + b01 = B + ((n_remainder - 4)* cs_b) + i + 4; //pointer to block of B to be used for GEMM + b11 = B + ((n_remainder - 4)* cs_b) + i; //pointer to block of B to be used for TRSM - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + k_iter = (m - i - 4); //number of times GEMM to be performed(in blocks of 4x4) - //(ROw2): FMA operations - ymm10 = _mm256_fnmadd_pd(ymm3, ymm9, ymm10); - ymm11 = _mm256_fnmadd_pd(ymm4, ymm9, ymm11); - ymm12 = _mm256_fnmadd_pd(ymm5, ymm9, ymm12); - ymm13 = _mm256_fnmadd_pd(ymm6, ymm9, ymm13); - ymm14 = _mm256_fnmadd_pd(ymm7, ymm9, ymm14); - ymm15 = _mm256_fnmadd_pd(ymm16, ymm9, ymm15); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS - //perform mul operation - ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + ///GEMM code begins/// + BLIS_STRSM_SMALL_GEMM_4mx4n(a10,b01,cs_b,p_lda,k_iter) - ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); - ymm5 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); - ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + ymm16 = _mm256_broadcast_ss((float const *)(&AlphaVal)); - a11 += rs_a; + ymm17 = _mm256_insertf128_ps(ymm16, _mm_loadu_ps((float const*)(b11)), 0); + ymm18 = _mm256_insertf128_ps(ymm16, _mm_loadu_ps((float const*)(b11 + cs_b)), 0); + ymm19 = _mm256_insertf128_ps(ymm16, _mm_loadu_ps((float const*)(b11 + cs_b*2)), 0); + ymm20 = _mm256_insertf128_ps(ymm16, _mm_loadu_ps((float const*)(b11 + cs_b*3)), 0); - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + ymm17 = _mm256_fmsub_ps(ymm17, ymm16, ymm8); + ymm18 = _mm256_fmsub_ps(ymm18, ymm16, ymm9); + ymm19 = _mm256_fmsub_ps(ymm19, ymm16, ymm10); + ymm20 = _mm256_fmsub_ps(ymm20, ymm16, ymm11); - //(ROw5): FMA operations - ymm11 = _mm256_fnmadd_pd(ymm4, ymm10, ymm11); - ymm12 = _mm256_fnmadd_pd(ymm5, ymm10, ymm12); - ymm13 = _mm256_fnmadd_pd(ymm6, ymm10, ymm13); - ymm14 = _mm256_fnmadd_pd(ymm7, ymm10, ymm14); - ymm15 = _mm256_fnmadd_pd(ymm16, ymm10, ymm15); + ymm8 = _mm256_unpacklo_ps(ymm17, ymm18); + ymm9 = _mm256_unpacklo_ps(ymm19, ymm20); - //perform mul operation - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b01000100); - ymm0 = _mm256_broadcast_sd((double const *)&ones); + ymm10 = _mm256_permute2f128_ps(ymm4,ymm4,0x20);//1 - //extract a44 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b11101110); - ymm5 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); - ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + ymm11 = _mm256_permute2f128_ps(ymm4,ymm4,0x20);//2 - a11 += rs_a; + ymm8 = _mm256_unpackhi_ps(ymm17, ymm18); + ymm9 = _mm256_unpackhi_ps(ymm19, ymm20); - //(ROw4): FMA operations - ymm12 = _mm256_fnmadd_pd(ymm5, ymm11, ymm12); - ymm13 = _mm256_fnmadd_pd(ymm6, ymm11, ymm13); - ymm14 = _mm256_fnmadd_pd(ymm7, ymm11, ymm14); - ymm15 = _mm256_fnmadd_pd(ymm16, ymm11, ymm15); + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b01000100); - //perform mul operation - ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); + ymm17 = _mm256_permute2f128_ps(ymm4,ymm4,0x20);//3 - ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b11101110); - a11 += rs_a; + ymm18 = _mm256_permute2f128_ps(ymm4,ymm4,0x20);//4 - //extract a55 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + // Implement TRSM - //(ROw5): FMA operations - ymm13 = _mm256_fnmadd_pd(ymm6, ymm12, ymm13); - ymm14 = _mm256_fnmadd_pd(ymm7, ymm12, ymm14); - ymm15 = _mm256_fnmadd_pd(ymm16, ymm12, ymm15); + //extract a33 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 3)); //perform mul operation - ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm1); + ymm18 = STRSM_SMALL_DIV_OR_SCALE(ymm18, ymm1); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 +cs_a*7)); + //extract a22 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 2)); - a11 += rs_a; + //(ROw2): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*3 + cs_a*2)); + ymm17 = _mm256_fnmadd_ps(ymm0, ymm18, ymm17); - //extract a66 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*3 + cs_a*1)); + ymm11 = _mm256_fnmadd_ps(ymm0, ymm18, ymm11 ); - //(ROw6): FMA operations - ymm14 = _mm256_fnmadd_pd(ymm7, ymm13, ymm14); - ymm15 = _mm256_fnmadd_pd(ymm16, ymm13, ymm15); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*3)); + ymm10 = _mm256_fnmadd_ps(ymm0, ymm18, ymm10 ); //perform mul operation - ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); + ymm17 = STRSM_SMALL_DIV_OR_SCALE(ymm17, ymm1); - //extract a77 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); + //extract a11 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + //(ROw1): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*2 + cs_a*1)); + ymm11 = _mm256_fnmadd_ps(ymm0, ymm17, ymm11 ); - a11 += rs_a; - //(ROw7): FMA operations - ymm15 = _mm256_fnmadd_pd(ymm16, ymm14, ymm15); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*2)); + ymm10 = _mm256_fnmadd_ps(ymm0, ymm17, ymm10 ); //perform mul operation - ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); + ymm11 = STRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + //extract a00 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 0)); - ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] - ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2] + //(ROw0): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*1)); + ymm10 = _mm256_fnmadd_ps(ymm0, ymm11, ymm10 ); - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + //perform mul operation + ymm10 = STRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + + ymm0 = _mm256_unpacklo_ps(ymm10, ymm11); + ymm1 = _mm256_unpacklo_ps(ymm17, ymm18); + + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b01000100); - ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] - ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] + ymm16 = _mm256_permute2f128_ps(ymm12,ymm12,0x20);//1 + _mm_storeu_ps((float *)(b11), _mm256_extractf128_ps(ymm16, 0)); - ///unpack high/// - ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b11101110); - ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[4][1] B11[5][1] B11[4][3] B11[5][3] - ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[6][1] B11[7][1] B11[6][3] B11[7][3] + ymm16 = _mm256_permute2f128_ps(ymm12,ymm12,0x20);//2 + _mm_storeu_ps((float *)(b11 + cs_b), _mm256_extractf128_ps(ymm16, 0)); - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm0 = _mm256_unpackhi_ps(ymm10, ymm11); + ymm1 = _mm256_unpackhi_ps(ymm17, ymm18); - ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] - ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b01000100); - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); //store B11[4][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); //store B11[5][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 4), ymm6); //store B11[6][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 3 + 4), ymm7); //store B11[7][0-3] + ymm16 = _mm256_permute2f128_ps(ymm12,ymm12,0x20);//3 + _mm_storeu_ps((float *)(b11 + 2*cs_b), _mm256_extractf128_ps(ymm16, 0)); - n_rem -=4; - j +=4; + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b11101110); + + ymm16 = _mm256_permute2f128_ps(ymm12,ymm12,0x20);//4 + _mm_storeu_ps((float *)(b11 + 3*cs_b), _mm256_extractf128_ps(ymm16, 0)); + n_remainder = n_remainder - 4; } - if(n_rem) + if(n_remainder) //implementation fo remaining columns(when 'N' is not a multiple of d_nr)() n = 3 { a10 = D_A_pack; - a11 = L + (i*rs_a) + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + j*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM + a11 = L + (i*cs_a) + (i*rs_a); + b01 = B + i + 4; + b11 = B + i; - k_iter = i; //number of times GEMM to be performed(in blocks of 4x4) + k_iter = (m - i - 4); - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS + ymm8 = _mm256_setzero_ps(); + ymm9 = _mm256_setzero_ps(); + ymm10 = _mm256_setzero_ps(); - if(3 == n_rem) + if(3 == n_remainder) { - ///GEMM code begins/// - BLIS_DTRSM_SMALL_GEMM_8mx3n(a10,b01,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + ymm16 = _mm256_broadcast_ss((float const *)(&AlphaVal)); //register to hold alpha - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm17 = _mm256_insertf128_ps(ymm16, _mm_loadu_ps((float const*)(b11)), 0); + ymm18 = _mm256_insertf128_ps(ymm16, _mm_loadu_ps((float const*)(b11 + cs_b)), 0); + ymm19 = _mm256_insertf128_ps(ymm16, _mm_loadu_ps((float const*)(b11 + cs_b*2)), 0); - ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] - ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] - ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *2 + 4)); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] + ymm17 = _mm256_fmsub_ps(ymm17, ymm16, ymm8); + ymm18 = _mm256_fmsub_ps(ymm18, ymm16, ymm9); + ymm19 = _mm256_fmsub_ps(ymm19, ymm16, ymm10); + ymm20 = _mm256_broadcast_ss((float const *)(&ones)); + } + else if(2 == n_remainder) + { + BLIS_STRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b,p_lda,k_iter) - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] - ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + ymm16 = _mm256_broadcast_ss((float const *)(&AlphaVal)); //register to hold alpha - ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4] - ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] * alpha -= B01[0-3][5] - ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); //B11[0-3][6] * alpha -= B01[0-3][6] - ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + ymm17 = _mm256_insertf128_ps(ymm16, _mm_loadu_ps((float const*)(b11)), 0); + ymm18 = _mm256_insertf128_ps(ymm16, _mm_loadu_ps((float const*)(b11 + cs_b)), 0); + + ymm17 = _mm256_fmsub_ps(ymm17, ymm16, ymm8); + ymm18 = _mm256_fmsub_ps(ymm18, ymm16, ymm9); + ymm19 = _mm256_broadcast_ss((float const *)(&ones)); + ymm20 = _mm256_broadcast_ss((float const *)(&ones)); } - else if(2 == n_rem) + else if(1 == n_remainder) { - ///GEMM code begins/// - BLIS_DTRSM_SMALL_GEMM_8mx2n(a10,b01,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b,p_lda,k_iter) - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + ymm16 = _mm256_broadcast_ss((float const *)(&AlphaVal)); //register to hold alpha - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm17 = _mm256_insertf128_ps(ymm16, _mm_loadu_ps((float const*)(b11)), 0); - ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] - ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] + ymm17 = _mm256_fmsub_ps(ymm17, ymm16, ymm8); + ymm18 = _mm256_broadcast_ss((float const *)(&ones)); + ymm19 = _mm256_broadcast_ss((float const *)(&ones)); + ymm20 = _mm256_broadcast_ss((float const *)(&ones)); + } - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] - ymm2 = _mm256_broadcast_sd((double const *)(&ones)); - ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + ///implement TRSM/// - ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4] - ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] * alpha -= B01[0-3][5] - ymm6 = _mm256_broadcast_sd((double const *)(&ones)); - ymm7 = _mm256_broadcast_sd((double const *)(&ones)); - } - else if(1 == n_rem) - { - ///GEMM code begins/// - BLIS_DTRSM_SMALL_GEMM_8mx1n(a10,b01,cs_b,p_lda,k_iter) + ymm8 = _mm256_unpacklo_ps(ymm17, ymm18); + ymm9 = _mm256_unpacklo_ps(ymm19, ymm20); - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b01000100); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm10 = _mm256_permute2f128_ps(ymm4,ymm4,0x20);//1 - ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b11101110); - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_broadcast_sd((double const *)(&ones)); - ymm2 = _mm256_broadcast_sd((double const *)(&ones)); - ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + ymm11 = _mm256_permute2f128_ps(ymm4,ymm4,0x20);//2 - ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4] - ymm5 = _mm256_broadcast_sd((double const *)(&ones)); - ymm6 = _mm256_broadcast_sd((double const *)(&ones)); - ymm7 = _mm256_broadcast_sd((double const *)(&ones)); - } - ///implement TRSM/// + ymm8 = _mm256_unpackhi_ps(ymm17, ymm18); + ymm9 = _mm256_unpackhi_ps(ymm19, ymm20); - ///transpose of B11// - ///unpacklow/// - ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b01000100); - ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[2][4] B11[2][5] - ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[2][6] B11[2][7] + ymm17 = _mm256_permute2f128_ps(ymm4,ymm4,0x20);//3 - //rearrange low elements - ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + ymm4 = _mm256_shuffle_ps(ymm8,ymm9,0b11101110); - ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3] - ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3] + ymm18 = _mm256_permute2f128_ps(ymm4,ymm4,0x20);//4 - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + //Implement TRSM - ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[1][4] B11[1][5] B11[3][4] B11[3][5] - ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[1][6] B11[1][7] B11[3][6] B11[3][7] + //extract a33 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 3)); - //rearrange high elements - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + //perform mul operation + ymm18 = STRSM_SMALL_DIV_OR_SCALE(ymm18, ymm1); - ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3] - ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3] + //extract a22 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 2)); - ymm0 = _mm256_broadcast_sd((double const *)&ones); + //(ROw2): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*3 + cs_a*2)); + ymm17 = _mm256_fnmadd_ps(ymm0, ymm18, ymm17); - //extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*3 + cs_a*1)); + ymm11 = _mm256_fnmadd_ps(ymm0, ymm18, ymm11 ); + + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*3)); + ymm10 = _mm256_fnmadd_ps(ymm0, ymm18, ymm10 ); //perform mul operation - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); + ymm17 = STRSM_SMALL_DIV_OR_SCALE(ymm17, ymm1); //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 1)); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*1)); - ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); - ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); - ymm5 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); - ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + //(ROw1): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*2 + cs_a*1)); + ymm11 = _mm256_fnmadd_ps(ymm0, ymm17, ymm11 ); - a11 += rs_a; + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*2)); + ymm10 = _mm256_fnmadd_ps(ymm0, ymm17, ymm10 ); - //(ROw1): FMA operations - ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); - ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); - ymm11 = _mm256_fnmadd_pd(ymm4, ymm8, ymm11); - ymm12 = _mm256_fnmadd_pd(ymm5, ymm8, ymm12); - ymm13 = _mm256_fnmadd_pd(ymm6, ymm8, ymm13); - ymm14 = _mm256_fnmadd_pd(ymm7, ymm8, ymm14); - ymm15 = _mm256_fnmadd_pd(ymm16, ymm8, ymm15); + //perform mul operation + ymm11 = STRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); + + //extract a00 + ymm1 = _mm256_broadcast_ss((float const *)(d11_pack + 0)); + + //(ROw0): FMA operations + ymm0 = _mm256_broadcast_ss((float const *)(a11 + rs_a*1)); + ymm10 = _mm256_fnmadd_ps(ymm0, ymm11, ymm10 ); //perform mul operation - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); + ymm10 = STRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); - ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); - ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); - ymm5 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); - ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + if(3 == n_remainder) + { + ymm0 = _mm256_unpacklo_ps(ymm10, ymm11); + ymm1 = _mm256_unpacklo_ps(ymm17, ymm18); - a11 += rs_a; + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b01000100); - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + ymm16 = _mm256_permute2f128_ps(ymm12,ymm12,0x20);//1 + _mm_storeu_ps((float *)(b11), _mm256_extractf128_ps(ymm16, 0)); - //(ROw2): FMA operations - ymm10 = _mm256_fnmadd_pd(ymm3, ymm9, ymm10); - ymm11 = _mm256_fnmadd_pd(ymm4, ymm9, ymm11); - ymm12 = _mm256_fnmadd_pd(ymm5, ymm9, ymm12); - ymm13 = _mm256_fnmadd_pd(ymm6, ymm9, ymm13); - ymm14 = _mm256_fnmadd_pd(ymm7, ymm9, ymm14); - ymm15 = _mm256_fnmadd_pd(ymm16, ymm9, ymm15); + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b11101110); - //perform mul operation - ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + ymm16 = _mm256_permute2f128_ps(ymm12,ymm12,0x20);//2 + _mm_storeu_ps((float *)(b11 + cs_b), _mm256_extractf128_ps(ymm16, 0)); - ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); - ymm5 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); - ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + ymm0 = _mm256_unpackhi_ps(ymm10, ymm11); + ymm1 = _mm256_unpackhi_ps(ymm17, ymm18); - a11 += rs_a; + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b01000100); - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + ymm16 = _mm256_permute2f128_ps(ymm12,ymm12,0x20);//3 + _mm_storeu_ps((float *)(b11 + 2*cs_b), _mm256_extractf128_ps(ymm16, 0)); + } + else if(2 == n_remainder) + { + ymm0 = _mm256_unpacklo_ps(ymm10, ymm11); + ymm1 = _mm256_unpacklo_ps(ymm17, ymm18); - //(ROw5): FMA operations - ymm11 = _mm256_fnmadd_pd(ymm4, ymm10, ymm11); - ymm12 = _mm256_fnmadd_pd(ymm5, ymm10, ymm12); - ymm13 = _mm256_fnmadd_pd(ymm6, ymm10, ymm13); - ymm14 = _mm256_fnmadd_pd(ymm7, ymm10, ymm14); - ymm15 = _mm256_fnmadd_pd(ymm16, ymm10, ymm15); + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b01000100); - //perform mul operation - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); + ymm16 = _mm256_permute2f128_ps(ymm12,ymm12,0x20);//1 + _mm_storeu_ps((float *)(b11), _mm256_extractf128_ps(ymm16, 0)); - ymm0 = _mm256_broadcast_sd((double const *)&ones); + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b11101110); - //extract a44 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + ymm16 = _mm256_permute2f128_ps(ymm12,ymm12,0x20);//2 + _mm_storeu_ps((float *)(b11 + cs_b), _mm256_extractf128_ps(ymm16, 0)); + } + else if(1 == n_remainder) + { + ymm0 = _mm256_unpacklo_ps(ymm10, ymm11); + ymm1 = _mm256_unpacklo_ps(ymm17, ymm18); - ymm5 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); - ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + ymm12 = _mm256_shuffle_ps(ymm0,ymm1,0b01000100); - a11 += rs_a; + ymm16 = _mm256_permute2f128_ps(ymm12,ymm12,0x20);//1 + _mm_storeu_ps((float *)(b11), _mm256_extractf128_ps(ymm16, 0)); + } + } + m_remainder -= 4; + } - //(ROw4): FMA operations - ymm12 = _mm256_fnmadd_pd(ymm5, ymm11, ymm12); - ymm13 = _mm256_fnmadd_pd(ymm6, ymm11, ymm13); - ymm14 = _mm256_fnmadd_pd(ymm7, ymm11, ymm14); - ymm15 = _mm256_fnmadd_pd(ymm16, ymm11, ymm15); + a10 = L + m_remainder*rs_a; - //perform mul operation - ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); + // Do transpose for a10 & store in D_A_pack + float *ptr_a10_dup = D_A_pack; + if(3 == m_remainder) // Repetative A blocks will be 3*3 + { + __m128 xmm0,xmm1,xmm2,xmm3; + __m128 xmm4,xmm5; + __m128 xmm6,xmm7,xmm8,xmm9; + dim_t p_lda = 4; // packed leading dimension + if(transa) + { + for(dim_t x =0;x < m-m_remainder;x+=p_lda) + { + xmm0 = _mm_loadu_ps((float const *)(a10)); + xmm1 = _mm_loadu_ps((float const *)(a10 + cs_a)); + xmm2 = _mm_loadu_ps((float const *)(a10 + cs_a * 2)); + xmm3 = _mm_broadcast_ss((float const *)&ones); + + xmm4 = _mm_unpacklo_ps(xmm0, xmm1); + xmm5 = _mm_unpacklo_ps(xmm2, xmm3); + xmm6 = _mm_shuffle_ps(xmm4,xmm5,0x44); + xmm7 = _mm_shuffle_ps(xmm4,xmm5,0xEE); + + xmm0 = _mm_unpackhi_ps(xmm0, xmm1); + xmm1 = _mm_unpackhi_ps(xmm2, xmm3); + xmm8 = _mm_shuffle_ps(xmm0,xmm1,0x44); + xmm9 = _mm_shuffle_ps(xmm0,xmm1,0xEE); + + _mm_storeu_ps((float *)(ptr_a10_dup), xmm6); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda), xmm7); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda*2), xmm8); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda*3), xmm9); - ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + a10 += p_lda; + ptr_a10_dup += p_lda*p_lda; + } + } + else + { + for(dim_t x =0;x < m-m_remainder;x++) + { + xmm5 = _mm_broadcast_ss((float const *)(a10 + rs_a * x + 2)); + xmm4 = _mm_loadl_pi(xmm5,(__m64 *)(a10 + rs_a * x)); + _mm_storel_pi((__m64 *)(ptr_a10_dup + p_lda * x), xmm4); + _mm_store_ss((float *)(ptr_a10_dup + p_lda * x + 2), _mm_permute_ps(xmm4,0x02)); + } + } - a11 += rs_a; + //cols + for(j = (n - d_nr); (j + 1) > 0; j -= d_nr) //loop along 'N' dimension + { + a10 = D_A_pack; + a11 = L; //pointer to block of A to be used for TRSM + b01 = B + (j* cs_b) + m_remainder; //pointer to block of B to be used for GEMM + b11 = B + (j* cs_b); //pointer to block of B to be used for TRSM - //extract a55 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM code begins/// + BLIS_STRSM_SMALL_GEMM_3mx6n(a10,b01,cs_b,p_lda,k_iter) + + ///GEMM code ends/// + ymm16 = _mm256_broadcast_ss((float const *)(&AlphaVal)); + + xmm4 = _mm_broadcast_ss((float const *)(b11 + 2)); + ymm17 = _mm256_insertf128_ps(ymm1, _mm_loadl_pi(xmm4,(__m64 *)(b11)), 0); + xmm4 = _mm_broadcast_ss((float const *)(b11 + cs_b + 2)); + ymm18 = _mm256_insertf128_ps(ymm1, _mm_loadl_pi(xmm4,(__m64 *)(b11 + cs_b)), 0); + xmm4 = _mm_broadcast_ss((float const *)(b11 + cs_b*2 + 2)); + ymm19 = _mm256_insertf128_ps(ymm1, _mm_loadl_pi(xmm4,(__m64 *)(b11 + cs_b*2)), 0); + xmm4 = _mm_broadcast_ss((float const *)(b11 + cs_b*3 + 2)); + ymm20 = _mm256_insertf128_ps(ymm1, _mm_loadl_pi(xmm4,(__m64 *)(b11 + cs_b*3)), 0); + xmm4 = _mm_broadcast_ss((float const *)(b11 + cs_b*4 + 2)); + ymm0 = _mm256_insertf128_ps(ymm1, _mm_loadl_pi(xmm4,(__m64 *)(b11 + cs_b*4)), 0); + xmm4 = _mm_broadcast_ss((float const *)(b11 + cs_b*5 + 2)); + ymm1 = _mm256_insertf128_ps(ymm1, _mm_loadl_pi(xmm4,(__m64 *)(b11 + cs_b*5)), 0); + + ymm17 = _mm256_fmsub_ps(ymm17, ymm16, ymm8); + ymm18 = _mm256_fmsub_ps(ymm18, ymm16, ymm9); + ymm19 = _mm256_fmsub_ps(ymm19, ymm16, ymm10); + ymm20 = _mm256_fmsub_ps(ymm20, ymm16, ymm11); + ymm0 = _mm256_fmsub_ps(ymm0 , ymm16, ymm4); + ymm1 = _mm256_fmsub_ps(ymm1 , ymm16, ymm5); + + _mm_storel_pi((__m64 *)(b11), _mm256_extractf128_ps(ymm17, 0)); + _mm_store_ss((float *)(b11 + 2), _mm_permute_ps(_mm256_extractf128_ps(ymm17, 0),0x02)); + + _mm_storel_pi((__m64 *)(b11 + cs_b), _mm256_extractf128_ps(ymm18, 0)); + _mm_store_ss((float *)(b11 + cs_b + 2), _mm_permute_ps(_mm256_extractf128_ps(ymm18, 0),0x02)); + + _mm_storel_pi((__m64 *)(b11 + cs_b*2), _mm256_extractf128_ps(ymm19, 0)); + _mm_store_ss((float *)(b11 + cs_b*2 + 2), _mm_permute_ps(_mm256_extractf128_ps(ymm19, 0),0x02)); + + _mm_storel_pi((__m64 *)(b11 + cs_b*3), _mm256_extractf128_ps(ymm20, 0)); + _mm_store_ss((float *)(b11 + cs_b*3 + 2), _mm_permute_ps(_mm256_extractf128_ps(ymm20, 0),0x02)); + + _mm_storel_pi((__m64 *)(b11 + cs_b*4), _mm256_extractf128_ps(ymm0, 0)); + _mm_store_ss((float *)(b11 + cs_b*4 + 2), _mm_permute_ps(_mm256_extractf128_ps(ymm0, 0),0x02)); + + _mm_storel_pi((__m64 *)(b11 + cs_b*5), _mm256_extractf128_ps(ymm1, 0)); + _mm_store_ss((float *)(b11 + cs_b*5 + 2), _mm_permute_ps(_mm256_extractf128_ps(ymm1, 0),0x02)); + + if(transa) + strsm_AltXB_ref(a11, b11, m_remainder, 6, cs_a, cs_b, is_unitdiag); + else + strsm_AuXB_ref(a11, b11, m_remainder, 6, rs_a, cs_b, is_unitdiag); + } + dim_t n_remainder = j + d_nr; + if((n_remainder >= 4)) + { + a10 = D_A_pack; + a11 = L; //pointer to block of A to be used for TRSM + b01 = B + ((n_remainder - 4)* cs_b) + m_remainder; //pointer to block of B to be used for GEMM + b11 = B + ((n_remainder - 4)* cs_b); //pointer to block of B to be used for TRSM - //(ROw5): FMA operations - ymm13 = _mm256_fnmadd_pd(ymm6, ymm12, ymm13); - ymm14 = _mm256_fnmadd_pd(ymm7, ymm12, ymm14); - ymm15 = _mm256_fnmadd_pd(ymm16, ymm12, ymm15); + k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) - //perform mul operation - ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm1); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 +cs_a*7)); + ///GEMM code begins/// + BLIS_STRSM_SMALL_GEMM_3mx4n(a10,b01,cs_b,p_lda,k_iter) - a11 += rs_a; + ymm16 = _mm256_broadcast_ss((float const *)(&AlphaVal)); - //extract a66 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); + xmm4 = _mm_broadcast_ss((float const *)(b11 + 2)); + ymm17 = _mm256_insertf128_ps(ymm1, _mm_loadl_pi(xmm4,(__m64 *)(b11)), 0); + xmm4 = _mm_broadcast_ss((float const *)(b11 + cs_b + 2)); + ymm18 = _mm256_insertf128_ps(ymm1, _mm_loadl_pi(xmm4,(__m64 *)(b11 + cs_b)), 0); + xmm4 = _mm_broadcast_ss((float const *)(b11 + cs_b*2 + 2)); + ymm19 = _mm256_insertf128_ps(ymm1, _mm_loadl_pi(xmm4,(__m64 *)(b11 + cs_b*2)), 0); + xmm4 = _mm_broadcast_ss((float const *)(b11 + cs_b*3 + 2)); + ymm20 = _mm256_insertf128_ps(ymm1, _mm_loadl_pi(xmm4,(__m64 *)(b11 + cs_b*3)), 0); + ymm17 = _mm256_fmsub_ps(ymm17, ymm16, ymm8); + ymm18 = _mm256_fmsub_ps(ymm18, ymm16, ymm9); + ymm19 = _mm256_fmsub_ps(ymm19, ymm16, ymm10); + ymm20 = _mm256_fmsub_ps(ymm20, ymm16, ymm11); - //(ROw6): FMA operations - ymm14 = _mm256_fnmadd_pd(ymm7, ymm13, ymm14); - ymm15 = _mm256_fnmadd_pd(ymm16, ymm13, ymm15); + _mm_storel_pi((__m64 *)(b11), _mm256_extractf128_ps(ymm17, 0)); + _mm_store_ss((float *)(b11 + 2), _mm_permute_ps(_mm256_extractf128_ps(ymm17, 0),0x02)); - //perform mul operation - ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); + _mm_storel_pi((__m64 *)(b11 + cs_b), _mm256_extractf128_ps(ymm18, 0)); + _mm_store_ss((float *)(b11 + cs_b + 2), _mm_permute_ps(_mm256_extractf128_ps(ymm18, 0),0x02)); - //extract a77 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); + _mm_storel_pi((__m64 *)(b11 + cs_b*2), _mm256_extractf128_ps(ymm19, 0)); + _mm_store_ss((float *)(b11 + cs_b*2 + 2), _mm_permute_ps(_mm256_extractf128_ps(ymm19, 0),0x02)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + _mm_storel_pi((__m64 *)(b11 + cs_b*3), _mm256_extractf128_ps(ymm20, 0)); + _mm_store_ss((float *)(b11 + cs_b*3 + 2), _mm_permute_ps(_mm256_extractf128_ps(ymm20, 0),0x02)); - a11 += rs_a; - //(ROw7): FMA operations - ymm15 = _mm256_fnmadd_pd(ymm16, ymm14, ymm15); + if(transa) + strsm_AltXB_ref(a11, b11, m_remainder, 4, cs_a, cs_b, is_unitdiag); + else + strsm_AuXB_ref(a11, b11, m_remainder, 4, rs_a, cs_b, is_unitdiag); + n_remainder -= 4; + } - //perform mul operation - ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); + if(n_remainder) + { + a10 = D_A_pack; + a11 = L; //pointer to block of A to be used for TRSM + b01 = B + m_remainder; //pointer to block of B to be used for GEMM + b11 = B; //pointer to block of B to be used for TRSM - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) - ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] - ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2] + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + if(3 == n_remainder) + { + ///GEMM code begins/// + BLIS_STRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) - ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] - ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] + BLIS_PRE_STRSM_SMALL_3M_3N(AlphaVal,b11,cs_b) - ///unpack high/// - ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + if(transa) + strsm_AltXB_ref(a11, b11, m_remainder, 3, cs_a, cs_b, is_unitdiag); + else + strsm_AuXB_ref(a11, b11, m_remainder, 3, rs_a, cs_b, is_unitdiag); + } + else if(2 == n_remainder) + { + ///GEMM code begins/// + BLIS_STRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b,p_lda,k_iter) - ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[4][1] B11[5][1] B11[4][3] B11[5][3] - ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[6][1] B11[7][1] B11[6][3] B11[7][3] + BLIS_PRE_STRSM_SMALL_3M_2N(AlphaVal,b11,cs_b) - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + if(transa) + strsm_AltXB_ref(a11, b11, m_remainder, 2, cs_a, cs_b, is_unitdiag); + else + strsm_AuXB_ref(a11, b11, m_remainder, 2, rs_a, cs_b, is_unitdiag); + } + else if(1 == n_remainder) + { + ///GEMM code begins/// + BLIS_STRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b,p_lda,k_iter) - ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] - ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] + BLIS_PRE_STRSM_SMALL_3M_1N(AlphaVal,b11,cs_b) - if(3 == n_rem) - { - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); //store B11[4][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); //store B11[5][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 4), ymm6); //store B11[6][0-3] + if(transa) + strsm_AltXB_ref(a11, b11, m_remainder, 1, cs_a, cs_b, is_unitdiag); + else + strsm_AuXB_ref(a11, b11, m_remainder, 1, rs_a, cs_b, is_unitdiag); + } } - else if(2 == n_rem) + } + else if(2 == m_remainder) // Repetative A blocks will be 2*2 + { + __m128 xmm0,xmm1,xmm2,xmm3; + __m128 xmm4,xmm5; + __m128 xmm6,xmm7,xmm8,xmm9; + dim_t p_lda = 4; // packed leading dimension + if(transa) { - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); //store B11[4][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); //store B11[5][0-3] + for(dim_t x =0;x < m-m_remainder;x+=p_lda) + { + xmm0 = _mm_loadu_ps((float const *)(a10)); + xmm1 = _mm_loadu_ps((float const *)(a10 + cs_a)); + xmm2 = _mm_broadcast_ss((float const *)&ones); + xmm3 = _mm_broadcast_ss((float const *)&ones); + + xmm4 = _mm_unpacklo_ps(xmm0, xmm1); + xmm5 = _mm_unpacklo_ps(xmm2, xmm3); + xmm6 = _mm_shuffle_ps(xmm4,xmm5,0x44); + xmm7 = _mm_shuffle_ps(xmm4,xmm5,0xEE); + + xmm0 = _mm_unpackhi_ps(xmm0, xmm1); + xmm1 = _mm_unpackhi_ps(xmm2, xmm3); + xmm8 = _mm_shuffle_ps(xmm0,xmm1,0x44); + xmm9 = _mm_shuffle_ps(xmm0,xmm1,0xEE); + + _mm_storeu_ps((float *)(ptr_a10_dup), xmm6); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda), xmm7); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda*2), xmm8); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda*3), xmm9); + + a10 += p_lda; + ptr_a10_dup += p_lda*p_lda; + } } - else if(1 == n_rem) + else { - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); //store B11[4][0-3] + xmm1 = _mm_broadcast_ss((float const *)(&zero)); + for(dim_t x =0;x < m-m_remainder;x++) + { + xmm0 = _mm_loadl_pi(xmm1,(__m64 *)(a10 + x*rs_a)); + _mm_storel_pi((__m64 *)(ptr_a10_dup + x*p_lda), xmm0); + } } - } - } + //cols + for(j = (n - d_nr); (j + 1) > 0; j -= d_nr) //loop along 'N' dimension + { + a10 = D_A_pack; + a11 = L; //pointer to block of A to be used for TRSM + b01 = B + (j* cs_b) + m_remainder; //pointer to block of B to be used for GEMM + b11 = B + (j* cs_b); //pointer to block of B to be used for TRSM - //======================M remainder cases================================ - dim_t m_rem = m-i; - if(m_rem>=4) //implementation for reamainder rows(when 'M' is not a multiple of d_mr) - { - a10 = L + (i*cs_a); //pointer to block of A to be used for GEMM - a11 = L + (i*rs_a) + (i*cs_a); - double *ptr_a10_dup = D_A_pack; - dim_t p_lda = 4; // packed leading dimension + k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) - if(transa) - { - for(dim_t x =0;x < i;x+=p_lda) + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM code begins/// + BLIS_STRSM_SMALL_GEMM_4mx6n(a10,b01,cs_b,p_lda,k_iter) + + ///GEMM code ends/// + ymm16 = _mm256_broadcast_ss((float const *)(&AlphaVal)); + + xmm4 = _mm_broadcast_ss((float const *)(&zero)); + ymm17 = _mm256_insertf128_ps(ymm1, _mm_loadl_pi(xmm4,(__m64 *)(b11)), 0); + ymm18 = _mm256_insertf128_ps(ymm1, _mm_loadl_pi(xmm4,(__m64 *)(b11 + cs_b)), 0); + ymm19 = _mm256_insertf128_ps(ymm1, _mm_loadl_pi(xmm4,(__m64 *)(b11 + cs_b*2)), 0); + ymm20 = _mm256_insertf128_ps(ymm1, _mm_loadl_pi(xmm4,(__m64 *)(b11 + cs_b*3)), 0); + ymm0 = _mm256_insertf128_ps(ymm1, _mm_loadl_pi(xmm4,(__m64 *)(b11 + cs_b*4)), 0); + ymm1 = _mm256_insertf128_ps(ymm1, _mm_loadl_pi(xmm4,(__m64 *)(b11 + cs_b*5)), 0); + + ymm17 = _mm256_fmsub_ps(ymm17, ymm16, ymm8); + ymm18 = _mm256_fmsub_ps(ymm18, ymm16, ymm9); + ymm19 = _mm256_fmsub_ps(ymm19, ymm16, ymm10); + ymm20 = _mm256_fmsub_ps(ymm20, ymm16, ymm11); + ymm0 = _mm256_fmsub_ps(ymm0 , ymm16, ymm4); + ymm1 = _mm256_fmsub_ps(ymm1 , ymm16, ymm5); + + _mm_storel_pi((__m64 *)(b11), _mm256_extractf128_ps(ymm17, 0)); + _mm_storel_pi((__m64 *)(b11 + cs_b), _mm256_extractf128_ps(ymm18, 0)); + _mm_storel_pi((__m64 *)(b11 + cs_b*2), _mm256_extractf128_ps(ymm19, 0)); + _mm_storel_pi((__m64 *)(b11 + cs_b*3), _mm256_extractf128_ps(ymm20, 0)); + _mm_storel_pi((__m64 *)(b11 + cs_b*4), _mm256_extractf128_ps(ymm0, 0)); + _mm_storel_pi((__m64 *)(b11 + cs_b*5), _mm256_extractf128_ps(ymm1, 0)); + + if(transa) + strsm_AltXB_ref(a11, b11, m_remainder, 6, cs_a, cs_b, is_unitdiag); + else + strsm_AuXB_ref(a11, b11, m_remainder, 6, rs_a, cs_b, is_unitdiag); + } + dim_t n_remainder = j + d_nr; + if((n_remainder >= 4)) { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm1 = _mm256_loadu_pd((double const *)(a10 + cs_a)); - ymm2 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); - ymm3 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); + a10 = D_A_pack; + a11 = L; //pointer to block of A to be used for TRSM + b01 = B + ((n_remainder - 4)* cs_b) + m_remainder; //pointer to block of B to be used for GEMM + b11 = B + ((n_remainder - 4)* cs_b); //pointer to block of B to be used for TRSM - ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + ///GEMM code begins/// + BLIS_STRSM_SMALL_GEMM_4mx4n(a10,b01,cs_b,p_lda,k_iter) - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + ymm16 = _mm256_broadcast_ss((float const *)(&AlphaVal)); - _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); + xmm4 = _mm_broadcast_ss((float const *)(&zero)); + ymm17 = _mm256_insertf128_ps(ymm1, _mm_loadl_pi(xmm4,(__m64 *)(b11)), 0); + ymm18 = _mm256_insertf128_ps(ymm1, _mm_loadl_pi(xmm4,(__m64 *)(b11 + cs_b)), 0); + ymm19 = _mm256_insertf128_ps(ymm1, _mm_loadl_pi(xmm4,(__m64 *)(b11 + cs_b*2)), 0); + ymm20 = _mm256_insertf128_ps(ymm1, _mm_loadl_pi(xmm4,(__m64 *)(b11 + cs_b*3)), 0); - a10 += p_lda; - ptr_a10_dup += p_lda*p_lda; + ymm17 = _mm256_fmsub_ps(ymm17, ymm16, ymm8); + ymm18 = _mm256_fmsub_ps(ymm18, ymm16, ymm9); + ymm19 = _mm256_fmsub_ps(ymm19, ymm16, ymm10); + ymm20 = _mm256_fmsub_ps(ymm20, ymm16, ymm11); + + _mm_storel_pi((__m64 *)(b11), _mm256_extractf128_ps(ymm17, 0)); + _mm_storel_pi((__m64 *)(b11 + cs_b), _mm256_extractf128_ps(ymm18, 0)); + _mm_storel_pi((__m64 *)(b11 + cs_b*2), _mm256_extractf128_ps(ymm19, 0)); + _mm_storel_pi((__m64 *)(b11 + cs_b*3), _mm256_extractf128_ps(ymm20, 0)); + + if(transa) + strsm_AltXB_ref(a11, b11, m_remainder, 4, cs_a, cs_b, is_unitdiag); + else + strsm_AuXB_ref(a11, b11, m_remainder, 4, rs_a, cs_b, is_unitdiag); + n_remainder -= 4; } - } - else - { - for(dim_t x =0;x < i;x++) + if(n_remainder) { - ymm0 = _mm256_loadu_pd((double const *)(a10 + rs_a * x)); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * x), ymm0); + a10 = D_A_pack; + a11 = L; //pointer to block of A to be used for TRSM + b01 = B + m_remainder; //pointer to block of B to be used for GEMM + b11 = B; //pointer to block of B to be used for TRSM + + k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + if(3 == n_remainder) + { + ///GEMM code begins/// + BLIS_STRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) + + BLIS_PRE_STRSM_SMALL_2M_3N(AlphaVal,b11,cs_b) + + if(transa) + strsm_AltXB_ref(a11, b11, m_remainder, 3, cs_a, cs_b, is_unitdiag); + else + strsm_AuXB_ref(a11, b11, m_remainder, 3, rs_a, cs_b, is_unitdiag); + } + else if(2 == n_remainder) + { + ///GEMM code begins/// + BLIS_STRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b,p_lda,k_iter) + + BLIS_PRE_STRSM_SMALL_2M_2N(AlphaVal,b11,cs_b) + + if(transa) + strsm_AltXB_ref(a11, b11, m_remainder, 2, cs_a, cs_b, is_unitdiag); + else + strsm_AuXB_ref(a11, b11, m_remainder, 2, rs_a, cs_b, is_unitdiag); + } + else if(1 == n_remainder) + { + ///GEMM code begins/// + BLIS_STRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b,p_lda,k_iter) + + BLIS_PRE_STRSM_SMALL_2M_1N(AlphaVal,b11,cs_b) + if(transa) + strsm_AltXB_ref(a11, b11, m_remainder, 1, cs_a, cs_b, is_unitdiag); + else + strsm_AuXB_ref(a11, b11, m_remainder, 1, rs_a, cs_b, is_unitdiag); + } } } - - ymm4 = _mm256_broadcast_sd((double const *)&ones); - if(!is_unitdiag) + else if(1 == m_remainder) // Repetative A blocks will be 1*1 { + __m128 xmm0,xmm1,xmm2,xmm3; + __m128 xmm4,xmm5; + __m128 xmm6,xmm7,xmm8,xmm9; + dim_t p_lda = 4; // packed leading dimension if(transa) { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_sd((double const *)(a11)); - ymm1 = _mm256_broadcast_sd((double const *)(a11+cs_a*1 + 1)); - ymm2 = _mm256_broadcast_sd((double const *)(a11+cs_a*2 + 2)); - ymm3 = _mm256_broadcast_sd((double const *)(a11+cs_a*3 + 3)); + for(dim_t x =0;x < m-m_remainder;x+=p_lda) + { + xmm0 = _mm_loadu_ps((float const *)(a10)); + xmm1 = _mm_broadcast_ss((float const *)&ones); + xmm2 = _mm_broadcast_ss((float const *)&ones); + xmm3 = _mm_broadcast_ss((float const *)&ones); + + xmm4 = _mm_unpacklo_ps(xmm0, xmm1); + xmm5 = _mm_unpacklo_ps(xmm2, xmm3); + xmm6 = _mm_shuffle_ps(xmm4,xmm5,0x44); + xmm7 = _mm_shuffle_ps(xmm4,xmm5,0xEE); + + xmm0 = _mm_unpackhi_ps(xmm0, xmm1); + xmm1 = _mm_unpackhi_ps(xmm2, xmm3); + xmm8 = _mm_shuffle_ps(xmm0,xmm1,0x44); + xmm9 = _mm_shuffle_ps(xmm0,xmm1,0xEE); + + _mm_storeu_ps((float *)(ptr_a10_dup), xmm6); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda), xmm7); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda*2), xmm8); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda*3), xmm9); + + a10 += p_lda; + ptr_a10_dup += p_lda*p_lda; + } } else { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_sd((double const *)(a11)); - ymm1 = _mm256_broadcast_sd((double const *)(a11+rs_a*1 + 1)); - ymm2 = _mm256_broadcast_sd((double const *)(a11+rs_a*2 + 2)); - ymm3 = _mm256_broadcast_sd((double const *)(a11+rs_a*3 + 3)); - } - - ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); - ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); - #ifdef BLIS_DISABLE_TRSM_PREINVERSION - ymm4 = ymm1; - #endif - #ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm4 = _mm256_div_pd(ymm4, ymm1); - #endif - } - _mm256_storeu_pd((double *)(d11_pack), ymm4); - - for(j = 0; (j+d_nr-1) < n; j += d_nr) //loop along 'N' dimension - { - a10 = D_A_pack; //pointer to block of A to be used for GEMM - a11 = L + (i*rs_a) + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM - b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM - - k_iter = i; //number of times GEMM operation to be done(in blocks of 4x4) + __m128 xmm0; + for(dim_t x =0;x < m-m_remainder;x++) + { + xmm0 = _mm_broadcast_ss((float const *)(a10 + x*rs_a)); + _mm_store_ss((float *)(ptr_a10_dup + x*p_lda), xmm0); + } + } + //cols + for(j = (n - d_nr); (j + 1) > 0; j -= d_nr) //loop along 'N' dimension + { + a10 = D_A_pack; + a11 = L; //pointer to block of A to be used for TRSM + b01 = B + (j* cs_b) + m_remainder; //pointer to block of B to be used for GEMM + b11 = B + (j* cs_b); //pointer to block of B to be used for TRSM - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS + k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) - ///GEMM code begins/// - BLIS_DTRSM_SMALL_GEMM_4mx6n(a10,b01,cs_b,p_lda,k_iter) + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + ///GEMM code begins/// + BLIS_STRSM_SMALL_GEMM_4mx6n(a10,b01,cs_b,p_lda,k_iter) - ///implement TRSM/// - ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); //B11[0-3][3] * alpha -= B01[0-3][3] + ///GEMM code ends/// + ymm16 = _mm256_broadcast_ss((float const *)(&AlphaVal)); + + ymm17 = _mm256_broadcast_ss((float const*)(b11)); + ymm18 = _mm256_broadcast_ss((float const*)(b11 + cs_b)); + ymm19 = _mm256_broadcast_ss((float const*)(b11 + cs_b*2)); + ymm20 = _mm256_broadcast_ss((float const*)(b11 + cs_b*3)); + ymm0 = _mm256_broadcast_ss((float const*)(b11 + cs_b*4)); + ymm1 = _mm256_broadcast_ss((float const*)(b11 + cs_b*5)); + + ymm17 = _mm256_fmsub_ps(ymm17, ymm16, ymm8); + ymm18 = _mm256_fmsub_ps(ymm18, ymm16, ymm9); + ymm19 = _mm256_fmsub_ps(ymm19, ymm16, ymm10); + ymm20 = _mm256_fmsub_ps(ymm20, ymm16, ymm11); + ymm0 = _mm256_fmsub_ps(ymm0 , ymm16, ymm4); + ymm1 = _mm256_fmsub_ps(ymm1 , ymm16, ymm5); + + _mm_store_ss((float *)(b11), _mm256_extractf128_ps(ymm17,0)); + _mm_store_ss((float *)(b11 + cs_b), _mm256_extractf128_ps(ymm18,0)); + _mm_store_ss((float *)(b11 + cs_b*2), _mm256_extractf128_ps(ymm19,0)); + _mm_store_ss((float *)(b11 + cs_b*3), _mm256_extractf128_ps(ymm20,0)); + _mm_store_ss((float *)(b11 + cs_b*4), _mm256_extractf128_ps(ymm0,0)); + _mm_store_ss((float *)(b11 + cs_b*5), _mm256_extractf128_ps(ymm1,0)); - ///transpose of B11// - ///unpacklow/// - ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + if(transa) + strsm_AltXB_ref(a11, b11, m_remainder, 6, cs_a, cs_b, is_unitdiag); + else + strsm_AuXB_ref(a11, b11, m_remainder, 6, rs_a, cs_b, is_unitdiag); + } + dim_t n_remainder = j + d_nr; + if((n_remainder >= 4)) + { + a10 = D_A_pack; + a11 = L; //pointer to block of A to be used for TRSM + b01 = B + ((n_remainder - 4)* cs_b) + m_remainder; //pointer to block of B to be used for GEMM + b11 = B + ((n_remainder - 4)* cs_b); //pointer to block of B to be used for TRSM - //rearrange low elements - ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS - //rearrange high elements - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + ///GEMM code begins/// + BLIS_STRSM_SMALL_GEMM_4mx4n(a10,b01,cs_b,p_lda,k_iter) + ymm16 = _mm256_broadcast_ss((float const *)(&AlphaVal)); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); + ymm17 = _mm256_broadcast_ss((float const*)(b11)); + ymm18 = _mm256_broadcast_ss((float const*)(b11 + cs_b)); + ymm19 = _mm256_broadcast_ss((float const*)(b11 + cs_b*2)); + ymm20 = _mm256_broadcast_ss((float const*)(b11 + cs_b*3)); + ymm17 = _mm256_fmsub_ps(ymm17, ymm16, ymm8); + ymm18 = _mm256_fmsub_ps(ymm18, ymm16, ymm9); + ymm19 = _mm256_fmsub_ps(ymm19, ymm16, ymm10); + ymm20 = _mm256_fmsub_ps(ymm20, ymm16, ymm11); - ymm16 = _mm256_broadcast_sd((double const *)(&ones)); + _mm_store_ss((float *)(b11), _mm256_extractf128_ps(ymm17,0)); + _mm_store_ss((float *)(b11 + cs_b), _mm256_extractf128_ps(ymm18,0)); + _mm_store_ss((float *)(b11 + cs_b*2), _mm256_extractf128_ps(ymm19,0)); + _mm_store_ss((float *)(b11 + cs_b*3), _mm256_extractf128_ps(ymm20,0)); - ////unpacklow//// - ymm7 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - //ymm16; + if(transa) + strsm_AltXB_ref(a11, b11, m_remainder, 4, cs_a, cs_b, is_unitdiag); + else + strsm_AuXB_ref(a11, b11, m_remainder, 4, rs_a, cs_b, is_unitdiag); + n_remainder -= 4; + } + if(n_remainder) + { + a10 = D_A_pack; + a11 = L; //pointer to block of A to be used for TRSM + b01 = B + m_remainder; //pointer to block of B to be used for GEMM + b11 = B; //pointer to block of B to be used for TRSM - //rearrange low elements - ymm4 = _mm256_permute2f128_pd(ymm7,ymm16,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm6 = _mm256_permute2f128_pd(ymm7,ymm16,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3] + k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - //ymm16; + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS - //rearrange high elements - ymm5 = _mm256_permute2f128_pd(ymm0,ymm16,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm7 = _mm256_permute2f128_pd(ymm0,ymm16,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - //b11 transpose end + if(3 == n_remainder) + { + ///GEMM code begins/// + BLIS_STRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) - ////extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + BLIS_PRE_STRSM_SMALL_1M_3N(AlphaVal,b11,cs_b) - //perform mul operation - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); - ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm1); + if(transa) + strsm_AltXB_ref(a11, b11, m_remainder, 3, cs_a, cs_b, is_unitdiag); + else + strsm_AuXB_ref(a11, b11, m_remainder, 3, rs_a, cs_b, is_unitdiag); + } + else if(2 == n_remainder) + { + ///GEMM code begins/// + BLIS_STRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b,p_lda,k_iter) - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + BLIS_PRE_STRSM_SMALL_1M_2N(AlphaVal,b11,cs_b) - //(ROw1): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*1)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); - ymm5 = _mm256_fnmadd_pd(ymm2, ymm4, ymm5); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm8, ymm10); - ymm6 = _mm256_fnmadd_pd(ymm2, ymm4, ymm6); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); - ymm7 = _mm256_fnmadd_pd(ymm2, ymm4, ymm7); + if(transa) + strsm_AltXB_ref(a11, b11, m_remainder, 2, cs_a, cs_b, is_unitdiag); + else + strsm_AuXB_ref(a11, b11, m_remainder, 2, rs_a, cs_b, is_unitdiag); + } + else if(1 == n_remainder) + { + ///GEMM code begins/// + BLIS_STRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b,p_lda,k_iter) + BLIS_PRE_STRSM_SMALL_1M_1N(AlphaVal,b11,cs_b) - //perform mul operation - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm1); + if(transa) + strsm_AltXB_ref(a11, b11, m_remainder, 1, cs_a, cs_b, is_unitdiag); + else + strsm_AuXB_ref(a11, b11, m_remainder, 1, rs_a, cs_b, is_unitdiag); + } + } + } - a11 += rs_a; + if ((required_packing_A == 1) && + bli_mem_is_alloc( &local_mem_buf_A_s )) + { + bli_membrk_release(&rntm,&local_mem_buf_A_s); + } + return BLIS_SUCCESS; +} - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); +BLIS_INLINE err_t bli_ztrsm_small_AutXB_AlXB +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +) +{ + dim_t m = bli_obj_length(b); // number of rows of matrix B + dim_t n = bli_obj_width(b); // number of columns of matrix B - //(ROw2): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm9, ymm10); - ymm6 = _mm256_fnmadd_pd(ymm2, ymm5, ymm6); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm9, ymm11); - ymm7 = _mm256_fnmadd_pd(ymm2, ymm5, ymm7); + bool transa = bli_obj_has_trans(a); + bool conjtransa = bli_obj_has_conj(a); - //perform mul operation - ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); - ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm1); + dim_t cs_a, rs_a; + dim_t d_mr = 4,d_nr = 3; - a11 += rs_a; + // Swap rs_a & cs_a in case of non-tranpose. + if(transa) + { + cs_a = bli_obj_col_stride(a); // column stride of A + rs_a = bli_obj_row_stride(a); // row stride of A + } + else + { + cs_a = bli_obj_row_stride(a); // row stride of A + rs_a = bli_obj_col_stride(a); // column stride of A + } + dim_t cs_b = bli_obj_col_stride(b); // column stride of B - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + dim_t i, j, k; //loop variables + dim_t k_iter; //number of times GEMM to be performed - //(ROw5): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm10, ymm11); - ymm7 = _mm256_fnmadd_pd(ymm2, ymm6, ymm7); + dcomplex AlphaVal = *(dcomplex *)AlphaObj->buffer; //value of alpha + dcomplex *L = a->buffer; //pointer to matrix A + dcomplex *B = b->buffer; //pointer to matrix B - //perform mul operation - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm1); + dcomplex *a10, *a11, *b01, *b11; //pointers that point to blocks for GEMM and TRSM - a11 += rs_a; + dcomplex ones = {1.0, 1.0}; + bool is_unitdiag = bli_obj_has_unit_diag(a); - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + //scratch registers + __m256d ymm0, ymm1, ymm2, ymm3; + __m256d ymm4, ymm5, ymm6, ymm7; + __m256d ymm8, ymm9, ymm10, ymm11; + __m256d ymm12, ymm13, ymm14, ymm15; + __m256d ymm16, ymm17, ymm18, ymm19; - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + __m128d xmm5, xmm4; - ///unpack high/// - ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + gint_t required_packing_A = 1; + mem_t local_mem_buf_A_s = {0}; + dcomplex *D_A_pack = NULL; + dcomplex d11_pack[d_mr] __attribute__((aligned(64))); + rntm_t rntm; - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + bli_rntm_init_from_global( &rntm ); + bli_rntm_set_num_threads_only( 1, &rntm ); + bli_membrk_rntm_set_membrk( &rntm ); - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] + siz_t buffer_size = bli_pool_block_size( + bli_membrk_pool( + bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), + bli_rntm_membrk(&rntm))); - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + if ( (d_mr * m * sizeof(dcomplex)) > buffer_size) + return BLIS_NOT_YET_IMPLEMENTED; - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + if (required_packing_A == 1) + { + // Get the buffer from the pool. + bli_membrk_acquire_m(&rntm, + buffer_size, + BLIS_BITVAL_BUFFER_FOR_A_BLOCK, + &local_mem_buf_A_s); + if(FALSE==bli_mem_is_alloc(&local_mem_buf_A_s)) return BLIS_NULL_POINTER; + D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); + if(NULL==D_A_pack) return BLIS_NULL_POINTER; + } - ///unpack high/// - ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + /* + Performs solving TRSM for 4 colmns at a time from 0 to m/4 in steps of d_mr + a. Load, transpose, Pack A (a10 block), the size of packing 4x3 to 4x (m-4) + First there will be no GEMM and no packing of a10 because it is only TRSM + b. Using packed a10 block and b01 block perform GEMM operation + c. Use GEMM outputs, perform TRSM operaton using a11, b11 and update B + d. Repeat b,c for n rows of B in steps of d_nr + */ + for(i = 0;(i+d_mr-1) < m; i += d_mr) //loop along 'M' dimension + { + a10 = L + (i*cs_a); //pointer to block of A to be used for GEMM + a11 = L + (i*rs_a) + (i*cs_a); + dim_t p_lda = d_mr; // packed leading dimension - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + if(transa) + { + /* + Load, tranpose and pack current A block (a10) into packed buffer memory + D_A_pack + a. This a10 block is used in GEMM portion only and this + a10 block size will be increasing by d_mr for every next itteration + untill it reaches 4x(m-4) which is the maximum GEMM alone block size + in A + b. This packed buffer is reused to calculate all n rows of B matrix + */ + bli_ztrsm_small_pack('L', i, 1, a10, cs_a, D_A_pack, p_lda,d_mr); - _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store B11[1][0-3] + /* + Pack 4 diagonal elements of A block into an array + a. This helps in utilze cache line efficiently in TRSM operation + b. store ones when input is unit diagonal + */ + ztrsm_small_pack_diag_element(is_unitdiag,a11,cs_a,d11_pack,d_mr); } - - dim_t n_rem = n-j; - if(n_rem >= 4) + else + { + bli_ztrsm_small_pack('L', i, 0, a10, rs_a, D_A_pack, p_lda,d_mr); + ztrsm_small_pack_diag_element(is_unitdiag,a11,rs_a,d11_pack,d_mr); + } + /* + a. Perform GEMM using a10, b01. + b. Perform TRSM on a11, b11 + c. This loop GEMM+TRSM loops operates with 4x3 block size + along n dimension for every d_nr rows of b01 where + packed A buffer is reused in computing all n rows of B. + d. Same approch is used in remaining fringe cases. + */ + dim_t temp = n - d_nr + 1; + for(j = 0; j < temp; j += d_nr) //loop along 'N' dimension { a10 = D_A_pack; - a11 = L + (i*rs_a) + (i*cs_a); //pointer to block of A to be used for TRSM + a11 = L + (i*rs_a) + (i*cs_a); //pointer to block of A to be used for TRSM b01 = B + j*cs_b; //pointer to block of B to be used for GEMM b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM - k_iter = i; //number of times GEMM to be performed(in blocks of 4x4) + k_iter = i; /*Fill zeros into ymm registers used in gemm accumulations */ BLIS_SET_YMM_REG_ZEROS - BLIS_DTRSM_SMALL_GEMM_4mx4n(a10,b01,cs_b,p_lda,k_iter) - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - - ///transpose of B11// - ///unpacklow/// - ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + /* + Peform GEMM between a10 and b01 blocks + For first itteration there will be no GEMM operation + where k_iter are zero + */ + BLIS_ZTRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) - //rearrange low elements - ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + /* + Load b11 of size 3x4 and multiply with alpha + Add the GEMM output and perform inregister transose of b11 + to peform TRSM operation. + */ + BLIS_ZTRSM_SMALL_NREG_TRANSPOSE_3x4(b11,cs_b,AlphaVal) + /* + Compute 4x3 TRSM block by using GEMM block output in register + a. The 4x3 input (gemm outputs) are stored in combinations of ymm + registers + 1. ymm8, ymm4 2. ymm9, ymm5 3. ymm10, ymm6, 4. ymm11, ymm7 + where ymm8-ymm11 holds 4x2 data and reaming 4x1 will be hold by + other registers + b. Towards the end do in regiser transpose of TRSM output and store in + b11 + */ + ////extract a00 + ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm8 and ymm4 with ymm1*/ + BLIS_ZTRSM_TWO_DIV(ymm8,ymm4) +#else + /*performs dcomplex multiplication of ymm8 and ymm4 with ymm1*/ + BLIS_ZTRSM_MUL(ymm8) + BLIS_ZTRSM_MUL(ymm4) +#endif + //extract a11 + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack + 1)); + //(ROW1): FMA operations + ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a*1)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + /* Step1 dcomplex multiply ymm2, ymm8 + * Step2 negate the result + * Step3 add ymm9*/ + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + //For ymm8 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm8, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm8, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); - //rearrange high elements - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + //For ymm4 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); - ymm0 = _mm256_broadcast_sd((double const *)&ones); + ymm13 = _mm256_mul_pd(ymm4, ymm2); + ymm14 = _mm256_mul_pd(ymm4, ymm14); + ymm17 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + ymm17 = _mm256_mul_pd(ymm17, ymm15); - //extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + //Step 3 + ymm9 = _mm256_add_pd(ymm16, ymm9); + ymm5 = _mm256_add_pd(ymm17, ymm5); - //perform mul operation - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); + ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a*2)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + //For ymm8 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm8, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm8, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + //For ymm4 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*1)); - ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); - ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + ymm13 = _mm256_mul_pd(ymm4, ymm2); + ymm14 = _mm256_mul_pd(ymm4, ymm14); + ymm17 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + ymm17 = _mm256_mul_pd(ymm17, ymm15); - a11 += rs_a; + //Step 3 + ymm10 = _mm256_add_pd(ymm16, ymm10); + ymm6 = _mm256_add_pd(ymm17, ymm6); - //(ROw1): FMA operations - ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); - ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); - ymm11 = _mm256_fnmadd_pd(ymm4, ymm8, ymm11); + ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a*3)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } - //perform mul operation - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + //For ymm8 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm8, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm8, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + //For ymm4 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); - ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); - ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + ymm13 = _mm256_mul_pd(ymm4, ymm2); + ymm14 = _mm256_mul_pd(ymm4, ymm14); + ymm17 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + ymm17 = _mm256_mul_pd(ymm17, ymm15); + //Step 3 + ymm11 = _mm256_add_pd(ymm16, ymm11); + ymm7 = _mm256_add_pd(ymm17, ymm7); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm9 and ymm5 with ymm1*/ + BLIS_ZTRSM_TWO_DIV(ymm9,ymm5) +#else + /*performs dcomplex multiplication of ymm9 and ymm5 with ymm1*/ + BLIS_ZTRSM_MUL(ymm9) + BLIS_ZTRSM_MUL(ymm5) +#endif a11 += rs_a; - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(ROw2): FMA operations - ymm10 = _mm256_fnmadd_pd(ymm3, ymm9, ymm10); - ymm11 = _mm256_fnmadd_pd(ymm4, ymm9, ymm11); - - //perform mul operation - ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack + 2)); - ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + //(ROW2): FMA operations + ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a*2)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); - a11 += rs_a; + //For ymm9 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm9, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm9, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + //For ymm5 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + ymm13 = _mm256_mul_pd(ymm5, ymm2); + ymm14 = _mm256_mul_pd(ymm5, ymm14); + ymm17 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + ymm17 = _mm256_mul_pd(ymm17, ymm15); + //Step 3 + ymm10 = _mm256_add_pd(ymm16, ymm10); + ymm6 = _mm256_add_pd(ymm17, ymm6); - //(ROw5): FMA operations - ymm11 = _mm256_fnmadd_pd(ymm4, ymm10, ymm11); + ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a*3)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); - //perform mul operation - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); + //For ymm9 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm9, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm9, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + //For ymm5 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + ymm13 = _mm256_mul_pd(ymm5, ymm2); + ymm14 = _mm256_mul_pd(ymm5, ymm14); + ymm17 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + ymm17 = _mm256_mul_pd(ymm17, ymm15); + //Step 3 + ymm11 = _mm256_add_pd(ymm16, ymm11); + ymm7 = _mm256_add_pd(ymm17, ymm7); - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm10 and ymm6 with ymm1*/ + BLIS_ZTRSM_TWO_DIV(ymm10,ymm6) +#else + /*performs dcomplex multiplication of ymm10 and ymm6 with ymm1*/ + BLIS_ZTRSM_MUL(ymm10) + BLIS_ZTRSM_MUL(ymm6) +#endif + a11 += rs_a; + //extract a44 + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack + 3)); + //(ROW3): FMA operations + ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a*3)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } - ///unpack high/// - ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + //For ymm10 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm10, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm10, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + //For ymm6 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] + ymm13 = _mm256_mul_pd(ymm6, ymm2); + ymm14 = _mm256_mul_pd(ymm6, ymm14); + ymm17 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + ymm17 = _mm256_mul_pd(ymm17, ymm15); + //Step 3 + ymm11 = _mm256_add_pd(ymm16, ymm11); + ymm7 = _mm256_add_pd(ymm17, ymm7); - n_rem -= 4; - j += 4; +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm11 and ymm7 with ymm1*/ + BLIS_ZTRSM_TWO_DIV(ymm11,ymm7) +#else + /*performs dcomplex nultiplication of ymm11 and ymm7 with ymm1*/ + BLIS_ZTRSM_MUL(ymm11) + BLIS_ZTRSM_MUL(ymm7) +#endif + a11 += rs_a; + BLIS_ZTRSM_SMALL_NREG_TRANSPOSE_4x3_AND_STORE(b11,cs_b) } + + dim_t n_rem = n-j; if(n_rem) { a10 = D_A_pack; - a11 = L + (i*rs_a) + (i*cs_a); //pointer to block of A to be used for TRSM + a11 = L + (i*rs_a) + (i*cs_a);//pointer to block of A to be used for TRSM b01 = B + j*cs_b; //pointer to block of B to be used for GEMM b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM - k_iter = i; //number of times GEMM to be performed(in blocks of 4x4) - - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - - if(3 == n_rem) - { - ///GEMM code begins/// - BLIS_DTRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + k_iter = i; //number of times GEMM to be performed(in blocks of 4x4) - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] - ymm3 = _mm256_broadcast_sd((double const *)(&ones)); - } - else if(2 == n_rem) + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + if(2 == n_rem) { ///GEMM code begins/// - BLIS_DTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b,p_lda,k_iter) - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] - ymm2 = _mm256_broadcast_sd((double const *)(&ones)); - ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + BLIS_ZTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b,p_lda,k_iter) + BLIS_ZTRSM_SMALL_NREG_TRANSPOSE_2x4(b11,cs_b,AlphaVal) } else if(1 == n_rem) { ///GEMM code begins/// - BLIS_DTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b,p_lda,k_iter) - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_broadcast_sd((double const *)(&ones)); - ymm2 = _mm256_broadcast_sd((double const *)(&ones)); - ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + BLIS_ZTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b,p_lda,k_iter) + BLIS_ZTRSM_SMALL_NREG_TRANSPOSE_1x4(b11,cs_b,AlphaVal) } + ///implement TRSM/// ///transpose of B11// - ///unpacklow/// - ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - - //rearrange low elements - ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] - - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - - //rearrange high elements - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - - ymm0 = _mm256_broadcast_sd((double const *)&ones); + ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + ymm10 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm11 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); ////extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); - //perform mul operation - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_ZTRSM_DIV(ymm8) +#else + BLIS_ZTRSM_MUL(ymm8) +#endif //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack + 1)); + //(ROW1): FMA operations + ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a*1)); + ymm3 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a*2)); + ymm4 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a*3)); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*1)); - ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); - ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + if(conjtransa){ + ymm2 = _mm256_mul_pd(ymm2, ymm0); + ymm3 = _mm256_mul_pd(ymm3, ymm0); + ymm4 = _mm256_mul_pd(ymm4, ymm0); + } a11 += rs_a; + /*Step1 dcomplex multiply ymmx, ymmx + * Step2 negate the result + * Step3 add ymmx*/ + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + //For ymm8 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm8, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm8, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); - //(ROw1): FMA operations - ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); - ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); - ymm11 = _mm256_fnmadd_pd(ymm4, ymm8, ymm11); - - //perform mul operation - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); - - ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); - ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); - a11 += rs_a; + //Step 3 + ymm9 = _mm256_add_pd(ymm16, ymm9); - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + //Step 1 + ymm14 = _mm256_permute_pd(ymm3, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + //For ymm8 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm8, ymm3); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm8, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); - //(ROw2): FMA operations - ymm10 = _mm256_fnmadd_pd(ymm3, ymm9, ymm10); - ymm11 = _mm256_fnmadd_pd(ymm4, ymm9, ymm11); + //Step 3 + ymm10 = _mm256_add_pd(ymm16, ymm10); - //perform mul operation - ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + //Step 1 + ymm14 = _mm256_permute_pd(ymm4, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + //For ymm8 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm8, ymm4); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm8, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + //Step 3 + ymm11 = _mm256_add_pd(ymm16, ymm11); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_ZTRSM_DIV(ymm9) +#else + BLIS_ZTRSM_MUL(ymm9) +#endif + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack + 2)); + ymm3 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a*2)); + ymm4 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a*3)); - ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + if(conjtransa){ + ymm3 = _mm256_mul_pd(ymm3, ymm0); + ymm4 = _mm256_mul_pd(ymm4, ymm0); + } a11 += rs_a; + //Step 1 + ymm14 = _mm256_permute_pd(ymm3, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + //For ymm9 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm9, ymm3); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm9, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - //(ROw5): FMA operations - ymm11 = _mm256_fnmadd_pd(ymm4, ymm10, ymm11); - - //perform mul operation - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + //Step 3 + ymm10 = _mm256_add_pd(ymm16, ymm10); - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + //Step 1 + ymm14 = _mm256_permute_pd(ymm4, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + //For ymm8 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm9, ymm4); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm9, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + //Step 3 + ymm11 = _mm256_add_pd(ymm16, ymm11); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_ZTRSM_DIV(ymm10) +#else + BLIS_ZTRSM_MUL(ymm10) +#endif - ///unpack high/// - ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack + 3)); + ymm4 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a*3)); - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + if(conjtransa){ + ymm4 = _mm256_mul_pd(ymm4, ymm0); + } - if(3 == n_rem) - { - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] - } - else if(2 == n_rem) - { - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - } - else if(1 == n_rem) - { - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - } + //Step 1 + ymm14 = _mm256_permute_pd(ymm4, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + //For ymm10 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm10, ymm4); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm10, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + //Step 3 + ymm11 = _mm256_add_pd(ymm16, ymm11); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_ZTRSM_DIV(ymm11) +#else + BLIS_ZTRSM_MUL(ymm11) +#endif + if(n_rem == 1) + { + ymm0 = _mm256_permute2f128_pd(ymm8,ymm9,0x20); + ymm4 = _mm256_permute2f128_pd(ymm10,ymm11,0x20); + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 2), ymm4); + } + else if(n_rem == 2) + { + ymm0 = _mm256_permute2f128_pd(ymm8,ymm9,0x20); + ymm4 = _mm256_permute2f128_pd(ymm10,ymm11,0x20); + ymm1 = _mm256_permute2f128_pd(ymm8,ymm9,0x31); + ymm3 = _mm256_permute2f128_pd(ymm10,ymm11,0x31); + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 2), ymm4); + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); + _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 2), ymm3); } - m_rem -=4; - i +=4; } - + } + dim_t m_rem = m-i; if(m_rem) { - a10 = L + (i*cs_a); //pointer to block of A to be used for GEMM - // Do transpose for a10 & store in D_A_pack - double *ptr_a10_dup = D_A_pack; - if(3 == m_rem) // Repetative A blocks will be 3*3 + a10 = L + (i*cs_a); + dcomplex *ptr_a10_dup = D_A_pack; + if(m_rem == 3) { - dim_t p_lda = 4; // packed leading dimension + dim_t p_lda = 4; if(transa) { - for(dim_t x=0;x= 4)) - { - a10 = D_A_pack; //pointer to block of A to be used for GEMM - a11 = L + (i*rs_a) + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM - b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM - - k_iter = i; //number of times GEMM to be performed(in blocks of 4x4) - - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS - - ///GEMM code begins/// - BLIS_DTRSM_SMALL_GEMM_4mx4n(a10,b01,cs_b,p_lda,k_iter) - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); - ymm3 = _mm256_broadcast_sd((double const *)(b11 + cs_b*3 + 2)); - ymm3 = _mm256_insertf128_pd(ymm3, xmm5, 0); - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x08); - - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - xmm5 = _mm256_extractf128_pd(ymm3, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 3),xmm5); - _mm_storel_pd((b11 + cs_b * 3 + 2), _mm256_extractf128_pd(ymm3, 1)); - - if(transa) - dtrsm_AutXB_ref(a11, b11, m_rem, 4, cs_a, cs_b,is_unitdiag); - else - dtrsm_AlXB_ref(a11, b11, m_rem, 4, rs_a, cs_b, is_unitdiag); - n_rem -= 4; - j +=4; - } - if(n_rem) { - a10 = D_A_pack; //pointer to block of A to be used for GEMM - a11 = L + (i*rs_a) + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM - b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM - - k_iter = i; //number of times GEMM to be performed(in blocks of 4x4) - - /*Fill zeros into ymm registers used in gemm accumulations */ + a10 = D_A_pack; + a11 = L + (i*rs_a) + (i*cs_a); + b01 = B + (j*cs_b); + b11 = B + i + (j* cs_b); + k_iter = i; BLIS_SET_YMM_REG_ZEROS - - if(3 == n_rem) + if(2 == n_rem) { ///GEMM code begins/// - BLIS_DTRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) - - BLIS_PRE_DTRSM_SMALL_3M_3N(AlphaVal,b11,cs_b) + BLIS_ZTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b, + p_lda,k_iter) + BLIS_PRE_ZTRSM_SMALL_3M_2N(AlphaVal,b11,cs_b) if(transa) - dtrsm_AutXB_ref(a11, b11, m_rem, 3, cs_a, cs_b,is_unitdiag); - else - dtrsm_AlXB_ref(a11, b11, m_rem, 3, rs_a, cs_b, is_unitdiag); - } - else if(2 == n_rem) - { - ///GEMM code begins/// - BLIS_DTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b,p_lda,k_iter) - - BLIS_PRE_DTRSM_SMALL_3M_2N(AlphaVal,b11,cs_b) + ztrsm_AutXB_ref(a11, b11, m_rem, 2, + cs_a, cs_b, is_unitdiag, + conjtransa); - if(transa) - dtrsm_AutXB_ref(a11, b11, m_rem, 2, cs_a, cs_b,is_unitdiag); else - dtrsm_AlXB_ref(a11, b11, m_rem, 2, rs_a, cs_b, is_unitdiag); + ztrsm_AlXB_ref(a11, b11, m_rem, 2, + rs_a, cs_b, is_unitdiag, + conjtransa); } else if(1 == n_rem) { ///GEMM code begins/// - BLIS_DTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b,p_lda,k_iter) - - BLIS_PRE_DTRSM_SMALL_3M_1N(AlphaVal,b11,cs_b) + BLIS_ZTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b, + p_lda,k_iter) + BLIS_PRE_ZTRSM_SMALL_3M_1N(AlphaVal,b11,cs_b) if(transa) - dtrsm_AutXB_ref(a11, b11, m_rem, 1, cs_a, cs_b, is_unitdiag); + ztrsm_AutXB_ref(a11, b11, m_rem, 1, + cs_a, cs_b, is_unitdiag, + conjtransa); else - dtrsm_AlXB_ref(a11, b11, m_rem, 1, rs_a, cs_b, is_unitdiag); + ztrsm_AlXB_ref(a11, b11, m_rem, 1, + rs_a, cs_b, is_unitdiag, + conjtransa); + } } + m_rem -=3; + i+=3; } - else if(2 == m_rem) // Repetative A blocks will be 2*2 + else if(m_rem == 2) { - dim_t p_lda = 4; // packed leading dimension - if(transa) - { - for(dim_t x=0;x= 4)) - { - a10 = D_A_pack; //pointer to block of A to be used for GEMM - a11 = L + (i*rs_a) + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM - b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM - - k_iter = i; //number of times GEMM to be performed(in blocks of 4x4) - - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS - - ///GEMM code begins/// - BLIS_DTRSM_SMALL_GEMM_4mx4n(a10,b01,cs_b,p_lda,k_iter) - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); - ymm3 = _mm256_insertf128_pd(ymm3, xmm5, 0); - - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0C); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0C); - - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - xmm5 = _mm256_extractf128_pd(ymm3, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 3), xmm5); - - if(transa) - dtrsm_AutXB_ref(a11, b11, m_rem, 4, cs_a, cs_b, is_unitdiag); - else - dtrsm_AlXB_ref(a11, b11, m_rem, 4, rs_a, cs_b, is_unitdiag); - n_rem -= 4; - j +=4; - } if(n_rem) { - a10 = D_A_pack; //pointer to block of A to be used for GEMM - a11 = L + (i*rs_a) + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM - b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM - - k_iter = i; //number of times GEMM to be performed(in blocks of 4x4) - - /*Fill zeros into ymm registers used in gemm accumulations */ + a10 = D_A_pack; + a11 = L + (i*rs_a) + (i*cs_a); + b01 = B + (j*cs_b); + b11 = B + i + (j* cs_b); + k_iter = i; BLIS_SET_YMM_REG_ZEROS - - if(3 == n_rem) + if(2 == n_rem) { ///GEMM code begins/// - BLIS_DTRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) - - BLIS_PRE_DTRSM_SMALL_2M_3N(AlphaVal,b11,cs_b) + BLIS_ZTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b, + p_lda,k_iter) + BLIS_PRE_ZTRSM_SMALL_2M_2N(AlphaVal,b11,cs_b) if(transa) - dtrsm_AutXB_ref(a11, b11, m_rem, 3, cs_a, cs_b, is_unitdiag); - else - dtrsm_AlXB_ref(a11, b11, m_rem, 3, rs_a, cs_b, is_unitdiag); - } - else if(2 == n_rem) - { - ///GEMM code begins/// - BLIS_DTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b,p_lda,k_iter) - - BLIS_PRE_DTRSM_SMALL_2M_2N(AlphaVal,b11,cs_b) + ztrsm_AutXB_ref(a11, b11, m_rem, 2, + cs_a, cs_b, is_unitdiag, + conjtransa); - if(transa) - dtrsm_AutXB_ref(a11, b11, m_rem, 2, cs_a, cs_b, is_unitdiag); else - dtrsm_AlXB_ref(a11, b11, m_rem, 2, rs_a, cs_b, is_unitdiag); + ztrsm_AlXB_ref(a11, b11, m_rem, 2, + rs_a, cs_b, is_unitdiag, + conjtransa); } else if(1 == n_rem) { ///GEMM code begins/// - BLIS_DTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b,p_lda,k_iter) + BLIS_ZTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b, + p_lda,k_iter) - BLIS_PRE_DTRSM_SMALL_2M_1N(AlphaVal,b11,cs_b) + BLIS_PRE_ZTRSM_SMALL_2M_1N(AlphaVal,b11,cs_b) if(transa) - dtrsm_AutXB_ref(a11, b11, m_rem, 1, cs_a, cs_b, is_unitdiag); + ztrsm_AutXB_ref(a11, b11, m_rem, 1, + cs_a, cs_b, is_unitdiag, + conjtransa); + else - dtrsm_AlXB_ref(a11, b11, m_rem, 1, rs_a, cs_b, is_unitdiag); + ztrsm_AlXB_ref(a11, b11, m_rem, 1, + rs_a, cs_b, is_unitdiag, + conjtransa); } } m_rem -=2; i+=2; } - else if(1 == m_rem) // Repetative A blocks will be 1*1 + else if(m_rem == 1) { - dim_t p_lda = 4; // packed leading dimension + dim_t p_lda = 2; // packed leading dimension if(transa) { - for(dim_t x=0;x= 4)) - { - a10 = D_A_pack; //pointer to block of A to be used for GEMM - a11 = L + (i*rs_a) + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM - b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM - - k_iter = i; //number of times GEMM to be performed(in blocks of 4x4) - - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS + ymm14 = _mm256_permute_pd(ymm16, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm18); + ymm17 = _mm256_mul_pd(ymm0, ymm16); + ymm14 = _mm256_mul_pd(ymm0, ymm14); + ymm15 = _mm256_hsub_pd(ymm17, ymm14); - ///GEMM code begins/// - BLIS_DTRSM_SMALL_GEMM_4mx4n(a10,b01,cs_b,p_lda,k_iter) + ymm8 = _mm256_sub_pd(ymm15,ymm8); - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + ymm14 = _mm256_permute_pd(ymm16, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm18); + ymm17 = _mm256_mul_pd(ymm1, ymm16); + ymm14 = _mm256_mul_pd(ymm1, ymm14); + ymm15 = _mm256_hsub_pd(ymm17, ymm14); - ///implement TRSM/// - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_broadcast_sd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_broadcast_sd((double const *)(b11 + cs_b *2)); - ymm3 = _mm256_broadcast_sd((double const *)(b11 + cs_b *3)); + ymm9 = _mm256_sub_pd(ymm15,ymm9); + ymm14 = _mm256_permute_pd(ymm16, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm18); + ymm17 = _mm256_mul_pd(ymm2, ymm16); + ymm14 = _mm256_mul_pd(ymm2, ymm14); + ymm15 = _mm256_hsub_pd(ymm17, ymm14); - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + ymm10 = _mm256_sub_pd(ymm15,ymm10); - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0E); - _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm0, 0)); - _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm1, 0)); - _mm_storel_pd((b11 + cs_b * 2), _mm256_extractf128_pd(ymm2, 0)); - _mm_storel_pd((b11 + cs_b * 3), _mm256_extractf128_pd(ymm3, 0)); + _mm_storeu_pd((double *)(b11), + _mm256_extractf128_pd(ymm8,0)); + _mm_storeu_pd((double *)(b11 + cs_b * 1), + _mm256_extractf128_pd(ymm9,0)); + _mm_storeu_pd((double *)(b11 + cs_b * 2), + _mm256_extractf128_pd(ymm10,0)); if(transa) - dtrsm_AutXB_ref(a11, b11, m_rem, 4, cs_a, cs_b, is_unitdiag); + ztrsm_AutXB_ref(a11, b11, m_rem, 3, + cs_a, cs_b, is_unitdiag, + conjtransa); + else - dtrsm_AlXB_ref(a11, b11, m_rem, 4, rs_a, cs_b, is_unitdiag); - n_rem -= 4; - j+=4; + ztrsm_AlXB_ref(a11, b11, m_rem, 3, rs_a, + cs_b, is_unitdiag, + conjtransa); } - if(n_rem) - { - a10 = D_A_pack; //pointer to block of A to be used for GEMM - a11 = L + (i*rs_a) + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM - b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM - - k_iter = i; //number of times GEMM to be performed(in blocks of 4x4) - - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS - - if(3 == n_rem) - { - ///GEMM code begins/// - BLIS_DTRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) - - BLIS_PRE_DTRSM_SMALL_1M_3N(AlphaVal,b11,cs_b) - - if(transa) - dtrsm_AutXB_ref(a11, b11, m_rem, 3, cs_a, cs_b, is_unitdiag); - else - dtrsm_AlXB_ref(a11, b11, m_rem, 3, rs_a, cs_b, is_unitdiag); - } - else if(2 == n_rem) + dim_t n_rem = n-j; + if(n_rem) + { + a10 = D_A_pack; + a11 = L + (i*rs_a) + (i*cs_a); + b01 = B + (j*cs_b); + b11 = B + i + (j* cs_b); + k_iter = i; + BLIS_SET_YMM_REG_ZEROS + if(2 == n_rem) { ///GEMM code begins/// - BLIS_DTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b,p_lda,k_iter) - - BLIS_PRE_DTRSM_SMALL_1M_2N(AlphaVal,b11,cs_b) + BLIS_ZTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b, + p_lda,k_iter) + BLIS_PRE_ZTRSM_SMALL_1M_2N(AlphaVal,b11,cs_b) if(transa) - dtrsm_AutXB_ref(a11, b11, m_rem, 2, cs_a, cs_b, is_unitdiag); + ztrsm_AutXB_ref(a11, b11, m_rem, 2, + cs_a, cs_b, is_unitdiag, + conjtransa); + else - dtrsm_AlXB_ref(a11, b11, m_rem, 2, rs_a, cs_b, is_unitdiag); + ztrsm_AlXB_ref(a11, b11, m_rem, 2, + rs_a, cs_b, is_unitdiag, + conjtransa); } else if(1 == n_rem) { ///GEMM code begins/// - BLIS_DTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b,p_lda,k_iter) + BLIS_ZTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b, + p_lda,k_iter) - BLIS_PRE_DTRSM_SMALL_1M_1N(AlphaVal,b11,cs_b) + BLIS_PRE_ZTRSM_SMALL_1M_1N(AlphaVal,b11,cs_b) if(transa) - dtrsm_AutXB_ref(a11, b11, m_rem, 1, cs_a, cs_b, is_unitdiag); + ztrsm_AutXB_ref(a11, b11, m_rem, 1, + cs_a, cs_b, is_unitdiag, + conjtransa); + else - dtrsm_AutXB_ref(a11, b11, m_rem, 1, rs_a, cs_b, is_unitdiag); + ztrsm_AlXB_ref(a11, b11, m_rem, 1, + rs_a, cs_b, is_unitdiag, + conjtransa); } } m_rem -=1; @@ -10859,7218 +32429,4083 @@ BLIS_INLINE err_t bli_dtrsm_small_AutXB_AlXB { bli_membrk_release(&rntm, &local_mem_buf_A_s); } - return BLIS_SUCCESS; + return BLIS_SUCCESS; } -/* - * ZTRSM utilities and kernel functions - */ - -#define DCOMPLEX_INV(a, b) {\ - a.real = b.real;\ - a.imag = (b.imag * -1.0);\ - /*Compute denominator eliminating imaginary component*/\ - double dnm = (b.real * b.real);\ - /*multiply two times with -1 for correct result as - * dcomplex number with positive imaginary part will - * invert the sign if not multiplied twice with -1*/\ - dnm += ((-1.0 * (b.imag * b.imag)) * -1.0);\ - /*Compute the final result by dividing real and imag part by dnm*/\ - a.real /= dnm;\ - a.imag /= dnm;\ -} +BLIS_INLINE err_t bli_ztrsm_small_AltXB_AuXB +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +) +{ + dim_t m = bli_obj_length(b); // number of rows of matrix B + dim_t n = bli_obj_width(b); // number of columns of matrix B -#define DCOMPLEX_MUL(a, b, c) {\ - double real = a.real * b.real;\ - real += ((a.imag * b.imag) * -1.0);\ - double imag = (a.real * b.imag);\ - imag += (a.imag * b.real);\ - c.real = real;\ - c.imag = imag;\ -} + bool transa = bli_obj_has_trans(a); + bool conjtransa = bli_obj_has_conj(a); -#define DCOMPLEX_DIV(a, b){\ - double dnm = b.real * b.real;\ - dnm += (-1.0 * (b.imag * (b.imag * -1.0) ));\ - a.real /= dnm;\ - a.imag /= dnm;\ -} + dim_t cs_a, rs_a; + dim_t d_mr = 4,d_nr = 3; + // Swap rs_a & cs_a in case of non-tranpose. + if(transa) + { + cs_a = bli_obj_col_stride(a); // column stride of A + rs_a = bli_obj_row_stride(a); // row stride of A + } + else + { + cs_a = bli_obj_row_stride(a); // row stride of A + rs_a = bli_obj_col_stride(a); // column stride of A + } + dim_t cs_b = bli_obj_col_stride(b); // column stride of B -#ifdef BLIS_ENABLE_TRSM_PREINVERSION -#define ZTRSM_DIAG_ELE_INV_OPS(a,b){\ - DCOMPLEX_INV(a, b)\ -} -#endif + dim_t i, j, k; //loop variables + dim_t k_iter; //number of times GEMM to be performed -#ifdef BLIS_DISABLE_TRSM_PREINVERSION -#define ZTRSM_DIAG_ELE_INV_OPS(a,b) {\ - a.real = b.real;\ - a.imag = b.imag;\ -} -#endif + dcomplex AlphaVal = *(dcomplex *)AlphaObj->buffer; //value of alpha + dcomplex *L = a->buffer; //pointer to matrix A + dcomplex *B = b->buffer; //pointer to matrix B + //pointers that point to blocks for GEMM and TRSM + dcomplex *a10, *a11, *b01, *b11; -#ifdef BLIS_ENABLE_TRSM_PREINVERSION -#define ZTRSM_DIAG_ELE_EVAL_OPS(a,b,c){\ - if(!is_unitdiag)\ - DCOMPLEX_MUL(b, c, c)\ -} -#endif + dcomplex ones = {1.0, 1.0}; + bool is_unitdiag = bli_obj_has_unit_diag(a); -#ifdef BLIS_DISABLE_TRSM_PREINVERSION -#define ZTRSM_DIAG_ELE_EVAL_OPS(a,b,c){\ - if(!is_unitdiag)\ - {\ - a.real = b.real;\ - a.imag = (b.imag * -1.0);\ - DCOMPLEX_MUL(c, a, c)\ - DCOMPLEX_DIV(c, b)\ - }\ -} -#endif + //scratch registers + __m256d ymm0, ymm1, ymm2, ymm3; + __m256d ymm4, ymm5, ymm6, ymm7; + __m256d ymm8, ymm9, ymm10, ymm11; + __m256d ymm12, ymm13, ymm14, ymm15; + __m256d ymm16, ymm17, ymm18, ymm19; -BLIS_INLINE err_t ztrsm_AltXB_ref -( - dcomplex *A, - dcomplex *B, - dim_t M, - dim_t N, - dim_t lda, - dim_t ldb, - bool is_unitdiag, - bool conjtransa -) -{ - dim_t i, j, k; - for (k = M-1; k >= 0; k--) - { - dcomplex lkk_inv = {1.0, 1.0}, cur_compute = {0.0, 0.0}, A_trans = {0.0, 0.0}; - if(!is_unitdiag) - { - ZTRSM_DIAG_ELE_INV_OPS(lkk_inv, A[k+k*lda]) - if(conjtransa) - { - lkk_inv.imag *= -1.0; - } - } - for (j = N -1; j >= 0; j--) - { - ZTRSM_DIAG_ELE_EVAL_OPS(cur_compute, lkk_inv, B[k + j*ldb]) - for (i = k-1; i >=0; i--) - { - if(conjtransa) - { - A_trans.real = A[i*lda + k].real; - A_trans.imag = A[i*lda + k].imag * -1.0; - } - else - { - A_trans.real = A[i*lda + k].real; - A_trans.imag = A[i*lda + k].imag; - } + __m128d xmm5, xmm4; + gint_t required_packing_A = 1; + mem_t local_mem_buf_A_s = {0}; + dcomplex *D_A_pack = NULL; + dcomplex d11_pack[d_mr] __attribute__((aligned(64))); + rntm_t rntm; - DCOMPLEX_MUL(A_trans, B[k+j*ldb], cur_compute) - B[i + j*ldb].real -= cur_compute.real; - B[i + j*ldb].imag -= cur_compute.imag; - } - } - } - return BLIS_SUCCESS; -} + bli_rntm_init_from_global( &rntm ); + bli_rntm_set_num_threads_only( 1, &rntm ); + bli_membrk_rntm_set_membrk( &rntm ); -BLIS_INLINE err_t ztrsm_AutXB_ref -( - dcomplex *A, - dcomplex *B, - dim_t M, - dim_t N, - dim_t lda, - dim_t ldb, - bool is_unitdiag, - bool conjtransa -) -{ - dim_t i, j, k; - for (k = 0; k < M; k++) - { - dcomplex lkk_inv = {1.0, 1.0}, cur_compute = {0.0, 0.0}, A_trans = {0.0, 0.0}; - if(!is_unitdiag) - { - ZTRSM_DIAG_ELE_INV_OPS(lkk_inv, A[k+k*lda]) - if(conjtransa) - { - lkk_inv.imag *= -1.0; - } - } - - for (j = 0; j < N; j++) - { - ZTRSM_DIAG_ELE_EVAL_OPS(cur_compute, lkk_inv, B[k + j*ldb]) - for (i = k+1; i < M; i++) - { - if(conjtransa) - { - A_trans.real = A[k+i*lda].real; - A_trans.imag = A[k+i*lda].imag * -1.0; - } - else - { - A_trans.real = A[k+i*lda].real; - A_trans.imag = A[k+i*lda].imag; - } - - DCOMPLEX_MUL(A_trans, B[k+j*ldb], cur_compute) - B[i + j*ldb].real -= cur_compute.real; - B[i + j*ldb].imag -= cur_compute.imag; - } - - } + siz_t buffer_size = bli_pool_block_size( + bli_membrk_pool( + bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), + bli_rntm_membrk(&rntm))); - } - return BLIS_SUCCESS; -} + if((d_mr * m * sizeof(dcomplex)) > buffer_size) + return BLIS_NOT_YET_IMPLEMENTED; -BLIS_INLINE err_t ztrsm_AlXB_ref -( - dcomplex *A, - dcomplex *B, - dim_t M, - dim_t N, - dim_t lda, - dim_t ldb, - bool is_unitdiag, - bool conjtransa -) -{ - dim_t i, j, k; - for (k = 0; k < M; k++) + if(required_packing_A == 1) { - dcomplex lkk_inv = {1.0, 1.0}, cur_compute = {0.0, 0.0}, A_trans = {0.0, 0.0}; - if(!is_unitdiag) - { - ZTRSM_DIAG_ELE_INV_OPS(lkk_inv, A[k+k*lda]) - if(conjtransa) - { - lkk_inv.imag *= -1.0; - } - } - for (j = 0; j < N; j++) - { - ZTRSM_DIAG_ELE_EVAL_OPS(cur_compute, lkk_inv, B[k + j*ldb]) - for (i = k+1; i < M; i++) - { - if(conjtransa) - { - A_trans.real = A[i+k*lda].real; - A_trans.imag = A[i+k*lda].imag * -1.0; - } - else - { - A_trans.real = A[i+k*lda].real; - A_trans.imag = A[i+k*lda].imag; - } - DCOMPLEX_MUL(A_trans, B[k+j*ldb], cur_compute) - B[i + j*ldb].real -= cur_compute.real; - B[i + j*ldb].imag -= cur_compute.imag; - } - } + // Get the buffer from the pool. + bli_membrk_acquire_m(&rntm, + buffer_size, + BLIS_BITVAL_BUFFER_FOR_A_BLOCK, + &local_mem_buf_A_s); + if(FALSE==bli_mem_is_alloc(&local_mem_buf_A_s)) return BLIS_NULL_POINTER; + D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); + if(NULL==D_A_pack) return BLIS_NULL_POINTER; } - return BLIS_SUCCESS; -} -BLIS_INLINE err_t ztrsm_AuXB_ref -( - dcomplex *A, - dcomplex *B, - dim_t M, - dim_t N, - dim_t lda, - dim_t ldb, - bool is_unitdiag, - bool conjtransa -) -{ - dim_t i, j, k; - for (k = M-1; k >= 0; k--) + /* + Performs solving TRSM for 4 colmns at a time from 0 to m/d_mr in steps of d_mr + a. Load, transpose, Pack A (a10 block), the size of packing 8x6 to 8x (m-d_mr) + First there will be no GEMM and no packing of a10 because it is only TRSM + b. Using packed a10 block and b01 block perform GEMM operation + c. Use GEMM outputs, perform TRSM operaton using a11, b11 and update B + d. Repeat b,c for n rows of B in steps of d_nr + */ + for(i = (m - d_mr); (i + 1) > 0; i -= d_mr) { - dcomplex lkk_inv = {1.0, 1.0}, cur_compute = {0.0, 0.0}, A_trans = {0.0, 0.0}; - if(!is_unitdiag) - { - ZTRSM_DIAG_ELE_INV_OPS(lkk_inv, A[k+k*lda]) - if(conjtransa) - { - lkk_inv.imag *= -1.0; - } - - } - for (j = N -1; j >= 0; j--) + a10 = L + (i*cs_a) + (i + d_mr)*rs_a;//pointer to block of A to be used for GEMM + a11 = L + (i*cs_a) + (i*rs_a);//pointer to block of A to be used for TRSM + + // Do transpose for a10 & store in D_A_pack + //ptr_a10_dup = D_A_pack; + + dim_t p_lda = d_mr; // packed leading dimension + + if(transa) { - ZTRSM_DIAG_ELE_EVAL_OPS(cur_compute, lkk_inv, B[k + j*ldb]) - for (i = k-1; i >=0; i--) - { - if(conjtransa) - { - A_trans.real = A[i+k*lda].real; - A_trans.imag = A[i+k*lda].imag * -1.0; - } - else - { - A_trans.real = A[i+k*lda].real; - A_trans.imag = A[i+k*lda].imag; - } + /* + Load, transpose and pack current A block (a10) into packed buffer memory + D_A_pack + a. This a10 block is used in GEMM portion only and this + a10 block size will be increasing by d_mr for every next itteration + untill it reaches 4x(m-4) which is the maximum GEMM alone block size + in A + b. This packed buffer is reused to calculate all n rows of B matrix + */ + bli_ztrsm_small_pack('L', (m-i-d_mr), 1, a10, cs_a, D_A_pack,p_lda,d_mr); - DCOMPLEX_MUL(A_trans, B[k+j*ldb], cur_compute) - B[i + j*ldb].real -= cur_compute.real; - B[i + j*ldb].imag -= cur_compute.imag; - } + /* + Pack 8 diagonal elements of A block into an array + a. This helps in utilze cache line efficiently in TRSM operation + b. store ones when input is unit diagonal + */ + ztrsm_small_pack_diag_element(is_unitdiag,a11,cs_a,d11_pack,d_mr); } - } - return BLIS_SUCCESS; -} + else + { + bli_ztrsm_small_pack('L', (m-i-d_mr), 0, a10, rs_a, D_A_pack,p_lda,d_mr); + ztrsm_small_pack_diag_element(is_unitdiag,a11,rs_a,d11_pack,d_mr); + } + + /* + a. Perform GEMM using a10, b01. + b. Perform TRSM on a11, b11 + c. This loop GEMM+TRSM loops operates with 8x6 block size + along n dimension for every d_nr rows of b01 where + packed A buffer is reused in computing all n rows of B. + d. Same approch is used in remaining fringe cases. + */ + for(j = (n - d_nr); (j + 1) > 0; j -= d_nr) + { + a10 = D_A_pack; + b01 = B + (j * cs_b) + i + d_mr;//pointer to block of B to be used for GEMM + b11 = B + (j * cs_b) + i;//pointer to block of B to be used for TRSM + + k_iter = (m - i - d_mr); + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + /* + Peform GEMM between a10 and b01 blocks + For first itteration there will be no GEMM operation + where k_iter are zero + */ + BLIS_ZTRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) -/** - * Multiplies Alpha with one dcomplex - * element of one column. - * One xmm register holds one dcomplex - * element only(real(64 bit) + imaginary(64 bit)) - */ -#define BLIS_PRE_ZTRSM_SMALL_1M_1N(AlphaVal,b11,cs_b) {\ - /*register to hold alpha*/\ - ymm16 = _mm256_broadcast_pd(( __m128d const *)(&AlphaVal));\ - \ - /*load dcomplex elements*/\ - xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 0));\ - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ - /*to negate the real part of complex number*/\ - ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ - /*dcomplex multiplication and substraction*/\ - /*swaps position of real and imag components of complex number*/\ - ymm14 = _mm256_permute_pd(ymm16, 0x5);\ - /*multiply with modified vec2 */\ - ymm14 = _mm256_mul_pd(ymm14, ymm18);\ - ymm17 = _mm256_mul_pd(ymm0, ymm16);\ - /*multiply with vec2 */\ - ymm14 = _mm256_mul_pd(ymm0, ymm14);\ - /*get the dcomplex mul answer into register*/\ - ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ - ymm8 = _mm256_sub_pd(ymm15,ymm8);\ - xmm5 = _mm256_extractf128_pd(ymm8, 0);\ - /*store dcomplex elements*/\ - _mm_storeu_pd((double *)(b11 + cs_b * 0), xmm5);\ -} + /* + Load b11 of size 6x8 and multiply with alpha + Add the GEMM output and perform inregister transose of b11 + to peform TRSM operation. + */ + BLIS_ZTRSM_SMALL_NREG_TRANSPOSE_3x4(b11,cs_b,AlphaVal) -/** - * Multiplies Alpha with one dcomplex - * element of two columns. - */ -#define BLIS_PRE_ZTRSM_SMALL_1M_2N(AlphaVal,b11,cs_b) {\ - /*register to hold alpha*/\ - ymm16 = _mm256_broadcast_pd(( __m128d const*)(&AlphaVal));\ - \ - /*ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0));*/\ - xmm4 = _mm_loadu_pd((double const *)(b11 + cs_b * 0));\ - xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 1));\ - ymm0 = _mm256_insertf128_pd(ymm0, xmm4, 0);\ - ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0);\ - /*to negate the real part of complex number*/\ - ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ - /*swaps position of real and imag components of complex number*/\ - ymm14 = _mm256_permute_pd(ymm16, 0x5);\ - /*dcomplex multiplication and substraction*/\ - /*multiply with modified vec2 */\ - ymm14 = _mm256_mul_pd(ymm14, ymm18);\ - ymm17 = _mm256_mul_pd(ymm0, ymm16);\ - /*multiply with vec2 */\ - ymm14 = _mm256_mul_pd(ymm0, ymm14);\ - /*get the dcomplex mul answer into register*/\ - ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ - ymm8 = _mm256_sub_pd(ymm15,ymm8);\ - \ - ymm14 = _mm256_permute_pd(ymm16, 0x5);\ - ymm14 = _mm256_mul_pd(ymm14, ymm18);\ - ymm17 = _mm256_mul_pd(ymm1, ymm16);\ - ymm14 = _mm256_mul_pd(ymm1, ymm14);\ - ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ - ymm9 = _mm256_sub_pd(ymm15,ymm9);\ - xmm4 = _mm256_extractf128_pd(ymm8, 0);\ - _mm_storeu_pd((double *)(b11 + cs_b * 0), xmm4);\ - xmm5 = _mm256_extractf128_pd(ymm9, 0);\ - _mm_storeu_pd((double *)(b11 + cs_b * 1), xmm5);\ -} + /* + Compute 4x3 TRSM block by using GEMM block output in register + a. The 4x3 input (gemm outputs) are stored in combinations of ymm + registers + 1. ymm8, ymm4 2. ymm9, ymm5 3. ymm10, ymm6, 4. ymm11, ymm7 + where ymm8-ymm11 holds 4x2 data and reaming 4x1 will be hold by + other registers + b. Towards the end do in regiser transpose of TRSM output and store in + b11 + */ + ////extract a00 + ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack + 3)); -#define BLIS_ZTRSM_SMALL_NREG_TRANSPOSE_1x4(b11,cs_b,AlphaVal) {\ - ymm16 = _mm256_broadcast_pd(( __m128d const *)&AlphaVal);\ - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0));\ - ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 2));\ - ymm1 = _mm256_broadcast_pd((__m128d const *)(&ones));\ - ymm5 = _mm256_broadcast_pd((__m128d const *)(&ones));\ - \ - ymm14 = _mm256_shuffle_pd(ymm16, ymm16, 5);\ - \ - /*dcomplex multiplication and substraction*/\ - ymm17 = _mm256_shuffle_pd(ymm0, ymm0, 15);\ - ymm18 = _mm256_shuffle_pd(ymm0, ymm0,0);\ - ymm19 = _mm256_mul_pd(ymm17, ymm14);\ - ymm15 = _mm256_fmaddsub_pd(ymm18, ymm16, ymm19);\ - ymm0 = _mm256_sub_pd(ymm15, ymm8);\ - \ - /*dcomplex multiplication and substraction*/\ - ymm17 = _mm256_shuffle_pd(ymm4, ymm4, 15);\ - ymm18 = _mm256_shuffle_pd(ymm4, ymm4,0);\ - ymm19 = _mm256_mul_pd(ymm17, ymm14);\ - ymm15 = _mm256_fmaddsub_pd(ymm18, ymm16, ymm19);\ - ymm4 = _mm256_sub_pd(ymm15, ymm12);\ -} +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm11 and ymm7 with ymm1*/ + BLIS_ZTRSM_TWO_DIV(ymm11,ymm7) +#else + /*performs dcomplex multiplication of ymm11 and ymm7 with ymm1*/ + BLIS_ZTRSM_MUL(ymm11) + BLIS_ZTRSM_MUL(ymm7) +#endif + //extract a11 + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack + 2)); + //(ROW1): FMA operations + ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a*2 + rs_a*3)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + /* Step1 dcomplex multiply ymm2, ymm8 + * Step2 negate the result + * Step3 add ymm9*/ + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + //For ymm11 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm11, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm11, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); -/** - * Multiplies Alpha with two dcomplex - * elements of one column and store it into - * buffer b11. - */ -#define BLIS_PRE_ZTRSM_SMALL_2M_1N(AlphaVal,b11,cs_b) {\ - ymm16 = _mm256_broadcast_pd(( __m128d const*)(&AlphaVal));\ - \ - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0));\ - ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ - /*dcomplex multiplication and substraction*/\ - ymm14 = _mm256_permute_pd(ymm16, 0x5);\ - ymm14 = _mm256_mul_pd(ymm14, ymm18);\ - ymm17 = _mm256_mul_pd(ymm0, ymm16);\ - ymm14 = _mm256_mul_pd(ymm0, ymm14);\ - ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ - ymm8 = _mm256_sub_pd(ymm15,ymm8);\ - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm8);\ -} + //For ymm7 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); -/** - * Multiplies Alpha with two elements of - * two columns and store the result in buffer b11 - * - */ -#define BLIS_PRE_ZTRSM_SMALL_2M_2N(AlphaVal,b11,cs_b){\ - ymm16 = _mm256_broadcast_pd(( __m128d const*)(&AlphaVal));\ - \ - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0));\ - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1));\ - ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ - /*dcomplex multiplication and substraction*/\ - ymm14 = _mm256_permute_pd(ymm16, 0x5);\ - ymm14 = _mm256_mul_pd(ymm14, ymm18);\ - ymm17 = _mm256_mul_pd(ymm0, ymm16);\ - ymm14 = _mm256_mul_pd(ymm0, ymm14);\ - ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ - ymm8 = _mm256_sub_pd(ymm15,ymm8);\ - \ - /*dcomplex multiplication and substraction*/\ - ymm14 = _mm256_permute_pd(ymm16, 0x5);\ - ymm14 = _mm256_mul_pd(ymm14, ymm18);\ - ymm17 = _mm256_mul_pd(ymm1, ymm16);\ - ymm14 = _mm256_mul_pd(ymm1, ymm14);\ - ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ - ymm9 = _mm256_sub_pd(ymm15,ymm9);\ - \ - _mm256_storeu_pd((double *)(b11), ymm8);\ - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm9);\ -} + ymm13 = _mm256_mul_pd(ymm7, ymm2); + ymm14 = _mm256_mul_pd(ymm7, ymm14); + ymm17 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + ymm17 = _mm256_mul_pd(ymm17, ymm15); -/** - * Performs GEMM operation. - * Two elements of column in ymm0 - * ymm1, ymm2 holds respective broadcasted element. - */ -#define BLIS_ZTRSM_SMALL_GEMM_2mx3n(a10,b01,cs_b,p_lda,k_iter){\ - double *tptr = (double *)b01;\ - if(conjtransa) {\ - ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ - for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ - {\ - ymm0 = _mm256_loadu_pd((double const *)(a10));\ - ymm0 = _mm256_mul_pd(ymm0, ymm18);\ - \ - ymm1 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0));\ - ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0 + 1));\ - \ - ymm8 = _mm256_fmadd_pd(ymm0, ymm1, ymm8);\ - ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4);\ - \ - ymm1 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1));\ - ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1 + 1));\ - \ - ymm9 = _mm256_fmadd_pd(ymm0, ymm1, ymm9);\ - ymm5 = _mm256_fmadd_pd(ymm0, ymm2, ymm5);\ - \ - ymm1 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 2));\ - ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 2 + 1));\ - \ - ymm10 = _mm256_fmadd_pd(ymm0, ymm1, ymm10);\ - ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6);\ - \ - tptr += 2; /*move to next row of B*/\ - a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ - }\ - }\ - else {\ - for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ - {\ - ymm0 = _mm256_loadu_pd((double const *)(a10));\ - \ - ymm1 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0));\ - ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0 + 1));\ - \ - ymm8 = _mm256_fmadd_pd(ymm0, ymm1, ymm8);\ - ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4);\ - \ - ymm1 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1));\ - ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1 + 1));\ - \ - ymm9 = _mm256_fmadd_pd(ymm0, ymm1, ymm9);\ - ymm5 = _mm256_fmadd_pd(ymm0, ymm2, ymm5);\ - \ - ymm1 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 2));\ - ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 2 + 1));\ - \ - ymm10 = _mm256_fmadd_pd(ymm0, ymm1, ymm10);\ - ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6);\ - \ - tptr += 2; /*move to next row of B*/\ - a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ - }\ - }\ - ymm4 = _mm256_permute_pd(ymm4, 0x5);\ - ymm5 = _mm256_permute_pd(ymm5, 0x5);\ - ymm6 = _mm256_permute_pd(ymm6, 0x5);\ - ymm8 = _mm256_addsub_pd(ymm8, ymm4);\ - ymm9 = _mm256_addsub_pd(ymm9, ymm5);\ - ymm10 = _mm256_addsub_pd(ymm10, ymm6);\ -} + //Step 3 + ymm10 = _mm256_add_pd(ymm16, ymm10); + ymm6 = _mm256_add_pd(ymm17, ymm6); -/** - * Performs GEMM operation. - * Four elements of column in ymm0, ymm1. - * ymm2, ymm7 holds respective broadcasted element. - */ -#define BLIS_ZTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) {\ - double *tptr = (double *)a01;\ - if(conjtransa) {\ - ymm18 = _mm256_set_pd(-1.0, -1.0, -1.0, -1.0);\ - for(k = 0; k < k_iter; k++)\ - {\ - ymm0 = _mm256_loadu_pd((double const *)b10);\ - ymm1 = _mm256_loadu_pd((double const *)(b10 + 2));\ - \ - _mm_prefetch((char*)( b10 + 4*cs_b), _MM_HINT_T0); \ - ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0));\ - ymm7 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0 + 1));\ - ymm7 = _mm256_mul_pd(ymm7, ymm18);\ - /*dcomplex multiplication and substraction*/\ - \ - ymm3 = _mm256_fmadd_pd(ymm0, ymm2, ymm3);\ - ymm4 = _mm256_fmadd_pd(ymm1, ymm2, ymm4);\ - ymm5 = _mm256_fmadd_pd(ymm0, ymm7, ymm5);\ - ymm6 = _mm256_fmadd_pd(ymm1, ymm7, ymm6);\ - /*dcomplex multiplication and substraction*/\ - \ - tptr += 2;\ - b10 += cs_b;\ - }\ - }\ - else {\ - for(k = 0; k < k_iter; k++)\ - {\ - ymm0 = _mm256_loadu_pd((double const *)b10);\ - ymm1 = _mm256_loadu_pd((double const *)(b10 + 2));\ - \ - _mm_prefetch((char*)( b10 + 4*cs_b), _MM_HINT_T0); \ - ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0));\ - ymm7 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0 + 1));\ - /*dcomplex multiplication and substraction*/\ - \ - ymm3 = _mm256_fmadd_pd(ymm0, ymm2, ymm3);\ - ymm4 = _mm256_fmadd_pd(ymm1, ymm2, ymm4);\ - ymm5 = _mm256_fmadd_pd(ymm0, ymm7, ymm5);\ - ymm6 = _mm256_fmadd_pd(ymm1, ymm7, ymm6);\ - /*ymm3 = _mm256_add_pd(ymm15, ymm3);*/\ - /*dcomplex multiplication and substraction*/\ - \ - tptr += 2;\ - b10 += cs_b;\ - }\ - }\ - ymm5 = _mm256_permute_pd(ymm5, 0x5);\ - ymm6 = _mm256_permute_pd(ymm6, 0x5);\ -\ - ymm3 = _mm256_addsub_pd(ymm3, ymm5);\ - ymm4 = _mm256_addsub_pd(ymm4, ymm6);\ -} + ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a*1 + rs_a*3)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + //For ymm11 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm11, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm11, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + //For ymm7 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); -/** - * Multiplies Alpha with 4 elements of column - */ -#define BLIS_PRE_ZTRSM_SMALL_1x4(b11,cs_b,AlphaVal) {\ - ymm16 = _mm256_broadcast_pd((__m128d const *)&AlphaVal);\ - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0));\ - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 2));\ -\ - ymm14 = _mm256_shuffle_pd(ymm16, ymm16, 5);\ -\ - ymm17 = _mm256_shuffle_pd(ymm0, ymm0, 15);\ - ymm18 = _mm256_shuffle_pd(ymm0, ymm0,0);\ - ymm19 = _mm256_mul_pd(ymm17, ymm14);\ - ymm15 = _mm256_fmaddsub_pd(ymm18, ymm16, ymm19);\ - ymm3 = _mm256_sub_pd(ymm15, ymm3);\ -\ - ymm17 = _mm256_shuffle_pd(ymm1, ymm1, 15);\ - ymm18 = _mm256_shuffle_pd(ymm1, ymm1,0);\ - ymm19 = _mm256_mul_pd(ymm17, ymm14);\ - ymm15 = _mm256_fmaddsub_pd(ymm18, ymm16, ymm19);\ - ymm4 = _mm256_sub_pd(ymm15, ymm4);\ -} + ymm13 = _mm256_mul_pd(ymm7, ymm2); + ymm14 = _mm256_mul_pd(ymm7, ymm14); + ymm17 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + ymm17 = _mm256_mul_pd(ymm17, ymm15); -/** - * Multiplies Alpha with 3 elements of column. - * ymm0 holds first 2 element and xmm5 holds the - * 3rd one. - */ -#define BLIS_PRE_ZTRSM_SMALL_1x3(b11,cs_b,AlphaVal) {\ - ymm16 = _mm256_broadcast_pd((__m128d const *)&AlphaVal);\ - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0));\ - xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 0 + 2));\ - ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0);\ -\ - ymm14 = _mm256_shuffle_pd(ymm16, ymm16, 5);\ -\ - ymm17 = _mm256_shuffle_pd(ymm0, ymm0, 15);\ - ymm18 = _mm256_shuffle_pd(ymm0, ymm0,0);\ - ymm19 = _mm256_mul_pd(ymm17, ymm14);\ - ymm15 = _mm256_fmaddsub_pd(ymm18, ymm16, ymm19);\ - ymm3 = _mm256_sub_pd(ymm15, ymm3);\ -\ - ymm17 = _mm256_shuffle_pd(ymm1, ymm1, 15);\ - ymm18 = _mm256_shuffle_pd(ymm1, ymm1,0);\ - ymm19 = _mm256_mul_pd(ymm17, ymm14);\ - ymm15 = _mm256_fmaddsub_pd(ymm18, ymm16, ymm19);\ - ymm4 = _mm256_sub_pd(ymm15, ymm4);\ -} + //Step 3 + ymm9 = _mm256_add_pd(ymm16, ymm9); + ymm5 = _mm256_add_pd(ymm17, ymm5); -#define BLIS_ZTRSM_SMALL_NREG_TRANSPOSE_2x4(b11,cs_b,AlphaVal) {\ - ymm16 = _mm256_broadcast_pd((__m128d const *)&AlphaVal);\ - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0));\ - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1));\ - ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 2));\ - ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 2));\ - ymm14 = _mm256_shuffle_pd(ymm16, ymm16, 5);\ -\ - ymm17 = _mm256_shuffle_pd(ymm0, ymm0, 15);\ - ymm18 = _mm256_shuffle_pd(ymm0, ymm0,0);\ - ymm19 = _mm256_mul_pd(ymm17, ymm14);\ - ymm15 = _mm256_fmaddsub_pd(ymm18, ymm16, ymm19);\ - ymm0 = _mm256_sub_pd(ymm15, ymm8);\ -\ - ymm17 = _mm256_shuffle_pd(ymm1, ymm1, 15);\ - ymm18 = _mm256_shuffle_pd(ymm1, ymm1,0);\ - ymm19 = _mm256_mul_pd(ymm17, ymm14);\ - ymm15 = _mm256_fmaddsub_pd(ymm18, ymm16, ymm19);\ - ymm1 = _mm256_sub_pd(ymm15, ymm9);\ -\ - ymm17 = _mm256_shuffle_pd(ymm4, ymm4, 15);\ - ymm18 = _mm256_shuffle_pd(ymm4, ymm4,0);\ - ymm19 = _mm256_mul_pd(ymm17, ymm14);\ - ymm15 = _mm256_fmaddsub_pd(ymm18, ymm16, ymm19);\ - ymm4 = _mm256_sub_pd(ymm15, ymm12);\ -\ - ymm17 = _mm256_shuffle_pd(ymm5, ymm5, 15);\ - ymm18 = _mm256_shuffle_pd(ymm5, ymm5,0);\ - ymm19 = _mm256_mul_pd(ymm17, ymm14);\ - ymm15 = _mm256_fmaddsub_pd(ymm18, ymm16, ymm19);\ - ymm5 = _mm256_sub_pd(ymm15, ymm13);\ -} + ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + rs_a*3)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + //For ymm11 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm11, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm11, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + //For ymm7 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); -#define BLIS_PRE_ZTRSM_SMALL_3M_1N(AlphaVal,b11,cs_b){\ - ymm16 = _mm256_broadcast_pd(( __m128d const *)(&AlphaVal));\ - \ - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0));\ - xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 0 + 2));\ - ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0);\ - \ - ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ - /*dcomplex multiplication and substraction*/\ - ymm14 = _mm256_permute_pd(ymm16, 0x5);\ - ymm14 = _mm256_mul_pd(ymm14, ymm18);\ - ymm17 = _mm256_mul_pd(ymm0, ymm16);\ - ymm14 = _mm256_mul_pd(ymm0, ymm14);\ - ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ - ymm8 = _mm256_sub_pd(ymm15,ymm8);\ - \ - /*dcomplex multiplication and substraction*/\ - ymm14 = _mm256_permute_pd(ymm16, 0x5);\ - ymm14 = _mm256_mul_pd(ymm14, ymm18);\ - ymm17 = _mm256_mul_pd(ymm1, ymm16);\ - ymm14 = _mm256_mul_pd(ymm1, ymm14);\ - ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ - ymm12 = _mm256_sub_pd(ymm15,ymm12);\ - \ - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm8);\ - xmm5 = _mm256_extractf128_pd(ymm12, 0);\ - _mm_storeu_pd((double *)(b11 + cs_b * 0 + 2), xmm5);\ -} + ymm13 = _mm256_mul_pd(ymm7, ymm2); + ymm14 = _mm256_mul_pd(ymm7, ymm14); + ymm17 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + ymm17 = _mm256_mul_pd(ymm17, ymm15); + //Step 3 + ymm8 = _mm256_add_pd(ymm16, ymm8); + ymm4 = _mm256_add_pd(ymm17, ymm4); -/** - * Multiplies Alpha with 3 elements of 2 columns - * and store into buffer b11. - * ymm0 ymm1 holds first 2 elements of 2 columns. - * xmm4 xmm5 holds the 3rd elements of 2 columns. - */ -#define BLIS_PRE_ZTRSM_SMALL_3M_2N(AlphaVal,b11,cs_b){\ - ymm16 = _mm256_broadcast_pd(( __m128d const*)(&AlphaVal));\ - \ - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0));\ - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1));\ - xmm4 = _mm_loadu_pd((double const *)(b11 + cs_b * 0 + 2));\ - ymm3 = _mm256_insertf128_pd(ymm3, xmm4, 0);\ -\ - xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 1 + 2));\ - ymm4 = _mm256_insertf128_pd(ymm4, xmm5, 0);\ -\ - ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ - /*dcomplex multiplication and substraction*/\ - ymm14 = _mm256_permute_pd(ymm16, 0x5);\ - ymm14 = _mm256_mul_pd(ymm14, ymm18);\ - ymm17 = _mm256_mul_pd(ymm0, ymm16);\ - ymm14 = _mm256_mul_pd(ymm0, ymm14);\ - ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ - ymm8 = _mm256_sub_pd(ymm15,ymm8);\ - \ - /*dcomplex multiplication and substraction*/\ - ymm14 = _mm256_permute_pd(ymm16, 0x5);\ - ymm14 = _mm256_mul_pd(ymm14, ymm18);\ - ymm17 = _mm256_mul_pd(ymm1, ymm16);\ - ymm14 = _mm256_mul_pd(ymm1, ymm14);\ - ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ - ymm9 = _mm256_sub_pd(ymm15,ymm9);\ - \ - /*dcomplex multiplication and substraction*/\ - ymm14 = _mm256_permute_pd(ymm16, 0x5);\ - ymm14 = _mm256_mul_pd(ymm14, ymm18);\ - ymm17 = _mm256_mul_pd(ymm3, ymm16);\ - ymm14 = _mm256_mul_pd(ymm3, ymm14);\ - ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ - ymm12 = _mm256_sub_pd(ymm15,ymm12);\ - \ - /*dcomplex multiplication and substraction*/\ - ymm14 = _mm256_permute_pd(ymm16, 0x5);\ - ymm14 = _mm256_mul_pd(ymm14, ymm18);\ - ymm17 = _mm256_mul_pd(ymm4, ymm16);\ - ymm14 = _mm256_mul_pd(ymm4, ymm14);\ - ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ - ymm13 = _mm256_sub_pd(ymm15,ymm13);\ - \ - _mm256_storeu_pd((double *)(b11), ymm8);\ - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm9);\ - xmm4 = _mm256_extractf128_pd(ymm12, 0);\ - _mm_storeu_pd((double *)(b11 + cs_b * 0 + 2), xmm4);\ - xmm5 = _mm256_extractf128_pd(ymm13, 0);\ - _mm_storeu_pd((double *)(b11 + cs_b * 1 + 2), xmm5);\ -} +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm10 and ymm6 with ymm1*/ + BLIS_ZTRSM_TWO_DIV(ymm10,ymm6) +#else + /*performs dcomplex multiplication of ymm10 and ymm6 with ymm1*/ + BLIS_ZTRSM_MUL(ymm10) + BLIS_ZTRSM_MUL(ymm6) +#endif + //extract a22 + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack + 1)); -/** - * Performs GEMM operation - * ymm0 holds 2 elements of column. - * ymm4 ymm6 holds broadcasted elements respectively - */ -#define BLIS_ZTRSM_SMALL_GEMM_3nx2m(a01,b10,cs_b,p_lda,k_iter) {\ - double *tptr = (double *)a01;\ - if(conjtransa) {\ - ymm18 = _mm256_set_pd(-1.0, -1.0, -1.0, -1.0);\ - for(k = 0; k< k_iter; k++) \ - {\ - ymm0 = _mm256_loadu_pd((double const *)(b10)); \ - \ - _mm_prefetch((char*)( b10 + 2*cs_b), _MM_HINT_T0); \ - ymm4 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0)); \ - ymm6 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0 + 1)); \ - ymm6 = _mm256_mul_pd(ymm6, ymm18);\ - /*dcomplex multiplication and substraction*/\ - \ - ymm3 = _mm256_fmadd_pd(ymm0, ymm4, ymm3);\ - ymm8 = _mm256_fmadd_pd(ymm0, ymm6, ymm8);\ - \ - ymm4 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1)); \ - ymm6 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1 + 1)); \ - ymm6 = _mm256_mul_pd(ymm6, ymm18);\ - \ - /*dcomplex multiplication and substraction*/\ - \ - ymm5 = _mm256_fmadd_pd(ymm0, ymm4, ymm5);\ - ymm9 = _mm256_fmadd_pd(ymm0, ymm6, ymm9);\ - \ - ymm4 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 2)); \ - ymm6 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 2 + 1)); \ - ymm6 = _mm256_mul_pd(ymm6, ymm18);\ - \ - /*dcomplex multiplication and substraction*/\ - \ - ymm7 = _mm256_fmadd_pd(ymm0, ymm4, ymm7);\ - ymm10 = _mm256_fmadd_pd(ymm0, ymm6, ymm10);\ - \ - tptr += 2; \ - b10 += cs_b; \ - }\ - }\ - else {\ - for(k = 0; k< k_iter; k++) \ - {\ - ymm0 = _mm256_loadu_pd((double const *)(b10)); \ - \ - _mm_prefetch((char*)( b10 + 2*cs_b), _MM_HINT_T0); \ - ymm4 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0)); \ - ymm6 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0 + 1)); \ - /*dcomplex multiplication and substraction*/\ - \ - ymm3 = _mm256_fmadd_pd(ymm0, ymm4, ymm3);\ - ymm8 = _mm256_fmadd_pd(ymm0, ymm6, ymm8);\ - /*ymm3 = _mm256_add_pd(ymm15, ymm3);*/\ - \ - ymm4 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1)); \ - ymm6 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1 + 1)); \ - \ - /*dcomplex multiplication and substraction*/\ - \ - ymm5 = _mm256_fmadd_pd(ymm0, ymm4, ymm5);\ - ymm9 = _mm256_fmadd_pd(ymm0, ymm6, ymm9);\ - /*ymm5 = _mm256_add_pd(ymm15, ymm5);*/\ - \ - ymm4 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 2)); \ - ymm6 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 2 + 1)); \ - \ - /*dcomplex multiplication and substraction*/\ - \ - ymm7 = _mm256_fmadd_pd(ymm0, ymm4, ymm7);\ - ymm10 = _mm256_fmadd_pd(ymm0, ymm6, ymm10);\ - /*ymm7 = _mm256_add_pd(ymm15, ymm7);*/\ - \ - tptr += 2; \ - b10 += cs_b; \ - }\ - }\ - ymm8 = _mm256_permute_pd(ymm8, 0x5);\ - ymm9 = _mm256_permute_pd(ymm9, 0x5);\ - ymm10 = _mm256_permute_pd(ymm10, 0x5);\ - ymm3 = _mm256_addsub_pd(ymm3, ymm8);\ - ymm5 = _mm256_addsub_pd(ymm5, ymm9);\ - ymm7 = _mm256_addsub_pd(ymm7, ymm10);\ -} + //(ROW2): FMA operations + ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a*1 + rs_a*2)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); -/** - * Multiplies Alpha with 2 elements of 3 columns - * ymm0 holds 2 elements of columns, once computation - * is done, it holds 2 elements of next columns after - * saving computed result into some other register. - * ymm3 ymm5 ymm7. - */ -#define BLIS_PRE_ZTRSM_SMALL_3x2(AlphaVal,b11,cs_b) {\ - ymm16 = _mm256_broadcast_pd(( __m128d const*)(&AlphaVal));\ - \ - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0));\ - ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ - \ - ymm14 = _mm256_permute_pd(ymm16, 0x5);\ - ymm14 = _mm256_mul_pd(ymm14, ymm18);\ - ymm17 = _mm256_mul_pd(ymm0, ymm16);\ - ymm14 = _mm256_mul_pd(ymm0, ymm14);\ - ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ - ymm3 = _mm256_sub_pd(ymm15,ymm3);\ - \ - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *1));\ -\ - ymm14 = _mm256_permute_pd(ymm16, 0x5);\ - ymm14 = _mm256_mul_pd(ymm14, ymm18);\ - ymm17 = _mm256_mul_pd(ymm0, ymm16);\ - ymm14 = _mm256_mul_pd(ymm0, ymm14);\ - ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ - ymm5 = _mm256_sub_pd(ymm15,ymm5);\ - \ - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *2));\ - \ - ymm14 = _mm256_permute_pd(ymm16, 0x5);\ - ymm14 = _mm256_mul_pd(ymm14, ymm18);\ - ymm17 = _mm256_mul_pd(ymm0, ymm16);\ - ymm14 = _mm256_mul_pd(ymm0, ymm14);\ - ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ - ymm7 = _mm256_sub_pd(ymm15,ymm7);\ - \ -} + //For ymm10 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm10, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm10, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + //For ymm6 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); -/** - * Performs GEMM - * ymm0 and ymm1 together holds 4 elements of column. - */ -#define BLIS_ZTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) {\ - double *tptr = (double *)a01;\ - if(conjtransa) {\ - ymm18 = _mm256_set_pd(-1.0, -1.0, -1.0, -1.0);\ - for(k = 0; k< k_iter; k++) \ - { \ - ymm0 = _mm256_loadu_pd((double const *)(b10)); \ - ymm1 = _mm256_loadu_pd((double const *)(b10 + 2)); \ - \ - _mm_prefetch((char*)( b10 + 4*cs_b), _MM_HINT_T0); \ - ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0)); \ - ymm12 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0 + 1)); \ - ymm12 = _mm256_mul_pd(ymm12, ymm18);\ - \ - ymm3 = _mm256_fmadd_pd(ymm0, ymm2, ymm3);\ - ymm4 = _mm256_fmadd_pd(ymm1, ymm2, ymm4);\ - ymm8 = _mm256_fmadd_pd(ymm0, ymm12, ymm8);\ - ymm9 = _mm256_fmadd_pd(ymm1, ymm12, ymm9);\ - \ - ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1)); \ - ymm12 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1 + 1)); \ - ymm12 = _mm256_mul_pd(ymm12, ymm18);\ - \ - ymm5 = _mm256_fmadd_pd(ymm0, ymm2, ymm5);\ - ymm6 = _mm256_fmadd_pd(ymm1, ymm2, ymm6);\ - ymm10 = _mm256_fmadd_pd(ymm0, ymm12, ymm10);\ - ymm11 = _mm256_fmadd_pd(ymm1, ymm12, ymm11);\ - \ - tptr += 2; \ - b10 += cs_b; \ - }\ - }\ - else {\ - for(k = 0; k< k_iter; k++) \ - { \ - ymm0 = _mm256_loadu_pd((double const *)(b10)); \ - ymm1 = _mm256_loadu_pd((double const *)(b10 + 2)); \ - \ - _mm_prefetch((char*)( b10 + 4*cs_b), _MM_HINT_T0); \ - ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0)); \ - ymm12 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0 + 1)); \ - \ - ymm3 = _mm256_fmadd_pd(ymm0, ymm2, ymm3);\ - ymm4 = _mm256_fmadd_pd(ymm1, ymm2, ymm4);\ - ymm8 = _mm256_fmadd_pd(ymm0, ymm12, ymm8);\ - ymm9 = _mm256_fmadd_pd(ymm1, ymm12, ymm9);\ - \ - ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1)); \ - ymm12 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1 + 1)); \ - \ - ymm5 = _mm256_fmadd_pd(ymm0, ymm2, ymm5);\ - ymm6 = _mm256_fmadd_pd(ymm1, ymm2, ymm6);\ - ymm10 = _mm256_fmadd_pd(ymm0, ymm12, ymm10);\ - ymm11 = _mm256_fmadd_pd(ymm1, ymm12, ymm11);\ - \ - tptr += 2; \ - b10 += cs_b; \ - }\ - }\ - ymm8 = _mm256_permute_pd(ymm8, 0x5);\ - ymm9 = _mm256_permute_pd(ymm9, 0x5);\ - ymm10 = _mm256_permute_pd(ymm10, 0x5);\ - ymm11 = _mm256_permute_pd(ymm11, 0x5);\ - ymm3 = _mm256_addsub_pd(ymm3, ymm8);\ - ymm4 = _mm256_addsub_pd(ymm4, ymm9);\ - ymm5 = _mm256_addsub_pd(ymm5, ymm10);\ - ymm6 = _mm256_addsub_pd(ymm6, ymm11);\ -} + ymm13 = _mm256_mul_pd(ymm6, ymm2); + ymm14 = _mm256_mul_pd(ymm6, ymm14); + ymm17 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + ymm17 = _mm256_mul_pd(ymm17, ymm15); + //Step 3 + ymm9 = _mm256_add_pd(ymm16, ymm9); + ymm5 = _mm256_add_pd(ymm17, ymm5); + + ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + rs_a*2)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); -/** - * Performs GEMM operation - * ymm0 holds 2 elements of a column. - */ -#define BLIS_ZTRSM_SMALL_GEMM_2nx2m(a01,b10,cs_b,p_lda,k_iter){\ - double *tptr = (double *)a01;\ - if(conjtransa) {\ - ymm18 = _mm256_set_pd(-1.0, -1.0, -1.0, -1.0);\ - for(k = 0; k< k_iter; k++) \ - { \ - ymm0 = _mm256_loadu_pd((double const *)(b10)); \ - \ - _mm_prefetch((char*)( b10 + 2*cs_b), _MM_HINT_T0); \ - ymm1 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0)); \ - ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0 + 1)); \ - ymm2 = _mm256_mul_pd(ymm2, ymm18);\ - \ - ymm3 = _mm256_fmadd_pd(ymm0, ymm1, ymm3);\ - ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4);\ - \ - \ - ymm1 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1)); \ - ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1 + 1)); \ - ymm2 = _mm256_mul_pd(ymm2, ymm18);\ - \ - ymm5 = _mm256_fmadd_pd(ymm0, ymm1, ymm5);\ - ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6);\ - \ - tptr += 2; \ - b10 += cs_b; \ - }\ - }\ - else {\ - for(k = 0; k< k_iter; k++) \ - { \ - ymm0 = _mm256_loadu_pd((double const *)(b10)); \ - \ - _mm_prefetch((char*)( b10 + 2*cs_b), _MM_HINT_T0); \ - ymm1 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0)); \ - ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0 + 1)); \ - \ - ymm3 = _mm256_fmadd_pd(ymm0, ymm1, ymm3);\ - ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4);\ - \ - \ - ymm1 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1)); \ - ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1 + 1)); \ - \ - ymm5 = _mm256_fmadd_pd(ymm0, ymm1, ymm5);\ - ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6);\ - \ - tptr += 2; \ - b10 += cs_b; \ - }\ - }\ - ymm4 = _mm256_permute_pd(ymm4, 0x5);\ - ymm6 = _mm256_permute_pd(ymm6, 0x5);\ - ymm3 = _mm256_addsub_pd(ymm3, ymm4);\ - ymm5 = _mm256_addsub_pd(ymm5, ymm6);\ -} + //For ymm10 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm10, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm10, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + //For ymm6 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); -/** - * Multiplies Alpha with 2 elements of a column. - * ymm0 holds the 2 element of a column. - */ -#define BLIS_PRE_ZTRSM_SMALL_1x1(AlphaVal,b11,cs_b){\ - ymm16 = _mm256_broadcast_pd(( __m128d const*)(&AlphaVal));\ - \ - xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 0));\ - ymm0 = _mm256_insertf128_pd(ymm1, xmm5, 0);\ - ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ - \ - ymm14 = _mm256_permute_pd(ymm16, 0x5);\ - ymm14 = _mm256_mul_pd(ymm14, ymm18);\ - ymm17 = _mm256_mul_pd(ymm0, ymm16);\ - ymm14 = _mm256_mul_pd(ymm0, ymm14);\ - ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ - ymm3 = _mm256_sub_pd(ymm15,ymm3);\ -} + ymm13 = _mm256_mul_pd(ymm6, ymm2); + ymm14 = _mm256_mul_pd(ymm6, ymm14); + ymm17 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + ymm17 = _mm256_mul_pd(ymm17, ymm15); + //Step 3 + ymm8 = _mm256_add_pd(ymm16, ymm8); + ymm4 = _mm256_add_pd(ymm17, ymm4); -/** - * Multiplies Alpha with 2 elements of a column. - * ymm0 holds the 2 element of a column. - */ -#define BLIS_PRE_ZTRSM_SMALL_1x2(AlphaVal,b11,cs_b){\ - ymm16 = _mm256_broadcast_pd(( __m128d const*)(&AlphaVal));\ - \ - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0));\ - ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ - \ - ymm14 = _mm256_permute_pd(ymm16, 0x5);\ - ymm14 = _mm256_mul_pd(ymm14, ymm18);\ - ymm17 = _mm256_mul_pd(ymm0, ymm16);\ - ymm14 = _mm256_mul_pd(ymm0, ymm14);\ - ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ - ymm3 = _mm256_sub_pd(ymm15,ymm3);\ -} +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm9 and ymm5 with ymm1*/ + BLIS_ZTRSM_TWO_DIV(ymm9,ymm5) +#else + /*performs dcomplex multiplication of ymm9 and ymm5 with ymm1*/ + BLIS_ZTRSM_MUL(ymm9) + BLIS_ZTRSM_MUL(ymm5) +#endif + //extract a44 + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); + //(ROW3): FMA operations + ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + rs_a)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); -/** - * Multiplies Alpha with 2 elements of 2 columns. - * ymm0 holds 2 elements of a columns respectively, - * once computation is done, gets stored in registers - * ymm3, ymm5 - */ -#define BLIS_PRE_ZTRSM_SMALL_2x2(AlphaVal,b11,cs_b){\ - ymm16 = _mm256_broadcast_pd(( __m128d const*)(&AlphaVal));\ - \ - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0));\ - ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ - \ - ymm14 = _mm256_permute_pd(ymm16, 0x5);\ - ymm14 = _mm256_mul_pd(ymm14, ymm18);\ - ymm17 = _mm256_mul_pd(ymm0, ymm16);\ - ymm14 = _mm256_mul_pd(ymm0, ymm14);\ - ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ - ymm3 = _mm256_sub_pd(ymm15,ymm3);\ - \ - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *1));\ -\ - ymm14 = _mm256_permute_pd(ymm16, 0x5);\ - ymm14 = _mm256_mul_pd(ymm14, ymm18);\ - ymm17 = _mm256_mul_pd(ymm0, ymm16);\ - ymm14 = _mm256_mul_pd(ymm0, ymm14);\ - ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ - ymm5 = _mm256_sub_pd(ymm15,ymm5);\ -} + //For ymm9 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm9, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm9, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + //For ymm5 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); -/** - * Performs GEMM operation - * 3 elements of a columns get held by ymm0(2 element) - * and xmm5 (1 element). - */ -#define BLIS_ZTRSM_SMALL_GEMM_1nx3m(a01,b10,cs_b,p_lda,k_iter) {\ - double *tptr = (double *)a01;\ - if(conjtransa) {\ - ymm18 = _mm256_set_pd(-1.0, -1.0, -1.0, -1.0);\ - for(k = 0; k< k_iter; k++) \ - {\ - ymm0 = _mm256_loadu_pd((double const *)(b10)); \ - /*ymm1 = _mm256_loadu_pd((double const *)(b10 + 2));*/\ - xmm5 = _mm_loadu_pd((double const *)(b10 + 2));\ - ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0);\ - \ - _mm_prefetch((char*)( b10 + 4*cs_b), _MM_HINT_T0); \ - ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0)); \ - ymm5 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0 + 1)); \ - ymm5 = _mm256_mul_pd(ymm5, ymm18);\ - \ - ymm3 = _mm256_fmadd_pd(ymm0, ymm2, ymm3);\ - ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6);\ - ymm4 = _mm256_fmadd_pd(ymm1, ymm5, ymm4);\ - ymm7 = _mm256_fmadd_pd(ymm1, ymm5, ymm7);\ - \ - tptr += 2;\ - b10 += cs_b;\ - }\ - }\ - else {\ - for(k = 0; k< k_iter; k++) \ - {\ - ymm0 = _mm256_loadu_pd((double const *)(b10)); \ - /*ymm1 = _mm256_loadu_pd((double const *)(b10 + 2));*/\ - xmm5 = _mm_loadu_pd((double const *)(b10 + 2));\ - ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0);\ - \ - _mm_prefetch((char*)( b10 + 4*cs_b), _MM_HINT_T0); \ - ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0)); \ - ymm5 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0 + 1)); \ - \ - ymm3 = _mm256_fmadd_pd(ymm0, ymm2, ymm3);\ - ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6);\ - ymm4 = _mm256_fmadd_pd(ymm1, ymm5, ymm4);\ - ymm7 = _mm256_fmadd_pd(ymm1, ymm5, ymm7);\ - \ - tptr += 2;\ - b10 += cs_b;\ - }\ - }\ - ymm6 = _mm256_permute_pd(ymm6, 0x5);\ - ymm7 = _mm256_permute_pd(ymm7, 0x5);\ - ymm3 = _mm256_addsub_pd(ymm3, ymm6);\ - ymm4 = _mm256_addsub_pd(ymm5, ymm7);\ -} + ymm13 = _mm256_mul_pd(ymm5, ymm2); + ymm14 = _mm256_mul_pd(ymm5, ymm14); + ymm17 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + ymm17 = _mm256_mul_pd(ymm17, ymm15); + //Step 3 + ymm8 = _mm256_add_pd(ymm16, ymm8); + ymm4 = _mm256_add_pd(ymm17, ymm4); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm8 and ymm4 with ymm1*/ + BLIS_ZTRSM_TWO_DIV(ymm8,ymm4) +#else + /*performs dcomplex nultiplication of ymm8 and ymm4 with ymm1*/ + BLIS_ZTRSM_MUL(ymm8) + BLIS_ZTRSM_MUL(ymm4) -/** - * Performs GEMM operation. - * 1 elements of a column are kept in ymm0. - */ -#define BLIS_ZTRSM_SMALL_GEMM_1nx1m(a01,b10,cs_b,p_lda,k_iter) {\ - double *tptr = (double *)a01;\ - if(conjtransa) {\ - ymm18 = _mm256_set_pd(-1.0, -1.0, -1.0, -1.0);\ - for(k = 0; k< k_iter; k++) \ - { \ - xmm5 = _mm_loadu_pd((double const *)(b10));\ - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ - \ - _mm_prefetch((char*)( b10 + 2*cs_b), _MM_HINT_T0); \ - ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0)); \ - ymm5 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0 + 1)); \ - ymm5 = _mm256_mul_pd(ymm5, ymm18);\ - \ - ymm3 = _mm256_fmadd_pd(ymm0, ymm2, ymm3);\ - ymm4 = _mm256_fmadd_pd(ymm0, ymm5, ymm4);\ - \ - tptr += 2; \ - b10 += cs_b; \ - }\ - }\ - else {\ - for(k = 0; k< k_iter; k++) \ - { \ - xmm5 = _mm_loadu_pd((double const *)(b10));\ - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ - \ - _mm_prefetch((char*)( b10 + 2*cs_b), _MM_HINT_T0); \ - ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0)); \ - ymm5 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0 + 1)); \ - \ - ymm3 = _mm256_fmadd_pd(ymm0, ymm2, ymm3);\ - ymm4 = _mm256_fmadd_pd(ymm0, ymm5, ymm4);\ - \ - tptr += 2; \ - b10 += cs_b; \ - }\ - }\ - ymm4 = _mm256_permute_pd(ymm4, 0x5);\ - ymm3 = _mm256_addsub_pd(ymm3, ymm4);\ -} +#endif + BLIS_ZTRSM_SMALL_NREG_TRANSPOSE_4x3_AND_STORE(b11,cs_b) -/** - * Performs GEMM operation. - * 2 elements of a column are kept in ymm0. - */ -#define BLIS_ZTRSM_SMALL_GEMM_1nx2m(a01,b10,cs_b,p_lda,k_iter) {\ - double *tptr = (double *)a01;\ - if(conjtransa) {\ - ymm18 = _mm256_set_pd(-1.0, -1.0, -1.0, -1.0);\ - for(k = 0; k< k_iter; k++) \ - { \ - ymm0 = _mm256_loadu_pd((double const *)(b10)); \ - \ - _mm_prefetch((char*)( b10 + 2*cs_b), _MM_HINT_T0); \ - ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0)); \ - ymm5 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0 + 1)); \ - ymm5 = _mm256_mul_pd(ymm5, ymm18);\ - \ - ymm3 = _mm256_fmadd_pd(ymm0, ymm2, ymm3);\ - ymm4 = _mm256_fmadd_pd(ymm0, ymm5, ymm4);\ - \ - tptr += 2; \ - b10 += cs_b; \ - }\ - }\ - else {\ - for(k = 0; k< k_iter; k++) \ - { \ - ymm0 = _mm256_loadu_pd((double const *)(b10)); \ - \ - _mm_prefetch((char*)( b10 + 2*cs_b), _MM_HINT_T0); \ - ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0)); \ - ymm5 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0 + 1)); \ - \ - ymm3 = _mm256_fmadd_pd(ymm0, ymm2, ymm3);\ - ymm4 = _mm256_fmadd_pd(ymm0, ymm5, ymm4);\ - \ - tptr += 2; \ - b10 += cs_b; \ - }\ - }\ - ymm4 = _mm256_permute_pd(ymm4, 0x5);\ - ymm3 = _mm256_addsub_pd(ymm3, ymm4);\ -} + } + dim_t n_remainder = j + d_nr; + if(n_remainder) + { + a10 = D_A_pack; + a11 = L + (i*cs_a) + (i*rs_a); + b01 = B + i + d_mr; + b11 = B + i; -/** - * Performs GEMM operation - * 4 elements of columns are kept in ymm0 and ymm1. - */ -#define BLIS_ZTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) {\ - double *tptr = (double *)a01;\ - if(conjtransa) {\ - ymm18 = _mm256_set_pd(-1.0, -1.0, -1.0, -1.0);\ - for(k = 0; k< k_iter; k++) \ - { \ - ymm0 = _mm256_loadu_pd((double const *)(b10)); \ - ymm1 = _mm256_loadu_pd((double const *)(b10 + 2)); \ - \ - _mm_prefetch((char*)( b10 + 4*cs_b), _MM_HINT_T0); \ - ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0));\ - ymm9 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0 + 1));\ - ymm9 = _mm256_mul_pd(ymm9, ymm18);\ - \ - ymm3 = _mm256_fmadd_pd(ymm0, ymm2, ymm3);\ - ymm4 = _mm256_fmadd_pd(ymm1, ymm2, ymm4);\ - ymm10 = _mm256_fmadd_pd(ymm0, ymm9, ymm10);\ - ymm11 = _mm256_fmadd_pd(ymm1, ymm9, ymm11);\ - \ - ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1)); \ - ymm9 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1 + 1)); \ - ymm9 = _mm256_mul_pd(ymm9, ymm18);\ - \ - ymm5 = _mm256_fmadd_pd(ymm0, ymm2, ymm5);\ - ymm6 = _mm256_fmadd_pd(ymm1, ymm2, ymm6);\ - ymm12 = _mm256_fmadd_pd(ymm0, ymm9, ymm12);\ - ymm13 = _mm256_fmadd_pd(ymm1, ymm9, ymm13);\ - \ - ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 2)); \ - ymm9 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 2 + 1)); \ - ymm9 = _mm256_mul_pd(ymm9, ymm18);\ - \ - ymm7 = _mm256_fmadd_pd(ymm0, ymm2, ymm7);\ - ymm8 = _mm256_fmadd_pd(ymm1, ymm2, ymm8);\ - ymm14 = _mm256_fmadd_pd(ymm0, ymm9, ymm14);\ - ymm15 = _mm256_fmadd_pd(ymm1, ymm9, ymm15);\ - \ - tptr += 2; \ - b10 += cs_b; \ - }\ - }\ - else {\ - for(k = 0; k< k_iter; k++) \ - { \ - ymm0 = _mm256_loadu_pd((double const *)(b10)); \ - ymm1 = _mm256_loadu_pd((double const *)(b10 + 2)); \ - \ - _mm_prefetch((char*)( b10 + 4*cs_b), _MM_HINT_T0); \ - ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0));\ - ymm9 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0 + 1));\ - \ - ymm3 = _mm256_fmadd_pd(ymm0, ymm2, ymm3);\ - ymm4 = _mm256_fmadd_pd(ymm1, ymm2, ymm4);\ - ymm10 = _mm256_fmadd_pd(ymm0, ymm9, ymm10);\ - ymm11 = _mm256_fmadd_pd(ymm1, ymm9, ymm11);\ - \ - ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1)); \ - ymm9 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1 + 1)); \ - \ - ymm5 = _mm256_fmadd_pd(ymm0, ymm2, ymm5);\ - ymm6 = _mm256_fmadd_pd(ymm1, ymm2, ymm6);\ - ymm12 = _mm256_fmadd_pd(ymm0, ymm9, ymm12);\ - ymm13 = _mm256_fmadd_pd(ymm1, ymm9, ymm13);\ - \ - ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 2)); \ - ymm9 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 2 + 1)); \ - \ - ymm7 = _mm256_fmadd_pd(ymm0, ymm2, ymm7);\ - ymm8 = _mm256_fmadd_pd(ymm1, ymm2, ymm8);\ - ymm14 = _mm256_fmadd_pd(ymm0, ymm9, ymm14);\ - ymm15 = _mm256_fmadd_pd(ymm1, ymm9, ymm15);\ - \ - tptr += 2; \ - b10 += cs_b; \ - }\ - }\ - ymm10 = _mm256_permute_pd(ymm10, 0x5);\ - ymm11 = _mm256_permute_pd(ymm11, 0x5);\ - ymm12 = _mm256_permute_pd(ymm12, 0x5);\ - ymm13 = _mm256_permute_pd(ymm13, 0x5);\ - ymm14 = _mm256_permute_pd(ymm14, 0x5);\ - ymm15 = _mm256_permute_pd(ymm15, 0x5);\ -\ - ymm3 = _mm256_addsub_pd(ymm3, ymm10);\ - ymm4 = _mm256_addsub_pd(ymm4, ymm11);\ - ymm5 = _mm256_addsub_pd(ymm5, ymm12);\ - ymm6 = _mm256_addsub_pd(ymm6, ymm13);\ - ymm7 = _mm256_addsub_pd(ymm7, ymm14);\ - ymm8 = _mm256_addsub_pd(ymm8, ymm15);\ -} + k_iter = (m - i - d_mr) ; -/** - * Multiplies Alpha with 4 element of 2 columns. - * ymm0 and ymm1 holds 4 elements of a column. - */ -#define BLIS_PRE_ZTRSM_SMALL_2x4(AlphaVal,b11,cs_b) {\ - ymm16 = _mm256_broadcast_pd(( __m128d const*)(&AlphaVal));\ - \ - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0));\ - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 2));\ - ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ - \ - ymm14 = _mm256_permute_pd(ymm16, 0x5);\ - ymm14 = _mm256_mul_pd(ymm14, ymm18);\ - ymm17 = _mm256_mul_pd(ymm0, ymm16);\ - ymm14 = _mm256_mul_pd(ymm0, ymm14);\ - ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ - ymm3 = _mm256_sub_pd(ymm15,ymm3);\ - \ - ymm14 = _mm256_permute_pd(ymm16, 0x5);\ - ymm14 = _mm256_mul_pd(ymm14, ymm18);\ - ymm17 = _mm256_mul_pd(ymm1, ymm16);\ - ymm14 = _mm256_mul_pd(ymm1, ymm14);\ - ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ - ymm4 = _mm256_sub_pd(ymm15,ymm4);\ - \ - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *1));\ - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 2));\ -\ - ymm14 = _mm256_permute_pd(ymm16, 0x5);\ - ymm14 = _mm256_mul_pd(ymm14, ymm18);\ - ymm17 = _mm256_mul_pd(ymm0, ymm16);\ - ymm14 = _mm256_mul_pd(ymm0, ymm14);\ - ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ - ymm5 = _mm256_sub_pd(ymm15,ymm5);\ - \ - ymm14 = _mm256_permute_pd(ymm16, 0x5);\ - ymm14 = _mm256_mul_pd(ymm14, ymm18);\ - ymm17 = _mm256_mul_pd(ymm1, ymm16);\ - ymm14 = _mm256_mul_pd(ymm1, ymm14);\ - ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ - ymm6 = _mm256_sub_pd(ymm15,ymm6);\ -} + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS -/** - * Multiplies Alpha with 4 element of 3 columns. - * ymm0 and ymm1 holds 4 elements of a column. - */ -#define BLIS_PRE_ZTRSM_SMALL_3x4(AlphaVal,b11,cs_b) {\ - ymm16 = _mm256_broadcast_pd(( __m128d const*)(&AlphaVal));\ - \ - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0));\ - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 2));\ - ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ - \ - ymm14 = _mm256_permute_pd(ymm16, 0x5);\ - ymm14 = _mm256_mul_pd(ymm14, ymm18);\ - ymm17 = _mm256_mul_pd(ymm0, ymm16);\ - ymm14 = _mm256_mul_pd(ymm0, ymm14);\ - ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ - ymm3 = _mm256_sub_pd(ymm15,ymm3);\ - \ - ymm14 = _mm256_permute_pd(ymm16, 0x5);\ - ymm14 = _mm256_mul_pd(ymm14, ymm18);\ - ymm17 = _mm256_mul_pd(ymm1, ymm16);\ - ymm14 = _mm256_mul_pd(ymm1, ymm14);\ - ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ - ymm4 = _mm256_sub_pd(ymm15,ymm4);\ - \ - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *1));\ - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 2));\ -\ - ymm14 = _mm256_permute_pd(ymm16, 0x5);\ - ymm14 = _mm256_mul_pd(ymm14, ymm18);\ - ymm17 = _mm256_mul_pd(ymm0, ymm16);\ - ymm14 = _mm256_mul_pd(ymm0, ymm14);\ - ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ - ymm5 = _mm256_sub_pd(ymm15,ymm5);\ - \ - ymm14 = _mm256_permute_pd(ymm16, 0x5);\ - ymm14 = _mm256_mul_pd(ymm14, ymm18);\ - ymm17 = _mm256_mul_pd(ymm1, ymm16);\ - ymm14 = _mm256_mul_pd(ymm1, ymm14);\ - ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ - ymm6 = _mm256_sub_pd(ymm15,ymm6);\ - \ - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *2));\ - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *2 + 2));\ - \ - ymm14 = _mm256_permute_pd(ymm16, 0x5);\ - ymm14 = _mm256_mul_pd(ymm14, ymm18);\ - ymm17 = _mm256_mul_pd(ymm0, ymm16);\ - ymm14 = _mm256_mul_pd(ymm0, ymm14);\ - ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ - ymm7 = _mm256_sub_pd(ymm15,ymm7);\ - \ - ymm14 = _mm256_permute_pd(ymm16, 0x5);\ - ymm14 = _mm256_mul_pd(ymm14, ymm18);\ - ymm17 = _mm256_mul_pd(ymm1, ymm16);\ - ymm14 = _mm256_mul_pd(ymm1, ymm14);\ - ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ - ymm8 = _mm256_sub_pd(ymm15,ymm8);\ - \ -} + if(2 == n_remainder) + { + ///GEMM code begins/// + BLIS_ZTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b,p_lda,k_iter) -/* - * Pack a block of 4xk or 3xk from input buffer into packed buffer - * directly or after transpose based on input params - */ + ymm16 = _mm256_broadcast_pd((__m128d const *)(&AlphaVal)); + //register to hold alpha + BLIS_ZTRSM_SMALL_NREG_TRANSPOSE_2x4(b11,cs_b,AlphaVal) + } + else if(1 == n_remainder) + { + ///GEMM code begins/// + BLIS_ZTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b,p_lda,k_iter) + BLIS_ZTRSM_SMALL_NREG_TRANSPOSE_1x4(b11,cs_b,AlphaVal) + } + ///implement TRSM/// + ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + ymm10 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm11 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); -/* - * Load b11 of size 3x4 and multiply with alpha - * Add the GEMM output and perform inregister transose of b11 - * to peform ZTRSM operation for left cases. - */ -#define BLIS_ZTRSM_SMALL_NREG_TRANSPOSE_3x4(b11,cs_b,AlphaVal) {\ - ymm16 = _mm256_broadcast_pd(( __m128d const *)(&AlphaVal));\ -\ - ymm0 = _mm256_loadu_pd((double const *)(b11));\ - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1));\ - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2));\ - ymm3 = _mm256_broadcast_pd((__m128d const *)&ones);\ - /*in register transpose - * ymm0,ymm1,ymm2 holds - * two dcomplex elements of b11 cols*/\ - ymm14 = _mm256_shuffle_pd(ymm16, ymm16, 5);\ - ymm5 = _mm256_shuffle_pd(ymm0, ymm0, 15);\ - ymm6 = _mm256_shuffle_pd(ymm0, ymm0,0);\ - ymm7 = _mm256_mul_pd(ymm5, ymm14);\ - ymm15 = _mm256_fmaddsub_pd(ymm6, ymm16, ymm7);\ - ymm0 = _mm256_sub_pd(ymm15, ymm8);\ -\ - ymm5 = _mm256_shuffle_pd(ymm1, ymm1, 15);\ - ymm6 = _mm256_shuffle_pd(ymm1, ymm1,0);\ - ymm7 = _mm256_mul_pd(ymm5, ymm14);\ - ymm15 = _mm256_fmaddsub_pd(ymm6, ymm16, ymm7);\ - ymm1 = _mm256_sub_pd(ymm15, ymm9);\ -\ - ymm5 = _mm256_shuffle_pd(ymm2, ymm2, 15);\ - ymm6 = _mm256_shuffle_pd(ymm2, ymm2,0);\ - ymm7 = _mm256_mul_pd(ymm5, ymm14);\ - ymm15 = _mm256_fmaddsub_pd(ymm6, ymm16, ymm7);\ - ymm2 = _mm256_sub_pd(ymm15, ymm10);\ -\ - /*in register transpose of computed b11 col*/\ - ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); \ - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31);\ - ymm4 = _mm256_permute2f128_pd(ymm2,ymm3,0x20); \ - ymm5 = _mm256_permute2f128_pd(ymm2,ymm3,0x31); \ -\ - /*in register transpose - * ymm0,ymm1,ymm2 holds - * next two dcomplex elements of b11 cols*/\ - ymm0 = _mm256_loadu_pd((double const *)(b11 + 2));\ - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 2));\ - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2 + 2));\ -\ - ymm17 = _mm256_shuffle_pd(ymm0, ymm0, 15);\ - ymm18 = _mm256_shuffle_pd(ymm0, ymm0, 0);\ - ymm19 = _mm256_mul_pd(ymm17, ymm14);\ - ymm15 = _mm256_fmaddsub_pd(ymm18, ymm16, ymm19);\ - ymm0 = _mm256_sub_pd(ymm15, ymm11);\ -\ - ymm17 = _mm256_shuffle_pd(ymm1, ymm1, 15);\ - ymm18 = _mm256_shuffle_pd(ymm1, ymm1, 0);\ - ymm19 = _mm256_mul_pd(ymm17, ymm14);\ - ymm15 = _mm256_fmaddsub_pd(ymm18, ymm16, ymm19);\ - ymm1 = _mm256_sub_pd(ymm15, ymm12);\ -\ - ymm17 = _mm256_shuffle_pd(ymm2, ymm2, 15);\ - ymm18 = _mm256_shuffle_pd(ymm2, ymm2, 0);\ - ymm19 = _mm256_mul_pd(ymm17, ymm14);\ - ymm15 = _mm256_fmaddsub_pd(ymm18, ymm16, ymm19);\ - ymm2 = _mm256_sub_pd(ymm15, ymm13);\ -\ - /*in register transpose of computed b11 col*/\ - ymm10 = _mm256_permute2f128_pd(ymm0,ymm1,0x20);\ - ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31);\ - ymm6 = _mm256_permute2f128_pd(ymm2,ymm3,0x20);\ - ymm7 = _mm256_permute2f128_pd(ymm2,ymm3,0x31);\ -} + ////extract a00 + ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack + 3)); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_ZTRSM_DIV(ymm11) +#else + BLIS_ZTRSM_MUL(ymm11) +#endif -/** - * Performs GEMM operation. - * 4 elements of a column are kept inymm0 and ymm1 - */ -#define BLIS_ZTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b,p_lda,k_iter) {\ - double *tptr = (double *)b01;\ - if(conjtransa) {\ - ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ - for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ - {\ - ymm0 = _mm256_loadu_pd((double const *)(a10));\ - ymm1 = _mm256_loadu_pd((double const *)(a10 + 2));\ - ymm0 = _mm256_mul_pd(ymm0, ymm18);\ - ymm1 = _mm256_mul_pd(ymm1, ymm18);\ - \ - ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0));\ - ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0 + 1)); \ - \ - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8);\ - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12);\ - \ - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);\ - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5);\ - tptr += 2; /*move to next row of B*/\ - a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ - }\ - }\ - else {\ - for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ - {\ - ymm0 = _mm256_loadu_pd((double const *)(a10));\ - ymm1 = _mm256_loadu_pd((double const *)(a10 + 2));\ - ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0));\ - ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0 + 1)); \ - \ - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8);\ - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12);\ - \ - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);\ - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5);\ - tptr += 2; /*move to next row of B*/\ - a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ - }\ - }\ - ymm4 = _mm256_permute_pd(ymm4, 0x5);\ - ymm5 = _mm256_permute_pd(ymm5, 0x5);\ - ymm8 = _mm256_addsub_pd(ymm8, ymm4);\ - ymm12 = _mm256_addsub_pd(ymm12, ymm5);\ -} + //extract a11 + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack + 2)); + //(ROW1): FMA operations + ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a*2 + rs_a*3)); + ymm3 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a*1 + rs_a*3)); + ymm4 = _mm256_broadcast_pd((__m128d const *)(a11 + rs_a*3)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + ymm3 = _mm256_mul_pd(ymm3, ymm0); + ymm4 = _mm256_mul_pd(ymm4, ymm0); + } + /*Step1 dcomplex multiply ymmx, ymmx + * Step2 negate the result + * Step3 add ymmx*/ + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + //For ymm8 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm11, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm11, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); -/** - * Performs the GEMM operation. - * 2 elements of a column are kept in ymm0. - */ -#define BLIS_ZTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b,p_lda,k_iter) {\ - double *tptr = (double * )b01;\ - if(conjtransa) {\ - ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ - for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ - {\ - ymm0 = _mm256_loadu_pd((double const *)(a10));\ - ymm1 = _mm256_loadu_pd((double const *)(a10 + 2));\ - ymm0 = _mm256_mul_pd(ymm0, ymm18);\ - ymm1 = _mm256_mul_pd(ymm1, ymm18);\ - \ - ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0));\ - ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0 + 1)); \ - \ - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8);\ - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12);\ - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);\ - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5);\ - ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1)); \ - ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1 + 1)); \ - \ - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9);\ - ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13);\ - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6);\ - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7);\ - tptr += 2; /*move to next row of B*/\ - a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ - }\ - }\ - else {\ - for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ - {\ - ymm0 = _mm256_loadu_pd((double const *)(a10));\ - ymm1 = _mm256_loadu_pd((double const *)(a10 + 2));\ - ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0));\ - ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0 + 1)); \ - \ - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8);\ - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12);\ - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);\ - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5);\ - ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1)); \ - ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1 + 1)); \ - \ - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9);\ - ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13);\ - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6);\ - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7);\ - tptr += 2; /*move to next row of B*/\ - a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ - }\ - }\ - ymm4 = _mm256_permute_pd(ymm4, 0x5);\ - ymm5 = _mm256_permute_pd(ymm5, 0x5);\ - ymm6 = _mm256_permute_pd(ymm6, 0x5);\ - ymm7 = _mm256_permute_pd(ymm7, 0x5);\ -\ - ymm8 = _mm256_addsub_pd(ymm8, ymm4);\ - ymm12 = _mm256_addsub_pd(ymm12, ymm5);\ - ymm9 = _mm256_addsub_pd(ymm9, ymm6);\ - ymm13 = _mm256_addsub_pd(ymm13, ymm7);\ -} + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); -/*GEMM block used in ztrsm small left cases*/ -#define BLIS_ZTRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) {\ - double *tptr = (double *)b01;\ - if(conjtransa) {\ - ymm16 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ - for(k = 0; k< k_iter; k++) \ - { \ - ymm0 = _mm256_loadu_pd((double const *)(a10)); \ - ymm1 = _mm256_loadu_pd((double const *)(a10 + 2)); \ - ymm0 = _mm256_mul_pd(ymm0, ymm16);\ - ymm1 = _mm256_mul_pd(ymm1, ymm16);\ - \ - ymm2 = _mm256_broadcast_sd((double const *)(tptr)); \ - ymm3 = _mm256_broadcast_sd((double const *)(tptr + 1)); \ - \ - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8);\ - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11);\ - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);\ - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5);\ - \ - ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 1 * 2)); \ - ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 1 * 2 + 1)); \ - \ - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9);\ - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12);\ - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6);\ - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7);\ - \ - ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b *2 * 2)); \ - ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 2 + 1)); \ - \ - ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10);\ - ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13);\ - \ - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14);\ - ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15);\ - \ - tptr += 2; \ - a10 += p_lda; \ - }\ - }\ - else {\ - for(k = 0; k< k_iter; k++) \ - { \ - ymm0 = _mm256_loadu_pd((double const *)(a10)); \ - ymm1 = _mm256_loadu_pd((double const *)(a10 + 2)); \ - \ - ymm2 = _mm256_broadcast_sd((double const *)(tptr)); \ - ymm3 = _mm256_broadcast_sd((double const *)(tptr + 1)); \ - \ - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8);\ - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11);\ - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);\ - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5);\ - \ - ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 1 * 2)); \ - ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 1 * 2 + 1)); \ - \ - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9);\ - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12);\ - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6);\ - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7);\ - \ - ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b *2 * 2)); \ - ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 2 + 1)); \ - \ - ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10);\ - ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13);\ - \ - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14);\ - ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15);\ - \ - tptr += 2; \ - a10 += p_lda; \ - }\ - }\ - ymm4 = _mm256_permute_pd(ymm4, 0x5);\ - ymm5 = _mm256_permute_pd(ymm5, 0x5);\ - ymm6 = _mm256_permute_pd(ymm6, 0x5);\ - ymm7 = _mm256_permute_pd(ymm7, 0x5);\ - ymm14 = _mm256_permute_pd(ymm14, 0x5);\ - ymm15 = _mm256_permute_pd(ymm15, 0x5);\ - \ - ymm8 = _mm256_addsub_pd(ymm8, ymm4);\ - ymm11 = _mm256_addsub_pd(ymm11, ymm5);\ - ymm9 = _mm256_addsub_pd(ymm9, ymm6);\ - ymm12 = _mm256_addsub_pd(ymm12, ymm7);\ - ymm10 = _mm256_addsub_pd(ymm10, ymm14);\ - ymm13 = _mm256_addsub_pd(ymm13, ymm15);\ -} + //Step 3 + ymm10 = _mm256_add_pd(ymm16, ymm10); + + //Step 1 + ymm14 = _mm256_permute_pd(ymm3, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + //For ymm8 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm11, ymm3); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm11, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + //Step 3 + ymm9 = _mm256_add_pd(ymm16, ymm9); -#define BLIS_ZTRSM_SMALL_NREG_TRANSPOSE_4x3_AND_STORE(b11,cs_b){\ - ymm0 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20);\ - ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31);\ - ymm2 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20);\ - _mm256_storeu_pd((double *)(b11), ymm0);\ - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1);\ - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2);\ -\ - ymm0 = _mm256_permute2f128_pd(ymm10, ymm11, 0x20);\ - ymm1 = _mm256_permute2f128_pd(ymm10, ymm11, 0x31);\ - ymm2 = _mm256_permute2f128_pd(ymm6, ymm7, 0x20);\ - _mm256_storeu_pd((double *)(b11 + 2), ymm0);\ - _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 2), ymm1);\ - _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 2), ymm2);\ -} + //Step 1 + ymm14 = _mm256_permute_pd(ymm4, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + //For ymm8 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm11, ymm4); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm11, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + //Step 3 + ymm8 = _mm256_add_pd(ymm16, ymm8); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_ZTRSM_DIV(ymm10) +#else + BLIS_ZTRSM_MUL(ymm10) +#endif + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack + 1)); + ymm3 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a*1 + rs_a*2)); + ymm4 = _mm256_broadcast_pd((__m128d const *)(a11 + rs_a*2)); + if(conjtransa) + { + ymm3 = _mm256_mul_pd(ymm3, ymm0); + ymm4 = _mm256_mul_pd(ymm4, ymm0); + } + //Step 1 + ymm14 = _mm256_permute_pd(ymm3, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + //For ymm9 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm10, ymm3); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm10, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); -/** - * Performs dcomplex division of vec1 and vec2 with ymm1. - * vec1 and vec2 gets divided by ymm1 which holds - * diagonal element from buffer. - * Function gets called while performing TRSM. - */ -#define BLIS_ZTRSM_TWO_DIV(vec1, vec2) {\ - if(!is_unitdiag) {\ - if(conjtransa){\ - ymm1 = _mm256_mul_pd(ymm1, ymm0);\ - }\ - ymm12 = _mm256_mul_pd(ymm1, ymm0);\ - /*perform decomplex multiplication*/\ - /* Switch the real and imaginary elements of vec2 */\ - ymm14 = _mm256_permute_pd(ymm12, 0x5);\ - /* Negate the imaginary elements of vec2 */\ - ymm14 = _mm256_mul_pd(ymm14, ymm0);\ - /* Multiply vec1 and vec2 */ \ - ymm13 = _mm256_mul_pd(vec1, ymm12); /*vec3*/\ - /* Multiply vec1 and the modified vec2 */\ - ymm14 = _mm256_mul_pd(vec1, ymm14); /*vec4*/\ - /* Horizontally subtract the elements in vec3 and vec4 */\ - vec1 = _mm256_hsub_pd(ymm13, ymm14);\ - \ - ymm14 = _mm256_permute_pd(ymm12, 0x5);\ - /* Negate the imaginary elements of vec2 */\ - ymm14 = _mm256_mul_pd(ymm14, ymm0);\ - ymm13 = _mm256_mul_pd(vec2, ymm12);\ - ymm14 = _mm256_mul_pd(vec2, ymm14);\ - vec2 = _mm256_hsub_pd(ymm13, ymm14);\ - /*dcomplex multiplication is done*/\ - /*Swapping real & imaginary component position for addition with respective - * components*/\ - ymm12 = _mm256_mul_pd(ymm1, ymm1);\ - ymm13 = _mm256_permute4x64_pd(ymm12, 0xb1);\ - ymm14 = _mm256_add_pd(ymm12, ymm13);\ - \ - /*Finally dividing numerator by denominator*/\ - vec1 = _mm256_div_pd(vec1, ymm14);\ - vec2 = _mm256_div_pd(vec2, ymm14);\ - }\ -} + //Step 3 + ymm9 = _mm256_add_pd(ymm16, ymm9); -/** - * Performs dcomplex division of vec1 with ymm1. - * ymm1 holds diagonal element from buffer. - * Function gets called while performing TRSM. - */ -#define BLIS_ZTRSM_DIV(vec1) {\ - if(!is_unitdiag){\ - if(conjtransa){\ - ymm1 = _mm256_mul_pd(ymm1, ymm0);\ - }\ - ymm12 = _mm256_mul_pd(ymm1, ymm0); /*vec2 and ymm8 is vec1*/\ - ymm14 = _mm256_permute_pd(ymm12, 0x5);\ - ymm14 = _mm256_mul_pd(ymm14, ymm0);\ - ymm13 = _mm256_mul_pd(vec1, ymm12); /*vec3*/\ - ymm14 = _mm256_mul_pd(vec1, ymm14); /*vec4*/\ - vec1 = _mm256_hsub_pd(ymm13, ymm14);\ - \ - ymm12 = _mm256_mul_pd(ymm1, ymm1);\ - ymm13 = _mm256_permute4x64_pd(ymm12, 0xb1);\ - ymm14 = _mm256_add_pd(ymm12, ymm13);\ - \ - /*Finally dividing numerator by denominator*/\ - vec1 = _mm256_div_pd(vec1, ymm14);\ - }\ -} + //Step 1 + ymm14 = _mm256_permute_pd(ymm4, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + //For ymm8 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm10, ymm4); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm10, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + //Step 3 + ymm8 = _mm256_add_pd(ymm16, ymm8); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_ZTRSM_DIV(ymm9) +#else + BLIS_ZTRSM_MUL(ymm9) +#endif -/** - * Performs dcomplex multiplication of vec1 with ymm1. - * ymm1 holds diagonal element from buffer. - * Function gets called while performing TRSM. - */ -#define BLIS_ZTRSM_MUL(vec1) {\ - if(!is_unitdiag){\ - if(conjtransa){\ - ymm19 = _mm256_mul_pd(ymm1, ymm0);\ - }\ - else{\ - ymm19 = ymm1;\ - }\ - ymm14 = _mm256_permute_pd(ymm19, 0x5);\ - /* Negate the imaginary elements of vec2 */\ - ymm14 = _mm256_mul_pd(ymm14, ymm0);\ - /* Multiply vec1 and vec2 */\ - ymm13 = _mm256_mul_pd(vec1, ymm19); /*vec3*/\ - /* Multiply vec1 and the modified vec2 */\ - ymm14 = _mm256_mul_pd(vec1, ymm14); /*vec4*/\ - /* Horizontally subtract the elements in vec3 and vec4 */\ - vec1 = _mm256_hsub_pd(ymm13, ymm14);\ - }\ -} + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); + ymm4 = _mm256_broadcast_pd((__m128d const *)(a11 + rs_a)); + if(conjtransa) + { + ymm4 = _mm256_mul_pd(ymm4, ymm0); + } + //Step 1 + ymm14 = _mm256_permute_pd(ymm4, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + //For ymm10 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm9, ymm4); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm9, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + //Step 3 + ymm8 = _mm256_add_pd(ymm16, ymm8); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_ZTRSM_DIV(ymm8) +#else + BLIS_ZTRSM_MUL(ymm8) +#endif -BLIS_INLINE void bli_ztrsm_small_pack -( - char side, - dim_t size, - bool trans, - dcomplex *inbuf, - dim_t cs_a, - dcomplex *pbuff, - dim_t p_lda, - dim_t mr -) -{ - //scratch registers - __m256d ymm0, ymm1, ymm2; - __m256d ymm5, ymm6, ymm7; - __m256d ymm8, ymm9, ymm10, ymm11; - __m128d xmm0,xmm1,xmm2; - double zero = 0.0; + if(2 == n_remainder) + { + ymm0 = _mm256_permute2f128_pd(ymm8,ymm9,0x20); + ymm4 = _mm256_permute2f128_pd(ymm10,ymm11,0x20); + ymm1 = _mm256_permute2f128_pd(ymm8,ymm9,0x31); + ymm3 = _mm256_permute2f128_pd(ymm10,ymm11,0x31); + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 2), ymm4); + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); + _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 2), ymm3); - if(side=='L'||side=='l') + } + else if(1 == n_remainder) + { + ymm0 = _mm256_permute2f128_pd(ymm8,ymm9,0x20); + ymm4 = _mm256_permute2f128_pd(ymm10,ymm11,0x20); + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 2), ymm4); + } + } + } + + dim_t m_remainder = i + d_mr; + a10 = L + m_remainder*rs_a; + dcomplex *ptr_a10_dup = D_A_pack; + if(m_remainder == 3) { - /*Left case is 4xk*/ - if(trans) + dim_t p_lda = 4; + if(transa) { - /* - ------------- ------------- - | | | | | - | 2x4 | | | | - ------------- ==> | 4x2 | 4x2 | - | 2x4 | | | | - | | | | | - ------------- ------------- - */ - for(dim_t x = 0; x < size; x += mr) + for(dim_t x = 0; x < m-m_remainder; x += p_lda) { - ymm0 = _mm256_loadu_pd((double const *)(inbuf)); - ymm10 = _mm256_loadu_pd((double const *)(inbuf + 2)); - ymm1 = _mm256_loadu_pd((double const *)(inbuf + cs_a)); - ymm11 = _mm256_loadu_pd((double const *)(inbuf + 2 + cs_a)); + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm10 = _mm256_loadu_pd((double const *) + (a10 + 2)); + ymm1 = _mm256_loadu_pd((double const *) + (a10 + cs_a)); + ymm11 = _mm256_loadu_pd((double const *) + (a10 + 2 + cs_a)); ymm6 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); ymm8 = _mm256_permute2f128_pd(ymm10,ymm11,0x20); ymm9 = _mm256_permute2f128_pd(ymm10,ymm11,0x31); - _mm256_storeu_pd((double *)(pbuff), ymm6); - _mm256_storeu_pd((double *)(pbuff + p_lda), ymm7); - _mm256_storeu_pd((double *)(pbuff + p_lda*2), ymm8); - _mm256_storeu_pd((double *)(pbuff + p_lda*3), ymm9); - - ymm0 = _mm256_loadu_pd((double const *)(inbuf + 2 * cs_a)); - ymm10 = _mm256_loadu_pd((double const *)(inbuf + 2 * cs_a + 2)); - ymm1 = _mm256_loadu_pd((double const *)(inbuf + 3 * cs_a)); - ymm11 = _mm256_loadu_pd((double const *)(inbuf + 3 * cs_a + 2)); + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + + p_lda*3), ymm9); + + ymm0 = _mm256_loadu_pd((double const *)(a10 + + 2 * cs_a)); + ymm10 = _mm256_loadu_pd((double const *)(a10 + + 2 * cs_a + 2)); + + ymm1 = _mm256_loadu_pd((double const *)(a10 + + 3 * cs_a)); + ymm11 = _mm256_loadu_pd((double const *)(a10 + + 3 * cs_a + 2)); ymm6 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); ymm8 = _mm256_permute2f128_pd(ymm10,ymm11,0x20); ymm9 = _mm256_permute2f128_pd(ymm10,ymm11,0x31); - _mm256_storeu_pd((double *)(pbuff + 2), ymm6); - _mm256_storeu_pd((double *)(pbuff + p_lda + 2), ymm7); - _mm256_storeu_pd((double *)(pbuff + p_lda*2 + 2), ymm8); - _mm256_storeu_pd((double *)(pbuff + p_lda*3 + 2), ymm9); + _mm256_storeu_pd((double *)(ptr_a10_dup + 2), + ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + + p_lda + 2), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + + p_lda*2 + 2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + + p_lda*3 + 2), ymm9); - inbuf += mr; - pbuff += mr*mr; + a10 += p_lda; + ptr_a10_dup += p_lda * p_lda; } - }else + + } + else { - //Expected multiples of 4 - p_lda = 4; - for(dim_t x = 0; x < size; x++) + for(dim_t x=0;x < m-m_remainder;x++) { - ymm0 = _mm256_loadu_pd((double const *)(inbuf)); - _mm256_storeu_pd((double *)(pbuff), ymm0); - ymm1 = _mm256_loadu_pd((double const *)(inbuf + 2)); - _mm256_storeu_pd((double *)(pbuff + 2), ymm1); - inbuf+=cs_a; - pbuff+=p_lda; + ymm0 = _mm256_loadu_pd((double const *) + (a10 + rs_a * x)); + _mm256_storeu_pd((double *) + (ptr_a10_dup + p_lda * x), ymm0); + ymm0 = _mm256_loadu_pd((double const *) + (a10 + rs_a * x + 2)); + _mm256_storeu_pd((double *) + (ptr_a10_dup + p_lda * x + 2), + ymm0); } } - }else if(side=='R'||side=='r') - { - - if(trans) + //cols + for(j = (n - d_nr); (j + 1) > 0; j -= d_nr) { - for(dim_t x=0; x>1); i++) - { - ymm0 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 0 )); - _mm256_storeu_pd((double *)(pbuff + p_lda * 0), ymm0); - ymm1 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 1 )); - _mm256_storeu_pd((double *)(pbuff + p_lda * 1), ymm1); - ymm2 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 2)); - _mm256_storeu_pd((double *)(pbuff + p_lda * 2), ymm2); - inbuf += 2; - pbuff += 2; - } - if(size & 0x1) - { - xmm0 = _mm_loadu_pd((double const *)(inbuf + cs_a * 0)); - _mm_storeu_pd((double *)(pbuff + p_lda * 0 ), xmm0); - xmm1 = _mm_loadu_pd((double const *)(inbuf + cs_a * 1)); - _mm_storeu_pd((double *)(pbuff + p_lda * 1), xmm1); - xmm2 = _mm_loadu_pd((double const *)(inbuf + cs_a * 2)); - _mm_storeu_pd((double *)(pbuff + p_lda * 2), xmm2); - } - } - } + else + ztrsm_AuXB_ref(a11, b11, m_remainder, 2, + rs_a, cs_b, is_unitdiag, + conjtransa); + } + else if(1 == n_remainder) + { + ///GEMM code begins/// + BLIS_ZTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b, + p_lda,k_iter) + BLIS_PRE_ZTRSM_SMALL_3M_1N(AlphaVal,b11,cs_b) -} + if(transa) + ztrsm_AltXB_ref(a11, b11, m_remainder, 1, + cs_a, cs_b, is_unitdiag, + conjtransa); + else + ztrsm_AuXB_ref(a11, b11, m_remainder, 1, + rs_a, cs_b, is_unitdiag, + conjtransa); + } + } + } + else if(m_remainder == 2) + { + dim_t p_lda = 2; + if(transa) + { + for(dim_t x = 0; x < m-m_remainder; x += p_lda) + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *) + (a10 + cs_a)); -BLIS_INLINE void ztrsm_small_pack_diag_element -( - bool is_unitdiag, - dcomplex *a11, - dim_t cs_a, - dcomplex *d11_pack, - dim_t size -) -{ - __m256d ymm1, ymm2, ymm3, ymm4, ymm5, ymm6, ymm7, ymm8; - bool is_four = (size == 4) ? 1 : 0; - dcomplex ones = {1.0, 1.0}; - ymm2 = ymm1 = _mm256_broadcast_pd((__m128d const *)&ones); - ymm7 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - if(!is_unitdiag) - { - //broadcast diagonal elements of A11 - ymm1 = _mm256_broadcast_pd((__m128d const *)a11); - ymm2 = _mm256_broadcast_pd((__m128d const *)a11+ cs_a +1); - /*Pick one element frome each column and create 3 element vector - and store it*/ - ymm1 = _mm256_permute2f128_pd(ymm1, ymm2, 0x20); - ymm2 = _mm256_broadcast_pd((__m128d const *)a11+ cs_a*2 + 2); - - if(is_four) - { - ymm3 = _mm256_broadcast_pd((__m128d const *)a11+ cs_a*2 + 2); - ymm2 = _mm256_broadcast_pd((__m128d const *)a11+ cs_a*3 + 3); - ymm2 = _mm256_permute2f128_pd(ymm3, ymm2, 0x20); - } + ymm6 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); -#ifdef BLIS_ENABLE_TRSM_PREINVERSION - /*Taking denomerator multiplication of real & imaginary components*/ - ymm4 = _mm256_mul_pd(ymm1, ymm1); - ymm5 = _mm256_mul_pd(ymm2,ymm2); - /*Swapping real & imaginary component position for addition with - * respective components*/ - ymm6 = _mm256_permute4x64_pd(ymm4, 0xb1); - ymm4 = _mm256_add_pd(ymm4, ymm6); - ymm8 = _mm256_permute4x64_pd(ymm5, 0xb1); - - ymm5 = _mm256_add_pd(ymm5, ymm8); - /*Negating imaginary component of numerator*/ - ymm1 = _mm256_mul_pd(ymm1, ymm7); - ymm2 = _mm256_mul_pd(ymm2, ymm7); - /*Dividing numerator by denominator*/ - ymm1 = _mm256_div_pd(ymm1, ymm4); - ymm2 = _mm256_div_pd(ymm2, ymm5); -#endif + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + + p_lda), ymm7); - } - _mm256_store_pd((double *)d11_pack, ymm1); - if(is_four) - { - _mm256_store_pd((double *)(d11_pack + 2), ymm2); - } - else - { - _mm_store_pd((double *)(d11_pack + 2), - _mm256_extractf128_pd(ymm2,0)); - - } -} + a10 += p_lda; + ptr_a10_dup += p_lda * p_lda; + } -BLIS_INLINE err_t bli_ztrsm_small_AutXB_AlXB -( - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl -) -{ - dim_t m = bli_obj_length(b); // number of rows of matrix B - dim_t n = bli_obj_width(b); // number of columns of matrix B + } + else + { + for(dim_t x=0;x < m-m_remainder;x++) + { + ymm0 = _mm256_loadu_pd((double const *) + (a10 + rs_a * x)); + _mm256_storeu_pd((double *) + (ptr_a10_dup + p_lda * x), ymm0); + } + } + //cols + for(j = (n - d_nr); (j + 1) > 0; j -= d_nr) + { + a10 = D_A_pack; + a11 = L; + b01 = B + (j*cs_b) + m_remainder; + b11 = B + (j* cs_b); + k_iter = (m - m_remainder); - bool transa = bli_obj_has_trans(a); - bool conjtransa = bli_obj_has_conj(a); + BLIS_SET_YMM_REG_ZEROS + ///GEMM code begins/// + BLIS_ZTRSM_SMALL_GEMM_2mx3n(a10,b01,cs_b,p_lda,k_iter) + ///GEMM code ends/// + ymm16 = _mm256_broadcast_pd((__m128d const *) + (&AlphaVal)); - dim_t cs_a, rs_a; - dim_t d_mr = 4,d_nr = 3; + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); - // Swap rs_a & cs_a in case of non-tranpose. - if(transa) - { - cs_a = bli_obj_col_stride(a); // column stride of A - rs_a = bli_obj_row_stride(a); // row stride of A - } - else - { - cs_a = bli_obj_row_stride(a); // row stride of A - rs_a = bli_obj_col_stride(a); // column stride of A - } - dim_t cs_b = bli_obj_col_stride(b); // column stride of B + ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - dim_t i, j, k; //loop variables - dim_t k_iter; //number of times GEMM to be performed + ymm14 = _mm256_permute_pd(ymm16, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm18); + ymm17 = _mm256_mul_pd(ymm0, ymm16); + ymm14 = _mm256_mul_pd(ymm0, ymm14); + ymm15 = _mm256_hsub_pd(ymm17, ymm14); - dcomplex AlphaVal = *(dcomplex *)AlphaObj->buffer; //value of alpha - dcomplex *L = a->buffer; //pointer to matrix A - dcomplex *B = b->buffer; //pointer to matrix B + ymm8 = _mm256_sub_pd(ymm15,ymm8); - dcomplex *a10, *a11, *b01, *b11; //pointers that point to blocks for GEMM and TRSM + ymm14 = _mm256_permute_pd(ymm16, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm18); + ymm17 = _mm256_mul_pd(ymm1, ymm16); + ymm14 = _mm256_mul_pd(ymm1, ymm14); + ymm15 = _mm256_hsub_pd(ymm17, ymm14); - dcomplex ones = {1.0, 1.0}; - bool is_unitdiag = bli_obj_has_unit_diag(a); + ymm9 = _mm256_sub_pd(ymm15,ymm9); - //scratch registers - __m256d ymm0, ymm1, ymm2, ymm3; - __m256d ymm4, ymm5, ymm6, ymm7; - __m256d ymm8, ymm9, ymm10, ymm11; - __m256d ymm12, ymm13, ymm14, ymm15; - __m256d ymm16, ymm17, ymm18, ymm19; + ymm14 = _mm256_permute_pd(ymm16, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm18); + ymm17 = _mm256_mul_pd(ymm2, ymm16); + ymm14 = _mm256_mul_pd(ymm2, ymm14); + ymm15 = _mm256_hsub_pd(ymm17, ymm14); - __m128d xmm5, xmm4, xmm3; + ymm10 = _mm256_sub_pd(ymm15,ymm10); - gint_t required_packing_A = 1; - mem_t local_mem_buf_A_s = {0}; - dcomplex *D_A_pack = NULL; - dcomplex d11_pack[d_mr] __attribute__((aligned(64))); - rntm_t rntm; + _mm256_storeu_pd((double *)(b11), ymm8); + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm9); + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm10); - bli_rntm_init_from_global( &rntm ); - bli_rntm_set_num_threads_only( 1, &rntm ); - bli_membrk_rntm_set_membrk( &rntm ); + if(transa) + ztrsm_AltXB_ref(a11, b11, m_remainder, 3, + cs_a, cs_b, is_unitdiag, + conjtransa); + else + ztrsm_AuXB_ref(a11, b11, m_remainder, 3, + rs_a, cs_b, is_unitdiag, + conjtransa); + } + dim_t n_remainder = j + d_nr; + if(n_remainder) + { + a10 = D_A_pack; + a11 = L; + b01 = B + m_remainder; + b11 = B; + k_iter = (m - m_remainder); + BLIS_SET_YMM_REG_ZEROS + if(2 == n_remainder) + { + ///GEMM code begins/// + BLIS_ZTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b, + p_lda,k_iter) + BLIS_PRE_ZTRSM_SMALL_2M_2N(AlphaVal,b11,cs_b) - siz_t buffer_size = bli_pool_block_size( - bli_membrk_pool( - bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), - bli_rntm_membrk(&rntm))); + if(transa) + ztrsm_AltXB_ref(a11, b11, m_remainder, 2, + cs_a, cs_b, is_unitdiag, + conjtransa); - if ( (d_mr * m * sizeof(dcomplex)) > buffer_size) - return BLIS_NOT_YET_IMPLEMENTED; + else + ztrsm_AuXB_ref(a11, b11, m_remainder, 2, + rs_a, cs_b, is_unitdiag, + conjtransa); + } + else if(1 == n_remainder) + { + ///GEMM code begins/// + BLIS_ZTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b, + p_lda,k_iter) + BLIS_PRE_ZTRSM_SMALL_2M_1N(AlphaVal,b11,cs_b) - if (required_packing_A == 1) - { - // Get the buffer from the pool. - bli_membrk_acquire_m(&rntm, - buffer_size, - BLIS_BITVAL_BUFFER_FOR_A_BLOCK, - &local_mem_buf_A_s); - if(FALSE==bli_mem_is_alloc(&local_mem_buf_A_s)) return BLIS_NULL_POINTER; - D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); - if(NULL==D_A_pack) return BLIS_NULL_POINTER; - } + if(transa) + ztrsm_AltXB_ref(a11, b11, m_remainder, 1, + cs_a, cs_b, is_unitdiag, + conjtransa); + else + ztrsm_AuXB_ref(a11, b11, m_remainder, 1, + rs_a, cs_b, is_unitdiag, + conjtransa); - /* - Performs solving TRSM for 4 colmns at a time from 0 to m/4 in steps of d_mr - a. Load, transpose, Pack A (a10 block), the size of packing 4x3 to 4x (m-4) - First there will be no GEMM and no packing of a10 because it is only TRSM - b. Using packed a10 block and b01 block perform GEMM operation - c. Use GEMM outputs, perform TRSM operaton using a11, b11 and update B - d. Repeat b,c for n rows of B in steps of d_nr - */ - for(i = 0;(i+d_mr-1) < m; i += d_mr) //loop along 'M' dimension + } + } + } + else if(m_remainder == 1) { - a10 = L + (i*cs_a); //pointer to block of A to be used for GEMM - a11 = L + (i*rs_a) + (i*cs_a); - dim_t p_lda = d_mr; // packed leading dimension - + dim_t p_lda = 2; // packed leading dimension if(transa) { - /* - Load, tranpose and pack current A block (a10) into packed buffer memory - D_A_pack - a. This a10 block is used in GEMM portion only and this - a10 block size will be increasing by d_mr for every next itteration - untill it reaches 4x(m-4) which is the maximum GEMM alone block size - in A - b. This packed buffer is reused to calculate all n rows of B matrix - */ - bli_ztrsm_small_pack('L', i, 1, a10, cs_a, D_A_pack, p_lda,d_mr); + for(dim_t x = 0; x < m-m_remainder; x += p_lda) + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *) + (a10 + cs_a)); + + ymm6 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + + p_lda), ymm7); + + a10 += p_lda; + ptr_a10_dup += p_lda * p_lda; + } - /* - Pack 4 diagonal elements of A block into an array - a. This helps in utilze cache line efficiently in TRSM operation - b. store ones when input is unit diagonal - */ - ztrsm_small_pack_diag_element(is_unitdiag,a11,cs_a,d11_pack,d_mr); } else { - bli_ztrsm_small_pack('L', i, 0, a10, rs_a, D_A_pack, p_lda,d_mr); - ztrsm_small_pack_diag_element(is_unitdiag,a11,rs_a,d11_pack,d_mr); + for(dim_t x=0;x 0; j -= d_nr) { a10 = D_A_pack; - a11 = L + (i*rs_a) + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + j*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM + a11 = L; + b01 = B + (j*cs_b) + m_remainder; + b11 = B + (j* cs_b); + k_iter = (m - m_remainder); - k_iter = i; + BLIS_SET_YMM_REG_ZEROS + ///GEMM code begins/// + BLIS_ZTRSM_SMALL_GEMM_2mx3n(a10,b01,cs_b,p_lda,k_iter) + ///GEMM code ends/// + ymm16 = _mm256_broadcast_pd((__m128d const *) + (&AlphaVal)); - /*Fill zeros into ymm registers used in gemm accumulations */ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm16, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm18); + ymm17 = _mm256_mul_pd(ymm0, ymm16); + ymm14 = _mm256_mul_pd(ymm0, ymm14); + ymm15 = _mm256_hsub_pd(ymm17, ymm14); + + ymm8 = _mm256_sub_pd(ymm15,ymm8); + + ymm14 = _mm256_permute_pd(ymm16, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm18); + ymm17 = _mm256_mul_pd(ymm1, ymm16); + ymm14 = _mm256_mul_pd(ymm1, ymm14); + ymm15 = _mm256_hsub_pd(ymm17, ymm14); + + ymm9 = _mm256_sub_pd(ymm15,ymm9); + ymm14 = _mm256_permute_pd(ymm16, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm18); + ymm17 = _mm256_mul_pd(ymm2, ymm16); + ymm14 = _mm256_mul_pd(ymm2, ymm14); + ymm15 = _mm256_hsub_pd(ymm17, ymm14); + + ymm10 = _mm256_sub_pd(ymm15,ymm10); + + _mm_storeu_pd((double *)(b11), + _mm256_extractf128_pd(ymm8,0)); + _mm_storeu_pd((double *)(b11 + cs_b * 1), + _mm256_extractf128_pd(ymm9,0) ); + _mm_storeu_pd((double *)(b11 + cs_b * 2), + _mm256_extractf128_pd(ymm10,0)); + + if(transa) + ztrsm_AltXB_ref(a11, b11, m_remainder, 3, + cs_a, cs_b, is_unitdiag, + conjtransa); + + else + ztrsm_AuXB_ref(a11, b11, m_remainder, 3, rs_a, + cs_b, is_unitdiag, + conjtransa); + } + dim_t n_remainder = j + d_nr; + if(n_remainder) + { + a10 = D_A_pack; + a11 = L ; + b01 = B + m_remainder; + b11 = B; + k_iter = (m - m_remainder); BLIS_SET_YMM_REG_ZEROS + if(2 == n_remainder) + { - /* - Peform GEMM between a10 and b01 blocks - For first itteration there will be no GEMM operation - where k_iter are zero - */ - BLIS_ZTRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) + ///GEMM code begins/// + BLIS_ZTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b, + p_lda,k_iter) + BLIS_PRE_ZTRSM_SMALL_1M_2N(AlphaVal,b11,cs_b) - /* - Load b11 of size 3x4 and multiply with alpha - Add the GEMM output and perform inregister transose of b11 - to peform TRSM operation. - */ - BLIS_ZTRSM_SMALL_NREG_TRANSPOSE_3x4(b11,cs_b,AlphaVal) - /* - Compute 4x3 TRSM block by using GEMM block output in register - a. The 4x3 input (gemm outputs) are stored in combinations of ymm - registers - 1. ymm8, ymm4 2. ymm9, ymm5 3. ymm10, ymm6, 4. ymm11, ymm7 - where ymm8-ymm11 holds 4x2 data and reaming 4x1 will be hold by - other registers - b. Towards the end do in regiser transpose of TRSM output and store in - b11 - */ - ////extract a00 - ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); + if(transa) + ztrsm_AltXB_ref(a11, b11, m_remainder, 2, + cs_a, cs_b, is_unitdiag, + conjtransa); -#ifndef BLIS_ENABLE_TRSM_PREINVERSION - /*performs dcomplex divison of ymm8 and ymm4 with ymm1*/ - BLIS_ZTRSM_TWO_DIV(ymm8,ymm4) -#else - /*performs dcomplex multiplication of ymm8 and ymm4 with ymm1*/ - BLIS_ZTRSM_MUL(ymm8) - BLIS_ZTRSM_MUL(ymm4) -#endif - //extract a11 - ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack + 1)); - //(ROW1): FMA operations - ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a*1)); - if(conjtransa) - { - ymm2 = _mm256_mul_pd(ymm2, ymm0); - } - /* Step1 dcomplex multiply ymm2, ymm8 - * Step2 negate the result - * Step3 add ymm9*/ - //Step 1 - ymm14 = _mm256_permute_pd(ymm2, 0x5); - /* Negate the imaginary elements of vec2 */ - ymm14 = _mm256_mul_pd(ymm14, ymm0); - //For ymm8 - /* Multiply vec1 and vec2 */ - ymm13 = _mm256_mul_pd(ymm8, ymm2); /*vec3*/ - /* Multiply vec1 and the modified vec2 */ - ymm14 = _mm256_mul_pd(ymm8, ymm14); /*vec4*/ - /* Horizontally subtract the elements in vec3 and vec4 */ - ymm16 = _mm256_hsub_pd(ymm13, ymm14); + else + ztrsm_AuXB_ref(a11, b11, m_remainder, 2, + rs_a, cs_b, is_unitdiag, + conjtransa); + } + else if(1 == n_remainder) + { + ///GEMM code begins/// + BLIS_ZTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b, + p_lda,k_iter) - //For ymm4 - ymm14 = _mm256_permute_pd(ymm2, 0x5); - /* Negate the imaginary elements of vec2 */ - ymm14 = _mm256_mul_pd(ymm14, ymm0); + BLIS_PRE_ZTRSM_SMALL_1M_1N(AlphaVal,b11,cs_b) - ymm13 = _mm256_mul_pd(ymm4, ymm2); - ymm14 = _mm256_mul_pd(ymm4, ymm14); - ymm17 = _mm256_hsub_pd(ymm13, ymm14); - //Step 2 - ymm16 = _mm256_mul_pd(ymm16, ymm15); - ymm17 = _mm256_mul_pd(ymm17, ymm15); + if(transa) + ztrsm_AltXB_ref(a11, b11, m_remainder, 1, + cs_a, cs_b, is_unitdiag, + conjtransa); - //Step 3 - ymm9 = _mm256_add_pd(ymm16, ymm9); - ymm5 = _mm256_add_pd(ymm17, ymm5); + else + ztrsm_AuXB_ref(a11, b11, m_remainder, 1, + rs_a, cs_b, is_unitdiag, + conjtransa); + } + } + } - ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a*2)); - if(conjtransa) - { - ymm2 = _mm256_mul_pd(ymm2, ymm0); - } + if ((required_packing_A == 1) && + bli_mem_is_alloc( &local_mem_buf_A_s )) + { + bli_membrk_release(&rntm, &local_mem_buf_A_s); + } - //Step 1 - ymm14 = _mm256_permute_pd(ymm2, 0x5); - /* Negate the imaginary elements of vec2 */ - ymm14 = _mm256_mul_pd(ymm14, ymm0); - //For ymm8 - /* Multiply vec1 and vec2 */ - ymm13 = _mm256_mul_pd(ymm8, ymm2); /*vec3*/ - /* Multiply vec1 and the modified vec2 */ - ymm14 = _mm256_mul_pd(ymm8, ymm14); /*vec4*/ - /* Horizontally subtract the elements in vec3 and vec4 */ - ymm16 = _mm256_hsub_pd(ymm13, ymm14); - //For ymm4 - ymm14 = _mm256_permute_pd(ymm2, 0x5); - /* Negate the imaginary elements of vec2 */ - ymm14 = _mm256_mul_pd(ymm14, ymm0); + return BLIS_SUCCESS; +} - ymm13 = _mm256_mul_pd(ymm4, ymm2); - ymm14 = _mm256_mul_pd(ymm4, ymm14); - ymm17 = _mm256_hsub_pd(ymm13, ymm14); - //Step 2 - ymm16 = _mm256_mul_pd(ymm16, ymm15); - ymm17 = _mm256_mul_pd(ymm17, ymm15); +BLIS_INLINE err_t bli_ztrsm_small_XAutB_XAlB +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +) +{ + dim_t m = bli_obj_length(b); //number of rows + dim_t n = bli_obj_width(b); //number of columns - //Step 3 - ymm10 = _mm256_add_pd(ymm16, ymm10); - ymm6 = _mm256_add_pd(ymm17, ymm6); + bool transa = bli_obj_has_trans(a); + bool conjtransa = bli_obj_has_conj(a); - ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a*3)); - if(conjtransa) - { - ymm2 = _mm256_mul_pd(ymm2, ymm0); - } + dim_t cs_a, rs_a; + dim_t d_mr = 4,d_nr = 3; - //Step 1 - ymm14 = _mm256_permute_pd(ymm2, 0x5); - /* Negate the imaginary elements of vec2 */ - ymm14 = _mm256_mul_pd(ymm14, ymm0); - //For ymm8 - /* Multiply vec1 and vec2 */ - ymm13 = _mm256_mul_pd(ymm8, ymm2); /*vec3*/ - /* Multiply vec1 and the modified vec2 */ - ymm14 = _mm256_mul_pd(ymm8, ymm14); /*vec4*/ - /* Horizontally subtract the elements in vec3 and vec4 */ - ymm16 = _mm256_hsub_pd(ymm13, ymm14); - //For ymm4 - ymm14 = _mm256_permute_pd(ymm2, 0x5); - /* Negate the imaginary elements of vec2 */ - ymm14 = _mm256_mul_pd(ymm14, ymm0); + // Swap rs_a & cs_a in case of non-tranpose. + if(transa) + { + cs_a = bli_obj_col_stride(a); // column stride of A + rs_a = bli_obj_row_stride(a); // row stride of A + } + else + { + cs_a = bli_obj_row_stride(a); // row stride of A + rs_a = bli_obj_col_stride(a); // column stride of A + } + dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B - ymm13 = _mm256_mul_pd(ymm4, ymm2); - ymm14 = _mm256_mul_pd(ymm4, ymm14); - ymm17 = _mm256_hsub_pd(ymm13, ymm14); - //Step 2 - ymm16 = _mm256_mul_pd(ymm16, ymm15); - ymm17 = _mm256_mul_pd(ymm17, ymm15); - //Step 3 - ymm11 = _mm256_add_pd(ymm16, ymm11); - ymm7 = _mm256_add_pd(ymm17, ymm7); + dim_t i, j, k; //loop variablse + dim_t k_iter; //determines the number of GEMM operations to be done -#ifndef BLIS_ENABLE_TRSM_PREINVERSION - /*performs dcomplex divison of ymm9 and ymm5 with ymm1*/ - BLIS_ZTRSM_TWO_DIV(ymm9,ymm5) -#else - /*performs dcomplex multiplication of ymm9 and ymm5 with ymm1*/ - BLIS_ZTRSM_MUL(ymm9) - BLIS_ZTRSM_MUL(ymm5) -#endif - a11 += rs_a; - //extract a22 - ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack + 2)); + dcomplex ones = {1.0, 1.0}; + dcomplex zero = {0.0, 0.0}; + bool is_unitdiag = bli_obj_has_unit_diag(a); - //(ROW2): FMA operations - ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a*2)); - if(conjtransa) - { - ymm2 = _mm256_mul_pd(ymm2, ymm0); - } - //Step 1 - ymm14 = _mm256_permute_pd(ymm2, 0x5); - /* Negate the imaginary elements of vec2 */ - ymm14 = _mm256_mul_pd(ymm14, ymm0); + dcomplex AlphaVal = *(dcomplex *)AlphaObj->buffer; //value of Alpha + dcomplex* restrict L = a->buffer; //pointer to matrix A + dcomplex* restrict B = b->buffer; //pointer to matrix B - //For ymm9 - /* Multiply vec1 and vec2 */ - ymm13 = _mm256_mul_pd(ymm9, ymm2); /*vec3*/ - /* Multiply vec1 and the modified vec2 */ - ymm14 = _mm256_mul_pd(ymm9, ymm14); /*vec4*/ - /* Horizontally subtract the elements in vec3 and vec4 */ - ymm16 = _mm256_hsub_pd(ymm13, ymm14); - //For ymm5 - ymm14 = _mm256_permute_pd(ymm2, 0x5); - /* Negate the imaginary elements of vec2 */ - ymm14 = _mm256_mul_pd(ymm14, ymm0); + dcomplex *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks - ymm13 = _mm256_mul_pd(ymm5, ymm2); - ymm14 = _mm256_mul_pd(ymm5, ymm14); - ymm17 = _mm256_hsub_pd(ymm13, ymm14); - //Step 2 - ymm16 = _mm256_mul_pd(ymm16, ymm15); - ymm17 = _mm256_mul_pd(ymm17, ymm15); - //Step 3 - ymm10 = _mm256_add_pd(ymm16, ymm10); - ymm6 = _mm256_add_pd(ymm17, ymm6); + gint_t required_packing_A = 1; + mem_t local_mem_buf_A_s = {0}; + dcomplex *D_A_pack = NULL; + dcomplex d11_pack[d_mr] __attribute__((aligned(64))); + rntm_t rntm; - ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a*3)); - if(conjtransa) - { - ymm2 = _mm256_mul_pd(ymm2, ymm0); - } - //Step 1 - ymm14 = _mm256_permute_pd(ymm2, 0x5); - /* Negate the imaginary elements of vec2 */ - ymm14 = _mm256_mul_pd(ymm14, ymm0); + bli_rntm_init_from_global( &rntm ); + bli_rntm_set_num_threads_only( 1, &rntm ); + bli_membrk_rntm_set_membrk( &rntm ); - //For ymm9 - /* Multiply vec1 and vec2 */ - ymm13 = _mm256_mul_pd(ymm9, ymm2); /*vec3*/ - /* Multiply vec1 and the modified vec2 */ - ymm14 = _mm256_mul_pd(ymm9, ymm14); /*vec4*/ - /* Horizontally subtract the elements in vec3 and vec4 */ - ymm16 = _mm256_hsub_pd(ymm13, ymm14); - //For ymm5 - ymm14 = _mm256_permute_pd(ymm2, 0x5); - /* Negate the imaginary elements of vec2 */ - ymm14 = _mm256_mul_pd(ymm14, ymm0); + siz_t buffer_size = bli_pool_block_size( + bli_membrk_pool( + bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), + bli_rntm_membrk(&rntm))); - ymm13 = _mm256_mul_pd(ymm5, ymm2); - ymm14 = _mm256_mul_pd(ymm5, ymm14); - ymm17 = _mm256_hsub_pd(ymm13, ymm14); - //Step 2 - ymm16 = _mm256_mul_pd(ymm16, ymm15); - ymm17 = _mm256_mul_pd(ymm17, ymm15); - //Step 3 - ymm11 = _mm256_add_pd(ymm16, ymm11); - ymm7 = _mm256_add_pd(ymm17, ymm7); + if( (d_nr * n * sizeof(dcomplex)) > buffer_size) + return BLIS_NOT_YET_IMPLEMENTED; -#ifndef BLIS_ENABLE_TRSM_PREINVERSION - /*performs dcomplex divison of ymm10 and ymm6 with ymm1*/ - BLIS_ZTRSM_TWO_DIV(ymm10,ymm6) -#else - /*performs dcomplex multiplication of ymm10 and ymm6 with ymm1*/ - BLIS_ZTRSM_MUL(ymm10) - BLIS_ZTRSM_MUL(ymm6) -#endif - a11 += rs_a; - //extract a44 - ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack + 3)); - //(ROW3): FMA operations - ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a*3)); - if(conjtransa) - { - ymm2 = _mm256_mul_pd(ymm2, ymm0); - } + if (required_packing_A == 1) + { + // Get the buffer from the pool. + bli_membrk_acquire_m(&rntm, + buffer_size, + BLIS_BITVAL_BUFFER_FOR_A_BLOCK, + &local_mem_buf_A_s); + if(FALSE==bli_mem_is_alloc(&local_mem_buf_A_s)) return BLIS_NULL_POINTER; + D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); + if(NULL==D_A_pack) return BLIS_NULL_POINTER; + } - //Step 1 - ymm14 = _mm256_permute_pd(ymm2, 0x5); - /* Negate the imaginary elements of vec2 */ - ymm14 = _mm256_mul_pd(ymm14, ymm0); + //ymm scratch reginsters + __m256d ymm0, ymm1, ymm2, ymm3; + __m256d ymm4, ymm5, ymm6, ymm7; + __m256d ymm8, ymm9, ymm10, ymm11; + __m256d ymm12, ymm13, ymm14, ymm15; + __m256d ymm16, ymm17, ymm18, ymm19; - //For ymm10 - /* Multiply vec1 and vec2 */ - ymm13 = _mm256_mul_pd(ymm10, ymm2); /*vec3*/ - /* Multiply vec1 and the modified vec2 */ - ymm14 = _mm256_mul_pd(ymm10, ymm14); /*vec4*/ - /* Horizontally subtract the elements in vec3 and vec4 */ - ymm16 = _mm256_hsub_pd(ymm13, ymm14); - //For ymm6 - ymm14 = _mm256_permute_pd(ymm2, 0x5); - /* Negate the imaginary elements of vec2 */ - ymm14 = _mm256_mul_pd(ymm14, ymm0); + __m128d xmm5; - ymm13 = _mm256_mul_pd(ymm6, ymm2); - ymm14 = _mm256_mul_pd(ymm6, ymm14); - ymm17 = _mm256_hsub_pd(ymm13, ymm14); - //Step 2 - ymm16 = _mm256_mul_pd(ymm16, ymm15); - ymm17 = _mm256_mul_pd(ymm17, ymm15); - //Step 3 - ymm11 = _mm256_add_pd(ymm16, ymm11); - ymm7 = _mm256_add_pd(ymm17, ymm7); + for(j = (n-d_nr); (j+1) > 0; j -= d_nr) //loop along 'N' direction + { + a01 = L + (j*rs_a) + (j+d_nr)*cs_a; + a11 = L + (j*cs_a) + (j*rs_a); -#ifndef BLIS_ENABLE_TRSM_PREINVERSION - /*performs dcomplex divison of ymm11 and ymm7 with ymm1*/ - BLIS_ZTRSM_TWO_DIV(ymm11,ymm7) -#else - /*performs dcomplex nultiplication of ymm11 and ymm7 with ymm1*/ - BLIS_ZTRSM_MUL(ymm11) - BLIS_ZTRSM_MUL(ymm7) -#endif - a11 += rs_a; - BLIS_ZTRSM_SMALL_NREG_TRANSPOSE_4x3_AND_STORE(b11,cs_b) + dim_t p_lda = (n-j-d_nr); // packed leading dimension + // perform copy of A to packed buffer D_A_pack + + if(transa) + { + /* + Pack current A block (a01) into packed buffer memory D_A_pack + a. This a10 block is used in GEMM portion only and this + a01 block size will be increasing by d_nr for every next + iteration until it reaches 3x(n-3) which is the maximum GEMM + alone block size in A + b. This packed buffer is reused to calculate all m cols of B + matrix + */ + bli_ztrsm_small_pack('R', p_lda, 1, a01, cs_a, D_A_pack, + p_lda,d_nr); + + /* + Pack 3 diagonal elements of A block into an array + a. This helps in utilze cache line efficiently in TRSM + operation + b. store ones when input is unit diagonal + */ + ztrsm_small_pack_diag_element(is_unitdiag,a11,cs_a, + d11_pack,d_nr); + } + else + { + bli_ztrsm_small_pack('R', p_lda, 0, a01, rs_a, D_A_pack, + p_lda,d_nr); + ztrsm_small_pack_diag_element(is_unitdiag,a11,rs_a, + d11_pack,d_nr); } - dim_t n_rem = n-j; - if(n_rem) + /* + a. Perform GEMM using a01, b10. + b. Perform TRSM on a11, b11 + c. This loop GEMM+TRSM loops operates with 8x6 block size + along m dimension for every d_mr columns of B10 where + packed A buffer is reused in computing all m cols of B. + d. Same approach is used in remaining fringe cases. + */ + for(i = (m-d_mr); (i+1) > 0; i -= d_mr) //loop along 'M' direction { - a10 = D_A_pack; - a11 = L + (i*rs_a) + (i*cs_a);//pointer to block of A to be used for TRSM - b01 = B + j*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; + b10 = B + i + (j+d_nr)*cs_b; + b11 = B + (i) + (j)*cs_b; - k_iter = i; //number of times GEMM to be performed(in blocks of 4x4) + k_iter = (n-j-d_nr); /*Fill zeros into ymm registers used in gemm accumulations */ BLIS_SET_YMM_REG_ZEROS - if(2 == n_rem) - { - ///GEMM code begins/// - BLIS_ZTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b,p_lda,k_iter) - BLIS_ZTRSM_SMALL_NREG_TRANSPOSE_2x4(b11,cs_b,AlphaVal) - } - else if(1 == n_rem) - { - ///GEMM code begins/// - BLIS_ZTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b,p_lda,k_iter) - BLIS_ZTRSM_SMALL_NREG_TRANSPOSE_1x4(b11,cs_b,AlphaVal) - } - ///implement TRSM/// - ///transpose of B11// - ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - ymm10 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm11 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + /* + Peform GEMM between a01 and b10 blocks + For first itteration there will be no GEMM operation + where k_iter are zero + */ + + BLIS_ZTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) + + /* + Load b11 multiply with alpha + Add the GEMM output to b11 + and peform TRSM operation. + */ + BLIS_PRE_ZTRSM_SMALL_3x4(AlphaVal,b11,cs_b) + ///implement TRSM/// + /* + Compute 3x3 TRSM block by using GEMM block output in register + a. The 4x3 input (gemm outputs) are stored in combinations of + ymm registers + 1. ymm7, ymm8 2. ymm5, ymm6 3. ymm3, ymm4 + b. Towards the end do in regiser transpose of TRSM output and + store in b11 + */ ////extract a00 ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); - + ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack + 2)); #ifndef BLIS_ENABLE_TRSM_PREINVERSION - BLIS_ZTRSM_DIV(ymm8) + /*performs dcomplex divison of ymm7 and ymm8 with ymm1*/ + BLIS_ZTRSM_TWO_DIV(ymm7,ymm8) #else - BLIS_ZTRSM_MUL(ymm8) + /*performs dcomplex multiplication of ymm7 and ymm8 with ymm1*/ + BLIS_ZTRSM_MUL(ymm7) + BLIS_ZTRSM_MUL(ymm8) #endif - //extract a11 ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack + 1)); //(ROW1): FMA operations - ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a*1)); - ymm3 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a*2)); - ymm4 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a*3)); - - if(conjtransa){ - ymm2 = _mm256_mul_pd(ymm2, ymm0); - ymm3 = _mm256_mul_pd(ymm3, ymm0); - ymm4 = _mm256_mul_pd(ymm4, ymm0); - } - - a11 += rs_a; - /*Step1 dcomplex multiply ymmx, ymmx + ymm2 = _mm256_broadcast_pd((__m128d const *) + (a11 + cs_a*2 + rs_a*1)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + /* Step1 dcomplex multiply ymm2, ymm7 * Step2 negate the result * Step3 add ymmx*/ //Step 1 ymm14 = _mm256_permute_pd(ymm2, 0x5); /* Negate the imaginary elements of vec2 */ ymm14 = _mm256_mul_pd(ymm14, ymm0); - //For ymm8 + //For ymm7 /* Multiply vec1 and vec2 */ - ymm13 = _mm256_mul_pd(ymm8, ymm2); /*vec3*/ + ymm13 = _mm256_mul_pd(ymm7, ymm2); /*vec3*/ /* Multiply vec1 and the modified vec2 */ - ymm14 = _mm256_mul_pd(ymm8, ymm14); /*vec4*/ + ymm14 = _mm256_mul_pd(ymm7, ymm14); /*vec4*/ /* Horizontally subtract the elements in vec3 and vec4 */ ymm16 = _mm256_hsub_pd(ymm13, ymm14); + //For ymm8 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + + ymm13 = _mm256_mul_pd(ymm8, ymm2); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm17 = _mm256_hsub_pd(ymm13, ymm14); //Step 2 ymm16 = _mm256_mul_pd(ymm16, ymm15); + ymm17 = _mm256_mul_pd(ymm17, ymm15); //Step 3 - ymm9 = _mm256_add_pd(ymm16, ymm9); + ymm5 = _mm256_add_pd(ymm16, ymm5); + ymm6 = _mm256_add_pd(ymm17, ymm6); + ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a*2)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } //Step 1 - ymm14 = _mm256_permute_pd(ymm3, 0x5); + ymm14 = _mm256_permute_pd(ymm2, 0x5); /* Negate the imaginary elements of vec2 */ ymm14 = _mm256_mul_pd(ymm14, ymm0); - //For ymm8 + //For ymm7 /* Multiply vec1 and vec2 */ - ymm13 = _mm256_mul_pd(ymm8, ymm3); /*vec3*/ + ymm13 = _mm256_mul_pd(ymm7, ymm2); /*vec3*/ /* Multiply vec1 and the modified vec2 */ - ymm14 = _mm256_mul_pd(ymm8, ymm14); /*vec4*/ + ymm14 = _mm256_mul_pd(ymm7, ymm14); /*vec4*/ /* Horizontally subtract the elements in vec3 and vec4 */ ymm16 = _mm256_hsub_pd(ymm13, ymm14); + //For ymm8 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + + ymm13 = _mm256_mul_pd(ymm8, ymm2); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm17 = _mm256_hsub_pd(ymm13, ymm14); //Step 2 ymm16 = _mm256_mul_pd(ymm16, ymm15); + ymm17 = _mm256_mul_pd(ymm17, ymm15); //Step 3 - ymm10 = _mm256_add_pd(ymm16, ymm10); + ymm3 = _mm256_add_pd(ymm16, ymm3); + ymm4 = _mm256_add_pd(ymm17, ymm4); + + +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm5 and ymm6 with ymm1*/ + BLIS_ZTRSM_TWO_DIV(ymm5,ymm6) +#else + /*performs dcomplex multiplication of ymm5 and ymm6 with ymm1*/ + BLIS_ZTRSM_MUL(ymm5) + BLIS_ZTRSM_MUL(ymm6) +#endif + //extract a22 + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); + //(ROW2): FMA operations + ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } //Step 1 - ymm14 = _mm256_permute_pd(ymm4, 0x5); + ymm14 = _mm256_permute_pd(ymm2, 0x5); /* Negate the imaginary elements of vec2 */ ymm14 = _mm256_mul_pd(ymm14, ymm0); - //For ymm8 + + //For ymm5 /* Multiply vec1 and vec2 */ - ymm13 = _mm256_mul_pd(ymm8, ymm4); /*vec3*/ + ymm13 = _mm256_mul_pd(ymm5, ymm2); /*vec3*/ /* Multiply vec1 and the modified vec2 */ - ymm14 = _mm256_mul_pd(ymm8, ymm14); /*vec4*/ + ymm14 = _mm256_mul_pd(ymm5, ymm14); /*vec4*/ /* Horizontally subtract the elements in vec3 and vec4 */ ymm16 = _mm256_hsub_pd(ymm13, ymm14); + //For ymm6 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + + ymm13 = _mm256_mul_pd(ymm6, ymm2); + ymm14 = _mm256_mul_pd(ymm6, ymm14); + ymm17 = _mm256_hsub_pd(ymm13, ymm14); //Step 2 ymm16 = _mm256_mul_pd(ymm16, ymm15); + ymm17 = _mm256_mul_pd(ymm17, ymm15); //Step 3 - ymm11 = _mm256_add_pd(ymm16, ymm11); + ymm3 = _mm256_add_pd(ymm16, ymm3); + ymm4 = _mm256_add_pd(ymm17, ymm4); + + #ifndef BLIS_ENABLE_TRSM_PREINVERSION - BLIS_ZTRSM_DIV(ymm9) + /*performs dcomplex divison of ymm3 and ymm4 with ymm1*/ + BLIS_ZTRSM_TWO_DIV(ymm3,ymm4) #else - BLIS_ZTRSM_MUL(ymm9) + /*performs dcomplex multiplication of ymm3 and ymm4 with ymm1*/ + BLIS_ZTRSM_MUL(ymm3) + BLIS_ZTRSM_MUL(ymm4) #endif - ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack + 2)); - ymm3 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a*2)); - ymm4 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a*3)); + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + 2), ymm4); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + _mm256_storeu_pd((double *)(b11 + cs_b + 2), ymm6); + _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); + _mm256_storeu_pd((double *)(b11 + cs_b*2 + 2), ymm8); - if(conjtransa){ - ymm3 = _mm256_mul_pd(ymm3, ymm0); - ymm4 = _mm256_mul_pd(ymm4, ymm0); - } + } + dim_t m_remainder = i + d_mr; + if(m_remainder) + { + if(3 == m_remainder) + { + a01 = D_A_pack; + a11 = L + (j*cs_a) + (j*rs_a); + b10 = B + (j+d_nr)*cs_b + (m_remainder - 3); + b11 = B + (m_remainder - 3) + (j*cs_b); + k_iter = (n-j-d_nr); + /*Fill zeros into ymm registers used in gemm + * accumulations */ + BLIS_SET_YMM_REG_ZEROS + /* + Peform GEMM between a01 and b10 blocks + For first itteration there will be no GEMM operation + where k_iter are zero + */ - a11 += rs_a; - //Step 1 - ymm14 = _mm256_permute_pd(ymm3, 0x5); - /* Negate the imaginary elements of vec2 */ - ymm14 = _mm256_mul_pd(ymm14, ymm0); - //For ymm9 - /* Multiply vec1 and vec2 */ - ymm13 = _mm256_mul_pd(ymm9, ymm3); /*vec3*/ - /* Multiply vec1 and the modified vec2 */ - ymm14 = _mm256_mul_pd(ymm9, ymm14); /*vec4*/ - /* Horizontally subtract the elements in vec3 and vec4 */ - ymm16 = _mm256_hsub_pd(ymm13, ymm14); - //Step 2 - ymm16 = _mm256_mul_pd(ymm16, ymm15); + BLIS_ZTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) - //Step 3 - ymm10 = _mm256_add_pd(ymm16, ymm10); + /* + Load b11 multiply with alpha + Add the GEMM output to b11 + and peform TRSM operation. + */ - //Step 1 - ymm14 = _mm256_permute_pd(ymm4, 0x5); - /* Negate the imaginary elements of vec2 */ - ymm14 = _mm256_mul_pd(ymm14, ymm0); - //For ymm8 - /* Multiply vec1 and vec2 */ - ymm13 = _mm256_mul_pd(ymm9, ymm4); /*vec3*/ - /* Multiply vec1 and the modified vec2 */ - ymm14 = _mm256_mul_pd(ymm9, ymm14); /*vec4*/ - /* Horizontally subtract the elements in vec3 and vec4 */ - ymm16 = _mm256_hsub_pd(ymm13, ymm14); - //Step 2 - ymm16 = _mm256_mul_pd(ymm16, ymm15); - //Step 3 - ymm11 = _mm256_add_pd(ymm16, ymm11); + BLIS_PRE_ZTRSM_SMALL_3x4(AlphaVal,b11,cs_b) + ///implement TRSM/// + /* + Compute 3x3 TRSM block by using GEMM block output in + register + a. The 4x3 input (gemm outputs) are stored in + combinations of ymm registers + 1. ymm7, ymm8 2. ymm5, ymm6 3. ymm3, ymm4 + b. Towards the end do in regiser transpose of TRSM + output and store in b11 + */ + ////extract a00 + ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + ymm1 = _mm256_broadcast_pd((__m128d const *) + (d11_pack + 2)); #ifndef BLIS_ENABLE_TRSM_PREINVERSION - BLIS_ZTRSM_DIV(ymm10) + /*performs dcomplex divison of ymm7 and ymm8 with ymm1*/ + BLIS_ZTRSM_TWO_DIV(ymm7,ymm8) #else - BLIS_ZTRSM_MUL(ymm10) + /*performs dcomplex multiplication of ymm7 and + * ymm8 with ymm1*/ + BLIS_ZTRSM_MUL(ymm7) + BLIS_ZTRSM_MUL(ymm8) #endif + //extract a11 + ymm1 = _mm256_broadcast_pd((__m128d const *) + (d11_pack + 1)); + //(ROW1): FMA operations + ymm2 = _mm256_broadcast_pd((__m128d const *) + (a11 + cs_a*2 + rs_a*1)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + /* Step1 dcomplex multiply ymm2, ymm7 + * Step2 negate the result + * Step3 add ymmx*/ + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + //For ymm7 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm7, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm7, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + + //For ymm8 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + + ymm13 = _mm256_mul_pd(ymm8, ymm2); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm17 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + ymm17 = _mm256_mul_pd(ymm17, ymm15); + + //Step 3 + ymm5 = _mm256_add_pd(ymm16, ymm5); + ymm6 = _mm256_add_pd(ymm17, ymm6); + + ymm2 = _mm256_broadcast_pd((__m128d const *) + (a11 + cs_a*2)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + //For ymm7 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm7, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm7, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + //For ymm8 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + + ymm13 = _mm256_mul_pd(ymm8, ymm2); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm17 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + ymm17 = _mm256_mul_pd(ymm17, ymm15); + + //Step 3 + ymm3 = _mm256_add_pd(ymm16, ymm3); + ymm4 = _mm256_add_pd(ymm17, ymm4); - ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack + 3)); - ymm4 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a*3)); - if(conjtransa){ - ymm4 = _mm256_mul_pd(ymm4, ymm0); - } +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm5 and ymm6 with ymm1*/ + BLIS_ZTRSM_TWO_DIV(ymm5,ymm6) +#else + /*performs dcomplex multiplication of ymm5 and + * ymm6 with ymm1*/ + BLIS_ZTRSM_MUL(ymm5) + BLIS_ZTRSM_MUL(ymm6) +#endif + //extract a22 + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); + + //(ROW2): FMA operations + ymm2 = _mm256_broadcast_pd((__m128d const *) + (a11 + cs_a)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + + //For ymm5 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm5, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm5, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + //For ymm6 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + + ymm13 = _mm256_mul_pd(ymm6, ymm2); + ymm14 = _mm256_mul_pd(ymm6, ymm14); + ymm17 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + ymm17 = _mm256_mul_pd(ymm17, ymm15); + //Step 3 + ymm3 = _mm256_add_pd(ymm16, ymm3); + ymm4 = _mm256_add_pd(ymm17, ymm4); + - //Step 1 - ymm14 = _mm256_permute_pd(ymm4, 0x5); - /* Negate the imaginary elements of vec2 */ - ymm14 = _mm256_mul_pd(ymm14, ymm0); - //For ymm10 - /* Multiply vec1 and vec2 */ - ymm13 = _mm256_mul_pd(ymm10, ymm4); /*vec3*/ - /* Multiply vec1 and the modified vec2 */ - ymm14 = _mm256_mul_pd(ymm10, ymm14); /*vec4*/ - /* Horizontally subtract the elements in vec3 and vec4 */ - ymm16 = _mm256_hsub_pd(ymm13, ymm14); - //Step 2 - ymm16 = _mm256_mul_pd(ymm16, ymm15); - //Step 3 - ymm11 = _mm256_add_pd(ymm16, ymm11); #ifndef BLIS_ENABLE_TRSM_PREINVERSION - BLIS_ZTRSM_DIV(ymm11) + /*performs dcomplex divison of ymm3 and ymm4 with ymm1*/ + BLIS_ZTRSM_TWO_DIV(ymm3,ymm4) #else - BLIS_ZTRSM_MUL(ymm11) + /*performs dcomplex multiplication of ymm3 and + * ymm4 with ymm1*/ + BLIS_ZTRSM_MUL(ymm3) + BLIS_ZTRSM_MUL(ymm4) #endif - if(n_rem == 1) - { - ymm0 = _mm256_permute2f128_pd(ymm8,ymm9,0x20); - ymm4 = _mm256_permute2f128_pd(ymm10,ymm11,0x20); - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); - _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 2), ymm4); - } - else if(n_rem == 2) - { - ymm0 = _mm256_permute2f128_pd(ymm8,ymm9,0x20); - ymm4 = _mm256_permute2f128_pd(ymm10,ymm11,0x20); - ymm1 = _mm256_permute2f128_pd(ymm8,ymm9,0x31); - ymm3 = _mm256_permute2f128_pd(ymm10,ymm11,0x31); - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); - _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 2), ymm4); - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); - _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 2), ymm3); - } - } - } - dim_t m_rem = m-i; - if(m_rem) - { - a10 = L + (i*cs_a); - dcomplex *ptr_a10_dup = D_A_pack; - if(m_rem == 3) - { - dim_t p_lda = 4; - if(transa) - { - for(dim_t x = 0; x < i; x += p_lda) - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm10 = _mm256_loadu_pd((double const *) - (a10 + 2)); - ymm1 = _mm256_loadu_pd((double const *) - (a10 + cs_a)); - ymm11 = _mm256_loadu_pd((double const *) - (a10 + 2 + cs_a)); - - ymm6 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - ymm8 = _mm256_permute2f128_pd(ymm10,ymm11,0x20); - ymm9 = _mm256_permute2f128_pd(ymm10,ymm11,0x31); - - _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); - _mm256_storeu_pd((double *)(ptr_a10_dup + - p_lda), ymm7); - _mm256_storeu_pd((double *)(ptr_a10_dup + - p_lda*2), ymm8); - _mm256_storeu_pd((double *)(ptr_a10_dup + - p_lda*3), ymm9); - - ymm0 = _mm256_loadu_pd((double const *)(a10 - + 2 * cs_a)); - ymm10 = _mm256_loadu_pd((double const *)(a10 - + 2 * cs_a + 2)); - - ymm1 = _mm256_loadu_pd((double const *)(a10 - + 3 * cs_a)); - ymm11 = _mm256_loadu_pd((double const *)(a10 - + 3 * cs_a + 2)); - - ymm6 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - ymm8 = _mm256_permute2f128_pd(ymm10,ymm11,0x20); - ymm9 = _mm256_permute2f128_pd(ymm10,ymm11,0x31); - - _mm256_storeu_pd((double *)(ptr_a10_dup + 2), - ymm6); - _mm256_storeu_pd((double *)(ptr_a10_dup + - p_lda + 2), ymm7); - _mm256_storeu_pd((double *)(ptr_a10_dup + - p_lda*2 + 2), ymm8); - _mm256_storeu_pd((double *)(ptr_a10_dup + - p_lda*3 + 2), ymm9); - - a10 += p_lda; - ptr_a10_dup += p_lda * p_lda; - } - - } - else - { - for(dim_t x=0;xbuffer; //value of alpha - dcomplex *L = a->buffer; //pointer to matrix A - dcomplex *B = b->buffer; //pointer to matrix B +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm5 with ymm1*/ + BLIS_ZTRSM_DIV(ymm5) +#else + /*performs dcomplex multiplication of ymm5 with ymm1*/ + BLIS_ZTRSM_MUL(ymm5) +#endif + //extract a22 + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); - //pointers that point to blocks for GEMM and TRSM - dcomplex *a10, *a11, *b01, *b11; + //(ROW2): FMA operations + ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + + //For ymm5 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm5, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm5, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + //Step 3 + ymm3 = _mm256_add_pd(ymm16, ymm3); - dcomplex ones = {1.0, 1.0}; - bool is_unitdiag = bli_obj_has_unit_diag(a); - //scratch registers - __m256d ymm0, ymm1, ymm2, ymm3; - __m256d ymm4, ymm5, ymm6, ymm7; - __m256d ymm8, ymm9, ymm10, ymm11; - __m256d ymm12, ymm13, ymm14, ymm15; - __m256d ymm16, ymm17, ymm18, ymm19; +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm3 with ymm1*/ + BLIS_ZTRSM_DIV(ymm3) +#else + /*performs dcomplex multiplication of ymm3 with ymm1*/ + BLIS_ZTRSM_MUL(ymm3) +#endif + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); + m_remainder -=2; + } + else if(1 == m_remainder) + { + a01 = D_A_pack; + a11 = L + (j*cs_a) + (j*rs_a); + b10 = B + (j+d_nr)*cs_b + (m_remainder - 1); + b11 = B + (m_remainder - 1) + (j*cs_b); + k_iter = (n-j-d_nr); + /*Fill zeros into ymm registers used in gemm + * accumulations */ + BLIS_SET_YMM_REG_ZEROS + /* + Peform GEMM between a01 and b10 blocks + For first itteration there will be no GEMM operation + where k_iter are zero + */ + + BLIS_ZTRSM_SMALL_GEMM_3nx2m(a01,b10,cs_b,p_lda,k_iter) + + /* + Load b11 and multiply with alpha + Add the GEMM output to b11 + and peform TRSM operation. + */ + + BLIS_PRE_ZTRSM_SMALL_3x2(AlphaVal,b11,cs_b) + ///implement TRSM/// + /* + Compute 3x3 TRSM block by using GEMM block output + in register + a. The 4x3 input (gemm outputs) are stored in + combinations of ymm registers + 1. ymm7, ymm8 2. ymm5, ymm6 3. ymm3, ymm4 + b. Towards the end do in regiser transpose of TRSM + output and store in + b11 + */ + ////extract a00 + ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + ymm1 = _mm256_broadcast_pd((__m128d const *) + (d11_pack + 2)); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm7 with ymm1*/ + BLIS_ZTRSM_DIV(ymm7) +#else + /*performs dcomplex multiplication of ymm7 with ymm1*/ + BLIS_ZTRSM_MUL(ymm7) +#endif + //extract a11 + ymm1 = _mm256_broadcast_pd((__m128d const *) + (d11_pack + 1)); + //(ROW1): FMA operations + ymm2 = _mm256_broadcast_pd((__m128d const *) + (a11 + cs_a*2 + rs_a*1)); + + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + /* Step1 dcomplex multiply ymm2, ymm7 + * Step2 negate the result + * Step3 add ymmx*/ + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + //For ymm7 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm7, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm7, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + + //Step 3 + ymm5 = _mm256_add_pd(ymm16, ymm5); + + ymm2 = _mm256_broadcast_pd((__m128d const *) + (a11 + cs_a*2)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + //For ymm7 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm7, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm7, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + + //Step 3 + ymm3 = _mm256_add_pd(ymm16, ymm3); - __m128d xmm5, xmm4, xmm3; - gint_t required_packing_A = 1; - mem_t local_mem_buf_A_s = {0}; - dcomplex *D_A_pack = NULL; - dcomplex d11_pack[d_mr] __attribute__((aligned(64))); - rntm_t rntm; +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm5 with ymm1*/ + BLIS_ZTRSM_DIV(ymm5) +#else + /*performs dcomplex multiplication of ymm5 with ymm1*/ + BLIS_ZTRSM_MUL(ymm5) +#endif + //extract a22 + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); - bli_rntm_init_from_global( &rntm ); - bli_rntm_set_num_threads_only( 1, &rntm ); - bli_membrk_rntm_set_membrk( &rntm ); + //(ROW2): FMA operations + ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + + //For ymm5 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm5, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm5, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + //Step 3 + ymm3 = _mm256_add_pd(ymm16, ymm3); - siz_t buffer_size = bli_pool_block_size( - bli_membrk_pool( - bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), - bli_rntm_membrk(&rntm))); - if((d_mr * m * sizeof(dcomplex)) > buffer_size) - return BLIS_NOT_YET_IMPLEMENTED; +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm3 with ymm1*/ + BLIS_ZTRSM_DIV(ymm3) +#else + /*performs dcomplex multiplication of ymm3 and with ymm1*/ + BLIS_ZTRSM_MUL(ymm3) +#endif + _mm_storeu_pd((double *)b11, + _mm256_extractf128_pd(ymm3,0)); + _mm_storeu_pd((double *)(b11 + cs_b), + _mm256_extractf128_pd(ymm5,0)); + _mm_storeu_pd((double *)(b11 + cs_b*2), + _mm256_extractf128_pd(ymm7,0)); + m_remainder -=1; + } + } - if(required_packing_A == 1) - { - // Get the buffer from the pool. - bli_membrk_acquire_m(&rntm, - buffer_size, - BLIS_BITVAL_BUFFER_FOR_A_BLOCK, - &local_mem_buf_A_s); - if(FALSE==bli_mem_is_alloc(&local_mem_buf_A_s)) return BLIS_NULL_POINTER; - D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); - if(NULL==D_A_pack) return BLIS_NULL_POINTER; } - - /* - Performs solving TRSM for 4 colmns at a time from 0 to m/d_mr in steps of d_mr - a. Load, transpose, Pack A (a10 block), the size of packing 8x6 to 8x (m-d_mr) - First there will be no GEMM and no packing of a10 because it is only TRSM - b. Using packed a10 block and b01 block perform GEMM operation - c. Use GEMM outputs, perform TRSM operaton using a11, b11 and update B - d. Repeat b,c for n rows of B in steps of d_nr - */ - for(i = (m - d_mr); (i + 1) > 0; i -= d_mr) + dim_t n_remainder = j + d_nr; + if(n_remainder == 2) { - a10 = L + (i*cs_a) + (i + d_mr)*rs_a;//pointer to block of A to be used for GEMM - a11 = L + (i*cs_a) + (i*rs_a);//pointer to block of A to be used for TRSM + a01 = L + (n_remainder - 2)*rs_a + n_remainder*cs_a; + a11 = L + (n_remainder - 2)*cs_a + (n_remainder - 2)*rs_a; - // Do transpose for a10 & store in D_A_pack - //ptr_a10_dup = D_A_pack; + dcomplex *ptr_a10_dup = D_A_pack; - dim_t p_lda = d_mr; // packed leading dimension + dim_t p_lda = (n-n_remainder); if(transa) { - /* - Load, transpose and pack current A block (a10) into packed buffer memory - D_A_pack - a. This a10 block is used in GEMM portion only and this - a10 block size will be increasing by d_mr for every next itteration - untill it reaches 4x(m-4) which is the maximum GEMM alone block size - in A - b. This packed buffer is reused to calculate all n rows of B matrix - */ - bli_ztrsm_small_pack('L', (m-i-d_mr), 1, a10, cs_a, D_A_pack,p_lda,d_mr); + for(dim_t x =0;x < p_lda;x+=d_nr) + { + ymm0 = _mm256_loadu_pd((double const *)(a01)); + ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); + ymm3 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm4 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm3); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm4); + ymm0 = _mm256_loadu_pd((double const *)(a01 + 2)); + ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a + 2)); + ymm3 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 2), + ymm3); + + ymm0 = _mm256_loadu_pd((double const *)(a01 + cs_a * 2)); + ymm1 = _mm256_loadu_pd((double const *) + (a01 + cs_a * 2 + 2)); + ymm5 = _mm256_broadcast_pd((__m128d const *)&zero); + + ymm3 = _mm256_permute2f128_pd(ymm0,ymm5,0x20); + ymm4 = _mm256_permute2f128_pd(ymm0,ymm5,0x31); + ymm5 = _mm256_permute2f128_pd(ymm1,ymm5,0x20); + + _mm_storeu_pd((double *)(ptr_a10_dup + 2), + _mm256_extractf128_pd(ymm3,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + p_lda + 2), + _mm256_extractf128_pd(ymm4,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 2 + 2), + _mm256_extractf128_pd(ymm5, 0)); + a01 += d_nr*cs_a; + ptr_a10_dup += d_nr; + } + } + else + { + dim_t loop_count = (n-n_remainder)/2; - /* - Pack 8 diagonal elements of A block into an array - a. This helps in utilze cache line efficiently in TRSM operation - b. store ones when input is unit diagonal - */ - ztrsm_small_pack_diag_element(is_unitdiag,a11,cs_a,d11_pack,d_mr); + for(dim_t x =0;x < loop_count;x++) + { + ymm15 = _mm256_loadu_pd((double const *) + (a01 + rs_a * 0 + x*2)); + _mm256_storeu_pd((double *) + (ptr_a10_dup + p_lda * 0 + x*2), ymm15); + ymm15 = _mm256_loadu_pd((double const *) + (a01 + rs_a * 1 + x*2)); + _mm256_storeu_pd((double *) + (ptr_a10_dup + p_lda * 1 + x*2), ymm15); + } + + dim_t remainder_loop_count = p_lda - loop_count*2; + + __m128d xmm0; + if(remainder_loop_count != 0) + { + xmm0 = _mm_loadu_pd((double const *) + (a01 + rs_a * 0 + loop_count*2)); + _mm_storeu_pd((double *) + (ptr_a10_dup + p_lda * 0 + loop_count*2), + xmm0); + xmm0 = _mm_loadu_pd((double const *) + (a01 + rs_a * 1 + loop_count*2)); + _mm_storeu_pd((double *) + (ptr_a10_dup + p_lda * 1 + loop_count*2), + xmm0); + } + } + if(!is_unitdiag) + { + if(transa) + { + ymm0 = _mm256_broadcast_pd((__m128d const *)(a11)); + ymm1 = _mm256_broadcast_pd((__m128d const *) + (a11+cs_a*1 + 1)); + } + else + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_pd((__m128d const *)(a11)); + ymm1 = _mm256_broadcast_pd((__m128d const *) + (a11+rs_a*1 + 1)); + } + ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + ymm7 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + /*Taking denomerator multiplication of real & + * imaginary components*/ + ymm4 = _mm256_mul_pd(ymm1, ymm1); + /*Swapping real & imaginary component position for addition with + * respective components*/ + ymm6 = _mm256_permute4x64_pd(ymm4, 0xb1); + ymm4 = _mm256_add_pd(ymm4, ymm6); + /*Negating imaginary component of numerator*/ + ymm1 = _mm256_mul_pd(ymm1, ymm7); + /*Dividing numerator by denominator*/ + ymm1 = _mm256_div_pd(ymm1, ymm4); +#endif } else { - bli_ztrsm_small_pack('L', (m-i-d_mr), 0, a10, rs_a, D_A_pack,p_lda,d_mr); - ztrsm_small_pack_diag_element(is_unitdiag,a11,rs_a,d11_pack,d_mr); + ymm1 = _mm256_broadcast_pd((__m128d const*)&ones); } - - /* - a. Perform GEMM using a10, b01. - b. Perform TRSM on a11, b11 - c. This loop GEMM+TRSM loops operates with 8x6 block size - along n dimension for every d_nr rows of b01 where - packed A buffer is reused in computing all n rows of B. - d. Same approch is used in remaining fringe cases. - */ - for(j = (n - d_nr); (j + 1) > 0; j -= d_nr) + _mm256_storeu_pd((double *)(d11_pack), ymm1); + for(i = (m-d_mr); (i+1) > 0; i -= d_mr) //loop along 'M' direction { - a10 = D_A_pack; - b01 = B + (j * cs_b) + i + d_mr;//pointer to block of B to be used for GEMM - b11 = B + (j * cs_b) + i;//pointer to block of B to be used for TRSM + a01 = D_A_pack; + a11 = L + (n_remainder - 2)*cs_a + (n_remainder - 2)*rs_a; + b10 = B + i + (n_remainder)*cs_b; + b11 = B + (i) + (n_remainder - 2)*cs_b; - k_iter = (m - i - d_mr); + k_iter = (n-n_remainder); /*Fill zeros into ymm registers used in gemm accumulations */ BLIS_SET_YMM_REG_ZEROS - - /* - Peform GEMM between a10 and b01 blocks - For first itteration there will be no GEMM operation - where k_iter are zero - */ - BLIS_ZTRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) - - /* - Load b11 of size 6x8 and multiply with alpha - Add the GEMM output and perform inregister transose of b11 - to peform TRSM operation. - */ - BLIS_ZTRSM_SMALL_NREG_TRANSPOSE_3x4(b11,cs_b,AlphaVal) - - /* - Compute 4x3 TRSM block by using GEMM block output in register - a. The 4x3 input (gemm outputs) are stored in combinations of ymm - registers - 1. ymm8, ymm4 2. ymm9, ymm5 3. ymm10, ymm6, 4. ymm11, ymm7 - where ymm8-ymm11 holds 4x2 data and reaming 4x1 will be hold by - other registers - b. Towards the end do in regiser transpose of TRSM output and store in - b11 - */ + BLIS_ZTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_PRE_ZTRSM_SMALL_2x4(AlphaVal,b11,cs_b) + ///implement TRSM/// ////extract a00 - ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack + 3)); + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack + 1)); #ifndef BLIS_ENABLE_TRSM_PREINVERSION - /*performs dcomplex divison of ymm11 and ymm7 with ymm1*/ - BLIS_ZTRSM_TWO_DIV(ymm11,ymm7) + /*performs dcomplex divison of ymm5 and ymm6 with ymm1*/ + BLIS_ZTRSM_TWO_DIV(ymm5,ymm6) #else - /*performs dcomplex multiplication of ymm11 and ymm7 with ymm1*/ - BLIS_ZTRSM_MUL(ymm11) - BLIS_ZTRSM_MUL(ymm7) + /*performs dcomplex multiplication of ymm5 and ymm6 with ymm1*/ + BLIS_ZTRSM_MUL(ymm5) + BLIS_ZTRSM_MUL(ymm6) #endif - //extract a11 - ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack + 2)); - //(ROW1): FMA operations - ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a*2 + rs_a*3)); - if(conjtransa) - { - ymm2 = _mm256_mul_pd(ymm2, ymm0); - } - /* Step1 dcomplex multiply ymm2, ymm8 - * Step2 negate the result - * Step3 add ymm9*/ + //extract a22 + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); + + //(ROW2): FMA operations + ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } //Step 1 ymm14 = _mm256_permute_pd(ymm2, 0x5); /* Negate the imaginary elements of vec2 */ ymm14 = _mm256_mul_pd(ymm14, ymm0); - //For ymm11 + + //For ymm5 /* Multiply vec1 and vec2 */ - ymm13 = _mm256_mul_pd(ymm11, ymm2); /*vec3*/ + ymm13 = _mm256_mul_pd(ymm5, ymm2); /*vec3*/ /* Multiply vec1 and the modified vec2 */ - ymm14 = _mm256_mul_pd(ymm11, ymm14); /*vec4*/ + ymm14 = _mm256_mul_pd(ymm5, ymm14); /*vec4*/ /* Horizontally subtract the elements in vec3 and vec4 */ ymm16 = _mm256_hsub_pd(ymm13, ymm14); - - //For ymm7 + //For ymm6 ymm14 = _mm256_permute_pd(ymm2, 0x5); /* Negate the imaginary elements of vec2 */ ymm14 = _mm256_mul_pd(ymm14, ymm0); - ymm13 = _mm256_mul_pd(ymm7, ymm2); - ymm14 = _mm256_mul_pd(ymm7, ymm14); + ymm13 = _mm256_mul_pd(ymm6, ymm2); + ymm14 = _mm256_mul_pd(ymm6, ymm14); ymm17 = _mm256_hsub_pd(ymm13, ymm14); //Step 2 ymm16 = _mm256_mul_pd(ymm16, ymm15); ymm17 = _mm256_mul_pd(ymm17, ymm15); - //Step 3 - ymm10 = _mm256_add_pd(ymm16, ymm10); - ymm6 = _mm256_add_pd(ymm17, ymm6); + ymm3 = _mm256_add_pd(ymm16, ymm3); + ymm4 = _mm256_add_pd(ymm17, ymm4); - ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a*1 + rs_a*3)); - if(conjtransa) - { - ymm2 = _mm256_mul_pd(ymm2, ymm0); - } - //Step 1 - ymm14 = _mm256_permute_pd(ymm2, 0x5); - /* Negate the imaginary elements of vec2 */ - ymm14 = _mm256_mul_pd(ymm14, ymm0); - //For ymm11 - /* Multiply vec1 and vec2 */ - ymm13 = _mm256_mul_pd(ymm11, ymm2); /*vec3*/ - /* Multiply vec1 and the modified vec2 */ - ymm14 = _mm256_mul_pd(ymm11, ymm14); /*vec4*/ - /* Horizontally subtract the elements in vec3 and vec4 */ - ymm16 = _mm256_hsub_pd(ymm13, ymm14); - //For ymm7 - ymm14 = _mm256_permute_pd(ymm2, 0x5); - /* Negate the imaginary elements of vec2 */ - ymm14 = _mm256_mul_pd(ymm14, ymm0); - ymm13 = _mm256_mul_pd(ymm7, ymm2); - ymm14 = _mm256_mul_pd(ymm7, ymm14); - ymm17 = _mm256_hsub_pd(ymm13, ymm14); - //Step 2 - ymm16 = _mm256_mul_pd(ymm16, ymm15); - ymm17 = _mm256_mul_pd(ymm17, ymm15); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm3 and ymm4 with ymm1*/ + BLIS_ZTRSM_TWO_DIV(ymm3,ymm4) +#else + /*performs dcomplex multiplication of ymm3 and ymm4 with ymm1*/ + BLIS_ZTRSM_MUL(ymm3) + BLIS_ZTRSM_MUL(ymm4) +#endif + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + 2), ymm4); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + _mm256_storeu_pd((double *)(b11 + cs_b + 2), ymm6); - //Step 3 - ymm9 = _mm256_add_pd(ymm16, ymm9); - ymm5 = _mm256_add_pd(ymm17, ymm5); + } + dim_t m_remainder = i + d_mr; + if(3 == m_remainder) + { + a01 = D_A_pack; + a11 = L + (n_remainder - 2)*cs_a + (n_remainder - 2)*rs_a; + b10 = B + (m_remainder - 3) + (n_remainder)*cs_b; + b11 = B + (m_remainder - 3) + (n_remainder - 2)*cs_b; - ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + rs_a*3)); - if(conjtransa) - { - ymm2 = _mm256_mul_pd(ymm2, ymm0); - } + k_iter = (n-n_remainder); + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + /* + Peform GEMM between a01 and b10 blocks + For first itteration there will be no GEMM operation + where k_iter are zero + */ + BLIS_ZTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) + + // Load b11 and multiply with alpha + BLIS_PRE_ZTRSM_SMALL_3x4(AlphaVal,b11,cs_b) + ////extract a00 + ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack + 1)); + +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm5 and ymm6 with ymm1*/ + BLIS_ZTRSM_TWO_DIV(ymm5,ymm6) +#else + /*performs dcomplex multiplication of ymm5 and ymm6 with ymm1*/ + BLIS_ZTRSM_MUL(ymm5) + BLIS_ZTRSM_MUL(ymm6) +#endif + //extract a22 + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); + + //(ROW2): FMA operations + ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } //Step 1 ymm14 = _mm256_permute_pd(ymm2, 0x5); /* Negate the imaginary elements of vec2 */ ymm14 = _mm256_mul_pd(ymm14, ymm0); - //For ymm11 + + //For ymm5 /* Multiply vec1 and vec2 */ - ymm13 = _mm256_mul_pd(ymm11, ymm2); /*vec3*/ + ymm13 = _mm256_mul_pd(ymm5, ymm2); /*vec3*/ /* Multiply vec1 and the modified vec2 */ - ymm14 = _mm256_mul_pd(ymm11, ymm14); /*vec4*/ + ymm14 = _mm256_mul_pd(ymm5, ymm14); /*vec4*/ /* Horizontally subtract the elements in vec3 and vec4 */ ymm16 = _mm256_hsub_pd(ymm13, ymm14); - //For ymm7 + //For ymm6 ymm14 = _mm256_permute_pd(ymm2, 0x5); /* Negate the imaginary elements of vec2 */ ymm14 = _mm256_mul_pd(ymm14, ymm0); - ymm13 = _mm256_mul_pd(ymm7, ymm2); - ymm14 = _mm256_mul_pd(ymm7, ymm14); + ymm13 = _mm256_mul_pd(ymm6, ymm2); + ymm14 = _mm256_mul_pd(ymm6, ymm14); ymm17 = _mm256_hsub_pd(ymm13, ymm14); //Step 2 ymm16 = _mm256_mul_pd(ymm16, ymm15); ymm17 = _mm256_mul_pd(ymm17, ymm15); //Step 3 - ymm8 = _mm256_add_pd(ymm16, ymm8); + ymm3 = _mm256_add_pd(ymm16, ymm3); ymm4 = _mm256_add_pd(ymm17, ymm4); + +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm3 and ymm4 with ymm1*/ + BLIS_ZTRSM_TWO_DIV(ymm3,ymm4) +#else + /*performs dcomplex multiplication of ymm3 and ymm4 with ymm1*/ + BLIS_ZTRSM_MUL(ymm3) + BLIS_ZTRSM_MUL(ymm4) +#endif + _mm256_storeu_pd((double *)b11, ymm3); + _mm_storeu_pd((double *)(b11 + 2), + _mm256_extractf128_pd(ymm4,0)); + + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + _mm_storeu_pd((double *)(b11 + cs_b + 2), + _mm256_extractf128_pd(ymm6,0)); + m_remainder -=3; + } + if(2 == m_remainder) + { + a01 = D_A_pack; + a11 = L + (n_remainder - 2)*cs_a + (n_remainder - 2)*rs_a; + b10 = B + (m_remainder - 2) + (n_remainder)*cs_b; + b11 = B + (m_remainder - 2) + (n_remainder - 2)*cs_b; + + k_iter = (n-n_remainder); + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + /* + Peform GEMM between a01 and b10 blocks + For first itteration there will be no GEMM operation + where k_iter are zero + */ + BLIS_ZTRSM_SMALL_GEMM_3nx2m(a01,b10,cs_b,p_lda,k_iter) + + // Load b11 and multiply with alpha + BLIS_PRE_ZTRSM_SMALL_3x2(AlphaVal,b11,cs_b) + ////extract a00 + ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack + 1)); + #ifndef BLIS_ENABLE_TRSM_PREINVERSION - /*performs dcomplex divison of ymm10 and ymm6 with ymm1*/ - BLIS_ZTRSM_TWO_DIV(ymm10,ymm6) + /*performs dcomplex divison of ymm5 with ymm1*/ + BLIS_ZTRSM_DIV(ymm5) #else - /*performs dcomplex multiplication of ymm10 and ymm6 with ymm1*/ - BLIS_ZTRSM_MUL(ymm10) - BLIS_ZTRSM_MUL(ymm6) + /*performs dcomplex multiplication of ymm5 with ymm1*/ + BLIS_ZTRSM_MUL(ymm5) #endif //extract a22 - ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack + 1)); + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); //(ROW2): FMA operations - ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a*1 + rs_a*2)); - if(conjtransa) - { - ymm2 = _mm256_mul_pd(ymm2, ymm0); - } - //Step 1 + ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + //Step 1 ymm14 = _mm256_permute_pd(ymm2, 0x5); /* Negate the imaginary elements of vec2 */ ymm14 = _mm256_mul_pd(ymm14, ymm0); - //For ymm10 + //For ymm5 /* Multiply vec1 and vec2 */ - ymm13 = _mm256_mul_pd(ymm10, ymm2); /*vec3*/ + ymm13 = _mm256_mul_pd(ymm5, ymm2); /*vec3*/ /* Multiply vec1 and the modified vec2 */ - ymm14 = _mm256_mul_pd(ymm10, ymm14); /*vec4*/ + ymm14 = _mm256_mul_pd(ymm5, ymm14); /*vec4*/ /* Horizontally subtract the elements in vec3 and vec4 */ ymm16 = _mm256_hsub_pd(ymm13, ymm14); - //For ymm6 - ymm14 = _mm256_permute_pd(ymm2, 0x5); - /* Negate the imaginary elements of vec2 */ - ymm14 = _mm256_mul_pd(ymm14, ymm0); - ymm13 = _mm256_mul_pd(ymm6, ymm2); - ymm14 = _mm256_mul_pd(ymm6, ymm14); - ymm17 = _mm256_hsub_pd(ymm13, ymm14); //Step 2 ymm16 = _mm256_mul_pd(ymm16, ymm15); - ymm17 = _mm256_mul_pd(ymm17, ymm15); //Step 3 - ymm9 = _mm256_add_pd(ymm16, ymm9); - ymm5 = _mm256_add_pd(ymm17, ymm5); + ymm3 = _mm256_add_pd(ymm16, ymm3); - ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + rs_a*2)); - if(conjtransa) - { - ymm2 = _mm256_mul_pd(ymm2, ymm0); - } - //Step 1 + +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm3 with ymm1*/ + BLIS_ZTRSM_DIV(ymm3) +#else + /*performs dcomplex multiplication of ymm3 with ymm1*/ + BLIS_ZTRSM_MUL(ymm3) +#endif + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + m_remainder -=2; + } + if(1 == m_remainder) + { + a01 = D_A_pack; + a11 = L + (n_remainder - 2)*cs_a + (n_remainder - 2)*rs_a; + b10 = B + (m_remainder - 1) + (n_remainder)*cs_b; + b11 = B + (m_remainder - 1) + (n_remainder - 2)*cs_b; + + k_iter = (n-n_remainder); + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + /* + Peform GEMM between a01 and b10 blocks + For first itteration there will be no GEMM operation + where k_iter are zero + */ + BLIS_ZTRSM_SMALL_GEMM_3nx2m(a01,b10,cs_b,p_lda,k_iter) + + // Load b11 and multiply with alpha + BLIS_PRE_ZTRSM_SMALL_3x2(AlphaVal,b11,cs_b) + ////extract a00 + ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack + 1)); + +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm5 with ymm1*/ + BLIS_ZTRSM_DIV(ymm5) +#else + /*performs dcomplex multiplication of ymm5 with ymm1*/ + BLIS_ZTRSM_MUL(ymm5) +#endif + //extract a22 + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); + + //(ROW2): FMA operations + ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + //Step 1 ymm14 = _mm256_permute_pd(ymm2, 0x5); /* Negate the imaginary elements of vec2 */ ymm14 = _mm256_mul_pd(ymm14, ymm0); - //For ymm10 + //For ymm5 /* Multiply vec1 and vec2 */ - ymm13 = _mm256_mul_pd(ymm10, ymm2); /*vec3*/ + ymm13 = _mm256_mul_pd(ymm5, ymm2); /*vec3*/ /* Multiply vec1 and the modified vec2 */ - ymm14 = _mm256_mul_pd(ymm10, ymm14); /*vec4*/ + ymm14 = _mm256_mul_pd(ymm5, ymm14); /*vec4*/ /* Horizontally subtract the elements in vec3 and vec4 */ ymm16 = _mm256_hsub_pd(ymm13, ymm14); - //For ymm6 - ymm14 = _mm256_permute_pd(ymm2, 0x5); - /* Negate the imaginary elements of vec2 */ - ymm14 = _mm256_mul_pd(ymm14, ymm0); - ymm13 = _mm256_mul_pd(ymm6, ymm2); - ymm14 = _mm256_mul_pd(ymm6, ymm14); - ymm17 = _mm256_hsub_pd(ymm13, ymm14); //Step 2 ymm16 = _mm256_mul_pd(ymm16, ymm15); - ymm17 = _mm256_mul_pd(ymm17, ymm15); //Step 3 - ymm8 = _mm256_add_pd(ymm16, ymm8); - ymm4 = _mm256_add_pd(ymm17, ymm4); + ymm3 = _mm256_add_pd(ymm16, ymm3); + + +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm3 with ymm1*/ + BLIS_ZTRSM_DIV(ymm3) +#else + /*performs dcomplex multiplication of ymm3 with ymm1*/ + BLIS_ZTRSM_MUL(ymm3) +#endif + _mm_storeu_pd((double *)b11, + _mm256_extractf128_pd(ymm3,0)); + _mm_storeu_pd((double *)(b11 + cs_b), + _mm256_extractf128_pd(ymm5,0)); + m_remainder -=1; + } + n_remainder -= 2; + } + else if(n_remainder == 1) + { + a01 = L + (n_remainder - 1)*rs_a + n_remainder*cs_a; + a11 = L + (n_remainder - 1)*cs_a + (n_remainder - 1)*rs_a; + + dcomplex *ptr_a10_dup = D_A_pack; + + dim_t p_lda = (n-n_remainder); // packed leading dimension + // perform copy of A to packed buffer D_A_pack + if(transa) + { + for(dim_t x =0;x < p_lda;x+=d_nr) + { + ymm0 = _mm256_loadu_pd((double const *)(a01)); + ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); + ymm3 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm4 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm3); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm4); + + ymm0 = _mm256_loadu_pd((double const *)(a01 + 2)); + ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a + 2)); + ymm3 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + _mm256_storeu_pd((double *) + (ptr_a10_dup + p_lda * 2), ymm3); + + ymm0 = _mm256_loadu_pd((double const *)(a01 + cs_a * 2)); + ymm1 = _mm256_loadu_pd((double const *) + (a01 + cs_a * 2 + 2)); + ymm5 = _mm256_broadcast_pd((__m128d const *)&zero); + + ymm3 = _mm256_permute2f128_pd(ymm0,ymm5,0x20); + ymm4 = _mm256_permute2f128_pd(ymm0,ymm5,0x31); + ymm5 = _mm256_permute2f128_pd(ymm1,ymm5,0x20); + + _mm_storeu_pd((double *)(ptr_a10_dup + 2), + _mm256_extractf128_pd(ymm3,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + p_lda + 2), + _mm256_extractf128_pd(ymm4,0)); + _mm_storeu_pd((double *) + (ptr_a10_dup + p_lda * 2 + 2), + _mm256_extractf128_pd(ymm5, 0)); + a01 += d_nr*cs_a; + ptr_a10_dup += d_nr; + } + + } + else + { + dim_t loop_count = (n-n_remainder)/2; + + for(dim_t x =0;x < loop_count;x++) + { + ymm15 = _mm256_loadu_pd((double const *) + (a01 + rs_a * 0 + x*2)); + _mm256_storeu_pd((double *) + (ptr_a10_dup + p_lda * 0 + x*2), ymm15); + } + + dim_t remainder_loop_count = p_lda - loop_count*2; + + __m128d xmm0; + if(remainder_loop_count != 0) + { + xmm0 = _mm_loadu_pd((double const *) + (a01 + rs_a * 0 + loop_count*2)); + _mm_storeu_pd((double *) + (ptr_a10_dup + p_lda * 0 + loop_count*2), + xmm0); + } + } + if(!is_unitdiag) + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_pd((__m128d const *)(a11)); + ymm1 = _mm256_broadcast_pd((__m128d const *)&ones); + ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + ymm7 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + /*Taking denomerator multiplication of real & + * imaginary components*/ + ymm4 = _mm256_mul_pd(ymm1, ymm1); + /*Swapping real & imaginary component position for addition with + * respective components*/ + ymm6 = _mm256_permute4x64_pd(ymm4, 0xb1); + ymm4 = _mm256_add_pd(ymm4, ymm6); + /*Negating imaginary component of numerator*/ + ymm1 = _mm256_mul_pd(ymm1, ymm7); + /*Dividing numerator by denominator*/ + ymm1 = _mm256_div_pd(ymm1, ymm4); +#endif + } + else + { + ymm1 = _mm256_broadcast_pd((__m128d const*)&ones); + } + _mm256_storeu_pd((double *)(d11_pack), ymm1); + for(i = (m-d_mr); (i+1) > 0; i -= d_mr) //loop along 'M' direction + { + a01 = D_A_pack; + a11 = L + (n_remainder - 1)*cs_a + (n_remainder - 1)*rs_a; + b10 = B + i + (n_remainder)*cs_b; + b11 = B + (i) + (n_remainder - 1)*cs_b; + + k_iter = (n-n_remainder); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + ///GEMM implementation starts/// + BLIS_ZTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_PRE_ZTRSM_SMALL_1x4(b11,cs_b,AlphaVal) + ///implement TRSM/// + ////extract a00 + ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm3 and ymm4 with ymm1*/ + BLIS_ZTRSM_TWO_DIV(ymm3,ymm4) +#else + /*performs dcomplex multiplication of ymm3 and ymm4 with ymm1*/ + BLIS_ZTRSM_MUL(ymm3) + BLIS_ZTRSM_MUL(ymm4) +#endif + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + 2),ymm4); + + } + dim_t m_remainder = i + d_mr; + if(3 == m_remainder) + { + a01 = D_A_pack; + a11 = L + (n_remainder - 1)*cs_a + (n_remainder - 1)*rs_a; + b10 = B + (m_remainder - 3) + (n_remainder)*cs_b; + b11 = B + (m_remainder - 3) + (n_remainder - 1)*cs_b; + + k_iter = (n-n_remainder); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_ZTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_PRE_ZTRSM_SMALL_1x3(b11,cs_b,AlphaVal) + + ///implement TRSM/// + ////extract a00 + ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm3 and ymm4 with ymm1*/ + BLIS_ZTRSM_TWO_DIV(ymm3,ymm4) +#else + /*performs dcomplex multiplication of ymm3 and ymm4 with ymm1*/ + BLIS_ZTRSM_MUL(ymm3) + BLIS_ZTRSM_MUL(ymm4) +#endif + + _mm256_storeu_pd((double *)b11, ymm3); + _mm_storeu_pd((double *)(b11 + 2), + _mm256_extractf128_pd(ymm4,0)); + m_remainder -=3; + + } + else if(2 == m_remainder) + { + a01 = D_A_pack; + a11 = L + (n_remainder - 1)*cs_a + (n_remainder - 1)*rs_a; + b10 = B + (m_remainder - 2) + (n_remainder)*cs_b; + b11 = B + (m_remainder - 2) + (n_remainder - 1)*cs_b; + + k_iter = (n-n_remainder); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_ZTRSM_SMALL_GEMM_1nx2m(a01,b10,cs_b,p_lda,k_iter) + + // Load b11 of size 2x1 and multiply with alpha + BLIS_PRE_ZTRSM_SMALL_1x2(AlphaVal,b11,cs_b) + + ///implement TRSM/// + ////extract a00 + ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm3 with ymm1*/ + BLIS_ZTRSM_DIV(ymm3) +#else + /*performs dcomplex multiplication of ymm3 with ymm1*/ + BLIS_ZTRSM_MUL(ymm3) +#endif + + _mm256_storeu_pd((double *)b11, ymm3); + m_remainder -=2; + + } + else if (1 == m_remainder) + { + a01 = D_A_pack; + a11 = L + (n_remainder - 1)*cs_a + (n_remainder - 1)*rs_a; + b10 = B + (m_remainder - 1) + (n_remainder)*cs_b; + b11 = B + (m_remainder - 1) + (n_remainder - 1)*cs_b; + + k_iter = (n-n_remainder); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_ZTRSM_SMALL_GEMM_1nx1m(a01,b10,cs_b,p_lda,k_iter) + + // Load b11 of size 4x6 and multiply with alpha + BLIS_PRE_ZTRSM_SMALL_1x1(AlphaVal,b11,cs_b) + + ///implement TRSM/// + ////extract a00 + ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + /*performs dcomplex divison of ymm3 with ymm1*/ + BLIS_ZTRSM_DIV(ymm3) +#else + /*performs dcomplex multiplication of ymm3 with ymm1*/ + BLIS_ZTRSM_MUL(ymm3) +#endif + _mm_storeu_pd((double *)b11, + _mm256_extractf128_pd(ymm3,0)); + m_remainder -=1; + } + n_remainder -= 1; + } + + if ((required_packing_A == 1) && + bli_mem_is_alloc( &local_mem_buf_A_s )) + { + bli_membrk_release(&rntm, &local_mem_buf_A_s); + } + + + return BLIS_SUCCESS; +} + +BLIS_INLINE err_t bli_ztrsm_small_XAltB_XAuB +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +) +{ + dim_t m = bli_obj_length(b); //number of rows + dim_t n = bli_obj_width(b); //number of columns + + bool transa = bli_obj_has_trans(a); + bool conjtransa = bli_obj_has_conj(a); + + dim_t cs_a, rs_a; + dim_t d_mr = 4,d_nr = 3; + + // Swap rs_a & cs_a in case of non-tranpose. + if(transa) + { + cs_a = bli_obj_col_stride(a); // column stride of A + rs_a = bli_obj_row_stride(a); // row stride of A + } + else + { + cs_a = bli_obj_row_stride(a); // row stride of A + rs_a = bli_obj_col_stride(a); // column stride of A + } + dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B + + dim_t i, j, k; //loop variablse + dim_t k_iter; //determines the number of GEMM operations to be done + + dcomplex ones = {1.0, 1.0}; + dcomplex zero = {0.0, 0.0}; + bool is_unitdiag = bli_obj_has_unit_diag(a); + + dcomplex AlphaVal = *(dcomplex *)AlphaObj->buffer; //value of Alpha + dcomplex* restrict L = a->buffer; //pointer to matrix A + dcomplex* restrict B = b->buffer; //pointer to matrix B + + dcomplex *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks + + gint_t required_packing_A = 1; + mem_t local_mem_buf_A_s = {0}; + dcomplex *D_A_pack = NULL; + dcomplex d11_pack[d_mr] __attribute__((aligned(64))); + rntm_t rntm; + + bli_rntm_init_from_global( &rntm ); + bli_rntm_set_num_threads_only( 1, &rntm ); + bli_membrk_rntm_set_membrk( &rntm ); + + siz_t buffer_size = bli_pool_block_size( + bli_membrk_pool( + bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), + bli_rntm_membrk(&rntm))); + + if( (d_nr * n * sizeof(dcomplex)) > buffer_size) + return BLIS_NOT_YET_IMPLEMENTED; -#ifndef BLIS_ENABLE_TRSM_PREINVERSION - /*performs dcomplex divison of ymm9 and ymm5 with ymm1*/ - BLIS_ZTRSM_TWO_DIV(ymm9,ymm5) -#else - /*performs dcomplex multiplication of ymm9 and ymm5 with ymm1*/ - BLIS_ZTRSM_MUL(ymm9) - BLIS_ZTRSM_MUL(ymm5) -#endif - //extract a44 - ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); - //(ROW3): FMA operations - ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + rs_a)); - if(conjtransa) - { - ymm2 = _mm256_mul_pd(ymm2, ymm0); - } - //Step 1 - ymm14 = _mm256_permute_pd(ymm2, 0x5); - /* Negate the imaginary elements of vec2 */ - ymm14 = _mm256_mul_pd(ymm14, ymm0); + if (required_packing_A == 1) + { + // Get the buffer from the pool. + bli_membrk_acquire_m(&rntm, + buffer_size, + BLIS_BITVAL_BUFFER_FOR_A_BLOCK, + &local_mem_buf_A_s); + if(FALSE==bli_mem_is_alloc(&local_mem_buf_A_s)) return BLIS_NULL_POINTER; + D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); + if(NULL==D_A_pack) return BLIS_NULL_POINTER; + } - //For ymm9 - /* Multiply vec1 and vec2 */ - ymm13 = _mm256_mul_pd(ymm9, ymm2); /*vec3*/ - /* Multiply vec1 and the modified vec2 */ - ymm14 = _mm256_mul_pd(ymm9, ymm14); /*vec4*/ - /* Horizontally subtract the elements in vec3 and vec4 */ - ymm16 = _mm256_hsub_pd(ymm13, ymm14); - //For ymm5 - ymm14 = _mm256_permute_pd(ymm2, 0x5); - /* Negate the imaginary elements of vec2 */ - ymm14 = _mm256_mul_pd(ymm14, ymm0); + //ymm scratch reginsters + __m256d ymm0, ymm1, ymm2, ymm3; + __m256d ymm4, ymm5, ymm6, ymm7; + __m256d ymm8, ymm9, ymm10, ymm11; + __m256d ymm12, ymm13, ymm14, ymm15; + __m256d ymm16, ymm17, ymm18, ymm19; - ymm13 = _mm256_mul_pd(ymm5, ymm2); - ymm14 = _mm256_mul_pd(ymm5, ymm14); - ymm17 = _mm256_hsub_pd(ymm13, ymm14); - //Step 2 - ymm16 = _mm256_mul_pd(ymm16, ymm15); - ymm17 = _mm256_mul_pd(ymm17, ymm15); - //Step 3 - ymm8 = _mm256_add_pd(ymm16, ymm8); - ymm4 = _mm256_add_pd(ymm17, ymm4); + __m128d xmm5; -#ifndef BLIS_ENABLE_TRSM_PREINVERSION - /*performs dcomplex divison of ymm8 and ymm4 with ymm1*/ - BLIS_ZTRSM_TWO_DIV(ymm8,ymm4) -#else - /*performs dcomplex nultiplication of ymm8 and ymm4 with ymm1*/ - BLIS_ZTRSM_MUL(ymm8) - BLIS_ZTRSM_MUL(ymm4) + for(j = 0; (j+d_nr-1) < n; j += d_nr) //loop along 'N' direction + { + a01 = L + j*rs_a;//pointer to block of A to be used in GEMM + a11 = L + j*cs_a + j*rs_a;//pointer to block of A to be used for TRSM -#endif - BLIS_ZTRSM_SMALL_NREG_TRANSPOSE_4x3_AND_STORE(b11,cs_b) + dim_t p_lda = j; // packed leading dimension + // perform copy of A to packed buffer D_A_pack + if(transa) + { + /* + Pack current A block (a01) into packed buffer memory D_A_pack + a. This a10 block is used in GEMM portion only and this + a01 block size will be increasing by d_nr for every next + iteration until it reaches 3x(n-3) which is the maximum GEMM + alone block size in A + b. This packed buffer is reused to calculate all m cols of + B matrix + */ + bli_ztrsm_small_pack('R', j, 1, a01, cs_a, D_A_pack, p_lda,d_nr); - } - dim_t n_remainder = j + d_nr; - if(n_remainder) + /* + Pack 3 diagonal elements of A block into an array + a. This helps in utilze cache line efficiently in TRSM + operation + b. store ones when input is unit diagonal + */ + ztrsm_small_pack_diag_element(is_unitdiag,a11,cs_a, + d11_pack,d_nr); + } + else { - a10 = D_A_pack; - a11 = L + (i*cs_a) + (i*rs_a); - b01 = B + i + d_mr; - b11 = B + i; + bli_ztrsm_small_pack('R', j, 0, a01, rs_a, D_A_pack, + p_lda,d_nr); + ztrsm_small_pack_diag_element(is_unitdiag,a11,rs_a, + d11_pack,d_nr); + } - k_iter = (m - i - d_mr) ; + /* + a. Perform GEMM using a01, b10. + b. Perform TRSM on a11, b11 + c. This loop GEMM+TRSM loops operates with 8x6 block size + along m dimension for every d_mr columns of B10 where + packed A buffer is reused in computing all m cols of B. + d. Same approach is used in remaining fringe cases. + */ + for(i = 0; (i+d_mr-1) < m; i += d_mr) //loop along 'M' direction + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; + b10 = B + i; + b11 = B + i + j*cs_b; + k_iter = j; /*Fill zeros into ymm registers used in gemm accumulations */ BLIS_SET_YMM_REG_ZEROS - if(2 == n_remainder) - { - ///GEMM code begins/// - BLIS_ZTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b,p_lda,k_iter) + /* + Peform GEMM between a01 and b10 blocks + For first itteration there will be no GEMM operation + where k_iter are zero + */ - ymm16 = _mm256_broadcast_pd((__m128d const *)(&AlphaVal)); - //register to hold alpha - BLIS_ZTRSM_SMALL_NREG_TRANSPOSE_2x4(b11,cs_b,AlphaVal) - } - else if(1 == n_remainder) - { - ///GEMM code begins/// - BLIS_ZTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b,p_lda,k_iter) - BLIS_ZTRSM_SMALL_NREG_TRANSPOSE_1x4(b11,cs_b,AlphaVal) - } - ///implement TRSM/// - ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - ymm10 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm11 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + BLIS_ZTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) + + /* + Load b11 of size 4x3 and multiply with alpha + Add the GEMM output to b11 + and peform TRSM operation. + */ + BLIS_PRE_ZTRSM_SMALL_3x4(AlphaVal,b11,cs_b) + ///implement TRSM/// + /* + Compute 3x3 TRSM block by using GEMM block output in register + a. The 3x4 input (gemm outputs) are stored in combinations of + ymm registers + 1. ymm3, ymm4 2. ymm5, ymm6 3. ymm7, ymm8 + b. Towards the end do in regiser transpose of TRSM output + and store in b11 + */ ////extract a00 ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack + 3)); + ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); #ifndef BLIS_ENABLE_TRSM_PREINVERSION - BLIS_ZTRSM_DIV(ymm11) + /*performs dcomplex divison of ymm3 and ymm4 with ymm1*/ + BLIS_ZTRSM_TWO_DIV(ymm3,ymm4) #else - BLIS_ZTRSM_MUL(ymm11) + /*performs dcomplex multiplication of ymm3 and ymm4 with ymm1*/ + BLIS_ZTRSM_MUL(ymm3) + BLIS_ZTRSM_MUL(ymm4) #endif - //extract a11 - ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack + 2)); + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack + 1)); //(ROW1): FMA operations - ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a*2 + rs_a*3)); - ymm3 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a*1 + rs_a*3)); - ymm4 = _mm256_broadcast_pd((__m128d const *)(a11 + rs_a*3)); - if(conjtransa) - { - ymm2 = _mm256_mul_pd(ymm2, ymm0); - ymm3 = _mm256_mul_pd(ymm3, ymm0); - ymm4 = _mm256_mul_pd(ymm4, ymm0); - } - /*Step1 dcomplex multiply ymmx, ymmx + ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + rs_a*1)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + /* Step1 dcomplex multiply ymm2, ymm3 * Step2 negate the result * Step3 add ymmx*/ //Step 1 ymm14 = _mm256_permute_pd(ymm2, 0x5); /* Negate the imaginary elements of vec2 */ ymm14 = _mm256_mul_pd(ymm14, ymm0); - //For ymm8 + //For ymm3 /* Multiply vec1 and vec2 */ - ymm13 = _mm256_mul_pd(ymm11, ymm2); /*vec3*/ + ymm13 = _mm256_mul_pd(ymm3, ymm2); /*vec3*/ /* Multiply vec1 and the modified vec2 */ - ymm14 = _mm256_mul_pd(ymm11, ymm14); /*vec4*/ + ymm14 = _mm256_mul_pd(ymm3, ymm14); /*vec4*/ /* Horizontally subtract the elements in vec3 and vec4 */ ymm16 = _mm256_hsub_pd(ymm13, ymm14); - //Step 2 - ymm16 = _mm256_mul_pd(ymm16, ymm15); - - //Step 3 - ymm10 = _mm256_add_pd(ymm16, ymm10); - - //Step 1 - ymm14 = _mm256_permute_pd(ymm3, 0x5); + //For ymm4 + ymm14 = _mm256_permute_pd(ymm2, 0x5); /* Negate the imaginary elements of vec2 */ ymm14 = _mm256_mul_pd(ymm14, ymm0); - //For ymm8 - /* Multiply vec1 and vec2 */ - ymm13 = _mm256_mul_pd(ymm11, ymm3); /*vec3*/ - /* Multiply vec1 and the modified vec2 */ - ymm14 = _mm256_mul_pd(ymm11, ymm14); /*vec4*/ - /* Horizontally subtract the elements in vec3 and vec4 */ - ymm16 = _mm256_hsub_pd(ymm13, ymm14); + + ymm13 = _mm256_mul_pd(ymm4, ymm2); + ymm14 = _mm256_mul_pd(ymm4, ymm14); + ymm17 = _mm256_hsub_pd(ymm13, ymm14); //Step 2 ymm16 = _mm256_mul_pd(ymm16, ymm15); + ymm17 = _mm256_mul_pd(ymm17, ymm15); //Step 3 - ymm9 = _mm256_add_pd(ymm16, ymm9); + ymm5 = _mm256_add_pd(ymm16, ymm5); + ymm6 = _mm256_add_pd(ymm17, ymm6); + ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + rs_a*2)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } //Step 1 - ymm14 = _mm256_permute_pd(ymm4, 0x5); + ymm14 = _mm256_permute_pd(ymm2, 0x5); /* Negate the imaginary elements of vec2 */ ymm14 = _mm256_mul_pd(ymm14, ymm0); - //For ymm8 + //For ymm3 /* Multiply vec1 and vec2 */ - ymm13 = _mm256_mul_pd(ymm11, ymm4); /*vec3*/ + ymm13 = _mm256_mul_pd(ymm3, ymm2); /*vec3*/ /* Multiply vec1 and the modified vec2 */ - ymm14 = _mm256_mul_pd(ymm11, ymm14); /*vec4*/ + ymm14 = _mm256_mul_pd(ymm3, ymm14); /*vec4*/ /* Horizontally subtract the elements in vec3 and vec4 */ ymm16 = _mm256_hsub_pd(ymm13, ymm14); - //Step 2 - ymm16 = _mm256_mul_pd(ymm16, ymm15); - //Step 3 - ymm8 = _mm256_add_pd(ymm16, ymm8); -#ifndef BLIS_ENABLE_TRSM_PREINVERSION - BLIS_ZTRSM_DIV(ymm10) -#else - BLIS_ZTRSM_MUL(ymm10) -#endif - ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack + 1)); - ymm3 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a*1 + rs_a*2)); - ymm4 = _mm256_broadcast_pd((__m128d const *)(a11 + rs_a*2)); - if(conjtransa) - { - ymm3 = _mm256_mul_pd(ymm3, ymm0); - ymm4 = _mm256_mul_pd(ymm4, ymm0); - } - //Step 1 - ymm14 = _mm256_permute_pd(ymm3, 0x5); + //For ymm4 + ymm14 = _mm256_permute_pd(ymm2, 0x5); /* Negate the imaginary elements of vec2 */ ymm14 = _mm256_mul_pd(ymm14, ymm0); - //For ymm9 - /* Multiply vec1 and vec2 */ - ymm13 = _mm256_mul_pd(ymm10, ymm3); /*vec3*/ - /* Multiply vec1 and the modified vec2 */ - ymm14 = _mm256_mul_pd(ymm10, ymm14); /*vec4*/ - /* Horizontally subtract the elements in vec3 and vec4 */ - ymm16 = _mm256_hsub_pd(ymm13, ymm14); + + ymm13 = _mm256_mul_pd(ymm4, ymm2); + ymm14 = _mm256_mul_pd(ymm4, ymm14); + ymm17 = _mm256_hsub_pd(ymm13, ymm14); //Step 2 ymm16 = _mm256_mul_pd(ymm16, ymm15); + ymm17 = _mm256_mul_pd(ymm17, ymm15); //Step 3 - ymm9 = _mm256_add_pd(ymm16, ymm9); + ymm7 = _mm256_add_pd(ymm16, ymm7); + ymm8 = _mm256_add_pd(ymm17, ymm8); + - //Step 1 - ymm14 = _mm256_permute_pd(ymm4, 0x5); - /* Negate the imaginary elements of vec2 */ - ymm14 = _mm256_mul_pd(ymm14, ymm0); - //For ymm8 - /* Multiply vec1 and vec2 */ - ymm13 = _mm256_mul_pd(ymm10, ymm4); /*vec3*/ - /* Multiply vec1 and the modified vec2 */ - ymm14 = _mm256_mul_pd(ymm10, ymm14); /*vec4*/ - /* Horizontally subtract the elements in vec3 and vec4 */ - ymm16 = _mm256_hsub_pd(ymm13, ymm14); - //Step 2 - ymm16 = _mm256_mul_pd(ymm16, ymm15); - //Step 3 - ymm8 = _mm256_add_pd(ymm16, ymm8); #ifndef BLIS_ENABLE_TRSM_PREINVERSION - BLIS_ZTRSM_DIV(ymm9) + /*performs dcomplex divison of ymm5 and ymm6 with ymm1*/ + BLIS_ZTRSM_TWO_DIV(ymm5,ymm6) #else - BLIS_ZTRSM_MUL(ymm9) + /*performs dcomplex multiplication of ymm5 and ymm6 with ymm1*/ + BLIS_ZTRSM_MUL(ymm5) + BLIS_ZTRSM_MUL(ymm6) #endif + a11 += cs_a; - ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); - ymm4 = _mm256_broadcast_pd((__m128d const *)(a11 + rs_a)); - if(conjtransa) - { - ymm4 = _mm256_mul_pd(ymm4, ymm0); - } - //Step 1 - ymm14 = _mm256_permute_pd(ymm4, 0x5); + //extract a22 + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack + 2)); + //(ROW2): FMA operations + ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + rs_a * 2)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); /* Negate the imaginary elements of vec2 */ ymm14 = _mm256_mul_pd(ymm14, ymm0); - //For ymm10 + + //For ymm5 /* Multiply vec1 and vec2 */ - ymm13 = _mm256_mul_pd(ymm9, ymm4); /*vec3*/ + ymm13 = _mm256_mul_pd(ymm5, ymm2); /*vec3*/ /* Multiply vec1 and the modified vec2 */ - ymm14 = _mm256_mul_pd(ymm9, ymm14); /*vec4*/ + ymm14 = _mm256_mul_pd(ymm5, ymm14); /*vec4*/ /* Horizontally subtract the elements in vec3 and vec4 */ ymm16 = _mm256_hsub_pd(ymm13, ymm14); + //For ymm6 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + + ymm13 = _mm256_mul_pd(ymm6, ymm2); + ymm14 = _mm256_mul_pd(ymm6, ymm14); + ymm17 = _mm256_hsub_pd(ymm13, ymm14); //Step 2 ymm16 = _mm256_mul_pd(ymm16, ymm15); + ymm17 = _mm256_mul_pd(ymm17, ymm15); //Step 3 - ymm8 = _mm256_add_pd(ymm16, ymm8); + ymm7 = _mm256_add_pd(ymm16, ymm7); + ymm8 = _mm256_add_pd(ymm17, ymm8); + + #ifndef BLIS_ENABLE_TRSM_PREINVERSION - BLIS_ZTRSM_DIV(ymm8) + /*performs dcomplex divison of ymm7 and ymm8 with ymm1*/ + BLIS_ZTRSM_TWO_DIV(ymm7,ymm8) #else - BLIS_ZTRSM_MUL(ymm8) + /*performs dcomplex multiplication of ymm7 and ymm8 with ymm1*/ + BLIS_ZTRSM_MUL(ymm7) + BLIS_ZTRSM_MUL(ymm8) #endif + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + 2), ymm4); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + _mm256_storeu_pd((double *)(b11 + cs_b + 2), ymm6); + _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); + _mm256_storeu_pd((double *)(b11 + cs_b*2 + 2), ymm8); - if(2 == n_remainder) - { - ymm0 = _mm256_permute2f128_pd(ymm8,ymm9,0x20); - ymm4 = _mm256_permute2f128_pd(ymm10,ymm11,0x20); - ymm1 = _mm256_permute2f128_pd(ymm8,ymm9,0x31); - ymm3 = _mm256_permute2f128_pd(ymm10,ymm11,0x31); - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); - _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 2), ymm4); - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); - _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 2), ymm3); + } - } - else if(1 == n_remainder) + dim_t m_remainder = m - i; + if(m_remainder) + { + if(m_remainder == 3) { - ymm0 = _mm256_permute2f128_pd(ymm8,ymm9,0x20); - ymm4 = _mm256_permute2f128_pd(ymm10,ymm11,0x20); - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); - _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 2), ymm4); - } - } - } + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; + b10 = B + i; + b11 = B + i + j*cs_b; - dim_t m_remainder = i + d_mr; - a10 = L + m_remainder*rs_a; - dcomplex *ptr_a10_dup = D_A_pack; - if(m_remainder == 3) - { - dim_t p_lda = 4; - if(transa) - { - for(dim_t x = 0; x < m-m_remainder; x += p_lda) - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm10 = _mm256_loadu_pd((double const *) - (a10 + 2)); - ymm1 = _mm256_loadu_pd((double const *) - (a10 + cs_a)); - ymm11 = _mm256_loadu_pd((double const *) - (a10 + 2 + cs_a)); - - ymm6 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - ymm8 = _mm256_permute2f128_pd(ymm10,ymm11,0x20); - ymm9 = _mm256_permute2f128_pd(ymm10,ymm11,0x31); - - _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); - _mm256_storeu_pd((double *)(ptr_a10_dup + - p_lda), ymm7); - _mm256_storeu_pd((double *)(ptr_a10_dup + - p_lda*2), ymm8); - _mm256_storeu_pd((double *)(ptr_a10_dup + - p_lda*3), ymm9); - - ymm0 = _mm256_loadu_pd((double const *)(a10 - + 2 * cs_a)); - ymm10 = _mm256_loadu_pd((double const *)(a10 - + 2 * cs_a + 2)); - - ymm1 = _mm256_loadu_pd((double const *)(a10 - + 3 * cs_a)); - ymm11 = _mm256_loadu_pd((double const *)(a10 - + 3 * cs_a + 2)); - - ymm6 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - ymm8 = _mm256_permute2f128_pd(ymm10,ymm11,0x20); - ymm9 = _mm256_permute2f128_pd(ymm10,ymm11,0x31); - - _mm256_storeu_pd((double *)(ptr_a10_dup + 2), - ymm6); - _mm256_storeu_pd((double *)(ptr_a10_dup + - p_lda + 2), ymm7); - _mm256_storeu_pd((double *)(ptr_a10_dup + - p_lda*2 + 2), ymm8); - _mm256_storeu_pd((double *)(ptr_a10_dup + - p_lda*3 + 2), ymm9); - - a10 += p_lda; - ptr_a10_dup += p_lda * p_lda; - } - - } - else - { - for(dim_t x=0;x < m-m_remainder;x++) - { - ymm0 = _mm256_loadu_pd((double const *) - (a10 + rs_a * x)); - _mm256_storeu_pd((double *) - (ptr_a10_dup + p_lda * x), ymm0); - ymm0 = _mm256_loadu_pd((double const *) - (a10 + rs_a * x + 2)); - _mm256_storeu_pd((double *) - (ptr_a10_dup + p_lda * x + 2), - ymm0); - } - } - //cols - for(j = (n - d_nr); (j + 1) > 0; j -= d_nr) - { - a10 = D_A_pack; - a11 = L; - b01 = B + (j*cs_b) + m_remainder; - b11 = B + (j* cs_b); - k_iter = (m - m_remainder); - - BLIS_SET_YMM_REG_ZEROS - ///GEMM code begins/// - BLIS_ZTRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) - ///GEMM code ends/// - ymm16 = _mm256_broadcast_pd((__m128d const *) - (&AlphaVal)); - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); - - ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - - ymm14 = _mm256_permute_pd(ymm16, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm18); - ymm17 = _mm256_mul_pd(ymm0, ymm16); - ymm14 = _mm256_mul_pd(ymm0, ymm14); - ymm15 = _mm256_hsub_pd(ymm17, ymm14); - - ymm8 = _mm256_sub_pd(ymm15,ymm8); - - ymm14 = _mm256_permute_pd(ymm16, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm18); - ymm17 = _mm256_mul_pd(ymm1, ymm16); - ymm14 = _mm256_mul_pd(ymm1, ymm14); - ymm15 = _mm256_hsub_pd(ymm17, ymm14); - - ymm9 = _mm256_sub_pd(ymm15,ymm9); - - ymm14 = _mm256_permute_pd(ymm16, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm18); - ymm17 = _mm256_mul_pd(ymm2, ymm16); - ymm14 = _mm256_mul_pd(ymm2, ymm14); - ymm15 = _mm256_hsub_pd(ymm17, ymm14); - - ymm10 = _mm256_sub_pd(ymm15,ymm10); - - _mm256_storeu_pd((double *)(b11), ymm8); - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm9); - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm10); - - ymm0 = _mm256_loadu_pd((double const *) - (b11 + cs_b *0 + 2)); - ymm1 = _mm256_loadu_pd((double const *) - (b11 + cs_b *1 + 2)); - ymm2 = _mm256_loadu_pd((double const *) - (b11 + cs_b *2 + 2)); - - ymm14 = _mm256_permute_pd(ymm16, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm18); - ymm17 = _mm256_mul_pd(ymm0, ymm16); - ymm14 = _mm256_mul_pd(ymm0, ymm14); - ymm15 = _mm256_hsub_pd(ymm17, ymm14); - - ymm11 = _mm256_sub_pd(ymm15,ymm11); - - ymm14 = _mm256_permute_pd(ymm16, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm18); - ymm17 = _mm256_mul_pd(ymm1, ymm16); - ymm14 = _mm256_mul_pd(ymm1, ymm14); - ymm15 = _mm256_hsub_pd(ymm17, ymm14); - - ymm12 = _mm256_sub_pd(ymm15,ymm12); - ymm14 = _mm256_permute_pd(ymm16, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm18); - ymm17 = _mm256_mul_pd(ymm2, ymm16); - ymm14 = _mm256_mul_pd(ymm2, ymm14); - ymm15 = _mm256_hsub_pd(ymm17, ymm14); - - ymm13 = _mm256_sub_pd(ymm15,ymm13); - _mm_storeu_pd((double *)(b11 + 2), - _mm256_extractf128_pd(ymm11,0)); - _mm_storeu_pd((double *)(b11 + cs_b * 1 + 2), - _mm256_extractf128_pd(ymm12,0)); - _mm_storeu_pd((double *)(b11 + cs_b * 2 + 2), - _mm256_extractf128_pd(ymm13,0)); - - if(transa) - ztrsm_AltXB_ref(a11, b11, m_remainder, 3, - cs_a, cs_b, is_unitdiag, - conjtransa); - else - ztrsm_AuXB_ref(a11, b11, m_remainder, 3, - rs_a, cs_b, is_unitdiag, - conjtransa); - } - dim_t n_remainder = j + d_nr; - if(n_remainder) - { - a10 = D_A_pack; - a11 = L; - b01 = B + m_remainder; - b11 = B; - k_iter = (m - m_remainder); - BLIS_SET_YMM_REG_ZEROS - if(2 == n_remainder) - { - ///GEMM code begins/// - BLIS_ZTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b, - p_lda,k_iter) - BLIS_PRE_ZTRSM_SMALL_3M_2N(AlphaVal,b11,cs_b) - - if(transa) - ztrsm_AltXB_ref(a11, b11, m_remainder, 2, - cs_a, cs_b, is_unitdiag, - conjtransa); - - else - ztrsm_AuXB_ref(a11, b11, m_remainder, 2, - rs_a, cs_b, is_unitdiag, - conjtransa); - } - else if(1 == n_remainder) - { - ///GEMM code begins/// - BLIS_ZTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b, - p_lda,k_iter) - BLIS_PRE_ZTRSM_SMALL_3M_1N(AlphaVal,b11,cs_b) - - if(transa) - ztrsm_AltXB_ref(a11, b11, m_remainder, 1, - cs_a, cs_b, is_unitdiag, - conjtransa); - else - ztrsm_AuXB_ref(a11, b11, m_remainder, 1, - rs_a, cs_b, is_unitdiag, - conjtransa); - - } - } - } - else if(m_remainder == 2) - { - dim_t p_lda = 2; - if(transa) - { - for(dim_t x = 0; x < m-m_remainder; x += p_lda) - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm1 = _mm256_loadu_pd((double const *) - (a10 + cs_a)); - - ymm6 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - - _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); - _mm256_storeu_pd((double *)(ptr_a10_dup + - p_lda), ymm7); - - a10 += p_lda; - ptr_a10_dup += p_lda * p_lda; - } - - } - else - { - for(dim_t x=0;x < m-m_remainder;x++) - { - ymm0 = _mm256_loadu_pd((double const *) - (a10 + rs_a * x)); - _mm256_storeu_pd((double *) - (ptr_a10_dup + p_lda * x), ymm0); - } - } - //cols - for(j = (n - d_nr); (j + 1) > 0; j -= d_nr) - { - a10 = D_A_pack; - a11 = L; - b01 = B + (j*cs_b) + m_remainder; - b11 = B + (j* cs_b); - k_iter = (m - m_remainder); - - BLIS_SET_YMM_REG_ZEROS - ///GEMM code begins/// - BLIS_ZTRSM_SMALL_GEMM_2mx3n(a10,b01,cs_b,p_lda,k_iter) - ///GEMM code ends/// - ymm16 = _mm256_broadcast_pd((__m128d const *) - (&AlphaVal)); - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); - - ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - - ymm14 = _mm256_permute_pd(ymm16, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm18); - ymm17 = _mm256_mul_pd(ymm0, ymm16); - ymm14 = _mm256_mul_pd(ymm0, ymm14); - ymm15 = _mm256_hsub_pd(ymm17, ymm14); - - ymm8 = _mm256_sub_pd(ymm15,ymm8); - - ymm14 = _mm256_permute_pd(ymm16, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm18); - ymm17 = _mm256_mul_pd(ymm1, ymm16); - ymm14 = _mm256_mul_pd(ymm1, ymm14); - ymm15 = _mm256_hsub_pd(ymm17, ymm14); - - ymm9 = _mm256_sub_pd(ymm15,ymm9); - - ymm14 = _mm256_permute_pd(ymm16, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm18); - ymm17 = _mm256_mul_pd(ymm2, ymm16); - ymm14 = _mm256_mul_pd(ymm2, ymm14); - ymm15 = _mm256_hsub_pd(ymm17, ymm14); - - ymm10 = _mm256_sub_pd(ymm15,ymm10); - - _mm256_storeu_pd((double *)(b11), ymm8); - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm9); - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm10); - - if(transa) - ztrsm_AltXB_ref(a11, b11, m_remainder, 3, - cs_a, cs_b, is_unitdiag, - conjtransa); - else - ztrsm_AuXB_ref(a11, b11, m_remainder, 3, - rs_a, cs_b, is_unitdiag, - conjtransa); - } - dim_t n_remainder = j + d_nr; - if(n_remainder) - { - a10 = D_A_pack; - a11 = L; - b01 = B + m_remainder; - b11 = B; - k_iter = (m - m_remainder); - BLIS_SET_YMM_REG_ZEROS - if(2 == n_remainder) - { - ///GEMM code begins/// - BLIS_ZTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b, - p_lda,k_iter) - BLIS_PRE_ZTRSM_SMALL_2M_2N(AlphaVal,b11,cs_b) - - if(transa) - ztrsm_AltXB_ref(a11, b11, m_remainder, 2, - cs_a, cs_b, is_unitdiag, - conjtransa); - - else - ztrsm_AuXB_ref(a11, b11, m_remainder, 2, - rs_a, cs_b, is_unitdiag, - conjtransa); - } - else if(1 == n_remainder) - { - ///GEMM code begins/// - BLIS_ZTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b, - p_lda,k_iter) - BLIS_PRE_ZTRSM_SMALL_2M_1N(AlphaVal,b11,cs_b) - - if(transa) - ztrsm_AltXB_ref(a11, b11, m_remainder, 1, - cs_a, cs_b, is_unitdiag, - conjtransa); - else - ztrsm_AuXB_ref(a11, b11, m_remainder, 1, - rs_a, cs_b, is_unitdiag, - conjtransa); - - } - } - } - else if(m_remainder == 1) - { - dim_t p_lda = 2; // packed leading dimension - if(transa) - { - for(dim_t x = 0; x < m-m_remainder; x += p_lda) - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm1 = _mm256_loadu_pd((double const *) - (a10 + cs_a)); - - ymm6 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - - _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); - _mm256_storeu_pd((double *)(ptr_a10_dup + - p_lda), ymm7); - - a10 += p_lda; - ptr_a10_dup += p_lda * p_lda; - } - - } - else - { - for(dim_t x=0;x 0; j -= d_nr) - { - a10 = D_A_pack; - a11 = L; - b01 = B + (j*cs_b) + m_remainder; - b11 = B + (j* cs_b); - k_iter = (m - m_remainder); - - BLIS_SET_YMM_REG_ZEROS - ///GEMM code begins/// - BLIS_ZTRSM_SMALL_GEMM_2mx3n(a10,b01,cs_b,p_lda,k_iter) - ///GEMM code ends/// - ymm16 = _mm256_broadcast_pd((__m128d const *) - (&AlphaVal)); - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); - ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - - ymm14 = _mm256_permute_pd(ymm16, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm18); - ymm17 = _mm256_mul_pd(ymm0, ymm16); - ymm14 = _mm256_mul_pd(ymm0, ymm14); - ymm15 = _mm256_hsub_pd(ymm17, ymm14); - - ymm8 = _mm256_sub_pd(ymm15,ymm8); - - ymm14 = _mm256_permute_pd(ymm16, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm18); - ymm17 = _mm256_mul_pd(ymm1, ymm16); - ymm14 = _mm256_mul_pd(ymm1, ymm14); - ymm15 = _mm256_hsub_pd(ymm17, ymm14); - - ymm9 = _mm256_sub_pd(ymm15,ymm9); - ymm14 = _mm256_permute_pd(ymm16, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm18); - ymm17 = _mm256_mul_pd(ymm2, ymm16); - ymm14 = _mm256_mul_pd(ymm2, ymm14); - ymm15 = _mm256_hsub_pd(ymm17, ymm14); - - ymm10 = _mm256_sub_pd(ymm15,ymm10); - - _mm_storeu_pd((double *)(b11), - _mm256_extractf128_pd(ymm8,0)); - _mm_storeu_pd((double *)(b11 + cs_b * 1), - _mm256_extractf128_pd(ymm9,0) ); - _mm_storeu_pd((double *)(b11 + cs_b * 2), - _mm256_extractf128_pd(ymm10,0)); - - if(transa) - ztrsm_AltXB_ref(a11, b11, m_remainder, 3, - cs_a, cs_b, is_unitdiag, - conjtransa); - - else - ztrsm_AuXB_ref(a11, b11, m_remainder, 3, rs_a, - cs_b, is_unitdiag, - conjtransa); - } - dim_t n_remainder = j + d_nr; - if(n_remainder) - { - a10 = D_A_pack; - a11 = L ; - b01 = B + m_remainder; - b11 = B; - k_iter = (m - m_remainder); - BLIS_SET_YMM_REG_ZEROS - if(2 == n_remainder) - { - - ///GEMM code begins/// - BLIS_ZTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b, - p_lda,k_iter) - BLIS_PRE_ZTRSM_SMALL_1M_2N(AlphaVal,b11,cs_b) - - if(transa) - ztrsm_AltXB_ref(a11, b11, m_remainder, 2, - cs_a, cs_b, is_unitdiag, - conjtransa); - - else - ztrsm_AuXB_ref(a11, b11, m_remainder, 2, - rs_a, cs_b, is_unitdiag, - conjtransa); - } - else if(1 == n_remainder) - { - ///GEMM code begins/// - BLIS_ZTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b, - p_lda,k_iter) - - BLIS_PRE_ZTRSM_SMALL_1M_1N(AlphaVal,b11,cs_b) - - if(transa) - ztrsm_AltXB_ref(a11, b11, m_remainder, 1, - cs_a, cs_b, is_unitdiag, - conjtransa); - - else - ztrsm_AuXB_ref(a11, b11, m_remainder, 1, - rs_a, cs_b, is_unitdiag, - conjtransa); - } - } - } + k_iter = j; - if ((required_packing_A == 1) && - bli_mem_is_alloc( &local_mem_buf_A_s )) - { - bli_membrk_release(&rntm, &local_mem_buf_A_s); - } + /*Fill zeros into ymm registers used in gemm + * accumulations */ + BLIS_SET_YMM_REG_ZEROS - return BLIS_SUCCESS; -} + ///GEMM implementation starts/// + BLIS_ZTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) -BLIS_INLINE err_t bli_ztrsm_small_XAutB_XAlB -( - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl -) -{ - dim_t m = bli_obj_length(b); //number of rows - dim_t n = bli_obj_width(b); //number of columns - - bool transa = bli_obj_has_trans(a); - bool conjtransa = bli_obj_has_conj(a); - - dim_t cs_a, rs_a; - dim_t d_mr = 4,d_nr = 3; - - // Swap rs_a & cs_a in case of non-tranpose. - if(transa) - { - cs_a = bli_obj_col_stride(a); // column stride of A - rs_a = bli_obj_row_stride(a); // row stride of A - } - else - { - cs_a = bli_obj_row_stride(a); // row stride of A - rs_a = bli_obj_col_stride(a); // column stride of A - } - dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B - - dim_t i, j, k; //loop variablse - dim_t k_iter; //determines the number of GEMM operations to be done - - dcomplex ones = {1.0, 1.0}; - dcomplex zero = {0.0, 0.0}; - bool is_unitdiag = bli_obj_has_unit_diag(a); - - dcomplex AlphaVal = *(dcomplex *)AlphaObj->buffer; //value of Alpha - dcomplex* restrict L = a->buffer; //pointer to matrix A - dcomplex* restrict B = b->buffer; //pointer to matrix B - - dcomplex *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks - - gint_t required_packing_A = 1; - mem_t local_mem_buf_A_s = {0}; - dcomplex *D_A_pack = NULL; - dcomplex d11_pack[d_mr] __attribute__((aligned(64))); - rntm_t rntm; - - bli_rntm_init_from_global( &rntm ); - bli_rntm_set_num_threads_only( 1, &rntm ); - bli_membrk_rntm_set_membrk( &rntm ); - - siz_t buffer_size = bli_pool_block_size( - bli_membrk_pool( - bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), - bli_rntm_membrk(&rntm))); - - if( (d_nr * n * sizeof(dcomplex)) > buffer_size) - return BLIS_NOT_YET_IMPLEMENTED; - - if (required_packing_A == 1) - { - // Get the buffer from the pool. - bli_membrk_acquire_m(&rntm, - buffer_size, - BLIS_BITVAL_BUFFER_FOR_A_BLOCK, - &local_mem_buf_A_s); - if(FALSE==bli_mem_is_alloc(&local_mem_buf_A_s)) return BLIS_NULL_POINTER; - D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); - if(NULL==D_A_pack) return BLIS_NULL_POINTER; - } - - //ymm scratch reginsters - __m256d ymm0, ymm1, ymm2, ymm3; - __m256d ymm4, ymm5, ymm6, ymm7; - __m256d ymm8, ymm9, ymm10, ymm11; - __m256d ymm12, ymm13, ymm14, ymm15; - __m256d ymm16, ymm17, ymm18, ymm19; - - __m128d xmm5, xmm4, xmm3; - - for(j = (n-d_nr); (j+1) > 0; j -= d_nr) //loop along 'N' direction - { - a01 = L + (j*rs_a) + (j+d_nr)*cs_a; - a11 = L + (j*cs_a) + (j*rs_a); - - dim_t p_lda = (n-j-d_nr); // packed leading dimension - // perform copy of A to packed buffer D_A_pack - - if(transa) - { - /* - Pack current A block (a01) into packed buffer memory D_A_pack - a. This a10 block is used in GEMM portion only and this - a01 block size will be increasing by d_nr for every next - iteration until it reaches 3x(n-3) which is the maximum GEMM - alone block size in A - b. This packed buffer is reused to calculate all m cols of B - matrix - */ - bli_ztrsm_small_pack('R', p_lda, 1, a01, cs_a, D_A_pack, - p_lda,d_nr); - - /* - Pack 3 diagonal elements of A block into an array - a. This helps in utilze cache line efficiently in TRSM - operation - b. store ones when input is unit diagonal - */ - ztrsm_small_pack_diag_element(is_unitdiag,a11,cs_a, - d11_pack,d_nr); - } - else - { - bli_ztrsm_small_pack('R', p_lda, 0, a01, rs_a, D_A_pack, - p_lda,d_nr); - ztrsm_small_pack_diag_element(is_unitdiag,a11,rs_a, - d11_pack,d_nr); - } - - /* - a. Perform GEMM using a01, b10. - b. Perform TRSM on a11, b11 - c. This loop GEMM+TRSM loops operates with 8x6 block size - along m dimension for every d_mr columns of B10 where - packed A buffer is reused in computing all m cols of B. - d. Same approach is used in remaining fringe cases. - */ - for(i = (m-d_mr); (i+1) > 0; i -= d_mr) //loop along 'M' direction - { - a01 = D_A_pack; - a11 = L + j*cs_a + j*rs_a; - b10 = B + i + (j+d_nr)*cs_b; - b11 = B + (i) + (j)*cs_b; - - k_iter = (n-j-d_nr); - - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS - - /* - Peform GEMM between a01 and b10 blocks - For first itteration there will be no GEMM operation - where k_iter are zero - */ - - BLIS_ZTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) - - /* - Load b11 multiply with alpha - Add the GEMM output to b11 - and peform TRSM operation. - */ - - BLIS_PRE_ZTRSM_SMALL_3x4(AlphaVal,b11,cs_b) - ///implement TRSM/// - /* - Compute 3x3 TRSM block by using GEMM block output in register - a. The 4x3 input (gemm outputs) are stored in combinations of - ymm registers - 1. ymm7, ymm8 2. ymm5, ymm6 3. ymm3, ymm4 - b. Towards the end do in regiser transpose of TRSM output and - store in b11 - */ - ////extract a00 - ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack + 2)); + // Load b11 of size 4x6 and multiply with alpha + BLIS_PRE_ZTRSM_SMALL_3x4(AlphaVal,b11,cs_b) + + ///implement TRSM/// + ////extract a00 + ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); #ifndef BLIS_ENABLE_TRSM_PREINVERSION - /*performs dcomplex divison of ymm7 and ymm8 with ymm1*/ - BLIS_ZTRSM_TWO_DIV(ymm7,ymm8) + /*performs dcomplex divison of ymm3 and ymm4 with ymm1*/ + BLIS_ZTRSM_TWO_DIV(ymm3,ymm4) #else - /*performs dcomplex multiplication of ymm7 and ymm8 with ymm1*/ - BLIS_ZTRSM_MUL(ymm7) - BLIS_ZTRSM_MUL(ymm8) + /*performs dcomplex multiplication of ymm3 and ymm4 + * with ymm1*/ + BLIS_ZTRSM_MUL(ymm3) + BLIS_ZTRSM_MUL(ymm4) #endif - //extract a11 - ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack + 1)); - //(ROW1): FMA operations - ymm2 = _mm256_broadcast_pd((__m128d const *) - (a11 + cs_a*2 + rs_a*1)); - if(conjtransa) - { - ymm2 = _mm256_mul_pd(ymm2, ymm0); - } - /* Step1 dcomplex multiply ymm2, ymm7 - * Step2 negate the result - * Step3 add ymmx*/ - //Step 1 - ymm14 = _mm256_permute_pd(ymm2, 0x5); - /* Negate the imaginary elements of vec2 */ - ymm14 = _mm256_mul_pd(ymm14, ymm0); - //For ymm7 - /* Multiply vec1 and vec2 */ - ymm13 = _mm256_mul_pd(ymm7, ymm2); /*vec3*/ - /* Multiply vec1 and the modified vec2 */ - ymm14 = _mm256_mul_pd(ymm7, ymm14); /*vec4*/ - /* Horizontally subtract the elements in vec3 and vec4 */ - ymm16 = _mm256_hsub_pd(ymm13, ymm14); - - //For ymm8 - ymm14 = _mm256_permute_pd(ymm2, 0x5); - /* Negate the imaginary elements of vec2 */ - ymm14 = _mm256_mul_pd(ymm14, ymm0); - - ymm13 = _mm256_mul_pd(ymm8, ymm2); - ymm14 = _mm256_mul_pd(ymm8, ymm14); - ymm17 = _mm256_hsub_pd(ymm13, ymm14); - //Step 2 - ymm16 = _mm256_mul_pd(ymm16, ymm15); - ymm17 = _mm256_mul_pd(ymm17, ymm15); - - //Step 3 - ymm5 = _mm256_add_pd(ymm16, ymm5); - ymm6 = _mm256_add_pd(ymm17, ymm6); - - ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a*2)); - if(conjtransa) - { - ymm2 = _mm256_mul_pd(ymm2, ymm0); - } - //Step 1 - ymm14 = _mm256_permute_pd(ymm2, 0x5); - /* Negate the imaginary elements of vec2 */ - ymm14 = _mm256_mul_pd(ymm14, ymm0); - //For ymm7 - /* Multiply vec1 and vec2 */ - ymm13 = _mm256_mul_pd(ymm7, ymm2); /*vec3*/ - /* Multiply vec1 and the modified vec2 */ - ymm14 = _mm256_mul_pd(ymm7, ymm14); /*vec4*/ - /* Horizontally subtract the elements in vec3 and vec4 */ - ymm16 = _mm256_hsub_pd(ymm13, ymm14); - //For ymm8 - ymm14 = _mm256_permute_pd(ymm2, 0x5); - /* Negate the imaginary elements of vec2 */ - ymm14 = _mm256_mul_pd(ymm14, ymm0); - - ymm13 = _mm256_mul_pd(ymm8, ymm2); - ymm14 = _mm256_mul_pd(ymm8, ymm14); - ymm17 = _mm256_hsub_pd(ymm13, ymm14); - //Step 2 - ymm16 = _mm256_mul_pd(ymm16, ymm15); - ymm17 = _mm256_mul_pd(ymm17, ymm15); - - //Step 3 - ymm3 = _mm256_add_pd(ymm16, ymm3); - ymm4 = _mm256_add_pd(ymm17, ymm4); + //extract a11 + ymm1 = _mm256_broadcast_pd((__m128d const *) + (d11_pack + 1)); + //(ROW1): FMA operations + ymm2 = _mm256_broadcast_pd((__m128d const *) + (a11 + rs_a*1)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + /* Step1 dcomplex multiply ymm2, ymm3 + * Step2 negate the result + * Step3 add ymmx*/ + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + //For ymm3 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm3, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm3, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + + //For ymm4 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + + ymm13 = _mm256_mul_pd(ymm4, ymm2); + ymm14 = _mm256_mul_pd(ymm4, ymm14); + ymm17 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + ymm17 = _mm256_mul_pd(ymm17, ymm15); + + //Step 3 + ymm5 = _mm256_add_pd(ymm16, ymm5); + ymm6 = _mm256_add_pd(ymm17, ymm6); + + ymm2 = _mm256_broadcast_pd((__m128d const *) + (a11 + rs_a*2)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + //For ymm3 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm3, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm3, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + //For ymm4 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + + ymm13 = _mm256_mul_pd(ymm4, ymm2); + ymm14 = _mm256_mul_pd(ymm4, ymm14); + ymm17 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + ymm17 = _mm256_mul_pd(ymm17, ymm15); + + //Step 3 + ymm7 = _mm256_add_pd(ymm16, ymm7); + ymm8 = _mm256_add_pd(ymm17, ymm8); #ifndef BLIS_ENABLE_TRSM_PREINVERSION - /*performs dcomplex divison of ymm5 and ymm6 with ymm1*/ - BLIS_ZTRSM_TWO_DIV(ymm5,ymm6) + /*performs dcomplex divison of ymm5 and ymm6 with ymm1*/ + BLIS_ZTRSM_TWO_DIV(ymm5,ymm6) #else - /*performs dcomplex multiplication of ymm5 and ymm6 with ymm1*/ - BLIS_ZTRSM_MUL(ymm5) - BLIS_ZTRSM_MUL(ymm6) + /*performs dcomplex multiplication of ymm5 and ymm6 with + * ymm1*/ + BLIS_ZTRSM_MUL(ymm5) + BLIS_ZTRSM_MUL(ymm6) #endif - //extract a22 - ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); - - //(ROW2): FMA operations - ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a)); - if(conjtransa) - { - ymm2 = _mm256_mul_pd(ymm2, ymm0); - } - //Step 1 - ymm14 = _mm256_permute_pd(ymm2, 0x5); - /* Negate the imaginary elements of vec2 */ - ymm14 = _mm256_mul_pd(ymm14, ymm0); - - //For ymm5 - /* Multiply vec1 and vec2 */ - ymm13 = _mm256_mul_pd(ymm5, ymm2); /*vec3*/ - /* Multiply vec1 and the modified vec2 */ - ymm14 = _mm256_mul_pd(ymm5, ymm14); /*vec4*/ - /* Horizontally subtract the elements in vec3 and vec4 */ - ymm16 = _mm256_hsub_pd(ymm13, ymm14); - //For ymm6 - ymm14 = _mm256_permute_pd(ymm2, 0x5); - /* Negate the imaginary elements of vec2 */ - ymm14 = _mm256_mul_pd(ymm14, ymm0); - - ymm13 = _mm256_mul_pd(ymm6, ymm2); - ymm14 = _mm256_mul_pd(ymm6, ymm14); - ymm17 = _mm256_hsub_pd(ymm13, ymm14); - //Step 2 - ymm16 = _mm256_mul_pd(ymm16, ymm15); - ymm17 = _mm256_mul_pd(ymm17, ymm15); - //Step 3 - ymm3 = _mm256_add_pd(ymm16, ymm3); - ymm4 = _mm256_add_pd(ymm17, ymm4); + a11 += cs_a; + + //extract a22 + ymm1 = _mm256_broadcast_pd((__m128d const *) + (d11_pack + 2)); + //(ROW2): FMA operations + ymm2 = _mm256_broadcast_pd((__m128d const *) + (a11 + rs_a * 2)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + + //For ymm5 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm5, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm5, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + //For ymm6 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + + ymm13 = _mm256_mul_pd(ymm6, ymm2); + ymm14 = _mm256_mul_pd(ymm6, ymm14); + ymm17 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + ymm17 = _mm256_mul_pd(ymm17, ymm15); + //Step 3 + ymm7 = _mm256_add_pd(ymm16, ymm7); + ymm8 = _mm256_add_pd(ymm17, ymm8); #ifndef BLIS_ENABLE_TRSM_PREINVERSION - /*performs dcomplex divison of ymm3 and ymm4 with ymm1*/ - BLIS_ZTRSM_TWO_DIV(ymm3,ymm4) -#else - /*performs dcomplex multiplication of ymm3 and ymm4 with ymm1*/ - BLIS_ZTRSM_MUL(ymm3) - BLIS_ZTRSM_MUL(ymm4) -#endif - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + 2), ymm4); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b + 2), ymm6); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*2 + 2), ymm8); - - } - dim_t m_remainder = i + d_mr; - if(m_remainder) - { - if(3 == m_remainder) - { - a01 = D_A_pack; - a11 = L + (j*cs_a) + (j*rs_a); - b10 = B + (j+d_nr)*cs_b + (m_remainder - 3); - b11 = B + (m_remainder - 3) + (j*cs_b); - k_iter = (n-j-d_nr); - /*Fill zeros into ymm registers used in gemm - * accumulations */ - BLIS_SET_YMM_REG_ZEROS - /* - Peform GEMM between a01 and b10 blocks - For first itteration there will be no GEMM operation - where k_iter are zero - */ - - BLIS_ZTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) - - /* - Load b11 multiply with alpha - Add the GEMM output to b11 - and peform TRSM operation. - */ - - BLIS_PRE_ZTRSM_SMALL_3x4(AlphaVal,b11,cs_b) - ///implement TRSM/// - /* - Compute 3x3 TRSM block by using GEMM block output in - register - a. The 4x3 input (gemm outputs) are stored in - combinations of ymm registers - 1. ymm7, ymm8 2. ymm5, ymm6 3. ymm3, ymm4 - b. Towards the end do in regiser transpose of TRSM - output and store in b11 - */ - ////extract a00 - ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - ymm1 = _mm256_broadcast_pd((__m128d const *) - (d11_pack + 2)); -#ifndef BLIS_ENABLE_TRSM_PREINVERSION - /*performs dcomplex divison of ymm7 and ymm8 with ymm1*/ - BLIS_ZTRSM_TWO_DIV(ymm7,ymm8) + /*performs dcomplex divison of ymm7 and ymm8 with ymm1*/ + BLIS_ZTRSM_TWO_DIV(ymm7,ymm8) #else - /*performs dcomplex multiplication of ymm7 and - * ymm8 with ymm1*/ - BLIS_ZTRSM_MUL(ymm7) - BLIS_ZTRSM_MUL(ymm8) + /*performs dcomplex multiplication of ymm7 and ymm8 + * with ymm1*/ + BLIS_ZTRSM_MUL(ymm7) + BLIS_ZTRSM_MUL(ymm8) #endif - //extract a11 - ymm1 = _mm256_broadcast_pd((__m128d const *) - (d11_pack + 1)); - //(ROW1): FMA operations - ymm2 = _mm256_broadcast_pd((__m128d const *) - (a11 + cs_a*2 + rs_a*1)); - if(conjtransa) - { - ymm2 = _mm256_mul_pd(ymm2, ymm0); - } - /* Step1 dcomplex multiply ymm2, ymm7 - * Step2 negate the result - * Step3 add ymmx*/ - //Step 1 - ymm14 = _mm256_permute_pd(ymm2, 0x5); - /* Negate the imaginary elements of vec2 */ - ymm14 = _mm256_mul_pd(ymm14, ymm0); - //For ymm7 - /* Multiply vec1 and vec2 */ - ymm13 = _mm256_mul_pd(ymm7, ymm2); /*vec3*/ - /* Multiply vec1 and the modified vec2 */ - ymm14 = _mm256_mul_pd(ymm7, ymm14); /*vec4*/ - /* Horizontally subtract the elements in vec3 and vec4 */ - ymm16 = _mm256_hsub_pd(ymm13, ymm14); - - //For ymm8 - ymm14 = _mm256_permute_pd(ymm2, 0x5); - /* Negate the imaginary elements of vec2 */ - ymm14 = _mm256_mul_pd(ymm14, ymm0); - - ymm13 = _mm256_mul_pd(ymm8, ymm2); - ymm14 = _mm256_mul_pd(ymm8, ymm14); - ymm17 = _mm256_hsub_pd(ymm13, ymm14); - //Step 2 - ymm16 = _mm256_mul_pd(ymm16, ymm15); - ymm17 = _mm256_mul_pd(ymm17, ymm15); - - //Step 3 - ymm5 = _mm256_add_pd(ymm16, ymm5); - ymm6 = _mm256_add_pd(ymm17, ymm6); - - ymm2 = _mm256_broadcast_pd((__m128d const *) - (a11 + cs_a*2)); - if(conjtransa) - { - ymm2 = _mm256_mul_pd(ymm2, ymm0); - } - //Step 1 - ymm14 = _mm256_permute_pd(ymm2, 0x5); - /* Negate the imaginary elements of vec2 */ - ymm14 = _mm256_mul_pd(ymm14, ymm0); - //For ymm7 - /* Multiply vec1 and vec2 */ - ymm13 = _mm256_mul_pd(ymm7, ymm2); /*vec3*/ - /* Multiply vec1 and the modified vec2 */ - ymm14 = _mm256_mul_pd(ymm7, ymm14); /*vec4*/ - /* Horizontally subtract the elements in vec3 and vec4 */ - ymm16 = _mm256_hsub_pd(ymm13, ymm14); - //For ymm8 - ymm14 = _mm256_permute_pd(ymm2, 0x5); - /* Negate the imaginary elements of vec2 */ - ymm14 = _mm256_mul_pd(ymm14, ymm0); - - ymm13 = _mm256_mul_pd(ymm8, ymm2); - ymm14 = _mm256_mul_pd(ymm8, ymm14); - ymm17 = _mm256_hsub_pd(ymm13, ymm14); - //Step 2 - ymm16 = _mm256_mul_pd(ymm16, ymm15); - ymm17 = _mm256_mul_pd(ymm17, ymm15); - - //Step 3 - ymm3 = _mm256_add_pd(ymm16, ymm3); - ymm4 = _mm256_add_pd(ymm17, ymm4); + _mm256_storeu_pd((double *)b11, ymm3); + _mm_storeu_pd((double *)(b11 + 2), + _mm256_extractf128_pd(ymm4,0)); -#ifndef BLIS_ENABLE_TRSM_PREINVERSION - /*performs dcomplex divison of ymm5 and ymm6 with ymm1*/ - BLIS_ZTRSM_TWO_DIV(ymm5,ymm6) -#else - /*performs dcomplex multiplication of ymm5 and - * ymm6 with ymm1*/ - BLIS_ZTRSM_MUL(ymm5) - BLIS_ZTRSM_MUL(ymm6) -#endif - //extract a22 - ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); - - //(ROW2): FMA operations - ymm2 = _mm256_broadcast_pd((__m128d const *) - (a11 + cs_a)); - if(conjtransa) - { - ymm2 = _mm256_mul_pd(ymm2, ymm0); - } - //Step 1 - ymm14 = _mm256_permute_pd(ymm2, 0x5); - /* Negate the imaginary elements of vec2 */ - ymm14 = _mm256_mul_pd(ymm14, ymm0); - - //For ymm5 - /* Multiply vec1 and vec2 */ - ymm13 = _mm256_mul_pd(ymm5, ymm2); /*vec3*/ - /* Multiply vec1 and the modified vec2 */ - ymm14 = _mm256_mul_pd(ymm5, ymm14); /*vec4*/ - /* Horizontally subtract the elements in vec3 and vec4 */ - ymm16 = _mm256_hsub_pd(ymm13, ymm14); - //For ymm6 - ymm14 = _mm256_permute_pd(ymm2, 0x5); - /* Negate the imaginary elements of vec2 */ - ymm14 = _mm256_mul_pd(ymm14, ymm0); - - ymm13 = _mm256_mul_pd(ymm6, ymm2); - ymm14 = _mm256_mul_pd(ymm6, ymm14); - ymm17 = _mm256_hsub_pd(ymm13, ymm14); - //Step 2 - ymm16 = _mm256_mul_pd(ymm16, ymm15); - ymm17 = _mm256_mul_pd(ymm17, ymm15); - //Step 3 - ymm3 = _mm256_add_pd(ymm16, ymm3); - ymm4 = _mm256_add_pd(ymm17, ymm4); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + _mm_storeu_pd((double *)(b11 + cs_b + 2), + _mm256_extractf128_pd(ymm6,0)); + + _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); + _mm_storeu_pd((double *)(b11 + cs_b*2 + 2), + _mm256_extractf128_pd(ymm8,0)); + + m_remainder -= 3; + i += 3; + } + else if(m_remainder == 2) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; + b10 = B + i; + b11 = B + i + j*cs_b; + + k_iter = j; + + /*Fill zeros into ymm registers used in gemm + * accumulations */ + BLIS_SET_YMM_REG_ZEROS + ///GEMM implementation starts/// + BLIS_ZTRSM_SMALL_GEMM_3nx2m(a01,b10,cs_b,p_lda,k_iter) + // Load b11 of size 4x6 and multiply with alpha + BLIS_PRE_ZTRSM_SMALL_3x2(AlphaVal,b11,cs_b) + ////extract a00 + ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); #ifndef BLIS_ENABLE_TRSM_PREINVERSION - /*performs dcomplex divison of ymm3 and ymm4 with ymm1*/ - BLIS_ZTRSM_TWO_DIV(ymm3,ymm4) + /*performs dcomplex divison of ymm3 with ymm1*/ + BLIS_ZTRSM_DIV(ymm3) #else - /*performs dcomplex multiplication of ymm3 and - * ymm4 with ymm1*/ - BLIS_ZTRSM_MUL(ymm3) - BLIS_ZTRSM_MUL(ymm4) + /*performs dcomplex multiplication of ymm3 + * with ymm1*/ + BLIS_ZTRSM_MUL(ymm3) #endif - _mm256_storeu_pd((double *)b11, ymm3); - _mm_storeu_pd((double *)(b11 + 2), - _mm256_extractf128_pd(ymm4,0)); - - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm_storeu_pd((double *)(b11 + cs_b + 2), - _mm256_extractf128_pd(ymm6,0)); - - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm_storeu_pd((double *)(b11 + cs_b*2 + 2), - _mm256_extractf128_pd(ymm8,0)); - m_remainder -=3; - } - else if(2 == m_remainder) - { - a01 = D_A_pack; - a11 = L + (j*cs_a) + (j*rs_a); - b10 = B + (j+d_nr)*cs_b + (m_remainder - 2); - b11 = B + (m_remainder - 2) + (j*cs_b); - k_iter = (n-j-d_nr); - /*Fill zeros into ymm registers used in gemm - * accumulations */ - BLIS_SET_YMM_REG_ZEROS - /* - Peform GEMM between a01 and b10 blocks - For first itteration there will be no GEMM operation - where k_iter are zero - */ - - BLIS_ZTRSM_SMALL_GEMM_3nx2m(a01,b10,cs_b,p_lda,k_iter) - - /* - Load b11 of size 8x6 and multiply with alpha - Add the GEMM output to b11 - and peform TRSM operation. - */ - - BLIS_PRE_ZTRSM_SMALL_3x2(AlphaVal,b11,cs_b) - ///implement TRSM/// - /* - Compute 3x3 TRSM block by using GEMM block output - in register - a. The 4x3 input (gemm outputs) are stored in - combinations of ymm registers - 1. ymm8, ymm11 2. ymm9, ymm12 3. ymm10, ymm13 - b. Towards the end do in regiser transpose of TRSM - output and store in b11 - */ - ////extract a00 - ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - ymm1 = _mm256_broadcast_pd((__m128d const *) - (d11_pack + 2)); + //extract a11 + ymm1 = _mm256_broadcast_pd((__m128d const *) + (d11_pack + 1)); + //(ROW1): FMA operations + ymm2 = _mm256_broadcast_pd((__m128d const *) + (a11 + rs_a*1)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + /* Step1 dcomplex multiply ymm2, ymm3 + * Step2 negate the result + * Step3 add ymmx*/ + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + //For ymm3 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm3, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm3, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + + //Step 3 + ymm5 = _mm256_add_pd(ymm16, ymm5); + + ymm2 = _mm256_broadcast_pd((__m128d const *) + (a11 + rs_a*2)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + //For ymm3 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm3, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm3, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + + //Step 3 + ymm7 = _mm256_add_pd(ymm16, ymm7); + #ifndef BLIS_ENABLE_TRSM_PREINVERSION - /*performs dcomplex divison of ymm7 with ymm1*/ - BLIS_ZTRSM_DIV(ymm7) + /*performs dcomplex divison of ymm5 with ymm1*/ + BLIS_ZTRSM_DIV(ymm5) #else - /*performs dcomplex multiplication of ymm7 with ymm1*/ - BLIS_ZTRSM_MUL(ymm7) + /*performs dcomplex multiplication of ymm5 + * with ymm1*/ + BLIS_ZTRSM_MUL(ymm5) #endif - //extract a11 - ymm1 = _mm256_broadcast_pd((__m128d const *) - (d11_pack + 1)); - //(ROW1): FMA operations - ymm2 = _mm256_broadcast_pd((__m128d const *) - (a11 + cs_a*2 + rs_a*1)); - if(conjtransa) - { - ymm2 = _mm256_mul_pd(ymm2, ymm0); - } - /* Step1 dcomplex multiply ymm2, ymm7 - * Step2 negate the result - * Step3 add ymmx*/ - //Step 1 - ymm14 = _mm256_permute_pd(ymm2, 0x5); - /* Negate the imaginary elements of vec2 */ - ymm14 = _mm256_mul_pd(ymm14, ymm0); - //For ymm7 - /* Multiply vec1 and vec2 */ - ymm13 = _mm256_mul_pd(ymm7, ymm2); /*vec3*/ - /* Multiply vec1 and the modified vec2 */ - ymm14 = _mm256_mul_pd(ymm7, ymm14); /*vec4*/ - /* Horizontally subtract the elements in vec3 and vec4 */ - ymm16 = _mm256_hsub_pd(ymm13, ymm14); - - //Step 2 - ymm16 = _mm256_mul_pd(ymm16, ymm15); - - //Step 3 - ymm5 = _mm256_add_pd(ymm16, ymm5); - - ymm2 = _mm256_broadcast_pd((__m128d const *) - (a11 + cs_a*2)); - if(conjtransa) - { - ymm2 = _mm256_mul_pd(ymm2, ymm0); - } - //Step 1 - ymm14 = _mm256_permute_pd(ymm2, 0x5); - /* Negate the imaginary elements of vec2 */ - ymm14 = _mm256_mul_pd(ymm14, ymm0); - //For ymm7 - /* Multiply vec1 and vec2 */ - ymm13 = _mm256_mul_pd(ymm7, ymm2); /*vec3*/ - /* Multiply vec1 and the modified vec2 */ - ymm14 = _mm256_mul_pd(ymm7, ymm14); /*vec4*/ - /* Horizontally subtract the elements in vec3 and vec4 */ - ymm16 = _mm256_hsub_pd(ymm13, ymm14); - //Step 2 - ymm16 = _mm256_mul_pd(ymm16, ymm15); - - //Step 3 - ymm3 = _mm256_add_pd(ymm16, ymm3); + a11 += cs_a; + //extract a22 + ymm1 = _mm256_broadcast_pd((__m128d const *) + (d11_pack + 2)); + //(ROW2): FMA operations + ymm2 = _mm256_broadcast_pd((__m128d const *) + (a11 + rs_a * 2)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + + //For ymm5 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm5, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm5, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + //Step 3 + ymm7 = _mm256_add_pd(ymm16, ymm7); #ifndef BLIS_ENABLE_TRSM_PREINVERSION - /*performs dcomplex divison of ymm5 with ymm1*/ - BLIS_ZTRSM_DIV(ymm5) + /*performs dcomplex divison of ymm7 with ymm1*/ + BLIS_ZTRSM_DIV(ymm7) #else - /*performs dcomplex multiplication of ymm5 with ymm1*/ - BLIS_ZTRSM_MUL(ymm5) + /*performs dcomplex multiplication of ymm7 + * with ymm1*/ + BLIS_ZTRSM_MUL(ymm7) #endif - //extract a22 - ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); - - //(ROW2): FMA operations - ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a)); - if(conjtransa) - { - ymm2 = _mm256_mul_pd(ymm2, ymm0); - } - //Step 1 - ymm14 = _mm256_permute_pd(ymm2, 0x5); - /* Negate the imaginary elements of vec2 */ - ymm14 = _mm256_mul_pd(ymm14, ymm0); - - //For ymm5 - /* Multiply vec1 and vec2 */ - ymm13 = _mm256_mul_pd(ymm5, ymm2); /*vec3*/ - /* Multiply vec1 and the modified vec2 */ - ymm14 = _mm256_mul_pd(ymm5, ymm14); /*vec4*/ - /* Horizontally subtract the elements in vec3 and vec4 */ - ymm16 = _mm256_hsub_pd(ymm13, ymm14); - //Step 2 - ymm16 = _mm256_mul_pd(ymm16, ymm15); - //Step 3 - ymm3 = _mm256_add_pd(ymm16, ymm3); + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); + m_remainder -= 2; + i += 2; + } + else if(m_remainder == 1) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; + b10 = B + i; + b11 = B + i + j*cs_b; + + k_iter = j; + + /*Fill zeros into ymm registers used in gemm + * accumulations */ + BLIS_SET_YMM_REG_ZEROS + ///GEMM implementation starts/// + BLIS_ZTRSM_SMALL_GEMM_3nx2m(a01,b10,cs_b,p_lda,k_iter) + + // Load b11 of size 2x3 and multiply with alpha + BLIS_PRE_ZTRSM_SMALL_3x2(AlphaVal,b11,cs_b) + ///implement TRSM/// + ////extract a00 + ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); #ifndef BLIS_ENABLE_TRSM_PREINVERSION - /*performs dcomplex divison of ymm3 with ymm1*/ - BLIS_ZTRSM_DIV(ymm3) + /*performs dcomplex divison of ymm3 with ymm1*/ + BLIS_ZTRSM_DIV(ymm3) #else - /*performs dcomplex multiplication of ymm3 with ymm1*/ - BLIS_ZTRSM_MUL(ymm3) + /*performs dcomplex multiplication of ymm3 + * with ymm1*/ + BLIS_ZTRSM_MUL(ymm3) #endif - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - m_remainder -=2; - } - else if(1 == m_remainder) - { - a01 = D_A_pack; - a11 = L + (j*cs_a) + (j*rs_a); - b10 = B + (j+d_nr)*cs_b + (m_remainder - 1); - b11 = B + (m_remainder - 1) + (j*cs_b); - k_iter = (n-j-d_nr); - /*Fill zeros into ymm registers used in gemm - * accumulations */ - BLIS_SET_YMM_REG_ZEROS - /* - Peform GEMM between a01 and b10 blocks - For first itteration there will be no GEMM operation - where k_iter are zero - */ - - BLIS_ZTRSM_SMALL_GEMM_3nx2m(a01,b10,cs_b,p_lda,k_iter) - - /* - Load b11 and multiply with alpha - Add the GEMM output to b11 - and peform TRSM operation. - */ - - BLIS_PRE_ZTRSM_SMALL_3x2(AlphaVal,b11,cs_b) - ///implement TRSM/// - /* - Compute 3x3 TRSM block by using GEMM block output - in register - a. The 4x3 input (gemm outputs) are stored in - combinations of ymm registers - 1. ymm7, ymm8 2. ymm5, ymm6 3. ymm3, ymm4 - b. Towards the end do in regiser transpose of TRSM - output and store in - b11 - */ - ////extract a00 - ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - ymm1 = _mm256_broadcast_pd((__m128d const *) - (d11_pack + 2)); + //extract a11 + ymm1 = _mm256_broadcast_pd((__m128d const *) + (d11_pack + 1)); + //(ROW1): FMA operations + ymm2 = _mm256_broadcast_pd((__m128d const *) + (a11 + rs_a*1)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + /* Step1 dcomplex multiply ymm2, ymm3 + * Step2 negate the result + * Step3 add ymmx*/ + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + //For ymm3 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm3, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm3, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + + //Step 3 + ymm5 = _mm256_add_pd(ymm16, ymm5); + + ymm2 = _mm256_broadcast_pd((__m128d const *) + (a11 + rs_a*2)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + //For ymm3 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm3, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm3, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + + //Step 3 + ymm7 = _mm256_add_pd(ymm16, ymm7); + #ifndef BLIS_ENABLE_TRSM_PREINVERSION - /*performs dcomplex divison of ymm7 with ymm1*/ - BLIS_ZTRSM_DIV(ymm7) + /*performs dcomplex divison of ymm5 with ymm1*/ + BLIS_ZTRSM_DIV(ymm5) #else - /*performs dcomplex multiplication of ymm7 with ymm1*/ - BLIS_ZTRSM_MUL(ymm7) + /*performs dcomplex multiplication of ymm5 + * with ymm1*/ + BLIS_ZTRSM_MUL(ymm5) #endif - //extract a11 - ymm1 = _mm256_broadcast_pd((__m128d const *) - (d11_pack + 1)); - //(ROW1): FMA operations - ymm2 = _mm256_broadcast_pd((__m128d const *) - (a11 + cs_a*2 + rs_a*1)); - - if(conjtransa) - { - ymm2 = _mm256_mul_pd(ymm2, ymm0); - } - /* Step1 dcomplex multiply ymm2, ymm7 - * Step2 negate the result - * Step3 add ymmx*/ - //Step 1 - ymm14 = _mm256_permute_pd(ymm2, 0x5); - /* Negate the imaginary elements of vec2 */ - ymm14 = _mm256_mul_pd(ymm14, ymm0); - //For ymm7 - /* Multiply vec1 and vec2 */ - ymm13 = _mm256_mul_pd(ymm7, ymm2); /*vec3*/ - /* Multiply vec1 and the modified vec2 */ - ymm14 = _mm256_mul_pd(ymm7, ymm14); /*vec4*/ - /* Horizontally subtract the elements in vec3 and vec4 */ - ymm16 = _mm256_hsub_pd(ymm13, ymm14); - - //Step 2 - ymm16 = _mm256_mul_pd(ymm16, ymm15); - - //Step 3 - ymm5 = _mm256_add_pd(ymm16, ymm5); - - ymm2 = _mm256_broadcast_pd((__m128d const *) - (a11 + cs_a*2)); - if(conjtransa) - { - ymm2 = _mm256_mul_pd(ymm2, ymm0); - } - //Step 1 - ymm14 = _mm256_permute_pd(ymm2, 0x5); - /* Negate the imaginary elements of vec2 */ - ymm14 = _mm256_mul_pd(ymm14, ymm0); - //For ymm7 - /* Multiply vec1 and vec2 */ - ymm13 = _mm256_mul_pd(ymm7, ymm2); /*vec3*/ - /* Multiply vec1 and the modified vec2 */ - ymm14 = _mm256_mul_pd(ymm7, ymm14); /*vec4*/ - /* Horizontally subtract the elements in vec3 and vec4 */ - ymm16 = _mm256_hsub_pd(ymm13, ymm14); - //Step 2 - ymm16 = _mm256_mul_pd(ymm16, ymm15); - - //Step 3 - ymm3 = _mm256_add_pd(ymm16, ymm3); + a11 += cs_a; + //extract a22 + ymm1 = _mm256_broadcast_pd((__m128d const *) + (d11_pack + 2)); + //(ROW2): FMA operations + ymm2 = _mm256_broadcast_pd((__m128d const *) + (a11 + rs_a * 2)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + + //For ymm5 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm5, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm5, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + //Step 3 + ymm7 = _mm256_add_pd(ymm16, ymm7); #ifndef BLIS_ENABLE_TRSM_PREINVERSION - /*performs dcomplex divison of ymm5 with ymm1*/ - BLIS_ZTRSM_DIV(ymm5) + /*performs dcomplex divison of ymm7 with ymm1*/ + BLIS_ZTRSM_DIV(ymm7) #else - /*performs dcomplex multiplication of ymm5 with ymm1*/ - BLIS_ZTRSM_MUL(ymm5) + /*performs dcomplex multiplication of ymm7 + * with ymm1*/ + BLIS_ZTRSM_MUL(ymm7) #endif - //extract a22 - ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); - - //(ROW2): FMA operations - ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a)); - if(conjtransa) - { - ymm2 = _mm256_mul_pd(ymm2, ymm0); - } - //Step 1 - ymm14 = _mm256_permute_pd(ymm2, 0x5); - /* Negate the imaginary elements of vec2 */ - ymm14 = _mm256_mul_pd(ymm14, ymm0); - - //For ymm5 - /* Multiply vec1 and vec2 */ - ymm13 = _mm256_mul_pd(ymm5, ymm2); /*vec3*/ - /* Multiply vec1 and the modified vec2 */ - ymm14 = _mm256_mul_pd(ymm5, ymm14); /*vec4*/ - /* Horizontally subtract the elements in vec3 and vec4 */ - ymm16 = _mm256_hsub_pd(ymm13, ymm14); - //Step 2 - ymm16 = _mm256_mul_pd(ymm16, ymm15); - //Step 3 - ymm3 = _mm256_add_pd(ymm16, ymm3); -#ifndef BLIS_ENABLE_TRSM_PREINVERSION - /*performs dcomplex divison of ymm3 with ymm1*/ - BLIS_ZTRSM_DIV(ymm3) -#else - /*performs dcomplex multiplication of ymm3 and with ymm1*/ - BLIS_ZTRSM_MUL(ymm3) -#endif - _mm_storeu_pd((double *)b11, - _mm256_extractf128_pd(ymm3,0)); - _mm_storeu_pd((double *)(b11 + cs_b), - _mm256_extractf128_pd(ymm5,0)); - _mm_storeu_pd((double *)(b11 + cs_b*2), - _mm256_extractf128_pd(ymm7,0)); - m_remainder -=1; - } - } - - } - dim_t n_remainder = j + d_nr; - if(n_remainder == 2) - { - a01 = L + (n_remainder - 2)*rs_a + n_remainder*cs_a; - a11 = L + (n_remainder - 2)*cs_a + (n_remainder - 2)*rs_a; - - dcomplex *ptr_a10_dup = D_A_pack; - - dim_t p_lda = (n-n_remainder); - - if(transa) - { - for(dim_t x =0;x < p_lda;x+=d_nr) - { - ymm0 = _mm256_loadu_pd((double const *)(a01)); - ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); - ymm3 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm4 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - - _mm256_storeu_pd((double *)(ptr_a10_dup), ymm3); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm4); - ymm0 = _mm256_loadu_pd((double const *)(a01 + 2)); - ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a + 2)); - ymm3 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 2), - ymm3); - - ymm0 = _mm256_loadu_pd((double const *)(a01 + cs_a * 2)); - ymm1 = _mm256_loadu_pd((double const *) - (a01 + cs_a * 2 + 2)); - ymm5 = _mm256_broadcast_pd((__m128d const *)&zero); - - ymm3 = _mm256_permute2f128_pd(ymm0,ymm5,0x20); - ymm4 = _mm256_permute2f128_pd(ymm0,ymm5,0x31); - ymm5 = _mm256_permute2f128_pd(ymm1,ymm5,0x20); - - _mm_storeu_pd((double *)(ptr_a10_dup + 2), - _mm256_extractf128_pd(ymm3,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + p_lda + 2), - _mm256_extractf128_pd(ymm4,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 2 + 2), - _mm256_extractf128_pd(ymm5, 0)); - a01 += d_nr*cs_a; - ptr_a10_dup += d_nr; - } - } - else - { - dim_t loop_count = (n-n_remainder)/2; - - for(dim_t x =0;x < loop_count;x++) - { - ymm15 = _mm256_loadu_pd((double const *) - (a01 + rs_a * 0 + x*2)); - _mm256_storeu_pd((double *) - (ptr_a10_dup + p_lda * 0 + x*2), ymm15); - ymm15 = _mm256_loadu_pd((double const *) - (a01 + rs_a * 1 + x*2)); - _mm256_storeu_pd((double *) - (ptr_a10_dup + p_lda * 1 + x*2), ymm15); - } - - dim_t remainder_loop_count = p_lda - loop_count*2; - - __m128d xmm0; - if(remainder_loop_count != 0) - { - xmm0 = _mm_loadu_pd((double const *) - (a01 + rs_a * 0 + loop_count*2)); - _mm_storeu_pd((double *) - (ptr_a10_dup + p_lda * 0 + loop_count*2), - xmm0); - xmm0 = _mm_loadu_pd((double const *) - (a01 + rs_a * 1 + loop_count*2)); - _mm_storeu_pd((double *) - (ptr_a10_dup + p_lda * 1 + loop_count*2), - xmm0); - } - } - if(!is_unitdiag) - { - if(transa) - { - ymm0 = _mm256_broadcast_pd((__m128d const *)(a11)); - ymm1 = _mm256_broadcast_pd((__m128d const *) - (a11+cs_a*1 + 1)); - } - else - { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_pd((__m128d const *)(a11)); - ymm1 = _mm256_broadcast_pd((__m128d const *) - (a11+rs_a*1 + 1)); - } - ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); + _mm_storeu_pd((double *)b11, + _mm256_extractf128_pd(ymm3,0)); + _mm_storeu_pd((double *)(b11 + cs_b), + _mm256_extractf128_pd(ymm5,0)); + _mm_storeu_pd((double *)(b11 + cs_b*2), + _mm256_extractf128_pd(ymm7,0)); + + m_remainder -= 1; + i += 1; + } + } + + } + dim_t n_remainder = n - j; + if(n_remainder == 2) + { + a01 = L + j*rs_a; + a11 = L + j*cs_a + j*rs_a; + dcomplex *ptr_a10_dup = D_A_pack; + + dim_t p_lda = j; + // perform copy of A to packed buffer D_A_pack + + if(transa) + { + for(dim_t x =0;x < p_lda;x+=d_nr) + { + ymm0 = _mm256_loadu_pd((double const *)(a01)); + ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); + ymm3 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm4 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm3); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm4); + + ymm0 = _mm256_loadu_pd((double const *)(a01 + 2)); + ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a + 2)); + ymm3 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 2), + ymm3); + + ymm0 = _mm256_loadu_pd((double const *) + (a01 + cs_a * 2)); + ymm1 = _mm256_loadu_pd((double const *) + (a01 + cs_a * 2 + 2)); + ymm5 = _mm256_broadcast_pd((__m128d const *)&zero); + + ymm3 = _mm256_permute2f128_pd(ymm0,ymm5,0x20); + ymm4 = _mm256_permute2f128_pd(ymm0,ymm5,0x31); + ymm5 = _mm256_permute2f128_pd(ymm1,ymm5,0x20); + + _mm_storeu_pd((double *)(ptr_a10_dup + 2), + _mm256_extractf128_pd(ymm3,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + p_lda + 2), + _mm256_extractf128_pd(ymm4,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 2 + 2), + _mm256_extractf128_pd(ymm5, 0)); + a01 += d_nr*cs_a; + ptr_a10_dup += d_nr; + } + } + else + { + dim_t loop_count = p_lda/2; + + for(dim_t x =0;x < loop_count;x++) + { + ymm15 = _mm256_loadu_pd((double const *) + (a01 + rs_a * 0 + x*2)); + _mm256_storeu_pd((double *) + (ptr_a10_dup + p_lda * 0 + x*2), ymm15); + ymm15 = _mm256_loadu_pd((double const *) + (a01 + rs_a * 1 + x*2)); + _mm256_storeu_pd((double *) + (ptr_a10_dup + p_lda * 1 + x*2), + ymm15); + } + + dim_t remainder_loop_count = p_lda - loop_count*2; + + __m128d xmm0; + if(remainder_loop_count != 0) + { + xmm0 = _mm_loadu_pd((double const *) + (a01 + rs_a * 0 + loop_count*2)); + _mm_storeu_pd((double *) + (ptr_a10_dup + p_lda * 0 + loop_count*2), + xmm0); + xmm0 = _mm_loadu_pd((double const *) + (a01 + rs_a * 1 + loop_count*2)); + _mm_storeu_pd((double *) + (ptr_a10_dup + p_lda * 1 + loop_count*2), + xmm0); + } + } + if(!is_unitdiag) + { + if(transa) + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_pd((__m128d const *)(a11)); + ymm1 = _mm256_broadcast_pd((__m128d const *) + (a11+cs_a*1 + 1)); + } + else + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_pd((__m128d const *)(a11)); + ymm1 = _mm256_broadcast_pd((__m128d const *) + (a11+rs_a*1 + 1)); + } + ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); #ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm7 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - /*Taking denomerator multiplication of real & - * imaginary components*/ - ymm4 = _mm256_mul_pd(ymm1, ymm1); - /*Swapping real & imaginary component position for addition with - * respective components*/ - ymm6 = _mm256_permute4x64_pd(ymm4, 0xb1); - ymm4 = _mm256_add_pd(ymm4, ymm6); - /*Negating imaginary component of numerator*/ - ymm1 = _mm256_mul_pd(ymm1, ymm7); - /*Dividing numerator by denominator*/ - ymm1 = _mm256_div_pd(ymm1, ymm4); + ymm7 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + /*Taking denomerator multiplication of real & + * imaginary components*/ + ymm4 = _mm256_mul_pd(ymm1, ymm1); + /*Swapping real & imaginary component position for addition with + * respective components*/ + ymm6 = _mm256_permute4x64_pd(ymm4, 0xb1); + ymm4 = _mm256_add_pd(ymm4, ymm6); + /*Negating imaginary component of numerator*/ + ymm1 = _mm256_mul_pd(ymm1, ymm7); + /*Dividing numerator by denominator*/ + ymm1 = _mm256_div_pd(ymm1, ymm4); #endif - } - else - { - ymm1 = _mm256_broadcast_pd((__m128d const*)&ones); - } - _mm256_storeu_pd((double *)(d11_pack), ymm1); - for(i = (m-d_mr); (i+1) > 0; i -= d_mr) //loop along 'M' direction - { - a01 = D_A_pack; - a11 = L + (n_remainder - 2)*cs_a + (n_remainder - 2)*rs_a; - b10 = B + i + (n_remainder)*cs_b; - b11 = B + (i) + (n_remainder - 2)*cs_b; - - k_iter = (n-n_remainder); - - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS - BLIS_ZTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) - BLIS_PRE_ZTRSM_SMALL_2x4(AlphaVal,b11,cs_b) - ///implement TRSM/// - ////extract a00 - ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack + 1)); + } + else + { + ymm1 = _mm256_broadcast_pd((__m128d const *)&ones); + } + _mm256_storeu_pd((double *)(d11_pack), ymm1); -#ifndef BLIS_ENABLE_TRSM_PREINVERSION - /*performs dcomplex divison of ymm5 and ymm6 with ymm1*/ - BLIS_ZTRSM_TWO_DIV(ymm5,ymm6) -#else - /*performs dcomplex multiplication of ymm5 and ymm6 with ymm1*/ - BLIS_ZTRSM_MUL(ymm5) - BLIS_ZTRSM_MUL(ymm6) -#endif - //extract a22 - ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); - - //(ROW2): FMA operations - ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a)); - if(conjtransa) - { - ymm2 = _mm256_mul_pd(ymm2, ymm0); - } - //Step 1 - ymm14 = _mm256_permute_pd(ymm2, 0x5); - /* Negate the imaginary elements of vec2 */ - ymm14 = _mm256_mul_pd(ymm14, ymm0); - - //For ymm5 - /* Multiply vec1 and vec2 */ - ymm13 = _mm256_mul_pd(ymm5, ymm2); /*vec3*/ - /* Multiply vec1 and the modified vec2 */ - ymm14 = _mm256_mul_pd(ymm5, ymm14); /*vec4*/ - /* Horizontally subtract the elements in vec3 and vec4 */ - ymm16 = _mm256_hsub_pd(ymm13, ymm14); - //For ymm6 - ymm14 = _mm256_permute_pd(ymm2, 0x5); - /* Negate the imaginary elements of vec2 */ - ymm14 = _mm256_mul_pd(ymm14, ymm0); - - ymm13 = _mm256_mul_pd(ymm6, ymm2); - ymm14 = _mm256_mul_pd(ymm6, ymm14); - ymm17 = _mm256_hsub_pd(ymm13, ymm14); - //Step 2 - ymm16 = _mm256_mul_pd(ymm16, ymm15); - ymm17 = _mm256_mul_pd(ymm17, ymm15); - //Step 3 - ymm3 = _mm256_add_pd(ymm16, ymm3); - ymm4 = _mm256_add_pd(ymm17, ymm4); + for(i = 0; (i+d_mr-1) < m; i += d_mr) //loop along 'M' direction + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; + b10 = B + i; + b11 = B + i + j*cs_b; + k_iter = j; + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + ///GEMM implementation starts/// + BLIS_ZTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_PRE_ZTRSM_SMALL_2x4(AlphaVal,b11,cs_b) + ///implement TRSM/// + ////extract a00 + ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); #ifndef BLIS_ENABLE_TRSM_PREINVERSION - /*performs dcomplex divison of ymm3 and ymm4 with ymm1*/ - BLIS_ZTRSM_TWO_DIV(ymm3,ymm4) + /*performs dcomplex divison of ymm3 and ymm4 with ymm1*/ + BLIS_ZTRSM_TWO_DIV(ymm3,ymm4) #else - /*performs dcomplex multiplication of ymm3 and ymm4 with ymm1*/ - BLIS_ZTRSM_MUL(ymm3) - BLIS_ZTRSM_MUL(ymm4) + /*performs dcomplex multiplication of ymm3 and ymm4 with ymm1*/ + BLIS_ZTRSM_MUL(ymm3) + BLIS_ZTRSM_MUL(ymm4) #endif - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + 2), ymm4); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b + 2), ymm6); - - } - dim_t m_remainder = i + d_mr; - if(3 == m_remainder) - { - a01 = D_A_pack; - a11 = L + (n_remainder - 2)*cs_a + (n_remainder - 2)*rs_a; - b10 = B + (m_remainder - 3) + (n_remainder)*cs_b; - b11 = B + (m_remainder - 3) + (n_remainder - 2)*cs_b; - - k_iter = (n-n_remainder); - - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS - /* - Peform GEMM between a01 and b10 blocks - For first itteration there will be no GEMM operation - where k_iter are zero - */ - BLIS_ZTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) - - // Load b11 and multiply with alpha - BLIS_PRE_ZTRSM_SMALL_3x4(AlphaVal,b11,cs_b) - ////extract a00 - ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack + 1)); + //extract a11 + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack + 1)); + //(ROW1): FMA operations + ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + rs_a*1)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + /* Step1 dcomplex multiply ymm2, ymm3 + * Step2 negate the result + * Step3 add ymmx*/ + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + //For ymm3 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm3, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm3, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); -#ifndef BLIS_ENABLE_TRSM_PREINVERSION - /*performs dcomplex divison of ymm5 and ymm6 with ymm1*/ - BLIS_ZTRSM_TWO_DIV(ymm5,ymm6) -#else - /*performs dcomplex multiplication of ymm5 and ymm6 with ymm1*/ - BLIS_ZTRSM_MUL(ymm5) - BLIS_ZTRSM_MUL(ymm6) -#endif - //extract a22 - ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); - - //(ROW2): FMA operations - ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a)); - if(conjtransa) - { - ymm2 = _mm256_mul_pd(ymm2, ymm0); - } - //Step 1 - ymm14 = _mm256_permute_pd(ymm2, 0x5); - /* Negate the imaginary elements of vec2 */ - ymm14 = _mm256_mul_pd(ymm14, ymm0); - - //For ymm5 - /* Multiply vec1 and vec2 */ - ymm13 = _mm256_mul_pd(ymm5, ymm2); /*vec3*/ - /* Multiply vec1 and the modified vec2 */ - ymm14 = _mm256_mul_pd(ymm5, ymm14); /*vec4*/ - /* Horizontally subtract the elements in vec3 and vec4 */ - ymm16 = _mm256_hsub_pd(ymm13, ymm14); - //For ymm6 - ymm14 = _mm256_permute_pd(ymm2, 0x5); - /* Negate the imaginary elements of vec2 */ - ymm14 = _mm256_mul_pd(ymm14, ymm0); - - ymm13 = _mm256_mul_pd(ymm6, ymm2); - ymm14 = _mm256_mul_pd(ymm6, ymm14); - ymm17 = _mm256_hsub_pd(ymm13, ymm14); - //Step 2 - ymm16 = _mm256_mul_pd(ymm16, ymm15); - ymm17 = _mm256_mul_pd(ymm17, ymm15); - //Step 3 - ymm3 = _mm256_add_pd(ymm16, ymm3); - ymm4 = _mm256_add_pd(ymm17, ymm4); + //For ymm4 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + ymm13 = _mm256_mul_pd(ymm4, ymm2); + ymm14 = _mm256_mul_pd(ymm4, ymm14); + ymm17 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + ymm17 = _mm256_mul_pd(ymm17, ymm15); -#ifndef BLIS_ENABLE_TRSM_PREINVERSION - /*performs dcomplex divison of ymm3 and ymm4 with ymm1*/ - BLIS_ZTRSM_TWO_DIV(ymm3,ymm4) -#else - /*performs dcomplex multiplication of ymm3 and ymm4 with ymm1*/ - BLIS_ZTRSM_MUL(ymm3) - BLIS_ZTRSM_MUL(ymm4) -#endif - _mm256_storeu_pd((double *)b11, ymm3); - _mm_storeu_pd((double *)(b11 + 2), - _mm256_extractf128_pd(ymm4,0)); - - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm_storeu_pd((double *)(b11 + cs_b + 2), - _mm256_extractf128_pd(ymm6,0)); - m_remainder -=3; - } - if(2 == m_remainder) - { - a01 = D_A_pack; - a11 = L + (n_remainder - 2)*cs_a + (n_remainder - 2)*rs_a; - b10 = B + (m_remainder - 2) + (n_remainder)*cs_b; - b11 = B + (m_remainder - 2) + (n_remainder - 2)*cs_b; - - k_iter = (n-n_remainder); - - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS - /* - Peform GEMM between a01 and b10 blocks - For first itteration there will be no GEMM operation - where k_iter are zero - */ - BLIS_ZTRSM_SMALL_GEMM_3nx2m(a01,b10,cs_b,p_lda,k_iter) - - // Load b11 and multiply with alpha - BLIS_PRE_ZTRSM_SMALL_3x2(AlphaVal,b11,cs_b) - ////extract a00 - ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack + 1)); + //Step 3 + ymm5 = _mm256_add_pd(ymm16, ymm5); + ymm6 = _mm256_add_pd(ymm17, ymm6); #ifndef BLIS_ENABLE_TRSM_PREINVERSION - /*performs dcomplex divison of ymm5 with ymm1*/ - BLIS_ZTRSM_DIV(ymm5) + /*performs dcomplex divison of ymm5 and ymm6 with ymm1*/ + BLIS_ZTRSM_TWO_DIV(ymm5,ymm6) #else - /*performs dcomplex multiplication of ymm5 with ymm1*/ - BLIS_ZTRSM_MUL(ymm5) + /*performs dcomplex multiplication of ymm5 and ymm6 with ymm1*/ + BLIS_ZTRSM_MUL(ymm5) + BLIS_ZTRSM_MUL(ymm6) #endif - //extract a22 - ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); - - //(ROW2): FMA operations - ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a)); - if(conjtransa) - { - ymm2 = _mm256_mul_pd(ymm2, ymm0); - } - //Step 1 - ymm14 = _mm256_permute_pd(ymm2, 0x5); - /* Negate the imaginary elements of vec2 */ - ymm14 = _mm256_mul_pd(ymm14, ymm0); - - //For ymm5 - /* Multiply vec1 and vec2 */ - ymm13 = _mm256_mul_pd(ymm5, ymm2); /*vec3*/ - /* Multiply vec1 and the modified vec2 */ - ymm14 = _mm256_mul_pd(ymm5, ymm14); /*vec4*/ - /* Horizontally subtract the elements in vec3 and vec4 */ - ymm16 = _mm256_hsub_pd(ymm13, ymm14); - - //Step 2 - ymm16 = _mm256_mul_pd(ymm16, ymm15); - //Step 3 - ymm3 = _mm256_add_pd(ymm16, ymm3); + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + 2), ymm4); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + _mm256_storeu_pd((double *)(b11 + cs_b + 2), ymm6); + } + dim_t m_remainder = m - i; + if(m_remainder == 3) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; + b10 = B + i; + b11 = B + i + j*cs_b; + k_iter = j; + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS -#ifndef BLIS_ENABLE_TRSM_PREINVERSION - /*performs dcomplex divison of ymm3 with ymm1*/ - BLIS_ZTRSM_DIV(ymm3) -#else - /*performs dcomplex multiplication of ymm3 with ymm1*/ - BLIS_ZTRSM_MUL(ymm3) -#endif - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - m_remainder -=2; - } - if(1 == m_remainder) - { - a01 = D_A_pack; - a11 = L + (n_remainder - 2)*cs_a + (n_remainder - 2)*rs_a; - b10 = B + (m_remainder - 1) + (n_remainder)*cs_b; - b11 = B + (m_remainder - 1) + (n_remainder - 2)*cs_b; - - k_iter = (n-n_remainder); - - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS - /* - Peform GEMM between a01 and b10 blocks - For first itteration there will be no GEMM operation - where k_iter are zero - */ - BLIS_ZTRSM_SMALL_GEMM_3nx2m(a01,b10,cs_b,p_lda,k_iter) - - // Load b11 and multiply with alpha - BLIS_PRE_ZTRSM_SMALL_3x2(AlphaVal,b11,cs_b) - ////extract a00 - ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack + 1)); + ///GEMM implementation starts/// + BLIS_ZTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) + + // Load b11 of size 4x6 and multiply with alpha + BLIS_PRE_ZTRSM_SMALL_3x4(AlphaVal,b11,cs_b) + ///implement TRSM/// + ////extract a00 + ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); #ifndef BLIS_ENABLE_TRSM_PREINVERSION - /*performs dcomplex divison of ymm5 with ymm1*/ - BLIS_ZTRSM_DIV(ymm5) + /*performs dcomplex divison of ymm3 and ymm4 with ymm1*/ + BLIS_ZTRSM_TWO_DIV(ymm3,ymm4) #else - /*performs dcomplex multiplication of ymm5 with ymm1*/ - BLIS_ZTRSM_MUL(ymm5) + /*performs dcomplex multiplication of ymm3 and ymm4 with ymm1*/ + BLIS_ZTRSM_MUL(ymm3) + BLIS_ZTRSM_MUL(ymm4) #endif - //extract a22 - ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); - - //(ROW2): FMA operations - ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + cs_a)); - if(conjtransa) - { - ymm2 = _mm256_mul_pd(ymm2, ymm0); - } - //Step 1 - ymm14 = _mm256_permute_pd(ymm2, 0x5); - /* Negate the imaginary elements of vec2 */ - ymm14 = _mm256_mul_pd(ymm14, ymm0); - - //For ymm5 - /* Multiply vec1 and vec2 */ - ymm13 = _mm256_mul_pd(ymm5, ymm2); /*vec3*/ - /* Multiply vec1 and the modified vec2 */ - ymm14 = _mm256_mul_pd(ymm5, ymm14); /*vec4*/ - /* Horizontally subtract the elements in vec3 and vec4 */ - ymm16 = _mm256_hsub_pd(ymm13, ymm14); - - //Step 2 - ymm16 = _mm256_mul_pd(ymm16, ymm15); - //Step 3 - ymm3 = _mm256_add_pd(ymm16, ymm3); + //extract a11 + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack + 1)); + //(ROW1): FMA operations + ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + rs_a*1)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + /* Step1 dcomplex multiply ymm2, ymm3 + * Step2 negate the result + * Step3 add ymmx*/ + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + //For ymm3 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm3, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm3, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); + + //For ymm4 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + ymm13 = _mm256_mul_pd(ymm4, ymm2); + ymm14 = _mm256_mul_pd(ymm4, ymm14); + ymm17 = _mm256_hsub_pd(ymm13, ymm14); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + ymm17 = _mm256_mul_pd(ymm17, ymm15); + //Step 3 + ymm5 = _mm256_add_pd(ymm16, ymm5); + ymm6 = _mm256_add_pd(ymm17, ymm6); #ifndef BLIS_ENABLE_TRSM_PREINVERSION - /*performs dcomplex divison of ymm3 with ymm1*/ - BLIS_ZTRSM_DIV(ymm3) -#else - /*performs dcomplex multiplication of ymm3 with ymm1*/ - BLIS_ZTRSM_MUL(ymm3) -#endif - _mm_storeu_pd((double *)b11, - _mm256_extractf128_pd(ymm3,0)); - _mm_storeu_pd((double *)(b11 + cs_b), - _mm256_extractf128_pd(ymm5,0)); - m_remainder -=1; - } - n_remainder -= 2; - } - else if(n_remainder == 1) - { - a01 = L + (n_remainder - 1)*rs_a + n_remainder*cs_a; - a11 = L + (n_remainder - 1)*cs_a + (n_remainder - 1)*rs_a; - - dcomplex *ptr_a10_dup = D_A_pack; - - dim_t p_lda = (n-n_remainder); // packed leading dimension - // perform copy of A to packed buffer D_A_pack - if(transa) - { - for(dim_t x =0;x < p_lda;x+=d_nr) - { - ymm0 = _mm256_loadu_pd((double const *)(a01)); - ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); - ymm3 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm4 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - - _mm256_storeu_pd((double *)(ptr_a10_dup), ymm3); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm4); - - ymm0 = _mm256_loadu_pd((double const *)(a01 + 2)); - ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a + 2)); - ymm3 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - _mm256_storeu_pd((double *) - (ptr_a10_dup + p_lda * 2), ymm3); - - ymm0 = _mm256_loadu_pd((double const *)(a01 + cs_a * 2)); - ymm1 = _mm256_loadu_pd((double const *) - (a01 + cs_a * 2 + 2)); - ymm5 = _mm256_broadcast_pd((__m128d const *)&zero); - - ymm3 = _mm256_permute2f128_pd(ymm0,ymm5,0x20); - ymm4 = _mm256_permute2f128_pd(ymm0,ymm5,0x31); - ymm5 = _mm256_permute2f128_pd(ymm1,ymm5,0x20); - - _mm_storeu_pd((double *)(ptr_a10_dup + 2), - _mm256_extractf128_pd(ymm3,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + p_lda + 2), - _mm256_extractf128_pd(ymm4,0)); - _mm_storeu_pd((double *) - (ptr_a10_dup + p_lda * 2 + 2), - _mm256_extractf128_pd(ymm5, 0)); - a01 += d_nr*cs_a; - ptr_a10_dup += d_nr; - } - - } - else - { - dim_t loop_count = (n-n_remainder)/2; - - for(dim_t x =0;x < loop_count;x++) - { - ymm15 = _mm256_loadu_pd((double const *) - (a01 + rs_a * 0 + x*2)); - _mm256_storeu_pd((double *) - (ptr_a10_dup + p_lda * 0 + x*2), ymm15); - } - - dim_t remainder_loop_count = p_lda - loop_count*2; - - __m128d xmm0; - if(remainder_loop_count != 0) - { - xmm0 = _mm_loadu_pd((double const *) - (a01 + rs_a * 0 + loop_count*2)); - _mm_storeu_pd((double *) - (ptr_a10_dup + p_lda * 0 + loop_count*2), - xmm0); - } - } - if(!is_unitdiag) - { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_pd((__m128d const *)(a11)); - ymm1 = _mm256_broadcast_pd((__m128d const *)&ones); - ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); -#ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm7 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - /*Taking denomerator multiplication of real & - * imaginary components*/ - ymm4 = _mm256_mul_pd(ymm1, ymm1); - /*Swapping real & imaginary component position for addition with - * respective components*/ - ymm6 = _mm256_permute4x64_pd(ymm4, 0xb1); - ymm4 = _mm256_add_pd(ymm4, ymm6); - /*Negating imaginary component of numerator*/ - ymm1 = _mm256_mul_pd(ymm1, ymm7); - /*Dividing numerator by denominator*/ - ymm1 = _mm256_div_pd(ymm1, ymm4); -#endif - } - else - { - ymm1 = _mm256_broadcast_pd((__m128d const*)&ones); - } - _mm256_storeu_pd((double *)(d11_pack), ymm1); - for(i = (m-d_mr); (i+1) > 0; i -= d_mr) //loop along 'M' direction - { - a01 = D_A_pack; - a11 = L + (n_remainder - 1)*cs_a + (n_remainder - 1)*rs_a; - b10 = B + i + (n_remainder)*cs_b; - b11 = B + (i) + (n_remainder - 1)*cs_b; - - k_iter = (n-n_remainder); - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS - ///GEMM implementation starts/// - BLIS_ZTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) - BLIS_PRE_ZTRSM_SMALL_1x4(b11,cs_b,AlphaVal) - ///implement TRSM/// - ////extract a00 - ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); -#ifndef BLIS_ENABLE_TRSM_PREINVERSION - /*performs dcomplex divison of ymm3 and ymm4 with ymm1*/ - BLIS_ZTRSM_TWO_DIV(ymm3,ymm4) + /*performs dcomplex divison of ymm5 and ymm6 with ymm1*/ + BLIS_ZTRSM_TWO_DIV(ymm5,ymm6) #else - /*performs dcomplex multiplication of ymm3 and ymm4 with ymm1*/ - BLIS_ZTRSM_MUL(ymm3) - BLIS_ZTRSM_MUL(ymm4) + /*performs dcomplex multiplication of ymm5 and ymm6 with ymm1*/ + BLIS_ZTRSM_MUL(ymm5) + BLIS_ZTRSM_MUL(ymm6) #endif - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + 2),ymm4); - - } - dim_t m_remainder = i + d_mr; - if(3 == m_remainder) - { - a01 = D_A_pack; - a11 = L + (n_remainder - 1)*cs_a + (n_remainder - 1)*rs_a; - b10 = B + (m_remainder - 3) + (n_remainder)*cs_b; - b11 = B + (m_remainder - 3) + (n_remainder - 1)*cs_b; - - k_iter = (n-n_remainder); - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS - - ///GEMM implementation starts/// - BLIS_ZTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) - BLIS_PRE_ZTRSM_SMALL_1x3(b11,cs_b,AlphaVal) - - ///implement TRSM/// - ////extract a00 - ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); + + _mm256_storeu_pd((double *)b11, ymm3); + _mm_storeu_pd((double *)(b11 + 2), + _mm256_extractf128_pd(ymm4,0)); + + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + _mm_storeu_pd((double *)(b11 + cs_b + 2), + _mm256_extractf128_pd(ymm6,0)); + m_remainder -= 3; + i += 3; + } + if(m_remainder == 2) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; + b10 = B + i; + b11 = B + i + j*cs_b; + + k_iter = j; + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_ZTRSM_SMALL_GEMM_2nx2m(a01,b10,cs_b,p_lda,k_iter) + + // Load b11 of size 4x6 and multiply with alpha + BLIS_PRE_ZTRSM_SMALL_2x2(AlphaVal,b11,cs_b) + + ///implement TRSM/// + ////extract a00 + ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); #ifndef BLIS_ENABLE_TRSM_PREINVERSION - /*performs dcomplex divison of ymm3 and ymm4 with ymm1*/ - BLIS_ZTRSM_TWO_DIV(ymm3,ymm4) + /*performs dcomplex divison of ymm3 with ymm1*/ + BLIS_ZTRSM_DIV(ymm3) #else - /*performs dcomplex multiplication of ymm3 and ymm4 with ymm1*/ - BLIS_ZTRSM_MUL(ymm3) - BLIS_ZTRSM_MUL(ymm4) + /*performs dcomplex multiplication of ymm3 with ymm1*/ + BLIS_ZTRSM_MUL(ymm3) #endif + //extract a11 + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack + 1)); + //(ROW1): FMA operations + ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + rs_a*1)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + /* Step1 dcomplex multiply ymm2, ymm3 + * Step2 negate the result + * Step3 add ymmx*/ + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + //For ymm3 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm3, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm3, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); - _mm256_storeu_pd((double *)b11, ymm3); - _mm_storeu_pd((double *)(b11 + 2), - _mm256_extractf128_pd(ymm4,0)); - m_remainder -=3; - - } - else if(2 == m_remainder) - { - a01 = D_A_pack; - a11 = L + (n_remainder - 1)*cs_a + (n_remainder - 1)*rs_a; - b10 = B + (m_remainder - 2) + (n_remainder)*cs_b; - b11 = B + (m_remainder - 2) + (n_remainder - 1)*cs_b; - - k_iter = (n-n_remainder); - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS - - ///GEMM implementation starts/// - BLIS_ZTRSM_SMALL_GEMM_1nx2m(a01,b10,cs_b,p_lda,k_iter) - - // Load b11 of size 2x1 and multiply with alpha - BLIS_PRE_ZTRSM_SMALL_1x2(AlphaVal,b11,cs_b) - - ///implement TRSM/// - ////extract a00 - ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); + + //Step 3 + ymm5 = _mm256_add_pd(ymm16, ymm5); #ifndef BLIS_ENABLE_TRSM_PREINVERSION - /*performs dcomplex divison of ymm3 with ymm1*/ - BLIS_ZTRSM_DIV(ymm3) + /*performs dcomplex divison of ymm5 with ymm1*/ + BLIS_ZTRSM_DIV(ymm5) #else - /*performs dcomplex multiplication of ymm3 with ymm1*/ - BLIS_ZTRSM_MUL(ymm3) + /*performs dcomplex multiplication of ymm5 with ymm1*/ + BLIS_ZTRSM_MUL(ymm5) #endif - _mm256_storeu_pd((double *)b11, ymm3); - m_remainder -=2; - - } - else if (1 == m_remainder) - { - a01 = D_A_pack; - a11 = L + (n_remainder - 1)*cs_a + (n_remainder - 1)*rs_a; - b10 = B + (m_remainder - 1) + (n_remainder)*cs_b; - b11 = B + (m_remainder - 1) + (n_remainder - 1)*cs_b; + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + m_remainder -= 2; + i += 2; + } + if(m_remainder == 1) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; + b10 = B + i; + b11 = B + i + j*cs_b; - k_iter = (n-n_remainder); - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS + k_iter = j; + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS - ///GEMM implementation starts/// - BLIS_ZTRSM_SMALL_GEMM_1nx1m(a01,b10,cs_b,p_lda,k_iter) + ///GEMM implementation starts/// + BLIS_ZTRSM_SMALL_GEMM_2nx2m(a01,b10,cs_b,p_lda,k_iter) - // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_ZTRSM_SMALL_1x1(AlphaVal,b11,cs_b) + // Load b11 of size 4x6 and multiply with alpha + BLIS_PRE_ZTRSM_SMALL_2x2(AlphaVal,b11,cs_b) - ///implement TRSM/// - ////extract a00 - ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); + ///implement TRSM/// + ////extract a00 + ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); #ifndef BLIS_ENABLE_TRSM_PREINVERSION - /*performs dcomplex divison of ymm3 with ymm1*/ - BLIS_ZTRSM_DIV(ymm3) + /*performs dcomplex divison of ymm3 with ymm1*/ + BLIS_ZTRSM_DIV(ymm3) #else - /*performs dcomplex multiplication of ymm3 with ymm1*/ - BLIS_ZTRSM_MUL(ymm3) + /*performs dcomplex multiplication of ymm3 with ymm1*/ + BLIS_ZTRSM_MUL(ymm3) #endif - _mm_storeu_pd((double *)b11, - _mm256_extractf128_pd(ymm3,0)); - m_remainder -=1; - } - n_remainder -= 1; - } - - if ((required_packing_A == 1) && - bli_mem_is_alloc( &local_mem_buf_A_s )) - { - bli_membrk_release(&rntm, &local_mem_buf_A_s); - } - + //extract a11 + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack + 1)); + //(ROW1): FMA operations + ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + rs_a*1)); + if(conjtransa) + { + ymm2 = _mm256_mul_pd(ymm2, ymm0); + } + /* Step1 dcomplex multiply ymm2, ymm3 + * Step2 negate the result + * Step3 add ymmx*/ + //Step 1 + ymm14 = _mm256_permute_pd(ymm2, 0x5); + /* Negate the imaginary elements of vec2 */ + ymm14 = _mm256_mul_pd(ymm14, ymm0); + //For ymm3 + /* Multiply vec1 and vec2 */ + ymm13 = _mm256_mul_pd(ymm3, ymm2); /*vec3*/ + /* Multiply vec1 and the modified vec2 */ + ymm14 = _mm256_mul_pd(ymm3, ymm14); /*vec4*/ + /* Horizontally subtract the elements in vec3 and vec4 */ + ymm16 = _mm256_hsub_pd(ymm13, ymm14); - return BLIS_SUCCESS; -} + //Step 2 + ymm16 = _mm256_mul_pd(ymm16, ymm15); -BLIS_INLINE err_t bli_ztrsm_small_XAltB_XAuB -( - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl -) -{ - dim_t m = bli_obj_length(b); //number of rows - dim_t n = bli_obj_width(b); //number of columns - - bool transa = bli_obj_has_trans(a); - bool conjtransa = bli_obj_has_conj(a); - - dim_t cs_a, rs_a; - dim_t d_mr = 4,d_nr = 3; - - // Swap rs_a & cs_a in case of non-tranpose. - if(transa) - { - cs_a = bli_obj_col_stride(a); // column stride of A - rs_a = bli_obj_row_stride(a); // row stride of A - } - else - { - cs_a = bli_obj_row_stride(a); // row stride of A - rs_a = bli_obj_col_stride(a); // column stride of A - } - dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B - - dim_t i, j, k; //loop variablse - dim_t k_iter; //determines the number of GEMM operations to be done - - dcomplex ones = {1.0, 1.0}; - dcomplex zero = {0.0, 0.0}; - bool is_unitdiag = bli_obj_has_unit_diag(a); - - dcomplex AlphaVal = *(dcomplex *)AlphaObj->buffer; //value of Alpha - dcomplex* restrict L = a->buffer; //pointer to matrix A - dcomplex* restrict B = b->buffer; //pointer to matrix B - - dcomplex *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks - - gint_t required_packing_A = 1; - mem_t local_mem_buf_A_s = {0}; - dcomplex *D_A_pack = NULL; - dcomplex d11_pack[d_mr] __attribute__((aligned(64))); - rntm_t rntm; - - bli_rntm_init_from_global( &rntm ); - bli_rntm_set_num_threads_only( 1, &rntm ); - bli_membrk_rntm_set_membrk( &rntm ); - - siz_t buffer_size = bli_pool_block_size( - bli_membrk_pool( - bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), - bli_rntm_membrk(&rntm))); - - if( (d_nr * n * sizeof(dcomplex)) > buffer_size) - return BLIS_NOT_YET_IMPLEMENTED; - - if (required_packing_A == 1) - { - // Get the buffer from the pool. - bli_membrk_acquire_m(&rntm, - buffer_size, - BLIS_BITVAL_BUFFER_FOR_A_BLOCK, - &local_mem_buf_A_s); - if(FALSE==bli_mem_is_alloc(&local_mem_buf_A_s)) return BLIS_NULL_POINTER; - D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); - if(NULL==D_A_pack) return BLIS_NULL_POINTER; - } - - //ymm scratch reginsters - __m256d ymm0, ymm1, ymm2, ymm3; - __m256d ymm4, ymm5, ymm6, ymm7; - __m256d ymm8, ymm9, ymm10, ymm11; - __m256d ymm12, ymm13, ymm14, ymm15; - __m256d ymm16, ymm17, ymm18, ymm19; - - __m128d xmm5, xmm4, xmm3; - - for(j = 0; (j+d_nr-1) < n; j += d_nr) //loop along 'N' direction - { - a01 = L + j*rs_a;//pointer to block of A to be used in GEMM - a11 = L + j*cs_a + j*rs_a;//pointer to block of A to be used for TRSM - - dim_t p_lda = j; // packed leading dimension - // perform copy of A to packed buffer D_A_pack - - if(transa) - { - /* - Pack current A block (a01) into packed buffer memory D_A_pack - a. This a10 block is used in GEMM portion only and this - a01 block size will be increasing by d_nr for every next - iteration until it reaches 3x(n-3) which is the maximum GEMM - alone block size in A - b. This packed buffer is reused to calculate all m cols of - B matrix - */ - bli_ztrsm_small_pack('R', j, 1, a01, cs_a, D_A_pack, p_lda,d_nr); - - /* - Pack 3 diagonal elements of A block into an array - a. This helps in utilze cache line efficiently in TRSM - operation - b. store ones when input is unit diagonal - */ - ztrsm_small_pack_diag_element(is_unitdiag,a11,cs_a, - d11_pack,d_nr); - } - else - { - bli_ztrsm_small_pack('R', j, 0, a01, rs_a, D_A_pack, - p_lda,d_nr); - ztrsm_small_pack_diag_element(is_unitdiag,a11,rs_a, - d11_pack,d_nr); - } - - /* - a. Perform GEMM using a01, b10. - b. Perform TRSM on a11, b11 - c. This loop GEMM+TRSM loops operates with 8x6 block size - along m dimension for every d_mr columns of B10 where - packed A buffer is reused in computing all m cols of B. - d. Same approach is used in remaining fringe cases. - */ - for(i = 0; (i+d_mr-1) < m; i += d_mr) //loop along 'M' direction - { - a01 = D_A_pack; - a11 = L + j*cs_a + j*rs_a; - b10 = B + i; - b11 = B + i + j*cs_b; - - k_iter = j; - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS - - /* - Peform GEMM between a01 and b10 blocks - For first itteration there will be no GEMM operation - where k_iter are zero - */ - - BLIS_ZTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) - - /* - Load b11 of size 4x3 and multiply with alpha - Add the GEMM output to b11 - and peform TRSM operation. - */ - - BLIS_PRE_ZTRSM_SMALL_3x4(AlphaVal,b11,cs_b) - ///implement TRSM/// - /* - Compute 3x3 TRSM block by using GEMM block output in register - a. The 3x4 input (gemm outputs) are stored in combinations of - ymm registers - 1. ymm3, ymm4 2. ymm5, ymm6 3. ymm7, ymm8 - b. Towards the end do in regiser transpose of TRSM output - and store in b11 - */ - ////extract a00 - ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); + //Step 3 + ymm5 = _mm256_add_pd(ymm16, ymm5); #ifndef BLIS_ENABLE_TRSM_PREINVERSION - /*performs dcomplex divison of ymm3 and ymm4 with ymm1*/ - BLIS_ZTRSM_TWO_DIV(ymm3,ymm4) + /*performs dcomplex divison of ymm5 with ymm1*/ + BLIS_ZTRSM_DIV(ymm5) #else - /*performs dcomplex multiplication of ymm3 and ymm4 with ymm1*/ - BLIS_ZTRSM_MUL(ymm3) - BLIS_ZTRSM_MUL(ymm4) + /*performs dcomplex multiplication of ymm5 with ymm1*/ + BLIS_ZTRSM_MUL(ymm5) #endif - //extract a11 - ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack + 1)); - //(ROW1): FMA operations - ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + rs_a*1)); - if(conjtransa) - { - ymm2 = _mm256_mul_pd(ymm2, ymm0); - } - /* Step1 dcomplex multiply ymm2, ymm3 - * Step2 negate the result - * Step3 add ymmx*/ - //Step 1 - ymm14 = _mm256_permute_pd(ymm2, 0x5); - /* Negate the imaginary elements of vec2 */ - ymm14 = _mm256_mul_pd(ymm14, ymm0); - //For ymm3 - /* Multiply vec1 and vec2 */ - ymm13 = _mm256_mul_pd(ymm3, ymm2); /*vec3*/ - /* Multiply vec1 and the modified vec2 */ - ymm14 = _mm256_mul_pd(ymm3, ymm14); /*vec4*/ - /* Horizontally subtract the elements in vec3 and vec4 */ - ymm16 = _mm256_hsub_pd(ymm13, ymm14); - - //For ymm4 - ymm14 = _mm256_permute_pd(ymm2, 0x5); - /* Negate the imaginary elements of vec2 */ - ymm14 = _mm256_mul_pd(ymm14, ymm0); - - ymm13 = _mm256_mul_pd(ymm4, ymm2); - ymm14 = _mm256_mul_pd(ymm4, ymm14); - ymm17 = _mm256_hsub_pd(ymm13, ymm14); - //Step 2 - ymm16 = _mm256_mul_pd(ymm16, ymm15); - ymm17 = _mm256_mul_pd(ymm17, ymm15); - - //Step 3 - ymm5 = _mm256_add_pd(ymm16, ymm5); - ymm6 = _mm256_add_pd(ymm17, ymm6); - - ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + rs_a*2)); - if(conjtransa) - { - ymm2 = _mm256_mul_pd(ymm2, ymm0); - } - //Step 1 - ymm14 = _mm256_permute_pd(ymm2, 0x5); - /* Negate the imaginary elements of vec2 */ - ymm14 = _mm256_mul_pd(ymm14, ymm0); - //For ymm3 - /* Multiply vec1 and vec2 */ - ymm13 = _mm256_mul_pd(ymm3, ymm2); /*vec3*/ - /* Multiply vec1 and the modified vec2 */ - ymm14 = _mm256_mul_pd(ymm3, ymm14); /*vec4*/ - /* Horizontally subtract the elements in vec3 and vec4 */ - ymm16 = _mm256_hsub_pd(ymm13, ymm14); - //For ymm4 - ymm14 = _mm256_permute_pd(ymm2, 0x5); - /* Negate the imaginary elements of vec2 */ - ymm14 = _mm256_mul_pd(ymm14, ymm0); - - ymm13 = _mm256_mul_pd(ymm4, ymm2); - ymm14 = _mm256_mul_pd(ymm4, ymm14); - ymm17 = _mm256_hsub_pd(ymm13, ymm14); - //Step 2 - ymm16 = _mm256_mul_pd(ymm16, ymm15); - ymm17 = _mm256_mul_pd(ymm17, ymm15); - - //Step 3 - ymm7 = _mm256_add_pd(ymm16, ymm7); - ymm8 = _mm256_add_pd(ymm17, ymm8); + _mm_storeu_pd((double *)b11, + _mm256_extractf128_pd(ymm3,0)); + _mm_storeu_pd((double *)(b11 + cs_b), + _mm256_extractf128_pd(ymm5,0)); + m_remainder -= 1; + i += 1; + } + j += 2; + n_remainder -= 2; + } + else if(n_remainder == 1) + { + a01 = L + j*rs_a; + a11 = L + j*cs_a + j*rs_a; + dcomplex *ptr_a10_dup = D_A_pack; + dim_t p_lda = j; + // perform copy of A to packed buffer D_A_pack + if(transa) + { + for(dim_t x =0;x < p_lda;x+=d_nr) + { + ymm0 = _mm256_loadu_pd((double const *)(a01)); + ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); + ymm3 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm4 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm3); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm4); + + ymm0 = _mm256_loadu_pd((double const *)(a01 + 2)); + ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a + 2)); + ymm3 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 2), + ymm3); + + ymm0 = _mm256_loadu_pd((double const *) + (a01 + cs_a * 2)); + ymm1 = _mm256_loadu_pd((double const *) + (a01 + cs_a * 2 + 2)); + ymm5 = _mm256_broadcast_pd((__m128d const *)&zero); + + ymm3 = _mm256_permute2f128_pd(ymm0,ymm5,0x20); + ymm4 = _mm256_permute2f128_pd(ymm0,ymm5,0x31); + ymm5 = _mm256_permute2f128_pd(ymm1,ymm5,0x20); + + _mm_storeu_pd((double *)(ptr_a10_dup + 2), + _mm256_extractf128_pd(ymm3,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + p_lda + 2), + _mm256_extractf128_pd(ymm4,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 2 + 2), + _mm256_extractf128_pd(ymm5, 0)); + a01 += d_nr*cs_a; + ptr_a10_dup += d_nr; + } -#ifndef BLIS_ENABLE_TRSM_PREINVERSION - /*performs dcomplex divison of ymm5 and ymm6 with ymm1*/ - BLIS_ZTRSM_TWO_DIV(ymm5,ymm6) -#else - /*performs dcomplex multiplication of ymm5 and ymm6 with ymm1*/ - BLIS_ZTRSM_MUL(ymm5) - BLIS_ZTRSM_MUL(ymm6) -#endif - a11 += cs_a; - - //extract a22 - ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack + 2)); - //(ROW2): FMA operations - ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + rs_a * 2)); - if(conjtransa) - { - ymm2 = _mm256_mul_pd(ymm2, ymm0); - } - //Step 1 - ymm14 = _mm256_permute_pd(ymm2, 0x5); - /* Negate the imaginary elements of vec2 */ - ymm14 = _mm256_mul_pd(ymm14, ymm0); - - //For ymm5 - /* Multiply vec1 and vec2 */ - ymm13 = _mm256_mul_pd(ymm5, ymm2); /*vec3*/ - /* Multiply vec1 and the modified vec2 */ - ymm14 = _mm256_mul_pd(ymm5, ymm14); /*vec4*/ - /* Horizontally subtract the elements in vec3 and vec4 */ - ymm16 = _mm256_hsub_pd(ymm13, ymm14); - //For ymm6 - ymm14 = _mm256_permute_pd(ymm2, 0x5); - /* Negate the imaginary elements of vec2 */ - ymm14 = _mm256_mul_pd(ymm14, ymm0); - - ymm13 = _mm256_mul_pd(ymm6, ymm2); - ymm14 = _mm256_mul_pd(ymm6, ymm14); - ymm17 = _mm256_hsub_pd(ymm13, ymm14); - //Step 2 - ymm16 = _mm256_mul_pd(ymm16, ymm15); - ymm17 = _mm256_mul_pd(ymm17, ymm15); - //Step 3 - ymm7 = _mm256_add_pd(ymm16, ymm7); - ymm8 = _mm256_add_pd(ymm17, ymm8); + } + else + { + dim_t loop_count = p_lda/2; + for(dim_t x =0;x < loop_count;x++) + { + ymm15 = _mm256_loadu_pd((double const *) + (a01 + rs_a * 0 + x*2)); + _mm256_storeu_pd((double *) + (ptr_a10_dup + p_lda * 0 + x*2), ymm15); + } -#ifndef BLIS_ENABLE_TRSM_PREINVERSION - /*performs dcomplex divison of ymm7 and ymm8 with ymm1*/ - BLIS_ZTRSM_TWO_DIV(ymm7,ymm8) -#else - /*performs dcomplex multiplication of ymm7 and ymm8 with ymm1*/ - BLIS_ZTRSM_MUL(ymm7) - BLIS_ZTRSM_MUL(ymm8) -#endif - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + 2), ymm4); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b + 2), ymm6); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*2 + 2), ymm8); - - } - - dim_t m_remainder = m - i; - if(m_remainder) - { - if(m_remainder == 3) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j*rs_a; - b10 = B + i; - b11 = B + i + j*cs_b; - - k_iter = j; - - /*Fill zeros into ymm registers used in gemm - * accumulations */ - BLIS_SET_YMM_REG_ZEROS - - ///GEMM implementation starts/// - BLIS_ZTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) - - // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_ZTRSM_SMALL_3x4(AlphaVal,b11,cs_b) - - ///implement TRSM/// - ////extract a00 - ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); -#ifndef BLIS_ENABLE_TRSM_PREINVERSION - /*performs dcomplex divison of ymm3 and ymm4 with ymm1*/ - BLIS_ZTRSM_TWO_DIV(ymm3,ymm4) -#else - /*performs dcomplex multiplication of ymm3 and ymm4 - * with ymm1*/ - BLIS_ZTRSM_MUL(ymm3) - BLIS_ZTRSM_MUL(ymm4) -#endif - //extract a11 - ymm1 = _mm256_broadcast_pd((__m128d const *) - (d11_pack + 1)); - //(ROW1): FMA operations - ymm2 = _mm256_broadcast_pd((__m128d const *) - (a11 + rs_a*1)); - if(conjtransa) - { - ymm2 = _mm256_mul_pd(ymm2, ymm0); - } - /* Step1 dcomplex multiply ymm2, ymm3 - * Step2 negate the result - * Step3 add ymmx*/ - //Step 1 - ymm14 = _mm256_permute_pd(ymm2, 0x5); - /* Negate the imaginary elements of vec2 */ - ymm14 = _mm256_mul_pd(ymm14, ymm0); - //For ymm3 - /* Multiply vec1 and vec2 */ - ymm13 = _mm256_mul_pd(ymm3, ymm2); /*vec3*/ - /* Multiply vec1 and the modified vec2 */ - ymm14 = _mm256_mul_pd(ymm3, ymm14); /*vec4*/ - /* Horizontally subtract the elements in vec3 and vec4 */ - ymm16 = _mm256_hsub_pd(ymm13, ymm14); - - //For ymm4 - ymm14 = _mm256_permute_pd(ymm2, 0x5); - /* Negate the imaginary elements of vec2 */ - ymm14 = _mm256_mul_pd(ymm14, ymm0); - - ymm13 = _mm256_mul_pd(ymm4, ymm2); - ymm14 = _mm256_mul_pd(ymm4, ymm14); - ymm17 = _mm256_hsub_pd(ymm13, ymm14); - //Step 2 - ymm16 = _mm256_mul_pd(ymm16, ymm15); - ymm17 = _mm256_mul_pd(ymm17, ymm15); - - //Step 3 - ymm5 = _mm256_add_pd(ymm16, ymm5); - ymm6 = _mm256_add_pd(ymm17, ymm6); - - ymm2 = _mm256_broadcast_pd((__m128d const *) - (a11 + rs_a*2)); - if(conjtransa) - { - ymm2 = _mm256_mul_pd(ymm2, ymm0); - } - //Step 1 - ymm14 = _mm256_permute_pd(ymm2, 0x5); - /* Negate the imaginary elements of vec2 */ - ymm14 = _mm256_mul_pd(ymm14, ymm0); - //For ymm3 - /* Multiply vec1 and vec2 */ - ymm13 = _mm256_mul_pd(ymm3, ymm2); /*vec3*/ - /* Multiply vec1 and the modified vec2 */ - ymm14 = _mm256_mul_pd(ymm3, ymm14); /*vec4*/ - /* Horizontally subtract the elements in vec3 and vec4 */ - ymm16 = _mm256_hsub_pd(ymm13, ymm14); - //For ymm4 - ymm14 = _mm256_permute_pd(ymm2, 0x5); - /* Negate the imaginary elements of vec2 */ - ymm14 = _mm256_mul_pd(ymm14, ymm0); - - ymm13 = _mm256_mul_pd(ymm4, ymm2); - ymm14 = _mm256_mul_pd(ymm4, ymm14); - ymm17 = _mm256_hsub_pd(ymm13, ymm14); - //Step 2 - ymm16 = _mm256_mul_pd(ymm16, ymm15); - ymm17 = _mm256_mul_pd(ymm17, ymm15); - - //Step 3 - ymm7 = _mm256_add_pd(ymm16, ymm7); - ymm8 = _mm256_add_pd(ymm17, ymm8); + dim_t remainder_loop_count = p_lda - loop_count*2; + __m128d xmm0; + if(remainder_loop_count != 0) + { + xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 0 + + loop_count*2)); -#ifndef BLIS_ENABLE_TRSM_PREINVERSION - /*performs dcomplex divison of ymm5 and ymm6 with ymm1*/ - BLIS_ZTRSM_TWO_DIV(ymm5,ymm6) -#else - /*performs dcomplex multiplication of ymm5 and ymm6 with - * ymm1*/ - BLIS_ZTRSM_MUL(ymm5) - BLIS_ZTRSM_MUL(ymm6) + _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + + loop_count*2), xmm0); + } + } + if(!is_unitdiag) + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_pd((__m128d const *)(a11)); + ymm1 = _mm256_broadcast_pd((__m128d const *)&ones); + ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + ymm7 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + /*Taking denomerator multiplication of real & + * imaginary components*/ + ymm4 = _mm256_mul_pd(ymm1, ymm1); + /*Swapping real & imaginary component position for addition with + * respective components*/ + ymm6 = _mm256_permute4x64_pd(ymm4, 0xb1); + ymm4 = _mm256_add_pd(ymm4, ymm6); + /*Negating imaginary component of numerator*/ + ymm1 = _mm256_mul_pd(ymm1, ymm7); + /*Dividing numerator by denominator*/ + ymm1 = _mm256_div_pd(ymm1, ymm4); #endif - a11 += cs_a; - - //extract a22 - ymm1 = _mm256_broadcast_pd((__m128d const *) - (d11_pack + 2)); - //(ROW2): FMA operations - ymm2 = _mm256_broadcast_pd((__m128d const *) - (a11 + rs_a * 2)); - if(conjtransa) - { - ymm2 = _mm256_mul_pd(ymm2, ymm0); - } - //Step 1 - ymm14 = _mm256_permute_pd(ymm2, 0x5); - /* Negate the imaginary elements of vec2 */ - ymm14 = _mm256_mul_pd(ymm14, ymm0); - - //For ymm5 - /* Multiply vec1 and vec2 */ - ymm13 = _mm256_mul_pd(ymm5, ymm2); /*vec3*/ - /* Multiply vec1 and the modified vec2 */ - ymm14 = _mm256_mul_pd(ymm5, ymm14); /*vec4*/ - /* Horizontally subtract the elements in vec3 and vec4 */ - ymm16 = _mm256_hsub_pd(ymm13, ymm14); - //For ymm6 - ymm14 = _mm256_permute_pd(ymm2, 0x5); - /* Negate the imaginary elements of vec2 */ - ymm14 = _mm256_mul_pd(ymm14, ymm0); - - ymm13 = _mm256_mul_pd(ymm6, ymm2); - ymm14 = _mm256_mul_pd(ymm6, ymm14); - ymm17 = _mm256_hsub_pd(ymm13, ymm14); - //Step 2 - ymm16 = _mm256_mul_pd(ymm16, ymm15); - ymm17 = _mm256_mul_pd(ymm17, ymm15); - //Step 3 - ymm7 = _mm256_add_pd(ymm16, ymm7); - ymm8 = _mm256_add_pd(ymm17, ymm8); + } + else + { + ymm1 = _mm256_broadcast_pd((__m128d const *)&ones); + } + _mm256_storeu_pd((double *)(d11_pack), ymm1); + for(i = 0; (i+d_mr-1) < m; i += d_mr) //loop along 'M' direction + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; + b10 = B + i; + b11 = B + i + j*cs_b; -#ifndef BLIS_ENABLE_TRSM_PREINVERSION - /*performs dcomplex divison of ymm7 and ymm8 with ymm1*/ - BLIS_ZTRSM_TWO_DIV(ymm7,ymm8) -#else - /*performs dcomplex multiplication of ymm7 and ymm8 - * with ymm1*/ - BLIS_ZTRSM_MUL(ymm7) - BLIS_ZTRSM_MUL(ymm8) -#endif + k_iter = j; - _mm256_storeu_pd((double *)b11, ymm3); - _mm_storeu_pd((double *)(b11 + 2), - _mm256_extractf128_pd(ymm4,0)); - - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm_storeu_pd((double *)(b11 + cs_b + 2), - _mm256_extractf128_pd(ymm6,0)); - - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm_storeu_pd((double *)(b11 + cs_b*2 + 2), - _mm256_extractf128_pd(ymm8,0)); - - m_remainder -= 3; - i += 3; - } - else if(m_remainder == 2) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j*rs_a; - b10 = B + i; - b11 = B + i + j*cs_b; - - k_iter = j; - - /*Fill zeros into ymm registers used in gemm - * accumulations */ - BLIS_SET_YMM_REG_ZEROS - - ///GEMM implementation starts/// - BLIS_ZTRSM_SMALL_GEMM_3nx2m(a01,b10,cs_b,p_lda,k_iter) - - // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_ZTRSM_SMALL_3x2(AlphaVal,b11,cs_b) - ////extract a00 - ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + ///GEMM implementation starts/// + BLIS_ZTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_PRE_ZTRSM_SMALL_1x4(b11,cs_b,AlphaVal) + ///implement TRSM/// + ////extract a00 + ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); #ifndef BLIS_ENABLE_TRSM_PREINVERSION - /*performs dcomplex divison of ymm3 with ymm1*/ - BLIS_ZTRSM_DIV(ymm3) + /*performs dcomplex divison of ymm3 and ymm4 with ymm1*/ + BLIS_ZTRSM_TWO_DIV(ymm3,ymm4) #else - /*performs dcomplex multiplication of ymm3 - * with ymm1*/ - BLIS_ZTRSM_MUL(ymm3) + /*performs dcomplex multiplication of ymm3 and ymm4 with ymm1*/ + BLIS_ZTRSM_MUL(ymm3) + BLIS_ZTRSM_MUL(ymm4) #endif - //extract a11 - ymm1 = _mm256_broadcast_pd((__m128d const *) - (d11_pack + 1)); - //(ROW1): FMA operations - ymm2 = _mm256_broadcast_pd((__m128d const *) - (a11 + rs_a*1)); - if(conjtransa) - { - ymm2 = _mm256_mul_pd(ymm2, ymm0); - } - /* Step1 dcomplex multiply ymm2, ymm3 - * Step2 negate the result - * Step3 add ymmx*/ - //Step 1 - ymm14 = _mm256_permute_pd(ymm2, 0x5); - /* Negate the imaginary elements of vec2 */ - ymm14 = _mm256_mul_pd(ymm14, ymm0); - //For ymm3 - /* Multiply vec1 and vec2 */ - ymm13 = _mm256_mul_pd(ymm3, ymm2); /*vec3*/ - /* Multiply vec1 and the modified vec2 */ - ymm14 = _mm256_mul_pd(ymm3, ymm14); /*vec4*/ - /* Horizontally subtract the elements in vec3 and vec4 */ - ymm16 = _mm256_hsub_pd(ymm13, ymm14); - - //Step 2 - ymm16 = _mm256_mul_pd(ymm16, ymm15); - - //Step 3 - ymm5 = _mm256_add_pd(ymm16, ymm5); - - ymm2 = _mm256_broadcast_pd((__m128d const *) - (a11 + rs_a*2)); - if(conjtransa) - { - ymm2 = _mm256_mul_pd(ymm2, ymm0); - } - //Step 1 - ymm14 = _mm256_permute_pd(ymm2, 0x5); - /* Negate the imaginary elements of vec2 */ - ymm14 = _mm256_mul_pd(ymm14, ymm0); - //For ymm3 - /* Multiply vec1 and vec2 */ - ymm13 = _mm256_mul_pd(ymm3, ymm2); /*vec3*/ - /* Multiply vec1 and the modified vec2 */ - ymm14 = _mm256_mul_pd(ymm3, ymm14); /*vec4*/ - /* Horizontally subtract the elements in vec3 and vec4 */ - ymm16 = _mm256_hsub_pd(ymm13, ymm14); - //Step 2 - ymm16 = _mm256_mul_pd(ymm16, ymm15); - - //Step 3 - ymm7 = _mm256_add_pd(ymm16, ymm7); + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + 2),ymm4); + } + dim_t m_remainder = m - i; + if(m_remainder == 3) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; + b10 = B + i; + b11 = B + i + j*cs_b; -#ifndef BLIS_ENABLE_TRSM_PREINVERSION - /*performs dcomplex divison of ymm5 with ymm1*/ - BLIS_ZTRSM_DIV(ymm5) -#else - /*performs dcomplex multiplication of ymm5 - * with ymm1*/ - BLIS_ZTRSM_MUL(ymm5) -#endif - a11 += cs_a; - - //extract a22 - ymm1 = _mm256_broadcast_pd((__m128d const *) - (d11_pack + 2)); - //(ROW2): FMA operations - ymm2 = _mm256_broadcast_pd((__m128d const *) - (a11 + rs_a * 2)); - if(conjtransa) - { - ymm2 = _mm256_mul_pd(ymm2, ymm0); - } - //Step 1 - ymm14 = _mm256_permute_pd(ymm2, 0x5); - /* Negate the imaginary elements of vec2 */ - ymm14 = _mm256_mul_pd(ymm14, ymm0); - - //For ymm5 - /* Multiply vec1 and vec2 */ - ymm13 = _mm256_mul_pd(ymm5, ymm2); /*vec3*/ - /* Multiply vec1 and the modified vec2 */ - ymm14 = _mm256_mul_pd(ymm5, ymm14); /*vec4*/ - /* Horizontally subtract the elements in vec3 and vec4 */ - ymm16 = _mm256_hsub_pd(ymm13, ymm14); - //Step 2 - ymm16 = _mm256_mul_pd(ymm16, ymm15); - //Step 3 - ymm7 = _mm256_add_pd(ymm16, ymm7); + k_iter = j; + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS -#ifndef BLIS_ENABLE_TRSM_PREINVERSION - /*performs dcomplex divison of ymm7 with ymm1*/ - BLIS_ZTRSM_DIV(ymm7) -#else - /*performs dcomplex multiplication of ymm7 - * with ymm1*/ - BLIS_ZTRSM_MUL(ymm7) -#endif - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - m_remainder -= 2; - i += 2; - } - else if(m_remainder == 1) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j*rs_a; - b10 = B + i; - b11 = B + i + j*cs_b; - - k_iter = j; - - /*Fill zeros into ymm registers used in gemm - * accumulations */ - BLIS_SET_YMM_REG_ZEROS - - ///GEMM implementation starts/// - BLIS_ZTRSM_SMALL_GEMM_3nx2m(a01,b10,cs_b,p_lda,k_iter) - - // Load b11 of size 2x3 and multiply with alpha - BLIS_PRE_ZTRSM_SMALL_3x2(AlphaVal,b11,cs_b) - - ///implement TRSM/// - ////extract a00 - ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); -#ifndef BLIS_ENABLE_TRSM_PREINVERSION - /*performs dcomplex divison of ymm3 with ymm1*/ - BLIS_ZTRSM_DIV(ymm3) -#else - /*performs dcomplex multiplication of ymm3 - * with ymm1*/ - BLIS_ZTRSM_MUL(ymm3) -#endif - //extract a11 - ymm1 = _mm256_broadcast_pd((__m128d const *) - (d11_pack + 1)); - //(ROW1): FMA operations - ymm2 = _mm256_broadcast_pd((__m128d const *) - (a11 + rs_a*1)); - if(conjtransa) - { - ymm2 = _mm256_mul_pd(ymm2, ymm0); - } - /* Step1 dcomplex multiply ymm2, ymm3 - * Step2 negate the result - * Step3 add ymmx*/ - //Step 1 - ymm14 = _mm256_permute_pd(ymm2, 0x5); - /* Negate the imaginary elements of vec2 */ - ymm14 = _mm256_mul_pd(ymm14, ymm0); - //For ymm3 - /* Multiply vec1 and vec2 */ - ymm13 = _mm256_mul_pd(ymm3, ymm2); /*vec3*/ - /* Multiply vec1 and the modified vec2 */ - ymm14 = _mm256_mul_pd(ymm3, ymm14); /*vec4*/ - /* Horizontally subtract the elements in vec3 and vec4 */ - ymm16 = _mm256_hsub_pd(ymm13, ymm14); - - //Step 2 - ymm16 = _mm256_mul_pd(ymm16, ymm15); - - //Step 3 - ymm5 = _mm256_add_pd(ymm16, ymm5); - - ymm2 = _mm256_broadcast_pd((__m128d const *) - (a11 + rs_a*2)); - if(conjtransa) - { - ymm2 = _mm256_mul_pd(ymm2, ymm0); - } - //Step 1 - ymm14 = _mm256_permute_pd(ymm2, 0x5); - /* Negate the imaginary elements of vec2 */ - ymm14 = _mm256_mul_pd(ymm14, ymm0); - //For ymm3 - /* Multiply vec1 and vec2 */ - ymm13 = _mm256_mul_pd(ymm3, ymm2); /*vec3*/ - /* Multiply vec1 and the modified vec2 */ - ymm14 = _mm256_mul_pd(ymm3, ymm14); /*vec4*/ - /* Horizontally subtract the elements in vec3 and vec4 */ - ymm16 = _mm256_hsub_pd(ymm13, ymm14); - //Step 2 - ymm16 = _mm256_mul_pd(ymm16, ymm15); - - //Step 3 - ymm7 = _mm256_add_pd(ymm16, ymm7); + ///GEMM implementation starts/// + BLIS_ZTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_PRE_ZTRSM_SMALL_1x3(b11,cs_b,AlphaVal) + ///implement TRSM/// + ////extract a00 + ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); #ifndef BLIS_ENABLE_TRSM_PREINVERSION - /*performs dcomplex divison of ymm5 with ymm1*/ - BLIS_ZTRSM_DIV(ymm5) + /*performs dcomplex divison of ymm3 and ymm4 with ymm1*/ + BLIS_ZTRSM_TWO_DIV(ymm3,ymm4) #else - /*performs dcomplex multiplication of ymm5 - * with ymm1*/ - BLIS_ZTRSM_MUL(ymm5) + /*performs dcomplex multiplication of ymm3 and ymm4 with ymm1*/ + BLIS_ZTRSM_MUL(ymm3) + BLIS_ZTRSM_MUL(ymm4) #endif - a11 += cs_a; - - //extract a22 - ymm1 = _mm256_broadcast_pd((__m128d const *) - (d11_pack + 2)); - //(ROW2): FMA operations - ymm2 = _mm256_broadcast_pd((__m128d const *) - (a11 + rs_a * 2)); - if(conjtransa) - { - ymm2 = _mm256_mul_pd(ymm2, ymm0); - } - //Step 1 - ymm14 = _mm256_permute_pd(ymm2, 0x5); - /* Negate the imaginary elements of vec2 */ - ymm14 = _mm256_mul_pd(ymm14, ymm0); - - //For ymm5 - /* Multiply vec1 and vec2 */ - ymm13 = _mm256_mul_pd(ymm5, ymm2); /*vec3*/ - /* Multiply vec1 and the modified vec2 */ - ymm14 = _mm256_mul_pd(ymm5, ymm14); /*vec4*/ - /* Horizontally subtract the elements in vec3 and vec4 */ - ymm16 = _mm256_hsub_pd(ymm13, ymm14); - //Step 2 - ymm16 = _mm256_mul_pd(ymm16, ymm15); - //Step 3 - ymm7 = _mm256_add_pd(ymm16, ymm7); -#ifndef BLIS_ENABLE_TRSM_PREINVERSION - /*performs dcomplex divison of ymm7 with ymm1*/ - BLIS_ZTRSM_DIV(ymm7) -#else - /*performs dcomplex multiplication of ymm7 - * with ymm1*/ - BLIS_ZTRSM_MUL(ymm7) -#endif + _mm256_storeu_pd((double *)b11, ymm3); + _mm_storeu_pd((double *)(b11 + 2), + _mm256_extractf128_pd(ymm4,0)); + m_remainder -= 3; + i += 3; + } + if(m_remainder == 2) + { + a01 = D_A_pack; + //pointer to block of A to be used for TRSM + a11 = L + j*cs_a + j*rs_a; + //pointer to block of B to be used in GEMM + b10 = B + i; + //pointer to block of B to be used for TRSM + b11 = B + i + j*cs_b; + //number of GEMM operations to be done + k_iter = j; + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + ///GEMM implementation starts/// + BLIS_ZTRSM_SMALL_GEMM_1nx2m(a01,b10,cs_b,p_lda,k_iter) - _mm_storeu_pd((double *)b11, - _mm256_extractf128_pd(ymm3,0)); - _mm_storeu_pd((double *)(b11 + cs_b), - _mm256_extractf128_pd(ymm5,0)); - _mm_storeu_pd((double *)(b11 + cs_b*2), - _mm256_extractf128_pd(ymm7,0)); - - m_remainder -= 1; - i += 1; - } - } - - } - dim_t n_remainder = n - j; - if(n_remainder == 2) - { - a01 = L + j*rs_a; - a11 = L + j*cs_a + j*rs_a; - dcomplex *ptr_a10_dup = D_A_pack; - - dim_t p_lda = j; - // perform copy of A to packed buffer D_A_pack - - if(transa) - { - for(dim_t x =0;x < p_lda;x+=d_nr) - { - ymm0 = _mm256_loadu_pd((double const *)(a01)); - ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); - ymm3 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm4 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - - _mm256_storeu_pd((double *)(ptr_a10_dup), ymm3); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm4); - - ymm0 = _mm256_loadu_pd((double const *)(a01 + 2)); - ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a + 2)); - ymm3 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 2), - ymm3); - - ymm0 = _mm256_loadu_pd((double const *) - (a01 + cs_a * 2)); - ymm1 = _mm256_loadu_pd((double const *) - (a01 + cs_a * 2 + 2)); - ymm5 = _mm256_broadcast_pd((__m128d const *)&zero); - - ymm3 = _mm256_permute2f128_pd(ymm0,ymm5,0x20); - ymm4 = _mm256_permute2f128_pd(ymm0,ymm5,0x31); - ymm5 = _mm256_permute2f128_pd(ymm1,ymm5,0x20); - - _mm_storeu_pd((double *)(ptr_a10_dup + 2), - _mm256_extractf128_pd(ymm3,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + p_lda + 2), - _mm256_extractf128_pd(ymm4,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 2 + 2), - _mm256_extractf128_pd(ymm5, 0)); - a01 += d_nr*cs_a; - ptr_a10_dup += d_nr; - } - } - else - { - dim_t loop_count = p_lda/2; - - for(dim_t x =0;x < loop_count;x++) - { - ymm15 = _mm256_loadu_pd((double const *) - (a01 + rs_a * 0 + x*2)); - _mm256_storeu_pd((double *) - (ptr_a10_dup + p_lda * 0 + x*2), ymm15); - ymm15 = _mm256_loadu_pd((double const *) - (a01 + rs_a * 1 + x*2)); - _mm256_storeu_pd((double *) - (ptr_a10_dup + p_lda * 1 + x*2), - ymm15); - } - - dim_t remainder_loop_count = p_lda - loop_count*2; - - __m128d xmm0; - if(remainder_loop_count != 0) - { - xmm0 = _mm_loadu_pd((double const *) - (a01 + rs_a * 0 + loop_count*2)); - _mm_storeu_pd((double *) - (ptr_a10_dup + p_lda * 0 + loop_count*2), - xmm0); - xmm0 = _mm_loadu_pd((double const *) - (a01 + rs_a * 1 + loop_count*2)); - _mm_storeu_pd((double *) - (ptr_a10_dup + p_lda * 1 + loop_count*2), - xmm0); - } - } - if(!is_unitdiag) - { - if(transa) - { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_pd((__m128d const *)(a11)); - ymm1 = _mm256_broadcast_pd((__m128d const *) - (a11+cs_a*1 + 1)); - } - else - { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_pd((__m128d const *)(a11)); - ymm1 = _mm256_broadcast_pd((__m128d const *) - (a11+rs_a*1 + 1)); - } - ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); -#ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm7 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - /*Taking denomerator multiplication of real & - * imaginary components*/ - ymm4 = _mm256_mul_pd(ymm1, ymm1); - /*Swapping real & imaginary component position for addition with - * respective components*/ - ymm6 = _mm256_permute4x64_pd(ymm4, 0xb1); - ymm4 = _mm256_add_pd(ymm4, ymm6); - /*Negating imaginary component of numerator*/ - ymm1 = _mm256_mul_pd(ymm1, ymm7); - /*Dividing numerator by denominator*/ - ymm1 = _mm256_div_pd(ymm1, ymm4); -#endif - } - else - { - ymm1 = _mm256_broadcast_pd((__m128d const *)&ones); - } - _mm256_storeu_pd((double *)(d11_pack), ymm1); - - for(i = 0; (i+d_mr-1) < m; i += d_mr) //loop along 'M' direction - { - a01 = D_A_pack; - a11 = L + j*cs_a + j*rs_a; - b10 = B + i; - b11 = B + i + j*cs_b; - - k_iter = j; - - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS - ///GEMM implementation starts/// - BLIS_ZTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) - BLIS_PRE_ZTRSM_SMALL_2x4(AlphaVal,b11,cs_b) - ///implement TRSM/// - ////extract a00 - ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); -#ifndef BLIS_ENABLE_TRSM_PREINVERSION - /*performs dcomplex divison of ymm3 and ymm4 with ymm1*/ - BLIS_ZTRSM_TWO_DIV(ymm3,ymm4) -#else - /*performs dcomplex multiplication of ymm3 and ymm4 with ymm1*/ - BLIS_ZTRSM_MUL(ymm3) - BLIS_ZTRSM_MUL(ymm4) -#endif - //extract a11 - ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack + 1)); - //(ROW1): FMA operations - ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + rs_a*1)); - if(conjtransa) - { - ymm2 = _mm256_mul_pd(ymm2, ymm0); - } - /* Step1 dcomplex multiply ymm2, ymm3 - * Step2 negate the result - * Step3 add ymmx*/ - //Step 1 - ymm14 = _mm256_permute_pd(ymm2, 0x5); - /* Negate the imaginary elements of vec2 */ - ymm14 = _mm256_mul_pd(ymm14, ymm0); - //For ymm3 - /* Multiply vec1 and vec2 */ - ymm13 = _mm256_mul_pd(ymm3, ymm2); /*vec3*/ - /* Multiply vec1 and the modified vec2 */ - ymm14 = _mm256_mul_pd(ymm3, ymm14); /*vec4*/ - /* Horizontally subtract the elements in vec3 and vec4 */ - ymm16 = _mm256_hsub_pd(ymm13, ymm14); - - //For ymm4 - ymm14 = _mm256_permute_pd(ymm2, 0x5); - /* Negate the imaginary elements of vec2 */ - ymm14 = _mm256_mul_pd(ymm14, ymm0); - - ymm13 = _mm256_mul_pd(ymm4, ymm2); - ymm14 = _mm256_mul_pd(ymm4, ymm14); - ymm17 = _mm256_hsub_pd(ymm13, ymm14); - //Step 2 - ymm16 = _mm256_mul_pd(ymm16, ymm15); - ymm17 = _mm256_mul_pd(ymm17, ymm15); - - //Step 3 - ymm5 = _mm256_add_pd(ymm16, ymm5); - ymm6 = _mm256_add_pd(ymm17, ymm6); + // Load b11 of size 4x6 and multiply with alpha + BLIS_PRE_ZTRSM_SMALL_1x2(AlphaVal,b11,cs_b) + ///implement TRSM/// + ////extract a00 + ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); #ifndef BLIS_ENABLE_TRSM_PREINVERSION - /*performs dcomplex divison of ymm5 and ymm6 with ymm1*/ - BLIS_ZTRSM_TWO_DIV(ymm5,ymm6) -#else - /*performs dcomplex multiplication of ymm5 and ymm6 with ymm1*/ - BLIS_ZTRSM_MUL(ymm5) - BLIS_ZTRSM_MUL(ymm6) -#endif - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + 2), ymm4); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b + 2), ymm6); - } - dim_t m_remainder = m - i; - if(m_remainder == 3) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j*rs_a; - b10 = B + i; - b11 = B + i + j*cs_b; - - k_iter = j; - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS - - ///GEMM implementation starts/// - BLIS_ZTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) - - // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_ZTRSM_SMALL_3x4(AlphaVal,b11,cs_b) - - ///implement TRSM/// - ////extract a00 - ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); -#ifndef BLIS_ENABLE_TRSM_PREINVERSION - /*performs dcomplex divison of ymm3 and ymm4 with ymm1*/ - BLIS_ZTRSM_TWO_DIV(ymm3,ymm4) -#else - /*performs dcomplex multiplication of ymm3 and ymm4 with ymm1*/ - BLIS_ZTRSM_MUL(ymm3) - BLIS_ZTRSM_MUL(ymm4) -#endif - //extract a11 - ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack + 1)); - //(ROW1): FMA operations - ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + rs_a*1)); - if(conjtransa) - { - ymm2 = _mm256_mul_pd(ymm2, ymm0); - } - /* Step1 dcomplex multiply ymm2, ymm3 - * Step2 negate the result - * Step3 add ymmx*/ - //Step 1 - ymm14 = _mm256_permute_pd(ymm2, 0x5); - /* Negate the imaginary elements of vec2 */ - ymm14 = _mm256_mul_pd(ymm14, ymm0); - //For ymm3 - /* Multiply vec1 and vec2 */ - ymm13 = _mm256_mul_pd(ymm3, ymm2); /*vec3*/ - /* Multiply vec1 and the modified vec2 */ - ymm14 = _mm256_mul_pd(ymm3, ymm14); /*vec4*/ - /* Horizontally subtract the elements in vec3 and vec4 */ - ymm16 = _mm256_hsub_pd(ymm13, ymm14); - - //For ymm4 - ymm14 = _mm256_permute_pd(ymm2, 0x5); - /* Negate the imaginary elements of vec2 */ - ymm14 = _mm256_mul_pd(ymm14, ymm0); - - ymm13 = _mm256_mul_pd(ymm4, ymm2); - ymm14 = _mm256_mul_pd(ymm4, ymm14); - ymm17 = _mm256_hsub_pd(ymm13, ymm14); - //Step 2 - ymm16 = _mm256_mul_pd(ymm16, ymm15); - ymm17 = _mm256_mul_pd(ymm17, ymm15); - - //Step 3 - ymm5 = _mm256_add_pd(ymm16, ymm5); - ymm6 = _mm256_add_pd(ymm17, ymm6); -#ifndef BLIS_ENABLE_TRSM_PREINVERSION - /*performs dcomplex divison of ymm5 and ymm6 with ymm1*/ - BLIS_ZTRSM_TWO_DIV(ymm5,ymm6) + /*performs dcomplex divison of ymm3 with ymm1*/ + BLIS_ZTRSM_DIV(ymm3) #else - /*performs dcomplex multiplication of ymm5 and ymm6 with ymm1*/ - BLIS_ZTRSM_MUL(ymm5) - BLIS_ZTRSM_MUL(ymm6) + /*performs dcomplex multiplication of ymm3 with ymm1*/ + BLIS_ZTRSM_MUL(ymm3) #endif - _mm256_storeu_pd((double *)b11, ymm3); - _mm_storeu_pd((double *)(b11 + 2), - _mm256_extractf128_pd(ymm4,0)); - - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm_storeu_pd((double *)(b11 + cs_b + 2), - _mm256_extractf128_pd(ymm6,0)); - m_remainder -= 3; - i += 3; - } - if(m_remainder == 2) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j*rs_a; - b10 = B + i; - b11 = B + i + j*cs_b; - - k_iter = j; - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS - - ///GEMM implementation starts/// - BLIS_ZTRSM_SMALL_GEMM_2nx2m(a01,b10,cs_b,p_lda,k_iter) - - // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_ZTRSM_SMALL_2x2(AlphaVal,b11,cs_b) - - ///implement TRSM/// - ////extract a00 - ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); -#ifndef BLIS_ENABLE_TRSM_PREINVERSION - /*performs dcomplex divison of ymm3 with ymm1*/ - BLIS_ZTRSM_DIV(ymm3) -#else - /*performs dcomplex multiplication of ymm3 with ymm1*/ - BLIS_ZTRSM_MUL(ymm3) -#endif - //extract a11 - ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack + 1)); - //(ROW1): FMA operations - ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + rs_a*1)); - if(conjtransa) - { - ymm2 = _mm256_mul_pd(ymm2, ymm0); - } - /* Step1 dcomplex multiply ymm2, ymm3 - * Step2 negate the result - * Step3 add ymmx*/ - //Step 1 - ymm14 = _mm256_permute_pd(ymm2, 0x5); - /* Negate the imaginary elements of vec2 */ - ymm14 = _mm256_mul_pd(ymm14, ymm0); - //For ymm3 - /* Multiply vec1 and vec2 */ - ymm13 = _mm256_mul_pd(ymm3, ymm2); /*vec3*/ - /* Multiply vec1 and the modified vec2 */ - ymm14 = _mm256_mul_pd(ymm3, ymm14); /*vec4*/ - /* Horizontally subtract the elements in vec3 and vec4 */ - ymm16 = _mm256_hsub_pd(ymm13, ymm14); - - //Step 2 - ymm16 = _mm256_mul_pd(ymm16, ymm15); - - //Step 3 - ymm5 = _mm256_add_pd(ymm16, ymm5); -#ifndef BLIS_ENABLE_TRSM_PREINVERSION - /*performs dcomplex divison of ymm5 with ymm1*/ - BLIS_ZTRSM_DIV(ymm5) -#else - /*performs dcomplex multiplication of ymm5 with ymm1*/ - BLIS_ZTRSM_MUL(ymm5) -#endif + _mm256_storeu_pd((double *)b11, ymm3); + m_remainder -= 2; + i += 2; + } + if(m_remainder == 1) + { + a01 = D_A_pack; + //pointer to block of A to be used for TRSM + a11 = L + j*cs_a + j*rs_a; + //pointer to block of B to be used in GEMM + b10 = B + i; + //pointer to block of B to be used for TRSM + b11 = B + i + j*cs_b; - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - m_remainder -= 2; - i += 2; - } - if(m_remainder == 1) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j*rs_a; - b10 = B + i; - b11 = B + i + j*cs_b; - - k_iter = j; - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS - - ///GEMM implementation starts/// - BLIS_ZTRSM_SMALL_GEMM_2nx2m(a01,b10,cs_b,p_lda,k_iter) - - // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_ZTRSM_SMALL_2x2(AlphaVal,b11,cs_b) - - ///implement TRSM/// - ////extract a00 - ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); -#ifndef BLIS_ENABLE_TRSM_PREINVERSION - /*performs dcomplex divison of ymm3 with ymm1*/ - BLIS_ZTRSM_DIV(ymm3) -#else - /*performs dcomplex multiplication of ymm3 with ymm1*/ - BLIS_ZTRSM_MUL(ymm3) -#endif - //extract a11 - ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack + 1)); - //(ROW1): FMA operations - ymm2 = _mm256_broadcast_pd((__m128d const *)(a11 + rs_a*1)); - if(conjtransa) - { - ymm2 = _mm256_mul_pd(ymm2, ymm0); - } - /* Step1 dcomplex multiply ymm2, ymm3 - * Step2 negate the result - * Step3 add ymmx*/ - //Step 1 - ymm14 = _mm256_permute_pd(ymm2, 0x5); - /* Negate the imaginary elements of vec2 */ - ymm14 = _mm256_mul_pd(ymm14, ymm0); - //For ymm3 - /* Multiply vec1 and vec2 */ - ymm13 = _mm256_mul_pd(ymm3, ymm2); /*vec3*/ - /* Multiply vec1 and the modified vec2 */ - ymm14 = _mm256_mul_pd(ymm3, ymm14); /*vec4*/ - /* Horizontally subtract the elements in vec3 and vec4 */ - ymm16 = _mm256_hsub_pd(ymm13, ymm14); - - //Step 2 - ymm16 = _mm256_mul_pd(ymm16, ymm15); - - //Step 3 - ymm5 = _mm256_add_pd(ymm16, ymm5); -#ifndef BLIS_ENABLE_TRSM_PREINVERSION - /*performs dcomplex divison of ymm5 with ymm1*/ - BLIS_ZTRSM_DIV(ymm5) -#else - /*performs dcomplex multiplication of ymm5 with ymm1*/ - BLIS_ZTRSM_MUL(ymm5) -#endif - _mm_storeu_pd((double *)b11, - _mm256_extractf128_pd(ymm3,0)); - _mm_storeu_pd((double *)(b11 + cs_b), - _mm256_extractf128_pd(ymm5,0)); - m_remainder -= 1; - i += 1; - } - j += 2; - n_remainder -= 2; - } - else if(n_remainder == 1) - { - a01 = L + j*rs_a; - a11 = L + j*cs_a + j*rs_a; - dcomplex *ptr_a10_dup = D_A_pack; - dim_t p_lda = j; - // perform copy of A to packed buffer D_A_pack - - if(transa) - { - for(dim_t x =0;x < p_lda;x+=d_nr) - { - ymm0 = _mm256_loadu_pd((double const *)(a01)); - ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); - ymm3 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm4 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - - _mm256_storeu_pd((double *)(ptr_a10_dup), ymm3); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm4); - - ymm0 = _mm256_loadu_pd((double const *)(a01 + 2)); - ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a + 2)); - ymm3 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 2), - ymm3); - - ymm0 = _mm256_loadu_pd((double const *) - (a01 + cs_a * 2)); - ymm1 = _mm256_loadu_pd((double const *) - (a01 + cs_a * 2 + 2)); - ymm5 = _mm256_broadcast_pd((__m128d const *)&zero); - - ymm3 = _mm256_permute2f128_pd(ymm0,ymm5,0x20); - ymm4 = _mm256_permute2f128_pd(ymm0,ymm5,0x31); - ymm5 = _mm256_permute2f128_pd(ymm1,ymm5,0x20); - - _mm_storeu_pd((double *)(ptr_a10_dup + 2), - _mm256_extractf128_pd(ymm3,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + p_lda + 2), - _mm256_extractf128_pd(ymm4,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 2 + 2), - _mm256_extractf128_pd(ymm5, 0)); - a01 += d_nr*cs_a; - ptr_a10_dup += d_nr; - } - - } - else - { - dim_t loop_count = p_lda/2; - - for(dim_t x =0;x < loop_count;x++) - { - ymm15 = _mm256_loadu_pd((double const *) - (a01 + rs_a * 0 + x*2)); - _mm256_storeu_pd((double *) - (ptr_a10_dup + p_lda * 0 + x*2), ymm15); - } - - dim_t remainder_loop_count = p_lda - loop_count*2; - - __m128d xmm0; - if(remainder_loop_count != 0) - { - xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 0 + - loop_count*2)); - - _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + - loop_count*2), xmm0); - } - } - if(!is_unitdiag) - { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_pd((__m128d const *)(a11)); - ymm1 = _mm256_broadcast_pd((__m128d const *)&ones); - ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); -#ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm7 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - /*Taking denomerator multiplication of real & - * imaginary components*/ - ymm4 = _mm256_mul_pd(ymm1, ymm1); - /*Swapping real & imaginary component position for addition with - * respective components*/ - ymm6 = _mm256_permute4x64_pd(ymm4, 0xb1); - ymm4 = _mm256_add_pd(ymm4, ymm6); - /*Negating imaginary component of numerator*/ - ymm1 = _mm256_mul_pd(ymm1, ymm7); - /*Dividing numerator by denominator*/ - ymm1 = _mm256_div_pd(ymm1, ymm4); -#endif - } - else - { - ymm1 = _mm256_broadcast_pd((__m128d const *)&ones); - } - _mm256_storeu_pd((double *)(d11_pack), ymm1); - - for(i = 0; (i+d_mr-1) < m; i += d_mr) //loop along 'M' direction - { - a01 = D_A_pack; - a11 = L + j*cs_a + j*rs_a; - b10 = B + i; - b11 = B + i + j*cs_b; - - k_iter = j; - - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS - ///GEMM implementation starts/// - BLIS_ZTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) - BLIS_PRE_ZTRSM_SMALL_1x4(b11,cs_b,AlphaVal) - ///implement TRSM/// - ////extract a00 - ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); -#ifndef BLIS_ENABLE_TRSM_PREINVERSION - /*performs dcomplex divison of ymm3 and ymm4 with ymm1*/ - BLIS_ZTRSM_TWO_DIV(ymm3,ymm4) -#else - /*performs dcomplex multiplication of ymm3 and ymm4 with ymm1*/ - BLIS_ZTRSM_MUL(ymm3) - BLIS_ZTRSM_MUL(ymm4) -#endif - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + 2),ymm4); - } - dim_t m_remainder = m - i; - if(m_remainder == 3) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j*rs_a; - b10 = B + i; - b11 = B + i + j*cs_b; - - k_iter = j; - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS - - ///GEMM implementation starts/// - BLIS_ZTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) - BLIS_PRE_ZTRSM_SMALL_1x3(b11,cs_b,AlphaVal) - - ///implement TRSM/// - ////extract a00 - ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); -#ifndef BLIS_ENABLE_TRSM_PREINVERSION - /*performs dcomplex divison of ymm3 and ymm4 with ymm1*/ - BLIS_ZTRSM_TWO_DIV(ymm3,ymm4) -#else - /*performs dcomplex multiplication of ymm3 and ymm4 with ymm1*/ - BLIS_ZTRSM_MUL(ymm3) - BLIS_ZTRSM_MUL(ymm4) -#endif + //number of GEMM operations to be done(in blocks of 4x4) + k_iter = j; - _mm256_storeu_pd((double *)b11, ymm3); - _mm_storeu_pd((double *)(b11 + 2), - _mm256_extractf128_pd(ymm4,0)); - m_remainder -= 3; - i += 3; - } - if(m_remainder == 2) - { - a01 = D_A_pack; - //pointer to block of A to be used for TRSM - a11 = L + j*cs_a + j*rs_a; - //pointer to block of B to be used in GEMM - b10 = B + i; - //pointer to block of B to be used for TRSM - b11 = B + i + j*cs_b; - //number of GEMM operations to be done - k_iter = j; - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS - - ///GEMM implementation starts/// - BLIS_ZTRSM_SMALL_GEMM_1nx2m(a01,b10,cs_b,p_lda,k_iter) - - // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_ZTRSM_SMALL_1x2(AlphaVal,b11,cs_b) - - ///implement TRSM/// - ////extract a00 - ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); -#ifndef BLIS_ENABLE_TRSM_PREINVERSION - /*performs dcomplex divison of ymm3 with ymm1*/ - BLIS_ZTRSM_DIV(ymm3) -#else - /*performs dcomplex multiplication of ymm3 with ymm1*/ - BLIS_ZTRSM_MUL(ymm3) -#endif + /*Fill zeros into ymm registers used in gemm accumulations*/ + BLIS_SET_YMM_REG_ZEROS - _mm256_storeu_pd((double *)b11, ymm3); - m_remainder -= 2; - i += 2; - } - if(m_remainder == 1) - { - a01 = D_A_pack; - //pointer to block of A to be used for TRSM - a11 = L + j*cs_a + j*rs_a; - //pointer to block of B to be used in GEMM - b10 = B + i; - //pointer to block of B to be used for TRSM - b11 = B + i + j*cs_b; - - //number of GEMM operations to be done(in blocks of 4x4) - k_iter = j; - - /*Fill zeros into ymm registers used in gemm accumulations*/ - BLIS_SET_YMM_REG_ZEROS - - ///GEMM implementation starts/// - BLIS_ZTRSM_SMALL_GEMM_1nx1m(a01,b10,cs_b,p_lda,k_iter) - - // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_ZTRSM_SMALL_1x1(AlphaVal,b11,cs_b) - - ///implement TRSM/// - ////extract a00 - ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); + ///GEMM implementation starts/// + BLIS_ZTRSM_SMALL_GEMM_1nx1m(a01,b10,cs_b,p_lda,k_iter) + + // Load b11 of size 4x6 and multiply with alpha + BLIS_PRE_ZTRSM_SMALL_1x1(AlphaVal,b11,cs_b) + + ///implement TRSM/// + ////extract a00 + ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + ymm1 = _mm256_broadcast_pd((__m128d const *)(d11_pack)); #ifndef BLIS_ENABLE_TRSM_PREINVERSION - /*performs dcomplex divison of ymm3 with ymm1*/ - BLIS_ZTRSM_DIV(ymm3) + /*performs dcomplex divison of ymm3 with ymm1*/ + BLIS_ZTRSM_DIV(ymm3) #else - /*performs dcomplex multiplication of ymm3 with ymm1*/ - BLIS_ZTRSM_MUL(ymm3) + /*performs dcomplex multiplication of ymm3 with ymm1*/ + BLIS_ZTRSM_MUL(ymm3) #endif - _mm_storeu_pd((double *)b11, - _mm256_extractf128_pd(ymm3,0)); - m_remainder -= 1; - i += 1; - } - j += 1; - n_remainder -= 1; - } - - if ((required_packing_A == 1) && - bli_mem_is_alloc( &local_mem_buf_A_s )) - { - bli_membrk_release(&rntm, &local_mem_buf_A_s); - } - - - return BLIS_SUCCESS; + _mm_storeu_pd((double *)b11, + _mm256_extractf128_pd(ymm3,0)); + m_remainder -= 1; + i += 1; + } + j += 1; + n_remainder -= 1; + } + + if ((required_packing_A == 1) && + bli_mem_is_alloc( &local_mem_buf_A_s )) + { + bli_membrk_release(&rntm, &local_mem_buf_A_s); + } + + + return BLIS_SUCCESS; } BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB @@ -18082,7 +36517,7 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB cntl_t* cntl ) { - return BLIS_SUCCESS; + return BLIS_SUCCESS; } BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB @@ -18094,7 +36529,7 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB cntl_t* cntl ) { - return BLIS_SUCCESS; + return BLIS_SUCCESS; } BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB @@ -18106,7 +36541,7 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB cntl_t* cntl ) { - return BLIS_SUCCESS; + return BLIS_SUCCESS; } BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB @@ -18118,55 +36553,6 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB cntl_t* cntl ) { - return BLIS_SUCCESS; -} - -BLIS_INLINE err_t bli_strsm_small_AutXB_AlXB -( - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl -) -{ - return BLIS_SUCCESS; -} - -BLIS_INLINE err_t bli_strsm_small_AltXB_AuXB -( - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl -) -{ - return BLIS_SUCCESS; -} - -BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB -( - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl -) -{ - return BLIS_SUCCESS; -} - -BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB -( - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl -) -{ - return BLIS_SUCCESS; + return BLIS_SUCCESS; } - -#endif //BLIS_ENABLE_SMALL_MATRIX_TRSM +#endif //BLIS_ENABLE_SMALL_MATRIX_TRSM \ No newline at end of file From 590c763e225d6f76f0db02ae7d57e911597b0c7b Mon Sep 17 00:00:00 2001 From: Harsh Dave Date: Sun, 3 Oct 2021 11:44:43 -0500 Subject: [PATCH 034/243] Implemented ctrsm small kernels Details: -- AMD Internal Id: CPUPL-1702 -- Used 8x3 CGEMM kernel with vector fma by utilizing ymm registers efficiently to produce 24 scomplex outputs at a time -- Used packing of matrix A to effectively cache and reuse -- Implemented kernels using macro based modular approach -- Added ctrsm_small for in ctrsm_ BLAS path for single thread when (m,n)<1000 and multithread (m+n)<320 -- Taken care of --disable_pre_inversion configuration -- Achieved 13% average performance improvement for sizes less than 1000 -- modularized all 16 combinations of trsm into 4 kernels Change-Id: I557c5bcd8cb7c034acd99ce0666bc411e9c4fe64 --- frame/compat/bla_trsm.c | 307 +- kernels/zen/3/bli_trsm_small.c | 9547 +++++++++++++++++++++++++++++++- 2 files changed, 9751 insertions(+), 103 deletions(-) diff --git a/frame/compat/bla_trsm.c b/frame/compat/bla_trsm.c index a2703d1cdd..b29219fe56 100644 --- a/frame/compat/bla_trsm.c +++ b/frame/compat/bla_trsm.c @@ -920,7 +920,7 @@ void ztrsm_ trans_t blis_transa; diag_t blis_diaga; dim_t m0, n0; - //conj_t conja = BLIS_NO_CONJUGATE ; + conj_t conja = BLIS_NO_CONJUGATE ; /* Initialize BLIS. */ bli_init_auto(); @@ -997,47 +997,8 @@ void ztrsm_ } else if( ( blis_side == BLIS_RIGHT ) && ( m0 != 1 ) ) { - /** NOTE: Since for RUCN kernel, function seem to - * be having issue with the computation, which is - * causing make check to fail, For time being, letting - * this particular case through small ztrsm for sake - * of make check. - * TODO: code snippet needs to be enabled, once - * fix is done. - */ - - /* b = alpha * b; */ -/* bli_zscalv_ex - ( - conja, - m0, - (dcomplex*)alpha, - (dcomplex*)b, rs_b, - NULL, - NULL - ); - if(blis_diaga == BLIS_NONUNIT_DIAG) - { - dcomplex inva = {0, 0}; - inva.real = a->real; - inva.imag = (a->imag * -1.0); - double dnm = (a->real * a->real); - dnm += ( (-1.0 * (a->imag * a->imag )) * -1.0 ); - inva.real /= dnm; - inva.imag /= dnm; - for(int indx = 0; indx < m0; indx ++) - { - double real = (inva.real * b[indx].real); - real += ((inva.imag * b[indx].imag) * -1.0); - double imag = (inva.real * b[indx].imag); - imag += (inva.imag * b[indx].real); - b[indx].real = real; - b[indx].imag = imag; - } - } - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return;*/ - } + ; + } } else if( m0 == 1 ) { @@ -1088,45 +1049,7 @@ void ztrsm_ } else if(( blis_side == BLIS_LEFT ) && ( n0 != 1 )) { - /** NOTE: Since for LUCN kernel, function seem to - * be having issue with the computation, which is - * causing make check to fail, For time being, letting - * this particular case through small ztrsm for sake - * of make check. - * TODO: code snippet needs to be enabled, once - * fix is done. - */ - /* b = alpha * b; */ -/* bli_zscalv_ex - ( - conja, - n0, - (dcomplex*)alpha, - (dcomplex*)b, cs_b, - NULL, - NULL - ); - if(blis_diaga == BLIS_NONUNIT_DIAG) - { - dcomplex inva = {0, 0}; - inva.real = a->real; - inva.imag = (a->imag * -1.0); - double dnm = (a->real * a->real); - dnm += ( (-1.0 * (a->imag * a->imag )) * -1.0 ); - inva.real /= dnm; - inva.imag /= dnm; - for(int indx = 0; indx < n0; indx ++) - { - double real = (inva.real * b[indx*cs_b].real); - real += ((inva.imag * b[indx*cs_b].imag) * -1.0); - double imag = (inva.real * b[indx*cs_b].imag); - imag += (inva.imag * b[indx*cs_b].real); - b[indx*cs_b].real = real; - b[indx*cs_b].imag = imag; - } - } - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return;*/ + ; } } @@ -1157,7 +1080,7 @@ void ztrsm_ * In case of multithread when [m,n]<=128 sinlge thread implemenation * is doing better than native multithread */ bool nt = bli_thread_get_is_parallel(); - if((nt==0 && m0<500 && n0<500) || + if((nt==0 && m0<=500 && n0<=500) || (nt && (m0+n0)<128) ) { err_t status; @@ -1195,7 +1118,225 @@ void ztrsm_ bli_finalize_auto(); } -GENTFUNC( scomplex, c, trsm, trsm ) +void ctrsm_ +( + const f77_char* side, + const f77_char* uploa, + const f77_char* transa, + const f77_char* diaga, + const f77_int* m, + const f77_int* n, + const scomplex* alpha, + const scomplex* a, const f77_int* lda, + scomplex* b, const f77_int* ldb +) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO) + AOCL_DTL_LOG_TRSM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 's', + *side, *uploa,*transa, *diaga, *m, *n, + (void*)alpha,*lda, *ldb); + + side_t blis_side; + uplo_t blis_uploa; + trans_t blis_transa; + diag_t blis_diaga; + dim_t m0, n0; + conj_t conja = BLIS_NO_CONJUGATE ; + + /* Initialize BLIS. */ + bli_init_auto(); + + /* Perform BLAS parameter checking. */ + PASTEBLACHK(trsm) + ( + MKSTR(c), + MKSTR(trsm), + side, + uploa, + transa, + diaga, + m, + n, + lda, + ldb + ); + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + bli_param_map_netlib_to_blis_side( *side, &blis_side ); + bli_param_map_netlib_to_blis_uplo( *uploa, &blis_uploa ); + bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); + bli_param_map_netlib_to_blis_diag( *diaga, &blis_diaga ); + + /* Typecast BLAS integers to BLIS integers. */ + bli_convert_blas_dim1( *m, m0 ); + bli_convert_blas_dim1( *n, n0 ); + + /* Set the row and column strides of the matrix operands. */ + const inc_t rs_a = 1; + const inc_t cs_a = *lda; + const inc_t rs_b = 1; + const inc_t cs_b = *ldb; + const num_t dt = BLIS_SCOMPLEX; + + + if( n0 == 1 ) + { + if( blis_side == BLIS_LEFT ) + { + if(bli_is_notrans(blis_transa)) + { + bli_ctrsv_unf_var2 + ( + blis_uploa, + blis_transa, + blis_diaga, + m0, + (scomplex*)alpha, + (scomplex*)a, rs_a, cs_a, + (scomplex*)b, rs_b, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + else if(bli_is_trans(blis_transa)) + { + bli_ctrsv_unf_var1 + ( + blis_uploa, + blis_transa, + blis_diaga, + m0, + (scomplex*)alpha, + (scomplex*)a, rs_a, cs_a, + (scomplex*)b, rs_b, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + } + else if( ( blis_side == BLIS_RIGHT ) && ( m0 != 1 ) ) + { + ; + } + } + else if( m0 == 1 ) + { + if(blis_side == BLIS_RIGHT) + { + if(bli_is_notrans(blis_transa)) + { + if(blis_uploa == BLIS_UPPER) + blis_uploa = BLIS_LOWER; + else + blis_uploa = BLIS_UPPER; + + bli_ctrsv_unf_var1 + ( + blis_uploa, + blis_transa, + blis_diaga, + n0, + (scomplex*)alpha, + (scomplex*)a, cs_a, rs_a, + (scomplex*)b, cs_b, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + else if(bli_is_trans(blis_transa)) + { + if(blis_uploa == BLIS_UPPER) + blis_uploa = BLIS_LOWER; + else + blis_uploa = BLIS_UPPER; + + bli_ctrsv_unf_var2 + ( + blis_uploa, + blis_transa, + blis_diaga, + n0, + (scomplex*)alpha, + (scomplex*)a, cs_a, rs_a, + (scomplex*)b, cs_b, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + } + else if(( blis_side == BLIS_LEFT ) && ( n0 != 1 )) + { + ; + } + } + + const struc_t struca = BLIS_TRIANGULAR; + + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; + obj_t ao = BLIS_OBJECT_INITIALIZER; + obj_t bo = BLIS_OBJECT_INITIALIZER; + + dim_t mn0_a; + + bli_set_dim_with_side( blis_side, m0, n0, &mn0_a ); + + bli_obj_init_finish_1x1( dt, (scomplex*)alpha, &alphao ); + + bli_obj_init_finish( dt, mn0_a, mn0_a, (scomplex*)a, rs_a, cs_a, &ao ); + bli_obj_init_finish( dt, m0, n0, (scomplex*)b, rs_b, cs_b, &bo ); + + bli_obj_set_uplo( blis_uploa, &ao ); + bli_obj_set_diag( blis_diaga, &ao ); + bli_obj_set_conjtrans( blis_transa, &ao ); + + bli_obj_set_struc( struca, &ao ); +#ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM + /* bli_ztrsm_small is performing better existing native + * implementations for [m,n]<=1000 for single thread. + * In case of multithread when [m,n]<=128 sinlge thread implemenation + * is doing better than native multithread */ + bool nt = bli_thread_get_is_parallel(); + if((nt==0 && m0<=1000 && n0<=1000) || + (nt && (m0+n0)<320) ) + { + err_t status; + status = bli_trsm_small + ( + blis_side, + &alphao, + &ao, + &bo, + NULL, + NULL + ); + if (status == BLIS_SUCCESS) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + /* Finalize BLIS. */ + bli_finalize_auto(); + return; + } + } +#endif + bli_trsmnat + ( + blis_side, + &alphao, + &ao, + &bo, + NULL, + NULL + ); + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) + /* Finalize BLIS. */ + bli_finalize_auto(); +} + #else INSERT_GENTFUNC_BLAS( trsm, trsm ) #endif diff --git a/kernels/zen/3/bli_trsm_small.c b/kernels/zen/3/bli_trsm_small.c index f4428564c5..576bec0abc 100644 --- a/kernels/zen/3/bli_trsm_small.c +++ b/kernels/zen/3/bli_trsm_small.c @@ -3815,15 +3815,31 @@ err_t bli_trsm_small bool uplo = bli_obj_is_upper(a); bool transa = bli_obj_has_trans(a); - /* ToDo: Temporary threshold condition for trsm single thread. - * It will be updated with arch based threshold function which reads - * tunned thresholds for all 64 (datatype,side,uplo,transa,unit,) trsm - combinations. We arrived to this condition based on performance - comparsion with only available native path - */ - if(m > 1000 || n > 1000) { - return BLIS_NOT_YET_IMPLEMENTED; - } + num_t dt = bli_obj_dt(a); + switch(dt) + { + case BLIS_DOUBLE: + case BLIS_FLOAT: + case BLIS_SCOMPLEX: + { + if(m > 1000 || n > 1000) { + return BLIS_NOT_YET_IMPLEMENTED; + } + break; + } + case BLIS_DCOMPLEX: + { + if(m > 500 || n > 500) { + return BLIS_NOT_YET_IMPLEMENTED; + } + break; + } + default: + { + return BLIS_NOT_YET_IMPLEMENTED; + break; + } + } /* If alpha is zero, B matrix will become zero after scaling hence solution is also zero matrix */ @@ -3838,12 +3854,6 @@ err_t bli_trsm_small return BLIS_INVALID_ROW_STRIDE; } - //Curretnly optimized for double data type only - num_t dt = bli_obj_dt(a); - if (dt != BLIS_DOUBLE && dt != BLIS_FLOAT && dt != BLIS_DCOMPLEX) { - return BLIS_NOT_YET_IMPLEMENTED; - } - // A is expected to be triangular in trsm if (!bli_obj_is_upper_or_lower (a)) { return BLIS_EXPECTED_TRIANGULAR_OBJECT; @@ -36508,6 +36518,2056 @@ BLIS_INLINE err_t bli_ztrsm_small_XAltB_XAuB return BLIS_SUCCESS; } +/* + * CTRSM utilities + */ + +#define SCOMPLEX_INV(a, b) {\ + a.real = b.real;\ + a.imag = (b.imag * -1.0);\ + /*Compute denominator eliminating imaginary component*/\ + float dnm = (b.real * b.real);\ + /*multiply two times with -1 for correct result as + * dcomplex number with positive imaginary part will + * invert the sign if not multiplied twice with -1*/\ + dnm += ((-1.0 * (b.imag * b.imag)) * -1.0);\ + /*Compute the final result by dividing real and imag part by dnm*/\ + a.real /= dnm;\ + a.imag /= dnm;\ +} + +#define SCOMPLEX_MUL(a, b, c) {\ + float real = a.real * b.real;\ + real += ((a.imag * b.imag) * -1.0);\ + float imag = (a.real * b.imag);\ + imag += (a.imag * b.real);\ + c.real = real;\ + c.imag = imag;\ +} + +#define SCOMPLEX_DIV(a, b){\ + float dnm = b.real * b.real;\ + dnm += (-1.0 * (b.imag * (b.imag * -1.0) ));\ + a.real /= dnm;\ + a.imag /= dnm;\ +} + +#ifdef BLIS_ENABLE_TRSM_PREINVERSION +#define CTRSM_DIAG_ELE_INV_OPS(a,b){\ + SCOMPLEX_INV(a, b)\ +} +#endif + +#ifdef BLIS_DISABLE_TRSM_PREINVERSION +#define CTRSM_DIAG_ELE_INV_OPS(a,b) {\ + a.real = b.real;\ + a.imag = b.imag;\ +} +#endif + + +#ifdef BLIS_ENABLE_TRSM_PREINVERSION +#define CTRSM_DIAG_ELE_EVAL_OPS(a,b,c){\ + if(!is_unitdiag)\ + SCOMPLEX_MUL(b, c, c)\ +} +#endif + +#ifdef BLIS_DISABLE_TRSM_PREINVERSION +#define CTRSM_DIAG_ELE_EVAL_OPS(a,b,c){\ + if(!is_unitdiag)\ + {\ + a.real = b.real;\ + a.imag = (b.imag * -1.0);\ + SCOMPLEX_MUL(c, a, c)\ + SCOMPLEX_DIV(c, b)\ + }\ +} +#endif + + +BLIS_INLINE err_t ctrsm_AltXB_ref +( + scomplex *A, + scomplex *B, + dim_t M, + dim_t N, + dim_t lda, + dim_t ldb, + bool is_unitdiag, + bool conjtransa +) +{ + dim_t i, j, k; + for (k = M-1; k >= 0; k--) + { + scomplex lkk_inv = {1.0, 1.0}, cur_compute = {0.0, 0.0}, A_trans = {0.0, 0.0}; + if(!is_unitdiag) + { + CTRSM_DIAG_ELE_INV_OPS(lkk_inv, A[k+k*lda]) + if(conjtransa) + { + lkk_inv.imag *= -1.0; + } + } + + for (j = N-1; j >= 0; j--) + { + CTRSM_DIAG_ELE_EVAL_OPS(cur_compute, lkk_inv, B[k + j*ldb]) + + for (i = k-1; i >= 0; i--) + { + if(conjtransa) + { + A_trans.real = A[i*lda + k].real; + A_trans.imag = A[i*lda + k].imag * -1.0; + } + else + { + A_trans.real = A[i*lda + k].real; + A_trans.imag = A[i*lda + k].imag; + } + SCOMPLEX_MUL(A_trans, B[k+j*ldb], cur_compute) + B[i + j*ldb].real -= cur_compute.real; + B[i + j*ldb].imag -= cur_compute.imag; + } + } + } + return BLIS_SUCCESS; +} + +BLIS_INLINE err_t ctrsm_AuXB_ref +( + scomplex *A, + scomplex *B, + dim_t M, + dim_t N, + dim_t lda, + dim_t ldb, + bool is_unitdiag, + bool conjtransa +) +{ + dim_t i, j, k; + for (k = M-1; k >= 0; k--) + { + scomplex lkk_inv = {1.0, 1.0}, cur_compute = {0.0, 0.0}, A_trans = {0.0, 0.0}; + if(!is_unitdiag) + { + CTRSM_DIAG_ELE_INV_OPS(lkk_inv, A[k+k*lda]) + if(conjtransa) + { + lkk_inv.imag *= -1.0; + } + } + for (j = N-1; j >= 0; j--) + { + CTRSM_DIAG_ELE_EVAL_OPS(cur_compute, lkk_inv, B[k + j*ldb]) + for (i = k-1; i >= 0; i--) + { + if(conjtransa) + { + A_trans.real = A[i+k*lda].real; + A_trans.imag = A[i+k*lda].imag * -1.0; + } + else + { + A_trans.real = A[i+k*lda].real; + A_trans.imag = A[i+k*lda].imag; + } + SCOMPLEX_MUL(A_trans, B[k+j*ldb], cur_compute) + B[i + j*ldb].real -= cur_compute.real; + B[i + j*ldb].imag -= cur_compute.imag; + } + } + + } + return BLIS_SUCCESS; +} + +BLIS_INLINE err_t ctrsm_AutXB_ref +( + scomplex *A, + scomplex *B, + dim_t M, + dim_t N, + dim_t lda, + dim_t ldb, + bool is_unitdiag, + bool conjtransa +) +{ + dim_t i, j, k; + for (k = 0; k < M; k++) + { + scomplex lkk_inv = {1.0, 1.0}, cur_compute = {0.0, 0.0}, A_trans = {0.0, 0.0}; + if(!is_unitdiag) + { + CTRSM_DIAG_ELE_INV_OPS(lkk_inv, A[k+k*lda]) + if(conjtransa) + { + lkk_inv.imag *= -1.0; + } + } + + for (j = 0; j < N; j++) + { + CTRSM_DIAG_ELE_EVAL_OPS(cur_compute, lkk_inv, B[k + j*ldb]) + for (i = k+1; i < M; i++) + { + if(conjtransa) + { + A_trans.real = A[i*lda + k].real; + A_trans.imag = A[i*lda + k].imag * -1.0; + } + else + { + A_trans.real = A[i*lda + k].real; + A_trans.imag = A[i*lda + k].imag; + } + SCOMPLEX_MUL(A_trans, B[k+j*ldb], cur_compute) + B[i + j*ldb].real -= cur_compute.real; + B[i + j*ldb].imag -= cur_compute.imag; + } + } + } + return BLIS_SUCCESS; +} + +BLIS_INLINE err_t ctrsm_AlXB_ref +( + scomplex *A, + scomplex *B, + dim_t M, + dim_t N, + dim_t lda, + dim_t ldb, + bool is_unitdiag, + bool conjtransa +) +{ + dim_t i, j, k; + for (k = 0; k < M; k++) + { + scomplex lkk_inv = {1.0, 1.0}, cur_compute = {0.0, 0.0}, A_trans = {0.0, 0.0}; + if(!is_unitdiag) + { + CTRSM_DIAG_ELE_INV_OPS(lkk_inv, A[k+k*lda]) + if(conjtransa) + { + lkk_inv.imag *= -1.0; + } + } + for (j = 0; j < N; j++) + { + CTRSM_DIAG_ELE_EVAL_OPS(cur_compute, lkk_inv, B[k + j*ldb]) + for (i = k+1; i < M; i++) + { + if(conjtransa) + { + A_trans.real = A[i+k*lda].real; + A_trans.imag = A[i+k*lda].imag * -1.0; + } + else + { + A_trans.real = A[i+k*lda].real; + A_trans.imag = A[i+k*lda].imag; + } + SCOMPLEX_MUL(A_trans, B[k+j*ldb], cur_compute) + B[i + j*ldb].real -= cur_compute.real; + B[i + j*ldb].imag -= cur_compute.imag; + } + } + + } + return BLIS_SUCCESS; +} + +BLIS_INLINE void bli_ctrsm_small_pack +( + char side, + dim_t size, + bool trans, + scomplex *inbuf, + dim_t cs_a, + scomplex *pbuff, + dim_t p_lda, + dim_t mr +) +{ + //scratch registers + __m256 ymm0, ymm1, ymm2,ymm3,ymm4; + __m256 ymm5, ymm6, ymm7; + __m256 ymm8, ymm9, ymm10, ymm11,ymm12,ymm13; + __m128 xmm0,xmm1,xmm2; + + if(side=='L'||side=='l') + { + /*Left case is 4xk*/ + if(trans) + { + for(dim_t x = 0; x < size; x += mr) + { + ymm0 = _mm256_loadu_ps((float const *)(inbuf)); + ymm10 = _mm256_loadu_ps((float const *)(inbuf + 4)); + ymm1 = _mm256_loadu_ps((float const *)(inbuf + cs_a)); + ymm11 = _mm256_loadu_ps((float const *)(inbuf + 4 + cs_a)); + ymm2 = _mm256_loadu_ps((float const *)(inbuf + cs_a * 2)); + ymm12 = _mm256_loadu_ps((float const *)(inbuf + 4 + cs_a * 2)); + ymm3 = _mm256_loadu_ps((float const *)(inbuf + cs_a * 3)); + ymm13 = _mm256_loadu_ps((float const *)(inbuf + 4 + cs_a * 3)); + + ymm4 = _mm256_shuffle_ps(ymm0, ymm1, 0x44); + ymm5 = _mm256_shuffle_ps(ymm2, ymm3, 0x44); + ymm6 = _mm256_permute2f128_ps(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_ps(ymm4,ymm5,0x31); + ymm0 = _mm256_shuffle_ps(ymm0, ymm1, 0xEE); + ymm1 = _mm256_shuffle_ps(ymm2, ymm3, 0xEE); + ymm7 = _mm256_permute2f128_ps(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_ps(ymm0,ymm1,0x31); + + _mm256_storeu_ps((float *)(pbuff), ymm6); + _mm256_storeu_ps((float *)(pbuff + p_lda), ymm7); + _mm256_storeu_ps((float *)(pbuff + p_lda*2), ymm8); + _mm256_storeu_ps((float *)(pbuff + p_lda*3), ymm9); + + ymm4 = _mm256_shuffle_ps(ymm10, ymm11, 0x44); + ymm5 = _mm256_shuffle_ps(ymm12, ymm13, 0x44); + + ymm6 = _mm256_permute2f128_ps(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_ps(ymm4,ymm5,0x31); + + ymm0 = _mm256_shuffle_ps(ymm10, ymm11, 0xEE); + ymm1 = _mm256_shuffle_ps(ymm12, ymm13, 0xEE); + + ymm7 = _mm256_permute2f128_ps(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_ps(ymm0,ymm1,0x31); + + _mm256_storeu_ps((float *)(pbuff + p_lda * 4), ymm6); + _mm256_storeu_ps((float *)(pbuff + p_lda * 5), ymm7); + _mm256_storeu_ps((float *)(pbuff + p_lda * 6), ymm8); + _mm256_storeu_ps((float *)(pbuff + p_lda * 7), ymm9); + + ymm0 = _mm256_loadu_ps((float const *)(inbuf + cs_a * 4)); + ymm10 = _mm256_loadu_ps((float const *)(inbuf + cs_a * 4 + 4)); + ymm1 = _mm256_loadu_ps((float const *)(inbuf + cs_a * 5)); + ymm11 = _mm256_loadu_ps((float const *)(inbuf + cs_a * 5 + 4)); + ymm2 = _mm256_loadu_ps((float const *)(inbuf + cs_a * 6)); + ymm12 = _mm256_loadu_ps((float const *)(inbuf + cs_a * 6 + 4)); + ymm3 = _mm256_loadu_ps((float const *)(inbuf + cs_a * 7)); + ymm13 = _mm256_loadu_ps((float const *)(inbuf + cs_a * 7 + 4)); + + ymm4 = _mm256_shuffle_ps(ymm0, ymm1, 0x44); + ymm5 = _mm256_shuffle_ps(ymm2, ymm3, 0x44); + ymm6 = _mm256_permute2f128_ps(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_ps(ymm4,ymm5,0x31); + ymm0 = _mm256_shuffle_ps(ymm0, ymm1, 0xEE); + ymm1 = _mm256_shuffle_ps(ymm2, ymm3, 0xEE); + ymm7 = _mm256_permute2f128_ps(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_ps(ymm0,ymm1,0x31); + + _mm256_storeu_ps((float *)(pbuff + 4), ymm6); + _mm256_storeu_ps((float *)(pbuff + 4 + p_lda), ymm7); + _mm256_storeu_ps((float *)(pbuff + 4 + p_lda*2), ymm8); + _mm256_storeu_ps((float *)(pbuff + 4 + p_lda*3), ymm9); + + ymm4 = _mm256_shuffle_ps(ymm10, ymm11, 0x44); + ymm5 = _mm256_shuffle_ps(ymm12, ymm13, 0x44); + ymm6 = _mm256_permute2f128_ps(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_ps(ymm4,ymm5,0x31); + ymm0 = _mm256_shuffle_ps(ymm10, ymm11, 0xEE); + ymm1 = _mm256_shuffle_ps(ymm12, ymm13, 0xEE); + ymm7 = _mm256_permute2f128_ps(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_ps(ymm0,ymm1,0x31); + + _mm256_storeu_ps((float *)(pbuff + 4 + p_lda * 4), ymm6); + _mm256_storeu_ps((float *)(pbuff + 4 + p_lda * 5), ymm7); + _mm256_storeu_ps((float *)(pbuff + 4 + p_lda * 6), ymm8); + _mm256_storeu_ps((float *)(pbuff + 4 + p_lda * 7), ymm9); + + inbuf += mr; + pbuff += mr*mr; + } + }else + { + //Expected multiples of 8 + p_lda = 8; + for(dim_t x = 0; x < size; x++) + { + ymm0 = _mm256_loadu_ps((float const *)(inbuf)); + _mm256_storeu_ps((float *)(pbuff), ymm0); + ymm1 = _mm256_loadu_ps((float const *)(inbuf + 4)); + _mm256_storeu_ps((float *)(pbuff + 4), ymm1); + inbuf+=cs_a; + pbuff+=p_lda; + } + } + } + else if(side=='R'||side=='r') + { + + if(trans) + { + for(dim_t x=0; xbuffer; //value of alpha + scomplex *L = a->buffer; //pointer to matrix A + scomplex *B = b->buffer; //pointer to matrix B + + scomplex *a10, *a11, *b01, *b11; //pointers that point to blocks for GEMM and TRSM + + scomplex ones = {1.0, 1.0}; + bool is_unitdiag = bli_obj_has_unit_diag(a); + + //scratch registers + __m256 ymm0, ymm1, ymm2, ymm3; + __m256 ymm4, ymm5, ymm6, ymm7; + __m256 ymm8, ymm9, ymm10, ymm11; + __m256 ymm12, ymm13, ymm14, ymm15; + __m256 ymm16, ymm17, ymm18, ymm19; + + __m128 xmm0, xmm1, xmm2, xmm3, xmm4; + + gint_t required_packing_A = 1; + mem_t local_mem_buf_A_s = {0}; + scomplex *D_A_pack = NULL; + scomplex d11_pack[d_mr] __attribute__((aligned(64))); + rntm_t rntm; + + bli_rntm_init_from_global( &rntm ); + bli_rntm_set_num_threads_only( 1, &rntm ); + bli_membrk_rntm_set_membrk( &rntm ); + + siz_t buffer_size = bli_pool_block_size( + bli_membrk_pool( + bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), + bli_rntm_membrk(&rntm))); + + if ( (d_mr * m * sizeof(scomplex)) > buffer_size) + return BLIS_NOT_YET_IMPLEMENTED; + + if (required_packing_A == 1) + { + // Get the buffer from the pool. + bli_membrk_acquire_m(&rntm, + buffer_size, + BLIS_BITVAL_BUFFER_FOR_A_BLOCK, + &local_mem_buf_A_s); + if(FALSE==bli_mem_is_alloc(&local_mem_buf_A_s)) return BLIS_NULL_POINTER; + D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); + if(NULL==D_A_pack) return BLIS_NULL_POINTER; + } + + /* + Performs solving TRSM for 4 colmns at a time from 0 to m/4 in steps of d_mr + a. Load, transpose, Pack A (a10 block), the size of packing 4x3 to 4x (m-4) + First there will be no GEMM and no packing of a10 because it is only TRSM + b. Using packed a10 block and b01 block perform GEMM operation + c. Use GEMM outputs, perform TRSM operaton using a11, b11 and update B + d. Repeat b,c for n rows of B in steps of d_nr + */ + for(i = 0;(i+d_mr-1) < m; i += d_mr) //loop along 'M' dimension + { + a10 = L + (i*cs_a); //pointer to block of A to be used for GEMM + a11 = L + (i*rs_a) + (i*cs_a); + dim_t p_lda = d_mr; // packed leading dimension + + if(transa) + { + /* + Load, tranpose and pack current A block (a10) into packed buffer memory + D_A_pack + a. This a10 block is used in GEMM portion only and this + a10 block size will be increasing by d_mr for every next itteration + untill it reaches 4x(m-4) which is the maximum GEMM alone block size + in A + b. This packed buffer is reused to calculate all n rows of B matrix + */ + bli_ctrsm_small_pack('L', i, 1, a10, cs_a, D_A_pack, p_lda,d_mr); + + /* + Pack 4 diagonal elements of A block into an array + a. This helps in utilze cache line efficiently in TRSM operation + b. store ones when input is unit diagonal + */ + ctrsm_small_pack_diag_element(is_unitdiag,a11,cs_a,d11_pack,d_mr); + } + else + { + bli_ctrsm_small_pack('L', i, 0, a10, rs_a, D_A_pack, p_lda,d_mr); + ctrsm_small_pack_diag_element(is_unitdiag,a11,rs_a,d11_pack,d_mr); + } + /* + a. Perform GEMM using a10, b01. + b. Perform TRSM on a11, b11 + c. This loop GEMM+TRSM loops operates with 4x3 block size + along n dimension for every d_nr rows of b01 where + packed A buffer is reused in computing all n rows of B. + d. Same approch is used in remaining fringe cases. + */ + dim_t temp = n - d_nr + 1; + for(j = 0; j < temp; j += d_nr) //loop along 'N' dimension + { + a10 = D_A_pack; + a11 = L + (i*rs_a) + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + j*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM + + k_iter = i; + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + /* + Peform GEMM between a10 and b01 blocks + For first itteration there will be no GEMM operation + where k_iter are zero + */ + BLIS_CTRSM_SMALL_GEMM_8mx3n(a10,b01,cs_b,p_lda,k_iter) + + /* + Load b11 of size 3x4 and multiply with alpha + Add the GEMM output and perform inregister transose of b11 + to peform TRSM operation. + */ + BLIS_CTRSM_SMALL_NREG_TRANSPOSE_3x8(b11,cs_b,AlphaVal) + /* + Compute 4x3 TRSM block by using GEMM block output in register + a. The 4x3 input (gemm outputs) are stored in combinations of ymm + registers + 1. ymm8, ymm4 2. ymm9, ymm5 3. ymm10, ymm6, 4. ymm11, ymm7 + where ymm8-ymm11 holds 4x2 data and reaming 4x1 will be hold by + other registers + b. Towards the end do in regiser transpose of TRSM output and store in + b11 + */ + ////extract a00 + ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm8) +#else + BLIS_CTRSM_MUL(ymm8) +#endif + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*1) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //extract a11 + //(ROw1): FMA operations + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm8, ymm8, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm8, ymm8,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm9 = _mm256_sub_ps(ymm9,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*2) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm8, ymm8, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm8, ymm8,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm10 = _mm256_sub_ps(ymm10,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*3) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm8, ymm8, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm8, ymm8,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm11 = _mm256_sub_ps(ymm11,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*4) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm8, ymm8, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm8, ymm8,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm12 = _mm256_sub_ps(ymm12,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*5) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm8, ymm8, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm8, ymm8,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm13 = _mm256_sub_ps(ymm13,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*6) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm8, ymm8, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm8, ymm8,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm14 = _mm256_sub_ps(ymm14,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*7) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm8, ymm8, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm8, ymm8,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm15 = _mm256_sub_ps(ymm15,ymm16); + + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm9) +#else + BLIS_CTRSM_MUL(ymm9) +#endif + + + a11 += rs_a; + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*2) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm9 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm9, ymm9, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm9, ymm9,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm10 = _mm256_sub_ps(ymm10,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*3) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm9, ymm9, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm9, ymm9,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm11 = _mm256_sub_ps(ymm11,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*4) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm9, ymm9, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm9, ymm9,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm12 = _mm256_sub_ps(ymm12,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*5) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm9, ymm9, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm9, ymm9,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm13 = _mm256_sub_ps(ymm13,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*6) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm9, ymm9, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm9, ymm9,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm14 = _mm256_sub_ps(ymm14,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*7) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm9, ymm9, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm9, ymm9,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm15 = _mm256_sub_ps(ymm15,ymm16); + + + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 2)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); + +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm10) +#else + BLIS_CTRSM_MUL(ymm10) +#endif + + + a11 += rs_a; + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*3) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm10 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm10, ymm10, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm10, ymm10,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm11 = _mm256_sub_ps(ymm11,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*4) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm10, ymm10, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm10, ymm10,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm12 = _mm256_sub_ps(ymm12,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*5) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm10, ymm10, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm10, ymm10,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm13 = _mm256_sub_ps(ymm13,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*6) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm10, ymm10, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm10, ymm10,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm14 = _mm256_sub_ps(ymm14,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*7) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm10, ymm10, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm10, ymm10,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm15 = _mm256_sub_ps(ymm15,ymm16); + + + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 3)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm11) +#else + BLIS_CTRSM_MUL(ymm11) +#endif + + + a11 += rs_a; + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*4) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm11, ymm11, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm11, ymm11,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm12 = _mm256_sub_ps(ymm12,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*5) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm11, ymm11, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm11, ymm11,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm13 = _mm256_sub_ps(ymm13,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*6) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm11, ymm11, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm11, ymm11,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm14 = _mm256_sub_ps(ymm14,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*7) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm11, ymm11, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm11, ymm11,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm15 = _mm256_sub_ps(ymm15,ymm16); + + + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 4)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm12) +#else + BLIS_CTRSM_MUL(ymm12) +#endif + a11 += rs_a; + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*5) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm12, ymm12, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm12, ymm12,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm13 = _mm256_sub_ps(ymm13,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*6) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm12, ymm12, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm12, ymm12,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm14 = _mm256_sub_ps(ymm14,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*7) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm12, ymm12, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm12, ymm12,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm15 = _mm256_sub_ps(ymm15,ymm16); + + + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 5)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm13) +#else + BLIS_CTRSM_MUL(ymm13) +#endif + + a11 += rs_a; + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*6) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm13, ymm13, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm13, ymm13,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm14 = _mm256_sub_ps(ymm14,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*7) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm13, ymm13, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm13, ymm13,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm15 = _mm256_sub_ps(ymm15,ymm16); + + + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 6)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm14) +#else + BLIS_CTRSM_MUL(ymm14) +#endif + + a11 += rs_a; + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*7) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm14, ymm14, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm14, ymm14,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm15 = _mm256_sub_ps(ymm15,ymm16); + + + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 7)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm15) +#else + BLIS_CTRSM_MUL(ymm15) +#endif + + BLIS_CTRSM_SMALL_NREG_TRANSPOSE_8x3_AND_STORE(b11,cs_b) + + } + dim_t n_rem = n-j; + if(n_rem) + { + a10 = D_A_pack; + a11 = L + (i*rs_a) + (i*cs_a); + b01 = B + j*cs_b; + b11 = B + i + j* cs_b; + + k_iter = i; + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + if(2 == n_rem) + { + BLIS_CTRSM_SMALL_GEMM_8mx2n(a10,b01,cs_b,p_lda,k_iter) + + float zero = 0.0; + ymm16 = _mm256_broadcast_ss(&AlphaVal.real); + ymm17 = _mm256_broadcast_ss(&AlphaVal.imag); + ymm2 = _mm256_broadcast_ss(&zero); + ymm3 = _mm256_broadcast_ss(&zero); + ymm6 = _mm256_broadcast_ss(&zero); + ymm7 = _mm256_broadcast_ss(&zero); + + ymm0 = _mm256_loadu_ps((float const *)(b11)); + ymm4 = _mm256_loadu_ps((float const *)(b11 + cs_b *1)); + ymm1 = _mm256_loadu_ps((float const *)(b11 + 4)); + ymm5 = _mm256_loadu_ps((float const *)(b11 + cs_b *1 + 4)); + + ymm2 = _mm256_fmadd_ps(ymm0, ymm16, ymm2); + ymm3 = _mm256_fmadd_ps(ymm1, ymm16, ymm3); + ymm6 = _mm256_fmadd_ps(ymm0, ymm17, ymm6); + ymm7 = _mm256_fmadd_ps(ymm1, ymm17, ymm7); + + ymm6 = _mm256_permute_ps(ymm6, 0xb1); + ymm7 = _mm256_permute_ps(ymm7, 0xb1); + + ymm0 = _mm256_addsub_ps(ymm2, ymm6); + ymm1 = _mm256_addsub_ps(ymm3, ymm7); + ymm0 = _mm256_sub_ps(ymm0, ymm8); + ymm1 = _mm256_sub_ps(ymm1, ymm12); + + ymm2 = _mm256_broadcast_ss(&zero); + ymm3 = _mm256_broadcast_ss(&zero); + + ymm6 = _mm256_broadcast_ss(&zero); + ymm7 = _mm256_broadcast_ss(&zero); + + ymm2 = _mm256_fmadd_ps(ymm4, ymm16, ymm2); + ymm3 = _mm256_fmadd_ps(ymm5, ymm16, ymm3); + ymm6 = _mm256_fmadd_ps(ymm4, ymm17, ymm6); + ymm7 = _mm256_fmadd_ps(ymm5, ymm17, ymm7); + + ymm6 = _mm256_permute_ps(ymm6, 0xb1); + ymm7 = _mm256_permute_ps(ymm7, 0xb1); + + ymm4 = _mm256_addsub_ps(ymm2, ymm6); + ymm5 = _mm256_addsub_ps(ymm3, ymm7); + ymm4 = _mm256_sub_ps(ymm4, ymm9); + ymm5 = _mm256_sub_ps(ymm5, ymm13); + ymm2 = _mm256_broadcast_ss((float const *)&ones); + ymm3 = _mm256_broadcast_ss((float const *)&ones); + + } + if(1 == n_rem) + { + BLIS_CTRSM_SMALL_GEMM_8mx1n(a10,b01,cs_b,p_lda,k_iter) + + float zero = 0.0; + ymm16 = _mm256_broadcast_ss(&AlphaVal.real); + ymm17 = _mm256_broadcast_ss(&AlphaVal.imag); + ymm2 = _mm256_broadcast_ss(&zero); + ymm3 = _mm256_broadcast_ss(&zero); + ymm6 = _mm256_broadcast_ss(&zero); + ymm7 = _mm256_broadcast_ss(&zero); + + ymm0 = _mm256_loadu_ps((float const *)(b11)); + ymm1 = _mm256_loadu_ps((float const *)(b11 + 4)); + + ymm2 = _mm256_fmadd_ps(ymm0, ymm16, ymm2); + ymm3 = _mm256_fmadd_ps(ymm1, ymm16, ymm3); + ymm6 = _mm256_fmadd_ps(ymm0, ymm17, ymm6); + ymm7 = _mm256_fmadd_ps(ymm1, ymm17, ymm7); + + ymm6 = _mm256_permute_ps(ymm6, 0xb1); + ymm7 = _mm256_permute_ps(ymm7, 0xb1); + + ymm0 = _mm256_addsub_ps(ymm2, ymm6); + ymm1 = _mm256_addsub_ps(ymm3, ymm7); + + ymm0 = _mm256_sub_ps(ymm0, ymm8); + ymm1 = _mm256_sub_ps(ymm1, ymm12); + + ymm2 = _mm256_broadcast_ss((float const *)&ones); + ymm4 = _mm256_broadcast_ss((float const *)&ones); + ymm5 = _mm256_broadcast_ss((float const *)&ones); + + } + ymm18 = _mm256_shuffle_ps(ymm0, ymm4, 0x44); + ymm19 = _mm256_shuffle_ps(ymm2, ymm2, 0x44); + /*BEFORE*/ + /*a[R0][I0] a[R1][I1] a[R2][I2] a[R3][I3] */ + /*b[R0][I0] b[R1][I1] b[R2][I2] b[R3][I3] */ + /*AFTER*/ + /*a[R0][I0] a[R1][I1] b[R0][I0] b[R1][I1]*/ + ymm8 = _mm256_permute2f128_ps(ymm18,ymm19,0x20); + ymm10 = _mm256_permute2f128_ps(ymm18,ymm19,0x31); + /*BEFORE*/ + /*a[R0][I0] a[R1][I1] a[R2][I2] a[R3][I3] */ + /*b[R0][I0] b[R1][I1] b[R2][I2] b[R3][I3] */ + /*AFTER*/ + /*a[R1][I1] b[R1][I1] */ + ymm18 = _mm256_shuffle_ps(ymm0, ymm4, 0xEE); + ymm19 = _mm256_shuffle_ps(ymm2, ymm2, 0xEE); + ymm9 = _mm256_permute2f128_ps(ymm18,ymm19,0x20); + ymm11 = _mm256_permute2f128_ps(ymm18,ymm19,0x31); + + ymm18 = _mm256_shuffle_ps(ymm1, ymm5, 0x44); + ymm19 = _mm256_shuffle_ps(ymm2, ymm2, 0x44); + /*BEFORE*/ + /*a[R0][I0] a[R1][I1] a[R2][I2] a[R3][I3] */ + /*b[R0][I0] b[R1][I1] b[R2][I2] b[R3][I3] */ + /*AFTER*/ + /*a[R0][I0] a[R1][I1] b[R0][I0] b[R1][I1]*/ + ymm12 = _mm256_permute2f128_ps(ymm18,ymm19,0x20); + ymm14 = _mm256_permute2f128_ps(ymm18,ymm19,0x31); + + /*BEFORE*/ + /*a[R0][I0] a[R1][I1] a[R2][I2] a[R3][I3] */ + /*b[R0][I0] b[R1][I1] b[R2][I2] b[R3][I3] */ + /*AFTER*/ + /*a[R1][I1] b[R1][I1] */ + ymm18 = _mm256_shuffle_ps(ymm1, ymm5, 0xEE); + ymm19 = _mm256_shuffle_ps(ymm2, ymm2, 0xEE); + ymm13 = _mm256_permute2f128_ps(ymm18,ymm19,0x20); + ymm15 = _mm256_permute2f128_ps(ymm18,ymm19,0x31); + + ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm8) +#else + BLIS_CTRSM_MUL(ymm8) +#endif + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*1) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //extract a11 + //(ROw1): FMA operations + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm8, ymm8, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm8, ymm8,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm9 = _mm256_sub_ps(ymm9,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*2) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm8, ymm8, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm8, ymm8,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm10 = _mm256_sub_ps(ymm10,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*3) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm8, ymm8, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm8, ymm8,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm11 = _mm256_sub_ps(ymm11,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*4) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm8, ymm8, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm8, ymm8,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm12 = _mm256_sub_ps(ymm12,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*5) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm8, ymm8, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm8, ymm8,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm13 = _mm256_sub_ps(ymm13,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*6) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm8, ymm8, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm8, ymm8,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm14 = _mm256_sub_ps(ymm14,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*7) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm8, ymm8, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm8, ymm8,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm15 = _mm256_sub_ps(ymm15,ymm16); + + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm9) +#else + BLIS_CTRSM_MUL(ymm9) +#endif + + + a11 += rs_a; + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*2) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm9 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm9, ymm9, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm9, ymm9,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm10 = _mm256_sub_ps(ymm10,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*3) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm9, ymm9, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm9, ymm9,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm11 = _mm256_sub_ps(ymm11,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*4) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm9, ymm9, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm9, ymm9,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm12 = _mm256_sub_ps(ymm12,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*5) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm9, ymm9, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm9, ymm9,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm13 = _mm256_sub_ps(ymm13,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*6) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm9, ymm9, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm9, ymm9,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm14 = _mm256_sub_ps(ymm14,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*7) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm9, ymm9, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm9, ymm9,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm15 = _mm256_sub_ps(ymm15,ymm16); + + + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 2)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); + +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm10) +#else + BLIS_CTRSM_MUL(ymm10) +#endif + + + a11 += rs_a; + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*3) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm10 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm10, ymm10, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm10, ymm10,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm11 = _mm256_sub_ps(ymm11,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*4) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm10, ymm10, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm10, ymm10,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm12 = _mm256_sub_ps(ymm12,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*5) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm10, ymm10, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm10, ymm10,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm13 = _mm256_sub_ps(ymm13,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*6) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm10, ymm10, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm10, ymm10,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm14 = _mm256_sub_ps(ymm14,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*7) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm10, ymm10, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm10, ymm10,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm15 = _mm256_sub_ps(ymm15,ymm16); + + + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 3)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm11) +#else + BLIS_CTRSM_MUL(ymm11) +#endif + + + a11 += rs_a; + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*4) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm11, ymm11, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm11, ymm11,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm12 = _mm256_sub_ps(ymm12,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*5) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm11, ymm11, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm11, ymm11,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm13 = _mm256_sub_ps(ymm13,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*6) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm11, ymm11, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm11, ymm11,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm14 = _mm256_sub_ps(ymm14,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*7) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm11, ymm11, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm11, ymm11,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm15 = _mm256_sub_ps(ymm15,ymm16); + + + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 4)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm12) +#else + BLIS_CTRSM_MUL(ymm12) +#endif + a11 += rs_a; + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*5) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm12, ymm12, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm12, ymm12,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm13 = _mm256_sub_ps(ymm13,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*6) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm12, ymm12, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm12, ymm12,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm14 = _mm256_sub_ps(ymm14,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*7) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm12, ymm12, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm12, ymm12,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm15 = _mm256_sub_ps(ymm15,ymm16); + + + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 5)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm13) +#else + BLIS_CTRSM_MUL(ymm13) +#endif + + a11 += rs_a; + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*6) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm13, ymm13, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm13, ymm13,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm14 = _mm256_sub_ps(ymm14,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*7) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm13, ymm13, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm13, ymm13,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm15 = _mm256_sub_ps(ymm15,ymm16); + + + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 6)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm14) +#else + BLIS_CTRSM_MUL(ymm14) +#endif + + a11 += rs_a; + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*7) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm14, ymm14, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm14, ymm14,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm15 = _mm256_sub_ps(ymm15,ymm16); + + + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 7)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm15) +#else + BLIS_CTRSM_MUL(ymm15) +#endif + ymm1 = _mm256_shuffle_ps(ymm8, ymm9, 0x44); + ymm3 = _mm256_shuffle_ps(ymm10, ymm11, 0x44); + + /*rearrange low elements*/ + ymm0 = _mm256_permute2f128_ps(ymm1, ymm3, 0x20); + /*unpack high*/ + ymm8 = _mm256_shuffle_ps(ymm8, ymm9, 0xEE); + ymm9 = _mm256_shuffle_ps(ymm10, ymm11, 0xEE); + + /*rearrange high elements*/ + ymm1 = _mm256_permute2f128_ps(ymm8, ymm9, 0x20); + + ymm3 = _mm256_shuffle_ps(ymm12, ymm13, 0x44); + ymm4 = _mm256_shuffle_ps(ymm14, ymm15, 0x44); + + /*rearrange low elements*/ + ymm4 = _mm256_permute2f128_ps(ymm3, ymm4, 0x20); + /*unpack high*/ + ymm8 = _mm256_shuffle_ps(ymm12, ymm13, 0xEE); + ymm9 = _mm256_shuffle_ps(ymm14, ymm15, 0xEE); + + /*rearrange high elements*/ + ymm5 = _mm256_permute2f128_ps(ymm8, ymm9, 0x20); + + if(2 == n_rem) + { + _mm256_storeu_ps((float *)(b11 + cs_b * 0), ymm0); + _mm256_storeu_ps((float *)(b11 + cs_b * 1), ymm1); + _mm256_storeu_ps((float *)(b11 + cs_b * 0 + 4), ymm4); + _mm256_storeu_ps((float *)(b11 + cs_b * 1 + 4), ymm5); + } + else if(1 == n_rem) + { + _mm256_storeu_ps((float *)(b11 + cs_b * 0), ymm0); + _mm256_storeu_ps((float *)(b11 + cs_b * 0 + 4), ymm4); + } + } + } + dim_t m_rem = m - i; + if(m_rem >= 4) + { + a10 = L + (i*cs_a); + a11 = L + (i*rs_a) + (i*cs_a); + scomplex *ptr_a10_dup = D_A_pack; + dim_t p_lda = 4; + if(transa) + { + for(dim_t x =0;x < i;x+=p_lda) + { + ymm0 = _mm256_loadu_ps((float const *)(a10)); + ymm1 = _mm256_loadu_ps((float const *)(a10 + cs_a)); + ymm2 = _mm256_loadu_ps((float const *)(a10 + cs_a * 2)); + ymm3 = _mm256_loadu_ps((float const *)(a10 + cs_a * 3)); + + ymm4 = _mm256_shuffle_ps(ymm0, ymm1, 0x44); + ymm5 = _mm256_shuffle_ps(ymm2, ymm3, 0x44); + ymm6 = _mm256_permute2f128_ps(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_ps(ymm4,ymm5,0x31); + ymm0 = _mm256_shuffle_ps(ymm0, ymm1, 0xEE); + ymm1 = _mm256_shuffle_ps(ymm2, ymm3, 0xEE); + ymm7 = _mm256_permute2f128_ps(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_ps(ymm0,ymm1,0x31); + + + _mm256_storeu_ps((float *)(ptr_a10_dup), ymm6); + _mm256_storeu_ps((float *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_ps((float *)(ptr_a10_dup + p_lda*2), ymm8); + _mm256_storeu_ps((float *)(ptr_a10_dup + p_lda*3), ymm9); + + a10 += p_lda; + ptr_a10_dup += p_lda*p_lda; + } + } + else + { + for(dim_t x =0;x < i;x++) + { + ymm0 = _mm256_loadu_ps((float const *)(a10 + rs_a * x)); + _mm256_storeu_ps((float *)(ptr_a10_dup + p_lda * x), ymm0); + } + } + + if(!is_unitdiag) + { + if(transa) + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_ps((__m128 const *)(a11)); + ymm1 = _mm256_broadcast_ps((__m128 const *)(a11+cs_a*1 + 1)); + ymm2 = _mm256_broadcast_ps((__m128 const *)(a11+cs_a*2 + 2)); + ymm3 = _mm256_broadcast_ps((__m128 const *)(a11+cs_a*3 + 3)); + ymm0 = _mm256_permute_ps(ymm0, 0x44); + ymm1 = _mm256_permute_ps(ymm1, 0x44); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + ymm3 = _mm256_permute_ps(ymm3, 0x44); + } + else + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_ps((__m128 const *)(a11)); + ymm1 = _mm256_broadcast_ps((__m128 const *)(a11+rs_a*1 + 1)); + ymm2 = _mm256_broadcast_ps((__m128 const *)(a11+rs_a*2 + 2)); + ymm3 = _mm256_broadcast_ps((__m128 const *)(a11+rs_a*3 + 3)); + ymm0 = _mm256_permute_ps(ymm0, 0x44); + ymm1 = _mm256_permute_ps(ymm1, 0x44); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + ymm3 = _mm256_permute_ps(ymm3, 0x44); + } + + ymm1 = _mm256_shuffle_ps(ymm0, ymm1, 0x44); + ymm2 = _mm256_shuffle_ps(ymm2, ymm3, 0x44); + ymm1 = _mm256_blend_ps(ymm1, ymm2, 0xF0); + +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + ymm7 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); + ymm4 = _mm256_mul_ps(ymm1, ymm1); + ymm6 = _mm256_permute_ps(ymm4, 0xB1); + ymm4 = _mm256_add_ps(ymm4, ymm6); + ymm1 = _mm256_mul_ps(ymm1, ymm7); + ymm1 = _mm256_div_ps(ymm1, ymm4); +#endif + } + _mm256_storeu_ps((float *)(d11_pack), ymm1); + + for(j = 0; (j+d_nr-1) < n; j += d_nr) + { + a10 = D_A_pack; + a11 = L + (i*rs_a) + (i*cs_a); + b01 = B + (j*cs_b); + b11 = B + i + (j* cs_b); + + k_iter = i; + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + BLIS_CTRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) + BLIS_CTRSM_SMALL_NREG_TRANSPOSE_3x4(b11,cs_b,AlphaVal) + + ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm8) +#else + BLIS_CTRSM_MUL(ymm8) +#endif + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*1) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + //extract a11 + //(ROw1): FMA operations + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm8, ymm8, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm8, ymm8,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm9 = _mm256_sub_ps(ymm9,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*2) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm8, ymm8, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm8, ymm8,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm10 = _mm256_sub_ps(ymm10,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*3) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm8, ymm8, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm8, ymm8,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm11 = _mm256_sub_ps(ymm11,ymm16); + + + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm9) +#else + BLIS_CTRSM_MUL(ymm9) +#endif + + + a11 += rs_a; + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*2) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + //For ymm9 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm9, ymm9, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm9, ymm9,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm10 = _mm256_sub_ps(ymm10,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*3) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm9, ymm9, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm9, ymm9,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm11 = _mm256_sub_ps(ymm11,ymm16); + + + + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 2)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); + +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm10) +#else + BLIS_CTRSM_MUL(ymm10) +#endif + + + a11 += rs_a; + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*3) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + //For ymm10 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm10, ymm10, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm10, ymm10,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm11 = _mm256_sub_ps(ymm11,ymm16); + + + + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 3)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm11) +#else + BLIS_CTRSM_MUL(ymm11) +#endif + + + a11 += rs_a; + BLIS_CTRSM_SMALL_NREG_TRANSPOSE_4x3_AND_STORE(b11,cs_b) + } + dim_t n_rem = n-j; + if(n_rem) + { + a10 = D_A_pack; + a11 = L + (i*rs_a) + (i*cs_a); + b01 = B + j*cs_b; + b11 = B + i + j* cs_b; + + k_iter = i; + BLIS_SET_S_YMM_REG_ZEROS + + if(2 == n_rem) + { + ///GEMM code begins/// + BLIS_CTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b,p_lda,k_iter) + BLIS_CTRSM_SMALL_NREG_TRANSPOSE_2x4(b11,cs_b,AlphaVal) + } + else if(1 == n_rem) + { + ///GEMM code begins/// + BLIS_CTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b,p_lda,k_iter) + BLIS_CTRSM_SMALL_NREG_TRANSPOSE_1x4(b11,cs_b,AlphaVal) + } + + + ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm8) +#else + BLIS_CTRSM_MUL(ymm8) +#endif + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*1) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + //extract a11 + //(ROw1): FMA operations + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm8, ymm8, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm8, ymm8,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm9 = _mm256_sub_ps(ymm9,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*2) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm8, ymm8, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm8, ymm8,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm10 = _mm256_sub_ps(ymm10,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*3) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm8, ymm8, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm8, ymm8,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm11 = _mm256_sub_ps(ymm11,ymm16); + + + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm9) +#else + BLIS_CTRSM_MUL(ymm9) +#endif + + + a11 += rs_a; + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*2) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + //For ymm9 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm9, ymm9, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm9, ymm9,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm10 = _mm256_sub_ps(ymm10,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*3) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm9, ymm9, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm9, ymm9,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm11 = _mm256_sub_ps(ymm11,ymm16); + + + + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 2)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); + +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm10) +#else + BLIS_CTRSM_MUL(ymm10) +#endif + + + a11 += rs_a; + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*3) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + //For ymm10 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm10, ymm10, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm10, ymm10,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm11 = _mm256_sub_ps(ymm11,ymm16); + + + + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 3)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm11) +#else + BLIS_CTRSM_MUL(ymm11) +#endif + + ymm1 = _mm256_shuffle_ps(ymm8, ymm9, 0x44); + ymm3 = _mm256_shuffle_ps(ymm10, ymm11, 0x44); + + ymm0 = _mm256_permute2f128_ps(ymm1, ymm3, 0x20); + ymm1 = _mm256_permute2f128_ps(ymm1, ymm3, 0x31); + + ymm2 = _mm256_shuffle_ps(ymm8, ymm9, 0xEE); + ymm3 = _mm256_shuffle_ps(ymm10, ymm11, 0xEE); + + ymm4 = _mm256_permute2f128_ps(ymm2, ymm3, 0x20); + ymm5 = _mm256_permute2f128_ps(ymm2, ymm3, 0x31); + + if(2 == n_rem) + { + _mm256_storeu_ps((float *)(b11 + cs_b * 0), ymm0); + _mm256_storeu_ps((float *)(b11 + cs_b * 1), ymm4); + } + else if(1 == n_rem) + { + _mm256_storeu_ps((float *)(b11 + cs_b * 0), ymm0); + } + } + m_rem -=4; + i +=4; + } + if(m_rem) + { + a10 = L + (i*cs_a); + scomplex *ptr_a10_dup = D_A_pack; + if(3 == m_rem) + { + dim_t p_lda = 4; + if(transa) + { + for(dim_t x =0;x < i;x+=p_lda) + { + ymm0 = _mm256_loadu_ps( + (float const *)(a10)); + ymm1 = _mm256_loadu_ps( + (float const *)(a10 + cs_a)); + ymm2 = _mm256_loadu_ps( + (float const *)(a10 + cs_a * 2)); + ymm3 = _mm256_loadu_ps( + (float const *)(a10 + cs_a * 3)); + + ymm4 = _mm256_shuffle_ps(ymm0, ymm1, 0x44); + ymm5 = _mm256_shuffle_ps(ymm2, ymm3, 0x44); + ymm6 = _mm256_permute2f128_ps(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_ps(ymm4,ymm5,0x31); + ymm0 = _mm256_shuffle_ps(ymm0, ymm1, 0xEE); + ymm1 = _mm256_shuffle_ps(ymm2, ymm3, 0xEE); + ymm7 = _mm256_permute2f128_ps(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_ps(ymm0,ymm1,0x31); + + + _mm256_storeu_ps((float *) + (ptr_a10_dup), ymm6); + _mm256_storeu_ps((float *) + (ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_ps((float *) + (ptr_a10_dup + p_lda*2), ymm8); + _mm256_storeu_ps((float *) + (ptr_a10_dup + p_lda*3), ymm9); + + a10 += p_lda; + ptr_a10_dup += p_lda*p_lda; + } + } + else + { + for(dim_t x =0;x < i;x++) + { + ymm0 = _mm256_loadu_ps( + (float const *)(a10 + rs_a * x)); + _mm256_storeu_ps( + (float *)(ptr_a10_dup + + p_lda * x), ymm0); + } + } + + for(j = 0; (j+d_nr-1) < n; j += d_nr) + { + a10 = D_A_pack; + a11 = L + (i*rs_a) + (i*cs_a); + b01 = B + (j*cs_b); + b11 = B + i + (j* cs_b); + + k_iter = i; + BLIS_SET_S_YMM_REG_ZEROS + BLIS_CTRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) + + ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); + ymm16 = _mm256_permute_ps(ymm16, 0x44); + + ymm0 = _mm256_loadu_ps((float const *)(b11)); + ymm1 = _mm256_loadu_ps((float const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_ps((float const *)(b11 + cs_b *2)); + + ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); + + ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); + ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5); + ymm19 = _mm256_mul_ps(ymm19, ymm17); + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19); + ymm8 = _mm256_sub_ps(ymm19, ymm8); + + ymm18 = _mm256_shuffle_ps(ymm1, ymm1, 0xA0); + ymm19 = _mm256_shuffle_ps(ymm1, ymm1,0xF5); + ymm19 = _mm256_mul_ps(ymm19, ymm17); + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19); + ymm9 = _mm256_sub_ps(ymm19, ymm9); + + ymm18 = _mm256_shuffle_ps(ymm2, ymm2, 0xA0); + ymm19 = _mm256_shuffle_ps(ymm2, ymm2,0xF5); + ymm19 = _mm256_mul_ps(ymm19, ymm17); + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19); + ymm10 = _mm256_sub_ps(ymm19, ymm10); + + xmm0 = _mm256_extractf128_ps(ymm8, 0); + xmm1 = _mm256_extractf128_ps(ymm9, 0); + xmm2 = _mm256_extractf128_ps(ymm10, 0); + + _mm_storeu_ps((float *)(b11), xmm0); + _mm_storeu_ps((float *)(b11 + cs_b * 1), xmm1); + _mm_storeu_ps((float *)(b11 + cs_b * 2), xmm2); + + xmm0 = _mm256_extractf128_ps(ymm8, 1); + xmm1 = _mm256_extractf128_ps(ymm9, 1); + xmm2 = _mm256_extractf128_ps(ymm10, 1); + + _mm_storel_pi((__m64 *)(b11 + 2), xmm0); + _mm_storel_pi((__m64 *)(b11 + cs_b * 1 + 2), xmm1); + _mm_storel_pi((__m64 *)(b11 + cs_b * 2 + 2), xmm2); + + + + if(transa) + ctrsm_AutXB_ref(a11, b11, m_rem, 3, + cs_a, cs_b,is_unitdiag, conjtransa); + else + { + ctrsm_AlXB_ref(a11, b11, m_rem, 3, + rs_a, cs_b, + is_unitdiag, conjtransa); + } + } + dim_t n_rem = n-j; + if(n_rem) + { + a10 = D_A_pack; + a11 = L + (i*rs_a) + (i*cs_a); + b01 = B + (j*cs_b); + b11 = B + i + (j* cs_b); + + k_iter = i; + BLIS_SET_S_YMM_REG_ZEROS + if(2 == n_rem) + { + ///GEMM code begins/// + BLIS_CTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b, + p_lda,k_iter) + + BLIS_PRE_CTRSM_SMALL_3M_2N(AlphaVal,b11,cs_b) + + if(transa) + ctrsm_AutXB_ref(a11, b11, m_rem, 2, + cs_a, cs_b,is_unitdiag, + conjtransa); + else + ctrsm_AlXB_ref(a11, b11, m_rem, 2, + rs_a, cs_b, is_unitdiag, + conjtransa); + } + else if(1 == n_rem) + { + ///GEMM code begins/// + BLIS_CTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b, + p_lda,k_iter) + + BLIS_PRE_CTRSM_SMALL_3M_1N(AlphaVal,b11,cs_b) + + if(transa) + ctrsm_AutXB_ref(a11, b11, m_rem, 1, + cs_a, cs_b, is_unitdiag, + conjtransa); + else + ctrsm_AlXB_ref(a11, b11, m_rem, 1, + rs_a, cs_b, is_unitdiag, + conjtransa); + } + } + } + if(2 == m_rem) + { + dim_t p_lda = 4; + if(transa) + { + for(dim_t x =0;x < i;x+=p_lda) + { + ymm0 = _mm256_loadu_ps( + (float const *)(a10)); + ymm1 = _mm256_loadu_ps( + (float const *)(a10 + cs_a)); + + ymm2 = _mm256_broadcast_ss((float const *)&ones); + + ymm4 = _mm256_shuffle_ps(ymm0, ymm1, 0x44); + ymm5 = _mm256_shuffle_ps(ymm2, ymm2, 0x44); + ymm6 = _mm256_permute2f128_ps(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_ps(ymm4,ymm5,0x31); + ymm0 = _mm256_shuffle_ps(ymm0, ymm1, 0xEE); + ymm1 = _mm256_shuffle_ps(ymm2, ymm2, 0xEE); + ymm7 = _mm256_permute2f128_ps(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_ps(ymm0,ymm1,0x31); + + + _mm256_storeu_ps((float *) + (ptr_a10_dup), ymm6); + _mm256_storeu_ps((float *) + (ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_ps((float *) + (ptr_a10_dup + p_lda*2), ymm8); + _mm256_storeu_ps((float *) + (ptr_a10_dup + p_lda*3), ymm9); + + a10 += p_lda; + ptr_a10_dup += p_lda*p_lda; + } + } + else + { + for(dim_t x =0;x < i;x++) + { + ymm0 = _mm256_loadu_ps( + (float const *)(a10 + rs_a * x)); + _mm256_storeu_ps( + (float *)(ptr_a10_dup + + p_lda * x), ymm0); + } + } + + for(j = 0; (j+d_nr-1) < n; j += d_nr) + { + a10 = D_A_pack; + a11 = L + (i*rs_a) + (i*cs_a); + b01 = B + (j*cs_b); + b11 = B + i + (j* cs_b); + + k_iter = i; + BLIS_SET_S_YMM_REG_ZEROS + BLIS_CTRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) + + ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); + ymm16 = _mm256_permute_ps(ymm16, 0x44); + + ymm0 = _mm256_loadu_ps((float const *)(b11)); + ymm1 = _mm256_loadu_ps((float const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_ps((float const *)(b11 + cs_b *2)); + + + ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); + + ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); + ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5); + ymm19 = _mm256_mul_ps(ymm19, ymm17); + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19); + ymm8 = _mm256_sub_ps(ymm19, ymm8); + + ymm18 = _mm256_shuffle_ps(ymm1, ymm1, 0xA0); + ymm19 = _mm256_shuffle_ps(ymm1, ymm1,0xF5); + ymm19 = _mm256_mul_ps(ymm19, ymm17); + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19); + ymm9 = _mm256_sub_ps(ymm19, ymm9); + + ymm18 = _mm256_shuffle_ps(ymm2, ymm2, 0xA0); + ymm19 = _mm256_shuffle_ps(ymm2, ymm2,0xF5); + ymm19 = _mm256_mul_ps(ymm19, ymm17); + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19); + ymm10 = _mm256_sub_ps(ymm19, ymm10); + + xmm0 = _mm256_extractf128_ps(ymm8, 0); + xmm1 = _mm256_extractf128_ps(ymm9, 0); + xmm2 = _mm256_extractf128_ps(ymm10, 0); + + _mm_storeu_ps((float *)(b11), xmm0); + _mm_storeu_ps((float *)(b11 + cs_b * 1), xmm1); + _mm_storeu_ps((float *)(b11 + cs_b * 2), xmm2); + + + + if(transa) + ctrsm_AutXB_ref(a11, b11, m_rem, 3, + cs_a, cs_b,is_unitdiag, conjtransa); + else + { + ctrsm_AlXB_ref(a11, b11, m_rem, 3, + rs_a, cs_b, + is_unitdiag, conjtransa); + } + } + dim_t n_rem = n-j; + if(n_rem) + { + a10 = D_A_pack; + a11 = L + (i*rs_a) + (i*cs_a); + b01 = B + (j*cs_b); + b11 = B + i + (j* cs_b); + + k_iter = i; + BLIS_SET_S_YMM_REG_ZEROS + if(2 == n_rem) + { + ///GEMM code begins/// + BLIS_CTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b, + p_lda,k_iter) + + BLIS_PRE_CTRSM_SMALL_2M_2N(AlphaVal,b11,cs_b) + + if(transa) + ctrsm_AutXB_ref(a11, b11, m_rem, 2, + cs_a, cs_b,is_unitdiag, + conjtransa); + else + ctrsm_AlXB_ref(a11, b11, m_rem, 2, + rs_a, cs_b, is_unitdiag, + conjtransa); + } + else if(1 == n_rem) + { + ///GEMM code begins/// + BLIS_CTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b, + p_lda,k_iter) + + BLIS_PRE_CTRSM_SMALL_2M_1N(AlphaVal,b11,cs_b) + + if(transa) + ctrsm_AutXB_ref(a11, b11, m_rem, 1, + cs_a, cs_b, is_unitdiag + ,conjtransa); + else + ctrsm_AlXB_ref(a11, b11, m_rem, 1, + rs_a, cs_b, is_unitdiag, + conjtransa); + } + } + } + if(1 == m_rem) + { + dim_t p_lda = 4; + if(transa) + { + for(dim_t x =0;x < i;x+=p_lda) + { + ymm0 = _mm256_loadu_ps( + (float const *)(a10)); + + ymm1 = _mm256_broadcast_ss((float const *)&ones); + + ymm4 = _mm256_shuffle_ps(ymm0, ymm1, 0x44); + ymm5 = _mm256_shuffle_ps(ymm1, ymm1, 0x44); + ymm6 = _mm256_permute2f128_ps(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_ps(ymm4,ymm5,0x31); + ymm0 = _mm256_shuffle_ps(ymm0, ymm1, 0xEE); + ymm1 = _mm256_shuffle_ps(ymm1, ymm1, 0xEE); + ymm7 = _mm256_permute2f128_ps(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_ps(ymm0,ymm1,0x31); + + + _mm256_storeu_ps((float *) + (ptr_a10_dup), ymm6); + _mm256_storeu_ps((float *) + (ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_ps((float *) + (ptr_a10_dup + p_lda*2), ymm8); + _mm256_storeu_ps((float *) + (ptr_a10_dup + p_lda*3), ymm9); + + a10 += p_lda; + ptr_a10_dup += p_lda*p_lda; + } + } + else + { + for(dim_t x =0;x < i;x++) + { + ymm0 = _mm256_loadu_ps( + (float const *)(a10 + rs_a * x)); + _mm256_storeu_ps( + (float *)(ptr_a10_dup + + p_lda * x), ymm0); + } + } + + for(j = 0; (j+d_nr-1) < n; j += d_nr) + { + a10 = D_A_pack; + a11 = L + (i*rs_a) + (i*cs_a); + b01 = B + (j*cs_b); + b11 = B + i + (j* cs_b); + + k_iter = i; + BLIS_SET_S_YMM_REG_ZEROS + BLIS_CTRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) + + ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); + ymm16 = _mm256_permute_ps(ymm16, 0x44); + + ymm0 = _mm256_loadu_ps((float const *)(b11)); + ymm1 = _mm256_loadu_ps((float const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_ps((float const *)(b11 + cs_b *2)); + + ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); + + ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); + ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5); + ymm19 = _mm256_mul_ps(ymm19, ymm17); + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19); + ymm8 = _mm256_sub_ps(ymm19, ymm8); + + ymm18 = _mm256_shuffle_ps(ymm1, ymm1, 0xA0); + ymm19 = _mm256_shuffle_ps(ymm1, ymm1,0xF5); + ymm19 = _mm256_mul_ps(ymm19, ymm17); + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19); + ymm9 = _mm256_sub_ps(ymm19, ymm9); + + ymm18 = _mm256_shuffle_ps(ymm2, ymm2, 0xA0); + ymm19 = _mm256_shuffle_ps(ymm2, ymm2,0xF5); + ymm19 = _mm256_mul_ps(ymm19, ymm17); + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19); + ymm10 = _mm256_sub_ps(ymm19, ymm10); + + xmm0 = _mm256_extractf128_ps(ymm8, 0); + xmm1 = _mm256_extractf128_ps(ymm9, 0); + xmm2 = _mm256_extractf128_ps(ymm10, 0); + + _mm_storel_pi((__m64 *)(b11), xmm0); + _mm_storel_pi((__m64 *)(b11 + cs_b * 1), xmm1); + _mm_storel_pi((__m64 *)(b11 + cs_b * 2), xmm2); + + + + if(transa) + ctrsm_AutXB_ref(a11, b11, m_rem, 3, + cs_a, cs_b,is_unitdiag, conjtransa); + else + { + ctrsm_AlXB_ref(a11, b11, m_rem, 3, + rs_a, cs_b, + is_unitdiag, conjtransa); + } + } + + dim_t n_rem = n-j; + if(n_rem) + { + a10 = D_A_pack; + a11 = L + (i*rs_a) + (i*cs_a); + b01 = B + (j*cs_b); + b11 = B + i + (j* cs_b); + + k_iter = i; + BLIS_SET_S_YMM_REG_ZEROS + if(2 == n_rem) + { + ///GEMM code begins/// + BLIS_CTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b, + p_lda,k_iter) + + BLIS_PRE_CTRSM_SMALL_1M_2N(AlphaVal,b11, + cs_b) + + if(transa) + ctrsm_AutXB_ref(a11, b11, m_rem, + 2, cs_a, cs_b, + is_unitdiag, + conjtransa); + else + ctrsm_AlXB_ref(a11, b11, m_rem, 2, + rs_a, cs_b, + is_unitdiag, + conjtransa); + } + else if(1 == n_rem) + { + ///GEMM code begins/// + BLIS_CTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b, + p_lda,k_iter) + + BLIS_PRE_CTRSM_SMALL_1M_1N(AlphaVal,b11, + cs_b) + + if(transa) + ctrsm_AutXB_ref(a11, b11, m_rem, + 1, cs_a, cs_b, + is_unitdiag, + conjtransa); + else + ctrsm_AlXB_ref(a11, b11, m_rem, 1, + rs_a, cs_b, is_unitdiag, + conjtransa); + } + } + } + } + if ((required_packing_A == 1) && + bli_mem_is_alloc( &local_mem_buf_A_s )) + { + bli_membrk_release(&rntm, &local_mem_buf_A_s); + } + + return BLIS_SUCCESS; } BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB @@ -36529,9 +40837,2235 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB cntl_t* cntl ) { - return BLIS_SUCCESS; + dim_t m = bli_obj_length(b); // number of rows of matrix B + dim_t n = bli_obj_width(b); // number of columns of matrix B + + bool transa = bli_obj_has_trans(a); + bool conjtransa = bli_obj_has_conj(a); + + dim_t cs_a, rs_a; + dim_t d_mr = 8,d_nr = 3; + + // Swap rs_a & cs_a in case of non-tranpose. + if(transa) + { + cs_a = bli_obj_col_stride(a); // column stride of A + rs_a = bli_obj_row_stride(a); // row stride of A + } + else + { + cs_a = bli_obj_row_stride(a); // row stride of A + rs_a = bli_obj_col_stride(a); // column stride of A + } + dim_t cs_b = bli_obj_col_stride(b); // column stride of B + + dim_t i, j, k; //loop variables + dim_t k_iter; //number of times GEMM to be performed + + scomplex AlphaVal = *(scomplex *)AlphaObj->buffer; //value of alpha + scomplex *L = a->buffer; //pointer to matrix A + scomplex *B = b->buffer; //pointer to matrix B + + scomplex *a10, *a11, *b01, *b11; //pointers that point to blocks for GEMM and TRSM + + scomplex ones = {1.0, 1.0}; + bool is_unitdiag = bli_obj_has_unit_diag(a); + + //scratch registers + __m256 ymm0, ymm1, ymm2, ymm3; + __m256 ymm4, ymm5, ymm6, ymm7; + __m256 ymm8, ymm9, ymm10, ymm11; + __m256 ymm12, ymm13, ymm14, ymm15; + __m256 ymm16, ymm17, ymm18, ymm19; + + __m128 xmm0, xmm1, xmm2, xmm3, xmm4; + + gint_t required_packing_A = 1; + mem_t local_mem_buf_A_s = {0}; + scomplex *D_A_pack = NULL; + scomplex d11_pack[d_mr] __attribute__((aligned(64))); + rntm_t rntm; + + bli_rntm_init_from_global( &rntm ); + bli_rntm_set_num_threads_only( 1, &rntm ); + bli_membrk_rntm_set_membrk( &rntm ); + + siz_t buffer_size = bli_pool_block_size( + bli_membrk_pool( + bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), + bli_rntm_membrk(&rntm))); + + if ( (d_mr * m * sizeof(scomplex)) > buffer_size) + return BLIS_NOT_YET_IMPLEMENTED; + + if (required_packing_A == 1) + { + // Get the buffer from the pool. + bli_membrk_acquire_m(&rntm, + buffer_size, + BLIS_BITVAL_BUFFER_FOR_A_BLOCK, + &local_mem_buf_A_s); + if(FALSE==bli_mem_is_alloc(&local_mem_buf_A_s)) return BLIS_NULL_POINTER; + D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); + if(NULL==D_A_pack) return BLIS_NULL_POINTER; + } + + /* + Performs solving TRSM for 4 colmns at a time from 0 to m/4 in steps of d_mr + a. Load, transpose, Pack A (a10 block), the size of packing 4x3 to 4x (m-4) + First there will be no GEMM and no packing of a10 because it is only TRSM + b. Using packed a10 block and b01 block perform GEMM operation + c. Use GEMM outputs, perform TRSM operaton using a11, b11 and update B + d. Repeat b,c for n rows of B in steps of d_nr + */ + for(i = (m - d_mr); (i + 1) > 0; i -= d_mr) //loop along 'M' dimension + { + + a10 = L + (i*cs_a) + (i + d_mr)*rs_a; //pointer to block of A to be used for GEMM + a11 = L + (i*cs_a) + (i*rs_a); //pointer to block of A to be used for TRSM + + dim_t p_lda = d_mr; // packed leading dimension + + if(transa) + { + /* + Load, tranpose and pack current A block (a10) into packed buffer memory + D_A_pack + a. This a10 block is used in GEMM portion only and this + a10 block size will be increasing by d_mr for every next itteration + untill it reaches 4x(m-4) which is the maximum GEMM alone block size + in A + b. This packed buffer is reused to calculate all n rows of B matrix + */ + bli_ctrsm_small_pack('L', (m-i-d_mr), 1, a10, cs_a, D_A_pack, p_lda,d_mr); + + /* + Pack 4 diagonal elements of A block into an array + a. This helps in utilze cache line efficiently in TRSM operation + b. store ones when input is unit diagonal + */ + ctrsm_small_pack_diag_element(is_unitdiag,a11,cs_a,d11_pack,d_mr); + } + else + { + bli_ctrsm_small_pack('L', (m-i-d_mr), 0, a10, rs_a, D_A_pack, p_lda,d_mr); + ctrsm_small_pack_diag_element(is_unitdiag,a11,rs_a,d11_pack,d_mr); + } + /* + a. Perform GEMM using a10, b01. + b. Perform TRSM on a11, b11 + c. This loop GEMM+TRSM loops operates with 4x3 block size + along n dimension for every d_nr rows of b01 where + packed A buffer is reused in computing all n rows of B. + d. Same approch is used in remaining fringe cases. + */ + + for(j = (n - d_nr); (j + 1) > 0; j -= d_nr) //loop along 'N' dimension + { + a10 = D_A_pack; + b01 = B + (j * cs_b) + i + d_mr; //pointer to block of B to be used for GEMM + b11 = B + (j * cs_b) + i; //pointer to block of B to be used for TRSM + + k_iter = (m - i - d_mr); + + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + /* + Peform GEMM between a10 and b01 blocks + For first itteration there will be no GEMM operation + where k_iter are zero + */ + BLIS_CTRSM_SMALL_GEMM_8mx3n(a10,b01,cs_b,p_lda,k_iter) + + /* + Load b11 of size 3x4 and multiply with alpha + Add the GEMM output and perform inregister transose of b11 + to peform TRSM operation. + */ + BLIS_CTRSM_SMALL_NREG_TRANSPOSE_3x8(b11,cs_b,AlphaVal) + /* + Compute 4x3 TRSM block by using GEMM block output in register + a. The 4x3 input (gemm outputs) are stored in combinations of ymm + registers + 1. ymm8, ymm4 2. ymm9, ymm5 3. ymm10, ymm6, 4. ymm11, ymm7 + where ymm8-ymm11 holds 4x2 data and reaming 4x1 will be hold by + other registers + b. Towards the end do in regiser transpose of TRSM output and store in + b11 + */ + ////extract a00 + ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 7)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm15) +#else + BLIS_CTRSM_MUL(ymm15) +#endif + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*6 + 7*rs_a)); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + //extract a11 + //(ROw1): FMA operations + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm15, ymm15, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm15, ymm15,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm14 = _mm256_sub_ps(ymm14,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*5 + 7*rs_a)); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm15, ymm15, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm15, ymm15,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm13 = _mm256_sub_ps(ymm13,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*4 + 7*rs_a)); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm15, ymm15, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm15, ymm15,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm12 = _mm256_sub_ps(ymm12,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*3 + 7*rs_a)); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm15, ymm15, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm15, ymm15,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm11 = _mm256_sub_ps(ymm11,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*2 + 7*rs_a)); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm15, ymm15, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm15, ymm15,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm10 = _mm256_sub_ps(ymm10,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*1 + 7*rs_a)); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm15, ymm15, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm15, ymm15,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm9 = _mm256_sub_ps(ymm9,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*0 + 7*rs_a)); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm15, ymm15, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm15, ymm15,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm8 = _mm256_sub_ps(ymm8,ymm16); + + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 6)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm14) +#else + BLIS_CTRSM_MUL(ymm14) +#endif + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*5 + 6*rs_a)); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm9 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm14, ymm14, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm14, ymm14,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm13 = _mm256_sub_ps(ymm13,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*4 + 6*rs_a)); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm14, ymm14, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm14, ymm14,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm12 = _mm256_sub_ps(ymm12,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*3 + 6*rs_a)); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm14, ymm14, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm14, ymm14,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm11 = _mm256_sub_ps(ymm11,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*2 + 6*rs_a)); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm14, ymm14, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm14, ymm14,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm10 = _mm256_sub_ps(ymm10,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*1 + 6*rs_a)); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm14, ymm14, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm14, ymm14,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm9 = _mm256_sub_ps(ymm9,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*0 + 6*rs_a)); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm14, ymm14, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm14, ymm14,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm8 = _mm256_sub_ps(ymm8,ymm16); + + + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 5)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); + +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm13) +#else + BLIS_CTRSM_MUL(ymm13) +#endif + + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*4 + 5*rs_a)); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm10 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm13, ymm13, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm13, ymm13,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm12 = _mm256_sub_ps(ymm12,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*3 + 5*rs_a)); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm13, ymm13, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm13, ymm13,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm11 = _mm256_sub_ps(ymm11,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*2 + 5*rs_a)); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm13, ymm13, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm13, ymm13,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm10 = _mm256_sub_ps(ymm10,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*1 + 5*rs_a)); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm13, ymm13, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm13, ymm13,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm9 = _mm256_sub_ps(ymm9,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*0 + 5*rs_a)); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm13, ymm13, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm13, ymm13,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm8 = _mm256_sub_ps(ymm8,ymm16); + + + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 4)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm12) +#else + BLIS_CTRSM_MUL(ymm12) +#endif + + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*3 + 4*rs_a)); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm12, ymm12, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm12, ymm12,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm11 = _mm256_sub_ps(ymm11,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*2 + 4*rs_a)); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm12, ymm12, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm12, ymm12,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm10 = _mm256_sub_ps(ymm10,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*1 + 4*rs_a)); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm12, ymm12, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm12, ymm12,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm9 = _mm256_sub_ps(ymm9,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*0 + 4*rs_a)); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm12, ymm12, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm12, ymm12,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm8 = _mm256_sub_ps(ymm8,ymm16); + + + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 3)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm11) +#else + BLIS_CTRSM_MUL(ymm11) +#endif + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*2 + 3*rs_a)); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm11, ymm11, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm11, ymm11,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm10 = _mm256_sub_ps(ymm10,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*1 + 3*rs_a)); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm11, ymm11, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm11, ymm11,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm9 = _mm256_sub_ps(ymm9,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*0 + 3*rs_a)); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm11, ymm11, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm11, ymm11,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm8 = _mm256_sub_ps(ymm8,ymm16); + + + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 2)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm10) +#else + BLIS_CTRSM_MUL(ymm10) +#endif + + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a + 2*rs_a)); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm10, ymm10, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm10, ymm10,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm9 = _mm256_sub_ps(ymm9,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*0 + 2*rs_a)); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm10, ymm10, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm10, ymm10,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm8 = _mm256_sub_ps(ymm8,ymm16); + + + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm9) +#else + BLIS_CTRSM_MUL(ymm9) +#endif + + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm9, ymm9, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm9, ymm9,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm8 = _mm256_sub_ps(ymm8,ymm16); + + + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm8) +#else + BLIS_CTRSM_MUL(ymm8) +#endif + + BLIS_CTRSM_SMALL_NREG_TRANSPOSE_8x3_AND_STORE(b11,cs_b) + + } + dim_t n_rem = j + d_nr; + if(n_rem) + { + + a10 = D_A_pack; + a11 = L + (i*cs_a) + (i*rs_a); + b01 = B + i + d_mr; + b11 = B + i; + + k_iter = (m - i - d_mr) ; + + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + if(2 == n_rem) + { + BLIS_CTRSM_SMALL_GEMM_8mx2n(a10,b01,cs_b,p_lda,k_iter) + + float zero = 0.0; + ymm16 = _mm256_broadcast_ss(&AlphaVal.real); + ymm17 = _mm256_broadcast_ss(&AlphaVal.imag); + ymm2 = _mm256_broadcast_ss(&zero); + ymm3 = _mm256_broadcast_ss(&zero); + ymm6 = _mm256_broadcast_ss(&zero); + ymm7 = _mm256_broadcast_ss(&zero); + + ymm0 = _mm256_loadu_ps((float const *)(b11)); + ymm4 = _mm256_loadu_ps((float const *)(b11 + cs_b *1)); + ymm1 = _mm256_loadu_ps((float const *)(b11 + 4)); + ymm5 = _mm256_loadu_ps((float const *)(b11 + cs_b *1 + 4)); + + ymm2 = _mm256_fmadd_ps(ymm0, ymm16, ymm2); + ymm3 = _mm256_fmadd_ps(ymm1, ymm16, ymm3); + ymm6 = _mm256_fmadd_ps(ymm0, ymm17, ymm6); + ymm7 = _mm256_fmadd_ps(ymm1, ymm17, ymm7); + + ymm6 = _mm256_permute_ps(ymm6, 0xb1); + ymm7 = _mm256_permute_ps(ymm7, 0xb1); + + ymm0 = _mm256_addsub_ps(ymm2, ymm6); + ymm1 = _mm256_addsub_ps(ymm3, ymm7); + ymm0 = _mm256_sub_ps(ymm0, ymm8); + ymm1 = _mm256_sub_ps(ymm1, ymm12); + + ymm2 = _mm256_broadcast_ss(&zero); + ymm3 = _mm256_broadcast_ss(&zero); + + ymm6 = _mm256_broadcast_ss(&zero); + ymm7 = _mm256_broadcast_ss(&zero); + + ymm2 = _mm256_fmadd_ps(ymm4, ymm16, ymm2); + ymm3 = _mm256_fmadd_ps(ymm5, ymm16, ymm3); + ymm6 = _mm256_fmadd_ps(ymm4, ymm17, ymm6); + ymm7 = _mm256_fmadd_ps(ymm5, ymm17, ymm7); + + ymm6 = _mm256_permute_ps(ymm6, 0xb1); + ymm7 = _mm256_permute_ps(ymm7, 0xb1); + + ymm4 = _mm256_addsub_ps(ymm2, ymm6); + ymm5 = _mm256_addsub_ps(ymm3, ymm7); + ymm4 = _mm256_sub_ps(ymm4, ymm9); + ymm5 = _mm256_sub_ps(ymm5, ymm13); + ymm2 = _mm256_broadcast_ss((float const *)&ones); + ymm3 = _mm256_broadcast_ss((float const *)&ones); + + } + if(1 == n_rem) + { + BLIS_CTRSM_SMALL_GEMM_8mx1n(a10,b01,cs_b,p_lda,k_iter) + + float zero = 0.0; + ymm16 = _mm256_broadcast_ss(&AlphaVal.real); + ymm17 = _mm256_broadcast_ss(&AlphaVal.imag); + ymm2 = _mm256_broadcast_ss(&zero); + ymm3 = _mm256_broadcast_ss(&zero); + ymm6 = _mm256_broadcast_ss(&zero); + ymm7 = _mm256_broadcast_ss(&zero); + + ymm0 = _mm256_loadu_ps((float const *)(b11)); + ymm1 = _mm256_loadu_ps((float const *)(b11 + 4)); + + ymm2 = _mm256_fmadd_ps(ymm0, ymm16, ymm2); + ymm3 = _mm256_fmadd_ps(ymm1, ymm16, ymm3); + ymm6 = _mm256_fmadd_ps(ymm0, ymm17, ymm6); + ymm7 = _mm256_fmadd_ps(ymm1, ymm17, ymm7); + + ymm6 = _mm256_permute_ps(ymm6, 0xb1); + ymm7 = _mm256_permute_ps(ymm7, 0xb1); + + ymm0 = _mm256_addsub_ps(ymm2, ymm6); + ymm1 = _mm256_addsub_ps(ymm3, ymm7); + + ymm0 = _mm256_sub_ps(ymm0, ymm8); + ymm1 = _mm256_sub_ps(ymm1, ymm12); + + ymm2 = _mm256_broadcast_ss((float const *)&ones); + ymm4 = _mm256_broadcast_ss((float const *)&ones); + ymm5 = _mm256_broadcast_ss((float const *)&ones); + + } + ymm18 = _mm256_shuffle_ps(ymm0, ymm4, 0x44); + ymm19 = _mm256_shuffle_ps(ymm2, ymm2, 0x44); + /*BEFORE*/ + /*a[R0][I0] a[R1][I1] a[R2][I2] a[R3][I3] */ + /*b[R0][I0] b[R1][I1] b[R2][I2] b[R3][I3] */ + /*AFTER*/ + /*a[R0][I0] a[R1][I1] b[R0][I0] b[R1][I1]*/ + ymm8 = _mm256_permute2f128_ps(ymm18,ymm19,0x20); + ymm10 = _mm256_permute2f128_ps(ymm18,ymm19,0x31); + /*BEFORE*/ + /*a[R0][I0] a[R1][I1] a[R2][I2] a[R3][I3] */ + /*b[R0][I0] b[R1][I1] b[R2][I2] b[R3][I3] */ + /*AFTER*/ + /*a[R1][I1] b[R1][I1] */ + ymm18 = _mm256_shuffle_ps(ymm0, ymm4, 0xEE); + ymm19 = _mm256_shuffle_ps(ymm2, ymm2, 0xEE); + ymm9 = _mm256_permute2f128_ps(ymm18,ymm19,0x20); + ymm11 = _mm256_permute2f128_ps(ymm18,ymm19,0x31); + + ymm18 = _mm256_shuffle_ps(ymm1, ymm5, 0x44); + ymm19 = _mm256_shuffle_ps(ymm2, ymm2, 0x44); + /*BEFORE*/ + /*a[R0][I0] a[R1][I1] a[R2][I2] a[R3][I3] */ + /*b[R0][I0] b[R1][I1] b[R2][I2] b[R3][I3] */ + /*AFTER*/ + /*a[R0][I0] a[R1][I1] b[R0][I0] b[R1][I1]*/ + ymm12 = _mm256_permute2f128_ps(ymm18,ymm19,0x20); + ymm14 = _mm256_permute2f128_ps(ymm18,ymm19,0x31); + + /*BEFORE*/ + /*a[R0][I0] a[R1][I1] a[R2][I2] a[R3][I3] */ + /*b[R0][I0] b[R1][I1] b[R2][I2] b[R3][I3] */ + /*AFTER*/ + /*a[R1][I1] b[R1][I1] */ + ymm18 = _mm256_shuffle_ps(ymm1, ymm5, 0xEE); + ymm19 = _mm256_shuffle_ps(ymm2, ymm2, 0xEE); + ymm13 = _mm256_permute2f128_ps(ymm18,ymm19,0x20); + ymm15 = _mm256_permute2f128_ps(ymm18,ymm19,0x31); + + + ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 7)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm15) +#else + BLIS_CTRSM_MUL(ymm15) +#endif + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*6 + 7*rs_a)); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //extract a11 + //(ROw1): FMA operations + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm15, ymm15, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm15, ymm15,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm14 = _mm256_sub_ps(ymm14,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*5 + 7*rs_a)); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm15, ymm15, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm15, ymm15,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm13 = _mm256_sub_ps(ymm13,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*4 + 7*rs_a)); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm15, ymm15, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm15, ymm15,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm12 = _mm256_sub_ps(ymm12,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*3 + 7*rs_a)); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm15, ymm15, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm15, ymm15,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm11 = _mm256_sub_ps(ymm11,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*2 + 7*rs_a)); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm15, ymm15, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm15, ymm15,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm10 = _mm256_sub_ps(ymm10,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*1 + 7*rs_a)); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm15, ymm15, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm15, ymm15,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm9 = _mm256_sub_ps(ymm9,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*0 + 7*rs_a)); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm15, ymm15, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm15, ymm15,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm8 = _mm256_sub_ps(ymm8,ymm16); + + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 6)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm14) +#else + BLIS_CTRSM_MUL(ymm14) +#endif + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*5 + 6*rs_a)); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm9 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm14, ymm14, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm14, ymm14,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm13 = _mm256_sub_ps(ymm13,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*4 + 6*rs_a)); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm14, ymm14, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm14, ymm14,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm12 = _mm256_sub_ps(ymm12,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*3 + 6*rs_a)); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm14, ymm14, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm14, ymm14,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm11 = _mm256_sub_ps(ymm11,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*2 + 6*rs_a)); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm14, ymm14, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm14, ymm14,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm10 = _mm256_sub_ps(ymm10,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*1 + 6*rs_a)); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm14, ymm14, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm14, ymm14,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm9 = _mm256_sub_ps(ymm9,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*0 + 6*rs_a)); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm14, ymm14, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm14, ymm14,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm8 = _mm256_sub_ps(ymm8,ymm16); + + + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 5)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); + +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm13) +#else + BLIS_CTRSM_MUL(ymm13) +#endif + + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*4 + 5*rs_a)); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm10 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm13, ymm13, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm13, ymm13,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm12 = _mm256_sub_ps(ymm12,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*3 + 5*rs_a)); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm13, ymm13, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm13, ymm13,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm11 = _mm256_sub_ps(ymm11,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*2 + 5*rs_a)); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm13, ymm13, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm13, ymm13,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm10 = _mm256_sub_ps(ymm10,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*1 + 5*rs_a)); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm13, ymm13, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm13, ymm13,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm9 = _mm256_sub_ps(ymm9,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*0 + 5*rs_a)); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm13, ymm13, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm13, ymm13,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm8 = _mm256_sub_ps(ymm8,ymm16); + + + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 4)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm12) +#else + BLIS_CTRSM_MUL(ymm12) +#endif + + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*3 + 4*rs_a)); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm12, ymm12, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm12, ymm12,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm11 = _mm256_sub_ps(ymm11,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*2 + 4*rs_a)); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm12, ymm12, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm12, ymm12,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm10 = _mm256_sub_ps(ymm10,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*1 + 4*rs_a)); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm12, ymm12, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm12, ymm12,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm9 = _mm256_sub_ps(ymm9,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*0 + 4*rs_a)); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm12, ymm12, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm12, ymm12,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm8 = _mm256_sub_ps(ymm8,ymm16); + + + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 3)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm11) +#else + BLIS_CTRSM_MUL(ymm11) +#endif + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*2 + 3*rs_a)); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm11, ymm11, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm11, ymm11,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm10 = _mm256_sub_ps(ymm10,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*1 + 3*rs_a)); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm11, ymm11, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm11, ymm11,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm9 = _mm256_sub_ps(ymm9,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*0 + 3*rs_a)); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm11, ymm11, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm11, ymm11,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm8 = _mm256_sub_ps(ymm8,ymm16); + + + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 2)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm10) +#else + BLIS_CTRSM_MUL(ymm10) +#endif + + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a + 2*rs_a)); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm10, ymm10, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm10, ymm10,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm9 = _mm256_sub_ps(ymm9,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*0 + 2*rs_a)); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm10, ymm10, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm10, ymm10,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm8 = _mm256_sub_ps(ymm8,ymm16); + + + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm9) +#else + BLIS_CTRSM_MUL(ymm9) +#endif + + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm9, ymm9, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm9, ymm9,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm8 = _mm256_sub_ps(ymm8,ymm16); + + + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm8) +#else + BLIS_CTRSM_MUL(ymm8) +#endif + + ymm1 = _mm256_shuffle_ps(ymm8, ymm9, 0x44); + ymm3 = _mm256_shuffle_ps(ymm10, ymm11, 0x44); + + /*rearrange low elements*/ + ymm0 = _mm256_permute2f128_ps(ymm1, ymm3, 0x20); + /*unpack high*/ + ymm8 = _mm256_shuffle_ps(ymm8, ymm9, 0xEE); + ymm9 = _mm256_shuffle_ps(ymm10, ymm11, 0xEE); + + /*rearrange high elements*/ + ymm1 = _mm256_permute2f128_ps(ymm8, ymm9, 0x20); + + ymm3 = _mm256_shuffle_ps(ymm12, ymm13, 0x44); + ymm4 = _mm256_shuffle_ps(ymm14, ymm15, 0x44); + + /*rearrange low elements*/ + ymm4 = _mm256_permute2f128_ps(ymm3, ymm4, 0x20); + /*unpack high*/ + ymm8 = _mm256_shuffle_ps(ymm12, ymm13, 0xEE); + ymm9 = _mm256_shuffle_ps(ymm14, ymm15, 0xEE); + + /*rearrange high elements*/ + ymm5 = _mm256_permute2f128_ps(ymm8, ymm9, 0x20); + + if(2 == n_rem) + { + _mm256_storeu_ps((float *)(b11 + cs_b * 0), ymm0); + _mm256_storeu_ps((float *)(b11 + cs_b * 1), ymm1); + _mm256_storeu_ps((float *)(b11 + cs_b * 0 + 4), ymm4); + _mm256_storeu_ps((float *)(b11 + cs_b * 1 + 4), ymm5); + } + else if(1 == n_rem) + { + _mm256_storeu_ps((float *)(b11 + cs_b * 0), ymm0); + _mm256_storeu_ps((float *)(b11 + cs_b * 0 + 4), ymm4); + } + } + } + dim_t m_rem = i + d_mr; + if(m_rem >= 4) + { + i = m_rem - 4; + a10 = L + (i*cs_a) + (i + 4)*rs_a; + a11 = L + (i*cs_a) + (i*rs_a); + + scomplex *ptr_a10_dup = D_A_pack; + dim_t p_lda = 4; + if(transa) + { + for(dim_t x =0;x < m-i+4;x+=p_lda) + { + ymm0 = _mm256_loadu_ps((float const *)(a10)); + ymm1 = _mm256_loadu_ps((float const *)(a10 + cs_a)); + ymm2 = _mm256_loadu_ps((float const *)(a10 + cs_a * 2)); + ymm3 = _mm256_loadu_ps((float const *)(a10 + cs_a * 3)); + + ymm4 = _mm256_shuffle_ps(ymm0, ymm1, 0x44); + ymm5 = _mm256_shuffle_ps(ymm2, ymm3, 0x44); + ymm6 = _mm256_permute2f128_ps(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_ps(ymm4,ymm5,0x31); + ymm0 = _mm256_shuffle_ps(ymm0, ymm1, 0xEE); + ymm1 = _mm256_shuffle_ps(ymm2, ymm3, 0xEE); + ymm7 = _mm256_permute2f128_ps(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_ps(ymm0,ymm1,0x31); + + + _mm256_storeu_ps((float *)(ptr_a10_dup), ymm6); + _mm256_storeu_ps((float *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_ps((float *)(ptr_a10_dup + p_lda*2), ymm8); + _mm256_storeu_ps((float *)(ptr_a10_dup + p_lda*3), ymm9); + + a10 += p_lda; + ptr_a10_dup += p_lda*p_lda; + } + } + else + { + for(dim_t x =0;x < m-i-4;x++) + { + ymm0 = _mm256_loadu_ps((float const *)(a10 + rs_a * x)); + _mm256_storeu_ps((float *)(ptr_a10_dup + p_lda * x), ymm0); + } + } + + if(!is_unitdiag) + { + if(transa) + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_ps((__m128 const *)(a11)); + ymm1 = _mm256_broadcast_ps((__m128 const *)(a11+cs_a*1 + 1)); + ymm2 = _mm256_broadcast_ps((__m128 const *)(a11+cs_a*2 + 2)); + ymm3 = _mm256_broadcast_ps((__m128 const *)(a11+cs_a*3 + 3)); + ymm0 = _mm256_permute_ps(ymm0, 0x44); + ymm1 = _mm256_permute_ps(ymm1, 0x44); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + ymm3 = _mm256_permute_ps(ymm3, 0x44); + } + else + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_ps((__m128 const *)(a11)); + ymm1 = _mm256_broadcast_ps((__m128 const *)(a11+rs_a*1 + 1)); + ymm2 = _mm256_broadcast_ps((__m128 const *)(a11+rs_a*2 + 2)); + ymm3 = _mm256_broadcast_ps((__m128 const *)(a11+rs_a*3 + 3)); + ymm0 = _mm256_permute_ps(ymm0, 0x44); + ymm1 = _mm256_permute_ps(ymm1, 0x44); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + ymm3 = _mm256_permute_ps(ymm3, 0x44); + } + + ymm1 = _mm256_shuffle_ps(ymm0, ymm1, 0x44); + ymm2 = _mm256_shuffle_ps(ymm2, ymm3, 0x44); + ymm1 = _mm256_blend_ps(ymm1, ymm2, 0xF0); + +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + ymm7 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); + ymm4 = _mm256_mul_ps(ymm1, ymm1); + ymm6 = _mm256_permute_ps(ymm4, 0xB1); + ymm4 = _mm256_add_ps(ymm4, ymm6); + ymm1 = _mm256_mul_ps(ymm1, ymm7); + ymm1 = _mm256_div_ps(ymm1, ymm4); +#endif + } + _mm256_storeu_ps((float *)(d11_pack), ymm1); + + for(j = (n - d_nr); (j + 1) > 0; j -= d_nr) + { + a10 = D_A_pack; + a11 = L + (i*cs_a) + (i*rs_a); + b01 = B + (j*cs_b) + i + 4; + b11 = B + (j* cs_b) + i; + k_iter = (m - i - 4); + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + BLIS_CTRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) + BLIS_CTRSM_SMALL_NREG_TRANSPOSE_3x4(b11,cs_b,AlphaVal) + + ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 3)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm11) +#else + BLIS_CTRSM_MUL(ymm11) +#endif + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*2 + 3*rs_a)); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm11, ymm11, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm11, ymm11,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm10 = _mm256_sub_ps(ymm10,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*1 + 3*rs_a)); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm11, ymm11, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm11, ymm11,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm9 = _mm256_sub_ps(ymm9,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*0 + 3*rs_a)); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm11, ymm11, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm11, ymm11,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm8 = _mm256_sub_ps(ymm8,ymm16); + + + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 2)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm10) +#else + BLIS_CTRSM_MUL(ymm10) +#endif + + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a + 2*rs_a)); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm10, ymm10, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm10, ymm10,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm9 = _mm256_sub_ps(ymm9,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*0 + 2*rs_a)); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm10, ymm10, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm10, ymm10,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm8 = _mm256_sub_ps(ymm8,ymm16); + + + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm9) +#else + BLIS_CTRSM_MUL(ymm9) +#endif + + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm9, ymm9, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm9, ymm9,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm8 = _mm256_sub_ps(ymm8,ymm16); + + + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm8) +#else + BLIS_CTRSM_MUL(ymm8) +#endif + + BLIS_CTRSM_SMALL_NREG_TRANSPOSE_4x3_AND_STORE(b11,cs_b) + } + dim_t n_rem = j + d_nr;; + if(n_rem) + { + a10 = D_A_pack; + a11 = L + (i*cs_a) + (i*rs_a); + b01 = B + i + 4; + b11 = B + i; + + k_iter = (m - i - 4); + + BLIS_SET_S_YMM_REG_ZEROS + + if(2 == n_rem) + { + ///GEMM code begins/// + BLIS_CTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b,p_lda,k_iter) + BLIS_CTRSM_SMALL_NREG_TRANSPOSE_2x4(b11,cs_b,AlphaVal) + } + else if(1 == n_rem) + { + ///GEMM code begins/// + BLIS_CTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b,p_lda,k_iter) + BLIS_CTRSM_SMALL_NREG_TRANSPOSE_1x4(b11,cs_b,AlphaVal) + } + + ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 3)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm11) +#else + BLIS_CTRSM_MUL(ymm11) +#endif + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*2 + 3*rs_a)); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm11, ymm11, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm11, ymm11,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm10 = _mm256_sub_ps(ymm10,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*1 + 3*rs_a)); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm11, ymm11, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm11, ymm11,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm9 = _mm256_sub_ps(ymm9,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*0 + 3*rs_a)); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm11, ymm11, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm11, ymm11,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm8 = _mm256_sub_ps(ymm8,ymm16); + + + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 2)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm10) +#else + BLIS_CTRSM_MUL(ymm10) +#endif + + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a + 2*rs_a)); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm10, ymm10, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm10, ymm10,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm9 = _mm256_sub_ps(ymm9,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*0 + 2*rs_a)); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm10, ymm10, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm10, ymm10,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm8 = _mm256_sub_ps(ymm8,ymm16); + + + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm9) +#else + BLIS_CTRSM_MUL(ymm9) +#endif + + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm9, ymm9, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm9, ymm9,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm8 = _mm256_sub_ps(ymm8,ymm16); + + + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm8) +#else + BLIS_CTRSM_MUL(ymm8) +#endif + + ymm1 = _mm256_shuffle_ps(ymm8, ymm9, 0x44); + ymm3 = _mm256_shuffle_ps(ymm10, ymm11, 0x44); + + ymm0 = _mm256_permute2f128_ps(ymm1, ymm3, 0x20); + ymm1 = _mm256_permute2f128_ps(ymm1, ymm3, 0x31); + + ymm2 = _mm256_shuffle_ps(ymm8, ymm9, 0xEE); + ymm3 = _mm256_shuffle_ps(ymm10, ymm11, 0xEE); + + ymm4 = _mm256_permute2f128_ps(ymm2, ymm3, 0x20); + ymm5 = _mm256_permute2f128_ps(ymm2, ymm3, 0x31); + + if(2 == n_rem) + { + _mm256_storeu_ps((float *)(b11 + cs_b * 0), ymm0); + _mm256_storeu_ps((float *)(b11 + cs_b * 1), ymm4); + } + else if(1 == n_rem) + { + _mm256_storeu_ps((float *)(b11 + cs_b * 0), ymm0); + } + + } + m_rem -=4; + } + a10 = L + m_rem*rs_a; + + // Do transpose for a10 & store in D_A_pack + scomplex *ptr_a10_dup = D_A_pack; + if(m_rem) + { + if(3 == m_rem) + { + dim_t p_lda = 4; + if(transa) + { + for(dim_t x =0;x < m-m_rem;x+=p_lda) + { + ymm0 = _mm256_loadu_ps( + (float const *)(a10)); + ymm1 = _mm256_loadu_ps( + (float const *)(a10 + cs_a)); + ymm2 = _mm256_loadu_ps( + (float const *)(a10 + cs_a * 2)); + ymm3 = _mm256_loadu_ps( + (float const *)(a10 + cs_a * 3)); + + ymm4 = _mm256_shuffle_ps(ymm0, ymm1, 0x44); + ymm5 = _mm256_shuffle_ps(ymm2, ymm3, 0x44); + ymm6 = _mm256_permute2f128_ps(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_ps(ymm4,ymm5,0x31); + ymm0 = _mm256_shuffle_ps(ymm0, ymm1, 0xEE); + ymm1 = _mm256_shuffle_ps(ymm2, ymm3, 0xEE); + ymm7 = _mm256_permute2f128_ps(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_ps(ymm0,ymm1,0x31); + + + _mm256_storeu_ps((float *) + (ptr_a10_dup), ymm6); + _mm256_storeu_ps((float *) + (ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_ps((float *) + (ptr_a10_dup + p_lda*2), ymm8); + _mm256_storeu_ps((float *) + (ptr_a10_dup + p_lda*3), ymm9); + + a10 += p_lda; + ptr_a10_dup += p_lda*p_lda; + } + } + else + { + for(dim_t x =0;x < m-m_rem;x++) + { + ymm0 = _mm256_loadu_ps( + (float const *)(a10 + rs_a * x)); + _mm256_storeu_ps( + (float *)(ptr_a10_dup + + p_lda * x), ymm0); + } + } + + for(j = (n - d_nr); (j + 1) > 0; j -= d_nr) + { + a10 = D_A_pack; + a11 = L; + b01 = B + (j* cs_b) + m_rem; + b11 = B + (j* cs_b); + k_iter = (m - m_rem); + + BLIS_SET_S_YMM_REG_ZEROS + BLIS_CTRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) + + ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); + ymm16 = _mm256_permute_ps(ymm16, 0x44); + + ymm0 = _mm256_loadu_ps((float const *)(b11)); + ymm1 = _mm256_loadu_ps((float const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_ps((float const *)(b11 + cs_b *2)); + + ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); + + ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); + ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5); + ymm19 = _mm256_mul_ps(ymm19, ymm17); + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19); + ymm8 = _mm256_sub_ps(ymm19, ymm8); + + ymm18 = _mm256_shuffle_ps(ymm1, ymm1, 0xA0); + ymm19 = _mm256_shuffle_ps(ymm1, ymm1,0xF5); + ymm19 = _mm256_mul_ps(ymm19, ymm17); + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19); + ymm9 = _mm256_sub_ps(ymm19, ymm9); + + ymm18 = _mm256_shuffle_ps(ymm2, ymm2, 0xA0); + ymm19 = _mm256_shuffle_ps(ymm2, ymm2,0xF5); + ymm19 = _mm256_mul_ps(ymm19, ymm17); + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19); + ymm10 = _mm256_sub_ps(ymm19, ymm10); + + xmm0 = _mm256_extractf128_ps(ymm8, 0); + xmm1 = _mm256_extractf128_ps(ymm9, 0); + xmm2 = _mm256_extractf128_ps(ymm10, 0); + + _mm_storeu_ps((float *)(b11), xmm0); + _mm_storeu_ps((float *)(b11 + cs_b * 1), xmm1); + _mm_storeu_ps((float *)(b11 + cs_b * 2), xmm2); + + xmm0 = _mm256_extractf128_ps(ymm8, 1); + xmm1 = _mm256_extractf128_ps(ymm9, 1); + xmm2 = _mm256_extractf128_ps(ymm10, 1); + + _mm_storel_pi((__m64 *)(b11 + 2), xmm0); + _mm_storel_pi((__m64 *)(b11 + cs_b * 1 + 2), xmm1); + _mm_storel_pi((__m64 *)(b11 + cs_b * 2 + 2), xmm2); + + + + if(transa) + ctrsm_AltXB_ref(a11, b11, m_rem, 3, + cs_a, cs_b,is_unitdiag, conjtransa); + else + { + ctrsm_AuXB_ref(a11, b11, m_rem, 3, + rs_a, cs_b, + is_unitdiag, conjtransa); + } + } + dim_t n_rem = j + d_nr; + if(n_rem) + { + a10 = D_A_pack; + a11 = L; + b01 = B + m_rem; + b11 = B; + + k_iter = (m - m_rem); + BLIS_SET_S_YMM_REG_ZEROS + if(2 == n_rem) + { + ///GEMM code begins/// + BLIS_CTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b, + p_lda,k_iter) + + BLIS_PRE_CTRSM_SMALL_3M_2N(AlphaVal,b11,cs_b) + + if(transa) + ctrsm_AltXB_ref(a11, b11, m_rem, 2, + cs_a, cs_b,is_unitdiag, + conjtransa); + else + ctrsm_AuXB_ref(a11, b11, m_rem, 2, + rs_a, cs_b, is_unitdiag, + conjtransa); + } + else if(1 == n_rem) + { + ///GEMM code begins/// + BLIS_CTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b, + p_lda,k_iter) + + BLIS_PRE_CTRSM_SMALL_3M_1N(AlphaVal,b11,cs_b) + + if(transa) + ctrsm_AltXB_ref(a11, b11, m_rem, 1, + cs_a, cs_b, is_unitdiag, + conjtransa); + else + ctrsm_AuXB_ref(a11, b11, m_rem, 1, + rs_a, cs_b, is_unitdiag, + conjtransa); + } + } + } + if(2 == m_rem) + { + dim_t p_lda = 4; + if(transa) + { + for(dim_t x =0;x < m-m_rem;x+=p_lda) + { + ymm0 = _mm256_loadu_ps( + (float const *)(a10)); + ymm1 = _mm256_loadu_ps( + (float const *)(a10 + cs_a)); + + ymm2 = _mm256_broadcast_ss((float const *)&ones); + + ymm4 = _mm256_shuffle_ps(ymm0, ymm1, 0x44); + ymm5 = _mm256_shuffle_ps(ymm2, ymm2, 0x44); + ymm6 = _mm256_permute2f128_ps(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_ps(ymm4,ymm5,0x31); + ymm0 = _mm256_shuffle_ps(ymm0, ymm1, 0xEE); + ymm1 = _mm256_shuffle_ps(ymm2, ymm2, 0xEE); + ymm7 = _mm256_permute2f128_ps(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_ps(ymm0,ymm1,0x31); + + + _mm256_storeu_ps((float *) + (ptr_a10_dup), ymm6); + _mm256_storeu_ps((float *) + (ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_ps((float *) + (ptr_a10_dup + p_lda*2), ymm8); + _mm256_storeu_ps((float *) + (ptr_a10_dup + p_lda*3), ymm9); + + a10 += p_lda; + ptr_a10_dup += p_lda*p_lda; + } + } + else + { + for(dim_t x =0;x < m-m_rem;x++) + { + ymm0 = _mm256_loadu_ps( + (float const *)(a10 + rs_a * x)); + _mm256_storeu_ps( + (float *)(ptr_a10_dup + + p_lda * x), ymm0); + } + } + + for(j = (n - d_nr); (j + 1) > 0; j -= d_nr) + { + a10 = D_A_pack; + a11 = L; + b01 = B + (j* cs_b) + m_rem; + b11 = B + (j* cs_b); + + k_iter = (m - m_rem); + + BLIS_SET_S_YMM_REG_ZEROS + BLIS_CTRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) + + ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); + ymm16 = _mm256_permute_ps(ymm16, 0x44); + + ymm0 = _mm256_loadu_ps((float const *)(b11)); + ymm1 = _mm256_loadu_ps((float const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_ps((float const *)(b11 + cs_b *2)); + + + ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); + + ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); + ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5); + ymm19 = _mm256_mul_ps(ymm19, ymm17); + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19); + ymm8 = _mm256_sub_ps(ymm19, ymm8); + + ymm18 = _mm256_shuffle_ps(ymm1, ymm1, 0xA0); + ymm19 = _mm256_shuffle_ps(ymm1, ymm1,0xF5); + ymm19 = _mm256_mul_ps(ymm19, ymm17); + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19); + ymm9 = _mm256_sub_ps(ymm19, ymm9); + + ymm18 = _mm256_shuffle_ps(ymm2, ymm2, 0xA0); + ymm19 = _mm256_shuffle_ps(ymm2, ymm2,0xF5); + ymm19 = _mm256_mul_ps(ymm19, ymm17); + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19); + ymm10 = _mm256_sub_ps(ymm19, ymm10); + + xmm0 = _mm256_extractf128_ps(ymm8, 0); + xmm1 = _mm256_extractf128_ps(ymm9, 0); + xmm2 = _mm256_extractf128_ps(ymm10, 0); + + _mm_storeu_ps((float *)(b11), xmm0); + _mm_storeu_ps((float *)(b11 + cs_b * 1), xmm1); + _mm_storeu_ps((float *)(b11 + cs_b * 2), xmm2); + + + + if(transa) + ctrsm_AltXB_ref(a11, b11, m_rem, 3, + cs_a, cs_b,is_unitdiag, conjtransa); + else + { + ctrsm_AuXB_ref(a11, b11, m_rem, 3, + rs_a, cs_b, + is_unitdiag, conjtransa); + } + } + dim_t n_rem = j + d_nr; + if(n_rem) + { + a10 = D_A_pack; + a11 = L; + b01 = B + m_rem; + b11 = B; + + k_iter = (m - m_rem); + + BLIS_SET_S_YMM_REG_ZEROS + if(2 == n_rem) + { + ///GEMM code begins/// + BLIS_CTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b, + p_lda,k_iter) + + BLIS_PRE_CTRSM_SMALL_2M_2N(AlphaVal,b11,cs_b) + + if(transa) + ctrsm_AltXB_ref(a11, b11, m_rem, 2, + cs_a, cs_b,is_unitdiag, + conjtransa); + else + ctrsm_AuXB_ref(a11, b11, m_rem, 2, + rs_a, cs_b, is_unitdiag, + conjtransa); + } + else if(1 == n_rem) + { + ///GEMM code begins/// + BLIS_CTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b, + p_lda,k_iter) + + BLIS_PRE_CTRSM_SMALL_2M_1N(AlphaVal,b11,cs_b) + + if(transa) + ctrsm_AltXB_ref(a11, b11, m_rem, 1, + cs_a, cs_b, is_unitdiag + ,conjtransa); + else + ctrsm_AuXB_ref(a11, b11, m_rem, 1, + rs_a, cs_b, is_unitdiag, + conjtransa); + } + } + } + if(1 == m_rem) + { + dim_t p_lda = 4; + if(transa) + { + for(dim_t x =0;x < m - m_rem;x+=p_lda) + { + ymm0 = _mm256_loadu_ps( + (float const *)(a10)); + + ymm1 = _mm256_broadcast_ss((float const *)&ones); + + ymm4 = _mm256_shuffle_ps(ymm0, ymm1, 0x44); + ymm5 = _mm256_shuffle_ps(ymm1, ymm1, 0x44); + ymm6 = _mm256_permute2f128_ps(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_ps(ymm4,ymm5,0x31); + ymm0 = _mm256_shuffle_ps(ymm0, ymm1, 0xEE); + ymm1 = _mm256_shuffle_ps(ymm1, ymm1, 0xEE); + ymm7 = _mm256_permute2f128_ps(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_ps(ymm0,ymm1,0x31); + + + _mm256_storeu_ps((float *) + (ptr_a10_dup), ymm6); + _mm256_storeu_ps((float *) + (ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_ps((float *) + (ptr_a10_dup + p_lda*2), ymm8); + _mm256_storeu_ps((float *) + (ptr_a10_dup + p_lda*3), ymm9); + + a10 += p_lda; + ptr_a10_dup += p_lda*p_lda; + } + } + else + { + for(dim_t x =0;x < m - m_rem;x++) + { + ymm0 = _mm256_loadu_ps( + (float const *)(a10 + rs_a * x)); + _mm256_storeu_ps( + (float *)(ptr_a10_dup + + p_lda * x), ymm0); + } + } + + for(j = (n - d_nr); (j + 1) > 0; j -= d_nr) + { + a10 = D_A_pack; + a11 = L; + b01 = B + (j* cs_b) + m_rem; + b11 = B + (j* cs_b); + + k_iter = (m - m_rem); + + BLIS_SET_S_YMM_REG_ZEROS + BLIS_CTRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) + + ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); + ymm16 = _mm256_permute_ps(ymm16, 0x44); + + ymm0 = _mm256_loadu_ps((float const *)(b11)); + ymm1 = _mm256_loadu_ps((float const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_ps((float const *)(b11 + cs_b *2)); + + + ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); + + ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); + ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5); + ymm19 = _mm256_mul_ps(ymm19, ymm17); + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19); + ymm8 = _mm256_sub_ps(ymm19, ymm8); + + ymm18 = _mm256_shuffle_ps(ymm1, ymm1, 0xA0); + ymm19 = _mm256_shuffle_ps(ymm1, ymm1,0xF5); + ymm19 = _mm256_mul_ps(ymm19, ymm17); + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19); + ymm9 = _mm256_sub_ps(ymm19, ymm9); + + ymm18 = _mm256_shuffle_ps(ymm2, ymm2, 0xA0); + ymm19 = _mm256_shuffle_ps(ymm2, ymm2,0xF5); + ymm19 = _mm256_mul_ps(ymm19, ymm17); + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19); + ymm10 = _mm256_sub_ps(ymm19, ymm10); + + xmm0 = _mm256_extractf128_ps(ymm8, 0); + xmm1 = _mm256_extractf128_ps(ymm9, 0); + xmm2 = _mm256_extractf128_ps(ymm10, 0); + + _mm_storel_pi((__m64 *)(b11), xmm0); + _mm_storel_pi((__m64 *)(b11 + cs_b * 1), xmm1); + _mm_storel_pi((__m64 *)(b11 + cs_b * 2), xmm2); + + + + if(transa) + ctrsm_AutXB_ref(a11, b11, m_rem, 3, + cs_a, cs_b,is_unitdiag, conjtransa); + else + { + ctrsm_AlXB_ref(a11, b11, m_rem, 3, + rs_a, cs_b, + is_unitdiag, conjtransa); + } + } + + dim_t n_rem = j + d_nr; + if(n_rem) + { + a10 = D_A_pack; + a11 = L; + b01 = B + m_rem; + b11 = B; + + k_iter = (m - m_rem); + + BLIS_SET_S_YMM_REG_ZEROS + if(2 == n_rem) + { + ///GEMM code begins/// + BLIS_CTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b, + p_lda,k_iter) + + BLIS_PRE_CTRSM_SMALL_1M_2N(AlphaVal,b11, + cs_b) + + if(transa) + ctrsm_AutXB_ref(a11, b11, m_rem, + 2, cs_a, cs_b, + is_unitdiag, + conjtransa); + else + ctrsm_AlXB_ref(a11, b11, m_rem, 2, + rs_a, cs_b, + is_unitdiag, + conjtransa); + } + else if(1 == n_rem) + { + ///GEMM code begins/// + BLIS_CTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b, + p_lda,k_iter) + + BLIS_PRE_CTRSM_SMALL_1M_1N(AlphaVal,b11, + cs_b) + + if(transa) + ctrsm_AutXB_ref(a11, b11, m_rem, + 1, cs_a, cs_b, + is_unitdiag, + conjtransa); + else + ctrsm_AlXB_ref(a11, b11, m_rem, 1, + rs_a, cs_b, is_unitdiag, + conjtransa); + } + } + } + } + + if ((required_packing_A == 1) && + bli_mem_is_alloc( &local_mem_buf_A_s )) + { + bli_membrk_release(&rntm, &local_mem_buf_A_s); + } + + + return BLIS_SUCCESS; } + BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ( obj_t* AlphaObj, @@ -36541,7 +43075,1454 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB cntl_t* cntl ) { - return BLIS_SUCCESS; + dim_t m = bli_obj_length(b); // number of rows of matrix B + dim_t n = bli_obj_width(b); // number of columns of matrix B + + bool transa = bli_obj_has_trans(a); + bool conjtransa = bli_obj_has_conj(a); + + dim_t cs_a, rs_a; + dim_t d_mr = 8,d_nr = 3; + + // Swap rs_a & cs_a in case of non-tranpose. + if(transa) + { + cs_a = bli_obj_col_stride(a); // column stride of A + rs_a = bli_obj_row_stride(a); // row stride of A + } + else + { + cs_a = bli_obj_row_stride(a); // row stride of A + rs_a = bli_obj_col_stride(a); // column stride of A + } + dim_t cs_b = bli_obj_col_stride(b); // column stride of B + + dim_t i, j, k; //loop variables + dim_t k_iter; //number of times GEMM to be performed + + scomplex AlphaVal = *(scomplex *)AlphaObj->buffer; //value of alpha + scomplex *L = a->buffer; //pointer to matrix A + scomplex *B = b->buffer; //pointer to matrix B + + scomplex *a01, *a11, *b10, *b11; //pointers that point to blocks for GEMM and TRSM + + scomplex ones = {1.0, 1.0}; + bool is_unitdiag = bli_obj_has_unit_diag(a); + + //scratch registers + __m256 ymm0, ymm1, ymm2, ymm3; + __m256 ymm4, ymm5, ymm6, ymm7; + __m256 ymm8, ymm9, ymm10, ymm11; + __m256 ymm12, ymm13, ymm14, ymm15; + __m256 ymm16, ymm17, ymm18, ymm19; + + __m128 xmm0, xmm1, xmm2; + + gint_t required_packing_A = 1; + mem_t local_mem_buf_A_s = {0}; + scomplex *D_A_pack = NULL; + scomplex d11_pack[d_mr] __attribute__((aligned(64))); + rntm_t rntm; + + bli_rntm_init_from_global( &rntm ); + bli_rntm_set_num_threads_only( 1, &rntm ); + bli_membrk_rntm_set_membrk( &rntm ); + + siz_t buffer_size = bli_pool_block_size( + bli_membrk_pool( + bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), + bli_rntm_membrk(&rntm))); + + if ( (d_mr * m * sizeof(scomplex)) > buffer_size) + return BLIS_NOT_YET_IMPLEMENTED; + + if (required_packing_A == 1) + { + // Get the buffer from the pool. + bli_membrk_acquire_m(&rntm, + buffer_size, + BLIS_BITVAL_BUFFER_FOR_A_BLOCK, + &local_mem_buf_A_s); + if(FALSE==bli_mem_is_alloc(&local_mem_buf_A_s)) return BLIS_NULL_POINTER; + D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); + if(NULL==D_A_pack) return BLIS_NULL_POINTER; + } + + /* + Performs solving TRSM for 4 colmns at a time from 0 to m/4 in steps of d_mr + a. Load, transpose, Pack A (a10 block), the size of packing 4x3 to 4x (m-4) + First there will be no GEMM and no packing of a10 because it is only TRSM + b. Using packed a10 block and b01 block perform GEMM operation + c. Use GEMM outputs, perform TRSM operaton using a11, b11 and update B + d. Repeat b,c for n rows of B in steps of d_nr + */ + for(j = (n-d_nr); (j+1) > 0; j -= d_nr) //loop along 'N' direction + { + a01 = L + (j*rs_a) + (j+d_nr)*cs_a; + a11 = L + (j*cs_a) + (j*rs_a); + dim_t p_lda = (n-j-d_nr); + + if(transa) + { + /* + Load, tranpose and pack current A block (a10) into packed buffer memory + D_A_pack + a. This a10 block is used in GEMM portion only and this + a10 block size will be increasing by d_mr for every next itteration + untill it reaches 4x(m-4) which is the maximum GEMM alone block size + in A + b. This packed buffer is reused to calculate all n rows of B matrix + */ + bli_ctrsm_small_pack('R', p_lda, 1, a01, cs_a, D_A_pack, p_lda,d_nr); + + /* + Pack 4 diagonal elements of A block into an array + a. This helps in utilze cache line efficiently in TRSM operation + b. store ones when input is unit diagonal + */ + ctrsm_small_pack_diag_element(is_unitdiag,a11,cs_a,d11_pack,d_nr); + } + else + { + bli_ctrsm_small_pack('R', p_lda, 0, a01, rs_a, D_A_pack, p_lda,d_nr); + ctrsm_small_pack_diag_element(is_unitdiag,a11,rs_a,d11_pack,d_nr); + } + /* + a. Perform GEMM using a10, b01. + b. Perform TRSM on a11, b11 + c. This loop GEMM+TRSM loops operates with 4x3 block size + along n dimension for every d_nr rows of b01 where + packed A buffer is reused in computing all n rows of B. + d. Same approch is used in remaining fringe cases. + */ + for(i = (m-d_mr); (i+1) > 0; i -= d_mr) //loop along 'M' direction + { + + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; + b10 = B + i + (j+d_nr)*cs_b; + b11 = B + (i) + (j)*cs_b; + + k_iter = (n-j-d_nr); + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + /* + Peform GEMM between a10 and b01 blocks + For first itteration there will be no GEMM operation + where k_iter are zero + */ + BLIS_CTRSM_SMALL_GEMM_3nx8m(a01,b10,cs_b,p_lda,k_iter) + + /* + Load b11 of size 3x4 and multiply with alpha + Add the GEMM output and perform inregister transose of b11 + to peform TRSM operation. + */ + BLIS_PRE_CTRSM_SMALL_3x8(AlphaVal, b11, cs_b) + /* + Compute 4x3 TRSM block by using GEMM block output in register + a. The 4x3 input (gemm outputs) are stored in combinations of ymm + registers + 1. ymm8, ymm4 2. ymm9, ymm5 3. ymm10, ymm6, 4. ymm11, ymm7 + where ymm8-ymm11 holds 4x2 data and reaming 4x1 will be hold by + other registers + b. Towards the end do in regiser transpose of TRSM output and store in + b11 + */ + ////extract a00 + ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 2)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_TWO_DIV(ymm12, ymm13) +#else + BLIS_CTRSM_MUL(ymm12) + BLIS_CTRSM_MUL(ymm13) +#endif + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a *2 + rs_a*1) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //extract a11 + //(ROw1): FMA operations + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm12, ymm12, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm12, ymm12,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm10 = _mm256_sub_ps(ymm10,ymm16); + + ymm1 = _mm256_shuffle_ps(ymm13, ymm13, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm13, ymm13,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm11 = _mm256_sub_ps(ymm11,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*2) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm12, ymm12, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm12, ymm12,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm8 = _mm256_sub_ps(ymm8,ymm16); + + ymm1 = _mm256_shuffle_ps(ymm13, ymm13, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm13, ymm13,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm9 = _mm256_sub_ps(ymm9,ymm16); + + + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_TWO_DIV(ymm10, ymm11) +#else + BLIS_CTRSM_MUL(ymm10) + BLIS_CTRSM_MUL(ymm11) +#endif + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm10, ymm10, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm10, ymm10,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm8 = _mm256_sub_ps(ymm8,ymm16); + + ymm1 = _mm256_shuffle_ps(ymm11, ymm11, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm11, ymm11,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm9 = _mm256_sub_ps(ymm9,ymm16); + + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); + +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_TWO_DIV(ymm8, ymm9) +#else + BLIS_CTRSM_MUL(ymm8) + BLIS_CTRSM_MUL(ymm9) +#endif + + _mm256_storeu_ps((float *)b11, ymm8); + _mm256_storeu_ps((float *)(b11 + 4), ymm9); + _mm256_storeu_ps((float *)(b11 + cs_b), ymm10); + _mm256_storeu_ps((float *)(b11 + cs_b + 4), ymm11); + _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm12); + _mm256_storeu_ps((float *)(b11 + cs_b*2 + 4), ymm13); + + } + dim_t m_rem = i + d_mr; + if(m_rem >= 4) + { + a01 = D_A_pack; + a11 = L + (j*cs_a) + (j*rs_a); + b10 = B + (m_rem - 4) + (j+d_nr)*cs_b; + b11 = B + (m_rem - 4) + (j*cs_b); + + k_iter = (n-j-d_nr); + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_CTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) + + // Load b11 of size 4x6 and multiply with alpha + BLIS_PRE_CTRSM_SMALL_3x4(AlphaVal,b11,cs_b) + ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 2)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm12) +#else + BLIS_CTRSM_MUL(ymm12) +#endif + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a *2 + rs_a*1) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //extract a11 + //(ROw1): FMA operations + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm12, ymm12, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm12, ymm12,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm10 = _mm256_sub_ps(ymm10,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*2) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm12, ymm12, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm12, ymm12,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm8 = _mm256_sub_ps(ymm8,ymm16); + + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm10) +#else + BLIS_CTRSM_MUL(ymm10) +#endif + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm10, ymm10, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm10, ymm10,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm8 = _mm256_sub_ps(ymm8,ymm16); + + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); + +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm8) +#else + BLIS_CTRSM_MUL(ymm8) +#endif + + + _mm256_storeu_ps((float *)b11, ymm8); + _mm256_storeu_ps((float *)(b11 + cs_b), ymm10); + _mm256_storeu_ps((float *)(b11 + cs_b * 2), ymm12); + + m_rem -= 4; + } + if(m_rem == 3) + { + + a01 = D_A_pack; + a11 = L + (j*cs_a) + (j*rs_a); + b10 = B + (j+d_nr)*cs_b + (m_rem - 3); + b11 = B + (m_rem - 3) + (j*cs_b); + + k_iter = (n-j-d_nr); + + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_CTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) + + // Load b11 of size 4x6 and multiply with alpha + BLIS_PRE_CTRSM_SMALL_3x4(AlphaVal,b11,cs_b) + + ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 2)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm12) +#else + BLIS_CTRSM_MUL(ymm12) +#endif + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a *2 + rs_a*1) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //extract a11 + //(ROw1): FMA operations + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm12, ymm12, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm12, ymm12,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm10 = _mm256_sub_ps(ymm10,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*2) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm12, ymm12, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm12, ymm12,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm8 = _mm256_sub_ps(ymm8,ymm16); + + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm10) +#else + BLIS_CTRSM_MUL(ymm10) +#endif + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm10, ymm10, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm10, ymm10,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm8 = _mm256_sub_ps(ymm8,ymm16); + + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); + +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm8) +#else + BLIS_CTRSM_MUL(ymm8) +#endif + xmm0 = _mm256_extractf128_ps(ymm8, 0); + xmm1 = _mm256_extractf128_ps(ymm10, 0); + xmm2 = _mm256_extractf128_ps(ymm12, 0); + + _mm_storeu_ps((float *)(b11), xmm0); + _mm_storeu_ps((float *)(b11 + cs_b * 1), xmm1); + _mm_storeu_ps((float *)(b11 + cs_b * 2), xmm2); + + xmm0 = _mm256_extractf128_ps(ymm8, 1); + xmm1 = _mm256_extractf128_ps(ymm10, 1); + xmm2 = _mm256_extractf128_ps(ymm12, 1); + + _mm_storel_pi((__m64 *)(b11 + 2), xmm0); + _mm_storel_pi((__m64 *)(b11 + cs_b * 1 + 2), xmm1); + _mm_storel_pi((__m64 *)(b11 + cs_b * 2 + 2), xmm2); + + + m_rem -= 3; + } + if(m_rem == 2) + { + + a01 = D_A_pack; + a11 = L + (j*cs_a) + (j*rs_a); + b10 = B + (j+d_nr)*cs_b + (m_rem - 2); + b11 = B + (m_rem - 2) + (j*cs_b); + + k_iter = (n-j-d_nr); + + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_CTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) + + // Load b11 of size 4x6 and multiply with alpha + BLIS_PRE_CTRSM_SMALL_3x4(AlphaVal,b11,cs_b) + + ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 2)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm12) +#else + BLIS_CTRSM_MUL(ymm12) +#endif + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a *2 + rs_a*1) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //extract a11 + //(ROw1): FMA operations + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm12, ymm12, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm12, ymm12,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm10 = _mm256_sub_ps(ymm10,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*2) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm12, ymm12, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm12, ymm12,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm8 = _mm256_sub_ps(ymm8,ymm16); + + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm10) +#else + BLIS_CTRSM_MUL(ymm10) +#endif + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm10, ymm10, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm10, ymm10,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm8 = _mm256_sub_ps(ymm8,ymm16); + + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); + +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm8) +#else + BLIS_CTRSM_MUL(ymm8) +#endif + xmm0 = _mm256_extractf128_ps(ymm8, 0); + xmm1 = _mm256_extractf128_ps(ymm10, 0); + xmm2 = _mm256_extractf128_ps(ymm12, 0); + + _mm_storeu_ps((float *)(b11), xmm0); + _mm_storeu_ps((float *)(b11 + cs_b * 1), xmm1); + _mm_storeu_ps((float *)(b11 + cs_b * 2), xmm2); + + m_rem -= 2; + } + if(m_rem == 1) + { + + a01 = D_A_pack; + a11 = L + (j*cs_a) + (j*rs_a); + b10 = B + (j+d_nr)*cs_b + (m_rem - 1); + b11 = B + (m_rem - 1) + (j*cs_b); + + k_iter = (n-j-d_nr); + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_CTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) + + // Load b11 of size 4x6 and multiply with alpha + BLIS_PRE_CTRSM_SMALL_3x4(AlphaVal,b11,cs_b) + + ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 2)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm12) +#else + BLIS_CTRSM_MUL(ymm12) +#endif + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a *2 + rs_a*1) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //extract a11 + //(ROw1): FMA operations + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm12, ymm12, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm12, ymm12,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm10 = _mm256_sub_ps(ymm10,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*2) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm12, ymm12, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm12, ymm12,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm8 = _mm256_sub_ps(ymm8,ymm16); + + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm10) +#else + BLIS_CTRSM_MUL(ymm10) +#endif + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm10, ymm10, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm10, ymm10,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm8 = _mm256_sub_ps(ymm8,ymm16); + + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); + +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm8) +#else + BLIS_CTRSM_MUL(ymm8) +#endif + + xmm0 = _mm256_extractf128_ps(ymm8, 0); + xmm1 = _mm256_extractf128_ps(ymm10, 0); + xmm2 = _mm256_extractf128_ps(ymm12, 0); + + _mm_storel_pi((__m64 *)(b11), xmm0); + _mm_storel_pi((__m64 *)(b11 + cs_b * 1), xmm1); + _mm_storel_pi((__m64 *)(b11 + cs_b * 2), xmm2); + + m_rem -= 1; + } + + } + dim_t n_rem = j + d_nr; + if(n_rem == 2) + { + a01 = L + (n_rem - 2)*rs_a + n_rem*cs_a; + a11 = L + (n_rem - 2)*cs_a + (n_rem - 2)*rs_a; + + scomplex *ptr_a10_dup = D_A_pack; + + dim_t p_lda = (n-n_rem); // packed leading dimension + // perform copy of A to packed buffer D_A_pack + + if(transa) + { + for(dim_t x=0; x 0; i -= d_mr) + { + a01 = D_A_pack; + a11 = L + (n_rem - 2)*cs_a + (n_rem - 2)*rs_a; + b10 = B + i + (n_rem)*cs_b; + b11 = B + (i) + (n_rem - 2)*cs_b; + + k_iter = (n-n_rem); + + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_CTRSM_SMALL_GEMM_2nx8m(a01,b10,cs_b,p_lda,k_iter) + ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); + ymm16 = _mm256_permute_ps(ymm16, 0x44); + + ymm0 = _mm256_loadu_ps((float const *)(b11)); + ymm1 = _mm256_loadu_ps((float const *)(b11 + 4)); + + ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); + ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); + ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5); + ymm19 = _mm256_mul_ps(ymm19, ymm17); + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19); + ymm8 = _mm256_sub_ps(ymm19, ymm8); + + ymm18 = _mm256_shuffle_ps(ymm1, ymm1, 0xA0); + ymm19 = _mm256_shuffle_ps(ymm1, ymm1,0xF5); + ymm19 = _mm256_mul_ps(ymm19, ymm17); + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19); + ymm9 = _mm256_sub_ps(ymm19, ymm9); + + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b * 1)); + ymm1 = _mm256_loadu_ps((float const *)(b11 + cs_b *1 + 4)); + + ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); + ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5); + ymm19 = _mm256_mul_ps(ymm19, ymm17); + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19); + ymm10 = _mm256_sub_ps(ymm19, ymm10); + + ymm18 = _mm256_shuffle_ps(ymm1, ymm1, 0xA0); + ymm19 = _mm256_shuffle_ps(ymm1, ymm1,0xF5); + ymm19 = _mm256_mul_ps(ymm19, ymm17); + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19); + ymm11 = _mm256_sub_ps(ymm19, ymm11); + + ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_TWO_DIV(ymm10, ymm11) +#else + BLIS_CTRSM_MUL(ymm10) + BLIS_CTRSM_MUL(ymm11) +#endif + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm10, ymm10, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm10, ymm10,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm8 = _mm256_sub_ps(ymm8,ymm16); + + ymm1 = _mm256_shuffle_ps(ymm11, ymm11, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm11, ymm11,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm9 = _mm256_sub_ps(ymm9,ymm16); + + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); + +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_TWO_DIV(ymm8, ymm9) +#else + BLIS_CTRSM_MUL(ymm8) + BLIS_CTRSM_MUL(ymm9) +#endif + + _mm256_storeu_ps((float *)b11, ymm8); + _mm256_storeu_ps((float *)(b11 + 4), ymm9); + _mm256_storeu_ps((float *)(b11 + cs_b), ymm10); + _mm256_storeu_ps((float *)(b11 + cs_b + 4), ymm11); + } + dim_t m_rem = i + d_mr; + if(m_rem >= 4) + { + a01 = D_A_pack; + a11 = L + (n_rem - 2)*cs_a + (n_rem - 2)*rs_a; + b10 = B + (m_rem - 4) + (n_rem)*cs_b; + b11 = B + (m_rem - 4) + (n_rem - 2)*cs_b; + + k_iter = (n-n_rem); + + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_CTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) + ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); + ymm16 = _mm256_permute_ps(ymm16, 0x44); + + ymm0 = _mm256_loadu_ps((float const *)(b11)); + + ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); + ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); + ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5); + ymm19 = _mm256_mul_ps(ymm19, ymm17); + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19); + ymm8 = _mm256_sub_ps(ymm19, ymm8); + + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b * 1)); + + ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); + ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5); + ymm19 = _mm256_mul_ps(ymm19, ymm17); + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19); + ymm10 = _mm256_sub_ps(ymm19, ymm10); + + ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm10) +#else + BLIS_CTRSM_MUL(ymm10) +#endif + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*1) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //extract a11 + //(ROw1): FMA operations + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm10, ymm10, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm10, ymm10,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm8 = _mm256_sub_ps(ymm8,ymm16); + + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); + +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm8) +#else + BLIS_CTRSM_MUL(ymm8) +#endif + + _mm256_storeu_ps((float *)b11, ymm8); + _mm256_storeu_ps((float *)(b11 + cs_b), ymm10); + + m_rem -=4; + } + if(m_rem == 3) + { + a01 = D_A_pack; + a11 = L + (n_rem - 2)*cs_a + (n_rem - 2)*rs_a; + b10 = B + (m_rem - 3) + (n_rem)*cs_b; + b11 = B + (m_rem - 3) + (n_rem - 2)*cs_b; + + k_iter = (n-n_rem); + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_CTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) + ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); + ymm16 = _mm256_permute_ps(ymm16, 0x44); + + ymm0 = _mm256_loadu_ps((float const *)(b11)); + + ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); + ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); + ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5); + ymm19 = _mm256_mul_ps(ymm19, ymm17); + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19); + ymm8 = _mm256_sub_ps(ymm19, ymm8); + + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b * 1)); + + ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); + ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5); + ymm19 = _mm256_mul_ps(ymm19, ymm17); + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19); + ymm10 = _mm256_sub_ps(ymm19, ymm10); + + ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm10) +#else + BLIS_CTRSM_MUL(ymm10) +#endif + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*1) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //extract a11 + //(ROw1): FMA operations + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm10, ymm10, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm10, ymm10,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm8 = _mm256_sub_ps(ymm8,ymm16); + + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); + +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm8) +#else + BLIS_CTRSM_MUL(ymm8) +#endif + + xmm0 = _mm256_extractf128_ps(ymm8, 0); + xmm1 = _mm256_extractf128_ps(ymm10, 0); + _mm_storeu_ps((float *)(b11),xmm0); + _mm_storeu_ps((float *)(b11 + cs_b * 1),xmm1); + + xmm0 = _mm256_extractf128_ps(ymm8, 1); + xmm1 = _mm256_extractf128_ps(ymm10, 1); + + _mm_storel_pi((__m64 *)(b11 + 2),xmm0); + _mm_storel_pi((__m64 *)(b11 + cs_b * 1 + 2),xmm1); + + m_rem -=3; + } + if(m_rem == 2) + { + a01 = D_A_pack; + a11 = L + (n_rem - 2)*cs_a + (n_rem - 2)*rs_a; + b10 = B + (m_rem - 2) + (n_rem)*cs_b; + b11 = B + (m_rem - 2) + (n_rem - 2)*cs_b; + + k_iter = (n-n_rem); + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_CTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) + ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); + ymm16 = _mm256_permute_ps(ymm16, 0x44); + + ymm0 = _mm256_loadu_ps((float const *)(b11)); + + ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); + ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); + ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5); + ymm19 = _mm256_mul_ps(ymm19, ymm17); + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19); + ymm8 = _mm256_sub_ps(ymm19, ymm8); + + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b * 1)); + + ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); + ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5); + ymm19 = _mm256_mul_ps(ymm19, ymm17); + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19); + ymm10 = _mm256_sub_ps(ymm19, ymm10); + + ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm10) +#else + BLIS_CTRSM_MUL(ymm10) +#endif + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*1) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //extract a11 + //(ROw1): FMA operations + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm10, ymm10, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm10, ymm10,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm8 = _mm256_sub_ps(ymm8,ymm16); + + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); + +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm8) +#else + BLIS_CTRSM_MUL(ymm8) +#endif + xmm0 = _mm256_extractf128_ps(ymm8, 0); + xmm1 = _mm256_extractf128_ps(ymm10, 0); + + _mm_storeu_ps((float *)(b11 + cs_b * 0),xmm0); + _mm_storeu_ps((float *)(b11 + cs_b * 1),xmm1); + + m_rem -=2; + } + if(m_rem == 1) + { + a01 = D_A_pack; + a11 = L + (n_rem - 2)*cs_a + (n_rem - 2)*rs_a; + b10 = B + (m_rem - 1) + (n_rem)*cs_b; + b11 = B + (m_rem - 1) + (n_rem - 2)*cs_b; + + k_iter = (n-n_rem); + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_CTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) + ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); + ymm16 = _mm256_permute_ps(ymm16, 0x44); + + ymm0 = _mm256_loadu_ps((float const *)(b11)); + + ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); + ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); + ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5); + ymm19 = _mm256_mul_ps(ymm19, ymm17); + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19); + ymm8 = _mm256_sub_ps(ymm19, ymm8); + + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b * 1)); + + ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); + ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5); + ymm19 = _mm256_mul_ps(ymm19, ymm17); + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19); + ymm10 = _mm256_sub_ps(ymm19, ymm10); + + ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm10) +#else + BLIS_CTRSM_MUL(ymm10) +#endif + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*1) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //extract a11 + //(ROw1): FMA operations + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm10, ymm10, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm10, ymm10,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm8 = _mm256_sub_ps(ymm8,ymm16); + + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); + +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm8) +#else + BLIS_CTRSM_MUL(ymm8) +#endif + + xmm0 = _mm256_extractf128_ps(ymm8, 0); + xmm1 = _mm256_extractf128_ps(ymm10, 0); + + _mm_storel_pi((__m64 *)(b11),xmm0); + _mm_storel_pi((__m64 *)(b11 + cs_b),xmm1); + + m_rem -=1; + } + n_rem -= 2; + } + if(n_rem == 1) + { + a01 = L + (n_rem - 1)*rs_a + n_rem*cs_a; + a11 = L + (n_rem - 1)*cs_a + (n_rem - 1)*rs_a; + + scomplex *ptr_a10_dup = D_A_pack; + + dim_t p_lda = (n-n_rem); // packed leading dimension + // perform copy of A to packed buffer D_A_pack + + if(transa) + { + for(dim_t x=0; x 0; i -= d_mr) + { + a01 = D_A_pack; + a11 = L + (n_rem - 1)*cs_a + (n_rem - 1)*rs_a; + b10 = B + i + (n_rem)*cs_b; + b11 = B + (i) + (n_rem - 1)*cs_b; + + k_iter = (n-n_rem); + + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_CTRSM_SMALL_GEMM_1nx8m(a01,b10,cs_b,p_lda,k_iter) + ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); + ymm16 = _mm256_permute_ps(ymm16, 0x44); + + ymm0 = _mm256_loadu_ps((float const *)(b11)); + ymm1 = _mm256_loadu_ps((float const *)(b11 + 4)); + + ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); + ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); + ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5); + ymm19 = _mm256_mul_ps(ymm19, ymm17); + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19); + ymm8 = _mm256_sub_ps(ymm19, ymm8); + + ymm18 = _mm256_shuffle_ps(ymm1, ymm1, 0xA0); + ymm19 = _mm256_shuffle_ps(ymm1, ymm1,0xF5); + ymm19 = _mm256_mul_ps(ymm19, ymm17); + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19); + ymm9 = _mm256_sub_ps(ymm19, ymm9); + + + ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_TWO_DIV(ymm8, ymm9) +#else + BLIS_CTRSM_MUL(ymm8) + BLIS_CTRSM_MUL(ymm9) +#endif + + _mm256_storeu_ps((float *)b11, ymm8); + _mm256_storeu_ps((float *)(b11 + 4), ymm9); + + } + dim_t m_rem = i + d_mr; + if(m_rem >= 4) + { + a01 = D_A_pack; + a11 = L + (n_rem - 1)*cs_a + (n_rem - 1)*rs_a; + b10 = B + (m_rem - 4) + (n_rem)*cs_b; + b11 = B + (m_rem - 4) + (n_rem - 1)*cs_b; + k_iter = (n-n_rem); + + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_CTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) + ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); + ymm16 = _mm256_permute_ps(ymm16, 0x44); + + ymm0 = _mm256_loadu_ps((float const *)(b11)); + + ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); + ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); + ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5); + ymm19 = _mm256_mul_ps(ymm19, ymm17); + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19); + ymm8 = _mm256_sub_ps(ymm19, ymm8); + + + ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm8) +#else + BLIS_CTRSM_MUL(ymm8) +#endif + + _mm256_storeu_ps((float *)b11, ymm8); + + m_rem -=4; + } + if(m_rem == 3) + { + a01 = D_A_pack; + a11 = L + (n_rem - 1)*cs_a + (n_rem - 1)*rs_a; + b10 = B + (m_rem - 3) + (n_rem)*cs_b; + b11 = B + (m_rem - 3) + (n_rem - 1)*cs_b; + k_iter = (n-n_rem); + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_CTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) + ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); + ymm16 = _mm256_permute_ps(ymm16, 0x44); + + ymm0 = _mm256_loadu_ps((float const *)(b11)); + + ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); + ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); + ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5); + ymm19 = _mm256_mul_ps(ymm19, ymm17); + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19); + ymm8 = _mm256_sub_ps(ymm19, ymm8); + + + ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm8) +#else + BLIS_CTRSM_MUL(ymm8) +#endif + xmm0 = _mm256_extractf128_ps(ymm8, 0); + _mm_storeu_ps((float *)(b11), xmm0); + xmm0 = _mm256_extractf128_ps(ymm8, 1); + _mm_storel_pi((__m64 *)(b11 + 2), xmm0); + + m_rem -=3; + } + if(m_rem == 2) + { + a01 = D_A_pack; + a11 = L + (n_rem - 1)*cs_a + (n_rem - 1)*rs_a; + b10 = B + (m_rem - 2) + (n_rem)*cs_b; + b11 = B + (m_rem - 2) + (n_rem - 1)*cs_b; + k_iter = (n-n_rem); + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_CTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) + ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); + ymm16 = _mm256_permute_ps(ymm16, 0x44); + + ymm0 = _mm256_loadu_ps((float const *)(b11)); + + ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); + ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); + ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5); + ymm19 = _mm256_mul_ps(ymm19, ymm17); + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19); + ymm8 = _mm256_sub_ps(ymm19, ymm8); + + + ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm8) +#else + BLIS_CTRSM_MUL(ymm8) +#endif + xmm0 = _mm256_extractf128_ps(ymm8, 0); + _mm_storeu_ps((float *)(b11), xmm0); + + m_rem -=2; + } + if(m_rem == 1) + { + a01 = D_A_pack; + a11 = L + (n_rem - 1)*cs_a + (n_rem - 1)*rs_a; + b10 = B + (m_rem - 1) + (n_rem)*cs_b; + b11 = B + (m_rem - 1) + (n_rem - 1)*cs_b; + k_iter = (n-n_rem); + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_CTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) + ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); + ymm16 = _mm256_permute_ps(ymm16, 0x44); + + ymm0 = _mm256_loadu_ps((float const *)(b11)); + + ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); + ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); + ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5); + ymm19 = _mm256_mul_ps(ymm19, ymm17); + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19); + ymm8 = _mm256_sub_ps(ymm19, ymm8); + + ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm8) +#else + BLIS_CTRSM_MUL(ymm8) +#endif + xmm0 = _mm256_extractf128_ps(ymm8, 0); + _mm_storel_pi((__m64 *)(b11), xmm0); + + m_rem -=1; + } + n_rem -= 1; + } + + if ((required_packing_A == 1) && + bli_mem_is_alloc( &local_mem_buf_A_s )) + { + bli_membrk_release(&rntm, &local_mem_buf_A_s); + } + + + return BLIS_SUCCESS; } BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB @@ -36553,6 +44534,1532 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB cntl_t* cntl ) { - return BLIS_SUCCESS; + dim_t m = bli_obj_length(b); // number of rows of matrix B + dim_t n = bli_obj_width(b); // number of columns of matrix B + + bool transa = bli_obj_has_trans(a); + bool conjtransa = bli_obj_has_conj(a); + + dim_t cs_a, rs_a; + dim_t d_mr = 8,d_nr = 3; + + // Swap rs_a & cs_a in case of non-tranpose. + if(transa) + { + cs_a = bli_obj_col_stride(a); // column stride of A + rs_a = bli_obj_row_stride(a); // row stride of A + } + else + { + cs_a = bli_obj_row_stride(a); // row stride of A + rs_a = bli_obj_col_stride(a); // column stride of A + } + dim_t cs_b = bli_obj_col_stride(b); // column stride of B + + dim_t i, j, k; //loop variables + dim_t k_iter; //number of times GEMM to be performed + + scomplex AlphaVal = *(scomplex *)AlphaObj->buffer; //value of alpha + scomplex *L = a->buffer; //pointer to matrix A + scomplex *B = b->buffer; //pointer to matrix B + + scomplex *a01, *a11, *b10, *b11; //pointers that point to blocks for GEMM and TRSM + + scomplex ones = {1.0, 1.0}; + bool is_unitdiag = bli_obj_has_unit_diag(a); + + //scratch registers + __m256 ymm0, ymm1, ymm2, ymm3; + __m256 ymm4, ymm5, ymm6, ymm7; + __m256 ymm8, ymm9, ymm10, ymm11; + __m256 ymm12, ymm13, ymm14, ymm15; + __m256 ymm16, ymm17, ymm18, ymm19; + + __m128 xmm0, xmm1, xmm2; + + gint_t required_packing_A = 1; + mem_t local_mem_buf_A_s = {0}; + scomplex *D_A_pack = NULL; + scomplex d11_pack[d_mr] __attribute__((aligned(64))); + rntm_t rntm; + + bli_rntm_init_from_global( &rntm ); + bli_rntm_set_num_threads_only( 1, &rntm ); + bli_membrk_rntm_set_membrk( &rntm ); + + siz_t buffer_size = bli_pool_block_size( + bli_membrk_pool( + bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), + bli_rntm_membrk(&rntm))); + + if ( (d_mr * m * sizeof(scomplex)) > buffer_size) + return BLIS_NOT_YET_IMPLEMENTED; + + if (required_packing_A == 1) + { + // Get the buffer from the pool. + bli_membrk_acquire_m(&rntm, + buffer_size, + BLIS_BITVAL_BUFFER_FOR_A_BLOCK, + &local_mem_buf_A_s); + if(FALSE==bli_mem_is_alloc(&local_mem_buf_A_s)) return BLIS_NULL_POINTER; + D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); + if(NULL==D_A_pack) return BLIS_NULL_POINTER; + } + + /* + Performs solving TRSM for 4 colmns at a time from 0 to m/4 in steps of d_mr + a. Load, transpose, Pack A (a10 block), the size of packing 4x3 to 4x (m-4) + First there will be no GEMM and no packing of a10 because it is only TRSM + b. Using packed a10 block and b01 block perform GEMM operation + c. Use GEMM outputs, perform TRSM operaton using a11, b11 and update B + d. Repeat b,c for n rows of B in steps of d_nr + */ + for(j = 0; (j+d_nr-1) < n; j += d_nr) //loop along 'N' direction + { + a01 = L + j*rs_a; + a11 = L + j*cs_a + j*rs_a; + + dim_t p_lda = j; + + if(transa) + { + /* + Load, tranpose and pack current A block (a10) into packed buffer memory + D_A_pack + a. This a10 block is used in GEMM portion only and this + a10 block size will be increasing by d_mr for every next itteration + untill it reaches 4x(m-4) which is the maximum GEMM alone block size + in A + b. This packed buffer is reused to calculate all n rows of B matrix + */ + bli_ctrsm_small_pack('R', j, 1, a01, cs_a, D_A_pack, p_lda,d_nr); + + /* + Pack 4 diagonal elements of A block into an array + a. This helps in utilze cache line efficiently in TRSM operation + b. store ones when input is unit diagonal + */ + ctrsm_small_pack_diag_element(is_unitdiag,a11,cs_a,d11_pack,d_nr); + } + else + { + bli_ctrsm_small_pack('R', j, 0, a01, rs_a, D_A_pack, p_lda,d_nr); + ctrsm_small_pack_diag_element(is_unitdiag,a11,rs_a,d11_pack,d_nr); + } + /* + a. Perform GEMM using a10, b01. + b. Perform TRSM on a11, b11 + c. This loop GEMM+TRSM loops operates with 4x3 block size + along n dimension for every d_nr rows of b01 where + packed A buffer is reused in computing all n rows of B. + d. Same approch is used in remaining fringe cases. + */ + for(i = 0; (i+d_mr-1) < m; i += d_mr) //loop along 'M' direction + { + + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; + b10 = B + i; + b11 = B + i + j*cs_b; + + k_iter = j; + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + /* + Peform GEMM between a10 and b01 blocks + For first itteration there will be no GEMM operation + where k_iter are zero + */ + BLIS_CTRSM_SMALL_GEMM_3nx8m(a01,b10,cs_b,p_lda,k_iter) + + /* + Load b11 of size 3x4 and multiply with alpha + Add the GEMM output and perform inregister transose of b11 + to peform TRSM operation. + */ + BLIS_PRE_CTRSM_SMALL_3x8(AlphaVal, b11, cs_b) + /* + Compute 4x3 TRSM block by using GEMM block output in register + a. The 4x3 input (gemm outputs) are stored in combinations of ymm + registers + 1. ymm8, ymm4 2. ymm9, ymm5 3. ymm10, ymm6, 4. ymm11, ymm7 + where ymm8-ymm11 holds 4x2 data and reaming 4x1 will be hold by + other registers + b. Towards the end do in regiser transpose of TRSM output and store in + b11 + */ + ////extract a00 + ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_TWO_DIV(ymm8, ymm9) +#else + BLIS_CTRSM_MUL(ymm8) + BLIS_CTRSM_MUL(ymm9) +#endif + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*1) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //extract a11 + //(ROw1): FMA operations + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm8, ymm8, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm8, ymm8,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm10 = _mm256_sub_ps(ymm10,ymm16); + + ymm1 = _mm256_shuffle_ps(ymm9, ymm9, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm9, ymm9,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm11 = _mm256_sub_ps(ymm11,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*2) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm8, ymm8, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm8, ymm8,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm12 = _mm256_sub_ps(ymm12,ymm16); + + ymm1 = _mm256_shuffle_ps(ymm9, ymm9, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm9, ymm9,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm13 = _mm256_sub_ps(ymm13,ymm16); + + + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_TWO_DIV(ymm10, ymm11) +#else + BLIS_CTRSM_MUL(ymm10) + BLIS_CTRSM_MUL(ymm11) +#endif + + + a11 += cs_a; + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*2) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm9 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm10, ymm10, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm10, ymm10,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm12 = _mm256_sub_ps(ymm12,ymm16); + + ymm1 = _mm256_shuffle_ps(ymm11, ymm11, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm11, ymm11,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm13 = _mm256_sub_ps(ymm13,ymm16); + + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 2)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); + +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_TWO_DIV(ymm12, ymm13) +#else + BLIS_CTRSM_MUL(ymm12) + BLIS_CTRSM_MUL(ymm13) +#endif + + _mm256_storeu_ps((float *)b11, ymm8); + _mm256_storeu_ps((float *)(b11 + 4), ymm9); + _mm256_storeu_ps((float *)(b11 + cs_b), ymm10); + _mm256_storeu_ps((float *)(b11 + cs_b + 4), ymm11); + _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm12); + _mm256_storeu_ps((float *)(b11 + cs_b*2 + 4), ymm13); + + } + dim_t m_rem = m - i; + if(m_rem >= 4) + { + + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; + b10 = B + i; + b11 = B + i + j*cs_b; + + k_iter = j; + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_CTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) + + // Load b11 of size 4x6 and multiply with alpha + BLIS_PRE_CTRSM_SMALL_3x4(AlphaVal,b11,cs_b) + + ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm8) +#else + BLIS_CTRSM_MUL(ymm8) +#endif + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*1) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //extract a11 + //(ROw1): FMA operations + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm8, ymm8, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm8, ymm8,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm10 = _mm256_sub_ps(ymm10,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*2) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm8, ymm8, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm8, ymm8,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm12 = _mm256_sub_ps(ymm12,ymm16); + + + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm10) +#else + BLIS_CTRSM_MUL(ymm10) +#endif + + + a11 += cs_a; + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*2) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm9 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm10, ymm10, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm10, ymm10,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm12 = _mm256_sub_ps(ymm12,ymm16); + + + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 2)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); + +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm12) +#else + BLIS_CTRSM_MUL(ymm12) +#endif + _mm256_storeu_ps((float *)b11, ymm8); + _mm256_storeu_ps((float *)(b11 + cs_b), ymm10); + _mm256_storeu_ps((float *)(b11 + cs_b * 2), ymm12); + + m_rem -= 4; + i += 4; + } + if(m_rem == 3) + { + + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; + b10 = B + i; + b11 = B + i + j*cs_b; + + k_iter = j; + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_CTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) + + // Load b11 of size 4x6 and multiply with alpha + BLIS_PRE_CTRSM_SMALL_3x4(AlphaVal,b11,cs_b) + + ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm8) +#else + BLIS_CTRSM_MUL(ymm8) +#endif + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*1) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //extract a11 + //(ROw1): FMA operations + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm8, ymm8, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm8, ymm8,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm10 = _mm256_sub_ps(ymm10,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*2) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm8, ymm8, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm8, ymm8,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm12 = _mm256_sub_ps(ymm12,ymm16); + + + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm10) +#else + BLIS_CTRSM_MUL(ymm10) +#endif + + + a11 += cs_a; + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*2) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm9 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm10, ymm10, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm10, ymm10,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm12 = _mm256_sub_ps(ymm12,ymm16); + + + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 2)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); + +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm12) +#else + BLIS_CTRSM_MUL(ymm12) +#endif + +/* ymm0 = _mm256_loadu_ps((float const *)b11); + ymm1 = _mm256_loadu_ps((float const *)(b11 + cs_b)); + ymm2 = _mm256_loadu_ps((float const *)(b11 + cs_b * 2)); + ymm8 = _mm256_blend_ps(ymm8, ymm0, 0xC0); + ymm10 = _mm256_blend_ps(ymm10, ymm1, 0xC0); + ymm12 = _mm256_blend_ps(ymm12, ymm2, 0xC0); + _mm256_storeu_ps((float *)b11, ymm8); + _mm256_storeu_ps((float *)(b11 + cs_b), ymm10); + _mm256_storeu_ps((float *)(b11 + cs_b * 2), ymm12);*/ + xmm0 = _mm256_extractf128_ps(ymm8, 0); + xmm1 = _mm256_extractf128_ps(ymm8, 1); + _mm_storeu_ps((float *)(b11), xmm0); + _mm_storel_pi((__m64 *)(b11 + 2), xmm1); + xmm0 = _mm256_extractf128_ps(ymm10, 0); + xmm1 = _mm256_extractf128_ps(ymm10, 1); + _mm_storeu_ps((float *)(b11 + cs_b), xmm0); + _mm_storel_pi((__m64 *)(b11 + cs_b + 2), xmm1); + xmm0 = _mm256_extractf128_ps(ymm12, 0); + xmm1 = _mm256_extractf128_ps(ymm12, 1); + _mm_storeu_ps((float *)(b11 + cs_b * 2), xmm0); + _mm_storel_pi((__m64 *)(b11 + cs_b * 2 + 2), xmm1); + + m_rem -= 3; + i += 3; + + } + if(m_rem == 2) + { + + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; + b10 = B + i; + b11 = B + i + j*cs_b; + + k_iter = j; + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_CTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) + + // Load b11 of size 4x6 and multiply with alpha + BLIS_PRE_CTRSM_SMALL_3x4(AlphaVal,b11,cs_b) + + ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm8) +#else + BLIS_CTRSM_MUL(ymm8) +#endif + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*1) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //extract a11 + //(ROw1): FMA operations + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm8, ymm8, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm8, ymm8,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm10 = _mm256_sub_ps(ymm10,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*2) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm8, ymm8, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm8, ymm8,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm12 = _mm256_sub_ps(ymm12,ymm16); + + + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm10) +#else + BLIS_CTRSM_MUL(ymm10) +#endif + + + a11 += cs_a; + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*2) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm9 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm10, ymm10, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm10, ymm10,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm12 = _mm256_sub_ps(ymm12,ymm16); + + + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 2)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); + +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm12) +#else + BLIS_CTRSM_MUL(ymm12) +#endif + +/* ymm0 = _mm256_loadu_ps((float const *)b11); + ymm1 = _mm256_loadu_ps((float const *)(b11 + cs_b)); + ymm2 = _mm256_loadu_ps((float const *)(b11 + cs_b * 2)); + ymm8 = _mm256_blend_ps(ymm8, ymm0, 0xF0); + ymm10 = _mm256_blend_ps(ymm10, ymm1, 0xF0); + ymm12 = _mm256_blend_ps(ymm12, ymm2, 0xF0); + _mm256_storeu_ps((float *)b11, ymm8); + _mm256_storeu_ps((float *)(b11 + cs_b), ymm10); + _mm256_storeu_ps((float *)(b11 + cs_b * 2), ymm12); +*/ + xmm0 = _mm256_extractf128_ps(ymm8, 0); + _mm_storeu_ps((float *)(b11), xmm0); + xmm0 = _mm256_extractf128_ps(ymm10, 0); + _mm_storeu_ps((float *)(b11 + cs_b), xmm0); + xmm0 = _mm256_extractf128_ps(ymm12, 0); + _mm_storeu_ps((float *)(b11 + cs_b * 2), xmm0); + + + m_rem -= 2; + i += 2; + } + if(m_rem == 1) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; + b10 = B + i; + b11 = B + i + j*cs_b; + + k_iter = j; + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_CTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) + + // Load b11 of size 4x6 and multiply with alpha + BLIS_PRE_CTRSM_SMALL_3x4(AlphaVal,b11,cs_b) + + ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm8) +#else + BLIS_CTRSM_MUL(ymm8) +#endif + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*1) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //extract a11 + //(ROw1): FMA operations + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm8, ymm8, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm8, ymm8,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm10 = _mm256_sub_ps(ymm10,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*2) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm8, ymm8, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm8, ymm8,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm12 = _mm256_sub_ps(ymm12,ymm16); + + + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm10) +#else + BLIS_CTRSM_MUL(ymm10) +#endif + + + a11 += cs_a; + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*2) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm9 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm10, ymm10, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm10, ymm10,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm12 = _mm256_sub_ps(ymm12,ymm16); + + + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 2)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); + +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm12) +#else + BLIS_CTRSM_MUL(ymm12) +#endif + +/* ymm0 = _mm256_loadu_ps((float const *)b11); + ymm1 = _mm256_loadu_ps((float const *)(b11 + cs_b)); + ymm2 = _mm256_loadu_ps((float const *)(b11 + cs_b * 2)); + ymm8 = _mm256_blend_ps(ymm8, ymm0, 0xFC); + ymm10 = _mm256_blend_ps(ymm10, ymm1, 0xFC); + ymm12 = _mm256_blend_ps(ymm12, ymm2, 0xFC); + _mm256_storeu_ps((float *)b11, ymm8); + _mm256_storeu_ps((float *)(b11 + cs_b), ymm10); + _mm256_storeu_ps((float *)(b11 + cs_b * 2), ymm12); +*/ + xmm0 = _mm256_extractf128_ps(ymm8, 0); + xmm1 = _mm256_extractf128_ps(ymm10, 0); + xmm2 = _mm256_extractf128_ps(ymm12, 0); + _mm_storel_pi((__m64 *)(b11), xmm0); + _mm_storel_pi((__m64 *)(b11 + cs_b), xmm1); + _mm_storel_pi((__m64 *)(b11 + cs_b * 2), xmm2); + + m_rem -= 1; + i += 1; + } + + } + dim_t n_rem = n - j; + if(n_rem == 2) + { + a01 = L + j*rs_a; + a11 = L + j*cs_a + j*rs_a; + + scomplex *ptr_a10_dup = D_A_pack; + + dim_t p_lda = j; // packed leading dimension + // perform copy of A to packed buffer D_A_pack + + if(transa) + { + for(dim_t x=0; x= 4) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; + b10 = B + i; + b11 = B + i + j*cs_b; + + k_iter = j; + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_CTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) + ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); + ymm16 = _mm256_permute_ps(ymm16, 0x44); + + ymm0 = _mm256_loadu_ps((float const *)(b11)); + + ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); + ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); + ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5); + ymm19 = _mm256_mul_ps(ymm19, ymm17); + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19); + ymm8 = _mm256_sub_ps(ymm19, ymm8); + + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b * 1)); + + ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); + ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5); + ymm19 = _mm256_mul_ps(ymm19, ymm17); + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19); + ymm10 = _mm256_sub_ps(ymm19, ymm10); + + ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm8) +#else + BLIS_CTRSM_MUL(ymm8) +#endif + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*1) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //extract a11 + //(ROw1): FMA operations + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm8, ymm8, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm8, ymm8,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm10 = _mm256_sub_ps(ymm10,ymm16); + + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); + +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm10) +#else + BLIS_CTRSM_MUL(ymm10) +#endif + + _mm256_storeu_ps((float *)b11, ymm8); + _mm256_storeu_ps((float *)(b11 + cs_b), ymm10); + + m_rem -=4; + i+=4; + } + if(m_rem == 3) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; + b10 = B + i; + b11 = B + i + j*cs_b; + + k_iter = j; + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_CTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) + ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); + ymm16 = _mm256_permute_ps(ymm16, 0x44); + + ymm0 = _mm256_loadu_ps((float const *)(b11)); + + ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); + ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); + ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5); + ymm19 = _mm256_mul_ps(ymm19, ymm17); + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19); + ymm8 = _mm256_sub_ps(ymm19, ymm8); + + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b * 1)); + + ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); + ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5); + ymm19 = _mm256_mul_ps(ymm19, ymm17); + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19); + ymm10 = _mm256_sub_ps(ymm19, ymm10); + + ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm8) +#else + BLIS_CTRSM_MUL(ymm8) +#endif + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*1) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //extract a11 + //(ROw1): FMA operations + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm8, ymm8, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm8, ymm8,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm10 = _mm256_sub_ps(ymm10,ymm16); + + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); + +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm10) +#else + BLIS_CTRSM_MUL(ymm10) +#endif + xmm0 = _mm256_extractf128_ps(ymm8, 0); + xmm1 = _mm256_extractf128_ps(ymm10, 0); + _mm_storeu_ps((float *)(b11),xmm0); + _mm_storeu_ps((float *)(b11 + cs_b * 1),xmm1); + + xmm0 = _mm256_extractf128_ps(ymm8, 1); + xmm1 = _mm256_extractf128_ps(ymm10, 1); + + _mm_storel_pi((__m64 *)(b11 + 2),xmm0); + _mm_storel_pi((__m64 *)(b11 + cs_b * 1 + 2),xmm1); + + + m_rem -=3; + i+=3; + } + if(m_rem == 2) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; + b10 = B + i; + b11 = B + i + j*cs_b; + + k_iter = j; + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_CTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) + ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); + ymm16 = _mm256_permute_ps(ymm16, 0x44); + + ymm0 = _mm256_loadu_ps((float const *)(b11)); + + ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); + ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); + ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5); + ymm19 = _mm256_mul_ps(ymm19, ymm17); + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19); + ymm8 = _mm256_sub_ps(ymm19, ymm8); + + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b * 1)); + + ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); + ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5); + ymm19 = _mm256_mul_ps(ymm19, ymm17); + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19); + ymm10 = _mm256_sub_ps(ymm19, ymm10); + + ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm8) +#else + BLIS_CTRSM_MUL(ymm8) +#endif + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*1) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //extract a11 + //(ROw1): FMA operations + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm8, ymm8, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm8, ymm8,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm10 = _mm256_sub_ps(ymm10,ymm16); + + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); + +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm10) +#else + BLIS_CTRSM_MUL(ymm10) +#endif + xmm0 = _mm256_extractf128_ps(ymm8, 0); + xmm1 = _mm256_extractf128_ps(ymm10, 0); + + _mm_storeu_ps((float *)(b11 + cs_b * 0),xmm0); + _mm_storeu_ps((float *)(b11 + cs_b * 1),xmm1); + + + m_rem -=2; + i+=2; + } + if(m_rem == 1) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; + b10 = B + i; + b11 = B + i + j*cs_b; + + k_iter = j; + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_CTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) + ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); + ymm16 = _mm256_permute_ps(ymm16, 0x44); + + ymm0 = _mm256_loadu_ps((float const *)(b11)); + + ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); + ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); + ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5); + ymm19 = _mm256_mul_ps(ymm19, ymm17); + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19); + ymm8 = _mm256_sub_ps(ymm19, ymm8); + + ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b * 1)); + + ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); + ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5); + ymm19 = _mm256_mul_ps(ymm19, ymm17); + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19); + ymm10 = _mm256_sub_ps(ymm19, ymm10); + + ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm8) +#else + BLIS_CTRSM_MUL(ymm8) +#endif + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*1) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //extract a11 + //(ROw1): FMA operations + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm8, ymm8, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm8, ymm8,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm10 = _mm256_sub_ps(ymm10,ymm16); + + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); + +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm10) +#else + BLIS_CTRSM_MUL(ymm10) +#endif + xmm0 = _mm256_extractf128_ps(ymm8, 0); + xmm1 = _mm256_extractf128_ps(ymm10, 0); + _mm_storel_pi((__m64 *)(b11),xmm0); + _mm_storel_pi((__m64 *)(b11 + cs_b),xmm1); + + m_rem -=1; + i+=1; + } + j += 2; + n_rem -= 2; + } + if(n_rem == 1) + { + a01 = L + j*rs_a; + a11 = L + j*cs_a + j*rs_a; + + scomplex *ptr_a10_dup = D_A_pack; + + dim_t p_lda = j; // packed leading dimension + // perform copy of A to packed buffer D_A_pack + + if(transa) + { + for(dim_t x=0; x= 4) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; + b10 = B + i; + b11 = B + i + j*cs_b; + + k_iter = j; + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_CTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) + ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); + ymm16 = _mm256_permute_ps(ymm16, 0x44); + + ymm0 = _mm256_loadu_ps((float const *)(b11)); + + ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); + ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); + ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5); + ymm19 = _mm256_mul_ps(ymm19, ymm17); + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19); + ymm8 = _mm256_sub_ps(ymm19, ymm8); + + + ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm8) +#else + BLIS_CTRSM_MUL(ymm8) +#endif + + _mm256_storeu_ps((float *)b11, ymm8); + + m_rem -=4; + i+=4; + } + if(m_rem == 3) + { + + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; + b10 = B + i; + b11 = B + i + j*cs_b; + + k_iter = j; + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_CTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) + ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); + ymm16 = _mm256_permute_ps(ymm16, 0x44); + + ymm0 = _mm256_loadu_ps((float const *)(b11)); + + ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); + ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); + ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5); + ymm19 = _mm256_mul_ps(ymm19, ymm17); + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19); + ymm8 = _mm256_sub_ps(ymm19, ymm8); + + + ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm8) +#else + BLIS_CTRSM_MUL(ymm8) +#endif + xmm0 = _mm256_extractf128_ps(ymm8, 0); + _mm_storeu_ps((float *)(b11), xmm0); + xmm0 = _mm256_extractf128_ps(ymm8, 1); + _mm_storel_pi((__m64 *)(b11 + 2), xmm0); + + m_rem -=3; + i+=3; + } + if(m_rem == 2) + { + + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; + b10 = B + i; + b11 = B + i + j*cs_b; + + k_iter = j; + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_CTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) + ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); + ymm16 = _mm256_permute_ps(ymm16, 0x44); + + ymm0 = _mm256_loadu_ps((float const *)(b11)); + + ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); + ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); + ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5); + ymm19 = _mm256_mul_ps(ymm19, ymm17); + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19); + ymm8 = _mm256_sub_ps(ymm19, ymm8); + + + ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm8) +#else + BLIS_CTRSM_MUL(ymm8) +#endif + xmm0 = _mm256_extractf128_ps(ymm8, 0); + _mm_storeu_ps((float *)(b11), xmm0); + + m_rem -=2; + i+=2; + } + if(m_rem == 1) + { + + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; + b10 = B + i; + b11 = B + i + j*cs_b; + + k_iter = j; + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_CTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) + ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); + ymm16 = _mm256_permute_ps(ymm16, 0x44); + + ymm0 = _mm256_loadu_ps((float const *)(b11)); + + ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); + ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); + ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5); + ymm19 = _mm256_mul_ps(ymm19, ymm17); + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19); + ymm8 = _mm256_sub_ps(ymm19, ymm8); + + ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm8) +#else + BLIS_CTRSM_MUL(ymm8) +#endif + xmm0 = _mm256_extractf128_ps(ymm8, 0); + _mm_storel_pi((__m64 *)(b11), xmm0); + + m_rem -=1; + i+=1; + } + j += 1; + n_rem -= 1; + } + + if ((required_packing_A == 1) && + bli_mem_is_alloc( &local_mem_buf_A_s )) + { + bli_membrk_release(&rntm, &local_mem_buf_A_s); + } + + + return BLIS_SUCCESS; } -#endif //BLIS_ENABLE_SMALL_MATRIX_TRSM \ No newline at end of file + +#endif //BLIS_ENABLE_SMALL_MATRIX_TRSM From d6fcfe734517a1a53fb0fa38d9a650841c9f09b0 Mon Sep 17 00:00:00 2001 From: Madan mohan Manokar Date: Fri, 17 Sep 2021 15:32:47 +0530 Subject: [PATCH 035/243] gemmt SUP limitThread count for small sizes 1. Max thread cap added for small dimension based on product(n*k). AMD-Internal: [CPUPL-1388] Change-Id: I34412a1374bb58a9c4b3fd8e40949a69006cf057 --- frame/3/bli_l3_sup.c | 9 ++++++++- frame/base/bli_rntm.c | 16 ++++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/frame/3/bli_l3_sup.c b/frame/3/bli_l3_sup.c index 163a828f86..a7d7a7874a 100644 --- a/frame/3/bli_l3_sup.c +++ b/frame/3/bli_l3_sup.c @@ -158,7 +158,6 @@ printf( "dims: %d %d %d (threshs: %d %d %d)\n", AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_2); } - err_t bli_gemmtsup ( obj_t* alpha, @@ -243,6 +242,14 @@ err_t bli_gemmtsup if ( rntm == NULL ) { bli_rntm_init_from_global( &rntm_l ); rntm = &rntm_l; } else { rntm_l = *rntm; rntm = &rntm_l; } +#ifdef AOCL_DYNAMIC + // If dynamic-threading is enabled, calculate optimum number + // of threads and update in rntm + + // Limit the number of thread for smaller sizes. + bli_nthreads_optimum( a, b, c, BLIS_GEMMT, rntm ); +#endif + #if 0 const num_t dt = bli_obj_dt( c ); const dim_t m = bli_obj_length( c ); diff --git a/frame/base/bli_rntm.c b/frame/base/bli_rntm.c index ba878ac6df..6a100bbe8e 100644 --- a/frame/base/bli_rntm.c +++ b/frame/base/bli_rntm.c @@ -605,6 +605,22 @@ void bli_nthreads_optimum( if(m<=512 && n<=512) n_threads_ideal = 4; } + else if( family == BLIS_GEMMT && bli_obj_is_double(c) ) + { + dim_t n = bli_obj_length(c); + dim_t k = bli_obj_width_after_trans(a); + dim_t product = (n*k)>>4; /* product is derived based on n and k */ + // Limit the number thread for smaller sizes: + if(product <= 346) + { + n_threads_ideal = 1; + } + /* finer threshold needs to set for max_thread cap of 2,3,4,5,6..32 */ + else + { + n_threads_ideal = n_threads; + } + } dim_t n_threads_opt = bli_min(n_threads, n_threads_ideal); From cbd9ea76affde0bb362c0a1fa7f8b5ca27c72d00 Mon Sep 17 00:00:00 2001 From: Nageshwar Singh Date: Mon, 20 Sep 2021 14:33:08 +0530 Subject: [PATCH 036/243] Complex single standalone gemv implementation independent of axpyf. Details - For axpyf implementation there are function(axpyf) calling overhead. - New implementations reduces function calling overhead. - This implementation uses kernel of size 8x4. - This implementation gives better performance for smaller sizes when compared to axpyf based implementation AMD-Internal: [CPUPL-1402] Change-Id: Ic9a5e59363290caf26284548638da9065952fd48 --- frame/2/gemv/bli_gemv_unf_var2.c | 75 +++++++--- kernels/zen/2/bli_gemv_zen_int_4.c | 221 ++++++++++++++++++++++++++++- kernels/zen/bli_kernels_zen.h | 1 + 3 files changed, 273 insertions(+), 24 deletions(-) diff --git a/frame/2/gemv/bli_gemv_unf_var2.c b/frame/2/gemv/bli_gemv_unf_var2.c index 34c11f758b..ffebf17bac 100644 --- a/frame/2/gemv/bli_gemv_unf_var2.c +++ b/frame/2/gemv/bli_gemv_unf_var2.c @@ -498,14 +498,14 @@ void bli_cgemv_unf_var2 /* If beta is zero, use setv. Otherwise, scale by beta. */ /* y = beta * y; */ /* beta=0 case is hadled by scalv internally */ - bli_cscalv_ex + bli_cscalv_zen_int10 ( BLIS_NO_CONJUGATE, n_elem, beta, - y, incy, - cntx, - NULL + y, + incy, + cntx ); if( bli_ceq0( *alpha ) ) @@ -513,30 +513,59 @@ void bli_cgemv_unf_var2 AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3) return; } - /* fusing factor. */ - b_fuse = 4; - for ( i = 0; i < n_iter; i += f ) + // for non-unit incx, incy and rs_at and conjugate will be added in the next patch + if( ( (incx == 1) && (incy == 1) && (rs_at == 1) ) && + !bli_is_conj(conja) && !bli_is_conj(conjx) && + !bli_is_trans(transa)) { - f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); - A1 = a + (0 )*rs_at + (i )*cs_at; - x1 = x + (i )*incx; - y1 = y + (0 )*incy; - - /* y = y + alpha * A1 * x1; */ - bli_caxpyf_zen_int_4 + // This gemv code deals with the followint conditions only + // 1. incx, incy, and row stride equal to one + // 2. Non conjugate A matrix and X vector + // 3. No Transpose for A Martix + // Rest is taken care by the else part (axpyf implementation) + bli_cgemv_zen_int_4x4 ( - conja, - conjx, - n_elem, - f, - alpha, - A1, rs_at, cs_at, - x1, incx, - y1, incy, - NULL + conja, + conjx, + m, + n, + alpha, + a, rs_at, cs_at, + x, incx, + beta, + y, incy, + NULL ); } + else + { + /* fusing factor. */ + b_fuse = 4; + + for ( i = 0; i < n_iter; i += f ) + { + f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); + A1 = a + (0 )*rs_at + (i )*cs_at; + x1 = x + (i )*incx; + y1 = y + (0 )*incy; + + /* y = y + alpha * A1 * x1; */ + bli_caxpyf_zen_int_4 + ( + conja, + conjx, + n_elem, + f, + alpha, + A1, rs_at, cs_at, + x1, incx, + y1, incy, + NULL + ); + } + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); } diff --git a/kernels/zen/2/bli_gemv_zen_int_4.c b/kernels/zen/2/bli_gemv_zen_int_4.c index 95060f57e2..b3c92b551c 100644 --- a/kernels/zen/2/bli_gemv_zen_int_4.c +++ b/kernels/zen/2/bli_gemv_zen_int_4.c @@ -257,4 +257,223 @@ void bli_zgemv_zen_int_4x4 cntx ); } -} \ No newline at end of file +} + +/* + This implementation uses 512 bits of cache line efficiently for + column stored matrix and vectors. + To achieve this, at each iteration we use 2 ymm registers + i.e. .512 bits for arithmetic operation. By this we use the + cache efficiently. +*/ +void bli_cgemv_zen_int_4x4 +( + conj_t conja, + conj_t conjx, + dim_t m, + dim_t n, + scomplex* restrict alpha, + scomplex* restrict a, inc_t inca, inc_t lda, + scomplex* restrict x, inc_t incx, + scomplex* restrict beta, + scomplex* restrict y, inc_t incy, + cntx_t* restrict cntx +) +{ + + const dim_t S_MR = 8; // Kernel size , m = 8 + const dim_t S_NR = 4; // Kernel size , n = 4 + + scomplex chi0; + scomplex chi1; + scomplex chi2; + scomplex chi3; + + inc_t lda2 = 2*lda; + inc_t lda3 = 3*lda; + inc_t incy4 = 4*incy; + inc_t incx2 = 2*incx; + inc_t incx3 = 3*incx; + inc_t inca2 = 4*inca; + + scomplex* x0 = x; + scomplex* y0 = y; + scomplex* a0 = a; + + dim_t i,j; + + __m256 ymm0, ymm1, ymm2, ymm3; + __m256 ymm4, ymm5, ymm6, ymm7; + __m256 ymm8, ymm9, ymm10, ymm11; + __m256 ymm12, ymm13, ymm14, ymm15; + + for( i = 0; i+S_NR-1 < n; i+=S_NR ) + { + a0 = a + (i )*lda; + x0 = x + (i )*incx; + y0 = y;// For each kernel, y should start form beginning + + chi0 = *( x0); + chi1 = *( x0 + incx ); + chi2 = *( x0 + incx2 ); + chi3 = *( x0 + incx3 ); + + bli_cscals( *alpha, chi0 ); + bli_cscals( *alpha, chi1 ); + bli_cscals( *alpha, chi2 ); + bli_cscals( *alpha, chi3 ); + + ymm0 = _mm256_broadcast_ss(&chi0.real); // real part of x0 + ymm1 = _mm256_broadcast_ss(&chi0.imag); // imag part of x0 + ymm2 = _mm256_broadcast_ss(&chi1.real); // real part of x1 + ymm3 = _mm256_broadcast_ss(&chi1.imag); // imag part of x1 + ymm4 = _mm256_broadcast_ss(&chi2.real); // real part of x2 + ymm5 = _mm256_broadcast_ss(&chi2.imag); // imag part of x2 + ymm6 = _mm256_broadcast_ss(&chi3.real); // real part of x3 + ymm7 = _mm256_broadcast_ss(&chi3.imag); // imag part of x3 + + for( j = 0 ; j+S_MR-1 < m ; j+=S_MR ) + { + //load columns of A, each ymm reg had 4 elements + ymm8 = _mm256_loadu_ps((float const *)(a0)); + ymm9 = _mm256_loadu_ps((float const *)(a0 + lda)); + ymm10 = _mm256_loadu_ps((float const *)(a0 + lda2)); + ymm11 = _mm256_loadu_ps((float const *)(a0 + lda3)); + + //-------------------- + //Ar*Xr Ai*Xr Ar*Xr Ai*Xr Ar*Xr Ai*Xr Ar*Xr Ai*Xr + ymm14 = _mm256_mul_ps(ymm8, ymm0); + //Ar*Xi Ai*Xi Ar*Xi Ai*Xi Ar*Xi Ai*Xi Ar*Xi Ai*Xi + ymm15 = _mm256_mul_ps(ymm8, ymm1); + + /* Next set of A mult by real and imag, + Add into the previous real and imag results */ + // (Ar*Xr Ai*Xr Ar*Xr Ai*Xr Ar*Xr Ai*Xr Ar*Xr Ai*Xr) + // + (prev iteration real results) + ymm14 = _mm256_fmadd_ps(ymm9, ymm2, ymm14); + // (Ar*Xi Ai*Xi Ar*Xi Ai*Xi Ar*Xi Ai*Xi Ar*Xi Ai*Xi) + // + (prev iteration imag results) + ymm15 = _mm256_fmadd_ps(ymm9, ymm3, ymm15); + // (Ar*Xr Ai*Xr Ar*Xr Ai*Xr Ar*Xr Ai*Xr Ar*Xr Ai*Xr) + // + (prev iteration real results) + ymm14 = _mm256_fmadd_ps(ymm10, ymm4, ymm14); + // (Ar*Xi Ai*Xi Ar*Xi Ai*Xi Ar*Xi Ai*Xi Ar*Xi Ai*Xi) + // + (prev iteration imag results) + ymm15 = _mm256_fmadd_ps(ymm10, ymm5, ymm15); + // (Ar*Xr Ai*Xr Ar*Xr Ai*Xr Ar*Xr Ai*Xr Ar*Xr Ai*Xr) + // + (prev iteration real results) + ymm14 = _mm256_fmadd_ps(ymm11, ymm6, ymm14); + // (Ar*Xi Ai*Xi Ar*Xi Ai*Xi Ar*Xi Ai*Xi Ar*Xi Ai*Xi) + // + (prev iteration imag results) + ymm15 = _mm256_fmadd_ps(ymm11, ymm7, ymm15); + /*Permute the imag acc register to addsub to real accu results */ + // (Ar*Xi Ai*Xi Ar*Xi Ai*Xi Ar*Xi Ai*Xi Ar*Xi Ai*Xi) + // => (Ai*Xi Ar*Xi Ai*Xi Ar*Xi Ai*Xi Ar*Xi Ai*Xi Ar*Xi) + ymm15 = _mm256_permute_ps(ymm15, 0xB1); + /*AddSub to get the 2 proper complex multipled value*/ + /* Ar*Xi - Ai*Xi, Ai*Xi + Ar*Xi, Ar*Xi - Ai*Xi, Ai*Xi + Ar*Xi, + Ar*Xi - Ai*Xi, Ai*Xi + Ar*Xi, Ar*Xi - Ai*Xi, Ai*Xi + Ar*Xi*/ + ymm12 = _mm256_addsub_ps(ymm14, ymm15); + //load Y vector + ymm14 = _mm256_loadu_ps((float*)y0); + //Add the results into y + ymm12 = _mm256_add_ps(ymm14, ymm12); + // Store the results back + _mm256_storeu_ps((float*)(y0), ymm12); + +//----------------------- + + // Load Next Set of A matrix elements for the same col + // Ar2 Ai2 Ar3 Ai3 + ymm8 = _mm256_loadu_ps((float const *)(a0 + (inca2))); + ymm9 = _mm256_loadu_ps((float const *)(a0 + (inca2) + lda)); + ymm10 = _mm256_loadu_ps((float const *)(a0 + (inca2) + lda2)); + ymm11 = _mm256_loadu_ps((float const *)(a0 + (inca2) + lda3)); + + //Ar0*Xr Ai0*Xr Ar1*Xr Ai1*Xr + ymm14 = _mm256_mul_ps(ymm8, ymm0); + //Ar0*Xi Ai0*Xi Ar1*Xi Ai1*Xi + ymm15 = _mm256_mul_ps(ymm8, ymm1); + + /* Next set of A mult by real and imag, + Add into the previous real and imag results */ + + // (Ar*Xr Ai*Xr Ar*Xr Ai*Xr) + (prev iteration real results) + ymm14 = _mm256_fmadd_ps(ymm9, ymm2, ymm14); + // (Ar*Xi Ai*Xi Ar*Xi Ai*Xi) + + (prev iteration imag results) + ymm15 = _mm256_fmadd_ps(ymm9, ymm3, ymm15); + + // (Ar*Xr Ai*Xr Ar*Xr Ai*Xr) + (prev iteration real results) + ymm14 = _mm256_fmadd_ps(ymm10, ymm4, ymm14); + // (Ar*Xi Ai*Xi Ar*Xi Ai*Xi) + + (prev iteration imag results) + ymm15 = _mm256_fmadd_ps(ymm10, ymm5, ymm15); + + // (Ar*Xr Ai*Xr Ar*Xr Ai*Xr) + (prev iteration real results) + ymm14 = _mm256_fmadd_ps(ymm11, ymm6, ymm14); + // (Ar*Xi Ai*Xi Ar*Xi Ai*Xi) + + (prev iteration imag results) + ymm15 = _mm256_fmadd_ps(ymm11, ymm7, ymm15); + + /*Permute the imag acc register to addsub to real accu results */ + // (Ar*Xi Ai*Xi Ar*Xi Ai*Xi) => (Ai*Xi Ar*Xi Ai*Xi Ar*Xi) + ymm15 = _mm256_permute_ps(ymm15, 0xB1); + /*AddSub to get the 2 proper complex multipled value*/ + /* Ar*Xi - Ai*Xi, Ai*Xi + Ar*Xi, Ar*Xi - Ai*Xi, Ai*Xi + Ar*Xi*/ + ymm13 = _mm256_addsub_ps(ymm14, ymm15); + + // load Y vector + ymm14 = _mm256_loadu_ps((float *)(y0 + (incy4))); + // Add the results into y + ymm13 = _mm256_add_ps(ymm14, ymm13); + // Store the results back + _mm256_storeu_ps((float*)(y0 + (incy4)), ymm13); + + y0 += S_MR*incy ; // Next Set of y0 vector + a0 += S_MR*inca ; // Next Set of a0 matrix elements in the same col + } + + // For resisual m + for( ; j < m ; ++j ) + { + scomplex y0c = *(scomplex*)y0; + const scomplex a0c = *a0; + const scomplex a1c = *(a0 + lda); + const scomplex a2c = *(a0 + lda2); + const scomplex a3c = *(a0 + lda3); + + y0c.real += chi0.real * a0c.real - chi0.imag * a0c.imag; + y0c.real += chi1.real * a1c.real - chi1.imag * a1c.imag; + y0c.real += chi2.real * a2c.real - chi2.imag * a2c.imag; + y0c.real += chi3.real * a3c.real - chi3.imag * a3c.imag; + + y0c.imag += chi0.imag * a0c.real + chi0.real * a0c.imag; + y0c.imag += chi1.imag * a1c.real + chi1.real * a1c.imag; + y0c.imag += chi2.imag * a2c.real + chi2.real * a2c.imag; + y0c.imag += chi3.imag * a3c.real + chi3.real * a3c.imag; + + *(scomplex*)y0 = y0c; + a0 += 1; + y0 += 1; + } + } + + // For resisual n, axpyv is used + for ( ; i < n; ++i ) + { + scomplex* a1 = a + (i )*lda; + scomplex* chi1 = x + (i )*incx; + scomplex* y1 = y; + scomplex alpha_chi1; + bli_ccopycjs( conjx, *chi1, alpha_chi1 ); + bli_cscals( *alpha, alpha_chi1 ); + bli_caxpyv_zen_int5 + ( + conja, + m, + &alpha_chi1, + a1, inca, + y1, incy, + cntx + ); + } +} + diff --git a/kernels/zen/bli_kernels_zen.h b/kernels/zen/bli_kernels_zen.h index b39ccec577..8845996962 100644 --- a/kernels/zen/bli_kernels_zen.h +++ b/kernels/zen/bli_kernels_zen.h @@ -117,6 +117,7 @@ DOTXF_KER_PROT( double, d, dotxf_zen_int_8 ) //gemv(scalar code) GEMV_KER_PROT( double, d, gemv_zen_ref_c ) +GEMV_KER_PROT( scomplex, c, gemv_zen_int_4x4 ) GEMV_KER_PROT( dcomplex, z, gemv_zen_int_4x4 ) // -- level-3 sup -------------------------------------------------------------- From 2b6faf21a19569c03d888f2fd305d80b52f57e54 Mon Sep 17 00:00:00 2001 From: Harsh Dave Date: Mon, 4 Oct 2021 02:56:47 -0500 Subject: [PATCH 037/243] Fixed conjugate transpose issue for zscalv and cscalv Details: AMD Internal Id: CPUPL-1702 - While performing trsm function A's imaginary part needed to be complimented as per conjugate transpose. -So in the case of conjugate transpose A's imaginary part is negated before doing trsm. Change-Id: Ic736733a483eeadf6356952b434128c0af988e36 --- frame/compat/bla_trsm.c | 183 +++++++++++++++++++++++++++++++-- kernels/zen/3/bli_trsm_small.c | 13 ++- 2 files changed, 188 insertions(+), 8 deletions(-) diff --git a/frame/compat/bla_trsm.c b/frame/compat/bla_trsm.c index b29219fe56..cd0a2b8066 100644 --- a/frame/compat/bla_trsm.c +++ b/frame/compat/bla_trsm.c @@ -920,7 +920,7 @@ void ztrsm_ trans_t blis_transa; diag_t blis_diaga; dim_t m0, n0; - conj_t conja = BLIS_NO_CONJUGATE ; + conj_t conja = BLIS_NO_CONJUGATE; /* Initialize BLIS. */ bli_init_auto(); @@ -997,7 +997,50 @@ void ztrsm_ } else if( ( blis_side == BLIS_RIGHT ) && ( m0 != 1 ) ) { - ; + bli_zscalv_ex + ( + conja, + m0, + (dcomplex*)alpha, + (dcomplex*)b, rs_b, + NULL, + NULL + ); + if(blis_diaga == BLIS_NONUNIT_DIAG) + { + dcomplex inva = {1.0, 0.0}; + dcomplex a_dup; + if(*transa == 'C' && *diaga == 'N') + { + a_dup.real = a->real; + a_dup.imag = a->imag * -1.0; + } + else + { + a_dup.real = a->real; + a_dup.imag = a->imag; + } + +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + bli_zinvscals(a_dup, inva); +#else + inva.real = a_dup.real; + inva.imag = a_dup.imag; +#endif + for(int indx = 0; indx < m0; indx ++) + { +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + bli_zscals(inva, b[indx]) +#else + + bli_zinvscals(inva, b[indx]) +#endif + } + + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; } } else if( m0 == 1 ) @@ -1049,7 +1092,50 @@ void ztrsm_ } else if(( blis_side == BLIS_LEFT ) && ( n0 != 1 )) { - ; + bli_zscalv_ex + ( + conja, + n0, + (dcomplex*)alpha, + (dcomplex*)b, cs_b, + NULL, + NULL + ); + if(blis_diaga == BLIS_NONUNIT_DIAG) + { + dcomplex inva = {1.0, 0.0}; + dcomplex a_dup; + if(*transa == 'C' && *diaga == 'N') + { + a_dup.real = a->real; + a_dup.imag = a->imag * -1.0; + } + else + { + a_dup.real = a->real; + a_dup.imag = a->imag; + } + +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + bli_zinvscals(a_dup, inva); +#else + inva.real = a_dup.real; + inva.imag = a_dup.imag; +#endif + for(int indx = 0; indx < n0; indx ++) + { +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + bli_zscals(inva ,b[indx * cs_b]) +#else + + bli_zinvscals(inva ,b[indx * cs_b]) +#endif + } + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } } @@ -1141,7 +1227,7 @@ void ctrsm_ trans_t blis_transa; diag_t blis_diaga; dim_t m0, n0; - conj_t conja = BLIS_NO_CONJUGATE ; + conj_t conja = BLIS_NO_CONJUGATE; /* Initialize BLIS. */ bli_init_auto(); @@ -1218,7 +1304,50 @@ void ctrsm_ } else if( ( blis_side == BLIS_RIGHT ) && ( m0 != 1 ) ) { - ; + bli_cscalv_ex + ( + conja, + m0, + (scomplex*)alpha, + (scomplex*)b, rs_b, + NULL, + NULL + ); + if(blis_diaga == BLIS_NONUNIT_DIAG) + { + scomplex inva = {1.0, 0.0}; + scomplex a_dup; + if(*transa == 'C' && *diaga == 'N') + { + a_dup.real = a->real; + a_dup.imag = a->imag * -1.0; + } + else + { + a_dup.real = a->real; + a_dup.imag = a->imag; + } + +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + bli_cinvscals(a_dup, inva); +#else + inva.real = a_dup.real; + inva.imag = a_dup.imag; +#endif + + for(int indx = 0; indx < m0; indx ++) + { +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + bli_cscals(inva ,b[indx]) +#else + bli_cinvscals(inva, b[indx]) +#endif + } + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } } else if( m0 == 1 ) @@ -1270,7 +1399,49 @@ void ctrsm_ } else if(( blis_side == BLIS_LEFT ) && ( n0 != 1 )) { - ; + bli_cscalv_ex + ( + conja, + n0, + (scomplex*)alpha, + (scomplex*)b, cs_b, + NULL, + NULL + ); + if(blis_diaga == BLIS_NONUNIT_DIAG) + { + scomplex inva = {1.0, 0.0}; + scomplex a_dup; + if(*transa == 'C' && *diaga == 'N') + { + a_dup.real = a->real; + a_dup.imag = a->imag * -1.0; + } + else + { + a_dup.real = a->real; + a_dup.imag = a->imag; + } + +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + bli_cinvscals(a_dup, inva) +#else + inva.real = a_dup.real; + inva.imag = a_dup.imag; +#endif + for(int indx = 0; indx < n0; indx ++) + { +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + bli_cscals(inva ,b[indx * cs_b]) +#else + bli_cinvscals(inva, b[indx * cs_b]) +#endif + + } + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; } } diff --git a/kernels/zen/3/bli_trsm_small.c b/kernels/zen/3/bli_trsm_small.c index 576bec0abc..c782a08a49 100644 --- a/kernels/zen/3/bli_trsm_small.c +++ b/kernels/zen/3/bli_trsm_small.c @@ -5823,11 +5823,15 @@ BLIS_INLINE void ztrsm_small_pack_diag_element dim_t size ) { +#ifdef BLIS_ENABLE_TRSM_PREINVERSION __m256d ymm1, ymm2, ymm3, ymm4, ymm5, ymm6, ymm7, ymm8; + ymm7 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); +#else + __m256d ymm1, ymm2, ymm3; +#endif bool is_four = (size == 4) ? 1 : 0; dcomplex ones = {1.0, 1.0}; ymm2 = ymm1 = _mm256_broadcast_pd((__m128d const *)&ones); - ymm7 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); if(!is_unitdiag) { //broadcast diagonal elements of A11 @@ -36977,11 +36981,15 @@ BLIS_INLINE void ctrsm_small_pack_diag_element dim_t size ) { - __m256 ymm1, ymm2, ymm3, ymm4, ymm5, ymm6, ymm7, ymm8; + __m256 ymm1, ymm2, ymm3, ymm4, ymm5, ymm6, ymm8; bool is_eight = (size == 8) ? 1 : 0; scomplex ones = {1.0, 1.0}; ymm2 = ymm1 = _mm256_broadcast_ps((__m128 const *)&ones); +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + __m256 ymm7; ymm7 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); +#endif + if(!is_unitdiag) { //broadcast diagonal elements of A11 @@ -37217,6 +37225,7 @@ BLIS_INLINE void ctrsm_small_pack_diag_element ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal));\ ymm16 = _mm256_permute_ps(ymm16, 0x44);\ \ + xmm4 = _mm_setzero_ps();\ xmm0 = _mm_loadu_ps((float const *)(b11));\ xmm1 = _mm_loadl_pi(xmm4, (__m64 const *)(b11 + 2));\ xmm2 = _mm_loadu_ps((float const *)(b11 + cs_b));\ From 53e1d0539f3b5741ca3311fe9a11e0895acab652 Mon Sep 17 00:00:00 2001 From: Harsh Dave Date: Wed, 6 Oct 2021 10:07:47 -0500 Subject: [PATCH 038/243] Fixed conjugate transpose kernel issue Details: AMD Internal Id: CPUPL-1702 - For the cases of A being of 1x1 dimension and of left and right hand side, A's only element is conjugate transposed by negating its imaginary component. Change-Id: I696ae982d9d60e0e702edaba98acbe9a5b0cd44c --- frame/compat/bla_trsm.c | 44 +++++++++++++++++++++++++++++++---------- 1 file changed, 34 insertions(+), 10 deletions(-) diff --git a/frame/compat/bla_trsm.c b/frame/compat/bla_trsm.c index cd0a2b8066..46f79f02f8 100644 --- a/frame/compat/bla_trsm.c +++ b/frame/compat/bla_trsm.c @@ -190,8 +190,8 @@ void PASTEF77(ch,blasname) \ /* ----------------------------------------------------------- */ \ /* TRSM API: AX = B, where X = B */ \ /* CALL TRSV when X & B are vector and when A is Matrix */ \ - /* Case 1: LEFT : TRSM, C(mxn) = A(mxm) * B(mxn) */ \ - /* Case 2: RIGHT : TRSM, C(mxn) = B(mxn) * A(nxn) */ \ + /* Case 1: LEFT : TRSM, B(mxn) = A(mxm) * X(mxn) */ \ + /* Case 2: RIGHT : TRSM, B(mxn) = X(mxn) * A(nxn) */ \ /* |--------|-------|-------|-------|------------------------| */ \ /* | | A | X | B | Implementation | */ \ /* |--------|-------|-------|-------|------------------------| */ \ @@ -494,7 +494,7 @@ void strsm_ if(blis_diaga == BLIS_NONUNIT_DIAG) { float inva = 1.0/ *a; - for(int indx = 0; indx < m0; indx ++) + for(dim_t indx = 0; indx < m0; indx ++) { b[indx] = ( inva * b[indx] ); } @@ -565,7 +565,7 @@ void strsm_ if(blis_diaga == BLIS_NONUNIT_DIAG) { float inva = 1.0/ *a; - for(int indx = 0; indx < n0; indx ++) + for(dim_t indx = 0; indx < n0; indx ++) { b[indx*cs_b] = (inva * b[indx*cs_b] ); } @@ -751,7 +751,7 @@ void dtrsm_ if(blis_diaga == BLIS_NONUNIT_DIAG) { double inva = 1.0/ *a; - for(int indx = 0; indx < m0; indx ++) + for(dim_t indx = 0; indx < m0; indx ++) { b[indx] = ( inva * b[indx] ); } @@ -822,7 +822,7 @@ void dtrsm_ if(blis_diaga == BLIS_NONUNIT_DIAG) { double inva = 1.0/ *a; - for(int indx = 0; indx < n0; indx ++) + for(dim_t indx = 0; indx < n0; indx ++) { b[indx*cs_b] = (inva * b[indx*cs_b] ); } @@ -1010,6 +1010,12 @@ void ztrsm_ { dcomplex inva = {1.0, 0.0}; dcomplex a_dup; + /** + * For conjugate transpose and non-unit diagonal + * kernel, negating imaginary part of A. + * As the dimension of A is 1x1, there's going to + * be only one 1 element of A. + */ if(*transa == 'C' && *diaga == 'N') { a_dup.real = a->real; @@ -1027,7 +1033,7 @@ void ztrsm_ inva.real = a_dup.real; inva.imag = a_dup.imag; #endif - for(int indx = 0; indx < m0; indx ++) + for(dim_t indx = 0; indx < m0; indx ++) { #ifdef BLIS_ENABLE_TRSM_PREINVERSION bli_zscals(inva, b[indx]) @@ -1105,6 +1111,12 @@ void ztrsm_ { dcomplex inva = {1.0, 0.0}; dcomplex a_dup; + /** + * For conjugate transpose and non-unit diagonal + * kernel, negating imaginary part of A. + * As the dimension of A is 1x1, there's going to + * be only one 1 element of A. + */ if(*transa == 'C' && *diaga == 'N') { a_dup.real = a->real; @@ -1122,7 +1134,7 @@ void ztrsm_ inva.real = a_dup.real; inva.imag = a_dup.imag; #endif - for(int indx = 0; indx < n0; indx ++) + for(dim_t indx = 0; indx < n0; indx ++) { #ifdef BLIS_ENABLE_TRSM_PREINVERSION bli_zscals(inva ,b[indx * cs_b]) @@ -1317,6 +1329,12 @@ void ctrsm_ { scomplex inva = {1.0, 0.0}; scomplex a_dup; + /** + * For conjugate transpose and non-unit diagonal + * kernel, negating imaginary part of A. + * As the dimension of A is 1x1, there's going to + * be only one 1 element of A. + */ if(*transa == 'C' && *diaga == 'N') { a_dup.real = a->real; @@ -1335,7 +1353,7 @@ void ctrsm_ inva.imag = a_dup.imag; #endif - for(int indx = 0; indx < m0; indx ++) + for(dim_t indx = 0; indx < m0; indx ++) { #ifdef BLIS_ENABLE_TRSM_PREINVERSION bli_cscals(inva ,b[indx]) @@ -1412,6 +1430,12 @@ void ctrsm_ { scomplex inva = {1.0, 0.0}; scomplex a_dup; + /** + * For conjugate transpose and non-unit diagonal + * kernel, negating imaginary part of A. + * As the dimension of A is 1x1, there's going to + * be only one 1 element of A. + */ if(*transa == 'C' && *diaga == 'N') { a_dup.real = a->real; @@ -1429,7 +1453,7 @@ void ctrsm_ inva.real = a_dup.real; inva.imag = a_dup.imag; #endif - for(int indx = 0; indx < n0; indx ++) + for(dim_t indx = 0; indx < n0; indx ++) { #ifdef BLIS_ENABLE_TRSM_PREINVERSION bli_cscals(inva ,b[indx * cs_b]) From ae844f475cce0b7608b6c007717dbf5129d73f7b Mon Sep 17 00:00:00 2001 From: Dipal M Zambare Date: Thu, 7 Oct 2021 11:35:51 +0530 Subject: [PATCH 039/243] Fixed build issue in DTL when only traces are enabled. AMD-Internal: [CPUPL-1691] Change-Id: Idc273666054529db5a2fb96a7d7ebbf7a3f5b008 --- aocl_dtl/aocldtl.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aocl_dtl/aocldtl.c b/aocl_dtl/aocldtl.c index 148e99d888..6f24788aa0 100644 --- a/aocl_dtl/aocldtl.c +++ b/aocl_dtl/aocldtl.c @@ -438,7 +438,7 @@ void DTL_DumpData( } /* DTL_DumpData */ #endif -#if (AOCL_DTL_TRACE_ENABLE || AOCL_DTL_LOG_ENABLE) +#if (AOCL_DTL_LOG_ENABLE) void AOCL_DTL_start_perf_timer(void) { AOCL_TID current_thread = AOCL_gettid(); From 3364c0e4ebedc1320be9331acc6d343e6fdf6b87 Mon Sep 17 00:00:00 2001 From: Dipal M Zambare Date: Fri, 8 Oct 2021 10:35:35 +0530 Subject: [PATCH 040/243] Binary and dynamic dispatch configuration name change -- Reverted changes made to include lp/ilp info in binary name This reverts commit c5e6f885f00e77e1c67637e6c176c97b679141ae. -- Included BLAS int size in 'make showconfig' -- Renamed amdepyc configuration to amdzen Change-Id: Ie87ec1c03e105f606aef1eac397ba0d8338906a6 --- CMakeLists.txt | 2 +- Makefile | 2 ++ common.mk | 12 ++---------- .../bli_family_amdzen.h} | 7 +++---- config/{amdepyc => amdzen}/make_defs.mk | 3 +-- config_registry | 2 +- frame/base/bli_arch.c | 2 +- frame/include/bli_arch_config.h | 4 ++-- 8 files changed, 13 insertions(+), 21 deletions(-) rename config/{amdepyc/bli_family_amdepyc.h => amdzen/bli_family_amdzen.h} (94%) rename config/{amdepyc => amdzen}/make_defs.mk (96%) diff --git a/CMakeLists.txt b/CMakeLists.txt index 3fa559abc6..2e13ef3ab2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -52,7 +52,7 @@ elseif (${AOCL_BLIS_FAMILY} STREQUAL "zen3") add_definitions(-DBLIS_KERNELS_HASWELL) elseif (${AOCL_BLIS_FAMILY} STREQUAL "amd64") set(AOCL_BLIS_ZEN FALSE) - add_definitions(-DBLIS_FAMILY_AMDEPYC) + add_definitions(-DBLIS_FAMILY_AMDZEN) add_definitions(-DBLIS_CONFIG_ZEN3) add_definitions(-DBLIS_CONFIG_ZEN2) add_definitions(-DBLIS_CONFIG_ZEN) diff --git a/Makefile b/Makefile index 38cc8144ab..b248d5781a 100644 --- a/Makefile +++ b/Makefile @@ -5,6 +5,7 @@ # libraries. # # Copyright (C) 2014, The University of Texas at Austin +# Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -1105,6 +1106,7 @@ showconfig: check-env @echo "complex return scheme: $(MK_COMPLEX_RETURN_SCHEME)" @echo "enable trsm preinversion: $(MK_ENABLE_TRSM_PREINVERSION)" @echo "enable AOCL dynamic threads: $(MK_ENABLE_AOCL_DYNAMIC)" + @echo "BLAS Integer size(LP/ILP): $(MK_BLAS_INT_TYPE_SIZE)" diff --git a/common.mk b/common.mk index 26e8627adb..00b1d4354e 100644 --- a/common.mk +++ b/common.mk @@ -411,17 +411,9 @@ BASE_LIB_PATH := $(LIB_PATH) # The base name of the BLIS library that we will build. ifeq ($(THREADING_MODEL),off) -ifeq ($(MK_BLAS_INT_TYPE_SIZE), 64) -LIBBLIS := libblis-ilp64 +LIBBLIS := libblis else -LIBBLIS := libblis-lp64 -endif -else -ifeq ($(MK_BLAS_INT_TYPE_SIZE), 64) -LIBBLIS := libblis-mt-ilp64 -else -LIBBLIS := libblis-mt-lp64 -endif +LIBBLIS := libblis-mt endif # The shared (dynamic) library file suffix is different for Linux and OS X. diff --git a/config/amdepyc/bli_family_amdepyc.h b/config/amdzen/bli_family_amdzen.h similarity index 94% rename from config/amdepyc/bli_family_amdepyc.h rename to config/amdzen/bli_family_amdzen.h index 5ae4460442..c73409673d 100644 --- a/config/amdepyc/bli_family_amdepyc.h +++ b/config/amdzen/bli_family_amdzen.h @@ -4,7 +4,6 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2014, The University of Texas at Austin Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without @@ -33,8 +32,8 @@ */ -#ifndef BLIS_FAMILY_AMD64_H -#define BLIS_FAMILY_AMD64_H +#ifndef BLIS_FAMILY_AMDZEN_H +#define BLIS_FAMILY_AMDZEN_H // By default, it is effective to parallelize the outer loops. // Setting these macros to 1 will force JR and IR inner loops @@ -60,7 +59,7 @@ // BLIS), defining this macro as 1 yields better performance. #define AOCL_BLIS_MULTIINSTANCE 0 -#define BLIS_ENABLE_FAST_MATH +//#define BLIS_ENABLE_FAST_MATH #endif diff --git a/config/amdepyc/make_defs.mk b/config/amdzen/make_defs.mk similarity index 96% rename from config/amdepyc/make_defs.mk rename to config/amdzen/make_defs.mk index d7e1b73226..7697e9ff05 100644 --- a/config/amdepyc/make_defs.mk +++ b/config/amdzen/make_defs.mk @@ -4,7 +4,6 @@ # An object-based framework for developing high-performance BLAS-like # libraries. # -# Copyright (C) 2014, The University of Texas at Austin # Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without @@ -36,7 +35,7 @@ # Declare the name of the current configuration and add it to the # running list of configurations included by common.mk. -THIS_CONFIG := amdepyc +THIS_CONFIG := amdzen # For architecture independent files we still need to define # the required flags diff --git a/config_registry b/config_registry index 97dbcf5ae5..558eccc30c 100644 --- a/config_registry +++ b/config_registry @@ -11,7 +11,7 @@ x86_64: intel64 amd64 amd64_legacy intel64: skx knl haswell sandybridge penryn generic amd64_legacy: excavator steamroller piledriver bulldozer generic -amdepyc: zen3 zen2 zen generic +amdzen: zen3 zen2 zen generic # NOTE: ARM families will remain disabled until runtime hardware detection # logic is added to BLIS. diff --git a/frame/base/bli_arch.c b/frame/base/bli_arch.c index 3df2c3688b..153787d3ed 100644 --- a/frame/base/bli_arch.c +++ b/frame/base/bli_arch.c @@ -125,7 +125,7 @@ void bli_arch_set_id( void ) // Architecture families. #if defined BLIS_FAMILY_INTEL64 || \ - defined BLIS_FAMILY_AMDEPYC || \ + defined BLIS_FAMILY_AMDZEN || \ defined BLIS_FAMILY_AMD64_LEGACY || \ defined BLIS_FAMILY_X86_64 || \ defined BLIS_FAMILY_ARM64 || \ diff --git a/frame/include/bli_arch_config.h b/frame/include/bli_arch_config.h index b341eaee3c..a62128dffe 100644 --- a/frame/include/bli_arch_config.h +++ b/frame/include/bli_arch_config.h @@ -136,8 +136,8 @@ CNTX_INIT_PROTS( generic ) #ifdef BLIS_FAMILY_INTEL64 #include "bli_family_intel64.h" #endif -#ifdef BLIS_FAMILY_AMDEPYC -#include "bli_family_amdepyc.h" +#ifdef BLIS_FAMILY_AMDZEN +#include "bli_family_amdzen.h" #endif #ifdef BLIS_FAMILY_AMD64_LEGACY #include "bli_family_amd64_legacy.h" From 366ab661349ad71d3adc423a2c3ac6e860413758 Mon Sep 17 00:00:00 2001 From: mkurumel Date: Fri, 8 Oct 2021 17:31:48 +0530 Subject: [PATCH 041/243] DNRM2 : Disable dnrm2 Fast math implementation. Details : - Accuracy failures observed when fast math and ILP64 are enabled. - Disabling the feature with macro BLIS_ENABLE_FAST_MATH . AMD-Internal: [CPUPL-1907] Change-Id: I92c661647fb8cc5f1d0af8f6c4eae0fac1df5f16 --- config/zen/bli_family_zen.h | 2 +- config/zen2/bli_family_zen2.h | 2 +- config/zen3/bli_family_zen3.h | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/config/zen/bli_family_zen.h b/config/zen/bli_family_zen.h index a166125889..23d3d608c7 100644 --- a/config/zen/bli_family_zen.h +++ b/config/zen/bli_family_zen.h @@ -53,6 +53,6 @@ #define BLIS_SMALL_MATRIX_A_THRES_M_SYRK 96 #define BLIS_SMALL_MATRIX_A_THRES_N_SYRK 128 -#define BLIS_ENABLE_FAST_MATH +//#define BLIS_ENABLE_FAST_MATH #endif diff --git a/config/zen2/bli_family_zen2.h b/config/zen2/bli_family_zen2.h index dbaa8f4f73..dbae9752cc 100644 --- a/config/zen2/bli_family_zen2.h +++ b/config/zen2/bli_family_zen2.h @@ -56,6 +56,6 @@ // When running HPL with pure MPI without DGEMM threading (Single-threaded // BLIS), defining this macro as 1 yields better performance. #define AOCL_BLIS_MULTIINSTANCE 0 -#define BLIS_ENABLE_FAST_MATH +//#define BLIS_ENABLE_FAST_MATH #endif diff --git a/config/zen3/bli_family_zen3.h b/config/zen3/bli_family_zen3.h index 78e2c9de97..69def1422d 100644 --- a/config/zen3/bli_family_zen3.h +++ b/config/zen3/bli_family_zen3.h @@ -55,6 +55,6 @@ #define BLIS_SMALL_MATRIX_A_THRES_M_SYRK 96 #define BLIS_SMALL_MATRIX_A_THRES_N_SYRK 128 -#define BLIS_ENABLE_FAST_MATH +//#define BLIS_ENABLE_FAST_MATH #endif From 4af525a31371197b3346431d6fa3eb24baedd495 Mon Sep 17 00:00:00 2001 From: nphaniku Date: Mon, 11 Oct 2021 17:02:04 +0530 Subject: [PATCH 042/243] AOCL Windows BLIS : Windows build for dynamic dispatch library Change-Id: Ie05eafbeacbd5589b514d9353517330515104939 --- CMakeLists.txt | 41 ++++++--- build/blis_ref_kernel_mirror.py | 155 ++++++++++++++++++++++++++++++++ config/CMakeLists.txt | 14 ++- config/generic/CMakeLists.txt | 5 ++ frame/base/CMakeLists.txt | 4 +- ref_kernels/CMakeLists.txt | 9 +- 6 files changed, 209 insertions(+), 19 deletions(-) create mode 100644 build/blis_ref_kernel_mirror.py create mode 100644 config/generic/CMakeLists.txt diff --git a/CMakeLists.txt b/CMakeLists.txt index 2e13ef3ab2..8d892463a7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -17,7 +17,7 @@ set(AOCL_BLIS_ZEN TRUE) set (PYTHON_EXE "python") if ("${AOCL_BLIS_FAMILY}" STREQUAL "") - message(FATAL_ERROR "Machine configuration missing! Select one of zen, zen2, zen3 or amd64") + message(FATAL_ERROR "Machine configuration missing! Select one of zen, zen2, zen3 or amdzen") endif () if (${AOCL_BLIS_FAMILY} STREQUAL "auto") @@ -50,7 +50,7 @@ elseif (${AOCL_BLIS_FAMILY} STREQUAL "zen3") add_definitions(-DBLIS_KERNELS_ZEN2) add_definitions(-DBLIS_KERNELS_ZEN) add_definitions(-DBLIS_KERNELS_HASWELL) -elseif (${AOCL_BLIS_FAMILY} STREQUAL "amd64") +elseif (${AOCL_BLIS_FAMILY} STREQUAL "amdzen") set(AOCL_BLIS_ZEN FALSE) add_definitions(-DBLIS_FAMILY_AMDZEN) add_definitions(-DBLIS_CONFIG_ZEN3) @@ -63,7 +63,7 @@ elseif (${AOCL_BLIS_FAMILY} STREQUAL "amd64") add_definitions(-DBLIS_KERNELS_ZEN) add_definitions(-DBLIS_KERNELS_GENERIC) else () - message(FATAL_ERROR "Wrong machine configuration. Select one of zen, zen2, zen3 or amd64") + message(FATAL_ERROR "Wrong machine configuration. Select one of zen, zen2, zen3 or amdzen") endif () set(TARGET_ARCH ${AOCL_BLIS_FAMILY}) @@ -97,6 +97,17 @@ option (ENABLE_COMPLEX_RETURN_INTEL "Enable complex_return_intel" OFF) option (ENABLE_TRSM_PREINVERSION "Enable TRSM preinversion" ON) option (ENABLE_AOCL_DYNAMIC "Enable Dynamic Multi-threading" OFF) +if (${AOCL_BLIS_FAMILY} STREQUAL "amdzen") + set(REF_KERNEL_MIRRORING_PY "${CMAKE_SOURCE_DIR}/build/blis_ref_kernel_mirror.py") + message("ref_kernel mirroring for fat binary") + # Run python script to find the architecture family name + execute_process( + COMMAND ${PYTHON_EXE} ${REF_KERNEL_MIRRORING_PY} ${CMAKE_BINARY_DIR} + RESULT_VARIABLE CMD_RESULT + OUTPUT_VARIABLE CMD_OUTPUT + OUTPUT_STRIP_TRAILING_WHITESPACE) + message( STATUS "Ref Kernel Mirroring :" ${CMD_OUTPUT}) +endif() if(ENABLE_NO_UNDERSCORE_API) add_definitions(-DBLIS_ENABLE_NO_UNDERSCORE_API) endif() @@ -305,8 +316,10 @@ add_definitions(-D_CRT_SECURE_NO_DEPRECATE) #add_definitions(-DBLIS_OS_WINDOWS) add_definitions(-D_MSC_VER) +if (${AOCL_BLIS_FAMILY} STREQUAL "amdzen") +else() add_definitions(-DBLIS_CNAME=${TARGET_ARCH}) - +endif() # Generate the bli_config.h header file configure_file (build/bli_win_config.h.in ${CMAKE_SOURCE_DIR}/bli_config.h @ONLY) @@ -380,6 +393,12 @@ include_directories(${CMAKE_SOURCE_DIR}/config/generic) include_directories(${CMAKE_SOURCE_DIR}/config/zen) include_directories(${CMAKE_SOURCE_DIR}/config/zen2) include_directories(${CMAKE_SOURCE_DIR}/config/zen3) +if(${AOCL_BLIS_FAMILY} STREQUAL "amdzen") + include_directories(${CMAKE_BINARY_DIR}/ref_kernels/generic) + include_directories(${CMAKE_BINARY_DIR}/ref_kernels/zen) + include_directories(${CMAKE_BINARY_DIR}/ref_kernels/zen2) + include_directories(${CMAKE_BINARY_DIR}/ref_kernels/zen3) +endif() include_directories(${CMAKE_SOURCE_DIR}/ref_kernels) include_directories(${CMAKE_SOURCE_DIR}/kernels) include_directories(${CMAKE_SOURCE_DIR}/kernels/haswell) @@ -409,15 +428,13 @@ elseif (${AOCL_BLIS_FAMILY} STREQUAL "zen2") " ${CMAKE_CURRENT_SOURCE_DIR}/config/zen2/" " ${CMAKE_CURRENT_SOURCE_DIR}/kernels/zen/" " ${CMAKE_CURRENT_SOURCE_DIR}/kernels/haswell/" -elseif (${AOCL_BLIS_FAMILY} STREQUAL "amd64") - " ${CMAKE_CURRENT_SOURCE_DIR}/config/amd64/" - " ${CMAKE_CURRENT_SOURCE_DIR}/config/bulldozer/" - " ${CMAKE_CURRENT_SOURCE_DIR}/config/excavator/" + " ${CMAKE_CURRENT_SOURCE_DIR}/config/amdzen/" + " ${CMAKE_CURRENT_SOURCE_DIR}/config/zen/" + " ${CMAKE_CURRENT_SOURCE_DIR}/config/zen2/" + " ${CMAKE_CURRENT_SOURCE_DIR}/config/zen3/" " ${CMAKE_CURRENT_SOURCE_DIR}/config/generic/" - " ${CMAKE_CURRENT_SOURCE_DIR}/config/piledriver/" - " ${CMAKE_CURRENT_SOURCE_DIR}/config/steamroller/" - " ${CMAKE_CURRENT_SOURCE_DIR}/kernels/piledriver/" - " ${CMAKE_CURRENT_SOURCE_DIR}/kernels/bulldozer/" + " ${CMAKE_CURRENT_SOURCE_DIR}/kernels/zen/" + " ${CMAKE_CURRENT_SOURCE_DIR}/kernels/haswell/" endif () " ${CMAKE_CURRENT_SOURCE_DIR}/frame/0/" " ${CMAKE_CURRENT_SOURCE_DIR}/frame/0/copysc/" diff --git a/build/blis_ref_kernel_mirror.py b/build/blis_ref_kernel_mirror.py new file mode 100644 index 0000000000..b756eb30b6 --- /dev/null +++ b/build/blis_ref_kernel_mirror.py @@ -0,0 +1,155 @@ +"""Copyright (C) 2021, Advanced Micro Devices, Inc. All Rights Reserved""" +import os +import shutil +import subprocess +import sys + + +def create_folder(path): + """ Function to create the folder in an given path. + + Args: + path:- Folder path to create. + """ + try: + os.makedirs(path) + except FileExistsError: + pass + + +def remove_folder(path): + """ Function to delete folder in a given path. + + Args: + path:- Folder path to delete. + """ + try: + shutil.rmtree(path) + except FileNotFoundError: + pass + + +def execute_and_check(cmd): + """ Function to run power shell command in windows and bash command + in linux. + + Arg: + cmd:- Power shell/ bash command. + + Return: + Returns command output on success and terminates the execution + on failure. + """ + print('********************************************************') + print('Started execution of {} command...\n'.format(cmd)) + print('********************************************************') + + proc = subprocess.Popen(cmd, shell=True, stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + output, err = proc.communicate() + + if not proc.returncode: + print('********************************************************') + print('Execution of command : {} - was successful'.format(cmd)) + print('command {} output: {}'.format(cmd, + output.decode('ASCII'))) + print('********************************************************') + return output.decode('ASCII') + else: + print('########################################################') + print('Execution of command : {} - was failed'.format(cmd)) + print('command {} output: {}\n{}\n'.format(cmd, output.decode( + 'ASCII'), err.decode('ASCII'))) + exit(1) + + +def remove_lines_in_file(filename): + with open(filename, 'r') as fd: + file_content = fd.read() + file_content = file_content.replace( + 'if(${TARGET_ARCH} STREQUAL amdzen)\nadd_subdirectory(${CMAKE_BINARY_' + 'DIR}/ref_kernels/generic ${CMAKE_BINARY_DIR}/ref_kernels/generic)\n' + 'add_subdirectory(${CMAKE_BINARY_DIR}/ref_kernels/zen ${CMAKE_BINARY_' + 'DIR}/ref_kernels/zen)\nadd_subdirectory(${CMAKE_BINARY_DIR}/' + 'ref_kernels/zen2 ${CMAKE_BINARY_DIR}/ref_kernels/zen2)\n' + 'add_subdirectory(${CMAKE_BINARY_DIR}/ref_kernels/zen3 ' + '${CMAKE_BINARY_DIR}/ref_kernels/zen3)\nelse()', '\n') + data = file_content.replace('endif()', '\n') + with open(filename, 'w') as fd: + fd.write(data + '\n') + + +def write_to_file(filename, data): + with open(filename, 'r') as fd: + file_content = fd.read() + file_content = file_content.split('#include "blis.h"') + data = '\n'.join([file_content[0], '#include "blis.h"', data] + + file_content[1:]) + + with open(filename, 'w') as fd: + fd.write(data + '\n') + + +def add_macro_to_cfiles(cfiles, macro): + for cfile in cfiles: + if os.path.exists(cfile): + write_to_file(cfile, macro) + + +if __name__ == '__main__': + cwd = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) + source_path = os.path.join(cwd, 'ref_kernels') + build_path = sys.argv[1].replace('/', '\\') + dest_path = os.path.join(build_path, 'ref_kernels') + if os.path.exists(dest_path): + remove_folder(dest_path) + + temp = os.path.join(cwd, 'temp') + create_folder(temp) + execute_and_check('XCOPY {} {} /E'.format(source_path, temp)) + create_folder(os.path.join(dest_path, 'zen')) + create_folder(os.path.join(dest_path, 'zen2')) + create_folder(os.path.join(dest_path, 'zen3')) + create_folder(os.path.join(dest_path, 'generic')) + execute_and_check('XCOPY {} {} /E'.format( + temp, os.path.join(dest_path, 'zen'))) + execute_and_check('XCOPY {} {} /E'.format( + temp, os.path.join(dest_path, 'zen2'))) + execute_and_check('XCOPY {} {} /E'.format( + temp, os.path.join(dest_path, 'zen3'))) + execute_and_check('XCOPY {} {} /E'.format( + temp, os.path.join(dest_path, 'generic'))) + remove_folder(temp) + remove_lines_in_file(os.path.join( + dest_path, 'generic', 'CMakeLists.txt')) + remove_lines_in_file(os.path.join( + dest_path, 'zen', 'CMakeLists.txt')) + remove_lines_in_file(os.path.join( + dest_path, 'zen2', 'CMakeLists.txt')) + remove_lines_in_file(os.path.join( + dest_path, 'zen3', 'CMakeLists.txt')) + cfiles_in_generic = execute_and_check('cd {} && dir / s / b / o: gn *.c' + .format(os.path.join(dest_path, + 'generic'))) + cfiles_in_generic = cfiles_in_generic.split('\r\n') + add_macro_to_cfiles(cfiles_in_generic, + '\n#define BLIS_CNAME_INFIX _generic\n') + cfiles_in_zen = execute_and_check('cd {} && dir / s / b / o: gn *.c' + .format(os.path.join(dest_path, + 'zen'))) + cfiles_in_zen = cfiles_in_zen.split('\r\n') + add_macro_to_cfiles(cfiles_in_zen, + '\n#define BLIS_CNAME_INFIX _zen\n') + cfiles_in_zen2 = execute_and_check('cd {} && dir / s / b / o: gn *.c' + .format(os.path.join(dest_path, + 'zen2'))) + cfiles_in_zen2 = cfiles_in_zen2.split('\r\n') + add_macro_to_cfiles(cfiles_in_zen2, + '\n#define BLIS_CNAME_INFIX _zen2\n') + cfiles_in_zen3 = execute_and_check('cd {} && dir / s / b / o: gn *.c' + .format(os.path.join(dest_path, + 'zen3'))) + cfiles_in_zen3 = cfiles_in_zen3.split('\r\n') + add_macro_to_cfiles(cfiles_in_zen3, + '\n#define BLIS_CNAME_INFIX _zen3\n') diff --git a/config/CMakeLists.txt b/config/CMakeLists.txt index 61f0d61f3f..12568f67f7 100644 --- a/config/CMakeLists.txt +++ b/config/CMakeLists.txt @@ -1,4 +1,4 @@ -##Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. ## +##Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. ## if(${TARGET_ARCH} STREQUAL zen3) message("The configuration is : ${TARGET_ARCH}") @@ -9,10 +9,16 @@ add_subdirectory(zen2) elseif(${TARGET_ARCH} STREQUAL zen) message("The configuration is : ${TARGET_ARCH}") add_subdirectory(zen) -elseif(${TARGET_ARCH} STREQUAL amd64) +elseif(${TARGET_ARCH} STREQUAL amdzen) message("The configuration is : ${TARGET_ARCH}") -add_subdirectory(amd64) -else(${TARGET_ARCH} STREQUAL haswell) +add_subdirectory(generic) +add_subdirectory(zen) +add_subdirectory(zen2) +add_subdirectory(zen3) +elseif(${TARGET_ARCH} STREQUAL haswell) message("The configuration is : ${TARGET_ARCH}") add_subdirectory(haswell) +else(${TARGET_ARCH} STREQUAL generic) +message("The configuration is : ${TARGET_ARCH}") +add_subdirectory(generic) endif() \ No newline at end of file diff --git a/config/generic/CMakeLists.txt b/config/generic/CMakeLists.txt new file mode 100644 index 0000000000..2fd3855574 --- /dev/null +++ b/config/generic/CMakeLists.txt @@ -0,0 +1,5 @@ +##Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. ## + +target_sources("${PROJECT_NAME}" PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/bli_cntx_init_generic.c + ) diff --git a/frame/base/CMakeLists.txt b/frame/base/CMakeLists.txt index 9abfcc6a9c..5727bfe62b 100644 --- a/frame/base/CMakeLists.txt +++ b/frame/base/CMakeLists.txt @@ -1,4 +1,4 @@ -##Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.## +##Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved.## target_sources("${PROJECT_NAME}" PUBLIC @@ -11,7 +11,7 @@ target_sources("${PROJECT_NAME}" ${CMAKE_CURRENT_SOURCE_DIR}/bli_cntl.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_cntx.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_const.c - #${CMAKE_CURRENT_SOURCE_DIR}/bli_cpuid.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_cpuid.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_env.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_error.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_func.c diff --git a/ref_kernels/CMakeLists.txt b/ref_kernels/CMakeLists.txt index 3ad92e8213..61357c1fec 100644 --- a/ref_kernels/CMakeLists.txt +++ b/ref_kernels/CMakeLists.txt @@ -1,5 +1,11 @@ -##Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.## +##Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved.## +if(${TARGET_ARCH} STREQUAL amdzen) +add_subdirectory(${CMAKE_BINARY_DIR}/ref_kernels/generic ${CMAKE_BINARY_DIR}/ref_kernels/generic) +add_subdirectory(${CMAKE_BINARY_DIR}/ref_kernels/zen ${CMAKE_BINARY_DIR}/ref_kernels/zen) +add_subdirectory(${CMAKE_BINARY_DIR}/ref_kernels/zen2 ${CMAKE_BINARY_DIR}/ref_kernels/zen2) +add_subdirectory(${CMAKE_BINARY_DIR}/ref_kernels/zen3 ${CMAKE_BINARY_DIR}/ref_kernels/zen3) +else() target_sources("${PROJECT_NAME}" PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/bli_cntx_ref.c @@ -11,3 +17,4 @@ set(SUBDIRECTORIES "1" "1f" "1m" "3" "ind") foreach(VAR ${SUBDIRECTORIES}) add_subdirectory(${VAR}) endforeach() +endif() From 30038af896640c9166e9fa2ff1d0d0e72d1ffa11 Mon Sep 17 00:00:00 2001 From: lcpu Date: Wed, 13 Oct 2021 15:47:10 +0530 Subject: [PATCH 043/243] Reverted: To fix accuracy issues for complex datatypes Details: -- reverted cscalv,zscalv,ctrsm,ztrsm changes to address accuracy issues observed by libflame and scalapack application testing. -- AMD-Internal: [CPUPL-1906], [CPUPL-1914] Change-Id: Ic364eacbdf49493dd3a166a66880c12ee84c2204 --- config/zen/bli_cntx_init_zen.c | 4 +- config/zen2/bli_cntx_init_zen2.c | 4 +- config/zen3/bli_cntx_init_zen3.c | 4 +- frame/2/gemv/bli_gemv_unf_var2.c | 28 +- frame/compat/bla_scal.c | 142 +---- frame/compat/bla_trsm.c | 8 +- frame/include/bli_gentfunc_macro_defs.h | 4 +- kernels/zen/1/bli_scalv_zen_int10.c | 728 +----------------------- kernels/zen/bli_kernels_zen.h | 2 - test/Makefile | 2 +- 10 files changed, 40 insertions(+), 886 deletions(-) diff --git a/config/zen/bli_cntx_init_zen.c b/config/zen/bli_cntx_init_zen.c index de4cbfb130..7595849866 100644 --- a/config/zen/bli_cntx_init_zen.c +++ b/config/zen/bli_cntx_init_zen.c @@ -95,7 +95,7 @@ void bli_cntx_init_zen( cntx_t* cntx ) // Update the context with optimized level-1v kernels. bli_cntx_set_l1v_kers ( - 22, + 20, #if 1 // amaxv BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int, @@ -128,8 +128,6 @@ void bli_cntx_init_zen( cntx_t* cntx ) #else BLIS_SCALV_KER, BLIS_FLOAT, bli_sscalv_zen_int10, BLIS_SCALV_KER, BLIS_DOUBLE, bli_dscalv_zen_int10, - BLIS_SCALV_KER, BLIS_SCOMPLEX, bli_cscalv_zen_int10, - BLIS_SCALV_KER, BLIS_DCOMPLEX, bli_zscalv_zen_int10, #endif BLIS_SWAPV_KER, BLIS_FLOAT, bli_sswapv_zen_int8, BLIS_SWAPV_KER, BLIS_DOUBLE, bli_dswapv_zen_int8, diff --git a/config/zen2/bli_cntx_init_zen2.c b/config/zen2/bli_cntx_init_zen2.c index 6f3bbf3da9..4f56316a7a 100644 --- a/config/zen2/bli_cntx_init_zen2.c +++ b/config/zen2/bli_cntx_init_zen2.c @@ -107,7 +107,7 @@ void bli_cntx_init_zen2( cntx_t* cntx ) // Update the context with optimized level-1v kernels. bli_cntx_set_l1v_kers ( - 22, + 20, #if 1 // amaxv BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int, @@ -134,8 +134,6 @@ void bli_cntx_init_zen2( cntx_t* cntx ) // scalv BLIS_SCALV_KER, BLIS_FLOAT, bli_sscalv_zen_int10, BLIS_SCALV_KER, BLIS_DOUBLE, bli_dscalv_zen_int10, - BLIS_SCALV_KER, BLIS_SCOMPLEX, bli_cscalv_zen_int10, - BLIS_SCALV_KER, BLIS_DCOMPLEX, bli_zscalv_zen_int10, //swap BLIS_SWAPV_KER, BLIS_FLOAT, bli_sswapv_zen_int8, diff --git a/config/zen3/bli_cntx_init_zen3.c b/config/zen3/bli_cntx_init_zen3.c index 6b97f6bbf2..fc7dbcb808 100644 --- a/config/zen3/bli_cntx_init_zen3.c +++ b/config/zen3/bli_cntx_init_zen3.c @@ -107,7 +107,7 @@ void bli_cntx_init_zen3( cntx_t* cntx ) // Update the context with optimized level-1v kernels. bli_cntx_set_l1v_kers ( - 22, + 20, #if 1 // amaxv BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int, @@ -134,8 +134,6 @@ void bli_cntx_init_zen3( cntx_t* cntx ) // scalv BLIS_SCALV_KER, BLIS_FLOAT, bli_sscalv_zen_int10, BLIS_SCALV_KER, BLIS_DOUBLE, bli_dscalv_zen_int10, - BLIS_SCALV_KER, BLIS_SCOMPLEX, bli_cscalv_zen_int10, - BLIS_SCALV_KER, BLIS_DCOMPLEX, bli_zscalv_zen_int10, //swap BLIS_SWAPV_KER, BLIS_FLOAT, bli_sswapv_zen_int8, diff --git a/frame/2/gemv/bli_gemv_unf_var2.c b/frame/2/gemv/bli_gemv_unf_var2.c index ffebf17bac..ae1356fbea 100644 --- a/frame/2/gemv/bli_gemv_unf_var2.c +++ b/frame/2/gemv/bli_gemv_unf_var2.c @@ -394,7 +394,7 @@ void bli_zgemv_unf_var2 /* If beta is zero, use setv. Otherwise, scale by beta. */ /* y = beta * y; */ /* beta=0 case is hadled by scalv internally */ - bli_zscalv_zen_int10 +/* bli_zscalv_zen_int10 ( BLIS_NO_CONJUGATE, n_elem, @@ -402,7 +402,16 @@ void bli_zgemv_unf_var2 y, incy, cntx - ); + );*/ + bli_zscalv_ex + ( + BLIS_NO_CONJUGATE, + n_elem, + beta, + y, incy, + cntx, + NULL + ); if( bli_zeq0( *alpha ) ) { @@ -498,7 +507,7 @@ void bli_cgemv_unf_var2 /* If beta is zero, use setv. Otherwise, scale by beta. */ /* y = beta * y; */ /* beta=0 case is hadled by scalv internally */ - bli_cscalv_zen_int10 + /*bli_cscalv_zen_int10 ( BLIS_NO_CONJUGATE, n_elem, @@ -506,7 +515,18 @@ void bli_cgemv_unf_var2 y, incy, cntx - ); + );*/ + bli_cscalv_ex + ( + BLIS_NO_CONJUGATE, + n_elem, + beta, + y, incy, + cntx, + NULL + ); + + if( bli_ceq0( *alpha ) ) { diff --git a/frame/compat/bla_scal.c b/frame/compat/bla_scal.c index b08fac87f5..184b14eda0 100644 --- a/frame/compat/bla_scal.c +++ b/frame/compat/bla_scal.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020-21, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -226,146 +226,8 @@ void dscal_ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) } -void cscal_ - ( - const f77_int* n, - const scomplex* alpha, - scomplex* x, - const f77_int* incx - ) -{ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); - AOCL_DTL_LOG_SCAL_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'C', (void *)alpha, *n, *incx ); - dim_t n0; - scomplex* x0; - inc_t incx0; - - /* Initialize BLIS */ - //bli_init_auto(); - - if (*n == 0 || alpha == NULL) { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return; - } - - /* Convert typecast negative values of n to zero. */ - if ( *n < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(*n); - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - if ( *incx < 0 ) - { - /* The semantics of negative stride in BLAS are that the vector - operand be traversed in reverse order. (Another way to think - of this is that negative strides effectively reverse the order - of the vector, but without any explicit data movements.) This - is also how BLIS interprets negative strides. The differences - is that with BLAS, the caller *always* passes in the 0th (i.e., - top-most or left-most) element of the vector, even when the - stride is negative. By contrast, in BLIS, negative strides are - used *relative* to the vector address as it is given. Thus, in - BLIS, if this backwards traversal is desired, the caller *must* - pass in the address to the (n-1)th (i.e., the bottom-most or - right-most) element along with a negative stride. */ - - x0 = (x) + (n0-1)*(-*incx); - incx0 = ( inc_t )(*incx); - - } - else - { - x0 = (x); - incx0 = ( inc_t )(*incx); - } - - /* Call BLIS kernel */ - bli_cscalv_zen_int10 - ( - BLIS_NO_CONJUGATE, - n0, - (scomplex*) alpha, - x0, incx0, - NULL - ); - - /* Finalize BLIS. */ - // bli_finalize_auto(); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) -} - -void zscal_ - ( - const f77_int* n, - const dcomplex* alpha, - dcomplex* x, - const f77_int* incx - ) -{ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) - AOCL_DTL_LOG_SCAL_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'Z', (void *)alpha, *n, *incx ); - dim_t n0; - dcomplex* x0; - inc_t incx0; - - /* Initialize BLIS */ - //bli_init_auto(); - - if (*n == 0 || alpha == NULL) { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return; - } - - /* Convert typecast negative values of n to zero. */ - if ( *n < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(*n); - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - if ( *incx < 0 ) - { - /* The semantics of negative stride in BLAS are that the vector - operand be traversed in reverse order. (Another way to think - of this is that negative strides effectively reverse the order - of the vector, but without any explicit data movements.) This - is also how BLIS interprets negative strides. The differences - is that with BLAS, the caller *always* passes in the 0th (i.e., - top-most or left-most) element of the vector, even when the - stride is negative. By contrast, in BLIS, negative strides are - used *relative* to the vector address as it is given. Thus, in - BLIS, if this backwards traversal is desired, the caller *must* - pass in the address to the (n-1)th (i.e., the bottom-most or - right-most) element along with a negative stride. */ - - x0 = (x) + (n0-1)*(-*incx); - incx0 = ( inc_t )(*incx); - - } - else - { - x0 = (x); - incx0 = ( inc_t )(*incx); - } - - /* Call BLIS kernel */ - bli_zscalv_zen_int10 - ( - BLIS_NO_CONJUGATE, - n0, - (dcomplex*) alpha, - x0, incx0, - NULL - ); - - /* Finalize BLIS. */ - // bli_finalize_auto(); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) -} - -INSERT_GENTFUNCSCAL_BLAS_CsZd(scal, scalv) - +INSERT_GENTFUNCSCAL_BLAS_CZ( scal, scalv ) #else INSERT_GENTFUNCSCAL_BLAS( scal, scalv ) #endif #endif - diff --git a/frame/compat/bla_trsm.c b/frame/compat/bla_trsm.c index 46f79f02f8..0ee6e180c2 100644 --- a/frame/compat/bla_trsm.c +++ b/frame/compat/bla_trsm.c @@ -896,7 +896,7 @@ void dtrsm_ /* Finalize BLIS. */ bli_finalize_auto(); } - +#if 0 void ztrsm_ ( const f77_char* side, @@ -1215,7 +1215,8 @@ void ztrsm_ /* Finalize BLIS. */ bli_finalize_auto(); } - +#endif +#if 0 void ctrsm_ ( const f77_char* side, @@ -1531,7 +1532,8 @@ void ctrsm_ /* Finalize BLIS. */ bli_finalize_auto(); } - +#endif +INSERT_GENTFUNC_BLAS_CZ( trsm, trsm ) #else INSERT_GENTFUNC_BLAS( trsm, trsm ) #endif diff --git a/frame/include/bli_gentfunc_macro_defs.h b/frame/include/bli_gentfunc_macro_defs.h index ae0b1f3857..1bac7aa7c4 100644 --- a/frame/include/bli_gentfunc_macro_defs.h +++ b/frame/include/bli_gentfunc_macro_defs.h @@ -151,8 +151,10 @@ GENTFUNCR2( dcomplex, double, z, d, blasname, blisname ) // -- Extended two-operand macro (used only for scal) -- -#define INSERT_GENTFUNCSCAL_BLAS_CsZd( blasname, blisname ) \ +#define INSERT_GENTFUNCSCAL_BLAS_CZ( blasname, blisname ) \ \ +GENTFUNCSCAL( scomplex, scomplex, c, , blasname, blisname ) \ +GENTFUNCSCAL( dcomplex, dcomplex, z, , blasname, blisname ) \ GENTFUNCSCAL( scomplex, float, c, s, blasname, blisname ) \ GENTFUNCSCAL( dcomplex, double, z, d, blasname, blisname ) diff --git a/kernels/zen/1/bli_scalv_zen_int10.c b/kernels/zen/1/bli_scalv_zen_int10.c index fef490196d..6c7f52e161 100644 --- a/kernels/zen/1/bli_scalv_zen_int10.c +++ b/kernels/zen/1/bli_scalv_zen_int10.c @@ -4,8 +4,8 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2017 - 2021, Advanced Micro Devices, Inc. All rights reserved. - Copyright (C) 2018, The University of Texas at Austin. + Copyright (C) 2017 - 2020, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018, The University of Texas at Austin Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -454,727 +454,3 @@ void bli_dscalv_zen_int10 } } -// ----------------------------------------------------------------------------- - -void bli_cscalv_zen_int10 - ( - conj_t conjalpha, - dim_t n, - scomplex* restrict alpha, - scomplex* restrict x, - inc_t incx, - cntx_t* restrict cntx - ) -{ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_4) - - const dim_t n_elem_per_reg = 8; - - dim_t i; - - float* restrict x0; - float* restrict alpha0; - float alphaR, alphaI; - - __m256 alphaRv; - __m256 alphaIv; - __m256 xv[10]; - __m256 x_sufv[10]; - - conj_t conjx_use = conjalpha; - - // If the vector dimension is zero, or if alpha is unit, return early. - if ( bli_zero_dim1( n ) || PASTEMAC(c,eq1)( *alpha ) ) return; - - // If alpha is zero, use setv. - if ( PASTEMAC(c,eq0)( *alpha ) ) - { - scomplex* zero = bli_c0; - if (cntx == NULL) - cntx = bli_gks_query_cntx(); - csetv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_SCOMPLEX, BLIS_SETV_KER, cntx ); - f - ( - BLIS_NO_CONJUGATE, - n, - zero, - x, incx, - cntx - ); - return; - } - - // Initialize local pointers. - x0 = (float*)x; - alpha0 = (float*)alpha; - - alphaR = alpha->real; - alphaI = alpha->imag; - - if ( incx == 1 ) - { - // Broadcast the alpha scalar to all elements of a vector register. - if ( !bli_is_conj (conjx_use) ) // If BLIS_NO_CONJUGATE - { - alphaRv = _mm256_broadcast_ss( &alphaR ); - alphaIv = _mm256_set_ps(alphaI, -alphaI, alphaI, -alphaI, alphaI, -alphaI, alphaI, -alphaI); - } - else - { - alphaIv = _mm256_broadcast_ss( &alphaI ); - alphaRv = _mm256_set_ps(-alphaR, alphaR, -alphaR, alphaR, -alphaR, alphaR, -alphaR, alphaR); - } - - /* - = (alpha_r + alpha_i) * (x_r + x_i) - = alpha_r*x_r + alpha_r*x_i + alpha_i*x_r + (-alpha_i*x_i) - = (alpha_r*x_r - alpha_i*x_i) + (alpha_r*x_i + alpha_i*x_r)I - - x = x_r , x_i , x_r , x_i , x_r , x_i , x_r , x_i - x_suf = x_i , x_r , x_i , x_r , x_i , x_r , x_i , x_r - alphaR = ar , ar , ar , ar , ar , ar , ar , ar - alphaI = -ai , ai ,-ai , ai ,-ai , ai ,-ai, ai - - step 1) Load x. - step 2) Shuffle x. - step 3) mul x <= x*alphaR => ar*x_r , ar*x_i - step 4) fma x <= x_suf*alphaI + x => (-ai*x_i , ai*x_r) + (ar*x_r , ar*x_i) - => (ar*x_r - ai*x_i), (ar*x_i + ai*x_r ) - */ - - for ( i = 0; (i + 39) < n; i += 40 ) - { - // Load the input values. - xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); - xv[1] = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); - xv[2] = _mm256_loadu_ps( x0 + 2*n_elem_per_reg ); - xv[3] = _mm256_loadu_ps( x0 + 3*n_elem_per_reg ); - xv[4] = _mm256_loadu_ps( x0 + 4*n_elem_per_reg ); - xv[5] = _mm256_loadu_ps( x0 + 5*n_elem_per_reg ); - xv[6] = _mm256_loadu_ps( x0 + 6*n_elem_per_reg ); - xv[7] = _mm256_loadu_ps( x0 + 7*n_elem_per_reg ); - xv[8] = _mm256_loadu_ps( x0 + 8*n_elem_per_reg ); - xv[9] = _mm256_loadu_ps( x0 + 9*n_elem_per_reg ); - - // x = xr0 , xi0, xr1, xi1 .... - // x_suf = xi0 , xr0, xi1, xr1 .... - x_sufv[0] = _mm256_permute_ps( xv[0], 0xB1); - x_sufv[1] = _mm256_permute_ps( xv[1], 0xB1); - x_sufv[2] = _mm256_permute_ps( xv[2], 0xB1); - x_sufv[3] = _mm256_permute_ps( xv[3], 0xB1); - x_sufv[4] = _mm256_permute_ps( xv[4], 0xB1); - x_sufv[5] = _mm256_permute_ps( xv[5], 0xB1); - x_sufv[6] = _mm256_permute_ps( xv[6], 0xB1); - x_sufv[7] = _mm256_permute_ps( xv[7], 0xB1); - x_sufv[8] = _mm256_permute_ps( xv[8], 0xB1); - x_sufv[9] = _mm256_permute_ps( xv[9], 0xB1); - - // mul x <= x*alphaR - // aphhaR = ar , ar , ar , ar , .... - // x = xr , xi , xr , xi , .... - // mul = ar*xr, ar*xi , ar*xr , ar*xi, .... - xv[0] = _mm256_mul_ps( alphaRv, xv[0] ); - xv[1] = _mm256_mul_ps( alphaRv, xv[1] ); - xv[2] = _mm256_mul_ps( alphaRv, xv[2] ); - xv[3] = _mm256_mul_ps( alphaRv, xv[3] ); - xv[4] = _mm256_mul_ps( alphaRv, xv[4] ); - xv[5] = _mm256_mul_ps( alphaRv, xv[5] ); - xv[6] = _mm256_mul_ps( alphaRv, xv[6] ); - xv[7] = _mm256_mul_ps( alphaRv, xv[7] ); - xv[8] = _mm256_mul_ps( alphaRv, xv[8] ); - xv[9] = _mm256_mul_ps( alphaRv, xv[9] ); - - // fma x <= x_suf*alphaI + x - // alphaI = -ai , ai , -ai , ai .... - // X suf = xi , xr , xi , xr .... - // mul = -ai*xi, ai*xr , -ai*xi, ai*xi .... - // add x = ar*xr - ai*xi, ar*xi + ai*xr, .... - xv[0] = _mm256_fmadd_ps( alphaIv, x_sufv[0], xv[0] ); - xv[1] = _mm256_fmadd_ps( alphaIv, x_sufv[1], xv[1] ); - xv[2] = _mm256_fmadd_ps( alphaIv, x_sufv[2], xv[2] ); - xv[3] = _mm256_fmadd_ps( alphaIv, x_sufv[3], xv[3] ); - xv[4] = _mm256_fmadd_ps( alphaIv, x_sufv[4], xv[4] ); - xv[5] = _mm256_fmadd_ps( alphaIv, x_sufv[5], xv[5] ); - xv[6] = _mm256_fmadd_ps( alphaIv, x_sufv[6], xv[6] ); - xv[7] = _mm256_fmadd_ps( alphaIv, x_sufv[7], xv[7] ); - xv[8] = _mm256_fmadd_ps( alphaIv, x_sufv[8], xv[8] ); - xv[9] = _mm256_fmadd_ps( alphaIv, x_sufv[9], xv[9] ); - - // Store the output. - _mm256_storeu_ps( (x0 + 0*n_elem_per_reg), xv[0] ); - _mm256_storeu_ps( (x0 + 1*n_elem_per_reg), xv[1] ); - _mm256_storeu_ps( (x0 + 2*n_elem_per_reg), xv[2] ); - _mm256_storeu_ps( (x0 + 3*n_elem_per_reg), xv[3] ); - _mm256_storeu_ps( (x0 + 4*n_elem_per_reg), xv[4] ); - _mm256_storeu_ps( (x0 + 5*n_elem_per_reg), xv[5] ); - _mm256_storeu_ps( (x0 + 6*n_elem_per_reg), xv[6] ); - _mm256_storeu_ps( (x0 + 7*n_elem_per_reg), xv[7] ); - _mm256_storeu_ps( (x0 + 8*n_elem_per_reg), xv[8] ); - _mm256_storeu_ps( (x0 + 9*n_elem_per_reg), xv[9] ); - - x0 += 10*n_elem_per_reg; - } - - for ( ; (i + 19) < n; i += 20 ) - { - // Load the input values. - xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); - xv[1] = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); - xv[2] = _mm256_loadu_ps( x0 + 2*n_elem_per_reg ); - xv[3] = _mm256_loadu_ps( x0 + 3*n_elem_per_reg ); - xv[4] = _mm256_loadu_ps( x0 + 4*n_elem_per_reg ); - - // x = xr0 , xi0, xr1, xi1 .... - // x_suf = xi0 , xr0, xi1, xr1 .... - x_sufv[0] = _mm256_permute_ps( xv[0], 0xB1); - x_sufv[1] = _mm256_permute_ps( xv[1], 0xB1); - x_sufv[2] = _mm256_permute_ps( xv[2], 0xB1); - x_sufv[3] = _mm256_permute_ps( xv[3], 0xB1); - x_sufv[4] = _mm256_permute_ps( xv[4], 0xB1); - - // mul x <= x*alphaR - // aphhaR = ar , ar , ar , ar , .... - // x = xr , xi , xr , xi , .... - // mul = ar*xr, ar*xi , ar*xr , ar*xi, .... - xv[0] = _mm256_mul_ps( alphaRv, xv[0] ); - xv[1] = _mm256_mul_ps( alphaRv, xv[1] ); - xv[2] = _mm256_mul_ps( alphaRv, xv[2] ); - xv[3] = _mm256_mul_ps( alphaRv, xv[3] ); - xv[4] = _mm256_mul_ps( alphaRv, xv[4] ); - - // fma x <= x_suf*alphaI + x - // alphaI = -ai , ai , -ai , ai .... - // X = xi , xr , xi , xr .... - // mul = -ai*xi, ai*xr , -ai*xi, ai*xi .... - // add x = ar*xr - ai*xi, ar*xi + ai*xr, - xv[0] = _mm256_fmadd_ps( alphaIv, x_sufv[0], xv[0] ); - xv[1] = _mm256_fmadd_ps( alphaIv, x_sufv[1], xv[1] ); - xv[2] = _mm256_fmadd_ps( alphaIv, x_sufv[2], xv[2] ); - xv[3] = _mm256_fmadd_ps( alphaIv, x_sufv[3], xv[3] ); - xv[4] = _mm256_fmadd_ps( alphaIv, x_sufv[4], xv[4] ); - - // Store the output. - _mm256_storeu_ps( (x0 + 0*n_elem_per_reg), xv[0] ); - _mm256_storeu_ps( (x0 + 1*n_elem_per_reg), xv[1] ); - _mm256_storeu_ps( (x0 + 2*n_elem_per_reg), xv[2] ); - _mm256_storeu_ps( (x0 + 3*n_elem_per_reg), xv[3] ); - _mm256_storeu_ps( (x0 + 4*n_elem_per_reg), xv[4] ); - - x0 += 5*n_elem_per_reg; - } - - for ( ; (i + 15) < n; i += 16 ) - { - // Load the input values. - xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); - xv[1] = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); - xv[2] = _mm256_loadu_ps( x0 + 2*n_elem_per_reg ); - xv[3] = _mm256_loadu_ps( x0 + 3*n_elem_per_reg ); - - // x = xr0 , xi0, xr1, xi1 .... - // x_suf = xi0 , xr0, xi1, xr1 .... - x_sufv[0] = _mm256_permute_ps( xv[0], 0xB1); - x_sufv[1] = _mm256_permute_ps( xv[1], 0xB1); - x_sufv[2] = _mm256_permute_ps( xv[2], 0xB1); - x_sufv[3] = _mm256_permute_ps( xv[3], 0xB1); - - // mul x <= x*alphaR - // aphhaR = ar , ar , ar , ar , .... - // x = xr , xi , xr , xi , .... - // mul = ar*xr, ar*xi , ar*xr , ar*xi, .... - xv[0] = _mm256_mul_ps( alphaRv, xv[0] ); - xv[1] = _mm256_mul_ps( alphaRv, xv[1] ); - xv[2] = _mm256_mul_ps( alphaRv, xv[2] ); - xv[3] = _mm256_mul_ps( alphaRv, xv[3] ); - - // fma x <= x_suf*alphaI + x - // alphaI = -ai , ai , -ai , ai .... - // X = xi , xr , xi , xr .... - // mul = -ai*xi, ai*xr , -ai*xi, ai*xi .... - // add x = ar*xr - ai*xi, ar*xi + ai*xr, - xv[0] = _mm256_fmadd_ps( alphaIv, x_sufv[0], xv[0] ); - xv[1] = _mm256_fmadd_ps( alphaIv, x_sufv[1], xv[1] ); - xv[2] = _mm256_fmadd_ps( alphaIv, x_sufv[2], xv[2] ); - xv[3] = _mm256_fmadd_ps( alphaIv, x_sufv[3], xv[3] ); - - // Store the output. - _mm256_storeu_ps( (x0 + 0*n_elem_per_reg), xv[0] ); - _mm256_storeu_ps( (x0 + 1*n_elem_per_reg), xv[1] ); - _mm256_storeu_ps( (x0 + 2*n_elem_per_reg), xv[2] ); - _mm256_storeu_ps( (x0 + 3*n_elem_per_reg), xv[3] ); - - x0 += 4*n_elem_per_reg; - } - - for ( ; (i + 7) < n; i += 8 ) - { - // Load the input values. - xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); - xv[1] = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); - - // x = xr0 , xi0, xr1, xi1 .... - // x_suf = xi0 , xr0, xi1, xr1 .... - x_sufv[0] = _mm256_permute_ps( xv[0], 0xB1); - x_sufv[1] = _mm256_permute_ps( xv[1], 0xB1); - - // mul x <= x*alphaR - // aphhaR = ar , ar , ar , ar , .... - // x = xr , xi , xr , xi , .... - // mul = ar*xr, ar*xi , ar*xr , ar*xi, .... - xv[0] = _mm256_mul_ps( alphaRv, xv[0] ); - xv[1] = _mm256_mul_ps( alphaRv, xv[1] ); - - // fma x <= x_suf*alphaI + x - // alphaI = -ai , ai , -ai , ai .... - // X = xi , xr , xi , xr .... - // mul = -ai*xi, ai*xr , -ai*xi, ai*xi .... - // add x = ar*xr - ai*xi, ar*xi + ai*xr, - xv[0] = _mm256_fmadd_ps( alphaIv, x_sufv[0], xv[0] ); - xv[1] = _mm256_fmadd_ps( alphaIv, x_sufv[1], xv[1] ); - - // Store the output. - _mm256_storeu_ps( (x0 + 0*n_elem_per_reg), xv[0] ); - _mm256_storeu_ps( (x0 + 1*n_elem_per_reg), xv[1] ); - - x0 += 2*n_elem_per_reg; - } - - for ( ; (i + 3) < n; i += 4 ) - { - // Load the input values. - xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); - - // x = xr0 , xi0, xr1, xi1 .... - // x_suf = xi0 , xr0, xi1, xr1 .... - x_sufv[0] = _mm256_permute_ps( xv[0], 0xB1); - - // mul x <= x*alphaR - // aphhaR = ar , ar , ar , ar , .... - // x = xr , xi , xr , xi , .... - // mul = ar*xr, ar*xi , ar*xr , ar*xi, .... - xv[0] = _mm256_mul_ps( alphaRv, xv[0] ); - - // fma x <= x_suf*alphaI + x - // alphaI = -ai , ai , -ai , ai .... - // X = xi , xr , xi , xr .... - // mul = -ai*xi, ai*xr , -ai*xi, ai*xi .... - // add x = ar*xr - ai*xi, ar*xi + ai*xr, - xv[0] = _mm256_fmadd_ps( alphaIv, x_sufv[0], xv[0] ); - - // Store the output. - _mm256_storeu_ps( (x0 + 0*n_elem_per_reg), xv[0] ); - - x0 += 1*n_elem_per_reg; - } - - for ( ; (i + 0) < n; i += 1 ) - { - float real; - - // real part: ( aR.xR - aIxI ) - real = *alpha0 * (*x0) - (*(alpha0 + 1)) * (*(x0+1)); - // img part: ( aR.xI + aI.xR ) - *(x0 + 1) = *alpha0 * (*(x0+1)) + (*(alpha0 + 1)) * (*x0); - - *x0 = real; - - x0 += 2; - } - } - else - { - const float alphar = *alpha0; - const float alphai = *(alpha0 + 1); - - if ( !bli_is_conj(conjx_use) ) // BLIS_NO_CONJUGATE - { - for ( i = 0; i < n; ++i ) - { - const float x0c = *x0; - const float x1c = *( x0+1 ); - - *x0 = alphar * x0c - alphai * x1c; - *(x0 + 1) = alphar * x1c + alphai * x0c; - - x0 += incx*2; - } - } - else // BLIS_CONJUGATE - { - for ( i = 0; i < n; ++i ) - { - const float x0c = *x0; - const float x1c = *( x0+1 ); - - *x0 = alphar * x0c + alphai * x1c; - *(x0 + 1) = alphai * x0c - alphar * x1c; - - x0 += incx*2; - } - } - } -} - -// ----------------------------------------------------------------------------- - -void bli_zscalv_zen_int10 - ( - conj_t conjalpha, - dim_t n, - dcomplex* restrict alpha, - dcomplex* restrict x, - inc_t incx, - cntx_t* restrict cntx - ) -{ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_4) - - const dim_t n_elem_per_reg = 4; - - dim_t i; - - double* restrict x0; - double* restrict alpha0; - double alphaR, alphaI; - - __m256d alphaRv; - __m256d alphaIv; - __m256d xv[10]; - __m256d x_sufv[10]; - - conj_t conjx_use = conjalpha; - - // If the vector dimension is zero, or if alpha is unit, return early. - if ( bli_zero_dim1( n ) || PASTEMAC(z,eq1)( *alpha ) ) return; - - // If alpha is zero, use setv. - if ( PASTEMAC(z,eq0)( *alpha ) ) - { - dcomplex* zero = bli_z0; - - if (cntx == NULL) - cntx = bli_gks_query_cntx(); - zsetv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_DCOMPLEX, BLIS_SETV_KER, cntx ); - f - ( - BLIS_NO_CONJUGATE, - n, - zero, - x, incx, - cntx - ); - - return; - } - - // Initialize local pointers. - x0 = (double*)x; - alpha0 = (double*)alpha; - - alphaR = alpha->real; - alphaI = alpha->imag; - - if ( incx == 1 ) - { - // Broadcast the alpha scalar to all elements of a vector register. - if ( !bli_is_conj (conjx_use) ) // If BLIS_NO_CONJUGATE - { - alphaRv = _mm256_broadcast_sd( &alphaR ); - alphaIv = _mm256_set_pd(alphaI, -alphaI, alphaI, -alphaI); - } - else - { - alphaIv = _mm256_broadcast_sd( &alphaI ); - alphaRv = _mm256_set_pd(alphaR, -alphaR, alphaR, -alphaR); - } - - /* - = (alpha_r + alpha_i) * (x_r + x_i) - = alpha_r*x_r + alpha_r*x_i + alpha_i*x_r + (-alpha_i*x_i) - = (alpha_r*x_r - alpha_i*x_i) + (alpha_r*x_i + alpha_i*x_r)I - - x = x_r , x_i , x_r , x_i , x_r , x_i , x_r , x_i - x_suf = x_i , x_r , x_i , x_r , x_i , x_r , x_i , x_r - alphaR = ar , ar , ar , ar , ar , ar , ar , ar - alphaI = -ai , ai ,-ai , ai ,-ai , ai ,-ai, ai - - step 1) Load x. - step 2) Shuffle x. - step 3) mul x <= x*alphaR => ar*x_r , ar*x_i - step 4) fma x <= x_suf*alphaI + x => (-ai*x_i , ai*x_r) + (ar*x_r , ar*x_i) - => (ar*x_r - ai*x_i), (ar*x_i + ai*x_r ) - */ - - for ( i = 0; (i + 19) < n; i += 20 ) - { - // Load the input values. - xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); - xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); - xv[2] = _mm256_loadu_pd( x0 + 2*n_elem_per_reg ); - xv[3] = _mm256_loadu_pd( x0 + 3*n_elem_per_reg ); - xv[4] = _mm256_loadu_pd( x0 + 4*n_elem_per_reg ); - xv[5] = _mm256_loadu_pd( x0 + 5*n_elem_per_reg ); - xv[6] = _mm256_loadu_pd( x0 + 6*n_elem_per_reg ); - xv[7] = _mm256_loadu_pd( x0 + 7*n_elem_per_reg ); - xv[8] = _mm256_loadu_pd( x0 + 8*n_elem_per_reg ); - xv[9] = _mm256_loadu_pd( x0 + 9*n_elem_per_reg ); - - // x = xr0 , xi0, xr1, xi1 .... - // x_suf = xi0 , xr0, xi1, xr1 .... - x_sufv[0] = _mm256_permute_pd( xv[0], 5); - x_sufv[1] = _mm256_permute_pd( xv[1], 5); - x_sufv[2] = _mm256_permute_pd( xv[2], 5); - x_sufv[3] = _mm256_permute_pd( xv[3], 5); - x_sufv[4] = _mm256_permute_pd( xv[4], 5); - x_sufv[5] = _mm256_permute_pd( xv[5], 5); - x_sufv[6] = _mm256_permute_pd( xv[6], 5); - x_sufv[7] = _mm256_permute_pd( xv[7], 5); - x_sufv[8] = _mm256_permute_pd( xv[8], 5); - x_sufv[9] = _mm256_permute_pd( xv[9], 5); - - // mul x <= x*alphaR - // aphhaR = ar , ar , ar , ar , .... - // x = xr , xi , xr , xi , .... - // mul = ar*xr, ar*xi , ar*xr , ar*xi, .... - xv[0] = _mm256_mul_pd( alphaRv, xv[0] ); - xv[1] = _mm256_mul_pd( alphaRv, xv[1] ); - xv[2] = _mm256_mul_pd( alphaRv, xv[2] ); - xv[3] = _mm256_mul_pd( alphaRv, xv[3] ); - xv[4] = _mm256_mul_pd( alphaRv, xv[4] ); - xv[5] = _mm256_mul_pd( alphaRv, xv[5] ); - xv[6] = _mm256_mul_pd( alphaRv, xv[6] ); - xv[7] = _mm256_mul_pd( alphaRv, xv[7] ); - xv[8] = _mm256_mul_pd( alphaRv, xv[8] ); - xv[9] = _mm256_mul_pd( alphaRv, xv[9] ); - - // fma x <= x_suf*alphaI + x - // alphaI = -ai , ai , -ai , ai .... - // X suf = xi , xr , xi , xr .... - // mul = -ai*xi, ai*xr , -ai*xi, ai*xi .... - // add x = ar*xr - ai*xi, ar*xi + ai*xr, .... - xv[0] = _mm256_fmadd_pd( alphaIv, x_sufv[0], xv[0] ); - xv[1] = _mm256_fmadd_pd( alphaIv, x_sufv[1], xv[1] ); - xv[2] = _mm256_fmadd_pd( alphaIv, x_sufv[2], xv[2] ); - xv[3] = _mm256_fmadd_pd( alphaIv, x_sufv[3], xv[3] ); - xv[4] = _mm256_fmadd_pd( alphaIv, x_sufv[4], xv[4] ); - xv[5] = _mm256_fmadd_pd( alphaIv, x_sufv[5], xv[5] ); - xv[6] = _mm256_fmadd_pd( alphaIv, x_sufv[6], xv[6] ); - xv[7] = _mm256_fmadd_pd( alphaIv, x_sufv[7], xv[7] ); - xv[8] = _mm256_fmadd_pd( alphaIv, x_sufv[8], xv[8] ); - xv[9] = _mm256_fmadd_pd( alphaIv, x_sufv[9], xv[9] ); - - // Store the output. - _mm256_storeu_pd( (x0 + 0*n_elem_per_reg), xv[0] ); - _mm256_storeu_pd( (x0 + 1*n_elem_per_reg), xv[1] ); - _mm256_storeu_pd( (x0 + 2*n_elem_per_reg), xv[2] ); - _mm256_storeu_pd( (x0 + 3*n_elem_per_reg), xv[3] ); - _mm256_storeu_pd( (x0 + 4*n_elem_per_reg), xv[4] ); - _mm256_storeu_pd( (x0 + 5*n_elem_per_reg), xv[5] ); - _mm256_storeu_pd( (x0 + 6*n_elem_per_reg), xv[6] ); - _mm256_storeu_pd( (x0 + 7*n_elem_per_reg), xv[7] ); - _mm256_storeu_pd( (x0 + 8*n_elem_per_reg), xv[8] ); - _mm256_storeu_pd( (x0 + 9*n_elem_per_reg), xv[9] ); - - x0 += 10*n_elem_per_reg; - } - - for ( ; (i + 9) < n; i += 10 ) - { - // Load the input values. - xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); - xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); - xv[2] = _mm256_loadu_pd( x0 + 2*n_elem_per_reg ); - xv[3] = _mm256_loadu_pd( x0 + 3*n_elem_per_reg ); - xv[4] = _mm256_loadu_pd( x0 + 4*n_elem_per_reg ); - - // x = xr0 , xi0, xr1, xi1 - // x_suf = xi0 , xr0, xi1, xr1 - x_sufv[0] = _mm256_permute_pd( xv[0], 5); - x_sufv[1] = _mm256_permute_pd( xv[1], 5); - x_sufv[2] = _mm256_permute_pd( xv[2], 5); - x_sufv[3] = _mm256_permute_pd( xv[3], 5); - x_sufv[4] = _mm256_permute_pd( xv[4], 5); - - // mul x <= x*alphaR - // aphhaR = ar , ar , ar , ar - // x = xr , xi , xr , xi - // mul = ar*xr, ar*xi , ar*xr , ar*xi - xv[0] = _mm256_mul_pd( alphaRv, xv[0] ); - xv[1] = _mm256_mul_pd( alphaRv, xv[1] ); - xv[2] = _mm256_mul_pd( alphaRv, xv[2] ); - xv[3] = _mm256_mul_pd( alphaRv, xv[3] ); - xv[4] = _mm256_mul_pd( alphaRv, xv[4] ); - - // fma x <= x_suf*alphaI + x - // alphaI = -ai , ai , -ai , ai - // X = xi , xr , xi , xr - // mul = -ai*xi, ai*xr , -ai*xi, ai*xi - // add x = ar*xr - ai*xi, ar*xi + ai*xr, - xv[0] = _mm256_fmadd_pd( alphaIv, x_sufv[0], xv[0] ); - xv[1] = _mm256_fmadd_pd( alphaIv, x_sufv[1], xv[1] ); - xv[2] = _mm256_fmadd_pd( alphaIv, x_sufv[2], xv[2] ); - xv[3] = _mm256_fmadd_pd( alphaIv, x_sufv[3], xv[3] ); - xv[4] = _mm256_fmadd_pd( alphaIv, x_sufv[4], xv[4] ); - - // Store the output. - _mm256_storeu_pd( (x0 + 0*n_elem_per_reg), xv[0] ); - _mm256_storeu_pd( (x0 + 1*n_elem_per_reg), xv[1] ); - _mm256_storeu_pd( (x0 + 2*n_elem_per_reg), xv[2] ); - _mm256_storeu_pd( (x0 + 3*n_elem_per_reg), xv[3] ); - _mm256_storeu_pd( (x0 + 4*n_elem_per_reg), xv[4] ); - - x0 += 5*n_elem_per_reg; - } - - for ( ; (i + 7) < n; i += 8 ) - { - // Load the input values. - xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); - xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); - xv[2] = _mm256_loadu_pd( x0 + 2*n_elem_per_reg ); - xv[3] = _mm256_loadu_pd( x0 + 3*n_elem_per_reg ); - - // x = xr0 , xi0, xr1, xi1 .... - // x_suf = xi0 , xr0, xi1, xr1 .... - x_sufv[0] = _mm256_permute_pd( xv[0], 5); - x_sufv[1] = _mm256_permute_pd( xv[1], 5); - x_sufv[2] = _mm256_permute_pd( xv[2], 5); - x_sufv[3] = _mm256_permute_pd( xv[3], 5); - - // mul x <= x*alphaR - // aphhaR = ar , ar , ar , ar , .... - // x = xr , xi , xr , xi , .... - // mul = ar*xr, ar*xi , ar*xr , ar*xi, .... - xv[0] = _mm256_mul_pd( alphaRv, xv[0] ); - xv[1] = _mm256_mul_pd( alphaRv, xv[1] ); - xv[2] = _mm256_mul_pd( alphaRv, xv[2] ); - xv[3] = _mm256_mul_pd( alphaRv, xv[3] ); - - // fma x <= x_suf*alphaI + x - // alphaI = -ai , ai , -ai , ai .... - // X = xi , xr , xi , xr .... - // mul = -ai*xi, ai*xr , -ai*xi, ai*xi .... - // add x = ar*xr - ai*xi, ar*xi + ai*xr, - xv[0] = _mm256_fmadd_pd( alphaIv, x_sufv[0], xv[0] ); - xv[1] = _mm256_fmadd_pd( alphaIv, x_sufv[1], xv[1] ); - xv[2] = _mm256_fmadd_pd( alphaIv, x_sufv[2], xv[2] ); - xv[3] = _mm256_fmadd_pd( alphaIv, x_sufv[3], xv[3] ); - - // Store the output. - _mm256_storeu_pd( (x0 + 0*n_elem_per_reg), xv[0] ); - _mm256_storeu_pd( (x0 + 1*n_elem_per_reg), xv[1] ); - _mm256_storeu_pd( (x0 + 2*n_elem_per_reg), xv[2] ); - _mm256_storeu_pd( (x0 + 3*n_elem_per_reg), xv[3] ); - - x0 += 4*n_elem_per_reg; - } - - - for ( ; (i + 3) < n; i += 4 ) - { - // Load the input values. - xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); - xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); - - // x = xr0 , xi0, xr1, xi1 .... - // x_suf = xi0 , xr0, xi1, xr1 .... - x_sufv[0] = _mm256_permute_pd( xv[0], 5); - x_sufv[1] = _mm256_permute_pd( xv[1], 5); - - // mul x <= x*alphaR - // aphhaR = ar , ar , ar , ar , .... - // x = xr , xi , xr , xi , .... - // mul = ar*xr, ar*xi , ar*xr , ar*xi, .... - xv[0] = _mm256_mul_pd( alphaRv, xv[0] ); - xv[1] = _mm256_mul_pd( alphaRv, xv[1] ); - - // fma x <= x_suf*alphaI + x - // alphaI = -ai , ai , -ai , ai .... - // X = xi , xr , xi , xr .... - // mul = -ai*xi, ai*xr , -ai*xi, ai*xi .... - // add x = ar*xr - ai*xi, ar*xi + ai*xr, - xv[0] = _mm256_fmadd_pd( alphaIv, x_sufv[0], xv[0] ); - xv[1] = _mm256_fmadd_pd( alphaIv, x_sufv[1], xv[1] ); - - // Store the output. - _mm256_storeu_pd( (x0 + 0*n_elem_per_reg), xv[0] ); - _mm256_storeu_pd( (x0 + 1*n_elem_per_reg), xv[1] ); - - x0 += 2*n_elem_per_reg; - } - - for ( ; (i + 1) < n; i += 2 ) - { - // Load the input values. - xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); - - // x = xr0 , xi0, xr1, xi1 .... - // x_suf = xi0 , xr0, xi1, xr1 .... - x_sufv[0] = _mm256_permute_pd( xv[0], 5); - - // mul x <= x*alphaR - // aphhaR = ar , ar , ar , ar , .... - // x = xr , xi , xr , xi , .... - // mul = ar*xr, ar*xi , ar*xr , ar*xi, .... - xv[0] = _mm256_mul_pd( alphaRv, xv[0] ); - - // fma x <= x_suf*alphaI + x - // alphaI = -ai , ai , -ai , ai .... - // X = xi , xr , xi , xr .... - // mul = -ai*xi, ai*xr , -ai*xi, ai*xi .... - // add x = ar*xr - ai*xi, ar*xi + ai*xr, - xv[0] = _mm256_fmadd_pd( alphaIv, x_sufv[0], xv[0] ); - - // Store the output. - _mm256_storeu_pd( (x0 + 0*n_elem_per_reg), xv[0] ); - - x0 += 1*n_elem_per_reg; - } - - for ( ; (i + 0) < n; i += 1 ) - { - double real; - - // real part: ( aR.xR - aIxI ) - real = *alpha0 * (*x0) - (*(alpha0 + 1)) * (*(x0+1)); - // img part: ( aR.xI + aI.xR ) - *(x0 + 1) = *alpha0 * (*(x0+1)) + (*(alpha0 + 1)) * (*x0); - - *x0 = real; - - x0 += 2; - } - } - else - { - const double alphar = *alpha0; - const double alphai = *(alpha0 + 1); - - if ( !bli_is_conj(conjx_use) ) // BLIS_NO_CONJUGATE - { - for ( i = 0; i < n; ++i ) - { - const double x0c = *x0; - const double x1c = *( x0 + 1 ); - - *x0 = alphar * x0c - alphai * x1c; - *(x0 + 1) = alphar * x1c + alphai * x0c; - - x0 += incx*2; - } - } - else // BLIS_CONJUGATE - { - for ( i = 0; i < n; ++i ) - { - const double x0c = *x0; - const double x1c = *( x0 + 1 ); - - *x0 = alphar * x0c + alphai * x1c; - *(x0 + 1) = alphai * x0c - alphar * x1c; - - x0 += incx*2; - } - } - } -} diff --git a/kernels/zen/bli_kernels_zen.h b/kernels/zen/bli_kernels_zen.h index 8845996962..914b5d631f 100644 --- a/kernels/zen/bli_kernels_zen.h +++ b/kernels/zen/bli_kernels_zen.h @@ -79,8 +79,6 @@ SCALV_KER_PROT( double, d, scalv_zen_int ) // scalv (intrinsics unrolled x10) SCALV_KER_PROT( float, s, scalv_zen_int10 ) SCALV_KER_PROT( double, d, scalv_zen_int10 ) -SCALV_KER_PROT( scomplex, c, scalv_zen_int10 ) -SCALV_KER_PROT( dcomplex, z, scalv_zen_int10 ) // swapv (intrinsics) SWAPV_KER_PROT(float, s, swapv_zen_int8 ) diff --git a/test/Makefile b/test/Makefile index 7521fb7f13..3370ce7157 100644 --- a/test/Makefile +++ b/test/Makefile @@ -155,7 +155,7 @@ CFLAGS += -I$(TEST_SRC_PATH) # # Define the operations we will test. -TEST_OPS := dotv axpyv scalv \ +TEST_OPS := dotv axpyv \ gemv ger hemv her her2 trmv trsv \ gemm hemm herk her2k trmm trsm \ From d683c224e8f5caf24b87fa406191eff35e2b0223 Mon Sep 17 00:00:00 2001 From: mkadavil Date: Mon, 11 Oct 2021 17:51:23 +0530 Subject: [PATCH 044/243] Workaround for perf regression observed for sgemm Details: - Perf regression is observed for certain m,n,k inputs where (m,n,k > 512) and (m > 4 * n) in BLIS 3.1. The root cause was traced to commit 11dfc176a3c422729f453f6c23204cf023e9954d where BLIS_THREAD_RATIO_M was updated from 2 to 1. This change was not part of BLIS 3.0.6 and hence resulted in the new perf drop in 3.1. - This workaround updates the m dimension (doubles it) that is passed as argument to bli_rntm_set_ways_for_op which is used to determine the ic,jc work split in the threads. The BLIS_THREAD_RATIO_M is not updated (to 2) and rather the effect is induced using the doubled m dimension. AMD-Internal: [CPUPL-1909] Change-Id: I3b6ec4d4a22154289cb56d8f7db4cb60e5f34afe --- frame/3/gemm/bli_gemm_front.c | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/frame/3/gemm/bli_gemm_front.c b/frame/3/gemm/bli_gemm_front.c index 662a6da9bb..a065156bbf 100644 --- a/frame/3/gemm/bli_gemm_front.c +++ b/frame/3/gemm/bli_gemm_front.c @@ -173,7 +173,24 @@ void bli_gemm_front // or the inlined code above. bli_obj_swap_pack_schemas( &a_local, &b_local ); } - + + dim_t m_dim_local = bli_obj_length( &c_local ); + dim_t n_dim_local = bli_obj_width( &c_local ); + dim_t k_dim_local = bli_obj_width( &a_local ); +#ifdef BLIS_CONFIG_EPYC + // Regression observed in sgemm native path in cases where m >= 4 * n + // after BLIS_THREAD_RATIO_M updated from 2 to 1 as part of commit + // 11dfc176a3c422729f453f6c23204cf023e9954d. Temporary workaround for + // the issue. + if( bli_obj_is_float( &c_local ) && + ( n_dim_local >= 1024 ) && + ( k_dim_local >= 1024 ) && + ( m_dim_local >= ( 4 * n_dim_local ) ) ) + { + m_dim_local *= 2; + } +#endif + // Parse and interpret the contents of the rntm_t object to properly // set the ways of parallelism for each loop, and then make any // additional modifications necessary for the current operation. @@ -181,9 +198,9 @@ void bli_gemm_front ( BLIS_GEMM, BLIS_LEFT, // ignored for gemm/hemm/symm - bli_obj_length( &c_local ), - bli_obj_width( &c_local ), - bli_obj_width( &a_local ), + m_dim_local, + n_dim_local, + k_dim_local, rntm ); From ddbdfd0ba40d9b9a2962138271d1939cee3bf132 Mon Sep 17 00:00:00 2001 From: Dipal M Zambare Date: Wed, 27 Oct 2021 15:14:33 +0530 Subject: [PATCH 045/243] Fixed dynamic dispatch crash issue on non-zen architecture. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit fixed issue for gemm and copy API’s. The BLIS binary with dynamic dispatch feature was crashing on non-zen CPUs (specifically CPUs without AVX2 support). The crash was caused by un-supported instructions in zen optimized kernels. The issue is fixed by calling only reference kernels if the architecture detected at runtime is not zen, zen2 or zen3. AMD-Internal: [CPUPL-1930] Change-Id: Ief57cd457b87542aa1a7bad64dc36c01f0d1a366 --- frame/compat/bla_copy.c | 81 +++++++++++++++++++++++++++++++---------- frame/compat/bla_gemm.c | 58 ++++++++++++++++++++++++++++- 2 files changed, 119 insertions(+), 20 deletions(-) diff --git a/frame/compat/bla_copy.c b/frame/compat/bla_copy.c index cdfa580b66..61df88cf1e 100644 --- a/frame/compat/bla_copy.c +++ b/frame/compat/bla_copy.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -153,16 +153,37 @@ void scopy_ incy0 = (inc_t)(*incy); } + // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. + // This function is invoked on all architectures including ‘generic’. + // Invoke architecture specific kernels only if we are sure that we are running on zen, + // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). + arch_t id = bli_arch_query_id(); + bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - /* Call BLIS kernel */ - bli_scopyv_zen_int - ( - BLIS_NO_CONJUGATE, - n0, - x0, incx0, - y0, incy0, - NULL - ); + if (bamdzen) + { + /* Call BLIS kernel */ + bli_scopyv_zen_int + ( + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + NULL + ); + } + else + { + PASTEMAC2(s, copyv, BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + NULL, + NULL + ); + } AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) /* Finalize BLIS. */ @@ -232,16 +253,38 @@ void dcopy_ incy0 = (inc_t)(*incy); } + // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. + // This function is invoked on all architectures including ‘generic’. + // Invoke architecture specific kernels only if we are sure that we are running on zen, + // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). + arch_t id = bli_arch_query_id(); + bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); + + if (bamdzen) + { + /* Call BLIS kernel */ + bli_dcopyv_zen_int + ( + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + NULL + ); + } + else + { + PASTEMAC2(d, copyv, BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + NULL, + NULL + ); + } - /* Call BLIS kernel */ - bli_dcopyv_zen_int - ( - BLIS_NO_CONJUGATE, - n0, - x0, incx0, - y0, incy0, - NULL - ); AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) /* Finalize BLIS. */ diff --git a/frame/compat/bla_gemm.c b/frame/compat/bla_gemm.c index 1bdb2397b2..50aa931a82 100644 --- a/frame/compat/bla_gemm.c +++ b/frame/compat/bla_gemm.c @@ -362,7 +362,63 @@ void dgemm_ const inc_t rs_c = 1; const inc_t cs_c = *ldc; - if((k0 == 1) && bli_is_notrans(blis_transa) && bli_is_notrans(blis_transb)) + // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. + // This function is invoked on all architectures including ‘generic’. + // Invoke architecture specific kernels only if we are sure that we are running on zen, + // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). + arch_t id = bli_arch_query_id(); + bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); + + if (!bamdzen) + { + // This code is duplicated below, however we don't want to move it out of + // this IF block as it will affect the performance on Zen architetures + // Also this is temporary fix which will be replaced later. + const num_t dt = BLIS_DOUBLE; + + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; + obj_t ao = BLIS_OBJECT_INITIALIZER; + obj_t bo = BLIS_OBJECT_INITIALIZER; + obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; + obj_t co = BLIS_OBJECT_INITIALIZER; + + dim_t m0_a, n0_a; + dim_t m0_b, n0_b; + + bli_set_dims_with_trans(blis_transa, m0, k0, &m0_a, &n0_a); + bli_set_dims_with_trans(blis_transb, k0, n0, &m0_b, &n0_b); + + bli_obj_init_finish_1x1(dt, (double *)alpha, &alphao); + bli_obj_init_finish_1x1(dt, (double *)beta, &betao); + + bli_obj_init_finish(dt, m0_a, n0_a, (double *)a, rs_a, cs_a, &ao); + bli_obj_init_finish(dt, m0_b, n0_b, (double *)b, rs_b, cs_b, &bo); + bli_obj_init_finish(dt, m0, n0, (double *)c, rs_c, cs_c, &co); + + bli_obj_set_conjtrans(blis_transa, &ao); + bli_obj_set_conjtrans(blis_transb, &bo); + + // Will call parallelized dgemm code - sup & native + PASTEMAC(gemm, BLIS_OAPI_EX_SUF) + ( + &alphao, + &ao, + &bo, + &betao, + &co, + NULL, + NULL + ); + + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + /* Finalize BLIS. */ + bli_finalize_auto(); + return; + } + + if((k0 == 1) && bli_is_notrans(blis_transa) && bli_is_notrans(blis_transb)) { bli_dgemm_ref_k1_nn( m0, n0, k0, (double*)alpha, From 61b5b9c4d0f4b65250a407c8da45df5304b010dd Mon Sep 17 00:00:00 2001 From: Dipal M Zambare Date: Wed, 27 Oct 2021 16:49:42 +0530 Subject: [PATCH 046/243] Fixed dynamic dispatch crash issue on non-zen architecture. Removed direct calling of zen kernels in cblas source itself. Similar optimizations are done by the function directly invoked from Cblas layer. The BLIS binary with dynamic dispatch feature was crashing on non-zen CPUs (specifically CPUs without AVX2 support). The crash was caused by un-supported instructions in zen optimized kernels. The issue is fixed by calling only reference kernels if the architecture detected at runtime is not zen, zen2 or zen3. AMD-Internal: [CPUPL-1930] Change-Id: I9178b7a98f2563dee2817064f37fcbb84073eeea --- frame/compat/cblas/src/cblas_daxpy.c | 68 +---------------------- frame/compat/cblas/src/cblas_dcopy.c | 72 +------------------------ frame/compat/cblas/src/cblas_ddot.c | 77 ++------------------------- frame/compat/cblas/src/cblas_dscal.c | 50 +---------------- frame/compat/cblas/src/cblas_dswap.c | 71 +----------------------- frame/compat/cblas/src/cblas_idamax.c | 68 +---------------------- frame/compat/cblas/src/cblas_isamax.c | 68 +---------------------- frame/compat/cblas/src/cblas_saxpy.c | 65 +--------------------- frame/compat/cblas/src/cblas_scopy.c | 74 +------------------------ frame/compat/cblas/src/cblas_sdot.c | 77 ++------------------------- frame/compat/cblas/src/cblas_sscal.c | 57 +------------------- frame/compat/cblas/src/cblas_sswap.c | 72 +------------------------ 12 files changed, 23 insertions(+), 796 deletions(-) diff --git a/frame/compat/cblas/src/cblas_daxpy.c b/frame/compat/cblas/src/cblas_daxpy.c index eb47367676..a42b92ae08 100644 --- a/frame/compat/cblas/src/cblas_daxpy.c +++ b/frame/compat/cblas/src/cblas_daxpy.c @@ -7,7 +7,7 @@ * * Written by Keita Teranishi. 2/11/1998 * - * Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. */ #include "cblas.h" #include "cblas_f77.h" @@ -22,72 +22,8 @@ void cblas_daxpy( f77_int N, double alpha, const double *X, #define F77_incY incY #endif -#ifdef BLIS_CONFIG_EPYC - dim_t n0; - double* x0; - double* y0; - inc_t incx0; - inc_t incy0; - - /* Initialize BLIS. */ -// bli_init_auto(); - - /* Convert/typecast negative values of n to zero. */ - if ( F77_N < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(F77_N); - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - if ( F77_incX < 0 ) - { - /* The semantics of negative stride in BLAS are that the vector - operand be traversed in reverse order. (Another way to think - of this is that negative strides effectively reverse the order - of the vector, but without any explicit data movements.) This - is also how BLIS interprets negative strides. The differences - is that with BLAS, the caller *always* passes in the 0th (i.e., - top-most or left-most) element of the vector, even when the - stride is negative. By contrast, in BLIS, negative strides are - used *relative* to the vector address as it is given. Thus, in - BLIS, if this backwards traversal is desired, the caller *must* - pass in the address to the (n-1)th (i.e., the bottom-most or - right-most) element along with a negative stride. */ - x0 = ((double*)X) + (n0-1)*(-F77_incX); - incx0 = ( inc_t )(F77_incX); - } - else - { - x0 = ((double*)X); - incx0 = ( inc_t )(F77_incX); - } - - if ( F77_incY < 0 ) - { - y0 = ((double*)Y) + (n0-1)*(-F77_incY); - incy0 = ( inc_t )(F77_incY); - } - else - { - y0 = ((double*)Y); - incy0 = ( inc_t )(F77_incY); - } - - bli_daxpyv_zen_int10( - BLIS_NO_CONJUGATE, - n0, - (double*)&alpha, - x0, incx0, - y0, incy0, - NULL - ); - - /* Finalize BLIS. */ -// bli_finalize_auto(); - - -#else F77_daxpy( &F77_N, &alpha, X, &F77_incX, Y, &F77_incY); -#endif + } #endif diff --git a/frame/compat/cblas/src/cblas_dcopy.c b/frame/compat/cblas/src/cblas_dcopy.c index c0be6fc0f2..7a5dcaf6b5 100644 --- a/frame/compat/cblas/src/cblas_dcopy.c +++ b/frame/compat/cblas/src/cblas_dcopy.c @@ -7,7 +7,7 @@ * * Written by Keita Teranishi. 2/11/1998 * - * Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. * */ @@ -18,80 +18,12 @@ void cblas_dcopy( f77_int N, const double *X, { #ifdef F77_INT F77_INT F77_N=N, F77_incX=incX, F77_incY=incY; -#else +#else #define F77_N N #define F77_incX incX #define F77_incY incY #endif -#ifdef BLIS_CONFIG_EPYC - dim_t n0; - double* x0; - double* y0; - inc_t incx0; - inc_t incy0; - - /* Initialize BLIS. */ -// bli_init_auto(); - - /* Convert/typecast negative values of n to zero. */ - if ( F77_N < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(F77_N); - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - if ( F77_incX < 0 ) - { - /* The semantics of negative stride in BLAS are that the vector - operand be traversed in reverse order. (Another way to think - of this is that negative strides effectively reverse the order - of the vector, but without any explicit data movements.) This - is also how BLIS interprets negative strides. The differences - is that with BLAS, the caller *always* passes in the 0th (i.e., - top-most or left-most) element of the vector, even when the - stride is negative. By contrast, in BLIS, negative strides are - used *relative* to the vector address as it is given. Thus, in - BLIS, if this backwards traversal is desired, the caller *must* - pass in the address to the (n-1)th (i.e., the bottom-most or - right-most) element along with a negative stride. */ - - x0 = (double*)((X) + (n0-1)*(-F77_incX)); - incx0 = ( inc_t )(F77_incX); - - } - else - { - x0 = (double*)(X); - incx0 = ( inc_t )(F77_incX); - } - if ( F77_incY < 0 ) - { - y0 = (Y) + (n0-1)*(-F77_incY); - incy0 = ( inc_t )(F77_incY); - - } - else - { - y0 = (Y); - incy0 = ( inc_t )(F77_incY); - } - - - /* Call BLIS kernel */ - bli_dcopyv_zen_int - ( - BLIS_NO_CONJUGATE, - n0, - x0, incx0, - y0, incy0, - NULL - ); - - /* Finalize BLIS. */ -// bli_finalize_auto(); -#else F77_dcopy( &F77_N, X, &F77_incX, Y, &F77_incY); -#endif - } #endif diff --git a/frame/compat/cblas/src/cblas_ddot.c b/frame/compat/cblas/src/cblas_ddot.c index fd16ad7615..47fc9efb1a 100644 --- a/frame/compat/cblas/src/cblas_ddot.c +++ b/frame/compat/cblas/src/cblas_ddot.c @@ -8,7 +8,7 @@ * * Written by Keita Teranishi. 2/11/1998 * - * Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. * */ #include "cblas.h" @@ -20,85 +20,14 @@ double cblas_ddot( f77_int N, const double *X, double dot; #ifdef F77_INT F77_INT F77_N=N, F77_incX=incX, F77_incY=incY; -#else +#else #define F77_N N #define F77_incX incX #define F77_incY incY #endif -#ifdef BLIS_CONFIG_EPYC - dim_t n0; - double* x0; - double* y0; - inc_t incx0; - inc_t incy0; - - /* Initialize BLIS. */ -// bli_init_auto(); - - /* Convert/typecast negative values of n to zero. */ - if ( F77_N < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(F77_N); - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - - if ( F77_incX < 0 ) - { - /* The semantics of negative stride in BLAS are that the vector - operand be traversed in reverse order. (Another way to think - of this is that negative strides effectively reverse the order - of the vector, but without any explicit data movements.) This - is also how BLIS interprets negative strides. The differences - is that with BLAS, the caller *always* passes in the 0th (i.e., - top-most or left-most) element of the vector, even when the - stride is negative. By contrast, in BLIS, negative strides are - used *relative* to the vector address as it is given. Thus, in - BLIS, if this backwards traversal is desired, the caller *must* - pass in the address to the (n-1)th (i.e., the bottom-most or - right-most) element along with a negative stride. */ - - x0 = ((double*)X) + (n0-1)*(-F77_incX); - incx0 = ( inc_t )(F77_incX); - } - else - { - x0 = ((double*)X); - incx0 = ( inc_t )(F77_incX); - } - - if ( F77_incY < 0 ) - { - y0 = ((double*)Y) + (n0-1)*(-F77_incY); - incy0 = ( inc_t )(F77_incY); - - } - else - { - y0 = ((double*)Y); - incy0 = ( inc_t )(F77_incY); - } - /* Call BLIS kernel. */ - bli_ddotv_zen_int10 - ( - BLIS_NO_CONJUGATE, - BLIS_NO_CONJUGATE, - n0, - x0, incx0, - y0, incy0, - &dot, - NULL - ); - - /* Finalize BLIS. */ -// bli_finalize_auto(); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return dot; - -#else F77_ddot_sub( &F77_N, X, &F77_incX, Y, &F77_incY, &dot); AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); return dot; -#endif -} +} #endif diff --git a/frame/compat/cblas/src/cblas_dscal.c b/frame/compat/cblas/src/cblas_dscal.c index e0e3b29b44..88c5b3fa0a 100644 --- a/frame/compat/cblas/src/cblas_dscal.c +++ b/frame/compat/cblas/src/cblas_dscal.c @@ -8,7 +8,7 @@ * Written by Keita Teranishi. 2/11/1998 * * - * Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. */ #include "cblas.h" #include "cblas_f77.h" @@ -22,56 +22,8 @@ void cblas_dscal( f77_int N, double alpha, double *X, #define F77_incX incX #define F77_incY incY #endif -#ifdef BLIS_CONFIG_EPYC - dim_t n0; - double* x0; - inc_t incx0; - /* Initialize BLIS. */ -// bli_init_auto(); - - if ( F77_N < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(F77_N); - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - if ( F77_incX < 0 ) - { - /* The semantics of negative stride in BLAS are that the vector - operand be traversed in reverse order. (Another way to think - of this is that negative strides effectively reverse the order - of the vector, but without any explicit data movements.) This - is also how BLIS interprets negative strides. The differences - is that with BLAS, the caller *always* passes in the 0th (i.e., - top-most or left-most) element of the vector, even when the - stride is negative. By contrast, in BLIS, negative strides are - used *relative* to the vector address as it is given. Thus, in - BLIS, if this backwards traversal is desired, the caller *must* - pass in the address to the (n-1)th (i.e., the bottom-most or - right-most) element along with a negative stride. */ - - x0 = (X) + (n0-1)*(-F77_incX); - incx0 = ( inc_t )(F77_incX); - - } - else - { - x0 = (X); - incx0 = ( inc_t )(F77_incX); - } - - /* Call BLIS kernel */ - bli_dscalv_zen_int10 - ( - BLIS_NO_CONJUGATE, - n0, - &alpha, - x0, incx0, - NULL - ); -#else F77_dscal( &F77_N, &alpha, X, &F77_incX); -#endif } #endif diff --git a/frame/compat/cblas/src/cblas_dswap.c b/frame/compat/cblas/src/cblas_dswap.c index 5a5ccbf146..1432d59ae9 100644 --- a/frame/compat/cblas/src/cblas_dswap.c +++ b/frame/compat/cblas/src/cblas_dswap.c @@ -7,7 +7,7 @@ * * Written by Keita Teranishi. 2/11/1998 * - * Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. * */ #include "cblas.h" @@ -17,79 +17,12 @@ void cblas_dswap( f77_int N, double *X, f77_int incX, double *Y, { #ifdef F77_INT F77_INT F77_N=N, F77_incX=incX, F77_incY=incY; -#else +#else #define F77_N N #define F77_incX incX #define F77_incY incY #endif -#ifdef BLIS_CONFIG_EPYC - dim_t n0; - double* x0; - double* y0; - inc_t incx0; - inc_t incy0; - - /* Initialize BLIS. */ -// bli_init_auto(); - - /* Convert/typecast negative values of n to zero. */ - if ( F77_N < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(F77_N); - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - if ( F77_incX < 0 ) - { - /* The semantics of negative stride in BLAS are that the vector - operand be traversed in reverse order. (Another way to think - of this is that negative strides effectively reverse the order - of the vector, but without any explicit data movements.) This - is also how BLIS interprets negative strides. The differences - is that with BLAS, the caller *always* passes in the 0th (i.e., - top-most or left-most) element of the vector, even when the - stride is negative. By contrast, in BLIS, negative strides are - used *relative* to the vector address as it is given. Thus, in - BLIS, if this backwards traversal is desired, the caller *must* - pass in the address to the (n-1)th (i.e., the bottom-most or - right-most) element along with a negative stride. */ - - x0 = (X) + (n0-1)*(-F77_incX); - incx0 = ( inc_t )(F77_incX); - - } - else - { - x0 = (X); - incx0 = ( inc_t )(F77_incX); - } - - if ( F77_incY < 0 ) - { - y0 = (Y) + (n0-1)*(-F77_incY); - incy0 = ( inc_t )(F77_incY); - - } - else - { - y0 = (Y); - incy0 = ( inc_t )(F77_incY); - } - - - /* Call BLIS kernel */ - bli_dswapv_zen_int8 - ( - n0, - x0, incx0, - y0, incy0, - NULL - ); - - /* Finalize BLIS. */ -// bli_finalize_auto(); -#else F77_dswap( &F77_N, X, &F77_incX, Y, &F77_incY); -#endif } #endif diff --git a/frame/compat/cblas/src/cblas_idamax.c b/frame/compat/cblas/src/cblas_idamax.c index 071482c364..46d7d93774 100644 --- a/frame/compat/cblas/src/cblas_idamax.c +++ b/frame/compat/cblas/src/cblas_idamax.c @@ -7,7 +7,7 @@ * It calls the fortran wrapper before calling idamax. * * Written by Keita Teranishi. 2/11/1998 - * Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. * */ #include "cblas.h" @@ -17,76 +17,12 @@ f77_int cblas_idamax( f77_int N, const double *X, f77_int incX) f77_int iamax; #ifdef F77_INT F77_INT F77_N=N, F77_incX=incX; -#else +#else #define F77_N N #define F77_incX incX #endif -#ifdef BLIS_CONFIG_EPYC - dim_t n0; - double* x0; - inc_t incx0; - gint_t bli_index; - - /* If the vector is empty, return an index of zero. This early check - is needed to emulate netlib BLAS. Without it, bli_?amaxv() will - return 0, which ends up getting incremented to 1 (below) before - being returned, which is not what we want. */ - if ( F77_N < 1 || F77_incX <= 0 ) return 0; - - /* Initialize BLIS. */ -// bli_init_auto(); - - /* Convert/typecast negative values of n to zero. */ - if ( F77_N < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(F77_N); - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - if ( F77_incX < 0 ) - { - /* The semantics of negative stride in BLAS are that the vector - operand be traversed in reverse order. (Another way to think - of this is that negative strides effectively reverse the order - of the vector, but without any explicit data movements.) This - is also how BLIS interprets negative strides. The differences - is that with BLAS, the caller *always* passes in the 0th (i.e., - top-most or left-most) element of the vector, even when the - stride is negative. By contrast, in BLIS, negative strides are - used *relative* to the vector address as it is given. Thus, in - BLIS, if this backwards traversal is desired, the caller *must* - pass in the address to the (n-1)th (i.e., the bottom-most or - right-most) element along with a negative stride. */ - - x0 = ((double*)X) + (n0-1)*(-F77_incX); - incx0 = ( inc_t )(F77_incX); - - } - else - { - x0 = ((double*)X); - incx0 = ( inc_t )(F77_incX); - } - - /* Call BLIS kernel. */ - bli_damaxv_zen_int - ( - n0, - x0, incx0, - &bli_index, - NULL - ); - - /* Finalize BLIS. */ -// bli_finalize_auto(); - - iamax = bli_index; - - return iamax; - -#else F77_idamax_sub( &F77_N, X, &F77_incX, &iamax); return iamax ? iamax-1 : 0; -#endif } #endif diff --git a/frame/compat/cblas/src/cblas_isamax.c b/frame/compat/cblas/src/cblas_isamax.c index 81d13d0990..f41e430970 100644 --- a/frame/compat/cblas/src/cblas_isamax.c +++ b/frame/compat/cblas/src/cblas_isamax.c @@ -7,7 +7,7 @@ * It calls the fortran wrapper before calling isamax. * * Written by Keita Teranishi. 2/11/1998 - * Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. * */ #include "cblas.h" @@ -17,76 +17,12 @@ f77_int cblas_isamax( f77_int N, const float *X, f77_int incX) f77_int iamax; #ifdef F77_INT F77_INT F77_N=N, F77_incX=incX; -#else +#else #define F77_N N #define F77_incX incX #endif -#ifdef BLIS_CONFIG_EPYC - dim_t n0; - float* x0; - inc_t incx0; - gint_t bli_index; - - /* If the vector is empty, return an index of zero. This early check - is needed to emulate netlib BLAS. Without it, bli_?amaxv() will - return 0, which ends up getting incremented to 1 (below) before - being returned, which is not what we want. */ - if ( F77_N < 1 || F77_incX <= 0 ) return 0; - - /* Initialize BLIS. */ -// bli_init_auto(); - - /* Convert/typecast negative values of n to zero. */ - if ( F77_N < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(F77_N); - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - if ( F77_incX < 0 ) - { - /* The semantics of negative stride in BLAS are that the vector - operand be traversed in reverse order. (Another way to think - of this is that negative strides effectively reverse the order - of the vector, but without any explicit data movements.) This - is also how BLIS interprets negative strides. The differences - is that with BLAS, the caller *always* passes in the 0th (i.e., - top-most or left-most) element of the vector, even when the - stride is negative. By contrast, in BLIS, negative strides are - used *relative* to the vector address as it is given. Thus, in - BLIS, if this backwards traversal is desired, the caller *must* - pass in the address to the (n-1)th (i.e., the bottom-most or - right-most) element along with a negative stride. */ - - x0 = ((float*)X) + (n0-1)*(-F77_incX); - incx0 = ( inc_t )(F77_incX); - - } - else - { - x0 = ((float*)X); - incx0 = ( inc_t )(F77_incX); - } - - /* Call BLIS kernel. */ - bli_samaxv_zen_int - ( - n0, - x0, incx0, - &bli_index, - NULL - ); - - /* Finalize BLIS. */ -// bli_finalize_auto(); - - iamax = bli_index; - - return iamax; - -#else F77_isamax_sub( &F77_N, X, &F77_incX, &iamax); return iamax ? iamax-1 : 0; -#endif } #endif diff --git a/frame/compat/cblas/src/cblas_saxpy.c b/frame/compat/cblas/src/cblas_saxpy.c index 8c5ace43fa..db6b21b855 100644 --- a/frame/compat/cblas/src/cblas_saxpy.c +++ b/frame/compat/cblas/src/cblas_saxpy.c @@ -8,7 +8,7 @@ * * Written by Keita Teranishi. 2/11/1998 * - * Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. */ #include "cblas.h" @@ -24,70 +24,7 @@ void cblas_saxpy( f77_int N, float alpha, const float *X, #define F77_incY incY #endif -#ifdef BLIS_CONFIG_EPYC - dim_t n0; - float* x0; - float* y0; - inc_t incx0; - inc_t incy0; - - /* Initialize BLIS. */ -// bli_init_auto(); - - /* Convert/typecast negative values of n to zero. */ - if ( F77_N < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(F77_N); - - if ( F77_incX < 0 ) - { - /* The semantics of negative stride in BLAS are that the vector - operand be traversed in reverse order. (Another way to think - of this is that negative strides effectively reverse the order - of the vector, but without any explicit data movements.) This - is also how BLIS interprets negative strides. The differences - is that with BLAS, the caller *always* passes in the 0th (i.e., - top-most or left-most) element of the vector, even when the - stride is negative. By contrast, in BLIS, negative strides are - used *relative* to the vector address as it is given. Thus, in - BLIS, if this backwards traversal is desired, the caller *must* - pass in the address to the (n-1)th (i.e., the bottom-most or - right-most) element along with a negative stride. */ - x0 = ((float*)X) + (n0-1)*(-F77_incX); - incx0 = ( inc_t )(F77_incX); - } - else - { - x0 = ((float*)X); - incx0 = ( inc_t )(F77_incX); - } - - if ( F77_incY < 0 ) - { - y0 = ((float*)Y) + (n0-1)*(-F77_incY); - incy0 = ( inc_t )(F77_incY); - } - else - { - y0 = ((float*)Y); - incy0 = ( inc_t )(F77_incY); - } - - bli_saxpyv_zen_int10( - BLIS_NO_CONJUGATE, - n0, - (float*)&alpha, - x0, incx0, - y0, incy0, - NULL - ); - - /* Finalize BLIS. */ -// bli_finalize_auto(); - -#else F77_saxpy( &F77_N, &alpha, X, &F77_incX, Y, &F77_incY); -#endif - } #endif diff --git a/frame/compat/cblas/src/cblas_scopy.c b/frame/compat/cblas/src/cblas_scopy.c index 518d4f6295..23c78e0dc2 100644 --- a/frame/compat/cblas/src/cblas_scopy.c +++ b/frame/compat/cblas/src/cblas_scopy.c @@ -7,7 +7,7 @@ * * Written by Keita Teranishi. 2/11/1998 * - * Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. * */ #include "cblas.h" @@ -17,82 +17,12 @@ void cblas_scopy( f77_int N, const float *X, { #ifdef F77_INT F77_INT F77_N=N, F77_incX=incX, F77_incY=incY; -#else +#else #define F77_N N #define F77_incX incX #define F77_incY incY #endif -#ifdef BLIS_CONFIG_EPYC - - dim_t n0; - float* x0; - float* y0; - inc_t incx0; - inc_t incy0; - - /* Initialize BLIS. */ -// bli_init_auto(); - - /* Convert/typecast negative values of n to zero. */ - if ( F77_N < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(F77_N); - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - if ( F77_incX < 0 ) - { - /* The semantics of negative stride in BLAS are that the vector - operand be traversed in reverse order. (Another way to think - of this is that negative strides effectively reverse the order - of the vector, but without any explicit data movements.) This - is also how BLIS interprets negative strides. The differences - is that with BLAS, the caller *always* passes in the 0th (i.e., - top-most or left-most) element of the vector, even when the - stride is negative. By contrast, in BLIS, negative strides are - used *relative* to the vector address as it is given. Thus, in - BLIS, if this backwards traversal is desired, the caller *must* - pass in the address to the (n-1)th (i.e., the bottom-most or - right-most) element along with a negative stride. */ - - x0 = (float*)((X) + (n0-1)*(-F77_incX)); - incx0 = ( inc_t )(F77_incX); - - } - else - { - x0 = (float*)(X); - incx0 = ( inc_t )(F77_incX); - } - - if ( F77_incY < 0 ) - { - y0 = (Y) + (n0-1)*(-F77_incY); - incy0 = ( inc_t )(F77_incY); - } - else - { - y0 = (Y); - incy0 = ( inc_t )(F77_incY); - } - - - /* Call BLIS kernel */ - bli_scopyv_zen_int - ( - BLIS_NO_CONJUGATE, - n0, - x0, incx0, - y0, incy0, - NULL - ); - - /* Finalize BLIS. */ -// bli_finalize_auto(); - -#else F77_scopy( &F77_N, X, &F77_incX, Y, &F77_incY); -#endif - } #endif diff --git a/frame/compat/cblas/src/cblas_sdot.c b/frame/compat/cblas/src/cblas_sdot.c index 970eda42d4..7504597470 100644 --- a/frame/compat/cblas/src/cblas_sdot.c +++ b/frame/compat/cblas/src/cblas_sdot.c @@ -8,7 +8,7 @@ * * Written by Keita Teranishi. 2/11/1998 * - * Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. * */ #include "cblas.h" @@ -20,85 +20,14 @@ float cblas_sdot( f77_int N, const float *X, float dot; #ifdef F77_INT F77_INT F77_N=N, F77_incX=incX, F77_incY=incY; -#else +#else #define F77_N N #define F77_incX incX #define F77_incY incY #endif -#ifdef BLIS_CONFIG_EPYC - dim_t n0; - float* x0; - float* y0; - inc_t incx0; - inc_t incy0; - - /* Initialize BLIS. */ -// bli_init_auto(); - - /* Convert/typecast negative values of n to zero. */ - if ( F77_N < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(F77_N); - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - - if ( F77_incX < 0 ) - { - /* The semantics of negative stride in BLAS are that the vector - operand be traversed in reverse order. (Another way to think - of this is that negative strides effectively reverse the order - of the vector, but without any explicit data movements.) This - is also how BLIS interprets negative strides. The differences - is that with BLAS, the caller *always* passes in the 0th (i.e., - top-most or left-most) element of the vector, even when the - stride is negative. By contrast, in BLIS, negative strides are - used *relative* to the vector address as it is given. Thus, in - BLIS, if this backwards traversal is desired, the caller *must* - pass in the address to the (n-1)th (i.e., the bottom-most or - right-most) element along with a negative stride. */ - - x0 = ((float*)X) + (n0-1)*(-F77_incX); - incx0 = ( inc_t )(F77_incX); - } - else - { - x0 = ((float*)X); - incx0 = ( inc_t )(F77_incX); - } - - if ( F77_incY < 0 ) - { - y0 = ((float*)Y) + (n0-1)*(-F77_incY); - incy0 = ( inc_t )(F77_incY); - - } - else - { - y0 = ((float*)Y); - incy0 = ( inc_t )(F77_incY); - } - - /* Call BLIS kernel. */ - bli_sdotv_zen_int10 - ( - BLIS_NO_CONJUGATE, - BLIS_NO_CONJUGATE, - n0, - x0, incx0, - y0, incy0, - &dot, - NULL - ); - - /* Finalize BLIS. */ -// bli_finalize_auto(); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return dot; -#else F77_sdot_sub( &F77_N, X, &F77_incX, Y, &F77_incY, &dot); AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); return dot; -#endif -} +} #endif diff --git a/frame/compat/cblas/src/cblas_sscal.c b/frame/compat/cblas/src/cblas_sscal.c index 6c4de46830..b1b4cb471b 100644 --- a/frame/compat/cblas/src/cblas_sscal.c +++ b/frame/compat/cblas/src/cblas_sscal.c @@ -7,7 +7,7 @@ * * Written by Keita Teranishi. 2/11/1998 * - * Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. * */ #include "cblas.h" @@ -22,63 +22,8 @@ void cblas_sscal( f77_int N, float alpha, float *X, #define F77_incX incX #define F77_incY incY #endif -#ifdef BLIS_CONFIG_EPYC - dim_t n0; - float* x0; - inc_t incx0; - - /* Initialize BLIS. */ - //bli_init_auto(); - - /* Convert/typecast negative values of n to zero. */ - if ( F77_N < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(F77_N); - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - if ( F77_incX < 0 ) - { - /* The semantics of negative stride in BLAS are that the vector - operand be traversed in reverse order. (Another way to think - of this is that negative strides effectively reverse the order - of the vector, but without any explicit data movements.) This - is also how BLIS interprets negative strides. The differences - is that with BLAS, the caller *always* passes in the 0th (i.e., - top-most or left-most) element of the vector, even when the - stride is negative. By contrast, in BLIS, negative strides are - used *relative* to the vector address as it is given. Thus, in - BLIS, if this backwards traversal is desired, the caller *must* - pass in the address to the (n-1)th (i.e., the bottom-most or - right-most) element along with a negative stride. */ - - x0 = (X) + (n0-1)*(-F77_incX); - incx0 = ( inc_t )(F77_incX); - - } - else - { - x0 = (X); - incx0 = ( inc_t )(F77_incX); - } - - - /* Call BLIS kernel */ - bli_sscalv_zen_int10 - ( - BLIS_NO_CONJUGATE, - n0, - &alpha, - x0, incx0, - NULL - ); - - /* Finalize BLIS. */ -// bli_finalize_auto(); - -#else F77_sscal( &F77_N, &alpha, X, &F77_incX); -#endif } #endif diff --git a/frame/compat/cblas/src/cblas_sswap.c b/frame/compat/cblas/src/cblas_sswap.c index c09e154c00..d352ee96a8 100644 --- a/frame/compat/cblas/src/cblas_sswap.c +++ b/frame/compat/cblas/src/cblas_sswap.c @@ -7,7 +7,7 @@ * * Written by Keita Teranishi. 2/11/1998 * - * Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. * */ #include "cblas.h" @@ -17,81 +17,13 @@ void cblas_sswap( f77_int N, float *X, f77_int incX, float *Y, { #ifdef F77_INT F77_INT F77_N=N, F77_incX=incX, F77_incY=incY; -#else +#else #define F77_N N #define F77_incX incX #define F77_incY incY #endif -#ifdef BLIS_CONFIG_EPYC - - dim_t n0; - float* x0; - float* y0; - inc_t incx0; - inc_t incy0; - - /* Initialize BLIS. */ -// bli_init_auto(); - - /* Convert/typecast negative values of n to zero. */ - if ( F77_N < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(F77_N); - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - if ( F77_incX < 0 ) - { - /* The semantics of negative stride in BLAS are that the vector - operand be traversed in reverse order. (Another way to think - of this is that negative strides effectively reverse the order - of the vector, but without any explicit data movements.) This - is also how BLIS interprets negative strides. The differences - is that with BLAS, the caller *always* passes in the 0th (i.e., - top-most or left-most) element of the vector, even when the - stride is negative. By contrast, in BLIS, negative strides are - used *relative* to the vector address as it is given. Thus, in - BLIS, if this backwards traversal is desired, the caller *must* - pass in the address to the (n-1)th (i.e., the bottom-most or - right-most) element along with a negative stride. */ - - x0 = (X) + (n0-1)*(-F77_incX); - incx0 = ( inc_t )(F77_incX); - - } - else - { - x0 = (X); - incx0 = ( inc_t )(F77_incX); - } - if ( F77_incY < 0 ) - { - y0 = (Y) + (n0-1)*(-F77_incY); - incy0 = ( inc_t )(F77_incY); - - } - else - { - y0 = (Y); - incy0 = ( inc_t )(F77_incY); - } - - - /* Call BLIS kernel */ - bli_sswapv_zen_int8 - ( - n0, - x0, incx0, - y0, incy0, - NULL - ); - - /* Finalize BLIS. */ -// bli_finalize_auto(); - -#else F77_sswap( &F77_N, X, &F77_incX, Y, &F77_incY); -#endif } #endif From 8f297f6267cb955230843b40dfa8e848243dada3 Mon Sep 17 00:00:00 2001 From: Harsh Dave Date: Wed, 27 Oct 2021 05:27:23 -0500 Subject: [PATCH 047/243] Fixed dynamic dispatch crash issue on non-zen architecture. Removed direct calling of zen kernels in blis interface for trsm, scalv, swapv. The BLIS binary with dynamic dispatch feature was crashing on non-zen CPUs (specifically CPUs without AVX2 support). The crash was caused by un-supported instructions in zen optimized kernels. The issue is fixed by calling only reference kernels if the architecture detected at runtime is not zen, zen2 or zen3. AMD-Internal: [CPUPL-1930] Change-Id: I7944d131d376e2c4e778fe441a8b030674952b81 --- frame/2/trsv/bli_trsv_unf_var1.c | 29 ++++++-- frame/2/trsv/bli_trsv_unf_var2.c | 31 +++++++-- frame/compat/bla_scal.c | 64 ++++++++++++----- frame/compat/bla_swap.c | 61 ++++++++++++----- frame/compat/bla_trsm.c | 114 ++++++++++++++++--------------- 5 files changed, 203 insertions(+), 96 deletions(-) diff --git a/frame/2/trsv/bli_trsv_unf_var1.c b/frame/2/trsv/bli_trsv_unf_var1.c index 0d0de9eb75..35b48d8a92 100644 --- a/frame/2/trsv/bli_trsv_unf_var1.c +++ b/frame/2/trsv/bli_trsv_unf_var1.c @@ -297,8 +297,18 @@ void bli_dtrsv_unf_var1 PASTECH(d,dotxf_ker_ft) kfp_df; /* Assign kernel function pointer and fusing factor. */ - kfp_df = bli_ddotxf_zen_int_8; - b_fuse = 8; + arch_t id = bli_arch_query_id(); + bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); + if (bamdzen) { + kfp_df = bli_ddotxf_zen_int_8; + b_fuse = 8; + } + else + { + num_t dt = PASTEMAC(d,type); + kfp_df = bli_cntx_get_l1f_ker_dt( dt, BLIS_DOTXF_KER, cntx ); + b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_DF, cntx ); + } /* We reduce all of the possible cases down to just lower/upper. */ if ( bli_is_upper( uploa_trans ) ) @@ -488,8 +498,19 @@ void bli_strsv_unf_var1 PASTECH(s,dotxf_ker_ft) kfp_df; /* Assign kernel function pointer and fusing factor. */ - kfp_df = bli_sdotxf_zen_int_8; - b_fuse = 8; + arch_t id = bli_arch_query_id(); + bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); + if (bamdzen) { + kfp_df = bli_sdotxf_zen_int_8; + b_fuse = 8; + } + else + { + num_t dt = PASTEMAC(s,type); + kfp_df = bli_cntx_get_l1f_ker_dt( dt, BLIS_DOTXF_KER, cntx ); + b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_DF, cntx ); + + } /* We reduce all of the possible cases down to just lower/upper. */ if ( bli_is_upper( uploa_trans ) ) diff --git a/frame/2/trsv/bli_trsv_unf_var2.c b/frame/2/trsv/bli_trsv_unf_var2.c index 7752417c41..58135b0752 100644 --- a/frame/2/trsv/bli_trsv_unf_var2.c +++ b/frame/2/trsv/bli_trsv_unf_var2.c @@ -293,8 +293,18 @@ void bli_dtrsv_unf_var2 PASTECH(d,axpyf_ker_ft) kfp_af; /* Assign kernel function pointer and fusing factor. */ - kfp_af = bli_daxpyf_zen_int_16x4; - b_fuse = 4; + arch_t id = bli_arch_query_id(); + bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); + if (bamdzen) { + kfp_af = bli_daxpyf_zen_int_16x4; + b_fuse = 4; + } + else + { + num_t dt = PASTEMAC(d,type); + kfp_af = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPYF_KER, cntx ); + b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_AF, cntx ); + } /* We reduce all of the possible cases down to just lower/upper. */ if ( bli_is_upper( uploa_trans ) ) @@ -479,8 +489,21 @@ void bli_strsv_unf_var2 PASTECH(s, axpyf_ker_ft) kfp_af; /* Assign function pointer and fusing factor. */ - kfp_af = bli_saxpyf_zen_int_5; - b_fuse = 5; +/* kfp_af = bli_saxpyf_zen_int_5; + b_fuse = 5;*/ + arch_t id = bli_arch_query_id(); + bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); + if (bamdzen) { + kfp_af = bli_saxpyf_zen_int_5; + b_fuse = 5; + } + else + { + + num_t dt = PASTEMAC(s,type); + kfp_af = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPYF_KER, cntx ); + b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_AF, cntx ); + } /* We reduce all of the possible cases down to just lower/upper. */ if ( bli_is_upper( uploa_trans ) ) diff --git a/frame/compat/bla_scal.c b/frame/compat/bla_scal.c index 184b14eda0..30fd857bc7 100644 --- a/frame/compat/bla_scal.c +++ b/frame/compat/bla_scal.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020-21, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -146,14 +146,29 @@ void sscal_ incx0 = ( inc_t )(*incx); } /* Call BLIS kernel */ - bli_sscalv_zen_int10 - ( - BLIS_NO_CONJUGATE, - n0, - (float *)alpha, - x0, incx0, - NULL - ); + arch_t id = bli_arch_query_id(); + bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); + if (bamdzen) { + bli_sscalv_zen_int10 + ( + BLIS_NO_CONJUGATE, + n0, + (float *)alpha, + x0, incx0, + NULL + ); + } + else{ + PASTEMAC2(s,scalv,BLIS_TAPI_EX_SUF) \ + ( \ + BLIS_NO_CONJUGATE,\ + n0, \ + (float *)alpha,\ + x0, incx0,\ + NULL, \ + NULL \ + );\ + } /* Finalize BLIS. */ // bli_finalize_auto(); @@ -212,14 +227,29 @@ void dscal_ incx0 = ( inc_t )(*incx); } /* Call BLIS kernel */ - bli_dscalv_zen_int10 - ( - BLIS_NO_CONJUGATE, - n0, - (double*) alpha, - x0, incx0, - NULL - ); + arch_t id = bli_arch_query_id(); + bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); + if (bamdzen){ + bli_dscalv_zen_int10 + ( + BLIS_NO_CONJUGATE, + n0, + (double*) alpha, + x0, incx0, + NULL + ); + } + else{ + PASTEMAC2(d,scalv,BLIS_TAPI_EX_SUF) \ + ( \ + BLIS_NO_CONJUGATE,\ + n0, \ + (double *)alpha,\ + x0, incx0,\ + NULL, \ + NULL \ + );\ + } /* Finalize BLIS. */ // bli_finalize_auto(); diff --git a/frame/compat/bla_swap.c b/frame/compat/bla_swap.c index a48783e849..6ecb360f95 100644 --- a/frame/compat/bla_swap.c +++ b/frame/compat/bla_swap.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020-21, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -146,15 +146,28 @@ void sswap_ incy0 = ( inc_t )(*incy); } - - /* Call BLIS kernel */ - bli_sswapv_zen_int8 - ( - n0, - x0, incx0, - y0, incy0, - NULL - ); + arch_t id = bli_arch_query_id(); + bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); + if (bamdzen) { +/* Call BLIS kernel */ + bli_sswapv_zen_int8 + ( + n0, + x0, incx0, + y0, incy0, + NULL + ); + } + else{ + PASTEMAC2(s,swapv,BLIS_TAPI_EX_SUF) \ + ( \ + n0, \ + x0, incx0, \ + y0, incy0, \ + NULL, \ + NULL \ + ); \ + } /* Finalize BLIS. */ // bli_finalize_auto(); @@ -224,13 +237,27 @@ void dswap_ /* Call BLIS kernel */ - bli_dswapv_zen_int8 - ( - n0, - x0, incx0, - y0, incy0, - NULL - ); + arch_t id = bli_arch_query_id(); + bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); + if (bamdzen) { + bli_dswapv_zen_int8 + ( + n0, + x0, incx0, + y0, incy0, + NULL + ); + } + else{ + PASTEMAC2(d,swapv,BLIS_TAPI_EX_SUF) \ + ( \ + n0, \ + x0, incx0, \ + y0, incy0, \ + NULL, \ + NULL \ + ); \ + } /* Finalize BLIS. */ // bli_finalize_auto(); diff --git a/frame/compat/bla_trsm.c b/frame/compat/bla_trsm.c index 0ee6e180c2..654d3530d2 100644 --- a/frame/compat/bla_trsm.c +++ b/frame/compat/bla_trsm.c @@ -595,35 +595,38 @@ void strsm_ bli_obj_set_struc( struca, &ao ); + arch_t id = bli_arch_query_id(); + bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); + if (bamdzen) { #ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM - /* bli_strsm_small is performing better existing native - * implementations for [m,n]<=1000 for single thread. - * In case of multithread when [m,n]<=128 sinlge thread implemenation - * is doing better than native multithread */ - bool nt = bli_thread_get_is_parallel(); - if((nt==0 && m0<=1000 && n0<=1000) || - (nt && (m0+n0)<320) ) - { - err_t status; - status = bli_trsm_small - ( - blis_side, - &alphao, - &ao, - &bo, - NULL, - NULL - ); - if (status == BLIS_SUCCESS) - { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - /* Finalize BLIS. */ - bli_finalize_auto(); - return; - } - } + /* bli_strsm_small is performing better existing native + * implementations for [m,n]<=1000 for single thread. + * In case of multithread when [m,n]<=128 sinlge thread implemenation + * is doing better than native multithread */ + bool nt = bli_thread_get_is_parallel(); + if((nt==0 && m0<=1000 && n0<=1000) || + (nt && (m0+n0)<320) ) + { + err_t status; + status = bli_trsm_small + ( + blis_side, + &alphao, + &ao, + &bo, + NULL, + NULL + ); + if (status == BLIS_SUCCESS) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + /* Finalize BLIS. */ + bli_finalize_auto(); + return; + } + } #endif - + } bli_trsmnat ( blis_side, @@ -853,35 +856,38 @@ void dtrsm_ bli_obj_set_struc( struca, &ao ); + arch_t id = bli_arch_query_id(); + bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); + if (bamdzen) { #ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM - /* bli_dtrsm_small is performing better existing native - * implementations for [m,n]<=1000 for single thread. - * In case of multithread when [m,n]<=128 sinlge thread implemenation - * is doing better than native multithread */ - bool nt = bli_thread_get_is_parallel(); - if((nt==0 && m0<=1000 && n0<=1000) || - (nt && (m0+n0)<320) ) - { - err_t status; - status = bli_trsm_small - ( - blis_side, - &alphao, - &ao, - &bo, - NULL, - NULL - ); - if (status == BLIS_SUCCESS) - { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - /* Finalize BLIS. */ - bli_finalize_auto(); - return; - } - } + /* bli_dtrsm_small is performing better existing native + * implementations for [m,n]<=1000 for single thread. + * In case of multithread when [m,n]<=128 sinlge thread implemenation + * is doing better than native multithread */ + bool nt = bli_thread_get_is_parallel(); + if((nt==0 && m0<=1000 && n0<=1000) || + (nt && (m0+n0)<320) ) + { + err_t status; + status = bli_trsm_small + ( + blis_side, + &alphao, + &ao, + &bo, + NULL, + NULL + ); + if (status == BLIS_SUCCESS) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + /* Finalize BLIS. */ + bli_finalize_auto(); + return; + } + } #endif - + } bli_trsmnat ( blis_side, From 4de6f2ca6d6c18f2115ea0e7409c3ff96374a758 Mon Sep 17 00:00:00 2001 From: Harihara Sudhan S Date: Wed, 27 Oct 2021 17:41:44 +0530 Subject: [PATCH 048/243] Fixed dynamic dispatch crash issue on non-zen architecture. Direct calls to zen kernels replaced by architecture dependent calls for dotv and amaxv kernels. For non-zen architecture, generic function is called using the BLIS interface. For zen architecture, direct calls to zen optimized kernels are made. Change-Id: I49fc9abc813434d6a49a23f49e47d16e95b7899f --- frame/compat/bla_amax.c | 78 ++++++++--- frame/compat/bla_dot.c | 289 ++++++++++++++++++++++++++++++---------- 2 files changed, 282 insertions(+), 85 deletions(-) diff --git a/frame/compat/bla_amax.c b/frame/compat/bla_amax.c index 894f1e1fc1..fabed6e72d 100644 --- a/frame/compat/bla_amax.c +++ b/frame/compat/bla_amax.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018-2021, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -35,7 +35,6 @@ #include "blis.h" - // // Define BLAS-to-BLIS interfaces. // @@ -107,6 +106,7 @@ f77_int isamax_ const float* x, const f77_int* incx ) { + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); AOCL_DTL_LOG_AMAX_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'S', *n, *incx); @@ -159,15 +159,36 @@ f77_int isamax_ incx0 = ( inc_t )(*incx); } - /* Call BLIS kernel. */ - bli_samaxv_zen_int - ( - n0, - x0, incx0, - &bli_index, - NULL - ); + // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. + // This function is invoked on all architectures including ‘generic’. + // Invoke architecture specific kernels only if we are sure that we are running on zen, + // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). + arch_t id = bli_arch_query_id(); + bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); + if (bamdzen) + { + /* Call BLIS kernel */ + bli_samaxv_zen_int + ( + n0, + x0, incx0, + &bli_index, + NULL + ); + } + else + { + PASTEMAC2(s,amaxv,BLIS_TAPI_EX_SUF) + ( + n0, + x0, incx0, + &bli_index, + NULL, + NULL + ); + } + /* Convert zero-based BLIS (C) index to one-based BLAS (Fortran) index. Also, if the BLAS integer size differs from the BLIS integer size, that typecast occurs here. */ @@ -239,14 +260,35 @@ f77_int idamax_ incx0 = ( inc_t )(*incx); } - /* Call BLIS kernel. */ - bli_damaxv_zen_int - ( - n0, - x0, incx0, - &bli_index, - NULL - ); + // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. + // This function is invoked on all architectures including ‘generic’. + // Invoke architecture specific kernels only if we are sure that we are running on zen, + // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). + arch_t id = bli_arch_query_id(); + bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); + + if (bamdzen) + { + /* Call BLIS kernel */ + bli_damaxv_zen_int + ( + n0, + x0, incx0, + &bli_index, + NULL + ); + } + else + { + PASTEMAC2(d,amaxv,BLIS_TAPI_EX_SUF) + ( + n0, + x0, incx0, + &bli_index, + NULL, + NULL + ); + } /* Convert zero-based BLIS (C) index to one-based BLAS (Fortran) index. Also, if the BLAS integer size differs from the BLIS diff --git a/frame/compat/bla_dot.c b/frame/compat/bla_dot.c index 7ca039aa93..2a0f815217 100644 --- a/frame/compat/bla_dot.c +++ b/frame/compat/bla_dot.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018-2021, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -154,17 +154,42 @@ float sdot_ incy0 = ( inc_t )(*incy); } - /* Call BLIS kernel. */ - bli_sdotv_zen_int10 - ( - BLIS_NO_CONJUGATE, - BLIS_NO_CONJUGATE, - n0, - x0, incx0, - y0, incy0, - &rho, - NULL - ); + // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. + // This function is invoked on all architectures including ‘generic’. + // Invoke architecture specific kernels only if we are sure that we are running on zen, + // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). + arch_t id = bli_arch_query_id(); + bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); + + if (bamdzen) + { + /* Call BLIS kernel. */ + bli_sdotv_zen_int10 + ( + BLIS_NO_CONJUGATE, + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + &rho, + NULL + ); + } + else + { + /* Call BLIS interface. */ + PASTEMAC2(s,dotv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + &rho, + NULL, + NULL + ); + } /* Finalize BLIS. */ // bli_finalize_auto(); @@ -235,17 +260,42 @@ double ddot_ incy0 = ( inc_t )(*incy); } - /* Call BLIS kernel. */ - bli_ddotv_zen_int10 - ( - BLIS_NO_CONJUGATE, - BLIS_NO_CONJUGATE, - n0, - x0, incx0, - y0, incy0, - &rho, - NULL - ); + // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. + // This function is invoked on all architectures including ‘generic’. + // Invoke architecture specific kernels only if we are sure that we are running on zen, + // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). + arch_t id = bli_arch_query_id(); + bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); + + if (bamdzen) + { + /* Call BLIS kernel. */ + bli_ddotv_zen_int10 + ( + BLIS_NO_CONJUGATE, + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + &rho, + NULL + ); + } + else + { + /* Call BLIS interface. */ + PASTEMAC2(d,dotv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + &rho, + NULL, + NULL + ); + } /* Finalize BLIS. */ // bli_finalize_auto(); @@ -322,17 +372,42 @@ scomplex cdotu_ incy0 = ( inc_t )(*incy); } - /* Call BLIS kernel. */ - bli_cdotv_zen_int5 - ( - BLIS_NO_CONJUGATE, - BLIS_NO_CONJUGATE, - n0, - x0, incx0, - y0, incy0, - &rho, - NULL - ); + // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. + // This function is invoked on all architectures including ‘generic’. + // Invoke architecture specific kernels only if we are sure that we are running on zen, + // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). + arch_t id = bli_arch_query_id(); + bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); + + if (bamdzen) + { + /* Call BLIS kernel. */ + bli_cdotv_zen_int5 + ( + BLIS_NO_CONJUGATE, + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + &rho, + NULL + ); + } + else + { + /* Call BLIS interface. */ + PASTEMAC2(c,dotv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + &rho, + NULL, + NULL + ); + } /* Finalize BLIS. */ // bli_finalize_auto(); @@ -404,18 +479,43 @@ dcomplex zdotu_ incy0 = ( inc_t )(*incy); } - /* Call BLIS kernel. */ - bli_zdotv_zen_int5 - ( - BLIS_NO_CONJUGATE, - BLIS_NO_CONJUGATE, - n0, - x0, incx0, - y0, incy0, - &rho, - NULL - ); + // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. + // This function is invoked on all architectures including ‘generic’. + // Invoke architecture specific kernels only if we are sure that we are running on zen, + // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). + arch_t id = bli_arch_query_id(); + bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); + if (bamdzen) + { + /* Call BLIS kernel. */ + bli_zdotv_zen_int5 + ( + BLIS_NO_CONJUGATE, + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + &rho, + NULL + ); + } + else + { + /* Call BLIS interface. */ + PASTEMAC2(z,dotv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + &rho, + NULL, + NULL + ); + } + /* Finalize BLIS. */ // bli_finalize_auto(); @@ -489,17 +589,42 @@ scomplex cdotc_ incy0 = ( inc_t )(*incy); } - /* Call BLIS kernel. */ - bli_cdotv_zen_int5 - ( - BLIS_CONJUGATE, - BLIS_NO_CONJUGATE, - n0, - x0, incx0, - y0, incy0, - &rho, - NULL - ); + // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. + // This function is invoked on all architectures including ‘generic’. + // Invoke architecture specific kernels only if we are sure that we are running on zen, + // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). + arch_t id = bli_arch_query_id(); + bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); + + if (bamdzen) + { + /* Call BLIS kernel. */ + bli_cdotv_zen_int5 + ( + BLIS_CONJUGATE, + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + &rho, + NULL + ); + } + else + { + /* Call BLIS interface. */ + PASTEMAC2(c,dotv,BLIS_TAPI_EX_SUF) + ( + BLIS_CONJUGATE, + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + &rho, + NULL, + NULL + ); + } /* Finalize BLIS. */ // bli_finalize_auto(); @@ -507,6 +632,7 @@ scomplex cdotc_ return rho; } + dcomplex zdotc_ ( const f77_int* n, @@ -570,17 +696,46 @@ dcomplex zdotc_ incy0 = ( inc_t )(*incy); } - /* Call BLIS kernel. */ - bli_zdotv_zen_int5 - ( - BLIS_CONJUGATE, - BLIS_NO_CONJUGATE, - n0, - x0, incx0, - y0, incy0, - &rho, - NULL - ); + // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. + // This function is invoked on all architectures including ‘generic’. + // Invoke architecture specific kernels only if we are sure that we are running on zen, + // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). + arch_t id = bli_arch_query_id(); + bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); + + if (bamdzen) + { + /* Call BLIS kernel. */ + bli_zdotv_zen_int5 + ( + BLIS_CONJUGATE, + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + &rho, + NULL + ); + } + else + { + /* Call BLIS interface. */ + PASTEMAC2(z,dotv,BLIS_TAPI_EX_SUF) + ( + BLIS_CONJUGATE, + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + &rho, + NULL, + NULL + ); + } + + + + /* Finalize BLIS. */ // bli_finalize_auto(); From 1b0a7e1c89606fcb7b22fb8076863f064b2b54fc Mon Sep 17 00:00:00 2001 From: Harsh Dave Date: Thu, 28 Oct 2021 06:23:15 -0500 Subject: [PATCH 049/243] Fixed dynamic dispatch crash issue on non-zen architecture. Removed direct calling of zen kernels in ctrsv, ztrsv interface. The BLIS binary with dynamic dispatch feature was crashing on non-zen CPUs (specifically CPUs without AVX2 support). The crash was caused by un-supported instructions in zen optimized kernels. AMD-Internal: [CPUPL-1930] Change-Id: I21f25a09cd6ffb013d16c66ea10aa9a42f7cad5b --- frame/2/trsv/bli_trsv_unf_var2.c | 41 ++++++++++++++++++++------------ 1 file changed, 26 insertions(+), 15 deletions(-) diff --git a/frame/2/trsv/bli_trsv_unf_var2.c b/frame/2/trsv/bli_trsv_unf_var2.c index 58135b0752..2d37d5a301 100644 --- a/frame/2/trsv/bli_trsv_unf_var2.c +++ b/frame/2/trsv/bli_trsv_unf_var2.c @@ -301,9 +301,8 @@ void bli_dtrsv_unf_var2 } else { - num_t dt = PASTEMAC(d,type); - kfp_af = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPYF_KER, cntx ); - b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_AF, cntx ); + kfp_af = bli_cntx_get_l1f_ker_dt( BLIS_DOUBLE, BLIS_AXPYF_KER, cntx ); + b_fuse = bli_cntx_get_blksz_def_dt( BLIS_DOUBLE, BLIS_AF, cntx ); } /* We reduce all of the possible cases down to just lower/upper. */ @@ -489,8 +488,6 @@ void bli_strsv_unf_var2 PASTECH(s, axpyf_ker_ft) kfp_af; /* Assign function pointer and fusing factor. */ -/* kfp_af = bli_saxpyf_zen_int_5; - b_fuse = 5;*/ arch_t id = bli_arch_query_id(); bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); if (bamdzen) { @@ -499,10 +496,8 @@ void bli_strsv_unf_var2 } else { - - num_t dt = PASTEMAC(s,type); - kfp_af = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPYF_KER, cntx ); - b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_AF, cntx ); + kfp_af = bli_cntx_get_l1f_ker_dt( BLIS_FLOAT, BLIS_AXPYF_KER, cntx ); + b_fuse = bli_cntx_get_blksz_def_dt( BLIS_FLOAT, BLIS_AF, cntx ); } /* We reduce all of the possible cases down to just lower/upper. */ @@ -688,9 +683,17 @@ void bli_ztrsv_unf_var2 PASTECH(z, axpyf_ker_ft) kfp_af; /* Assign function pointer and fusing factor. */ - kfp_af = bli_zaxpyf_zen_int_5; - b_fuse = 5; - + arch_t id = bli_arch_query_id(); + bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); + if (bamdzen) { + kfp_af = bli_zaxpyf_zen_int_5; + b_fuse = 5; + } + else + { + kfp_af = bli_cntx_get_l1f_ker_dt( BLIS_DCOMPLEX, BLIS_AXPYF_KER, cntx ); + b_fuse = bli_cntx_get_blksz_def_dt( BLIS_DCOMPLEX, BLIS_AF, cntx ); + } /* We reduce all of the possible cases down to just lower/upper. */ if ( bli_is_upper( uploa_trans ) ) { @@ -874,9 +877,17 @@ void bli_ctrsv_unf_var2 PASTECH(c, axpyf_ker_ft) kfp_af; /* Assign function pointer and fusing factor. */ - kfp_af = bli_caxpyf_zen_int_5; - b_fuse = 5; - + arch_t id = bli_arch_query_id(); + bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); + if (bamdzen) { + kfp_af = bli_caxpyf_zen_int_5; + b_fuse = 5; + } + else + { + kfp_af = bli_cntx_get_l1f_ker_dt( BLIS_SCOMPLEX, BLIS_AXPYF_KER, cntx ); + b_fuse = bli_cntx_get_blksz_def_dt( BLIS_SCOMPLEX, BLIS_AF, cntx ); + } /* We reduce all of the possible cases down to just lower/upper. */ if ( bli_is_upper( uploa_trans ) ) { From 235071690a203a460c3971e858beca924c68314f Mon Sep 17 00:00:00 2001 From: mkurumel Date: Wed, 27 Oct 2021 18:04:53 +0530 Subject: [PATCH 050/243] Fixed dynamic dispatch crash issue on non-zen architecture for gemv and axpy routines. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: 1. This commit fixed issue for gemv and axpy API’s. 2. The BLIS binary with dynamic dispatch feature was crashing on non-zen CPUs (specifically CPUs without AVX2 support). 3. The crash was caused by un-supported instructions in zen optimized kernels.The issue is fixed by calling only reference kernels if the architecture detected at runtime is not zen, zen2 or zen3. Change-Id: Icc6f7fdc80bc58fac1a97b1502b6f269e5e89aa4 --- frame/2/gemv/bli_gemv_unf_var1.c | 92 ++++ frame/2/gemv/bli_gemv_unf_var2.c | 313 +++++++++++- frame/2/trsv/bli_trsv_unf_var1.c | 2 + frame/2/trsv/bli_trsv_unf_var2.c | 4 + frame/compat/bla_axpy.c | 169 ++++-- frame/compat/bla_gemv.c | 848 ++++++++++++++++++------------- 6 files changed, 1043 insertions(+), 385 deletions(-) diff --git a/frame/2/gemv/bli_gemv_unf_var1.c b/frame/2/gemv/bli_gemv_unf_var1.c index 7f65f7168a..4f0054c1f1 100644 --- a/frame/2/gemv/bli_gemv_unf_var1.c +++ b/frame/2/gemv/bli_gemv_unf_var1.c @@ -144,6 +144,53 @@ void bli_dgemv_unf_var1 conja = bli_extract_conj( transa ); + // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. + // This function is invoked on all architectures including ‘generic’. + // Invoke architecture specific kernels only if we are sure that we are running on zen, + // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). + arch_t id = bli_arch_query_id(); + bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); + + if (bamdzen == 0) + { + if ( cntx == NULL ) cntx = bli_gks_query_cntx(); + const num_t dt = PASTEMAC(d,type); + double* x1; + double* y1; + PASTECH(d,dotxf_ker_ft) kfp_df; + /* Query the context for the kernel function pointer and fusing factor. */ + kfp_df = bli_cntx_get_l1f_ker_dt( dt, BLIS_DOTXF_KER, cntx ); + dim_t b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_DF, cntx ); + + for ( i = 0; i < n_iter; i += f ) + { + f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); + + A1 = a + (i )*rs_at + (0 )*cs_at; + x1 = x + (0 )*incy; + y1 = y + (i )*incy; + + /* y1 = beta * y1 + alpha * A1 * x; */ + kfp_df + ( + conja, + conjx, + n_elem, + f, + alpha, + A1, cs_at, rs_at, + x1, incx, + beta, + y1, incy, + cntx + ); + + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); + return; + } + if (incx > 1) { /* @@ -261,6 +308,51 @@ void bli_sgemv_unf_var1 conja = bli_extract_conj( transa ); + // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. + // This function is invoked on all architectures including ‘generic’. + // Invoke architecture specific kernels only if we are sure that we are running on zen, + // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). + arch_t id = bli_arch_query_id(); + bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); + + if (bamdzen == 0) + { + if ( cntx == NULL ) cntx = bli_gks_query_cntx(); + const num_t dt = PASTEMAC(s,type); + float* x1 ; + PASTECH(s,dotxf_ker_ft) kfp_df; + /* Query the context for the kernel function pointer and fusing factor. */ + kfp_df = bli_cntx_get_l1f_ker_dt( dt, BLIS_DOTXF_KER, cntx ); + b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_DF, cntx ); + + for ( i = 0; i < n_iter; i += f ) + { + f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); + + A1 = a + (i )*rs_at + (0 )*cs_at; + x1 = x + (0 )*incy; + y1 = y + (i )*incy; + + /* y1 = beta * y1 + alpha * A1 * x; */ + kfp_df + ( + conja, + conjx, + n_elem, + f, + alpha, + A1, cs_at, rs_at, + x1, incx, + beta, + y1, incy, + cntx + ); + + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); + return; + } /* Query the context for the kernel function pointer and fusing factor. */ b_fuse = 8; diff --git a/frame/2/gemv/bli_gemv_unf_var2.c b/frame/2/gemv/bli_gemv_unf_var2.c index ae1356fbea..84a67c3189 100644 --- a/frame/2/gemv/bli_gemv_unf_var2.c +++ b/frame/2/gemv/bli_gemv_unf_var2.c @@ -174,6 +174,80 @@ void bli_dgemv_unf_var2 conja = bli_extract_conj( transa ); + // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. + // This function is invoked on all architectures including ‘generic’. + // Invoke architecture specific kernels only if we are sure that we are running on zen, + // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). + arch_t id = bli_arch_query_id(); + bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); + + if (bamdzen == 0) + { + if ( cntx == NULL ) cntx = bli_gks_query_cntx(); + const num_t dt = PASTEMAC(d,type); + double* x1; + double* y1; + /* If beta is zero, use setv. Otherwise, scale by beta. */ + if ( PASTEMAC(d,eq0)( *beta ) ) + { + double* zero = PASTEMAC(d,0); + /* y = 0; */ + PASTEMAC2(d,setv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + n_elem, + zero, + y, incy, + cntx, + NULL + ); + } + else + { + /* y = beta * y; */ + PASTEMAC2(d,scalv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + n_elem, + beta, + y, incy, + cntx, + NULL + ); + } + + PASTECH(d,axpyf_ker_ft) kfp_af; + + /* Query the context for the kernel function pointer and fusing factor. */ + kfp_af = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPYF_KER, cntx ); + dim_t b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_AF, cntx ); + + for ( i = 0; i < n_iter; i += f ) + { + f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); + + A1 = a + (0 )*rs_at + (i )*cs_at; + x1 = x + (i )*incx; + y1 = y + (0 )*incy; + + /* y = y + alpha * A1 * x1; */ + kfp_af + ( + conja, + conjx, + n_elem, + f, + alpha, + A1, rs_at, cs_at, + x1, incx, + y1, incy, + cntx + ); + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); + return; + } + /* If beta is zero, use setv. Otherwise, scale by beta. */ /* y = beta * y; */ /* beta=0 case is hadled by scalv internally */ @@ -312,6 +386,78 @@ void bli_sgemv_unf_var2 conja = bli_extract_conj( transa ); + // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. + // This function is invoked on all architectures including ‘generic’. + // Invoke architecture specific kernels only if we are sure that we are running on zen, + // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). + arch_t id = bli_arch_query_id(); + bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); + + if (bamdzen == 0) + { + if ( cntx == NULL ) cntx = bli_gks_query_cntx(); + const num_t dt = PASTEMAC(s,type); + /* If beta is zero, use setv. Otherwise, scale by beta. */ + if ( PASTEMAC(s,eq0)( *beta ) ) + { + float* zero = PASTEMAC(s,0); + /* y = 0; */ + PASTEMAC2(s,setv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + n_elem, + zero, + y, incy, + cntx, + NULL + ); + } + else + { + /* y = beta * y; */ + PASTEMAC2(s,scalv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + n_elem, + beta, + y, incy, + cntx, + NULL + ); + } + + PASTECH(s,axpyf_ker_ft) kfp_af; + + /* Query the context for the kernel function pointer and fusing factor. */ + kfp_af = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPYF_KER, cntx ); + b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_AF, cntx ); + + for ( i = 0; i < n_iter; i += f ) + { + f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); + + A1 = a + (0 )*rs_at + (i )*cs_at; + x1 = x + (i )*incx; + y1 = y + (0 )*incy; + + /* y = y + alpha * A1 * x1; */ + kfp_af + ( + conja, + conjx, + n_elem, + f, + alpha, + A1, rs_at, cs_at, + x1, incx, + y1, incy, + cntx + ); + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); + return; + } + /* If beta is zero, use setv. Otherwise, scale by beta. */ /* y = beta * y; */ /* beta=0 case is hadled by scalv internally */ @@ -392,9 +538,10 @@ void bli_zgemv_unf_var2 conja = bli_extract_conj( transa ); /* If beta is zero, use setv. Otherwise, scale by beta. */ - /* y = beta * y; */ + /* y = beta * y; */ + /* beta=0 case is hadled by scalv internally */ -/* bli_zscalv_zen_int10 + /* bli_zscalv_zen_int10 ( BLIS_NO_CONJUGATE, n_elem, @@ -403,15 +550,88 @@ void bli_zgemv_unf_var2 incy, cntx );*/ + + // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. + // This function is invoked on all architectures including ‘generic’. + // Invoke architecture specific kernels only if we are sure that we are running on zen, + // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). + arch_t id = bli_arch_query_id(); + bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); + + if (bamdzen == 0) + { + if ( cntx == NULL ) cntx = bli_gks_query_cntx(); + const num_t dt = PASTEMAC(z,type); + /* If beta is zero, use setv. Otherwise, scale by beta. */ + if ( PASTEMAC(z,eq0)( *beta ) ) + { + dcomplex* zero = PASTEMAC(z,0); + /* y = 0; */ + PASTEMAC2(z,setv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + n_elem, + zero, + y, incy, + cntx, + NULL + ); + } + else + { + /* y = beta * y; */ + PASTEMAC2(z,scalv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + n_elem, + beta, + y, incy, + cntx, + NULL + ); + } + + PASTECH(z,axpyf_ker_ft) kfp_af; + + /* Query the context for the kernel function pointer and fusing factor. */ + kfp_af = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPYF_KER, cntx ); + b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_AF, cntx ); + + for ( i = 0; i < n_iter; i += f ) + { + f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); + + A1 = a + (0 )*rs_at + (i )*cs_at; + x1 = x + (i )*incx; + y1 = y + (0 )*incy; + + /* y = y + alpha * A1 * x1; */ + kfp_af + ( + conja, + conjx, + n_elem, + f, + alpha, + A1, rs_at, cs_at, + x1, incx, + y1, incy, + cntx + ); + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); + return; + } + bli_zscalv_ex - ( - BLIS_NO_CONJUGATE, - n_elem, - beta, - y, incy, - cntx, - NULL - ); + ( + BLIS_NO_CONJUGATE, + n_elem, + beta, + y, incy, + cntx, + NULL + ); if( bli_zeq0( *alpha ) ) { @@ -516,6 +736,79 @@ void bli_cgemv_unf_var2 incy, cntx );*/ + + // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. + // This function is invoked on all architectures including ‘generic’. + // Invoke architecture specific kernels only if we are sure that we are running on zen, + // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). + arch_t id = bli_arch_query_id(); + bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); + + if (bamdzen == 0) + { + if ( cntx == NULL ) cntx = bli_gks_query_cntx(); + const num_t dt = PASTEMAC(c,type); + /* If beta is zero, use setv. Otherwise, scale by beta. */ + if ( PASTEMAC(c,eq0)( *beta ) ) + { + scomplex* zero = PASTEMAC(c,0); + /* y = 0; */ + PASTEMAC2(c,setv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + n_elem, + zero, + y, incy, + cntx, + NULL + ); + } + else + { + /* y = beta * y; */ + PASTEMAC2(c,scalv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + n_elem, + beta, + y, incy, + cntx, + NULL + ); + } + + PASTECH(c,axpyf_ker_ft) kfp_af; + + /* Query the context for the kernel function pointer and fusing factor. */ + kfp_af = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPYF_KER, cntx ); + b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_AF, cntx ); + + for ( i = 0; i < n_iter; i += f ) + { + f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); + + A1 = a + (0 )*rs_at + (i )*cs_at; + x1 = x + (i )*incx; + y1 = y + (0 )*incy; + + /* y = y + alpha * A1 * x1; */ + kfp_af + ( + conja, + conjx, + n_elem, + f, + alpha, + A1, rs_at, cs_at, + x1, incx, + y1, incy, + cntx + ); + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); + return; + } + bli_cscalv_ex ( BLIS_NO_CONJUGATE, diff --git a/frame/2/trsv/bli_trsv_unf_var1.c b/frame/2/trsv/bli_trsv_unf_var1.c index 35b48d8a92..4f19e1ac5e 100644 --- a/frame/2/trsv/bli_trsv_unf_var1.c +++ b/frame/2/trsv/bli_trsv_unf_var1.c @@ -305,6 +305,7 @@ void bli_dtrsv_unf_var1 } else { + if ( cntx == NULL ) cntx = bli_gks_query_cntx(); num_t dt = PASTEMAC(d,type); kfp_df = bli_cntx_get_l1f_ker_dt( dt, BLIS_DOTXF_KER, cntx ); b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_DF, cntx ); @@ -506,6 +507,7 @@ void bli_strsv_unf_var1 } else { + if ( cntx == NULL ) cntx = bli_gks_query_cntx(); num_t dt = PASTEMAC(s,type); kfp_df = bli_cntx_get_l1f_ker_dt( dt, BLIS_DOTXF_KER, cntx ); b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_DF, cntx ); diff --git a/frame/2/trsv/bli_trsv_unf_var2.c b/frame/2/trsv/bli_trsv_unf_var2.c index 2d37d5a301..7ece8f8470 100644 --- a/frame/2/trsv/bli_trsv_unf_var2.c +++ b/frame/2/trsv/bli_trsv_unf_var2.c @@ -301,6 +301,7 @@ void bli_dtrsv_unf_var2 } else { + if ( cntx == NULL ) cntx = bli_gks_query_cntx(); kfp_af = bli_cntx_get_l1f_ker_dt( BLIS_DOUBLE, BLIS_AXPYF_KER, cntx ); b_fuse = bli_cntx_get_blksz_def_dt( BLIS_DOUBLE, BLIS_AF, cntx ); } @@ -496,6 +497,7 @@ void bli_strsv_unf_var2 } else { + if ( cntx == NULL ) cntx = bli_gks_query_cntx(); kfp_af = bli_cntx_get_l1f_ker_dt( BLIS_FLOAT, BLIS_AXPYF_KER, cntx ); b_fuse = bli_cntx_get_blksz_def_dt( BLIS_FLOAT, BLIS_AF, cntx ); } @@ -691,6 +693,7 @@ void bli_ztrsv_unf_var2 } else { + if ( cntx == NULL ) cntx = bli_gks_query_cntx(); kfp_af = bli_cntx_get_l1f_ker_dt( BLIS_DCOMPLEX, BLIS_AXPYF_KER, cntx ); b_fuse = bli_cntx_get_blksz_def_dt( BLIS_DCOMPLEX, BLIS_AF, cntx ); } @@ -885,6 +888,7 @@ void bli_ctrsv_unf_var2 } else { + if ( cntx == NULL ) cntx = bli_gks_query_cntx(); kfp_af = bli_cntx_get_l1f_ker_dt( BLIS_SCOMPLEX, BLIS_AXPYF_KER, cntx ); b_fuse = bli_cntx_get_blksz_def_dt( BLIS_SCOMPLEX, BLIS_AF, cntx ); } diff --git a/frame/compat/bla_axpy.c b/frame/compat/bla_axpy.c index 60a0d48746..41885e95d6 100644 --- a/frame/compat/bla_axpy.c +++ b/frame/compat/bla_axpy.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 21, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -146,15 +146,40 @@ void saxpy_ incy0 = ( inc_t )(*incy); } - bli_saxpyv_zen_int10( - BLIS_NO_CONJUGATE, - n0, - (float*)alpha, - x0, incx0, - y0, incy0, - NULL - ); - + // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. + // This function is invoked on all architectures including ‘generic’. + // Invoke architecture specific kernels only if we are sure that we are running on zen, + // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). + arch_t id = bli_arch_query_id(); + bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); + + if (bamdzen) + { + bli_saxpyv_zen_int10 + ( + BLIS_NO_CONJUGATE, + n0, + (float*)alpha, + x0, incx0, + y0, incy0, + NULL + ); + + } + else + { + PASTEMAC2(s,axpyv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + n0, + (float*)alpha, + x0, incx0, + y0, incy0, + NULL, + NULL + ); + + } /* Finalize BLIS. */ // bli_finalize_auto(); AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); @@ -218,14 +243,40 @@ void daxpy_ incy0 = ( inc_t )(*incy); } - bli_daxpyv_zen_int10( - BLIS_NO_CONJUGATE, - n0, - (double*)alpha, - x0, incx0, - y0, incy0, - NULL - ); + // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. + // This function is invoked on all architectures including ‘generic’. + // Invoke architecture specific kernels only if we are sure that we are running on zen, + // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). + arch_t id = bli_arch_query_id(); + bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); + + if (bamdzen) + { + bli_daxpyv_zen_int10 + ( + BLIS_NO_CONJUGATE, + n0, + (double*)alpha, + x0, incx0, + y0, incy0, + NULL + ); + + } + else + { + PASTEMAC2(d,axpyv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + n0, + (double*)alpha, + x0, incx0, + y0, incy0, + NULL, + NULL + ); + + } AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); /* Finalize BLIS. */ @@ -290,14 +341,39 @@ void caxpy_ incy0 = ( inc_t )(*incy); } - bli_caxpyv_zen_int5( - BLIS_NO_CONJUGATE, - n0, - (scomplex*)alpha, - x0, incx0, - y0, incy0, - NULL - ); + // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. + // This function is invoked on all architectures including ‘generic’. + // Invoke architecture specific kernels only if we are sure that we are running on zen, + // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). + arch_t id = bli_arch_query_id(); + bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); + + if (bamdzen) + { + bli_caxpyv_zen_int5 + ( + BLIS_NO_CONJUGATE, + n0, + (scomplex*)alpha, + x0, incx0, + y0, incy0, + NULL + ); + + } + else + { + PASTEMAC2(c,axpyv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + n0, + (scomplex*)alpha, + x0, incx0, + y0, incy0, + NULL, + NULL + ); + } AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); /* Finalize BLIS. */ @@ -363,14 +439,39 @@ void zaxpy_ incy0 = ( inc_t )(*incy); } - bli_zaxpyv_zen_int5( - BLIS_NO_CONJUGATE, - n0, - (dcomplex*)alpha, - x0, incx0, - y0, incy0, - NULL - ); + // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. + // This function is invoked on all architectures including ‘generic’. + // Invoke architecture specific kernels only if we are sure that we are running on zen, + // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). + arch_t id = bli_arch_query_id(); + bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); + + if (bamdzen) + { + bli_zaxpyv_zen_int5 + ( + BLIS_NO_CONJUGATE, + n0, + (dcomplex*)alpha, + x0, incx0, + y0, incy0, + NULL + ); + + } + else + { + PASTEMAC2(z,axpyv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + n0, + (dcomplex*)alpha, + x0, incx0, + y0, incy0, + NULL, + NULL + ); + } AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); /* Finalize BLIS. */ diff --git a/frame/compat/bla_gemv.c b/frame/compat/bla_gemv.c index 84917dc71f..e9b210bbc1 100644 --- a/frame/compat/bla_gemv.c +++ b/frame/compat/bla_gemv.c @@ -160,146 +160,182 @@ void dgemv_ double* y, const f77_int* incy ) { - trans_t blis_transa; - dim_t m0, n0; - dim_t m_y, n_x; - double* x0; - double* y0; - inc_t incx0; - inc_t incy0; - inc_t rs_a, cs_a; - - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); - AOCL_DTL_LOG_GEMV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'D', *transa, *m, *n, (void*)alpha, *lda, *incx, (void*)beta, *incy); - - /* Perform BLAS parameter checking. */ - PASTEBLACHK(gemv) + trans_t blis_transa; + dim_t m0, n0; + dim_t m_y, n_x; + double* x0; + double* y0; + inc_t incx0; + inc_t incy0; + inc_t rs_a, cs_a; + + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); + AOCL_DTL_LOG_GEMV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'D', *transa, *m, *n, (void*)alpha, *lda, *incx, (void*)beta, *incy); + + /* Perform BLAS parameter checking. */ + PASTEBLACHK(gemv) ( - MKSTR(d), - MKSTR(gemv), - transa, - m, - n, - lda, - incx, - incy - ); - - if (*m == 0 || *n == 0) { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return; - } - - /* Map BLAS chars to their corresponding BLIS enumerated type value. */ - if ( *transa == 'n' || *transa == 'N' ) blis_transa = BLIS_NO_TRANSPOSE; - else if ( *transa == 't' || *transa == 'T' ) blis_transa = BLIS_TRANSPOSE; - else if ( *transa == 'c' || *transa == 'C' ) blis_transa = BLIS_CONJ_TRANSPOSE; - else - { - // See comment for bli_param_map_netlib_to_blis_side() above. - //bli_check_error_code( BLIS_INVALID_TRANS ); - blis_transa = BLIS_NO_TRANSPOSE; - } - - /* Convert/typecast negative values of m and n to zero. */ - if ( *m < 0 ) m0 = ( dim_t )0; - else m0 = ( dim_t )(*m); - - if ( *n < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(*n); - - /* Determine the dimensions of x and y so we can adjust the increments, - if necessary.*/ - if ( bli_does_notrans( blis_transa ) ) { m_y = m0; n_x = n0; } - else { m_y = n0; n_x = m0; } - - /* BLAS handles cases where trans(A) has no columns, and x has no elements, - in a peculiar way. In these situations, BLAS returns without performing - any action, even though most sane interpretations of gemv would have the - the operation reduce to y := beta * y. Here, we catch those cases that - BLAS would normally mishandle and emulate the BLAS exactly so as to - provide "bug-for-bug" compatibility. Note that this extreme level of - compatibility would not be as much of an issue if it weren't for the - fact that some BLAS test suites actually test for these cases. Also, it - should be emphasized that BLIS, if called natively, does NOT exhibit - this quirky behavior; it will scale y by beta, as one would expect. */ - if ( m_y > 0 && n_x == 0 ) - { - /* Finalize BLIS. */ - // bli_finalize_auto(); + MKSTR(d), + MKSTR(gemv), + transa, + m, + n, + lda, + incx, + incy + ); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return; + if (*m == 0 || *n == 0) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return; } - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - if ( *incx < 0 ) + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + if ( *transa == 'n' || *transa == 'N' ) blis_transa = BLIS_NO_TRANSPOSE; + else if ( *transa == 't' || *transa == 'T' ) blis_transa = BLIS_TRANSPOSE; + else if ( *transa == 'c' || *transa == 'C' ) blis_transa = BLIS_CONJ_TRANSPOSE; + else { - x0 = ((double*)x) + (n_x-1)*(-*incx); - incx0 = ( inc_t )(*incx); + // See comment for bli_param_map_netlib_to_blis_side() above. + //bli_check_error_code( BLIS_INVALID_TRANS ); + blis_transa = BLIS_NO_TRANSPOSE; + } + + /* Convert/typecast negative values of m and n to zero. */ + if ( *m < 0 ) m0 = ( dim_t )0; + else m0 = ( dim_t )(*m); + + if ( *n < 0 ) n0 = ( dim_t )0; + else n0 = ( dim_t )(*n); + + /* Determine the dimensions of x and y so we can adjust the increments, + if necessary.*/ + if ( bli_does_notrans( blis_transa ) ) + { + m_y = m0; + n_x = n0; + } + else + { + m_y = n0; + n_x = m0; } - else + + /* BLAS handles cases where trans(A) has no columns, and x has no elements, + in a peculiar way. In these situations, BLAS returns without performing + any action, even though most sane interpretations of gemv would have the + the operation reduce to y := beta * y. Here, we catch those cases that + BLAS would normally mishandle and emulate the BLAS exactly so as to + provide "bug-for-bug" compatibility. Note that this extreme level of + compatibility would not be as much of an issue if it weren't for the + fact that some BLAS test suites actually test for these cases. Also, it + should be emphasized that BLIS, if called natively, does NOT exhibit + this quirky behavior; it will scale y by beta, as one would expect. */ + if ( m_y > 0 && n_x == 0 ) { - x0 = ((double*)x); - incx0 = ( inc_t )(*incx); + /* Finalize BLIS. */ + // bli_finalize_auto(); + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return; } - if ( *incy < 0 ) + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + if ( *incx < 0 ) { - y0 = ((double*)y) + (m_y-1)*(-*incy); - incy0 = ( inc_t )(*incy); + x0 = ((double*)x) + (n_x-1)*(-*incx); + incx0 = ( inc_t )(*incx); } - else + else { - y0 = ((double*)y); - incy0 = ( inc_t )(*incy); + x0 = ((double*)x); + incx0 = ( inc_t )(*incx); } - /* Set the row and column strides of A. */ - rs_a = 1; - cs_a = *lda; + if ( *incy < 0 ) + { + y0 = ((double*)y) + (m_y-1)*(-*incy); + incy0 = ( inc_t )(*incy); + } + else + { + y0 = ((double*)y); + incy0 = ( inc_t )(*incy); + } + + /* Set the row and column strides of A. */ + rs_a = 1; + cs_a = *lda; - /* Call variants based on transpose value. */ - if(bli_does_notrans(blis_transa)) + // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. + // This function is invoked on all architectures including ‘generic’. + // Invoke architecture specific kernels only if we are sure that we are running on zen, + // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). + arch_t id = bli_arch_query_id(); + bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); + + if (bamdzen == 0) { - //variant_2 is chosen for column-storage - // and uses axpyf-based implementation - bli_dgemv_unf_var2 - ( - blis_transa, - BLIS_NO_CONJUGATE, - m0, - n0, - (double*)alpha, - (double*)a, rs_a, cs_a, - x0, incx0, - (double*)beta, - y0, incy0, - NULL - ); - } - else - { - //var_1 is chosen for row-storage - //and uses dotxf-based implementation - bli_dgemv_unf_var1 + /* Call BLIS interface. */ + PASTEMAC2(d,gemv,BLIS_TAPI_EX_SUF) ( - blis_transa, - BLIS_NO_CONJUGATE, - m0, - n0, - (double*)alpha, - (double*)a, rs_a, cs_a, - x0, incx0, - (double*)beta, - y0, incy0, - NULL - ); + blis_transa, + BLIS_NO_CONJUGATE, + m0, + n0, + (double*)alpha, + (double*)a, rs_a, cs_a, + x0, incx0, + (double*)beta, + y0, incy0, + NULL, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return; + } + /* Call variants based on transpose value. */ + if(bli_does_notrans(blis_transa)) + { + //variant_2 is chosen for column-storage + // and uses axpyf-based implementation + bli_dgemv_unf_var2 + ( + blis_transa, + BLIS_NO_CONJUGATE, + m0, + n0, + (double*)alpha, + (double*)a, rs_a, cs_a, + x0, incx0, + (double*)beta, + y0, incy0, + NULL + ); + } + else + { + //var_1 is chosen for row-storage + //and uses dotxf-based implementation + bli_dgemv_unf_var1 + ( + blis_transa, + BLIS_NO_CONJUGATE, + m0, + n0, + (double*)alpha, + (double*)a, rs_a, cs_a, + x0, incx0, + (double*)beta, + y0, incy0, + NULL + ); } - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); } void sgemv_ @@ -314,141 +350,176 @@ void sgemv_ float* y, const f77_int* incy ) { - trans_t blis_transa; - dim_t m0, n0; - dim_t m_y, n_x; - float* x0; - float* y0; - inc_t incx0; - inc_t incy0; - inc_t rs_a, cs_a; - - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); - AOCL_DTL_LOG_GEMV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'S', *transa, *m, *n, (void*)alpha, *lda, *incx, (void*)beta, *incy); - /* Perform BLAS parameter checking. */ - PASTEBLACHK(gemv) + trans_t blis_transa; + dim_t m0, n0; + dim_t m_y, n_x; + float* x0; + float* y0; + inc_t incx0; + inc_t incy0; + inc_t rs_a, cs_a; + + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); + AOCL_DTL_LOG_GEMV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'S', *transa, *m, *n, (void*)alpha, *lda, *incx, (void*)beta, *incy); + /* Perform BLAS parameter checking. */ + PASTEBLACHK(gemv) ( - MKSTR(s), - MKSTR(gemv), - transa, - m, - n, - lda, - incx, - incy - ); - - if (*m == 0 || *n == 0) { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return; - } - - /* Map BLAS chars to their corresponding BLIS enumerated type value. */ - if ( *transa == 'n' || *transa == 'N' ) blis_transa = BLIS_NO_TRANSPOSE; - else if ( *transa == 't' || *transa == 'T' ) blis_transa = BLIS_TRANSPOSE; - else if ( *transa == 'c' || *transa == 'C' ) blis_transa = BLIS_CONJ_TRANSPOSE; - else - { - // See comment for bli_param_map_netlib_to_blis_side() above. - //bli_check_error_code( BLIS_INVALID_TRANS ); - blis_transa = BLIS_NO_TRANSPOSE; - } - - /* Convert/typecast negative values of m and n to zero. */ - if ( *m < 0 ) m0 = ( dim_t )0; - else m0 = ( dim_t )(*m); - - if ( *n < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(*n); - - /* Determine the dimensions of x and y so we can adjust the increments, - if necessary.*/ - if ( bli_does_notrans( blis_transa ) ) { m_y = m0; n_x = n0; } - else { m_y = n0; n_x = m0; } - - /* BLAS handles cases where trans(A) has no columns, and x has no elements, - in a peculiar way. In these situations, BLAS returns without performing - any action, even though most sane interpretations of gemv would have the - the operation reduce to y := beta * y. Here, we catch those cases that - BLAS would normally mishandle and emulate the BLAS exactly so as to - provide "bug-for-bug" compatibility. Note that this extreme level of - compatibility would not be as much of an issue if it weren't for the - fact that some BLAS test suites actually test for these cases. Also, it - should be emphasized that BLIS, if called natively, does NOT exhibit - this quirky behavior; it will scale y by beta, as one would expect. */ - if ( m_y > 0 && n_x == 0 ) - { - /* Finalize BLIS. */ - // bli_finalize_auto(); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return; + MKSTR(s), + MKSTR(gemv), + transa, + m, + n, + lda, + incx, + incy + ); + + if (*m == 0 || *n == 0) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return; + } + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + if ( *transa == 'n' || *transa == 'N' ) blis_transa = BLIS_NO_TRANSPOSE; + else if ( *transa == 't' || *transa == 'T' ) blis_transa = BLIS_TRANSPOSE; + else if ( *transa == 'c' || *transa == 'C' ) blis_transa = BLIS_CONJ_TRANSPOSE; + else + { + // See comment for bli_param_map_netlib_to_blis_side() above. + //bli_check_error_code( BLIS_INVALID_TRANS ); + blis_transa = BLIS_NO_TRANSPOSE; + } + + /* Convert/typecast negative values of m and n to zero. */ + if ( *m < 0 ) m0 = ( dim_t )0; + else m0 = ( dim_t )(*m); + + if ( *n < 0 ) n0 = ( dim_t )0; + else n0 = ( dim_t )(*n); + + /* Determine the dimensions of x and y so we can adjust the increments, + if necessary.*/ + if ( bli_does_notrans( blis_transa ) ) + { + m_y = m0; + n_x = n0; + } + else + { + m_y = n0; + n_x = m0; + } + + /* BLAS handles cases where trans(A) has no columns, and x has no elements, + in a peculiar way. In these situations, BLAS returns without performing + any action, even though most sane interpretations of gemv would have the + the operation reduce to y := beta * y. Here, we catch those cases that + BLAS would normally mishandle and emulate the BLAS exactly so as to + provide "bug-for-bug" compatibility. Note that this extreme level of + compatibility would not be as much of an issue if it weren't for the + fact that some BLAS test suites actually test for these cases. Also, it + should be emphasized that BLIS, if called natively, does NOT exhibit + this quirky behavior; it will scale y by beta, as one would expect. */ + if ( m_y > 0 && n_x == 0 ) + { + /* Finalize BLIS. */ + // bli_finalize_auto(); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return; } - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - if ( *incx < 0 ) + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + if ( *incx < 0 ) { - x0 = ((float*)x) + (n_x-1)*(-*incx); - incx0 = ( inc_t )(*incx); + x0 = ((float*)x) + (n_x-1)*(-*incx); + incx0 = ( inc_t )(*incx); } - else + else { - x0 = ((float*)x); - incx0 = ( inc_t )(*incx); + x0 = ((float*)x); + incx0 = ( inc_t )(*incx); } - if ( *incy < 0 ) + if ( *incy < 0 ) { - y0 = ((float*)y) + (m_y-1)*(-*incy); - incy0 = ( inc_t )(*incy); + y0 = ((float*)y) + (m_y-1)*(-*incy); + incy0 = ( inc_t )(*incy); } - else + else { - y0 = ((float*)y); - incy0 = ( inc_t )(*incy); + y0 = ((float*)y); + incy0 = ( inc_t )(*incy); } - /* Set the row and column strides of A. */ - rs_a = 1; - cs_a = *lda; + /* Set the row and column strides of A. */ + rs_a = 1; + cs_a = *lda; - /* Call variants based on transpose value. */ - if(bli_does_notrans(blis_transa)) + // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. + // This function is invoked on all architectures including ‘generic’. + // Invoke architecture specific kernels only if we are sure that we are running on zen, + // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). + arch_t id = bli_arch_query_id(); + bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); + + if (bamdzen == 0) + { + /* Call BLIS interface. */ + PASTEMAC2(s,gemv,BLIS_TAPI_EX_SUF) + ( + blis_transa, + BLIS_NO_CONJUGATE, + m0, + n0, + (float*)alpha, + (float*)a, rs_a, cs_a, + x0, incx0, + (float*)beta, + y0, incy0, + NULL, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return; + } + + /* Call variants based on transpose value. */ + if(bli_does_notrans(blis_transa)) { - bli_sgemv_unf_var2 + bli_sgemv_unf_var2 ( - blis_transa, - BLIS_NO_CONJUGATE, - m0, - n0, - (float*)alpha, - (float*)a, rs_a, cs_a, - x0, incx0, - (float*)beta, - y0, incy0, - NULL - ); - - } - else - { - bli_sgemv_unf_var1 + blis_transa, + BLIS_NO_CONJUGATE, + m0, + n0, + (float*)alpha, + (float*)a, rs_a, cs_a, + x0, incx0, + (float*)beta, + y0, incy0, + NULL + ); + } + else + { + bli_sgemv_unf_var1 ( - blis_transa, - BLIS_NO_CONJUGATE, - m0, - n0, - (float*)alpha, - (float*)a, rs_a, cs_a, - x0, incx0, - (float*)beta, - y0, incy0, - NULL - ); - + blis_transa, + BLIS_NO_CONJUGATE, + m0, + n0, + (float*)alpha, + (float*)a, rs_a, cs_a, + x0, incx0, + (float*)beta, + y0, incy0, + NULL + ); } - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); } @@ -489,8 +560,9 @@ void cgemv_ incy ); - if (*m == 0 || *n == 0) { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + if (*m == 0 || *n == 0) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); return; } @@ -513,20 +585,20 @@ void cgemv_ else n0 = (dim_t)(*n); /* Determine the dimensions of x and y so we can adjust the increments, - if necessary.*/ + if necessary.*/ if( bli_does_notrans( blis_transa ) ) { m_y = m0, n_x = n0; } else { m_y = n0; n_x = m0; } /* BLAS handles cases where trans(A) has no columns, and x has no elements, - in a peculiar way. In these situations, BLAS returns without performing - any action, even though most sane interpretations of gemv would have the - the operation reduce to y := beta * y. Here, we catch those cases that - BLAS would normally mishandle and emulate the BLAS exactly so as to - provide "bug-for-bug" compatibility. Note that this extreme level of - compatibility would not be as much of an issue if it weren't for the - fact that some BLAS test suites actually test for these cases. Also, it - should be emphasized that BLIS, if called natively, does NOT exhibit - this quirky behavior; it will scale y by beta, as one would expect. */ + in a peculiar way. In these situations, BLAS returns without performing + any action, even though most sane interpretations of gemv would have the + the operation reduce to y := beta * y. Here, we catch those cases that + BLAS would normally mishandle and emulate the BLAS exactly so as to + provide "bug-for-bug" compatibility. Note that this extreme level of + compatibility would not be as much of an issue if it weren't for the + fact that some BLAS test suites actually test for these cases. Also, it + should be emphasized that BLIS, if called natively, does NOT exhibit + this quirky behavior; it will scale y by beta, as one would expect. */ if ( m_y > 0 && n_x == 0 ) { @@ -535,7 +607,7 @@ void cgemv_ } /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ + use positive increments instead. */ if( *incx < 0 ) { x0 = ((scomplex*)x) + (n_x-1)*(-*incx); @@ -562,71 +634,118 @@ void cgemv_ rs_a = 1; cs_a = *lda; + // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. + // This function is invoked on all architectures including ‘generic’. + // Invoke architecture specific kernels only if we are sure that we are running on zen, + // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). + arch_t id = bli_arch_query_id(); + bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); + if( m_y == 1 ) { conj_t conja = bli_extract_conj(blis_transa); scomplex rho; - bli_cdotv_zen_int5 - ( - conja, - BLIS_NO_CONJUGATE, - n_x, - (scomplex*)a, bli_is_notrans(blis_transa)?cs_a:rs_a, - x0, incx0, - &rho, - NULL - ); - scomplex yval = *y0; - if(!bli_ceq0(*beta)) - { - bli_cscals( *beta, yval ); - } - else - { - bli_csetsc( 0.0, 0.0, &yval); - } - if(!bli_ceq0(*alpha)) - { + if (bamdzen) + { + bli_cdotv_zen_int5 + ( + conja, + BLIS_NO_CONJUGATE, + n_x, + (scomplex*)a, bli_is_notrans(blis_transa)?cs_a:rs_a, + x0, incx0, + &rho, + NULL + ); + } + else + { + /* Call BLIS interface. */ + PASTEMAC2(c,dotv,BLIS_TAPI_EX_SUF) + ( + conja, + BLIS_NO_CONJUGATE, + n_x, + (scomplex*)a, bli_is_notrans(blis_transa)?cs_a:rs_a, + x0, incx0, + &rho, + NULL, + NULL + ); + } + + scomplex yval = *y0; + if(!bli_ceq0(*beta)) + { + bli_cscals( *beta, yval ); + } + else + { + bli_csetsc( 0.0, 0.0, &yval); + } + if(!bli_ceq0(*alpha)) + { bli_caxpys( *alpha, rho, yval); - } - y0->real = yval.real; + } + y0->real = yval.real; y0->imag = yval.imag; AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); return; } - + + if (bamdzen == 0) + { + /* Call BLIS interface. */ + PASTEMAC2(c,gemv,BLIS_TAPI_EX_SUF) + ( + blis_transa, + BLIS_NO_CONJUGATE, + m0, + n0, + (scomplex*)alpha, + (scomplex*)a, rs_a, cs_a, + x0, incx0, + (scomplex*)beta, + y0, incy0, + NULL, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return; + } + /* call variants based on transpose value */ if( bli_does_notrans( blis_transa ) ) { bli_cgemv_unf_var2 ( - blis_transa, - BLIS_NO_CONJUGATE, - m0, - n0, - (scomplex*)alpha, - (scomplex*)a, rs_a, cs_a, - x0, incx0, - (scomplex*)beta, - y0, incy0, - NULL + blis_transa, + BLIS_NO_CONJUGATE, + m0, + n0, + (scomplex*)alpha, + (scomplex*)a, rs_a, cs_a, + x0, incx0, + (scomplex*)beta, + y0, incy0, + NULL ); } else { bli_cgemv_unf_var1 - ( - blis_transa, - BLIS_NO_CONJUGATE, - m0, - n0, - (scomplex*)alpha, - (scomplex*)a, rs_a, cs_a, - x0, incx0, - (scomplex*)beta, - y0, incy0, - NULL + ( + blis_transa, + BLIS_NO_CONJUGATE, + m0, + n0, + (scomplex*)alpha, + (scomplex*)a, rs_a, cs_a, + x0, incx0, + (scomplex*)beta, + y0, incy0, + NULL ); } @@ -671,8 +790,9 @@ void zgemv_ incy ); - if (*m == 0 || *n == 0) { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + if (*m == 0 || *n == 0) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); return; } @@ -744,73 +864,119 @@ void zgemv_ rs_a = 1; cs_a = *lda; + // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. + // This function is invoked on all architectures including ‘generic’. + // Invoke architecture specific kernels only if we are sure that we are running on zen, + // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). + arch_t id = bli_arch_query_id(); + bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); + if( m_y == 1 ) { conj_t conja = bli_extract_conj(blis_transa); dcomplex rho; - - bli_zdotv_zen_int5 - ( - conja, - BLIS_NO_CONJUGATE, - n_x, - (dcomplex*)a, bli_is_notrans(blis_transa)?cs_a:rs_a, - x0, incx0, - &rho, - NULL - ); - - dcomplex yval = *y0; - if(!bli_zeq0(*beta)) - { - bli_zscals( *beta, yval ); - } - else - { - bli_zsetsc( 0.0, 0.0, &yval); - } - if(!bli_zeq0(*alpha)) - { + + if (bamdzen) + { + bli_zdotv_zen_int5 + ( + conja, + BLIS_NO_CONJUGATE, + n_x, + (dcomplex*)a, bli_is_notrans(blis_transa)?cs_a:rs_a, + x0, incx0, + &rho, + NULL + ); + } + else + { + /* Call BLIS interface. */ + PASTEMAC2(z,dotv,BLIS_TAPI_EX_SUF) + ( + conja, + BLIS_NO_CONJUGATE, + n_x, + (dcomplex*)a, bli_is_notrans(blis_transa)?cs_a:rs_a, + x0, incx0, + &rho, + NULL, + NULL + ); + } + + dcomplex yval = *y0; + if(!bli_zeq0(*beta)) + { + bli_zscals( *beta, yval ); + } + else + { + bli_zsetsc( 0.0, 0.0, &yval); + } + if(!bli_zeq0(*alpha)) + { bli_zaxpys( *alpha, rho, yval); - } - y0->real = yval.real; + } + y0->real = yval.real; y0->imag = yval.imag; AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); return; } + if (bamdzen == 0) + { + /* Call BLIS interface. */ + PASTEMAC2(z,gemv,BLIS_TAPI_EX_SUF) + ( + blis_transa, + BLIS_NO_CONJUGATE, + m0, + n0, + (dcomplex*)alpha, + (dcomplex*)a, rs_a, cs_a, + x0, incx0, + (dcomplex*)beta, + y0, incy0, + NULL, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return; + } + /* call variants based on transpose value */ if( bli_does_notrans( blis_transa ) ) { bli_zgemv_unf_var2 ( - blis_transa, - BLIS_NO_CONJUGATE, - m0, - n0, - (dcomplex*)alpha, - (dcomplex*)a, rs_a, cs_a, - x0, incx0, - (dcomplex*)beta, - y0, incy0, - NULL + blis_transa, + BLIS_NO_CONJUGATE, + m0, + n0, + (dcomplex*)alpha, + (dcomplex*)a, rs_a, cs_a, + x0, incx0, + (dcomplex*)beta, + y0, incy0, + NULL ); } else { bli_zgemv_unf_var1 ( - blis_transa, - BLIS_NO_CONJUGATE, - m0, - n0, - (dcomplex*)alpha, - (dcomplex*)a, rs_a, cs_a, - x0, incx0, - (dcomplex*)beta, - y0, incy0, - NULL + blis_transa, + BLIS_NO_CONJUGATE, + m0, + n0, + (dcomplex*)alpha, + (dcomplex*)a, rs_a, cs_a, + x0, incx0, + (dcomplex*)beta, + y0, incy0, + NULL ); } From 7a15aa9c87e5bd83cf14193752211c54f1c35cc9 Mon Sep 17 00:00:00 2001 From: Dipal M Zambare Date: Tue, 2 Nov 2021 13:25:47 +0530 Subject: [PATCH 051/243] Fixed xGEMM dynamic dispatch crash on ST library. Small gemm implemenation is called from gemmnat path when library is built as multi-threaded small gemm is completely disabled. For single threaded the crash is fixed by disabling small gemm on generic architecture. AMD-Internal: [CPUPL-1930] Change-Id: If718870d89909cef908a1c23918b7ef6f7d80f7a --- kernels/zen/3/bli_gemm_small.c | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/kernels/zen/3/bli_gemm_small.c b/kernels/zen/3/bli_gemm_small.c index 0b9f9c79ef..d9c4047ec4 100644 --- a/kernels/zen/3/bli_gemm_small.c +++ b/kernels/zen/3/bli_gemm_small.c @@ -113,7 +113,19 @@ err_t bli_gemm_small #ifdef BLIS_ENABLE_MULTITHREADING AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); return BLIS_NOT_YET_IMPLEMENTED; +#else + // When dynamic dispatch is enabled i.e. library is built for 'amdzen' configuration. + // Invoke architecture specific kernels only if we are sure that we are running on zen, + // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). + arch_t id = bli_arch_query_id(); + bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); + + if (0 == bamdzen) + { + return BLIS_NOT_YET_IMPLEMENTED; + } #endif + // If alpha is zero, scale by beta and return. if (bli_obj_equals(alpha, &BLIS_ZERO)) { From fd8a3aace9f23deb8cc73214687244f41d23fcfb Mon Sep 17 00:00:00 2001 From: Dipal M Zambare Date: Tue, 16 Nov 2021 14:35:25 +0530 Subject: [PATCH 052/243] Added support for zen4 architecture - Added configuration option for zen4 architecture - Added auto-detection of zen4 architecture - Added zen4 configuration for all checks related to AMD specific optimizations AMD-Internal: [CPUPL-1937] Change-Id: I1a1a45de04653f725aa53c30dffb6c0f7cc6e39a --- CMakeLists.txt | 15 +- build/blis_ref_kernel_mirror.py | 16 +- config/CMakeLists.txt | 6 +- config/zen4/CMakeLists.txt | 7 + config/zen4/bli_cntx_init_zen4.c | 285 +++++++++++++++++++++++++++++++ config/zen4/bli_family_zen4.h | 59 +++++++ config/zen4/make_defs.mk | 141 +++++++++++++++ config_registry | 3 +- frame/2/gemv/bli_gemv_unf_var1.c | 10 +- frame/2/gemv/bli_gemv_unf_var2.c | 20 ++- frame/2/trsv/bli_trsv_unf_var1.c | 12 +- frame/2/trsv/bli_trsv_unf_var2.c | 24 ++- frame/base/bli_arch.c | 4 + frame/base/bli_cpuid.c | 42 +++++ frame/base/bli_cpuid.h | 3 +- frame/base/bli_gks.c | 15 +- frame/compat/bla_amax.c | 16 +- frame/compat/bla_axpy.c | 20 ++- frame/compat/bla_copy.c | 10 +- frame/compat/bla_dot.c | 84 +++++---- frame/compat/bla_gemm.c | 5 +- frame/compat/bla_gemv.c | 36 ++-- frame/compat/bla_scal.c | 12 +- frame/compat/bla_swap.c | 12 +- frame/compat/bla_trsm.c | 12 +- frame/include/bli_arch_config.h | 3 + frame/include/bli_type_defs.h | 13 +- kernels/zen/3/bli_gemm_small.c | 39 +++-- kernels/zen4/README | 1 + ref_kernels/CMakeLists.txt | 1 + 30 files changed, 812 insertions(+), 114 deletions(-) create mode 100644 config/zen4/CMakeLists.txt create mode 100644 config/zen4/bli_cntx_init_zen4.c create mode 100644 config/zen4/bli_family_zen4.h create mode 100644 config/zen4/make_defs.mk create mode 100644 kernels/zen4/README diff --git a/CMakeLists.txt b/CMakeLists.txt index 8d892463a7..aebb509d73 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -50,20 +50,30 @@ elseif (${AOCL_BLIS_FAMILY} STREQUAL "zen3") add_definitions(-DBLIS_KERNELS_ZEN2) add_definitions(-DBLIS_KERNELS_ZEN) add_definitions(-DBLIS_KERNELS_HASWELL) +elseif (${AOCL_BLIS_FAMILY} STREQUAL "zen4") + add_definitions(-DBLIS_FAMILY_ZEN4) + add_definitions(-DBLIS_CONFIG_ZEN4) + add_definitions(-DBLIS_KERNELS_ZEN4) + add_definitions(-DBLIS_KERNELS_ZEN3) + add_definitions(-DBLIS_KERNELS_ZEN2) + add_definitions(-DBLIS_KERNELS_ZEN) + add_definitions(-DBLIS_KERNELS_HASWELL) elseif (${AOCL_BLIS_FAMILY} STREQUAL "amdzen") set(AOCL_BLIS_ZEN FALSE) add_definitions(-DBLIS_FAMILY_AMDZEN) + add_definitions(-DBLIS_CONFIG_ZEN4) add_definitions(-DBLIS_CONFIG_ZEN3) add_definitions(-DBLIS_CONFIG_ZEN2) add_definitions(-DBLIS_CONFIG_ZEN) add_definitions(-DBLIS_CONFIG_GENERIC) + add_definitions(-DBLIS_KERNELS_ZEN4) add_definitions(-DBLIS_KERNELS_ZEN3) add_definitions(-DBLIS_KERNELS_ZEN2) add_definitions(-DBLIS_KERNELS_HASWELL) add_definitions(-DBLIS_KERNELS_ZEN) add_definitions(-DBLIS_KERNELS_GENERIC) else () - message(FATAL_ERROR "Wrong machine configuration. Select one of zen, zen2, zen3 or amdzen") + message(FATAL_ERROR "Wrong machine configuration. Select one of zen, zen2, zen3, zen4 or amdzen") endif () set(TARGET_ARCH ${AOCL_BLIS_FAMILY}) @@ -393,11 +403,13 @@ include_directories(${CMAKE_SOURCE_DIR}/config/generic) include_directories(${CMAKE_SOURCE_DIR}/config/zen) include_directories(${CMAKE_SOURCE_DIR}/config/zen2) include_directories(${CMAKE_SOURCE_DIR}/config/zen3) +include_directories(${CMAKE_SOURCE_DIR}/config/zen4) if(${AOCL_BLIS_FAMILY} STREQUAL "amdzen") include_directories(${CMAKE_BINARY_DIR}/ref_kernels/generic) include_directories(${CMAKE_BINARY_DIR}/ref_kernels/zen) include_directories(${CMAKE_BINARY_DIR}/ref_kernels/zen2) include_directories(${CMAKE_BINARY_DIR}/ref_kernels/zen3) + include_directories(${CMAKE_BINARY_DIR}/ref_kernels/zen4) endif() include_directories(${CMAKE_SOURCE_DIR}/ref_kernels) include_directories(${CMAKE_SOURCE_DIR}/kernels) @@ -432,6 +444,7 @@ elseif (${AOCL_BLIS_FAMILY} STREQUAL "zen2") " ${CMAKE_CURRENT_SOURCE_DIR}/config/zen/" " ${CMAKE_CURRENT_SOURCE_DIR}/config/zen2/" " ${CMAKE_CURRENT_SOURCE_DIR}/config/zen3/" + " ${CMAKE_CURRENT_SOURCE_DIR}/config/zen4/" " ${CMAKE_CURRENT_SOURCE_DIR}/config/generic/" " ${CMAKE_CURRENT_SOURCE_DIR}/kernels/zen/" " ${CMAKE_CURRENT_SOURCE_DIR}/kernels/haswell/" diff --git a/build/blis_ref_kernel_mirror.py b/build/blis_ref_kernel_mirror.py index b756eb30b6..0dec5d66fa 100644 --- a/build/blis_ref_kernel_mirror.py +++ b/build/blis_ref_kernel_mirror.py @@ -68,13 +68,17 @@ def remove_lines_in_file(filename): with open(filename, 'r') as fd: file_content = fd.read() file_content = file_content.replace( - 'if(${TARGET_ARCH} STREQUAL amdzen)\nadd_subdirectory(${CMAKE_BINARY_' - 'DIR}/ref_kernels/generic ${CMAKE_BINARY_DIR}/ref_kernels/generic)\n' - 'add_subdirectory(${CMAKE_BINARY_DIR}/ref_kernels/zen ${CMAKE_BINARY_' - 'DIR}/ref_kernels/zen)\nadd_subdirectory(${CMAKE_BINARY_DIR}/' - 'ref_kernels/zen2 ${CMAKE_BINARY_DIR}/ref_kernels/zen2)\n' + 'if(${TARGET_ARCH} STREQUAL amdzen)\n' + 'add_subdirectory(${CMAKE_BINARY_DIR}/ref_kernels/generic' + '${CMAKE_BINARY_DIR}/ref_kernels/generic)\n' + 'add_subdirectory(${CMAKE_BINARY_DIR}/ref_kernels/zen' + '${CMAKE_BINARY_DIR}/ref_kernels/zen)\n' + 'add_subdirectory(${CMAKE_BINARY_DIR}/ref_kernels/zen2' + '${CMAKE_BINARY_DIR}/ref_kernels/zen2)\n' 'add_subdirectory(${CMAKE_BINARY_DIR}/ref_kernels/zen3 ' - '${CMAKE_BINARY_DIR}/ref_kernels/zen3)\nelse()', '\n') + '${CMAKE_BINARY_DIR}/ref_kernels/zen3)\n' + 'add_subdirectory(${CMAKE_BINARY_DIR}/ref_kernels/zen4 ' + '${CMAKE_BINARY_DIR}/ref_kernels/zen4)\nelse()', '\n') data = file_content.replace('endif()', '\n') with open(filename, 'w') as fd: fd.write(data + '\n') diff --git a/config/CMakeLists.txt b/config/CMakeLists.txt index 12568f67f7..dd8305c371 100644 --- a/config/CMakeLists.txt +++ b/config/CMakeLists.txt @@ -1,6 +1,9 @@ ##Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. ## -if(${TARGET_ARCH} STREQUAL zen3) +if(${TARGET_ARCH} STREQUAL zen4) +message("The configuration is : ${TARGET_ARCH}") +add_subdirectory(zen4) +elseif(${TARGET_ARCH} STREQUAL zen3) message("The configuration is : ${TARGET_ARCH}") add_subdirectory(zen3) elseif(${TARGET_ARCH} STREQUAL zen2) @@ -15,6 +18,7 @@ add_subdirectory(generic) add_subdirectory(zen) add_subdirectory(zen2) add_subdirectory(zen3) +add_subdirectory(zen4) elseif(${TARGET_ARCH} STREQUAL haswell) message("The configuration is : ${TARGET_ARCH}") add_subdirectory(haswell) diff --git a/config/zen4/CMakeLists.txt b/config/zen4/CMakeLists.txt new file mode 100644 index 0000000000..1c18083516 --- /dev/null +++ b/config/zen4/CMakeLists.txt @@ -0,0 +1,7 @@ +##Copyright (C) 2021, Advanced Micro Devices, Inc ## + +target_sources("${PROJECT_NAME}" + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/bli_cntx_init_zen4.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_family_zen4.h + ) diff --git a/config/zen4/bli_cntx_init_zen4.c b/config/zen4/bli_cntx_init_zen4.c new file mode 100644 index 0000000000..806c268a0f --- /dev/null +++ b/config/zen4/bli_cntx_init_zen4.c @@ -0,0 +1,285 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +void bli_cntx_init_zen4( cntx_t* cntx ) +{ + blksz_t blkszs[ BLIS_NUM_BLKSZS ]; + blksz_t thresh[ BLIS_NUM_THRESH ]; + // Set default kernel blocksizes and functions. + bli_cntx_init_zen4_ref( cntx ); + + // ------------------------------------------------------------------------- + + // Update the context with optimized native gemm micro-kernels and + // their storage preferences. + bli_cntx_set_l3_nat_ukrs + ( + 8, + // gemm + BLIS_GEMM_UKR, BLIS_FLOAT, bli_sgemm_haswell_asm_6x16, TRUE, + BLIS_GEMM_UKR, BLIS_DOUBLE, bli_dgemm_haswell_asm_6x8, TRUE, + BLIS_GEMM_UKR, BLIS_SCOMPLEX, bli_cgemm_haswell_asm_3x8, TRUE, + BLIS_GEMM_UKR, BLIS_DCOMPLEX, bli_zgemm_haswell_asm_3x4, TRUE, + // gemmtrsm_l + BLIS_GEMMTRSM_L_UKR, BLIS_FLOAT, bli_sgemmtrsm_l_haswell_asm_6x16, TRUE, + BLIS_GEMMTRSM_L_UKR, BLIS_DOUBLE, bli_dgemmtrsm_l_haswell_asm_6x8, TRUE, + // gemmtrsm_u + BLIS_GEMMTRSM_U_UKR, BLIS_FLOAT, bli_sgemmtrsm_u_haswell_asm_6x16, TRUE, + BLIS_GEMMTRSM_U_UKR, BLIS_DOUBLE, bli_dgemmtrsm_u_haswell_asm_6x8, TRUE, + cntx + ); + + // Update the context with architecture specific threshold functions + bli_cntx_set_l3_thresh_funcs + ( + 2, + // GEMMT + BLIS_GEMMT, bli_cntx_gemmtsup_thresh_is_met_zen, + // SYRK + BLIS_SYRK, bli_cntx_syrksup_thresh_is_met_zen, + cntx + ); + + // packm kernels + bli_cntx_set_packm_kers + ( + 8, + BLIS_PACKM_6XK_KER, BLIS_FLOAT, bli_spackm_haswell_asm_6xk, + BLIS_PACKM_16XK_KER, BLIS_FLOAT, bli_spackm_haswell_asm_16xk, + BLIS_PACKM_6XK_KER, BLIS_DOUBLE, bli_dpackm_haswell_asm_6xk, + BLIS_PACKM_8XK_KER, BLIS_DOUBLE, bli_dpackm_haswell_asm_8xk, + BLIS_PACKM_3XK_KER, BLIS_SCOMPLEX, bli_cpackm_haswell_asm_3xk, + BLIS_PACKM_8XK_KER, BLIS_SCOMPLEX, bli_cpackm_haswell_asm_8xk, + BLIS_PACKM_3XK_KER, BLIS_DCOMPLEX, bli_zpackm_haswell_asm_3xk, + BLIS_PACKM_4XK_KER, BLIS_DCOMPLEX, bli_zpackm_haswell_asm_4xk, + cntx + ); + + // Update the context with optimized level-1f kernels. + bli_cntx_set_l1f_kers + ( + 6, + // axpyf + BLIS_AXPYF_KER, BLIS_FLOAT, bli_saxpyf_zen_int_5, + BLIS_AXPYF_KER, BLIS_DOUBLE, bli_daxpyf_zen_int_5, + BLIS_AXPYF_KER, BLIS_SCOMPLEX, bli_caxpyf_zen_int_5, + BLIS_AXPYF_KER, BLIS_DCOMPLEX, bli_zaxpyf_zen_int_5, + // dotxf + BLIS_DOTXF_KER, BLIS_FLOAT, bli_sdotxf_zen_int_8, + BLIS_DOTXF_KER, BLIS_DOUBLE, bli_ddotxf_zen_int_8, + cntx + ); + + // Update the context with optimized level-1v kernels. + bli_cntx_set_l1v_kers + ( + 20, + + // amaxv + BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int, + BLIS_AMAXV_KER, BLIS_DOUBLE, bli_damaxv_zen_int, + // axpyv + + // axpyv + BLIS_AXPYV_KER, BLIS_FLOAT, bli_saxpyv_zen_int10, + BLIS_AXPYV_KER, BLIS_DOUBLE, bli_daxpyv_zen_int10, + BLIS_AXPYV_KER, BLIS_SCOMPLEX, bli_caxpyv_zen_int5, + BLIS_AXPYV_KER, BLIS_DCOMPLEX, bli_zaxpyv_zen_int5, + + // dotv + BLIS_DOTV_KER, BLIS_FLOAT, bli_sdotv_zen_int10, + BLIS_DOTV_KER, BLIS_DOUBLE, bli_ddotv_zen_int10, + BLIS_DOTV_KER, BLIS_SCOMPLEX, bli_cdotv_zen_int5, + BLIS_DOTV_KER, BLIS_DCOMPLEX, bli_zdotv_zen_int5, + + // dotxv + BLIS_DOTXV_KER, BLIS_FLOAT, bli_sdotxv_zen_int, + BLIS_DOTXV_KER, BLIS_DOUBLE, bli_ddotxv_zen_int, + + // scalv + BLIS_SCALV_KER, BLIS_FLOAT, bli_sscalv_zen_int10, + BLIS_SCALV_KER, BLIS_DOUBLE, bli_dscalv_zen_int10, + + //swap + BLIS_SWAPV_KER, BLIS_FLOAT, bli_sswapv_zen_int8, + BLIS_SWAPV_KER, BLIS_DOUBLE, bli_dswapv_zen_int8, + + //copy + BLIS_COPYV_KER, BLIS_FLOAT, bli_scopyv_zen_int, + BLIS_COPYV_KER, BLIS_DOUBLE, bli_dcopyv_zen_int, + + //set + BLIS_SETV_KER, BLIS_FLOAT, bli_ssetv_zen_int, + BLIS_SETV_KER, BLIS_DOUBLE, bli_dsetv_zen_int, + cntx + ); + + // Initialize level-3 blocksize objects with architecture-specific values. + // + // These are reference block sizes and may be overridden based on + // number of threads used at runtime. + // s d c z + bli_blksz_init_easy( &blkszs[ BLIS_MR ], 6, 6, 3, 3 ); + bli_blksz_init_easy( &blkszs[ BLIS_NR ], 16, 8, 8, 4 ); + bli_blksz_init_easy( &blkszs[ BLIS_MC ], 144, 72, 144, 18 ); + bli_blksz_init_easy( &blkszs[ BLIS_KC ], 256, 256, 256, 566 ); + bli_blksz_init_easy( &blkszs[ BLIS_NC ], 4080, 4080, 4080, 256 ); + + bli_blksz_init_easy( &blkszs[ BLIS_AF ], 5, 5, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_DF ], 8, 8, -1, -1 ); + + // Update the context with the current architecture's register and cache + // blocksizes (and multiples) for native execution. + bli_cntx_set_blkszs + ( + BLIS_NAT, 7, + // level-3 + BLIS_NC, &blkszs[ BLIS_NC ], BLIS_NR, + BLIS_KC, &blkszs[ BLIS_KC ], BLIS_KR, + BLIS_MC, &blkszs[ BLIS_MC ], BLIS_MR, + BLIS_NR, &blkszs[ BLIS_NR ], BLIS_NR, + BLIS_MR, &blkszs[ BLIS_MR ], BLIS_MR, + // level-1f + BLIS_AF, &blkszs[ BLIS_AF ], BLIS_AF, + BLIS_DF, &blkszs[ BLIS_DF ], BLIS_DF, + cntx + ); + // ------------------------------------------------------------------------- + + //Initialize TRSM blocksize objects with architecture-specific values. + //Using different cache block sizes for TRSM instead of common level-3 block sizes. + //Tuning is done for double-precision only. + // s d c z + bli_blksz_init_easy( &blkszs[ BLIS_MC ], 144, 72, 144, 72 ); + bli_blksz_init_easy( &blkszs[ BLIS_KC ], 256, 492, 256, 256 ); + bli_blksz_init_easy( &blkszs[ BLIS_NC ], 4080, 1600, 4080, 4080 ); + + // Update the context with the current architecture's register and cache + // blocksizes for level-3 TRSM problems. + bli_cntx_set_trsm_blkszs + ( + 5, + BLIS_NC, &blkszs[ BLIS_NC ], + BLIS_KC, &blkszs[ BLIS_KC ], + BLIS_MC, &blkszs[ BLIS_MC ], + BLIS_NR, &blkszs[ BLIS_NR ], + BLIS_MR, &blkszs[ BLIS_MR ], + cntx + ); + + // Initialize sup thresholds with architecture-appropriate values. s d c z + bli_blksz_init_easy( &thresh[ BLIS_MT ], 512, 256, 380, 110 ); + bli_blksz_init_easy( &thresh[ BLIS_NT ], 200, 256, 256, 128 ); + bli_blksz_init_easy( &thresh[ BLIS_KT ], 240, 220, 220, 110 ); + + // Initialize the context with the sup thresholds. + bli_cntx_set_l3_sup_thresh + ( + 3, + BLIS_MT, &thresh[ BLIS_MT ], + BLIS_NT, &thresh[ BLIS_NT ], + BLIS_KT, &thresh[ BLIS_KT ], + cntx + ); + + // Initialize the context with the sup handlers. + bli_cntx_set_l3_sup_handlers + ( + 2, + BLIS_GEMM, bli_gemmsup_ref, + BLIS_GEMMT, bli_gemmtsup_ref, + cntx + ); + + // Update the context with optimized small/unpacked gemm kernels. + bli_cntx_set_l3_sup_kers + ( + 28, + //BLIS_RCR, BLIS_DOUBLE, bli_dgemmsup_r_haswell_ref, + BLIS_RRR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m, TRUE, + BLIS_RRC, BLIS_DOUBLE, bli_dgemmsup_rd_haswell_asm_6x8m, TRUE, + BLIS_RCR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m, TRUE, + BLIS_RCC, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8n, TRUE, + BLIS_CRR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8m, TRUE, + BLIS_CRC, BLIS_DOUBLE, bli_dgemmsup_rd_haswell_asm_6x8n, TRUE, + BLIS_CCR, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8n, TRUE, + BLIS_CCC, BLIS_DOUBLE, bli_dgemmsup_rv_haswell_asm_6x8n, TRUE, + BLIS_RRR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16m, TRUE, + BLIS_RRC, BLIS_FLOAT, bli_sgemmsup_rd_zen_asm_6x16m, TRUE, + BLIS_RCR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16m, TRUE, + BLIS_RCC, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16n, TRUE, + BLIS_CRR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16m, TRUE, + BLIS_CRC, BLIS_FLOAT, bli_sgemmsup_rd_zen_asm_6x16n, TRUE, + BLIS_CCR, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16n, TRUE, + BLIS_CCC, BLIS_FLOAT, bli_sgemmsup_rv_zen_asm_6x16n, TRUE, + BLIS_RRR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE, + BLIS_RCR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE, + BLIS_CRR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8m, TRUE, + BLIS_RCC, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE, + BLIS_CCR, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE, + BLIS_CCC, BLIS_SCOMPLEX, bli_cgemmsup_rv_zen_asm_3x8n, TRUE, + BLIS_RRR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE, + BLIS_RCR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE, + BLIS_CRR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4m, TRUE, + BLIS_RCC, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4n, TRUE, + BLIS_CCR, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4n, TRUE, + BLIS_CCC, BLIS_DCOMPLEX, bli_zgemmsup_rv_zen_asm_3x4n, TRUE, + cntx + ); + + // Initialize level-3 sup blocksize objects with architecture-specific + // values. + // s d c z + bli_blksz_init ( &blkszs[ BLIS_MR ], 6, 6, 3, 3, + 9, 9, 3, 3 ); + bli_blksz_init_easy( &blkszs[ BLIS_NR ], 16, 8, 8, 4 ); + bli_blksz_init_easy( &blkszs[ BLIS_MC ], 144, 72, 72, 36 ); + bli_blksz_init_easy( &blkszs[ BLIS_KC ], 512, 256, 128, 64 ); + bli_blksz_init_easy( &blkszs[ BLIS_NC ], 8160, 4080, 2040, 1020 ); + + // Update the context with the current architecture's register and cache + // blocksizes for small/unpacked level-3 problems. + bli_cntx_set_l3_sup_blkszs + ( + 5, + BLIS_NC, &blkszs[ BLIS_NC ], + BLIS_KC, &blkszs[ BLIS_KC ], + BLIS_MC, &blkszs[ BLIS_MC ], + BLIS_NR, &blkszs[ BLIS_NR ], + BLIS_MR, &blkszs[ BLIS_MR ], + cntx + ); +} diff --git a/config/zen4/bli_family_zen4.h b/config/zen4/bli_family_zen4.h new file mode 100644 index 0000000000..9c70fcef83 --- /dev/null +++ b/config/zen4/bli_family_zen4.h @@ -0,0 +1,59 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLI_FAMILY_ZEN4_ +#define BLI_FAMILY_ZEN4_ + +// By default, it is effective to parallelize the outer loops. +// Setting these macros to 1 will force JR and IR inner loops +// to be not paralleized. +// + +#define BLIS_THREAD_MAX_IR 1 +#define BLIS_THREAD_MAX_JR 1 + +#define BLIS_ENABLE_SMALL_MATRIX +#define BLIS_ENABLE_SMALL_MATRIX_TRSM + +// This will select the threshold below which small matrix code will be called. +#define BLIS_SMALL_MATRIX_THRES 700 +#define BLIS_SMALL_M_RECT_MATRIX_THRES 160 +#define BLIS_SMALL_K_RECT_MATRIX_THRES 128 + +#define BLIS_SMALL_MATRIX_A_THRES_M_SYRK 96 +#define BLIS_SMALL_MATRIX_A_THRES_N_SYRK 128 + +//#define BLIS_ENABLE_FAST_MATH + +#endif diff --git a/config/zen4/make_defs.mk b/config/zen4/make_defs.mk new file mode 100644 index 0000000000..352bd29c4e --- /dev/null +++ b/config/zen4/make_defs.mk @@ -0,0 +1,141 @@ +# +# +# BLIS +# An object-based framework for developing high-performance BLAS-like +# libraries. +# +# Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# - Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# - Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# - Neither the name(s) of the copyright holder(s) nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# + +# FLAGS that are specific to the 'zen3' architecture are added here. +# FLAGS that are common for all the AMD architectures are present in +# config/zen/amd_config.mk. + +# Declare the name of the current configuration and add it to the +# running list of configurations included by common.mk. +THIS_CONFIG := zen4 +#CONFIGS_INCL += $(THIS_CONFIG) + +# +# --- Determine the C compiler and related flags --- +# + +# NOTE: The build system will append these variables with various +# general-purpose/configuration-agnostic flags in common.mk. You +# may specify additional flags here as needed. + +# Since we removed BLIS_CONFIG_EPYC from header file, we need to +# add it here at two places, +# CPPROCFLAGS = This will enable it for framework code +# This flag is used when configure is invoked with specific architecture +# CKOPTFLAGS = This will enable it for architecture specific kernels +# This flag is used for kernels assocaited with this architecture +# irrespective of the configuration it is built for. + +CPPROCFLAGS := -DBLIS_CONFIG_EPYC +CMISCFLAGS := +CPICFLAGS := +CWARNFLAGS := + +ifneq ($(DEBUG_TYPE),off) +CDBGFLAGS := -g +endif + +ifeq ($(DEBUG_TYPE),noopt) +COPTFLAGS := -O0 +else +COPTFLAGS := -O3 +endif + +# Flags specific to optimized kernels. +# NOTE: The -fomit-frame-pointer option is needed for some kernels because +# they make explicit use of the rbp register. +CKOPTFLAGS := $(COPTFLAGS) -fomit-frame-pointer +ifeq ($(CC_VENDOR),gcc) +GCC_VERSION := $(strip $(shell $(CC) -dumpversion | cut -d. -f1)) +# gcc or clang version must be atleast 4.0 +# gcc 9.0 or later: +ifeq ($(shell test $(GCC_VERSION) -ge 11; echo $$?),0) +CKVECFLAGS += -march=znver3 +else +ifeq ($(shell test $(GCC_VERSION) -ge 9; echo $$?),0) +CKVECFLAGS += -march=znver2 +else +# If gcc is older than 9.1.0 but at least 6.1.0, then we can use -march=znver1 +# as the fallback option. +CRVECFLAGS += -march=znver1 -mno-avx256-split-unaligned-store +CKVECFLAGS += -march=znver1 -mno-avx256-split-unaligned-store +endif # GCC 9 +endif # GCC 11 +else +ifeq ($(CC_VENDOR),clang) + +# AOCC clang has various formats for the version line + +# AOCC.LLVM.2.0.0.B191.2019_07_19 clang version 8.0.0 (CLANG: Jenkins AOCC_2_0_0-Build#191) (based on LLVM AOCC.LLVM.2.0.0.B191.2019_07_19) +# AOCC.LLVM.2.1.0.B1030.2019_11_12 clang version 9.0.0 (CLANG: Build#1030) (based on LLVM AOCC.LLVM.2.1.0.B1030.2019_11_12) +# AMD clang version 10.0.0 (CLANG: AOCC_2.2.0-Build#93 2020_06_25) (based on LLVM Mirror.Version.10.0.0) +# AMD clang version 11.0.0 (CLANG: AOCC_2.3.0-Build#85 2020_11_10) (based on LLVM Mirror.Version.11.0.0) +# AMD clang version 12.0.0 (CLANG: AOCC_3.0.0-Build#2 2020_11_05) (based on LLVM Mirror.Version.12.0.0) + +# For our prupose we just want to know if it version 2x or 3x + +# for version 3x we will enable znver3 +ifeq ($(strip $(shell $(CC) -v |&head -1 |grep -c 'AOCC_3')),1) +CKVECFLAGS += -march=znver3 +else +# for version 2x we will enable znver2 +ifeq ($(strip $(shell $(CC) -v |&head -1 |grep -c 'AOCC.LLVM.2\|AOCC_2')),1) +CKVECFLAGS += -march=znver2 +else +#if compiling with clang +VENDOR_STRING := $(strip $(shell ${CC_VENDOR} --version | egrep -o '[0-9]+\.[0-9]+\.?[0-9]*')) +CC_MAJOR := $(shell (echo ${VENDOR_STRING} | cut -d. -f1)) +#clang 9.0 or later: +ifeq ($(shell test $(CC_MAJOR) -ge 9; echo $$?),0) +CKVECFLAGS += -march=znver2 +else +CKVECFLAGS += -march=znver1 +endif # ge 9 +endif # aocc 2 +endif # aocc 3 +endif # clang +endif # gcc + +# Flags specific to reference kernels. +CROPTFLAGS := $(CKOPTFLAGS) +CRVECFLAGS := $(CKVECFLAGS) + +# Add this after updating variables for reference kernels +# we don't want this defined for them +CKOPTFLAGS += -DBLIS_CONFIG_EPYC + +# Store all of the variables here to new variables containing the +# configuration name. +$(eval $(call store-make-defs,$(THIS_CONFIG))) + diff --git a/config_registry b/config_registry index 558eccc30c..822b133f5c 100644 --- a/config_registry +++ b/config_registry @@ -11,7 +11,7 @@ x86_64: intel64 amd64 amd64_legacy intel64: skx knl haswell sandybridge penryn generic amd64_legacy: excavator steamroller piledriver bulldozer generic -amdzen: zen3 zen2 zen generic +amdzen: zen4 zen3 zen2 zen generic # NOTE: ARM families will remain disabled until runtime hardware detection # logic is added to BLIS. @@ -26,6 +26,7 @@ sandybridge: sandybridge penryn: penryn # AMD architectures. +zen4: zen4/zen4/zen3/zen2/zen/haswell zen3: zen3/zen3/zen2/zen/haswell zen2: zen2/zen2/zen/haswell zen: zen/zen/haswell diff --git a/frame/2/gemv/bli_gemv_unf_var1.c b/frame/2/gemv/bli_gemv_unf_var1.c index 4f0054c1f1..e468587d4b 100644 --- a/frame/2/gemv/bli_gemv_unf_var1.c +++ b/frame/2/gemv/bli_gemv_unf_var1.c @@ -149,7 +149,10 @@ void bli_dgemv_unf_var1 // Invoke architecture specific kernels only if we are sure that we are running on zen, // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); + bool bamdzen = (id == BLIS_ARCH_ZEN4) || + (id == BLIS_ARCH_ZEN3) || + (id == BLIS_ARCH_ZEN2) || + (id == BLIS_ARCH_ZEN); if (bamdzen == 0) { @@ -313,7 +316,10 @@ void bli_sgemv_unf_var1 // Invoke architecture specific kernels only if we are sure that we are running on zen, // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); + bool bamdzen = (id == BLIS_ARCH_ZEN4) || + (id == BLIS_ARCH_ZEN3) || + (id == BLIS_ARCH_ZEN2) || + (id == BLIS_ARCH_ZEN); if (bamdzen == 0) { diff --git a/frame/2/gemv/bli_gemv_unf_var2.c b/frame/2/gemv/bli_gemv_unf_var2.c index 84a67c3189..093b615a7d 100644 --- a/frame/2/gemv/bli_gemv_unf_var2.c +++ b/frame/2/gemv/bli_gemv_unf_var2.c @@ -179,7 +179,10 @@ void bli_dgemv_unf_var2 // Invoke architecture specific kernels only if we are sure that we are running on zen, // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); + bool bamdzen = (id == BLIS_ARCH_ZEN4) || + (id == BLIS_ARCH_ZEN3) || + (id == BLIS_ARCH_ZEN2) || + (id == BLIS_ARCH_ZEN); if (bamdzen == 0) { @@ -391,7 +394,10 @@ void bli_sgemv_unf_var2 // Invoke architecture specific kernels only if we are sure that we are running on zen, // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); + bool bamdzen = (id == BLIS_ARCH_ZEN4) || + (id == BLIS_ARCH_ZEN3) || + (id == BLIS_ARCH_ZEN2) || + (id == BLIS_ARCH_ZEN); if (bamdzen == 0) { @@ -556,7 +562,10 @@ void bli_zgemv_unf_var2 // Invoke architecture specific kernels only if we are sure that we are running on zen, // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); + bool bamdzen = (id == BLIS_ARCH_ZEN4) || + (id == BLIS_ARCH_ZEN3) || + (id == BLIS_ARCH_ZEN2) || + (id == BLIS_ARCH_ZEN); if (bamdzen == 0) { @@ -742,7 +751,10 @@ void bli_cgemv_unf_var2 // Invoke architecture specific kernels only if we are sure that we are running on zen, // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); + bool bamdzen = (id == BLIS_ARCH_ZEN4) || + (id == BLIS_ARCH_ZEN3) || + (id == BLIS_ARCH_ZEN2) || + (id == BLIS_ARCH_ZEN); if (bamdzen == 0) { diff --git a/frame/2/trsv/bli_trsv_unf_var1.c b/frame/2/trsv/bli_trsv_unf_var1.c index 4f19e1ac5e..f2f9ea6a6d 100644 --- a/frame/2/trsv/bli_trsv_unf_var1.c +++ b/frame/2/trsv/bli_trsv_unf_var1.c @@ -298,7 +298,11 @@ void bli_dtrsv_unf_var1 /* Assign kernel function pointer and fusing factor. */ arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); + bool bamdzen = (id == BLIS_ARCH_ZEN4) || + (id == BLIS_ARCH_ZEN3) || + (id == BLIS_ARCH_ZEN2) || + (id == BLIS_ARCH_ZEN); + if (bamdzen) { kfp_df = bli_ddotxf_zen_int_8; b_fuse = 8; @@ -500,7 +504,11 @@ void bli_strsv_unf_var1 /* Assign kernel function pointer and fusing factor. */ arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); + bool bamdzen = (id == BLIS_ARCH_ZEN4) || + (id == BLIS_ARCH_ZEN3) || + (id == BLIS_ARCH_ZEN2) || + (id == BLIS_ARCH_ZEN); + if (bamdzen) { kfp_df = bli_sdotxf_zen_int_8; b_fuse = 8; diff --git a/frame/2/trsv/bli_trsv_unf_var2.c b/frame/2/trsv/bli_trsv_unf_var2.c index 7ece8f8470..2fd89dacf5 100644 --- a/frame/2/trsv/bli_trsv_unf_var2.c +++ b/frame/2/trsv/bli_trsv_unf_var2.c @@ -294,7 +294,11 @@ void bli_dtrsv_unf_var2 /* Assign kernel function pointer and fusing factor. */ arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); + bool bamdzen = (id == BLIS_ARCH_ZEN4) || + (id == BLIS_ARCH_ZEN3) || + (id == BLIS_ARCH_ZEN2) || + (id == BLIS_ARCH_ZEN); + if (bamdzen) { kfp_af = bli_daxpyf_zen_int_16x4; b_fuse = 4; @@ -490,7 +494,11 @@ void bli_strsv_unf_var2 /* Assign function pointer and fusing factor. */ arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); + bool bamdzen = (id == BLIS_ARCH_ZEN4) || + (id == BLIS_ARCH_ZEN3) || + (id == BLIS_ARCH_ZEN2) || + (id == BLIS_ARCH_ZEN); + if (bamdzen) { kfp_af = bli_saxpyf_zen_int_5; b_fuse = 5; @@ -686,7 +694,11 @@ void bli_ztrsv_unf_var2 /* Assign function pointer and fusing factor. */ arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); + bool bamdzen = (id == BLIS_ARCH_ZEN4) || + (id == BLIS_ARCH_ZEN3) || + (id == BLIS_ARCH_ZEN2) || + (id == BLIS_ARCH_ZEN); + if (bamdzen) { kfp_af = bli_zaxpyf_zen_int_5; b_fuse = 5; @@ -881,7 +893,11 @@ void bli_ctrsv_unf_var2 /* Assign function pointer and fusing factor. */ arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); + bool bamdzen = (id == BLIS_ARCH_ZEN4) || + (id == BLIS_ARCH_ZEN3) || + (id == BLIS_ARCH_ZEN2) || + (id == BLIS_ARCH_ZEN); + if (bamdzen) { kfp_af = bli_caxpyf_zen_int_5; b_fuse = 5; diff --git a/frame/base/bli_arch.c b/frame/base/bli_arch.c index 153787d3ed..2696236717 100644 --- a/frame/base/bli_arch.c +++ b/frame/base/bli_arch.c @@ -154,6 +154,9 @@ void bli_arch_set_id( void ) #endif // AMD microarchitectures. + #ifdef BLIS_FAMILY_ZEN4 + id = BLIS_ARCH_ZEN4; + #endif #ifdef BLIS_FAMILY_ZEN3 id = BLIS_ARCH_ZEN3; #endif @@ -236,6 +239,7 @@ static char* config_name[ BLIS_NUM_ARCHS ] = "sandybridge", "penryn", + "zen4", "zen3", "zen2", "zen", diff --git a/frame/base/bli_cpuid.c b/frame/base/bli_cpuid.c index 4b3837544f..d5d8315543 100644 --- a/frame/base/bli_cpuid.c +++ b/frame/base/bli_cpuid.c @@ -114,6 +114,10 @@ arch_t bli_cpuid_query_id( void ) // Check for each AMD configuration that is enabled, check for that // microarchitecture. We check from most recent to most dated. +#ifdef BLIS_CONFIG_ZEN4 + if ( bli_cpuid_is_zen4( family, model, features ) ) + return BLIS_ARCH_ZEN4; +#endif #ifdef BLIS_CONFIG_ZEN3 if ( bli_cpuid_is_zen3( family, model, features ) ) return BLIS_ARCH_ZEN3; @@ -264,6 +268,44 @@ bool bli_cpuid_is_penryn } // ----------------------------------------------------------------------------- +bool bli_cpuid_is_zen4 + ( + uint32_t family, + uint32_t model, + uint32_t features + ) +{ + // Check for expected CPU features. + const uint32_t expected = FEATURE_SSE3 | + FEATURE_SSSE3 | + FEATURE_SSE41 | + FEATURE_SSE42 | + FEATURE_AVX | + FEATURE_AVX2 | + FEATURE_FMA3 | + FEATURE_AVX512F | + FEATURE_AVX512DQ | + FEATURE_AVX512CD | + FEATURE_AVX512BW | + FEATURE_AVX512VL ; + + if ( !bli_cpuid_has_features( features, expected ) ) return FALSE; + + // For zen4 the family id is 0x19 + if ( family != 0x19 ) return FALSE; + + // Finally, check for specific models: + // Zen 4 maps to couple of different model number ranges + // we check for all of them. + const bool is_arch + = + (0x10 <= model && model <= 0x1f ); + + if ( !is_arch ) return FALSE; + + return TRUE; +} + bool bli_cpuid_is_zen3 ( uint32_t family, diff --git a/frame/base/bli_cpuid.h b/frame/base/bli_cpuid.h index 62c05ad5ca..a9f960847a 100644 --- a/frame/base/bli_cpuid.h +++ b/frame/base/bli_cpuid.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018-2020, Advanced Micro Devices, Inc. + Copyright (C) 2018-2021, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -61,6 +61,7 @@ bool bli_cpuid_is_sandybridge( uint32_t family, uint32_t model, uint32_t feature bool bli_cpuid_is_penryn( uint32_t family, uint32_t model, uint32_t features ); // AMD +bool bli_cpuid_is_zen4( uint32_t family, uint32_t model, uint32_t features ); bool bli_cpuid_is_zen3( uint32_t family, uint32_t model, uint32_t features ); bool bli_cpuid_is_zen2( uint32_t family, uint32_t model, uint32_t features ); bool bli_cpuid_is_zen( uint32_t family, uint32_t model, uint32_t features ); diff --git a/frame/base/bli_gks.c b/frame/base/bli_gks.c index 746b141a93..acb36d306f 100644 --- a/frame/base/bli_gks.c +++ b/frame/base/bli_gks.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018-2020, Advanced Micro Devices, Inc. + Copyright (C) 2018-2021, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -97,6 +97,11 @@ void bli_gks_init( void ) #endif // AMD architectures +#ifdef BLIS_CONFIG_ZEN4 + bli_gks_register_cntx( BLIS_ARCH_ZEN4, bli_cntx_init_zen4, + bli_cntx_init_zen4_ref, + bli_cntx_init_zen4_ind ); +#endif #ifdef BLIS_CONFIG_ZEN3 bli_gks_register_cntx( BLIS_ARCH_ZEN3, bli_cntx_init_zen3, bli_cntx_init_zen3_ref, @@ -165,7 +170,7 @@ void bli_gks_init( void ) bli_gks_register_cntx( BLIS_ARCH_POWER10, bli_cntx_init_power10, bli_cntx_init_power10_ref, bli_cntx_init_power10_ind ); -#endif +#endif #ifdef BLIS_CONFIG_POWER9 bli_gks_register_cntx( BLIS_ARCH_POWER9, bli_cntx_init_power9, bli_cntx_init_power9_ref, @@ -247,7 +252,7 @@ void bli_gks_finalize( void ) void bli_gks_init_index( void ) { // This function is called by bli_gks_init(). It simply initializes all - // architecture id elements of the internal arrays to NULL. + // architecture id elements of the internal arrays to NULL. const size_t gks_size = sizeof( cntx_t* ) * BLIS_NUM_ARCHS; const size_t fpa_size = sizeof( void_fp ) * BLIS_NUM_ARCHS; @@ -360,7 +365,7 @@ void bli_gks_register_cntx // functions for reference kernels and induced method execution. The // former will be used whenever we need to obtain reference kernels and // latter will be used later on if the user calls a level-3 function - // with induced execution enabled. + // with induced execution enabled. cntx_ref_init[ id ] = ref_fp; cntx_ind_init[ id ] = ind_fp; @@ -554,7 +559,7 @@ cntx_t* bli_gks_query_ind_cntx // function on the newly allocated structure, we must first copy // over the contents of the native context. *gks_id_ind = *gks_id_nat; - + // Use the architecture id to look up the function pointer to the // context initialization function for induced methods. ind_cntx_init_ft f = cntx_ind_init[ id ]; diff --git a/frame/compat/bla_amax.c b/frame/compat/bla_amax.c index fabed6e72d..214dfe67aa 100644 --- a/frame/compat/bla_amax.c +++ b/frame/compat/bla_amax.c @@ -164,7 +164,10 @@ f77_int isamax_ // Invoke architecture specific kernels only if we are sure that we are running on zen, // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); + bool bamdzen = (id == BLIS_ARCH_ZEN4) || + (id == BLIS_ARCH_ZEN3) || + (id == BLIS_ARCH_ZEN2) || + (id == BLIS_ARCH_ZEN); if (bamdzen) { @@ -180,7 +183,7 @@ f77_int isamax_ else { PASTEMAC2(s,amaxv,BLIS_TAPI_EX_SUF) - ( + ( n0, x0, incx0, &bli_index, @@ -188,7 +191,7 @@ f77_int isamax_ NULL ); } - + /* Convert zero-based BLIS (C) index to one-based BLAS (Fortran) index. Also, if the BLAS integer size differs from the BLIS integer size, that typecast occurs here. */ @@ -265,7 +268,10 @@ f77_int idamax_ // Invoke architecture specific kernels only if we are sure that we are running on zen, // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); + bool bamdzen = (id == BLIS_ARCH_ZEN4) || + (id == BLIS_ARCH_ZEN3) || + (id == BLIS_ARCH_ZEN2) || + (id == BLIS_ARCH_ZEN); if (bamdzen) { @@ -281,7 +287,7 @@ f77_int idamax_ else { PASTEMAC2(d,amaxv,BLIS_TAPI_EX_SUF) - ( + ( n0, x0, incx0, &bli_index, diff --git a/frame/compat/bla_axpy.c b/frame/compat/bla_axpy.c index 41885e95d6..93f30e1e55 100644 --- a/frame/compat/bla_axpy.c +++ b/frame/compat/bla_axpy.c @@ -151,7 +151,10 @@ void saxpy_ // Invoke architecture specific kernels only if we are sure that we are running on zen, // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); + bool bamdzen = (id == BLIS_ARCH_ZEN4) || + (id == BLIS_ARCH_ZEN3) || + (id == BLIS_ARCH_ZEN2) || + (id == BLIS_ARCH_ZEN); if (bamdzen) { @@ -248,7 +251,10 @@ void daxpy_ // Invoke architecture specific kernels only if we are sure that we are running on zen, // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); + bool bamdzen = (id == BLIS_ARCH_ZEN4) || + (id == BLIS_ARCH_ZEN3) || + (id == BLIS_ARCH_ZEN2) || + (id == BLIS_ARCH_ZEN); if (bamdzen) { @@ -346,7 +352,10 @@ void caxpy_ // Invoke architecture specific kernels only if we are sure that we are running on zen, // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); + bool bamdzen = (id == BLIS_ARCH_ZEN4) || + (id == BLIS_ARCH_ZEN3) || + (id == BLIS_ARCH_ZEN2) || + (id == BLIS_ARCH_ZEN); if (bamdzen) { @@ -444,7 +453,10 @@ void zaxpy_ // Invoke architecture specific kernels only if we are sure that we are running on zen, // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); + bool bamdzen = (id == BLIS_ARCH_ZEN4) || + (id == BLIS_ARCH_ZEN3) || + (id == BLIS_ARCH_ZEN2) || + (id == BLIS_ARCH_ZEN); if (bamdzen) { diff --git a/frame/compat/bla_copy.c b/frame/compat/bla_copy.c index 61df88cf1e..f4aa3ee83b 100644 --- a/frame/compat/bla_copy.c +++ b/frame/compat/bla_copy.c @@ -158,7 +158,10 @@ void scopy_ // Invoke architecture specific kernels only if we are sure that we are running on zen, // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); + bool bamdzen = (id == BLIS_ARCH_ZEN4) || + (id == BLIS_ARCH_ZEN3) || + (id == BLIS_ARCH_ZEN2) || + (id == BLIS_ARCH_ZEN); if (bamdzen) { @@ -258,7 +261,10 @@ void dcopy_ // Invoke architecture specific kernels only if we are sure that we are running on zen, // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); + bool bamdzen = (id == BLIS_ARCH_ZEN4) || + (id == BLIS_ARCH_ZEN3) || + (id == BLIS_ARCH_ZEN2) || + (id == BLIS_ARCH_ZEN); if (bamdzen) { diff --git a/frame/compat/bla_dot.c b/frame/compat/bla_dot.c index 2a0f815217..419f8c7dce 100644 --- a/frame/compat/bla_dot.c +++ b/frame/compat/bla_dot.c @@ -159,7 +159,10 @@ float sdot_ // Invoke architecture specific kernels only if we are sure that we are running on zen, // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); + bool bamdzen = (id == BLIS_ARCH_ZEN4) || + (id == BLIS_ARCH_ZEN3) || + (id == BLIS_ARCH_ZEN2) || + (id == BLIS_ARCH_ZEN); if (bamdzen) { @@ -177,10 +180,10 @@ float sdot_ } else { - /* Call BLIS interface. */ - PASTEMAC2(s,dotv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, + /* Call BLIS interface. */ + PASTEMAC2(s,dotv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, BLIS_NO_CONJUGATE, n0, x0, incx0, @@ -265,7 +268,10 @@ double ddot_ // Invoke architecture specific kernels only if we are sure that we are running on zen, // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); + bool bamdzen = (id == BLIS_ARCH_ZEN4) || + (id == BLIS_ARCH_ZEN3) || + (id == BLIS_ARCH_ZEN2) || + (id == BLIS_ARCH_ZEN); if (bamdzen) { @@ -283,10 +289,10 @@ double ddot_ } else { - /* Call BLIS interface. */ - PASTEMAC2(d,dotv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, + /* Call BLIS interface. */ + PASTEMAC2(d,dotv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, BLIS_NO_CONJUGATE, n0, x0, incx0, @@ -377,7 +383,10 @@ scomplex cdotu_ // Invoke architecture specific kernels only if we are sure that we are running on zen, // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); + bool bamdzen = (id == BLIS_ARCH_ZEN4) || + (id == BLIS_ARCH_ZEN3) || + (id == BLIS_ARCH_ZEN2) || + (id == BLIS_ARCH_ZEN); if (bamdzen) { @@ -395,10 +404,10 @@ scomplex cdotu_ } else { - /* Call BLIS interface. */ - PASTEMAC2(c,dotv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, + /* Call BLIS interface. */ + PASTEMAC2(c,dotv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, BLIS_NO_CONJUGATE, n0, x0, incx0, @@ -484,7 +493,10 @@ dcomplex zdotu_ // Invoke architecture specific kernels only if we are sure that we are running on zen, // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); + bool bamdzen = (id == BLIS_ARCH_ZEN4) || + (id == BLIS_ARCH_ZEN3) || + (id == BLIS_ARCH_ZEN2) || + (id == BLIS_ARCH_ZEN); if (bamdzen) { @@ -502,10 +514,10 @@ dcomplex zdotu_ } else { - /* Call BLIS interface. */ - PASTEMAC2(z,dotv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, + /* Call BLIS interface. */ + PASTEMAC2(z,dotv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, BLIS_NO_CONJUGATE, n0, x0, incx0, @@ -515,7 +527,7 @@ dcomplex zdotu_ NULL ); } - + /* Finalize BLIS. */ // bli_finalize_auto(); @@ -594,7 +606,10 @@ scomplex cdotc_ // Invoke architecture specific kernels only if we are sure that we are running on zen, // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); + bool bamdzen = (id == BLIS_ARCH_ZEN4) || + (id == BLIS_ARCH_ZEN3) || + (id == BLIS_ARCH_ZEN2) || + (id == BLIS_ARCH_ZEN); if (bamdzen) { @@ -612,10 +627,10 @@ scomplex cdotc_ } else { - /* Call BLIS interface. */ - PASTEMAC2(c,dotv,BLIS_TAPI_EX_SUF) - ( - BLIS_CONJUGATE, + /* Call BLIS interface. */ + PASTEMAC2(c,dotv,BLIS_TAPI_EX_SUF) + ( + BLIS_CONJUGATE, BLIS_NO_CONJUGATE, n0, x0, incx0, @@ -701,7 +716,10 @@ dcomplex zdotc_ // Invoke architecture specific kernels only if we are sure that we are running on zen, // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); + bool bamdzen = (id == BLIS_ARCH_ZEN4) || + (id == BLIS_ARCH_ZEN3) || + (id == BLIS_ARCH_ZEN2) || + (id == BLIS_ARCH_ZEN); if (bamdzen) { @@ -719,10 +737,10 @@ dcomplex zdotc_ } else { - /* Call BLIS interface. */ - PASTEMAC2(z,dotv,BLIS_TAPI_EX_SUF) - ( - BLIS_CONJUGATE, + /* Call BLIS interface. */ + PASTEMAC2(z,dotv,BLIS_TAPI_EX_SUF) + ( + BLIS_CONJUGATE, BLIS_NO_CONJUGATE, n0, x0, incx0, @@ -733,9 +751,9 @@ dcomplex zdotc_ ); } - - + + /* Finalize BLIS. */ // bli_finalize_auto(); diff --git a/frame/compat/bla_gemm.c b/frame/compat/bla_gemm.c index 50aa931a82..80ad197c68 100644 --- a/frame/compat/bla_gemm.c +++ b/frame/compat/bla_gemm.c @@ -367,7 +367,10 @@ void dgemm_ // Invoke architecture specific kernels only if we are sure that we are running on zen, // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); + bool bamdzen = (id == BLIS_ARCH_ZEN4) || + (id == BLIS_ARCH_ZEN3) || + (id == BLIS_ARCH_ZEN2) || + (id == BLIS_ARCH_ZEN); if (!bamdzen) { diff --git a/frame/compat/bla_gemv.c b/frame/compat/bla_gemv.c index e9b210bbc1..af2745ca98 100644 --- a/frame/compat/bla_gemv.c +++ b/frame/compat/bla_gemv.c @@ -274,7 +274,10 @@ void dgemv_ // Invoke architecture specific kernels only if we are sure that we are running on zen, // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); + bool bamdzen = (id == BLIS_ARCH_ZEN4) || + (id == BLIS_ARCH_ZEN3) || + (id == BLIS_ARCH_ZEN2) || + (id == BLIS_ARCH_ZEN); if (bamdzen == 0) { @@ -462,7 +465,10 @@ void sgemv_ // Invoke architecture specific kernels only if we are sure that we are running on zen, // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); + bool bamdzen = (id == BLIS_ARCH_ZEN4) || + (id == BLIS_ARCH_ZEN3) || + (id == BLIS_ARCH_ZEN2) || + (id == BLIS_ARCH_ZEN); if (bamdzen == 0) { @@ -639,7 +645,10 @@ void cgemv_ // Invoke architecture specific kernels only if we are sure that we are running on zen, // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); + bool bamdzen = (id == BLIS_ARCH_ZEN4) || + (id == BLIS_ARCH_ZEN3) || + (id == BLIS_ARCH_ZEN2) || + (id == BLIS_ARCH_ZEN); if( m_y == 1 ) { @@ -660,10 +669,10 @@ void cgemv_ } else { - /* Call BLIS interface. */ - PASTEMAC2(c,dotv,BLIS_TAPI_EX_SUF) - ( - conja, + /* Call BLIS interface. */ + PASTEMAC2(c,dotv,BLIS_TAPI_EX_SUF) + ( + conja, BLIS_NO_CONJUGATE, n_x, (scomplex*)a, bli_is_notrans(blis_transa)?cs_a:rs_a, @@ -869,7 +878,10 @@ void zgemv_ // Invoke architecture specific kernels only if we are sure that we are running on zen, // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); + bool bamdzen = (id == BLIS_ARCH_ZEN4) || + (id == BLIS_ARCH_ZEN3) || + (id == BLIS_ARCH_ZEN2) || + (id == BLIS_ARCH_ZEN); if( m_y == 1 ) { @@ -891,10 +903,10 @@ void zgemv_ } else { - /* Call BLIS interface. */ - PASTEMAC2(z,dotv,BLIS_TAPI_EX_SUF) - ( - conja, + /* Call BLIS interface. */ + PASTEMAC2(z,dotv,BLIS_TAPI_EX_SUF) + ( + conja, BLIS_NO_CONJUGATE, n_x, (dcomplex*)a, bli_is_notrans(blis_transa)?cs_a:rs_a, diff --git a/frame/compat/bla_scal.c b/frame/compat/bla_scal.c index 30fd857bc7..ab63a34592 100644 --- a/frame/compat/bla_scal.c +++ b/frame/compat/bla_scal.c @@ -147,7 +147,11 @@ void sscal_ } /* Call BLIS kernel */ arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); + bool bamdzen = (id == BLIS_ARCH_ZEN4) || + (id == BLIS_ARCH_ZEN3) || + (id == BLIS_ARCH_ZEN2) || + (id == BLIS_ARCH_ZEN); + if (bamdzen) { bli_sscalv_zen_int10 ( @@ -228,7 +232,11 @@ void dscal_ } /* Call BLIS kernel */ arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); + bool bamdzen = (id == BLIS_ARCH_ZEN4) || + (id == BLIS_ARCH_ZEN3) || + (id == BLIS_ARCH_ZEN2) || + (id == BLIS_ARCH_ZEN); + if (bamdzen){ bli_dscalv_zen_int10 ( diff --git a/frame/compat/bla_swap.c b/frame/compat/bla_swap.c index 6ecb360f95..526414f332 100644 --- a/frame/compat/bla_swap.c +++ b/frame/compat/bla_swap.c @@ -147,7 +147,11 @@ void sswap_ } arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); + bool bamdzen = (id == BLIS_ARCH_ZEN4) || + (id == BLIS_ARCH_ZEN3) || + (id == BLIS_ARCH_ZEN2) || + (id == BLIS_ARCH_ZEN); + if (bamdzen) { /* Call BLIS kernel */ bli_sswapv_zen_int8 @@ -238,7 +242,11 @@ void dswap_ /* Call BLIS kernel */ arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); + bool bamdzen = (id == BLIS_ARCH_ZEN4) || + (id == BLIS_ARCH_ZEN3) || + (id == BLIS_ARCH_ZEN2) || + (id == BLIS_ARCH_ZEN); + if (bamdzen) { bli_dswapv_zen_int8 ( diff --git a/frame/compat/bla_trsm.c b/frame/compat/bla_trsm.c index 654d3530d2..fa8f0dacd1 100644 --- a/frame/compat/bla_trsm.c +++ b/frame/compat/bla_trsm.c @@ -596,7 +596,11 @@ void strsm_ bli_obj_set_struc( struca, &ao ); arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); + bool bamdzen = (id == BLIS_ARCH_ZEN4) || + (id == BLIS_ARCH_ZEN3) || + (id == BLIS_ARCH_ZEN2) || + (id == BLIS_ARCH_ZEN); + if (bamdzen) { #ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM /* bli_strsm_small is performing better existing native @@ -857,7 +861,11 @@ void dtrsm_ bli_obj_set_struc( struca, &ao ); arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); + bool bamdzen = (id == BLIS_ARCH_ZEN4) || + (id == BLIS_ARCH_ZEN3) || + (id == BLIS_ARCH_ZEN2) || + (id == BLIS_ARCH_ZEN); + if (bamdzen) { #ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM /* bli_dtrsm_small is performing better existing native diff --git a/frame/include/bli_arch_config.h b/frame/include/bli_arch_config.h index a62128dffe..3e2e0b022b 100644 --- a/frame/include/bli_arch_config.h +++ b/frame/include/bli_arch_config.h @@ -62,6 +62,9 @@ CNTX_INIT_PROTS( penryn ) #endif // -- AMD64 architectures -- +#ifdef BLIS_CONFIG_ZEN4 +CNTX_INIT_PROTS( zen4 ) +#endif #ifdef BLIS_CONFIG_ZEN3 CNTX_INIT_PROTS( zen3 ) #endif diff --git a/frame/include/bli_type_defs.h b/frame/include/bli_type_defs.h index 770f5c5378..c0e0095061 100644 --- a/frame/include/bli_type_defs.h +++ b/frame/include/bli_type_defs.h @@ -384,7 +384,7 @@ typedef void* void_fp; #define BLIS_BITVAL_SINGLE_PREC 0x0 #define BLIS_BITVAL_DOUBLE_PREC BLIS_PRECISION_BIT #define BLIS_BITVAL_FLOAT_TYPE 0x0 -#define BLIS_BITVAL_SCOMPLEX_TYPE BLIS_DOMAIN_BIT +#define BLIS_BITVAL_SCOMPLEX_TYPE BLIS_DOMAIN_BIT #define BLIS_BITVAL_DOUBLE_TYPE BLIS_PRECISION_BIT #define BLIS_BITVAL_DCOMPLEX_TYPE ( BLIS_DOMAIN_BIT | BLIS_PRECISION_BIT ) #define BLIS_BITVAL_INT_TYPE 0x04 @@ -394,10 +394,10 @@ typedef void* void_fp; #define BLIS_BITVAL_NO_CONJ 0x0 #define BLIS_BITVAL_CONJ BLIS_CONJ_BIT #define BLIS_BITVAL_CONJ_TRANS ( BLIS_CONJ_BIT | BLIS_TRANS_BIT ) -#define BLIS_BITVAL_ZEROS 0x0 +#define BLIS_BITVAL_ZEROS 0x0 #define BLIS_BITVAL_UPPER ( BLIS_UPPER_BIT | BLIS_DIAG_BIT ) #define BLIS_BITVAL_LOWER ( BLIS_LOWER_BIT | BLIS_DIAG_BIT ) -#define BLIS_BITVAL_DENSE BLIS_UPLO_BITS +#define BLIS_BITVAL_DENSE BLIS_UPLO_BITS #define BLIS_BITVAL_NONUNIT_DIAG 0x0 #define BLIS_BITVAL_UNIT_DIAG BLIS_UNIT_DIAG_BIT #define BLIS_BITVAL_INVERT_DIAG BLIS_INVERT_DIAG_BIT @@ -999,6 +999,7 @@ typedef enum BLIS_ARCH_PENRYN, // AMD + BLIS_ARCH_ZEN4, BLIS_ARCH_ZEN3, BLIS_ARCH_ZEN2, BLIS_ARCH_ZEN, @@ -1541,13 +1542,13 @@ typedef enum BLIS_INVALID_COL_STRIDE = ( -51), BLIS_INVALID_DIM_STRIDE_COMBINATION = ( -52), - // Structure-specific errors + // Structure-specific errors BLIS_EXPECTED_GENERAL_OBJECT = ( -60), BLIS_EXPECTED_HERMITIAN_OBJECT = ( -61), BLIS_EXPECTED_SYMMETRIC_OBJECT = ( -62), BLIS_EXPECTED_TRIANGULAR_OBJECT = ( -63), - // Storage-specific errors + // Storage-specific errors BLIS_EXPECTED_UPPER_OR_LOWER_OBJECT = ( -70), // Partitioning-specific errors @@ -1561,7 +1562,7 @@ typedef enum // Packing-specific errors BLIS_PACK_SCHEMA_NOT_SUPPORTED_FOR_UNPACK = (-100), - // Buffer-specific errors + // Buffer-specific errors BLIS_EXPECTED_NONNULL_OBJECT_BUFFER = (-110), // Memory errors diff --git a/kernels/zen/3/bli_gemm_small.c b/kernels/zen/3/bli_gemm_small.c index d9c4047ec4..bf6c9c29cd 100644 --- a/kernels/zen/3/bli_gemm_small.c +++ b/kernels/zen/3/bli_gemm_small.c @@ -47,7 +47,7 @@ #define D_BLIS_SMALL_MATRIX_THRES (BLIS_SMALL_MATRIX_THRES / 2 ) #define D_BLIS_SMALL_M_RECT_MATRIX_THRES (BLIS_SMALL_M_RECT_MATRIX_THRES / 2) #define D_BLIS_SMALL_K_RECT_MATRIX_THRES (BLIS_SMALL_K_RECT_MATRIX_THRES / 2) -#define BLIS_ATBN_M_THRES 40 // Threshold value of M for/below which small matrix code is called. +#define BLIS_ATBN_M_THRES 40 // Threshold value of M for/below which small matrix code is called. #define AT_MR 4 // The kernel dimension of the A transpose GEMM kernel.(AT_MR * NR). static err_t bli_sgemm_small ( @@ -109,7 +109,7 @@ err_t bli_gemm_small ) { AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); - + #ifdef BLIS_ENABLE_MULTITHREADING AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); return BLIS_NOT_YET_IMPLEMENTED; @@ -118,7 +118,10 @@ err_t bli_gemm_small // Invoke architecture specific kernels only if we are sure that we are running on zen, // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); + bool bamdzen = (id == BLIS_ARCH_ZEN4) || + (id == BLIS_ARCH_ZEN3) || + (id == BLIS_ARCH_ZEN2) || + (id == BLIS_ARCH_ZEN); if (0 == bamdzen) { @@ -262,7 +265,7 @@ static err_t bli_sgemm_small const num_t dt_exec = bli_obj_dt( c ); float* restrict alpha_cast = bli_obj_buffer_for_1x1( dt_exec, alpha ); - float* restrict beta_cast = bli_obj_buffer_for_1x1( dt_exec, beta ); + float* restrict beta_cast = bli_obj_buffer_for_1x1( dt_exec, beta ); /*Beta Zero Check*/ bool is_beta_non_zero=0; @@ -299,11 +302,11 @@ static err_t bli_sgemm_small bli_membrk_rntm_set_membrk( &rntm ); // Get the current size of the buffer pool for A block packing. - // We will use the same size to avoid pool re-initialization + // We will use the same size to avoid pool re-initialization siz_t buffer_size = bli_pool_block_size(bli_membrk_pool(bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), bli_rntm_membrk(&rntm))); - // Based on the available memory in the buffer we will decide if + // Based on the available memory in the buffer we will decide if // we want to do packing or not. // // This kernel assumes that "A" will be un-packged if N <= 3. @@ -315,18 +318,18 @@ static err_t bli_sgemm_small // If this check is removed it will result in the crash as // reported in CPUPL-587. // - + if ((N <= 3) || (((MR * K) << 2) > buffer_size)) { required_packing_A = 0; } - else + else { #ifdef BLIS_ENABLE_MEM_TRACING printf( "bli_sgemm_small: Requesting mem pool block of size %lu\n", buffer_size); #endif // Get the buffer from the pool, if there is no pool with - // required size, it will be created. + // required size, it will be created. bli_membrk_acquire_m(&rntm, buffer_size, BLIS_BITVAL_BUFFER_FOR_A_BLOCK, @@ -1730,7 +1733,7 @@ static err_t bli_sgemm_small bli_membrk_release(&rntm, &local_mem_buf_A_s); } - + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); return BLIS_SUCCESS; } @@ -1757,7 +1760,7 @@ static err_t bli_sgemm_small ) { AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO); - + gint_t M = bli_obj_length( c ); // number of rows of Matrix C gint_t N = bli_obj_width( c ); // number of columns of Matrix C gint_t K = bli_obj_width( a ); // number of columns of OP(A), will be updated if OP(A) is Transpose(A) . @@ -1771,7 +1774,7 @@ static err_t bli_sgemm_small /* ); */ /* return BLIS_NOT_YET_IMPLEMENTED; VK */ /* } */ - + if(L && K ) // Non-zero dimensions will be handled by either sup or native kernels { @@ -1844,7 +1847,7 @@ static err_t bli_sgemm_small bli_membrk_rntm_set_membrk( &rntm ); // Get the current size of the buffer pool for A block packing. - // We will use the same size to avoid pool re-initliazaton + // We will use the same size to avoid pool re-initliazaton siz_t buffer_size = bli_pool_block_size( bli_membrk_pool(bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), bli_rntm_membrk(&rntm))); @@ -1865,7 +1868,7 @@ static err_t bli_sgemm_small { required_packing_A = 0; } - + if (required_packing_A == 1) { #ifdef BLIS_ENABLE_MEM_TRACING @@ -3345,7 +3348,7 @@ static err_t bli_sgemm_small_atbn ) { AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO); - + gint_t M = bli_obj_length( c ); // number of rows of Matrix C gint_t N = bli_obj_width( c ); // number of columns of Matrix C gint_t K = bli_obj_length( b ); // number of rows of Matrix B @@ -3371,7 +3374,7 @@ static err_t bli_sgemm_small_atbn float scratch[8] = {0.0}; const num_t dt_exec = bli_obj_dt( c ); float* restrict alpha_cast = bli_obj_buffer_for_1x1( dt_exec, alpha ); - float* restrict beta_cast = bli_obj_buffer_for_1x1( dt_exec, beta ); + float* restrict beta_cast = bli_obj_buffer_for_1x1( dt_exec, beta ); /*Beta Zero Check*/ bool is_beta_non_zero=0; @@ -3822,7 +3825,7 @@ static err_t bli_dgemm_small_atbn ) { AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO); - + gint_t M = bli_obj_length( c ); // number of rows of Matrix C gint_t N = bli_obj_width( c ); // number of columns of Matrix C gint_t K = bli_obj_length( b ); // number of rows of Matrix B @@ -4358,7 +4361,7 @@ err_t bli_dgemm_small_At bli_membrk_rntm_set_membrk( &rntm ); // Get the current size of the buffer pool for A block packing. - // We will use the same size to avoid pool re-initliazaton + // We will use the same size to avoid pool re-initliazaton siz_t buffer_size = bli_pool_block_size( bli_membrk_pool(bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), bli_rntm_membrk(&rntm))); diff --git a/kernels/zen4/README b/kernels/zen4/README new file mode 100644 index 0000000000..c9e16c2735 --- /dev/null +++ b/kernels/zen4/README @@ -0,0 +1 @@ +Currently there are no zen4 specific kernels, however, this folder is required for the the build system. diff --git a/ref_kernels/CMakeLists.txt b/ref_kernels/CMakeLists.txt index 61357c1fec..d26bce06a5 100644 --- a/ref_kernels/CMakeLists.txt +++ b/ref_kernels/CMakeLists.txt @@ -5,6 +5,7 @@ add_subdirectory(${CMAKE_BINARY_DIR}/ref_kernels/generic ${CMAKE_BINARY_DIR}/ref add_subdirectory(${CMAKE_BINARY_DIR}/ref_kernels/zen ${CMAKE_BINARY_DIR}/ref_kernels/zen) add_subdirectory(${CMAKE_BINARY_DIR}/ref_kernels/zen2 ${CMAKE_BINARY_DIR}/ref_kernels/zen2) add_subdirectory(${CMAKE_BINARY_DIR}/ref_kernels/zen3 ${CMAKE_BINARY_DIR}/ref_kernels/zen3) +add_subdirectory(${CMAKE_BINARY_DIR}/ref_kernels/zen4 ${CMAKE_BINARY_DIR}/ref_kernels/zen4) else() target_sources("${PROJECT_NAME}" PRIVATE From 0f43db8347383533bac1e02d2d23613e73f7202d Mon Sep 17 00:00:00 2001 From: Harsh Dave Date: Tue, 23 Nov 2021 08:33:27 -0600 Subject: [PATCH 053/243] Optimized dsymv implementation -Implemented hemv framework calls for lower and upper kernel variants. -hemv computation is implemented in two parts. One part operate on triangular part of matrix and the remaining part is computed by dotxfaxpyf kernel. -First part performs dotxf and axpyf operation on triangular part of matrix in chunk of 8x8. Two separate helper function for doing so are implemented for lower and upper kernels respectively. -Second part is ddotxaxpyf fused kernel, which performs dotxf and axpyf operation alltogether on non-triangular part of matrix in chunk of 4x8. -Implementation efficiently uses cache memory while computing for optimal performance. Change-Id: Id603031b4578e87a92c6b77f710c647acc195c8e --- frame/2/hemv/bli_hemv_unf_var1.c | 197 +++++++ frame/2/hemv/bli_hemv_unf_var3.c | 195 +++++++ kernels/zen/1f/bli_dotxaxpyf_int_8.c | 735 +++++++++++++++++++++++++++ kernels/zen/bli_kernels_zen.h | 11 + 4 files changed, 1138 insertions(+) create mode 100644 kernels/zen/1f/bli_dotxaxpyf_int_8.c diff --git a/frame/2/hemv/bli_hemv_unf_var1.c b/frame/2/hemv/bli_hemv_unf_var1.c index d36dc00988..ccb39b3485 100644 --- a/frame/2/hemv/bli_hemv_unf_var1.c +++ b/frame/2/hemv/bli_hemv_unf_var1.c @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -215,5 +216,201 @@ void PASTEMAC(ch,varname) \ } \ } +#ifdef BLIS_CONFIG_EPYC + +void post_hemv_8x8(double *a, double *x, + double *y, double *alpha, + dim_t cs_a, dim_t rs_a); + +void bli_dhemv_unf_var1 + ( + uplo_t uplo, + conj_t conja, + conj_t conjx, + conj_t conjh, + dim_t m, + double* alpha, + double* a, inc_t rs_a, inc_t cs_a, + double* x, inc_t incx, + double* beta, + double* y, inc_t incy, + cntx_t* cntx + ) +{ + const num_t dt = PASTEMAC(d,type); + + double* one = PASTEMAC(d,1); + double* zero = PASTEMAC(d,0); + double* A10; + double* A11; + double* a10t; + double* alpha11; + double* a21; + double* x0; + double* x1; + double* chi11; + double* y0; + double* y1; + double* y01; + double* psi11; + double* y21; + double conjx_chi11; + double alpha_chi11; + double alpha11_temp; + dim_t i, k, j; + dim_t b_fuse, f; + dim_t n_behind; + dim_t f_ahead, f_behind; + inc_t rs_at, cs_at; + conj_t conj0 = 0, conj1 = 0; + + /* The algorithm will be expressed in terms of the lower triangular + * case;the upper triangular case is supported by swapping the row + * and column strides of A and toggling some conj parameters. */ + if ( bli_is_lower( uplo ) ) + { + rs_at = rs_a; + cs_at = cs_a; + } + else /* if ( bli_is_upper( uplo ) ) */ + { + rs_at = cs_a; + cs_at = rs_a; + } + + /* If beta is zero, use setv. Otherwise, scale by beta. */ + if ( PASTEMAC(d,eq0)( *beta ) ) + { + /* y = 0; */ + PASTEMAC2(d,setv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + m, + zero, + y, incy, + cntx, + NULL + ); + } + else + { + /* y = beta * y; */ + PASTEMAC2(d,scalv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + m, + beta, + y, incy, + cntx, + NULL + ); + } + + PASTECH(d,dotxaxpyf_ker_ft) kfp_dotxaxpyf_ker; + + /* Query the context for the kernel function pointer and fusing + * factor. */ + /* Assign kernel function pointer and fusing factor. */ + arch_t id = bli_arch_query_id(); + bool bamdzen = ((id == BLIS_ARCH_ZEN4) ||(id == BLIS_ARCH_ZEN3) + || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN)); + if (bamdzen) + { + kfp_dotxaxpyf_ker = bli_ddotxaxpyf_zen_int_8; + b_fuse = 8; + } + else + { + if ( cntx == NULL ) cntx = bli_gks_query_cntx(); + kfp_dotxaxpyf_ker = + bli_cntx_get_l1f_ker_dt( dt, BLIS_DOTXAXPYF_KER, cntx); + b_fuse = + bli_cntx_get_blksz_def_dt( dt, BLIS_XF, cntx ); + } + + for ( i = 0; i < m; i += f ) + { + f = bli_determine_blocksize_dim_f( i, m, b_fuse ); + n_behind = i; + A10 = a + (i )*rs_at + (0 )*cs_at; + A11 = a + (i )*rs_at + (i )*cs_at; + x0 = x + (0 )*incx; + x1 = x + (i )*incx; + y0 = y + (0 )*incy; + y1 = y + (i )*incy; + + /* y1 = y1 + alpha * A10 * x0; (dotxf) */ + /* y0 = y0 + alpha * A10' * x1; (axpyf) */ + kfp_dotxaxpyf_ker + ( + conj0, + conj1, + conjx, + conjx, + n_behind, + f, + alpha, + A10, cs_at, rs_at, + x0, incx, + x1, incx, + one, + y1, incy, + y0, incy, + cntx + ); + + /* y1 = y1 + alpha * A11 * x1; (variant 4) */ + if((f == 8) && (incx == 1) && (incy == 1) && (cs_at == 1)) + { + /*this helper function handles unit stride only*/ + bli_post_hemv_8x8(A11, x1, y1, alpha, rs_at, cs_at); + } + else + { + for ( k = 0; k < f; ++k ) + { + f_behind = k; + f_ahead = f - k - 1; + a10t = A11 + (k )*rs_at + (0 )*cs_at; + alpha11 = A11 + (k )*rs_at + (k )*cs_at; + a21 = A11 + (k+1)*rs_at + (k )*cs_at; + chi11 = x1 + (k )*incx; + y01 = y1 + (0 )*incy; + psi11 = y1 + (k )*incy; + y21 = y1 + (k+1)*incy; + + /* y01 = y01 + alpha * a10t' * chi11; */ + PASTEMAC(d,copycjs)( conjx, *chi11, + conjx_chi11 ); + PASTEMAC(d,scal2s)( *alpha, conjx_chi11, + alpha_chi11 ); + for ( j = 0; j < f_behind; ++j ) + PASTEMAC(d,axpys)( alpha_chi11, + *(a10t + j*cs_at), + *(y01 + j*incy) ); + + PASTEMAC(d,copycjs)( conja, *alpha11, + alpha11_temp ); + + /* psi11 = psi11 + alpha * alpha11 * chi11; */ + PASTEMAC(d,axpys)( alpha_chi11, alpha11_temp, + *psi11 ); + + /* y21 = y21 + alpha * a21 * chi11; */ + for ( j = 0; j < f_ahead; ++j ) + { + PASTEMAC(d,axpys)( alpha_chi11, + *(a21 + j*rs_at), + *(y21 + j*incy) ); + } + } + } + } +} +GENTFUNC(float, s, hemv_unf_var1) +GENTFUNC(scomplex, c, hemv_unf_var1) +GENTFUNC(dcomplex, z, hemv_unf_var1) +#else INSERT_GENTFUNC_BASIC0( hemv_unf_var1 ) +#endif diff --git a/frame/2/hemv/bli_hemv_unf_var3.c b/frame/2/hemv/bli_hemv_unf_var3.c index d8db9bc78a..6ed18efea4 100644 --- a/frame/2/hemv/bli_hemv_unf_var3.c +++ b/frame/2/hemv/bli_hemv_unf_var3.c @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -215,5 +216,199 @@ void PASTEMAC(ch,varname) \ } \ } +#ifdef BLIS_CONFIG_EPYC +void bli_dhemv_unf_var3 + ( + uplo_t uplo, + conj_t conja, + conj_t conjx, + conj_t conjh, + dim_t m, + double* alpha, + double* a, inc_t rs_a, inc_t cs_a, + double* x, inc_t incx, + double* beta, + double* y, inc_t incy, + cntx_t* cntx + ) +{ + const num_t dt = PASTEMAC(d,type); + + double* one = PASTEMAC(d,1); + double* zero = PASTEMAC(d,0); + double* A11; + double* A21; + double* a10t; + double* alpha11; + double* a21; + double* x1; + double* x2; + double* chi11; + double* y1; + double* y2; + double* y01; + double* psi11; + double* y21; + double conjx_chi11; + double alpha_chi11; + double alpha11_temp; + dim_t i, k, j; + dim_t b_fuse, f; + dim_t n_ahead; + dim_t f_ahead, f_behind; + inc_t rs_at, cs_at; + conj_t conj0 = 0, conj1 = 0; + + /* The algorithm will be expressed in terms of the lower triangular + * case; the upper triangular case is supported by swapping the row + * and column strides of A and toggling some conj parameters. */ + if ( bli_is_lower( uplo ) ) + { + rs_at = rs_a; + cs_at = cs_a; + } + else /* if ( bli_is_upper( uplo ) ) */ + { + rs_at = cs_a; + cs_at = rs_a; + } + + /* If beta is zero, use setv. Otherwise, scale by beta. */ + if ( PASTEMAC(d,eq0)( *beta ) ) + { + /* y = 0; */ + PASTEMAC2(d,setv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + m, + zero, + y, incy, + cntx, + NULL + ); + } + else + { + /* y = beta * y; */ + PASTEMAC2(d,scalv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + m, + beta, + y, incy, + cntx, + NULL + ); + } + + PASTECH(d,dotxaxpyf_ker_ft) kfp_dotxaxpyf_ker; + + arch_t id = bli_arch_query_id(); + bool bamdzen = ((id == BLIS_ARCH_ZEN4) || (id == BLIS_ARCH_ZEN3) + || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN)); + if (bamdzen) + { + kfp_dotxaxpyf_ker = bli_ddotxaxpyf_zen_int_8; + b_fuse = 8; + } + else + { + if ( cntx == NULL ) cntx = bli_gks_query_cntx(); + kfp_dotxaxpyf_ker = + bli_cntx_get_l1f_ker_dt( dt, BLIS_DOTXAXPYF_KER, cntx); + b_fuse = + bli_cntx_get_blksz_def_dt( dt, BLIS_XF, cntx ); + } + + for ( i = 0; i < m; i += f ) + { + f = bli_determine_blocksize_dim_f( i, m, b_fuse ); + n_ahead = m - i - f; + A11 = a + (i )*rs_at + (i )*cs_at; + A21 = a + (i+f)*rs_at + (i )*cs_at; + x1 = x + (i )*incx; + x2 = x + (i+f)*incx; + y1 = y + (i )*incy; + y2 = y + (i+f)*incy; + + /* y1 = y1 + alpha * A11 * x1; (variant 4) */ + if((f == 8) && (incx == 1) && (incy == 1) && (rs_at == 1)) + { + /*this helper function handles unit stride only*/ + bli_pre_hemv_8x8(A11, x1, y1, alpha, cs_at, rs_at); + } + else + { + for ( k = 0; k < f; ++k ) + { + f_behind = k; + f_ahead = f - k - 1; + a10t = A11 + (k )*rs_at + (0 )*cs_at; + alpha11 = A11 + (k )*rs_at + (k )*cs_at; + a21 = A11 + (k+1)*rs_at + (k )*cs_at; + chi11 = x1 + (k )*incx; + y01 = y1 + (0 )*incy; + psi11 = y1 + (k )*incy; + y21 = y1 + (k+1)*incy; + + /* y01 = y01 + alpha * a10t' * chi11; */ + PASTEMAC(d,copycjs)( conjx, + *chi11, conjx_chi11 ); + PASTEMAC(d,scal2s)( *alpha, conjx_chi11, + alpha_chi11 ); + { + for ( j = 0; j < f_behind; ++j ) + { + PASTEMAC(d,axpys) + ( alpha_chi11, + *(a10t + j*cs_at), + *(y01 + j*incy) ); + } + } + + PASTEMAC(d,copycjs)( conja, *alpha11, + alpha11_temp ); + + /* psi11 = psi11 + alpha * alpha11 * chi11; */ + PASTEMAC(d,axpys)( alpha_chi11, alpha11_temp, + *psi11 ); + + /* y21 = y21 + alpha * a21 * chi11; */ + for ( j = 0; j < f_ahead; ++j ) + { + PASTEMAC(d,axpys)( alpha_chi11, + *(a21 + j*rs_at), + *(y21 + j*incy) ); + } + } + } + + /* y1 = y1 + alpha * A21' * x2; (dotxf) */ + /* y2 = y2 + alpha * A21 * x1; (axpyf) */ + kfp_dotxaxpyf_ker + ( + conj0, + conj1, + conjx, + conjx, + n_ahead, + f, + alpha, + A21, rs_at, cs_at, + x2, incx, + x1, incx, + one, + y1, incy, + y2, incy, + cntx + ); + } +} + +GENTFUNC(float, s, hemv_unf_var3) +GENTFUNC(scomplex, c, hemv_unf_var3) +GENTFUNC(dcomplex, z, hemv_unf_var3) +#else INSERT_GENTFUNC_BASIC0( hemv_unf_var3 ) +#endif diff --git a/kernels/zen/1f/bli_dotxaxpyf_int_8.c b/kernels/zen/1f/bli_dotxaxpyf_int_8.c new file mode 100644 index 0000000000..b24aab7571 --- /dev/null +++ b/kernels/zen/1f/bli_dotxaxpyf_int_8.c @@ -0,0 +1,735 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "immintrin.h" + +typedef union{ + __m256d v; + double d[4] __attribute__((aligned(64))); +}vec; + +/** + * bli_pre_hemv_lower_8x8 is a helper function which computes + * "y = y + alpha * a * x" + * dotxf and axpyf of triangular matrix with vector + * for lower triangular matrix cases. + * Computes 8 elements of Y vector by dot product + * of 8 elements of x vector with 8x8 tile of A matrix + * and axpy computation of each x vector elements with + * each column of 8x8 A matrix tile. + +*/ +void bli_pre_hemv_8x8(double *a, double *x, double *y, double *alpha, + dim_t cs_a, dim_t rs_a) +{ + __m256d ymm0, ymm1, ymm2, ymm3, ymm4, ymm5, ymm6, ymm7, ymm8, ymm9; + __m256d ymm10, ymm11, ymm12; + double alpha_chi[8] = {0}; + /*Broadcast alpha*/ + ymm9 = _mm256_broadcast_sd(alpha); + + /** + * Scaling vector x with alpha + * to gather alpha_chi elements + * arranged in one buffer. + */ + ymm10 = _mm256_loadu_pd(x); + ymm11 = _mm256_loadu_pd(x + 4); + ymm10 = _mm256_mul_pd(ymm9, ymm10); + ymm11 = _mm256_mul_pd(ymm9, ymm11); + _mm256_storeu_pd(alpha_chi, ymm10); + _mm256_storeu_pd(alpha_chi + 4, ymm11); + + /*Load y vector*/ + ymm10 = _mm256_loadu_pd(y); + ymm11 = _mm256_loadu_pd(y + 4); + + //Col 0 computation + /*Broadcasts chi and multiplies with alpha to get alpha chi*/ + ymm12 = _mm256_broadcast_sd(alpha_chi); + /*Load first column of A matrix*/ + ymm0 = _mm256_loadu_pd(a); + ymm1 = _mm256_loadu_pd(a + 4); + ymm10 = _mm256_fmadd_pd(ymm12, ymm0, ymm10); + ymm11 = _mm256_fmadd_pd(ymm12, ymm1, ymm11); + + //Col 1 computation + ymm12 = _mm256_broadcast_sd(alpha_chi + 1); + /** + * pack the data in following manner into ymm register + * Since it is computing 2nd column, packing to be done + * as shown below for ymm0: + * col-0 col-1 + * --- --- + x x + --- x + --- x + */ + ymm3 = _mm256_broadcast_sd(a + 1); + ymm0 = _mm256_loadu_pd(a + cs_a * 1); + ymm0 = _mm256_blend_pd(ymm0, ymm3, 0x1); + ymm1 = _mm256_loadu_pd(a + 4 + cs_a * 1); + ymm10 = _mm256_fmadd_pd(ymm12, ymm0, ymm10); + ymm11 = _mm256_fmadd_pd(ymm12, ymm1, ymm11); + + //Col 2 computation + ymm12 = _mm256_broadcast_sd(alpha_chi + 2); + /** + * pack the data in following manner into ymm register + * Since it is computing 3rd column, packing to be done + * as shown below for ymm0: + * col-0 col-1 col-2 + * --- --- --- + x x --- + --- --- x + --- --- x + */ + ymm3 = _mm256_broadcast_sd(a + 2); + ymm4 = _mm256_broadcast_sd(a + 2 + cs_a); + ymm0 = _mm256_loadu_pd(a + cs_a * 2); + ymm0 = _mm256_blend_pd(ymm0, ymm3, 0x1); + ymm0 = _mm256_blend_pd(ymm0, ymm4, 0x2); + ymm1 = _mm256_loadu_pd(a + 4 + cs_a * 2); + ymm10 = _mm256_fmadd_pd(ymm12, ymm0, ymm10); + ymm11 = _mm256_fmadd_pd(ymm12, ymm1, ymm11); + + //Col 3 computation + ymm12 = _mm256_broadcast_sd(alpha_chi + 3); + /** + * pack the data in following manner into ymm register + * Since it is computing 4rd column, packing to be done + * as shown below for ymm0: + * col-0 col-1 col-2 col-3 + * --- --- --- --- + x x x --- + --- --- --- x + */ + ymm3 = _mm256_broadcast_sd(a + 3); + ymm4 = _mm256_broadcast_sd(a + 3 + cs_a); + ymm5 = _mm256_broadcast_sd(a + 3 + cs_a * 2); + ymm0 = _mm256_loadu_pd(a + cs_a * 3); + ymm0 = _mm256_blend_pd(ymm0, ymm3, 0x1); + ymm0 = _mm256_blend_pd(ymm0, ymm4, 0x2); + ymm0 = _mm256_blend_pd(ymm0, ymm5, 0x4); + ymm1 = _mm256_loadu_pd(a + 4 + cs_a * 3); + ymm10 = _mm256_fmadd_pd(ymm12, ymm0, ymm10); + ymm11 = _mm256_fmadd_pd(ymm12, ymm1, ymm11); + + /** + * Transpose 4x4 tile of matrix A, + * for remainder column computation. + */ + ymm0 = _mm256_loadu_pd(a+4 + cs_a * 0); + ymm1 = _mm256_loadu_pd(a+4 + cs_a * 1); + ymm2 = _mm256_loadu_pd(a+4 + cs_a * 2); + ymm3 = _mm256_loadu_pd(a+4 + cs_a * 3); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //Transposed col 1 + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //Transposed col 3 + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //Transposed col 2 + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //Transposed col 4 + + //Col 4 computation + ymm12 = _mm256_broadcast_sd(alpha_chi + 4); + /** + * pack the data in following manner into ymm register + * Since it is computing 4rd column, packing to be done + * as shown below for ymm0: + * col-0 col-1 col-2 col-3 col-4 + * --- --- --- --- --- + x x x x --- + --- --- --- --- --- + --- --- --- --- --- + */ + ymm1 = _mm256_loadu_pd(a + 4 + cs_a * 4); + ymm10 = _mm256_fmadd_pd(ymm12, ymm6, ymm10); + ymm11 = _mm256_fmadd_pd(ymm12, ymm1, ymm11); + + //Col 5 computation + /** + * Packs the data in similar manner as shown + * for col 0-4 computation, along with + * packing all 5th elements from col 0 - 4 + * in other ymm register. + * col-4 col-5 + * --- --- + x x + --- x + --- x + + */ + ymm12 = _mm256_broadcast_sd(alpha_chi + 5); + ymm3 = _mm256_broadcast_sd(a + 5 + cs_a * 4); + ymm1 = _mm256_loadu_pd(a + 4 + cs_a * 5); + ymm1 = _mm256_blend_pd(ymm1, ymm3, 0x1); + ymm10 = _mm256_fmadd_pd(ymm12, ymm7, ymm10); + ymm11 = _mm256_fmadd_pd(ymm12, ymm1, ymm11); + + //Col 6 computation + /** + * Packs the data in similar manner as shown + * for col 0-4 computation, along with + * packing all 6th elements from col 0 - 4 + * in other ymm register. + * col-4 col-5 col-6 + * --- --- --- + x x --- + --- --- x + --- --- x + */ + ymm12 = _mm256_broadcast_sd(alpha_chi + 6); + ymm1 = _mm256_loadu_pd(a + 4 + cs_a * 6); + ymm3 = _mm256_broadcast_sd(a + 6 + cs_a * 4); + ymm4 = _mm256_broadcast_sd(a + 6 + cs_a * 5); + ymm1 = _mm256_blend_pd(ymm1, ymm3, 0x1); + ymm1 = _mm256_blend_pd(ymm1, ymm4, 0x2); + ymm10 = _mm256_fmadd_pd(ymm12, ymm8, ymm10); + ymm11 = _mm256_fmadd_pd(ymm12, ymm1, ymm11); + + //Col 7 computation + /** + * Packs the data in similar manner as shown + * for col 0-4 computation, along with + * packing all 7th elements from col 0 - 4 + * in other ymm register. + * col-4 col-5 col-6 col-7 + * --- --- --- --- + x x x --- + --- --- --- x + */ + ymm12 = _mm256_broadcast_sd(alpha_chi + 7); + ymm1 = _mm256_loadu_pd(a + 4 + cs_a * 7); + ymm3 = _mm256_broadcast_sd(a + 7 + cs_a * 4); + ymm4 = _mm256_broadcast_sd(a + 7 + cs_a * 5); + ymm5 = _mm256_broadcast_sd(a + 7 + cs_a * 6); + ymm1 = _mm256_blend_pd(ymm1, ymm3, 0x1); + ymm1 = _mm256_blend_pd(ymm1, ymm4, 0x2); + ymm1 = _mm256_blend_pd(ymm1, ymm5, 0x4); + ymm10 = _mm256_fmadd_pd(ymm12, ymm9, ymm10); + ymm11 = _mm256_fmadd_pd(ymm12, ymm1, ymm11); + + /** + * Computed result of vector y is available in ymm10, ymm11. + * Storing the result back from ymm register into y vector for + * further computaion. + */ + _mm256_storeu_pd(y, ymm10); + _mm256_storeu_pd(y + 4, ymm11); +} + + +/** + * bli_post_hemv_lower_8x8 is a helper function which computes + * "y = y + alpha * a * x" + * dotxf and axpyf of triangular matrix with vector + * for upper triangular matrix cases. + * Computes 8 elements of Y vector by dot product + * of 8 elements of x vector with 8x8 tile of A matrix + * and axpy computation of each x vector elements with + * each column of 8x8 A matrix tile. +*/ +void bli_post_hemv_8x8(double *a, double *x, double *y, double *alpha, + dim_t cs_a, dim_t rs_a) +{ + __m256d ymm0, ymm1, ymm2, ymm3, ymm4, ymm5, ymm6, ymm7, ymm8, ymm9; + __m256d ymm10, ymm11, ymm12; + double alpha_chi[8] = {0}; + + ymm9 = _mm256_broadcast_sd(alpha); + + ymm10 = _mm256_loadu_pd(x); + ymm11 = _mm256_loadu_pd(x + 4); + ymm10 = _mm256_mul_pd(ymm9, ymm10); + ymm11 = _mm256_mul_pd(ymm9, ymm11); + _mm256_storeu_pd(alpha_chi, ymm10); + _mm256_storeu_pd(alpha_chi + 4, ymm11); + + ymm10 = _mm256_loadu_pd(y); + ymm11 = _mm256_loadu_pd(y + 4); + + ymm0 = _mm256_loadu_pd(a + cs_a * 4); + ymm1 = _mm256_loadu_pd(a + cs_a * 5); + ymm2 = _mm256_loadu_pd(a + cs_a * 6); + ymm3 = _mm256_loadu_pd(a + cs_a * 7); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + //Col 0 computation + /** + * pack the data in following manner into ymm register + * Since it is computing 4rd column, packing to be done + * as shown below for ymm0: + * col-0 col-1 col-2 col-3 + * x x x x + --- + --- + --- + */ + ymm12 = _mm256_broadcast_sd(alpha_chi); + ymm0 = _mm256_loadu_pd(a); + ymm1 = _mm256_broadcast_sd(a + cs_a * 1); + ymm2 = _mm256_broadcast_sd(a + cs_a * 2); + ymm3 = _mm256_broadcast_sd(a + cs_a * 3); + ymm0 = _mm256_blend_pd(ymm0, ymm1, 0x2); + ymm0 = _mm256_blend_pd(ymm0, ymm2, 0x4); + ymm0 = _mm256_blend_pd(ymm0, ymm3, 0x8); + ymm10 = _mm256_fmadd_pd(ymm12, ymm0, ymm10); + ymm11 = _mm256_fmadd_pd(ymm12, ymm6, ymm11); + + //Col 1 computation + /** + * pack the data in following manner into ymm register + * Since it is computing 4rd column, packing to be done + * as shown below for ymm0: + * col-1 col-2 col-3 + * x x x + x + --- + --- + */ + ymm12 = _mm256_broadcast_sd(alpha_chi + 1); + ymm0 = _mm256_loadu_pd(a + cs_a * 1); + ymm2 = _mm256_broadcast_sd(a + cs_a * 2 + 1); + ymm3 = _mm256_broadcast_sd(a + cs_a * 3 + 1); + ymm0 = _mm256_blend_pd(ymm0, ymm2, 0x4); + ymm0 = _mm256_blend_pd(ymm0, ymm3, 0x8); + ymm10 = _mm256_fmadd_pd(ymm12, ymm0, ymm10); + ymm11 = _mm256_fmadd_pd(ymm12, ymm7, ymm11); + + //Col 2 computation + /** + * pack the data in following manner into ymm register + * Since it is computing 4rd column, packing to be done + * as shown below for ymm0: + * col-2 col-3 + * x x + x + x + --- + */ + ymm12 = _mm256_broadcast_sd(alpha_chi + 2); + ymm0 = _mm256_loadu_pd(a + cs_a * 2); + ymm2 = _mm256_broadcast_sd(a + cs_a * 3 + 2); + ymm0 = _mm256_blend_pd(ymm0, ymm2, 0x8); + ymm10 = _mm256_fmadd_pd(ymm12, ymm0, ymm10); + ymm11 = _mm256_fmadd_pd(ymm12, ymm8, ymm11); + + //Col 3 computation + /** + * pack the data in following manner into ymm register + * Since it is computing 4rd column, packing to be done + * as shown below for ymm0: + * col-3 + * x + x + x + x + */ + ymm12 = _mm256_broadcast_sd(alpha_chi + 3); + ymm0 = _mm256_loadu_pd(a + cs_a * 3); + ymm10 = _mm256_fmadd_pd(ymm12, ymm0, ymm10); + ymm11 = _mm256_fmadd_pd(ymm12, ymm9, ymm11); + + //Col 4 computation + ymm12 = _mm256_broadcast_sd(alpha_chi + 4); + ymm0 = _mm256_loadu_pd(a + cs_a * 4); + ymm1 = _mm256_loadu_pd(a + cs_a * 4 + 4); + ymm4 = _mm256_broadcast_sd(a + cs_a * 5 + 4); + ymm5 = _mm256_broadcast_sd(a + cs_a * 6 + 4); + ymm6 = _mm256_broadcast_sd(a + cs_a * 7 + 4); + ymm1 = _mm256_blend_pd(ymm1, ymm4, 0x2); + ymm1 = _mm256_blend_pd(ymm1, ymm5, 0x4); + ymm1 = _mm256_blend_pd(ymm1, ymm6, 0x8); + ymm10 = _mm256_fmadd_pd(ymm12, ymm0, ymm10); + ymm11 = _mm256_fmadd_pd(ymm12, ymm1, ymm11); + + //Col 5 computation + ymm12 = _mm256_broadcast_sd(alpha_chi + 5); + ymm0 = _mm256_loadu_pd(a + cs_a * 5); + ymm1 = _mm256_loadu_pd(a + cs_a * 5 + 4); + ymm5 = _mm256_broadcast_sd(a + cs_a * 6 + 5); + ymm6 = _mm256_broadcast_sd(a + cs_a * 7 + 5); + ymm1 = _mm256_blend_pd(ymm1, ymm5, 0x4); + ymm1 = _mm256_blend_pd(ymm1, ymm6, 0x8); + ymm10 = _mm256_fmadd_pd(ymm12, ymm0, ymm10); + ymm11 = _mm256_fmadd_pd(ymm12, ymm1, ymm11); + + //Col 6 computation + ymm12 = _mm256_broadcast_sd(alpha_chi + 6); + ymm0 = _mm256_loadu_pd(a + cs_a * 6); + ymm1 = _mm256_loadu_pd(a + cs_a * 6 + 4); + ymm6 = _mm256_broadcast_sd(a + cs_a * 7 + 6); + ymm1 = _mm256_blend_pd(ymm1, ymm6, 0x8); + ymm10 = _mm256_fmadd_pd(ymm12, ymm0, ymm10); + ymm11 = _mm256_fmadd_pd(ymm12, ymm1, ymm11); + + //Col 7 computation + ymm12 = _mm256_broadcast_sd(alpha_chi + 7); + ymm0 = _mm256_loadu_pd(a + cs_a * 7); + ymm1 = _mm256_loadu_pd(a + cs_a * 7 + 4); + ymm10 = _mm256_fmadd_pd(ymm12, ymm0, ymm10); + ymm11 = _mm256_fmadd_pd(ymm12, ymm1, ymm11); + + /** + * Computed result of vector y is available in ymm10, ymm11. + * Storing the result back from ymm register into y vector for + * further computaion. + */ + _mm256_storeu_pd(y, ymm10); + _mm256_storeu_pd(y + 4, ymm11); +} + + +/** + * ddotxaxpyf kernel performs dot and apxy function all togather + * on a tile of 4x8 size. + * x_trsv holds 4 elements of vector x, a_tile[0-7] holds + * 4x8 tile of A matrix. + * Following equations are solved in a way represented + * y1 = y1 + alpha * A21' * x2; (dotxf) + y2 = y2 + alpha * A21 * x1; (axpyf) + + * B1 B2 B3 B4 B5 B6 B7 B8 + * (broadcast elements of [x*alpha] vector) + * tile 0 1 2 3 4 5 6 7 + * x_trsv[0] A00 A01 A02 A03 => rho0 | A04 A05 A06 A07 => rho4 + * x_trsv[1] A10 A11 A12 A13 => rho1 | A14 A15 A16 A17 => rho5 + * x_trsv[2] A20 A21 A22 A23 => rho2 | A24 A25 A26 A27 => rho6 + * x_trsv[3] A30 A31 A32 A33 => rho3 | A34 A35 A36 A37 => rho7 + || || || || || || || || + \/ \/ \/ \/ \/ \/ \/ \/ + += += += += += += += += + z_vec z_vec z_vec z_vec z_vec z_vec z_vec z_vec + * + * + */ +void bli_ddotxaxpyf_zen_int_8 +( + conj_t conjat, + conj_t conja, + conj_t conjw, + conj_t conjx, + dim_t m, + dim_t b_n, + double* restrict alpha, + double* restrict a, inc_t inca, inc_t lda, + double* restrict w, inc_t incw, + double* restrict x, inc_t incx, + double* restrict beta, + double* restrict y, inc_t incy, + double* restrict z, inc_t incz, + cntx_t* restrict cntx + ) +{ + /* A is m x n. */ + /* y = beta * y + alpha * A^T w; */ + /* z = z + alpha * A x; */ + if ((inca == 1) && (incw == 1) && (incx == 1) + && (incy == 1) && (incz == 1) && (b_n == 8)) + { + __m256d r0, r1; + r0 = _mm256_setzero_pd(); + r1 = _mm256_setzero_pd(); + + /* If beta is zero, clear y. Otherwise, scale by beta. */ + if ( PASTEMAC(d,eq0)( *beta ) ) + { + for ( dim_t i = 0; i < 8; ++i ) + { + PASTEMAC(d,set0s)( y[i] ); + } + } + else + { + for ( dim_t i = 0; i < 8; ++i ) + { + PASTEMAC(d,scals)( *beta, y[i] ); + } + } + + /* If the vectors are empty or if alpha is zero, return early*/ + if ( bli_zero_dim1( m ) || PASTEMAC(d,eq0)( *alpha ) ) return; + + dim_t row = 0; + dim_t iter = m/4; + dim_t rem = m%4; + if(iter) + { + vec x_trsv, x_hemvB1, x_hemvB2, x_hemvB3, x_hemvB4; + vec x_hemvB5, x_hemvB6, x_hemvB7, x_hemvB8; + + vec a_tile0, a_tile1, a_tile2, a_tile3; + vec a_tile4, a_tile5, a_tile6, a_tile7; + + vec rho0, rho1, rho2, rho3; + vec rho4, rho5, rho6, rho7; + + __m256d z_vec; + + /** + * Load [x vector * alpha], broadcast each element into + * different ymm registers. To perform axpyf operation + * with 4x8 tile of A matrix. + */ + + x_hemvB1.v = _mm256_set1_pd(x[0*incx] * (*alpha)); + x_hemvB2.v = _mm256_set1_pd(x[1*incx] * (*alpha)); + x_hemvB3.v = _mm256_set1_pd(x[2*incx] * (*alpha)); + x_hemvB4.v = _mm256_set1_pd(x[3*incx] * (*alpha)); + + x_hemvB5.v = _mm256_set1_pd(x[4*incx] * (*alpha)); + x_hemvB6.v = _mm256_set1_pd(x[5*incx] * (*alpha)); + x_hemvB7.v = _mm256_set1_pd(x[6*incx] * (*alpha)); + x_hemvB8.v = _mm256_set1_pd(x[7*incx] * (*alpha)); + + /** + * clear rho register which holds result of + * fmadds for dotxf operation. + * Once micro tile is computed, horizontal addition + * of all rho's will provide us with the result of + * dotxf opereation. + */ + rho0.v = _mm256_setzero_pd(); + rho1.v = _mm256_setzero_pd(); + rho2.v = _mm256_setzero_pd(); + rho3.v = _mm256_setzero_pd(); + rho4.v = _mm256_setzero_pd(); + rho5.v = _mm256_setzero_pd(); + rho6.v = _mm256_setzero_pd(); + rho7.v = _mm256_setzero_pd(); + + for(; (row + 3) < m; row+= 4) + { + a_tile0.v = _mm256_loadu_pd((double *) + &a[row + 0 * lda] ); + a_tile1.v = _mm256_loadu_pd((double *) + &a[row + 1 * lda] ); + a_tile2.v = _mm256_loadu_pd((double *) + &a[row + 2 * lda] ); + a_tile3.v = _mm256_loadu_pd((double *) + &a[row + 3 * lda] ); + a_tile4.v = _mm256_loadu_pd((double *) + &a[row + 4 * lda] ); + a_tile5.v = _mm256_loadu_pd((double *) + &a[row + 5 * lda] ); + a_tile6.v = _mm256_loadu_pd((double *) + &a[row + 6 * lda] ); + a_tile7.v = _mm256_loadu_pd((double *) + &a[row + 7 * lda] ); + + x_trsv.v = _mm256_loadu_pd((double *) &w[row]); + z_vec = _mm256_loadu_pd((double *) &z[row] ); + + //dot product operation + rho0.v = _mm256_fmadd_pd(a_tile0.v, + x_trsv.v, rho0.v); + rho4.v = _mm256_fmadd_pd(a_tile4.v, + x_trsv.v, rho4.v); + + rho1.v = _mm256_fmadd_pd(a_tile1.v, + x_trsv.v, rho1.v); + rho5.v = _mm256_fmadd_pd(a_tile5.v, + x_trsv.v, rho5.v); + + rho2.v = _mm256_fmadd_pd(a_tile2.v, + x_trsv.v, rho2.v); + rho6.v = _mm256_fmadd_pd(a_tile6.v, + x_trsv.v, rho6.v); + + rho3.v = _mm256_fmadd_pd(a_tile3.v, + x_trsv.v, rho3.v); + rho7.v = _mm256_fmadd_pd(a_tile7.v, + x_trsv.v, rho7.v); + + //axpy operation + z_vec = _mm256_fmadd_pd(a_tile0.v, + x_hemvB1.v, z_vec); + z_vec = _mm256_fmadd_pd(a_tile1.v, + x_hemvB2.v, z_vec); + z_vec = _mm256_fmadd_pd(a_tile2.v, + x_hemvB3.v, z_vec); + z_vec = _mm256_fmadd_pd(a_tile3.v, + x_hemvB4.v, z_vec); + + z_vec = _mm256_fmadd_pd(a_tile4.v, + x_hemvB5.v, z_vec); + z_vec = _mm256_fmadd_pd(a_tile5.v, + x_hemvB6.v, z_vec); + z_vec = _mm256_fmadd_pd(a_tile6.v, + x_hemvB7.v, z_vec); + z_vec = _mm256_fmadd_pd(a_tile7.v, + x_hemvB8.v, z_vec); + + _mm256_storeu_pd((double *)&z[row], z_vec); + } + /*Horizontal addition of rho's elements to compute + * the final dotxf result. + */ + rho0.v = _mm256_hadd_pd( rho0.v, rho1.v ); + rho2.v = _mm256_hadd_pd( rho2.v, rho3.v ); + rho4.v = _mm256_hadd_pd( rho4.v, rho5.v ); + rho6.v = _mm256_hadd_pd( rho6.v, rho7.v ); + + { + __m128d xmm0, xmm1; + + xmm0 = _mm256_extractf128_pd(rho0.v, 0); + xmm1 = _mm256_extractf128_pd(rho0.v, 1); + xmm0 = _mm_add_pd(xmm0, xmm1); + r0 = _mm256_insertf128_pd(r0, xmm0, 0); + + xmm0 = _mm256_extractf128_pd(rho2.v, 0); + xmm1 = _mm256_extractf128_pd(rho2.v, 1); + xmm0 = _mm_add_pd(xmm0, xmm1); + r0 = _mm256_insertf128_pd(r0, xmm0, 1); + + + xmm0 = _mm256_extractf128_pd(rho4.v, 0); + xmm1 = _mm256_extractf128_pd(rho4.v, 1); + xmm0 = _mm_add_pd(xmm0, xmm1); + r1 = _mm256_insertf128_pd(r1, xmm0, 0); + + xmm0 = _mm256_extractf128_pd(rho6.v, 0); + xmm1 = _mm256_extractf128_pd(rho6.v, 1); + xmm0 = _mm_add_pd(xmm0, xmm1); + r1 = _mm256_insertf128_pd(r1, xmm0, 1); + } + } + if(rem) + { + double r[ 8 ]; + double ax[ 8 ]; + /** + * Computed dot product computation needs + * to be brought into the r buffer for + * corner cases, so that remainder computation + * can be updated in r buffer. + */ + _mm256_storeu_pd((double *)r, r0); + _mm256_storeu_pd( (double *)(r + 4), r1); + + PRAGMA_SIMD + for ( dim_t i = 0; i < 8; ++i ) + { + PASTEMAC(d,scal2s) + ( *alpha, x[i], ax[i] ); + } + + PRAGMA_SIMD + for ( dim_t p = row; p < m; ++p ) + { + for ( dim_t i = 0; i < 8; ++i ) + { + PASTEMAC(d,axpys) + ( a[p + i*lda], + w[p], r[i] ); + PASTEMAC(d,axpyjs) + ( ax[i], + a[p + i*lda], z[p] ); + } + } + /** + * Final dot product computation needs be + * loaded into registers, for getting + * scaled by Alpha and finally be stored + * back into output vector. + */ + r0 = _mm256_loadu_pd((double const *)r); + r1 = _mm256_loadu_pd((double const *)(r + 4)); + } + + /** + * Storing the computed result after being + * scaled by Alpha into output vector. + */ + { + __m256d y0, y1, Alpha; + y0 = _mm256_loadu_pd(y); + y1 = _mm256_loadu_pd(y + 4); + Alpha = _mm256_broadcast_sd(alpha); + y0 = _mm256_fmadd_pd(Alpha, r0, y0); + y1 = _mm256_fmadd_pd(Alpha, r1, y1); + _mm256_storeu_pd(y, y0); + _mm256_storeu_pd(y+4, y1); + } + } + else + { + /* Query the context for the kernel function pointer. */ + const num_t dt = PASTEMAC(d,type); + PASTECH(d,dotxf_ker_ft) kfp_df = + bli_cntx_get_l1f_ker_dt( dt, BLIS_DOTXF_KER, cntx ); + PASTECH(d,axpyf_ker_ft) kfp_af = + bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPYF_KER, cntx ); + + kfp_df + ( + conjat, + conjw, + m, + b_n, + alpha, + a, inca, lda, + w, incw, + beta, + y, incy, + cntx + ); + + kfp_af + ( + conja, + conjx, + m, + b_n, + alpha, + a, inca, lda, + x, incx, + z, incz, + cntx + ); + } +} diff --git a/kernels/zen/bli_kernels_zen.h b/kernels/zen/bli_kernels_zen.h index 914b5d631f..f3a939b0b7 100644 --- a/kernels/zen/bli_kernels_zen.h +++ b/kernels/zen/bli_kernels_zen.h @@ -32,6 +32,14 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ +// hemv helper function +void bli_pre_hemv_8x8(double *a, double *x, + double *y, double *alpha, + dim_t cs_a, dim_t rs_a); + +void bli_post_hemv_8x8(double *a, double *x, + double *y, double *alpha, + dim_t cs_a, dim_t rs_a); // -- level-1m -- PACKM_KER_PROT(double, d, packm_8xk_gen_zen) @@ -111,6 +119,9 @@ AXPYF_KER_PROT( dcomplex, z, axpyf_zen_int_4 ) DOTXF_KER_PROT( float, s, dotxf_zen_int_8 ) DOTXF_KER_PROT( double, d, dotxf_zen_int_8 ) +// dotxaxpyf (intrinsics) +DOTXAXPYF_KER_PROT( double, d, dotxaxpyf_zen_int_8 ) + // -- level-2 ---------------------------------------------------------------- //gemv(scalar code) From 8201bcfdaf96ff12d99f33da70d3f2d95d3653f9 Mon Sep 17 00:00:00 2001 From: Harihara Sudhan S Date: Tue, 14 Dec 2021 12:01:12 +0530 Subject: [PATCH 054/243] Improved DGEMV performance for smaller sizes - Introduced two new ddotxf functions with lower fuse factor. - Changed the DGEMV framework to use new kernels to improve problem decomposition. Change-Id: I523e158fd33260d06224118fbf74f2314e03a617 --- frame/2/gemv/bli_gemv_unf_var1.c | 211 +++-- kernels/zen/1f/bli_dotxf_zen_int_8.c | 1105 +++++++++++++++++++++----- kernels/zen/bli_kernels_zen.h | 2 + 3 files changed, 1066 insertions(+), 252 deletions(-) diff --git a/frame/2/gemv/bli_gemv_unf_var1.c b/frame/2/gemv/bli_gemv_unf_var1.c index e468587d4b..838ea577bc 100644 --- a/frame/2/gemv/bli_gemv_unf_var1.c +++ b/frame/2/gemv/bli_gemv_unf_var1.c @@ -34,7 +34,6 @@ */ #include "blis.h" -#define BLIS_DGEMV_VAR1_FUSE 8 #undef GENTFUNC #define GENTFUNC( ctype, ch, varname ) \ @@ -121,30 +120,30 @@ void bli_dgemv_unf_var1 ) { - double* A1; - double* y1; - dim_t i; - dim_t f; - dim_t n_elem, n_iter; - inc_t rs_at, cs_at; - conj_t conja; + double *A1; + double *y1; + dim_t i; + dim_t f; + dim_t n_elem, n_iter; + inc_t rs_at, cs_at; + conj_t conja; //memory pool declarations for packing vector X. - mem_t mem_bufX; - rntm_t rntm; - double *x_buf = x; - inc_t buf_incx = incx; + mem_t mem_bufX; + rntm_t rntm; + double *x_buf = x; + inc_t buf_incx = incx; bli_init_once(); - if( cntx == NULL ) cntx = bli_gks_query_cntx(); + if (cntx == NULL) + cntx = bli_gks_query_cntx(); - bli_set_dims_incs_with_trans( transa, - m, n, rs_a, cs_a, - &n_iter, &n_elem, &rs_at, &cs_at ); + bli_set_dims_incs_with_trans(transa, + m, n, rs_a, cs_a, + &n_iter, &n_elem, &rs_at, &cs_at); - conja = bli_extract_conj( transa ); + conja = bli_extract_conj(transa); - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. // This function is invoked on all architectures including ‘generic’. // Invoke architecture specific kernels only if we are sure that we are running on zen, // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). @@ -193,88 +192,154 @@ void bli_dgemv_unf_var1 AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); return; } - + if (incx > 1) { - /* + /* Initialize mem pool buffer to NULL and size to 0 "buf" and "size" fields are assigned once memory is allocated from the pool in bli_membrk_acquire_m(). This will ensure bli_mem_is_alloc() will be passed on an allocated memory if created or a NULL . - */ - mem_bufX.pblk.buf = NULL; mem_bufX.pblk.block_size = 0; - mem_bufX.buf_type = 0; mem_bufX.size = 0; - mem_bufX.pool = NULL; + */ - /* In order to get the buffer from pool via rntm access to memory broker + mem_bufX.pblk.buf = NULL; + mem_bufX.pblk.block_size = 0; + mem_bufX.buf_type = 0; + mem_bufX.size = 0; + mem_bufX.pool = NULL; + + /* In order to get the buffer from pool via rntm access to memory broker is needed.Following are initializations for rntm */ - bli_rntm_init_from_global( &rntm ); - bli_rntm_set_num_threads_only( 1, &rntm ); - bli_membrk_rntm_set_membrk( &rntm ); + bli_rntm_init_from_global(&rntm); + bli_rntm_set_num_threads_only(1, &rntm); + bli_membrk_rntm_set_membrk(&rntm); - //calculate the size required for n_elem double elements in vector X. - size_t buffer_size = n_elem * sizeof(double); + //calculate the size required for n_elem double elements in vector X. + size_t buffer_size = n_elem * sizeof(double); - #ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_dgemv_unf_var1(): get mem pool block\n" ); - #endif +#ifdef BLIS_ENABLE_MEM_TRACING + printf("bli_dgemv_unf_var1(): get mem pool block\n"); +#endif - /*acquire a Buffer(n_elem*size(double)) from the memory broker - and save the associated mem_t entry to mem_bufX.*/ - bli_membrk_acquire_m(&rntm, - buffer_size, - BLIS_BUFFER_FOR_B_PANEL, - &mem_bufX); + /*acquire a Buffer(n_elem*size(double)) from the memory broker + and save the associated mem_t entry to mem_bufX.*/ + bli_membrk_acquire_m(&rntm, + buffer_size, + BLIS_BUFFER_FOR_B_PANEL, + &mem_bufX); - /*Continue packing X if buffer memory is allocated*/ - if ((bli_mem_is_alloc( &mem_bufX ))) - { - x_buf = bli_mem_buffer(&mem_bufX); - - //pack X vector with non-unit stride to a temp buffer x_buf with unit stride - for(dim_t x_index = 0 ; x_index < n_elem ; x_index++) - { - *(x_buf + x_index) = *(x + (x_index * incx)) ; - } - // stride of vector x_buf =1 - buf_incx = 1; - } + /*Continue packing X if buffer memory is allocated*/ + if ((bli_mem_is_alloc(&mem_bufX))) + { + x_buf = bli_mem_buffer(&mem_bufX); + + //pack X vector with non-unit stride to a temp buffer x_buf with unit stride + for (dim_t x_index = 0; x_index < n_elem; x_index++) + { + *(x_buf + x_index) = *(x + (x_index * incx)); + } + // stride of vector x_buf =1 + buf_incx = 1; } - - for ( i = 0; i < n_iter; i += f ) + } + + dim_t fuse_factor = 8; + dim_t f_temp =0; + + if (n < 4) + { + fuse_factor = 2; + } else if (n < 8) + { + fuse_factor = 4; + } + + + for (i = 0; i < n_iter; i += f) + { + f = bli_determine_blocksize_dim_f(i, n_iter, fuse_factor); + + //A = a + i * row_increment + 0 * column_increment + A1 = a + (i)*rs_at; + y1 = y + (i)*incy; + + /* y1 = beta * y1 + alpha * A1 * x; */ + switch (f) { - f = bli_determine_blocksize_dim_f( i, n_iter, BLIS_DGEMV_VAR1_FUSE ); + case 8: - A1 = a + (i )*rs_at + (0 )*cs_at; - y1 = y + (i )*incy; - - /* y1 = beta * y1 + alpha * A1 * x; */ - bli_ddotxf_zen_int_8 - ( + bli_ddotxf_zen_int_8( conja, conjx, n_elem, f, alpha, - A1, cs_at, rs_at, - x_buf, buf_incx, + A1, cs_at, rs_at, + x_buf, buf_incx, beta, - y1, incy, - cntx - ); + y1, incy, + cntx); + + break; + default: + if (f < 4) + { + bli_ddotxf_zen_int_2( + conja, + conjx, + n_elem, + f, + alpha, + A1, cs_at, rs_at, + x_buf, buf_incx, + beta, + y1, incy, + cntx); + } + else + { + bli_ddotxf_zen_int_4( + conja, + conjx, + n_elem, + f, + alpha, + A1, cs_at, rs_at, + x_buf, buf_incx, + beta, + y1, incy, + cntx); + } } - if ((incx > 1) && bli_mem_is_alloc( &mem_bufX )) + + f_temp = bli_determine_blocksize_dim_f(i + f, n_iter, fuse_factor); + + if (f_temp < fuse_factor) { - #ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_dgemv_unf_var1(): releasing mem pool block\n" ); - #endif - // Return the buffer to pool - bli_membrk_release(&rntm , &mem_bufX); + switch (fuse_factor) + { + case 8: + fuse_factor = 4; + break; + case 4: + fuse_factor = 2; + break; + } } - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); + } + + if ((incx > 1) && bli_mem_is_alloc(&mem_bufX)) + { +#ifdef BLIS_ENABLE_MEM_TRACING + printf("bli_dgemv_unf_var1(): releasing mem pool block\n"); +#endif + // Return the buffer to pool + bli_membrk_release(&rntm, &mem_bufX); + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); } void bli_sgemv_unf_var1 diff --git a/kernels/zen/1f/bli_dotxf_zen_int_8.c b/kernels/zen/1f/bli_dotxf_zen_int_8.c index 531a389b50..e25910fb4e 100644 --- a/kernels/zen/1f/bli_dotxf_zen_int_8.c +++ b/kernels/zen/1f/bli_dotxf_zen_int_8.c @@ -52,6 +52,14 @@ typedef union double d[4] __attribute__((aligned(64))); } v4df_t; +/* Union data structure to access AVX registers +* One 128-bit AVX register holds 2 DP elements. */ +typedef union +{ + __m128d v; + double d[2] __attribute__((aligned(64))); +} v2df_t; + // ----------------------------------------------------------------------------- void bli_sdotxf_zen_int_8 @@ -430,49 +438,46 @@ void bli_ddotxf_zen_int_8 cntx_t* restrict cntx ) { - const dim_t fuse_fac = 8; - const dim_t n_elem_per_reg = 4; + const dim_t fuse_fac = 8; + const dim_t n_elem_per_reg = 4; // If the b_n dimension is zero, y is empty and there is no computation. - if ( bli_zero_dim1( b_n ) ) return; + if (bli_zero_dim1(b_n)) + return; // If the m dimension is zero, or if alpha is zero, the computation // simplifies to updating y. - if ( bli_zero_dim1( m ) || PASTEMAC(d,eq0)( *alpha ) ) + if (bli_zero_dim1(m) || PASTEMAC(d, eq0)(*alpha)) { - bli_dscalv_zen_int10 - ( - BLIS_NO_CONJUGATE, - b_n, - beta, - y, incy, - cntx - ); + bli_dscalv_zen_int10( + BLIS_NO_CONJUGATE, + b_n, + beta, + y, incy, + cntx); return; } // If b_n is not equal to the fusing factor, then perform the entire // operation as a loop over dotxv. - if ( b_n != fuse_fac ) + if (b_n != fuse_fac) { - for ( dim_t i = 0; i < b_n; ++i ) + for (dim_t i = 0; i < b_n; ++i) { - double* a1 = a + (0 )*inca + (i )*lda; - double* x1 = x + (0 )*incx; - double* psi1 = y + (i )*incy; - - bli_ddotxv_zen_int - ( - conjat, - conjx, - m, - alpha, - a1, inca, - x1, incx, - beta, - psi1, - cntx - ); + double *a1 = a + (0) * inca + (i)*lda; + double *x1 = x + (0) * incx; + double *psi1 = y + (i)*incy; + + bli_ddotxv_zen_int( + conjat, + conjx, + m, + alpha, + a1, inca, + x1, incx, + beta, + psi1, + cntx); } return; } @@ -493,115 +498,113 @@ void bli_ddotxf_zen_int_8 // distinguishes between (1) and (2). // Intermediate variables to hold the completed dot products - double rho0 = 0, rho1 = 0, rho2 = 0, rho3 = 0, - rho4 = 0, rho5 = 0, rho6 = 0, rho7 = 0; + double rho0 = 0, rho1 = 0, rho2 = 0, rho3 = 0; + double rho4 = 0, rho5 = 0, rho6 = 0, rho7 = 0; - if ( inca == 1 && incx == 1 ) + if (inca == 1 && incx == 1) { const dim_t n_iter_unroll = 1; // Use the unrolling factor and the number of elements per register // to compute the number of vectorized and leftover iterations. - dim_t m_viter = ( m ) / ( n_elem_per_reg * n_iter_unroll ); + dim_t m_viter; + + // Calculate the number of vector iterations that can occur + // for the given unroll factors. + m_viter = (m) / (n_elem_per_reg * n_iter_unroll); // Set up pointers for x and the b_n columns of A (rows of A^T). - double* restrict x0 = x; - double* restrict a0 = a + 0*lda; - double* restrict a1 = a + 1*lda; - double* restrict a2 = a + 2*lda; - double* restrict a3 = a + 3*lda; - double* restrict a4 = a + 4*lda; - double* restrict a5 = a + 5*lda; - double* restrict a6 = a + 6*lda; - double* restrict a7 = a + 7*lda; + double *restrict x0 = x; + double *restrict av[8]; + + av[0] = a + 0 * lda; + av[1] = a + 1 * lda; + av[2] = a + 2 * lda; + av[3] = a + 3 * lda; + av[4] = a + 4 * lda; + av[5] = a + 5 * lda; + av[6] = a + 6 * lda; + av[7] = a + 7 * lda; // Initialize b_n rho vector accumulators to zero. - v4df_t rho0v; rho0v.v = _mm256_setzero_pd(); - v4df_t rho1v; rho1v.v = _mm256_setzero_pd(); - v4df_t rho2v; rho2v.v = _mm256_setzero_pd(); - v4df_t rho3v; rho3v.v = _mm256_setzero_pd(); - v4df_t rho4v; rho4v.v = _mm256_setzero_pd(); - v4df_t rho5v; rho5v.v = _mm256_setzero_pd(); - v4df_t rho6v; rho6v.v = _mm256_setzero_pd(); - v4df_t rho7v; rho7v.v = _mm256_setzero_pd(); + v4df_t rhov[8]; - v4df_t x0v; - v4df_t a0v, a1v, a2v, a3v, a4v, a5v, a6v, a7v; + rhov[0].v = _mm256_setzero_pd(); + rhov[1].v = _mm256_setzero_pd(); + rhov[2].v = _mm256_setzero_pd(); + rhov[3].v = _mm256_setzero_pd(); + rhov[4].v = _mm256_setzero_pd(); + rhov[5].v = _mm256_setzero_pd(); + rhov[6].v = _mm256_setzero_pd(); + rhov[7].v = _mm256_setzero_pd(); - // If there are vectorized iterations, perform them with vector - // instructions. - for ( dim_t i = 0; i < m_viter; ++i ) + v4df_t xv; + v4df_t avec[8]; + + for (dim_t i = 0; i < m_viter; ++i) { // Load the input values. - x0v.v = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); + xv.v = _mm256_loadu_pd(x0 + 0 * n_elem_per_reg); - a0v.v = _mm256_loadu_pd( a0 + 0*n_elem_per_reg ); - a1v.v = _mm256_loadu_pd( a1 + 0*n_elem_per_reg ); - a2v.v = _mm256_loadu_pd( a2 + 0*n_elem_per_reg ); - a3v.v = _mm256_loadu_pd( a3 + 0*n_elem_per_reg ); - a4v.v = _mm256_loadu_pd( a4 + 0*n_elem_per_reg ); - a5v.v = _mm256_loadu_pd( a5 + 0*n_elem_per_reg ); - a6v.v = _mm256_loadu_pd( a6 + 0*n_elem_per_reg ); - a7v.v = _mm256_loadu_pd( a7 + 0*n_elem_per_reg ); + avec[0].v = _mm256_loadu_pd(av[0] + 0 * n_elem_per_reg); + avec[1].v = _mm256_loadu_pd(av[1] + 0 * n_elem_per_reg); + avec[2].v = _mm256_loadu_pd(av[2] + 0 * n_elem_per_reg); + avec[3].v = _mm256_loadu_pd(av[3] + 0 * n_elem_per_reg); // perform: rho?v += a?v * x0v; - rho0v.v = _mm256_fmadd_pd( a0v.v, x0v.v, rho0v.v ); - rho1v.v = _mm256_fmadd_pd( a1v.v, x0v.v, rho1v.v ); - rho2v.v = _mm256_fmadd_pd( a2v.v, x0v.v, rho2v.v ); - rho3v.v = _mm256_fmadd_pd( a3v.v, x0v.v, rho3v.v ); - rho4v.v = _mm256_fmadd_pd( a4v.v, x0v.v, rho4v.v ); - rho5v.v = _mm256_fmadd_pd( a5v.v, x0v.v, rho5v.v ); - rho6v.v = _mm256_fmadd_pd( a6v.v, x0v.v, rho6v.v ); - rho7v.v = _mm256_fmadd_pd( a7v.v, x0v.v, rho7v.v ); + rhov[0].v = _mm256_fmadd_pd(avec[0].v, xv.v, rhov[0].v); + rhov[1].v = _mm256_fmadd_pd(avec[1].v, xv.v, rhov[1].v); + rhov[2].v = _mm256_fmadd_pd(avec[2].v, xv.v, rhov[2].v); + rhov[3].v = _mm256_fmadd_pd(avec[3].v, xv.v, rhov[3].v); + + avec[4].v = _mm256_loadu_pd(av[4] + 0 * n_elem_per_reg); + avec[5].v = _mm256_loadu_pd(av[5] + 0 * n_elem_per_reg); + avec[6].v = _mm256_loadu_pd(av[6] + 0 * n_elem_per_reg); + avec[7].v = _mm256_loadu_pd(av[7] + 0 * n_elem_per_reg); + + rhov[4].v = _mm256_fmadd_pd(avec[4].v, xv.v, rhov[4].v); + rhov[5].v = _mm256_fmadd_pd(avec[5].v, xv.v, rhov[5].v); + rhov[6].v = _mm256_fmadd_pd(avec[6].v, xv.v, rhov[6].v); + rhov[7].v = _mm256_fmadd_pd(avec[7].v, xv.v, rhov[7].v); x0 += n_elem_per_reg * n_iter_unroll; - a0 += n_elem_per_reg * n_iter_unroll; - a1 += n_elem_per_reg * n_iter_unroll; - a2 += n_elem_per_reg * n_iter_unroll; - a3 += n_elem_per_reg * n_iter_unroll; - a4 += n_elem_per_reg * n_iter_unroll; - a5 += n_elem_per_reg * n_iter_unroll; - a6 += n_elem_per_reg * n_iter_unroll; - a7 += n_elem_per_reg * n_iter_unroll; + av[0] += n_elem_per_reg * n_iter_unroll; + av[1] += n_elem_per_reg * n_iter_unroll; + av[2] += n_elem_per_reg * n_iter_unroll; + av[3] += n_elem_per_reg * n_iter_unroll; + av[4] += n_elem_per_reg * n_iter_unroll; + av[5] += n_elem_per_reg * n_iter_unroll; + av[6] += n_elem_per_reg * n_iter_unroll; + av[7] += n_elem_per_reg * n_iter_unroll; } -#if 0 - rho0 += rho0v.d[0] + rho0v.d[1] + rho0v.d[2] + rho0v.d[3]; - rho1 += rho1v.d[0] + rho1v.d[1] + rho1v.d[2] + rho1v.d[3]; - rho2 += rho2v.d[0] + rho2v.d[1] + rho2v.d[2] + rho2v.d[3]; - rho3 += rho3v.d[0] + rho3v.d[1] + rho3v.d[2] + rho3v.d[3]; - rho4 += rho4v.d[0] + rho4v.d[1] + rho4v.d[2] + rho4v.d[3]; - rho5 += rho5v.d[0] + rho5v.d[1] + rho5v.d[2] + rho5v.d[3]; - rho6 += rho6v.d[0] + rho6v.d[1] + rho6v.d[2] + rho6v.d[3]; - rho7 += rho7v.d[0] + rho7v.d[1] + rho7v.d[2] + rho7v.d[3]; -#else // Sum the elements of a given rho?v. This computes the sum of // elements within lanes and stores the sum to both elements. - rho0v.v = _mm256_hadd_pd( rho0v.v, rho0v.v ); - rho1v.v = _mm256_hadd_pd( rho1v.v, rho1v.v ); - rho2v.v = _mm256_hadd_pd( rho2v.v, rho2v.v ); - rho3v.v = _mm256_hadd_pd( rho3v.v, rho3v.v ); - rho4v.v = _mm256_hadd_pd( rho4v.v, rho4v.v ); - rho5v.v = _mm256_hadd_pd( rho5v.v, rho5v.v ); - rho6v.v = _mm256_hadd_pd( rho6v.v, rho6v.v ); - rho7v.v = _mm256_hadd_pd( rho7v.v, rho7v.v ); + rhov[0].v = _mm256_hadd_pd(rhov[0].v, rhov[0].v); + rhov[1].v = _mm256_hadd_pd(rhov[1].v, rhov[1].v); + rhov[2].v = _mm256_hadd_pd(rhov[2].v, rhov[2].v); + rhov[3].v = _mm256_hadd_pd(rhov[3].v, rhov[3].v); + rhov[4].v = _mm256_hadd_pd(rhov[4].v, rhov[4].v); + rhov[5].v = _mm256_hadd_pd(rhov[5].v, rhov[5].v); + rhov[6].v = _mm256_hadd_pd(rhov[6].v, rhov[6].v); + rhov[7].v = _mm256_hadd_pd(rhov[7].v, rhov[7].v); // Manually add the results from above to finish the sum. - rho0 = rho0v.d[0] + rho0v.d[2]; - rho1 = rho1v.d[0] + rho1v.d[2]; - rho2 = rho2v.d[0] + rho2v.d[2]; - rho3 = rho3v.d[0] + rho3v.d[2]; - rho4 = rho4v.d[0] + rho4v.d[2]; - rho5 = rho5v.d[0] + rho5v.d[2]; - rho6 = rho6v.d[0] + rho6v.d[2]; - rho7 = rho7v.d[0] + rho7v.d[2]; -#endif + rho0 = rhov[0].d[0] + rhov[0].d[2]; + rho1 = rhov[1].d[0] + rhov[1].d[2]; + rho2 = rhov[2].d[0] + rhov[2].d[2]; + rho3 = rhov[3].d[0] + rhov[3].d[2]; + rho4 = rhov[4].d[0] + rhov[4].d[2]; + rho5 = rhov[5].d[0] + rhov[5].d[2]; + rho6 = rhov[6].d[0] + rhov[6].d[2]; + rho7 = rhov[7].d[0] + rhov[7].d[2]; + // Adjust for scalar subproblem. m -= n_elem_per_reg * n_iter_unroll * m_viter; a += n_elem_per_reg * n_iter_unroll * m_viter /* * inca */; x += n_elem_per_reg * n_iter_unroll * m_viter /* * incx */; - } - else if ( lda == 1 ) + + }else if (lda == 1) { const dim_t n_iter_unroll = 3; const dim_t n_reg_per_row = 2; // fuse_fac / n_elem_per_reg; @@ -672,127 +675,871 @@ void bli_ddotxf_zen_int_8 a += n_iter_unroll * m_viter * inca; x += n_iter_unroll * m_viter * incx; } + + // Initialize pointers for x and the b_n columns of A (rows of A^T). + double *restrict x0 = x; + double *restrict a0 = a + 0 * lda; + double *restrict a1 = a + 1 * lda; + double *restrict a2 = a + 2 * lda; + double *restrict a3 = a + 3 * lda; + double *restrict a4 = a + 4 * lda; + double *restrict a5 = a + 5 * lda; + double *restrict a6 = a + 6 * lda; + double *restrict a7 = a + 7 * lda; + + // If there are leftover iterations, perform them with scalar code. + for (dim_t i = 0; i < m; ++i) + { + const double x0c = *x0; + + const double a0c = *a0; + const double a1c = *a1; + const double a2c = *a2; + const double a3c = *a3; + const double a4c = *a4; + const double a5c = *a5; + const double a6c = *a6; + const double a7c = *a7; + + rho0 += a0c * x0c; + rho1 += a1c * x0c; + rho2 += a2c * x0c; + rho3 += a3c * x0c; + rho4 += a4c * x0c; + rho5 += a5c * x0c; + rho6 += a6c * x0c; + rho7 += a7c * x0c; + + x0 += incx; + a0 += inca; + a1 += inca; + a2 += inca; + a3 += inca; + a4 += inca; + a5 += inca; + a6 += inca; + a7 += inca; + } + + // Now prepare the final rho values to output/accumulate back into + // the y vector. + + v4df_t rho0v, rho1v, y0v, y1v; + + // Insert the scalar rho values into a single vector. + rho0v.d[0] = rho0; + rho0v.d[1] = rho1; + rho0v.d[2] = rho2; + rho0v.d[3] = rho3; + rho1v.d[0] = rho4; + rho1v.d[1] = rho5; + rho1v.d[2] = rho6; + rho1v.d[3] = rho7; + + // Broadcast the alpha scalar. + v4df_t alphav; + alphav.v = _mm256_broadcast_sd(alpha); + + // We know at this point that alpha is nonzero; however, beta may still + // be zero. If beta is indeed zero, we must overwrite y rather than scale + // by beta (in case y contains NaN or Inf). + if (PASTEMAC(d, eq0)(*beta)) + { + // Apply alpha to the accumulated dot product in rho: + // y := alpha * rho + y0v.v = _mm256_mul_pd(alphav.v, rho0v.v); + y1v.v = _mm256_mul_pd(alphav.v, rho1v.v); + } else { - // No vectorization possible; use scalar iterations for the entire - // problem. + // Broadcast the beta scalar. + v4df_t betav; + betav.v = _mm256_broadcast_sd(beta); + + // Load y. + if (incy == 1) + { + y0v.v = _mm256_loadu_pd(y + 0 * n_elem_per_reg); + y1v.v = _mm256_loadu_pd(y + 1 * n_elem_per_reg); + } + else + { + y0v.d[0] = *(y + 0 * incy); + y0v.d[1] = *(y + 1 * incy); + y0v.d[2] = *(y + 2 * incy); + y0v.d[3] = *(y + 3 * incy); + y1v.d[0] = *(y + 4 * incy); + y1v.d[1] = *(y + 5 * incy); + y1v.d[2] = *(y + 6 * incy); + y1v.d[3] = *(y + 7 * incy); + } + + // Apply beta to y and alpha to the accumulated dot product in rho: + // y := beta * y + alpha * rho + y0v.v = _mm256_mul_pd(betav.v, y0v.v); + y1v.v = _mm256_mul_pd(betav.v, y1v.v); + y0v.v = _mm256_fmadd_pd(alphav.v, rho0v.v, y0v.v); + y1v.v = _mm256_fmadd_pd(alphav.v, rho1v.v, y1v.v); } - // Scalar edge case. + if (incy == 1) { - // Initialize pointers for x and the b_n columns of A (rows of A^T). - double* restrict x0 = x; - double* restrict a0 = a + 0*lda; - double* restrict a1 = a + 1*lda; - double* restrict a2 = a + 2*lda; - double* restrict a3 = a + 3*lda; - double* restrict a4 = a + 4*lda; - double* restrict a5 = a + 5*lda; - double* restrict a6 = a + 6*lda; - double* restrict a7 = a + 7*lda; + // Store the output. + _mm256_storeu_pd((y + 0 * n_elem_per_reg), y0v.v); + _mm256_storeu_pd((y + 1 * n_elem_per_reg), y1v.v); + } + else + { + *(y + 0 * incy) = y0v.d[0]; + *(y + 1 * incy) = y0v.d[1]; + *(y + 2 * incy) = y0v.d[2]; + *(y + 3 * incy) = y0v.d[3]; + *(y + 4 * incy) = y1v.d[0]; + *(y + 5 * incy) = y1v.d[1]; + *(y + 6 * incy) = y1v.d[2]; + *(y + 7 * incy) = y1v.d[3]; + } +} - // If there are leftover iterations, perform them with scalar code. - for ( dim_t i = 0; i < m ; ++i ) + +void bli_ddotxf_zen_int_4 + ( + conj_t conjat, + conj_t conjx, + dim_t m, + dim_t b_n, + double *restrict alpha, + double *restrict a, inc_t inca, inc_t lda, + double *restrict x, inc_t incx, + double *restrict beta, + double *restrict y, inc_t incy, + cntx_t *restrict cntx + ) +{ + const dim_t fuse_fac = 4; + const dim_t n_elem_per_reg = 4; + + // If the b_n dimension is zero, y is empty and there is no computation. + if (bli_zero_dim1(b_n)) + return; + + // If the m dimension is zero, or if alpha is zero, the computation + // simplifies to updating y. + if (bli_zero_dim1(m) || PASTEMAC(d, eq0)(*alpha)) + { + bli_dscalv_zen_int10( + BLIS_NO_CONJUGATE, + b_n, + beta, + y, incy, + cntx); + return; + } + + // If b_n is not equal to the fusing factor, then perform the entire + // operation as a loop over dotxv. + if (b_n != fuse_fac) + { + for (dim_t i = 0; i < b_n; ++i) { - const double x0c = *x0; + double *a1 = a + (0) * inca + (i)*lda; + double *x1 = x + (0) * incx; + double *psi1 = y + (i)*incy; + + bli_ddotxv_zen_int( + conjat, + conjx, + m, + alpha, + a1, inca, + x1, incx, + beta, + psi1, + cntx); + } + return; + } - const double a0c = *a0; - const double a1c = *a1; - const double a2c = *a2; - const double a3c = *a3; - const double a4c = *a4; - const double a5c = *a5; - const double a6c = *a6; - const double a7c = *a7; + // At this point, we know that b_n is exactly equal to the fusing factor. + // However, m may not be a multiple of the number of elements per vector. - rho0 += a0c * x0c; - rho1 += a1c * x0c; - rho2 += a2c * x0c; - rho3 += a3c * x0c; - rho4 += a4c * x0c; - rho5 += a5c * x0c; - rho6 += a6c * x0c; - rho7 += a7c * x0c; + // Going forward, we handle two possible storage formats of A explicitly: + // (1) A is stored by columns, or (2) A is stored by rows. Either case is + // further split into two subproblems along the m dimension: + // (a) a vectorized part, starting at m = 0 and ending at any 0 <= m' <= m. + // (b) a scalar part, starting at m' and ending at m. If no vectorization + // is possible then m' == 0 and thus the scalar part is the entire + // problem. If 0 < m', then the a and x pointers and m variable will + // be adjusted accordingly for the second subproblem. + // Note: since parts (b) for both (1) and (2) are so similar, they are + // factored out into one code block after the following conditional, which + // distinguishes between (1) and (2). - x0 += incx; - a0 += inca; - a1 += inca; - a2 += inca; - a3 += inca; - a4 += inca; - a5 += inca; - a6 += inca; - a7 += inca; + // Intermediate variables to hold the completed dot products + double rho0 = 0, rho1 = 0, rho2 = 0, rho3 = 0; + + if (inca == 1 && incx == 1) + { + const dim_t n_iter_unroll[4] = {4, 3, 2, 1}; + + // Use the unrolling factor and the number of elements per register + // to compute the number of vectorized and leftover iterations. + dim_t m_viter[4], m_left = m, i; + + // Calculate the number of vector iterations that can occur for + // various unroll factors. + for (i = 0; i < 4; ++i) + { + m_viter[i] = (m_left) / (n_elem_per_reg * n_iter_unroll[i]); + m_left = (m_left) % (n_elem_per_reg * n_iter_unroll[i]); + } + + // Set up pointers for x and the b_n columns of A (rows of A^T). + double *restrict x0 = x; + double *restrict av[4]; + + av[0] = a + 0 * lda; + av[1] = a + 1 * lda; + av[2] = a + 2 * lda; + av[3] = a + 3 * lda; + + // Initialize b_n rho vector accumulators to zero. + v4df_t rhov[8]; + + rhov[0].v = _mm256_setzero_pd(); + rhov[1].v = _mm256_setzero_pd(); + rhov[2].v = _mm256_setzero_pd(); + rhov[3].v = _mm256_setzero_pd(); + rhov[4].v = _mm256_setzero_pd(); + rhov[5].v = _mm256_setzero_pd(); + rhov[6].v = _mm256_setzero_pd(); + rhov[7].v = _mm256_setzero_pd(); + + v4df_t xv[4]; + v4df_t avec[16]; + + // If there are vectorized iterations, perform them with vector + // instructions. + for (i = 0; i < m_viter[0]; ++i) + { + // Load the input values. + xv[0].v = _mm256_loadu_pd(x0 + 0 * n_elem_per_reg); + xv[1].v = _mm256_loadu_pd(x0 + 1 * n_elem_per_reg); + xv[2].v = _mm256_loadu_pd(x0 + 2 * n_elem_per_reg); + xv[3].v = _mm256_loadu_pd(x0 + 3 * n_elem_per_reg); + + avec[0].v = _mm256_loadu_pd(av[0] + 0 * n_elem_per_reg); + avec[1].v = _mm256_loadu_pd(av[1] + 0 * n_elem_per_reg); + avec[2].v = _mm256_loadu_pd(av[2] + 0 * n_elem_per_reg); + avec[3].v = _mm256_loadu_pd(av[3] + 0 * n_elem_per_reg); + + // perform: rho?v += a?v * x0v; + rhov[0].v = _mm256_fmadd_pd(avec[0].v, xv[0].v, rhov[0].v); + rhov[1].v = _mm256_fmadd_pd(avec[1].v, xv[0].v, rhov[1].v); + rhov[2].v = _mm256_fmadd_pd(avec[2].v, xv[0].v, rhov[2].v); + rhov[3].v = _mm256_fmadd_pd(avec[3].v, xv[0].v, rhov[3].v); + + avec[4].v = _mm256_loadu_pd(av[0] + 1 * n_elem_per_reg); + avec[5].v = _mm256_loadu_pd(av[1] + 1 * n_elem_per_reg); + avec[6].v = _mm256_loadu_pd(av[2] + 1 * n_elem_per_reg); + avec[7].v = _mm256_loadu_pd(av[3] + 1 * n_elem_per_reg); + + rhov[4].v = _mm256_fmadd_pd(avec[4].v, xv[1].v, rhov[4].v); + rhov[5].v = _mm256_fmadd_pd(avec[5].v, xv[1].v, rhov[5].v); + rhov[6].v = _mm256_fmadd_pd(avec[6].v, xv[1].v, rhov[6].v); + rhov[7].v = _mm256_fmadd_pd(avec[7].v, xv[1].v, rhov[7].v); + + avec[8].v = _mm256_loadu_pd(av[0] + 2 * n_elem_per_reg); + avec[9].v = _mm256_loadu_pd(av[1] + 2 * n_elem_per_reg); + avec[10].v = _mm256_loadu_pd(av[2] + 2 * n_elem_per_reg); + avec[11].v = _mm256_loadu_pd(av[3] + 2 * n_elem_per_reg); + + rhov[0].v = _mm256_fmadd_pd(avec[8].v, xv[2].v, rhov[0].v); + rhov[1].v = _mm256_fmadd_pd(avec[9].v, xv[2].v, rhov[1].v); + rhov[2].v = _mm256_fmadd_pd(avec[10].v, xv[2].v, rhov[2].v); + rhov[3].v = _mm256_fmadd_pd(avec[11].v, xv[2].v, rhov[3].v); + + avec[12].v = _mm256_loadu_pd(av[0] + 3 * n_elem_per_reg); + avec[13].v = _mm256_loadu_pd(av[1] + 3 * n_elem_per_reg); + avec[14].v = _mm256_loadu_pd(av[2] + 3 * n_elem_per_reg); + avec[15].v = _mm256_loadu_pd(av[3] + 3 * n_elem_per_reg); + + rhov[4].v = _mm256_fmadd_pd(avec[12].v, xv[3].v, rhov[4].v); + rhov[5].v = _mm256_fmadd_pd(avec[13].v, xv[3].v, rhov[5].v); + rhov[6].v = _mm256_fmadd_pd(avec[14].v, xv[3].v, rhov[6].v); + rhov[7].v = _mm256_fmadd_pd(avec[15].v, xv[3].v, rhov[7].v); + + x0 += n_elem_per_reg * n_iter_unroll[0]; + av[0] += n_elem_per_reg * n_iter_unroll[0]; + av[1] += n_elem_per_reg * n_iter_unroll[0]; + av[2] += n_elem_per_reg * n_iter_unroll[0]; + av[3] += n_elem_per_reg * n_iter_unroll[0]; + } + + for (i = 0; i < m_viter[1]; ++i) + { + // Load the input values. + xv[0].v = _mm256_loadu_pd(x0 + 0 * n_elem_per_reg); + xv[1].v = _mm256_loadu_pd(x0 + 1 * n_elem_per_reg); + xv[2].v = _mm256_loadu_pd(x0 + 2 * n_elem_per_reg); + + avec[0].v = _mm256_loadu_pd(av[0] + 0 * n_elem_per_reg); + avec[1].v = _mm256_loadu_pd(av[1] + 0 * n_elem_per_reg); + avec[2].v = _mm256_loadu_pd(av[2] + 0 * n_elem_per_reg); + avec[3].v = _mm256_loadu_pd(av[3] + 0 * n_elem_per_reg); + + // perform: rho?v += a?v * x0v; + rhov[0].v = _mm256_fmadd_pd(avec[0].v, xv[0].v, rhov[0].v); + rhov[1].v = _mm256_fmadd_pd(avec[1].v, xv[0].v, rhov[1].v); + rhov[2].v = _mm256_fmadd_pd(avec[2].v, xv[0].v, rhov[2].v); + rhov[3].v = _mm256_fmadd_pd(avec[3].v, xv[0].v, rhov[3].v); + + avec[4].v = _mm256_loadu_pd(av[0] + 1 * n_elem_per_reg); + avec[5].v = _mm256_loadu_pd(av[1] + 1 * n_elem_per_reg); + avec[6].v = _mm256_loadu_pd(av[2] + 1 * n_elem_per_reg); + avec[7].v = _mm256_loadu_pd(av[3] + 1 * n_elem_per_reg); + + rhov[4].v = _mm256_fmadd_pd(avec[4].v, xv[1].v, rhov[4].v); + rhov[5].v = _mm256_fmadd_pd(avec[5].v, xv[1].v, rhov[5].v); + rhov[6].v = _mm256_fmadd_pd(avec[6].v, xv[1].v, rhov[6].v); + rhov[7].v = _mm256_fmadd_pd(avec[7].v, xv[1].v, rhov[7].v); + + avec[8].v = _mm256_loadu_pd(av[0] + 2 * n_elem_per_reg); + avec[9].v = _mm256_loadu_pd(av[1] + 2 * n_elem_per_reg); + avec[10].v = _mm256_loadu_pd(av[2] + 2 * n_elem_per_reg); + avec[11].v = _mm256_loadu_pd(av[3] + 2 * n_elem_per_reg); + + rhov[0].v = _mm256_fmadd_pd(avec[8].v, xv[2].v, rhov[0].v); + rhov[1].v = _mm256_fmadd_pd(avec[9].v, xv[2].v, rhov[1].v); + rhov[2].v = _mm256_fmadd_pd(avec[10].v, xv[2].v, rhov[2].v); + rhov[3].v = _mm256_fmadd_pd(avec[11].v, xv[2].v, rhov[3].v); + + x0 += n_elem_per_reg * n_iter_unroll[1]; + av[0] += n_elem_per_reg * n_iter_unroll[1]; + av[1] += n_elem_per_reg * n_iter_unroll[1]; + av[2] += n_elem_per_reg * n_iter_unroll[1]; + av[3] += n_elem_per_reg * n_iter_unroll[1]; + } + + for (i = 0; i < m_viter[2]; ++i) + { + // Load the input values. + xv[0].v = _mm256_loadu_pd(x0 + 0 * n_elem_per_reg); + xv[1].v = _mm256_loadu_pd(x0 + 1 * n_elem_per_reg); + + avec[0].v = _mm256_loadu_pd(av[0] + 0 * n_elem_per_reg); + avec[1].v = _mm256_loadu_pd(av[1] + 0 * n_elem_per_reg); + avec[2].v = _mm256_loadu_pd(av[2] + 0 * n_elem_per_reg); + avec[3].v = _mm256_loadu_pd(av[3] + 0 * n_elem_per_reg); + + avec[4].v = _mm256_loadu_pd(av[0] + 1 * n_elem_per_reg); + avec[5].v = _mm256_loadu_pd(av[1] + 1 * n_elem_per_reg); + avec[6].v = _mm256_loadu_pd(av[2] + 1 * n_elem_per_reg); + avec[7].v = _mm256_loadu_pd(av[3] + 1 * n_elem_per_reg); + + // perform: rho?v += a?v * x0v; + rhov[0].v = _mm256_fmadd_pd(avec[0].v, xv[0].v, rhov[0].v); + rhov[1].v = _mm256_fmadd_pd(avec[1].v, xv[0].v, rhov[1].v); + rhov[2].v = _mm256_fmadd_pd(avec[2].v, xv[0].v, rhov[2].v); + rhov[3].v = _mm256_fmadd_pd(avec[3].v, xv[0].v, rhov[3].v); + + rhov[4].v = _mm256_fmadd_pd(avec[4].v, xv[1].v, rhov[4].v); + rhov[5].v = _mm256_fmadd_pd(avec[5].v, xv[1].v, rhov[5].v); + rhov[6].v = _mm256_fmadd_pd(avec[6].v, xv[1].v, rhov[6].v); + rhov[7].v = _mm256_fmadd_pd(avec[7].v, xv[1].v, rhov[7].v); + + x0 += n_elem_per_reg * n_iter_unroll[2]; + av[0] += n_elem_per_reg * n_iter_unroll[2]; + av[1] += n_elem_per_reg * n_iter_unroll[2]; + av[2] += n_elem_per_reg * n_iter_unroll[2]; + av[3] += n_elem_per_reg * n_iter_unroll[2]; + } + + for (i = 0; i < m_viter[3]; ++i) + { + // Load the input values. + xv[0].v = _mm256_loadu_pd(x0 + 0 * n_elem_per_reg); + + avec[0].v = _mm256_loadu_pd(av[0] + 0 * n_elem_per_reg); + avec[1].v = _mm256_loadu_pd(av[1] + 0 * n_elem_per_reg); + avec[2].v = _mm256_loadu_pd(av[2] + 0 * n_elem_per_reg); + avec[3].v = _mm256_loadu_pd(av[3] + 0 * n_elem_per_reg); + + // perform: rho?v += a?v * x0v; + rhov[0].v = _mm256_fmadd_pd(avec[0].v, xv[0].v, rhov[0].v); + rhov[1].v = _mm256_fmadd_pd(avec[1].v, xv[0].v, rhov[1].v); + rhov[2].v = _mm256_fmadd_pd(avec[2].v, xv[0].v, rhov[2].v); + rhov[3].v = _mm256_fmadd_pd(avec[3].v, xv[0].v, rhov[3].v); + + x0 += n_elem_per_reg * n_iter_unroll[3]; + av[0] += n_elem_per_reg * n_iter_unroll[3]; + av[1] += n_elem_per_reg * n_iter_unroll[3]; + av[2] += n_elem_per_reg * n_iter_unroll[3]; + av[3] += n_elem_per_reg * n_iter_unroll[3]; + } + + // Sum the elements of a given rho?v. This computes the sum of + // elements within lanes and stores the sum to both elements. + rhov[0].v = _mm256_add_pd(rhov[0].v, rhov[4].v); + rhov[1].v = _mm256_add_pd(rhov[1].v, rhov[5].v); + rhov[2].v = _mm256_add_pd(rhov[2].v, rhov[6].v); + rhov[3].v = _mm256_add_pd(rhov[3].v, rhov[7].v); + + rhov[0].v = _mm256_hadd_pd(rhov[0].v, rhov[0].v); + rhov[1].v = _mm256_hadd_pd(rhov[1].v, rhov[1].v); + rhov[2].v = _mm256_hadd_pd(rhov[2].v, rhov[2].v); + rhov[3].v = _mm256_hadd_pd(rhov[3].v, rhov[3].v); + + // Manually add the results from above to finish the sum. + rho0 = rhov[0].d[0] + rhov[0].d[2]; + rho1 = rhov[1].d[0] + rhov[1].d[2]; + rho2 = rhov[2].d[0] + rhov[2].d[2]; + rho3 = rhov[3].d[0] + rhov[3].d[2]; + + // Adjust for scalar subproblem. + for (i = 0; i < 4; ++i) + { + m -= n_elem_per_reg * n_iter_unroll[i] * m_viter[i]; + a += n_elem_per_reg * n_iter_unroll[i] * m_viter[i] /* * inca */; + x += n_elem_per_reg * n_iter_unroll[i] * m_viter[i] /* * incx */; } } + // Initialize pointers for x and the b_n columns of A (rows of A^T). + double *restrict x0 = x; + double *restrict a0 = a + 0 * lda; + double *restrict a1 = a + 1 * lda; + double *restrict a2 = a + 2 * lda; + double *restrict a3 = a + 3 * lda; + + // If there are leftover iterations, perform them with scalar code. + for (dim_t i = 0; i < m; ++i) + { + const double x0c = *x0; + + const double a0c = *a0; + const double a1c = *a1; + const double a2c = *a2; + const double a3c = *a3; + + rho0 += a0c * x0c; + rho1 += a1c * x0c; + rho2 += a2c * x0c; + rho3 += a3c * x0c; + + x0 += incx; + a0 += inca; + a1 += inca; + a2 += inca; + a3 += inca; + } + // Now prepare the final rho values to output/accumulate back into // the y vector. - v4df_t rho0v, rho1v, y0v, y1v; + v4df_t rho0v, y0v; // Insert the scalar rho values into a single vector. rho0v.d[0] = rho0; rho0v.d[1] = rho1; rho0v.d[2] = rho2; rho0v.d[3] = rho3; - rho1v.d[0] = rho4; - rho1v.d[1] = rho5; - rho1v.d[2] = rho6; - rho1v.d[3] = rho7; // Broadcast the alpha scalar. - v4df_t alphav; alphav.v = _mm256_broadcast_sd( alpha ); + v4df_t alphav; + alphav.v = _mm256_broadcast_sd(alpha); // We know at this point that alpha is nonzero; however, beta may still // be zero. If beta is indeed zero, we must overwrite y rather than scale // by beta (in case y contains NaN or Inf). - if ( PASTEMAC(d,eq0)( *beta ) ) + if (PASTEMAC(d, eq0)(*beta)) { // Apply alpha to the accumulated dot product in rho: // y := alpha * rho - y0v.v = _mm256_mul_pd( alphav.v, rho0v.v ); - y1v.v = _mm256_mul_pd( alphav.v, rho1v.v ); + y0v.v = _mm256_mul_pd(alphav.v, rho0v.v); } else { // Broadcast the beta scalar. - v4df_t betav; betav.v = _mm256_broadcast_sd( beta ); + v4df_t betav; + betav.v = _mm256_broadcast_sd(beta); // Load y. - if ( incy == 1 ) + if (incy == 1) { - y0v.v = _mm256_loadu_pd( y + 0*n_elem_per_reg ); - y1v.v = _mm256_loadu_pd( y + 1*n_elem_per_reg ); + y0v.v = _mm256_loadu_pd(y + 0 * n_elem_per_reg); } else { - y0v.d[0] = *(y + 0*incy); y0v.d[1] = *(y + 1*incy); - y0v.d[2] = *(y + 2*incy); y0v.d[3] = *(y + 3*incy); - y1v.d[0] = *(y + 4*incy); y1v.d[1] = *(y + 5*incy); - y1v.d[2] = *(y + 6*incy); y1v.d[3] = *(y + 7*incy); + y0v.d[0] = *(y + 0 * incy); + y0v.d[1] = *(y + 1 * incy); + y0v.d[2] = *(y + 2 * incy); + y0v.d[3] = *(y + 3 * incy); } // Apply beta to y and alpha to the accumulated dot product in rho: // y := beta * y + alpha * rho - y0v.v = _mm256_mul_pd( betav.v, y0v.v ); - y1v.v = _mm256_mul_pd( betav.v, y1v.v ); - y0v.v = _mm256_fmadd_pd( alphav.v, rho0v.v, y0v.v ); - y1v.v = _mm256_fmadd_pd( alphav.v, rho1v.v, y1v.v ); + y0v.v = _mm256_mul_pd(betav.v, y0v.v); + y0v.v = _mm256_fmadd_pd(alphav.v, rho0v.v, y0v.v); } - if ( incy == 1 ) + if (incy == 1) { // Store the output. - _mm256_storeu_pd( (y + 0*n_elem_per_reg), y0v.v ); - _mm256_storeu_pd( (y + 1*n_elem_per_reg), y1v.v ); + _mm256_storeu_pd((y + 0 * n_elem_per_reg), y0v.v); } else { - *(y + 0*incy) = y0v.d[0]; *(y + 1*incy) = y0v.d[1]; - *(y + 2*incy) = y0v.d[2]; *(y + 3*incy) = y0v.d[3]; - *(y + 4*incy) = y1v.d[0]; *(y + 5*incy) = y1v.d[1]; - *(y + 6*incy) = y1v.d[2]; *(y + 7*incy) = y1v.d[3]; + *(y + 0 * incy) = y0v.d[0]; + *(y + 1 * incy) = y0v.d[1]; + *(y + 2 * incy) = y0v.d[2]; + *(y + 3 * incy) = y0v.d[3]; } } +void bli_ddotxf_zen_int_2 + ( + conj_t conjat, + conj_t conjx, + dim_t m, + dim_t b_n, + double *restrict alpha, + double *restrict a, inc_t inca, inc_t lda, + double *restrict x, inc_t incx, + double *restrict beta, + double *restrict y, inc_t incy, + cntx_t *restrict cntx + ) +{ + const dim_t fuse_fac = 2; + const dim_t n_elem_per_reg = 4; + + // If the b_n dimension is zero, y is empty and there is no computation. + if (bli_zero_dim1(b_n)) + return; + + // If the m dimension is zero, or if alpha is zero, the computation + // simplifies to updating y. + if (bli_zero_dim1(m) || PASTEMAC(d, eq0)(*alpha)) + { + bli_dscalv_zen_int10( + BLIS_NO_CONJUGATE, + b_n, + beta, + y, incy, + cntx); + return; + } + + // If b_n is not equal to the fusing factor, then perform the entire + // operation as a loop over dotxv. + if (b_n != fuse_fac) + { + for (dim_t i = 0; i < b_n; ++i) + { + double *a1 = a + (0) * inca + (i)*lda; + double *x1 = x + (0) * incx; + double *psi1 = y + (i)*incy; + + bli_ddotxv_zen_int( + conjat, + conjx, + m, + alpha, + a1, inca, + x1, incx, + beta, + psi1, + cntx); + } + return; + } + + // At this point, we know that b_n is exactly equal to the fusing factor. + // However, m may not be a multiple of the number of elements per vector. + + // Going forward, we handle two possible storage formats of A explicitly: + // (1) A is stored by columns, or (2) A is stored by rows. Either case is + // further split into two subproblems along the m dimension: + // (a) a vectorized part, starting at m = 0 and ending at any 0 <= m' <= m. + // (b) a scalar part, starting at m' and ending at m. If no vectorization + // is possible then m' == 0 and thus the scalar part is the entire + // problem. If 0 < m', then the a and x pointers and m variable will + // be adjusted accordingly for the second subproblem. + // Note: since parts (b) for both (1) and (2) are so similar, they are + // factored out into one code block after the following conditional, which + // distinguishes between (1) and (2). + + // Intermediate variables to hold the completed dot products + double rho0 = 0, rho1 = 0; + + if (inca == 1 && incx == 1) + { + const dim_t n_iter_unroll[4] = {8, 4, 2, 1}; + + // Use the unrolling factor and the number of elements per register + // to compute the number of vectorized and leftover iterations. + dim_t m_viter[4], i, m_left = m; + + for (i = 0; i < 4; ++i) + { + m_viter[i] = (m_left) / (n_elem_per_reg * n_iter_unroll[i]); + m_left = (m_left) % (n_elem_per_reg * n_iter_unroll[i]); + } + + // Set up pointers for x and the b_n columns of A (rows of A^T). + double *restrict x0 = x; + double *restrict av[2]; + + av[0] = a + 0 * lda; + av[1] = a + 1 * lda; + + // Initialize b_n rho vector accumulators to zero. + v4df_t rhov[8]; + + rhov[0].v = _mm256_setzero_pd(); + rhov[1].v = _mm256_setzero_pd(); + rhov[2].v = _mm256_setzero_pd(); + rhov[3].v = _mm256_setzero_pd(); + rhov[4].v = _mm256_setzero_pd(); + rhov[5].v = _mm256_setzero_pd(); + rhov[6].v = _mm256_setzero_pd(); + rhov[7].v = _mm256_setzero_pd(); + + v4df_t xv[4]; + v4df_t avec[8]; + + for (i = 0; i < m_viter[0]; ++i) + { + // Load the input values. + xv[0].v = _mm256_loadu_pd(x0 + 0 * n_elem_per_reg); + xv[1].v = _mm256_loadu_pd(x0 + 1 * n_elem_per_reg); + xv[2].v = _mm256_loadu_pd(x0 + 2 * n_elem_per_reg); + xv[3].v = _mm256_loadu_pd(x0 + 3 * n_elem_per_reg); + + avec[0].v = _mm256_loadu_pd(av[0] + 0 * n_elem_per_reg); + avec[1].v = _mm256_loadu_pd(av[1] + 0 * n_elem_per_reg); + avec[2].v = _mm256_loadu_pd(av[0] + 1 * n_elem_per_reg); + avec[3].v = _mm256_loadu_pd(av[1] + 1 * n_elem_per_reg); + avec[4].v = _mm256_loadu_pd(av[0] + 2 * n_elem_per_reg); + avec[5].v = _mm256_loadu_pd(av[1] + 2 * n_elem_per_reg); + avec[6].v = _mm256_loadu_pd(av[0] + 3 * n_elem_per_reg); + avec[7].v = _mm256_loadu_pd(av[1] + 3 * n_elem_per_reg); + + // perform: rho?v += a?v * x0v; + rhov[0].v = _mm256_fmadd_pd(avec[0].v, xv[0].v, rhov[0].v); + rhov[1].v = _mm256_fmadd_pd(avec[1].v, xv[0].v, rhov[1].v); + rhov[2].v = _mm256_fmadd_pd(avec[2].v, xv[1].v, rhov[2].v); + rhov[3].v = _mm256_fmadd_pd(avec[3].v, xv[1].v, rhov[3].v); + rhov[4].v = _mm256_fmadd_pd(avec[4].v, xv[2].v, rhov[4].v); + rhov[5].v = _mm256_fmadd_pd(avec[5].v, xv[2].v, rhov[5].v); + rhov[6].v = _mm256_fmadd_pd(avec[6].v, xv[3].v, rhov[6].v); + rhov[7].v = _mm256_fmadd_pd(avec[7].v, xv[3].v, rhov[7].v); + + // Load the input values. + xv[0].v = _mm256_loadu_pd(x0 + 0 * n_elem_per_reg); + xv[1].v = _mm256_loadu_pd(x0 + 1 * n_elem_per_reg); + xv[2].v = _mm256_loadu_pd(x0 + 2 * n_elem_per_reg); + xv[3].v = _mm256_loadu_pd(x0 + 3 * n_elem_per_reg); + + avec[0].v = _mm256_loadu_pd(av[0] + 0 * n_elem_per_reg); + avec[1].v = _mm256_loadu_pd(av[1] + 0 * n_elem_per_reg); + avec[2].v = _mm256_loadu_pd(av[0] + 1 * n_elem_per_reg); + avec[3].v = _mm256_loadu_pd(av[1] + 1 * n_elem_per_reg); + avec[4].v = _mm256_loadu_pd(av[0] + 2 * n_elem_per_reg); + avec[5].v = _mm256_loadu_pd(av[1] + 2 * n_elem_per_reg); + avec[6].v = _mm256_loadu_pd(av[0] + 3 * n_elem_per_reg); + avec[7].v = _mm256_loadu_pd(av[1] + 3 * n_elem_per_reg); + + // perform: rho?v += a?v * x0v; + rhov[0].v = _mm256_fmadd_pd(avec[0].v, xv[0].v, rhov[0].v); + rhov[1].v = _mm256_fmadd_pd(avec[1].v, xv[0].v, rhov[1].v); + rhov[2].v = _mm256_fmadd_pd(avec[2].v, xv[1].v, rhov[2].v); + rhov[3].v = _mm256_fmadd_pd(avec[3].v, xv[1].v, rhov[3].v); + rhov[4].v = _mm256_fmadd_pd(avec[4].v, xv[2].v, rhov[4].v); + rhov[5].v = _mm256_fmadd_pd(avec[5].v, xv[2].v, rhov[5].v); + rhov[6].v = _mm256_fmadd_pd(avec[6].v, xv[3].v, rhov[6].v); + rhov[7].v = _mm256_fmadd_pd(avec[7].v, xv[3].v, rhov[7].v); + + x0 += n_elem_per_reg * n_iter_unroll[0]; + av[0] += n_elem_per_reg * n_iter_unroll[0]; + av[1] += n_elem_per_reg * n_iter_unroll[0]; + } + + for (i = 0; i < m_viter[1]; ++i) + { + // Load the input values. + xv[0].v = _mm256_loadu_pd(x0 + 0 * n_elem_per_reg); + xv[1].v = _mm256_loadu_pd(x0 + 1 * n_elem_per_reg); + xv[2].v = _mm256_loadu_pd(x0 + 2 * n_elem_per_reg); + xv[3].v = _mm256_loadu_pd(x0 + 3 * n_elem_per_reg); + + avec[0].v = _mm256_loadu_pd(av[0] + 0 * n_elem_per_reg); + avec[1].v = _mm256_loadu_pd(av[1] + 0 * n_elem_per_reg); + avec[2].v = _mm256_loadu_pd(av[0] + 1 * n_elem_per_reg); + avec[3].v = _mm256_loadu_pd(av[1] + 1 * n_elem_per_reg); + avec[4].v = _mm256_loadu_pd(av[0] + 2 * n_elem_per_reg); + avec[5].v = _mm256_loadu_pd(av[1] + 2 * n_elem_per_reg); + avec[6].v = _mm256_loadu_pd(av[0] + 3 * n_elem_per_reg); + avec[7].v = _mm256_loadu_pd(av[1] + 3 * n_elem_per_reg); + + // perform: rho?v += a?v * x0v; + rhov[0].v = _mm256_fmadd_pd(avec[0].v, xv[0].v, rhov[0].v); + rhov[1].v = _mm256_fmadd_pd(avec[1].v, xv[0].v, rhov[1].v); + rhov[2].v = _mm256_fmadd_pd(avec[2].v, xv[1].v, rhov[2].v); + rhov[3].v = _mm256_fmadd_pd(avec[3].v, xv[1].v, rhov[3].v); + rhov[4].v = _mm256_fmadd_pd(avec[4].v, xv[2].v, rhov[4].v); + rhov[5].v = _mm256_fmadd_pd(avec[5].v, xv[2].v, rhov[5].v); + rhov[6].v = _mm256_fmadd_pd(avec[6].v, xv[3].v, rhov[6].v); + rhov[7].v = _mm256_fmadd_pd(avec[7].v, xv[3].v, rhov[7].v); + + x0 += n_elem_per_reg * n_iter_unroll[1]; + av[0] += n_elem_per_reg * n_iter_unroll[1]; + av[1] += n_elem_per_reg * n_iter_unroll[1]; + } + + rhov[0].v = _mm256_add_pd(rhov[0].v, rhov[4].v); + rhov[1].v = _mm256_add_pd(rhov[1].v, rhov[5].v); + rhov[2].v = _mm256_add_pd(rhov[2].v, rhov[6].v); + rhov[3].v = _mm256_add_pd(rhov[3].v, rhov[7].v); + + for (i = 0; i < m_viter[2]; ++i) + { + // Load the input values. + xv[0].v = _mm256_loadu_pd(x0 + 0 * n_elem_per_reg); + xv[1].v = _mm256_loadu_pd(x0 + 1 * n_elem_per_reg); + + avec[0].v = _mm256_loadu_pd(av[0] + 0 * n_elem_per_reg); + avec[1].v = _mm256_loadu_pd(av[1] + 0 * n_elem_per_reg); + avec[2].v = _mm256_loadu_pd(av[0] + 1 * n_elem_per_reg); + avec[3].v = _mm256_loadu_pd(av[1] + 1 * n_elem_per_reg); + + // perform: rho?v += a?v * x0v; + rhov[0].v = _mm256_fmadd_pd(avec[0].v, xv[0].v, rhov[0].v); + rhov[1].v = _mm256_fmadd_pd(avec[1].v, xv[0].v, rhov[1].v); + rhov[2].v = _mm256_fmadd_pd(avec[2].v, xv[1].v, rhov[2].v); + rhov[3].v = _mm256_fmadd_pd(avec[3].v, xv[1].v, rhov[3].v); + + x0 += n_elem_per_reg * n_iter_unroll[2]; + av[0] += n_elem_per_reg * n_iter_unroll[2]; + av[1] += n_elem_per_reg * n_iter_unroll[2]; + } + + rhov[0].v = _mm256_add_pd(rhov[0].v, rhov[2].v); + rhov[1].v = _mm256_add_pd(rhov[1].v, rhov[3].v); + + for (i = 0; i < m_viter[3]; ++i) + { + // Load the input values. + xv[0].v = _mm256_loadu_pd(x0 + 0 * n_elem_per_reg); + + avec[0].v = _mm256_loadu_pd(av[0] + 0 * n_elem_per_reg); + avec[1].v = _mm256_loadu_pd(av[1] + 0 * n_elem_per_reg); + + // perform: rho?v += a?v * x0v; + rhov[0].v = _mm256_fmadd_pd(avec[0].v, xv[0].v, rhov[0].v); + rhov[1].v = _mm256_fmadd_pd(avec[1].v, xv[0].v, rhov[1].v); + + x0 += n_elem_per_reg * n_iter_unroll[3]; + av[0] += n_elem_per_reg * n_iter_unroll[3]; + av[1] += n_elem_per_reg * n_iter_unroll[3]; + } + + // Sum the elements of a given rho?v. This computes the sum of + // elements within lanes and stores the sum to both elements. + rhov[0].v = _mm256_hadd_pd(rhov[0].v, rhov[0].v); + rhov[1].v = _mm256_hadd_pd(rhov[1].v, rhov[1].v); + + // Manually add the results from above to finish the sum. + rho0 = rhov[0].d[0] + rhov[0].d[2]; + rho1 = rhov[1].d[0] + rhov[1].d[2]; + + // Adjust for scalar subproblem. + for (i = 0; i < 4; ++i) + { + m -= n_elem_per_reg * n_iter_unroll[i] * m_viter[i]; + a += n_elem_per_reg * n_iter_unroll[i] * m_viter[i] /* * inca */; + x += n_elem_per_reg * n_iter_unroll[i] * m_viter[i] /* * incx */; + } + } + + // Initialize pointers for x and the b_n columns of A (rows of A^T). + double *restrict x0 = x; + double *restrict a0 = a + 0 * lda; + double *restrict a1 = a + 1 * lda; + + // If there are leftover iterations, perform them with scalar code. + for (dim_t i = 0; i < m; ++i) + { + const double x0c = *x0; + + const double a0c = *a0; + const double a1c = *a1; + + rho0 += a0c * x0c; + rho1 += a1c * x0c; + + x0 += incx; + a0 += inca; + a1 += inca; + } + + // Now prepare the final rho values to output/accumulate back into + // the y vector. + + v2df_t rho0v, y0v; + + // Insert the scalar rho values into a single vector. + rho0v.d[0] = rho0; + rho0v.d[1] = rho1; + + // Broadcast the alpha scalar. + v2df_t alphav; + + alphav.v = _mm_load1_pd(alpha); + + // We know at this point that alpha is nonzero; however, beta may still + // be zero. If beta is indeed zero, we must overwrite y rather than scale + // by beta (in case y contains NaN or Inf). + if (PASTEMAC(d, eq0)(*beta)) + { + // Apply alpha to the accumulated dot product in rho: + // y := alpha * rho + y0v.v = _mm_mul_pd(alphav.v, rho0v.v); + } + else + { + // Broadcast the beta scalar. + v2df_t betav; + betav.v = _mm_load1_pd(beta); + + // Load y. + if (incy == 1) + { + y0v.v = _mm_loadu_pd(y + 0 * 2); + } + else + { + y0v.d[0] = *(y + 0 * incy); + y0v.d[1] = *(y + 1 * incy); + } + + // Apply beta to y and alpha to the accumulated dot product in rho: + // y := beta * y + alpha * rho + y0v.v = _mm_mul_pd(betav.v, y0v.v); + y0v.v = _mm_fmadd_pd(alphav.v, rho0v.v, y0v.v); + } + + if (incy == 1) + { + // Store the output. + _mm_storeu_pd((y + 0 * 2), y0v.v); + } + else + { + *(y + 0 * incy) = y0v.d[0]; + *(y + 1 * incy) = y0v.d[1]; + } +} + + diff --git a/kernels/zen/bli_kernels_zen.h b/kernels/zen/bli_kernels_zen.h index f3a939b0b7..d46164a9c5 100644 --- a/kernels/zen/bli_kernels_zen.h +++ b/kernels/zen/bli_kernels_zen.h @@ -118,6 +118,8 @@ AXPYF_KER_PROT( dcomplex, z, axpyf_zen_int_4 ) // dotxf (intrinsics) DOTXF_KER_PROT( float, s, dotxf_zen_int_8 ) DOTXF_KER_PROT( double, d, dotxf_zen_int_8 ) +DOTXF_KER_PROT( double, d, dotxf_zen_int_4 ) +DOTXF_KER_PROT( double, d, dotxf_zen_int_2 ) // dotxaxpyf (intrinsics) DOTXAXPYF_KER_PROT( double, d, dotxaxpyf_zen_int_8 ) From c2df5eac1c6a9fe4d750129cf3fbe5fc5b332ee1 Mon Sep 17 00:00:00 2001 From: Nallani Bhaskar Date: Wed, 15 Dec 2021 15:11:08 +0530 Subject: [PATCH 055/243] Reduced number of threads in dgemm for small dimensions - Number of threads are reduced to 1 when the dimensions are very low. - Removed uninitialized xmm compilation warning in trsm small Change-Id: I23262fb82729af5b98ded5d36f5eed45d5255d5b --- frame/base/bli_rntm.c | 4 ++++ kernels/zen/3/bli_trsm_small.c | 3 +++ 2 files changed, 7 insertions(+) diff --git a/frame/base/bli_rntm.c b/frame/base/bli_rntm.c index 6a100bbe8e..dc0acf6bf9 100644 --- a/frame/base/bli_rntm.c +++ b/frame/base/bli_rntm.c @@ -574,6 +574,10 @@ void bli_nthreads_optimum( if(n < 15) n_threads_ideal = 1; else n_threads_ideal = 4; } + else if( ( m < 34) && (k < 68) && ( m < 34)) + { + n_threads_ideal = 1; + } else { if(n < 20) n_threads_ideal = 1; diff --git a/kernels/zen/3/bli_trsm_small.c b/kernels/zen/3/bli_trsm_small.c index c782a08a49..0fa8f66d5a 100644 --- a/kernels/zen/3/bli_trsm_small.c +++ b/kernels/zen/3/bli_trsm_small.c @@ -2847,6 +2847,7 @@ BLIS_INLINE err_t dtrsm_XAltB_ref #define BLIS_PRE_STRSM_SMALL_3N_2M(AlphaVal,b11,cs_b)\ ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); /*register to hold alpha*/\ \ + xmm5 = _mm_setzero_ps();\ xmm5 = _mm_loadl_pi(xmm5,(__m64*)(b11));\ ymm6 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ ymm3 = _mm256_fmsub_ps(ymm6, ymm15, ymm3);\ @@ -3009,6 +3010,7 @@ BLIS_INLINE err_t dtrsm_XAltB_ref #define BLIS_PRE_STRSM_SMALL_2N_2M(AlphaVal,b11,cs_b)\ ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); /*register to hold alpha*/\ \ + xmm5 = _mm_setzero_ps();\ xmm5 = _mm_loadl_pi(xmm5,(__m64*)(b11));\ ymm6 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ ymm3 = _mm256_fmsub_ps(ymm6, ymm15, ymm3);\ @@ -3116,6 +3118,7 @@ BLIS_INLINE err_t dtrsm_XAltB_ref #define BLIS_PRE_STRSM_SMALL_1N_2M(AlphaVal,b11,cs_b)\ ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); /*register to hold alpha*/\ \ + xmm5 = _mm_setzero_ps();\ xmm5 = _mm_loadl_pi(xmm5,(__m64*)(b11));\ ymm6 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ ymm3 = _mm256_fmsub_ps(ymm6, ymm15, ymm3); From f72758d80add2c14578f4018a81cba8225a8cf59 Mon Sep 17 00:00:00 2001 From: Harihara Sudhan S Date: Fri, 17 Dec 2021 14:04:13 +0530 Subject: [PATCH 056/243] Fixed DDOTXF Bug - Corrected xv and avec indexing in vector loop of bli_ddotxf_zen_int_2 Change-Id: I4c511236aad09541fe6b1295103a1a8b54ceec39 --- kernels/zen/1f/bli_dotxf_zen_int_8.c | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/kernels/zen/1f/bli_dotxf_zen_int_8.c b/kernels/zen/1f/bli_dotxf_zen_int_8.c index e25910fb4e..ad27403bdc 100644 --- a/kernels/zen/1f/bli_dotxf_zen_int_8.c +++ b/kernels/zen/1f/bli_dotxf_zen_int_8.c @@ -1337,19 +1337,19 @@ void bli_ddotxf_zen_int_2 rhov[7].v = _mm256_fmadd_pd(avec[7].v, xv[3].v, rhov[7].v); // Load the input values. - xv[0].v = _mm256_loadu_pd(x0 + 0 * n_elem_per_reg); - xv[1].v = _mm256_loadu_pd(x0 + 1 * n_elem_per_reg); - xv[2].v = _mm256_loadu_pd(x0 + 2 * n_elem_per_reg); - xv[3].v = _mm256_loadu_pd(x0 + 3 * n_elem_per_reg); - - avec[0].v = _mm256_loadu_pd(av[0] + 0 * n_elem_per_reg); - avec[1].v = _mm256_loadu_pd(av[1] + 0 * n_elem_per_reg); - avec[2].v = _mm256_loadu_pd(av[0] + 1 * n_elem_per_reg); - avec[3].v = _mm256_loadu_pd(av[1] + 1 * n_elem_per_reg); - avec[4].v = _mm256_loadu_pd(av[0] + 2 * n_elem_per_reg); - avec[5].v = _mm256_loadu_pd(av[1] + 2 * n_elem_per_reg); - avec[6].v = _mm256_loadu_pd(av[0] + 3 * n_elem_per_reg); - avec[7].v = _mm256_loadu_pd(av[1] + 3 * n_elem_per_reg); + xv[0].v = _mm256_loadu_pd(x0 + 4 * n_elem_per_reg); + xv[1].v = _mm256_loadu_pd(x0 + 5 * n_elem_per_reg); + xv[2].v = _mm256_loadu_pd(x0 + 6 * n_elem_per_reg); + xv[3].v = _mm256_loadu_pd(x0 + 7 * n_elem_per_reg); + + avec[0].v = _mm256_loadu_pd(av[0] + 4 * n_elem_per_reg); + avec[1].v = _mm256_loadu_pd(av[1] + 4 * n_elem_per_reg); + avec[2].v = _mm256_loadu_pd(av[0] + 5 * n_elem_per_reg); + avec[3].v = _mm256_loadu_pd(av[1] + 5 * n_elem_per_reg); + avec[4].v = _mm256_loadu_pd(av[0] + 6 * n_elem_per_reg); + avec[5].v = _mm256_loadu_pd(av[1] + 6 * n_elem_per_reg); + avec[6].v = _mm256_loadu_pd(av[0] + 7 * n_elem_per_reg); + avec[7].v = _mm256_loadu_pd(av[1] + 7 * n_elem_per_reg); // perform: rho?v += a?v * x0v; rhov[0].v = _mm256_fmadd_pd(avec[0].v, xv[0].v, rhov[0].v); From 75d5f538d2836f168f0c81046566a33269d6703c Mon Sep 17 00:00:00 2001 From: Harihara Sudhan S Date: Mon, 20 Dec 2021 12:17:05 +0530 Subject: [PATCH 057/243] Improved AXPYV Kernel performance - Increased the unroll factor of the loop by 15 in SAXPYV - Increased the unroll factor of the loop by 12 in DAXPYV - The above changes were made for better register utilization Change-Id: I69ad1fec2fcf958dbd1bfd71378641274b43a6aa --- kernels/zen/1/bli_axpyv_zen_int10.c | 150 ++++++++++++++++++++++++++-- 1 file changed, 142 insertions(+), 8 deletions(-) diff --git a/kernels/zen/1/bli_axpyv_zen_int10.c b/kernels/zen/1/bli_axpyv_zen_int10.c index 6f953e6f4c..4ef6981cd7 100644 --- a/kernels/zen/1/bli_axpyv_zen_int10.c +++ b/kernels/zen/1/bli_axpyv_zen_int10.c @@ -75,9 +75,9 @@ void bli_saxpyv_zen_int10 float* restrict y0; __m256 alphav; - __m256 xv[10]; - __m256 yv[10]; - __m256 zv[10]; + __m256 xv[15]; + __m256 yv[15]; + __m256 zv[15]; // If the vector dimension is zero, or if alpha is zero, return early. if ( bli_zero_dim1( n ) || PASTEMAC(s,eq0)( *alpha ) ) @@ -95,7 +95,78 @@ void bli_saxpyv_zen_int10 // Broadcast the alpha scalar to all elements of a vector register. alphav = _mm256_broadcast_ss( alpha ); - for ( i = 0; (i + 79) < n; i += 80 ) + for (i = 0; (i + 119) < n; i += 120) + { + // 120 elements will be processed per loop; 15 FMAs will run per loop. + xv[0] = _mm256_loadu_ps(x0 + 0 * n_elem_per_reg); + xv[1] = _mm256_loadu_ps(x0 + 1 * n_elem_per_reg); + xv[2] = _mm256_loadu_ps(x0 + 2 * n_elem_per_reg); + xv[3] = _mm256_loadu_ps(x0 + 3 * n_elem_per_reg); + xv[4] = _mm256_loadu_ps(x0 + 4 * n_elem_per_reg); + xv[5] = _mm256_loadu_ps(x0 + 5 * n_elem_per_reg); + xv[6] = _mm256_loadu_ps(x0 + 6 * n_elem_per_reg); + xv[7] = _mm256_loadu_ps(x0 + 7 * n_elem_per_reg); + xv[8] = _mm256_loadu_ps(x0 + 8 * n_elem_per_reg); + xv[9] = _mm256_loadu_ps(x0 + 9 * n_elem_per_reg); + xv[10] = _mm256_loadu_ps(x0 + 10 * n_elem_per_reg); + xv[11] = _mm256_loadu_ps(x0 + 11 * n_elem_per_reg); + xv[12] = _mm256_loadu_ps(x0 + 12 * n_elem_per_reg); + xv[13] = _mm256_loadu_ps(x0 + 13 * n_elem_per_reg); + xv[14] = _mm256_loadu_ps(x0 + 14 * n_elem_per_reg); + + yv[0] = _mm256_loadu_ps(y0 + 0 * n_elem_per_reg); + yv[1] = _mm256_loadu_ps(y0 + 1 * n_elem_per_reg); + yv[2] = _mm256_loadu_ps(y0 + 2 * n_elem_per_reg); + yv[3] = _mm256_loadu_ps(y0 + 3 * n_elem_per_reg); + yv[4] = _mm256_loadu_ps(y0 + 4 * n_elem_per_reg); + yv[5] = _mm256_loadu_ps(y0 + 5 * n_elem_per_reg); + yv[6] = _mm256_loadu_ps(y0 + 6 * n_elem_per_reg); + yv[7] = _mm256_loadu_ps(y0 + 7 * n_elem_per_reg); + yv[8] = _mm256_loadu_ps(y0 + 8 * n_elem_per_reg); + yv[9] = _mm256_loadu_ps(y0 + 9 * n_elem_per_reg); + yv[10] = _mm256_loadu_ps(y0 + 10 * n_elem_per_reg); + yv[11] = _mm256_loadu_ps(y0 + 11 * n_elem_per_reg); + yv[12] = _mm256_loadu_ps(y0 + 12 * n_elem_per_reg); + yv[13] = _mm256_loadu_ps(y0 + 13 * n_elem_per_reg); + yv[14] = _mm256_loadu_ps(y0 + 14 * n_elem_per_reg); + + zv[0] = _mm256_fmadd_ps(xv[0], alphav, yv[0]); + zv[1] = _mm256_fmadd_ps(xv[1], alphav, yv[1]); + zv[2] = _mm256_fmadd_ps(xv[2], alphav, yv[2]); + zv[3] = _mm256_fmadd_ps(xv[3], alphav, yv[3]); + zv[4] = _mm256_fmadd_ps(xv[4], alphav, yv[4]); + zv[5] = _mm256_fmadd_ps(xv[5], alphav, yv[5]); + zv[6] = _mm256_fmadd_ps(xv[6], alphav, yv[6]); + zv[7] = _mm256_fmadd_ps(xv[7], alphav, yv[7]); + zv[8] = _mm256_fmadd_ps(xv[8], alphav, yv[8]); + zv[9] = _mm256_fmadd_ps(xv[9], alphav, yv[9]); + zv[10] = _mm256_fmadd_ps(xv[10], alphav, yv[10]); + zv[11] = _mm256_fmadd_ps(xv[11], alphav, yv[11]); + zv[12] = _mm256_fmadd_ps(xv[12], alphav, yv[12]); + zv[13] = _mm256_fmadd_ps(xv[13], alphav, yv[13]); + zv[14] = _mm256_fmadd_ps(xv[14], alphav, yv[14]); + + _mm256_storeu_ps((y0 + 0 * n_elem_per_reg), zv[0]); + _mm256_storeu_ps((y0 + 1 * n_elem_per_reg), zv[1]); + _mm256_storeu_ps((y0 + 2 * n_elem_per_reg), zv[2]); + _mm256_storeu_ps((y0 + 3 * n_elem_per_reg), zv[3]); + _mm256_storeu_ps((y0 + 4 * n_elem_per_reg), zv[4]); + _mm256_storeu_ps((y0 + 5 * n_elem_per_reg), zv[5]); + _mm256_storeu_ps((y0 + 6 * n_elem_per_reg), zv[6]); + _mm256_storeu_ps((y0 + 7 * n_elem_per_reg), zv[7]); + _mm256_storeu_ps((y0 + 8 * n_elem_per_reg), zv[8]); + _mm256_storeu_ps((y0 + 9 * n_elem_per_reg), zv[9]); + _mm256_storeu_ps((y0 + 10 * n_elem_per_reg), zv[10]); + _mm256_storeu_ps((y0 + 11 * n_elem_per_reg), zv[11]); + _mm256_storeu_ps((y0 + 12 * n_elem_per_reg), zv[12]); + _mm256_storeu_ps((y0 + 13 * n_elem_per_reg), zv[13]); + _mm256_storeu_ps((y0 + 14 * n_elem_per_reg), zv[14]); + + x0 += 15 * n_elem_per_reg; + y0 += 15 * n_elem_per_reg; + } + + for (; (i + 79) < n; i += 80 ) { // 80 elements will be processed per loop; 10 FMAs will run per loop. xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); @@ -288,9 +359,9 @@ void bli_daxpyv_zen_int10 double* restrict y0 = y; __m256d alphav; - __m256d xv[10]; - __m256d yv[10]; - __m256d zv[10]; + __m256d xv[13]; + __m256d yv[13]; + __m256d zv[13]; // If the vector dimension is zero, or if alpha is zero, return early. if ( bli_zero_dim1( n ) || PASTEMAC(d,eq0)( *alpha ) ) @@ -308,7 +379,70 @@ void bli_daxpyv_zen_int10 // Broadcast the alpha scalar to all elements of a vector register. alphav = _mm256_broadcast_sd( alpha ); - for ( i = 0; (i + 39) < n; i += 40 ) + for (i = 0; (i + 51) < n; i += 52) + { + // 52 elements will be processed per loop; 13 FMAs will run per loop. + xv[0] = _mm256_loadu_pd(x0 + 0 * n_elem_per_reg); + xv[1] = _mm256_loadu_pd(x0 + 1 * n_elem_per_reg); + xv[2] = _mm256_loadu_pd(x0 + 2 * n_elem_per_reg); + xv[3] = _mm256_loadu_pd(x0 + 3 * n_elem_per_reg); + xv[4] = _mm256_loadu_pd(x0 + 4 * n_elem_per_reg); + xv[5] = _mm256_loadu_pd(x0 + 5 * n_elem_per_reg); + xv[6] = _mm256_loadu_pd(x0 + 6 * n_elem_per_reg); + xv[7] = _mm256_loadu_pd(x0 + 7 * n_elem_per_reg); + xv[8] = _mm256_loadu_pd(x0 + 8 * n_elem_per_reg); + xv[9] = _mm256_loadu_pd(x0 + 9 * n_elem_per_reg); + xv[10] = _mm256_loadu_pd(x0 + 10 * n_elem_per_reg); + xv[11] = _mm256_loadu_pd(x0 + 11 * n_elem_per_reg); + xv[12] = _mm256_loadu_pd(x0 + 12 * n_elem_per_reg); + + yv[0] = _mm256_loadu_pd(y0 + 0 * n_elem_per_reg); + yv[1] = _mm256_loadu_pd(y0 + 1 * n_elem_per_reg); + yv[2] = _mm256_loadu_pd(y0 + 2 * n_elem_per_reg); + yv[3] = _mm256_loadu_pd(y0 + 3 * n_elem_per_reg); + yv[4] = _mm256_loadu_pd(y0 + 4 * n_elem_per_reg); + yv[5] = _mm256_loadu_pd(y0 + 5 * n_elem_per_reg); + yv[6] = _mm256_loadu_pd(y0 + 6 * n_elem_per_reg); + yv[7] = _mm256_loadu_pd(y0 + 7 * n_elem_per_reg); + yv[8] = _mm256_loadu_pd(y0 + 8 * n_elem_per_reg); + yv[9] = _mm256_loadu_pd(y0 + 9 * n_elem_per_reg); + yv[10] = _mm256_loadu_pd(y0 + 10 * n_elem_per_reg); + yv[11] = _mm256_loadu_pd(y0 + 11 * n_elem_per_reg); + yv[12] = _mm256_loadu_pd(y0 + 12 * n_elem_per_reg); + + zv[0] = _mm256_fmadd_pd(xv[0], alphav, yv[0]); + zv[1] = _mm256_fmadd_pd(xv[1], alphav, yv[1]); + zv[2] = _mm256_fmadd_pd(xv[2], alphav, yv[2]); + zv[3] = _mm256_fmadd_pd(xv[3], alphav, yv[3]); + zv[4] = _mm256_fmadd_pd(xv[4], alphav, yv[4]); + zv[5] = _mm256_fmadd_pd(xv[5], alphav, yv[5]); + zv[6] = _mm256_fmadd_pd(xv[6], alphav, yv[6]); + zv[7] = _mm256_fmadd_pd(xv[7], alphav, yv[7]); + zv[8] = _mm256_fmadd_pd(xv[8], alphav, yv[8]); + zv[9] = _mm256_fmadd_pd(xv[9], alphav, yv[9]); + zv[10] = _mm256_fmadd_pd(xv[10], alphav, yv[10]); + zv[11] = _mm256_fmadd_pd(xv[11], alphav, yv[11]); + zv[12] = _mm256_fmadd_pd(xv[12], alphav, yv[12]); + + _mm256_storeu_pd((y0 + 0 * n_elem_per_reg), zv[0]); + _mm256_storeu_pd((y0 + 1 * n_elem_per_reg), zv[1]); + _mm256_storeu_pd((y0 + 2 * n_elem_per_reg), zv[2]); + _mm256_storeu_pd((y0 + 3 * n_elem_per_reg), zv[3]); + _mm256_storeu_pd((y0 + 4 * n_elem_per_reg), zv[4]); + _mm256_storeu_pd((y0 + 5 * n_elem_per_reg), zv[5]); + _mm256_storeu_pd((y0 + 6 * n_elem_per_reg), zv[6]); + _mm256_storeu_pd((y0 + 7 * n_elem_per_reg), zv[7]); + _mm256_storeu_pd((y0 + 8 * n_elem_per_reg), zv[8]); + _mm256_storeu_pd((y0 + 9 * n_elem_per_reg), zv[9]); + _mm256_storeu_pd((y0 + 10 * n_elem_per_reg), zv[10]); + _mm256_storeu_pd((y0 + 11 * n_elem_per_reg), zv[11]); + _mm256_storeu_pd((y0 + 12 * n_elem_per_reg), zv[12]); + + x0 += 13 * n_elem_per_reg; + y0 += 13 * n_elem_per_reg; + } + + for ( ; (i + 39) < n; i += 40 ) { // 40 elements will be processed per loop; 10 FMAs will run per loop. xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); From b3553c08fa8a3f5007df0c5428c3cc6f2eb20ba4 Mon Sep 17 00:00:00 2001 From: Chandrashekara K R Date: Wed, 22 Dec 2021 14:47:15 +0530 Subject: [PATCH 058/243] AOCL-Windows: Updating the blis windows build system. 1. Removed the libomp.lib hardcoded from cmake scripts and made it user configurable. By default libomp.lib is used as an omp library. 2. Added the STATIC_LIBRARY_OPTIONS property in set_target_properties cmake command to link omp library to build static-mt blis library. 3. Updated the blis_ref_kernel_mirror.py to give support for zen4 architecture. AMD-Internal: CPUPL-1630 Change-Id: I54b04cde2fa6a1ddc4b4303f1da808c1efe0484a --- CMakeLists.txt | 13 +++-- build/blis_ref_kernel_mirror.py | 25 +++++---- test/CMakeLists.txt | 96 ++++++++++++++++----------------- testsuite/CMakeLists.txt | 4 +- 4 files changed, 72 insertions(+), 66 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index aebb509d73..3affe7b40c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -10,8 +10,7 @@ set(CMAKE_RUNTIME_OUTPUT_DIRECTORY "${CMAKE_SOURCE_DIR}/bin") SET(AOCL_BLIS_FAMILY "zen" CACHE STRING "AOCL BLIS family name") -SET(OPENMP_PATH "C:\\Program Files\\LLVM\\lib" CACHE STRING "openmp library -path") +SET(OMP_LIB "C:\\Program Files\\LLVM\\lib\\libomp.lib" CACHE STRING "openmp library path") set(TARGET_ARCH ${AOCL_BLIS_FAMILY}) set(AOCL_BLIS_ZEN TRUE) set (PYTHON_EXE "python") @@ -542,15 +541,14 @@ file (STRINGS "version" BLIS_VERSION) set(BLIS_VERSION_STRING ${BLIS_VERSION}) add_definitions(-DBLIS_VERSION_STRING="AOCL BLIS ${BLIS_VERSION_STRING}") -message( STATUS "OPENMP PATH:" ${OPENMP_PATH}) -link_directories("${OPENMP_PATH}") +message( STATUS "OPENMP Library:" ${OMP_LIB}) if(BUILD_SHARED_LIBS) add_library("${PROJECT_NAME}" SHARED ${CMAKE_SOURCE_DIR}/bli_config.h ${CMAKE_SOURCE_DIR}/include/${TARGET_ARCH}/blis.h ${headers}) if(ENABLE_OPENMP) - target_link_libraries("${PROJECT_NAME}" PUBLIC "${OPENMP_PATH}/libomp.lib") + target_link_libraries("${PROJECT_NAME}" PUBLIC "${OMP_LIB}") endif() target_compile_definitions("${PROJECT_NAME}" PUBLIC -DBLIS_IS_BUILDING_LIBRARY) set_target_properties("${PROJECT_NAME}" PROPERTIES LINKER_LANGUAGE C OUTPUT_NAME "${LIB_NAME}") @@ -560,9 +558,10 @@ if(NOT BUILD_SHARED_LIBS) ${CMAKE_SOURCE_DIR}/include/${TARGET_ARCH}/blis.h ${headers}) if(ENABLE_OPENMP) - target_link_libraries("${PROJECT_NAME}" PUBLIC "${OPENMP_PATH}/libomp.lib") + set_target_properties("${PROJECT_NAME}" PROPERTIES LINKER_LANGUAGE C OUTPUT_NAME "${LIB_NAME}" STATIC_LIBRARY_OPTIONS "${OMP_LIB}") + else() + set_target_properties("${PROJECT_NAME}" PROPERTIES LINKER_LANGUAGE C OUTPUT_NAME "${LIB_NAME}") endif() - set_target_properties("${PROJECT_NAME}" PROPERTIES LINKER_LANGUAGE C OUTPUT_NAME "${LIB_NAME}") endif() link_directories(${CMAKE_LIBRARY_OUTPUT_DIRECTORY}) diff --git a/build/blis_ref_kernel_mirror.py b/build/blis_ref_kernel_mirror.py index 0dec5d66fa..834de1cee9 100644 --- a/build/blis_ref_kernel_mirror.py +++ b/build/blis_ref_kernel_mirror.py @@ -69,11 +69,11 @@ def remove_lines_in_file(filename): file_content = fd.read() file_content = file_content.replace( 'if(${TARGET_ARCH} STREQUAL amdzen)\n' - 'add_subdirectory(${CMAKE_BINARY_DIR}/ref_kernels/generic' + 'add_subdirectory(${CMAKE_BINARY_DIR}/ref_kernels/generic ' '${CMAKE_BINARY_DIR}/ref_kernels/generic)\n' - 'add_subdirectory(${CMAKE_BINARY_DIR}/ref_kernels/zen' + 'add_subdirectory(${CMAKE_BINARY_DIR}/ref_kernels/zen ' '${CMAKE_BINARY_DIR}/ref_kernels/zen)\n' - 'add_subdirectory(${CMAKE_BINARY_DIR}/ref_kernels/zen2' + 'add_subdirectory(${CMAKE_BINARY_DIR}/ref_kernels/zen2 ' '${CMAKE_BINARY_DIR}/ref_kernels/zen2)\n' 'add_subdirectory(${CMAKE_BINARY_DIR}/ref_kernels/zen3 ' '${CMAKE_BINARY_DIR}/ref_kernels/zen3)\n' @@ -115,6 +115,7 @@ def add_macro_to_cfiles(cfiles, macro): create_folder(os.path.join(dest_path, 'zen')) create_folder(os.path.join(dest_path, 'zen2')) create_folder(os.path.join(dest_path, 'zen3')) + create_folder(os.path.join(dest_path, 'zen4')) create_folder(os.path.join(dest_path, 'generic')) execute_and_check('XCOPY {} {} /E'.format( temp, os.path.join(dest_path, 'zen'))) @@ -122,6 +123,8 @@ def add_macro_to_cfiles(cfiles, macro): temp, os.path.join(dest_path, 'zen2'))) execute_and_check('XCOPY {} {} /E'.format( temp, os.path.join(dest_path, 'zen3'))) + execute_and_check('XCOPY {} {} /E'.format( + temp, os.path.join(dest_path, 'zen4'))) execute_and_check('XCOPY {} {} /E'.format( temp, os.path.join(dest_path, 'generic'))) remove_folder(temp) @@ -133,6 +136,8 @@ def add_macro_to_cfiles(cfiles, macro): dest_path, 'zen2', 'CMakeLists.txt')) remove_lines_in_file(os.path.join( dest_path, 'zen3', 'CMakeLists.txt')) + remove_lines_in_file(os.path.join( + dest_path, 'zen4', 'CMakeLists.txt')) cfiles_in_generic = execute_and_check('cd {} && dir / s / b / o: gn *.c' .format(os.path.join(dest_path, 'generic'))) @@ -140,20 +145,22 @@ def add_macro_to_cfiles(cfiles, macro): add_macro_to_cfiles(cfiles_in_generic, '\n#define BLIS_CNAME_INFIX _generic\n') cfiles_in_zen = execute_and_check('cd {} && dir / s / b / o: gn *.c' - .format(os.path.join(dest_path, - 'zen'))) + .format(os.path.join(dest_path, 'zen'))) cfiles_in_zen = cfiles_in_zen.split('\r\n') add_macro_to_cfiles(cfiles_in_zen, '\n#define BLIS_CNAME_INFIX _zen\n') cfiles_in_zen2 = execute_and_check('cd {} && dir / s / b / o: gn *.c' - .format(os.path.join(dest_path, - 'zen2'))) + .format(os.path.join(dest_path, 'zen2'))) cfiles_in_zen2 = cfiles_in_zen2.split('\r\n') add_macro_to_cfiles(cfiles_in_zen2, '\n#define BLIS_CNAME_INFIX _zen2\n') cfiles_in_zen3 = execute_and_check('cd {} && dir / s / b / o: gn *.c' - .format(os.path.join(dest_path, - 'zen3'))) + .format(os.path.join(dest_path, 'zen3'))) cfiles_in_zen3 = cfiles_in_zen3.split('\r\n') add_macro_to_cfiles(cfiles_in_zen3, '\n#define BLIS_CNAME_INFIX _zen3\n') + cfiles_in_zen4 = execute_and_check('cd {} && dir / s / b / o: gn *.c' + .format(os.path.join(dest_path, 'zen4'))) + cfiles_in_zen4 = cfiles_in_zen4.split('\r\n') + add_macro_to_cfiles(cfiles_in_zen4, + '\n#define BLIS_CNAME_INFIX _zen4\n') diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 3b0315c9ae..fe8f7bac98 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -4,169 +4,169 @@ add_definitions(-DBLAS="AOCL") add_executable(TestAminv test_aminv.c) target_link_libraries(TestAminv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestAminv "${OPENMP_PATH}/libomp.lib") +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestAminv "${OMP_LIB}") endif() target_link_libraries(TestAminv optimized "${LIB_NAME}.lib") add_executable(TestAxpyv test_axpyv.c) target_link_libraries(TestAxpyv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestAxpyv "${OPENMP_PATH}/libomp.lib") +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestAxpyv "${OMP_LIB}") endif() target_link_libraries(TestAxpyv optimized "${LIB_NAME}.lib") add_executable(TestAxpbyv test_axpbyv.c) target_link_libraries(TestAxpbyv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestAxpbyv "${OPENMP_PATH}/libomp.lib") +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestAxpbyv "${OMP_LIB}") endif() target_link_libraries(TestAxpbyv optimized "${LIB_NAME}.lib") add_executable(TestCopyv test_copyv.c) target_link_libraries(TestCopyv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestCopyv "${OPENMP_PATH}/libomp.lib") +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestCopyv "${OMP_LIB}") endif() target_link_libraries(TestCopyv optimized "${LIB_NAME}.lib") add_executable(TestCabs1 test_cabs1.c) target_link_libraries(TestCabs1 debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestCabs1 "${OPENMP_PATH}/libomp.lib") +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestCabs1 "${OMP_LIB}") endif() target_link_libraries(TestCabs1 optimized "${LIB_NAME}.lib") add_executable(TestDotv test_dotv.c) target_link_libraries(TestDotv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestDotv "${OPENMP_PATH}/libomp.lib") +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestDotv "${OMP_LIB}") endif() target_link_libraries(TestDotv optimized "${LIB_NAME}.lib") add_executable(TestGemm test_gemm.c) target_link_libraries(TestGemm debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestGemm "${OPENMP_PATH}/libomp.lib") +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestGemm "${OMP_LIB}") endif() target_link_libraries(TestGemm optimized "${LIB_NAME}.lib") add_executable(TestGemmBatch test_gemm_batch.c) target_link_libraries(TestGemmBatch debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestGemmBatch "${OPENMP_PATH}/libomp.lib") +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestGemmBatch "${OMP_LIB}") endif() target_link_libraries(TestGemmBatch optimized "${LIB_NAME}.lib") add_executable(TestGemm3m test_gemm3m.c) target_link_libraries(TestGemm3m debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestGemm3m "${OPENMP_PATH}/libomp.lib") +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestGemm3m "${OMP_LIB}") endif() target_link_libraries(TestGemm3m optimized "${LIB_NAME}.lib") add_executable(TestGemmt test_gemmt.c) target_link_libraries(TestGemmt debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestGemmt "${OPENMP_PATH}/libomp.lib") +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestGemmt "${OMP_LIB}") endif() target_link_libraries(TestGemmt optimized "${LIB_NAME}.lib") add_executable(TestGemv test_gemv.c) target_link_libraries(TestGemv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestGemv "${OPENMP_PATH}/libomp.lib") +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestGemv "${OMP_LIB}") endif() target_link_libraries(TestGemv optimized "${LIB_NAME}.lib") add_executable(TestGer test_ger.c) target_link_libraries(TestGer debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestGer "${OPENMP_PATH}/libomp.lib") +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestGer "${OMP_LIB}") endif() target_link_libraries(TestGer optimized "${LIB_NAME}.lib") add_executable(TestHemm test_hemm.c) target_link_libraries(TestHemm debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestHemm "${OPENMP_PATH}/libomp.lib") +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestHemm "${OMP_LIB}") endif() target_link_libraries(TestHemm optimized "${LIB_NAME}.lib") add_executable(TestHemv test_hemv.c) target_link_libraries(TestHemv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestHemv "${OPENMP_PATH}/libomp.lib") +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestHemv "${OMP_LIB}") endif() target_link_libraries(TestHemv optimized "${LIB_NAME}.lib") add_executable(TestHer test_her.c) target_link_libraries(TestHer debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestHer "${OPENMP_PATH}/libomp.lib") +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestHer "${OMP_LIB}") endif() target_link_libraries(TestHer optimized "${LIB_NAME}.lib") add_executable(TestHer2 test_her2.c) target_link_libraries(TestHer2 debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestHer2 "${OPENMP_PATH}/libomp.lib") +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestHer2 "${OMP_LIB}") endif() target_link_libraries(TestHer2 optimized "${LIB_NAME}.lib") add_executable(TestHer2k test_her2k.c) target_link_libraries(TestHer2k debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestHer2k "${OPENMP_PATH}/libomp.lib") +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestHer2k "${OMP_LIB}") endif() target_link_libraries(TestHer2k optimized "${LIB_NAME}.lib") add_executable(TestHerk test_herk.c) target_link_libraries(TestHerk debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestHerk "${OPENMP_PATH}/libomp.lib") +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestHerk "${OMP_LIB}") endif() target_link_libraries(TestHerk optimized "${LIB_NAME}.lib") add_executable(TestScalv test_scalv.c) target_link_libraries(TestScalv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestScalv "${OPENMP_PATH}/libomp.lib") +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestScalv "${OMP_LIB}") endif() target_link_libraries(TestScalv optimized "${LIB_NAME}.lib") add_executable(TestSwapv test_swapv.c) target_link_libraries(TestSwapv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestSwapv "${OPENMP_PATH}/libomp.lib") +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestSwapv "${OMP_LIB}") endif() target_link_libraries(TestSwapv optimized "${LIB_NAME}.lib") add_executable(TestTrmm test_trmm.c) target_link_libraries(TestTrmm debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestTrmm "${OPENMP_PATH}/libomp.lib") +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestTrmm "${OMP_LIB}") endif() target_link_libraries(TestTrmm optimized "${LIB_NAME}.lib") add_executable(TestTrmv test_trmv.c) target_link_libraries(TestTrmv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestTrmv "${OPENMP_PATH}/libomp.lib") +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestTrmv "${OMP_LIB}") endif() target_link_libraries(TestTrmv optimized "${LIB_NAME}.lib") add_executable(TestTrsm test_trsm.c) target_link_libraries(TestTrsm debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestTrsm "${OPENMP_PATH}/libomp.lib") +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestTrsm "${OMP_LIB}") endif() target_link_libraries(TestTrsm optimized "${LIB_NAME}.lib") add_executable(TestTrsv test_trsv.c) target_link_libraries(TestTrsv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestTrsv "${OPENMP_PATH}/libomp.lib") +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestTrsv "${OMP_LIB}") endif() target_link_libraries(TestTrsv optimized "${LIB_NAME}.lib") diff --git a/testsuite/CMakeLists.txt b/testsuite/CMakeLists.txt index f03d094782..613f9e3861 100644 --- a/testsuite/CMakeLists.txt +++ b/testsuite/CMakeLists.txt @@ -7,8 +7,8 @@ add_executable(test_libblis "") add_subdirectory(src) target_link_libraries(test_libblis debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(test_libblis "${OPENMP_PATH}/libomp.lib") +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(test_libblis "${OMP_LIB}") endif() target_link_libraries(test_libblis optimized "${LIB_NAME}.lib") From b095f1f3a2ff87a9794b64bdf78f0878bdb57c28 Mon Sep 17 00:00:00 2001 From: Harihara Sudhan S Date: Mon, 15 Nov 2021 23:28:33 +0530 Subject: [PATCH 059/243] Improved SCALV kernel performance. - Unrolled the loop by a greater factor. Incorporated switch case to decide unrolling factor according to the input size. - Removed unused structs. AMD-Internal: [CPUPL-1974] Change-Id: Iee9d7defcc8c582ca0420f84c4fb2c202dabe3e7 --- kernels/zen/1/bli_scalv_zen_int10.c | 666 +++++++++++++++++----------- 1 file changed, 404 insertions(+), 262 deletions(-) diff --git a/kernels/zen/1/bli_scalv_zen_int10.c b/kernels/zen/1/bli_scalv_zen_int10.c index 6c7f52e161..de9d8339d3 100644 --- a/kernels/zen/1/bli_scalv_zen_int10.c +++ b/kernels/zen/1/bli_scalv_zen_int10.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2017 - 2020, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2017 - 2021, Advanced Micro Devices, Inc. All rights reserved. Copyright (C) 2018, The University of Texas at Austin Redistribution and use in source and binary forms, with or without @@ -36,23 +36,6 @@ #include "immintrin.h" #include "blis.h" - -/* Union data structure to access AVX registers - One 256-bit AVX register holds 8 SP elements. */ -typedef union -{ - __m256 v; - float f[8] __attribute__((aligned(64))); -} v8sf_t; - -/* Union data structure to access AVX registers -* One 256-bit AVX register holds 4 DP elements. */ -typedef union -{ - __m256d v; - double d[4] __attribute__((aligned(64))); -} v4df_t; - // ----------------------------------------------------------------------------- void bli_sscalv_zen_int10 @@ -66,13 +49,13 @@ void bli_sscalv_zen_int10 { const dim_t n_elem_per_reg = 8; - dim_t i; + dim_t i = 0; float* restrict x0; __m256 alphav; - __m256 xv[10]; - __m256 zv[10]; + __m256 xv[16]; + __m256 zv[16]; // If the vector dimension is zero, or if alpha is unit, return early. if ( bli_zero_dim1( n ) || PASTEMAC(s,eq1)( *alpha ) ) return; @@ -111,140 +94,218 @@ void bli_sscalv_zen_int10 { // Broadcast the alpha scalar to all elements of a vector register. alphav = _mm256_broadcast_ss( alpha ); + dim_t option; - for ( i = 0; (i + 79) < n; i += 80 ) + // Unroll and the loop used is picked based on the input size. + if( n < 300) { - // Load the input values. - xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); - xv[1] = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); - xv[2] = _mm256_loadu_ps( x0 + 2*n_elem_per_reg ); - xv[3] = _mm256_loadu_ps( x0 + 3*n_elem_per_reg ); - xv[4] = _mm256_loadu_ps( x0 + 4*n_elem_per_reg ); - xv[5] = _mm256_loadu_ps( x0 + 5*n_elem_per_reg ); - xv[6] = _mm256_loadu_ps( x0 + 6*n_elem_per_reg ); - xv[7] = _mm256_loadu_ps( x0 + 7*n_elem_per_reg ); - xv[8] = _mm256_loadu_ps( x0 + 8*n_elem_per_reg ); - xv[9] = _mm256_loadu_ps( x0 + 9*n_elem_per_reg ); - - // perform : x := alpha * x; - zv[0] = _mm256_mul_ps( alphav, xv[0] ); - zv[1] = _mm256_mul_ps( alphav, xv[1] ); - zv[2] = _mm256_mul_ps( alphav, xv[2] ); - zv[3] = _mm256_mul_ps( alphav, xv[3] ); - zv[4] = _mm256_mul_ps( alphav, xv[4] ); - zv[5] = _mm256_mul_ps( alphav, xv[5] ); - zv[6] = _mm256_mul_ps( alphav, xv[6] ); - zv[7] = _mm256_mul_ps( alphav, xv[7] ); - zv[8] = _mm256_mul_ps( alphav, xv[8] ); - zv[9] = _mm256_mul_ps( alphav, xv[9] ); - - // Store the output. - _mm256_storeu_ps( (x0 + 0*n_elem_per_reg), zv[0] ); - _mm256_storeu_ps( (x0 + 1*n_elem_per_reg), zv[1] ); - _mm256_storeu_ps( (x0 + 2*n_elem_per_reg), zv[2] ); - _mm256_storeu_ps( (x0 + 3*n_elem_per_reg), zv[3] ); - _mm256_storeu_ps( (x0 + 4*n_elem_per_reg), zv[4] ); - _mm256_storeu_ps( (x0 + 5*n_elem_per_reg), zv[5] ); - _mm256_storeu_ps( (x0 + 6*n_elem_per_reg), zv[6] ); - _mm256_storeu_ps( (x0 + 7*n_elem_per_reg), zv[7] ); - _mm256_storeu_ps( (x0 + 8*n_elem_per_reg), zv[8] ); - _mm256_storeu_ps( (x0 + 9*n_elem_per_reg), zv[9] ); - - x0 += 10*n_elem_per_reg; + option = 2; } - - for ( ; (i + 39) < n; i += 40 ) - { - // Load the input values. - xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); - xv[1] = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); - xv[2] = _mm256_loadu_ps( x0 + 2*n_elem_per_reg ); - xv[3] = _mm256_loadu_ps( x0 + 3*n_elem_per_reg ); - xv[4] = _mm256_loadu_ps( x0 + 4*n_elem_per_reg ); - - // perform : x := alpha * x; - zv[0] = _mm256_mul_ps( alphav, xv[0] ); - zv[1] = _mm256_mul_ps( alphav, xv[1] ); - zv[2] = _mm256_mul_ps( alphav, xv[2] ); - zv[3] = _mm256_mul_ps( alphav, xv[3] ); - zv[4] = _mm256_mul_ps( alphav, xv[4] ); - - // Store the output. - _mm256_storeu_ps( (x0 + 0*n_elem_per_reg), zv[0] ); - _mm256_storeu_ps( (x0 + 1*n_elem_per_reg), zv[1] ); - _mm256_storeu_ps( (x0 + 2*n_elem_per_reg), zv[2] ); - _mm256_storeu_ps( (x0 + 3*n_elem_per_reg), zv[3] ); - _mm256_storeu_ps( (x0 + 4*n_elem_per_reg), zv[4] ); - - x0 += 5*n_elem_per_reg; - } - - for ( ; (i + 31) < n; i += 32 ) + else if( n < 500) { - // Load the input values. - xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); - xv[1] = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); - xv[2] = _mm256_loadu_ps( x0 + 2*n_elem_per_reg ); - xv[3] = _mm256_loadu_ps( x0 + 3*n_elem_per_reg ); - - // perform : x := alpha * x; - zv[0] = _mm256_mul_ps( alphav, xv[0] ); - zv[1] = _mm256_mul_ps( alphav, xv[1] ); - zv[2] = _mm256_mul_ps( alphav, xv[2] ); - zv[3] = _mm256_mul_ps( alphav, xv[3] ); - - // Store the output. - _mm256_storeu_ps( (x0 + 0*n_elem_per_reg), zv[0] ); - _mm256_storeu_ps( (x0 + 1*n_elem_per_reg), zv[1] ); - _mm256_storeu_ps( (x0 + 2*n_elem_per_reg), zv[2] ); - _mm256_storeu_ps( (x0 + 3*n_elem_per_reg), zv[3] ); - - x0 += 4*n_elem_per_reg; + option = 1; } - - for ( ; (i + 15) < n; i += 16 ) + else { - // Load the input values. - xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); - xv[1] = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); - - // perform : x := alpha * x; - zv[0] = _mm256_mul_ps( alphav, xv[0] ); - zv[1] = _mm256_mul_ps( alphav, xv[1] ); - - // Store the output. - _mm256_storeu_ps( (x0 + 0*n_elem_per_reg), zv[0] ); - _mm256_storeu_ps( (x0 + 1*n_elem_per_reg), zv[1] ); - - x0 += 2*n_elem_per_reg; + option = 0; } - for ( ; (i + 7) < n; i += 8 ) + switch(option) { - // Load the input values. - xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); - - // perform : x := alpha * x; - zv[0] = _mm256_mul_ps( alphav, xv[0] ); - - // Store the output. - _mm256_storeu_ps( (x0 + 0*n_elem_per_reg), zv[0] ); - - x0 += 1*n_elem_per_reg; - } - - for ( ; (i + 0) < n; i += 1 ) - { - *x0 *= *alpha; - - x0 += 1; + case 0: + + for ( ; (i + 127) < n; i += 128 ) + { + //Load the input values + xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); + xv[2] = _mm256_loadu_ps( x0 + 2*n_elem_per_reg ); + xv[3] = _mm256_loadu_ps( x0 + 3*n_elem_per_reg ); + + // Perform : x := alpha * x; + zv[0] = _mm256_mul_ps( alphav, xv[0] ); + zv[1] = _mm256_mul_ps( alphav, xv[1] ); + zv[2] = _mm256_mul_ps( alphav, xv[2] ); + zv[3] = _mm256_mul_ps( alphav, xv[3] ); + + // Store the result + _mm256_storeu_ps( (x0 + 0*n_elem_per_reg), zv[0] ); + _mm256_storeu_ps( (x0 + 1*n_elem_per_reg), zv[1] ); + _mm256_storeu_ps( (x0 + 2*n_elem_per_reg), zv[2] ); + _mm256_storeu_ps( (x0 + 3*n_elem_per_reg), zv[3] ); + + xv[4] = _mm256_loadu_ps( x0 + 4*n_elem_per_reg ); + xv[5] = _mm256_loadu_ps( x0 + 5*n_elem_per_reg ); + xv[6] = _mm256_loadu_ps( x0 + 6*n_elem_per_reg ); + xv[7] = _mm256_loadu_ps( x0 + 7*n_elem_per_reg ); + + zv[4] = _mm256_mul_ps( alphav, xv[4] ); + zv[5] = _mm256_mul_ps( alphav, xv[5] ); + zv[6] = _mm256_mul_ps( alphav, xv[6] ); + zv[7] = _mm256_mul_ps( alphav, xv[7] ); + + _mm256_storeu_ps( (x0 + 4*n_elem_per_reg), zv[4] ); + _mm256_storeu_ps( (x0 + 5*n_elem_per_reg), zv[5] ); + _mm256_storeu_ps( (x0 + 6*n_elem_per_reg), zv[6] ); + _mm256_storeu_ps( (x0 + 7*n_elem_per_reg), zv[7] ); + + xv[8] = _mm256_loadu_ps( x0 + 8*n_elem_per_reg ); + xv[9] = _mm256_loadu_ps( x0 + 9*n_elem_per_reg ); + xv[10] = _mm256_loadu_ps( x0 + 10*n_elem_per_reg ); + xv[11] = _mm256_loadu_ps( x0 + 11*n_elem_per_reg ); + + zv[8] = _mm256_mul_ps( alphav, xv[8] ); + zv[9] = _mm256_mul_ps( alphav, xv[9] ); + zv[10] = _mm256_mul_ps( alphav, xv[10] ); + zv[11] = _mm256_mul_ps( alphav, xv[11] ); + + _mm256_storeu_ps( (x0 + 8*n_elem_per_reg), zv[8] ); + _mm256_storeu_ps( (x0 + 9*n_elem_per_reg), zv[9] ); + _mm256_storeu_ps( (x0 + 10*n_elem_per_reg), zv[10] ); + _mm256_storeu_ps( (x0 + 11*n_elem_per_reg), zv[11] ); + + xv[12] = _mm256_loadu_ps( x0 + 12*n_elem_per_reg ); + xv[13] = _mm256_loadu_ps( x0 + 13*n_elem_per_reg ); + xv[14] = _mm256_loadu_ps( x0 + 14*n_elem_per_reg ); + xv[15] = _mm256_loadu_ps( x0 + 15*n_elem_per_reg ); + + zv[12] = _mm256_mul_ps( alphav, xv[12] ); + zv[13] = _mm256_mul_ps( alphav, xv[13] ); + zv[14] = _mm256_mul_ps( alphav, xv[14] ); + zv[15] = _mm256_mul_ps( alphav, xv[15] ); + + _mm256_storeu_ps( (x0 + 12*n_elem_per_reg), zv[12] ); + _mm256_storeu_ps( (x0 + 13*n_elem_per_reg), zv[13] ); + _mm256_storeu_ps( (x0 + 14*n_elem_per_reg), zv[14] ); + _mm256_storeu_ps( (x0 + 15*n_elem_per_reg), zv[15] ); + + x0 += 16*n_elem_per_reg; + } + + case 1 : + + for ( ; (i + 95) < n; i += 96 ) + { + xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); + xv[2] = _mm256_loadu_ps( x0 + 2*n_elem_per_reg ); + xv[3] = _mm256_loadu_ps( x0 + 3*n_elem_per_reg ); + + zv[0] = _mm256_mul_ps( alphav, xv[0] ); + zv[1] = _mm256_mul_ps( alphav, xv[1] ); + zv[2] = _mm256_mul_ps( alphav, xv[2] ); + zv[3] = _mm256_mul_ps( alphav, xv[3] ); + + _mm256_storeu_ps( (x0 + 0*n_elem_per_reg), zv[0] ); + _mm256_storeu_ps( (x0 + 1*n_elem_per_reg), zv[1] ); + _mm256_storeu_ps( (x0 + 2*n_elem_per_reg), zv[2] ); + _mm256_storeu_ps( (x0 + 3*n_elem_per_reg), zv[3] ); + + xv[4] = _mm256_loadu_ps( x0 + 4*n_elem_per_reg ); + xv[5] = _mm256_loadu_ps( x0 + 5*n_elem_per_reg ); + xv[6] = _mm256_loadu_ps( x0 + 6*n_elem_per_reg ); + xv[7] = _mm256_loadu_ps( x0 + 7*n_elem_per_reg ); + + zv[4] = _mm256_mul_ps( alphav, xv[4] ); + zv[5] = _mm256_mul_ps( alphav, xv[5] ); + zv[6] = _mm256_mul_ps( alphav, xv[6] ); + zv[7] = _mm256_mul_ps( alphav, xv[7] ); + + _mm256_storeu_ps( (x0 + 4*n_elem_per_reg), zv[4] ); + _mm256_storeu_ps( (x0 + 5*n_elem_per_reg), zv[5] ); + _mm256_storeu_ps( (x0 + 6*n_elem_per_reg), zv[6] ); + _mm256_storeu_ps( (x0 + 7*n_elem_per_reg), zv[7] ); + + xv[8] = _mm256_loadu_ps( x0 + 8*n_elem_per_reg ); + xv[9] = _mm256_loadu_ps( x0 + 9*n_elem_per_reg ); + xv[10] = _mm256_loadu_ps( x0 + 10*n_elem_per_reg ); + xv[11] = _mm256_loadu_ps( x0 + 11*n_elem_per_reg ); + + zv[8] = _mm256_mul_ps( alphav, xv[8] ); + zv[9] = _mm256_mul_ps( alphav, xv[9] ); + zv[10] = _mm256_mul_ps( alphav, xv[10] ); + zv[11] = _mm256_mul_ps( alphav, xv[11] ); + + _mm256_storeu_ps( (x0 + 8*n_elem_per_reg), zv[8] ); + _mm256_storeu_ps( (x0 + 9*n_elem_per_reg), zv[9] ); + _mm256_storeu_ps( (x0 + 10*n_elem_per_reg), zv[10] ); + _mm256_storeu_ps( (x0 + 11*n_elem_per_reg), zv[11] ); + + x0 += 12*n_elem_per_reg; + } + + case 2: + + for ( ; (i + 47) < n; i += 48 ) + { + xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); + xv[2] = _mm256_loadu_ps( x0 + 2*n_elem_per_reg ); + + zv[0] = _mm256_mul_ps( alphav, xv[0] ); + zv[1] = _mm256_mul_ps( alphav, xv[1] ); + zv[2] = _mm256_mul_ps( alphav, xv[2] ); + + _mm256_storeu_ps( (x0 + 0*n_elem_per_reg), zv[0] ); + _mm256_storeu_ps( (x0 + 1*n_elem_per_reg), zv[1] ); + _mm256_storeu_ps( (x0 + 2*n_elem_per_reg), zv[2] ); + + xv[3] = _mm256_loadu_ps( x0 + 3*n_elem_per_reg ); + xv[4] = _mm256_loadu_ps( x0 + 4*n_elem_per_reg ); + xv[5] = _mm256_loadu_ps( x0 + 5*n_elem_per_reg ); + + zv[3] = _mm256_mul_ps( alphav, xv[3] ); + zv[4] = _mm256_mul_ps( alphav, xv[4] ); + zv[5] = _mm256_mul_ps( alphav, xv[5] ); + + _mm256_storeu_ps( (x0 + 3*n_elem_per_reg), zv[3] ); + _mm256_storeu_ps( (x0 + 4*n_elem_per_reg), zv[4] ); + _mm256_storeu_ps( (x0 + 5*n_elem_per_reg), zv[5] ); + + x0 += 6*n_elem_per_reg; + } + + for ( ; (i + 23) < n; i += 24 ) + { + xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); + xv[2] = _mm256_loadu_ps( x0 + 2*n_elem_per_reg ); + + zv[0] = _mm256_mul_ps( alphav, xv[0] ); + zv[1] = _mm256_mul_ps( alphav, xv[1] ); + zv[2] = _mm256_mul_ps( alphav, xv[2] ); + + _mm256_storeu_ps( (x0 + 0*n_elem_per_reg), zv[0] ); + _mm256_storeu_ps( (x0 + 1*n_elem_per_reg), zv[1] ); + _mm256_storeu_ps( (x0 + 2*n_elem_per_reg), zv[2] ); + + x0 += 3*n_elem_per_reg; + } + + for ( ; (i + 7) < n; i += 8 ) + { + xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); + + zv[0] = _mm256_mul_ps( alphav, xv[0] ); + + _mm256_storeu_ps( (x0 + 0*n_elem_per_reg), zv[0] ); + + x0 += 1*n_elem_per_reg; + } + + for ( ; (i + 0) < n; i += 1 ) + { + *x0 *= *alpha; + + x0 += 1; + } } } else { const float alphac = *alpha; - for ( i = 0; i < n; ++i ) + for ( ; i < n; ++i ) { *x0 *= alphac; @@ -266,13 +327,13 @@ void bli_dscalv_zen_int10 { const dim_t n_elem_per_reg = 4; - dim_t i; + dim_t i = 0; double* restrict x0; __m256d alphav; - __m256d xv[10]; - __m256d zv[10]; + __m256d xv[16]; + __m256d zv[16]; // If the vector dimension is zero, or if alpha is unit, return early. if ( bli_zero_dim1( n ) || PASTEMAC(d,eq1)( *alpha ) ) return; @@ -312,140 +373,221 @@ void bli_dscalv_zen_int10 { // Broadcast the alpha scalar to all elements of a vector register. alphav = _mm256_broadcast_sd( alpha ); + dim_t option; - for ( i = 0; (i + 39) < n; i += 40 ) + // Unroll and the loop used is picked based on the input size. + if(n < 200) { - // Load the input values. - xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); - xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); - xv[2] = _mm256_loadu_pd( x0 + 2*n_elem_per_reg ); - xv[3] = _mm256_loadu_pd( x0 + 3*n_elem_per_reg ); - xv[4] = _mm256_loadu_pd( x0 + 4*n_elem_per_reg ); - xv[5] = _mm256_loadu_pd( x0 + 5*n_elem_per_reg ); - xv[6] = _mm256_loadu_pd( x0 + 6*n_elem_per_reg ); - xv[7] = _mm256_loadu_pd( x0 + 7*n_elem_per_reg ); - xv[8] = _mm256_loadu_pd( x0 + 8*n_elem_per_reg ); - xv[9] = _mm256_loadu_pd( x0 + 9*n_elem_per_reg ); - - // perform : x := alpha * x; - zv[0] = _mm256_mul_pd( alphav, xv[0] ); - zv[1] = _mm256_mul_pd( alphav, xv[1] ); - zv[2] = _mm256_mul_pd( alphav, xv[2] ); - zv[3] = _mm256_mul_pd( alphav, xv[3] ); - zv[4] = _mm256_mul_pd( alphav, xv[4] ); - zv[5] = _mm256_mul_pd( alphav, xv[5] ); - zv[6] = _mm256_mul_pd( alphav, xv[6] ); - zv[7] = _mm256_mul_pd( alphav, xv[7] ); - zv[8] = _mm256_mul_pd( alphav, xv[8] ); - zv[9] = _mm256_mul_pd( alphav, xv[9] ); - - // Store the output. - _mm256_storeu_pd( (x0 + 0*n_elem_per_reg), zv[0] ); - _mm256_storeu_pd( (x0 + 1*n_elem_per_reg), zv[1] ); - _mm256_storeu_pd( (x0 + 2*n_elem_per_reg), zv[2] ); - _mm256_storeu_pd( (x0 + 3*n_elem_per_reg), zv[3] ); - _mm256_storeu_pd( (x0 + 4*n_elem_per_reg), zv[4] ); - _mm256_storeu_pd( (x0 + 5*n_elem_per_reg), zv[5] ); - _mm256_storeu_pd( (x0 + 6*n_elem_per_reg), zv[6] ); - _mm256_storeu_pd( (x0 + 7*n_elem_per_reg), zv[7] ); - _mm256_storeu_pd( (x0 + 8*n_elem_per_reg), zv[8] ); - _mm256_storeu_pd( (x0 + 9*n_elem_per_reg), zv[9] ); - - x0 += 10*n_elem_per_reg; + option = 2; } - - for ( ; (i + 19) < n; i += 20 ) + else if(n < 500) { - // Load the input values. - xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); - xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); - xv[2] = _mm256_loadu_pd( x0 + 2*n_elem_per_reg ); - xv[3] = _mm256_loadu_pd( x0 + 3*n_elem_per_reg ); - xv[4] = _mm256_loadu_pd( x0 + 4*n_elem_per_reg ); - - // perform : x := alpha * x; - zv[0] = _mm256_mul_pd( alphav, xv[0] ); - zv[1] = _mm256_mul_pd( alphav, xv[1] ); - zv[2] = _mm256_mul_pd( alphav, xv[2] ); - zv[3] = _mm256_mul_pd( alphav, xv[3] ); - zv[4] = _mm256_mul_pd( alphav, xv[4] ); - - // Store the output. - _mm256_storeu_pd( (x0 + 0*n_elem_per_reg), zv[0] ); - _mm256_storeu_pd( (x0 + 1*n_elem_per_reg), zv[1] ); - _mm256_storeu_pd( (x0 + 2*n_elem_per_reg), zv[2] ); - _mm256_storeu_pd( (x0 + 3*n_elem_per_reg), zv[3] ); - _mm256_storeu_pd( (x0 + 4*n_elem_per_reg), zv[4] ); - - x0 += 5*n_elem_per_reg; + option = 1; } - - for ( ; (i + 15) < n; i += 16 ) + else { - // Load the input values. - xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); - xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); - xv[2] = _mm256_loadu_pd( x0 + 2*n_elem_per_reg ); - xv[3] = _mm256_loadu_pd( x0 + 3*n_elem_per_reg ); - - // perform : x := alpha * x; - zv[0] = _mm256_mul_pd( alphav, xv[0] ); - zv[1] = _mm256_mul_pd( alphav, xv[1] ); - zv[2] = _mm256_mul_pd( alphav, xv[2] ); - zv[3] = _mm256_mul_pd( alphav, xv[3] ); - - // Store the output. - _mm256_storeu_pd( (x0 + 0*n_elem_per_reg), zv[0] ); - _mm256_storeu_pd( (x0 + 1*n_elem_per_reg), zv[1] ); - _mm256_storeu_pd( (x0 + 2*n_elem_per_reg), zv[2] ); - _mm256_storeu_pd( (x0 + 3*n_elem_per_reg), zv[3] ); - - x0 += 4*n_elem_per_reg; + option = 0; } - for ( ; (i + 7) < n; i += 8 ) + switch(option) { - // Load the input values. - xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); - xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); - - // perform : x := alpha * x; - zv[0] = _mm256_mul_pd( alphav, xv[0] ); - zv[1] = _mm256_mul_pd( alphav, xv[1] ); - - // Store the output. - _mm256_storeu_pd( (x0 + 0*n_elem_per_reg), zv[0] ); - _mm256_storeu_pd( (x0 + 1*n_elem_per_reg), zv[1] ); - - x0 += 2*n_elem_per_reg; - } - - for ( ; (i + 3) < n; i += 4 ) - { - // Load the input values. - xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); - - // perform : x := alpha * x; - zv[0] = _mm256_mul_pd( alphav, xv[0] ); - - // Store the output. - _mm256_storeu_pd( (x0 + 0*n_elem_per_reg), zv[0] ); - - x0 += 1*n_elem_per_reg; - } - - for ( ; (i + 0) < n; i += 1 ) - { - *x0 *= *alpha; - - x0 += 1; + case 0: + + for (; (i + 63) < n; i += 64 ) + { + xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); + xv[2] = _mm256_loadu_pd( x0 + 2*n_elem_per_reg ); + xv[3] = _mm256_loadu_pd( x0 + 3*n_elem_per_reg ); + + zv[0] = _mm256_mul_pd( alphav, xv[0] ); + zv[1] = _mm256_mul_pd( alphav, xv[1] ); + zv[2] = _mm256_mul_pd( alphav, xv[2] ); + zv[3] = _mm256_mul_pd( alphav, xv[3] ); + + _mm256_storeu_pd( (x0 + 0*n_elem_per_reg), zv[0] ); + _mm256_storeu_pd( (x0 + 1*n_elem_per_reg), zv[1] ); + _mm256_storeu_pd( (x0 + 2*n_elem_per_reg), zv[2] ); + _mm256_storeu_pd( (x0 + 3*n_elem_per_reg), zv[3] ); + + xv[4] = _mm256_loadu_pd( x0 + 4*n_elem_per_reg ); + xv[5] = _mm256_loadu_pd( x0 + 5*n_elem_per_reg ); + xv[6] = _mm256_loadu_pd( x0 + 6*n_elem_per_reg ); + xv[7] = _mm256_loadu_pd( x0 + 7*n_elem_per_reg ); + + zv[4] = _mm256_mul_pd( alphav, xv[4] ); + zv[5] = _mm256_mul_pd( alphav, xv[5] ); + zv[6] = _mm256_mul_pd( alphav, xv[6] ); + zv[7] = _mm256_mul_pd( alphav, xv[7] ); + + _mm256_storeu_pd( (x0 + 4*n_elem_per_reg), zv[4] ); + _mm256_storeu_pd( (x0 + 5*n_elem_per_reg), zv[5] ); + _mm256_storeu_pd( (x0 + 6*n_elem_per_reg), zv[6] ); + _mm256_storeu_pd( (x0 + 7*n_elem_per_reg), zv[7] ); + + xv[8] = _mm256_loadu_pd( x0 + 8*n_elem_per_reg ); + xv[9] = _mm256_loadu_pd( x0 + 9*n_elem_per_reg ); + xv[10] = _mm256_loadu_pd( x0 + 10*n_elem_per_reg ); + xv[11] = _mm256_loadu_pd( x0 + 11*n_elem_per_reg ); + + zv[8] = _mm256_mul_pd( alphav, xv[8] ); + zv[9] = _mm256_mul_pd( alphav, xv[9] ); + zv[10] = _mm256_mul_pd( alphav, xv[10] ); + zv[11] = _mm256_mul_pd( alphav, xv[11] ); + + _mm256_storeu_pd( (x0 + 8*n_elem_per_reg), zv[8] ); + _mm256_storeu_pd( (x0 + 9*n_elem_per_reg), zv[9] ); + _mm256_storeu_pd( (x0 + 10*n_elem_per_reg), zv[10] ); + _mm256_storeu_pd( (x0 + 11*n_elem_per_reg), zv[11] ); + + xv[12] = _mm256_loadu_pd( x0 + 12*n_elem_per_reg ); + xv[13] = _mm256_loadu_pd( x0 + 13*n_elem_per_reg ); + xv[14] = _mm256_loadu_pd( x0 + 14*n_elem_per_reg ); + xv[15] = _mm256_loadu_pd( x0 + 15*n_elem_per_reg ); + + zv[12] = _mm256_mul_pd( alphav, xv[12] ); + zv[13] = _mm256_mul_pd( alphav, xv[13] ); + zv[14] = _mm256_mul_pd( alphav, xv[14] ); + zv[15] = _mm256_mul_pd( alphav, xv[15] ); + + _mm256_storeu_pd( (x0 + 12*n_elem_per_reg), zv[12] ); + _mm256_storeu_pd( (x0 + 13*n_elem_per_reg), zv[13] ); + _mm256_storeu_pd( (x0 + 14*n_elem_per_reg), zv[14] ); + _mm256_storeu_pd( (x0 + 15*n_elem_per_reg), zv[15] ); + + x0 += 16*n_elem_per_reg; + } + + for (; (i + 47) < n; i += 48 ) + { + xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); + xv[2] = _mm256_loadu_pd( x0 + 2*n_elem_per_reg ); + xv[3] = _mm256_loadu_pd( x0 + 3*n_elem_per_reg ); + + zv[0] = _mm256_mul_pd( alphav, xv[0] ); + zv[1] = _mm256_mul_pd( alphav, xv[1] ); + zv[2] = _mm256_mul_pd( alphav, xv[2] ); + zv[3] = _mm256_mul_pd( alphav, xv[3] ); + + _mm256_storeu_pd( (x0 + 0*n_elem_per_reg), zv[0] ); + _mm256_storeu_pd( (x0 + 1*n_elem_per_reg), zv[1] ); + _mm256_storeu_pd( (x0 + 2*n_elem_per_reg), zv[2] ); + _mm256_storeu_pd( (x0 + 3*n_elem_per_reg), zv[3] ); + + xv[4] = _mm256_loadu_pd( x0 + 4*n_elem_per_reg ); + xv[5] = _mm256_loadu_pd( x0 + 5*n_elem_per_reg ); + xv[6] = _mm256_loadu_pd( x0 + 6*n_elem_per_reg ); + xv[7] = _mm256_loadu_pd( x0 + 7*n_elem_per_reg ); + + zv[4] = _mm256_mul_pd( alphav, xv[4] ); + zv[5] = _mm256_mul_pd( alphav, xv[5] ); + zv[6] = _mm256_mul_pd( alphav, xv[6] ); + zv[7] = _mm256_mul_pd( alphav, xv[7] ); + + _mm256_storeu_pd( (x0 + 4*n_elem_per_reg), zv[4] ); + _mm256_storeu_pd( (x0 + 5*n_elem_per_reg), zv[5] ); + _mm256_storeu_pd( (x0 + 6*n_elem_per_reg), zv[6] ); + _mm256_storeu_pd( (x0 + 7*n_elem_per_reg), zv[7] ); + + xv[8] = _mm256_loadu_pd( x0 + 8*n_elem_per_reg ); + xv[9] = _mm256_loadu_pd( x0 + 9*n_elem_per_reg ); + xv[10] = _mm256_loadu_pd( x0 + 10*n_elem_per_reg ); + xv[11] = _mm256_loadu_pd( x0 + 11*n_elem_per_reg ); + + zv[8] = _mm256_mul_pd( alphav, xv[8] ); + zv[9] = _mm256_mul_pd( alphav, xv[9] ); + zv[10] = _mm256_mul_pd( alphav, xv[10] ); + zv[11] = _mm256_mul_pd( alphav, xv[11] ); + + _mm256_storeu_pd( (x0 + 8*n_elem_per_reg), zv[8] ); + _mm256_storeu_pd( (x0 + 9*n_elem_per_reg), zv[9] ); + _mm256_storeu_pd( (x0 + 10*n_elem_per_reg), zv[10] ); + _mm256_storeu_pd( (x0 + 11*n_elem_per_reg), zv[11] ); + + x0 += 12*n_elem_per_reg; + } + + case 1: + + for (; (i + 31) < n; i += 32 ) + { + xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); + xv[2] = _mm256_loadu_pd( x0 + 2*n_elem_per_reg ); + xv[3] = _mm256_loadu_pd( x0 + 3*n_elem_per_reg ); + + zv[0] = _mm256_mul_pd( alphav, xv[0] ); + zv[1] = _mm256_mul_pd( alphav, xv[1] ); + zv[2] = _mm256_mul_pd( alphav, xv[2] ); + zv[3] = _mm256_mul_pd( alphav, xv[3] ); + + _mm256_storeu_pd( (x0 + 0*n_elem_per_reg), zv[0] ); + _mm256_storeu_pd( (x0 + 1*n_elem_per_reg), zv[1] ); + _mm256_storeu_pd( (x0 + 2*n_elem_per_reg), zv[2] ); + _mm256_storeu_pd( (x0 + 3*n_elem_per_reg), zv[3] ); + + xv[4] = _mm256_loadu_pd( x0 + 4*n_elem_per_reg ); + xv[5] = _mm256_loadu_pd( x0 + 5*n_elem_per_reg ); + xv[6] = _mm256_loadu_pd( x0 + 6*n_elem_per_reg ); + xv[7] = _mm256_loadu_pd( x0 + 7*n_elem_per_reg ); + + zv[4] = _mm256_mul_pd( alphav, xv[4] ); + zv[5] = _mm256_mul_pd( alphav, xv[5] ); + zv[6] = _mm256_mul_pd( alphav, xv[6] ); + zv[7] = _mm256_mul_pd( alphav, xv[7] ); + + _mm256_storeu_pd( (x0 + 4*n_elem_per_reg), zv[4] ); + _mm256_storeu_pd( (x0 + 5*n_elem_per_reg), zv[5] ); + _mm256_storeu_pd( (x0 + 6*n_elem_per_reg), zv[6] ); + _mm256_storeu_pd( (x0 + 7*n_elem_per_reg), zv[7] ); + + x0 += 8*n_elem_per_reg; + } + + case 2: + + for ( ; (i + 11) < n; i += 12 ) + { + xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); + xv[2] = _mm256_loadu_pd( x0 + 2*n_elem_per_reg ); + + zv[0] = _mm256_mul_pd( alphav, xv[0] ); + zv[1] = _mm256_mul_pd( alphav, xv[1] ); + zv[2] = _mm256_mul_pd( alphav, xv[2] ); + + _mm256_storeu_pd( (x0 + 0*n_elem_per_reg), zv[0] ); + _mm256_storeu_pd( (x0 + 1*n_elem_per_reg), zv[1] ); + _mm256_storeu_pd( (x0 + 2*n_elem_per_reg), zv[2] ); + + x0 += 3*n_elem_per_reg; + } + + for ( ; (i + 3) < n; i += 4 ) + { + xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); + + zv[0] = _mm256_mul_pd( alphav, xv[0] ); + + _mm256_storeu_pd( (x0 + 0*n_elem_per_reg), zv[0] ); + + x0 += 1*n_elem_per_reg; + } + + for ( ; (i + 0) < n; i += 1 ) + { + *x0 *= *alpha; + + x0 += 1; + } } } else { const double alphac = *alpha; - for ( i = 0; i < n; ++i ) + for ( ; i < n; ++i ) { *x0 *= alphac; From 3190e547b05c2773afd47914231dc95a8e7eaa91 Mon Sep 17 00:00:00 2001 From: Arnav Sharma Date: Tue, 21 Dec 2021 16:49:11 +0530 Subject: [PATCH 060/243] Optimized AXPBYV Kernel using AVX2 Intrinsics Details: - Intrinsic implementation of axpbyv for AVX2 - Bench written for axpbyv - Added definitions in zen contexts AMD-Internal: [CPUPL-1963] Change-Id: I9bc21a6170f5c944eb6e9e9f0e994b9992f8b539 --- bench/Makefile | 12 +- bench/bench_axpbyv.c | 265 +++++++ bench/inputaxpbyv.txt | 40 + config/zen/bli_cntx_init_zen.c | 10 +- config/zen2/bli_cntx_init_zen2.c | 10 +- config/zen3/bli_cntx_init_zen3.c | 10 +- config/zen4/bli_cntx_init_zen4.c | 11 +- kernels/zen/1/CMakeLists.txt | 2 + kernels/zen/1/bli_axpbyv_zen_int.c | 1099 ++++++++++++++++++++++++++ kernels/zen/1/bli_axpbyv_zen_int10.c | 709 +++++++++++++++++ kernels/zen/bli_kernels_zen.h | 10 + 11 files changed, 2163 insertions(+), 15 deletions(-) create mode 100644 bench/bench_axpbyv.c create mode 100644 bench/inputaxpbyv.txt create mode 100644 kernels/zen/1/bli_axpbyv_zen_int.c create mode 100644 kernels/zen/1/bli_axpbyv_zen_int10.c diff --git a/bench/Makefile b/bench/Makefile index 3ee497212d..d47485b2fc 100755 --- a/bench/Makefile +++ b/bench/Makefile @@ -191,7 +191,8 @@ blis: \ bench_trsv_blis.x \ bench_amaxv_blis.x \ bench_copyv_blis.x \ - bench_swapv_blis.x + bench_swapv_blis.x \ + bench_axpbyv_blis.x openblas: \ bench_gemm_openblas.x \ @@ -205,7 +206,8 @@ openblas: \ bench_trsv_openblas.x \ bench_amaxv_openblas.x \ bench_copyv_openblas.x \ - bench_swapv_openblas.x + bench_swapv_openblas.x \ + bench_axpbyv_openblas.x atlas: \ bench_gemm_atlas.x \ @@ -219,7 +221,8 @@ atlas: \ bench_trsv_atlas.x \ bench_amaxv_atlas.x \ bench_copyv_atlas.x \ - bench_swapv_atlas.x + bench_swapv_atlas.x \ + bench_axpbyv_atlax.x mkl: \ bench_gemm_mkl.x \ @@ -233,7 +236,8 @@ mkl: \ bench_trsv_mkl.x \ bench_amaxv_mkl.x \ bench_copyv_mkl.x \ - bench_swapv_mkl.x + bench_swapv_mkl.x \ + bench_axpbyv_mkl.x # --Object file rules -- diff --git a/bench/bench_axpbyv.c b/bench/bench_axpbyv.c new file mode 100644 index 0000000000..36a203f696 --- /dev/null +++ b/bench/bench_axpbyv.c @@ -0,0 +1,265 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name of The University of Texas nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifdef WIN32 +#include +#else +#include +#endif +#include "blis.h" + +#ifndef DT +#define DT BLIS_DOUBLE +#endif +#define AOCL_MATRIX_INITIALISATION + +int main( int argc, char** argv ) +{ + obj_t x, y, alpha, beta; // BLIS objects + dim_t p_inc = 0; // To keep track of number of inputs + num_t dt; // BLIS datatype + char dt_ch; // {S, D, Z, C} from input + int r, n_repeats; // repetition counter; number of repeats + + double dtime; + double dtime_save; + double gflops; + + FILE* fin = NULL; // Input FILE* + FILE* fout = NULL; // Output FILE* + + n_repeats = N_REPEAT; // Fetched from Makefile + + dt = DT; // Set datatype as BLIS_DOUBLE + + if ( argc < 3 ) + { + printf( "Usage: ./bench_axpbyv_XX.x input.txt output.txt\n" ); + exit( 1 ); + } + + fin = fopen( argv[1], "r" ); // Open input file in read mode + if ( fin == NULL ) + { + printf( "Error opening input file %s\n", argv[1] ); + exit( 1 ); + } + + fout = fopen( argv[2], "w" ); // Open output file in write mode + if ( fout == NULL ) + { + printf( "Error opening output file %s\n", argv[2] ); + exit( 1 ); + } + +#ifdef DEBUG + fprintf( fout, "gflops\n" ); +#else + fprintf(fout, "Dt\t n\t alpha_r\t alpha_i\t beta_r\t beta_i\t gflops\n" ); +#endif + + dim_t n; // dimension + inc_t incx; // stride x + inc_t incy; // stride y + char tmp[256]; // to store function name, line not present in logs + double alpha_r, alpha_i, beta_r, beta_i; + + // {function name} {S, D, C, Z} {n} + // {alpha_r} {alpha_i} {incx} {beta_r} {beta_i} {incy} + while ( fscanf( fin, "%s %c %ld %lf %lf %ld %lf %lf %ld\n", + tmp, &dt_ch, &n, + &alpha_r, &alpha_i, &incx, &beta_r, &beta_i, &incy ) == 9 ) + { + if ( dt_ch == 'D' || dt_ch == 'd' ) dt = BLIS_DOUBLE; + else if ( dt_ch == 'Z' || dt_ch == 'z' ) dt = BLIS_DCOMPLEX; + else if ( dt_ch == 'S' || dt_ch == 's' ) dt = BLIS_FLOAT; + else if ( dt_ch == 'C' || dt_ch == 'c' ) dt = BLIS_SCOMPLEX; + else + { + printf( "Invalid data type %c\n", dt_ch ); + continue; + } + + // Creating BLIS objects + bli_obj_create( dt, n, 1, incx, 1, &x ); // For input vector x + bli_obj_create( dt, n, 1, incy, 1, &y ); // For input vector y + bli_obj_create( dt, 1, 1, 0, 0, &alpha); // For input vector alpha + bli_obj_create( dt, 1, 1, 0, 0, &beta); // For input vector beta + + #ifdef AOCL_MATRIX_INITIALISATION + bli_randm( &x ); + bli_randm( &y ); + #endif + + bli_setsc( alpha_r, alpha_i, &alpha ); + bli_setsc( beta_r, beta_i, &beta ); + + dtime_save = DBL_MAX; + + for ( r = 0; r < n_repeats; ++r ) + { + dtime = bli_clock(); + +#ifdef BLIS + bli_axpbyv( &alpha, &x, &beta, &y ); +#else + f77_int nn = bli_obj_length( &x ); + f77_int blas_incx = bli_obj_vector_inc( &x ); + f77_int blas_incy = bli_obj_vector_inc( &y ); + + if ( bli_is_float( dt ) ) + { + float* alphap = bli_obj_buffer( &alpha ); + float* xp = bli_obj_buffer( &x ); + float* betap = bli_obj_buffer( &beta ); + float* yp = bli_obj_buffer( &y ); + +#ifdef CBLAS + cblas_saxpby( nn, + *alphap, + xp, + blas_incx, + *betap, + yp, + blas_incy ); +#else + saxpby_( &nn, + alphap, + xp, + &blas_incx, + betap, + yp, + &blas_incy ); +#endif + } + else if ( bli_is_double( dt ) ) + { + double* alphap = bli_obj_buffer( &alpha ); + double* xp = bli_obj_buffer( &x ); + double* betap = bli_obj_buffer( &beta ); + double* yp = bli_obj_buffer( &y ); + +#ifdef CBLAS + cblas_daxpby( nn, + *alphap, + xp, + blas_incx, + *betap, + yp, + blas_incy ); +#else + daxpby_( &nn, + alphap, + xp, + &blas_incx, + betap, + yp, + &blas_incy ); +#endif + } + else if ( bli_is_scomplex( dt ) ) + { + scomplex* alphap = bli_obj_buffer( &alpha ); + scomplex* xp = bli_obj_buffer( &x ); + scomplex* betap = bli_obj_buffer( &beta ); + scomplex* yp = bli_obj_buffer( &y ); + +#ifdef CBLAS + cblas_caxpby( nn, + *alphap, + xp, + blas_incx, + *betap, + yp, + blas_incy ); +#else + caxpby_( &nn, + alphap, + xp, + &blas_incx, + betap, + yp, + &blas_incy ); +#endif + } + else if ( bli_is_dcomplex( dt ) ) + { + dcomplex* alphap = bli_obj_buffer( &alpha ); + dcomplex* xp = bli_obj_buffer( &x ); + dcomplex* betap = bli_obj_buffer( &beta ); + dcomplex* yp = bli_obj_buffer( &y ); + +#ifdef CBLAS + cblas_zaxpby( nn, + *alphap, + xp, + blas_incx, + *betap, + yp, + blas_incy ); +#else + zaxpby_( &nn, + alphap, + xp, + &blas_incx, + betap, + yp, + &blas_incy ); +#endif + } +#endif + + dtime_save = bli_clock_min_diff( dtime_save, dtime ); + } + gflops = ( 3.0 * n ) / ( dtime_save * 1.0e9 ); + if ( bli_is_complex( dt ) ) gflops *= 4.0; + + printf( "data_axpbyv_%s", BLAS ); + + p_inc++; + printf( " %4lu [ %4lu %7.2f ];\n", + (unsigned long)(p_inc), + (unsigned long)n, + gflops ); + + fprintf( fout, "%c\t %ld\t %lf\t %lf\t %lf\t %lf\t %6.3f\n", + dt_ch, n, alpha_r, alpha_i, beta_r, beta_i, gflops ); + fflush( fout ); + + bli_obj_free( &x ); + bli_obj_free( &y ); + } + + return 0; +} \ No newline at end of file diff --git a/bench/inputaxpbyv.txt b/bench/inputaxpbyv.txt new file mode 100644 index 0000000000..3cfc7ae732 --- /dev/null +++ b/bench/inputaxpbyv.txt @@ -0,0 +1,40 @@ +saxpbyv_ S 32 0.900000 0.000000 1 0.900000 0.000000 1 +saxpbyv_ S 64 1.000000 0.000000 1 1.000000 0.000000 1 +saxpbyv_ S 100 -1 0.000000 1 -1 0.000000 1 +saxpbyv_ S 200 -1.100000 0.000000 1 -1.100000 0.000000 1 +saxpbyv_ S 300 1.100000 0.000000 1 1.100000 0.000000 1 +saxpbyv_ S 400 0.900000 0.000000 1 0.900000 0.000000 1 +saxpbyv_ S 500 1.000000 0.000000 1 1.000000 0.000000 1 +saxpbyv_ S 1000 -1 0.000000 1 -1 0.000000 1 +saxpbyv_ S 5000 -1.100000 0.000000 1 -1.100000 0.000000 1 +saxpbyv_ S 10000 1.100000 0.000000 1 1.100000 0.000000 1 +daxpbyv_ D 32 0.900000 0.000000 1 0.900000 0.000000 1 +daxpbyv_ D 64 1.000000 0.000000 1 1.000000 0.000000 1 +daxpbyv_ D 100 -1 0.000000 1 -1 0.000000 1 +daxpbyv_ D 200 -1.100000 0.000000 1 -1.100000 0.000000 1 +daxpbyv_ D 300 1.100000 0.000000 1 1.100000 0.000000 1 +daxpbyv_ D 400 0.900000 0.000000 1 0.900000 0.000000 1 +daxpbyv_ D 500 1.000000 0.000000 1 1.000000 0.000000 1 +daxpbyv_ D 1000 -1 0.000000 1 -1 0.000000 1 +daxpbyv_ D 5000 -1.100000 0.000000 1 -1.100000 0.000000 1 +daxpbyv_ D 10000 1.100000 0.000000 1 1.100000 0.000000 1 +caxpbyv_ C 32 0.900000 -1.100000 1 0.900000 -1.100000 1 +caxpbyv_ C 64 1.000000 1.100000 1 1.000000 1.100000 1 +caxpbyv_ C 100 -1 1.000000 1 -1 1 1 +caxpbyv_ C 200 -1.100000 0.900000 1 -1.100000 0.900000 1 +caxpbyv_ C 300 1.100000 1.000000 1 1.100000 1 1 +caxpbyv_ C 400 0.900000 -1.100000 1 0.900000 -1.100000 1 +caxpbyv_ C 500 1.000000 1.000000 1 1.000000 1 1 +caxpbyv_ C 1000 -1 0.900000 1 -1 0.900000 1 +caxpbyv_ C 5000 -1.100000 -1 1 -1.100000 -1 1 +caxpbyv_ C 10000 1.100000 -1 1 1.100000 -1 1 +zaxpbyv_ Z 32 0.900000 -1.100000 1 0.900000 -1.100000 1 +zaxpbyv_ Z 64 1.000000 1.100000 1 1.000000 1.100000 1 +zaxpbyv_ Z 100 -1 1.000000 1 -1 1 1 +zaxpbyv_ Z 200 -1.100000 0.900000 1 -1.100000 0.900000 1 +zaxpbyv_ Z 300 1.100000 1.000000 1 1.100000 1 1 +zaxpbyv_ Z 400 0.900000 -1.100000 1 0.900000 -1.100000 1 +zaxpbyv_ Z 500 1.000000 1.000000 1 1.000000 1 1 +zaxpbyv_ Z 1000 -1 0.900000 1 -1 0.900000 1 +zaxpbyv_ Z 5000 -1.100000 -1 1 -1.100000 -1 1 +zaxpbyv_ Z 10000 1.100000 -1 1 1.100000 -1 1 diff --git a/config/zen/bli_cntx_init_zen.c b/config/zen/bli_cntx_init_zen.c index 7595849866..020e7052b9 100644 --- a/config/zen/bli_cntx_init_zen.c +++ b/config/zen/bli_cntx_init_zen.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -95,12 +95,18 @@ void bli_cntx_init_zen( cntx_t* cntx ) // Update the context with optimized level-1v kernels. bli_cntx_set_l1v_kers ( - 20, + 24, #if 1 // amaxv BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int, BLIS_AMAXV_KER, BLIS_DOUBLE, bli_damaxv_zen_int, #endif + // axpbyv + BLIS_AXPBYV_KER, BLIS_FLOAT, bli_saxpbyv_zen_int10, + BLIS_AXPBYV_KER, BLIS_DOUBLE, bli_daxpbyv_zen_int10, + BLIS_AXPBYV_KER, BLIS_SCOMPLEX, bli_caxpbyv_zen_int, + BLIS_AXPBYV_KER, BLIS_DCOMPLEX, bli_zaxpbyv_zen_int, + // axpyv #if 0 BLIS_AXPYV_KER, BLIS_FLOAT, bli_saxpyv_zen_int, diff --git a/config/zen2/bli_cntx_init_zen2.c b/config/zen2/bli_cntx_init_zen2.c index 4f56316a7a..315362067e 100644 --- a/config/zen2/bli_cntx_init_zen2.c +++ b/config/zen2/bli_cntx_init_zen2.c @@ -3,7 +3,7 @@ An object-based framework for developing high-performance BLAS-like libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -107,13 +107,17 @@ void bli_cntx_init_zen2( cntx_t* cntx ) // Update the context with optimized level-1v kernels. bli_cntx_set_l1v_kers ( - 20, + 24, #if 1 // amaxv BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int, BLIS_AMAXV_KER, BLIS_DOUBLE, bli_damaxv_zen_int, #endif - // axpyv + // axpbyv + BLIS_AXPBYV_KER, BLIS_FLOAT, bli_saxpbyv_zen_int10, + BLIS_AXPBYV_KER, BLIS_DOUBLE, bli_daxpbyv_zen_int10, + BLIS_AXPBYV_KER, BLIS_SCOMPLEX, bli_caxpbyv_zen_int, + BLIS_AXPBYV_KER, BLIS_DCOMPLEX, bli_zaxpbyv_zen_int, // axpyv BLIS_AXPYV_KER, BLIS_FLOAT, bli_saxpyv_zen_int10, diff --git a/config/zen3/bli_cntx_init_zen3.c b/config/zen3/bli_cntx_init_zen3.c index fc7dbcb808..ef47987454 100644 --- a/config/zen3/bli_cntx_init_zen3.c +++ b/config/zen3/bli_cntx_init_zen3.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -107,13 +107,17 @@ void bli_cntx_init_zen3( cntx_t* cntx ) // Update the context with optimized level-1v kernels. bli_cntx_set_l1v_kers ( - 20, + 24, #if 1 // amaxv BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int, BLIS_AMAXV_KER, BLIS_DOUBLE, bli_damaxv_zen_int, #endif - // axpyv + // axpbyv + BLIS_AXPBYV_KER, BLIS_FLOAT, bli_saxpbyv_zen_int10, + BLIS_AXPBYV_KER, BLIS_DOUBLE, bli_daxpbyv_zen_int10, + BLIS_AXPBYV_KER, BLIS_SCOMPLEX, bli_caxpbyv_zen_int, + BLIS_AXPBYV_KER, BLIS_DCOMPLEX, bli_zaxpbyv_zen_int, // axpyv BLIS_AXPYV_KER, BLIS_FLOAT, bli_saxpyv_zen_int10, diff --git a/config/zen4/bli_cntx_init_zen4.c b/config/zen4/bli_cntx_init_zen4.c index 806c268a0f..4f4c16d0ae 100644 --- a/config/zen4/bli_cntx_init_zen4.c +++ b/config/zen4/bli_cntx_init_zen4.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -106,12 +106,17 @@ void bli_cntx_init_zen4( cntx_t* cntx ) // Update the context with optimized level-1v kernels. bli_cntx_set_l1v_kers ( - 20, + 24, // amaxv BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int, BLIS_AMAXV_KER, BLIS_DOUBLE, bli_damaxv_zen_int, - // axpyv + + // axpbyv + BLIS_AXPBYV_KER, BLIS_FLOAT, bli_saxpbyv_zen_int10, + BLIS_AXPBYV_KER, BLIS_DOUBLE, bli_daxpbyv_zen_int10, + BLIS_AXPBYV_KER, BLIS_SCOMPLEX, bli_caxpbyv_zen_int, + BLIS_AXPBYV_KER, BLIS_DCOMPLEX, bli_zaxpbyv_zen_int, // axpyv BLIS_AXPYV_KER, BLIS_FLOAT, bli_saxpyv_zen_int10, diff --git a/kernels/zen/1/CMakeLists.txt b/kernels/zen/1/CMakeLists.txt index 669a3ba89a..434be490d5 100644 --- a/kernels/zen/1/CMakeLists.txt +++ b/kernels/zen/1/CMakeLists.txt @@ -3,6 +3,8 @@ target_sources("${PROJECT_NAME}" PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/bli_amaxv_zen_int.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_axpbyv_zen_int.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_axpbyv_zen_int10.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_axpyv_zen_int.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_axpyv_zen_int10.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_copyv_zen_int.c diff --git a/kernels/zen/1/bli_axpbyv_zen_int.c b/kernels/zen/1/bli_axpbyv_zen_int.c new file mode 100644 index 0000000000..05ef96175a --- /dev/null +++ b/kernels/zen/1/bli_axpbyv_zen_int.c @@ -0,0 +1,1099 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "immintrin.h" +#include "blis.h" + +/* Union DS to access AVX registers */ +/* One 256-bit AVX register holds 8 SP elements */ +typedef union +{ + __m256 v; + float f[8] __attribute__((aligned(64))); +} v8sf_t; + +/* One 256-bit AVX register holds 4 DP elements */ +typedef union +{ + __m256d v; + double d[4] __attribute__((aligned(64))); +} v4df_t; + +/** + * saxpbyv kernel performs the axpbyv operation. + * y := beta * y + alpha * conjx(x) + * where, + * x & y are single precision vectors of length n. + * alpha & beta are scalers. + */ +void bli_saxpbyv_zen_int + ( + conj_t conjx, + dim_t n, + float* restrict alpha, + float* restrict x, inc_t incx, + float* restrict beta, + float* restrict y, inc_t incy, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_4) + const dim_t n_elem_per_reg = 8; // number of elements per register + const dim_t n_iter_unroll = 4; // num of registers per iteration + + dim_t i; // iterator + + float* restrict x0; + float* restrict y0; + + v8sf_t alphav; + v8sf_t betav; + v8sf_t y0v, y1v, y2v, y3v; + + /* if the vector dimension is zero, or if alpha & beta are zero, + return early. */ + if ( bli_zero_dim1( n ) || + ( PASTEMAC( s, eq0 )( *alpha ) && PASTEMAC( s, eq0 )( *beta ) ) ) + return; + + // initialize local pointers + x0 = x; + y0 = y; + + if ( incx == 1 && incy == 1 ) + { + // broadcast alpha & beta to all elements of respective vector registers + alphav.v = _mm256_broadcast_ss( alpha ); + betav.v = _mm256_broadcast_ss( beta ); + + // unrolling and vectorizing + for ( i = 0; ( i + 31 ) < n; i += 32 ) + { + // loading input y + y0v.v = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); + y1v.v = _mm256_loadu_ps( y0 + 1*n_elem_per_reg ); + y2v.v = _mm256_loadu_ps( y0 + 2*n_elem_per_reg ); + y3v.v = _mm256_loadu_ps( y0 + 3*n_elem_per_reg ); + + // y' := y := beta * y + y0v.v = _mm256_mul_ps( betav.v, y0v.v ); + y1v.v = _mm256_mul_ps( betav.v, y1v.v ); + y2v.v = _mm256_mul_ps( betav.v, y2v.v ); + y3v.v = _mm256_mul_ps( betav.v, y3v.v ); + + // y := y' + alpha * x + y0v.v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 0*n_elem_per_reg ), + y0v.v + ); + y1v.v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 1*n_elem_per_reg ), + y1v.v + ); + y2v.v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 2*n_elem_per_reg ), + y2v.v + ); + y3v.v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 3*n_elem_per_reg ), + y3v.v + ); + + // storing the output + _mm256_storeu_ps( ( y0 + 0*n_elem_per_reg ), y0v.v ); + _mm256_storeu_ps( ( y0 + 1*n_elem_per_reg ), y1v.v ); + _mm256_storeu_ps( ( y0 + 2*n_elem_per_reg ), y2v.v ); + _mm256_storeu_ps( ( y0 + 3*n_elem_per_reg ), y3v.v ); + + x0 += n_elem_per_reg * n_iter_unroll; + y0 += n_elem_per_reg * n_iter_unroll; + } + + // Issue vzeroupper instruction to clear upper lanes of ymm registers. + // This avoids a performance penalty caused by false dependencies when + // transitioning from from AVX to SSE instructions (which may occur + // as soon as the n_left cleanup loop below if BLIS is compiled with + // -mfpmath=sse). + _mm256_zeroupper(); + + // if there are leftover iterations, perform them with scaler code + for ( ; i < n; ++i ) + { + *y0 = ( (*alpha) * (*x0) ) + ( (*beta) * (*y0) ); + + x0 += incx; + y0 += incy; + } + } + else + { + // for non-unit increments, use scaler code + for ( i = 0; i < n; ++i ) + { + *y0 = ( (*alpha) * (*x0) ) + ( (*beta) * (*y0) ); + + x0 += incx; + y0 += incy; + } + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) +} + +/** + * daxpbyv kernel performs the axpbyv operation. + * y := beta * y + alpha * conjx(x) + * where, + * x & y are double precision vectors of length n. + * alpha & beta are scalers. + */ +void bli_daxpbyv_zen_int + ( + conj_t conjx, + dim_t n, + double* restrict alpha, + double* restrict x, inc_t incx, + double* restrict beta, + double* restrict y, inc_t incy, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_4) + const dim_t n_elem_per_reg = 4; // number of elements per register + const dim_t n_iter_unroll = 4; // number of registers per iteration + + dim_t i; // iterator + + double* restrict x0; + double* restrict y0; + + v4df_t alphav; + v4df_t betav; + v4df_t y0v, y1v, y2v, y3v; + + /* if the vector dimension is zero, or if alpha & beta are zero, + return early. */ + if ( bli_zero_dim1( n ) || + ( PASTEMAC( s, eq0 )( *alpha ) && PASTEMAC( s, eq0 )( *beta ) ) ) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) + return; + } + + // initialize local pointers + x0 = x; + y0 = y; + + if ( incx == 1 && incy == 1 ) + { + // broadcast alpha & beta to all elements of respective vector registers + alphav.v = _mm256_broadcast_sd( alpha ); + betav.v = _mm256_broadcast_sd( beta ); + + // unrolling and vectorizing + for ( i = 0; ( i + 15 ) < n; i += 16 ) + { + // loading input y + y0v.v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + y1v.v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + y2v.v = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); + y3v.v = _mm256_loadu_pd( y0 + 3*n_elem_per_reg ); + + // y' := y := beta * y + y0v.v = _mm256_mul_pd( betav.v, y0v.v ); + y1v.v = _mm256_mul_pd( betav.v, y1v.v ); + y2v.v = _mm256_mul_pd( betav.v, y2v.v ); + y3v.v = _mm256_mul_pd( betav.v, y3v.v ); + + // y := y' + alpha * x + // := beta * y + alpha * x + y0v.v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 0*n_elem_per_reg ), + y0v.v + ); + y1v.v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 1*n_elem_per_reg ), + y1v.v + ); + y2v.v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 2*n_elem_per_reg ), + y2v.v + ); + y3v.v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 3*n_elem_per_reg ), + y3v.v + ); + + // storing the output + _mm256_storeu_pd( ( y0 + 0*n_elem_per_reg ), y0v.v ); + _mm256_storeu_pd( ( y0 + 1*n_elem_per_reg ), y1v.v ); + _mm256_storeu_pd( ( y0 + 2*n_elem_per_reg ), y2v.v ); + _mm256_storeu_pd( ( y0 + 3*n_elem_per_reg ), y3v.v ); + + x0 += n_elem_per_reg * n_iter_unroll; + y0 += n_elem_per_reg * n_iter_unroll; + } + + // Issue vzeroupper instruction to clear upper lanes of ymm registers. + // This avoids a performance penalty caused by false dependencies when + // transitioning from from AVX to SSE instructions (which may occur + // as soon as the n_left cleanup loop below if BLIS is compiled with + // -mfpmath=sse). + _mm256_zeroupper(); + + // if there are leftover iterations, perform them with scaler code + for ( ; i < n; ++i ) + { + *y0 = ( (*alpha) * (*x0) ) + ( (*beta) * (*y0) ); + + x0 += incx; + y0 += incy; + } + } + else + { + // for non-unit increments, use scaler code + for ( i = 0; i < n; ++i ) + { + *y0 = ( (*alpha) * (*x0) ) + ( (*beta) * (*y0) ); + + x0 += incx; + y0 += incy; + } + } +} + +/** + * caxpbyv kernel performs the axpbyv operation. + * y := beta * y + alpha * conjx(x) + * where, + * x & y are simple complex vectors of length n. + * alpha & beta are scalers. + */ +void bli_caxpbyv_zen_int + ( + conj_t conjx, + dim_t n, + scomplex* restrict alpha, + scomplex* restrict x, inc_t incx, + scomplex* restrict beta, + scomplex* restrict y, inc_t incy, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_4) + const dim_t n_elem_per_reg = 8; // number of elements per register + + dim_t i; // iterator + + float* restrict x0; + float* restrict y0; + + float alphaR, alphaI, betaR, betaI; + + __m256 alphaRv; + __m256 alphaIv; + __m256 betaRv; + __m256 betaIv; + __m256 xv[4]; + __m256 yv[4]; + __m256 iv[4]; // intermediate registers + + conj_t conjx_use = conjx; + + /* if the vector dimension is zero, or if alpha & beta are zero, + return early. */ + if ( bli_zero_dim1( n ) || + ( PASTEMAC( c, eq0 )( *alpha ) && PASTEMAC( c, eq0 )( *beta ) ) ) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) + return; + } + + // initialize local pointers + x0 = ( float* ) x; + y0 = ( float* ) y; + + alphaR = alpha->real; + alphaI = alpha->imag; + betaR = beta->real; + betaI = beta->imag; + + if ( incx == 1 && incy == 1 ) + { + //---------- Scalar algorithm BLIS_NO_CONJUGATE ------------- + // y = beta*y + alpha*x + // y = ( bR + ibI ) * ( yR + iyI ) + ( aR + iaI ) * ( xR + ixI ) + // y = bR.yR + ibR.yI + ibI.yR - ibIyI + aR.xR + iaR.xI + iaI.xR - aI.xI + // y = ( bR.yR - bI.yI + aR.xR - aI.xI ) + + // i ( bR.yI + bI.yR + aR.xI + aI.xR ) + + // SIMD Algorithm BLIS_NO_CONJUGATE + // yv = yR1 yI1 yR2 yI2 yR3 yI3 yR4 yI4 + // yv' = yI1 yR1 yI2 yR2 yI3 yR3 yI4 yR4 + // xv = xR1 xI1 xR2 xI2 xR3 xI3 xR4 xI4 + // xv' = xI1 xR1 xI2 xR2 xI3 xR3 xI4 xR4 + // arv = aR aR aR aR aR aR aR aR + // aiv = -aI aI -aI aI -aI aI -aI aI + // brv = bR bR bR bR bR bR bR bR + // biv = -bI bI -bI bI -bI bI -bI bI + + // step 1: iv = brv * iv + // step 2: shuffle yv -> yv' + // step 3: FMA yv = biv * yv' + iv + // step 4: iv = arv * xv + // step 5: shuffle xv -> xv' + // step 6: FMA yv = aiv * xv' + iv + + //---------- Scalar algorithm BLIS_CONJUGATE ------------- + // y = beta*y + alpha*conj(x) + // y = ( bR + ibI ) * ( yR + iyI ) + ( aR + iaI ) * ( xR - ixI ) + // y = bR.yR + ibR.yI + ibI.yR - bI.yI + aR.xR - iaR.xI + iaI.xR + aI.xI + // y = ( bR.yR - bI.yI + aR.xR + aI.xI ) + + // i ( bR.yI + bI.yR - aR.xI + aI.xR ) + + // SIMD Algorithm BLIS_CONJUGATE + // yv = yR1 yI1 yR2 yI2 yR3 yI3 yR4 yI4 + // yv' = yI1 yR1 yI2 yR2 yI3 yR3 yI4 yR4 + // xv = xR1 xI1 xR2 xI2 xR3 xI3 xR4 xI4 + // xv' = xI1 xR1 xI2 xR2 xI3 xR3 xI4 xR4 + // arv = aR -aR aR -aR aR -aR aR -aR + // aiv = aI aI aI aI aI aI aI aI + // brv = bR bR bR bR bR bR bR bR + // biv = -bI bI -bI bI -bI bI -bI bI + // + // step 1: iv = brv * iv + // step 2: shuffle yv -> yv' + // step 3: FMA yv = biv * yv' + iv + // step 4: iv = arv * xv + // step 5: shuffle xv -> xv' + // step 6: FMA yv = aiv * xv' + iv + + // broadcast alpha & beta to all elements of respective vector registers + if ( !bli_is_conj( conjx ) ) // If BLIS_NO_CONJUGATE + { + // alphaRv = aR aR aR aR aR aR aR aR + // alphaIv = -aI aI -aI aI -aI aI -aI aI + // betaRv = bR bR bR bR bR bR bR bR + // betaIv = -bI bI -bI bI -bI bI -bI bI + alphaRv = _mm256_broadcast_ss( &alphaR ); + alphaIv = _mm256_set_ps + ( + alphaI, -alphaI, alphaI, -alphaI, + alphaI, -alphaI, alphaI, -alphaI + ); + betaRv = _mm256_broadcast_ss( &betaR ); + betaIv = _mm256_set_ps + ( + betaI, -betaI, betaI, -betaI, + betaI, -betaI, betaI, -betaI + ); + } + else + { + // alphaRv = aR -aR aR -aR aR -aR aR -aR + // alphaIv = aI aI aI aI aI aI aI aI + // betaRv = bR bR bR bR bR bR bR bR + // betaIv = -bI bI -bI bI -bI bI -bI bI + alphaRv = _mm256_set_ps + ( + -alphaR, alphaR, -alphaR, alphaR, + -alphaR, alphaR, -alphaR, alphaR + ); + alphaIv = _mm256_broadcast_ss( &alphaI ); + betaRv = _mm256_broadcast_ss( &betaR ); + betaIv = _mm256_set_ps + ( + betaI, -betaI, betaI, -betaI, + betaI, -betaI, betaI, -betaI + ); + } + + // Processing 16 elements per loop, 8 FMAs + for ( i = 0; ( i + 15 ) < n; i += 16 ) + { + // xv = xR1 xI1 xR2 xI2 xR3 xI3 xR4 xI4 + xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); + xv[2] = _mm256_loadu_ps( x0 + 2*n_elem_per_reg ); + xv[3] = _mm256_loadu_ps( x0 + 3*n_elem_per_reg ); + + // yv = yR1 yI1 yR2 yI2 yR3 yI3 yR4 yI4 + yv[0] = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); + yv[1] = _mm256_loadu_ps( y0 + 1*n_elem_per_reg ); + yv[2] = _mm256_loadu_ps( y0 + 2*n_elem_per_reg ); + yv[3] = _mm256_loadu_ps( y0 + 3*n_elem_per_reg ); + + // iv = betaRv * yv + // = yR1.bR, yI1.bR, yR2.bR, yI2.bR, ... + iv[0] = _mm256_mul_ps( betaRv, yv[0] ); + iv[1] = _mm256_mul_ps( betaRv, yv[1] ); + iv[2] = _mm256_mul_ps( betaRv, yv[2] ); + iv[3] = _mm256_mul_ps( betaRv, yv[3] ); + + // yv' = yI1 yR1 yI2 yR2 yI3 yR3 yI4 yR4 + yv[0] = _mm256_permute_ps( yv[0], 0xB1); + yv[1] = _mm256_permute_ps( yv[1], 0xB1); + yv[2] = _mm256_permute_ps( yv[2], 0xB1); + yv[3] = _mm256_permute_ps( yv[3], 0xB1); + + // yv = betaIv * yv' + iv + // = yR1.bR - yI1.bI, yI1.bR + yR1.bI, ... + yv[0] = _mm256_fmadd_ps( betaIv, yv[0], iv[0] ); + yv[1] = _mm256_fmadd_ps( betaIv, yv[1], iv[1] ); + yv[2] = _mm256_fmadd_ps( betaIv, yv[2], iv[2] ); + yv[3] = _mm256_fmadd_ps( betaIv, yv[3], iv[3] ); + + // iv = alphaRv * xv + // = xR1.aR, xI1.aR, xR2.aR, xI2.aR, ... + iv[0] = _mm256_mul_ps( alphaRv, xv[0] ); + iv[1] = _mm256_mul_ps( alphaRv, xv[1] ); + iv[2] = _mm256_mul_ps( alphaRv, xv[2] ); + iv[3] = _mm256_mul_ps( alphaRv, xv[3] ); + + // xv' = xI1 xR1 xI2 xR2 xI3 xR3 xI4 xR4 + xv[0] = _mm256_permute_ps( xv[0], 0xB1); + xv[1] = _mm256_permute_ps( xv[1], 0xB1); + xv[2] = _mm256_permute_ps( xv[2], 0xB1); + xv[3] = _mm256_permute_ps( xv[3], 0xB1); + + // yv = alphaIv * xv + yv + // = yR1.bR - yR1.bI - xR1.aI, yI1.bR + yI1.bI + xI1.aI, ... + yv[0] = _mm256_fmadd_ps( alphaIv, xv[0], yv[0] ); + yv[1] = _mm256_fmadd_ps( alphaIv, xv[1], yv[1] ); + yv[2] = _mm256_fmadd_ps( alphaIv, xv[2], yv[2] ); + yv[3] = _mm256_fmadd_ps( alphaIv, xv[3], yv[3] ); + + _mm256_storeu_ps( (y0 + 0*n_elem_per_reg), yv[0] ); + _mm256_storeu_ps( (y0 + 1*n_elem_per_reg), yv[1] ); + _mm256_storeu_ps( (y0 + 2*n_elem_per_reg), yv[2] ); + _mm256_storeu_ps( (y0 + 3*n_elem_per_reg), yv[3] ); + + y0 += 4*n_elem_per_reg; + x0 += 4*n_elem_per_reg; + } + + // Processing 12 elements per loop, 6 FMAs + for ( ; ( i + 11 ) < n; i += 12 ) + { + // xv = xR1 xI1 xR2 xI2 xR3 xI3 xR4 xI4 + xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); + xv[2] = _mm256_loadu_ps( x0 + 2*n_elem_per_reg ); + + // yv = yR1 yI1 yR2 yI2 yR3 yI3 yR4 yI4 + yv[0] = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); + yv[1] = _mm256_loadu_ps( y0 + 1*n_elem_per_reg ); + yv[2] = _mm256_loadu_ps( y0 + 2*n_elem_per_reg ); + + // iv = betaRv * yv + // = yR1.bR, yI1.bR, yR2.bR, yI2.bR, ... + iv[0] = _mm256_mul_ps( betaRv, yv[0] ); + iv[1] = _mm256_mul_ps( betaRv, yv[1] ); + iv[2] = _mm256_mul_ps( betaRv, yv[2] ); + + // yv' = yI1 yR1 yI2 yR2 yI3 yR3 yI4 yR4 + yv[0] = _mm256_permute_ps( yv[0], 0xB1); + yv[1] = _mm256_permute_ps( yv[1], 0xB1); + yv[2] = _mm256_permute_ps( yv[2], 0xB1); + + // yv = betaIv * yv' + iv + // = yR1.bR - yI1.bI, yI1.bR + yR1.bI, ... + yv[0] = _mm256_fmadd_ps( betaIv, yv[0], iv[0] ); + yv[1] = _mm256_fmadd_ps( betaIv, yv[1], iv[1] ); + yv[2] = _mm256_fmadd_ps( betaIv, yv[2], iv[2] ); + + // iv = alphaRv * xv + // = xR1.aR, xI1.aR, xR2.aR, xI2.aR, ... + iv[0] = _mm256_mul_ps( alphaRv, xv[0] ); + iv[1] = _mm256_mul_ps( alphaRv, xv[1] ); + iv[2] = _mm256_mul_ps( alphaRv, xv[2] ); + + // xv' = xI1 xR1 xI2 xR2 xI3 xR3 xI4 xR4 + xv[0] = _mm256_permute_ps( xv[0], 0xB1); + xv[1] = _mm256_permute_ps( xv[1], 0xB1); + xv[2] = _mm256_permute_ps( xv[2], 0xB1); + + // yv = alphaIv * xv + yv + // = yR1.bR - yR1.bI - xR1.aI, yI1.bR + yI1.bI + xI1.aI, ... + yv[0] = _mm256_fmadd_ps( alphaIv, xv[0], yv[0] ); + yv[1] = _mm256_fmadd_ps( alphaIv, xv[1], yv[1] ); + yv[2] = _mm256_fmadd_ps( alphaIv, xv[2], yv[2] ); + + _mm256_storeu_ps( (y0 + 0*n_elem_per_reg), yv[0] ); + _mm256_storeu_ps( (y0 + 1*n_elem_per_reg), yv[1] ); + _mm256_storeu_ps( (y0 + 2*n_elem_per_reg), yv[2] ); + + y0 += 3*n_elem_per_reg; + x0 += 3*n_elem_per_reg; + } + + // Processing 16 elements per loop, 8 FMAs + for ( ; ( i + 7 ) < n; i += 8 ) + { + // xv = xR1 xI1 xR2 xI2 xR3 xI3 xR4 xI4 + xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); + + // yv = yR1 yI1 yR2 yI2 yR3 yI3 yR4 yI4 + yv[0] = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); + yv[1] = _mm256_loadu_ps( y0 + 1*n_elem_per_reg ); + + // iv = betaRv * yv + // = yR1.bR, yI1.bR, yR2.bR, yI2.bR, ... + iv[0] = _mm256_mul_ps( betaRv, yv[0] ); + iv[1] = _mm256_mul_ps( betaRv, yv[1] ); + + // yv' = yI1 yR1 yI2 yR2 yI3 yR3 yI4 yR4 + yv[0] = _mm256_permute_ps( yv[0], 0xB1); + yv[1] = _mm256_permute_ps( yv[1], 0xB1); + + // yv = betaIv * yv' + iv + // = yR1.bR - yI1.bI, yI1.bR + yR1.bI, ... + yv[0] = _mm256_fmadd_ps( betaIv, yv[0], iv[0] ); + yv[1] = _mm256_fmadd_ps( betaIv, yv[1], iv[1] ); + + // iv = alphaRv * xv + // = xR1.aR, xI1.aR, xR2.aR, xI2.aR, ... + iv[0] = _mm256_mul_ps( alphaRv, xv[0] ); + iv[1] = _mm256_mul_ps( alphaRv, xv[1] ); + + // xv' = xI1 xR1 xI2 xR2 xI3 xR3 xI4 xR4 + xv[0] = _mm256_permute_ps( xv[0], 0xB1); + xv[1] = _mm256_permute_ps( xv[1], 0xB1); + + // yv = alphaIv * xv + yv + // = yR1.bR - yR1.bI - xR1.aI, yI1.bR + yI1.bI + xI1.aI, ... + yv[0] = _mm256_fmadd_ps( alphaIv, xv[0], yv[0] ); + yv[1] = _mm256_fmadd_ps( alphaIv, xv[1], yv[1] ); + + _mm256_storeu_ps( (y0 + 0*n_elem_per_reg), yv[0] ); + _mm256_storeu_ps( (y0 + 1*n_elem_per_reg), yv[1] ); + + y0 += 2*n_elem_per_reg; + x0 += 2*n_elem_per_reg; + } + + // Issue vzeroupper instruction to clear upper lanes of ymm registers. + // This avoids a performance penalty caused by false dependencies when + // transitioning from from AVX to SSE instructions (which may occur + // as soon as the n_left cleanup loop below if BLIS is compiled with + // -mfpmath=sse). + _mm256_zeroupper(); + + if ( !bli_is_conj( conjx_use ) ) + { + for ( ; i < n ; ++i ) + { + *y0 = ( betaR * (*y0) ) - ( betaI * (*(y0 + 1)) ) + + ( alphaR * (*x0) ) - ( alphaI * (*(x0 + 1)) ); + *(y0 + 1) = ( betaR * (*(y0 + 1)) ) + ( betaI * (*y0) ) + + ( alphaR * (*(x0 + 1)) ) + ( alphaI * (*x0) ); + + x0 += 2; + y0 += 2; + } + } + else + { + for ( ; i < n ; ++i ) + { + *y0 = ( betaR * (*y0) ) - ( betaI * (*(y0 + 1)) ) + + ( alphaR * (*x0) ) + ( alphaI * (*(x0 + 1)) ); + *(y0 + 1) = ( betaR * (*(y0 + 1)) ) + ( betaI * (*y0) ) - + ( alphaR * (*(x0 + 1)) ) + ( alphaI * (*x0) ); + + x0 += 2; + y0 += 2; + } + } + } + else + { + // for non-unit increments, use scaler code + if ( !bli_is_conj( conjx_use ) ) + { + for ( i = 0; i < n ; ++i ) + { + // yReal = ( bR.yR - bI.yI + aR.xR - aI.xI ) + *y0 = ( betaR * (*y0) ) - ( betaI * (*(y0 + 1)) ) + + ( alphaR * (*x0) ) - ( alphaI * (*(x0 + 1)) ); + // yImag = ( bR.yI + bI.yR + aR.xI + aI.xR ) + *(y0 + 1) = ( betaR * (*(y0 + 1)) ) + ( betaI * (*y0) ) + + ( alphaR * (*(x0 + 1)) ) + ( alphaI * (*x0) ); + + x0 += incx * 2; + y0 += incy * 2; + } + } + else + { + for ( i = 0; i < n ; ++i ) + { + // yReal = ( bR.yR - bI.yI + aR.xR - aI.xI ) + *y0 = ( betaR * (*y0) ) - ( betaI * (*(y0 + 1)) ) + + ( alphaR * (*x0) ) + ( alphaI * (*(x0 + 1)) ); + // yImag = ( bR.yI + bI.yR + aR.xI + aI.xR ) + *(y0 + 1) = ( betaR * (*(y0 + 1)) ) + ( betaI * (*y0) ) - + ( alphaR * (*(x0 + 1)) ) + ( alphaI * (*x0) ); + + x0 += incx * 2; + y0 += incy * 2; + } + } + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) +} + +/** + * zaxpbyv kernel performs the axpbyv operation. + * y := beta * y + alpha * conjx(x) + * where, + * x & y are double complex vectors of length n. + * alpha & beta are scalers. + */ +void bli_zaxpbyv_zen_int + ( + conj_t conjx, + dim_t n, + dcomplex* restrict alpha, + dcomplex* restrict x, inc_t incx, + dcomplex* restrict beta, + dcomplex* restrict y, inc_t incy, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_4) + const dim_t n_elem_per_reg = 4; // number of elements per register + + dim_t i; // iterator + + double* restrict x0; + double* restrict y0; + + double alphaR, alphaI, betaR, betaI; + + __m256d alphaRv; + __m256d alphaIv; + __m256d betaRv; + __m256d betaIv; + __m256d xv[4]; + __m256d yv[4]; + __m256d iv[4]; // intermediate registers + + conj_t conjx_use = conjx; + + /* if the vector dimension is zero, or if alpha & beta are zero, + return early. */ + if ( bli_zero_dim1( n ) || + ( PASTEMAC( c, eq0 )( *alpha ) && PASTEMAC( c, eq0 )( *beta ) ) ) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) + return; + } + + // initialize local pointers + x0 = ( double* ) x; + y0 = ( double* ) y; + + alphaR = alpha->real; + alphaI = alpha->imag; + betaR = beta->real; + betaI = beta->imag; + + if ( incx == 1 && incy == 1 ) + { + //---------- Scalar algorithm BLIS_NO_CONJUGATE ------------- + // y = beta*y + alpha*x + // y = ( bR + ibI ) * ( yR + iyI ) + ( aR + iaI ) * ( xR + ixI ) + // y = bR.yR + ibR.yI + ibI.yR - ibIyI + aR.xR + iaR.xI + iaI.xR - aI.xI + // y = ( bR.yR - bI.yI + aR.xR - aI.xI ) + + // i ( bR.yI + bI.yR + aR.xI + aI.xR ) + + // SIMD Algorithm BLIS_NO_CONJUGATE + // yv = yR1 yI1 yR2 yI2 + // yv' = yI1 yR1 yI2 yR2 + // xv = xR1 xI1 xR2 xI2 + // xv' = xI1 xR1 xI2 xR2 + // arv = aR aR aR aR + // aiv = -aI aI -aI aI + // brv = bR bR bR bR + // biv = -bI bI -bI bI + // + // step 1: iv = brv * iv + // step 2: shuffle yv -> yv' + // step 3: FMA yv = biv * yv' + iv + // step 4: iv = arv * xv + // step 5: shuffle xv -> xv' + // step 6: FMA yv = aiv * xv' + iv + + //---------- Scalar algorithm BLIS_CONJUGATE ------------- + // y = beta*y + alpha*conj(x) + // y = ( bR + ibI ) * ( yR + iyI ) + ( aR + iaI ) * ( xR - ixI ) + // y = bR.yR + ibR.yI + ibI.yR - bI.yI + aR.xR - iaR.xI + iaI.xR + aI.xI + // y = ( bR.yR - bI.yI + aR.xR + aI.xI ) + + // i ( bR.yI + bI.yR - aR.xI + aI.xR ) + + // SIMD Algorithm BLIS_CONJUGATE + // yv = yR1 yI1 yR2 yI2 + // yv' = yI1 yR1 yI2 yR2 + // xv = xR1 xI1 xR2 xI2 + // xv' = xI1 xR1 xI2 xR2 + // arv = aR -aR aR -aR + // aiv = aI aI aI aI + // brv = bR bR bR bR + // biv = -bI bI -bI bI + // + // step 1: iv = brv * iv + // step 2: shuffle yv -> yv' + // step 3: FMA yv = biv * yv' + iv + // step 4: iv = arv * xv + // step 5: shuffle xv -> xv' + // step 6: FMA yv = aiv * xv' + iv + + // broadcast alpha & beta to all elements of respective vector registers + if ( !bli_is_conj( conjx ) ) + { + // alphaRv = aR aR aR aR + // alphaIv = -aI aI -aI aI + // betaRv = bR bR bR bR + // betaIv = -bI bI -bI bI + alphaRv = _mm256_broadcast_sd( &alphaR ); + alphaIv = _mm256_set_pd( alphaI, -alphaI, alphaI, -alphaI ); + betaRv = _mm256_broadcast_sd( &betaR ); + betaIv = _mm256_set_pd( betaI, -betaI, betaI, -betaI ); + } + else + { + // alphaRv = aR -aR aR -aR + // alphaIv = aI aI aI aI + // betaRv = bR bR bR bR + // betaIv = -bI bI -bI bI + alphaRv = _mm256_set_pd( -alphaR, alphaR, -alphaR, alphaR ); + alphaIv = _mm256_broadcast_sd( &alphaI ); + betaRv = _mm256_broadcast_sd( &betaR ); + betaIv = _mm256_set_pd( betaI, -betaI, betaI, -betaI ); + } + + // Processing 8 elements per loop, 8 FMAs + for ( i = 0; ( i + 7 ) < n; i += 8 ) + { + // xv = xR1 xI1 xR2 xI2 + xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); + xv[2] = _mm256_loadu_pd( x0 + 2*n_elem_per_reg ); + xv[3] = _mm256_loadu_pd( x0 + 3*n_elem_per_reg ); + + // yv = yR1 yI1 yR2 yI2 + yv[0] = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + yv[1] = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + yv[2] = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); + yv[3] = _mm256_loadu_pd( y0 + 3*n_elem_per_reg ); + + // iv = betaRv * yv + // = yR1.bR, yI1.bR, yR2.bR, yI2.bR, ... + iv[0] = _mm256_mul_pd( betaRv, yv[0] ); + iv[1] = _mm256_mul_pd( betaRv, yv[1] ); + iv[2] = _mm256_mul_pd( betaRv, yv[2] ); + iv[3] = _mm256_mul_pd( betaRv, yv[3] ); + + // yv' = yI1 yR1 yI2 yR2 + yv[0] = _mm256_permute_pd( yv[0], 5); + yv[1] = _mm256_permute_pd( yv[1], 5); + yv[2] = _mm256_permute_pd( yv[2], 5); + yv[3] = _mm256_permute_pd( yv[3], 5); + + // yv = betaIv * yv' + iv + // = yR1.bR - yI1.bI, yI1.bR + yR1.bI, ... + yv[0] = _mm256_fmadd_pd( betaIv, yv[0], iv[0] ); + yv[1] = _mm256_fmadd_pd( betaIv, yv[1], iv[1] ); + yv[2] = _mm256_fmadd_pd( betaIv, yv[2], iv[2] ); + yv[3] = _mm256_fmadd_pd( betaIv, yv[3], iv[3] ); + + // iv = alphaRv * xv + // = xR1.aR, xI1.aR, xR2.aR, xI2.aR, ... + iv[0] = _mm256_mul_pd( alphaRv, xv[0] ); + iv[1] = _mm256_mul_pd( alphaRv, xv[1] ); + iv[2] = _mm256_mul_pd( alphaRv, xv[2] ); + iv[3] = _mm256_mul_pd( alphaRv, xv[3] ); + + // xv' = xI1 xR1 xI2 xR2 + xv[0] = _mm256_permute_pd( xv[0], 5); + xv[1] = _mm256_permute_pd( xv[1], 5); + xv[2] = _mm256_permute_pd( xv[2], 5); + xv[3] = _mm256_permute_pd( xv[3], 5); + + // yv = alphaIv * xv + yv + // = yR1.bR - yR1.bI - xR1.aI, yI1.bR + yI1.bI + xI1.aI, ... + yv[0] = _mm256_fmadd_pd( alphaIv, xv[0], yv[0] ); + yv[1] = _mm256_fmadd_pd( alphaIv, xv[1], yv[1] ); + yv[2] = _mm256_fmadd_pd( alphaIv, xv[2], yv[2] ); + yv[3] = _mm256_fmadd_pd( alphaIv, xv[3], yv[3] ); + + _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), yv[0] ); + _mm256_storeu_pd( (y0 + 1*n_elem_per_reg), yv[1] ); + _mm256_storeu_pd( (y0 + 2*n_elem_per_reg), yv[2] ); + _mm256_storeu_pd( (y0 + 3*n_elem_per_reg), yv[3] ); + + y0 += 4*n_elem_per_reg; + x0 += 4*n_elem_per_reg; + } + + // Processing 6 elements per loop, 6 FMAs + for ( ; ( i + 5 ) < n; i += 6 ) + { + // xv = xR1 xI1 xR2 xI2 + xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); + xv[2] = _mm256_loadu_pd( x0 + 2*n_elem_per_reg ); + + // yv = yR1 yI1 yR2 yI2 + yv[0] = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + yv[1] = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + yv[2] = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); + + // iv = betaRv * yv + // = yR1.bR, yI1.bR, yR2.bR, yI2.bR, ... + iv[0] = _mm256_mul_pd( betaRv, yv[0] ); + iv[1] = _mm256_mul_pd( betaRv, yv[1] ); + iv[2] = _mm256_mul_pd( betaRv, yv[2] ); + + // yv' = yI1 yR1 yI2 yR2 + yv[0] = _mm256_permute_pd( yv[0], 5); + yv[1] = _mm256_permute_pd( yv[1], 5); + yv[2] = _mm256_permute_pd( yv[2], 5); + + // yv = betaIv * yv' + iv + // = yR1.bR - yI1.bI, yI1.bR + yR1.bI, ... + yv[0] = _mm256_fmadd_pd( betaIv, yv[0], iv[0] ); + yv[1] = _mm256_fmadd_pd( betaIv, yv[1], iv[1] ); + yv[2] = _mm256_fmadd_pd( betaIv, yv[2], iv[2] ); + + // iv = alphaRv * xv + // = xR1.aR, xI1.aR, xR2.aR, xI2.aR, ... + iv[0] = _mm256_mul_pd( alphaRv, xv[0] ); + iv[1] = _mm256_mul_pd( alphaRv, xv[1] ); + iv[2] = _mm256_mul_pd( alphaRv, xv[2] ); + + // xv' = xI1 xR1 xI2 xR2 + xv[0] = _mm256_permute_pd( xv[0], 5); + xv[1] = _mm256_permute_pd( xv[1], 5); + xv[2] = _mm256_permute_pd( xv[2], 5); + + // yv = alphaIv * xv + yv + // = yR1.bR - yR1.bI - xR1.aI, yI1.bR + yI1.bI + xI1.aI, ... + yv[0] = _mm256_fmadd_pd( alphaIv, xv[0], yv[0] ); + yv[1] = _mm256_fmadd_pd( alphaIv, xv[1], yv[1] ); + yv[2] = _mm256_fmadd_pd( alphaIv, xv[2], yv[2] ); + + _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), yv[0] ); + _mm256_storeu_pd( (y0 + 1*n_elem_per_reg), yv[1] ); + _mm256_storeu_pd( (y0 + 2*n_elem_per_reg), yv[2] ); + + y0 += 3*n_elem_per_reg; + x0 += 3*n_elem_per_reg; + } + + // Processing 4 elements per loop, 4 FMAs + for ( ; ( i + 3 ) < n; i += 4 ) + { + // xv = xR1 xI1 xR2 xI2 + xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); + + // yv = yR1 yI1 yR2 yI2 + yv[0] = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + yv[1] = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + + // iv = betaRv * yv + // = yR1.bR, yI1.bR, yR2.bR, yI2.bR, ... + iv[0] = _mm256_mul_pd( betaRv, yv[0] ); + iv[1] = _mm256_mul_pd( betaRv, yv[1] ); + + // yv' = yI1 yR1 yI2 yR2 + yv[0] = _mm256_permute_pd( yv[0], 5); + yv[1] = _mm256_permute_pd( yv[1], 5); + + // yv = betaIv * yv' + iv + // = yR1.bR - yI1.bI, yI1.bR + yR1.bI, ... + yv[0] = _mm256_fmadd_pd( betaIv, yv[0], iv[0] ); + yv[1] = _mm256_fmadd_pd( betaIv, yv[1], iv[1] ); + + // iv = alphaRv * xv + // = xR1.aR, xI1.aR, xR2.aR, xI2.aR, ... + iv[0] = _mm256_mul_pd( alphaRv, xv[0] ); + iv[1] = _mm256_mul_pd( alphaRv, xv[1] ); + + // xv' = xI1 xR1 xI2 xR2 + xv[0] = _mm256_permute_pd( xv[0], 5); + xv[1] = _mm256_permute_pd( xv[1], 5); + + // yv = alphaIv * xv + yv + // = yR1.bR - yR1.bI - xR1.aI, yI1.bR + yI1.bI + xI1.aI, ... + yv[0] = _mm256_fmadd_pd( alphaIv, xv[0], yv[0] ); + yv[1] = _mm256_fmadd_pd( alphaIv, xv[1], yv[1] ); + + _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), yv[0] ); + _mm256_storeu_pd( (y0 + 1*n_elem_per_reg), yv[1] ); + + y0 += 2*n_elem_per_reg; + x0 += 2*n_elem_per_reg; + } + + // Processing 2 elements per loop, 3 FMAs + for ( ; ( i + 1 ) < n; i += 2 ) + { + // xv = xR1 xI1 xR2 xI2 + xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); + + // yv = yR1 yI1 yR2 yI2 + yv[0] = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + + // iv = betaRv * yv + // = yR1.bR, yI1.bR, yR2.bR, yI2.bR, ... + iv[0] = _mm256_mul_pd( betaRv, yv[0] ); + + // yv' = yI1 yR1 yI2 yR2 + yv[0] = _mm256_permute_pd( yv[0], 5); + + // yv = betaIv * yv' + iv + // = yR1.bR - yI1.bI, yI1.bR + yR1.bI, ... + yv[0] = _mm256_fmadd_pd( betaIv, yv[0], iv[0] ); + + // iv = alphaRv * xv + // = xR1.aR, xI1.aR, xR2.aR, xI2.aR, ... + iv[0] = _mm256_mul_pd( alphaRv, xv[0] ); + + // xv' = xI1 xR1 xI2 xR2 + xv[0] = _mm256_permute_pd( xv[0], 5); + + // yv = alphaIv * xv + yv + // = yR1.bR - yR1.bI - xR1.aI, yI1.bR + yI1.bI + xI1.aI, ... + yv[0] = _mm256_fmadd_pd( alphaIv, xv[0], yv[0] ); + + _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), yv[0] ); + + y0 += 1*n_elem_per_reg; + x0 += 1*n_elem_per_reg; + } + + // Issue vzeroupper instruction to clear upper lanes of ymm registers. + // This avoids a performance penalty caused by false dependencies when + // transitioning from from AVX to SSE instructions (which may occur + // as soon as the n_left cleanup loop below if BLIS is compiled with + // -mfpmath=sse). + _mm256_zeroupper(); + + if ( !bli_is_conj( conjx_use ) ) + { + for ( ; i < n ; ++i ) + { + // yReal = ( bR.yR - bI.yI + aR.xR - aI.xI ) + *y0 = ( betaR * (*y0) ) - ( betaI * (*(y0 + 1)) ) + + ( alphaR * (*x0) ) - ( alphaI * (*(x0 + 1)) ); + // yImag = ( bR.yI + bI.yR + aR.xI + aI.xR ) + *(y0 + 1) = ( betaR * (*(y0 + 1)) ) + ( betaI * (*y0) ) + + ( alphaR * (*(x0 + 1)) ) + ( alphaI * (*x0) ); + + x0 += 2; + y0 += 2; + } + } + else + { + for ( ; i < n ; ++i ) + { + // yReal = ( bR.yR - bI.yI + aR.xR - aI.xI ) + *y0 = ( betaR * (*y0) ) - ( betaI * (*(y0 + 1)) ) + + ( alphaR * (*x0) ) + ( alphaI * (*(x0 + 1)) ); + // yImag = ( bR.yI + bI.yR + aR.xI + aI.xR ) + *(y0 + 1) = ( betaR * (*(y0 + 1)) ) + ( betaI * (*y0) ) - + ( alphaR * (*(x0 + 1)) ) + ( alphaI * (*x0) ); + + x0 += 2; + y0 += 2; + } + } + } + else + { + // for non-unit increments, use scaler code + if ( !bli_is_conj( conjx_use ) ) + { + for ( i = 0; i < n ; ++i ) + { + // yReal = ( bR.yR - bI.yI + aR.xR - aI.xI ) + *y0 = ( betaR * (*y0) ) - ( betaI * (*(y0 + 1)) ) + + ( alphaR * (*x0) ) - ( alphaI * (*(x0 + 1)) ); + // yImag = ( bR.yI + bI.yR + aR.xI + aI.xR ) + *(y0 + 1) = ( betaR * (*(y0 + 1)) ) + ( betaI * (*y0) ) + + ( alphaR * (*(x0 + 1)) ) + ( alphaI * (*x0) ); + + x0 += incx * 2; + y0 += incy * 2; + } + } + else + { + for ( i = 0; i < n ; ++i ) + { + // yReal = ( bR.yR - bI.yI + aR.xR - aI.xI ) + *y0 = ( betaR * (*y0) ) - ( betaI * (*(y0 + 1)) ) + + ( alphaR * (*x0) ) + ( alphaI * (*(x0 + 1)) ); + // yImag = ( bR.yI + bI.yR + aR.xI + aI.xR ) + *(y0 + 1) = ( betaR * (*(y0 + 1)) ) + ( betaI * (*y0) ) - + ( alphaR * (*(x0 + 1)) ) + ( alphaI * (*x0) ); + + x0 += incx * 2; + y0 += incy * 2; + } + } + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) +} \ No newline at end of file diff --git a/kernels/zen/1/bli_axpbyv_zen_int10.c b/kernels/zen/1/bli_axpbyv_zen_int10.c new file mode 100644 index 0000000000..bbfdaf0d6a --- /dev/null +++ b/kernels/zen/1/bli_axpbyv_zen_int10.c @@ -0,0 +1,709 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "immintrin.h" +#include "blis.h" + +/* Union DS to access AVX registers */ +/* One 256-bit AVX register holds 8 SP elements */ +typedef union +{ + __m256 v; + float f[8] __attribute__((aligned(64))); +} v8sf_t; + +/* One 256-bit AVX register holds 4 DP elements */ +typedef union +{ + __m256d v; + double d[4] __attribute__((aligned(64))); +} v4df_t; + +/** + * saxpbyv kernel performs the axpbyv operation. + * y := beta * y + alpha * conjx(x) + * where, + * x & y are single precision vectors of length n. + * alpha & beta are scalers. + */ +void bli_saxpbyv_zen_int10 + ( + conj_t conjx, + dim_t n, + float* restrict alpha, + float* restrict x, inc_t incx, + float* restrict beta, + float* restrict y, inc_t incy, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_4) + const dim_t n_elem_per_reg = 8; // number of elements per register + + dim_t i; // iterator + + float* restrict x0; + float* restrict y0; + + v8sf_t alphav; + v8sf_t betav; + v8sf_t yv[10]; + + /* if the vector dimension is zero, or if alpha & beta are zero, + return early. */ + if ( bli_zero_dim1( n ) || + ( PASTEMAC( s, eq0 )( *alpha ) && PASTEMAC( s, eq0 )( *beta ) ) ) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) + return; + } + + // initialize local pointers + x0 = x; + y0 = y; + + if ( incx == 1 && incy == 1 ) + { + // broadcast alpha & beta to all elements of respective vector registers + alphav.v = _mm256_broadcast_ss( alpha ); + betav.v = _mm256_broadcast_ss( beta ); + + // Processing 80 elements per loop, 10 FMAs + for ( i = 0; ( i + 79 ) < n; i += 80 ) + { + // loading input values + yv[0].v = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); + yv[1].v = _mm256_loadu_ps( y0 + 1*n_elem_per_reg ); + yv[2].v = _mm256_loadu_ps( y0 + 2*n_elem_per_reg ); + yv[3].v = _mm256_loadu_ps( y0 + 3*n_elem_per_reg ); + yv[4].v = _mm256_loadu_ps( y0 + 4*n_elem_per_reg ); + yv[5].v = _mm256_loadu_ps( y0 + 5*n_elem_per_reg ); + yv[6].v = _mm256_loadu_ps( y0 + 6*n_elem_per_reg ); + yv[7].v = _mm256_loadu_ps( y0 + 7*n_elem_per_reg ); + yv[8].v = _mm256_loadu_ps( y0 + 8*n_elem_per_reg ); + yv[9].v = _mm256_loadu_ps( y0 + 9*n_elem_per_reg ); + + // y' := y := beta * y + yv[0].v = _mm256_mul_ps( betav.v, yv[0].v ); + yv[1].v = _mm256_mul_ps( betav.v, yv[1].v ); + yv[2].v = _mm256_mul_ps( betav.v, yv[2].v ); + yv[3].v = _mm256_mul_ps( betav.v, yv[3].v ); + yv[4].v = _mm256_mul_ps( betav.v, yv[4].v ); + yv[5].v = _mm256_mul_ps( betav.v, yv[5].v ); + yv[6].v = _mm256_mul_ps( betav.v, yv[6].v ); + yv[7].v = _mm256_mul_ps( betav.v, yv[7].v ); + yv[8].v = _mm256_mul_ps( betav.v, yv[8].v ); + yv[9].v = _mm256_mul_ps( betav.v, yv[9].v ); + + // y := y' + alpha * x + yv[0].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 0*n_elem_per_reg ), + yv[0].v + ); + yv[1].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 1*n_elem_per_reg ), + yv[1].v + ); + yv[2].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 2*n_elem_per_reg ), + yv[2].v + ); + yv[3].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 3*n_elem_per_reg ), + yv[3].v + ); + yv[4].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 4*n_elem_per_reg ), + yv[4].v + ); + yv[5].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 5*n_elem_per_reg ), + yv[5].v + ); + yv[6].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 6*n_elem_per_reg ), + yv[6].v + ); + yv[7].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 7*n_elem_per_reg ), + yv[7].v + ); + yv[8].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 8*n_elem_per_reg ), + yv[8].v + ); + yv[9].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 9*n_elem_per_reg ), + yv[9].v + ); + + // storing the output + _mm256_storeu_ps( ( y0 + 0*n_elem_per_reg ), yv[0].v ); + _mm256_storeu_ps( ( y0 + 1*n_elem_per_reg ), yv[1].v ); + _mm256_storeu_ps( ( y0 + 2*n_elem_per_reg ), yv[2].v ); + _mm256_storeu_ps( ( y0 + 3*n_elem_per_reg ), yv[3].v ); + _mm256_storeu_ps( ( y0 + 4*n_elem_per_reg ), yv[4].v ); + _mm256_storeu_ps( ( y0 + 5*n_elem_per_reg ), yv[5].v ); + _mm256_storeu_ps( ( y0 + 6*n_elem_per_reg ), yv[6].v ); + _mm256_storeu_ps( ( y0 + 7*n_elem_per_reg ), yv[7].v ); + _mm256_storeu_ps( ( y0 + 8*n_elem_per_reg ), yv[8].v ); + _mm256_storeu_ps( ( y0 + 9*n_elem_per_reg ), yv[9].v ); + + x0 += 10 * n_elem_per_reg; + y0 += 10 * n_elem_per_reg; + } + + // Processing 40 elements per loop, 5 FMAs + for ( ; ( i + 39 ) < n; i += 40 ) + { + // loading input values + yv[0].v = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); + yv[1].v = _mm256_loadu_ps( y0 + 1*n_elem_per_reg ); + yv[2].v = _mm256_loadu_ps( y0 + 2*n_elem_per_reg ); + yv[3].v = _mm256_loadu_ps( y0 + 3*n_elem_per_reg ); + yv[4].v = _mm256_loadu_ps( y0 + 4*n_elem_per_reg ); + + // y' := y := beta * y + yv[0].v = _mm256_mul_ps( betav.v, yv[0].v ); + yv[1].v = _mm256_mul_ps( betav.v, yv[1].v ); + yv[2].v = _mm256_mul_ps( betav.v, yv[2].v ); + yv[3].v = _mm256_mul_ps( betav.v, yv[3].v ); + yv[4].v = _mm256_mul_ps( betav.v, yv[4].v ); + + // y := y' + alpha * x + yv[0].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 0*n_elem_per_reg ), + yv[0].v + ); + yv[1].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 1*n_elem_per_reg ), + yv[1].v + ); + yv[2].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 2*n_elem_per_reg ), + yv[2].v + ); + yv[3].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 3*n_elem_per_reg ), + yv[3].v + ); + yv[4].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 4*n_elem_per_reg ), + yv[4].v + ); + + // storing the output + _mm256_storeu_ps( ( y0 + 0*n_elem_per_reg ), yv[0].v ); + _mm256_storeu_ps( ( y0 + 1*n_elem_per_reg ), yv[1].v ); + _mm256_storeu_ps( ( y0 + 2*n_elem_per_reg ), yv[2].v ); + _mm256_storeu_ps( ( y0 + 3*n_elem_per_reg ), yv[3].v ); + _mm256_storeu_ps( ( y0 + 4*n_elem_per_reg ), yv[4].v ); + + x0 += 5 * n_elem_per_reg; + y0 += 5 * n_elem_per_reg; + } + + // Processing 32 elements per loop, 4 FMAs + for ( ; ( i + 31 ) < n; i += 32 ) + { + // loading input values + yv[0].v = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); + yv[1].v = _mm256_loadu_ps( y0 + 1*n_elem_per_reg ); + yv[2].v = _mm256_loadu_ps( y0 + 2*n_elem_per_reg ); + yv[3].v = _mm256_loadu_ps( y0 + 3*n_elem_per_reg ); + + // y' := y := beta * y + yv[0].v = _mm256_mul_ps( betav.v, yv[0].v ); + yv[1].v = _mm256_mul_ps( betav.v, yv[1].v ); + yv[2].v = _mm256_mul_ps( betav.v, yv[2].v ); + yv[3].v = _mm256_mul_ps( betav.v, yv[3].v ); + + // y := y' + alpha * x + yv[0].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 0*n_elem_per_reg ), + yv[0].v + ); + yv[1].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 1*n_elem_per_reg ), + yv[1].v + ); + yv[2].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 2*n_elem_per_reg ), + yv[2].v + ); + yv[3].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 3*n_elem_per_reg ), + yv[3].v + ); + + // storing the output + _mm256_storeu_ps( ( y0 + 0*n_elem_per_reg ), yv[0].v ); + _mm256_storeu_ps( ( y0 + 1*n_elem_per_reg ), yv[1].v ); + _mm256_storeu_ps( ( y0 + 2*n_elem_per_reg ), yv[2].v ); + _mm256_storeu_ps( ( y0 + 3*n_elem_per_reg ), yv[3].v ); + + x0 += 4 * n_elem_per_reg; + y0 += 4 * n_elem_per_reg; + } + + // Processing 16 elements per loop, 2 FMAs + for ( ; ( i + 15 ) < n; i += 16 ) + { + // loading input values + yv[0].v = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); + yv[1].v = _mm256_loadu_ps( y0 + 1*n_elem_per_reg ); + + // y' := y := beta * y + yv[0].v = _mm256_mul_ps( betav.v, yv[0].v ); + yv[1].v = _mm256_mul_ps( betav.v, yv[1].v ); + + // y := y' + alpha * x + yv[0].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 0*n_elem_per_reg ), + yv[0].v + ); + yv[1].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 1*n_elem_per_reg ), + yv[1].v + ); + + // storing the output + _mm256_storeu_ps( ( y0 + 0*n_elem_per_reg ), yv[0].v ); + _mm256_storeu_ps( ( y0 + 1*n_elem_per_reg ), yv[1].v ); + + x0 += 2 * n_elem_per_reg; + y0 += 2 * n_elem_per_reg; + } + + // Processing 8 elements per loop, 1 FMA + for ( ; ( i + 7 ) < n; i += 8 ) + { + // loading input values + yv[0].v = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); + + // y' := y := beta * y + yv[0].v = _mm256_mul_ps( betav.v, yv[0].v ); + + // y := y' + alpha * x + yv[0].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 0*n_elem_per_reg ), + yv[0].v + ); + + // storing the output + _mm256_storeu_ps( ( y0 + 0*n_elem_per_reg ), yv[0].v ); + + x0 += 1 * n_elem_per_reg; + y0 += 1 * n_elem_per_reg; + } + + // Issue vzeroupper instruction to clear upper lanes of ymm registers. + // This avoids a performance penalty caused by false dependencies when + // transitioning from from AVX to SSE instructions (which may occur + // as soon as the n_left cleanup loop below if BLIS is compiled with + // -mfpmath=sse). + _mm256_zeroupper(); + + // if there are leftover iterations, perform them with scaler code + for ( ; i < n; i++ ) + { + *y0 = ( (*alpha) * (*x0) ) + ( (*beta) * (*y0) ); + + x0 += incx; + y0 += incy; + } + } + else + { + // for non-unit increments, use scaler code + for ( i = 0; i < n; ++i ) + { + *y0 = ( (*alpha) * (*x0) ) + ( (*beta) * (*y0) ); + + x0 += incx; + y0 += incy; + } + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) +} + +/** + * daxpbyv kernel performs the axpbyv operation. + * y := beta * y + alpha * conjx(x) + * where, + * x & y are double precision vectors of length n. + * alpha & beta are scalers. + */ +void bli_daxpbyv_zen_int10 + ( + conj_t conjx, + dim_t n, + double* restrict alpha, + double* restrict x, inc_t incx, + double* restrict beta, + double* restrict y, inc_t incy, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_4) + const dim_t n_elem_per_reg = 4; // number of elements per register + const dim_t n_iter_unroll = 10; // number of registers per iteration + + dim_t i; // iterator + + double* restrict x0; + double* restrict y0; + + v4df_t alphav; + v4df_t betav; + v4df_t y0v, y1v, y2v, y3v, y4v, y5v, y6v, y7v, y8v, y9v; + + /* if the vector dimension is zero, or if alpha & beta are zero, + return early. */ + if ( bli_zero_dim1( n ) || + ( PASTEMAC( s, eq0 )( *alpha ) && PASTEMAC( s, eq0 )( *beta ) ) ) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) + return; + } + + // initialize local pointers + x0 = x; + y0 = y; + + if ( incx == 1 && incy == 1 ) + { + // broadcast alpha & beta to all elements of respective vector registers + alphav.v = _mm256_broadcast_sd( alpha ); + betav.v = _mm256_broadcast_sd( beta ); + + // Using 10 FMAs per loop + for ( i = 0; ( i + 39 ) < n; i += 40 ) + { + // loading input y + y0v.v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + y1v.v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + y2v.v = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); + y3v.v = _mm256_loadu_pd( y0 + 3*n_elem_per_reg ); + y4v.v = _mm256_loadu_pd( y0 + 4*n_elem_per_reg ); + y5v.v = _mm256_loadu_pd( y0 + 5*n_elem_per_reg ); + y6v.v = _mm256_loadu_pd( y0 + 6*n_elem_per_reg ); + y7v.v = _mm256_loadu_pd( y0 + 7*n_elem_per_reg ); + y8v.v = _mm256_loadu_pd( y0 + 8*n_elem_per_reg ); + y9v.v = _mm256_loadu_pd( y0 + 9*n_elem_per_reg ); + + // y' := y := beta * y + y0v.v = _mm256_mul_pd( betav.v, y0v.v ); + y1v.v = _mm256_mul_pd( betav.v, y1v.v ); + y2v.v = _mm256_mul_pd( betav.v, y2v.v ); + y3v.v = _mm256_mul_pd( betav.v, y3v.v ); + y4v.v = _mm256_mul_pd( betav.v, y4v.v ); + y5v.v = _mm256_mul_pd( betav.v, y5v.v ); + y6v.v = _mm256_mul_pd( betav.v, y6v.v ); + y7v.v = _mm256_mul_pd( betav.v, y7v.v ); + y8v.v = _mm256_mul_pd( betav.v, y8v.v ); + y9v.v = _mm256_mul_pd( betav.v, y9v.v ); + + // y := y' + alpha * x + // := beta * y + alpha * x + y0v.v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 0*n_elem_per_reg ), + y0v.v + ); + y1v.v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 1*n_elem_per_reg ), + y1v.v + ); + y2v.v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 2*n_elem_per_reg ), + y2v.v + ); + y3v.v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 3*n_elem_per_reg ), + y3v.v + ); + y4v.v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 4*n_elem_per_reg ), + y4v.v + ); + y5v.v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 5*n_elem_per_reg ), + y5v.v + ); + y6v.v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 6*n_elem_per_reg ), + y6v.v + ); + y7v.v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 7*n_elem_per_reg ), + y7v.v + ); + y8v.v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 8*n_elem_per_reg ), + y8v.v + ); + y9v.v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 9*n_elem_per_reg ), + y9v.v + ); + + // storing the output + _mm256_storeu_pd( ( y0 + 0*n_elem_per_reg ), y0v.v ); + _mm256_storeu_pd( ( y0 + 1*n_elem_per_reg ), y1v.v ); + _mm256_storeu_pd( ( y0 + 2*n_elem_per_reg ), y2v.v ); + _mm256_storeu_pd( ( y0 + 3*n_elem_per_reg ), y3v.v ); + _mm256_storeu_pd( ( y0 + 4*n_elem_per_reg ), y4v.v ); + _mm256_storeu_pd( ( y0 + 5*n_elem_per_reg ), y5v.v ); + _mm256_storeu_pd( ( y0 + 6*n_elem_per_reg ), y6v.v ); + _mm256_storeu_pd( ( y0 + 7*n_elem_per_reg ), y7v.v ); + _mm256_storeu_pd( ( y0 + 8*n_elem_per_reg ), y8v.v ); + _mm256_storeu_pd( ( y0 + 9*n_elem_per_reg ), y9v.v ); + + x0 += n_elem_per_reg * n_iter_unroll; + y0 += n_elem_per_reg * n_iter_unroll; + } + + // Using 5 FMAs per loop + for ( ; ( i + 19 ) < n; i += 20 ) + { + // loading input y + y0v.v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + y1v.v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + y2v.v = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); + y3v.v = _mm256_loadu_pd( y0 + 3*n_elem_per_reg ); + y4v.v = _mm256_loadu_pd( y0 + 4*n_elem_per_reg ); + + // y' := y := beta * y + y0v.v = _mm256_mul_pd( betav.v, y0v.v ); + y1v.v = _mm256_mul_pd( betav.v, y1v.v ); + y2v.v = _mm256_mul_pd( betav.v, y2v.v ); + y3v.v = _mm256_mul_pd( betav.v, y3v.v ); + y4v.v = _mm256_mul_pd( betav.v, y4v.v ); + + // y := y' + alpha * x + // := beta * y + alpha * x + y0v.v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 0*n_elem_per_reg ), + y0v.v + ); + y1v.v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 1*n_elem_per_reg ), + y1v.v + ); + y2v.v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 2*n_elem_per_reg ), + y2v.v + ); + y3v.v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 3*n_elem_per_reg ), + y3v.v + ); + y4v.v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 4*n_elem_per_reg ), + y4v.v + ); + + // storing the output + _mm256_storeu_pd( ( y0 + 0*n_elem_per_reg ), y0v.v ); + _mm256_storeu_pd( ( y0 + 1*n_elem_per_reg ), y1v.v ); + _mm256_storeu_pd( ( y0 + 2*n_elem_per_reg ), y2v.v ); + _mm256_storeu_pd( ( y0 + 3*n_elem_per_reg ), y3v.v ); + _mm256_storeu_pd( ( y0 + 4*n_elem_per_reg ), y4v.v ); + + x0 += n_elem_per_reg * 5; + y0 += n_elem_per_reg * 5; + } + + // Using 2 FMAs per loop + for ( ; ( i + 7 ) < n; i += 8 ) + { + // loading input y + y0v.v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + y1v.v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + + // y' := y := beta * y + y0v.v = _mm256_mul_pd( betav.v, y0v.v ); + y1v.v = _mm256_mul_pd( betav.v, y1v.v ); + + // y := y' + alpha * x + // := beta * y + alpha * x + y0v.v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 0*n_elem_per_reg ), + y0v.v + ); + y1v.v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 1*n_elem_per_reg ), + y1v.v + ); + + // storing the output + _mm256_storeu_pd( ( y0 + 0*n_elem_per_reg ), y0v.v ); + _mm256_storeu_pd( ( y0 + 1*n_elem_per_reg ), y1v.v ); + + x0 += n_elem_per_reg * 2; + y0 += n_elem_per_reg * 2; + } + + // Using 1 FMAs per loop + for ( ; ( i + 3 ) < n; i += 4 ) + { + // loading input y + y0v.v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + + // y' := y := beta * y + y0v.v = _mm256_mul_pd( betav.v, y0v.v ); + + // y := y' + alpha * x + // := beta * y + alpha * x + y0v.v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 0*n_elem_per_reg ), + y0v.v + ); + + // storing the output + _mm256_storeu_pd( ( y0 + 0*n_elem_per_reg ), y0v.v ); + + x0 += n_elem_per_reg * 1; + y0 += n_elem_per_reg * 1; + } + + // Issue vzeroupper instruction to clear upper lanes of ymm registers. + // This avoids a performance penalty caused by false dependencies when + // transitioning from from AVX to SSE instructions (which may occur + // as soon as the n_left cleanup loop below if BLIS is compiled with + // -mfpmath=sse). + _mm256_zeroupper(); + + // if there are leftover iterations, perform them with scaler code + for ( ; i < n; ++i ) + { + *y0 = ( (*alpha) * (*x0) ) + ( (*beta) * (*y0) ); + + x0 += incx; + y0 += incy; + } + } + else + { + // for non-unit increments, use scaler code + for ( i = 0; i < n; ++i ) + { + *y0 = ( (*alpha) * (*x0) ) + ( (*beta) * (*y0) ); + + x0 += incx; + y0 += incy; + } + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) +} \ No newline at end of file diff --git a/kernels/zen/bli_kernels_zen.h b/kernels/zen/bli_kernels_zen.h index d46164a9c5..86640dff53 100644 --- a/kernels/zen/bli_kernels_zen.h +++ b/kernels/zen/bli_kernels_zen.h @@ -56,6 +56,16 @@ AMAXV_KER_PROT( float, s, amaxv_zen_int_avx512 ) AMAXV_KER_PROT( double, d, amaxv_zen_int ) AMAXV_KER_PROT( double, d, amaxv_zen_int_avx512 ) +// axpbyv (intrinsics) +AXPBYV_KER_PROT( float, s, axpbyv_zen_int ) +AXPBYV_KER_PROT( double, d, axpbyv_zen_int ) +AXPBYV_KER_PROT( scomplex, c, axpbyv_zen_int ) +AXPBYV_KER_PROT( dcomplex, z, axpbyv_zen_int ) + +// axpbyv (intrinsics, unrolled x10) +AXPBYV_KER_PROT( float, s, axpbyv_zen_int10 ) +AXPBYV_KER_PROT( double, d, axpbyv_zen_int10 ) + // axpyv (intrinsics) AXPYV_KER_PROT( float, s, axpyv_zen_int ) AXPYV_KER_PROT( double, d, axpyv_zen_int ) From 0e7073a60027e39af0c0d6c626d0a5d792b92c76 Mon Sep 17 00:00:00 2001 From: Harsh Dave Date: Tue, 7 Dec 2021 00:56:16 -0600 Subject: [PATCH 061/243] Optimized ztrsv implementation - Implemented alternate method of performing multiplication and addition operations on double precision complex datatype by separating out real and imaginary parts of complex number. - Optimal and reuse of vector registers for faster computation. AMD-Internal: [CPUPL-1969] Change-Id: Ib181f193c05740d5f6b9de3930e1995dea4a50f2 --- kernels/zen/1f/bli_axpyf_zen_int_5.c | 891 ++++++++++++++++----------- 1 file changed, 528 insertions(+), 363 deletions(-) diff --git a/kernels/zen/1f/bli_axpyf_zen_int_5.c b/kernels/zen/1f/bli_axpyf_zen_int_5.c index f770389196..1125197775 100644 --- a/kernels/zen/1f/bli_axpyf_zen_int_5.c +++ b/kernels/zen/1f/bli_axpyf_zen_int_5.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2020 - 21, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -1747,8 +1747,17 @@ void bli_caxpyf_zen_int_5 } -// ----------------------------------------------------------------------------- - +//------------------------------------------------------------------------------ +/** + * Following kernel performs axpyf operation on dcomplex data. + * Operate over 5 columns of a matrix at a time and march through + * rows in steps of 4 or 2. + * For optimal performance, it separate outs imaginary and real + * components of chis and broadcast them into separate ymm vector + * registers. + * By doing so it avoids necessity of permute operation to get the + * final result of dcomp-lex multiplication. + */ void bli_zaxpyf_zen_int_5 ( conj_t conja, @@ -1762,391 +1771,547 @@ void bli_zaxpyf_zen_int_5 cntx_t* restrict cntx ) { - const dim_t fuse_fac = 5; + const dim_t fuse_fac = 5; - const dim_t n_elem_per_reg = 2; - const dim_t n_iter_unroll = 2; + const dim_t n_elem_per_reg = 2; + const dim_t n_iter_unroll = 2; - dim_t i = 0; - dim_t setPlusOne = 1; + dim_t i = 0; + dim_t setPlusOne = 1; - v4df_t chi0v, chi1v, chi2v, chi3v, chi4v; - v4df_t chi5v, chi6v, chi7v, chi8v, chi9v; + v4df_t chi0v, chi1v, chi2v, chi3v, chi4v; + v4df_t chi5v, chi6v, chi7v, chi8v, chi9v; - v4df_t a00v, a01v, a02v, a03v, a04v; - v4df_t a05v, a06v, a07v, a08v, a09v; + v4df_t a00v, a01v, a02v, a03v, a04v; - v4df_t a10v, a11v, a12v, a13v, a14v; - v4df_t a15v, a16v, a17v, a18v, a19v; + v4df_t a10v, a11v, a12v, a13v, a14v; - v4df_t y0v, y1v; - v4df_t setMinus, setPlus; + v4df_t y0v, y1v, y2v, y3v; + v4df_t r0v, r1v, conjv; - dcomplex chi0, chi1, chi2, chi3, chi4; - dcomplex* restrict a0; - dcomplex* restrict a1; - dcomplex* restrict a2; - dcomplex* restrict a3; - dcomplex* restrict a4; + dcomplex chi0, chi1, chi2, chi3, chi4; + dcomplex* restrict a0; + dcomplex* restrict a1; + dcomplex* restrict a2; + dcomplex* restrict a3; + dcomplex* restrict a4; - dcomplex* restrict y0; + dcomplex* restrict y0; - if ( bli_is_conj(conja) ){ - setPlusOne = -1; - } + if ( bli_is_conj(conja) ){ + setPlusOne = -1; + } - // If either dimension is zero, or if alpha is zero, return early. - if ( bli_zero_dim2( m, b_n ) || bli_zeq0( *alpha ) ) return; + // If either dimension is zero, or if alpha is zero, return early. + if ( bli_zero_dim2( m, b_n ) || bli_zeq0( *alpha ) ) return; - // If b_n is not equal to the fusing factor, then perform the entire - // operation as a loop over axpyv. - if ( b_n != fuse_fac ) - { + // If b_n is not equal to the fusing factor, then perform the entire + // operation as a loop over axpyv. + if ( b_n != fuse_fac ) + { #ifdef BLIS_CONFIG_EPYC - for ( i = 0; i < b_n; ++i ) - { - dcomplex* a1 = a + (0 )*inca + (i )*lda; - dcomplex* chi1 = x + (i )*incx; - dcomplex* y1 = y + (0 )*incy; - dcomplex alpha_chi1; - - bli_zcopycjs( conjx, *chi1, alpha_chi1 ); - bli_zscals( *alpha, alpha_chi1 ); - - bli_zaxpyv_zen_int5 - ( - conja, - m, - &alpha_chi1, - a1, inca, - y1, incy, - cntx - ); - } + for ( i = 0; i < b_n; ++i ) + { + dcomplex* a1 = a + (0 )*inca + (i )*lda; + dcomplex* chi1 = x + (i )*incx; + dcomplex* y1 = y + (0 )*incy; + dcomplex alpha_chi1; + + bli_zcopycjs( conjx, *chi1, alpha_chi1 ); + bli_zscals( *alpha, alpha_chi1 ); + + bli_zaxpyv_zen_int5 + ( + conja, + m, + &alpha_chi1, + a1, inca, + y1, incy, + cntx + ); + } #else - zaxpyv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_DCOMPLEX, BLIS_AXPYV_KER, cntx ); - - for ( i = 0; i < b_n; ++i ) - { - dcomplex* a1 = a + (0 )*inca + (i )*lda; - dcomplex* chi1 = x + (i )*incx; - dcomplex* y1 = y + (0 )*incy; - dcomplex alpha_chi1; - - bli_zcopycjs( conjx, *chi1, alpha_chi1 ); - bli_zscals( *alpha, alpha_chi1 ); - - f - ( - conja, - m, - &alpha_chi1, - a1, inca, - y1, incy, - cntx - ); - } + zaxpyv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_DCOMPLEX, BLIS_AXPYV_KER, cntx ); + + for ( i = 0; i < b_n; ++i ) + { + dcomplex* a1 = a + (0 )*inca + (i )*lda; + dcomplex* chi1 = x + (i )*incx; + dcomplex* y1 = y + (0 )*incy; + dcomplex alpha_chi1; + + bli_zcopycjs( conjx, *chi1, alpha_chi1 ); + bli_zscals( *alpha, alpha_chi1 ); + + f + ( + conja, + m, + &alpha_chi1, + a1, inca, + y1, incy, + cntx + ); + } #endif - return; - } - - - // At this point, we know that b_n is exactly equal to the fusing factor. - - a0 = a + 0*lda; - a1 = a + 1*lda; - a2 = a + 2*lda; - a3 = a + 3*lda; - a4 = a + 4*lda; - y0 = y; - - chi0 = *( x + 0*incx ); - chi1 = *( x + 1*incx ); - chi2 = *( x + 2*incx ); - chi3 = *( x + 3*incx ); - chi4 = *( x + 4*incx ); - - dcomplex *pchi0 = x + 0*incx ; - dcomplex *pchi1 = x + 1*incx ; - dcomplex *pchi2 = x + 2*incx ; - dcomplex *pchi3 = x + 3*incx ; - dcomplex *pchi4 = x + 4*incx ; - - bli_zcopycjs( conjx, *pchi0, chi0 ); - bli_zcopycjs( conjx, *pchi1, chi1 ); - bli_zcopycjs( conjx, *pchi2, chi2 ); - bli_zcopycjs( conjx, *pchi3, chi3 ); - bli_zcopycjs( conjx, *pchi4, chi4 ); - - // Scale each chi scalar by alpha. - bli_zscals( *alpha, chi0 ); - bli_zscals( *alpha, chi1 ); - bli_zscals( *alpha, chi2 ); - bli_zscals( *alpha, chi3 ); - bli_zscals( *alpha, chi4 ); - - // Broadcast the (alpha*chi?) scalars to all elements of vector registers. - chi0v.v = _mm256_broadcast_sd( &chi0.real ); - chi1v.v = _mm256_broadcast_sd( &chi1.real ); - chi2v.v = _mm256_broadcast_sd( &chi2.real ); - chi3v.v = _mm256_broadcast_sd( &chi3.real ); - chi4v.v = _mm256_broadcast_sd( &chi4.real ); - - chi5v.v = _mm256_broadcast_sd( &chi0.imag ); - chi6v.v = _mm256_broadcast_sd( &chi1.imag ); - chi7v.v = _mm256_broadcast_sd( &chi2.imag ); - chi8v.v = _mm256_broadcast_sd( &chi3.imag ); - chi9v.v = _mm256_broadcast_sd( &chi4.imag ); - - // If there are vectorized iterations, perform them with vector - // instructions. - if ( inca == 1 && incy == 1 ) - { - setMinus.v = _mm256_set_pd( -1, 1, -1, 1 ); - - setPlus.v = _mm256_set1_pd( 1 ); - if ( bli_is_conj(conja) ){ - setPlus.v = _mm256_set_pd( -1, 1, -1, 1 ); - } - - /* - y := y + alpha * conja(A) * conjx(x) - - nn - (ar + ai) (xr + xi) - ar * xr - ai * xi - ar * xi + ai * xr - - cc : (ar - ai) (xr - xi) - ar * xr - ai * xi - -(ar * xi + ai * xr) - - nc : (ar + ai) (xr - xi) - ar * xr + ai * xi - -(ar * xi - ai * xr) - - cn : (ar - ai) (xr + xi) - ar * xr + ai * xi - ar * xi - ai * xr - - */ - - for( i = 0; (i + 3) < m; i += 4 ) - { - // Load the input values. - y0v.v = _mm256_loadu_pd( (double*) (y0 + 0*n_elem_per_reg )); - y1v.v = _mm256_loadu_pd( (double*) (y0 + 1*n_elem_per_reg )); - - a00v.v = _mm256_loadu_pd( (double*) (a0 + 0*n_elem_per_reg )); - a10v.v = _mm256_loadu_pd( (double*) (a0 + 1*n_elem_per_reg )); - - a01v.v = _mm256_loadu_pd( (double*) (a1 + 0*n_elem_per_reg )); - a11v.v = _mm256_loadu_pd( (double*) (a1 + 1*n_elem_per_reg )); - - a02v.v = _mm256_loadu_pd( (double*) (a2 + 0*n_elem_per_reg )); - a12v.v = _mm256_loadu_pd( (double*) (a2 + 1*n_elem_per_reg )); - - a03v.v = _mm256_loadu_pd( (double*) (a3 + 0*n_elem_per_reg )); - a13v.v = _mm256_loadu_pd( (double*) (a3 + 1*n_elem_per_reg )); - - a04v.v = _mm256_loadu_pd( (double*) (a4 + 0*n_elem_per_reg )); - a14v.v = _mm256_loadu_pd( (double*) (a4 + 1*n_elem_per_reg )); - - a00v.v = _mm256_mul_pd( a00v.v, setPlus.v ); - a01v.v = _mm256_mul_pd( a01v.v, setPlus.v ); - a02v.v = _mm256_mul_pd( a02v.v, setPlus.v ); - a03v.v = _mm256_mul_pd( a03v.v, setPlus.v ); - a04v.v = _mm256_mul_pd( a04v.v, setPlus.v ); - - a05v.v = _mm256_mul_pd( a00v.v, setMinus.v ); - a06v.v = _mm256_mul_pd( a01v.v, setMinus.v ); - a07v.v = _mm256_mul_pd( a02v.v, setMinus.v ); - a08v.v = _mm256_mul_pd( a03v.v, setMinus.v ); - a09v.v = _mm256_mul_pd( a04v.v, setMinus.v ); - - a05v.v = _mm256_permute_pd( a05v.v, 5 ); - a06v.v = _mm256_permute_pd( a06v.v, 5 ); - a07v.v = _mm256_permute_pd( a07v.v, 5 ); - a08v.v = _mm256_permute_pd( a08v.v, 5 ); - a09v.v = _mm256_permute_pd( a09v.v, 5 ); - - a10v.v = _mm256_mul_pd( a10v.v, setPlus.v ); - a11v.v = _mm256_mul_pd( a11v.v, setPlus.v ); - a12v.v = _mm256_mul_pd( a12v.v, setPlus.v ); - a13v.v = _mm256_mul_pd( a13v.v, setPlus.v ); - a14v.v = _mm256_mul_pd( a14v.v, setPlus.v ); - - a15v.v = _mm256_mul_pd( a10v.v, setMinus.v ); - a16v.v = _mm256_mul_pd( a11v.v, setMinus.v ); - a17v.v = _mm256_mul_pd( a12v.v, setMinus.v ); - a18v.v = _mm256_mul_pd( a13v.v, setMinus.v ); - a19v.v = _mm256_mul_pd( a14v.v, setMinus.v ); - - a15v.v = _mm256_permute_pd( a15v.v, 5 ); - a16v.v = _mm256_permute_pd( a16v.v, 5 ); - a17v.v = _mm256_permute_pd( a17v.v, 5 ); - a18v.v = _mm256_permute_pd( a18v.v, 5 ); - a19v.v = _mm256_permute_pd( a19v.v, 5 ); - - // perform : y += alpha * x; - y0v.v = _mm256_fmadd_pd( a00v.v, chi0v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a01v.v, chi1v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a02v.v, chi2v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a03v.v, chi3v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a04v.v, chi4v.v, y0v.v ); - - y0v.v = _mm256_fmadd_pd( a05v.v, chi5v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a06v.v, chi6v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a07v.v, chi7v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a08v.v, chi8v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a09v.v, chi9v.v, y0v.v ); + return; + } + + + // At this point, we know that b_n is exactly equal to the fusing factor. + + a0 = a + 0*lda; + a1 = a + 1*lda; + a2 = a + 2*lda; + a3 = a + 3*lda; + a4 = a + 4*lda; + y0 = y; + + chi0 = *( x + 0*incx ); + chi1 = *( x + 1*incx ); + chi2 = *( x + 2*incx ); + chi3 = *( x + 3*incx ); + chi4 = *( x + 4*incx ); + + dcomplex *pchi0 = x + 0*incx ; + dcomplex *pchi1 = x + 1*incx ; + dcomplex *pchi2 = x + 2*incx ; + dcomplex *pchi3 = x + 3*incx ; + dcomplex *pchi4 = x + 4*incx ; + + bli_zcopycjs( conjx, *pchi0, chi0 ); + bli_zcopycjs( conjx, *pchi1, chi1 ); + bli_zcopycjs( conjx, *pchi2, chi2 ); + bli_zcopycjs( conjx, *pchi3, chi3 ); + bli_zcopycjs( conjx, *pchi4, chi4 ); + + // Scale each chi scalar by alpha. + bli_zscals( *alpha, chi0 ); + bli_zscals( *alpha, chi1 ); + bli_zscals( *alpha, chi2 ); + bli_zscals( *alpha, chi3 ); + bli_zscals( *alpha, chi4 ); + + // Broadcast the (alpha*chi?) scalars to all elements of vector registers. + chi0v.v = _mm256_broadcast_sd( &chi0.real ); + chi1v.v = _mm256_broadcast_sd( &chi1.real ); + chi2v.v = _mm256_broadcast_sd( &chi2.real ); + chi3v.v = _mm256_broadcast_sd( &chi3.real ); + chi4v.v = _mm256_broadcast_sd( &chi4.real ); + + chi5v.v = _mm256_broadcast_sd( &chi0.imag ); + chi6v.v = _mm256_broadcast_sd( &chi1.imag ); + chi7v.v = _mm256_broadcast_sd( &chi2.imag ); + chi8v.v = _mm256_broadcast_sd( &chi3.imag ); + chi9v.v = _mm256_broadcast_sd( &chi4.imag ); + + // If there are vectorized iterations, perform them with vector + // instructions. + if ( inca == 1 && incy == 1 ) + { + // March through vectors in multiple of 4. + for( i = 0; (i + 3) < m; i += 4 ) + { + // Load the input values. + r0v.v = _mm256_loadu_pd( (double*) (y0 + 0*n_elem_per_reg )); + r1v.v = _mm256_loadu_pd( (double*) (y0 + 1*n_elem_per_reg )); + + y0v.v = _mm256_setzero_pd(); + y1v.v = _mm256_setzero_pd(); + y2v.v = _mm256_setzero_pd(); + y3v.v = _mm256_setzero_pd(); + + if ( bli_is_conj(conja) ){ + /** + * For conjugate cases imaginary part + * is negated. + */ + conjv.v = _mm256_set_pd( -1, 1, -1, 1 ); + a00v.v = _mm256_loadu_pd( (double*) (a0 + 0*n_elem_per_reg )); + a10v.v = _mm256_loadu_pd( (double*) (a0 + 1*n_elem_per_reg )); + + a01v.v = _mm256_loadu_pd( (double*) (a1 + 0*n_elem_per_reg )); + a11v.v = _mm256_loadu_pd( (double*) (a1 + 1*n_elem_per_reg )); + + a02v.v = _mm256_loadu_pd( (double*) (a2 + 0*n_elem_per_reg )); + a12v.v = _mm256_loadu_pd( (double*) (a2 + 1*n_elem_per_reg )); + + a03v.v = _mm256_loadu_pd( (double*) (a3 + 0*n_elem_per_reg )); + a13v.v = _mm256_loadu_pd( (double*) (a3 + 1*n_elem_per_reg )); + + a04v.v = _mm256_loadu_pd( (double*) (a4 + 0*n_elem_per_reg )); + a14v.v = _mm256_loadu_pd( (double*) (a4 + 1*n_elem_per_reg )); + + a00v.v = _mm256_mul_pd(a00v.v, conjv.v); + a10v.v = _mm256_mul_pd(a10v.v, conjv.v); + a01v.v = _mm256_mul_pd(a01v.v, conjv.v); + a11v.v = _mm256_mul_pd(a11v.v, conjv.v); + a02v.v = _mm256_mul_pd(a02v.v, conjv.v); + a12v.v = _mm256_mul_pd(a12v.v, conjv.v); + a03v.v = _mm256_mul_pd(a03v.v, conjv.v); + a13v.v = _mm256_mul_pd(a13v.v, conjv.v); + a04v.v = _mm256_mul_pd(a04v.v, conjv.v); + a14v.v = _mm256_mul_pd(a14v.v, conjv.v); + } + else + { + a00v.v = _mm256_loadu_pd( (double*) (a0 + 0*n_elem_per_reg )); + a10v.v = _mm256_loadu_pd( (double*) (a0 + 1*n_elem_per_reg )); + + a01v.v = _mm256_loadu_pd( (double*) (a1 + 0*n_elem_per_reg )); + a11v.v = _mm256_loadu_pd( (double*) (a1 + 1*n_elem_per_reg )); + + a02v.v = _mm256_loadu_pd( (double*) (a2 + 0*n_elem_per_reg )); + a12v.v = _mm256_loadu_pd( (double*) (a2 + 1*n_elem_per_reg )); + + a03v.v = _mm256_loadu_pd( (double*) (a3 + 0*n_elem_per_reg )); + a13v.v = _mm256_loadu_pd( (double*) (a3 + 1*n_elem_per_reg )); + + a04v.v = _mm256_loadu_pd( (double*) (a4 + 0*n_elem_per_reg )); + a14v.v = _mm256_loadu_pd( (double*) (a4 + 1*n_elem_per_reg )); + + } + + // perform : y += alpha * x; + /** + * chi[x]v.v holds real part of chi. + * chi[x]v.v holds imag part of chi. + * ys holds following computation: + * + * a[xx]v.v R1 I1 R2 I2 + * chi[x]v.v chi_R chi_R chi_R chi_R + * chi[x]v.v chi_I chi_I chi_I chi_I + * y[x]v.v R1*chi_R I1*chi_R R2*chi_R I2*chiR (compute with chi-real part) + * y[x]v.v R1*chi_I I1*chi_I R2*chi_I I2*chiI (compute with chi-imag part) + * + */ + y0v.v = _mm256_mul_pd( a00v.v, chi0v.v); + y1v.v = _mm256_mul_pd( a10v.v, chi0v.v); + + y2v.v = _mm256_mul_pd( a00v.v, chi5v.v); + y3v.v = _mm256_mul_pd( a10v.v, chi5v.v); + + /** + * y0v.v & y1v.v holds computation with real part of chi. + * y2v.v & y3v.v holds computaion with imag part of chi. + * Permute will swap the positions of elements in y2v.v & y3v.v + * as we need to perform: [ R*R + I*I & R*I + I*R]. + * Once dcomplex multiplication is done add the result into r0v.v + * r1v.v which holds axpy result of current tile which is being + * computed. + */ + y2v.v = _mm256_permute_pd(y2v.v, 0x5); + y3v.v = _mm256_permute_pd(y3v.v, 0x5); + y0v.v = _mm256_addsub_pd(y0v.v, y2v.v); + y1v.v = _mm256_addsub_pd(y1v.v, y3v.v); + + r0v.v = _mm256_add_pd(y0v.v, r0v.v); + r1v.v = _mm256_add_pd(y1v.v, r1v.v); + + y0v.v = _mm256_setzero_pd(); + y1v.v = _mm256_setzero_pd(); + y2v.v = _mm256_setzero_pd(); + y3v.v = _mm256_setzero_pd(); + + /** + * Repeat the same computation as above + * for remaining tile. + */ + y0v.v = _mm256_mul_pd( a01v.v, chi1v.v ); + y1v.v = _mm256_mul_pd( a11v.v, chi1v.v ); + + y2v.v = _mm256_mul_pd( a01v.v, chi6v.v ); + y3v.v = _mm256_mul_pd( a11v.v, chi6v.v ); + + y2v.v = _mm256_permute_pd(y2v.v, 0x5); + y3v.v = _mm256_permute_pd(y3v.v, 0x5); + y0v.v = _mm256_addsub_pd(y0v.v, y2v.v); + y1v.v = _mm256_addsub_pd(y1v.v, y3v.v); + + r0v.v = _mm256_add_pd(y0v.v, r0v.v); + r1v.v = _mm256_add_pd(y1v.v, r1v.v); + + y0v.v = _mm256_setzero_pd(); + y1v.v = _mm256_setzero_pd(); + y2v.v = _mm256_setzero_pd(); + y3v.v = _mm256_setzero_pd(); + + + y0v.v = _mm256_mul_pd( a02v.v, chi2v.v); + y1v.v = _mm256_mul_pd( a12v.v, chi2v.v); + + y2v.v = _mm256_mul_pd( a02v.v, chi7v.v ); + y3v.v = _mm256_mul_pd( a12v.v, chi7v.v ); + + y2v.v = _mm256_permute_pd(y2v.v, 0x5); + y3v.v = _mm256_permute_pd(y3v.v, 0x5); + y0v.v = _mm256_addsub_pd(y0v.v, y2v.v); + y1v.v = _mm256_addsub_pd(y1v.v, y3v.v); + + r0v.v = _mm256_add_pd(y0v.v, r0v.v); + r1v.v = _mm256_add_pd(y1v.v, r1v.v); + + y0v.v = _mm256_setzero_pd(); + y1v.v = _mm256_setzero_pd(); + y2v.v = _mm256_setzero_pd(); + y3v.v = _mm256_setzero_pd(); + + + y0v.v = _mm256_mul_pd( a03v.v, chi3v.v ); + y1v.v = _mm256_mul_pd( a13v.v, chi3v.v ); + + y2v.v = _mm256_mul_pd( a03v.v, chi8v.v ); + y3v.v = _mm256_mul_pd( a13v.v, chi8v.v ); + + y2v.v = _mm256_permute_pd(y2v.v, 0x5); + y3v.v = _mm256_permute_pd(y3v.v, 0x5); + y0v.v = _mm256_addsub_pd(y0v.v, y2v.v); + y1v.v = _mm256_addsub_pd(y1v.v, y3v.v); + + r0v.v = _mm256_add_pd(y0v.v, r0v.v); + r1v.v = _mm256_add_pd(y1v.v, r1v.v); + + y0v.v = _mm256_setzero_pd(); + y1v.v = _mm256_setzero_pd(); + y2v.v = _mm256_setzero_pd(); + y3v.v = _mm256_setzero_pd(); + + + y0v.v = _mm256_mul_pd( a04v.v, chi4v.v ); + y1v.v = _mm256_mul_pd( a14v.v, chi4v.v ); + + y2v.v = _mm256_mul_pd( a04v.v, chi9v.v ); + y3v.v = _mm256_mul_pd( a14v.v, chi9v.v ); + + y2v.v = _mm256_permute_pd(y2v.v, 0x5); + y3v.v = _mm256_permute_pd(y3v.v, 0x5); + y0v.v = _mm256_addsub_pd(y0v.v, y2v.v); + y1v.v = _mm256_addsub_pd(y1v.v, y3v.v); + + r0v.v = _mm256_add_pd(y0v.v, r0v.v); + r1v.v = _mm256_add_pd(y1v.v, r1v.v); + + /** + * Final axpy compuation is available in r0v.v + * and r1v.v registers. + * Store it back into y vector. + */ + _mm256_storeu_pd( (double*) (y0 + 0*n_elem_per_reg), r0v.v ); + _mm256_storeu_pd( (double*) (y0 + 1*n_elem_per_reg), r1v.v ); + + /** + * Set the pointers next vectors elements to be + * computed based on unroll factor. + */ + y0 += n_elem_per_reg * n_iter_unroll; + a0 += n_elem_per_reg * n_iter_unroll; + a1 += n_elem_per_reg * n_iter_unroll; + a2 += n_elem_per_reg * n_iter_unroll; + a3 += n_elem_per_reg * n_iter_unroll; + a4 += n_elem_per_reg * n_iter_unroll; + } + // March through vectors in multiple of 2. + for( ; (i + 1) < m; i += 2 ) + { + r0v.v = _mm256_loadu_pd( (double*) (y0 + 0*n_elem_per_reg )); + + y0v.v = _mm256_setzero_pd(); + y2v.v = _mm256_setzero_pd(); - // For next 4 elements perform : y += alpha * x; - y1v.v = _mm256_fmadd_pd( a10v.v, chi0v.v, y1v.v ); - y1v.v = _mm256_fmadd_pd( a11v.v, chi1v.v, y1v.v ); - y1v.v = _mm256_fmadd_pd( a12v.v, chi2v.v, y1v.v ); - y1v.v = _mm256_fmadd_pd( a13v.v, chi3v.v, y1v.v ); - y1v.v = _mm256_fmadd_pd( a14v.v, chi4v.v, y1v.v ); + if ( bli_is_conj(conja) ){ + conjv.v = _mm256_set_pd( -1, 1, -1, 1 ); + a00v.v = _mm256_loadu_pd( (double*) (a0 + 0*n_elem_per_reg )); - y1v.v = _mm256_fmadd_pd( a15v.v, chi5v.v, y1v.v ); - y1v.v = _mm256_fmadd_pd( a16v.v, chi6v.v, y1v.v ); - y1v.v = _mm256_fmadd_pd( a17v.v, chi7v.v, y1v.v ); - y1v.v = _mm256_fmadd_pd( a18v.v, chi8v.v, y1v.v ); - y1v.v = _mm256_fmadd_pd( a19v.v, chi9v.v, y1v.v ); + a01v.v = _mm256_loadu_pd( (double*) (a1 + 0*n_elem_per_reg )); - // Store the output. - _mm256_storeu_pd( (double*) (y0 + 0*n_elem_per_reg), y0v.v ); - _mm256_storeu_pd( (double*) (y0 + 1*n_elem_per_reg), y1v.v ); + a02v.v = _mm256_loadu_pd( (double*) (a2 + 0*n_elem_per_reg )); - y0 += n_elem_per_reg * n_iter_unroll; - a0 += n_elem_per_reg * n_iter_unroll; - a1 += n_elem_per_reg * n_iter_unroll; - a2 += n_elem_per_reg * n_iter_unroll; - a3 += n_elem_per_reg * n_iter_unroll; - a4 += n_elem_per_reg * n_iter_unroll; - } - for( ; (i + 1) < m; i += 2 ) - { - // Load the input values. - y0v.v = _mm256_loadu_pd( (double*) (y0 + 0*n_elem_per_reg )); - - a00v.v = _mm256_loadu_pd( (double*)(a0 + 0*n_elem_per_reg) ); - a01v.v = _mm256_loadu_pd( (double*)(a1 + 0*n_elem_per_reg) ); - a02v.v = _mm256_loadu_pd( (double*)(a2 + 0*n_elem_per_reg) ); - a03v.v = _mm256_loadu_pd( (double*)(a3 + 0*n_elem_per_reg) ); - a04v.v = _mm256_loadu_pd( (double*)(a4 + 0*n_elem_per_reg) ); - - a00v.v = _mm256_mul_pd( a00v.v, setPlus.v ); - a01v.v = _mm256_mul_pd( a01v.v, setPlus.v ); - a02v.v = _mm256_mul_pd( a02v.v, setPlus.v ); - a03v.v = _mm256_mul_pd( a03v.v, setPlus.v ); - a04v.v = _mm256_mul_pd( a04v.v, setPlus.v ); - - a05v.v = _mm256_mul_pd( a00v.v, setMinus.v ); - a06v.v = _mm256_mul_pd( a01v.v, setMinus.v ); - a07v.v = _mm256_mul_pd( a02v.v, setMinus.v ); - a08v.v = _mm256_mul_pd( a03v.v, setMinus.v ); - a09v.v = _mm256_mul_pd( a04v.v, setMinus.v ); - - a05v.v = _mm256_permute_pd( a05v.v, 5 ); - a06v.v = _mm256_permute_pd( a06v.v, 5 ); - a07v.v = _mm256_permute_pd( a07v.v, 5 ); - a08v.v = _mm256_permute_pd( a08v.v, 5 ); - a09v.v = _mm256_permute_pd( a09v.v, 5 ); + a03v.v = _mm256_loadu_pd( (double*) (a3 + 0*n_elem_per_reg )); - // perform : y += alpha * x; - y0v.v = _mm256_fmadd_pd( a00v.v, chi0v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a01v.v, chi1v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a02v.v, chi2v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a03v.v, chi3v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a04v.v, chi4v.v, y0v.v ); + a04v.v = _mm256_loadu_pd( (double*) (a4 + 0*n_elem_per_reg )); - y0v.v = _mm256_fmadd_pd( a05v.v, chi5v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a06v.v, chi6v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a07v.v, chi7v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a08v.v, chi8v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a09v.v, chi9v.v, y0v.v ); + a00v.v = _mm256_mul_pd(a00v.v, conjv.v); + a01v.v = _mm256_mul_pd(a01v.v, conjv.v); + a02v.v = _mm256_mul_pd(a02v.v, conjv.v); + a03v.v = _mm256_mul_pd(a03v.v, conjv.v); + a04v.v = _mm256_mul_pd(a04v.v, conjv.v); + } + else + { + a00v.v = _mm256_loadu_pd( (double*) (a0 + 0*n_elem_per_reg )); - // Store the output. - _mm256_storeu_pd( (double *)(y0 + 0*n_elem_per_reg), y0v.v ); + a01v.v = _mm256_loadu_pd( (double*) (a1 + 0*n_elem_per_reg )); - y0 += n_elem_per_reg ; - a0 += n_elem_per_reg ; - a1 += n_elem_per_reg ; - a2 += n_elem_per_reg ; - a3 += n_elem_per_reg ; - a4 += n_elem_per_reg ; - } - // If there are leftover iterations, perform them with scalar code. - for ( ; (i + 0) < m ; ++i ) - { - dcomplex y0c = *y0; - - const dcomplex a0c = *a0; - const dcomplex a1c = *a1; - const dcomplex a2c = *a2; - const dcomplex a3c = *a3; - const dcomplex a4c = *a4; - - y0c.real += chi0.real * a0c.real - chi0.imag * a0c.imag * setPlusOne; - y0c.real += chi1.real * a1c.real - chi1.imag * a1c.imag * setPlusOne; - y0c.real += chi2.real * a2c.real - chi2.imag * a2c.imag * setPlusOne; - y0c.real += chi3.real * a3c.real - chi3.imag * a3c.imag * setPlusOne; - y0c.real += chi4.real * a4c.real - chi4.imag * a4c.imag * setPlusOne; + a02v.v = _mm256_loadu_pd( (double*) (a2 + 0*n_elem_per_reg )); + + a03v.v = _mm256_loadu_pd( (double*) (a3 + 0*n_elem_per_reg )); + + a04v.v = _mm256_loadu_pd( (double*) (a4 + 0*n_elem_per_reg )); - y0c.imag += chi0.imag * a0c.real + chi0.real * a0c.imag * setPlusOne; - y0c.imag += chi1.imag * a1c.real + chi1.real * a1c.imag * setPlusOne; - y0c.imag += chi2.imag * a2c.real + chi2.real * a2c.imag * setPlusOne; - y0c.imag += chi3.imag * a3c.real + chi3.real * a3c.imag * setPlusOne; - y0c.imag += chi4.imag * a4c.real + chi4.real * a4c.imag * setPlusOne; - - *y0 = y0c; - - a0 += 1; - a1 += 1; - a2 += 1; - a3 += 1; - a4 += 1; - y0 += 1; - } - } - else - { - for ( ; (i + 0) < m ; ++i ) - { - dcomplex y0c = *y0; - - const dcomplex a0c = *a0; - const dcomplex a1c = *a1; - const dcomplex a2c = *a2; - const dcomplex a3c = *a3; - const dcomplex a4c = *a4; - - y0c.real += chi0.real * a0c.real - chi0.imag * a0c.imag * setPlusOne; - y0c.real += chi1.real * a1c.real - chi1.imag * a1c.imag * setPlusOne; - y0c.real += chi2.real * a2c.real - chi2.imag * a2c.imag * setPlusOne; - y0c.real += chi3.real * a3c.real - chi3.imag * a3c.imag * setPlusOne; - y0c.real += chi4.real * a4c.real - chi4.imag * a4c.imag * setPlusOne; - - y0c.imag += chi0.imag * a0c.real + chi0.real * a0c.imag * setPlusOne; - y0c.imag += chi1.imag * a1c.real + chi1.real * a1c.imag * setPlusOne; - y0c.imag += chi2.imag * a2c.real + chi2.real * a2c.imag * setPlusOne; - y0c.imag += chi3.imag * a3c.real + chi3.real * a3c.imag * setPlusOne; - y0c.imag += chi4.imag * a4c.real + chi4.real * a4c.imag * setPlusOne; - - *y0 = y0c; - - a0 += inca; - a1 += inca; - a2 += inca; - a3 += inca; - a4 += inca; - y0 += incy; - } - - } + } + + // perform : y += alpha * x; + /** + * chi[x]v.v holds real part of chi. + * chi[x]v.v holds imag part of chi. + * ys holds following computation: + * + * a[xx]v.v R1 I1 R2 I2 + * chi[x]v.v chi_R chi_R chi_R chi_R + * chi[x]v.v chi_I chi_I chi_I chi_I + * y[x]v.v R1*chi_R I1*chi_R R2*chi_R I2*chiR (compute with chi-real part) + * y[x]v.v R1*chi_I I1*chi_I R2*chi_I I2*chiI (compute with chi-imag part) + * + */ + y0v.v = _mm256_mul_pd( a00v.v, chi0v.v ); + y2v.v = _mm256_mul_pd( a00v.v, chi5v.v ); + + /** + * y0v.v holds computation with real part of chi. + * y2v.v holds computaion with imag part of chi. + * Permute will swap the positions of elements in y2v.v. + * as we need to perform: [ R*R + I*I & R*I + I*R]. + * Once dcomplex multiplication is done add the result into r0v.v + * which holds axpy result of current tile which is being + * computed. + */ + y2v.v = _mm256_permute_pd(y2v.v, 0x5); + y0v.v = _mm256_addsub_pd(y0v.v, y2v.v); + r0v.v = _mm256_add_pd(y0v.v, r0v.v); + + y0v.v = _mm256_setzero_pd(); + y2v.v = _mm256_setzero_pd(); + + /** + * Repeat the same computation as above + * for remaining tile. + */ + y0v.v = _mm256_mul_pd( a01v.v, chi1v.v ); + y2v.v = _mm256_mul_pd( a01v.v, chi6v.v ); + + y2v.v = _mm256_permute_pd(y2v.v, 0x5); + y0v.v = _mm256_addsub_pd(y0v.v, y2v.v); + r0v.v = _mm256_add_pd(y0v.v, r0v.v); + + y0v.v = _mm256_setzero_pd(); + y2v.v = _mm256_setzero_pd(); + + + y0v.v = _mm256_mul_pd( a02v.v, chi2v.v ); + y2v.v = _mm256_mul_pd( a02v.v, chi7v.v ); + + y2v.v = _mm256_permute_pd(y2v.v, 0x5); + y0v.v = _mm256_addsub_pd(y0v.v, y2v.v); + r0v.v = _mm256_add_pd(y0v.v, r0v.v); + + y0v.v = _mm256_setzero_pd(); + y2v.v = _mm256_setzero_pd(); + + + y0v.v = _mm256_mul_pd( a03v.v, chi3v.v ); + y2v.v = _mm256_mul_pd( a03v.v, chi8v.v ); + + y2v.v = _mm256_permute_pd(y2v.v, 0x5); + y0v.v = _mm256_addsub_pd(y0v.v, y2v.v); + r0v.v = _mm256_add_pd(y0v.v, r0v.v); + + y0v.v = _mm256_setzero_pd(); + y2v.v = _mm256_setzero_pd(); + + + y0v.v = _mm256_mul_pd( a04v.v, chi4v.v ); + y2v.v = _mm256_mul_pd( a04v.v, chi9v.v ); + + + y2v.v = _mm256_permute_pd(y2v.v, 0x5); + y0v.v = _mm256_addsub_pd(y0v.v, y2v.v); + r0v.v = _mm256_add_pd(y0v.v, r0v.v); + + /** + * Final axpy compuation is available in r0v.v + * Store it back into y vector. + */ + _mm256_storeu_pd( (double*) (y0 + 0*n_elem_per_reg), r0v.v ); + + y0 += n_iter_unroll; + a0 += n_iter_unroll; + a1 += n_iter_unroll; + a2 += n_iter_unroll; + a3 += n_iter_unroll; + a4 += n_iter_unroll; + + } + + // If there are leftover iterations, perform them with scalar code. + for ( ; (i + 0) < m ; ++i ) + { + dcomplex y0c = *y0; + + const dcomplex a0c = *a0; + const dcomplex a1c = *a1; + const dcomplex a2c = *a2; + const dcomplex a3c = *a3; + const dcomplex a4c = *a4; + + y0c.real += chi0.real * a0c.real - chi0.imag * a0c.imag * setPlusOne; + y0c.real += chi1.real * a1c.real - chi1.imag * a1c.imag * setPlusOne; + y0c.real += chi2.real * a2c.real - chi2.imag * a2c.imag * setPlusOne; + y0c.real += chi3.real * a3c.real - chi3.imag * a3c.imag * setPlusOne; + y0c.real += chi4.real * a4c.real - chi4.imag * a4c.imag * setPlusOne; + + y0c.imag += chi0.imag * a0c.real + chi0.real * a0c.imag * setPlusOne; + y0c.imag += chi1.imag * a1c.real + chi1.real * a1c.imag * setPlusOne; + y0c.imag += chi2.imag * a2c.real + chi2.real * a2c.imag * setPlusOne; + y0c.imag += chi3.imag * a3c.real + chi3.real * a3c.imag * setPlusOne; + y0c.imag += chi4.imag * a4c.real + chi4.real * a4c.imag * setPlusOne; + + *y0 = y0c; + + a0 += 1; + a1 += 1; + a2 += 1; + a3 += 1; + a4 += 1; + y0 += 1; + } + } + else + { + for ( ; (i + 0) < m ; ++i ) + { + dcomplex y0c = *y0; + + const dcomplex a0c = *a0; + const dcomplex a1c = *a1; + const dcomplex a2c = *a2; + const dcomplex a3c = *a3; + const dcomplex a4c = *a4; + + y0c.real += chi0.real * a0c.real - chi0.imag * a0c.imag * setPlusOne; + y0c.real += chi1.real * a1c.real - chi1.imag * a1c.imag * setPlusOne; + y0c.real += chi2.real * a2c.real - chi2.imag * a2c.imag * setPlusOne; + y0c.real += chi3.real * a3c.real - chi3.imag * a3c.imag * setPlusOne; + y0c.real += chi4.real * a4c.real - chi4.imag * a4c.imag * setPlusOne; + + y0c.imag += chi0.imag * a0c.real + chi0.real * a0c.imag * setPlusOne; + y0c.imag += chi1.imag * a1c.real + chi1.real * a1c.imag * setPlusOne; + y0c.imag += chi2.imag * a2c.real + chi2.real * a2c.imag * setPlusOne; + y0c.imag += chi3.imag * a3c.real + chi3.real * a3c.imag * setPlusOne; + y0c.imag += chi4.imag * a4c.real + chi4.real * a4c.imag * setPlusOne; + + *y0 = y0c; + + a0 += inca; + a1 += inca; + a2 += inca; + a3 += inca; + a4 += inca; + y0 += incy; + } + + } } From 8b5b2707c1ccf23fe75090e748a1b95b86b3ac59 Mon Sep 17 00:00:00 2001 From: Harsh Dave Date: Thu, 23 Dec 2021 04:44:24 -0600 Subject: [PATCH 062/243] Optimized daxpy2v implementation - Optimized axpy2v implementation for double datatype by handling rows in mulitple of 4 and store the final computed result at the end of computation, preventing unnecessary stores for improving the performance. - Optimal and reuse of vector registers for faster computation. AMD-Internal: [CPUPL-1973] Change-Id: I7b8ef94d0f67c1c666fdce26e9b2b7291365d2e9 --- config/zen/bli_cntx_init_zen.c | 4 +- config/zen2/bli_cntx_init_zen2.c | 8 +- config/zen3/bli_cntx_init_zen3.c | 4 +- config/zen4/bli_cntx_init_zen4.c | 4 +- kernels/zen/1f/CMakeLists.txt | 1 + kernels/zen/1f/bli_axpy2v_zen_int.c | 188 ++++++++++++++++++++++++++++ kernels/zen/bli_kernels_zen.h | 11 +- 7 files changed, 210 insertions(+), 10 deletions(-) create mode 100644 kernels/zen/1f/bli_axpy2v_zen_int.c diff --git a/config/zen/bli_cntx_init_zen.c b/config/zen/bli_cntx_init_zen.c index 020e7052b9..ec356fd231 100644 --- a/config/zen/bli_cntx_init_zen.c +++ b/config/zen/bli_cntx_init_zen.c @@ -80,7 +80,7 @@ void bli_cntx_init_zen( cntx_t* cntx ) // Update the context with optimized level-1f kernels. bli_cntx_set_l1f_kers ( - 6, + 7, // axpyf BLIS_AXPYF_KER, BLIS_FLOAT, bli_saxpyf_zen_int_8, BLIS_AXPYF_KER, BLIS_DOUBLE, bli_daxpyf_zen_int_8, @@ -89,6 +89,8 @@ void bli_cntx_init_zen( cntx_t* cntx ) // dotxf BLIS_DOTXF_KER, BLIS_FLOAT, bli_sdotxf_zen_int_8, BLIS_DOTXF_KER, BLIS_DOUBLE, bli_ddotxf_zen_int_8, + //axpy2v + BLIS_AXPY2V_KER, BLIS_DOUBLE, bli_daxpy2v_zen_int, cntx ); diff --git a/config/zen2/bli_cntx_init_zen2.c b/config/zen2/bli_cntx_init_zen2.c index 315362067e..47846ef22d 100644 --- a/config/zen2/bli_cntx_init_zen2.c +++ b/config/zen2/bli_cntx_init_zen2.c @@ -92,15 +92,17 @@ void bli_cntx_init_zen2( cntx_t* cntx ) // Update the context with optimized level-1f kernels. bli_cntx_set_l1f_kers ( - 6, + 7, // axpyf BLIS_AXPYF_KER, BLIS_FLOAT, bli_saxpyf_zen_int_5, BLIS_AXPYF_KER, BLIS_DOUBLE, bli_daxpyf_zen_int_5, - BLIS_AXPYF_KER, BLIS_SCOMPLEX, bli_caxpyf_zen_int_5, - BLIS_AXPYF_KER, BLIS_DCOMPLEX, bli_zaxpyf_zen_int_5, + BLIS_AXPYF_KER, BLIS_SCOMPLEX, bli_caxpyf_zen_int_5, + BLIS_AXPYF_KER, BLIS_DCOMPLEX, bli_zaxpyf_zen_int_5, // dotxf BLIS_DOTXF_KER, BLIS_FLOAT, bli_sdotxf_zen_int_8, BLIS_DOTXF_KER, BLIS_DOUBLE, bli_ddotxf_zen_int_8, + // axpy2v + BLIS_AXPY2V_KER, BLIS_DOUBLE, bli_daxpy2v_zen_int, cntx ); diff --git a/config/zen3/bli_cntx_init_zen3.c b/config/zen3/bli_cntx_init_zen3.c index ef47987454..7e7b120832 100644 --- a/config/zen3/bli_cntx_init_zen3.c +++ b/config/zen3/bli_cntx_init_zen3.c @@ -92,7 +92,7 @@ void bli_cntx_init_zen3( cntx_t* cntx ) // Update the context with optimized level-1f kernels. bli_cntx_set_l1f_kers ( - 6, + 7, // axpyf BLIS_AXPYF_KER, BLIS_FLOAT, bli_saxpyf_zen_int_5, BLIS_AXPYF_KER, BLIS_DOUBLE, bli_daxpyf_zen_int_5, @@ -101,6 +101,8 @@ void bli_cntx_init_zen3( cntx_t* cntx ) // dotxf BLIS_DOTXF_KER, BLIS_FLOAT, bli_sdotxf_zen_int_8, BLIS_DOTXF_KER, BLIS_DOUBLE, bli_ddotxf_zen_int_8, + // axpy2v + BLIS_AXPY2V_KER, BLIS_DOUBLE, bli_daxpy2v_zen_int, cntx ); diff --git a/config/zen4/bli_cntx_init_zen4.c b/config/zen4/bli_cntx_init_zen4.c index 4f4c16d0ae..98c15796ef 100644 --- a/config/zen4/bli_cntx_init_zen4.c +++ b/config/zen4/bli_cntx_init_zen4.c @@ -91,7 +91,7 @@ void bli_cntx_init_zen4( cntx_t* cntx ) // Update the context with optimized level-1f kernels. bli_cntx_set_l1f_kers ( - 6, + 7, // axpyf BLIS_AXPYF_KER, BLIS_FLOAT, bli_saxpyf_zen_int_5, BLIS_AXPYF_KER, BLIS_DOUBLE, bli_daxpyf_zen_int_5, @@ -100,6 +100,8 @@ void bli_cntx_init_zen4( cntx_t* cntx ) // dotxf BLIS_DOTXF_KER, BLIS_FLOAT, bli_sdotxf_zen_int_8, BLIS_DOTXF_KER, BLIS_DOUBLE, bli_ddotxf_zen_int_8, + // axpy2v + BLIS_AXPY2V_KER, BLIS_DOUBLE, bli_daxpy2v_zen_int, cntx ); diff --git a/kernels/zen/1f/CMakeLists.txt b/kernels/zen/1f/CMakeLists.txt index d2bf13822d..4b9caa40b6 100644 --- a/kernels/zen/1f/CMakeLists.txt +++ b/kernels/zen/1f/CMakeLists.txt @@ -7,4 +7,5 @@ target_sources("${PROJECT_NAME}" ${CMAKE_CURRENT_SOURCE_DIR}/bli_axpyf_zen_int_5.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_axpyf_zen_int_4.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_axpyf_zen_int_6.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_axpy2v_zen_int.c ) diff --git a/kernels/zen/1f/bli_axpy2v_zen_int.c b/kernels/zen/1f/bli_axpy2v_zen_int.c new file mode 100644 index 0000000000..4ddca52162 --- /dev/null +++ b/kernels/zen/1f/bli_axpy2v_zen_int.c @@ -0,0 +1,188 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2018, The University of Texas at Austin + Copyright (C) 2022, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ +#include "blis.h" +#include "immintrin.h" + + +/** + * daxpy2v kernel performs axpy2v operation. + * z := y + alphax * conjx(x) + alphay * conjy(y) + * where x, y, and z are vectors of length n. + */ +void bli_daxpy2v_zen_int + ( + conj_t conjx, + conj_t conjy, + dim_t n, + double* restrict alphax, + double* restrict alphay, + double* restrict x, inc_t incx, + double* restrict y, inc_t incy, + double* restrict z, inc_t incz, + cntx_t* restrict cntx + ) +{ + if ( bli_zero_dim1( n ) ) return; + + if ( incz == 1 && incx == 1 && incy == 1 ) + { + dim_t i = 0; + dim_t rem = n%4; + const dim_t n_elem_per_reg = 4; + __m256d xv[4], yv[4], zv[4]; + __m256d alphaxv, alphayv; + + alphaxv = _mm256_broadcast_sd((double const*) alphax); + alphayv = _mm256_broadcast_sd((double const*) alphay); + + for( ; (i + 15) < n; i+= 16 ) + { + xv[0] = _mm256_loadu_pd( x + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_pd( x + 1*n_elem_per_reg ); + xv[2] = _mm256_loadu_pd( x + 2*n_elem_per_reg ); + xv[3] = _mm256_loadu_pd( x + 3*n_elem_per_reg ); + + yv[0] = _mm256_loadu_pd( y + 0*n_elem_per_reg ); + yv[1] = _mm256_loadu_pd( y + 1*n_elem_per_reg ); + yv[2] = _mm256_loadu_pd( y + 2*n_elem_per_reg ); + yv[3] = _mm256_loadu_pd( y + 3*n_elem_per_reg ); + + zv[0] = _mm256_loadu_pd( z + 0*n_elem_per_reg ); + zv[1] = _mm256_loadu_pd( z + 1*n_elem_per_reg ); + zv[2] = _mm256_loadu_pd( z + 2*n_elem_per_reg ); + zv[3] = _mm256_loadu_pd( z + 3*n_elem_per_reg ); + + zv[0] = _mm256_fmadd_pd(xv[0], alphaxv, zv[0]); + zv[1] = _mm256_fmadd_pd(xv[1], alphaxv, zv[1]); + zv[2] = _mm256_fmadd_pd(xv[2], alphaxv, zv[2]); + zv[3] = _mm256_fmadd_pd(xv[3], alphaxv, zv[3]); + + zv[0] = _mm256_fmadd_pd(yv[0], alphayv, zv[0]); + zv[1] = _mm256_fmadd_pd(yv[1], alphayv, zv[1]); + zv[2] = _mm256_fmadd_pd(yv[2], alphayv, zv[2]); + zv[3] = _mm256_fmadd_pd(yv[3], alphayv, zv[3]); + + _mm256_storeu_pd((z + 0*n_elem_per_reg), zv[0]); + _mm256_storeu_pd((z + 1*n_elem_per_reg), zv[1]); + _mm256_storeu_pd((z + 2*n_elem_per_reg), zv[2]); + _mm256_storeu_pd((z + 3*n_elem_per_reg), zv[3]); + + z += 4*n_elem_per_reg; + x += 4*n_elem_per_reg; + y += 4*n_elem_per_reg; + } + + for( ; (i + 7) < n; i+= 8 ) + { + xv[0] = _mm256_loadu_pd( x + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_pd( x + 1*n_elem_per_reg ); + + yv[0] = _mm256_loadu_pd( y + 0*n_elem_per_reg ); + yv[1] = _mm256_loadu_pd( y + 1*n_elem_per_reg ); + + zv[0] = _mm256_loadu_pd( z + 0*n_elem_per_reg ); + zv[1] = _mm256_loadu_pd( z + 1*n_elem_per_reg ); + + zv[0] = _mm256_fmadd_pd(xv[0], alphaxv, zv[0]); + zv[1] = _mm256_fmadd_pd(xv[1], alphaxv, zv[1]); + + zv[0] = _mm256_fmadd_pd(yv[0], alphayv, zv[0]); + zv[1] = _mm256_fmadd_pd(yv[1], alphayv, zv[1]); + + _mm256_storeu_pd((z + 0*n_elem_per_reg), zv[0]); + _mm256_storeu_pd((z + 1*n_elem_per_reg), zv[1]); + + z += 2*n_elem_per_reg; + x += 2*n_elem_per_reg; + y += 2*n_elem_per_reg; + } + + for( ; (i + 3) < n; i+= 4 ) + { + xv[0] = _mm256_loadu_pd( x + 0*n_elem_per_reg ); + + yv[0] = _mm256_loadu_pd( y + 0*n_elem_per_reg ); + + zv[0] = _mm256_loadu_pd( z + 0*n_elem_per_reg ); + + zv[0] = _mm256_fmadd_pd(xv[0], alphaxv, zv[0]); + + zv[0] = _mm256_fmadd_pd(yv[0], alphayv, zv[0]); + + _mm256_storeu_pd((z + 0*n_elem_per_reg), zv[0]); + + z += n_elem_per_reg; + x += n_elem_per_reg; + y += n_elem_per_reg; + } + if(rem) + { + PRAGMA_SIMD + for ( i = 0; i < rem; ++i ) + { + PASTEMAC(d,axpys)( *alphax, x[i], z[i] ); + PASTEMAC(d,axpys)( *alphay, y[i], z[i] ); + } + } + } + else + { + /* Query the context for the kernel function pointer. */ + const num_t dt = PASTEMAC(d,type); + PASTECH(d,axpyv_ker_ft) kfp_av + = + bli_cntx_get_l1v_ker_dt( dt, BLIS_AXPYV_KER, cntx ); + + kfp_av + ( + conjx, + n, + alphax, + x, incx, + z, incz, + cntx + ); + + kfp_av + ( + conjy, + n, + alphay, + y, incy, + z, incz, + cntx + ); + } +} diff --git a/kernels/zen/bli_kernels_zen.h b/kernels/zen/bli_kernels_zen.h index 86640dff53..e87995d2ce 100644 --- a/kernels/zen/bli_kernels_zen.h +++ b/kernels/zen/bli_kernels_zen.h @@ -34,12 +34,13 @@ */ // hemv helper function void bli_pre_hemv_8x8(double *a, double *x, - double *y, double *alpha, - dim_t cs_a, dim_t rs_a); + double *y, double *alpha, + dim_t cs_a, dim_t rs_a); void bli_post_hemv_8x8(double *a, double *x, - double *y, double *alpha, - dim_t cs_a, dim_t rs_a); + double *y, double *alpha, + dim_t cs_a, dim_t rs_a); + // -- level-1m -- PACKM_KER_PROT(double, d, packm_8xk_gen_zen) @@ -124,6 +125,8 @@ AXPYF_KER_PROT( scomplex, c, axpyf_zen_int_5 ) AXPYF_KER_PROT( scomplex, c, axpyf_zen_int_4 ) AXPYF_KER_PROT( dcomplex, z, axpyf_zen_int_5 ) AXPYF_KER_PROT( dcomplex, z, axpyf_zen_int_4 ) +// axpy2v (intrinsics) +AXPY2V_KER_PROT(double, d, axpy2v_zen_int ) // dotxf (intrinsics) DOTXF_KER_PROT( float, s, dotxf_zen_int_8 ) From 351269219f31b20177cc13514dbae356489fe497 Mon Sep 17 00:00:00 2001 From: Harsh Dave Date: Fri, 17 Dec 2021 02:34:52 -0600 Subject: [PATCH 063/243] Optimized dher2 implementation - Impplemented her2 framework calls for transposed and non transposed kernel variants. - dher2 kernel operate over 4 columns at a time. It computes 4x4 triangular part of matrix first and remainder part is computed in chunk of 4x4 tile upto m rows. - remainder cases(m < 4) are handled serially. AMD-Internal: [CPUPL-1968] Change-Id: I12ae97b2ad673a7fd9b733c607f27b1089142313 --- frame/2/hemv/bli_hemv_unf_var1.c | 12 +- frame/2/hemv/bli_hemv_unf_var3.c | 11 + frame/2/her2/bli_her2_unf_var1.c | 212 +++++++++++++++ frame/2/her2/bli_her2_unf_var4.c | 187 ++++++++++++++ kernels/zen/2/CMakeLists.txt | 1 + kernels/zen/2/bli_her2_zen_int_4.c | 396 +++++++++++++++++++++++++++++ kernels/zen/bli_kernels_zen.h | 10 - 7 files changed, 816 insertions(+), 13 deletions(-) create mode 100644 kernels/zen/2/bli_her2_zen_int_4.c diff --git a/frame/2/hemv/bli_hemv_unf_var1.c b/frame/2/hemv/bli_hemv_unf_var1.c index ccb39b3485..6790e5bd08 100644 --- a/frame/2/hemv/bli_hemv_unf_var1.c +++ b/frame/2/hemv/bli_hemv_unf_var1.c @@ -218,9 +218,15 @@ void PASTEMAC(ch,varname) \ #ifdef BLIS_CONFIG_EPYC -void post_hemv_8x8(double *a, double *x, - double *y, double *alpha, - dim_t cs_a, dim_t rs_a); +void bli_post_hemv_8x8 + ( + double *a, + double *x, + double *y, + double *alpha, + dim_t cs_a, + dim_t rs_a + ); void bli_dhemv_unf_var1 ( diff --git a/frame/2/hemv/bli_hemv_unf_var3.c b/frame/2/hemv/bli_hemv_unf_var3.c index 6ed18efea4..abf08dfdaf 100644 --- a/frame/2/hemv/bli_hemv_unf_var3.c +++ b/frame/2/hemv/bli_hemv_unf_var3.c @@ -217,6 +217,17 @@ void PASTEMAC(ch,varname) \ } #ifdef BLIS_CONFIG_EPYC + +void bli_pre_hemv_8x8 + ( + double *a, + double *x, + double *y, + double *alpha, + dim_t cs_a, + dim_t rs_a + ); + void bli_dhemv_unf_var3 ( uplo_t uplo, diff --git a/frame/2/her2/bli_her2_unf_var1.c b/frame/2/her2/bli_her2_unf_var1.c index a0aec48f71..299e3d161d 100644 --- a/frame/2/her2/bli_her2_unf_var1.c +++ b/frame/2/her2/bli_her2_unf_var1.c @@ -158,5 +158,217 @@ void PASTEMAC(ch,varname) \ } \ } + +#ifdef BLIS_CONFIG_EPYC + +/** + * Following is function declaration + * that computes her2 for transposed case. + * It handles triangular part of matrix and + * remaining computation in optimal way to + * gain performance improvement. + * a is triangular matrix, x and y are vectors + */ +void bli_dher2_trans_zen_int_4 + ( + double *a, + double *x, + double *y, + double *alpha, + dim_t m, + dim_t lda + ); + +void bli_dher2_unf_var1 + ( + uplo_t uplo, + conj_t conjx, + conj_t conjy, + conj_t conjh, + dim_t m, + double* alpha, + double* x, inc_t incx, + double* y, inc_t incy, + double* c, inc_t rs_c, inc_t cs_c, + cntx_t* cntx + ) +{ + const num_t dt = PASTEMAC(d,type); + + double* x0; + double* chi1; + double* y0; + double* psi1; + double* c10t; + double* gamma11; + double alpha0; + double alpha1; + double alpha0_chi1; + double alpha1_psi1; + double alpha0_chi1_psi1; + double conjx0_chi1; + double conjy1_psi1; + double conjy0_psi1; + dim_t i; + dim_t n_behind; + inc_t rs_ct, cs_ct; + conj_t conj0, conj1; + + /* The algorithm will be expressed in terms of the lower triangular + * case;the upper triangular case is supported by swapping the row + * and column strides of A and toggling some conj parameters. + */ + if ( bli_is_lower( uplo ) ) + { + rs_ct = rs_c; + cs_ct = cs_c; + + PASTEMAC(d,copys)( *alpha, alpha0 ); + PASTEMAC(d,copycjs)( conjh, *alpha, alpha1 ); + } + else /* if ( bli_is_upper( uplo ) ) */ + { + rs_ct = cs_c; + cs_ct = rs_c; + + /* Toggle conjugation of conjx/conjy, but only if we are being + * invoked as her2; for syr2, conjx/conjy are unchanged. + */ + conjx = bli_apply_conj( conjh, conjx ); + conjy = bli_apply_conj( conjh, conjy ); + + PASTEMAC(d,copycjs)( conjh, *alpha, alpha0 ); + PASTEMAC(d,copys)( *alpha, alpha1 ); + } + + /* Apply conjh (which carries the conjugation component of the + * Hermitian transpose, if applicable) to conjx and/or conjy as + * needed to arrive at the effective conjugation for the vector + * subproblems. + */ + conj0 = bli_apply_conj( conjh, conjy ); + conj1 = bli_apply_conj( conjh, conjx ); + + PASTECH(d,axpy2v_ker_ft) kfp_2v; + + /* Query the context for the kernel function pointer. */ + kfp_2v = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPY2V_KER, cntx ); + + if( (incx == 1) && (incy == 1) && (rs_ct == 1)) + { + for ( i = 0; i < m; ) + { + n_behind = i; + x0 = x + (0 )*incx; + chi1 = x + (i )*incx; + y0 = y + (0 )*incy; + psi1 = y + (i )*incy; + c10t = c + (i )*rs_ct + (0 )*cs_ct; + gamma11 = c + (i )*rs_ct + (i )*cs_ct; + + if((n_behind >= 3)) + { + bli_dher2_trans_zen_int_4(c10t, x0, y0, &alpha0, n_behind + 1, cs_ct); + i+=4; + } + else + { + /* Apply conjx and/or conjy to chi1 and/or psi1. */ + PASTEMAC(d,copycjs)( conjx, *chi1, conjx0_chi1 ); + PASTEMAC(d,copycjs)( conjy, *psi1, conjy1_psi1 ); + PASTEMAC(d,copycjs)( conj0, *psi1, conjy0_psi1 ); + + /* Compute scalars for vector subproblems. */ + PASTEMAC(d,scal2s)( alpha0, conjx0_chi1, alpha0_chi1 ); + PASTEMAC(d,scal2s)( alpha1, conjy1_psi1, alpha1_psi1 ); + + /* Compute alpha * chi1 * conj(psi1) after both chi1 + * and psi1 have already been conjugated, if needed, + * by conjx and conjy. + */ + PASTEMAC(d,scal2s)( alpha0_chi1, conjy0_psi1, + alpha0_chi1_psi1 ); + + /* c10t = c10t + alpha * chi1 * y0'; */ + /* c10t = c10t + conj(alpha) * psi1 * x0'; */ + kfp_2v + ( + conj0, + conj1, + n_behind, + &alpha0_chi1, + &alpha1_psi1, + y0, incy, + x0, incx, + c10t, cs_ct, + cntx + ); + + /* gamma11 = gamma11 + alpha * chi1 * conj(psi1) + + conj(alpha) * psi1 * conj(chi1); */ + PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); + PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); + + i+=1; + } + } + } + else + { + for ( i = 0; i < m; ++i ) + { + n_behind = i; + x0 = x + (0 )*incx; + chi1 = x + (i )*incx; + y0 = y + (0 )*incy; + psi1 = y + (i )*incy; + c10t = c + (i )*rs_ct + (0 )*cs_ct; + gamma11 = c + (i )*rs_ct + (i )*cs_ct; + + /* Apply conjx and/or conjy to chi1 and/or psi1. */ + PASTEMAC(d,copycjs)( conjx, *chi1, conjx0_chi1 ); + PASTEMAC(d,copycjs)( conjy, *psi1, conjy1_psi1 ); + PASTEMAC(d,copycjs)( conj0, *psi1, conjy0_psi1 ); + + /* Compute scalars for vector subproblems. */ + PASTEMAC(d,scal2s)( alpha0, conjx0_chi1, alpha0_chi1 ); + PASTEMAC(d,scal2s)( alpha1, conjy1_psi1, alpha1_psi1 ); + + /* Compute alpha * chi1 * conj(psi1) after both chi1 + * and psi1 have already been conjugated, if needed, + * by conjx and conjy. + */ + PASTEMAC(d,scal2s)( alpha0_chi1, conjy0_psi1, + alpha0_chi1_psi1 ); + + /* c10t = c10t + alpha * chi1 * y0'; */ + /* c10t = c10t + conj(alpha) * psi1 * x0'; */ + kfp_2v + ( + conj0, + conj1, + n_behind, + &alpha0_chi1, + &alpha1_psi1, + y0, incy, + x0, incx, + c10t, cs_ct, + cntx + ); + + /* gamma11 = gamma11 + alpha * chi1 * conj(psi1) + + conj(alpha) * psi1 * conj(chi1); */ + PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); + PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); + + } + } +} + +GENTFUNC(float, s, her2_unf_var1) +GENTFUNC(scomplex, c, her2_unf_var1) +GENTFUNC(dcomplex, z,her2_unf_var1) +#else INSERT_GENTFUNC_BASIC0( her2_unf_var1 ) +#endif diff --git a/frame/2/her2/bli_her2_unf_var4.c b/frame/2/her2/bli_her2_unf_var4.c index 3dea31d53e..e39c7224c4 100644 --- a/frame/2/her2/bli_her2_unf_var4.c +++ b/frame/2/her2/bli_her2_unf_var4.c @@ -166,5 +166,192 @@ void PASTEMAC(ch,varname) \ } \ } +#ifdef BLIS_CONFIG_EPYC +/** + * Following is function declaration + * that computes her2 for transposed case. + * It handles triangular part of matrix and + * remaining computation in optimal way to + * gain performance improvement. + * a is triangular matrix, x and y are vectors + */ +void bli_dher2_zen_int_4 + ( + double *a, + double *x, + double *y, + double *alpha, + dim_t m, + dim_t lda + ); + +void bli_dher2_unf_var4 + ( + uplo_t uplo, + conj_t conjx, + conj_t conjy, + conj_t conjh, + dim_t m, + double* alpha, + double* x, inc_t incx, + double* y, inc_t incy, + double* c, inc_t rs_c, inc_t cs_c, + cntx_t* cntx + ) +{ + + double* chi1; + double* x2; + double* psi1; + double* y2; + double* gamma11; + double* c21; + double alpha0; + double alpha0_psi1; + double alpha1_chi1; + double alpha0_chi1_psi1; + dim_t i; + dim_t n_ahead; + inc_t rs_ct, cs_ct; + + const num_t dt = PASTEMAC(d,type); + + /* The algorithm will be expressed in terms of the lower triangular + * case; the upper triangular case is supported by swapping the row + * and column strides of A and toggling some conj parameters. + */ + if ( bli_is_lower( uplo ) ) + { + rs_ct = rs_c; + cs_ct = cs_c; + + PASTEMAC(d,copys)( *alpha, alpha0 ); + } + else /* if ( bli_is_upper( uplo ) ) */ + { + rs_ct = cs_c; + cs_ct = rs_c; + + /* Toggle conjugation of conjx/conjy, but only if we are being + * invoked as her2; for syr2, conjx/conjy are unchanged. + */ + + PASTEMAC(d,copys)( *alpha, alpha0 ); + } + /* Apply conjh (which carries the conjugation component of the + * Hermitian transpose, if applicable) to conjx and/or conjy as + * needed to arrive at the effective conjugation for the vector + * subproblems. + */ + + PASTECH(d,axpy2v_ker_ft) kfp_2v; + + /* Query the context for the kernel function pointer. */ + kfp_2v = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPY2V_KER, cntx ); + + if((incx == 1) && (incy == 1) && (rs_ct == 1)) + { + for ( i = 0; i < m; ) + { + n_ahead = m - i - 1; + chi1 = x + (i ) * incx; + x2 = x + (i+1) * incx; + psi1 = y + (i ) * incy; + y2 = y + (i+1) * incy; + gamma11 = c + (i ) + (i )*cs_ct; + c21 = c + (i+1) + (i )*cs_ct; + + if((n_ahead >= 3)) + { + bli_dher2_zen_int_4(gamma11, chi1, psi1, &alpha0, n_ahead + 1, cs_ct); + i+= 4; + } + else + { + /* Compute scalars for vector subproblems. */ + PASTEMAC(d,scal2s)( alpha0, *psi1, alpha0_psi1 ); + PASTEMAC(d,scal2s)( alpha0, *chi1, alpha1_chi1 ); + + /* Compute alpha * chi1 * conj(psi1) after both chi1 + * and psi1 have + already been conjugated, if needed, by conjx and + conjy. */ + PASTEMAC(d,scal2s)( alpha0_psi1, *chi1, + alpha0_chi1_psi1 ); + + /* c21 = c21 + alpha * x2 * conj(psi1); */ + /* c21 = c21 + conj(alpha) * y2 * conj(chi1); */ + + kfp_2v + ( + conjx, + conjy, + n_ahead, + &alpha0_psi1, + &alpha1_chi1, + x2, incx, + y2, incy, + c21, rs_ct, + cntx + ); + + + PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); + PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); + i+=1; + } + } + } + else + { + for ( i = 0; i < m; ++i) + { + n_ahead = m - i - 1; + chi1 = x + (i ) * incx; + x2 = x + (i+1) * incx; + psi1 = y + (i ) * incy; + y2 = y + (i+1) * incy; + gamma11 = c + (i ) + (i )*cs_ct; + c21 = c + (i+1) + (i )*cs_ct; + + /* Compute scalars for vector subproblems. */ + PASTEMAC(d,scal2s)( alpha0, *psi1, alpha0_psi1 ); + PASTEMAC(d,scal2s)( alpha0, *chi1, alpha1_chi1 ); + + /* Compute alpha * chi1 * conj(psi1) after both chi1 + * and psi1 have + already been conjugated, if needed, by conjx and + conjy. */ + PASTEMAC(d,scal2s)( alpha0_psi1, *chi1, + alpha0_chi1_psi1 ); + + /* c21 = c21 + alpha * x2 * conj(psi1); */ + /* c21 = c21 + conj(alpha) * y2 * conj(chi1); */ + + kfp_2v + ( + conjx, + conjy, + n_ahead, + &alpha0_psi1, + &alpha1_chi1, + x2, incx, + y2, incy, + c21, rs_ct, + cntx + ); + + + PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); + PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); + } + } +} + +GENTFUNC(float, s, her2_unf_var4) +GENTFUNC(scomplex, c, her2_unf_var4) +GENTFUNC(dcomplex, z,her2_unf_var4) +#else INSERT_GENTFUNC_BASIC0( her2_unf_var4 ) +#endif diff --git a/kernels/zen/2/CMakeLists.txt b/kernels/zen/2/CMakeLists.txt index 480837c023..72176895c5 100644 --- a/kernels/zen/2/CMakeLists.txt +++ b/kernels/zen/2/CMakeLists.txt @@ -3,6 +3,7 @@ target_sources("${PROJECT_NAME}" PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemv_zen_ref.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_her2_zen_int_4.c ) diff --git a/kernels/zen/2/bli_her2_zen_int_4.c b/kernels/zen/2/bli_her2_zen_int_4.c new file mode 100644 index 0000000000..9b181aa278 --- /dev/null +++ b/kernels/zen/2/bli_her2_zen_int_4.c @@ -0,0 +1,396 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "immintrin.h" +#include "blis.h" + +void bli_dher2_trans_zen_int_4 + ( + double *a, + double *x, + double *y, + double *alpha, + dim_t m, + dim_t lda + ) +{ + dim_t row = 0; + dim_t rem = m % 4; + + /*holds 4 diagonal elements of triangular part of 4x4 tile*/ + double a_diag[4] = {0}; + /*alpha_chi holds x*alpha and alpha_psi holds y*alpha*/ + double alpha_chi[4] = {0}; + double alpha_psi[4] = {0}; + /*Extracts diagonal element and store into a_diag buffer*/ + PRAGMA_SIMD + for(dim_t i = 0; i < 4; i++) + { + a_diag[i] = *(a + m + i + (i * lda)); + } + + __m256d x0, x1, x2, x3; + __m256d y0, y1, y2, y3; + + __m256d xr, yr, zero, gamma; + __m256d a0, a1, a2, a3; + + zero = _mm256_setzero_pd(); + + /*Loading elements of x & y vectors*/ + x0 = _mm256_loadu_pd(x + m); + y0 = _mm256_loadu_pd(y + m); + /*Broadcasting alpha to compute alpha_psi and alpha_chi*/ + x1 = _mm256_broadcast_sd(alpha); + + x2 = _mm256_mul_pd(x0, x1); + y0 = _mm256_mul_pd(y0, x1); + + /*Storing alpha_chi and alpha_psi for later usage in computation loop*/ + _mm256_storeu_pd(alpha_chi, x2); + _mm256_storeu_pd(alpha_psi, y0); + + x0 = _mm256_mul_pd(x0, y0); + gamma = _mm256_loadu_pd(a_diag); + gamma = _mm256_add_pd(gamma, x0); + gamma = _mm256_add_pd(gamma, x0); + _mm256_storeu_pd(a_diag, gamma); + + /* Broadcasting 4 alpha_psis and alpha_chis which + * are to be used througout the computation of 4x4 tile + * upto m rows. + */ + x0 = _mm256_broadcast_sd(&alpha_chi[0]); + x1 = _mm256_broadcast_sd(&alpha_chi[1]); + x2 = _mm256_broadcast_sd(&alpha_chi[2]); + x3 = _mm256_broadcast_sd(&alpha_chi[3]); + + y0 = _mm256_broadcast_sd(&alpha_psi[0]); + y1 = _mm256_broadcast_sd(&alpha_psi[1]); + y2 = _mm256_broadcast_sd(&alpha_psi[2]); + y3 = _mm256_broadcast_sd(&alpha_psi[3]); + + /* Loading 4x4 tile of A matrix for + * triangular part computation + */ + a0 = _mm256_loadu_pd(a + (0 * lda) + m); + a1 = _mm256_loadu_pd(a + (1 * lda) + m); + a2 = _mm256_loadu_pd(a + (2 * lda) + m); + a3 = _mm256_loadu_pd(a + (3 * lda) + m); + + yr = _mm256_loadu_pd(y); + xr = _mm256_loadu_pd(x); + + /*Setting first element of x & y vectors to zero + * to eliminate diagonal element of 1st column + * from computation + */ + xr = _mm256_blend_pd(xr, zero, 0x1); + yr = _mm256_blend_pd(yr, zero, 0x1); + a0 = _mm256_blend_pd(a0, zero, 0x1); + + a1 = _mm256_blend_pd(a1, zero, 0x3); + a2 = _mm256_blend_pd(a2, zero, 0x7); + a3 = _mm256_blend_pd(a3, zero, 0xF); + + a0 = _mm256_fmadd_pd(xr, y0, a0); + a0 = _mm256_fmadd_pd(yr, x0, a0); + + /*Setting two elements of x & y vectors to zero + * to eliminate diagonal element of 2nd column + * from computation + */ + xr = _mm256_blend_pd(xr, zero, 0x3); + yr = _mm256_blend_pd(yr, zero, 0x3); + a1 = _mm256_fmadd_pd(xr, y1, a1); + a1 = _mm256_fmadd_pd(yr, x1, a1); + + /*Setting three elements of x & y vectors to zero + * to eliminate diagonal element of 3rd column + * from computation + */ + xr = _mm256_blend_pd(xr, zero, 0x7); + yr = _mm256_blend_pd(yr, zero, 0x7); + a2 = _mm256_fmadd_pd(xr, y2, a2); + a2 = _mm256_fmadd_pd(yr, x2, a2); + + _mm256_storeu_pd(a + (0 * lda) + m, a0 ); + + /* Loading data from memory location first + * so it could be blend with and finally + * gets stored at same location to prevent + * unnecessary data overwriting at nearby + * memory locations + */ + a3 = _mm256_loadu_pd(a + (1 * lda) + m ); + a1 = _mm256_blend_pd(a1, a3, 0x1); + _mm256_storeu_pd(a + (1 * lda) + m, a1 ); + + a3 = _mm256_loadu_pd(a + (2 * lda) + m ); + a2 = _mm256_blend_pd(a2, a3, 0x3); + _mm256_storeu_pd(a + (2 * lda) + m, a2 ); + + /* Triangular part of matrix is computed, remaining + * part is computed in below loop upto m rows. + */ + for(; (row + 4) <= m; row+=4) + { + /* Loading elements of x and y vector */ + xr = _mm256_loadu_pd(x + row); + yr = _mm256_loadu_pd(y + row); + /* Loading tile of A matrix of size 4x4 */ + a0 = _mm256_loadu_pd(a + row + (0 * lda) ); + a1 = _mm256_loadu_pd(a + row + (1 * lda) ); + a2 = _mm256_loadu_pd(a + row + (2 * lda) ); + a3 = _mm256_loadu_pd(a + row + (3 * lda) ); + + a0 = _mm256_fmadd_pd(xr, y0, a0); + a1 = _mm256_fmadd_pd(xr, y1, a1); + a2 = _mm256_fmadd_pd(xr, y2, a2); + a3 = _mm256_fmadd_pd(xr, y3, a3); + + a0 = _mm256_fmadd_pd(yr, x0, a0); + a1 = _mm256_fmadd_pd(yr, x1, a1); + a2 = _mm256_fmadd_pd(yr, x2, a2); + a3 = _mm256_fmadd_pd(yr, x3, a3); + + _mm256_storeu_pd(a + row + (0 * lda), a0); + _mm256_storeu_pd(a + row + (1 * lda), a1); + _mm256_storeu_pd(a + row + (2 * lda), a2); + _mm256_storeu_pd(a + row + (3 * lda), a3); + } + + /* Computes remainder cases where m is less than 4 */ + if(rem) + { + PRAGMA_SIMD + for(dim_t i = 0; i < 4; i++) + { + for(dim_t j = row; j < m; j++) + { + a[ j + (i * lda)] += x[j] * (y[i] * (*alpha)); + a[ j + (i * lda)] += y[j] * (x[i] * (*alpha)); + } + } + } + + /* Computing 4 diagonal elements of triangular part of matrix + * and storing result back at corresponding location in matrix A + */ + PRAGMA_SIMD + for(dim_t i = 0; i < 4; i++) + { + *(a + m + i + (i * lda)) = a_diag[i]; + } +} + + +void bli_dher2_zen_int_4 + ( + double *a, + double *x, + double *y, + double *alpha, + dim_t m, + dim_t lda + ) +{ + dim_t row = 4; + dim_t rem = m % 4; + + /*holds 4 diagonal elements of triangular part of 4x4 tile*/ + double a_diag[4] = {0}; + /*alpha_chi holds x*alpha and alpha_psi holds y*alpha*/ + double alpha_chi[4] = {0}; + double alpha_psi[4] = {0}; + /*Extracts diagonal element and store into a_diag buffer*/ + PRAGMA_SIMD + for(dim_t i = 0; i < 4; i++) + { + a_diag[i] = *(a + i + (i * lda)); + } + + __m256d x0, x1, x2, x3; + __m256d y0, y1, y2, y3; + + __m256d xr, yr, zero, gamma; + __m256d a0, a1, a2, a3; + + zero = _mm256_setzero_pd(); + + /*Loading elements of x & y vectors*/ + x0 = _mm256_loadu_pd(x); + y0 = _mm256_loadu_pd(y); + /*Broadcasting alpha to compute alpha_psi and alpha_chi*/ + x1 = _mm256_broadcast_sd(alpha); + + x2 = _mm256_mul_pd(x0, x1); + y0 = _mm256_mul_pd(y0, x1); + + /*Storing alpha_chi and alpha_psi for later usage in computation loop*/ + _mm256_storeu_pd(alpha_chi, x2); + _mm256_storeu_pd(alpha_psi, y0); + + x0 = _mm256_mul_pd(x0, y0); + gamma = _mm256_loadu_pd(a_diag); + gamma = _mm256_add_pd(gamma, x0); + gamma = _mm256_add_pd(gamma, x0); + _mm256_storeu_pd(a_diag, gamma); + + /* Broadcasting 4 alpha_psis and alpha_chis which + * are to be used througout the computation of 4x4 tile + * upto m rows. + */ + x0 = _mm256_broadcast_sd(&alpha_chi[0]); + x1 = _mm256_broadcast_sd(&alpha_chi[1]); + x2 = _mm256_broadcast_sd(&alpha_chi[2]); + x3 = _mm256_broadcast_sd(&alpha_chi[3]); + + y0 = _mm256_broadcast_sd(&alpha_psi[0]); + y1 = _mm256_broadcast_sd(&alpha_psi[1]); + y2 = _mm256_broadcast_sd(&alpha_psi[2]); + y3 = _mm256_broadcast_sd(&alpha_psi[3]); + + /* Loading 4x4 tile of A matrix for + * triangular part computation + */ + a0 = _mm256_loadu_pd(a + (0 * lda) ); + a1 = _mm256_loadu_pd(a + (1 * lda) ); + a2 = _mm256_loadu_pd(a + (2 * lda) ); + a3 = _mm256_loadu_pd(a + (3 * lda) ); + + yr = _mm256_loadu_pd(y); + xr = _mm256_loadu_pd(x); + + /*Setting first element of x & y vectors to zero + * to eliminate diagonal element of 1st column + * from computation + */ + xr = _mm256_blend_pd(xr, zero, 0x1); + yr = _mm256_blend_pd(yr, zero, 0x1); + a0 = _mm256_blend_pd(a0, zero, 0x1); + a1 = _mm256_blend_pd(a1, zero, 0x3); + a2 = _mm256_blend_pd(a2, zero, 0x7); + a3 = _mm256_blend_pd(a3, zero, 0xF); + + a0 = _mm256_fmadd_pd(xr, y0, a0); + a0 = _mm256_fmadd_pd(yr, x0, a0); + + /*Setting two elements of x & y vectors to zero + * to eliminate diagonal element of 2nd column + * from computation + */ + xr = _mm256_blend_pd(xr, zero, 0x3); + yr = _mm256_blend_pd(yr, zero, 0x3); + a1 = _mm256_fmadd_pd(xr, y1, a1); + a1 = _mm256_fmadd_pd(yr, x1, a1); + + /*Setting three elements of x & y vectors to zero + * to eliminate diagonal element of 3rd column + * from computation + */ + xr = _mm256_blend_pd(xr, zero, 0x7); + yr = _mm256_blend_pd(yr, zero, 0x7); + a2 = _mm256_fmadd_pd(xr, y2, a2); + a2 = _mm256_fmadd_pd(yr, x2, a2); + + _mm256_storeu_pd(a + (0 * lda), a0 ); + + /* Loading data from memory location first + * so it could be blend with and finally + * gets stored at same location to prevent + * unnecessary data overwriting at nearby + * memory locations + */ + a3 = _mm256_loadu_pd(a + (1 * lda) ); + a1 = _mm256_blend_pd(a1, a3, 0x1); + _mm256_storeu_pd(a + (1 * lda), a1 ); + + a3 = _mm256_loadu_pd(a + (2 * lda) ); + a2 = _mm256_blend_pd(a2, a3, 0x3); + _mm256_storeu_pd(a + (2 * lda), a2 ); + + /* Triangular part of matrix is computed, remaining + * part is computed in below loop upto m rows. + */ + for(; (row + 4) <= m; row+=4) + { + /* Loading elements of x and y vector */ + xr = _mm256_loadu_pd(x + row); + yr = _mm256_loadu_pd(y + row); + /* Loading tile of A matrix of size 4x4 */ + a0 = _mm256_loadu_pd(a + row + (0 * lda) ); + a1 = _mm256_loadu_pd(a + row + (1 * lda) ); + a2 = _mm256_loadu_pd(a + row + (2 * lda) ); + a3 = _mm256_loadu_pd(a + row + (3 * lda) ); + + a0 = _mm256_fmadd_pd(xr, y0, a0); + a1 = _mm256_fmadd_pd(xr, y1, a1); + a2 = _mm256_fmadd_pd(xr, y2, a2); + a3 = _mm256_fmadd_pd(xr, y3, a3); + + a0 = _mm256_fmadd_pd(yr, x0, a0); + a1 = _mm256_fmadd_pd(yr, x1, a1); + a2 = _mm256_fmadd_pd(yr, x2, a2); + a3 = _mm256_fmadd_pd(yr, x3, a3); + + _mm256_storeu_pd(a + row + (0 * lda), a0); + _mm256_storeu_pd(a + row + (1 * lda), a1); + _mm256_storeu_pd(a + row + (2 * lda), a2); + _mm256_storeu_pd(a + row + (3 * lda), a3); + } + + /* Computes remainder cases where m is less than 4 */ + if(rem) + { + PRAGMA_SIMD + for(dim_t i = 0; i < 4; i++) + { + for(dim_t j = row; j < m; j++) + { + a[ j + (i * lda)] += x[j] * (y[i] * (*alpha)); + a[ j + (i * lda)] += y[j] * (x[i] * (*alpha)); + } + } + } + + /* Computing 4 diagonal elements of triangular part of matrix + * and storing result back at corresponding location in matrix A + */ + PRAGMA_SIMD + for(dim_t i = 0; i < 4; i++) + { + *(a + i + (i * lda)) = a_diag[i]; + } +} diff --git a/kernels/zen/bli_kernels_zen.h b/kernels/zen/bli_kernels_zen.h index e87995d2ce..62b19e9a2d 100644 --- a/kernels/zen/bli_kernels_zen.h +++ b/kernels/zen/bli_kernels_zen.h @@ -32,16 +32,6 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ -// hemv helper function -void bli_pre_hemv_8x8(double *a, double *x, - double *y, double *alpha, - dim_t cs_a, dim_t rs_a); - -void bli_post_hemv_8x8(double *a, double *x, - double *y, double *alpha, - dim_t cs_a, dim_t rs_a); - - // -- level-1m -- PACKM_KER_PROT(double, d, packm_8xk_gen_zen) PACKM_KER_PROT(double, d, packm_6xk_gen_zen) From 457c33a6017b6f10b115cca75e01d579831ed262 Mon Sep 17 00:00:00 2001 From: mkadavil Date: Tue, 21 Dec 2021 15:08:16 +0530 Subject: [PATCH 064/243] Eliminating barriers in SUP path when matrices are not packed. -Current gemm SUP path uses bli_thrinfo_sup_grow, bli_thread_range_sub to generate per thread data ranges at each loop of gemm algorithm. bli_thrinfo_sup_grow involves usage of multiple barriers for cross thread synchronization. These barriers are necessary in cases where either the A or B matrix are packed for centralized pack buffer allocation/deallocation (bli_thread_am_ochief thread). -However for cases where both A and B matrices are unpacked, these barrier are resulting in overhead for smaller dimensions. Here creation of unnecessary communicators are avoided and subsequently the requirement for barriers are eliminated when packing is disabled for both the input matrices in SUP path. Change-Id: Ic373dfd2d6b08b8f577dc98399a83bb08f794afa --- frame/thread/bli_thrinfo_sup.c | 126 ++++++++++++++++++++++----------- 1 file changed, 83 insertions(+), 43 deletions(-) diff --git a/frame/thread/bli_thrinfo_sup.c b/frame/thread/bli_thrinfo_sup.c index e67e8b6426..8ce714547c 100644 --- a/frame/thread/bli_thrinfo_sup.c +++ b/frame/thread/bli_thrinfo_sup.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -167,8 +167,23 @@ thrinfo_t* bli_thrinfo_sup_create_for_cntl thrcomm_t* static_comms[ BLIS_NUM_STATIC_COMMS ]; thrcomm_t** new_comms = NULL; + + const bool packa = bli_rntm_pack_a( rntm ); + const bool packb = bli_rntm_pack_b( rntm ); + dim_t parent_nt_in = 0; + + // thrinfo ocomm is not created when neither packa nor packb is + // enabled. Need to derive parent_nt_in without depending on ocomm in + // those cases. + if ( packa || packb ) + { + parent_nt_in = bli_thread_num_threads( thread_par ); + } + else + { + parent_nt_in = bli_rntm_calc_num_threads_in( bszid_par, rntm ); + } - const dim_t parent_nt_in = bli_thread_num_threads( thread_par ); const dim_t parent_n_way = bli_thread_n_way( thread_par ); const dim_t parent_comm_id = bli_thread_ocomm_id( thread_par ); const dim_t parent_work_id = bli_thread_work_id( thread_par ); @@ -193,50 +208,75 @@ thrinfo_t* bli_thrinfo_sup_create_for_cntl //printf( "thread %d: child_n_way = %d child_nt_in = %d parent_n_way = %d (bszid = %d->%d)\n", (int)child_comm_id, (int)child_nt_in, (int)child_n_way, (int)parent_n_way, (int)bli_cntl_bszid( cntl_par ), (int)bszid_chl ); - // The parent's chief thread creates a temporary array of thrcomm_t - // pointers. - if ( bli_thread_am_ochief( thread_par ) ) + thrinfo_t* thread_chl = NULL; + + // The communicators are only used when either packa or packb is + // enabled. This means that the communicator creation along with the + // overhead from the barriers (required for synchronizing comm across + // threads) are not required when both packa and packb are disabled. + if ( packa || packb ) { - if ( parent_n_way > BLIS_NUM_STATIC_COMMS ) - new_comms = bli_malloc_intl( parent_n_way * sizeof( thrcomm_t* ) ); - else - new_comms = static_comms; - } + // The parent's chief thread creates a temporary array of thrcomm_t + // pointers. + if ( bli_thread_am_ochief( thread_par ) ) + { + if ( parent_n_way > BLIS_NUM_STATIC_COMMS ) + new_comms = bli_malloc_intl( parent_n_way * sizeof( thrcomm_t* ) ); + else + new_comms = static_comms; + } + + // Broadcast the temporary array to all threads in the parent's + // communicator. + new_comms = bli_thread_broadcast( thread_par, new_comms ); + + // Chiefs in the child communicator allocate the communicator + // object and store it in the array element corresponding to the + // parent's work id. + if ( child_comm_id == 0 ) + new_comms[ parent_work_id ] = bli_thrcomm_create( rntm, child_nt_in ); + + bli_thread_barrier( thread_par ); + + // All threads create a new thrinfo_t node using the communicator + // that was created by their chief, as identified by parent_work_id. + thread_chl = bli_thrinfo_create + ( + rntm, // rntm + new_comms[ parent_work_id ], // ocomm + child_comm_id, // ocomm_id + child_n_way, // n_way + child_work_id, // work_id + TRUE, // free_comm + *bszid_chl, // bszid + NULL // sub_node + ); + + bli_thread_barrier( thread_par ); - // Broadcast the temporary array to all threads in the parent's - // communicator. - new_comms = bli_thread_broadcast( thread_par, new_comms ); - - // Chiefs in the child communicator allocate the communicator - // object and store it in the array element corresponding to the - // parent's work id. - if ( child_comm_id == 0 ) - new_comms[ parent_work_id ] = bli_thrcomm_create( rntm, child_nt_in ); - - bli_thread_barrier( thread_par ); - - // All threads create a new thrinfo_t node using the communicator - // that was created by their chief, as identified by parent_work_id. - thrinfo_t* thread_chl = bli_thrinfo_create - ( - rntm, // rntm - new_comms[ parent_work_id ], // ocomm - child_comm_id, // ocomm_id - child_n_way, // n_way - child_work_id, // work_id - TRUE, // free_comm - *bszid_chl, // bszid - NULL // sub_node - ); - - bli_thread_barrier( thread_par ); - - // The parent's chief thread frees the temporary array of thrcomm_t - // pointers. - if ( bli_thread_am_ochief( thread_par ) ) + // The parent's chief thread frees the temporary array of thrcomm_t + // pointers. + if ( bli_thread_am_ochief( thread_par ) ) + { + if ( parent_n_way > BLIS_NUM_STATIC_COMMS ) + bli_free_intl( new_comms ); + } + } + else { - if ( parent_n_way > BLIS_NUM_STATIC_COMMS ) - bli_free_intl( new_comms ); + // No communicator is reqiured in cases where neither packa nor + // packb is enabled. + thread_chl = bli_thrinfo_create + ( + rntm, // rntm + NULL, // ocomm + child_comm_id, // ocomm_id + child_n_way, // n_way + child_work_id, // work_id + FALSE, // free_comm + *bszid_chl, // bszid + NULL // sub_node + ); } return thread_chl; From 79c6aa56432e989164d9d1a39ee7a945b8d51a94 Mon Sep 17 00:00:00 2001 From: Harsh Dave Date: Thu, 30 Dec 2021 22:59:39 -0600 Subject: [PATCH 065/243] Implemented optimal S/DCOMPLEX dotxf kernel - Optimized dotxf implementation for double and single precision complex datatype by handling dot product computation in tile 2x6 and 4x6 handling 6 columns at a time, and rows in multiple of 2 and 4. - Dot product computation is arranged such a way that multiple rho vector register will hold the temporary result till the end of loop and finally does horizontal addition to get final dot product result. - Corner cases are handled serially. - Optimal and reuse of vector registers for faster computation. AMD-Internal: [CPUPL-1975] Change-Id: I7dd305e73adf54100d54661769c7d5aada9b0098 --- config/zen/bli_cntx_init_zen.c | 4 +- config/zen2/bli_cntx_init_zen2.c | 4 +- config/zen3/bli_cntx_init_zen3.c | 4 +- config/zen4/bli_cntx_init_zen4.c | 4 +- kernels/zen/1f/bli_dotxf_zen_int_8.c | 902 ++++++++++++++++++++++++++- kernels/zen/bli_kernels_zen.h | 3 +- 6 files changed, 915 insertions(+), 6 deletions(-) diff --git a/config/zen/bli_cntx_init_zen.c b/config/zen/bli_cntx_init_zen.c index ec356fd231..eed39b3149 100644 --- a/config/zen/bli_cntx_init_zen.c +++ b/config/zen/bli_cntx_init_zen.c @@ -80,7 +80,7 @@ void bli_cntx_init_zen( cntx_t* cntx ) // Update the context with optimized level-1f kernels. bli_cntx_set_l1f_kers ( - 7, + 9, // axpyf BLIS_AXPYF_KER, BLIS_FLOAT, bli_saxpyf_zen_int_8, BLIS_AXPYF_KER, BLIS_DOUBLE, bli_daxpyf_zen_int_8, @@ -89,6 +89,8 @@ void bli_cntx_init_zen( cntx_t* cntx ) // dotxf BLIS_DOTXF_KER, BLIS_FLOAT, bli_sdotxf_zen_int_8, BLIS_DOTXF_KER, BLIS_DOUBLE, bli_ddotxf_zen_int_8, + BLIS_DOTXF_KER, BLIS_DCOMPLEX, bli_zdotxf_zen_int_6, + BLIS_DOTXF_KER, BLIS_SCOMPLEX, bli_cdotxf_zen_int_6, //axpy2v BLIS_AXPY2V_KER, BLIS_DOUBLE, bli_daxpy2v_zen_int, cntx diff --git a/config/zen2/bli_cntx_init_zen2.c b/config/zen2/bli_cntx_init_zen2.c index 47846ef22d..f6b8eef1e4 100644 --- a/config/zen2/bli_cntx_init_zen2.c +++ b/config/zen2/bli_cntx_init_zen2.c @@ -92,7 +92,7 @@ void bli_cntx_init_zen2( cntx_t* cntx ) // Update the context with optimized level-1f kernels. bli_cntx_set_l1f_kers ( - 7, + 9, // axpyf BLIS_AXPYF_KER, BLIS_FLOAT, bli_saxpyf_zen_int_5, BLIS_AXPYF_KER, BLIS_DOUBLE, bli_daxpyf_zen_int_5, @@ -101,6 +101,8 @@ void bli_cntx_init_zen2( cntx_t* cntx ) // dotxf BLIS_DOTXF_KER, BLIS_FLOAT, bli_sdotxf_zen_int_8, BLIS_DOTXF_KER, BLIS_DOUBLE, bli_ddotxf_zen_int_8, + BLIS_DOTXF_KER, BLIS_DCOMPLEX, bli_zdotxf_zen_int_6, + BLIS_DOTXF_KER, BLIS_SCOMPLEX, bli_cdotxf_zen_int_6, // axpy2v BLIS_AXPY2V_KER, BLIS_DOUBLE, bli_daxpy2v_zen_int, cntx diff --git a/config/zen3/bli_cntx_init_zen3.c b/config/zen3/bli_cntx_init_zen3.c index 7e7b120832..a043d5ad22 100644 --- a/config/zen3/bli_cntx_init_zen3.c +++ b/config/zen3/bli_cntx_init_zen3.c @@ -92,7 +92,7 @@ void bli_cntx_init_zen3( cntx_t* cntx ) // Update the context with optimized level-1f kernels. bli_cntx_set_l1f_kers ( - 7, + 9, // axpyf BLIS_AXPYF_KER, BLIS_FLOAT, bli_saxpyf_zen_int_5, BLIS_AXPYF_KER, BLIS_DOUBLE, bli_daxpyf_zen_int_5, @@ -101,6 +101,8 @@ void bli_cntx_init_zen3( cntx_t* cntx ) // dotxf BLIS_DOTXF_KER, BLIS_FLOAT, bli_sdotxf_zen_int_8, BLIS_DOTXF_KER, BLIS_DOUBLE, bli_ddotxf_zen_int_8, + BLIS_DOTXF_KER, BLIS_DCOMPLEX, bli_zdotxf_zen_int_6, + BLIS_DOTXF_KER, BLIS_SCOMPLEX, bli_cdotxf_zen_int_6, // axpy2v BLIS_AXPY2V_KER, BLIS_DOUBLE, bli_daxpy2v_zen_int, cntx diff --git a/config/zen4/bli_cntx_init_zen4.c b/config/zen4/bli_cntx_init_zen4.c index 98c15796ef..c340fa9087 100644 --- a/config/zen4/bli_cntx_init_zen4.c +++ b/config/zen4/bli_cntx_init_zen4.c @@ -91,7 +91,7 @@ void bli_cntx_init_zen4( cntx_t* cntx ) // Update the context with optimized level-1f kernels. bli_cntx_set_l1f_kers ( - 7, + 9, // axpyf BLIS_AXPYF_KER, BLIS_FLOAT, bli_saxpyf_zen_int_5, BLIS_AXPYF_KER, BLIS_DOUBLE, bli_daxpyf_zen_int_5, @@ -100,6 +100,8 @@ void bli_cntx_init_zen4( cntx_t* cntx ) // dotxf BLIS_DOTXF_KER, BLIS_FLOAT, bli_sdotxf_zen_int_8, BLIS_DOTXF_KER, BLIS_DOUBLE, bli_ddotxf_zen_int_8, + BLIS_DOTXF_KER, BLIS_DCOMPLEX, bli_zdotxf_zen_int_6, + BLIS_DOTXF_KER, BLIS_SCOMPLEX, bli_cdotxf_zen_int_6, // axpy2v BLIS_AXPY2V_KER, BLIS_DOUBLE, bli_daxpy2v_zen_int, cntx diff --git a/kernels/zen/1f/bli_dotxf_zen_int_8.c b/kernels/zen/1f/bli_dotxf_zen_int_8.c index ad27403bdc..815e388f21 100644 --- a/kernels/zen/1f/bli_dotxf_zen_int_8.c +++ b/kernels/zen/1f/bli_dotxf_zen_int_8.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2018, The University of Texas at Austin - Copyright (C) 2017 - 21, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2017 - 22, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -1542,4 +1542,904 @@ void bli_ddotxf_zen_int_2 } } +/** + * Performs dotxf operation on dcomplex. + * x and y are vectors and a is the matrix. + * Computation is done on 6 columns at a time + * Marches through vectors in multiple of 2. + */ +void bli_zdotxf_zen_int_6 + ( + conj_t conjat, + conj_t conjx, + dim_t m, + dim_t b_n, + dcomplex* restrict alpha, + dcomplex* restrict a, inc_t inca, inc_t lda, + dcomplex* restrict x, inc_t incx, + dcomplex* restrict beta, + dcomplex* restrict y, inc_t incy, + cntx_t* restrict cntx + ) +{ + /** + * Handles only unit stride cases and 6 column at a time + * b_n check for columns to be 6. + */ + if ( (inca == 1) && (incx == 1) && (incy == 1) && (b_n == 6) ) + { + /* Temporary rho buffer holds computed dot product result */ + dcomplex r[ 6 ]; + + /* If beta is zero, clear y. Otherwise, scale by beta. */ + if ( PASTEMAC(z,eq0)( *beta ) ) + { + for ( dim_t i = 0; i < 6; ++i ) + { + PASTEMAC(z,set0s)( y[i] ); + } + } + else + { + for ( dim_t i = 0; i < 6; ++i ) + { + PASTEMAC(z,scals)( *beta, y[i] ); + } + } + + /* If the vectors are empty or if alpha is zero, return early*/ + if ( bli_zero_dim1( m ) || PASTEMAC(z,eq0)( *alpha ) ) return; + + /* Initialize r vector to 0. */ + for ( dim_t i = 0; i < 6; ++i ) PASTEMAC(z,set0s)( r[i] ); + + /* If a must be conjugated, we do so indirectly by first + * toggling the effective conjugation of x and then conjugating + * the resulting do products. + * Rather conjugating each element of a matrix, final computed result + * can be conjugated at the end of loop. This takes off the overhead + * of conjugating each element inside the loop and improves the + * performance. + */ + conj_t conjx_use = conjx; + + if ( bli_is_conj( conjat ) ) + { + bli_toggle_conj( &conjx_use ); + } + + /* Setting rho vectors to 0 */ + v4df_t rho0v; rho0v.v = _mm256_setzero_pd(); + v4df_t rho1v; rho1v.v = _mm256_setzero_pd(); + v4df_t rho2v; rho2v.v = _mm256_setzero_pd(); + v4df_t rho3v; rho3v.v = _mm256_setzero_pd(); + v4df_t rho4v; rho4v.v = _mm256_setzero_pd(); + v4df_t rho5v; rho5v.v = _mm256_setzero_pd(); + + v4df_t rho6v; rho6v.v = _mm256_setzero_pd(); + v4df_t rho7v; rho7v.v = _mm256_setzero_pd(); + v4df_t rho8v; rho8v.v = _mm256_setzero_pd(); + v4df_t rho9v; rho9v.v = _mm256_setzero_pd(); + v4df_t rho10v; rho10v.v = _mm256_setzero_pd(); + v4df_t rho11v; rho11v.v = _mm256_setzero_pd(); + + /* Holds 2 dcomplex element of x vector + * for computing dot product with A tile + */ + v4df_t x0v, x1v; + /* Holds 2x6 tile of matrix A */ + v4df_t a0v, a1v, a2v, a3v, a4v, a5v; + /** + * Since complex datatype multiplication is + * being held in two sets of rho vectors. + * Where first set holds the computaion with + * real part of vector x and other holds + * imaginary part of vector x. + * For final computation, based on conj sign + * of imaginary component needs to be toggled. + */ + __m256d no_conju = _mm256_setr_pd(-1, 1, -1, 1); + __m256d conju = _mm256_setr_pd(1, -1, 1, -1); + dim_t iter = m / 2; + dim_t rem = m % 2; + dim_t i = 0; + + if ( bli_is_noconj( conjx_use ) ) + { + if(iter) + { + for ( ; (i+1) < m; i+=2) + { + /*Load 2 dcomplex elements from + * vector x + */ + x0v.v = _mm256_loadu_pd( + (double *)(x + i) ); + /* x1v.v holds imaginary part of dcomplex + * elements from vector x + * It will do following operation. + * R0 I0 R1 I1 => I0 I0 I1 I1 + * + */ + x1v.v = _mm256_permute_pd( x0v.v, 15 ); + /* x1v.v holds real part of dcomplex + * elements from vector x + * It will do following operation. + * R0 I0 R1 I1 => R0 R0 R1 R1 + */ + x0v.v = _mm256_permute_pd( x0v.v, 0 ); + + /*Load 2x6 tile of matrix A*/ + a0v.v = _mm256_loadu_pd( (double *) + (a + i + 0 * lda) ); + a1v.v = _mm256_loadu_pd( (double *) + (a + i + 1 * lda) ); + a2v.v = _mm256_loadu_pd( (double *) + (a + i + 2 * lda) ); + a3v.v = _mm256_loadu_pd( (double *) + (a + i + 3 * lda) ); + a4v.v = _mm256_loadu_pd( (double *) + (a + i + 4 * lda) ); + a5v.v = _mm256_loadu_pd( (double *) + (a + i + 5 * lda) ); + + // perform: rho?v += a?v * x0v; + rho0v.v = _mm256_fmadd_pd( a0v.v, + x0v.v, rho0v.v ); + rho6v.v = _mm256_fmadd_pd( a0v.v, + x1v.v, rho6v.v ); + + rho1v.v = _mm256_fmadd_pd( a1v.v, + x0v.v, rho1v.v ); + rho7v.v = _mm256_fmadd_pd( a1v.v, + x1v.v, rho7v.v ); + + rho2v.v = _mm256_fmadd_pd( a2v.v, + x0v.v, rho2v.v ); + rho8v.v = _mm256_fmadd_pd( a2v.v, + x1v.v, rho8v.v ); + + rho3v.v = _mm256_fmadd_pd( a3v.v, + x0v.v, rho3v.v ); + rho9v.v = _mm256_fmadd_pd( a3v.v, + x1v.v, rho9v.v ); + + rho4v.v = _mm256_fmadd_pd( a4v.v, + x0v.v, rho4v.v ); + rho10v.v = _mm256_fmadd_pd( a4v.v, + x1v.v, rho10v.v ); + + rho5v.v = _mm256_fmadd_pd( a5v.v, + x0v.v, rho5v.v ); + rho11v.v = _mm256_fmadd_pd( a5v.v, + x1v.v, rho11v.v ); + } + + /*Swapping position of real and imag component + * for horizontal addition to get the final + * dot product computation + * rho register are holding computation which needs + * to be arranged in following manner. + * Ra0*Ix0 | Ia0*Ix0 | Ra1*Ix1 | Ia1*Ix1 + * || + * \/ + * Ia0*Ix0 | Ra0*Ix0 | Ia1*Ix1 | Ra1*Ix1 + */ + rho6v.v = _mm256_permute_pd(rho6v.v, 0x05); + rho7v.v = _mm256_permute_pd(rho7v.v, 0x05); + rho8v.v = _mm256_permute_pd(rho8v.v, 0x05); + rho9v.v = _mm256_permute_pd(rho9v.v, 0x05); + rho10v.v = _mm256_permute_pd(rho10v.v, 0x05); + rho11v.v = _mm256_permute_pd(rho11v.v, 0x05); + + /*Negating imaginary part for computing + * the final result of dcomplex multiplication + */ + rho6v.v = _mm256_mul_pd(rho6v.v, no_conju); + rho7v.v = _mm256_mul_pd(rho7v.v, no_conju); + rho8v.v = _mm256_mul_pd(rho8v.v, no_conju); + rho9v.v = _mm256_mul_pd(rho9v.v, no_conju); + rho10v.v = _mm256_mul_pd(rho10v.v, no_conju); + rho11v.v = _mm256_mul_pd(rho11v.v, no_conju); + + rho0v.v = _mm256_add_pd(rho0v.v, rho6v.v); + rho1v.v = _mm256_add_pd(rho1v.v, rho7v.v); + rho2v.v = _mm256_add_pd(rho2v.v, rho8v.v); + rho3v.v = _mm256_add_pd(rho3v.v, rho9v.v); + rho4v.v = _mm256_add_pd(rho4v.v, rho10v.v); + rho5v.v = _mm256_add_pd(rho5v.v, rho11v.v); + + /*rho0, rho1, rho2 holds final dot product + * result of 6 dcomplex elements. + */ + rho0v.d[0] += rho0v.d[2]; + rho0v.d[1] += rho0v.d[3]; + + rho0v.d[2] = rho1v.d[0] + rho1v.d[2]; + rho0v.d[3] = rho1v.d[1] + rho1v.d[3]; + + rho1v.d[0] = rho2v.d[0] + rho2v.d[2]; + rho1v.d[1] = rho2v.d[1] + rho2v.d[3]; + + rho1v.d[2] = rho3v.d[0] + rho3v.d[2]; + rho1v.d[3] = rho3v.d[1] + rho3v.d[3]; + + rho2v.d[0] = rho4v.d[0] + rho4v.d[2]; + rho2v.d[1] = rho4v.d[1] + rho4v.d[3]; + + rho2v.d[2] = rho5v.d[0] + rho5v.d[2]; + rho2v.d[3] = rho5v.d[1] + rho5v.d[3]; + + /*Computed dot product result is being stored + * in temp buffer r for further computation. + */ + _mm256_storeu_pd((double *)r, rho0v.v); + _mm256_storeu_pd((double *)(r+2) , rho1v.v); + _mm256_storeu_pd((double *)(r+4) , rho2v.v); + + } + /*handles remainder cases*/ + if(rem) + { + PRAGMA_SIMD + for(dim_t p = 0; p < 6 ; p++) + { + PASTEMAC(z,axpys)( a[i + p*lda] + , x[i], r[p] ); + } + } + } + else + { + if(iter) + { + for ( ; (i+1) < m; i+=2) + { + /*Load 2 dcomplex elements from + * vector x + */ + x0v.v = _mm256_loadu_pd( (double *) + (x + i) ); + /* x1v.v holds imaginary part of dcomplex + * elements from vector x + */ + x1v.v = _mm256_permute_pd( x0v.v, 15 ); + /* x1v.v holds real part of dcomplex + * elements from vector x + */ + x0v.v = _mm256_permute_pd( x0v.v, 0 ); + + /*Load 2x6 tile of matrix A*/ + a0v.v = _mm256_loadu_pd( (double *) + (a + i + 0 * lda)); + a1v.v = _mm256_loadu_pd( (double *) + (a + i + 1 * lda)); + a2v.v = _mm256_loadu_pd( (double *) + (a + i + 2 * lda)); + a3v.v = _mm256_loadu_pd( (double *) + (a + i + 3 * lda)); + a4v.v = _mm256_loadu_pd( (double *) + (a + i + 4 * lda)); + a5v.v = _mm256_loadu_pd( (double *) + (a + i + 5 * lda)); + + // perform: rho?v += a?v * x0v; + rho0v.v = _mm256_fmadd_pd( a0v.v, + x0v.v, rho0v.v ); + rho6v.v = _mm256_fmadd_pd( a0v.v, + x1v.v, rho6v.v ); + + rho1v.v = _mm256_fmadd_pd( a1v.v, + x0v.v, rho1v.v ); + rho7v.v = _mm256_fmadd_pd( a1v.v, + x1v.v, rho7v.v ); + + rho2v.v = _mm256_fmadd_pd( a2v.v, + x0v.v, rho2v.v ); + rho8v.v = _mm256_fmadd_pd( a2v.v, + x1v.v, rho8v.v ); + + rho3v.v = _mm256_fmadd_pd( a3v.v, + x0v.v, rho3v.v ); + rho9v.v = _mm256_fmadd_pd( a3v.v, + x1v.v, rho9v.v ); + + rho4v.v = _mm256_fmadd_pd( a4v.v, + x0v.v, rho4v.v ); + rho10v.v = _mm256_fmadd_pd( a4v.v, + x1v.v, rho10v.v ); + + rho5v.v = _mm256_fmadd_pd( a5v.v, + x0v.v, rho5v.v ); + rho11v.v = _mm256_fmadd_pd( a5v.v, + x1v.v, rho11v.v ); + } + + /*Swapping position of real and imag component + * for horizontal addition to get the final + * dot product computation + * rho register are holding computation which needs + * to be arranged in following manner. + * Ra0*Ix0 | Ia0*Ix0 | Ra1*Ix1 | Ia1*Ix1 + * || + * \/ + * Ia0*Ix0 | Ra0*Ix0 | Ia1*Ix1 | Ra1*Ix1 + */ + rho6v.v = _mm256_permute_pd(rho6v.v, 0x05); + rho7v.v = _mm256_permute_pd(rho7v.v, 0x05); + rho8v.v = _mm256_permute_pd(rho8v.v, 0x05); + rho9v.v = _mm256_permute_pd(rho9v.v, 0x05); + rho10v.v = _mm256_permute_pd(rho10v.v, 0x05); + rho11v.v = _mm256_permute_pd(rho11v.v, 0x05); + + /*Negating imaginary part for computing + * the final result of dcomplex multiplication + */ + rho6v.v = _mm256_mul_pd(rho6v.v, conju); + rho7v.v = _mm256_mul_pd(rho7v.v, conju); + rho8v.v = _mm256_mul_pd(rho8v.v, conju); + rho9v.v = _mm256_mul_pd(rho9v.v, conju); + rho10v.v = _mm256_mul_pd(rho10v.v, conju); + rho11v.v = _mm256_mul_pd(rho11v.v, conju); + + rho0v.v = _mm256_add_pd(rho0v.v, rho6v.v); + rho1v.v = _mm256_add_pd(rho1v.v, rho7v.v); + rho2v.v = _mm256_add_pd(rho2v.v, rho8v.v); + rho3v.v = _mm256_add_pd(rho3v.v, rho9v.v); + rho4v.v = _mm256_add_pd(rho4v.v, rho10v.v); + rho5v.v = _mm256_add_pd(rho5v.v, rho11v.v); + + /*rho0, rho1, rho2 holds final dot product + * result of 6 dcomplex elements. + */ + rho0v.d[0] += rho0v.d[2]; + rho0v.d[1] += rho0v.d[3]; + + rho0v.d[2] = rho1v.d[0] + rho1v.d[2]; + rho0v.d[3] = rho1v.d[1] + rho1v.d[3]; + + rho1v.d[0] = rho2v.d[0] + rho2v.d[2]; + rho1v.d[1] = rho2v.d[1] + rho2v.d[3]; + + rho1v.d[2] = rho3v.d[0] + rho3v.d[2]; + rho1v.d[3] = rho3v.d[1] + rho3v.d[3]; + + rho2v.d[0] = rho4v.d[0] + rho4v.d[2]; + rho2v.d[1] = rho4v.d[1] + rho4v.d[3]; + + rho2v.d[2] = rho5v.d[0] + rho5v.d[2]; + rho2v.d[3] = rho5v.d[1] + rho5v.d[3]; + + /*Computed dot product result is being stored + * in temp buffer r for further computation. + */ + _mm256_storeu_pd((double *)r, rho0v.v); + _mm256_storeu_pd((double *)(r+2) , rho1v.v); + _mm256_storeu_pd((double *)(r+4) , rho2v.v); + + } + if(rem) + { + PRAGMA_SIMD + for(dim_t p = 0; p < 6 ; p++) + { + PASTEMAC(z,axpyjs)(a[i + p*lda] + , x[i], r[p] ); + } + } + } + + if ( bli_is_conj( conjat ) ) + for ( dim_t i = 0; i < 6; ++i ) + { + PASTEMAC(z,conjs)( r[i] ); + } + + /*scaling dot product result with alpha and + * adding the result to vector + */ + for ( dim_t i = 0; i < 6; ++i ) + { + PASTEMAC(z,axpys)( *alpha, r[i], y[i] ); + } + } + else + { + /* Query the context for the kernel function pointer. */ + const num_t dt = PASTEMAC(z,type); + PASTECH(z,dotxv_ker_ft) kfp_dv + = + bli_cntx_get_l1v_ker_dt( dt, BLIS_DOTXV_KER, cntx ); + + for ( dim_t i = 0; i < b_n; ++i ) + { + dcomplex* restrict a1 = a + (0 )*inca + (i )*lda; + dcomplex* restrict x1 = x + (0 )*incx; + dcomplex* restrict psi1 = y + (i )*incy; + + kfp_dv + ( + conjat, + conjx, + m, + alpha, + a1, inca, + x1, incx, + beta, + psi1, + cntx + ); + } + } + +} + + +/** + * Performs dotxf operation on scomplex. + * x and y are vectors and a is the matrix. + * Computation is done on 6 columns at a time + * Marches through vectors in multiple of 4 and 2. + */ +void bli_cdotxf_zen_int_6 + ( + conj_t conjat, + conj_t conjx, + dim_t m, + dim_t b_n, + scomplex* restrict alpha, + scomplex* restrict a, inc_t inca, inc_t lda, + scomplex* restrict x, inc_t incx, + scomplex* restrict beta, + scomplex* restrict y, inc_t incy, + cntx_t* restrict cntx + ) +{ + if ( (inca == 1) && (incx == 1) && (incy == 1) && (b_n == 6) ) + { + /* Temporary rho buffer holds computed dot product result */ + scomplex r[ 6 ]; + + /* If beta is zero, clear y. Otherwise, scale by beta. */ + if ( PASTEMAC(c,eq0)( *beta ) ) + { + for ( dim_t i = 0; i < 6; ++i ) + { + PASTEMAC(c,set0s)( y[i] ); + } + } + else + { + for ( dim_t i = 0; i < 6; ++i ) + { + PASTEMAC(c,scals)( *beta, y[i] ); + } + } + + /* If the vectors are empty or if alpha is zero, return early. */ + if ( bli_zero_dim1( m ) || PASTEMAC(c,eq0)( *alpha ) ) return; + + /* Initialize r vector to 0. */ + for ( dim_t i = 0; i < 6; ++i ) PASTEMAC(c,set0s)( r[i] ); + + /* If a must be conjugated, we do so indirectly by first toggling the + effective conjugation of x and then conjugating the resulting do + products. */ + conj_t conjx_use = conjx; + + if ( bli_is_conj( conjat ) ) + bli_toggle_conj( &conjx_use ); + + dim_t iter = m / 2; + dim_t iter4 = m / 4; + dim_t rem = m % 2; + dim_t i = 0; + if(iter) + { + if(iter4) + { + /* Setting rho vectors to 0 */ + __m256 rho0v; rho0v = _mm256_setzero_ps(); + __m256 rho1v; rho1v = _mm256_setzero_ps(); + __m256 rho2v; rho2v = _mm256_setzero_ps(); + __m256 rho3v; rho3v = _mm256_setzero_ps(); + __m256 rho4v; rho4v = _mm256_setzero_ps(); + __m256 rho5v; rho5v = _mm256_setzero_ps(); + + __m256 rho6v; rho6v = _mm256_setzero_ps(); + __m256 rho7v; rho7v = _mm256_setzero_ps(); + __m256 rho8v; rho8v = _mm256_setzero_ps(); + __m256 rho9v; rho9v = _mm256_setzero_ps(); + __m256 rho10v; rho10v = _mm256_setzero_ps(); + __m256 rho11v; rho11v = _mm256_setzero_ps(); + /* Holds 2 dcomplex element of x vector + * for computing dot product with A tile + */ + __m256 x0v, x1v; + /* Holds 2x6 tile of matrix A */ + __m256 a0v, a1v, a2v, a3v, a4v, a5v; + /** + * Since complex datatype multiplication is + * being held in two sets of rho vectors. + * Where first set holds the computaion with + * real part of vector x and other holds + * imaginary part of vector x. + * For final computation, based on conj sign + * of imaginary component needs to be toggled. + */ + __m256 no_conju = _mm256_setr_ps(-1, 1, -1, 1, -1, 1, -1, 1); + __m256 conju = _mm256_setr_ps(1, -1, 1, -1, 1, -1, 1, -1); + + // March through vectos in multiple of 4. + for ( ; (i+3) < m; i+=4) + { + /*Load 4 scomplex elements from vector x*/ + x0v = _mm256_loadu_ps( (float *) (x + i) ); + /* x1v.v holds imaginary part of dcomplex + * elements from vector x + */ + x1v = _mm256_permute_ps( x0v, 0xf5 ); + /* x1v.v holds real part of dcomplex + * elements from vector x + */ + x0v = _mm256_permute_ps( x0v, 0xa0); + /* x1v.v holds imag part of dcomplex + Load 4x6 tile of matrix A*/ + a0v = _mm256_loadu_ps( (float *)(a + i + 0 * lda)); + a1v = _mm256_loadu_ps( (float *)(a + i + 1 * lda)); + a2v = _mm256_loadu_ps( (float *)(a + i + 2 * lda)); + a3v = _mm256_loadu_ps( (float *)(a + i + 3 * lda)); + a4v = _mm256_loadu_ps( (float *)(a + i + 4 * lda)); + a5v = _mm256_loadu_ps( (float *)(a + i + 5 * lda)); + + // perform: rho?v += a?v * x0v; + + rho0v = _mm256_fmadd_ps( a0v, x0v, rho0v ); + rho6v = _mm256_fmadd_ps( a0v, x1v, rho6v ); + + rho1v = _mm256_fmadd_ps( a1v, x0v, rho1v ); + rho7v = _mm256_fmadd_ps( a1v, x1v, rho7v ); + + rho2v = _mm256_fmadd_ps( a2v, x0v, rho2v ); + rho8v = _mm256_fmadd_ps( a2v, x1v, rho8v ); + + rho3v = _mm256_fmadd_ps( a3v, x0v, rho3v ); + rho9v = _mm256_fmadd_ps( a3v, x1v, rho9v ); + + rho4v = _mm256_fmadd_ps( a4v, x0v, rho4v ); + rho10v = _mm256_fmadd_ps( a4v, x1v, rho10v ); + + rho5v = _mm256_fmadd_ps( a5v, x0v, rho5v ); + rho11v = _mm256_fmadd_ps( a5v, x1v, rho11v ); + } + + + /*Swapping position of real and imag component + * for horizontal addition to get the final + * dot product computation + * rho register are holding computation which needs + * to be arranged in following manner. + * Ra0*Ix0 | Ia0*Ix0 | Ra1*Ix1 | Ia1*Ix1 + * || + * \/ + * Ia0*Ix0 | Ra0*Ix0 | Ia1*Ix1 | Ra1*Ix1 + */ + + rho6v = _mm256_permute_ps(rho6v, 0xb1); + rho7v = _mm256_permute_ps(rho7v, 0xb1); + rho8v = _mm256_permute_ps(rho8v, 0xb1); + rho9v = _mm256_permute_ps(rho9v, 0xb1); + rho10v = _mm256_permute_ps(rho10v, 0xb1); + rho11v = _mm256_permute_ps(rho11v, 0xb1); + + /*Negating imaginary part for computing + * the final result of dcomplex multiplication*/ + if ( bli_is_noconj( conjx_use ) ) + { + rho6v = _mm256_mul_ps(rho6v, no_conju); + rho7v = _mm256_mul_ps(rho7v, no_conju); + rho8v = _mm256_mul_ps(rho8v, no_conju); + rho9v = _mm256_mul_ps(rho9v, no_conju); + rho10v = _mm256_mul_ps(rho10v, no_conju); + rho11v = _mm256_mul_ps(rho11v, no_conju); + } + else + { + + rho6v = _mm256_mul_ps(rho6v, conju); + rho7v = _mm256_mul_ps(rho7v, conju); + rho8v = _mm256_mul_ps(rho8v, conju); + rho9v = _mm256_mul_ps(rho9v, conju); + rho10v = _mm256_mul_ps(rho10v, conju); + rho11v = _mm256_mul_ps(rho11v, conju); + + } + + rho0v = _mm256_add_ps(rho0v, rho6v); + rho1v = _mm256_add_ps(rho1v, rho7v); + rho2v = _mm256_add_ps(rho2v, rho8v); + rho3v = _mm256_add_ps(rho3v, rho9v); + rho4v = _mm256_add_ps(rho4v, rho10v); + rho5v = _mm256_add_ps(rho5v, rho11v); + + /** + * Horizontal addition of rho elements + * for computing final dotxf result. + * ptr pointer addresses all 6 rho + * register one by one and store the + * computed result into r buffer. + */ + scomplex *ptr = (scomplex *)&rho0v; + for(dim_t i = 0; i < 4; i++) + { + r[0].real += ptr[i].real; + r[0].imag += ptr[i].imag; + } + ptr = (scomplex *)&rho1v; + for(dim_t i = 0; i < 4; i++) + { + r[1].real += ptr[i].real; + r[1].imag += ptr[i].imag; + } + ptr = (scomplex *)&rho2v; + for(dim_t i = 0; i < 4; i++) + { + r[2].real += ptr[i].real; + r[2].imag += ptr[i].imag; + } + ptr = (scomplex *)&rho3v; + for(dim_t i = 0; i < 4; i++) + { + r[3].real += ptr[i].real; + r[3].imag += ptr[i].imag; + } + ptr = (scomplex *)&rho4v; + for(dim_t i = 0; i < 4; i++) + { + r[4].real += ptr[i].real; + r[4].imag += ptr[i].imag; + } + ptr = (scomplex *)&rho5v; + for(dim_t i = 0; i < 4; i++) + { + r[5].real += ptr[i].real; + r[5].imag += ptr[i].imag; + } + } + // March through vectos in multiple of 2. + if(i+1 < m) + { + /* Setting rho vectors to 0 */ + __m128 rho0v; rho0v = _mm_setzero_ps(); + __m128 rho1v; rho1v = _mm_setzero_ps(); + __m128 rho2v; rho2v = _mm_setzero_ps(); + __m128 rho3v; rho3v = _mm_setzero_ps(); + __m128 rho4v; rho4v = _mm_setzero_ps(); + __m128 rho5v; rho5v = _mm_setzero_ps(); + + __m128 rho6v; rho6v = _mm_setzero_ps(); + __m128 rho7v; rho7v = _mm_setzero_ps(); + __m128 rho8v; rho8v = _mm_setzero_ps(); + __m128 rho9v; rho9v = _mm_setzero_ps(); + __m128 rho10v; rho10v = _mm_setzero_ps(); + __m128 rho11v; rho11v = _mm_setzero_ps(); + /* Holds 2 dcomplex element of x vector + * for computing dot product with A tile + */ + __m128 x0v, x1v; + /* Holds 2x6 tile of matrix A */ + __m128 a0v, a1v, a2v, a3v, a4v, a5v; + /** + * Since complex datatype multiplication is + * being held in two sets of rho vectors. + * Where first set holds the computaion with + * real part of vector x and other holds + * imaginary part of vector x. + * For final computation, based on conj sign + * of imaginary component needs to be toggled. + */ + __m128 no_conju = _mm_setr_ps(-1, 1, -1, 1); + __m128 conju = _mm_setr_ps(1, -1, 1, -1); + + for ( ; (i+1) < m; i+=2) + { + /*Load 4 scomplex elements from vector x*/ + x0v = _mm_loadu_ps( (float *)(x + i) ); + /* x1v.v holds imaginary part of dcomplex + * elements from vector x + */ + x1v = _mm_permute_ps( x0v, 0xf5 ); + /* x1v.v holds real part of dcomplex + * elements from vector x + */ + x0v = _mm_permute_ps( x0v, 0xa0); + /* x1v.v holds imag part of dcomplex + Load 4x6 tile of matrix A*/ + + a0v = _mm_loadu_ps( (float *)(a + i + 0 * lda)); + a1v = _mm_loadu_ps( (float *)(a + i + 1 * lda)); + a2v = _mm_loadu_ps( (float *)(a + i + 2 * lda)); + a3v = _mm_loadu_ps( (float *)(a + i + 3 * lda)); + a4v = _mm_loadu_ps( (float *)(a + i + 4 * lda)); + a5v = _mm_loadu_ps( (float *)(a + i + 5 * lda)); + + // perform: rho?v += a?v * x0v; + + rho0v = _mm_fmadd_ps( a0v, x0v, rho0v ); + rho6v = _mm_fmadd_ps( a0v, x1v, rho6v ); + + rho1v = _mm_fmadd_ps( a1v, x0v, rho1v ); + rho7v = _mm_fmadd_ps( a1v, x1v, rho7v ); + + rho2v = _mm_fmadd_ps( a2v, x0v, rho2v ); + rho8v = _mm_fmadd_ps( a2v, x1v, rho8v ); + + rho3v = _mm_fmadd_ps( a3v, x0v, rho3v ); + rho9v = _mm_fmadd_ps( a3v, x1v, rho9v ); + + rho4v = _mm_fmadd_ps( a4v, x0v, rho4v ); + rho10v = _mm_fmadd_ps( a4v, x1v, rho10v ); + + rho5v = _mm_fmadd_ps( a5v, x0v, rho5v ); + rho11v = _mm_fmadd_ps( a5v, x1v, rho11v ); + } + /*Swapping position of real and imag component + * for horizontal addition to get the final + * dot product computation + * rho register are holding computation which needs + * to be arranged in following manner. + * Ra0*Ix0 | Ia0*Ix0 | Ra1*Ix1 | Ia1*Ix1 + * || + * \/ + * Ia0*Ix0 | Ra0*Ix0 | Ia1*Ix1 | Ra1*Ix1 + */ + rho6v = _mm_permute_ps(rho6v, 0xb1); + rho7v = _mm_permute_ps(rho7v, 0xb1); + rho8v = _mm_permute_ps(rho8v, 0xb1); + rho9v = _mm_permute_ps(rho9v, 0xb1); + rho10v = _mm_permute_ps(rho10v, 0xb1); + rho11v = _mm_permute_ps(rho11v, 0xb1); + + /*Negating imaginary part for computing + * the final result of dcomplex multiplication*/ + if ( bli_is_noconj( conjx_use ) ) + { + + rho6v = _mm_mul_ps(rho6v, no_conju); + rho7v = _mm_mul_ps(rho7v, no_conju); + rho8v = _mm_mul_ps(rho8v, no_conju); + rho9v = _mm_mul_ps(rho9v, no_conju); + rho10v = _mm_mul_ps(rho10v, no_conju); + rho11v = _mm_mul_ps(rho11v, no_conju); + } + else + { + rho6v = _mm_mul_ps(rho6v, conju); + rho7v = _mm_mul_ps(rho7v, conju); + rho8v = _mm_mul_ps(rho8v, conju); + rho9v = _mm_mul_ps(rho9v, conju); + rho10v = _mm_mul_ps(rho10v, conju); + rho11v = _mm_mul_ps(rho11v, conju); + } + + rho0v = _mm_add_ps(rho0v, rho6v); + rho1v = _mm_add_ps(rho1v, rho7v); + rho2v = _mm_add_ps(rho2v, rho8v); + rho3v = _mm_add_ps(rho3v, rho9v); + rho4v = _mm_add_ps(rho4v, rho10v); + rho5v = _mm_add_ps(rho5v, rho11v); + + /** + * Horizontal addition of rho elements + * for computing final dotxf result. + * ptr pointer addresses all 6 rho + * register one by one and store the + * computed result into r buffer. + */ + scomplex *ptr = (scomplex *)&rho0v; + for(dim_t i = 0; i < 2; i++) + { + r[0].real += ptr[i].real; + r[0].imag += ptr[i].imag; + } + ptr = (scomplex *)&rho1v; + for(dim_t i = 0; i < 2; i++) + { + r[1].real += ptr[i].real; + r[1].imag += ptr[i].imag; + } + ptr = (scomplex *)&rho2v; + for(dim_t i = 0; i < 2; i++) + { + r[2].real += ptr[i].real; + r[2].imag += ptr[i].imag; + } + ptr = (scomplex *)&rho3v; + for(dim_t i = 0; i < 2; i++) + { + r[3].real += ptr[i].real; + r[3].imag += ptr[i].imag; + } + ptr = (scomplex *)&rho4v; + for(dim_t i = 0; i < 2; i++) + { + r[4].real += ptr[i].real; + r[4].imag += ptr[i].imag; + } + ptr = (scomplex *)&rho5v; + for(dim_t i = 0; i < 2; i++) + { + r[5].real += ptr[i].real; + r[5].imag += ptr[i].imag; + } + } + } + /*handles remainder cases*/ + if(rem) + { + if ( bli_is_noconj( conjx_use ) ) + { + + PRAGMA_SIMD + for(dim_t p = 0; p < 6 ; p++) + { + PASTEMAC(c,axpys)( a[i + p*lda], x[i], r[p] ); + } + } + else + { + PRAGMA_SIMD + for(dim_t p = 0; p < 6 ; p++) + { + PASTEMAC(c,axpyjs)( a[i + p*lda], x[i], r[p] ); + } + + } + } + + if ( bli_is_conj( conjat ) ) + { + for ( dim_t i = 0; i < 6; ++i ) + { + PASTEMAC(c,conjs)( r[i] ); + } + } + + /*scaling dot product result with alpha and + * adding the result to vector + */ + for ( dim_t i = 0; i < 6; ++i ) + { + PASTEMAC(c,axpys)( *alpha, r[i], y[i] ); + } + } + else + { + /* Query the context for the kernel function pointer. */ + const num_t dt = PASTEMAC(c,type); + PASTECH(c,dotxv_ker_ft) kfp_dv + = + bli_cntx_get_l1v_ker_dt( dt, BLIS_DOTXV_KER, cntx ); + + for ( dim_t i = 0; i < b_n; ++i ) + { + scomplex* restrict a1 = a + (0 )*inca + (i )*lda; + scomplex* restrict x1 = x + (0 )*incx; + scomplex* restrict psi1 = y + (i )*incy; + + kfp_dv + ( + conjat, + conjx, + m, + alpha, + a1, inca, + x1, incx, + beta, + psi1, + cntx + ); + } + } +} diff --git a/kernels/zen/bli_kernels_zen.h b/kernels/zen/bli_kernels_zen.h index 62b19e9a2d..fea59796a8 100644 --- a/kernels/zen/bli_kernels_zen.h +++ b/kernels/zen/bli_kernels_zen.h @@ -123,7 +123,8 @@ DOTXF_KER_PROT( float, s, dotxf_zen_int_8 ) DOTXF_KER_PROT( double, d, dotxf_zen_int_8 ) DOTXF_KER_PROT( double, d, dotxf_zen_int_4 ) DOTXF_KER_PROT( double, d, dotxf_zen_int_2 ) - +DOTXF_KER_PROT( dcomplex, z, dotxf_zen_int_6 ) +DOTXF_KER_PROT( scomplex, c, dotxf_zen_int_6 ) // dotxaxpyf (intrinsics) DOTXAXPYF_KER_PROT( double, d, dotxaxpyf_zen_int_8 ) From f63f78d783accc5a86b3e6469e647698518e4e2c Mon Sep 17 00:00:00 2001 From: Dipal M Zambare Date: Mon, 20 Dec 2021 09:43:13 +0530 Subject: [PATCH 066/243] Removed Arch specific code from BLIS framework. - Removed BLIS_CONFIG_EPYC macro - The code dependent on this macro is handled in one of the three ways -- It is updated to work across platforms. -- Added in architecture/feature specific runtime checks. -- Duplicated in AMD specific files. Build system is updated to pick AMD specific files when library is built for any of the zen architecture AMD-Internal: [CPUPL-1960] Change-Id: I6f9f8018e41fa48eb43ae4245c9c2c361857f43b --- Makefile | 24 +- build/config.mk.in | 4 +- config/amdzen/make_defs.mk | 12 +- config/zen/make_defs.mk | 19 +- config/zen2/make_defs.mk | 16 +- config/zen3/make_defs.mk | 16 +- config/zen4/make_defs.mk | 16 +- configure | 3 +- frame/2/gemv/bli_gemv_unf_var1.c | 356 +----- frame/2/gemv/bli_gemv_unf_var1_amd.c | 440 ++++++++ frame/2/gemv/bli_gemv_unf_var2.c | 764 +------------ frame/2/gemv/bli_gemv_unf_var2_amd.c | 879 +++++++++++++++ frame/2/hemv/bli_hemv_unf_var1.c | 204 +--- frame/2/hemv/bli_hemv_unf_var1_amd.c | 418 +++++++ frame/2/hemv/bli_hemv_unf_var3.c | 208 +--- frame/2/hemv/bli_hemv_unf_var3_amd.c | 420 +++++++ frame/2/her2/bli_her2_unf_var1.c | 212 ---- frame/2/her2/bli_her2_unf_var1_amd.c | 369 ++++++ frame/2/her2/bli_her2_unf_var4.c | 187 ---- frame/2/her2/bli_her2_unf_var4_amd.c | 354 ++++++ frame/2/trsv/bli_trsv_unf_var1.c | 419 +------ frame/2/trsv/bli_trsv_unf_var1_amd.c | 638 +++++++++++ frame/2/trsv/bli_trsv_unf_var2.c | 804 +------------- frame/2/trsv/bli_trsv_unf_var2_amd.c | 1024 +++++++++++++++++ frame/3/bli_l3_sup_int.c | 128 +-- frame/3/bli_l3_sup_int_amd.c | 352 ++++++ frame/3/gemm/bli_gemm_front.c | 13 - frame/3/gemm/bli_gemm_front_amd.c | 413 +++++++ frame/base/bli_cpuid.c | 19 + frame/base/bli_cpuid.h | 4 +- frame/compat/bla_amax.c | 214 +--- frame/compat/bla_amax_amd.c | 295 +++++ frame/compat/bla_axpy.c | 407 +------ frame/compat/bla_axpy_amd.c | 462 ++++++++ frame/compat/bla_copy.c | 214 +--- frame/compat/bla_copy_amd.c | 285 +++++ frame/compat/bla_dot.c | 678 +---------- frame/compat/bla_dot_amd.c | 841 ++++++++++++++ frame/compat/bla_gemm.c | 507 +-------- frame/compat/bla_gemm_amd.c | 894 +++++++++++++++ frame/compat/bla_gemv.c | 853 +------------- frame/compat/bla_gemv_amd.c | 963 ++++++++++++++++ frame/compat/bla_scal.c | 176 +-- frame/compat/bla_scal_amd.c | 260 +++++ frame/compat/bla_swap.c | 195 +--- frame/compat/bla_swap_amd.c | 268 +++++ frame/compat/bla_trsm.c | 1172 +------------------ frame/compat/bla_trsm_amd.c | 1544 ++++++++++++++++++++++++++ kernels/zen/1/bli_scalv_zen_int10.c | 28 +- kernels/zen/1f/bli_axpyf_zen_int_4.c | 49 +- kernels/zen/1f/bli_axpyf_zen_int_5.c | 173 +-- kernels/zen/1f/bli_axpyf_zen_int_6.c | 26 +- kernels/zen/3/bli_gemm_small.c | 15 +- 53 files changed, 11226 insertions(+), 8028 deletions(-) create mode 100644 frame/2/gemv/bli_gemv_unf_var1_amd.c create mode 100644 frame/2/gemv/bli_gemv_unf_var2_amd.c create mode 100644 frame/2/hemv/bli_hemv_unf_var1_amd.c create mode 100644 frame/2/hemv/bli_hemv_unf_var3_amd.c create mode 100644 frame/2/her2/bli_her2_unf_var1_amd.c create mode 100644 frame/2/her2/bli_her2_unf_var4_amd.c create mode 100644 frame/2/trsv/bli_trsv_unf_var1_amd.c create mode 100644 frame/2/trsv/bli_trsv_unf_var2_amd.c create mode 100644 frame/3/bli_l3_sup_int_amd.c create mode 100644 frame/3/gemm/bli_gemm_front_amd.c create mode 100644 frame/compat/bla_amax_amd.c create mode 100644 frame/compat/bla_axpy_amd.c create mode 100644 frame/compat/bla_copy_amd.c create mode 100644 frame/compat/bla_dot_amd.c create mode 100644 frame/compat/bla_gemm_amd.c create mode 100644 frame/compat/bla_gemv_amd.c create mode 100644 frame/compat/bla_scal_amd.c create mode 100644 frame/compat/bla_swap_amd.c create mode 100644 frame/compat/bla_trsm_amd.c diff --git a/Makefile b/Makefile index b248d5781a..1658e16de2 100644 --- a/Makefile +++ b/Makefile @@ -5,7 +5,7 @@ # libraries. # # Copyright (C) 2014, The University of Texas at Austin -# Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -212,6 +212,27 @@ MK_REFKERN_OBJS := $(foreach arch, $(CONFIG_LIST), \ # Generate object file paths for all of the portable framework source code. MK_FRAME_OBJS := $(call gen-obj-paths-from-src,$(FRAME_SRC_SUFS),$(MK_FRAME_SRC),$(FRAME_PATH),$(BASE_OBJ_FRAME_PATH)) +# AMD has optimized some of the framework files, these optimizations +# may not be compatible with other platforms. +# +# In order to keep main framework code independent of AMD changes, +# AMD has duplicated the files and updated them for example +# frame/compact/bla_gemm.c : generic framework file +# frame/compact/bla_gemm_amd.c : AMD optimized framework file +# Based on the archiecture we choose correct files + +ifeq ($(MK_IS_ARCH_ZEN),yes) +# Build is being done for AMD platforms, remove the objects which +# don't have amd suffix (for which exists AMD specific implementation). +MK_FRAME_AMD_OBJS := $(filter $(BASE_OBJ_FRAME_PATH)/%amd.o, $(MK_FRAME_OBJS)) +FILES_TO_REMOVE := $(subst _amd.o,.o, $(MK_FRAME_AMD_OBJS)) +MK_FRAME_OBJS := $(filter-out $(FILES_TO_REMOVE), $(MK_FRAME_OBJS)) +else +# Build is done for non AMD platforms, remove the amd specific objects +MK_FRAME_AMD_OBJS := $(filter $(BASE_OBJ_FRAME_PATH)/%amd.o, $(MK_FRAME_OBJS)) +MK_FRAME_OBJS := $(filter-out $(MK_FRAME_AMD_OBJS), $(MK_FRAME_OBJS)) +endif + # Generate object file paths for all of the debgu and trace logger. MK_AOCLDTL_OBJS := $(call gen-obj-paths-from-src,$(AOCLDTL_SRC_SUFS),$(MK_AOCLDTL_SRC),$(AOCLDTL_PATH),$(BASE_OBJ_AOCLDTL_PATH)) @@ -1338,4 +1359,3 @@ else @echo "Uninstalling $(@F) from $(@D)/" @- $(RM_F) $@ endif - diff --git a/build/config.mk.in b/build/config.mk.in index 709e0f543c..a880074e8f 100644 --- a/build/config.mk.in +++ b/build/config.mk.in @@ -5,7 +5,7 @@ # libraries. # # Copyright (C) 2014, The University of Texas at Austin -# Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -204,5 +204,7 @@ MK_ENABLE_AOCL_DYNAMIC := @enable_aocl_dynamic@ # BLAS int size MK_BLAS_INT_TYPE_SIZE := @blas_int_type_size@ +MK_IS_ARCH_ZEN := @enable_aocl_zen@ + # end of ifndef CONFIG_MK_INCLUDED conditional block endif diff --git a/config/amdzen/make_defs.mk b/config/amdzen/make_defs.mk index 7697e9ff05..e467461601 100644 --- a/config/amdzen/make_defs.mk +++ b/config/amdzen/make_defs.mk @@ -4,7 +4,7 @@ # An object-based framework for developing high-performance BLAS-like # libraries. # -# Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -49,16 +49,6 @@ else COPTFLAGS := -O3 endif -# This will add BLIS_CONFIG_EPYC for all framework files -# FIXME: framework files should not have architecture specific -# checks at least at compile time. Once the macro -# is defined it is applicable to every build in the -# Family including any non AMD configuration. -# However, it is still better to define it in makefiles -# instead of headers so we can have slighly more -# control on this. -COPTFLAGS += -DBLIS_CONFIG_EPYC - # Store all of the variables here to new variables containing the # configuration name. $(eval $(call store-make-defs,$(THIS_CONFIG))) diff --git a/config/zen/make_defs.mk b/config/zen/make_defs.mk index be1086a1de..08d8628bec 100644 --- a/config/zen/make_defs.mk +++ b/config/zen/make_defs.mk @@ -5,7 +5,7 @@ # libraries. # # Copyright (C) 2014, The University of Texas at Austin -# Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -46,25 +46,12 @@ AMD_CONFIG_FILE := amd_config.mk AMD_CONFIG_PATH := $(BASE_SHARE_PATH)/config/zen -include $(AMD_CONFIG_PATH)/$(AMD_CONFIG_FILE) - -# Since we removed BLIS_CONFIG_EPYC from header file, we need to -# add it here at two places, -# CPPROCFLAGS = This will enable it for framework code -# This flag is used when configure is invoked with specific architecture -# CKOPTFLAGS = This will enable it for architecture specific kernels -# This flag is used for kernels assocaited with this architecture -# irrespective of the configuration it is built for. - -CPPROCFLAGS := -DBLIS_CONFIG_EPYC - - ifeq ($(DEBUG_TYPE),noopt) COPTFLAGS := -O0 else COPTFLAGS := -O3 endif - # # --- Enable ETRACE across the library if enabled ETRACE_ENABLE=[0,1] ----------------------- # @@ -86,10 +73,6 @@ else CRVECFLAGS := $(CKVECFLAGS) endif -# Add this after updating variables for reference kernels -# we don't want this defined for them -CKOPTFLAGS += -DBLIS_CONFIG_EPYC - # Store all of the variables here to new variables containing the # configuration name. $(eval $(call store-make-defs,$(THIS_CONFIG))) diff --git a/config/zen2/make_defs.mk b/config/zen2/make_defs.mk index ba91f722ab..3b87d35b00 100644 --- a/config/zen2/make_defs.mk +++ b/config/zen2/make_defs.mk @@ -5,7 +5,7 @@ # libraries. # # Copyright (C) 2014, The University of Texas at Austin -# Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -50,15 +50,7 @@ THIS_CONFIG := zen2 # general-purpose/configuration-agnostic flags in common.mk. You # may specify additional flags here as needed. -# Since we removed BLIS_CONFIG_EPYC from header file, we need to -# add it here at two places, -# CPPROCFLAGS = This will enable it for framework code -# This flag is used when configure is invoked with specific architecture -# CKOPTFLAGS = This will enable it for architecture specific kernels -# This flag is used for kernels assocaited with this architecture -# irrespective of the configuration it is built for. - -CPPROCFLAGS := -DBLIS_CONFIG_EPYC +CPPROCFLAGS := CMISCFLAGS := CPICFLAGS := CWARNFLAGS := @@ -111,10 +103,6 @@ endif CROPTFLAGS := $(CKOPTFLAGS) CRVECFLAGS := $(CKVECFLAGS) -# Add this after updating variables for reference kernels -# we don't want this defined for them -CKOPTFLAGS += -DBLIS_CONFIG_EPYC - # Store all of the variables here to new variables containing the # configuration name. $(eval $(call store-make-defs,$(THIS_CONFIG))) diff --git a/config/zen3/make_defs.mk b/config/zen3/make_defs.mk index a479acf8a5..8522a1e956 100644 --- a/config/zen3/make_defs.mk +++ b/config/zen3/make_defs.mk @@ -5,7 +5,7 @@ # libraries. # # Copyright (C) 2014, The University of Texas at Austin -# Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -50,15 +50,7 @@ THIS_CONFIG := zen3 # general-purpose/configuration-agnostic flags in common.mk. You # may specify additional flags here as needed. -# Since we removed BLIS_CONFIG_EPYC from header file, we need to -# add it here at two places, -# CPPROCFLAGS = This will enable it for framework code -# This flag is used when configure is invoked with specific architecture -# CKOPTFLAGS = This will enable it for architecture specific kernels -# This flag is used for kernels assocaited with this architecture -# irrespective of the configuration it is built for. - -CPPROCFLAGS := -DBLIS_CONFIG_EPYC +CPPROCFLAGS := CMISCFLAGS := CPICFLAGS := CWARNFLAGS := @@ -132,10 +124,6 @@ endif # gcc CROPTFLAGS := $(CKOPTFLAGS) CRVECFLAGS := $(CKVECFLAGS) -# Add this after updating variables for reference kernels -# we don't want this defined for them -CKOPTFLAGS += -DBLIS_CONFIG_EPYC - # Store all of the variables here to new variables containing the # configuration name. $(eval $(call store-make-defs,$(THIS_CONFIG))) diff --git a/config/zen4/make_defs.mk b/config/zen4/make_defs.mk index 352bd29c4e..44e96bb0c7 100644 --- a/config/zen4/make_defs.mk +++ b/config/zen4/make_defs.mk @@ -4,7 +4,7 @@ # An object-based framework for developing high-performance BLAS-like # libraries. # -# Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -49,15 +49,7 @@ THIS_CONFIG := zen4 # general-purpose/configuration-agnostic flags in common.mk. You # may specify additional flags here as needed. -# Since we removed BLIS_CONFIG_EPYC from header file, we need to -# add it here at two places, -# CPPROCFLAGS = This will enable it for framework code -# This flag is used when configure is invoked with specific architecture -# CKOPTFLAGS = This will enable it for architecture specific kernels -# This flag is used for kernels assocaited with this architecture -# irrespective of the configuration it is built for. - -CPPROCFLAGS := -DBLIS_CONFIG_EPYC +CPPROCFLAGS := CMISCFLAGS := CPICFLAGS := CWARNFLAGS := @@ -131,10 +123,6 @@ endif # gcc CROPTFLAGS := $(CKOPTFLAGS) CRVECFLAGS := $(CKVECFLAGS) -# Add this after updating variables for reference kernels -# we don't want this defined for them -CKOPTFLAGS += -DBLIS_CONFIG_EPYC - # Store all of the variables here to new variables containing the # configuration name. $(eval $(call store-make-defs,$(THIS_CONFIG))) diff --git a/configure b/configure index bec498d3cf..f49ea19e5e 100755 --- a/configure +++ b/configure @@ -5,7 +5,7 @@ # libraries. # # Copyright (C) 2014, The University of Texas at Austin -# Copyright (C) 2020-2021, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2020-2022, Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -3370,6 +3370,7 @@ main() | sed -e "s/@enable_aocl_dynamic@/${enable_aocl_dynamic}/g" \ | sed -e "s/@complex_return@/${complex_return}/g" \ | sed -e "s/@blas_int_type_size@/${blas_int_type_size}/g" \ + | sed -e "s/\@enable_aocl_zen\@/${enable_aocl_zen}/g" \ > "${config_mk_out_path}" diff --git a/frame/2/gemv/bli_gemv_unf_var1.c b/frame/2/gemv/bli_gemv_unf_var1.c index 838ea577bc..8162613c18 100644 --- a/frame/2/gemv/bli_gemv_unf_var1.c +++ b/frame/2/gemv/bli_gemv_unf_var1.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 21, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 22, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -104,357 +104,5 @@ void PASTEMAC(ch,varname) \ } \ } -#ifdef BLIS_CONFIG_EPYC -void bli_dgemv_unf_var1 - ( - trans_t transa, - conj_t conjx, - dim_t m, - dim_t n, - double* alpha, - double* a, inc_t rs_a, inc_t cs_a, - double* x, inc_t incx, - double* beta, - double* y, inc_t incy, - cntx_t* cntx - ) -{ - - double *A1; - double *y1; - dim_t i; - dim_t f; - dim_t n_elem, n_iter; - inc_t rs_at, cs_at; - conj_t conja; - //memory pool declarations for packing vector X. - mem_t mem_bufX; - rntm_t rntm; - double *x_buf = x; - inc_t buf_incx = incx; - - bli_init_once(); - - if (cntx == NULL) - cntx = bli_gks_query_cntx(); - - bli_set_dims_incs_with_trans(transa, - m, n, rs_a, cs_a, - &n_iter, &n_elem, &rs_at, &cs_at); - - conja = bli_extract_conj(transa); - - // This function is invoked on all architectures including ‘generic’. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN4) || - (id == BLIS_ARCH_ZEN3) || - (id == BLIS_ARCH_ZEN2) || - (id == BLIS_ARCH_ZEN); - - if (bamdzen == 0) - { - if ( cntx == NULL ) cntx = bli_gks_query_cntx(); - const num_t dt = PASTEMAC(d,type); - double* x1; - double* y1; - PASTECH(d,dotxf_ker_ft) kfp_df; - /* Query the context for the kernel function pointer and fusing factor. */ - kfp_df = bli_cntx_get_l1f_ker_dt( dt, BLIS_DOTXF_KER, cntx ); - dim_t b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_DF, cntx ); - - for ( i = 0; i < n_iter; i += f ) - { - f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); - - A1 = a + (i )*rs_at + (0 )*cs_at; - x1 = x + (0 )*incy; - y1 = y + (i )*incy; - - /* y1 = beta * y1 + alpha * A1 * x; */ - kfp_df - ( - conja, - conjx, - n_elem, - f, - alpha, - A1, cs_at, rs_at, - x1, incx, - beta, - y1, incy, - cntx - ); - - } - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); - return; - } - - if (incx > 1) - { - /* - Initialize mem pool buffer to NULL and size to 0 - "buf" and "size" fields are assigned once memory - is allocated from the pool in bli_membrk_acquire_m(). - This will ensure bli_mem_is_alloc() will be passed on - an allocated memory if created or a NULL . - */ - - mem_bufX.pblk.buf = NULL; - mem_bufX.pblk.block_size = 0; - mem_bufX.buf_type = 0; - mem_bufX.size = 0; - mem_bufX.pool = NULL; - - /* In order to get the buffer from pool via rntm access to memory broker - is needed.Following are initializations for rntm */ - - bli_rntm_init_from_global(&rntm); - bli_rntm_set_num_threads_only(1, &rntm); - bli_membrk_rntm_set_membrk(&rntm); - - //calculate the size required for n_elem double elements in vector X. - size_t buffer_size = n_elem * sizeof(double); - -#ifdef BLIS_ENABLE_MEM_TRACING - printf("bli_dgemv_unf_var1(): get mem pool block\n"); -#endif - - /*acquire a Buffer(n_elem*size(double)) from the memory broker - and save the associated mem_t entry to mem_bufX.*/ - bli_membrk_acquire_m(&rntm, - buffer_size, - BLIS_BUFFER_FOR_B_PANEL, - &mem_bufX); - - /*Continue packing X if buffer memory is allocated*/ - if ((bli_mem_is_alloc(&mem_bufX))) - { - x_buf = bli_mem_buffer(&mem_bufX); - - //pack X vector with non-unit stride to a temp buffer x_buf with unit stride - for (dim_t x_index = 0; x_index < n_elem; x_index++) - { - *(x_buf + x_index) = *(x + (x_index * incx)); - } - // stride of vector x_buf =1 - buf_incx = 1; - } - } - - dim_t fuse_factor = 8; - dim_t f_temp =0; - - if (n < 4) - { - fuse_factor = 2; - } else if (n < 8) - { - fuse_factor = 4; - } - - - for (i = 0; i < n_iter; i += f) - { - f = bli_determine_blocksize_dim_f(i, n_iter, fuse_factor); - - //A = a + i * row_increment + 0 * column_increment - A1 = a + (i)*rs_at; - y1 = y + (i)*incy; - - /* y1 = beta * y1 + alpha * A1 * x; */ - switch (f) - { - case 8: - - bli_ddotxf_zen_int_8( - conja, - conjx, - n_elem, - f, - alpha, - A1, cs_at, rs_at, - x_buf, buf_incx, - beta, - y1, incy, - cntx); - - break; - default: - - if (f < 4) - { - bli_ddotxf_zen_int_2( - conja, - conjx, - n_elem, - f, - alpha, - A1, cs_at, rs_at, - x_buf, buf_incx, - beta, - y1, incy, - cntx); - } - else - { - bli_ddotxf_zen_int_4( - conja, - conjx, - n_elem, - f, - alpha, - A1, cs_at, rs_at, - x_buf, buf_incx, - beta, - y1, incy, - cntx); - } - } - - f_temp = bli_determine_blocksize_dim_f(i + f, n_iter, fuse_factor); - - if (f_temp < fuse_factor) - { - switch (fuse_factor) - { - case 8: - fuse_factor = 4; - break; - case 4: - fuse_factor = 2; - break; - } - } - } - - if ((incx > 1) && bli_mem_is_alloc(&mem_bufX)) - { -#ifdef BLIS_ENABLE_MEM_TRACING - printf("bli_dgemv_unf_var1(): releasing mem pool block\n"); -#endif - // Return the buffer to pool - bli_membrk_release(&rntm, &mem_bufX); - } - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); -} - -void bli_sgemv_unf_var1 - ( - trans_t transa, - conj_t conjx, - dim_t m, - dim_t n, - float* alpha, - float* a, inc_t rs_a, inc_t cs_a, - float* x, inc_t incx, - float* beta, - float* y, inc_t incy, - cntx_t* cntx - ) -{ - - float* A1; - float* x1; - float* y1; - dim_t i; - dim_t b_fuse, f; - dim_t n_elem, n_iter; - inc_t rs_at, cs_at; - conj_t conja; - - bli_init_once(); - - if( cntx == NULL ) cntx = bli_gks_query_cntx(); - - bli_set_dims_incs_with_trans( transa, - m, n, rs_a, cs_a, - &n_iter, &n_elem, &rs_at, &cs_at ); - - conja = bli_extract_conj( transa ); - - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. - // This function is invoked on all architectures including ‘generic’. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN4) || - (id == BLIS_ARCH_ZEN3) || - (id == BLIS_ARCH_ZEN2) || - (id == BLIS_ARCH_ZEN); - - if (bamdzen == 0) - { - if ( cntx == NULL ) cntx = bli_gks_query_cntx(); - const num_t dt = PASTEMAC(s,type); - float* x1 ; - PASTECH(s,dotxf_ker_ft) kfp_df; - /* Query the context for the kernel function pointer and fusing factor. */ - kfp_df = bli_cntx_get_l1f_ker_dt( dt, BLIS_DOTXF_KER, cntx ); - b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_DF, cntx ); - - for ( i = 0; i < n_iter; i += f ) - { - f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); - - A1 = a + (i )*rs_at + (0 )*cs_at; - x1 = x + (0 )*incy; - y1 = y + (i )*incy; - - /* y1 = beta * y1 + alpha * A1 * x; */ - kfp_df - ( - conja, - conjx, - n_elem, - f, - alpha, - A1, cs_at, rs_at, - x1, incx, - beta, - y1, incy, - cntx - ); - - } - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); - return; - } - - /* Query the context for the kernel function pointer and fusing factor. */ - b_fuse = 8; - - for ( i = 0; i < n_iter; i += f ) - { - f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); - - A1 = a + (i )*rs_at + (0 )*cs_at; - x1 = x + (0 )*incy; - y1 = y + (i )*incy; - - /* y1 = beta * y1 + alpha * A1 * x; */ - bli_sdotxf_zen_int_8 - ( - conja, - conjx, - n_elem, - f, - alpha, - A1, cs_at, rs_at, - x1, incx, - beta, - y1, incy, - cntx - ); - - } -} - -INSERT_GENTFUNC_BASIC0_CZ( gemv_unf_var1 ) -#else INSERT_GENTFUNC_BASIC0( gemv_unf_var1 ) -#endif + diff --git a/frame/2/gemv/bli_gemv_unf_var1_amd.c b/frame/2/gemv/bli_gemv_unf_var1_amd.c new file mode 100644 index 0000000000..7228c12f75 --- /dev/null +++ b/frame/2/gemv/bli_gemv_unf_var1_amd.c @@ -0,0 +1,440 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020 - 22, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTEMAC(ch,varname) \ + ( \ + trans_t transa, \ + conj_t conjx, \ + dim_t m, \ + dim_t n, \ + ctype* alpha, \ + ctype* a, inc_t rs_a, inc_t cs_a, \ + ctype* x, inc_t incx, \ + ctype* beta, \ + ctype* y, inc_t incy, \ + cntx_t* cntx \ + ) \ +{ \ +\ + if(cntx == NULL) cntx = bli_gks_query_cntx(); \ +\ + const num_t dt = PASTEMAC(ch,type); \ +\ + ctype* A1; \ + ctype* x1; \ + ctype* y1; \ + dim_t i; \ + dim_t b_fuse, f; \ + dim_t n_elem, n_iter; \ + inc_t rs_at, cs_at; \ + conj_t conja; \ +\ + bli_set_dims_incs_with_trans( transa, \ + m, n, rs_a, cs_a, \ + &n_iter, &n_elem, &rs_at, &cs_at ); \ +\ + conja = bli_extract_conj( transa ); \ +\ + PASTECH(ch,dotxf_ker_ft) kfp_df; \ +\ + /* Query the context for the kernel function pointer and fusing factor. */ \ + kfp_df = bli_cntx_get_l1f_ker_dt( dt, BLIS_DOTXF_KER, cntx ); \ + b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_DF, cntx ); \ +\ + for ( i = 0; i < n_iter; i += f ) \ + { \ + f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); \ +\ + A1 = a + (i )*rs_at + (0 )*cs_at; \ + x1 = x + (0 )*incy; \ + y1 = y + (i )*incy; \ +\ + /* y1 = beta * y1 + alpha * A1 * x; */ \ + kfp_df \ + ( \ + conja, \ + conjx, \ + n_elem, \ + f, \ + alpha, \ + A1, cs_at, rs_at, \ + x1, incx, \ + beta, \ + y1, incy, \ + cntx \ + ); \ +\ + } \ +} + +void bli_dgemv_unf_var1 + ( + trans_t transa, + conj_t conjx, + dim_t m, + dim_t n, + double* alpha, + double* a, inc_t rs_a, inc_t cs_a, + double* x, inc_t incx, + double* beta, + double* y, inc_t incy, + cntx_t* cntx + ) +{ + + double *A1; + double *y1; + dim_t i; + dim_t f; + dim_t n_elem, n_iter; + inc_t rs_at, cs_at; + conj_t conja; + //memory pool declarations for packing vector X. + mem_t mem_bufX; + rntm_t rntm; + double *x_buf = x; + inc_t buf_incx = incx; + + bli_init_once(); + + if (cntx == NULL) + cntx = bli_gks_query_cntx(); + + bli_set_dims_incs_with_trans(transa, + m, n, rs_a, cs_a, + &n_iter, &n_elem, &rs_at, &cs_at); + + conja = bli_extract_conj(transa); + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == FALSE) + { + if ( cntx == NULL ) cntx = bli_gks_query_cntx(); + const num_t dt = PASTEMAC(d,type); + double* x1; + double* y1; + PASTECH(d,dotxf_ker_ft) kfp_df; + /* Query the context for the kernel function pointer and fusing factor. */ + kfp_df = bli_cntx_get_l1f_ker_dt( dt, BLIS_DOTXF_KER, cntx ); + dim_t b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_DF, cntx ); + + for ( i = 0; i < n_iter; i += f ) + { + f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); + + A1 = a + (i )*rs_at + (0 )*cs_at; + x1 = x + (0 )*incy; + y1 = y + (i )*incy; + + /* y1 = beta * y1 + alpha * A1 * x; */ + kfp_df + ( + conja, + conjx, + n_elem, + f, + alpha, + A1, cs_at, rs_at, + x1, incx, + beta, + y1, incy, + cntx + ); + + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); + return; + } + if (incx > 1) + { + /* + Initialize mem pool buffer to NULL and size to 0 + "buf" and "size" fields are assigned once memory + is allocated from the pool in bli_membrk_acquire_m(). + This will ensure bli_mem_is_alloc() will be passed on + an allocated memory if created or a NULL . + */ + + mem_bufX.pblk.buf = NULL; + mem_bufX.pblk.block_size = 0; + mem_bufX.buf_type = 0; + mem_bufX.size = 0; + mem_bufX.pool = NULL; + + /* In order to get the buffer from pool via rntm access to memory broker + is needed.Following are initializations for rntm */ + + bli_rntm_init_from_global(&rntm); + bli_rntm_set_num_threads_only(1, &rntm); + bli_membrk_rntm_set_membrk(&rntm); + + //calculate the size required for n_elem double elements in vector X. + size_t buffer_size = n_elem * sizeof(double); + +#ifdef BLIS_ENABLE_MEM_TRACING + printf("bli_dgemv_unf_var1(): get mem pool block\n"); +#endif + + /*acquire a Buffer(n_elem*size(double)) from the memory broker + and save the associated mem_t entry to mem_bufX.*/ + bli_membrk_acquire_m(&rntm, + buffer_size, + BLIS_BUFFER_FOR_B_PANEL, + &mem_bufX); + + /*Continue packing X if buffer memory is allocated*/ + if ((bli_mem_is_alloc(&mem_bufX))) + { + x_buf = bli_mem_buffer(&mem_bufX); + + //pack X vector with non-unit stride to a temp buffer x_buf with unit stride + for (dim_t x_index = 0; x_index < n_elem; x_index++) + { + *(x_buf + x_index) = *(x + (x_index * incx)); + } + // stride of vector x_buf =1 + buf_incx = 1; + } + } + + dim_t fuse_factor = 8; + dim_t f_temp =0; + + if (n < 4) + { + fuse_factor = 2; + } else if (n < 8) + { + fuse_factor = 4; + } + + for (i = 0; i < n_iter; i += f) + { + f = bli_determine_blocksize_dim_f(i, n_iter, fuse_factor); + + //A = a + i * row_increment + 0 * column_increment + A1 = a + (i)*rs_at; + y1 = y + (i)*incy; + + /* y1 = beta * y1 + alpha * A1 * x; */ + switch (f) + { + case 8: + + bli_ddotxf_zen_int_8( + conja, + conjx, + n_elem, + f, + alpha, + A1, cs_at, rs_at, + x_buf, buf_incx, + beta, + y1, incy, + cntx); + + break; + default: + + if (f < 4) + { + bli_ddotxf_zen_int_2( + conja, + conjx, + n_elem, + f, + alpha, + A1, cs_at, rs_at, + x_buf, buf_incx, + beta, + y1, incy, + cntx); + } + else + { + bli_ddotxf_zen_int_4( + conja, + conjx, + n_elem, + f, + alpha, + A1, cs_at, rs_at, + x_buf, buf_incx, + beta, + y1, incy, + cntx); + } + } + + f_temp = bli_determine_blocksize_dim_f(i + f, n_iter, fuse_factor); + + if (f_temp < fuse_factor) + { + switch (fuse_factor) + { + case 8: + fuse_factor = 4; + break; + case 4: + fuse_factor = 2; + break; + } + } + } + + if ((incx > 1) && bli_mem_is_alloc(&mem_bufX)) + { +#ifdef BLIS_ENABLE_MEM_TRACING + printf("bli_dgemv_unf_var1(): releasing mem pool block\n"); +#endif + // Return the buffer to pool + bli_membrk_release(&rntm, &mem_bufX); + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); +} + +void bli_sgemv_unf_var1 + ( + trans_t transa, + conj_t conjx, + dim_t m, + dim_t n, + float* alpha, + float* a, inc_t rs_a, inc_t cs_a, + float* x, inc_t incx, + float* beta, + float* y, inc_t incy, + cntx_t* cntx + ) +{ + + float* A1; + float* x1; + float* y1; + dim_t i; + dim_t b_fuse, f; + dim_t n_elem, n_iter; + inc_t rs_at, cs_at; + conj_t conja; + + bli_init_once(); + + if( cntx == NULL ) cntx = bli_gks_query_cntx(); + + bli_set_dims_incs_with_trans( transa, + m, n, rs_a, cs_a, + &n_iter, &n_elem, &rs_at, &cs_at ); + + conja = bli_extract_conj( transa ); + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == FALSE) + { + if ( cntx == NULL ) cntx = bli_gks_query_cntx(); + const num_t dt = PASTEMAC(s,type); + float* x1 ; + PASTECH(s,dotxf_ker_ft) kfp_df; + /* Query the context for the kernel function pointer and fusing factor. */ + kfp_df = bli_cntx_get_l1f_ker_dt( dt, BLIS_DOTXF_KER, cntx ); + b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_DF, cntx ); + + for ( i = 0; i < n_iter; i += f ) + { + f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); + + A1 = a + (i )*rs_at + (0 )*cs_at; + x1 = x + (0 )*incy; + y1 = y + (i )*incy; + + /* y1 = beta * y1 + alpha * A1 * x; */ + kfp_df + ( + conja, + conjx, + n_elem, + f, + alpha, + A1, cs_at, rs_at, + x1, incx, + beta, + y1, incy, + cntx + ); + + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); + return; + } + + /* Query the context for the kernel function pointer and fusing factor. */ + b_fuse = 8; + + for ( i = 0; i < n_iter; i += f ) + { + f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); + + A1 = a + (i )*rs_at + (0 )*cs_at; + x1 = x + (0 )*incy; + y1 = y + (i )*incy; + + /* y1 = beta * y1 + alpha * A1 * x; */ + bli_sdotxf_zen_int_8 + ( + conja, + conjx, + n_elem, + f, + alpha, + A1, cs_at, rs_at, + x1, incx, + beta, + y1, incy, + cntx + ); + + } +} + +INSERT_GENTFUNC_BASIC0_CZ( gemv_unf_var1 ) + diff --git a/frame/2/gemv/bli_gemv_unf_var2.c b/frame/2/gemv/bli_gemv_unf_var2.c index 093b615a7d..227e43ad01 100644 --- a/frame/2/gemv/bli_gemv_unf_var2.c +++ b/frame/2/gemv/bli_gemv_unf_var2.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020-21, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020-22, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -137,764 +137,4 @@ void PASTEMAC(ch,varname) \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); \ } -#ifdef BLIS_CONFIG_EPYC - -void bli_dgemv_unf_var2 - ( - trans_t transa, - conj_t conjx, - dim_t m, - dim_t n, - double* alpha, - double* a, inc_t rs_a, inc_t cs_a, - double* x, inc_t incx, - double* beta, - double* y, inc_t incy, - cntx_t* cntx - ) -{ - - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_3); - double* A1; - double* x1; - dim_t i; - dim_t f; - dim_t n_elem, n_iter; - inc_t rs_at, cs_at; - conj_t conja; - //memory pool declarations for packing vector Y. - mem_t mem_bufY; - rntm_t rntm; - double *y_buf = y; - inc_t buf_incy = incy; - - bli_set_dims_incs_with_trans( transa, - m, n, rs_a, cs_a, - &n_elem, &n_iter, &rs_at, &cs_at ); - - conja = bli_extract_conj( transa ); - - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. - // This function is invoked on all architectures including ‘generic’. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN4) || - (id == BLIS_ARCH_ZEN3) || - (id == BLIS_ARCH_ZEN2) || - (id == BLIS_ARCH_ZEN); - - if (bamdzen == 0) - { - if ( cntx == NULL ) cntx = bli_gks_query_cntx(); - const num_t dt = PASTEMAC(d,type); - double* x1; - double* y1; - /* If beta is zero, use setv. Otherwise, scale by beta. */ - if ( PASTEMAC(d,eq0)( *beta ) ) - { - double* zero = PASTEMAC(d,0); - /* y = 0; */ - PASTEMAC2(d,setv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - n_elem, - zero, - y, incy, - cntx, - NULL - ); - } - else - { - /* y = beta * y; */ - PASTEMAC2(d,scalv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - n_elem, - beta, - y, incy, - cntx, - NULL - ); - } - - PASTECH(d,axpyf_ker_ft) kfp_af; - - /* Query the context for the kernel function pointer and fusing factor. */ - kfp_af = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPYF_KER, cntx ); - dim_t b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_AF, cntx ); - - for ( i = 0; i < n_iter; i += f ) - { - f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); - - A1 = a + (0 )*rs_at + (i )*cs_at; - x1 = x + (i )*incx; - y1 = y + (0 )*incy; - - /* y = y + alpha * A1 * x1; */ - kfp_af - ( - conja, - conjx, - n_elem, - f, - alpha, - A1, rs_at, cs_at, - x1, incx, - y1, incy, - cntx - ); - } - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); - return; - } - - /* If beta is zero, use setv. Otherwise, scale by beta. */ - /* y = beta * y; */ - /* beta=0 case is hadled by scalv internally */ - - bli_dscalv_zen_int10 - ( - BLIS_NO_CONJUGATE, - n_elem, - beta, - y, incy, - NULL - ); - - if( bli_deq0( *alpha ) ) - { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3) - return; - } - - if (incy > 1) - { - /* - Initialize mem pool buffer to NULL and size to 0 - "buf" and "size" fields are assigned once memory - is allocated from the pool in bli_membrk_acquire_m(). - This will ensure bli_mem_is_alloc() will be passed on - an allocated memory if created or a NULL . - */ - mem_bufY.pblk.buf = NULL; mem_bufY.pblk.block_size = 0; - mem_bufY.buf_type = 0; mem_bufY.size = 0; - mem_bufY.pool = NULL; - - /* In order to get the buffer from pool via rntm access to memory broker - is needed.Following are initializations for rntm */ - - bli_rntm_init_from_global( &rntm ); - bli_rntm_set_num_threads_only( 1, &rntm ); - bli_membrk_rntm_set_membrk( &rntm ); - - //calculate the size required for n_elem double elements in vector Y. - size_t buffer_size = n_elem * sizeof(double); - - #ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_dgemv_unf_var2(): get mem pool block\n" ); - #endif - - /*acquire a Buffer(n_elem*size(double)) from the memory broker - and save the associated mem_t entry to mem_bufY.*/ - bli_membrk_acquire_m(&rntm, - buffer_size, - BLIS_BUFFER_FOR_B_PANEL, - &mem_bufY); - - /*Continue packing Y if buffer memory is allocated*/ - if ((bli_mem_is_alloc( &mem_bufY ))) - { - y_buf = bli_mem_buffer(&mem_bufY); - - //pack Y vector with non-unit stride to a temp buffer y_buf with unit stride - for(dim_t y_index = 0 ; y_index < n_elem ; y_index++) - { - *(y_buf + y_index) = *(y + (y_index * incy)) ; - } - // stride of vector y_buf =1 - buf_incy = 1; - } - } - - for ( i = 0; i < n_iter; i += f ) - { - f = bli_determine_blocksize_dim_f( i, n_iter, BLIS_DGEMV_VAR2_FUSE ); - - A1 = a + (0 )*rs_at + (i )*cs_at; - x1 = x + (i )*incx; - - /* y = y + alpha * A1 * x1; */ - bli_daxpyf_zen_int_16x4 - ( - conja, - conjx, - n_elem, - f, - alpha, - A1, rs_at, cs_at, - x1, incx, - y_buf, buf_incy, - NULL - ); - } - if ((incy > 1) && bli_mem_is_alloc( &mem_bufY )) - { - //store the result from unit strided y_buf to non-unit strided Y - for(dim_t y_index = 0 ; y_index < n_elem ; y_index++) - { - *(y + (y_index * incy)) = *(y_buf + y_index) ; - } - - #ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_dgemv_unf_var2(): releasing mem pool block\n" ); - #endif - // Return the buffer to pool - bli_membrk_release(&rntm , &mem_bufY); - } - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); -} - -void bli_sgemv_unf_var2 - ( - trans_t transa, - conj_t conjx, - dim_t m, - dim_t n, - float* alpha, - float* a, inc_t rs_a, inc_t cs_a, - float* x, inc_t incx, - float* beta, - float* y, inc_t incy, - cntx_t* cntx - ) -{ - - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_3); - float* A1; - float* x1; - float* y1; - dim_t i; - dim_t b_fuse, f; - dim_t n_elem, n_iter; - inc_t rs_at, cs_at; - conj_t conja; - - bli_set_dims_incs_with_trans( transa, - m, n, rs_a, cs_a, - &n_elem, &n_iter, &rs_at, &cs_at ); - - conja = bli_extract_conj( transa ); - - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. - // This function is invoked on all architectures including ‘generic’. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN4) || - (id == BLIS_ARCH_ZEN3) || - (id == BLIS_ARCH_ZEN2) || - (id == BLIS_ARCH_ZEN); - - if (bamdzen == 0) - { - if ( cntx == NULL ) cntx = bli_gks_query_cntx(); - const num_t dt = PASTEMAC(s,type); - /* If beta is zero, use setv. Otherwise, scale by beta. */ - if ( PASTEMAC(s,eq0)( *beta ) ) - { - float* zero = PASTEMAC(s,0); - /* y = 0; */ - PASTEMAC2(s,setv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - n_elem, - zero, - y, incy, - cntx, - NULL - ); - } - else - { - /* y = beta * y; */ - PASTEMAC2(s,scalv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - n_elem, - beta, - y, incy, - cntx, - NULL - ); - } - - PASTECH(s,axpyf_ker_ft) kfp_af; - - /* Query the context for the kernel function pointer and fusing factor. */ - kfp_af = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPYF_KER, cntx ); - b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_AF, cntx ); - - for ( i = 0; i < n_iter; i += f ) - { - f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); - - A1 = a + (0 )*rs_at + (i )*cs_at; - x1 = x + (i )*incx; - y1 = y + (0 )*incy; - - /* y = y + alpha * A1 * x1; */ - kfp_af - ( - conja, - conjx, - n_elem, - f, - alpha, - A1, rs_at, cs_at, - x1, incx, - y1, incy, - cntx - ); - } - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); - return; - } - - /* If beta is zero, use setv. Otherwise, scale by beta. */ - /* y = beta * y; */ - /* beta=0 case is hadled by scalv internally */ - - bli_sscalv_zen_int10 - ( - BLIS_NO_CONJUGATE, - n_elem, - beta, - y, incy, - NULL - ); - - if( bli_seq0( *alpha ) ) - { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3) - return; - } - - /* Query the context for the kernel function pointer and fusing factor. */ - b_fuse = 6; - - for ( i = 0; i < n_iter; i += f ) - { - f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); - - A1 = a + (0 )*rs_at + (i )*cs_at; - x1 = x + (i )*incx; - y1 = y + (0 )*incy; - - /* y = y + alpha * A1 * x1; */ - bli_saxpyf_zen_int_6 - ( - conja, - conjx, - n_elem, - f, - alpha, - A1, rs_at, cs_at, - x1, incx, - y1, incy, - NULL - ); - } - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); -} - - -void bli_zgemv_unf_var2 - ( - trans_t transa, - conj_t conjx, - dim_t m, - dim_t n, - dcomplex* alpha, - dcomplex* a, inc_t rs_a, inc_t cs_a, - dcomplex* x, inc_t incx, - dcomplex* beta, - dcomplex* y, inc_t incy, - cntx_t* cntx - ) -{ - - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_3); - dcomplex* A1; - dcomplex* x1; - dcomplex* y1; - dim_t i; - dim_t b_fuse, f; - dim_t n_elem, n_iter; - inc_t rs_at, cs_at; - conj_t conja; - - bli_set_dims_incs_with_trans( transa, - m, n, rs_a, cs_a, - &n_elem, &n_iter, &rs_at, &cs_at ); - - conja = bli_extract_conj( transa ); - - /* If beta is zero, use setv. Otherwise, scale by beta. */ - /* y = beta * y; */ - - /* beta=0 case is hadled by scalv internally */ - /* bli_zscalv_zen_int10 - ( - BLIS_NO_CONJUGATE, - n_elem, - beta, - y, - incy, - cntx - );*/ - - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. - // This function is invoked on all architectures including ‘generic’. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN4) || - (id == BLIS_ARCH_ZEN3) || - (id == BLIS_ARCH_ZEN2) || - (id == BLIS_ARCH_ZEN); - - if (bamdzen == 0) - { - if ( cntx == NULL ) cntx = bli_gks_query_cntx(); - const num_t dt = PASTEMAC(z,type); - /* If beta is zero, use setv. Otherwise, scale by beta. */ - if ( PASTEMAC(z,eq0)( *beta ) ) - { - dcomplex* zero = PASTEMAC(z,0); - /* y = 0; */ - PASTEMAC2(z,setv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - n_elem, - zero, - y, incy, - cntx, - NULL - ); - } - else - { - /* y = beta * y; */ - PASTEMAC2(z,scalv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - n_elem, - beta, - y, incy, - cntx, - NULL - ); - } - - PASTECH(z,axpyf_ker_ft) kfp_af; - - /* Query the context for the kernel function pointer and fusing factor. */ - kfp_af = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPYF_KER, cntx ); - b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_AF, cntx ); - - for ( i = 0; i < n_iter; i += f ) - { - f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); - - A1 = a + (0 )*rs_at + (i )*cs_at; - x1 = x + (i )*incx; - y1 = y + (0 )*incy; - - /* y = y + alpha * A1 * x1; */ - kfp_af - ( - conja, - conjx, - n_elem, - f, - alpha, - A1, rs_at, cs_at, - x1, incx, - y1, incy, - cntx - ); - } - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); - return; - } - - bli_zscalv_ex - ( - BLIS_NO_CONJUGATE, - n_elem, - beta, - y, incy, - cntx, - NULL - ); - - if( bli_zeq0( *alpha ) ) - { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); - return; - } - - // for non-unit incx, incy and rs_at and conjugate will be added in the next patch - if( (incx == 1 && incy == 1 && rs_at == 1 ) && - !bli_is_conj(conja) && !bli_is_conj(conjx) && !bli_is_trans(transa)) - { - // This gemv code deals with the followint conditions only - // 1. incx, incy, and row stride equal to one - // 2. Non conjugate A matrix and X vector - // 3. No Transpose for A Martix - // Rest is taken care by the else part (axpyf implementation) - bli_zgemv_zen_int_4x4 - ( - conja, - conjx, - m, - n, - alpha, - a, rs_at, cs_at, - x, incx, - beta, - y, incy, - NULL - ); - } - else - { - /* fusing factor */ - b_fuse = 4; - - for ( i = 0; i < n_iter; i += f ) - { - f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); - A1 = a + (0 )*rs_at + (i )*cs_at; - x1 = x + (i )*incx; - y1 = y + (0 )*incy; - - /* y = y + alpha * A1 * x1; */ - bli_zaxpyf_zen_int_4 - ( - conja, - conjx, - n_elem, - f, - alpha, - A1, rs_at, cs_at, - x1, incx, - y1, incy, - NULL - ); - } - } - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); -} - -void bli_cgemv_unf_var2 - ( - trans_t transa, - conj_t conjx, - dim_t m, - dim_t n, - scomplex* alpha, - scomplex* a, inc_t rs_a, inc_t cs_a, - scomplex* x, inc_t incx, - scomplex* beta, - scomplex* y, inc_t incy, - cntx_t* cntx - ) -{ - - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_3); - scomplex* A1; - scomplex* x1; - scomplex* y1; - dim_t i; - dim_t b_fuse, f; - dim_t n_elem, n_iter; - inc_t rs_at, cs_at; - conj_t conja; - - bli_set_dims_incs_with_trans( transa, - m, n, rs_a, cs_a, - &n_elem, &n_iter, &rs_at, &cs_at ); - - conja = bli_extract_conj( transa ); - - /* If beta is zero, use setv. Otherwise, scale by beta. */ - /* y = beta * y; */ - /* beta=0 case is hadled by scalv internally */ - /*bli_cscalv_zen_int10 - ( - BLIS_NO_CONJUGATE, - n_elem, - beta, - y, - incy, - cntx - );*/ - - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. - // This function is invoked on all architectures including ‘generic’. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN4) || - (id == BLIS_ARCH_ZEN3) || - (id == BLIS_ARCH_ZEN2) || - (id == BLIS_ARCH_ZEN); - - if (bamdzen == 0) - { - if ( cntx == NULL ) cntx = bli_gks_query_cntx(); - const num_t dt = PASTEMAC(c,type); - /* If beta is zero, use setv. Otherwise, scale by beta. */ - if ( PASTEMAC(c,eq0)( *beta ) ) - { - scomplex* zero = PASTEMAC(c,0); - /* y = 0; */ - PASTEMAC2(c,setv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - n_elem, - zero, - y, incy, - cntx, - NULL - ); - } - else - { - /* y = beta * y; */ - PASTEMAC2(c,scalv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - n_elem, - beta, - y, incy, - cntx, - NULL - ); - } - - PASTECH(c,axpyf_ker_ft) kfp_af; - - /* Query the context for the kernel function pointer and fusing factor. */ - kfp_af = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPYF_KER, cntx ); - b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_AF, cntx ); - - for ( i = 0; i < n_iter; i += f ) - { - f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); - - A1 = a + (0 )*rs_at + (i )*cs_at; - x1 = x + (i )*incx; - y1 = y + (0 )*incy; - - /* y = y + alpha * A1 * x1; */ - kfp_af - ( - conja, - conjx, - n_elem, - f, - alpha, - A1, rs_at, cs_at, - x1, incx, - y1, incy, - cntx - ); - } - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); - return; - } - - bli_cscalv_ex - ( - BLIS_NO_CONJUGATE, - n_elem, - beta, - y, incy, - cntx, - NULL - ); - - - - if( bli_ceq0( *alpha ) ) - { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3) - return; - } - - // for non-unit incx, incy and rs_at and conjugate will be added in the next patch - if( ( (incx == 1) && (incy == 1) && (rs_at == 1) ) && - !bli_is_conj(conja) && !bli_is_conj(conjx) && - !bli_is_trans(transa)) - { - // This gemv code deals with the followint conditions only - // 1. incx, incy, and row stride equal to one - // 2. Non conjugate A matrix and X vector - // 3. No Transpose for A Martix - // Rest is taken care by the else part (axpyf implementation) - bli_cgemv_zen_int_4x4 - ( - conja, - conjx, - m, - n, - alpha, - a, rs_at, cs_at, - x, incx, - beta, - y, incy, - NULL - ); - } - else - { - /* fusing factor. */ - b_fuse = 4; - - for ( i = 0; i < n_iter; i += f ) - { - f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); - A1 = a + (0 )*rs_at + (i )*cs_at; - x1 = x + (i )*incx; - y1 = y + (0 )*incy; - - /* y = y + alpha * A1 * x1; */ - bli_caxpyf_zen_int_4 - ( - conja, - conjx, - n_elem, - f, - alpha, - A1, rs_at, cs_at, - x1, incx, - y1, incy, - NULL - ); - } - } - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); -} - - -#else -INSERT_GENTFUNC_BASIC0( gemv_unf_var2 ) -#endif +INSERT_GENTFUNC_BASIC0( gemv_unf_var2 ) \ No newline at end of file diff --git a/frame/2/gemv/bli_gemv_unf_var2_amd.c b/frame/2/gemv/bli_gemv_unf_var2_amd.c new file mode 100644 index 0000000000..d7f5145e31 --- /dev/null +++ b/frame/2/gemv/bli_gemv_unf_var2_amd.c @@ -0,0 +1,879 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020-22, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#define BLIS_DGEMV_VAR2_FUSE 4 + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTEMAC(ch,varname) \ + ( \ + trans_t transa, \ + conj_t conjx, \ + dim_t m, \ + dim_t n, \ + ctype* alpha, \ + ctype* a, inc_t rs_a, inc_t cs_a, \ + ctype* x, inc_t incx, \ + ctype* beta, \ + ctype* y, inc_t incy, \ + cntx_t* cntx \ + ) \ +{ \ +\ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_3); \ +\ + bli_init_once(); \ +\ + if(cntx == NULL) cntx = bli_gks_query_cntx(); \ +\ + const num_t dt = PASTEMAC(ch,type); \ +\ + ctype* zero = PASTEMAC(ch,0); \ + ctype* A1; \ + ctype* x1; \ + ctype* y1; \ + dim_t i; \ + dim_t b_fuse, f; \ + dim_t n_elem, n_iter; \ + inc_t rs_at, cs_at; \ + conj_t conja; \ +\ + bli_set_dims_incs_with_trans( transa, \ + m, n, rs_a, cs_a, \ + &n_elem, &n_iter, &rs_at, &cs_at ); \ +\ + conja = bli_extract_conj( transa ); \ +\ + /* If beta is zero, use setv. Otherwise, scale by beta. */ \ + if ( PASTEMAC(ch,eq0)( *beta ) ) \ + { \ + /* y = 0; */ \ + PASTEMAC2(ch,setv,BLIS_TAPI_EX_SUF) \ + ( \ + BLIS_NO_CONJUGATE, \ + n_elem, \ + zero, \ + y, incy, \ + cntx, \ + NULL \ + ); \ + } \ + else \ + { \ + /* y = beta * y; */ \ + PASTEMAC2(ch,scalv,BLIS_TAPI_EX_SUF) \ + ( \ + BLIS_NO_CONJUGATE, \ + n_elem, \ + beta, \ + y, incy, \ + cntx, \ + NULL \ + ); \ + } \ +\ + PASTECH(ch,axpyf_ker_ft) kfp_af; \ +\ + /* Query the context for the kernel function pointer and fusing factor. */ \ + kfp_af = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPYF_KER, cntx ); \ + b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_AF, cntx ); \ +\ + for ( i = 0; i < n_iter; i += f ) \ + { \ + f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); \ +\ + A1 = a + (0 )*rs_at + (i )*cs_at; \ + x1 = x + (i )*incx; \ + y1 = y + (0 )*incy; \ +\ + /* y = y + alpha * A1 * x1; */ \ + kfp_af \ + ( \ + conja, \ + conjx, \ + n_elem, \ + f, \ + alpha, \ + A1, rs_at, cs_at, \ + x1, incx, \ + y1, incy, \ + cntx \ + ); \ + } \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); \ +} + +void bli_dgemv_unf_var2 + ( + trans_t transa, + conj_t conjx, + dim_t m, + dim_t n, + double* alpha, + double* a, inc_t rs_a, inc_t cs_a, + double* x, inc_t incx, + double* beta, + double* y, inc_t incy, + cntx_t* cntx + ) +{ + + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_3); + double* A1; + double* x1; + dim_t i; + dim_t f; + dim_t n_elem, n_iter; + inc_t rs_at, cs_at; + conj_t conja; + //memory pool declarations for packing vector Y. + mem_t mem_bufY; + rntm_t rntm; + double *y_buf = y; + inc_t buf_incy = incy; + + // For AMD these APIS are invoked skipping intermediate framework layers + // Hence we need to ensure that cntx is set here. + bli_init_once(); + if(cntx == NULL) cntx = bli_gks_query_cntx(); + + bli_set_dims_incs_with_trans( transa, + m, n, rs_a, cs_a, + &n_elem, &n_iter, &rs_at, &cs_at ); + + conja = bli_extract_conj( transa ); + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == FALSE) + { + const num_t dt = PASTEMAC(d,type); + double* x1; + double* y1; + /* If beta is zero, use setv. Otherwise, scale by beta. */ + if ( PASTEMAC(d,eq0)( *beta ) ) + { + double* zero = PASTEMAC(d,0); + /* y = 0; */ + PASTEMAC2(d,setv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + n_elem, + zero, + y, incy, + cntx, + NULL + ); + } + else + { + /* y = beta * y; */ + PASTEMAC2(d,scalv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + n_elem, + beta, + y, incy, + cntx, + NULL + ); + } + + PASTECH(d,axpyf_ker_ft) kfp_af; + + /* Query the context for the kernel function pointer and fusing factor. */ + kfp_af = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPYF_KER, cntx ); + dim_t b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_AF, cntx ); + + for ( i = 0; i < n_iter; i += f ) + { + f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); + + A1 = a + (0 )*rs_at + (i )*cs_at; + x1 = x + (i )*incx; + y1 = y + (0 )*incy; + + /* y = y + alpha * A1 * x1; */ + kfp_af + ( + conja, + conjx, + n_elem, + f, + alpha, + A1, rs_at, cs_at, + x1, incx, + y1, incy, + cntx + ); + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); + return; + } + + /* If beta is zero, use setv. Otherwise, scale by beta. */ + /* y = beta * y; */ + /* beta=0 case is hadled by scalv internally */ + + bli_dscalv_zen_int10 + ( + BLIS_NO_CONJUGATE, + n_elem, + beta, + y, incy, + cntx + ); + + if( bli_deq0( *alpha ) ) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3) + return; + } + + if (incy > 1) + { + /* + Initialize mem pool buffer to NULL and size to 0 + "buf" and "size" fields are assigned once memory + is allocated from the pool in bli_membrk_acquire_m(). + This will ensure bli_mem_is_alloc() will be passed on + an allocated memory if created or a NULL . + */ + mem_bufY.pblk.buf = NULL; mem_bufY.pblk.block_size = 0; + mem_bufY.buf_type = 0; mem_bufY.size = 0; + mem_bufY.pool = NULL; + + /* In order to get the buffer from pool via rntm access to memory broker + is needed.Following are initializations for rntm */ + + bli_rntm_init_from_global( &rntm ); + bli_rntm_set_num_threads_only( 1, &rntm ); + bli_membrk_rntm_set_membrk( &rntm ); + + //calculate the size required for n_elem double elements in vector Y. + size_t buffer_size = n_elem * sizeof(double); + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_dgemv_unf_var2(): get mem pool block\n" ); + #endif + + /*acquire a Buffer(n_elem*size(double)) from the memory broker + and save the associated mem_t entry to mem_bufY.*/ + bli_membrk_acquire_m(&rntm, + buffer_size, + BLIS_BUFFER_FOR_B_PANEL, + &mem_bufY); + + /*Continue packing Y if buffer memory is allocated*/ + if ((bli_mem_is_alloc( &mem_bufY ))) + { + y_buf = bli_mem_buffer(&mem_bufY); + + //pack Y vector with non-unit stride to a temp buffer y_buf with unit stride + for(dim_t y_index = 0 ; y_index < n_elem ; y_index++) + { + *(y_buf + y_index) = *(y + (y_index * incy)) ; + } + // stride of vector y_buf =1 + buf_incy = 1; + } + } + + for ( i = 0; i < n_iter; i += f ) + { + f = bli_determine_blocksize_dim_f( i, n_iter, BLIS_DGEMV_VAR2_FUSE ); + + A1 = a + (0 )*rs_at + (i )*cs_at; + x1 = x + (i )*incx; + + /* y = y + alpha * A1 * x1; */ + bli_daxpyf_zen_int_16x4 + ( + conja, + conjx, + n_elem, + f, + alpha, + A1, rs_at, cs_at, + x1, incx, + y_buf, buf_incy, + cntx + ); + } + if ((incy > 1) && bli_mem_is_alloc( &mem_bufY )) + { + //store the result from unit strided y_buf to non-unit strided Y + for(dim_t y_index = 0 ; y_index < n_elem ; y_index++) + { + *(y + (y_index * incy)) = *(y_buf + y_index) ; + } + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_dgemv_unf_var2(): releasing mem pool block\n" ); + #endif + // Return the buffer to pool + bli_membrk_release(&rntm , &mem_bufY); + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); +} + +void bli_sgemv_unf_var2 + ( + trans_t transa, + conj_t conjx, + dim_t m, + dim_t n, + float* alpha, + float* a, inc_t rs_a, inc_t cs_a, + float* x, inc_t incx, + float* beta, + float* y, inc_t incy, + cntx_t* cntx + ) +{ + + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_3); + float* A1; + float* x1; + float* y1; + dim_t i; + dim_t b_fuse, f; + dim_t n_elem, n_iter; + inc_t rs_at, cs_at; + conj_t conja; + + // For AMD these APIS are invoked skipping intermediate framework layers + // Hence we need to ensure that cntx is set here. + bli_init_once(); + if(cntx == NULL) cntx = bli_gks_query_cntx(); + + bli_set_dims_incs_with_trans( transa, + m, n, rs_a, cs_a, + &n_elem, &n_iter, &rs_at, &cs_at ); + + conja = bli_extract_conj( transa ); + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == FALSE) + { + const num_t dt = PASTEMAC(s,type); + /* If beta is zero, use setv. Otherwise, scale by beta. */ + if ( PASTEMAC(s,eq0)( *beta ) ) + { + float* zero = PASTEMAC(s,0); + /* y = 0; */ + PASTEMAC2(s,setv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + n_elem, + zero, + y, incy, + cntx, + NULL + ); + } + else + { + /* y = beta * y; */ + PASTEMAC2(s,scalv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + n_elem, + beta, + y, incy, + cntx, + NULL + ); + } + + PASTECH(s,axpyf_ker_ft) kfp_af; + + /* Query the context for the kernel function pointer and fusing factor. */ + kfp_af = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPYF_KER, cntx ); + b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_AF, cntx ); + + for ( i = 0; i < n_iter; i += f ) + { + f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); + + A1 = a + (0 )*rs_at + (i )*cs_at; + x1 = x + (i )*incx; + y1 = y + (0 )*incy; + + /* y = y + alpha * A1 * x1; */ + kfp_af + ( + conja, + conjx, + n_elem, + f, + alpha, + A1, rs_at, cs_at, + x1, incx, + y1, incy, + cntx + ); + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); + return; + } + + /* If beta is zero, use setv. Otherwise, scale by beta. */ + /* y = beta * y; */ + /* beta=0 case is hadled by scalv internally */ + bli_sscalv_zen_int10 + ( + BLIS_NO_CONJUGATE, + n_elem, + beta, + y, incy, + cntx + ); + + if( bli_seq0( *alpha ) ) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3) + return; + } + + /* Query the context for the kernel function pointer and fusing factor. */ + b_fuse = 6; + + for ( i = 0; i < n_iter; i += f ) + { + f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); + + A1 = a + (0 )*rs_at + (i )*cs_at; + x1 = x + (i )*incx; + y1 = y + (0 )*incy; + + /* y = y + alpha * A1 * x1; */ + bli_saxpyf_zen_int_6 + ( + conja, + conjx, + n_elem, + f, + alpha, + A1, rs_at, cs_at, + x1, incx, + y1, incy, + cntx + ); + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); +} + + +void bli_zgemv_unf_var2 + ( + trans_t transa, + conj_t conjx, + dim_t m, + dim_t n, + dcomplex* alpha, + dcomplex* a, inc_t rs_a, inc_t cs_a, + dcomplex* x, inc_t incx, + dcomplex* beta, + dcomplex* y, inc_t incy, + cntx_t* cntx + ) +{ + + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_3); + dcomplex* A1; + dcomplex* x1; + dcomplex* y1; + dim_t i; + dim_t b_fuse, f; + dim_t n_elem, n_iter; + inc_t rs_at, cs_at; + conj_t conja; + + // For AMD these APIS are invoked skipping intermediate framework layers + // Hence we need to ensure that cntx is set here. + bli_init_once(); + if(cntx == NULL) cntx = bli_gks_query_cntx(); + + bli_set_dims_incs_with_trans( transa, + m, n, rs_a, cs_a, + &n_elem, &n_iter, &rs_at, &cs_at ); + + conja = bli_extract_conj( transa ); + + /* If beta is zero, use setv. Otherwise, scale by beta. */ + /* y = beta * y; */ + + /* beta=0 case is hadled by scalv internally */ + /* bli_zscalv_zen_int10 + ( + BLIS_NO_CONJUGATE, + n_elem, + beta, + y, + incy, + cntx + );*/ + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == FALSE) + { + const num_t dt = PASTEMAC(z,type); + /* If beta is zero, use setv. Otherwise, scale by beta. */ + if ( PASTEMAC(z,eq0)( *beta ) ) + { + dcomplex* zero = PASTEMAC(z,0); + /* y = 0; */ + PASTEMAC2(z,setv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + n_elem, + zero, + y, incy, + cntx, + NULL + ); + } + else + { + /* y = beta * y; */ + PASTEMAC2(z,scalv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + n_elem, + beta, + y, incy, + cntx, + NULL + ); + } + + PASTECH(z,axpyf_ker_ft) kfp_af; + + /* Query the context for the kernel function pointer and fusing factor. */ + kfp_af = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPYF_KER, cntx ); + b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_AF, cntx ); + + for ( i = 0; i < n_iter; i += f ) + { + f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); + + A1 = a + (0 )*rs_at + (i )*cs_at; + x1 = x + (i )*incx; + y1 = y + (0 )*incy; + + /* y = y + alpha * A1 * x1; */ + kfp_af + ( + conja, + conjx, + n_elem, + f, + alpha, + A1, rs_at, cs_at, + x1, incx, + y1, incy, + cntx + ); + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); + return; + } + + bli_zscalv_ex + ( + BLIS_NO_CONJUGATE, + n_elem, + beta, + y, incy, + cntx, + NULL + ); + + if( bli_zeq0( *alpha ) ) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); + return; + } + + // for non-unit incx, incy and rs_at and conjugate will be added in the next patch + if( (incx == 1 && incy == 1 && rs_at == 1 ) && + !bli_is_conj(conja) && !bli_is_conj(conjx) && !bli_is_trans(transa)) + { + // This gemv code deals with the followint conditions only + // 1. incx, incy, and row stride equal to one + // 2. Non conjugate A matrix and X vector + // 3. No Transpose for A Martix + // Rest is taken care by the else part (axpyf implementation) + bli_zgemv_zen_int_4x4 + ( + conja, + conjx, + m, + n, + alpha, + a, rs_at, cs_at, + x, incx, + beta, + y, incy, + cntx + ); + } + else + { + /* fusing factor */ + b_fuse = 4; + + for ( i = 0; i < n_iter; i += f ) + { + f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); + A1 = a + (0 )*rs_at + (i )*cs_at; + x1 = x + (i )*incx; + y1 = y + (0 )*incy; + + /* y = y + alpha * A1 * x1; */ + bli_zaxpyf_zen_int_4 + ( + conja, + conjx, + n_elem, + f, + alpha, + A1, rs_at, cs_at, + x1, incx, + y1, incy, + cntx + ); + } + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); +} + +void bli_cgemv_unf_var2 + ( + trans_t transa, + conj_t conjx, + dim_t m, + dim_t n, + scomplex* alpha, + scomplex* a, inc_t rs_a, inc_t cs_a, + scomplex* x, inc_t incx, + scomplex* beta, + scomplex* y, inc_t incy, + cntx_t* cntx + ) +{ + + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_3); + scomplex* A1; + scomplex* x1; + scomplex* y1; + dim_t i; + dim_t b_fuse, f; + dim_t n_elem, n_iter; + inc_t rs_at, cs_at; + conj_t conja; + + // For AMD these APIS are invoked skipping intermediate framework layers + // Hence we need to ensure that cntx is set here. + bli_init_once(); + if(cntx == NULL) cntx = bli_gks_query_cntx(); + + bli_set_dims_incs_with_trans( transa, + m, n, rs_a, cs_a, + &n_elem, &n_iter, &rs_at, &cs_at ); + + conja = bli_extract_conj( transa ); + + /* If beta is zero, use setv. Otherwise, scale by beta. */ + /* y = beta * y; */ + /* beta=0 case is hadled by scalv internally */ + /*bli_cscalv_zen_int10 + ( + BLIS_NO_CONJUGATE, + n_elem, + beta, + y, + incy, + cntx + );*/ + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == FALSE) + { + const num_t dt = PASTEMAC(c,type); + /* If beta is zero, use setv. Otherwise, scale by beta. */ + if ( PASTEMAC(c,eq0)( *beta ) ) + { + scomplex* zero = PASTEMAC(c,0); + /* y = 0; */ + PASTEMAC2(c,setv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + n_elem, + zero, + y, incy, + cntx, + NULL + ); + } + else + { + /* y = beta * y; */ + PASTEMAC2(c,scalv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + n_elem, + beta, + y, incy, + cntx, + NULL + ); + } + + PASTECH(c,axpyf_ker_ft) kfp_af; + + /* Query the context for the kernel function pointer and fusing factor. */ + kfp_af = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPYF_KER, cntx ); + b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_AF, cntx ); + + for ( i = 0; i < n_iter; i += f ) + { + f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); + + A1 = a + (0 )*rs_at + (i )*cs_at; + x1 = x + (i )*incx; + y1 = y + (0 )*incy; + + /* y = y + alpha * A1 * x1; */ + kfp_af + ( + conja, + conjx, + n_elem, + f, + alpha, + A1, rs_at, cs_at, + x1, incx, + y1, incy, + cntx + ); + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); + return; + } + + bli_cscalv_ex + ( + BLIS_NO_CONJUGATE, + n_elem, + beta, + y, incy, + cntx, + NULL + ); + + + + if( bli_ceq0( *alpha ) ) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3) + return; + } + + // for non-unit incx, incy and rs_at and conjugate will be added in the next patch + if( ( (incx == 1) && (incy == 1) && (rs_at == 1) ) && + !bli_is_conj(conja) && !bli_is_conj(conjx) && + !bli_is_trans(transa)) + { + // This gemv code deals with the followint conditions only + // 1. incx, incy, and row stride equal to one + // 2. Non conjugate A matrix and X vector + // 3. No Transpose for A Martix + // Rest is taken care by the else part (axpyf implementation) + bli_cgemv_zen_int_4x4 + ( + conja, + conjx, + m, + n, + alpha, + a, rs_at, cs_at, + x, incx, + beta, + y, incy, + cntx + ); + } + else + { + /* fusing factor. */ + b_fuse = 4; + + for ( i = 0; i < n_iter; i += f ) + { + f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); + A1 = a + (0 )*rs_at + (i )*cs_at; + x1 = x + (i )*incx; + y1 = y + (0 )*incy; + + /* y = y + alpha * A1 * x1; */ + bli_caxpyf_zen_int_4 + ( + conja, + conjx, + n_elem, + f, + alpha, + A1, rs_at, cs_at, + x1, incx, + y1, incy, + cntx + ); + } + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); +} + + + diff --git a/frame/2/hemv/bli_hemv_unf_var1.c b/frame/2/hemv/bli_hemv_unf_var1.c index 6790e5bd08..e3229543c0 100644 --- a/frame/2/hemv/bli_hemv_unf_var1.c +++ b/frame/2/hemv/bli_hemv_unf_var1.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2021-22, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -216,207 +216,5 @@ void PASTEMAC(ch,varname) \ } \ } -#ifdef BLIS_CONFIG_EPYC - -void bli_post_hemv_8x8 - ( - double *a, - double *x, - double *y, - double *alpha, - dim_t cs_a, - dim_t rs_a - ); - -void bli_dhemv_unf_var1 - ( - uplo_t uplo, - conj_t conja, - conj_t conjx, - conj_t conjh, - dim_t m, - double* alpha, - double* a, inc_t rs_a, inc_t cs_a, - double* x, inc_t incx, - double* beta, - double* y, inc_t incy, - cntx_t* cntx - ) -{ - const num_t dt = PASTEMAC(d,type); - - double* one = PASTEMAC(d,1); - double* zero = PASTEMAC(d,0); - double* A10; - double* A11; - double* a10t; - double* alpha11; - double* a21; - double* x0; - double* x1; - double* chi11; - double* y0; - double* y1; - double* y01; - double* psi11; - double* y21; - double conjx_chi11; - double alpha_chi11; - double alpha11_temp; - dim_t i, k, j; - dim_t b_fuse, f; - dim_t n_behind; - dim_t f_ahead, f_behind; - inc_t rs_at, cs_at; - conj_t conj0 = 0, conj1 = 0; - - /* The algorithm will be expressed in terms of the lower triangular - * case;the upper triangular case is supported by swapping the row - * and column strides of A and toggling some conj parameters. */ - if ( bli_is_lower( uplo ) ) - { - rs_at = rs_a; - cs_at = cs_a; - } - else /* if ( bli_is_upper( uplo ) ) */ - { - rs_at = cs_a; - cs_at = rs_a; - } - - /* If beta is zero, use setv. Otherwise, scale by beta. */ - if ( PASTEMAC(d,eq0)( *beta ) ) - { - /* y = 0; */ - PASTEMAC2(d,setv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - m, - zero, - y, incy, - cntx, - NULL - ); - } - else - { - /* y = beta * y; */ - PASTEMAC2(d,scalv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - m, - beta, - y, incy, - cntx, - NULL - ); - } - - PASTECH(d,dotxaxpyf_ker_ft) kfp_dotxaxpyf_ker; - - /* Query the context for the kernel function pointer and fusing - * factor. */ - /* Assign kernel function pointer and fusing factor. */ - arch_t id = bli_arch_query_id(); - bool bamdzen = ((id == BLIS_ARCH_ZEN4) ||(id == BLIS_ARCH_ZEN3) - || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN)); - if (bamdzen) - { - kfp_dotxaxpyf_ker = bli_ddotxaxpyf_zen_int_8; - b_fuse = 8; - } - else - { - if ( cntx == NULL ) cntx = bli_gks_query_cntx(); - kfp_dotxaxpyf_ker = - bli_cntx_get_l1f_ker_dt( dt, BLIS_DOTXAXPYF_KER, cntx); - b_fuse = - bli_cntx_get_blksz_def_dt( dt, BLIS_XF, cntx ); - } - - for ( i = 0; i < m; i += f ) - { - f = bli_determine_blocksize_dim_f( i, m, b_fuse ); - n_behind = i; - A10 = a + (i )*rs_at + (0 )*cs_at; - A11 = a + (i )*rs_at + (i )*cs_at; - x0 = x + (0 )*incx; - x1 = x + (i )*incx; - y0 = y + (0 )*incy; - y1 = y + (i )*incy; - - /* y1 = y1 + alpha * A10 * x0; (dotxf) */ - /* y0 = y0 + alpha * A10' * x1; (axpyf) */ - kfp_dotxaxpyf_ker - ( - conj0, - conj1, - conjx, - conjx, - n_behind, - f, - alpha, - A10, cs_at, rs_at, - x0, incx, - x1, incx, - one, - y1, incy, - y0, incy, - cntx - ); - - /* y1 = y1 + alpha * A11 * x1; (variant 4) */ - if((f == 8) && (incx == 1) && (incy == 1) && (cs_at == 1)) - { - /*this helper function handles unit stride only*/ - bli_post_hemv_8x8(A11, x1, y1, alpha, rs_at, cs_at); - } - else - { - for ( k = 0; k < f; ++k ) - { - f_behind = k; - f_ahead = f - k - 1; - a10t = A11 + (k )*rs_at + (0 )*cs_at; - alpha11 = A11 + (k )*rs_at + (k )*cs_at; - a21 = A11 + (k+1)*rs_at + (k )*cs_at; - chi11 = x1 + (k )*incx; - y01 = y1 + (0 )*incy; - psi11 = y1 + (k )*incy; - y21 = y1 + (k+1)*incy; - - /* y01 = y01 + alpha * a10t' * chi11; */ - PASTEMAC(d,copycjs)( conjx, *chi11, - conjx_chi11 ); - PASTEMAC(d,scal2s)( *alpha, conjx_chi11, - alpha_chi11 ); - for ( j = 0; j < f_behind; ++j ) - PASTEMAC(d,axpys)( alpha_chi11, - *(a10t + j*cs_at), - *(y01 + j*incy) ); - - PASTEMAC(d,copycjs)( conja, *alpha11, - alpha11_temp ); - - /* psi11 = psi11 + alpha * alpha11 * chi11; */ - PASTEMAC(d,axpys)( alpha_chi11, alpha11_temp, - *psi11 ); - - /* y21 = y21 + alpha * a21 * chi11; */ - for ( j = 0; j < f_ahead; ++j ) - { - PASTEMAC(d,axpys)( alpha_chi11, - *(a21 + j*rs_at), - *(y21 + j*incy) ); - } - } - } - } -} -GENTFUNC(float, s, hemv_unf_var1) -GENTFUNC(scomplex, c, hemv_unf_var1) -GENTFUNC(dcomplex, z, hemv_unf_var1) -#else INSERT_GENTFUNC_BASIC0( hemv_unf_var1 ) -#endif diff --git a/frame/2/hemv/bli_hemv_unf_var1_amd.c b/frame/2/hemv/bli_hemv_unf_var1_amd.c new file mode 100644 index 0000000000..6532323d11 --- /dev/null +++ b/frame/2/hemv/bli_hemv_unf_var1_amd.c @@ -0,0 +1,418 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2021-22, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTEMAC(ch,varname) \ + ( \ + uplo_t uplo, \ + conj_t conja, \ + conj_t conjx, \ + conj_t conjh, \ + dim_t m, \ + ctype* alpha, \ + ctype* a, inc_t rs_a, inc_t cs_a, \ + ctype* x, inc_t incx, \ + ctype* beta, \ + ctype* y, inc_t incy, \ + cntx_t* cntx \ + ) \ +{ \ + const num_t dt = PASTEMAC(ch,type); \ +\ + ctype* one = PASTEMAC(ch,1); \ + ctype* zero = PASTEMAC(ch,0); \ + ctype* A10; \ + ctype* A11; \ + ctype* a10t; \ + ctype* alpha11; \ + ctype* a21; \ + ctype* x0; \ + ctype* x1; \ + ctype* chi11; \ + ctype* y0; \ + ctype* y1; \ + ctype* y01; \ + ctype* psi11; \ + ctype* y21; \ + ctype conjx_chi11; \ + ctype alpha_chi11; \ + ctype alpha11_temp; \ + dim_t i, k, j; \ + dim_t b_fuse, f; \ + dim_t n_behind; \ + dim_t f_ahead, f_behind; \ + inc_t rs_at, cs_at; \ + conj_t conj0, conj1; \ +\ + /* The algorithm will be expressed in terms of the lower triangular case; + the upper triangular case is supported by swapping the row and column + strides of A and toggling some conj parameters. */ \ + if ( bli_is_lower( uplo ) ) \ + { \ + rs_at = rs_a; \ + cs_at = cs_a; \ +\ + conj0 = conja; \ + conj1 = bli_apply_conj( conjh, conja ); \ + } \ + else /* if ( bli_is_upper( uplo ) ) */ \ + { \ + rs_at = cs_a; \ + cs_at = rs_a; \ +\ + conj0 = bli_apply_conj( conjh, conja ); \ + conj1 = conja; \ + } \ +\ + /* If beta is zero, use setv. Otherwise, scale by beta. */ \ + if ( PASTEMAC(ch,eq0)( *beta ) ) \ + { \ + /* y = 0; */ \ + PASTEMAC2(ch,setv,BLIS_TAPI_EX_SUF) \ + ( \ + BLIS_NO_CONJUGATE, \ + m, \ + zero, \ + y, incy, \ + cntx, \ + NULL \ + ); \ + } \ + else \ + { \ + /* y = beta * y; */ \ + PASTEMAC2(ch,scalv,BLIS_TAPI_EX_SUF) \ + ( \ + BLIS_NO_CONJUGATE, \ + m, \ + beta, \ + y, incy, \ + cntx, \ + NULL \ + ); \ + } \ +\ + PASTECH(ch,dotxaxpyf_ker_ft) kfp_xf; \ +\ + /* Query the context for the kernel function pointer and fusing factor. */ \ + kfp_xf = bli_cntx_get_l1f_ker_dt( dt, BLIS_DOTXAXPYF_KER, cntx ); \ + b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_XF, cntx ); \ +\ + for ( i = 0; i < m; i += f ) \ + { \ + f = bli_determine_blocksize_dim_f( i, m, b_fuse ); \ + n_behind = i; \ + A10 = a + (i )*rs_at + (0 )*cs_at; \ + A11 = a + (i )*rs_at + (i )*cs_at; \ + x0 = x + (0 )*incx; \ + x1 = x + (i )*incx; \ + y0 = y + (0 )*incy; \ + y1 = y + (i )*incy; \ +\ + /* y1 = y1 + alpha * A10 * x0; (dotxf) */ \ + /* y0 = y0 + alpha * A10' * x1; (axpyf) */ \ + kfp_xf \ + ( \ + conj0, \ + conj1, \ + conjx, \ + conjx, \ + n_behind, \ + f, \ + alpha, \ + A10, cs_at, rs_at, \ + x0, incx, \ + x1, incx, \ + one, \ + y1, incy, \ + y0, incy, \ + cntx \ + ); \ +\ + /* y1 = y1 + alpha * A11 * x1; (variant 4) */ \ + for ( k = 0; k < f; ++k ) \ + { \ + f_behind = k; \ + f_ahead = f - k - 1; \ + a10t = A11 + (k )*rs_at + (0 )*cs_at; \ + alpha11 = A11 + (k )*rs_at + (k )*cs_at; \ + a21 = A11 + (k+1)*rs_at + (k )*cs_at; \ + chi11 = x1 + (k )*incx; \ + y01 = y1 + (0 )*incy; \ + psi11 = y1 + (k )*incy; \ + y21 = y1 + (k+1)*incy; \ +\ + /* y01 = y01 + alpha * a10t' * chi11; */ \ + PASTEMAC(ch,copycjs)( conjx, *chi11, conjx_chi11 ); \ + PASTEMAC(ch,scal2s)( *alpha, conjx_chi11, alpha_chi11 ); \ + if ( bli_is_conj( conj1 ) ) \ + { \ + for ( j = 0; j < f_behind; ++j ) \ + PASTEMAC(ch,axpyjs)( alpha_chi11, *(a10t + j*cs_at), *(y01 + j*incy) ); \ + } \ + else \ + { \ + for ( j = 0; j < f_behind; ++j ) \ + PASTEMAC(ch,axpys)( alpha_chi11, *(a10t + j*cs_at), *(y01 + j*incy) ); \ + } \ +\ + /* For hemv, explicitly set the imaginary component of alpha11 to + zero. */ \ + PASTEMAC(ch,copycjs)( conja, *alpha11, alpha11_temp ); \ + if ( bli_is_conj( conjh ) ) \ + PASTEMAC(ch,seti0s)( alpha11_temp ); \ +\ + /* psi11 = psi11 + alpha * alpha11 * chi11; */ \ + PASTEMAC(ch,axpys)( alpha_chi11, alpha11_temp, *psi11 ); \ +\ + /* y21 = y21 + alpha * a21 * chi11; */ \ + if ( bli_is_conj( conj0 ) ) \ + { \ + for ( j = 0; j < f_ahead; ++j ) \ + PASTEMAC(ch,axpyjs)( alpha_chi11, *(a21 + j*rs_at), *(y21 + j*incy) ); \ + } \ + else \ + { \ + for ( j = 0; j < f_ahead; ++j ) \ + PASTEMAC(ch,axpys)( alpha_chi11, *(a21 + j*rs_at), *(y21 + j*incy) ); \ + } \ + } \ + } \ +} + +void bli_post_hemv_8x8 + ( + double *a, + double *x, + double *y, + double *alpha, + dim_t cs_a, + dim_t rs_a + ); + +void bli_dhemv_unf_var1 + ( + uplo_t uplo, + conj_t conja, + conj_t conjx, + conj_t conjh, + dim_t m, + double* alpha, + double* a, inc_t rs_a, inc_t cs_a, + double* x, inc_t incx, + double* beta, + double* y, inc_t incy, + cntx_t* cntx + ) +{ + const num_t dt = PASTEMAC(d,type); + + double* one = PASTEMAC(d,1); + double* zero = PASTEMAC(d,0); + double* A10; + double* A11; + double* a10t; + double* alpha11; + double* a21; + double* x0; + double* x1; + double* chi11; + double* y0; + double* y1; + double* y01; + double* psi11; + double* y21; + double conjx_chi11; + double alpha_chi11; + double alpha11_temp; + dim_t i, k, j; + dim_t b_fuse, f; + dim_t n_behind; + dim_t f_ahead, f_behind; + inc_t rs_at, cs_at; + conj_t conj0 = 0, conj1 = 0; + + /* The algorithm will be expressed in terms of the lower triangular + * case;the upper triangular case is supported by swapping the row + * and column strides of A and toggling some conj parameters. */ + if ( bli_is_lower( uplo ) ) + { + rs_at = rs_a; + cs_at = cs_a; + } + else /* if ( bli_is_upper( uplo ) ) */ + { + rs_at = cs_a; + cs_at = rs_a; + } + + /* If beta is zero, use setv. Otherwise, scale by beta. */ + if ( PASTEMAC(d,eq0)( *beta ) ) + { + /* y = 0; */ + PASTEMAC2(d,setv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + m, + zero, + y, incy, + cntx, + NULL + ); + } + else + { + /* y = beta * y; */ + PASTEMAC2(d,scalv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + m, + beta, + y, incy, + cntx, + NULL + ); + } + + PASTECH(d,dotxaxpyf_ker_ft) kfp_dotxaxpyf_ker; + + /* Query the context for the kernel function pointer and fusing + * factor. */ + /* Assign kernel function pointer and fusing factor. */ + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) + { + kfp_dotxaxpyf_ker = bli_ddotxaxpyf_zen_int_8; + b_fuse = 8; + } + else + { + if ( cntx == NULL ) cntx = bli_gks_query_cntx(); + kfp_dotxaxpyf_ker = + bli_cntx_get_l1f_ker_dt( dt, BLIS_DOTXAXPYF_KER, cntx); + b_fuse = + bli_cntx_get_blksz_def_dt( dt, BLIS_XF, cntx ); + } + + for ( i = 0; i < m; i += f ) + { + f = bli_determine_blocksize_dim_f( i, m, b_fuse ); + n_behind = i; + A10 = a + (i )*rs_at + (0 )*cs_at; + A11 = a + (i )*rs_at + (i )*cs_at; + x0 = x + (0 )*incx; + x1 = x + (i )*incx; + y0 = y + (0 )*incy; + y1 = y + (i )*incy; + + /* y1 = y1 + alpha * A10 * x0; (dotxf) */ + /* y0 = y0 + alpha * A10' * x1; (axpyf) */ + kfp_dotxaxpyf_ker + ( + conj0, + conj1, + conjx, + conjx, + n_behind, + f, + alpha, + A10, cs_at, rs_at, + x0, incx, + x1, incx, + one, + y1, incy, + y0, incy, + cntx + ); + + /* y1 = y1 + alpha * A11 * x1; (variant 4) */ + if((f == 8) && (incx == 1) && (incy == 1) && (cs_at == 1)) + { + /*this helper function handles unit stride only*/ + bli_post_hemv_8x8(A11, x1, y1, alpha, rs_at, cs_at); + } + else + { + for ( k = 0; k < f; ++k ) + { + f_behind = k; + f_ahead = f - k - 1; + a10t = A11 + (k )*rs_at + (0 )*cs_at; + alpha11 = A11 + (k )*rs_at + (k )*cs_at; + a21 = A11 + (k+1)*rs_at + (k )*cs_at; + chi11 = x1 + (k )*incx; + y01 = y1 + (0 )*incy; + psi11 = y1 + (k )*incy; + y21 = y1 + (k+1)*incy; + + /* y01 = y01 + alpha * a10t' * chi11; */ + PASTEMAC(d,copycjs)( conjx, *chi11, + conjx_chi11 ); + PASTEMAC(d,scal2s)( *alpha, conjx_chi11, + alpha_chi11 ); + for ( j = 0; j < f_behind; ++j ) + PASTEMAC(d,axpys)( alpha_chi11, + *(a10t + j*cs_at), + *(y01 + j*incy) ); + + PASTEMAC(d,copycjs)( conja, *alpha11, + alpha11_temp ); + + /* psi11 = psi11 + alpha * alpha11 * chi11; */ + PASTEMAC(d,axpys)( alpha_chi11, alpha11_temp, + *psi11 ); + + /* y21 = y21 + alpha * a21 * chi11; */ + for ( j = 0; j < f_ahead; ++j ) + { + PASTEMAC(d,axpys)( alpha_chi11, + *(a21 + j*rs_at), + *(y21 + j*incy) ); + } + } + } + } +} +GENTFUNC(float, s, hemv_unf_var1) +GENTFUNC(scomplex, c, hemv_unf_var1) +GENTFUNC(dcomplex, z, hemv_unf_var1) + + diff --git a/frame/2/hemv/bli_hemv_unf_var3.c b/frame/2/hemv/bli_hemv_unf_var3.c index abf08dfdaf..b8e26cbcb6 100644 --- a/frame/2/hemv/bli_hemv_unf_var3.c +++ b/frame/2/hemv/bli_hemv_unf_var3.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -216,210 +216,6 @@ void PASTEMAC(ch,varname) \ } \ } -#ifdef BLIS_CONFIG_EPYC - -void bli_pre_hemv_8x8 - ( - double *a, - double *x, - double *y, - double *alpha, - dim_t cs_a, - dim_t rs_a - ); - -void bli_dhemv_unf_var3 - ( - uplo_t uplo, - conj_t conja, - conj_t conjx, - conj_t conjh, - dim_t m, - double* alpha, - double* a, inc_t rs_a, inc_t cs_a, - double* x, inc_t incx, - double* beta, - double* y, inc_t incy, - cntx_t* cntx - ) -{ - const num_t dt = PASTEMAC(d,type); - - double* one = PASTEMAC(d,1); - double* zero = PASTEMAC(d,0); - double* A11; - double* A21; - double* a10t; - double* alpha11; - double* a21; - double* x1; - double* x2; - double* chi11; - double* y1; - double* y2; - double* y01; - double* psi11; - double* y21; - double conjx_chi11; - double alpha_chi11; - double alpha11_temp; - dim_t i, k, j; - dim_t b_fuse, f; - dim_t n_ahead; - dim_t f_ahead, f_behind; - inc_t rs_at, cs_at; - conj_t conj0 = 0, conj1 = 0; - - /* The algorithm will be expressed in terms of the lower triangular - * case; the upper triangular case is supported by swapping the row - * and column strides of A and toggling some conj parameters. */ - if ( bli_is_lower( uplo ) ) - { - rs_at = rs_a; - cs_at = cs_a; - } - else /* if ( bli_is_upper( uplo ) ) */ - { - rs_at = cs_a; - cs_at = rs_a; - } - - /* If beta is zero, use setv. Otherwise, scale by beta. */ - if ( PASTEMAC(d,eq0)( *beta ) ) - { - /* y = 0; */ - PASTEMAC2(d,setv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - m, - zero, - y, incy, - cntx, - NULL - ); - } - else - { - /* y = beta * y; */ - PASTEMAC2(d,scalv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - m, - beta, - y, incy, - cntx, - NULL - ); - } - - PASTECH(d,dotxaxpyf_ker_ft) kfp_dotxaxpyf_ker; - - arch_t id = bli_arch_query_id(); - bool bamdzen = ((id == BLIS_ARCH_ZEN4) || (id == BLIS_ARCH_ZEN3) - || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN)); - if (bamdzen) - { - kfp_dotxaxpyf_ker = bli_ddotxaxpyf_zen_int_8; - b_fuse = 8; - } - else - { - if ( cntx == NULL ) cntx = bli_gks_query_cntx(); - kfp_dotxaxpyf_ker = - bli_cntx_get_l1f_ker_dt( dt, BLIS_DOTXAXPYF_KER, cntx); - b_fuse = - bli_cntx_get_blksz_def_dt( dt, BLIS_XF, cntx ); - } - - for ( i = 0; i < m; i += f ) - { - f = bli_determine_blocksize_dim_f( i, m, b_fuse ); - n_ahead = m - i - f; - A11 = a + (i )*rs_at + (i )*cs_at; - A21 = a + (i+f)*rs_at + (i )*cs_at; - x1 = x + (i )*incx; - x2 = x + (i+f)*incx; - y1 = y + (i )*incy; - y2 = y + (i+f)*incy; - - /* y1 = y1 + alpha * A11 * x1; (variant 4) */ - if((f == 8) && (incx == 1) && (incy == 1) && (rs_at == 1)) - { - /*this helper function handles unit stride only*/ - bli_pre_hemv_8x8(A11, x1, y1, alpha, cs_at, rs_at); - } - else - { - for ( k = 0; k < f; ++k ) - { - f_behind = k; - f_ahead = f - k - 1; - a10t = A11 + (k )*rs_at + (0 )*cs_at; - alpha11 = A11 + (k )*rs_at + (k )*cs_at; - a21 = A11 + (k+1)*rs_at + (k )*cs_at; - chi11 = x1 + (k )*incx; - y01 = y1 + (0 )*incy; - psi11 = y1 + (k )*incy; - y21 = y1 + (k+1)*incy; - - /* y01 = y01 + alpha * a10t' * chi11; */ - PASTEMAC(d,copycjs)( conjx, - *chi11, conjx_chi11 ); - PASTEMAC(d,scal2s)( *alpha, conjx_chi11, - alpha_chi11 ); - { - for ( j = 0; j < f_behind; ++j ) - { - PASTEMAC(d,axpys) - ( alpha_chi11, - *(a10t + j*cs_at), - *(y01 + j*incy) ); - } - } - - PASTEMAC(d,copycjs)( conja, *alpha11, - alpha11_temp ); - - /* psi11 = psi11 + alpha * alpha11 * chi11; */ - PASTEMAC(d,axpys)( alpha_chi11, alpha11_temp, - *psi11 ); - - /* y21 = y21 + alpha * a21 * chi11; */ - for ( j = 0; j < f_ahead; ++j ) - { - PASTEMAC(d,axpys)( alpha_chi11, - *(a21 + j*rs_at), - *(y21 + j*incy) ); - } - } - } - - /* y1 = y1 + alpha * A21' * x2; (dotxf) */ - /* y2 = y2 + alpha * A21 * x1; (axpyf) */ - kfp_dotxaxpyf_ker - ( - conj0, - conj1, - conjx, - conjx, - n_ahead, - f, - alpha, - A21, rs_at, cs_at, - x2, incx, - x1, incx, - one, - y1, incy, - y2, incy, - cntx - ); - } -} - -GENTFUNC(float, s, hemv_unf_var3) -GENTFUNC(scomplex, c, hemv_unf_var3) -GENTFUNC(dcomplex, z, hemv_unf_var3) -#else INSERT_GENTFUNC_BASIC0( hemv_unf_var3 ) -#endif + diff --git a/frame/2/hemv/bli_hemv_unf_var3_amd.c b/frame/2/hemv/bli_hemv_unf_var3_amd.c new file mode 100644 index 0000000000..34d40cf5cc --- /dev/null +++ b/frame/2/hemv/bli_hemv_unf_var3_amd.c @@ -0,0 +1,420 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTEMAC(ch,varname) \ + ( \ + uplo_t uplo, \ + conj_t conja, \ + conj_t conjx, \ + conj_t conjh, \ + dim_t m, \ + ctype* alpha, \ + ctype* a, inc_t rs_a, inc_t cs_a, \ + ctype* x, inc_t incx, \ + ctype* beta, \ + ctype* y, inc_t incy, \ + cntx_t* cntx \ + ) \ +{ \ + const num_t dt = PASTEMAC(ch,type); \ +\ + ctype* one = PASTEMAC(ch,1); \ + ctype* zero = PASTEMAC(ch,0); \ + ctype* A11; \ + ctype* A21; \ + ctype* a10t; \ + ctype* alpha11; \ + ctype* a21; \ + ctype* x1; \ + ctype* x2; \ + ctype* chi11; \ + ctype* y1; \ + ctype* y2; \ + ctype* y01; \ + ctype* psi11; \ + ctype* y21; \ + ctype conjx_chi11; \ + ctype alpha_chi11; \ + ctype alpha11_temp; \ + dim_t i, k, j; \ + dim_t b_fuse, f; \ + dim_t n_ahead; \ + dim_t f_ahead, f_behind; \ + inc_t rs_at, cs_at; \ + conj_t conj0, conj1; \ +\ + /* The algorithm will be expressed in terms of the lower triangular case; + the upper triangular case is supported by swapping the row and column + strides of A and toggling some conj parameters. */ \ + if ( bli_is_lower( uplo ) ) \ + { \ + rs_at = rs_a; \ + cs_at = cs_a; \ +\ + conj0 = bli_apply_conj( conjh, conja ); \ + conj1 = conja; \ + } \ + else /* if ( bli_is_upper( uplo ) ) */ \ + { \ + rs_at = cs_a; \ + cs_at = rs_a; \ +\ + conj0 = conja; \ + conj1 = bli_apply_conj( conjh, conja ); \ + } \ +\ + /* If beta is zero, use setv. Otherwise, scale by beta. */ \ + if ( PASTEMAC(ch,eq0)( *beta ) ) \ + { \ + /* y = 0; */ \ + PASTEMAC2(ch,setv,BLIS_TAPI_EX_SUF) \ + ( \ + BLIS_NO_CONJUGATE, \ + m, \ + zero, \ + y, incy, \ + cntx, \ + NULL \ + ); \ + } \ + else \ + { \ + /* y = beta * y; */ \ + PASTEMAC2(ch,scalv,BLIS_TAPI_EX_SUF) \ + ( \ + BLIS_NO_CONJUGATE, \ + m, \ + beta, \ + y, incy, \ + cntx, \ + NULL \ + ); \ + } \ +\ + PASTECH(ch,dotxaxpyf_ker_ft) kfp_xf; \ +\ + /* Query the context for the kernel function pointer and fusing factor. */ \ + kfp_xf = bli_cntx_get_l1f_ker_dt( dt, BLIS_DOTXAXPYF_KER, cntx ); \ + b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_XF, cntx ); \ +\ + for ( i = 0; i < m; i += f ) \ + { \ + f = bli_determine_blocksize_dim_f( i, m, b_fuse ); \ + n_ahead = m - i - f; \ + A11 = a + (i )*rs_at + (i )*cs_at; \ + A21 = a + (i+f)*rs_at + (i )*cs_at; \ + x1 = x + (i )*incx; \ + x2 = x + (i+f)*incx; \ + y1 = y + (i )*incy; \ + y2 = y + (i+f)*incy; \ +\ + /* y1 = y1 + alpha * A11 * x1; (variant 4) */ \ + for ( k = 0; k < f; ++k ) \ + { \ + f_behind = k; \ + f_ahead = f - k - 1; \ + a10t = A11 + (k )*rs_at + (0 )*cs_at; \ + alpha11 = A11 + (k )*rs_at + (k )*cs_at; \ + a21 = A11 + (k+1)*rs_at + (k )*cs_at; \ + chi11 = x1 + (k )*incx; \ + y01 = y1 + (0 )*incy; \ + psi11 = y1 + (k )*incy; \ + y21 = y1 + (k+1)*incy; \ +\ + /* y01 = y01 + alpha * a10t' * chi11; */ \ + PASTEMAC(ch,copycjs)( conjx, *chi11, conjx_chi11 ); \ + PASTEMAC(ch,scal2s)( *alpha, conjx_chi11, alpha_chi11 ); \ + if ( bli_is_conj( conj0 ) ) \ + { \ + for ( j = 0; j < f_behind; ++j ) \ + PASTEMAC(ch,axpyjs)( alpha_chi11, *(a10t + j*cs_at), *(y01 + j*incy) ); \ + } \ + else \ + { \ + for ( j = 0; j < f_behind; ++j ) \ + PASTEMAC(ch,axpys)( alpha_chi11, *(a10t + j*cs_at), *(y01 + j*incy) ); \ + } \ +\ + /* For hemv, explicitly set the imaginary component of alpha11 to + zero. */ \ + PASTEMAC(ch,copycjs)( conja, *alpha11, alpha11_temp ); \ + if ( bli_is_conj( conjh ) ) \ + PASTEMAC(ch,seti0s)( alpha11_temp ); \ +\ + /* psi11 = psi11 + alpha * alpha11 * chi11; */ \ + PASTEMAC(ch,axpys)( alpha_chi11, alpha11_temp, *psi11 ); \ +\ + /* y21 = y21 + alpha * a21 * chi11; */ \ + if ( bli_is_conj( conj1 ) ) \ + { \ + for ( j = 0; j < f_ahead; ++j ) \ + PASTEMAC(ch,axpyjs)( alpha_chi11, *(a21 + j*rs_at), *(y21 + j*incy) ); \ + } \ + else \ + { \ + for ( j = 0; j < f_ahead; ++j ) \ + PASTEMAC(ch,axpys)( alpha_chi11, *(a21 + j*rs_at), *(y21 + j*incy) ); \ + } \ + } \ +\ + /* y1 = y1 + alpha * A21' * x2; (dotxf) */ \ + /* y2 = y2 + alpha * A21 * x1; (axpyf) */ \ + kfp_xf \ + ( \ + conj0, \ + conj1, \ + conjx, \ + conjx, \ + n_ahead, \ + f, \ + alpha, \ + A21, rs_at, cs_at, \ + x2, incx, \ + x1, incx, \ + one, \ + y1, incy, \ + y2, incy, \ + cntx \ + ); \ + } \ +} + +void bli_pre_hemv_8x8 + ( + double *a, + double *x, + double *y, + double *alpha, + dim_t cs_a, + dim_t rs_a + ); + +void bli_dhemv_unf_var3 + ( + uplo_t uplo, + conj_t conja, + conj_t conjx, + conj_t conjh, + dim_t m, + double* alpha, + double* a, inc_t rs_a, inc_t cs_a, + double* x, inc_t incx, + double* beta, + double* y, inc_t incy, + cntx_t* cntx + ) +{ + const num_t dt = PASTEMAC(d,type); + + double* one = PASTEMAC(d,1); + double* zero = PASTEMAC(d,0); + double* A11; + double* A21; + double* a10t; + double* alpha11; + double* a21; + double* x1; + double* x2; + double* chi11; + double* y1; + double* y2; + double* y01; + double* psi11; + double* y21; + double conjx_chi11; + double alpha_chi11; + double alpha11_temp; + dim_t i, k, j; + dim_t b_fuse, f; + dim_t n_ahead; + dim_t f_ahead, f_behind; + inc_t rs_at, cs_at; + conj_t conj0 = 0, conj1 = 0; + + /* The algorithm will be expressed in terms of the lower triangular + * case; the upper triangular case is supported by swapping the row + * and column strides of A and toggling some conj parameters. */ + if ( bli_is_lower( uplo ) ) + { + rs_at = rs_a; + cs_at = cs_a; + } + else /* if ( bli_is_upper( uplo ) ) */ + { + rs_at = cs_a; + cs_at = rs_a; + } + + /* If beta is zero, use setv. Otherwise, scale by beta. */ + if ( PASTEMAC(d,eq0)( *beta ) ) + { + /* y = 0; */ + PASTEMAC2(d,setv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + m, + zero, + y, incy, + cntx, + NULL + ); + } + else + { + /* y = beta * y; */ + PASTEMAC2(d,scalv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + m, + beta, + y, incy, + cntx, + NULL + ); + } + + PASTECH(d,dotxaxpyf_ker_ft) kfp_dotxaxpyf_ker; + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) + { + kfp_dotxaxpyf_ker = bli_ddotxaxpyf_zen_int_8; + b_fuse = 8; + } + else + { + if ( cntx == NULL ) cntx = bli_gks_query_cntx(); + kfp_dotxaxpyf_ker = + bli_cntx_get_l1f_ker_dt( dt, BLIS_DOTXAXPYF_KER, cntx); + b_fuse = + bli_cntx_get_blksz_def_dt( dt, BLIS_XF, cntx ); + } + + for ( i = 0; i < m; i += f ) + { + f = bli_determine_blocksize_dim_f( i, m, b_fuse ); + n_ahead = m - i - f; + A11 = a + (i )*rs_at + (i )*cs_at; + A21 = a + (i+f)*rs_at + (i )*cs_at; + x1 = x + (i )*incx; + x2 = x + (i+f)*incx; + y1 = y + (i )*incy; + y2 = y + (i+f)*incy; + + /* y1 = y1 + alpha * A11 * x1; (variant 4) */ + if((f == 8) && (incx == 1) && (incy == 1) && (rs_at == 1)) + { + /*this helper function handles unit stride only*/ + bli_pre_hemv_8x8(A11, x1, y1, alpha, cs_at, rs_at); + } + else + { + for ( k = 0; k < f; ++k ) + { + f_behind = k; + f_ahead = f - k - 1; + a10t = A11 + (k )*rs_at + (0 )*cs_at; + alpha11 = A11 + (k )*rs_at + (k )*cs_at; + a21 = A11 + (k+1)*rs_at + (k )*cs_at; + chi11 = x1 + (k )*incx; + y01 = y1 + (0 )*incy; + psi11 = y1 + (k )*incy; + y21 = y1 + (k+1)*incy; + + /* y01 = y01 + alpha * a10t' * chi11; */ + PASTEMAC(d,copycjs)( conjx, + *chi11, conjx_chi11 ); + PASTEMAC(d,scal2s)( *alpha, conjx_chi11, + alpha_chi11 ); + { + for ( j = 0; j < f_behind; ++j ) + { + PASTEMAC(d,axpys) + ( alpha_chi11, + *(a10t + j*cs_at), + *(y01 + j*incy) ); + } + } + + PASTEMAC(d,copycjs)( conja, *alpha11, + alpha11_temp ); + + /* psi11 = psi11 + alpha * alpha11 * chi11; */ + PASTEMAC(d,axpys)( alpha_chi11, alpha11_temp, + *psi11 ); + + /* y21 = y21 + alpha * a21 * chi11; */ + for ( j = 0; j < f_ahead; ++j ) + { + PASTEMAC(d,axpys)( alpha_chi11, + *(a21 + j*rs_at), + *(y21 + j*incy) ); + } + } + } + + /* y1 = y1 + alpha * A21' * x2; (dotxf) */ + /* y2 = y2 + alpha * A21 * x1; (axpyf) */ + kfp_dotxaxpyf_ker + ( + conj0, + conj1, + conjx, + conjx, + n_ahead, + f, + alpha, + A21, rs_at, cs_at, + x2, incx, + x1, incx, + one, + y1, incy, + y2, incy, + cntx + ); + } +} + +GENTFUNC(float, s, hemv_unf_var3) +GENTFUNC(scomplex, c, hemv_unf_var3) +GENTFUNC(dcomplex, z, hemv_unf_var3) + + diff --git a/frame/2/her2/bli_her2_unf_var1.c b/frame/2/her2/bli_her2_unf_var1.c index 299e3d161d..a0aec48f71 100644 --- a/frame/2/her2/bli_her2_unf_var1.c +++ b/frame/2/her2/bli_her2_unf_var1.c @@ -158,217 +158,5 @@ void PASTEMAC(ch,varname) \ } \ } - -#ifdef BLIS_CONFIG_EPYC - -/** - * Following is function declaration - * that computes her2 for transposed case. - * It handles triangular part of matrix and - * remaining computation in optimal way to - * gain performance improvement. - * a is triangular matrix, x and y are vectors - */ -void bli_dher2_trans_zen_int_4 - ( - double *a, - double *x, - double *y, - double *alpha, - dim_t m, - dim_t lda - ); - -void bli_dher2_unf_var1 - ( - uplo_t uplo, - conj_t conjx, - conj_t conjy, - conj_t conjh, - dim_t m, - double* alpha, - double* x, inc_t incx, - double* y, inc_t incy, - double* c, inc_t rs_c, inc_t cs_c, - cntx_t* cntx - ) -{ - const num_t dt = PASTEMAC(d,type); - - double* x0; - double* chi1; - double* y0; - double* psi1; - double* c10t; - double* gamma11; - double alpha0; - double alpha1; - double alpha0_chi1; - double alpha1_psi1; - double alpha0_chi1_psi1; - double conjx0_chi1; - double conjy1_psi1; - double conjy0_psi1; - dim_t i; - dim_t n_behind; - inc_t rs_ct, cs_ct; - conj_t conj0, conj1; - - /* The algorithm will be expressed in terms of the lower triangular - * case;the upper triangular case is supported by swapping the row - * and column strides of A and toggling some conj parameters. - */ - if ( bli_is_lower( uplo ) ) - { - rs_ct = rs_c; - cs_ct = cs_c; - - PASTEMAC(d,copys)( *alpha, alpha0 ); - PASTEMAC(d,copycjs)( conjh, *alpha, alpha1 ); - } - else /* if ( bli_is_upper( uplo ) ) */ - { - rs_ct = cs_c; - cs_ct = rs_c; - - /* Toggle conjugation of conjx/conjy, but only if we are being - * invoked as her2; for syr2, conjx/conjy are unchanged. - */ - conjx = bli_apply_conj( conjh, conjx ); - conjy = bli_apply_conj( conjh, conjy ); - - PASTEMAC(d,copycjs)( conjh, *alpha, alpha0 ); - PASTEMAC(d,copys)( *alpha, alpha1 ); - } - - /* Apply conjh (which carries the conjugation component of the - * Hermitian transpose, if applicable) to conjx and/or conjy as - * needed to arrive at the effective conjugation for the vector - * subproblems. - */ - conj0 = bli_apply_conj( conjh, conjy ); - conj1 = bli_apply_conj( conjh, conjx ); - - PASTECH(d,axpy2v_ker_ft) kfp_2v; - - /* Query the context for the kernel function pointer. */ - kfp_2v = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPY2V_KER, cntx ); - - if( (incx == 1) && (incy == 1) && (rs_ct == 1)) - { - for ( i = 0; i < m; ) - { - n_behind = i; - x0 = x + (0 )*incx; - chi1 = x + (i )*incx; - y0 = y + (0 )*incy; - psi1 = y + (i )*incy; - c10t = c + (i )*rs_ct + (0 )*cs_ct; - gamma11 = c + (i )*rs_ct + (i )*cs_ct; - - if((n_behind >= 3)) - { - bli_dher2_trans_zen_int_4(c10t, x0, y0, &alpha0, n_behind + 1, cs_ct); - i+=4; - } - else - { - /* Apply conjx and/or conjy to chi1 and/or psi1. */ - PASTEMAC(d,copycjs)( conjx, *chi1, conjx0_chi1 ); - PASTEMAC(d,copycjs)( conjy, *psi1, conjy1_psi1 ); - PASTEMAC(d,copycjs)( conj0, *psi1, conjy0_psi1 ); - - /* Compute scalars for vector subproblems. */ - PASTEMAC(d,scal2s)( alpha0, conjx0_chi1, alpha0_chi1 ); - PASTEMAC(d,scal2s)( alpha1, conjy1_psi1, alpha1_psi1 ); - - /* Compute alpha * chi1 * conj(psi1) after both chi1 - * and psi1 have already been conjugated, if needed, - * by conjx and conjy. - */ - PASTEMAC(d,scal2s)( alpha0_chi1, conjy0_psi1, - alpha0_chi1_psi1 ); - - /* c10t = c10t + alpha * chi1 * y0'; */ - /* c10t = c10t + conj(alpha) * psi1 * x0'; */ - kfp_2v - ( - conj0, - conj1, - n_behind, - &alpha0_chi1, - &alpha1_psi1, - y0, incy, - x0, incx, - c10t, cs_ct, - cntx - ); - - /* gamma11 = gamma11 + alpha * chi1 * conj(psi1) - + conj(alpha) * psi1 * conj(chi1); */ - PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); - PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); - - i+=1; - } - } - } - else - { - for ( i = 0; i < m; ++i ) - { - n_behind = i; - x0 = x + (0 )*incx; - chi1 = x + (i )*incx; - y0 = y + (0 )*incy; - psi1 = y + (i )*incy; - c10t = c + (i )*rs_ct + (0 )*cs_ct; - gamma11 = c + (i )*rs_ct + (i )*cs_ct; - - /* Apply conjx and/or conjy to chi1 and/or psi1. */ - PASTEMAC(d,copycjs)( conjx, *chi1, conjx0_chi1 ); - PASTEMAC(d,copycjs)( conjy, *psi1, conjy1_psi1 ); - PASTEMAC(d,copycjs)( conj0, *psi1, conjy0_psi1 ); - - /* Compute scalars for vector subproblems. */ - PASTEMAC(d,scal2s)( alpha0, conjx0_chi1, alpha0_chi1 ); - PASTEMAC(d,scal2s)( alpha1, conjy1_psi1, alpha1_psi1 ); - - /* Compute alpha * chi1 * conj(psi1) after both chi1 - * and psi1 have already been conjugated, if needed, - * by conjx and conjy. - */ - PASTEMAC(d,scal2s)( alpha0_chi1, conjy0_psi1, - alpha0_chi1_psi1 ); - - /* c10t = c10t + alpha * chi1 * y0'; */ - /* c10t = c10t + conj(alpha) * psi1 * x0'; */ - kfp_2v - ( - conj0, - conj1, - n_behind, - &alpha0_chi1, - &alpha1_psi1, - y0, incy, - x0, incx, - c10t, cs_ct, - cntx - ); - - /* gamma11 = gamma11 + alpha * chi1 * conj(psi1) - + conj(alpha) * psi1 * conj(chi1); */ - PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); - PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); - - } - } -} - -GENTFUNC(float, s, her2_unf_var1) -GENTFUNC(scomplex, c, her2_unf_var1) -GENTFUNC(dcomplex, z,her2_unf_var1) -#else INSERT_GENTFUNC_BASIC0( her2_unf_var1 ) -#endif diff --git a/frame/2/her2/bli_her2_unf_var1_amd.c b/frame/2/her2/bli_her2_unf_var1_amd.c new file mode 100644 index 0000000000..43a74f49cd --- /dev/null +++ b/frame/2/her2/bli_her2_unf_var1_amd.c @@ -0,0 +1,369 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTEMAC(ch,varname) \ + ( \ + uplo_t uplo, \ + conj_t conjx, \ + conj_t conjy, \ + conj_t conjh, \ + dim_t m, \ + ctype* alpha, \ + ctype* x, inc_t incx, \ + ctype* y, inc_t incy, \ + ctype* c, inc_t rs_c, inc_t cs_c, \ + cntx_t* cntx \ + ) \ +{ \ + const num_t dt = PASTEMAC(ch,type); \ +\ + ctype* x0; \ + ctype* chi1; \ + ctype* y0; \ + ctype* psi1; \ + ctype* c10t; \ + ctype* gamma11; \ + ctype alpha0; \ + ctype alpha1; \ + ctype alpha0_chi1; \ + ctype alpha1_psi1; \ + ctype alpha0_chi1_psi1; \ + ctype conjx0_chi1; \ + ctype conjy1_psi1; \ + ctype conjy0_psi1; \ + dim_t i; \ + dim_t n_behind; \ + inc_t rs_ct, cs_ct; \ + conj_t conj0, conj1; \ +\ + /* The algorithm will be expressed in terms of the lower triangular case; + the upper triangular case is supported by swapping the row and column + strides of A and toggling some conj parameters. */ \ + if ( bli_is_lower( uplo ) ) \ + { \ + rs_ct = rs_c; \ + cs_ct = cs_c; \ +\ + PASTEMAC(ch,copys)( *alpha, alpha0 ); \ + PASTEMAC(ch,copycjs)( conjh, *alpha, alpha1 ); \ + } \ + else /* if ( bli_is_upper( uplo ) ) */ \ + { \ + rs_ct = cs_c; \ + cs_ct = rs_c; \ +\ + /* Toggle conjugation of conjx/conjy, but only if we are being invoked + as her2; for syr2, conjx/conjy are unchanged. */ \ + conjx = bli_apply_conj( conjh, conjx ); \ + conjy = bli_apply_conj( conjh, conjy ); \ +\ + PASTEMAC(ch,copycjs)( conjh, *alpha, alpha0 ); \ + PASTEMAC(ch,copys)( *alpha, alpha1 ); \ + } \ +\ + /* Apply conjh (which carries the conjugation component of the Hermitian + transpose, if applicable) to conjx and/or conjy as needed to arrive at + the effective conjugation for the vector subproblems. */ \ + conj0 = bli_apply_conj( conjh, conjy ); \ + conj1 = bli_apply_conj( conjh, conjx ); \ +\ + PASTECH(ch,axpy2v_ker_ft) kfp_2v; \ +\ + /* Query the context for the kernel function pointer. */ \ + kfp_2v = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPY2V_KER, cntx ); \ +\ + for ( i = 0; i < m; ++i ) \ + { \ + n_behind = i; \ + x0 = x + (0 )*incx; \ + chi1 = x + (i )*incx; \ + y0 = y + (0 )*incy; \ + psi1 = y + (i )*incy; \ + c10t = c + (i )*rs_ct + (0 )*cs_ct; \ + gamma11 = c + (i )*rs_ct + (i )*cs_ct; \ +\ + /* Apply conjx and/or conjy to chi1 and/or psi1. */ \ + PASTEMAC(ch,copycjs)( conjx, *chi1, conjx0_chi1 ); \ + PASTEMAC(ch,copycjs)( conjy, *psi1, conjy1_psi1 ); \ + PASTEMAC(ch,copycjs)( conj0, *psi1, conjy0_psi1 ); \ +\ + /* Compute scalars for vector subproblems. */ \ + PASTEMAC(ch,scal2s)( alpha0, conjx0_chi1, alpha0_chi1 ); \ + PASTEMAC(ch,scal2s)( alpha1, conjy1_psi1, alpha1_psi1 ); \ +\ + /* Compute alpha * chi1 * conj(psi1) after both chi1 and psi1 have + already been conjugated, if needed, by conjx and conjy. */ \ + PASTEMAC(ch,scal2s)( alpha0_chi1, conjy0_psi1, alpha0_chi1_psi1 ); \ +\ + /* c10t = c10t + alpha * chi1 * y0'; */ \ + /* c10t = c10t + conj(alpha) * psi1 * x0'; */ \ + kfp_2v \ + ( \ + conj0, \ + conj1, \ + n_behind, \ + &alpha0_chi1, \ + &alpha1_psi1, \ + y0, incy, \ + x0, incx, \ + c10t, cs_ct, \ + cntx \ + ); \ +\ + /* gamma11 = gamma11 + alpha * chi1 * conj(psi1) \ + + conj(alpha) * psi1 * conj(chi1); */ \ + PASTEMAC(ch,adds)( alpha0_chi1_psi1, *gamma11 ); \ + PASTEMAC(ch,adds)( alpha0_chi1_psi1, *gamma11 ); \ +\ + /* For her2, explicitly set the imaginary component of gamma11 to + zero. */ \ + if ( bli_is_conj( conjh ) ) \ + PASTEMAC(ch,seti0s)( *gamma11 ); \ + } \ +} + +/** + * Following is function declaration + * that computes her2 for transposed case. + * It handles triangular part of matrix and + * remaining computation in optimal way to + * gain performance improvement. + * a is triangular matrix, x and y are vectors + */ +void bli_dher2_trans_zen_int_4 + ( + double *a, + double *x, + double *y, + double *alpha, + dim_t m, + dim_t lda + ); + +void bli_dher2_unf_var1 + ( + uplo_t uplo, + conj_t conjx, + conj_t conjy, + conj_t conjh, + dim_t m, + double* alpha, + double* x, inc_t incx, + double* y, inc_t incy, + double* c, inc_t rs_c, inc_t cs_c, + cntx_t* cntx + ) +{ + const num_t dt = PASTEMAC(d,type); + + double* x0; + double* chi1; + double* y0; + double* psi1; + double* c10t; + double* gamma11; + double alpha0; + double alpha1; + double alpha0_chi1; + double alpha1_psi1; + double alpha0_chi1_psi1; + double conjx0_chi1; + double conjy1_psi1; + double conjy0_psi1; + dim_t i; + dim_t n_behind; + inc_t rs_ct, cs_ct; + conj_t conj0, conj1; + + /* The algorithm will be expressed in terms of the lower triangular + * case;the upper triangular case is supported by swapping the row + * and column strides of A and toggling some conj parameters. + */ + if ( bli_is_lower( uplo ) ) + { + rs_ct = rs_c; + cs_ct = cs_c; + + PASTEMAC(d,copys)( *alpha, alpha0 ); + PASTEMAC(d,copycjs)( conjh, *alpha, alpha1 ); + } + else /* if ( bli_is_upper( uplo ) ) */ + { + rs_ct = cs_c; + cs_ct = rs_c; + + /* Toggle conjugation of conjx/conjy, but only if we are being + * invoked as her2; for syr2, conjx/conjy are unchanged. + */ + conjx = bli_apply_conj( conjh, conjx ); + conjy = bli_apply_conj( conjh, conjy ); + + PASTEMAC(d,copycjs)( conjh, *alpha, alpha0 ); + PASTEMAC(d,copys)( *alpha, alpha1 ); + } + + /* Apply conjh (which carries the conjugation component of the + * Hermitian transpose, if applicable) to conjx and/or conjy as + * needed to arrive at the effective conjugation for the vector + * subproblems. + */ + conj0 = bli_apply_conj( conjh, conjy ); + conj1 = bli_apply_conj( conjh, conjx ); + + PASTECH(d,axpy2v_ker_ft) kfp_2v; + + /* Query the context for the kernel function pointer. */ + kfp_2v = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPY2V_KER, cntx ); + + if( (incx == 1) && (incy == 1) && (rs_ct == 1)) + { + for ( i = 0; i < m; ) + { + n_behind = i; + x0 = x + (0 )*incx; + chi1 = x + (i )*incx; + y0 = y + (0 )*incy; + psi1 = y + (i )*incy; + c10t = c + (i )*rs_ct + (0 )*cs_ct; + gamma11 = c + (i )*rs_ct + (i )*cs_ct; + + if((n_behind >= 3)) + { + bli_dher2_trans_zen_int_4(c10t, x0, y0, &alpha0, n_behind + 1, cs_ct); + i+=4; + } + else + { + /* Apply conjx and/or conjy to chi1 and/or psi1. */ + PASTEMAC(d,copycjs)( conjx, *chi1, conjx0_chi1 ); + PASTEMAC(d,copycjs)( conjy, *psi1, conjy1_psi1 ); + PASTEMAC(d,copycjs)( conj0, *psi1, conjy0_psi1 ); + + /* Compute scalars for vector subproblems. */ + PASTEMAC(d,scal2s)( alpha0, conjx0_chi1, alpha0_chi1 ); + PASTEMAC(d,scal2s)( alpha1, conjy1_psi1, alpha1_psi1 ); + + /* Compute alpha * chi1 * conj(psi1) after both chi1 + * and psi1 have already been conjugated, if needed, + * by conjx and conjy. + */ + PASTEMAC(d,scal2s)( alpha0_chi1, conjy0_psi1, + alpha0_chi1_psi1 ); + + /* c10t = c10t + alpha * chi1 * y0'; */ + /* c10t = c10t + conj(alpha) * psi1 * x0'; */ + kfp_2v + ( + conj0, + conj1, + n_behind, + &alpha0_chi1, + &alpha1_psi1, + y0, incy, + x0, incx, + c10t, cs_ct, + cntx + ); + + /* gamma11 = gamma11 + alpha * chi1 * conj(psi1) + + conj(alpha) * psi1 * conj(chi1); */ + PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); + PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); + + i+=1; + } + } + } + else + { + for ( i = 0; i < m; ++i ) + { + n_behind = i; + x0 = x + (0 )*incx; + chi1 = x + (i )*incx; + y0 = y + (0 )*incy; + psi1 = y + (i )*incy; + c10t = c + (i )*rs_ct + (0 )*cs_ct; + gamma11 = c + (i )*rs_ct + (i )*cs_ct; + + /* Apply conjx and/or conjy to chi1 and/or psi1. */ + PASTEMAC(d,copycjs)( conjx, *chi1, conjx0_chi1 ); + PASTEMAC(d,copycjs)( conjy, *psi1, conjy1_psi1 ); + PASTEMAC(d,copycjs)( conj0, *psi1, conjy0_psi1 ); + + /* Compute scalars for vector subproblems. */ + PASTEMAC(d,scal2s)( alpha0, conjx0_chi1, alpha0_chi1 ); + PASTEMAC(d,scal2s)( alpha1, conjy1_psi1, alpha1_psi1 ); + + /* Compute alpha * chi1 * conj(psi1) after both chi1 + * and psi1 have already been conjugated, if needed, + * by conjx and conjy. + */ + PASTEMAC(d,scal2s)( alpha0_chi1, conjy0_psi1, + alpha0_chi1_psi1 ); + + /* c10t = c10t + alpha * chi1 * y0'; */ + /* c10t = c10t + conj(alpha) * psi1 * x0'; */ + kfp_2v + ( + conj0, + conj1, + n_behind, + &alpha0_chi1, + &alpha1_psi1, + y0, incy, + x0, incx, + c10t, cs_ct, + cntx + ); + + /* gamma11 = gamma11 + alpha * chi1 * conj(psi1) + + conj(alpha) * psi1 * conj(chi1); */ + PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); + PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); + + } + } +} + +GENTFUNC(float, s, her2_unf_var1) +GENTFUNC(scomplex, c, her2_unf_var1) +GENTFUNC(dcomplex, z,her2_unf_var1) + + diff --git a/frame/2/her2/bli_her2_unf_var4.c b/frame/2/her2/bli_her2_unf_var4.c index e39c7224c4..3dea31d53e 100644 --- a/frame/2/her2/bli_her2_unf_var4.c +++ b/frame/2/her2/bli_her2_unf_var4.c @@ -166,192 +166,5 @@ void PASTEMAC(ch,varname) \ } \ } -#ifdef BLIS_CONFIG_EPYC -/** - * Following is function declaration - * that computes her2 for transposed case. - * It handles triangular part of matrix and - * remaining computation in optimal way to - * gain performance improvement. - * a is triangular matrix, x and y are vectors - */ -void bli_dher2_zen_int_4 - ( - double *a, - double *x, - double *y, - double *alpha, - dim_t m, - dim_t lda - ); - -void bli_dher2_unf_var4 - ( - uplo_t uplo, - conj_t conjx, - conj_t conjy, - conj_t conjh, - dim_t m, - double* alpha, - double* x, inc_t incx, - double* y, inc_t incy, - double* c, inc_t rs_c, inc_t cs_c, - cntx_t* cntx - ) -{ - - double* chi1; - double* x2; - double* psi1; - double* y2; - double* gamma11; - double* c21; - double alpha0; - double alpha0_psi1; - double alpha1_chi1; - double alpha0_chi1_psi1; - dim_t i; - dim_t n_ahead; - inc_t rs_ct, cs_ct; - - const num_t dt = PASTEMAC(d,type); - - /* The algorithm will be expressed in terms of the lower triangular - * case; the upper triangular case is supported by swapping the row - * and column strides of A and toggling some conj parameters. - */ - if ( bli_is_lower( uplo ) ) - { - rs_ct = rs_c; - cs_ct = cs_c; - - PASTEMAC(d,copys)( *alpha, alpha0 ); - } - else /* if ( bli_is_upper( uplo ) ) */ - { - rs_ct = cs_c; - cs_ct = rs_c; - - /* Toggle conjugation of conjx/conjy, but only if we are being - * invoked as her2; for syr2, conjx/conjy are unchanged. - */ - - PASTEMAC(d,copys)( *alpha, alpha0 ); - } - /* Apply conjh (which carries the conjugation component of the - * Hermitian transpose, if applicable) to conjx and/or conjy as - * needed to arrive at the effective conjugation for the vector - * subproblems. - */ - - PASTECH(d,axpy2v_ker_ft) kfp_2v; - - /* Query the context for the kernel function pointer. */ - kfp_2v = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPY2V_KER, cntx ); - - if((incx == 1) && (incy == 1) && (rs_ct == 1)) - { - for ( i = 0; i < m; ) - { - n_ahead = m - i - 1; - chi1 = x + (i ) * incx; - x2 = x + (i+1) * incx; - psi1 = y + (i ) * incy; - y2 = y + (i+1) * incy; - gamma11 = c + (i ) + (i )*cs_ct; - c21 = c + (i+1) + (i )*cs_ct; - - if((n_ahead >= 3)) - { - bli_dher2_zen_int_4(gamma11, chi1, psi1, &alpha0, n_ahead + 1, cs_ct); - i+= 4; - } - else - { - /* Compute scalars for vector subproblems. */ - PASTEMAC(d,scal2s)( alpha0, *psi1, alpha0_psi1 ); - PASTEMAC(d,scal2s)( alpha0, *chi1, alpha1_chi1 ); - - /* Compute alpha * chi1 * conj(psi1) after both chi1 - * and psi1 have - already been conjugated, if needed, by conjx and - conjy. */ - PASTEMAC(d,scal2s)( alpha0_psi1, *chi1, - alpha0_chi1_psi1 ); - - /* c21 = c21 + alpha * x2 * conj(psi1); */ - /* c21 = c21 + conj(alpha) * y2 * conj(chi1); */ - - kfp_2v - ( - conjx, - conjy, - n_ahead, - &alpha0_psi1, - &alpha1_chi1, - x2, incx, - y2, incy, - c21, rs_ct, - cntx - ); - - - PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); - PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); - i+=1; - } - } - } - else - { - for ( i = 0; i < m; ++i) - { - n_ahead = m - i - 1; - chi1 = x + (i ) * incx; - x2 = x + (i+1) * incx; - psi1 = y + (i ) * incy; - y2 = y + (i+1) * incy; - gamma11 = c + (i ) + (i )*cs_ct; - c21 = c + (i+1) + (i )*cs_ct; - - /* Compute scalars for vector subproblems. */ - PASTEMAC(d,scal2s)( alpha0, *psi1, alpha0_psi1 ); - PASTEMAC(d,scal2s)( alpha0, *chi1, alpha1_chi1 ); - - /* Compute alpha * chi1 * conj(psi1) after both chi1 - * and psi1 have - already been conjugated, if needed, by conjx and - conjy. */ - PASTEMAC(d,scal2s)( alpha0_psi1, *chi1, - alpha0_chi1_psi1 ); - - /* c21 = c21 + alpha * x2 * conj(psi1); */ - /* c21 = c21 + conj(alpha) * y2 * conj(chi1); */ - - kfp_2v - ( - conjx, - conjy, - n_ahead, - &alpha0_psi1, - &alpha1_chi1, - x2, incx, - y2, incy, - c21, rs_ct, - cntx - ); - - - PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); - PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); - } - } -} - -GENTFUNC(float, s, her2_unf_var4) -GENTFUNC(scomplex, c, her2_unf_var4) -GENTFUNC(dcomplex, z,her2_unf_var4) -#else INSERT_GENTFUNC_BASIC0( her2_unf_var4 ) -#endif diff --git a/frame/2/her2/bli_her2_unf_var4_amd.c b/frame/2/her2/bli_her2_unf_var4_amd.c new file mode 100644 index 0000000000..4d77397cd2 --- /dev/null +++ b/frame/2/her2/bli_her2_unf_var4_amd.c @@ -0,0 +1,354 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTEMAC(ch,varname) \ + ( \ + uplo_t uplo, \ + conj_t conjx, \ + conj_t conjy, \ + conj_t conjh, \ + dim_t m, \ + ctype* alpha, \ + ctype* x, inc_t incx, \ + ctype* y, inc_t incy, \ + ctype* c, inc_t rs_c, inc_t cs_c, \ + cntx_t* cntx \ + ) \ +{ \ + const num_t dt = PASTEMAC(ch,type); \ +\ + ctype* chi1; \ + ctype* x2; \ + ctype* psi1; \ + ctype* y2; \ + ctype* gamma11; \ + ctype* c21; \ + ctype alpha0; \ + ctype alpha1; \ + ctype alpha0_psi1; \ + ctype alpha1_chi1; \ + ctype alpha0_chi1_psi1; \ + ctype conjy0_psi1; \ + ctype conjx1_chi1; \ + ctype conjx0_chi1; \ + dim_t i; \ + dim_t n_ahead; \ + inc_t rs_ct, cs_ct; \ + conj_t conj0, conj1; \ + conj_t conjh_conjx; \ + conj_t conjh_conjy; \ +\ + /* Eliminate unused variable warnings. */ \ + ( void )conjh_conjx; \ + ( void )conjh_conjy; \ +\ + /* The algorithm will be expressed in terms of the lower triangular case; + the upper triangular case is supported by swapping the row and column + strides of A and toggling some conj parameters. */ \ + if ( bli_is_lower( uplo ) ) \ + { \ + rs_ct = rs_c; \ + cs_ct = cs_c; \ +\ + PASTEMAC(ch,copys)( *alpha, alpha0 ); \ + PASTEMAC(ch,copycjs)( conjh, *alpha, alpha1 ); \ + } \ + else /* if ( bli_is_upper( uplo ) ) */ \ + { \ + rs_ct = cs_c; \ + cs_ct = rs_c; \ +\ + /* Toggle conjugation of conjx/conjy, but only if we are being invoked + as her2; for syr2, conjx/conjy are unchanged. */ \ + conjx = bli_apply_conj( conjh, conjx ); \ + conjy = bli_apply_conj( conjh, conjy ); \ +\ + PASTEMAC(ch,copycjs)( conjh, *alpha, alpha0 ); \ + PASTEMAC(ch,copys)( *alpha, alpha1 ); \ + } \ +\ + /* Apply conjh (which carries the conjugation component of the Hermitian + transpose, if applicable) to conjx and/or conjy as needed to arrive at + the effective conjugation for the vector subproblems. */ \ + conj0 = conjx; \ + conj1 = conjy; \ + conjh_conjx = bli_apply_conj( conjh, conjx ); \ + conjh_conjy = bli_apply_conj( conjh, conjy ); \ +\ + PASTECH(ch,axpy2v_ker_ft) kfp_2v; \ +\ + /* Query the context for the kernel function pointer. */ \ + kfp_2v = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPY2V_KER, cntx ); \ +\ + for ( i = 0; i < m; ++i ) \ + { \ + n_ahead = m - i - 1; \ + chi1 = x + (i )*incx; \ + x2 = x + (i+1)*incx; \ + psi1 = y + (i )*incy; \ + y2 = y + (i+1)*incy; \ + gamma11 = c + (i )*rs_ct + (i )*cs_ct; \ + c21 = c + (i+1)*rs_ct + (i )*cs_ct; \ +\ + /* Apply conjx and/or conjy to chi1 and/or psi1. */ \ + PASTEMAC(ch,copycjs)( conjh_conjy, *psi1, conjy0_psi1 ); \ + PASTEMAC(ch,copycjs)( conjh_conjx, *chi1, conjx1_chi1 ); \ + PASTEMAC(ch,copycjs)( conj0, *chi1, conjx0_chi1 ); \ +\ + /* Compute scalars for vector subproblems. */ \ + PASTEMAC(ch,scal2s)( alpha0, conjy0_psi1, alpha0_psi1 ); \ + PASTEMAC(ch,scal2s)( alpha1, conjx1_chi1, alpha1_chi1 ); \ +\ + /* Compute alpha * chi1 * conj(psi1) after both chi1 and psi1 have + already been conjugated, if needed, by conjx and conjy. */ \ + PASTEMAC(ch,scal2s)( alpha0_psi1, conjx0_chi1, alpha0_chi1_psi1 ); \ +\ + /* c21 = c21 + alpha * x2 * conj(psi1); */ \ + /* c21 = c21 + conj(alpha) * y2 * conj(chi1); */ \ + kfp_2v \ + ( \ + conj0, \ + conj1, \ + n_ahead, \ + &alpha0_psi1, \ + &alpha1_chi1, \ + x2, incx, \ + y2, incy, \ + c21, rs_ct, \ + cntx \ + ); \ +\ + /* gamma11 = gamma11 + alpha * chi1 * conj(psi1) \ + + conj(alpha) * psi1 * conj(chi1); */ \ + PASTEMAC(ch,adds)( alpha0_chi1_psi1, *gamma11 ); \ + PASTEMAC(ch,adds)( alpha0_chi1_psi1, *gamma11 ); \ +\ + /* For her2, explicitly set the imaginary component of gamma11 to + zero. */ \ + if ( bli_is_conj( conjh ) ) \ + PASTEMAC(ch,seti0s)( *gamma11 ); \ + } \ +} + +/** + * Following is function declaration + * that computes her2 for transposed case. + * It handles triangular part of matrix and + * remaining computation in optimal way to + * gain performance improvement. + * a is triangular matrix, x and y are vectors + */ +void bli_dher2_zen_int_4 + ( + double *a, + double *x, + double *y, + double *alpha, + dim_t m, + dim_t lda + ); + +void bli_dher2_unf_var4 + ( + uplo_t uplo, + conj_t conjx, + conj_t conjy, + conj_t conjh, + dim_t m, + double* alpha, + double* x, inc_t incx, + double* y, inc_t incy, + double* c, inc_t rs_c, inc_t cs_c, + cntx_t* cntx + ) +{ + + double* chi1; + double* x2; + double* psi1; + double* y2; + double* gamma11; + double* c21; + double alpha0; + double alpha0_psi1; + double alpha1_chi1; + double alpha0_chi1_psi1; + dim_t i; + dim_t n_ahead; + inc_t rs_ct, cs_ct; + + const num_t dt = PASTEMAC(d,type); + + /* The algorithm will be expressed in terms of the lower triangular + * case; the upper triangular case is supported by swapping the row + * and column strides of A and toggling some conj parameters. + */ + if ( bli_is_lower( uplo ) ) + { + rs_ct = rs_c; + cs_ct = cs_c; + + PASTEMAC(d,copys)( *alpha, alpha0 ); + } + else /* if ( bli_is_upper( uplo ) ) */ + { + rs_ct = cs_c; + cs_ct = rs_c; + + /* Toggle conjugation of conjx/conjy, but only if we are being + * invoked as her2; for syr2, conjx/conjy are unchanged. + */ + + PASTEMAC(d,copys)( *alpha, alpha0 ); + } + /* Apply conjh (which carries the conjugation component of the + * Hermitian transpose, if applicable) to conjx and/or conjy as + * needed to arrive at the effective conjugation for the vector + * subproblems. + */ + + PASTECH(d,axpy2v_ker_ft) kfp_2v; + + /* Query the context for the kernel function pointer. */ + kfp_2v = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPY2V_KER, cntx ); + + if((incx == 1) && (incy == 1) && (rs_ct == 1)) + { + for ( i = 0; i < m; ) + { + n_ahead = m - i - 1; + chi1 = x + (i ) * incx; + x2 = x + (i+1) * incx; + psi1 = y + (i ) * incy; + y2 = y + (i+1) * incy; + gamma11 = c + (i ) + (i )*cs_ct; + c21 = c + (i+1) + (i )*cs_ct; + + if((n_ahead >= 3)) + { + bli_dher2_zen_int_4(gamma11, chi1, psi1, &alpha0, n_ahead + 1, cs_ct); + i+= 4; + } + else + { + /* Compute scalars for vector subproblems. */ + PASTEMAC(d,scal2s)( alpha0, *psi1, alpha0_psi1 ); + PASTEMAC(d,scal2s)( alpha0, *chi1, alpha1_chi1 ); + + /* Compute alpha * chi1 * conj(psi1) after both chi1 + * and psi1 have + already been conjugated, if needed, by conjx and + conjy. */ + PASTEMAC(d,scal2s)( alpha0_psi1, *chi1, + alpha0_chi1_psi1 ); + + /* c21 = c21 + alpha * x2 * conj(psi1); */ + /* c21 = c21 + conj(alpha) * y2 * conj(chi1); */ + + kfp_2v + ( + conjx, + conjy, + n_ahead, + &alpha0_psi1, + &alpha1_chi1, + x2, incx, + y2, incy, + c21, rs_ct, + cntx + ); + + + PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); + PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); + i+=1; + } + } + } + else + { + for ( i = 0; i < m; ++i) + { + n_ahead = m - i - 1; + chi1 = x + (i ) * incx; + x2 = x + (i+1) * incx; + psi1 = y + (i ) * incy; + y2 = y + (i+1) * incy; + gamma11 = c + (i ) + (i )*cs_ct; + c21 = c + (i+1) + (i )*cs_ct; + + /* Compute scalars for vector subproblems. */ + PASTEMAC(d,scal2s)( alpha0, *psi1, alpha0_psi1 ); + PASTEMAC(d,scal2s)( alpha0, *chi1, alpha1_chi1 ); + + /* Compute alpha * chi1 * conj(psi1) after both chi1 + * and psi1 have + already been conjugated, if needed, by conjx and + conjy. */ + PASTEMAC(d,scal2s)( alpha0_psi1, *chi1, + alpha0_chi1_psi1 ); + + /* c21 = c21 + alpha * x2 * conj(psi1); */ + /* c21 = c21 + conj(alpha) * y2 * conj(chi1); */ + + kfp_2v + ( + conjx, + conjy, + n_ahead, + &alpha0_psi1, + &alpha1_chi1, + x2, incx, + y2, incy, + c21, rs_ct, + cntx + ); + + + PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); + PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); + } + } +} + +GENTFUNC(float, s, her2_unf_var4) +GENTFUNC(scomplex, c, her2_unf_var4) +GENTFUNC(dcomplex, z,her2_unf_var4) + + diff --git a/frame/2/trsv/bli_trsv_unf_var1.c b/frame/2/trsv/bli_trsv_unf_var1.c index f2f9ea6a6d..55e28a4417 100644 --- a/frame/2/trsv/bli_trsv_unf_var1.c +++ b/frame/2/trsv/bli_trsv_unf_var1.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019 - 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -231,421 +231,4 @@ void PASTEMAC(ch,varname) \ } \ } -#ifdef BLIS_CONFIG_EPYC -void bli_dtrsv_unf_var1 - ( - uplo_t uploa, - trans_t transa, - diag_t diaga, - dim_t m, - double* alpha, - double* a, inc_t rs_a, inc_t cs_a, - double* x, inc_t incx, - cntx_t* cntx - ) -{ - - double* one = PASTEMAC(d,1); - double* minus_one = PASTEMAC(d,m1); - double* A10; - double* A11; - double* A12; - double* a10t; - double* alpha11; - double* a12t; - double* x0; - double* x1; - double* x2; - double* x01; - double* chi11; - double* x21; - double alpha11_conj; - double rho1; - dim_t iter, i, k, j, l; - dim_t b_fuse, f; - dim_t n_behind, f_behind; - inc_t rs_at, cs_at; - uplo_t uploa_trans; - conj_t conja; - - /* x = alpha * x; */ - PASTEMAC2(d,scalv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - m, - alpha, - x, incx, - cntx, - NULL - ); - - if( bli_does_notrans( transa ) ) - { - rs_at = rs_a; - cs_at = cs_a; - uploa_trans = uploa; - } - else /* if ( bli_does_trans( transa ) ) */ - { - rs_at = cs_a; - cs_at = rs_a; - uploa_trans = bli_uplo_toggled( uploa ); - } - - conja = bli_extract_conj( transa ); - - PASTECH(d,dotxf_ker_ft) kfp_df; - - /* Assign kernel function pointer and fusing factor. */ - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN4) || - (id == BLIS_ARCH_ZEN3) || - (id == BLIS_ARCH_ZEN2) || - (id == BLIS_ARCH_ZEN); - - if (bamdzen) { - kfp_df = bli_ddotxf_zen_int_8; - b_fuse = 8; - } - else - { - if ( cntx == NULL ) cntx = bli_gks_query_cntx(); - num_t dt = PASTEMAC(d,type); - kfp_df = bli_cntx_get_l1f_ker_dt( dt, BLIS_DOTXF_KER, cntx ); - b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_DF, cntx ); - } - - /* We reduce all of the possible cases down to just lower/upper. */ - if ( bli_is_upper( uploa_trans ) ) - { - for ( iter = 0; iter < m; iter += f ) - { - f = bli_determine_blocksize_dim_b( iter, m, b_fuse ); - i = m - iter - f; - n_behind = iter; - A11 = a + (i )*rs_at + (i )*cs_at; - A12 = a + (i )*rs_at + (i+f)*cs_at; - x1 = x + (i )*incx; - x2 = x + (i+f)*incx; - - /* x1 = x1 - A12 * x2; */ - kfp_df - ( - conja, - BLIS_NO_CONJUGATE, - n_behind, - f, - minus_one, - A12, cs_at, rs_at, - x2, incx, - one, - x1, incx, - cntx - ); - - /* x1 = x1 / triu( A11 ); */ - for ( k = 0; k < f; ++k ) - { - l = f - k - 1; - f_behind = k; - alpha11 = A11 + (l )*rs_at + (l )*cs_at; - a12t = A11 + (l )*rs_at + (l+1)*cs_at; - chi11 = x1 + (l )*incx; - x21 = x1 + (l+1)*incx; - - /* chi11 = chi11 - a12t * x21; */ - PASTEMAC(d,set0s)( rho1 ); - if ( bli_is_conj( conja ) ) - { - for ( j = 0; j < f_behind; ++j ) - PASTEMAC(d,dotjs)( *(a12t + j*cs_at), *(x21 + j*incx), rho1 ); - } - else - { - for ( j = 0; j < f_behind; ++j ) - PASTEMAC(d,dots)( *(a12t + j*cs_at), *(x21 + j*incx), rho1 ); - } - PASTEMAC(d,subs)( rho1, *chi11 ); - - /* chi11 = chi11 / alpha11; */ - if ( bli_is_nonunit_diag( diaga ) ) - { - PASTEMAC(d,copycjs)( conja, *alpha11, alpha11_conj ); - PASTEMAC(d,invscals)( alpha11_conj, *chi11 ); - } - } - } - } - else /* if ( bli_is_lower( uploa_trans ) ) */ - { - for ( iter = 0; iter < m; iter += f ) - { - f = bli_determine_blocksize_dim_f( iter, m, b_fuse ); - i = iter; - n_behind = i; - A11 = a + (i )*rs_at + (i )*cs_at; - A10 = a + (i )*rs_at + (0 )*cs_at; - x1 = x + (i )*incx; - x0 = x + (0 )*incx; - - /* x1 = x1 - A10 * x0; */ - kfp_df - ( - conja, - BLIS_NO_CONJUGATE, - n_behind, - f, - minus_one, - A10, cs_at, rs_at, - x0, incx, - one, - x1, incx, - cntx - ); - - /* x1 = x1 / tril( A11 ); */ - for ( k = 0; k < f; ++k ) - { - l = k; - f_behind = l; - alpha11 = A11 + (l )*rs_at + (l )*cs_at; - a10t = A11 + (l )*rs_at + (0 )*cs_at; - chi11 = x1 + (l )*incx; - x01 = x1 + (0 )*incx; - - /* chi11 = chi11 - a10t * x01; */ - PASTEMAC(d,set0s)( rho1 ); - if ( bli_is_conj( conja ) ) - { - for ( j = 0; j < f_behind; ++j ) - PASTEMAC(d,dotjs)( *(a10t + j*cs_at), *(x01 + j*incx), rho1 ); - } - else - { - for ( j = 0; j < f_behind; ++j ) - PASTEMAC(d,dots)( *(a10t + j*cs_at), *(x01 + j*incx), rho1 ); - } - PASTEMAC(d,subs)( rho1, *chi11 ); - - /* chi11 = chi11 / alpha11; */ - if ( bli_is_nonunit_diag( diaga ) ) - { - PASTEMAC(d,copycjs)( conja, *alpha11, alpha11_conj ); - PASTEMAC(d,invscals)( alpha11_conj, *chi11 ); - } - } - } - } -} - -void bli_strsv_unf_var1 - ( - uplo_t uploa, - trans_t transa, - diag_t diaga, - dim_t m, - float* alpha, - float* a, inc_t rs_a, inc_t cs_a, - float* x, inc_t incx, - cntx_t* cntx - ) -{ - - float* one = PASTEMAC(s,1); - float* minus_one = PASTEMAC(s,m1); - float* A10; - float* A11; - float* A12; - float* a10t; - float* alpha11; - float* a12t; - float* x0; - float* x1; - float* x2; - float* x01; - float* chi11; - float* x21; - float alpha11_conj; - float rho1; - dim_t iter, i, k, j, l; - dim_t b_fuse, f; - dim_t n_behind, f_behind; - inc_t rs_at, cs_at; - uplo_t uploa_trans; - conj_t conja; - - /* x = alpha * x; */ - PASTEMAC2(s,scalv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - m, - alpha, - x, incx, - cntx, - NULL - ); - - if( bli_does_notrans( transa ) ) - { - rs_at = rs_a; - cs_at = cs_a; - uploa_trans = uploa; - } - else /* if ( bli_does_trans( transa ) ) */ - { - rs_at = cs_a; - cs_at = rs_a; - uploa_trans = bli_uplo_toggled( uploa ); - } - - conja = bli_extract_conj( transa ); - - PASTECH(s,dotxf_ker_ft) kfp_df; - - /* Assign kernel function pointer and fusing factor. */ - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN4) || - (id == BLIS_ARCH_ZEN3) || - (id == BLIS_ARCH_ZEN2) || - (id == BLIS_ARCH_ZEN); - - if (bamdzen) { - kfp_df = bli_sdotxf_zen_int_8; - b_fuse = 8; - } - else - { - if ( cntx == NULL ) cntx = bli_gks_query_cntx(); - num_t dt = PASTEMAC(s,type); - kfp_df = bli_cntx_get_l1f_ker_dt( dt, BLIS_DOTXF_KER, cntx ); - b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_DF, cntx ); - - } - - /* We reduce all of the possible cases down to just lower/upper. */ - if ( bli_is_upper( uploa_trans ) ) - { - for ( iter = 0; iter < m; iter += f ) - { - f = bli_determine_blocksize_dim_b( iter, m, b_fuse ); - i = m - iter - f; - n_behind = iter; - A11 = a + (i )*rs_at + (i )*cs_at; - A12 = a + (i )*rs_at + (i+f)*cs_at; - x1 = x + (i )*incx; - x2 = x + (i+f)*incx; - - /* x1 = x1 - A12 * x2; */ - kfp_df - ( - conja, - BLIS_NO_CONJUGATE, - n_behind, - f, - minus_one, - A12, cs_at, rs_at, - x2, incx, - one, - x1, incx, - cntx - ); - - /* x1 = x1 / triu( A11 ); */ - for ( k = 0; k < f; ++k ) - { - l = f - k - 1; - f_behind = k; - alpha11 = A11 + (l )*rs_at + (l )*cs_at; - a12t = A11 + (l )*rs_at + (l+1)*cs_at; - chi11 = x1 + (l )*incx; - x21 = x1 + (l+1)*incx; - - /* chi11 = chi11 - a12t * x21; */ - PASTEMAC(s,set0s)( rho1 ); - if ( bli_is_conj( conja ) ) - { - for ( j = 0; j < f_behind; ++j ) - PASTEMAC(s,dotjs)( *(a12t + j*cs_at), *(x21 + j*incx), rho1 ); - } - else - { - for ( j = 0; j < f_behind; ++j ) - PASTEMAC(s,dots)( *(a12t + j*cs_at), *(x21 + j*incx), rho1 ); - } - PASTEMAC(s,subs)( rho1, *chi11 ); - - /* chi11 = chi11 / alpha11; */ - if ( bli_is_nonunit_diag( diaga ) ) - { - PASTEMAC(s,copycjs)( conja, *alpha11, alpha11_conj ); - PASTEMAC(s,invscals)( alpha11_conj, *chi11 ); - } - } - } - } - else /* if ( bli_is_lower( uploa_trans ) ) */ - { - for ( iter = 0; iter < m; iter += f ) - { - f = bli_determine_blocksize_dim_f( iter, m, b_fuse ); - i = iter; - n_behind = i; - A11 = a + (i )*rs_at + (i )*cs_at; - A10 = a + (i )*rs_at + (0 )*cs_at; - x1 = x + (i )*incx; - x0 = x + (0 )*incx; - - /* x1 = x1 - A10 * x0; */ - kfp_df - ( - conja, - BLIS_NO_CONJUGATE, - n_behind, - f, - minus_one, - A10, cs_at, rs_at, - x0, incx, - one, - x1, incx, - cntx - ); - - /* x1 = x1 / tril( A11 ); */ - for ( k = 0; k < f; ++k ) - { - l = k; - f_behind = l; - alpha11 = A11 + (l )*rs_at + (l )*cs_at; - a10t = A11 + (l )*rs_at + (0 )*cs_at; - chi11 = x1 + (l )*incx; - x01 = x1 + (0 )*incx; - - /* chi11 = chi11 - a10t * x01; */ - PASTEMAC(s,set0s)( rho1 ); - if ( bli_is_conj( conja ) ) - { - for ( j = 0; j < f_behind; ++j ) - PASTEMAC(s,dotjs)( *(a10t + j*cs_at), *(x01 + j*incx), rho1 ); - } - else - { - for ( j = 0; j < f_behind; ++j ) - PASTEMAC(s,dots)( *(a10t + j*cs_at), *(x01 + j*incx), rho1 ); - } - PASTEMAC(s,subs)( rho1, *chi11 ); - - /* chi11 = chi11 / alpha11; */ - if ( bli_is_nonunit_diag( diaga ) ) - { - PASTEMAC(s,copycjs)( conja, *alpha11, alpha11_conj ); - PASTEMAC(s,invscals)( alpha11_conj, *chi11 ); - } - } - } - } -} - -INSERT_GENTFUNC_BASIC0_CZ( trsv_unf_var1 ) -#else INSERT_GENTFUNC_BASIC0( trsv_unf_var1 ) -#endif diff --git a/frame/2/trsv/bli_trsv_unf_var1_amd.c b/frame/2/trsv/bli_trsv_unf_var1_amd.c new file mode 100644 index 0000000000..4f026f2c6a --- /dev/null +++ b/frame/2/trsv/bli_trsv_unf_var1_amd.c @@ -0,0 +1,638 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019 - 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTEMAC(ch,varname) \ + ( \ + uplo_t uploa, \ + trans_t transa, \ + diag_t diaga, \ + dim_t m, \ + ctype* alpha, \ + ctype* a, inc_t rs_a, inc_t cs_a, \ + ctype* x, inc_t incx, \ + cntx_t* cntx \ + ) \ +{ \ + if(cntx == NULL) cntx = bli_gks_query_cntx(); \ + const num_t dt = PASTEMAC(ch,type); \ +\ + ctype* one = PASTEMAC(ch,1); \ + ctype* minus_one = PASTEMAC(ch,m1); \ + ctype* A10; \ + ctype* A11; \ + ctype* A12; \ + ctype* a10t; \ + ctype* alpha11; \ + ctype* a12t; \ + ctype* x0; \ + ctype* x1; \ + ctype* x2; \ + ctype* x01; \ + ctype* chi11; \ + ctype* x21; \ + ctype alpha11_conj; \ + ctype rho1; \ + dim_t iter, i, k, j, l; \ + dim_t b_fuse, f; \ + dim_t n_behind, f_behind; \ + inc_t rs_at, cs_at; \ + uplo_t uploa_trans; \ + conj_t conja; \ +\ + /* x = alpha * x; */ \ + PASTEMAC2(ch,scalv,BLIS_TAPI_EX_SUF) \ + ( \ + BLIS_NO_CONJUGATE, \ + m, \ + alpha, \ + x, incx, \ + cntx, \ + NULL \ + ); \ +\ + if ( bli_does_notrans( transa ) ) \ + { \ + rs_at = rs_a; \ + cs_at = cs_a; \ + uploa_trans = uploa; \ + } \ + else /* if ( bli_does_trans( transa ) ) */ \ + { \ + rs_at = cs_a; \ + cs_at = rs_a; \ + uploa_trans = bli_uplo_toggled( uploa ); \ + } \ +\ + conja = bli_extract_conj( transa ); \ +\ + PASTECH(ch,dotxf_ker_ft) kfp_df; \ +\ + /* Query the context for the kernel function pointer and fusing factor. */ \ + kfp_df = bli_cntx_get_l1f_ker_dt( dt, BLIS_DOTXF_KER, cntx ); \ + b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_DF, cntx ); \ +\ + /* We reduce all of the possible cases down to just lower/upper. */ \ + if ( bli_is_upper( uploa_trans ) ) \ + { \ + for ( iter = 0; iter < m; iter += f ) \ + { \ + f = bli_determine_blocksize_dim_b( iter, m, b_fuse ); \ + i = m - iter - f; \ + n_behind = iter; \ + A11 = a + (i )*rs_at + (i )*cs_at; \ + A12 = a + (i )*rs_at + (i+f)*cs_at; \ + x1 = x + (i )*incx; \ + x2 = x + (i+f)*incx; \ +\ + /* x1 = x1 - A12 * x2; */ \ + kfp_df \ + ( \ + conja, \ + BLIS_NO_CONJUGATE, \ + n_behind, \ + f, \ + minus_one, \ + A12, cs_at, rs_at, \ + x2, incx, \ + one, \ + x1, incx, \ + cntx \ + ); \ +\ + /* x1 = x1 / triu( A11 ); */ \ + for ( k = 0; k < f; ++k ) \ + { \ + l = f - k - 1; \ + f_behind = k; \ + alpha11 = A11 + (l )*rs_at + (l )*cs_at; \ + a12t = A11 + (l )*rs_at + (l+1)*cs_at; \ + chi11 = x1 + (l )*incx; \ + x21 = x1 + (l+1)*incx; \ +\ + /* chi11 = chi11 - a12t * x21; */ \ + PASTEMAC(ch,set0s)( rho1 ); \ + if ( bli_is_conj( conja ) ) \ + { \ + for ( j = 0; j < f_behind; ++j ) \ + PASTEMAC(ch,dotjs)( *(a12t + j*cs_at), *(x21 + j*incx), rho1 ); \ + } \ + else \ + { \ + for ( j = 0; j < f_behind; ++j ) \ + PASTEMAC(ch,dots)( *(a12t + j*cs_at), *(x21 + j*incx), rho1 ); \ + } \ + PASTEMAC(ch,subs)( rho1, *chi11 ); \ +\ + /* chi11 = chi11 / alpha11; */ \ + if ( bli_is_nonunit_diag( diaga ) ) \ + { \ + PASTEMAC(ch,copycjs)( conja, *alpha11, alpha11_conj ); \ + PASTEMAC(ch,invscals)( alpha11_conj, *chi11 ); \ + } \ + } \ + } \ + } \ + else /* if ( bli_is_lower( uploa_trans ) ) */ \ + { \ + for ( iter = 0; iter < m; iter += f ) \ + { \ + f = bli_determine_blocksize_dim_f( iter, m, b_fuse ); \ + i = iter; \ + n_behind = i; \ + A11 = a + (i )*rs_at + (i )*cs_at; \ + A10 = a + (i )*rs_at + (0 )*cs_at; \ + x1 = x + (i )*incx; \ + x0 = x + (0 )*incx; \ +\ + /* x1 = x1 - A10 * x0; */ \ + kfp_df \ + ( \ + conja, \ + BLIS_NO_CONJUGATE, \ + n_behind, \ + f, \ + minus_one, \ + A10, cs_at, rs_at, \ + x0, incx, \ + one, \ + x1, incx, \ + cntx \ + ); \ +\ + /* x1 = x1 / tril( A11 ); */ \ + for ( k = 0; k < f; ++k ) \ + { \ + l = k; \ + f_behind = l; \ + alpha11 = A11 + (l )*rs_at + (l )*cs_at; \ + a10t = A11 + (l )*rs_at + (0 )*cs_at; \ + chi11 = x1 + (l )*incx; \ + x01 = x1 + (0 )*incx; \ +\ + /* chi11 = chi11 - a10t * x01; */ \ + PASTEMAC(ch,set0s)( rho1 ); \ + if ( bli_is_conj( conja ) ) \ + { \ + for ( j = 0; j < f_behind; ++j ) \ + PASTEMAC(ch,dotjs)( *(a10t + j*cs_at), *(x01 + j*incx), rho1 ); \ + } \ + else \ + { \ + for ( j = 0; j < f_behind; ++j ) \ + PASTEMAC(ch,dots)( *(a10t + j*cs_at), *(x01 + j*incx), rho1 ); \ + } \ + PASTEMAC(ch,subs)( rho1, *chi11 ); \ +\ + /* chi11 = chi11 / alpha11; */ \ + if ( bli_is_nonunit_diag( diaga ) ) \ + { \ + PASTEMAC(ch,copycjs)( conja, *alpha11, alpha11_conj ); \ + PASTEMAC(ch,invscals)( alpha11_conj, *chi11 ); \ + } \ + } \ + } \ + } \ +} + +void bli_dtrsv_unf_var1 + ( + uplo_t uploa, + trans_t transa, + diag_t diaga, + dim_t m, + double* alpha, + double* a, inc_t rs_a, inc_t cs_a, + double* x, inc_t incx, + cntx_t* cntx + ) +{ + + double* one = PASTEMAC(d,1); + double* minus_one = PASTEMAC(d,m1); + double* A10; + double* A11; + double* A12; + double* a10t; + double* alpha11; + double* a12t; + double* x0; + double* x1; + double* x2; + double* x01; + double* chi11; + double* x21; + double alpha11_conj; + double rho1; + dim_t iter, i, k, j, l; + dim_t b_fuse, f; + dim_t n_behind, f_behind; + inc_t rs_at, cs_at; + uplo_t uploa_trans; + conj_t conja; + + /* x = alpha * x; */ + PASTEMAC2(d,scalv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + m, + alpha, + x, incx, + cntx, + NULL + ); + + if( bli_does_notrans( transa ) ) + { + rs_at = rs_a; + cs_at = cs_a; + uploa_trans = uploa; + } + else /* if ( bli_does_trans( transa ) ) */ + { + rs_at = cs_a; + cs_at = rs_a; + uploa_trans = bli_uplo_toggled( uploa ); + } + + conja = bli_extract_conj( transa ); + + PASTECH(d,dotxf_ker_ft) kfp_df; + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) { + kfp_df = bli_ddotxf_zen_int_8; + b_fuse = 8; + } + else + { + if ( cntx == NULL ) cntx = bli_gks_query_cntx(); + num_t dt = PASTEMAC(d,type); + kfp_df = bli_cntx_get_l1f_ker_dt( dt, BLIS_DOTXF_KER, cntx ); + b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_DF, cntx ); + } + + /* We reduce all of the possible cases down to just lower/upper. */ + if ( bli_is_upper( uploa_trans ) ) + { + for ( iter = 0; iter < m; iter += f ) + { + f = bli_determine_blocksize_dim_b( iter, m, b_fuse ); + i = m - iter - f; + n_behind = iter; + A11 = a + (i )*rs_at + (i )*cs_at; + A12 = a + (i )*rs_at + (i+f)*cs_at; + x1 = x + (i )*incx; + x2 = x + (i+f)*incx; + + /* x1 = x1 - A12 * x2; */ + kfp_df + ( + conja, + BLIS_NO_CONJUGATE, + n_behind, + f, + minus_one, + A12, cs_at, rs_at, + x2, incx, + one, + x1, incx, + cntx + ); + + /* x1 = x1 / triu( A11 ); */ + for ( k = 0; k < f; ++k ) + { + l = f - k - 1; + f_behind = k; + alpha11 = A11 + (l )*rs_at + (l )*cs_at; + a12t = A11 + (l )*rs_at + (l+1)*cs_at; + chi11 = x1 + (l )*incx; + x21 = x1 + (l+1)*incx; + + /* chi11 = chi11 - a12t * x21; */ + PASTEMAC(d,set0s)( rho1 ); + if ( bli_is_conj( conja ) ) + { + for ( j = 0; j < f_behind; ++j ) + PASTEMAC(d,dotjs)( *(a12t + j*cs_at), *(x21 + j*incx), rho1 ); + } + else + { + for ( j = 0; j < f_behind; ++j ) + PASTEMAC(d,dots)( *(a12t + j*cs_at), *(x21 + j*incx), rho1 ); + } + PASTEMAC(d,subs)( rho1, *chi11 ); + + /* chi11 = chi11 / alpha11; */ + if ( bli_is_nonunit_diag( diaga ) ) + { + PASTEMAC(d,copycjs)( conja, *alpha11, alpha11_conj ); + PASTEMAC(d,invscals)( alpha11_conj, *chi11 ); + } + } + } + } + else /* if ( bli_is_lower( uploa_trans ) ) */ + { + for ( iter = 0; iter < m; iter += f ) + { + f = bli_determine_blocksize_dim_f( iter, m, b_fuse ); + i = iter; + n_behind = i; + A11 = a + (i )*rs_at + (i )*cs_at; + A10 = a + (i )*rs_at + (0 )*cs_at; + x1 = x + (i )*incx; + x0 = x + (0 )*incx; + + /* x1 = x1 - A10 * x0; */ + kfp_df + ( + conja, + BLIS_NO_CONJUGATE, + n_behind, + f, + minus_one, + A10, cs_at, rs_at, + x0, incx, + one, + x1, incx, + cntx + ); + + /* x1 = x1 / tril( A11 ); */ + for ( k = 0; k < f; ++k ) + { + l = k; + f_behind = l; + alpha11 = A11 + (l )*rs_at + (l )*cs_at; + a10t = A11 + (l )*rs_at + (0 )*cs_at; + chi11 = x1 + (l )*incx; + x01 = x1 + (0 )*incx; + + /* chi11 = chi11 - a10t * x01; */ + PASTEMAC(d,set0s)( rho1 ); + if ( bli_is_conj( conja ) ) + { + for ( j = 0; j < f_behind; ++j ) + PASTEMAC(d,dotjs)( *(a10t + j*cs_at), *(x01 + j*incx), rho1 ); + } + else + { + for ( j = 0; j < f_behind; ++j ) + PASTEMAC(d,dots)( *(a10t + j*cs_at), *(x01 + j*incx), rho1 ); + } + PASTEMAC(d,subs)( rho1, *chi11 ); + + /* chi11 = chi11 / alpha11; */ + if ( bli_is_nonunit_diag( diaga ) ) + { + PASTEMAC(d,copycjs)( conja, *alpha11, alpha11_conj ); + PASTEMAC(d,invscals)( alpha11_conj, *chi11 ); + } + } + } + } +} + +void bli_strsv_unf_var1 + ( + uplo_t uploa, + trans_t transa, + diag_t diaga, + dim_t m, + float* alpha, + float* a, inc_t rs_a, inc_t cs_a, + float* x, inc_t incx, + cntx_t* cntx + ) +{ + + float* one = PASTEMAC(s,1); + float* minus_one = PASTEMAC(s,m1); + float* A10; + float* A11; + float* A12; + float* a10t; + float* alpha11; + float* a12t; + float* x0; + float* x1; + float* x2; + float* x01; + float* chi11; + float* x21; + float alpha11_conj; + float rho1; + dim_t iter, i, k, j, l; + dim_t b_fuse, f; + dim_t n_behind, f_behind; + inc_t rs_at, cs_at; + uplo_t uploa_trans; + conj_t conja; + + /* x = alpha * x; */ + PASTEMAC2(s,scalv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + m, + alpha, + x, incx, + cntx, + NULL + ); + + if( bli_does_notrans( transa ) ) + { + rs_at = rs_a; + cs_at = cs_a; + uploa_trans = uploa; + } + else /* if ( bli_does_trans( transa ) ) */ + { + rs_at = cs_a; + cs_at = rs_a; + uploa_trans = bli_uplo_toggled( uploa ); + } + + conja = bli_extract_conj( transa ); + + PASTECH(s,dotxf_ker_ft) kfp_df; + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) { + kfp_df = bli_sdotxf_zen_int_8; + b_fuse = 8; + } + else + { + if ( cntx == NULL ) cntx = bli_gks_query_cntx(); + num_t dt = PASTEMAC(s,type); + kfp_df = bli_cntx_get_l1f_ker_dt( dt, BLIS_DOTXF_KER, cntx ); + b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_DF, cntx ); + + } + + /* We reduce all of the possible cases down to just lower/upper. */ + if ( bli_is_upper( uploa_trans ) ) + { + for ( iter = 0; iter < m; iter += f ) + { + f = bli_determine_blocksize_dim_b( iter, m, b_fuse ); + i = m - iter - f; + n_behind = iter; + A11 = a + (i )*rs_at + (i )*cs_at; + A12 = a + (i )*rs_at + (i+f)*cs_at; + x1 = x + (i )*incx; + x2 = x + (i+f)*incx; + + /* x1 = x1 - A12 * x2; */ + kfp_df + ( + conja, + BLIS_NO_CONJUGATE, + n_behind, + f, + minus_one, + A12, cs_at, rs_at, + x2, incx, + one, + x1, incx, + cntx + ); + + /* x1 = x1 / triu( A11 ); */ + for ( k = 0; k < f; ++k ) + { + l = f - k - 1; + f_behind = k; + alpha11 = A11 + (l )*rs_at + (l )*cs_at; + a12t = A11 + (l )*rs_at + (l+1)*cs_at; + chi11 = x1 + (l )*incx; + x21 = x1 + (l+1)*incx; + + /* chi11 = chi11 - a12t * x21; */ + PASTEMAC(s,set0s)( rho1 ); + if ( bli_is_conj( conja ) ) + { + for ( j = 0; j < f_behind; ++j ) + PASTEMAC(s,dotjs)( *(a12t + j*cs_at), *(x21 + j*incx), rho1 ); + } + else + { + for ( j = 0; j < f_behind; ++j ) + PASTEMAC(s,dots)( *(a12t + j*cs_at), *(x21 + j*incx), rho1 ); + } + PASTEMAC(s,subs)( rho1, *chi11 ); + + /* chi11 = chi11 / alpha11; */ + if ( bli_is_nonunit_diag( diaga ) ) + { + PASTEMAC(s,copycjs)( conja, *alpha11, alpha11_conj ); + PASTEMAC(s,invscals)( alpha11_conj, *chi11 ); + } + } + } + } + else /* if ( bli_is_lower( uploa_trans ) ) */ + { + for ( iter = 0; iter < m; iter += f ) + { + f = bli_determine_blocksize_dim_f( iter, m, b_fuse ); + i = iter; + n_behind = i; + A11 = a + (i )*rs_at + (i )*cs_at; + A10 = a + (i )*rs_at + (0 )*cs_at; + x1 = x + (i )*incx; + x0 = x + (0 )*incx; + + /* x1 = x1 - A10 * x0; */ + kfp_df + ( + conja, + BLIS_NO_CONJUGATE, + n_behind, + f, + minus_one, + A10, cs_at, rs_at, + x0, incx, + one, + x1, incx, + cntx + ); + + /* x1 = x1 / tril( A11 ); */ + for ( k = 0; k < f; ++k ) + { + l = k; + f_behind = l; + alpha11 = A11 + (l )*rs_at + (l )*cs_at; + a10t = A11 + (l )*rs_at + (0 )*cs_at; + chi11 = x1 + (l )*incx; + x01 = x1 + (0 )*incx; + + /* chi11 = chi11 - a10t * x01; */ + PASTEMAC(s,set0s)( rho1 ); + if ( bli_is_conj( conja ) ) + { + for ( j = 0; j < f_behind; ++j ) + PASTEMAC(s,dotjs)( *(a10t + j*cs_at), *(x01 + j*incx), rho1 ); + } + else + { + for ( j = 0; j < f_behind; ++j ) + PASTEMAC(s,dots)( *(a10t + j*cs_at), *(x01 + j*incx), rho1 ); + } + PASTEMAC(s,subs)( rho1, *chi11 ); + + /* chi11 = chi11 / alpha11; */ + if ( bli_is_nonunit_diag( diaga ) ) + { + PASTEMAC(s,copycjs)( conja, *alpha11, alpha11_conj ); + PASTEMAC(s,invscals)( alpha11_conj, *chi11 ); + } + } + } + } +} + +INSERT_GENTFUNC_BASIC0_CZ( trsv_unf_var1 ) + diff --git a/frame/2/trsv/bli_trsv_unf_var2.c b/frame/2/trsv/bli_trsv_unf_var2.c index 2fd89dacf5..c0ef6abe45 100644 --- a/frame/2/trsv/bli_trsv_unf_var2.c +++ b/frame/2/trsv/bli_trsv_unf_var2.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019 - 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -228,805 +228,5 @@ void PASTEMAC(ch,varname) \ } \ } \ } -#ifdef BLIS_CONFIG_EPYC -void bli_dtrsv_unf_var2 - ( - uplo_t uploa, - trans_t transa, - diag_t diaga, - dim_t m, - double* alpha, - double* a, inc_t rs_a, inc_t cs_a, - double* x, inc_t incx, - cntx_t* cntx - ) -{ - double* minus_one = PASTEMAC(d,m1); - double* A01; - double* A11; - double* A21; - double* a01; - double* alpha11; - double* a21; - double* x0; - double* x1; - double* x2; - double* x01; - double* chi11; - double* x21; - double alpha11_conj; - double minus_chi11; - dim_t iter, i, k, j, l; - dim_t b_fuse, f; - dim_t n_ahead, f_ahead; - inc_t rs_at, cs_at; - uplo_t uploa_trans; - conj_t conja; - - /* x = alpha * x; */ - PASTEMAC2(d,scalv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - m, - alpha, - x, incx, - cntx, - NULL - ); - - if ( bli_does_notrans( transa ) ) - { - rs_at = rs_a; - cs_at = cs_a; - uploa_trans = uploa; - } - else /* if ( bli_does_trans( transa ) ) */ - { - rs_at = cs_a; - cs_at = rs_a; - uploa_trans = bli_uplo_toggled( uploa ); - } - - conja = bli_extract_conj( transa ); - - PASTECH(d,axpyf_ker_ft) kfp_af; - - /* Assign kernel function pointer and fusing factor. */ - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN4) || - (id == BLIS_ARCH_ZEN3) || - (id == BLIS_ARCH_ZEN2) || - (id == BLIS_ARCH_ZEN); - - if (bamdzen) { - kfp_af = bli_daxpyf_zen_int_16x4; - b_fuse = 4; - } - else - { - if ( cntx == NULL ) cntx = bli_gks_query_cntx(); - kfp_af = bli_cntx_get_l1f_ker_dt( BLIS_DOUBLE, BLIS_AXPYF_KER, cntx ); - b_fuse = bli_cntx_get_blksz_def_dt( BLIS_DOUBLE, BLIS_AF, cntx ); - } - - /* We reduce all of the possible cases down to just lower/upper. */ - if ( bli_is_upper( uploa_trans ) ) - { - for ( iter = 0; iter < m; iter += f ) - { - f = bli_determine_blocksize_dim_b( iter, m, b_fuse ); - i = m - iter - f; - n_ahead = i; - A11 = a + (i )*rs_at + (i )*cs_at; - A01 = a + (0 )*rs_at + (i )*cs_at; - x1 = x + (i )*incx; - x0 = x + (0 )*incx; - - /* x1 = x1 / triu( A11 ); */ - for ( k = 0; k < f; ++k ) - { - l = f - k - 1; - f_ahead = l; - alpha11 = A11 + (l )*rs_at + (l )*cs_at; - a01 = A11 + (0 )*rs_at + (l )*cs_at; - chi11 = x1 + (l )*incx; - x01 = x1 + (0 )*incx; - - /* chi11 = chi11 / alpha11; */ - if ( bli_is_nonunit_diag( diaga ) ) - { - PASTEMAC(d,copycjs)( conja, *alpha11, alpha11_conj ); - PASTEMAC(d,invscals)( alpha11_conj, *chi11 ); - } - - /* x01 = x01 - chi11 * a01; */ - PASTEMAC(d,neg2s)( *chi11, minus_chi11 ); - if ( bli_is_conj( conja ) ) - { - for ( j = 0; j < f_ahead; ++j ) - PASTEMAC(d,axpyjs)( minus_chi11, *(a01 + j*rs_at), *(x01 + j*incx) ); - } - else - { - for ( j = 0; j < f_ahead; ++j ) - PASTEMAC(d,axpys)( minus_chi11, *(a01 + j*rs_at), *(x01 + j*incx) ); - } - } - - /* x0 = x0 - A01 * x1; */ - kfp_af - ( - conja, - BLIS_NO_CONJUGATE, - n_ahead, - f, - minus_one, - A01, rs_at, cs_at, - x1, incx, - x0, incx, - cntx - ); - } - } - else /* if ( bli_is_lower( uploa_trans ) ) */ - { - for ( iter = 0; iter < m; iter += f ) - { - f = bli_determine_blocksize_dim_f( iter, m, b_fuse ); - i = iter; - n_ahead = m - iter - f; - A11 = a + (i )*rs_at + (i )*cs_at; - A21 = a + (i+f)*rs_at + (i )*cs_at; - x1 = x + (i )*incx; - x2 = x + (i+f)*incx; - - /* x1 = x1 / tril( A11 ); */ - for ( k = 0; k < f; ++k ) - { - l = k; - f_ahead = f - k - 1; - alpha11 = A11 + (l )*rs_at + (l )*cs_at; - a21 = A11 + (l+1)*rs_at + (l )*cs_at; - chi11 = x1 + (l )*incx; - x21 = x1 + (l+1)*incx; - - /* chi11 = chi11 / alpha11; */ - if ( bli_is_nonunit_diag( diaga ) ) - { - PASTEMAC(d,copycjs)( conja, *alpha11, alpha11_conj ); - PASTEMAC(d,invscals)( alpha11_conj, *chi11 ); - } - - /* x21 = x21 - chi11 * a21; */ - PASTEMAC(d,neg2s)( *chi11, minus_chi11 ); - if ( bli_is_conj( conja ) ) - { - for ( j = 0; j < f_ahead; ++j ) - PASTEMAC(d,axpyjs)( minus_chi11, *(a21 + j*rs_at), *(x21 + j*incx) ); - } - else - { - for ( j = 0; j < f_ahead; ++j ) - PASTEMAC(d,axpys)( minus_chi11, *(a21 + j*rs_at), *(x21 + j*incx) ); - } - } - - /* x2 = x2 - A21 * x1; */ - kfp_af - ( - conja, - BLIS_NO_CONJUGATE, - n_ahead, - f, - minus_one, - A21, rs_at, cs_at, - x1, incx, - x2, incx, - cntx - ); - } - } -} - -void bli_strsv_unf_var2 - ( - uplo_t uploa, - trans_t transa, - diag_t diaga, - dim_t m, - float* alpha, - float* a, inc_t rs_a, inc_t cs_a, - float* x, inc_t incx, - cntx_t* cntx - ) -{ - - float* minus_one = PASTEMAC(s, m1); - float* A01; - float* A11; - float* A21; - float* a01; - float* alpha11; - float* a21; - float* x0; - float* x1; - float* x2; - float* x01; - float* chi11; - float* x21; - float alpha11_conj; - float minus_chi11; - dim_t iter, i, k, j, l; - dim_t b_fuse, f; - dim_t n_ahead, f_ahead; - inc_t rs_at, cs_at; - uplo_t uploa_trans; - conj_t conja; - - /* x = alpha * x; */ - PASTEMAC2(s, scalv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - m, - alpha, - x, incx, - cntx, - NULL - ); - - if( bli_does_notrans( transa ) ) - { - rs_at = rs_a; - cs_at = cs_a; - uploa_trans = uploa; - } - else /* if ( bli_does_trans( transa ) ) */ - { - rs_at = cs_a; - cs_at = rs_a; - uploa_trans = bli_uplo_toggled( uploa ); - } - - conja = bli_extract_conj( transa ); - - PASTECH(s, axpyf_ker_ft) kfp_af; - - /* Assign function pointer and fusing factor. */ - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN4) || - (id == BLIS_ARCH_ZEN3) || - (id == BLIS_ARCH_ZEN2) || - (id == BLIS_ARCH_ZEN); - - if (bamdzen) { - kfp_af = bli_saxpyf_zen_int_5; - b_fuse = 5; - } - else - { - if ( cntx == NULL ) cntx = bli_gks_query_cntx(); - kfp_af = bli_cntx_get_l1f_ker_dt( BLIS_FLOAT, BLIS_AXPYF_KER, cntx ); - b_fuse = bli_cntx_get_blksz_def_dt( BLIS_FLOAT, BLIS_AF, cntx ); - } - - /* We reduce all of the possible cases down to just lower/upper. */ - if ( bli_is_upper( uploa_trans ) ) - { - for ( iter = 0; iter < m; iter += f ) - { - f = bli_determine_blocksize_dim_b( iter, m, b_fuse ); - i = m - iter - f; - n_ahead = i; - A11 = a + (i )*rs_at + (i )*cs_at; - A01 = a + (0 )*rs_at + (i )*cs_at; - x1 = x + (i )*incx; - x0 = x + (0 )*incx; - - /* x1 = x1 / triu( A11 ); */ - for ( k = 0; k < f; ++k ) - { - l = f - k - 1; - f_ahead = l; - alpha11 = A11 + (l )*rs_at + (l )*cs_at; - a01 = A11 + (0 )*rs_at + (l )*cs_at; - chi11 = x1 + (l )*incx; - x01 = x1 + (0 )*incx; - - /* chi11 = chi11 / alpha11; */ - if ( bli_is_nonunit_diag( diaga ) ) - { - PASTEMAC(s, copycjs)( conja, *alpha11, alpha11_conj ); - PASTEMAC(s, invscals)( alpha11_conj, *chi11 ); - } - - /* x01 = x01 - chi11 * a01; */ - PASTEMAC(s, neg2s)( *chi11, minus_chi11 ); - if ( bli_is_conj( conja ) ) - { - for ( j = 0; j < f_ahead; ++j ) - PASTEMAC(s, axpyjs)( minus_chi11, *(a01 + j*rs_at), *(x01 + j*incx) ); - } - else - { - for ( j = 0; j < f_ahead; ++j ) - PASTEMAC(s, axpys)( minus_chi11, *(a01 + j*rs_at), *(x01 + j*incx) ); - } - } - - /* x0 = x0 - A01 * x1; */ - kfp_af - ( - conja, - BLIS_NO_CONJUGATE, - n_ahead, - f, - minus_one, - A01, rs_at, cs_at, - x1, incx, - x0, incx, - cntx - ); - } - } - else /* if ( bli_is_lower( uploa_trans ) ) */ - { - for ( iter = 0; iter < m; iter += f ) - { - f = bli_determine_blocksize_dim_f( iter, m, b_fuse ); - i = iter; - n_ahead = m - iter - f; - A11 = a + (i )*rs_at + (i )*cs_at; - A21 = a + (i+f)*rs_at + (i )*cs_at; - x1 = x + (i )*incx; - x2 = x + (i+f)*incx; - - /* x1 = x1 / tril( A11 ); */ - for ( k = 0; k < f; ++k ) - { - l = k; - f_ahead = f - k - 1; - alpha11 = A11 + (l )*rs_at + (l )*cs_at; - a21 = A11 + (l+1)*rs_at + (l )*cs_at; - chi11 = x1 + (l )*incx; - x21 = x1 + (l+1)*incx; - - /* chi11 = chi11 / alpha11; */ - if ( bli_is_nonunit_diag( diaga ) ) - { - PASTEMAC(s, copycjs)( conja, *alpha11, alpha11_conj ); - PASTEMAC(s, invscals)( alpha11_conj, *chi11 ); - } - - /* x21 = x21 - chi11 * a21; */ - PASTEMAC(s, neg2s)( *chi11, minus_chi11 ); - if ( bli_is_conj( conja ) ) - { - for ( j = 0; j < f_ahead; ++j ) - PASTEMAC(s, axpyjs)( minus_chi11, *(a21 + j*rs_at), *(x21 + j*incx) ); - } - else - { - for ( j = 0; j < f_ahead; ++j ) - PASTEMAC(s, axpys)( minus_chi11, *(a21 + j*rs_at), *(x21 + j*incx) ); - } - } - - /* x2 = x2 - A21 * x1; */ - kfp_af - ( - conja, - BLIS_NO_CONJUGATE, - n_ahead, - f, - minus_one, - A21, rs_at, cs_at, - x1, incx, - x2, incx, - cntx - ); - } - } -} - -void bli_ztrsv_unf_var2 - ( - uplo_t uploa, - trans_t transa, - diag_t diaga, - dim_t m, - dcomplex* alpha, - dcomplex* a, inc_t rs_a, inc_t cs_a, - dcomplex* x, inc_t incx, - cntx_t* cntx - ) -{ - - dcomplex* minus_one = PASTEMAC(z, m1); - dcomplex* A01; - dcomplex* A11; - dcomplex* A21; - dcomplex* a01; - dcomplex* alpha11; - dcomplex* a21; - dcomplex* x0; - dcomplex* x1; - dcomplex* x2; - dcomplex* x01; - dcomplex* chi11; - dcomplex* x21; - dcomplex alpha11_conj; - dcomplex minus_chi11; - dim_t iter, i, k, j, l; - dim_t b_fuse, f; - dim_t n_ahead, f_ahead; - inc_t rs_at, cs_at; - uplo_t uploa_trans; - conj_t conja; - - /* x = alpha * x; */ - PASTEMAC2(z, scalv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - m, - alpha, - x, incx, - cntx, - NULL - ); - - if( bli_does_notrans( transa ) ) - { - rs_at = rs_a; - cs_at = cs_a; - uploa_trans = uploa; - } - else /* if ( bli_does_trans( transa ) ) */ - { - rs_at = cs_a; - cs_at = rs_a; - uploa_trans = bli_uplo_toggled( uploa ); - } - - conja = bli_extract_conj( transa ); - - PASTECH(z, axpyf_ker_ft) kfp_af; - - /* Assign function pointer and fusing factor. */ - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN4) || - (id == BLIS_ARCH_ZEN3) || - (id == BLIS_ARCH_ZEN2) || - (id == BLIS_ARCH_ZEN); - - if (bamdzen) { - kfp_af = bli_zaxpyf_zen_int_5; - b_fuse = 5; - } - else - { - if ( cntx == NULL ) cntx = bli_gks_query_cntx(); - kfp_af = bli_cntx_get_l1f_ker_dt( BLIS_DCOMPLEX, BLIS_AXPYF_KER, cntx ); - b_fuse = bli_cntx_get_blksz_def_dt( BLIS_DCOMPLEX, BLIS_AF, cntx ); - } - /* We reduce all of the possible cases down to just lower/upper. */ - if ( bli_is_upper( uploa_trans ) ) - { - for ( iter = 0; iter < m; iter += f ) - { - f = bli_determine_blocksize_dim_b( iter, m, b_fuse ); - i = m - iter - f; - n_ahead = i; - A11 = a + (i )*rs_at + (i )*cs_at; - A01 = a + (0 )*rs_at + (i )*cs_at; - x1 = x + (i )*incx; - x0 = x + (0 )*incx; - - /* x1 = x1 / triu( A11 ); */ - for ( k = 0; k < f; ++k ) - { - l = f - k - 1; - f_ahead = l; - alpha11 = A11 + (l )*rs_at + (l )*cs_at; - a01 = A11 + (0 )*rs_at + (l )*cs_at; - chi11 = x1 + (l )*incx; - x01 = x1 + (0 )*incx; - - /* chi11 = chi11 / alpha11; */ - if ( bli_is_nonunit_diag( diaga ) ) - { - PASTEMAC(z, copycjs)( conja, *alpha11, alpha11_conj ); - PASTEMAC(z, invscals)( alpha11_conj, *chi11 ); - } - - /* x01 = x01 - chi11 * a01; */ - PASTEMAC(z, neg2s)( *chi11, minus_chi11 ); - if ( bli_is_conj( conja ) ) - { - for ( j = 0; j < f_ahead; ++j ) - PASTEMAC(z, axpyjs)( minus_chi11, *(a01 + j*rs_at), *(x01 + j*incx) ); - } - else - { - for ( j = 0; j < f_ahead; ++j ) - PASTEMAC(z, axpys)( minus_chi11, *(a01 + j*rs_at), *(x01 + j*incx) ); - } - } - - /* x0 = x0 - A01 * x1; */ - kfp_af - ( - conja, - BLIS_NO_CONJUGATE, - n_ahead, - f, - minus_one, - A01, rs_at, cs_at, - x1, incx, - x0, incx, - cntx - ); - } - } - else /* if ( bli_is_lower( uploa_trans ) ) */ - { - for ( iter = 0; iter < m; iter += f ) - { - f = bli_determine_blocksize_dim_f( iter, m, b_fuse ); - i = iter; - n_ahead = m - iter - f; - A11 = a + (i )*rs_at + (i )*cs_at; - A21 = a + (i+f)*rs_at + (i )*cs_at; - x1 = x + (i )*incx; - x2 = x + (i+f)*incx; - - /* x1 = x1 / tril( A11 ); */ - for ( k = 0; k < f; ++k ) - { - l = k; - f_ahead = f - k - 1; - alpha11 = A11 + (l )*rs_at + (l )*cs_at; - a21 = A11 + (l+1)*rs_at + (l )*cs_at; - chi11 = x1 + (l )*incx; - x21 = x1 + (l+1)*incx; - - /* chi11 = chi11 / alpha11; */ - if ( bli_is_nonunit_diag( diaga ) ) - { - PASTEMAC(z, copycjs)( conja, *alpha11, alpha11_conj ); - PASTEMAC(z, invscals)( alpha11_conj, *chi11 ); - } - - /* x21 = x21 - chi11 * a21; */ - PASTEMAC(z, neg2s)( *chi11, minus_chi11 ); - if ( bli_is_conj( conja ) ) - { - for ( j = 0; j < f_ahead; ++j ) - PASTEMAC(z, axpyjs)( minus_chi11, *(a21 + j*rs_at), *(x21 + j*incx) ); - } - else - { - for ( j = 0; j < f_ahead; ++j ) - PASTEMAC(z, axpys)( minus_chi11, *(a21 + j*rs_at), *(x21 + j*incx) ); - } - } - - /* x2 = x2 - A21 * x1; */ - kfp_af - ( - conja, - BLIS_NO_CONJUGATE, - n_ahead, - f, - minus_one, - A21, rs_at, cs_at, - x1, incx, - x2, incx, - cntx - ); - } - } -} - -void bli_ctrsv_unf_var2 - ( - uplo_t uploa, - trans_t transa, - diag_t diaga, - dim_t m, - scomplex* alpha, - scomplex* a, inc_t rs_a, inc_t cs_a, - scomplex* x, inc_t incx, - cntx_t* cntx - ) -{ - - scomplex* minus_one = PASTEMAC(c, m1); - scomplex* A01; - scomplex* A11; - scomplex* A21; - scomplex* a01; - scomplex* alpha11; - scomplex* a21; - scomplex* x0; - scomplex* x1; - scomplex* x2; - scomplex* x01; - scomplex* chi11; - scomplex* x21; - scomplex alpha11_conj; - scomplex minus_chi11; - dim_t iter, i, k, j, l; - dim_t b_fuse, f; - dim_t n_ahead, f_ahead; - inc_t rs_at, cs_at; - uplo_t uploa_trans; - conj_t conja; - - /* x = alpha * x; */ - PASTEMAC2(c, scalv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - m, - alpha, - x, incx, - cntx, - NULL - ); - - if( bli_does_notrans( transa ) ) - { - rs_at = rs_a; - cs_at = cs_a; - uploa_trans = uploa; - } - else /* if ( bli_does_trans( transa ) ) */ - { - rs_at = cs_a; - cs_at = rs_a; - uploa_trans = bli_uplo_toggled( uploa ); - } - - conja = bli_extract_conj( transa ); - - PASTECH(c, axpyf_ker_ft) kfp_af; - - /* Assign function pointer and fusing factor. */ - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN4) || - (id == BLIS_ARCH_ZEN3) || - (id == BLIS_ARCH_ZEN2) || - (id == BLIS_ARCH_ZEN); - - if (bamdzen) { - kfp_af = bli_caxpyf_zen_int_5; - b_fuse = 5; - } - else - { - if ( cntx == NULL ) cntx = bli_gks_query_cntx(); - kfp_af = bli_cntx_get_l1f_ker_dt( BLIS_SCOMPLEX, BLIS_AXPYF_KER, cntx ); - b_fuse = bli_cntx_get_blksz_def_dt( BLIS_SCOMPLEX, BLIS_AF, cntx ); - } - /* We reduce all of the possible cases down to just lower/upper. */ - if ( bli_is_upper( uploa_trans ) ) - { - for ( iter = 0; iter < m; iter += f ) - { - f = bli_determine_blocksize_dim_b( iter, m, b_fuse ); - i = m - iter - f; - n_ahead = i; - A11 = a + (i )*rs_at + (i )*cs_at; - A01 = a + (0 )*rs_at + (i )*cs_at; - x1 = x + (i )*incx; - x0 = x + (0 )*incx; - - /* x1 = x1 / triu( A11 ); */ - for ( k = 0; k < f; ++k ) - { - l = f - k - 1; - f_ahead = l; - alpha11 = A11 + (l )*rs_at + (l )*cs_at; - a01 = A11 + (0 )*rs_at + (l )*cs_at; - chi11 = x1 + (l )*incx; - x01 = x1 + (0 )*incx; - - /* chi11 = chi11 / alpha11; */ - if ( bli_is_nonunit_diag( diaga ) ) - { - PASTEMAC(c, copycjs)( conja, *alpha11, alpha11_conj ); - PASTEMAC(c, invscals)( alpha11_conj, *chi11 ); - } - - /* x01 = x01 - chi11 * a01; */ - PASTEMAC(c, neg2s)( *chi11, minus_chi11 ); - if ( bli_is_conj( conja ) ) - { - for ( j = 0; j < f_ahead; ++j ) - PASTEMAC(c, axpyjs)( minus_chi11, *(a01 + j*rs_at), *(x01 + j*incx) ); - } - else - { - for ( j = 0; j < f_ahead; ++j ) - PASTEMAC(c, axpys)( minus_chi11, *(a01 + j*rs_at), *(x01 + j*incx) ); - } - } - - /* x0 = x0 - A01 * x1; */ - kfp_af - ( - conja, - BLIS_NO_CONJUGATE, - n_ahead, - f, - minus_one, - A01, rs_at, cs_at, - x1, incx, - x0, incx, - cntx - ); - } - } - else /* if ( bli_is_lower( uploa_trans ) ) */ - { - for ( iter = 0; iter < m; iter += f ) - { - f = bli_determine_blocksize_dim_f( iter, m, b_fuse ); - i = iter; - n_ahead = m - iter - f; - A11 = a + (i )*rs_at + (i )*cs_at; - A21 = a + (i+f)*rs_at + (i )*cs_at; - x1 = x + (i )*incx; - x2 = x + (i+f)*incx; - - /* x1 = x1 / tril( A11 ); */ - for ( k = 0; k < f; ++k ) - { - l = k; - f_ahead = f - k - 1; - alpha11 = A11 + (l )*rs_at + (l )*cs_at; - a21 = A11 + (l+1)*rs_at + (l )*cs_at; - chi11 = x1 + (l )*incx; - x21 = x1 + (l+1)*incx; - - /* chi11 = chi11 / alpha11; */ - if ( bli_is_nonunit_diag( diaga ) ) - { - PASTEMAC(c, copycjs)( conja, *alpha11, alpha11_conj ); - PASTEMAC(c, invscals)( alpha11_conj, *chi11 ); - } - - /* x21 = x21 - chi11 * a21; */ - PASTEMAC(c, neg2s)( *chi11, minus_chi11 ); - if ( bli_is_conj( conja ) ) - { - for ( j = 0; j < f_ahead; ++j ) - PASTEMAC(c, axpyjs)( minus_chi11, *(a21 + j*rs_at), *(x21 + j*incx) ); - } - else - { - for ( j = 0; j < f_ahead; ++j ) - PASTEMAC(c, axpys)( minus_chi11, *(a21 + j*rs_at), *(x21 + j*incx) ); - } - } - - /* x2 = x2 - A21 * x1; */ - kfp_af - ( - conja, - BLIS_NO_CONJUGATE, - n_ahead, - f, - minus_one, - A21, rs_at, cs_at, - x1, incx, - x2, incx, - cntx - ); - } - } -} - -#else -INSERT_GENTFUNC_BASIC0( trsv_unf_var2 ) -#endif +INSERT_GENTFUNC_BASIC0( trsv_unf_var2 ) \ No newline at end of file diff --git a/frame/2/trsv/bli_trsv_unf_var2_amd.c b/frame/2/trsv/bli_trsv_unf_var2_amd.c new file mode 100644 index 0000000000..51bbcabab7 --- /dev/null +++ b/frame/2/trsv/bli_trsv_unf_var2_amd.c @@ -0,0 +1,1024 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019 - 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTEMAC(ch,varname) \ + ( \ + uplo_t uploa, \ + trans_t transa, \ + diag_t diaga, \ + dim_t m, \ + ctype* alpha, \ + ctype* a, inc_t rs_a, inc_t cs_a, \ + ctype* x, inc_t incx, \ + cntx_t* cntx \ + ) \ +{ \ + const num_t dt = PASTEMAC(ch,type); \ +\ + bli_init_once(); \ +\ + if( cntx == NULL ) cntx = bli_gks_query_cntx(); \ +\ + ctype* minus_one = PASTEMAC(ch,m1); \ + ctype* A01; \ + ctype* A11; \ + ctype* A21; \ + ctype* a01; \ + ctype* alpha11; \ + ctype* a21; \ + ctype* x0; \ + ctype* x1; \ + ctype* x2; \ + ctype* x01; \ + ctype* chi11; \ + ctype* x21; \ + ctype alpha11_conj; \ + ctype minus_chi11; \ + dim_t iter, i, k, j, l; \ + dim_t b_fuse, f; \ + dim_t n_ahead, f_ahead; \ + inc_t rs_at, cs_at; \ + uplo_t uploa_trans; \ + conj_t conja; \ +\ + /* x = alpha * x; */ \ + PASTEMAC2(ch,scalv,BLIS_TAPI_EX_SUF) \ + ( \ + BLIS_NO_CONJUGATE, \ + m, \ + alpha, \ + x, incx, \ + cntx, \ + NULL \ + ); \ +\ + if ( bli_does_notrans( transa ) ) \ + { \ + rs_at = rs_a; \ + cs_at = cs_a; \ + uploa_trans = uploa; \ + } \ + else /* if ( bli_does_trans( transa ) ) */ \ + { \ + rs_at = cs_a; \ + cs_at = rs_a; \ + uploa_trans = bli_uplo_toggled( uploa ); \ + } \ +\ + conja = bli_extract_conj( transa ); \ +\ + PASTECH(ch,axpyf_ker_ft) kfp_af; \ +\ + /* Query the context for the kernel function pointer and fusing factor. */ \ + kfp_af = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPYF_KER, cntx ); \ + b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_AF, cntx ); \ +\ + /* We reduce all of the possible cases down to just lower/upper. */ \ + if ( bli_is_upper( uploa_trans ) ) \ + { \ + for ( iter = 0; iter < m; iter += f ) \ + { \ + f = bli_determine_blocksize_dim_b( iter, m, b_fuse ); \ + i = m - iter - f; \ + n_ahead = i; \ + A11 = a + (i )*rs_at + (i )*cs_at; \ + A01 = a + (0 )*rs_at + (i )*cs_at; \ + x1 = x + (i )*incx; \ + x0 = x + (0 )*incx; \ +\ + /* x1 = x1 / triu( A11 ); */ \ + for ( k = 0; k < f; ++k ) \ + { \ + l = f - k - 1; \ + f_ahead = l; \ + alpha11 = A11 + (l )*rs_at + (l )*cs_at; \ + a01 = A11 + (0 )*rs_at + (l )*cs_at; \ + chi11 = x1 + (l )*incx; \ + x01 = x1 + (0 )*incx; \ +\ + /* chi11 = chi11 / alpha11; */ \ + if ( bli_is_nonunit_diag( diaga ) ) \ + { \ + PASTEMAC(ch,copycjs)( conja, *alpha11, alpha11_conj ); \ + PASTEMAC(ch,invscals)( alpha11_conj, *chi11 ); \ + } \ +\ + /* x01 = x01 - chi11 * a01; */ \ + PASTEMAC(ch,neg2s)( *chi11, minus_chi11 ); \ + if ( bli_is_conj( conja ) ) \ + { \ + for ( j = 0; j < f_ahead; ++j ) \ + PASTEMAC(ch,axpyjs)( minus_chi11, *(a01 + j*rs_at), *(x01 + j*incx) ); \ + } \ + else \ + { \ + for ( j = 0; j < f_ahead; ++j ) \ + PASTEMAC(ch,axpys)( minus_chi11, *(a01 + j*rs_at), *(x01 + j*incx) ); \ + } \ + } \ +\ + /* x0 = x0 - A01 * x1; */ \ + kfp_af \ + ( \ + conja, \ + BLIS_NO_CONJUGATE, \ + n_ahead, \ + f, \ + minus_one, \ + A01, rs_at, cs_at, \ + x1, incx, \ + x0, incx, \ + cntx \ + ); \ + } \ + } \ + else /* if ( bli_is_lower( uploa_trans ) ) */ \ + { \ + for ( iter = 0; iter < m; iter += f ) \ + { \ + f = bli_determine_blocksize_dim_f( iter, m, b_fuse ); \ + i = iter; \ + n_ahead = m - iter - f; \ + A11 = a + (i )*rs_at + (i )*cs_at; \ + A21 = a + (i+f)*rs_at + (i )*cs_at; \ + x1 = x + (i )*incx; \ + x2 = x + (i+f)*incx; \ +\ + /* x1 = x1 / tril( A11 ); */ \ + for ( k = 0; k < f; ++k ) \ + { \ + l = k; \ + f_ahead = f - k - 1; \ + alpha11 = A11 + (l )*rs_at + (l )*cs_at; \ + a21 = A11 + (l+1)*rs_at + (l )*cs_at; \ + chi11 = x1 + (l )*incx; \ + x21 = x1 + (l+1)*incx; \ +\ + /* chi11 = chi11 / alpha11; */ \ + if ( bli_is_nonunit_diag( diaga ) ) \ + { \ + PASTEMAC(ch,copycjs)( conja, *alpha11, alpha11_conj ); \ + PASTEMAC(ch,invscals)( alpha11_conj, *chi11 ); \ + } \ +\ + /* x21 = x21 - chi11 * a21; */ \ + PASTEMAC(ch,neg2s)( *chi11, minus_chi11 ); \ + if ( bli_is_conj( conja ) ) \ + { \ + for ( j = 0; j < f_ahead; ++j ) \ + PASTEMAC(ch,axpyjs)( minus_chi11, *(a21 + j*rs_at), *(x21 + j*incx) ); \ + } \ + else \ + { \ + for ( j = 0; j < f_ahead; ++j ) \ + PASTEMAC(ch,axpys)( minus_chi11, *(a21 + j*rs_at), *(x21 + j*incx) ); \ + } \ + } \ +\ + /* x2 = x2 - A21 * x1; */ \ + kfp_af \ + ( \ + conja, \ + BLIS_NO_CONJUGATE, \ + n_ahead, \ + f, \ + minus_one, \ + A21, rs_at, cs_at, \ + x1, incx, \ + x2, incx, \ + cntx \ + ); \ + } \ + } \ +} + +void bli_dtrsv_unf_var2 + ( + uplo_t uploa, + trans_t transa, + diag_t diaga, + dim_t m, + double* alpha, + double* a, inc_t rs_a, inc_t cs_a, + double* x, inc_t incx, + cntx_t* cntx + ) +{ + + double* minus_one = PASTEMAC(d,m1); + double* A01; + double* A11; + double* A21; + double* a01; + double* alpha11; + double* a21; + double* x0; + double* x1; + double* x2; + double* x01; + double* chi11; + double* x21; + double alpha11_conj; + double minus_chi11; + dim_t iter, i, k, j, l; + dim_t b_fuse, f; + dim_t n_ahead, f_ahead; + inc_t rs_at, cs_at; + uplo_t uploa_trans; + conj_t conja; + + // For AMD these APIS are invoked skipping intermediate framework layers + // Hence we need to ensure that cntx is set here + bli_init_once(); + if( cntx == NULL ) cntx = bli_gks_query_cntx(); + + /* x = alpha * x; */ + PASTEMAC2(d,scalv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + m, + alpha, + x, incx, + cntx, + NULL + ); + + if ( bli_does_notrans( transa ) ) + { + rs_at = rs_a; + cs_at = cs_a; + uploa_trans = uploa; + } + else /* if ( bli_does_trans( transa ) ) */ + { + rs_at = cs_a; + cs_at = rs_a; + uploa_trans = bli_uplo_toggled( uploa ); + } + + conja = bli_extract_conj( transa ); + + PASTECH(d,axpyf_ker_ft) kfp_af; + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) { + kfp_af = bli_daxpyf_zen_int_16x4; + b_fuse = 4; + } + else + { + kfp_af = bli_cntx_get_l1f_ker_dt( BLIS_DOUBLE, BLIS_AXPYF_KER, cntx ); + b_fuse = bli_cntx_get_blksz_def_dt( BLIS_DOUBLE, BLIS_AF, cntx ); + } + + /* We reduce all of the possible cases down to just lower/upper. */ + if ( bli_is_upper( uploa_trans ) ) + { + for ( iter = 0; iter < m; iter += f ) + { + f = bli_determine_blocksize_dim_b( iter, m, b_fuse ); + i = m - iter - f; + n_ahead = i; + A11 = a + (i )*rs_at + (i )*cs_at; + A01 = a + (0 )*rs_at + (i )*cs_at; + x1 = x + (i )*incx; + x0 = x + (0 )*incx; + + /* x1 = x1 / triu( A11 ); */ + for ( k = 0; k < f; ++k ) + { + l = f - k - 1; + f_ahead = l; + alpha11 = A11 + (l )*rs_at + (l )*cs_at; + a01 = A11 + (0 )*rs_at + (l )*cs_at; + chi11 = x1 + (l )*incx; + x01 = x1 + (0 )*incx; + + /* chi11 = chi11 / alpha11; */ + if ( bli_is_nonunit_diag( diaga ) ) + { + PASTEMAC(d,copycjs)( conja, *alpha11, alpha11_conj ); + PASTEMAC(d,invscals)( alpha11_conj, *chi11 ); + } + + /* x01 = x01 - chi11 * a01; */ + PASTEMAC(d,neg2s)( *chi11, minus_chi11 ); + if ( bli_is_conj( conja ) ) + { + for ( j = 0; j < f_ahead; ++j ) + PASTEMAC(d,axpyjs)( minus_chi11, *(a01 + j*rs_at), *(x01 + j*incx) ); + } + else + { + for ( j = 0; j < f_ahead; ++j ) + PASTEMAC(d,axpys)( minus_chi11, *(a01 + j*rs_at), *(x01 + j*incx) ); + } + } + + /* x0 = x0 - A01 * x1; */ + kfp_af + ( + conja, + BLIS_NO_CONJUGATE, + n_ahead, + f, + minus_one, + A01, rs_at, cs_at, + x1, incx, + x0, incx, + cntx + ); + } + } + else /* if ( bli_is_lower( uploa_trans ) ) */ + { + for ( iter = 0; iter < m; iter += f ) + { + f = bli_determine_blocksize_dim_f( iter, m, b_fuse ); + i = iter; + n_ahead = m - iter - f; + A11 = a + (i )*rs_at + (i )*cs_at; + A21 = a + (i+f)*rs_at + (i )*cs_at; + x1 = x + (i )*incx; + x2 = x + (i+f)*incx; + + /* x1 = x1 / tril( A11 ); */ + for ( k = 0; k < f; ++k ) + { + l = k; + f_ahead = f - k - 1; + alpha11 = A11 + (l )*rs_at + (l )*cs_at; + a21 = A11 + (l+1)*rs_at + (l )*cs_at; + chi11 = x1 + (l )*incx; + x21 = x1 + (l+1)*incx; + + /* chi11 = chi11 / alpha11; */ + if ( bli_is_nonunit_diag( diaga ) ) + { + PASTEMAC(d,copycjs)( conja, *alpha11, alpha11_conj ); + PASTEMAC(d,invscals)( alpha11_conj, *chi11 ); + } + + /* x21 = x21 - chi11 * a21; */ + PASTEMAC(d,neg2s)( *chi11, minus_chi11 ); + if ( bli_is_conj( conja ) ) + { + for ( j = 0; j < f_ahead; ++j ) + PASTEMAC(d,axpyjs)( minus_chi11, *(a21 + j*rs_at), *(x21 + j*incx) ); + } + else + { + for ( j = 0; j < f_ahead; ++j ) + PASTEMAC(d,axpys)( minus_chi11, *(a21 + j*rs_at), *(x21 + j*incx) ); + } + } + + /* x2 = x2 - A21 * x1; */ + kfp_af + ( + conja, + BLIS_NO_CONJUGATE, + n_ahead, + f, + minus_one, + A21, rs_at, cs_at, + x1, incx, + x2, incx, + cntx + ); + } + } +} + +void bli_strsv_unf_var2 + ( + uplo_t uploa, + trans_t transa, + diag_t diaga, + dim_t m, + float* alpha, + float* a, inc_t rs_a, inc_t cs_a, + float* x, inc_t incx, + cntx_t* cntx + ) +{ + + float* minus_one = PASTEMAC(s, m1); + float* A01; + float* A11; + float* A21; + float* a01; + float* alpha11; + float* a21; + float* x0; + float* x1; + float* x2; + float* x01; + float* chi11; + float* x21; + float alpha11_conj; + float minus_chi11; + dim_t iter, i, k, j, l; + dim_t b_fuse, f; + dim_t n_ahead, f_ahead; + inc_t rs_at, cs_at; + uplo_t uploa_trans; + conj_t conja; + + // For AMD these APIS are invoked skipping intermediate framework layers + // Hence we need to ensure that cntx is set here + bli_init_once(); + if( cntx == NULL ) cntx = bli_gks_query_cntx(); + + /* x = alpha * x; */ + PASTEMAC2(s, scalv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + m, + alpha, + x, incx, + cntx, + NULL + ); + + if( bli_does_notrans( transa ) ) + { + rs_at = rs_a; + cs_at = cs_a; + uploa_trans = uploa; + } + else /* if ( bli_does_trans( transa ) ) */ + { + rs_at = cs_a; + cs_at = rs_a; + uploa_trans = bli_uplo_toggled( uploa ); + } + + conja = bli_extract_conj( transa ); + + PASTECH(s, axpyf_ker_ft) kfp_af; + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) { + kfp_af = bli_saxpyf_zen_int_5; + b_fuse = 5; + } + else + { + kfp_af = bli_cntx_get_l1f_ker_dt( BLIS_FLOAT, BLIS_AXPYF_KER, cntx ); + b_fuse = bli_cntx_get_blksz_def_dt( BLIS_FLOAT, BLIS_AF, cntx ); + } + + /* We reduce all of the possible cases down to just lower/upper. */ + if ( bli_is_upper( uploa_trans ) ) + { + for ( iter = 0; iter < m; iter += f ) + { + f = bli_determine_blocksize_dim_b( iter, m, b_fuse ); + i = m - iter - f; + n_ahead = i; + A11 = a + (i )*rs_at + (i )*cs_at; + A01 = a + (0 )*rs_at + (i )*cs_at; + x1 = x + (i )*incx; + x0 = x + (0 )*incx; + + /* x1 = x1 / triu( A11 ); */ + for ( k = 0; k < f; ++k ) + { + l = f - k - 1; + f_ahead = l; + alpha11 = A11 + (l )*rs_at + (l )*cs_at; + a01 = A11 + (0 )*rs_at + (l )*cs_at; + chi11 = x1 + (l )*incx; + x01 = x1 + (0 )*incx; + + /* chi11 = chi11 / alpha11; */ + if ( bli_is_nonunit_diag( diaga ) ) + { + PASTEMAC(s, copycjs)( conja, *alpha11, alpha11_conj ); + PASTEMAC(s, invscals)( alpha11_conj, *chi11 ); + } + + /* x01 = x01 - chi11 * a01; */ + PASTEMAC(s, neg2s)( *chi11, minus_chi11 ); + if ( bli_is_conj( conja ) ) + { + for ( j = 0; j < f_ahead; ++j ) + PASTEMAC(s, axpyjs)( minus_chi11, *(a01 + j*rs_at), *(x01 + j*incx) ); + } + else + { + for ( j = 0; j < f_ahead; ++j ) + PASTEMAC(s, axpys)( minus_chi11, *(a01 + j*rs_at), *(x01 + j*incx) ); + } + } + + /* x0 = x0 - A01 * x1; */ + kfp_af + ( + conja, + BLIS_NO_CONJUGATE, + n_ahead, + f, + minus_one, + A01, rs_at, cs_at, + x1, incx, + x0, incx, + cntx + ); + } + } + else /* if ( bli_is_lower( uploa_trans ) ) */ + { + for ( iter = 0; iter < m; iter += f ) + { + f = bli_determine_blocksize_dim_f( iter, m, b_fuse ); + i = iter; + n_ahead = m - iter - f; + A11 = a + (i )*rs_at + (i )*cs_at; + A21 = a + (i+f)*rs_at + (i )*cs_at; + x1 = x + (i )*incx; + x2 = x + (i+f)*incx; + + /* x1 = x1 / tril( A11 ); */ + for ( k = 0; k < f; ++k ) + { + l = k; + f_ahead = f - k - 1; + alpha11 = A11 + (l )*rs_at + (l )*cs_at; + a21 = A11 + (l+1)*rs_at + (l )*cs_at; + chi11 = x1 + (l )*incx; + x21 = x1 + (l+1)*incx; + + /* chi11 = chi11 / alpha11; */ + if ( bli_is_nonunit_diag( diaga ) ) + { + PASTEMAC(s, copycjs)( conja, *alpha11, alpha11_conj ); + PASTEMAC(s, invscals)( alpha11_conj, *chi11 ); + } + + /* x21 = x21 - chi11 * a21; */ + PASTEMAC(s, neg2s)( *chi11, minus_chi11 ); + if ( bli_is_conj( conja ) ) + { + for ( j = 0; j < f_ahead; ++j ) + PASTEMAC(s, axpyjs)( minus_chi11, *(a21 + j*rs_at), *(x21 + j*incx) ); + } + else + { + for ( j = 0; j < f_ahead; ++j ) + PASTEMAC(s, axpys)( minus_chi11, *(a21 + j*rs_at), *(x21 + j*incx) ); + } + } + + /* x2 = x2 - A21 * x1; */ + kfp_af + ( + conja, + BLIS_NO_CONJUGATE, + n_ahead, + f, + minus_one, + A21, rs_at, cs_at, + x1, incx, + x2, incx, + cntx + ); + } + } +} + +void bli_ztrsv_unf_var2 + ( + uplo_t uploa, + trans_t transa, + diag_t diaga, + dim_t m, + dcomplex* alpha, + dcomplex* a, inc_t rs_a, inc_t cs_a, + dcomplex* x, inc_t incx, + cntx_t* cntx + ) +{ + + dcomplex* minus_one = PASTEMAC(z, m1); + dcomplex* A01; + dcomplex* A11; + dcomplex* A21; + dcomplex* a01; + dcomplex* alpha11; + dcomplex* a21; + dcomplex* x0; + dcomplex* x1; + dcomplex* x2; + dcomplex* x01; + dcomplex* chi11; + dcomplex* x21; + dcomplex alpha11_conj; + dcomplex minus_chi11; + dim_t iter, i, k, j, l; + dim_t b_fuse, f; + dim_t n_ahead, f_ahead; + inc_t rs_at, cs_at; + uplo_t uploa_trans; + conj_t conja; + + // For AMD these APIS are invoked skipping intermediate framework layers + // Hence we need to ensure that cntx is set here + bli_init_once(); + if( cntx == NULL ) cntx = bli_gks_query_cntx(); + + /* x = alpha * x; */ + PASTEMAC2(z, scalv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + m, + alpha, + x, incx, + cntx, + NULL + ); + + if( bli_does_notrans( transa ) ) + { + rs_at = rs_a; + cs_at = cs_a; + uploa_trans = uploa; + } + else /* if ( bli_does_trans( transa ) ) */ + { + rs_at = cs_a; + cs_at = rs_a; + uploa_trans = bli_uplo_toggled( uploa ); + } + + conja = bli_extract_conj( transa ); + + PASTECH(z, axpyf_ker_ft) kfp_af; + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) { + kfp_af = bli_zaxpyf_zen_int_5; + b_fuse = 5; + } + else + { + kfp_af = bli_cntx_get_l1f_ker_dt( BLIS_DCOMPLEX, BLIS_AXPYF_KER, cntx ); + b_fuse = bli_cntx_get_blksz_def_dt( BLIS_DCOMPLEX, BLIS_AF, cntx ); + } + /* We reduce all of the possible cases down to just lower/upper. */ + if ( bli_is_upper( uploa_trans ) ) + { + for ( iter = 0; iter < m; iter += f ) + { + f = bli_determine_blocksize_dim_b( iter, m, b_fuse ); + i = m - iter - f; + n_ahead = i; + A11 = a + (i )*rs_at + (i )*cs_at; + A01 = a + (0 )*rs_at + (i )*cs_at; + x1 = x + (i )*incx; + x0 = x + (0 )*incx; + + /* x1 = x1 / triu( A11 ); */ + for ( k = 0; k < f; ++k ) + { + l = f - k - 1; + f_ahead = l; + alpha11 = A11 + (l )*rs_at + (l )*cs_at; + a01 = A11 + (0 )*rs_at + (l )*cs_at; + chi11 = x1 + (l )*incx; + x01 = x1 + (0 )*incx; + + /* chi11 = chi11 / alpha11; */ + if ( bli_is_nonunit_diag( diaga ) ) + { + PASTEMAC(z, copycjs)( conja, *alpha11, alpha11_conj ); + PASTEMAC(z, invscals)( alpha11_conj, *chi11 ); + } + + /* x01 = x01 - chi11 * a01; */ + PASTEMAC(z, neg2s)( *chi11, minus_chi11 ); + if ( bli_is_conj( conja ) ) + { + for ( j = 0; j < f_ahead; ++j ) + PASTEMAC(z, axpyjs)( minus_chi11, *(a01 + j*rs_at), *(x01 + j*incx) ); + } + else + { + for ( j = 0; j < f_ahead; ++j ) + PASTEMAC(z, axpys)( minus_chi11, *(a01 + j*rs_at), *(x01 + j*incx) ); + } + } + + /* x0 = x0 - A01 * x1; */ + kfp_af + ( + conja, + BLIS_NO_CONJUGATE, + n_ahead, + f, + minus_one, + A01, rs_at, cs_at, + x1, incx, + x0, incx, + cntx + ); + } + } + else /* if ( bli_is_lower( uploa_trans ) ) */ + { + for ( iter = 0; iter < m; iter += f ) + { + f = bli_determine_blocksize_dim_f( iter, m, b_fuse ); + i = iter; + n_ahead = m - iter - f; + A11 = a + (i )*rs_at + (i )*cs_at; + A21 = a + (i+f)*rs_at + (i )*cs_at; + x1 = x + (i )*incx; + x2 = x + (i+f)*incx; + + /* x1 = x1 / tril( A11 ); */ + for ( k = 0; k < f; ++k ) + { + l = k; + f_ahead = f - k - 1; + alpha11 = A11 + (l )*rs_at + (l )*cs_at; + a21 = A11 + (l+1)*rs_at + (l )*cs_at; + chi11 = x1 + (l )*incx; + x21 = x1 + (l+1)*incx; + + /* chi11 = chi11 / alpha11; */ + if ( bli_is_nonunit_diag( diaga ) ) + { + PASTEMAC(z, copycjs)( conja, *alpha11, alpha11_conj ); + PASTEMAC(z, invscals)( alpha11_conj, *chi11 ); + } + + /* x21 = x21 - chi11 * a21; */ + PASTEMAC(z, neg2s)( *chi11, minus_chi11 ); + if ( bli_is_conj( conja ) ) + { + for ( j = 0; j < f_ahead; ++j ) + PASTEMAC(z, axpyjs)( minus_chi11, *(a21 + j*rs_at), *(x21 + j*incx) ); + } + else + { + for ( j = 0; j < f_ahead; ++j ) + PASTEMAC(z, axpys)( minus_chi11, *(a21 + j*rs_at), *(x21 + j*incx) ); + } + } + + /* x2 = x2 - A21 * x1; */ + kfp_af + ( + conja, + BLIS_NO_CONJUGATE, + n_ahead, + f, + minus_one, + A21, rs_at, cs_at, + x1, incx, + x2, incx, + cntx + ); + } + } +} + +void bli_ctrsv_unf_var2 + ( + uplo_t uploa, + trans_t transa, + diag_t diaga, + dim_t m, + scomplex* alpha, + scomplex* a, inc_t rs_a, inc_t cs_a, + scomplex* x, inc_t incx, + cntx_t* cntx + ) +{ + + scomplex* minus_one = PASTEMAC(c, m1); + scomplex* A01; + scomplex* A11; + scomplex* A21; + scomplex* a01; + scomplex* alpha11; + scomplex* a21; + scomplex* x0; + scomplex* x1; + scomplex* x2; + scomplex* x01; + scomplex* chi11; + scomplex* x21; + scomplex alpha11_conj; + scomplex minus_chi11; + dim_t iter, i, k, j, l; + dim_t b_fuse, f; + dim_t n_ahead, f_ahead; + inc_t rs_at, cs_at; + uplo_t uploa_trans; + conj_t conja; + + // For AMD these APIS are invoked skipping intermediate framework layers + // Hence we need to ensure that cntx is set here + bli_init_once(); + if( cntx == NULL ) cntx = bli_gks_query_cntx(); + + /* x = alpha * x; */ + PASTEMAC2(c, scalv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + m, + alpha, + x, incx, + cntx, + NULL + ); + + if( bli_does_notrans( transa ) ) + { + rs_at = rs_a; + cs_at = cs_a; + uploa_trans = uploa; + } + else /* if ( bli_does_trans( transa ) ) */ + { + rs_at = cs_a; + cs_at = rs_a; + uploa_trans = bli_uplo_toggled( uploa ); + } + + conja = bli_extract_conj( transa ); + + PASTECH(c, axpyf_ker_ft) kfp_af; + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) { + kfp_af = bli_caxpyf_zen_int_5; + b_fuse = 5; + } + else + { + kfp_af = bli_cntx_get_l1f_ker_dt( BLIS_SCOMPLEX, BLIS_AXPYF_KER, cntx ); + b_fuse = bli_cntx_get_blksz_def_dt( BLIS_SCOMPLEX, BLIS_AF, cntx ); + } + /* We reduce all of the possible cases down to just lower/upper. */ + if ( bli_is_upper( uploa_trans ) ) + { + for ( iter = 0; iter < m; iter += f ) + { + f = bli_determine_blocksize_dim_b( iter, m, b_fuse ); + i = m - iter - f; + n_ahead = i; + A11 = a + (i )*rs_at + (i )*cs_at; + A01 = a + (0 )*rs_at + (i )*cs_at; + x1 = x + (i )*incx; + x0 = x + (0 )*incx; + + /* x1 = x1 / triu( A11 ); */ + for ( k = 0; k < f; ++k ) + { + l = f - k - 1; + f_ahead = l; + alpha11 = A11 + (l )*rs_at + (l )*cs_at; + a01 = A11 + (0 )*rs_at + (l )*cs_at; + chi11 = x1 + (l )*incx; + x01 = x1 + (0 )*incx; + + /* chi11 = chi11 / alpha11; */ + if ( bli_is_nonunit_diag( diaga ) ) + { + PASTEMAC(c, copycjs)( conja, *alpha11, alpha11_conj ); + PASTEMAC(c, invscals)( alpha11_conj, *chi11 ); + } + + /* x01 = x01 - chi11 * a01; */ + PASTEMAC(c, neg2s)( *chi11, minus_chi11 ); + if ( bli_is_conj( conja ) ) + { + for ( j = 0; j < f_ahead; ++j ) + PASTEMAC(c, axpyjs)( minus_chi11, *(a01 + j*rs_at), *(x01 + j*incx) ); + } + else + { + for ( j = 0; j < f_ahead; ++j ) + PASTEMAC(c, axpys)( minus_chi11, *(a01 + j*rs_at), *(x01 + j*incx) ); + } + } + + /* x0 = x0 - A01 * x1; */ + kfp_af + ( + conja, + BLIS_NO_CONJUGATE, + n_ahead, + f, + minus_one, + A01, rs_at, cs_at, + x1, incx, + x0, incx, + cntx + ); + } + } + else /* if ( bli_is_lower( uploa_trans ) ) */ + { + for ( iter = 0; iter < m; iter += f ) + { + f = bli_determine_blocksize_dim_f( iter, m, b_fuse ); + i = iter; + n_ahead = m - iter - f; + A11 = a + (i )*rs_at + (i )*cs_at; + A21 = a + (i+f)*rs_at + (i )*cs_at; + x1 = x + (i )*incx; + x2 = x + (i+f)*incx; + + /* x1 = x1 / tril( A11 ); */ + for ( k = 0; k < f; ++k ) + { + l = k; + f_ahead = f - k - 1; + alpha11 = A11 + (l )*rs_at + (l )*cs_at; + a21 = A11 + (l+1)*rs_at + (l )*cs_at; + chi11 = x1 + (l )*incx; + x21 = x1 + (l+1)*incx; + + /* chi11 = chi11 / alpha11; */ + if ( bli_is_nonunit_diag( diaga ) ) + { + PASTEMAC(c, copycjs)( conja, *alpha11, alpha11_conj ); + PASTEMAC(c, invscals)( alpha11_conj, *chi11 ); + } + + /* x21 = x21 - chi11 * a21; */ + PASTEMAC(c, neg2s)( *chi11, minus_chi11 ); + if ( bli_is_conj( conja ) ) + { + for ( j = 0; j < f_ahead; ++j ) + PASTEMAC(c, axpyjs)( minus_chi11, *(a21 + j*rs_at), *(x21 + j*incx) ); + } + else + { + for ( j = 0; j < f_ahead; ++j ) + PASTEMAC(c, axpys)( minus_chi11, *(a21 + j*rs_at), *(x21 + j*incx) ); + } + } + + /* x2 = x2 - A21 * x1; */ + kfp_af + ( + conja, + BLIS_NO_CONJUGATE, + n_ahead, + f, + minus_one, + A21, rs_at, cs_at, + x1, incx, + x2, incx, + cntx + ); + } + } +} diff --git a/frame/3/bli_l3_sup_int.c b/frame/3/bli_l3_sup_int.c index 7ef4bdd49f..909f480599 100644 --- a/frame/3/bli_l3_sup_int.c +++ b/frame/3/bli_l3_sup_int.c @@ -48,120 +48,6 @@ err_t bli_gemmsup_int { AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_4); -#ifdef BLIS_CONFIG_EPYC - const num_t dt = bli_obj_dt( c ); - const dim_t m = bli_obj_length( c ); - const dim_t n = bli_obj_width( c ); - const dim_t k = bli_obj_width( a ); - const dim_t MR = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx ); - const dim_t NR = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx ); - const bool auto_factor = bli_rntm_auto_factor( rntm ); - const dim_t n_threads = bli_rntm_num_threads( rntm ); - - dim_t jc_new; - dim_t ic_new; - - - //bli_gemmsup_ref_var2 - //bli_gemmsup_ref_var1 - #if 0 - bli_gemmsup_ref_var1n - #else - #endif - const stor3_t stor_id = bli_obj_stor3_from_strides( c, a, b ); - const bool is_rrr_rrc_rcr_crr = ( stor_id == BLIS_RRR || - stor_id == BLIS_RRC || - stor_id == BLIS_RCR || - stor_id == BLIS_CRR ); - #ifdef TRACEVAR - if ( bli_thread_am_ochief( thread ) ) - printf( "bli_l3_sup_int(): var2m primary\n" ); - #endif - - // Don't use the small/unpacked implementation if one of the matrices - // uses general stride. - if ( stor_id == BLIS_XXX ) { - AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_4, "SUP doesn't support general stide."); - return BLIS_FAILURE; - } - - if ( is_rrr_rrc_rcr_crr ) - { - // This branch handles: - // - rrr rrc rcr crr for row-preferential kernels - // - rcc crc ccr ccc for column-preferential kernels - // - Currently only row-preferential kernels are only supported. - - // calculate number of micropanels in m and n dimensions and - // recalculate the automatic thread factorization based on these number of micropanels - const dim_t mu = m / MR; - const dim_t nu = n / NR; - - // If the parallel thread factorization was automatic, we update it - // with a new factorization based on the matrix dimensions in units - // of micropanels. - if ( auto_factor ) - { - // In the block-panel algorithm, the m dimension is parallelized - // with ic_nt and the n dimension is parallelized with jc_nt. - bli_thread_partition_2x2( n_threads, mu, nu, &ic_new, &jc_new ); - - // Update the ways of parallelism for the jc and ic loops, and then - // update the current thread's root thrinfo_t node according to the - // new ways of parallelism value for the jc loop. - bli_rntm_set_ways_only( jc_new, 1, ic_new, 1, 1, rntm ); - bli_l3_sup_thrinfo_update_root( rntm, thread ); - } - - /*Enable packing for B matrix for higher sizes*/ - if(bli_is_float(dt) && (n_threads==1)) { - if((m > 240) && (k > 240) && (n > 240)) - bli_rntm_set_pack_b( 1, rntm ); - } - - bli_gemmsup_ref_var2m( BLIS_NO_TRANSPOSE, - alpha, a, b, beta, c, - stor_id, cntx, rntm, thread ); - } - else - { - // This branch handles: - // - rrr rrc rcr crr for column-preferential kernels - // - rcc crc ccr ccc for row-preferential kernels - // - Currently only row-preferential kernels are only supported. - const dim_t mu = n / MR; // the n becomes m after a transposition - const dim_t nu = m / NR; // the m becomes n after a transposition - - if ( auto_factor ) - { - // In the block-panel algorithm, the m dimension is parallelized - // with ic_nt and the n dimension is parallelized with jc_nt. - bli_thread_partition_2x2( n_threads, mu, nu, &ic_new, &jc_new ); - - // Update the ways of parallelism for the jc and ic loops, and then - // update the current thread's root thrinfo_t node according to the - // new ways of parallelism value for the jc loop. - bli_rntm_set_ways_only( jc_new, 1, ic_new, 1, 1, rntm ); - bli_l3_sup_thrinfo_update_root( rntm, thread ); - } - - /* Enable packing for B matrix for higher sizes. Note that pack A - * becomes pack B inside var2m because this is transpose case*/ - if(bli_is_float(dt) && (n_threads==1)) { - if((m > 240) && (k > 240) && (n > 240)) - bli_rntm_set_pack_a( 1, rntm ); - } - - bli_gemmsup_ref_var2m( BLIS_TRANSPOSE, - alpha, a, b, beta, c, - stor_id, cntx, rntm, thread ); - } - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4); - return BLIS_SUCCESS; - -#else // #ifdef BLIS_CONFIG_EPYC - const stor3_t stor_id = bli_obj_stor3_from_strides( c, a, b ); // Don't use the small/unpacked implementation if one of the matrices @@ -335,8 +221,6 @@ err_t bli_gemmsup_int // Return success so that the caller knows that we computed the solution. AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) return BLIS_SUCCESS; - -#endif } // ----------------------------------------------------------------------------- @@ -401,15 +285,9 @@ err_t bli_gemmtsup_int // Decide which algorithm to use (block-panel var2m or panel-block // var1n) based on the number of micropanels in the m and n dimensions. // Also, recalculate the automatic thread factorization. -#ifdef BLIS_CONFIG_EPYC - if ( mu >= nu ) use_bp = TRUE; - else /* if ( mu < nu ) */ use_bp = TRUE;// var1n is not implemented for GEMMT - -#else if ( mu >= nu ) use_bp = TRUE; else /* if ( mu < nu ) */ use_bp = FALSE; -#endif // If the parallel thread factorization was automatic, we update it // with a new factorization based on the matrix dimensions in units // of micropanels. @@ -472,14 +350,10 @@ err_t bli_gemmtsup_int // Decide which algorithm to use (block-panel var2m or panel-block // var1n) based on the number of micropanels in the m and n dimensions. // Also, recalculate the automatic thread factorization. -#ifdef BLIS_CONFIG_EPYC - if ( mu >= nu ) use_bp = TRUE; - else /* if ( mu < nu ) */ use_bp = TRUE; //var1n is not implemented for gemmt -#else + if ( mu >= nu ) use_bp = TRUE; else /* if ( mu < nu ) */ use_bp = FALSE; -#endif // If the parallel thread factorization was automatic, we update it // with a new factorization based on the matrix dimensions in units // of micropanels. diff --git a/frame/3/bli_l3_sup_int_amd.c b/frame/3/bli_l3_sup_int_amd.c new file mode 100644 index 0000000000..7bd44266d2 --- /dev/null +++ b/frame/3/bli_l3_sup_int_amd.c @@ -0,0 +1,352 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2019-21, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +err_t bli_gemmsup_int + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm, + thrinfo_t* thread + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_4); + + const num_t dt = bli_obj_dt( c ); + const dim_t m = bli_obj_length( c ); + const dim_t n = bli_obj_width( c ); + const dim_t k = bli_obj_width( a ); + const dim_t MR = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx ); + const dim_t NR = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx ); + const bool auto_factor = bli_rntm_auto_factor( rntm ); + const dim_t n_threads = bli_rntm_num_threads( rntm ); + + dim_t jc_new; + dim_t ic_new; + + + //bli_gemmsup_ref_var2 + //bli_gemmsup_ref_var1 + #if 0 + bli_gemmsup_ref_var1n + #else + #endif + const stor3_t stor_id = bli_obj_stor3_from_strides( c, a, b ); + const bool is_rrr_rrc_rcr_crr = ( stor_id == BLIS_RRR || + stor_id == BLIS_RRC || + stor_id == BLIS_RCR || + stor_id == BLIS_CRR ); + #ifdef TRACEVAR + if ( bli_thread_am_ochief( thread ) ) + printf( "bli_l3_sup_int(): var2m primary\n" ); + #endif + + // Don't use the small/unpacked implementation if one of the matrices + // uses general stride. + if ( stor_id == BLIS_XXX ) { + AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_4, "SUP doesn't support general stide."); + return BLIS_FAILURE; + } + + if ( is_rrr_rrc_rcr_crr ) + { + // This branch handles: + // - rrr rrc rcr crr for row-preferential kernels + // - rcc crc ccr ccc for column-preferential kernels + // - Currently only row-preferential kernels are only supported. + + // calculate number of micropanels in m and n dimensions and + // recalculate the automatic thread factorization based on these number of micropanels + const dim_t mu = m / MR; + const dim_t nu = n / NR; + + // If the parallel thread factorization was automatic, we update it + // with a new factorization based on the matrix dimensions in units + // of micropanels. + if ( auto_factor ) + { + // In the block-panel algorithm, the m dimension is parallelized + // with ic_nt and the n dimension is parallelized with jc_nt. + bli_thread_partition_2x2( n_threads, mu, nu, &ic_new, &jc_new ); + + // Update the ways of parallelism for the jc and ic loops, and then + // update the current thread's root thrinfo_t node according to the + // new ways of parallelism value for the jc loop. + bli_rntm_set_ways_only( jc_new, 1, ic_new, 1, 1, rntm ); + bli_l3_sup_thrinfo_update_root( rntm, thread ); + } + + /*Enable packing for B matrix for higher sizes*/ + if(bli_is_float(dt) && (n_threads==1)) { + if((m > 240) && (k > 240) && (n > 240)) + bli_rntm_set_pack_b( 1, rntm ); + } + + bli_gemmsup_ref_var2m( BLIS_NO_TRANSPOSE, + alpha, a, b, beta, c, + stor_id, cntx, rntm, thread ); + } + else + { + // This branch handles: + // - rrr rrc rcr crr for column-preferential kernels + // - rcc crc ccr ccc for row-preferential kernels + // - Currently only row-preferential kernels are only supported. + const dim_t mu = n / MR; // the n becomes m after a transposition + const dim_t nu = m / NR; // the m becomes n after a transposition + + if ( auto_factor ) + { + // In the block-panel algorithm, the m dimension is parallelized + // with ic_nt and the n dimension is parallelized with jc_nt. + bli_thread_partition_2x2( n_threads, mu, nu, &ic_new, &jc_new ); + + // Update the ways of parallelism for the jc and ic loops, and then + // update the current thread's root thrinfo_t node according to the + // new ways of parallelism value for the jc loop. + bli_rntm_set_ways_only( jc_new, 1, ic_new, 1, 1, rntm ); + bli_l3_sup_thrinfo_update_root( rntm, thread ); + } + + /* Enable packing for B matrix for higher sizes. Note that pack A + * becomes pack B inside var2m because this is transpose case*/ + if(bli_is_float(dt) && (n_threads==1)) { + if((m > 240) && (k > 240) && (n > 240)) + bli_rntm_set_pack_a( 1, rntm ); + } + + bli_gemmsup_ref_var2m( BLIS_TRANSPOSE, + alpha, a, b, beta, c, + stor_id, cntx, rntm, thread ); + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4); + return BLIS_SUCCESS; + + +} + +// ----------------------------------------------------------------------------- + +err_t bli_gemmtsup_int + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm, + thrinfo_t* thread + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_4); +// AOCL_DTL_LOG_GEMMT_INPUTS(AOCL_DTL_LEVEL_TRACE_4, alpha, a, b, beta, c); + + + const stor3_t stor_id = bli_obj_stor3_from_strides( c, a, b ); + + // Don't use the small/unpacked implementation if one of the matrices + // uses general stride. + if ( stor_id == BLIS_XXX ) { + AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_4, "SUP doesn't support general stide."); + return BLIS_FAILURE; + } + + const bool is_rrr_rrc_rcr_crr = ( stor_id == BLIS_RRR || + stor_id == BLIS_RRC || + stor_id == BLIS_RCR || + stor_id == BLIS_CRR ); + const bool is_rcc_crc_ccr_ccc = !is_rrr_rrc_rcr_crr; + + const num_t dt = bli_obj_dt( c ); + const bool row_pref = bli_cntx_l3_sup_ker_prefers_rows_dt( dt, stor_id, cntx ); + + const bool is_primary = ( row_pref ? is_rrr_rrc_rcr_crr + : is_rcc_crc_ccr_ccc ); + + const dim_t m = bli_obj_length( c ); + const dim_t n = m; + const dim_t MR = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx ); + const dim_t NR = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx ); + const bool auto_factor = bli_rntm_auto_factor( rntm ); + const dim_t n_threads = bli_rntm_num_threads( rntm ); + bool use_bp = TRUE; + dim_t jc_new; + dim_t ic_new; + + + if ( is_primary ) + { + // This branch handles: + // - rrr rrc rcr crr for row-preferential kernels + // - rcc crc ccr ccc for column-preferential kernels + + const dim_t mu = m / MR; + const dim_t nu = n / NR; + + // Decide which algorithm to use (block-panel var2m or panel-block + // var1n) based on the number of micropanels in the m and n dimensions. + // Also, recalculate the automatic thread factorization. + + if ( mu >= nu ) use_bp = TRUE; + else /* if ( mu < nu ) */ use_bp = TRUE;// var1n is not implemented for GEMMT + + // If the parallel thread factorization was automatic, we update it + // with a new factorization based on the matrix dimensions in units + // of micropanels. + if ( auto_factor ) + { + if ( use_bp ) + { + // In the block-panel algorithm, the m dimension is parallelized + // with ic_nt and the n dimension is parallelized with jc_nt. + bli_thread_partition_2x2( n_threads, mu, nu, &ic_new, &jc_new ); + } + else // if ( !use_bp ) + { + // In the panel-block algorithm, the m dimension is parallelized + // with jc_nt and the n dimension is parallelized with ic_nt. + bli_thread_partition_2x2( n_threads, mu, nu, &jc_new, &ic_new ); + } + + // Update the ways of parallelism for the jc and ic loops, and then + // update the current thread's root thrinfo_t node according to the + // new ways of parallelism value for the jc loop. + bli_rntm_set_ways_only( jc_new, 1, ic_new, 1, 1, rntm ); + bli_l3_sup_thrinfo_update_root( rntm, thread ); + } + + + if ( use_bp ) + { + #ifdef TRACEVAR + if ( bli_thread_am_ochief( thread ) ) + printf( "bli_l3_sup_int(): var2m primary\n" ); + #endif + // block-panel macrokernel; m -> mc, mr; n -> nc, nr: var2() + bli_gemmtsup_ref_var2m( BLIS_NO_TRANSPOSE, + alpha, a, b, beta, c, + stor_id, cntx, rntm, thread ); + } + else // use_pb + { + #ifdef TRACEVAR + if ( bli_thread_am_ochief( thread ) ) + printf( "bli_l3_sup_int(): var1n primary\n" ); + #endif + // panel-block macrokernel; m -> nc*,mr; n -> mc*,nr: var1() + bli_gemmtsup_ref_var1n( BLIS_NO_TRANSPOSE, + alpha, a, b, beta, c, + stor_id, cntx, rntm, thread ); + // *requires nudging of nc up to be a multiple of mr. + } + } + else + { + // This branch handles: + // - rrr rrc rcr crr for column-preferential kernels + // - rcc crc ccr ccc for row-preferential kernels + + const dim_t mu = n / MR; // the n becomes m after a transposition + const dim_t nu = m / NR; // the m becomes n after a transposition + + // Decide which algorithm to use (block-panel var2m or panel-block + // var1n) based on the number of micropanels in the m and n dimensions. + // Also, recalculate the automatic thread factorization. + + if ( mu >= nu ) use_bp = TRUE; + else /* if ( mu < nu ) */ use_bp = TRUE; //var1n is not implemented for gemmt + + // If the parallel thread factorization was automatic, we update it + // with a new factorization based on the matrix dimensions in units + // of micropanels. + if ( auto_factor ) + { + if ( use_bp ) + { + // In the block-panel algorithm, the m dimension is parallelized + // with ic_nt and the n dimension is parallelized with jc_nt. + bli_thread_partition_2x2( n_threads, mu, nu, &ic_new, &jc_new ); + } + else // if ( !use_bp ) + { + // In the panel-block algorithm, the m dimension is parallelized + // with jc_nt and the n dimension is parallelized with ic_nt. + bli_thread_partition_2x2( n_threads, mu, nu, &jc_new, &ic_new ); + } + + // Update the ways of parallelism for the jc and ic loops, and then + // update the current thread's root thrinfo_t node according to the + // new ways of parallelism value for the jc loop. + bli_rntm_set_ways_only( jc_new, 1, ic_new, 1, 1, rntm ); + bli_l3_sup_thrinfo_update_root( rntm, thread ); + } + + + if ( use_bp ) + { + #ifdef TRACEVAR + if ( bli_thread_am_ochief( thread ) ) + printf( "bli_l3_sup_int(): var2m non-primary\n" ); + #endif + // panel-block macrokernel; m -> nc, nr; n -> mc, mr: var2() + trans + bli_gemmtsup_ref_var2m( BLIS_TRANSPOSE, + alpha, a, b, beta, c, + stor_id, cntx, rntm, thread ); + } + else // use_pb + { + #ifdef TRACEVAR + if ( bli_thread_am_ochief( thread ) ) + printf( "bli_l3_sup_int(): var1n non-primary\n" ); + #endif + // block-panel macrokernel; m -> mc*,nr; n -> nc*,mr: var1() + trans + bli_gemmtsup_ref_var1n( BLIS_TRANSPOSE, + alpha, a, b, beta, c, + stor_id, cntx, rntm, thread ); + // *requires nudging of mc up to be a multiple of nr. + } + } + + // Return success so that the caller knows that we computed the solution. + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) + return BLIS_SUCCESS; +} + diff --git a/frame/3/gemm/bli_gemm_front.c b/frame/3/gemm/bli_gemm_front.c index a065156bbf..972a7a782f 100644 --- a/frame/3/gemm/bli_gemm_front.c +++ b/frame/3/gemm/bli_gemm_front.c @@ -177,19 +177,6 @@ void bli_gemm_front dim_t m_dim_local = bli_obj_length( &c_local ); dim_t n_dim_local = bli_obj_width( &c_local ); dim_t k_dim_local = bli_obj_width( &a_local ); -#ifdef BLIS_CONFIG_EPYC - // Regression observed in sgemm native path in cases where m >= 4 * n - // after BLIS_THREAD_RATIO_M updated from 2 to 1 as part of commit - // 11dfc176a3c422729f453f6c23204cf023e9954d. Temporary workaround for - // the issue. - if( bli_obj_is_float( &c_local ) && - ( n_dim_local >= 1024 ) && - ( k_dim_local >= 1024 ) && - ( m_dim_local >= ( 4 * n_dim_local ) ) ) - { - m_dim_local *= 2; - } -#endif // Parse and interpret the contents of the rntm_t object to properly // set the ways of parallelism for each loop, and then make any diff --git a/frame/3/gemm/bli_gemm_front_amd.c b/frame/3/gemm/bli_gemm_front_amd.c new file mode 100644 index 0000000000..41af62007c --- /dev/null +++ b/frame/3/gemm/bli_gemm_front_amd.c @@ -0,0 +1,413 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +void bli_gemm_front + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm, + cntl_t* cntl + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_3); + bli_init_once(); + + obj_t a_local; + obj_t b_local; + obj_t c_local; + + // Check parameters. + if ( bli_error_checking_is_enabled() ) + bli_gemm_check( alpha, a, b, beta, c, cntx ); + + // If C has a zero dimension, return early. + if ( bli_obj_has_zero_dim( c ) ) + { + return; + } + + // If alpha is zero, or if A or B has a zero dimension, scale C by beta + // and return early. + if ( bli_obj_equals( alpha, &BLIS_ZERO ) || + bli_obj_has_zero_dim( a ) || + bli_obj_has_zero_dim( b ) ) + { + bli_scalm( beta, c ); + return; + } + +#ifdef BLIS_ENABLE_SMALL_MATRIX + // Only handle small problems separately for homogeneous datatypes. + if ( bli_obj_dt( a ) == bli_obj_dt( b ) && + bli_obj_dt( a ) == bli_obj_dt( c ) && + bli_obj_comp_prec( c ) == bli_obj_prec( c ) ) + { + err_t status = bli_gemm_small( alpha, a, b, beta, c, cntx, cntl ); + + if ( status == BLIS_SUCCESS ) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); + return; + } + } +#endif + + // Alias A, B, and C in case we need to apply transformations. + bli_obj_alias_to( a, &a_local ); + bli_obj_alias_to( b, &b_local ); + bli_obj_alias_to( c, &c_local ); + +#ifdef BLIS_ENABLE_GEMM_MD + cntx_t cntx_local; + + // If any of the storage datatypes differ, or if the computation precision + // differs from the storage precision of C, utilize the mixed datatype + // code path. + // NOTE: If we ever want to support the caller setting the computation + // domain explicitly, we will need to check the computation dt against the + // storage dt of C (instead of the computation precision against the + // storage precision of C). + if ( bli_obj_dt( &c_local ) != bli_obj_dt( &a_local ) || + bli_obj_dt( &c_local ) != bli_obj_dt( &b_local ) || + bli_obj_comp_prec( &c_local ) != bli_obj_prec( &c_local ) ) + { + // Handle mixed datatype cases in bli_gemm_md(), which may modify + // the objects or the context. (If the context is modified, cntx + // is adjusted to point to cntx_local.) + bli_gemm_md( &a_local, &b_local, beta, &c_local, &cntx_local, &cntx ); + } + //else // homogeneous datatypes +#endif + + // Load the pack schemas from the context and embed them into the objects + // for A and B. (Native contexts are initialized with the correct pack + // schemas, as are contexts for 1m, and if necessary bli_gemm_md() would + // have made a copy and modified the schemas, so reading them from the + // context should be a safe bet at this point.) This is a sort of hack for + // communicating the desired pack schemas to bli_gemm_cntl_create() (via + // bli_l3_thread_decorator() and bli_l3_cntl_create_if()). This allows us + // to subsequently access the schemas from the control tree, which + // hopefully reduces some confusion, particularly in bli_packm_init(). + const pack_t schema_a = bli_cntx_schema_a_block( cntx ); + const pack_t schema_b = bli_cntx_schema_b_panel( cntx ); + + bli_obj_set_pack_schema( schema_a, &a_local ); + bli_obj_set_pack_schema( schema_b, &b_local ); + + // Next, we handle the possibility of needing to typecast alpha to the + // computation datatype and/or beta to the storage datatype of C. + + // Attach alpha to B, and in the process typecast alpha to the target + // datatype of the matrix (which in this case is equal to the computation + // datatype). + bli_obj_scalar_attach( BLIS_NO_CONJUGATE, alpha, &b_local ); + + // Attach beta to C, and in the process typecast beta to the target + // datatype of the matrix (which in this case is equal to the storage + // datatype of C). + bli_obj_scalar_attach( BLIS_NO_CONJUGATE, beta, &c_local ); + + // Change the alpha and beta pointers to BLIS_ONE since the values have + // now been typecast and attached to the matrices above. + alpha = &BLIS_ONE; + beta = &BLIS_ONE; + +#ifdef BLIS_ENABLE_GEMM_MD + // Don't perform the following optimization for ccr or crc cases, as + // those cases are sensitive to the ukernel storage preference (ie: + // transposing the operation would break them). + if ( !bli_gemm_md_is_ccr( &a_local, &b_local, &c_local ) && + !bli_gemm_md_is_crc( &a_local, &b_local, &c_local ) ) +#endif + // An optimization: If C is stored by rows and the micro-kernel prefers + // contiguous columns, or if C is stored by columns and the micro-kernel + // prefers contiguous rows, transpose the entire operation to allow the + // micro-kernel to access elements of C in its preferred manner. + if ( bli_cntx_l3_vir_ukr_dislikes_storage_of( &c_local, BLIS_GEMM_UKR, cntx ) ) + { + bli_obj_swap( &a_local, &b_local ); + + bli_obj_induce_trans( &a_local ); + bli_obj_induce_trans( &b_local ); + bli_obj_induce_trans( &c_local ); + + // We must also swap the pack schemas, which were set by bli_gemm_md() + // or the inlined code above. + bli_obj_swap_pack_schemas( &a_local, &b_local ); + } + + dim_t m_dim_local = bli_obj_length( &c_local ); + dim_t n_dim_local = bli_obj_width( &c_local ); + dim_t k_dim_local = bli_obj_width( &a_local ); + + // Regression observed in sgemm native path in cases where m >= 4 * n + // after BLIS_THREAD_RATIO_M updated from 2 to 1 as part of commit + // 11dfc176a3c422729f453f6c23204cf023e9954d. Temporary workaround for + // the issue. + if( bli_obj_is_float( &c_local ) && + ( n_dim_local >= 1024 ) && + ( k_dim_local >= 1024 ) && + ( m_dim_local >= ( 4 * n_dim_local ) ) ) + { + m_dim_local *= 2; + } + + // Parse and interpret the contents of the rntm_t object to properly + // set the ways of parallelism for each loop, and then make any + // additional modifications necessary for the current operation. + bli_rntm_set_ways_for_op + ( + BLIS_GEMM, + BLIS_LEFT, // ignored for gemm/hemm/symm + m_dim_local, + n_dim_local, + k_dim_local, + rntm + ); + + obj_t* cp = &c_local; + obj_t* betap = beta; + +#ifdef BLIS_ENABLE_GEMM_MD +#ifdef BLIS_ENABLE_GEMM_MD_EXTRA_MEM + // If any of the following conditions are met, create a temporary matrix + // conformal to C into which we will accumulate the matrix product: + // - the storage precision of C differs from the computation precision; + // - the domains are mixed as crr; + // - the storage format of C does not match the preferred orientation + // of the ccr or crc cases. + // Then, after the computation is complete, this matrix will be copied + // or accumulated back to C. + const bool is_ccr_mismatch = + ( bli_gemm_md_is_ccr( &a_local, &b_local, &c_local ) && + !bli_obj_is_col_stored( &c_local ) ); + const bool is_crc_mismatch = + ( bli_gemm_md_is_crc( &a_local, &b_local, &c_local ) && + !bli_obj_is_row_stored( &c_local ) ); + + obj_t ct; + bool use_ct = FALSE; + + // FGVZ: Consider adding another guard here that only creates and uses a + // temporary matrix for accumulation if k < c * kc, where c is some small + // constant like 2. And don't forget to use the same conditional for the + // castm() and free() at the end. + if ( + bli_obj_prec( &c_local ) != bli_obj_comp_prec( &c_local ) || + bli_gemm_md_is_crr( &a_local, &b_local, &c_local ) || + is_ccr_mismatch || + is_crc_mismatch + ) + { + use_ct = TRUE; + } + + // If we need a temporary matrix conformal to C for whatever reason, + // we create it and prepare to use it now. + if ( use_ct ) + { + const dim_t m = bli_obj_length( &c_local ); + const dim_t n = bli_obj_width( &c_local ); + inc_t rs = bli_obj_row_stride( &c_local ); + inc_t cs = bli_obj_col_stride( &c_local ); + + num_t dt_ct = bli_obj_domain( &c_local ) | + bli_obj_comp_prec( &c_local ); + + // When performing the crr case, accumulate to a contiguously-stored + // real matrix so we do not have to repeatedly update C with general + // stride. + if ( bli_gemm_md_is_crr( &a_local, &b_local, &c_local ) ) + dt_ct = BLIS_REAL | bli_obj_comp_prec( &c_local ); + + // When performing the mismatched ccr or crc cases, now is the time + // to specify the appropriate storage so the gemm_md_c2r_ref() virtual + // microkernel can output directly to C (instead of using a temporary + // microtile). + if ( is_ccr_mismatch ) { rs = 1; cs = m; } + else if ( is_crc_mismatch ) { rs = n; cs = 1; } + + bli_obj_create( dt_ct, m, n, rs, cs, &ct ); + + const num_t dt_exec = bli_obj_exec_dt( &c_local ); + const num_t dt_comp = bli_obj_comp_dt( &c_local ); + + bli_obj_set_target_dt( dt_ct, &ct ); + bli_obj_set_exec_dt( dt_exec, &ct ); + bli_obj_set_comp_dt( dt_comp, &ct ); + + // A naive approach would cast C to the comptuation datatype, + // compute with beta, and then cast the result back to the + // user-provided output matrix. However, we employ a different + // approach that halves the number of memops on C (or its + // typecast temporary) by writing the A*B product directly to + // temporary storage, and then using xpbym to scale the + // output matrix by beta and accumulate/cast the A*B product. + //bli_castm( &c_local, &ct ); + betap = &BLIS_ZERO; + + cp = &ct; + } +#endif +#endif + + // Invoke the internal back-end via the thread handler. + bli_l3_thread_decorator + ( + bli_gemm_int, + BLIS_GEMM, // operation family id + alpha, + &a_local, + &b_local, + betap, + cp, + cntx, + rntm, + cntl + ); + +#ifdef BLIS_ENABLE_GEMM_MD +#ifdef BLIS_ENABLE_GEMM_MD_EXTRA_MEM + // If we created a temporary matrix conformal to C for whatever reason, + // we copy/accumulate the result back to C and then release the object. + if ( use_ct ) + { + obj_t beta_local; + + bli_obj_scalar_detach( &c_local, &beta_local ); + + //bli_castnzm( &ct, &c_local ); + bli_xpbym( &ct, &beta_local, &c_local ); + + bli_obj_free( &ct ); + } +#endif +#endif + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); +} + +// ----------------------------------------------------------------------------- + +#if 0 + if ( bli_obj_dt( a ) != bli_obj_dt( b ) || + bli_obj_dt( a ) != bli_obj_dt( c ) || + bli_obj_comp_prec( c ) != bli_obj_prec( c ) ) + { + const bool a_is_real = bli_obj_is_real( a ); + const bool a_is_comp = bli_obj_is_complex( a ); + const bool b_is_real = bli_obj_is_real( b ); + const bool b_is_comp = bli_obj_is_complex( b ); + const bool c_is_real = bli_obj_is_real( c ); + const bool c_is_comp = bli_obj_is_complex( c ); + + const bool a_is_single = bli_obj_is_single_prec( a ); + const bool a_is_double = bli_obj_is_double_prec( a ); + const bool b_is_single = bli_obj_is_single_prec( b ); + const bool b_is_double = bli_obj_is_double_prec( b ); + const bool c_is_single = bli_obj_is_single_prec( c ); + const bool c_is_double = bli_obj_is_double_prec( c ); + + const bool comp_single = bli_obj_comp_prec( c ) == BLIS_SINGLE_PREC; + const bool comp_double = bli_obj_comp_prec( c ) == BLIS_DOUBLE_PREC; + + const bool mixeddomain = bli_obj_domain( c ) != bli_obj_domain( a ) || + bli_obj_domain( c ) != bli_obj_domain( b ); + + ( void )a_is_real; ( void )a_is_comp; + ( void )b_is_real; ( void )b_is_comp; + ( void )c_is_real; ( void )c_is_comp; + ( void )a_is_single; ( void )a_is_double; + ( void )b_is_single; ( void )b_is_double; + ( void )c_is_single; ( void )c_is_double; + ( void )comp_single; ( void )comp_double; + + if ( + //( c_is_comp && a_is_comp && b_is_real ) || + //( c_is_comp && a_is_real && b_is_comp ) || + //( c_is_real && a_is_comp && b_is_comp ) || + //( c_is_comp && a_is_real && b_is_real ) || + //( c_is_real && a_is_comp && b_is_real ) || + //( c_is_real && a_is_real && b_is_comp ) || + //FALSE + TRUE + ) + { + if ( + ( c_is_single && a_is_single && b_is_single && mixeddomain ) || + ( c_is_single && a_is_single && b_is_single && comp_single ) || + ( c_is_single && a_is_single && b_is_single && comp_double ) || + ( c_is_single && a_is_single && b_is_double ) || + ( c_is_single && a_is_double && b_is_single ) || + ( c_is_double && a_is_single && b_is_single ) || + ( c_is_single && a_is_double && b_is_double ) || + ( c_is_double && a_is_single && b_is_double ) || + ( c_is_double && a_is_double && b_is_single ) || + ( c_is_double && a_is_double && b_is_double && comp_single ) || + ( c_is_double && a_is_double && b_is_double && comp_double ) || + ( c_is_double && a_is_double && b_is_double && mixeddomain ) || + FALSE + ) + bli_gemm_md_front( alpha, a, b, beta, c, cntx, cntl ); + else + bli_gemm_md_zgemm( alpha, a, b, beta, c, cntx, cntl ); + } + else + bli_gemm_md_zgemm( alpha, a, b, beta, c, cntx, cntl ); + return; + } +#else +#if 0 + // If any of the storage datatypes differ, or if the execution precision + // differs from the storage precision of C, utilize the mixed datatype + // code path. + // NOTE: We could check the exec dt against the storage dt of C, but for + // now we don't support the caller setting the execution domain + // explicitly. + if ( bli_obj_dt( a ) != bli_obj_dt( b ) || + bli_obj_dt( a ) != bli_obj_dt( c ) || + bli_obj_comp_prec( c ) != bli_obj_prec( c ) ) + { + bli_gemm_md_front( alpha, a, b, beta, c, cntx, cntl ); + return; + } +#endif +#endif + diff --git a/frame/base/bli_cpuid.c b/frame/base/bli_cpuid.c index d5d8315543..98ea947f3c 100644 --- a/frame/base/bli_cpuid.c +++ b/frame/base/bli_cpuid.c @@ -501,6 +501,25 @@ bool bli_cpuid_is_bulldozer return TRUE; } +bool bli_cpuid_is_avx_supported( void ) +{ + uint32_t family, model, features; + + // Call the CPUID instruction and parse its results into a family id, + // model id, and a feature bit field. The return value encodes the + // vendor. + bli_cpuid_query( &family, &model, &features ); + + // Check for expected CPU features. + const uint32_t expected = FEATURE_AVX | + FEATURE_FMA3 | + FEATURE_AVX2; + + if ( !bli_cpuid_has_features( features, expected ) ) return FALSE; + + return TRUE; +} + #elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) arch_t bli_cpuid_query_id( void ) diff --git a/frame/base/bli_cpuid.h b/frame/base/bli_cpuid.h index a9f960847a..cb4c45ab5d 100644 --- a/frame/base/bli_cpuid.h +++ b/frame/base/bli_cpuid.h @@ -133,7 +133,7 @@ BLIS_INLINE bool bli_cpuid_has_features( uint32_t have, uint32_t want ) void get_cpu_name( char *cpu_name ); int vpu_count( void ); - +bool bli_cpuid_is_avx_supported(void); enum { @@ -160,6 +160,8 @@ enum FEATURE_AVX512VL = 0x4000 }; + + #elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) char* find_string_in( char* target, char* buffer, size_t buf_len, char* filepath ); diff --git a/frame/compat/bla_amax.c b/frame/compat/bla_amax.c index 214dfe67aa..b1cf77e7b8 100644 --- a/frame/compat/bla_amax.c +++ b/frame/compat/bla_amax.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018-2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -98,217 +98,5 @@ f77_int PASTEF772(i,chx,blasname) \ } #ifdef BLIS_ENABLE_BLAS -#ifdef BLIS_CONFIG_EPYC - -f77_int isamax_ - ( - const f77_int* n, - const float* x, const f77_int* incx - ) -{ - - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); - AOCL_DTL_LOG_AMAX_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'S', *n, *incx); - - dim_t n0; - float* x0; - inc_t incx0; - gint_t bli_index; - f77_int f77_index; - - /* If the vector is empty, return an index of zero. This early check - is needed to emulate netlib BLAS. Without it, bli_?amaxv() will - return 0, which ends up getting incremented to 1 (below) before - being returned, which is not what we want. */ - if ( *n < 1 || *incx <= 0 ) { - AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_1, "isamax_: vector empty"); - return 0; - } - - /* Initialize BLIS. */ -// bli_init_auto(); - - /* Convert/typecast negative values of n to zero. */ - if ( *n < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(*n); - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - if ( *incx < 0 ) - { - /* The semantics of negative stride in BLAS are that the vector - operand be traversed in reverse order. (Another way to think - of this is that negative strides effectively reverse the order - of the vector, but without any explicit data movements.) This - is also how BLIS interprets negative strides. The differences - is that with BLAS, the caller *always* passes in the 0th (i.e., - top-most or left-most) element of the vector, even when the - stride is negative. By contrast, in BLIS, negative strides are - used *relative* to the vector address as it is given. Thus, in - BLIS, if this backwards traversal is desired, the caller *must* - pass in the address to the (n-1)th (i.e., the bottom-most or - right-most) element along with a negative stride. */ - - x0 = ((float*)x) + (n0-1)*(-*incx); - incx0 = ( inc_t )(*incx); - - } - else - { - x0 = ((float*)x); - incx0 = ( inc_t )(*incx); - } - - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. - // This function is invoked on all architectures including ‘generic’. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN4) || - (id == BLIS_ARCH_ZEN3) || - (id == BLIS_ARCH_ZEN2) || - (id == BLIS_ARCH_ZEN); - - if (bamdzen) - { - /* Call BLIS kernel */ - bli_samaxv_zen_int - ( - n0, - x0, incx0, - &bli_index, - NULL - ); - } - else - { - PASTEMAC2(s,amaxv,BLIS_TAPI_EX_SUF) - ( - n0, - x0, incx0, - &bli_index, - NULL, - NULL - ); - } - - /* Convert zero-based BLIS (C) index to one-based BLAS (Fortran) - index. Also, if the BLAS integer size differs from the BLIS - integer size, that typecast occurs here. */ - f77_index = bli_index + 1; - - /* Finalize BLIS. */ -// bli_finalize_auto(); - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - - return f77_index; -} - -f77_int idamax_ - ( - const f77_int* n, - const double* x, const f77_int* incx - ) -{ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); - AOCL_DTL_LOG_AMAX_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'D', *n, *incx); - - dim_t n0; - double* x0; - inc_t incx0; - gint_t bli_index; - f77_int f77_index; - - /* If the vector is empty, return an index of zero. This early check - is needed to emulate netlib BLAS. Without it, bli_?amaxv() will - return 0, which ends up getting incremented to 1 (below) before - being returned, which is not what we want. */ - if ( *n < 1 || *incx <= 0 ) { - AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_1, "idamax_: vector empty"); - return 0; - } - - /* Initialize BLIS. */ -// bli_init_auto(); - - /* Convert/typecast negative values of n to zero. */ - if ( *n < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(*n); - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - if ( *incx < 0 ) - { - /* The semantics of negative stride in BLAS are that the vector - operand be traversed in reverse order. (Another way to think - of this is that negative strides effectively reverse the order - of the vector, but without any explicit data movements.) This - is also how BLIS interprets negative strides. The differences - is that with BLAS, the caller *always* passes in the 0th (i.e., - top-most or left-most) element of the vector, even when the - stride is negative. By contrast, in BLIS, negative strides are - used *relative* to the vector address as it is given. Thus, in - BLIS, if this backwards traversal is desired, the caller *must* - pass in the address to the (n-1)th (i.e., the bottom-most or - right-most) element along with a negative stride. */ - - x0 = ((double*)x) + (n0-1)*(-*incx); - incx0 = ( inc_t )(*incx); - - } - else - { - x0 = ((double*)x); - incx0 = ( inc_t )(*incx); - } - - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. - // This function is invoked on all architectures including ‘generic’. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN4) || - (id == BLIS_ARCH_ZEN3) || - (id == BLIS_ARCH_ZEN2) || - (id == BLIS_ARCH_ZEN); - - if (bamdzen) - { - /* Call BLIS kernel */ - bli_damaxv_zen_int - ( - n0, - x0, incx0, - &bli_index, - NULL - ); - } - else - { - PASTEMAC2(d,amaxv,BLIS_TAPI_EX_SUF) - ( - n0, - x0, incx0, - &bli_index, - NULL, - NULL - ); - } - - /* Convert zero-based BLIS (C) index to one-based BLAS (Fortran) - index. Also, if the BLAS integer size differs from the BLIS - integer size, that typecast occurs here. */ - f77_index = bli_index + 1; - - /* Finalize BLIS. */ -// bli_finalize_auto(); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return f77_index; -} - -INSERT_GENTFUNC_BLAS_CZ( amax, amaxv ) -#else INSERT_GENTFUNC_BLAS( amax, amaxv ) #endif -#endif diff --git a/frame/compat/bla_amax_amd.c b/frame/compat/bla_amax_amd.c new file mode 100644 index 0000000000..7f1a771f7c --- /dev/null +++ b/frame/compat/bla_amax_amd.c @@ -0,0 +1,295 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +// +// Define BLAS-to-BLIS interfaces. +// +#undef GENTFUNC +#define GENTFUNC( ftype_x, chx, blasname, blisname ) \ +\ +f77_int PASTEF772(i,chx,blasname) \ + ( \ + const f77_int* n, \ + const ftype_x* x, const f77_int* incx \ + ) \ +{ \ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) \ + AOCL_DTL_LOG_AMAX_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(chx), *n, *incx) \ +\ + dim_t n0; \ + ftype_x* x0; \ + inc_t incx0; \ + gint_t bli_index; \ + f77_int f77_index; \ +\ + /* If the vector is empty, return an index of zero. This early check + is needed to emulate netlib BLAS. Without it, bli_?amaxv() will + return 0, which ends up getting incremented to 1 (below) before + being returned, which is not what we want. */ \ + if ( *n < 1 || *incx <= 0 ) { \ + AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_1, "iamax_: vector empty") \ + return 0; \ + }\ +\ + /* Initialize BLIS. */ \ + bli_init_auto(); \ +\ + /* Convert/typecast negative values of n to zero. */ \ + bli_convert_blas_dim1( *n, n0 ); \ +\ + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ \ + bli_convert_blas_incv( n0, (ftype_x*)x, *incx, x0, incx0 ); \ +\ + /* Call BLIS interface. */ \ + PASTEMAC2(chx,blisname,BLIS_TAPI_EX_SUF) \ + ( \ + n0, \ + x0, incx0, \ + &bli_index, \ + NULL, \ + NULL \ + ); \ +\ + /* Convert zero-based BLIS (C) index to one-based BLAS (Fortran) + index. Also, if the BLAS integer size differs from the BLIS + integer size, that typecast occurs here. */ \ + f77_index = bli_index + 1; \ +\ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ +\ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ + return f77_index; \ +} + +#ifdef BLIS_ENABLE_BLAS + +f77_int isamax_ + ( + const f77_int* n, + const float* x, const f77_int* incx + ) +{ + + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); + AOCL_DTL_LOG_AMAX_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'S', *n, *incx); + + dim_t n0; + float* x0; + inc_t incx0; + gint_t bli_index; + f77_int f77_index; + + /* If the vector is empty, return an index of zero. This early check + is needed to emulate netlib BLAS. Without it, bli_?amaxv() will + return 0, which ends up getting incremented to 1 (below) before + being returned, which is not what we want. */ + if ( *n < 1 || *incx <= 0 ) { + AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_1, "isamax_: vector empty"); + return 0; + } + + /* Initialize BLIS. */ +// bli_init_auto(); + + /* Convert/typecast negative values of n to zero. */ + if ( *n < 0 ) n0 = ( dim_t )0; + else n0 = ( dim_t )(*n); + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + if ( *incx < 0 ) + { + /* The semantics of negative stride in BLAS are that the vector + operand be traversed in reverse order. (Another way to think + of this is that negative strides effectively reverse the order + of the vector, but without any explicit data movements.) This + is also how BLIS interprets negative strides. The differences + is that with BLAS, the caller *always* passes in the 0th (i.e., + top-most or left-most) element of the vector, even when the + stride is negative. By contrast, in BLIS, negative strides are + used *relative* to the vector address as it is given. Thus, in + BLIS, if this backwards traversal is desired, the caller *must* + pass in the address to the (n-1)th (i.e., the bottom-most or + right-most) element along with a negative stride. */ + + x0 = ((float*)x) + (n0-1)*(-*incx); + incx0 = ( inc_t )(*incx); + + } + else + { + x0 = ((float*)x); + incx0 = ( inc_t )(*incx); + } + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) + { + /* Call BLIS kernel */ + bli_samaxv_zen_int + ( + n0, + x0, incx0, + &bli_index, + NULL + ); + } + else + { + PASTEMAC2(s,amaxv,BLIS_TAPI_EX_SUF) + ( + n0, + x0, incx0, + &bli_index, + NULL, + NULL + ); + } + + /* Convert zero-based BLIS (C) index to one-based BLAS (Fortran) + index. Also, if the BLAS integer size differs from the BLIS + integer size, that typecast occurs here. */ + f77_index = bli_index + 1; + + /* Finalize BLIS. */ +// bli_finalize_auto(); + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + + return f77_index; +} + +f77_int idamax_ + ( + const f77_int* n, + const double* x, const f77_int* incx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); + AOCL_DTL_LOG_AMAX_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'D', *n, *incx); + + dim_t n0; + double* x0; + inc_t incx0; + gint_t bli_index; + f77_int f77_index; + + /* If the vector is empty, return an index of zero. This early check + is needed to emulate netlib BLAS. Without it, bli_?amaxv() will + return 0, which ends up getting incremented to 1 (below) before + being returned, which is not what we want. */ + if ( *n < 1 || *incx <= 0 ) { + AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_1, "idamax_: vector empty"); + return 0; + } + + /* Initialize BLIS. */ +// bli_init_auto(); + + /* Convert/typecast negative values of n to zero. */ + if ( *n < 0 ) n0 = ( dim_t )0; + else n0 = ( dim_t )(*n); + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + if ( *incx < 0 ) + { + /* The semantics of negative stride in BLAS are that the vector + operand be traversed in reverse order. (Another way to think + of this is that negative strides effectively reverse the order + of the vector, but without any explicit data movements.) This + is also how BLIS interprets negative strides. The differences + is that with BLAS, the caller *always* passes in the 0th (i.e., + top-most or left-most) element of the vector, even when the + stride is negative. By contrast, in BLIS, negative strides are + used *relative* to the vector address as it is given. Thus, in + BLIS, if this backwards traversal is desired, the caller *must* + pass in the address to the (n-1)th (i.e., the bottom-most or + right-most) element along with a negative stride. */ + + x0 = ((double*)x) + (n0-1)*(-*incx); + incx0 = ( inc_t )(*incx); + + } + else + { + x0 = ((double*)x); + incx0 = ( inc_t )(*incx); + } + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) + { + /* Call BLIS kernel */ + bli_damaxv_zen_int + ( + n0, + x0, incx0, + &bli_index, + NULL + ); + } + else + { + PASTEMAC2(d,amaxv,BLIS_TAPI_EX_SUF) + ( + n0, + x0, incx0, + &bli_index, + NULL, + NULL + ); + } + + /* Convert zero-based BLIS (C) index to one-based BLAS (Fortran) + index. Also, if the BLAS integer size differs from the BLIS + integer size, that typecast occurs here. */ + f77_index = bli_index + 1; + + /* Finalize BLIS. */ +// bli_finalize_auto(); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return f77_index; +} + +INSERT_GENTFUNC_BLAS_CZ( amax, amaxv ) + +#endif diff --git a/frame/compat/bla_axpy.c b/frame/compat/bla_axpy.c index 93f30e1e55..1a30f417b3 100644 --- a/frame/compat/bla_axpy.c +++ b/frame/compat/bla_axpy.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 21, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 22, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -87,411 +87,6 @@ void PASTEF77(ch,blasname) \ #ifdef BLIS_ENABLE_BLAS -#ifdef BLIS_CONFIG_EPYC -void saxpy_ -( - const f77_int* n, - const float* alpha, - const float* x, const f77_int* incx, - float* y, const f77_int* incy - ) -{ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) - AOCL_DTL_LOG_AXPY_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'S', *n, (float*)alpha, *incx, *incy) - dim_t n0; - float* x0; - float* y0; - inc_t incx0; - inc_t incy0; - - /* Initialize BLIS. */ - // bli_init_auto(); - - /* Convert/typecast negative values of n to zero. */ - if ( *n < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(*n); - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - if ( *incx < 0 ) - { - /* The semantics of negative stride in BLAS are that the vector - operand be traversed in reverse order. (Another way to think - of this is that negative strides effectively reverse the order - of the vector, but without any explicit data movements.) This - is also how BLIS interprets negative strides. The differences - is that with BLAS, the caller *always* passes in the 0th (i.e., - top-most or left-most) element of the vector, even when the - stride is negative. By contrast, in BLIS, negative strides are - used *relative* to the vector address as it is given. Thus, in - BLIS, if this backwards traversal is desired, the caller *must* - pass in the address to the (n-1)th (i.e., the bottom-most or - right-most) element along with a negative stride. */ - x0 = ((float*)x) + (n0-1)*(-*incx); - incx0 = ( inc_t )(*incx); - } - else - { - x0 = ((float*)x); - incx0 = ( inc_t )(*incx); - } - if ( *incy < 0 ) - { - y0 = ((float*)y) + (n0-1)*(-*incy); - incy0 = ( inc_t )(*incy); - } - else - { - y0 = ((float*)y); - incy0 = ( inc_t )(*incy); - } - - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. - // This function is invoked on all architectures including ‘generic’. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN4) || - (id == BLIS_ARCH_ZEN3) || - (id == BLIS_ARCH_ZEN2) || - (id == BLIS_ARCH_ZEN); - - if (bamdzen) - { - bli_saxpyv_zen_int10 - ( - BLIS_NO_CONJUGATE, - n0, - (float*)alpha, - x0, incx0, - y0, incy0, - NULL - ); - - } - else - { - PASTEMAC2(s,axpyv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - n0, - (float*)alpha, - x0, incx0, - y0, incy0, - NULL, - NULL - ); - - } - /* Finalize BLIS. */ - // bli_finalize_auto(); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); -} - -void daxpy_ -( - const f77_int* n, - const double* alpha, - const double* x, const f77_int* incx, - double* y, const f77_int* incy - ) -{ - dim_t n0; - double* x0; - double* y0; - inc_t incx0; - inc_t incy0; - - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) - AOCL_DTL_LOG_AXPY_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'D', *n, (double*)alpha, *incx, *incy) - /* Initialize BLIS. */ - // bli_init_auto(); - - /* Convert/typecast negative values of n to zero. */ - if ( *n < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(*n); - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - if ( *incx < 0 ) - { - /* The semantics of negative stride in BLAS are that the vector - operand be traversed in reverse order. (Another way to think - of this is that negative strides effectively reverse the order - of the vector, but without any explicit data movements.) This - is also how BLIS interprets negative strides. The differences - is that with BLAS, the caller *always* passes in the 0th (i.e., - top-most or left-most) element of the vector, even when the - stride is negative. By contrast, in BLIS, negative strides are - used *relative* to the vector address as it is given. Thus, in - BLIS, if this backwards traversal is desired, the caller *must* - pass in the address to the (n-1)th (i.e., the bottom-most or - right-most) element along with a negative stride. */ - x0 = ((double*)x) + (n0-1)*(-*incx); - incx0 = ( inc_t )(*incx); - } - else - { - x0 = ((double*)x); - incx0 = ( inc_t )(*incx); - } - if ( *incy < 0 ) - { - y0 = ((double*)y) + (n0-1)*(-*incy); - incy0 = ( inc_t )(*incy); - } - else - { - y0 = ((double*)y); - incy0 = ( inc_t )(*incy); - } - - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. - // This function is invoked on all architectures including ‘generic’. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN4) || - (id == BLIS_ARCH_ZEN3) || - (id == BLIS_ARCH_ZEN2) || - (id == BLIS_ARCH_ZEN); - - if (bamdzen) - { - bli_daxpyv_zen_int10 - ( - BLIS_NO_CONJUGATE, - n0, - (double*)alpha, - x0, incx0, - y0, incy0, - NULL - ); - - } - else - { - PASTEMAC2(d,axpyv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - n0, - (double*)alpha, - x0, incx0, - y0, incy0, - NULL, - NULL - ); - - } - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - /* Finalize BLIS. */ - // bli_finalize_auto(); -} - -void caxpy_ -( - const f77_int* n, - const scomplex* alpha, - const scomplex* x, const f77_int* incx, - scomplex* y, const f77_int* incy - ) -{ - dim_t n0; - scomplex* x0; - scomplex* y0; - inc_t incx0; - inc_t incy0; - - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) - AOCL_DTL_LOG_AXPY_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'C', *n, (scomplex*)alpha, *incx, *incy) - - /* Initialize BLIS. */ - // bli_init_auto(); - /* Convert/typecast negative values of n to zero. */ - if ( *n < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(*n); - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - if ( *incx < 0 ) - { - /* The semantics of negative stride in BLAS are that the vector - operand be traversed in reverse order. (Another way to think - of this is that negative strides effectively reverse the order - of the vector, but without any explicit data movements.) This - is also how BLIS interprets negative strides. The differences - is that with BLAS, the caller *always* passes in the 0th (i.e., - top-most or left-most) element of the vector, even when the - stride is negative. By contrast, in BLIS, negative strides are - used *relative* to the vector address as it is given. Thus, in - BLIS, if this backwards traversal is desired, the caller *must* - pass in the address to the (n-1)th (i.e., the bottom-most or - right-most) element along with a negative stride. */ - x0 = ((scomplex*)x) + (n0-1)*(-*incx); - incx0 = ( inc_t )(*incx); - } - else - { - x0 = ((scomplex*)x); - incx0 = ( inc_t )(*incx); - } - if ( *incy < 0 ) - { - y0 = ((scomplex*)y) + (n0-1)*(-*incy); - incy0 = ( inc_t )(*incy); - } - else - { - y0 = ((scomplex*)y); - incy0 = ( inc_t )(*incy); - } - - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. - // This function is invoked on all architectures including ‘generic’. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN4) || - (id == BLIS_ARCH_ZEN3) || - (id == BLIS_ARCH_ZEN2) || - (id == BLIS_ARCH_ZEN); - - if (bamdzen) - { - bli_caxpyv_zen_int5 - ( - BLIS_NO_CONJUGATE, - n0, - (scomplex*)alpha, - x0, incx0, - y0, incy0, - NULL - ); - - } - else - { - PASTEMAC2(c,axpyv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - n0, - (scomplex*)alpha, - x0, incx0, - y0, incy0, - NULL, - NULL - ); - } - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - /* Finalize BLIS. */ - // bli_finalize_auto(); -} - -void zaxpy_ -( - const f77_int* n, - const dcomplex* alpha, - const dcomplex* x, const f77_int* incx, - dcomplex* y, const f77_int* incy - ) -{ - dim_t n0; - dcomplex* x0; - dcomplex* y0; - inc_t incx0; - inc_t incy0; - - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) - AOCL_DTL_LOG_AXPY_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'Z', *n, (dcomplex*)alpha, *incx, *incy) - - /* Initialize BLIS. */ - // bli_init_auto(); - - /* Convert/typecast negative values of n to zero. */ - if ( *n < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(*n); - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - if ( *incx < 0 ) - { - /* The semantics of negative stride in BLAS are that the vector - operand be traversed in reverse order. (Another way to think - of this is that negative strides effectively reverse the order - of the vector, but without any explicit data movements.) This - is also how BLIS interprets negative strides. The differences - is that with BLAS, the caller *always* passes in the 0th (i.e., - top-most or left-most) element of the vector, even when the - stride is negative. By contrast, in BLIS, negative strides are - used *relative* to the vector address as it is given. Thus, in - BLIS, if this backwards traversal is desired, the caller *must* - pass in the address to the (n-1)th (i.e., the bottom-most or - right-most) element along with a negative stride. */ - x0 = ((dcomplex*)x) + (n0-1)*(-*incx); - incx0 = ( inc_t )(*incx); - } - else - { - x0 = ((dcomplex*)x); - incx0 = ( inc_t )(*incx); - } - if ( *incy < 0 ) - { - y0 = ((dcomplex*)y) + (n0-1)*(-*incy); - incy0 = ( inc_t )(*incy); - } - else - { - y0 = ((dcomplex*)y); - incy0 = ( inc_t )(*incy); - } - - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. - // This function is invoked on all architectures including ‘generic’. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN4) || - (id == BLIS_ARCH_ZEN3) || - (id == BLIS_ARCH_ZEN2) || - (id == BLIS_ARCH_ZEN); - - if (bamdzen) - { - bli_zaxpyv_zen_int5 - ( - BLIS_NO_CONJUGATE, - n0, - (dcomplex*)alpha, - x0, incx0, - y0, incy0, - NULL - ); - - } - else - { - PASTEMAC2(z,axpyv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - n0, - (dcomplex*)alpha, - x0, incx0, - y0, incy0, - NULL, - NULL - ); - } - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - /* Finalize BLIS. */ - // bli_finalize_auto(); -} - -#else INSERT_GENTFUNC_BLAS( axpy, axpyv ) -#endif #endif diff --git a/frame/compat/bla_axpy_amd.c b/frame/compat/bla_axpy_amd.c new file mode 100644 index 0000000000..8a9f0280c6 --- /dev/null +++ b/frame/compat/bla_axpy_amd.c @@ -0,0 +1,462 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020 - 22, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + + +// +// Define BLAS-to-BLIS interfaces. +// +#undef GENTFUNC +#define GENTFUNC( ftype, ch, blasname, blisname ) \ +\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_int* n, \ + const ftype* alpha, \ + const ftype* x, const f77_int* incx, \ + ftype* y, const f77_int* incy \ + ) \ +{ \ + dim_t n0; \ + ftype* x0; \ + ftype* y0; \ + inc_t incx0; \ + inc_t incy0; \ +\ + /* Initialize BLIS. */ \ + bli_init_auto(); \ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) \ + AOCL_DTL_LOG_AXPY_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(ch), *n, (void*)alpha, *incx, *incy) \ + /* Convert/typecast negative values of n to zero. */ \ + bli_convert_blas_dim1( *n, n0 ); \ +\ + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ \ + bli_convert_blas_incv( n0, (ftype*)x, *incx, x0, incx0 ); \ + bli_convert_blas_incv( n0, (ftype*)y, *incy, y0, incy0 ); \ +\ + /* Call BLIS interface. */ \ + PASTEMAC2(ch,blisname,BLIS_TAPI_EX_SUF) \ + ( \ + BLIS_NO_CONJUGATE, \ + n0, \ + (ftype*)alpha, \ + x0, incx0, \ + y0, incy0, \ + NULL, \ + NULL \ + ); \ +\ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ +} + +#ifdef BLIS_ENABLE_BLAS + +void saxpy_ +( + const f77_int* n, + const float* alpha, + const float* x, const f77_int* incx, + float* y, const f77_int* incy + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) + AOCL_DTL_LOG_AXPY_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'S', *n, (float*)alpha, *incx, *incy) + dim_t n0; + float* x0; + float* y0; + inc_t incx0; + inc_t incy0; + + /* Initialize BLIS. */ + // bli_init_auto(); + + /* Convert/typecast negative values of n to zero. */ + if ( *n < 0 ) n0 = ( dim_t )0; + else n0 = ( dim_t )(*n); + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + if ( *incx < 0 ) + { + /* The semantics of negative stride in BLAS are that the vector + operand be traversed in reverse order. (Another way to think + of this is that negative strides effectively reverse the order + of the vector, but without any explicit data movements.) This + is also how BLIS interprets negative strides. The differences + is that with BLAS, the caller *always* passes in the 0th (i.e., + top-most or left-most) element of the vector, even when the + stride is negative. By contrast, in BLIS, negative strides are + used *relative* to the vector address as it is given. Thus, in + BLIS, if this backwards traversal is desired, the caller *must* + pass in the address to the (n-1)th (i.e., the bottom-most or + right-most) element along with a negative stride. */ + x0 = ((float*)x) + (n0-1)*(-*incx); + incx0 = ( inc_t )(*incx); + } + else + { + x0 = ((float*)x); + incx0 = ( inc_t )(*incx); + } + if ( *incy < 0 ) + { + y0 = ((float*)y) + (n0-1)*(-*incy); + incy0 = ( inc_t )(*incy); + } + else + { + y0 = ((float*)y); + incy0 = ( inc_t )(*incy); + } + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) + { + bli_saxpyv_zen_int10 + ( + BLIS_NO_CONJUGATE, + n0, + (float*)alpha, + x0, incx0, + y0, incy0, + NULL + ); + + } + else + { + PASTEMAC2(s,axpyv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + n0, + (float*)alpha, + x0, incx0, + y0, incy0, + NULL, + NULL + ); + + } + /* Finalize BLIS. */ + // bli_finalize_auto(); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); +} + +void daxpy_ +( + const f77_int* n, + const double* alpha, + const double* x, const f77_int* incx, + double* y, const f77_int* incy + ) +{ + dim_t n0; + double* x0; + double* y0; + inc_t incx0; + inc_t incy0; + + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) + AOCL_DTL_LOG_AXPY_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'D', *n, (double*)alpha, *incx, *incy) + /* Initialize BLIS. */ + // bli_init_auto(); + + /* Convert/typecast negative values of n to zero. */ + if ( *n < 0 ) n0 = ( dim_t )0; + else n0 = ( dim_t )(*n); + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + if ( *incx < 0 ) + { + /* The semantics of negative stride in BLAS are that the vector + operand be traversed in reverse order. (Another way to think + of this is that negative strides effectively reverse the order + of the vector, but without any explicit data movements.) This + is also how BLIS interprets negative strides. The differences + is that with BLAS, the caller *always* passes in the 0th (i.e., + top-most or left-most) element of the vector, even when the + stride is negative. By contrast, in BLIS, negative strides are + used *relative* to the vector address as it is given. Thus, in + BLIS, if this backwards traversal is desired, the caller *must* + pass in the address to the (n-1)th (i.e., the bottom-most or + right-most) element along with a negative stride. */ + x0 = ((double*)x) + (n0-1)*(-*incx); + incx0 = ( inc_t )(*incx); + } + else + { + x0 = ((double*)x); + incx0 = ( inc_t )(*incx); + } + if ( *incy < 0 ) + { + y0 = ((double*)y) + (n0-1)*(-*incy); + incy0 = ( inc_t )(*incy); + } + else + { + y0 = ((double*)y); + incy0 = ( inc_t )(*incy); + } + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) + { + bli_daxpyv_zen_int10 + ( + BLIS_NO_CONJUGATE, + n0, + (double*)alpha, + x0, incx0, + y0, incy0, + NULL + ); + + } + else + { + PASTEMAC2(d,axpyv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + n0, + (double*)alpha, + x0, incx0, + y0, incy0, + NULL, + NULL + ); + + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + /* Finalize BLIS. */ + // bli_finalize_auto(); +} + +void caxpy_ +( + const f77_int* n, + const scomplex* alpha, + const scomplex* x, const f77_int* incx, + scomplex* y, const f77_int* incy + ) +{ + dim_t n0; + scomplex* x0; + scomplex* y0; + inc_t incx0; + inc_t incy0; + + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) + AOCL_DTL_LOG_AXPY_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'C', *n, (scomplex*)alpha, *incx, *incy) + + /* Initialize BLIS. */ + // bli_init_auto(); + /* Convert/typecast negative values of n to zero. */ + if ( *n < 0 ) n0 = ( dim_t )0; + else n0 = ( dim_t )(*n); + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + if ( *incx < 0 ) + { + /* The semantics of negative stride in BLAS are that the vector + operand be traversed in reverse order. (Another way to think + of this is that negative strides effectively reverse the order + of the vector, but without any explicit data movements.) This + is also how BLIS interprets negative strides. The differences + is that with BLAS, the caller *always* passes in the 0th (i.e., + top-most or left-most) element of the vector, even when the + stride is negative. By contrast, in BLIS, negative strides are + used *relative* to the vector address as it is given. Thus, in + BLIS, if this backwards traversal is desired, the caller *must* + pass in the address to the (n-1)th (i.e., the bottom-most or + right-most) element along with a negative stride. */ + x0 = ((scomplex*)x) + (n0-1)*(-*incx); + incx0 = ( inc_t )(*incx); + } + else + { + x0 = ((scomplex*)x); + incx0 = ( inc_t )(*incx); + } + if ( *incy < 0 ) + { + y0 = ((scomplex*)y) + (n0-1)*(-*incy); + incy0 = ( inc_t )(*incy); + } + else + { + y0 = ((scomplex*)y); + incy0 = ( inc_t )(*incy); + } + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) + { + bli_caxpyv_zen_int5 + ( + BLIS_NO_CONJUGATE, + n0, + (scomplex*)alpha, + x0, incx0, + y0, incy0, + NULL + ); + + } + else + { + PASTEMAC2(c,axpyv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + n0, + (scomplex*)alpha, + x0, incx0, + y0, incy0, + NULL, + NULL + ); + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + /* Finalize BLIS. */ + // bli_finalize_auto(); +} + +void zaxpy_ +( + const f77_int* n, + const dcomplex* alpha, + const dcomplex* x, const f77_int* incx, + dcomplex* y, const f77_int* incy + ) +{ + dim_t n0; + dcomplex* x0; + dcomplex* y0; + inc_t incx0; + inc_t incy0; + + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) + AOCL_DTL_LOG_AXPY_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'Z', *n, (dcomplex*)alpha, *incx, *incy) + + /* Initialize BLIS. */ + // bli_init_auto(); + + /* Convert/typecast negative values of n to zero. */ + if ( *n < 0 ) n0 = ( dim_t )0; + else n0 = ( dim_t )(*n); + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + if ( *incx < 0 ) + { + /* The semantics of negative stride in BLAS are that the vector + operand be traversed in reverse order. (Another way to think + of this is that negative strides effectively reverse the order + of the vector, but without any explicit data movements.) This + is also how BLIS interprets negative strides. The differences + is that with BLAS, the caller *always* passes in the 0th (i.e., + top-most or left-most) element of the vector, even when the + stride is negative. By contrast, in BLIS, negative strides are + used *relative* to the vector address as it is given. Thus, in + BLIS, if this backwards traversal is desired, the caller *must* + pass in the address to the (n-1)th (i.e., the bottom-most or + right-most) element along with a negative stride. */ + x0 = ((dcomplex*)x) + (n0-1)*(-*incx); + incx0 = ( inc_t )(*incx); + } + else + { + x0 = ((dcomplex*)x); + incx0 = ( inc_t )(*incx); + } + if ( *incy < 0 ) + { + y0 = ((dcomplex*)y) + (n0-1)*(-*incy); + incy0 = ( inc_t )(*incy); + } + else + { + y0 = ((dcomplex*)y); + incy0 = ( inc_t )(*incy); + } + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) + { + bli_zaxpyv_zen_int5 + ( + BLIS_NO_CONJUGATE, + n0, + (dcomplex*)alpha, + x0, incx0, + y0, incy0, + NULL + ); + + } + else + { + PASTEMAC2(z,axpyv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + n0, + (dcomplex*)alpha, + x0, incx0, + y0, incy0, + NULL, + NULL + ); + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + /* Finalize BLIS. */ + // bli_finalize_auto(); +} + + + +#endif diff --git a/frame/compat/bla_copy.c b/frame/compat/bla_copy.c index f4aa3ee83b..74baba689c 100644 --- a/frame/compat/bla_copy.c +++ b/frame/compat/bla_copy.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -88,217 +88,5 @@ void PASTEF77(ch,blasname) \ } #ifdef BLIS_ENABLE_BLAS -#ifdef BLIS_CONFIG_EPYC - -void scopy_ -( - const f77_int* n, - const float* x, const f77_int* incx, - float* y, const f77_int* incy -) -{ - dim_t n0; - float* x0; - float* y0; - inc_t incx0; - inc_t incy0; - - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) - AOCL_DTL_LOG_COPY_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'S', *n, *incx, *incy) - /* Initialize BLIS. */ -// bli_init_auto(); - - /* Convert/typecast negative values of n to zero. */ - if (*n < 0) - n0 = (dim_t)0; - else - n0 = (dim_t)(*n); - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - if (*incx < 0) - { - /* The semantics of negative stride in BLAS are that the vector - operand be traversed in reverse order. (Another way to think - of this is that negative strides effectively reverse the order - of the vector, but without any explicit data movements.) This - is also how BLIS interprets negative strides. The differences - is that with BLAS, the caller *always* passes in the 0th (i.e., - top-most or left-most) element of the vector, even when the - stride is negative. By contrast, in BLIS, negative strides are - used *relative* to the vector address as it is given. Thus, in - BLIS, if this backwards traversal is desired, the caller *must* - pass in the address to the (n-1)th (i.e., the bottom-most or - right-most) element along with a negative stride. */ - - x0 = (float*)((x)+(n0 - 1)*(-*incx)); - incx0 = (inc_t)(*incx); - - } - else - { - x0 = (float*)(x); - incx0 = (inc_t)(*incx); - } - - if (*incy < 0) - { - y0 = (y)+(n0 - 1)*(-*incy); - incy0 = (inc_t)(*incy); - - } - else - { - y0 = (y); - incy0 = (inc_t)(*incy); - } - - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. - // This function is invoked on all architectures including ‘generic’. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN4) || - (id == BLIS_ARCH_ZEN3) || - (id == BLIS_ARCH_ZEN2) || - (id == BLIS_ARCH_ZEN); - - if (bamdzen) - { - /* Call BLIS kernel */ - bli_scopyv_zen_int - ( - BLIS_NO_CONJUGATE, - n0, - x0, incx0, - y0, incy0, - NULL - ); - } - else - { - PASTEMAC2(s, copyv, BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - n0, - x0, incx0, - y0, incy0, - NULL, - NULL - ); - } - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) - /* Finalize BLIS. */ -// bli_finalize_auto(); -} - -void dcopy_ -( - const f77_int* n, - const double* x, const f77_int* incx, - double* y, const f77_int* incy -) -{ - dim_t n0; - double* x0; - double* y0; - inc_t incx0; - inc_t incy0; - - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); - AOCL_DTL_LOG_COPY_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'D', *n, *incx, *incy) - /* Initialize BLIS. */ -// bli_init_auto(); - - /* Convert/typecast negative values of n to zero. */ - if (*n < 0) - n0 = (dim_t)0; - else - n0 = (dim_t)(*n); - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - if (*incx < 0) - { - /* The semantics of negative stride in BLAS are that the vector - operand be traversed in reverse order. (Another way to think - of this is that negative strides effectively reverse the order - of the vector, but without any explicit data movements.) This - is also how BLIS interprets negative strides. The differences - is that with BLAS, the caller *always* passes in the 0th (i.e., - top-most or left-most) element of the vector, even when the - stride is negative. By contrast, in BLIS, negative strides are - used *relative* to the vector address as it is given. Thus, in - BLIS, if this backwards traversal is desired, the caller *must* - pass in the address to the (n-1)th (i.e., the bottom-most or - right-most) element along with a negative stride. */ - - x0 = (double*)((x)+(n0 - 1)*(-*incx)); - incx0 = (inc_t)(*incx); - - } - else - { - x0 = (double*)(x); - incx0 = (inc_t)(*incx); - } - - if (*incy < 0) - { - y0 = (y)+(n0 - 1)*(-*incy); - incy0 = (inc_t)(*incy); - - } - else - { - y0 = (y); - incy0 = (inc_t)(*incy); - } - - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. - // This function is invoked on all architectures including ‘generic’. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN4) || - (id == BLIS_ARCH_ZEN3) || - (id == BLIS_ARCH_ZEN2) || - (id == BLIS_ARCH_ZEN); - - if (bamdzen) - { - /* Call BLIS kernel */ - bli_dcopyv_zen_int - ( - BLIS_NO_CONJUGATE, - n0, - x0, incx0, - y0, incy0, - NULL - ); - } - else - { - PASTEMAC2(d, copyv, BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - n0, - x0, incx0, - y0, incy0, - NULL, - NULL - ); - } - - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) - /* Finalize BLIS. */ -// bli_finalize_auto(); -} - -INSERT_GENTFUNC_BLAS_CZ(copy, copyv) -#else INSERT_GENTFUNC_BLAS(copy, copyv) #endif -#endif diff --git a/frame/compat/bla_copy_amd.c b/frame/compat/bla_copy_amd.c new file mode 100644 index 0000000000..8dc4d5287c --- /dev/null +++ b/frame/compat/bla_copy_amd.c @@ -0,0 +1,285 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + + +// +// Define BLAS-to-BLIS interfaces. +// +#undef GENTFUNC +#define GENTFUNC( ftype, ch, blasname, blisname ) \ +\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_int* n, \ + const ftype* x, const f77_int* incx, \ + ftype* y, const f77_int* incy \ + ) \ +{ \ + dim_t n0; \ + ftype* x0; \ + ftype* y0; \ + inc_t incx0; \ + inc_t incy0; \ +\ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); \ + AOCL_DTL_LOG_COPY_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(ch), *n, *incx, *incy) \ +\ + /* Initialize BLIS. */ \ + bli_init_auto(); \ +\ + /* Convert/typecast negative values of n to zero. */ \ + bli_convert_blas_dim1( *n, n0 ); \ +\ + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ \ + bli_convert_blas_incv(n0, (ftype*)x, *incx, x0, incx0); \ + bli_convert_blas_incv(n0, (ftype*)y, *incy, y0, incy0); \ + \ + /* Call BLIS interface. */ \ + PASTEMAC2(ch, blisname, BLIS_TAPI_EX_SUF) \ + (\ + BLIS_NO_CONJUGATE, \ + n0, \ + x0, incx0, \ + y0, incy0, \ + NULL, \ + NULL \ + ); \ + \ +\ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ +\ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ +} + +#ifdef BLIS_ENABLE_BLAS + +void scopy_ +( + const f77_int* n, + const float* x, const f77_int* incx, + float* y, const f77_int* incy +) +{ + dim_t n0; + float* x0; + float* y0; + inc_t incx0; + inc_t incy0; + + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) + AOCL_DTL_LOG_COPY_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'S', *n, *incx, *incy) + /* Initialize BLIS. */ +// bli_init_auto(); + + /* Convert/typecast negative values of n to zero. */ + if (*n < 0) + n0 = (dim_t)0; + else + n0 = (dim_t)(*n); + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + if (*incx < 0) + { + /* The semantics of negative stride in BLAS are that the vector + operand be traversed in reverse order. (Another way to think + of this is that negative strides effectively reverse the order + of the vector, but without any explicit data movements.) This + is also how BLIS interprets negative strides. The differences + is that with BLAS, the caller *always* passes in the 0th (i.e., + top-most or left-most) element of the vector, even when the + stride is negative. By contrast, in BLIS, negative strides are + used *relative* to the vector address as it is given. Thus, in + BLIS, if this backwards traversal is desired, the caller *must* + pass in the address to the (n-1)th (i.e., the bottom-most or + right-most) element along with a negative stride. */ + + x0 = (float*)((x)+(n0 - 1)*(-*incx)); + incx0 = (inc_t)(*incx); + + } + else + { + x0 = (float*)(x); + incx0 = (inc_t)(*incx); + } + + if (*incy < 0) + { + y0 = (y)+(n0 - 1)*(-*incy); + incy0 = (inc_t)(*incy); + + } + else + { + y0 = (y); + incy0 = (inc_t)(*incy); + } + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) + { + /* Call BLIS kernel */ + bli_scopyv_zen_int + ( + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + NULL + ); + } + else + { + PASTEMAC2(s, copyv, BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + NULL, + NULL + ); + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) + /* Finalize BLIS. */ +// bli_finalize_auto(); +} + +void dcopy_ +( + const f77_int* n, + const double* x, const f77_int* incx, + double* y, const f77_int* incy +) +{ + dim_t n0; + double* x0; + double* y0; + inc_t incx0; + inc_t incy0; + + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); + AOCL_DTL_LOG_COPY_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'D', *n, *incx, *incy) + /* Initialize BLIS. */ +// bli_init_auto(); + + /* Convert/typecast negative values of n to zero. */ + if (*n < 0) + n0 = (dim_t)0; + else + n0 = (dim_t)(*n); + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + if (*incx < 0) + { + /* The semantics of negative stride in BLAS are that the vector + operand be traversed in reverse order. (Another way to think + of this is that negative strides effectively reverse the order + of the vector, but without any explicit data movements.) This + is also how BLIS interprets negative strides. The differences + is that with BLAS, the caller *always* passes in the 0th (i.e., + top-most or left-most) element of the vector, even when the + stride is negative. By contrast, in BLIS, negative strides are + used *relative* to the vector address as it is given. Thus, in + BLIS, if this backwards traversal is desired, the caller *must* + pass in the address to the (n-1)th (i.e., the bottom-most or + right-most) element along with a negative stride. */ + + x0 = (double*)((x)+(n0 - 1)*(-*incx)); + incx0 = (inc_t)(*incx); + + } + else + { + x0 = (double*)(x); + incx0 = (inc_t)(*incx); + } + + if (*incy < 0) + { + y0 = (y)+(n0 - 1)*(-*incy); + incy0 = (inc_t)(*incy); + + } + else + { + y0 = (y); + incy0 = (inc_t)(*incy); + } + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) + { + /* Call BLIS kernel */ + bli_dcopyv_zen_int + ( + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + NULL + ); + } + else + { + PASTEMAC2(d, copyv, BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + NULL, + NULL + ); + } + + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) + /* Finalize BLIS. */ +// bli_finalize_auto(); +} + +INSERT_GENTFUNC_BLAS_CZ(copy, copyv) + +#endif diff --git a/frame/compat/bla_dot.c b/frame/compat/bla_dot.c index 419f8c7dce..3c4d8c538f 100644 --- a/frame/compat/bla_dot.c +++ b/frame/compat/bla_dot.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018-2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -90,681 +90,11 @@ ftype PASTEF772(ch,blasname,chc) \ } #ifdef BLIS_ENABLE_BLAS -#ifdef BLIS_CONFIG_EPYC -float sdot_ - ( - const f77_int* n, - const float* x, const f77_int* incx, - const float* y, const f77_int* incy - ) -{ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); - AOCL_DTL_LOG_DOTV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'S', *n, *incx, *incy); - dim_t n0; - float* x0; - float* y0; - inc_t incx0; - inc_t incy0; - float rho; - - /* Initialize BLIS. */ -// bli_init_auto(); - - /* Convert/typecast negative values of n to zero. */ - if ( *n < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(*n); - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - - if ( *incx < 0 ) - { - /* The semantics of negative stride in BLAS are that the vector - operand be traversed in reverse order. (Another way to think - of this is that negative strides effectively reverse the order - of the vector, but without any explicit data movements.) This - is also how BLIS interprets negative strides. The differences - is that with BLAS, the caller *always* passes in the 0th (i.e., - top-most or left-most) element of the vector, even when the - stride is negative. By contrast, in BLIS, negative strides are - used *relative* to the vector address as it is given. Thus, in - BLIS, if this backwards traversal is desired, the caller *must* - pass in the address to the (n-1)th (i.e., the bottom-most or - right-most) element along with a negative stride. */ - - x0 = ((float*)x) + (n0-1)*(-*incx); - incx0 = ( inc_t )(*incx); - - } - else - { - x0 = ((float*)x); - incx0 = ( inc_t )(*incx); - } - - if ( *incy < 0 ) - { - y0 = ((float*)y) + (n0-1)*(-*incy); - incy0 = ( inc_t )(*incy); - - } - else - { - y0 = ((float*)y); - incy0 = ( inc_t )(*incy); - } - - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. - // This function is invoked on all architectures including ‘generic’. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN4) || - (id == BLIS_ARCH_ZEN3) || - (id == BLIS_ARCH_ZEN2) || - (id == BLIS_ARCH_ZEN); - - if (bamdzen) - { - /* Call BLIS kernel. */ - bli_sdotv_zen_int10 - ( - BLIS_NO_CONJUGATE, - BLIS_NO_CONJUGATE, - n0, - x0, incx0, - y0, incy0, - &rho, - NULL - ); - } - else - { - /* Call BLIS interface. */ - PASTEMAC2(s,dotv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - BLIS_NO_CONJUGATE, - n0, - x0, incx0, - y0, incy0, - &rho, - NULL, - NULL - ); - } - - /* Finalize BLIS. */ -// bli_finalize_auto(); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return rho; -} - -double ddot_ - ( - const f77_int* n, - const double* x, const f77_int* incx, - const double* y, const f77_int* incy - ) -{ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); - AOCL_DTL_LOG_DOTV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'D', *n, *incx, *incy); - dim_t n0; - double* x0; - double* y0; - inc_t incx0; - inc_t incy0; - double rho; - - /* Initialize BLIS. */ -// bli_init_auto(); - - /* Convert/typecast negative values of n to zero. */ - if ( *n < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(*n); - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - - if ( *incx < 0 ) - { - /* The semantics of negative stride in BLAS are that the vector - operand be traversed in reverse order. (Another way to think - of this is that negative strides effectively reverse the order - of the vector, but without any explicit data movements.) This - is also how BLIS interprets negative strides. The differences - is that with BLAS, the caller *always* passes in the 0th (i.e., - top-most or left-most) element of the vector, even when the - stride is negative. By contrast, in BLIS, negative strides are - used *relative* to the vector address as it is given. Thus, in - BLIS, if this backwards traversal is desired, the caller *must* - pass in the address to the (n-1)th (i.e., the bottom-most or - right-most) element along with a negative stride. */ - - x0 = ((double*)x) + (n0-1)*(-*incx); - incx0 = ( inc_t )(*incx); - - } - else - { - x0 = ((double*)x); - incx0 = ( inc_t )(*incx); - } - - if ( *incy < 0 ) - { - y0 = ((double*)y) + (n0-1)*(-*incy); - incy0 = ( inc_t )(*incy); - - } - else - { - y0 = ((double*)y); - incy0 = ( inc_t )(*incy); - } - - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. - // This function is invoked on all architectures including ‘generic’. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN4) || - (id == BLIS_ARCH_ZEN3) || - (id == BLIS_ARCH_ZEN2) || - (id == BLIS_ARCH_ZEN); - - if (bamdzen) - { - /* Call BLIS kernel. */ - bli_ddotv_zen_int10 - ( - BLIS_NO_CONJUGATE, - BLIS_NO_CONJUGATE, - n0, - x0, incx0, - y0, incy0, - &rho, - NULL - ); - } - else - { - /* Call BLIS interface. */ - PASTEMAC2(d,dotv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - BLIS_NO_CONJUGATE, - n0, - x0, incx0, - y0, incy0, - &rho, - NULL, - NULL - ); - } - - /* Finalize BLIS. */ -// bli_finalize_auto(); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return rho; -} -#else INSERT_GENTFUNCDOTR_BLAS( dot, dotv ) -#endif #ifdef BLIS_ENABLE_BLAS #ifdef BLIS_DISABLE_COMPLEX_RETURN_INTEL -#ifdef BLIS_CONFIG_EPYC -scomplex cdotu_ - ( - const f77_int* n, - const scomplex* x, const f77_int* incx, - const scomplex* y, const f77_int* incy - ) -{ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); - AOCL_DTL_LOG_DOTV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'C', *n, *incx, *incy); - dim_t n0; - scomplex* x0; - scomplex* y0; - inc_t incx0; - inc_t incy0; - scomplex rho; - - /* Initialize BLIS. */ -// bli_init_auto(); - - /* Convert/typecast negative values of n to zero. */ - if ( *n < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(*n); - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - - if ( *incx < 0 ) - { - /* The semantics of negative stride in BLAS are that the vector - operand be traversed in reverse order. (Another way to think - of this is that negative strides effectively reverse the order - of the vector, but without any explicit data movements.) This - is also how BLIS interprets negative strides. The differences - is that with BLAS, the caller *always* passes in the 0th (i.e., - top-most or left-most) element of the vector, even when the - stride is negative. By contrast, in BLIS, negative strides are - used *relative* to the vector address as it is given. Thus, in - BLIS, if this backwards traversal is desired, the caller *must* - pass in the address to the (n-1)th (i.e., the bottom-most or - right-most) element along with a negative stride. */ - - x0 = ((scomplex*)x) + (n0-1)*(-*incx); - incx0 = ( inc_t )(*incx); - - } - else - { - x0 = ((scomplex*)x); - incx0 = ( inc_t )(*incx); - } - - if ( *incy < 0 ) - { - y0 = ((scomplex*)y) + (n0-1)*(-*incy); - incy0 = ( inc_t )(*incy); - - } - else - { - y0 = ((scomplex*)y); - incy0 = ( inc_t )(*incy); - } - - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. - // This function is invoked on all architectures including ‘generic’. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN4) || - (id == BLIS_ARCH_ZEN3) || - (id == BLIS_ARCH_ZEN2) || - (id == BLIS_ARCH_ZEN); - - if (bamdzen) - { - /* Call BLIS kernel. */ - bli_cdotv_zen_int5 - ( - BLIS_NO_CONJUGATE, - BLIS_NO_CONJUGATE, - n0, - x0, incx0, - y0, incy0, - &rho, - NULL - ); - } - else - { - /* Call BLIS interface. */ - PASTEMAC2(c,dotv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - BLIS_NO_CONJUGATE, - n0, - x0, incx0, - y0, incy0, - &rho, - NULL, - NULL - ); - } - - /* Finalize BLIS. */ -// bli_finalize_auto(); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return rho; -} - -dcomplex zdotu_ - ( - const f77_int* n, - const dcomplex* x, const f77_int* incx, - const dcomplex* y, const f77_int* incy - ) -{ - dim_t n0; - dcomplex* x0; - dcomplex* y0; - inc_t incx0; - inc_t incy0; - dcomplex rho; - - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); - AOCL_DTL_LOG_DOTV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'Z', *n, *incx, *incy); - - /* Initialize BLIS. */ -// bli_init_auto(); - - /* Convert/typecast negative values of n to zero. */ - if ( *n < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(*n); - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - - if ( *incx < 0 ) - { - /* The semantics of negative stride in BLAS are that the vector - operand be traversed in reverse order. (Another way to think - of this is that negative strides effectively reverse the order - of the vector, but without any explicit data movements.) This - is also how BLIS interprets negative strides. The differences - is that with BLAS, the caller *always* passes in the 0th (i.e., - top-most or left-most) element of the vector, even when the - stride is negative. By contrast, in BLIS, negative strides are - used *relative* to the vector address as it is given. Thus, in - BLIS, if this backwards traversal is desired, the caller *must* - pass in the address to the (n-1)th (i.e., the bottom-most or - right-most) element along with a negative stride. */ - - x0 = ((dcomplex*)x) + (n0-1)*(-*incx); - incx0 = ( inc_t )(*incx); - - } - else - { - x0 = ((dcomplex*)x); - incx0 = ( inc_t )(*incx); - } - - if ( *incy < 0 ) - { - y0 = ((dcomplex*)y) + (n0-1)*(-*incy); - incy0 = ( inc_t )(*incy); - - } - else - { - y0 = ((dcomplex*)y); - incy0 = ( inc_t )(*incy); - } - - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. - // This function is invoked on all architectures including ‘generic’. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN4) || - (id == BLIS_ARCH_ZEN3) || - (id == BLIS_ARCH_ZEN2) || - (id == BLIS_ARCH_ZEN); - - if (bamdzen) - { - /* Call BLIS kernel. */ - bli_zdotv_zen_int5 - ( - BLIS_NO_CONJUGATE, - BLIS_NO_CONJUGATE, - n0, - x0, incx0, - y0, incy0, - &rho, - NULL - ); - } - else - { - /* Call BLIS interface. */ - PASTEMAC2(z,dotv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - BLIS_NO_CONJUGATE, - n0, - x0, incx0, - y0, incy0, - &rho, - NULL, - NULL - ); - } - - /* Finalize BLIS. */ -// bli_finalize_auto(); - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - - return rho; -} - - -scomplex cdotc_ - ( - const f77_int* n, - const scomplex* x, const f77_int* incx, - const scomplex* y, const f77_int* incy - ) -{ - dim_t n0; - scomplex* x0; - scomplex* y0; - inc_t incx0; - inc_t incy0; - scomplex rho; - - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); - AOCL_DTL_LOG_DOTV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'C', *n, *incx, *incy); - - /* Initialize BLIS. */ -// bli_init_auto(); - - /* Convert/typecast negative values of n to zero. */ - if ( *n < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(*n); - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - - if ( *incx < 0 ) - { - /* The semantics of negative stride in BLAS are that the vector - operand be traversed in reverse order. (Another way to think - of this is that negative strides effectively reverse the order - of the vector, but without any explicit data movements.) This - is also how BLIS interprets negative strides. The differences - is that with BLAS, the caller *always* passes in the 0th (i.e., - top-most or left-most) element of the vector, even when the - stride is negative. By contrast, in BLIS, negative strides are - used *relative* to the vector address as it is given. Thus, in - BLIS, if this backwards traversal is desired, the caller *must* - pass in the address to the (n-1)th (i.e., the bottom-most or - right-most) element along with a negative stride. */ - - x0 = ((scomplex*)x) + (n0-1)*(-*incx); - incx0 = ( inc_t )(*incx); - - } - else - { - x0 = ((scomplex*)x); - incx0 = ( inc_t )(*incx); - } - - if ( *incy < 0 ) - { - y0 = ((scomplex*)y) + (n0-1)*(-*incy); - incy0 = ( inc_t )(*incy); - - } - else - { - y0 = ((scomplex*)y); - incy0 = ( inc_t )(*incy); - } - - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. - // This function is invoked on all architectures including ‘generic’. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN4) || - (id == BLIS_ARCH_ZEN3) || - (id == BLIS_ARCH_ZEN2) || - (id == BLIS_ARCH_ZEN); - - if (bamdzen) - { - /* Call BLIS kernel. */ - bli_cdotv_zen_int5 - ( - BLIS_CONJUGATE, - BLIS_NO_CONJUGATE, - n0, - x0, incx0, - y0, incy0, - &rho, - NULL - ); - } - else - { - /* Call BLIS interface. */ - PASTEMAC2(c,dotv,BLIS_TAPI_EX_SUF) - ( - BLIS_CONJUGATE, - BLIS_NO_CONJUGATE, - n0, - x0, incx0, - y0, incy0, - &rho, - NULL, - NULL - ); - } - - /* Finalize BLIS. */ -// bli_finalize_auto(); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - - return rho; -} - -dcomplex zdotc_ - ( - const f77_int* n, - const dcomplex* x, const f77_int* incx, - const dcomplex* y, const f77_int* incy - ) -{ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); - AOCL_DTL_LOG_DOTV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'Z', *n, *incx, *incy); - dim_t n0; - dcomplex* x0; - dcomplex* y0; - inc_t incx0; - inc_t incy0; - dcomplex rho; - - /* Initialize BLIS. */ -// bli_init_auto(); - - /* Convert/typecast negative values of n to zero. */ - if ( *n < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(*n); - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - - if ( *incx < 0 ) - { - /* The semantics of negative stride in BLAS are that the vector - operand be traversed in reverse order. (Another way to think - of this is that negative strides effectively reverse the order - of the vector, but without any explicit data movements.) This - is also how BLIS interprets negative strides. The differences - is that with BLAS, the caller *always* passes in the 0th (i.e., - top-most or left-most) element of the vector, even when the - stride is negative. By contrast, in BLIS, negative strides are - used *relative* to the vector address as it is given. Thus, in - BLIS, if this backwards traversal is desired, the caller *must* - pass in the address to the (n-1)th (i.e., the bottom-most or - right-most) element along with a negative stride. */ - - x0 = ((dcomplex*)x) + (n0-1)*(-*incx); - incx0 = ( inc_t )(*incx); - - } - else - { - x0 = ((dcomplex*)x); - incx0 = ( inc_t )(*incx); - } - - if ( *incy < 0 ) - { - y0 = ((dcomplex*)y) + (n0-1)*(-*incy); - incy0 = ( inc_t )(*incy); - - } - else - { - y0 = ((dcomplex*)y); - incy0 = ( inc_t )(*incy); - } - - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. - // This function is invoked on all architectures including ‘generic’. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN4) || - (id == BLIS_ARCH_ZEN3) || - (id == BLIS_ARCH_ZEN2) || - (id == BLIS_ARCH_ZEN); - - if (bamdzen) - { - /* Call BLIS kernel. */ - bli_zdotv_zen_int5 - ( - BLIS_CONJUGATE, - BLIS_NO_CONJUGATE, - n0, - x0, incx0, - y0, incy0, - &rho, - NULL - ); - } - else - { - /* Call BLIS interface. */ - PASTEMAC2(z,dotv,BLIS_TAPI_EX_SUF) - ( - BLIS_CONJUGATE, - BLIS_NO_CONJUGATE, - n0, - x0, incx0, - y0, incy0, - &rho, - NULL, - NULL - ); - } - - - - - - /* Finalize BLIS. */ -// bli_finalize_auto(); - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - - return rho; -} -#else INSERT_GENTFUNCDOTC_BLAS( dot, dotv ) -#endif #else // For the "intel" complex return type, use a hidden parameter to return the result #undef GENTFUNCDOT @@ -819,8 +149,8 @@ void PASTEF772(ch,blasname,chc) \ } INSERT_GENTFUNCDOTC_BLAS( dot, dotv ) -#endif -#endif +#endif // BLIS_DISABLE_COMPLEX_RETURN_INTEL +#endif // BLIS_ENABLE_BLAS // -- "Black sheep" dot product function definitions -- @@ -894,4 +224,4 @@ double PASTEF77(d,sdot) return rho; } -#endif +#endif // BLIS_ENABLE_BLAS diff --git a/frame/compat/bla_dot_amd.c b/frame/compat/bla_dot_amd.c new file mode 100644 index 0000000000..0cdaa6535b --- /dev/null +++ b/frame/compat/bla_dot_amd.c @@ -0,0 +1,841 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + + +// +// Define BLAS-to-BLIS interfaces. +// +#undef GENTFUNCDOT +#define GENTFUNCDOT( ftype, ch, chc, blis_conjx, blasname, blisname ) \ +\ +ftype PASTEF772(ch,blasname,chc) \ + ( \ + const f77_int* n, \ + const ftype* x, const f77_int* incx, \ + const ftype* y, const f77_int* incy \ + ) \ +{ \ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); \ + AOCL_DTL_LOG_DOTV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(ch), *n, *incx, *incy); \ + dim_t n0; \ + ftype* x0; \ + ftype* y0; \ + inc_t incx0; \ + inc_t incy0; \ + ftype rho; \ +\ + /* Initialize BLIS. */ \ + bli_init_auto(); \ +\ + /* Convert/typecast negative values of n to zero. */ \ + bli_convert_blas_dim1( *n, n0 ); \ +\ + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ \ + bli_convert_blas_incv( n0, (ftype*)x, *incx, x0, incx0 ); \ + bli_convert_blas_incv( n0, (ftype*)y, *incy, y0, incy0 ); \ +\ + /* Call BLIS interface. */ \ + PASTEMAC2(ch,blisname,BLIS_TAPI_EX_SUF) \ + ( \ + blis_conjx, \ + BLIS_NO_CONJUGATE, \ + n0, \ + x0, incx0, \ + y0, incy0, \ + &rho, \ + NULL, \ + NULL \ + ); \ +\ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ +\ + return rho; \ +} + +#ifdef BLIS_ENABLE_BLAS +float sdot_ + ( + const f77_int* n, + const float* x, const f77_int* incx, + const float* y, const f77_int* incy + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); + AOCL_DTL_LOG_DOTV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'S', *n, *incx, *incy); + dim_t n0; + float* x0; + float* y0; + inc_t incx0; + inc_t incy0; + float rho; + + /* Initialize BLIS. */ +// bli_init_auto(); + + /* Convert/typecast negative values of n to zero. */ + if ( *n < 0 ) n0 = ( dim_t )0; + else n0 = ( dim_t )(*n); + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + + if ( *incx < 0 ) + { + /* The semantics of negative stride in BLAS are that the vector + operand be traversed in reverse order. (Another way to think + of this is that negative strides effectively reverse the order + of the vector, but without any explicit data movements.) This + is also how BLIS interprets negative strides. The differences + is that with BLAS, the caller *always* passes in the 0th (i.e., + top-most or left-most) element of the vector, even when the + stride is negative. By contrast, in BLIS, negative strides are + used *relative* to the vector address as it is given. Thus, in + BLIS, if this backwards traversal is desired, the caller *must* + pass in the address to the (n-1)th (i.e., the bottom-most or + right-most) element along with a negative stride. */ + + x0 = ((float*)x) + (n0-1)*(-*incx); + incx0 = ( inc_t )(*incx); + + } + else + { + x0 = ((float*)x); + incx0 = ( inc_t )(*incx); + } + + if ( *incy < 0 ) + { + y0 = ((float*)y) + (n0-1)*(-*incy); + incy0 = ( inc_t )(*incy); + + } + else + { + y0 = ((float*)y); + incy0 = ( inc_t )(*incy); + } + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) + { + /* Call BLIS kernel. */ + bli_sdotv_zen_int10 + ( + BLIS_NO_CONJUGATE, + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + &rho, + NULL + ); + } + else + { + /* Call BLIS interface. */ + PASTEMAC2(s,dotv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + &rho, + NULL, + NULL + ); + } + + /* Finalize BLIS. */ +// bli_finalize_auto(); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return rho; +} + +double ddot_ + ( + const f77_int* n, + const double* x, const f77_int* incx, + const double* y, const f77_int* incy + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); + AOCL_DTL_LOG_DOTV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'D', *n, *incx, *incy); + dim_t n0; + double* x0; + double* y0; + inc_t incx0; + inc_t incy0; + double rho; + + /* Initialize BLIS. */ +// bli_init_auto(); + + /* Convert/typecast negative values of n to zero. */ + if ( *n < 0 ) n0 = ( dim_t )0; + else n0 = ( dim_t )(*n); + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + + if ( *incx < 0 ) + { + /* The semantics of negative stride in BLAS are that the vector + operand be traversed in reverse order. (Another way to think + of this is that negative strides effectively reverse the order + of the vector, but without any explicit data movements.) This + is also how BLIS interprets negative strides. The differences + is that with BLAS, the caller *always* passes in the 0th (i.e., + top-most or left-most) element of the vector, even when the + stride is negative. By contrast, in BLIS, negative strides are + used *relative* to the vector address as it is given. Thus, in + BLIS, if this backwards traversal is desired, the caller *must* + pass in the address to the (n-1)th (i.e., the bottom-most or + right-most) element along with a negative stride. */ + + x0 = ((double*)x) + (n0-1)*(-*incx); + incx0 = ( inc_t )(*incx); + + } + else + { + x0 = ((double*)x); + incx0 = ( inc_t )(*incx); + } + + if ( *incy < 0 ) + { + y0 = ((double*)y) + (n0-1)*(-*incy); + incy0 = ( inc_t )(*incy); + + } + else + { + y0 = ((double*)y); + incy0 = ( inc_t )(*incy); + } + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) + { + /* Call BLIS kernel. */ + bli_ddotv_zen_int10 + ( + BLIS_NO_CONJUGATE, + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + &rho, + NULL + ); + } + else + { + /* Call BLIS interface. */ + PASTEMAC2(d,dotv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + &rho, + NULL, + NULL + ); + } + + /* Finalize BLIS. */ +// bli_finalize_auto(); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return rho; +} + +#ifdef BLIS_DISABLE_COMPLEX_RETURN_INTEL +scomplex cdotu_ + ( + const f77_int* n, + const scomplex* x, const f77_int* incx, + const scomplex* y, const f77_int* incy + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); + AOCL_DTL_LOG_DOTV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'C', *n, *incx, *incy); + dim_t n0; + scomplex* x0; + scomplex* y0; + inc_t incx0; + inc_t incy0; + scomplex rho; + + /* Initialize BLIS. */ +// bli_init_auto(); + + /* Convert/typecast negative values of n to zero. */ + if ( *n < 0 ) n0 = ( dim_t )0; + else n0 = ( dim_t )(*n); + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + + if ( *incx < 0 ) + { + /* The semantics of negative stride in BLAS are that the vector + operand be traversed in reverse order. (Another way to think + of this is that negative strides effectively reverse the order + of the vector, but without any explicit data movements.) This + is also how BLIS interprets negative strides. The differences + is that with BLAS, the caller *always* passes in the 0th (i.e., + top-most or left-most) element of the vector, even when the + stride is negative. By contrast, in BLIS, negative strides are + used *relative* to the vector address as it is given. Thus, in + BLIS, if this backwards traversal is desired, the caller *must* + pass in the address to the (n-1)th (i.e., the bottom-most or + right-most) element along with a negative stride. */ + + x0 = ((scomplex*)x) + (n0-1)*(-*incx); + incx0 = ( inc_t )(*incx); + + } + else + { + x0 = ((scomplex*)x); + incx0 = ( inc_t )(*incx); + } + + if ( *incy < 0 ) + { + y0 = ((scomplex*)y) + (n0-1)*(-*incy); + incy0 = ( inc_t )(*incy); + + } + else + { + y0 = ((scomplex*)y); + incy0 = ( inc_t )(*incy); + } + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) + { + /* Call BLIS kernel. */ + bli_cdotv_zen_int5 + ( + BLIS_NO_CONJUGATE, + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + &rho, + NULL + ); + } + else + { + /* Call BLIS interface. */ + PASTEMAC2(c,dotv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + &rho, + NULL, + NULL + ); + } + + /* Finalize BLIS. */ +// bli_finalize_auto(); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return rho; +} + +dcomplex zdotu_ + ( + const f77_int* n, + const dcomplex* x, const f77_int* incx, + const dcomplex* y, const f77_int* incy + ) +{ + dim_t n0; + dcomplex* x0; + dcomplex* y0; + inc_t incx0; + inc_t incy0; + dcomplex rho; + + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); + AOCL_DTL_LOG_DOTV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'Z', *n, *incx, *incy); + + /* Initialize BLIS. */ +// bli_init_auto(); + + /* Convert/typecast negative values of n to zero. */ + if ( *n < 0 ) n0 = ( dim_t )0; + else n0 = ( dim_t )(*n); + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + + if ( *incx < 0 ) + { + /* The semantics of negative stride in BLAS are that the vector + operand be traversed in reverse order. (Another way to think + of this is that negative strides effectively reverse the order + of the vector, but without any explicit data movements.) This + is also how BLIS interprets negative strides. The differences + is that with BLAS, the caller *always* passes in the 0th (i.e., + top-most or left-most) element of the vector, even when the + stride is negative. By contrast, in BLIS, negative strides are + used *relative* to the vector address as it is given. Thus, in + BLIS, if this backwards traversal is desired, the caller *must* + pass in the address to the (n-1)th (i.e., the bottom-most or + right-most) element along with a negative stride. */ + + x0 = ((dcomplex*)x) + (n0-1)*(-*incx); + incx0 = ( inc_t )(*incx); + + } + else + { + x0 = ((dcomplex*)x); + incx0 = ( inc_t )(*incx); + } + + if ( *incy < 0 ) + { + y0 = ((dcomplex*)y) + (n0-1)*(-*incy); + incy0 = ( inc_t )(*incy); + + } + else + { + y0 = ((dcomplex*)y); + incy0 = ( inc_t )(*incy); + } + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) + { + /* Call BLIS kernel. */ + bli_zdotv_zen_int5 + ( + BLIS_NO_CONJUGATE, + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + &rho, + NULL + ); + } + else + { + /* Call BLIS interface. */ + PASTEMAC2(z,dotv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + &rho, + NULL, + NULL + ); + } + + /* Finalize BLIS. */ +// bli_finalize_auto(); + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + + return rho; +} + + +scomplex cdotc_ + ( + const f77_int* n, + const scomplex* x, const f77_int* incx, + const scomplex* y, const f77_int* incy + ) +{ + dim_t n0; + scomplex* x0; + scomplex* y0; + inc_t incx0; + inc_t incy0; + scomplex rho; + + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); + AOCL_DTL_LOG_DOTV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'C', *n, *incx, *incy); + + /* Initialize BLIS. */ +// bli_init_auto(); + + /* Convert/typecast negative values of n to zero. */ + if ( *n < 0 ) n0 = ( dim_t )0; + else n0 = ( dim_t )(*n); + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + + if ( *incx < 0 ) + { + /* The semantics of negative stride in BLAS are that the vector + operand be traversed in reverse order. (Another way to think + of this is that negative strides effectively reverse the order + of the vector, but without any explicit data movements.) This + is also how BLIS interprets negative strides. The differences + is that with BLAS, the caller *always* passes in the 0th (i.e., + top-most or left-most) element of the vector, even when the + stride is negative. By contrast, in BLIS, negative strides are + used *relative* to the vector address as it is given. Thus, in + BLIS, if this backwards traversal is desired, the caller *must* + pass in the address to the (n-1)th (i.e., the bottom-most or + right-most) element along with a negative stride. */ + + x0 = ((scomplex*)x) + (n0-1)*(-*incx); + incx0 = ( inc_t )(*incx); + + } + else + { + x0 = ((scomplex*)x); + incx0 = ( inc_t )(*incx); + } + + if ( *incy < 0 ) + { + y0 = ((scomplex*)y) + (n0-1)*(-*incy); + incy0 = ( inc_t )(*incy); + + } + else + { + y0 = ((scomplex*)y); + incy0 = ( inc_t )(*incy); + } + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) + { + /* Call BLIS kernel. */ + bli_cdotv_zen_int5 + ( + BLIS_CONJUGATE, + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + &rho, + NULL + ); + } + else + { + /* Call BLIS interface. */ + PASTEMAC2(c,dotv,BLIS_TAPI_EX_SUF) + ( + BLIS_CONJUGATE, + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + &rho, + NULL, + NULL + ); + } + + /* Finalize BLIS. */ +// bli_finalize_auto(); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + + return rho; +} + +dcomplex zdotc_ + ( + const f77_int* n, + const dcomplex* x, const f77_int* incx, + const dcomplex* y, const f77_int* incy + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); + AOCL_DTL_LOG_DOTV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'Z', *n, *incx, *incy); + dim_t n0; + dcomplex* x0; + dcomplex* y0; + inc_t incx0; + inc_t incy0; + dcomplex rho; + + /* Initialize BLIS. */ +// bli_init_auto(); + + /* Convert/typecast negative values of n to zero. */ + if ( *n < 0 ) n0 = ( dim_t )0; + else n0 = ( dim_t )(*n); + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + + if ( *incx < 0 ) + { + /* The semantics of negative stride in BLAS are that the vector + operand be traversed in reverse order. (Another way to think + of this is that negative strides effectively reverse the order + of the vector, but without any explicit data movements.) This + is also how BLIS interprets negative strides. The differences + is that with BLAS, the caller *always* passes in the 0th (i.e., + top-most or left-most) element of the vector, even when the + stride is negative. By contrast, in BLIS, negative strides are + used *relative* to the vector address as it is given. Thus, in + BLIS, if this backwards traversal is desired, the caller *must* + pass in the address to the (n-1)th (i.e., the bottom-most or + right-most) element along with a negative stride. */ + + x0 = ((dcomplex*)x) + (n0-1)*(-*incx); + incx0 = ( inc_t )(*incx); + + } + else + { + x0 = ((dcomplex*)x); + incx0 = ( inc_t )(*incx); + } + + if ( *incy < 0 ) + { + y0 = ((dcomplex*)y) + (n0-1)*(-*incy); + incy0 = ( inc_t )(*incy); + + } + else + { + y0 = ((dcomplex*)y); + incy0 = ( inc_t )(*incy); + } + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) + { + /* Call BLIS kernel. */ + bli_zdotv_zen_int5 + ( + BLIS_CONJUGATE, + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + &rho, + NULL + ); + } + else + { + /* Call BLIS interface. */ + PASTEMAC2(z,dotv,BLIS_TAPI_EX_SUF) + ( + BLIS_CONJUGATE, + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + &rho, + NULL, + NULL + ); + } + + + + + + /* Finalize BLIS. */ +// bli_finalize_auto(); + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + + return rho; +} + +#else // BLIS_DISABLE_COMPLEX_RETURN_INTEL +// For the "intel" complex return type, use a hidden parameter to return the result +#undef GENTFUNCDOT +#define GENTFUNCDOT( ftype, ch, chc, blis_conjx, blasname, blisname ) \ +\ +void PASTEF772(ch,blasname,chc) \ + ( \ + ftype* rhop, \ + const f77_int* n, \ + const ftype* x, const f77_int* incx, \ + const ftype* y, const f77_int* incy \ + ) \ +{ \ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); \ + AOCL_DTL_LOG_DOTV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(ch), *n, *incx, *incy); \ + dim_t n0; \ + ftype* x0; \ + ftype* y0; \ + inc_t incx0; \ + inc_t incy0; \ + ftype rho; \ +\ + /* Initialize BLIS. */ \ + bli_init_auto(); \ +\ + /* Convert/typecast negative values of n to zero. */ \ + bli_convert_blas_dim1( *n, n0 ); \ +\ + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ \ + bli_convert_blas_incv( n0, (ftype*)x, *incx, x0, incx0 ); \ + bli_convert_blas_incv( n0, (ftype*)y, *incy, y0, incy0 ); \ +\ + /* Call BLIS interface. */ \ + PASTEMAC2(ch,blisname,BLIS_TAPI_EX_SUF) \ + ( \ + blis_conjx, \ + BLIS_NO_CONJUGATE, \ + n0, \ + x0, incx0, \ + y0, incy0, \ + &rho, \ + NULL, \ + NULL \ + ); \ +\ + /* Finalize BLIS. */ \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ + bli_finalize_auto(); \ +\ + *rhop = rho; \ +} + +INSERT_GENTFUNCDOTC_BLAS( dot, dotv ) +#endif // BLIS_DISABLE_COMPLEX_RETURN_INTEL + + + +// -- "Black sheep" dot product function definitions -- + +// Input vectors stored in single precision, computed in double precision, +// with result returned in single precision. +float PASTEF77(sd,sdot) + ( + const f77_int* n, + const float* sb, + const float* x, const f77_int* incx, + const float* y, const f77_int* incy + ) +{ + return ( float ) + ( + ( double )(*sb) + + PASTEF77(d,sdot) + ( + n, + x, incx, + y, incy + ) + ); +} + +// Input vectors stored in single precision, computed in double precision, +// with result returned in double precision. +double PASTEF77(d,sdot) + ( + const f77_int* n, + const float* x, const f77_int* incx, + const float* y, const f77_int* incy + ) +{ + dim_t n0; + float* x0; + float* y0; + inc_t incx0; + inc_t incy0; + double rho; + dim_t i; + + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); + AOCL_DTL_LOG_DOTV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'D', *n, *incx, *incy); + /* Initialization of BLIS is not required. */ + + /* Convert/typecast negative values of n to zero. */ + bli_convert_blas_dim1( *n, n0 ); + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + bli_convert_blas_incv( n0, (float*)x, *incx, x0, incx0 ); + bli_convert_blas_incv( n0, (float*)y, *incy, y0, incy0 ); + + rho = 0.0; + + for ( i = 0; i < n0; i++ ) + { + float* chi1 = x0 + (i )*incx0; + float* psi1 = y0 + (i )*incy0; + + bli_ddots( (( double )(*chi1)), + (( double )(*psi1)), rho ); + } + + /* Finalization of BLIS is not required, because initialization was + not required. */ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + + return rho; +} + +#endif diff --git a/frame/compat/bla_gemm.c b/frame/compat/bla_gemm.c index 80ad197c68..8d08a9e010 100644 --- a/frame/compat/bla_gemm.c +++ b/frame/compat/bla_gemm.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019 - 21, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 22, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -300,512 +300,7 @@ void PASTEF77(ch,blasname) \ #endif #ifdef BLIS_ENABLE_BLAS -#ifdef BLIS_CONFIG_EPYC -void dgemm_ -( - const f77_char* transa, - const f77_char* transb, - const f77_int* m, - const f77_int* n, - const f77_int* k, - const double* alpha, - const double* a, const f77_int* lda, - const double* b, const f77_int* ldb, - const double* beta, - double* c, const f77_int* ldc -) -{ - - - - trans_t blis_transa; - trans_t blis_transb; - dim_t m0, n0, k0; - - /* Initialize BLIS. */ - bli_init_auto(); - - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) - AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(d), *transa, *transb, *m, *n, *k, \ - (void*)alpha, *lda, *ldb, (void*)beta, *ldc); - - /* Perform BLAS parameter checking. */ - PASTEBLACHK(gemm) - ( - MKSTR(d), - MKSTR(gemm), - transa, - transb, - m, - n, - k, - lda, - ldb, - ldc - ); - - /* Map BLAS chars to their corresponding BLIS enumerated type value. */ - bli_param_map_netlib_to_blis_trans(*transa, &blis_transa); - bli_param_map_netlib_to_blis_trans(*transb, &blis_transb); - - /* Typecast BLAS integers to BLIS integers. */ - bli_convert_blas_dim1(*m, m0); - bli_convert_blas_dim1(*n, n0); - bli_convert_blas_dim1(*k, k0); - - - /* Set the row and column strides of the matrix operands. */ - const inc_t rs_a = 1; - const inc_t cs_a = *lda; - const inc_t rs_b = 1; - const inc_t cs_b = *ldb; - const inc_t rs_c = 1; - const inc_t cs_c = *ldc; - - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. - // This function is invoked on all architectures including ‘generic’. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN4) || - (id == BLIS_ARCH_ZEN3) || - (id == BLIS_ARCH_ZEN2) || - (id == BLIS_ARCH_ZEN); - - if (!bamdzen) - { - // This code is duplicated below, however we don't want to move it out of - // this IF block as it will affect the performance on Zen architetures - // Also this is temporary fix which will be replaced later. - const num_t dt = BLIS_DOUBLE; - - obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; - obj_t ao = BLIS_OBJECT_INITIALIZER; - obj_t bo = BLIS_OBJECT_INITIALIZER; - obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; - obj_t co = BLIS_OBJECT_INITIALIZER; - - dim_t m0_a, n0_a; - dim_t m0_b, n0_b; - - bli_set_dims_with_trans(blis_transa, m0, k0, &m0_a, &n0_a); - bli_set_dims_with_trans(blis_transb, k0, n0, &m0_b, &n0_b); - - bli_obj_init_finish_1x1(dt, (double *)alpha, &alphao); - bli_obj_init_finish_1x1(dt, (double *)beta, &betao); - - bli_obj_init_finish(dt, m0_a, n0_a, (double *)a, rs_a, cs_a, &ao); - bli_obj_init_finish(dt, m0_b, n0_b, (double *)b, rs_b, cs_b, &bo); - bli_obj_init_finish(dt, m0, n0, (double *)c, rs_c, cs_c, &co); - - bli_obj_set_conjtrans(blis_transa, &ao); - bli_obj_set_conjtrans(blis_transb, &bo); - - // Will call parallelized dgemm code - sup & native - PASTEMAC(gemm, BLIS_OAPI_EX_SUF) - ( - &alphao, - &ao, - &bo, - &betao, - &co, - NULL, - NULL - ); - - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - /* Finalize BLIS. */ - bli_finalize_auto(); - return; - } - - if((k0 == 1) && bli_is_notrans(blis_transa) && bli_is_notrans(blis_transb)) - { - bli_dgemm_ref_k1_nn( m0, n0, k0, - (double*)alpha, - (double*)a, *lda, - (double*)b, *ldb, - (double*)beta, - c, *ldc - ); - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - /* Finalize BLIS */ - bli_finalize_auto(); - - return; - } - - if (n0 == 1) - { - if (bli_is_notrans(blis_transa)) - { - bli_dgemv_unf_var2( - BLIS_NO_TRANSPOSE, - bli_extract_conj(blis_transb), - m0, k0, - (double*)alpha, - (double*)a, rs_a, cs_a, - (double*)b, bli_is_notrans(blis_transb) ? rs_b : cs_b, - (double*)beta, - c, rs_c, - ((void*)0) - ); - } - else - { - bli_dgemv_unf_var1( - blis_transa, - bli_extract_conj(blis_transb), - k0, m0, - (double*)alpha, - (double*)a, rs_a, cs_a, - (double*)b, bli_is_notrans(blis_transb) ? rs_b : cs_b, - (double*)beta, - c, rs_c, - ((void*)0) - ); - } - - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - - return; - } - else if (m0 == 1) - { - if (bli_is_notrans(blis_transb)) - { - bli_dgemv_unf_var1( - blis_transb, - bli_extract_conj(blis_transa), - n0, k0, - (double*)alpha, - (double*)b, cs_b, rs_b, - (double*)a, bli_is_notrans(blis_transa) ? cs_a : rs_a, - (double*)beta, - c, cs_c, - ((void*)0) - ); - } - else - { - bli_dgemv_unf_var2( - blis_transb, - bli_extract_conj(blis_transa), - k0, n0, - (double*)alpha, - (double*)b, cs_b, rs_b, - (double*)a, bli_is_notrans(blis_transa) ? cs_a : rs_a, - (double*)beta, - c, cs_c, - ((void*)0) - ); - } - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - return; - } - - const num_t dt = BLIS_DOUBLE; - - obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; - obj_t ao = BLIS_OBJECT_INITIALIZER; - obj_t bo = BLIS_OBJECT_INITIALIZER; - obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; - obj_t co = BLIS_OBJECT_INITIALIZER; - - dim_t m0_a, n0_a; - dim_t m0_b, n0_b; - - bli_set_dims_with_trans(blis_transa, m0, k0, &m0_a, &n0_a); - bli_set_dims_with_trans(blis_transb, k0, n0, &m0_b, &n0_b); - - bli_obj_init_finish_1x1(dt, (double*)alpha, &alphao); - bli_obj_init_finish_1x1(dt, (double*)beta, &betao); - - bli_obj_init_finish(dt, m0_a, n0_a, (double*)a, rs_a, cs_a, &ao); - bli_obj_init_finish(dt, m0_b, n0_b, (double*)b, rs_b, cs_b, &bo); - bli_obj_init_finish(dt, m0, n0, (double*)c, rs_c, cs_c, &co); - - bli_obj_set_conjtrans(blis_transa, &ao); - bli_obj_set_conjtrans(blis_transb, &bo); - - //cntx_t* cntx = bli_gks_query_cntx(); - //dim_t nt = bli_thread_get_num_threads(); // get number of threads - bool nt = bli_thread_get_is_parallel(); // Check if parallel dgemm is invoked. - - // if m0 is large and (n0 & k0) < 10 - SMALL GEMM - ST is better - // - -#ifdef AOCL_DYNAMIC - if (nt && ((n0 > 10 ) || (k0 > 10)) ) -#else - if (nt) -#endif - { - // Will call parallelized dgemm code - sup & native - PASTEMAC(gemm, BLIS_OAPI_EX_SUF) - ( - &alphao, - &ao, - &bo, - &betao, - &co, - NULL, - NULL - ); - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - /* Finalize BLIS. */ - bli_finalize_auto(); - return; - } - - // The code below will be called when number of threads = 1. - -#ifdef BLIS_ENABLE_SMALL_MATRIX - - //if( ((m0 + n0 -k0) < 2000) && ((m0 + k0-n0) < 2000) && ((n0 + k0-m0) < 2000) && (n0 > 2)) - if( ( ( (m0 + n0 -k0) < 2000) && ((m0 + k0-n0) < 2000) && ((n0 + k0-m0) < 2000) ) || - ((n0 <= 10) && (k0 <=10)) ) - { - err_t status; - if (bli_is_notrans(blis_transa)) - { - status = bli_dgemm_small( &alphao, - &ao, - &bo, - &betao, - &co, - NULL, //cntx, - NULL - ); - } - else - { - status = bli_dgemm_small_At ( &alphao, - &ao, - &bo, - &betao, - &co, - NULL, //cntx, - NULL - ); - } - - if (status == BLIS_SUCCESS) - { - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - /* Finalize BLIS. */ - bli_finalize_auto(); - - return; - } - } - -#endif //#ifdef BLIS_ENABLE_SMALL_MATRIX - - err_t status = bli_gemmsup(&alphao, &ao, &bo, &betao, &co, NULL, NULL); - if (status == BLIS_SUCCESS) - { - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - return; - } - - // fall back on native path when dgemm is not handled in sup path. - bli_gemmnat(&alphao, &ao, &bo, &betao, &co, NULL, NULL); - - - /* PASTEMAC(gemm, BLIS_OAPI_EX_SUF) */ - /* ( */ - /* &alphao, */ - /* &ao, */ - /* &bo, */ - /* &betao, */ - /* &co, */ - /* NULL, */ - /* NULL */ - /* ); */ - - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - /* Finalize BLIS. */ - bli_finalize_auto(); -} // end of dgemm_ - -void zgemm_ - ( - const f77_char* transa, - const f77_char* transb, - const f77_int* m, - const f77_int* n, - const f77_int* k, - const dcomplex* alpha, - const dcomplex* a, const f77_int* lda, - const dcomplex* b, const f77_int* ldb, - const dcomplex* beta, - dcomplex* c, const f77_int* ldc - ) -{ - trans_t blis_transa; - trans_t blis_transb; - dim_t m0, n0, k0; - - /* Initialize BLIS. */ - bli_init_auto(); - - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) - AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(z), *transa, *transb, *m, *n, *k, - (void*)alpha, *lda, *ldb, (void*)beta, *ldc); - - /* Perform BLAS parameter checking. */ - PASTEBLACHK(gemm) - ( - MKSTR(z), - MKSTR(gemm), - transa, - transb, - m, - n, - k, - lda, - ldb, - ldc - ); - - /* Map BLAS chars to their corresponding BLIS enumerated type value. */ - bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); - bli_param_map_netlib_to_blis_trans( *transb, &blis_transb ); - - /* Typecast BLAS integers to BLIS integers. */ - bli_convert_blas_dim1( *m, m0 ); - bli_convert_blas_dim1( *n, n0 ); - bli_convert_blas_dim1( *k, k0 ); - - /* Set the row and column strides of the matrix operands. */ - const inc_t rs_a = 1; - const inc_t cs_a = *lda; - const inc_t rs_b = 1; - const inc_t cs_b = *ldb; - const inc_t rs_c = 1; - const inc_t cs_c = *ldc; - - const num_t dt = BLIS_DCOMPLEX; - - obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; - obj_t ao = BLIS_OBJECT_INITIALIZER; - obj_t bo = BLIS_OBJECT_INITIALIZER; - obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; - obj_t co = BLIS_OBJECT_INITIALIZER; - - dim_t m0_a, n0_a; - dim_t m0_b, n0_b; - - bli_set_dims_with_trans( blis_transa, m0, k0, &m0_a, &n0_a ); - bli_set_dims_with_trans( blis_transb, k0, n0, &m0_b, &n0_b ); - - bli_obj_init_finish_1x1( dt, (dcomplex*)alpha, &alphao ); - bli_obj_init_finish_1x1( dt, (dcomplex*)beta, &betao ); - - bli_obj_init_finish( dt, m0_a, n0_a, (dcomplex*)a, rs_a, cs_a, &ao ); - bli_obj_init_finish( dt, m0_b, n0_b, (dcomplex*)b, rs_b, cs_b, &bo ); - bli_obj_init_finish( dt, m0, n0, (dcomplex*)c, rs_c, cs_c, &co ); - - bli_obj_set_conjtrans( blis_transa, &ao ); - bli_obj_set_conjtrans( blis_transb, &bo ); - - // default instance peformance tuning is done in zgemm. - // Single instance tuning is done based on env set. - dim_t single_instance = bli_env_get_var( "BLIS_SINGLE_INSTANCE", -1 ); - - //dim_t nt = bli_thread_get_num_threads(); // get number of threads - bool nt = bli_thread_get_is_parallel(); // Check if parallel zgemm is invoked. - if ( nt ) - { - // Will call parallelized zgemm code - sup & native - PASTEMAC(gemm, BLIS_OAPI_EX_SUF) - ( - &alphao, - &ao, - &bo, - &betao, - &co, - NULL, - NULL - ); - - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - /* Finalize BLIS. */ - bli_finalize_auto(); - return; - } - - // The code below will be called when number of threads = 1. -#if ENABLE_INDUCED_METHOD - /* 3m_sqp is optimal for certain matrix shapes. - Initial study that it works well for square sizes and sizes closer to square shape. - - * Usage of 3m_sqp is restricted to sizes, where it is found efficient compared to native, sup and other induced method. - * Further investigation is necessary to make the usage choices more generic. */ - bool sqp_on = false; - if( (m0 == n0 ) && ( n0 == k0 ) && ( m0 == 128 ) ) - { - sqp_on = true; - } - - // current range of sizes used for 3m_sqp to be expaned after evaluation. - if( ( m0 >= 4200) && ( m0 <= 4600 ) && ( ( n0 >= 326 ) || (n0 <= 1600 ) ) - && ( k0 == 1120 ) ) //to be tuned further. - { - sqp_on = true; - } - - if( ( blis_transb == BLIS_NO_TRANSPOSE) && ( sqp_on == true ) ) - { - //sqp algo is found better for n > 40 - if(bli_gemm_sqp(&alphao, &ao, &bo, &betao, &co, NULL, NULL)==BLIS_SUCCESS) - { - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) - return; - } - } -#endif//ENABLE_INDUCED_METHOD - -// native tuning resulted in better numbers compared to sup in constrained multi-instance -// sup has been enabled for single instance cases. - if(single_instance==1) - { - err_t status = bli_gemmsup(&alphao, &ao, &bo, &betao, &co, NULL, NULL); - if(status==BLIS_SUCCESS) - { - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) - return; - } - - } - // fall back on native path when zgemm is not handled in sup path. - bli_gemmnat(&alphao, &ao, &bo, &betao, &co, NULL, NULL); - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) - return; - - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) - /* Finalize BLIS. */ - bli_finalize_auto(); -}// end of zgemm_ - - -INSERT_GENTFUNC_BLAS_SC( gemm, gemm ) -#else INSERT_GENTFUNC_BLAS( gemm,gemm ) -#endif // Observed a regression in dgemm with this function addition. // Disabling temporarily. diff --git a/frame/compat/bla_gemm_amd.c b/frame/compat/bla_gemm_amd.c new file mode 100644 index 0000000000..7ef58bfb35 --- /dev/null +++ b/frame/compat/bla_gemm_amd.c @@ -0,0 +1,894 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019 - 22, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +// +// Define BLAS-to-BLIS interfaces. +// +#define ENABLE_INDUCED_METHOD 0 +#ifdef BLIS_BLAS3_CALLS_TAPI + +#undef GENTFUNC +#define GENTFUNC( ftype, ch, blasname, blisname ) \ +\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* transa, \ + const f77_char* transb, \ + const f77_int* m, \ + const f77_int* n, \ + const f77_int* k, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + const ftype* b, const f77_int* ldb, \ + const ftype* beta, \ + ftype* c, const f77_int* ldc \ + ) \ +{ \ + trans_t blis_transa; \ + trans_t blis_transb; \ + dim_t m0, n0, k0; \ + inc_t rs_a, cs_a; \ + inc_t rs_b, cs_b; \ + inc_t rs_c, cs_c; \ +\ + /* Initialize BLIS. */ \ + bli_init_auto(); \ +\ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); \ + AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(ch), *transa, *transb, *m, *n, *k, \ + (void*)alpha, *lda, *ldb, (void*)beta, *ldc); \ +\ + /* Perform BLAS parameter checking. */ \ + PASTEBLACHK(blasname) \ + ( \ + MKSTR(ch), \ + MKSTR(blasname), \ + transa, \ + transb, \ + m, \ + n, \ + k, \ + lda, \ + ldb, \ + ldc \ + ); \ +\ + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ + bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ + bli_param_map_netlib_to_blis_trans( *transb, &blis_transb ); \ +\ + /* Typecast BLAS integers to BLIS integers. */ \ + bli_convert_blas_dim1( *m, m0 ); \ + bli_convert_blas_dim1( *n, n0 ); \ + bli_convert_blas_dim1( *k, k0 ); \ +\ + /* Set the row and column strides of the matrix operands. */ \ + rs_a = 1; \ + cs_a = *lda; \ + rs_b = 1; \ + cs_b = *ldb; \ + rs_c = 1; \ + cs_c = *ldc; \ +\ + /* Call BLIS interface. */ \ + PASTEMAC2(ch,blisname,BLIS_TAPI_EX_SUF) \ + ( \ + blis_transa, \ + blis_transb, \ + m0, \ + n0, \ + k0, \ + (ftype*)alpha, \ + (ftype*)a, rs_a, cs_a, \ + (ftype*)b, rs_b, cs_b, \ + (ftype*)beta, \ + (ftype*)c, rs_c, cs_c, \ + NULL, \ + NULL \ + ); \ +\ + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ +} + +#else + +#undef GENTFUNC +#define GENTFUNC( ftype, ch, blasname, blisname ) \ +\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* transa, \ + const f77_char* transb, \ + const f77_int* m, \ + const f77_int* n, \ + const f77_int* k, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + const ftype* b, const f77_int* ldb, \ + const ftype* beta, \ + ftype* c, const f77_int* ldc \ + ) \ +{ \ +\ + trans_t blis_transa; \ + trans_t blis_transb; \ + dim_t m0, n0, k0; \ +\ + dim_t m0_a, n0_a; \ + dim_t m0_b, n0_b; \ +\ + /* Initialize BLIS. */ \ + bli_init_auto(); \ +\ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); \ + AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(ch), *transa, *transb, *m, *n, *k, \ + (void*)alpha, *lda, *ldb, (void*)beta, *ldc); \ +\ + /* Perform BLAS parameter checking. */ \ + PASTEBLACHK(blasname) \ + ( \ + MKSTR(ch), \ + MKSTR(blasname), \ + transa, \ + transb, \ + m, \ + n, \ + k, \ + lda, \ + ldb, \ + ldc \ + ); \ +\ + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ + bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ + bli_param_map_netlib_to_blis_trans( *transb, &blis_transb ); \ +\ + /* Typecast BLAS integers to BLIS integers. */ \ + bli_convert_blas_dim1( *m, m0 ); \ + bli_convert_blas_dim1( *n, n0 ); \ + bli_convert_blas_dim1( *k, k0 ); \ +\ + /* Set the row and column strides of the matrix operands. */ \ + const inc_t rs_a = 1; \ + const inc_t cs_a = *lda; \ + const inc_t rs_b = 1; \ + const inc_t cs_b = *ldb; \ + const inc_t rs_c = 1; \ + const inc_t cs_c = *ldc; \ +\ + if( n0 == 1 ) \ + { \ + if(bli_is_notrans(blis_transa)) \ + { \ + PASTEMAC(ch,gemv_unf_var2)( \ + BLIS_NO_TRANSPOSE, \ + bli_extract_conj(blis_transb), \ + m0, k0, \ + (ftype*)alpha, \ + (ftype*)a, rs_a, cs_a,\ + (ftype*)b, bli_is_notrans(blis_transb)?rs_b:cs_b, \ + (ftype*) beta, \ + c, rs_c, \ + NULL \ + ); \ + } \ + else \ + { \ + PASTEMAC(ch,gemv_unf_var1)( \ + blis_transa, \ + bli_extract_conj(blis_transb), \ + k0, m0, \ + (ftype*)alpha, \ + (ftype*)a, rs_a, cs_a, \ + (ftype*)b, bli_is_notrans(blis_transb)?rs_b:cs_b, \ + (ftype*)beta, \ + c, rs_c, \ + NULL \ + ); \ + } \ + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); \ + return; \ + } \ + else if( m0 == 1 ) \ + { \ + if(bli_is_notrans(blis_transb)) \ + { \ + PASTEMAC(ch,gemv_unf_var1)( \ + blis_transb, \ + bli_extract_conj(blis_transa), \ + n0, k0, \ + (ftype*)alpha, \ + (ftype*)b, cs_b, rs_b, \ + (ftype*)a, bli_is_notrans(blis_transa)?cs_a:rs_a, \ + (ftype*)beta, \ + c, cs_c, \ + NULL \ + ); \ + } \ + else \ + { \ + PASTEMAC(ch,gemv_unf_var2)( \ + blis_transb, \ + bli_extract_conj(blis_transa), \ + k0, n0, \ + (ftype*)alpha, \ + (ftype*)b, cs_b, rs_b, \ + (ftype*)a, bli_is_notrans(blis_transa)?cs_a:rs_a, \ + (ftype*)beta, \ + c, cs_c, \ + NULL \ + ); \ + } \ + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); \ + return; \ + } \ +\ + const num_t dt = PASTEMAC(ch,type); \ +\ + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t ao = BLIS_OBJECT_INITIALIZER; \ + obj_t bo = BLIS_OBJECT_INITIALIZER; \ + obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t co = BLIS_OBJECT_INITIALIZER; \ +\ + bli_set_dims_with_trans( blis_transa, m0, k0, &m0_a, &n0_a ); \ + bli_set_dims_with_trans( blis_transb, k0, n0, &m0_b, &n0_b ); \ +\ + bli_obj_init_finish_1x1( dt, (ftype*)alpha, &alphao ); \ + bli_obj_init_finish_1x1( dt, (ftype*)beta, &betao ); \ +\ + bli_obj_init_finish( dt, m0_a, n0_a, (ftype*)a, rs_a, cs_a, &ao ); \ + bli_obj_init_finish( dt, m0_b, n0_b, (ftype*)b, rs_b, cs_b, &bo ); \ + bli_obj_init_finish( dt, m0, n0, (ftype*)c, rs_c, cs_c, &co ); \ +\ + bli_obj_set_conjtrans( blis_transa, &ao ); \ + bli_obj_set_conjtrans( blis_transb, &bo ); \ +\ + PASTEMAC(blisname,BLIS_OAPI_EX_SUF) \ + ( \ + &alphao, \ + &ao, \ + &bo, \ + &betao, \ + &co, \ + NULL, \ + NULL \ + ); \ +\ + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ +} +#endif + +#ifdef BLIS_ENABLE_BLAS +void dgemm_ +( + const f77_char* transa, + const f77_char* transb, + const f77_int* m, + const f77_int* n, + const f77_int* k, + const double* alpha, + const double* a, const f77_int* lda, + const double* b, const f77_int* ldb, + const double* beta, + double* c, const f77_int* ldc +) +{ + + + + trans_t blis_transa; + trans_t blis_transb; + dim_t m0, n0, k0; + + /* Initialize BLIS. */ + bli_init_auto(); + + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) + AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(d), *transa, *transb, *m, *n, *k, \ + (void*)alpha, *lda, *ldb, (void*)beta, *ldc); + + /* Perform BLAS parameter checking. */ + PASTEBLACHK(gemm) + ( + MKSTR(d), + MKSTR(gemm), + transa, + transb, + m, + n, + k, + lda, + ldb, + ldc + ); + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + bli_param_map_netlib_to_blis_trans(*transa, &blis_transa); + bli_param_map_netlib_to_blis_trans(*transb, &blis_transb); + + /* Typecast BLAS integers to BLIS integers. */ + bli_convert_blas_dim1(*m, m0); + bli_convert_blas_dim1(*n, n0); + bli_convert_blas_dim1(*k, k0); + + + /* Set the row and column strides of the matrix operands. */ + const inc_t rs_a = 1; + const inc_t cs_a = *lda; + const inc_t rs_b = 1; + const inc_t cs_b = *ldb; + const inc_t rs_c = 1; + const inc_t cs_c = *ldc; + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == FALSE) + { + // This code is duplicated below, however we don't want to move it out of + // this IF block as it will affect the performance on Zen architetures + // Also this is temporary fix which will be replaced later. + const num_t dt = BLIS_DOUBLE; + + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; + obj_t ao = BLIS_OBJECT_INITIALIZER; + obj_t bo = BLIS_OBJECT_INITIALIZER; + obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; + obj_t co = BLIS_OBJECT_INITIALIZER; + + dim_t m0_a, n0_a; + dim_t m0_b, n0_b; + + bli_set_dims_with_trans(blis_transa, m0, k0, &m0_a, &n0_a); + bli_set_dims_with_trans(blis_transb, k0, n0, &m0_b, &n0_b); + + bli_obj_init_finish_1x1(dt, (double *)alpha, &alphao); + bli_obj_init_finish_1x1(dt, (double *)beta, &betao); + + bli_obj_init_finish(dt, m0_a, n0_a, (double *)a, rs_a, cs_a, &ao); + bli_obj_init_finish(dt, m0_b, n0_b, (double *)b, rs_b, cs_b, &bo); + bli_obj_init_finish(dt, m0, n0, (double *)c, rs_c, cs_c, &co); + + bli_obj_set_conjtrans(blis_transa, &ao); + bli_obj_set_conjtrans(blis_transb, &bo); + + // Will call parallelized dgemm code - sup & native + PASTEMAC(gemm, BLIS_OAPI_EX_SUF) + ( + &alphao, + &ao, + &bo, + &betao, + &co, + NULL, + NULL + ); + + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + /* Finalize BLIS. */ + bli_finalize_auto(); + return; + } + + if((k0 == 1) && bli_is_notrans(blis_transa) && bli_is_notrans(blis_transb)) + { + bli_dgemm_ref_k1_nn( m0, n0, k0, + (double*)alpha, + (double*)a, *lda, + (double*)b, *ldb, + (double*)beta, + c, *ldc + ); + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + /* Finalize BLIS */ + bli_finalize_auto(); + + return; + } + + if (n0 == 1) + { + if (bli_is_notrans(blis_transa)) + { + bli_dgemv_unf_var2( + BLIS_NO_TRANSPOSE, + bli_extract_conj(blis_transb), + m0, k0, + (double*)alpha, + (double*)a, rs_a, cs_a, + (double*)b, bli_is_notrans(blis_transb) ? rs_b : cs_b, + (double*)beta, + c, rs_c, + ((void*)0) + ); + } + else + { + bli_dgemv_unf_var1( + blis_transa, + bli_extract_conj(blis_transb), + k0, m0, + (double*)alpha, + (double*)a, rs_a, cs_a, + (double*)b, bli_is_notrans(blis_transb) ? rs_b : cs_b, + (double*)beta, + c, rs_c, + ((void*)0) + ); + } + + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + + return; + } + else if (m0 == 1) + { + if (bli_is_notrans(blis_transb)) + { + bli_dgemv_unf_var1( + blis_transb, + bli_extract_conj(blis_transa), + n0, k0, + (double*)alpha, + (double*)b, cs_b, rs_b, + (double*)a, bli_is_notrans(blis_transa) ? cs_a : rs_a, + (double*)beta, + c, cs_c, + ((void*)0) + ); + } + else + { + bli_dgemv_unf_var2( + blis_transb, + bli_extract_conj(blis_transa), + k0, n0, + (double*)alpha, + (double*)b, cs_b, rs_b, + (double*)a, bli_is_notrans(blis_transa) ? cs_a : rs_a, + (double*)beta, + c, cs_c, + ((void*)0) + ); + } + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + return; + } + + const num_t dt = BLIS_DOUBLE; + + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; + obj_t ao = BLIS_OBJECT_INITIALIZER; + obj_t bo = BLIS_OBJECT_INITIALIZER; + obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; + obj_t co = BLIS_OBJECT_INITIALIZER; + + dim_t m0_a, n0_a; + dim_t m0_b, n0_b; + + bli_set_dims_with_trans(blis_transa, m0, k0, &m0_a, &n0_a); + bli_set_dims_with_trans(blis_transb, k0, n0, &m0_b, &n0_b); + + bli_obj_init_finish_1x1(dt, (double*)alpha, &alphao); + bli_obj_init_finish_1x1(dt, (double*)beta, &betao); + + bli_obj_init_finish(dt, m0_a, n0_a, (double*)a, rs_a, cs_a, &ao); + bli_obj_init_finish(dt, m0_b, n0_b, (double*)b, rs_b, cs_b, &bo); + bli_obj_init_finish(dt, m0, n0, (double*)c, rs_c, cs_c, &co); + + bli_obj_set_conjtrans(blis_transa, &ao); + bli_obj_set_conjtrans(blis_transb, &bo); + + //cntx_t* cntx = bli_gks_query_cntx(); + //dim_t nt = bli_thread_get_num_threads(); // get number of threads + bool nt = bli_thread_get_is_parallel(); // Check if parallel dgemm is invoked. + + // if m0 is large and (n0 & k0) < 10 - SMALL GEMM - ST is better + // + +#ifdef AOCL_DYNAMIC + if (nt && ((n0 > 10 ) || (k0 > 10)) ) +#else + if (nt) +#endif + { + // Will call parallelized dgemm code - sup & native + PASTEMAC(gemm, BLIS_OAPI_EX_SUF) + ( + &alphao, + &ao, + &bo, + &betao, + &co, + NULL, + NULL + ); + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + /* Finalize BLIS. */ + bli_finalize_auto(); + return; + } + + // The code below will be called when number of threads = 1. + +#ifdef BLIS_ENABLE_SMALL_MATRIX + + //if( ((m0 + n0 -k0) < 2000) && ((m0 + k0-n0) < 2000) && ((n0 + k0-m0) < 2000) && (n0 > 2)) + if( ( ( (m0 + n0 -k0) < 2000) && ((m0 + k0-n0) < 2000) && ((n0 + k0-m0) < 2000) ) || + ((n0 <= 10) && (k0 <=10)) ) + { + err_t status; + if (bli_is_notrans(blis_transa)) + { + status = bli_dgemm_small( &alphao, + &ao, + &bo, + &betao, + &co, + NULL, //cntx, + NULL + ); + } + else + { + status = bli_dgemm_small_At ( &alphao, + &ao, + &bo, + &betao, + &co, + NULL, //cntx, + NULL + ); + } + + if (status == BLIS_SUCCESS) + { + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + /* Finalize BLIS. */ + bli_finalize_auto(); + + return; + } + } + +#endif //#ifdef BLIS_ENABLE_SMALL_MATRIX + + err_t status = bli_gemmsup(&alphao, &ao, &bo, &betao, &co, NULL, NULL); + if (status == BLIS_SUCCESS) + { + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + return; + } + + // fall back on native path when dgemm is not handled in sup path. + bli_gemmnat(&alphao, &ao, &bo, &betao, &co, NULL, NULL); + + + /* PASTEMAC(gemm, BLIS_OAPI_EX_SUF) */ + /* ( */ + /* &alphao, */ + /* &ao, */ + /* &bo, */ + /* &betao, */ + /* &co, */ + /* NULL, */ + /* NULL */ + /* ); */ + + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + /* Finalize BLIS. */ + bli_finalize_auto(); +} // end of dgemm_ + +void zgemm_ + ( + const f77_char* transa, + const f77_char* transb, + const f77_int* m, + const f77_int* n, + const f77_int* k, + const dcomplex* alpha, + const dcomplex* a, const f77_int* lda, + const dcomplex* b, const f77_int* ldb, + const dcomplex* beta, + dcomplex* c, const f77_int* ldc + ) +{ + trans_t blis_transa; + trans_t blis_transb; + dim_t m0, n0, k0; + + /* Initialize BLIS. */ + bli_init_auto(); + + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) + AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(z), *transa, *transb, *m, *n, *k, + (void*)alpha, *lda, *ldb, (void*)beta, *ldc); + + /* Perform BLAS parameter checking. */ + PASTEBLACHK(gemm) + ( + MKSTR(z), + MKSTR(gemm), + transa, + transb, + m, + n, + k, + lda, + ldb, + ldc + ); + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); + bli_param_map_netlib_to_blis_trans( *transb, &blis_transb ); + + /* Typecast BLAS integers to BLIS integers. */ + bli_convert_blas_dim1( *m, m0 ); + bli_convert_blas_dim1( *n, n0 ); + bli_convert_blas_dim1( *k, k0 ); + + /* Set the row and column strides of the matrix operands. */ + const inc_t rs_a = 1; + const inc_t cs_a = *lda; + const inc_t rs_b = 1; + const inc_t cs_b = *ldb; + const inc_t rs_c = 1; + const inc_t cs_c = *ldc; + + const num_t dt = BLIS_DCOMPLEX; + + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; + obj_t ao = BLIS_OBJECT_INITIALIZER; + obj_t bo = BLIS_OBJECT_INITIALIZER; + obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; + obj_t co = BLIS_OBJECT_INITIALIZER; + + dim_t m0_a, n0_a; + dim_t m0_b, n0_b; + + bli_set_dims_with_trans( blis_transa, m0, k0, &m0_a, &n0_a ); + bli_set_dims_with_trans( blis_transb, k0, n0, &m0_b, &n0_b ); + + bli_obj_init_finish_1x1( dt, (dcomplex*)alpha, &alphao ); + bli_obj_init_finish_1x1( dt, (dcomplex*)beta, &betao ); + + bli_obj_init_finish( dt, m0_a, n0_a, (dcomplex*)a, rs_a, cs_a, &ao ); + bli_obj_init_finish( dt, m0_b, n0_b, (dcomplex*)b, rs_b, cs_b, &bo ); + bli_obj_init_finish( dt, m0, n0, (dcomplex*)c, rs_c, cs_c, &co ); + + bli_obj_set_conjtrans( blis_transa, &ao ); + bli_obj_set_conjtrans( blis_transb, &bo ); + + // default instance peformance tuning is done in zgemm. + // Single instance tuning is done based on env set. + dim_t single_instance = bli_env_get_var( "BLIS_SINGLE_INSTANCE", -1 ); + + //dim_t nt = bli_thread_get_num_threads(); // get number of threads + bool nt = bli_thread_get_is_parallel(); // Check if parallel zgemm is invoked. + if ( nt ) + { + // Will call parallelized zgemm code - sup & native + PASTEMAC(gemm, BLIS_OAPI_EX_SUF) + ( + &alphao, + &ao, + &bo, + &betao, + &co, + NULL, + NULL + ); + + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + /* Finalize BLIS. */ + bli_finalize_auto(); + return; + } + + // The code below will be called when number of threads = 1. +#if ENABLE_INDUCED_METHOD + /* 3m_sqp is optimal for certain matrix shapes. + Initial study that it works well for square sizes and sizes closer to square shape. + + * Usage of 3m_sqp is restricted to sizes, where it is found efficient compared to native, sup and other induced method. + * Further investigation is necessary to make the usage choices more generic. */ + bool sqp_on = false; + if( (m0 == n0 ) && ( n0 == k0 ) && ( m0 == 128 ) ) + { + sqp_on = true; + } + + // current range of sizes used for 3m_sqp to be expaned after evaluation. + if( ( m0 >= 4200) && ( m0 <= 4600 ) && ( ( n0 >= 326 ) || (n0 <= 1600 ) ) + && ( k0 == 1120 ) ) //to be tuned further. + { + sqp_on = true; + } + + if( ( blis_transb == BLIS_NO_TRANSPOSE) && ( sqp_on == true ) ) + { + //sqp algo is found better for n > 40 + if(bli_gemm_sqp(&alphao, &ao, &bo, &betao, &co, NULL, NULL)==BLIS_SUCCESS) + { + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) + return; + } + } +#endif//ENABLE_INDUCED_METHOD + +// native tuning resulted in better numbers compared to sup in constrained multi-instance +// sup has been enabled for single instance cases. + if(single_instance==1) + { + err_t status = bli_gemmsup(&alphao, &ao, &bo, &betao, &co, NULL, NULL); + if(status==BLIS_SUCCESS) + { + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) + return; + } + + } + // fall back on native path when zgemm is not handled in sup path. + bli_gemmnat(&alphao, &ao, &bo, &betao, &co, NULL, NULL); + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) + return; + + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) + /* Finalize BLIS. */ + bli_finalize_auto(); +}// end of zgemm_ + + +INSERT_GENTFUNC_BLAS_SC( gemm, gemm ) + + +// Observed a regression in dgemm with this function addition. +// Disabling temporarily. +#if 0 +void dzgemm_ + ( + const f77_char* transa, + const f77_char* transb, + const f77_int* m, + const f77_int* n, + const f77_int* k, + const dcomplex* alpha, + const double* a, const f77_int* lda, + const dcomplex* b, const f77_int* ldb, + const dcomplex* beta, + dcomplex* c, const f77_int* ldc + ) +{ + + trans_t blis_transa; + trans_t blis_transb; + dim_t m0, n0, k0; + + /* Initialize BLIS. */ + bli_init_auto(); + + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) + AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(z), *transa, *transb, *m, *n, *k, + (void*)alpha, *lda, *ldb, (void*)beta, *ldc); + + /* Perform BLAS parameter checking. */ + PASTEBLACHK(gemm) + ( + MKSTR(z), + MKSTR(gemm), + transa, + transb, + m, + n, + k, + lda, + ldb, + ldc + ); + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); + bli_param_map_netlib_to_blis_trans( *transb, &blis_transb ); + + /* Typecast BLAS integers to BLIS integers. */ + bli_convert_blas_dim1( *m, m0 ); + bli_convert_blas_dim1( *n, n0 ); + bli_convert_blas_dim1( *k, k0 ); + + /* Set the row and column strides of the matrix operands. */ + const inc_t rs_a = 1; + const inc_t cs_a = *lda; + const inc_t rs_b = 1; + const inc_t cs_b = *ldb; + const inc_t rs_c = 1; + const inc_t cs_c = *ldc; + + const num_t dt = BLIS_DCOMPLEX; + const num_t dt_a = BLIS_DOUBLE; + + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; + obj_t ao = BLIS_OBJECT_INITIALIZER; + obj_t bo = BLIS_OBJECT_INITIALIZER; + obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; + obj_t co = BLIS_OBJECT_INITIALIZER; + + dim_t m0_a, n0_a; + dim_t m0_b, n0_b; + + bli_set_dims_with_trans( blis_transa, m0, k0, &m0_a, &n0_a ); + bli_set_dims_with_trans( blis_transb, k0, n0, &m0_b, &n0_b ); + + bli_obj_init_finish_1x1( dt, (dcomplex*)alpha, &alphao ); + bli_obj_init_finish_1x1( dt, (dcomplex*)beta, &betao ); + + bli_obj_init_finish( dt_a, m0_a, n0_a, (dcomplex*)a, rs_a, cs_a, &ao ); + bli_obj_init_finish( dt, m0_b, n0_b, (dcomplex*)b, rs_b, cs_b, &bo ); + bli_obj_init_finish( dt, m0, n0, (dcomplex*)c, rs_c, cs_c, &co ); + + bli_obj_set_conjtrans( blis_transa, &ao ); + bli_obj_set_conjtrans( blis_transb, &bo ); + + // fall back on native path when zgemm is not handled in sup path. + bli_gemmnat(&alphao, &ao, &bo, &betao, &co, NULL, NULL); + + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) + /* Finalize BLIS. */ + bli_finalize_auto(); +}// end of dzgemm_ +#endif +#endif diff --git a/frame/compat/bla_gemv.c b/frame/compat/bla_gemv.c index af2745ca98..9dba1b43c4 100644 --- a/frame/compat/bla_gemv.c +++ b/frame/compat/bla_gemv.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 21, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 22, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -147,856 +147,5 @@ void PASTEF77(ch,blasname) \ #ifdef BLIS_ENABLE_BLAS -#ifdef BLIS_CONFIG_EPYC -void dgemv_ - ( - const f77_char* transa, - const f77_int* m, - const f77_int* n, - const double* alpha, - const double* a, const f77_int* lda, - const double* x, const f77_int* incx, - const double* beta, - double* y, const f77_int* incy - ) -{ - trans_t blis_transa; - dim_t m0, n0; - dim_t m_y, n_x; - double* x0; - double* y0; - inc_t incx0; - inc_t incy0; - inc_t rs_a, cs_a; - - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); - AOCL_DTL_LOG_GEMV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'D', *transa, *m, *n, (void*)alpha, *lda, *incx, (void*)beta, *incy); - - /* Perform BLAS parameter checking. */ - PASTEBLACHK(gemv) - ( - MKSTR(d), - MKSTR(gemv), - transa, - m, - n, - lda, - incx, - incy - ); - - if (*m == 0 || *n == 0) - { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return; - } - - /* Map BLAS chars to their corresponding BLIS enumerated type value. */ - if ( *transa == 'n' || *transa == 'N' ) blis_transa = BLIS_NO_TRANSPOSE; - else if ( *transa == 't' || *transa == 'T' ) blis_transa = BLIS_TRANSPOSE; - else if ( *transa == 'c' || *transa == 'C' ) blis_transa = BLIS_CONJ_TRANSPOSE; - else - { - // See comment for bli_param_map_netlib_to_blis_side() above. - //bli_check_error_code( BLIS_INVALID_TRANS ); - blis_transa = BLIS_NO_TRANSPOSE; - } - - /* Convert/typecast negative values of m and n to zero. */ - if ( *m < 0 ) m0 = ( dim_t )0; - else m0 = ( dim_t )(*m); - - if ( *n < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(*n); - - /* Determine the dimensions of x and y so we can adjust the increments, - if necessary.*/ - if ( bli_does_notrans( blis_transa ) ) - { - m_y = m0; - n_x = n0; - } - else - { - m_y = n0; - n_x = m0; - } - - /* BLAS handles cases where trans(A) has no columns, and x has no elements, - in a peculiar way. In these situations, BLAS returns without performing - any action, even though most sane interpretations of gemv would have the - the operation reduce to y := beta * y. Here, we catch those cases that - BLAS would normally mishandle and emulate the BLAS exactly so as to - provide "bug-for-bug" compatibility. Note that this extreme level of - compatibility would not be as much of an issue if it weren't for the - fact that some BLAS test suites actually test for these cases. Also, it - should be emphasized that BLIS, if called natively, does NOT exhibit - this quirky behavior; it will scale y by beta, as one would expect. */ - if ( m_y > 0 && n_x == 0 ) - { - /* Finalize BLIS. */ - // bli_finalize_auto(); - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return; - } - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - if ( *incx < 0 ) - { - x0 = ((double*)x) + (n_x-1)*(-*incx); - incx0 = ( inc_t )(*incx); - } - else - { - x0 = ((double*)x); - incx0 = ( inc_t )(*incx); - } - - if ( *incy < 0 ) - { - y0 = ((double*)y) + (m_y-1)*(-*incy); - incy0 = ( inc_t )(*incy); - } - else - { - y0 = ((double*)y); - incy0 = ( inc_t )(*incy); - } - - /* Set the row and column strides of A. */ - rs_a = 1; - cs_a = *lda; - - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. - // This function is invoked on all architectures including ‘generic’. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN4) || - (id == BLIS_ARCH_ZEN3) || - (id == BLIS_ARCH_ZEN2) || - (id == BLIS_ARCH_ZEN); - - if (bamdzen == 0) - { - /* Call BLIS interface. */ - PASTEMAC2(d,gemv,BLIS_TAPI_EX_SUF) - ( - blis_transa, - BLIS_NO_CONJUGATE, - m0, - n0, - (double*)alpha, - (double*)a, rs_a, cs_a, - x0, incx0, - (double*)beta, - y0, incy0, - NULL, - NULL - ); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return; - } - - /* Call variants based on transpose value. */ - if(bli_does_notrans(blis_transa)) - { - //variant_2 is chosen for column-storage - // and uses axpyf-based implementation - bli_dgemv_unf_var2 - ( - blis_transa, - BLIS_NO_CONJUGATE, - m0, - n0, - (double*)alpha, - (double*)a, rs_a, cs_a, - x0, incx0, - (double*)beta, - y0, incy0, - NULL - ); - } - else - { - //var_1 is chosen for row-storage - //and uses dotxf-based implementation - bli_dgemv_unf_var1 - ( - blis_transa, - BLIS_NO_CONJUGATE, - m0, - n0, - (double*)alpha, - (double*)a, rs_a, cs_a, - x0, incx0, - (double*)beta, - y0, incy0, - NULL - ); - } - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); -} - -void sgemv_ - ( - const f77_char* transa, - const f77_int* m, - const f77_int* n, - const float* alpha, - const float* a, const f77_int* lda, - const float* x, const f77_int* incx, - const float* beta, - float* y, const f77_int* incy - ) -{ - trans_t blis_transa; - dim_t m0, n0; - dim_t m_y, n_x; - float* x0; - float* y0; - inc_t incx0; - inc_t incy0; - inc_t rs_a, cs_a; - - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); - AOCL_DTL_LOG_GEMV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'S', *transa, *m, *n, (void*)alpha, *lda, *incx, (void*)beta, *incy); - /* Perform BLAS parameter checking. */ - PASTEBLACHK(gemv) - ( - MKSTR(s), - MKSTR(gemv), - transa, - m, - n, - lda, - incx, - incy - ); - - if (*m == 0 || *n == 0) - { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return; - } - - /* Map BLAS chars to their corresponding BLIS enumerated type value. */ - if ( *transa == 'n' || *transa == 'N' ) blis_transa = BLIS_NO_TRANSPOSE; - else if ( *transa == 't' || *transa == 'T' ) blis_transa = BLIS_TRANSPOSE; - else if ( *transa == 'c' || *transa == 'C' ) blis_transa = BLIS_CONJ_TRANSPOSE; - else - { - // See comment for bli_param_map_netlib_to_blis_side() above. - //bli_check_error_code( BLIS_INVALID_TRANS ); - blis_transa = BLIS_NO_TRANSPOSE; - } - - /* Convert/typecast negative values of m and n to zero. */ - if ( *m < 0 ) m0 = ( dim_t )0; - else m0 = ( dim_t )(*m); - - if ( *n < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(*n); - - /* Determine the dimensions of x and y so we can adjust the increments, - if necessary.*/ - if ( bli_does_notrans( blis_transa ) ) - { - m_y = m0; - n_x = n0; - } - else - { - m_y = n0; - n_x = m0; - } - - /* BLAS handles cases where trans(A) has no columns, and x has no elements, - in a peculiar way. In these situations, BLAS returns without performing - any action, even though most sane interpretations of gemv would have the - the operation reduce to y := beta * y. Here, we catch those cases that - BLAS would normally mishandle and emulate the BLAS exactly so as to - provide "bug-for-bug" compatibility. Note that this extreme level of - compatibility would not be as much of an issue if it weren't for the - fact that some BLAS test suites actually test for these cases. Also, it - should be emphasized that BLIS, if called natively, does NOT exhibit - this quirky behavior; it will scale y by beta, as one would expect. */ - if ( m_y > 0 && n_x == 0 ) - { - /* Finalize BLIS. */ - // bli_finalize_auto(); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return; - } - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - if ( *incx < 0 ) - { - x0 = ((float*)x) + (n_x-1)*(-*incx); - incx0 = ( inc_t )(*incx); - } - else - { - x0 = ((float*)x); - incx0 = ( inc_t )(*incx); - } - - if ( *incy < 0 ) - { - y0 = ((float*)y) + (m_y-1)*(-*incy); - incy0 = ( inc_t )(*incy); - } - else - { - y0 = ((float*)y); - incy0 = ( inc_t )(*incy); - } - - /* Set the row and column strides of A. */ - rs_a = 1; - cs_a = *lda; - - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. - // This function is invoked on all architectures including ‘generic’. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN4) || - (id == BLIS_ARCH_ZEN3) || - (id == BLIS_ARCH_ZEN2) || - (id == BLIS_ARCH_ZEN); - - if (bamdzen == 0) - { - /* Call BLIS interface. */ - PASTEMAC2(s,gemv,BLIS_TAPI_EX_SUF) - ( - blis_transa, - BLIS_NO_CONJUGATE, - m0, - n0, - (float*)alpha, - (float*)a, rs_a, cs_a, - x0, incx0, - (float*)beta, - y0, incy0, - NULL, - NULL - ); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return; - } - - /* Call variants based on transpose value. */ - if(bli_does_notrans(blis_transa)) - { - bli_sgemv_unf_var2 - ( - blis_transa, - BLIS_NO_CONJUGATE, - m0, - n0, - (float*)alpha, - (float*)a, rs_a, cs_a, - x0, incx0, - (float*)beta, - y0, incy0, - NULL - ); - } - else - { - bli_sgemv_unf_var1 - ( - blis_transa, - BLIS_NO_CONJUGATE, - m0, - n0, - (float*)alpha, - (float*)a, rs_a, cs_a, - x0, incx0, - (float*)beta, - y0, incy0, - NULL - ); - } - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); -} - - -void cgemv_ - ( - const f77_char* transa, - const f77_int* m, - const f77_int* n, - const scomplex* alpha, - const scomplex* a, const f77_int* lda, - const scomplex* x, const f77_int* incx, - const scomplex* beta, - scomplex* y, const f77_int* incy - ) -{ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); - AOCL_DTL_LOG_GEMV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'C', *transa, *m, *n, (void*)alpha, *lda, *incx, (void*)beta, *incy); - - trans_t blis_transa; - dim_t m0, n0; - dim_t m_y, n_x; - scomplex* x0; - scomplex* y0; - inc_t incx0; - inc_t incy0; - inc_t rs_a, cs_a; - - /* Perform BLAS parameter checking. */ - PASTEBLACHK(gemv) - ( - MKSTR(c), - MKSTR(gemv), - transa, - m, - n, - lda, - incx, - incy - ); - - if (*m == 0 || *n == 0) - { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return; - } - - /* Map BLAS chars to their corresponding BLIS enumerated type value. */ - if( *transa == 'n' || *transa == 'N' ) blis_transa = BLIS_NO_TRANSPOSE; - else if( *transa == 't' || *transa == 'T' ) blis_transa = BLIS_TRANSPOSE; - else if( * transa == 'c' || *transa == 'C' ) blis_transa = BLIS_CONJ_TRANSPOSE; - else - { - // See comment for bli_param_map_netlib_to_blis_side() above. - // bli_check_error_code( BLIS_INVALID_TRANS ); - blis_transa = BLIS_NO_TRANSPOSE; - } - - /* Convert/typecast negative values of m and n to zero. */ - if( *m < 0 ) m0 = (dim_t)0; - else m0 = (dim_t)(*m); - - if( *n < 0 ) n0 = (dim_t)0; - else n0 = (dim_t)(*n); - - /* Determine the dimensions of x and y so we can adjust the increments, - if necessary.*/ - if( bli_does_notrans( blis_transa ) ) { m_y = m0, n_x = n0; } - else { m_y = n0; n_x = m0; } - - /* BLAS handles cases where trans(A) has no columns, and x has no elements, - in a peculiar way. In these situations, BLAS returns without performing - any action, even though most sane interpretations of gemv would have the - the operation reduce to y := beta * y. Here, we catch those cases that - BLAS would normally mishandle and emulate the BLAS exactly so as to - provide "bug-for-bug" compatibility. Note that this extreme level of - compatibility would not be as much of an issue if it weren't for the - fact that some BLAS test suites actually test for these cases. Also, it - should be emphasized that BLIS, if called natively, does NOT exhibit - this quirky behavior; it will scale y by beta, as one would expect. */ - - if ( m_y > 0 && n_x == 0 ) - { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return; - } - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - if( *incx < 0 ) - { - x0 = ((scomplex*)x) + (n_x-1)*(-*incx); - incx0 = ( inc_t )(*incx); - } - else - { - x0 = ((scomplex*)x); - incx0 = (inc_t)(*incx); - } - - if ( *incy < 0 ) - { - y0 = ((scomplex*)y) + (m_y-1)*(-*incy); - incy0 = ( inc_t )(*incy); - } - else - { - y0 = ((scomplex*)y); - incy0 = ( inc_t )(*incy); - } - - /* Set the row and column strides of A. */ - rs_a = 1; - cs_a = *lda; - - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. - // This function is invoked on all architectures including ‘generic’. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN4) || - (id == BLIS_ARCH_ZEN3) || - (id == BLIS_ARCH_ZEN2) || - (id == BLIS_ARCH_ZEN); - - if( m_y == 1 ) - { - conj_t conja = bli_extract_conj(blis_transa); - scomplex rho; - if (bamdzen) - { - bli_cdotv_zen_int5 - ( - conja, - BLIS_NO_CONJUGATE, - n_x, - (scomplex*)a, bli_is_notrans(blis_transa)?cs_a:rs_a, - x0, incx0, - &rho, - NULL - ); - } - else - { - /* Call BLIS interface. */ - PASTEMAC2(c,dotv,BLIS_TAPI_EX_SUF) - ( - conja, - BLIS_NO_CONJUGATE, - n_x, - (scomplex*)a, bli_is_notrans(blis_transa)?cs_a:rs_a, - x0, incx0, - &rho, - NULL, - NULL - ); - } - - scomplex yval = *y0; - if(!bli_ceq0(*beta)) - { - bli_cscals( *beta, yval ); - } - else - { - bli_csetsc( 0.0, 0.0, &yval); - } - if(!bli_ceq0(*alpha)) - { - bli_caxpys( *alpha, rho, yval); - } - y0->real = yval.real; - y0->imag = yval.imag; - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return; - } - - if (bamdzen == 0) - { - /* Call BLIS interface. */ - PASTEMAC2(c,gemv,BLIS_TAPI_EX_SUF) - ( - blis_transa, - BLIS_NO_CONJUGATE, - m0, - n0, - (scomplex*)alpha, - (scomplex*)a, rs_a, cs_a, - x0, incx0, - (scomplex*)beta, - y0, incy0, - NULL, - NULL - ); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return; - } - - /* call variants based on transpose value */ - if( bli_does_notrans( blis_transa ) ) - { - bli_cgemv_unf_var2 - ( - blis_transa, - BLIS_NO_CONJUGATE, - m0, - n0, - (scomplex*)alpha, - (scomplex*)a, rs_a, cs_a, - x0, incx0, - (scomplex*)beta, - y0, incy0, - NULL - ); - } - else - { - bli_cgemv_unf_var1 - ( - blis_transa, - BLIS_NO_CONJUGATE, - m0, - n0, - (scomplex*)alpha, - (scomplex*)a, rs_a, cs_a, - x0, incx0, - (scomplex*)beta, - y0, incy0, - NULL - ); - } - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); -} - - -void zgemv_ - ( - const f77_char* transa, - const f77_int* m, - const f77_int* n, - const dcomplex* alpha, - const dcomplex* a, const f77_int* lda, - const dcomplex* x, const f77_int* incx, - const dcomplex* beta, - dcomplex* y, const f77_int* incy - ) -{ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); - AOCL_DTL_LOG_GEMV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'Z', *transa, *m, *n, (void*)alpha, *lda, *incx, (void*)beta, *incy); - - trans_t blis_transa; - dim_t m0, n0; - dim_t m_y, n_x; - dcomplex* x0; - dcomplex* y0; - inc_t incx0; - inc_t incy0; - inc_t rs_a, cs_a; - - /* Perform BLAS parameter checking. */ - PASTEBLACHK(gemv) - ( - MKSTR(z), - MKSTR(gemv), - transa, - m, - n, - lda, - incx, - incy - ); - - if (*m == 0 || *n == 0) - { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return; - } - - /* Map BLAS chars to their corresponding BLIS enumerated type value. */ - if( *transa == 'n' || *transa == 'N' ) blis_transa = BLIS_NO_TRANSPOSE; - else if( *transa == 't' || *transa == 'T' ) blis_transa = BLIS_TRANSPOSE; - else if( * transa == 'c' || *transa == 'C' ) blis_transa = BLIS_CONJ_TRANSPOSE; - else - { - // See comment for bli_param_map_netlib_to_blis_side() above. - // bli_check_error_code( BLIS_INVALID_TRANS ); - blis_transa = BLIS_NO_TRANSPOSE; - } - - /* Convert/typecast negative values of m and n to zero. */ - if( *m < 0 ) m0 = (dim_t)0; - else m0 = (dim_t)(*m); - - if( *n < 0 ) n0 = (dim_t)0; - else n0 = (dim_t)(*n); - - /* Determine the dimensions of x and y so we can adjust the increments, - if necessary.*/ - if( bli_does_notrans( blis_transa ) ) { m_y = m0, n_x = n0; } - else { m_y = n0; n_x = m0; } - - /* BLAS handles cases where trans(A) has no columns, and x has no elements, - in a peculiar way. In these situations, BLAS returns without performing - any action, even though most sane interpretations of gemv would have the - the operation reduce to y := beta * y. Here, we catch those cases that - BLAS would normally mishandle and emulate the BLAS exactly so as to - provide "bug-for-bug" compatibility. Note that this extreme level of - compatibility would not be as much of an issue if it weren't for the - fact that some BLAS test suites actually test for these cases. Also, it - should be emphasized that BLIS, if called natively, does NOT exhibit - this quirky behavior; it will scale y by beta, as one would expect. */ - - if ( m_y > 0 && n_x == 0 ) - { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return; - } - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - if( *incx < 0 ) - { - x0 = ((dcomplex*)x) + (n_x-1)*(-*incx); - incx0 = ( inc_t )(*incx); - } - else - { - x0 = ((dcomplex*)x); - incx0 = (inc_t)(*incx); - } - - if ( *incy < 0 ) - { - y0 = ((dcomplex*)y) + (m_y-1)*(-*incy); - incy0 = ( inc_t )(*incy); - } - else - { - y0 = ((dcomplex*)y); - incy0 = ( inc_t )(*incy); - } - - /* Set the row and column strides of A. */ - rs_a = 1; - cs_a = *lda; - - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. - // This function is invoked on all architectures including ‘generic’. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN4) || - (id == BLIS_ARCH_ZEN3) || - (id == BLIS_ARCH_ZEN2) || - (id == BLIS_ARCH_ZEN); - - if( m_y == 1 ) - { - conj_t conja = bli_extract_conj(blis_transa); - dcomplex rho; - - if (bamdzen) - { - bli_zdotv_zen_int5 - ( - conja, - BLIS_NO_CONJUGATE, - n_x, - (dcomplex*)a, bli_is_notrans(blis_transa)?cs_a:rs_a, - x0, incx0, - &rho, - NULL - ); - } - else - { - /* Call BLIS interface. */ - PASTEMAC2(z,dotv,BLIS_TAPI_EX_SUF) - ( - conja, - BLIS_NO_CONJUGATE, - n_x, - (dcomplex*)a, bli_is_notrans(blis_transa)?cs_a:rs_a, - x0, incx0, - &rho, - NULL, - NULL - ); - } - - dcomplex yval = *y0; - if(!bli_zeq0(*beta)) - { - bli_zscals( *beta, yval ); - } - else - { - bli_zsetsc( 0.0, 0.0, &yval); - } - if(!bli_zeq0(*alpha)) - { - bli_zaxpys( *alpha, rho, yval); - } - y0->real = yval.real; - y0->imag = yval.imag; - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return; - } - - if (bamdzen == 0) - { - /* Call BLIS interface. */ - PASTEMAC2(z,gemv,BLIS_TAPI_EX_SUF) - ( - blis_transa, - BLIS_NO_CONJUGATE, - m0, - n0, - (dcomplex*)alpha, - (dcomplex*)a, rs_a, cs_a, - x0, incx0, - (dcomplex*)beta, - y0, incy0, - NULL, - NULL - ); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return; - } - - /* call variants based on transpose value */ - if( bli_does_notrans( blis_transa ) ) - { - bli_zgemv_unf_var2 - ( - blis_transa, - BLIS_NO_CONJUGATE, - m0, - n0, - (dcomplex*)alpha, - (dcomplex*)a, rs_a, cs_a, - x0, incx0, - (dcomplex*)beta, - y0, incy0, - NULL - ); - } - else - { - bli_zgemv_unf_var1 - ( - blis_transa, - BLIS_NO_CONJUGATE, - m0, - n0, - (dcomplex*)alpha, - (dcomplex*)a, rs_a, cs_a, - x0, incx0, - (dcomplex*)beta, - y0, incy0, - NULL - ); - } - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); -} - - -#else INSERT_GENTFUNC_BLAS( gemv, gemv ) #endif -#endif diff --git a/frame/compat/bla_gemv_amd.c b/frame/compat/bla_gemv_amd.c new file mode 100644 index 0000000000..354f45fe1b --- /dev/null +++ b/frame/compat/bla_gemv_amd.c @@ -0,0 +1,963 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020 - 22, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + + +// +// Define BLAS-to-BLIS interfaces. +// +#undef GENTFUNC +#define GENTFUNC( ftype, ch, blasname, blisname ) \ +\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* transa, \ + const f77_int* m, \ + const f77_int* n, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + const ftype* x, const f77_int* incx, \ + const ftype* beta, \ + ftype* y, const f77_int* incy \ + ) \ +{ \ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); \ + AOCL_DTL_LOG_GEMV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(ch), *transa, *m, *n, (void*)alpha, *lda, *incx, (void*)beta, *incy); \ + trans_t blis_transa; \ + dim_t m0, n0; \ + dim_t m_y, n_x; \ + ftype* x0; \ + ftype* y0; \ + inc_t incx0; \ + inc_t incy0; \ + inc_t rs_a, cs_a; \ +\ + /* Initialize BLIS. */ \ + bli_init_auto(); \ +\ + /* Perform BLAS parameter checking. */ \ + PASTEBLACHK(blasname) \ + ( \ + MKSTR(ch), \ + MKSTR(blasname), \ + transa, \ + m, \ + n, \ + lda, \ + incx, \ + incy \ + ); \ +\ + if (*m == 0 || *n == 0) { \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ + return; \ + } \ +\ + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ + bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ +\ + /* Convert/typecast negative values of m and n to zero. */ \ + bli_convert_blas_dim1( *m, m0 ); \ + bli_convert_blas_dim1( *n, n0 ); \ +\ + /* Determine the dimensions of x and y so we can adjust the increments, + if necessary.*/ \ + bli_set_dims_with_trans( blis_transa, m0, n0, &m_y, &n_x ); \ +\ + /* BLAS handles cases where trans(A) has no columns, and x has no elements, + in a peculiar way. In these situations, BLAS returns without performing + any action, even though most sane interpretations of gemv would have the + the operation reduce to y := beta * y. Here, we catch those cases that + BLAS would normally mishandle and emulate the BLAS exactly so as to + provide "bug-for-bug" compatibility. Note that this extreme level of + compatibility would not be as much of an issue if it weren't for the + fact that some BLAS test suites actually test for these cases. Also, it + should be emphasized that BLIS, if called natively, does NOT exhibit + this quirky behavior; it will scale y by beta, as one would expect. */ \ + if ( m_y > 0 && n_x == 0 ) \ + { \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ +\ + return; \ + } \ +\ + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ \ + bli_convert_blas_incv( n_x, (ftype*)x, *incx, x0, incx0 ); \ + bli_convert_blas_incv( m_y, (ftype*)y, *incy, y0, incy0 ); \ +\ + /* Set the row and column strides of A. */ \ + rs_a = 1; \ + cs_a = *lda; \ +\ + /* Call BLIS interface. */ \ + PASTEMAC2(ch,blisname,BLIS_TAPI_EX_SUF) \ + ( \ + blis_transa, \ + BLIS_NO_CONJUGATE, \ + m0, \ + n0, \ + (ftype*)alpha, \ + (ftype*)a, rs_a, cs_a, \ + x0, incx0, \ + (ftype*)beta, \ + y0, incy0, \ + NULL, \ + NULL \ + ); \ +\ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ +} + + +#ifdef BLIS_ENABLE_BLAS +void dgemv_ + ( + const f77_char* transa, + const f77_int* m, + const f77_int* n, + const double* alpha, + const double* a, const f77_int* lda, + const double* x, const f77_int* incx, + const double* beta, + double* y, const f77_int* incy + ) +{ + trans_t blis_transa; + dim_t m0, n0; + dim_t m_y, n_x; + double* x0; + double* y0; + inc_t incx0; + inc_t incy0; + inc_t rs_a, cs_a; + + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); + AOCL_DTL_LOG_GEMV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'D', *transa, *m, *n, (void*)alpha, *lda, *incx, (void*)beta, *incy); + + /* Perform BLAS parameter checking. */ + PASTEBLACHK(gemv) + ( + MKSTR(d), + MKSTR(gemv), + transa, + m, + n, + lda, + incx, + incy + ); + + if (*m == 0 || *n == 0) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return; + } + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + if ( *transa == 'n' || *transa == 'N' ) blis_transa = BLIS_NO_TRANSPOSE; + else if ( *transa == 't' || *transa == 'T' ) blis_transa = BLIS_TRANSPOSE; + else if ( *transa == 'c' || *transa == 'C' ) blis_transa = BLIS_CONJ_TRANSPOSE; + else + { + // See comment for bli_param_map_netlib_to_blis_side() above. + //bli_check_error_code( BLIS_INVALID_TRANS ); + blis_transa = BLIS_NO_TRANSPOSE; + } + + /* Convert/typecast negative values of m and n to zero. */ + if ( *m < 0 ) m0 = ( dim_t )0; + else m0 = ( dim_t )(*m); + + if ( *n < 0 ) n0 = ( dim_t )0; + else n0 = ( dim_t )(*n); + + /* Determine the dimensions of x and y so we can adjust the increments, + if necessary.*/ + if ( bli_does_notrans( blis_transa ) ) + { + m_y = m0; + n_x = n0; + } + else + { + m_y = n0; + n_x = m0; + } + + /* BLAS handles cases where trans(A) has no columns, and x has no elements, + in a peculiar way. In these situations, BLAS returns without performing + any action, even though most sane interpretations of gemv would have the + the operation reduce to y := beta * y. Here, we catch those cases that + BLAS would normally mishandle and emulate the BLAS exactly so as to + provide "bug-for-bug" compatibility. Note that this extreme level of + compatibility would not be as much of an issue if it weren't for the + fact that some BLAS test suites actually test for these cases. Also, it + should be emphasized that BLIS, if called natively, does NOT exhibit + this quirky behavior; it will scale y by beta, as one would expect. */ + if ( m_y > 0 && n_x == 0 ) + { + /* Finalize BLIS. */ + // bli_finalize_auto(); + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return; + } + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + if ( *incx < 0 ) + { + x0 = ((double*)x) + (n_x-1)*(-*incx); + incx0 = ( inc_t )(*incx); + } + else + { + x0 = ((double*)x); + incx0 = ( inc_t )(*incx); + } + + if ( *incy < 0 ) + { + y0 = ((double*)y) + (m_y-1)*(-*incy); + incy0 = ( inc_t )(*incy); + } + else + { + y0 = ((double*)y); + incy0 = ( inc_t )(*incy); + } + + /* Set the row and column strides of A. */ + rs_a = 1; + cs_a = *lda; + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == FALSE) + { + /* Call BLIS interface. */ + PASTEMAC2(d,gemv,BLIS_TAPI_EX_SUF) + ( + blis_transa, + BLIS_NO_CONJUGATE, + m0, + n0, + (double*)alpha, + (double*)a, rs_a, cs_a, + x0, incx0, + (double*)beta, + y0, incy0, + NULL, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return; + } + + /* Call variants based on transpose value. */ + if(bli_does_notrans(blis_transa)) + { + //variant_2 is chosen for column-storage + // and uses axpyf-based implementation + bli_dgemv_unf_var2 + ( + blis_transa, + BLIS_NO_CONJUGATE, + m0, + n0, + (double*)alpha, + (double*)a, rs_a, cs_a, + x0, incx0, + (double*)beta, + y0, incy0, + NULL + ); + } + else + { + //var_1 is chosen for row-storage + //and uses dotxf-based implementation + bli_dgemv_unf_var1 + ( + blis_transa, + BLIS_NO_CONJUGATE, + m0, + n0, + (double*)alpha, + (double*)a, rs_a, cs_a, + x0, incx0, + (double*)beta, + y0, incy0, + NULL + ); + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); +} + +void sgemv_ + ( + const f77_char* transa, + const f77_int* m, + const f77_int* n, + const float* alpha, + const float* a, const f77_int* lda, + const float* x, const f77_int* incx, + const float* beta, + float* y, const f77_int* incy + ) +{ + trans_t blis_transa; + dim_t m0, n0; + dim_t m_y, n_x; + float* x0; + float* y0; + inc_t incx0; + inc_t incy0; + inc_t rs_a, cs_a; + + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); + AOCL_DTL_LOG_GEMV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'S', *transa, *m, *n, (void*)alpha, *lda, *incx, (void*)beta, *incy); + /* Perform BLAS parameter checking. */ + PASTEBLACHK(gemv) + ( + MKSTR(s), + MKSTR(gemv), + transa, + m, + n, + lda, + incx, + incy + ); + + if (*m == 0 || *n == 0) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return; + } + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + if ( *transa == 'n' || *transa == 'N' ) blis_transa = BLIS_NO_TRANSPOSE; + else if ( *transa == 't' || *transa == 'T' ) blis_transa = BLIS_TRANSPOSE; + else if ( *transa == 'c' || *transa == 'C' ) blis_transa = BLIS_CONJ_TRANSPOSE; + else + { + // See comment for bli_param_map_netlib_to_blis_side() above. + //bli_check_error_code( BLIS_INVALID_TRANS ); + blis_transa = BLIS_NO_TRANSPOSE; + } + + /* Convert/typecast negative values of m and n to zero. */ + if ( *m < 0 ) m0 = ( dim_t )0; + else m0 = ( dim_t )(*m); + + if ( *n < 0 ) n0 = ( dim_t )0; + else n0 = ( dim_t )(*n); + + /* Determine the dimensions of x and y so we can adjust the increments, + if necessary.*/ + if ( bli_does_notrans( blis_transa ) ) + { + m_y = m0; + n_x = n0; + } + else + { + m_y = n0; + n_x = m0; + } + + /* BLAS handles cases where trans(A) has no columns, and x has no elements, + in a peculiar way. In these situations, BLAS returns without performing + any action, even though most sane interpretations of gemv would have the + the operation reduce to y := beta * y. Here, we catch those cases that + BLAS would normally mishandle and emulate the BLAS exactly so as to + provide "bug-for-bug" compatibility. Note that this extreme level of + compatibility would not be as much of an issue if it weren't for the + fact that some BLAS test suites actually test for these cases. Also, it + should be emphasized that BLIS, if called natively, does NOT exhibit + this quirky behavior; it will scale y by beta, as one would expect. */ + if ( m_y > 0 && n_x == 0 ) + { + /* Finalize BLIS. */ + // bli_finalize_auto(); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return; + } + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + if ( *incx < 0 ) + { + x0 = ((float*)x) + (n_x-1)*(-*incx); + incx0 = ( inc_t )(*incx); + } + else + { + x0 = ((float*)x); + incx0 = ( inc_t )(*incx); + } + + if ( *incy < 0 ) + { + y0 = ((float*)y) + (m_y-1)*(-*incy); + incy0 = ( inc_t )(*incy); + } + else + { + y0 = ((float*)y); + incy0 = ( inc_t )(*incy); + } + + /* Set the row and column strides of A. */ + rs_a = 1; + cs_a = *lda; + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == FALSE) + { + /* Call BLIS interface. */ + PASTEMAC2(s,gemv,BLIS_TAPI_EX_SUF) + ( + blis_transa, + BLIS_NO_CONJUGATE, + m0, + n0, + (float*)alpha, + (float*)a, rs_a, cs_a, + x0, incx0, + (float*)beta, + y0, incy0, + NULL, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return; + } + + /* Call variants based on transpose value. */ + if(bli_does_notrans(blis_transa)) + { + bli_sgemv_unf_var2 + ( + blis_transa, + BLIS_NO_CONJUGATE, + m0, + n0, + (float*)alpha, + (float*)a, rs_a, cs_a, + x0, incx0, + (float*)beta, + y0, incy0, + NULL + ); + } + else + { + bli_sgemv_unf_var1 + ( + blis_transa, + BLIS_NO_CONJUGATE, + m0, + n0, + (float*)alpha, + (float*)a, rs_a, cs_a, + x0, incx0, + (float*)beta, + y0, incy0, + NULL + ); + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); +} + + +void cgemv_ + ( + const f77_char* transa, + const f77_int* m, + const f77_int* n, + const scomplex* alpha, + const scomplex* a, const f77_int* lda, + const scomplex* x, const f77_int* incx, + const scomplex* beta, + scomplex* y, const f77_int* incy + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); + AOCL_DTL_LOG_GEMV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'C', *transa, *m, *n, (void*)alpha, *lda, *incx, (void*)beta, *incy); + + trans_t blis_transa; + dim_t m0, n0; + dim_t m_y, n_x; + scomplex* x0; + scomplex* y0; + inc_t incx0; + inc_t incy0; + inc_t rs_a, cs_a; + + /* Perform BLAS parameter checking. */ + PASTEBLACHK(gemv) + ( + MKSTR(c), + MKSTR(gemv), + transa, + m, + n, + lda, + incx, + incy + ); + + if (*m == 0 || *n == 0) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return; + } + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + if( *transa == 'n' || *transa == 'N' ) blis_transa = BLIS_NO_TRANSPOSE; + else if( *transa == 't' || *transa == 'T' ) blis_transa = BLIS_TRANSPOSE; + else if( * transa == 'c' || *transa == 'C' ) blis_transa = BLIS_CONJ_TRANSPOSE; + else + { + // See comment for bli_param_map_netlib_to_blis_side() above. + // bli_check_error_code( BLIS_INVALID_TRANS ); + blis_transa = BLIS_NO_TRANSPOSE; + } + + /* Convert/typecast negative values of m and n to zero. */ + if( *m < 0 ) m0 = (dim_t)0; + else m0 = (dim_t)(*m); + + if( *n < 0 ) n0 = (dim_t)0; + else n0 = (dim_t)(*n); + + /* Determine the dimensions of x and y so we can adjust the increments, + if necessary.*/ + if( bli_does_notrans( blis_transa ) ) { m_y = m0, n_x = n0; } + else { m_y = n0; n_x = m0; } + + /* BLAS handles cases where trans(A) has no columns, and x has no elements, + in a peculiar way. In these situations, BLAS returns without performing + any action, even though most sane interpretations of gemv would have the + the operation reduce to y := beta * y. Here, we catch those cases that + BLAS would normally mishandle and emulate the BLAS exactly so as to + provide "bug-for-bug" compatibility. Note that this extreme level of + compatibility would not be as much of an issue if it weren't for the + fact that some BLAS test suites actually test for these cases. Also, it + should be emphasized that BLIS, if called natively, does NOT exhibit + this quirky behavior; it will scale y by beta, as one would expect. */ + + if ( m_y > 0 && n_x == 0 ) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return; + } + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + if( *incx < 0 ) + { + x0 = ((scomplex*)x) + (n_x-1)*(-*incx); + incx0 = ( inc_t )(*incx); + } + else + { + x0 = ((scomplex*)x); + incx0 = (inc_t)(*incx); + } + + if ( *incy < 0 ) + { + y0 = ((scomplex*)y) + (m_y-1)*(-*incy); + incy0 = ( inc_t )(*incy); + } + else + { + y0 = ((scomplex*)y); + incy0 = ( inc_t )(*incy); + } + + /* Set the row and column strides of A. */ + rs_a = 1; + cs_a = *lda; + + if( m_y == 1 ) + { + conj_t conja = bli_extract_conj(blis_transa); + scomplex rho; + if (bli_cpuid_is_avx_supported() == TRUE) + { + bli_cdotv_zen_int5 + ( + conja, + BLIS_NO_CONJUGATE, + n_x, + (scomplex*)a, bli_is_notrans(blis_transa)?cs_a:rs_a, + x0, incx0, + &rho, + NULL + ); + } + else + { + /* Call BLIS interface. */ + PASTEMAC2(c,dotv,BLIS_TAPI_EX_SUF) + ( + conja, + BLIS_NO_CONJUGATE, + n_x, + (scomplex*)a, bli_is_notrans(blis_transa)?cs_a:rs_a, + x0, incx0, + &rho, + NULL, + NULL + ); + } + + scomplex yval = *y0; + if(!bli_ceq0(*beta)) + { + bli_cscals( *beta, yval ); + } + else + { + bli_csetsc( 0.0, 0.0, &yval); + } + if(!bli_ceq0(*alpha)) + { + bli_caxpys( *alpha, rho, yval); + } + y0->real = yval.real; + y0->imag = yval.imag; + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return; + } + + if (bli_cpuid_is_avx_supported() == FALSE) + { + /* Call BLIS interface. */ + PASTEMAC2(c,gemv,BLIS_TAPI_EX_SUF) + ( + blis_transa, + BLIS_NO_CONJUGATE, + m0, + n0, + (scomplex*)alpha, + (scomplex*)a, rs_a, cs_a, + x0, incx0, + (scomplex*)beta, + y0, incy0, + NULL, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return; + } + + /* call variants based on transpose value */ + if( bli_does_notrans( blis_transa ) ) + { + bli_cgemv_unf_var2 + ( + blis_transa, + BLIS_NO_CONJUGATE, + m0, + n0, + (scomplex*)alpha, + (scomplex*)a, rs_a, cs_a, + x0, incx0, + (scomplex*)beta, + y0, incy0, + NULL + ); + } + else + { + bli_cgemv_unf_var1 + ( + blis_transa, + BLIS_NO_CONJUGATE, + m0, + n0, + (scomplex*)alpha, + (scomplex*)a, rs_a, cs_a, + x0, incx0, + (scomplex*)beta, + y0, incy0, + NULL + ); + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); +} + + +void zgemv_ + ( + const f77_char* transa, + const f77_int* m, + const f77_int* n, + const dcomplex* alpha, + const dcomplex* a, const f77_int* lda, + const dcomplex* x, const f77_int* incx, + const dcomplex* beta, + dcomplex* y, const f77_int* incy + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); + AOCL_DTL_LOG_GEMV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'Z', *transa, *m, *n, (void*)alpha, *lda, *incx, (void*)beta, *incy); + + trans_t blis_transa; + dim_t m0, n0; + dim_t m_y, n_x; + dcomplex* x0; + dcomplex* y0; + inc_t incx0; + inc_t incy0; + inc_t rs_a, cs_a; + + /* Perform BLAS parameter checking. */ + PASTEBLACHK(gemv) + ( + MKSTR(z), + MKSTR(gemv), + transa, + m, + n, + lda, + incx, + incy + ); + + if (*m == 0 || *n == 0) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return; + } + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + if( *transa == 'n' || *transa == 'N' ) blis_transa = BLIS_NO_TRANSPOSE; + else if( *transa == 't' || *transa == 'T' ) blis_transa = BLIS_TRANSPOSE; + else if( * transa == 'c' || *transa == 'C' ) blis_transa = BLIS_CONJ_TRANSPOSE; + else + { + // See comment for bli_param_map_netlib_to_blis_side() above. + // bli_check_error_code( BLIS_INVALID_TRANS ); + blis_transa = BLIS_NO_TRANSPOSE; + } + + /* Convert/typecast negative values of m and n to zero. */ + if( *m < 0 ) m0 = (dim_t)0; + else m0 = (dim_t)(*m); + + if( *n < 0 ) n0 = (dim_t)0; + else n0 = (dim_t)(*n); + + /* Determine the dimensions of x and y so we can adjust the increments, + if necessary.*/ + if( bli_does_notrans( blis_transa ) ) { m_y = m0, n_x = n0; } + else { m_y = n0; n_x = m0; } + + /* BLAS handles cases where trans(A) has no columns, and x has no elements, + in a peculiar way. In these situations, BLAS returns without performing + any action, even though most sane interpretations of gemv would have the + the operation reduce to y := beta * y. Here, we catch those cases that + BLAS would normally mishandle and emulate the BLAS exactly so as to + provide "bug-for-bug" compatibility. Note that this extreme level of + compatibility would not be as much of an issue if it weren't for the + fact that some BLAS test suites actually test for these cases. Also, it + should be emphasized that BLIS, if called natively, does NOT exhibit + this quirky behavior; it will scale y by beta, as one would expect. */ + + if ( m_y > 0 && n_x == 0 ) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return; + } + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + if( *incx < 0 ) + { + x0 = ((dcomplex*)x) + (n_x-1)*(-*incx); + incx0 = ( inc_t )(*incx); + } + else + { + x0 = ((dcomplex*)x); + incx0 = (inc_t)(*incx); + } + + if ( *incy < 0 ) + { + y0 = ((dcomplex*)y) + (m_y-1)*(-*incy); + incy0 = ( inc_t )(*incy); + } + else + { + y0 = ((dcomplex*)y); + incy0 = ( inc_t )(*incy); + } + + /* Set the row and column strides of A. */ + rs_a = 1; + cs_a = *lda; + + if( m_y == 1 ) + { + conj_t conja = bli_extract_conj(blis_transa); + dcomplex rho; + + if (bli_cpuid_is_avx_supported() == TRUE) + { + bli_zdotv_zen_int5 + ( + conja, + BLIS_NO_CONJUGATE, + n_x, + (dcomplex*)a, bli_is_notrans(blis_transa)?cs_a:rs_a, + x0, incx0, + &rho, + NULL + ); + } + else + { + /* Call BLIS interface. */ + PASTEMAC2(z,dotv,BLIS_TAPI_EX_SUF) + ( + conja, + BLIS_NO_CONJUGATE, + n_x, + (dcomplex*)a, bli_is_notrans(blis_transa)?cs_a:rs_a, + x0, incx0, + &rho, + NULL, + NULL + ); + } + + dcomplex yval = *y0; + if(!bli_zeq0(*beta)) + { + bli_zscals( *beta, yval ); + } + else + { + bli_zsetsc( 0.0, 0.0, &yval); + } + if(!bli_zeq0(*alpha)) + { + bli_zaxpys( *alpha, rho, yval); + } + y0->real = yval.real; + y0->imag = yval.imag; + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return; + } + + if (bli_cpuid_is_avx_supported() == FALSE) + { + /* Call BLIS interface. */ + PASTEMAC2(z,gemv,BLIS_TAPI_EX_SUF) + ( + blis_transa, + BLIS_NO_CONJUGATE, + m0, + n0, + (dcomplex*)alpha, + (dcomplex*)a, rs_a, cs_a, + x0, incx0, + (dcomplex*)beta, + y0, incy0, + NULL, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return; + } + + /* call variants based on transpose value */ + if( bli_does_notrans( blis_transa ) ) + { + bli_zgemv_unf_var2 + ( + blis_transa, + BLIS_NO_CONJUGATE, + m0, + n0, + (dcomplex*)alpha, + (dcomplex*)a, rs_a, cs_a, + x0, incx0, + (dcomplex*)beta, + y0, incy0, + NULL + ); + } + else + { + bli_zgemv_unf_var1 + ( + blis_transa, + BLIS_NO_CONJUGATE, + m0, + n0, + (dcomplex*)alpha, + (dcomplex*)a, rs_a, cs_a, + x0, incx0, + (dcomplex*)beta, + y0, incy0, + NULL + ); + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); +} + + + +#endif diff --git a/frame/compat/bla_scal.c b/frame/compat/bla_scal.c index ab63a34592..b9651577eb 100644 --- a/frame/compat/bla_scal.c +++ b/frame/compat/bla_scal.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020-21, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020-22, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -93,179 +93,5 @@ void PASTEF772(chx,cha,blasname) \ } #ifdef BLIS_ENABLE_BLAS -#ifdef BLIS_CONFIG_EPYC - -void sscal_ - ( - const f77_int* n, - const float* alpha, - float* x, const f77_int* incx - ) -{ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) - AOCL_DTL_LOG_SCAL_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'S', (void *) alpha, *n, *incx ); - dim_t n0; - float* x0; - inc_t incx0; - /* Initialize BLIS. */ - //bli_init_auto(); - - if (*n == 0 || alpha == NULL) { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return; - } - - /* Convert/typecast negative values of n to zero. */ - if ( *n < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(*n); - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - if ( *incx < 0 ) - { - /* The semantics of negative stride in BLAS are that the vector - operand be traversed in reverse order. (Another way to think - of this is that negative strides effectively reverse the order - of the vector, but without any explicit data movements.) This - is also how BLIS interprets negative strides. The differences - is that with BLAS, the caller *always* passes in the 0th (i.e., - top-most or left-most) element of the vector, even when the - stride is negative. By contrast, in BLIS, negative strides are - used *relative* to the vector address as it is given. Thus, in - BLIS, if this backwards traversal is desired, the caller *must* - pass in the address to the (n-1)th (i.e., the bottom-most or - right-most) element along with a negative stride. */ - - x0 = (x) + (n0-1)*(-*incx); - incx0 = ( inc_t )(*incx); - - } - else - { - x0 = (x); - incx0 = ( inc_t )(*incx); - } - /* Call BLIS kernel */ - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN4) || - (id == BLIS_ARCH_ZEN3) || - (id == BLIS_ARCH_ZEN2) || - (id == BLIS_ARCH_ZEN); - - if (bamdzen) { - bli_sscalv_zen_int10 - ( - BLIS_NO_CONJUGATE, - n0, - (float *)alpha, - x0, incx0, - NULL - ); - } - else{ - PASTEMAC2(s,scalv,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE,\ - n0, \ - (float *)alpha,\ - x0, incx0,\ - NULL, \ - NULL \ - );\ - } - - /* Finalize BLIS. */ -// bli_finalize_auto(); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) -} - -void dscal_ - ( - const f77_int* n, - const double* alpha, - double* x, const f77_int* incx - ) -{ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) - AOCL_DTL_LOG_SCAL_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'D', (void *)alpha, *n, *incx ); - dim_t n0; - double* x0; - inc_t incx0; - - /* Initialize BLIS */ - //bli_init_auto(); - - if (*n == 0 || alpha == NULL) { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return; - } - - /* Convert typecast negative values of n to zero. */ - if ( *n < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(*n); - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - if ( *incx < 0 ) - { - /* The semantics of negative stride in BLAS are that the vector - operand be traversed in reverse order. (Another way to think - of this is that negative strides effectively reverse the order - of the vector, but without any explicit data movements.) This - is also how BLIS interprets negative strides. The differences - is that with BLAS, the caller *always* passes in the 0th (i.e., - top-most or left-most) element of the vector, even when the - stride is negative. By contrast, in BLIS, negative strides are - used *relative* to the vector address as it is given. Thus, in - BLIS, if this backwards traversal is desired, the caller *must* - pass in the address to the (n-1)th (i.e., the bottom-most or - right-most) element along with a negative stride. */ - - x0 = (x) + (n0-1)*(-*incx); - incx0 = ( inc_t )(*incx); - - } - else - { - x0 = (x); - incx0 = ( inc_t )(*incx); - } - /* Call BLIS kernel */ - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN4) || - (id == BLIS_ARCH_ZEN3) || - (id == BLIS_ARCH_ZEN2) || - (id == BLIS_ARCH_ZEN); - - if (bamdzen){ - bli_dscalv_zen_int10 - ( - BLIS_NO_CONJUGATE, - n0, - (double*) alpha, - x0, incx0, - NULL - ); - } - else{ - PASTEMAC2(d,scalv,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE,\ - n0, \ - (double *)alpha,\ - x0, incx0,\ - NULL, \ - NULL \ - );\ - } - - /* Finalize BLIS. */ -// bli_finalize_auto(); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) -} - -INSERT_GENTFUNCSCAL_BLAS_CZ( scal, scalv ) -#else INSERT_GENTFUNCSCAL_BLAS( scal, scalv ) #endif -#endif diff --git a/frame/compat/bla_scal_amd.c b/frame/compat/bla_scal_amd.c new file mode 100644 index 0000000000..178776a149 --- /dev/null +++ b/frame/compat/bla_scal_amd.c @@ -0,0 +1,260 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020-22, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + + +// +// Define BLAS-to-BLIS interfaces. +// +#undef GENTFUNCSCAL +#define GENTFUNCSCAL( ftype_x, ftype_a, chx, cha, blasname, blisname ) \ +\ +void PASTEF772(chx,cha,blasname) \ + ( \ + const f77_int* n, \ + const ftype_a* alpha, \ + ftype_x* x, const f77_int* incx \ + ) \ +{ \ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) \ + dim_t n0; \ + ftype_x* x0; \ + inc_t incx0; \ + ftype_x alpha_cast; \ +\ + /* Initialize BLIS. */ \ + bli_init_auto(); \ +\ + if (*n == 0 || alpha == NULL) { \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ + return ; \ + } \ +\ + /* Convert/typecast negative values of n to zero. */ \ + bli_convert_blas_dim1( *n, n0 ); \ +\ + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ \ + bli_convert_blas_incv( n0, (ftype_x*)x, *incx, x0, incx0 ); \ +\ + /* NOTE: We do not natively implement BLAS's csscal/zdscal in BLIS. + that is, we just always sub-optimally implement those cases + by casting alpha to ctype_x (potentially the complex domain) and + using the homogeneous datatype instance according to that type. */ \ + PASTEMAC2(cha,chx,copys)( *alpha, alpha_cast ); \ +\ + /* Call BLIS interface. */ \ + PASTEMAC2(chx,blisname,BLIS_TAPI_EX_SUF) \ + ( \ + BLIS_NO_CONJUGATE, \ + n0, \ + &alpha_cast, \ + x0, incx0, \ + NULL, \ + NULL \ + ); \ +\ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ +} + +#ifdef BLIS_ENABLE_BLAS + +void sscal_ + ( + const f77_int* n, + const float* alpha, + float* x, const f77_int* incx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) + AOCL_DTL_LOG_SCAL_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'S', (void *) alpha, *n, *incx ); + dim_t n0; + float* x0; + inc_t incx0; + /* Initialize BLIS. */ + //bli_init_auto(); + + if (*n == 0 || alpha == NULL) { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return; + } + + /* Convert/typecast negative values of n to zero. */ + if ( *n < 0 ) n0 = ( dim_t )0; + else n0 = ( dim_t )(*n); + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + if ( *incx < 0 ) + { + /* The semantics of negative stride in BLAS are that the vector + operand be traversed in reverse order. (Another way to think + of this is that negative strides effectively reverse the order + of the vector, but without any explicit data movements.) This + is also how BLIS interprets negative strides. The differences + is that with BLAS, the caller *always* passes in the 0th (i.e., + top-most or left-most) element of the vector, even when the + stride is negative. By contrast, in BLIS, negative strides are + used *relative* to the vector address as it is given. Thus, in + BLIS, if this backwards traversal is desired, the caller *must* + pass in the address to the (n-1)th (i.e., the bottom-most or + right-most) element along with a negative stride. */ + + x0 = (x) + (n0-1)*(-*incx); + incx0 = ( inc_t )(*incx); + + } + else + { + x0 = (x); + incx0 = ( inc_t )(*incx); + } + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) { + bli_sscalv_zen_int10 + ( + BLIS_NO_CONJUGATE, + n0, + (float *)alpha, + x0, incx0, + NULL + ); + } + else{ + PASTEMAC2(s,scalv,BLIS_TAPI_EX_SUF) \ + ( \ + BLIS_NO_CONJUGATE,\ + n0, \ + (float *)alpha,\ + x0, incx0,\ + NULL, \ + NULL \ + );\ + } + + /* Finalize BLIS. */ +// bli_finalize_auto(); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) +} + +void dscal_ + ( + const f77_int* n, + const double* alpha, + double* x, const f77_int* incx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) + AOCL_DTL_LOG_SCAL_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'D', (void *)alpha, *n, *incx ); + dim_t n0; + double* x0; + inc_t incx0; + + /* Initialize BLIS */ + //bli_init_auto(); + + if (*n == 0 || alpha == NULL) { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return; + } + + /* Convert typecast negative values of n to zero. */ + if ( *n < 0 ) n0 = ( dim_t )0; + else n0 = ( dim_t )(*n); + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + if ( *incx < 0 ) + { + /* The semantics of negative stride in BLAS are that the vector + operand be traversed in reverse order. (Another way to think + of this is that negative strides effectively reverse the order + of the vector, but without any explicit data movements.) This + is also how BLIS interprets negative strides. The differences + is that with BLAS, the caller *always* passes in the 0th (i.e., + top-most or left-most) element of the vector, even when the + stride is negative. By contrast, in BLIS, negative strides are + used *relative* to the vector address as it is given. Thus, in + BLIS, if this backwards traversal is desired, the caller *must* + pass in the address to the (n-1)th (i.e., the bottom-most or + right-most) element along with a negative stride. */ + + x0 = (x) + (n0-1)*(-*incx); + incx0 = ( inc_t )(*incx); + + } + else + { + x0 = (x); + incx0 = ( inc_t )(*incx); + } + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE){ + bli_dscalv_zen_int10 + ( + BLIS_NO_CONJUGATE, + n0, + (double*) alpha, + x0, incx0, + NULL + ); + } + else{ + PASTEMAC2(d,scalv,BLIS_TAPI_EX_SUF) \ + ( \ + BLIS_NO_CONJUGATE,\ + n0, \ + (double *)alpha,\ + x0, incx0,\ + NULL, \ + NULL \ + );\ + } + + /* Finalize BLIS. */ +// bli_finalize_auto(); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) +} + +INSERT_GENTFUNCSCAL_BLAS_CZ( scal, scalv ) + +#endif diff --git a/frame/compat/bla_swap.c b/frame/compat/bla_swap.c index 526414f332..d653426478 100644 --- a/frame/compat/bla_swap.c +++ b/frame/compat/bla_swap.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020-21, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020-22, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -83,198 +83,5 @@ void PASTEF77(ch,blasname) \ } #ifdef BLIS_ENABLE_BLAS -#ifdef BLIS_CONFIG_EPYC - -void sswap_ - ( - const f77_int* n, - float* x, const f77_int* incx, - float* y, const f77_int* incy - ) -{ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) - AOCL_DTL_LOG_SWAP_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'S', *n, *incx, *incy); - dim_t n0; - float* x0; - float* y0; - inc_t incx0; - inc_t incy0; - - /* Initialize BLIS. */ -// bli_init_auto(); - - /* Convert/typecast negative values of n to zero. */ - if ( *n < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(*n); - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - if ( *incx < 0 ) - { - /* The semantics of negative stride in BLAS are that the vector - operand be traversed in reverse order. (Another way to think - of this is that negative strides effectively reverse the order - of the vector, but without any explicit data movements.) This - is also how BLIS interprets negative strides. The differences - is that with BLAS, the caller *always* passes in the 0th (i.e., - top-most or left-most) element of the vector, even when the - stride is negative. By contrast, in BLIS, negative strides are - used *relative* to the vector address as it is given. Thus, in - BLIS, if this backwards traversal is desired, the caller *must* - pass in the address to the (n-1)th (i.e., the bottom-most or - right-most) element along with a negative stride. */ - - x0 = (x) + (n0-1)*(-*incx); - incx0 = ( inc_t )(*incx); - - } - else - { - x0 = (x); - incx0 = ( inc_t )(*incx); - } - - if ( *incy < 0 ) - { - y0 = (y) + (n0-1)*(-*incy); - incy0 = ( inc_t )(*incy); - - } - else - { - y0 = (y); - incy0 = ( inc_t )(*incy); - } - - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN4) || - (id == BLIS_ARCH_ZEN3) || - (id == BLIS_ARCH_ZEN2) || - (id == BLIS_ARCH_ZEN); - - if (bamdzen) { -/* Call BLIS kernel */ - bli_sswapv_zen_int8 - ( - n0, - x0, incx0, - y0, incy0, - NULL - ); - } - else{ - PASTEMAC2(s,swapv,BLIS_TAPI_EX_SUF) \ - ( \ - n0, \ - x0, incx0, \ - y0, incy0, \ - NULL, \ - NULL \ - ); \ - } - - /* Finalize BLIS. */ -// bli_finalize_auto(); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) -} - -void dswap_ - ( - const f77_int* n, - double* x, const f77_int* incx, - double* y, const f77_int* incy - ) -{ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) - AOCL_DTL_LOG_SWAP_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'D', *n, *incx, *incy); - dim_t n0; - double* x0; - double* y0; - inc_t incx0; - inc_t incy0; - - /* Initialize BLIS. */ -// bli_init_auto(); - - /* Convert/typecast negative values of n to zero. */ - if ( *n < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(*n); - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - if ( *incx < 0 ) - { - /* The semantics of negative stride in BLAS are that the vector - operand be traversed in reverse order. (Another way to think - of this is that negative strides effectively reverse the order - of the vector, but without any explicit data movements.) This - is also how BLIS interprets negative strides. The differences - is that with BLAS, the caller *always* passes in the 0th (i.e., - top-most or left-most) element of the vector, even when the - stride is negative. By contrast, in BLIS, negative strides are - used *relative* to the vector address as it is given. Thus, in - BLIS, if this backwards traversal is desired, the caller *must* - pass in the address to the (n-1)th (i.e., the bottom-most or - right-most) element along with a negative stride. */ - - x0 = (x) + (n0-1)*(-*incx); - incx0 = ( inc_t )(*incx); - - } - else - { - x0 = (x); - incx0 = ( inc_t )(*incx); - } - - if ( *incy < 0 ) - { - y0 = (y) + (n0-1)*(-*incy); - incy0 = ( inc_t )(*incy); - - } - else - { - y0 = (y); - incy0 = ( inc_t )(*incy); - } - - - /* Call BLIS kernel */ - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN4) || - (id == BLIS_ARCH_ZEN3) || - (id == BLIS_ARCH_ZEN2) || - (id == BLIS_ARCH_ZEN); - - if (bamdzen) { - bli_dswapv_zen_int8 - ( - n0, - x0, incx0, - y0, incy0, - NULL - ); - } - else{ - PASTEMAC2(d,swapv,BLIS_TAPI_EX_SUF) \ - ( \ - n0, \ - x0, incx0, \ - y0, incy0, \ - NULL, \ - NULL \ - ); \ - } - - /* Finalize BLIS. */ -// bli_finalize_auto(); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) -} - -INSERT_GENTFUNC_BLAS_CZ( swap, swapv ) - -#else INSERT_GENTFUNC_BLAS( swap, swapv ) #endif -#endif diff --git a/frame/compat/bla_swap_amd.c b/frame/compat/bla_swap_amd.c new file mode 100644 index 0000000000..617c78a4aa --- /dev/null +++ b/frame/compat/bla_swap_amd.c @@ -0,0 +1,268 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020-22, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + + +// +// Define BLAS-to-BLIS interfaces. +// +#undef GENTFUNC +#define GENTFUNC( ftype, ch, blasname, blisname ) \ +\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_int* n, \ + ftype* x, const f77_int* incx, \ + ftype* y, const f77_int* incy \ + ) \ +{ \ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) \ + dim_t n0; \ + ftype* x0; \ + ftype* y0; \ + inc_t incx0; \ + inc_t incy0; \ +\ + /* Initialize BLIS. */ \ + bli_init_auto(); \ +\ + /* Convert/typecast negative values of n to zero. */ \ + bli_convert_blas_dim1( *n, n0 ); \ +\ + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ \ + bli_convert_blas_incv( n0, (ftype*)x, *incx, x0, incx0 ); \ + bli_convert_blas_incv( n0, (ftype*)y, *incy, y0, incy0 ); \ +\ + /* Call BLIS interface. */ \ + PASTEMAC2(ch,blisname,BLIS_TAPI_EX_SUF) \ + ( \ + n0, \ + x0, incx0, \ + y0, incy0, \ + NULL, \ + NULL \ + ); \ +\ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ +} + +#ifdef BLIS_ENABLE_BLAS + +void sswap_ + ( + const f77_int* n, + float* x, const f77_int* incx, + float* y, const f77_int* incy + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) + AOCL_DTL_LOG_SWAP_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'S', *n, *incx, *incy); + dim_t n0; + float* x0; + float* y0; + inc_t incx0; + inc_t incy0; + + /* Initialize BLIS. */ +// bli_init_auto(); + + /* Convert/typecast negative values of n to zero. */ + if ( *n < 0 ) n0 = ( dim_t )0; + else n0 = ( dim_t )(*n); + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + if ( *incx < 0 ) + { + /* The semantics of negative stride in BLAS are that the vector + operand be traversed in reverse order. (Another way to think + of this is that negative strides effectively reverse the order + of the vector, but without any explicit data movements.) This + is also how BLIS interprets negative strides. The differences + is that with BLAS, the caller *always* passes in the 0th (i.e., + top-most or left-most) element of the vector, even when the + stride is negative. By contrast, in BLIS, negative strides are + used *relative* to the vector address as it is given. Thus, in + BLIS, if this backwards traversal is desired, the caller *must* + pass in the address to the (n-1)th (i.e., the bottom-most or + right-most) element along with a negative stride. */ + + x0 = (x) + (n0-1)*(-*incx); + incx0 = ( inc_t )(*incx); + + } + else + { + x0 = (x); + incx0 = ( inc_t )(*incx); + } + + if ( *incy < 0 ) + { + y0 = (y) + (n0-1)*(-*incy); + incy0 = ( inc_t )(*incy); + + } + else + { + y0 = (y); + incy0 = ( inc_t )(*incy); + } + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) { + /* Call BLIS kernel */ + bli_sswapv_zen_int8 + ( + n0, + x0, incx0, + y0, incy0, + NULL + ); + } + else{ + PASTEMAC2(s,swapv,BLIS_TAPI_EX_SUF) \ + ( \ + n0, \ + x0, incx0, \ + y0, incy0, \ + NULL, \ + NULL \ + ); \ + } + + /* Finalize BLIS. */ +// bli_finalize_auto(); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) +} + +void dswap_ + ( + const f77_int* n, + double* x, const f77_int* incx, + double* y, const f77_int* incy + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) + AOCL_DTL_LOG_SWAP_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'D', *n, *incx, *incy); + dim_t n0; + double* x0; + double* y0; + inc_t incx0; + inc_t incy0; + + /* Initialize BLIS. */ +// bli_init_auto(); + + /* Convert/typecast negative values of n to zero. */ + if ( *n < 0 ) n0 = ( dim_t )0; + else n0 = ( dim_t )(*n); + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + if ( *incx < 0 ) + { + /* The semantics of negative stride in BLAS are that the vector + operand be traversed in reverse order. (Another way to think + of this is that negative strides effectively reverse the order + of the vector, but without any explicit data movements.) This + is also how BLIS interprets negative strides. The differences + is that with BLAS, the caller *always* passes in the 0th (i.e., + top-most or left-most) element of the vector, even when the + stride is negative. By contrast, in BLIS, negative strides are + used *relative* to the vector address as it is given. Thus, in + BLIS, if this backwards traversal is desired, the caller *must* + pass in the address to the (n-1)th (i.e., the bottom-most or + right-most) element along with a negative stride. */ + + x0 = (x) + (n0-1)*(-*incx); + incx0 = ( inc_t )(*incx); + + } + else + { + x0 = (x); + incx0 = ( inc_t )(*incx); + } + + if ( *incy < 0 ) + { + y0 = (y) + (n0-1)*(-*incy); + incy0 = ( inc_t )(*incy); + + } + else + { + y0 = (y); + incy0 = ( inc_t )(*incy); + } + + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) { + bli_dswapv_zen_int8 + ( + n0, + x0, incx0, + y0, incy0, + NULL + ); + } + else{ + PASTEMAC2(d,swapv,BLIS_TAPI_EX_SUF) \ + ( \ + n0, \ + x0, incx0, \ + y0, incy0, \ + NULL, \ + NULL \ + ); \ + } + + /* Finalize BLIS. */ +// bli_finalize_auto(); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) +} + +INSERT_GENTFUNC_BLAS_CZ( swap, swapv ) + + +#endif diff --git a/frame/compat/bla_trsm.c b/frame/compat/bla_trsm.c index fa8f0dacd1..fea7ba6f17 100644 --- a/frame/compat/bla_trsm.c +++ b/frame/compat/bla_trsm.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019 - 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -380,1175 +380,5 @@ void PASTEF77(ch,blasname) \ #endif #ifdef BLIS_ENABLE_BLAS -#ifdef BLIS_CONFIG_EPYC - -void strsm_ -( - const f77_char* side, - const f77_char* uploa, - const f77_char* transa, - const f77_char* diaga, - const f77_int* m, - const f77_int* n, - const float* alpha, - const float* a, const f77_int* lda, - float* b, const f77_int* ldb -) -{ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO) - AOCL_DTL_LOG_TRSM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'd', - *side, *uploa,*transa, *diaga, *m, *n, - (void*)alpha,*lda, *ldb); - - side_t blis_side; - uplo_t blis_uploa; - trans_t blis_transa; - diag_t blis_diaga; - dim_t m0, n0; - conj_t conja = BLIS_NO_CONJUGATE ; - - /* Initialize BLIS. */ - bli_init_auto(); - - /* Perform BLAS parameter checking. */ - PASTEBLACHK(trsm) - ( - MKSTR(s), - MKSTR(trsm), - side, - uploa, - transa, - diaga, - m, - n, - lda, - ldb - ); - - /* Map BLAS chars to their corresponding BLIS enumerated type value. */ - bli_param_map_netlib_to_blis_side( *side, &blis_side ); - bli_param_map_netlib_to_blis_uplo( *uploa, &blis_uploa ); - bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); - bli_param_map_netlib_to_blis_diag( *diaga, &blis_diaga ); - - /* Typecast BLAS integers to BLIS integers. */ - bli_convert_blas_dim1( *m, m0 ); - bli_convert_blas_dim1( *n, n0 ); - - /* Set the row and column strides of the matrix operands. */ - const inc_t rs_a = 1; - const inc_t cs_a = *lda; - const inc_t rs_b = 1; - const inc_t cs_b = *ldb; - const num_t dt = BLIS_FLOAT; - - if( n0 == 1 ) - { - if( blis_side == BLIS_LEFT ) - { - if(bli_is_notrans(blis_transa)) - { - bli_strsv_unf_var2 - ( - blis_uploa, - blis_transa, - blis_diaga, - m0, - (float*)alpha, - (float*)a, rs_a, cs_a, - (float*)b, rs_b, - NULL - ); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return; - } - else if(bli_is_trans(blis_transa)) - { - bli_strsv_unf_var1 - ( - blis_uploa, - blis_transa, - blis_diaga, - m0, - (float*)alpha, - (float*)a, rs_a, cs_a, - (float*)b, rs_b, - NULL - ); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return; - } - } - else if( ( blis_side == BLIS_RIGHT ) && ( m0 != 1 ) ) - { - /* b = alpha * b; */ - bli_sscalv_ex - ( - conja, - m0, - (float*)alpha, - b, rs_b, - NULL, - NULL - ); - if(blis_diaga == BLIS_NONUNIT_DIAG) - { - float inva = 1.0/ *a; - for(dim_t indx = 0; indx < m0; indx ++) - { - b[indx] = ( inva * b[indx] ); - } - } - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return; - } - } - else if( m0 == 1 ) - { - if(blis_side == BLIS_RIGHT) - { - if(bli_is_notrans(blis_transa)) - { - if(blis_uploa == BLIS_UPPER) - blis_uploa = BLIS_LOWER; - else - blis_uploa = BLIS_UPPER; - - bli_strsv_unf_var1 - ( - blis_uploa, - blis_transa, - blis_diaga, - n0, - (float*)alpha, - (float*)a, cs_a, rs_a, - (float*)b, cs_b, - NULL - ); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return; - } - else if(bli_is_trans(blis_transa)) - { - if(blis_uploa == BLIS_UPPER) - blis_uploa = BLIS_LOWER; - else - blis_uploa = BLIS_UPPER; - - bli_strsv_unf_var2 - ( - blis_uploa, - blis_transa, - blis_diaga, - n0, - (float*)alpha, - (float*)a, cs_a, rs_a, - (float*)b, cs_b, - NULL - ); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return; - } - } - else if(( blis_side == BLIS_LEFT ) && ( n0 != 1 )) - { - /* b = alpha * b; */ - bli_sscalv_ex - ( - conja, - n0, - (float*)alpha, - b, cs_b, - NULL, - NULL - ); - if(blis_diaga == BLIS_NONUNIT_DIAG) - { - float inva = 1.0/ *a; - for(dim_t indx = 0; indx < n0; indx ++) - { - b[indx*cs_b] = (inva * b[indx*cs_b] ); - } - } - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return; - } - } - const struc_t struca = BLIS_TRIANGULAR; - - obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; - obj_t ao = BLIS_OBJECT_INITIALIZER; - obj_t bo = BLIS_OBJECT_INITIALIZER; - - dim_t mn0_a; - - bli_set_dim_with_side( blis_side, m0, n0, &mn0_a ); - - bli_obj_init_finish_1x1( dt, (float*)alpha, &alphao ); - - bli_obj_init_finish( dt, mn0_a, mn0_a, (float*)a, rs_a, cs_a, &ao ); - bli_obj_init_finish( dt, m0, n0, (float*)b, rs_b, cs_b, &bo ); - - bli_obj_set_uplo( blis_uploa, &ao ); - bli_obj_set_diag( blis_diaga, &ao ); - bli_obj_set_conjtrans( blis_transa, &ao ); - - bli_obj_set_struc( struca, &ao ); - - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN4) || - (id == BLIS_ARCH_ZEN3) || - (id == BLIS_ARCH_ZEN2) || - (id == BLIS_ARCH_ZEN); - - if (bamdzen) { -#ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM - /* bli_strsm_small is performing better existing native - * implementations for [m,n]<=1000 for single thread. - * In case of multithread when [m,n]<=128 sinlge thread implemenation - * is doing better than native multithread */ - bool nt = bli_thread_get_is_parallel(); - if((nt==0 && m0<=1000 && n0<=1000) || - (nt && (m0+n0)<320) ) - { - err_t status; - status = bli_trsm_small - ( - blis_side, - &alphao, - &ao, - &bo, - NULL, - NULL - ); - if (status == BLIS_SUCCESS) - { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - /* Finalize BLIS. */ - bli_finalize_auto(); - return; - } - } -#endif - } - bli_trsmnat - ( - blis_side, - &alphao, - &ao, - &bo, - NULL, - NULL - ); - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) - /* Finalize BLIS. */ - bli_finalize_auto(); -} - -void dtrsm_ -( - const f77_char* side, - const f77_char* uploa, - const f77_char* transa, - const f77_char* diaga, - const f77_int* m, - const f77_int* n, - const double* alpha, - const double* a, const f77_int* lda, - double* b, const f77_int* ldb -) -{ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO) - AOCL_DTL_LOG_TRSM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'd', - *side, *uploa,*transa, *diaga, *m, *n, - (void*)alpha,*lda, *ldb); - - side_t blis_side; - uplo_t blis_uploa; - trans_t blis_transa; - diag_t blis_diaga; - dim_t m0, n0; - conj_t conja = BLIS_NO_CONJUGATE ; - - /* Initialize BLIS. */ - bli_init_auto(); - - /* Perform BLAS parameter checking. */ - PASTEBLACHK(trsm) - ( - MKSTR(d), - MKSTR(trsm), - side, - uploa, - transa, - diaga, - m, - n, - lda, - ldb - ); - - /* Map BLAS chars to their corresponding BLIS enumerated type value. */ - bli_param_map_netlib_to_blis_side( *side, &blis_side ); - bli_param_map_netlib_to_blis_uplo( *uploa, &blis_uploa ); - bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); - bli_param_map_netlib_to_blis_diag( *diaga, &blis_diaga ); - - /* Typecast BLAS integers to BLIS integers. */ - bli_convert_blas_dim1( *m, m0 ); - bli_convert_blas_dim1( *n, n0 ); - - /* Set the row and column strides of the matrix operands. */ - const inc_t rs_a = 1; - const inc_t cs_a = *lda; - const inc_t rs_b = 1; - const inc_t cs_b = *ldb; - const num_t dt = BLIS_DOUBLE; - - if( n0 == 1 ) - { - if( blis_side == BLIS_LEFT ) - { - if(bli_is_notrans(blis_transa)) - { - bli_dtrsv_unf_var2 - ( - blis_uploa, - blis_transa, - blis_diaga, - m0, - (double*)alpha, - (double*)a, rs_a, cs_a, - (double*)b, rs_b, - NULL - ); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return; - } - else if(bli_is_trans(blis_transa)) - { - bli_dtrsv_unf_var1 - ( - blis_uploa, - blis_transa, - blis_diaga, - m0, - (double*)alpha, - (double*)a, rs_a, cs_a, - (double*)b, rs_b, - NULL - ); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return; - } - } - else if( ( blis_side == BLIS_RIGHT ) && ( m0 != 1 ) ) - { - /* b = alpha * b; */ - bli_dscalv_ex - ( - conja, - m0, - (double*)alpha, - b, rs_b, - NULL, - NULL - ); - if(blis_diaga == BLIS_NONUNIT_DIAG) - { - double inva = 1.0/ *a; - for(dim_t indx = 0; indx < m0; indx ++) - { - b[indx] = ( inva * b[indx] ); - } - } - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return; - } - } - else if( m0 == 1 ) - { - if(blis_side == BLIS_RIGHT) - { - if(bli_is_notrans(blis_transa)) - { - if(blis_uploa == BLIS_UPPER) - blis_uploa = BLIS_LOWER; - else - blis_uploa = BLIS_UPPER; - - bli_dtrsv_unf_var1 - ( - blis_uploa, - blis_transa, - blis_diaga, - n0, - (double*)alpha, - (double*)a, cs_a, rs_a, - (double*)b, cs_b, - NULL - ); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return; - } - else if(bli_is_trans(blis_transa)) - { - if(blis_uploa == BLIS_UPPER) - blis_uploa = BLIS_LOWER; - else - blis_uploa = BLIS_UPPER; - - bli_dtrsv_unf_var2 - ( - blis_uploa, - blis_transa, - blis_diaga, - n0, - (double*)alpha, - (double*)a, cs_a, rs_a, - (double*)b, cs_b, - NULL - ); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return; - } - } - else if(( blis_side == BLIS_LEFT ) && ( n0 != 1 )) - { - /* b = alpha * b; */ - bli_dscalv_ex - ( - conja, - n0, - (double*)alpha, - b, cs_b, - NULL, - NULL - ); - if(blis_diaga == BLIS_NONUNIT_DIAG) - { - double inva = 1.0/ *a; - for(dim_t indx = 0; indx < n0; indx ++) - { - b[indx*cs_b] = (inva * b[indx*cs_b] ); - } - } - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return; - } - } - - const struc_t struca = BLIS_TRIANGULAR; - - obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; - obj_t ao = BLIS_OBJECT_INITIALIZER; - obj_t bo = BLIS_OBJECT_INITIALIZER; - - dim_t mn0_a; - - bli_set_dim_with_side( blis_side, m0, n0, &mn0_a ); - - bli_obj_init_finish_1x1( dt, (double*)alpha, &alphao ); - - bli_obj_init_finish( dt, mn0_a, mn0_a, (double*)a, rs_a, cs_a, &ao ); - bli_obj_init_finish( dt, m0, n0, (double*)b, rs_b, cs_b, &bo ); - - bli_obj_set_uplo( blis_uploa, &ao ); - bli_obj_set_diag( blis_diaga, &ao ); - bli_obj_set_conjtrans( blis_transa, &ao ); - - bli_obj_set_struc( struca, &ao ); - - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN4) || - (id == BLIS_ARCH_ZEN3) || - (id == BLIS_ARCH_ZEN2) || - (id == BLIS_ARCH_ZEN); - - if (bamdzen) { -#ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM - /* bli_dtrsm_small is performing better existing native - * implementations for [m,n]<=1000 for single thread. - * In case of multithread when [m,n]<=128 sinlge thread implemenation - * is doing better than native multithread */ - bool nt = bli_thread_get_is_parallel(); - if((nt==0 && m0<=1000 && n0<=1000) || - (nt && (m0+n0)<320) ) - { - err_t status; - status = bli_trsm_small - ( - blis_side, - &alphao, - &ao, - &bo, - NULL, - NULL - ); - if (status == BLIS_SUCCESS) - { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - /* Finalize BLIS. */ - bli_finalize_auto(); - return; - } - } -#endif - } - bli_trsmnat - ( - blis_side, - &alphao, - &ao, - &bo, - NULL, - NULL - ); - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) - /* Finalize BLIS. */ - bli_finalize_auto(); -} -#if 0 -void ztrsm_ -( - const f77_char* side, - const f77_char* uploa, - const f77_char* transa, - const f77_char* diaga, - const f77_int* m, - const f77_int* n, - const dcomplex* alpha, - const dcomplex* a, const f77_int* lda, - dcomplex* b, const f77_int* ldb -) -{ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO) - AOCL_DTL_LOG_TRSM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'z', - *side, *uploa,*transa, *diaga, *m, *n, - (void*)alpha,*lda, *ldb); - - side_t blis_side; - uplo_t blis_uploa; - trans_t blis_transa; - diag_t blis_diaga; - dim_t m0, n0; - conj_t conja = BLIS_NO_CONJUGATE; - - /* Initialize BLIS. */ - bli_init_auto(); - - /* Perform BLAS parameter checking. */ - PASTEBLACHK(trsm) - ( - MKSTR(z), - MKSTR(trsm), - side, - uploa, - transa, - diaga, - m, - n, - lda, - ldb - ); - - /* Map BLAS chars to their corresponding BLIS enumerated type value. */ - bli_param_map_netlib_to_blis_side( *side, &blis_side ); - bli_param_map_netlib_to_blis_uplo( *uploa, &blis_uploa ); - bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); - bli_param_map_netlib_to_blis_diag( *diaga, &blis_diaga ); - - /* Typecast BLAS integers to BLIS integers. */ - bli_convert_blas_dim1( *m, m0 ); - bli_convert_blas_dim1( *n, n0 ); - - /* Set the row and column strides of the matrix operands. */ - const inc_t rs_a = 1; - const inc_t cs_a = *lda; - const inc_t rs_b = 1; - const inc_t cs_b = *ldb; - const num_t dt = BLIS_DCOMPLEX; - - - if( n0 == 1 ) - { - if( blis_side == BLIS_LEFT ) - { - if(bli_is_notrans(blis_transa)) - { - bli_ztrsv_unf_var2 - ( - blis_uploa, - blis_transa, - blis_diaga, - m0, - (dcomplex*)alpha, - (dcomplex*)a, rs_a, cs_a, - (dcomplex*)b, rs_b, - NULL - ); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return; - } - else if(bli_is_trans(blis_transa)) - { - bli_ztrsv_unf_var1 - ( - blis_uploa, - blis_transa, - blis_diaga, - m0, - (dcomplex*)alpha, - (dcomplex*)a, rs_a, cs_a, - (dcomplex*)b, rs_b, - NULL - ); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return; - } - } - else if( ( blis_side == BLIS_RIGHT ) && ( m0 != 1 ) ) - { - bli_zscalv_ex - ( - conja, - m0, - (dcomplex*)alpha, - (dcomplex*)b, rs_b, - NULL, - NULL - ); - if(blis_diaga == BLIS_NONUNIT_DIAG) - { - dcomplex inva = {1.0, 0.0}; - dcomplex a_dup; - /** - * For conjugate transpose and non-unit diagonal - * kernel, negating imaginary part of A. - * As the dimension of A is 1x1, there's going to - * be only one 1 element of A. - */ - if(*transa == 'C' && *diaga == 'N') - { - a_dup.real = a->real; - a_dup.imag = a->imag * -1.0; - } - else - { - a_dup.real = a->real; - a_dup.imag = a->imag; - } - -#ifdef BLIS_ENABLE_TRSM_PREINVERSION - bli_zinvscals(a_dup, inva); -#else - inva.real = a_dup.real; - inva.imag = a_dup.imag; -#endif - for(dim_t indx = 0; indx < m0; indx ++) - { -#ifdef BLIS_ENABLE_TRSM_PREINVERSION - bli_zscals(inva, b[indx]) -#else - - bli_zinvscals(inva, b[indx]) -#endif - } - - } - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return; - } - } - else if( m0 == 1 ) - { - if(blis_side == BLIS_RIGHT) - { - if(bli_is_notrans(blis_transa)) - { - if(blis_uploa == BLIS_UPPER) - blis_uploa = BLIS_LOWER; - else - blis_uploa = BLIS_UPPER; - - bli_ztrsv_unf_var1 - ( - blis_uploa, - blis_transa, - blis_diaga, - n0, - (dcomplex*)alpha, - (dcomplex*)a, cs_a, rs_a, - (dcomplex*)b, cs_b, - NULL - ); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return; - } - else if(bli_is_trans(blis_transa)) - { - if(blis_uploa == BLIS_UPPER) - blis_uploa = BLIS_LOWER; - else - blis_uploa = BLIS_UPPER; - - bli_ztrsv_unf_var2 - ( - blis_uploa, - blis_transa, - blis_diaga, - n0, - (dcomplex*)alpha, - (dcomplex*)a, cs_a, rs_a, - (dcomplex*)b, cs_b, - NULL - ); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return; - } - } - else if(( blis_side == BLIS_LEFT ) && ( n0 != 1 )) - { - bli_zscalv_ex - ( - conja, - n0, - (dcomplex*)alpha, - (dcomplex*)b, cs_b, - NULL, - NULL - ); - if(blis_diaga == BLIS_NONUNIT_DIAG) - { - dcomplex inva = {1.0, 0.0}; - dcomplex a_dup; - /** - * For conjugate transpose and non-unit diagonal - * kernel, negating imaginary part of A. - * As the dimension of A is 1x1, there's going to - * be only one 1 element of A. - */ - if(*transa == 'C' && *diaga == 'N') - { - a_dup.real = a->real; - a_dup.imag = a->imag * -1.0; - } - else - { - a_dup.real = a->real; - a_dup.imag = a->imag; - } - -#ifdef BLIS_ENABLE_TRSM_PREINVERSION - bli_zinvscals(a_dup, inva); -#else - inva.real = a_dup.real; - inva.imag = a_dup.imag; -#endif - for(dim_t indx = 0; indx < n0; indx ++) - { -#ifdef BLIS_ENABLE_TRSM_PREINVERSION - bli_zscals(inva ,b[indx * cs_b]) -#else - - bli_zinvscals(inva ,b[indx * cs_b]) -#endif - } - } - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return; - - } - } - - const struc_t struca = BLIS_TRIANGULAR; - - obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; - obj_t ao = BLIS_OBJECT_INITIALIZER; - obj_t bo = BLIS_OBJECT_INITIALIZER; - - dim_t mn0_a; - - bli_set_dim_with_side( blis_side, m0, n0, &mn0_a ); - - bli_obj_init_finish_1x1( dt, (dcomplex*)alpha, &alphao ); - - bli_obj_init_finish( dt, mn0_a, mn0_a, (dcomplex*)a, rs_a, cs_a, &ao ); - bli_obj_init_finish( dt, m0, n0, (dcomplex*)b, rs_b, cs_b, &bo ); - - bli_obj_set_uplo( blis_uploa, &ao ); - bli_obj_set_diag( blis_diaga, &ao ); - bli_obj_set_conjtrans( blis_transa, &ao ); - - bli_obj_set_struc( struca, &ao ); - -#ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM - /* bli_ztrsm_small is performing better existing native - * implementations for [m,n]<=1000 for single thread. - * In case of multithread when [m,n]<=128 sinlge thread implemenation - * is doing better than native multithread */ - bool nt = bli_thread_get_is_parallel(); - if((nt==0 && m0<=500 && n0<=500) || - (nt && (m0+n0)<128) ) - { - err_t status; - status = bli_trsm_small - ( - blis_side, - &alphao, - &ao, - &bo, - NULL, - NULL - ); - if (status == BLIS_SUCCESS) - { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - /* Finalize BLIS. */ - bli_finalize_auto(); - return; - } - } -#endif - - bli_trsmnat - ( - blis_side, - &alphao, - &ao, - &bo, - NULL, - NULL - ); - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) - /* Finalize BLIS. */ - bli_finalize_auto(); -} -#endif -#if 0 -void ctrsm_ -( - const f77_char* side, - const f77_char* uploa, - const f77_char* transa, - const f77_char* diaga, - const f77_int* m, - const f77_int* n, - const scomplex* alpha, - const scomplex* a, const f77_int* lda, - scomplex* b, const f77_int* ldb -) -{ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO) - AOCL_DTL_LOG_TRSM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 's', - *side, *uploa,*transa, *diaga, *m, *n, - (void*)alpha,*lda, *ldb); - - side_t blis_side; - uplo_t blis_uploa; - trans_t blis_transa; - diag_t blis_diaga; - dim_t m0, n0; - conj_t conja = BLIS_NO_CONJUGATE; - - /* Initialize BLIS. */ - bli_init_auto(); - - /* Perform BLAS parameter checking. */ - PASTEBLACHK(trsm) - ( - MKSTR(c), - MKSTR(trsm), - side, - uploa, - transa, - diaga, - m, - n, - lda, - ldb - ); - - /* Map BLAS chars to their corresponding BLIS enumerated type value. */ - bli_param_map_netlib_to_blis_side( *side, &blis_side ); - bli_param_map_netlib_to_blis_uplo( *uploa, &blis_uploa ); - bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); - bli_param_map_netlib_to_blis_diag( *diaga, &blis_diaga ); - - /* Typecast BLAS integers to BLIS integers. */ - bli_convert_blas_dim1( *m, m0 ); - bli_convert_blas_dim1( *n, n0 ); - - /* Set the row and column strides of the matrix operands. */ - const inc_t rs_a = 1; - const inc_t cs_a = *lda; - const inc_t rs_b = 1; - const inc_t cs_b = *ldb; - const num_t dt = BLIS_SCOMPLEX; - - - if( n0 == 1 ) - { - if( blis_side == BLIS_LEFT ) - { - if(bli_is_notrans(blis_transa)) - { - bli_ctrsv_unf_var2 - ( - blis_uploa, - blis_transa, - blis_diaga, - m0, - (scomplex*)alpha, - (scomplex*)a, rs_a, cs_a, - (scomplex*)b, rs_b, - NULL - ); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return; - } - else if(bli_is_trans(blis_transa)) - { - bli_ctrsv_unf_var1 - ( - blis_uploa, - blis_transa, - blis_diaga, - m0, - (scomplex*)alpha, - (scomplex*)a, rs_a, cs_a, - (scomplex*)b, rs_b, - NULL - ); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return; - } - } - else if( ( blis_side == BLIS_RIGHT ) && ( m0 != 1 ) ) - { - bli_cscalv_ex - ( - conja, - m0, - (scomplex*)alpha, - (scomplex*)b, rs_b, - NULL, - NULL - ); - if(blis_diaga == BLIS_NONUNIT_DIAG) - { - scomplex inva = {1.0, 0.0}; - scomplex a_dup; - /** - * For conjugate transpose and non-unit diagonal - * kernel, negating imaginary part of A. - * As the dimension of A is 1x1, there's going to - * be only one 1 element of A. - */ - if(*transa == 'C' && *diaga == 'N') - { - a_dup.real = a->real; - a_dup.imag = a->imag * -1.0; - } - else - { - a_dup.real = a->real; - a_dup.imag = a->imag; - } - -#ifdef BLIS_ENABLE_TRSM_PREINVERSION - bli_cinvscals(a_dup, inva); -#else - inva.real = a_dup.real; - inva.imag = a_dup.imag; -#endif - - for(dim_t indx = 0; indx < m0; indx ++) - { -#ifdef BLIS_ENABLE_TRSM_PREINVERSION - bli_cscals(inva ,b[indx]) -#else - bli_cinvscals(inva, b[indx]) -#endif - } - } - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return; - - } - } - else if( m0 == 1 ) - { - if(blis_side == BLIS_RIGHT) - { - if(bli_is_notrans(blis_transa)) - { - if(blis_uploa == BLIS_UPPER) - blis_uploa = BLIS_LOWER; - else - blis_uploa = BLIS_UPPER; - - bli_ctrsv_unf_var1 - ( - blis_uploa, - blis_transa, - blis_diaga, - n0, - (scomplex*)alpha, - (scomplex*)a, cs_a, rs_a, - (scomplex*)b, cs_b, - NULL - ); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return; - } - else if(bli_is_trans(blis_transa)) - { - if(blis_uploa == BLIS_UPPER) - blis_uploa = BLIS_LOWER; - else - blis_uploa = BLIS_UPPER; - - bli_ctrsv_unf_var2 - ( - blis_uploa, - blis_transa, - blis_diaga, - n0, - (scomplex*)alpha, - (scomplex*)a, cs_a, rs_a, - (scomplex*)b, cs_b, - NULL - ); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return; - } - } - else if(( blis_side == BLIS_LEFT ) && ( n0 != 1 )) - { - bli_cscalv_ex - ( - conja, - n0, - (scomplex*)alpha, - (scomplex*)b, cs_b, - NULL, - NULL - ); - if(blis_diaga == BLIS_NONUNIT_DIAG) - { - scomplex inva = {1.0, 0.0}; - scomplex a_dup; - /** - * For conjugate transpose and non-unit diagonal - * kernel, negating imaginary part of A. - * As the dimension of A is 1x1, there's going to - * be only one 1 element of A. - */ - if(*transa == 'C' && *diaga == 'N') - { - a_dup.real = a->real; - a_dup.imag = a->imag * -1.0; - } - else - { - a_dup.real = a->real; - a_dup.imag = a->imag; - } - -#ifdef BLIS_ENABLE_TRSM_PREINVERSION - bli_cinvscals(a_dup, inva) -#else - inva.real = a_dup.real; - inva.imag = a_dup.imag; -#endif - for(dim_t indx = 0; indx < n0; indx ++) - { -#ifdef BLIS_ENABLE_TRSM_PREINVERSION - bli_cscals(inva ,b[indx * cs_b]) -#else - bli_cinvscals(inva, b[indx * cs_b]) -#endif - - } - } - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return; - } - } - - const struc_t struca = BLIS_TRIANGULAR; - - obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; - obj_t ao = BLIS_OBJECT_INITIALIZER; - obj_t bo = BLIS_OBJECT_INITIALIZER; - - dim_t mn0_a; - - bli_set_dim_with_side( blis_side, m0, n0, &mn0_a ); - - bli_obj_init_finish_1x1( dt, (scomplex*)alpha, &alphao ); - - bli_obj_init_finish( dt, mn0_a, mn0_a, (scomplex*)a, rs_a, cs_a, &ao ); - bli_obj_init_finish( dt, m0, n0, (scomplex*)b, rs_b, cs_b, &bo ); - - bli_obj_set_uplo( blis_uploa, &ao ); - bli_obj_set_diag( blis_diaga, &ao ); - bli_obj_set_conjtrans( blis_transa, &ao ); - - bli_obj_set_struc( struca, &ao ); -#ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM - /* bli_ztrsm_small is performing better existing native - * implementations for [m,n]<=1000 for single thread. - * In case of multithread when [m,n]<=128 sinlge thread implemenation - * is doing better than native multithread */ - bool nt = bli_thread_get_is_parallel(); - if((nt==0 && m0<=1000 && n0<=1000) || - (nt && (m0+n0)<320) ) - { - err_t status; - status = bli_trsm_small - ( - blis_side, - &alphao, - &ao, - &bo, - NULL, - NULL - ); - if (status == BLIS_SUCCESS) - { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - /* Finalize BLIS. */ - bli_finalize_auto(); - return; - } - } -#endif - bli_trsmnat - ( - blis_side, - &alphao, - &ao, - &bo, - NULL, - NULL - ); - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) - /* Finalize BLIS. */ - bli_finalize_auto(); -} -#endif -INSERT_GENTFUNC_BLAS_CZ( trsm, trsm ) -#else INSERT_GENTFUNC_BLAS( trsm, trsm ) #endif -#endif diff --git a/frame/compat/bla_trsm_amd.c b/frame/compat/bla_trsm_amd.c new file mode 100644 index 0000000000..21b2a1598d --- /dev/null +++ b/frame/compat/bla_trsm_amd.c @@ -0,0 +1,1544 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019 - 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + + +// +// Define BLAS-to-BLIS interfaces. +// + +#ifdef BLIS_BLAS3_CALLS_TAPI + +#undef GENTFUNC +#define GENTFUNC( ftype, ch, blasname, blisname ) \ +\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* side, \ + const f77_char* uploa, \ + const f77_char* transa, \ + const f77_char* diaga, \ + const f77_int* m, \ + const f77_int* n, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + ftype* b, const f77_int* ldb \ + ) \ +{ \ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO) \ +\ + side_t blis_side; \ + uplo_t blis_uploa; \ + trans_t blis_transa; \ + diag_t blis_diaga; \ + dim_t m0, n0; \ + inc_t rs_a, cs_a; \ + inc_t rs_b, cs_b; \ +\ + /* Initialize BLIS. */ \ + bli_init_auto(); \ +\ + /* Perform BLAS parameter checking. */ \ + PASTEBLACHK(blasname) \ + ( \ + MKSTR(ch), \ + MKSTR(blasname), \ + side, \ + uploa, \ + transa, \ + diaga, \ + m, \ + n, \ + lda, \ + ldb \ + ); \ +\ + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ + bli_param_map_netlib_to_blis_side( *side, &blis_side ); \ + bli_param_map_netlib_to_blis_uplo( *uploa, &blis_uploa ); \ + bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ + bli_param_map_netlib_to_blis_diag( *diaga, &blis_diaga ); \ +\ + /* Typecast BLAS integers to BLIS integers. */ \ + bli_convert_blas_dim1( *m, m0 ); \ + bli_convert_blas_dim1( *n, n0 ); \ +\ + /* Set the row and column strides of the matrix operands. */ \ + rs_a = 1; \ + cs_a = *lda; \ + rs_b = 1; \ + cs_b = *ldb; \ +\ + /* Call BLIS interface. */ \ + PASTEMAC2(ch,blisname,BLIS_TAPI_EX_SUF) \ + ( \ + blis_side, \ + blis_uploa, \ + blis_transa, \ + blis_diaga, \ + m0, \ + n0, \ + (ftype*)alpha, \ + (ftype*)a, rs_a, cs_a, \ + (ftype*)b, rs_b, cs_b, \ + NULL, \ + NULL \ + ); \ +\ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ +} + +#else + +#undef GENTFUNC +#define GENTFUNC( ftype, ch, blasname, blisname ) \ +\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* side, \ + const f77_char* uploa, \ + const f77_char* transa, \ + const f77_char* diaga, \ + const f77_int* m, \ + const f77_int* n, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + ftype* b, const f77_int* ldb \ + ) \ +{ \ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO) \ + AOCL_DTL_LOG_TRSM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(ch), *side, *uploa, \ + *transa, *diaga, *m, *n, (void*)alpha, *lda, *ldb); \ + side_t blis_side; \ + uplo_t blis_uploa; \ + trans_t blis_transa; \ + diag_t blis_diaga; \ + dim_t m0, n0; \ + ftype a_conj; \ + conj_t conja = BLIS_NO_CONJUGATE ; \ +\ + /* Initialize BLIS. */ \ + bli_init_auto(); \ +\ + /* Perform BLAS parameter checking. */ \ + PASTEBLACHK(blasname) \ + ( \ + MKSTR(ch), \ + MKSTR(blasname), \ + side, \ + uploa, \ + transa, \ + diaga, \ + m, \ + n, \ + lda, \ + ldb \ + ); \ +\ + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ + bli_param_map_netlib_to_blis_side( *side, &blis_side ); \ + bli_param_map_netlib_to_blis_uplo( *uploa, &blis_uploa ); \ + bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ + bli_param_map_netlib_to_blis_diag( *diaga, &blis_diaga ); \ +\ + /* Typecast BLAS integers to BLIS integers. */ \ + bli_convert_blas_dim1( *m, m0 ); \ + bli_convert_blas_dim1( *n, n0 ); \ +\ + /* Set the row and column strides of the matrix operands. */ \ + const inc_t rs_a = 1; \ + const inc_t cs_a = *lda; \ + const inc_t rs_b = 1; \ + const inc_t cs_b = *ldb; \ + const num_t dt = PASTEMAC(ch,type); \ +\ + /* ----------------------------------------------------------- */ \ + /* TRSM API: AX = B, where X = B */ \ + /* CALL TRSV when X & B are vector and when A is Matrix */ \ + /* Case 1: LEFT : TRSM, B(mxn) = A(mxm) * X(mxn) */ \ + /* Case 2: RIGHT : TRSM, B(mxn) = X(mxn) * A(nxn) */ \ + /* |--------|-------|-------|-------|------------------------| */ \ + /* | | A | X | B | Implementation | */ \ + /* |--------|-------|-------|-------|------------------------| */ \ + /* | LEFT | mxm | mxn | mxn | | */ \ + /* |--------|-------|-------|-------|------------------------| */ \ + /* | n = 1 | mxm | mx1 | mx1 | TRSV | */ \ + /* | m = 1 | 1x1 | 1xn | 1xn | INVSCALS | */ \ + /* |--------|-------|-------|-------|------------------------| */ \ + /* |--------|-------|-------|-------|------------------------| */ \ + /* | | X | A | B | Implementation | */ \ + /* |--------|-------|-------|-------|------------------------| */ \ + /* | RIGHT | mxn | nxn | mxn | | */ \ + /* |--------|-------|-------|-------|------------------------| */ \ + /* | n = 1 | mx1 | 1x1 | mx1 | Transpose and INVSCALS| */ \ + /* | m = 1 | 1xn | nxn | 1xn | Transpose and TRSV | */ \ + /* |--------|-------|-------|-------|------------------------| */ \ + /* If Transpose(A) uplo = lower then uplo = higher */ \ + /* If Transpose(A) uplo = higher then uplo = lower */ \ + /* ----------------------------------------------------------- */ \ +\ + if( n0 == 1 ) \ + { \ + if( blis_side == BLIS_LEFT ) \ + { \ + if(bli_is_notrans(blis_transa)) \ + { \ + PASTEMAC(ch, trsv_unf_var2) \ + ( \ + blis_uploa, \ + blis_transa, \ + blis_diaga, \ + m0, \ + (ftype*)alpha, \ + (ftype*)a, rs_a, cs_a, \ + (ftype*)b, rs_b, \ + NULL \ + ); \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) \ + return; \ + } \ + else if(bli_is_trans(blis_transa)) \ + { \ + PASTEMAC(ch, trsv_unf_var1) \ + ( \ + blis_uploa, \ + blis_transa, \ + blis_diaga, \ + m0, \ + (ftype*)alpha, \ + (ftype*)a, rs_a, cs_a, \ + (ftype*)b, rs_b, \ + NULL \ + ); \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) \ + return; \ + } \ + } \ + else if( ( blis_side == BLIS_RIGHT ) && ( m0 != 1 ) ) \ + { \ + /* b = alpha * b; */ \ + PASTEMAC2(ch,scalv,BLIS_TAPI_EX_SUF) \ + ( \ + conja, \ + m0, \ + (ftype*)alpha, \ + b, rs_b, \ + NULL, \ + NULL \ + ); \ + if(blis_diaga == BLIS_NONUNIT_DIAG) \ + { \ + conja = bli_extract_conj( blis_transa ); \ + PASTEMAC(ch,copycjs)( conja, *a, a_conj ); \ + for(int indx = 0; indx < m0; indx ++) \ + { \ + PASTEMAC(ch,invscals)( a_conj, b[indx] ); \ + } \ + }\ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) \ + return; \ + } \ + } \ + else if( m0 == 1 ) \ + { \ + if(blis_side == BLIS_RIGHT) \ + { \ + if(bli_is_notrans(blis_transa)) \ + { \ + if(blis_uploa == BLIS_UPPER) \ + blis_uploa = BLIS_LOWER; \ + else \ + blis_uploa = BLIS_UPPER; \ + PASTEMAC(ch, trsv_unf_var1)( \ + blis_uploa, \ + blis_transa, \ + blis_diaga, \ + n0, \ + (ftype*)alpha, \ + (ftype*)a, cs_a, rs_a, \ + (ftype*)b, cs_b, \ + NULL); \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) \ + return; \ + } \ + else if(bli_is_trans(blis_transa)) \ + { \ + if(blis_uploa == BLIS_UPPER) \ + blis_uploa = BLIS_LOWER; \ + else \ + blis_uploa = BLIS_UPPER; \ + PASTEMAC(ch, trsv_unf_var2)( \ + blis_uploa, \ + blis_transa, \ + blis_diaga, \ + n0, \ + (ftype*)alpha, \ + (ftype*)a, cs_a, rs_a, \ + (ftype*)b, cs_b, \ + NULL); \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) \ + return; \ + } \ + } \ + else if(( blis_side == BLIS_LEFT ) && ( n0 != 1 )) \ + { \ + /* b = alpha * b; */ \ + PASTEMAC2(ch,scalv,BLIS_TAPI_EX_SUF) \ + ( \ + conja, \ + n0, \ + (ftype*)alpha, \ + b, cs_b, \ + NULL, \ + NULL \ + ); \ + if(blis_diaga == BLIS_NONUNIT_DIAG) \ + { \ + conja = bli_extract_conj( blis_transa ); \ + PASTEMAC(ch,copycjs)( conja, *a, a_conj ); \ + for(int indx = 0; indx < n0; indx ++) \ + { \ + PASTEMAC(ch,invscals)( a_conj, b[indx*cs_b] ); \ + }\ + } \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) \ + return; \ + } \ + } \ +\ + const struc_t struca = BLIS_TRIANGULAR; \ +\ + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t ao = BLIS_OBJECT_INITIALIZER; \ + obj_t bo = BLIS_OBJECT_INITIALIZER; \ +\ + dim_t mn0_a; \ +\ + bli_set_dim_with_side( blis_side, m0, n0, &mn0_a ); \ +\ + bli_obj_init_finish_1x1( dt, (ftype*)alpha, &alphao ); \ +\ + bli_obj_init_finish( dt, mn0_a, mn0_a, (ftype*)a, rs_a, cs_a, &ao ); \ + bli_obj_init_finish( dt, m0, n0, (ftype*)b, rs_b, cs_b, &bo ); \ +\ + bli_obj_set_uplo( blis_uploa, &ao ); \ + bli_obj_set_diag( blis_diaga, &ao ); \ + bli_obj_set_conjtrans( blis_transa, &ao ); \ +\ + bli_obj_set_struc( struca, &ao ); \ +\ + PASTEMAC(blisname,BLIS_OAPI_EX_SUF) \ + ( \ + blis_side, \ + &alphao, \ + &ao, \ + &bo, \ + NULL, \ + NULL \ + ); \ +\ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ +} + +#endif + +#ifdef BLIS_ENABLE_BLAS + +void strsm_ +( + const f77_char* side, + const f77_char* uploa, + const f77_char* transa, + const f77_char* diaga, + const f77_int* m, + const f77_int* n, + const float* alpha, + const float* a, const f77_int* lda, + float* b, const f77_int* ldb +) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO) + AOCL_DTL_LOG_TRSM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'd', + *side, *uploa,*transa, *diaga, *m, *n, + (void*)alpha,*lda, *ldb); + + side_t blis_side; + uplo_t blis_uploa; + trans_t blis_transa; + diag_t blis_diaga; + dim_t m0, n0; + conj_t conja = BLIS_NO_CONJUGATE ; + + /* Initialize BLIS. */ + bli_init_auto(); + + /* Perform BLAS parameter checking. */ + PASTEBLACHK(trsm) + ( + MKSTR(s), + MKSTR(trsm), + side, + uploa, + transa, + diaga, + m, + n, + lda, + ldb + ); + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + bli_param_map_netlib_to_blis_side( *side, &blis_side ); + bli_param_map_netlib_to_blis_uplo( *uploa, &blis_uploa ); + bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); + bli_param_map_netlib_to_blis_diag( *diaga, &blis_diaga ); + + /* Typecast BLAS integers to BLIS integers. */ + bli_convert_blas_dim1( *m, m0 ); + bli_convert_blas_dim1( *n, n0 ); + + /* Set the row and column strides of the matrix operands. */ + const inc_t rs_a = 1; + const inc_t cs_a = *lda; + const inc_t rs_b = 1; + const inc_t cs_b = *ldb; + const num_t dt = BLIS_FLOAT; + + if( n0 == 1 ) + { + if( blis_side == BLIS_LEFT ) + { + if(bli_is_notrans(blis_transa)) + { + bli_strsv_unf_var2 + ( + blis_uploa, + blis_transa, + blis_diaga, + m0, + (float*)alpha, + (float*)a, rs_a, cs_a, + (float*)b, rs_b, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + else if(bli_is_trans(blis_transa)) + { + bli_strsv_unf_var1 + ( + blis_uploa, + blis_transa, + blis_diaga, + m0, + (float*)alpha, + (float*)a, rs_a, cs_a, + (float*)b, rs_b, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + } + else if( ( blis_side == BLIS_RIGHT ) && ( m0 != 1 ) ) + { + /* b = alpha * b; */ + bli_sscalv_ex + ( + conja, + m0, + (float*)alpha, + b, rs_b, + NULL, + NULL + ); + if(blis_diaga == BLIS_NONUNIT_DIAG) + { + float inva = 1.0/ *a; + for(dim_t indx = 0; indx < m0; indx ++) + { + b[indx] = ( inva * b[indx] ); + } + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + } + else if( m0 == 1 ) + { + if(blis_side == BLIS_RIGHT) + { + if(bli_is_notrans(blis_transa)) + { + if(blis_uploa == BLIS_UPPER) + blis_uploa = BLIS_LOWER; + else + blis_uploa = BLIS_UPPER; + + bli_strsv_unf_var1 + ( + blis_uploa, + blis_transa, + blis_diaga, + n0, + (float*)alpha, + (float*)a, cs_a, rs_a, + (float*)b, cs_b, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + else if(bli_is_trans(blis_transa)) + { + if(blis_uploa == BLIS_UPPER) + blis_uploa = BLIS_LOWER; + else + blis_uploa = BLIS_UPPER; + + bli_strsv_unf_var2 + ( + blis_uploa, + blis_transa, + blis_diaga, + n0, + (float*)alpha, + (float*)a, cs_a, rs_a, + (float*)b, cs_b, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + } + else if(( blis_side == BLIS_LEFT ) && ( n0 != 1 )) + { + /* b = alpha * b; */ + bli_sscalv_ex + ( + conja, + n0, + (float*)alpha, + b, cs_b, + NULL, + NULL + ); + if(blis_diaga == BLIS_NONUNIT_DIAG) + { + float inva = 1.0/ *a; + for(dim_t indx = 0; indx < n0; indx ++) + { + b[indx*cs_b] = (inva * b[indx*cs_b] ); + } + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + } + const struc_t struca = BLIS_TRIANGULAR; + + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; + obj_t ao = BLIS_OBJECT_INITIALIZER; + obj_t bo = BLIS_OBJECT_INITIALIZER; + + dim_t mn0_a; + + bli_set_dim_with_side( blis_side, m0, n0, &mn0_a ); + + bli_obj_init_finish_1x1( dt, (float*)alpha, &alphao ); + + bli_obj_init_finish( dt, mn0_a, mn0_a, (float*)a, rs_a, cs_a, &ao ); + bli_obj_init_finish( dt, m0, n0, (float*)b, rs_b, cs_b, &bo ); + + bli_obj_set_uplo( blis_uploa, &ao ); + bli_obj_set_diag( blis_diaga, &ao ); + bli_obj_set_conjtrans( blis_transa, &ao ); + + bli_obj_set_struc( struca, &ao ); + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) { +#ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM + /* bli_strsm_small is performing better existing native + * implementations for [m,n]<=1000 for single thread. + * In case of multithread when [m,n]<=128 sinlge thread implemenation + * is doing better than native multithread */ + bool nt = bli_thread_get_is_parallel(); + if((nt==0 && m0<=1000 && n0<=1000) || + (nt && (m0+n0)<320) ) + { + err_t status; + status = bli_trsm_small + ( + blis_side, + &alphao, + &ao, + &bo, + NULL, + NULL + ); + if (status == BLIS_SUCCESS) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + /* Finalize BLIS. */ + bli_finalize_auto(); + return; + } + } +#endif + } + bli_trsmnat + ( + blis_side, + &alphao, + &ao, + &bo, + NULL, + NULL + ); + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) + /* Finalize BLIS. */ + bli_finalize_auto(); +} + +void dtrsm_ +( + const f77_char* side, + const f77_char* uploa, + const f77_char* transa, + const f77_char* diaga, + const f77_int* m, + const f77_int* n, + const double* alpha, + const double* a, const f77_int* lda, + double* b, const f77_int* ldb +) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO) + AOCL_DTL_LOG_TRSM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'd', + *side, *uploa,*transa, *diaga, *m, *n, + (void*)alpha,*lda, *ldb); + + side_t blis_side; + uplo_t blis_uploa; + trans_t blis_transa; + diag_t blis_diaga; + dim_t m0, n0; + conj_t conja = BLIS_NO_CONJUGATE ; + + /* Initialize BLIS. */ + bli_init_auto(); + + /* Perform BLAS parameter checking. */ + PASTEBLACHK(trsm) + ( + MKSTR(d), + MKSTR(trsm), + side, + uploa, + transa, + diaga, + m, + n, + lda, + ldb + ); + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + bli_param_map_netlib_to_blis_side( *side, &blis_side ); + bli_param_map_netlib_to_blis_uplo( *uploa, &blis_uploa ); + bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); + bli_param_map_netlib_to_blis_diag( *diaga, &blis_diaga ); + + /* Typecast BLAS integers to BLIS integers. */ + bli_convert_blas_dim1( *m, m0 ); + bli_convert_blas_dim1( *n, n0 ); + + /* Set the row and column strides of the matrix operands. */ + const inc_t rs_a = 1; + const inc_t cs_a = *lda; + const inc_t rs_b = 1; + const inc_t cs_b = *ldb; + const num_t dt = BLIS_DOUBLE; + + if( n0 == 1 ) + { + if( blis_side == BLIS_LEFT ) + { + if(bli_is_notrans(blis_transa)) + { + bli_dtrsv_unf_var2 + ( + blis_uploa, + blis_transa, + blis_diaga, + m0, + (double*)alpha, + (double*)a, rs_a, cs_a, + (double*)b, rs_b, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + else if(bli_is_trans(blis_transa)) + { + bli_dtrsv_unf_var1 + ( + blis_uploa, + blis_transa, + blis_diaga, + m0, + (double*)alpha, + (double*)a, rs_a, cs_a, + (double*)b, rs_b, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + } + else if( ( blis_side == BLIS_RIGHT ) && ( m0 != 1 ) ) + { + /* b = alpha * b; */ + bli_dscalv_ex + ( + conja, + m0, + (double*)alpha, + b, rs_b, + NULL, + NULL + ); + if(blis_diaga == BLIS_NONUNIT_DIAG) + { + double inva = 1.0/ *a; + for(dim_t indx = 0; indx < m0; indx ++) + { + b[indx] = ( inva * b[indx] ); + } + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + } + else if( m0 == 1 ) + { + if(blis_side == BLIS_RIGHT) + { + if(bli_is_notrans(blis_transa)) + { + if(blis_uploa == BLIS_UPPER) + blis_uploa = BLIS_LOWER; + else + blis_uploa = BLIS_UPPER; + + bli_dtrsv_unf_var1 + ( + blis_uploa, + blis_transa, + blis_diaga, + n0, + (double*)alpha, + (double*)a, cs_a, rs_a, + (double*)b, cs_b, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + else if(bli_is_trans(blis_transa)) + { + if(blis_uploa == BLIS_UPPER) + blis_uploa = BLIS_LOWER; + else + blis_uploa = BLIS_UPPER; + + bli_dtrsv_unf_var2 + ( + blis_uploa, + blis_transa, + blis_diaga, + n0, + (double*)alpha, + (double*)a, cs_a, rs_a, + (double*)b, cs_b, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + } + else if(( blis_side == BLIS_LEFT ) && ( n0 != 1 )) + { + /* b = alpha * b; */ + bli_dscalv_ex + ( + conja, + n0, + (double*)alpha, + b, cs_b, + NULL, + NULL + ); + if(blis_diaga == BLIS_NONUNIT_DIAG) + { + double inva = 1.0/ *a; + for(dim_t indx = 0; indx < n0; indx ++) + { + b[indx*cs_b] = (inva * b[indx*cs_b] ); + } + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + } + + const struc_t struca = BLIS_TRIANGULAR; + + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; + obj_t ao = BLIS_OBJECT_INITIALIZER; + obj_t bo = BLIS_OBJECT_INITIALIZER; + + dim_t mn0_a; + + bli_set_dim_with_side( blis_side, m0, n0, &mn0_a ); + + bli_obj_init_finish_1x1( dt, (double*)alpha, &alphao ); + + bli_obj_init_finish( dt, mn0_a, mn0_a, (double*)a, rs_a, cs_a, &ao ); + bli_obj_init_finish( dt, m0, n0, (double*)b, rs_b, cs_b, &bo ); + + bli_obj_set_uplo( blis_uploa, &ao ); + bli_obj_set_diag( blis_diaga, &ao ); + bli_obj_set_conjtrans( blis_transa, &ao ); + + bli_obj_set_struc( struca, &ao ); + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) { + +#ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM + /* bli_dtrsm_small is performing better existing native + * implementations for [m,n]<=1000 for single thread. + * In case of multithread when [m,n]<=128 sinlge thread implemenation + * is doing better than native multithread */ + bool nt = bli_thread_get_is_parallel(); + if((nt==0 && m0<=1000 && n0<=1000) || + (nt && (m0+n0)<320) ) + { + err_t status; + status = bli_trsm_small + ( + blis_side, + &alphao, + &ao, + &bo, + NULL, + NULL + ); + if (status == BLIS_SUCCESS) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + /* Finalize BLIS. */ + bli_finalize_auto(); + return; + } + } +#endif + } + bli_trsmnat + ( + blis_side, + &alphao, + &ao, + &bo, + NULL, + NULL + ); + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) + /* Finalize BLIS. */ + bli_finalize_auto(); +} +#if 0 +void ztrsm_ +( + const f77_char* side, + const f77_char* uploa, + const f77_char* transa, + const f77_char* diaga, + const f77_int* m, + const f77_int* n, + const dcomplex* alpha, + const dcomplex* a, const f77_int* lda, + dcomplex* b, const f77_int* ldb +) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO) + AOCL_DTL_LOG_TRSM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'z', + *side, *uploa,*transa, *diaga, *m, *n, + (void*)alpha,*lda, *ldb); + + side_t blis_side; + uplo_t blis_uploa; + trans_t blis_transa; + diag_t blis_diaga; + dim_t m0, n0; + conj_t conja = BLIS_NO_CONJUGATE; + + /* Initialize BLIS. */ + bli_init_auto(); + + /* Perform BLAS parameter checking. */ + PASTEBLACHK(trsm) + ( + MKSTR(z), + MKSTR(trsm), + side, + uploa, + transa, + diaga, + m, + n, + lda, + ldb + ); + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + bli_param_map_netlib_to_blis_side( *side, &blis_side ); + bli_param_map_netlib_to_blis_uplo( *uploa, &blis_uploa ); + bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); + bli_param_map_netlib_to_blis_diag( *diaga, &blis_diaga ); + + /* Typecast BLAS integers to BLIS integers. */ + bli_convert_blas_dim1( *m, m0 ); + bli_convert_blas_dim1( *n, n0 ); + + /* Set the row and column strides of the matrix operands. */ + const inc_t rs_a = 1; + const inc_t cs_a = *lda; + const inc_t rs_b = 1; + const inc_t cs_b = *ldb; + const num_t dt = BLIS_DCOMPLEX; + + + if( n0 == 1 ) + { + if( blis_side == BLIS_LEFT ) + { + if(bli_is_notrans(blis_transa)) + { + bli_ztrsv_unf_var2 + ( + blis_uploa, + blis_transa, + blis_diaga, + m0, + (dcomplex*)alpha, + (dcomplex*)a, rs_a, cs_a, + (dcomplex*)b, rs_b, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + else if(bli_is_trans(blis_transa)) + { + bli_ztrsv_unf_var1 + ( + blis_uploa, + blis_transa, + blis_diaga, + m0, + (dcomplex*)alpha, + (dcomplex*)a, rs_a, cs_a, + (dcomplex*)b, rs_b, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + } + else if( ( blis_side == BLIS_RIGHT ) && ( m0 != 1 ) ) + { + bli_zscalv_ex + ( + conja, + m0, + (dcomplex*)alpha, + (dcomplex*)b, rs_b, + NULL, + NULL + ); + if(blis_diaga == BLIS_NONUNIT_DIAG) + { + dcomplex inva = {1.0, 0.0}; + dcomplex a_dup; + /** + * For conjugate transpose and non-unit diagonal + * kernel, negating imaginary part of A. + * As the dimension of A is 1x1, there's going to + * be only one 1 element of A. + */ + if(*transa == 'C' && *diaga == 'N') + { + a_dup.real = a->real; + a_dup.imag = a->imag * -1.0; + } + else + { + a_dup.real = a->real; + a_dup.imag = a->imag; + } + +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + bli_zinvscals(a_dup, inva); +#else + inva.real = a_dup.real; + inva.imag = a_dup.imag; +#endif + for(dim_t indx = 0; indx < m0; indx ++) + { +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + bli_zscals(inva, b[indx]) +#else + + bli_zinvscals(inva, b[indx]) +#endif + } + + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + } + else if( m0 == 1 ) + { + if(blis_side == BLIS_RIGHT) + { + if(bli_is_notrans(blis_transa)) + { + if(blis_uploa == BLIS_UPPER) + blis_uploa = BLIS_LOWER; + else + blis_uploa = BLIS_UPPER; + + bli_ztrsv_unf_var1 + ( + blis_uploa, + blis_transa, + blis_diaga, + n0, + (dcomplex*)alpha, + (dcomplex*)a, cs_a, rs_a, + (dcomplex*)b, cs_b, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + else if(bli_is_trans(blis_transa)) + { + if(blis_uploa == BLIS_UPPER) + blis_uploa = BLIS_LOWER; + else + blis_uploa = BLIS_UPPER; + + bli_ztrsv_unf_var2 + ( + blis_uploa, + blis_transa, + blis_diaga, + n0, + (dcomplex*)alpha, + (dcomplex*)a, cs_a, rs_a, + (dcomplex*)b, cs_b, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + } + else if(( blis_side == BLIS_LEFT ) && ( n0 != 1 )) + { + bli_zscalv_ex + ( + conja, + n0, + (dcomplex*)alpha, + (dcomplex*)b, cs_b, + NULL, + NULL + ); + if(blis_diaga == BLIS_NONUNIT_DIAG) + { + dcomplex inva = {1.0, 0.0}; + dcomplex a_dup; + /** + * For conjugate transpose and non-unit diagonal + * kernel, negating imaginary part of A. + * As the dimension of A is 1x1, there's going to + * be only one 1 element of A. + */ + if(*transa == 'C' && *diaga == 'N') + { + a_dup.real = a->real; + a_dup.imag = a->imag * -1.0; + } + else + { + a_dup.real = a->real; + a_dup.imag = a->imag; + } + +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + bli_zinvscals(a_dup, inva); +#else + inva.real = a_dup.real; + inva.imag = a_dup.imag; +#endif + for(dim_t indx = 0; indx < n0; indx ++) + { +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + bli_zscals(inva ,b[indx * cs_b]) +#else + + bli_zinvscals(inva ,b[indx * cs_b]) +#endif + } + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + + } + } + + const struc_t struca = BLIS_TRIANGULAR; + + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; + obj_t ao = BLIS_OBJECT_INITIALIZER; + obj_t bo = BLIS_OBJECT_INITIALIZER; + + dim_t mn0_a; + + bli_set_dim_with_side( blis_side, m0, n0, &mn0_a ); + + bli_obj_init_finish_1x1( dt, (dcomplex*)alpha, &alphao ); + + bli_obj_init_finish( dt, mn0_a, mn0_a, (dcomplex*)a, rs_a, cs_a, &ao ); + bli_obj_init_finish( dt, m0, n0, (dcomplex*)b, rs_b, cs_b, &bo ); + + bli_obj_set_uplo( blis_uploa, &ao ); + bli_obj_set_diag( blis_diaga, &ao ); + bli_obj_set_conjtrans( blis_transa, &ao ); + + bli_obj_set_struc( struca, &ao ); + +#ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM + /* bli_ztrsm_small is performing better existing native + * implementations for [m,n]<=1000 for single thread. + * In case of multithread when [m,n]<=128 sinlge thread implemenation + * is doing better than native multithread */ + bool nt = bli_thread_get_is_parallel(); + if((nt==0 && m0<=500 && n0<=500) || + (nt && (m0+n0)<128) ) + { + err_t status; + status = bli_trsm_small + ( + blis_side, + &alphao, + &ao, + &bo, + NULL, + NULL + ); + if (status == BLIS_SUCCESS) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + /* Finalize BLIS. */ + bli_finalize_auto(); + return; + } + } +#endif + + bli_trsmnat + ( + blis_side, + &alphao, + &ao, + &bo, + NULL, + NULL + ); + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) + /* Finalize BLIS. */ + bli_finalize_auto(); +} +#endif +#if 0 +void ctrsm_ +( + const f77_char* side, + const f77_char* uploa, + const f77_char* transa, + const f77_char* diaga, + const f77_int* m, + const f77_int* n, + const scomplex* alpha, + const scomplex* a, const f77_int* lda, + scomplex* b, const f77_int* ldb +) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO) + AOCL_DTL_LOG_TRSM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 's', + *side, *uploa,*transa, *diaga, *m, *n, + (void*)alpha,*lda, *ldb); + + side_t blis_side; + uplo_t blis_uploa; + trans_t blis_transa; + diag_t blis_diaga; + dim_t m0, n0; + conj_t conja = BLIS_NO_CONJUGATE; + + /* Initialize BLIS. */ + bli_init_auto(); + + /* Perform BLAS parameter checking. */ + PASTEBLACHK(trsm) + ( + MKSTR(c), + MKSTR(trsm), + side, + uploa, + transa, + diaga, + m, + n, + lda, + ldb + ); + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + bli_param_map_netlib_to_blis_side( *side, &blis_side ); + bli_param_map_netlib_to_blis_uplo( *uploa, &blis_uploa ); + bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); + bli_param_map_netlib_to_blis_diag( *diaga, &blis_diaga ); + + /* Typecast BLAS integers to BLIS integers. */ + bli_convert_blas_dim1( *m, m0 ); + bli_convert_blas_dim1( *n, n0 ); + + /* Set the row and column strides of the matrix operands. */ + const inc_t rs_a = 1; + const inc_t cs_a = *lda; + const inc_t rs_b = 1; + const inc_t cs_b = *ldb; + const num_t dt = BLIS_SCOMPLEX; + + + if( n0 == 1 ) + { + if( blis_side == BLIS_LEFT ) + { + if(bli_is_notrans(blis_transa)) + { + bli_ctrsv_unf_var2 + ( + blis_uploa, + blis_transa, + blis_diaga, + m0, + (scomplex*)alpha, + (scomplex*)a, rs_a, cs_a, + (scomplex*)b, rs_b, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + else if(bli_is_trans(blis_transa)) + { + bli_ctrsv_unf_var1 + ( + blis_uploa, + blis_transa, + blis_diaga, + m0, + (scomplex*)alpha, + (scomplex*)a, rs_a, cs_a, + (scomplex*)b, rs_b, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + } + else if( ( blis_side == BLIS_RIGHT ) && ( m0 != 1 ) ) + { + bli_cscalv_ex + ( + conja, + m0, + (scomplex*)alpha, + (scomplex*)b, rs_b, + NULL, + NULL + ); + if(blis_diaga == BLIS_NONUNIT_DIAG) + { + scomplex inva = {1.0, 0.0}; + scomplex a_dup; + /** + * For conjugate transpose and non-unit diagonal + * kernel, negating imaginary part of A. + * As the dimension of A is 1x1, there's going to + * be only one 1 element of A. + */ + if(*transa == 'C' && *diaga == 'N') + { + a_dup.real = a->real; + a_dup.imag = a->imag * -1.0; + } + else + { + a_dup.real = a->real; + a_dup.imag = a->imag; + } + +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + bli_cinvscals(a_dup, inva); +#else + inva.real = a_dup.real; + inva.imag = a_dup.imag; +#endif + + for(dim_t indx = 0; indx < m0; indx ++) + { +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + bli_cscals(inva ,b[indx]) +#else + bli_cinvscals(inva, b[indx]) +#endif + } + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + + } + } + else if( m0 == 1 ) + { + if(blis_side == BLIS_RIGHT) + { + if(bli_is_notrans(blis_transa)) + { + if(blis_uploa == BLIS_UPPER) + blis_uploa = BLIS_LOWER; + else + blis_uploa = BLIS_UPPER; + + bli_ctrsv_unf_var1 + ( + blis_uploa, + blis_transa, + blis_diaga, + n0, + (scomplex*)alpha, + (scomplex*)a, cs_a, rs_a, + (scomplex*)b, cs_b, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + else if(bli_is_trans(blis_transa)) + { + if(blis_uploa == BLIS_UPPER) + blis_uploa = BLIS_LOWER; + else + blis_uploa = BLIS_UPPER; + + bli_ctrsv_unf_var2 + ( + blis_uploa, + blis_transa, + blis_diaga, + n0, + (scomplex*)alpha, + (scomplex*)a, cs_a, rs_a, + (scomplex*)b, cs_b, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + } + else if(( blis_side == BLIS_LEFT ) && ( n0 != 1 )) + { + bli_cscalv_ex + ( + conja, + n0, + (scomplex*)alpha, + (scomplex*)b, cs_b, + NULL, + NULL + ); + if(blis_diaga == BLIS_NONUNIT_DIAG) + { + scomplex inva = {1.0, 0.0}; + scomplex a_dup; + /** + * For conjugate transpose and non-unit diagonal + * kernel, negating imaginary part of A. + * As the dimension of A is 1x1, there's going to + * be only one 1 element of A. + */ + if(*transa == 'C' && *diaga == 'N') + { + a_dup.real = a->real; + a_dup.imag = a->imag * -1.0; + } + else + { + a_dup.real = a->real; + a_dup.imag = a->imag; + } + +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + bli_cinvscals(a_dup, inva) +#else + inva.real = a_dup.real; + inva.imag = a_dup.imag; +#endif + for(dim_t indx = 0; indx < n0; indx ++) + { +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + bli_cscals(inva ,b[indx * cs_b]) +#else + bli_cinvscals(inva, b[indx * cs_b]) +#endif + + } + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + } + + const struc_t struca = BLIS_TRIANGULAR; + + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; + obj_t ao = BLIS_OBJECT_INITIALIZER; + obj_t bo = BLIS_OBJECT_INITIALIZER; + + dim_t mn0_a; + + bli_set_dim_with_side( blis_side, m0, n0, &mn0_a ); + + bli_obj_init_finish_1x1( dt, (scomplex*)alpha, &alphao ); + + bli_obj_init_finish( dt, mn0_a, mn0_a, (scomplex*)a, rs_a, cs_a, &ao ); + bli_obj_init_finish( dt, m0, n0, (scomplex*)b, rs_b, cs_b, &bo ); + + bli_obj_set_uplo( blis_uploa, &ao ); + bli_obj_set_diag( blis_diaga, &ao ); + bli_obj_set_conjtrans( blis_transa, &ao ); + + bli_obj_set_struc( struca, &ao ); +#ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM + /* bli_ztrsm_small is performing better existing native + * implementations for [m,n]<=1000 for single thread. + * In case of multithread when [m,n]<=128 sinlge thread implemenation + * is doing better than native multithread */ + bool nt = bli_thread_get_is_parallel(); + if((nt==0 && m0<=1000 && n0<=1000) || + (nt && (m0+n0)<320) ) + { + err_t status; + status = bli_trsm_small + ( + blis_side, + &alphao, + &ao, + &bo, + NULL, + NULL + ); + if (status == BLIS_SUCCESS) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + /* Finalize BLIS. */ + bli_finalize_auto(); + return; + } + } +#endif + bli_trsmnat + ( + blis_side, + &alphao, + &ao, + &bo, + NULL, + NULL + ); + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) + /* Finalize BLIS. */ + bli_finalize_auto(); +} +#endif +INSERT_GENTFUNC_BLAS_CZ( trsm, trsm ) + +#endif diff --git a/kernels/zen/1/bli_scalv_zen_int10.c b/kernels/zen/1/bli_scalv_zen_int10.c index de9d8339d3..7146e86879 100644 --- a/kernels/zen/1/bli_scalv_zen_int10.c +++ b/kernels/zen/1/bli_scalv_zen_int10.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2017 - 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2017 - 2022, Advanced Micro Devices, Inc. All rights reserved. Copyright (C) 2018, The University of Texas at Austin Redistribution and use in source and binary forms, with or without @@ -64,16 +64,7 @@ void bli_sscalv_zen_int10 if ( PASTEMAC(s,eq0)( *alpha ) ) { float* zero = bli_s0; -#ifdef BLIS_CONFIG_EPYC - bli_ssetv_zen_int - ( - BLIS_NO_CONJUGATE, - n, - zero, - x, incx, - cntx - ); -#else + if ( cntx == NULL ) cntx = bli_gks_query_cntx(); ssetv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_FLOAT, BLIS_SETV_KER, cntx ); f ( @@ -83,7 +74,7 @@ void bli_sscalv_zen_int10 x, incx, cntx ); -#endif + return; } @@ -342,16 +333,7 @@ void bli_dscalv_zen_int10 if ( PASTEMAC(d,eq0)( *alpha ) ) { double* zero = bli_d0; -#ifdef BLIS_CONFIG_EPYC - bli_dsetv_zen_int - ( - BLIS_NO_CONJUGATE, - n, - zero, - x, incx, - cntx - ); -#else + if ( cntx == NULL ) cntx = bli_gks_query_cntx(); dsetv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_DOUBLE, BLIS_SETV_KER, cntx ); f @@ -362,7 +344,7 @@ void bli_dscalv_zen_int10 x, incx, cntx ); -#endif + return; } diff --git a/kernels/zen/1f/bli_axpyf_zen_int_4.c b/kernels/zen/1f/bli_axpyf_zen_int_4.c index f5a043db84..bb24e6c52f 100644 --- a/kernels/zen/1f/bli_axpyf_zen_int_4.c +++ b/kernels/zen/1f/bli_axpyf_zen_int_4.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -95,29 +95,6 @@ void bli_caxpyf_zen_int_4 // operation as a loop over axpyv. if ( b_n != fuse_fac ) { -#ifdef BLIS_CONFIG_EPYC - for ( i = 0; i < b_n; ++i ) - { - scomplex* a1 = a + (0 )*inca + (i )*lda; - scomplex* chi1 = x + (i )*incx; - scomplex* y1 = y + (0 )*incy; - scomplex alpha_chi1; - - bli_ccopycjs( conjx, *chi1, alpha_chi1 ); - bli_cscals( *alpha, alpha_chi1 ); - - bli_caxpyv_zen_int5 - ( - conja, - m, - &alpha_chi1, - a1, inca, - y1, incy, - cntx - ); - } - -#else caxpyv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_SCOMPLEX, BLIS_AXPYV_KER, cntx ); for ( i = 0; i < b_n; ++i ) @@ -141,7 +118,6 @@ void bli_caxpyf_zen_int_4 ); } -#endif return; } @@ -357,28 +333,6 @@ void bli_zaxpyf_zen_int_4 // operation as a loop over axpyv. if ( b_n != fuse_fac ) { -#ifdef BLIS_CONFIG_EPYC - for ( i = 0; i < b_n; ++i ) - { - dcomplex* a1 = a + (0 )*inca + (i )*lda; - dcomplex* chi1 = x + (i )*incx; - dcomplex* y1 = y + (0 )*incy; - dcomplex alpha_chi1; - - bli_zcopycjs( conjx, *chi1, alpha_chi1 ); - bli_zscals( *alpha, alpha_chi1 ); - - bli_zaxpyv_zen_int5 - ( - conja, - m, - &alpha_chi1, - a1, inca, - y1, incy, - cntx - ); - } -#else zaxpyv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_DCOMPLEX, BLIS_AXPYV_KER, cntx ); for ( i = 0; i < b_n; ++i ) @@ -402,7 +356,6 @@ void bli_zaxpyf_zen_int_4 ); } -#endif return; } diff --git a/kernels/zen/1f/bli_axpyf_zen_int_5.c b/kernels/zen/1f/bli_axpyf_zen_int_5.c index 1125197775..d09a85f57f 100644 --- a/kernels/zen/1f/bli_axpyf_zen_int_5.c +++ b/kernels/zen/1f/bli_axpyf_zen_int_5.c @@ -108,29 +108,6 @@ void bli_saxpyf_zen_int_5 // operation as a loop over axpyv. if ( b_n != fuse_fac ) { -#ifdef BLIS_CONFIG_EPYC - for ( i = 0; i < b_n; ++i ) - { - float* a1 = a + (0 )*inca + (i )*lda; - float* chi1 = x + (i )*incx; - float* y1 = y + (0 )*incy; - float alpha_chi1; - - bli_scopycjs( conjx, *chi1, alpha_chi1 ); - bli_sscals( *alpha, alpha_chi1 ); - - bli_saxpyv_zen_int10 - ( - conja, - m, - &alpha_chi1, - a1, inca, - y1, incy, - cntx - ); - } - -#else saxpyv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_FLOAT, BLIS_AXPYV_KER, cntx ); for ( i = 0; i < b_n; ++i ) @@ -154,7 +131,6 @@ void bli_saxpyf_zen_int_5 ); } -#endif return; } @@ -382,29 +358,6 @@ void bli_daxpyf_zen_int_5 // operation as a loop over axpyv. if ( b_n != fuse_fac ) { -#ifdef BLIS_CONFIG_EPYC - for ( i = 0; i < b_n; ++i ) - { - double* a1 = a + (0 )*inca + (i )*lda; - double* chi1 = x + (i )*incx; - double* y1 = y + (0 )*incy; - double alpha_chi1; - - bli_dcopycjs( conjx, *chi1, alpha_chi1 ); - bli_dscals( *alpha, alpha_chi1 ); - - bli_daxpyv_zen_int10 - ( - conja, - m, - &alpha_chi1, - a1, inca, - y1, incy, - cntx - ); - } - -#else daxpyv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_DOUBLE, BLIS_AXPYV_KER, cntx ); for ( i = 0; i < b_n; ++i ) @@ -428,7 +381,6 @@ void bli_daxpyf_zen_int_5 ); } -#endif return; } @@ -655,29 +607,6 @@ static void bli_daxpyf_zen_int_16x2 // operation as a loop over axpyv. if ( b_n != fuse_fac ) { -#ifdef BLIS_CONFIG_EPYC - for ( i = 0; i < b_n; ++i ) - { - double* a1 = a + (0 )*inca + (i )*lda; - double* chi1 = x + (i )*incx; - double* y1 = y + (0 )*incy; - double alpha_chi1; - - bli_dcopycjs( conjx, *chi1, alpha_chi1 ); - bli_dscals( *alpha, alpha_chi1 ); - - bli_daxpyv_zen_int10 - ( - conja, - m, - &alpha_chi1, - a1, inca, - y1, incy, - cntx - ); - } - -#else daxpyv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_DOUBLE, BLIS_AXPYV_KER, cntx ); for ( i = 0; i < b_n; ++i ) @@ -701,7 +630,6 @@ static void bli_daxpyf_zen_int_16x2 ); } -#endif return; } @@ -966,43 +894,21 @@ void bli_daxpyf_zen_int_16x4 // operation as a loop over axpyv. if ( b_n != fuse_fac ) { -#ifdef BLIS_CONFIG_EPYC - if(b_n & 2) - { - bli_daxpyf_zen_int_16x2( conja, - conjx, - m, 2, - alpha, a, inca, lda, - x, incx, - y, incy, - cntx - ); - b_n -= 2; - a += 2*lda; - x += 2 * incx; - } - for ( i = 0; i < b_n; ++i ) - { - double* a1 = a + (0 )*inca + (i )*lda; - double* chi1 = x + (i )*incx; - double* y1 = y + (0 )*incy; - double alpha_chi1; - - bli_dcopycjs( conjx, *chi1, alpha_chi1 ); - bli_dscals( *alpha, alpha_chi1 ); - - bli_daxpyv_zen_int10 - ( - conja, - m, - &alpha_chi1, - a1, inca, - y1, incy, - cntx - ); - } + if (b_n & 2) + { + bli_daxpyf_zen_int_16x2( conja, + conjx, + m, 2, + alpha, a, inca, lda, + x, incx, + y, incy, + cntx + ); + b_n -= 2; + a += 2*lda; + x += 2 * incx; + } -#else daxpyv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_DOUBLE, BLIS_AXPYV_KER, cntx ); for ( i = 0; i < b_n; ++i ) @@ -1026,7 +932,6 @@ void bli_daxpyf_zen_int_16x4 ); } -#endif return; } @@ -1396,29 +1301,6 @@ void bli_caxpyf_zen_int_5 // operation as a loop over axpyv. if ( b_n != fuse_fac ) { -#ifdef BLIS_CONFIG_EPYC - for ( i = 0; i < b_n; ++i ) - { - scomplex* a1 = a + (0 )*inca + (i )*lda; - scomplex* chi1 = x + (i )*incx; - scomplex* y1 = y + (0 )*incy; - scomplex alpha_chi1; - - bli_ccopycjs( conjx, *chi1, alpha_chi1 ); - bli_cscals( *alpha, alpha_chi1 ); - - bli_caxpyv_zen_int5 - ( - conja, - m, - &alpha_chi1, - a1, inca, - y1, incy, - cntx - ); - } - -#else caxpyv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_SCOMPLEX, BLIS_AXPYV_KER, cntx ); for ( i = 0; i < b_n; ++i ) @@ -1442,7 +1324,6 @@ void bli_caxpyf_zen_int_5 ); } -#endif return; } @@ -1810,29 +1691,6 @@ void bli_zaxpyf_zen_int_5 // operation as a loop over axpyv. if ( b_n != fuse_fac ) { -#ifdef BLIS_CONFIG_EPYC - for ( i = 0; i < b_n; ++i ) - { - dcomplex* a1 = a + (0 )*inca + (i )*lda; - dcomplex* chi1 = x + (i )*incx; - dcomplex* y1 = y + (0 )*incy; - dcomplex alpha_chi1; - - bli_zcopycjs( conjx, *chi1, alpha_chi1 ); - bli_zscals( *alpha, alpha_chi1 ); - - bli_zaxpyv_zen_int5 - ( - conja, - m, - &alpha_chi1, - a1, inca, - y1, incy, - cntx - ); - } - -#else zaxpyv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_DCOMPLEX, BLIS_AXPYV_KER, cntx ); for ( i = 0; i < b_n; ++i ) @@ -1855,8 +1713,7 @@ void bli_zaxpyf_zen_int_5 cntx ); } - -#endif + return; } diff --git a/kernels/zen/1f/bli_axpyf_zen_int_6.c b/kernels/zen/1f/bli_axpyf_zen_int_6.c index 99b544db15..cf7dbd1732 100644 --- a/kernels/zen/1f/bli_axpyf_zen_int_6.c +++ b/kernels/zen/1f/bli_axpyf_zen_int_6.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -97,28 +97,6 @@ void bli_saxpyf_zen_int_6 // operation as a loop over axpyv. if ( b_n != fuse_fac ) { -#ifdef BLIS_CONFIG_EPYC - for ( i = 0; i < b_n; ++i ) - { - float* a1 = a + (0 )*inca + (i )*lda; - float* chi1 = x + (i )*incx; - float* y1 = y + (0 )*incy; - float alpha_chi1; - - bli_scopycjs( conjx, *chi1, alpha_chi1 ); - bli_sscals( *alpha, alpha_chi1 ); - - bli_saxpyv_zen_int10 - ( - conja, - m, - &alpha_chi1, - a1, inca, - y1, incy, - cntx - ); - } -#else saxpyv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_FLOAT, BLIS_AXPYV_KER, cntx ); for ( i = 0; i < b_n; ++i ) @@ -141,7 +119,7 @@ void bli_saxpyf_zen_int_6 cntx ); } -#endif + return; } diff --git a/kernels/zen/3/bli_gemm_small.c b/kernels/zen/3/bli_gemm_small.c index bf6c9c29cd..3e9463fabc 100644 --- a/kernels/zen/3/bli_gemm_small.c +++ b/kernels/zen/3/bli_gemm_small.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2017-2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2017-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -114,16 +114,9 @@ err_t bli_gemm_small AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); return BLIS_NOT_YET_IMPLEMENTED; #else - // When dynamic dispatch is enabled i.e. library is built for 'amdzen' configuration. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN4) || - (id == BLIS_ARCH_ZEN3) || - (id == BLIS_ARCH_ZEN2) || - (id == BLIS_ARCH_ZEN); - - if (0 == bamdzen) + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == FALSE) { return BLIS_NOT_YET_IMPLEMENTED; } From e783ea10db2636f2c4f828917fab9c8deced1662 Mon Sep 17 00:00:00 2001 From: Saitharun Date: Wed, 19 Jan 2022 11:38:45 +0530 Subject: [PATCH 067/243] Enable wrapper code by default details: Changes Made for 4.0 branch to enable wrapper code by default and also removed ENABLE_API_WRAPPER macro. Change-Id: I5c9ede7ae959d811bc009073a266e66cbf07ef1a --- CMakeLists.txt | 7 +------ frame/util/bli_util_api_wrap.c | 4 +++- frame/util/bli_util_api_wrap.h | 4 +++- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 3affe7b40c..3f7cc6ad94 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -99,9 +99,8 @@ option(BLIS_ENABLE_ILP64 "ENABLE BLIS ILP64" OFF) option(ENABLE_INT_TYPE_SIZE " Internal BLIS integers ,used in native BLIS interfaces based on architecture dependent " ON) option(ENABLE_BLASTEST "Enable the blastest" OFF) option(ENABLE_TESTCPP_TESTING "Enabling testcpp" OFF) -option (ENABLE_NO_UNDERSCORE_API "export APIs without underscore" ON) +option (ENABLE_NO_UNDERSCORE_API "export APIs without underscore" OFF) option (ENABLE_UPPERCASE_API "export APIs with uppercase" OFF) -option (ENABLE_API_WRAPPER "Enable wrapper code" OFF) option (ENABLE_COMPLEX_RETURN_INTEL "Enable complex_return_intel" OFF) option (ENABLE_TRSM_PREINVERSION "Enable TRSM preinversion" ON) option (ENABLE_AOCL_DYNAMIC "Enable Dynamic Multi-threading" OFF) @@ -131,10 +130,6 @@ if(ENABLE_UPPERCASE_API) add_definitions(-DBLIS_ENABLE_UPPERCASE_API) endif() -if(ENABLE_API_WRAPPER) - add_definitions(-DBLIS_ENABLE_API_WRAPPER) -endif() - if(ENABLE_AOCL_DYNAMIC) set(AOCL_DYNAMIC TRUE) endif() diff --git a/frame/util/bli_util_api_wrap.c b/frame/util/bli_util_api_wrap.c index 128fba8b87..81300761fb 100644 --- a/frame/util/bli_util_api_wrap.c +++ b/frame/util/bli_util_api_wrap.c @@ -39,7 +39,8 @@ #include "bli_util_api_wrap.h" // wrapper functions to support additional symbols -#ifdef BLIS_ENABLE_API_WRAPPER +#ifndef BLIS_ENABLE_NO_UNDERSCORE_API +#ifndef BLIS_ENABLE_UPPERCASE_API void CAXPY(const f77_int *n,const scomplex *ca,const scomplex *cx,const f77_int *incx,scomplex *cy,const f77_int *incy) { caxpy_( n, ca, cx, incx, cy, incy); @@ -3221,3 +3222,4 @@ void CAXPBY_( const f77_int* n, const scomplex* alpha, const scomplex *x, con } #endif +#endif diff --git a/frame/util/bli_util_api_wrap.h b/frame/util/bli_util_api_wrap.h index f0aff49ff2..78f088e28e 100644 --- a/frame/util/bli_util_api_wrap.h +++ b/frame/util/bli_util_api_wrap.h @@ -35,7 +35,8 @@ // file define different formats of BLAS APIs- uppercase with // and without underscore, lowercase without underscore. -#ifdef BLIS_ENABLE_API_WRAPPER +#ifndef BLIS_ENABLE_NO_UNDERSCORE_API +#ifndef BLIS_ENABLE_UPPERCASE_API //Level 1 APIs BLIS_EXPORT_BLIS void SROTG(float *sa, float *sb, float *c, float *s); @@ -1729,3 +1730,4 @@ BLIS_EXPORT_BLIS void ZOMATCOPY_(f77_char* trans, f77_int* rows, f77_int* cols #endif +#endif From 14fb31c0d5091fffdf401e04982adff306a17481 Mon Sep 17 00:00:00 2001 From: Harihara Sudhan Date: Fri, 7 Jan 2022 14:10:56 +0530 Subject: [PATCH 068/243] Improved performance of DOTXV kernel for float and double - Vectorized sections of code that were not vectorized AMD Internal: [CPUPL-1980] Change-Id: I08528d054442a5e728f631142f244f1624170136 --- kernels/zen/1/bli_dotxv_zen_int.c | 131 ++++++++++++++++++------------ 1 file changed, 78 insertions(+), 53 deletions(-) diff --git a/kernels/zen/1/bli_dotxv_zen_int.c b/kernels/zen/1/bli_dotxv_zen_int.c index 99ea517104..8ba1d1bba4 100644 --- a/kernels/zen/1/bli_dotxv_zen_int.c +++ b/kernels/zen/1/bli_dotxv_zen_int.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2016 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2016 - 2022, Advanced Micro Devices, Inc. Copyright (C) 2018, The University of Texas at Austin Redistribution and use in source and binary forms, with or without @@ -36,6 +36,14 @@ #include "immintrin.h" #include "blis.h" +/* Union data structure to access AVX registers + One 128-bit AVX register holds 8 SP elements. */ +typedef union +{ + __m128 v; + float f[4] __attribute__((aligned(64))); +} v4sf_t; + /* Union data structure to access AVX registers One 256-bit AVX register holds 8 SP elements. */ typedef union @@ -44,6 +52,14 @@ typedef union float f[8] __attribute__((aligned(64))); } v8sf_t; +/* Union data structure to access AVX registers +* One 128-bit AVX register holds 4 DP elements. */ +typedef union +{ + __m128d v; + double d[2] __attribute__((aligned(64))); +} v2df_t; + /* Union data structure to access AVX registers * One 256-bit AVX register holds 4 DP elements. */ typedef union @@ -78,11 +94,7 @@ void bli_sdotxv_zen_int float* restrict y0; float rho0; - v8sf_t rho0v, rho1v, rho2v, rho3v; - v8sf_t x0v, y0v; - v8sf_t x1v, y1v; - v8sf_t x2v, y2v; - v8sf_t x3v, y3v; + v8sf_t rhov[4], xv[4], yv[4]; // If beta is zero, initialize rho1 to zero instead of scaling // rho by beta (in case rho contains NaN or Inf). @@ -117,45 +129,55 @@ void bli_sdotxv_zen_int y0 = y; // Initialize the unrolled iterations' rho vectors to zero. - rho0v.v = _mm256_setzero_ps(); - rho1v.v = _mm256_setzero_ps(); - rho2v.v = _mm256_setzero_ps(); - rho3v.v = _mm256_setzero_ps(); + rhov[0].v = _mm256_setzero_ps(); + rhov[1].v = _mm256_setzero_ps(); + rhov[2].v = _mm256_setzero_ps(); + rhov[3].v = _mm256_setzero_ps(); for ( i = 0; i < n_viter; ++i ) { // Load the x and y input vector elements. - x0v.v = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); - y0v.v = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); + xv[0].v = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); + yv[0].v = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); - x1v.v = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); - y1v.v = _mm256_loadu_ps( y0 + 1*n_elem_per_reg ); + xv[1].v = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); + yv[1].v = _mm256_loadu_ps( y0 + 1*n_elem_per_reg ); - x2v.v = _mm256_loadu_ps( x0 + 2*n_elem_per_reg ); - y2v.v = _mm256_loadu_ps( y0 + 2*n_elem_per_reg ); + xv[2].v = _mm256_loadu_ps( x0 + 2*n_elem_per_reg ); + yv[2].v = _mm256_loadu_ps( y0 + 2*n_elem_per_reg ); - x3v.v = _mm256_loadu_ps( x0 + 3*n_elem_per_reg ); - y3v.v = _mm256_loadu_ps( y0 + 3*n_elem_per_reg ); + xv[3].v = _mm256_loadu_ps( x0 + 3*n_elem_per_reg ); + yv[3].v = _mm256_loadu_ps( y0 + 3*n_elem_per_reg ); // Compute the element-wise product of the x and y vectors, // storing in the corresponding rho vectors. - rho0v.v = _mm256_fmadd_ps( x0v.v, y0v.v, rho0v.v ); - rho1v.v = _mm256_fmadd_ps( x1v.v, y1v.v, rho1v.v ); - rho2v.v = _mm256_fmadd_ps( x2v.v, y2v.v, rho2v.v ); - rho3v.v = _mm256_fmadd_ps( x3v.v, y3v.v, rho3v.v ); + rhov[0].v = _mm256_fmadd_ps( xv[0].v, yv[0].v, rhov[0].v ); + rhov[1].v = _mm256_fmadd_ps( xv[1].v, yv[1].v, rhov[1].v ); + rhov[2].v = _mm256_fmadd_ps( xv[2].v, yv[2].v, rhov[2].v ); + rhov[3].v = _mm256_fmadd_ps( xv[3].v, yv[3].v, rhov[3].v ); x0 += ( n_elem_per_reg * n_iter_unroll ); y0 += ( n_elem_per_reg * n_iter_unroll ); } // Accumulate the unrolled rho vectors into a single vector. - rho0v.v += rho1v.v; - rho0v.v += rho2v.v; - rho0v.v += rho3v.v; + rhov[0].v = _mm256_add_ps(rhov[0].v,rhov[1].v); + rhov[0].v = _mm256_add_ps(rhov[0].v,rhov[2].v); + rhov[0].v = _mm256_add_ps(rhov[0].v,rhov[3].v); + + v4sf_t inter0, inter1; + + inter0.v = _mm256_extractf128_ps(rhov[0].v,0); + inter1.v = _mm256_extractf128_ps(rhov[0].v,1); + + inter0.v = _mm_add_ps(inter0.v, inter1.v); + + inter1.v = _mm_permute_ps(inter0.v, 14); + + inter0.v = _mm_add_ps(inter0.v,inter1.v); // Accumulate the final rho vector into a single scalar result. - rho0 = rho0v.f[0] + rho0v.f[1] + rho0v.f[2] + rho0v.f[3] + - rho0v.f[4] + rho0v.f[5] + rho0v.f[6] + rho0v.f[7]; + rho0 = inter0.f[0] + inter0.f[1]; // Issue vzeroupper instruction to clear upper lanes of ymm registers. // This avoids a performance penalty caused by false dependencies when @@ -206,12 +228,8 @@ void bli_ddotxv_zen_int double* restrict y0; double rho0; - v4df_t rho0v, rho1v, rho2v, rho3v; - v4df_t x0v, y0v; - v4df_t x1v, y1v; - v4df_t x2v, y2v; - v4df_t x3v, y3v; - + v4df_t rhov[4], xv[4], yv[4]; + // If beta is zero, initialize rho1 to zero instead of scaling // rho by beta (in case rho contains NaN or Inf). if ( PASTEMAC(d,eq0)( *beta ) ) @@ -245,44 +263,51 @@ void bli_ddotxv_zen_int y0 = y; // Initialize the unrolled iterations' rho vectors to zero. - rho0v.v = _mm256_setzero_pd(); - rho1v.v = _mm256_setzero_pd(); - rho2v.v = _mm256_setzero_pd(); - rho3v.v = _mm256_setzero_pd(); + rhov[0].v = _mm256_setzero_pd(); + rhov[1].v = _mm256_setzero_pd(); + rhov[2].v = _mm256_setzero_pd(); + rhov[3].v = _mm256_setzero_pd(); for ( i = 0; i < n_viter; ++i ) { // Load the x and y input vector elements. - x0v.v = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); - y0v.v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + xv[0].v = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); + yv[0].v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); - x1v.v = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); - y1v.v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + xv[1].v = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); + yv[1].v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); - x2v.v = _mm256_loadu_pd( x0 + 2*n_elem_per_reg ); - y2v.v = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); + xv[2].v = _mm256_loadu_pd( x0 + 2*n_elem_per_reg ); + yv[2].v = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); - x3v.v = _mm256_loadu_pd( x0 + 3*n_elem_per_reg ); - y3v.v = _mm256_loadu_pd( y0 + 3*n_elem_per_reg ); + xv[3].v = _mm256_loadu_pd( x0 + 3*n_elem_per_reg ); + yv[3].v = _mm256_loadu_pd( y0 + 3*n_elem_per_reg ); // Compute the element-wise product of the x and y vectors, // storing in the corresponding rho vectors. - rho0v.v = _mm256_fmadd_pd( x0v.v, y0v.v, rho0v.v ); - rho1v.v = _mm256_fmadd_pd( x1v.v, y1v.v, rho1v.v ); - rho2v.v = _mm256_fmadd_pd( x2v.v, y2v.v, rho2v.v ); - rho3v.v = _mm256_fmadd_pd( x3v.v, y3v.v, rho3v.v ); + rhov[0].v = _mm256_fmadd_pd( xv[0].v, yv[0].v, rhov[0].v ); + rhov[1].v = _mm256_fmadd_pd( xv[1].v, yv[1].v, rhov[1].v ); + rhov[2].v = _mm256_fmadd_pd( xv[2].v, yv[2].v, rhov[2].v ); + rhov[3].v = _mm256_fmadd_pd( xv[3].v, yv[3].v, rhov[3].v ); x0 += ( n_elem_per_reg * n_iter_unroll ); y0 += ( n_elem_per_reg * n_iter_unroll ); } // Accumulate the unrolled rho vectors into a single vector. - rho0v.v += rho1v.v; - rho0v.v += rho2v.v; - rho0v.v += rho3v.v; + rhov[0].v = _mm256_add_pd(rhov[1].v,rhov[0].v); + rhov[0].v = _mm256_add_pd(rhov[2].v,rhov[0].v); + rhov[0].v = _mm256_add_pd(rhov[3].v,rhov[0].v); + + v2df_t inter1, inter2; + + inter1.v = _mm256_extractf128_pd(rhov[0].v,1); + inter2.v = _mm256_extractf128_pd(rhov[0].v,0); + + inter1.v = _mm_add_pd(inter1.v, inter2.v); // Accumulate the final rho vector into a single scalar result. - rho0 = rho0v.d[0] + rho0v.d[1] + rho0v.d[2] + rho0v.d[3]; + rho0 = inter1.d[0] + inter1.d[1]; // Issue vzeroupper instruction to clear upper lanes of ymm registers. // This avoids a performance penalty caused by false dependencies when From 6d1edca727fb1b91a6f46a7a863474286be30da0 Mon Sep 17 00:00:00 2001 From: Dipal M Zambare Date: Tue, 1 Feb 2022 10:22:58 +0530 Subject: [PATCH 069/243] Optimized CPU feature determination. We added new API to check if the CPU architecture has support for AVX instruction. This API was calling CPUID instruction every time it is invoked. However, since this information does not change at runtime, it is sufficient to determine it once and use the cached results for subsequent calls. This optimization is needed to improve performance for small size matrix vector operations. AMD-Internal: [CPUPL-2009] Change-Id: If6697e1da6dd6b7f28fbfed45215ea3fdd569c5f --- frame/base/bli_cpuid.c | 45 ++++++++++++++++++++++++++++++++++++------ 1 file changed, 39 insertions(+), 6 deletions(-) diff --git a/frame/base/bli_cpuid.c b/frame/base/bli_cpuid.c index 98ea947f3c..dfac510440 100644 --- a/frame/base/bli_cpuid.c +++ b/frame/base/bli_cpuid.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018-2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. Copyright (C) 2019, Dave Love, University of Manchester Redistribution and use in source and binary forms, with or without @@ -501,13 +501,23 @@ bool bli_cpuid_is_bulldozer return TRUE; } -bool bli_cpuid_is_avx_supported( void ) +// Check (at runtime) if AVX is supported on the current platform, this is to +// ensure that AVX kernels are not used on legacy platforms which results in crash + +// The support for AVX is checked only once (when this API is called first time) +// On subsequent calls the cached value is returned. This is achieved using +// pthread_once mechanism since this information does not change once the library +// is loaded. +static bool is_avx_supported = FALSE; + + +// Determine if the CPU has support for AVX. +void bli_cpuid_check_avx_support( void ) { uint32_t family, model, features; // Call the CPUID instruction and parse its results into a family id, - // model id, and a feature bit field. The return value encodes the - // vendor. + // model id, and a feature bit field. bli_cpuid_query( &family, &model, &features ); // Check for expected CPU features. @@ -515,9 +525,32 @@ bool bli_cpuid_is_avx_supported( void ) FEATURE_FMA3 | FEATURE_AVX2; - if ( !bli_cpuid_has_features( features, expected ) ) return FALSE; + if ( !bli_cpuid_has_features( features, expected ) ) + { + is_avx_supported = FALSE; + } + else + { + is_avx_supported = TRUE; + } +} - return TRUE; +static bli_pthread_once_t once_check_avx_support = BLIS_PTHREAD_ONCE_INIT; + +// Ensure that actual support determincation happens only once +void bli_cpuid_check_avx_support_once( void ) +{ +#ifndef BLIS_CONFIGURETIME_CPUID + bli_pthread_once( &once_check_avx_support, bli_cpuid_check_avx_support ); +#endif +} + +// API to check if AVX is supported or not on the current platform. +bool bli_cpuid_is_avx_supported( void ) +{ + bli_cpuid_check_avx_support_once(); + + return is_avx_supported; } #elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) From 6696f91f416ad2a1d7e05cb731a41adf86ea08cf Mon Sep 17 00:00:00 2001 From: Harihara Sudhan S Date: Fri, 28 Jan 2022 11:44:38 +0530 Subject: [PATCH 070/243] Improved DGEMV performance for column-major cases - Altered the framework to use 2 more fused kernels for better problem decomposition - Increased unroll factor in AXPYF5 and AXPYF8 kernels to improve register usage AMD-Internal: [CPUPL-1970] Change-Id: I79750235d9554466def5ff93898f832834990343 --- frame/2/gemv/bli_gemv_unf_var2_amd.c | 94 +++++- kernels/zen/1f/bli_axpyf_zen_int_5.c | 356 +++++++++++++-------- kernels/zen/1f/bli_axpyf_zen_int_8.c | 450 ++++++++++++++++++++------- 3 files changed, 653 insertions(+), 247 deletions(-) diff --git a/frame/2/gemv/bli_gemv_unf_var2_amd.c b/frame/2/gemv/bli_gemv_unf_var2_amd.c index d7f5145e31..831d906ca4 100644 --- a/frame/2/gemv/bli_gemv_unf_var2_amd.c +++ b/frame/2/gemv/bli_gemv_unf_var2_amd.c @@ -313,27 +313,87 @@ void bli_dgemv_unf_var2 } } - for ( i = 0; i < n_iter; i += f ) + dim_t fuse_factor = 8; + dim_t f_temp = 0; + + // Change the fuse factor based on + // Input size and available kernels + // This ensures that fusing is possible when the number of + // left over colums is less (better problem decomposition) + if (n < 5) fuse_factor = 4; + else if (n < 8) fuse_factor = 5; + + for (i = 0; i < n_iter; i += f) { - f = bli_determine_blocksize_dim_f( i, n_iter, BLIS_DGEMV_VAR2_FUSE ); + f = bli_determine_blocksize_dim_f(i, n_iter, fuse_factor); - A1 = a + (0 )*rs_at + (i )*cs_at; - x1 = x + (i )*incx; + A1 = a + (i)*cs_at; + x1 = x + (i)*incx; - /* y = y + alpha * A1 * x1; */ - bli_daxpyf_zen_int_16x4 - ( - conja, - conjx, - n_elem, - f, - alpha, - A1, rs_at, cs_at, - x1, incx, - y_buf, buf_incy, - cntx - ); + // Pick kernel based on problem size + switch (f) + { + case 8: + + bli_daxpyf_zen_int_8( + conja, + conjx, + n_elem, + f, + alpha, + A1, rs_at, cs_at, + x1, incx, + y_buf, buf_incy, + cntx); + + break; + default: + + if (f < 5) + { + bli_daxpyf_zen_int_16x4( + conja, + conjx, + n_elem, + f, + alpha, + A1, rs_at, cs_at, + x1, incx, + y_buf, buf_incy, + cntx); + } + else + { + bli_daxpyf_zen_int_5( + conja, + conjx, + n_elem, + f, + alpha, + A1, rs_at, cs_at, + x1, incx, + y_buf, buf_incy, + cntx); + } + } + + // Calculate the next problem size + f_temp = bli_determine_blocksize_dim_f(i + f, n_iter, fuse_factor); + + // Change fuse factor based on the next problem size + if (f_temp < fuse_factor) + { + if (f_temp < 5) + { + fuse_factor = 4; + } + else + { + fuse_factor = 5; + } + } } + if ((incy > 1) && bli_mem_is_alloc( &mem_bufY )) { //store the result from unit strided y_buf to non-unit strided Y diff --git a/kernels/zen/1f/bli_axpyf_zen_int_5.c b/kernels/zen/1f/bli_axpyf_zen_int_5.c index d09a85f57f..8b1f697cec 100644 --- a/kernels/zen/1f/bli_axpyf_zen_int_5.c +++ b/kernels/zen/1f/bli_axpyf_zen_int_5.c @@ -329,27 +329,13 @@ void bli_daxpyf_zen_int_5 dim_t i; - double* restrict a0; - double* restrict a1; - double* restrict a2; - double* restrict a3; - double* restrict a4; + double* restrict av[5] __attribute__((aligned(64))); double* restrict y0; - v4df_t chi0v, chi1v, chi2v, chi3v; - v4df_t chi4v; - - v4df_t a00v, a01v, a02v, a03v; - v4df_t a04v; - - v4df_t a10v, a11v, a12v, a13v; - v4df_t a14v; - - v4df_t y0v, y1v; - - double chi0, chi1, chi2, chi3; - double chi4; + v4df_t chiv[5], a_vec[20], yv[4]; + + double chi[5]; // If either dimension is zero, or if alpha is zero, return early. if ( bli_zero_dim2( m, b_n ) || bli_deq0( *alpha ) ) return; @@ -385,117 +371,241 @@ void bli_daxpyf_zen_int_5 } // At this point, we know that b_n is exactly equal to the fusing factor. - - a0 = a + 0*lda; - a1 = a + 1*lda; - a2 = a + 2*lda; - a3 = a + 3*lda; - a4 = a + 4*lda; + // av points to the 5 columns under consideration + av[0] = a + 0*lda; + av[1] = a + 1*lda; + av[2] = a + 2*lda; + av[3] = a + 3*lda; + av[4] = a + 4*lda; y0 = y; - chi0 = *( x + 0*incx ); - chi1 = *( x + 1*incx ); - chi2 = *( x + 2*incx ); - chi3 = *( x + 3*incx ); - chi4 = *( x + 4*incx ); + chi[0] = *( x + 0*incx ); + chi[1] = *( x + 1*incx ); + chi[2] = *( x + 2*incx ); + chi[3] = *( x + 3*incx ); + chi[4] = *( x + 4*incx ); // Scale each chi scalar by alpha. - bli_dscals( *alpha, chi0 ); - bli_dscals( *alpha, chi1 ); - bli_dscals( *alpha, chi2 ); - bli_dscals( *alpha, chi3 ); - bli_dscals( *alpha, chi4 ); + bli_dscals( *alpha, chi[0] ); + bli_dscals( *alpha, chi[1] ); + bli_dscals( *alpha, chi[2] ); + bli_dscals( *alpha, chi[3] ); + bli_dscals( *alpha, chi[4] ); // Broadcast the (alpha*chi?) scalars to all elements of vector registers. - chi0v.v = _mm256_broadcast_sd( &chi0 ); - chi1v.v = _mm256_broadcast_sd( &chi1 ); - chi2v.v = _mm256_broadcast_sd( &chi2 ); - chi3v.v = _mm256_broadcast_sd( &chi3 ); - chi4v.v = _mm256_broadcast_sd( &chi4 ); + chiv[0].v = _mm256_broadcast_sd( &chi[0] ); + chiv[1].v = _mm256_broadcast_sd( &chi[1] ); + chiv[2].v = _mm256_broadcast_sd( &chi[2] ); + chiv[3].v = _mm256_broadcast_sd( &chi[3] ); + chiv[4].v = _mm256_broadcast_sd( &chi[4] ); // If there are vectorized iterations, perform them with vector // instructions. if ( inca == 1 && incy == 1 ) { - for ( i = 0; (i + 7) < m; i += 8 ) + // 16 elements of the result are computed per iteration + for ( i = 0; (i + 15) < m; i += 16 ) { // Load the input values. - y0v.v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); - y1v.v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + yv[0].v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + yv[1].v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + yv[2].v = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); + yv[3].v = _mm256_loadu_pd( y0 + 3*n_elem_per_reg ); + + a_vec[0].v = _mm256_loadu_pd( av[0] + 0*n_elem_per_reg ); + a_vec[1].v = _mm256_loadu_pd( av[1] + 0*n_elem_per_reg ); + a_vec[2].v = _mm256_loadu_pd( av[2] + 0*n_elem_per_reg ); + a_vec[3].v = _mm256_loadu_pd( av[3] + 0*n_elem_per_reg ); + a_vec[4].v = _mm256_loadu_pd( av[4] + 0*n_elem_per_reg ); + + a_vec[5].v = _mm256_loadu_pd( av[0] + 1*n_elem_per_reg ); + a_vec[6].v = _mm256_loadu_pd( av[1] + 1*n_elem_per_reg ); + a_vec[7].v = _mm256_loadu_pd( av[2] + 1*n_elem_per_reg ); + a_vec[8].v = _mm256_loadu_pd( av[3] + 1*n_elem_per_reg ); + a_vec[9].v = _mm256_loadu_pd( av[4] + 1*n_elem_per_reg ); + + a_vec[10].v = _mm256_loadu_pd( av[0] + 2*n_elem_per_reg ); + a_vec[11].v = _mm256_loadu_pd( av[1] + 2*n_elem_per_reg ); + a_vec[12].v = _mm256_loadu_pd( av[2] + 2*n_elem_per_reg ); + a_vec[13].v = _mm256_loadu_pd( av[3] + 2*n_elem_per_reg ); + a_vec[14].v = _mm256_loadu_pd( av[4] + 2*n_elem_per_reg ); + + a_vec[15].v = _mm256_loadu_pd( av[0] + 3*n_elem_per_reg ); + a_vec[16].v = _mm256_loadu_pd( av[1] + 3*n_elem_per_reg ); + a_vec[17].v = _mm256_loadu_pd( av[2] + 3*n_elem_per_reg ); + a_vec[18].v = _mm256_loadu_pd( av[3] + 3*n_elem_per_reg ); + a_vec[19].v = _mm256_loadu_pd( av[4] + 3*n_elem_per_reg ); - a00v.v = _mm256_loadu_pd( a0 + 0*n_elem_per_reg ); - a10v.v = _mm256_loadu_pd( a0 + 1*n_elem_per_reg ); - - a01v.v = _mm256_loadu_pd( a1 + 0*n_elem_per_reg ); - a11v.v = _mm256_loadu_pd( a1 + 1*n_elem_per_reg ); - - a02v.v = _mm256_loadu_pd( a2 + 0*n_elem_per_reg ); - a12v.v = _mm256_loadu_pd( a2 + 1*n_elem_per_reg ); + // perform : y += alpha * x; + yv[0].v = _mm256_fmadd_pd( a_vec[0].v, chiv[0].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[1].v, chiv[1].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[2].v, chiv[2].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[3].v, chiv[3].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[4].v, chiv[4].v, yv[0].v ); + + yv[1].v = _mm256_fmadd_pd( a_vec[5].v, chiv[0].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[6].v, chiv[1].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[7].v, chiv[2].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[8].v, chiv[3].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[9].v, chiv[4].v, yv[1].v ); + + yv[2].v = _mm256_fmadd_pd( a_vec[10].v, chiv[0].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[11].v, chiv[1].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[12].v, chiv[2].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[13].v, chiv[3].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[14].v, chiv[4].v, yv[2].v ); + + yv[3].v = _mm256_fmadd_pd( a_vec[15].v, chiv[0].v, yv[3].v ); + yv[3].v = _mm256_fmadd_pd( a_vec[16].v, chiv[1].v, yv[3].v ); + yv[3].v = _mm256_fmadd_pd( a_vec[17].v, chiv[2].v, yv[3].v ); + yv[3].v = _mm256_fmadd_pd( a_vec[18].v, chiv[3].v, yv[3].v ); + yv[3].v = _mm256_fmadd_pd( a_vec[19].v, chiv[4].v, yv[3].v ); - a03v.v = _mm256_loadu_pd( a3 + 0*n_elem_per_reg ); - a13v.v = _mm256_loadu_pd( a3 + 1*n_elem_per_reg ); + // Store the output. + _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), yv[0].v ); + _mm256_storeu_pd( (y0 + 1*n_elem_per_reg), yv[1].v ); + _mm256_storeu_pd( (y0 + 2*n_elem_per_reg), yv[2].v ); + _mm256_storeu_pd( (y0 + 3*n_elem_per_reg), yv[3].v ); + + y0 += n_elem_per_reg * 4; + av[0] += n_elem_per_reg * 4; + av[1] += n_elem_per_reg * 4; + av[2] += n_elem_per_reg * 4; + av[3] += n_elem_per_reg * 4; + av[4] += n_elem_per_reg * 4; + } - a04v.v = _mm256_loadu_pd( a4 + 0*n_elem_per_reg ); - a14v.v = _mm256_loadu_pd( a4 + 1*n_elem_per_reg ); + // 12 elements of the result are computed per iteration + for ( ; (i + 11) < m; i += 12 ) + { + // Load the input values. + yv[0].v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + yv[1].v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + yv[2].v = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); + + a_vec[0].v = _mm256_loadu_pd( av[0] + 0*n_elem_per_reg ); + a_vec[1].v = _mm256_loadu_pd( av[1] + 0*n_elem_per_reg ); + a_vec[2].v = _mm256_loadu_pd( av[2] + 0*n_elem_per_reg ); + a_vec[3].v = _mm256_loadu_pd( av[3] + 0*n_elem_per_reg ); + a_vec[4].v = _mm256_loadu_pd( av[4] + 0*n_elem_per_reg ); + + a_vec[5].v = _mm256_loadu_pd( av[0] + 1*n_elem_per_reg ); + a_vec[6].v = _mm256_loadu_pd( av[1] + 1*n_elem_per_reg ); + a_vec[7].v = _mm256_loadu_pd( av[2] + 1*n_elem_per_reg ); + a_vec[8].v = _mm256_loadu_pd( av[3] + 1*n_elem_per_reg ); + a_vec[9].v = _mm256_loadu_pd( av[4] + 1*n_elem_per_reg ); + + a_vec[10].v = _mm256_loadu_pd( av[0] + 2*n_elem_per_reg ); + a_vec[11].v = _mm256_loadu_pd( av[1] + 2*n_elem_per_reg ); + a_vec[12].v = _mm256_loadu_pd( av[2] + 2*n_elem_per_reg ); + a_vec[13].v = _mm256_loadu_pd( av[3] + 2*n_elem_per_reg ); + a_vec[14].v = _mm256_loadu_pd( av[4] + 2*n_elem_per_reg ); // perform : y += alpha * x; - y0v.v = _mm256_fmadd_pd( a00v.v, chi0v.v, y0v.v ); - y1v.v = _mm256_fmadd_pd( a10v.v, chi0v.v, y1v.v ); + yv[0].v = _mm256_fmadd_pd( a_vec[0].v, chiv[0].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[1].v, chiv[1].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[2].v, chiv[2].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[3].v, chiv[3].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[4].v, chiv[4].v, yv[0].v ); + + yv[1].v = _mm256_fmadd_pd( a_vec[5].v, chiv[0].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[6].v, chiv[1].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[7].v, chiv[2].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[8].v, chiv[3].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[9].v, chiv[4].v, yv[1].v ); + + yv[2].v = _mm256_fmadd_pd( a_vec[10].v, chiv[0].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[11].v, chiv[1].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[12].v, chiv[2].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[13].v, chiv[3].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[14].v, chiv[4].v, yv[2].v ); - y0v.v = _mm256_fmadd_pd( a01v.v, chi1v.v, y0v.v ); - y1v.v = _mm256_fmadd_pd( a11v.v, chi1v.v, y1v.v ); + // Store the output. + _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), yv[0].v ); + _mm256_storeu_pd( (y0 + 1*n_elem_per_reg), yv[1].v ); + _mm256_storeu_pd( (y0 + 2*n_elem_per_reg), yv[2].v ); + + y0 += n_elem_per_reg * 3; + av[0] += n_elem_per_reg * 3; + av[1] += n_elem_per_reg * 3; + av[2] += n_elem_per_reg * 3; + av[3] += n_elem_per_reg * 3; + av[4] += n_elem_per_reg * 3; + } - y0v.v = _mm256_fmadd_pd( a02v.v, chi2v.v, y0v.v ); - y1v.v = _mm256_fmadd_pd( a12v.v, chi2v.v, y1v.v ); + // 8 elements of the result are computed per iteration + for (; (i + 7) < m; i += 8 ) + { + // Load the input values. + yv[0].v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + yv[1].v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); - y0v.v = _mm256_fmadd_pd( a03v.v, chi3v.v, y0v.v ); - y1v.v = _mm256_fmadd_pd( a13v.v, chi3v.v, y1v.v ); + a_vec[0].v = _mm256_loadu_pd( av[0] + 0*n_elem_per_reg ); + a_vec[1].v = _mm256_loadu_pd( av[1] + 0*n_elem_per_reg ); + a_vec[2].v = _mm256_loadu_pd( av[2] + 0*n_elem_per_reg ); + a_vec[3].v = _mm256_loadu_pd( av[3] + 0*n_elem_per_reg ); + a_vec[4].v = _mm256_loadu_pd( av[4] + 0*n_elem_per_reg ); - y0v.v = _mm256_fmadd_pd( a04v.v, chi4v.v, y0v.v ); - y1v.v = _mm256_fmadd_pd( a14v.v, chi4v.v, y1v.v ); + a_vec[5].v = _mm256_loadu_pd( av[0] + 1*n_elem_per_reg ); + a_vec[6].v = _mm256_loadu_pd( av[1] + 1*n_elem_per_reg ); + a_vec[7].v = _mm256_loadu_pd( av[2] + 1*n_elem_per_reg ); + a_vec[8].v = _mm256_loadu_pd( av[3] + 1*n_elem_per_reg ); + a_vec[9].v = _mm256_loadu_pd( av[4] + 1*n_elem_per_reg ); + // perform : y += alpha * x; + yv[0].v = _mm256_fmadd_pd( a_vec[0].v, chiv[0].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[1].v, chiv[1].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[2].v, chiv[2].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[3].v, chiv[3].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[4].v, chiv[4].v, yv[0].v ); + + yv[1].v = _mm256_fmadd_pd( a_vec[5].v, chiv[0].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[6].v, chiv[1].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[7].v, chiv[2].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[8].v, chiv[3].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[9].v, chiv[4].v, yv[1].v ); // Store the output. - _mm256_storeu_pd( (double *)(y0 + 0*n_elem_per_reg), y0v.v ); - _mm256_storeu_pd( (double *)(y0 + 1*n_elem_per_reg), y1v.v ); - - y0 += n_iter_unroll * n_elem_per_reg; - a0 += n_iter_unroll * n_elem_per_reg; - a1 += n_iter_unroll * n_elem_per_reg; - a2 += n_iter_unroll * n_elem_per_reg; - a3 += n_iter_unroll * n_elem_per_reg; - a4 += n_iter_unroll * n_elem_per_reg; + _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), yv[0].v ); + _mm256_storeu_pd( (y0 + 1*n_elem_per_reg), yv[1].v ); + + y0 += n_elem_per_reg * 2; + av[0] += n_elem_per_reg * 2; + av[1] += n_elem_per_reg * 2; + av[2] += n_elem_per_reg * 2; + av[3] += n_elem_per_reg * 2; + av[4] += n_elem_per_reg * 2; } + // 4 elements of the result are computed per iteration for( ; (i + 3) < m; i += 4 ) { // Load the input values. - y0v.v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); - - a00v.v = _mm256_loadu_pd( a0 + 0*n_elem_per_reg ); - a01v.v = _mm256_loadu_pd( a1 + 0*n_elem_per_reg ); - a02v.v = _mm256_loadu_pd( a2 + 0*n_elem_per_reg ); - a03v.v = _mm256_loadu_pd( a3 + 0*n_elem_per_reg ); - a04v.v = _mm256_loadu_pd( a4 + 0*n_elem_per_reg ); + yv[0].v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + a_vec[0].v = _mm256_loadu_pd( av[0] + 0*n_elem_per_reg ); + a_vec[1].v = _mm256_loadu_pd( av[1] + 0*n_elem_per_reg ); + a_vec[2].v = _mm256_loadu_pd( av[2] + 0*n_elem_per_reg ); + a_vec[3].v = _mm256_loadu_pd( av[3] + 0*n_elem_per_reg ); + a_vec[4].v = _mm256_loadu_pd( av[4] + 0*n_elem_per_reg ); // perform : y += alpha * x; - y0v.v = _mm256_fmadd_pd( a00v.v, chi0v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a01v.v, chi1v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a02v.v, chi2v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a03v.v, chi3v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a04v.v, chi4v.v, y0v.v ); + yv[0].v = _mm256_fmadd_pd( a_vec[0].v, chiv[0].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[1].v, chiv[1].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[2].v, chiv[2].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[3].v, chiv[3].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[4].v, chiv[4].v, yv[0].v ); // Store the output. - _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), y0v.v ); + _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), yv[0].v ); y0 += n_elem_per_reg; - a0 += n_elem_per_reg; - a1 += n_elem_per_reg; - a2 += n_elem_per_reg; - a3 += n_elem_per_reg; - a4 += n_elem_per_reg; + av[0] += n_elem_per_reg; + av[1] += n_elem_per_reg; + av[2] += n_elem_per_reg; + av[3] += n_elem_per_reg; + av[4] += n_elem_per_reg; } // If there are leftover iterations, perform them with scalar code. @@ -503,25 +613,25 @@ void bli_daxpyf_zen_int_5 { double y0c = *y0; - const double a0c = *a0; - const double a1c = *a1; - const double a2c = *a2; - const double a3c = *a3; - const double a4c = *a4; + const double a0c = *av[0]; + const double a1c = *av[1]; + const double a2c = *av[2]; + const double a3c = *av[3]; + const double a4c = *av[4]; - y0c += chi0 * a0c; - y0c += chi1 * a1c; - y0c += chi2 * a2c; - y0c += chi3 * a3c; - y0c += chi4 * a4c; + y0c += chi[0] * a0c; + y0c += chi[1] * a1c; + y0c += chi[2] * a2c; + y0c += chi[3] * a3c; + y0c += chi[4] * a4c; *y0 = y0c; - a0 += 1; - a1 += 1; - a2 += 1; - a3 += 1; - a4 += 1; + av[0] += 1; + av[1] += 1; + av[2] += 1; + av[3] += 1; + av[4] += 1; y0 += 1; } } @@ -531,25 +641,25 @@ void bli_daxpyf_zen_int_5 { double y0c = *y0; - const double a0c = *a0; - const double a1c = *a1; - const double a2c = *a2; - const double a3c = *a3; - const double a4c = *a4; + const double a0c = *av[0]; + const double a1c = *av[1]; + const double a2c = *av[2]; + const double a3c = *av[3]; + const double a4c = *av[4]; - y0c += chi0 * a0c; - y0c += chi1 * a1c; - y0c += chi2 * a2c; - y0c += chi3 * a3c; - y0c += chi4 * a4c; + y0c += chi[0] * a0c; + y0c += chi[1] * a1c; + y0c += chi[2] * a2c; + y0c += chi[3] * a3c; + y0c += chi[4] * a4c; *y0 = y0c; - a0 += inca; - a1 += inca; - a2 += inca; - a3 += inca; - a4 += inca; + av[0] += inca; + av[1] += inca; + av[2] += inca; + av[3] += inca; + av[4] += inca; y0 += incy; } @@ -1153,7 +1263,7 @@ void bli_daxpyf_zen_int_16x4 a2 += n_elem_per_reg; a3 += n_elem_per_reg; } -#if 1 + for ( ; (i + 1) < m; i += 2) { @@ -1186,7 +1296,7 @@ void bli_daxpyf_zen_int_16x4 a2 += 2; a3 += 2; } -#endif + // If there are leftover iterations, perform them with scalar code. for ( ; (i + 0) < m ; ++i ) { diff --git a/kernels/zen/1f/bli_axpyf_zen_int_8.c b/kernels/zen/1f/bli_axpyf_zen_int_8.c index b958600ce6..27dafb28fc 100644 --- a/kernels/zen/1f/bli_axpyf_zen_int_8.c +++ b/kernels/zen/1f/bli_axpyf_zen_int_8.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2018, The University of Texas at Austin - Copyright (C) 2016 - 2018, Advanced Micro Devices, Inc. + Copyright (C) 2016 - 2022, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -279,32 +279,19 @@ void bli_daxpyf_zen_int_8 const dim_t fuse_fac = 8; const dim_t n_elem_per_reg = 4; - const dim_t n_iter_unroll = 1; + const dim_t n_iter_unroll[4] = {4, 3, 2, 1}; dim_t i; - dim_t m_viter; - dim_t m_left; + dim_t m_viter[4]; + dim_t m_left = m; - double* restrict a0; - double* restrict a1; - double* restrict a2; - double* restrict a3; - double* restrict a4; - double* restrict a5; - double* restrict a6; - double* restrict a7; + double* restrict av[8] __attribute__((aligned(64))); double* restrict y0; - v4df_t chi0v, chi1v, chi2v, chi3v; - v4df_t chi4v, chi5v, chi6v, chi7v; + v4df_t chiv[8], a_vec[32], yv[4]; - v4df_t a0v, a1v, a2v, a3v; - v4df_t a4v, a5v, a6v, a7v; - v4df_t y0v; - - double chi0, chi1, chi2, chi3; - double chi4, chi5, chi6, chi7; + double chi[8] __attribute__((aligned(64))); // If either dimension is zero, or if alpha is zero, return early. if ( bli_zero_dim2( m, b_n ) || PASTEMAC(d,eq0)( *alpha ) ) return; @@ -343,94 +330,343 @@ void bli_daxpyf_zen_int_8 // Use the unrolling factor and the number of elements per register // to compute the number of vectorized and leftover iterations. - m_viter = ( m ) / ( n_elem_per_reg * n_iter_unroll ); - m_left = ( m ) % ( n_elem_per_reg * n_iter_unroll ); + m_viter[0] = ( m_left ) / ( n_elem_per_reg * n_iter_unroll[0] ); + m_left = ( m_left ) % ( n_elem_per_reg * n_iter_unroll[0] ); + + m_viter[1] = ( m_left ) / ( n_elem_per_reg * n_iter_unroll[1] ); + m_left = ( m_left ) % ( n_elem_per_reg * n_iter_unroll[1] ); + + m_viter[2] = ( m_left ) / ( n_elem_per_reg * n_iter_unroll[2] ); + m_left = ( m_left ) % ( n_elem_per_reg * n_iter_unroll[2] ); + + m_viter[3] = ( m_left ) / ( n_elem_per_reg * n_iter_unroll[3] ); + m_left = ( m_left ) % ( n_elem_per_reg * n_iter_unroll[3] ); // If there is anything that would interfere with our use of contiguous // vector loads/stores, override m_viter and m_left to use scalar code // for all iterations. if ( inca != 1 || incy != 1 ) { - m_viter = 0; + m_viter[0] = m_viter[1] = m_viter[2] = m_viter[3] = 0; m_left = m; } - a0 = a + 0*lda; - a1 = a + 1*lda; - a2 = a + 2*lda; - a3 = a + 3*lda; - a4 = a + 4*lda; - a5 = a + 5*lda; - a6 = a + 6*lda; - a7 = a + 7*lda; + // av points to the 8 columns under consideration + av[0] = a + 0*lda; + av[1] = a + 1*lda; + av[2] = a + 2*lda; + av[3] = a + 3*lda; + av[4] = a + 4*lda; + av[5] = a + 5*lda; + av[6] = a + 6*lda; + av[7] = a + 7*lda; y0 = y; - chi0 = *( x + 0*incx ); - chi1 = *( x + 1*incx ); - chi2 = *( x + 2*incx ); - chi3 = *( x + 3*incx ); - chi4 = *( x + 4*incx ); - chi5 = *( x + 5*incx ); - chi6 = *( x + 6*incx ); - chi7 = *( x + 7*incx ); + chi[0] = *( x + 0*incx ); + chi[1] = *( x + 1*incx ); + chi[2] = *( x + 2*incx ); + chi[3] = *( x + 3*incx ); + chi[4] = *( x + 4*incx ); + chi[5] = *( x + 5*incx ); + chi[6] = *( x + 6*incx ); + chi[7] = *( x + 7*incx ); // Scale each chi scalar by alpha. - PASTEMAC(d,scals)( *alpha, chi0 ); - PASTEMAC(d,scals)( *alpha, chi1 ); - PASTEMAC(d,scals)( *alpha, chi2 ); - PASTEMAC(d,scals)( *alpha, chi3 ); - PASTEMAC(d,scals)( *alpha, chi4 ); - PASTEMAC(d,scals)( *alpha, chi5 ); - PASTEMAC(d,scals)( *alpha, chi6 ); - PASTEMAC(d,scals)( *alpha, chi7 ); + PASTEMAC(d,scals)( *alpha, chi[0] ); + PASTEMAC(d,scals)( *alpha, chi[1] ); + PASTEMAC(d,scals)( *alpha, chi[2] ); + PASTEMAC(d,scals)( *alpha, chi[3] ); + PASTEMAC(d,scals)( *alpha, chi[4] ); + PASTEMAC(d,scals)( *alpha, chi[5] ); + PASTEMAC(d,scals)( *alpha, chi[6] ); + PASTEMAC(d,scals)( *alpha, chi[7] ); // Broadcast the (alpha*chi?) scalars to all elements of vector registers. - chi0v.v = _mm256_broadcast_sd( &chi0 ); - chi1v.v = _mm256_broadcast_sd( &chi1 ); - chi2v.v = _mm256_broadcast_sd( &chi2 ); - chi3v.v = _mm256_broadcast_sd( &chi3 ); - chi4v.v = _mm256_broadcast_sd( &chi4 ); - chi5v.v = _mm256_broadcast_sd( &chi5 ); - chi6v.v = _mm256_broadcast_sd( &chi6 ); - chi7v.v = _mm256_broadcast_sd( &chi7 ); + chiv[0].v = _mm256_broadcast_sd( &chi[0] ); + chiv[1].v = _mm256_broadcast_sd( &chi[1] ); + chiv[2].v = _mm256_broadcast_sd( &chi[2] ); + chiv[3].v = _mm256_broadcast_sd( &chi[3] ); + chiv[4].v = _mm256_broadcast_sd( &chi[4] ); + chiv[5].v = _mm256_broadcast_sd( &chi[5] ); + chiv[6].v = _mm256_broadcast_sd( &chi[6] ); + chiv[7].v = _mm256_broadcast_sd( &chi[7] ); // If there are vectorized iterations, perform them with vector // instructions. - for ( i = 0; i < m_viter; ++i ) + // 16 elements of the result are computed per iteration + for ( i = 0; i < m_viter[0]; ++i ) { // Load the input values. - y0v.v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); - a0v.v = _mm256_loadu_pd( a0 + 0*n_elem_per_reg ); - a1v.v = _mm256_loadu_pd( a1 + 0*n_elem_per_reg ); - a2v.v = _mm256_loadu_pd( a2 + 0*n_elem_per_reg ); - a3v.v = _mm256_loadu_pd( a3 + 0*n_elem_per_reg ); - a4v.v = _mm256_loadu_pd( a4 + 0*n_elem_per_reg ); - a5v.v = _mm256_loadu_pd( a5 + 0*n_elem_per_reg ); - a6v.v = _mm256_loadu_pd( a6 + 0*n_elem_per_reg ); - a7v.v = _mm256_loadu_pd( a7 + 0*n_elem_per_reg ); + yv[0].v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + yv[1].v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + yv[2].v = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); + yv[3].v = _mm256_loadu_pd( y0 + 3*n_elem_per_reg ); + + a_vec[0].v = _mm256_loadu_pd( av[0] + 0*n_elem_per_reg ); + a_vec[1].v = _mm256_loadu_pd( av[1] + 0*n_elem_per_reg ); + a_vec[2].v = _mm256_loadu_pd( av[2] + 0*n_elem_per_reg ); + a_vec[3].v = _mm256_loadu_pd( av[3] + 0*n_elem_per_reg ); + a_vec[4].v = _mm256_loadu_pd( av[4] + 0*n_elem_per_reg ); + a_vec[5].v = _mm256_loadu_pd( av[5] + 0*n_elem_per_reg ); + a_vec[6].v = _mm256_loadu_pd( av[6] + 0*n_elem_per_reg ); + a_vec[7].v = _mm256_loadu_pd( av[7] + 0*n_elem_per_reg ); + + a_vec[8].v = _mm256_loadu_pd( av[0] + 1*n_elem_per_reg ); + a_vec[9].v = _mm256_loadu_pd( av[1] + 1*n_elem_per_reg ); + a_vec[10].v = _mm256_loadu_pd( av[2] + 1*n_elem_per_reg ); + a_vec[11].v = _mm256_loadu_pd( av[3] + 1*n_elem_per_reg ); + a_vec[12].v = _mm256_loadu_pd( av[4] + 1*n_elem_per_reg ); + a_vec[13].v = _mm256_loadu_pd( av[5] + 1*n_elem_per_reg ); + a_vec[14].v = _mm256_loadu_pd( av[6] + 1*n_elem_per_reg ); + a_vec[15].v = _mm256_loadu_pd( av[7] + 1*n_elem_per_reg ); + + a_vec[16].v = _mm256_loadu_pd( av[0] + 2*n_elem_per_reg ); + a_vec[17].v = _mm256_loadu_pd( av[1] + 2*n_elem_per_reg ); + a_vec[18].v = _mm256_loadu_pd( av[2] + 2*n_elem_per_reg ); + a_vec[19].v = _mm256_loadu_pd( av[3] + 2*n_elem_per_reg ); + a_vec[20].v = _mm256_loadu_pd( av[4] + 2*n_elem_per_reg ); + a_vec[21].v = _mm256_loadu_pd( av[5] + 2*n_elem_per_reg ); + a_vec[22].v = _mm256_loadu_pd( av[6] + 2*n_elem_per_reg ); + a_vec[23].v = _mm256_loadu_pd( av[7] + 2*n_elem_per_reg ); + + a_vec[24].v = _mm256_loadu_pd( av[0] + 3*n_elem_per_reg ); + a_vec[25].v = _mm256_loadu_pd( av[1] + 3*n_elem_per_reg ); + a_vec[26].v = _mm256_loadu_pd( av[2] + 3*n_elem_per_reg ); + a_vec[27].v = _mm256_loadu_pd( av[3] + 3*n_elem_per_reg ); + a_vec[28].v = _mm256_loadu_pd( av[4] + 3*n_elem_per_reg ); + a_vec[29].v = _mm256_loadu_pd( av[5] + 3*n_elem_per_reg ); + a_vec[30].v = _mm256_loadu_pd( av[6] + 3*n_elem_per_reg ); + a_vec[31].v = _mm256_loadu_pd( av[7] + 3*n_elem_per_reg ); // perform : y += alpha * x; - y0v.v = _mm256_fmadd_pd( a0v.v, chi0v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a1v.v, chi1v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a2v.v, chi2v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a3v.v, chi3v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a4v.v, chi4v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a5v.v, chi5v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a6v.v, chi6v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a7v.v, chi7v.v, y0v.v ); + yv[0].v = _mm256_fmadd_pd( a_vec[0].v, chiv[0].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[1].v, chiv[1].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[2].v, chiv[2].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[3].v, chiv[3].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[4].v, chiv[4].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[5].v, chiv[5].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[6].v, chiv[6].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[7].v, chiv[7].v, yv[0].v ); + + yv[1].v = _mm256_fmadd_pd( a_vec[8].v, chiv[0].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[9].v, chiv[1].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[10].v, chiv[2].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[11].v, chiv[3].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[12].v, chiv[4].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[13].v, chiv[5].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[14].v, chiv[6].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[15].v, chiv[7].v, yv[1].v ); + + yv[2].v = _mm256_fmadd_pd( a_vec[16].v, chiv[0].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[17].v, chiv[1].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[18].v, chiv[2].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[19].v, chiv[3].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[20].v, chiv[4].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[21].v, chiv[5].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[22].v, chiv[6].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[23].v, chiv[7].v, yv[2].v ); + + yv[3].v = _mm256_fmadd_pd( a_vec[24].v, chiv[0].v, yv[3].v ); + yv[3].v = _mm256_fmadd_pd( a_vec[25].v, chiv[1].v, yv[3].v ); + yv[3].v = _mm256_fmadd_pd( a_vec[26].v, chiv[2].v, yv[3].v ); + yv[3].v = _mm256_fmadd_pd( a_vec[27].v, chiv[3].v, yv[3].v ); + yv[3].v = _mm256_fmadd_pd( a_vec[28].v, chiv[4].v, yv[3].v ); + yv[3].v = _mm256_fmadd_pd( a_vec[29].v, chiv[5].v, yv[3].v ); + yv[3].v = _mm256_fmadd_pd( a_vec[30].v, chiv[6].v, yv[3].v ); + yv[3].v = _mm256_fmadd_pd( a_vec[31].v, chiv[7].v, yv[3].v ); // Store the output. - _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), y0v.v ); + _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), yv[0].v ); + _mm256_storeu_pd( (y0 + 1*n_elem_per_reg), yv[1].v ); + _mm256_storeu_pd( (y0 + 2*n_elem_per_reg), yv[2].v ); + _mm256_storeu_pd( (y0 + 3*n_elem_per_reg), yv[3].v ); + + y0 += n_elem_per_reg * n_iter_unroll[0]; + av[0] += n_elem_per_reg * n_iter_unroll[0]; + av[1] += n_elem_per_reg * n_iter_unroll[0]; + av[2] += n_elem_per_reg * n_iter_unroll[0]; + av[3] += n_elem_per_reg * n_iter_unroll[0]; + av[4] += n_elem_per_reg * n_iter_unroll[0]; + av[5] += n_elem_per_reg * n_iter_unroll[0]; + av[6] += n_elem_per_reg * n_iter_unroll[0]; + av[7] += n_elem_per_reg * n_iter_unroll[0]; + } + + // 12 elements of the result are computed per iteration + for ( i = 0; i < m_viter[1]; ++i ) + { + // Load the input values. + yv[0].v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + yv[1].v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + yv[2].v = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); + + a_vec[0].v = _mm256_loadu_pd( av[0] + 0*n_elem_per_reg ); + a_vec[1].v = _mm256_loadu_pd( av[1] + 0*n_elem_per_reg ); + a_vec[2].v = _mm256_loadu_pd( av[2] + 0*n_elem_per_reg ); + a_vec[3].v = _mm256_loadu_pd( av[3] + 0*n_elem_per_reg ); + a_vec[4].v = _mm256_loadu_pd( av[4] + 0*n_elem_per_reg ); + a_vec[5].v = _mm256_loadu_pd( av[5] + 0*n_elem_per_reg ); + a_vec[6].v = _mm256_loadu_pd( av[6] + 0*n_elem_per_reg ); + a_vec[7].v = _mm256_loadu_pd( av[7] + 0*n_elem_per_reg ); + + a_vec[8].v = _mm256_loadu_pd( av[0] + 1*n_elem_per_reg ); + a_vec[9].v = _mm256_loadu_pd( av[1] + 1*n_elem_per_reg ); + a_vec[10].v = _mm256_loadu_pd( av[2] + 1*n_elem_per_reg ); + a_vec[11].v = _mm256_loadu_pd( av[3] + 1*n_elem_per_reg ); + a_vec[12].v = _mm256_loadu_pd( av[4] + 1*n_elem_per_reg ); + a_vec[13].v = _mm256_loadu_pd( av[5] + 1*n_elem_per_reg ); + a_vec[14].v = _mm256_loadu_pd( av[6] + 1*n_elem_per_reg ); + a_vec[15].v = _mm256_loadu_pd( av[7] + 1*n_elem_per_reg ); + + a_vec[16].v = _mm256_loadu_pd( av[0] + 2*n_elem_per_reg ); + a_vec[17].v = _mm256_loadu_pd( av[1] + 2*n_elem_per_reg ); + a_vec[18].v = _mm256_loadu_pd( av[2] + 2*n_elem_per_reg ); + a_vec[19].v = _mm256_loadu_pd( av[3] + 2*n_elem_per_reg ); + a_vec[20].v = _mm256_loadu_pd( av[4] + 2*n_elem_per_reg ); + a_vec[21].v = _mm256_loadu_pd( av[5] + 2*n_elem_per_reg ); + a_vec[22].v = _mm256_loadu_pd( av[6] + 2*n_elem_per_reg ); + a_vec[23].v = _mm256_loadu_pd( av[7] + 2*n_elem_per_reg ); + + // perform : y += alpha * x; + yv[0].v = _mm256_fmadd_pd( a_vec[0].v, chiv[0].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[1].v, chiv[1].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[2].v, chiv[2].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[3].v, chiv[3].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[4].v, chiv[4].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[5].v, chiv[5].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[6].v, chiv[6].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[7].v, chiv[7].v, yv[0].v ); + + yv[1].v = _mm256_fmadd_pd( a_vec[8].v, chiv[0].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[9].v, chiv[1].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[10].v, chiv[2].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[11].v, chiv[3].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[12].v, chiv[4].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[13].v, chiv[5].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[14].v, chiv[6].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[15].v, chiv[7].v, yv[1].v ); + + yv[2].v = _mm256_fmadd_pd( a_vec[16].v, chiv[0].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[17].v, chiv[1].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[18].v, chiv[2].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[19].v, chiv[3].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[20].v, chiv[4].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[21].v, chiv[5].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[22].v, chiv[6].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[23].v, chiv[7].v, yv[2].v ); + + // Store the output. + _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), yv[0].v ); + _mm256_storeu_pd( (y0 + 1*n_elem_per_reg), yv[1].v ); + _mm256_storeu_pd( (y0 + 2*n_elem_per_reg), yv[2].v ); + + y0 += n_elem_per_reg * n_iter_unroll[1]; + av[0] += n_elem_per_reg * n_iter_unroll[1]; + av[1] += n_elem_per_reg * n_iter_unroll[1]; + av[2] += n_elem_per_reg * n_iter_unroll[1]; + av[3] += n_elem_per_reg * n_iter_unroll[1]; + av[4] += n_elem_per_reg * n_iter_unroll[1]; + av[5] += n_elem_per_reg * n_iter_unroll[1]; + av[6] += n_elem_per_reg * n_iter_unroll[1]; + av[7] += n_elem_per_reg * n_iter_unroll[1]; + } + + // 8 elements of the result are computed per iteration + for ( i = 0; i < m_viter[2]; ++i ) + { + // Load the input values. + yv[0].v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + yv[1].v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + + a_vec[0].v = _mm256_loadu_pd( av[0] + 0*n_elem_per_reg ); + a_vec[1].v = _mm256_loadu_pd( av[1] + 0*n_elem_per_reg ); + a_vec[2].v = _mm256_loadu_pd( av[2] + 0*n_elem_per_reg ); + a_vec[3].v = _mm256_loadu_pd( av[3] + 0*n_elem_per_reg ); + a_vec[4].v = _mm256_loadu_pd( av[4] + 0*n_elem_per_reg ); + a_vec[5].v = _mm256_loadu_pd( av[5] + 0*n_elem_per_reg ); + a_vec[6].v = _mm256_loadu_pd( av[6] + 0*n_elem_per_reg ); + a_vec[7].v = _mm256_loadu_pd( av[7] + 0*n_elem_per_reg ); + + a_vec[8].v = _mm256_loadu_pd( av[0] + 1*n_elem_per_reg ); + a_vec[9].v = _mm256_loadu_pd( av[1] + 1*n_elem_per_reg ); + a_vec[10].v = _mm256_loadu_pd( av[2] + 1*n_elem_per_reg ); + a_vec[11].v = _mm256_loadu_pd( av[3] + 1*n_elem_per_reg ); + a_vec[12].v = _mm256_loadu_pd( av[4] + 1*n_elem_per_reg ); + a_vec[13].v = _mm256_loadu_pd( av[5] + 1*n_elem_per_reg ); + a_vec[14].v = _mm256_loadu_pd( av[6] + 1*n_elem_per_reg ); + a_vec[15].v = _mm256_loadu_pd( av[7] + 1*n_elem_per_reg ); + + // perform : y += alpha * x; + yv[0].v = _mm256_fmadd_pd( a_vec[0].v, chiv[0].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[1].v, chiv[1].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[2].v, chiv[2].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[3].v, chiv[3].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[4].v, chiv[4].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[5].v, chiv[5].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[6].v, chiv[6].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[7].v, chiv[7].v, yv[0].v ); + + yv[1].v = _mm256_fmadd_pd( a_vec[8].v, chiv[0].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[9].v, chiv[1].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[10].v, chiv[2].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[11].v, chiv[3].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[12].v, chiv[4].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[13].v, chiv[5].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[14].v, chiv[6].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[15].v, chiv[7].v, yv[1].v ); + + // Store the output. + _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), yv[0].v ); + _mm256_storeu_pd( (y0 + 1*n_elem_per_reg), yv[1].v ); + + y0 += n_elem_per_reg * n_iter_unroll[2]; + av[0] += n_elem_per_reg * n_iter_unroll[2]; + av[1] += n_elem_per_reg * n_iter_unroll[2]; + av[2] += n_elem_per_reg * n_iter_unroll[2]; + av[3] += n_elem_per_reg * n_iter_unroll[2]; + av[4] += n_elem_per_reg * n_iter_unroll[2]; + av[5] += n_elem_per_reg * n_iter_unroll[2]; + av[6] += n_elem_per_reg * n_iter_unroll[2]; + av[7] += n_elem_per_reg * n_iter_unroll[2]; + } + + // 4 elements of the result are computed per iteration + for ( i = 0; i < m_viter[3]; ++i ) + { + // Load the input values. + yv[0].v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + + a_vec[0].v = _mm256_loadu_pd( av[0] + 0*n_elem_per_reg ); + a_vec[1].v = _mm256_loadu_pd( av[1] + 0*n_elem_per_reg ); + a_vec[2].v = _mm256_loadu_pd( av[2] + 0*n_elem_per_reg ); + a_vec[3].v = _mm256_loadu_pd( av[3] + 0*n_elem_per_reg ); + a_vec[4].v = _mm256_loadu_pd( av[4] + 0*n_elem_per_reg ); + a_vec[5].v = _mm256_loadu_pd( av[5] + 0*n_elem_per_reg ); + a_vec[6].v = _mm256_loadu_pd( av[6] + 0*n_elem_per_reg ); + a_vec[7].v = _mm256_loadu_pd( av[7] + 0*n_elem_per_reg ); + + // perform : y += alpha * x; + yv[0].v = _mm256_fmadd_pd( a_vec[0].v, chiv[0].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[1].v, chiv[1].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[2].v, chiv[2].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[3].v, chiv[3].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[4].v, chiv[4].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[5].v, chiv[5].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[6].v, chiv[6].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[7].v, chiv[7].v, yv[0].v ); + + // Store the output. + _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), yv[0].v ); y0 += n_elem_per_reg; - a0 += n_elem_per_reg; - a1 += n_elem_per_reg; - a2 += n_elem_per_reg; - a3 += n_elem_per_reg; - a4 += n_elem_per_reg; - a5 += n_elem_per_reg; - a6 += n_elem_per_reg; - a7 += n_elem_per_reg; + av[0] += n_elem_per_reg; + av[1] += n_elem_per_reg; + av[2] += n_elem_per_reg; + av[3] += n_elem_per_reg; + av[4] += n_elem_per_reg; + av[5] += n_elem_per_reg; + av[6] += n_elem_per_reg; + av[7] += n_elem_per_reg; } // If there are leftover iterations, perform them with scalar code. @@ -438,34 +674,34 @@ void bli_daxpyf_zen_int_8 { double y0c = *y0; - const double a0c = *a0; - const double a1c = *a1; - const double a2c = *a2; - const double a3c = *a3; - const double a4c = *a4; - const double a5c = *a5; - const double a6c = *a6; - const double a7c = *a7; - - y0c += chi0 * a0c; - y0c += chi1 * a1c; - y0c += chi2 * a2c; - y0c += chi3 * a3c; - y0c += chi4 * a4c; - y0c += chi5 * a5c; - y0c += chi6 * a6c; - y0c += chi7 * a7c; + const double a0c = *av[0]; + const double a1c = *av[1]; + const double a2c = *av[2]; + const double a3c = *av[3]; + const double a4c = *av[4]; + const double a5c = *av[5]; + const double a6c = *av[6]; + const double a7c = *av[7]; + + y0c += chi[0] * a0c; + y0c += chi[1] * a1c; + y0c += chi[2] * a2c; + y0c += chi[3] * a3c; + y0c += chi[4] * a4c; + y0c += chi[5] * a5c; + y0c += chi[6] * a6c; + y0c += chi[7] * a7c; *y0 = y0c; - a0 += inca; - a1 += inca; - a2 += inca; - a3 += inca; - a4 += inca; - a5 += inca; - a6 += inca; - a7 += inca; + av[0] += inca; + av[1] += inca; + av[2] += inca; + av[3] += inca; + av[4] += inca; + av[5] += inca; + av[6] += inca; + av[7] += inca; y0 += incy; } } From 0792eb86082d0886708609eb4689ba45dc57f601 Mon Sep 17 00:00:00 2001 From: Meghana Vankadari Date: Wed, 2 Feb 2022 15:28:09 +0530 Subject: [PATCH 071/243] Fixed a bug in deriving dimensions from objects in gemm_front files Change-Id: I1f796c3a7ce6efacb6ef64651a7818b7ee38c6bb --- frame/3/gemm/bli_gemm_front.c | 12 ++++-------- frame/3/gemm/bli_gemm_front_amd.c | 2 +- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/frame/3/gemm/bli_gemm_front.c b/frame/3/gemm/bli_gemm_front.c index 972a7a782f..46e163c026 100644 --- a/frame/3/gemm/bli_gemm_front.c +++ b/frame/3/gemm/bli_gemm_front.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -174,10 +174,6 @@ void bli_gemm_front bli_obj_swap_pack_schemas( &a_local, &b_local ); } - dim_t m_dim_local = bli_obj_length( &c_local ); - dim_t n_dim_local = bli_obj_width( &c_local ); - dim_t k_dim_local = bli_obj_width( &a_local ); - // Parse and interpret the contents of the rntm_t object to properly // set the ways of parallelism for each loop, and then make any // additional modifications necessary for the current operation. @@ -185,9 +181,9 @@ void bli_gemm_front ( BLIS_GEMM, BLIS_LEFT, // ignored for gemm/hemm/symm - m_dim_local, - n_dim_local, - k_dim_local, + bli_obj_length( &c_local ), + bli_obj_width( &c_local ), + bli_obj_width_after_trans( &a_local ), rntm ); diff --git a/frame/3/gemm/bli_gemm_front_amd.c b/frame/3/gemm/bli_gemm_front_amd.c index 41af62007c..a29a0bb85b 100644 --- a/frame/3/gemm/bli_gemm_front_amd.c +++ b/frame/3/gemm/bli_gemm_front_amd.c @@ -176,7 +176,7 @@ void bli_gemm_front dim_t m_dim_local = bli_obj_length( &c_local ); dim_t n_dim_local = bli_obj_width( &c_local ); - dim_t k_dim_local = bli_obj_width( &a_local ); + dim_t k_dim_local = bli_obj_width_after_trans( &a_local ); // Regression observed in sgemm native path in cases where m >= 4 * n // after BLIS_THREAD_RATIO_M updated from 2 to 1 as part of commit From 7a0ba4194f847b5487ed6bfcf535347af22941a2 Mon Sep 17 00:00:00 2001 From: "Field G. Van Zee" Date: Sat, 13 Nov 2021 16:39:37 -0600 Subject: [PATCH 072/243] Added support for addons. Details: - Implemented a new feature called addons, which are similar to sandboxes except that there is no requirement to define gemm or any other particular operation. - Updated configure to accept --enable-addon= or -a syntax for requesting an addon be included within a BLIS build. configure now outputs the list of enabled addons into config.mk. It also outputs the corresponding #include directives for the addons' headers to a new companion to the bli_config.h header file named bli_addon.h. Because addons may wish to make use of existing BLIS types within their own definitions, the addons' headers must be included sometime after that of bli_config.h (which currently is #included before bli_type_defs.h). This is why the #include directives needed to go into a new top-level header file rather than the existing bli_config.h file. - Added a markdown document, docs/Addons.md, to explain addons, how to build with them, and what assumptions their authors should keep in mind as they create them. - Added a gemmlike-like implementation of sandwich gemm called 'gemmd' as an addon in addon/gemmd. The code uses a 'bao_' prefix for local functions, including the user-level object and typed APIs. - Updated .gitignore so that git ignores bli_addon.h files. Change-Id: Ie7efdea366481ce25075cb2459bdbcfd52309717 --- .gitignore | 1 + Makefile | 47 ++ addon/gemmd/attic/bli_gemm_ex.c | 88 +++ addon/gemmd/bao_gemmd.c | 305 +++++++++++ addon/gemmd/bao_gemmd.h | 105 ++++ addon/gemmd/bao_gemmd_bp_var1.c | 530 ++++++++++++++++++ addon/gemmd/bao_gemmd_bp_var2.c | 602 +++++++++++++++++++++ addon/gemmd/bao_gemmd_check.c | 131 +++++ addon/gemmd/bao_gemmd_check.h | 50 ++ addon/gemmd/bao_gemmd_var.h | 126 +++++ addon/gemmd/bao_l3_packm_a.c | 330 +++++++++++ addon/gemmd/bao_l3_packm_a.h | 123 +++++ addon/gemmd/bao_l3_packm_b.c | 330 +++++++++++ addon/gemmd/bao_l3_packm_b.h | 123 +++++ addon/gemmd/bao_l3_packm_var.h | 69 +++ addon/gemmd/bao_l3_packm_var1.c | 195 +++++++ addon/gemmd/bao_l3_packm_var2.c | 245 +++++++++ addon/gemmd/bao_packm_cxk.c | 199 +++++++ addon/gemmd/bao_packm_cxk.h | 59 ++ addon/gemmd/gemmd.h | 54 ++ addon/gemmd/thread/bao_l3_decor.h | 75 +++ addon/gemmd/thread/bao_l3_decor_openmp.c | 140 +++++ addon/gemmd/thread/bao_l3_decor_openmp.h | 44 ++ addon/gemmd/thread/bao_l3_decor_pthreads.c | 220 ++++++++ addon/gemmd/thread/bao_l3_decor_pthreads.h | 47 ++ addon/gemmd/thread/bao_l3_decor_single.c | 143 +++++ addon/gemmd/thread/bao_l3_decor_single.h | 44 ++ build/bli_addon.h.in | 47 ++ build/config.mk.in | 4 + common.mk | 126 ++++- configure | 151 +++++- docs/Addons.md | 231 ++++++++ frame/base/bli_info.c | 2 +- frame/include/bli_config_macro_defs.h | 5 +- frame/include/blis.h | 8 + 35 files changed, 4965 insertions(+), 34 deletions(-) create mode 100644 addon/gemmd/attic/bli_gemm_ex.c create mode 100644 addon/gemmd/bao_gemmd.c create mode 100644 addon/gemmd/bao_gemmd.h create mode 100644 addon/gemmd/bao_gemmd_bp_var1.c create mode 100644 addon/gemmd/bao_gemmd_bp_var2.c create mode 100644 addon/gemmd/bao_gemmd_check.c create mode 100644 addon/gemmd/bao_gemmd_check.h create mode 100644 addon/gemmd/bao_gemmd_var.h create mode 100644 addon/gemmd/bao_l3_packm_a.c create mode 100644 addon/gemmd/bao_l3_packm_a.h create mode 100644 addon/gemmd/bao_l3_packm_b.c create mode 100644 addon/gemmd/bao_l3_packm_b.h create mode 100644 addon/gemmd/bao_l3_packm_var.h create mode 100644 addon/gemmd/bao_l3_packm_var1.c create mode 100644 addon/gemmd/bao_l3_packm_var2.c create mode 100644 addon/gemmd/bao_packm_cxk.c create mode 100644 addon/gemmd/bao_packm_cxk.h create mode 100644 addon/gemmd/gemmd.h create mode 100644 addon/gemmd/thread/bao_l3_decor.h create mode 100644 addon/gemmd/thread/bao_l3_decor_openmp.c create mode 100644 addon/gemmd/thread/bao_l3_decor_openmp.h create mode 100644 addon/gemmd/thread/bao_l3_decor_pthreads.c create mode 100644 addon/gemmd/thread/bao_l3_decor_pthreads.h create mode 100644 addon/gemmd/thread/bao_l3_decor_single.c create mode 100644 addon/gemmd/thread/bao_l3_decor_single.h create mode 100644 build/bli_addon.h.in create mode 100644 docs/Addons.md diff --git a/.gitignore b/.gitignore index b3b811654a..d0de225b5c 100644 --- a/.gitignore +++ b/.gitignore @@ -31,6 +31,7 @@ config.mk bli_config.h +bli_addon.h # -- monolithic headers -- diff --git a/Makefile b/Makefile index 1658e16de2..820954e3e5 100644 --- a/Makefile +++ b/Makefile @@ -116,6 +116,7 @@ BASE_OBJ_FRAME_PATH := $(BASE_OBJ_PATH)/$(FRAME_DIR) BASE_OBJ_AOCLDTL_PATH := $(BASE_OBJ_PATH)/$(AOCLDTL_DIR) BASE_OBJ_REFKERN_PATH := $(BASE_OBJ_PATH)/$(REFKERN_DIR) BASE_OBJ_KERNELS_PATH := $(BASE_OBJ_PATH)/$(KERNELS_DIR) +BASE_OBJ_ADDON_PATH := $(BASE_OBJ_PATH)/$(ADDON_DIR) BASE_OBJ_SANDBOX_PATH := $(BASE_OBJ_PATH)/$(SANDBOX_DIR) # --- Define install target names for static libraries --- @@ -237,6 +238,9 @@ endif MK_AOCLDTL_OBJS := $(call gen-obj-paths-from-src,$(AOCLDTL_SRC_SUFS),$(MK_AOCLDTL_SRC),$(AOCLDTL_PATH),$(BASE_OBJ_AOCLDTL_PATH)) +# Generate object file paths for the addon source code. If one or more addons +# were not enabled a configure-time, this variable will we empty. +MK_ADDON_OBJS := $(call gen-obj-paths-from-src,$(ADDON_SRC_SUFS),$(MK_ADDON_SRC),$(ADDON_PATH),$(BASE_OBJ_ADDON_PATH)) # Generate object file paths for the sandbox source code. If a sandbox was not # enabled a configure-time, this variable will we empty. @@ -248,6 +252,7 @@ MK_BLIS_OBJS := $(MK_CONFIG_OBJS) \ $(MK_REFKERN_OBJS) \ $(MK_FRAME_OBJS) \ $(MK_AOCLDTL_OBJS) \ + $(MK_ADDON_OBJS) \ $(MK_SANDBOX_OBJS) # Optionally filter out the BLAS and CBLAS compatibility layer object files. @@ -588,6 +593,28 @@ else endif endef +# first argument: a configuration name from the union of config_list and +# config_name, used to look up the CFLAGS to use during compilation. +define make-c99-addon-rule +$(BASE_OBJ_ADDON_PATH)/%.o: $(ADDON_PATH)/%.$(2) $(BLIS_H_FLAT) $(ADDON_H99_FILES) $(MAKE_DEFS_MK_PATHS) +ifeq ($(ENABLE_VERBOSE),yes) + $(CC) $(call get-addon-c99flags-for,$(1)) -c $$< -o $$@ +else + @echo "Compiling $$@" $(call get-addon-c99text-for,$(1)) + @$(CC) $(call get-addon-c99flags-for,$(1)) -c $$< -o $$@ +endif +endef + +define make-cxx-addon-rule +$(BASE_OBJ_ADDON_PATH)/%.o: $(ADDON_PATH)/%.$(2) $(BLIS_H_FLAT) $(ADDON_HXX_FILES) $(MAKE_DEFS_MK_PATHS) +ifeq ($(ENABLE_VERBOSE),yes) + $(CXX) $(call get-addon-cxxflags-for,$(1)) -c $$< -o $$@ +else + @echo "Compiling $$@" $(call get-addon-cxxtext-for,$(1)) + @$(CXX) $(call get-addon-cxxflags-for,$(1)) -c $$< -o $$@ +endif +endef + # first argument: a configuration name from the union of config_list and # config_name, used to look up the CFLAGS to use during compilation. define make-c99-sandbox-rule @@ -648,6 +675,16 @@ $(foreach conf, $(CONFIG_LIST), $(eval $(call make-refkern-rule,$(conf)))) $(foreach suf, $(KERNELS_SRC_SUFS), \ $(foreach kset, $(KERNEL_LIST), $(eval $(call make-kernels-rule,$(kset),$(call get-config-for-kset,$(kset)),$(suf))))) +# Instantiate the build rule for C addon files. Use the CFLAGS for the +# configuration family. +$(foreach suf, $(ADDON_C99_SUFS), \ +$(foreach conf, $(CONFIG_NAME), $(eval $(call make-c99-addon-rule,$(conf),$(suf))))) + +# Instantiate the build rule for C++ addon files. Use the CFLAGS for the +# configuration family. +$(foreach suf, $(ADDON_CXX_SUFS), \ +$(foreach conf, $(CONFIG_NAME), $(eval $(call make-cxx-addon-rule,$(conf),$(suf))))) + # Instantiate the build rule for C sandbox files. Use the CFLAGS for the # configuration family. $(foreach suf, $(SANDBOX_C99_SUFS), \ @@ -1141,6 +1178,9 @@ ifeq ($(ENABLE_VERBOSE),yes) - $(FIND) $(AOCLDTL_FRAG_PATH) -name "$(FRAGMENT_MK)" | $(XARGS) $(RM_F) - $(FIND) $(REFKERN_FRAG_PATH) -name "$(FRAGMENT_MK)" | $(XARGS) $(RM_F) - $(FIND) $(KERNELS_FRAG_PATH) -name "$(FRAGMENT_MK)" | $(XARGS) $(RM_F) +ifneq ($(ADDON_LIST),) + - $(FIND) $(ADDON_FRAG_PATH) -name "$(FRAGMENT_MK)" | $(XARGS) $(RM_F) +endif ifneq ($(SANDBOX),) - $(FIND) $(SANDBOX_FRAG_PATH) -name "$(FRAGMENT_MK)" | $(XARGS) $(RM_F) endif @@ -1155,6 +1195,10 @@ else @- $(FIND) $(REFKERN_FRAG_PATH) -name "$(FRAGMENT_MK)" | $(XARGS) $(RM_F) @echo "Removing makefile fragments from $(KERNELS_FRAG_PATH)" @- $(FIND) $(KERNELS_FRAG_PATH) -name "$(FRAGMENT_MK)" | $(XARGS) $(RM_F) +ifneq ($(ADDON_LIST),) + @echo "Removing makefile fragments from $(ADDON_FRAG_PATH)" + @- $(FIND) $(ADDON_FRAG_PATH) -name "$(FRAGMENT_MK)" | $(XARGS) $(RM_F) +endif ifneq ($(SANDBOX),) @echo "Removing makefile fragments from $(SANDBOX_FRAG_PATH)" @- $(FIND) $(SANDBOX_FRAG_PATH) -name "$(FRAGMENT_MK)" | $(XARGS) $(RM_F) @@ -1275,6 +1319,7 @@ endif # IS_CONFIGURED distclean: cleanmk cleanh cleanlib cleantest ifeq ($(IS_CONFIGURED),yes) ifeq ($(ENABLE_VERBOSE),yes) + - $(RM_F) $(BLIS_ADDON_H) - $(RM_F) $(BLIS_CONFIG_H) - $(RM_F) $(CONFIG_MK_FILE) - $(RM_F) $(PC_OUT_FILE) @@ -1282,6 +1327,8 @@ ifeq ($(ENABLE_VERBOSE),yes) - $(RM_RF) $(LIB_DIR) - $(RM_RF) $(INCLUDE_DIR) else + @echo "Removing $(BLIS_ADDON_H)" + @$(RM_F) $(BLIS_ADDON_H) @echo "Removing $(BLIS_CONFIG_H)" @$(RM_F) $(BLIS_CONFIG_H) @echo "Removing $(CONFIG_MK_FILE)" diff --git a/addon/gemmd/attic/bli_gemm_ex.c b/addon/gemmd/attic/bli_gemm_ex.c new file mode 100644 index 0000000000..0f40d1cb39 --- /dev/null +++ b/addon/gemmd/attic/bli_gemm_ex.c @@ -0,0 +1,88 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +void bli_gemm_ex + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm + ) +{ + bli_init_once(); + + // A switch to easily toggle whether we use the addon implementation + // of bao_gemmd() as the implementation for bli_gemm(). (This allows for + // easy testing of bao_gemmd() via the testsuite.) + if ( 1 ) + { + const dim_t k = bli_obj_width_after_trans( a ); + const num_t dt = bli_obj_dt( c ); + obj_t d; + + bli_obj_create( dt, k, 1, 1, k, &d ); + bli_setv( &BLIS_ONE, &d ); + //bli_randv( &d ); + + bao_gemmd_ex( alpha, a, &d, b, beta, c, cntx, rntm ); + + bli_obj_free( &d ); + return; + } + + // Initialize a local runtime with global settings if necessary. Note + // that in the case that a runtime is passed in, we make a local copy. + rntm_t rntm_l; + if ( rntm == NULL ) { bli_rntm_init_from_global( &rntm_l ); rntm = &rntm_l; } + else { rntm_l = *rntm; rntm = &rntm_l; } + + // Obtain a valid (native) context from the gks if necessary. + if ( cntx == NULL ) cntx = bli_gks_query_cntx(); + + // Check the operands. + if ( bli_error_checking_is_enabled() ) + bli_gemm_check( alpha, a, b, beta, c, cntx ); + + // Invoke the operation's front end. + bli_gemm_front + ( + alpha, a, b, beta, c, cntx, rntm, NULL + ); +} + diff --git a/addon/gemmd/bao_gemmd.c b/addon/gemmd/bao_gemmd.c new file mode 100644 index 0000000000..71d49806ba --- /dev/null +++ b/addon/gemmd/bao_gemmd.c @@ -0,0 +1,305 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +// +// -- Define the gemmd operation's object API ---------------------------------- +// + +void bao_gemmd + ( + obj_t* alpha, + obj_t* a, + obj_t* d, + obj_t* b, + obj_t* beta, + obj_t* c + ) +{ + bao_gemmd_ex + ( + alpha, + a, + d, + b, + beta, + c, + NULL, + NULL + ); +} + +void bao_gemmd_ex + ( + obj_t* alpha, + obj_t* a, + obj_t* d, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm + ) +{ + bli_init_once(); + + // Initialize a local runtime with global settings if necessary. Note + // that in the case that a runtime is passed in, we make a local copy. + rntm_t rntm_l; + if ( rntm == NULL ) { bli_rntm_init_from_global( &rntm_l ); rntm = &rntm_l; } + else { rntm_l = *rntm; rntm = &rntm_l; } + + // Obtain a valid (native) context from the gks if necessary. + // NOTE: This must be done before calling the _check() function, since + // that function assumes the context pointer is valid. + if ( cntx == NULL ) cntx = bli_gks_query_cntx(); + + // Check parameters. + if ( bli_error_checking_is_enabled() ) + bao_gemmd_check( alpha, a, d, b, beta, c, cntx ); + + // -- bli_gemmd_front() ---------------------------------------------------- + + obj_t a_local; + obj_t b_local; + obj_t c_local; + + // If C has a zero dimension, return early. + if ( bli_obj_has_zero_dim( c ) ) + { + return; + } + + // If alpha is zero, or if A or B has a zero dimension, scale C by beta + // and return early. + if ( bli_obj_equals( alpha, &BLIS_ZERO ) || + bli_obj_has_zero_dim( a ) || + bli_obj_has_zero_dim( b ) ) + { + bli_scalm( beta, c ); + return; + } + + // Alias A, B, and C in case we need to apply transformations. + bli_obj_alias_to( a, &a_local ); + bli_obj_alias_to( b, &b_local ); + bli_obj_alias_to( c, &c_local ); + + // Induce a transposition of A if it has its transposition property set. + // Then clear the transposition bit in the object. + if ( bli_obj_has_trans( &a_local ) ) + { + bli_obj_induce_trans( &a_local ); + bli_obj_set_onlytrans( BLIS_NO_TRANSPOSE, &a_local ); + } + + // Induce a transposition of B if it has its transposition property set. + // Then clear the transposition bit in the object. + if ( bli_obj_has_trans( &b_local ) ) + { + bli_obj_induce_trans( &b_local ); + bli_obj_set_onlytrans( BLIS_NO_TRANSPOSE, &b_local ); + } + + // An optimization: If C is stored by rows and the micro-kernel prefers + // contiguous columns, or if C is stored by columns and the micro-kernel + // prefers contiguous rows, transpose the entire operation to allow the + // micro-kernel to access elements of C in its preferred manner. + if ( bli_cntx_l3_vir_ukr_dislikes_storage_of( &c_local, BLIS_GEMM_UKR, cntx ) ) + { + bli_obj_swap( &a_local, &b_local ); + + bli_obj_induce_trans( &a_local ); + bli_obj_induce_trans( &b_local ); + bli_obj_induce_trans( &c_local ); + } + + // Parse and interpret the contents of the rntm_t object to properly + // set the ways of parallelism for each loop, and then make any + // additional modifications necessary for the current operation. + bli_rntm_set_ways_for_op + ( + BLIS_GEMM, + BLIS_LEFT, // ignored for gemm/hemm/symm + bli_obj_length( &c_local ), + bli_obj_width( &c_local ), + bli_obj_width( &a_local ), + rntm + ); + + // Spawn threads (if applicable), where bao_gemmd_int() is the thread entry + // point function for each thread. This also begins the process of creating + // the thrinfo_t tree, which contains thread communicators. + bao_l3_thread_decorator + ( + bao_gemmd_int, + BLIS_GEMM, // operation family id + alpha, + &a_local, + d, + &b_local, + beta, + &c_local, + cntx, + rntm + ); +} + +// +// -- Define the gemmd operation's thread entry point -------------------------- +// + +void bao_gemmd_int + ( + obj_t* alpha, + obj_t* a, + obj_t* d, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm, + thrinfo_t* thread + ) +{ + // In this function, we choose the gemmd implementation that is executed + // on each thread. + +#if 1 + // Call the block-panel algorithm that calls the kernel directly, which + // exposes edge-case handling. + bao_gemmd_bp_var1 + ( + alpha, + a, + d, + b, + beta, + c, + cntx, + rntm, + thread + ); +#else + // Call the block-panel algorithm that calls the kernel indirectly via a + // wrapper function, which hides edge-case handling. + bao_gemmd_bp_var2 + ( + alpha, + a, + d, + b, + beta, + c, + cntx, + rntm, + thread + ); +#endif +} + +// +// -- Define the gemmd operation's typed API ----------------------------------- +// + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTECH2(bao_,ch,opname) \ + ( \ + trans_t transa, \ + trans_t transb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + ctype* alpha, \ + ctype* a, inc_t rs_a, inc_t cs_a, \ + ctype* d, inc_t incd, \ + ctype* b, inc_t rs_b, inc_t cs_b, \ + ctype* beta, \ + ctype* c, inc_t rs_c, inc_t cs_c \ + ) \ +{ \ + bli_init_once(); \ +\ + /* Determine the datatype (e.g. BLIS_FLOAT, BLIS_DOUBLE, etc.) based on + the macro parameter 'ch' (e.g. s, d, etc). */ \ + const num_t dt = PASTEMAC(ch,type); \ +\ + obj_t alphao, ao, dd, bo, betao, co; \ +\ + dim_t m_a, n_a; \ + dim_t m_b, n_b; \ +\ + /* Adjust the dimensions of matrices A and B according to the transa and + transb parameters. */ \ + bli_set_dims_with_trans( transa, m, k, &m_a, &n_a ); \ + bli_set_dims_with_trans( transb, k, n, &m_b, &n_b ); \ +\ + /* Create bufferless scalar objects and attach the provided scalar pointers + to those scalar objects. */ \ + bli_obj_create_1x1_with_attached_buffer( dt, alpha, &alphao ); \ + bli_obj_create_1x1_with_attached_buffer( dt, beta, &betao ); \ +\ + /* Create bufferless matrix objects and attach the provided matrix pointers + to those matrix objects. */ \ + bli_obj_create_with_attached_buffer( dt, m_a, n_a, a, rs_a, cs_a, &ao ); \ + bli_obj_create_with_attached_buffer( dt, k, 1, d, incd, k, &dd ); \ + bli_obj_create_with_attached_buffer( dt, m_b, n_b, b, rs_b, cs_b, &bo ); \ + bli_obj_create_with_attached_buffer( dt, m, n, c, rs_c, cs_c, &co ); \ +\ + /* Set the transposition/conjugation properties of the objects for matrices + A and B. */ \ + bli_obj_set_conjtrans( transa, &ao ); \ + bli_obj_set_conjtrans( transb, &bo ); \ +\ + /* Call the object interface. */ \ + PASTECH(bao_,opname) \ + ( \ + &alphao, \ + &ao, \ + &dd, \ + &bo, \ + &betao, \ + &co \ + ); \ +} + +//INSERT_GENTFUNC_BASIC0( gemmd ) +GENTFUNC( float, s, gemmd ) +GENTFUNC( double, d, gemmd ) +GENTFUNC( scomplex, c, gemmd ) +GENTFUNC( dcomplex, z, gemmd ) + diff --git a/addon/gemmd/bao_gemmd.h b/addon/gemmd/bao_gemmd.h new file mode 100644 index 0000000000..7c7466494d --- /dev/null +++ b/addon/gemmd/bao_gemmd.h @@ -0,0 +1,105 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +// +// -- Prototype the gemmd operation's object API ------------------------------- +// + +BLIS_EXPORT_ADDON void bao_gemmd + ( + obj_t* alpha, + obj_t* a, + obj_t* d, + obj_t* b, + obj_t* beta, + obj_t* c + ); + +BLIS_EXPORT_ADDON void bao_gemmd_ex + ( + obj_t* alpha, + obj_t* a, + obj_t* d, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm + ); + +// +// -- Prototype the gemmd operation's thread entry point ----------------------- +// + +void bao_gemmd_int + ( + obj_t* alpha, + obj_t* a, + obj_t* d, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm, + thrinfo_t* thread + ); + +// +// -- Prototype the gemmd operation's typed API -------------------------------- +// + +#undef GENTPROT +#define GENTPROT( ctype, ch, opname ) \ +\ +BLIS_EXPORT_ADDON void PASTECH2(bao_,ch,opname) \ + ( \ + trans_t transa, \ + trans_t transb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + ctype* alpha, \ + ctype* a, inc_t rs_a, inc_t cs_a, \ + ctype* d, inc_t incd, \ + ctype* b, inc_t rs_b, inc_t cs_b, \ + ctype* beta, \ + ctype* c, inc_t rs_c, inc_t cs_c \ + ); + +//INSERT_GENTPROT_BASIC0( gemmd ) +GENTPROT( float, s, gemmd ) +GENTPROT( double, d, gemmd ) +GENTPROT( scomplex, c, gemmd ) +GENTPROT( dcomplex, z, gemmd ) + diff --git a/addon/gemmd/bao_gemmd_bp_var1.c b/addon/gemmd/bao_gemmd_bp_var1.c new file mode 100644 index 0000000000..e042f1fd81 --- /dev/null +++ b/addon/gemmd/bao_gemmd_bp_var1.c @@ -0,0 +1,530 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define FUNCPTR_T gemmd_fp + +typedef void (*FUNCPTR_T) + ( + conj_t conja, + conj_t conjb, + dim_t m, + dim_t n, + dim_t k, + void* restrict alpha, + void* restrict a, inc_t rs_a, inc_t cs_a, + void* restrict d, inc_t incd, + void* restrict b, inc_t rs_b, inc_t cs_b, + void* restrict beta, + void* restrict c, inc_t rs_c, inc_t cs_c, + cntx_t* restrict cntx, + rntm_t* restrict rntm, + thrinfo_t* restrict thread + ); + +// +// -- gemmd-like block-panel algorithm (object interface) ---------------------- +// + +// Define a function pointer array named ftypes and initialize its contents with +// the addresses of the typed functions defined below, bao_?gemmd_bp_var1(). +static FUNCPTR_T GENARRAY_PREF(ftypes,bao_,gemmd_bp_var1); + +void bao_gemmd_bp_var1 + ( + obj_t* alpha, + obj_t* a, + obj_t* d, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm, + thrinfo_t* thread + ) +{ + const num_t dt = bli_obj_dt( c ); + + const conj_t conja = bli_obj_conj_status( a ); + const conj_t conjb = bli_obj_conj_status( b ); + + const dim_t m = bli_obj_length( c ); + const dim_t n = bli_obj_width( c ); + const dim_t k = bli_obj_width( a ); + + void* restrict buf_a = bli_obj_buffer_at_off( a ); + const inc_t rs_a = bli_obj_row_stride( a ); + const inc_t cs_a = bli_obj_col_stride( a ); + + void* restrict buf_d = bli_obj_buffer_at_off( d ); + const inc_t incd = bli_obj_vector_inc( d ); + + void* restrict buf_b = bli_obj_buffer_at_off( b ); + const inc_t rs_b = bli_obj_row_stride( b ); + const inc_t cs_b = bli_obj_col_stride( b ); + + void* restrict buf_c = bli_obj_buffer_at_off( c ); + const inc_t rs_c = bli_obj_row_stride( c ); + const inc_t cs_c = bli_obj_col_stride( c ); + + void* restrict buf_alpha = bli_obj_buffer_for_1x1( dt, alpha ); + void* restrict buf_beta = bli_obj_buffer_for_1x1( dt, beta ); + + // Index into the function pointer array to extract the correct + // typed function pointer based on the chosen datatype. + FUNCPTR_T f = ftypes[dt]; + + // Invoke the function. + f + ( + conja, + conjb, + m, + n, + k, + buf_alpha, + buf_a, rs_a, cs_a, + buf_d, incd, + buf_b, rs_b, cs_b, + buf_beta, + buf_c, rs_c, cs_c, + cntx, + rntm, + thread + ); +} + +// +// -- gemmd-like block-panel algorithm (typed interface) ----------------------- +// + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTECH2(bao_,ch,varname) \ + ( \ + conj_t conja, \ + conj_t conjb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + void* restrict alpha, \ + void* restrict a, inc_t rs_a, inc_t cs_a, \ + void* restrict d, inc_t incd, \ + void* restrict b, inc_t rs_b, inc_t cs_b, \ + void* restrict beta, \ + void* restrict c, inc_t rs_c, inc_t cs_c, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + const num_t dt = PASTEMAC(ch,type); \ +\ + /* Query the context for various blocksizes. */ \ + const dim_t NR = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx ); \ + const dim_t MR = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx ); \ + const dim_t NC = bli_cntx_get_blksz_def_dt( dt, BLIS_NC, cntx ); \ + const dim_t MC = bli_cntx_get_blksz_def_dt( dt, BLIS_MC, cntx ); \ + const dim_t KC = bli_cntx_get_blksz_def_dt( dt, BLIS_KC, cntx ); \ +\ + /* Query the context for the microkernel address and cast it to its + function pointer type. */ \ + PASTECH(ch,gemm_ukr_ft) \ + gemm_ukr = bli_cntx_get_l3_nat_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \ +\ + /* Temporary C buffer for edge cases. Note that the strides of this + temporary buffer are set so that they match the storage of the + original C matrix. For example, if C is column-stored, ct will be + column-stored as well. */ \ + ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ + / sizeof( ctype ) ] \ + __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ + const bool col_pref = bli_cntx_l3_nat_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ + const inc_t rs_ct = ( col_pref ? 1 : NR ); \ + const inc_t cs_ct = ( col_pref ? MR : 1 ); \ +\ + /* Compute partitioning step values for each matrix of each loop. */ \ + const inc_t jcstep_c = cs_c; \ + const inc_t jcstep_b = cs_b; \ +\ + const inc_t pcstep_a = cs_a; \ + const inc_t pcstep_d = incd; \ + const inc_t pcstep_b = rs_b; \ +\ + const inc_t icstep_c = rs_c; \ + const inc_t icstep_a = rs_a; \ +\ + const inc_t jrstep_c = cs_c * NR; \ +\ + const inc_t irstep_c = rs_c * MR; \ +\ + ctype* restrict a_00 = a; \ + ctype* restrict d_00 = d; \ + ctype* restrict b_00 = b; \ + ctype* restrict c_00 = c; \ + ctype* restrict alpha_cast = alpha; \ + ctype* restrict beta_cast = beta; \ +\ + /* Make local copies of the scalars to prevent any unnecessary sharing of + cache lines between the cores' caches. */ \ + ctype alpha_local = *alpha_cast; \ + ctype beta_local = *beta_cast; \ + ctype one_local = *PASTEMAC(ch,1); \ + ctype zero_local = *PASTEMAC(ch,0); \ +\ + auxinfo_t aux; \ +\ + /* Initialize a mem_t entry for A and B. Strictly speaking, this is only + needed for the matrix we will be packing (if any), but we do it + unconditionally to be safe. */ \ + mem_t mem_a = BLIS_MEM_INITIALIZER; \ + mem_t mem_b = BLIS_MEM_INITIALIZER; \ +\ + /* Define an array of bszid_t ids, which will act as our substitute for + the cntl_t tree. */ \ + bszid_t bszids[8] = { BLIS_NC, /* 5th loop */ \ + BLIS_KC, /* 4th loop */ \ + BLIS_NO_PART, /* pack B */ \ + BLIS_MC, /* 3rd loop */ \ + BLIS_NO_PART, /* pack A */ \ + BLIS_NR, /* 2nd loop */ \ + BLIS_MR, /* 1st loop */ \ + BLIS_KR }; /* microkernel loop */ \ +\ + bszid_t* restrict bszids_jc = &bszids[0]; \ + bszid_t* restrict bszids_pc = &bszids[1]; \ + /*bszid_t* restrict bszids_pb = &bszids[2];*/ \ + bszid_t* restrict bszids_ic = &bszids[3]; \ + /*bszid_t* restrict bszids_pa = &bszids[4];*/ \ + bszid_t* restrict bszids_jr = &bszids[5]; \ + /*bszid_t* restrict bszids_ir = &bszids[6];*/ \ +\ + thrinfo_t* restrict thread_jc = NULL; \ + thrinfo_t* restrict thread_pc = NULL; \ + thrinfo_t* restrict thread_pb = NULL; \ + thrinfo_t* restrict thread_ic = NULL; \ + thrinfo_t* restrict thread_pa = NULL; \ + thrinfo_t* restrict thread_jr = NULL; \ + thrinfo_t* restrict thread_ir = NULL; \ +\ + /* Identify the current thrinfo_t node and then grow the tree. */ \ + thread_jc = thread; \ + bli_thrinfo_sup_grow( rntm, bszids_jc, thread_jc ); \ +\ + /* Compute the JC loop thread range for the current thread. */ \ + dim_t jc_start, jc_end; \ + bli_thread_range_sub( thread_jc, n, NR, FALSE, &jc_start, &jc_end ); \ + const dim_t n_local = jc_end - jc_start; \ +\ + /* Compute number of primary and leftover components of the JC loop. */ \ + /*const dim_t jc_iter = ( n_local + NC - 1 ) / NC;*/ \ + const dim_t jc_left = n_local % NC; \ +\ + /* Loop over the n dimension (NC rows/columns at a time). */ \ + for ( dim_t jj = jc_start; jj < jc_end; jj += NC ) \ + { \ + /* Calculate the thread's current JC block dimension. */ \ + const dim_t nc_cur = ( NC <= jc_end - jj ? NC : jc_left ); \ +\ + ctype* restrict b_jc = b_00 + jj * jcstep_b; \ + ctype* restrict c_jc = c_00 + jj * jcstep_c; \ +\ + /* Identify the current thrinfo_t node and then grow the tree. */ \ + thread_pc = bli_thrinfo_sub_node( thread_jc ); \ + bli_thrinfo_sup_grow( rntm, bszids_pc, thread_pc ); \ +\ + /* Compute the PC loop thread range for the current thread. */ \ + const dim_t pc_start = 0, pc_end = k; \ + const dim_t k_local = k; \ +\ + /* Compute number of primary and leftover components of the PC loop. */ \ + /*const dim_t pc_iter = ( k_local + KC - 1 ) / KC;*/ \ + const dim_t pc_left = k_local % KC; \ +\ + /* Loop over the k dimension (KC rows/columns at a time). */ \ + for ( dim_t pp = pc_start; pp < pc_end; pp += KC ) \ + { \ + /* Calculate the thread's current PC block dimension. */ \ + const dim_t kc_cur = ( KC <= pc_end - pp ? KC : pc_left ); \ +\ + ctype* restrict a_pc = a_00 + pp * pcstep_a; \ + ctype* restrict d_pc = d_00 + pp * pcstep_d; \ + ctype* restrict b_pc = b_jc + pp * pcstep_b; \ +\ + /* Only apply beta to the first iteration of the pc loop. */ \ + ctype* restrict beta_use = ( pp == 0 ? &beta_local : &one_local ); \ +\ + ctype* b_use; \ + inc_t rs_b_use, cs_b_use, ps_b_use; \ +\ + /* Identify the current thrinfo_t node. Note that the thrinfo_t + node will have already been created by a previous call to + bli_thrinfo_sup_grow() since bszid_t values of BLIS_NO_PART + cause the tree to grow by two (e.g. to the next bszid that is + a normal bszid_t value). */ \ + thread_pb = bli_thrinfo_sub_node( thread_pc ); \ + /*bli_thrinfo_sup_grow( rntm, bszids_pb, thread_pb );*/ \ +\ + /* Determine the packing buffer and related parameters for matrix + B. Then call the packm implementation. */ \ + PASTECH2(bao_,ch,packm_b) \ + ( \ + conjb, \ + KC, NC, \ + kc_cur, nc_cur, NR, \ + &one_local, \ + d_pc, incd, \ + b_pc, rs_b, cs_b, \ + &b_use, &rs_b_use, &cs_b_use, \ + &ps_b_use, \ + cntx, \ + rntm, \ + &mem_b, \ + thread_pb \ + ); \ +\ + /* Alias b_use so that it's clear this is our current block of + matrix B. */ \ + ctype* restrict b_pc_use = b_use; \ +\ + /* Identify the current thrinfo_t node and then grow the tree. */ \ + thread_ic = bli_thrinfo_sub_node( thread_pb ); \ + bli_thrinfo_sup_grow( rntm, bszids_ic, thread_ic ); \ +\ + /* Compute the IC loop thread range for the current thread. */ \ + dim_t ic_start, ic_end; \ + bli_thread_range_sub( thread_ic, m, MR, FALSE, &ic_start, &ic_end ); \ + const dim_t m_local = ic_end - ic_start; \ +\ + /* Compute number of primary and leftover components of the IC loop. */ \ + /*const dim_t ic_iter = ( m_local + MC - 1 ) / MC;*/ \ + const dim_t ic_left = m_local % MC; \ +\ + /* Loop over the m dimension (MC rows at a time). */ \ + for ( dim_t ii = ic_start; ii < ic_end; ii += MC ) \ + { \ + /* Calculate the thread's current IC block dimension. */ \ + const dim_t mc_cur = ( MC <= ic_end - ii ? MC : ic_left ); \ +\ + ctype* restrict a_ic = a_pc + ii * icstep_a; \ + ctype* restrict c_ic = c_jc + ii * icstep_c; \ +\ + ctype* a_use; \ + inc_t rs_a_use, cs_a_use, ps_a_use; \ +\ + /* Identify the current thrinfo_t node. Note that the thrinfo_t + node will have already been created by a previous call to + bli_thrinfo_sup_grow() since bszid_t values of BLIS_NO_PART + cause the tree to grow by two (e.g. to the next bszid that is + a normal bszid_t value). */ \ + thread_pa = bli_thrinfo_sub_node( thread_ic ); \ + /*bli_thrinfo_sup_grow( rntm, bszids_pa, thread_pa );*/ \ +\ + /* Determine the packing buffer and related parameters for matrix + A. Then call the packm implementation. */ \ + PASTECH2(bao_,ch,packm_a) \ + ( \ + conja, \ + MC, KC, \ + mc_cur, kc_cur, MR, \ + &one_local, \ + d_pc, incd, \ + a_ic, rs_a, cs_a, \ + &a_use, &rs_a_use, &cs_a_use, \ + &ps_a_use, \ + cntx, \ + rntm, \ + &mem_a, \ + thread_pa \ + ); \ +\ + /* Alias a_use so that it's clear this is our current block of + matrix A. */ \ + ctype* restrict a_ic_use = a_use; \ +\ + /* Identify the current thrinfo_t node and then grow the tree. */ \ + thread_jr = bli_thrinfo_sub_node( thread_pa ); \ + bli_thrinfo_sup_grow( rntm, bszids_jr, thread_jr ); \ +\ + /* Query the number of threads and thread ids for the JR loop. + NOTE: These values are only needed when computing the next + micropanel of B. */ \ + const dim_t jr_nt = bli_thread_n_way( thread_jr ); \ + const dim_t jr_tid = bli_thread_work_id( thread_jr ); \ +\ + /* Compute number of primary and leftover components of the JR loop. */ \ + dim_t jr_iter = ( nc_cur + NR - 1 ) / NR; \ + dim_t jr_left = nc_cur % NR; \ +\ + /* Compute the JR loop thread range for the current thread. */ \ + dim_t jr_start, jr_end; \ + bli_thread_range_sub( thread_jr, jr_iter, 1, FALSE, &jr_start, &jr_end ); \ +\ + /* Loop over the n dimension (NR columns at a time). */ \ + for ( dim_t j = jr_start; j < jr_end; j += 1 ) \ + { \ + const dim_t nr_cur \ + = ( bli_is_not_edge_f( j, jr_iter, jr_left ) ? NR : jr_left ); \ +\ + ctype* restrict b_jr = b_pc_use + j * ps_b_use; \ + ctype* restrict c_jr = c_ic + j * jrstep_c; \ +\ + /* Assume for now that our next panel of B to be the current panel + of B. */ \ + ctype* restrict b2 = b_jr; \ +\ + /* Identify the current thrinfo_t node. */ \ + thread_ir = bli_thrinfo_sub_node( thread_jr ); \ +\ + /* Query the number of threads and thread ids for the IR loop. + NOTE: These values are only needed when computing the next + micropanel of A. */ \ + const dim_t ir_nt = bli_thread_n_way( thread_ir ); \ + const dim_t ir_tid = bli_thread_work_id( thread_ir ); \ +\ + /* Compute number of primary and leftover components of the IR loop. */ \ + dim_t ir_iter = ( mc_cur + MR - 1 ) / MR; \ + dim_t ir_left = mc_cur % MR; \ +\ + /* Compute the IR loop thread range for the current thread. */ \ + dim_t ir_start, ir_end; \ + bli_thread_range_sub( thread_ir, ir_iter, 1, FALSE, &ir_start, &ir_end ); \ +\ + /* Loop over the m dimension (MR rows at a time). */ \ + for ( dim_t i = ir_start; i < ir_end; i += 1 ) \ + { \ + const dim_t mr_cur \ + = ( bli_is_not_edge_f( i, ir_iter, ir_left ) ? MR : ir_left ); \ +\ + ctype* restrict a_ir = a_ic_use + i * ps_a_use; \ + ctype* restrict c_ir = c_jr + i * irstep_c; \ +\ + ctype* restrict a2; \ +\ + /* Compute the addresses of the next micropanels of A and B. */ \ + a2 = bli_gemm_get_next_a_upanel( a_ir, ps_a_use, 1 ); \ + if ( bli_is_last_iter( i, ir_end, ir_tid, ir_nt ) ) \ + { \ + a2 = a_ic_use; \ + b2 = bli_gemm_get_next_b_upanel( b_jr, ps_b_use, 1 ); \ + if ( bli_is_last_iter( j, jr_end, jr_tid, jr_nt ) ) \ + b2 = b_pc_use; \ + } \ +\ + /* Save the addresses of next micropanels of A and B to the + auxinfo_t object. */ \ + bli_auxinfo_set_next_a( a2, &aux ); \ + bli_auxinfo_set_next_b( b2, &aux ); \ +\ + /* Handle interior and edge cases separately. */ \ + if ( mr_cur == MR && nr_cur == NR ) \ + { \ + /* Invoke the gemm microkernel. */ \ + gemm_ukr \ + ( \ + kc_cur, \ + &alpha_local, \ + a_ir, \ + b_jr, \ + beta_use, \ + c_ir, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ + } \ + else \ + { \ + /* Invoke the gemm microkernel. */ \ + gemm_ukr \ + ( \ + kc_cur, \ + &alpha_local, \ + a_ir, \ + b_jr, \ + &zero_local, \ + ct, rs_ct, cs_ct, \ + &aux, \ + cntx \ + ); \ +\ + /* Scale the bottom edge of C and add the result from above. */ \ + PASTEMAC(ch,xpbys_mxn) \ + ( \ + mr_cur, \ + nr_cur, \ + ct, rs_ct, cs_ct, \ + beta_use, \ + c_ir, rs_c, cs_c \ + ); \ + } \ + } \ + } \ + } \ +\ + /* This barrier is needed to prevent threads from starting to pack + the next row panel of B before the current row panel is fully + computed upon. */ \ + bli_thread_barrier( thread_pb ); \ + } \ + } \ +\ + /* Release any memory that was acquired for packing matrices A and B. */ \ + PASTECH2(bao_,ch,packm_finalize_mem_a) \ + ( \ + rntm, \ + &mem_a, \ + thread_pa \ + ); \ + PASTECH2(bao_,ch,packm_finalize_mem_b) \ + ( \ + rntm, \ + &mem_b, \ + thread_pb \ + ); \ +\ +/* +PASTEMAC(ch,fprintm)( stdout, "gemmd_bp_var1: a1_packed", mr_cur, kc_cur, a_ir, rs_a_use, cs_a_use, "%5.2f", "" ); \ +PASTEMAC(ch,fprintm)( stdout, "gemmd_bp_var1: b1_packed", kc_cur, nr_cur, b_jr, rs_b_use, cs_b_use, "%5.2f", "" ); \ +PASTEMAC(ch,fprintm)( stdout, "gemmd_bp_var1: c ", mr_cur, nr_cur, c_ir, rs_c, cs_c, "%5.2f", "" ); \ +*/ \ +} + +//INSERT_GENTFUNC_BASIC0( gemmd_bp_var1 ) +GENTFUNC( float, s, gemmd_bp_var1 ) +GENTFUNC( double, d, gemmd_bp_var1 ) +GENTFUNC( scomplex, c, gemmd_bp_var1 ) +GENTFUNC( dcomplex, z, gemmd_bp_var1 ) + diff --git a/addon/gemmd/bao_gemmd_bp_var2.c b/addon/gemmd/bao_gemmd_bp_var2.c new file mode 100644 index 0000000000..a0040fec06 --- /dev/null +++ b/addon/gemmd/bao_gemmd_bp_var2.c @@ -0,0 +1,602 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define FUNCPTR_T gemmd_fp + +typedef void (*FUNCPTR_T) + ( + conj_t conja, + conj_t conjb, + dim_t m, + dim_t n, + dim_t k, + void* restrict alpha, + void* restrict a, inc_t rs_a, inc_t cs_a, + void* restrict d, inc_t incd, + void* restrict b, inc_t rs_b, inc_t cs_b, + void* restrict beta, + void* restrict c, inc_t rs_c, inc_t cs_c, + cntx_t* restrict cntx, + rntm_t* restrict rntm, + thrinfo_t* restrict thread + ); + +// +// -- gemmd-like block-panel algorithm (object interface) ---------------------- +// + +// Define a function pointer array named ftypes and initialize its contents with +// the addresses of the typed functions defined below, bao_?gemmd_bp_var2(). +static FUNCPTR_T GENARRAY_PREF(ftypes,bao_,gemmd_bp_var2); + +void bao_gemmd_bp_var2 + ( + obj_t* alpha, + obj_t* a, + obj_t* d, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm, + thrinfo_t* thread + ) +{ + const num_t dt = bli_obj_dt( c ); + + const conj_t conja = bli_obj_conj_status( a ); + const conj_t conjb = bli_obj_conj_status( b ); + + const dim_t m = bli_obj_length( c ); + const dim_t n = bli_obj_width( c ); + const dim_t k = bli_obj_width( a ); + + void* restrict buf_a = bli_obj_buffer_at_off( a ); + const inc_t rs_a = bli_obj_row_stride( a ); + const inc_t cs_a = bli_obj_col_stride( a ); + + void* restrict buf_d = bli_obj_buffer_at_off( d ); + const inc_t incd = bli_obj_vector_inc( d ); + + void* restrict buf_b = bli_obj_buffer_at_off( b ); + const inc_t rs_b = bli_obj_row_stride( b ); + const inc_t cs_b = bli_obj_col_stride( b ); + + void* restrict buf_c = bli_obj_buffer_at_off( c ); + const inc_t rs_c = bli_obj_row_stride( c ); + const inc_t cs_c = bli_obj_col_stride( c ); + + void* restrict buf_alpha = bli_obj_buffer_for_1x1( dt, alpha ); + void* restrict buf_beta = bli_obj_buffer_for_1x1( dt, beta ); + + // Index into the function pointer array to extract the correct + // typed function pointer based on the chosen datatype. + FUNCPTR_T f = ftypes[dt]; + + // Invoke the function. + f + ( + conja, + conjb, + m, + n, + k, + buf_alpha, + buf_a, rs_a, cs_a, + buf_d, incd, + buf_b, rs_b, cs_b, + buf_beta, + buf_c, rs_c, cs_c, + cntx, + rntm, + thread + ); +} + +// +// -- gemmd-like block-panel algorithm (typed interface) ----------------------- +// + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTECH2(bao_,ch,varname) \ + ( \ + conj_t conja, \ + conj_t conjb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + void* restrict alpha, \ + void* restrict a, inc_t rs_a, inc_t cs_a, \ + void* restrict d, inc_t incd, \ + void* restrict b, inc_t rs_b, inc_t cs_b, \ + void* restrict beta, \ + void* restrict c, inc_t rs_c, inc_t cs_c, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + const num_t dt = PASTEMAC(ch,type); \ +\ + /* Query the context for various blocksizes. */ \ + const dim_t NR = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx ); \ + const dim_t MR = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx ); \ + const dim_t NC = bli_cntx_get_blksz_def_dt( dt, BLIS_NC, cntx ); \ + const dim_t MC = bli_cntx_get_blksz_def_dt( dt, BLIS_MC, cntx ); \ + const dim_t KC = bli_cntx_get_blksz_def_dt( dt, BLIS_KC, cntx ); \ +\ + /* Query the context for the microkernel address and cast it to its + function pointer type. */ \ + /* + PASTECH(ch,gemm_ukr_ft) \ + gemm_ukr = bli_cntx_get_l3_nat_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \ + */ \ +\ + /* Temporary C buffer for edge cases. Note that the strides of this + temporary buffer are set so that they match the storage of the + original C matrix. For example, if C is column-stored, ct will be + column-stored as well. */ \ + /* + ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ + / sizeof( ctype ) ] \ + __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ + const bool col_pref = bli_cntx_l3_nat_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ + const inc_t rs_ct = ( col_pref ? 1 : NR ); \ + const inc_t cs_ct = ( col_pref ? MR : 1 ); \ + */ \ +\ + /* Compute partitioning step values for each matrix of each loop. */ \ + const inc_t jcstep_c = cs_c; \ + const inc_t jcstep_b = cs_b; \ +\ + const inc_t pcstep_a = cs_a; \ + const inc_t pcstep_d = incd; \ + const inc_t pcstep_b = rs_b; \ +\ + const inc_t icstep_c = rs_c; \ + const inc_t icstep_a = rs_a; \ +\ + const inc_t jrstep_c = cs_c * NR; \ +\ + const inc_t irstep_c = rs_c * MR; \ +\ + ctype* restrict a_00 = a; \ + ctype* restrict d_00 = d; \ + ctype* restrict b_00 = b; \ + ctype* restrict c_00 = c; \ + ctype* restrict alpha_cast = alpha; \ + ctype* restrict beta_cast = beta; \ +\ + /* Make local copies of the scalars to prevent any unnecessary sharing of + cache lines between the cores' caches. */ \ + ctype alpha_local = *alpha_cast; \ + ctype beta_local = *beta_cast; \ + ctype one_local = *PASTEMAC(ch,1); \ + /*ctype zero_local = *PASTEMAC(ch,0);*/ \ +\ + auxinfo_t aux; \ +\ + /* Initialize a mem_t entry for A and B. Strictly speaking, this is only + needed for the matrix we will be packing (if any), but we do it + unconditionally to be safe. */ \ + mem_t mem_a = BLIS_MEM_INITIALIZER; \ + mem_t mem_b = BLIS_MEM_INITIALIZER; \ +\ + /* Define an array of bszid_t ids, which will act as our substitute for + the cntl_t tree. */ \ + bszid_t bszids[8] = { BLIS_NC, /* 5th loop */ \ + BLIS_KC, /* 4th loop */ \ + BLIS_NO_PART, /* pack B */ \ + BLIS_MC, /* 3rd loop */ \ + BLIS_NO_PART, /* pack A */ \ + BLIS_NR, /* 2nd loop */ \ + BLIS_MR, /* 1st loop */ \ + BLIS_KR }; /* microkernel loop */ \ +\ + bszid_t* restrict bszids_jc = &bszids[0]; \ + bszid_t* restrict bszids_pc = &bszids[1]; \ + /*bszid_t* restrict bszids_pb = &bszids[2];*/ \ + bszid_t* restrict bszids_ic = &bszids[3]; \ + /*bszid_t* restrict bszids_pa = &bszids[4];*/ \ + bszid_t* restrict bszids_jr = &bszids[5]; \ + /*bszid_t* restrict bszids_ir = &bszids[6];*/ \ +\ + thrinfo_t* restrict thread_jc = NULL; \ + thrinfo_t* restrict thread_pc = NULL; \ + thrinfo_t* restrict thread_pb = NULL; \ + thrinfo_t* restrict thread_ic = NULL; \ + thrinfo_t* restrict thread_pa = NULL; \ + thrinfo_t* restrict thread_jr = NULL; \ + thrinfo_t* restrict thread_ir = NULL; \ +\ + /* Identify the current thrinfo_t node and then grow the tree. */ \ + thread_jc = thread; \ + bli_thrinfo_sup_grow( rntm, bszids_jc, thread_jc ); \ +\ + /* Compute the JC loop thread range for the current thread. */ \ + dim_t jc_start, jc_end; \ + bli_thread_range_sub( thread_jc, n, NR, FALSE, &jc_start, &jc_end ); \ + const dim_t n_local = jc_end - jc_start; \ +\ + /* Compute number of primary and leftover components of the JC loop. */ \ + /*const dim_t jc_iter = ( n_local + NC - 1 ) / NC;*/ \ + const dim_t jc_left = n_local % NC; \ +\ + /* Loop over the n dimension (NC rows/columns at a time). */ \ + for ( dim_t jj = jc_start; jj < jc_end; jj += NC ) \ + { \ + /* Calculate the thread's current JC block dimension. */ \ + const dim_t nc_cur = ( NC <= jc_end - jj ? NC : jc_left ); \ +\ + ctype* restrict b_jc = b_00 + jj * jcstep_b; \ + ctype* restrict c_jc = c_00 + jj * jcstep_c; \ +\ + /* Identify the current thrinfo_t node and then grow the tree. */ \ + thread_pc = bli_thrinfo_sub_node( thread_jc ); \ + bli_thrinfo_sup_grow( rntm, bszids_pc, thread_pc ); \ +\ + /* Compute the PC loop thread range for the current thread. */ \ + const dim_t pc_start = 0, pc_end = k; \ + const dim_t k_local = k; \ +\ + /* Compute number of primary and leftover components of the PC loop. */ \ + /*const dim_t pc_iter = ( k_local + KC - 1 ) / KC;*/ \ + const dim_t pc_left = k_local % KC; \ +\ + /* Loop over the k dimension (KC rows/columns at a time). */ \ + for ( dim_t pp = pc_start; pp < pc_end; pp += KC ) \ + { \ + /* Calculate the thread's current PC block dimension. */ \ + const dim_t kc_cur = ( KC <= pc_end - pp ? KC : pc_left ); \ +\ + ctype* restrict a_pc = a_00 + pp * pcstep_a; \ + ctype* restrict d_pc = d_00 + pp * pcstep_d; \ + ctype* restrict b_pc = b_jc + pp * pcstep_b; \ +\ + /* Only apply beta to the first iteration of the pc loop. */ \ + ctype* restrict beta_use = ( pp == 0 ? &beta_local : &one_local ); \ +\ + ctype* b_use; \ + inc_t rs_b_use, cs_b_use, ps_b_use; \ +\ + /* Identify the current thrinfo_t node. Note that the thrinfo_t + node will have already been created by a previous call to + bli_thrinfo_sup_grow() since bszid_t values of BLIS_NO_PART + cause the tree to grow by two (e.g. to the next bszid that is + a normal bszid_t value). */ \ + thread_pb = bli_thrinfo_sub_node( thread_pc ); \ + /*bli_thrinfo_sup_grow( rntm, bszids_pb, thread_pb );*/ \ +\ + /* Determine the packing buffer and related parameters for matrix + B. Then call the packm implementation. */ \ + PASTECH2(bao_,ch,packm_b) \ + ( \ + conjb, \ + KC, NC, \ + kc_cur, nc_cur, NR, \ + &one_local, \ + d_pc, incd, \ + b_pc, rs_b, cs_b, \ + &b_use, &rs_b_use, &cs_b_use, \ + &ps_b_use, \ + cntx, \ + rntm, \ + &mem_b, \ + thread_pb \ + ); \ +\ + /* Alias b_use so that it's clear this is our current block of + matrix B. */ \ + ctype* restrict b_pc_use = b_use; \ +\ + /* Identify the current thrinfo_t node and then grow the tree. */ \ + thread_ic = bli_thrinfo_sub_node( thread_pb ); \ + bli_thrinfo_sup_grow( rntm, bszids_ic, thread_ic ); \ +\ + /* Compute the IC loop thread range for the current thread. */ \ + dim_t ic_start, ic_end; \ + bli_thread_range_sub( thread_ic, m, MR, FALSE, &ic_start, &ic_end ); \ + const dim_t m_local = ic_end - ic_start; \ +\ + /* Compute number of primary and leftover components of the IC loop. */ \ + /*const dim_t ic_iter = ( m_local + MC - 1 ) / MC;*/ \ + const dim_t ic_left = m_local % MC; \ +\ + /* Loop over the m dimension (MC rows at a time). */ \ + for ( dim_t ii = ic_start; ii < ic_end; ii += MC ) \ + { \ + /* Calculate the thread's current IC block dimension. */ \ + const dim_t mc_cur = ( MC <= ic_end - ii ? MC : ic_left ); \ +\ + ctype* restrict a_ic = a_pc + ii * icstep_a; \ + ctype* restrict c_ic = c_jc + ii * icstep_c; \ +\ + ctype* a_use; \ + inc_t rs_a_use, cs_a_use, ps_a_use; \ +\ + /* Identify the current thrinfo_t node. Note that the thrinfo_t + node will have already been created by a previous call to + bli_thrinfo_sup_grow() since bszid_t values of BLIS_NO_PART + cause the tree to grow by two (e.g. to the next bszid that is + a normal bszid_t value). */ \ + thread_pa = bli_thrinfo_sub_node( thread_ic ); \ + /*bli_thrinfo_sup_grow( rntm, bszids_pa, thread_pa );*/ \ +\ + /* Determine the packing buffer and related parameters for matrix + A. Then call the packm implementation. */ \ + PASTECH2(bao_,ch,packm_a) \ + ( \ + conja, \ + MC, KC, \ + mc_cur, kc_cur, MR, \ + &one_local, \ + d_pc, incd, \ + a_ic, rs_a, cs_a, \ + &a_use, &rs_a_use, &cs_a_use, \ + &ps_a_use, \ + cntx, \ + rntm, \ + &mem_a, \ + thread_pa \ + ); \ +\ + /* Alias a_use so that it's clear this is our current block of + matrix A. */ \ + ctype* restrict a_ic_use = a_use; \ +\ + /* Identify the current thrinfo_t node and then grow the tree. */ \ + thread_jr = bli_thrinfo_sub_node( thread_pa ); \ + bli_thrinfo_sup_grow( rntm, bszids_jr, thread_jr ); \ +\ + /* Query the number of threads and thread ids for the JR loop. + NOTE: These values are only needed when computing the next + micropanel of B. */ \ + const dim_t jr_nt = bli_thread_n_way( thread_jr ); \ + const dim_t jr_tid = bli_thread_work_id( thread_jr ); \ +\ + /* Compute number of primary and leftover components of the JR loop. */ \ + dim_t jr_iter = ( nc_cur + NR - 1 ) / NR; \ + dim_t jr_left = nc_cur % NR; \ +\ + /* Compute the JR loop thread range for the current thread. */ \ + dim_t jr_start, jr_end; \ + bli_thread_range_sub( thread_jr, jr_iter, 1, FALSE, &jr_start, &jr_end ); \ +\ + /* Loop over the n dimension (NR columns at a time). */ \ + for ( dim_t j = jr_start; j < jr_end; j += 1 ) \ + { \ + const dim_t nr_cur \ + = ( bli_is_not_edge_f( j, jr_iter, jr_left ) ? NR : jr_left ); \ +\ + ctype* restrict b_jr = b_pc_use + j * ps_b_use; \ + ctype* restrict c_jr = c_ic + j * jrstep_c; \ +\ + /* Assume for now that our next panel of B to be the current panel + of B. */ \ + ctype* restrict b2 = b_jr; \ +\ + /* Identify the current thrinfo_t node. */ \ + thread_ir = bli_thrinfo_sub_node( thread_jr ); \ +\ + /* Query the number of threads and thread ids for the IR loop. + NOTE: These values are only needed when computing the next + micropanel of A. */ \ + const dim_t ir_nt = bli_thread_n_way( thread_ir ); \ + const dim_t ir_tid = bli_thread_work_id( thread_ir ); \ +\ + /* Compute number of primary and leftover components of the IR loop. */ \ + dim_t ir_iter = ( mc_cur + MR - 1 ) / MR; \ + dim_t ir_left = mc_cur % MR; \ +\ + /* Compute the IR loop thread range for the current thread. */ \ + dim_t ir_start, ir_end; \ + bli_thread_range_sub( thread_ir, ir_iter, 1, FALSE, &ir_start, &ir_end ); \ +\ + /* Loop over the m dimension (MR rows at a time). */ \ + for ( dim_t i = ir_start; i < ir_end; i += 1 ) \ + { \ + const dim_t mr_cur \ + = ( bli_is_not_edge_f( i, ir_iter, ir_left ) ? MR : ir_left ); \ +\ + ctype* restrict a_ir = a_ic_use + i * ps_a_use; \ + ctype* restrict c_ir = c_jr + i * irstep_c; \ +\ + ctype* restrict a2; \ +\ + /* Compute the addresses of the next micropanels of A and B. */ \ + a2 = bli_gemm_get_next_a_upanel( a_ir, ps_a_use, 1 ); \ + if ( bli_is_last_iter( i, ir_end, ir_tid, ir_nt ) ) \ + { \ + a2 = a_ic_use; \ + b2 = bli_gemm_get_next_b_upanel( b_jr, ps_b_use, 1 ); \ + if ( bli_is_last_iter( j, jr_end, jr_tid, jr_nt ) ) \ + b2 = b_pc_use; \ + } \ +\ + /* Save the addresses of next micropanels of A and B to the + auxinfo_t object. */ \ + bli_auxinfo_set_next_a( a2, &aux ); \ + bli_auxinfo_set_next_b( b2, &aux ); \ +\ + /* Call a wrapper to the kernel (which handles edge cases). */ \ + PASTECH2(bao_,ch,gemm_kernel) \ + ( \ + MR, \ + NR, \ + mr_cur, \ + nr_cur, \ + kc_cur, \ + &alpha_local, \ + a_ir, rs_a_use, cs_a_use, \ + b_jr, rs_b_use, cs_b_use, \ + beta_use, \ + c_ir, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ + } \ + } \ + } \ +\ + /* This barrier is needed to prevent threads from starting to pack + the next row panel of B before the current row panel is fully + computed upon. */ \ + bli_thread_barrier( thread_pb ); \ + } \ + } \ +\ + /* Release any memory that was acquired for packing matrices A and B. */ \ + PASTECH2(bao_,ch,packm_finalize_mem_a) \ + ( \ + rntm, \ + &mem_a, \ + thread_pa \ + ); \ + PASTECH2(bao_,ch,packm_finalize_mem_b) \ + ( \ + rntm, \ + &mem_b, \ + thread_pb \ + ); \ +\ +/* +PASTEMAC(ch,fprintm)( stdout, "gemmd_bp_var2: a1_packed", mr_cur, kc_cur, a_ir, rs_a_use, cs_a_use, "%5.2f", "" ); \ +PASTEMAC(ch,fprintm)( stdout, "gemmd_bp_var2: b1_packed", kc_cur, nr_cur, b_jr, rs_b_use, cs_b_use, "%5.2f", "" ); \ +PASTEMAC(ch,fprintm)( stdout, "gemmd_bp_var2: c ", mr_cur, nr_cur, c_ir, rs_c, cs_c, "%5.2f", "" ); \ +*/ \ +} + +//INSERT_GENTFUNC_BASIC0( gemmd_bp_var2 ) +GENTFUNC( float, s, gemmd_bp_var2 ) +GENTFUNC( double, d, gemmd_bp_var2 ) +GENTFUNC( scomplex, c, gemmd_bp_var2 ) +GENTFUNC( dcomplex, z, gemmd_bp_var2 ) + +// +// -- gemm-like microkernel wrapper -------------------------------------------- +// + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTECH2(bao_,ch,varname) \ + ( \ + const dim_t MR, \ + const dim_t NR, \ + dim_t mr_cur, \ + dim_t nr_cur, \ + dim_t kc_cur, \ + ctype* restrict alpha, \ + ctype* restrict a, inc_t rs_a, inc_t cs_a, \ + ctype* restrict b, inc_t rs_b, inc_t cs_b, \ + ctype* restrict beta, \ + ctype* restrict c, inc_t rs_c, inc_t cs_c, \ + auxinfo_t* restrict aux, \ + cntx_t* restrict cntx \ + ) \ +{ \ + /* Infer the datatype from the ctype. */ \ + const num_t dt = PASTEMAC(ch,type); \ +\ + /* Query the context for the microkernel address and cast it to its + function pointer type. */ \ + PASTECH(ch,gemm_ukr_ft) \ + gemm_ukr = bli_cntx_get_l3_nat_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \ +\ + /* Temporary C buffer for edge cases. Note that the strides of this + temporary buffer are set so that they match the storage of the + original C matrix. For example, if C is column-stored, ct will be + column-stored as well. */ \ + ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ + / sizeof( ctype ) ] \ + __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ + const bool col_pref = bli_cntx_l3_nat_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ + const inc_t rs_ct = ( col_pref ? 1 : NR ); \ + const inc_t cs_ct = ( col_pref ? MR : 1 ); \ +\ + ctype zero = *PASTEMAC(ch,0); \ +\ + /* Handle interior and edge cases separately. */ \ + if ( mr_cur == MR && nr_cur == NR ) \ + { \ + /* Invoke the gemm microkernel. */ \ + gemm_ukr \ + ( \ + kc_cur, \ + alpha, \ + a, \ + b, \ + beta, \ + c, rs_c, cs_c, \ + aux, \ + cntx \ + ); \ + } \ + else \ + { \ + /* Invoke the gemm microkernel. */ \ + gemm_ukr \ + ( \ + kc_cur, \ + alpha, \ + a, \ + b, \ + &zero, \ + ct, rs_ct, cs_ct, \ + aux, \ + cntx \ + ); \ +\ + /* Scale the bottom edge of C and add the result from above. */ \ + PASTEMAC(ch,xpbys_mxn) \ + ( \ + mr_cur, \ + nr_cur, \ + ct, rs_ct, cs_ct, \ + beta, \ + c, rs_c, cs_c \ + ); \ + } \ +} + +//INSERT_GENTFUNC_BASIC0( gemm_kernel ) +GENTFUNC( float, s, gemm_kernel ) +GENTFUNC( double, d, gemm_kernel ) +GENTFUNC( scomplex, c, gemm_kernel ) +GENTFUNC( dcomplex, z, gemm_kernel ) + diff --git a/addon/gemmd/bao_gemmd_check.c b/addon/gemmd/bao_gemmd_check.c new file mode 100644 index 0000000000..864e9a1acb --- /dev/null +++ b/addon/gemmd/bao_gemmd_check.c @@ -0,0 +1,131 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +void bao_gemmd_check + ( + obj_t* alpha, + obj_t* a, + obj_t* d, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx + ) +{ + err_t e_val; + + // Check object datatypes. + + e_val = bli_check_noninteger_object( alpha ); + bli_check_error_code( e_val ); + + e_val = bli_check_noninteger_object( beta ); + bli_check_error_code( e_val ); + + e_val = bli_check_floating_object( a ); + bli_check_error_code( e_val ); + + e_val = bli_check_floating_object( d ); + bli_check_error_code( e_val ); + + e_val = bli_check_floating_object( b ); + bli_check_error_code( e_val ); + + e_val = bli_check_floating_object( c ); + bli_check_error_code( e_val ); + + // Check scalar/vector/matrix type. + + e_val = bli_check_scalar_object( alpha ); + bli_check_error_code( e_val ); + + e_val = bli_check_scalar_object( beta ); + bli_check_error_code( e_val ); + + e_val = bli_check_matrix_object( a ); + bli_check_error_code( e_val ); + + e_val = bli_check_vector_object( d ); + bli_check_error_code( e_val ); + + e_val = bli_check_matrix_object( b ); + bli_check_error_code( e_val ); + + e_val = bli_check_matrix_object( c ); + bli_check_error_code( e_val ); + + // Check object buffers (for non-NULLness). + + e_val = bli_check_object_buffer( alpha ); + bli_check_error_code( e_val ); + + e_val = bli_check_object_buffer( a ); + bli_check_error_code( e_val ); + + e_val = bli_check_object_buffer( d ); + bli_check_error_code( e_val ); + + e_val = bli_check_object_buffer( b ); + bli_check_error_code( e_val ); + + e_val = bli_check_object_buffer( beta ); + bli_check_error_code( e_val ); + + e_val = bli_check_object_buffer( c ); + bli_check_error_code( e_val ); + + // Check object dimensions. + + e_val = bli_check_level3_dims( a, b, c ); + bli_check_error_code( e_val ); + + e_val = bli_check_vector_dim_equals( d, bli_obj_width_after_trans( a ) ); + bli_check_error_code( e_val ); + + // Check for consistent datatypes. + // NOTE: We only perform these tests when mixed datatype support is + // disabled. + + e_val = bli_check_consistent_object_datatypes( c, a ); + bli_check_error_code( e_val ); + + e_val = bli_check_consistent_object_datatypes( c, d ); + bli_check_error_code( e_val ); + + e_val = bli_check_consistent_object_datatypes( c, b ); + bli_check_error_code( e_val ); +} + diff --git a/addon/gemmd/bao_gemmd_check.h b/addon/gemmd/bao_gemmd_check.h new file mode 100644 index 0000000000..243ec70c8c --- /dev/null +++ b/addon/gemmd/bao_gemmd_check.h @@ -0,0 +1,50 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + + +// +// Prototype object-based check functions. +// + +void bao_gemmd_check + ( + obj_t* alpha, + obj_t* a, + obj_t* d, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx + ); + diff --git a/addon/gemmd/bao_gemmd_var.h b/addon/gemmd/bao_gemmd_var.h new file mode 100644 index 0000000000..5c66747275 --- /dev/null +++ b/addon/gemmd/bao_gemmd_var.h @@ -0,0 +1,126 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + + +// +// Prototype the object-based variant interfaces. +// + +#undef GENPROT +#define GENPROT( opname ) \ +\ +void PASTECH(bao_,opname) \ + ( \ + obj_t* alpha, \ + obj_t* a, \ + obj_t* d, \ + obj_t* b, \ + obj_t* beta, \ + obj_t* c, \ + cntx_t* cntx, \ + rntm_t* rntm, \ + thrinfo_t* thread \ + ); + +GENPROT( gemmd_bp_var1 ) +GENPROT( gemmd_bp_var2 ) + + +// +// Prototype the typed variant interfaces. +// + +#undef GENTPROT +#define GENTPROT( ctype, ch, varname ) \ +\ +void PASTECH2(bao_,ch,varname) \ + ( \ + conj_t conja, \ + conj_t conjb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + void* restrict alpha, \ + void* restrict a, inc_t rs_a, inc_t cs_a, \ + void* restrict d, inc_t incd, \ + void* restrict b, inc_t rs_b, inc_t cs_b, \ + void* restrict beta, \ + void* restrict c, inc_t rs_c, inc_t cs_c, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + thrinfo_t* restrict thread \ + ); + +//INSERT_GENTPROT_BASIC0( gemmd_bp_var1 ) +GENTPROT( float, s, gemmd_bp_var1 ) +GENTPROT( double, d, gemmd_bp_var1 ) +GENTPROT( scomplex, c, gemmd_bp_var1 ) +GENTPROT( dcomplex, z, gemmd_bp_var1 ) + +//INSERT_GENTPROT_BASIC0( gemmd_bp_var2 ) +GENTPROT( float, s, gemmd_bp_var2 ) +GENTPROT( double, d, gemmd_bp_var2 ) +GENTPROT( scomplex, c, gemmd_bp_var2 ) +GENTPROT( dcomplex, z, gemmd_bp_var2 ) + + +// +// Prototype the typed kernel interfaces. +// + +#undef GENTPROT +#define GENTPROT( ctype, ch, varname ) \ +\ +void PASTECH2(bao_,ch,varname) \ + ( \ + const dim_t MR, \ + const dim_t NR, \ + dim_t mr_cur, \ + dim_t nr_cur, \ + dim_t k, \ + ctype* restrict alpha, \ + ctype* restrict a, inc_t rs_a, inc_t cs_a, \ + ctype* restrict b, inc_t rs_b, inc_t cs_b, \ + ctype* restrict beta, \ + ctype* restrict c, inc_t rs_c, inc_t cs_c, \ + auxinfo_t* restrict aux, \ + cntx_t* restrict cntx \ + ); + +//INSERT_GENTPROT_BASIC0( gemm_kernel ) +GENTPROT( float, s, gemm_kernel ) +GENTPROT( double, d, gemm_kernel ) +GENTPROT( scomplex, c, gemm_kernel ) +GENTPROT( dcomplex, z, gemm_kernel ) + diff --git a/addon/gemmd/bao_l3_packm_a.c b/addon/gemmd/bao_l3_packm_a.c new file mode 100644 index 0000000000..49bb34664c --- /dev/null +++ b/addon/gemmd/bao_l3_packm_a.c @@ -0,0 +1,330 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTECH2(bao_,ch,opname) \ + ( \ + dim_t m, \ + dim_t k, \ + dim_t mr, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + /* Set the pack buffer type so that we are obtaining memory blocks from + the pool dedicated to blocks of A. */ \ + const packbuf_t pack_buf_type = BLIS_BUFFER_FOR_A_BLOCK; \ +\ + /* NOTE: This "rounding up" of the last upanel is absolutely necessary since + we NEED that last micropanel to have the same ldim (cs_p) as the other + micropanels. Why? Because the microkernel assumes that the register (MR, + NR) AND storage (PACKMR, PACKNR) blocksizes do not change. */ \ + const dim_t m_pack = ( m / mr + ( m % mr ? 1 : 0 ) ) * mr; \ + const dim_t k_pack = k; \ +\ + /* Barrier to make sure all threads are caught up and ready to begin the + packm stage. */ \ + bli_thread_barrier( thread ); \ +\ + /* Compute the size of the memory block eneded. */ \ + siz_t size_needed = sizeof( ctype ) * m_pack * k_pack; \ +\ + /* Check the mem_t entry provided by the caller. If it is unallocated, + then we need to acquire a block from the packed block allocator. */ \ + if ( bli_mem_is_unalloc( mem ) ) \ + { \ + if ( bli_thread_am_ochief( thread ) ) \ + { \ + /* Acquire directly to the chief thread's mem_t that was passed in. + It needs to be that mem_t struct, and not a local (temporary) + mem_t, since there is no barrier until after packing is finished, + which could allow a race condition whereby the chief thread exits + the current function before the other threads have a chance to + copy from it. (A barrier would fix that race condition, but then + again, I prefer to keep barriers to a minimum.) */ \ + bli_pba_acquire_m \ + ( \ + rntm, \ + size_needed, \ + pack_buf_type, \ + mem \ + ); \ + } \ +\ + /* Broadcast the address of the chief thread's passed-in mem_t to all + threads. */ \ + mem_t* mem_p = bli_thread_broadcast( thread, mem ); \ +\ + /* Non-chief threads: Copy the contents of the chief thread's + passed-in mem_t to the passed-in mem_t for this thread. (The + chief thread already has the mem_t, so it does not need to + perform any copy.) */ \ + if ( !bli_thread_am_ochief( thread ) ) \ + { \ + *mem = *mem_p; \ + } \ + } \ + else /* if ( bli_mem_is_alloc( mem ) ) */ \ + { \ + /* If the mem_t entry provided by the caller does NOT contain a NULL + buffer, then a block has already been acquired from the packed + block allocator and cached by the caller. */ \ +\ + /* As a sanity check, we should make sure that the mem_t object isn't + associated with a block that is too small compared to the size of + the packed matrix buffer that is needed, according to the value + computed above. */ \ + siz_t mem_size = bli_mem_size( mem ); \ +\ + if ( mem_size < size_needed ) \ + { \ + if ( bli_thread_am_ochief( thread ) ) \ + { \ + /* The chief thread releases the existing block associated + with the mem_t, and then re-acquires a new block, saving + the associated mem_t to its passed-in mem_t. (See coment + above for why the acquisition needs to be directly to + the chief thread's passed-in mem_t and not a local + (temporary) mem_t. */ \ + bli_pba_release \ + ( \ + rntm, \ + mem \ + ); \ + bli_pba_acquire_m \ + ( \ + rntm, \ + size_needed, \ + pack_buf_type, \ + mem \ + ); \ + } \ +\ + /* Broadcast the address of the chief thread's passed-in mem_t + to all threads. */ \ + mem_t* mem_p = bli_thread_broadcast( thread, mem ); \ +\ + /* Non-chief threads: Copy the contents of the chief thread's + passed-in mem_t to the passed-in mem_t for this thread. (The + chief thread already has the mem_t, so it does not need to + perform any copy.) */ \ + if ( !bli_thread_am_ochief( thread ) ) \ + { \ + *mem = *mem_p; \ + } \ + } \ + else \ + { \ + /* If the mem_t entry is already allocated and sufficiently large, + then we use it as-is. No action is needed. */ \ + } \ + } \ +} + +//INSERT_GENTFUNC_BASIC0( packm_init_mem_a ) +GENTFUNC( float, s, packm_init_mem_a ) +GENTFUNC( double, d, packm_init_mem_a ) +GENTFUNC( scomplex, c, packm_init_mem_a ) +GENTFUNC( dcomplex, z, packm_init_mem_a ) + + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTECH2(bao_,ch,opname) \ + ( \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + if ( thread != NULL ) \ + if ( bli_thread_am_ochief( thread ) ) \ + { \ + /* Check the mem_t entry provided by the caller. Only proceed if it + is allocated, which it should be. */ \ + if ( bli_mem_is_alloc( mem ) ) \ + { \ + bli_pba_release \ + ( \ + rntm, \ + mem \ + ); \ + } \ + } \ +} + +//INSERT_GENTFUNC_BASIC0( packm_finalize_mem_a ) +GENTFUNC( float, s, packm_finalize_mem_a ) +GENTFUNC( double, d, packm_finalize_mem_a ) +GENTFUNC( scomplex, c, packm_finalize_mem_a ) +GENTFUNC( dcomplex, z, packm_finalize_mem_a ) + + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTECH2(bao_,ch,opname) \ + ( \ + pack_t* restrict schema, \ + dim_t m, \ + dim_t k, \ + dim_t mr, \ + dim_t* restrict m_max, \ + dim_t* restrict k_max, \ + ctype** p, inc_t* restrict rs_p, inc_t* restrict cs_p, \ + dim_t* restrict pd_p, inc_t* restrict ps_p, \ + mem_t* restrict mem \ + ) \ +{ \ + /* NOTE: This "rounding up" of the last upanel is absolutely necessary since + we NEED that last micropanel to have the same ldim (cs_p) as the other + micropanels. Why? Because the microkernel assumes that the register (MR, + NR) AND storage (PACKMR, PACKNR) blocksizes do not change. */ \ + *m_max = ( m / mr + ( m % mr ? 1 : 0 ) ) * mr; \ + *k_max = k; \ +\ + /* Determine the dimensions and strides for the packed matrix A. */ \ + { \ + /* Pack A to column-stored row-panels. */ \ + *rs_p = 1; \ + *cs_p = mr; \ +\ + *pd_p = mr; \ + *ps_p = mr * k; \ +\ + /* Set the schema to "packed row panels" to indicate packing to + conventional column-stored row panels. */ \ + *schema = BLIS_PACKED_ROW_PANELS; \ + } \ +\ + /* Set the buffer address provided by the caller to point to the memory + associated with the mem_t entry acquired from the memory pool. */ \ + *p = bli_mem_buffer( mem ); \ +} + +//INSERT_GENTFUNC_BASIC0( packm_init_a ) +GENTFUNC( float, s, packm_init_a ) +GENTFUNC( double, d, packm_init_a ) +GENTFUNC( scomplex, c, packm_init_a ) +GENTFUNC( dcomplex, z, packm_init_a ) + + +// +// Define BLAS-like interfaces to the variant chooser. +// + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTECH2(bao_,ch,opname) \ + ( \ + conj_t conj, \ + dim_t m_alloc, \ + dim_t k_alloc, \ + dim_t m, \ + dim_t k, \ + dim_t mr, \ + ctype* restrict kappa, \ + ctype* restrict d, inc_t incd, \ + ctype* restrict a, inc_t rs_a, inc_t cs_a, \ + ctype** restrict p, inc_t* restrict rs_p, inc_t* restrict cs_p, \ + inc_t* restrict ps_p, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + pack_t schema; \ + dim_t m_max; \ + dim_t k_max; \ + dim_t pd_p; \ +\ + /* Prepare the packing destination buffer. */ \ + PASTECH2(bao_,ch,packm_init_mem_a) \ + ( \ + m_alloc, k_alloc, mr, \ + cntx, \ + rntm, \ + mem, \ + thread \ + ); \ +\ + /* Determine the packing buffer and related parameters for matrix A. */ \ + PASTECH2(bao_,ch,packm_init_a) \ + ( \ + &schema, \ + m, k, mr, \ + &m_max, &k_max, \ + p, rs_p, cs_p, \ + &pd_p, ps_p, \ + mem \ + ); \ +\ + /* Pack matrix A to the destination buffer chosen above. Here, the packed + matrix is stored to column-stored MR x k micropanels. */ \ + PASTECH2(bao_,ch,packm_var1) \ + ( \ + conj, \ + schema, \ + m, \ + k, \ + m_max, \ + k_max, \ + kappa, \ + d, incd, \ + a, rs_a, cs_a, \ + *p, *rs_p, *cs_p, \ + pd_p, *ps_p, \ + cntx, \ + thread \ + ); \ +\ + /* Barrier so that packing is done before computation. */ \ + bli_thread_barrier( thread ); \ +} + +//INSERT_GENTFUNC_BASIC0( packm_a ) +GENTFUNC( float, s, packm_a ) +GENTFUNC( double, d, packm_a ) +GENTFUNC( scomplex, c, packm_a ) +GENTFUNC( dcomplex, z, packm_a ) + diff --git a/addon/gemmd/bao_l3_packm_a.h b/addon/gemmd/bao_l3_packm_a.h new file mode 100644 index 0000000000..b683b79d4a --- /dev/null +++ b/addon/gemmd/bao_l3_packm_a.h @@ -0,0 +1,123 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#undef GENTPROT +#define GENTPROT( ctype, ch, opname ) \ +\ +void PASTECH2(bao_,ch,opname) \ + ( \ + dim_t m, \ + dim_t k, \ + dim_t mr, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ); \ + +//INSERT_GENTPROT_BASIC0( packm_init_mem_a ) +GENTPROT( float, s, packm_init_mem_a ) +GENTPROT( double, d, packm_init_mem_a ) +GENTPROT( scomplex, c, packm_init_mem_a ) +GENTPROT( dcomplex, z, packm_init_mem_a ) + + +#undef GENTPROT +#define GENTPROT( ctype, ch, opname ) \ +\ +void PASTECH2(bao_,ch,opname) \ + ( \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ); \ + +//INSERT_GENTPROT_BASIC0( packm_finalize_mem_a ) +GENTPROT( float, s, packm_finalize_mem_a ) +GENTPROT( double, d, packm_finalize_mem_a ) +GENTPROT( scomplex, c, packm_finalize_mem_a ) +GENTPROT( dcomplex, z, packm_finalize_mem_a ) + + +#undef GENTPROT +#define GENTPROT( ctype, ch, opname ) \ +\ +void PASTECH2(bao_,ch,opname) \ + ( \ + pack_t* restrict schema, \ + dim_t m, \ + dim_t k, \ + dim_t mr, \ + dim_t* restrict m_max, \ + dim_t* restrict k_max, \ + ctype** p, inc_t* restrict rs_p, inc_t* restrict cs_p, \ + dim_t* restrict pd_p, inc_t* restrict ps_p, \ + mem_t* restrict mem \ + ); \ + +//INSERT_GENTPROT_BASIC0( packm_init_a ) +GENTPROT( float, s, packm_init_a ) +GENTPROT( double, d, packm_init_a ) +GENTPROT( scomplex, c, packm_init_a ) +GENTPROT( dcomplex, z, packm_init_a ) + + +#undef GENTPROT +#define GENTPROT( ctype, ch, opname ) \ +\ +void PASTECH2(bao_,ch,opname) \ + ( \ + conj_t conj, \ + dim_t m_alloc, \ + dim_t k_alloc, \ + dim_t m, \ + dim_t k, \ + dim_t mr, \ + ctype* restrict kappa, \ + ctype* restrict d, inc_t incd, \ + ctype* restrict a, inc_t rs_a, inc_t cs_a, \ + ctype** restrict p, inc_t* restrict rs_p, inc_t* restrict cs_p, \ + inc_t* restrict ps_p, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ); \ + +//INSERT_GENTPROT_BASIC0( packm_a ) +GENTPROT( float, s, packm_a ) +GENTPROT( double, d, packm_a ) +GENTPROT( scomplex, c, packm_a ) +GENTPROT( dcomplex, z, packm_a ) + diff --git a/addon/gemmd/bao_l3_packm_b.c b/addon/gemmd/bao_l3_packm_b.c new file mode 100644 index 0000000000..c41b062b6e --- /dev/null +++ b/addon/gemmd/bao_l3_packm_b.c @@ -0,0 +1,330 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTECH2(bao_,ch,opname) \ + ( \ + dim_t k, \ + dim_t n, \ + dim_t nr, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + /* Set the pack buffer type so that we are obtaining memory blocks from + the pool dedicated to panels of B. */ \ + const packbuf_t pack_buf_type = BLIS_BUFFER_FOR_B_PANEL; \ +\ + /* NOTE: This "rounding up" of the last upanel is absolutely necessary since + we NEED that last micropanel to have the same ldim (cs_p) as the other + micropanels. Why? Because the microkernel assumes that the register (MR, + NR) AND storage (PACKMR, PACKNR) blocksizes do not change. */ \ + const dim_t k_pack = k; \ + const dim_t n_pack = ( n / nr + ( n % nr ? 1 : 0 ) ) * nr; \ +\ + /* Barrier to make sure all threads are caught up and ready to begin the + packm stage. */ \ + bli_thread_barrier( thread ); \ +\ + /* Compute the size of the memory block eneded. */ \ + siz_t size_needed = sizeof( ctype ) * k_pack * n_pack; \ +\ + /* Check the mem_t entry provided by the caller. If it is unallocated, + then we need to acquire a block from the packed block allocator. */ \ + if ( bli_mem_is_unalloc( mem ) ) \ + { \ + if ( bli_thread_am_ochief( thread ) ) \ + { \ + /* Acquire directly to the chief thread's mem_t that was passed in. + It needs to be that mem_t struct, and not a local (temporary) + mem_t, since there is no barrier until after packing is finished, + which could allow a race condition whereby the chief thread exits + the current function before the other threads have a chance to + copy from it. (A barrier would fix that race condition, but then + again, I prefer to keep barriers to a minimum.) */ \ + bli_pba_acquire_m \ + ( \ + rntm, \ + size_needed, \ + pack_buf_type, \ + mem \ + ); \ + } \ +\ + /* Broadcast the address of the chief thread's passed-in mem_t to all + threads. */ \ + mem_t* mem_p = bli_thread_broadcast( thread, mem ); \ +\ + /* Non-chief threads: Copy the contents of the chief thread's + passed-in mem_t to the passed-in mem_t for this thread. (The + chief thread already has the mem_t, so it does not need to + perform any copy.) */ \ + if ( !bli_thread_am_ochief( thread ) ) \ + { \ + *mem = *mem_p; \ + } \ + } \ + else /* if ( bli_mem_is_alloc( mem ) ) */ \ + { \ + /* If the mem_t entry provided by the caller does NOT contain a NULL + buffer, then a block has already been acquired from the packed + block allocator and cached by the caller. */ \ +\ + /* As a sanity check, we should make sure that the mem_t object isn't + associated with a block that is too small compared to the size of + the packed matrix buffer that is needed, according to the value + computed above. */ \ + siz_t mem_size = bli_mem_size( mem ); \ +\ + if ( mem_size < size_needed ) \ + { \ + if ( bli_thread_am_ochief( thread ) ) \ + { \ + /* The chief thread releases the existing block associated + with the mem_t, and then re-acquires a new block, saving + the associated mem_t to its passed-in mem_t. (See coment + above for why the acquisition needs to be directly to + the chief thread's passed-in mem_t and not a local + (temporary) mem_t. */ \ + bli_pba_release \ + ( \ + rntm, \ + mem \ + ); \ + bli_pba_acquire_m \ + ( \ + rntm, \ + size_needed, \ + pack_buf_type, \ + mem \ + ); \ + } \ +\ + /* Broadcast the address of the chief thread's passed-in mem_t + to all threads. */ \ + mem_t* mem_p = bli_thread_broadcast( thread, mem ); \ +\ + /* Non-chief threads: Copy the contents of the chief thread's + passed-in mem_t to the passed-in mem_t for this thread. (The + chief thread already has the mem_t, so it does not need to + perform any copy.) */ \ + if ( !bli_thread_am_ochief( thread ) ) \ + { \ + *mem = *mem_p; \ + } \ + } \ + else \ + { \ + /* If the mem_t entry is already allocated and sufficiently large, + then we use it as-is. No action is needed. */ \ + } \ + } \ +} + +//INSERT_GENTFUNC_BASIC0( packm_init_mem_b ) +GENTFUNC( float, s, packm_init_mem_b ) +GENTFUNC( double, d, packm_init_mem_b ) +GENTFUNC( scomplex, c, packm_init_mem_b ) +GENTFUNC( dcomplex, z, packm_init_mem_b ) + + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTECH2(bao_,ch,opname) \ + ( \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + if ( thread != NULL ) \ + if ( bli_thread_am_ochief( thread ) ) \ + { \ + /* Check the mem_t entry provided by the caller. Only proceed if it + is allocated, which it should be. */ \ + if ( bli_mem_is_alloc( mem ) ) \ + { \ + bli_pba_release \ + ( \ + rntm, \ + mem \ + ); \ + } \ + } \ +} + +//INSERT_GENTFUNC_BASIC0( packm_finalize_mem_b ) +GENTFUNC( float, s, packm_finalize_mem_b ) +GENTFUNC( double, d, packm_finalize_mem_b ) +GENTFUNC( scomplex, c, packm_finalize_mem_b ) +GENTFUNC( dcomplex, z, packm_finalize_mem_b ) + + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTECH2(bao_,ch,opname) \ + ( \ + pack_t* restrict schema, \ + dim_t k, \ + dim_t n, \ + dim_t nr, \ + dim_t* restrict k_max, \ + dim_t* restrict n_max, \ + ctype** p, inc_t* restrict rs_p, inc_t* restrict cs_p, \ + dim_t* restrict pd_p, inc_t* restrict ps_p, \ + mem_t* restrict mem \ + ) \ +{ \ + /* NOTE: This "rounding up" of the last upanel is absolutely necessary since + we NEED that last micropanel to have the same ldim (cs_p) as the other + micropanels. Why? Because the microkernel assumes that the register (MR, + NR) AND storage (PACKMR, PACKNR) blocksizes do not change. */ \ + *k_max = k; \ + *n_max = ( n / nr + ( n % nr ? 1 : 0 ) ) * nr; \ +\ + /* Determine the dimensions and strides for the packed matrix B. */ \ + { \ + /* Pack B to row-stored column-panels. */ \ + *rs_p = nr; \ + *cs_p = 1; \ +\ + *pd_p = nr; \ + *ps_p = k * nr; \ +\ + /* Set the schema to "packed column panels" to indicate packing to + conventional row-stored column panels. */ \ + *schema = BLIS_PACKED_COL_PANELS; \ + } \ +\ + /* Set the buffer address provided by the caller to point to the memory + associated with the mem_t entry acquired from the memory pool. */ \ + *p = bli_mem_buffer( mem ); \ +} + +//INSERT_GENTFUNC_BASIC0( packm_init_b ) +GENTFUNC( float, s, packm_init_b ) +GENTFUNC( double, d, packm_init_b ) +GENTFUNC( scomplex, c, packm_init_b ) +GENTFUNC( dcomplex, z, packm_init_b ) + + +// +// Define BLAS-like interfaces to the variant chooser. +// + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTECH2(bao_,ch,opname) \ + ( \ + conj_t conj, \ + dim_t k_alloc, \ + dim_t n_alloc, \ + dim_t k, \ + dim_t n, \ + dim_t nr, \ + ctype* restrict kappa, \ + ctype* restrict d, inc_t incd, \ + ctype* restrict b, inc_t rs_b, inc_t cs_b, \ + ctype** restrict p, inc_t* restrict rs_p, inc_t* restrict cs_p, \ + inc_t* restrict ps_p, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + pack_t schema; \ + dim_t k_max; \ + dim_t n_max; \ + dim_t pd_p; \ +\ + /* Prepare the packing destination buffer. */ \ + PASTECH2(bao_,ch,packm_init_mem_b) \ + ( \ + k_alloc, n_alloc, nr, \ + cntx, \ + rntm, \ + mem, \ + thread \ + ); \ +\ + /* Determine the packing buffer and related parameters for matrix B. */ \ + PASTECH2(bao_,ch,packm_init_b) \ + ( \ + &schema, \ + k, n, nr, \ + &k_max, &n_max, \ + p, rs_p, cs_p, \ + &pd_p, ps_p, \ + mem \ + ); \ +\ + /* Pack matrix B to the destination buffer chosen above. Here, the packed + matrix is stored to row-stored k x NR micropanels. */ \ + PASTECH2(bao_,ch,packm_var1) \ + ( \ + conj, \ + schema, \ + k, \ + n, \ + k_max, \ + n_max, \ + kappa, \ + d, incd, \ + b, rs_b, cs_b, \ + *p, *rs_p, *cs_p, \ + pd_p, *ps_p, \ + cntx, \ + thread \ + ); \ +\ + /* Barrier so that packing is done before computation. */ \ + bli_thread_barrier( thread ); \ +} + +//INSERT_GENTFUNC_BASIC0( packm_b ) +GENTFUNC( float, s, packm_b ) +GENTFUNC( double, d, packm_b ) +GENTFUNC( scomplex, c, packm_b ) +GENTFUNC( dcomplex, z, packm_b ) + diff --git a/addon/gemmd/bao_l3_packm_b.h b/addon/gemmd/bao_l3_packm_b.h new file mode 100644 index 0000000000..9161604ce9 --- /dev/null +++ b/addon/gemmd/bao_l3_packm_b.h @@ -0,0 +1,123 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#undef GENTPROT +#define GENTPROT( ctype, ch, opname ) \ +\ +void PASTECH2(bao_,ch,opname) \ + ( \ + dim_t k, \ + dim_t n, \ + dim_t nr, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ); \ + +//INSERT_GENTPROT_BASIC0( packm_init_mem_b ) +GENTPROT( float, s, packm_init_mem_b ) +GENTPROT( double, d, packm_init_mem_b ) +GENTPROT( scomplex, c, packm_init_mem_b ) +GENTPROT( dcomplex, z, packm_init_mem_b ) + + +#undef GENTPROT +#define GENTPROT( ctype, ch, opname ) \ +\ +void PASTECH2(bao_,ch,opname) \ + ( \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ); \ + +//INSERT_GENTPROT_BASIC0( packm_finalize_mem_b ) +GENTPROT( float, s, packm_finalize_mem_b ) +GENTPROT( double, d, packm_finalize_mem_b ) +GENTPROT( scomplex, c, packm_finalize_mem_b ) +GENTPROT( dcomplex, z, packm_finalize_mem_b ) + + +#undef GENTPROT +#define GENTPROT( ctype, ch, opname ) \ +\ +void PASTECH2(bao_,ch,opname) \ + ( \ + pack_t* restrict schema, \ + dim_t k, \ + dim_t n, \ + dim_t nr, \ + dim_t* restrict k_max, \ + dim_t* restrict n_max, \ + ctype** p, inc_t* restrict rs_p, inc_t* restrict cs_p, \ + dim_t* restrict pd_p, inc_t* restrict ps_p, \ + mem_t* restrict mem \ + ); \ + +//INSERT_GENTPROT_BASIC0( packm_init_b ) +GENTPROT( float, s, packm_init_b ) +GENTPROT( double, d, packm_init_b ) +GENTPROT( scomplex, c, packm_init_b ) +GENTPROT( dcomplex, z, packm_init_b ) + + +#undef GENTPROT +#define GENTPROT( ctype, ch, opname ) \ +\ +void PASTECH2(bao_,ch,opname) \ + ( \ + conj_t conj, \ + dim_t k_alloc, \ + dim_t n_alloc, \ + dim_t k, \ + dim_t n, \ + dim_t nr, \ + ctype* restrict kappa, \ + ctype* restrict d, inc_t incd, \ + ctype* restrict b, inc_t rs_b, inc_t cs_b, \ + ctype** restrict p, inc_t* restrict rs_p, inc_t* restrict cs_p, \ + inc_t* restrict ps_p, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ); \ + +//INSERT_GENTPROT_BASIC0( packm_b ) +GENTPROT( float, s, packm_b ) +GENTPROT( double, d, packm_b ) +GENTPROT( scomplex, c, packm_b ) +GENTPROT( dcomplex, z, packm_b ) + diff --git a/addon/gemmd/bao_l3_packm_var.h b/addon/gemmd/bao_l3_packm_var.h new file mode 100644 index 0000000000..063e59e5f8 --- /dev/null +++ b/addon/gemmd/bao_l3_packm_var.h @@ -0,0 +1,69 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +// +// Prototype BLAS-like interfaces to the variants. +// + +#undef GENTPROT +#define GENTPROT( ctype, ch, varname ) \ +\ +void PASTECH2(bao_,ch,varname) \ + ( \ + trans_t transc, \ + pack_t schema, \ + dim_t m, \ + dim_t n, \ + dim_t m_max, \ + dim_t n_max, \ + ctype* restrict kappa, \ + ctype* restrict d, inc_t incd, \ + ctype* restrict c, inc_t rs_c, inc_t cs_c, \ + ctype* restrict p, inc_t rs_p, inc_t cs_p, \ + dim_t pd_p, inc_t ps_p, \ + cntx_t* restrict cntx, \ + thrinfo_t* restrict thread \ + ); + +//INSERT_GENTPROT_BASIC0( packm_var1 ) +GENTPROT( float, s, packm_var1 ) +GENTPROT( double, d, packm_var1 ) +GENTPROT( scomplex, c, packm_var1 ) +GENTPROT( dcomplex, z, packm_var1 ) + +//INSERT_GENTPROT_BASIC0( packm_var2 ) +GENTPROT( float, s, packm_var2 ) +GENTPROT( double, d, packm_var2 ) +GENTPROT( scomplex, c, packm_var2 ) +GENTPROT( dcomplex, z, packm_var2 ) diff --git a/addon/gemmd/bao_l3_packm_var1.c b/addon/gemmd/bao_l3_packm_var1.c new file mode 100644 index 0000000000..24c0a2cc13 --- /dev/null +++ b/addon/gemmd/bao_l3_packm_var1.c @@ -0,0 +1,195 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +// +// Variant 1 provides basic support for packing by calling packm_cxk(). +// + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTECH2(bao_,ch,varname) \ + ( \ + trans_t transc, \ + pack_t schema, \ + dim_t m, \ + dim_t n, \ + dim_t m_max, \ + dim_t n_max, \ + ctype* restrict kappa, \ + ctype* restrict d, inc_t incd, \ + ctype* restrict c, inc_t rs_c, inc_t cs_c, \ + ctype* restrict p, inc_t rs_p, inc_t cs_p, \ + dim_t pd_p, inc_t ps_p, \ + cntx_t* restrict cntx, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + ctype* restrict kappa_cast = kappa; \ + ctype* restrict c_cast = c; \ + ctype* restrict p_cast = p; \ +\ + dim_t iter_dim; \ + dim_t n_iter; \ + dim_t it, ic; \ + dim_t ic0; \ + doff_t ic_inc; \ + dim_t panel_len; \ + dim_t panel_len_max; \ + dim_t panel_dim; \ + dim_t panel_dim_max; \ + inc_t incc; \ + inc_t ldc; \ + inc_t ldp; \ + conj_t conjc; \ +\ +\ + /* Extract the conjugation bit from the transposition argument. */ \ + conjc = bli_extract_conj( transc ); \ +\ + /* Create flags to incidate row or column storage. Note that the + schema bit that encodes row or column is describing the form of + micro-panel, not the storage in the micro-panel. Hence the + mismatch in "row" and "column" semantics. */ \ + bool row_stored = bli_is_col_packed( schema ); \ + /*bool col_stored = bli_is_row_packed( schema );*/ \ +\ + /* If the row storage flag indicates row storage, then we are packing + to column panels; otherwise, if the strides indicate column storage, + we are packing to row panels. */ \ + if ( row_stored ) \ + { \ + /* Prepare to pack to row-stored column panels. */ \ + iter_dim = n; \ + panel_len = m; \ + panel_len_max = m_max; \ + panel_dim_max = pd_p; \ + incc = cs_c; \ + ldc = rs_c; \ + ldp = rs_p; \ + } \ + else /* if ( col_stored ) */ \ + { \ + /* Prepare to pack to column-stored row panels. */ \ + iter_dim = m; \ + panel_len = n; \ + panel_len_max = n_max; \ + panel_dim_max = pd_p; \ + incc = rs_c; \ + ldc = cs_c; \ + ldp = cs_p; \ + } \ +\ + /* Compute the total number of iterations we'll need. */ \ + n_iter = iter_dim / panel_dim_max + ( iter_dim % panel_dim_max ? 1 : 0 ); \ +\ + /* Set the initial values and increments for indices related to C and P + based on whether reverse iteration was requested. */ \ + { \ + ic0 = 0; \ + ic_inc = panel_dim_max; \ + } \ +\ + ctype* restrict p_begin = p_cast; \ +\ + /* Query the number of threads and thread ids from the current thread's + packm thrinfo_t node. */ \ + const dim_t nt = bli_thread_n_way( thread ); \ + const dim_t tid = bli_thread_work_id( thread ); \ +\ + /* Suppress warnings in case tid isn't used (ie: as in slab partitioning). */ \ + ( void )nt; \ + ( void )tid; \ +\ + dim_t it_start, it_end, it_inc; \ +\ + /* Determine the thread range and increment using the current thread's + packm thrinfo_t node. NOTE: The definition of bli_thread_range_jrir() + will depend on whether slab or round-robin partitioning was requested + at configure-time. */ \ + bli_thread_range_jrir( thread, n_iter, 1, FALSE, &it_start, &it_end, &it_inc ); \ +\ + /* Iterate over every logical micropanel in the source matrix. */ \ + for ( ic = ic0, it = 0; it < n_iter; \ + ic += ic_inc, it += 1 ) \ + { \ + panel_dim = bli_min( panel_dim_max, iter_dim - ic ); \ +\ + ctype* restrict c_begin = c_cast + (ic )*incc; \ +\ + ctype* restrict c_use = c_begin; \ + ctype* restrict p_use = p_begin; \ +\ + /* The definition of bli_packm_my_iter() will depend on whether slab + or round-robin partitioning was requested at configure-time. (The + default is slab.) */ \ + if ( bli_packm_my_iter( it, it_start, it_end, tid, nt ) ) \ + { \ + PASTECH2(bao_,ch,packm_cxk) \ + ( \ + conjc, \ + schema, \ + panel_dim, \ + panel_dim_max, \ + panel_len, \ + panel_len_max, \ + kappa_cast, \ + d, incd, \ + c_use, incc, ldc, \ + p_use, ldp, \ + cntx \ + ); \ + } \ +\ +/* +if ( !row_stored ) \ +PASTEMAC(ch,fprintm)( stdout, "packm_var1: a packed", panel_dim_max, panel_len_max, \ + p_use, rs_p, cs_p, "%5.2f", "" ); \ +else \ +PASTEMAC(ch,fprintm)( stdout, "packm_var1: b packed", panel_len_max, panel_dim_max, \ + p_use, rs_p, cs_p, "%5.2f", "" ); \ +*/ \ +\ + p_begin += ps_p; \ + } \ +} + +//INSERT_GENTFUNC_BASIC0( packm_var1 ) +GENTFUNC( float, s, packm_var1 ) +GENTFUNC( double, d, packm_var1 ) +GENTFUNC( scomplex, c, packm_var1 ) +GENTFUNC( dcomplex, z, packm_var1 ) + diff --git a/addon/gemmd/bao_l3_packm_var2.c b/addon/gemmd/bao_l3_packm_var2.c new file mode 100644 index 0000000000..830e499b31 --- /dev/null +++ b/addon/gemmd/bao_l3_packm_var2.c @@ -0,0 +1,245 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +// +// Variant 2 is similar to variant 1, but inlines the contents of packm_cxk(). +// + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTECH2(bao_,ch,varname) \ + ( \ + trans_t transc, \ + pack_t schema, \ + dim_t m, \ + dim_t n, \ + dim_t m_max, \ + dim_t n_max, \ + ctype* restrict kappa, \ + ctype* restrict d, inc_t incd, \ + ctype* restrict c, inc_t rs_c, inc_t cs_c, \ + ctype* restrict p, inc_t rs_p, inc_t cs_p, \ + dim_t pd_p, inc_t ps_p, \ + cntx_t* restrict cntx, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + ctype* restrict kappa_cast = kappa; \ + ctype* restrict c_cast = c; \ + ctype* restrict p_cast = p; \ +\ + dim_t iter_dim; \ + dim_t n_iter; \ + dim_t it, ic; \ + dim_t ic0; \ + doff_t ic_inc; \ + dim_t panel_len; \ + dim_t panel_len_max; \ + dim_t panel_dim; \ + dim_t panel_dim_max; \ + inc_t incc; \ + inc_t ldc; \ + inc_t ldp; \ + conj_t conjc; \ +\ +\ + /* Extract the conjugation bit from the transposition argument. */ \ + conjc = bli_extract_conj( transc ); \ +\ + /* Create flags to incidate row or column storage. Note that the + schema bit that encodes row or column is describing the form of + micro-panel, not the storage in the micro-panel. Hence the + mismatch in "row" and "column" semantics. */ \ + bool row_stored = bli_is_col_packed( schema ); \ + /*bool col_stored = bli_is_row_packed( schema );*/ \ +\ + /* If the row storage flag indicates row storage, then we are packing + to column panels; otherwise, if the strides indicate column storage, + we are packing to row panels. */ \ + if ( row_stored ) \ + { \ + /* Prepare to pack to row-stored column panels. */ \ + iter_dim = n; \ + panel_len = m; \ + panel_len_max = m_max; \ + panel_dim_max = pd_p; \ + incc = cs_c; \ + ldc = rs_c; \ + ldp = rs_p; \ + } \ + else /* if ( col_stored ) */ \ + { \ + /* Prepare to pack to column-stored row panels. */ \ + iter_dim = m; \ + panel_len = n; \ + panel_len_max = n_max; \ + panel_dim_max = pd_p; \ + incc = rs_c; \ + ldc = cs_c; \ + ldp = cs_p; \ + } \ +\ + /* Compute the total number of iterations we'll need. */ \ + n_iter = iter_dim / panel_dim_max + ( iter_dim % panel_dim_max ? 1 : 0 ); \ +\ + /* Set the initial values and increments for indices related to C and P + based on whether reverse iteration was requested. */ \ + { \ + ic0 = 0; \ + ic_inc = panel_dim_max; \ + } \ +\ + ctype* restrict p_begin = p_cast; \ +\ + /* Query the number of threads and thread ids from the current thread's + packm thrinfo_t node. */ \ + const dim_t nt = bli_thread_n_way( thread ); \ + const dim_t tid = bli_thread_work_id( thread ); \ +\ + /* Suppress warnings in case tid isn't used (ie: as in slab partitioning). */ \ + ( void )nt; \ + ( void )tid; \ +\ + dim_t it_start, it_end, it_inc; \ +\ + /* Determine the thread range and increment using the current thread's + packm thrinfo_t node. NOTE: The definition of bli_thread_range_jrir() + will depend on whether slab or round-robin partitioning was requested + at configure-time. */ \ + bli_thread_range_jrir( thread, n_iter, 1, FALSE, &it_start, &it_end, &it_inc ); \ +\ + /* Iterate over every logical micropanel in the source matrix. */ \ + for ( ic = ic0, it = 0; it < n_iter; \ + ic += ic_inc, it += 1 ) \ + { \ + panel_dim = bli_min( panel_dim_max, iter_dim - ic ); \ +\ + ctype* restrict c_begin = c_cast + (ic )*incc; \ +\ + ctype* restrict c_use = c_begin; \ + ctype* restrict p_use = p_begin; \ +\ + /* The definition of bli_packm_my_iter() will depend on whether slab + or round-robin partitioning was requested at configure-time. (The + default is slab.) */ \ + if ( bli_packm_my_iter( it, it_start, it_end, tid, nt ) ) \ + { \ + /* NOTE: We assume here that kappa = 1 and therefore ignore it. If + we're wrong, this will get someone's attention. */ \ + if ( !PASTEMAC(ch,eq1)( *kappa_cast ) ) \ + bli_abort(); \ +\ + /* Perform the packing, taking conjc into account. */ \ + if ( bli_is_conj( conjc ) ) \ + { \ + for ( dim_t l = 0; l < panel_len; ++l ) \ + { \ + for ( dim_t d = 0; d < panel_dim; ++d ) \ + { \ + ctype* cld = c_use + (l )*ldc + (d )*incc; \ + ctype* pld = p_use + (l )*ldp + (d )*1; \ +\ + PASTEMAC(ch,copyjs)( *cld, *pld ); \ + } \ + } \ + } \ + else \ + { \ + for ( dim_t l = 0; l < panel_len; ++l ) \ + { \ + for ( dim_t d = 0; d < panel_dim; ++d ) \ + { \ + ctype* cld = c_use + (l )*ldc + (d )*incc; \ + ctype* pld = p_use + (l )*ldp + (d )*1; \ +\ + PASTEMAC(ch,copys)( *cld, *pld ); \ + } \ + } \ + } \ +\ + /* If panel_dim < panel_dim_max, then we zero those unused rows. */ \ + if ( panel_dim < panel_dim_max ) \ + { \ + const dim_t i = panel_dim; \ + const dim_t m_edge = panel_dim_max - panel_dim; \ + const dim_t n_edge = panel_len_max; \ + ctype* restrict p_edge = p_use + (i )*1; \ +\ + PASTEMAC(ch,set0s_mxn) \ + ( \ + m_edge, \ + n_edge, \ + p_edge, 1, ldp \ + ); \ + } \ +\ + /* If panel_len < panel_len_max, then we zero those unused columns. */ \ + if ( panel_len < panel_len_max ) \ + { \ + const dim_t j = panel_len; \ + const dim_t m_edge = panel_dim_max; \ + const dim_t n_edge = panel_len_max - panel_len; \ + ctype* restrict p_edge = p_use + (j )*ldp; \ +\ + PASTEMAC(ch,set0s_mxn) \ + ( \ + m_edge, \ + n_edge, \ + p_edge, 1, ldp \ + ); \ + } \ + } \ +\ +/* +if ( !row_stored ) \ +PASTEMAC(ch,fprintm)( stdout, "packm_var1: a packed", panel_dim_max, panel_len_max, \ + p_use, rs_p, cs_p, "%5.2f", "" ); \ +else \ +PASTEMAC(ch,fprintm)( stdout, "packm_var1: b packed", panel_len_max, panel_dim_max, \ + p_use, rs_p, cs_p, "%5.2f", "" ); \ +*/ \ +\ + p_begin += ps_p; \ + } \ +} + +//INSERT_GENTFUNC_BASIC0( packm_var1 ) +GENTFUNC( float, s, packm_var2 ) +GENTFUNC( double, d, packm_var2 ) +GENTFUNC( scomplex, c, packm_var2 ) +GENTFUNC( dcomplex, z, packm_var2 ) + diff --git a/addon/gemmd/bao_packm_cxk.c b/addon/gemmd/bao_packm_cxk.c new file mode 100644 index 0000000000..645f09d798 --- /dev/null +++ b/addon/gemmd/bao_packm_cxk.c @@ -0,0 +1,199 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTECH2(bao_,ch,opname) \ + ( \ + conj_t conja, \ + pack_t schema, \ + dim_t panel_dim, \ + dim_t panel_dim_max, \ + dim_t panel_len, \ + dim_t panel_len_max, \ + ctype* kappa, \ + ctype* d, inc_t incd, \ + ctype* a, inc_t inca, inc_t lda, \ + ctype* p, inc_t ldp, \ + cntx_t* cntx \ + ) \ +{ \ + /* Note that we use panel_dim_max, not panel_dim, to query the packm + kernel function pointer. This means that we always use the same + kernel, even for edge cases. */ \ + num_t dt = PASTEMAC(ch,type); \ + l1mkr_t ker_id = panel_dim_max; \ +\ + PASTECH2(ch,opname,_ker_ft) f; \ +\ + /* Query the context for the packm kernel corresponding to the current + panel dimension, or kernel id. If the id is invalid, the function will + return NULL. */ \ + f = bli_cntx_get_packm_ker_dt( dt, ker_id, cntx ); \ +\ + /* If there exists a kernel implementation for the micro-panel dimension + provided, we invoke the implementation. Otherwise, we use scal2m. */ \ + /* NOTE: We've disabled calling packm micro-kernels from the context for + this implementation. To re-enable, change FALSE to TRUE in the + conditional below. */ \ + if ( f != NULL && FALSE ) \ + { \ + f \ + ( \ + conja, \ + schema, \ + panel_dim, \ + panel_len, \ + panel_len_max, \ + kappa, \ + a, inca, lda, \ + p, ldp, \ + cntx \ + ); \ + } \ + else \ + { \ + /* NOTE: We assume here that kappa = 1 and therefore ignore it. If + we're wrong, this will get someone's attention. */ \ + if ( !PASTEMAC(ch,eq1)( *kappa ) ) \ + bli_abort(); \ +\ + if ( d == NULL ) \ + { \ + /* Perform the packing, taking conja into account. */ \ + if ( bli_is_conj( conja ) ) \ + { \ + for ( dim_t l = 0; l < panel_len; ++l ) \ + { \ + for ( dim_t i = 0; i < panel_dim; ++i ) \ + { \ + ctype* ali = a + (l )*lda + (i )*inca; \ + ctype* pli = p + (l )*ldp + (i )*1; \ +\ + PASTEMAC(ch,copyjs)( *ali, *pli ); \ + } \ + } \ + } \ + else \ + { \ + for ( dim_t l = 0; l < panel_len; ++l ) \ + { \ + for ( dim_t i = 0; i < panel_dim; ++i ) \ + { \ + ctype* ali = a + (l )*lda + (i )*inca; \ + ctype* pli = p + (l )*ldp + (i )*1; \ +\ + PASTEMAC(ch,copys)( *ali, *pli ); \ + } \ + } \ + } \ + } \ + else /* if ( d != NULL ) */ \ + { \ + /* Perform the packing, taking conja into account. */ \ + if ( bli_is_conj( conja ) ) \ + { \ + for ( dim_t l = 0; l < panel_len; ++l ) \ + { \ + for ( dim_t i = 0; i < panel_dim; ++i ) \ + { \ + ctype* ali = a + (l )*lda + (i )*inca; \ + ctype* dl = d + (l )*incd; \ + ctype* pli = p + (l )*ldp + (i )*1; \ +\ + /* Note that ali must be the second operand here since + that is what is conjugated by scal2js. */ \ + PASTEMAC(ch,scal2js)( *dl, *ali, *pli ); \ + } \ + } \ + } \ + else \ + { \ + for ( dim_t l = 0; l < panel_len; ++l ) \ + { \ + for ( dim_t i = 0; i < panel_dim; ++i ) \ + { \ + ctype* ali = a + (l )*lda + (i )*inca; \ + ctype* dl = d + (l )*incd; \ + ctype* pli = p + (l )*ldp + (i )*1; \ +\ + PASTEMAC(ch,scal2s)( *ali, *dl, *pli ); \ + } \ + } \ + } \ + } \ +\ + /* If panel_dim < panel_dim_max, then we zero those unused rows. */ \ + if ( panel_dim < panel_dim_max ) \ + { \ + const dim_t i = panel_dim; \ + const dim_t m_edge = panel_dim_max - panel_dim; \ + const dim_t n_edge = panel_len_max; \ + ctype* restrict p_edge = p + (i )*1; \ +\ + PASTEMAC(ch,set0s_mxn) \ + ( \ + m_edge, \ + n_edge, \ + p_edge, 1, ldp \ + ); \ + } \ +\ + /* If panel_len < panel_len_max, then we zero those unused columns. */ \ + if ( panel_len < panel_len_max ) \ + { \ + const dim_t j = panel_len; \ + const dim_t m_edge = panel_dim_max; \ + const dim_t n_edge = panel_len_max - panel_len; \ + ctype* restrict p_edge = p + (j )*ldp; \ +\ + PASTEMAC(ch,set0s_mxn) \ + ( \ + m_edge, \ + n_edge, \ + p_edge, 1, ldp \ + ); \ + } \ + } \ +} + +//INSERT_GENTFUNC_BASIC0( packm_cxk ) +GENTFUNC( float, s, packm_cxk ) +GENTFUNC( double, d, packm_cxk ) +GENTFUNC( scomplex, c, packm_cxk ) +GENTFUNC( dcomplex, z, packm_cxk ) + diff --git a/addon/gemmd/bao_packm_cxk.h b/addon/gemmd/bao_packm_cxk.h new file mode 100644 index 0000000000..3e977a7cc2 --- /dev/null +++ b/addon/gemmd/bao_packm_cxk.h @@ -0,0 +1,59 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + + +#undef GENTPROT +#define GENTPROT( ctype, ch, varname ) \ +\ +void PASTECH2(bao_,ch,varname) \ + ( \ + conj_t conja, \ + pack_t schema, \ + dim_t panel_dim, \ + dim_t panel_dim_max, \ + dim_t panel_len, \ + dim_t panel_len_max, \ + ctype* kappa, \ + ctype* d, inc_t incd, \ + ctype* a, inc_t inca, inc_t lda, \ + ctype* p, inc_t ldp, \ + cntx_t* cntx \ + ); + +//INSERT_GENTPROT_BASIC0( packm_cxk ) +GENTPROT( float, s, packm_cxk ) +GENTPROT( double, d, packm_cxk ) +GENTPROT( scomplex, c, packm_cxk ) +GENTPROT( dcomplex, z, packm_cxk ) + diff --git a/addon/gemmd/gemmd.h b/addon/gemmd/gemmd.h new file mode 100644 index 0000000000..cab61bd181 --- /dev/null +++ b/addon/gemmd/gemmd.h @@ -0,0 +1,54 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name of copyright holder(s) nor the names + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef GEMMD_H +#define GEMMD_H + +// This header should contain (or #include) any definitions that must be +// folded into blis.h. + +#include "bao_gemmd.h" +#include "bao_gemmd_check.h" +#include "bao_gemmd_var.h" + +#include "bao_l3_packm_a.h" +#include "bao_l3_packm_b.h" +#include "bao_l3_packm_var.h" + +#include "bao_packm_cxk.h" + +#include "bao_l3_decor.h" + + +#endif diff --git a/addon/gemmd/thread/bao_l3_decor.h b/addon/gemmd/thread/bao_l3_decor.h new file mode 100644 index 0000000000..b4fd2b9b76 --- /dev/null +++ b/addon/gemmd/thread/bao_l3_decor.h @@ -0,0 +1,75 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_SBX_L3_DECOR_H +#define BLIS_SBX_L3_DECOR_H + +// -- sup definitions ---------------------------------------------------------- + +// Level-3 sup internal function type. +typedef void (*l3sbxint_t) + ( + obj_t* alpha, + obj_t* a, + obj_t* d, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm, + thrinfo_t* thread + ); + +// Level-3 sup thread decorator prototype. +void bao_l3_thread_decorator + ( + l3sbxint_t func, + opid_t family, + obj_t* alpha, + obj_t* a, + obj_t* d, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm + ); + +// Include definitions specific to the method of multithreading. +#include "bao_l3_decor_single.h" +#include "bao_l3_decor_openmp.h" +#include "bao_l3_decor_pthreads.h" + +#endif + diff --git a/addon/gemmd/thread/bao_l3_decor_openmp.c b/addon/gemmd/thread/bao_l3_decor_openmp.c new file mode 100644 index 0000000000..1aca8de275 --- /dev/null +++ b/addon/gemmd/thread/bao_l3_decor_openmp.c @@ -0,0 +1,140 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#ifdef BLIS_ENABLE_OPENMP + +// Define a dummy thread entry function, which is needed in the pthreads +// version, so that when building Windows DLLs (with OpenMP enabled or with +// no multithreading) we don't risk having an unresolved symbol. +void* bao_l3_thread_entry( void* data_void ) { return NULL; } + +//#define PRINT_THRINFO + +void bao_l3_thread_decorator + ( + l3sbxint_t func, + opid_t family, + obj_t* alpha, + obj_t* a, + obj_t* d, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm + ) +{ + // Query the total number of threads from the rntm_t object. + const dim_t n_threads = bli_rntm_num_threads( rntm ); + + // NOTE: The sba was initialized in bli_init(). + + // Check out an array_t from the small block allocator. This is done + // with an internal lock to ensure only one application thread accesses + // the sba at a time. bli_sba_checkout_array() will also automatically + // resize the array_t, if necessary. + array_t* restrict array = bli_sba_checkout_array( n_threads ); + + // Access the pool_t* for thread 0 and embed it into the rntm. We do + // this up-front only so that we have the rntm_t.sba_pool field + // initialized and ready for the global communicator creation below. + bli_sba_rntm_set_pool( 0, array, rntm ); + + // Set the packing block allocator field of the rntm. This will be + // inherited by all of the child threads when they make local copies of + // the rntm below. + bli_pba_rntm_set_pba( rntm ); + + // Allcoate a global communicator for the root thrinfo_t structures. + thrcomm_t* restrict gl_comm = bli_thrcomm_create( rntm, n_threads ); + + + _Pragma( "omp parallel num_threads(n_threads)" ) + { + // Create a thread-local copy of the master thread's rntm_t. This is + // necessary since we want each thread to be able to track its own + // small block pool_t as it executes down the function stack. + rntm_t rntm_l = *rntm; + rntm_t* restrict rntm_p = &rntm_l; + + // Query the thread's id from OpenMP. + const dim_t tid = omp_get_thread_num(); + + // Check for a somewhat obscure OpenMP thread-mistmatch issue. + // NOTE: This calls the same function used for the conventional/large + // code path. + bli_l3_thread_decorator_thread_check( n_threads, tid, gl_comm, rntm_p ); + + // Use the thread id to access the appropriate pool_t* within the + // array_t, and use it to set the sba_pool field within the rntm_t. + // If the pool_t* element within the array_t is NULL, it will first + // be allocated/initialized. + bli_sba_rntm_set_pool( tid, array, rntm_p ); + + thrinfo_t* thread = NULL; + + // Create the root node of the thread's thrinfo_t structure. + bli_l3_sup_thrinfo_create_root( tid, gl_comm, rntm_p, &thread ); + + func + ( + alpha, + a, + d, + b, + beta, + c, + cntx, + rntm_p, + thread + ); + + // Free the current thread's thrinfo_t structure. + bli_l3_sup_thrinfo_free( rntm_p, thread ); + } + + // We shouldn't free the global communicator since it was already freed + // by the global communicator's chief thread in bli_l3_thrinfo_free() + // (called from the thread entry function). + + // Check the array_t back into the small block allocator. Similar to the + // check-out, this is done using a lock embedded within the sba to ensure + // mutual exclusion. + bli_sba_checkin_array( array ); +} + +#endif + diff --git a/addon/gemmd/thread/bao_l3_decor_openmp.h b/addon/gemmd/thread/bao_l3_decor_openmp.h new file mode 100644 index 0000000000..9c956d7c36 --- /dev/null +++ b/addon/gemmd/thread/bao_l3_decor_openmp.h @@ -0,0 +1,44 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_SBX_L3_DECOR_OPENMP_H +#define BLIS_SBX_L3_DECOR_OPENMP_H + +// Definitions specific to situations when OpenMP multithreading is enabled. +#ifdef BLIS_ENABLE_OPENMP + +#endif + +#endif + diff --git a/addon/gemmd/thread/bao_l3_decor_pthreads.c b/addon/gemmd/thread/bao_l3_decor_pthreads.c new file mode 100644 index 0000000000..587b8400f1 --- /dev/null +++ b/addon/gemmd/thread/bao_l3_decor_pthreads.c @@ -0,0 +1,220 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#ifdef BLIS_ENABLE_PTHREADS + +// A data structure to assist in passing operands to additional threads. +typedef struct thread_data +{ + l3sbxint_t func; + opid_t family; + obj_t* alpha; + obj_t* a; + obj_t* d; + obj_t* b; + obj_t* beta; + obj_t* c; + cntx_t* cntx; + rntm_t* rntm; + dim_t tid; + thrcomm_t* gl_comm; + array_t* array; +} thread_data_t; + +// Entry point function for additional threads. +void* bao_l3_thread_entry( void* data_void ) +{ + thread_data_t* data = data_void; + + l3sbxint_t func = data->func; + opid_t family = data->family; + obj_t* alpha = data->alpha; + obj_t* a = data->a; + obj_t* d = data->d; + obj_t* b = data->b; + obj_t* beta = data->beta; + obj_t* c = data->c; + cntx_t* cntx = data->cntx; + rntm_t* rntm = data->rntm; + dim_t tid = data->tid; + array_t* array = data->array; + thrcomm_t* gl_comm = data->gl_comm; + + ( void )family; + + // Create a thread-local copy of the master thread's rntm_t. This is + // necessary since we want each thread to be able to track its own + // small block pool_t as it executes down the function stack. + rntm_t rntm_l = *rntm; + rntm_t* restrict rntm_p = &rntm_l; + + // Use the thread id to access the appropriate pool_t* within the + // array_t, and use it to set the sba_pool field within the rntm_t. + // If the pool_t* element within the array_t is NULL, it will first + // be allocated/initialized. + bli_sba_rntm_set_pool( tid, array, rntm_p ); + + thrinfo_t* thread = NULL; + + // Create the root node of the current thread's thrinfo_t structure. + bli_l3_sup_thrinfo_create_root( tid, gl_comm, rntm_p, &thread ); + + func + ( + alpha, + a, + d, + b, + beta, + c, + cntx, + rntm_p, + thread + ); + + // Free the current thread's thrinfo_t structure. + bli_l3_sup_thrinfo_free( rntm_p, thread ); + + return NULL; +} + +void bao_l3_thread_decorator + ( + l3sbxint_t func, + opid_t family, + obj_t* alpha, + obj_t* a, + obj_t* d, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm + ) +{ + err_t r_val; + + // Query the total number of threads from the context. + const dim_t n_threads = bli_rntm_num_threads( rntm ); + + // NOTE: The sba was initialized in bli_init(). + + // Check out an array_t from the small block allocator. This is done + // with an internal lock to ensure only one application thread accesses + // the sba at a time. bli_sba_checkout_array() will also automatically + // resize the array_t, if necessary. + array_t* restrict array = bli_sba_checkout_array( n_threads ); + + // Access the pool_t* for thread 0 and embed it into the rntm. We do + // this up-front only so that we have the rntm_t.sba_pool field + // initialized and ready for the global communicator creation below. + bli_sba_rntm_set_pool( 0, array, rntm ); + + // Set the packing block allocator field of the rntm. This will be + // inherited by all of the child threads when they make local copies of + // the rntm below. + bli_pba_rntm_set_pba( rntm ); + + // Allocate a global communicator for the root thrinfo_t structures. + thrcomm_t* restrict gl_comm = bli_thrcomm_create( rntm, n_threads ); + + // Allocate an array of pthread objects and auxiliary data structs to pass + // to the thread entry functions. + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_l3_thread_decorator().pth: " ); + #endif + bli_pthread_t* pthreads = bli_malloc_intl( sizeof( bli_pthread_t ) * n_threads, &r_val ); + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_l3_thread_decorator().pth: " ); + #endif + thread_data_t* datas = bli_malloc_intl( sizeof( thread_data_t ) * n_threads, &r_val ); + + // NOTE: We must iterate backwards so that the chief thread (thread id 0) + // can spawn all other threads before proceeding with its own computation. + for ( dim_t tid = n_threads - 1; 0 <= tid; tid-- ) + { + // Set up thread data for additional threads (beyond thread 0). + datas[tid].func = func; + datas[tid].family = family; + datas[tid].alpha = alpha; + datas[tid].a = a; + datas[tid].d = d; + datas[tid].b = b; + datas[tid].beta = beta; + datas[tid].c = c; + datas[tid].cntx = cntx; + datas[tid].rntm = rntm; + datas[tid].tid = tid; + datas[tid].gl_comm = gl_comm; + datas[tid].array = array; + + // Spawn additional threads for ids greater than 1. + if ( tid != 0 ) + bli_pthread_create( &pthreads[tid], NULL, &bao_l3_thread_entry, &datas[tid] ); + else + bao_l3_thread_entry( ( void* )(&datas[0]) ); + } + + // We shouldn't free the global communicator since it was already freed + // by the global communicator's chief thread in bli_l3_thrinfo_free() + // (called from the thread entry function). + + // Thread 0 waits for additional threads to finish. + for ( dim_t tid = 1; tid < n_threads; tid++ ) + { + bli_pthread_join( pthreads[tid], NULL ); + } + + // Check the array_t back into the small block allocator. Similar to the + // check-out, this is done using a lock embedded within the sba to ensure + // mutual exclusion. + bli_sba_checkin_array( array ); + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_l3_thread_decorator().pth: " ); + #endif + bli_free_intl( pthreads ); + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_l3_thread_decorator().pth: " ); + #endif + bli_free_intl( datas ); +} + +#endif + diff --git a/addon/gemmd/thread/bao_l3_decor_pthreads.h b/addon/gemmd/thread/bao_l3_decor_pthreads.h new file mode 100644 index 0000000000..69adec45ee --- /dev/null +++ b/addon/gemmd/thread/bao_l3_decor_pthreads.h @@ -0,0 +1,47 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_SBX_L3_DECOR_PTHREADS_H +#define BLIS_SBX_L3_DECOR_PTHREADS_H + +// Definitions specific to situations when POSIX multithreading is enabled. +#ifdef BLIS_ENABLE_PTHREADS + +// Thread entry point prototype. +void* bao_l3_thread_entry( void* data_void ); + +#endif + +#endif + diff --git a/addon/gemmd/thread/bao_l3_decor_single.c b/addon/gemmd/thread/bao_l3_decor_single.c new file mode 100644 index 0000000000..d60891d65b --- /dev/null +++ b/addon/gemmd/thread/bao_l3_decor_single.c @@ -0,0 +1,143 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#ifndef BLIS_ENABLE_MULTITHREADING + +#define SKIP_THRINFO_TREE + +void bao_l3_thread_decorator + ( + l3sbxint_t func, + opid_t family, + //pack_t schema_a, + //pack_t schema_b, + obj_t* alpha, + obj_t* a, + obj_t* d, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm + ) +{ + // For sequential execution, we use only one thread. + const dim_t n_threads = 1; + + // NOTE: The sba was initialized in bli_init(). + + // Check out an array_t from the small block allocator. This is done + // with an internal lock to ensure only one application thread accesses + // the sba at a time. bli_sba_checkout_array() will also automatically + // resize the array_t, if necessary. + array_t* restrict array = bli_sba_checkout_array( n_threads ); + + // Access the pool_t* for thread 0 and embed it into the rntm. + bli_sba_rntm_set_pool( 0, array, rntm ); + + // Set the packing block allocator field of the rntm. + bli_pba_rntm_set_pba( rntm ); + +#ifndef SKIP_THRINFO_TREE + // Allcoate a global communicator for the root thrinfo_t structures. + thrcomm_t* restrict gl_comm = bli_thrcomm_create( rntm, n_threads ); +#endif + + + { + // NOTE: We don't need to create another copy of the rntm_t since + // it was already copied in one of the high-level oapi functions. + rntm_t* restrict rntm_p = rntm; + + // There is only one thread id (for the thief thread). + const dim_t tid = 0; + + // Use the thread id to access the appropriate pool_t* within the + // array_t, and use it to set the sba_pool field within the rntm_t. + // If the pool_t* element within the array_t is NULL, it will first + // be allocated/initialized. + // NOTE: This is commented out because, in the single-threaded case, + // this is redundant since it's already been done above. + //bli_sba_rntm_set_pool( tid, array, rntm_p ); + +#ifndef SKIP_THRINFO_TREE + thrinfo_t* thread = NULL; + + // Create the root node of the thread's thrinfo_t structure. + bli_l3_sup_thrinfo_create_root( tid, gl_comm, rntm_p, &thread ); +#else + // This optimization allows us to use one of the global thrinfo_t + // objects for single-threaded execution rather than grow one from + // scratch. The key is that bli_thrinfo_sup_grow(), which is called + // from within the variants, will immediately return if it detects + // that the thrinfo_t* passed into it is either + // &BLIS_GEMM_SINGLE_THREADED or &BLIS_PACKM_SINGLE_THREADED. + thrinfo_t* thread = &BLIS_GEMM_SINGLE_THREADED; + + ( void )tid; +#endif + + func + ( + alpha, + a, + d, + b, + beta, + c, + cntx, + rntm_p, + thread + ); + +#ifndef SKIP_THRINFO_TREE + // Free the current thread's thrinfo_t structure. + bli_l3_sup_thrinfo_free( rntm_p, thread ); +#endif + } + + // We shouldn't free the global communicator since it was already freed + // by the global communicator's chief thread in bli_l3_thrinfo_free() + // (called above). + + // Check the array_t back into the small block allocator. Similar to the + // check-out, this is done using a lock embedded within the sba to ensure + // mutual exclusion. + bli_sba_checkin_array( array ); +} + +#endif + diff --git a/addon/gemmd/thread/bao_l3_decor_single.h b/addon/gemmd/thread/bao_l3_decor_single.h new file mode 100644 index 0000000000..211a43a894 --- /dev/null +++ b/addon/gemmd/thread/bao_l3_decor_single.h @@ -0,0 +1,44 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_SBX_L3_DECOR_SINGLE_H +#define BLIS_SBX_L3_DECOR_SINGLE_H + +// Definitions specific to situations when multithreading is disabled. +#ifndef BLIS_ENABLE_MULTITHREADING + +#endif + +#endif + diff --git a/build/bli_addon.h.in b/build/bli_addon.h.in new file mode 100644 index 0000000000..36a8e29bd1 --- /dev/null +++ b/build/bli_addon.h.in @@ -0,0 +1,47 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_ADDON_H +#define BLIS_ADDON_H + +#if @enable_addons@ +#define BLIS_ENABLE_ADDONS +#else +#define BLIS_DISABLE_ADDONS +#endif + +// Enabled addons +@addon_list_includes@ + +#endif diff --git a/build/config.mk.in b/build/config.mk.in index a880074e8f..eddb69f705 100644 --- a/build/config.mk.in +++ b/build/config.mk.in @@ -183,6 +183,10 @@ MK_ENABLE_CBLAS := @enable_cblas@ # Whether libblis will depend on libmemkind for certain memory allocations. MK_ENABLE_MEMKIND := @enable_memkind@ +# The names of the addons to include when building BLIS. If empty, no addons +# will be included. +ADDON_LIST := @addon_list@ + # The name of a sandbox defining an alternative gemm implementation. If empty, # no sandbox will be used and the conventional gemm implementation will remain # enabled. diff --git a/common.mk b/common.mk index 00b1d4354e..02f34360af 100644 --- a/common.mk +++ b/common.mk @@ -161,18 +161,35 @@ get-kernel-cflags-for = $(strip $(call load-var-for,CKOPTFLAGS,$(1)) \ # When compiling sandboxes, we use flags similar to those of general framework # source. This ensures that the same code can be linked and run across various -# sub-configurations. (If we switch to using refkern/kernel flags, we should -# prevent enabling sandboxes for umbrella families by verifying that -# config_list == config_name if --enable-sandbox is given.) +# sub-configurations. +get-addon-c99flags-for = $(strip $(call load-var-for,COPTFLAGS,$(1)) \ + $(call get-noopt-cflags-for,$(1)) \ + $(CADDONINCFLAGS) \ + $(BUILD_CPPFLAGS) \ + $(BUILD_SYMFLAGS) \ + ) +get-addon-cxxflags-for = $(strip $(call load-var-for,COPTFLAGS,$(1)) \ + $(call get-noopt-cxxflags-for,$(1)) \ + $(CADDONINCFLAGS) \ + $(BUILD_CPPFLAGS) \ + $(BUILD_SYMFLAGS) \ + ) + +# When compiling sandboxes, we use flags similar to those of general framework +# source. This ensures that the same code can be linked and run across various +# sub-configurations. (NOTE: If we ever switch to using refkernel or kernel +# flags, we should prevent enabling sandboxes for umbrella families by verifying +# that config_list == config_name if --enable-sandbox is given. THIS ALSO +# APPLIES TO ADDONS ABOVE.) get-sandbox-c99flags-for = $(strip $(call load-var-for,COPTFLAGS,$(1)) \ $(call get-noopt-cflags-for,$(1)) \ - $(CSBOXINCFLAGS) \ + $(CSANDINCFLAGS) \ $(BUILD_CPPFLAGS) \ $(BUILD_SYMFLAGS) \ ) get-sandbox-cxxflags-for = $(strip $(call load-var-for,COPTFLAGS,$(1)) \ $(call get-noopt-cxxflags-for,$(1)) \ - $(CSBOXINCFLAGS) \ + $(CSANDINCFLAGS) \ $(BUILD_CPPFLAGS) \ $(BUILD_SYMFLAGS) \ ) @@ -198,6 +215,8 @@ get-config-text-for = "('$(1)' CFLAGS for config code)" get-frame-text-for = "('$(1)' CFLAGS for framework code)" get-aocldtl-text-for = "('$(1)' CFLAGS for AOCL debug and trace code)" get-kernel-text-for = "('$(1)' CFLAGS for kernels)" +get-addon-c99text-for = "('$(1)' CFLAGS for addons)" +get-addon-cxxtext-for = "('$(1)' CXXFLAGS for addons)" get-sandbox-c99text-for = "('$(1)' CFLAGS for sandboxes)" get-sandbox-cxxtext-for = "('$(1)' CXXFLAGS for sandboxes)" @@ -212,6 +231,10 @@ get-sandbox-cxxtext-for = "('$(1)' CXXFLAGS for sandboxes)" files-that-contain = $(strip $(foreach f, $(1), $(if $(findstring $(2),$(f)),$(f),))) files-that-dont-contain = $(strip $(foreach f, $(1), $(if $(findstring $(2),$(f)),,$(f)))) +# Define a function that removes duplicate strings *without* using the sort +# function. +rm-dups = $(if $1,$(firstword $1) $(call rm-dups,$(filter-out $(firstword $1),$1))) + # # --- Include makefile configuration file -------------------------------------- @@ -297,6 +320,7 @@ FRAME_DIR := frame AOCLDTL_DIR := aocl_dtl REFKERN_DIR := ref_kernels KERNELS_DIR := kernels +ADDON_DIR := addon SANDBOX_DIR := sandbox OBJ_DIR := obj LIB_DIR := lib @@ -313,12 +337,13 @@ REFNM := ref # Source suffixes. CONFIG_SRC_SUFS := c - KERNELS_SRC_SUFS := c s S - FRAME_SRC_SUFS := c AOCLDTL_SRC_SUFS := c +ADDON_C99_SUFS := c +ADDON_CXX_SUFS := cc cpp cxx +ADDON_SRC_SUFS := $(ADDON_C99_SUFS) $(ADDON_CXX_SUFS) SANDBOX_C99_SUFS := c SANDBOX_CXX_SUFS := cc cpp cxx @@ -328,6 +353,9 @@ SANDBOX_SRC_SUFS := $(SANDBOX_C99_SUFS) $(SANDBOX_CXX_SUFS) FRAME_HDR_SUFS := h AOCLDTL_HDR_SUFS := h +ADDON_H99_SUFS := h +ADDON_HXX_SUFS := hh hpp hxx +ADDON_HDR_SUFS := $(ADDON_H99_SUFS) $(ADDON_HXX_SUFS) SANDBOX_H99_SUFS := h SANDBOX_HXX_SUFS := hh hpp hxx @@ -335,10 +363,12 @@ SANDBOX_HDR_SUFS := $(SANDBOX_H99_SUFS) $(SANDBOX_HXX_SUFS) # Combine all header suffixes and remove duplicates via sort(). ALL_HDR_SUFS := $(sort $(FRAME_HDR_SUFS) \ + $(ADDON_HDR_SUFS) \ $(SANDBOX_HDR_SUFS) \ $(AOCLDTL_HDR_SUFS)) ALL_H99_SUFS := $(sort $(FRAME_HDR_SUFS) \ + $(ADDON_HDR_SUFS) \ $(SANDBOX_H99_SUFS) \ $(AOCLDTL_HDR_SUFS)) @@ -366,12 +396,14 @@ SHELL := bash # Construct paths to the four primary directories of source code: # the config directory, general framework code, reference kernel code, -# and optimized kernel code. +# and optimized kernel code. Also process paths for addon and sandbox +# directories. CONFIG_PATH := $(DIST_PATH)/$(CONFIG_DIR) FRAME_PATH := $(DIST_PATH)/$(FRAME_DIR) AOCLDTL_PATH := $(DIST_PATH)/$(AOCLDTL_DIR) REFKERN_PATH := $(DIST_PATH)/$(REFKERN_DIR) KERNELS_PATH := $(DIST_PATH)/$(KERNELS_DIR) +ADDON_PATH := $(DIST_PATH)/$(ADDON_DIR) SANDBOX_PATH := $(DIST_PATH)/$(SANDBOX_DIR) # Construct paths to some optional C++ template headers contributed by AMD. @@ -386,6 +418,7 @@ FRAME_FRAG_PATH := ./obj/$(CONFIG_NAME)/$(FRAME_DIR) AOCLDTL_FRAG_PATH := ./obj/$(CONFIG_NAME)/$(AOCLDTL_DIR) REFKERN_FRAG_PATH := ./obj/$(CONFIG_NAME)/$(REFKERN_DIR) KERNELS_FRAG_PATH := ./obj/$(CONFIG_NAME)/$(KERNELS_DIR) +ADDON_FRAG_PATH := ./obj/$(CONFIG_NAME)/$(ADDON_DIR) SANDBOX_FRAG_PATH := ./obj/$(CONFIG_NAME)/$(SANDBOX_DIR) @@ -863,6 +896,7 @@ MK_KERNELS_SRC := MK_REFKERN_SRC := MK_FRAME_SRC := MK_AOCLDTL_SRC := +MK_ADDON_SRC := MK_SANDBOX_SRC := # -- config -- @@ -914,6 +948,24 @@ PARENT_PATH := $(OBJ_DIR)/$(CONFIG_NAME) -include $(addsuffix /$(FRAGMENT_MK), $(FRAME_FRAG_PATH)) -include $(addsuffix /$(FRAGMENT_MK), $(AOCLDTL_FRAG_PATH)) +# -- addon -- + +# Construct paths to each addon. +# NOTE: If $(ADDON_LIST) is empty (because no addon was enabled at configure- +# time) then $(ADDON_PATHS) will also be empty, which will cause no fragments +# to be included. +ADDON_PATHS := $(addprefix $(ADDON_FRAG_PATH)/, $(ADDON_LIST)) + +# This variable is used by the include statements as they recursively include +# one another. For the 'addons' directory, we initialize it to that directory +# in preparation to include the fragments in the configuration sub-directory. +PARENT_SRC_PATH := $(ADDON_PATH) +PARENT_PATH := $(ADDON_FRAG_PATH) + +# Recursively include the makefile fragments in each of the addons sub- +# directories. +-include $(addsuffix /$(FRAGMENT_MK), $(ADDON_PATHS)) + # -- sandbox -- # Construct paths to each sandbox. (At present, there can be only one.) @@ -931,6 +983,8 @@ PARENT_PATH := $(SANDBOX_FRAG_PATH) # Recursively include the makefile fragments in the sandbox sub-directory. -include $(addsuffix /$(FRAGMENT_MK), $(SANDBOX_PATHS)) +# -- post-processing -- + # Create a list of the makefile fragments using the variable into which each # of the above include statements accumulated their directory paths. MAKEFILE_FRAGMENTS := $(addsuffix /$(FRAGMENT_MK), $(FRAGMENT_DIR_PATHS)) @@ -949,14 +1003,14 @@ endif # # Define a function that will expand all of the directory paths given in $(1) -# to actual filepaths using the list of suffixes provided $(2). +# to actual filepaths using the list of suffixes provided in $(2). get-filepaths = $(strip $(foreach path, $(1), \ $(foreach suf, $(2), \ $(wildcard $(path)/*.$(suf)) \ ) ) ) # Define a function that will expand all of the directory paths given in $(1) -# to actual filepaths using the list of suffixes provided $(2), taking only +# to actual filepaths using the list of suffixes provided in $(2), taking only # the first expansion from each directory with at least one file matching # the current suffix. Finally, strip the filenames from all resulting files, # returning only the directory paths. @@ -966,20 +1020,29 @@ get-dirpaths = $(dir $(foreach path, $(1), \ $(wildcard $(path)/*.$(suf)) \ ) ) ) ) -# We'll use two directory lists. The first is a list of all of the directories -# in which makefile fragments were generated (plus the current directory). The -# second is the subset of the first that begins with the sandbox root path. +# We'll use three directory lists. The first is a list of all of the directories +# in which makefile fragments were generated, plus the current directory. (The +# current directory is needed so we include bli_config.h and bli_addon.h in the +# processing of header files.) The second and third are subsets of the first +# that begins with the addon and sandbox root paths, respectively. ALLFRAG_DIR_PATHS := . $(FRAGMENT_DIR_PATHS) +ADDON_DIR_PATHS := $(filter $(ADDON_PATH)/%,$(ALLFRAG_DIR_PATHS)) SANDBOX_DIR_PATHS := $(filter $(SANDBOX_PATH)/%,$(ALLFRAG_DIR_PATHS)) ALL_H99_FILES := $(call get-filepaths,$(ALLFRAG_DIR_PATHS),$(ALL_H99_SUFS)) -FRAME_H99_FILES := $(filter-out $(SANDBOX_PATH)/%,$(ALL_H99_FILES)) +FRAME_H99_FILES := $(filter-out $(ADDON_PATH)/%, \ + $(filter-out $(SANDBOX_PATH)/%, \ + $(ALL_H99_FILES) \ + ) ) -ALL_H99_DIRPATHS := $(call get-dirpaths,$(ALLFRAG_DIR_PATHS),$(ALL_H99_SUFS)) +ALL_H99_DIRPATHS := $(call get-dirpaths,$(ALLFRAG_DIR_PATHS),$(ALL_H99_SUFS)) -SANDBOX_H99_FILES := $(call get-filepaths,$(SANDBOX_DIR_PATHS),$(SANDBOX_H99_SUFS)) -SANDBOX_HXX_FILES := $(call get-filepaths,$(SANDBOX_DIR_PATHS),$(SANDBOX_HXX_SUFS)) +ADDON_H99_FILES := $(call get-filepaths,$(ADDON_DIR_PATHS),$(ADDON_H99_SUFS)) +ADDON_HXX_FILES := $(call get-filepaths,$(ADDON_DIR_PATHS),$(ADDON_HXX_SUFS)) +ADDON_HDR_DIRPATHS := $(call get-dirpaths,$(ADDON_DIR_PATHS),$(ALL_HDR_SUFS)) +SANDBOX_H99_FILES := $(call get-filepaths,$(SANDBOX_DIR_PATHS),$(SANDBOX_H99_SUFS)) +SANDBOX_HXX_FILES := $(call get-filepaths,$(SANDBOX_DIR_PATHS),$(SANDBOX_HXX_SUFS)) SANDBOX_HDR_DIRPATHS := $(call get-dirpaths,$(SANDBOX_DIR_PATHS),$(ALL_HDR_SUFS)) @@ -1032,8 +1095,8 @@ CBLAS_H_FLAT := $(BASE_INC_PATH)/$(CBLAS_H) # # Obtain a list of header files #included inside of the bli_cntx_ref.c file. -# Paths to these files will be needed when compiling with the monolithic -# header. +# Due to the way that bli_cntx_ref.c uses headers and macros, paths to these +# files will be needed when compiling bli_cntx_ref.c with the monolithic header. ifeq ($(strip $(SHARE_PATH)),.) REF_KER_SRC := $(DIST_PATH)/$(REFKERN_DIR)/bli_cntx_ref.c REF_KER_HEADERS := $(shell $(GREP) "\#include" $(REF_KER_SRC) | sed -e "s/\#include [\"<]\([a-zA-Z0-9\_\.\/\-]*\)[\">].*/\1/g" | $(GREP) -v $(BLIS_H)) @@ -1041,9 +1104,10 @@ endif # Match each header found above with the path to that header, and then strip # leading, trailing, and internal whitespace. -REF_KER_H_PATHS := $(strip $(foreach header, $(REF_KER_HEADERS), \ - $(dir $(filter %/$(header), \ - $(FRAME_H99_FILES))))) +REF_KER_H_PATHS := $(call rm-dups,$(strip \ + $(foreach header, $(REF_KER_HEADERS), \ + $(dir $(filter %/$(header), \ + $(FRAME_H99_FILES)))))) # Add -I to each header path so we can specify our include search paths to the # C compiler. Then add frame/include since it's needed for bli_oapi_w[o]_cntx.h. @@ -1055,17 +1119,29 @@ REF_KER_I_PATHS += -I$(DIST_PATH)/frame/include # now #include the monolithic/flattened blis.h instead. CINCFLAGS := -I$(BASE_INC_PATH) $(REF_KER_I_PATHS) +# If CBLAS is enabled, we also include the path to the cblas.h directory so +# that the compiler will be able to find cblas.h as the CBLAS source code is +# being compiled. +ifeq ($(MK_ENABLE_CBLAS),yes) +CINCFLAGS += -I$(CBLAS_H_DIRPATH) +endif + +# Obtain a list of header paths in the configured addons. Then add -I to each +# header path. +CADDONINCFLAGS := $(strip $(patsubst %, -I%, $(ADDON_HDR_DIRPATHS))) + # Obtain a list of header paths in the configured sandbox. Then add -I to each # header path. -CSBOXINCFLAGS := $(strip $(patsubst %, -I%, $(SANDBOX_HDR_DIRPATHS))) +CSANDINCFLAGS := $(strip $(patsubst %, -I%, $(SANDBOX_HDR_DIRPATHS))) # # --- BLIS configuration header definitions ------------------------------------ # -# This file was created by configure, but we need to define it here so we can -# remove it as part of the clean targets. +# These files were created by configure, but we need to define them here so we +# can remove them as part of the clean targets. +BLIS_ADDON_H := ./bli_addon.h BLIS_CONFIG_H := ./bli_config.h diff --git a/configure b/configure index f49ea19e5e..b35348abdc 100755 --- a/configure +++ b/configure @@ -264,6 +264,15 @@ print_usage() echo " \"small\" depends on thresholds that may vary by sub-" echo " configuration." echo " " + echo " -a NAME --enable-addon=NAME" + echo " " + echo " Enable the code provided by an addon. An addon consists" + echo " of a separate directory of code that provides additional" + echo " APIs, implementations, and/or operations that would" + echo " otherwise not be present within a build of BLIS. This" + echo " option may be used multiple times to specify the inclusion" + echo " of multiple addons. By default, no addons are enabled." + echo " " echo " -s NAME --enable-sandbox=NAME" echo " " echo " Enable a separate sandbox implementation of gemm. This" @@ -940,6 +949,18 @@ canonicalize_ws() echo "${str}" } +rm_duplicate_words_simple() +{ + local str revstr revres res + + str="$1" + + # Remote duplicates, keeping the first occurrence. + res=$(echo "${str}" | awk '{for (i=1;i<=NF;i++) if (!a[$i]++) printf("%s%s",$i,FS)}{printf("\n")}') + + echo "${res}" +} + rm_duplicate_words() { local str revstr revres res @@ -1915,6 +1936,13 @@ main() bli_config_h_in_path="${build_dirpath}/${bli_config_h_in}" bli_config_h_out_path="${cur_dirpath}/${bli_config_h_out}" + # The names/paths for the template bli_addon.h.in and its instantiated + # counterpart. + bli_addon_h_in='bli_addon.h.in' + bli_addon_h_out='bli_addon.h' + bli_addon_h_in_path="${build_dirpath}/${bli_addon_h_in}" + bli_addon_h_out_path="${cur_dirpath}/${bli_addon_h_out}" + # Path to 'mirror-tree.sh' script. mirror_tree_sh="${build_dirpath}/mirror-tree.sh" @@ -1941,6 +1969,9 @@ main() # The root directory of the BLIS framework. aocldtl_dir='aocl_dtl' aocldtl_dirpath="${dist_path}/${aocldtl_dir}" + # The names of the addons. + addon_dir='addon' + addon_dirpath="${dist_path}/${addon_dir}" # The name of the sandbox directory. sandbox_dir='sandbox' @@ -2049,6 +2080,10 @@ main() force_version='no' complex_return='default' + # The addon flag and names. + addon_flag='' + addon_list='' + # The sandbox flag and name. sandbox_flag='' sandbox='' @@ -2093,7 +2128,7 @@ main() # Process our command line options. unset OPTIND - while getopts ":hp:d:e:s:t:r:qci:b:-:" opt; do + while getopts ":hp:d:e:a:s:t:r:qci:b:-:" opt; do case $opt in -) case "$OPTARG" in @@ -2194,12 +2229,21 @@ main() disable-mem-tracing) enable_mem_tracing='no' ;; + enable-addon=*) + addon_flag=1 + addon_name=${OPTARG#*=} + # Append the addon name to the list. + addon_list="${addon_list} ${addon_name}" + ;; + disable-addon) + addon_flag='' + ;; enable-sandbox=*) sandbox_flag=1 sandbox=${OPTARG#*=} ;; disable-sandbox) - sandbox_flag=0 + sandbox_flag='' ;; int-size=*) int_type_size=${OPTARG#*=} @@ -2282,6 +2326,12 @@ main() e) export_shared=$OPTARG ;; + a) + addon_flag=1 + addon_name=$OPTARG + # Append the addon name to the list. + addon_list="${addon_list} ${addon_name}" + ;; s) sandbox_flag=1 sandbox=$OPTARG @@ -3141,6 +3191,34 @@ main() exit 1 fi + # Check if addons were given. + if [ -n "${addon_flag}" ]; then + + # Remove duplicates in the addon list, if they exist. + addon_list=$(rm_duplicate_words_simple "${addon_list}") + + echo "${script_name}: configuring with addons:" + + for addon in ${addon_list}; do + + echo "${script_name}: ${addon_dir}/${addon}" + + addon_fullpath="${addon_dirpath}/${addon}" + + if [ ! -d "${addon_fullpath}" ]; then + echo "${script_name}: requested addon sub-directory does not exist! Cannot continue." + echo "${script_name}: *** Please verify addon existence and name." + exit 1 + fi + done + + enable_addons_01=1 + else + echo "${script_name}: configuring with no addons." + + enable_addons_01=0 + fi + # Check if a sandbox was given. if [ -n "${sandbox_flag}" ]; then @@ -3292,6 +3370,15 @@ main() kernel_list_defines="${kernel_list_defines}#define ${kernel_define}\n" done + # Create a list of #includes, one for each addon in addon_list. + addon_list_includes="" + for addon in ${addon_list}; do + + # Create a #define and add it to the running list. + addon_header="\"${addon}.h\"" + addon_list_includes="${addon_list_includes}#include ${addon_header}\n" + done + # -- Determine whether we are performing an out-of-tree build -------------- @@ -3319,7 +3406,7 @@ main() fi - # -- Instantiate config.mk, bli_config.h files from templates -------------- + # -- Instantiate config.mk file from template ------------------------------ # Begin substituting information into the config_mk_in file, outputting # to config_mk_out. @@ -3365,6 +3452,7 @@ main() | sed -e "s/@enable_cblas@/${enable_cblas}/g" \ | sed -e "s/@enable_memkind@/${enable_memkind}/g" \ | sed -e "s/@pragma_omp_simd@/${pragma_omp_simd}/g" \ + | sed -e "s/@addon_list@/${addon_list}/g" \ | sed -e "s/@sandbox@/${sandbox}/g" \ | sed -e "s/@enable_trsm_preinversion@/${enable_trsm_preinversion}/g" \ | sed -e "s/@enable_aocl_dynamic@/${enable_aocl_dynamic}/g" \ @@ -3373,6 +3461,7 @@ main() | sed -e "s/\@enable_aocl_zen\@/${enable_aocl_zen}/g" \ > "${config_mk_out_path}" + # -- Instantiate bli_config.h file from template --------------------------- # Begin substituting information into the bli_config_h_in file, outputting # to bli_config_h_out. NOTE: We use perl instead of sed because the version @@ -3409,6 +3498,17 @@ main() | sed -e "s/@complex_return_intel@/${complex_return_intel01}/g" \ > "${bli_config_h_out_path}" + # -- Instantiate bli_addon.h file from template ---------------------------- + + # Begin substituting information into the bli_addon_h_in file, outputting + # to bli_addon_h_out. NOTE: We use perl instead of sed because the version + # of sed used on OS X is old and does not handle the '\n' character + # intuitively, which was used when constructing ${addon_list_includes}. + echo "${script_name}: creating ${bli_addon_h_out_path} from ${bli_addon_h_in_path}" + cat "${bli_addon_h_in_path}" \ + | perl -pe "s/\@addon_list_includes\@/${addon_list_includes}/g" \ + | sed -e "s/@enable_addons@/${enable_addons_01}/g" \ + > "${bli_addon_h_out_path}" # -- Create top-level object directories ----------------------------------- @@ -3421,7 +3521,6 @@ main() obj_config_dirpath="${base_obj_dirpath}/${config_dir}" - #echo "${script_name}: creating ${obj_config_dirpath}" mkdir -p ${obj_config_dirpath} for conf in ${config_list}; do echo "${script_name}: creating ${obj_config_dirpath}/${conf}" @@ -3431,7 +3530,6 @@ main() obj_kernels_dirpath="${base_obj_dirpath}/${kernels_dir}" - #echo "${script_name}: creating ${obj_kernels_dirpath}" mkdir -p ${obj_kernels_dirpath} for kern in ${kernel_list}; do echo "${script_name}: creating ${obj_kernels_dirpath}/${kern}" @@ -3441,7 +3539,6 @@ main() obj_refkern_dirpath="${base_obj_dirpath}/${refkern_dir}" - #echo "${script_name}: creating ${obj_refkern_dirpath}" mkdir -p ${obj_refkern_dirpath} for conf in ${config_list}; do echo "${script_name}: creating ${obj_refkern_dirpath}/${conf}" @@ -3460,6 +3557,18 @@ main() echo "${script_name}: creating ${obj_frame_dirpath}" mkdir -p ${obj_frame_dirpath} + + if [ -n "${addon_flag}" ]; then + + obj_addon_dirpath="${base_obj_dirpath}/${addon_dir}" + + for addon in ${addon_list}; do + echo "${script_name}: creating ${obj_addon_dirpath}/${addon}" + mkdir -p ${obj_addon_dirpath}/${addon} + done + fi + + if [ -n "${sandbox_flag}" ]; then obj_sandbox_dirpath="${base_obj_dirpath}/${sandbox_dir}" @@ -3487,6 +3596,7 @@ main() echo "${script_name}: creating ${base_lib_dirpath}" mkdir -p ${base_lib_dirpath} + # Create include directory (if it does not already exist). base_include_dirpath="${include_dirpath}/${config_name}" @@ -3545,6 +3655,16 @@ main() echo "${script_name}: mirroring ${aocldtl_dirpath} to ${obj_aocldtl_dirpath}" ${mirror_tree_sh} ${aocldtl_dirpath} ${obj_aocldtl_dirpath} + # Mirror the chosen addon source tree to its object sub-directory. + if [ -n "${addon_flag}" ]; then + + for addon in ${addon_list}; do + + echo "${script_name}: mirroring ${addon_dirpath}/${addon} to ${obj_addon_dirpath}/${addon}" + ${mirror_tree_sh} "${addon_dirpath}/${addon}" "${obj_addon_dirpath}/${addon}" + done + fi + # Mirror the chosen sandbox source tree to its object sub-directory. if [ -n "${sandbox_flag}" ]; then @@ -3643,6 +3763,25 @@ main() ${gen_make_frags_dirpath}/suffix_list \ ${gen_make_frags_dirpath}/ignore_list + # Generate makefile fragments in the addon sub-directory. + if [ -n "${addon_flag}" ]; then + + for addon in ${addon_list}; do + + echo "${script_name}: creating makefile fragments in ${obj_addon_dirpath}/${addon}" + ${gen_make_frags_sh} \ + -h -r -v0 \ + -o ${script_name} \ + -p 'ADDON' \ + ${addon_dirpath}/${addon} \ + ${obj_addon_dirpath}/${addon} \ + ${gen_make_frags_dirpath}/fragment.mk \ + ${gen_make_frags_dirpath}/suffix_list \ + ${gen_make_frags_dirpath}/ignore_list + done + fi + + # Generate makefile fragments in the sandbox sub-directory. if [ -n "${sandbox_flag}" ]; then diff --git a/docs/Addons.md b/docs/Addons.md new file mode 100644 index 0000000000..595cebfa4b --- /dev/null +++ b/docs/Addons.md @@ -0,0 +1,231 @@ +## Contents + +* **[Introduction](Addons.md#introduction)** +* **[Enabling addons](Addons.md#enabling-addons)** +* **[Addon rules](Addons.md#addon-rules)** +* **[Caveats](Addons.md#caveats)** +* **[Known issues](Addons.md#known-issues)** +* **[Conclusion](Addons.md#conclusion)** + + +## Introduction + +This file briefly describes the requirements for building a custom BLIS +*addon*. + +Simply put, an addon in BLIS provides additional APIs, operations, and/or +implementations that may be useful to certain users. An addon can be +thought of as a standalone extension of BLIS that does not depend on any +other addon, although addons may utilize existing functionality or kernels +within the core framework. + +By definition, an addon should *never* provide APIs that conflict with +the interfaces that belong to either the [typed API](BLISTypedAPI.md) or the +[object API](BLISObjectAPI.md). Thus, you'll never have to worry about a +properly constructed (and properly functioning) addon interfering with or +otherwise changing core BLIS functionality. + +How does an addon differ from a [sandbox](Sandboxes.md)? Great question! +Sometimes you want to include additional BLIS-like functionality that does +not relate directly to `gemm` or any other BLIS operation. +(By contrast, a sandbox requires you to implement `gemm` whether you want +to or not.) +Furthermore, you may wish to enable multiple addons simultaneously. +(By contrast, only one sandbox may be enabled at a time.) +Thus, the addon feature provides additional flexibility to some +users in a way that sandboxes cannot, while still providing many of the +conveniences of sandboxes. + +## Enabling an addon + +To enable an existing addon at configure-time, you simply specify it as an +option to `configure`. Either of the following usages are accepted: +``` +$ ./configure --enable-addon=foobar auto +$ ./configure -a foobar auto +``` +Here, we tell `configure` that we want to use the `foobar` addon, which +corresponds to a subdirectory of the `addon` directory named `foobar`. +(Reminder: the `auto` argument is the configuration target and +unrelated to addons.) + +You may also enable multiple addons within the same build of BLIS: +``` +$ ./configure -a foobar -a thing1 -a thing2 auto +``` +Note that the default behavior of `configure` is that no addons are enabled. + +As `configure` runs, you should get output that includes lines +similar to: +``` +configure: configuring with addons: +configure: addon/foobar +configure: addon/thing1 +configure: addon/thing2 +``` +And when you build BLIS, the addon source code will be among the last files to +be compiled: +``` +Compiling obj/haswell/addon/foobar/foobar.o ('haswell' CFLAGS for addons) +Compiling obj/haswell/addon/thing1/thing1.o ('haswell' CFLAGS for addons) +Compiling obj/haswell/addon/thing1/thing1_api.o ('haswell' CFLAGS for addons) +Compiling obj/haswell/addon/thing2/thing2_api.o ('haswell' CFLAGS for addons) +... +``` +That's it! After the BLIS library is built, it will contain your chosen +addons. You can always confirm this by using `nm` to confirm the presence +of your API symbols: +``` +$ nm lib/haswell/libblis.a | grep foobar +foobar.o: +0000000000000000 T foobar +``` + +## Addon rules + +Please follow these guidelines for the best developer experience when +creating addons. + +1. As with sandboxes, you don't need to worry about creating makefiles. The +BLIS build system will take care of this for you. :) By configuring BLIS with +an addon enabled, `make` will scan your addon subdirectory and compile +all of its source code using similar compilation rules as were used for the rest +of the framework. In addition, the compilation command line will automatically +contain one `-I` option for every subdirectory in your addon, +so it doesn't matter where in your addon directory hierarchy you place your +header files -- they will be found! + +2. We recommend that you write your addon in C99. While you *may* use C++11 +to implement your addon, you should provide a C99 wrapper API to your +implementation so that others can interface with it. There is no guarantee +that the end-user will be using a C++11 compiler, and therefore you should +limit the definitions in your addon header to those that are C99 compliant. +If you write your addon in C++11, you must use one of the BLIS-approved file +extensions for your source files (`.cc`, `.cpp`, `.cxx`) and your local +header files (`.hh`, `.hpp`, `.hxx`). +Note that `blis.h` already contains all of its definitions inside of an +`extern "C"` block, so you should be able to `#include "blis.h"` from your +C++11 source code without any issues. + +3. All of your code related to the addon should reside within the named +addon directory, or some subdirectory therein. If your addon requires +new kernels, you should add kernel source code to an appropriate +microarchitecture-specific subdirectory within the top-level `kernels` +directory so that they are compiled with the correct +microarchitecture-specific optimization flags. + +4. If your addon is named `foobar`, the BLIS build system will expect to +find a header called `foobar.h` somewhere in the `addon/foobar` directory +(or one of its subdirectories). This `foobar.h` header will automatically +be inlined into the monolithic `blis.h` header that is produced by the +BLIS build system. `foobar.h` may `#include` other local headers, each of +which will also (recursively) get inlined into `blis.h`. However, you may +choose to omit some local addon headers from `foobar.h.` You might do this, +for example, because those headers define things that are not needed in +order for the end user to call your addon code. + +5. Your addon APIs will always be available within static library builds of +BLIS, but if you want your addon APIs to be exported as public APIs within +*shared* library builds of BLIS, you'll need to annotate the prototypes +accordingly. (BLIS makes its shared library symbols private by default; this +allows us to export only those functions that we consider to be part of the +public APIs.) This annotation can be done by prefixing function prototypes +with the `BLIS_EXPORT_ADDON` macro as follows: +```c +BLIS_EXPORT_ADDON void foobar_calc( void* a, void* b ); +``` + +6. Do not define any symbols in your addon that conflict with any symbols within +the core framework. For example, don't define a function called `bli_copym()` +in your addon since that function is already defined within BLIS. + +7. Do not define any symbols in your addon that conflict with any symbols within +the C99 standard libraries/headers. For example, don't define a function called +`printf()` since that function is already defined within the C99 standard library. + +8. *Try* to not define any symbols in your addon that conflict with symbols in any +other addon, unless your addon is meant to serve as an alternative to the +conflicting addon, in which case conflicting symbol names is okay (since you +will presumably never build with both addons enabled). + +9. When choosing names for your addon files, avoid source filenames that already +exist within BLIS. For example, don't name one of your files `bli_obj.c` +since that file would compile into `bli_obj.o`, which will have already been +placed into the library by the build system. + +10. Similarly, avoid header filenames that already exist within BLIS or C99. +For example, don't name one of your header files `bli_obj.h` since that file +already exists in BLIS. Also, don't name one of your header files `math.h` +since that name would conflict with the `math.h` defined by C99. (This also +means you shouldn't name your addon `math` since normally that name would +require that you provide a `math.h` header inside the addon directory.) + +If you follow these rules, you will be much more likely to have a pleasant +experience integrating your BLIS addon into the larger framework. + +## Caveats + +Notice that the BLIS addons are limited in what they can accomplish. Generally +speaking, addons cannot change existing implementations within BLIS. Instead, +addons aim to provide a way to quickly augment BLIS with additional bundles of +code that extend BLIS's set of functionality in some interesting way. If you +want to define new BLAS-like functions, but don't know where to start, creating +a new addon is an appropriate place to start experimenting. If you want to +change or refactor existing BLIS code, an addon is probably not suited for your +needs. + +Another important limitation is the fact that the build system currently uses +"framework `CFLAGS`" when compiling the addon source files. These are the same +`CFLAGS` used when compiling general framework source code, +``` +# Example framework CFLAGS used by 'haswell' sub-configuration +-O2 -Wall -Wno-unused-function -Wfatal-errors -fPIC -std=c99 +-D_POSIX_C_SOURCE=200112L -Iinclude/haswell -I./frame/3/ +-I./frame/1m/ -I./frame/1f/ -I./frame/1/ -I./frame/include +-DBLIS_VERSION_STRING=\"0.8.1-195\" -fvisibility=hidden +``` +which are likely more general-purpose than the `CFLAGS` used for, say, +optimized kernels or even reference kernels: +``` +# Example optimized kernel CFLAGS used by 'haswell' sub-configuration +-O3 -fomit-frame-pointer -mavx2 -mfma -mfpmath=sse -march=haswell -Wall +-Wno-unused-function -Wfatal-errors -fPIC -std=c99 -D_POSIX_C_SOURCE=200112L +-Iinclude/haswell -I./frame/3/ -I./frame/1m/ -I./frame/1f/ -I./frame/1/ +-I./frame/include -DBLIS_VERSION_STRING=\"0.8.1-195\" -fvisibility=hidden +``` +(To see precisely which flags are being employed for any given file, enable +verbosity at compile-time via `make V=1`.) Compiling addons with these more +versatile `CFLAGS` compiler options means that we only need to compile one +instance of each addon source file, even when targeting multiple +configurations (for example, via `./configure x86_64`). However, it also means +that addons are not ideal for microkernels, as they sometimes need additional +compiler flags in order to +yield the highest performance. If you have a new microkernel you would like to +use within an addon, you can always develop it within that addon. However, +once it is stable and ready for use by others, it's best to move the kernel(s) +to the appropriate microarchitecture-specific subdirectory of the `kernels` +directory the kernel(s). This will allow the kernel to be compiled with the +appropriate microarchitecture-specific compiler flags. +Please see the +[Configuration Guide](ConfigurationHowTo) +for more details, and when in doubt, please don't be shy about seeking +guidance from BLIS developers by opening a +[new issue](https://github.com/flame/blis/issues) or sending a message to the +[blis-devel](http://groups.google.com/d/forum/blis-devel) mailing list. + +Notwithstanding these limitations, hopefully you still find BLIS addons +useful! + +## Known issues + +* None yet. + +## Conclusion + +If you encounter any problems, please open +a new [issue on GitHub](https://github.com/flame/blis/issues). + +If you are unsure about how something works, you can still open an issue. Or, you +can send a message to +[blis-devel](https://groups.google.com/d/forum/blis-devel) mailing list. + diff --git a/frame/base/bli_info.c b/frame/base/bli_info.c index bfd6f6fcc8..a3e9cb2ec5 100644 --- a/frame/base/bli_info.c +++ b/frame/base/bli_info.c @@ -40,7 +40,7 @@ // This string gets defined via -D on the command line when BLIS is compiled. // This string is (or rather, should be) only used here. -static char* bli_version_str = BLIS_VERSION_STRING; +static char* bli_version_str = "4.0"; //BLIS_VERSION_STRING; static char* bli_int_type_size_str = STRINGIFY_INT( BLIS_INT_TYPE_SIZE ); char* bli_info_get_version_str( void ) { return bli_version_str; } diff --git a/frame/include/bli_config_macro_defs.h b/frame/include/bli_config_macro_defs.h index d00df2f0be..cfdc9652fc 100644 --- a/frame/include/bli_config_macro_defs.h +++ b/frame/include/bli_config_macro_defs.h @@ -241,8 +241,9 @@ #endif #endif -#define BLIS_EXPORT_BLIS BLIS_EXPORT -#define BLIS_EXPORT_BLAS BLIS_EXPORT +#define BLIS_EXPORT_BLIS BLIS_EXPORT +#define BLIS_EXPORT_BLAS BLIS_EXPORT +#define BLIS_EXPORT_ADDON BLIS_EXPORT // -- STATIC INLINE FUNCTIONS -------------------------------------------------- diff --git a/frame/include/blis.h b/frame/include/blis.h index 783b5de0eb..b335fc59d4 100644 --- a/frame/include/blis.h +++ b/frame/include/blis.h @@ -186,6 +186,14 @@ extern "C" { #include "bli_util.h" +// -- addon definitions -- + +// NOTE: These definitions should not be included much earlier since an addon +// may wish to utilize other types and definitions provided by BLIS. + +#include "bli_addon.h" + + // -- sandbox implementation -- #include "bli_sbox.h" From a4abb1083145150988524a6f24e8e4128dff8b12 Mon Sep 17 00:00:00 2001 From: "Field G. Van Zee" Date: Fri, 28 May 2021 14:49:57 -0500 Subject: [PATCH 073/243] Added a new 'gemmlike' sandbox. Details: - Added a new sandbox called 'gemmlike', which implements sequential and multithreaded gemm in the style of gemmsup but also unconditionally employs packing. The purpose of this sandbox is to (1) avoid select abstractions, such as objects and control trees, in order to allow readers to better understand how a real-world implementation of high-performance gemm can be constructed; (2) provide a starting point for expert users who wish to build something that is gemm-like without "reinventing the wheel." Thanks to Jeff Diamond, Tze Meng Low, Nicholai Tukanov, and Devangi Parikh for requesting and inspiring this work. - The functions defined in this sandbox currently use the "bls_" prefix instead of "bli_" in order to avoid any symbol collisions in the main library. - The sandbox contains two variants, each of which implements gemm via a block-panel algorithm. The only difference between the two is that variant 1 calls the microkernel directly while variant 2 calls the microkernel indirectly, via a function wrapper, which allows the edge case handling to be abstracted away from the classic five loops. - This sandbox implementation utilizes the conventional gemm microkernel (not the skinny/unpacked gemmsup kernels). - Updated some typos in the comments of a few files in the main framework. Change-Id: Ifc3c50e9fd0072aada38eace50c57552c88cc6cf --- frame/3/bli_l3_sup_packm_a.c | 2 +- frame/3/bli_l3_sup_packm_b.c | 8 +- frame/include/bli_genarray_macro_defs.h | 14 + frame/include/bli_obj_macro_defs.h | 4 +- frame/include/blis.h | 5 +- sandbox/gemmlike/bli_gemmnat.c | 88 +++ sandbox/gemmlike/bli_sandbox.h | 56 ++ sandbox/gemmlike/bls_gemm.c | 304 +++++++++ sandbox/gemmlike/bls_gemm.h | 101 +++ sandbox/gemmlike/bls_gemm_bp_var1.c | 518 +++++++++++++++ sandbox/gemmlike/bls_gemm_bp_var2.c | 590 ++++++++++++++++++ sandbox/gemmlike/bls_gemm_var.h | 124 ++++ sandbox/gemmlike/bls_l3_packm_a.c | 328 ++++++++++ sandbox/gemmlike/bls_l3_packm_a.h | 122 ++++ sandbox/gemmlike/bls_l3_packm_b.c | 328 ++++++++++ sandbox/gemmlike/bls_l3_packm_b.h | 122 ++++ sandbox/gemmlike/bls_l3_packm_var.c | 198 ++++++ sandbox/gemmlike/bls_l3_packm_var.h | 63 ++ sandbox/gemmlike/thread/bls_l3_decor.h | 73 +++ sandbox/gemmlike/thread/bls_l3_decor_openmp.c | 138 ++++ sandbox/gemmlike/thread/bls_l3_decor_openmp.h | 44 ++ .../gemmlike/thread/bls_l3_decor_pthreads.c | 213 +++++++ .../gemmlike/thread/bls_l3_decor_pthreads.h | 47 ++ sandbox/gemmlike/thread/bls_l3_decor_single.c | 141 +++++ sandbox/gemmlike/thread/bls_l3_decor_single.h | 44 ++ sandbox/power10/bli_gemmnat.c | 9 +- 26 files changed, 3675 insertions(+), 9 deletions(-) create mode 100644 sandbox/gemmlike/bli_gemmnat.c create mode 100644 sandbox/gemmlike/bli_sandbox.h create mode 100644 sandbox/gemmlike/bls_gemm.c create mode 100644 sandbox/gemmlike/bls_gemm.h create mode 100644 sandbox/gemmlike/bls_gemm_bp_var1.c create mode 100644 sandbox/gemmlike/bls_gemm_bp_var2.c create mode 100644 sandbox/gemmlike/bls_gemm_var.h create mode 100644 sandbox/gemmlike/bls_l3_packm_a.c create mode 100644 sandbox/gemmlike/bls_l3_packm_a.h create mode 100644 sandbox/gemmlike/bls_l3_packm_b.c create mode 100644 sandbox/gemmlike/bls_l3_packm_b.h create mode 100644 sandbox/gemmlike/bls_l3_packm_var.c create mode 100644 sandbox/gemmlike/bls_l3_packm_var.h create mode 100644 sandbox/gemmlike/thread/bls_l3_decor.h create mode 100644 sandbox/gemmlike/thread/bls_l3_decor_openmp.c create mode 100644 sandbox/gemmlike/thread/bls_l3_decor_openmp.h create mode 100644 sandbox/gemmlike/thread/bls_l3_decor_pthreads.c create mode 100644 sandbox/gemmlike/thread/bls_l3_decor_pthreads.h create mode 100644 sandbox/gemmlike/thread/bls_l3_decor_single.c create mode 100644 sandbox/gemmlike/thread/bls_l3_decor_single.h diff --git a/frame/3/bli_l3_sup_packm_a.c b/frame/3/bli_l3_sup_packm_a.c index 6933b6906f..4e5f0e4444 100644 --- a/frame/3/bli_l3_sup_packm_a.c +++ b/frame/3/bli_l3_sup_packm_a.c @@ -58,7 +58,7 @@ void PASTEMAC(ch,opname) \ } \ else /* if ( will_pack == TRUE ) */ \ { \ - /* NOTE: This is "rounding up" of the last upanel is actually optional + /* NOTE: This "rounding up" of the last upanel is actually optional for the rrc/crc cases, but absolutely necessary for the other cases since we NEED that last micropanel to have the same ldim (cs_p) as the other micropanels. Why? So that millikernels can use the same diff --git a/frame/3/bli_l3_sup_packm_b.c b/frame/3/bli_l3_sup_packm_b.c index 20c41b6b0b..7d7c8815ab 100644 --- a/frame/3/bli_l3_sup_packm_b.c +++ b/frame/3/bli_l3_sup_packm_b.c @@ -58,7 +58,7 @@ void PASTEMAC(ch,opname) \ } \ else /* if ( will_pack == TRUE ) */ \ { \ - /* NOTE: This is "rounding up" of the last upanel is actually optional + /* NOTE: This "rounding up" of the last upanel is actually optional for the rrc/crc cases, but absolutely necessary for the other cases since we NEED that last micropanel to have the same ldim (cs_p) as the other micropanels. Why? So that millikernels can use the same @@ -285,15 +285,15 @@ void PASTEMAC(ch,opname) \ } \ else \ { \ - /* All other stor3_t ids: pack A to column-stored row-panels. */ \ + /* All other stor3_t ids: pack B to row-stored column-panels. */ \ *rs_p = nr; \ *cs_p = 1; \ \ *pd_p = nr; \ *ps_p = k * nr; \ \ - /* Set the schema to "packed row panels" to indicate packing to - conventional column-stored row panels. */ \ + /* Set the schema to "packed column panels" to indicate packing to + conventional row-stored column panels. */ \ *schema = BLIS_PACKED_COL_PANELS; \ } \ \ diff --git a/frame/include/bli_genarray_macro_defs.h b/frame/include/bli_genarray_macro_defs.h index 1e9c772fa6..a63af52a76 100644 --- a/frame/include/bli_genarray_macro_defs.h +++ b/frame/include/bli_genarray_macro_defs.h @@ -140,6 +140,20 @@ arrayname[BLIS_NUM_FP_TYPES][BLIS_NUM_FP_TYPES] = \ +// -- One-operand macro (with custom prefix) -- + +#define GENARRAY_PREF(arrayname,prefix,op) \ +\ +arrayname[BLIS_NUM_FP_TYPES] = \ +{ \ + PASTECH2(prefix,s,op), \ + PASTECH2(prefix,c,op), \ + PASTECH2(prefix,d,op), \ + PASTECH2(prefix,z,op) \ +} + + + // -- Two-operand macros -- diff --git a/frame/include/bli_obj_macro_defs.h b/frame/include/bli_obj_macro_defs.h index 2b3ac35ae0..855384425e 100644 --- a/frame/include/bli_obj_macro_defs.h +++ b/frame/include/bli_obj_macro_defs.h @@ -1190,7 +1190,7 @@ BLIS_INLINE stor3_t bli_obj_stor3_from_strides( obj_t* c, obj_t* a, obj_t* b ) // -- Initialization-related macros -- // Finish the initialization started by the matrix-specific static initializer -// (e.g. BLIS_OBJECT_PREINITIALIZER) +// (e.g. BLIS_OBJECT_INITIALIZER) // NOTE: This is intended only for use in the BLAS compatibility API and typed // BLIS API. @@ -1223,7 +1223,7 @@ BLIS_INLINE void bli_obj_init_finish( num_t dt, dim_t m, dim_t n, void* p, inc_t } // Finish the initialization started by the 1x1-specific static initializer -// (e.g. BLIS_OBJECT_PREINITIALIZER_1X1) +// (e.g. BLIS_OBJECT_INITIALIZER_1X1) // NOTE: This is intended only for use in the BLAS compatibility API and typed // BLIS API. diff --git a/frame/include/blis.h b/frame/include/blis.h index b335fc59d4..b0b0ba5476 100644 --- a/frame/include/blis.h +++ b/frame/include/blis.h @@ -190,8 +190,11 @@ extern "C" { // NOTE: These definitions should not be included much earlier since an addon // may wish to utilize other types and definitions provided by BLIS. - +// TODO: Disable addon header file inclusion for windows since configure +// script is not executed, and subsequently the header file ie not generated. +#if !defined(_WIN32) && !defined(__CYGWIN__) #include "bli_addon.h" +#endif // -- sandbox implementation -- diff --git a/sandbox/gemmlike/bli_gemmnat.c b/sandbox/gemmlike/bli_gemmnat.c new file mode 100644 index 0000000000..37fb701859 --- /dev/null +++ b/sandbox/gemmlike/bli_gemmnat.c @@ -0,0 +1,88 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +// Given the current architecture of BLIS sandboxes, bli_gemmnat() is the +// entry point to any sandbox implementation. + +// NOTE: This function is implemented identically to the function that it +// overrides in frame/ind/oapi/bli_l3_nat_oapi.c. This means that we are +// forgoing the option of customizing the implementations that underlie +// bli_gemm() and bli_?gemm(). Any new code defined in this sandbox +// directory, however, will be included in the BLIS. + +#include "blis.h" + +#undef GENFRONT +#define GENFRONT( opname, cname, imeth ) \ +\ +void PASTEMAC(opname,imeth) \ + ( \ + obj_t* alpha, \ + obj_t* a, \ + obj_t* b, \ + obj_t* beta, \ + obj_t* c, \ + cntx_t* cntx, \ + rntm_t* rntm \ + ) \ +{ \ +\ + /* A switch to easily toggle whether we use the sandbox implementation + of bls_gemm() as the implementation for bli_gemm(). (This allows for + easy testing of bls_gemm() via the testsuite.) */ \ + if ( 1 ) \ + { \ + bls_gemm_ex( alpha, a, b, beta, c, cntx, rntm ); \ + return; \ + } \ +\ + bli_init_once(); \ +\ + /* Obtain a valid (native) context from the gks if necessary. */ \ + if ( cntx == NULL ) cntx = bli_gks_query_cntx(); \ +\ + /* Initialize a local runtime with global settings if necessary. Note + that in the case that a runtime is passed in, we make a local copy. */ \ + rntm_t rntm_l; \ + if ( rntm == NULL ) { bli_rntm_init_from_global( &rntm_l ); rntm = &rntm_l; } \ + else { rntm_l = *rntm; rntm = &rntm_l; } \ +\ + /* Invoke the operation's front end. */ \ + PASTEMAC(opname,_front) \ + ( \ + alpha, a, b, beta, c, cntx, rntm, NULL \ + ); \ +} + +GENFRONT( gemm, gemm, nat ) diff --git a/sandbox/gemmlike/bli_sandbox.h b/sandbox/gemmlike/bli_sandbox.h new file mode 100644 index 0000000000..d6e6522e8c --- /dev/null +++ b/sandbox/gemmlike/bli_sandbox.h @@ -0,0 +1,56 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name of copyright holder(s) nor the names + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_SANDBOX_H +#define BLIS_SANDBOX_H + +// NOTE: This header is the only header required to be present in the sandbox +// implementation directory. + +// This header should contain (or #include) any definitions that must be +// folded into blis.h. Typically, it will remain empty since any header +// definitions specific to the sandbox implementation will not need to be +// made available to applications (or the framework) during compilation. + +#include "bls_gemm.h" +#include "bls_gemm_var.h" + +#include "bls_l3_packm_a.h" +#include "bls_l3_packm_b.h" +#include "bls_l3_packm_var.h" + +#include "bls_l3_decor.h" + + +#endif diff --git a/sandbox/gemmlike/bls_gemm.c b/sandbox/gemmlike/bls_gemm.c new file mode 100644 index 0000000000..3e4c9b2a33 --- /dev/null +++ b/sandbox/gemmlike/bls_gemm.c @@ -0,0 +1,304 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +// +// -- Define the gemm-like operation's object API ------------------------------ +// + +void bls_gemm + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c + ) +{ + bls_gemm_ex + ( + alpha, + a, + b, + beta, + c, + NULL, + NULL + ); +} + +void bls_gemm_ex + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm + ) +{ + bli_init_once(); + + // -- bli_gemmnat() -------------------------------------------------------- + + // Obtain a valid (native) context from the gks if necessary. + // NOTE: This must be done before calling the _check() function, since + // that function assumes the context pointer is valid. + if ( cntx == NULL ) cntx = bli_gks_query_cntx(); + + // Initialize a local runtime with global settings if necessary. Note + // that in the case that a runtime is passed in, we make a local copy. + rntm_t rntm_l; + if ( rntm == NULL ) { bli_rntm_init_from_global( &rntm_l ); rntm = &rntm_l; } + else { rntm_l = *rntm; rntm = &rntm_l; } + + // -- bli_gemm_front() ----------------------------------------------------- + + obj_t a_local; + obj_t b_local; + obj_t c_local; + + // Check parameters. + if ( bli_error_checking_is_enabled() ) + { + bli_gemm_check( alpha, a, b, beta, c, cntx ); + } + + // If C has a zero dimension, return early. + if ( bli_obj_has_zero_dim( c ) ) + { + return; + } + + // If alpha is zero, or if A or B has a zero dimension, scale C by beta + // and return early. + if ( bli_obj_equals( alpha, &BLIS_ZERO ) || + bli_obj_has_zero_dim( a ) || + bli_obj_has_zero_dim( b ) ) + { + bli_scalm( beta, c ); + return; + } + + // Alias A, B, and C in case we need to apply transformations. + bli_obj_alias_to( a, &a_local ); + bli_obj_alias_to( b, &b_local ); + bli_obj_alias_to( c, &c_local ); + + // Induce a transposition of A if it has its transposition property set. + // Then clear the transposition bit in the object. + if ( bli_obj_has_trans( &a_local ) ) + { + bli_obj_induce_trans( &a_local ); + bli_obj_set_onlytrans( BLIS_NO_TRANSPOSE, &a_local ); + } + + // Induce a transposition of B if it has its transposition property set. + // Then clear the transposition bit in the object. + if ( bli_obj_has_trans( &b_local ) ) + { + bli_obj_induce_trans( &b_local ); + bli_obj_set_onlytrans( BLIS_NO_TRANSPOSE, &b_local ); + } + + // An optimization: If C is stored by rows and the micro-kernel prefers + // contiguous columns, or if C is stored by columns and the micro-kernel + // prefers contiguous rows, transpose the entire operation to allow the + // micro-kernel to access elements of C in its preferred manner. + if ( bli_cntx_l3_vir_ukr_dislikes_storage_of( &c_local, BLIS_GEMM_UKR, cntx ) ) + { + bli_obj_swap( &a_local, &b_local ); + + bli_obj_induce_trans( &a_local ); + bli_obj_induce_trans( &b_local ); + bli_obj_induce_trans( &c_local ); + + // NOTE: This is probably not needed within the sandbox. + // We must also swap the pack schemas, which were set by bli_gemm_md() + // or the inlined code above. + //bli_obj_swap_pack_schemas( &a_local, &b_local ); + } + + // Parse and interpret the contents of the rntm_t object to properly + // set the ways of parallelism for each loop, and then make any + // additional modifications necessary for the current operation. + bli_rntm_set_ways_for_op + ( + BLIS_GEMM, + BLIS_LEFT, // ignored for gemm/hemm/symm + bli_obj_length( &c_local ), + bli_obj_width( &c_local ), + bli_obj_width( &a_local ), + rntm + ); + + // Spawn threads (if applicable), where bls_gemm_int() is the thread entry + // point function for each thread. This also begins the process of creating + // the thrinfo_t tree, which contains thread communicators. + bls_l3_thread_decorator + ( + bls_gemm_int, + BLIS_GEMM, // operation family id + alpha, + &a_local, + &b_local, + beta, + &c_local, + cntx, + rntm + ); +} + +// +// -- Define the gemm-like operation's thread entry point ---------------------- +// + +void bls_gemm_int + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm, + thrinfo_t* thread + ) +{ + // In this function, we choose the gemm implementation that is executed + // on each thread. + +#if 1 + // Call the block-panel algorithm that calls the kernel directly, which + // exposes edge-case handling. + bls_gemm_bp_var1 + ( + alpha, + a, + b, + beta, + c, + cntx, + rntm, + thread + ); +#else + // Call the block-panel algorithm that calls the kernel indirectly via a + // wrapper function, which hides edge-case handling. + bls_gemm_bp_var2 + ( + alpha, + a, + b, + beta, + c, + cntx, + rntm, + thread + ); +#endif +} + +// +// -- Define the gemm-like operation's typed API ------------------------------- +// + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTECH2(bls_,ch,opname) \ + ( \ + trans_t transa, \ + trans_t transb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + ctype* alpha, \ + ctype* a, inc_t rs_a, inc_t cs_a, \ + ctype* b, inc_t rs_b, inc_t cs_b, \ + ctype* beta, \ + ctype* c, inc_t rs_c, inc_t cs_c \ + ) \ +{ \ + bli_init_once(); \ +\ + /* Determine the datatype (e.g. BLIS_FLOAT, BLIS_DOUBLE, etc.) based on + the macro parameter 'ch' (e.g. s, d, etc). */ \ + const num_t dt = PASTEMAC(ch,type); \ +\ + obj_t alphao, ao, bo, betao, co; \ +\ + dim_t m_a, n_a; \ + dim_t m_b, n_b; \ +\ + /* Adjust the dimensions of matrices A and B according to the transa and + transb parameters. */ \ + bli_set_dims_with_trans( transa, m, k, &m_a, &n_a ); \ + bli_set_dims_with_trans( transb, k, n, &m_b, &n_b ); \ +\ + /* Create bufferless scalar objects and attach the provided scalar pointers + to those scalar objects. */ \ + bli_obj_create_1x1_with_attached_buffer( dt, alpha, &alphao ); \ + bli_obj_create_1x1_with_attached_buffer( dt, beta, &betao ); \ +\ + /* Create bufferless matrix objects and attach the provided matrix pointers + to those matrix objects. */ \ + bli_obj_create_with_attached_buffer( dt, m_a, n_a, a, rs_a, cs_a, &ao ); \ + bli_obj_create_with_attached_buffer( dt, m_b, n_b, b, rs_b, cs_b, &bo ); \ + bli_obj_create_with_attached_buffer( dt, m, n, c, rs_c, cs_c, &co ); \ +\ + /* Set the transposition/conjugation properties of the objects for matrices + A and B. */ \ + bli_obj_set_conjtrans( transa, &ao ); \ + bli_obj_set_conjtrans( transb, &bo ); \ +\ + /* Call the object interface. */ \ + PASTECH(bls_,opname) \ + ( \ + &alphao, \ + &ao, \ + &bo, \ + &betao, \ + &co \ + ); \ +} + +//INSERT_GENTFUNC_BASIC0( gemm ) +GENTFUNC( float, s, gemm ) +GENTFUNC( double, d, gemm ) +GENTFUNC( scomplex, c, gemm ) +GENTFUNC( dcomplex, z, gemm ) + diff --git a/sandbox/gemmlike/bls_gemm.h b/sandbox/gemmlike/bls_gemm.h new file mode 100644 index 0000000000..b296ac1c0f --- /dev/null +++ b/sandbox/gemmlike/bls_gemm.h @@ -0,0 +1,101 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +// +// -- Prototype the gemm-like operation's object API --------------------------- +// + +void bls_gemm + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c + ); + +void bls_gemm_ex + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm + ); + +// +// -- Prototype the gemm-like operation's thread entry point ------------------- +// + +void bls_gemm_int + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm, + thrinfo_t* thread + ); + +// +// -- Prototype the gemm-like operation's typed API ---------------------------- +// + +#undef GENTPROT +#define GENTPROT( ctype, ch, opname ) \ +\ +void PASTECH2(bls_,ch,opname) \ + ( \ + trans_t transa, \ + trans_t transb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + ctype* alpha, \ + ctype* a, inc_t rs_a, inc_t cs_a, \ + ctype* b, inc_t rs_b, inc_t cs_b, \ + ctype* beta, \ + ctype* c, inc_t rs_c, inc_t cs_c \ + ); + +//INSERT_GENTPROT_BASIC0( gemm ) +GENTPROT( float, s, gemm ) +GENTPROT( double, d, gemm ) +GENTPROT( scomplex, c, gemm ) +GENTPROT( dcomplex, z, gemm ) + diff --git a/sandbox/gemmlike/bls_gemm_bp_var1.c b/sandbox/gemmlike/bls_gemm_bp_var1.c new file mode 100644 index 0000000000..ae695ce34f --- /dev/null +++ b/sandbox/gemmlike/bls_gemm_bp_var1.c @@ -0,0 +1,518 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define FUNCPTR_T gemm_fp + +typedef void (*FUNCPTR_T) + ( + conj_t conja, + conj_t conjb, + dim_t m, + dim_t n, + dim_t k, + void* restrict alpha, + void* restrict a, inc_t rs_a, inc_t cs_a, + void* restrict b, inc_t rs_b, inc_t cs_b, + void* restrict beta, + void* restrict c, inc_t rs_c, inc_t cs_c, + cntx_t* restrict cntx, + rntm_t* restrict rntm, + thrinfo_t* restrict thread + ); + +// +// -- gemm-like block-panel algorithm (object interface) ----------------------- +// + +// Define a function pointer array named ftypes and initialize its contents with +// the addresses of the typed functions defined below, bls_?gemm_bp_var1(). +static FUNCPTR_T GENARRAY_PREF(ftypes,bls_,gemm_bp_var1); + +void bls_gemm_bp_var1 + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm, + thrinfo_t* thread + ) +{ + const num_t dt = bli_obj_dt( c ); + + const conj_t conja = bli_obj_conj_status( a ); + const conj_t conjb = bli_obj_conj_status( b ); + + const dim_t m = bli_obj_length( c ); + const dim_t n = bli_obj_width( c ); + const dim_t k = bli_obj_width( a ); + + void* restrict buf_a = bli_obj_buffer_at_off( a ); + const inc_t rs_a = bli_obj_row_stride( a ); + const inc_t cs_a = bli_obj_col_stride( a ); + + void* restrict buf_b = bli_obj_buffer_at_off( b ); + const inc_t rs_b = bli_obj_row_stride( b ); + const inc_t cs_b = bli_obj_col_stride( b ); + + void* restrict buf_c = bli_obj_buffer_at_off( c ); + const inc_t rs_c = bli_obj_row_stride( c ); + const inc_t cs_c = bli_obj_col_stride( c ); + + void* restrict buf_alpha = bli_obj_buffer_for_1x1( dt, alpha ); + void* restrict buf_beta = bli_obj_buffer_for_1x1( dt, beta ); + + // Index into the function pointer array to extract the correct + // typed function pointer based on the chosen datatype. + FUNCPTR_T f = ftypes[dt]; + + // Invoke the function. + f + ( + conja, + conjb, + m, + n, + k, + buf_alpha, + buf_a, rs_a, cs_a, + buf_b, rs_b, cs_b, + buf_beta, + buf_c, rs_c, cs_c, + cntx, + rntm, + thread + ); +} + +// +// -- gemm-like block-panel algorithm (typed interface) ------------------------ +// + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTECH2(bls_,ch,varname) \ + ( \ + conj_t conja, \ + conj_t conjb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + void* restrict alpha, \ + void* restrict a, inc_t rs_a, inc_t cs_a, \ + void* restrict b, inc_t rs_b, inc_t cs_b, \ + void* restrict beta, \ + void* restrict c, inc_t rs_c, inc_t cs_c, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + const num_t dt = PASTEMAC(ch,type); \ +\ + /* Query the context for various blocksizes. */ \ + const dim_t NR = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx ); \ + const dim_t MR = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx ); \ + const dim_t NC = bli_cntx_get_blksz_def_dt( dt, BLIS_NC, cntx ); \ + const dim_t MC = bli_cntx_get_blksz_def_dt( dt, BLIS_MC, cntx ); \ + const dim_t KC = bli_cntx_get_blksz_def_dt( dt, BLIS_KC, cntx ); \ +\ + /* Query the context for the microkernel address and cast it to its + function pointer type. */ \ + PASTECH(ch,gemm_ukr_ft) \ + gemm_ukr = bli_cntx_get_l3_nat_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \ +\ + /* Temporary C buffer for edge cases. Note that the strides of this + temporary buffer are set so that they match the storage of the + original C matrix. For example, if C is column-stored, ct will be + column-stored as well. */ \ + ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ + / sizeof( ctype ) ] \ + __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ + const bool col_pref = bli_cntx_l3_nat_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ + const inc_t rs_ct = ( col_pref ? 1 : NR ); \ + const inc_t cs_ct = ( col_pref ? MR : 1 ); \ +\ + /* Compute partitioning step values for each matrix of each loop. */ \ + const inc_t jcstep_c = cs_c; \ + const inc_t jcstep_b = cs_b; \ +\ + const inc_t pcstep_a = cs_a; \ + const inc_t pcstep_b = rs_b; \ +\ + const inc_t icstep_c = rs_c; \ + const inc_t icstep_a = rs_a; \ +\ + const inc_t jrstep_c = cs_c * NR; \ +\ + const inc_t irstep_c = rs_c * MR; \ +\ + ctype* restrict a_00 = a; \ + ctype* restrict b_00 = b; \ + ctype* restrict c_00 = c; \ + ctype* restrict alpha_cast = alpha; \ + ctype* restrict beta_cast = beta; \ +\ + /* Make local copies of the scalars to prevent any unnecessary sharing of + cache lines between the cores' caches. */ \ + ctype alpha_local = *alpha_cast; \ + ctype beta_local = *beta_cast; \ + ctype one_local = *PASTEMAC(ch,1); \ + ctype zero_local = *PASTEMAC(ch,0); \ +\ + auxinfo_t aux; \ +\ + /* Initialize a mem_t entry for A and B. Strictly speaking, this is only + needed for the matrix we will be packing (if any), but we do it + unconditionally to be safe. */ \ + mem_t mem_a = BLIS_MEM_INITIALIZER; \ + mem_t mem_b = BLIS_MEM_INITIALIZER; \ +\ + /* Define an array of bszid_t ids, which will act as our substitute for + the cntl_t tree. */ \ + bszid_t bszids[8] = { BLIS_NC, /* 5th loop */ \ + BLIS_KC, /* 4th loop */ \ + BLIS_NO_PART, /* pack B */ \ + BLIS_MC, /* 3rd loop */ \ + BLIS_NO_PART, /* pack A */ \ + BLIS_NR, /* 2nd loop */ \ + BLIS_MR, /* 1st loop */ \ + BLIS_KR }; /* microkernel loop */ \ +\ + bszid_t* restrict bszids_jc = &bszids[0]; \ + bszid_t* restrict bszids_pc = &bszids[1]; \ + /*bszid_t* restrict bszids_pb = &bszids[2];*/ \ + bszid_t* restrict bszids_ic = &bszids[3]; \ + /*bszid_t* restrict bszids_pa = &bszids[4];*/ \ + bszid_t* restrict bszids_jr = &bszids[5]; \ + /*bszid_t* restrict bszids_ir = &bszids[6];*/ \ +\ + thrinfo_t* restrict thread_jc = NULL; \ + thrinfo_t* restrict thread_pc = NULL; \ + thrinfo_t* restrict thread_pb = NULL; \ + thrinfo_t* restrict thread_ic = NULL; \ + thrinfo_t* restrict thread_pa = NULL; \ + thrinfo_t* restrict thread_jr = NULL; \ + thrinfo_t* restrict thread_ir = NULL; \ +\ + /* Identify the current thrinfo_t node and then grow the tree. */ \ + thread_jc = thread; \ + bli_thrinfo_sup_grow( rntm, bszids_jc, thread_jc ); \ +\ + /* Compute the JC loop thread range for the current thread. */ \ + dim_t jc_start, jc_end; \ + bli_thread_range_sub( thread_jc, n, NR, FALSE, &jc_start, &jc_end ); \ + const dim_t n_local = jc_end - jc_start; \ +\ + /* Compute number of primary and leftover components of the JC loop. */ \ + /*const dim_t jc_iter = ( n_local + NC - 1 ) / NC;*/ \ + const dim_t jc_left = n_local % NC; \ +\ + /* Loop over the n dimension (NC rows/columns at a time). */ \ + for ( dim_t jj = jc_start; jj < jc_end; jj += NC ) \ + { \ + /* Calculate the thread's current JC block dimension. */ \ + const dim_t nc_cur = ( NC <= jc_end - jj ? NC : jc_left ); \ +\ + ctype* restrict b_jc = b_00 + jj * jcstep_b; \ + ctype* restrict c_jc = c_00 + jj * jcstep_c; \ +\ + /* Identify the current thrinfo_t node and then grow the tree. */ \ + thread_pc = bli_thrinfo_sub_node( thread_jc ); \ + bli_thrinfo_sup_grow( rntm, bszids_pc, thread_pc ); \ +\ + /* Compute the PC loop thread range for the current thread. */ \ + const dim_t pc_start = 0, pc_end = k; \ + const dim_t k_local = k; \ +\ + /* Compute number of primary and leftover components of the PC loop. */ \ + /*const dim_t pc_iter = ( k_local + KC - 1 ) / KC;*/ \ + const dim_t pc_left = k_local % KC; \ +\ + /* Loop over the k dimension (KC rows/columns at a time). */ \ + for ( dim_t pp = pc_start; pp < pc_end; pp += KC ) \ + { \ + /* Calculate the thread's current PC block dimension. */ \ + const dim_t kc_cur = ( KC <= pc_end - pp ? KC : pc_left ); \ +\ + ctype* restrict a_pc = a_00 + pp * pcstep_a; \ + ctype* restrict b_pc = b_jc + pp * pcstep_b; \ +\ + /* Only apply beta to the first iteration of the pc loop. */ \ + ctype* restrict beta_use = ( pp == 0 ? &beta_local : &one_local ); \ +\ + ctype* b_use; \ + inc_t rs_b_use, cs_b_use, ps_b_use; \ +\ + /* Identify the current thrinfo_t node. Note that the thrinfo_t + node will have already been created by a previous call to + bli_thrinfo_sup_grow() since bszid_t values of BLIS_NO_PART + cause the tree to grow by two (e.g. to the next bszid that is + a normal bszid_t value). */ \ + thread_pb = bli_thrinfo_sub_node( thread_pc ); \ + /*bli_thrinfo_sup_grow( rntm, bszids_pb, thread_pb );*/ \ +\ + /* Determine the packing buffer and related parameters for matrix + B. Then call the packm implementation. */ \ + PASTECH2(bls_,ch,packm_b) \ + ( \ + conjb, \ + KC, NC, \ + kc_cur, nc_cur, NR, \ + &one_local, \ + b_pc, rs_b, cs_b, \ + &b_use, &rs_b_use, &cs_b_use, \ + &ps_b_use, \ + cntx, \ + rntm, \ + &mem_b, \ + thread_pb \ + ); \ +\ + /* Alias b_use so that it's clear this is our current block of + matrix B. */ \ + ctype* restrict b_pc_use = b_use; \ +\ + /* Identify the current thrinfo_t node and then grow the tree. */ \ + thread_ic = bli_thrinfo_sub_node( thread_pb ); \ + bli_thrinfo_sup_grow( rntm, bszids_ic, thread_ic ); \ +\ + /* Compute the IC loop thread range for the current thread. */ \ + dim_t ic_start, ic_end; \ + bli_thread_range_sub( thread_ic, m, MR, FALSE, &ic_start, &ic_end ); \ + const dim_t m_local = ic_end - ic_start; \ +\ + /* Compute number of primary and leftover components of the IC loop. */ \ + /*const dim_t ic_iter = ( m_local + MC - 1 ) / MC;*/ \ + const dim_t ic_left = m_local % MC; \ +\ + /* Loop over the m dimension (MC rows at a time). */ \ + for ( dim_t ii = ic_start; ii < ic_end; ii += MC ) \ + { \ + /* Calculate the thread's current IC block dimension. */ \ + const dim_t mc_cur = ( MC <= ic_end - ii ? MC : ic_left ); \ +\ + ctype* restrict a_ic = a_pc + ii * icstep_a; \ + ctype* restrict c_ic = c_jc + ii * icstep_c; \ +\ + ctype* a_use; \ + inc_t rs_a_use, cs_a_use, ps_a_use; \ +\ + /* Identify the current thrinfo_t node. Note that the thrinfo_t + node will have already been created by a previous call to + bli_thrinfo_sup_grow() since bszid_t values of BLIS_NO_PART + cause the tree to grow by two (e.g. to the next bszid that is + a normal bszid_t value). */ \ + thread_pa = bli_thrinfo_sub_node( thread_ic ); \ + /*bli_thrinfo_sup_grow( rntm, bszids_pa, thread_pa );*/ \ +\ + /* Determine the packing buffer and related parameters for matrix + A. Then call the packm implementation. */ \ + PASTECH2(bls_,ch,packm_a) \ + ( \ + conja, \ + MC, KC, \ + mc_cur, kc_cur, MR, \ + &one_local, \ + a_ic, rs_a, cs_a, \ + &a_use, &rs_a_use, &cs_a_use, \ + &ps_a_use, \ + cntx, \ + rntm, \ + &mem_a, \ + thread_pa \ + ); \ +\ + /* Alias a_use so that it's clear this is our current block of + matrix A. */ \ + ctype* restrict a_ic_use = a_use; \ +\ + /* Identify the current thrinfo_t node and then grow the tree. */ \ + thread_jr = bli_thrinfo_sub_node( thread_pa ); \ + bli_thrinfo_sup_grow( rntm, bszids_jr, thread_jr ); \ +\ + /* Query the number of threads and thread ids for the JR loop. + NOTE: These values are only needed when computing the next + micropanel of B. */ \ + const dim_t jr_nt = bli_thread_n_way( thread_jr ); \ + const dim_t jr_tid = bli_thread_work_id( thread_jr ); \ +\ + /* Compute number of primary and leftover components of the JR loop. */ \ + dim_t jr_iter = ( nc_cur + NR - 1 ) / NR; \ + dim_t jr_left = nc_cur % NR; \ +\ + /* Compute the JR loop thread range for the current thread. */ \ + dim_t jr_start, jr_end; \ + bli_thread_range_sub( thread_jr, jr_iter, 1, FALSE, &jr_start, &jr_end ); \ +\ + /* Loop over the n dimension (NR columns at a time). */ \ + for ( dim_t j = jr_start; j < jr_end; j += 1 ) \ + { \ + const dim_t nr_cur \ + = ( bli_is_not_edge_f( j, jr_iter, jr_left ) ? NR : jr_left ); \ +\ + ctype* restrict b_jr = b_pc_use + j * ps_b_use; \ + ctype* restrict c_jr = c_ic + j * jrstep_c; \ +\ + /* Assume for now that our next panel of B to be the current panel + of B. */ \ + ctype* restrict b2 = b_jr; \ +\ + /* Identify the current thrinfo_t node. */ \ + thread_ir = bli_thrinfo_sub_node( thread_jr ); \ +\ + /* Query the number of threads and thread ids for the IR loop. + NOTE: These values are only needed when computing the next + micropanel of A. */ \ + const dim_t ir_nt = bli_thread_n_way( thread_ir ); \ + const dim_t ir_tid = bli_thread_work_id( thread_ir ); \ +\ + /* Compute number of primary and leftover components of the IR loop. */ \ + dim_t ir_iter = ( mc_cur + MR - 1 ) / MR; \ + dim_t ir_left = mc_cur % MR; \ +\ + /* Compute the IR loop thread range for the current thread. */ \ + dim_t ir_start, ir_end; \ + bli_thread_range_sub( thread_ir, ir_iter, 1, FALSE, &ir_start, &ir_end ); \ +\ + /* Loop over the m dimension (MR rows at a time). */ \ + for ( dim_t i = ir_start; i < ir_end; i += 1 ) \ + { \ + const dim_t mr_cur \ + = ( bli_is_not_edge_f( i, ir_iter, ir_left ) ? MR : ir_left ); \ +\ + ctype* restrict a_ir = a_ic_use + i * ps_a_use; \ + ctype* restrict c_ir = c_jr + i * irstep_c; \ +\ + ctype* restrict a2; \ +\ + /* Compute the addresses of the next micropanels of A and B. */ \ + a2 = bli_gemm_get_next_a_upanel( a_ir, ps_a_use, 1 ); \ + if ( bli_is_last_iter( i, ir_end, ir_tid, ir_nt ) ) \ + { \ + a2 = a_ic_use; \ + b2 = bli_gemm_get_next_b_upanel( b_jr, ps_b_use, 1 ); \ + if ( bli_is_last_iter( j, jr_end, jr_tid, jr_nt ) ) \ + b2 = b_pc_use; \ + } \ +\ + /* Save the addresses of next micropanels of A and B to the + auxinfo_t object. */ \ + bli_auxinfo_set_next_a( a2, &aux ); \ + bli_auxinfo_set_next_b( b2, &aux ); \ +\ + /* Handle interior and edge cases separately. */ \ + if ( mr_cur == MR && nr_cur == NR ) \ + { \ + /* Invoke the gemm microkernel. */ \ + gemm_ukr \ + ( \ + kc_cur, \ + &alpha_local, \ + a_ir, \ + b_jr, \ + beta_use, \ + c_ir, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ + } \ + else \ + { \ + /* Invoke the gemm microkernel. */ \ + gemm_ukr \ + ( \ + kc_cur, \ + &alpha_local, \ + a_ir, \ + b_jr, \ + &zero_local, \ + ct, rs_ct, cs_ct, \ + &aux, \ + cntx \ + ); \ +\ + /* Scale the bottom edge of C and add the result from above. */ \ + PASTEMAC(ch,xpbys_mxn) \ + ( \ + mr_cur, \ + nr_cur, \ + ct, rs_ct, cs_ct, \ + beta_use, \ + c_ir, rs_c, cs_c \ + ); \ + } \ + } \ + } \ + } \ +\ + /* This barrier is needed to prevent threads from starting to pack + the next row panel of B before the current row panel is fully + computed upon. */ \ + bli_thread_barrier( thread_pb ); \ + } \ + } \ +\ + /* Release any memory that was acquired for packing matrices A and B. */ \ + PASTECH2(bls_,ch,packm_finalize_mem_a) \ + ( \ + rntm, \ + &mem_a, \ + thread_pa \ + ); \ + PASTECH2(bls_,ch,packm_finalize_mem_b) \ + ( \ + rntm, \ + &mem_b, \ + thread_pb \ + ); \ +\ +/* +PASTEMAC(ch,fprintm)( stdout, "gemm_bp_var1: a1_packed", mr_cur, kc_cur, a_ir, rs_a_use, cs_a_use, "%5.2f", "" ); \ +PASTEMAC(ch,fprintm)( stdout, "gemm_bp_var1: b1_packed", kc_cur, nr_cur, b_jr, rs_b_use, cs_b_use, "%5.2f", "" ); \ +PASTEMAC(ch,fprintm)( stdout, "gemm_bp_var1: c ", mr_cur, nr_cur, c_ir, rs_c, cs_c, "%5.2f", "" ); \ +*/ \ +} + +//INSERT_GENTFUNC_BASIC0( gemm_bp_var1 ) +GENTFUNC( float, s, gemm_bp_var1 ) +GENTFUNC( double, d, gemm_bp_var1 ) +GENTFUNC( scomplex, c, gemm_bp_var1 ) +GENTFUNC( dcomplex, z, gemm_bp_var1 ) + diff --git a/sandbox/gemmlike/bls_gemm_bp_var2.c b/sandbox/gemmlike/bls_gemm_bp_var2.c new file mode 100644 index 0000000000..957cd57944 --- /dev/null +++ b/sandbox/gemmlike/bls_gemm_bp_var2.c @@ -0,0 +1,590 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#define FUNCPTR_T gemm_fp + +typedef void (*FUNCPTR_T) + ( + conj_t conja, + conj_t conjb, + dim_t m, + dim_t n, + dim_t k, + void* restrict alpha, + void* restrict a, inc_t rs_a, inc_t cs_a, + void* restrict b, inc_t rs_b, inc_t cs_b, + void* restrict beta, + void* restrict c, inc_t rs_c, inc_t cs_c, + cntx_t* restrict cntx, + rntm_t* restrict rntm, + thrinfo_t* restrict thread + ); + +// +// -- gemm-like block-panel algorithm (object interface) ----------------------- +// + +// Define a function pointer array named ftypes and initialize its contents with +// the addresses of the typed functions defined below, bls_?gemm_bp_var2(). +static FUNCPTR_T GENARRAY_PREF(ftypes,bls_,gemm_bp_var2); + +void bls_gemm_bp_var2 + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm, + thrinfo_t* thread + ) +{ + const num_t dt = bli_obj_dt( c ); + + const conj_t conja = bli_obj_conj_status( a ); + const conj_t conjb = bli_obj_conj_status( b ); + + const dim_t m = bli_obj_length( c ); + const dim_t n = bli_obj_width( c ); + const dim_t k = bli_obj_width( a ); + + void* restrict buf_a = bli_obj_buffer_at_off( a ); + const inc_t rs_a = bli_obj_row_stride( a ); + const inc_t cs_a = bli_obj_col_stride( a ); + + void* restrict buf_b = bli_obj_buffer_at_off( b ); + const inc_t rs_b = bli_obj_row_stride( b ); + const inc_t cs_b = bli_obj_col_stride( b ); + + void* restrict buf_c = bli_obj_buffer_at_off( c ); + const inc_t rs_c = bli_obj_row_stride( c ); + const inc_t cs_c = bli_obj_col_stride( c ); + + void* restrict buf_alpha = bli_obj_buffer_for_1x1( dt, alpha ); + void* restrict buf_beta = bli_obj_buffer_for_1x1( dt, beta ); + + // Index into the function pointer array to extract the correct + // typed function pointer based on the chosen datatype. + FUNCPTR_T f = ftypes[dt]; + + // Invoke the function. + f + ( + conja, + conjb, + m, + n, + k, + buf_alpha, + buf_a, rs_a, cs_a, + buf_b, rs_b, cs_b, + buf_beta, + buf_c, rs_c, cs_c, + cntx, + rntm, + thread + ); +} + +// +// -- gemm-like block-panel algorithm (typed interface) ------------------------ +// + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTECH2(bls_,ch,varname) \ + ( \ + conj_t conja, \ + conj_t conjb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + void* restrict alpha, \ + void* restrict a, inc_t rs_a, inc_t cs_a, \ + void* restrict b, inc_t rs_b, inc_t cs_b, \ + void* restrict beta, \ + void* restrict c, inc_t rs_c, inc_t cs_c, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + const num_t dt = PASTEMAC(ch,type); \ +\ + /* Query the context for various blocksizes. */ \ + const dim_t NR = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx ); \ + const dim_t MR = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx ); \ + const dim_t NC = bli_cntx_get_blksz_def_dt( dt, BLIS_NC, cntx ); \ + const dim_t MC = bli_cntx_get_blksz_def_dt( dt, BLIS_MC, cntx ); \ + const dim_t KC = bli_cntx_get_blksz_def_dt( dt, BLIS_KC, cntx ); \ +\ + /* Query the context for the microkernel address and cast it to its + function pointer type. */ \ + /* + PASTECH(ch,gemm_ukr_ft) \ + gemm_ukr = bli_cntx_get_l3_nat_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \ + */ \ +\ + /* Temporary C buffer for edge cases. Note that the strides of this + temporary buffer are set so that they match the storage of the + original C matrix. For example, if C is column-stored, ct will be + column-stored as well. */ \ + /* + ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ + / sizeof( ctype ) ] \ + __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ + const bool col_pref = bli_cntx_l3_nat_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ + const inc_t rs_ct = ( col_pref ? 1 : NR ); \ + const inc_t cs_ct = ( col_pref ? MR : 1 ); \ + */ \ +\ + /* Compute partitioning step values for each matrix of each loop. */ \ + const inc_t jcstep_c = cs_c; \ + const inc_t jcstep_b = cs_b; \ +\ + const inc_t pcstep_a = cs_a; \ + const inc_t pcstep_b = rs_b; \ +\ + const inc_t icstep_c = rs_c; \ + const inc_t icstep_a = rs_a; \ +\ + const inc_t jrstep_c = cs_c * NR; \ +\ + const inc_t irstep_c = rs_c * MR; \ +\ + ctype* restrict a_00 = a; \ + ctype* restrict b_00 = b; \ + ctype* restrict c_00 = c; \ + ctype* restrict alpha_cast = alpha; \ + ctype* restrict beta_cast = beta; \ +\ + /* Make local copies of the scalars to prevent any unnecessary sharing of + cache lines between the cores' caches. */ \ + ctype alpha_local = *alpha_cast; \ + ctype beta_local = *beta_cast; \ + ctype one_local = *PASTEMAC(ch,1); \ + /*ctype zero_local = *PASTEMAC(ch,0);*/ \ +\ + auxinfo_t aux; \ +\ + /* Initialize a mem_t entry for A and B. Strictly speaking, this is only + needed for the matrix we will be packing (if any), but we do it + unconditionally to be safe. */ \ + mem_t mem_a = BLIS_MEM_INITIALIZER; \ + mem_t mem_b = BLIS_MEM_INITIALIZER; \ +\ + /* Define an array of bszid_t ids, which will act as our substitute for + the cntl_t tree. */ \ + bszid_t bszids[8] = { BLIS_NC, /* 5th loop */ \ + BLIS_KC, /* 4th loop */ \ + BLIS_NO_PART, /* pack B */ \ + BLIS_MC, /* 3rd loop */ \ + BLIS_NO_PART, /* pack A */ \ + BLIS_NR, /* 2nd loop */ \ + BLIS_MR, /* 1st loop */ \ + BLIS_KR }; /* microkernel loop */ \ +\ + bszid_t* restrict bszids_jc = &bszids[0]; \ + bszid_t* restrict bszids_pc = &bszids[1]; \ + /*bszid_t* restrict bszids_pb = &bszids[2];*/ \ + bszid_t* restrict bszids_ic = &bszids[3]; \ + /*bszid_t* restrict bszids_pa = &bszids[4];*/ \ + bszid_t* restrict bszids_jr = &bszids[5]; \ + /*bszid_t* restrict bszids_ir = &bszids[6];*/ \ +\ + thrinfo_t* restrict thread_jc = NULL; \ + thrinfo_t* restrict thread_pc = NULL; \ + thrinfo_t* restrict thread_pb = NULL; \ + thrinfo_t* restrict thread_ic = NULL; \ + thrinfo_t* restrict thread_pa = NULL; \ + thrinfo_t* restrict thread_jr = NULL; \ + thrinfo_t* restrict thread_ir = NULL; \ +\ + /* Identify the current thrinfo_t node and then grow the tree. */ \ + thread_jc = thread; \ + bli_thrinfo_sup_grow( rntm, bszids_jc, thread_jc ); \ +\ + /* Compute the JC loop thread range for the current thread. */ \ + dim_t jc_start, jc_end; \ + bli_thread_range_sub( thread_jc, n, NR, FALSE, &jc_start, &jc_end ); \ + const dim_t n_local = jc_end - jc_start; \ +\ + /* Compute number of primary and leftover components of the JC loop. */ \ + /*const dim_t jc_iter = ( n_local + NC - 1 ) / NC;*/ \ + const dim_t jc_left = n_local % NC; \ +\ + /* Loop over the n dimension (NC rows/columns at a time). */ \ + for ( dim_t jj = jc_start; jj < jc_end; jj += NC ) \ + { \ + /* Calculate the thread's current JC block dimension. */ \ + const dim_t nc_cur = ( NC <= jc_end - jj ? NC : jc_left ); \ +\ + ctype* restrict b_jc = b_00 + jj * jcstep_b; \ + ctype* restrict c_jc = c_00 + jj * jcstep_c; \ +\ + /* Identify the current thrinfo_t node and then grow the tree. */ \ + thread_pc = bli_thrinfo_sub_node( thread_jc ); \ + bli_thrinfo_sup_grow( rntm, bszids_pc, thread_pc ); \ +\ + /* Compute the PC loop thread range for the current thread. */ \ + const dim_t pc_start = 0, pc_end = k; \ + const dim_t k_local = k; \ +\ + /* Compute number of primary and leftover components of the PC loop. */ \ + /*const dim_t pc_iter = ( k_local + KC - 1 ) / KC;*/ \ + const dim_t pc_left = k_local % KC; \ +\ + /* Loop over the k dimension (KC rows/columns at a time). */ \ + for ( dim_t pp = pc_start; pp < pc_end; pp += KC ) \ + { \ + /* Calculate the thread's current PC block dimension. */ \ + const dim_t kc_cur = ( KC <= pc_end - pp ? KC : pc_left ); \ +\ + ctype* restrict a_pc = a_00 + pp * pcstep_a; \ + ctype* restrict b_pc = b_jc + pp * pcstep_b; \ +\ + /* Only apply beta to the first iteration of the pc loop. */ \ + ctype* restrict beta_use = ( pp == 0 ? &beta_local : &one_local ); \ +\ + ctype* b_use; \ + inc_t rs_b_use, cs_b_use, ps_b_use; \ +\ + /* Identify the current thrinfo_t node. Note that the thrinfo_t + node will have already been created by a previous call to + bli_thrinfo_sup_grow() since bszid_t values of BLIS_NO_PART + cause the tree to grow by two (e.g. to the next bszid that is + a normal bszid_t value). */ \ + thread_pb = bli_thrinfo_sub_node( thread_pc ); \ + /*bli_thrinfo_sup_grow( rntm, bszids_pb, thread_pb );*/ \ +\ + /* Determine the packing buffer and related parameters for matrix + B. Then call the packm implementation. */ \ + PASTECH2(bls_,ch,packm_b) \ + ( \ + conjb, \ + KC, NC, \ + kc_cur, nc_cur, NR, \ + &one_local, \ + b_pc, rs_b, cs_b, \ + &b_use, &rs_b_use, &cs_b_use, \ + &ps_b_use, \ + cntx, \ + rntm, \ + &mem_b, \ + thread_pb \ + ); \ +\ + /* Alias b_use so that it's clear this is our current block of + matrix B. */ \ + ctype* restrict b_pc_use = b_use; \ +\ + /* Identify the current thrinfo_t node and then grow the tree. */ \ + thread_ic = bli_thrinfo_sub_node( thread_pb ); \ + bli_thrinfo_sup_grow( rntm, bszids_ic, thread_ic ); \ +\ + /* Compute the IC loop thread range for the current thread. */ \ + dim_t ic_start, ic_end; \ + bli_thread_range_sub( thread_ic, m, MR, FALSE, &ic_start, &ic_end ); \ + const dim_t m_local = ic_end - ic_start; \ +\ + /* Compute number of primary and leftover components of the IC loop. */ \ + /*const dim_t ic_iter = ( m_local + MC - 1 ) / MC;*/ \ + const dim_t ic_left = m_local % MC; \ +\ + /* Loop over the m dimension (MC rows at a time). */ \ + for ( dim_t ii = ic_start; ii < ic_end; ii += MC ) \ + { \ + /* Calculate the thread's current IC block dimension. */ \ + const dim_t mc_cur = ( MC <= ic_end - ii ? MC : ic_left ); \ +\ + ctype* restrict a_ic = a_pc + ii * icstep_a; \ + ctype* restrict c_ic = c_jc + ii * icstep_c; \ +\ + ctype* a_use; \ + inc_t rs_a_use, cs_a_use, ps_a_use; \ +\ + /* Identify the current thrinfo_t node. Note that the thrinfo_t + node will have already been created by a previous call to + bli_thrinfo_sup_grow() since bszid_t values of BLIS_NO_PART + cause the tree to grow by two (e.g. to the next bszid that is + a normal bszid_t value). */ \ + thread_pa = bli_thrinfo_sub_node( thread_ic ); \ + /*bli_thrinfo_sup_grow( rntm, bszids_pa, thread_pa );*/ \ +\ + /* Determine the packing buffer and related parameters for matrix + A. Then call the packm implementation. */ \ + PASTECH2(bls_,ch,packm_a) \ + ( \ + conja, \ + MC, KC, \ + mc_cur, kc_cur, MR, \ + &one_local, \ + a_ic, rs_a, cs_a, \ + &a_use, &rs_a_use, &cs_a_use, \ + &ps_a_use, \ + cntx, \ + rntm, \ + &mem_a, \ + thread_pa \ + ); \ +\ + /* Alias a_use so that it's clear this is our current block of + matrix A. */ \ + ctype* restrict a_ic_use = a_use; \ +\ + /* Identify the current thrinfo_t node and then grow the tree. */ \ + thread_jr = bli_thrinfo_sub_node( thread_pa ); \ + bli_thrinfo_sup_grow( rntm, bszids_jr, thread_jr ); \ +\ + /* Query the number of threads and thread ids for the JR loop. + NOTE: These values are only needed when computing the next + micropanel of B. */ \ + const dim_t jr_nt = bli_thread_n_way( thread_jr ); \ + const dim_t jr_tid = bli_thread_work_id( thread_jr ); \ +\ + /* Compute number of primary and leftover components of the JR loop. */ \ + dim_t jr_iter = ( nc_cur + NR - 1 ) / NR; \ + dim_t jr_left = nc_cur % NR; \ +\ + /* Compute the JR loop thread range for the current thread. */ \ + dim_t jr_start, jr_end; \ + bli_thread_range_sub( thread_jr, jr_iter, 1, FALSE, &jr_start, &jr_end ); \ +\ + /* Loop over the n dimension (NR columns at a time). */ \ + for ( dim_t j = jr_start; j < jr_end; j += 1 ) \ + { \ + const dim_t nr_cur \ + = ( bli_is_not_edge_f( j, jr_iter, jr_left ) ? NR : jr_left ); \ +\ + ctype* restrict b_jr = b_pc_use + j * ps_b_use; \ + ctype* restrict c_jr = c_ic + j * jrstep_c; \ +\ + /* Assume for now that our next panel of B to be the current panel + of B. */ \ + ctype* restrict b2 = b_jr; \ +\ + /* Identify the current thrinfo_t node. */ \ + thread_ir = bli_thrinfo_sub_node( thread_jr ); \ +\ + /* Query the number of threads and thread ids for the IR loop. + NOTE: These values are only needed when computing the next + micropanel of A. */ \ + const dim_t ir_nt = bli_thread_n_way( thread_ir ); \ + const dim_t ir_tid = bli_thread_work_id( thread_ir ); \ +\ + /* Compute number of primary and leftover components of the IR loop. */ \ + dim_t ir_iter = ( mc_cur + MR - 1 ) / MR; \ + dim_t ir_left = mc_cur % MR; \ +\ + /* Compute the IR loop thread range for the current thread. */ \ + dim_t ir_start, ir_end; \ + bli_thread_range_sub( thread_ir, ir_iter, 1, FALSE, &ir_start, &ir_end ); \ +\ + /* Loop over the m dimension (MR rows at a time). */ \ + for ( dim_t i = ir_start; i < ir_end; i += 1 ) \ + { \ + const dim_t mr_cur \ + = ( bli_is_not_edge_f( i, ir_iter, ir_left ) ? MR : ir_left ); \ +\ + ctype* restrict a_ir = a_ic_use + i * ps_a_use; \ + ctype* restrict c_ir = c_jr + i * irstep_c; \ +\ + ctype* restrict a2; \ +\ + /* Compute the addresses of the next micropanels of A and B. */ \ + a2 = bli_gemm_get_next_a_upanel( a_ir, ps_a_use, 1 ); \ + if ( bli_is_last_iter( i, ir_end, ir_tid, ir_nt ) ) \ + { \ + a2 = a_ic_use; \ + b2 = bli_gemm_get_next_b_upanel( b_jr, ps_b_use, 1 ); \ + if ( bli_is_last_iter( j, jr_end, jr_tid, jr_nt ) ) \ + b2 = b_pc_use; \ + } \ +\ + /* Save the addresses of next micropanels of A and B to the + auxinfo_t object. */ \ + bli_auxinfo_set_next_a( a2, &aux ); \ + bli_auxinfo_set_next_b( b2, &aux ); \ +\ + /* Call a wrapper to the kernel (which handles edge cases). */ \ + PASTECH2(bls_,ch,gemm_kernel) \ + ( \ + MR, \ + NR, \ + mr_cur, \ + nr_cur, \ + kc_cur, \ + &alpha_local, \ + a_ir, rs_a_use, cs_a_use, \ + b_jr, rs_b_use, cs_b_use, \ + beta_use, \ + c_ir, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ + } \ + } \ + } \ +\ + /* This barrier is needed to prevent threads from starting to pack + the next row panel of B before the current row panel is fully + computed upon. */ \ + bli_thread_barrier( thread_pb ); \ + } \ + } \ +\ + /* Release any memory that was acquired for packing matrices A and B. */ \ + PASTECH2(bls_,ch,packm_finalize_mem_a) \ + ( \ + rntm, \ + &mem_a, \ + thread_pa \ + ); \ + PASTECH2(bls_,ch,packm_finalize_mem_b) \ + ( \ + rntm, \ + &mem_b, \ + thread_pb \ + ); \ +\ +/* +PASTEMAC(ch,fprintm)( stdout, "gemm_bp_var2: a1_packed", mr_cur, kc_cur, a_ir, rs_a_use, cs_a_use, "%5.2f", "" ); \ +PASTEMAC(ch,fprintm)( stdout, "gemm_bp_var2: b1_packed", kc_cur, nr_cur, b_jr, rs_b_use, cs_b_use, "%5.2f", "" ); \ +PASTEMAC(ch,fprintm)( stdout, "gemm_bp_var2: c ", mr_cur, nr_cur, c_ir, rs_c, cs_c, "%5.2f", "" ); \ +*/ \ +} + +//INSERT_GENTFUNC_BASIC0( gemm_bp_var2 ) +GENTFUNC( float, s, gemm_bp_var2 ) +GENTFUNC( double, d, gemm_bp_var2 ) +GENTFUNC( scomplex, c, gemm_bp_var2 ) +GENTFUNC( dcomplex, z, gemm_bp_var2 ) + +// +// -- gemm-like microkernel wrapper -------------------------------------------- +// + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTECH2(bls_,ch,varname) \ + ( \ + const dim_t MR, \ + const dim_t NR, \ + dim_t mr_cur, \ + dim_t nr_cur, \ + dim_t kc_cur, \ + ctype* restrict alpha, \ + ctype* restrict a, inc_t rs_a, inc_t cs_a, \ + ctype* restrict b, inc_t rs_b, inc_t cs_b, \ + ctype* restrict beta, \ + ctype* restrict c, inc_t rs_c, inc_t cs_c, \ + auxinfo_t* restrict aux, \ + cntx_t* restrict cntx \ + ) \ +{ \ + /* Infer the datatype from the ctype. */ \ + const num_t dt = PASTEMAC(ch,type); \ +\ + /* Query the context for the microkernel address and cast it to its + function pointer type. */ \ + PASTECH(ch,gemm_ukr_ft) \ + gemm_ukr = bli_cntx_get_l3_nat_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \ +\ + /* Temporary C buffer for edge cases. Note that the strides of this + temporary buffer are set so that they match the storage of the + original C matrix. For example, if C is column-stored, ct will be + column-stored as well. */ \ + ctype ct[ BLIS_STACK_BUF_MAX_SIZE \ + / sizeof( ctype ) ] \ + __attribute__((aligned(BLIS_STACK_BUF_ALIGN_SIZE))); \ + const bool col_pref = bli_cntx_l3_nat_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, cntx ); \ + const inc_t rs_ct = ( col_pref ? 1 : NR ); \ + const inc_t cs_ct = ( col_pref ? MR : 1 ); \ +\ + ctype zero = *PASTEMAC(ch,0); \ +\ + /* Handle interior and edge cases separately. */ \ + if ( mr_cur == MR && nr_cur == NR ) \ + { \ + /* Invoke the gemm microkernel. */ \ + gemm_ukr \ + ( \ + kc_cur, \ + alpha, \ + a, \ + b, \ + beta, \ + c, rs_c, cs_c, \ + aux, \ + cntx \ + ); \ + } \ + else \ + { \ + /* Invoke the gemm microkernel. */ \ + gemm_ukr \ + ( \ + kc_cur, \ + alpha, \ + a, \ + b, \ + &zero, \ + ct, rs_ct, cs_ct, \ + aux, \ + cntx \ + ); \ +\ + /* Scale the bottom edge of C and add the result from above. */ \ + PASTEMAC(ch,xpbys_mxn) \ + ( \ + mr_cur, \ + nr_cur, \ + ct, rs_ct, cs_ct, \ + beta, \ + c, rs_c, cs_c \ + ); \ + } \ +} + +//INSERT_GENTFUNC_BASIC0( gemm_kernel ) +GENTFUNC( float, s, gemm_kernel ) +GENTFUNC( double, d, gemm_kernel ) +GENTFUNC( scomplex, c, gemm_kernel ) +GENTFUNC( dcomplex, z, gemm_kernel ) + diff --git a/sandbox/gemmlike/bls_gemm_var.h b/sandbox/gemmlike/bls_gemm_var.h new file mode 100644 index 0000000000..025b54a06f --- /dev/null +++ b/sandbox/gemmlike/bls_gemm_var.h @@ -0,0 +1,124 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + + +// +// Prototype the object-based variant interfaces. +// + +#undef GENPROT +#define GENPROT( opname ) \ +\ +void PASTECH(bls_,opname) \ + ( \ + obj_t* alpha, \ + obj_t* a, \ + obj_t* b, \ + obj_t* beta, \ + obj_t* c, \ + cntx_t* cntx, \ + rntm_t* rntm, \ + thrinfo_t* thread \ + ); + +GENPROT( gemm_bp_var1 ) +GENPROT( gemm_bp_var2 ) + + +// +// Prototype the typed variant interfaces. +// + +#undef GENTPROT +#define GENTPROT( ctype, ch, varname ) \ +\ +void PASTECH2(bls_,ch,varname) \ + ( \ + conj_t conja, \ + conj_t conjb, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + void* restrict alpha, \ + void* restrict a, inc_t rs_a, inc_t cs_a, \ + void* restrict b, inc_t rs_b, inc_t cs_b, \ + void* restrict beta, \ + void* restrict c, inc_t rs_c, inc_t cs_c, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + thrinfo_t* restrict thread \ + ); + +//INSERT_GENTPROT_BASIC0( gemm_bp_var1 ) +GENTPROT( float, s, gemm_bp_var1 ) +GENTPROT( double, d, gemm_bp_var1 ) +GENTPROT( scomplex, c, gemm_bp_var1 ) +GENTPROT( dcomplex, z, gemm_bp_var1 ) + +//INSERT_GENTPROT_BASIC0( gemm_bp_var2 ) +GENTPROT( float, s, gemm_bp_var2 ) +GENTPROT( double, d, gemm_bp_var2 ) +GENTPROT( scomplex, c, gemm_bp_var2 ) +GENTPROT( dcomplex, z, gemm_bp_var2 ) + + +// +// Prototype the typed kernel interfaces. +// + +#undef GENTPROT +#define GENTPROT( ctype, ch, varname ) \ +\ +void PASTECH2(bls_,ch,varname) \ + ( \ + const dim_t MR, \ + const dim_t NR, \ + dim_t mr_cur, \ + dim_t nr_cur, \ + dim_t k, \ + ctype* restrict alpha, \ + ctype* restrict a, inc_t rs_a, inc_t cs_a, \ + ctype* restrict b, inc_t rs_b, inc_t cs_b, \ + ctype* restrict beta, \ + ctype* restrict c, inc_t rs_c, inc_t cs_c, \ + auxinfo_t* restrict aux, \ + cntx_t* restrict cntx \ + ); + +//INSERT_GENTPROT_BASIC0( gemm_kernel ) +GENTPROT( float, s, gemm_kernel ) +GENTPROT( double, d, gemm_kernel ) +GENTPROT( scomplex, c, gemm_kernel ) +GENTPROT( dcomplex, z, gemm_kernel ) + diff --git a/sandbox/gemmlike/bls_l3_packm_a.c b/sandbox/gemmlike/bls_l3_packm_a.c new file mode 100644 index 0000000000..c55a19c7b7 --- /dev/null +++ b/sandbox/gemmlike/bls_l3_packm_a.c @@ -0,0 +1,328 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTECH2(bls_,ch,opname) \ + ( \ + dim_t m, \ + dim_t k, \ + dim_t mr, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + /* Set the pack buffer type so that we are obtaining memory blocks from + the pool dedicated to blocks of A. */ \ + const packbuf_t pack_buf_type = BLIS_BUFFER_FOR_A_BLOCK; \ +\ + /* NOTE: This "rounding up" of the last upanel is absolutely necessary since + we NEED that last micropanel to have the same ldim (cs_p) as the other + micropanels. Why? Because the microkernel assumes that the register (MR, + NR) AND storage (PACKMR, PACKNR) blocksizes do not change. */ \ + const dim_t m_pack = ( m / mr + ( m % mr ? 1 : 0 ) ) * mr; \ + const dim_t k_pack = k; \ +\ + /* Barrier to make sure all threads are caught up and ready to begin the + packm stage. */ \ + bli_thread_barrier( thread ); \ +\ + /* Compute the size of the memory block eneded. */ \ + siz_t size_needed = sizeof( ctype ) * m_pack * k_pack; \ +\ + /* Check the mem_t entry provided by the caller. If it is unallocated, + then we need to acquire a block from the memory broker. */ \ + if ( bli_mem_is_unalloc( mem ) ) \ + { \ + if ( bli_thread_am_ochief( thread ) ) \ + { \ + /* Acquire directly to the chief thread's mem_t that was passed in. + It needs to be that mem_t struct, and not a local (temporary) + mem_t, since there is no barrier until after packing is finished, + which could allow a race condition whereby the chief thread exits + the current function before the other threads have a chance to + copy from it. (A barrier would fix that race condition, but then + again, I prefer to keep barriers to a minimum.) */ \ + bli_membrk_acquire_m \ + ( \ + rntm, \ + size_needed, \ + pack_buf_type, \ + mem \ + ); \ + } \ +\ + /* Broadcast the address of the chief thread's passed-in mem_t to all + threads. */ \ + mem_t* mem_p = bli_thread_broadcast( thread, mem ); \ +\ + /* Non-chief threads: Copy the contents of the chief thread's + passed-in mem_t to the passed-in mem_t for this thread. (The + chief thread already has the mem_t, so it does not need to + perform any copy.) */ \ + if ( !bli_thread_am_ochief( thread ) ) \ + { \ + *mem = *mem_p; \ + } \ + } \ + else /* if ( bli_mem_is_alloc( mem ) ) */ \ + { \ + /* If the mem_t entry provided by the caller does NOT contain a NULL + buffer, then a block has already been acquired from the memory + broker and cached by the caller. */ \ +\ + /* As a sanity check, we should make sure that the mem_t object isn't + associated with a block that is too small compared to the size of + the packed matrix buffer that is needed, according to the value + computed above. */ \ + siz_t mem_size = bli_mem_size( mem ); \ +\ + if ( mem_size < size_needed ) \ + { \ + if ( bli_thread_am_ochief( thread ) ) \ + { \ + /* The chief thread releases the existing block associated + with the mem_t, and then re-acquires a new block, saving + the associated mem_t to its passed-in mem_t. (See coment + above for why the acquisition needs to be directly to + the chief thread's passed-in mem_t and not a local + (temporary) mem_t. */ \ + bli_membrk_release \ + ( \ + rntm, \ + mem \ + ); \ + bli_membrk_acquire_m \ + ( \ + rntm, \ + size_needed, \ + pack_buf_type, \ + mem \ + ); \ + } \ +\ + /* Broadcast the address of the chief thread's passed-in mem_t + to all threads. */ \ + mem_t* mem_p = bli_thread_broadcast( thread, mem ); \ +\ + /* Non-chief threads: Copy the contents of the chief thread's + passed-in mem_t to the passed-in mem_t for this thread. (The + chief thread already has the mem_t, so it does not need to + perform any copy.) */ \ + if ( !bli_thread_am_ochief( thread ) ) \ + { \ + *mem = *mem_p; \ + } \ + } \ + else \ + { \ + /* If the mem_t entry is already allocated and sufficiently large, + then we use it as-is. No action is needed. */ \ + } \ + } \ +} + +//INSERT_GENTFUNC_BASIC0( packm_init_mem_a ) +GENTFUNC( float, s, packm_init_mem_a ) +GENTFUNC( double, d, packm_init_mem_a ) +GENTFUNC( scomplex, c, packm_init_mem_a ) +GENTFUNC( dcomplex, z, packm_init_mem_a ) + + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTECH2(bls_,ch,opname) \ + ( \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + if ( thread != NULL ) \ + if ( bli_thread_am_ochief( thread ) ) \ + { \ + /* Check the mem_t entry provided by the caller. Only proceed if it + is allocated, which it should be. */ \ + if ( bli_mem_is_alloc( mem ) ) \ + { \ + bli_membrk_release \ + ( \ + rntm, \ + mem \ + ); \ + } \ + } \ +} + +//INSERT_GENTFUNC_BASIC0( packm_finalize_mem_a ) +GENTFUNC( float, s, packm_finalize_mem_a ) +GENTFUNC( double, d, packm_finalize_mem_a ) +GENTFUNC( scomplex, c, packm_finalize_mem_a ) +GENTFUNC( dcomplex, z, packm_finalize_mem_a ) + + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTECH2(bls_,ch,opname) \ + ( \ + pack_t* restrict schema, \ + dim_t m, \ + dim_t k, \ + dim_t mr, \ + dim_t* restrict m_max, \ + dim_t* restrict k_max, \ + ctype** p, inc_t* restrict rs_p, inc_t* restrict cs_p, \ + dim_t* restrict pd_p, inc_t* restrict ps_p, \ + mem_t* restrict mem \ + ) \ +{ \ + /* NOTE: This "rounding up" of the last upanel is absolutely necessary since + we NEED that last micropanel to have the same ldim (cs_p) as the other + micropanels. Why? Because the microkernel assumes that the register (MR, + NR) AND storage (PACKMR, PACKNR) blocksizes do not change. */ \ + *m_max = ( m / mr + ( m % mr ? 1 : 0 ) ) * mr; \ + *k_max = k; \ +\ + /* Determine the dimensions and strides for the packed matrix A. */ \ + { \ + /* Pack A to column-stored row-panels. */ \ + *rs_p = 1; \ + *cs_p = mr; \ +\ + *pd_p = mr; \ + *ps_p = mr * k; \ +\ + /* Set the schema to "packed row panels" to indicate packing to + conventional column-stored row panels. */ \ + *schema = BLIS_PACKED_ROW_PANELS; \ + } \ +\ + /* Set the buffer address provided by the caller to point to the memory + associated with the mem_t entry acquired from the memory pool. */ \ + *p = bli_mem_buffer( mem ); \ +} + +//INSERT_GENTFUNC_BASIC0( packm_init_a ) +GENTFUNC( float, s, packm_init_a ) +GENTFUNC( double, d, packm_init_a ) +GENTFUNC( scomplex, c, packm_init_a ) +GENTFUNC( dcomplex, z, packm_init_a ) + + +// +// Define BLAS-like interfaces to the variant chooser. +// + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTECH2(bls_,ch,opname) \ + ( \ + conj_t conj, \ + dim_t m_alloc, \ + dim_t k_alloc, \ + dim_t m, \ + dim_t k, \ + dim_t mr, \ + ctype* restrict kappa, \ + ctype* restrict a, inc_t rs_a, inc_t cs_a, \ + ctype** restrict p, inc_t* restrict rs_p, inc_t* restrict cs_p, \ + inc_t* restrict ps_p, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + pack_t schema; \ + dim_t m_max; \ + dim_t k_max; \ + dim_t pd_p; \ +\ + /* Prepare the packing destination buffer. */ \ + PASTECH2(bls_,ch,packm_init_mem_a) \ + ( \ + m_alloc, k_alloc, mr, \ + cntx, \ + rntm, \ + mem, \ + thread \ + ); \ +\ + /* Determine the packing buffer and related parameters for matrix A. */ \ + PASTECH2(bls_,ch,packm_init_a) \ + ( \ + &schema, \ + m, k, mr, \ + &m_max, &k_max, \ + p, rs_p, cs_p, \ + &pd_p, ps_p, \ + mem \ + ); \ +\ + /* Pack matrix A to the destination buffer chosen above. Here, the packed + matrix is stored to column-stored MR x k micropanels. */ \ + PASTECH2(bls_,ch,packm_var1) \ + ( \ + conj, \ + schema, \ + m, \ + k, \ + m_max, \ + k_max, \ + kappa, \ + a, rs_a, cs_a, \ + *p, *rs_p, *cs_p, \ + pd_p, *ps_p, \ + cntx, \ + thread \ + ); \ +\ + /* Barrier so that packing is done before computation. */ \ + bli_thread_barrier( thread ); \ +} + +//INSERT_GENTFUNC_BASIC0( packm_a ) +GENTFUNC( float, s, packm_a ) +GENTFUNC( double, d, packm_a ) +GENTFUNC( scomplex, c, packm_a ) +GENTFUNC( dcomplex, z, packm_a ) + diff --git a/sandbox/gemmlike/bls_l3_packm_a.h b/sandbox/gemmlike/bls_l3_packm_a.h new file mode 100644 index 0000000000..201a24efae --- /dev/null +++ b/sandbox/gemmlike/bls_l3_packm_a.h @@ -0,0 +1,122 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#undef GENTPROT +#define GENTPROT( ctype, ch, opname ) \ +\ +void PASTECH2(bls_,ch,opname) \ + ( \ + dim_t m, \ + dim_t k, \ + dim_t mr, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ); \ + +//INSERT_GENTPROT_BASIC0( packm_init_mem_a ) +GENTPROT( float, s, packm_init_mem_a ) +GENTPROT( double, d, packm_init_mem_a ) +GENTPROT( scomplex, c, packm_init_mem_a ) +GENTPROT( dcomplex, z, packm_init_mem_a ) + + +#undef GENTPROT +#define GENTPROT( ctype, ch, opname ) \ +\ +void PASTECH2(bls_,ch,opname) \ + ( \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ); \ + +//INSERT_GENTPROT_BASIC0( packm_finalize_mem_a ) +GENTPROT( float, s, packm_finalize_mem_a ) +GENTPROT( double, d, packm_finalize_mem_a ) +GENTPROT( scomplex, c, packm_finalize_mem_a ) +GENTPROT( dcomplex, z, packm_finalize_mem_a ) + + +#undef GENTPROT +#define GENTPROT( ctype, ch, opname ) \ +\ +void PASTECH2(bls_,ch,opname) \ + ( \ + pack_t* restrict schema, \ + dim_t m, \ + dim_t k, \ + dim_t mr, \ + dim_t* restrict m_max, \ + dim_t* restrict k_max, \ + ctype** p, inc_t* restrict rs_p, inc_t* restrict cs_p, \ + dim_t* restrict pd_p, inc_t* restrict ps_p, \ + mem_t* restrict mem \ + ); \ + +//INSERT_GENTPROT_BASIC0( packm_init_a ) +GENTPROT( float, s, packm_init_a ) +GENTPROT( double, d, packm_init_a ) +GENTPROT( scomplex, c, packm_init_a ) +GENTPROT( dcomplex, z, packm_init_a ) + + +#undef GENTPROT +#define GENTPROT( ctype, ch, opname ) \ +\ +void PASTECH2(bls_,ch,opname) \ + ( \ + conj_t conj, \ + dim_t m_alloc, \ + dim_t k_alloc, \ + dim_t m, \ + dim_t k, \ + dim_t mr, \ + ctype* restrict kappa, \ + ctype* restrict a, inc_t rs_a, inc_t cs_a, \ + ctype** restrict p, inc_t* restrict rs_p, inc_t* restrict cs_p, \ + inc_t* restrict ps_p, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ); \ + +//INSERT_GENTPROT_BASIC0( packm_a ) +GENTPROT( float, s, packm_a ) +GENTPROT( double, d, packm_a ) +GENTPROT( scomplex, c, packm_a ) +GENTPROT( dcomplex, z, packm_a ) + diff --git a/sandbox/gemmlike/bls_l3_packm_b.c b/sandbox/gemmlike/bls_l3_packm_b.c new file mode 100644 index 0000000000..cae93df012 --- /dev/null +++ b/sandbox/gemmlike/bls_l3_packm_b.c @@ -0,0 +1,328 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTECH2(bls_,ch,opname) \ + ( \ + dim_t k, \ + dim_t n, \ + dim_t nr, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + /* Set the pack buffer type so that we are obtaining memory blocks from + the pool dedicated to panels of B. */ \ + const packbuf_t pack_buf_type = BLIS_BUFFER_FOR_B_PANEL; \ +\ + /* NOTE: This "rounding up" of the last upanel is absolutely necessary since + we NEED that last micropanel to have the same ldim (cs_p) as the other + micropanels. Why? Because the microkernel assumes that the register (MR, + NR) AND storage (PACKMR, PACKNR) blocksizes do not change. */ \ + const dim_t k_pack = k; \ + const dim_t n_pack = ( n / nr + ( n % nr ? 1 : 0 ) ) * nr; \ +\ + /* Barrier to make sure all threads are caught up and ready to begin the + packm stage. */ \ + bli_thread_barrier( thread ); \ +\ + /* Compute the size of the memory block eneded. */ \ + siz_t size_needed = sizeof( ctype ) * k_pack * n_pack; \ +\ + /* Check the mem_t entry provided by the caller. If it is unallocated, + then we need to acquire a block from the memory broker. */ \ + if ( bli_mem_is_unalloc( mem ) ) \ + { \ + if ( bli_thread_am_ochief( thread ) ) \ + { \ + /* Acquire directly to the chief thread's mem_t that was passed in. + It needs to be that mem_t struct, and not a local (temporary) + mem_t, since there is no barrier until after packing is finished, + which could allow a race condition whereby the chief thread exits + the current function before the other threads have a chance to + copy from it. (A barrier would fix that race condition, but then + again, I prefer to keep barriers to a minimum.) */ \ + bli_membrk_acquire_m \ + ( \ + rntm, \ + size_needed, \ + pack_buf_type, \ + mem \ + ); \ + } \ +\ + /* Broadcast the address of the chief thread's passed-in mem_t to all + threads. */ \ + mem_t* mem_p = bli_thread_broadcast( thread, mem ); \ +\ + /* Non-chief threads: Copy the contents of the chief thread's + passed-in mem_t to the passed-in mem_t for this thread. (The + chief thread already has the mem_t, so it does not need to + perform any copy.) */ \ + if ( !bli_thread_am_ochief( thread ) ) \ + { \ + *mem = *mem_p; \ + } \ + } \ + else /* if ( bli_mem_is_alloc( mem ) ) */ \ + { \ + /* If the mem_t entry provided by the caller does NOT contain a NULL + buffer, then a block has already been acquired from the memory + broker and cached by the caller. */ \ +\ + /* As a sanity check, we should make sure that the mem_t object isn't + associated with a block that is too small compared to the size of + the packed matrix buffer that is needed, according to the value + computed above. */ \ + siz_t mem_size = bli_mem_size( mem ); \ +\ + if ( mem_size < size_needed ) \ + { \ + if ( bli_thread_am_ochief( thread ) ) \ + { \ + /* The chief thread releases the existing block associated + with the mem_t, and then re-acquires a new block, saving + the associated mem_t to its passed-in mem_t. (See coment + above for why the acquisition needs to be directly to + the chief thread's passed-in mem_t and not a local + (temporary) mem_t. */ \ + bli_membrk_release \ + ( \ + rntm, \ + mem \ + ); \ + bli_membrk_acquire_m \ + ( \ + rntm, \ + size_needed, \ + pack_buf_type, \ + mem \ + ); \ + } \ +\ + /* Broadcast the address of the chief thread's passed-in mem_t + to all threads. */ \ + mem_t* mem_p = bli_thread_broadcast( thread, mem ); \ +\ + /* Non-chief threads: Copy the contents of the chief thread's + passed-in mem_t to the passed-in mem_t for this thread. (The + chief thread already has the mem_t, so it does not need to + perform any copy.) */ \ + if ( !bli_thread_am_ochief( thread ) ) \ + { \ + *mem = *mem_p; \ + } \ + } \ + else \ + { \ + /* If the mem_t entry is already allocated and sufficiently large, + then we use it as-is. No action is needed. */ \ + } \ + } \ +} + +//INSERT_GENTFUNC_BASIC0( packm_init_mem_b ) +GENTFUNC( float, s, packm_init_mem_b ) +GENTFUNC( double, d, packm_init_mem_b ) +GENTFUNC( scomplex, c, packm_init_mem_b ) +GENTFUNC( dcomplex, z, packm_init_mem_b ) + + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTECH2(bls_,ch,opname) \ + ( \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + if ( thread != NULL ) \ + if ( bli_thread_am_ochief( thread ) ) \ + { \ + /* Check the mem_t entry provided by the caller. Only proceed if it + is allocated, which it should be. */ \ + if ( bli_mem_is_alloc( mem ) ) \ + { \ + bli_membrk_release \ + ( \ + rntm, \ + mem \ + ); \ + } \ + } \ +} + +//INSERT_GENTFUNC_BASIC0( packm_finalize_mem_b ) +GENTFUNC( float, s, packm_finalize_mem_b ) +GENTFUNC( double, d, packm_finalize_mem_b ) +GENTFUNC( scomplex, c, packm_finalize_mem_b ) +GENTFUNC( dcomplex, z, packm_finalize_mem_b ) + + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTECH2(bls_,ch,opname) \ + ( \ + pack_t* restrict schema, \ + dim_t k, \ + dim_t n, \ + dim_t nr, \ + dim_t* restrict k_max, \ + dim_t* restrict n_max, \ + ctype** p, inc_t* restrict rs_p, inc_t* restrict cs_p, \ + dim_t* restrict pd_p, inc_t* restrict ps_p, \ + mem_t* restrict mem \ + ) \ +{ \ + /* NOTE: This "rounding up" of the last upanel is absolutely necessary since + we NEED that last micropanel to have the same ldim (cs_p) as the other + micropanels. Why? Because the microkernel assumes that the register (MR, + NR) AND storage (PACKMR, PACKNR) blocksizes do not change. */ \ + *k_max = k; \ + *n_max = ( n / nr + ( n % nr ? 1 : 0 ) ) * nr; \ +\ + /* Determine the dimensions and strides for the packed matrix B. */ \ + { \ + /* Pack B to row-stored column-panels. */ \ + *rs_p = nr; \ + *cs_p = 1; \ +\ + *pd_p = nr; \ + *ps_p = k * nr; \ +\ + /* Set the schema to "packed column panels" to indicate packing to + conventional row-stored column panels. */ \ + *schema = BLIS_PACKED_COL_PANELS; \ + } \ +\ + /* Set the buffer address provided by the caller to point to the memory + associated with the mem_t entry acquired from the memory pool. */ \ + *p = bli_mem_buffer( mem ); \ +} + +//INSERT_GENTFUNC_BASIC0( packm_init_b ) +GENTFUNC( float, s, packm_init_b ) +GENTFUNC( double, d, packm_init_b ) +GENTFUNC( scomplex, c, packm_init_b ) +GENTFUNC( dcomplex, z, packm_init_b ) + + +// +// Define BLAS-like interfaces to the variant chooser. +// + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, opname ) \ +\ +void PASTECH2(bls_,ch,opname) \ + ( \ + conj_t conj, \ + dim_t k_alloc, \ + dim_t n_alloc, \ + dim_t k, \ + dim_t n, \ + dim_t nr, \ + ctype* restrict kappa, \ + ctype* restrict b, inc_t rs_b, inc_t cs_b, \ + ctype** restrict p, inc_t* restrict rs_p, inc_t* restrict cs_p, \ + inc_t* restrict ps_p, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + pack_t schema; \ + dim_t k_max; \ + dim_t n_max; \ + dim_t pd_p; \ +\ + /* Prepare the packing destination buffer. */ \ + PASTECH2(bls_,ch,packm_init_mem_b) \ + ( \ + k_alloc, n_alloc, nr, \ + cntx, \ + rntm, \ + mem, \ + thread \ + ); \ +\ + /* Determine the packing buffer and related parameters for matrix B. */ \ + PASTECH2(bls_,ch,packm_init_b) \ + ( \ + &schema, \ + k, n, nr, \ + &k_max, &n_max, \ + p, rs_p, cs_p, \ + &pd_p, ps_p, \ + mem \ + ); \ +\ + /* Pack matrix B to the destination buffer chosen above. Here, the packed + matrix is stored to row-stored k x NR micropanels. */ \ + PASTECH2(bls_,ch,packm_var1) \ + ( \ + conj, \ + schema, \ + k, \ + n, \ + k_max, \ + n_max, \ + kappa, \ + b, rs_b, cs_b, \ + *p, *rs_p, *cs_p, \ + pd_p, *ps_p, \ + cntx, \ + thread \ + ); \ +\ + /* Barrier so that packing is done before computation. */ \ + bli_thread_barrier( thread ); \ +} + +//INSERT_GENTFUNC_BASIC0( packm_b ) +GENTFUNC( float, s, packm_b ) +GENTFUNC( double, d, packm_b ) +GENTFUNC( scomplex, c, packm_b ) +GENTFUNC( dcomplex, z, packm_b ) + diff --git a/sandbox/gemmlike/bls_l3_packm_b.h b/sandbox/gemmlike/bls_l3_packm_b.h new file mode 100644 index 0000000000..728d21aed5 --- /dev/null +++ b/sandbox/gemmlike/bls_l3_packm_b.h @@ -0,0 +1,122 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#undef GENTPROT +#define GENTPROT( ctype, ch, opname ) \ +\ +void PASTECH2(bls_,ch,opname) \ + ( \ + dim_t k, \ + dim_t n, \ + dim_t nr, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ); \ + +//INSERT_GENTPROT_BASIC0( packm_init_mem_b ) +GENTPROT( float, s, packm_init_mem_b ) +GENTPROT( double, d, packm_init_mem_b ) +GENTPROT( scomplex, c, packm_init_mem_b ) +GENTPROT( dcomplex, z, packm_init_mem_b ) + + +#undef GENTPROT +#define GENTPROT( ctype, ch, opname ) \ +\ +void PASTECH2(bls_,ch,opname) \ + ( \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ); \ + +//INSERT_GENTPROT_BASIC0( packm_finalize_mem_b ) +GENTPROT( float, s, packm_finalize_mem_b ) +GENTPROT( double, d, packm_finalize_mem_b ) +GENTPROT( scomplex, c, packm_finalize_mem_b ) +GENTPROT( dcomplex, z, packm_finalize_mem_b ) + + +#undef GENTPROT +#define GENTPROT( ctype, ch, opname ) \ +\ +void PASTECH2(bls_,ch,opname) \ + ( \ + pack_t* restrict schema, \ + dim_t k, \ + dim_t n, \ + dim_t nr, \ + dim_t* restrict k_max, \ + dim_t* restrict n_max, \ + ctype** p, inc_t* restrict rs_p, inc_t* restrict cs_p, \ + dim_t* restrict pd_p, inc_t* restrict ps_p, \ + mem_t* restrict mem \ + ); \ + +//INSERT_GENTPROT_BASIC0( packm_init_b ) +GENTPROT( float, s, packm_init_b ) +GENTPROT( double, d, packm_init_b ) +GENTPROT( scomplex, c, packm_init_b ) +GENTPROT( dcomplex, z, packm_init_b ) + + +#undef GENTPROT +#define GENTPROT( ctype, ch, opname ) \ +\ +void PASTECH2(bls_,ch,opname) \ + ( \ + conj_t conj, \ + dim_t k_alloc, \ + dim_t n_alloc, \ + dim_t k, \ + dim_t n, \ + dim_t nr, \ + ctype* restrict kappa, \ + ctype* restrict b, inc_t rs_b, inc_t cs_b, \ + ctype** restrict p, inc_t* restrict rs_p, inc_t* restrict cs_p, \ + inc_t* restrict ps_p, \ + cntx_t* restrict cntx, \ + rntm_t* restrict rntm, \ + mem_t* restrict mem, \ + thrinfo_t* restrict thread \ + ); \ + +//INSERT_GENTPROT_BASIC0( packm_b ) +GENTPROT( float, s, packm_b ) +GENTPROT( double, d, packm_b ) +GENTPROT( scomplex, c, packm_b ) +GENTPROT( dcomplex, z, packm_b ) + diff --git a/sandbox/gemmlike/bls_l3_packm_var.c b/sandbox/gemmlike/bls_l3_packm_var.c new file mode 100644 index 0000000000..8a4c1d0206 --- /dev/null +++ b/sandbox/gemmlike/bls_l3_packm_var.c @@ -0,0 +1,198 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +// +// Define BLAS-like interfaces to the variants. +// + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTECH2(bls_,ch,varname) \ + ( \ + trans_t transc, \ + pack_t schema, \ + dim_t m, \ + dim_t n, \ + dim_t m_max, \ + dim_t n_max, \ + ctype* restrict kappa, \ + ctype* restrict c, inc_t rs_c, inc_t cs_c, \ + ctype* restrict p, inc_t rs_p, inc_t cs_p, \ + dim_t pd_p, inc_t ps_p, \ + cntx_t* restrict cntx, \ + thrinfo_t* restrict thread \ + ) \ +{ \ + ctype* restrict kappa_cast = kappa; \ + ctype* restrict c_cast = c; \ + ctype* restrict p_cast = p; \ +\ + dim_t iter_dim; \ + dim_t n_iter; \ + dim_t it, ic; \ + dim_t ic0; \ + doff_t ic_inc; \ + dim_t panel_len_full; \ + dim_t panel_len_i; \ + dim_t panel_len_max; \ + dim_t panel_len_max_i; \ + dim_t panel_dim_i; \ + dim_t panel_dim_max; \ + inc_t vs_c; \ + inc_t ldc; \ + inc_t ldp; \ + conj_t conjc; \ +\ +\ + /* Extract the conjugation bit from the transposition argument. */ \ + conjc = bli_extract_conj( transc ); \ +\ + /* Create flags to incidate row or column storage. Note that the + schema bit that encodes row or column is describing the form of + micro-panel, not the storage in the micro-panel. Hence the + mismatch in "row" and "column" semantics. */ \ + bool row_stored = bli_is_col_packed( schema ); \ + /*bool col_stored = bli_is_row_packed( schema );*/ \ +\ + /* If the row storage flag indicates row storage, then we are packing + to column panels; otherwise, if the strides indicate column storage, + we are packing to row panels. */ \ + if ( row_stored ) \ + { \ + /* Prepare to pack to row-stored column panels. */ \ + iter_dim = n; \ + panel_len_full = m; \ + panel_len_max = m_max; \ + panel_dim_max = pd_p; \ + vs_c = cs_c; \ + ldc = rs_c; \ + ldp = rs_p; \ + } \ + else /* if ( col_stored ) */ \ + { \ + /* Prepare to pack to column-stored row panels. */ \ + iter_dim = m; \ + panel_len_full = n; \ + panel_len_max = n_max; \ + panel_dim_max = pd_p; \ + vs_c = rs_c; \ + ldc = cs_c; \ + ldp = cs_p; \ + } \ +\ + /* Compute the total number of iterations we'll need. */ \ + n_iter = iter_dim / panel_dim_max + ( iter_dim % panel_dim_max ? 1 : 0 ); \ +\ + /* Set the initial values and increments for indices related to C and P + based on whether reverse iteration was requested. */ \ + { \ + ic0 = 0; \ + ic_inc = panel_dim_max; \ + } \ +\ + ctype* restrict p_begin = p_cast; \ +\ + /* Query the number of threads and thread ids from the current thread's + packm thrinfo_t node. */ \ + const dim_t nt = bli_thread_n_way( thread ); \ + const dim_t tid = bli_thread_work_id( thread ); \ +\ + /* Suppress warnings in case tid isn't used (ie: as in slab partitioning). */ \ + ( void )nt; \ + ( void )tid; \ +\ + dim_t it_start, it_end, it_inc; \ +\ + /* Determine the thread range and increment using the current thread's + packm thrinfo_t node. NOTE: The definition of bli_thread_range_jrir() + will depend on whether slab or round-robin partitioning was requested + at configure-time. */ \ + bli_thread_range_jrir( thread, n_iter, 1, FALSE, &it_start, &it_end, &it_inc ); \ +\ + /* Iterate over every logical micropanel in the source matrix. */ \ + for ( ic = ic0, it = 0; it < n_iter; \ + ic += ic_inc, it += 1 ) \ + { \ + panel_dim_i = bli_min( panel_dim_max, iter_dim - ic ); \ +\ + ctype* restrict c_begin = c_cast + (ic )*vs_c; \ +\ + ctype* restrict c_use = c_begin; \ + ctype* restrict p_use = p_begin; \ +\ + panel_len_i = panel_len_full; \ + panel_len_max_i = panel_len_max; \ +\ + /* The definition of bli_packm_my_iter() will depend on whether slab + or round-robin partitioning was requested at configure-time. (The + default is slab.) */ \ + if ( bli_packm_my_iter( it, it_start, it_end, tid, nt ) ) \ + { \ + PASTEMAC(ch,packm_cxk) \ + ( \ + conjc, \ + schema, \ + panel_dim_i, \ + panel_dim_max, \ + panel_len_i, \ + panel_len_max_i, \ + kappa_cast, \ + c_use, vs_c, ldc, \ + p_use, ldp, \ + cntx \ + ); \ + } \ +\ + p_begin += ps_p; \ +\ +/* +if ( row_stored ) \ +PASTEMAC(ch,fprintm)( stdout, "packm_sup_var1: b packed", panel_len_max, panel_dim_max, \ + p_use, rs_p, cs_p, "%5.2f", "" ); \ +if ( !row_stored ) \ +PASTEMAC(ch,fprintm)( stdout, "packm_sup_var1: a packed", panel_dim_max, panel_len_max, \ + p_use, rs_p, cs_p, "%5.2f", "" ); \ +*/ \ + } \ +} + +//INSERT_GENTFUNC_BASIC0( packm_var1 ) +GENTFUNC( float, s, packm_var1 ) +GENTFUNC( double, d, packm_var1 ) +GENTFUNC( scomplex, c, packm_var1 ) +GENTFUNC( dcomplex, z, packm_var1 ) + diff --git a/sandbox/gemmlike/bls_l3_packm_var.h b/sandbox/gemmlike/bls_l3_packm_var.h new file mode 100644 index 0000000000..0e8eb9ee8a --- /dev/null +++ b/sandbox/gemmlike/bls_l3_packm_var.h @@ -0,0 +1,63 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +// +// Prototype BLAS-like interfaces to the variants. +// + +#undef GENTPROT +#define GENTPROT( ctype, ch, varname ) \ +\ +void PASTECH2(bls_,ch,varname) \ + ( \ + trans_t transc, \ + pack_t schema, \ + dim_t m, \ + dim_t n, \ + dim_t m_max, \ + dim_t n_max, \ + ctype* restrict kappa, \ + ctype* restrict c, inc_t rs_c, inc_t cs_c, \ + ctype* restrict p, inc_t rs_p, inc_t cs_p, \ + dim_t pd_p, inc_t ps_p, \ + cntx_t* restrict cntx, \ + thrinfo_t* restrict thread \ + ); + +//INSERT_GENTPROT_BASIC0( packm_var1 ) +GENTPROT( float, s, packm_var1 ) +GENTPROT( double, d, packm_var1 ) +GENTPROT( scomplex, c, packm_var1 ) +GENTPROT( dcomplex, z, packm_var1 ) + diff --git a/sandbox/gemmlike/thread/bls_l3_decor.h b/sandbox/gemmlike/thread/bls_l3_decor.h new file mode 100644 index 0000000000..bb8a95bb46 --- /dev/null +++ b/sandbox/gemmlike/thread/bls_l3_decor.h @@ -0,0 +1,73 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_SBX_L3_DECOR_H +#define BLIS_SBX_L3_DECOR_H + +// -- sup definitions ---------------------------------------------------------- + +// Level-3 sup internal function type. +typedef void (*l3sbxint_t) + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm, + thrinfo_t* thread + ); + +// Level-3 sup thread decorator prototype. +void bls_l3_thread_decorator + ( + l3sbxint_t func, + opid_t family, + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm + ); + +// Include definitions specific to the method of multithreading. +#include "bls_l3_decor_single.h" +#include "bls_l3_decor_openmp.h" +#include "bls_l3_decor_pthreads.h" + +#endif + diff --git a/sandbox/gemmlike/thread/bls_l3_decor_openmp.c b/sandbox/gemmlike/thread/bls_l3_decor_openmp.c new file mode 100644 index 0000000000..851a29e52b --- /dev/null +++ b/sandbox/gemmlike/thread/bls_l3_decor_openmp.c @@ -0,0 +1,138 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#ifdef BLIS_ENABLE_OPENMP + +// Define a dummy thread entry function, which is needed in the pthreads +// version, so that when building Windows DLLs (with OpenMP enabled or with +// no multithreading) we don't risk having an unresolved symbol. +void* bls_l3_thread_entry( void* data_void ) { return NULL; } + +//#define PRINT_THRINFO + +void bls_l3_thread_decorator + ( + l3sbxint_t func, + opid_t family, + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm + ) +{ + // Query the total number of threads from the rntm_t object. + const dim_t n_threads = bli_rntm_num_threads( rntm ); + + // NOTE: The sba was initialized in bli_init(). + + // Check out an array_t from the small block allocator. This is done + // with an internal lock to ensure only one application thread accesses + // the sba at a time. bli_sba_checkout_array() will also automatically + // resize the array_t, if necessary. + array_t* restrict array = bli_sba_checkout_array( n_threads ); + + // Access the pool_t* for thread 0 and embed it into the rntm. We do + // this up-front only so that we have the rntm_t.sba_pool field + // initialized and ready for the global communicator creation below. + bli_sba_rntm_set_pool( 0, array, rntm ); + + // Set the packing block allocator field of the rntm. This will be + // inherited by all of the child threads when they make local copies of + // the rntm below. + bli_membrk_rntm_set_membrk( rntm ); + + // Allcoate a global communicator for the root thrinfo_t structures. + thrcomm_t* restrict gl_comm = bli_thrcomm_create( rntm, n_threads ); + + + _Pragma( "omp parallel num_threads(n_threads)" ) + { + // Create a thread-local copy of the master thread's rntm_t. This is + // necessary since we want each thread to be able to track its own + // small block pool_t as it executes down the function stack. + rntm_t rntm_l = *rntm; + rntm_t* restrict rntm_p = &rntm_l; + + // Query the thread's id from OpenMP. + const dim_t tid = omp_get_thread_num(); + + // Check for a somewhat obscure OpenMP thread-mistmatch issue. + // NOTE: This calls the same function used for the conventional/large + // code path. + bli_l3_thread_decorator_thread_check( n_threads, tid, gl_comm, rntm_p ); + + // Use the thread id to access the appropriate pool_t* within the + // array_t, and use it to set the sba_pool field within the rntm_t. + // If the pool_t* element within the array_t is NULL, it will first + // be allocated/initialized. + bli_sba_rntm_set_pool( tid, array, rntm_p ); + + thrinfo_t* thread = NULL; + + // Create the root node of the thread's thrinfo_t structure. + bli_l3_sup_thrinfo_create_root( tid, gl_comm, rntm_p, &thread ); + + func + ( + alpha, + a, + b, + beta, + c, + cntx, + rntm_p, + thread + ); + + // Free the current thread's thrinfo_t structure. + bli_l3_sup_thrinfo_free( rntm_p, thread ); + } + + // We shouldn't free the global communicator since it was already freed + // by the global communicator's chief thread in bli_l3_thrinfo_free() + // (called from the thread entry function). + + // Check the array_t back into the small block allocator. Similar to the + // check-out, this is done using a lock embedded within the sba to ensure + // mutual exclusion. + bli_sba_checkin_array( array ); +} + +#endif + diff --git a/sandbox/gemmlike/thread/bls_l3_decor_openmp.h b/sandbox/gemmlike/thread/bls_l3_decor_openmp.h new file mode 100644 index 0000000000..9c956d7c36 --- /dev/null +++ b/sandbox/gemmlike/thread/bls_l3_decor_openmp.h @@ -0,0 +1,44 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_SBX_L3_DECOR_OPENMP_H +#define BLIS_SBX_L3_DECOR_OPENMP_H + +// Definitions specific to situations when OpenMP multithreading is enabled. +#ifdef BLIS_ENABLE_OPENMP + +#endif + +#endif + diff --git a/sandbox/gemmlike/thread/bls_l3_decor_pthreads.c b/sandbox/gemmlike/thread/bls_l3_decor_pthreads.c new file mode 100644 index 0000000000..f87d79fd6c --- /dev/null +++ b/sandbox/gemmlike/thread/bls_l3_decor_pthreads.c @@ -0,0 +1,213 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#ifdef BLIS_ENABLE_PTHREADS + +// A data structure to assist in passing operands to additional threads. +typedef struct thread_data +{ + l3sbxint_t func; + opid_t family; + obj_t* alpha; + obj_t* a; + obj_t* b; + obj_t* beta; + obj_t* c; + cntx_t* cntx; + rntm_t* rntm; + dim_t tid; + thrcomm_t* gl_comm; + array_t* array; +} thread_data_t; + +// Entry point function for additional threads. +void* bls_l3_thread_entry( void* data_void ) +{ + thread_data_t* data = data_void; + + l3sbxint_t func = data->func; + opid_t family = data->family; + obj_t* alpha = data->alpha; + obj_t* a = data->a; + obj_t* b = data->b; + obj_t* beta = data->beta; + obj_t* c = data->c; + cntx_t* cntx = data->cntx; + rntm_t* rntm = data->rntm; + dim_t tid = data->tid; + array_t* array = data->array; + thrcomm_t* gl_comm = data->gl_comm; + + ( void )family; + + // Create a thread-local copy of the master thread's rntm_t. This is + // necessary since we want each thread to be able to track its own + // small block pool_t as it executes down the function stack. + rntm_t rntm_l = *rntm; + rntm_t* restrict rntm_p = &rntm_l; + + // Use the thread id to access the appropriate pool_t* within the + // array_t, and use it to set the sba_pool field within the rntm_t. + // If the pool_t* element within the array_t is NULL, it will first + // be allocated/initialized. + bli_sba_rntm_set_pool( tid, array, rntm_p ); + + thrinfo_t* thread = NULL; + + // Create the root node of the current thread's thrinfo_t structure. + bli_l3_sup_thrinfo_create_root( tid, gl_comm, rntm_p, &thread ); + + func + ( + alpha, + a, + b, + beta, + c, + cntx, + rntm_p, + thread + ); + + // Free the current thread's thrinfo_t structure. + bli_l3_sup_thrinfo_free( rntm_p, thread ); + + return NULL; +} + +void bls_l3_thread_decorator + ( + l3sbxint_t func, + opid_t family, + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm + ) +{ + // Query the total number of threads from the context. + const dim_t n_threads = bli_rntm_num_threads( rntm ); + + // NOTE: The sba was initialized in bli_init(). + + // Check out an array_t from the small block allocator. This is done + // with an internal lock to ensure only one application thread accesses + // the sba at a time. bli_sba_checkout_array() will also automatically + // resize the array_t, if necessary. + array_t* restrict array = bli_sba_checkout_array( n_threads ); + + // Access the pool_t* for thread 0 and embed it into the rntm. We do + // this up-front only so that we have the rntm_t.sba_pool field + // initialized and ready for the global communicator creation below. + bli_sba_rntm_set_pool( 0, array, rntm ); + + // Set the packing block allocator field of the rntm. This will be + // inherited by all of the child threads when they make local copies of + // the rntm below. + bli_membrk_rntm_set_membrk( rntm ); + + // Allocate a global communicator for the root thrinfo_t structures. + thrcomm_t* restrict gl_comm = bli_thrcomm_create( rntm, n_threads ); + + // Allocate an array of pthread objects and auxiliary data structs to pass + // to the thread entry functions. + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_l3_thread_decorator().pth: " ); + #endif + bli_pthread_t* pthreads = bli_malloc_intl( sizeof( bli_pthread_t ) * n_threads ); + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_l3_thread_decorator().pth: " ); + #endif + thread_data_t* datas = bli_malloc_intl( sizeof( thread_data_t ) * n_threads ); + + // NOTE: We must iterate backwards so that the chief thread (thread id 0) + // can spawn all other threads before proceeding with its own computation. + for ( dim_t tid = n_threads - 1; 0 <= tid; tid-- ) + { + // Set up thread data for additional threads (beyond thread 0). + datas[tid].func = func; + datas[tid].family = family; + datas[tid].alpha = alpha; + datas[tid].a = a; + datas[tid].b = b; + datas[tid].beta = beta; + datas[tid].c = c; + datas[tid].cntx = cntx; + datas[tid].rntm = rntm; + datas[tid].tid = tid; + datas[tid].gl_comm = gl_comm; + datas[tid].array = array; + + // Spawn additional threads for ids greater than 1. + if ( tid != 0 ) + bli_pthread_create( &pthreads[tid], NULL, &bls_l3_thread_entry, &datas[tid] ); + else + bls_l3_thread_entry( ( void* )(&datas[0]) ); + } + + // We shouldn't free the global communicator since it was already freed + // by the global communicator's chief thread in bli_l3_thrinfo_free() + // (called from the thread entry function). + + // Thread 0 waits for additional threads to finish. + for ( dim_t tid = 1; tid < n_threads; tid++ ) + { + bli_pthread_join( pthreads[tid], NULL ); + } + + // Check the array_t back into the small block allocator. Similar to the + // check-out, this is done using a lock embedded within the sba to ensure + // mutual exclusion. + bli_sba_checkin_array( array ); + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_l3_thread_decorator().pth: " ); + #endif + bli_free_intl( pthreads ); + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_l3_thread_decorator().pth: " ); + #endif + bli_free_intl( datas ); +} + +#endif + diff --git a/sandbox/gemmlike/thread/bls_l3_decor_pthreads.h b/sandbox/gemmlike/thread/bls_l3_decor_pthreads.h new file mode 100644 index 0000000000..ef5c3bad45 --- /dev/null +++ b/sandbox/gemmlike/thread/bls_l3_decor_pthreads.h @@ -0,0 +1,47 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_SBX_L3_DECOR_PTHREADS_H +#define BLIS_SBX_L3_DECOR_PTHREADS_H + +// Definitions specific to situations when POSIX multithreading is enabled. +#ifdef BLIS_ENABLE_PTHREADS + +// Thread entry point prototype. +void* bls_l3_thread_entry( void* data_void ); + +#endif + +#endif + diff --git a/sandbox/gemmlike/thread/bls_l3_decor_single.c b/sandbox/gemmlike/thread/bls_l3_decor_single.c new file mode 100644 index 0000000000..7d9017dcd5 --- /dev/null +++ b/sandbox/gemmlike/thread/bls_l3_decor_single.c @@ -0,0 +1,141 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#ifndef BLIS_ENABLE_MULTITHREADING + +#define SKIP_THRINFO_TREE + +void bls_l3_thread_decorator + ( + l3sbxint_t func, + opid_t family, + //pack_t schema_a, + //pack_t schema_b, + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm + ) +{ + // For sequential execution, we use only one thread. + const dim_t n_threads = 1; + + // NOTE: The sba was initialized in bli_init(). + + // Check out an array_t from the small block allocator. This is done + // with an internal lock to ensure only one application thread accesses + // the sba at a time. bli_sba_checkout_array() will also automatically + // resize the array_t, if necessary. + array_t* restrict array = bli_sba_checkout_array( n_threads ); + + // Access the pool_t* for thread 0 and embed it into the rntm. + bli_sba_rntm_set_pool( 0, array, rntm ); + + // Set the packing block allocator field of the rntm. + bli_membrk_rntm_set_membrk( rntm ); + +#ifndef SKIP_THRINFO_TREE + // Allcoate a global communicator for the root thrinfo_t structures. + thrcomm_t* restrict gl_comm = bli_thrcomm_create( rntm, n_threads ); +#endif + + + { + // NOTE: We don't need to create another copy of the rntm_t since + // it was already copied in one of the high-level oapi functions. + rntm_t* restrict rntm_p = rntm; + + // There is only one thread id (for the thief thread). + const dim_t tid = 0; + + // Use the thread id to access the appropriate pool_t* within the + // array_t, and use it to set the sba_pool field within the rntm_t. + // If the pool_t* element within the array_t is NULL, it will first + // be allocated/initialized. + // NOTE: This is commented out because, in the single-threaded case, + // this is redundant since it's already been done above. + //bli_sba_rntm_set_pool( tid, array, rntm_p ); + +#ifndef SKIP_THRINFO_TREE + thrinfo_t* thread = NULL; + + // Create the root node of the thread's thrinfo_t structure. + bli_l3_sup_thrinfo_create_root( tid, gl_comm, rntm_p, &thread ); +#else + // This optimization allows us to use one of the global thrinfo_t + // objects for single-threaded execution rather than grow one from + // scratch. The key is that bli_thrinfo_sup_grow(), which is called + // from within the variants, will immediately return if it detects + // that the thrinfo_t* passed into it is either + // &BLIS_GEMM_SINGLE_THREADED or &BLIS_PACKM_SINGLE_THREADED. + thrinfo_t* thread = &BLIS_GEMM_SINGLE_THREADED; + + ( void )tid; +#endif + + func + ( + alpha, + a, + b, + beta, + c, + cntx, + rntm_p, + thread + ); + +#ifndef SKIP_THRINFO_TREE + // Free the current thread's thrinfo_t structure. + bli_l3_sup_thrinfo_free( rntm_p, thread ); +#endif + } + + // We shouldn't free the global communicator since it was already freed + // by the global communicator's chief thread in bli_l3_thrinfo_free() + // (called above). + + // Check the array_t back into the small block allocator. Similar to the + // check-out, this is done using a lock embedded within the sba to ensure + // mutual exclusion. + bli_sba_checkin_array( array ); +} + +#endif + diff --git a/sandbox/gemmlike/thread/bls_l3_decor_single.h b/sandbox/gemmlike/thread/bls_l3_decor_single.h new file mode 100644 index 0000000000..211a43a894 --- /dev/null +++ b/sandbox/gemmlike/thread/bls_l3_decor_single.h @@ -0,0 +1,44 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, The University of Texas at Austin + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_SBX_L3_DECOR_SINGLE_H +#define BLIS_SBX_L3_DECOR_SINGLE_H + +// Definitions specific to situations when multithreading is disabled. +#ifndef BLIS_ENABLE_MULTITHREADING + +#endif + +#endif + diff --git a/sandbox/power10/bli_gemmnat.c b/sandbox/power10/bli_gemmnat.c index b2dabd29aa..846ccd35a8 100644 --- a/sandbox/power10/bli_gemmnat.c +++ b/sandbox/power10/bli_gemmnat.c @@ -32,7 +32,14 @@ */ -// This file is needed for the BLIS build system. +// Given the current architecture of BLIS sandboxes, bli_gemmnat() is the +// entry point to any sandbox implementation. + +// NOTE: This function is implemented identically to the function that it +// overrides in frame/ind/oapi/bli_l3_nat_oapi.c. This means that we are +// forgoing the option of customizing the implementations that underlie +// bli_gemm() and bli_?gemm(). Any new code defined in this sandbox +// directory, however, will be included in the BLIS. #include "blis.h" From 62c96a419067ffcb3074372b5a8f7d0a169582cc Mon Sep 17 00:00:00 2001 From: "Dipal M. Zambare" Date: Thu, 21 Apr 2022 06:28:29 +0000 Subject: [PATCH 074/243] Enabled AVX-512 kernels for Zen4 config Enabled AVX-512 skylake kernels in zen4 configuration. AVX-512 kernels are added for float and double types. AMD-Internal: [CPUPL-2108] --- config/zen4/bli_cntx_init_zen4.c | 20 +++++++++++--------- config/zen4/bli_family_zen4.h | 10 ++++++++-- config/zen4/make_defs.mk | 22 +++++++++++++++------- config_registry | 2 +- frame/include/bli_arch_config.h | 5 ++++- 5 files changed, 39 insertions(+), 20 deletions(-) diff --git a/config/zen4/bli_cntx_init_zen4.c b/config/zen4/bli_cntx_init_zen4.c index c340fa9087..9a37f61c3b 100644 --- a/config/zen4/bli_cntx_init_zen4.c +++ b/config/zen4/bli_cntx_init_zen4.c @@ -49,8 +49,8 @@ void bli_cntx_init_zen4( cntx_t* cntx ) ( 8, // gemm - BLIS_GEMM_UKR, BLIS_FLOAT, bli_sgemm_haswell_asm_6x16, TRUE, - BLIS_GEMM_UKR, BLIS_DOUBLE, bli_dgemm_haswell_asm_6x8, TRUE, + BLIS_GEMM_UKR, BLIS_FLOAT , bli_sgemm_skx_asm_32x12_l2, FALSE, + BLIS_GEMM_UKR, BLIS_DOUBLE, bli_dgemm_skx_asm_16x14, FALSE, BLIS_GEMM_UKR, BLIS_SCOMPLEX, bli_cgemm_haswell_asm_3x8, TRUE, BLIS_GEMM_UKR, BLIS_DCOMPLEX, bli_zgemm_haswell_asm_3x4, TRUE, // gemmtrsm_l @@ -160,14 +160,16 @@ void bli_cntx_init_zen4( cntx_t* cntx ) // // These are reference block sizes and may be overridden based on // number of threads used at runtime. + // s d c z - bli_blksz_init_easy( &blkszs[ BLIS_MR ], 6, 6, 3, 3 ); - bli_blksz_init_easy( &blkszs[ BLIS_NR ], 16, 8, 8, 4 ); - bli_blksz_init_easy( &blkszs[ BLIS_MC ], 144, 72, 144, 18 ); - bli_blksz_init_easy( &blkszs[ BLIS_KC ], 256, 256, 256, 566 ); - bli_blksz_init_easy( &blkszs[ BLIS_NC ], 4080, 4080, 4080, 256 ); - - bli_blksz_init_easy( &blkszs[ BLIS_AF ], 5, 5, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_MR ], 32, 16, 3, 3 ); + bli_blksz_init_easy( &blkszs[ BLIS_NR ], 12, 14, 8, 4 ); + bli_blksz_init_easy( &blkszs[ BLIS_MC ], 480, 240, 144, 18 ); + bli_blksz_init ( &blkszs[ BLIS_KC ], 384, 256, 256, 566, + 480, 320, 256, 566 ); + bli_blksz_init_easy( &blkszs[ BLIS_NC ], 3072, 3752, 4080, 256 ); + + bli_blksz_init_easy( &blkszs[ BLIS_AF ], 8, 8, -1, -1 ); bli_blksz_init_easy( &blkszs[ BLIS_DF ], 8, 8, -1, -1 ); // Update the context with the current architecture's register and cache diff --git a/config/zen4/bli_family_zen4.h b/config/zen4/bli_family_zen4.h index 9c70fcef83..71929cdac4 100644 --- a/config/zen4/bli_family_zen4.h +++ b/config/zen4/bli_family_zen4.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2021-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -39,7 +39,6 @@ // Setting these macros to 1 will force JR and IR inner loops // to be not paralleized. // - #define BLIS_THREAD_MAX_IR 1 #define BLIS_THREAD_MAX_JR 1 @@ -56,4 +55,11 @@ //#define BLIS_ENABLE_FAST_MATH +// -- SIMD config -------------------------------------------------------- + +#define BLIS_SIMD_ALIGN_SIZE 64 + +#define BLIS_SIMD_SIZE 64 +#define BLIS_SIMD_NUM_REGISTERS 32 + #endif diff --git a/config/zen4/make_defs.mk b/config/zen4/make_defs.mk index 44e96bb0c7..b20c0ef476 100644 --- a/config/zen4/make_defs.mk +++ b/config/zen4/make_defs.mk @@ -32,7 +32,7 @@ # # -# FLAGS that are specific to the 'zen3' architecture are added here. +# FLAGS that are specific to the 'zen4' architecture are added here. # FLAGS that are common for all the AMD architectures are present in # config/zen/amd_config.mk. @@ -73,15 +73,17 @@ GCC_VERSION := $(strip $(shell $(CC) -dumpversion | cut -d. -f1)) # gcc or clang version must be atleast 4.0 # gcc 9.0 or later: ifeq ($(shell test $(GCC_VERSION) -ge 11; echo $$?),0) -CKVECFLAGS += -march=znver3 +CKVECFLAGS += -march=znver3 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mfpmath=sse +CRVECFLAGS += -march=znver3 else ifeq ($(shell test $(GCC_VERSION) -ge 9; echo $$?),0) -CKVECFLAGS += -march=znver2 +CKVECFLAGS += -march=znver2 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mfpmath=sse +CRVECFLAGS += -march=znver2 else # If gcc is older than 9.1.0 but at least 6.1.0, then we can use -march=znver1 # as the fallback option. -CRVECFLAGS += -march=znver1 -mno-avx256-split-unaligned-store CKVECFLAGS += -march=znver1 -mno-avx256-split-unaligned-store +CRVECFLAGS += -march=znver1 -mno-avx256-split-unaligned-store endif # GCC 9 endif # GCC 11 else @@ -99,11 +101,13 @@ ifeq ($(CC_VENDOR),clang) # for version 3x we will enable znver3 ifeq ($(strip $(shell $(CC) -v |&head -1 |grep -c 'AOCC_3')),1) -CKVECFLAGS += -march=znver3 +CKVECFLAGS += -march=znver3 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mfpmath=sse +CRVECFLAGS += -march=znver3 else # for version 2x we will enable znver2 ifeq ($(strip $(shell $(CC) -v |&head -1 |grep -c 'AOCC.LLVM.2\|AOCC_2')),1) -CKVECFLAGS += -march=znver2 +CKVECFLAGS += -march=znver2 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mfpmath=sse +CRVECFLAGS += -march=znver2 else #if compiling with clang VENDOR_STRING := $(strip $(shell ${CC_VENDOR} --version | egrep -o '[0-9]+\.[0-9]+\.?[0-9]*')) @@ -111,8 +115,10 @@ CC_MAJOR := $(shell (echo ${VENDOR_STRING} | cut -d. -f1)) #clang 9.0 or later: ifeq ($(shell test $(CC_MAJOR) -ge 9; echo $$?),0) CKVECFLAGS += -march=znver2 +CRVECFLAGS += -march=znver2 else CKVECFLAGS += -march=znver1 +CRVECFLAGS += -march=znver1 endif # ge 9 endif # aocc 2 endif # aocc 3 @@ -121,7 +127,9 @@ endif # gcc # Flags specific to reference kernels. CROPTFLAGS := $(CKOPTFLAGS) -CRVECFLAGS := $(CKVECFLAGS) + +# Disable AVX-512 for reference kernels +CRVECFLAGS += -mno-avx512f -mno-avx512vl -mno-avx512bw -mno-avx512dq -mno-avx512cd -funsafe-math-optimizations -ffp-contract=fast # Store all of the variables here to new variables containing the # configuration name. diff --git a/config_registry b/config_registry index 822b133f5c..4e6716dfa1 100644 --- a/config_registry +++ b/config_registry @@ -26,7 +26,7 @@ sandybridge: sandybridge penryn: penryn # AMD architectures. -zen4: zen4/zen4/zen3/zen2/zen/haswell +zen4: zen4/zen4/skx/zen3/zen2/zen/haswell zen3: zen3/zen3/zen2/zen/haswell zen2: zen2/zen2/zen/haswell zen: zen/zen/haswell diff --git a/frame/include/bli_arch_config.h b/frame/include/bli_arch_config.h index 3e2e0b022b..6343c6ba89 100644 --- a/frame/include/bli_arch_config.h +++ b/frame/include/bli_arch_config.h @@ -6,7 +6,7 @@ Copyright (C) 2014, The University of Texas at Austin Copyright (C) 2016, Hewlett Packard Enterprise Development LP - Copyright (C) 2019 - 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -171,6 +171,9 @@ CNTX_INIT_PROTS( generic ) // -- AMD64 architectures -- +#ifdef BLIS_FAMILY_ZEN4 +#include "bli_family_zen4.h" +#endif #ifdef BLIS_FAMILY_ZEN3 #include "bli_family_zen3.h" #endif From f816cf059f8c916c9c3a73a2d0ba0c2b4adb6a28 Mon Sep 17 00:00:00 2001 From: "Dipal M. Zambare" Date: Thu, 21 Apr 2022 06:28:29 +0000 Subject: [PATCH 075/243] Enabled AVX-512 kernels for Zen4 config Enabled AVX-512 skylake kernels in zen4 configuration. AVX-512 kernels are added for float and double types. AMD-Internal: [CPUPL-2108] Change-Id: Idfe3f64a037db019cbdf43318954db52ad241a51 --- config/zen4/make_defs.mk | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config/zen4/make_defs.mk b/config/zen4/make_defs.mk index b20c0ef476..d95d12bfb4 100644 --- a/config/zen4/make_defs.mk +++ b/config/zen4/make_defs.mk @@ -128,7 +128,7 @@ endif # gcc # Flags specific to reference kernels. CROPTFLAGS := $(CKOPTFLAGS) -# Disable AVX-512 for reference kernels +# Disable AVX-512 for reference kernels. CRVECFLAGS += -mno-avx512f -mno-avx512vl -mno-avx512bw -mno-avx512dq -mno-avx512cd -funsafe-math-optimizations -ffp-contract=fast # Store all of the variables here to new variables containing the From 0adb525f5bc07fdf164753d52d15cd718219a330 Mon Sep 17 00:00:00 2001 From: "Dipal M. Zambare" Date: Thu, 21 Apr 2022 06:45:38 +0000 Subject: [PATCH 076/243] Revert "Enabled AVX-512 kernels for Zen4 config" This reverts commit f816cf059f8c916c9c3a73a2d0ba0c2b4adb6a28. Was committed without review. --- config/zen4/make_defs.mk | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config/zen4/make_defs.mk b/config/zen4/make_defs.mk index d95d12bfb4..b20c0ef476 100644 --- a/config/zen4/make_defs.mk +++ b/config/zen4/make_defs.mk @@ -128,7 +128,7 @@ endif # gcc # Flags specific to reference kernels. CROPTFLAGS := $(CKOPTFLAGS) -# Disable AVX-512 for reference kernels. +# Disable AVX-512 for reference kernels CRVECFLAGS += -mno-avx512f -mno-avx512vl -mno-avx512bw -mno-avx512dq -mno-avx512cd -funsafe-math-optimizations -ffp-contract=fast # Store all of the variables here to new variables containing the From b90420627a13b8927cda68efbb9ccc589ed3d787 Mon Sep 17 00:00:00 2001 From: "Dipal M. Zambare" Date: Thu, 21 Apr 2022 06:46:00 +0000 Subject: [PATCH 077/243] Revert "Enabled AVX-512 kernels for Zen4 config" This reverts commit 62c96a419067ffcb3074372b5a8f7d0a169582cc. Was committed without review. --- config/zen4/bli_cntx_init_zen4.c | 20 +++++++++----------- config/zen4/bli_family_zen4.h | 10 ++-------- config/zen4/make_defs.mk | 22 +++++++--------------- config_registry | 2 +- frame/include/bli_arch_config.h | 5 +---- 5 files changed, 20 insertions(+), 39 deletions(-) diff --git a/config/zen4/bli_cntx_init_zen4.c b/config/zen4/bli_cntx_init_zen4.c index 9a37f61c3b..c340fa9087 100644 --- a/config/zen4/bli_cntx_init_zen4.c +++ b/config/zen4/bli_cntx_init_zen4.c @@ -49,8 +49,8 @@ void bli_cntx_init_zen4( cntx_t* cntx ) ( 8, // gemm - BLIS_GEMM_UKR, BLIS_FLOAT , bli_sgemm_skx_asm_32x12_l2, FALSE, - BLIS_GEMM_UKR, BLIS_DOUBLE, bli_dgemm_skx_asm_16x14, FALSE, + BLIS_GEMM_UKR, BLIS_FLOAT, bli_sgemm_haswell_asm_6x16, TRUE, + BLIS_GEMM_UKR, BLIS_DOUBLE, bli_dgemm_haswell_asm_6x8, TRUE, BLIS_GEMM_UKR, BLIS_SCOMPLEX, bli_cgemm_haswell_asm_3x8, TRUE, BLIS_GEMM_UKR, BLIS_DCOMPLEX, bli_zgemm_haswell_asm_3x4, TRUE, // gemmtrsm_l @@ -160,16 +160,14 @@ void bli_cntx_init_zen4( cntx_t* cntx ) // // These are reference block sizes and may be overridden based on // number of threads used at runtime. - // s d c z - bli_blksz_init_easy( &blkszs[ BLIS_MR ], 32, 16, 3, 3 ); - bli_blksz_init_easy( &blkszs[ BLIS_NR ], 12, 14, 8, 4 ); - bli_blksz_init_easy( &blkszs[ BLIS_MC ], 480, 240, 144, 18 ); - bli_blksz_init ( &blkszs[ BLIS_KC ], 384, 256, 256, 566, - 480, 320, 256, 566 ); - bli_blksz_init_easy( &blkszs[ BLIS_NC ], 3072, 3752, 4080, 256 ); - - bli_blksz_init_easy( &blkszs[ BLIS_AF ], 8, 8, -1, -1 ); + bli_blksz_init_easy( &blkszs[ BLIS_MR ], 6, 6, 3, 3 ); + bli_blksz_init_easy( &blkszs[ BLIS_NR ], 16, 8, 8, 4 ); + bli_blksz_init_easy( &blkszs[ BLIS_MC ], 144, 72, 144, 18 ); + bli_blksz_init_easy( &blkszs[ BLIS_KC ], 256, 256, 256, 566 ); + bli_blksz_init_easy( &blkszs[ BLIS_NC ], 4080, 4080, 4080, 256 ); + + bli_blksz_init_easy( &blkszs[ BLIS_AF ], 5, 5, -1, -1 ); bli_blksz_init_easy( &blkszs[ BLIS_DF ], 8, 8, -1, -1 ); // Update the context with the current architecture's register and cache diff --git a/config/zen4/bli_family_zen4.h b/config/zen4/bli_family_zen4.h index 71929cdac4..9c70fcef83 100644 --- a/config/zen4/bli_family_zen4.h +++ b/config/zen4/bli_family_zen4.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2021-2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -39,6 +39,7 @@ // Setting these macros to 1 will force JR and IR inner loops // to be not paralleized. // + #define BLIS_THREAD_MAX_IR 1 #define BLIS_THREAD_MAX_JR 1 @@ -55,11 +56,4 @@ //#define BLIS_ENABLE_FAST_MATH -// -- SIMD config -------------------------------------------------------- - -#define BLIS_SIMD_ALIGN_SIZE 64 - -#define BLIS_SIMD_SIZE 64 -#define BLIS_SIMD_NUM_REGISTERS 32 - #endif diff --git a/config/zen4/make_defs.mk b/config/zen4/make_defs.mk index b20c0ef476..44e96bb0c7 100644 --- a/config/zen4/make_defs.mk +++ b/config/zen4/make_defs.mk @@ -32,7 +32,7 @@ # # -# FLAGS that are specific to the 'zen4' architecture are added here. +# FLAGS that are specific to the 'zen3' architecture are added here. # FLAGS that are common for all the AMD architectures are present in # config/zen/amd_config.mk. @@ -73,17 +73,15 @@ GCC_VERSION := $(strip $(shell $(CC) -dumpversion | cut -d. -f1)) # gcc or clang version must be atleast 4.0 # gcc 9.0 or later: ifeq ($(shell test $(GCC_VERSION) -ge 11; echo $$?),0) -CKVECFLAGS += -march=znver3 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mfpmath=sse -CRVECFLAGS += -march=znver3 +CKVECFLAGS += -march=znver3 else ifeq ($(shell test $(GCC_VERSION) -ge 9; echo $$?),0) -CKVECFLAGS += -march=znver2 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mfpmath=sse -CRVECFLAGS += -march=znver2 +CKVECFLAGS += -march=znver2 else # If gcc is older than 9.1.0 but at least 6.1.0, then we can use -march=znver1 # as the fallback option. -CKVECFLAGS += -march=znver1 -mno-avx256-split-unaligned-store CRVECFLAGS += -march=znver1 -mno-avx256-split-unaligned-store +CKVECFLAGS += -march=znver1 -mno-avx256-split-unaligned-store endif # GCC 9 endif # GCC 11 else @@ -101,13 +99,11 @@ ifeq ($(CC_VENDOR),clang) # for version 3x we will enable znver3 ifeq ($(strip $(shell $(CC) -v |&head -1 |grep -c 'AOCC_3')),1) -CKVECFLAGS += -march=znver3 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mfpmath=sse -CRVECFLAGS += -march=znver3 +CKVECFLAGS += -march=znver3 else # for version 2x we will enable znver2 ifeq ($(strip $(shell $(CC) -v |&head -1 |grep -c 'AOCC.LLVM.2\|AOCC_2')),1) -CKVECFLAGS += -march=znver2 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mfpmath=sse -CRVECFLAGS += -march=znver2 +CKVECFLAGS += -march=znver2 else #if compiling with clang VENDOR_STRING := $(strip $(shell ${CC_VENDOR} --version | egrep -o '[0-9]+\.[0-9]+\.?[0-9]*')) @@ -115,10 +111,8 @@ CC_MAJOR := $(shell (echo ${VENDOR_STRING} | cut -d. -f1)) #clang 9.0 or later: ifeq ($(shell test $(CC_MAJOR) -ge 9; echo $$?),0) CKVECFLAGS += -march=znver2 -CRVECFLAGS += -march=znver2 else CKVECFLAGS += -march=znver1 -CRVECFLAGS += -march=znver1 endif # ge 9 endif # aocc 2 endif # aocc 3 @@ -127,9 +121,7 @@ endif # gcc # Flags specific to reference kernels. CROPTFLAGS := $(CKOPTFLAGS) - -# Disable AVX-512 for reference kernels -CRVECFLAGS += -mno-avx512f -mno-avx512vl -mno-avx512bw -mno-avx512dq -mno-avx512cd -funsafe-math-optimizations -ffp-contract=fast +CRVECFLAGS := $(CKVECFLAGS) # Store all of the variables here to new variables containing the # configuration name. diff --git a/config_registry b/config_registry index 4e6716dfa1..822b133f5c 100644 --- a/config_registry +++ b/config_registry @@ -26,7 +26,7 @@ sandybridge: sandybridge penryn: penryn # AMD architectures. -zen4: zen4/zen4/skx/zen3/zen2/zen/haswell +zen4: zen4/zen4/zen3/zen2/zen/haswell zen3: zen3/zen3/zen2/zen/haswell zen2: zen2/zen2/zen/haswell zen: zen/zen/haswell diff --git a/frame/include/bli_arch_config.h b/frame/include/bli_arch_config.h index 6343c6ba89..3e2e0b022b 100644 --- a/frame/include/bli_arch_config.h +++ b/frame/include/bli_arch_config.h @@ -6,7 +6,7 @@ Copyright (C) 2014, The University of Texas at Austin Copyright (C) 2016, Hewlett Packard Enterprise Development LP - Copyright (C) 2019 - 2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 2021, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -171,9 +171,6 @@ CNTX_INIT_PROTS( generic ) // -- AMD64 architectures -- -#ifdef BLIS_FAMILY_ZEN4 -#include "bli_family_zen4.h" -#endif #ifdef BLIS_FAMILY_ZEN3 #include "bli_family_zen3.h" #endif From c11fd5a8f697b69073f68e26c996e3d94d236811 Mon Sep 17 00:00:00 2001 From: Meghana Vankadari Date: Wed, 19 Jan 2022 09:16:41 +0530 Subject: [PATCH 078/243] Added functionality support for dzgemm AMD-Internal: [SWLCSG-1012] Change-Id: I2eac3131d2dcd534f84491289cbd3fe7fb7de3da --- frame/1m/packm/bli_packm_blk_var1.c | 5 +++-- frame/3/bli_l3_check.c | 7 +++++-- frame/3/gemm/bli_gemm_packab.c | 12 +++++++++++- frame/compat/bla_gemm.c | 6 ++---- frame/compat/bla_gemm.h | 4 ++-- frame/include/bli_type_defs.h | 5 +++-- test/test_gemm.c | 8 +++----- 7 files changed, 29 insertions(+), 18 deletions(-) diff --git a/frame/1m/packm/bli_packm_blk_var1.c b/frame/1m/packm/bli_packm_blk_var1.c index 87f8df4f7d..c720317b96 100644 --- a/frame/1m/packm/bli_packm_blk_var1.c +++ b/frame/1m/packm/bli_packm_blk_var1.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -159,7 +159,8 @@ void bli_packm_blk_var1 // Treatment of kappa (ie: packing during scaling) depends on // whether we are executing an induced method. - if ( bli_is_nat_packed( schema ) ) + // For dzgemm, scale alpha during packing. + if ( bli_is_nat_packed( schema ) && cntl && bli_cntl_family(cntl) != BLIS_GEMM_MD) { // This branch is for native execution, where we assume that // the micro-kernel will always apply the alpha scalar of the diff --git a/frame/3/bli_l3_check.c b/frame/3/bli_l3_check.c index 945b267fda..43ba867283 100644 --- a/frame/3/bli_l3_check.c +++ b/frame/3/bli_l3_check.c @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -323,8 +324,10 @@ void bli_gemm_basic_check // When mixing datatypes, make sure that alpha does not have a non-zero // imaginary component. - if ( bli_obj_dt( c ) != bli_obj_dt( a ) || - bli_obj_dt( c ) != bli_obj_dt( b ) || + // To support dzgemm, we continue execution when datatypes of C and A + // do not match instead of aborting with an error message. + // Non-zero imaginary component of alpha is handled while packing B. + if ( bli_obj_dt( c ) != bli_obj_dt( b ) || bli_obj_comp_prec( c ) != bli_obj_prec( c ) ) if ( !bli_obj_imag_is_zero( alpha ) ) { diff --git a/frame/3/gemm/bli_gemm_packab.c b/frame/3/gemm/bli_gemm_packab.c index 3dfed88478..6828725546 100644 --- a/frame/3/gemm/bli_gemm_packab.c +++ b/frame/3/gemm/bli_gemm_packab.c @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -90,9 +91,14 @@ void bli_gemm_packb ) { AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_5); - + obj_t b_pack; + // BY setting family id to BLIS_GEMM_MD, we indicate packing kernels + // to scale alpha while packing. + if(bli_obj_dt(c) != bli_obj_dt(a)) + bli_cntl_set_family(BLIS_GEMM_MD, cntl); + // Pack matrix B according to the control tree node. bli_l3_packm ( @@ -103,6 +109,10 @@ void bli_gemm_packb cntl, thread ); + // Once packing of B matrix is done, fall back to GEMM execution. + if(bli_obj_dt(c) != bli_obj_dt(a)) + bli_cntl_set_family(BLIS_GEMM, cntl); + // Proceed with execution using packed matrix B. bli_gemm_int diff --git a/frame/compat/bla_gemm.c b/frame/compat/bla_gemm.c index 8d08a9e010..406ff69d53 100644 --- a/frame/compat/bla_gemm.c +++ b/frame/compat/bla_gemm.c @@ -302,9 +302,7 @@ void PASTEF77(ch,blasname) \ #ifdef BLIS_ENABLE_BLAS INSERT_GENTFUNC_BLAS( gemm,gemm ) -// Observed a regression in dgemm with this function addition. -// Disabling temporarily. -#if 0 +#if 1 void dzgemm_ ( const f77_char* transa, @@ -381,7 +379,7 @@ void dzgemm_ bli_obj_init_finish_1x1( dt, (dcomplex*)alpha, &alphao ); bli_obj_init_finish_1x1( dt, (dcomplex*)beta, &betao ); - bli_obj_init_finish( dt_a, m0_a, n0_a, (dcomplex*)a, rs_a, cs_a, &ao ); + bli_obj_init_finish( dt_a, m0_a, n0_a, (double*)a, rs_a, cs_a, &ao ); bli_obj_init_finish( dt, m0_b, n0_b, (dcomplex*)b, rs_b, cs_b, &bo ); bli_obj_init_finish( dt, m0, n0, (dcomplex*)c, rs_c, cs_c, &co ); diff --git a/frame/compat/bla_gemm.h b/frame/compat/bla_gemm.h index 25aef8d11f..c9ea83149a 100644 --- a/frame/compat/bla_gemm.h +++ b/frame/compat/bla_gemm.h @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -54,8 +55,7 @@ BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ ); #ifdef BLIS_ENABLE_BLAS -// Disabling temporarily -#if 0 +#if 1 BLIS_EXPORT_BLAS void dzgemm_ ( const f77_char* transa, \ diff --git a/frame/include/bli_type_defs.h b/frame/include/bli_type_defs.h index c0e0095061..cb4e4e4b84 100644 --- a/frame/include/bli_type_defs.h +++ b/frame/include/bli_type_defs.h @@ -6,7 +6,7 @@ Copyright (C) 2014, The University of Texas at Austin Copyright (C) 2016, Hewlett Packard Enterprise Development LP - Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2021 - 22, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -931,10 +931,11 @@ typedef enum BLIS_TRMM, BLIS_TRSM, BLIS_GEMMT, + BLIS_GEMM_MD, BLIS_NOID } opid_t; -#define BLIS_NUM_LEVEL3_OPS 11 +#define BLIS_NUM_LEVEL3_OPS 12 // -- Blocksize ID type -- diff --git a/test/test_gemm.c b/test/test_gemm.c index 772d73c7b1..25fc5e3d8d 100644 --- a/test/test_gemm.c +++ b/test/test_gemm.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019-2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -382,8 +382,7 @@ int main( int argc, char** argv ) cp, ldc ); #else -//Disabled dzgemm function temporarily. -#if 0 +#if 1 if( bli_is_double( dt_a ) ) { dzgemm_( @@ -401,7 +400,6 @@ int main( int argc, char** argv ) } else { -#else zgemm_( &f77_transa, &f77_transb, &mm, @@ -412,7 +410,7 @@ int main( int argc, char** argv ) bp, (f77_int*)&ldb, betap, cp, (f77_int*)&ldc ); -// } + } #endif #endif } From 31921b9974396347682a40ee654bfab5777cd4b9 Mon Sep 17 00:00:00 2001 From: Dipal M Zambare Date: Mon, 24 Jan 2022 20:30:25 +0530 Subject: [PATCH 079/243] Updated windows build system to define BLIS_CONFIG_EPYC flag. All AMD specific optimization in BLIS are enclosed in BLIS_CONFIG_EPYC pre-preprocessor, this was not defined in CMake which are resulting in overall lower performance. Updated version number to 3.1.1 Change-Id: I9848b695a599df07da44e77e71a64414b28c75b9 --- CMakeLists.txt | 5 ++++- kernels/zen/2/CMakeLists.txt | 3 ++- so_version | 2 +- version | 2 +- 4 files changed, 8 insertions(+), 4 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 3f7cc6ad94..19cf7e68f3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,4 +1,4 @@ -##Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved.## +##Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.## cmake_minimum_required(VERSION 3.0.0) @@ -33,17 +33,20 @@ endif () if(${AOCL_BLIS_FAMILY} STREQUAL "zen") add_definitions(-DBLIS_FAMILY_ZEN) + add_definitions(-DBLIS_CONFIG_EPYC) add_definitions(-DBLIS_CONFIG_ZEN) add_definitions(-DBLIS_KERNELS_ZEN) add_definitions(-DBLIS_KERNELS_HASWELL) elseif (${AOCL_BLIS_FAMILY} STREQUAL "zen2") add_definitions(-DBLIS_FAMILY_ZEN2) + add_definitions(-DBLIS_CONFIG_EPYC) add_definitions(-DBLIS_CONFIG_ZEN2) add_definitions(-DBLIS_KERNELS_ZEN2) add_definitions(-DBLIS_KERNELS_ZEN) add_definitions(-DBLIS_KERNELS_HASWELL) elseif (${AOCL_BLIS_FAMILY} STREQUAL "zen3") add_definitions(-DBLIS_FAMILY_ZEN3) + add_definitions(-DBLIS_CONFIG_EPYC) add_definitions(-DBLIS_CONFIG_ZEN3) add_definitions(-DBLIS_KERNELS_ZEN3) add_definitions(-DBLIS_KERNELS_ZEN2) diff --git a/kernels/zen/2/CMakeLists.txt b/kernels/zen/2/CMakeLists.txt index 72176895c5..07a9266b0e 100644 --- a/kernels/zen/2/CMakeLists.txt +++ b/kernels/zen/2/CMakeLists.txt @@ -1,9 +1,10 @@ -##Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.## +##Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.## target_sources("${PROJECT_NAME}" PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemv_zen_ref.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_her2_zen_int_4.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemv_zen_int_4.c ) diff --git a/so_version b/so_version index a831c0e579..b1f189286c 100644 --- a/so_version +++ b/so_version @@ -1,2 +1,2 @@ 3 -1.0 +1.1 diff --git a/version b/version index 0c6173b5f1..1795fa298a 100644 --- a/version +++ b/version @@ -1,2 +1,2 @@ -3.1.0 +3.1.1 From ec6e4162bcc7cfe68a81d833102a0f93048701da Mon Sep 17 00:00:00 2001 From: Chandrashekara K R Date: Tue, 25 Jan 2022 13:53:03 +0530 Subject: [PATCH 080/243] Updated windows build system. We were using add_compile_options(-Xclang -fopenmp) statement to set omp library compiler flags on MSVC using cmake. Observing there is an performance regression because of the compiler version which is using in MSVC(clang 10), so removing it from the windows build system and configuring the compiler version(clang 13) and compiler options manually on MSVC gui to gain a performance on matlab bench. Change-Id: I37d778abdceb7c1fae9b1caaeea8adb114677dd2 --- CMakeLists.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 19cf7e68f3..39e8a44bce 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -267,7 +267,6 @@ if(ENABLE_MULTITHREADING) find_package(OpenMP) if (OPENMP_FOUND) set(BLIS_ENABLE_OPENMP TRUE) - add_compile_options(-Xclang -fopenmp) else() message (FATAL_ERROR "Openmp Not Found") endif() From d687bd36ea2db3db7752fad2e98edadb847ce01d Mon Sep 17 00:00:00 2001 From: HariharaSudhan S Date: Fri, 24 Dec 2021 00:05:13 -0500 Subject: [PATCH 081/243] Merge "Improved AXPYV Kernel performance" into amd-staging-genoa-4.0 From 86690f9fd302cc6e8e0b647db8d1650dc133ccfd Mon Sep 17 00:00:00 2001 From: Arnav Sharma Date: Tue, 21 Dec 2021 16:49:11 +0530 Subject: [PATCH 082/243] Optimized AXPBYV Kernel using AVX2 Intrinsics Details: - Intrinsic implementation of axpbyv for AVX2 - Bench written for axpbyv - Added definitions in zen contexts AMD-Internal: [CPUPL-1963] Change-Id: I9bc21a6170f5c944eb6e9e9f0e994b9992f8b539 --- kernels/zen/bli_kernels_zen.h | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/kernels/zen/bli_kernels_zen.h b/kernels/zen/bli_kernels_zen.h index fea59796a8..c092ab3ab7 100644 --- a/kernels/zen/bli_kernels_zen.h +++ b/kernels/zen/bli_kernels_zen.h @@ -57,6 +57,16 @@ AXPBYV_KER_PROT( dcomplex, z, axpbyv_zen_int ) AXPBYV_KER_PROT( float, s, axpbyv_zen_int10 ) AXPBYV_KER_PROT( double, d, axpbyv_zen_int10 ) +// axpbyv (intrinsics) +AXPBYV_KER_PROT( float, s, axpbyv_zen_int ) +AXPBYV_KER_PROT( double, d, axpbyv_zen_int ) +AXPBYV_KER_PROT( scomplex, c, axpbyv_zen_int ) +AXPBYV_KER_PROT( dcomplex, z, axpbyv_zen_int ) + +// axpbyv (intrinsics, unrolled x10) +AXPBYV_KER_PROT( float, s, axpbyv_zen_int10 ) +AXPBYV_KER_PROT( double, d, axpbyv_zen_int10 ) + // axpyv (intrinsics) AXPYV_KER_PROT( float, s, axpyv_zen_int ) AXPYV_KER_PROT( double, d, axpyv_zen_int ) From d116780616bc13cc59d7ac2dad9fddb923ebb02b Mon Sep 17 00:00:00 2001 From: Harsh Dave Date: Fri, 17 Dec 2021 02:34:52 -0600 Subject: [PATCH 083/243] Optimized dher2 implementation - Impplemented her2 framework calls for transposed and non transposed kernel variants. - dher2 kernel operate over 4 columns at a time. It computes 4x4 triangular part of matrix first and remainder part is computed in chunk of 4x4 tile upto m rows. - remainder cases(m < 4) are handled serially. AMD-Internal: [CPUPL-1968] Change-Id: I12ae97b2ad673a7fd9b733c607f27b1089142313 --- frame/2/her2/bli_her2_unf_var1.c | 212 +++++++++++++++++++++++++++++++ frame/2/her2/bli_her2_unf_var4.c | 187 +++++++++++++++++++++++++++ 2 files changed, 399 insertions(+) diff --git a/frame/2/her2/bli_her2_unf_var1.c b/frame/2/her2/bli_her2_unf_var1.c index a0aec48f71..299e3d161d 100644 --- a/frame/2/her2/bli_her2_unf_var1.c +++ b/frame/2/her2/bli_her2_unf_var1.c @@ -158,5 +158,217 @@ void PASTEMAC(ch,varname) \ } \ } + +#ifdef BLIS_CONFIG_EPYC + +/** + * Following is function declaration + * that computes her2 for transposed case. + * It handles triangular part of matrix and + * remaining computation in optimal way to + * gain performance improvement. + * a is triangular matrix, x and y are vectors + */ +void bli_dher2_trans_zen_int_4 + ( + double *a, + double *x, + double *y, + double *alpha, + dim_t m, + dim_t lda + ); + +void bli_dher2_unf_var1 + ( + uplo_t uplo, + conj_t conjx, + conj_t conjy, + conj_t conjh, + dim_t m, + double* alpha, + double* x, inc_t incx, + double* y, inc_t incy, + double* c, inc_t rs_c, inc_t cs_c, + cntx_t* cntx + ) +{ + const num_t dt = PASTEMAC(d,type); + + double* x0; + double* chi1; + double* y0; + double* psi1; + double* c10t; + double* gamma11; + double alpha0; + double alpha1; + double alpha0_chi1; + double alpha1_psi1; + double alpha0_chi1_psi1; + double conjx0_chi1; + double conjy1_psi1; + double conjy0_psi1; + dim_t i; + dim_t n_behind; + inc_t rs_ct, cs_ct; + conj_t conj0, conj1; + + /* The algorithm will be expressed in terms of the lower triangular + * case;the upper triangular case is supported by swapping the row + * and column strides of A and toggling some conj parameters. + */ + if ( bli_is_lower( uplo ) ) + { + rs_ct = rs_c; + cs_ct = cs_c; + + PASTEMAC(d,copys)( *alpha, alpha0 ); + PASTEMAC(d,copycjs)( conjh, *alpha, alpha1 ); + } + else /* if ( bli_is_upper( uplo ) ) */ + { + rs_ct = cs_c; + cs_ct = rs_c; + + /* Toggle conjugation of conjx/conjy, but only if we are being + * invoked as her2; for syr2, conjx/conjy are unchanged. + */ + conjx = bli_apply_conj( conjh, conjx ); + conjy = bli_apply_conj( conjh, conjy ); + + PASTEMAC(d,copycjs)( conjh, *alpha, alpha0 ); + PASTEMAC(d,copys)( *alpha, alpha1 ); + } + + /* Apply conjh (which carries the conjugation component of the + * Hermitian transpose, if applicable) to conjx and/or conjy as + * needed to arrive at the effective conjugation for the vector + * subproblems. + */ + conj0 = bli_apply_conj( conjh, conjy ); + conj1 = bli_apply_conj( conjh, conjx ); + + PASTECH(d,axpy2v_ker_ft) kfp_2v; + + /* Query the context for the kernel function pointer. */ + kfp_2v = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPY2V_KER, cntx ); + + if( (incx == 1) && (incy == 1) && (rs_ct == 1)) + { + for ( i = 0; i < m; ) + { + n_behind = i; + x0 = x + (0 )*incx; + chi1 = x + (i )*incx; + y0 = y + (0 )*incy; + psi1 = y + (i )*incy; + c10t = c + (i )*rs_ct + (0 )*cs_ct; + gamma11 = c + (i )*rs_ct + (i )*cs_ct; + + if((n_behind >= 3)) + { + bli_dher2_trans_zen_int_4(c10t, x0, y0, &alpha0, n_behind + 1, cs_ct); + i+=4; + } + else + { + /* Apply conjx and/or conjy to chi1 and/or psi1. */ + PASTEMAC(d,copycjs)( conjx, *chi1, conjx0_chi1 ); + PASTEMAC(d,copycjs)( conjy, *psi1, conjy1_psi1 ); + PASTEMAC(d,copycjs)( conj0, *psi1, conjy0_psi1 ); + + /* Compute scalars for vector subproblems. */ + PASTEMAC(d,scal2s)( alpha0, conjx0_chi1, alpha0_chi1 ); + PASTEMAC(d,scal2s)( alpha1, conjy1_psi1, alpha1_psi1 ); + + /* Compute alpha * chi1 * conj(psi1) after both chi1 + * and psi1 have already been conjugated, if needed, + * by conjx and conjy. + */ + PASTEMAC(d,scal2s)( alpha0_chi1, conjy0_psi1, + alpha0_chi1_psi1 ); + + /* c10t = c10t + alpha * chi1 * y0'; */ + /* c10t = c10t + conj(alpha) * psi1 * x0'; */ + kfp_2v + ( + conj0, + conj1, + n_behind, + &alpha0_chi1, + &alpha1_psi1, + y0, incy, + x0, incx, + c10t, cs_ct, + cntx + ); + + /* gamma11 = gamma11 + alpha * chi1 * conj(psi1) + + conj(alpha) * psi1 * conj(chi1); */ + PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); + PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); + + i+=1; + } + } + } + else + { + for ( i = 0; i < m; ++i ) + { + n_behind = i; + x0 = x + (0 )*incx; + chi1 = x + (i )*incx; + y0 = y + (0 )*incy; + psi1 = y + (i )*incy; + c10t = c + (i )*rs_ct + (0 )*cs_ct; + gamma11 = c + (i )*rs_ct + (i )*cs_ct; + + /* Apply conjx and/or conjy to chi1 and/or psi1. */ + PASTEMAC(d,copycjs)( conjx, *chi1, conjx0_chi1 ); + PASTEMAC(d,copycjs)( conjy, *psi1, conjy1_psi1 ); + PASTEMAC(d,copycjs)( conj0, *psi1, conjy0_psi1 ); + + /* Compute scalars for vector subproblems. */ + PASTEMAC(d,scal2s)( alpha0, conjx0_chi1, alpha0_chi1 ); + PASTEMAC(d,scal2s)( alpha1, conjy1_psi1, alpha1_psi1 ); + + /* Compute alpha * chi1 * conj(psi1) after both chi1 + * and psi1 have already been conjugated, if needed, + * by conjx and conjy. + */ + PASTEMAC(d,scal2s)( alpha0_chi1, conjy0_psi1, + alpha0_chi1_psi1 ); + + /* c10t = c10t + alpha * chi1 * y0'; */ + /* c10t = c10t + conj(alpha) * psi1 * x0'; */ + kfp_2v + ( + conj0, + conj1, + n_behind, + &alpha0_chi1, + &alpha1_psi1, + y0, incy, + x0, incx, + c10t, cs_ct, + cntx + ); + + /* gamma11 = gamma11 + alpha * chi1 * conj(psi1) + + conj(alpha) * psi1 * conj(chi1); */ + PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); + PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); + + } + } +} + +GENTFUNC(float, s, her2_unf_var1) +GENTFUNC(scomplex, c, her2_unf_var1) +GENTFUNC(dcomplex, z,her2_unf_var1) +#else INSERT_GENTFUNC_BASIC0( her2_unf_var1 ) +#endif diff --git a/frame/2/her2/bli_her2_unf_var4.c b/frame/2/her2/bli_her2_unf_var4.c index 3dea31d53e..e39c7224c4 100644 --- a/frame/2/her2/bli_her2_unf_var4.c +++ b/frame/2/her2/bli_her2_unf_var4.c @@ -166,5 +166,192 @@ void PASTEMAC(ch,varname) \ } \ } +#ifdef BLIS_CONFIG_EPYC +/** + * Following is function declaration + * that computes her2 for transposed case. + * It handles triangular part of matrix and + * remaining computation in optimal way to + * gain performance improvement. + * a is triangular matrix, x and y are vectors + */ +void bli_dher2_zen_int_4 + ( + double *a, + double *x, + double *y, + double *alpha, + dim_t m, + dim_t lda + ); + +void bli_dher2_unf_var4 + ( + uplo_t uplo, + conj_t conjx, + conj_t conjy, + conj_t conjh, + dim_t m, + double* alpha, + double* x, inc_t incx, + double* y, inc_t incy, + double* c, inc_t rs_c, inc_t cs_c, + cntx_t* cntx + ) +{ + + double* chi1; + double* x2; + double* psi1; + double* y2; + double* gamma11; + double* c21; + double alpha0; + double alpha0_psi1; + double alpha1_chi1; + double alpha0_chi1_psi1; + dim_t i; + dim_t n_ahead; + inc_t rs_ct, cs_ct; + + const num_t dt = PASTEMAC(d,type); + + /* The algorithm will be expressed in terms of the lower triangular + * case; the upper triangular case is supported by swapping the row + * and column strides of A and toggling some conj parameters. + */ + if ( bli_is_lower( uplo ) ) + { + rs_ct = rs_c; + cs_ct = cs_c; + + PASTEMAC(d,copys)( *alpha, alpha0 ); + } + else /* if ( bli_is_upper( uplo ) ) */ + { + rs_ct = cs_c; + cs_ct = rs_c; + + /* Toggle conjugation of conjx/conjy, but only if we are being + * invoked as her2; for syr2, conjx/conjy are unchanged. + */ + + PASTEMAC(d,copys)( *alpha, alpha0 ); + } + /* Apply conjh (which carries the conjugation component of the + * Hermitian transpose, if applicable) to conjx and/or conjy as + * needed to arrive at the effective conjugation for the vector + * subproblems. + */ + + PASTECH(d,axpy2v_ker_ft) kfp_2v; + + /* Query the context for the kernel function pointer. */ + kfp_2v = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPY2V_KER, cntx ); + + if((incx == 1) && (incy == 1) && (rs_ct == 1)) + { + for ( i = 0; i < m; ) + { + n_ahead = m - i - 1; + chi1 = x + (i ) * incx; + x2 = x + (i+1) * incx; + psi1 = y + (i ) * incy; + y2 = y + (i+1) * incy; + gamma11 = c + (i ) + (i )*cs_ct; + c21 = c + (i+1) + (i )*cs_ct; + + if((n_ahead >= 3)) + { + bli_dher2_zen_int_4(gamma11, chi1, psi1, &alpha0, n_ahead + 1, cs_ct); + i+= 4; + } + else + { + /* Compute scalars for vector subproblems. */ + PASTEMAC(d,scal2s)( alpha0, *psi1, alpha0_psi1 ); + PASTEMAC(d,scal2s)( alpha0, *chi1, alpha1_chi1 ); + + /* Compute alpha * chi1 * conj(psi1) after both chi1 + * and psi1 have + already been conjugated, if needed, by conjx and + conjy. */ + PASTEMAC(d,scal2s)( alpha0_psi1, *chi1, + alpha0_chi1_psi1 ); + + /* c21 = c21 + alpha * x2 * conj(psi1); */ + /* c21 = c21 + conj(alpha) * y2 * conj(chi1); */ + + kfp_2v + ( + conjx, + conjy, + n_ahead, + &alpha0_psi1, + &alpha1_chi1, + x2, incx, + y2, incy, + c21, rs_ct, + cntx + ); + + + PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); + PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); + i+=1; + } + } + } + else + { + for ( i = 0; i < m; ++i) + { + n_ahead = m - i - 1; + chi1 = x + (i ) * incx; + x2 = x + (i+1) * incx; + psi1 = y + (i ) * incy; + y2 = y + (i+1) * incy; + gamma11 = c + (i ) + (i )*cs_ct; + c21 = c + (i+1) + (i )*cs_ct; + + /* Compute scalars for vector subproblems. */ + PASTEMAC(d,scal2s)( alpha0, *psi1, alpha0_psi1 ); + PASTEMAC(d,scal2s)( alpha0, *chi1, alpha1_chi1 ); + + /* Compute alpha * chi1 * conj(psi1) after both chi1 + * and psi1 have + already been conjugated, if needed, by conjx and + conjy. */ + PASTEMAC(d,scal2s)( alpha0_psi1, *chi1, + alpha0_chi1_psi1 ); + + /* c21 = c21 + alpha * x2 * conj(psi1); */ + /* c21 = c21 + conj(alpha) * y2 * conj(chi1); */ + + kfp_2v + ( + conjx, + conjy, + n_ahead, + &alpha0_psi1, + &alpha1_chi1, + x2, incx, + y2, incy, + c21, rs_ct, + cntx + ); + + + PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); + PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); + } + } +} + +GENTFUNC(float, s, her2_unf_var4) +GENTFUNC(scomplex, c, her2_unf_var4) +GENTFUNC(dcomplex, z,her2_unf_var4) +#else INSERT_GENTFUNC_BASIC0( her2_unf_var4 ) +#endif From f69f59c32c0839ba67985e058006da59fd22d2bc Mon Sep 17 00:00:00 2001 From: Dipal M Zambare Date: Mon, 20 Dec 2021 09:43:13 +0530 Subject: [PATCH 084/243] Removed Arch specific code from BLIS framework. - Removed BLIS_CONFIG_EPYC macro - The code dependent on this macro is handled in one of the three ways -- It is updated to work across platforms. -- Added in architecture/feature specific runtime checks. -- Duplicated in AMD specific files. Build system is updated to pick AMD specific files when library is built for any of the zen architecture AMD-Internal: [CPUPL-1960] Change-Id: I6f9f8018e41fa48eb43ae4245c9c2c361857f43b --- frame/2/gemv/bli_gemv_unf_var2.c | 2 +- frame/2/her2/bli_her2_unf_var1.c | 212 ------------------------------- frame/2/her2/bli_her2_unf_var4.c | 187 --------------------------- frame/2/trsv/bli_trsv_unf_var2.c | 2 +- 4 files changed, 2 insertions(+), 401 deletions(-) diff --git a/frame/2/gemv/bli_gemv_unf_var2.c b/frame/2/gemv/bli_gemv_unf_var2.c index 227e43ad01..d6c21de6df 100644 --- a/frame/2/gemv/bli_gemv_unf_var2.c +++ b/frame/2/gemv/bli_gemv_unf_var2.c @@ -137,4 +137,4 @@ void PASTEMAC(ch,varname) \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); \ } -INSERT_GENTFUNC_BASIC0( gemv_unf_var2 ) \ No newline at end of file +INSERT_GENTFUNC_BASIC0( gemv_unf_var2 ) diff --git a/frame/2/her2/bli_her2_unf_var1.c b/frame/2/her2/bli_her2_unf_var1.c index 299e3d161d..a0aec48f71 100644 --- a/frame/2/her2/bli_her2_unf_var1.c +++ b/frame/2/her2/bli_her2_unf_var1.c @@ -158,217 +158,5 @@ void PASTEMAC(ch,varname) \ } \ } - -#ifdef BLIS_CONFIG_EPYC - -/** - * Following is function declaration - * that computes her2 for transposed case. - * It handles triangular part of matrix and - * remaining computation in optimal way to - * gain performance improvement. - * a is triangular matrix, x and y are vectors - */ -void bli_dher2_trans_zen_int_4 - ( - double *a, - double *x, - double *y, - double *alpha, - dim_t m, - dim_t lda - ); - -void bli_dher2_unf_var1 - ( - uplo_t uplo, - conj_t conjx, - conj_t conjy, - conj_t conjh, - dim_t m, - double* alpha, - double* x, inc_t incx, - double* y, inc_t incy, - double* c, inc_t rs_c, inc_t cs_c, - cntx_t* cntx - ) -{ - const num_t dt = PASTEMAC(d,type); - - double* x0; - double* chi1; - double* y0; - double* psi1; - double* c10t; - double* gamma11; - double alpha0; - double alpha1; - double alpha0_chi1; - double alpha1_psi1; - double alpha0_chi1_psi1; - double conjx0_chi1; - double conjy1_psi1; - double conjy0_psi1; - dim_t i; - dim_t n_behind; - inc_t rs_ct, cs_ct; - conj_t conj0, conj1; - - /* The algorithm will be expressed in terms of the lower triangular - * case;the upper triangular case is supported by swapping the row - * and column strides of A and toggling some conj parameters. - */ - if ( bli_is_lower( uplo ) ) - { - rs_ct = rs_c; - cs_ct = cs_c; - - PASTEMAC(d,copys)( *alpha, alpha0 ); - PASTEMAC(d,copycjs)( conjh, *alpha, alpha1 ); - } - else /* if ( bli_is_upper( uplo ) ) */ - { - rs_ct = cs_c; - cs_ct = rs_c; - - /* Toggle conjugation of conjx/conjy, but only if we are being - * invoked as her2; for syr2, conjx/conjy are unchanged. - */ - conjx = bli_apply_conj( conjh, conjx ); - conjy = bli_apply_conj( conjh, conjy ); - - PASTEMAC(d,copycjs)( conjh, *alpha, alpha0 ); - PASTEMAC(d,copys)( *alpha, alpha1 ); - } - - /* Apply conjh (which carries the conjugation component of the - * Hermitian transpose, if applicable) to conjx and/or conjy as - * needed to arrive at the effective conjugation for the vector - * subproblems. - */ - conj0 = bli_apply_conj( conjh, conjy ); - conj1 = bli_apply_conj( conjh, conjx ); - - PASTECH(d,axpy2v_ker_ft) kfp_2v; - - /* Query the context for the kernel function pointer. */ - kfp_2v = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPY2V_KER, cntx ); - - if( (incx == 1) && (incy == 1) && (rs_ct == 1)) - { - for ( i = 0; i < m; ) - { - n_behind = i; - x0 = x + (0 )*incx; - chi1 = x + (i )*incx; - y0 = y + (0 )*incy; - psi1 = y + (i )*incy; - c10t = c + (i )*rs_ct + (0 )*cs_ct; - gamma11 = c + (i )*rs_ct + (i )*cs_ct; - - if((n_behind >= 3)) - { - bli_dher2_trans_zen_int_4(c10t, x0, y0, &alpha0, n_behind + 1, cs_ct); - i+=4; - } - else - { - /* Apply conjx and/or conjy to chi1 and/or psi1. */ - PASTEMAC(d,copycjs)( conjx, *chi1, conjx0_chi1 ); - PASTEMAC(d,copycjs)( conjy, *psi1, conjy1_psi1 ); - PASTEMAC(d,copycjs)( conj0, *psi1, conjy0_psi1 ); - - /* Compute scalars for vector subproblems. */ - PASTEMAC(d,scal2s)( alpha0, conjx0_chi1, alpha0_chi1 ); - PASTEMAC(d,scal2s)( alpha1, conjy1_psi1, alpha1_psi1 ); - - /* Compute alpha * chi1 * conj(psi1) after both chi1 - * and psi1 have already been conjugated, if needed, - * by conjx and conjy. - */ - PASTEMAC(d,scal2s)( alpha0_chi1, conjy0_psi1, - alpha0_chi1_psi1 ); - - /* c10t = c10t + alpha * chi1 * y0'; */ - /* c10t = c10t + conj(alpha) * psi1 * x0'; */ - kfp_2v - ( - conj0, - conj1, - n_behind, - &alpha0_chi1, - &alpha1_psi1, - y0, incy, - x0, incx, - c10t, cs_ct, - cntx - ); - - /* gamma11 = gamma11 + alpha * chi1 * conj(psi1) - + conj(alpha) * psi1 * conj(chi1); */ - PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); - PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); - - i+=1; - } - } - } - else - { - for ( i = 0; i < m; ++i ) - { - n_behind = i; - x0 = x + (0 )*incx; - chi1 = x + (i )*incx; - y0 = y + (0 )*incy; - psi1 = y + (i )*incy; - c10t = c + (i )*rs_ct + (0 )*cs_ct; - gamma11 = c + (i )*rs_ct + (i )*cs_ct; - - /* Apply conjx and/or conjy to chi1 and/or psi1. */ - PASTEMAC(d,copycjs)( conjx, *chi1, conjx0_chi1 ); - PASTEMAC(d,copycjs)( conjy, *psi1, conjy1_psi1 ); - PASTEMAC(d,copycjs)( conj0, *psi1, conjy0_psi1 ); - - /* Compute scalars for vector subproblems. */ - PASTEMAC(d,scal2s)( alpha0, conjx0_chi1, alpha0_chi1 ); - PASTEMAC(d,scal2s)( alpha1, conjy1_psi1, alpha1_psi1 ); - - /* Compute alpha * chi1 * conj(psi1) after both chi1 - * and psi1 have already been conjugated, if needed, - * by conjx and conjy. - */ - PASTEMAC(d,scal2s)( alpha0_chi1, conjy0_psi1, - alpha0_chi1_psi1 ); - - /* c10t = c10t + alpha * chi1 * y0'; */ - /* c10t = c10t + conj(alpha) * psi1 * x0'; */ - kfp_2v - ( - conj0, - conj1, - n_behind, - &alpha0_chi1, - &alpha1_psi1, - y0, incy, - x0, incx, - c10t, cs_ct, - cntx - ); - - /* gamma11 = gamma11 + alpha * chi1 * conj(psi1) - + conj(alpha) * psi1 * conj(chi1); */ - PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); - PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); - - } - } -} - -GENTFUNC(float, s, her2_unf_var1) -GENTFUNC(scomplex, c, her2_unf_var1) -GENTFUNC(dcomplex, z,her2_unf_var1) -#else INSERT_GENTFUNC_BASIC0( her2_unf_var1 ) -#endif diff --git a/frame/2/her2/bli_her2_unf_var4.c b/frame/2/her2/bli_her2_unf_var4.c index e39c7224c4..3dea31d53e 100644 --- a/frame/2/her2/bli_her2_unf_var4.c +++ b/frame/2/her2/bli_her2_unf_var4.c @@ -166,192 +166,5 @@ void PASTEMAC(ch,varname) \ } \ } -#ifdef BLIS_CONFIG_EPYC -/** - * Following is function declaration - * that computes her2 for transposed case. - * It handles triangular part of matrix and - * remaining computation in optimal way to - * gain performance improvement. - * a is triangular matrix, x and y are vectors - */ -void bli_dher2_zen_int_4 - ( - double *a, - double *x, - double *y, - double *alpha, - dim_t m, - dim_t lda - ); - -void bli_dher2_unf_var4 - ( - uplo_t uplo, - conj_t conjx, - conj_t conjy, - conj_t conjh, - dim_t m, - double* alpha, - double* x, inc_t incx, - double* y, inc_t incy, - double* c, inc_t rs_c, inc_t cs_c, - cntx_t* cntx - ) -{ - - double* chi1; - double* x2; - double* psi1; - double* y2; - double* gamma11; - double* c21; - double alpha0; - double alpha0_psi1; - double alpha1_chi1; - double alpha0_chi1_psi1; - dim_t i; - dim_t n_ahead; - inc_t rs_ct, cs_ct; - - const num_t dt = PASTEMAC(d,type); - - /* The algorithm will be expressed in terms of the lower triangular - * case; the upper triangular case is supported by swapping the row - * and column strides of A and toggling some conj parameters. - */ - if ( bli_is_lower( uplo ) ) - { - rs_ct = rs_c; - cs_ct = cs_c; - - PASTEMAC(d,copys)( *alpha, alpha0 ); - } - else /* if ( bli_is_upper( uplo ) ) */ - { - rs_ct = cs_c; - cs_ct = rs_c; - - /* Toggle conjugation of conjx/conjy, but only if we are being - * invoked as her2; for syr2, conjx/conjy are unchanged. - */ - - PASTEMAC(d,copys)( *alpha, alpha0 ); - } - /* Apply conjh (which carries the conjugation component of the - * Hermitian transpose, if applicable) to conjx and/or conjy as - * needed to arrive at the effective conjugation for the vector - * subproblems. - */ - - PASTECH(d,axpy2v_ker_ft) kfp_2v; - - /* Query the context for the kernel function pointer. */ - kfp_2v = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPY2V_KER, cntx ); - - if((incx == 1) && (incy == 1) && (rs_ct == 1)) - { - for ( i = 0; i < m; ) - { - n_ahead = m - i - 1; - chi1 = x + (i ) * incx; - x2 = x + (i+1) * incx; - psi1 = y + (i ) * incy; - y2 = y + (i+1) * incy; - gamma11 = c + (i ) + (i )*cs_ct; - c21 = c + (i+1) + (i )*cs_ct; - - if((n_ahead >= 3)) - { - bli_dher2_zen_int_4(gamma11, chi1, psi1, &alpha0, n_ahead + 1, cs_ct); - i+= 4; - } - else - { - /* Compute scalars for vector subproblems. */ - PASTEMAC(d,scal2s)( alpha0, *psi1, alpha0_psi1 ); - PASTEMAC(d,scal2s)( alpha0, *chi1, alpha1_chi1 ); - - /* Compute alpha * chi1 * conj(psi1) after both chi1 - * and psi1 have - already been conjugated, if needed, by conjx and - conjy. */ - PASTEMAC(d,scal2s)( alpha0_psi1, *chi1, - alpha0_chi1_psi1 ); - - /* c21 = c21 + alpha * x2 * conj(psi1); */ - /* c21 = c21 + conj(alpha) * y2 * conj(chi1); */ - - kfp_2v - ( - conjx, - conjy, - n_ahead, - &alpha0_psi1, - &alpha1_chi1, - x2, incx, - y2, incy, - c21, rs_ct, - cntx - ); - - - PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); - PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); - i+=1; - } - } - } - else - { - for ( i = 0; i < m; ++i) - { - n_ahead = m - i - 1; - chi1 = x + (i ) * incx; - x2 = x + (i+1) * incx; - psi1 = y + (i ) * incy; - y2 = y + (i+1) * incy; - gamma11 = c + (i ) + (i )*cs_ct; - c21 = c + (i+1) + (i )*cs_ct; - - /* Compute scalars for vector subproblems. */ - PASTEMAC(d,scal2s)( alpha0, *psi1, alpha0_psi1 ); - PASTEMAC(d,scal2s)( alpha0, *chi1, alpha1_chi1 ); - - /* Compute alpha * chi1 * conj(psi1) after both chi1 - * and psi1 have - already been conjugated, if needed, by conjx and - conjy. */ - PASTEMAC(d,scal2s)( alpha0_psi1, *chi1, - alpha0_chi1_psi1 ); - - /* c21 = c21 + alpha * x2 * conj(psi1); */ - /* c21 = c21 + conj(alpha) * y2 * conj(chi1); */ - - kfp_2v - ( - conjx, - conjy, - n_ahead, - &alpha0_psi1, - &alpha1_chi1, - x2, incx, - y2, incy, - c21, rs_ct, - cntx - ); - - - PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); - PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); - } - } -} - -GENTFUNC(float, s, her2_unf_var4) -GENTFUNC(scomplex, c, her2_unf_var4) -GENTFUNC(dcomplex, z,her2_unf_var4) -#else INSERT_GENTFUNC_BASIC0( her2_unf_var4 ) -#endif diff --git a/frame/2/trsv/bli_trsv_unf_var2.c b/frame/2/trsv/bli_trsv_unf_var2.c index c0ef6abe45..9eb02781a4 100644 --- a/frame/2/trsv/bli_trsv_unf_var2.c +++ b/frame/2/trsv/bli_trsv_unf_var2.c @@ -229,4 +229,4 @@ void PASTEMAC(ch,varname) \ } \ } -INSERT_GENTFUNC_BASIC0( trsv_unf_var2 ) \ No newline at end of file +INSERT_GENTFUNC_BASIC0( trsv_unf_var2 ) From d3b22f590fe34f64322c42196127c116099c19de Mon Sep 17 00:00:00 2001 From: "Dipal M. Zambare" Date: Fri, 11 Feb 2022 12:12:01 +0530 Subject: [PATCH 085/243] Updated version number to 3.2 Change-Id: Iea5712d8cb854d4eaffea510e0fe2d9657e4d21f --- so_version | 2 +- version | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/so_version b/so_version index b1f189286c..8efd5969fe 100644 --- a/so_version +++ b/so_version @@ -1,2 +1,2 @@ 3 -1.1 +2.0 diff --git a/version b/version index 1795fa298a..252fb77212 100644 --- a/version +++ b/version @@ -1,2 +1,2 @@ -3.1.1 +3.2.0 From 97fbff4b65753cc94ee2bfc83dfac04b88222d59 Mon Sep 17 00:00:00 2001 From: Chandrashekara K R Date: Mon, 14 Feb 2022 17:43:41 +0530 Subject: [PATCH 086/243] AOCL_Windows: Updated windows build system. Updated the windows build system to link the user given openmp library using -DOpenMP_libomp_LIBRARY= option using command line or through cmake-gui application to build blis library and its test applications. If user not given any openmp library then by default openmp library will be C:/Program Files/LLVM/lib/libomp.lib. Change-Id: I07542c79454496f88e65e26327ad76a7f49c7a8c --- CMakeLists.txt | 13 ++++-- test/CMakeLists.txt | 98 ++++++++++++++++++++-------------------- testsuite/CMakeLists.txt | 6 +-- 3 files changed, 60 insertions(+), 57 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 39e8a44bce..601b2dce6f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -10,7 +10,8 @@ set(CMAKE_RUNTIME_OUTPUT_DIRECTORY "${CMAKE_SOURCE_DIR}/bin") SET(AOCL_BLIS_FAMILY "zen" CACHE STRING "AOCL BLIS family name") -SET(OMP_LIB "C:\\Program Files\\LLVM\\lib\\libomp.lib" CACHE STRING "openmp library path") +SET(OpenMP_libomp_LIBRARY "C:/Program Files/LLVM/lib/libomp.lib" CACHE STRING "openmp library +path") set(TARGET_ARCH ${AOCL_BLIS_FAMILY}) set(AOCL_BLIS_ZEN TRUE) set (PYTHON_EXE "python") @@ -267,6 +268,9 @@ if(ENABLE_MULTITHREADING) find_package(OpenMP) if (OPENMP_FOUND) set(BLIS_ENABLE_OPENMP TRUE) + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") + set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${OpenMP_EXE_LINKER_FLAGS}") else() message (FATAL_ERROR "Openmp Not Found") endif() @@ -538,14 +542,12 @@ file (STRINGS "version" BLIS_VERSION) set(BLIS_VERSION_STRING ${BLIS_VERSION}) add_definitions(-DBLIS_VERSION_STRING="AOCL BLIS ${BLIS_VERSION_STRING}") -message( STATUS "OPENMP Library:" ${OMP_LIB}) - if(BUILD_SHARED_LIBS) add_library("${PROJECT_NAME}" SHARED ${CMAKE_SOURCE_DIR}/bli_config.h ${CMAKE_SOURCE_DIR}/include/${TARGET_ARCH}/blis.h ${headers}) if(ENABLE_OPENMP) - target_link_libraries("${PROJECT_NAME}" PUBLIC "${OMP_LIB}") + target_link_libraries("${PROJECT_NAME}" PRIVATE OpenMP::OpenMP_CXX) endif() target_compile_definitions("${PROJECT_NAME}" PUBLIC -DBLIS_IS_BUILDING_LIBRARY) set_target_properties("${PROJECT_NAME}" PROPERTIES LINKER_LANGUAGE C OUTPUT_NAME "${LIB_NAME}") @@ -555,9 +557,10 @@ if(NOT BUILD_SHARED_LIBS) ${CMAKE_SOURCE_DIR}/include/${TARGET_ARCH}/blis.h ${headers}) if(ENABLE_OPENMP) - set_target_properties("${PROJECT_NAME}" PROPERTIES LINKER_LANGUAGE C OUTPUT_NAME "${LIB_NAME}" STATIC_LIBRARY_OPTIONS "${OMP_LIB}") + set_target_properties("${PROJECT_NAME}" PROPERTIES LINKER_LANGUAGE C OUTPUT_NAME "${LIB_NAME}" STATIC_LIBRARY_OPTIONS "${OpenMP_libomp_LIBRARY}") else() set_target_properties("${PROJECT_NAME}" PROPERTIES LINKER_LANGUAGE C OUTPUT_NAME "${LIB_NAME}") + target_link_libraries("${PROJECT_NAME}" PRIVATE OpenMP::OpenMP_CXX) endif() endif() diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index fe8f7bac98..d116e942d0 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -1,172 +1,172 @@ -##Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved.## +##Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.## add_definitions(-DBLAS="AOCL") add_executable(TestAminv test_aminv.c) target_link_libraries(TestAminv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) - target_link_libraries(TestAminv "${OMP_LIB}") +if(ENABLE_OPENMP) + target_link_libraries(TestAminv OpenMP::OpenMP_CXX) endif() target_link_libraries(TestAminv optimized "${LIB_NAME}.lib") add_executable(TestAxpyv test_axpyv.c) target_link_libraries(TestAxpyv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) - target_link_libraries(TestAxpyv "${OMP_LIB}") +if(ENABLE_OPENMP) + target_link_libraries(TestAxpyv OpenMP::OpenMP_CXX) endif() target_link_libraries(TestAxpyv optimized "${LIB_NAME}.lib") add_executable(TestAxpbyv test_axpbyv.c) target_link_libraries(TestAxpbyv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) - target_link_libraries(TestAxpbyv "${OMP_LIB}") +if(ENABLE_OPENMP) + target_link_libraries(TestAxpbyv OpenMP::OpenMP_CXX) endif() target_link_libraries(TestAxpbyv optimized "${LIB_NAME}.lib") add_executable(TestCopyv test_copyv.c) target_link_libraries(TestCopyv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) - target_link_libraries(TestCopyv "${OMP_LIB}") +if(ENABLE_OPENMP) + target_link_libraries(TestCopyv OpenMP::OpenMP_CXX) endif() target_link_libraries(TestCopyv optimized "${LIB_NAME}.lib") add_executable(TestCabs1 test_cabs1.c) target_link_libraries(TestCabs1 debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) - target_link_libraries(TestCabs1 "${OMP_LIB}") +if(ENABLE_OPENMP) + target_link_libraries(TestCabs1 OpenMP::OpenMP_CXX) endif() target_link_libraries(TestCabs1 optimized "${LIB_NAME}.lib") add_executable(TestDotv test_dotv.c) target_link_libraries(TestDotv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) - target_link_libraries(TestDotv "${OMP_LIB}") +if(ENABLE_OPENMP) + target_link_libraries(TestDotv OpenMP::OpenMP_CXX) endif() target_link_libraries(TestDotv optimized "${LIB_NAME}.lib") add_executable(TestGemm test_gemm.c) target_link_libraries(TestGemm debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) - target_link_libraries(TestGemm "${OMP_LIB}") +if(ENABLE_OPENMP) + target_link_libraries(TestGemm OpenMP::OpenMP_CXX) endif() target_link_libraries(TestGemm optimized "${LIB_NAME}.lib") add_executable(TestGemmBatch test_gemm_batch.c) target_link_libraries(TestGemmBatch debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) - target_link_libraries(TestGemmBatch "${OMP_LIB}") +if(ENABLE_OPENMP) + target_link_libraries(TestGemmBatch OpenMP::OpenMP_CXX) endif() target_link_libraries(TestGemmBatch optimized "${LIB_NAME}.lib") add_executable(TestGemm3m test_gemm3m.c) target_link_libraries(TestGemm3m debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) - target_link_libraries(TestGemm3m "${OMP_LIB}") +if(ENABLE_OPENMP) + target_link_libraries(TestGemm3m OpenMP::OpenMP_CXX) endif() target_link_libraries(TestGemm3m optimized "${LIB_NAME}.lib") add_executable(TestGemmt test_gemmt.c) target_link_libraries(TestGemmt debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) - target_link_libraries(TestGemmt "${OMP_LIB}") +if(ENABLE_OPENMP) + target_link_libraries(TestGemmt OpenMP::OpenMP_CXX) endif() target_link_libraries(TestGemmt optimized "${LIB_NAME}.lib") add_executable(TestGemv test_gemv.c) target_link_libraries(TestGemv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) - target_link_libraries(TestGemv "${OMP_LIB}") +if(ENABLE_OPENMP) + target_link_libraries(TestGemv OpenMP::OpenMP_CXX) endif() target_link_libraries(TestGemv optimized "${LIB_NAME}.lib") add_executable(TestGer test_ger.c) target_link_libraries(TestGer debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) - target_link_libraries(TestGer "${OMP_LIB}") +if(ENABLE_OPENMP) + target_link_libraries(TestGer OpenMP::OpenMP_CXX) endif() target_link_libraries(TestGer optimized "${LIB_NAME}.lib") add_executable(TestHemm test_hemm.c) target_link_libraries(TestHemm debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) - target_link_libraries(TestHemm "${OMP_LIB}") +if(ENABLE_OPENMP) + target_link_libraries(TestHemm OpenMP::OpenMP_CXX) endif() target_link_libraries(TestHemm optimized "${LIB_NAME}.lib") add_executable(TestHemv test_hemv.c) target_link_libraries(TestHemv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) - target_link_libraries(TestHemv "${OMP_LIB}") +if(ENABLE_OPENMP) + target_link_libraries(TestHemv OpenMP::OpenMP_CXX) endif() target_link_libraries(TestHemv optimized "${LIB_NAME}.lib") add_executable(TestHer test_her.c) target_link_libraries(TestHer debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) - target_link_libraries(TestHer "${OMP_LIB}") +if(ENABLE_OPENMP) + target_link_libraries(TestHer OpenMP::OpenMP_CXX) endif() target_link_libraries(TestHer optimized "${LIB_NAME}.lib") add_executable(TestHer2 test_her2.c) target_link_libraries(TestHer2 debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) - target_link_libraries(TestHer2 "${OMP_LIB}") +if(ENABLE_OPENMP) + target_link_libraries(TestHer2 OpenMP::OpenMP_CXX) endif() target_link_libraries(TestHer2 optimized "${LIB_NAME}.lib") add_executable(TestHer2k test_her2k.c) target_link_libraries(TestHer2k debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) - target_link_libraries(TestHer2k "${OMP_LIB}") +if(ENABLE_OPENMP) + target_link_libraries(TestHer2k OpenMP::OpenMP_CXX) endif() target_link_libraries(TestHer2k optimized "${LIB_NAME}.lib") add_executable(TestHerk test_herk.c) target_link_libraries(TestHerk debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) - target_link_libraries(TestHerk "${OMP_LIB}") +if(ENABLE_OPENMP) + target_link_libraries(TestHerk OpenMP::OpenMP_CXX) endif() target_link_libraries(TestHerk optimized "${LIB_NAME}.lib") add_executable(TestScalv test_scalv.c) target_link_libraries(TestScalv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) - target_link_libraries(TestScalv "${OMP_LIB}") +if(ENABLE_OPENMP) + target_link_libraries(TestScalv OpenMP::OpenMP_CXX) endif() target_link_libraries(TestScalv optimized "${LIB_NAME}.lib") add_executable(TestSwapv test_swapv.c) target_link_libraries(TestSwapv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) - target_link_libraries(TestSwapv "${OMP_LIB}") +if(ENABLE_OPENMP) + target_link_libraries(TestSwapv OpenMP::OpenMP_CXX) endif() target_link_libraries(TestSwapv optimized "${LIB_NAME}.lib") add_executable(TestTrmm test_trmm.c) target_link_libraries(TestTrmm debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) - target_link_libraries(TestTrmm "${OMP_LIB}") +if(ENABLE_OPENMP) + target_link_libraries(TestTrmm OpenMP::OpenMP_CXX) endif() target_link_libraries(TestTrmm optimized "${LIB_NAME}.lib") add_executable(TestTrmv test_trmv.c) target_link_libraries(TestTrmv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) - target_link_libraries(TestTrmv "${OMP_LIB}") +if(ENABLE_OPENMP) + target_link_libraries(TestTrmv OpenMP::OpenMP_CXX) endif() target_link_libraries(TestTrmv optimized "${LIB_NAME}.lib") add_executable(TestTrsm test_trsm.c) target_link_libraries(TestTrsm debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) - target_link_libraries(TestTrsm "${OMP_LIB}") +if(ENABLE_OPENMP) + target_link_libraries(TestTrsm OpenMP::OpenMP_CXX) endif() target_link_libraries(TestTrsm optimized "${LIB_NAME}.lib") add_executable(TestTrsv test_trsv.c) target_link_libraries(TestTrsv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) - target_link_libraries(TestTrsv "${OMP_LIB}") +if(ENABLE_OPENMP) + target_link_libraries(TestTrsv OpenMP::OpenMP_CXX) endif() target_link_libraries(TestTrsv optimized "${LIB_NAME}.lib") diff --git a/testsuite/CMakeLists.txt b/testsuite/CMakeLists.txt index 613f9e3861..85866926dd 100644 --- a/testsuite/CMakeLists.txt +++ b/testsuite/CMakeLists.txt @@ -1,4 +1,4 @@ -##Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved.## +##Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.## include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src) @@ -7,8 +7,8 @@ add_executable(test_libblis "") add_subdirectory(src) target_link_libraries(test_libblis debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) - target_link_libraries(test_libblis "${OMP_LIB}") +if(ENABLE_OPENMP) + target_link_libraries(test_libblis OpenMP::OpenMP_CXX) endif() target_link_libraries(test_libblis optimized "${LIB_NAME}.lib") From 393effbb0c1474aff9e2504258e70b5cd4994424 Mon Sep 17 00:00:00 2001 From: Arnav Sharma Date: Tue, 22 Feb 2022 12:08:53 +0530 Subject: [PATCH 087/243] Optimized ZAXPY2V using AVX2 Intrinsics Details: - Intrinsic implementation of ZAXPY2V fused kernel for AVX2 - Updated definitions in zen contexts AMD-Internal: [CPUPL-2023] Change-Id: I8889ae08c826d26e66ae607c416c4282136937fa --- config/zen/bli_cntx_init_zen.c | 3 +- config/zen2/bli_cntx_init_zen2.c | 3 +- config/zen3/bli_cntx_init_zen3.c | 3 +- kernels/zen/1f/bli_axpy2v_zen_int.c | 533 ++++++++++++++++++++++++++++ kernels/zen/bli_kernels_zen.h | 1 + 5 files changed, 540 insertions(+), 3 deletions(-) diff --git a/config/zen/bli_cntx_init_zen.c b/config/zen/bli_cntx_init_zen.c index eed39b3149..1badc24f96 100644 --- a/config/zen/bli_cntx_init_zen.c +++ b/config/zen/bli_cntx_init_zen.c @@ -80,7 +80,7 @@ void bli_cntx_init_zen( cntx_t* cntx ) // Update the context with optimized level-1f kernels. bli_cntx_set_l1f_kers ( - 9, + 10, // axpyf BLIS_AXPYF_KER, BLIS_FLOAT, bli_saxpyf_zen_int_8, BLIS_AXPYF_KER, BLIS_DOUBLE, bli_daxpyf_zen_int_8, @@ -93,6 +93,7 @@ void bli_cntx_init_zen( cntx_t* cntx ) BLIS_DOTXF_KER, BLIS_SCOMPLEX, bli_cdotxf_zen_int_6, //axpy2v BLIS_AXPY2V_KER, BLIS_DOUBLE, bli_daxpy2v_zen_int, + BLIS_AXPY2V_KER, BLIS_DCOMPLEX, bli_zaxpy2v_zen_int, cntx ); diff --git a/config/zen2/bli_cntx_init_zen2.c b/config/zen2/bli_cntx_init_zen2.c index f6b8eef1e4..997ccdba2e 100644 --- a/config/zen2/bli_cntx_init_zen2.c +++ b/config/zen2/bli_cntx_init_zen2.c @@ -92,7 +92,7 @@ void bli_cntx_init_zen2( cntx_t* cntx ) // Update the context with optimized level-1f kernels. bli_cntx_set_l1f_kers ( - 9, + 10, // axpyf BLIS_AXPYF_KER, BLIS_FLOAT, bli_saxpyf_zen_int_5, BLIS_AXPYF_KER, BLIS_DOUBLE, bli_daxpyf_zen_int_5, @@ -105,6 +105,7 @@ void bli_cntx_init_zen2( cntx_t* cntx ) BLIS_DOTXF_KER, BLIS_SCOMPLEX, bli_cdotxf_zen_int_6, // axpy2v BLIS_AXPY2V_KER, BLIS_DOUBLE, bli_daxpy2v_zen_int, + BLIS_AXPY2V_KER, BLIS_DCOMPLEX, bli_zaxpy2v_zen_int, cntx ); diff --git a/config/zen3/bli_cntx_init_zen3.c b/config/zen3/bli_cntx_init_zen3.c index a043d5ad22..61fefdbc31 100644 --- a/config/zen3/bli_cntx_init_zen3.c +++ b/config/zen3/bli_cntx_init_zen3.c @@ -92,7 +92,7 @@ void bli_cntx_init_zen3( cntx_t* cntx ) // Update the context with optimized level-1f kernels. bli_cntx_set_l1f_kers ( - 9, + 10, // axpyf BLIS_AXPYF_KER, BLIS_FLOAT, bli_saxpyf_zen_int_5, BLIS_AXPYF_KER, BLIS_DOUBLE, bli_daxpyf_zen_int_5, @@ -105,6 +105,7 @@ void bli_cntx_init_zen3( cntx_t* cntx ) BLIS_DOTXF_KER, BLIS_SCOMPLEX, bli_cdotxf_zen_int_6, // axpy2v BLIS_AXPY2V_KER, BLIS_DOUBLE, bli_daxpy2v_zen_int, + BLIS_AXPY2V_KER, BLIS_DCOMPLEX, bli_zaxpy2v_zen_int, cntx ); diff --git a/kernels/zen/1f/bli_axpy2v_zen_int.c b/kernels/zen/1f/bli_axpy2v_zen_int.c index 4ddca52162..cba0141376 100644 --- a/kernels/zen/1f/bli_axpy2v_zen_int.c +++ b/kernels/zen/1f/bli_axpy2v_zen_int.c @@ -186,3 +186,536 @@ void bli_daxpy2v_zen_int ); } } + +/** + * zaxpy2v kernel performs axpy2v operation. + * z := z + alphax * conjx(x) + alphay * conjy(y) + * where, + * x, y & z are double complex vectors of length n. + * alpha & beta are complex scalers. + */ +void bli_zaxpy2v_zen_int + ( + conj_t conjx, + conj_t conjy, + dim_t n, + dcomplex* restrict alphax, + dcomplex* restrict alphay, + dcomplex* restrict x, inc_t incx, + dcomplex* restrict y, inc_t incy, + dcomplex* restrict z, inc_t incz, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_4) + + // If the vectors are empty or if both alpha are zero, return early + if ( ( bli_zero_dim1( n ) ) || + ( PASTEMAC(z,eq0)( *alphax ) && PASTEMAC(z,eq0)( *alphay ) ) ) { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) + return; + } + + const dim_t n_elem_per_reg = 4; // Number of elements per register + + dim_t i = 0; // Iterator + + double* restrict x0; + double* restrict y0; + double* restrict z0; + double* restrict alphax0; + double* restrict alphay0; + + // Initialize local pointers. + x0 = (double*) x; + y0 = (double*) y; + z0 = (double*) z; + alphax0 = (double*) alphax; + alphay0 = (double*) alphay; + + if ( incx == 1 && incy == 1 && incz == 1 ) + { + //---------- Scalar algorithm BLIS_NO_CONJUGATE ------------- + // + // z = z + alphax * x + alphay * y + // z = ( zR + izI ) + + // ( axR + iaxI ) * ( xR + ixI ) + + // ( ayR + iayI ) * ( yR + iyI ) + // z = ( zR + izI ) + + // ( axR.xR + iaxR.xI + iaxI.xR - axI.xI ) + + // ( xyR.yR + iayR.yI + iayI.yR - ayI.yI ) + // z = ( zR + izI ) + + // ( ( axR.xR - axI.xI ) + i( axR.xI + axI.xR ) ) + + // ( ( ayR.yR - ayI.yI ) + i( ayR.yI + ayI.yR ) ) + // z = ( zR + axR.xR - axI.xI + ayR.yR - ayI.yI ) + + // i( zI + axR.xI + axI.xR + ayR.yI + ayI.yR ) + // + // SIMD Algorithm BLIS_NO_CONJUGATE + // xv = xR0 xI0 xR1 xI1 + // xv' = xI0 xR0 xI1 xR1 + // yv = yR0 yI0 yR1 yI1 + // yv' = yI0 yR0 yI1 yR1 + // zv = zR0 zI0 zR1 zI1 + // zv' = zI0 zR0 zI1 zR1 + // axrv = axR axR axR axR + // axiv = -axI axI -axI axI + // ayrv = ayR ayR ayR ayR + // ayiv = -ayI ayI -ayI ayI + // + // step 1: FMA zv = zv + axrv * xv + // step 2: shuffle xv -> xv' + // step 3: FMA zv = zv + axiv * xv' + // step 4: FMA zv = zv + ayrv * yv + // step 5: shuffle yv -> xyv' + // step 6: FMA zv = zv + ayiv * yv' + + //---------- Scalar algorithm BLIS_CONJUGATE ------------- + // + // z = z + alphax * x + alphay * y + // z = ( zR + izI ) + + // ( axR + iaxI ) * ( xR - ixI ) + + // ( ayR + iayI ) * ( yR - iyI ) + // z = ( zR + izI ) + + // ( axR.xR - iaxR.xI + iaxI.xR + axI.xI ) + + // ( xyR.yR - iayR.yI + iayI.yR + ayI.yI ) + // z = ( zR + izI ) + + // ( ( axR.xR + axI.xI ) + i( -axR.xI + axI.xR ) ) + + // ( ( ayR.yR + ayI.yI ) + i( -ayR.yI + ayI.yR ) ) + // z = ( zR + axR.xR + axI.xI + ayR.yR + ayI.yI ) + + // i( zI - axR.xI + axI.xR - ayR.yI + ayI.yR ) + // + // SIMD Algorithm BLIS_CONJUGATE + // xv = xR0 xI0 xR1 xI1 + // xv' = xI0 xR0 xI1 xR1 + // yv = yR0 yI0 yR1 yI1 + // yv' = yI0 yR0 yI1 yR1 + // zv = zR0 zI0 zR1 zI1 + // zv' = zI0 zR0 zI1 zR1 + // axrv = axR -axR axR -axR + // axiv = axI axI axI axI + // ayrv = ayR -ayR ayR -ayR + // ayiv = ayI ayI ayI ayI + // + // step 1: FMA zv = zv + axrv * xv + // step 2: shuffle xv -> xv' + // step 3: FMA zv = zv + axiv * xv' + // step 4: FMA zv = zv + ayrv * yv + // step 5: shuffle yv -> xyv' + // step 6: FMA zv = zv + ayiv * yv' + + __m256d alphaxRv; + __m256d alphaxIv; + __m256d alphayRv; + __m256d alphayIv; + __m256d xv[4]; + __m256d yv[4]; + __m256d zv[4]; + + double alphaxR, alphaxI; + double alphayR, alphayI; + + alphaxR = alphax->real; + alphaxI = alphax->imag; + alphayR = alphay->real; + alphayI = alphay->imag; + + // Broadcast alphax & alphay to respective vector registers + if ( !bli_is_conj( conjx ) ) // If not x conjugate + { + // alphaxRv = axR axR axR axR + // alphaxIv = -axI axI -axI axI + alphaxRv = _mm256_broadcast_sd( &alphaxR ); + alphaxIv = _mm256_set_pd( alphaxI, -alphaxI, alphaxI, -alphaxI ); + } + else + { + // alphaxRv = axR -axR axR -axR + // alphaxIv = axI axI axI axI + alphaxRv = _mm256_set_pd( -alphaxR, alphaxR, -alphaxR, alphaxR ); + alphaxIv = _mm256_broadcast_sd( &alphaxI ); + } + + if ( !bli_is_conj( conjy ) ) // If not y conjugate + { + // alphayRv = ayR ayR ayR ayR + // alphayIv = -ayI ayI -ayI ayI + alphayRv = _mm256_broadcast_sd( &alphayR ); + alphayIv = _mm256_set_pd( alphayI, -alphayI, alphayI, -alphayI ); + } + else + { + // alphayRv = ayR -ayR ayR -ayR + // alphayIv = ayI ayI ayI ayI + alphayRv = _mm256_set_pd( -alphayR, alphayR, -alphayR, alphayR ); + alphayIv = _mm256_broadcast_sd( &alphayI ); + } + + // Processing 8 elements per loop, 16 FMAs + for ( ; ( i + 7 ) < n; i += 8 ) + { + // Loading x vector + // xv = xR0 xI0 xR1 xI1 + xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); + xv[2] = _mm256_loadu_pd( x0 + 2*n_elem_per_reg ); + xv[3] = _mm256_loadu_pd( x0 + 3*n_elem_per_reg ); + + // Loading y vector + // yv = yR0 yI0 yR1 yI1 + yv[0] = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + yv[1] = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + yv[2] = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); + yv[3] = _mm256_loadu_pd( y0 + 3*n_elem_per_reg ); + + // Loading z vector + // zv = zR0 zI0 zR1 zI1 + zv[0] = _mm256_loadu_pd( z0 + 0*n_elem_per_reg ); + zv[1] = _mm256_loadu_pd( z0 + 1*n_elem_per_reg ); + zv[2] = _mm256_loadu_pd( z0 + 2*n_elem_per_reg ); + zv[3] = _mm256_loadu_pd( z0 + 3*n_elem_per_reg ); + + // zv = zv + alphaxRv * xv + // zv = zR0 + axR.xR0, zI0 + axR.xI0, ... + zv[0] = _mm256_fmadd_pd( xv[0], alphaxRv, zv[0] ); + zv[1] = _mm256_fmadd_pd( xv[1], alphaxRv, zv[1] ); + zv[2] = _mm256_fmadd_pd( xv[2], alphaxRv, zv[2] ); + zv[3] = _mm256_fmadd_pd( xv[3], alphaxRv, zv[3] ); + + // Shuffling xv + // xv = xI0 xR0 xI1 xR1 + xv[0] = _mm256_permute_pd( xv[0], 5 ); + xv[1] = _mm256_permute_pd( xv[1], 5 ); + xv[2] = _mm256_permute_pd( xv[2], 5 ); + xv[3] = _mm256_permute_pd( xv[3], 5 ); + + // zv = zv + alphaxIv * xv + // zv = zR0 + axR.xR0 - axI.xI0, zI0 + axR.xI0 + axI.xR0, ... + zv[0] = _mm256_fmadd_pd( xv[0], alphaxIv, zv[0] ); + zv[1] = _mm256_fmadd_pd( xv[1], alphaxIv, zv[1] ); + zv[2] = _mm256_fmadd_pd( xv[2], alphaxIv, zv[2] ); + zv[3] = _mm256_fmadd_pd( xv[3], alphaxIv, zv[3] ); + + // zv = zv + alphayRv * yv + // zv = zR0 + axR.xR0 - axI.xI0 + ayR.yR0, + // zI0 + axR.xI0 + axI.xR0 + ayR.yI0, ... + zv[0] = _mm256_fmadd_pd( yv[0], alphayRv, zv[0] ); + zv[1] = _mm256_fmadd_pd( yv[1], alphayRv, zv[1] ); + zv[2] = _mm256_fmadd_pd( yv[2], alphayRv, zv[2] ); + zv[3] = _mm256_fmadd_pd( yv[3], alphayRv, zv[3] ); + + // Shuffling yv + // yv = yI0 yR0 yI1 yR1 + yv[0] = _mm256_permute_pd( yv[0], 5 ); + yv[1] = _mm256_permute_pd( yv[1], 5 ); + yv[2] = _mm256_permute_pd( yv[2], 5 ); + yv[3] = _mm256_permute_pd( yv[3], 5 ); + + // zv = zv + alphayIv * yv + // zv = zR0 + axR.xR0 - axI.xI0 + ayR.yR0 - ayI.yI0, + // zI0 + axR.xI0 + axI.xR0 + ayR.yI0 + ayI.yR0, ... + zv[0] = _mm256_fmadd_pd( yv[0], alphayIv, zv[0] ); + zv[1] = _mm256_fmadd_pd( yv[1], alphayIv, zv[1] ); + zv[2] = _mm256_fmadd_pd( yv[2], alphayIv, zv[2] ); + zv[3] = _mm256_fmadd_pd( yv[3], alphayIv, zv[3] ); + + // Storing results from zv + _mm256_storeu_pd( (z0 + 0*n_elem_per_reg), zv[0] ); + _mm256_storeu_pd( (z0 + 1*n_elem_per_reg), zv[1] ); + _mm256_storeu_pd( (z0 + 2*n_elem_per_reg), zv[2] ); + _mm256_storeu_pd( (z0 + 3*n_elem_per_reg), zv[3] ); + + x0 += 4*n_elem_per_reg; + y0 += 4*n_elem_per_reg; + z0 += 4*n_elem_per_reg; + } + + // Processing 4 elements per loop, 8 FMAs + for ( ; ( i + 3 ) < n; i += 4 ) + { + // Loading x vector + // xv = xR0 xI0 xR1 xI1 + xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); + + // Loading y vector + // yv = yR0 yI0 yR1 yI1 + yv[0] = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + yv[1] = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + + // Loading z vector + // zv = zR0 zI0 zR1 zI1 + zv[0] = _mm256_loadu_pd( z0 + 0*n_elem_per_reg ); + zv[1] = _mm256_loadu_pd( z0 + 1*n_elem_per_reg ); + + // zv = zv + alphaxRv * xv + // zv = zR0 + axR.xR0, zI0 + axR.xI0, ... + zv[0] = _mm256_fmadd_pd( xv[0], alphaxRv, zv[0] ); + zv[1] = _mm256_fmadd_pd( xv[1], alphaxRv, zv[1] ); + + // Shuffling xv + // xv = xI0 xR0 xI1 xR1 + xv[0] = _mm256_permute_pd( xv[0], 5 ); + xv[1] = _mm256_permute_pd( xv[1], 5 ); + + // zv = zv + alphaxIv * xv + // zv = zR0 + axR.xR0 - axI.xI0, zI0 + axR.xI0 + axI.xR0, ... + zv[0] = _mm256_fmadd_pd( xv[0], alphaxIv, zv[0] ); + zv[1] = _mm256_fmadd_pd( xv[1], alphaxIv, zv[1] ); + + // zv = zv + alphayRv * yv + // zv = zR0 + axR.xR0 - axI.xI0 + ayR.yR0, + // zI0 + axR.xI0 + axI.xR0 + ayR.yI0, ... + zv[0] = _mm256_fmadd_pd( yv[0], alphayRv, zv[0] ); + zv[1] = _mm256_fmadd_pd( yv[1], alphayRv, zv[1] ); + + // Shuffling yv + // yv = yI0 yR0 yI1 yR1 + yv[0] = _mm256_permute_pd( yv[0], 5 ); + yv[1] = _mm256_permute_pd( yv[1], 5 ); + + // zv = zv + alphayIv * yv + // zv = zR0 + axR.xR0 - axI.xI0 + ayR.yR0 - ayI.yI0, + // zI0 + axR.xI0 + axI.xR0 + ayR.yI0 + ayI.yR0, ... + zv[0] = _mm256_fmadd_pd( yv[0], alphayIv, zv[0] ); + zv[1] = _mm256_fmadd_pd( yv[1], alphayIv, zv[1] ); + + // Storing results from zv + _mm256_storeu_pd( (z0 + 0*n_elem_per_reg), zv[0] ); + _mm256_storeu_pd( (z0 + 1*n_elem_per_reg), zv[1] ); + + x0 += 2*n_elem_per_reg; + y0 += 2*n_elem_per_reg; + z0 += 2*n_elem_per_reg; + } + + // Processing 2 elements per loop, 4FMAs + for ( ; ( i + 1 ) < n; i += 2 ) + { + // Loading x vector + // xv = xR0 xI0 xR1 xI1 + xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); + + // Loading y vector + // yv = yR0 yI0 yR1 yI1 + yv[0] = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + + // Loading z vector + // zv = zR0 zI0 zR1 zI1 + zv[0] = _mm256_loadu_pd( z0 + 0*n_elem_per_reg ); + + // zv = zv + alphaxRv * xv + // zv = zR0 + axR.xR0, zI0 + axR.xI0, ... + zv[0] = _mm256_fmadd_pd( xv[0], alphaxRv, zv[0] ); + + // Shuffling xv + // xv = xI0 xR0 xI1 xR1 + xv[0] = _mm256_permute_pd( xv[0], 5 ); + + // zv = zv + alphaxIv * xv + // zv = zR0 + axR.xR0 - axI.xI0, zI0 + axR.xI0 + axI.xR0, ... + zv[0] = _mm256_fmadd_pd( xv[0], alphaxIv, zv[0] ); + + // zv = zv + alphayRv * yv + // zv = zR0 + axR.xR0 - axI.xI0 + ayR.yR0, + // zI0 + axR.xI0 + axI.xR0 + ayR.yI0, ... + zv[0] = _mm256_fmadd_pd( yv[0], alphayRv, zv[0] ); + + // Shuffling yv + // yv = yI0 yR0 yI1 yR1 + yv[0] = _mm256_permute_pd( yv[0], 5 ); + + // zv = zv + alphayIv * yv + // zv = zR0 + axR.xR0 - axI.xI0 + ayR.yR0 - ayI.yI0, + // zI0 + axR.xI0 + axI.xR0 + ayR.yI0 + ayI.yR0, ... + zv[0] = _mm256_fmadd_pd( yv[0], alphayIv, zv[0] ); + + // Storing results from zv + _mm256_storeu_pd( (z0 + 0*n_elem_per_reg), zv[0] ); + + x0 += 1*n_elem_per_reg; + y0 += 1*n_elem_per_reg; + z0 += 1*n_elem_per_reg; + } + + // Issue vzeroupper instruction to clear upper lanes of ymm registers. + // This avoids a performance penalty caused by false dependencies when + // transitioning from from AVX to SSE instructions (which may occur + // as soon as the n_left cleanup loop below if BLIS is compiled with + // -mfpmath=sse). + _mm256_zeroupper(); + + if ( !bli_is_conj( conjx ) && !bli_is_conj( conjy ) ) + { + for ( ; i < n; i++ ) + { + // zR += ( axR.xR - axI.xI + ayR.yR - ayI.yI ) + *z0 += (*alphax0) * (*x0) - + (*(alphax0 + 1)) * (*(x0 + 1)) + + (*alphay0) * (*y0) - + (*(alphay0 + 1)) * (*(y0 + 1)); + + // zI += ( axR.xI + axI.xR + ayR.yI + ayI.yR ) + *(z0 + 1) += (*alphax0) * (*(x0 + 1)) + + (*(alphax0 + 1)) * (*x0) + + (*alphay0) * (*(y0 + 1)) + + (*(alphay0 + 1)) * (*y0); + + x0 += 2; + y0 += 2; + z0 += 2; + } + } + else if ( !bli_is_conj( conjx ) && bli_is_conj( conjy ) ) + { + for ( ; i < n; i++ ) + { + // zR += ( axR.xR - axI.xI + ayR.yR + ayI.yI ) + *z0 += (*alphax0) * (*x0) - + (*(alphax0 + 1)) * (*(x0 + 1)) + + (*alphay0) * (*y0) + + (*(alphay0 + 1)) * (*(y0 + 1)); + + // zI += ( axR.xI + axI.xR + ayR.yI - ayI.yR ) + *(z0 + 1) += (*alphax0) * (*(x0 + 1)) + + (*(alphax0 + 1)) * (*x0) + + (*(alphay0 + 1)) * (*y0) - + (*alphay0) * (*(y0 + 1)); + + x0 += 2; + y0 += 2; + z0 += 2; + } + } + else if ( bli_is_conj( conjx ) && !bli_is_conj( conjy ) ) + { + for ( ; i < n; i++ ) + { + // zR += ( axR.xR + axI.xI + ayR.yR - ayI.yI ) + *z0 += (*alphax0) * (*x0) + + (*(alphax0 + 1)) * (*(x0 + 1)) + + (*alphay0) * (*y0) - + (*(alphay0 + 1)) * (*(y0 + 1)); + + // zI += ( axR.xI - axI.xR + ayR.yI + ayI.yR ) + *(z0 + 1) += (*(alphax0 + 1)) * (*x0) - + (*alphax0) * (*(x0 + 1)) + + (*alphay0) * (*(y0 + 1)) + + (*(alphay0 + 1)) * (*y0); + + x0 += 2; + y0 += 2; + z0 += 2; + } + } + else + { + for ( ; i < n; i++ ) + { + // zR += ( axR.xR + axI.xI + ayR.yR + ayI.yI ) + *z0 += (*alphax0) * (*x0) + + (*(alphax0 + 1)) * (*(x0 + 1)) + + (*alphay0) * (*y0) + + (*(alphay0 + 1)) * (*(y0 + 1)); + + // zI += ( axR.xI - axI.xR + ayR.yI - ayI.yR ) + *(z0 + 1) += (*(alphax0 + 1)) * (*x0) - + (*alphax0) * (*(x0 + 1)) + + (*(alphay0 + 1)) * (*y0) - + (*alphay0) * (*(y0 + 1)); + + x0 += 2; + y0 += 2; + z0 += 2; + } + } + } + else + { + // Using scalar code for non-unit increments + if ( !bli_is_conj( conjx ) && !bli_is_conj( conjy ) ) + { + for ( ; i < n; i++ ) + { + // zR += ( axR.xR - axI.xI + ayR.yR - ayI.yI ) + *z0 += (*alphax0) * (*x0) - + (*(alphax0 + 1)) * (*(x0 + 1)) + + (*alphay0) * (*y0) - + (*(alphay0 + 1)) * (*(y0 + 1)); + + // zI += ( axR.xI + axI.xR + ayR.yI + ayI.yR ) + *(z0 + 1) += (*alphax0) * (*(x0 + 1)) + + (*(alphax0 + 1)) * (*x0) + + (*alphay0) * (*(y0 + 1)) + + (*(alphay0 + 1)) * (*y0); + + x0 += 2 * incx; + y0 += 2 * incy; + z0 += 2 * incz; + } + } + else if ( !bli_is_conj( conjx ) && bli_is_conj( conjy ) ) + { + for ( ; i < n; i++ ) + { + // zR += ( axR.xR - axI.xI + ayR.yR + ayI.yI ) + *z0 += (*alphax0) * (*x0) - + (*(alphax0 + 1)) * (*(x0 + 1)) + + (*alphay0) * (*y0) + + (*(alphay0 + 1)) * (*(y0 + 1)); + + // zI += ( axR.xI + axI.xR + ayR.yI - ayI.yR ) + *(z0 + 1) += (*alphax0) * (*(x0 + 1)) + + (*(alphax0 + 1)) * (*x0) + + (*(alphay0 + 1)) * (*y0) - + (*alphay0) * (*(y0 + 1)); + + x0 += 2 * incx; + y0 += 2 * incy; + z0 += 2 * incz; + } + } + else if ( bli_is_conj( conjx ) && !bli_is_conj( conjy ) ) + { + for ( ; i < n; i++ ) + { + // zR += ( axR.xR + axI.xI + ayR.yR - ayI.yI ) + *z0 += (*alphax0) * (*x0) + + (*(alphax0 + 1)) * (*(x0 + 1)) + + (*alphay0) * (*y0) - + (*(alphay0 + 1)) * (*(y0 + 1)); + + // zI += ( axR.xI - axI.xR + ayR.yI + ayI.yR ) + *(z0 + 1) += (*(alphax0 + 1)) * (*x0) - + (*alphax0) * (*(x0 + 1)) + + (*alphay0) * (*(y0 + 1)) + + (*(alphay0 + 1)) * (*y0); + + x0 += 2 * incx; + y0 += 2 * incy; + z0 += 2 * incz; + } + } + else + { + for ( ; i < n; i++ ) + { + // zR += ( axR.xR + axI.xI + ayR.yR + ayI.yI ) + *z0 += (*alphax0) * (*x0) + + (*(alphax0 + 1)) * (*(x0 + 1)) + + (*alphay0) * (*y0) + + (*(alphay0 + 1)) * (*(y0 + 1)); + + // zI += ( axR.xI - axI.xR + ayR.yI - ayI.yR ) + *(z0 + 1) += (*(alphax0 + 1)) * (*x0) - + (*alphax0) * (*(x0 + 1)) + + (*(alphay0 + 1)) * (*y0) - + (*alphay0) * (*(y0 + 1)); + + x0 += 2 * incx; + y0 += 2 * incy; + z0 += 2 * incz; + } + } + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) +} \ No newline at end of file diff --git a/kernels/zen/bli_kernels_zen.h b/kernels/zen/bli_kernels_zen.h index c092ab3ab7..bb7b57d096 100644 --- a/kernels/zen/bli_kernels_zen.h +++ b/kernels/zen/bli_kernels_zen.h @@ -127,6 +127,7 @@ AXPYF_KER_PROT( dcomplex, z, axpyf_zen_int_5 ) AXPYF_KER_PROT( dcomplex, z, axpyf_zen_int_4 ) // axpy2v (intrinsics) AXPY2V_KER_PROT(double, d, axpy2v_zen_int ) +AXPY2V_KER_PROT(dcomplex, z, axpy2v_zen_int ) // dotxf (intrinsics) DOTXF_KER_PROT( float, s, dotxf_zen_int_8 ) From e12f45033d9c3a9cf05c1f6847d4cb6c288f21bb Mon Sep 17 00:00:00 2001 From: Chandrashekara K R Date: Wed, 23 Feb 2022 13:11:54 +0530 Subject: [PATCH 088/243] AOCL_Windows: Updated windows build system. Removed the "target_link_libraries("${PROJECT_NAME}" PRIVATE OpenMP::OpenMP_CXX)" statement for the static ST library builb. This statement is not needed for static ST library build, mistakenly added. Change-Id: I577a28c75644043fd077d938bf7f51cdea8ee13d --- CMakeLists.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 601b2dce6f..7fc9d272e0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -560,7 +560,6 @@ if(NOT BUILD_SHARED_LIBS) set_target_properties("${PROJECT_NAME}" PROPERTIES LINKER_LANGUAGE C OUTPUT_NAME "${LIB_NAME}" STATIC_LIBRARY_OPTIONS "${OpenMP_libomp_LIBRARY}") else() set_target_properties("${PROJECT_NAME}" PROPERTIES LINKER_LANGUAGE C OUTPUT_NAME "${LIB_NAME}") - target_link_libraries("${PROJECT_NAME}" PRIVATE OpenMP::OpenMP_CXX) endif() endif() From d50d607995afe8d682933478383d31e62e4dc9bb Mon Sep 17 00:00:00 2001 From: Harsh Dave Date: Wed, 2 Mar 2022 04:08:26 -0600 Subject: [PATCH 089/243] dher2 API in blis make check fails on non avx2 platform - dher2 did not have avx check for platform. It was calling avx kernel regardless of platform support. Which resulted in core dump. - Added avx based platform check in both variant of dher2 for fixing the issue. AMD-Internal: [CPUPL-2043] Change-Id: I1fd1dcc9336980bfb7ffa9376f491f107c889c0b --- frame/2/her2/bli_her2_unf_var1_amd.c | 64 ++++++++++++++++++---------- frame/2/her2/bli_her2_unf_var4_amd.c | 39 +++++++++++------ 2 files changed, 67 insertions(+), 36 deletions(-) diff --git a/frame/2/her2/bli_her2_unf_var1_amd.c b/frame/2/her2/bli_her2_unf_var1_amd.c index 43a74f49cd..31667cc3e4 100644 --- a/frame/2/her2/bli_her2_unf_var1_amd.c +++ b/frame/2/her2/bli_her2_unf_var1_amd.c @@ -249,9 +249,13 @@ void bli_dher2_unf_var1 PASTECH(d,axpy2v_ker_ft) kfp_2v; /* Query the context for the kernel function pointer. */ - kfp_2v = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPY2V_KER, cntx ); + if ( cntx == NULL ) cntx = bli_gks_query_cntx(); + kfp_2v = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPY2V_KER, cntx ); - if( (incx == 1) && (incy == 1) && (rs_ct == 1)) + if ( (bli_cpuid_is_avx_supported() == TRUE) + && (incx == 1) + && (incy == 1) + && (rs_ct == 1)) { for ( i = 0; i < m; ) { @@ -265,29 +269,43 @@ void bli_dher2_unf_var1 if((n_behind >= 3)) { - bli_dher2_trans_zen_int_4(c10t, x0, y0, &alpha0, n_behind + 1, cs_ct); + bli_dher2_trans_zen_int_4(c10t, x0, y0, + &alpha0, + n_behind + 1, + cs_ct); i+=4; } else { - /* Apply conjx and/or conjy to chi1 and/or psi1. */ - PASTEMAC(d,copycjs)( conjx, *chi1, conjx0_chi1 ); - PASTEMAC(d,copycjs)( conjy, *psi1, conjy1_psi1 ); - PASTEMAC(d,copycjs)( conj0, *psi1, conjy0_psi1 ); - - /* Compute scalars for vector subproblems. */ - PASTEMAC(d,scal2s)( alpha0, conjx0_chi1, alpha0_chi1 ); - PASTEMAC(d,scal2s)( alpha1, conjy1_psi1, alpha1_psi1 ); - - /* Compute alpha * chi1 * conj(psi1) after both chi1 - * and psi1 have already been conjugated, if needed, + /* Apply conjx and/or conjy to chi1 + * and/or psi1. */ + PASTEMAC(d,copycjs)( conjx, *chi1, + conjx0_chi1 ); + PASTEMAC(d,copycjs)( conjy, *psi1, + conjy1_psi1 ); + PASTEMAC(d,copycjs)( conj0, *psi1, + conjy0_psi1 ); + + /* Compute scalars for vector + * subproblems. */ + PASTEMAC(d,scal2s)( alpha0, + conjx0_chi1, + alpha0_chi1 ); + PASTEMAC(d,scal2s)( alpha1, + conjy1_psi1, + alpha1_psi1 ); + + /* Compute alpha * chi1 * conj(psi1) + * after both chi1 and psi1 have + * already been conjugated, if needed * by conjx and conjy. */ - PASTEMAC(d,scal2s)( alpha0_chi1, conjy0_psi1, - alpha0_chi1_psi1 ); + PASTEMAC(d,scal2s)( alpha0_chi1, + conjy0_psi1, + alpha0_chi1_psi1 ); - /* c10t = c10t + alpha * chi1 * y0'; */ - /* c10t = c10t + conj(alpha) * psi1 * x0'; */ + /* c10t = c10t + alpha * chi1 * y0';*/ + /* c10t = c10t + conj(alpha) * psi1 * x0';*/ kfp_2v ( conj0, @@ -301,10 +319,12 @@ void bli_dher2_unf_var1 cntx ); - /* gamma11 = gamma11 + alpha * chi1 * conj(psi1) - + conj(alpha) * psi1 * conj(chi1); */ - PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); - PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); + /* gamma11 = gamma11 + alpha * chi1 *conj(psi1) + * + conj(alpha) * psi1 * conj(chi1);*/ + PASTEMAC(d,adds)( alpha0_chi1_psi1, + *gamma11 ); + PASTEMAC(d,adds)( alpha0_chi1_psi1, + *gamma11 ); i+=1; } diff --git a/frame/2/her2/bli_her2_unf_var4_amd.c b/frame/2/her2/bli_her2_unf_var4_amd.c index 4d77397cd2..6e999be7d1 100644 --- a/frame/2/her2/bli_her2_unf_var4_amd.c +++ b/frame/2/her2/bli_her2_unf_var4_amd.c @@ -246,9 +246,13 @@ void bli_dher2_unf_var4 PASTECH(d,axpy2v_ker_ft) kfp_2v; /* Query the context for the kernel function pointer. */ + if ( cntx == NULL ) cntx = bli_gks_query_cntx(); kfp_2v = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPY2V_KER, cntx ); - if((incx == 1) && (incy == 1) && (rs_ct == 1)) + if ( (bli_cpuid_is_avx_supported() == TRUE) + && (incx == 1) + && (incy == 1) + && (rs_ct == 1)) { for ( i = 0; i < m; ) { @@ -262,23 +266,28 @@ void bli_dher2_unf_var4 if((n_ahead >= 3)) { - bli_dher2_zen_int_4(gamma11, chi1, psi1, &alpha0, n_ahead + 1, cs_ct); + bli_dher2_zen_int_4(gamma11, chi1, + psi1, &alpha0, + n_ahead + 1, cs_ct); i+= 4; } else { - /* Compute scalars for vector subproblems. */ - PASTEMAC(d,scal2s)( alpha0, *psi1, alpha0_psi1 ); - PASTEMAC(d,scal2s)( alpha0, *chi1, alpha1_chi1 ); - - /* Compute alpha * chi1 * conj(psi1) after both chi1 - * and psi1 have - already been conjugated, if needed, by conjx and - conjy. */ + /* Compute scalars for vector + * subproblems. */ + PASTEMAC(d,scal2s)( alpha0, *psi1, + alpha0_psi1 ); + PASTEMAC(d,scal2s)( alpha0, *chi1, + alpha1_chi1 ); + + /* Compute alpha * chi1 * conj(psi1) + * after both chi1 and psi1 have + * already been conjugated, if needed, + * by conjx and conjy. */ PASTEMAC(d,scal2s)( alpha0_psi1, *chi1, - alpha0_chi1_psi1 ); + alpha0_chi1_psi1 ); - /* c21 = c21 + alpha * x2 * conj(psi1); */ + /* c21 = c21 + alpha * x2 * conj(psi1)*/ /* c21 = c21 + conj(alpha) * y2 * conj(chi1); */ kfp_2v @@ -295,8 +304,10 @@ void bli_dher2_unf_var4 ); - PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); - PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); + PASTEMAC(d,adds)( alpha0_chi1_psi1, + *gamma11 ); + PASTEMAC(d,adds)( alpha0_chi1_psi1, + *gamma11 ); i+=1; } } From 06e386f054e820360d170b1f0e828606c4bc513c Mon Sep 17 00:00:00 2001 From: Dipal M Zambare Date: Mon, 7 Mar 2022 14:38:08 +0530 Subject: [PATCH 090/243] Updated Windows build system to pick AMD specific sources. The framework cleanup was done for linux as part of f63f78d7 Removed Arch specific code from BLIS framework. This commit adds changes needed for windows build. AMD-Internal: [CPUPL-2052] Change-Id: Ibd503a0adeea66850de156fb95657b124e1c4b9d --- .gitignore | 10 +++++ CMakeLists.txt | 3 -- frame/2/gemv/CMakeLists.txt | 20 +++++++-- frame/2/hemv/CMakeLists.txt | 21 +++++++-- frame/2/her2/CMakeLists.txt | 21 +++++++-- frame/2/trsv/CMakeLists.txt | 21 +++++++-- frame/3/CMakeLists.txt | 18 +++++++- frame/3/gemm/CMakeLists.txt | 19 +++++++- frame/compat/CMakeLists.txt | 44 ++++++++++++++----- frame/compat/bla_gemm_amd.c | 4 +- kernels/zen/1f/CMakeLists.txt | 3 +- ...xpyf_int_8.c => bli_dotxaxpyf_zen_int_8.c} | 0 12 files changed, 152 insertions(+), 32 deletions(-) rename kernels/zen/1f/{bli_dotxaxpyf_int_8.c => bli_dotxaxpyf_zen_int_8.c} (100%) diff --git a/.gitignore b/.gitignore index d0de225b5c..f883af441e 100644 --- a/.gitignore +++ b/.gitignore @@ -53,3 +53,13 @@ out.* GPATH GRTAGS GTAGS + +# Windows Build +build/* +bin/* +*.dll +*.lib +*.pdb +*.exe + +.vscode diff --git a/CMakeLists.txt b/CMakeLists.txt index 7fc9d272e0..4c057f8bfa 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -34,20 +34,17 @@ endif () if(${AOCL_BLIS_FAMILY} STREQUAL "zen") add_definitions(-DBLIS_FAMILY_ZEN) - add_definitions(-DBLIS_CONFIG_EPYC) add_definitions(-DBLIS_CONFIG_ZEN) add_definitions(-DBLIS_KERNELS_ZEN) add_definitions(-DBLIS_KERNELS_HASWELL) elseif (${AOCL_BLIS_FAMILY} STREQUAL "zen2") add_definitions(-DBLIS_FAMILY_ZEN2) - add_definitions(-DBLIS_CONFIG_EPYC) add_definitions(-DBLIS_CONFIG_ZEN2) add_definitions(-DBLIS_KERNELS_ZEN2) add_definitions(-DBLIS_KERNELS_ZEN) add_definitions(-DBLIS_KERNELS_HASWELL) elseif (${AOCL_BLIS_FAMILY} STREQUAL "zen3") add_definitions(-DBLIS_FAMILY_ZEN3) - add_definitions(-DBLIS_CONFIG_EPYC) add_definitions(-DBLIS_CONFIG_ZEN3) add_definitions(-DBLIS_KERNELS_ZEN3) add_definitions(-DBLIS_KERNELS_ZEN2) diff --git a/frame/2/gemv/CMakeLists.txt b/frame/2/gemv/CMakeLists.txt index 86be8ddc08..2f75a00f63 100644 --- a/frame/2/gemv/CMakeLists.txt +++ b/frame/2/gemv/CMakeLists.txt @@ -1,11 +1,25 @@ -##Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.## +##Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.## target_sources("${PROJECT_NAME}" PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemv_unb_var1.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemv_unb_var2.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemv_unf_var1.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemv_unf_var2.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemv_unf_var2.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemv_var_oapi.c ) +# Select AMD specific sources for AMD configurations. +if(${TARGET_ARCH} STREQUAL zen OR + ${TARGET_ARCH} STREQUAL zen2 OR + ${TARGET_ARCH} STREQUAL zen3 OR + ${TARGET_ARCH} STREQUAL amdzen) + target_sources("${PROJECT_NAME}" + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemv_unf_var1_amd.c + ) +else() + target_sources("${PROJECT_NAME}" + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemv_unf_var1.c + ) +endif() diff --git a/frame/2/hemv/CMakeLists.txt b/frame/2/hemv/CMakeLists.txt index 677c253271..34820c3762 100644 --- a/frame/2/hemv/CMakeLists.txt +++ b/frame/2/hemv/CMakeLists.txt @@ -1,4 +1,4 @@ -##Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.## +##Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.## target_sources("${PROJECT_NAME}" PRIVATE @@ -6,10 +6,25 @@ target_sources("${PROJECT_NAME}" ${CMAKE_CURRENT_SOURCE_DIR}/bli_hemv_unb_var2.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_hemv_unb_var3.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_hemv_unb_var4.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_hemv_unf_var1.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_hemv_unf_var1a.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_hemv_unf_var3.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_hemv_unf_var3a.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_hemv_var_oapi.c ) +# Select AMD specific sources for AMD configurations. +if(${TARGET_ARCH} STREQUAL zen OR + ${TARGET_ARCH} STREQUAL zen2 OR + ${TARGET_ARCH} STREQUAL zen3 OR + ${TARGET_ARCH} STREQUAL amdzen) + target_sources("${PROJECT_NAME}" + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/bli_hemv_unf_var1_amd.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_hemv_unf_var3_amd.c + ) +else() + target_sources("${PROJECT_NAME}" + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/bli_hemv_unf_var1.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_hemv_unf_var3.c + ) +endif() \ No newline at end of file diff --git a/frame/2/her2/CMakeLists.txt b/frame/2/her2/CMakeLists.txt index 1b4c264443..83629df8f5 100644 --- a/frame/2/her2/CMakeLists.txt +++ b/frame/2/her2/CMakeLists.txt @@ -1,4 +1,4 @@ -##Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.## +##Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.## target_sources("${PROJECT_NAME}" PRIVATE @@ -6,8 +6,23 @@ target_sources("${PROJECT_NAME}" ${CMAKE_CURRENT_SOURCE_DIR}/bli_her2_unb_var2.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_her2_unb_var3.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_her2_unb_var4.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_her2_unf_var1.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_her2_unf_var4.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_her2_var_oapi.c ) +# Select AMD specific sources for AMD configurations. +if(${TARGET_ARCH} STREQUAL zen OR + ${TARGET_ARCH} STREQUAL zen2 OR + ${TARGET_ARCH} STREQUAL zen3 OR + ${TARGET_ARCH} STREQUAL amdzen) + target_sources("${PROJECT_NAME}" + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/bli_her2_unf_var1_amd.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_her2_unf_var4_amd.c + ) +else() + target_sources("${PROJECT_NAME}" + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/bli_her2_unf_var1.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_her2_unf_var4.c + ) +endif() \ No newline at end of file diff --git a/frame/2/trsv/CMakeLists.txt b/frame/2/trsv/CMakeLists.txt index 1d16769d32..b07389340e 100644 --- a/frame/2/trsv/CMakeLists.txt +++ b/frame/2/trsv/CMakeLists.txt @@ -1,11 +1,26 @@ -##Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.## +##Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.## target_sources("${PROJECT_NAME}" PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/bli_trsv_unb_var1.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_trsv_unb_var2.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_trsv_unf_var1.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_trsv_unf_var2.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_trsv_var_oapi.c ) +# Select AMD specific sources for AMD configurations. +if(${TARGET_ARCH} STREQUAL zen OR + ${TARGET_ARCH} STREQUAL zen2 OR + ${TARGET_ARCH} STREQUAL zen3 OR + ${TARGET_ARCH} STREQUAL amdzen) + target_sources("${PROJECT_NAME}" + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/bli_trsv_unf_var1_amd.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_trsv_unf_var2_amd.c + ) +else() + target_sources("${PROJECT_NAME}" + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/bli_trsv_unf_var1.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_trsv_unf_var2.c + ) +endif() diff --git a/frame/3/CMakeLists.txt b/frame/3/CMakeLists.txt index 4b7711ed4e..b3aaf2c8c8 100644 --- a/frame/3/CMakeLists.txt +++ b/frame/3/CMakeLists.txt @@ -1,4 +1,4 @@ -##Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.## +##Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.## target_sources("${PROJECT_NAME}" PRIVATE @@ -12,7 +12,6 @@ target_sources("${PROJECT_NAME}" ${CMAKE_CURRENT_SOURCE_DIR}/bli_l3_packm.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_l3_prune.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_l3_sup.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_l3_sup_int.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_l3_sup_packm_a.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_l3_sup_packm_b.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_l3_sup_packm_var.c @@ -27,6 +26,21 @@ target_sources("${PROJECT_NAME}" ${CMAKE_CURRENT_SOURCE_DIR}/bli_l3_ukr_oapi.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_l3_ukr_tapi.c ) +# Select AMD specific sources for AMD configurations. +if(${TARGET_ARCH} STREQUAL zen OR + ${TARGET_ARCH} STREQUAL zen2 OR + ${TARGET_ARCH} STREQUAL zen3 OR + ${TARGET_ARCH} STREQUAL amdzen) + target_sources("${PROJECT_NAME}" + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/bli_l3_sup_int_amd.c + ) +else() + target_sources("${PROJECT_NAME}" + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/bli_l3_sup_int_amd.c + ) +endif() set(SUBDIRECTORIES "gemm" "hemm" "her2k" "herk" "symm" "syr2k" "syrk" "trmm" "trmm3" "trsm" "gemmt") diff --git a/frame/3/gemm/CMakeLists.txt b/frame/3/gemm/CMakeLists.txt index 8eb115d1f0..825dd745ca 100644 --- a/frame/3/gemm/CMakeLists.txt +++ b/frame/3/gemm/CMakeLists.txt @@ -1,4 +1,4 @@ -##Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.## +##Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.## target_sources("${PROJECT_NAME}" PRIVATE @@ -6,7 +6,6 @@ target_sources("${PROJECT_NAME}" ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemm_blk_var2.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemm_blk_var3.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemm_cntl.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemm_front.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemm_int.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemm_ker_var1.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemm_ker_var2.c @@ -16,4 +15,20 @@ target_sources("${PROJECT_NAME}" ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemm_packab.c ) +# Select AMD specific sources for AMD configurations. +if(${TARGET_ARCH} STREQUAL zen OR +${TARGET_ARCH} STREQUAL zen2 OR +${TARGET_ARCH} STREQUAL zen3 OR +${TARGET_ARCH} STREQUAL amdzen) + target_sources("${PROJECT_NAME}" + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemm_front_amd.c + ) +else() + target_sources("${PROJECT_NAME}" + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemm_front.c + ) +endif() + add_subdirectory(ind) diff --git a/frame/compat/CMakeLists.txt b/frame/compat/CMakeLists.txt index 7c20f5100c..48b66acbcb 100644 --- a/frame/compat/CMakeLists.txt +++ b/frame/compat/CMakeLists.txt @@ -1,17 +1,12 @@ -##Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.## +##Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.## target_sources("${PROJECT_NAME}" PRIVATE -${CMAKE_CURRENT_SOURCE_DIR}/bla_amax.c + ${CMAKE_CURRENT_SOURCE_DIR}/bla_amin.c ${CMAKE_CURRENT_SOURCE_DIR}/bla_asum.c -${CMAKE_CURRENT_SOURCE_DIR}/bla_axpy.c -${CMAKE_CURRENT_SOURCE_DIR}/bla_copy.c -${CMAKE_CURRENT_SOURCE_DIR}/bla_dot.c -${CMAKE_CURRENT_SOURCE_DIR}/bla_gemm.c ${CMAKE_CURRENT_SOURCE_DIR}/bla_gemm3m.c ${CMAKE_CURRENT_SOURCE_DIR}/bla_gemmt.c -${CMAKE_CURRENT_SOURCE_DIR}/bla_gemv.c ${CMAKE_CURRENT_SOURCE_DIR}/bla_ger.c ${CMAKE_CURRENT_SOURCE_DIR}/bla_hemm.c ${CMAKE_CURRENT_SOURCE_DIR}/bla_hemv.c @@ -20,8 +15,6 @@ ${CMAKE_CURRENT_SOURCE_DIR}/bla_her2.c ${CMAKE_CURRENT_SOURCE_DIR}/bla_her2k.c ${CMAKE_CURRENT_SOURCE_DIR}/bla_herk.c ${CMAKE_CURRENT_SOURCE_DIR}/bla_nrm2.c -${CMAKE_CURRENT_SOURCE_DIR}/bla_scal.c -${CMAKE_CURRENT_SOURCE_DIR}/bla_swap.c ${CMAKE_CURRENT_SOURCE_DIR}/bla_symm.c ${CMAKE_CURRENT_SOURCE_DIR}/bla_symv.c ${CMAKE_CURRENT_SOURCE_DIR}/bla_syr.c @@ -30,7 +23,6 @@ ${CMAKE_CURRENT_SOURCE_DIR}/bla_syr2k.c ${CMAKE_CURRENT_SOURCE_DIR}/bla_syrk.c ${CMAKE_CURRENT_SOURCE_DIR}/bla_trmm.c ${CMAKE_CURRENT_SOURCE_DIR}/bla_trmv.c -${CMAKE_CURRENT_SOURCE_DIR}/bla_trsm.c ${CMAKE_CURRENT_SOURCE_DIR}/bla_trsv.c ${CMAKE_CURRENT_SOURCE_DIR}/bla_gemm_batch.c ${CMAKE_CURRENT_SOURCE_DIR}/bla_axpby.c @@ -40,6 +32,38 @@ ${CMAKE_CURRENT_SOURCE_DIR}/bla_omatcopy2.c ${CMAKE_CURRENT_SOURCE_DIR}/bla_omatadd.c ) +# Select AMD specific sources for AMD configurations. +if(${TARGET_ARCH} STREQUAL zen OR +${TARGET_ARCH} STREQUAL zen2 OR +${TARGET_ARCH} STREQUAL zen3 OR +${TARGET_ARCH} STREQUAL amdzen) + target_sources("${PROJECT_NAME}" + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/bla_amax_amd.c + ${CMAKE_CURRENT_SOURCE_DIR}/bla_axpy_amd.c + ${CMAKE_CURRENT_SOURCE_DIR}/bla_copy_amd.c + ${CMAKE_CURRENT_SOURCE_DIR}/bla_dot_amd.c + ${CMAKE_CURRENT_SOURCE_DIR}/bla_gemm_amd.c + ${CMAKE_CURRENT_SOURCE_DIR}/bla_gemv_amd.c + ${CMAKE_CURRENT_SOURCE_DIR}/bla_scal_amd.c + ${CMAKE_CURRENT_SOURCE_DIR}/bla_swap_amd.c + ${CMAKE_CURRENT_SOURCE_DIR}/bla_trsm_amd.c + ) +else() + target_sources("${PROJECT_NAME}" + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/bla_amax.c + ${CMAKE_CURRENT_SOURCE_DIR}/bla_axpy.c + ${CMAKE_CURRENT_SOURCE_DIR}/bla_copy.c + ${CMAKE_CURRENT_SOURCE_DIR}/bla_dot.c + ${CMAKE_CURRENT_SOURCE_DIR}/bla_gemm.c + ${CMAKE_CURRENT_SOURCE_DIR}/bla_gemv.c + ${CMAKE_CURRENT_SOURCE_DIR}/bla_scal.c + ${CMAKE_CURRENT_SOURCE_DIR}/bla_swap.c + ${CMAKE_CURRENT_SOURCE_DIR}/bla_trsm.c + ) +endif() + #Add all subdirectories # add_subdirectory(attic) # add_subdirectory(blis) diff --git a/frame/compat/bla_gemm_amd.c b/frame/compat/bla_gemm_amd.c index 7ef58bfb35..197cc3e235 100644 --- a/frame/compat/bla_gemm_amd.c +++ b/frame/compat/bla_gemm_amd.c @@ -798,7 +798,7 @@ INSERT_GENTFUNC_BLAS_SC( gemm, gemm ) // Observed a regression in dgemm with this function addition. // Disabling temporarily. -#if 0 +#if 1 void dzgemm_ ( const f77_char* transa, @@ -875,7 +875,7 @@ void dzgemm_ bli_obj_init_finish_1x1( dt, (dcomplex*)alpha, &alphao ); bli_obj_init_finish_1x1( dt, (dcomplex*)beta, &betao ); - bli_obj_init_finish( dt_a, m0_a, n0_a, (dcomplex*)a, rs_a, cs_a, &ao ); + bli_obj_init_finish( dt_a, m0_a, n0_a, (double*)a, rs_a, cs_a, &ao ); bli_obj_init_finish( dt, m0_b, n0_b, (dcomplex*)b, rs_b, cs_b, &bo ); bli_obj_init_finish( dt, m0, n0, (dcomplex*)c, rs_c, cs_c, &co ); diff --git a/kernels/zen/1f/CMakeLists.txt b/kernels/zen/1f/CMakeLists.txt index 4b9caa40b6..3a77f69ef1 100644 --- a/kernels/zen/1f/CMakeLists.txt +++ b/kernels/zen/1f/CMakeLists.txt @@ -1,4 +1,4 @@ -##Copyright (C) 2020-2021, Advanced Micro Devices, Inc. All rights reserved.## +##Copyright (C) 2020-2022, Advanced Micro Devices, Inc. All rights reserved.## target_sources("${PROJECT_NAME}" PRIVATE @@ -8,4 +8,5 @@ target_sources("${PROJECT_NAME}" ${CMAKE_CURRENT_SOURCE_DIR}/bli_axpyf_zen_int_4.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_axpyf_zen_int_6.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_axpy2v_zen_int.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_dotxaxpyf_zen_int_8.c ) diff --git a/kernels/zen/1f/bli_dotxaxpyf_int_8.c b/kernels/zen/1f/bli_dotxaxpyf_zen_int_8.c similarity index 100% rename from kernels/zen/1f/bli_dotxaxpyf_int_8.c rename to kernels/zen/1f/bli_dotxaxpyf_zen_int_8.c From ab06f17689576ace327e390325dd7d79848d4608 Mon Sep 17 00:00:00 2001 From: mkurumel Date: Fri, 18 Feb 2022 16:00:13 +0530 Subject: [PATCH 091/243] DGEMMT : Tuning SUP threshold to improve ST and MT performance. Details : - SUP Threshold change for native vs SUP - Improved the ST performances for sizes n<800 - Introduce PACKB in SUP to improve ST performance between 320 320) && (k > 50)) + bli_rntm_set_pack_b( 1, rntm ); + } + } @@ -317,6 +326,14 @@ err_t bli_gemmtsup_int // new ways of parallelism value for the jc loop. bli_rntm_set_ways_only( jc_new, 1, ic_new, 1, 1, rntm ); bli_l3_sup_thrinfo_update_root( rntm, thread ); + + /* Enable packing for A matrix for higher sizes. Note that pack A + * * becomes pack B inside var2m because this is transpose case*/ + if(bli_is_double(dt) && (n_threads==1)) + { + if((m > 320) && (k > 50)) + bli_rntm_set_pack_a( 1, rntm ); + } } diff --git a/kernels/zen/util/bli_thresh_funcs_zen.c b/kernels/zen/util/bli_thresh_funcs_zen.c index 1b5fc86998..2786f00e43 100644 --- a/kernels/zen/util/bli_thresh_funcs_zen.c +++ b/kernels/zen/util/bli_thresh_funcs_zen.c @@ -37,16 +37,31 @@ // -- gemmt specific function bool bli_cntx_gemmtsup_thresh_is_met_zen( obj_t* a, obj_t* b, obj_t* c, cntx_t* cntx ) { - num_t dt = bli_obj_dt( c ); + num_t dt = bli_obj_dt( c ); + dim_t n = bli_obj_length( c ); + dim_t k = bli_obj_width_after_trans( a ); + rntm_t rntm; - dim_t n = bli_obj_length( c ); - dim_t k = bli_obj_width_after_trans( a ); + bli_rntm_init_from_global( &rntm ); + + // Query the number of threads from rntm object. + const dim_t n_threads = bli_rntm_num_threads( &rntm ); if( bli_is_double( dt )) { - if ( n < 300 ) return TRUE; - if ( (k / n ) > 50 ) return TRUE; - + if( n_threads == 16) + { + /*Push sizes for n<1200 into SUP path*/ + if ( n < 1200 ) return TRUE; + /*For 12005 , With packing , Native path performance is better */ + if ( n < 1600 && (n / k) < 5) return TRUE; + } + else + { + if ( n < 800 ) return TRUE; + if ( (k / n ) > 50 ) return TRUE; + } return FALSE; } else if ( bli_is_dcomplex( dt ) ) From 6a2c4acc666b28e9fe97c08ef16b524421cec47c Mon Sep 17 00:00:00 2001 From: Sireesha Sanga Date: Tue, 15 Mar 2022 16:33:55 +0530 Subject: [PATCH 092/243] Runtime Thread Control using OpenMP API Details: - During runtime, Application can set the desired number of threads using standard OpenMP API omp_set_num_threads(nt). - BLIS Library uses standard OpenMP API omp_get_max_threads() internally, to fetch the latest value set by the application. - This value will be used to decide the number of threads in the subsequent BLAS calls. - At the time of BLIS Initialization, BLIS_NUM_THREADS environment variable will be given precedence, over the OpenMP standard API omp_set_num_threads(nt) and OMP_NUM_THREADS environment variable. - Order of precedence followed during BLIS Initialization is as follows 1. Valid value of BLIS_NUM_THREADS 2. omp_set_num_threads(nt) 3. valid value of OMP_NUM_THREADS 4. Number of cores - After BLIS initialization, if the Application issues omp_set_num_threads(nt) during runtime, number of threads set during BLIS Initialization, is overridden by the latest value set by the Application. - Existing precedence of BLIS_*_NT environment variables and the decision of optimal number of threads over the number of threads derived from the above process remains as it is. AMD-Internal: [CPUPL-2076] Change-Id: I935ba0246b1c256d0fee7d386eac0f5940fabff8 --- frame/base/bli_rntm.c | 14 +++++++++++++ frame/thread/bli_thread.c | 44 +++++++++++++++++++++++++++++++-------- 2 files changed, 49 insertions(+), 9 deletions(-) diff --git a/frame/base/bli_rntm.c b/frame/base/bli_rntm.c index dc0acf6bf9..7176dacc4e 100644 --- a/frame/base/bli_rntm.c +++ b/frame/base/bli_rntm.c @@ -49,9 +49,23 @@ void bli_rntm_init_from_global( rntm_t* rntm ) // We must ensure that global_rntm has been initialized. bli_init_once(); + // Fetch the number of threads based on the order of precedence, + // or the latest value of number of threads, + // if set by the Application using omp_set_num_threads(nt) API. +#ifdef BLIS_ENABLE_OPENMP + dim_t n_threads = omp_get_max_threads(); +#endif + // Acquire the mutex protecting global_rntm. bli_pthread_mutex_lock( &global_rntm_mutex ); + // Update the latest value of number of threads into global rntm structure, + // before copying into local rntm structure. This updated value will be + // used in the subsequent parallel regions. +#ifdef BLIS_ENABLE_OPENMP + global_rntm.num_threads = n_threads; +#endif + *rntm = global_rntm; // Release the mutex protecting global_rntm. diff --git a/frame/thread/bli_thread.c b/frame/thread/bli_thread.c index 159a9e802e..f570bcc2d8 100644 --- a/frame/thread/bli_thread.c +++ b/frame/thread/bli_thread.c @@ -1633,20 +1633,46 @@ void bli_thread_init_rntm_from_env // Try to read BLIS_NUM_THREADS first. nt = bli_env_get_var( "BLIS_NUM_THREADS", -1 ); - // If BLIS_NUM_THREADS was not set, try to read OMP_NUM_THREADS. - if ( nt == -1 ) - nt = bli_env_get_var( "OMP_NUM_THREADS", -1 ); #ifdef BLIS_ENABLE_OPENMP - // If both environment variables are not set - - // number of threads can also be set by the application by calling omp_set_num_threads(nt) - // The next parallel region when encountered will run with number of threads set by the above API. - // We can know about the number of threads by using the API "omp_get_max_threads()" - if (nt == -1) nt = omp_get_max_threads(); - // If application is multithreaded and number of threads is set using omp_set_num_threads(nt) + + // Scenarios: + // 1. If BLIS_NUM_THREADS is set with valid value, set the nt using omp_set_num_threads(nt) + // so that this value can be fetched inside BLIS API as well. + // 2. If BLIS_NUM_THREADS is not set, then if Application is multithreaded and issued + // omp_set_num_threads(nt) with desired number of threads, + // omp_get_max_threads() API will fetch the number of threads set earlier. + // 3. If BLIS_NUM_THREADS is not set, omp_set_num_threads(nt) is not called by the application, + // but only OMP_NUM_THREADS is set, + // omp_get_max_threads() API will fetch the value of OMP_NUM_THREADS. + // 4. If both environment variables are not set, or if they are set with invalid values, and + // omp_set_num_threads(nt) is not issued by application, + // omp_get_max_threads() API will return the number of the cores in the current context. + // // BLIS will rntm->num_threads will also get initialized with the same value. // However if omp_set_nested is false - BLIS APIs called from parallel threads will run in sequential. // But if nested parallelism is enabled - Then each application will launch MT BLIS. + // + // Order of precedence used for number of threads: + // 1. valid value set for BLIS_NUM_THREADS environment variable + // 2. omp_set_num_threads(nt) issued by the application + // 3. valid value set for OMP_NUM_THREADS environment variable + // 4. Number of cores + // + // Note: If nt is not a valid value for omp_set_num_threads(nt) API, number of threads would be set to 1. + // omp_get_max_threads() API will return 1. + // + // OMP_NUM_THREADS environment variable is applicable only when OpenMP is enabled. + + if(nt > 0) + { + omp_set_num_threads(nt); + } + else + { + nt = omp_get_max_threads(); + } + #endif // Read the environment variables for the number of threads (ways // of parallelism) for each individual loop. From 34fee0fdbc9111f0e0ec022329460e6e48db6a9d Mon Sep 17 00:00:00 2001 From: Chandrashekara K R Date: Wed, 16 Mar 2022 11:51:05 +0530 Subject: [PATCH 093/243] AOCL-Windows: Added logic in the windows build system to generate cblas.h at configure time. AMD-Internal: [CPUPL-2037] Change-Id: Ie4ffd1d655079c895878f96dbb6f811547ad953d --- CMakeLists.txt | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index 4c057f8bfa..bcb67f2ccf 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -534,6 +534,23 @@ execute_process( OUTPUT_VARIABLE CMD_OUTPUT) message( STATUS "Generating monolithic header file :" ${CMD_OUTPUT}) +# Logic to generate the cblas.h in include folder. +set(CBLAS_H "cblas.h") +# Arguements for python script +set(C_COMMENT "-c") +set(VERBOSE "-v1") +set(INPUT "${CMAKE_SOURCE_DIR}/frame/compat/cblas/src/${CBLAS_H}") +set(OUTPUT "${CMAKE_SOURCE_DIR}/include/${TARGET_ARCH}/${CBLAS_H}") +set(TEMP_DIR "${INCLUDE}") +set(DIR_H_PATH "${HEADER_PATH}") + +# Run python script to generate monolithic header at configuration time +execute_process( + COMMAND ${PYTHON_EXE} ${FLATTEN_PY} "${C_COMMENT}" "${VERBOSE}" "${INPUT}" "${OUTPUT}" "${TEMP_DIR}" "${DIR_H_PATH}" + RESULT_VARIABLE CMD_RESULT + OUTPUT_VARIABLE CMD_OUTPUT) +message( STATUS "Generating monolithic cblas header file :" ${CMD_OUTPUT}) + # setting the blis version string file (STRINGS "version" BLIS_VERSION) set(BLIS_VERSION_STRING ${BLIS_VERSION}) From eb0ff018714055102d7d1e76848b235db1e73128 Mon Sep 17 00:00:00 2001 From: Nallani Bhaskar Date: Wed, 23 Mar 2022 10:30:14 +0530 Subject: [PATCH 094/243] Fine-tuning dynamic threading logic of DGEMM for small dimensions Description: 1. For small dimensions single threads dgemm_small performing better than dgemmsup and native paths. 2. Irrespecive of given number of threads we are redirecting into single thread dgemm_small AMD-Internal:[CPUPL-2053] Change-Id: If591152d18282c2544249f70bd2f0a8cd816b94e --- frame/compat/bla_gemm_amd.c | 30 ++++++++++++++---------------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/frame/compat/bla_gemm_amd.c b/frame/compat/bla_gemm_amd.c index 197cc3e235..2bb9126804 100644 --- a/frame/compat/bla_gemm_amd.c +++ b/frame/compat/bla_gemm_amd.c @@ -526,33 +526,31 @@ void dgemm_ //dim_t nt = bli_thread_get_num_threads(); // get number of threads bool nt = bli_thread_get_is_parallel(); // Check if parallel dgemm is invoked. - // if m0 is large and (n0 & k0) < 10 - SMALL GEMM - ST is better - // - #ifdef AOCL_DYNAMIC - if (nt && ((n0 > 10 ) || (k0 > 10)) ) + //For smaller sizes dgemm_small is perfoming better + if (nt && (((m0 >32) || (n0>32) || (k0>32)) && ((m0+n0+k0)>150)) ) #else - if (nt) + if (nt) #endif - { + { // Will call parallelized dgemm code - sup & native PASTEMAC(gemm, BLIS_OAPI_EX_SUF) - ( - &alphao, - &ao, - &bo, - &betao, - &co, - NULL, - NULL - ); + ( + &alphao, + &ao, + &bo, + &betao, + &co, + NULL, + NULL + ); AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); /* Finalize BLIS. */ bli_finalize_auto(); return; - } + } // The code below will be called when number of threads = 1. From 0976ed9ce56189c66a445238ef9523f15ebd6e91 Mon Sep 17 00:00:00 2001 From: Harsh Dave Date: Tue, 22 Mar 2022 06:59:36 -0500 Subject: [PATCH 095/243] Implement zgemm_small kernel Details: - Intrinsic implementation of zgemm_small nn kernel. - Intrinsic implementation of zgemm_small_At kernel. - Added support conjugate and hermitian transpose - Main loop operates in multiple of 4x3 tile. - Edge cases are handles separately. AMD-Internal: [CPUPL-2084] Change-Id: I512da265e4d4ceec904877544f1d15cddc147a66 --- frame/compat/bla_gemm_amd.c | 46 +- kernels/zen/3/bli_gemm_small.c | 7726 +++++++++++++++++++++++++++++++- 2 files changed, 7759 insertions(+), 13 deletions(-) diff --git a/frame/compat/bla_gemm_amd.c b/frame/compat/bla_gemm_amd.c index 2bb9126804..ff995b5f07 100644 --- a/frame/compat/bla_gemm_amd.c +++ b/frame/compat/bla_gemm_amd.c @@ -712,7 +712,12 @@ void zgemm_ //dim_t nt = bli_thread_get_num_threads(); // get number of threads bool nt = bli_thread_get_is_parallel(); // Check if parallel zgemm is invoked. - if ( nt ) +#ifdef AOCL_DYNAMIC + //For smaller sizes zgemm_small is perfoming better + if (nt && (((m0 >32) || (n0>32) || (k0>32)) && ((m0+n0+k0)>100)) ) +#else + if (nt) +#endif { // Will call parallelized zgemm code - sup & native PASTEMAC(gemm, BLIS_OAPI_EX_SUF) @@ -733,6 +738,31 @@ void zgemm_ return; } +#ifdef BLIS_ENABLE_SMALL_MATRIX + err_t status; + + if((nt == 0) && (m0 <= 512 ) && ( n0 <= 512 ) && ( k0 <= 512 )) + { + status = bli_gemm_small( &alphao, + &ao, + &bo, + &betao, + &co, + NULL, //cntx, + NULL + ); + } + + if (status == BLIS_SUCCESS) + { + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + /* Finalize BLIS. */ + bli_finalize_auto(); + + return; + } +#endif // The code below will be called when number of threads = 1. #if ENABLE_INDUCED_METHOD /* 3m_sqp is optimal for certain matrix shapes. @@ -769,13 +799,13 @@ void zgemm_ // sup has been enabled for single instance cases. if(single_instance==1) { - err_t status = bli_gemmsup(&alphao, &ao, &bo, &betao, &co, NULL, NULL); - if(status==BLIS_SUCCESS) - { - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) - return; - } + err_t status = bli_gemmsup(&alphao, &ao, &bo, &betao, &co, NULL, NULL); + if(status==BLIS_SUCCESS) + { + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) + return; + } } // fall back on native path when zgemm is not handled in sup path. diff --git a/kernels/zen/3/bli_gemm_small.c b/kernels/zen/3/bli_gemm_small.c index 3e9463fabc..e0c84eb227 100644 --- a/kernels/zen/3/bli_gemm_small.c +++ b/kernels/zen/3/bli_gemm_small.c @@ -40,6 +40,7 @@ #define MR 32 #define D_MR (MR >> 1) +#define Z_MR (MR >> 3) #define NR 3 #define D_BLIS_SMALL_MATRIX_K_THRES_ROME 256 @@ -70,7 +71,26 @@ err_t bli_dgemm_small cntx_t* cntx, cntl_t* cntl ); - +static err_t bli_zgemm_small + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + cntl_t* cntl + ); +static err_t bli_zgemm_small_At + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + cntl_t* cntl + ); static err_t bli_sgemm_small_atbn ( obj_t* alpha, @@ -112,7 +132,7 @@ err_t bli_gemm_small #ifdef BLIS_ENABLE_MULTITHREADING AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); - return BLIS_NOT_YET_IMPLEMENTED; + return BLIS_NOT_YET_IMPLEMENTED; #else // This function is invoked on all architectures including ‘generic’. // Non-AVX platforms will use the kernels derived from the context. @@ -152,6 +172,18 @@ err_t bli_gemm_small return bli_dgemm_small_At(alpha, a, b, beta, c, cntx, cntl); #endif } + if(dt == BLIS_DCOMPLEX) + { +#ifndef BLIS_ENABLE_MULTITHREADING + // bli_zgemm_small_At is called directly from blas interface for + // sizes within thresholds. + // Avoinding calling of bli_zgemm_small_At from gemm_front + // and directing to native implementation. + return BLIS_NOT_YET_IMPLEMENTED; +#else + return bli_zgemm_small_At(alpha, a, b, beta, c, cntx, cntl); +#endif + } if (bli_obj_has_notrans( b )) { @@ -180,6 +212,19 @@ err_t bli_gemm_small #endif } + if (dt == BLIS_DCOMPLEX) + { +#ifndef BLIS_ENABLE_MULTITHREADING + // bli_zgemm_small is called directly from BLAS interface for sizes within thresholds. + // Avoiding calling bli_zgemm_small from gemm_front and directing to + // native implementation. + return BLIS_NOT_YET_IMPLEMENTED; +#else + return bli_zgemm_small(alpha, a, b, beta, c, cntx, cntl); +#endif + } + + if (dt == BLIS_FLOAT) { return bli_sgemm_small(alpha, a, b, beta, c, cntx, cntl); @@ -189,7 +234,6 @@ err_t bli_gemm_small return BLIS_NOT_YET_IMPLEMENTED; }; - static err_t bli_sgemm_small ( obj_t* alpha, @@ -2865,7 +2909,6 @@ static err_t bli_sgemm_small if (m_remainder >= 4) { - //printf("HERE\n"); m_remainder -= 4; for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) @@ -5377,7 +5420,6 @@ err_t bli_dgemm_small_At if (m_remainder >= 4) { - //printf("HERE\n"); m_remainder -= 4; tA = A + row_idx * lda; @@ -5705,5 +5747,7679 @@ err_t bli_dgemm_small_At return BLIS_NONCONFORMAL_DIMENSIONS; } }; + + +#define BLIS_SET_YMM_REG_ZEROS \ + ymm4 = _mm256_setzero_pd(); \ + ymm5 = _mm256_setzero_pd(); \ + ymm6 = _mm256_setzero_pd(); \ + ymm7 = _mm256_setzero_pd(); \ + ymm14 = _mm256_setzero_pd(); \ + ymm15 = _mm256_setzero_pd(); \ + ymm16 = _mm256_setzero_pd(); \ + ymm17 = _mm256_setzero_pd(); \ + ymm18 = _mm256_setzero_pd(); \ + ymm19 = _mm256_setzero_pd(); \ + ymm20 = _mm256_setzero_pd(); \ + ymm21 = _mm256_setzero_pd(); \ + + +#define BLIS_SET_ALL_YMM_REG_ZEROS \ + ymm4 = _mm256_setzero_pd(); \ + ymm5 = _mm256_setzero_pd(); \ + ymm6 = _mm256_setzero_pd(); \ + ymm7 = _mm256_setzero_pd(); \ + ymm8 = _mm256_setzero_pd(); \ + ymm9 = _mm256_setzero_pd(); \ + ymm10 = _mm256_setzero_pd(); \ + ymm11 = _mm256_setzero_pd(); \ + ymm12 = _mm256_setzero_pd(); \ + ymm13 = _mm256_setzero_pd(); \ + ymm14 = _mm256_setzero_pd(); \ + ymm15 = _mm256_setzero_pd(); \ + + + +static err_t bli_zgemm_small + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + cntl_t* cntl + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO); + + bool conjtransa = bli_obj_has_conj(a); + bool conjtransb = bli_obj_has_conj(b); + + gint_t M = bli_obj_length( c ); // number of rows of Matrix C + gint_t N = bli_obj_width( c ); // number of columns of Matrix C + // number of columns of OP(A), will be updated if OP(A) is Transpose(A) + gint_t K = bli_obj_width( a ); + gint_t L = M * N; + + if(L && K ) + { + guint_t lda = bli_obj_col_stride( a ); // column stride of matrix OP(A). + guint_t ldb = bli_obj_col_stride( b ); // column stride of matrix OP(B). + guint_t ldc = bli_obj_col_stride( c ); // column stride of matrix C + guint_t row_idx, col_idx, k; + dcomplex *A = bli_obj_buffer_at_off(a); //pointer to elements of Matrix A + dcomplex *B = bli_obj_buffer_at_off(b); //pointer to elements of Matrix B + dcomplex *C = bli_obj_buffer_at_off(c); //pointer to elements of Matrix C + + dcomplex *tA = A, *tB = B, *tC = C;//, *tA_pack; + dcomplex *tA_packed; //temprorary pointer to hold packed A memory pointer + guint_t row_idx_packed; //packed A memory row index + guint_t lda_packed; //lda of packed A + guint_t col_idx_start; //starting index after A matrix is packed. + dim_t tb_inc_row = 1; // row stride of matrix B + dim_t tb_inc_col = ldb; // column stride of matrix B + __m256d ymm4, ymm5, ymm6, ymm7; + __m256d ymm8, ymm9, ymm10, ymm11; + __m256d ymm12, ymm13, ymm14, ymm15; + __m256d ymm16, ymm17, ymm18, ymm19, ymm20, ymm21; + __m256d ymm0, ymm1, ymm2, ymm3; + + gint_t n_remainder; // If the N is non multiple of 3.(N%3) + gint_t m_remainder; // If the M is non multiple of 4.(M%4) + + dcomplex *alpha_cast, *beta_cast; // alpha, beta multiples + alpha_cast = bli_obj_buffer_for_1x1(BLIS_DCOMPLEX, alpha); + beta_cast = bli_obj_buffer_for_1x1(BLIS_DCOMPLEX, beta); + + gint_t required_packing_A = 1; + mem_t local_mem_buf_A_s; + dcomplex *D_A_pack = NULL; + rntm_t rntm; + + //update the pointer math if matrix B needs to be transposed. + if (bli_obj_has_trans( b )) + { + tb_inc_col = 1; //switch row and column strides + tb_inc_row = ldb; + } + + //checking whether beta value is zero. + //if true, we should perform C=alpha * A*B operation + //instead of C = beta * C + alpha * (A * B) + bool is_beta_non_zero = 0; + if(!bli_obj_equals(beta, &BLIS_ZERO)) + is_beta_non_zero = 1; + + /* + * This function was using global array to pack part of A input when + * needed. However, using this global array make the function + * non-reentrant. Instead of using a global array we should allocate + * buffer for each invocation. Since the buffer size is too big or stack + * and doing malloc every time will be too expensive, better approach is + * to get the buffer from the pre-allocated pool and it the pool once we + * are doing. + * + * In order to get the buffer from pool, we need access to memory broker, + * currently this function is not invoked in such a way that it can + * receive the memory broker (via rntm). Following hack will get the + * global memory broker that can be use it to access the pool. + * + * Note there will be memory allocation at least on first innovation + * as there will not be any pool created for this size. + * Subsequent invocations will just reuse the buffer from the pool. + */ + + bli_rntm_init_from_global( &rntm ); + bli_rntm_set_num_threads_only( 1, &rntm ); + bli_membrk_rntm_set_membrk( &rntm ); + + // Get the current size of the buffer pool for A block packing. + // We will use the same size to avoid pool re-initliazaton + siz_t buffer_size = bli_pool_block_size( + bli_membrk_pool(bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), + bli_rntm_membrk(&rntm))); + + // + // This kernel assumes that "A" will be unpackged if N <= 3. + // Usually this range (N <= 3) is handled by SUP, however, + // if SUP is disabled or for any other condition if we do + // enter this kernel with N <= 3, we want to make sure that + // "A" remains unpacked. + // + + if ((N < 3) || ((Z_MR * K) << 3) > buffer_size) + { + required_packing_A = 0; + } + + if (required_packing_A == 1) + { +#ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_zgemm_small: Requesting mem pool block of size %lu\n", + buffer_size); +#endif + // Get the buffer from the pool. + bli_membrk_acquire_m(&rntm, + buffer_size, + BLIS_BITVAL_BUFFER_FOR_A_BLOCK, + &local_mem_buf_A_s); + + D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); + } + + /* + * The computation loop runs for Z_MRxN columns of C matrix, thus + * accessing the Z_MRxK A matrix data and KxNR B matrix data. + * The computation is organized as inner loops of dimension Z_MRxNR. + */ + // Process D_MR rows of C matrix at a time. + for (row_idx = 0; (row_idx + (Z_MR - 1)) < M; row_idx += Z_MR) + { + col_idx_start = 0; + tA_packed = A; + row_idx_packed = row_idx; + lda_packed = lda; + + /** + * This is the part of the pack and compute optimization. + * During the first column iteration, we store the accessed A + * matrix into contiguous static memory. This helps to keep te A + * matrix in Cache and aviods the TLB misses. + */ + if (required_packing_A) + { + col_idx = 0; + + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + tA_packed = D_A_pack; + +#ifdef BLIS_ENABLE_PREFETCH + _mm_prefetch((char*)(tC + 0), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 8), _MM_HINT_T0); + _mm_prefetch((char*)(tC + ldc), _MM_HINT_T0); + _mm_prefetch((char*)(tC + ldc + 8), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 2 * ldc), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 2 * ldc + 8), _MM_HINT_T0); +#endif + // clear scratch registers. + BLIS_SET_ALL_YMM_REG_ZEROS + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B + // matrix i data and multiplies it with + // the A matrix. + // This loop is processing Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied with matrix A columns. + ymm0 = _mm256_loadu_pd( + (double const *)tA); + ymm1 = _mm256_loadu_pd( + (double const *)(tA + 2)); + _mm256_storeu_pd( + (double *)tA_packed, ymm0); + _mm256_storeu_pd( + (double *) + (tA_packed + 2), ymm1); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) * + 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + tA_packed += Z_MR; + } + + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd( + (double const *)tA); + ymm1 = _mm256_loadu_pd( + (double const *)(tA + 2)); + _mm256_storeu_pd( + (double *)tA_packed, ymm0); + _mm256_storeu_pd( + (double *)(tA_packed + 2) + , ymm1); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + tA_packed += Z_MR; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and multiplies it with the A + // matrix. This loop is processing + // Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied with matrix A columns. + ymm0 = _mm256_loadu_pd( + (double const *)tA); + ymm1 = _mm256_loadu_pd( + (double const *)(tA + 2)); + _mm256_storeu_pd( + (double *)tA_packed, ymm0); + _mm256_storeu_pd( + (double *)(tA_packed + 2) + , ymm1); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + tA_packed += Z_MR; + } + + } + else //handles non-transpose case + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and multiplies it with the A + // matrix. This loop is processing + // Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd( + (double const *)tA); + ymm1 = _mm256_loadu_pd( + (double const *)(tA + 2)); + _mm256_storeu_pd( + (double *)tA_packed, ymm0); + _mm256_storeu_pd( + (double *)(tA_packed + 2) + , ymm1); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + tA_packed += Z_MR; + } + } + + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + ymm14 = _mm256_permute_pd(ymm14, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm11 = _mm256_addsub_pd(ymm11, ymm5); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + ymm12 = _mm256_addsub_pd(ymm12, ymm7); + ymm10 = _mm256_addsub_pd(ymm10, ymm14); + ymm13 = _mm256_addsub_pd(ymm13, ymm15); + + // alpha, beta multiplication. + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm10, ymm0); + ymm14 = _mm256_mul_pd(ymm10, ymm14); + ymm10 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm11, ymm0); + ymm14 = _mm256_mul_pd(ymm11, ymm14); + ymm11 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm12, ymm0); + ymm14 = _mm256_mul_pd(ymm12, ymm14); + ymm12 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm13, ymm0); + ymm14 = _mm256_mul_pd(ymm13, ymm14); + ymm13 = _mm256_hsub_pd(ymm15, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + (&beta_cast->imag)); + + + BLIS_SET_YMM_REG_ZEROS + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + ymm0 = _mm256_loadu_pd((double const *)(tC + 2)); + ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6); + ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7); + + // col 2 + ymm0 = _mm256_loadu_pd((double const *)(tC + ldc)); + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc + 2)); + ymm16 = _mm256_fmadd_pd(ymm0, ymm2, ymm16); + ymm17 = _mm256_fmadd_pd(ymm0, ymm3, ymm17); + + // col 3 + ymm0 = _mm256_loadu_pd((double const *) + (tC + (ldc * 2))); + ymm18 = _mm256_fmadd_pd(ymm0, ymm2, ymm18); + ymm19 = _mm256_fmadd_pd(ymm0, ymm3, ymm19); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + (ldc * 2) + 2)); + ymm20 = _mm256_fmadd_pd(ymm0, ymm2, ymm20); + ymm21 = _mm256_fmadd_pd(ymm0, ymm3, ymm21); + + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + ymm17 = _mm256_permute_pd(ymm17, 0x5); + ymm19 = _mm256_permute_pd(ymm19, 0x5); + ymm21 = _mm256_permute_pd(ymm21, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm6 = _mm256_addsub_pd(ymm6, ymm7); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + ymm16 = _mm256_addsub_pd(ymm16, ymm17); + ymm18 = _mm256_addsub_pd(ymm18, ymm19); + ymm20 = _mm256_addsub_pd(ymm20, ymm21); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm11 = _mm256_add_pd(ymm11, ymm6); + ymm9 = _mm256_add_pd(ymm9, ymm14); + ymm12 = _mm256_add_pd(ymm12, ymm16); + ymm10 = _mm256_add_pd(ymm10, ymm18); + ymm13 = _mm256_add_pd(ymm13, ymm20); + + _mm256_storeu_pd((double *)tC, ymm8); + _mm256_storeu_pd((double *)(tC + 2), ymm11); + + tC += ldc; + + _mm256_storeu_pd((double *)tC, ymm9); + _mm256_storeu_pd((double *)(tC + 2), ymm12); + + tC += ldc; + + _mm256_storeu_pd((double *)tC, ymm10); + _mm256_storeu_pd((double *)(tC + 2), ymm13); + + // modify the pointer arithematic to use packed A matrix. + col_idx_start = NR; + tA_packed = D_A_pack; + row_idx_packed = 0; + lda_packed = Z_MR; + } + // Process NR columns of C matrix at a time. + for (col_idx = col_idx_start; (col_idx + (NR - 1)) < N; + col_idx += NR) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + +#ifdef BLIS_ENABLE_PREFETCH + _mm_prefetch((char*)(tC + 0), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 8), _MM_HINT_T0); + _mm_prefetch((char*)(tC + ldc), _MM_HINT_T0); + _mm_prefetch((char*)(tC + ldc + 8), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 2 * ldc), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 2 * ldc + 8), _MM_HINT_T0); +#endif + // clear scratch registers. + + + BLIS_SET_ALL_YMM_REG_ZEROS + + double *tptr = (double *)tB; + + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd( + (double const *)tA); + ymm1 = _mm256_loadu_pd( + (double const *)(tA + 2)); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and multiplies it with the A + // matrix. This loop is processing + // Z_MR x K The inner loop broadcasts + // the B matrix data and multiplies it + // with the A matrix. This loop is + // processing Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and multiplies it with the A + // matrix. This loop is processing + // Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + else //handles non-transpose case + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and multiplies it with the A + // matrix. This loop is processing + // Z_MR x K The inner loop broadcasts the + // B matrix data and multiplies it with + // the A matrix. This loop is processing + // Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + ymm14 = _mm256_permute_pd(ymm14, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm11 = _mm256_addsub_pd(ymm11, ymm5); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + ymm12 = _mm256_addsub_pd(ymm12, ymm7); + ymm10 = _mm256_addsub_pd(ymm10, ymm14); + ymm13 = _mm256_addsub_pd(ymm13, ymm15); + + // alpha, beta multiplication. + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm10, ymm0); + ymm14 = _mm256_mul_pd(ymm10, ymm14); + ymm10 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm11, ymm0); + ymm14 = _mm256_mul_pd(ymm11, ymm14); + ymm11 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm12, ymm0); + ymm14 = _mm256_mul_pd(ymm12, ymm14); + ymm12 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm13, ymm0); + ymm14 = _mm256_mul_pd(ymm13, ymm14); + ymm13 = _mm256_hsub_pd(ymm15, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + + BLIS_SET_YMM_REG_ZEROS + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + ymm0 = _mm256_loadu_pd((double const *)(tC + 2)); + ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6); + ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7); + + ymm0 = _mm256_loadu_pd((double const *)(tC + ldc)); + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc + 2)); + ymm16 = _mm256_fmadd_pd(ymm0, ymm2, ymm16); + ymm17 = _mm256_fmadd_pd(ymm0, ymm3, ymm17); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc * 2)); + ymm18 = _mm256_fmadd_pd(ymm0, ymm2, ymm18); + ymm19 = _mm256_fmadd_pd(ymm0, ymm3, ymm19); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc * 2 + 2)); + ymm20 = _mm256_fmadd_pd(ymm0, ymm2, ymm20); + ymm21 = _mm256_fmadd_pd(ymm0, ymm3, ymm21); + + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + ymm17 = _mm256_permute_pd(ymm17, 0x5); + ymm19 = _mm256_permute_pd(ymm19, 0x5); + ymm21 = _mm256_permute_pd(ymm21, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm6 = _mm256_addsub_pd(ymm6, ymm7); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + ymm16 = _mm256_addsub_pd(ymm16, ymm17); + ymm18 = _mm256_addsub_pd(ymm18, ymm19); + ymm20 = _mm256_addsub_pd(ymm20, ymm21); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm11 = _mm256_add_pd(ymm11, ymm6); + ymm9 = _mm256_add_pd(ymm9, ymm14); + ymm12 = _mm256_add_pd(ymm12, ymm16); + ymm10 = _mm256_add_pd(ymm10, ymm18); + ymm13 = _mm256_add_pd(ymm13, ymm20); + + _mm256_storeu_pd((double *)tC, ymm8); + _mm256_storeu_pd((double *)(tC + 2), ymm11); + + tC += ldc; + + _mm256_storeu_pd((double *)tC, ymm9); + _mm256_storeu_pd((double *)(tC + 2), ymm12); + + tC += ldc; + + _mm256_storeu_pd((double *)tC, ymm10); + _mm256_storeu_pd((double *)(tC + 2), ymm13); + } + n_remainder = N - col_idx; + if (n_remainder == 2) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + // clear scratch registers. + + + BLIS_SET_ALL_YMM_REG_ZEROS + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and multiplies it with the A + // matrix. This loop is processing + // Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + + tptr += (tb_inc_row * 2); + tA += lda; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and multiplies it with the A + // matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied with matrix A columns. + ymm0 = _mm256_loadu_pd((double const*)tA); + ymm1 = _mm256_loadu_pd((double const*) + (tA + 2)); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and multiplies it with the A + // matrix. This loop is processing + // Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + + tptr += (tb_inc_row * 2); + tA += lda; + } + + } + else //handles non-transpose case + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and multiplies it with the A + // matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied with matrix A columns. + ymm0 = _mm256_loadu_pd((double const*)tA); + ymm1 = _mm256_loadu_pd((double const*) + (tA + 2)); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + tptr += tb_inc_row*2; + tA += lda; + } + + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm11 = _mm256_addsub_pd(ymm11, ymm5); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + ymm12 = _mm256_addsub_pd(ymm12, ymm7); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm11, ymm0); + ymm14 = _mm256_mul_pd(ymm11, ymm14); + ymm11 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm12, ymm0); + ymm14 = _mm256_mul_pd(ymm12, ymm14); + ymm12 = _mm256_hsub_pd(ymm15, ymm14); + + + BLIS_SET_YMM_REG_ZEROS + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + ymm0 = _mm256_loadu_pd((double const *)(tC + 2)); + ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6); + ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7); + + ymm0 = _mm256_loadu_pd((double const *)(tC + ldc)); + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc + 2)); + ymm16 = _mm256_fmadd_pd(ymm0, ymm2, ymm16); + ymm17 = _mm256_fmadd_pd(ymm0, ymm3, ymm17); + + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + ymm17 = _mm256_permute_pd(ymm17, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm6 = _mm256_addsub_pd(ymm6, ymm7); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + ymm16 = _mm256_addsub_pd(ymm16, ymm17); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm11 = _mm256_add_pd(ymm11, ymm6); + ymm9 = _mm256_add_pd(ymm9, ymm14); + ymm12 = _mm256_add_pd(ymm12, ymm16); + + _mm256_storeu_pd((double *)(tC + 0), ymm8); + _mm256_storeu_pd((double *)(tC + 2), ymm11); + tC += ldc; + _mm256_storeu_pd((double *)tC, ymm9); + _mm256_storeu_pd((double *)(tC + 2), ymm12); + } + + if (n_remainder == 1) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + // clear scratch registers. + + + BLIS_SET_ALL_YMM_REG_ZEROS + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and multiplies it with the A + // matrix. This loop is processing + // Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tptr += (tb_inc_row * 2); + tA += lda; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + tptr += tb_inc_row*2; + + //broadcasted matrix B elements are + //multiplied with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tA += lda; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tptr += (tb_inc_row * 2); + tA += lda; + } + } + else //handles non-transpose case + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + tptr += tb_inc_row*2; + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tA += lda; + } + + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm5 = _mm256_permute_pd(ymm5, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm11 = _mm256_addsub_pd(ymm11, ymm5); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm11, ymm0); + ymm14 = _mm256_mul_pd(ymm11, ymm14); + ymm11 = _mm256_hsub_pd(ymm15, ymm14); + + + BLIS_SET_YMM_REG_ZEROS + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + ymm0 = _mm256_loadu_pd((double const *)(tC + 2)); + ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6); + ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7); + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm6 = _mm256_addsub_pd(ymm6, ymm7); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm11 = _mm256_add_pd(ymm11, ymm6); + + _mm256_storeu_pd((double *)tC, ymm8); + _mm256_storeu_pd((double *)(tC + 2), ymm11); + } + } + m_remainder = M - row_idx; + + if ((m_remainder == 3)) + { + m_remainder -= 3; + __m128d xmm0; + + for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + + BLIS_SET_ALL_YMM_REG_ZEROS + + xmm0 = _mm_setzero_pd(); + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *)(tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *)(tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + } + + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + } + + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + ymm14 = _mm256_permute_pd(ymm14, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm11 = _mm256_addsub_pd(ymm11, ymm5); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + ymm12 = _mm256_addsub_pd(ymm12, ymm7); + ymm10 = _mm256_addsub_pd(ymm10, ymm14); + ymm13 = _mm256_addsub_pd(ymm13, ymm15); + // alpha, beta multiplication. + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm10, ymm0); + ymm14 = _mm256_mul_pd(ymm10, ymm14); + ymm10 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm11, ymm0); + ymm14 = _mm256_mul_pd(ymm11, ymm14); + ymm11 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm12, ymm0); + ymm14 = _mm256_mul_pd(ymm12, ymm14); + ymm12 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm13, ymm0); + ymm14 = _mm256_mul_pd(ymm13, ymm14); + ymm13 = _mm256_hsub_pd(ymm15, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + + + BLIS_SET_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + xmm0 = _mm_loadu_pd((double const *)(tC + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + ymm6 = _mm256_fmadd_pd(ymm1, ymm2, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc)); + xmm0 = _mm_loadu_pd((double const *) + (tC + ldc + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + ymm16 = _mm256_fmadd_pd(ymm1, ymm2, ymm16); + ymm17 = _mm256_fmadd_pd(ymm1, ymm3, ymm17); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc * 2)); + xmm0 = _mm_loadu_pd((double const *) + (tC + ldc * 2 + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm18 = _mm256_fmadd_pd(ymm0, ymm2, ymm18); + ymm19 = _mm256_fmadd_pd(ymm0, ymm3, ymm19); + ymm20 = _mm256_fmadd_pd(ymm1, ymm2, ymm20); + ymm21 = _mm256_fmadd_pd(ymm1, ymm3, ymm21); + + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + ymm17 = _mm256_permute_pd(ymm17, 0x5); + ymm19 = _mm256_permute_pd(ymm19, 0x5); + ymm21 = _mm256_permute_pd(ymm21, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm6 = _mm256_addsub_pd(ymm6, ymm7); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + ymm16 = _mm256_addsub_pd(ymm16, ymm17); + ymm18 = _mm256_addsub_pd(ymm18, ymm19); + ymm20 = _mm256_addsub_pd(ymm20, ymm21); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm11 = _mm256_add_pd(ymm11, ymm6); + ymm9 = _mm256_add_pd(ymm9, ymm14); + ymm12 = _mm256_add_pd(ymm12, ymm16); + ymm10 = _mm256_add_pd(ymm10, ymm18); + ymm13 = _mm256_add_pd(ymm13, ymm20); + + _mm256_storeu_pd((double *)tC, ymm8); + xmm0 = _mm256_extractf128_pd(ymm11, 0); + _mm_storeu_pd((double *)(tC + 2), xmm0); + + tC += ldc; + + _mm256_storeu_pd((double *)tC, ymm9); + xmm0 = _mm256_extractf128_pd(ymm12, 0); + _mm_storeu_pd((double *)(tC + 2), xmm0); + + tC += ldc; + + _mm256_storeu_pd((double *)tC, ymm10); + xmm0 = _mm256_extractf128_pd(ymm13, 0); + _mm_storeu_pd((double *)(tC + 2), xmm0); + } + n_remainder = N - col_idx; + if (n_remainder == 2) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + // clear scratch registers. + + BLIS_SET_ALL_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd((tptr + + tb_inc_col + * 0)); + ymm3 = _mm256_broadcast_sd((tptr + + tb_inc_col + * 0 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd((tptr + + tb_inc_col + * 0)); + ymm3 = _mm256_broadcast_sd((tptr + + tb_inc_col + * 0 + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd((tptr + + tb_inc_col + * 0)); + ymm3 = _mm256_broadcast_sd((tptr + + tb_inc_col + * 0 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd((tptr + + tb_inc_col + * 0)); + ymm3 = _mm256_broadcast_sd((tptr + + tb_inc_col + * 0 + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + tptr += tb_inc_row*2; + tA += lda; + } + + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm11 = _mm256_addsub_pd(ymm11, ymm5); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + ymm12 = _mm256_addsub_pd(ymm12, ymm7); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm11, ymm0); + ymm14 = _mm256_mul_pd(ymm11, ymm14); + ymm11 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm12, ymm0); + ymm14 = _mm256_mul_pd(ymm12, ymm14); + ymm12 = _mm256_hsub_pd(ymm15, ymm14); + + + + BLIS_SET_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + xmm0 = _mm_loadu_pd((double const *)(tC + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + ymm6 = _mm256_fmadd_pd(ymm1, ymm2, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm0 = _mm256_loadu_pd((double const *)(tC + ldc)); + xmm0 = _mm_loadu_pd((double const *)(tC + ldc + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + ymm16 = _mm256_fmadd_pd(ymm1, ymm2, ymm16); + ymm17 = _mm256_fmadd_pd(ymm1, ymm3, ymm17); + + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + ymm17 = _mm256_permute_pd(ymm17, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm6 = _mm256_addsub_pd(ymm6, ymm7); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + ymm16 = _mm256_addsub_pd(ymm16, ymm17); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm11 = _mm256_add_pd(ymm11, ymm6); + ymm9 = _mm256_add_pd(ymm9, ymm14); + ymm12 = _mm256_add_pd(ymm12, ymm16); + + _mm256_storeu_pd((double *)tC, ymm8); + xmm0 = _mm256_extractf128_pd(ymm11, 0); + _mm_storeu_pd((double *)(tC + 2), xmm0); + + tC += ldc; + _mm256_storeu_pd((double *)tC, ymm9); + xmm0 = _mm256_extractf128_pd(ymm12, 0); + _mm_storeu_pd((double *)(tC + 2), xmm0); + } + if (n_remainder == 1) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + // clear scratch registers. + + + BLIS_SET_ALL_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tptr += tb_inc_row*2; + tA += lda; + } + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm5 = _mm256_permute_pd(ymm5, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm11 = _mm256_addsub_pd(ymm11, ymm5); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm11, ymm0); + ymm14 = _mm256_mul_pd(ymm11, ymm14); + ymm11 = _mm256_hsub_pd(ymm15, ymm14); + + + + BLIS_SET_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + xmm0 = _mm_loadu_pd((double const *)(tC + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + ymm6 = _mm256_fmadd_pd(ymm1, ymm2, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm6 = _mm256_addsub_pd(ymm6, ymm7); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm11 = _mm256_add_pd(ymm11, ymm6); + + _mm256_storeu_pd((double *)tC, ymm8); + xmm0 = _mm256_extractf128_pd(ymm11, 0); + _mm_storeu_pd((double *)(tC + 2), xmm0); + } + } + if ((m_remainder == 2)) + { + m_remainder -= 2; + + for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + + + BLIS_SET_ALL_YMM_REG_ZEROS + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + } + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + ymm14 = _mm256_permute_pd(ymm14, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + ymm10 = _mm256_addsub_pd(ymm10, ymm14); + // alpha, beta multiplication. + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm10, ymm0); + ymm14 = _mm256_mul_pd(ymm10, ymm14); + ymm10 = _mm256_hsub_pd(ymm15, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + + BLIS_SET_YMM_REG_ZEROS + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + ymm0 = _mm256_loadu_pd((double const *)(tC + ldc)); + + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc * 2)); + + ymm18 = _mm256_fmadd_pd(ymm0, ymm2, ymm18); + ymm19 = _mm256_fmadd_pd(ymm0, ymm3, ymm19); + + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + ymm19 = _mm256_permute_pd(ymm19, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + ymm18 = _mm256_addsub_pd(ymm18, ymm19); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm9 = _mm256_add_pd(ymm9, ymm14); + ymm10 = _mm256_add_pd(ymm10, ymm18); + + _mm256_storeu_pd((double *)tC, ymm8); + + tC += ldc; + + _mm256_storeu_pd((double *)tC, ymm9); + + tC += ldc; + + _mm256_storeu_pd((double *)tC, ymm10); + } + n_remainder = N - col_idx; + if (n_remainder == 2) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + // clear scratch registers. + + BLIS_SET_ALL_YMM_REG_ZEROS + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda; + } + + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + + + BLIS_SET_YMM_REG_ZEROS + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + ymm0 = _mm256_loadu_pd((double const *)(tC + ldc)); + + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm9 = _mm256_add_pd(ymm9, ymm14); + + _mm256_storeu_pd((double *)tC, ymm8); + tC += ldc; + _mm256_storeu_pd((double *)tC, ymm9); + } + if (n_remainder == 1) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + // clear scratch registers. + + + BLIS_SET_ALL_YMM_REG_ZEROS + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda; + } + + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + + + BLIS_SET_YMM_REG_ZEROS + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + + _mm256_storeu_pd((double *)tC, ymm8); + } + } + if ((m_remainder == 1)) + { + m_remainder -= 1; + __m128d xmm0; + + for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + + + BLIS_SET_ALL_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + } + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + ymm14 = _mm256_permute_pd(ymm14, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + ymm10 = _mm256_addsub_pd(ymm10, ymm14); + // alpha, beta multiplication. + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm10, ymm0); + ymm14 = _mm256_mul_pd(ymm10, ymm14); + ymm10 = _mm256_hsub_pd(ymm15, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + BLIS_SET_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + xmm0 = _mm_loadu_pd((double const *)(tC)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + xmm0 = _mm_loadu_pd((double const *)(tC + ldc)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + + xmm0 = _mm_loadu_pd((double const *)(tC + ldc * 2)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm18 = _mm256_fmadd_pd(ymm0, ymm2, ymm18); + ymm19 = _mm256_fmadd_pd(ymm0, ymm3, ymm19); + + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + ymm19 = _mm256_permute_pd(ymm19, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + ymm18 = _mm256_addsub_pd(ymm18, ymm19); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm9 = _mm256_add_pd(ymm9, ymm14); + ymm10 = _mm256_add_pd(ymm10, ymm18); + + xmm0 = _mm256_extractf128_pd(ymm8, 0); + _mm_storeu_pd((double *)tC, xmm0); + + tC += ldc; + + xmm0 = _mm256_extractf128_pd(ymm9, 0); + _mm_storeu_pd((double *)tC, xmm0); + + tC += ldc; + xmm0 = _mm256_extractf128_pd(ymm10, 0); + _mm_storeu_pd((double *)tC, xmm0); + } + n_remainder = N - col_idx; + if (n_remainder == 2) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + // clear scratch registers. + + + BLIS_SET_ALL_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda; + } + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + + + + BLIS_SET_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + xmm0 = _mm_loadu_pd((double const *)(tC)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + xmm0 = _mm_loadu_pd((double const *)(tC + ldc)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm9 = _mm256_add_pd(ymm9, ymm14); + + xmm0 = _mm256_extractf128_pd(ymm8, 0); + _mm_storeu_pd((double *)tC, xmm0); + tC += ldc; + xmm0 = _mm256_extractf128_pd(ymm9, 0); + _mm_storeu_pd((double *)tC, xmm0); + } + if (n_remainder == 1) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + // clear scratch registers. + + BLIS_SET_ALL_YMM_REG_ZEROS + + xmm0 = _mm_setzero_pd(); + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda; + } + + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda; + } + + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + + + BLIS_SET_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + xmm0 = _mm_loadu_pd((double const *)(tC)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + + xmm0 = _mm256_extractf128_pd(ymm8, 0); + _mm_storeu_pd((double *)tC, xmm0); + + } + } + // Return the buffer to pool + if ((required_packing_A == 1) && bli_mem_is_alloc( &local_mem_buf_A_s )) { +#ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_zgemm_small(): releasing mem pool block\n" ); +#endif + bli_membrk_release(&rntm, + &local_mem_buf_A_s); + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return BLIS_SUCCESS; + } + else + { + AOCL_DTL_TRACE_EXIT_ERR( + AOCL_DTL_LEVEL_INFO, + "Invalid dimesions for small gemm." + ); + return BLIS_NONCONFORMAL_DIMENSIONS; + } +}; + +static err_t bli_zgemm_small_At + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + cntl_t* cntl + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO); + + bool conjtransa = bli_obj_has_conj(a); + bool conjtransb = bli_obj_has_conj(b); + + gint_t M = bli_obj_length( c ); // number of rows of Matrix C + gint_t N = bli_obj_width( c ); // number of columns of Matrix C + gint_t K = bli_obj_width_after_trans( a ); // number of columns of OP(A) + + + if (N<3) //Implemenation assumes that N is atleast 3. + { + AOCL_DTL_TRACE_EXIT_ERR( + AOCL_DTL_LEVEL_INFO, + "N < 3, cannot be processed by small gemm" + ); + return BLIS_NOT_YET_IMPLEMENTED; + } + + if( M && N && K ) + { + guint_t lda = bli_obj_col_stride( a ); // column stride of matrix OP(A) + guint_t ldb = bli_obj_col_stride( b ); // column stride of matrix OP(B) + guint_t ldc = bli_obj_col_stride( c ); // column stride of matrix C + guint_t row_idx, col_idx, k; + dcomplex *A = bli_obj_buffer_at_off(a); //pointer to elements of Matrix A + dcomplex *B = bli_obj_buffer_at_off(b); //pointer to elements of Matrix B + dcomplex *C = bli_obj_buffer_at_off(c); //pointer to elements of Matrix C + + dcomplex *tA = A, *tB = B, *tC = C;//, *tA_pack; + dcomplex *tA_packed; // temprorary pointer to hold packed A memory pointer + guint_t row_idx_packed; //packed A memory row index + guint_t lda_packed; //lda of packed A + dim_t tb_inc_row = 1; // row stride of matrix B + dim_t tb_inc_col = ldb; // column stride of matrix B + + dcomplex *alpha_cast, *beta_cast; // alpha, beta multiples + alpha_cast = bli_obj_buffer_for_1x1(BLIS_DCOMPLEX, alpha); + beta_cast = bli_obj_buffer_for_1x1(BLIS_DCOMPLEX, beta); + + gint_t required_packing_A = 1; + mem_t local_mem_buf_A_s; + dcomplex *D_A_pack = NULL; + rntm_t rntm; + + if( bli_obj_has_trans( b ) ) + { + tb_inc_col = 1; // switch row and column strides + tb_inc_row = ldb; + } + + __m256d ymm4, ymm5, ymm6, ymm7; + __m256d ymm8, ymm9, ymm10, ymm11; + __m256d ymm12, ymm13, ymm14, ymm15; + __m256d ymm16, ymm17, ymm18, ymm19, ymm20, ymm21; + __m256d ymm0, ymm1, ymm2, ymm3; + + gint_t n_remainder; // If the N is non multiple of 3.(N%3) + gint_t m_remainder; // If the M is non multiple of 16.(M%16) + + //checking whether beta value is zero. + //if true, we should perform C=alpha * A*B operation + //instead of C = beta * C + alpha * (A * B) + bool is_beta_non_zero = 0; + if(!bli_obj_equals(beta, &BLIS_ZERO)) + is_beta_non_zero = 1; + + /* + * This function was using global array to pack part of A input when + * needed. + * However, using this global array make the function non-reentrant. + * Instead of using a global array we should allocate buffer for each + * invocation. + * Since the buffer size is too big or stack and doing malloc every time + * will be too expensive, + * better approach is to get the buffer from the pre-allocated pool and + * return + * it the pool once we are doing. + * + * In order to get the buffer from pool, we need access to memory broker, + * currently this function is not invoked in such a way that it can + * receive + * the memory broker (via rntm). Following hack will get the global memory + * broker that can be use it to access the pool. + * + * Note there will be memory allocation at least on first innovation + * as there will not be any pool created for this size. + * Subsequent invocations will just reuse the buffer from the pool. + */ + + bli_rntm_init_from_global( &rntm ); + bli_rntm_set_num_threads_only( 1, &rntm ); + bli_membrk_rntm_set_membrk( &rntm ); + + // Get the current size of the buffer pool for A block packing. + // We will use the same size to avoid pool re-initliazaton + siz_t buffer_size = bli_pool_block_size( + bli_membrk_pool(bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), + bli_rntm_membrk(&rntm))); + + // + // This kernel assumes that "A" will be unpackged if N <= 3. + // Usually this range (N <= 3) is handled by SUP, however, + // if SUP is disabled or for any other condition if we do + // enter this kernel with N <= 3, we want to make sure that + // "A" remains unpacked. + // + // If this check is removed it will result in the crash as + // reported in CPUPL-587. + // + + if ((N < 3) || ((Z_MR * K) << 3) > buffer_size) + { + required_packing_A = 0; + return BLIS_NOT_YET_IMPLEMENTED; + } + + if (required_packing_A == 1) + { +#ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_dgemm_small: Requesting mem pool block of size %lu\n", + buffer_size); #endif + // Get the buffer from the pool. + bli_membrk_acquire_m(&rntm, + buffer_size, + BLIS_BITVAL_BUFFER_FOR_A_BLOCK, + &local_mem_buf_A_s); + + D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); + } + + /* + * The computation loop runs for D_MRxN columns of C matrix, thus + * accessing the D_MRxK A matrix data and KxNR B matrix data. + * The computation is organized as inner loops of dimension D_MRxNR. + */ + // Process D_MR rows of C matrix at a time. + for (row_idx = 0; (row_idx + (Z_MR - 1)) < M; row_idx += Z_MR) + { + + tA = A + row_idx * lda; + tA_packed = D_A_pack; + lda_packed = Z_MR; + + // Pack 16xk of matrix A into buffer + // continuous access for A and strided stores to B + for(inc_t x = 0; (x) < 2; x += 1) + { + dcomplex* tA_temp = tA; + + for(k = 0; (k+1) < K; k += 2) + { + ymm0 = _mm256_loadu_pd((double const *) + (tA_temp + 0 * lda)); + ymm2 = _mm256_loadu_pd((double const *) + (tA_temp + 1 * lda)); + + ymm6 = _mm256_permute2f128_pd(ymm0,ymm2,0x20); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm2,0x31); + + _mm256_storeu_pd((double *) + (tA_packed + 0 * lda_packed), + ymm6); + _mm256_storeu_pd((double *) + (tA_packed + 1 * lda_packed), + ymm7); + + tA_temp += 2; + tA_packed += 2 * lda_packed; + } + + for(; k < K; k += 1) + { + tA_packed[0].real = tA_temp[0 * lda].real; + tA_packed[0].imag = tA_temp[0 * lda].imag; + tA_packed[1].real = tA_temp[1 * lda].real; + tA_packed[1].imag = tA_temp[1 * lda].imag; + + tA_temp += 1; + tA_packed += lda_packed; + } + + tA += 2 * lda; + tA_packed = D_A_pack + (x + 1)*2; + } + + tA_packed = D_A_pack; + row_idx_packed = 0; + lda_packed = Z_MR; + + // Process NR columns of C matrix at a time. + for (col_idx = 0; (col_idx + (NR - 1)) < N; col_idx += NR) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; +#ifdef BLIS_ENABLE_PREFETCH + _mm_prefetch((char*)(tC + 0), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 8), _MM_HINT_T0); + _mm_prefetch((char*)(tC + ldc), _MM_HINT_T0); + _mm_prefetch((char*)(tC + ldc + 8), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 2 * ldc), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 2 * ldc + 8), _MM_HINT_T0); +#endif + // clear scratch registers. + + BLIS_SET_ALL_YMM_REG_ZEROS + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + ymm14 = _mm256_permute_pd(ymm14, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm11 = _mm256_addsub_pd(ymm11, ymm5); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + ymm12 = _mm256_addsub_pd(ymm12, ymm7); + ymm10 = _mm256_addsub_pd(ymm10, ymm14); + ymm13 = _mm256_addsub_pd(ymm13, ymm15); + + // alpha, beta multiplication. + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm10, ymm0); + ymm14 = _mm256_mul_pd(ymm10, ymm14); + ymm10 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm11, ymm0); + ymm14 = _mm256_mul_pd(ymm11, ymm14); + ymm11 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm12, ymm0); + ymm14 = _mm256_mul_pd(ymm12, ymm14); + ymm12 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm13, ymm0); + ymm14 = _mm256_mul_pd(ymm13, ymm14); + ymm13 = _mm256_hsub_pd(ymm15, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + (&beta_cast->imag)); + + + + BLIS_SET_YMM_REG_ZEROS + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + ymm0 = _mm256_loadu_pd((double const *)(tC + 2)); + ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6); + ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7); + + // col 2 + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc)); + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc + 2)); + ymm16 = _mm256_fmadd_pd(ymm0, ymm2, ymm16); + ymm17 = _mm256_fmadd_pd(ymm0, ymm3, ymm17); + + // col 3 + ymm0 = _mm256_loadu_pd((double const *) + (tC + (ldc * 2))); + ymm18 = _mm256_fmadd_pd(ymm0, ymm2, ymm18); + ymm19 = _mm256_fmadd_pd(ymm0, ymm3, ymm19); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + (ldc * 2) + 2)); + ymm20 = _mm256_fmadd_pd(ymm0, ymm2, ymm20); + ymm21 = _mm256_fmadd_pd(ymm0, ymm3, ymm21); + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + ymm17 = _mm256_permute_pd(ymm17, 0x5); + ymm19 = _mm256_permute_pd(ymm19, 0x5); + ymm21 = _mm256_permute_pd(ymm21, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm6 = _mm256_addsub_pd(ymm6, ymm7); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + ymm16 = _mm256_addsub_pd(ymm16, ymm17); + ymm18 = _mm256_addsub_pd(ymm18, ymm19); + ymm20 = _mm256_addsub_pd(ymm20, ymm21); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm11 = _mm256_add_pd(ymm11, ymm6); + ymm9 = _mm256_add_pd(ymm9, ymm14); + ymm12 = _mm256_add_pd(ymm12, ymm16); + ymm10 = _mm256_add_pd(ymm10, ymm18); + ymm13 = _mm256_add_pd(ymm13, ymm20); + + _mm256_storeu_pd((double *)tC, ymm8); + _mm256_storeu_pd((double *)(tC + 2), ymm11); + + tC += ldc; + + _mm256_storeu_pd((double *)tC, ymm9); + _mm256_storeu_pd((double *)(tC + 2), ymm12); + + tC += ldc; + + _mm256_storeu_pd((double *)tC, ymm10); + _mm256_storeu_pd((double *)(tC + 2), ymm13); + + } + n_remainder = N - col_idx; + + // if the N is not multiple of 3. + // handling edge case. + if (n_remainder == 2) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + + // clear scratch registers. + + + BLIS_SET_ALL_YMM_REG_ZEROS + double *tptr = (double *)tB; + + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const*)tA); + ymm1 = _mm256_loadu_pd((double const*) + (tA + 2)); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + tptr += tb_inc_row*2; + tA += lda_packed; + + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const*)tA); + ymm1 = _mm256_loadu_pd((double const*) + (tA + 2)); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const*)tA); + ymm1 = _mm256_loadu_pd((double const*) + (tA + 2)); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const*)tA); + ymm1 = _mm256_loadu_pd((double const*) + (tA + 2)); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + + + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm11 = _mm256_addsub_pd(ymm11, ymm5); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + ymm12 = _mm256_addsub_pd(ymm12, ymm7); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm11, ymm0); + ymm14 = _mm256_mul_pd(ymm11, ymm14); + ymm11 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm12, ymm0); + ymm14 = _mm256_mul_pd(ymm12, ymm14); + ymm12 = _mm256_hsub_pd(ymm15, ymm14); + + + + BLIS_SET_YMM_REG_ZEROS + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + ymm0 = _mm256_loadu_pd((double const *)(tC + 2)); + ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6); + ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7); + + ymm0 = _mm256_loadu_pd((double const *)(tC + ldc)); + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc + 2)); + ymm16 = _mm256_fmadd_pd(ymm0, ymm2, ymm16); + ymm17 = _mm256_fmadd_pd(ymm0, ymm3, ymm17); + + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + ymm17 = _mm256_permute_pd(ymm17, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm6 = _mm256_addsub_pd(ymm6, ymm7); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + ymm16 = _mm256_addsub_pd(ymm16, ymm17); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm11 = _mm256_add_pd(ymm11, ymm6); + ymm9 = _mm256_add_pd(ymm9, ymm14); + ymm12 = _mm256_add_pd(ymm12, ymm16); + + _mm256_storeu_pd((double *)(tC + 0), ymm8); + _mm256_storeu_pd((double *)(tC + 2), ymm11); + tC += ldc; + _mm256_storeu_pd((double *)tC, ymm9); + _mm256_storeu_pd((double *)(tC + 2), ymm12); + } + // if the N is not multiple of 3. + // handling edge case. + if (n_remainder == 1) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + + // clear scratch registers. + BLIS_SET_ALL_YMM_REG_ZEROS + double *tptr = (double *)tB; + + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + tptr += tb_inc_row*2; + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *)(tA + 2)); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tA += lda_packed; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd((double const *)(tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd((double const *)(tptr + tb_inc_col * 0 + 1)); + tptr += tb_inc_row*2; + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tA += lda_packed; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + tptr += tb_inc_row*2; + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tA += lda_packed; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + tptr += tb_inc_row*2; + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tA += lda_packed; + } + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm5 = _mm256_permute_pd(ymm5, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm11 = _mm256_addsub_pd(ymm11, ymm5); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm11, ymm0); + ymm14 = _mm256_mul_pd(ymm11, ymm14); + ymm11 = _mm256_hsub_pd(ymm15, ymm14); + + + + BLIS_SET_YMM_REG_ZEROS + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + ymm0 = _mm256_loadu_pd((double const *)(tC + 2)); + ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6); + ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7); + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm6 = _mm256_addsub_pd(ymm6, ymm7); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm11 = _mm256_add_pd(ymm11, ymm6); + + _mm256_storeu_pd((double *)tC, ymm8); + _mm256_storeu_pd((double *)(tC + 2), ymm11); + } + } + + m_remainder = M - row_idx; + if ((m_remainder == 3)) + { + m_remainder -= 3; + __m128d xmm0; + + tA = A + row_idx * lda; + tA_packed = D_A_pack; + lda_packed = 3; + { + dcomplex* tA_temp = tA; + + for(k = 0; (k+1) < K; k += 2) + { + ymm0 = _mm256_loadu_pd((double const *) + (tA_temp + 0 * lda)); + ymm2 = _mm256_loadu_pd((double const *) + (tA_temp + 1 * lda)); + ymm3 = _mm256_loadu_pd((double const *) + (tA_temp + 2 * lda)); + + ymm6 = _mm256_permute2f128_pd(ymm0,ymm2,0x20); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm2,0x31); + + _mm256_storeu_pd((double *) + (tA_packed + 0 * lda_packed), + ymm6); + xmm0 = _mm256_extractf128_pd(ymm3, 0); + _mm_storeu_pd((double *) + (tA_packed + 0 * lda_packed + 2), + xmm0); + + _mm256_storeu_pd((double *) + (tA_packed + 1 * lda_packed), + ymm7); + xmm0 = _mm256_extractf128_pd(ymm3, 1); + _mm_storeu_pd((double *) + (tA_packed + 1 * lda_packed + 2), + xmm0); + + tA_temp += 2; + tA_packed += 2 * lda_packed; + } + + for(; k < K; k += 1) + { + tA_packed[0].real = tA_temp[0 * lda].real; + tA_packed[0].imag = tA_temp[0 * lda].imag; + tA_packed[1].real = tA_temp[1 * lda].real; + tA_packed[1].imag = tA_temp[1 * lda].imag; + tA_packed[2].real = tA_temp[2 * lda].real; + tA_packed[2].imag = tA_temp[2 * lda].imag; + + tA_temp += 1; + tA_packed += lda_packed; + } + } + + tA_packed = D_A_pack; + row_idx_packed = 0; + lda_packed = 3; + + for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + + + BLIS_SET_ALL_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + ymm14 = _mm256_permute_pd(ymm14, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm11 = _mm256_addsub_pd(ymm11, ymm5); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + ymm12 = _mm256_addsub_pd(ymm12, ymm7); + ymm10 = _mm256_addsub_pd(ymm10, ymm14); + ymm13 = _mm256_addsub_pd(ymm13, ymm15); + // alpha, beta multiplication. + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm10, ymm0); + ymm14 = _mm256_mul_pd(ymm10, ymm14); + ymm10 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm11, ymm0); + ymm14 = _mm256_mul_pd(ymm11, ymm14); + ymm11 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm12, ymm0); + ymm14 = _mm256_mul_pd(ymm12, ymm14); + ymm12 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm13, ymm0); + ymm14 = _mm256_mul_pd(ymm13, ymm14); + ymm13 = _mm256_hsub_pd(ymm15, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + + BLIS_SET_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + xmm0 = _mm_loadu_pd((double const *)(tC + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + ymm6 = _mm256_fmadd_pd(ymm1, ymm2, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc)); + xmm0 = _mm_loadu_pd((double const *) + (tC + ldc + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + ymm16 = _mm256_fmadd_pd(ymm1, ymm2, ymm16); + ymm17 = _mm256_fmadd_pd(ymm1, ymm3, ymm17); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc * 2)); + xmm0 = _mm_loadu_pd((double const *) + (tC + ldc * 2 + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm18 = _mm256_fmadd_pd(ymm0, ymm2, ymm18); + ymm19 = _mm256_fmadd_pd(ymm0, ymm3, ymm19); + ymm20 = _mm256_fmadd_pd(ymm1, ymm2, ymm20); + ymm21 = _mm256_fmadd_pd(ymm1, ymm3, ymm21); + + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + ymm17 = _mm256_permute_pd(ymm17, 0x5); + ymm19 = _mm256_permute_pd(ymm19, 0x5); + ymm21 = _mm256_permute_pd(ymm21, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm6 = _mm256_addsub_pd(ymm6, ymm7); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + ymm16 = _mm256_addsub_pd(ymm16, ymm17); + ymm18 = _mm256_addsub_pd(ymm18, ymm19); + ymm20 = _mm256_addsub_pd(ymm20, ymm21); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm11 = _mm256_add_pd(ymm11, ymm6); + ymm9 = _mm256_add_pd(ymm9, ymm14); + ymm12 = _mm256_add_pd(ymm12, ymm16); + ymm10 = _mm256_add_pd(ymm10, ymm18); + ymm13 = _mm256_add_pd(ymm13, ymm20); + + _mm256_storeu_pd((double *)tC, ymm8); + xmm0 = _mm256_extractf128_pd(ymm11, 0); + _mm_storeu_pd((double *)(tC + 2), xmm0); + + tC += ldc; + + _mm256_storeu_pd((double *)tC, ymm9); + xmm0 = _mm256_extractf128_pd(ymm12, 0); + _mm_storeu_pd((double *)(tC + 2), xmm0); + + tC += ldc; + + _mm256_storeu_pd((double *)tC, ymm10); + xmm0 = _mm256_extractf128_pd(ymm13, 0); + _mm_storeu_pd((double *)(tC + 2), xmm0); + } + n_remainder = N - col_idx; + if (n_remainder == 2) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + + // clear scratch registers. + BLIS_SET_ALL_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd((tptr + + tb_inc_col + * 0)); + ymm3 = _mm256_broadcast_sd((tptr + + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd((tptr + + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd((tptr + + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd((tptr + + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd((tptr + + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd((tptr + + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd((tptr + + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm11 = _mm256_addsub_pd(ymm11, ymm5); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + ymm12 = _mm256_addsub_pd(ymm12, ymm7); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm11, ymm0); + ymm14 = _mm256_mul_pd(ymm11, ymm14); + ymm11 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm12, ymm0); + ymm14 = _mm256_mul_pd(ymm12, ymm14); + ymm12 = _mm256_hsub_pd(ymm15, ymm14); + + + BLIS_SET_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + xmm0 = _mm_loadu_pd((double const *)(tC + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + ymm6 = _mm256_fmadd_pd(ymm1, ymm2, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc)); + xmm0 = _mm_loadu_pd((double const *) + (tC + ldc + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + ymm16 = _mm256_fmadd_pd(ymm1, ymm2, ymm16); + ymm17 = _mm256_fmadd_pd(ymm1, ymm3, ymm17); + + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + ymm17 = _mm256_permute_pd(ymm17, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm6 = _mm256_addsub_pd(ymm6, ymm7); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + ymm16 = _mm256_addsub_pd(ymm16, ymm17); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm11 = _mm256_add_pd(ymm11, ymm6); + ymm9 = _mm256_add_pd(ymm9, ymm14); + ymm12 = _mm256_add_pd(ymm12, ymm16); + + _mm256_storeu_pd((double *)tC, ymm8); + xmm0 = _mm256_extractf128_pd(ymm11, 0); + _mm_storeu_pd((double *)(tC + 2), xmm0); + + tC += ldc; + _mm256_storeu_pd((double *)tC, ymm9); + xmm0 = _mm256_extractf128_pd(ymm12, 0); + _mm_storeu_pd((double *)(tC + 2), xmm0); + } + if (n_remainder == 1) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + + // clear scratch registers. + + BLIS_SET_ALL_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm5 = _mm256_permute_pd(ymm5, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm11 = _mm256_addsub_pd(ymm11, ymm5); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm11, ymm0); + ymm14 = _mm256_mul_pd(ymm11, ymm14); + ymm11 = _mm256_hsub_pd(ymm15, ymm14); + + + BLIS_SET_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + xmm0 = _mm_loadu_pd((double const *)(tC + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + ymm6 = _mm256_fmadd_pd(ymm1, ymm2, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm6 = _mm256_addsub_pd(ymm6, ymm7); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm11 = _mm256_add_pd(ymm11, ymm6); + + _mm256_storeu_pd((double *)tC, ymm8); + xmm0 = _mm256_extractf128_pd(ymm11, 0); + _mm_storeu_pd((double *)(tC + 2), xmm0); + } + } + if ((m_remainder == 2)) + { + m_remainder -= 2; + + tA = A + row_idx * lda; + tA_packed = D_A_pack; + lda_packed = 2; + + { + dcomplex* tA_temp = tA; + + for(k = 0; (k+1) < K; k += 2) + { + ymm0 = _mm256_loadu_pd((double const *) + (tA_temp + 0 * lda)); + ymm2 = _mm256_loadu_pd((double const *) + (tA_temp + 1 * lda)); + + ymm6 = _mm256_permute2f128_pd(ymm0,ymm2,0x20); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm2,0x31); + + _mm256_storeu_pd((double *) + (tA_packed + 0 * lda_packed), + ymm6); + _mm256_storeu_pd((double *) + (tA_packed + 1 * lda_packed), + ymm7); + + tA_temp += 2; + tA_packed += 2 * lda_packed; + } + + for(; k < K; k += 1) + { + tA_packed[0].real = tA_temp[0 * lda].real; + tA_packed[0].imag = tA_temp[0 * lda].imag; + tA_packed[1].real = tA_temp[1 * lda].real; + tA_packed[1].imag = tA_temp[1 * lda].imag; + + tA_temp += 1; + tA_packed += lda_packed; + } + } + + tA_packed = D_A_pack; + row_idx_packed = 0; + lda_packed = 2; + + for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + + + + BLIS_SET_ALL_YMM_REG_ZEROS + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + ymm14 = _mm256_permute_pd(ymm14, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + ymm10 = _mm256_addsub_pd(ymm10, ymm14); + // alpha, beta multiplication. + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm10, ymm0); + ymm14 = _mm256_mul_pd(ymm10, ymm14); + ymm10 = _mm256_hsub_pd(ymm15, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + BLIS_SET_YMM_REG_ZEROS + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc)); + + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc * 2)); + + ymm18 = _mm256_fmadd_pd(ymm0, ymm2, ymm18); + ymm19 = _mm256_fmadd_pd(ymm0, ymm3, ymm19); + + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + ymm19 = _mm256_permute_pd(ymm19, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + ymm18 = _mm256_addsub_pd(ymm18, ymm19); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm9 = _mm256_add_pd(ymm9, ymm14); + ymm10 = _mm256_add_pd(ymm10, ymm18); + + _mm256_storeu_pd((double *)tC, ymm8); + + tC += ldc; + + _mm256_storeu_pd((double *)tC, ymm9); + + tC += ldc; + + _mm256_storeu_pd((double *)tC, ymm10); + } + n_remainder = N - col_idx; + if (n_remainder == 2) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + + + // clear scratch registers. + + BLIS_SET_ALL_YMM_REG_ZEROS + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + + BLIS_SET_YMM_REG_ZEROS + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + ymm0 = _mm256_loadu_pd((double const *)(tC + ldc)); + + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm9 = _mm256_add_pd(ymm9, ymm14); + + _mm256_storeu_pd((double *)tC, ymm8); + tC += ldc; + _mm256_storeu_pd((double *)tC, ymm9); + } + if (n_remainder == 1) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + + // clear scratch registers. + + BLIS_SET_ALL_YMM_REG_ZEROS + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matri + // x data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matri + // x data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + + BLIS_SET_YMM_REG_ZEROS + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + + _mm256_storeu_pd((double *)tC, ymm8); + } + } + if ((m_remainder == 1)) + { + m_remainder -= 1; + __m128d xmm0; + + tA = A + row_idx * lda; + tA_packed = D_A_pack; + lda_packed = 1; + + { + dcomplex* tA_temp = tA; + + for(k = 0; (k+1) < K; k += 2) + { + ymm0 = _mm256_loadu_pd((double const *) + (tA_temp + 0 * lda)); + + xmm0 = _mm256_extractf128_pd(ymm0, 0); + _mm_storeu_pd((double *) + (tA_packed + 0 * lda_packed), + xmm0); + + xmm0 = _mm256_extractf128_pd(ymm0, 1); + _mm_storeu_pd((double *)(tA_packed + 1 + * lda_packed), xmm0); + + tA_temp += 2; + tA_packed += 2 * lda_packed; + } + + for(; k < K; k += 1) + { + tA_packed[0].real = tA_temp[0 * lda].real; + tA_packed[0].imag = tA_temp[0 * lda].imag; + + tA_temp += 1; + tA_packed += lda_packed; + } + } + + tA_packed = D_A_pack; + row_idx_packed = 0; + lda_packed = 1; + + for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + + + BLIS_SET_ALL_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + ymm14 = _mm256_permute_pd(ymm14, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + ymm10 = _mm256_addsub_pd(ymm10, ymm14); + // alpha, beta multiplication. + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm10, ymm0); + ymm14 = _mm256_mul_pd(ymm10, ymm14); + ymm10 = _mm256_hsub_pd(ymm15, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + BLIS_SET_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + xmm0 = _mm_loadu_pd((double const *)(tC)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + xmm0 = _mm_loadu_pd((double const *)(tC + ldc)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + + xmm0 = _mm_loadu_pd((double const *) + (tC + ldc * 2)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm18 = _mm256_fmadd_pd(ymm0, ymm2, ymm18); + ymm19 = _mm256_fmadd_pd(ymm0, ymm3, ymm19); + + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + ymm19 = _mm256_permute_pd(ymm19, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + ymm18 = _mm256_addsub_pd(ymm18, ymm19); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm9 = _mm256_add_pd(ymm9, ymm14); + ymm10 = _mm256_add_pd(ymm10, ymm18); + + xmm0 = _mm256_extractf128_pd(ymm8, 0); + _mm_storeu_pd((double *)tC, xmm0); + + tC += ldc; + + xmm0 = _mm256_extractf128_pd(ymm9, 0); + _mm_storeu_pd((double *)tC, xmm0); + + tC += ldc; + xmm0 = _mm256_extractf128_pd(ymm10, 0); + _mm_storeu_pd((double *)tC, xmm0); + } + n_remainder = N - col_idx; + if (n_remainder == 2) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + + // clear scratch registers. + + BLIS_SET_ALL_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + + + BLIS_SET_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + xmm0 = _mm_loadu_pd((double const *)(tC)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + xmm0 = _mm_loadu_pd((double const *)(tC + ldc)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm9 = _mm256_add_pd(ymm9, ymm14); + + xmm0 = _mm256_extractf128_pd(ymm8, 0); + _mm_storeu_pd((double *)tC, xmm0); + tC += ldc; + xmm0 = _mm256_extractf128_pd(ymm9, 0); + _mm_storeu_pd((double *)tC, xmm0); + } + if (n_remainder == 1) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + + // clear scratch registers. + + BLIS_SET_ALL_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matri + // x data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + + BLIS_SET_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + xmm0 = _mm_loadu_pd((double const *)(tC)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + + xmm0 = _mm256_extractf128_pd(ymm8, 0); + _mm_storeu_pd((double *)tC, xmm0); + + } + } + // Return the buffer to pool + if ((required_packing_A == 1) && bli_mem_is_alloc( &local_mem_buf_A_s )){ +#ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_dgemm_small_At(): releasing mem pool block\n" ); +#endif + bli_membrk_release(&rntm, + &local_mem_buf_A_s); + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return BLIS_SUCCESS; + } + else + { + AOCL_DTL_TRACE_EXIT_ERR( + AOCL_DTL_LEVEL_INFO, + "Invalid dimesions for dgemm_small_At." + ); + return BLIS_NONCONFORMAL_DIMENSIONS; + } +}; +#endif From 015bcb88d4d2db12d2e98dc7db8d6e3a7c25ee1d Mon Sep 17 00:00:00 2001 From: Harsh Dave Date: Wed, 30 Mar 2022 07:16:24 -0500 Subject: [PATCH 096/243] Fixed ztrsm computational failure - Fixed memory access for edge cases such that all load are within memory boundary only. - Corrected ztrsm utility APIs for dcomplex multiplication and division. AMD-Internal: [CPUPL-2093] Change-Id: Ib2c65e7921f6391b530cd20d6ea6b50f24bd705e --- kernels/zen/3/bli_trsm_small.c | 771 ++++++++++++++++++++++++--------- 1 file changed, 567 insertions(+), 204 deletions(-) diff --git a/kernels/zen/3/bli_trsm_small.c b/kernels/zen/3/bli_trsm_small.c index 0fa8f66d5a..32b7647a50 100644 --- a/kernels/zen/3/bli_trsm_small.c +++ b/kernels/zen/3/bli_trsm_small.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2018-2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -3891,33 +3891,20 @@ err_t bli_trsm_small */ #define DCOMPLEX_INV(a, b) {\ - a.real = b.real;\ - a.imag = (b.imag * -1.0);\ - /*Compute denominator eliminating imaginary component*/\ - double dnm = (b.real * b.real);\ - /*multiply two times with -1 for correct result as - * dcomplex number with positive imaginary part will - * invert the sign if not multiplied twice with -1*/\ - dnm += ((-1.0 * (b.imag * b.imag)) * -1.0);\ - /*Compute the final result by dividing real and imag part by dnm*/\ - a.real /= dnm;\ - a.imag /= dnm;\ +/* dcomplex inva = {1.0, 0.0};*/\ + a.real = 1.0;\ + a.imag = 0.0;\ + bli_zinvscals(b, a);\ } #define DCOMPLEX_MUL(a, b, c) {\ - double real = a.real * b.real;\ - real += ((a.imag * b.imag) * -1.0);\ - double imag = (a.real * b.imag);\ - imag += (a.imag * b.real);\ - c.real = real;\ - c.imag = imag;\ + c.real = b.real;\ + c.imag = b.imag;\ + bli_zscals(a,c);\ } #define DCOMPLEX_DIV(a, b){\ - double dnm = b.real * b.real;\ - dnm += (-1.0 * (b.imag * (b.imag * -1.0) ));\ - a.real /= dnm;\ - a.imag /= dnm;\ + bli_zinvscals(b,a); \ } @@ -3946,11 +3933,8 @@ err_t bli_trsm_small #define ZTRSM_DIAG_ELE_EVAL_OPS(a,b,c){\ if(!is_unitdiag)\ {\ - a.real = b.real;\ - a.imag = (b.imag * -1.0);\ - DCOMPLEX_MUL(c, a, c)\ - DCOMPLEX_DIV(c, b)\ - }\ + bli_zinvscals(b, c);\ + }\ } #endif @@ -4299,6 +4283,213 @@ BLIS_INLINE err_t ztrsm_AuXB_ref _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm9);\ } + +#define BLIS_ZTRSM_SMALL_GEMM_3mx3n(a10,b01,cs_b,p_lda,k_iter) {\ + double *tptr = (double *)b01;\ + if(conjtransa) {\ + ymm16 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ + for(k = 0; k< k_iter; k++) \ + { \ + ymm0 = _mm256_loadu_pd((double const *)(a10)); \ + xmm4 = _mm_loadu_pd((double const *)(a10 + 2));\ + ymm1 = _mm256_insertf128_pd(ymm1, xmm4, 0); \ + ymm0 = _mm256_mul_pd(ymm0, ymm16);\ + ymm1 = _mm256_mul_pd(ymm1, ymm16);\ + \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr)); \ + ymm3 = _mm256_broadcast_sd((double const *)(tptr + 1)); \ + \ + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8);\ + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11);\ + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);\ + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5);\ + \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 1 * 2)); \ + ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 1 * 2 + 1)); \ + \ + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9);\ + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12);\ + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6);\ + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7);\ + \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b *2 * 2)); \ + ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 2 + 1)); \ + \ + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10);\ + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13);\ + \ + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14);\ + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15);\ + \ + tptr += 2; \ + a10 += p_lda; \ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) \ + { \ + ymm0 = _mm256_loadu_pd((double const *)(a10)); \ + xmm4 = _mm_loadu_pd((double const *)(a10 + 2));\ + ymm1 = _mm256_insertf128_pd(ymm1, xmm4, 0); \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr)); \ + ymm3 = _mm256_broadcast_sd((double const *)(tptr + 1)); \ + \ + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8);\ + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11);\ + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);\ + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5);\ + \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 1 * 2)); \ + ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 1 * 2 + 1)); \ + \ + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9);\ + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12);\ + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6);\ + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7);\ + \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b *2 * 2)); \ + ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 2 + 1)); \ + \ + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10);\ + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13);\ + \ + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14);\ + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15);\ + \ + tptr += 2; \ + a10 += p_lda; \ + }\ + }\ + ymm4 = _mm256_permute_pd(ymm4, 0x5);\ + ymm5 = _mm256_permute_pd(ymm5, 0x5);\ + ymm6 = _mm256_permute_pd(ymm6, 0x5);\ + ymm7 = _mm256_permute_pd(ymm7, 0x5);\ + ymm14 = _mm256_permute_pd(ymm14, 0x5);\ + ymm15 = _mm256_permute_pd(ymm15, 0x5);\ + \ + ymm8 = _mm256_addsub_pd(ymm8, ymm4);\ + ymm11 = _mm256_addsub_pd(ymm11, ymm5);\ + ymm9 = _mm256_addsub_pd(ymm9, ymm6);\ + ymm12 = _mm256_addsub_pd(ymm12, ymm7);\ + ymm10 = _mm256_addsub_pd(ymm10, ymm14);\ + ymm13 = _mm256_addsub_pd(ymm13, ymm15);\ +} + + +#define BLIS_ZTRSM_SMALL_GEMM_3mx2n(a10,b01,cs_b,p_lda,k_iter) {\ + double *tptr = (double * )b01;\ + if(conjtransa) {\ + ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_loadu_pd((double const *)(a10));\ + xmm4 = _mm_loadu_pd((double const *)(a10 + 2));\ + ymm1 = _mm256_insertf128_pd(ymm1, xmm4, 0); \ + ymm0 = _mm256_mul_pd(ymm0, ymm18);\ + ymm1 = _mm256_mul_pd(ymm1, ymm18);\ + \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0));\ + ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0 + 1)); \ + \ + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8);\ + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12);\ + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);\ + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5);\ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1)); \ + ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1 + 1)); \ + \ + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9);\ + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13);\ + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6);\ + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7);\ + tptr += 2; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_loadu_pd((double const *)(a10));\ + xmm4 = _mm_loadu_pd((double const *)(a10 + 2));\ + ymm1 = _mm256_insertf128_pd(ymm1, xmm4, 0); \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0));\ + ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0 + 1)); \ + \ + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8);\ + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12);\ + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);\ + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5);\ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1)); \ + ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1 + 1)); \ + \ + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9);\ + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13);\ + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6);\ + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7);\ + tptr += 2; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + }\ + }\ + ymm4 = _mm256_permute_pd(ymm4, 0x5);\ + ymm5 = _mm256_permute_pd(ymm5, 0x5);\ + ymm6 = _mm256_permute_pd(ymm6, 0x5);\ + ymm7 = _mm256_permute_pd(ymm7, 0x5);\ +\ + ymm8 = _mm256_addsub_pd(ymm8, ymm4);\ + ymm12 = _mm256_addsub_pd(ymm12, ymm5);\ + ymm9 = _mm256_addsub_pd(ymm9, ymm6);\ + ymm13 = _mm256_addsub_pd(ymm13, ymm7);\ +} + +#define BLIS_ZTRSM_SMALL_GEMM_3mx1n(a10,b01,cs_b,p_lda,k_iter) {\ + double *tptr = (double *)b01;\ + if(conjtransa) {\ + ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_loadu_pd((double const *)(a10));\ + xmm4 = _mm_loadu_pd((double const *)(a10 + 2));\ + ymm1 = _mm256_insertf128_pd(ymm1, xmm4, 0); \ + ymm0 = _mm256_mul_pd(ymm0, ymm18);\ + ymm1 = _mm256_mul_pd(ymm1, ymm18);\ + \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0));\ + ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0 + 1)); \ + \ + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8);\ + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12);\ + \ + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);\ + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5);\ + tptr += 2; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_loadu_pd((double const *)(a10));\ + xmm4 = _mm_loadu_pd((double const *)(a10 + 2));\ + ymm1 = _mm256_insertf128_pd(ymm1, xmm4, 0); \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0));\ + ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0 + 1)); \ + \ + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8);\ + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12);\ + \ + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);\ + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5);\ + tptr += 2; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + }\ + }\ + ymm4 = _mm256_permute_pd(ymm4, 0x5);\ + ymm5 = _mm256_permute_pd(ymm5, 0x5);\ + ymm8 = _mm256_addsub_pd(ymm8, ymm4);\ + ymm12 = _mm256_addsub_pd(ymm12, ymm5);\ +} + + /** * Performs GEMM operation. * Two elements of column in ymm0 @@ -31943,75 +32134,160 @@ BLIS_INLINE err_t bli_ztrsm_small_AutXB_AlXB if(m_rem == 3) { dim_t p_lda = 4; - if(transa) - { - for(dim_t x = 0; x < i; x += p_lda) - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm10 = _mm256_loadu_pd((double const *) - (a10 + 2)); - ymm1 = _mm256_loadu_pd((double const *) - (a10 + cs_a)); - ymm11 = _mm256_loadu_pd((double const *) - (a10 + 2 + cs_a)); + if(transa) + { + dim_t x = 0; + for(x = 0; (x+3) < i; x += p_lda) + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm10 = _mm256_loadu_pd((double const *) + (a10 + 2)); + ymm1 = _mm256_loadu_pd((double const *) + (a10 + cs_a)); + ymm11 = _mm256_loadu_pd((double const *) + (a10 + 2 + cs_a)); + + ymm6 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + ymm8 = _mm256_permute2f128_pd(ymm10,ymm11,0x20); + ymm9 = _mm256_permute2f128_pd(ymm10,ymm11,0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + + p_lda*3), ymm9); + + ymm0 = _mm256_loadu_pd((double const *)(a10 + + 2 * cs_a)); + ymm10 = _mm256_loadu_pd((double const *)(a10 + + 2 * cs_a + 2)); + ymm1 = _mm256_set_pd(1, 1, 1, 1); + + ymm6 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + ymm8 = _mm256_permute2f128_pd(ymm10,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm10,ymm1,0x31); + + + _mm256_storeu_pd((double *)(ptr_a10_dup + 2), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + + p_lda + 2), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + + p_lda*2 + 2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + + p_lda*3 + 2), ymm9); + + a10 += p_lda; + ptr_a10_dup += p_lda * p_lda; + } + for(; (x+2) < i; x += 3) + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + xmm4 = _mm_loadu_pd((double const *) + (a10 + 2)); + ymm10 = _mm256_insertf128_pd(ymm10, xmm4, 0); + ymm1 = _mm256_loadu_pd((double const *) + (a10 + cs_a)); + xmm4 = _mm_loadu_pd((double const *) + (a10 + 2 + cs_a)); + ymm11 = _mm256_insertf128_pd(ymm11, xmm4, 0); + + ymm6 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + ymm8 = _mm256_permute2f128_pd(ymm10,ymm11,0x20); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + + p_lda*2), ymm8); + + ymm0 = _mm256_loadu_pd((double const *)(a10 + + 2 * cs_a)); + xmm4 = _mm_loadu_pd((double const *)(a10 + + 2 * cs_a + 2)); + ymm10 = _mm256_insertf128_pd(ymm10, xmm4, 0); + ymm1 = _mm256_set_pd(1, 1, 1, 1); + + ymm6 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + ymm8 = _mm256_permute2f128_pd(ymm10,ymm1,0x20); + + + _mm256_storeu_pd((double *)(ptr_a10_dup + 2), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + + p_lda + 2), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + + p_lda*2 + 2), ymm8); + + a10 += 3; + ptr_a10_dup += p_lda * p_lda; + } + for(; (x+1) < i; x += 2) + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *) + (a10 + cs_a)); - ymm6 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - ymm8 = _mm256_permute2f128_pd(ymm10,ymm11,0x20); - ymm9 = _mm256_permute2f128_pd(ymm10,ymm11,0x31); + ymm6 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); - _mm256_storeu_pd((double *)(ptr_a10_dup + - p_lda), ymm7); - _mm256_storeu_pd((double *)(ptr_a10_dup + - p_lda*2), ymm8); - _mm256_storeu_pd((double *)(ptr_a10_dup + - p_lda*3), ymm9); + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + + p_lda), ymm7); - ymm0 = _mm256_loadu_pd((double const *)(a10 - + 2 * cs_a)); - ymm10 = _mm256_loadu_pd((double const *)(a10 - + 2 * cs_a + 2)); + ymm0 = _mm256_loadu_pd((double const *)(a10 + + 2 * cs_a)); + ymm1 = _mm256_set_pd(1, 1, 1, 1); - ymm1 = _mm256_loadu_pd((double const *)(a10 - + 3 * cs_a)); - ymm11 = _mm256_loadu_pd((double const *)(a10 - + 3 * cs_a + 2)); + ymm6 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - ymm6 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - ymm8 = _mm256_permute2f128_pd(ymm10,ymm11,0x20); - ymm9 = _mm256_permute2f128_pd(ymm10,ymm11,0x31); - _mm256_storeu_pd((double *)(ptr_a10_dup + 2), - ymm6); - _mm256_storeu_pd((double *)(ptr_a10_dup + - p_lda + 2), ymm7); - _mm256_storeu_pd((double *)(ptr_a10_dup + - p_lda*2 + 2), ymm8); - _mm256_storeu_pd((double *)(ptr_a10_dup + - p_lda*3 + 2), ymm9); + _mm256_storeu_pd((double *)(ptr_a10_dup + 2), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + + p_lda + 2), ymm7); - a10 += p_lda; - ptr_a10_dup += p_lda * p_lda; - } + a10 += 2; + ptr_a10_dup += p_lda * p_lda; + } + for(; x < i; x += 1) + { + xmm4 = _mm_loadu_pd((double const *)(a10)); + xmm5 = _mm_loadu_pd((double const *) + (a10 + cs_a)); - } - else - { - for(dim_t x=0;x 0; j -= d_nr) { @@ -33429,37 +33791,38 @@ BLIS_INLINE err_t bli_ztrsm_small_AltXB_AuXB } else if(m_remainder == 1) { - dim_t p_lda = 2; // packed leading dimension - if(transa) - { - for(dim_t x = 0; x < m-m_remainder; x += p_lda) - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm1 = _mm256_loadu_pd((double const *) - (a10 + cs_a)); - - ymm6 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - - _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); - _mm256_storeu_pd((double *)(ptr_a10_dup + - p_lda), ymm7); - - a10 += p_lda; - ptr_a10_dup += p_lda * p_lda; - } - - } - else - { - for(dim_t x=0;x 0; j -= d_nr) { From 9621ef3067d0fb3937e1f6e667999cbdc04b6095 Mon Sep 17 00:00:00 2001 From: Sireesha Sanga Date: Mon, 4 Apr 2022 16:08:18 +0530 Subject: [PATCH 097/243] Performance Improvement for ztrsm small sizes Details: - Enable ztrsm small implementation - For small sizes, Right Variants and Left Unit Diag Variants are using ztrsm_small implementations. - Optimization of Left Non-Unit Diagonal Variants, Work In Progress AMD-Internal: [SWLCSG-1194] Change-Id: Ib3cce6e2e4ac0817ccd4dff4bb0fa4a23e231ca4 --- frame/compat/bla_trsm_amd.c | 11 ++++++----- frame/include/bli_gentfunc_macro_defs.h | 6 +++++- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/frame/compat/bla_trsm_amd.c b/frame/compat/bla_trsm_amd.c index 21b2a1598d..eb5c835ff5 100644 --- a/frame/compat/bla_trsm_amd.c +++ b/frame/compat/bla_trsm_amd.c @@ -902,7 +902,6 @@ void dtrsm_ /* Finalize BLIS. */ bli_finalize_auto(); } -#if 0 void ztrsm_ ( const f77_char* side, @@ -1184,8 +1183,10 @@ void ztrsm_ * In case of multithread when [m,n]<=128 sinlge thread implemenation * is doing better than native multithread */ bool nt = bli_thread_get_is_parallel(); - if((nt==0 && m0<=500 && n0<=500) || - (nt && (m0+n0)<128) ) + + if((blis_side == BLIS_RIGHT) || (blis_diaga == BLIS_UNIT_DIAG)) { + if(((nt==0) && (m0<=500) && (n0<=500)) || + (nt && ((m0+n0)<128))) { err_t status; status = bli_trsm_small @@ -1205,6 +1206,7 @@ void ztrsm_ return; } } + } #endif bli_trsmnat @@ -1221,7 +1223,6 @@ void ztrsm_ /* Finalize BLIS. */ bli_finalize_auto(); } -#endif #if 0 void ctrsm_ ( @@ -1539,6 +1540,6 @@ void ctrsm_ bli_finalize_auto(); } #endif -INSERT_GENTFUNC_BLAS_CZ( trsm, trsm ) +INSERT_GENTFUNC_BLAS_C( trsm, trsm ) #endif diff --git a/frame/include/bli_gentfunc_macro_defs.h b/frame/include/bli_gentfunc_macro_defs.h index 1bac7aa7c4..49c79cb8ae 100644 --- a/frame/include/bli_gentfunc_macro_defs.h +++ b/frame/include/bli_gentfunc_macro_defs.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 21, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 22, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -67,6 +67,10 @@ GENTFUNC( scomplex, c, blasname, blisname ) GENTFUNC( scomplex, c, blasname, blisname ) \ GENTFUNC( dcomplex, z, blasname, blisname ) +#define INSERT_GENTFUNC_BLAS_C( blasname, blisname ) \ +\ +GENTFUNC( scomplex, c, blasname, blisname ) + // -- Basic one-operand macro with real domain only -- From fe7f0a9085856152872cd7e1d10a83b43084118f Mon Sep 17 00:00:00 2001 From: satish kumar nuggu Date: Fri, 1 Apr 2022 09:15:14 +0530 Subject: [PATCH 098/243] Changes to enable zgemm small from BLAS Layer 1. Removed small gemm call from native path to avoid Single threaded calls as a part of MultiThreaded scenarios. 2. SUP and INDUCED Method path disabled. 3. Added AOCL Dynamic for optimum number of threads to achieve higher performance. Change-Id: I3c41641bef4906bdbdb5f05e67c0f61e86025d92 --- frame/3/gemm/bli_gemm_front.c | 16 - frame/3/gemm/bli_gemm_front_amd.c | 26 +- frame/base/bli_rntm.c | 16 + frame/compat/bla_gemm_amd.c | 1254 ++- kernels/zen/3/bli_gemm_small.c | 15357 ++++++++++++++-------------- kernels/zen/bli_kernels_zen.h | 22 + 6 files changed, 8354 insertions(+), 8337 deletions(-) diff --git a/frame/3/gemm/bli_gemm_front.c b/frame/3/gemm/bli_gemm_front.c index 46e163c026..a9bada995d 100644 --- a/frame/3/gemm/bli_gemm_front.c +++ b/frame/3/gemm/bli_gemm_front.c @@ -74,22 +74,6 @@ void bli_gemm_front return; } -#ifdef BLIS_ENABLE_SMALL_MATRIX - // Only handle small problems separately for homogeneous datatypes. - if ( bli_obj_dt( a ) == bli_obj_dt( b ) && - bli_obj_dt( a ) == bli_obj_dt( c ) && - bli_obj_comp_prec( c ) == bli_obj_prec( c ) ) - { - err_t status = bli_gemm_small( alpha, a, b, beta, c, cntx, cntl ); - - if ( status == BLIS_SUCCESS ) - { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); - return; - } - } -#endif - // Alias A, B, and C in case we need to apply transformations. bli_obj_alias_to( a, &a_local ); bli_obj_alias_to( b, &b_local ); diff --git a/frame/3/gemm/bli_gemm_front_amd.c b/frame/3/gemm/bli_gemm_front_amd.c index a29a0bb85b..34b41f0568 100644 --- a/frame/3/gemm/bli_gemm_front_amd.c +++ b/frame/3/gemm/bli_gemm_front_amd.c @@ -50,6 +50,16 @@ void bli_gemm_front AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_3); bli_init_once(); + #ifdef AOCL_DYNAMIC + // If dynamic-threading is enabled, calculate optimum number + // of threads. + // rntm will be updated with optimum number of threads. + if( bli_obj_is_dcomplex(c))// This will enable for ZGEMM + { + bli_nthreads_optimum(a, b, c, BLIS_GEMM, rntm); + } + #endif + obj_t a_local; obj_t b_local; obj_t c_local; @@ -74,22 +84,6 @@ void bli_gemm_front return; } -#ifdef BLIS_ENABLE_SMALL_MATRIX - // Only handle small problems separately for homogeneous datatypes. - if ( bli_obj_dt( a ) == bli_obj_dt( b ) && - bli_obj_dt( a ) == bli_obj_dt( c ) && - bli_obj_comp_prec( c ) == bli_obj_prec( c ) ) - { - err_t status = bli_gemm_small( alpha, a, b, beta, c, cntx, cntl ); - - if ( status == BLIS_SUCCESS ) - { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); - return; - } - } -#endif - // Alias A, B, and C in case we need to apply transformations. bli_obj_alias_to( a, &a_local ); bli_obj_alias_to( b, &b_local ); diff --git a/frame/base/bli_rntm.c b/frame/base/bli_rntm.c index 7176dacc4e..c597074f58 100644 --- a/frame/base/bli_rntm.c +++ b/frame/base/bli_rntm.c @@ -600,6 +600,22 @@ void bli_nthreads_optimum( } } + else if( family == BLIS_GEMM && bli_obj_is_dcomplex(c)) + { + + dim_t m = bli_obj_length(c); + dim_t n = bli_obj_width(c); + dim_t k = bli_obj_width_after_trans(a); + + if((m<=128 || n<=128 || k<=128) && (m+n+k <= 400) ) + { + n_threads_ideal = 8; + } + else if((m<=256 || n<=256 || k<=256) && (m+n+k <= 800) ) + { + n_threads_ideal = 16; + } + } else if( family == BLIS_SYRK && bli_obj_is_double(c)) { dim_t n = bli_obj_length(c); diff --git a/frame/compat/bla_gemm_amd.c b/frame/compat/bla_gemm_amd.c index ff995b5f07..7060509de2 100644 --- a/frame/compat/bla_gemm_amd.c +++ b/frame/compat/bla_gemm_amd.c @@ -55,76 +55,76 @@ void PASTEF77(ch,blasname) \ const ftype* a, const f77_int* lda, \ const ftype* b, const f77_int* ldb, \ const ftype* beta, \ - ftype* c, const f77_int* ldc \ + ftype* c, const f77_int* ldc \ ) \ { \ - trans_t blis_transa; \ - trans_t blis_transb; \ - dim_t m0, n0, k0; \ - inc_t rs_a, cs_a; \ - inc_t rs_b, cs_b; \ - inc_t rs_c, cs_c; \ + trans_t blis_transa; \ + trans_t blis_transb; \ + dim_t m0, n0, k0; \ + inc_t rs_a, cs_a; \ + inc_t rs_b, cs_b; \ + inc_t rs_c, cs_c; \ \ - /* Initialize BLIS. */ \ - bli_init_auto(); \ + /* Initialize BLIS. */ \ + bli_init_auto(); \ \ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); \ - AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(ch), *transa, *transb, *m, *n, *k, \ - (void*)alpha, *lda, *ldb, (void*)beta, *ldc); \ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); \ + AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(ch), *transa, *transb, *m, *n, *k, \ + (void*)alpha, *lda, *ldb, (void*)beta, *ldc); \ \ - /* Perform BLAS parameter checking. */ \ - PASTEBLACHK(blasname) \ - ( \ - MKSTR(ch), \ - MKSTR(blasname), \ - transa, \ - transb, \ - m, \ - n, \ - k, \ - lda, \ - ldb, \ - ldc \ - ); \ + /* Perform BLAS parameter checking. */ \ + PASTEBLACHK(blasname) \ + ( \ + MKSTR(ch), \ + MKSTR(blasname), \ + transa, \ + transb, \ + m, \ + n, \ + k, \ + lda, \ + ldb, \ + ldc \ + ); \ \ - /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ - bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ - bli_param_map_netlib_to_blis_trans( *transb, &blis_transb ); \ + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ + bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ + bli_param_map_netlib_to_blis_trans( *transb, &blis_transb ); \ \ - /* Typecast BLAS integers to BLIS integers. */ \ - bli_convert_blas_dim1( *m, m0 ); \ - bli_convert_blas_dim1( *n, n0 ); \ - bli_convert_blas_dim1( *k, k0 ); \ + /* Typecast BLAS integers to BLIS integers. */ \ + bli_convert_blas_dim1( *m, m0 ); \ + bli_convert_blas_dim1( *n, n0 ); \ + bli_convert_blas_dim1( *k, k0 ); \ \ - /* Set the row and column strides of the matrix operands. */ \ - rs_a = 1; \ - cs_a = *lda; \ - rs_b = 1; \ - cs_b = *ldb; \ - rs_c = 1; \ - cs_c = *ldc; \ + /* Set the row and column strides of the matrix operands. */ \ + rs_a = 1; \ + cs_a = *lda; \ + rs_b = 1; \ + cs_b = *ldb; \ + rs_c = 1; \ + cs_c = *ldc; \ \ - /* Call BLIS interface. */ \ - PASTEMAC2(ch,blisname,BLIS_TAPI_EX_SUF) \ - ( \ - blis_transa, \ - blis_transb, \ - m0, \ - n0, \ - k0, \ - (ftype*)alpha, \ - (ftype*)a, rs_a, cs_a, \ - (ftype*)b, rs_b, cs_b, \ - (ftype*)beta, \ - (ftype*)c, rs_c, cs_c, \ - NULL, \ - NULL \ - ); \ + /* Call BLIS interface. */ \ + PASTEMAC2(ch,blisname,BLIS_TAPI_EX_SUF) \ + ( \ + blis_transa, \ + blis_transb, \ + m0, \ + n0, \ + k0, \ + (ftype*)alpha, \ + (ftype*)a, rs_a, cs_a, \ + (ftype*)b, rs_b, cs_b, \ + (ftype*)beta, \ + (ftype*)c, rs_c, cs_c, \ + NULL, \ + NULL \ + ); \ \ - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ - /* Finalize BLIS. */ \ - bli_finalize_auto(); \ + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ } #else @@ -143,175 +143,175 @@ void PASTEF77(ch,blasname) \ const ftype* a, const f77_int* lda, \ const ftype* b, const f77_int* ldb, \ const ftype* beta, \ - ftype* c, const f77_int* ldc \ + ftype* c, const f77_int* ldc \ ) \ { \ \ - trans_t blis_transa; \ - trans_t blis_transb; \ - dim_t m0, n0, k0; \ + trans_t blis_transa; \ + trans_t blis_transb; \ + dim_t m0, n0, k0; \ \ - dim_t m0_a, n0_a; \ - dim_t m0_b, n0_b; \ + dim_t m0_a, n0_a; \ + dim_t m0_b, n0_b; \ \ - /* Initialize BLIS. */ \ - bli_init_auto(); \ + /* Initialize BLIS. */ \ + bli_init_auto(); \ \ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); \ - AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(ch), *transa, *transb, *m, *n, *k, \ - (void*)alpha, *lda, *ldb, (void*)beta, *ldc); \ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); \ + AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(ch), *transa, *transb, *m, *n, *k, \ + (void*)alpha, *lda, *ldb, (void*)beta, *ldc); \ \ - /* Perform BLAS parameter checking. */ \ - PASTEBLACHK(blasname) \ - ( \ - MKSTR(ch), \ - MKSTR(blasname), \ - transa, \ - transb, \ - m, \ - n, \ - k, \ - lda, \ - ldb, \ - ldc \ - ); \ + /* Perform BLAS parameter checking. */ \ + PASTEBLACHK(blasname) \ + ( \ + MKSTR(ch), \ + MKSTR(blasname), \ + transa, \ + transb, \ + m, \ + n, \ + k, \ + lda, \ + ldb, \ + ldc \ + ); \ \ - /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ - bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ - bli_param_map_netlib_to_blis_trans( *transb, &blis_transb ); \ + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ + bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ + bli_param_map_netlib_to_blis_trans( *transb, &blis_transb ); \ \ - /* Typecast BLAS integers to BLIS integers. */ \ - bli_convert_blas_dim1( *m, m0 ); \ - bli_convert_blas_dim1( *n, n0 ); \ - bli_convert_blas_dim1( *k, k0 ); \ + /* Typecast BLAS integers to BLIS integers. */ \ + bli_convert_blas_dim1( *m, m0 ); \ + bli_convert_blas_dim1( *n, n0 ); \ + bli_convert_blas_dim1( *k, k0 ); \ \ - /* Set the row and column strides of the matrix operands. */ \ - const inc_t rs_a = 1; \ - const inc_t cs_a = *lda; \ - const inc_t rs_b = 1; \ - const inc_t cs_b = *ldb; \ - const inc_t rs_c = 1; \ - const inc_t cs_c = *ldc; \ + /* Set the row and column strides of the matrix operands. */ \ + const inc_t rs_a = 1; \ + const inc_t cs_a = *lda; \ + const inc_t rs_b = 1; \ + const inc_t cs_b = *ldb; \ + const inc_t rs_c = 1; \ + const inc_t cs_c = *ldc; \ \ - if( n0 == 1 ) \ - { \ - if(bli_is_notrans(blis_transa)) \ - { \ - PASTEMAC(ch,gemv_unf_var2)( \ - BLIS_NO_TRANSPOSE, \ - bli_extract_conj(blis_transb), \ - m0, k0, \ - (ftype*)alpha, \ - (ftype*)a, rs_a, cs_a,\ - (ftype*)b, bli_is_notrans(blis_transb)?rs_b:cs_b, \ - (ftype*) beta, \ - c, rs_c, \ - NULL \ - ); \ - } \ - else \ - { \ - PASTEMAC(ch,gemv_unf_var1)( \ - blis_transa, \ - bli_extract_conj(blis_transb), \ - k0, m0, \ - (ftype*)alpha, \ - (ftype*)a, rs_a, cs_a, \ - (ftype*)b, bli_is_notrans(blis_transb)?rs_b:cs_b, \ - (ftype*)beta, \ - c, rs_c, \ - NULL \ - ); \ - } \ - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); \ - return; \ - } \ - else if( m0 == 1 ) \ - { \ - if(bli_is_notrans(blis_transb)) \ - { \ - PASTEMAC(ch,gemv_unf_var1)( \ - blis_transb, \ - bli_extract_conj(blis_transa), \ - n0, k0, \ - (ftype*)alpha, \ - (ftype*)b, cs_b, rs_b, \ - (ftype*)a, bli_is_notrans(blis_transa)?cs_a:rs_a, \ - (ftype*)beta, \ - c, cs_c, \ - NULL \ - ); \ - } \ - else \ - { \ - PASTEMAC(ch,gemv_unf_var2)( \ - blis_transb, \ - bli_extract_conj(blis_transa), \ - k0, n0, \ - (ftype*)alpha, \ - (ftype*)b, cs_b, rs_b, \ - (ftype*)a, bli_is_notrans(blis_transa)?cs_a:rs_a, \ - (ftype*)beta, \ - c, cs_c, \ - NULL \ - ); \ - } \ - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); \ - return; \ - } \ + if( n0 == 1 ) \ + { \ + if(bli_is_notrans(blis_transa)) \ + { \ + PASTEMAC(ch,gemv_unf_var2)( \ + BLIS_NO_TRANSPOSE, \ + bli_extract_conj(blis_transb), \ + m0, k0, \ + (ftype*)alpha, \ + (ftype*)a, rs_a, cs_a,\ + (ftype*)b, bli_is_notrans(blis_transb)?rs_b:cs_b, \ + (ftype*) beta, \ + c, rs_c, \ + NULL \ + ); \ + } \ + else \ + { \ + PASTEMAC(ch,gemv_unf_var1)( \ + blis_transa, \ + bli_extract_conj(blis_transb), \ + k0, m0, \ + (ftype*)alpha, \ + (ftype*)a, rs_a, cs_a, \ + (ftype*)b, bli_is_notrans(blis_transb)?rs_b:cs_b, \ + (ftype*)beta, \ + c, rs_c, \ + NULL \ + ); \ + } \ + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); \ + return; \ + } \ + else if( m0 == 1 ) \ + { \ + if(bli_is_notrans(blis_transb)) \ + { \ + PASTEMAC(ch,gemv_unf_var1)( \ + blis_transb, \ + bli_extract_conj(blis_transa), \ + n0, k0, \ + (ftype*)alpha, \ + (ftype*)b, cs_b, rs_b, \ + (ftype*)a, bli_is_notrans(blis_transa)?cs_a:rs_a, \ + (ftype*)beta, \ + c, cs_c, \ + NULL \ + ); \ + } \ + else \ + { \ + PASTEMAC(ch,gemv_unf_var2)( \ + blis_transb, \ + bli_extract_conj(blis_transa), \ + k0, n0, \ + (ftype*)alpha, \ + (ftype*)b, cs_b, rs_b, \ + (ftype*)a, bli_is_notrans(blis_transa)?cs_a:rs_a, \ + (ftype*)beta, \ + c, cs_c, \ + NULL \ + ); \ + } \ + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); \ + return; \ + } \ \ - const num_t dt = PASTEMAC(ch,type); \ + const num_t dt = PASTEMAC(ch,type); \ \ - obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; \ - obj_t ao = BLIS_OBJECT_INITIALIZER; \ - obj_t bo = BLIS_OBJECT_INITIALIZER; \ - obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; \ - obj_t co = BLIS_OBJECT_INITIALIZER; \ + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t ao = BLIS_OBJECT_INITIALIZER; \ + obj_t bo = BLIS_OBJECT_INITIALIZER; \ + obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t co = BLIS_OBJECT_INITIALIZER; \ \ - bli_set_dims_with_trans( blis_transa, m0, k0, &m0_a, &n0_a ); \ - bli_set_dims_with_trans( blis_transb, k0, n0, &m0_b, &n0_b ); \ + bli_set_dims_with_trans( blis_transa, m0, k0, &m0_a, &n0_a ); \ + bli_set_dims_with_trans( blis_transb, k0, n0, &m0_b, &n0_b ); \ \ - bli_obj_init_finish_1x1( dt, (ftype*)alpha, &alphao ); \ - bli_obj_init_finish_1x1( dt, (ftype*)beta, &betao ); \ + bli_obj_init_finish_1x1( dt, (ftype*)alpha, &alphao ); \ + bli_obj_init_finish_1x1( dt, (ftype*)beta, &betao ); \ \ - bli_obj_init_finish( dt, m0_a, n0_a, (ftype*)a, rs_a, cs_a, &ao ); \ - bli_obj_init_finish( dt, m0_b, n0_b, (ftype*)b, rs_b, cs_b, &bo ); \ - bli_obj_init_finish( dt, m0, n0, (ftype*)c, rs_c, cs_c, &co ); \ + bli_obj_init_finish( dt, m0_a, n0_a, (ftype*)a, rs_a, cs_a, &ao ); \ + bli_obj_init_finish( dt, m0_b, n0_b, (ftype*)b, rs_b, cs_b, &bo ); \ + bli_obj_init_finish( dt, m0, n0, (ftype*)c, rs_c, cs_c, &co ); \ \ - bli_obj_set_conjtrans( blis_transa, &ao ); \ - bli_obj_set_conjtrans( blis_transb, &bo ); \ + bli_obj_set_conjtrans( blis_transa, &ao ); \ + bli_obj_set_conjtrans( blis_transb, &bo ); \ \ - PASTEMAC(blisname,BLIS_OAPI_EX_SUF) \ - ( \ - &alphao, \ - &ao, \ - &bo, \ - &betao, \ - &co, \ - NULL, \ - NULL \ - ); \ + PASTEMAC(blisname,BLIS_OAPI_EX_SUF) \ + ( \ + &alphao, \ + &ao, \ + &bo, \ + &betao, \ + &co, \ + NULL, \ + NULL \ + ); \ \ - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); \ - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ - /* Finalize BLIS. */ \ - bli_finalize_auto(); \ + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ } #endif #ifdef BLIS_ENABLE_BLAS void dgemm_ ( - const f77_char* transa, - const f77_char* transb, - const f77_int* m, - const f77_int* n, - const f77_int* k, - const double* alpha, - const double* a, const f77_int* lda, - const double* b, const f77_int* ldb, - const double* beta, - double* c, const f77_int* ldc + const f77_char* transa, + const f77_char* transb, + const f77_int* m, + const f77_int* n, + const f77_int* k, + const double* alpha, + const double* a, const f77_int* lda, + const double* b, const f77_int* ldb, + const double* beta, + double* c, const f77_int* ldc ) { @@ -343,7 +343,7 @@ void dgemm_ ldc ); - /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ bli_param_map_netlib_to_blis_trans(*transa, &blis_transa); bli_param_map_netlib_to_blis_trans(*transb, &blis_transb); @@ -361,141 +361,141 @@ void dgemm_ const inc_t rs_c = 1; const inc_t cs_c = *ldc; - // This function is invoked on all architectures including ‘generic’. - // Non-AVX platforms will use the kernels derived from the context. - if (bli_cpuid_is_avx_supported() == FALSE) - { - // This code is duplicated below, however we don't want to move it out of - // this IF block as it will affect the performance on Zen architetures - // Also this is temporary fix which will be replaced later. - const num_t dt = BLIS_DOUBLE; - - obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; - obj_t ao = BLIS_OBJECT_INITIALIZER; - obj_t bo = BLIS_OBJECT_INITIALIZER; - obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; - obj_t co = BLIS_OBJECT_INITIALIZER; - - dim_t m0_a, n0_a; - dim_t m0_b, n0_b; - - bli_set_dims_with_trans(blis_transa, m0, k0, &m0_a, &n0_a); - bli_set_dims_with_trans(blis_transb, k0, n0, &m0_b, &n0_b); - - bli_obj_init_finish_1x1(dt, (double *)alpha, &alphao); - bli_obj_init_finish_1x1(dt, (double *)beta, &betao); - - bli_obj_init_finish(dt, m0_a, n0_a, (double *)a, rs_a, cs_a, &ao); - bli_obj_init_finish(dt, m0_b, n0_b, (double *)b, rs_b, cs_b, &bo); - bli_obj_init_finish(dt, m0, n0, (double *)c, rs_c, cs_c, &co); - - bli_obj_set_conjtrans(blis_transa, &ao); - bli_obj_set_conjtrans(blis_transb, &bo); - - // Will call parallelized dgemm code - sup & native - PASTEMAC(gemm, BLIS_OAPI_EX_SUF) - ( - &alphao, - &ao, - &bo, - &betao, - &co, - NULL, - NULL - ); - - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - /* Finalize BLIS. */ - bli_finalize_auto(); - return; - } - - if((k0 == 1) && bli_is_notrans(blis_transa) && bli_is_notrans(blis_transb)) + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == FALSE) { - bli_dgemm_ref_k1_nn( m0, n0, k0, - (double*)alpha, - (double*)a, *lda, - (double*)b, *ldb, - (double*)beta, - c, *ldc - ); - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - /* Finalize BLIS */ - bli_finalize_auto(); - - return; + // This code is duplicated below, however we don't want to move it out of + // this IF block as it will affect the performance on Zen architetures + // Also this is temporary fix which will be replaced later. + const num_t dt = BLIS_DOUBLE; + + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; + obj_t ao = BLIS_OBJECT_INITIALIZER; + obj_t bo = BLIS_OBJECT_INITIALIZER; + obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; + obj_t co = BLIS_OBJECT_INITIALIZER; + + dim_t m0_a, n0_a; + dim_t m0_b, n0_b; + + bli_set_dims_with_trans(blis_transa, m0, k0, &m0_a, &n0_a); + bli_set_dims_with_trans(blis_transb, k0, n0, &m0_b, &n0_b); + + bli_obj_init_finish_1x1(dt, (double *)alpha, &alphao); + bli_obj_init_finish_1x1(dt, (double *)beta, &betao); + + bli_obj_init_finish(dt, m0_a, n0_a, (double *)a, rs_a, cs_a, &ao); + bli_obj_init_finish(dt, m0_b, n0_b, (double *)b, rs_b, cs_b, &bo); + bli_obj_init_finish(dt, m0, n0, (double *)c, rs_c, cs_c, &co); + + bli_obj_set_conjtrans(blis_transa, &ao); + bli_obj_set_conjtrans(blis_transb, &bo); + + // Will call parallelized dgemm code - sup & native + PASTEMAC(gemm, BLIS_OAPI_EX_SUF) + ( + &alphao, + &ao, + &bo, + &betao, + &co, + NULL, + NULL + ); + + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + /* Finalize BLIS. */ + bli_finalize_auto(); + return; + } + + if((k0 == 1) && bli_is_notrans(blis_transa) && bli_is_notrans(blis_transb)) + { + bli_dgemm_ref_k1_nn( m0, n0, k0, + (double*)alpha, + (double*)a, *lda, + (double*)b, *ldb, + (double*)beta, + c, *ldc + ); + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + /* Finalize BLIS */ + bli_finalize_auto(); + + return; } if (n0 == 1) { - if (bli_is_notrans(blis_transa)) - { - bli_dgemv_unf_var2( - BLIS_NO_TRANSPOSE, - bli_extract_conj(blis_transb), - m0, k0, - (double*)alpha, - (double*)a, rs_a, cs_a, - (double*)b, bli_is_notrans(blis_transb) ? rs_b : cs_b, - (double*)beta, - c, rs_c, - ((void*)0) - ); - } - else - { - bli_dgemv_unf_var1( - blis_transa, - bli_extract_conj(blis_transb), - k0, m0, - (double*)alpha, - (double*)a, rs_a, cs_a, - (double*)b, bli_is_notrans(blis_transb) ? rs_b : cs_b, - (double*)beta, - c, rs_c, - ((void*)0) - ); - } - - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - - return; + if (bli_is_notrans(blis_transa)) + { + bli_dgemv_unf_var2( + BLIS_NO_TRANSPOSE, + bli_extract_conj(blis_transb), + m0, k0, + (double*)alpha, + (double*)a, rs_a, cs_a, + (double*)b, bli_is_notrans(blis_transb) ? rs_b : cs_b, + (double*)beta, + c, rs_c, + ((void*)0) + ); + } + else + { + bli_dgemv_unf_var1( + blis_transa, + bli_extract_conj(blis_transb), + k0, m0, + (double*)alpha, + (double*)a, rs_a, cs_a, + (double*)b, bli_is_notrans(blis_transb) ? rs_b : cs_b, + (double*)beta, + c, rs_c, + ((void*)0) + ); + } + + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + + return; } else if (m0 == 1) { - if (bli_is_notrans(blis_transb)) - { - bli_dgemv_unf_var1( - blis_transb, - bli_extract_conj(blis_transa), - n0, k0, - (double*)alpha, - (double*)b, cs_b, rs_b, - (double*)a, bli_is_notrans(blis_transa) ? cs_a : rs_a, - (double*)beta, - c, cs_c, - ((void*)0) - ); - } - else - { - bli_dgemv_unf_var2( - blis_transb, - bli_extract_conj(blis_transa), - k0, n0, - (double*)alpha, - (double*)b, cs_b, rs_b, - (double*)a, bli_is_notrans(blis_transa) ? cs_a : rs_a, - (double*)beta, - c, cs_c, - ((void*)0) - ); - } - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - return; + if (bli_is_notrans(blis_transb)) + { + bli_dgemv_unf_var1( + blis_transb, + bli_extract_conj(blis_transa), + n0, k0, + (double*)alpha, + (double*)b, cs_b, rs_b, + (double*)a, bli_is_notrans(blis_transa) ? cs_a : rs_a, + (double*)beta, + c, cs_c, + ((void*)0) + ); + } + else + { + bli_dgemv_unf_var2( + blis_transb, + bli_extract_conj(blis_transa), + k0, n0, + (double*)alpha, + (double*)b, cs_b, rs_b, + (double*)a, bli_is_notrans(blis_transa) ? cs_a : rs_a, + (double*)beta, + c, cs_c, + ((void*)0) + ); + } + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + return; } const num_t dt = BLIS_DOUBLE; @@ -527,29 +527,29 @@ void dgemm_ bool nt = bli_thread_get_is_parallel(); // Check if parallel dgemm is invoked. #ifdef AOCL_DYNAMIC - //For smaller sizes dgemm_small is perfoming better + //For smaller sizes dgemm_small is perfoming better if (nt && (((m0 >32) || (n0>32) || (k0>32)) && ((m0+n0+k0)>150)) ) #else if (nt) #endif { - // Will call parallelized dgemm code - sup & native - PASTEMAC(gemm, BLIS_OAPI_EX_SUF) - ( - &alphao, - &ao, - &bo, - &betao, - &co, - NULL, - NULL + // Will call parallelized dgemm code - sup & native + PASTEMAC(gemm, BLIS_OAPI_EX_SUF) + ( + &alphao, + &ao, + &bo, + &betao, + &co, + NULL, + NULL ); - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - /* Finalize BLIS. */ - bli_finalize_auto(); - return; + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + /* Finalize BLIS. */ + bli_finalize_auto(); + return; } // The code below will be called when number of threads = 1. @@ -558,71 +558,71 @@ void dgemm_ //if( ((m0 + n0 -k0) < 2000) && ((m0 + k0-n0) < 2000) && ((n0 + k0-m0) < 2000) && (n0 > 2)) if( ( ( (m0 + n0 -k0) < 2000) && ((m0 + k0-n0) < 2000) && ((n0 + k0-m0) < 2000) ) || - ((n0 <= 10) && (k0 <=10)) ) + ((n0 <= 10) && (k0 <=10)) ) + { + err_t status; + if (bli_is_notrans(blis_transa)) + { + status = bli_dgemm_small( &alphao, + &ao, + &bo, + &betao, + &co, + NULL, //cntx, + NULL + ); + } + else { - err_t status; - if (bli_is_notrans(blis_transa)) - { - status = bli_dgemm_small( &alphao, - &ao, - &bo, - &betao, - &co, - NULL, //cntx, - NULL - ); - } - else - { - status = bli_dgemm_small_At ( &alphao, - &ao, - &bo, - &betao, - &co, - NULL, //cntx, - NULL - ); - } - - if (status == BLIS_SUCCESS) - { - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - /* Finalize BLIS. */ - bli_finalize_auto(); - - return; - } + status = bli_dgemm_small_At ( &alphao, + &ao, + &bo, + &betao, + &co, + NULL, //cntx, + NULL + ); + } + + if (status == BLIS_SUCCESS) + { + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + /* Finalize BLIS. */ + bli_finalize_auto(); + + return; + } } #endif //#ifdef BLIS_ENABLE_SMALL_MATRIX err_t status = bli_gemmsup(&alphao, &ao, &bo, &betao, &co, NULL, NULL); - if (status == BLIS_SUCCESS) - { - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - return; - } - - // fall back on native path when dgemm is not handled in sup path. - bli_gemmnat(&alphao, &ao, &bo, &betao, &co, NULL, NULL); - - - /* PASTEMAC(gemm, BLIS_OAPI_EX_SUF) */ - /* ( */ - /* &alphao, */ - /* &ao, */ - /* &bo, */ - /* &betao, */ - /* &co, */ - /* NULL, */ - /* NULL */ - /* ); */ - - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - /* Finalize BLIS. */ - bli_finalize_auto(); + if (status == BLIS_SUCCESS) + { + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + return; + } + + // fall back on native path when dgemm is not handled in sup path. + bli_gemmnat(&alphao, &ao, &bo, &betao, &co, NULL, NULL); + + + /* PASTEMAC(gemm, BLIS_OAPI_EX_SUF) */ + /* ( */ + /* &alphao, */ + /* &ao, */ + /* &bo, */ + /* &betao, */ + /* &co, */ + /* NULL, */ + /* NULL */ + /* ); */ + + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + /* Finalize BLIS. */ + bli_finalize_auto(); } // end of dgemm_ void zgemm_ @@ -648,176 +648,166 @@ void zgemm_ AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(z), *transa, *transb, *m, *n, *k, - (void*)alpha, *lda, *ldb, (void*)beta, *ldc); + (void*)alpha, *lda, *ldb, (void*)beta, *ldc); /* Perform BLAS parameter checking. */ - PASTEBLACHK(gemm) - ( - MKSTR(z), - MKSTR(gemm), - transa, - transb, - m, - n, - k, - lda, - ldb, - ldc - ); - - /* Map BLAS chars to their corresponding BLIS enumerated type value. */ - bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); - bli_param_map_netlib_to_blis_trans( *transb, &blis_transb ); - - /* Typecast BLAS integers to BLIS integers. */ - bli_convert_blas_dim1( *m, m0 ); - bli_convert_blas_dim1( *n, n0 ); - bli_convert_blas_dim1( *k, k0 ); - - /* Set the row and column strides of the matrix operands. */ - const inc_t rs_a = 1; - const inc_t cs_a = *lda; - const inc_t rs_b = 1; - const inc_t cs_b = *ldb; - const inc_t rs_c = 1; - const inc_t cs_c = *ldc; - - const num_t dt = BLIS_DCOMPLEX; - - obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; - obj_t ao = BLIS_OBJECT_INITIALIZER; - obj_t bo = BLIS_OBJECT_INITIALIZER; - obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; - obj_t co = BLIS_OBJECT_INITIALIZER; - - dim_t m0_a, n0_a; - dim_t m0_b, n0_b; - - bli_set_dims_with_trans( blis_transa, m0, k0, &m0_a, &n0_a ); - bli_set_dims_with_trans( blis_transb, k0, n0, &m0_b, &n0_b ); - - bli_obj_init_finish_1x1( dt, (dcomplex*)alpha, &alphao ); - bli_obj_init_finish_1x1( dt, (dcomplex*)beta, &betao ); - - bli_obj_init_finish( dt, m0_a, n0_a, (dcomplex*)a, rs_a, cs_a, &ao ); - bli_obj_init_finish( dt, m0_b, n0_b, (dcomplex*)b, rs_b, cs_b, &bo ); - bli_obj_init_finish( dt, m0, n0, (dcomplex*)c, rs_c, cs_c, &co ); - - bli_obj_set_conjtrans( blis_transa, &ao ); - bli_obj_set_conjtrans( blis_transb, &bo ); - - // default instance peformance tuning is done in zgemm. - // Single instance tuning is done based on env set. - dim_t single_instance = bli_env_get_var( "BLIS_SINGLE_INSTANCE", -1 ); - - //dim_t nt = bli_thread_get_num_threads(); // get number of threads - bool nt = bli_thread_get_is_parallel(); // Check if parallel zgemm is invoked. -#ifdef AOCL_DYNAMIC - //For smaller sizes zgemm_small is perfoming better - if (nt && (((m0 >32) || (n0>32) || (k0>32)) && ((m0+n0+k0)>100)) ) -#else - if (nt) -#endif - { - // Will call parallelized zgemm code - sup & native - PASTEMAC(gemm, BLIS_OAPI_EX_SUF) - ( - &alphao, - &ao, - &bo, - &betao, - &co, - NULL, - NULL - ); - - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - /* Finalize BLIS. */ - bli_finalize_auto(); - return; - } + PASTEBLACHK(gemm) + ( + MKSTR(z), + MKSTR(gemm), + transa, + transb, + m, + n, + k, + lda, + ldb, + ldc + ); + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); + bli_param_map_netlib_to_blis_trans( *transb, &blis_transb ); + + /* Typecast BLAS integers to BLIS integers. */ + bli_convert_blas_dim1( *m, m0 ); + bli_convert_blas_dim1( *n, n0 ); + bli_convert_blas_dim1( *k, k0 ); + + /* Set the row and column strides of the matrix operands. */ + const inc_t rs_a = 1; + const inc_t cs_a = *lda; + const inc_t rs_b = 1; + const inc_t cs_b = *ldb; + const inc_t rs_c = 1; + const inc_t cs_c = *ldc; + + const num_t dt = BLIS_DCOMPLEX; + + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; + obj_t ao = BLIS_OBJECT_INITIALIZER; + obj_t bo = BLIS_OBJECT_INITIALIZER; + obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; + obj_t co = BLIS_OBJECT_INITIALIZER; + + dim_t m0_a, n0_a; + dim_t m0_b, n0_b; + + bli_set_dims_with_trans( blis_transa, m0, k0, &m0_a, &n0_a ); + bli_set_dims_with_trans( blis_transb, k0, n0, &m0_b, &n0_b ); + + bli_obj_init_finish_1x1( dt, (dcomplex*)alpha, &alphao ); + bli_obj_init_finish_1x1( dt, (dcomplex*)beta, &betao ); + + bli_obj_init_finish( dt, m0_a, n0_a, (dcomplex*)a, rs_a, cs_a, &ao ); + bli_obj_init_finish( dt, m0_b, n0_b, (dcomplex*)b, rs_b, cs_b, &bo ); + bli_obj_init_finish( dt, m0, n0, (dcomplex*)c, rs_c, cs_c, &co ); + + bli_obj_set_conjtrans( blis_transa, &ao ); + bli_obj_set_conjtrans( blis_transb, &bo ); + + // default instance peformance tuning is done in zgemm. + // Single instance tuning is done based on env set. + //dim_t single_instance = bli_env_get_var( "BLIS_SINGLE_INSTANCE", -1 ); + + //dim_t nt = bli_thread_get_num_threads(); // get number of threads + bool nt = bli_thread_get_is_parallel(); // Check if parallel zgemm is invoked. #ifdef BLIS_ENABLE_SMALL_MATRIX - err_t status; - - if((nt == 0) && (m0 <= 512 ) && ( n0 <= 512 ) && ( k0 <= 512 )) - { - status = bli_gemm_small( &alphao, - &ao, - &bo, - &betao, - &co, - NULL, //cntx, - NULL - ); - } - - if (status == BLIS_SUCCESS) - { - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - /* Finalize BLIS. */ - bli_finalize_auto(); - - return; - } + + if( ( (nt == 0) && (m0 <= 512 ) && ( n0 <= 512 ) && ( k0 <= 512 ) ) || + ( (nt == 1) && ((( m0 <= 32)||(n0 <= 32)||(k0 <=32)) && ((m0+n0+k0)<=100)) ) + ) + { + err_t status = BLIS_NOT_YET_IMPLEMENTED; + if (bli_is_notrans(blis_transa)) + { + status = bli_zgemm_small(&alphao, + &ao, + &bo, + &betao, + &co, + NULL, //cntx, + NULL + ); + } + else + { + status = bli_zgemm_small_At(&alphao, + &ao, + &bo, + &betao, + &co, + NULL, //cntx, + NULL + ); + } + + if (status == BLIS_SUCCESS) + { + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + /* Finalize BLIS. */ + bli_finalize_auto(); + return; + } + } #endif + // The code below will be called when number of threads = 1. -#if ENABLE_INDUCED_METHOD - /* 3m_sqp is optimal for certain matrix shapes. - Initial study that it works well for square sizes and sizes closer to square shape. - - * Usage of 3m_sqp is restricted to sizes, where it is found efficient compared to native, sup and other induced method. - * Further investigation is necessary to make the usage choices more generic. */ - bool sqp_on = false; - if( (m0 == n0 ) && ( n0 == k0 ) && ( m0 == 128 ) ) - { - sqp_on = true; - } - - // current range of sizes used for 3m_sqp to be expaned after evaluation. - if( ( m0 >= 4200) && ( m0 <= 4600 ) && ( ( n0 >= 326 ) || (n0 <= 1600 ) ) +#if 0//ENABLE_INDUCED_METHOD + /* 3m_sqp is optimal for certain matrix shapes. + Initial study that it works well for square sizes and sizes closer to square shape. + + * Usage of 3m_sqp is restricted to sizes, where it is found efficient compared to native, sup and other induced method. + * Further investigation is necessary to make the usage choices more generic. */ + bool sqp_on = false; + if( (m0 == n0 ) && ( n0 == k0 ) && ( m0 == 128 ) ) + { + sqp_on = true; + } + + // current range of sizes used for 3m_sqp to be expaned after evaluation. + if( ( m0 >= 4200) && ( m0 <= 4600 ) && ( ( n0 >= 326 ) || (n0 <= 1600 ) ) && ( k0 == 1120 ) ) //to be tuned further. - { - sqp_on = true; - } - - if( ( blis_transb == BLIS_NO_TRANSPOSE) && ( sqp_on == true ) ) - { - //sqp algo is found better for n > 40 - if(bli_gemm_sqp(&alphao, &ao, &bo, &betao, &co, NULL, NULL)==BLIS_SUCCESS) - { - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) - return; - } - } + { + sqp_on = true; + } + + if( ( blis_transb == BLIS_NO_TRANSPOSE) && ( sqp_on == true ) ) + { + //sqp algo is found better for n > 40 + if(bli_gemm_sqp(&alphao, &ao, &bo, &betao, &co, NULL, NULL)==BLIS_SUCCESS) + { + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) + return; + } + } #endif//ENABLE_INDUCED_METHOD -// native tuning resulted in better numbers compared to sup in constrained multi-instance -// sup has been enabled for single instance cases. - if(single_instance==1) - { - err_t status = bli_gemmsup(&alphao, &ao, &bo, &betao, &co, NULL, NULL); - if(status==BLIS_SUCCESS) - { - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) - return; - } - - } - // fall back on native path when zgemm is not handled in sup path. - bli_gemmnat(&alphao, &ao, &bo, &betao, &co, NULL, NULL); - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) - return; - - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) - /* Finalize BLIS. */ - bli_finalize_auto(); +// sup has been disabled. + if(0) + { + err_t status = bli_gemmsup(&alphao, &ao, &bo, &betao, &co, NULL, NULL); + if(status==BLIS_SUCCESS) + { + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) + return; + } + + } + // fall back on native path when zgemm is not handled in sup path. + bli_gemmnat(&alphao, &ao, &bo, &betao, &co, NULL, NULL); + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) + return; + + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) + /* Finalize BLIS. */ + bli_finalize_auto(); }// end of zgemm_ @@ -851,72 +841,72 @@ void dzgemm_ AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(z), *transa, *transb, *m, *n, *k, - (void*)alpha, *lda, *ldb, (void*)beta, *ldc); + (void*)alpha, *lda, *ldb, (void*)beta, *ldc); /* Perform BLAS parameter checking. */ - PASTEBLACHK(gemm) - ( - MKSTR(z), - MKSTR(gemm), - transa, - transb, - m, - n, - k, - lda, - ldb, - ldc - ); - - /* Map BLAS chars to their corresponding BLIS enumerated type value. */ - bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); - bli_param_map_netlib_to_blis_trans( *transb, &blis_transb ); - - /* Typecast BLAS integers to BLIS integers. */ - bli_convert_blas_dim1( *m, m0 ); - bli_convert_blas_dim1( *n, n0 ); - bli_convert_blas_dim1( *k, k0 ); - - /* Set the row and column strides of the matrix operands. */ - const inc_t rs_a = 1; - const inc_t cs_a = *lda; - const inc_t rs_b = 1; - const inc_t cs_b = *ldb; - const inc_t rs_c = 1; - const inc_t cs_c = *ldc; - - const num_t dt = BLIS_DCOMPLEX; - const num_t dt_a = BLIS_DOUBLE; - - obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; - obj_t ao = BLIS_OBJECT_INITIALIZER; - obj_t bo = BLIS_OBJECT_INITIALIZER; - obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; - obj_t co = BLIS_OBJECT_INITIALIZER; - - dim_t m0_a, n0_a; - dim_t m0_b, n0_b; - - bli_set_dims_with_trans( blis_transa, m0, k0, &m0_a, &n0_a ); - bli_set_dims_with_trans( blis_transb, k0, n0, &m0_b, &n0_b ); - - bli_obj_init_finish_1x1( dt, (dcomplex*)alpha, &alphao ); - bli_obj_init_finish_1x1( dt, (dcomplex*)beta, &betao ); - - bli_obj_init_finish( dt_a, m0_a, n0_a, (double*)a, rs_a, cs_a, &ao ); - bli_obj_init_finish( dt, m0_b, n0_b, (dcomplex*)b, rs_b, cs_b, &bo ); - bli_obj_init_finish( dt, m0, n0, (dcomplex*)c, rs_c, cs_c, &co ); - - bli_obj_set_conjtrans( blis_transa, &ao ); - bli_obj_set_conjtrans( blis_transb, &bo ); - - // fall back on native path when zgemm is not handled in sup path. - bli_gemmnat(&alphao, &ao, &bo, &betao, &co, NULL, NULL); - - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) - /* Finalize BLIS. */ - bli_finalize_auto(); + PASTEBLACHK(gemm) + ( + MKSTR(z), + MKSTR(gemm), + transa, + transb, + m, + n, + k, + lda, + ldb, + ldc + ); + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); + bli_param_map_netlib_to_blis_trans( *transb, &blis_transb ); + + /* Typecast BLAS integers to BLIS integers. */ + bli_convert_blas_dim1( *m, m0 ); + bli_convert_blas_dim1( *n, n0 ); + bli_convert_blas_dim1( *k, k0 ); + + /* Set the row and column strides of the matrix operands. */ + const inc_t rs_a = 1; + const inc_t cs_a = *lda; + const inc_t rs_b = 1; + const inc_t cs_b = *ldb; + const inc_t rs_c = 1; + const inc_t cs_c = *ldc; + + const num_t dt = BLIS_DCOMPLEX; + const num_t dt_a = BLIS_DOUBLE; + + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; + obj_t ao = BLIS_OBJECT_INITIALIZER; + obj_t bo = BLIS_OBJECT_INITIALIZER; + obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; + obj_t co = BLIS_OBJECT_INITIALIZER; + + dim_t m0_a, n0_a; + dim_t m0_b, n0_b; + + bli_set_dims_with_trans( blis_transa, m0, k0, &m0_a, &n0_a ); + bli_set_dims_with_trans( blis_transb, k0, n0, &m0_b, &n0_b ); + + bli_obj_init_finish_1x1( dt, (dcomplex*)alpha, &alphao ); + bli_obj_init_finish_1x1( dt, (dcomplex*)beta, &betao ); + + bli_obj_init_finish( dt_a, m0_a, n0_a, (double*)a, rs_a, cs_a, &ao ); + bli_obj_init_finish( dt, m0_b, n0_b, (dcomplex*)b, rs_b, cs_b, &bo ); + bli_obj_init_finish( dt, m0, n0, (dcomplex*)c, rs_c, cs_c, &co ); + + bli_obj_set_conjtrans( blis_transa, &ao ); + bli_obj_set_conjtrans( blis_transb, &bo ); + + // fall back on native path when zgemm is not handled in sup path. + bli_gemmnat(&alphao, &ao, &bo, &betao, &co, NULL, NULL); + + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) + /* Finalize BLIS. */ + bli_finalize_auto(); }// end of dzgemm_ #endif #endif diff --git a/kernels/zen/3/bli_gemm_small.c b/kernels/zen/3/bli_gemm_small.c index e0c84eb227..0cf5c8c5ce 100644 --- a/kernels/zen/3/bli_gemm_small.c +++ b/kernels/zen/3/bli_gemm_small.c @@ -71,7 +71,7 @@ err_t bli_dgemm_small cntx_t* cntx, cntl_t* cntl ); -static err_t bli_zgemm_small +err_t bli_zgemm_small ( obj_t* alpha, obj_t* a, @@ -81,7 +81,7 @@ static err_t bli_zgemm_small cntx_t* cntx, cntl_t* cntl ); -static err_t bli_zgemm_small_At +err_t bli_zgemm_small_At ( obj_t* alpha, obj_t* a, @@ -128,18 +128,18 @@ err_t bli_gemm_small cntl_t* cntl ) { - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); #ifdef BLIS_ENABLE_MULTITHREADING - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); - return BLIS_NOT_YET_IMPLEMENTED; + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); + return BLIS_NOT_YET_IMPLEMENTED; #else // This function is invoked on all architectures including ‘generic’. // Non-AVX platforms will use the kernels derived from the context. if (bli_cpuid_is_avx_supported() == FALSE) - { - return BLIS_NOT_YET_IMPLEMENTED; - } + { + return BLIS_NOT_YET_IMPLEMENTED; + } #endif // If alpha is zero, scale by beta and return. @@ -172,8 +172,8 @@ err_t bli_gemm_small return bli_dgemm_small_At(alpha, a, b, beta, c, cntx, cntl); #endif } - if(dt == BLIS_DCOMPLEX) - { + if(dt == BLIS_DCOMPLEX) + { #ifndef BLIS_ENABLE_MULTITHREADING // bli_zgemm_small_At is called directly from blas interface for // sizes within thresholds. @@ -181,9 +181,9 @@ err_t bli_gemm_small // and directing to native implementation. return BLIS_NOT_YET_IMPLEMENTED; #else - return bli_zgemm_small_At(alpha, a, b, beta, c, cntx, cntl); + return bli_zgemm_small_At(alpha, a, b, beta, c, cntx, cntl); #endif - } + } if (bli_obj_has_notrans( b )) { @@ -230,7 +230,7 @@ err_t bli_gemm_small return bli_sgemm_small(alpha, a, b, beta, c, cntx, cntl); } - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); return BLIS_NOT_YET_IMPLEMENTED; }; @@ -245,13 +245,13 @@ static err_t bli_sgemm_small cntl_t* cntl ) { - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); gint_t M = bli_obj_length( c ); // number of rows of Matrix C gint_t N = bli_obj_width( c ); // number of columns of Matrix C gint_t K = bli_obj_width( a ); // number of columns of OP(A), will be updated if OP(A) is Transpose(A) . gint_t L = M * N; - // when N is equal to 1 call GEMV instead of GEMM + // when N is equal to 1 call GEMV instead of GEMM if (N == 1) { bli_gemv @@ -262,7 +262,7 @@ static err_t bli_sgemm_small beta, c ); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); return BLIS_SUCCESS; } @@ -288,7 +288,7 @@ static err_t bli_sgemm_small dim_t tb_inc_row = 1; // row stride of matrix B dim_t tb_inc_col = ldb; // column stride of matrix B - __m256 ymm4, ymm5, ymm6, ymm7; + __m256 ymm4, ymm5, ymm6, ymm7; __m256 ymm8, ymm9, ymm10, ymm11; __m256 ymm12, ymm13, ymm14, ymm15; __m256 ymm0, ymm1, ymm2, ymm3; @@ -310,7 +310,7 @@ static err_t bli_sgemm_small is_beta_non_zero = 1; } - //update the pointer math if matrix B needs to be transposed. + //update the pointer math if matrix B needs to be transposed. if (bli_obj_has_trans( b )) { tb_inc_col = 1; //switch row and column strides tb_inc_row = ldb; @@ -1668,7 +1668,7 @@ static err_t bli_sgemm_small if(is_beta_non_zero){ ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7); } - _mm256_storeu_ps(f_temp, ymm7); + _mm256_storeu_ps(f_temp, ymm7); for (int i = 0; i < m_remainder; i++) { tC[i] = f_temp[i]; @@ -1771,17 +1771,17 @@ static err_t bli_sgemm_small &local_mem_buf_A_s); } - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); return BLIS_SUCCESS; } else - { - AOCL_DTL_TRACE_EXIT_ERR( - AOCL_DTL_LEVEL_INFO, - "Invalid dimesions for small gemm." - ); - return BLIS_NONCONFORMAL_DIMENSIONS; - } + { + AOCL_DTL_TRACE_EXIT_ERR( + AOCL_DTL_LEVEL_INFO, + "Invalid dimesions for small gemm." + ); + return BLIS_NONCONFORMAL_DIMENSIONS; + } }; @@ -1796,21 +1796,24 @@ static err_t bli_sgemm_small cntl_t* cntl ) { - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO); - + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO); + if (bli_cpuid_is_avx_supported() == FALSE) + { + return BLIS_NOT_YET_IMPLEMENTED; + } gint_t M = bli_obj_length( c ); // number of rows of Matrix C gint_t N = bli_obj_width( c ); // number of columns of Matrix C gint_t K = bli_obj_width( a ); // number of columns of OP(A), will be updated if OP(A) is Transpose(A) . gint_t L = M * N; /* if (N<3) //Implemenation assumes that N is atleast 3. VK */ - /* { */ - /* AOCL_DTL_TRACE_EXIT_ERR( */ - /* AOCL_DTL_LEVEL_INFO, */ + /* { */ + /* AOCL_DTL_TRACE_EXIT_ERR( */ + /* AOCL_DTL_LEVEL_INFO, */ /* "N < 3 cannot be processed by small_gemm" */ - /* ); */ + /* ); */ /* return BLIS_NOT_YET_IMPLEMENTED; VK */ - /* } */ + /* } */ if(L && K ) // Non-zero dimensions will be handled by either sup or native kernels @@ -1900,8 +1903,8 @@ static err_t bli_sgemm_small // reported in CPUPL-587. // - // if ((N <= 3) || ((D_MR * K) << 3) > buffer_size) - if ((N < 3) || ((D_MR * K) << 3) > buffer_size) + // if ((N <= 3) || ((D_MR * K) << 3) > buffer_size) + if ((N < 3) || ((D_MR * K) << 3) > buffer_size) { required_packing_A = 0; } @@ -3359,17 +3362,17 @@ static err_t bli_sgemm_small bli_membrk_release(&rntm, &local_mem_buf_A_s); } - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); return BLIS_SUCCESS; } else - { - AOCL_DTL_TRACE_EXIT_ERR( - AOCL_DTL_LEVEL_INFO, - "Invalid dimesions for small gemm." - ); + { + AOCL_DTL_TRACE_EXIT_ERR( + AOCL_DTL_LEVEL_INFO, + "Invalid dimesions for small gemm." + ); return BLIS_NONCONFORMAL_DIMENSIONS; - } + } }; static err_t bli_sgemm_small_atbn @@ -3383,9 +3386,9 @@ static err_t bli_sgemm_small_atbn cntl_t* cntl ) { - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO); + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO); - gint_t M = bli_obj_length( c ); // number of rows of Matrix C + gint_t M = bli_obj_length( c ); // number of rows of Matrix C gint_t N = bli_obj_width( c ); // number of columns of Matrix C gint_t K = bli_obj_length( b ); // number of rows of Matrix B @@ -3836,17 +3839,17 @@ static err_t bli_sgemm_small_atbn } } } - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); return BLIS_SUCCESS; } else - { - AOCL_DTL_TRACE_EXIT_ERR( - AOCL_DTL_LEVEL_INFO, - "Invalid dimesions for small gemm." - ); + { + AOCL_DTL_TRACE_EXIT_ERR( + AOCL_DTL_LEVEL_INFO, + "Invalid dimesions for small gemm." + ); return BLIS_NONCONFORMAL_DIMENSIONS; - } + } } static err_t bli_dgemm_small_atbn @@ -3860,7 +3863,7 @@ static err_t bli_dgemm_small_atbn cntl_t* cntl ) { - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO); + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO); gint_t M = bli_obj_length( c ); // number of rows of Matrix C gint_t N = bli_obj_width( c ); // number of columns of Matrix C @@ -4276,17 +4279,17 @@ static err_t bli_dgemm_small_atbn } } } - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); return BLIS_SUCCESS; } else - { - AOCL_DTL_TRACE_EXIT_ERR( - AOCL_DTL_LEVEL_INFO, - "Invalid dimesions for small gemm." - ); - return BLIS_NONCONFORMAL_DIMENSIONS; - } + { + AOCL_DTL_TRACE_EXIT_ERR( + AOCL_DTL_LEVEL_INFO, + "Invalid dimesions for small gemm." + ); + return BLIS_NONCONFORMAL_DIMENSIONS; + } } err_t bli_dgemm_small_At @@ -4302,7 +4305,10 @@ err_t bli_dgemm_small_At { AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO); - + if (bli_cpuid_is_avx_supported() == FALSE) + { + return BLIS_NOT_YET_IMPLEMENTED; + } gint_t M = bli_obj_length( c ); // number of rows of Matrix C gint_t N = bli_obj_width( c ); // number of columns of Matrix C gint_t K = bli_obj_width_after_trans( a ); // number of columns of OP(A), will be updated if OP(A) is Transpose(A) . @@ -4352,14 +4358,14 @@ err_t bli_dgemm_small_At if( bli_obj_has_trans( b ) ) { - tb_inc_col = 1; // switch row and column strides + tb_inc_col = 1; // switch row and column strides tb_inc_row = ldb; } __m256d ymm4, ymm5, ymm6, ymm7; __m256d ymm8, ymm9, ymm10, ymm11; __m256d ymm12, ymm13, ymm14, ymm15; - __m256d ymm0, ymm1, ymm2, ymm3; + __m256d ymm0, ymm1, ymm2, ymm3; double result; double scratch[8] = {0.0}; @@ -5780,7 +5786,7 @@ err_t bli_dgemm_small_At -static err_t bli_zgemm_small +err_t bli_zgemm_small ( obj_t* alpha, obj_t* a, @@ -5791,7635 +5797,7640 @@ static err_t bli_zgemm_small cntl_t* cntl ) { - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO); - - bool conjtransa = bli_obj_has_conj(a); - bool conjtransb = bli_obj_has_conj(b); - - gint_t M = bli_obj_length( c ); // number of rows of Matrix C - gint_t N = bli_obj_width( c ); // number of columns of Matrix C - // number of columns of OP(A), will be updated if OP(A) is Transpose(A) - gint_t K = bli_obj_width( a ); - gint_t L = M * N; - - if(L && K ) - { - guint_t lda = bli_obj_col_stride( a ); // column stride of matrix OP(A). - guint_t ldb = bli_obj_col_stride( b ); // column stride of matrix OP(B). - guint_t ldc = bli_obj_col_stride( c ); // column stride of matrix C - guint_t row_idx, col_idx, k; - dcomplex *A = bli_obj_buffer_at_off(a); //pointer to elements of Matrix A - dcomplex *B = bli_obj_buffer_at_off(b); //pointer to elements of Matrix B - dcomplex *C = bli_obj_buffer_at_off(c); //pointer to elements of Matrix C - - dcomplex *tA = A, *tB = B, *tC = C;//, *tA_pack; - dcomplex *tA_packed; //temprorary pointer to hold packed A memory pointer - guint_t row_idx_packed; //packed A memory row index - guint_t lda_packed; //lda of packed A - guint_t col_idx_start; //starting index after A matrix is packed. - dim_t tb_inc_row = 1; // row stride of matrix B - dim_t tb_inc_col = ldb; // column stride of matrix B - __m256d ymm4, ymm5, ymm6, ymm7; - __m256d ymm8, ymm9, ymm10, ymm11; - __m256d ymm12, ymm13, ymm14, ymm15; - __m256d ymm16, ymm17, ymm18, ymm19, ymm20, ymm21; - __m256d ymm0, ymm1, ymm2, ymm3; - - gint_t n_remainder; // If the N is non multiple of 3.(N%3) - gint_t m_remainder; // If the M is non multiple of 4.(M%4) - - dcomplex *alpha_cast, *beta_cast; // alpha, beta multiples - alpha_cast = bli_obj_buffer_for_1x1(BLIS_DCOMPLEX, alpha); - beta_cast = bli_obj_buffer_for_1x1(BLIS_DCOMPLEX, beta); - - gint_t required_packing_A = 1; - mem_t local_mem_buf_A_s; - dcomplex *D_A_pack = NULL; - rntm_t rntm; - - //update the pointer math if matrix B needs to be transposed. - if (bli_obj_has_trans( b )) - { - tb_inc_col = 1; //switch row and column strides - tb_inc_row = ldb; - } - - //checking whether beta value is zero. - //if true, we should perform C=alpha * A*B operation - //instead of C = beta * C + alpha * (A * B) - bool is_beta_non_zero = 0; - if(!bli_obj_equals(beta, &BLIS_ZERO)) - is_beta_non_zero = 1; - - /* - * This function was using global array to pack part of A input when - * needed. However, using this global array make the function - * non-reentrant. Instead of using a global array we should allocate - * buffer for each invocation. Since the buffer size is too big or stack - * and doing malloc every time will be too expensive, better approach is - * to get the buffer from the pre-allocated pool and it the pool once we - * are doing. - * - * In order to get the buffer from pool, we need access to memory broker, - * currently this function is not invoked in such a way that it can - * receive the memory broker (via rntm). Following hack will get the - * global memory broker that can be use it to access the pool. - * - * Note there will be memory allocation at least on first innovation - * as there will not be any pool created for this size. - * Subsequent invocations will just reuse the buffer from the pool. - */ - - bli_rntm_init_from_global( &rntm ); - bli_rntm_set_num_threads_only( 1, &rntm ); - bli_membrk_rntm_set_membrk( &rntm ); - - // Get the current size of the buffer pool for A block packing. - // We will use the same size to avoid pool re-initliazaton - siz_t buffer_size = bli_pool_block_size( - bli_membrk_pool(bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), - bli_rntm_membrk(&rntm))); - - // - // This kernel assumes that "A" will be unpackged if N <= 3. - // Usually this range (N <= 3) is handled by SUP, however, - // if SUP is disabled or for any other condition if we do - // enter this kernel with N <= 3, we want to make sure that - // "A" remains unpacked. - // - - if ((N < 3) || ((Z_MR * K) << 3) > buffer_size) - { - required_packing_A = 0; - } - - if (required_packing_A == 1) - { -#ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_zgemm_small: Requesting mem pool block of size %lu\n", - buffer_size); -#endif - // Get the buffer from the pool. - bli_membrk_acquire_m(&rntm, - buffer_size, - BLIS_BITVAL_BUFFER_FOR_A_BLOCK, - &local_mem_buf_A_s); - - D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); - } - - /* - * The computation loop runs for Z_MRxN columns of C matrix, thus - * accessing the Z_MRxK A matrix data and KxNR B matrix data. - * The computation is organized as inner loops of dimension Z_MRxNR. - */ - // Process D_MR rows of C matrix at a time. - for (row_idx = 0; (row_idx + (Z_MR - 1)) < M; row_idx += Z_MR) - { - col_idx_start = 0; - tA_packed = A; - row_idx_packed = row_idx; - lda_packed = lda; - - /** - * This is the part of the pack and compute optimization. - * During the first column iteration, we store the accessed A - * matrix into contiguous static memory. This helps to keep te A - * matrix in Cache and aviods the TLB misses. - */ - if (required_packing_A) - { - col_idx = 0; - - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; - tA_packed = D_A_pack; + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO); + if (bli_cpuid_is_avx_supported() == FALSE) + { + return BLIS_NOT_YET_IMPLEMENTED; + } + bool conjtransa = bli_obj_has_conj(a); + bool conjtransb = bli_obj_has_conj(b); -#ifdef BLIS_ENABLE_PREFETCH - _mm_prefetch((char*)(tC + 0), _MM_HINT_T0); - _mm_prefetch((char*)(tC + 8), _MM_HINT_T0); - _mm_prefetch((char*)(tC + ldc), _MM_HINT_T0); - _mm_prefetch((char*)(tC + ldc + 8), _MM_HINT_T0); - _mm_prefetch((char*)(tC + 2 * ldc), _MM_HINT_T0); - _mm_prefetch((char*)(tC + 2 * ldc + 8), _MM_HINT_T0); -#endif - // clear scratch registers. - BLIS_SET_ALL_YMM_REG_ZEROS - - double *tptr = (double *)tB; - if(conjtransa && conjtransb) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B - // matrix i data and multiplies it with - // the A matrix. - // This loop is processing Z_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied with matrix A columns. - ymm0 = _mm256_loadu_pd( - (double const *)tA); - ymm1 = _mm256_loadu_pd( - (double const *)(tA + 2)); - _mm256_storeu_pd( - (double *)tA_packed, ymm0); - _mm256_storeu_pd( - (double *) - (tA_packed + 2), ymm1); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - ymm1 = _mm256_mul_pd(ymm1, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) * - 2 + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); - ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); - - tptr += (tb_inc_row * 2); - tB += tb_inc_row; - tA += lda; - tA_packed += Z_MR; - } - - } - else if(conjtransa) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix data and - // multiplies it with the A matrix. - // This loop is processing Z_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd( - (double const *)tA); - ymm1 = _mm256_loadu_pd( - (double const *)(tA + 2)); - _mm256_storeu_pd( - (double *)tA_packed, ymm0); - _mm256_storeu_pd( - (double *)(tA_packed + 2) - , ymm1); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - ymm1 = _mm256_mul_pd(ymm1, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2 + 1)); - - ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); - ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); - - tptr += (tb_inc_row * 2); - tB += tb_inc_row; - tA += lda; - tA_packed += Z_MR; - } - } - else if(conjtransb) - { - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and multiplies it with the A - // matrix. This loop is processing - // Z_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied with matrix A columns. - ymm0 = _mm256_loadu_pd( - (double const *)tA); - ymm1 = _mm256_loadu_pd( - (double const *)(tA + 2)); - _mm256_storeu_pd( - (double *)tA_packed, ymm0); - _mm256_storeu_pd( - (double *)(tA_packed + 2) - , ymm1); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2 + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); - ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); - - tptr += (tb_inc_row * 2); - tB += tb_inc_row; - tA += lda; - tA_packed += Z_MR; - } - - } - else //handles non-transpose case - { - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and multiplies it with the A - // matrix. This loop is processing - // Z_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd( - (double const *)tA); - ymm1 = _mm256_loadu_pd( - (double const *)(tA + 2)); - _mm256_storeu_pd( - (double *)tA_packed, ymm0); - _mm256_storeu_pd( - (double *)(tA_packed + 2) - , ymm1); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2 + 1)); - - ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); - ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); - - tptr += (tb_inc_row * 2); - tB += tb_inc_row; - tA += lda; - tA_packed += Z_MR; - } - } - - ymm4 = _mm256_permute_pd(ymm4, 0x5); - ymm5 = _mm256_permute_pd(ymm5, 0x5); - ymm6 = _mm256_permute_pd(ymm6, 0x5); - ymm7 = _mm256_permute_pd(ymm7, 0x5); - ymm14 = _mm256_permute_pd(ymm14, 0x5); - ymm15 = _mm256_permute_pd(ymm15, 0x5); - - ymm8 = _mm256_addsub_pd(ymm8, ymm4); - ymm11 = _mm256_addsub_pd(ymm11, ymm5); - ymm9 = _mm256_addsub_pd(ymm9, ymm6); - ymm12 = _mm256_addsub_pd(ymm12, ymm7); - ymm10 = _mm256_addsub_pd(ymm10, ymm14); - ymm13 = _mm256_addsub_pd(ymm13, ymm15); - - // alpha, beta multiplication. - ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm8, ymm0); - ymm14 = _mm256_mul_pd(ymm8, ymm14); - ymm8 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm9, ymm0); - ymm14 = _mm256_mul_pd(ymm9, ymm14); - ymm9 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm10, ymm0); - ymm14 = _mm256_mul_pd(ymm10, ymm14); - ymm10 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm11, ymm0); - ymm14 = _mm256_mul_pd(ymm11, ymm14); - ymm11 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm12, ymm0); - ymm14 = _mm256_mul_pd(ymm12, ymm14); - ymm12 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm13, ymm0); - ymm14 = _mm256_mul_pd(ymm13, ymm14); - ymm13 = _mm256_hsub_pd(ymm15, ymm14); - - ymm2 = _mm256_broadcast_sd((double const *) - &beta_cast->real); - ymm3 = _mm256_broadcast_sd((double const *) - (&beta_cast->imag)); - - - BLIS_SET_YMM_REG_ZEROS - - if(is_beta_non_zero) - { - // multiply C by beta and accumulate col 1. - ymm0 = _mm256_loadu_pd((double const *)tC); - ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - - ymm0 = _mm256_loadu_pd((double const *)(tC + 2)); - ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6); - ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7); - - // col 2 - ymm0 = _mm256_loadu_pd((double const *)(tC + ldc)); - ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); - ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); - - ymm0 = _mm256_loadu_pd((double const *) - (tC + ldc + 2)); - ymm16 = _mm256_fmadd_pd(ymm0, ymm2, ymm16); - ymm17 = _mm256_fmadd_pd(ymm0, ymm3, ymm17); - - // col 3 - ymm0 = _mm256_loadu_pd((double const *) - (tC + (ldc * 2))); - ymm18 = _mm256_fmadd_pd(ymm0, ymm2, ymm18); - ymm19 = _mm256_fmadd_pd(ymm0, ymm3, ymm19); - - ymm0 = _mm256_loadu_pd((double const *) - (tC + (ldc * 2) + 2)); - ymm20 = _mm256_fmadd_pd(ymm0, ymm2, ymm20); - ymm21 = _mm256_fmadd_pd(ymm0, ymm3, ymm21); - - } - ymm5 = _mm256_permute_pd(ymm5, 0x5); - ymm7 = _mm256_permute_pd(ymm7, 0x5); - ymm15 = _mm256_permute_pd(ymm15, 0x5); - ymm17 = _mm256_permute_pd(ymm17, 0x5); - ymm19 = _mm256_permute_pd(ymm19, 0x5); - ymm21 = _mm256_permute_pd(ymm21, 0x5); - - ymm4 = _mm256_addsub_pd(ymm4, ymm5); - ymm6 = _mm256_addsub_pd(ymm6, ymm7); - ymm14 = _mm256_addsub_pd(ymm14, ymm15); - ymm16 = _mm256_addsub_pd(ymm16, ymm17); - ymm18 = _mm256_addsub_pd(ymm18, ymm19); - ymm20 = _mm256_addsub_pd(ymm20, ymm21); - - ymm8 = _mm256_add_pd(ymm8, ymm4); - ymm11 = _mm256_add_pd(ymm11, ymm6); - ymm9 = _mm256_add_pd(ymm9, ymm14); - ymm12 = _mm256_add_pd(ymm12, ymm16); - ymm10 = _mm256_add_pd(ymm10, ymm18); - ymm13 = _mm256_add_pd(ymm13, ymm20); - - _mm256_storeu_pd((double *)tC, ymm8); - _mm256_storeu_pd((double *)(tC + 2), ymm11); - - tC += ldc; - - _mm256_storeu_pd((double *)tC, ymm9); - _mm256_storeu_pd((double *)(tC + 2), ymm12); - - tC += ldc; - - _mm256_storeu_pd((double *)tC, ymm10); - _mm256_storeu_pd((double *)(tC + 2), ymm13); - - // modify the pointer arithematic to use packed A matrix. - col_idx_start = NR; - tA_packed = D_A_pack; - row_idx_packed = 0; - lda_packed = Z_MR; - } - // Process NR columns of C matrix at a time. - for (col_idx = col_idx_start; (col_idx + (NR - 1)) < N; - col_idx += NR) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = tA_packed + row_idx_packed; + gint_t M = bli_obj_length( c ); // number of rows of Matrix C + gint_t N = bli_obj_width( c ); // number of columns of Matrix C + // number of columns of OP(A), will be updated if OP(A) is Transpose(A) + gint_t K = bli_obj_width( a ); + gint_t L = M * N; -#ifdef BLIS_ENABLE_PREFETCH - _mm_prefetch((char*)(tC + 0), _MM_HINT_T0); - _mm_prefetch((char*)(tC + 8), _MM_HINT_T0); - _mm_prefetch((char*)(tC + ldc), _MM_HINT_T0); - _mm_prefetch((char*)(tC + ldc + 8), _MM_HINT_T0); - _mm_prefetch((char*)(tC + 2 * ldc), _MM_HINT_T0); - _mm_prefetch((char*)(tC + 2 * ldc + 8), _MM_HINT_T0); -#endif - // clear scratch registers. - - - BLIS_SET_ALL_YMM_REG_ZEROS - - double *tptr = (double *)tB; - - if(conjtransa && conjtransb) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing Z_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd( - (double const *)tA); - ymm1 = _mm256_loadu_pd( - (double const *)(tA + 2)); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - ymm1 = _mm256_mul_pd(ymm1, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2 + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); - ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); - - tptr += (tb_inc_row * 2); - tB += tb_inc_row; - tA += lda_packed; - } - } - else if(conjtransa) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and multiplies it with the A - // matrix. This loop is processing - // Z_MR x K The inner loop broadcasts - // the B matrix data and multiplies it - // with the A matrix. This loop is - // processing Z_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - ymm1 = _mm256_loadu_pd((double const *) - (tA + 2)); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - ymm1 = _mm256_mul_pd(ymm1, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2 + 1)); - - ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); - ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); - - tptr += (tb_inc_row * 2); - tB += tb_inc_row; - tA += lda_packed; - } - } - else if(conjtransb) - { - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and multiplies it with the A - // matrix. This loop is processing - // Z_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - ymm1 = _mm256_loadu_pd((double const *) - (tA + 2)); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2 + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); - ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); - - tptr += (tb_inc_row * 2); - tB += tb_inc_row; - tA += lda_packed; - } - } - else //handles non-transpose case - { - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and multiplies it with the A - // matrix. This loop is processing - // Z_MR x K The inner loop broadcasts the - // B matrix data and multiplies it with - // the A matrix. This loop is processing - // Z_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - ymm1 = _mm256_loadu_pd((double const *) - (tA + 2)); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2 + 1)); - - ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); - ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); - - tptr += (tb_inc_row * 2); - tB += tb_inc_row; - tA += lda_packed; - } - } - ymm4 = _mm256_permute_pd(ymm4, 0x5); - ymm5 = _mm256_permute_pd(ymm5, 0x5); - ymm6 = _mm256_permute_pd(ymm6, 0x5); - ymm7 = _mm256_permute_pd(ymm7, 0x5); - ymm14 = _mm256_permute_pd(ymm14, 0x5); - ymm15 = _mm256_permute_pd(ymm15, 0x5); - - ymm8 = _mm256_addsub_pd(ymm8, ymm4); - ymm11 = _mm256_addsub_pd(ymm11, ymm5); - ymm9 = _mm256_addsub_pd(ymm9, ymm6); - ymm12 = _mm256_addsub_pd(ymm12, ymm7); - ymm10 = _mm256_addsub_pd(ymm10, ymm14); - ymm13 = _mm256_addsub_pd(ymm13, ymm15); - - // alpha, beta multiplication. - ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm8, ymm0); - ymm14 = _mm256_mul_pd(ymm8, ymm14); - ymm8 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm9, ymm0); - ymm14 = _mm256_mul_pd(ymm9, ymm14); - ymm9 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm10, ymm0); - ymm14 = _mm256_mul_pd(ymm10, ymm14); - ymm10 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm11, ymm0); - ymm14 = _mm256_mul_pd(ymm11, ymm14); - ymm11 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm12, ymm0); - ymm14 = _mm256_mul_pd(ymm12, ymm14); - ymm12 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm13, ymm0); - ymm14 = _mm256_mul_pd(ymm13, ymm14); - ymm13 = _mm256_hsub_pd(ymm15, ymm14); - - ymm2 = _mm256_broadcast_sd((double const *) - &beta_cast->real); - ymm3 = _mm256_broadcast_sd((double const *) - &beta_cast->imag); - - - BLIS_SET_YMM_REG_ZEROS - - if(is_beta_non_zero) - { - // multiply C by beta and accumulate col 1. - ymm0 = _mm256_loadu_pd((double const *)tC); - ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - - ymm0 = _mm256_loadu_pd((double const *)(tC + 2)); - ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6); - ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7); - - ymm0 = _mm256_loadu_pd((double const *)(tC + ldc)); - ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); - ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); - - ymm0 = _mm256_loadu_pd((double const *) - (tC + ldc + 2)); - ymm16 = _mm256_fmadd_pd(ymm0, ymm2, ymm16); - ymm17 = _mm256_fmadd_pd(ymm0, ymm3, ymm17); - - ymm0 = _mm256_loadu_pd((double const *) - (tC + ldc * 2)); - ymm18 = _mm256_fmadd_pd(ymm0, ymm2, ymm18); - ymm19 = _mm256_fmadd_pd(ymm0, ymm3, ymm19); - - ymm0 = _mm256_loadu_pd((double const *) - (tC + ldc * 2 + 2)); - ymm20 = _mm256_fmadd_pd(ymm0, ymm2, ymm20); - ymm21 = _mm256_fmadd_pd(ymm0, ymm3, ymm21); - - } - ymm5 = _mm256_permute_pd(ymm5, 0x5); - ymm7 = _mm256_permute_pd(ymm7, 0x5); - ymm15 = _mm256_permute_pd(ymm15, 0x5); - ymm17 = _mm256_permute_pd(ymm17, 0x5); - ymm19 = _mm256_permute_pd(ymm19, 0x5); - ymm21 = _mm256_permute_pd(ymm21, 0x5); - - ymm4 = _mm256_addsub_pd(ymm4, ymm5); - ymm6 = _mm256_addsub_pd(ymm6, ymm7); - ymm14 = _mm256_addsub_pd(ymm14, ymm15); - ymm16 = _mm256_addsub_pd(ymm16, ymm17); - ymm18 = _mm256_addsub_pd(ymm18, ymm19); - ymm20 = _mm256_addsub_pd(ymm20, ymm21); - - ymm8 = _mm256_add_pd(ymm8, ymm4); - ymm11 = _mm256_add_pd(ymm11, ymm6); - ymm9 = _mm256_add_pd(ymm9, ymm14); - ymm12 = _mm256_add_pd(ymm12, ymm16); - ymm10 = _mm256_add_pd(ymm10, ymm18); - ymm13 = _mm256_add_pd(ymm13, ymm20); - - _mm256_storeu_pd((double *)tC, ymm8); - _mm256_storeu_pd((double *)(tC + 2), ymm11); - - tC += ldc; - - _mm256_storeu_pd((double *)tC, ymm9); - _mm256_storeu_pd((double *)(tC + 2), ymm12); - - tC += ldc; - - _mm256_storeu_pd((double *)tC, ymm10); - _mm256_storeu_pd((double *)(tC + 2), ymm13); - } - n_remainder = N - col_idx; - if (n_remainder == 2) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; - - // clear scratch registers. - - - BLIS_SET_ALL_YMM_REG_ZEROS - double *tptr = (double *)tB; - if(conjtransa && conjtransb) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and multiplies it with the A - // matrix. This loop is processing - // Z_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - ymm1 = _mm256_loadu_pd((double const *) - (tA + 2)); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - ymm1 = _mm256_mul_pd(ymm1, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - - tptr += (tb_inc_row * 2); - tA += lda; - } - } - else if(conjtransa) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and multiplies it with the A - // matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied with matrix A columns. - ymm0 = _mm256_loadu_pd((double const*)tA); - ymm1 = _mm256_loadu_pd((double const*) - (tA + 2)); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - ymm1 = _mm256_mul_pd(ymm1, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - tptr += tb_inc_row*2; - tA += lda; - } - } - else if(conjtransb) - { - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and multiplies it with the A - // matrix. This loop is processing - // Z_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - ymm1 = _mm256_loadu_pd((double const *) - (tA + 2)); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - - tptr += (tb_inc_row * 2); - tA += lda; - } - - } - else //handles non-transpose case - { - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and multiplies it with the A - // matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied with matrix A columns. - ymm0 = _mm256_loadu_pd((double const*)tA); - ymm1 = _mm256_loadu_pd((double const*) - (tA + 2)); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - tptr += tb_inc_row*2; - tA += lda; - } - - } - ymm4 = _mm256_permute_pd(ymm4, 0x5); - ymm5 = _mm256_permute_pd(ymm5, 0x5); - ymm6 = _mm256_permute_pd(ymm6, 0x5); - ymm7 = _mm256_permute_pd(ymm7, 0x5); - - ymm8 = _mm256_addsub_pd(ymm8, ymm4); - ymm11 = _mm256_addsub_pd(ymm11, ymm5); - ymm9 = _mm256_addsub_pd(ymm9, ymm6); - ymm12 = _mm256_addsub_pd(ymm12, ymm7); - - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); - ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm8, ymm0); - ymm14 = _mm256_mul_pd(ymm8, ymm14); - ymm8 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm11, ymm0); - ymm14 = _mm256_mul_pd(ymm11, ymm14); - ymm11 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm9, ymm0); - ymm14 = _mm256_mul_pd(ymm9, ymm14); - ymm9 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm12, ymm0); - ymm14 = _mm256_mul_pd(ymm12, ymm14); - ymm12 = _mm256_hsub_pd(ymm15, ymm14); - - - BLIS_SET_YMM_REG_ZEROS - ymm2 = _mm256_broadcast_sd((double const *) - &beta_cast->real); - ymm3 = _mm256_broadcast_sd((double const *) - &beta_cast->imag); - - if(is_beta_non_zero) - { - // multiply C by beta and accumulate col 1. - ymm0 = _mm256_loadu_pd((double const *)tC); - ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - - ymm0 = _mm256_loadu_pd((double const *)(tC + 2)); - ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6); - ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7); - - ymm0 = _mm256_loadu_pd((double const *)(tC + ldc)); - ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); - ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); - - ymm0 = _mm256_loadu_pd((double const *) - (tC + ldc + 2)); - ymm16 = _mm256_fmadd_pd(ymm0, ymm2, ymm16); - ymm17 = _mm256_fmadd_pd(ymm0, ymm3, ymm17); - - } - ymm5 = _mm256_permute_pd(ymm5, 0x5); - ymm7 = _mm256_permute_pd(ymm7, 0x5); - ymm15 = _mm256_permute_pd(ymm15, 0x5); - ymm17 = _mm256_permute_pd(ymm17, 0x5); - - ymm4 = _mm256_addsub_pd(ymm4, ymm5); - ymm6 = _mm256_addsub_pd(ymm6, ymm7); - ymm14 = _mm256_addsub_pd(ymm14, ymm15); - ymm16 = _mm256_addsub_pd(ymm16, ymm17); - - ymm8 = _mm256_add_pd(ymm8, ymm4); - ymm11 = _mm256_add_pd(ymm11, ymm6); - ymm9 = _mm256_add_pd(ymm9, ymm14); - ymm12 = _mm256_add_pd(ymm12, ymm16); - - _mm256_storeu_pd((double *)(tC + 0), ymm8); - _mm256_storeu_pd((double *)(tC + 2), ymm11); - tC += ldc; - _mm256_storeu_pd((double *)tC, ymm9); - _mm256_storeu_pd((double *)(tC + 2), ymm12); - } - - if (n_remainder == 1) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; - - // clear scratch registers. - - - BLIS_SET_ALL_YMM_REG_ZEROS - - double *tptr = (double *)tB; - if(conjtransa && conjtransb) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and multiplies it with the A - // matrix. This loop is processing - // Z_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - ymm1 = _mm256_loadu_pd((double const *) - (tA + 2)); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - ymm1 = _mm256_mul_pd(ymm1, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - tptr += (tb_inc_row * 2); - tA += lda; - } - } - else if(conjtransa) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - for (k = 0; k < K; ++k) - { - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - tptr += tb_inc_row*2; - - //broadcasted matrix B elements are - //multiplied with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - ymm1 = _mm256_loadu_pd((double const *) - (tA + 2)); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - ymm1 = _mm256_mul_pd(ymm1, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - tA += lda; - } - } - else if(conjtransb) - { - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing Z_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - ymm1 = _mm256_loadu_pd((double const *) - (tA + 2)); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - tptr += (tb_inc_row * 2); - tA += lda; - } - } - else //handles non-transpose case - { - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - tptr += tb_inc_row*2; - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - ymm1 = _mm256_loadu_pd((double const *) - (tA + 2)); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - tA += lda; - } - - } - ymm4 = _mm256_permute_pd(ymm4, 0x5); - ymm5 = _mm256_permute_pd(ymm5, 0x5); - - ymm8 = _mm256_addsub_pd(ymm8, ymm4); - ymm11 = _mm256_addsub_pd(ymm11, ymm5); - - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); - ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm8, ymm0); - ymm14 = _mm256_mul_pd(ymm8, ymm14); - ymm8 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm11, ymm0); - ymm14 = _mm256_mul_pd(ymm11, ymm14); - ymm11 = _mm256_hsub_pd(ymm15, ymm14); - - - BLIS_SET_YMM_REG_ZEROS - ymm2 = _mm256_broadcast_sd((double const *) - &beta_cast->real); - ymm3 = _mm256_broadcast_sd((double const *) - &beta_cast->imag); - - if(is_beta_non_zero) - { - // multiply C by beta and accumulate col 1. - ymm0 = _mm256_loadu_pd((double const *)tC); - ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - - ymm0 = _mm256_loadu_pd((double const *)(tC + 2)); - ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6); - ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7); - } - ymm5 = _mm256_permute_pd(ymm5, 0x5); - ymm7 = _mm256_permute_pd(ymm7, 0x5); - - ymm4 = _mm256_addsub_pd(ymm4, ymm5); - ymm6 = _mm256_addsub_pd(ymm6, ymm7); - - ymm8 = _mm256_add_pd(ymm8, ymm4); - ymm11 = _mm256_add_pd(ymm11, ymm6); - - _mm256_storeu_pd((double *)tC, ymm8); - _mm256_storeu_pd((double *)(tC + 2), ymm11); - } - } - m_remainder = M - row_idx; - - if ((m_remainder == 3)) - { - m_remainder -= 3; - __m128d xmm0; - - for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; - - - BLIS_SET_ALL_YMM_REG_ZEROS - - xmm0 = _mm_setzero_pd(); - - double *tptr = (double *)tB; - if(conjtransa && conjtransb) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing Z_MR x K - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing Z_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - xmm0 = _mm_loadu_pd((double const *)(tA + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - ymm1 = _mm256_mul_pd(ymm1, ymm20); - - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2 + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); - ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); - - tptr += (tb_inc_row * 2); - tB += tb_inc_row; - tA += lda; - } - } - else if(conjtransa) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing Z_MR x K - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing Z_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - xmm0 = _mm_loadu_pd((double const *)(tA + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - ymm1 = _mm256_mul_pd(ymm1, ymm20); - - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2 + 1)); - - ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); - ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); - - tptr += (tb_inc_row * 2); - tB += tb_inc_row; - tA += lda; - } - } - else if(conjtransb) - { - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing Z_MR x K - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing Z_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - xmm0 = _mm_loadu_pd((double const *) - (tA + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); - - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2 + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); - ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); - - tptr += (tb_inc_row * 2); - tB += tb_inc_row; - tA += lda; - } - - } - else - { - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing Z_MR x K - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing Z_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - xmm0 = _mm_loadu_pd((double const *) - (tA + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2 + 1)); - - ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); - ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); - - tptr += (tb_inc_row * 2); - tB += tb_inc_row; - tA += lda; - } - - } - ymm4 = _mm256_permute_pd(ymm4, 0x5); - ymm5 = _mm256_permute_pd(ymm5, 0x5); - ymm6 = _mm256_permute_pd(ymm6, 0x5); - ymm7 = _mm256_permute_pd(ymm7, 0x5); - ymm14 = _mm256_permute_pd(ymm14, 0x5); - ymm15 = _mm256_permute_pd(ymm15, 0x5); - - ymm8 = _mm256_addsub_pd(ymm8, ymm4); - ymm11 = _mm256_addsub_pd(ymm11, ymm5); - ymm9 = _mm256_addsub_pd(ymm9, ymm6); - ymm12 = _mm256_addsub_pd(ymm12, ymm7); - ymm10 = _mm256_addsub_pd(ymm10, ymm14); - ymm13 = _mm256_addsub_pd(ymm13, ymm15); - // alpha, beta multiplication. - ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm8, ymm0); - ymm14 = _mm256_mul_pd(ymm8, ymm14); - ymm8 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm9, ymm0); - ymm14 = _mm256_mul_pd(ymm9, ymm14); - ymm9 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm10, ymm0); - ymm14 = _mm256_mul_pd(ymm10, ymm14); - ymm10 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm11, ymm0); - ymm14 = _mm256_mul_pd(ymm11, ymm14); - ymm11 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm12, ymm0); - ymm14 = _mm256_mul_pd(ymm12, ymm14); - ymm12 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm13, ymm0); - ymm14 = _mm256_mul_pd(ymm13, ymm14); - ymm13 = _mm256_hsub_pd(ymm15, ymm14); - - ymm2 = _mm256_broadcast_sd((double const *) - &beta_cast->real); - ymm3 = _mm256_broadcast_sd((double const *) - &beta_cast->imag); - - - - BLIS_SET_YMM_REG_ZEROS - xmm0 = _mm_setzero_pd(); - - if(is_beta_non_zero) - { - // multiply C by beta and accumulate col 1. - ymm0 = _mm256_loadu_pd((double const *)tC); - xmm0 = _mm_loadu_pd((double const *)(tC + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); - - ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - ymm6 = _mm256_fmadd_pd(ymm1, ymm2, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - ymm0 = _mm256_loadu_pd((double const *) - (tC + ldc)); - xmm0 = _mm_loadu_pd((double const *) - (tC + ldc + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); - - ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); - ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); - ymm16 = _mm256_fmadd_pd(ymm1, ymm2, ymm16); - ymm17 = _mm256_fmadd_pd(ymm1, ymm3, ymm17); - - ymm0 = _mm256_loadu_pd((double const *) - (tC + ldc * 2)); - xmm0 = _mm_loadu_pd((double const *) - (tC + ldc * 2 + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); - - ymm18 = _mm256_fmadd_pd(ymm0, ymm2, ymm18); - ymm19 = _mm256_fmadd_pd(ymm0, ymm3, ymm19); - ymm20 = _mm256_fmadd_pd(ymm1, ymm2, ymm20); - ymm21 = _mm256_fmadd_pd(ymm1, ymm3, ymm21); - - } - ymm5 = _mm256_permute_pd(ymm5, 0x5); - ymm7 = _mm256_permute_pd(ymm7, 0x5); - ymm15 = _mm256_permute_pd(ymm15, 0x5); - ymm17 = _mm256_permute_pd(ymm17, 0x5); - ymm19 = _mm256_permute_pd(ymm19, 0x5); - ymm21 = _mm256_permute_pd(ymm21, 0x5); - - ymm4 = _mm256_addsub_pd(ymm4, ymm5); - ymm6 = _mm256_addsub_pd(ymm6, ymm7); - ymm14 = _mm256_addsub_pd(ymm14, ymm15); - ymm16 = _mm256_addsub_pd(ymm16, ymm17); - ymm18 = _mm256_addsub_pd(ymm18, ymm19); - ymm20 = _mm256_addsub_pd(ymm20, ymm21); - - ymm8 = _mm256_add_pd(ymm8, ymm4); - ymm11 = _mm256_add_pd(ymm11, ymm6); - ymm9 = _mm256_add_pd(ymm9, ymm14); - ymm12 = _mm256_add_pd(ymm12, ymm16); - ymm10 = _mm256_add_pd(ymm10, ymm18); - ymm13 = _mm256_add_pd(ymm13, ymm20); - - _mm256_storeu_pd((double *)tC, ymm8); - xmm0 = _mm256_extractf128_pd(ymm11, 0); - _mm_storeu_pd((double *)(tC + 2), xmm0); - - tC += ldc; - - _mm256_storeu_pd((double *)tC, ymm9); - xmm0 = _mm256_extractf128_pd(ymm12, 0); - _mm_storeu_pd((double *)(tC + 2), xmm0); - - tC += ldc; - - _mm256_storeu_pd((double *)tC, ymm10); - xmm0 = _mm256_extractf128_pd(ymm13, 0); - _mm_storeu_pd((double *)(tC + 2), xmm0); - } - n_remainder = N - col_idx; - if (n_remainder == 2) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; - - // clear scratch registers. - - BLIS_SET_ALL_YMM_REG_ZEROS - xmm0 = _mm_setzero_pd(); - - double *tptr = (double *)tB; - if(conjtransa && conjtransb) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd((tptr - + tb_inc_col - * 0)); - ymm3 = _mm256_broadcast_sd((tptr - + tb_inc_col - * 0 + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - xmm0 = _mm_loadu_pd((double const *) - (tA + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - ymm1 = _mm256_mul_pd(ymm1, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - tptr += tb_inc_row*2; - tA += lda; - } - } - else if(conjtransa) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd((tptr - + tb_inc_col - * 0)); - ymm3 = _mm256_broadcast_sd((tptr - + tb_inc_col - * 0 + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - xmm0 = _mm_loadu_pd((double const *) - (tA + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - ymm1 = _mm256_mul_pd(ymm1, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - tptr += tb_inc_row*2; - tA += lda; - } - } - else if(conjtransb) - { - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd((tptr - + tb_inc_col - * 0)); - ymm3 = _mm256_broadcast_sd((tptr - + tb_inc_col - * 0 + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - xmm0 = _mm_loadu_pd((double const *) - (tA + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - tptr += tb_inc_row*2; - tA += lda; - } - } - else - { - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd((tptr - + tb_inc_col - * 0)); - ymm3 = _mm256_broadcast_sd((tptr - + tb_inc_col - * 0 + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - xmm0 = _mm_loadu_pd((double const *) - (tA + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - tptr += tb_inc_row*2; - tA += lda; - } - - } - ymm4 = _mm256_permute_pd(ymm4, 0x5); - ymm5 = _mm256_permute_pd(ymm5, 0x5); - ymm6 = _mm256_permute_pd(ymm6, 0x5); - ymm7 = _mm256_permute_pd(ymm7, 0x5); - - ymm8 = _mm256_addsub_pd(ymm8, ymm4); - ymm11 = _mm256_addsub_pd(ymm11, ymm5); - ymm9 = _mm256_addsub_pd(ymm9, ymm6); - ymm12 = _mm256_addsub_pd(ymm12, ymm7); - - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); - ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm8, ymm0); - ymm14 = _mm256_mul_pd(ymm8, ymm14); - ymm8 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm11, ymm0); - ymm14 = _mm256_mul_pd(ymm11, ymm14); - ymm11 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm9, ymm0); - ymm14 = _mm256_mul_pd(ymm9, ymm14); - ymm9 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm12, ymm0); - ymm14 = _mm256_mul_pd(ymm12, ymm14); - ymm12 = _mm256_hsub_pd(ymm15, ymm14); - - - - BLIS_SET_YMM_REG_ZEROS - xmm0 = _mm_setzero_pd(); - - ymm2 = _mm256_broadcast_sd((double const *) - &beta_cast->real); - ymm3 = _mm256_broadcast_sd((double const *) - &beta_cast->imag); - - if(is_beta_non_zero) - { - // multiply C by beta and accumulate col 1. - ymm0 = _mm256_loadu_pd((double const *)tC); - xmm0 = _mm_loadu_pd((double const *)(tC + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); - - ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - ymm6 = _mm256_fmadd_pd(ymm1, ymm2, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - ymm0 = _mm256_loadu_pd((double const *)(tC + ldc)); - xmm0 = _mm_loadu_pd((double const *)(tC + ldc + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); - - ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); - ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); - ymm16 = _mm256_fmadd_pd(ymm1, ymm2, ymm16); - ymm17 = _mm256_fmadd_pd(ymm1, ymm3, ymm17); - - } - ymm5 = _mm256_permute_pd(ymm5, 0x5); - ymm7 = _mm256_permute_pd(ymm7, 0x5); - ymm15 = _mm256_permute_pd(ymm15, 0x5); - ymm17 = _mm256_permute_pd(ymm17, 0x5); - - ymm4 = _mm256_addsub_pd(ymm4, ymm5); - ymm6 = _mm256_addsub_pd(ymm6, ymm7); - ymm14 = _mm256_addsub_pd(ymm14, ymm15); - ymm16 = _mm256_addsub_pd(ymm16, ymm17); - - ymm8 = _mm256_add_pd(ymm8, ymm4); - ymm11 = _mm256_add_pd(ymm11, ymm6); - ymm9 = _mm256_add_pd(ymm9, ymm14); - ymm12 = _mm256_add_pd(ymm12, ymm16); - - _mm256_storeu_pd((double *)tC, ymm8); - xmm0 = _mm256_extractf128_pd(ymm11, 0); - _mm_storeu_pd((double *)(tC + 2), xmm0); - - tC += ldc; - _mm256_storeu_pd((double *)tC, ymm9); - xmm0 = _mm256_extractf128_pd(ymm12, 0); - _mm_storeu_pd((double *)(tC + 2), xmm0); - } - if (n_remainder == 1) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; - - // clear scratch registers. - - - BLIS_SET_ALL_YMM_REG_ZEROS - xmm0 = _mm_setzero_pd(); - - double *tptr = (double *)tB; - if(conjtransa && conjtransb) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - xmm0 = _mm_loadu_pd((double const *) - (tA + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - ymm1 = _mm256_mul_pd(ymm1, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - tptr += tb_inc_row*2; - tA += lda; - } - } - else if(conjtransa) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - xmm0 = _mm_loadu_pd((double const *) - (tA + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - ymm1 = _mm256_mul_pd(ymm1, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - tptr += tb_inc_row*2; - tA += lda; - } - } - else if(conjtransb) - { - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - xmm0 = _mm_loadu_pd((double const *) - (tA + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - tptr += tb_inc_row*2; - tA += lda; - } - } - else - { - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - xmm0 = _mm_loadu_pd((double const *) - (tA + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - tptr += tb_inc_row*2; - tA += lda; - } - } - ymm4 = _mm256_permute_pd(ymm4, 0x5); - ymm5 = _mm256_permute_pd(ymm5, 0x5); - - ymm8 = _mm256_addsub_pd(ymm8, ymm4); - ymm11 = _mm256_addsub_pd(ymm11, ymm5); - - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); - ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm8, ymm0); - ymm14 = _mm256_mul_pd(ymm8, ymm14); - ymm8 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm11, ymm0); - ymm14 = _mm256_mul_pd(ymm11, ymm14); - ymm11 = _mm256_hsub_pd(ymm15, ymm14); - - - - BLIS_SET_YMM_REG_ZEROS - xmm0 = _mm_setzero_pd(); - - ymm2 = _mm256_broadcast_sd((double const *) - &beta_cast->real); - ymm3 = _mm256_broadcast_sd((double const *) - &beta_cast->imag); - - if(is_beta_non_zero) - { - // multiply C by beta and accumulate col 1. - ymm0 = _mm256_loadu_pd((double const *)tC); - xmm0 = _mm_loadu_pd((double const *)(tC + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); - - ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - ymm6 = _mm256_fmadd_pd(ymm1, ymm2, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - } - ymm5 = _mm256_permute_pd(ymm5, 0x5); - ymm7 = _mm256_permute_pd(ymm7, 0x5); - - ymm4 = _mm256_addsub_pd(ymm4, ymm5); - ymm6 = _mm256_addsub_pd(ymm6, ymm7); - - ymm8 = _mm256_add_pd(ymm8, ymm4); - ymm11 = _mm256_add_pd(ymm11, ymm6); - - _mm256_storeu_pd((double *)tC, ymm8); - xmm0 = _mm256_extractf128_pd(ymm11, 0); - _mm_storeu_pd((double *)(tC + 2), xmm0); - } - } - if ((m_remainder == 2)) - { - m_remainder -= 2; - - for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; - - - - BLIS_SET_ALL_YMM_REG_ZEROS - double *tptr = (double *)tB; - if(conjtransa && conjtransb) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing Z_MR x K - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing Z_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2 + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - - tptr += (tb_inc_row * 2); - tB += tb_inc_row; - tA += lda; - } - } - else if(conjtransa) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing Z_MR x K - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing Z_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2 + 1)); - - ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - - tptr += (tb_inc_row * 2); - tB += tb_inc_row; - tA += lda; - } - } - else if(conjtransb) - { - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing Z_MR x K - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing Z_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2 + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - - tptr += (tb_inc_row * 2); - tB += tb_inc_row; - tA += lda; - } - } - else - { - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing Z_MR x K - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing Z_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2 + 1)); - - ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - - tptr += (tb_inc_row * 2); - tB += tb_inc_row; - tA += lda; - } - } - ymm4 = _mm256_permute_pd(ymm4, 0x5); - ymm6 = _mm256_permute_pd(ymm6, 0x5); - ymm14 = _mm256_permute_pd(ymm14, 0x5); - - ymm8 = _mm256_addsub_pd(ymm8, ymm4); - ymm9 = _mm256_addsub_pd(ymm9, ymm6); - ymm10 = _mm256_addsub_pd(ymm10, ymm14); - // alpha, beta multiplication. - ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm8, ymm0); - ymm14 = _mm256_mul_pd(ymm8, ymm14); - ymm8 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm9, ymm0); - ymm14 = _mm256_mul_pd(ymm9, ymm14); - ymm9 = _mm256_hsub_pd(ymm15, ymm14); - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm10, ymm0); - ymm14 = _mm256_mul_pd(ymm10, ymm14); - ymm10 = _mm256_hsub_pd(ymm15, ymm14); - - ymm2 = _mm256_broadcast_sd((double const *) - &beta_cast->real); - ymm3 = _mm256_broadcast_sd((double const *) - &beta_cast->imag); - - - BLIS_SET_YMM_REG_ZEROS - if(is_beta_non_zero) - { - // multiply C by beta and accumulate col 1. - ymm0 = _mm256_loadu_pd((double const *)tC); - - ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - - ymm0 = _mm256_loadu_pd((double const *)(tC + ldc)); - - ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); - ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); - - ymm0 = _mm256_loadu_pd((double const *) - (tC + ldc * 2)); - - ymm18 = _mm256_fmadd_pd(ymm0, ymm2, ymm18); - ymm19 = _mm256_fmadd_pd(ymm0, ymm3, ymm19); - - } - ymm5 = _mm256_permute_pd(ymm5, 0x5); - ymm15 = _mm256_permute_pd(ymm15, 0x5); - ymm19 = _mm256_permute_pd(ymm19, 0x5); - - ymm4 = _mm256_addsub_pd(ymm4, ymm5); - ymm14 = _mm256_addsub_pd(ymm14, ymm15); - ymm18 = _mm256_addsub_pd(ymm18, ymm19); - - ymm8 = _mm256_add_pd(ymm8, ymm4); - ymm9 = _mm256_add_pd(ymm9, ymm14); - ymm10 = _mm256_add_pd(ymm10, ymm18); - - _mm256_storeu_pd((double *)tC, ymm8); - - tC += ldc; - - _mm256_storeu_pd((double *)tC, ymm9); - - tC += ldc; - - _mm256_storeu_pd((double *)tC, ymm10); - } - n_remainder = N - col_idx; - if (n_remainder == 2) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; - - // clear scratch registers. - - BLIS_SET_ALL_YMM_REG_ZEROS - - double *tptr = (double *)tB; - if(conjtransa && conjtransb) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - - tptr += tb_inc_row*2; - tA += lda; - } - } - else if(conjtransa) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - - tptr += tb_inc_row*2; - tA += lda; - } - } - else if(conjtransb) - { - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - - tptr += tb_inc_row*2; - tA += lda; - } - } - else - { - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - - tptr += tb_inc_row*2; - tA += lda; - } - - } - ymm4 = _mm256_permute_pd(ymm4, 0x5); - ymm6 = _mm256_permute_pd(ymm6, 0x5); - - ymm8 = _mm256_addsub_pd(ymm8, ymm4); - ymm9 = _mm256_addsub_pd(ymm9, ymm6); - - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); - ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm8, ymm0); - ymm14 = _mm256_mul_pd(ymm8, ymm14); - ymm8 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm9, ymm0); - ymm14 = _mm256_mul_pd(ymm9, ymm14); - ymm9 = _mm256_hsub_pd(ymm15, ymm14); - - - BLIS_SET_YMM_REG_ZEROS - - ymm2 = _mm256_broadcast_sd((double const *) - &beta_cast->real); - ymm3 = _mm256_broadcast_sd((double const *) - &beta_cast->imag); - - if(is_beta_non_zero) - { - // multiply C by beta and accumulate col 1. - ymm0 = _mm256_loadu_pd((double const *)tC); - - ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - - ymm0 = _mm256_loadu_pd((double const *)(tC + ldc)); - - ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); - ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); - - } - ymm5 = _mm256_permute_pd(ymm5, 0x5); - ymm15 = _mm256_permute_pd(ymm15, 0x5); - - ymm4 = _mm256_addsub_pd(ymm4, ymm5); - ymm14 = _mm256_addsub_pd(ymm14, ymm15); - - ymm8 = _mm256_add_pd(ymm8, ymm4); - ymm9 = _mm256_add_pd(ymm9, ymm14); - - _mm256_storeu_pd((double *)tC, ymm8); - tC += ldc; - _mm256_storeu_pd((double *)tC, ymm9); - } - if (n_remainder == 1) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; - - // clear scratch registers. - - - BLIS_SET_ALL_YMM_REG_ZEROS - - double *tptr = (double *)tB; - if(conjtransa && conjtransb) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - tptr += tb_inc_row*2; - tA += lda; - } - } - else if(conjtransa) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - tptr += tb_inc_row*2; - tA += lda; - } - } - else if(conjtransb) - { - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - tptr += tb_inc_row*2; - tA += lda; - } - } - else - { - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - tptr += tb_inc_row*2; - tA += lda; - } - - } - ymm4 = _mm256_permute_pd(ymm4, 0x5); - - ymm8 = _mm256_addsub_pd(ymm8, ymm4); - - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); - ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm8, ymm0); - ymm14 = _mm256_mul_pd(ymm8, ymm14); - ymm8 = _mm256_hsub_pd(ymm15, ymm14); - - - - BLIS_SET_YMM_REG_ZEROS - ymm2 = _mm256_broadcast_sd((double const *) - &beta_cast->real); - ymm3 = _mm256_broadcast_sd((double const *) - &beta_cast->imag); - - if(is_beta_non_zero) - { - // multiply C by beta and accumulate col 1. - ymm0 = _mm256_loadu_pd((double const *)tC); - - ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - } - ymm5 = _mm256_permute_pd(ymm5, 0x5); - - ymm4 = _mm256_addsub_pd(ymm4, ymm5); - - ymm8 = _mm256_add_pd(ymm8, ymm4); - - _mm256_storeu_pd((double *)tC, ymm8); - } - } - if ((m_remainder == 1)) - { - m_remainder -= 1; - __m128d xmm0; - - for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; - - - - BLIS_SET_ALL_YMM_REG_ZEROS - xmm0 = _mm_setzero_pd(); - - double *tptr = (double *)tB; - if(conjtransa && conjtransb) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing D_MR x K - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing D_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - xmm0 = _mm_loadu_pd((double const *)(tA)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2 + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - - tptr += (tb_inc_row * 2); - tB += tb_inc_row; - tA += lda; - } - } - else if(conjtransa) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing D_MR x K - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing D_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - xmm0 = _mm_loadu_pd((double const *)(tA)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2 + 1)); - - ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - - tptr += (tb_inc_row * 2); - tB += tb_inc_row; - tA += lda; - } - } - else if(conjtransb) - { - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing D_MR x K - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing D_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are multiplied - //with matrix A columns. - xmm0 = _mm_loadu_pd((double const *)(tA)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2 + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - - tptr += (tb_inc_row * 2); - tB += tb_inc_row; - tA += lda; - } - } - else - { - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing D_MR x K - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing D_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - xmm0 = _mm_loadu_pd((double const *)(tA)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2 + 1)); - - ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - - tptr += (tb_inc_row * 2); - tB += tb_inc_row; - tA += lda; - } - } - ymm4 = _mm256_permute_pd(ymm4, 0x5); - ymm6 = _mm256_permute_pd(ymm6, 0x5); - ymm14 = _mm256_permute_pd(ymm14, 0x5); - - ymm8 = _mm256_addsub_pd(ymm8, ymm4); - ymm9 = _mm256_addsub_pd(ymm9, ymm6); - ymm10 = _mm256_addsub_pd(ymm10, ymm14); - // alpha, beta multiplication. - ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm8, ymm0); - ymm14 = _mm256_mul_pd(ymm8, ymm14); - ymm8 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm9, ymm0); - ymm14 = _mm256_mul_pd(ymm9, ymm14); - ymm9 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm10, ymm0); - ymm14 = _mm256_mul_pd(ymm10, ymm14); - ymm10 = _mm256_hsub_pd(ymm15, ymm14); - - ymm2 = _mm256_broadcast_sd((double const *) - &beta_cast->real); - ymm3 = _mm256_broadcast_sd((double const *) - &beta_cast->imag); - - BLIS_SET_YMM_REG_ZEROS - xmm0 = _mm_setzero_pd(); - - if(is_beta_non_zero) - { - // multiply C by beta and accumulate col 1. - xmm0 = _mm_loadu_pd((double const *)(tC)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); - - ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - - xmm0 = _mm_loadu_pd((double const *)(tC + ldc)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); - - ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); - ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); - - xmm0 = _mm_loadu_pd((double const *)(tC + ldc * 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); - - ymm18 = _mm256_fmadd_pd(ymm0, ymm2, ymm18); - ymm19 = _mm256_fmadd_pd(ymm0, ymm3, ymm19); - - } - ymm5 = _mm256_permute_pd(ymm5, 0x5); - ymm15 = _mm256_permute_pd(ymm15, 0x5); - ymm19 = _mm256_permute_pd(ymm19, 0x5); - - ymm4 = _mm256_addsub_pd(ymm4, ymm5); - ymm14 = _mm256_addsub_pd(ymm14, ymm15); - ymm18 = _mm256_addsub_pd(ymm18, ymm19); - - ymm8 = _mm256_add_pd(ymm8, ymm4); - ymm9 = _mm256_add_pd(ymm9, ymm14); - ymm10 = _mm256_add_pd(ymm10, ymm18); - - xmm0 = _mm256_extractf128_pd(ymm8, 0); - _mm_storeu_pd((double *)tC, xmm0); - - tC += ldc; - - xmm0 = _mm256_extractf128_pd(ymm9, 0); - _mm_storeu_pd((double *)tC, xmm0); - - tC += ldc; - xmm0 = _mm256_extractf128_pd(ymm10, 0); - _mm_storeu_pd((double *)tC, xmm0); - } - n_remainder = N - col_idx; - if (n_remainder == 2) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; - - // clear scratch registers. - - - BLIS_SET_ALL_YMM_REG_ZEROS - xmm0 = _mm_setzero_pd(); - - double *tptr = (double *)tB; - if(conjtransa && conjtransb) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - xmm0 = _mm_loadu_pd((double const *)(tA)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - - tptr += tb_inc_row*2; - tA += lda; - } - } - else if(conjtransa) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - xmm0 = _mm_loadu_pd((double const *)(tA)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - - tptr += tb_inc_row*2; - tA += lda; - } - } - else if(conjtransb) - { - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - xmm0 = _mm_loadu_pd((double const *)(tA)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - - tptr += tb_inc_row*2; - tA += lda; - } - } - else - { - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - xmm0 = _mm_loadu_pd((double const *)(tA)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - - tptr += tb_inc_row*2; - tA += lda; - } - } - ymm4 = _mm256_permute_pd(ymm4, 0x5); - ymm6 = _mm256_permute_pd(ymm6, 0x5); - - ymm8 = _mm256_addsub_pd(ymm8, ymm4); - ymm9 = _mm256_addsub_pd(ymm9, ymm6); - - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); - ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm8, ymm0); - ymm14 = _mm256_mul_pd(ymm8, ymm14); - ymm8 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm9, ymm0); - ymm14 = _mm256_mul_pd(ymm9, ymm14); - ymm9 = _mm256_hsub_pd(ymm15, ymm14); - - - - BLIS_SET_YMM_REG_ZEROS - xmm0 = _mm_setzero_pd(); - - - ymm2 = _mm256_broadcast_sd((double const *) - &beta_cast->real); - ymm3 = _mm256_broadcast_sd((double const *) - &beta_cast->imag); - - if(is_beta_non_zero) - { - // multiply C by beta and accumulate col 1. - xmm0 = _mm_loadu_pd((double const *)(tC)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); - - ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - - xmm0 = _mm_loadu_pd((double const *)(tC + ldc)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); - - ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); - ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); - } - ymm5 = _mm256_permute_pd(ymm5, 0x5); - ymm15 = _mm256_permute_pd(ymm15, 0x5); - - ymm4 = _mm256_addsub_pd(ymm4, ymm5); - ymm14 = _mm256_addsub_pd(ymm14, ymm15); - - ymm8 = _mm256_add_pd(ymm8, ymm4); - ymm9 = _mm256_add_pd(ymm9, ymm14); - - xmm0 = _mm256_extractf128_pd(ymm8, 0); - _mm_storeu_pd((double *)tC, xmm0); - tC += ldc; - xmm0 = _mm256_extractf128_pd(ymm9, 0); - _mm_storeu_pd((double *)tC, xmm0); - } - if (n_remainder == 1) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; - - // clear scratch registers. - - BLIS_SET_ALL_YMM_REG_ZEROS - - xmm0 = _mm_setzero_pd(); - - double *tptr = (double *)tB; - if(conjtransa && conjtransb) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - xmm0 = _mm_loadu_pd((double const *)(tA)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - tptr += tb_inc_row*2; - tA += lda; - } - } - else if(conjtransa) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - xmm0 = _mm_loadu_pd((double const *)(tA)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - tptr += tb_inc_row*2; - tA += lda; - } - } - else if(conjtransb) - { - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - xmm0 = _mm_loadu_pd((double const *)(tA)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - tptr += tb_inc_row*2; - tA += lda; - } - - } - else - { - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - xmm0 = _mm_loadu_pd((double const *)(tA)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - tptr += tb_inc_row*2; - tA += lda; - } - - } - ymm4 = _mm256_permute_pd(ymm4, 0x5); - - ymm8 = _mm256_addsub_pd(ymm8, ymm4); - - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); - ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm8, ymm0); - ymm14 = _mm256_mul_pd(ymm8, ymm14); - ymm8 = _mm256_hsub_pd(ymm15, ymm14); - - - - BLIS_SET_YMM_REG_ZEROS - xmm0 = _mm_setzero_pd(); - - ymm2 = _mm256_broadcast_sd((double const *) - &beta_cast->real); - ymm3 = _mm256_broadcast_sd((double const *) - &beta_cast->imag); - - if(is_beta_non_zero) - { - // multiply C by beta and accumulate col 1. - xmm0 = _mm_loadu_pd((double const *)(tC)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); - - ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - } - ymm5 = _mm256_permute_pd(ymm5, 0x5); - - ymm4 = _mm256_addsub_pd(ymm4, ymm5); - - ymm8 = _mm256_add_pd(ymm8, ymm4); - - xmm0 = _mm256_extractf128_pd(ymm8, 0); - _mm_storeu_pd((double *)tC, xmm0); - - } - } - // Return the buffer to pool - if ((required_packing_A == 1) && bli_mem_is_alloc( &local_mem_buf_A_s )) { -#ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_zgemm_small(): releasing mem pool block\n" ); -#endif - bli_membrk_release(&rntm, - &local_mem_buf_A_s); - } - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return BLIS_SUCCESS; - } - else - { - AOCL_DTL_TRACE_EXIT_ERR( - AOCL_DTL_LEVEL_INFO, - "Invalid dimesions for small gemm." - ); - return BLIS_NONCONFORMAL_DIMENSIONS; - } -}; + if(L && K ) + { + guint_t lda = bli_obj_col_stride( a ); // column stride of matrix OP(A). + guint_t ldb = bli_obj_col_stride( b ); // column stride of matrix OP(B). + guint_t ldc = bli_obj_col_stride( c ); // column stride of matrix C + guint_t row_idx, col_idx, k; + dcomplex *A = bli_obj_buffer_at_off(a); //pointer to elements of Matrix A + dcomplex *B = bli_obj_buffer_at_off(b); //pointer to elements of Matrix B + dcomplex *C = bli_obj_buffer_at_off(c); //pointer to elements of Matrix C -static err_t bli_zgemm_small_At - ( - obj_t* alpha, - obj_t* a, - obj_t* b, - obj_t* beta, - obj_t* c, - cntx_t* cntx, - cntl_t* cntl - ) -{ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO); - - bool conjtransa = bli_obj_has_conj(a); - bool conjtransb = bli_obj_has_conj(b); - - gint_t M = bli_obj_length( c ); // number of rows of Matrix C - gint_t N = bli_obj_width( c ); // number of columns of Matrix C - gint_t K = bli_obj_width_after_trans( a ); // number of columns of OP(A) - - - if (N<3) //Implemenation assumes that N is atleast 3. - { - AOCL_DTL_TRACE_EXIT_ERR( - AOCL_DTL_LEVEL_INFO, - "N < 3, cannot be processed by small gemm" - ); - return BLIS_NOT_YET_IMPLEMENTED; - } - - if( M && N && K ) - { - guint_t lda = bli_obj_col_stride( a ); // column stride of matrix OP(A) - guint_t ldb = bli_obj_col_stride( b ); // column stride of matrix OP(B) - guint_t ldc = bli_obj_col_stride( c ); // column stride of matrix C - guint_t row_idx, col_idx, k; - dcomplex *A = bli_obj_buffer_at_off(a); //pointer to elements of Matrix A - dcomplex *B = bli_obj_buffer_at_off(b); //pointer to elements of Matrix B - dcomplex *C = bli_obj_buffer_at_off(c); //pointer to elements of Matrix C - - dcomplex *tA = A, *tB = B, *tC = C;//, *tA_pack; - dcomplex *tA_packed; // temprorary pointer to hold packed A memory pointer - guint_t row_idx_packed; //packed A memory row index - guint_t lda_packed; //lda of packed A - dim_t tb_inc_row = 1; // row stride of matrix B - dim_t tb_inc_col = ldb; // column stride of matrix B - - dcomplex *alpha_cast, *beta_cast; // alpha, beta multiples - alpha_cast = bli_obj_buffer_for_1x1(BLIS_DCOMPLEX, alpha); - beta_cast = bli_obj_buffer_for_1x1(BLIS_DCOMPLEX, beta); - - gint_t required_packing_A = 1; - mem_t local_mem_buf_A_s; - dcomplex *D_A_pack = NULL; - rntm_t rntm; - - if( bli_obj_has_trans( b ) ) - { - tb_inc_col = 1; // switch row and column strides - tb_inc_row = ldb; - } - - __m256d ymm4, ymm5, ymm6, ymm7; - __m256d ymm8, ymm9, ymm10, ymm11; - __m256d ymm12, ymm13, ymm14, ymm15; - __m256d ymm16, ymm17, ymm18, ymm19, ymm20, ymm21; - __m256d ymm0, ymm1, ymm2, ymm3; - - gint_t n_remainder; // If the N is non multiple of 3.(N%3) - gint_t m_remainder; // If the M is non multiple of 16.(M%16) - - //checking whether beta value is zero. - //if true, we should perform C=alpha * A*B operation - //instead of C = beta * C + alpha * (A * B) - bool is_beta_non_zero = 0; - if(!bli_obj_equals(beta, &BLIS_ZERO)) - is_beta_non_zero = 1; - - /* - * This function was using global array to pack part of A input when - * needed. - * However, using this global array make the function non-reentrant. - * Instead of using a global array we should allocate buffer for each - * invocation. - * Since the buffer size is too big or stack and doing malloc every time - * will be too expensive, - * better approach is to get the buffer from the pre-allocated pool and - * return - * it the pool once we are doing. - * - * In order to get the buffer from pool, we need access to memory broker, - * currently this function is not invoked in such a way that it can - * receive - * the memory broker (via rntm). Following hack will get the global memory - * broker that can be use it to access the pool. - * - * Note there will be memory allocation at least on first innovation - * as there will not be any pool created for this size. - * Subsequent invocations will just reuse the buffer from the pool. - */ - - bli_rntm_init_from_global( &rntm ); - bli_rntm_set_num_threads_only( 1, &rntm ); - bli_membrk_rntm_set_membrk( &rntm ); - - // Get the current size of the buffer pool for A block packing. - // We will use the same size to avoid pool re-initliazaton - siz_t buffer_size = bli_pool_block_size( - bli_membrk_pool(bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), - bli_rntm_membrk(&rntm))); - - // - // This kernel assumes that "A" will be unpackged if N <= 3. - // Usually this range (N <= 3) is handled by SUP, however, - // if SUP is disabled or for any other condition if we do - // enter this kernel with N <= 3, we want to make sure that - // "A" remains unpacked. - // - // If this check is removed it will result in the crash as - // reported in CPUPL-587. - // - - if ((N < 3) || ((Z_MR * K) << 3) > buffer_size) - { - required_packing_A = 0; - return BLIS_NOT_YET_IMPLEMENTED; - } - - if (required_packing_A == 1) - { + dcomplex *tA = A, *tB = B, *tC = C;//, *tA_pack; + dcomplex *tA_packed; //temprorary pointer to hold packed A memory pointer + guint_t row_idx_packed; //packed A memory row index + guint_t lda_packed; //lda of packed A + guint_t col_idx_start; //starting index after A matrix is packed. + dim_t tb_inc_row = 1; // row stride of matrix B + dim_t tb_inc_col = ldb; // column stride of matrix B + __m256d ymm4, ymm5, ymm6, ymm7; + __m256d ymm8, ymm9, ymm10, ymm11; + __m256d ymm12, ymm13, ymm14, ymm15; + __m256d ymm16, ymm17, ymm18, ymm19, ymm20, ymm21; + __m256d ymm0, ymm1, ymm2, ymm3; + + gint_t n_remainder; // If the N is non multiple of 3.(N%3) + gint_t m_remainder; // If the M is non multiple of 4.(M%4) + + dcomplex *alpha_cast, *beta_cast; // alpha, beta multiples + alpha_cast = bli_obj_buffer_for_1x1(BLIS_DCOMPLEX, alpha); + beta_cast = bli_obj_buffer_for_1x1(BLIS_DCOMPLEX, beta); + + gint_t required_packing_A = 1; + mem_t local_mem_buf_A_s; + dcomplex *D_A_pack = NULL; + rntm_t rntm; + + //update the pointer math if matrix B needs to be transposed. + if (bli_obj_has_trans( b )) + { + tb_inc_col = 1; //switch row and column strides + tb_inc_row = ldb; + } + + //checking whether beta value is zero. + //if true, we should perform C=alpha * A*B operation + //instead of C = beta * C + alpha * (A * B) + bool is_beta_non_zero = 0; + if(!bli_obj_equals(beta, &BLIS_ZERO)) + is_beta_non_zero = 1; + + /* + * This function was using global array to pack part of A input when + * needed. However, using this global array make the function + * non-reentrant. Instead of using a global array we should allocate + * buffer for each invocation. Since the buffer size is too big or stack + * and doing malloc every time will be too expensive, better approach is + * to get the buffer from the pre-allocated pool and it the pool once we + * are doing. + * + * In order to get the buffer from pool, we need access to memory broker, + * currently this function is not invoked in such a way that it can + * receive the memory broker (via rntm). Following hack will get the + * global memory broker that can be use it to access the pool. + * + * Note there will be memory allocation at least on first innovation + * as there will not be any pool created for this size. + * Subsequent invocations will just reuse the buffer from the pool. + */ + + bli_rntm_init_from_global( &rntm ); + bli_rntm_set_num_threads_only( 1, &rntm ); + bli_membrk_rntm_set_membrk( &rntm ); + + // Get the current size of the buffer pool for A block packing. + // We will use the same size to avoid pool re-initliazaton + siz_t buffer_size = bli_pool_block_size( + bli_membrk_pool(bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), + bli_rntm_membrk(&rntm))); + + // + // This kernel assumes that "A" will be unpackged if N <= 3. + // Usually this range (N <= 3) is handled by SUP, however, + // if SUP is disabled or for any other condition if we do + // enter this kernel with N <= 3, we want to make sure that + // "A" remains unpacked. + // + + if ((N < 3) || ((Z_MR * K) << 4) > buffer_size) + { + required_packing_A = 0; + } + + if (required_packing_A == 1) + { #ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_dgemm_small: Requesting mem pool block of size %lu\n", - buffer_size); + printf( "bli_zgemm_small: Requesting mem pool block of size %lu\n", + buffer_size); #endif - // Get the buffer from the pool. - bli_membrk_acquire_m(&rntm, - buffer_size, - BLIS_BITVAL_BUFFER_FOR_A_BLOCK, - &local_mem_buf_A_s); - - D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); - } - - /* - * The computation loop runs for D_MRxN columns of C matrix, thus - * accessing the D_MRxK A matrix data and KxNR B matrix data. - * The computation is organized as inner loops of dimension D_MRxNR. - */ - // Process D_MR rows of C matrix at a time. - for (row_idx = 0; (row_idx + (Z_MR - 1)) < M; row_idx += Z_MR) - { - - tA = A + row_idx * lda; - tA_packed = D_A_pack; - lda_packed = Z_MR; - - // Pack 16xk of matrix A into buffer - // continuous access for A and strided stores to B - for(inc_t x = 0; (x) < 2; x += 1) - { - dcomplex* tA_temp = tA; - - for(k = 0; (k+1) < K; k += 2) - { - ymm0 = _mm256_loadu_pd((double const *) - (tA_temp + 0 * lda)); - ymm2 = _mm256_loadu_pd((double const *) - (tA_temp + 1 * lda)); - - ymm6 = _mm256_permute2f128_pd(ymm0,ymm2,0x20); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm2,0x31); - - _mm256_storeu_pd((double *) - (tA_packed + 0 * lda_packed), - ymm6); - _mm256_storeu_pd((double *) - (tA_packed + 1 * lda_packed), - ymm7); - - tA_temp += 2; - tA_packed += 2 * lda_packed; - } - - for(; k < K; k += 1) - { - tA_packed[0].real = tA_temp[0 * lda].real; - tA_packed[0].imag = tA_temp[0 * lda].imag; - tA_packed[1].real = tA_temp[1 * lda].real; - tA_packed[1].imag = tA_temp[1 * lda].imag; - - tA_temp += 1; - tA_packed += lda_packed; - } - - tA += 2 * lda; - tA_packed = D_A_pack + (x + 1)*2; - } - - tA_packed = D_A_pack; - row_idx_packed = 0; - lda_packed = Z_MR; - - // Process NR columns of C matrix at a time. - for (col_idx = 0; (col_idx + (NR - 1)) < N; col_idx += NR) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = tA_packed + row_idx_packed; + // Get the buffer from the pool. + bli_membrk_acquire_m(&rntm, + buffer_size, + BLIS_BITVAL_BUFFER_FOR_A_BLOCK, + &local_mem_buf_A_s); + + D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); + } + + /* + * The computation loop runs for Z_MRxN columns of C matrix, thus + * accessing the Z_MRxK A matrix data and KxNR B matrix data. + * The computation is organized as inner loops of dimension Z_MRxNR. + */ + // Process D_MR rows of C matrix at a time. + for (row_idx = 0; (row_idx + (Z_MR - 1)) < M; row_idx += Z_MR) + { + col_idx_start = 0; + tA_packed = A; + row_idx_packed = row_idx; + lda_packed = lda; + + /** + * This is the part of the pack and compute optimization. + * During the first column iteration, we store the accessed A + * matrix into contiguous static memory. This helps to keep te A + * matrix in Cache and aviods the TLB misses. + */ + if (required_packing_A) + { + col_idx = 0; + + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + tA_packed = D_A_pack; #ifdef BLIS_ENABLE_PREFETCH - _mm_prefetch((char*)(tC + 0), _MM_HINT_T0); - _mm_prefetch((char*)(tC + 8), _MM_HINT_T0); - _mm_prefetch((char*)(tC + ldc), _MM_HINT_T0); - _mm_prefetch((char*)(tC + ldc + 8), _MM_HINT_T0); - _mm_prefetch((char*)(tC + 2 * ldc), _MM_HINT_T0); - _mm_prefetch((char*)(tC + 2 * ldc + 8), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 0), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 8), _MM_HINT_T0); + _mm_prefetch((char*)(tC + ldc), _MM_HINT_T0); + _mm_prefetch((char*)(tC + ldc + 8), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 2 * ldc), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 2 * ldc + 8), _MM_HINT_T0); #endif - // clear scratch registers. - - BLIS_SET_ALL_YMM_REG_ZEROS - - double *tptr = (double *)tB; - if(conjtransa && conjtransb) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing D_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - ymm1 = _mm256_loadu_pd((double const *) - (tA + 2)); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - ymm1 = _mm256_mul_pd(ymm1, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2 + 1)); - - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); - ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); - - tptr += (tb_inc_row * 2); - tB += tb_inc_row; - tA += lda_packed; - } - } - else if(conjtransa) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing D_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - ymm1 = _mm256_loadu_pd((double const *) - (tA + 2)); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - ymm1 = _mm256_mul_pd(ymm1, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2 + 1)); - - ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); - ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); - - tptr += (tb_inc_row * 2); - tB += tb_inc_row; - tA += lda_packed; - } - } - else if(conjtransb) - { - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing D_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - ymm1 = _mm256_loadu_pd((double const *) - (tA + 2)); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2 + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); - ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); - - tptr += (tb_inc_row * 2); - tB += tb_inc_row; - tA += lda_packed; - } - } - else - { - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing D_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - ymm1 = _mm256_loadu_pd((double const *) - (tA + 2)); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2 + 1)); - - ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); - ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); - - tptr += (tb_inc_row * 2); - tB += tb_inc_row; - tA += lda_packed; - } - } - - ymm4 = _mm256_permute_pd(ymm4, 0x5); - ymm5 = _mm256_permute_pd(ymm5, 0x5); - ymm6 = _mm256_permute_pd(ymm6, 0x5); - ymm7 = _mm256_permute_pd(ymm7, 0x5); - ymm14 = _mm256_permute_pd(ymm14, 0x5); - ymm15 = _mm256_permute_pd(ymm15, 0x5); - - ymm8 = _mm256_addsub_pd(ymm8, ymm4); - ymm11 = _mm256_addsub_pd(ymm11, ymm5); - ymm9 = _mm256_addsub_pd(ymm9, ymm6); - ymm12 = _mm256_addsub_pd(ymm12, ymm7); - ymm10 = _mm256_addsub_pd(ymm10, ymm14); - ymm13 = _mm256_addsub_pd(ymm13, ymm15); - - // alpha, beta multiplication. - ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm8, ymm0); - ymm14 = _mm256_mul_pd(ymm8, ymm14); - ymm8 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm9, ymm0); - ymm14 = _mm256_mul_pd(ymm9, ymm14); - ymm9 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm10, ymm0); - ymm14 = _mm256_mul_pd(ymm10, ymm14); - ymm10 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm11, ymm0); - ymm14 = _mm256_mul_pd(ymm11, ymm14); - ymm11 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm12, ymm0); - ymm14 = _mm256_mul_pd(ymm12, ymm14); - ymm12 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm13, ymm0); - ymm14 = _mm256_mul_pd(ymm13, ymm14); - ymm13 = _mm256_hsub_pd(ymm15, ymm14); - - ymm2 = _mm256_broadcast_sd((double const *) - &beta_cast->real); - ymm3 = _mm256_broadcast_sd((double const *) - (&beta_cast->imag)); - - - - BLIS_SET_YMM_REG_ZEROS - if(is_beta_non_zero) - { - // multiply C by beta and accumulate col 1. - ymm0 = _mm256_loadu_pd((double const *)tC); - ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - - ymm0 = _mm256_loadu_pd((double const *)(tC + 2)); - ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6); - ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7); - - // col 2 - ymm0 = _mm256_loadu_pd((double const *) - (tC + ldc)); - ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); - ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); - - ymm0 = _mm256_loadu_pd((double const *) - (tC + ldc + 2)); - ymm16 = _mm256_fmadd_pd(ymm0, ymm2, ymm16); - ymm17 = _mm256_fmadd_pd(ymm0, ymm3, ymm17); - - // col 3 - ymm0 = _mm256_loadu_pd((double const *) - (tC + (ldc * 2))); - ymm18 = _mm256_fmadd_pd(ymm0, ymm2, ymm18); - ymm19 = _mm256_fmadd_pd(ymm0, ymm3, ymm19); - - ymm0 = _mm256_loadu_pd((double const *) - (tC + (ldc * 2) + 2)); - ymm20 = _mm256_fmadd_pd(ymm0, ymm2, ymm20); - ymm21 = _mm256_fmadd_pd(ymm0, ymm3, ymm21); - } - ymm5 = _mm256_permute_pd(ymm5, 0x5); - ymm7 = _mm256_permute_pd(ymm7, 0x5); - ymm15 = _mm256_permute_pd(ymm15, 0x5); - ymm17 = _mm256_permute_pd(ymm17, 0x5); - ymm19 = _mm256_permute_pd(ymm19, 0x5); - ymm21 = _mm256_permute_pd(ymm21, 0x5); - - ymm4 = _mm256_addsub_pd(ymm4, ymm5); - ymm6 = _mm256_addsub_pd(ymm6, ymm7); - ymm14 = _mm256_addsub_pd(ymm14, ymm15); - ymm16 = _mm256_addsub_pd(ymm16, ymm17); - ymm18 = _mm256_addsub_pd(ymm18, ymm19); - ymm20 = _mm256_addsub_pd(ymm20, ymm21); - - ymm8 = _mm256_add_pd(ymm8, ymm4); - ymm11 = _mm256_add_pd(ymm11, ymm6); - ymm9 = _mm256_add_pd(ymm9, ymm14); - ymm12 = _mm256_add_pd(ymm12, ymm16); - ymm10 = _mm256_add_pd(ymm10, ymm18); - ymm13 = _mm256_add_pd(ymm13, ymm20); - - _mm256_storeu_pd((double *)tC, ymm8); - _mm256_storeu_pd((double *)(tC + 2), ymm11); - - tC += ldc; - - _mm256_storeu_pd((double *)tC, ymm9); - _mm256_storeu_pd((double *)(tC + 2), ymm12); - - tC += ldc; - - _mm256_storeu_pd((double *)tC, ymm10); - _mm256_storeu_pd((double *)(tC + 2), ymm13); - - } - n_remainder = N - col_idx; - - // if the N is not multiple of 3. - // handling edge case. - if (n_remainder == 2) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = tA_packed + row_idx_packed; - - // clear scratch registers. - - - BLIS_SET_ALL_YMM_REG_ZEROS - double *tptr = (double *)tB; - - if(conjtransa && conjtransb) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const*)tA); - ymm1 = _mm256_loadu_pd((double const*) - (tA + 2)); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - ymm1 = _mm256_mul_pd(ymm1, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - tptr += tb_inc_row*2; - tA += lda_packed; - - } - } - else if(conjtransa) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const*)tA); - ymm1 = _mm256_loadu_pd((double const*) - (tA + 2)); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - ymm1 = _mm256_mul_pd(ymm1, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - tptr += tb_inc_row*2; - tA += lda_packed; - } - } - else if(conjtransb) - { - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const*)tA); - ymm1 = _mm256_loadu_pd((double const*) - (tA + 2)); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - tptr += tb_inc_row*2; - tA += lda_packed; - } - } - else - { - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const*)tA); - ymm1 = _mm256_loadu_pd((double const*) - (tA + 2)); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - tptr += tb_inc_row*2; - tA += lda_packed; - } - } - - - ymm4 = _mm256_permute_pd(ymm4, 0x5); - ymm5 = _mm256_permute_pd(ymm5, 0x5); - ymm6 = _mm256_permute_pd(ymm6, 0x5); - ymm7 = _mm256_permute_pd(ymm7, 0x5); - - ymm8 = _mm256_addsub_pd(ymm8, ymm4); - ymm11 = _mm256_addsub_pd(ymm11, ymm5); - ymm9 = _mm256_addsub_pd(ymm9, ymm6); - ymm12 = _mm256_addsub_pd(ymm12, ymm7); - - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); - ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm8, ymm0); - ymm14 = _mm256_mul_pd(ymm8, ymm14); - ymm8 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm11, ymm0); - ymm14 = _mm256_mul_pd(ymm11, ymm14); - ymm11 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm9, ymm0); - ymm14 = _mm256_mul_pd(ymm9, ymm14); - ymm9 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm12, ymm0); - ymm14 = _mm256_mul_pd(ymm12, ymm14); - ymm12 = _mm256_hsub_pd(ymm15, ymm14); - - - - BLIS_SET_YMM_REG_ZEROS - ymm2 = _mm256_broadcast_sd((double const *) - &beta_cast->real); - ymm3 = _mm256_broadcast_sd((double const *) - &beta_cast->imag); - - if(is_beta_non_zero) - { - // multiply C by beta and accumulate col 1. - ymm0 = _mm256_loadu_pd((double const *)tC); - ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - - ymm0 = _mm256_loadu_pd((double const *)(tC + 2)); - ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6); - ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7); - - ymm0 = _mm256_loadu_pd((double const *)(tC + ldc)); - ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); - ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); - - ymm0 = _mm256_loadu_pd((double const *) - (tC + ldc + 2)); - ymm16 = _mm256_fmadd_pd(ymm0, ymm2, ymm16); - ymm17 = _mm256_fmadd_pd(ymm0, ymm3, ymm17); - - } - ymm5 = _mm256_permute_pd(ymm5, 0x5); - ymm7 = _mm256_permute_pd(ymm7, 0x5); - ymm15 = _mm256_permute_pd(ymm15, 0x5); - ymm17 = _mm256_permute_pd(ymm17, 0x5); - - ymm4 = _mm256_addsub_pd(ymm4, ymm5); - ymm6 = _mm256_addsub_pd(ymm6, ymm7); - ymm14 = _mm256_addsub_pd(ymm14, ymm15); - ymm16 = _mm256_addsub_pd(ymm16, ymm17); - - ymm8 = _mm256_add_pd(ymm8, ymm4); - ymm11 = _mm256_add_pd(ymm11, ymm6); - ymm9 = _mm256_add_pd(ymm9, ymm14); - ymm12 = _mm256_add_pd(ymm12, ymm16); - - _mm256_storeu_pd((double *)(tC + 0), ymm8); - _mm256_storeu_pd((double *)(tC + 2), ymm11); - tC += ldc; - _mm256_storeu_pd((double *)tC, ymm9); - _mm256_storeu_pd((double *)(tC + 2), ymm12); - } - // if the N is not multiple of 3. - // handling edge case. - if (n_remainder == 1) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = tA_packed + row_idx_packed; - - // clear scratch registers. - BLIS_SET_ALL_YMM_REG_ZEROS - double *tptr = (double *)tB; - - if(conjtransa && conjtransb) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - tptr += tb_inc_row*2; - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - ymm1 = _mm256_loadu_pd((double const *)(tA + 2)); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - ymm1 = _mm256_mul_pd(ymm1, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - tA += lda_packed; - } - } - else if(conjtransa) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd((double const *)(tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd((double const *)(tptr + tb_inc_col * 0 + 1)); - tptr += tb_inc_row*2; - - //broadcasted matrix B elements are multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - ymm1 = _mm256_loadu_pd((double const *) - (tA + 2)); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - ymm1 = _mm256_mul_pd(ymm1, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - tA += lda_packed; - } - } - else if(conjtransb) - { - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - tptr += tb_inc_row*2; - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - ymm1 = _mm256_loadu_pd((double const *) - (tA + 2)); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - tA += lda_packed; - } - } - else - { - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - tptr += tb_inc_row*2; - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - ymm1 = _mm256_loadu_pd((double const *) - (tA + 2)); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - tA += lda_packed; - } - } - ymm4 = _mm256_permute_pd(ymm4, 0x5); - ymm5 = _mm256_permute_pd(ymm5, 0x5); - - ymm8 = _mm256_addsub_pd(ymm8, ymm4); - ymm11 = _mm256_addsub_pd(ymm11, ymm5); - - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); - ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm8, ymm0); - ymm14 = _mm256_mul_pd(ymm8, ymm14); - ymm8 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm11, ymm0); - ymm14 = _mm256_mul_pd(ymm11, ymm14); - ymm11 = _mm256_hsub_pd(ymm15, ymm14); - - - - BLIS_SET_YMM_REG_ZEROS - ymm2 = _mm256_broadcast_sd((double const *) - &beta_cast->real); - ymm3 = _mm256_broadcast_sd((double const *) - &beta_cast->imag); - - if(is_beta_non_zero) - { - // multiply C by beta and accumulate col 1. - ymm0 = _mm256_loadu_pd((double const *)tC); - ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - - ymm0 = _mm256_loadu_pd((double const *)(tC + 2)); - ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6); - ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7); - } - ymm5 = _mm256_permute_pd(ymm5, 0x5); - ymm7 = _mm256_permute_pd(ymm7, 0x5); - - ymm4 = _mm256_addsub_pd(ymm4, ymm5); - ymm6 = _mm256_addsub_pd(ymm6, ymm7); - - ymm8 = _mm256_add_pd(ymm8, ymm4); - ymm11 = _mm256_add_pd(ymm11, ymm6); - - _mm256_storeu_pd((double *)tC, ymm8); - _mm256_storeu_pd((double *)(tC + 2), ymm11); - } - } - - m_remainder = M - row_idx; - if ((m_remainder == 3)) - { - m_remainder -= 3; - __m128d xmm0; - - tA = A + row_idx * lda; - tA_packed = D_A_pack; - lda_packed = 3; - { - dcomplex* tA_temp = tA; - - for(k = 0; (k+1) < K; k += 2) - { - ymm0 = _mm256_loadu_pd((double const *) - (tA_temp + 0 * lda)); - ymm2 = _mm256_loadu_pd((double const *) - (tA_temp + 1 * lda)); - ymm3 = _mm256_loadu_pd((double const *) - (tA_temp + 2 * lda)); - - ymm6 = _mm256_permute2f128_pd(ymm0,ymm2,0x20); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm2,0x31); - - _mm256_storeu_pd((double *) - (tA_packed + 0 * lda_packed), - ymm6); - xmm0 = _mm256_extractf128_pd(ymm3, 0); - _mm_storeu_pd((double *) - (tA_packed + 0 * lda_packed + 2), - xmm0); - - _mm256_storeu_pd((double *) - (tA_packed + 1 * lda_packed), - ymm7); - xmm0 = _mm256_extractf128_pd(ymm3, 1); - _mm_storeu_pd((double *) - (tA_packed + 1 * lda_packed + 2), - xmm0); - - tA_temp += 2; - tA_packed += 2 * lda_packed; - } - - for(; k < K; k += 1) - { - tA_packed[0].real = tA_temp[0 * lda].real; - tA_packed[0].imag = tA_temp[0 * lda].imag; - tA_packed[1].real = tA_temp[1 * lda].real; - tA_packed[1].imag = tA_temp[1 * lda].imag; - tA_packed[2].real = tA_temp[2 * lda].real; - tA_packed[2].imag = tA_temp[2 * lda].imag; - - tA_temp += 1; - tA_packed += lda_packed; - } - } - - tA_packed = D_A_pack; - row_idx_packed = 0; - lda_packed = 3; - - for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = tA_packed + row_idx_packed; - - - BLIS_SET_ALL_YMM_REG_ZEROS - xmm0 = _mm_setzero_pd(); - - double *tptr = (double *)tB; - if(conjtransa && conjtransb) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing D_MR x K - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing D_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - xmm0 = _mm_loadu_pd((double const *) - (tA + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - ymm1 = _mm256_mul_pd(ymm1, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2 + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); - ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); - - tptr += (tb_inc_row * 2); - tB += tb_inc_row; - tA += lda_packed; - } - } - else if(conjtransa) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing D_MR x K - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing D_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - xmm0 = _mm_loadu_pd((double const *) - (tA + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - ymm1 = _mm256_mul_pd(ymm1, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2 + 1)); - - ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); - ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); - - tptr += (tb_inc_row * 2); - tB += tb_inc_row; - tA += lda_packed; - } - } - else if(conjtransb) - { - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing D_MR x K - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing D_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - xmm0 = _mm_loadu_pd((double const *) - (tA + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2 + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); - ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); - - tptr += (tb_inc_row * 2); - tB += tb_inc_row; - tA += lda_packed; - } - } - else - { - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing D_MR x K - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing D_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - xmm0 = _mm_loadu_pd((double const *) - (tA + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2 + 1)); - - ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); - ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); - - tptr += (tb_inc_row * 2); - tB += tb_inc_row; - tA += lda_packed; - } - } - ymm4 = _mm256_permute_pd(ymm4, 0x5); - ymm5 = _mm256_permute_pd(ymm5, 0x5); - ymm6 = _mm256_permute_pd(ymm6, 0x5); - ymm7 = _mm256_permute_pd(ymm7, 0x5); - ymm14 = _mm256_permute_pd(ymm14, 0x5); - ymm15 = _mm256_permute_pd(ymm15, 0x5); - - ymm8 = _mm256_addsub_pd(ymm8, ymm4); - ymm11 = _mm256_addsub_pd(ymm11, ymm5); - ymm9 = _mm256_addsub_pd(ymm9, ymm6); - ymm12 = _mm256_addsub_pd(ymm12, ymm7); - ymm10 = _mm256_addsub_pd(ymm10, ymm14); - ymm13 = _mm256_addsub_pd(ymm13, ymm15); - // alpha, beta multiplication. - ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm8, ymm0); - ymm14 = _mm256_mul_pd(ymm8, ymm14); - ymm8 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm9, ymm0); - ymm14 = _mm256_mul_pd(ymm9, ymm14); - ymm9 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm10, ymm0); - ymm14 = _mm256_mul_pd(ymm10, ymm14); - ymm10 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm11, ymm0); - ymm14 = _mm256_mul_pd(ymm11, ymm14); - ymm11 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm12, ymm0); - ymm14 = _mm256_mul_pd(ymm12, ymm14); - ymm12 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm13, ymm0); - ymm14 = _mm256_mul_pd(ymm13, ymm14); - ymm13 = _mm256_hsub_pd(ymm15, ymm14); - - ymm2 = _mm256_broadcast_sd((double const *) - &beta_cast->real); - ymm3 = _mm256_broadcast_sd((double const *) - &beta_cast->imag); - - - BLIS_SET_YMM_REG_ZEROS - xmm0 = _mm_setzero_pd(); - - if(is_beta_non_zero) - { - // multiply C by beta and accumulate col 1. - ymm0 = _mm256_loadu_pd((double const *)tC); - xmm0 = _mm_loadu_pd((double const *)(tC + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); - - ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - ymm6 = _mm256_fmadd_pd(ymm1, ymm2, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - ymm0 = _mm256_loadu_pd((double const *) - (tC + ldc)); - xmm0 = _mm_loadu_pd((double const *) - (tC + ldc + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); - - ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); - ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); - ymm16 = _mm256_fmadd_pd(ymm1, ymm2, ymm16); - ymm17 = _mm256_fmadd_pd(ymm1, ymm3, ymm17); - - ymm0 = _mm256_loadu_pd((double const *) - (tC + ldc * 2)); - xmm0 = _mm_loadu_pd((double const *) - (tC + ldc * 2 + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); - - ymm18 = _mm256_fmadd_pd(ymm0, ymm2, ymm18); - ymm19 = _mm256_fmadd_pd(ymm0, ymm3, ymm19); - ymm20 = _mm256_fmadd_pd(ymm1, ymm2, ymm20); - ymm21 = _mm256_fmadd_pd(ymm1, ymm3, ymm21); - - } - ymm5 = _mm256_permute_pd(ymm5, 0x5); - ymm7 = _mm256_permute_pd(ymm7, 0x5); - ymm15 = _mm256_permute_pd(ymm15, 0x5); - ymm17 = _mm256_permute_pd(ymm17, 0x5); - ymm19 = _mm256_permute_pd(ymm19, 0x5); - ymm21 = _mm256_permute_pd(ymm21, 0x5); - - ymm4 = _mm256_addsub_pd(ymm4, ymm5); - ymm6 = _mm256_addsub_pd(ymm6, ymm7); - ymm14 = _mm256_addsub_pd(ymm14, ymm15); - ymm16 = _mm256_addsub_pd(ymm16, ymm17); - ymm18 = _mm256_addsub_pd(ymm18, ymm19); - ymm20 = _mm256_addsub_pd(ymm20, ymm21); - - ymm8 = _mm256_add_pd(ymm8, ymm4); - ymm11 = _mm256_add_pd(ymm11, ymm6); - ymm9 = _mm256_add_pd(ymm9, ymm14); - ymm12 = _mm256_add_pd(ymm12, ymm16); - ymm10 = _mm256_add_pd(ymm10, ymm18); - ymm13 = _mm256_add_pd(ymm13, ymm20); - - _mm256_storeu_pd((double *)tC, ymm8); - xmm0 = _mm256_extractf128_pd(ymm11, 0); - _mm_storeu_pd((double *)(tC + 2), xmm0); - - tC += ldc; - - _mm256_storeu_pd((double *)tC, ymm9); - xmm0 = _mm256_extractf128_pd(ymm12, 0); - _mm_storeu_pd((double *)(tC + 2), xmm0); - - tC += ldc; - - _mm256_storeu_pd((double *)tC, ymm10); - xmm0 = _mm256_extractf128_pd(ymm13, 0); - _mm_storeu_pd((double *)(tC + 2), xmm0); - } - n_remainder = N - col_idx; - if (n_remainder == 2) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = tA_packed + row_idx_packed; - - // clear scratch registers. - BLIS_SET_ALL_YMM_REG_ZEROS - xmm0 = _mm_setzero_pd(); - - double *tptr = (double *)tB; - if(conjtransa && conjtransb) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd((tptr + - tb_inc_col - * 0)); - ymm3 = _mm256_broadcast_sd((tptr + - tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - xmm0 = _mm_loadu_pd((double const *) - (tA + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - ymm1 = _mm256_mul_pd(ymm1, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - tptr += tb_inc_row*2; - tA += lda_packed; - } - } - - else if(conjtransa) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd((tptr + - tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd((tptr + - tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - xmm0 = _mm_loadu_pd((double const *) - (tA + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - ymm1 = _mm256_mul_pd(ymm1, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - tptr += tb_inc_row*2; - tA += lda_packed; - } - } - - else if(conjtransb) - { - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd((tptr + - tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd((tptr + - tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - xmm0 = _mm_loadu_pd((double const *) - (tA + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - tptr += tb_inc_row*2; - tA += lda_packed; - } - } - else - { - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd((tptr + - tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd((tptr + - tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - xmm0 = _mm_loadu_pd((double const *) - (tA + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - tptr += tb_inc_row*2; - tA += lda_packed; - } - } - ymm4 = _mm256_permute_pd(ymm4, 0x5); - ymm5 = _mm256_permute_pd(ymm5, 0x5); - ymm6 = _mm256_permute_pd(ymm6, 0x5); - ymm7 = _mm256_permute_pd(ymm7, 0x5); - - ymm8 = _mm256_addsub_pd(ymm8, ymm4); - ymm11 = _mm256_addsub_pd(ymm11, ymm5); - ymm9 = _mm256_addsub_pd(ymm9, ymm6); - ymm12 = _mm256_addsub_pd(ymm12, ymm7); - - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); - ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm8, ymm0); - ymm14 = _mm256_mul_pd(ymm8, ymm14); - ymm8 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm11, ymm0); - ymm14 = _mm256_mul_pd(ymm11, ymm14); - ymm11 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm9, ymm0); - ymm14 = _mm256_mul_pd(ymm9, ymm14); - ymm9 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm12, ymm0); - ymm14 = _mm256_mul_pd(ymm12, ymm14); - ymm12 = _mm256_hsub_pd(ymm15, ymm14); - - - BLIS_SET_YMM_REG_ZEROS - xmm0 = _mm_setzero_pd(); - - ymm2 = _mm256_broadcast_sd((double const *) - &beta_cast->real); - ymm3 = _mm256_broadcast_sd((double const *) - &beta_cast->imag); - - if(is_beta_non_zero) - { - // multiply C by beta and accumulate col 1. - ymm0 = _mm256_loadu_pd((double const *)tC); - xmm0 = _mm_loadu_pd((double const *)(tC + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); - - ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - ymm6 = _mm256_fmadd_pd(ymm1, ymm2, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - ymm0 = _mm256_loadu_pd((double const *) - (tC + ldc)); - xmm0 = _mm_loadu_pd((double const *) - (tC + ldc + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); - - ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); - ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); - ymm16 = _mm256_fmadd_pd(ymm1, ymm2, ymm16); - ymm17 = _mm256_fmadd_pd(ymm1, ymm3, ymm17); - - } - ymm5 = _mm256_permute_pd(ymm5, 0x5); - ymm7 = _mm256_permute_pd(ymm7, 0x5); - ymm15 = _mm256_permute_pd(ymm15, 0x5); - ymm17 = _mm256_permute_pd(ymm17, 0x5); - - ymm4 = _mm256_addsub_pd(ymm4, ymm5); - ymm6 = _mm256_addsub_pd(ymm6, ymm7); - ymm14 = _mm256_addsub_pd(ymm14, ymm15); - ymm16 = _mm256_addsub_pd(ymm16, ymm17); - - ymm8 = _mm256_add_pd(ymm8, ymm4); - ymm11 = _mm256_add_pd(ymm11, ymm6); - ymm9 = _mm256_add_pd(ymm9, ymm14); - ymm12 = _mm256_add_pd(ymm12, ymm16); - - _mm256_storeu_pd((double *)tC, ymm8); - xmm0 = _mm256_extractf128_pd(ymm11, 0); - _mm_storeu_pd((double *)(tC + 2), xmm0); - - tC += ldc; - _mm256_storeu_pd((double *)tC, ymm9); - xmm0 = _mm256_extractf128_pd(ymm12, 0); - _mm_storeu_pd((double *)(tC + 2), xmm0); - } - if (n_remainder == 1) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = tA_packed + row_idx_packed; - - // clear scratch registers. - - BLIS_SET_ALL_YMM_REG_ZEROS - xmm0 = _mm_setzero_pd(); - - double *tptr = (double *)tB; - if(conjtransa && conjtransb) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - xmm0 = _mm_loadu_pd((double const *) - (tA + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - ymm1 = _mm256_mul_pd(ymm1, ymm20); - - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - tptr += tb_inc_row*2; - tA += lda_packed; - } - } - - else if(conjtransa) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - xmm0 = _mm_loadu_pd((double const *) - (tA + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - ymm1 = _mm256_mul_pd(ymm1, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - tptr += tb_inc_row*2; - tA += lda_packed; - } - } - else if(conjtransb) - { - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - xmm0 = _mm_loadu_pd((double const *) - (tA + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - tptr += tb_inc_row*2; - tA += lda_packed; - } - } - else - { - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - xmm0 = _mm_loadu_pd((double const *) - (tA + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - tptr += tb_inc_row*2; - tA += lda_packed; - } - } - ymm4 = _mm256_permute_pd(ymm4, 0x5); - ymm5 = _mm256_permute_pd(ymm5, 0x5); - - ymm8 = _mm256_addsub_pd(ymm8, ymm4); - ymm11 = _mm256_addsub_pd(ymm11, ymm5); - - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); - ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm8, ymm0); - ymm14 = _mm256_mul_pd(ymm8, ymm14); - ymm8 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm11, ymm0); - ymm14 = _mm256_mul_pd(ymm11, ymm14); - ymm11 = _mm256_hsub_pd(ymm15, ymm14); - - - BLIS_SET_YMM_REG_ZEROS - xmm0 = _mm_setzero_pd(); - - ymm2 = _mm256_broadcast_sd((double const *) - &beta_cast->real); - ymm3 = _mm256_broadcast_sd((double const *) - &beta_cast->imag); - - if(is_beta_non_zero) - { - // multiply C by beta and accumulate col 1. - ymm0 = _mm256_loadu_pd((double const *)tC); - xmm0 = _mm_loadu_pd((double const *)(tC + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); - - ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - ymm6 = _mm256_fmadd_pd(ymm1, ymm2, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - } - ymm5 = _mm256_permute_pd(ymm5, 0x5); - ymm7 = _mm256_permute_pd(ymm7, 0x5); - - ymm4 = _mm256_addsub_pd(ymm4, ymm5); - ymm6 = _mm256_addsub_pd(ymm6, ymm7); - - ymm8 = _mm256_add_pd(ymm8, ymm4); - ymm11 = _mm256_add_pd(ymm11, ymm6); - - _mm256_storeu_pd((double *)tC, ymm8); - xmm0 = _mm256_extractf128_pd(ymm11, 0); - _mm_storeu_pd((double *)(tC + 2), xmm0); - } - } - if ((m_remainder == 2)) - { - m_remainder -= 2; + // clear scratch registers. + BLIS_SET_ALL_YMM_REG_ZEROS - tA = A + row_idx * lda; - tA_packed = D_A_pack; - lda_packed = 2; + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B + // matrix i data and multiplies it with + // the A matrix. + // This loop is processing Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied with matrix A columns. + ymm0 = _mm256_loadu_pd( + (double const *)tA); + ymm1 = _mm256_loadu_pd( + (double const *)(tA + 2)); + _mm256_storeu_pd( + (double *)tA_packed, ymm0); + _mm256_storeu_pd( + (double *) + (tA_packed + 2), ymm1); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) * + 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + tA_packed += Z_MR; + } - { - dcomplex* tA_temp = tA; + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd( + (double const *)tA); + ymm1 = _mm256_loadu_pd( + (double const *)(tA + 2)); + _mm256_storeu_pd( + (double *)tA_packed, ymm0); + _mm256_storeu_pd( + (double *)(tA_packed + 2) + , ymm1); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + tA_packed += Z_MR; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and multiplies it with the A + // matrix. This loop is processing + // Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied with matrix A columns. + ymm0 = _mm256_loadu_pd( + (double const *)tA); + ymm1 = _mm256_loadu_pd( + (double const *)(tA + 2)); + _mm256_storeu_pd( + (double *)tA_packed, ymm0); + _mm256_storeu_pd( + (double *)(tA_packed + 2) + , ymm1); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + tA_packed += Z_MR; + } - for(k = 0; (k+1) < K; k += 2) - { - ymm0 = _mm256_loadu_pd((double const *) - (tA_temp + 0 * lda)); - ymm2 = _mm256_loadu_pd((double const *) - (tA_temp + 1 * lda)); + } + else //handles non-transpose case + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and multiplies it with the A + // matrix. This loop is processing + // Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd( + (double const *)tA); + ymm1 = _mm256_loadu_pd( + (double const *)(tA + 2)); + _mm256_storeu_pd( + (double *)tA_packed, ymm0); + _mm256_storeu_pd( + (double *)(tA_packed + 2) + , ymm1); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + tA_packed += Z_MR; + } + } - ymm6 = _mm256_permute2f128_pd(ymm0,ymm2,0x20); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm2,0x31); + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + ymm14 = _mm256_permute_pd(ymm14, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); - _mm256_storeu_pd((double *) - (tA_packed + 0 * lda_packed), - ymm6); - _mm256_storeu_pd((double *) - (tA_packed + 1 * lda_packed), - ymm7); + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm11 = _mm256_addsub_pd(ymm11, ymm5); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + ymm12 = _mm256_addsub_pd(ymm12, ymm7); + ymm10 = _mm256_addsub_pd(ymm10, ymm14); + ymm13 = _mm256_addsub_pd(ymm13, ymm15); - tA_temp += 2; - tA_packed += 2 * lda_packed; - } + // alpha, beta multiplication. + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm10, ymm0); + ymm14 = _mm256_mul_pd(ymm10, ymm14); + ymm10 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm11, ymm0); + ymm14 = _mm256_mul_pd(ymm11, ymm14); + ymm11 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm12, ymm0); + ymm14 = _mm256_mul_pd(ymm12, ymm14); + ymm12 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm13, ymm0); + ymm14 = _mm256_mul_pd(ymm13, ymm14); + ymm13 = _mm256_hsub_pd(ymm15, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + (&beta_cast->imag)); + + + BLIS_SET_YMM_REG_ZEROS - for(; k < K; k += 1) - { - tA_packed[0].real = tA_temp[0 * lda].real; - tA_packed[0].imag = tA_temp[0 * lda].imag; - tA_packed[1].real = tA_temp[1 * lda].real; - tA_packed[1].imag = tA_temp[1 * lda].imag; + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - tA_temp += 1; - tA_packed += lda_packed; - } - } + ymm0 = _mm256_loadu_pd((double const *)(tC + 2)); + ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6); + ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7); - tA_packed = D_A_pack; - row_idx_packed = 0; - lda_packed = 2; + // col 2 + ymm0 = _mm256_loadu_pd((double const *)(tC + ldc)); + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); - for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = tA_packed + row_idx_packed; + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc + 2)); + ymm16 = _mm256_fmadd_pd(ymm0, ymm2, ymm16); + ymm17 = _mm256_fmadd_pd(ymm0, ymm3, ymm17); + + // col 3 + ymm0 = _mm256_loadu_pd((double const *) + (tC + (ldc * 2))); + ymm18 = _mm256_fmadd_pd(ymm0, ymm2, ymm18); + ymm19 = _mm256_fmadd_pd(ymm0, ymm3, ymm19); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + (ldc * 2) + 2)); + ymm20 = _mm256_fmadd_pd(ymm0, ymm2, ymm20); + ymm21 = _mm256_fmadd_pd(ymm0, ymm3, ymm21); + + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + ymm17 = _mm256_permute_pd(ymm17, 0x5); + ymm19 = _mm256_permute_pd(ymm19, 0x5); + ymm21 = _mm256_permute_pd(ymm21, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm6 = _mm256_addsub_pd(ymm6, ymm7); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + ymm16 = _mm256_addsub_pd(ymm16, ymm17); + ymm18 = _mm256_addsub_pd(ymm18, ymm19); + ymm20 = _mm256_addsub_pd(ymm20, ymm21); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm11 = _mm256_add_pd(ymm11, ymm6); + ymm9 = _mm256_add_pd(ymm9, ymm14); + ymm12 = _mm256_add_pd(ymm12, ymm16); + ymm10 = _mm256_add_pd(ymm10, ymm18); + ymm13 = _mm256_add_pd(ymm13, ymm20); + + _mm256_storeu_pd((double *)tC, ymm8); + _mm256_storeu_pd((double *)(tC + 2), ymm11); + tC += ldc; + _mm256_storeu_pd((double *)tC, ymm9); + _mm256_storeu_pd((double *)(tC + 2), ymm12); - BLIS_SET_ALL_YMM_REG_ZEROS - - double *tptr = (double *)tB; - if(conjtransa && conjtransb) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing D_MR x K - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing D_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2 + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - - tptr += (tb_inc_row * 2); - tB += tb_inc_row; - tA += lda_packed; - } - } - else if(conjtransa) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing D_MR x K - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing D_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2 + 1)); - - ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - - tptr += (tb_inc_row * 2); - tB += tb_inc_row; - tA += lda_packed; - } - } - else if(conjtransb) - { - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing D_MR x K - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing D_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2 + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - - tptr += (tb_inc_row * 2); - tB += tb_inc_row; - tA += lda_packed; - } - } - else - { - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing D_MR x K - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing D_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2 + 1)); - - ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - - tptr += (tb_inc_row * 2); - tB += tb_inc_row; - tA += lda_packed; - } - } - - ymm4 = _mm256_permute_pd(ymm4, 0x5); - ymm6 = _mm256_permute_pd(ymm6, 0x5); - ymm14 = _mm256_permute_pd(ymm14, 0x5); - - ymm8 = _mm256_addsub_pd(ymm8, ymm4); - ymm9 = _mm256_addsub_pd(ymm9, ymm6); - ymm10 = _mm256_addsub_pd(ymm10, ymm14); - // alpha, beta multiplication. - ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm8, ymm0); - ymm14 = _mm256_mul_pd(ymm8, ymm14); - ymm8 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm9, ymm0); - ymm14 = _mm256_mul_pd(ymm9, ymm14); - ymm9 = _mm256_hsub_pd(ymm15, ymm14); - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm10, ymm0); - ymm14 = _mm256_mul_pd(ymm10, ymm14); - ymm10 = _mm256_hsub_pd(ymm15, ymm14); - - ymm2 = _mm256_broadcast_sd((double const *) - &beta_cast->real); - ymm3 = _mm256_broadcast_sd((double const *) - &beta_cast->imag); - - BLIS_SET_YMM_REG_ZEROS - - if(is_beta_non_zero) - { - // multiply C by beta and accumulate col 1. - ymm0 = _mm256_loadu_pd((double const *)tC); - - ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - - ymm0 = _mm256_loadu_pd((double const *) - (tC + ldc)); - - ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); - ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); - - ymm0 = _mm256_loadu_pd((double const *) - (tC + ldc * 2)); - - ymm18 = _mm256_fmadd_pd(ymm0, ymm2, ymm18); - ymm19 = _mm256_fmadd_pd(ymm0, ymm3, ymm19); - - } - ymm5 = _mm256_permute_pd(ymm5, 0x5); - ymm15 = _mm256_permute_pd(ymm15, 0x5); - ymm19 = _mm256_permute_pd(ymm19, 0x5); - - ymm4 = _mm256_addsub_pd(ymm4, ymm5); - ymm14 = _mm256_addsub_pd(ymm14, ymm15); - ymm18 = _mm256_addsub_pd(ymm18, ymm19); - - ymm8 = _mm256_add_pd(ymm8, ymm4); - ymm9 = _mm256_add_pd(ymm9, ymm14); - ymm10 = _mm256_add_pd(ymm10, ymm18); - - _mm256_storeu_pd((double *)tC, ymm8); - - tC += ldc; - - _mm256_storeu_pd((double *)tC, ymm9); - - tC += ldc; - - _mm256_storeu_pd((double *)tC, ymm10); - } - n_remainder = N - col_idx; - if (n_remainder == 2) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = tA_packed + row_idx_packed; - - - // clear scratch registers. - - BLIS_SET_ALL_YMM_REG_ZEROS - - double *tptr = (double *)tB; - if(conjtransa && conjtransb) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - - tptr += tb_inc_row*2; - tA += lda_packed; - } - } - else if(conjtransa) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - - tptr += tb_inc_row*2; - tA += lda_packed; - } - } - else if(conjtransb) - { - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - - tptr += tb_inc_row*2; - tA += lda_packed; - } - } - else - { - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - - tptr += tb_inc_row*2; - tA += lda_packed; - } - } - ymm4 = _mm256_permute_pd(ymm4, 0x5); - ymm6 = _mm256_permute_pd(ymm6, 0x5); - - ymm8 = _mm256_addsub_pd(ymm8, ymm4); - ymm9 = _mm256_addsub_pd(ymm9, ymm6); - - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); - ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm8, ymm0); - ymm14 = _mm256_mul_pd(ymm8, ymm14); - ymm8 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm9, ymm0); - ymm14 = _mm256_mul_pd(ymm9, ymm14); - ymm9 = _mm256_hsub_pd(ymm15, ymm14); - - BLIS_SET_YMM_REG_ZEROS - - ymm2 = _mm256_broadcast_sd((double const *) - &beta_cast->real); - ymm3 = _mm256_broadcast_sd((double const *) - &beta_cast->imag); - - if(is_beta_non_zero) - { - // multiply C by beta and accumulate col 1. - ymm0 = _mm256_loadu_pd((double const *)tC); - - ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - - ymm0 = _mm256_loadu_pd((double const *)(tC + ldc)); - - ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); - ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); - - } - ymm5 = _mm256_permute_pd(ymm5, 0x5); - ymm15 = _mm256_permute_pd(ymm15, 0x5); - - ymm4 = _mm256_addsub_pd(ymm4, ymm5); - ymm14 = _mm256_addsub_pd(ymm14, ymm15); - - ymm8 = _mm256_add_pd(ymm8, ymm4); - ymm9 = _mm256_add_pd(ymm9, ymm14); - - _mm256_storeu_pd((double *)tC, ymm8); - tC += ldc; - _mm256_storeu_pd((double *)tC, ymm9); - } - if (n_remainder == 1) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = tA_packed + row_idx_packed; + tC += ldc; - // clear scratch registers. - - BLIS_SET_ALL_YMM_REG_ZEROS - - double *tptr = (double *)tB; - if(conjtransa && conjtransb) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - tptr += tb_inc_row*2; - tA += lda_packed; - } - } - else if(conjtransa) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matri - // x data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - tptr += tb_inc_row*2; - tA += lda_packed; - } - } - else if(conjtransb) - { - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matri - // x data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - tptr += tb_inc_row*2; - tA += lda_packed; - } - } - else - { - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - tptr += tb_inc_row*2; - tA += lda_packed; - } - } - ymm4 = _mm256_permute_pd(ymm4, 0x5); - - ymm8 = _mm256_addsub_pd(ymm8, ymm4); - - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); - ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm8, ymm0); - ymm14 = _mm256_mul_pd(ymm8, ymm14); - ymm8 = _mm256_hsub_pd(ymm15, ymm14); - - - BLIS_SET_YMM_REG_ZEROS - - ymm2 = _mm256_broadcast_sd((double const *) - &beta_cast->real); - ymm3 = _mm256_broadcast_sd((double const *) - &beta_cast->imag); - - if(is_beta_non_zero) - { - // multiply C by beta and accumulate col 1. - ymm0 = _mm256_loadu_pd((double const *)tC); - - ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - } - ymm5 = _mm256_permute_pd(ymm5, 0x5); - - ymm4 = _mm256_addsub_pd(ymm4, ymm5); - - ymm8 = _mm256_add_pd(ymm8, ymm4); - - _mm256_storeu_pd((double *)tC, ymm8); - } - } - if ((m_remainder == 1)) - { - m_remainder -= 1; - __m128d xmm0; - - tA = A + row_idx * lda; - tA_packed = D_A_pack; - lda_packed = 1; - - { - dcomplex* tA_temp = tA; - - for(k = 0; (k+1) < K; k += 2) - { - ymm0 = _mm256_loadu_pd((double const *) - (tA_temp + 0 * lda)); - - xmm0 = _mm256_extractf128_pd(ymm0, 0); - _mm_storeu_pd((double *) - (tA_packed + 0 * lda_packed), - xmm0); - - xmm0 = _mm256_extractf128_pd(ymm0, 1); - _mm_storeu_pd((double *)(tA_packed + 1 - * lda_packed), xmm0); - - tA_temp += 2; - tA_packed += 2 * lda_packed; - } - - for(; k < K; k += 1) - { - tA_packed[0].real = tA_temp[0 * lda].real; - tA_packed[0].imag = tA_temp[0 * lda].imag; - - tA_temp += 1; - tA_packed += lda_packed; - } - } - - tA_packed = D_A_pack; - row_idx_packed = 0; - lda_packed = 1; - - for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = tA_packed + row_idx_packed; - - - BLIS_SET_ALL_YMM_REG_ZEROS - xmm0 = _mm_setzero_pd(); - - double *tptr = (double *)tB; - if(conjtransa && conjtransb) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing D_MR x K - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing D_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - xmm0 = _mm_loadu_pd((double const *)(tA)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2 + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - - tptr += (tb_inc_row * 2); - tB += tb_inc_row; - tA += lda_packed; - } - } - else if(conjtransa) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing D_MR x K - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing D_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - xmm0 = _mm_loadu_pd((double const *)(tA)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2 + 1)); - - ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - - tptr += (tb_inc_row * 2); - tB += tb_inc_row; - tA += lda_packed; - } - } - else if(conjtransb) - { - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing D_MR x K - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing D_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - xmm0 = _mm_loadu_pd((double const *)(tA)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2 + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - - tptr += (tb_inc_row * 2); - tB += tb_inc_row; - tA += lda_packed; - } - } - else - { - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing D_MR x K - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing D_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - xmm0 = _mm_loadu_pd((double const *)(tA)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2 + 1)); - - ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - - tptr += (tb_inc_row * 2); - tB += tb_inc_row; - tA += lda_packed; - } - } - ymm4 = _mm256_permute_pd(ymm4, 0x5); - ymm6 = _mm256_permute_pd(ymm6, 0x5); - ymm14 = _mm256_permute_pd(ymm14, 0x5); - - ymm8 = _mm256_addsub_pd(ymm8, ymm4); - ymm9 = _mm256_addsub_pd(ymm9, ymm6); - ymm10 = _mm256_addsub_pd(ymm10, ymm14); - // alpha, beta multiplication. - ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm8, ymm0); - ymm14 = _mm256_mul_pd(ymm8, ymm14); - ymm8 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm9, ymm0); - ymm14 = _mm256_mul_pd(ymm9, ymm14); - ymm9 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm10, ymm0); - ymm14 = _mm256_mul_pd(ymm10, ymm14); - ymm10 = _mm256_hsub_pd(ymm15, ymm14); - - ymm2 = _mm256_broadcast_sd((double const *) - &beta_cast->real); - ymm3 = _mm256_broadcast_sd((double const *) - &beta_cast->imag); - - BLIS_SET_YMM_REG_ZEROS - xmm0 = _mm_setzero_pd(); - - if(is_beta_non_zero) - { - // multiply C by beta and accumulate col 1. - xmm0 = _mm_loadu_pd((double const *)(tC)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); - - ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - - xmm0 = _mm_loadu_pd((double const *)(tC + ldc)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); - - ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); - ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); - - xmm0 = _mm_loadu_pd((double const *) - (tC + ldc * 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); - - ymm18 = _mm256_fmadd_pd(ymm0, ymm2, ymm18); - ymm19 = _mm256_fmadd_pd(ymm0, ymm3, ymm19); - - } - ymm5 = _mm256_permute_pd(ymm5, 0x5); - ymm15 = _mm256_permute_pd(ymm15, 0x5); - ymm19 = _mm256_permute_pd(ymm19, 0x5); - - ymm4 = _mm256_addsub_pd(ymm4, ymm5); - ymm14 = _mm256_addsub_pd(ymm14, ymm15); - ymm18 = _mm256_addsub_pd(ymm18, ymm19); - - ymm8 = _mm256_add_pd(ymm8, ymm4); - ymm9 = _mm256_add_pd(ymm9, ymm14); - ymm10 = _mm256_add_pd(ymm10, ymm18); - - xmm0 = _mm256_extractf128_pd(ymm8, 0); - _mm_storeu_pd((double *)tC, xmm0); - - tC += ldc; - - xmm0 = _mm256_extractf128_pd(ymm9, 0); - _mm_storeu_pd((double *)tC, xmm0); - - tC += ldc; - xmm0 = _mm256_extractf128_pd(ymm10, 0); - _mm_storeu_pd((double *)tC, xmm0); - } - n_remainder = N - col_idx; - if (n_remainder == 2) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = tA_packed + row_idx_packed; - - // clear scratch registers. - - BLIS_SET_ALL_YMM_REG_ZEROS - xmm0 = _mm_setzero_pd(); - - double *tptr = (double *)tB; - if(conjtransa && conjtransb) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - xmm0 = _mm_loadu_pd((double const *)(tA)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - - tptr += tb_inc_row*2; - tA += lda_packed; - } - } - else if(conjtransa) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - xmm0 = _mm_loadu_pd((double const *)(tA)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - - tptr += tb_inc_row*2; - tA += lda_packed; - } - } - else if(conjtransb) - { - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - xmm0 = _mm_loadu_pd((double const *)(tA)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - - tptr += tb_inc_row*2; - tA += lda_packed; - } - } - else - { - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - xmm0 = _mm_loadu_pd((double const *)(tA)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - - tptr += tb_inc_row*2; - tA += lda_packed; - } - } - ymm4 = _mm256_permute_pd(ymm4, 0x5); - ymm6 = _mm256_permute_pd(ymm6, 0x5); - - ymm8 = _mm256_addsub_pd(ymm8, ymm4); - ymm9 = _mm256_addsub_pd(ymm9, ymm6); - - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); - ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm8, ymm0); - ymm14 = _mm256_mul_pd(ymm8, ymm14); - ymm8 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm9, ymm0); - ymm14 = _mm256_mul_pd(ymm9, ymm14); - ymm9 = _mm256_hsub_pd(ymm15, ymm14); - - - BLIS_SET_YMM_REG_ZEROS - xmm0 = _mm_setzero_pd(); - - - ymm2 = _mm256_broadcast_sd((double const *) - &beta_cast->real); - ymm3 = _mm256_broadcast_sd((double const *) - &beta_cast->imag); - - if(is_beta_non_zero) - { - // multiply C by beta and accumulate col 1. - xmm0 = _mm_loadu_pd((double const *)(tC)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); - - ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - - xmm0 = _mm_loadu_pd((double const *)(tC + ldc)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); - - ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); - ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); - } - ymm5 = _mm256_permute_pd(ymm5, 0x5); - ymm15 = _mm256_permute_pd(ymm15, 0x5); - - ymm4 = _mm256_addsub_pd(ymm4, ymm5); - ymm14 = _mm256_addsub_pd(ymm14, ymm15); - - ymm8 = _mm256_add_pd(ymm8, ymm4); - ymm9 = _mm256_add_pd(ymm9, ymm14); - - xmm0 = _mm256_extractf128_pd(ymm8, 0); - _mm_storeu_pd((double *)tC, xmm0); - tC += ldc; - xmm0 = _mm256_extractf128_pd(ymm9, 0); - _mm_storeu_pd((double *)tC, xmm0); - } - if (n_remainder == 1) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = tA_packed + row_idx_packed; - - // clear scratch registers. - - BLIS_SET_ALL_YMM_REG_ZEROS - xmm0 = _mm_setzero_pd(); - - double *tptr = (double *)tB; - if(conjtransa && conjtransb) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - xmm0 = _mm_loadu_pd((double const *)(tA)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - tptr += tb_inc_row*2; - tA += lda_packed; - } - } - else if(conjtransa) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - xmm0 = _mm_loadu_pd((double const *)(tA)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - tptr += tb_inc_row*2; - tA += lda_packed; - } - } - else if(conjtransb) - { - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - xmm0 = _mm_loadu_pd((double const *)(tA)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - tptr += tb_inc_row*2; - tA += lda_packed; - } - } - else - { - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matri - // x data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - xmm0 = _mm_loadu_pd((double const *)(tA)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - tptr += tb_inc_row*2; - tA += lda_packed; - } - } - ymm4 = _mm256_permute_pd(ymm4, 0x5); - - ymm8 = _mm256_addsub_pd(ymm8, ymm4); - - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); - ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm8, ymm0); - ymm14 = _mm256_mul_pd(ymm8, ymm14); - ymm8 = _mm256_hsub_pd(ymm15, ymm14); - - - BLIS_SET_YMM_REG_ZEROS - xmm0 = _mm_setzero_pd(); - - ymm2 = _mm256_broadcast_sd((double const *) - &beta_cast->real); - ymm3 = _mm256_broadcast_sd((double const *) - &beta_cast->imag); - - if(is_beta_non_zero) - { - // multiply C by beta and accumulate col 1. - xmm0 = _mm_loadu_pd((double const *)(tC)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); - - ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - } - ymm5 = _mm256_permute_pd(ymm5, 0x5); - - ymm4 = _mm256_addsub_pd(ymm4, ymm5); - - ymm8 = _mm256_add_pd(ymm8, ymm4); - - xmm0 = _mm256_extractf128_pd(ymm8, 0); - _mm_storeu_pd((double *)tC, xmm0); - - } - } - // Return the buffer to pool - if ((required_packing_A == 1) && bli_mem_is_alloc( &local_mem_buf_A_s )){ -#ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_dgemm_small_At(): releasing mem pool block\n" ); -#endif - bli_membrk_release(&rntm, - &local_mem_buf_A_s); - } - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return BLIS_SUCCESS; - } - else - { - AOCL_DTL_TRACE_EXIT_ERR( - AOCL_DTL_LEVEL_INFO, - "Invalid dimesions for dgemm_small_At." - ); - return BLIS_NONCONFORMAL_DIMENSIONS; - } + _mm256_storeu_pd((double *)tC, ymm10); + _mm256_storeu_pd((double *)(tC + 2), ymm13); + + // modify the pointer arithematic to use packed A matrix. + col_idx_start = NR; + tA_packed = D_A_pack; + row_idx_packed = 0; + lda_packed = Z_MR; + } + // Process NR columns of C matrix at a time. + for (col_idx = col_idx_start; (col_idx + (NR - 1)) < N; + col_idx += NR) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + +#ifdef BLIS_ENABLE_PREFETCH + _mm_prefetch((char*)(tC + 0), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 8), _MM_HINT_T0); + _mm_prefetch((char*)(tC + ldc), _MM_HINT_T0); + _mm_prefetch((char*)(tC + ldc + 8), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 2 * ldc), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 2 * ldc + 8), _MM_HINT_T0); +#endif + // clear scratch registers. + + + BLIS_SET_ALL_YMM_REG_ZEROS + + double *tptr = (double *)tB; + + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd( + (double const *)tA); + ymm1 = _mm256_loadu_pd( + (double const *)(tA + 2)); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and multiplies it with the A + // matrix. This loop is processing + // Z_MR x K The inner loop broadcasts + // the B matrix data and multiplies it + // with the A matrix. This loop is + // processing Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and multiplies it with the A + // matrix. This loop is processing + // Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + else //handles non-transpose case + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and multiplies it with the A + // matrix. This loop is processing + // Z_MR x K The inner loop broadcasts the + // B matrix data and multiplies it with + // the A matrix. This loop is processing + // Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + ymm14 = _mm256_permute_pd(ymm14, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm11 = _mm256_addsub_pd(ymm11, ymm5); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + ymm12 = _mm256_addsub_pd(ymm12, ymm7); + ymm10 = _mm256_addsub_pd(ymm10, ymm14); + ymm13 = _mm256_addsub_pd(ymm13, ymm15); + + // alpha, beta multiplication. + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm10, ymm0); + ymm14 = _mm256_mul_pd(ymm10, ymm14); + ymm10 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm11, ymm0); + ymm14 = _mm256_mul_pd(ymm11, ymm14); + ymm11 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm12, ymm0); + ymm14 = _mm256_mul_pd(ymm12, ymm14); + ymm12 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm13, ymm0); + ymm14 = _mm256_mul_pd(ymm13, ymm14); + ymm13 = _mm256_hsub_pd(ymm15, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + + BLIS_SET_YMM_REG_ZEROS + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + ymm0 = _mm256_loadu_pd((double const *)(tC + 2)); + ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6); + ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7); + + ymm0 = _mm256_loadu_pd((double const *)(tC + ldc)); + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc + 2)); + ymm16 = _mm256_fmadd_pd(ymm0, ymm2, ymm16); + ymm17 = _mm256_fmadd_pd(ymm0, ymm3, ymm17); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc * 2)); + ymm18 = _mm256_fmadd_pd(ymm0, ymm2, ymm18); + ymm19 = _mm256_fmadd_pd(ymm0, ymm3, ymm19); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc * 2 + 2)); + ymm20 = _mm256_fmadd_pd(ymm0, ymm2, ymm20); + ymm21 = _mm256_fmadd_pd(ymm0, ymm3, ymm21); + + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + ymm17 = _mm256_permute_pd(ymm17, 0x5); + ymm19 = _mm256_permute_pd(ymm19, 0x5); + ymm21 = _mm256_permute_pd(ymm21, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm6 = _mm256_addsub_pd(ymm6, ymm7); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + ymm16 = _mm256_addsub_pd(ymm16, ymm17); + ymm18 = _mm256_addsub_pd(ymm18, ymm19); + ymm20 = _mm256_addsub_pd(ymm20, ymm21); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm11 = _mm256_add_pd(ymm11, ymm6); + ymm9 = _mm256_add_pd(ymm9, ymm14); + ymm12 = _mm256_add_pd(ymm12, ymm16); + ymm10 = _mm256_add_pd(ymm10, ymm18); + ymm13 = _mm256_add_pd(ymm13, ymm20); + + _mm256_storeu_pd((double *)tC, ymm8); + _mm256_storeu_pd((double *)(tC + 2), ymm11); + + tC += ldc; + + _mm256_storeu_pd((double *)tC, ymm9); + _mm256_storeu_pd((double *)(tC + 2), ymm12); + + tC += ldc; + + _mm256_storeu_pd((double *)tC, ymm10); + _mm256_storeu_pd((double *)(tC + 2), ymm13); + } + n_remainder = N - col_idx; + if (n_remainder == 2) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + // clear scratch registers. + + + BLIS_SET_ALL_YMM_REG_ZEROS + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and multiplies it with the A + // matrix. This loop is processing + // Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + + tptr += (tb_inc_row * 2); + tA += lda; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and multiplies it with the A + // matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied with matrix A columns. + ymm0 = _mm256_loadu_pd((double const*)tA); + ymm1 = _mm256_loadu_pd((double const*) + (tA + 2)); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and multiplies it with the A + // matrix. This loop is processing + // Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + + tptr += (tb_inc_row * 2); + tA += lda; + } + + } + else //handles non-transpose case + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and multiplies it with the A + // matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied with matrix A columns. + ymm0 = _mm256_loadu_pd((double const*)tA); + ymm1 = _mm256_loadu_pd((double const*) + (tA + 2)); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + tptr += tb_inc_row*2; + tA += lda; + } + + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm11 = _mm256_addsub_pd(ymm11, ymm5); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + ymm12 = _mm256_addsub_pd(ymm12, ymm7); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm11, ymm0); + ymm14 = _mm256_mul_pd(ymm11, ymm14); + ymm11 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm12, ymm0); + ymm14 = _mm256_mul_pd(ymm12, ymm14); + ymm12 = _mm256_hsub_pd(ymm15, ymm14); + + + BLIS_SET_YMM_REG_ZEROS + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + ymm0 = _mm256_loadu_pd((double const *)(tC + 2)); + ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6); + ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7); + + ymm0 = _mm256_loadu_pd((double const *)(tC + ldc)); + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc + 2)); + ymm16 = _mm256_fmadd_pd(ymm0, ymm2, ymm16); + ymm17 = _mm256_fmadd_pd(ymm0, ymm3, ymm17); + + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + ymm17 = _mm256_permute_pd(ymm17, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm6 = _mm256_addsub_pd(ymm6, ymm7); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + ymm16 = _mm256_addsub_pd(ymm16, ymm17); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm11 = _mm256_add_pd(ymm11, ymm6); + ymm9 = _mm256_add_pd(ymm9, ymm14); + ymm12 = _mm256_add_pd(ymm12, ymm16); + + _mm256_storeu_pd((double *)(tC + 0), ymm8); + _mm256_storeu_pd((double *)(tC + 2), ymm11); + tC += ldc; + _mm256_storeu_pd((double *)tC, ymm9); + _mm256_storeu_pd((double *)(tC + 2), ymm12); + } + + if (n_remainder == 1) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + // clear scratch registers. + + + BLIS_SET_ALL_YMM_REG_ZEROS + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and multiplies it with the A + // matrix. This loop is processing + // Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tptr += (tb_inc_row * 2); + tA += lda; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + tptr += tb_inc_row*2; + + //broadcasted matrix B elements are + //multiplied with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tA += lda; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tptr += (tb_inc_row * 2); + tA += lda; + } + } + else //handles non-transpose case + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + tptr += tb_inc_row*2; + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tA += lda; + } + + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm5 = _mm256_permute_pd(ymm5, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm11 = _mm256_addsub_pd(ymm11, ymm5); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm11, ymm0); + ymm14 = _mm256_mul_pd(ymm11, ymm14); + ymm11 = _mm256_hsub_pd(ymm15, ymm14); + + + BLIS_SET_YMM_REG_ZEROS + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + ymm0 = _mm256_loadu_pd((double const *)(tC + 2)); + ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6); + ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7); + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm6 = _mm256_addsub_pd(ymm6, ymm7); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm11 = _mm256_add_pd(ymm11, ymm6); + + _mm256_storeu_pd((double *)tC, ymm8); + _mm256_storeu_pd((double *)(tC + 2), ymm11); + } + } + m_remainder = M - row_idx; + + if ((m_remainder == 3)) + { + m_remainder -= 3; + __m128d xmm0; + + for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + + BLIS_SET_ALL_YMM_REG_ZEROS + + xmm0 = _mm_setzero_pd(); + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *)(tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *)(tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + } + + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + } + + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + ymm14 = _mm256_permute_pd(ymm14, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm11 = _mm256_addsub_pd(ymm11, ymm5); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + ymm12 = _mm256_addsub_pd(ymm12, ymm7); + ymm10 = _mm256_addsub_pd(ymm10, ymm14); + ymm13 = _mm256_addsub_pd(ymm13, ymm15); + // alpha, beta multiplication. + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm10, ymm0); + ymm14 = _mm256_mul_pd(ymm10, ymm14); + ymm10 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm11, ymm0); + ymm14 = _mm256_mul_pd(ymm11, ymm14); + ymm11 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm12, ymm0); + ymm14 = _mm256_mul_pd(ymm12, ymm14); + ymm12 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm13, ymm0); + ymm14 = _mm256_mul_pd(ymm13, ymm14); + ymm13 = _mm256_hsub_pd(ymm15, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + + + BLIS_SET_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + xmm0 = _mm_loadu_pd((double const *)(tC + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + ymm6 = _mm256_fmadd_pd(ymm1, ymm2, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc)); + xmm0 = _mm_loadu_pd((double const *) + (tC + ldc + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + ymm16 = _mm256_fmadd_pd(ymm1, ymm2, ymm16); + ymm17 = _mm256_fmadd_pd(ymm1, ymm3, ymm17); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc * 2)); + xmm0 = _mm_loadu_pd((double const *) + (tC + ldc * 2 + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm18 = _mm256_fmadd_pd(ymm0, ymm2, ymm18); + ymm19 = _mm256_fmadd_pd(ymm0, ymm3, ymm19); + ymm20 = _mm256_fmadd_pd(ymm1, ymm2, ymm20); + ymm21 = _mm256_fmadd_pd(ymm1, ymm3, ymm21); + + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + ymm17 = _mm256_permute_pd(ymm17, 0x5); + ymm19 = _mm256_permute_pd(ymm19, 0x5); + ymm21 = _mm256_permute_pd(ymm21, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm6 = _mm256_addsub_pd(ymm6, ymm7); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + ymm16 = _mm256_addsub_pd(ymm16, ymm17); + ymm18 = _mm256_addsub_pd(ymm18, ymm19); + ymm20 = _mm256_addsub_pd(ymm20, ymm21); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm11 = _mm256_add_pd(ymm11, ymm6); + ymm9 = _mm256_add_pd(ymm9, ymm14); + ymm12 = _mm256_add_pd(ymm12, ymm16); + ymm10 = _mm256_add_pd(ymm10, ymm18); + ymm13 = _mm256_add_pd(ymm13, ymm20); + + _mm256_storeu_pd((double *)tC, ymm8); + xmm0 = _mm256_extractf128_pd(ymm11, 0); + _mm_storeu_pd((double *)(tC + 2), xmm0); + + tC += ldc; + + _mm256_storeu_pd((double *)tC, ymm9); + xmm0 = _mm256_extractf128_pd(ymm12, 0); + _mm_storeu_pd((double *)(tC + 2), xmm0); + + tC += ldc; + + _mm256_storeu_pd((double *)tC, ymm10); + xmm0 = _mm256_extractf128_pd(ymm13, 0); + _mm_storeu_pd((double *)(tC + 2), xmm0); + } + n_remainder = N - col_idx; + if (n_remainder == 2) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + // clear scratch registers. + + BLIS_SET_ALL_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd((tptr + + tb_inc_col + * 0)); + ymm3 = _mm256_broadcast_sd((tptr + + tb_inc_col + * 0 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd((tptr + + tb_inc_col + * 0)); + ymm3 = _mm256_broadcast_sd((tptr + + tb_inc_col + * 0 + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd((tptr + + tb_inc_col + * 0)); + ymm3 = _mm256_broadcast_sd((tptr + + tb_inc_col + * 0 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd((tptr + + tb_inc_col + * 0)); + ymm3 = _mm256_broadcast_sd((tptr + + tb_inc_col + * 0 + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + tptr += tb_inc_row*2; + tA += lda; + } + + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm11 = _mm256_addsub_pd(ymm11, ymm5); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + ymm12 = _mm256_addsub_pd(ymm12, ymm7); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm11, ymm0); + ymm14 = _mm256_mul_pd(ymm11, ymm14); + ymm11 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm12, ymm0); + ymm14 = _mm256_mul_pd(ymm12, ymm14); + ymm12 = _mm256_hsub_pd(ymm15, ymm14); + + + + BLIS_SET_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + xmm0 = _mm_loadu_pd((double const *)(tC + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + ymm6 = _mm256_fmadd_pd(ymm1, ymm2, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm0 = _mm256_loadu_pd((double const *)(tC + ldc)); + xmm0 = _mm_loadu_pd((double const *)(tC + ldc + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + ymm16 = _mm256_fmadd_pd(ymm1, ymm2, ymm16); + ymm17 = _mm256_fmadd_pd(ymm1, ymm3, ymm17); + + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + ymm17 = _mm256_permute_pd(ymm17, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm6 = _mm256_addsub_pd(ymm6, ymm7); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + ymm16 = _mm256_addsub_pd(ymm16, ymm17); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm11 = _mm256_add_pd(ymm11, ymm6); + ymm9 = _mm256_add_pd(ymm9, ymm14); + ymm12 = _mm256_add_pd(ymm12, ymm16); + + _mm256_storeu_pd((double *)tC, ymm8); + xmm0 = _mm256_extractf128_pd(ymm11, 0); + _mm_storeu_pd((double *)(tC + 2), xmm0); + + tC += ldc; + _mm256_storeu_pd((double *)tC, ymm9); + xmm0 = _mm256_extractf128_pd(ymm12, 0); + _mm_storeu_pd((double *)(tC + 2), xmm0); + } + if (n_remainder == 1) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + // clear scratch registers. + + + BLIS_SET_ALL_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tptr += tb_inc_row*2; + tA += lda; + } + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm5 = _mm256_permute_pd(ymm5, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm11 = _mm256_addsub_pd(ymm11, ymm5); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm11, ymm0); + ymm14 = _mm256_mul_pd(ymm11, ymm14); + ymm11 = _mm256_hsub_pd(ymm15, ymm14); + + + + BLIS_SET_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + xmm0 = _mm_loadu_pd((double const *)(tC + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + ymm6 = _mm256_fmadd_pd(ymm1, ymm2, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm6 = _mm256_addsub_pd(ymm6, ymm7); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm11 = _mm256_add_pd(ymm11, ymm6); + + _mm256_storeu_pd((double *)tC, ymm8); + xmm0 = _mm256_extractf128_pd(ymm11, 0); + _mm_storeu_pd((double *)(tC + 2), xmm0); + } + } + if ((m_remainder == 2)) + { + m_remainder -= 2; + + for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + + + BLIS_SET_ALL_YMM_REG_ZEROS + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + } + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + ymm14 = _mm256_permute_pd(ymm14, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + ymm10 = _mm256_addsub_pd(ymm10, ymm14); + // alpha, beta multiplication. + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm10, ymm0); + ymm14 = _mm256_mul_pd(ymm10, ymm14); + ymm10 = _mm256_hsub_pd(ymm15, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + + BLIS_SET_YMM_REG_ZEROS + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + ymm0 = _mm256_loadu_pd((double const *)(tC + ldc)); + + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc * 2)); + + ymm18 = _mm256_fmadd_pd(ymm0, ymm2, ymm18); + ymm19 = _mm256_fmadd_pd(ymm0, ymm3, ymm19); + + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + ymm19 = _mm256_permute_pd(ymm19, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + ymm18 = _mm256_addsub_pd(ymm18, ymm19); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm9 = _mm256_add_pd(ymm9, ymm14); + ymm10 = _mm256_add_pd(ymm10, ymm18); + + _mm256_storeu_pd((double *)tC, ymm8); + + tC += ldc; + + _mm256_storeu_pd((double *)tC, ymm9); + + tC += ldc; + + _mm256_storeu_pd((double *)tC, ymm10); + } + n_remainder = N - col_idx; + if (n_remainder == 2) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + // clear scratch registers. + + BLIS_SET_ALL_YMM_REG_ZEROS + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda; + } + + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + + + BLIS_SET_YMM_REG_ZEROS + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + ymm0 = _mm256_loadu_pd((double const *)(tC + ldc)); + + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm9 = _mm256_add_pd(ymm9, ymm14); + + _mm256_storeu_pd((double *)tC, ymm8); + tC += ldc; + _mm256_storeu_pd((double *)tC, ymm9); + } + if (n_remainder == 1) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + // clear scratch registers. + + + BLIS_SET_ALL_YMM_REG_ZEROS + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda; + } + + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + + + BLIS_SET_YMM_REG_ZEROS + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + + _mm256_storeu_pd((double *)tC, ymm8); + } + } + if ((m_remainder == 1)) + { + m_remainder -= 1; + __m128d xmm0; + + for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + + + BLIS_SET_ALL_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + } + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + ymm14 = _mm256_permute_pd(ymm14, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + ymm10 = _mm256_addsub_pd(ymm10, ymm14); + // alpha, beta multiplication. + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm10, ymm0); + ymm14 = _mm256_mul_pd(ymm10, ymm14); + ymm10 = _mm256_hsub_pd(ymm15, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + BLIS_SET_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + xmm0 = _mm_loadu_pd((double const *)(tC)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + xmm0 = _mm_loadu_pd((double const *)(tC + ldc)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + + xmm0 = _mm_loadu_pd((double const *)(tC + ldc * 2)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm18 = _mm256_fmadd_pd(ymm0, ymm2, ymm18); + ymm19 = _mm256_fmadd_pd(ymm0, ymm3, ymm19); + + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + ymm19 = _mm256_permute_pd(ymm19, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + ymm18 = _mm256_addsub_pd(ymm18, ymm19); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm9 = _mm256_add_pd(ymm9, ymm14); + ymm10 = _mm256_add_pd(ymm10, ymm18); + + xmm0 = _mm256_extractf128_pd(ymm8, 0); + _mm_storeu_pd((double *)tC, xmm0); + + tC += ldc; + + xmm0 = _mm256_extractf128_pd(ymm9, 0); + _mm_storeu_pd((double *)tC, xmm0); + + tC += ldc; + xmm0 = _mm256_extractf128_pd(ymm10, 0); + _mm_storeu_pd((double *)tC, xmm0); + } + n_remainder = N - col_idx; + if (n_remainder == 2) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + // clear scratch registers. + + + BLIS_SET_ALL_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda; + } + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + + + + BLIS_SET_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + xmm0 = _mm_loadu_pd((double const *)(tC)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + xmm0 = _mm_loadu_pd((double const *)(tC + ldc)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm9 = _mm256_add_pd(ymm9, ymm14); + + xmm0 = _mm256_extractf128_pd(ymm8, 0); + _mm_storeu_pd((double *)tC, xmm0); + tC += ldc; + xmm0 = _mm256_extractf128_pd(ymm9, 0); + _mm_storeu_pd((double *)tC, xmm0); + } + if (n_remainder == 1) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + // clear scratch registers. + + BLIS_SET_ALL_YMM_REG_ZEROS + + xmm0 = _mm_setzero_pd(); + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda; + } + + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda; + } + + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + + + BLIS_SET_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + xmm0 = _mm_loadu_pd((double const *)(tC)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + + xmm0 = _mm256_extractf128_pd(ymm8, 0); + _mm_storeu_pd((double *)tC, xmm0); + + } + } + // Return the buffer to pool + if ((required_packing_A == 1) && bli_mem_is_alloc( &local_mem_buf_A_s )) { +#ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_zgemm_small(): releasing mem pool block\n" ); +#endif + bli_membrk_release(&rntm, + &local_mem_buf_A_s); + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return BLIS_SUCCESS; + } + else + { + AOCL_DTL_TRACE_EXIT_ERR( + AOCL_DTL_LEVEL_INFO, + "Invalid dimesions for small gemm." + ); + return BLIS_NONCONFORMAL_DIMENSIONS; + } +}; + +err_t bli_zgemm_small_At + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + cntl_t* cntl + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO); + if (bli_cpuid_is_avx_supported() == FALSE) + { + return BLIS_NOT_YET_IMPLEMENTED; + } + bool conjtransa = bli_obj_has_conj(a); + bool conjtransb = bli_obj_has_conj(b); + + gint_t M = bli_obj_length( c ); // number of rows of Matrix C + gint_t N = bli_obj_width( c ); // number of columns of Matrix C + gint_t K = bli_obj_width_after_trans( a ); // number of columns of OP(A) + + if (N<3) //Implemenation assumes that N is atleast 3. + { + AOCL_DTL_TRACE_EXIT_ERR( + AOCL_DTL_LEVEL_INFO, + "N < 3, cannot be processed by small gemm" + ); + return BLIS_NOT_YET_IMPLEMENTED; + } + + if( M && N && K ) + { + guint_t lda = bli_obj_col_stride( a ); // column stride of matrix OP(A) + guint_t ldb = bli_obj_col_stride( b ); // column stride of matrix OP(B) + guint_t ldc = bli_obj_col_stride( c ); // column stride of matrix C + guint_t row_idx, col_idx, k; + dcomplex *A = bli_obj_buffer_at_off(a); //pointer to elements of Matrix A + dcomplex *B = bli_obj_buffer_at_off(b); //pointer to elements of Matrix B + dcomplex *C = bli_obj_buffer_at_off(c); //pointer to elements of Matrix C + + dcomplex *tA = A, *tB = B, *tC = C;//, *tA_pack; + dcomplex *tA_packed; // temprorary pointer to hold packed A memory pointer + guint_t row_idx_packed; //packed A memory row index + guint_t lda_packed; //lda of packed A + dim_t tb_inc_row = 1; // row stride of matrix B + dim_t tb_inc_col = ldb; // column stride of matrix B + + dcomplex *alpha_cast, *beta_cast; // alpha, beta multiples + alpha_cast = bli_obj_buffer_for_1x1(BLIS_DCOMPLEX, alpha); + beta_cast = bli_obj_buffer_for_1x1(BLIS_DCOMPLEX, beta); + + gint_t required_packing_A = 1; + mem_t local_mem_buf_A_s; + dcomplex *D_A_pack = NULL; + rntm_t rntm; + + if( bli_obj_has_trans( b ) ) + { + tb_inc_col = 1; // switch row and column strides + tb_inc_row = ldb; + } + + __m256d ymm4, ymm5, ymm6, ymm7; + __m256d ymm8, ymm9, ymm10, ymm11; + __m256d ymm12, ymm13, ymm14, ymm15; + __m256d ymm16, ymm17, ymm18, ymm19, ymm20, ymm21; + __m256d ymm0, ymm1, ymm2, ymm3; + + gint_t n_remainder; // If the N is non multiple of 3.(N%3) + gint_t m_remainder; // If the M is non multiple of 16.(M%16) + + //checking whether beta value is zero. + //if true, we should perform C=alpha * A*B operation + //instead of C = beta * C + alpha * (A * B) + bool is_beta_non_zero = 0; + if(!bli_obj_equals(beta, &BLIS_ZERO)) + is_beta_non_zero = 1; + + /* + * This function was using global array to pack part of A input when + * needed. + * However, using this global array make the function non-reentrant. + * Instead of using a global array we should allocate buffer for each + * invocation. + * Since the buffer size is too big or stack and doing malloc every time + * will be too expensive, + * better approach is to get the buffer from the pre-allocated pool and + * return + * it the pool once we are doing. + * + * In order to get the buffer from pool, we need access to memory broker, + * currently this function is not invoked in such a way that it can + * receive + * the memory broker (via rntm). Following hack will get the global memory + * broker that can be use it to access the pool. + * + * Note there will be memory allocation at least on first innovation + * as there will not be any pool created for this size. + * Subsequent invocations will just reuse the buffer from the pool. + */ + + bli_rntm_init_from_global( &rntm ); + bli_rntm_set_num_threads_only( 1, &rntm ); + bli_membrk_rntm_set_membrk( &rntm ); + + // Get the current size of the buffer pool for A block packing. + // We will use the same size to avoid pool re-initliazaton + siz_t buffer_size = bli_pool_block_size( + bli_membrk_pool(bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), + bli_rntm_membrk(&rntm))); + + // + // This kernel assumes that "A" will be unpackged if N <= 3. + // Usually this range (N <= 3) is handled by SUP, however, + // if SUP is disabled or for any other condition if we do + // enter this kernel with N <= 3, we want to make sure that + // "A" remains unpacked. + // + // If this check is removed it will result in the crash as + // reported in CPUPL-587. + // + + if ((N < 3) || ((Z_MR * K) << 4) > buffer_size) + { + required_packing_A = 0; + return BLIS_NOT_YET_IMPLEMENTED; + } + + if (required_packing_A == 1) + { +#ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_zgemm_small_At: Requesting mem pool block of size %lu\n", + buffer_size); +#endif + // Get the buffer from the pool. + bli_membrk_acquire_m(&rntm, + buffer_size, + BLIS_BITVAL_BUFFER_FOR_A_BLOCK, + &local_mem_buf_A_s); + + D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); + } + + /* + * The computation loop runs for D_MRxN columns of C matrix, thus + * accessing the D_MRxK A matrix data and KxNR B matrix data. + * The computation is organized as inner loops of dimension D_MRxNR. + */ + // Process D_MR rows of C matrix at a time. + for (row_idx = 0; (row_idx + (Z_MR - 1)) < M; row_idx += Z_MR) + { + + tA = A + row_idx * lda; + tA_packed = D_A_pack; + lda_packed = Z_MR; + + // Pack 16xk of matrix A into buffer + // continuous access for A and strided stores to B + for(inc_t x = 0; (x) < 2; x += 1) + { + dcomplex* tA_temp = tA; + + for(k = 0; (k+1) < K; k += 2) + { + ymm0 = _mm256_loadu_pd((double const *) + (tA_temp + 0 * lda)); + ymm2 = _mm256_loadu_pd((double const *) + (tA_temp + 1 * lda)); + + ymm6 = _mm256_permute2f128_pd(ymm0,ymm2,0x20); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm2,0x31); + + _mm256_storeu_pd((double *) + (tA_packed + 0 * lda_packed), + ymm6); + _mm256_storeu_pd((double *) + (tA_packed + 1 * lda_packed), + ymm7); + + tA_temp += 2; + tA_packed += 2 * lda_packed; + } + + for(; k < K; k += 1) + { + tA_packed[0].real = tA_temp[0 * lda].real; + tA_packed[0].imag = tA_temp[0 * lda].imag; + tA_packed[1].real = tA_temp[1 * lda].real; + tA_packed[1].imag = tA_temp[1 * lda].imag; + + tA_temp += 1; + tA_packed += lda_packed; + } + + tA += 2 * lda; + tA_packed = D_A_pack + (x + 1)*2; + } + + tA_packed = D_A_pack; + row_idx_packed = 0; + lda_packed = Z_MR; + + // Process NR columns of C matrix at a time. + for (col_idx = 0; (col_idx + (NR - 1)) < N; col_idx += NR) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + +#ifdef BLIS_ENABLE_PREFETCH + _mm_prefetch((char*)(tC + 0), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 8), _MM_HINT_T0); + _mm_prefetch((char*)(tC + ldc), _MM_HINT_T0); + _mm_prefetch((char*)(tC + ldc + 8), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 2 * ldc), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 2 * ldc + 8), _MM_HINT_T0); +#endif + // clear scratch registers. + + BLIS_SET_ALL_YMM_REG_ZEROS + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + ymm14 = _mm256_permute_pd(ymm14, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm11 = _mm256_addsub_pd(ymm11, ymm5); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + ymm12 = _mm256_addsub_pd(ymm12, ymm7); + ymm10 = _mm256_addsub_pd(ymm10, ymm14); + ymm13 = _mm256_addsub_pd(ymm13, ymm15); + + // alpha, beta multiplication. + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm10, ymm0); + ymm14 = _mm256_mul_pd(ymm10, ymm14); + ymm10 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm11, ymm0); + ymm14 = _mm256_mul_pd(ymm11, ymm14); + ymm11 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm12, ymm0); + ymm14 = _mm256_mul_pd(ymm12, ymm14); + ymm12 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm13, ymm0); + ymm14 = _mm256_mul_pd(ymm13, ymm14); + ymm13 = _mm256_hsub_pd(ymm15, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + (&beta_cast->imag)); + + + + BLIS_SET_YMM_REG_ZEROS + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + ymm0 = _mm256_loadu_pd((double const *)(tC + 2)); + ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6); + ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7); + + // col 2 + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc)); + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc + 2)); + ymm16 = _mm256_fmadd_pd(ymm0, ymm2, ymm16); + ymm17 = _mm256_fmadd_pd(ymm0, ymm3, ymm17); + + // col 3 + ymm0 = _mm256_loadu_pd((double const *) + (tC + (ldc * 2))); + ymm18 = _mm256_fmadd_pd(ymm0, ymm2, ymm18); + ymm19 = _mm256_fmadd_pd(ymm0, ymm3, ymm19); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + (ldc * 2) + 2)); + ymm20 = _mm256_fmadd_pd(ymm0, ymm2, ymm20); + ymm21 = _mm256_fmadd_pd(ymm0, ymm3, ymm21); + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + ymm17 = _mm256_permute_pd(ymm17, 0x5); + ymm19 = _mm256_permute_pd(ymm19, 0x5); + ymm21 = _mm256_permute_pd(ymm21, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm6 = _mm256_addsub_pd(ymm6, ymm7); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + ymm16 = _mm256_addsub_pd(ymm16, ymm17); + ymm18 = _mm256_addsub_pd(ymm18, ymm19); + ymm20 = _mm256_addsub_pd(ymm20, ymm21); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm11 = _mm256_add_pd(ymm11, ymm6); + ymm9 = _mm256_add_pd(ymm9, ymm14); + ymm12 = _mm256_add_pd(ymm12, ymm16); + ymm10 = _mm256_add_pd(ymm10, ymm18); + ymm13 = _mm256_add_pd(ymm13, ymm20); + + _mm256_storeu_pd((double *)tC, ymm8); + _mm256_storeu_pd((double *)(tC + 2), ymm11); + + tC += ldc; + + _mm256_storeu_pd((double *)tC, ymm9); + _mm256_storeu_pd((double *)(tC + 2), ymm12); + + tC += ldc; + + _mm256_storeu_pd((double *)tC, ymm10); + _mm256_storeu_pd((double *)(tC + 2), ymm13); + + } + n_remainder = N - col_idx; + + // if the N is not multiple of 3. + // handling edge case. + if (n_remainder == 2) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + + // clear scratch registers. + + + BLIS_SET_ALL_YMM_REG_ZEROS + double *tptr = (double *)tB; + + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const*)tA); + ymm1 = _mm256_loadu_pd((double const*) + (tA + 2)); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + tptr += tb_inc_row*2; + tA += lda_packed; + + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const*)tA); + ymm1 = _mm256_loadu_pd((double const*) + (tA + 2)); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const*)tA); + ymm1 = _mm256_loadu_pd((double const*) + (tA + 2)); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const*)tA); + ymm1 = _mm256_loadu_pd((double const*) + (tA + 2)); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + + + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm11 = _mm256_addsub_pd(ymm11, ymm5); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + ymm12 = _mm256_addsub_pd(ymm12, ymm7); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm11, ymm0); + ymm14 = _mm256_mul_pd(ymm11, ymm14); + ymm11 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm12, ymm0); + ymm14 = _mm256_mul_pd(ymm12, ymm14); + ymm12 = _mm256_hsub_pd(ymm15, ymm14); + + + + BLIS_SET_YMM_REG_ZEROS + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + ymm0 = _mm256_loadu_pd((double const *)(tC + 2)); + ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6); + ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7); + + ymm0 = _mm256_loadu_pd((double const *)(tC + ldc)); + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc + 2)); + ymm16 = _mm256_fmadd_pd(ymm0, ymm2, ymm16); + ymm17 = _mm256_fmadd_pd(ymm0, ymm3, ymm17); + + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + ymm17 = _mm256_permute_pd(ymm17, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm6 = _mm256_addsub_pd(ymm6, ymm7); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + ymm16 = _mm256_addsub_pd(ymm16, ymm17); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm11 = _mm256_add_pd(ymm11, ymm6); + ymm9 = _mm256_add_pd(ymm9, ymm14); + ymm12 = _mm256_add_pd(ymm12, ymm16); + + _mm256_storeu_pd((double *)(tC + 0), ymm8); + _mm256_storeu_pd((double *)(tC + 2), ymm11); + tC += ldc; + _mm256_storeu_pd((double *)tC, ymm9); + _mm256_storeu_pd((double *)(tC + 2), ymm12); + } + // if the N is not multiple of 3. + // handling edge case. + if (n_remainder == 1) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + + // clear scratch registers. + BLIS_SET_ALL_YMM_REG_ZEROS + double *tptr = (double *)tB; + + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + tptr += tb_inc_row*2; + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *)(tA + 2)); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tA += lda_packed; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd((double const *)(tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd((double const *)(tptr + tb_inc_col * 0 + 1)); + tptr += tb_inc_row*2; + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tA += lda_packed; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + tptr += tb_inc_row*2; + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tA += lda_packed; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + tptr += tb_inc_row*2; + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tA += lda_packed; + } + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm5 = _mm256_permute_pd(ymm5, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm11 = _mm256_addsub_pd(ymm11, ymm5); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm11, ymm0); + ymm14 = _mm256_mul_pd(ymm11, ymm14); + ymm11 = _mm256_hsub_pd(ymm15, ymm14); + + + + BLIS_SET_YMM_REG_ZEROS + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + ymm0 = _mm256_loadu_pd((double const *)(tC + 2)); + ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6); + ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7); + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm6 = _mm256_addsub_pd(ymm6, ymm7); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm11 = _mm256_add_pd(ymm11, ymm6); + + _mm256_storeu_pd((double *)tC, ymm8); + _mm256_storeu_pd((double *)(tC + 2), ymm11); + } + } + + m_remainder = M - row_idx; + if ((m_remainder == 3)) + { + m_remainder -= 3; + __m128d xmm0; + + tA = A + row_idx * lda; + tA_packed = D_A_pack; + lda_packed = 3; + { + dcomplex* tA_temp = tA; + + for(k = 0; (k+1) < K; k += 2) + { + ymm0 = _mm256_loadu_pd((double const *) + (tA_temp + 0 * lda)); + ymm2 = _mm256_loadu_pd((double const *) + (tA_temp + 1 * lda)); + ymm3 = _mm256_loadu_pd((double const *) + (tA_temp + 2 * lda)); + + ymm6 = _mm256_permute2f128_pd(ymm0,ymm2,0x20); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm2,0x31); + + _mm256_storeu_pd((double *) + (tA_packed + 0 * lda_packed), + ymm6); + xmm0 = _mm256_extractf128_pd(ymm3, 0); + _mm_storeu_pd((double *) + (tA_packed + 0 * lda_packed + 2), + xmm0); + + _mm256_storeu_pd((double *) + (tA_packed + 1 * lda_packed), + ymm7); + xmm0 = _mm256_extractf128_pd(ymm3, 1); + _mm_storeu_pd((double *) + (tA_packed + 1 * lda_packed + 2), + xmm0); + + tA_temp += 2; + tA_packed += 2 * lda_packed; + } + + for(; k < K; k += 1) + { + tA_packed[0].real = tA_temp[0 * lda].real; + tA_packed[0].imag = tA_temp[0 * lda].imag; + tA_packed[1].real = tA_temp[1 * lda].real; + tA_packed[1].imag = tA_temp[1 * lda].imag; + tA_packed[2].real = tA_temp[2 * lda].real; + tA_packed[2].imag = tA_temp[2 * lda].imag; + + tA_temp += 1; + tA_packed += lda_packed; + } + } + + tA_packed = D_A_pack; + row_idx_packed = 0; + lda_packed = 3; + + for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + + + BLIS_SET_ALL_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + ymm14 = _mm256_permute_pd(ymm14, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm11 = _mm256_addsub_pd(ymm11, ymm5); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + ymm12 = _mm256_addsub_pd(ymm12, ymm7); + ymm10 = _mm256_addsub_pd(ymm10, ymm14); + ymm13 = _mm256_addsub_pd(ymm13, ymm15); + // alpha, beta multiplication. + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm10, ymm0); + ymm14 = _mm256_mul_pd(ymm10, ymm14); + ymm10 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm11, ymm0); + ymm14 = _mm256_mul_pd(ymm11, ymm14); + ymm11 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm12, ymm0); + ymm14 = _mm256_mul_pd(ymm12, ymm14); + ymm12 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm13, ymm0); + ymm14 = _mm256_mul_pd(ymm13, ymm14); + ymm13 = _mm256_hsub_pd(ymm15, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + + BLIS_SET_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + xmm0 = _mm_loadu_pd((double const *)(tC + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + ymm6 = _mm256_fmadd_pd(ymm1, ymm2, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc)); + xmm0 = _mm_loadu_pd((double const *) + (tC + ldc + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + ymm16 = _mm256_fmadd_pd(ymm1, ymm2, ymm16); + ymm17 = _mm256_fmadd_pd(ymm1, ymm3, ymm17); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc * 2)); + xmm0 = _mm_loadu_pd((double const *) + (tC + ldc * 2 + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm18 = _mm256_fmadd_pd(ymm0, ymm2, ymm18); + ymm19 = _mm256_fmadd_pd(ymm0, ymm3, ymm19); + ymm20 = _mm256_fmadd_pd(ymm1, ymm2, ymm20); + ymm21 = _mm256_fmadd_pd(ymm1, ymm3, ymm21); + + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + ymm17 = _mm256_permute_pd(ymm17, 0x5); + ymm19 = _mm256_permute_pd(ymm19, 0x5); + ymm21 = _mm256_permute_pd(ymm21, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm6 = _mm256_addsub_pd(ymm6, ymm7); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + ymm16 = _mm256_addsub_pd(ymm16, ymm17); + ymm18 = _mm256_addsub_pd(ymm18, ymm19); + ymm20 = _mm256_addsub_pd(ymm20, ymm21); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm11 = _mm256_add_pd(ymm11, ymm6); + ymm9 = _mm256_add_pd(ymm9, ymm14); + ymm12 = _mm256_add_pd(ymm12, ymm16); + ymm10 = _mm256_add_pd(ymm10, ymm18); + ymm13 = _mm256_add_pd(ymm13, ymm20); + + _mm256_storeu_pd((double *)tC, ymm8); + xmm0 = _mm256_extractf128_pd(ymm11, 0); + _mm_storeu_pd((double *)(tC + 2), xmm0); + + tC += ldc; + + _mm256_storeu_pd((double *)tC, ymm9); + xmm0 = _mm256_extractf128_pd(ymm12, 0); + _mm_storeu_pd((double *)(tC + 2), xmm0); + + tC += ldc; + + _mm256_storeu_pd((double *)tC, ymm10); + xmm0 = _mm256_extractf128_pd(ymm13, 0); + _mm_storeu_pd((double *)(tC + 2), xmm0); + } + n_remainder = N - col_idx; + if (n_remainder == 2) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + + // clear scratch registers. + BLIS_SET_ALL_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd((tptr + + tb_inc_col + * 0)); + ymm3 = _mm256_broadcast_sd((tptr + + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd((tptr + + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd((tptr + + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd((tptr + + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd((tptr + + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd((tptr + + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd((tptr + + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm11 = _mm256_addsub_pd(ymm11, ymm5); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + ymm12 = _mm256_addsub_pd(ymm12, ymm7); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm11, ymm0); + ymm14 = _mm256_mul_pd(ymm11, ymm14); + ymm11 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm12, ymm0); + ymm14 = _mm256_mul_pd(ymm12, ymm14); + ymm12 = _mm256_hsub_pd(ymm15, ymm14); + + + BLIS_SET_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + xmm0 = _mm_loadu_pd((double const *)(tC + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + ymm6 = _mm256_fmadd_pd(ymm1, ymm2, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc)); + xmm0 = _mm_loadu_pd((double const *) + (tC + ldc + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + ymm16 = _mm256_fmadd_pd(ymm1, ymm2, ymm16); + ymm17 = _mm256_fmadd_pd(ymm1, ymm3, ymm17); + + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + ymm17 = _mm256_permute_pd(ymm17, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm6 = _mm256_addsub_pd(ymm6, ymm7); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + ymm16 = _mm256_addsub_pd(ymm16, ymm17); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm11 = _mm256_add_pd(ymm11, ymm6); + ymm9 = _mm256_add_pd(ymm9, ymm14); + ymm12 = _mm256_add_pd(ymm12, ymm16); + + _mm256_storeu_pd((double *)tC, ymm8); + xmm0 = _mm256_extractf128_pd(ymm11, 0); + _mm_storeu_pd((double *)(tC + 2), xmm0); + + tC += ldc; + _mm256_storeu_pd((double *)tC, ymm9); + xmm0 = _mm256_extractf128_pd(ymm12, 0); + _mm_storeu_pd((double *)(tC + 2), xmm0); + } + if (n_remainder == 1) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + + // clear scratch registers. + + BLIS_SET_ALL_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm5 = _mm256_permute_pd(ymm5, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm11 = _mm256_addsub_pd(ymm11, ymm5); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm11, ymm0); + ymm14 = _mm256_mul_pd(ymm11, ymm14); + ymm11 = _mm256_hsub_pd(ymm15, ymm14); + + + BLIS_SET_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + xmm0 = _mm_loadu_pd((double const *)(tC + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + ymm6 = _mm256_fmadd_pd(ymm1, ymm2, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm6 = _mm256_addsub_pd(ymm6, ymm7); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm11 = _mm256_add_pd(ymm11, ymm6); + + _mm256_storeu_pd((double *)tC, ymm8); + xmm0 = _mm256_extractf128_pd(ymm11, 0); + _mm_storeu_pd((double *)(tC + 2), xmm0); + } + } + if ((m_remainder == 2)) + { + m_remainder -= 2; + + tA = A + row_idx * lda; + tA_packed = D_A_pack; + lda_packed = 2; + + { + dcomplex* tA_temp = tA; + + for(k = 0; (k+1) < K; k += 2) + { + ymm0 = _mm256_loadu_pd((double const *) + (tA_temp + 0 * lda)); + ymm2 = _mm256_loadu_pd((double const *) + (tA_temp + 1 * lda)); + + ymm6 = _mm256_permute2f128_pd(ymm0,ymm2,0x20); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm2,0x31); + + _mm256_storeu_pd((double *) + (tA_packed + 0 * lda_packed), + ymm6); + _mm256_storeu_pd((double *) + (tA_packed + 1 * lda_packed), + ymm7); + + tA_temp += 2; + tA_packed += 2 * lda_packed; + } + + for(; k < K; k += 1) + { + tA_packed[0].real = tA_temp[0 * lda].real; + tA_packed[0].imag = tA_temp[0 * lda].imag; + tA_packed[1].real = tA_temp[1 * lda].real; + tA_packed[1].imag = tA_temp[1 * lda].imag; + + tA_temp += 1; + tA_packed += lda_packed; + } + } + + tA_packed = D_A_pack; + row_idx_packed = 0; + lda_packed = 2; + + for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + + + + BLIS_SET_ALL_YMM_REG_ZEROS + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + ymm14 = _mm256_permute_pd(ymm14, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + ymm10 = _mm256_addsub_pd(ymm10, ymm14); + // alpha, beta multiplication. + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm10, ymm0); + ymm14 = _mm256_mul_pd(ymm10, ymm14); + ymm10 = _mm256_hsub_pd(ymm15, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + BLIS_SET_YMM_REG_ZEROS + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc)); + + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc * 2)); + + ymm18 = _mm256_fmadd_pd(ymm0, ymm2, ymm18); + ymm19 = _mm256_fmadd_pd(ymm0, ymm3, ymm19); + + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + ymm19 = _mm256_permute_pd(ymm19, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + ymm18 = _mm256_addsub_pd(ymm18, ymm19); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm9 = _mm256_add_pd(ymm9, ymm14); + ymm10 = _mm256_add_pd(ymm10, ymm18); + + _mm256_storeu_pd((double *)tC, ymm8); + + tC += ldc; + + _mm256_storeu_pd((double *)tC, ymm9); + + tC += ldc; + + _mm256_storeu_pd((double *)tC, ymm10); + } + n_remainder = N - col_idx; + if (n_remainder == 2) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + + + // clear scratch registers. + + BLIS_SET_ALL_YMM_REG_ZEROS + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + + BLIS_SET_YMM_REG_ZEROS + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + ymm0 = _mm256_loadu_pd((double const *)(tC + ldc)); + + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm9 = _mm256_add_pd(ymm9, ymm14); + + _mm256_storeu_pd((double *)tC, ymm8); + tC += ldc; + _mm256_storeu_pd((double *)tC, ymm9); + } + if (n_remainder == 1) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + + // clear scratch registers. + + BLIS_SET_ALL_YMM_REG_ZEROS + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matri + // x data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matri + // x data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + + BLIS_SET_YMM_REG_ZEROS + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + + _mm256_storeu_pd((double *)tC, ymm8); + } + } + if ((m_remainder == 1)) + { + m_remainder -= 1; + __m128d xmm0; + + tA = A + row_idx * lda; + tA_packed = D_A_pack; + lda_packed = 1; + + { + dcomplex* tA_temp = tA; + + for(k = 0; (k+1) < K; k += 2) + { + ymm0 = _mm256_loadu_pd((double const *) + (tA_temp + 0 * lda)); + + xmm0 = _mm256_extractf128_pd(ymm0, 0); + _mm_storeu_pd((double *) + (tA_packed + 0 * lda_packed), + xmm0); + + xmm0 = _mm256_extractf128_pd(ymm0, 1); + _mm_storeu_pd((double *)(tA_packed + 1 + * lda_packed), xmm0); + + tA_temp += 2; + tA_packed += 2 * lda_packed; + } + + for(; k < K; k += 1) + { + tA_packed[0].real = tA_temp[0 * lda].real; + tA_packed[0].imag = tA_temp[0 * lda].imag; + + tA_temp += 1; + tA_packed += lda_packed; + } + } + + tA_packed = D_A_pack; + row_idx_packed = 0; + lda_packed = 1; + + for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + + + BLIS_SET_ALL_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + ymm14 = _mm256_permute_pd(ymm14, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + ymm10 = _mm256_addsub_pd(ymm10, ymm14); + // alpha, beta multiplication. + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm10, ymm0); + ymm14 = _mm256_mul_pd(ymm10, ymm14); + ymm10 = _mm256_hsub_pd(ymm15, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + BLIS_SET_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + xmm0 = _mm_loadu_pd((double const *)(tC)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + xmm0 = _mm_loadu_pd((double const *)(tC + ldc)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + + xmm0 = _mm_loadu_pd((double const *) + (tC + ldc * 2)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm18 = _mm256_fmadd_pd(ymm0, ymm2, ymm18); + ymm19 = _mm256_fmadd_pd(ymm0, ymm3, ymm19); + + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + ymm19 = _mm256_permute_pd(ymm19, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + ymm18 = _mm256_addsub_pd(ymm18, ymm19); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm9 = _mm256_add_pd(ymm9, ymm14); + ymm10 = _mm256_add_pd(ymm10, ymm18); + + xmm0 = _mm256_extractf128_pd(ymm8, 0); + _mm_storeu_pd((double *)tC, xmm0); + + tC += ldc; + + xmm0 = _mm256_extractf128_pd(ymm9, 0); + _mm_storeu_pd((double *)tC, xmm0); + + tC += ldc; + xmm0 = _mm256_extractf128_pd(ymm10, 0); + _mm_storeu_pd((double *)tC, xmm0); + } + n_remainder = N - col_idx; + if (n_remainder == 2) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + + // clear scratch registers. + + BLIS_SET_ALL_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + + + BLIS_SET_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + xmm0 = _mm_loadu_pd((double const *)(tC)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + xmm0 = _mm_loadu_pd((double const *)(tC + ldc)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm9 = _mm256_add_pd(ymm9, ymm14); + + xmm0 = _mm256_extractf128_pd(ymm8, 0); + _mm_storeu_pd((double *)tC, xmm0); + tC += ldc; + xmm0 = _mm256_extractf128_pd(ymm9, 0); + _mm_storeu_pd((double *)tC, xmm0); + } + if (n_remainder == 1) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + + // clear scratch registers. + + BLIS_SET_ALL_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matri + // x data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + + BLIS_SET_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + xmm0 = _mm_loadu_pd((double const *)(tC)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + + xmm0 = _mm256_extractf128_pd(ymm8, 0); + _mm_storeu_pd((double *)tC, xmm0); + + } + } + // Return the buffer to pool + if ((required_packing_A == 1) && bli_mem_is_alloc( &local_mem_buf_A_s )){ +#ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_zgemm_small_At(): releasing mem pool block\n" ); +#endif + bli_membrk_release(&rntm, + &local_mem_buf_A_s); + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return BLIS_SUCCESS; + } + else + { + AOCL_DTL_TRACE_EXIT_ERR( + AOCL_DTL_LEVEL_INFO, + "Invalid dimesions for dgemm_small_At." + ); + return BLIS_NONCONFORMAL_DIMENSIONS; + } }; #endif diff --git a/kernels/zen/bli_kernels_zen.h b/kernels/zen/bli_kernels_zen.h index bb7b57d096..73dd55b994 100644 --- a/kernels/zen/bli_kernels_zen.h +++ b/kernels/zen/bli_kernels_zen.h @@ -271,6 +271,28 @@ err_t bli_dgemm_small_At cntl_t* cntl ); +err_t bli_zgemm_small + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + cntl_t* cntl + ); + +err_t bli_zgemm_small_At + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + cntl_t* cntl + ); + // gemm square matrix size friendly implementation err_t bli_gemm_sqp ( From cc3069fb5e205a78475384c6414f32e8f6cbbac5 Mon Sep 17 00:00:00 2001 From: Sireesha Sanga Date: Wed, 6 Apr 2022 00:53:27 +0530 Subject: [PATCH 099/243] Performance Improvement for ztrsm small sizes Details: - Optimization of ztrsm for Non-unit Diag Variants. - Handled Overflow and Underflow Vulnerabilites in ztrsm small implementations. - Fixed failures observed in libflame testing. - Fine-tuned ztrsm small implementations for specific sizes 64<= m,n <= 256, by keeping the number of threads to the optimum value, under AOCL_DYNAMIC flag. - For small sizes, ztrsm small implementation is used for all variants. AMD-Internal: [SWLCSG-1194] Change-Id: I066491bb03e5cda390cb699182af4350ae60be2d --- frame/base/bli_rntm.c | 10 ++- frame/compat/bla_trsm_amd.c | 2 - kernels/zen/3/bli_trsm_small.c | 155 ++++++++++++--------------------- 3 files changed, 66 insertions(+), 101 deletions(-) diff --git a/frame/base/bli_rntm.c b/frame/base/bli_rntm.c index c597074f58..f8e00c6208 100644 --- a/frame/base/bli_rntm.c +++ b/frame/base/bli_rntm.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -639,6 +639,14 @@ void bli_nthreads_optimum( if(m<=512 && n<=512) n_threads_ideal = 4; } + else if( family == BLIS_TRSM && bli_obj_is_dcomplex(c)) + { + dim_t m = bli_obj_length(c); + dim_t n = bli_obj_width(c); + + if((m>=64) && (m<=256) && (n>=64) && (n<=256)) + n_threads_ideal = 8; + } else if( family == BLIS_GEMMT && bli_obj_is_double(c) ) { dim_t n = bli_obj_length(c); diff --git a/frame/compat/bla_trsm_amd.c b/frame/compat/bla_trsm_amd.c index eb5c835ff5..9ff8073be0 100644 --- a/frame/compat/bla_trsm_amd.c +++ b/frame/compat/bla_trsm_amd.c @@ -1184,7 +1184,6 @@ void ztrsm_ * is doing better than native multithread */ bool nt = bli_thread_get_is_parallel(); - if((blis_side == BLIS_RIGHT) || (blis_diaga == BLIS_UNIT_DIAG)) { if(((nt==0) && (m0<=500) && (n0<=500)) || (nt && ((m0+n0)<128))) { @@ -1206,7 +1205,6 @@ void ztrsm_ return; } } - } #endif bli_trsmnat diff --git a/kernels/zen/3/bli_trsm_small.c b/kernels/zen/3/bli_trsm_small.c index 32b7647a50..bb6d198c78 100644 --- a/kernels/zen/3/bli_trsm_small.c +++ b/kernels/zen/3/bli_trsm_small.c @@ -5771,68 +5771,58 @@ BLIS_INLINE err_t ztrsm_AuXB_ref * Performs dcomplex division of vec1 and vec2 with ymm1. * vec1 and vec2 gets divided by ymm1 which holds * diagonal element from buffer. - * Function gets called while performing TRSM. + * Using bli_zinvscals() to avoid overflow and underflow + * scenarios. Function gets called while performing TRSM. */ #define BLIS_ZTRSM_TWO_DIV(vec1, vec2) {\ if(!is_unitdiag) {\ if(conjtransa){\ ymm1 = _mm256_mul_pd(ymm1, ymm0);\ }\ - ymm12 = _mm256_mul_pd(ymm1, ymm0);\ - /*perform decomplex multiplication*/\ - /* Switch the real and imaginary elements of vec2 */\ - ymm14 = _mm256_permute_pd(ymm12, 0x5);\ - /* Negate the imaginary elements of vec2 */\ - ymm14 = _mm256_mul_pd(ymm14, ymm0);\ - /* Multiply vec1 and vec2 */ \ - ymm13 = _mm256_mul_pd(vec1, ymm12); /*vec3*/\ - /* Multiply vec1 and the modified vec2 */\ - ymm14 = _mm256_mul_pd(vec1, ymm14); /*vec4*/\ - /* Horizontally subtract the elements in vec3 and vec4 */\ - vec1 = _mm256_hsub_pd(ymm13, ymm14);\ - \ - ymm14 = _mm256_permute_pd(ymm12, 0x5);\ - /* Negate the imaginary elements of vec2 */\ - ymm14 = _mm256_mul_pd(ymm14, ymm0);\ - ymm13 = _mm256_mul_pd(vec2, ymm12);\ - ymm14 = _mm256_mul_pd(vec2, ymm14);\ - vec2 = _mm256_hsub_pd(ymm13, ymm14);\ - /*dcomplex multiplication is done*/\ - /*Swapping real & imaginary component position for addition with respective - * components*/\ - ymm12 = _mm256_mul_pd(ymm1, ymm1);\ - ymm13 = _mm256_permute4x64_pd(ymm12, 0xb1);\ - ymm14 = _mm256_add_pd(ymm12, ymm13);\ - \ - /*Finally dividing numerator by denominator*/\ - vec1 = _mm256_div_pd(vec1, ymm14);\ - vec2 = _mm256_div_pd(vec2, ymm14);\ +\ + dcomplex b_data[4];\ + dcomplex d11_data[2];\ +\ + _mm256_storeu_pd((double *)(b_data), vec1);\ + _mm256_storeu_pd((double *)(b_data + 2), vec2);\ + _mm256_storeu_pd((double *)(d11_data), ymm1);\ +\ + for(dim_t i = 0; i < 4; i++)\ + {\ + bli_zinvscals(d11_data[0],b_data[i]);\ + }\ +\ + vec1 = _mm256_loadu_pd((double *)b_data);\ + vec2 = _mm256_loadu_pd((double *)(b_data+2));\ +\ }\ } /** * Performs dcomplex division of vec1 with ymm1. * ymm1 holds diagonal element from buffer. - * Function gets called while performing TRSM. + * Using bli_zinvscals() to avoid overflow and underflow + * scenarios. Function gets called while performing TRSM. */ #define BLIS_ZTRSM_DIV(vec1) {\ if(!is_unitdiag){\ if(conjtransa){\ ymm1 = _mm256_mul_pd(ymm1, ymm0);\ }\ - ymm12 = _mm256_mul_pd(ymm1, ymm0); /*vec2 and ymm8 is vec1*/\ - ymm14 = _mm256_permute_pd(ymm12, 0x5);\ - ymm14 = _mm256_mul_pd(ymm14, ymm0);\ - ymm13 = _mm256_mul_pd(vec1, ymm12); /*vec3*/\ - ymm14 = _mm256_mul_pd(vec1, ymm14); /*vec4*/\ - vec1 = _mm256_hsub_pd(ymm13, ymm14);\ - \ - ymm12 = _mm256_mul_pd(ymm1, ymm1);\ - ymm13 = _mm256_permute4x64_pd(ymm12, 0xb1);\ - ymm14 = _mm256_add_pd(ymm12, ymm13);\ - \ - /*Finally dividing numerator by denominator*/\ - vec1 = _mm256_div_pd(vec1, ymm14);\ +\ + dcomplex b_data[2];\ + dcomplex d11_data[2];\ +\ + _mm256_storeu_pd((double *)(b_data), vec1);\ + _mm256_storeu_pd((double *)(d11_data), ymm1);\ +\ + for(dim_t i = 0; i < 2; i++)\ + {\ + bli_zinvscals(d11_data[0],b_data[i]);\ + }\ +\ + vec1 = _mm256_loadu_pd((double *)b_data);\ +\ }\ } @@ -6007,7 +5997,6 @@ BLIS_INLINE void bli_ztrsm_small_pack } - BLIS_INLINE void ztrsm_small_pack_diag_element ( bool is_unitdiag, @@ -6018,64 +6007,31 @@ BLIS_INLINE void ztrsm_small_pack_diag_element ) { #ifdef BLIS_ENABLE_TRSM_PREINVERSION - __m256d ymm1, ymm2, ymm3, ymm4, ymm5, ymm6, ymm7, ymm8; - ymm7 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); -#else - __m256d ymm1, ymm2, ymm3; -#endif - bool is_four = (size == 4) ? 1 : 0; - dcomplex ones = {1.0, 1.0}; - ymm2 = ymm1 = _mm256_broadcast_pd((__m128d const *)&ones); - if(!is_unitdiag) + // If Preinversion is enabled, inverse the diaganol + // elements from A and pack into diagonal buffer. + // In order to avoid the overflow and underflow scenarios, + // bli_zinvscals is used + for( dim_t i = 0; i < size; i++) { - //broadcast diagonal elements of A11 - ymm1 = _mm256_broadcast_pd((__m128d const *)a11); - ymm2 = _mm256_broadcast_pd((__m128d const *)a11+ cs_a +1); - /*Pick one element frome each column and create 3 element vector - and store it*/ - ymm1 = _mm256_permute2f128_pd(ymm1, ymm2, 0x20); - ymm2 = _mm256_broadcast_pd((__m128d const *)a11+ cs_a*2 + 2); - - if(is_four) - { - ymm3 = _mm256_broadcast_pd((__m128d const *)a11+ cs_a*2 + 2); - ymm2 = _mm256_broadcast_pd((__m128d const *)a11+ cs_a*3 + 3); - ymm2 = _mm256_permute2f128_pd(ymm3, ymm2, 0x20); - } + dim_t d = ((i*cs_a) + i); + dcomplex ones = {1.0, 0.0}; + bli_zinvscals(a11[d], ones) + d11_pack[i].real = ones.real; + d11_pack[i].imag = ones.imag; + } -#ifdef BLIS_ENABLE_TRSM_PREINVERSION - /*Taking denomerator multiplication of real & imaginary components*/ - ymm4 = _mm256_mul_pd(ymm1, ymm1); - ymm5 = _mm256_mul_pd(ymm2,ymm2); - /*Swapping real & imaginary component position for addition with - * respective components*/ - ymm6 = _mm256_permute4x64_pd(ymm4, 0xb1); - ymm4 = _mm256_add_pd(ymm4, ymm6); - ymm8 = _mm256_permute4x64_pd(ymm5, 0xb1); - - ymm5 = _mm256_add_pd(ymm5, ymm8); - /*Negating imaginary component of numerator*/ - ymm1 = _mm256_mul_pd(ymm1, ymm7); - ymm2 = _mm256_mul_pd(ymm2, ymm7); - /*Dividing numerator by denominator*/ - ymm1 = _mm256_div_pd(ymm1, ymm4); - ymm2 = _mm256_div_pd(ymm2, ymm5); -#endif +#else //BLIS_ENABLE_TRSM_PREINVERSION - } - _mm256_store_pd((double *)d11_pack, ymm1); - if(is_four) + // If Preinversion is disabled, pack the diaganol + // elements from A into diagonal buffer. + for( dim_t i = 0; i < size; i++) { - _mm256_store_pd((double *)(d11_pack + 2), ymm2); + dim_t d = ((i*cs_a) + i); + bli_zcopys(a11[d],d11_pack[i]); } - else - { - _mm_store_pd((double *)(d11_pack + 2), - _mm256_extractf128_pd(ymm2,0)); - } +#endif //BLIS_ENABLE_TRSM_PREINVERSION } - /*implements TRSM for the case XA = alpha * B *A is lower triangular, non-unit diagonal/unit diagonal, transpose *dimensions: X:mxn A:nxn B: mxn @@ -14948,9 +14904,12 @@ BLIS_INLINE void strsm_small_pack_diag_element __m256 ymm0, ymm1, ymm2, ymm3; __m256 ymm4, ymm5, ymm6, ymm7; __m256 ymm8, ymm9, ymm10,ymm11; - __m256 ymm14, ymm15, ymm12,ymm13; + __m256 ymm14, ymm15, ymm12; float ones = 1.0; - ymm13 = ymm14 = ymm15 = _mm256_broadcast_ss((float const *)&ones); + ymm14 = ymm15 = _mm256_broadcast_ss((float const *)&ones); +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + __m256 ymm13 = _mm256_broadcast_ss((float const *)&ones); +#endif if(side=='L'||side=='l') { if(!is_unitdiag) From caa5b37005adcb657bbbbc404b55477a4b60d248 Mon Sep 17 00:00:00 2001 From: Arnav Sharma Date: Mon, 21 Mar 2022 12:53:05 +0530 Subject: [PATCH 100/243] Optimized S/DCOMPLEX DOTXAXPYF using AVX2 Intrinsics Details: - Optimized implementation of DOTXAXPYF fused kernel for single and double precision complex datatype using AVX2 Intrinsics - Updated definitions zen context AMD-Internal: [CPUPL-2059] Change-Id: Ic657e4b66172ae459173626222af2756a4125565 --- config/zen/bli_cntx_init_zen.c | 5 +- config/zen2/bli_cntx_init_zen2.c | 5 +- config/zen3/bli_cntx_init_zen3.c | 5 +- kernels/zen/1f/bli_dotxaxpyf_zen_int_8.c | 832 ++++++++++++++++++++++- kernels/zen/bli_kernels_zen.h | 2 + 5 files changed, 843 insertions(+), 6 deletions(-) diff --git a/config/zen/bli_cntx_init_zen.c b/config/zen/bli_cntx_init_zen.c index 1badc24f96..674549d77f 100644 --- a/config/zen/bli_cntx_init_zen.c +++ b/config/zen/bli_cntx_init_zen.c @@ -80,12 +80,15 @@ void bli_cntx_init_zen( cntx_t* cntx ) // Update the context with optimized level-1f kernels. bli_cntx_set_l1f_kers ( - 10, + 12, // axpyf BLIS_AXPYF_KER, BLIS_FLOAT, bli_saxpyf_zen_int_8, BLIS_AXPYF_KER, BLIS_DOUBLE, bli_daxpyf_zen_int_8, BLIS_AXPYF_KER, BLIS_SCOMPLEX, bli_caxpyf_zen_int_5, BLIS_AXPYF_KER, BLIS_DCOMPLEX, bli_zaxpyf_zen_int_5, + // dotxaxpyf + BLIS_DOTXAXPYF_KER, BLIS_SCOMPLEX, bli_cdotxaxpyf_zen_int_8, + BLIS_DOTXAXPYF_KER, BLIS_DCOMPLEX, bli_zdotxaxpyf_zen_int_8, // dotxf BLIS_DOTXF_KER, BLIS_FLOAT, bli_sdotxf_zen_int_8, BLIS_DOTXF_KER, BLIS_DOUBLE, bli_ddotxf_zen_int_8, diff --git a/config/zen2/bli_cntx_init_zen2.c b/config/zen2/bli_cntx_init_zen2.c index 997ccdba2e..48cb90a4f8 100644 --- a/config/zen2/bli_cntx_init_zen2.c +++ b/config/zen2/bli_cntx_init_zen2.c @@ -92,12 +92,15 @@ void bli_cntx_init_zen2( cntx_t* cntx ) // Update the context with optimized level-1f kernels. bli_cntx_set_l1f_kers ( - 10, + 12, // axpyf BLIS_AXPYF_KER, BLIS_FLOAT, bli_saxpyf_zen_int_5, BLIS_AXPYF_KER, BLIS_DOUBLE, bli_daxpyf_zen_int_5, BLIS_AXPYF_KER, BLIS_SCOMPLEX, bli_caxpyf_zen_int_5, BLIS_AXPYF_KER, BLIS_DCOMPLEX, bli_zaxpyf_zen_int_5, + // dotxaxpyf + BLIS_DOTXAXPYF_KER, BLIS_SCOMPLEX, bli_cdotxaxpyf_zen_int_8, + BLIS_DOTXAXPYF_KER, BLIS_DCOMPLEX, bli_zdotxaxpyf_zen_int_8, // dotxf BLIS_DOTXF_KER, BLIS_FLOAT, bli_sdotxf_zen_int_8, BLIS_DOTXF_KER, BLIS_DOUBLE, bli_ddotxf_zen_int_8, diff --git a/config/zen3/bli_cntx_init_zen3.c b/config/zen3/bli_cntx_init_zen3.c index 61fefdbc31..e83a12b401 100644 --- a/config/zen3/bli_cntx_init_zen3.c +++ b/config/zen3/bli_cntx_init_zen3.c @@ -92,12 +92,15 @@ void bli_cntx_init_zen3( cntx_t* cntx ) // Update the context with optimized level-1f kernels. bli_cntx_set_l1f_kers ( - 10, + 12, // axpyf BLIS_AXPYF_KER, BLIS_FLOAT, bli_saxpyf_zen_int_5, BLIS_AXPYF_KER, BLIS_DOUBLE, bli_daxpyf_zen_int_5, BLIS_AXPYF_KER, BLIS_SCOMPLEX, bli_caxpyf_zen_int_5, BLIS_AXPYF_KER, BLIS_DCOMPLEX, bli_zaxpyf_zen_int_5, + // dotxaxpyf + BLIS_DOTXAXPYF_KER, BLIS_SCOMPLEX, bli_cdotxaxpyf_zen_int_8, + BLIS_DOTXAXPYF_KER, BLIS_DCOMPLEX, bli_zdotxaxpyf_zen_int_8, // dotxf BLIS_DOTXF_KER, BLIS_FLOAT, bli_sdotxf_zen_int_8, BLIS_DOTXF_KER, BLIS_DOUBLE, bli_ddotxf_zen_int_8, diff --git a/kernels/zen/1f/bli_dotxaxpyf_zen_int_8.c b/kernels/zen/1f/bli_dotxaxpyf_zen_int_8.c index b24aab7571..1be9975ecf 100644 --- a/kernels/zen/1f/bli_dotxaxpyf_zen_int_8.c +++ b/kernels/zen/1f/bli_dotxaxpyf_zen_int_8.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2021-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -40,6 +40,12 @@ typedef union{ double d[4] __attribute__((aligned(64))); }vec; +typedef union +{ + __m256 v; + float f[8] __attribute__((aligned(64))); +} v8sf_t; + /** * bli_pre_hemv_lower_8x8 is a helper function which computes * "y = y + alpha * a * x" @@ -467,8 +473,9 @@ void bli_ddotxaxpyf_zen_int_8 /* A is m x n. */ /* y = beta * y + alpha * A^T w; */ /* z = z + alpha * A x; */ - if ((inca == 1) && (incw == 1) && (incx == 1) - && (incy == 1) && (incz == 1) && (b_n == 8)) + if ( ( bli_cpuid_is_avx_supported() == TRUE ) && + (inca == 1) && (incw == 1) && (incx == 1) + && (incy == 1) && (incz == 1) && (b_n == 8) ) { __m256d r0, r1; r0 = _mm256_setzero_pd(); @@ -733,3 +740,822 @@ void bli_ddotxaxpyf_zen_int_8 ); } } + +/** + * zdotxaxpyf kernel performs dot and apxy function together. + * y := conj(beta) * y + conj(alpha) * conj(A)^t * conj(w) (dotxf) + * z := z + alpha * conj(A) * conj(x) (axpyf) + * where, + * A is an m x b matrix. + * w, z are vectors of length m. + * x, y are vectors of length b. + * alpha, beta are scalars + */ +void bli_zdotxaxpyf_zen_int_8 +( + conj_t conjat, + conj_t conja, + conj_t conjw, + conj_t conjx, + dim_t m, + dim_t b_n, + dcomplex* restrict alpha, + dcomplex* restrict a, inc_t inca, inc_t lda, + dcomplex* restrict w, inc_t incw, + dcomplex* restrict x, inc_t incx, + dcomplex* restrict beta, + dcomplex* restrict y, inc_t incy, + dcomplex* restrict z, inc_t incz, + cntx_t* restrict cntx + ) +{ + // A: m x b + // w, z: m + // x, y: b + // + // y = beta * y + alpha * A^T w; + // z = z + alpha * A x; + if ( ( bli_cpuid_is_avx_supported() == TRUE ) && + ( inca == 1 ) && ( incw == 1 ) && ( incx == 1 ) + && ( incy == 1 ) && ( incz == 1 ) && ( b_n == 4 ) ) + { + // Temporary rho buffer holds computed dot product result + dcomplex rho[ 4 ]; + + // chi? variables to hold scaled scaler values from x vector + dcomplex chi0; + dcomplex chi1; + dcomplex chi2; + dcomplex chi3; + + // If beta is zero, clear y + // Else, scale by beta + if ( PASTEMAC(z,eq0)( *beta ) ) + { + for ( dim_t i = 0; i < 4; ++i ) + { + PASTEMAC(z,set0s)( y[i] ); + } + } + else + { + for ( dim_t i = 0; i < 4; ++i ) + { + PASTEMAC(z,scals)( *beta, y[i] ); + } + } + + // If the vectors are empty or if alpha is zero, return early + if ( bli_zero_dim1( m ) || PASTEMAC(z,eq0)( *alpha ) ) return; + + // Initialize rho vector to 0 + for ( dim_t i = 0; i < 4; ++i ) PASTEMAC(z,set0s)( rho[i] ); + + // Set conj use variable for dot operation + conj_t conjdot_use = conjw; + if ( bli_is_conj( conjat ) ) + { + bli_toggle_conj( &conjdot_use ); + } + + // Set conj use variable for dotxf operation, scalar + dim_t conjdotxf = 1; + if ( bli_is_conj( conjdot_use ) ) + { + conjdotxf = -1; + } + + // Set conj use variable for axpyf operation, scalar + dim_t conjaxpyf = 1; + if ( bli_is_conj( conja ) ) + { + conjaxpyf = -1; + } + + // Store each element of x vector in a scalar and apply conjx + if( bli_is_noconj( conjx ) ) + { + chi0 = *( x + 0*incx ); + chi1 = *( x + 1*incx ); + chi2 = *( x + 2*incx ); + chi3 = *( x + 3*incx ); + } + else + { + bli_zcopycjs( conjx, *( x + 0*incx ), chi0 ); + bli_zcopycjs( conjx, *( x + 1*incx ), chi1 ); + bli_zcopycjs( conjx, *( x + 2*incx ), chi2 ); + bli_zcopycjs( conjx, *( x + 3*incx ), chi3 ); + } + + // Scale each chi scalar by alpha + bli_zscals( *alpha, chi0 ); + bli_zscals( *alpha, chi1 ); + bli_zscals( *alpha, chi2 ); + bli_zscals( *alpha, chi3 ); + + dim_t row = 0; + dim_t iter = m / 2; + dim_t rem = m % 2; + if (iter) + { + vec x0R, x1R, x2R, x3R; // x?R holds real part of x[?] + vec x0I, x1I, x2I, x3I; // x?I hold real part of x[?] + vec a_tile0, a_tile1; // a_tile? holds columns of a + vec temp1, temp2, temp3; // temp? registers for intermediate op + vec wR, wI; // holds real & imag components of w + vec z_vec; // holds the z vector + + // rho? registers hold results of fmadds for dotxf operation + vec rho0, rho1, rho2, rho3; + vec rho4, rho5, rho6, rho7; + + // For final computation, based on conjdot_use + // sign of imaginary component needs to be toggled + __m256d no_conju = _mm256_setr_pd( -1, 1, -1, 1 ); + __m256d conju = _mm256_setr_pd( 1, -1, 1, -1 ); + + // Clear the temp registers + temp1.v = _mm256_setzero_pd(); + temp2.v = _mm256_setzero_pd(); + temp3.v = _mm256_setzero_pd(); + + // Clear rho registers + // Once micro tile is computed, horizontal addition + // of all rho's will provide us with the result of + // dotxf opereation + rho0.v = _mm256_setzero_pd(); + rho1.v = _mm256_setzero_pd(); + rho2.v = _mm256_setzero_pd(); + rho3.v = _mm256_setzero_pd(); + rho4.v = _mm256_setzero_pd(); + rho5.v = _mm256_setzero_pd(); + rho6.v = _mm256_setzero_pd(); + rho7.v = _mm256_setzero_pd(); + + // Broadcast real & imag parts of 4 elements of x + // to perform axpyf operation with 4x8 tile of A + x0R.v = _mm256_broadcast_sd( &chi0.real ); // real part of x0 + x0I.v = _mm256_broadcast_sd( &chi0.imag ); // imag part of x0 + x1R.v = _mm256_broadcast_sd( &chi1.real ); // real part of x1 + x1I.v = _mm256_broadcast_sd( &chi1.imag ); // imag part of x1 + x2R.v = _mm256_broadcast_sd( &chi2.real ); // real part of x2 + x2I.v = _mm256_broadcast_sd( &chi2.imag ); // imag part of x2 + x3R.v = _mm256_broadcast_sd( &chi3.real ); // real part of x3 + x3I.v = _mm256_broadcast_sd( &chi3.imag ); // imag part of x3 + + for ( ; ( row + 1 ) < m; row += 2) + { + // Load first two columns of A + // a_tile0.v -> a00R a00I a10R a10I + // a_tile1.v -> a01R a01I a11R a11I + a_tile0.v = _mm256_loadu_pd( (double *)&a[row + 0 * lda] ); + a_tile1.v = _mm256_loadu_pd( (double *)&a[row + 1 * lda] ); + + temp1.v = _mm256_mul_pd( a_tile0.v, x0R.v ); + temp2.v = _mm256_mul_pd( a_tile0.v, x0I.v ); + + temp1.v = _mm256_fmadd_pd( a_tile1.v, x1R.v, temp1.v ); + temp2.v = _mm256_fmadd_pd( a_tile1.v, x1I.v, temp2.v ); + + // Load w vector + // wR.v -> w0R w0I w1R w1I + // wI.v ( shuf wR.v ) -> w0I w0I w1I w1I + // wR.v ( shuf wR.v ) -> w0R w0R w1R w1R + wR.v = _mm256_loadu_pd( (double *)&w[row] ); + wI.v = _mm256_permute_pd( wR.v, 15 ); + wR.v = _mm256_permute_pd( wR.v, 0 ); + + rho0.v = _mm256_fmadd_pd( a_tile0.v, wR.v, rho0.v); + rho4.v = _mm256_fmadd_pd( a_tile0.v, wI.v, rho4.v); + + rho1.v = _mm256_fmadd_pd( a_tile1.v, wR.v, rho1.v); + rho5.v = _mm256_fmadd_pd( a_tile1.v, wI.v, rho5.v); + + // Load 3rd and 4th columns of A + // a_tile0.v -> a20R a20I a30R a30I + // a_tile1.v -> a21R a21I a31R a31I + a_tile0.v = _mm256_loadu_pd( (double *)&a[row + 2 * lda] ); + a_tile1.v = _mm256_loadu_pd( (double *)&a[row + 3 * lda] ); + + temp1.v = _mm256_fmadd_pd( a_tile0.v, x2R.v, temp1.v ); + temp2.v = _mm256_fmadd_pd( a_tile0.v, x2I.v, temp2.v ); + + temp1.v = _mm256_fmadd_pd( a_tile1.v, x3R.v, temp1.v ); + temp2.v = _mm256_fmadd_pd( a_tile1.v, x3I.v, temp2.v ); + + rho2.v = _mm256_fmadd_pd( a_tile0.v, wR.v, rho2.v); + rho6.v = _mm256_fmadd_pd( a_tile0.v, wI.v, rho6.v); + + rho3.v = _mm256_fmadd_pd( a_tile1.v, wR.v, rho3.v); + rho7.v = _mm256_fmadd_pd( a_tile1.v, wI.v, rho7.v); + + // Load z vector + z_vec.v = _mm256_loadu_pd( (double *)&z[row] ); + + // Permute the result and alternatively add-sub final values + if( bli_is_noconj( conja ) ) + { + temp2.v = _mm256_permute_pd(temp2.v, 5); + temp3.v = _mm256_addsub_pd(temp1.v, temp2.v); + } + else + { + temp1.v = _mm256_permute_pd( temp1.v, 5 ); + temp3.v = _mm256_addsub_pd( temp2.v, temp1.v ); + temp3.v = _mm256_permute_pd( temp3.v, 5 ); + } + + // Add & store result to z_vec + z_vec.v = _mm256_add_pd( temp3.v, z_vec.v ); + _mm256_storeu_pd( (double *)&z[row], z_vec.v ); + } + + // Swapping position of real and imag component + // for horizontal addition to get the final + // dot product computation + // rho register are holding computation which needs + // to be arranged in following manner. + // a0R * x0I | a0I * x0I | a1R * x1I | a1I * x1R + // || + // \/ + // a0I * x0I | a0R * x0I | a1I * x1R | a1R * x1I + + rho4.v = _mm256_permute_pd(rho4.v, 0x05); + rho5.v = _mm256_permute_pd(rho5.v, 0x05); + rho6.v = _mm256_permute_pd(rho6.v, 0x05); + rho7.v = _mm256_permute_pd(rho7.v, 0x05); + + // Negating imaginary part for computing + // the final result of dcomplex multiplication + if ( bli_is_noconj( conjdot_use ) ) + { + rho4.v = _mm256_mul_pd(rho4.v, no_conju); + rho5.v = _mm256_mul_pd(rho5.v, no_conju); + rho6.v = _mm256_mul_pd(rho6.v, no_conju); + rho7.v = _mm256_mul_pd(rho7.v, no_conju); + } + else + { + rho4.v = _mm256_mul_pd(rho4.v, conju); + rho5.v = _mm256_mul_pd(rho5.v, conju); + rho6.v = _mm256_mul_pd(rho6.v, conju); + rho7.v = _mm256_mul_pd(rho7.v, conju); + } + + rho0.v = _mm256_add_pd(rho0.v, rho4.v); + rho1.v = _mm256_add_pd(rho1.v, rho5.v); + rho2.v = _mm256_add_pd(rho2.v, rho6.v); + rho3.v = _mm256_add_pd(rho3.v, rho7.v); + + // rho0 & rho1 hold final dot product + // result of 4 dcomplex elements + rho0.d[0] += rho0.d[2]; + rho0.d[1] += rho0.d[3]; + + rho0.d[2] = rho1.d[0] + rho1.d[2]; + rho0.d[3] = rho1.d[1] + rho1.d[3]; + + rho1.d[0] = rho2.d[0] + rho2.d[2]; + rho1.d[1] = rho2.d[1] + rho2.d[3]; + + rho1.d[2] = rho3.d[0] + rho3.d[2]; + rho1.d[3] = rho3.d[1] + rho3.d[3]; + + // Storing the computed dot product + // in temp buffer rho for further computation. + _mm256_storeu_pd( (double *)rho, rho0.v ); + _mm256_storeu_pd( (double *)(rho+2) , rho1.v ); + } + + // To handle the remaining cases + if ( rem ) + { + PRAGMA_SIMD + for ( dim_t p = row; p < m; ++p ) + { + const dcomplex a0c = a[p + 0 * lda]; + const dcomplex a1c = a[p + 1 * lda]; + const dcomplex a2c = a[p + 2 * lda]; + const dcomplex a3c = a[p + 3 * lda]; + + // dot + dcomplex r0c = rho[0]; + dcomplex r1c = rho[1]; + dcomplex r2c = rho[2]; + dcomplex r3c = rho[3]; + + dcomplex w0c = w[p]; + + r0c.real += a0c.real * w0c.real - a0c.imag * w0c.imag + * conjdotxf; + r0c.imag += a0c.imag * w0c.real + a0c.real * w0c.imag + * conjdotxf; + r1c.real += a1c.real * w0c.real - a1c.imag * w0c.imag + * conjdotxf; + r1c.imag += a1c.imag * w0c.real + a1c.real * w0c.imag + * conjdotxf; + r2c.real += a2c.real * w0c.real - a2c.imag * w0c.imag + * conjdotxf; + r2c.imag += a2c.imag * w0c.real + a2c.real * w0c.imag + * conjdotxf; + r3c.real += a3c.real * w0c.real - a3c.imag * w0c.imag + * conjdotxf; + r3c.imag += a3c.imag * w0c.real + a3c.real * w0c.imag + * conjdotxf; + + rho[0] = r0c; + rho[1] = r1c; + rho[2] = r2c; + rho[3] = r3c; + + // axpy + dcomplex z0c = z[p]; + + z0c.real += chi0.real * a0c.real - chi0.imag * a0c.imag + * conjaxpyf; + z0c.real += chi1.real * a1c.real - chi1.imag * a1c.imag + * conjaxpyf; + z0c.real += chi2.real * a2c.real - chi2.imag * a2c.imag + * conjaxpyf; + z0c.real += chi3.real * a3c.real - chi3.imag * a3c.imag + * conjaxpyf; + z0c.imag += chi0.imag * a0c.real + chi0.real * a0c.imag + * conjaxpyf; + z0c.imag += chi1.imag * a1c.real + chi1.real * a1c.imag + * conjaxpyf; + z0c.imag += chi2.imag * a2c.real + chi2.real * a2c.imag + * conjaxpyf; + z0c.imag += chi3.imag * a3c.real + chi3.real * a3c.imag + * conjaxpyf; + + z[p] = z0c; + } + } + + // Conjugating the final result if conjat + if ( bli_is_conj( conjat ) ) + { + for ( dim_t i = 0; i < 4; ++i ) + { + PASTEMAC(z,conjs)( rho[i] ); + } + } + + // Scaling the dot product result with alpha + // and adding the result to vector y + for ( dim_t i = 0; i < 4; ++i ) + { + PASTEMAC(z,axpys)( *alpha, rho[i], y[i] ); + } + } + else + { + // For non-unit increments + /* Query the context for the kernel function pointer. */ + const num_t dt = PASTEMAC(z,type); + PASTECH(z,dotxf_ker_ft) kfp_df = + bli_cntx_get_l1f_ker_dt( dt, BLIS_DOTXF_KER, cntx ); + PASTECH(z,axpyf_ker_ft) kfp_af = + bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPYF_KER, cntx ); + + kfp_df + ( + conjat, + conjw, + m, + b_n, + alpha, + a, inca, lda, + w, incw, + beta, + y, incy, + cntx + ); + + kfp_af + ( + conja, + conjx, + m, + b_n, + alpha, + a, inca, lda, + x, incx, + z, incz, + cntx + ); + } +} + +/** + * cdotxaxpyf kernel performs dot and apxy function together. + * y := conj(beta) * y + conj(alpha) * conj(A)^t * conj(w) (dotxf) + * z := z + alpha * conj(A) * conj(x) (axpyf) + * where, + * A is an m x b matrix. + * w, z are vectors of length m. + * x, y are vectors of length b. + * alpha, beta are scalars + */ +void bli_cdotxaxpyf_zen_int_8 +( + conj_t conjat, + conj_t conja, + conj_t conjw, + conj_t conjx, + dim_t m, + dim_t b_n, + scomplex* restrict alpha, + scomplex* restrict a, inc_t inca, inc_t lda, + scomplex* restrict w, inc_t incw, + scomplex* restrict x, inc_t incx, + scomplex* restrict beta, + scomplex* restrict y, inc_t incy, + scomplex* restrict z, inc_t incz, + cntx_t* restrict cntx + ) +{ + // A: m x b + // w, z: m + // x, y: b + // + // y = beta * y + alpha * A^T w; + // z = z + alpha * A x; + if ( ( bli_cpuid_is_avx_supported() == TRUE ) && + ( inca == 1 ) && ( incw == 1 ) && ( incx == 1 ) + && ( incy == 1 ) && ( incz == 1 ) && ( b_n == 4 ) ) + { + // Temporary rho buffer holds computed dot product result + scomplex rho[ 4 ]; + + // chi? variables to hold scaled scaler values from x vector + scomplex chi0; + scomplex chi1; + scomplex chi2; + scomplex chi3; + + // If beta is zero, clear y + // Else, scale by beta + if ( PASTEMAC(c,eq0)( *beta ) ) + { + for ( dim_t i = 0; i < 4; ++i ) + { + PASTEMAC(c,set0s)( y[i] ); + } + } + else + { + for ( dim_t i = 0; i < 4; ++i ) + { + PASTEMAC(c,scals)( *beta, y[i] ); + } + } + + // If the vectors are empty or if alpha is zero, return early + if ( bli_zero_dim1( m ) || PASTEMAC(c,eq0)( *alpha ) ) return; + + // Initialize rho vector to 0 + for ( dim_t i = 0; i < 4; ++i ) PASTEMAC(c,set0s)( rho[i] ); + + // Set conj use variable for dot operation + conj_t conjdot_use = conjw; + if ( bli_is_conj( conjat ) ) + { + bli_toggle_conj( &conjdot_use ); + } + + // Set conj use variable for dotxf operation, scalar + dim_t conjdotxf = 1; + if ( bli_is_conj( conjdot_use ) ) + { + conjdotxf = -1; + } + + // Set conj use variable for axpyf operation, scalar + dim_t conjaxpyf = 1; + if ( bli_is_conj( conja ) ) + { + conjaxpyf = -1; + } + + // Store each element of x vector in a scalar and apply conjx + if( bli_is_noconj( conjx ) ) + { + chi0 = *( x + 0*incx ); + chi1 = *( x + 1*incx ); + chi2 = *( x + 2*incx ); + chi3 = *( x + 3*incx ); + } + else + { + bli_ccopycjs( conjx, *( x + 0*incx ), chi0 ); + bli_ccopycjs( conjx, *( x + 1*incx ), chi1 ); + bli_ccopycjs( conjx, *( x + 2*incx ), chi2 ); + bli_ccopycjs( conjx, *( x + 3*incx ), chi3 ); + } + + // Scale each chi scalar by alpha + bli_cscals( *alpha, chi0 ); + bli_cscals( *alpha, chi1 ); + bli_cscals( *alpha, chi2 ); + bli_cscals( *alpha, chi3 ); + + dim_t i = 0; + dim_t iter = m / 4; + dim_t rem = m % 4; + if (iter) + { + v8sf_t x0R, x1R, x2R, x3R; // x?R holds real part of x[?] + v8sf_t x0I, x1I, x2I, x3I; // x?I hold real part of x[?] + v8sf_t a_tile0, a_tile1; // a_tile? holds columns of a + v8sf_t temp1, temp2, temp3; // temp? registers for intermediate op + v8sf_t wR, wI; // holds real & imag components of w + v8sf_t z_vec; // holds the z vector + + // For final computation, based on conjdot_use + // sign of imaginary component needs to be toggled + __m256 no_conju = _mm256_setr_ps( -1, 1, -1, 1, -1, 1, -1, 1 ); + __m256 conju = _mm256_setr_ps( 1, -1, 1, -1, 1, -1, 1, -1 ); + + // Clear the temp registers + temp1.v = _mm256_setzero_ps(); + temp2.v = _mm256_setzero_ps(); + temp3.v = _mm256_setzero_ps(); + + // Clear rho registers + // Once micro tile is computed, horizontal addition + // of all rho's will provide us with the result of + // dotxf opereation + __m256 rho0v; rho0v = _mm256_setzero_ps(); + __m256 rho1v; rho1v = _mm256_setzero_ps(); + __m256 rho2v; rho2v = _mm256_setzero_ps(); + __m256 rho3v; rho3v = _mm256_setzero_ps(); + + __m256 rho4v; rho4v = _mm256_setzero_ps(); + __m256 rho5v; rho5v = _mm256_setzero_ps(); + __m256 rho6v; rho6v = _mm256_setzero_ps(); + __m256 rho7v; rho7v = _mm256_setzero_ps(); + + // Broadcast real & imag parts of 4 elements of x + // to perform axpyf operation with 4x8 tile of A + x0R.v = _mm256_broadcast_ss( &chi0.real ); // real part of x0 + x0I.v = _mm256_broadcast_ss( &chi0.imag ); // imag part of x0 + x1R.v = _mm256_broadcast_ss( &chi1.real ); // real part of x1 + x1I.v = _mm256_broadcast_ss( &chi1.imag ); // imag part of x1 + x2R.v = _mm256_broadcast_ss( &chi2.real ); // real part of x2 + x2I.v = _mm256_broadcast_ss( &chi2.imag ); // imag part of x2 + x3R.v = _mm256_broadcast_ss( &chi3.real ); // real part of x3 + x3I.v = _mm256_broadcast_ss( &chi3.imag ); // imag part of x3 + + for ( ; ( i + 3 ) < m; i += 4) + { + // Load first two columns of A + // a_tile0.v -> a00R a00I a10R a10I a20R a20I a30R a30I + // a_tile1.v -> a01R a01I a11R a11I a21R a21I a31R a31I + a_tile0.v = _mm256_loadu_ps( (float *)&a[i + 0 * lda] ); + a_tile1.v = _mm256_loadu_ps( (float *)&a[i + 1 * lda] ); + + temp1.v = _mm256_mul_ps( a_tile0.v, x0R.v ); + temp2.v = _mm256_mul_ps( a_tile0.v, x0I.v ); + + temp1.v = _mm256_fmadd_ps( a_tile1.v, x1R.v, temp1.v ); + temp2.v = _mm256_fmadd_ps( a_tile1.v, x1I.v, temp2.v ); + + // Load w vector + // wR.v -> w0R w0I w1R w1I w2R w2I w3R w3I + // wI.v ( shuf wR.v ) -> w0I w0I w1I w1I w2I w2I w3I w3I + // wR.v ( shuf wR.v ) -> w0R w0R w1R w1R w2R w2R w3R w3R + wR.v = _mm256_loadu_ps( (float *) (w + i) ); + wI.v = _mm256_permute_ps( wR.v, 0xf5 ); + wR.v = _mm256_permute_ps( wR.v, 0xa0); + + rho0v = _mm256_fmadd_ps( a_tile0.v, wR.v, rho0v ); + rho4v = _mm256_fmadd_ps( a_tile0.v, wI.v, rho4v ); + + rho1v = _mm256_fmadd_ps( a_tile1.v, wR.v, rho1v ); + rho5v = _mm256_fmadd_ps( a_tile1.v, wI.v, rho5v ); + + // Load 3rd and 4th columns of A + // a_tile0.v -> a20R a20I a30R a30I + // a_tile1.v -> a21R a21I a31R a31I + a_tile0.v = _mm256_loadu_ps( (float *)&a[i + 2 * lda] ); + a_tile1.v = _mm256_loadu_ps( (float *)&a[i + 3 * lda] ); + + temp1.v = _mm256_fmadd_ps( a_tile0.v, x2R.v, temp1.v ); + temp2.v = _mm256_fmadd_ps( a_tile0.v, x2I.v, temp2.v ); + + temp1.v = _mm256_fmadd_ps( a_tile1.v, x3R.v, temp1.v ); + temp2.v = _mm256_fmadd_ps( a_tile1.v, x3I.v, temp2.v ); + + rho2v = _mm256_fmadd_ps( a_tile0.v, wR.v, rho2v ); + rho6v = _mm256_fmadd_ps( a_tile0.v, wI.v, rho6v ); + + rho3v = _mm256_fmadd_ps( a_tile1.v, wR.v, rho3v ); + rho7v = _mm256_fmadd_ps( a_tile1.v, wI.v, rho7v ); + + // Load z vector + z_vec.v = _mm256_loadu_ps( (float *)&z[i] ); + + // Permute the result and alternatively add-sub final values + if( bli_is_noconj( conja ) ) + { + temp2.v = _mm256_permute_ps(temp2.v, 0xB1); + temp3.v = _mm256_addsub_ps(temp1.v, temp2.v); + } + else + { + temp1.v = _mm256_permute_ps( temp1.v, 0xB1 ); + temp3.v = _mm256_addsub_ps( temp2.v, temp1.v ); + temp3.v = _mm256_permute_ps( temp3.v, 0xB1 ); + } + + // Add & store result to z_vec + z_vec.v = _mm256_add_ps( temp3.v, z_vec.v ); + _mm256_storeu_ps( (float *)&z[i], z_vec.v ); + } + + // Swapping position of real and imag component + // for horizontal addition to get the final + // dot product computation + // rho register are holding computation which needs + // to be arranged in following manner. + // a0R * x0I | a0I * x0I | a1R * x1I | a1I * x1R | ... + // || + // \/ + // a0I * x0I | a0R * x0I | a1I * x1R | a1R * x1I | ... + + rho4v = _mm256_permute_ps(rho4v, 0xb1); + rho5v = _mm256_permute_ps(rho5v, 0xb1); + rho6v = _mm256_permute_ps(rho6v, 0xb1); + rho7v = _mm256_permute_ps(rho7v, 0xb1); + + // Negating imaginary part for computing + // the final result of dcomplex multiplication + if ( bli_is_noconj( conjdot_use ) ) + { + rho4v = _mm256_mul_ps(rho4v, no_conju); + rho5v = _mm256_mul_ps(rho5v, no_conju); + rho6v = _mm256_mul_ps(rho6v, no_conju); + rho7v = _mm256_mul_ps(rho7v, no_conju); + } + else + { + rho4v = _mm256_mul_ps(rho4v, conju); + rho5v = _mm256_mul_ps(rho5v, conju); + rho6v = _mm256_mul_ps(rho6v, conju); + rho7v = _mm256_mul_ps(rho7v, conju); + } + + rho0v = _mm256_add_ps(rho0v, rho4v); + rho1v = _mm256_add_ps(rho1v, rho5v); + rho2v = _mm256_add_ps(rho2v, rho6v); + rho3v = _mm256_add_ps(rho3v, rho7v); + + // Horizontal addition of rho elements for computing final dotxf + // and storing the results into rho buffer + scomplex *ptr = (scomplex *)&rho0v; + for(dim_t j = 0; j < 4; j++) + { + rho[0].real += ptr[j].real; + rho[0].imag += ptr[j].imag; + } + ptr = (scomplex *)&rho1v; + for(dim_t j = 0; j < 4; j++) + { + rho[1].real += ptr[j].real; + rho[1].imag += ptr[j].imag; + } + ptr = (scomplex *)&rho2v; + for(dim_t j = 0; j < 4; j++) + { + rho[2].real += ptr[j].real; + rho[2].imag += ptr[j].imag; + } + ptr = (scomplex *)&rho3v; + for(dim_t j = 0; j < 4; j++) + { + rho[3].real += ptr[j].real; + rho[3].imag += ptr[j].imag; + } + } + + // To handle the remaining cases + if ( rem ) + { + PRAGMA_SIMD + for ( dim_t p = i; p < m; ++p ) + { + const scomplex a0c = a[p + 0 * lda]; + const scomplex a1c = a[p + 1 * lda]; + const scomplex a2c = a[p + 2 * lda]; + const scomplex a3c = a[p + 3 * lda]; + + // dot + scomplex r0c = rho[0]; + scomplex r1c = rho[1]; + scomplex r2c = rho[2]; + scomplex r3c = rho[3]; + + scomplex w0c = w[p]; + + r0c.real += a0c.real * w0c.real - a0c.imag * w0c.imag + * conjdotxf; + r0c.imag += a0c.imag * w0c.real + a0c.real * w0c.imag + * conjdotxf; + r1c.real += a1c.real * w0c.real - a1c.imag * w0c.imag + * conjdotxf; + r1c.imag += a1c.imag * w0c.real + a1c.real * w0c.imag + * conjdotxf; + r2c.real += a2c.real * w0c.real - a2c.imag * w0c.imag + * conjdotxf; + r2c.imag += a2c.imag * w0c.real + a2c.real * w0c.imag + * conjdotxf; + r3c.real += a3c.real * w0c.real - a3c.imag * w0c.imag + * conjdotxf; + r3c.imag += a3c.imag * w0c.real + a3c.real * w0c.imag + * conjdotxf; + + rho[0] = r0c; + rho[1] = r1c; + rho[2] = r2c; + rho[3] = r3c; + + // axpy + scomplex z0c = z[p]; + + z0c.real += chi0.real * a0c.real - chi0.imag * a0c.imag + * conjaxpyf; + z0c.real += chi1.real * a1c.real - chi1.imag * a1c.imag + * conjaxpyf; + z0c.real += chi2.real * a2c.real - chi2.imag * a2c.imag + * conjaxpyf; + z0c.real += chi3.real * a3c.real - chi3.imag * a3c.imag + * conjaxpyf; + z0c.imag += chi0.imag * a0c.real + chi0.real * a0c.imag + * conjaxpyf; + z0c.imag += chi1.imag * a1c.real + chi1.real * a1c.imag + * conjaxpyf; + z0c.imag += chi2.imag * a2c.real + chi2.real * a2c.imag + * conjaxpyf; + z0c.imag += chi3.imag * a3c.real + chi3.real * a3c.imag + * conjaxpyf; + + z[p] = z0c; + } + } + + // Conjugating the final result if conjat + if ( bli_is_conj( conjat ) ) + { + for ( dim_t j = 0; j < 4; ++j ) + { + PASTEMAC(c,conjs)( rho[j] ); + } + } + + // Scaling the dot product result with alpha + // and adding the result to vector y + for ( dim_t j = 0; j < 4; ++j ) + { + PASTEMAC(c,axpys)( *alpha, rho[j], y[j] ); + } + } + else + { + // For non-unit increments + /* Query the context for the kernel function pointer. */ + const num_t dt = PASTEMAC(c,type); + PASTECH(c,dotxf_ker_ft) kfp_df = + bli_cntx_get_l1f_ker_dt( dt, BLIS_DOTXF_KER, cntx ); + PASTECH(c,axpyf_ker_ft) kfp_af = + bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPYF_KER, cntx ); + + kfp_df + ( + conjat, + conjw, + m, + b_n, + alpha, + a, inca, lda, + w, incw, + beta, + y, incy, + cntx + ); + + kfp_af + ( + conja, + conjx, + m, + b_n, + alpha, + a, inca, lda, + x, incx, + z, incz, + cntx + ); + } +} \ No newline at end of file diff --git a/kernels/zen/bli_kernels_zen.h b/kernels/zen/bli_kernels_zen.h index 73dd55b994..da7afd00e8 100644 --- a/kernels/zen/bli_kernels_zen.h +++ b/kernels/zen/bli_kernels_zen.h @@ -138,6 +138,8 @@ DOTXF_KER_PROT( dcomplex, z, dotxf_zen_int_6 ) DOTXF_KER_PROT( scomplex, c, dotxf_zen_int_6 ) // dotxaxpyf (intrinsics) DOTXAXPYF_KER_PROT( double, d, dotxaxpyf_zen_int_8 ) +DOTXAXPYF_KER_PROT( scomplex, c, dotxaxpyf_zen_int_8 ) +DOTXAXPYF_KER_PROT( dcomplex, z, dotxaxpyf_zen_int_8 ) // -- level-2 ---------------------------------------------------------------- From 52e4fd0f113c84fee646101791d777f919bd3154 Mon Sep 17 00:00:00 2001 From: Harsh Dave Date: Thu, 7 Apr 2022 00:18:38 -0500 Subject: [PATCH 101/243] Performance Improvement for ctrsm small sizes Details: - Enable ctrsm small implementation - Handled Overflow and Underflow Vulnerabilites in ctrsm small implementations. - Fixed failures observed in libflame testing. - For small sizes, ctrsm small implementation is used for all variants. Change-Id: I17b862dcb794a5af0ec68f585992131fef57b179 --- frame/compat/bla_trsm_amd.c | 9 +- kernels/zen/3/bli_trsm_small.c | 341 +++++++-------------------------- 2 files changed, 73 insertions(+), 277 deletions(-) diff --git a/frame/compat/bla_trsm_amd.c b/frame/compat/bla_trsm_amd.c index 9ff8073be0..e1a2fffafd 100644 --- a/frame/compat/bla_trsm_amd.c +++ b/frame/compat/bla_trsm_amd.c @@ -902,6 +902,8 @@ void dtrsm_ /* Finalize BLIS. */ bli_finalize_auto(); } + + void ztrsm_ ( const f77_char* side, @@ -1221,7 +1223,8 @@ void ztrsm_ /* Finalize BLIS. */ bli_finalize_auto(); } -#if 0 + + void ctrsm_ ( const f77_char* side, @@ -1236,7 +1239,7 @@ void ctrsm_ ) { AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO) - AOCL_DTL_LOG_TRSM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 's', + AOCL_DTL_LOG_TRSM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'c', *side, *uploa,*transa, *diaga, *m, *n, (void*)alpha,*lda, *ldb); @@ -1537,7 +1540,5 @@ void ctrsm_ /* Finalize BLIS. */ bli_finalize_auto(); } -#endif -INSERT_GENTFUNC_BLAS_C( trsm, trsm ) #endif diff --git a/kernels/zen/3/bli_trsm_small.c b/kernels/zen/3/bli_trsm_small.c index bb6d198c78..07077010f2 100644 --- a/kernels/zen/3/bli_trsm_small.c +++ b/kernels/zen/3/bli_trsm_small.c @@ -36852,33 +36852,19 @@ BLIS_INLINE err_t bli_ztrsm_small_XAltB_XAuB */ #define SCOMPLEX_INV(a, b) {\ - a.real = b.real;\ - a.imag = (b.imag * -1.0);\ - /*Compute denominator eliminating imaginary component*/\ - float dnm = (b.real * b.real);\ - /*multiply two times with -1 for correct result as - * dcomplex number with positive imaginary part will - * invert the sign if not multiplied twice with -1*/\ - dnm += ((-1.0 * (b.imag * b.imag)) * -1.0);\ - /*Compute the final result by dividing real and imag part by dnm*/\ - a.real /= dnm;\ - a.imag /= dnm;\ + a.real = 1.0;\ + a.imag = 0.0;\ + bli_cinvscals(b, a);\ } #define SCOMPLEX_MUL(a, b, c) {\ - float real = a.real * b.real;\ - real += ((a.imag * b.imag) * -1.0);\ - float imag = (a.real * b.imag);\ - imag += (a.imag * b.real);\ - c.real = real;\ - c.imag = imag;\ + c.real = b.real;\ + c.imag = b.imag;\ + bli_cscals(a,c);\ } #define SCOMPLEX_DIV(a, b){\ - float dnm = b.real * b.real;\ - dnm += (-1.0 * (b.imag * (b.imag * -1.0) ));\ - a.real /= dnm;\ - a.imag /= dnm;\ + bli_cinvscals(b,a); \ } #ifdef BLIS_ENABLE_TRSM_PREINVERSION @@ -36904,13 +36890,10 @@ BLIS_INLINE err_t bli_ztrsm_small_XAltB_XAuB #ifdef BLIS_DISABLE_TRSM_PREINVERSION #define CTRSM_DIAG_ELE_EVAL_OPS(a,b,c){\ - if(!is_unitdiag)\ - {\ - a.real = b.real;\ - a.imag = (b.imag * -1.0);\ - SCOMPLEX_MUL(c, a, c)\ - SCOMPLEX_DIV(c, b)\ - }\ + if(!is_unitdiag)\ + {\ + bli_cinvscals(b, c);\ + }\ } #endif @@ -37306,72 +37289,30 @@ BLIS_INLINE void ctrsm_small_pack_diag_element dim_t size ) { - __m256 ymm1, ymm2, ymm3, ymm4, ymm5, ymm6, ymm8; - bool is_eight = (size == 8) ? 1 : 0; - scomplex ones = {1.0, 1.0}; - ymm2 = ymm1 = _mm256_broadcast_ps((__m128 const *)&ones); #ifdef BLIS_ENABLE_TRSM_PREINVERSION - __m256 ymm7; - ymm7 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); -#endif - - if(!is_unitdiag) + // If Preinversion is disabled, inverse the diaganol + // elements from A and pack into diagonal buffer. + // In order to avoid the overflow and underflow scenarios, + // bli_cinvscals is used. + for( dim_t i = 0; i < size; i++) { - //broadcast diagonal elements of A11 - ymm1 = _mm256_broadcast_ps((__m128 const *)a11); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11+ cs_a +1)); - ymm3 = _mm256_broadcast_ps((__m128 const *) (a11+ cs_a*2 +2)); - - ymm1 = _mm256_shuffle_ps(ymm1, ymm2, 0x44); - - if(is_eight) { - ymm4 = _mm256_broadcast_ps((__m128 const *)(a11 + 4 + cs_a*4)); - ymm5 = _mm256_broadcast_ps((__m128 const *)(a11 + 5 + cs_a*5)); - ymm6 = _mm256_shuffle_ps(ymm4, ymm5, 0x44); - - ymm4 = _mm256_broadcast_ps((__m128 const *)(a11 + 6 + cs_a*6)); - ymm5 = _mm256_broadcast_ps((__m128 const *)(a11 + 7 + cs_a*7)); - ymm8 = _mm256_shuffle_ps(ymm4, ymm5, 0x44); - - ymm2 = _mm256_blend_ps(ymm6, ymm8, 0xF0); - - ymm4 = _mm256_broadcast_ps((__m128 const *)(a11 + 3 + cs_a*3)); - ymm3 = _mm256_shuffle_ps(ymm3, ymm4, 0x44); - } - - ymm1 = _mm256_blend_ps(ymm1, ymm3, 0xF0); - -#ifdef BLIS_ENABLE_TRSM_PREINVERSION - /*Taking denomerator multiplication of real & imaginary components*/ - ymm4 = _mm256_mul_ps(ymm1, ymm1); - ymm5 = _mm256_mul_ps(ymm2, ymm2); - /*Swapping real & imaginary component position for addition with - * respective components*/ - //BEFORE - //a[0] a[1] a[2] a[3] - //AFTER - //a[1] a[0] a[3] a[2] - //MESS - ymm6 = _mm256_permute_ps(ymm4, 0xB1); - ymm8 = _mm256_permute_ps(ymm5, 0xB1); - ymm4 = _mm256_add_ps(ymm4, ymm6); - ymm5 = _mm256_add_ps(ymm5, ymm8); - - /*Negating imaginary component of numerator*/ - ymm1 = _mm256_mul_ps(ymm1, ymm7); - ymm2 = _mm256_mul_ps(ymm2, ymm7); - - /*Dividing numerator by denominator*/ - ymm1 = _mm256_div_ps(ymm1, ymm4); - ymm2 = _mm256_div_ps(ymm2, ymm5); - -#endif + dim_t d = ((i*cs_a) + i); + scomplex ones = {1.0, 0.0}; + bli_cinvscals(a11[d], ones) + d11_pack[i].real = ones.real; + d11_pack[i].imag = ones.imag; } - _mm256_store_ps((float *)d11_pack, ymm1); - if(is_eight) + +#else //BLIS_ENABLE_TRSM_PREINVERSION + // If Preinversion is disabled, pack the diaganol + // elements from A into diagonal buffer. + for( dim_t i = 0; i < size; i++) { - _mm256_store_ps((float *)(d11_pack + 4), ymm2); + dim_t d = ((i*cs_a) + i); + bli_ccopys(a11[d],d11_pack[i]); } + +#endif //BLIS_ENABLE_TRSM_PREINVERSION } /** @@ -37619,26 +37560,19 @@ BLIS_INLINE void ctrsm_small_pack_diag_element ymm2 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0);\ ymm1 = _mm256_mul_ps(ymm1, ymm2);\ }\ - /*Negating imaginary component of numerator*/\ - ymm2 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0);\ - ymm1 = _mm256_mul_ps(ymm1, ymm2);\ - /*BLIS_CTRSM_MUL(vec1)*/\ - /*BLIS_CTRSM_MUL(vec2)*/\ - /*vec1 * ymm1*/\ - ymm3 = _mm256_shuffle_ps(ymm1, ymm1, 0x11);\ - ymm2 = _mm256_shuffle_ps(vec1, vec1, 0xA0);\ - ymm16 = _mm256_shuffle_ps(vec1, vec1,0xF5);\ - ymm16 = _mm256_mul_ps(ymm16, ymm3);\ - vec1 = _mm256_fmaddsub_ps(ymm2, ymm1, ymm16);\ - /*vec1 * ymm1*/\ - ymm2 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0);\ - ymm1 = _mm256_mul_ps(ymm1, ymm2);\ - /*Taking denomerator multiplication of real & imaginary components*/\ - ymm3 = _mm256_mul_ps(ymm1, ymm1);\ - ymm2 = _mm256_permute_ps(ymm3, 0xB1);\ - ymm3 = _mm256_add_ps(ymm2, ymm3);\ - /*Dividing numerator by denominator*/\ - vec1 = _mm256_div_ps(vec1, ymm3);\ + scomplex b_data[4];\ + scomplex d11_data[4];\ + \ + _mm256_storeu_ps((float *)(b_data), vec1);\ + _mm256_storeu_ps((float *)(d11_data), ymm1);\ + \ + for(dim_t i = 0; i < 4; i++)\ + {\ + bli_cinvscals(d11_data[0],b_data[i]);\ + }\ + \ + vec1 = _mm256_loadu_ps((float *)b_data);\ + \ }\ } @@ -37649,32 +37583,21 @@ BLIS_INLINE void ctrsm_small_pack_diag_element ymm2 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0);\ ymm1 = _mm256_mul_ps(ymm1, ymm2);\ }\ - /*Negating imaginary component of numerator*/\ - ymm2 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0);\ - ymm1 = _mm256_mul_ps(ymm1, ymm2);\ - /*BLIS_CTRSM_MUL(vec1)*/\ - /*BLIS_CTRSM_MUL(vec2)*/\ - /*vec1 * ymm1*/\ - ymm3 = _mm256_shuffle_ps(ymm1, ymm1, 0x11);\ - ymm2 = _mm256_shuffle_ps(vec1, vec1, 0xA0);\ - ymm16 = _mm256_shuffle_ps(vec1, vec1,0xF5);\ - ymm16 = _mm256_mul_ps(ymm16, ymm3);\ - vec1 = _mm256_fmaddsub_ps(ymm2, ymm1, ymm16);\ - /*vec1 * ymm1*/\ - ymm2 = _mm256_shuffle_ps(vec2, vec2, 0xA0);\ - ymm16 = _mm256_shuffle_ps(vec2, vec2,0xF5);\ - ymm16 = _mm256_mul_ps(ymm16, ymm3);\ - vec2 = _mm256_fmaddsub_ps(ymm2, ymm1, ymm16);\ - /*done*/\ - ymm2 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0);\ - ymm1 = _mm256_mul_ps(ymm1, ymm2);\ - /*Taking denomerator multiplication of real & imaginary components*/\ - ymm3 = _mm256_mul_ps(ymm1, ymm1);\ - ymm2 = _mm256_permute_ps(ymm3, 0xB1);\ - ymm3 = _mm256_add_ps(ymm2, ymm3);\ - /*Dividing numerator by denominator*/\ - vec1 = _mm256_div_ps(vec1, ymm3);\ - vec2 = _mm256_div_ps(vec2, ymm3);\ + scomplex b_data[8];\ + scomplex d11_data[4];\ + \ + _mm256_storeu_ps((float *)(b_data), vec1);\ + _mm256_storeu_ps((float *)(b_data + 4), vec2);\ + _mm256_storeu_ps((float *)(d11_data), ymm1);\ + \ + for(dim_t i = 0; i < 8; i++)\ + {\ + bli_cinvscals(d11_data[0],b_data[i]);\ + }\ + \ + vec1 = _mm256_loadu_ps((float *)b_data);\ + vec2 = _mm256_loadu_ps((float *)(b_data+4));\ + \ }\ } @@ -40308,43 +40231,13 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB { if(transa) { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_ps((__m128 const *)(a11)); - ymm1 = _mm256_broadcast_ps((__m128 const *)(a11+cs_a*1 + 1)); - ymm2 = _mm256_broadcast_ps((__m128 const *)(a11+cs_a*2 + 2)); - ymm3 = _mm256_broadcast_ps((__m128 const *)(a11+cs_a*3 + 3)); - ymm0 = _mm256_permute_ps(ymm0, 0x44); - ymm1 = _mm256_permute_ps(ymm1, 0x44); - ymm2 = _mm256_permute_ps(ymm2, 0x44); - ymm3 = _mm256_permute_ps(ymm3, 0x44); + ctrsm_small_pack_diag_element(is_unitdiag,a11,cs_a,d11_pack,m_rem); } else { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_ps((__m128 const *)(a11)); - ymm1 = _mm256_broadcast_ps((__m128 const *)(a11+rs_a*1 + 1)); - ymm2 = _mm256_broadcast_ps((__m128 const *)(a11+rs_a*2 + 2)); - ymm3 = _mm256_broadcast_ps((__m128 const *)(a11+rs_a*3 + 3)); - ymm0 = _mm256_permute_ps(ymm0, 0x44); - ymm1 = _mm256_permute_ps(ymm1, 0x44); - ymm2 = _mm256_permute_ps(ymm2, 0x44); - ymm3 = _mm256_permute_ps(ymm3, 0x44); + ctrsm_small_pack_diag_element(is_unitdiag,a11,rs_a,d11_pack,m_rem); } - - ymm1 = _mm256_shuffle_ps(ymm0, ymm1, 0x44); - ymm2 = _mm256_shuffle_ps(ymm2, ymm3, 0x44); - ymm1 = _mm256_blend_ps(ymm1, ymm2, 0xF0); - -#ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm7 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm4 = _mm256_mul_ps(ymm1, ymm1); - ymm6 = _mm256_permute_ps(ymm4, 0xB1); - ymm4 = _mm256_add_ps(ymm4, ymm6); - ymm1 = _mm256_mul_ps(ymm1, ymm7); - ymm1 = _mm256_div_ps(ymm1, ymm4); -#endif } - _mm256_storeu_ps((float *)(d11_pack), ymm1); for(j = 0; (j+d_nr-1) < n; j += d_nr) { @@ -42555,43 +42448,13 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB { if(transa) { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_ps((__m128 const *)(a11)); - ymm1 = _mm256_broadcast_ps((__m128 const *)(a11+cs_a*1 + 1)); - ymm2 = _mm256_broadcast_ps((__m128 const *)(a11+cs_a*2 + 2)); - ymm3 = _mm256_broadcast_ps((__m128 const *)(a11+cs_a*3 + 3)); - ymm0 = _mm256_permute_ps(ymm0, 0x44); - ymm1 = _mm256_permute_ps(ymm1, 0x44); - ymm2 = _mm256_permute_ps(ymm2, 0x44); - ymm3 = _mm256_permute_ps(ymm3, 0x44); + ctrsm_small_pack_diag_element(is_unitdiag,a11,cs_a,d11_pack,m_rem); } else { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_ps((__m128 const *)(a11)); - ymm1 = _mm256_broadcast_ps((__m128 const *)(a11+rs_a*1 + 1)); - ymm2 = _mm256_broadcast_ps((__m128 const *)(a11+rs_a*2 + 2)); - ymm3 = _mm256_broadcast_ps((__m128 const *)(a11+rs_a*3 + 3)); - ymm0 = _mm256_permute_ps(ymm0, 0x44); - ymm1 = _mm256_permute_ps(ymm1, 0x44); - ymm2 = _mm256_permute_ps(ymm2, 0x44); - ymm3 = _mm256_permute_ps(ymm3, 0x44); + ctrsm_small_pack_diag_element(is_unitdiag,a11,rs_a,d11_pack,m_rem); } - - ymm1 = _mm256_shuffle_ps(ymm0, ymm1, 0x44); - ymm2 = _mm256_shuffle_ps(ymm2, ymm3, 0x44); - ymm1 = _mm256_blend_ps(ymm1, ymm2, 0xF0); - -#ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm7 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm4 = _mm256_mul_ps(ymm1, ymm1); - ymm6 = _mm256_permute_ps(ymm4, 0xB1); - ymm4 = _mm256_add_ps(ymm4, ymm6); - ymm1 = _mm256_mul_ps(ymm1, ymm7); - ymm1 = _mm256_div_ps(ymm1, ymm4); -#endif } - _mm256_storeu_ps((float *)(d11_pack), ymm1); for(j = (n - d_nr); (j + 1) > 0; j -= d_nr) { @@ -44147,30 +44010,13 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB { if(transa) { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_ps((__m128 const *)(a11)); - ymm1 = _mm256_broadcast_ps((__m128 const *) - (a11+cs_a*1 + 1)); + ctrsm_small_pack_diag_element(is_unitdiag,a11,cs_a,d11_pack,n_rem); } else { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_ps((__m128 const *)(a11)); - ymm1 = _mm256_broadcast_ps((__m128 const *) - (a11+rs_a*1 + 1)); + ctrsm_small_pack_diag_element(is_unitdiag,a11,rs_a,d11_pack,n_rem); } - ymm1 = _mm256_shuffle_ps(ymm0, ymm1, 0x44); -#ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm7 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm4 = _mm256_mul_ps(ymm1, ymm1); - ymm6 = _mm256_permute_ps(ymm4, 0xB1); - ymm4 = _mm256_add_ps(ymm4, ymm6); - ymm1 = _mm256_mul_ps(ymm1, ymm7); - ymm1 = _mm256_div_ps(ymm1, ymm4); -#endif } - _mm_store_ps((float *)(d11_pack), - _mm256_extractf128_ps(ymm1,0)); for(i = (m-d_mr); (i+1) > 0; i -= d_mr) { @@ -44626,25 +44472,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB } - - ymm1 = _mm256_broadcast_ps((__m128 const *)&ones); - ymm1 = _mm256_permute_ps(ymm1, 0x44); if(!is_unitdiag) { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_ps((__m128 const *)(a11)); - ymm1 = _mm256_blend_ps(ymm0, ymm1, 0xC0); -#ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm7 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm4 = _mm256_mul_ps(ymm1, ymm1); - ymm6 = _mm256_permute_ps(ymm4, 0xB1); - ymm4 = _mm256_add_ps(ymm4, ymm6); - ymm1 = _mm256_mul_ps(ymm1, ymm7); - ymm1 = _mm256_div_ps(ymm1, ymm4); -#endif + ctrsm_small_pack_diag_element(is_unitdiag,a11,cs_a,d11_pack,n_rem); } - _mm_store_ps((float *)(d11_pack), - _mm256_extractf128_ps(ymm1,0)); for(i = (m-d_mr); (i+1) > 0; i -= d_mr) { @@ -44899,7 +44730,6 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB scomplex *a01, *a11, *b10, *b11; //pointers that point to blocks for GEMM and TRSM - scomplex ones = {1.0, 1.0}; bool is_unitdiag = bli_obj_has_unit_diag(a); //scratch registers @@ -45658,37 +45488,17 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB } } - - ymm1 = _mm256_broadcast_ps((__m128 const *)&ones); - ymm1 = _mm256_permute_ps(ymm1, 0x44); if(!is_unitdiag) { if(transa) { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_ps((__m128 const *)(a11)); - ymm1 = _mm256_broadcast_ps((__m128 const *) - (a11+cs_a*1 + 1)); + ctrsm_small_pack_diag_element(is_unitdiag,a11,cs_a,d11_pack,n_rem); } else { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_ps((__m128 const *)(a11)); - ymm1 = _mm256_broadcast_ps((__m128 const *) - (a11+rs_a*1 + 1)); + ctrsm_small_pack_diag_element(is_unitdiag,a11,rs_a,d11_pack,n_rem); } - ymm1 = _mm256_shuffle_ps(ymm0, ymm1, 0x44); -#ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm7 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm4 = _mm256_mul_ps(ymm1, ymm1); - ymm6 = _mm256_permute_ps(ymm4, 0xB1); - ymm4 = _mm256_add_ps(ymm4, ymm6); - ymm1 = _mm256_mul_ps(ymm1, ymm7); - ymm1 = _mm256_div_ps(ymm1, ymm4); -#endif } - _mm_store_ps((float *)(d11_pack), - _mm256_extractf128_ps(ymm1,0)); for(i = 0; (i+d_mr-1) < m; i += d_mr) { @@ -46153,25 +45963,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB } } - - ymm1 = _mm256_broadcast_ps((__m128 const *)&ones); - ymm1 = _mm256_permute_ps(ymm1, 0x44); if(!is_unitdiag) { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_ps((__m128 const *)(a11)); - ymm1 = _mm256_blend_ps(ymm0, ymm1, 0xC0); -#ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm7 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm4 = _mm256_mul_ps(ymm1, ymm1); - ymm6 = _mm256_permute_ps(ymm4, 0xB1); - ymm4 = _mm256_add_ps(ymm4, ymm6); - ymm1 = _mm256_mul_ps(ymm1, ymm7); - ymm1 = _mm256_div_ps(ymm1, ymm4); -#endif + ctrsm_small_pack_diag_element(is_unitdiag,a11,cs_a,d11_pack,n_rem); } - _mm_store_ps((float *)(d11_pack), - _mm256_extractf128_ps(ymm1,0)); for(i = 0; (i+d_mr-1) < m; i += d_mr) { From e712ffe1391c27924f122fa1f7c88e9f2c2645d1 Mon Sep 17 00:00:00 2001 From: Dipal M Zambare Date: Tue, 22 Mar 2022 11:48:25 +0530 Subject: [PATCH 102/243] Added AOCL progress support for BLIS MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit -- AOCL libraries are used for lengthy computations which can go on for hours or days, once the operation is started, the user doesn’t get any update on current state of the computation. This (AOCL progress) feature enables user to receive a periodic update from the libraries. -- User registers a callback with the library if it is interested in receiving the periodic update. -- The library invokes this callback periodically with information about current state of the operation. -- The update frequency is statically set in the code, it can be modified as needed if the library is built from source. -- These feature is supported for GEMM and TRSM operations. -- Added example for GEMM and TRSM. -- Cleaned up and reformatted test_gemm.c and test_trsm.c to remove warnings and making indentation consistent across the file. AMD-Internal: [CPUPL-2082] Change-Id: I2aacdd8fb76f52e19e3850ee0295df49a8b7a90e --- aocl_dtl/aocldtl.h | 3 +- aocl_dtl/aoclos.c | 13 +- aocl_dtl/aoclos.h | 4 +- frame/3/gemm/bli_gemm_ker_var2.c | 18 +- frame/3/trsm/bli_trsm_xx_ker_var2.c | 55 +- frame/include/bli_config_macro_defs.h | 8 +- frame/thread/bli_l3_decor_openmp.c | 13 +- frame/thread/bli_l3_decor_single.c | 15 +- frame/util/CMakeLists.txt | 3 +- frame/util/bli_util.h | 5 +- frame/util/bli_util_progress.c | 56 ++ frame/util/bli_util_progress.h | 74 +++ test/test_gemm.c | 787 ++++++++++++++------------ test/test_trsm.c | 681 +++++++++++----------- 14 files changed, 1018 insertions(+), 717 deletions(-) create mode 100644 frame/util/bli_util_progress.c create mode 100644 frame/util/bli_util_progress.h diff --git a/aocl_dtl/aocldtl.h b/aocl_dtl/aocldtl.h index 58c1a56079..7ce81561b7 100644 --- a/aocl_dtl/aocldtl.h +++ b/aocl_dtl/aocldtl.h @@ -5,7 +5,7 @@ * It provides defination for all macros to be * used by user to add debug/trace information. * - * Copyright (C) 2020-2021, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2020-2022, Advanced Micro Devices, Inc. All rights reserved. * *==================================================================*/ @@ -15,6 +15,7 @@ #include "aocldtlcf.h" #include "aocltpdef.h" #include "aoclflist.h" +#include "aoclos.h" #define TRACE_TYPE_FENTRY (1) #define TRACE_TYPE_FEXIT (2) diff --git a/aocl_dtl/aoclos.c b/aocl_dtl/aoclos.c index 92a489564e..896b1c89b3 100644 --- a/aocl_dtl/aoclos.c +++ b/aocl_dtl/aoclos.c @@ -3,7 +3,7 @@ * * Description : Abstraction for os services used by DTL. * - * Copyright (C) 2020, Advanced Micro Devices, Inc + * Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. * *==================================================================*/ #include "aocltpdef.h" @@ -85,8 +85,15 @@ uint64 AOCL_getTimestamp(void) #else /* Non linux support */ AOCL_TID AOCL_gettid(void) { - /* stub for other os's */ - return 0; +#ifdef BLIS_ENABLE_OPENMP + return omp_get_thread_num(); +#else +#ifdef BLIS_ENABLE_PTHREADS + return pthread_self(); +#else + return 0; +#endif +#endif } pid_t AOCL_getpid(void) diff --git a/aocl_dtl/aoclos.h b/aocl_dtl/aoclos.h index 3d8e1cddcc..57e0c24902 100644 --- a/aocl_dtl/aoclos.h +++ b/aocl_dtl/aoclos.h @@ -3,7 +3,7 @@ * * Description : Abstraction for os services used by DTL. * - * Copyright (C) 2020, Advanced Micro Devices, Inc + * Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. * *==================================================================*/ @@ -19,7 +19,7 @@ #define AOCL_malloc malloc #define AOCL_free free -uint32 AOCL_gettid(void); +AOCL_TID AOCL_gettid(void); pid_t AOCL_getpid(void); uint64 AOCL_getTimestamp(void); diff --git a/frame/3/gemm/bli_gemm_ker_var2.c b/frame/3/gemm/bli_gemm_ker_var2.c index 5e0a4ddb70..dc1c3d14dc 100644 --- a/frame/3/gemm/bli_gemm_ker_var2.c +++ b/frame/3/gemm/bli_gemm_ker_var2.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -407,6 +407,22 @@ void PASTEMAC(ch,varname) \ } \ } \ \ +/* Send progress update if the user has enabled it */ \ +if(AOCL_progress_ptr) { \ + /* Running total for current thread */ \ + tls_aoclprogress_counter += m * n * k; \ + /* Send the update only if enough number of elements are processes */ \ + if ((tls_aoclprogress_counter - tls_aoclprogress_last_update) >= AOCL_PROGRESS_FREQUENCY) \ + { \ + tls_aoclprogress_last_update = tls_aoclprogress_counter; \ + AOCL_PROGRESS_DT(*MKSTR(ch), \ + "gemm", \ + tls_aoclprogress_counter, \ + AOCL_gettid(), \ + bli_rntm_num_threads(rntm)); \ + }\ +} \ + \ /* PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: b1", k, NR, b1, NR, 1, "%4.1f", "" ); \ PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: a1", MR, k, a1, 1, MR, "%4.1f", "" ); \ diff --git a/frame/3/trsm/bli_trsm_xx_ker_var2.c b/frame/3/trsm/bli_trsm_xx_ker_var2.c index de8cad065a..8d2f8689a9 100644 --- a/frame/3/trsm/bli_trsm_xx_ker_var2.c +++ b/frame/3/trsm/bli_trsm_xx_ker_var2.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2020, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -87,6 +87,59 @@ void bli_trsm_xx_ker_var2 cntl, thread ); + + // Send progress update if enabled + if (AOCL_progress_ptr) + { + + // Get the size of block processed in + // this iteration, add it to the accumulated + // total and send the update. + dim_t m = bli_obj_length(c); + dim_t n = bli_obj_width(c); + dim_t k = bli_obj_width(a); + + num_t dt = bli_obj_dt(c); + char dt_c; + + // Running total for current thread. + tls_aoclprogress_counter += m * n * k; + + // Send the update only if number of elements processes so far + // has exceeded the freqency of reporting. + if ((tls_aoclprogress_counter - tls_aoclprogress_last_update) >= + AOCL_PROGRESS_FREQUENCY) + { + + // reset the last update counter for next iteration. + tls_aoclprogress_last_update = tls_aoclprogress_counter; + + switch (dt) + { + case BLIS_FLOAT: + dt_c = 's'; + break; + case BLIS_DOUBLE: + dt_c = 'd'; + break; + case BLIS_SCOMPLEX: + dt_c = 'c'; + break; + case BLIS_DCOMPLEX: + dt_c = 'z'; + break; + default: + dt_c = ' '; + } + + AOCL_PROGRESS_DT(dt_c, + "trsm", + tls_aoclprogress_counter, + AOCL_gettid(), + bli_rntm_num_threads(rntm)); + } + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_6); } diff --git a/frame/include/bli_config_macro_defs.h b/frame/include/bli_config_macro_defs.h index cfdc9652fc..dd6e8f6062 100644 --- a/frame/include/bli_config_macro_defs.h +++ b/frame/include/bli_config_macro_defs.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2019-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -261,5 +261,11 @@ #endif +#ifdef BLIS_OS_WINDOWS + #define BLIS_TLS_TYPE __declspec(thread) +#else + #define BLIS_TLS_TYPE __thread +#endif + #endif diff --git a/frame/thread/bli_l3_decor_openmp.c b/frame/thread/bli_l3_decor_openmp.c index 0bf3ad8547..b01c208a30 100644 --- a/frame/thread/bli_l3_decor_openmp.c +++ b/frame/thread/bli_l3_decor_openmp.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -140,6 +140,17 @@ void bli_l3_thread_decorator bli_l3_thrinfo_create_root( tid, gl_comm, rntm_p, cntl_use, &thread ); #if 1 + // Reset the progress state to 0 as we are starting new operations. + // This counter track running progress in current thread. + tls_aoclprogress_counter = 0; + + // We send the update only after certain threshold is reached, + // The thresold is defined as AOCL_PROGRESS_FREQUENCY. + // This variable stores the counter value when last update was sent. + // It is compared with current counter value to see if it is time to + // send the next update. + tls_aoclprogress_last_update = 0; + func ( alpha, diff --git a/frame/thread/bli_l3_decor_single.c b/frame/thread/bli_l3_decor_single.c index 12f27ad873..444583e73e 100644 --- a/frame/thread/bli_l3_decor_single.c +++ b/frame/thread/bli_l3_decor_single.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -115,7 +115,18 @@ void bli_l3_thread_decorator // Create the root node of the thread's thrinfo_t structure. bli_l3_thrinfo_create_root( tid, gl_comm, rntm_p, cntl_use, &thread ); - + + // Reset the progress state to 0 as we are starting new operations. + // This counter track running progress in current thread. + tls_aoclprogress_counter = 0; + + // We send the update only after certain threshold is reached, + // The thresold is defined as AOCL_PROGRESS_FREQUENCY. + // This variable stores the counter value when last update was sent. + // It is compared with current counter value to see if it is time to + // send the next update. + tls_aoclprogress_last_update = 0; + func ( alpha, diff --git a/frame/util/CMakeLists.txt b/frame/util/CMakeLists.txt index c20d7c525d..13fd53fc52 100644 --- a/frame/util/CMakeLists.txt +++ b/frame/util/CMakeLists.txt @@ -1,4 +1,4 @@ -##Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.## +##Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.## target_sources("${PROJECT_NAME}" PRIVATE @@ -13,4 +13,5 @@ target_sources("${PROJECT_NAME}" ${CMAKE_CURRENT_SOURCE_DIR}/bli_util_unb_var1.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_util_update.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_util_api_wrap.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_util_progress.c ) diff --git a/frame/util/bli_util.h b/frame/util/bli_util.h index 3c4e5722af..f7be273526 100644 --- a/frame/util/bli_util.h +++ b/frame/util/bli_util.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -63,3 +63,6 @@ // Header file define different formats of BLAS APIs- uppercase with // and without underscore, lowercase without underscore. #include "bli_util_api_wrap.h" + +// Public interface for the progress feature +#include "bli_util_progress.h" \ No newline at end of file diff --git a/frame/util/bli_util_progress.c b/frame/util/bli_util_progress.c new file mode 100644 index 0000000000..4097eb1126 --- /dev/null +++ b/frame/util/bli_util_progress.c @@ -0,0 +1,56 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +// The progress feature periodically updates the user with current state +// of the operation, We maintain the progress for each thread separately +// following variables are used to store the elements processed in each +// thread using thread local storage. +BLIS_TLS_TYPE dim_t tls_aoclprogress_counter; + +// Store the counter when last update was sent, this is used to implement +// update freqency. +BLIS_TLS_TYPE dim_t tls_aoclprogress_last_update; + + +// AOCL_progress_ptr contains the pointer to the callback function +// By default it is set to NULL, which effectivly disabled the +// progress feature. +AOCL_progress_callback AOCL_progress_ptr = NULL; + +void AOCL_BLIS_set_progress(AOCL_progress_callback func) +{ + AOCL_progress_ptr = func; +} \ No newline at end of file diff --git a/frame/util/bli_util_progress.h b/frame/util/bli_util_progress.h new file mode 100644 index 0000000000..0e2a63eb1c --- /dev/null +++ b/frame/util/bli_util_progress.h @@ -0,0 +1,74 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLI_UTIL_PROGRESS_H +#define BLI_UTIL_PROGRESS_H + +// Public interface for the end user. + +typedef dim_t (*AOCL_progress_callback)(char *api, + dim_t lapi, + dim_t progress, + dim_t current_thread, + dim_t total_threads); + +BLIS_EXPORT_BLIS void AOCL_BLIS_set_progress(AOCL_progress_callback func); + +// Private interfaces for internal use + +extern AOCL_progress_callback AOCL_progress_ptr; + +extern BLIS_TLS_TYPE dim_t tls_aoclprogress_counter; +extern BLIS_TLS_TYPE dim_t tls_aoclprogress_last_update; + +// Define the frequency of reporting (number of elements). +// Progress update will be sent only after these many +// elements are processed in the current thread. +#define AOCL_PROGRESS_FREQUENCY 1e+9 + +#define MAX_API_NAME_LEN 20 + +// Macro to send update using datatype character and the api name +#define AOCL_PROGRESS_DT(dt, api, progress, tid, nt) \ + char buf[MAX_API_NAME_LEN]; \ + snprintf(buf, MAX_API_NAME_LEN, "%c%s", dt, api); \ + (*AOCL_progress_ptr) (buf, strlen(buf), progress, tid, nt); \ + +// Macro to send update using api name alone. +#define AOCL_PROGRESS_NAME(api, progress, tid, nt) \ + char buf[MAX_API_NAME_LEN]; \ + snprintf(buf, MAX_API_NAME_LEN, "%s", dt, api); \ + (*AOCL_progress_ptr) (buf, strlen(buf), progress, tid, nt); \ + +#endif // BLI_UTIL_PROGRESS_H diff --git a/test/test_gemm.c b/test/test_gemm.c index 25fc5e3d8d..81b7e36616 100644 --- a/test/test_gemm.c +++ b/test/test_gemm.c @@ -10,14 +10,14 @@ Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name of The University of Texas nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name of The University of Texas nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -47,426 +47,471 @@ // uncomment to enable cblas interface //#define CBLAS -int main( int argc, char** argv ) +// Uncomment to enable progress printing. +//#define PROGRESS_ENABLED + +#ifdef PROGRESS_ENABLED +dim_t AOCL_progress(char *api, + dim_t lapi, + dim_t progress, + dim_t current_thread, + dim_t total_threads) +{ + printf("\n%s, len = %ld, nt = %ld, tid = %ld, Processed %ld Elements", + api, lapi, total_threads, current_thread, progress); + + return 0; +} +#endif + +int main(int argc, char **argv) { - obj_t a, b, c; - obj_t c_save; - obj_t alpha, beta; - dim_t m, n, k; - inc_t lda, ldb, ldc; - num_t dt, dt_a; - inc_t r, n_repeats; - trans_t transa; - trans_t transb; - f77_char f77_transa; - f77_char f77_transb; - - double dtime; - double dtime_save; - double gflops; - - //bli_init(); - //bli_error_checking_level_set( BLIS_NO_ERROR_CHECKING ); - - n_repeats = 300; - - //dt = BLIS_FLOAT; - dt = BLIS_DOUBLE; - //dt = BLIS_SCOMPLEX; - //dt = BLIS_DCOMPLEX; - - if( bli_is_real( dt ) || bli_is_scomplex( dt ) ) + obj_t a, b, c; + obj_t c_save; + obj_t alpha, beta; + dim_t m, n, k; + inc_t lda, ldb, ldc; + num_t dt, dt_a; + inc_t r, n_repeats; + trans_t transa; + trans_t transb; + f77_char f77_transa; + f77_char f77_transb; + + double dtime; + double dtime_save; + double gflops; + +#ifdef PROGRESS_ENABLED + AOCL_BLIS_set_progress(AOCL_progress); +#endif + + // bli_init(); + // bli_error_checking_level_set( BLIS_NO_ERROR_CHECKING ); + + n_repeats = 300; + + // dt = BLIS_FLOAT; + dt = BLIS_DOUBLE; + // dt = BLIS_SCOMPLEX; + // dt = BLIS_DCOMPLEX; + + if (bli_is_real(dt) || bli_is_scomplex(dt)) dt_a = dt; else { dt_a = dt; // Enable the following to call - // dzgemm - //dt_a = BLIS_DOUBLE; + // dzgemm + // dt_a = BLIS_DOUBLE; } const char stor_scheme = 'C'; - transa = BLIS_NO_TRANSPOSE; - transb = BLIS_NO_TRANSPOSE; - - bli_param_map_blis_to_netlib_trans( transa, &f77_transa ); - bli_param_map_blis_to_netlib_trans( transb, &f77_transb ); + transa = BLIS_NO_TRANSPOSE; + transb = BLIS_NO_TRANSPOSE; + bli_param_map_blis_to_netlib_trans(transa, &f77_transa); + bli_param_map_blis_to_netlib_trans(transb, &f77_transb); printf("BLIS Library version is : %s\n", bli_info_get_version_str()); #ifdef FILE_IN_OUT - FILE* fin = NULL; - FILE* fout = NULL; - if (argc < 3){ - printf("Usage: ./test_gemm_XX.x input.csv output.csv\n"); - exit(1); - } - fin = fopen(argv[1], "r"); - if (fin == NULL){ - printf("Error opening the file %s\n", argv[1]); - exit(1); - } - fout = fopen(argv[2], "w"); - if (fout == NULL){ - printf("Error opening output file %s\n", argv[2]); - exit(1); - } - fprintf(fout, "m\t k\t n\t cs_a\t cs_b\t cs_c\t gflops\n"); - printf("~~~~~~~~~~_BLAS\t m\t k\t n\t cs_a\t cs_b\t cs_c \t gflops\n"); - - while (fscanf(fin, "%lld %lld %lld %lld %lld %lld\n", &m, &k, &n, &lda, &ldb, &ldc) == 6) - { - // dimensions should not be greater than leading dimensions - // These are valid only when Op(A) = n and op(B) = n - if( (stor_scheme == 'C') || (stor_scheme == 'c') ) { - if ((m > lda) || (k > ldb) || (m > ldc)) continue; - }else if( (stor_scheme == 'R') || (stor_scheme == 'r') ) { - // leading dimension should be greater than number of cols - if ((k > lda) || (n > ldb) || (n > ldc)) continue; - }else { - printf("Invalid Storage type\n"); - continue; - } + FILE *fin = NULL; + FILE *fout = NULL; + if (argc < 3) + { + printf("Usage: ./test_gemm_XX.x input.csv output.csv\n"); + exit(1); + } + fin = fopen(argv[1], "r"); + if (fin == NULL) + { + printf("Error opening the file %s\n", argv[1]); + exit(1); + } + fout = fopen(argv[2], "w"); + if (fout == NULL) + { + printf("Error opening output file %s\n", argv[2]); + exit(1); + } + fprintf(fout, "m\t k\t n\t cs_a\t cs_b\t cs_c\t gflops\n"); + printf("~~~~~~~~~~_BLAS\t m\t k\t n\t cs_a\t cs_b\t cs_c \t gflops\n"); + + while (fscanf(fin, "%ld %ld %ld %ld %ld %ld\n", &m, &k, &n, &lda, &ldb, &ldc) == 6) + { + // dimensions should not be greater than leading dimensions + // These are valid only when Op(A) = n and op(B) = n + if ((stor_scheme == 'C') || (stor_scheme == 'c')) + { + if ((m > lda) || (k > ldb) || (m > ldc)) + continue; + } + else if ((stor_scheme == 'R') || (stor_scheme == 'r')) + { + // leading dimension should be greater than number of cols + if ((k > lda) || (n > ldb) || (n > ldc)) + continue; + } + else + { + printf("Invalid Storage type\n"); + continue; + } #else - dim_t p, p_begin, p_end, p_inc; - dim_t m_input, n_input, k_input; - p_begin = 200; - p_end = 2000; - p_inc = 200; - - m_input = n_input = k_input = -1; - for ( p = p_begin; p <= p_end; p += p_inc ) - { - if ( m_input < 0 ) m = p * ( dim_t )abs(m_input); - else m = ( dim_t ) m_input; - if ( n_input < 0 ) n = p * ( dim_t )abs(n_input); - else n = ( dim_t ) n_input; - if ( k_input < 0 ) k = p * ( dim_t )abs(k_input); - else k = ( dim_t ) k_input; - - if( (stor_scheme == 'C') || (stor_scheme == 'c') ) { - lda = m; ldb = k, ldc = m; - }else if( (stor_scheme == 'R') || (stor_scheme == 'r') ) { - lda = k; ldb = n, ldc = n; - } + dim_t p, p_begin, p_end, p_inc; + dim_t m_input, n_input, k_input; + p_begin = 200; + p_end = 2000; + p_inc = 200; + + m_input = n_input = k_input = -1; + for (p = p_begin; p <= p_end; p += p_inc) + { + if (m_input < 0) + m = p * (dim_t)labs(m_input); + else + m = (dim_t)m_input; + if (n_input < 0) + n = p * (dim_t)labs(n_input); + else + n = (dim_t)n_input; + if (k_input < 0) + k = p * (dim_t)labs(k_input); + else + k = (dim_t)k_input; + + if ((stor_scheme == 'C') || (stor_scheme == 'c')) + { + lda = m; + ldb = k, ldc = m; + } + else if ((stor_scheme == 'R') || (stor_scheme == 'r')) + { + lda = k; + ldb = n, ldc = n; + } #endif - bli_obj_create( dt, 1, 1, 0, 0, &alpha); - bli_obj_create( dt, 1, 1, 0, 0, &beta ); - - siz_t elem_size = bli_dt_size( dt ); - - lda = bli_align_dim_to_size( lda, elem_size, BLIS_HEAP_STRIDE_ALIGN_SIZE ); - ldb = bli_align_dim_to_size( ldb, elem_size, BLIS_HEAP_STRIDE_ALIGN_SIZE ); - ldc = bli_align_dim_to_size( ldc, elem_size, BLIS_HEAP_STRIDE_ALIGN_SIZE ); - - // Will verify the leading dimension is powers of 2 and add 64bytes. - inc_t n_bytes = lda*sizeof(dt_a); - - if((n_bytes!=0) && !(n_bytes&(n_bytes-1)))// check whether n_bytes is power of 2. - lda += BLIS_SIMD_ALIGN_SIZE/sizeof(dt_a); - - n_bytes = ldb*sizeof(dt); - if((n_bytes!=0) && !(n_bytes&(n_bytes-1)))// check whether n_bytes is power of 2. - ldb += BLIS_SIMD_ALIGN_SIZE/sizeof(dt); - - n_bytes = ldc*sizeof(dt); - if((n_bytes!=0) && !(n_bytes&(n_bytes-1)))// check whether n_bytes is power of 2. - ldc += BLIS_SIMD_ALIGN_SIZE/sizeof(dt); - - if( (stor_scheme == 'C') || (stor_scheme == 'c') ) - { - // Col-major Order - bli_obj_create( dt_a, m, k, 1, lda, &a ); - bli_obj_create( dt, k, n, 1, ldb, &b ); - bli_obj_create( dt, m, n, 1, ldc, &c ); - bli_obj_create( dt, m, n, 1, ldc, &c_save ); - } - else if( (stor_scheme == 'R') || (stor_scheme == 'r') ) - { - // Row-major Order - bli_obj_create( dt_a, m, k, lda, 1, &a ); - bli_obj_create( dt, k, n, ldb, 1, &b ); - bli_obj_create( dt, m, n, ldc, 1, &c ); - bli_obj_create( dt, m, n, ldc, 1, &c_save ); - } - else - { - printf("Invalid Storage type\n"); - continue; - } + bli_obj_create(dt, 1, 1, 0, 0, &alpha); + bli_obj_create(dt, 1, 1, 0, 0, &beta); + + siz_t elem_size = bli_dt_size(dt); + + lda = bli_align_dim_to_size(lda, elem_size, BLIS_HEAP_STRIDE_ALIGN_SIZE); + ldb = bli_align_dim_to_size(ldb, elem_size, BLIS_HEAP_STRIDE_ALIGN_SIZE); + ldc = bli_align_dim_to_size(ldc, elem_size, BLIS_HEAP_STRIDE_ALIGN_SIZE); + + // Will verify the leading dimension is powers of 2 and add 64bytes. + inc_t n_bytes = lda * sizeof(dt_a); + + if ((n_bytes != 0) && !(n_bytes & (n_bytes - 1))) // check whether n_bytes is power of 2. + lda += BLIS_SIMD_ALIGN_SIZE / sizeof(dt_a); + + n_bytes = ldb * sizeof(dt); + if ((n_bytes != 0) && !(n_bytes & (n_bytes - 1))) // check whether n_bytes is power of 2. + ldb += BLIS_SIMD_ALIGN_SIZE / sizeof(dt); + + n_bytes = ldc * sizeof(dt); + if ((n_bytes != 0) && !(n_bytes & (n_bytes - 1))) // check whether n_bytes is power of 2. + ldc += BLIS_SIMD_ALIGN_SIZE / sizeof(dt); + + if ((stor_scheme == 'C') || (stor_scheme == 'c')) + { + // Col-major Order + bli_obj_create(dt_a, m, k, 1, lda, &a); + bli_obj_create(dt, k, n, 1, ldb, &b); + bli_obj_create(dt, m, n, 1, ldc, &c); + bli_obj_create(dt, m, n, 1, ldc, &c_save); + } + else if ((stor_scheme == 'R') || (stor_scheme == 'r')) + { + // Row-major Order + bli_obj_create(dt_a, m, k, lda, 1, &a); + bli_obj_create(dt, k, n, ldb, 1, &b); + bli_obj_create(dt, m, n, ldc, 1, &c); + bli_obj_create(dt, m, n, ldc, 1, &c_save); + } + else + { + printf("Invalid Storage type\n"); + continue; + } #ifdef MATRIX_INITIALISATION - bli_randm( &a ); - bli_randm( &b ); - bli_randm( &c ); + bli_randm(&a); + bli_randm(&b); + bli_randm(&c); #endif - bli_obj_set_conjtrans( transa, &a); - bli_obj_set_conjtrans( transb, &b); - bli_setsc( (0.9/1.0), 0.2, &alpha ); - bli_setsc( -(1.1/1.0), 0.3, &beta ); - - bli_copym( &c, &c_save ); - dtime_save = DBL_MAX; - for ( r = 0; r < n_repeats; ++r ) - { - bli_copym( &c_save, &c ); - dtime = bli_clock(); + bli_obj_set_conjtrans(transa, &a); + bli_obj_set_conjtrans(transb, &b); + bli_setsc((0.9 / 1.0), 0.2, &alpha); + bli_setsc(-(1.1 / 1.0), 0.3, &beta); + + bli_copym(&c, &c_save); + dtime_save = DBL_MAX; + for (r = 0; r < n_repeats; ++r) + { + bli_copym(&c_save, &c); + dtime = bli_clock(); #ifdef BLIS - bli_gemm( &alpha, - &a, - &b, - &beta, - &c ); + bli_gemm(&alpha, + &a, + &b, + &beta, + &c); #else - f77_int lda, ldb, ldc; - f77_int mm = bli_obj_length( &c ); - f77_int kk = bli_obj_width_after_trans( &a ); - f77_int nn = bli_obj_width( &c ); + f77_int lda, ldb, ldc; + f77_int mm = bli_obj_length(&c); + f77_int kk = bli_obj_width_after_trans(&a); + f77_int nn = bli_obj_width(&c); #ifdef CBLAS - enum CBLAS_ORDER cblas_order; - enum CBLAS_TRANSPOSE cblas_transa; - enum CBLAS_TRANSPOSE cblas_transb; - - if ( bli_obj_row_stride( &c ) == 1 ){ - cblas_order = CblasColMajor; - }else{ - cblas_order = CblasRowMajor; - } - - if( bli_is_trans( transa ) ) - cblas_transa = CblasTrans; - else if( bli_is_conjtrans( transa ) ) - cblas_transa = CblasConjTrans; - else - cblas_transa = CblasNoTrans; - - if( bli_is_trans( transb ) ) - cblas_transb = CblasTrans; - else if( bli_is_conjtrans( transb ) ) - cblas_transb = CblasConjTrans; - else - cblas_transb = CblasNoTrans; + enum CBLAS_ORDER cblas_order; + enum CBLAS_TRANSPOSE cblas_transa; + enum CBLAS_TRANSPOSE cblas_transb; + + if (bli_obj_row_stride(&c) == 1) + { + cblas_order = CblasColMajor; + } + else + { + cblas_order = CblasRowMajor; + } + + if (bli_is_trans(transa)) + cblas_transa = CblasTrans; + else if (bli_is_conjtrans(transa)) + cblas_transa = CblasConjTrans; + else + cblas_transa = CblasNoTrans; + + if (bli_is_trans(transb)) + cblas_transb = CblasTrans; + else if (bli_is_conjtrans(transb)) + cblas_transb = CblasConjTrans; + else + cblas_transb = CblasNoTrans; #else - f77_char f77_transa; - f77_char f77_transb; - bli_param_map_blis_to_netlib_trans( transa, &f77_transa ); - bli_param_map_blis_to_netlib_trans( transb, &f77_transb ); + f77_char f77_transa; + f77_char f77_transb; + bli_param_map_blis_to_netlib_trans(transa, &f77_transa); + bli_param_map_blis_to_netlib_trans(transb, &f77_transb); #endif - if( (stor_scheme == 'C') || (stor_scheme == 'c') ){ - lda = bli_obj_col_stride( &a ); - ldb = bli_obj_col_stride( &b ); - ldc = bli_obj_col_stride( &c ); - } else { - lda = bli_obj_row_stride( &a ); - ldb = bli_obj_row_stride( &b ); - ldc = bli_obj_row_stride( &c ); - } - - if ( bli_is_float( dt ) ) - { - float* alphap = bli_obj_buffer( &alpha ); - float* ap = bli_obj_buffer( &a ); - float* bp = bli_obj_buffer( &b ); - float* betap = bli_obj_buffer( &beta ); - float* cp = bli_obj_buffer( &c ); + if ((stor_scheme == 'C') || (stor_scheme == 'c')) + { + lda = bli_obj_col_stride(&a); + ldb = bli_obj_col_stride(&b); + ldc = bli_obj_col_stride(&c); + } + else + { + lda = bli_obj_row_stride(&a); + ldb = bli_obj_row_stride(&b); + ldc = bli_obj_row_stride(&c); + } + + if (bli_is_float(dt)) + { + float *alphap = bli_obj_buffer(&alpha); + float *ap = bli_obj_buffer(&a); + float *bp = bli_obj_buffer(&b); + float *betap = bli_obj_buffer(&beta); + float *cp = bli_obj_buffer(&c); #ifdef CBLAS - cblas_sgemm( cblas_order, - cblas_transa, - cblas_transb, - mm, - nn, - kk, - *alphap, - ap, lda, - bp, ldb, - *betap, - cp, ldc - ); + cblas_sgemm(cblas_order, + cblas_transa, + cblas_transb, + mm, + nn, + kk, + *alphap, + ap, lda, + bp, ldb, + *betap, + cp, ldc); #else - sgemm_( &f77_transa, - &f77_transb, - &mm, - &nn, - &kk, - alphap, - ap, (f77_int*)&lda, - bp, (f77_int*)&ldb, - betap, - cp, (f77_int*)&ldc ); + sgemm_(&f77_transa, + &f77_transb, + &mm, + &nn, + &kk, + alphap, + ap, (f77_int *)&lda, + bp, (f77_int *)&ldb, + betap, + cp, (f77_int *)&ldc); #endif - }else if ( bli_is_double( dt ) ) - { - double* alphap = bli_obj_buffer( &alpha ); - double* ap = bli_obj_buffer( &a ); - double* bp = bli_obj_buffer( &b ); - double* betap = bli_obj_buffer( &beta ); - double* cp = bli_obj_buffer( &c ); + } + else if (bli_is_double(dt)) + { + double *alphap = bli_obj_buffer(&alpha); + double *ap = bli_obj_buffer(&a); + double *bp = bli_obj_buffer(&b); + double *betap = bli_obj_buffer(&beta); + double *cp = bli_obj_buffer(&c); #ifdef CBLAS - cblas_dgemm( cblas_order, - cblas_transa, - cblas_transb, - mm, - nn, - kk, - *alphap, - ap, lda, - bp, ldb, - *betap, - cp, ldc - ); + cblas_dgemm(cblas_order, + cblas_transa, + cblas_transb, + mm, + nn, + kk, + *alphap, + ap, lda, + bp, ldb, + *betap, + cp, ldc); #else - dgemm_( &f77_transa, - &f77_transb, - &mm, - &nn, - &kk, - alphap, - ap, (f77_int*)&lda, - bp, (f77_int*)&ldb, - betap, - cp, (f77_int*)&ldc ); + dgemm_(&f77_transa, + &f77_transb, + &mm, + &nn, + &kk, + alphap, + ap, (f77_int *)&lda, + bp, (f77_int *)&ldb, + betap, + cp, (f77_int *)&ldc); #endif - }else if ( bli_is_scomplex( dt ) ) - { - scomplex* alphap = bli_obj_buffer( &alpha ); - scomplex* ap = bli_obj_buffer( &a ); - scomplex* bp = bli_obj_buffer( &b ); - scomplex* betap = bli_obj_buffer( &beta ); - scomplex* cp = bli_obj_buffer( &c ); + } + else if (bli_is_scomplex(dt)) + { + scomplex *alphap = bli_obj_buffer(&alpha); + scomplex *ap = bli_obj_buffer(&a); + scomplex *bp = bli_obj_buffer(&b); + scomplex *betap = bli_obj_buffer(&beta); + scomplex *cp = bli_obj_buffer(&c); #ifdef CBLAS - cblas_cgemm( cblas_order, - cblas_transa, - cblas_transb, - mm, - nn, - kk, - alphap, - ap, lda, - bp, ldb, - betap, - cp, ldc - ); + cblas_cgemm(cblas_order, + cblas_transa, + cblas_transb, + mm, + nn, + kk, + alphap, + ap, lda, + bp, ldb, + betap, + cp, ldc); #else - cgemm_( &f77_transa, - &f77_transb, - &mm, - &nn, - &kk, - alphap, - ap, (f77_int*)&lda, - bp, (f77_int*)&ldb, - betap, - cp, (f77_int*)&ldc ); + cgemm_(&f77_transa, + &f77_transb, + &mm, + &nn, + &kk, + alphap, + ap, (f77_int *)&lda, + bp, (f77_int *)&ldb, + betap, + cp, (f77_int *)&ldc); #endif - }else if ( bli_is_dcomplex( dt ) ) - { - dcomplex* alphap = bli_obj_buffer( &alpha ); - dcomplex* ap = bli_obj_buffer( &a ); - dcomplex* bp = bli_obj_buffer( &b ); - dcomplex* betap = bli_obj_buffer( &beta ); - dcomplex* cp = bli_obj_buffer( &c ); + } + else if (bli_is_dcomplex(dt)) + { + dcomplex *alphap = bli_obj_buffer(&alpha); + dcomplex *ap = bli_obj_buffer(&a); + dcomplex *bp = bli_obj_buffer(&b); + dcomplex *betap = bli_obj_buffer(&beta); + dcomplex *cp = bli_obj_buffer(&c); #ifdef CBLAS - cblas_zgemm( cblas_order, - cblas_transa, - cblas_transb, - mm, - nn, - kk, - alphap, - ap, lda, - bp, ldb, - betap, - cp, ldc - ); + cblas_zgemm(cblas_order, + cblas_transa, + cblas_transb, + mm, + nn, + kk, + alphap, + ap, lda, + bp, ldb, + betap, + cp, ldc); #else #if 1 - if( bli_is_double( dt_a ) ) - { - dzgemm_( - &f77_transa, - &f77_transb, - &mm, - &nn, - &kk, - alphap, - (double*)ap, (f77_int*)&lda, - bp, (f77_int*)&ldb, - betap, - cp, (f77_int*)&ldc - ); - } - else - { - zgemm_( &f77_transa, - &f77_transb, - &mm, - &nn, - &kk, - alphap, - ap, (f77_int*)&lda, - bp, (f77_int*)&ldb, - betap, - cp, (f77_int*)&ldc ); - } + if (bli_is_double(dt_a)) + { + dzgemm_( + &f77_transa, + &f77_transb, + &mm, + &nn, + &kk, + alphap, + (double *)ap, (f77_int *)&lda, + bp, (f77_int *)&ldb, + betap, + cp, (f77_int *)&ldc); + } + else + { + zgemm_(&f77_transa, + &f77_transb, + &mm, + &nn, + &kk, + alphap, + ap, (f77_int *)&lda, + bp, (f77_int *)&ldb, + betap, + cp, (f77_int *)&ldc); + } #endif #endif - } + } #endif #ifdef PRINT - bli_printm( "a", &a, "%4.1f", "" ); - bli_printm( "b", &b, "%4.1f", "" ); - bli_printm( "c", &c, "%4.1f", "" ); - bli_printm( "c after", &c, "%4.1f", "" ); - exit(1); + bli_printm("a", &a, "%4.1f", ""); + bli_printm("b", &b, "%4.1f", ""); + bli_printm("c", &c, "%4.1f", ""); + bli_printm("c after", &c, "%4.1f", ""); + exit(1); #endif - dtime_save = bli_clock_min_diff( dtime_save, dtime ); - }//nrepeats + dtime_save = bli_clock_min_diff(dtime_save, dtime); + } // nrepeats - gflops = ( 2.0 * m * k * n ) / ( dtime_save * 1.0e9 ); - if (bli_is_dcomplex(dt) && (bli_is_double(dt_a))) - gflops *= 2.0; - else if ( bli_is_complex( dt ) ) gflops *= 4.0; + gflops = (2.0 * m * k * n) / (dtime_save * 1.0e9); + if (bli_is_dcomplex(dt) && (bli_is_double(dt_a))) + gflops *= 2.0; + else if (bli_is_complex(dt)) + gflops *= 4.0; #ifdef BLIS - printf("data_gemm_blis" ); + printf("data_gemm_blis"); #else - printf("data_gemm_%s", BLAS ); + printf("data_gemm_%s", BLAS); #endif - #ifdef FILE_IN_OUT - printf("%6lu \t %4lu \t %4lu \t %4lu \t %4lu \t %4lu \t %6.3f\n", \ - ( unsigned long )m,( unsigned long )k,( unsigned long )n, - (unsigned long)lda,(unsigned long)ldb,(unsigned long)ldc,gflops); + printf("%6lu \t %4lu \t %4lu \t %4lu \t %4lu \t %4lu \t %6.3f\n", + (unsigned long)m, (unsigned long)k, (unsigned long)n, + (unsigned long)lda, (unsigned long)ldb, (unsigned long)ldc, gflops); - fprintf(fout, "%6lu \t %4lu \t %4lu \t %4lu \t %4lu \t %4lu \t %6.3f\n", \ - ( unsigned long )m,( unsigned long )k,( unsigned long )n, - (unsigned long)lda,(unsigned long)ldb,(unsigned long)ldc,gflops); - fflush(fout); + fprintf(fout, "%6lu \t %4lu \t %4lu \t %4lu \t %4lu \t %4lu \t %6.3f\n", + (unsigned long)m, (unsigned long)k, (unsigned long)n, + (unsigned long)lda, (unsigned long)ldb, (unsigned long)ldc, gflops); + fflush(fout); #else - printf( "( %2lu, 1:4 ) = [ %4lu %4lu %4lu %7.2f ];\n", - ( unsigned long )(p - p_begin)/p_inc + 1, - ( unsigned long )m,( unsigned long )k, - ( unsigned long )n, gflops ); + printf("( %2lu, 1:4 ) = [ %4lu %4lu %4lu %7.2f ];\n", + (unsigned long)(p - p_begin) / p_inc + 1, + (unsigned long)m, (unsigned long)k, + (unsigned long)n, gflops); #endif - bli_obj_free( &alpha ); - bli_obj_free( &beta ); + bli_obj_free(&alpha); + bli_obj_free(&beta); - bli_obj_free( &a ); - bli_obj_free( &b ); - bli_obj_free( &c ); - bli_obj_free( &c_save ); - }//while + bli_obj_free(&a); + bli_obj_free(&b); + bli_obj_free(&c); + bli_obj_free(&c_save); + } // while - //bli_finalize(); + // bli_finalize(); #ifdef FILE_IN_OUT - fclose(fin); - fclose(fout); + fclose(fin); + fclose(fout); #endif - return 0; + return 0; } diff --git a/test/test_trsm.c b/test/test_trsm.c index 72156d92fe..f6709f5d7f 100644 --- a/test/test_trsm.c +++ b/test/test_trsm.c @@ -5,19 +5,19 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020-2021, Advanced Micro Devices, Inc. + Copyright (C) 2020-2022, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -50,14 +50,31 @@ #define CACHE_LINE_SIZE 64 -int main( int argc, char** argv ) +// Uncomment to enable progress printing. +//#define PROGRESS_ENABLED + +#ifdef PROGRESS_ENABLED +dim_t AOCL_progress(char *api, + dim_t lapi, + dim_t progress, + dim_t current_thread, + dim_t total_threads) +{ + printf("\n%s, len = %ld, nt = %ld, tid = %ld, Processed %ld Elements", + api, lapi, total_threads, current_thread, progress); + + return 0; +} +#endif + +int main(int argc, char **argv) { obj_t a, c; obj_t c_save; obj_t alpha; dim_t m, n; num_t dt; - int r, n_repeats; + int r, n_repeats; side_t side; uplo_t uploa; trans_t transa; @@ -72,16 +89,20 @@ int main( int argc, char** argv ) double gflops; #ifdef FILE_IN_OUT - FILE* fin = NULL; - FILE* fout = NULL; + FILE *fin = NULL; + FILE *fout = NULL; #else dim_t p; dim_t p_begin, p_end, p_inc; - int m_input, n_input; + int m_input, n_input; - //bli_init(); +#ifdef PROGRESS_ENABLED + AOCL_BLIS_set_progress(AOCL_progress); +#endif + + // bli_init(); - //bli_error_checking_level_set( BLIS_NO_ERROR_CHECKING ); + // bli_error_checking_level_set( BLIS_NO_ERROR_CHECKING ); #ifndef PRINT p_begin = 200; @@ -102,26 +123,26 @@ int main( int argc, char** argv ) n_repeats = 3; - //dt = BLIS_FLOAT; + // dt = BLIS_FLOAT; dt = BLIS_DOUBLE; - //dt = BLIS_SCOMPLEX; - //dt = BLIS_DCOMPLEX; + // dt = BLIS_SCOMPLEX; + // dt = BLIS_DCOMPLEX; #ifdef FILE_IN_OUT - if(argc < 3) + if (argc < 3) { printf("Usage: ./test_trsm_XX.x input.csv output.csv\n"); exit(1); } fin = fopen(argv[1], "r"); - if(fin == NULL) + if (fin == NULL) { printf("Error opening the file %s\n", argv[1]); exit(1); } fout = fopen(argv[2], "w"); - if(fout == NULL) + if (fout == NULL) { printf("Error opening the file %s\n", argv[2]); exit(1); @@ -130,425 +151,421 @@ int main( int argc, char** argv ) inc_t cs_b; #ifdef READ_ALL_PARAMS_FROM_FILE char side_c, uploa_c, transa_c, diaga_c; - + fprintf(fout, "side, uploa, transa, diaga, m\t n\t cs_a\t cs_b\t gflops\n"); printf("~~~~~~~_BLAS\t side, uploa, transa, diaga, m\t n\t cs_a\t cs_b\t gflops\n"); - while(fscanf(fin, "%c %c %c %c %ld %ld %ld %ld\n", &side_c, &uploa_c, &transa_c, &diaga_c, &m, &n, &cs_a, &cs_b) == 8) + while (fscanf(fin, "%c %c %c %c %ld %ld %ld %ld\n", &side_c, &uploa_c, &transa_c, &diaga_c, &m, &n, &cs_a, &cs_b) == 8) { - if( 'l' == side_c|| 'L' == side_c) - side = BLIS_LEFT; - else if('r' == side_c || 'R' == side_c) - side = BLIS_RIGHT; - else - { - printf("Invalid entry for the argument 'side':%c\n",side_c); - continue; - } + if ('l' == side_c || 'L' == side_c) + side = BLIS_LEFT; + else if ('r' == side_c || 'R' == side_c) + side = BLIS_RIGHT; + else + { + printf("Invalid entry for the argument 'side':%c\n", side_c); + continue; + } - if('l' == uploa_c || 'L' == uploa_c) - uploa = BLIS_LOWER; - else if('u' == uploa_c || 'U' == uploa_c) - uploa = BLIS_UPPER; - else - { - printf("Invalid entry for the argument 'uplo':%c\n",uploa_c); - continue; - } + if ('l' == uploa_c || 'L' == uploa_c) + uploa = BLIS_LOWER; + else if ('u' == uploa_c || 'U' == uploa_c) + uploa = BLIS_UPPER; + else + { + printf("Invalid entry for the argument 'uplo':%c\n", uploa_c); + continue; + } - if('t' == transa_c || 'T' == transa_c) - transa = BLIS_TRANSPOSE; - else if('n' == transa_c || 'N' == transa_c) - transa = BLIS_NO_TRANSPOSE; - else - { - printf("Invalid entry for the argument 'transa':%c\n",transa_c); - continue; - } - - if('u' == diaga_c || 'U' == diaga_c) - diaga = BLIS_UNIT_DIAG; - else if('n' == diaga_c || 'N' == diaga_c) - diaga = BLIS_NONUNIT_DIAG; - else - { - printf("Invalid entry for the argument 'diaga':%c\n", diaga_c); - continue; - } + if ('t' == transa_c || 'T' == transa_c) + transa = BLIS_TRANSPOSE; + else if ('n' == transa_c || 'N' == transa_c) + transa = BLIS_NO_TRANSPOSE; + else + { + printf("Invalid entry for the argument 'transa':%c\n", transa_c); + continue; + } + + if ('u' == diaga_c || 'U' == diaga_c) + diaga = BLIS_UNIT_DIAG; + else if ('n' == diaga_c || 'N' == diaga_c) + diaga = BLIS_NONUNIT_DIAG; + else + { + printf("Invalid entry for the argument 'diaga':%c\n", diaga_c); + continue; + } #else - + fprintf(fout, "m\t n\t cs_a\t cs_b\t gflops\n"); printf("~~~~~~~_BLAS\t m\t n\t cs_a\t cs_b\t gflops\n"); - while(fscanf(fin, "%ld %ld %ld %ld\n", &m, &n, &cs_a, &cs_b) == 4) + while (fscanf(fin, "%ld %ld %ld %ld\n", &m, &n, &cs_a, &cs_b) == 4) { - - side = BLIS_LEFT; - //side = BLIS_RIGHT; - uploa = BLIS_LOWER; - //uploa = BLIS_UPPER; + side = BLIS_LEFT; + // side = BLIS_RIGHT; - transa = BLIS_NO_TRANSPOSE; + uploa = BLIS_LOWER; + // uploa = BLIS_UPPER; - diaga = BLIS_NONUNIT_DIAG; + transa = BLIS_NO_TRANSPOSE; + diaga = BLIS_NONUNIT_DIAG; #endif - bli_param_map_blis_to_netlib_side( side, &f77_side ); - bli_param_map_blis_to_netlib_uplo( uploa, &f77_uploa ); - bli_param_map_blis_to_netlib_trans( transa, &f77_transa ); - bli_param_map_blis_to_netlib_diag( diaga, &f77_diaga ); + bli_param_map_blis_to_netlib_side(side, &f77_side); + bli_param_map_blis_to_netlib_uplo(uploa, &f77_uploa); + bli_param_map_blis_to_netlib_trans(transa, &f77_transa); + bli_param_map_blis_to_netlib_diag(diaga, &f77_diaga); + siz_t elem_size = bli_dt_size(dt); - siz_t elem_size = bli_dt_size( dt ); + cs_a = bli_align_dim_to_size(cs_a, elem_size, BLIS_HEAP_STRIDE_ALIGN_SIZE); + cs_b = bli_align_dim_to_size(cs_b, elem_size, BLIS_HEAP_STRIDE_ALIGN_SIZE); - cs_a = bli_align_dim_to_size( cs_a, elem_size, BLIS_HEAP_STRIDE_ALIGN_SIZE ); - cs_b = bli_align_dim_to_size( cs_b, elem_size, BLIS_HEAP_STRIDE_ALIGN_SIZE ); + // Will verify the leading dimension is powers of 2 and add 64bytes. + inc_t n_bytes = cs_a * sizeof(dt); - //Will verify the leading dimension is powers of 2 and add 64bytes. - inc_t n_bytes = cs_a*sizeof(dt); + if ((n_bytes != 0) && !(n_bytes & (n_bytes - 1))) // check whether n_bytes is power of 2. + cs_a += CACHE_LINE_SIZE / sizeof(dt); - if((n_bytes!=0) && !(n_bytes&(n_bytes-1)))// check whether n_bytes is power of 2. - cs_a += CACHE_LINE_SIZE/sizeof(dt); + n_bytes = cs_b * sizeof(dt); + if ((n_bytes != 0) && !(n_bytes & (n_bytes - 1))) // check whether n_bytes is power of 2. + cs_b += CACHE_LINE_SIZE / sizeof(dt); - n_bytes = cs_b*sizeof(dt); - if((n_bytes!=0) && !(n_bytes&(n_bytes-1)))// check whether n_bytes is power of 2. - cs_b += CACHE_LINE_SIZE/sizeof(dt); + if (bli_is_left(side) && ((m > cs_a) || (m > cs_b))) + continue; // leading dimension should be greater than number of rows + if (bli_is_right(side) && ((n > cs_a) || (m > cs_b))) + continue; // leading dimension should be greater than number of rows - if(bli_is_left(side) && ((m > cs_a) || (m > cs_b))) continue; //leading dimension should be greater than number of rows - - if(bli_is_right(side) && ((n > cs_a) || (m > cs_b))) continue; //leading dimension should be greater than number of rows - - if ( bli_is_left( side ) ) - bli_obj_create( dt, m, m, 1, m, &a ); + if (bli_is_left(side)) + bli_obj_create(dt, m, m, 1, m, &a); else - bli_obj_create( dt, n, n, 1, n, &a ); - bli_obj_create( dt, m, n, 1, m, &c ); - bli_obj_create( dt, m, n, 1, m, &c_save ); + bli_obj_create(dt, n, n, 1, n, &a); + bli_obj_create(dt, m, n, 1, m, &c); + bli_obj_create(dt, m, n, 1, m, &c_save); #else - for ( p = p_end; p >= p_begin; p -= p_inc ) + for (p = p_end; p >= p_begin; p -= p_inc) { - if ( m_input < 0 ) m = p * ( dim_t )abs(m_input); - else m = ( dim_t ) m_input; - if ( n_input < 0 ) n = p * ( dim_t )abs(n_input); - else n = ( dim_t ) n_input; + if (m_input < 0) + m = p * (dim_t)abs(m_input); + else + m = (dim_t)m_input; + if (n_input < 0) + n = p * (dim_t)abs(n_input); + else + n = (dim_t)n_input; - - side = BLIS_LEFT; - //side = BLIS_RIGHT; + side = BLIS_LEFT; + // side = BLIS_RIGHT; - uploa = BLIS_LOWER; - //uploa = BLIS_UPPER; + uploa = BLIS_LOWER; + // uploa = BLIS_UPPER; - transa = BLIS_NO_TRANSPOSE; + transa = BLIS_NO_TRANSPOSE; - diaga = BLIS_NONUNIT_DIAG; + diaga = BLIS_NONUNIT_DIAG; - bli_param_map_blis_to_netlib_side( side, &f77_side ); - bli_param_map_blis_to_netlib_uplo( uploa, &f77_uploa ); - bli_param_map_blis_to_netlib_trans( transa, &f77_transa ); - bli_param_map_blis_to_netlib_diag( diaga, &f77_diaga ); + bli_param_map_blis_to_netlib_side(side, &f77_side); + bli_param_map_blis_to_netlib_uplo(uploa, &f77_uploa); + bli_param_map_blis_to_netlib_trans(transa, &f77_transa); + bli_param_map_blis_to_netlib_diag(diaga, &f77_diaga); - if ( bli_is_left( side ) ) - bli_obj_create( dt, m, m, 0, 0, &a ); + if (bli_is_left(side)) + bli_obj_create(dt, m, m, 0, 0, &a); else - bli_obj_create( dt, n, n, 0, 0, &a ); - bli_obj_create( dt, m, n, 0, 0, &c ); - bli_obj_create( dt, m, n, 0, 0, &c_save ); + bli_obj_create(dt, n, n, 0, 0, &a); + bli_obj_create(dt, m, n, 0, 0, &c); + bli_obj_create(dt, m, n, 0, 0, &c_save); #endif - bli_randm( &a ); - bli_randm( &c ); + bli_randm(&a); + bli_randm(&c); - bli_obj_set_struc( BLIS_TRIANGULAR, &a ); - bli_obj_set_uplo( uploa, &a ); - bli_obj_set_conjtrans( transa, &a ); - bli_obj_set_diag( diaga, &a ); + bli_obj_set_struc(BLIS_TRIANGULAR, &a); + bli_obj_set_uplo(uploa, &a); + bli_obj_set_conjtrans(transa, &a); + bli_obj_set_diag(diaga, &a); // Randomize A and zero the unstored triangle to ensure the // implementation reads only from the stored region. - bli_randm( &a ); - bli_mktrim( &a ); + bli_randm(&a); + bli_mktrim(&a); // Load the diagonal of A to make it more likely to be invertible. - bli_shiftd( &BLIS_TWO, &a ); + bli_shiftd(&BLIS_TWO, &a); - bli_obj_create( dt, 1, 1, 0, 0, &alpha ); - bli_setsc( (2.0/1.0), 1.0, &alpha ); + bli_obj_create(dt, 1, 1, 0, 0, &alpha); + bli_setsc((2.0 / 1.0), 1.0, &alpha); + bli_copym(&c, &c_save); - bli_copym( &c, &c_save ); - dtime_save = DBL_MAX; - for ( r = 0; r < n_repeats; ++r ) + for (r = 0; r < n_repeats; ++r) { - bli_copym( &c_save, &c ); - + bli_copym(&c_save, &c); dtime = bli_clock(); - #ifdef PRINT - bli_invertd( &a ); - bli_printm( "a", &a, "%4.1f", "" ); - bli_invertd( &a ); - bli_printm( "c", &c, "%4.1f", "" ); + bli_invertd(&a); + bli_printm("a", &a, "%4.1f", ""); + bli_invertd(&a); + bli_printm("c", &c, "%4.1f", ""); #endif #ifdef BLIS - bli_trsm( side, - &alpha, - &a, - &c ); + bli_trsm(side, + &alpha, + &a, + &c); #else #ifdef CBLAS - enum CBLAS_ORDER cblas_order; - enum CBLAS_TRANSPOSE cblas_transa; - enum CBLAS_UPLO cblas_uplo; - enum CBLAS_SIDE cblas_side; - enum CBLAS_DIAG cblas_diag; - - if ( bli_obj_row_stride( &c ) == 1 ) - cblas_order = CblasColMajor; - else - cblas_order = CblasRowMajor; - - if( bli_is_trans( transa ) ) - cblas_transa = CblasTrans; - else if( bli_is_conjtrans( transa ) ) - cblas_transa = CblasConjTrans; - else - cblas_transa = CblasNoTrans; - - if(bli_is_upper(uploa)) - cblas_uplo = CblasUpper; - else - cblas_uplo = CblasLower; - - if(bli_is_left(side)) - cblas_side = CblasLeft; - else - cblas_side = CblasRight; - - if(bli_is_unit_diag(diaga)) - cblas_diag = CblasUnit; - else - cblas_diag = CblasNonUnit; + enum CBLAS_ORDER cblas_order; + enum CBLAS_TRANSPOSE cblas_transa; + enum CBLAS_UPLO cblas_uplo; + enum CBLAS_SIDE cblas_side; + enum CBLAS_DIAG cblas_diag; + + if (bli_obj_row_stride(&c) == 1) + cblas_order = CblasColMajor; + else + cblas_order = CblasRowMajor; + + if (bli_is_trans(transa)) + cblas_transa = CblasTrans; + else if (bli_is_conjtrans(transa)) + cblas_transa = CblasConjTrans; + else + cblas_transa = CblasNoTrans; + + if (bli_is_upper(uploa)) + cblas_uplo = CblasUpper; + else + cblas_uplo = CblasLower; + + if (bli_is_left(side)) + cblas_side = CblasLeft; + else + cblas_side = CblasRight; + + if (bli_is_unit_diag(diaga)) + cblas_diag = CblasUnit; + else + cblas_diag = CblasNonUnit; #else - f77_char f77_transa; - bli_param_map_blis_to_netlib_trans( transa, &f77_transa ); + f77_char f77_transa; + bli_param_map_blis_to_netlib_trans(transa, &f77_transa); #endif - if ( bli_is_float( dt ) ) - { - f77_int mm = bli_obj_length( &c ); - f77_int nn = bli_obj_width( &c ); - f77_int lda = bli_obj_col_stride( &a ); - f77_int ldc = bli_obj_col_stride( &c ); + if (bli_is_float(dt)) + { + f77_int mm = bli_obj_length(&c); + f77_int nn = bli_obj_width(&c); + f77_int lda = bli_obj_col_stride(&a); + f77_int ldc = bli_obj_col_stride(&c); - float* alphap = bli_obj_buffer( &alpha ); - float* ap = bli_obj_buffer( &a ); - float* cp = bli_obj_buffer( &c ); + float *alphap = bli_obj_buffer(&alpha); + float *ap = bli_obj_buffer(&a); + float *cp = bli_obj_buffer(&c); #ifdef CBLAS - cblas_strsm( cblas_order, - cblas_side, - cblas_uplo, - cblas_transa, - cblas_diag, - mm, - nn, - *alphap, - ap, lda, - cp, ldc - ); + cblas_strsm(cblas_order, + cblas_side, + cblas_uplo, + cblas_transa, + cblas_diag, + mm, + nn, + *alphap, + ap, lda, + cp, ldc); #else - strsm_( &f77_side, - &f77_uploa, - &f77_transa, - &f77_diaga, - &mm, - &nn, - alphap, - ap, &lda, - cp, &ldc ); + strsm_(&f77_side, + &f77_uploa, + &f77_transa, + &f77_diaga, + &mm, + &nn, + alphap, + ap, &lda, + cp, &ldc); #endif - } - else if ( bli_is_double( dt ) ) - { - f77_int mm = bli_obj_length( &c ); - f77_int nn = bli_obj_width( &c ); - f77_int lda = bli_obj_col_stride( &a ); - f77_int ldc = bli_obj_col_stride( &c ); - double* alphap = bli_obj_buffer( &alpha ); - double* ap = bli_obj_buffer( &a ); - double* cp = bli_obj_buffer( &c ); + } + else if (bli_is_double(dt)) + { + f77_int mm = bli_obj_length(&c); + f77_int nn = bli_obj_width(&c); + f77_int lda = bli_obj_col_stride(&a); + f77_int ldc = bli_obj_col_stride(&c); + double *alphap = bli_obj_buffer(&alpha); + double *ap = bli_obj_buffer(&a); + double *cp = bli_obj_buffer(&c); #ifdef CBLAS - cblas_dtrsm( cblas_order, - cblas_side, - cblas_uplo, - cblas_transa, - cblas_diag, - mm, - nn, - *alphap, - ap, lda, - cp, ldc - ); -#else - dtrsm_( &f77_side, - &f77_uploa, - &f77_transa, - &f77_diaga, - &mm, - &nn, - alphap, - ap, &lda, - cp, &ldc ); + cblas_dtrsm(cblas_order, + cblas_side, + cblas_uplo, + cblas_transa, + cblas_diag, + mm, + nn, + *alphap, + ap, lda, + cp, ldc); +#else + dtrsm_(&f77_side, + &f77_uploa, + &f77_transa, + &f77_diaga, + &mm, + &nn, + alphap, + ap, &lda, + cp, &ldc); #endif - - } - else if ( bli_is_scomplex( dt ) ) - { - f77_int mm = bli_obj_length( &c ); - f77_int nn = bli_obj_width( &c ); - f77_int lda = bli_obj_col_stride( &a ); - f77_int ldc = bli_obj_col_stride( &c ); - scomplex* alphap = bli_obj_buffer( &alpha ); - scomplex* ap = bli_obj_buffer( &a ); - scomplex* cp = bli_obj_buffer( &c ); + } + else if (bli_is_scomplex(dt)) + { + f77_int mm = bli_obj_length(&c); + f77_int nn = bli_obj_width(&c); + f77_int lda = bli_obj_col_stride(&a); + f77_int ldc = bli_obj_col_stride(&c); + scomplex *alphap = bli_obj_buffer(&alpha); + scomplex *ap = bli_obj_buffer(&a); + scomplex *cp = bli_obj_buffer(&c); #ifdef CBLAS - cblas_ctrsm( cblas_order, - cblas_side, - cblas_uplo, - cblas_transa, - cblas_diag, - mm, - nn, - alphap, - ap, lda, - cp, ldc - ); + cblas_ctrsm(cblas_order, + cblas_side, + cblas_uplo, + cblas_transa, + cblas_diag, + mm, + nn, + alphap, + ap, lda, + cp, ldc); #else - ctrsm_( &f77_side, - &f77_uploa, - &f77_transa, - &f77_diaga, - &mm, - &nn, - alphap, - ap, &lda, - cp, &ldc ); + ctrsm_(&f77_side, + &f77_uploa, + &f77_transa, + &f77_diaga, + &mm, + &nn, + alphap, + ap, &lda, + cp, &ldc); #endif - } - else if ( bli_is_dcomplex( dt ) ) - { - f77_int mm = bli_obj_length( &c ); - f77_int nn = bli_obj_width( &c ); - f77_int lda = bli_obj_col_stride( &a ); - f77_int ldc = bli_obj_col_stride( &c ); - dcomplex* alphap = bli_obj_buffer( &alpha ); - dcomplex* ap = bli_obj_buffer( &a ); - dcomplex* cp = bli_obj_buffer( &c ); + } + else if (bli_is_dcomplex(dt)) + { + f77_int mm = bli_obj_length(&c); + f77_int nn = bli_obj_width(&c); + f77_int lda = bli_obj_col_stride(&a); + f77_int ldc = bli_obj_col_stride(&c); + dcomplex *alphap = bli_obj_buffer(&alpha); + dcomplex *ap = bli_obj_buffer(&a); + dcomplex *cp = bli_obj_buffer(&c); #ifdef CBLAS - cblas_ztrsm( cblas_order, - cblas_side, - cblas_uplo, - cblas_transa, - cblas_diag, - mm, - nn, - alphap, - ap, lda, - cp, ldc - ); + cblas_ztrsm(cblas_order, + cblas_side, + cblas_uplo, + cblas_transa, + cblas_diag, + mm, + nn, + alphap, + ap, lda, + cp, ldc); #else - ztrsm_( &f77_side, - &f77_uploa, - &f77_transa, - &f77_diaga, - &mm, - &nn, - alphap, - ap, &lda, - cp, &ldc ); + ztrsm_(&f77_side, + &f77_uploa, + &f77_transa, + &f77_diaga, + &mm, + &nn, + alphap, + ap, &lda, + cp, &ldc); #endif - }else{ - printf("Invalid data type! Exiting!\n"); - exit(1); - } + } + else + { + printf("Invalid data type! Exiting!\n"); + exit(1); + } #endif - dtime_save = bli_clock_min_diff( dtime_save, dtime ); + dtime_save = bli_clock_min_diff(dtime_save, dtime); } - if ( bli_is_left( side ) ) - gflops = ( 1.0 * m * m * n ) / ( dtime_save * 1.0e9 ); + if (bli_is_left(side)) + gflops = (1.0 * m * m * n) / (dtime_save * 1.0e9); else - gflops = ( 1.0 * m * n * n ) / ( dtime_save * 1.0e9 ); + gflops = (1.0 * m * n * n) / (dtime_save * 1.0e9); - if ( bli_is_complex( dt ) ) gflops *= 4.0; + if (bli_is_complex(dt)) + gflops *= 4.0; #ifdef BLIS - printf( "data_trsm_blis" ); + printf("data_trsm_blis"); #else - printf( "data_trsm_%s", BLAS ); + printf("data_trsm_%s", BLAS); #endif #ifdef FILE_IN_OUT #ifdef READ_ALL_PARAMS_FROM_FILE - printf("%c\t %c\t %c\t %c\t %4lu\t %4lu\t %4lu\t %4lu\t %6.3f\n",side_c, uploa_c, transa_c, diaga_c, - (unsigned long )m, (unsigned long ) n, - (unsigned long )cs_a, (unsigned long )cs_b, - gflops); + printf("%c\t %c\t %c\t %c\t %4lu\t %4lu\t %4lu\t %4lu\t %6.3f\n", side_c, uploa_c, transa_c, diaga_c, + (unsigned long)m, (unsigned long)n, + (unsigned long)cs_a, (unsigned long)cs_b, + gflops); - fprintf(fout,"%c\t %c\t %c\t %c\t %4lu\t %4lu\t %4lu\t %4lu\t %6.3f\n", side_c, uploa_c, transa_c, diaga_c, - (unsigned long )m, (unsigned long ) n, - (unsigned long )cs_a, (unsigned long )cs_b, - gflops); + fprintf(fout, "%c\t %c\t %c\t %c\t %4lu\t %4lu\t %4lu\t %4lu\t %6.3f\n", side_c, uploa_c, transa_c, diaga_c, + (unsigned long)m, (unsigned long)n, + (unsigned long)cs_a, (unsigned long)cs_b, + gflops); #else - printf("%4lu\t %4lu\t %4lu\t %4lu\t %6.3f\n", (unsigned long )m, (unsigned long ) n, - (unsigned long )cs_a, (unsigned long )cs_b, - gflops); - fprintf(fout,"%4lu\t %4lu\t %4lu\t %4lu\t %6.3f\n", (unsigned long )m, (unsigned long ) n, - (unsigned long )cs_a, (unsigned long )cs_b, - gflops); + printf("%4lu\t %4lu\t %4lu\t %4lu\t %6.3f\n", (unsigned long)m, (unsigned long)n, + (unsigned long)cs_a, (unsigned long)cs_b, + gflops); + fprintf(fout, "%4lu\t %4lu\t %4lu\t %4lu\t %6.3f\n", (unsigned long)m, (unsigned long)n, + (unsigned long)cs_a, (unsigned long)cs_b, + gflops); #endif -fflush(fout); + fflush(fout); #else - printf( "( %2lu, 1:3 ) = [ %4lu %4lu %7.2f ];\n", - ( unsigned long )(p - p_begin)/p_inc + 1, - ( unsigned long )m, - ( unsigned long )n, gflops ); + printf("( %2lu, 1:3 ) = [ %4lu %4lu %7.2f ];\n", + (unsigned long)(p - p_begin) / p_inc + 1, + (unsigned long)m, + (unsigned long)n, gflops); #endif - bli_obj_free( &alpha ); + bli_obj_free(&alpha); - bli_obj_free( &a ); - bli_obj_free( &c ); - bli_obj_free( &c_save ); + bli_obj_free(&a); + bli_obj_free(&c); + bli_obj_free(&c_save); } #ifdef FILE_IN_OUT - fclose(fin); - fclose(fout); + fclose(fin); + fclose(fout); #endif - //bli_finalize(); + // bli_finalize(); return 0; } - From f17d043e1c47ae825c0e89e0c7b6fc56a5b2a7d1 Mon Sep 17 00:00:00 2001 From: Harsh Dave Date: Fri, 11 Mar 2022 00:12:52 -0600 Subject: [PATCH 103/243] Implemented optimal dotxv kernel Details: - Intrinsic implementation of zdotxv, cdotxv kernel - Unrolling in multiple of 8, remaining corner cases are handled serially for zdotxv kernel - Unrolling in multiple of 16, remainig corner cases are handled serially for cdotxv kernel - Added declaration in zen contexts AMD-Internal: [CPUPL-2050] Change-Id: Id58b0dbfdb7a782eb50eecc7142f051b630d9211 --- config/zen/bli_cntx_init_zen.c | 4 +- config/zen2/bli_cntx_init_zen2.c | 4 +- config/zen3/bli_cntx_init_zen3.c | 4 +- kernels/zen/1/bli_dotxv_zen_int.c | 499 ++++++++++++++++++++++++++++++ kernels/zen/bli_kernels_zen.h | 4 +- 5 files changed, 511 insertions(+), 4 deletions(-) diff --git a/config/zen/bli_cntx_init_zen.c b/config/zen/bli_cntx_init_zen.c index 674549d77f..3fea3ea8f9 100644 --- a/config/zen/bli_cntx_init_zen.c +++ b/config/zen/bli_cntx_init_zen.c @@ -103,7 +103,7 @@ void bli_cntx_init_zen( cntx_t* cntx ) // Update the context with optimized level-1v kernels. bli_cntx_set_l1v_kers ( - 24, + 26, #if 1 // amaxv BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int, @@ -135,6 +135,8 @@ void bli_cntx_init_zen( cntx_t* cntx ) // dotxv BLIS_DOTXV_KER, BLIS_FLOAT, bli_sdotxv_zen_int, BLIS_DOTXV_KER, BLIS_DOUBLE, bli_ddotxv_zen_int, + BLIS_DOTXV_KER, BLIS_DCOMPLEX, bli_zdotxv_zen_int, + BLIS_DOTXV_KER, BLIS_SCOMPLEX, bli_cdotxv_zen_int, // scalv #if 0 BLIS_SCALV_KER, BLIS_FLOAT, bli_sscalv_zen_int, diff --git a/config/zen2/bli_cntx_init_zen2.c b/config/zen2/bli_cntx_init_zen2.c index 48cb90a4f8..1ecb62ff52 100644 --- a/config/zen2/bli_cntx_init_zen2.c +++ b/config/zen2/bli_cntx_init_zen2.c @@ -115,7 +115,7 @@ void bli_cntx_init_zen2( cntx_t* cntx ) // Update the context with optimized level-1v kernels. bli_cntx_set_l1v_kers ( - 24, + 26, #if 1 // amaxv BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int, @@ -142,6 +142,8 @@ void bli_cntx_init_zen2( cntx_t* cntx ) // dotxv BLIS_DOTXV_KER, BLIS_FLOAT, bli_sdotxv_zen_int, BLIS_DOTXV_KER, BLIS_DOUBLE, bli_ddotxv_zen_int, + BLIS_DOTXV_KER, BLIS_DCOMPLEX, bli_zdotxv_zen_int, + BLIS_DOTXV_KER, BLIS_SCOMPLEX, bli_cdotxv_zen_int, // scalv BLIS_SCALV_KER, BLIS_FLOAT, bli_sscalv_zen_int10, diff --git a/config/zen3/bli_cntx_init_zen3.c b/config/zen3/bli_cntx_init_zen3.c index e83a12b401..02e264d277 100644 --- a/config/zen3/bli_cntx_init_zen3.c +++ b/config/zen3/bli_cntx_init_zen3.c @@ -115,7 +115,7 @@ void bli_cntx_init_zen3( cntx_t* cntx ) // Update the context with optimized level-1v kernels. bli_cntx_set_l1v_kers ( - 24, + 26, #if 1 // amaxv BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int, @@ -142,6 +142,8 @@ void bli_cntx_init_zen3( cntx_t* cntx ) // dotxv BLIS_DOTXV_KER, BLIS_FLOAT, bli_sdotxv_zen_int, BLIS_DOTXV_KER, BLIS_DOUBLE, bli_ddotxv_zen_int, + BLIS_DOTXV_KER, BLIS_DCOMPLEX, bli_zdotxv_zen_int, + BLIS_DOTXV_KER, BLIS_SCOMPLEX, bli_cdotxv_zen_int, // scalv BLIS_SCALV_KER, BLIS_FLOAT, bli_sscalv_zen_int10, diff --git a/kernels/zen/1/bli_dotxv_zen_int.c b/kernels/zen/1/bli_dotxv_zen_int.c index 8ba1d1bba4..c210eceff5 100644 --- a/kernels/zen/1/bli_dotxv_zen_int.c +++ b/kernels/zen/1/bli_dotxv_zen_int.c @@ -332,3 +332,502 @@ void bli_ddotxv_zen_int PASTEMAC(d,axpys)( *alpha, rho0, *rho ); } + + +void bli_zdotxv_zen_int + ( + conj_t conjx, + conj_t conjy, + dim_t n, + dcomplex* restrict alpha, + dcomplex* restrict x, inc_t incx, + dcomplex* restrict y, inc_t incy, + dcomplex* restrict beta, + dcomplex* restrict rho, + cntx_t* restrict cntx + ) +{ + const dim_t n_elem_per_reg = 2; + const dim_t n_iter_unroll = 4; + + dim_t i; + dim_t n_viter; + dim_t n_left; + + dcomplex* restrict x0; + dcomplex* restrict y0; + dcomplex rho0; + + v4df_t rhov[8], xv[4], yv[8]; + + conj_t conjx_use = conjx; + if ( bli_is_conj( conjy ) ) + { + bli_toggle_conj( &conjx_use ); + } + // If beta is zero, initialize rho1 to zero instead of scaling + // rho by beta (in case rho contains NaN or Inf). + if ( PASTEMAC(z,eq0)( *beta ) ) + { + PASTEMAC(z,set0s)( *rho ); + } + else + { + PASTEMAC(z,scals)( *beta, *rho ); + } + + // If the vector dimension is zero, output rho and return early. + if ( bli_zero_dim1( n ) || PASTEMAC(z,eq0)( *alpha ) ) return; + + // Use the unrolling factor and the number of elements per register + // to compute the number of vectorized and leftover iterations. + n_viter = ( n ) / ( n_elem_per_reg * n_iter_unroll ); + n_left = ( n ) % ( n_elem_per_reg * n_iter_unroll ); + + // If there is anything that would interfere with our use of contiguous + // vector loads/stores, override n_viter and n_left to use scalar code + // for all iterations. + if ( incx != 1 || incy != 1 ) + { + n_viter = 0; + n_left = n; + } + + // Initialize local pointers. + x0 = x; + y0 = y; + + // Initialize the unrolled iterations' rho vectors to zero. + rhov[0].v = _mm256_setzero_pd(); + rhov[1].v = _mm256_setzero_pd(); + rhov[2].v = _mm256_setzero_pd(); + rhov[3].v = _mm256_setzero_pd(); + + rhov[4].v = _mm256_setzero_pd(); + rhov[5].v = _mm256_setzero_pd(); + rhov[6].v = _mm256_setzero_pd(); + rhov[7].v = _mm256_setzero_pd(); + + if ( bli_is_conj( conjx_use ) ) + { + __m256d conju = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for ( i = 0; i < n_viter; ++i ) + { + // Load the x and y input vector elements. + xv[0].v = _mm256_loadu_pd((double *) (x0 + 0*n_elem_per_reg) ); + yv[0].v = _mm256_loadu_pd((double *) (y0 + 0*n_elem_per_reg) ); + + xv[1].v = _mm256_loadu_pd((double *) (x0 + 1*n_elem_per_reg) ); + yv[1].v = _mm256_loadu_pd((double *) (y0 + 1*n_elem_per_reg) ); + + xv[2].v = _mm256_loadu_pd((double *) (x0 + 2*n_elem_per_reg) ); + yv[2].v = _mm256_loadu_pd((double *) (y0 + 2*n_elem_per_reg) ); + + xv[3].v = _mm256_loadu_pd((double *) (x0 + 3*n_elem_per_reg) ); + yv[3].v = _mm256_loadu_pd((double *) (y0 + 3*n_elem_per_reg) ); + + yv[0].v = _mm256_mul_pd(yv[0].v, conju); + yv[1].v = _mm256_mul_pd(yv[1].v, conju); + yv[2].v = _mm256_mul_pd(yv[2].v, conju); + yv[3].v = _mm256_mul_pd(yv[3].v, conju); + //yi0 yi0 yi1 yi1 + //xr0 xi0 xr1 xi1 + //after permute of vector registers + //yi0*xr0 yi0*xi0 yi1*xr1 yi1*xi1 + yv[4].v = _mm256_permute_pd( yv[0].v, 15 ); + yv[5].v = _mm256_permute_pd( yv[1].v, 15 ); + yv[6].v = _mm256_permute_pd( yv[2].v, 15 ); + yv[7].v = _mm256_permute_pd( yv[3].v, 15 ); + + //yr0 yr0 yr1 yr1 + //xr0 xi0 xr1 xi1 + //after permute of vector registers + //yr0*xr0 yr0*xi0 yr1*xr1 yr1*xi1 + yv[0].v = _mm256_permute_pd( yv[0].v, 0 ); + yv[1].v = _mm256_permute_pd( yv[1].v, 0 ); + yv[2].v = _mm256_permute_pd( yv[2].v, 0 ); + yv[3].v = _mm256_permute_pd( yv[3].v, 0 ); + + // Compute the element-wise product of the x and y vectors, + // storing in the corresponding rho vectors. + rhov[0].v = _mm256_fmadd_pd( xv[0].v, yv[0].v, rhov[0].v ); + rhov[1].v = _mm256_fmadd_pd( xv[1].v, yv[1].v, rhov[1].v ); + rhov[2].v = _mm256_fmadd_pd( xv[2].v, yv[2].v, rhov[2].v ); + rhov[3].v = _mm256_fmadd_pd( xv[3].v, yv[3].v, rhov[3].v ); + + rhov[4].v = _mm256_fmadd_pd( xv[0].v, yv[4].v, rhov[4].v ); + rhov[5].v = _mm256_fmadd_pd( xv[1].v, yv[5].v, rhov[5].v ); + rhov[6].v = _mm256_fmadd_pd( xv[2].v, yv[6].v, rhov[6].v ); + rhov[7].v = _mm256_fmadd_pd( xv[3].v, yv[7].v, rhov[7].v ); + + x0 += ( n_elem_per_reg * n_iter_unroll ); + y0 += ( n_elem_per_reg * n_iter_unroll ); + } + } + else + { + for ( i = 0; i < n_viter; ++i ) + { + // Load the x and y input vector elements. + xv[0].v = _mm256_loadu_pd((double *) (x0 + 0*n_elem_per_reg) ); + yv[0].v = _mm256_loadu_pd((double *) (y0 + 0*n_elem_per_reg) ); + + xv[1].v = _mm256_loadu_pd((double *) (x0 + 1*n_elem_per_reg) ); + yv[1].v = _mm256_loadu_pd((double *) (y0 + 1*n_elem_per_reg) ); + + xv[2].v = _mm256_loadu_pd((double *) (x0 + 2*n_elem_per_reg) ); + yv[2].v = _mm256_loadu_pd((double *) (y0 + 2*n_elem_per_reg) ); + + xv[3].v = _mm256_loadu_pd((double *) (x0 + 3*n_elem_per_reg) ); + yv[3].v = _mm256_loadu_pd((double *) (y0 + 3*n_elem_per_reg) ); + + //yi0 yi0 yi1 yi1 + //xr0 xi0 xr1 xi1 + //--------------- + //yi0*xr0 yi0*xi0 yi1*xr1 yi1*xi1 + yv[4].v = _mm256_permute_pd( yv[0].v, 15 ); + yv[5].v = _mm256_permute_pd( yv[1].v, 15 ); + yv[6].v = _mm256_permute_pd( yv[2].v, 15 ); + yv[7].v = _mm256_permute_pd( yv[3].v, 15 ); + + //yr0 yr0 yr1 yr1 + //xr0 xi0 xr1 xi1 + //---------------- + //yr0*xr0 yr0*xi0 yr1*xr1 yr1*xi1 + yv[0].v = _mm256_permute_pd( yv[0].v, 0 ); + yv[1].v = _mm256_permute_pd( yv[1].v, 0 ); + yv[2].v = _mm256_permute_pd( yv[2].v, 0 ); + yv[3].v = _mm256_permute_pd( yv[3].v, 0 ); + + // Compute the element-wise product of the x and y vectors, + // storing in the corresponding rho vectors. + rhov[0].v = _mm256_fmadd_pd( xv[0].v, yv[0].v, rhov[0].v ); + rhov[1].v = _mm256_fmadd_pd( xv[1].v, yv[1].v, rhov[1].v ); + rhov[2].v = _mm256_fmadd_pd( xv[2].v, yv[2].v, rhov[2].v ); + rhov[3].v = _mm256_fmadd_pd( xv[3].v, yv[3].v, rhov[3].v ); + + rhov[4].v = _mm256_fmadd_pd( xv[0].v, yv[4].v, rhov[4].v ); + rhov[5].v = _mm256_fmadd_pd( xv[1].v, yv[5].v, rhov[5].v ); + rhov[6].v = _mm256_fmadd_pd( xv[2].v, yv[6].v, rhov[6].v ); + rhov[7].v = _mm256_fmadd_pd( xv[3].v, yv[7].v, rhov[7].v ); + + x0 += ( n_elem_per_reg * n_iter_unroll ); + y0 += ( n_elem_per_reg * n_iter_unroll ); + } + } + + //yr0*xr0 yr0*xi0 yr1*xr1 yr1*xi1 + // - + - + + //yi0*xi0 yi0*xr0 yi1*xi1 yi1*xr1 + rhov[4].v = _mm256_permute_pd(rhov[4].v, 0x05); + rhov[5].v = _mm256_permute_pd(rhov[5].v, 0x05); + rhov[6].v = _mm256_permute_pd(rhov[6].v, 0x05); + rhov[7].v = _mm256_permute_pd(rhov[7].v, 0x05); + + rhov[0].v = _mm256_addsub_pd(rhov[0].v, rhov[4].v); + rhov[1].v = _mm256_addsub_pd(rhov[1].v, rhov[5].v); + rhov[2].v = _mm256_addsub_pd(rhov[2].v, rhov[6].v); + rhov[3].v = _mm256_addsub_pd(rhov[3].v, rhov[7].v); + + // Accumulate the unrolled rho vectors into a single vector. + rhov[0].v = _mm256_add_pd(rhov[1].v,rhov[0].v); + rhov[0].v = _mm256_add_pd(rhov[2].v,rhov[0].v); + rhov[0].v = _mm256_add_pd(rhov[3].v,rhov[0].v); + + v2df_t inter1, inter2; + + inter1.v = _mm256_extractf128_pd(rhov[0].v,1); + inter2.v = _mm256_extractf128_pd(rhov[0].v,0); + + inter1.v = _mm_add_pd(inter1.v, inter2.v); + + // Accumulate the final rho vector into a single scalar result. + rho0.real = inter1.d[0]; + rho0.imag = inter1.d[1]; + + /* Negate sign of imaginary value when vector y is conjugate */ + if ( bli_is_conj(conjx_use)) + rho0.imag = -rho0.imag; + + // Issue vzeroupper instruction to clear upper lanes of ymm registers. + // This avoids a performance penalty caused by false dependencies when + // transitioning from from AVX to SSE instructions (which may occur + // as soon as the n_left cleanup loop below if BLIS is compiled with + // -mfpmath=sse). + _mm256_zeroupper(); + + // If there are leftover iterations, perform them with scalar code. + if ( bli_is_conj( conjx_use ) ) + { + for ( i = 0; i < n_left; ++i ) + { + PASTEMAC(z,dotjs)( *x0, *y0, rho0 ); + x0 += incx; + y0 += incy; + } + } + else + { + for ( i = 0; i < n_left; ++i ) + { + PASTEMAC(z,dots)( *x0, *y0, rho0 ); + x0 += incx; + y0 += incy; + } + } + + if ( bli_is_conj( conjy ) ) + PASTEMAC(z,conjs)( rho0 ); + + // Accumulate the final result into the output variable. + PASTEMAC(z,axpys)( *alpha, rho0, *rho ); +} + +void bli_cdotxv_zen_int + ( + conj_t conjx, + conj_t conjy, + dim_t n, + scomplex* restrict alpha, + scomplex* restrict x, inc_t incx, + scomplex* restrict y, inc_t incy, + scomplex* restrict beta, + scomplex* restrict rho, + cntx_t* restrict cntx + ) +{ + const dim_t n_elem_per_reg = 4; + const dim_t n_iter_unroll = 4; + + dim_t i; + dim_t n_viter; + dim_t n_left; + + scomplex* restrict x0; + scomplex* restrict y0; + scomplex rho0; + + v8sf_t rhov[8], xv[4], yv[8]; + + conj_t conjx_use = conjx; + if ( bli_is_conj( conjy ) ) + { + bli_toggle_conj( &conjx_use ); + } + // If beta is zero, initialize rho1 to zero instead of scaling + // rho by beta (in case rho contains NaN or Inf). + if ( PASTEMAC(c,eq0)( *beta ) ) + { + PASTEMAC(c,set0s)( *rho ); + } + else + { + PASTEMAC(c,scals)( *beta, *rho ); + } + + // If the vector dimension is zero, output rho and return early. + if ( bli_zero_dim1( n ) || PASTEMAC(c,eq0)( *alpha ) ) return; + + // Use the unrolling factor and the number of elements per register + // to compute the number of vectorized and leftover iterations. + n_viter = ( n ) / ( n_elem_per_reg * n_iter_unroll ); + n_left = ( n ) % ( n_elem_per_reg * n_iter_unroll ); + + // If there is anything that would interfere with our use of contiguous + // vector loads/stores, override n_viter and n_left to use scalar code + // for all iterations. + if ( incx != 1 || incy != 1 ) + { + n_viter = 0; + n_left = n; + } + + // Initialize local pointers. + x0 = x; + y0 = y; + + // Initialize the unrolled iterations' rho vectors to zero. + rhov[0].v = _mm256_setzero_ps(); + rhov[1].v = _mm256_setzero_ps(); + rhov[2].v = _mm256_setzero_ps(); + rhov[3].v = _mm256_setzero_ps(); + + rhov[4].v = _mm256_setzero_ps(); + rhov[5].v = _mm256_setzero_ps(); + rhov[6].v = _mm256_setzero_ps(); + rhov[7].v = _mm256_setzero_ps(); + + if ( bli_is_conj( conjx_use ) ) + { + __m256 conju = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); + for ( i = 0; i < n_viter; ++i ) + { + // Load the x and y input vector elements. + xv[0].v = _mm256_loadu_ps((float *) (x0 + 0*n_elem_per_reg) ); + yv[0].v = _mm256_loadu_ps((float *) (y0 + 0*n_elem_per_reg) ); + + xv[1].v = _mm256_loadu_ps((float *) (x0 + 1*n_elem_per_reg) ); + yv[1].v = _mm256_loadu_ps((float *) (y0 + 1*n_elem_per_reg) ); + + xv[2].v = _mm256_loadu_ps((float *) (x0 + 2*n_elem_per_reg) ); + yv[2].v = _mm256_loadu_ps((float *) (y0 + 2*n_elem_per_reg) ); + + xv[3].v = _mm256_loadu_ps((float *) (x0 + 3*n_elem_per_reg) ); + yv[3].v = _mm256_loadu_ps((float *) (y0 + 3*n_elem_per_reg) ); + + yv[0].v = _mm256_mul_ps(yv[0].v, conju); + yv[1].v = _mm256_mul_ps(yv[1].v, conju); + yv[2].v = _mm256_mul_ps(yv[2].v, conju); + yv[3].v = _mm256_mul_ps(yv[3].v, conju); + //yi0 yi0 yi1 yi1 + //xr0 xi0 xr1 xi1 + //after permute of vector registers + //yi0*xr0 yi0*xi0 yi1*xr1 yi1*xi1 + yv[4].v = _mm256_permute_ps( yv[0].v, 0xf5 ); + yv[5].v = _mm256_permute_ps( yv[1].v, 0xf5 ); + yv[6].v = _mm256_permute_ps( yv[2].v, 0xf5 ); + yv[7].v = _mm256_permute_ps( yv[3].v, 0xf5 ); + + //yr0 yr0 yr1 yr1 + //xr0 xi0 xr1 xi1 + //after permute of vector registers + //yr0*xr0 yr0*xi0 yr1*xr1 yr1*xi1 + yv[0].v = _mm256_permute_ps( yv[0].v, 0xa0 ); + yv[1].v = _mm256_permute_ps( yv[1].v, 0xa0 ); + yv[2].v = _mm256_permute_ps( yv[2].v, 0xa0 ); + yv[3].v = _mm256_permute_ps( yv[3].v, 0xa0 ); + + // Compute the element-wise product of the x and y vectors, + // storing in the corresponding rho vectors. + rhov[0].v = _mm256_fmadd_ps( xv[0].v, yv[0].v, rhov[0].v ); + rhov[1].v = _mm256_fmadd_ps( xv[1].v, yv[1].v, rhov[1].v ); + rhov[2].v = _mm256_fmadd_ps( xv[2].v, yv[2].v, rhov[2].v ); + rhov[3].v = _mm256_fmadd_ps( xv[3].v, yv[3].v, rhov[3].v ); + + rhov[4].v = _mm256_fmadd_ps( xv[0].v, yv[4].v, rhov[4].v ); + rhov[5].v = _mm256_fmadd_ps( xv[1].v, yv[5].v, rhov[5].v ); + rhov[6].v = _mm256_fmadd_ps( xv[2].v, yv[6].v, rhov[6].v ); + rhov[7].v = _mm256_fmadd_ps( xv[3].v, yv[7].v, rhov[7].v ); + + x0 += ( n_elem_per_reg * n_iter_unroll ); + y0 += ( n_elem_per_reg * n_iter_unroll ); + } + } + else + { + for ( i = 0; i < n_viter; ++i ) + { + // Load the x and y input vector elements. + xv[0].v = _mm256_loadu_ps((float *) (x0 + 0*n_elem_per_reg) ); + yv[0].v = _mm256_loadu_ps((float *) (y0 + 0*n_elem_per_reg) ); + + xv[1].v = _mm256_loadu_ps((float *) (x0 + 1*n_elem_per_reg) ); + yv[1].v = _mm256_loadu_ps((float *) (y0 + 1*n_elem_per_reg) ); + + xv[2].v = _mm256_loadu_ps((float *) (x0 + 2*n_elem_per_reg) ); + yv[2].v = _mm256_loadu_ps((float *) (y0 + 2*n_elem_per_reg) ); + + xv[3].v = _mm256_loadu_ps((float *) (x0 + 3*n_elem_per_reg) ); + yv[3].v = _mm256_loadu_ps((float *) (y0 + 3*n_elem_per_reg) ); + + //yi0 yi0 yi1 yi1 + //xr0 xi0 xr1 xi1 + //--------------- + //yi0*xr0 yi0*xi0 yi1*xr1 yi1*xi1 + yv[4].v = _mm256_permute_ps( yv[0].v, 0xf5 ); + yv[5].v = _mm256_permute_ps( yv[1].v, 0xf5 ); + yv[6].v = _mm256_permute_ps( yv[2].v, 0xf5 ); + yv[7].v = _mm256_permute_ps( yv[3].v, 0xf5 ); + + //yr0 yr0 yr1 yr1 + //xr0 xi0 xr1 xi1 + //---------------- + //yr0*xr0 yr0*xi0 yr1*xr1 yr1*xi1 + yv[0].v = _mm256_permute_ps( yv[0].v, 0xa0 ); + yv[1].v = _mm256_permute_ps( yv[1].v, 0xa0 ); + yv[2].v = _mm256_permute_ps( yv[2].v, 0xa0 ); + yv[3].v = _mm256_permute_ps( yv[3].v, 0xa0 ); + + // Compute the element-wise product of the x and y vectors, + // storing in the corresponding rho vectors. + rhov[0].v = _mm256_fmadd_ps( xv[0].v, yv[0].v, rhov[0].v ); + rhov[1].v = _mm256_fmadd_ps( xv[1].v, yv[1].v, rhov[1].v ); + rhov[2].v = _mm256_fmadd_ps( xv[2].v, yv[2].v, rhov[2].v ); + rhov[3].v = _mm256_fmadd_ps( xv[3].v, yv[3].v, rhov[3].v ); + + rhov[4].v = _mm256_fmadd_ps( xv[0].v, yv[4].v, rhov[4].v ); + rhov[5].v = _mm256_fmadd_ps( xv[1].v, yv[5].v, rhov[5].v ); + rhov[6].v = _mm256_fmadd_ps( xv[2].v, yv[6].v, rhov[6].v ); + rhov[7].v = _mm256_fmadd_ps( xv[3].v, yv[7].v, rhov[7].v ); + + x0 += ( n_elem_per_reg * n_iter_unroll ); + y0 += ( n_elem_per_reg * n_iter_unroll ); + } + } + + //yr0*xr0 yr0*xi0 yr1*xr1 yr1*xi1 + // - + - + + //yi0*xi0 yi0*xr0 yi1*xi1 yi1*xr1 + rhov[4].v = _mm256_permute_ps(rhov[4].v, 0xb1); + rhov[5].v = _mm256_permute_ps(rhov[5].v, 0xb1); + rhov[6].v = _mm256_permute_ps(rhov[6].v, 0xb1); + rhov[7].v = _mm256_permute_ps(rhov[7].v, 0xb1); + + rhov[0].v = _mm256_addsub_ps(rhov[0].v, rhov[4].v); + rhov[1].v = _mm256_addsub_ps(rhov[1].v, rhov[5].v); + rhov[2].v = _mm256_addsub_ps(rhov[2].v, rhov[6].v); + rhov[3].v = _mm256_addsub_ps(rhov[3].v, rhov[7].v); + + // Accumulate the unrolled rho vectors into a single vector. + rhov[0].v = _mm256_add_ps(rhov[1].v,rhov[0].v); + rhov[0].v = _mm256_add_ps(rhov[2].v,rhov[0].v); + rhov[0].v = _mm256_add_ps(rhov[3].v,rhov[0].v); + + v4sf_t inter1, inter2; + + inter1.v = _mm256_extractf128_ps(rhov[0].v,1); + inter2.v = _mm256_extractf128_ps(rhov[0].v,0); + + inter1.v = _mm_add_ps(inter1.v, inter2.v); + + // Accumulate the final rho vector into a single scalar result. + rho0.real = inter1.f[0] + inter1.f[2]; + rho0.imag = inter1.f[1] + inter1.f[3]; + + /* Negate sign of imaginary value when vector y is conjugate */ + if ( bli_is_conj(conjx_use)) + rho0.imag = -rho0.imag; + + // Issue vzeroupper instruction to clear upper lanes of ymm registers. + // This avoids a performance penalty caused by false dependencies when + // transitioning from from AVX to SSE instructions (which may occur + // as soon as the n_left cleanup loop below if BLIS is compiled with + // -mfpmath=sse). + _mm256_zeroupper(); + + // If there are leftover iterations, perform them with scalar code. + if ( bli_is_conj( conjx_use ) ) + { + for ( i = 0; i < n_left; ++i ) + { + PASTEMAC(c,dotjs)( *x0, *y0, rho0 ); + x0 += incx; + y0 += incy; + } + } + else + { + for ( i = 0; i < n_left; ++i ) + { + PASTEMAC(c,dots)( *x0, *y0, rho0 ); + x0 += incx; + y0 += incy; + } + } + + if ( bli_is_conj( conjy ) ) + PASTEMAC(c,conjs)( rho0 ); + + // Accumulate the final result into the output variable. + PASTEMAC(c,axpys)( *alpha, rho0, *rho ); +} diff --git a/kernels/zen/bli_kernels_zen.h b/kernels/zen/bli_kernels_zen.h index da7afd00e8..ff97ca9ea2 100644 --- a/kernels/zen/bli_kernels_zen.h +++ b/kernels/zen/bli_kernels_zen.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 21, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -90,6 +90,8 @@ DOTV_KER_PROT( dcomplex, z, dotv_zen_int5 ) // dotxv (intrinsics) DOTXV_KER_PROT( float, s, dotxv_zen_int ) DOTXV_KER_PROT( double, d, dotxv_zen_int ) +DOTXV_KER_PROT( dcomplex, z, dotxv_zen_int ) +DOTXV_KER_PROT( scomplex, c, dotxv_zen_int ) // scalv (intrinsics) SCALV_KER_PROT( float, s, scalv_zen_int ) From 8e6da6b844e8c045fb0bd0ae8abde2cd58e9f74a Mon Sep 17 00:00:00 2001 From: Chandrashekara K R Date: Wed, 13 Apr 2022 10:03:27 +0530 Subject: [PATCH 104/243] Added the checks to not defining the bool type for C++ code in windows to avoid redefinition build time errror. AMD-Internal: [CPUPL-2037] Change-Id: I065da9206ab06f60876324f258ee12fb9fe83f88 --- frame/include/bli_type_defs.h | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/frame/include/bli_type_defs.h b/frame/include/bli_type_defs.h index cb4e4e4b84..584c221ba0 100644 --- a/frame/include/bli_type_defs.h +++ b/frame/include/bli_type_defs.h @@ -89,10 +89,14 @@ typedef unsigned long int guint_t; // -- Boolean type -- // NOTE: bool_t is no longer used and has been replaced with C99's bool type. +// Not defining the bool type for C++ code in windows platform to avoid +// duplicate definition build error. #ifdef _WIN32 +#ifndef __cplusplus #undef bool typedef gint_t bool; #endif +#endif // BLIS uses TRUE and FALSE macro constants as possible boolean values, but we // define these macros in terms of true and false, respectively, which are // defined by C99 in stdbool.h. From 1a3428ddfc106d1925bcd76f02de79d4e84babf3 Mon Sep 17 00:00:00 2001 From: satish kumar nuggu Date: Fri, 8 Apr 2022 13:19:34 +0530 Subject: [PATCH 105/243] Parallelization of dtrsm_small routine 1. Parallelized dtrsm_small across m-dimension or n-dimension based on side(Left/Right). 2. Fine-tuning with AOCL_DYNAMIC to achieve better performance. AMD-Internal: [CPUPL-2103] Change-Id: I6be6a2b579de7df9a3141e0d68bdf3e8a869a005 --- frame/base/bli_rntm.c | 15 +++- frame/compat/bla_trsm_amd.c | 41 ++++++++- kernels/zen/3/bli_trsm_small.c | 147 ++++++++++++++++++++++++++++++--- kernels/zen/bli_kernels_zen.h | 14 +++- 4 files changed, 201 insertions(+), 16 deletions(-) diff --git a/frame/base/bli_rntm.c b/frame/base/bli_rntm.c index f8e00c6208..c15650e918 100644 --- a/frame/base/bli_rntm.c +++ b/frame/base/bli_rntm.c @@ -631,13 +631,22 @@ void bli_nthreads_optimum( else n_threads_ideal = n_threads; } - else if( family == BLIS_TRSM && bli_obj_is_double(c)) + else if( family == BLIS_TRSM && bli_obj_is_double(c) ) { dim_t m = bli_obj_length(c); dim_t n = bli_obj_width(c); - if(m<=512 && n<=512) - n_threads_ideal = 4; +#ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM + if ( (m <= 300) && (n <= 300) ) + n_threads_ideal = 8; + else if ( (m <= 400) && (n <= 400) ) + n_threads_ideal = 16; + else if ( (m <= 900) && (n <= 900) ) + n_threads_ideal = 32; +#else + if ( (m <= 512) && (n <= 512) ) + n_threads_ideal = 4; +#endif } else if( family == BLIS_TRSM && bli_obj_is_dcomplex(c)) { diff --git a/frame/compat/bla_trsm_amd.c b/frame/compat/bla_trsm_amd.c index e1a2fffafd..3b3850928a 100644 --- a/frame/compat/bla_trsm_amd.c +++ b/frame/compat/bla_trsm_amd.c @@ -395,7 +395,7 @@ void strsm_ ) { AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO) - AOCL_DTL_LOG_TRSM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'd', + AOCL_DTL_LOG_TRSM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 's', *side, *uploa,*transa, *diaga, *m, *n, (void*)alpha,*lda, *ldb); @@ -886,8 +886,45 @@ void dtrsm_ return; } } -#endif + + //bli_trsm_small_mt is performing better than native multithread + //for certain sizes of m & n. +#ifdef BLIS_ENABLE_OPENMP + rntm_t rntm; + bli_rntm_init_from_global( &rntm ); + + // Query the total number of threads from the rntm_t object. + dim_t n_threads = bli_rntm_num_threads( &rntm ); + if ( ( (n_threads > 1) && (m0 <= 1500) && (n0 <= 1500) ) || + ( (n_threads == 32) && (m0 <= 2300) && (n0 <= 2300) ) || + ( (n_threads == 16) && (m0 <= 3800) && (n0 <= 3800) ) || + ( (n_threads == 8) && (m0 <= 2800) && (n0 <= 2800) ) || + ( (n_threads == 4) && (m0 <= 2000) && (n0 <= 2000) ) || + ( (n_threads == 2) && (m0 <= 2000) && (n0 <= 2000) ) ) + { + err_t status; + status = bli_trsm_small_mt + ( + blis_side, + &alphao, + &ao, + &bo, + NULL, + NULL + ); + + if ( status == BLIS_SUCCESS ) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + /* Finalize BLIS. */ + bli_finalize_auto(); + return; + } + } +#endif// BLIS_ENABLE_OPENMP +#endif// END of BLIS_ENABLE_SMALL_MATRIX_TRSM } + bli_trsmnat ( blis_side, diff --git a/kernels/zen/3/bli_trsm_small.c b/kernels/zen/3/bli_trsm_small.c index 07077010f2..f8c0ea5911 100644 --- a/kernels/zen/3/bli_trsm_small.c +++ b/kernels/zen/3/bli_trsm_small.c @@ -3821,15 +3821,22 @@ err_t bli_trsm_small num_t dt = bli_obj_dt(a); switch(dt) { - case BLIS_DOUBLE: - case BLIS_FLOAT: - case BLIS_SCOMPLEX: - { - if(m > 1000 || n > 1000) { + case BLIS_DOUBLE: + { + bool nt = bli_thread_get_is_parallel(); + if((nt == 0) && (m > 1000 || n > 1000)) { + return BLIS_NOT_YET_IMPLEMENTED; + } + break; + } + case BLIS_FLOAT: + case BLIS_SCOMPLEX: + { + if(m > 1000 || n > 1000) { return BLIS_NOT_YET_IMPLEMENTED; } break; - } + } case BLIS_DCOMPLEX: { if(m > 500 || n > 500) { @@ -3886,6 +3893,126 @@ err_t bli_trsm_small return err; }; +#ifdef BLIS_ENABLE_OPENMP +/* + * Parallelized dtrsm_small across m-dimension or n-dimension based on side(Left/Right) + */ + +err_t bli_trsm_small_mt +( + side_t side, + obj_t* alpha, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +) +{ + rntm_t rntm; + gint_t m = bli_obj_length( b ); // number of rows of matrix b + gint_t n = bli_obj_width( b ); // number of columns of Matrix b + dim_t d_mr = 8,d_nr = 6; + + num_t dt = bli_obj_dt(a); + switch(dt) + { + case BLIS_DOUBLE: + { + d_mr = 8,d_nr = 6; + break; + } + default: + { + return BLIS_NOT_YET_IMPLEMENTED; + break; + } + } + + #ifdef AOCL_DYNAMIC + // If dynamic-threading is enabled, calculate optimum number + // of threads. + // rntm will be updated with optimum number of threads. + if( bli_obj_is_double(b)) + { + bli_nthreads_optimum(a, b, b, BLIS_TRSM, &rntm); + } + #endif + + bli_rntm_init_from_global( &rntm ); + + // Query the total number of threads from the rntm_t object. + dim_t n_threads = bli_rntm_num_threads( &rntm ); + + if (n_threads < 0 ) n_threads = 1; + + err_t status = BLIS_SUCCESS; + _Pragma( "omp parallel num_threads(n_threads)" ) + { + // Query the thread's id from OpenMP. + const dim_t tid = omp_get_thread_num(); + + obj_t b_t; + dim_t start; // Each thread start Index + dim_t end; // Each thread end Index + thrinfo_t thread; + + thread.n_way = n_threads; + thread.work_id = tid; + thread.ocomm_id = tid; + + + // Compute start and end indexes of matrix partitioning for each thread + if ( bli_is_right( side ) ) + { + bli_thread_range_sub ( &thread, + m, + d_mr,// Need to decide based on type + FALSE, + &start, + &end + ); + // For each thread acquire matrix block on which they operate + // Data-based parallelism + + bli_acquire_mpart_mdim(BLIS_FWD, BLIS_SUBPART1, start, end-start, b, &b_t); + } + else + { + bli_thread_range_sub ( &thread, + n, + d_nr,// Need to decide based on type + FALSE, + &start, + &end + ); + // For each thread acquire matrix block on which they operate + // Data-based parallelism + + bli_acquire_mpart_ndim(BLIS_FWD, BLIS_SUBPART1, start, end-start, b, &b_t); + } + + // Parallelism is only across m-dimension/n-dimension - therefore matrix a is common to + // all threads + err_t status_l = BLIS_SUCCESS; + + status_l = bli_trsm_small + ( + side, + alpha, + a, + &b_t, + NULL, + NULL + ); + // To capture the error populated from any of the threads + _Pragma( "omp critical" ) + status = (status != BLIS_NOT_YET_IMPLEMENTED)?status_l:status; + } + + return status; +}// End of function +#endif + /* * ZTRSM utilities and kernel functions */ @@ -6105,7 +6232,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB double AlphaVal = *(double *)AlphaObj->buffer; //value of Alpha double* restrict L = a->buffer; //pointer to matrix A - double* restrict B = b->buffer; //pointer to matrix B + double *B = bli_obj_buffer_at_off(b); //pointer to matrix B double *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks @@ -8565,7 +8692,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB double AlphaVal = *(double *)AlphaObj->buffer; //value of Alpha double* restrict L = a->buffer; //pointer to matrix A - double* restrict B = b->buffer; //pointer to matrix B + double *B = bli_obj_buffer_at_off(b); //pointer to matrix B double *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks @@ -10909,7 +11036,7 @@ BLIS_INLINE err_t bli_dtrsm_small_AltXB_AuXB double AlphaVal = *(double *)AlphaObj->buffer; //value of alpha double *L = a->buffer; //pointer to matrix A - double *B = b->buffer; //pointer to matrix B + double *B = bli_obj_buffer_at_off(b); //pointer to matrix B //pointers that point to blocks for GEMM and TRSM double *a10, *a11, *b01, *b11; @@ -12889,7 +13016,7 @@ BLIS_INLINE err_t bli_dtrsm_small_AutXB_AlXB double AlphaVal = *(double *)AlphaObj->buffer; //value of alpha double *L = a->buffer; //pointer to matrix A - double *B = b->buffer; //pointer to matrix B + double *B = bli_obj_buffer_at_off(b); //pointer to matrix B double *a10, *a11, *b01, *b11; //pointers that point to blocks for GEMM and TRSM diff --git a/kernels/zen/bli_kernels_zen.h b/kernels/zen/bli_kernels_zen.h index ff97ca9ea2..4bba0b22f0 100644 --- a/kernels/zen/bli_kernels_zen.h +++ b/kernels/zen/bli_kernels_zen.h @@ -321,7 +321,7 @@ void bli_dgemm_ref_k1_nn double* c, const inc_t ldc ); - err_t bli_trsm_small +err_t bli_trsm_small ( side_t side, obj_t* alpha, @@ -331,6 +331,18 @@ void bli_dgemm_ref_k1_nn cntl_t* cntl ); +#ifdef BLIS_ENABLE_OPENMP +err_t bli_trsm_small_mt + ( + side_t side, + obj_t* alpha, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl + ); +#endif + // threshold functions bool bli_cntx_gemmtsup_thresh_is_met_zen ( From f23233eb4c0727d24f2996561e617a758d2cd09a Mon Sep 17 00:00:00 2001 From: Dipal M Zambare Date: Thu, 7 Apr 2022 13:47:39 +0530 Subject: [PATCH 106/243] Added runtime control for DTL logging Feature The logs can be enabled with following two methods: -- Environment variable based control: The feature can be enabled by specifying environment variable AOCL_VERBOSE=1. -- API based control: Two API's will be added to enable/disable logging at runtime 1. AOCL_DTL_Enable_Logs() 2. AOCL_DTL_Disable_Logs() -- The API takes precedence over the environment settings. AMD-Internal: [CPUPL-2101] Change-Id: Ie71c1095496fae89226049c9b9f80b00400350d5 --- aocl_dtl/aocldtl.c | 51 ++++++++++++---- aocl_dtl/aocldtl.h | 25 ++++++++ aocl_dtl/aocldtl_blis.h | 129 +++++++++++++++++++++++++--------------- aocl_dtl/aocldtlcf.h | 20 +++++-- 4 files changed, 163 insertions(+), 62 deletions(-) diff --git a/aocl_dtl/aocldtl.c b/aocl_dtl/aocldtl.c index 6f24788aa0..f3c1658ff8 100644 --- a/aocl_dtl/aocldtl.c +++ b/aocl_dtl/aocldtl.c @@ -5,7 +5,7 @@ * These functions are invoked though macros by * end user. * - * Copyright (C) 2020-2021, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2020-2022, Advanced Micro Devices, Inc. All rights reserved. * *=======================================================================*/ #include "blis.h" @@ -56,6 +56,10 @@ static char *pchDTL_LOG_FILE = AOCL_DTL_LOG_FILE; /* Global file pointer for logging the results */ AOCL_FLIST_Node *gpLogFileList = NULL; + + +/* Global flag to check if logging is enabled or not */ +Bool gbIsLoggingEnabled = FALSE; #endif #if AOCL_DTL_AUTO_TRACE_ENABLE @@ -82,6 +86,23 @@ AOCL_FLIST_Node *gpAutoTraceFileList = NULL; void DTL_Initialize( uint32 ui32CurrentLogLevel) { + /* + * This function can be invoked multiple times either via library + * initialization function (e.g. bli_init()) or when user changes + * logging state using API. However we want it to run only once + * This flag ensure that it is executed only once. + * + * DTL can be used with many libraries hence it needs its own + * method to ensure this. + */ + + static Bool bIsDTLInitDone = FALSE; + + if (bIsDTLInitDone) + { + return; + } + /* If user selects invalid trace log level then the dafault trace log level will be AOCL_DTL_LEVEL_ALL */ if ((ui32CurrentLogLevel < 1) || (ui32CurrentLogLevel > AOCL_DTL_LEVEL_ALL)) @@ -107,15 +128,9 @@ void DTL_Initialize( #endif #if (AOCL_DTL_LOG_ENABLE || AOCL_DTL_DUMP_ENABLE) - /* Create/Open the file to log the log data */ - AOCL_FLIST_AddFile(pchDTL_LOG_FILE, &gpLogFileList, AOCL_gettid()); - - if (NULL == gpLogFileList) - { - /* Unable to open the specified file.*/ - AOCL_DEBUGPRINT("Unable to create the log file %s\n", pchDTL_LOG_FILE); - return; - } + + /* Check if DTL logging is requested via envoronment variable */ + gbIsLoggingEnabled = bli_env_get_var( "AOCL_VERBOSE", FALSE ); #endif #if AOCL_DTL_AUTO_TRACE_ENABLE @@ -133,6 +148,9 @@ void DTL_Initialize( /* Save Id for main thread */ gtidMainThreadID = AOCL_gettid(); + // Ensure that this function is executed only once + bIsDTLInitDone = TRUE; + } /* DTL_Initialize */ #endif @@ -193,6 +211,19 @@ void DTL_Trace( { uint8 i = 0; AOCL_FAL_FILE *pOutFile = NULL; + +#if AOCL_DTL_LOG_ENABLE + /* + * For performance reasons we check the logging state in end user + * macros, this is just an additional check in case the function + * is invoked from any other context. + */ + if (gbIsLoggingEnabled == FALSE && ui8LogType == TRACE_TYPE_LOG) + { + return; + } +#endif + uint64 u64EventTime = AOCL_getTimestamp(); dim_t u64RequestedThreadsCount = AOCL_get_requested_threads_count(); diff --git a/aocl_dtl/aocldtl.h b/aocl_dtl/aocldtl.h index 7ce81561b7..f520518e9c 100644 --- a/aocl_dtl/aocldtl.h +++ b/aocl_dtl/aocldtl.h @@ -109,6 +109,31 @@ void AOCL_DTL_start_perf_timer(void); uint64 AOCL_DTL_get_time_spent(void); +/* + * Logging of inputs can be enabled by two methods: + * + * 1. Using environment variable AOCL_VERBOSE. + * 2. APIs + * + * The API takes precedence over environment variable. + * + * The global flag is maintain in the code to track the final + * state of the logging feature. + */ +extern Bool gbIsLoggingEnabled; + +/* API to enable logging at runtime */ +#define AOCL_DTL_Enable_Logs() \ + /* Initialize DTL if not alredy done so */ \ + AOCL_DTL_INITIALIZE(AOCL_DTL_TRACE_LEVEL); \ + gbIsLoggingEnabled = TRUE; + +/* API to disable logging at runtime */ +#define AOCL_DTL_Disable_Logs() \ + /* Initialize DTL if not alredy done so */ \ + AOCL_DTL_INITIALIZE(AOCL_DTL_TRACE_LEVEL); \ + gbIsLoggingEnabled = FALSE; + /* Macro to log the Data */ #define AOCL_DTL_START_PERF_TIMER() \ AOCL_DTL_start_perf_timer() diff --git a/aocl_dtl/aocldtl_blis.h b/aocl_dtl/aocldtl_blis.h index a9ea3368f9..7b352f9d43 100755 --- a/aocl_dtl/aocldtl_blis.h +++ b/aocl_dtl/aocldtl_blis.h @@ -3,7 +3,7 @@ * * Description : BLIS library specific debug helpes. * - * Copyright (C) 2020-2021, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2020-2022, Advanced Micro Devices, Inc. All rights reserved. * *==================================================================*/ @@ -385,115 +385,148 @@ void AOCL_DTL_log_trmm_sizes(int8 loglevel, #define AOCL_DTL_LOG_GEMM_INPUTS(loglevel, dt, transa, transb, m, n, k, alpha, lda, ldb, beta, ldc) \ - AOCL_DTL_log_gemm_sizes(loglevel, dt, transa, transb, m, n, k, alpha, lda, ldb, beta, ldc, __FILE__, __FUNCTION__, __LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_gemm_sizes(loglevel, dt, transa, transb, m, n, k, alpha, lda, ldb, beta, ldc, \ + __FILE__, __FUNCTION__, __LINE__); #define AOCL_DTL_LOG_GEMM_STATS(loglevel, m, n, k) \ - AOCL_DTL_log_gemm_stats(loglevel, m, n, k); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_gemm_stats(loglevel, m, n, k); #define AOCL_DTL_LOG_TRSM_INPUTS(loglevel, dt, side, uploa, transa, diaga, m, n, alpha, lda, ldb) \ - AOCL_DTL_log_trsm_sizes(loglevel, dt, side, uploa, transa, diaga, m, n, alpha, lda, ldb, __FILE__, __FUNCTION__, __LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_trsm_sizes(loglevel, dt, side, uploa, transa, diaga, m, n, alpha, lda, ldb, \ + __FILE__, __FUNCTION__, __LINE__); #define AOCL_DTL_LOG_GEMMT_INPUTS(loglevel, dt, uplo, transa, transb, n, k, alpha, lda, ldb, beta, ldc) \ - AOCL_DTL_log_gemmt_sizes(loglevel, dt, uplo, transa, transb, n, k, alpha, lda, ldb, beta, ldc, __FILE__,__FUNCTION__,__LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_gemmt_sizes(loglevel, dt, uplo, transa, transb, n, k, alpha, lda, ldb, beta, ldc, \ + __FILE__,__FUNCTION__,__LINE__); #define AOCL_DTL_LOG_HEMM_INPUTS(loglevel, dt_type, side, uplo, m, n, alpha, lda, ldb, beta, ldc) \ - AOCL_DTL_log_hemm_sizes(loglevel, dt_type, side, uplo, m, n, alpha, lda, ldb, beta, ldc, \ - __FILE__, __FUNCTION__, __LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_hemm_sizes(loglevel, dt_type, side, uplo, m, n, alpha, lda, ldb, beta, ldc, \ + __FILE__, __FUNCTION__, __LINE__); // Level-3 Macros #define AOCL_DTL_LOG_HERK_INPUTS(loglevel, dt_type, uploc, transa, m, k, alpha, lda, beta, ldc)\ - AOCL_DTL_log_herk_sizes(loglevel, dt_type, transa, uploc, m, k, alpha, lda, beta, ldc, __FILE__,\ - __FUNCTION__, __LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_herk_sizes(loglevel, dt_type, transa, uploc, m, k, alpha, lda, beta, ldc, __FILE__,\ + __FUNCTION__, __LINE__); #define AOCL_DTL_LOG_HER2K_INPUTS(loglevel, dt_type, uploc, transa, m, k, alpha, lda, ldb, beta, ldc)\ - AOCL_DTL_log_her2k_sizes(loglevel, dt_type, uploc, transa, m, k, alpha, lda, ldb, beta, ldc, __FILE__,\ - __FUNCTION__, __LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_her2k_sizes(loglevel, dt_type, uploc, transa, m, k, alpha, lda, ldb, beta, ldc, __FILE__,\ + __FUNCTION__, __LINE__); #define AOCL_DTL_LOG_SYMM_INPUTS(loglevel, dt_type, side, uploa, m, n, alpha, lda, ldb, beta, ldc)\ - AOCL_DTL_log_symm_sizes(loglevel, dt_type, side, uploa, m, n, alpha, lda, ldb, beta, ldc, __FILE__,\ - __FUNCTION__, __LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_symm_sizes(loglevel, dt_type, side, uploa, m, n, alpha, lda, ldb, beta, ldc, __FILE__,\ + __FUNCTION__, __LINE__); // Level-2 Macros #define AOCL_DTL_LOG_GEMV_INPUTS(loglevel, dt_type, transa, m, n, alp, lda, incx, beta, incy) \ - AOCL_DTL_log_gemv_sizes(loglevel, dt_type, transa, m, n, alp, lda, incx, beta, incy, __FILE__,\ - __FUNCTION__, __LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_gemv_sizes(loglevel, dt_type, transa, m, n, alp, lda, incx, beta, incy, __FILE__,\ + __FUNCTION__, __LINE__); #define AOCL_DTL_LOG_GER_INPUTS(loglevel, dt_type, m, n, alpha, incx, incy, lda) \ - AOCL_DTL_log_ger_sizes(loglevel, dt_type, m, n, alpha, incx, incy, lda, __FILE__, __FUNCTION__, __LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_ger_sizes(loglevel, dt_type, m, n, alpha, incx, incy, lda, __FILE__, __FUNCTION__, __LINE__); #define AOCL_DTL_LOG_HER_INPUTS(loglevel, dt_type, uploa, m, alpha, incx, lda )\ - AOCL_DTL_log_her_sizes(loglevel, dt_type, uploa, m, alpha, incx, lda, __FILE__,__FUNCTION__,__LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_her_sizes(loglevel, dt_type, uploa, m, alpha, incx, lda, __FILE__,__FUNCTION__,__LINE__); #define AOCL_DTL_LOG_SYMV_INPUTS(loglevel, dt_type, uploa, m, alpha, lda, incx, beta, incy)\ - AOCL_DTL_log_symv_sizes(loglevel, dt_type, uploa, m, alpha, lda, incx, beta, incy, __FILE__,\ - __FUNCTION__, __LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_symv_sizes(loglevel, dt_type, uploa, m, alpha, lda, incx, beta, incy, __FILE__,\ + __FUNCTION__, __LINE__); // Level-1 Macros #define AOCL_DTL_LOG_COPY_INPUTS(loglevel, dt_type, n, incx, incy) \ - AOCL_DTL_log_copy_sizes(loglevel, dt_type, n, incx, incy, __FILE__, __FUNCTION__, __LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_copy_sizes(loglevel, dt_type, n, incx, incy, __FILE__, __FUNCTION__, __LINE__); #define AOCL_DTL_LOG_SCAL_INPUTS(loglevel, dt_type, alpha, n, incx )\ - AOCL_DTL_log_scal_sizes(loglevel, dt_type, alpha, n, incx, __FILE__,__FUNCTION__,__LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_scal_sizes(loglevel, dt_type, alpha, n, incx, __FILE__,__FUNCTION__,__LINE__); #define AOCL_DTL_LOG_SWAP_INPUTS(loglevel, dt_type, n, incx, incy)\ - AOCL_DTL_log_swap_sizes(loglevel, dt_type, n, incx, incy, __FILE__,__FUNCTION__,__LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_swap_sizes(loglevel, dt_type, n, incx, incy, __FILE__,__FUNCTION__,__LINE__); #define AOCL_DTL_LOG_NRM2_INPUTS(loglevel, dt_type, n, incx)\ - AOCL_DTL_log_nrm2_sizes(loglevel, dt_type, n, incx, __FILE__,__FUNCTION__,__LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_nrm2_sizes(loglevel, dt_type, n, incx, __FILE__,__FUNCTION__,__LINE__); #define AOCL_DTL_LOG_HEMV_INPUTS(loglevel, dt_type, uploa, m, alpha, lda, incx, beta, incy) \ - AOCL_DTL_log_hemv_sizes(loglevel, dt_type, uploa, m, alpha, lda, incx, beta, incy, \ - __FILE__, __FUNCTION__, __LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_hemv_sizes(loglevel, dt_type, uploa, m, alpha, lda, incx, beta, incy, \ + __FILE__, __FUNCTION__, __LINE__); #define AOCL_DTL_LOG_HER2_INPUTS(loglevel, dt_type, uploa, m, alpha, incx, incy, lda) \ - AOCL_DTL_log_her2_sizes(loglevel, dt_type, uploa, m, alpha, incx, incy, lda, \ - __FILE__, __FUNCTION__, __LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_her2_sizes(loglevel, dt_type, uploa, m, alpha, incx, incy, lda, \ + __FILE__, __FUNCTION__, __LINE__); // Level-1 Macros #define AOCL_DTL_LOG_AMAX_INPUTS(loglevel, dt_type, n, incx) \ - AOCL_DTL_log_amax_sizes(loglevel, dt_type, n, incx, __FILE__, __FUNCTION__, __LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_amax_sizes(loglevel, dt_type, n, incx, __FILE__, __FUNCTION__, __LINE__); #define AOCL_DTL_LOG_ASUM_INPUTS(loglevel, dt_type, n, incx) \ - AOCL_DTL_log_asum_sizes(loglevel, dt_type, n, incx, __FILE__, __FUNCTION__, __LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_asum_sizes(loglevel, dt_type, n, incx, __FILE__, __FUNCTION__, __LINE__); #define AOCL_DTL_LOG_AXPBY_INPUTS(loglevel, dt_type, n, alpha, incx, beta, incy) \ - AOCL_DTL_log_axpby_sizes(loglevel, dt_type, n, alpha, incx, beta, incy, __FILE__,\ - __FUNCTION__, __LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_axpby_sizes(loglevel, dt_type, n, alpha, incx, beta, incy, __FILE__,\ + __FUNCTION__, __LINE__); #define AOCL_DTL_LOG_AXPY_INPUTS(loglevel, dt_type, n, alpha, incx, incy) \ - AOCL_DTL_log_axpy_sizes(loglevel, dt_type, n, alpha, incx, incy, __FILE__,\ - __FUNCTION__, __LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_axpy_sizes(loglevel, dt_type, n, alpha, incx, incy, __FILE__,\ + __FUNCTION__, __LINE__); #define AOCL_DTL_LOG_DOTV_INPUTS(loglevel, dt_type, n, incx, incy) \ - AOCL_DTL_log_dotv_sizes(loglevel, dt_type, n, incx, incy, __FILE__, __FUNCTION__, __LINE__); \ + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_dotv_sizes(loglevel, dt_type, n, incx, incy, __FILE__, __FUNCTION__, __LINE__); \ #define AOCL_DTL_LOG_SYR2_INPUTS(loglevel, dt_type, uploa, m, alpha, incx, incy, lda) \ - AOCL_DTL_log_syr2_sizes(loglevel, dt_type, uploa, m, alpha, incx, incy, lda, __FILE__,\ - __FUNCTION__,__LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_syr2_sizes(loglevel, dt_type, uploa, m, alpha, incx, incy, lda, __FILE__,\ + __FUNCTION__,__LINE__); #define AOCL_DTL_LOG_SYR2K_INPUTS(loglevel, dt_type, uploc, transa, m, k, alpha, lda, ldb, beta, ldc) \ - AOCL_DTL_log_syr2k_sizes(loglevel, dt_type, uploc, transa, m, k, alpha, lda, ldb, beta,\ - ldc, __FILE__, __FUNCTION__,__LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_syr2k_sizes(loglevel, dt_type, uploc, transa, m, k, alpha, lda, ldb, beta,\ + ldc, __FILE__, __FUNCTION__,__LINE__); #define AOCL_DTL_LOG_SYR_INPUTS(loglevel, dt_type, uploa, m, alpha, incx, lda) \ - AOCL_DTL_log_syr_sizes(loglevel, dt_type, uploa, m, alpha, incx, lda,\ - __FILE__,__FUNCTION__,__LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_syr_sizes(loglevel, dt_type, uploa, m, alpha, incx, lda,\ + __FILE__,__FUNCTION__,__LINE__); #define AOCL_DTL_LOG_SYRK_INPUTS(loglevel, dt_type, uploc, transa, m, k, alpha, lda, beta, ldc) \ - AOCL_DTL_log_syrk_sizes(loglevel, dt_type, uploc, transa, m, k, alpha, lda, beta, ldc, __FILE__,\ - __FUNCTION__,__LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_syrk_sizes(loglevel, dt_type, uploc, transa, m, k, alpha, lda, beta, ldc, __FILE__,\ + __FUNCTION__,__LINE__); #define AOCL_DTL_LOG_TRMM_INPUTS(loglevel, dt_type, side, uploa, transa, diaga, m, n, alpha, lda, ldb) \ - AOCL_DTL_log_trmm_sizes(loglevel, dt_type, side, uploa, transa, diaga, m, n, alpha, lda, ldb, __FILE__,\ - __FUNCTION__,__LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_trmm_sizes(loglevel, dt_type, side, uploa, transa, diaga, m, n, alpha, lda, ldb, __FILE__,\ + __FUNCTION__,__LINE__); #define AOCL_DTL_LOG_TRMV_INPUTS(loglevel, dt_type, uploa, transa, diaga, m, lda, incx) \ - AOCL_DTL_log_trmv_sizes(loglevel, dt_type, uploa, transa, diaga, m, lda, incx,\ - __FILE__,__FUNCTION__,__LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_trmv_sizes(loglevel, dt_type, uploa, transa, diaga, m, lda, incx,\ + __FILE__,__FUNCTION__,__LINE__); #define AOCL_DTL_LOG_TRSV_INPUTS(loglevel, dt_type, uploa, transa, diaga, m, lda, incx ) \ - AOCL_DTL_log_trsv_sizes(loglevel, dt_type, uploa, transa, diaga, m, lda, incx,\ - __FILE__,__FUNCTION__,__LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_trsv_sizes(loglevel, dt_type, uploa, transa, diaga, m, lda, incx,\ + __FILE__,__FUNCTION__,__LINE__); #else #define AOCL_DTL_LOG_GEMM_INPUTS(loglevel, dt, transa, transb, m, n, k, alpha, lda, ldb, beta, ldc) diff --git a/aocl_dtl/aocldtlcf.h b/aocl_dtl/aocldtlcf.h index 4f1e923a05..9420e7d364 100644 --- a/aocl_dtl/aocldtlcf.h +++ b/aocl_dtl/aocldtlcf.h @@ -5,7 +5,7 @@ * libaray, all debug features (except auto trace) * can be enabled/disabled in this file. * - * Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2020-2022, Advanced Micro Devices, Inc. All rights reserved. * *==================================================================*/ @@ -20,9 +20,21 @@ enable this macro by making it to 1 else 0 */ #define AOCL_DTL_DUMP_ENABLE 0 -/* Macro for logging the logs If the user wants to enable loging information he - has to enable this macro by making it to 1 else 0 */ -#define AOCL_DTL_LOG_ENABLE 0 +/* + * Logging of inputs can be enabled by two methods: + * + * 1. Using environment variable AOCL_VERBOSE. + * 2. APIs AOCL_DTL_Enable_Logs(), AOCL_DTL_Disable_Logs() + * + * The API takes precedence over environment variable. + * + * The global flag is maintain in the code to track the final + * state of the logging feature. + * + * Setting AOCL_DTL_LOG_ENABLE = 0 will disable this feature + * completely and it is not recommended. + */ +#define AOCL_DTL_LOG_ENABLE 1 /* Select the trace level till which you want to log the data */ /* By default it will log for all levels */ From 963a6aa0997d146d4057fbc2ae617425961b4cab Mon Sep 17 00:00:00 2001 From: Dave Date: Fri, 22 Apr 2022 11:47:00 +0530 Subject: [PATCH 107/243] Enabled zgemm_sup path and removed sqp path - Previously zgemm computation failures were due to status variable did not have pre-defined initial value which resulted in zgemm computation to return without being computed by any kernel. Reflected same change in dgemm_ function as well. - Enabled sup zgemm as the issue is fixed with status variable with bli_zgemm_small call. -Removed calling sqp method as it is disabled Change-Id: I0f4edfd619bc4877ebfc5cb6532c26c3888f919d --- frame/compat/bla_gemm_amd.c | 48 +++++-------------------------------- 1 file changed, 6 insertions(+), 42 deletions(-) diff --git a/frame/compat/bla_gemm_amd.c b/frame/compat/bla_gemm_amd.c index 7060509de2..681869c9b8 100644 --- a/frame/compat/bla_gemm_amd.c +++ b/frame/compat/bla_gemm_amd.c @@ -560,7 +560,7 @@ void dgemm_ if( ( ( (m0 + n0 -k0) < 2000) && ((m0 + k0-n0) < 2000) && ((n0 + k0-m0) < 2000) ) || ((n0 <= 10) && (k0 <=10)) ) { - err_t status; + err_t status = BLIS_FAILURE; if (bli_is_notrans(blis_transa)) { status = bli_dgemm_small( &alphao, @@ -754,50 +754,14 @@ void zgemm_ } #endif - // The code below will be called when number of threads = 1. -#if 0//ENABLE_INDUCED_METHOD - /* 3m_sqp is optimal for certain matrix shapes. - Initial study that it works well for square sizes and sizes closer to square shape. - - * Usage of 3m_sqp is restricted to sizes, where it is found efficient compared to native, sup and other induced method. - * Further investigation is necessary to make the usage choices more generic. */ - bool sqp_on = false; - if( (m0 == n0 ) && ( n0 == k0 ) && ( m0 == 128 ) ) - { - sqp_on = true; - } - - // current range of sizes used for 3m_sqp to be expaned after evaluation. - if( ( m0 >= 4200) && ( m0 <= 4600 ) && ( ( n0 >= 326 ) || (n0 <= 1600 ) ) - && ( k0 == 1120 ) ) //to be tuned further. - { - sqp_on = true; - } - - if( ( blis_transb == BLIS_NO_TRANSPOSE) && ( sqp_on == true ) ) + err_t status = bli_gemmsup(&alphao, &ao, &bo, &betao, &co, NULL, NULL); + if(status==BLIS_SUCCESS) { - //sqp algo is found better for n > 40 - if(bli_gemm_sqp(&alphao, &ao, &bo, &betao, &co, NULL, NULL)==BLIS_SUCCESS) - { - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) - return; - } + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) + return; } -#endif//ENABLE_INDUCED_METHOD - -// sup has been disabled. - if(0) - { - err_t status = bli_gemmsup(&alphao, &ao, &bo, &betao, &co, NULL, NULL); - if(status==BLIS_SUCCESS) - { - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) - return; - } - } // fall back on native path when zgemm is not handled in sup path. bli_gemmnat(&alphao, &ao, &bo, &betao, &co, NULL, NULL); AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); From 16de63c81829710008775b2f4ec9fded5021af3b Mon Sep 17 00:00:00 2001 From: Dipal M Zambare Date: Mon, 25 Apr 2022 15:58:10 +0530 Subject: [PATCH 108/243] Updated version and copyright notice. Changed AMD-BLIS version to 3.1.2 AMD-Internal: [CPUPL-2111] Change-Id: Id8fc3fbc112f08bd5e5def646c472047352e65b5 --- LICENSE | 2 +- so_version | 2 +- version | 3 +-- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/LICENSE b/LICENSE index 0e7a6071d2..be24a09734 100644 --- a/LICENSE +++ b/LICENSE @@ -15,7 +15,7 @@ copyright info. All parties provide their portions of the code under the Copyright (C) 2018, The University of Texas at Austin Copyright (C) 2016, Hewlett Packard Enterprise Development LP -Copyright (C) 2018 - 2021, Advanced Micro Devices, Inc. +Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/so_version b/so_version index 8efd5969fe..77605e74c7 100644 --- a/so_version +++ b/so_version @@ -1,2 +1,2 @@ 3 -2.0 +1.2 diff --git a/version b/version index 252fb77212..ef538c2810 100644 --- a/version +++ b/version @@ -1,2 +1 @@ -3.2.0 - +3.1.2 From a8bc55c37322f6cf6d6f7c92ae1607faf660de39 Mon Sep 17 00:00:00 2001 From: "S, HariharaSudhan" Date: Tue, 29 Mar 2022 18:05:59 +0530 Subject: [PATCH 109/243] Multithreaded SGEMV var 1 with smart threading - Implemented an OpenMP based stand alone SGEMV kernel for row-major (var 1) for multithread scenarios - Smart threading is enabled when AOCL DYNAMIC is defined - Number of threads are decided based on the input dims using smart threading AMD-Internal: [CPUPL-1984] Change-Id: I9b191e965ba7468e95aabcce21b35a533017502e --- frame/2/gemv/bli_gemv_unf_var1_amd.c | 128 ++++++++- kernels/zen/2/bli_gemv_zen_int_4.c | 395 +++++++++++++++++++++++++++ 2 files changed, 522 insertions(+), 1 deletion(-) diff --git a/frame/2/gemv/bli_gemv_unf_var1_amd.c b/frame/2/gemv/bli_gemv_unf_var1_amd.c index 7228c12f75..8295f3927e 100644 --- a/frame/2/gemv/bli_gemv_unf_var1_amd.c +++ b/frame/2/gemv/bli_gemv_unf_var1_amd.c @@ -332,6 +332,92 @@ void bli_dgemv_unf_var1 AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); } +// Returns the optimal number of threads for the given input sizes and fuse factor +void bli_sgemv_var1_smart_threading + ( + dim_t m, dim_t n, + dim_t fuse, + dim_t* nt, dim_t nt_max + ) +{ + // Calculate the amount data processed per iteration + dim_t n_per_loop = n / fuse; + double data_per_iter = n_per_loop* m; + double m_n_ratio = m/n; + + // When the input value is less than the fuse factor + if(n_per_loop < 1) + { + *nt = 1; + return; + } + + // Then there are two cases one + // In m < n the thread spawning is less aggressive when compared to m > n and m = n cases + if(m_n_ratio <= 0.6) + { + // Boundary units is the amount of data processed by each iteration + // This is the variable X in the equation + const double lower_boundary = 50000; + const double higher_boundary = 500000; + + if(data_per_iter < lower_boundary) + { + double coeff_x = 0.9148; + double constant = -1.6252; + // Number of threads = 0.9148 * log(x) - 1.6252 + *nt = ceil(coeff_x * log(data_per_iter) + constant); + } + else if(data_per_iter < higher_boundary) + { + float coeff_x = 10.23; + float constant = -82.332; + // Number of threads = 10.23 * log(x) - 82.332 + *nt = ceil(coeff_x * log(data_per_iter) + constant); + } + else + { + // When the amount of data to be processed is above both of the boundaries + // The number of threads spawned will be equal to the max number of threads set + *nt = nt_max; + } + } + else + { + // Boundary units is the amount of data processed by each iteration + // This is the variable X in the equation + const float lower_boundary = 50000; + const float higher_boundary = 360000; + + if(data_per_iter < lower_boundary) + { + float coeff_x2 = -2E-09; + float coeff_x = 0.0002; + float constant = 1.0234; + // Number of threads = -2E-09*x^2 + 0.0002 * x + 1.0234 + *nt = ceil(coeff_x2 * (data_per_iter * data_per_iter) + coeff_x * data_per_iter + constant); + } + else if(data_per_iter < higher_boundary) + { + float coeff_x = 16.917; + float constant = -164.82; + // Number of threads = 16.917 * log(x) - 164.82 + *nt = ceil(coeff_x * log(data_per_iter) + constant); + } + else + { + // When the amount of data to be processed is above both of the boundaries + // The number of threads spawned will be equal to the max number of threads set + *nt = nt_max; + } + } + + // When the number of threads calculated is greater than the user provided value + // Choose the user provided value + if(*nt > nt_max) + *nt = nt_max; +} + void bli_sgemv_unf_var1 ( trans_t transa, @@ -407,7 +493,46 @@ void bli_sgemv_unf_var1 return; } - /* Query the context for the kernel function pointer and fusing factor. */ +// If both multithreading and OpenMP are enabled, GEMV will multithread +#if defined(BLIS_ENABLE_MULTITHREADING) && defined(BLIS_ENABLE_OPENMP) + dim_t nt, nt_max; + + rntm_t rnmt_obj; + + b_fuse = 4; + + // Initialize a local runtime with global settings. + bli_rntm_init_from_global( &rnmt_obj ); + + // Query the total number of threads from the rntm_t object. + nt_max = bli_rntm_num_threads( &rnmt_obj ); + + //Setting the thread count to the maximum number of threads provided + nt = nt_max; + + // Enable smart threading when AOCL dynamic is enabled + #ifdef AOCL_DYNAMIC + bli_sgemv_var1_smart_threading(n_elem, n_iter, b_fuse, &nt, nt_max); + #endif + + // Pass the input paramaters along with the number of threads to be used + bli_multi_sgemv_4x2 + ( + conja, + conjx, + n_elem, + n_iter, + alpha, + a, cs_at, rs_at, + x, incx, + beta, + y, incy, + cntx, + nt + ); + +#else + b_fuse = 8; for ( i = 0; i < n_iter; i += f ) @@ -434,6 +559,7 @@ void bli_sgemv_unf_var1 ); } +#endif } INSERT_GENTFUNC_BASIC0_CZ( gemv_unf_var1 ) diff --git a/kernels/zen/2/bli_gemv_zen_int_4.c b/kernels/zen/2/bli_gemv_zen_int_4.c index b3c92b551c..74904605ee 100644 --- a/kernels/zen/2/bli_gemv_zen_int_4.c +++ b/kernels/zen/2/bli_gemv_zen_int_4.c @@ -35,6 +35,24 @@ #include "immintrin.h" #include "blis.h" +/* Union data structure to access AVX registers + One 256-bit AVX register holds 8 SP elements. */ +typedef union +{ + __m256 v; + float f[8] __attribute__((aligned(64))); +} v8sf_t; + + +/* Union data structure to access AVX registers +* One 128-bit AVX register holds 4 SP elements. */ +typedef union +{ + __m128 v; + float f[4] __attribute__((aligned(64))); +} v4sf_t; + + /* This implementation uses 512 bits of cache line efficiently for column stored matrix and vectors. @@ -477,3 +495,380 @@ void bli_cgemv_zen_int_4x4 } } + +/* +Function performs multithreaded GEMV for float datatype +All parameters are similar to single thread GEMV except +n_thread which specifies the number of threads to be used +*/ +void bli_multi_sgemv_4x2 + ( + conj_t conjat, + conj_t conjx, + dim_t m, + dim_t b_n, + float* restrict alpha, + float* restrict a, inc_t inca, inc_t lda, + float* restrict x, inc_t incx, + float* restrict beta, + float* restrict y, inc_t incy, + cntx_t* restrict cntx, + dim_t n_threads + ) +{ + const dim_t b_fuse = 4; + const dim_t n_elem_per_reg = 8; + dim_t total_iteration = 0; + + // If the b_n dimension is zero, y is empty and there is no computation. + if (bli_zero_dim1(b_n)) + return; + + // If the m dimension is zero, or if alpha is zero, the computation + // simplifies to updating y. + if (bli_zero_dim1(m) || PASTEMAC(s, eq0)(*alpha)) + { + + bli_sscalv_zen_int10( + BLIS_NO_CONJUGATE, + b_n, + beta, + y, incy, + cntx); + return; + } + + // If b_n is not equal to the fusing factor, then perform the entire + // operation as a loop over dotxv. + if (b_n < b_fuse) + { + for (dim_t i = 0; i < b_n; ++i) + { + float *a1 = a + (0) * inca + (i)*lda; + float *x1 = x + (0) * incx; + float *psi1 = y + (i)*incy; + + bli_sdotxv_zen_int( + conjat, + conjx, + m, + alpha, + a1, inca, + x1, incx, + beta, + psi1, + cntx); + } + return; + } + + // Calculate the total number of multithreaded iteration + total_iteration = b_n / b_fuse; + +#pragma omp parallel for num_threads(n_threads) + for (dim_t j = 0; j < total_iteration; j++) + { + float *A1 = a + (b_fuse * j) * lda; + float *x1 = x; + float *y1 = y + (b_fuse * j) * incy; + + // Intermediate variables to hold the completed dot products + float rho0[4] = {0, 0, 0, 0}; + + // If vectorization is possible, perform them with vector + // instructions. + if (inca == 1 && incx == 1) + { + const dim_t n_iter_unroll = 2; + + // Use the unrolling factor and the number of elements per register + // to compute the number of vectorized and leftover iterations. + dim_t l, unroll_inc, m_viter[2], m_left = m; + + unroll_inc = n_elem_per_reg * n_iter_unroll; + + m_viter[0] = m_left / unroll_inc; + m_left = m_left % unroll_inc; + + m_viter[1] = m_left / n_elem_per_reg ; + m_left = m_left % n_elem_per_reg; + + // Set up pointers for x and the b_n columns of A (rows of A^T). + float *restrict x0 = x1; + float *restrict av[4]; + + av[0] = A1 + 0 * lda; + av[1] = A1 + 1 * lda; + av[2] = A1 + 2 * lda; + av[3] = A1 + 3 * lda; + + // Initialize b_n rho vector accumulators to zero. + v8sf_t rhov[4]; + + rhov[0].v = _mm256_setzero_ps(); + rhov[1].v = _mm256_setzero_ps(); + rhov[2].v = _mm256_setzero_ps(); + rhov[3].v = _mm256_setzero_ps(); + + v8sf_t xv[2]; + v8sf_t a_vec[8]; + + // FMA operation is broken down to mul and add + // to reduce backend stalls + for (l = 0; l < m_viter[0]; ++l) + { + xv[0].v = _mm256_loadu_ps(x0); + x0 += n_elem_per_reg; + xv[1].v = _mm256_loadu_ps(x0); + x0 += n_elem_per_reg; + + a_vec[0].v = _mm256_loadu_ps(av[0]); + a_vec[4].v = _mm256_loadu_ps(av[0] + n_elem_per_reg); + + // perform: rho?v += a?v * x0v; + a_vec[0].v = _mm256_mul_ps(a_vec[0].v, xv[0].v); + rhov[0].v = _mm256_fmadd_ps(a_vec[4].v, xv[1].v, rhov[0].v); + rhov[0].v = _mm256_add_ps(a_vec[0].v, rhov[0].v); + + a_vec[1].v = _mm256_loadu_ps(av[1]); + a_vec[5].v = _mm256_loadu_ps(av[1] + n_elem_per_reg); + + a_vec[1].v = _mm256_mul_ps(a_vec[1].v, xv[0].v); + rhov[1].v = _mm256_fmadd_ps(a_vec[5].v, xv[1].v, rhov[1].v); + rhov[1].v = _mm256_add_ps(a_vec[1].v, rhov[1].v); + + a_vec[2].v = _mm256_loadu_ps(av[2]); + a_vec[6].v = _mm256_loadu_ps(av[2] + n_elem_per_reg); + + a_vec[2].v = _mm256_mul_ps(a_vec[2].v, xv[0].v); + rhov[2].v = _mm256_fmadd_ps(a_vec[6].v, xv[1].v, rhov[2].v); + rhov[2].v = _mm256_add_ps(a_vec[2].v, rhov[2].v); + + a_vec[3].v = _mm256_loadu_ps(av[3]); + a_vec[7].v = _mm256_loadu_ps(av[3] + n_elem_per_reg); + + a_vec[3].v = _mm256_mul_ps(a_vec[3].v, xv[0].v); + rhov[3].v = _mm256_fmadd_ps(a_vec[7].v, xv[1].v, rhov[3].v); + rhov[3].v = _mm256_add_ps(a_vec[3].v, rhov[3].v); + + av[0] += unroll_inc; + av[1] += unroll_inc; + av[2] += unroll_inc; + av[3] += unroll_inc; + } + + for (l = 0; l < m_viter[1]; ++l) + { + // Load the input values. + xv[0].v = _mm256_loadu_ps(x0); + x0 += n_elem_per_reg; + + a_vec[0].v = _mm256_loadu_ps(av[0]); + a_vec[1].v = _mm256_loadu_ps(av[1]); + + rhov[0].v = _mm256_fmadd_ps(a_vec[0].v, xv[0].v, rhov[0].v); + rhov[1].v = _mm256_fmadd_ps(a_vec[1].v, xv[0].v, rhov[1].v); + + av[0] += n_elem_per_reg; + av[1] += n_elem_per_reg; + + a_vec[2].v = _mm256_loadu_ps(av[2]); + a_vec[3].v = _mm256_loadu_ps(av[3]); + + rhov[2].v = _mm256_fmadd_ps(a_vec[2].v, xv[0].v, rhov[2].v); + rhov[3].v = _mm256_fmadd_ps(a_vec[3].v, xv[0].v, rhov[3].v); + + av[2] += n_elem_per_reg; + av[3] += n_elem_per_reg; + } + + // Sum the elements within each vector. + // Sum the elements of a given rho?v with hadd. + rhov[0].v = _mm256_hadd_ps(rhov[0].v, rhov[1].v); + rhov[2].v = _mm256_hadd_ps(rhov[2].v, rhov[3].v); + rhov[0].v = _mm256_hadd_ps(rhov[0].v, rhov[0].v); + rhov[2].v = _mm256_hadd_ps(rhov[2].v, rhov[2].v); + + // Manually add the results from above to finish the sum. + rho0[0] = rhov[0].f[0] + rhov[0].f[4]; + rho0[1] = rhov[0].f[1] + rhov[0].f[5]; + rho0[2] = rhov[2].f[0] + rhov[2].f[4]; + rho0[3] = rhov[2].f[1] + rhov[2].f[5]; + + // If leftover elements are more than 4, perform SSE + if (m_left > 4) + { + v4sf_t xv128, a_vec128[4], rhov128[4]; + + rhov128[0].v = _mm_set1_ps(0); + rhov128[1].v = _mm_set1_ps(0); + rhov128[2].v = _mm_set1_ps(0); + rhov128[3].v = _mm_set1_ps(0); + + // Load the input values. + xv128.v = _mm_loadu_ps(x0 + 0 * n_elem_per_reg); + x0 += 4; + m_left -= 4; + + a_vec128[0].v = _mm_loadu_ps(av[0]); + a_vec128[1].v = _mm_loadu_ps(av[1]); + + // perform: rho?v += a?v * x0v; + rhov128[0].v = _mm_fmadd_ps(a_vec128[0].v, xv128.v, rhov128[0].v); + rhov128[1].v = _mm_fmadd_ps(a_vec128[1].v, xv128.v, rhov128[1].v); + rhov128[0].v = _mm_hadd_ps(rhov128[0].v, rhov128[1].v); + rhov128[0].v = _mm_hadd_ps(rhov128[0].v, rhov128[0].v); + + a_vec128[2].v = _mm_loadu_ps(av[2]); + a_vec128[3].v = _mm_loadu_ps(av[3]); + + rhov128[2].v = _mm_fmadd_ps(a_vec128[2].v, xv128.v, rhov128[2].v); + rhov128[3].v = _mm_fmadd_ps(a_vec128[3].v, xv128.v, rhov128[3].v); + rhov128[2].v = _mm_hadd_ps(rhov128[2].v, rhov128[3].v); + rhov128[2].v = _mm_hadd_ps(rhov128[2].v, rhov128[2].v); + + rho0[0] += rhov128[0].f[0]; + rho0[1] += rhov128[0].f[1]; + rho0[2] += rhov128[2].f[0]; + rho0[3] += rhov128[2].f[1]; + + av[0] += 4; + av[1] += 4; + av[2] += 4; + av[3] += 4; + } + + // If there are leftover iterations, perform them with scalar code. + for (l = 0; l < m_left; ++l) + { + rho0[0] += *(av[0]) * (*x0); + rho0[1] += *(av[1]) * (*x0); + rho0[2] += *(av[2]) * (*x0); + rho0[3] += *(av[3]) * (*x0); + + x0 += incx; + av[0] += inca; + av[1] += inca; + av[2] += inca; + av[3] += inca; + } + + } + else + { + // When vectorization is not possible, perform with scalar code + + // Initialize pointers for x and the b_n columns of A (rows of A^T). + float *restrict x0 = x1; + float *restrict a0 = A1 + 0 * lda; + float *restrict a1 = A1 + 1 * lda; + float *restrict a2 = A1 + 2 * lda; + float *restrict a3 = A1 + 3 * lda; + + for (dim_t l = 0; l < m; ++l) + { + const float x0c = *x0; + + const float a0c = *a0; + const float a1c = *a1; + const float a2c = *a2; + const float a3c = *a3; + + rho0[0] += a0c * x0c; + rho0[1] += a1c * x0c; + rho0[2] += a2c * x0c; + rho0[3] += a3c * x0c; + + x0 += incx; + a0 += inca; + a1 += inca; + a2 += inca; + a3 += inca; + } + } + + v4sf_t rho0v, y0v; + + rho0v.v = _mm_loadu_ps(rho0); + + // Broadcast the alpha scalar. + v4sf_t alphav; + alphav.v = _mm_broadcast_ss(alpha); + + // We know at this point that alpha is nonzero; however, beta may still + // be zero. If beta is indeed zero, we must overwrite y rather than scale + // by beta (in case y contains NaN or Inf). + if (PASTEMAC(s, eq0)(*beta)) + { + // Apply alpha to the accumulated dot product in rho: + // y := alpha * rho + y0v.v = _mm_mul_ps(alphav.v, rho0v.v); + } + else + { + // Broadcast the beta scalar. + v4sf_t betav; + betav.v = _mm_broadcast_ss(beta); + + if (incy == 0) + { + // Load y. + y0v.v = _mm_loadu_ps(y1 + 0 * n_elem_per_reg); + } + else + { + // Load y. + y0v.f[0] = *(y1 + 0 * incy); + y0v.f[1] = *(y1 + 1 * incy); + y0v.f[2] = *(y1 + 2 * incy); + y0v.f[3] = *(y1 + 3 * incy); + } + + // Apply beta to y and alpha to the accumulated dot product in rho: + // y := beta * y + alpha * rho + y0v.v = _mm_mul_ps(betav.v, y0v.v); + y0v.v = _mm_fmadd_ps(alphav.v, rho0v.v, y0v.v); + } + + // Store the output. + if (incy == 1) + { + _mm_storeu_ps((y1 + 0 * n_elem_per_reg), y0v.v); + } + else + { + // Store the output. + *(y1 + 0 * incy) = y0v.f[0]; + *(y1 + 1 * incy) = y0v.f[1]; + *(y1 + 2 * incy) = y0v.f[2]; + *(y1 + 3 * incy) = y0v.f[3]; + } + } + + // Performs the complete computation if OpenMP is not enabled + dim_t start = total_iteration * b_fuse; + dim_t new_fuse = 8, f; + + // Left over corner cases completed using fused kernel + for (dim_t i = start; i < b_n; i += f) + { + f = bli_determine_blocksize_dim_f(i, b_n, new_fuse); + + float *A1 = a + (i)*lda + (0) * inca; + float *x1 = x + (0) * incx; + float *y1 = y + (i)*incy; + + /* y1 = beta * y1 + alpha * A1 * x; */ + bli_sdotxf_zen_int_8( + conjat, + conjx, + m, + f, + alpha, + A1, inca, lda, + x1, incx, + beta, + y1, incy, + cntx); + } +} From a3836a560da687dbf86c2a0b7e00a1b5e26a24e6 Mon Sep 17 00:00:00 2001 From: mkadavil Date: Tue, 22 Mar 2022 21:17:43 +0530 Subject: [PATCH 110/243] Smart Threading for GEMM (sgemm) v1. - Cache aware factorization. Experiments shows that ic,jc factorization based on m,n gives better results compared to mu,nu on a generic data set in SUP path. Also slight adjustments in the factorizations w.r.t matrix data loads can help in improving perf further. - Moving native path inputs to SUP path. Experiments shows that in multi-threaded scenarios if the per thread data falls under SUP thresholds, taking SUP path instead of native path results in improved performance. This is the case even if the original matrix dimensions falls in native path. This is not applicable if A matrix transpose is required. - Enabling B matrix packing in SUP path. Performance improvement is observed when B matrix is packed in cases where gemm takes SUP path instead of native path based on per thread matrix dimensions. AMD-Internal: [CPUPL-659] Change-Id: I3b8fc238a0ece1ababe5d64aebab63092f7c6914 --- frame/3/CMakeLists.txt | 3 +- frame/3/bli_l3.h | 5 +- frame/3/bli_l3_smart_threading.c | 557 +++++++++++++++++++++++++++++++ frame/3/bli_l3_smart_threading.h | 68 ++++ frame/3/bli_l3_sup.c | 49 +-- frame/3/bli_l3_sup_int_amd.c | 5 +- frame/base/bli_rntm.c | 82 ++++- frame/base/bli_rntm.h | 12 +- 8 files changed, 755 insertions(+), 26 deletions(-) create mode 100644 frame/3/bli_l3_smart_threading.c create mode 100644 frame/3/bli_l3_smart_threading.h diff --git a/frame/3/CMakeLists.txt b/frame/3/CMakeLists.txt index b3aaf2c8c8..e9d7da7b8e 100644 --- a/frame/3/CMakeLists.txt +++ b/frame/3/CMakeLists.txt @@ -25,6 +25,7 @@ target_sources("${PROJECT_NAME}" ${CMAKE_CURRENT_SOURCE_DIR}/bli_l3_ukr_fpa.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_l3_ukr_oapi.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_l3_ukr_tapi.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_l3_smart_threading.c ) # Select AMD specific sources for AMD configurations. if(${TARGET_ARCH} STREQUAL zen OR @@ -38,7 +39,7 @@ if(${TARGET_ARCH} STREQUAL zen OR else() target_sources("${PROJECT_NAME}" PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/bli_l3_sup_int_amd.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_l3_sup_int.c ) endif() diff --git a/frame/3/bli_l3.h b/frame/3/bli_l3.h index b64da054c9..b65edfcaac 100644 --- a/frame/3/bli_l3.h +++ b/frame/3/bli_l3.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020, Advanced Micro Devices, Inc. + Copyright (C) 2020-22, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -98,3 +98,6 @@ #include "bli_trmm3.h" #include "bli_trsm.h" #include "bli_gemmt.h" + +// Smart Threading API's. +#include "bli_l3_smart_threading.h" diff --git a/frame/3/bli_l3_smart_threading.c b/frame/3/bli_l3_smart_threading.c new file mode 100644 index 0000000000..e4b9b43e24 --- /dev/null +++ b/frame/3/bli_l3_smart_threading.c @@ -0,0 +1,557 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "bli_l3_smart_threading.h" + +#ifdef AOCL_DYNAMIC + +// Utility functions. +static inline dim_t next_factor + ( + const dim_t nt, + const dim_t part_nt + ) +{ + if ( part_nt == nt) + { + return part_nt; + } + + dim_t nt_temp = part_nt + 1; + while ( ( nt_temp <= nt ) && ( ( nt % nt_temp ) != 0 ) ) + { + nt_temp++; + } + return nt_temp; +} + +static inline dim_t prev_factor + ( + const dim_t nt, + const dim_t part_nt + ) +{ + if ( part_nt == 1) + { + return part_nt; + } + + dim_t nt_temp = part_nt - 1; + while ((nt_temp >= 1) && ((nt % nt_temp) != 0)) + { + nt_temp--; + } + return nt_temp; +} +// End utility functions. + +static err_t bli_gemm_ic_jc_optimum_sup_arch_dispatcher + ( + num_t dt, + siz_t elem_size, + const bool is_rrr_rrc_rcr_crr, + const dim_t m, + const dim_t n, + const dim_t k, + const dim_t max_available_nt, + cntx_t* cntx, + rntm_t* rntm + ); + +static err_t bli_gemm_ic_jc_optimum_sup_zen3 + ( + num_t dt, + siz_t elem_size, + const bool is_rrr_rrc_rcr_crr, + const dim_t m, + const dim_t n, + const dim_t k, + const dim_t max_available_nt, + cntx_t* cntx, + rntm_t* rntm + ); + +static void bli_gemm_cache_heur_adjust_ic_jc_sup_zen3 + ( + const dim_t m, + const dim_t n, + const dim_t k, + dim_t nt, + dim_t* ic, + dim_t* jc, + const dim_t MR, + const dim_t NR, + const dim_t MC, + const dim_t KC + ); + +err_t bli_check_and_transform_native_to_SUP + ( + num_t dt, + siz_t elem_size, + const bool is_rrr_rrc_rcr_crr, + const dim_t m, + const dim_t n, + const dim_t k, + dim_t ic, + dim_t jc, + const dim_t NR, + const dim_t MC, + const dim_t KC, + cntx_t* cntx, + rntm_t* rntm + ); + +err_t bli_gemm_smart_threading_sup + ( + num_t dt, + siz_t elem_size, + const bool is_rrr_rrc_rcr_crr, + const dim_t m, + const dim_t n, + const dim_t k, + const dim_t max_available_nt, + cntx_t* cntx, + rntm_t* rntm + ) +{ + err_t ret_val = BLIS_FAILURE; + + // Sanity check, max available threads should be atleast 4 for the + // smart threading/factorization to be meaningful. For nt < 4 the + // default ic,jc factorization holds good. + if ( ( m <= 1 ) || ( n <= 1 ) || ( k <= 1 ) || ( max_available_nt < 4 ) ) + { + return ret_val; + } + + if ( bli_is_float( dt ) ) + { + ret_val = bli_gemm_ic_jc_optimum_sup_arch_dispatcher + ( + dt, elem_size, is_rrr_rrc_rcr_crr, m, n, k, + max_available_nt, cntx, rntm + ); + } + else + { + // Other data types not supported for now. + } + + if ( ret_val == BLIS_SUCCESS ) + { + // This is a workaround to ensure that auto_factor attribute of rntm_t + // is not set to TRUE inside bli_rntm_set_ways_from_rntm_sup. Also + // the nt value will be properly set to ic*jc towards the end of + // bli_rntm_set_ways_from_rntm_sup. + bli_rntm_set_num_threads_only( -1, rntm ); + } + + return ret_val; +} + +static err_t bli_gemm_ic_jc_optimum_sup_arch_dispatcher + ( + num_t dt, + siz_t elem_size, + const bool is_rrr_rrc_rcr_crr, + const dim_t m, + const dim_t n, + const dim_t k, + const dim_t max_available_nt, + cntx_t* cntx, + rntm_t* rntm + ) +{ + err_t ret_val = BLIS_FAILURE; + + arch_t id = bli_arch_query_id(); + if ( id == BLIS_ARCH_ZEN3 ) + { + ret_val = bli_gemm_ic_jc_optimum_sup_zen3 + ( + dt, elem_size, is_rrr_rrc_rcr_crr, m, n, k, + max_available_nt, cntx, rntm + ); + } + else + { + // Other architectures not supported for now. + } + + return ret_val; +} + +// open zen3 region. +#define NUM_CORES_PER_CCD_ZEN3 8 + +// Determines the optimal number of threads (nt) and corresponding work split +// (ic,jc factorization of nt) for gemm on zen3 machines. +static err_t bli_gemm_ic_jc_optimum_sup_zen3 + ( + num_t dt, + siz_t elem_size, + const bool is_rrr_rrc_rcr_crr, + const dim_t m, + const dim_t n, + const dim_t k, + const dim_t max_available_nt, + cntx_t* cntx, + rntm_t* rntm + ) +{ + err_t ret_val = BLIS_SUCCESS; + + const dim_t MR = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_MR, cntx ); + const dim_t NR = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_NR, cntx ); + const dim_t MC = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_MC, cntx ); + const dim_t NC = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_NC, cntx ); + const dim_t KC = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_KC, cntx ); + + dim_t ic = -1; + dim_t jc = -1; + + bli_thread_partition_2x2( max_available_nt, m, n, &ic, &jc ); + + dim_t jc_per_ccd = ( NUM_CORES_PER_CCD_ZEN3 + ic - 1 ) / ic ; + dim_t b_mat_data_per_ccd = jc_per_ccd * ( n / jc ); + + // All the cores (8) on a CCD share a L3 cache and hence total data + // loaded by the cores on a CCD should be < NC to avoid L3 contention. + // In cases where it is violated, it is better to increase ic and + // reduce B data per CCD, using micro panels mu, nu for thread + // partitioning can help achieve this. Avoiding further ic,jc + // adjustment in this case. + if ( b_mat_data_per_ccd > NC ) + { + const dim_t mu = m / MR; + const dim_t nu = n / NR; + bli_thread_partition_2x2( max_available_nt, mu, nu, &ic, &jc ); + } + else + { + // Adjust the ic,jc in the best match so that m_ic and n_jc + // turns out to be more cache friendly. + bli_gemm_cache_heur_adjust_ic_jc_sup_zen3 + ( + m, n, k, max_available_nt, &ic, &jc, MR, NR, MC, KC + ); + } + + ret_val = bli_check_and_transform_native_to_SUP + ( + dt, elem_size, is_rrr_rrc_rcr_crr, m, n, k, + ic, jc, NR, MC, KC, cntx, rntm + ); + + if ( ret_val == BLIS_SUCCESS ) + { + bli_rntm_set_ic_ways_only( ic, rntm ); + bli_rntm_set_jc_ways_only( jc, rntm ); + } + + return ret_val; +} + +// The factorization of nt into ic,jc is based on m and n values (for simplicity +// it can be assumed to be based on m:n ratio). It does not take into account +// how the matrices are loaded into cache or which matrix goes to the larger +// cache. Depending on the matrix dimensions, increasing the ic can result in +// reduced loads from main memory to L2 cache for A matrix without any impact on +// B matrix load (since B is streamed into L3, which is larger). Similary +// adjusting jc can result in B matrix panels fitting perfectly within the L1 +// cache.This function makes these adjustments on ic,jc. +static void bli_gemm_cache_heur_adjust_ic_jc_sup_zen3 + ( + const dim_t m, + const dim_t n, + const dim_t k, + dim_t nt, + dim_t* ic, + dim_t* jc, + const dim_t MR, + const dim_t NR, + const dim_t MC, + const dim_t KC + ) +{ + const dim_t m_ic = m / ( *ic ); + const dim_t n_jc = n / ( *jc ); + const int64_t cur_work_per_thread = m_ic + n_jc; + + // The next and prev factors are caluclated with respect to the current + // factor part of nt. In effect + // 1. next ic * prev jc = nt + // 2. prev ic * next jc = nt + // 3. ic * jc = nt + const dim_t next_ic = next_factor( nt, ( *ic ) ); + const dim_t prev_ic = prev_factor( nt, ( *ic ) ); + const dim_t next_jc = next_factor( nt, ( *jc ) ); + const dim_t prev_jc = prev_factor( nt, ( *jc ) ); + + const dim_t m_next_ic = m / next_ic; + const dim_t m_prev_ic = m / prev_ic; + const dim_t n_next_jc = n / next_jc; + const dim_t n_prev_jc = n / prev_jc; + const dim_t n_jc_modulo_NR = n_jc % NR; + const dim_t n_prev_jc_modulo_NR = n_prev_jc % NR; + + const int64_t next_jc_work_per_thread = n_next_jc + m_prev_ic; + const int64_t next_ic_work_per_thread = m_next_ic + n_prev_jc; + + const dim_t MCx2 = MC * 2; + const dim_t NRx4 = NR * 4; + const dim_t NRx8 = NR * 8; + + // MC will be reduced if the following mods are zero. Incrementing jc + // helps in this case. + const dim_t n_mod_256 = n % 256; + const dim_t k_mod_256 = k % 256; + + const dim_t k_factor = k / KC; + + bool can_increase_jc = FALSE; + bool can_increase_ic = FALSE; + + // jc adjustment towards next highest factor if it results in n_jc*KC + // fittting completely within l1d cache. Only done if ic prev factor + // does not move m_prev_ic out of good l2 load zone (MC). + // Performance improvement also observed when n_jc is a multiple of NR. + if ( ( ( *ic ) > 1 ) && ( ( *jc ) < nt ) ) + { + // Check whether m_prev_ic remains in good l2 load zone. + if ( ( ( ( m_ic <= MC ) && ( m_prev_ic <= MC ) ) || + ( m_ic > MC ) ) && + ( ( n_jc > NR ) && ( n_next_jc == NR ) ) ) + { + can_increase_jc = TRUE; + } + // 2x2 factorization doesnt always give equal sum partition. + else if ( next_jc_work_per_thread < cur_work_per_thread ) + { + can_increase_jc = TRUE; + } + } + + // Favor jc if both n and k are multiples of 256 ( high cache line + // replacement ). + if ( ( ( *ic ) < nt ) && ( ( *jc ) > 1) ) + { + // ic adjustment towards next highest factor if it results in + // m_next_ic <= MC. This helps in reducing number of A matrix + // loads per thread to l2 from main memory. + if ( ( m_ic > MC ) && ( m_next_ic <= MC ) && + ( m_next_ic >= MR ) && ( k_factor > 4 ) ) + { + can_increase_ic = TRUE; + } + // ic adjustment towards next highest factor resulted in better + // performance when m is sufficiently larger than n and jc prev + // factor did not result in n_prev_jc moving out of good l2 + // load zone (n_jc < 64). + else if ( ( m > ( 5 * n ) ) && ( m_ic >= MCx2 ) && ( k_factor > 4 ) && + ( ( n_jc > NRx4 ) || + ( ( n_jc <= NRx4 ) && ( n_prev_jc <= NRx4 ) ) ) ) + { + can_increase_ic = TRUE; + } + // Performance improvement also observed when n_jc is a multiple + // of NR. + else if ( ( n_jc_modulo_NR != 0 ) && ( n_prev_jc_modulo_NR == 0 ) && + ( k_factor > 4 ) ) + { + can_increase_ic = TRUE; + } + // 2x2 factorization doesnt always give equal sum partition. + else if ( next_ic_work_per_thread <= cur_work_per_thread ) + { + can_increase_ic = TRUE; + } + } + + // Favor jc if both n and k are multiples of 256 ( high cache line + // replacement ). + if ( ( n_mod_256 == 0 ) && ( k_mod_256 == 0 ) && ( k > KC ) ) + { + if ( can_increase_ic == TRUE ) + { + can_increase_ic = FALSE; + } + else if ( can_increase_jc == FALSE ) + { + can_increase_jc = TRUE; + } + } + // If only one of either n or k is a multiple of 256, favour jc if n per + // thread is within a heuristic factor of NR. + else if ( ( ( n_mod_256 == 0 ) || ( k_mod_256 == 0 ) ) && ( k > KC ) ) + { + if ( ( can_increase_ic == TRUE ) && ( n_jc <= NRx8 ) ) + { + can_increase_ic = FALSE; + } + else if ( ( can_increase_jc == FALSE ) && ( n_next_jc <= NRx8 ) ) + { + can_increase_jc = TRUE; + } + } + + // Increasing ic factor is given a higher priority compared to jc + // since it was observed that the A matrix loads (main memory -> l2) had + // more impact on perf compared to B matrix (main memory -> l3 -> l1) + // for the sizes considered. + if ( can_increase_ic ) + { + // It is expected that the larger dimension (m or n) will be + // allocated a larger share of the thread factorization. + if ( ( ( m >= n ) && ( next_ic >= prev_jc ) ) || + ( ( m <= n ) && ( next_ic <= prev_jc ) ) ) + { + *ic = next_ic; + *jc = prev_jc; + } + } + else if ( can_increase_jc ) + { + // It is expected that the larger dimension (m or n) will be + // allocated a larger share of the thread factorization. + if ( ( ( m >= n ) && ( prev_ic >= next_jc ) ) || + ( ( m <= n ) && ( prev_ic <= next_jc ) ) ) + { + *ic = prev_ic; + *jc = next_jc; + } + } +} + +// It was observed that the SUP thresholds can be lowered and applied on a +// per thread basis in multi threaded scenarios. +err_t bli_check_and_transform_native_to_SUP + ( + num_t dt, + siz_t elem_size, + const bool is_rrr_rrc_rcr_crr, + const dim_t m, + const dim_t n, + const dim_t k, + dim_t ic, + dim_t jc, + const dim_t NR, + const dim_t MC, + const dim_t KC, + cntx_t* cntx, + rntm_t* rntm + ) +{ + err_t ret_val = BLIS_FAILURE; + dim_t m_ic; + dim_t n_jc; + + const dim_t MT = bli_cntx_get_l3_sup_thresh_dt( dt, BLIS_MT, cntx ); + const dim_t NT = bli_cntx_get_l3_sup_thresh_dt( dt, BLIS_NT, cntx ); + const dim_t KT = bli_cntx_get_l3_sup_thresh_dt( dt, BLIS_KT, cntx ); + + const dim_t MT_2 = MT / 2; + const dim_t NTx4 = NT * 4; + const dim_t NRx8 = NR * 8; + + const dim_t page_size = bli_info_get_page_size(); + const dim_t page_size_b_float = page_size / ( dim_t ) elem_size; + const dim_t page_size_b_floatx2 = page_size_b_float * 2; + + // Default SUP check without considering per thread dimensions. + if ( ( k < KT ) || ( m < MT ) || ( n < NT ) ) + { + ret_val = BLIS_SUCCESS; + } + // Per thread SUP limit checking. It was observed that when k is large, + // (twice page size) moving native to SUP did not help even if m_ic or + // n_jc were within SUP limits. + else if ( ( m >= MT ) && ( n >= NT ) && ( k < page_size_b_floatx2 ) ) + { + m_ic = m / ic; + n_jc = n / jc; + // In multi-threaded scenario, it was observed that if the per + // thread m dimension(A matrix) and n dimension(B matrix) is + // within a factor of SUP limits, SUP path without packing + // resulted in gains. Along similar lines, if the B matrix is + // large enough and reuse is good, packing B matrix alone in SUP + // resulted in perf gains. + if ( ( m_ic <= MT_2 ) && ( n_jc < NTx4 ) ) + { + if ( ( k > KC ) && + ( m_ic >= MC ) && ( n_jc >= NT ) ) + { + if ( is_rrr_rrc_rcr_crr ) + { + bli_rntm_set_pack_b( 1, rntm ); + } + else + { + bli_rntm_set_pack_a( 1, rntm ); + } + } + ret_val = BLIS_SUCCESS; + } + else if ( ( n_jc < NT ) && ( m_ic <= MT ) ) + { + if ( ( k > KC ) && ( m_ic >= MC ) && ( n_jc >= NRx8 ) ) + { + if ( is_rrr_rrc_rcr_crr ) + { + bli_rntm_set_pack_b( 1, rntm ); + } + else + { + bli_rntm_set_pack_a( 1, rntm ); + } + } + ret_val = BLIS_SUCCESS; + } + else + { + ret_val = BLIS_FAILURE; + } + } + else + { + ret_val = BLIS_FAILURE; + } + + return ret_val; +} +// close zen3 region. + +#endif diff --git a/frame/3/bli_l3_smart_threading.h b/frame/3/bli_l3_smart_threading.h new file mode 100644 index 0000000000..48a0a17bb2 --- /dev/null +++ b/frame/3/bli_l3_smart_threading.h @@ -0,0 +1,68 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifdef AOCL_DYNAMIC + +#ifndef BLIS_L3_SMART_THREADING_H +#define BLIS_L3_SMART_THREADING_H + +// Smart threading encompasses the following multi-threading related +// optimizations: +// 1. Selection of optimal number of threads (BLIS_NUM_THREADS) based +// on matrix dimensions. +// 2. Factorization of threads along m and n dimensions (BLIS_IC_NT, +// BLIS_JC_NT) based on matrix dimensions and cache friendliness. +// 3. Transformation of native to SUP path based on the per thread matrix +// dimensions after thread factorization, given that per thread dimensions +// are within SUP limits. +// 4. Enabling packing of B alone in SUP path if native -> SUP path +// transformation happened and depending on per thread matrix dimensions. +// This function captures smart threading logic fine tuned for gemm SUP path. +// Optimal thread selection is not enabled now. +err_t bli_gemm_smart_threading_sup + ( + num_t dt, + siz_t elem_size, + const bool is_rrr_rrc_rcr_crr, + const dim_t m, + const dim_t n, + const dim_t k, + const dim_t max_available_nt, + cntx_t* cntx, + rntm_t* rntm + ); + +#endif //BLIS_L3_SMART_THREADING_H + +#endif diff --git a/frame/3/bli_l3_sup.c b/frame/3/bli_l3_sup.c index a7d7a7874a..d23df8c1e5 100644 --- a/frame/3/bli_l3_sup.c +++ b/frame/3/bli_l3_sup.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2019-21, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019-22, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -101,6 +101,34 @@ err_t bli_gemmsup // that function assumes the context pointer is valid. if ( cntx == NULL ) cntx = bli_gks_query_cntx(); + // Initialize a local runtime with global settings if necessary. Note + // that in the case that a runtime is passed in, we make a local copy. + rntm_t rntm_l; + if ( rntm == NULL ) { bli_rntm_init_from_global( &rntm_l ); rntm = &rntm_l; } + else { rntm_l = *rntm; rntm = &rntm_l; } + +#ifdef AOCL_DYNAMIC + // Calculating optimal nt and corresponding factorization (ic,jc) here, so + // as to determine the matrix dimensions (A - m, B - n) per thread. This + // can be used to check if dimensions per thread falls under the SUP + // threshold and potentially move some of the native path gemm to SUP path + // in multi-threaded scenario. + err_t smart_threading = bli_smart_threading_sup( a, b, c, BLIS_GEMM, rntm, cntx ); + + if ( smart_threading != BLIS_SUCCESS ) + { + thresh_func_ft func_fp; + func_fp = bli_cntx_get_l3_thresh_func(BLIS_GEMM, cntx); + + // Return early if the sizes are beyond SUP thresholds + if ( !func_fp( a, b, c, cntx ) ) + { + AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_2, + "SUP - Sizes are beyond SUP thresholds."); + return BLIS_FAILURE; + } + } +#else thresh_func_ft func_fp; func_fp = bli_cntx_get_l3_thresh_func(BLIS_GEMM, cntx); @@ -110,26 +138,7 @@ err_t bli_gemmsup AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_2, "SUP - Sizes are beyond SUP thresholds."); return BLIS_FAILURE; } - - // Initialize a local runtime with global settings if necessary. Note - // that in the case that a runtime is passed in, we make a local copy. - rntm_t rntm_l; - if ( rntm == NULL ) { bli_rntm_init_from_global( &rntm_l ); rntm = &rntm_l; } - else { rntm_l = *rntm; rntm = &rntm_l; } - -#if 0 -const num_t dt = bli_obj_dt( c ); -const dim_t m = bli_obj_length( c ); -const dim_t n = bli_obj_width( c ); -const dim_t k = bli_obj_width_after_trans( a ); -const dim_t tm = bli_cntx_get_l3_sup_thresh_dt( dt, BLIS_MT, cntx ); -const dim_t tn = bli_cntx_get_l3_sup_thresh_dt( dt, BLIS_NT, cntx ); -const dim_t tk = bli_cntx_get_l3_sup_thresh_dt( dt, BLIS_KT, cntx ); - -printf( "dims: %d %d %d (threshs: %d %d %d)\n", - (int)m, (int)n, (int)k, (int)tm, (int)tn, (int)tk ); #endif - // We've now ruled out the following two possibilities: // - the ukernel prefers the operation as-is, and the sup thresholds are // unsatisfied. diff --git a/frame/3/bli_l3_sup_int_amd.c b/frame/3/bli_l3_sup_int_amd.c index dc2ce24d2d..e00cc54ad0 100644 --- a/frame/3/bli_l3_sup_int_amd.c +++ b/frame/3/bli_l3_sup_int_amd.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2019-21, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019-22, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -98,7 +98,8 @@ err_t bli_gemmsup_int // If the parallel thread factorization was automatic, we update it // with a new factorization based on the matrix dimensions in units - // of micropanels. + // of micropanels. However in case smart threading is enabled, + // auto_factor will be false. if ( auto_factor ) { // In the block-panel algorithm, the m dimension is parallelized diff --git a/frame/base/bli_rntm.c b/frame/base/bli_rntm.c index c15650e918..5908471cb2 100644 --- a/frame/base/bli_rntm.c +++ b/frame/base/bli_rntm.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2021 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -681,4 +681,84 @@ void bli_nthreads_optimum( return; } + +// Calculates the optimum number of threads along with the factorization +// (ic, jc) using m, n, k dimensions. This function modifies only the local +// copy of rntm with optimum threads. Since global rntm remains unchanged the +// num_threads set by application is available in global_rntm data structure. +err_t bli_smart_threading_sup + ( + obj_t* a, + obj_t* b, + obj_t* c, + opid_t family, + rntm_t* rntm, + cntx_t* cntx + ) +{ + // By default smart threading should be disabled. + err_t ret_val = BLIS_FAILURE; + +#ifndef BLIS_ENABLE_MULTITHREADING + return ret_val; +#endif + + dim_t n_threads = bli_rntm_num_threads( rntm ); + + // For non-openmp based threading, n_threads could be -1. + if ( ( n_threads == -1 ) || ( n_threads == 1 ) ) return ret_val; + + dim_t ic_way = bli_rntm_ic_ways( rntm ); + dim_t jc_way = bli_rntm_jc_ways( rntm ); + + // Dont enable smart threading if the user supplied the factorization. + if( ( ic_way > 0 ) || ( jc_way > 0 ) ) return ret_val; + + // Only supporting sgemm for now. + if ( ( family == BLIS_GEMM ) && bli_obj_is_float( c ) ) + { + dim_t k = bli_obj_width_after_trans(a); + dim_t m = 0; + dim_t n = 0; + + bool trans_A_for_kernel = FALSE; + + const stor3_t stor_id = bli_obj_stor3_from_strides( c, a, b ); + const bool is_rrr_rrc_rcr_crr = ( + stor_id == BLIS_RRR || + stor_id == BLIS_RRC || + stor_id == BLIS_RCR || + stor_id == BLIS_CRR + ); + + // The A and B matrices are swapped based on the storage type in + // var1n2m. Need to account for this when determining ic and jc + // based on m and n dimensions of A and B. + if ( is_rrr_rrc_rcr_crr ) + { + m = bli_obj_length( c ); + n = bli_obj_width( c ); + trans_A_for_kernel = bli_obj_has_trans( a ); + } + else + { + m = bli_obj_width( c ); + n = bli_obj_length( c ); + trans_A_for_kernel = bli_obj_has_trans( b ); + } + + // Take default path if transpose is enabled for A matrix. + if ( trans_A_for_kernel == FALSE ) + { + // A successfull call to smart threading api implies smart + // factorization and possibly native -> SUP path conversion. + // Optimal thread selection is not supported yet. + ret_val = bli_gemm_smart_threading_sup( bli_obj_dt( c ), + bli_obj_elem_size( c ), + is_rrr_rrc_rcr_crr, m, n, k, n_threads, + cntx, rntm ); + } + } + return ret_val; +} #endif // AOCL_DYNAMIC diff --git a/frame/base/bli_rntm.h b/frame/base/bli_rntm.h index 5e8e236af6..e28463c5ab 100644 --- a/frame/base/bli_rntm.h +++ b/frame/base/bli_rntm.h @@ -6,7 +6,7 @@ Copyright (C) 2014, The University of Texas at Austin Copyright (C) 2016, Hewlett Packard Enterprise Development LP - Copyright (C) 2018 - 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -400,6 +400,16 @@ void bli_nthreads_optimum opid_t family, rntm_t* rntm ); + +err_t bli_smart_threading_sup + ( + obj_t* a, + obj_t* b, + obj_t* c, + opid_t family, + rntm_t* rntm, + cntx_t* cntx + ); #endif #endif From 2acb3f6ed000274afab634683dd47bded4959662 Mon Sep 17 00:00:00 2001 From: Nallani Bhaskar Date: Thu, 28 Apr 2022 15:52:06 +0530 Subject: [PATCH 111/243] Tuned aocl dynamic for specific range in dgemm Description: 1. Decision logic to choose optimal number of threads for given input dgemm dimensions under aocl dynamic feature were retuned based on latest code. 2. Updated code in few file to avoid compilation warnings. 3. Added a min check for nt in bli_sgemv_var1_smart_threading function AMD-Internal: [ CPUPL-2100 ] Change-Id: I2bc70cc87c73505dd5d2bdafb06193f664760e02 --- bench/bench_ger.c | 9 +- frame/2/gemv/bli_gemv_unf_var1_amd.c | 7 +- frame/base/bli_rntm.c | 324 ++++++++++++++------------- kernels/zen/1f/bli_axpyf_zen_int_5.c | 1 - kernels/zen/bli_kernels_zen.h | 17 ++ 5 files changed, 195 insertions(+), 163 deletions(-) diff --git a/bench/bench_ger.c b/bench/bench_ger.c index f6e5b27f59..fb50c94265 100644 --- a/bench/bench_ger.c +++ b/bench/bench_ger.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2021-22, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -66,7 +66,6 @@ int main( int argc, char** argv ) dim_t p_inc = 0; // to keep track of number of inputs num_t dt; char dt_ch; - char stor_scheme; int r, n_repeats; double dtime; @@ -76,6 +75,10 @@ int main( int argc, char** argv ) FILE* fin = NULL; FILE* fout = NULL; +#ifdef CBLAS + char stor_scheme; +#endif + n_repeats = N_REPEAT; // This macro will get from Makefile. dt = DT; @@ -108,7 +111,9 @@ int main( int argc, char** argv ) inc_t incy; char tmp[256]; // to store function name, line no present in logs. +#ifdef CBLAS stor_scheme = 'C'; +#endif // {S,D,C,Z} {transa m n alpha incx incy lda} diff --git a/frame/2/gemv/bli_gemv_unf_var1_amd.c b/frame/2/gemv/bli_gemv_unf_var1_amd.c index 8295f3927e..fd399c6f84 100644 --- a/frame/2/gemv/bli_gemv_unf_var1_amd.c +++ b/frame/2/gemv/bli_gemv_unf_var1_amd.c @@ -412,10 +412,11 @@ void bli_sgemv_var1_smart_threading } } + // When the number of threads calculated is greater than the user provided value // Choose the user provided value - if(*nt > nt_max) - *nt = nt_max; + if(*nt > nt_max ) *nt = nt_max; + if(*nt <=0 ) *nt = 1; } void bli_sgemv_unf_var1 @@ -434,7 +435,6 @@ void bli_sgemv_unf_var1 { float* A1; - float* x1; float* y1; dim_t i; dim_t b_fuse, f; @@ -537,6 +537,7 @@ void bli_sgemv_unf_var1 for ( i = 0; i < n_iter; i += f ) { + float* x1; f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); A1 = a + (i )*rs_at + (0 )*cs_at; diff --git a/frame/base/bli_rntm.c b/frame/base/bli_rntm.c index 5908471cb2..f8d48c4a2e 100644 --- a/frame/base/bli_rntm.c +++ b/frame/base/bli_rntm.c @@ -53,7 +53,7 @@ void bli_rntm_init_from_global( rntm_t* rntm ) // or the latest value of number of threads, // if set by the Application using omp_set_num_threads(nt) API. #ifdef BLIS_ENABLE_OPENMP - dim_t n_threads = omp_get_max_threads(); + dim_t n_threads = omp_get_max_threads(); #endif // Acquire the mutex protecting global_rntm. @@ -63,7 +63,7 @@ void bli_rntm_init_from_global( rntm_t* rntm ) // before copying into local rntm structure. This updated value will be // used in the subsequent parallel regions. #ifdef BLIS_ENABLE_OPENMP - global_rntm.num_threads = n_threads; + global_rntm.num_threads = n_threads; #endif *rntm = global_rntm; @@ -75,14 +75,14 @@ void bli_rntm_init_from_global( rntm_t* rntm ) // ----------------------------------------------------------------------------- void bli_rntm_set_ways_for_op - ( - opid_t l3_op, - side_t side, - dim_t m, - dim_t n, - dim_t k, - rntm_t* rntm - ) + ( + opid_t l3_op, + side_t side, + dim_t m, + dim_t n, + dim_t k, + rntm_t* rntm + ) { // Set the number of ways for each loop, if needed, depending on what // kind of information is already stored in the rntm_t object. @@ -95,7 +95,7 @@ bli_rntm_print( rntm ); // Now modify the number of ways, if necessary, based on the operation. if ( l3_op == BLIS_TRMM || - l3_op == BLIS_TRSM ) + l3_op == BLIS_TRSM ) { dim_t jc = bli_rntm_jc_ways( rntm ); dim_t pc = bli_rntm_pc_ways( rntm ); @@ -169,12 +169,12 @@ bli_rntm_print( rntm ); } void bli_rntm_set_ways_from_rntm - ( - dim_t m, - dim_t n, - dim_t k, - rntm_t* rntm - ) + ( + dim_t m, + dim_t n, + dim_t k, + rntm_t* rntm + ) { dim_t nt = bli_rntm_num_threads( rntm ); @@ -252,7 +252,7 @@ void bli_rntm_set_ways_from_rntm pc = 1; bli_thread_partition_2x2( nt, m*BLIS_THREAD_RATIO_M, - n*BLIS_THREAD_RATIO_N, &ic, &jc ); + n*BLIS_THREAD_RATIO_N, &ic, &jc ); for ( ir = BLIS_THREAD_MAX_IR ; ir > 1 ; ir-- ) { @@ -290,12 +290,12 @@ void bli_rntm_set_ways_from_rntm } void bli_rntm_set_ways_from_rntm_sup - ( - dim_t m, - dim_t n, - dim_t k, - rntm_t* rntm - ) + ( + dim_t m, + dim_t n, + dim_t k, + rntm_t* rntm + ) { dim_t nt = bli_rntm_num_threads( rntm ); @@ -373,9 +373,9 @@ void bli_rntm_set_ways_from_rntm_sup pc = 1; //bli_thread_partition_2x2( nt, m*BLIS_THREAD_SUP_RATIO_M, - // n*BLIS_THREAD_SUP_RATIO_N, &ic, &jc ); + // n*BLIS_THREAD_SUP_RATIO_N, &ic, &jc ); bli_thread_partition_2x2( nt, m, - n, &ic, &jc ); + n, &ic, &jc ); //printf( "bli_rntm_set_ways_from_rntm_sup(): jc = %d ic = %d\n", (int)jc, (int)ic ); #if 0 @@ -420,9 +420,9 @@ void bli_rntm_set_ways_from_rntm_sup } void bli_rntm_print - ( - rntm_t* rntm - ) + ( + rntm_t* rntm + ) { dim_t af = bli_rntm_auto_factor( rntm ); @@ -434,35 +434,35 @@ void bli_rntm_print dim_t jr = bli_rntm_jr_ways( rntm ); dim_t ir = bli_rntm_ir_ways( rntm ); - printf( "rntm contents nt jc pc ic jr ir\n" ); + printf( "rntm contents nt jc pc ic jr ir\n" ); printf( "autofac? %1d | %4d%4d%4d%4d%4d%4d\n", (int)af, - (int)nt, (int)jc, (int)pc, - (int)ic, (int)jr, (int)ir ); + (int)nt, (int)jc, (int)pc, + (int)ic, (int)jr, (int)ir ); } // ----------------------------------------------------------------------------- dim_t bli_rntm_calc_num_threads_in - ( - bszid_t* restrict bszid_cur, - rntm_t* restrict rntm - ) + ( + bszid_t* restrict bszid_cur, + rntm_t* restrict rntm + ) { - /* // bp algorithm: - bszid_t bszids[7] = { BLIS_NC, // level 0: 5th loop - BLIS_KC, // level 1: 4th loop + /* // bp algorithm: + bszid_t bszids[7] = { BLIS_NC, // level 0: 5th loop + BLIS_KC, // level 1: 4th loop BLIS_NO_PART, // level 2: pack B - BLIS_MC, // level 3: 3rd loop + BLIS_MC, // level 3: 3rd loop BLIS_NO_PART, // level 4: pack A - BLIS_NR, // level 5: 2nd loop - BLIS_MR, // level 6: 1st loop - BLIS_KR // level 7: ukr loop - - ... // pb algorithm: - BLIS_NR, // level 5: 2nd loop - BLIS_MR, // level 6: 1st loop - BLIS_KR // level 7: ukr loop - }; */ + BLIS_NR, // level 5: 2nd loop + BLIS_MR, // level 6: 1st loop + BLIS_KR // level 7: ukr loop + + ... // pb algorithm: + BLIS_NR, // level 5: 2nd loop + BLIS_MR, // level 6: 1st loop + BLIS_KR // level 7: ukr loop + }; */ dim_t n_threads_in = 1; // Starting with the current element of the bszids array (pointed @@ -491,7 +491,7 @@ dim_t bli_rntm_calc_num_threads_in for ( ; *bszid_cur != BLIS_KR; bszid_cur++ ) { const bszid_t bszid = *bszid_cur; - dim_t cur_way = 1; + dim_t cur_way = 1; // We assume bszid is in {NC,KC,MC,NR,MR,KR} if it is not // BLIS_NO_PART. @@ -512,12 +512,12 @@ dim_t bli_rntm_calc_num_threads_in //application is available in global_rntm data structure. void bli_nthreads_optimum( - obj_t* a, - obj_t* b, - obj_t* c, - opid_t family, - rntm_t* rntm - ) + obj_t* a, + obj_t* b, + obj_t* c, + opid_t family, + rntm_t* rntm + ) { #ifndef BLIS_ENABLE_MULTITHREADING return; @@ -531,105 +531,112 @@ void bli_nthreads_optimum( if( family == BLIS_GEMM && bli_obj_is_double(c)) { - dim_t m = bli_obj_length(c); dim_t n = bli_obj_width(c); dim_t k = bli_obj_width_after_trans(a); - if( k >= 128) { - if(n <= 15) n_threads_ideal = 8; - else n_threads_ideal = 16; + if(n <= 15) + { + if(m < 128) n_threads_ideal = 8; + else if(m < 256) n_threads_ideal = 16; + else if(m < 512) n_threads_ideal = 32; + else n_threads_ideal = 64; + }else if (n <= 64) + { + if(m < 128) n_threads_ideal = 16; + else if(m < 256) n_threads_ideal = 32; + else n_threads_ideal = 64; + }else{ + if(m < 256) n_threads_ideal = 32; + else n_threads_ideal = 64; + } } else - { - if(m > 10000) - { - - /* if(n >= 96) n_threads_ideal = 16; */ - /* else n_threads_ideal = 8; */ - - // current logic is only limiting threads to - // less or equal to 64 - limits performance. - - // To deal with larger matrix sizes we need to use - // large number of threads to improve performance - - // Need to derive this upperTH - and - // if matrix -sizes are larger and user wants - // to use higher number of threads - that should be allowed. - - // if (n > UpperTH) n_threads_ideal = n_threads; - if (n > 200 ) n_threads_ideal = 64; - else if ( n > 120 ) n_threads_ideal = 32; - else if ( n > 40 ) n_threads_ideal = 16; - else if ( n > 10 ) n_threads_ideal = 8; - else /* if ( n <= 10) */ n_threads_ideal = 4; - } - else if( m > 1000) - { - if (n <= 10) n_threads_ideal = 4; - else if ( n <= 40 ) n_threads_ideal = 8; - else if ( n <= 120 ) n_threads_ideal = 16; - else if ( n <= 200 ) n_threads_ideal = 32; - else n_threads_ideal = 64; - - /* if(n < 15) n_threads_ideal = 4; */ - /* else n_threads_ideal = 8; */ - } - else if(m > 210) - { - if(n < 10) n_threads_ideal = 1; - else n_threads_ideal = 4; - } - else if(m > 150) - { - if(n < 15) n_threads_ideal = 1; - else n_threads_ideal = 4; - } - else if( ( m < 34) && (k < 68) && ( m < 34)) - { - n_threads_ideal = 1; - } - else - { - if(n < 20) n_threads_ideal = 1; - else n_threads_ideal = 4; - } + { + if(m > 10000) + { + // current logic is only limiting threads to + // less or equal to 64 - limits performance. + // To deal with larger matrix sizes we need to use + // large number of threads to improve performance + // Need to derive this upperTH - and + // if matrix -sizes are larger and user wants + // to use higher number of threads - that should be allowed. + + // if (n > UpperTH) n_threads_ideal = n_threads; + if (n > 200 ) n_threads_ideal = 64; + else if ( n > 120 ) n_threads_ideal = 32; + else if ( n > 40 ) n_threads_ideal = 16; + else if ( n > 10 ) n_threads_ideal = 8; + else n_threads_ideal = 4; + } + else if( m > 1000) + { + if (n <= 10) n_threads_ideal = 4; + else if ( n <= 512 ) n_threads_ideal = 8; + else if ( n <= 1024 ) n_threads_ideal = 16; + else if ( n <= 2048 ) n_threads_ideal = 32; + else n_threads_ideal = 64; + } + else if(m > 210) + { + if(n < 10) n_threads_ideal = 4; + else if(n <= 512) n_threads_ideal = 8; + else if(n <= 1024) n_threads_ideal = 16; + else if(n <= 2048) n_threads_ideal = 32; + else n_threads_ideal = 64; + } + else if(m > 150) + { + if(n < 10) n_threads_ideal = 2; + else if(n <= 512) n_threads_ideal = 8; + else if(n <= 1024) n_threads_ideal = 16; + else if(n <= 2048) n_threads_ideal = 32; + else n_threads_ideal = 64; + } + else if( ( m < 34) && (k < 68) && ( n < 34)) + { + n_threads_ideal = 1; + } + else + { //(m<150 && k<128) + if(n < 20) n_threads_ideal = 1; + if(n < 64) n_threads_ideal = 4; + else n_threads_ideal = 8; + } } - } else if( family == BLIS_GEMM && bli_obj_is_dcomplex(c)) - { - - dim_t m = bli_obj_length(c); - dim_t n = bli_obj_width(c); - dim_t k = bli_obj_width_after_trans(a); - - if((m<=128 || n<=128 || k<=128) && (m+n+k <= 400) ) - { - n_threads_ideal = 8; - } - else if((m<=256 || n<=256 || k<=256) && (m+n+k <= 800) ) - { - n_threads_ideal = 16; - } - } + { + dim_t m = bli_obj_length(c); + dim_t n = bli_obj_width(c); + dim_t k = bli_obj_width_after_trans(a); + + if((m<=128 || n<=128 || k<=128) && ((m+n+k) <= 400) ) + { + n_threads_ideal = 8; + } + else if((m<=256 || n<=256 || k<=256) && ((m+n+k) <= 800) ) + { + n_threads_ideal = 16; + } + } else if( family == BLIS_SYRK && bli_obj_is_double(c)) { - dim_t n = bli_obj_length(c); - dim_t k = bli_obj_width_after_trans(a); - - if( (( n <= 10) && ( k < 700)) || - (( n <= 20) && ( k <= 190)) || - (( n <= 40) && ( k <= 80)) || - (( n <= 50) && ( k <= 40)) || - (( n <= 60) && ( k <= 20)) - ) - n_threads_ideal = 1; - else - n_threads_ideal = n_threads; + dim_t n = bli_obj_length(c); + dim_t k = bli_obj_width_after_trans(a); + + if( (( n <= 10) && ( k < 700)) || + (( n <= 20) && ( k <= 190)) || + (( n <= 40) && ( k <= 80)) || + (( n <= 50) && ( k <= 40)) || + (( n <= 60) && ( k <= 20)) + ) + n_threads_ideal = 1; + else + n_threads_ideal = n_threads; } else if( family == BLIS_TRSM && bli_obj_is_double(c) ) { @@ -637,31 +644,34 @@ void bli_nthreads_optimum( dim_t n = bli_obj_width(c); #ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM - if ( (m <= 300) && (n <= 300) ) - n_threads_ideal = 8; - else if ( (m <= 400) && (n <= 400) ) - n_threads_ideal = 16; - else if ( (m <= 900) && (n <= 900) ) - n_threads_ideal = 32; + if ( (m <= 300) && (n <= 300) ) + n_threads_ideal = 8; + else if ( (m <= 400) && (n <= 400) ) + n_threads_ideal = 16; + else if ( (m <= 900) && (n <= 900) ) + n_threads_ideal = 32; #else - if ( (m <= 512) && (n <= 512) ) - n_threads_ideal = 4; + if ( (m <= 512) && (n <= 512) ) + n_threads_ideal = 4; #endif } else if( family == BLIS_TRSM && bli_obj_is_dcomplex(c)) - { - dim_t m = bli_obj_length(c); - dim_t n = bli_obj_width(c); + { + dim_t m = bli_obj_length(c); + dim_t n = bli_obj_width(c); - if((m>=64) && (m<=256) && (n>=64) && (n<=256)) - n_threads_ideal = 8; - } + if((m>=64) && (m<=256) && (n>=64) && (n<=256)) + { + n_threads_ideal = 8; + } + } else if( family == BLIS_GEMMT && bli_obj_is_double(c) ) { dim_t n = bli_obj_length(c); dim_t k = bli_obj_width_after_trans(a); dim_t product = (n*k)>>4; /* product is derived based on n and k */ - // Limit the number thread for smaller sizes: + + //Limit the number thread for smaller sizes: if(product <= 346) { n_threads_ideal = 1; diff --git a/kernels/zen/1f/bli_axpyf_zen_int_5.c b/kernels/zen/1f/bli_axpyf_zen_int_5.c index 8b1f697cec..8fea5f6498 100644 --- a/kernels/zen/1f/bli_axpyf_zen_int_5.c +++ b/kernels/zen/1f/bli_axpyf_zen_int_5.c @@ -325,7 +325,6 @@ void bli_daxpyf_zen_int_5 const dim_t fuse_fac = 5; const dim_t n_elem_per_reg = 4; - const dim_t n_iter_unroll = 2; dim_t i; diff --git a/kernels/zen/bli_kernels_zen.h b/kernels/zen/bli_kernels_zen.h index 4bba0b22f0..5444c90ea8 100644 --- a/kernels/zen/bli_kernels_zen.h +++ b/kernels/zen/bli_kernels_zen.h @@ -341,6 +341,22 @@ err_t bli_trsm_small_mt cntx_t* cntx, cntl_t* cntl ); + +void bli_multi_sgemv_4x2 + ( + conj_t conjat, + conj_t conjx, + dim_t m, + dim_t b_n, + float* restrict alpha, + float* restrict a, inc_t inca, inc_t lda, + float* restrict x, inc_t incx, + float* restrict beta, + float* restrict y, inc_t incy, + cntx_t* restrict cntx, + dim_t n_threads + ); + #endif // threshold functions @@ -369,3 +385,4 @@ void bli_dnorm2fv_unb_var1 cntx_t* cntx ); #endif + From 09b70de6352ea5fa3b621d4b60320e37d9cd08a9 Mon Sep 17 00:00:00 2001 From: satish kumar nuggu Date: Fri, 29 Apr 2022 17:13:29 +0530 Subject: [PATCH 112/243] Performance Improvement for ztrsm small sizes Details: - Handled Overflow and Underflow Vulnerabilites in ztrsm small right implementations. - Fixed failures observed in Scalapack testing. AMD-Internal: [CPUPL-2115] Change-Id: I22c1ba583e0ba14d1a4684a85fa1ca6e152e8439 --- kernels/zen/3/bli_trsm_small.c | 121 ++++++++++----------------------- 1 file changed, 35 insertions(+), 86 deletions(-) diff --git a/kernels/zen/3/bli_trsm_small.c b/kernels/zen/3/bli_trsm_small.c index f8c0ea5911..d7192a062b 100644 --- a/kernels/zen/3/bli_trsm_small.c +++ b/kernels/zen/3/bli_trsm_small.c @@ -34922,38 +34922,21 @@ BLIS_INLINE err_t bli_ztrsm_small_XAutB_XAlB { if(transa) { - ymm0 = _mm256_broadcast_pd((__m128d const *)(a11)); - ymm1 = _mm256_broadcast_pd((__m128d const *) - (a11+cs_a*1 + 1)); + ztrsm_small_pack_diag_element(is_unitdiag,a11,cs_a, + d11_pack,n_remainder); } else { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_pd((__m128d const *)(a11)); - ymm1 = _mm256_broadcast_pd((__m128d const *) - (a11+rs_a*1 + 1)); + ztrsm_small_pack_diag_element(is_unitdiag,a11,rs_a, + d11_pack,n_remainder); } - ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); -#ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm7 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - /*Taking denomerator multiplication of real & - * imaginary components*/ - ymm4 = _mm256_mul_pd(ymm1, ymm1); - /*Swapping real & imaginary component position for addition with - * respective components*/ - ymm6 = _mm256_permute4x64_pd(ymm4, 0xb1); - ymm4 = _mm256_add_pd(ymm4, ymm6); - /*Negating imaginary component of numerator*/ - ymm1 = _mm256_mul_pd(ymm1, ymm7); - /*Dividing numerator by denominator*/ - ymm1 = _mm256_div_pd(ymm1, ymm4); -#endif } else { ymm1 = _mm256_broadcast_pd((__m128d const*)&ones); + _mm256_storeu_pd((double *)(d11_pack), ymm1); } - _mm256_storeu_pd((double *)(d11_pack), ymm1); + for(i = (m-d_mr); (i+1) > 0; i -= d_mr) //loop along 'M' direction { a01 = D_A_pack; @@ -35340,30 +35323,23 @@ BLIS_INLINE err_t bli_ztrsm_small_XAutB_XAlB } if(!is_unitdiag) { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_pd((__m128d const *)(a11)); - ymm1 = _mm256_broadcast_pd((__m128d const *)&ones); - ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); -#ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm7 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - /*Taking denomerator multiplication of real & - * imaginary components*/ - ymm4 = _mm256_mul_pd(ymm1, ymm1); - /*Swapping real & imaginary component position for addition with - * respective components*/ - ymm6 = _mm256_permute4x64_pd(ymm4, 0xb1); - ymm4 = _mm256_add_pd(ymm4, ymm6); - /*Negating imaginary component of numerator*/ - ymm1 = _mm256_mul_pd(ymm1, ymm7); - /*Dividing numerator by denominator*/ - ymm1 = _mm256_div_pd(ymm1, ymm4); -#endif + if(transa) + { + ztrsm_small_pack_diag_element(is_unitdiag,a11,cs_a, + d11_pack,n_remainder); + } + else + { + ztrsm_small_pack_diag_element(is_unitdiag,a11,rs_a, + d11_pack,n_remainder); + } } else { ymm1 = _mm256_broadcast_pd((__m128d const*)&ones); + _mm256_storeu_pd((double *)(d11_pack), ymm1); } - _mm256_storeu_pd((double *)(d11_pack), ymm1); + for(i = (m-d_mr); (i+1) > 0; i -= d_mr) //loop along 'M' direction { a01 = D_A_pack; @@ -36374,39 +36350,20 @@ BLIS_INLINE err_t bli_ztrsm_small_XAltB_XAuB { if(transa) { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_pd((__m128d const *)(a11)); - ymm1 = _mm256_broadcast_pd((__m128d const *) - (a11+cs_a*1 + 1)); + ztrsm_small_pack_diag_element(is_unitdiag,a11,cs_a, + d11_pack,n_remainder); } else { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_pd((__m128d const *)(a11)); - ymm1 = _mm256_broadcast_pd((__m128d const *) - (a11+rs_a*1 + 1)); + ztrsm_small_pack_diag_element(is_unitdiag,a11,rs_a, + d11_pack,n_remainder); } - ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); -#ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm7 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - /*Taking denomerator multiplication of real & - * imaginary components*/ - ymm4 = _mm256_mul_pd(ymm1, ymm1); - /*Swapping real & imaginary component position for addition with - * respective components*/ - ymm6 = _mm256_permute4x64_pd(ymm4, 0xb1); - ymm4 = _mm256_add_pd(ymm4, ymm6); - /*Negating imaginary component of numerator*/ - ymm1 = _mm256_mul_pd(ymm1, ymm7); - /*Dividing numerator by denominator*/ - ymm1 = _mm256_div_pd(ymm1, ymm4); -#endif } else { ymm1 = _mm256_broadcast_pd((__m128d const *)&ones); + _mm256_storeu_pd((double *)(d11_pack), ymm1); } - _mm256_storeu_pd((double *)(d11_pack), ymm1); for(i = 0; (i+d_mr-1) < m; i += d_mr) //loop along 'M' direction { @@ -36793,30 +36750,22 @@ BLIS_INLINE err_t bli_ztrsm_small_XAltB_XAuB } if(!is_unitdiag) { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_pd((__m128d const *)(a11)); - ymm1 = _mm256_broadcast_pd((__m128d const *)&ones); - ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); -#ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm7 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - /*Taking denomerator multiplication of real & - * imaginary components*/ - ymm4 = _mm256_mul_pd(ymm1, ymm1); - /*Swapping real & imaginary component position for addition with - * respective components*/ - ymm6 = _mm256_permute4x64_pd(ymm4, 0xb1); - ymm4 = _mm256_add_pd(ymm4, ymm6); - /*Negating imaginary component of numerator*/ - ymm1 = _mm256_mul_pd(ymm1, ymm7); - /*Dividing numerator by denominator*/ - ymm1 = _mm256_div_pd(ymm1, ymm4); -#endif + if(transa) + { + ztrsm_small_pack_diag_element(is_unitdiag,a11,cs_a, + d11_pack,n_remainder); + } + else + { + ztrsm_small_pack_diag_element(is_unitdiag,a11,rs_a, + d11_pack,n_remainder); + } } else { ymm1 = _mm256_broadcast_pd((__m128d const *)&ones); - } - _mm256_storeu_pd((double *)(d11_pack), ymm1); + _mm256_storeu_pd((double *)(d11_pack), ymm1); + } for(i = 0; (i+d_mr-1) < m; i += d_mr) //loop along 'M' direction { From 4ccb438c18a33063c7dd9dc9555fb9c0e7ce30a2 Mon Sep 17 00:00:00 2001 From: Dipal M Zambare Date: Fri, 29 Apr 2022 10:51:55 +0530 Subject: [PATCH 113/243] Updated Zen3 architecture detection for Ryzen 5000 - Added support to detect Ryzen 5000 Desktop and APUs AMD-Internal: [CPUPL-2117] Change-Id: I312a7de1a84cf368b74ba20e58192803a9f7dace --- frame/base/bli_cpuid.c | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/frame/base/bli_cpuid.c b/frame/base/bli_cpuid.c index dfac510440..605d4c8089 100644 --- a/frame/base/bli_cpuid.c +++ b/frame/base/bli_cpuid.c @@ -328,8 +328,13 @@ bool bli_cpuid_is_zen3 // we check for all of them. const bool is_arch = - (( model <= 0x0f ) || - (0x30 <= model && model <= 0x3f )); + ( + ( model <= 0x0f ) || // EPYC and ThreadRipper + ( 0x20 <= model && model <= 0x2f ) || // Ryzen 5000 Desktop + ( 0x30 <= model && model <= 0x3f ) || // Trento + ( 0x40 <= model && model <= 0x4f ) || // RMB + ( 0x50 <= model && model <= 0x5f ) // Ryzen 5000 APU + ); if ( !is_arch ) return FALSE; From 7658067107786f9f3deb578a075c39f25caf108a Mon Sep 17 00:00:00 2001 From: Nallani Bhaskar Date: Fri, 29 Apr 2022 23:29:20 +0530 Subject: [PATCH 114/243] Added AOCL Dynamic feature for dtrmm Description: 1. Tuned number of threads to achive better performance for dtrmm AMD-Internal: [ CPUPL-2100 ] Change-Id: Ib2e3df224ba76d86185721bef1837cd7855dd593 --- frame/3/trmm/CMakeLists.txt | 18 ++- frame/3/trmm/bli_trmm_front_amd.c | 206 ++++++++++++++++++++++++++++++ frame/base/bli_rntm.c | 93 ++++++++++++++ 3 files changed, 315 insertions(+), 2 deletions(-) create mode 100644 frame/3/trmm/bli_trmm_front_amd.c diff --git a/frame/3/trmm/CMakeLists.txt b/frame/3/trmm/CMakeLists.txt index 076d7d4a6b..a3845f3858 100644 --- a/frame/3/trmm/CMakeLists.txt +++ b/frame/3/trmm/CMakeLists.txt @@ -1,12 +1,26 @@ -##Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.## +##Copyright (C) 2020-22, Advanced Micro Devices, Inc. All rights reserved.## target_sources("${PROJECT_NAME}" PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/bli_trmm_front.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_trmm_ll_ker_var2.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_trmm_lu_ker_var2.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_trmm_rl_ker_var2.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_trmm_ru_ker_var2.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_trmm_xx_ker_var2.c ) +# Select AMD specific sources for AMD configurations. +if(${TARGET_ARCH} STREQUAL zen OR +${TARGET_ARCH} STREQUAL zen2 OR +${TARGET_ARCH} STREQUAL zen3 OR +${TARGET_ARCH} STREQUAL amdzen) + target_sources("${PROJECT_NAME}" + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/bli_trmm_front_amd.c + ) +else() + target_sources("${PROJECT_NAME}" + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/bli_trmm_front.c + ) +endif() diff --git a/frame/3/trmm/bli_trmm_front_amd.c b/frame/3/trmm/bli_trmm_front_amd.c new file mode 100644 index 0000000000..2301b323a7 --- /dev/null +++ b/frame/3/trmm/bli_trmm_front_amd.c @@ -0,0 +1,206 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2022, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +void bli_trmm_front + ( + side_t side, + obj_t* alpha, + obj_t* a, + obj_t* b, + cntx_t* cntx, + rntm_t* rntm, + cntl_t* cntl + ) +{ + bli_init_once(); + + obj_t a_local; + obj_t b_local; + obj_t c_local; + + // Check parameters. + if ( bli_error_checking_is_enabled() ) + bli_trmm_check( side, alpha, a, b, &BLIS_ZERO, b, cntx ); + + // If alpha is zero, scale by beta and return. + if ( bli_obj_equals( alpha, &BLIS_ZERO ) ) + { + bli_scalm( alpha, b ); + return; + } + + // Alias A and B so we can tweak the objects if necessary. + bli_obj_alias_to( a, &a_local ); + bli_obj_alias_to( b, &b_local ); + bli_obj_alias_to( b, &c_local ); + + // We do not explicitly implement the cases where A is transposed. + // However, we can still handle them. Specifically, if A is marked as + // needing a transposition, we simply induce a transposition. This + // allows us to only explicitly implement the no-transpose cases. Once + // the transposition is induced, the correct algorithm will be called, + // since, for example, an algorithm over a transposed lower triangular + // matrix A moves in the same direction (forwards) as a non-transposed + // upper triangular matrix. And with the transposition induced, the + // matrix now appears to be upper triangular, so the upper triangular + // algorithm will grab the correct partitions, as if it were upper + // triangular (with no transpose) all along. + if ( bli_obj_has_trans( &a_local ) ) + { + bli_obj_induce_trans( &a_local ); + bli_obj_set_onlytrans( BLIS_NO_TRANSPOSE, &a_local ); + } + +#ifdef BLIS_DISABLE_TRMM_RIGHT + // NOTE: This case casts right-side trmm in terms of left side. This is + // necessary when the current subconfiguration uses a gemm microkernel + // that assumes that the packing kernel will have already duplicated + // (broadcast) element of B in the packed copy of B. Supporting + // duplication within the logic that packs micropanels from triangular + // matrices would be ugly, and so we simply don't support it. As a + // consequence, those subconfigurations need a way to force the triangular + // matrix to be on the left (and thus the general matrix to the on the + // right). So our solution is that in those cases, the subconfigurations + // simply #define BLIS_DISABLE_TRMM_RIGHT. + + // NOTE: This case casts right-side trmm in terms of left side. This can + // lead to the microkernel being executed on an output matrix with the + // microkernel's general stride IO case (unless the microkernel supports + // both both row and column IO cases as well). + + // NOTE: Casting right-side trmm in terms of left side reduces the number + // of macrokernels exercised to two (trmm_ll and trmm_lu). + + // If A is being multiplied from the right, transpose all operands + // so that we can perform the computation as if A were being multiplied + // from the left. + if ( bli_is_right( side ) ) + { + bli_toggle_side( &side ); + bli_obj_induce_trans( &a_local ); + bli_obj_induce_trans( &b_local ); + bli_obj_induce_trans( &c_local ); + } + +#else + // NOTE: This case computes right-side trmm natively with trmm_rl and + // trmm_ru macrokernels. This code path always gives us the opportunity + // to transpose the entire operation so that the effective storage format + // of the output matrix matches the microkernel's output preference. + // Thus, from a performance perspective, this case is preferred. + + // An optimization: If C is stored by rows and the micro-kernel prefers + // contiguous columns, or if C is stored by columns and the micro-kernel + // prefers contiguous rows, transpose the entire operation to allow the + // micro-kernel to access elements of C in its preferred manner. + // NOTE: We disable the optimization for 1x1 matrices since the concept + // of row- vs. column storage breaks down. + //if ( !bli_obj_is_1x1( &c_local ) ) // NOTE: This conditional should NOT + // be enabled. See issue #342 comments. + if ( bli_cntx_l3_vir_ukr_dislikes_storage_of( &c_local, BLIS_GEMM_UKR, cntx ) ) + { + bli_toggle_side( &side ); + bli_obj_induce_trans( &a_local ); + bli_obj_induce_trans( &b_local ); + bli_obj_induce_trans( &c_local ); + } + + // If A is being multiplied from the right, swap A and B so that + // the matrix will actually be on the right. + if ( bli_is_right( side ) ) + { + bli_obj_swap( &a_local, &b_local ); + } + +#endif + + // Set each alias as the root object. + // NOTE: We MUST wait until we are done potentially swapping the objects + // before setting the root fields! + bli_obj_set_as_root( &a_local ); + bli_obj_set_as_root( &b_local ); + bli_obj_set_as_root( &c_local ); + +#ifdef AOCL_DYNAMIC + // If dynamic-threading is enabled, calculate optimum number + // of threads and update in rntm + if(bli_obj_is_double(b)) + { + bli_nthreads_optimum(a, b, b, BLIS_TRMM, rntm ); + } +#endif + + // Parse and interpret the contents of the rntm_t object to properly + // set the ways of parallelism for each loop, and then make any + // additional modifications necessary for the current operation. + bli_rntm_set_ways_for_op + ( + BLIS_TRMM, + side, + bli_obj_length( &c_local ), + bli_obj_width( &c_local ), + bli_obj_width( &a_local ), + rntm + ); + + // A sort of hack for communicating the desired pach schemas for A and B + // to bli_gemm_cntl_create() (via bli_l3_thread_decorator() and + // bli_l3_cntl_create_if()). This allows us to access the schemas from + // the control tree, which hopefully reduces some confusion, particularly + // in bli_packm_init(). + pack_t schema_a = bli_cntx_schema_a_block( cntx ); + pack_t schema_b = bli_cntx_schema_b_panel( cntx ); + + bli_obj_set_pack_schema( schema_a, &a_local ); + bli_obj_set_pack_schema( schema_b, &b_local ); + + // Invoke the internal back-end. + bli_l3_thread_decorator + ( + bli_gemm_int, + BLIS_TRMM, // operation family id + alpha, + &a_local, + &b_local, + &BLIS_ZERO, + &c_local, + cntx, + rntm, + cntl + ); +} + diff --git a/frame/base/bli_rntm.c b/frame/base/bli_rntm.c index f8d48c4a2e..1d6c41528c 100644 --- a/frame/base/bli_rntm.c +++ b/frame/base/bli_rntm.c @@ -682,6 +682,99 @@ void bli_nthreads_optimum( n_threads_ideal = n_threads; } } + else if( family == BLIS_TRMM && bli_obj_is_double(c)) + { + dim_t m = bli_obj_length(c); + dim_t n = bli_obj_width(c); + + if(( n <= 32) && (m <= 32)) + { + n_threads_ideal=1; + /*If Side is Left*/ + }else + { + //Left Side + if(bli_obj_is_triangular(a)) + { + if((m < 300)) + { + if (n < 1000) + { + n_threads_ideal=8; + }else if (n < 2000) + { + n_threads_ideal=16; + }else if (n < 3000) + { + n_threads_ideal=32; + }else + { + n_threads_ideal=64; + } + }else if(m < 600) + { + if (n < 2000) + { + n_threads_ideal=16; + }else if (n < 3000) + { + n_threads_ideal=32; + }else + { + n_threads_ideal=64; + } + }else + { + if(n < 1000) + { + n_threads_ideal=32; + }else + { + n_threads_ideal=64; + } + } + }else//Right Side + { + if((n < 300)) + { + if (m < 1000) + { + n_threads_ideal=8; + }else if (m < 2000) + { + n_threads_ideal=16; + }else if (m < 3000) + { + n_threads_ideal=32; + }else + { + n_threads_ideal=64; + } + }else if(n < 600) + { + if (m < 2000) + { + n_threads_ideal=16; + }else if (m < 3000) + { + n_threads_ideal=32; + }else + { + n_threads_ideal=64; + } + }else + { + if(m < 1000) + { + n_threads_ideal=32; + }else + { + n_threads_ideal=64; + } + } + } + } + } dim_t n_threads_opt = bli_min(n_threads, n_threads_ideal); From 7247e6a150d83311a2554993f937bd4fff98a5fa Mon Sep 17 00:00:00 2001 From: Dipal M Zambare Date: Thu, 5 May 2022 12:05:40 +0530 Subject: [PATCH 115/243] Fixed crash issue in TRSM on non-avx platform. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Ensured that FMA, AVX2 based kernels are called only on platforms supporting these instructions, otherwise standard ‘C’ kernels will be called. - Code cleanup for optimization and consistency AMD-Internal: [CPUPL-2126] Change-Id: I203270892b2fad2ccc9301fb55e2bae75508e050 --- frame/compat/bla_trsm_amd.c | 228 +++++++++++++++++++----------------- 1 file changed, 119 insertions(+), 109 deletions(-) diff --git a/frame/compat/bla_trsm_amd.c b/frame/compat/bla_trsm_amd.c index 3b3850928a..f479b5eac0 100644 --- a/frame/compat/bla_trsm_amd.c +++ b/frame/compat/bla_trsm_amd.c @@ -594,10 +594,11 @@ void strsm_ bli_obj_set_struc( struca, &ao ); +#ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM // This function is invoked on all architectures including ‘generic’. // Non-AVX platforms will use the kernels derived from the context. - if (bli_cpuid_is_avx_supported() == TRUE) { -#ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM + if (bli_cpuid_is_avx_supported() == TRUE) + { /* bli_strsm_small is performing better existing native * implementations for [m,n]<=1000 for single thread. * In case of multithread when [m,n]<=128 sinlge thread implemenation @@ -624,8 +625,9 @@ void strsm_ return; } } -#endif } +#endif + bli_trsmnat ( blis_side, @@ -854,76 +856,72 @@ void dtrsm_ bli_obj_set_conjtrans( blis_transa, &ao ); bli_obj_set_struc( struca, &ao ); - + +#ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM // This function is invoked on all architectures including ‘generic’. // Non-AVX platforms will use the kernels derived from the context. - if (bli_cpuid_is_avx_supported() == TRUE) { - -#ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM - /* bli_dtrsm_small is performing better existing native - * implementations for [m,n]<=1000 for single thread. - * In case of multithread when [m,n]<=128 sinlge thread implemenation - * is doing better than native multithread */ - bool nt = bli_thread_get_is_parallel(); - if((nt==0 && m0<=1000 && n0<=1000) || - (nt && (m0+n0)<320) ) - { - err_t status; - status = bli_trsm_small - ( - blis_side, - &alphao, - &ao, - &bo, - NULL, - NULL - ); - if (status == BLIS_SUCCESS) - { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - /* Finalize BLIS. */ - bli_finalize_auto(); - return; - } - } - - //bli_trsm_small_mt is performing better than native multithread - //for certain sizes of m & n. -#ifdef BLIS_ENABLE_OPENMP - rntm_t rntm; - bli_rntm_init_from_global( &rntm ); - - // Query the total number of threads from the rntm_t object. - dim_t n_threads = bli_rntm_num_threads( &rntm ); - if ( ( (n_threads > 1) && (m0 <= 1500) && (n0 <= 1500) ) || - ( (n_threads == 32) && (m0 <= 2300) && (n0 <= 2300) ) || - ( (n_threads == 16) && (m0 <= 3800) && (n0 <= 3800) ) || - ( (n_threads == 8) && (m0 <= 2800) && (n0 <= 2800) ) || - ( (n_threads == 4) && (m0 <= 2000) && (n0 <= 2000) ) || - ( (n_threads == 2) && (m0 <= 2000) && (n0 <= 2000) ) ) + if (bli_cpuid_is_avx_supported() == TRUE) { - err_t status; - status = bli_trsm_small_mt - ( - blis_side, - &alphao, - &ao, - &bo, - NULL, - NULL - ); + /* bli_dtrsm_small is performing better existing native + * implementations for [m,n]<=1000 for single thread. + * In case of multithread when [m,n]<=128 sinlge thread implemenation + * is doing better than native multithread */ + bool nt = bli_thread_get_is_parallel(); + if ((nt == 0 && m0 <= 1000 && n0 <= 1000) || + (nt && (m0 + n0) < 320)) + { + err_t status; + status = bli_trsm_small( + blis_side, + &alphao, + &ao, + &bo, + NULL, + NULL); + if (status == BLIS_SUCCESS) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + /* Finalize BLIS. */ + bli_finalize_auto(); + return; + } + } - if ( status == BLIS_SUCCESS ) + // bli_trsm_small_mt is performing better than native multithread + // for certain sizes of m & n. +#ifdef BLIS_ENABLE_OPENMP + rntm_t rntm; + bli_rntm_init_from_global( &rntm ); + + // Query the total number of threads from the rntm_t object. + dim_t n_threads = bli_rntm_num_threads( &rntm ); + if ( ( (n_threads > 1) && (m0 <= 1500) && (n0 <= 1500) ) || + ( (n_threads == 32) && (m0 <= 2300) && (n0 <= 2300) ) || + ( (n_threads == 16) && (m0 <= 3800) && (n0 <= 3800) ) || + ( (n_threads == 8) && (m0 <= 2800) && (n0 <= 2800) ) || + ( (n_threads == 4) && (m0 <= 2000) && (n0 <= 2000) ) || + ( (n_threads == 2) && (m0 <= 2000) && (n0 <= 2000) ) ) { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + err_t status; + status = bli_trsm_small_mt( + blis_side, + &alphao, + &ao, + &bo, + NULL, + NULL); + + if ( status == BLIS_SUCCESS ) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); /* Finalize BLIS. */ bli_finalize_auto(); return; + } } - } #endif// BLIS_ENABLE_OPENMP + } // bli_cpuid_is_avx_supported #endif// END of BLIS_ENABLE_SMALL_MATRIX_TRSM - } bli_trsmnat ( @@ -1217,33 +1215,38 @@ void ztrsm_ bli_obj_set_struc( struca, &ao ); #ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM - /* bli_ztrsm_small is performing better existing native - * implementations for [m,n]<=1000 for single thread. - * In case of multithread when [m,n]<=128 sinlge thread implemenation - * is doing better than native multithread */ - bool nt = bli_thread_get_is_parallel(); - - if(((nt==0) && (m0<=500) && (n0<=500)) || - (nt && ((m0+n0)<128))) + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) { - err_t status; - status = bli_trsm_small - ( - blis_side, - &alphao, - &ao, - &bo, - NULL, - NULL - ); - if (status == BLIS_SUCCESS) + /* bli_ztrsm_small is performing better existing native + * implementations for [m,n]<=1000 for single thread. + * In case of multithread when [m,n]<=128 sinlge thread implemenation + * is doing better than native multithread */ + bool nt = bli_thread_get_is_parallel(); + + if(((nt==0) && (m0<=500) && (n0<=500)) || + (nt && ((m0+n0)<128))) { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - /* Finalize BLIS. */ - bli_finalize_auto(); - return; + err_t status; + status = bli_trsm_small + ( + blis_side, + &alphao, + &ao, + &bo, + NULL, + NULL + ); + if (status == BLIS_SUCCESS) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + /* Finalize BLIS. */ + bli_finalize_auto(); + return; + } } - } + } // bli_cpuid_is_avx_supported} #endif bli_trsmnat @@ -1535,34 +1538,41 @@ void ctrsm_ bli_obj_set_conjtrans( blis_transa, &ao ); bli_obj_set_struc( struca, &ao ); + #ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM - /* bli_ztrsm_small is performing better existing native - * implementations for [m,n]<=1000 for single thread. - * In case of multithread when [m,n]<=128 sinlge thread implemenation - * is doing better than native multithread */ - bool nt = bli_thread_get_is_parallel(); - if((nt==0 && m0<=1000 && n0<=1000) || - (nt && (m0+n0)<320) ) + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) { - err_t status; - status = bli_trsm_small - ( - blis_side, - &alphao, - &ao, - &bo, - NULL, - NULL - ); - if (status == BLIS_SUCCESS) + /* bli_ztrsm_small is performing better existing native + * implementations for [m,n]<=1000 for single thread. + * In case of multithread when [m,n]<=128 sinlge thread implemenation + * is doing better than native multithread */ + bool nt = bli_thread_get_is_parallel(); + if((nt==0 && m0<=1000 && n0<=1000) || + (nt && (m0+n0)<320) ) { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - /* Finalize BLIS. */ - bli_finalize_auto(); - return; + err_t status; + status = bli_trsm_small + ( + blis_side, + &alphao, + &ao, + &bo, + NULL, + NULL + ); + if (status == BLIS_SUCCESS) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + /* Finalize BLIS. */ + bli_finalize_auto(); + return; + } } - } + } // bli_cpuid_is_avx_supported #endif + bli_trsmnat ( blis_side, From 349dcc459a5b5e76a48de76699184aa2504e49fb Mon Sep 17 00:00:00 2001 From: Harsh Dave Date: Fri, 6 May 2022 06:34:15 -0500 Subject: [PATCH 116/243] Fixed scalapack xcsep failer due to cdotxv kernel. -Failure was observed in zen configuration as gcc flag safe-math-optimization was being used for reference kernel compilation. - Optmized kernels were being compiled without this gcc flag resulted in computation difference resulting in test case failure. AMD-Internal: [CPUPL-2121] Change-Id: I5d86e589cdea633220aecadbcab84d9b88b31f57 --- config/generic/make_defs.mk | 4 ++-- config/zen/make_defs.mk | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/config/generic/make_defs.mk b/config/generic/make_defs.mk index ee77b6cf0e..4ce2fac758 100644 --- a/config/generic/make_defs.mk +++ b/config/generic/make_defs.mk @@ -79,10 +79,10 @@ endif # Flags specific to reference kernels. CROPTFLAGS := $(CKOPTFLAGS) ifeq ($(CC_VENDOR),gcc) -CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations -ffp-contract=fast +CRVECFLAGS := $(CKVECFLAGS) else ifeq ($(CC_VENDOR),clang) -CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations -ffp-contract=fast +CRVECFLAGS := $(CKVECFLAGS) else CRVECFLAGS := $(CKVECFLAGS) endif diff --git a/config/zen/make_defs.mk b/config/zen/make_defs.mk index 08d8628bec..b4153fcbfb 100644 --- a/config/zen/make_defs.mk +++ b/config/zen/make_defs.mk @@ -68,7 +68,7 @@ endif # Flags specific to reference kernels. CROPTFLAGS := $(CKOPTFLAGS) ifeq ($(CC_VENDOR),gcc) -CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations +CRVECFLAGS := $(CKVECFLAGS) else CRVECFLAGS := $(CKVECFLAGS) endif From 8670992c3d2087d91835241094e21b0eaa9d0608 Mon Sep 17 00:00:00 2001 From: mkadavil Date: Tue, 10 May 2022 14:46:47 +0530 Subject: [PATCH 117/243] Default sgemv kernel to be used in single-threaded scenarios. - sgemv calls a multi-threading friendly kernel whenever it is compiled with open mp and multi-threading enabled. However it was observed that this kernel is not suited for scenarios where sgemv is invoked in a single-threaded context (eg: sgemv from ST sgemm fringe kernels and with matrix blocking). Falling back to the default single-threaded sgemv kernel resulted in better performance for this scenario. AMD-Internal: [CPUPL-2136] Change-Id: Ic023db4d20b2503ea45e56a839aa35de0337d5a6 --- frame/2/gemv/bli_gemv_unf_var1_amd.c | 97 +++++++++++++++------------- 1 file changed, 51 insertions(+), 46 deletions(-) diff --git a/frame/2/gemv/bli_gemv_unf_var1_amd.c b/frame/2/gemv/bli_gemv_unf_var1_amd.c index fd399c6f84..447f8dbc43 100644 --- a/frame/2/gemv/bli_gemv_unf_var1_amd.c +++ b/frame/2/gemv/bli_gemv_unf_var1_amd.c @@ -495,72 +495,77 @@ void bli_sgemv_unf_var1 // If both multithreading and OpenMP are enabled, GEMV will multithread #if defined(BLIS_ENABLE_MULTITHREADING) && defined(BLIS_ENABLE_OPENMP) - dim_t nt, nt_max; - - rntm_t rnmt_obj; + bool is_omp_mt_enabled = TRUE; +#else + bool is_omp_mt_enabled = FALSE; +#endif - b_fuse = 4; + dim_t nt_max; + rntm_t rnmt_obj; // Initialize a local runtime with global settings. bli_rntm_init_from_global( &rnmt_obj ); // Query the total number of threads from the rntm_t object. nt_max = bli_rntm_num_threads( &rnmt_obj ); - - //Setting the thread count to the maximum number of threads provided - nt = nt_max; - - // Enable smart threading when AOCL dynamic is enabled - #ifdef AOCL_DYNAMIC - bli_sgemv_var1_smart_threading(n_elem, n_iter, b_fuse, &nt, nt_max); - #endif - - // Pass the input paramaters along with the number of threads to be used - bli_multi_sgemv_4x2 - ( - conja, - conjx, - n_elem, - n_iter, - alpha, - a, cs_at, rs_at, - x, incx, - beta, - y, incy, - cntx, - nt - ); - -#else - b_fuse = 8; - - for ( i = 0; i < n_iter; i += f ) + if ( ( nt_max > 1 ) & ( is_omp_mt_enabled == TRUE ) ) { - float* x1; - f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); + b_fuse = 4; + + //Setting the thread count to the maximum number of threads provided + dim_t nt = nt_max; - A1 = a + (i )*rs_at + (0 )*cs_at; - x1 = x + (0 )*incy; - y1 = y + (i )*incy; + // Enable smart threading when AOCL dynamic is enabled + #ifdef AOCL_DYNAMIC + bli_sgemv_var1_smart_threading(n_elem, n_iter, b_fuse, &nt, nt_max); + #endif - /* y1 = beta * y1 + alpha * A1 * x; */ - bli_sdotxf_zen_int_8 + // Pass the input paramaters along with the number of threads to be used + bli_multi_sgemv_4x2 ( conja, conjx, n_elem, - f, + n_iter, alpha, - A1, cs_at, rs_at, - x1, incx, + a, cs_at, rs_at, + x, incx, beta, - y1, incy, - cntx + y, incy, + cntx, + nt ); + } + else + { + b_fuse = 8; + for ( i = 0; i < n_iter; i += f ) + { + float* x1; + f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); + + A1 = a + (i )*rs_at + (0 )*cs_at; + x1 = x + (0 )*incy; + y1 = y + (i )*incy; + + /* y1 = beta * y1 + alpha * A1 * x; */ + bli_sdotxf_zen_int_8 + ( + conja, + conjx, + n_elem, + f, + alpha, + A1, cs_at, rs_at, + x1, incx, + beta, + y1, incy, + cntx + ); + } } -#endif } INSERT_GENTFUNC_BASIC0_CZ( gemv_unf_var1 ) From 31f8820bab6e782ff3825dc8c8528e32d1ce69ec Mon Sep 17 00:00:00 2001 From: mkadavil Date: Thu, 12 May 2022 18:02:01 +0530 Subject: [PATCH 118/243] Bug fixes for open mp based multi-threaded GEMM/GEMMT SUP path. - auto_factor to be disabled if BLIS_IC_NT/BLIS_JC_NT is set irrespective of whether num_threads (BLIS_NUM_THREADS) is modified at runtime. Currently the auto_factor is enabled if num_threads > 0 and not reverted if ic/jc/pc/jr/ir ways are set in bli_rntm_set_ways_from_rntm. This results in gemm/gemmt SUP path applying 2x2 factorization of num_threads, and thereby modifying the preset factorization. This issue is not observed in native path since factorization happens without checking auto_factor value. - Setting omp threads to n_threads using omp_set_num_threads after the global_rntm n_threads update in bli_thread_set_num_threads. This ensures that in bli_rntm_init_from_global, omp_get_max_threads returns the same value as set previously. AMD-Internal: [CPUPL-2137] Change-Id: I6c5de0462c5837cfb64793c3e6d49ec3ac2b6426 --- frame/base/bli_rntm.c | 10 ++++++++++ frame/thread/bli_thread.c | 12 ++++++++++++ 2 files changed, 22 insertions(+) diff --git a/frame/base/bli_rntm.c b/frame/base/bli_rntm.c index 1d6c41528c..fbf5654b7a 100644 --- a/frame/base/bli_rntm.c +++ b/frame/base/bli_rntm.c @@ -219,6 +219,11 @@ void bli_rntm_set_ways_from_rntm if ( ic < 1 ) ic = 1; if ( jr < 1 ) jr = 1; if ( ir < 1 ) ir = 1; + + // auto factorization is to be disabled if BLIS_IC_NT/BLIS_JC_NT env + // variables are set irrespective of whether num_threads is modified + // or not. This ensures that preset factorization is prioritized. + auto_factor = FALSE; } // Now we use the values of nt_set and ways_set to determine how to @@ -340,6 +345,11 @@ void bli_rntm_set_ways_from_rntm_sup if ( ic < 1 ) ic = 1; if ( jr < 1 ) jr = 1; if ( ir < 1 ) ir = 1; + + // auto factorization is to be disabled if BLIS_IC_NT/BLIS_JC_NT env + // variables are set irrespective of whether num_threads is modified + // or not. This ensures that preset factorization is prioritized. + auto_factor = FALSE; } // Now we use the values of nt_set and ways_set to determine how to diff --git a/frame/thread/bli_thread.c b/frame/thread/bli_thread.c index f570bcc2d8..097d136e7e 100644 --- a/frame/thread/bli_thread.c +++ b/frame/thread/bli_thread.c @@ -1604,11 +1604,23 @@ void bli_thread_set_num_threads( dim_t n_threads ) // We must ensure that global_rntm has been initialized. bli_init_once(); + if ( n_threads <= 0 ) + { + n_threads = 1; + } + // Acquire the mutex protecting global_rntm. bli_pthread_mutex_lock( &global_rntm_mutex ); bli_rntm_set_num_threads_only( n_threads, &global_rntm ); +#ifdef BLIS_ENABLE_OPENMP + // In the function bli_rntm_init_from_global() we extract n_threads + // using the API omp_get_max_threads(). Following step ensures that + // omp_get_max_threads returns the same value as set here. + omp_set_num_threads( n_threads ); +#endif + // Release the mutex protecting global_rntm. bli_pthread_mutex_unlock( &global_rntm_mutex ); } From 8f9be1766b2ebd0cf250093988f7769a32293b25 Mon Sep 17 00:00:00 2001 From: Dipal M Zambare Date: Mon, 16 May 2022 15:57:12 +0530 Subject: [PATCH 119/243] Disable AOCL_VERBOSE feature - AOCL_VERBOSE implementation is causing breakage in libFLAME. Currently DTL code is duplicated in BLIS and libFLAME, Which results in duplicate symbol errors when DTL is enabled in both the libraries. - It will be addressed by making DTL as separate library. - The input logs can still be enabled by setting AOCL_DTL_LOG_ENABLE = 1 in aocldtlcf.h and recompiling the BLIS library. AMD-Internal: [CPUPL-2101] Change-Id: I8e69b68d53940e306a1d16ffbb65019def7e655a --- aocl_dtl/aocldtl.c | 4 ++-- aocl_dtl/aocldtlcf.h | 18 +++--------------- 2 files changed, 5 insertions(+), 17 deletions(-) diff --git a/aocl_dtl/aocldtl.c b/aocl_dtl/aocldtl.c index f3c1658ff8..6e7ee35102 100644 --- a/aocl_dtl/aocldtl.c +++ b/aocl_dtl/aocldtl.c @@ -59,7 +59,7 @@ AOCL_FLIST_Node *gpLogFileList = NULL; /* Global flag to check if logging is enabled or not */ -Bool gbIsLoggingEnabled = FALSE; +Bool gbIsLoggingEnabled = TRUE; #endif #if AOCL_DTL_AUTO_TRACE_ENABLE @@ -130,7 +130,7 @@ void DTL_Initialize( #if (AOCL_DTL_LOG_ENABLE || AOCL_DTL_DUMP_ENABLE) /* Check if DTL logging is requested via envoronment variable */ - gbIsLoggingEnabled = bli_env_get_var( "AOCL_VERBOSE", FALSE ); + gbIsLoggingEnabled = bli_env_get_var( "AOCL_VERBOSE", TRUE ); #endif #if AOCL_DTL_AUTO_TRACE_ENABLE diff --git a/aocl_dtl/aocldtlcf.h b/aocl_dtl/aocldtlcf.h index 9420e7d364..1f44f54405 100644 --- a/aocl_dtl/aocldtlcf.h +++ b/aocl_dtl/aocldtlcf.h @@ -20,21 +20,9 @@ enable this macro by making it to 1 else 0 */ #define AOCL_DTL_DUMP_ENABLE 0 -/* - * Logging of inputs can be enabled by two methods: - * - * 1. Using environment variable AOCL_VERBOSE. - * 2. APIs AOCL_DTL_Enable_Logs(), AOCL_DTL_Disable_Logs() - * - * The API takes precedence over environment variable. - * - * The global flag is maintain in the code to track the final - * state of the logging feature. - * - * Setting AOCL_DTL_LOG_ENABLE = 0 will disable this feature - * completely and it is not recommended. - */ -#define AOCL_DTL_LOG_ENABLE 1 +/* Macro for dumping the log If the user wants to enable input logs he has to + enable this macro by making it to 1 else 0 */ +#define AOCL_DTL_LOG_ENABLE 0 /* Select the trace level till which you want to log the data */ /* By default it will log for all levels */ From 0fddd9eb0d4d792b3292172ba26e539298a4976a Mon Sep 17 00:00:00 2001 From: Harihara Sudhan S Date: Thu, 26 Aug 2021 20:24:13 +0530 Subject: [PATCH 120/243] Level 1 Kernel: damaxv AVX512 Details: - Developed damaxv for AVX512 extension - Implemented removeNAN function that converts NAN values to negative values based on the location - Usage COMPARE256/COMPARE128 avoided in AVX512 implementation for better performance - Unrolled the loop by order of 4. Change-Id: Icf2a3606cf311ecc646aeb3db0628b293b9a3326 --- kernels/zen/1/bli_amaxv_zen_int.c | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/kernels/zen/1/bli_amaxv_zen_int.c b/kernels/zen/1/bli_amaxv_zen_int.c index 8487bdce4b..31358421e3 100644 --- a/kernels/zen/1/bli_amaxv_zen_int.c +++ b/kernels/zen/1/bli_amaxv_zen_int.c @@ -56,6 +56,15 @@ typedef union } v16sf_t; #endif +/* Union data structure to access AVX registers + One 512-bit AVX register holds 8 DP elements. */ +typedef union +{ + __m512d v; + double d[8] __attribute__((aligned(64))); +} v8df_t; + + /* Union data structure to access AVX registers One 256-bit AVX register holds 8 SP elements. */ typedef union From 9c6c76613cd18188ee80500801f402971a5f8565 Mon Sep 17 00:00:00 2001 From: Dipal M Zambare Date: Tue, 16 Nov 2021 14:35:25 +0530 Subject: [PATCH 121/243] Added support for zen4 architecture - Added configuration option for zen4 architecture - Added auto-detection of zen4 architecture - Added zen4 configuration for all checks related to AMD specific optimizations AMD-Internal: [CPUPL-1937] Change-Id: I1a1a45de04653f725aa53c30dffb6c0f7cc6e39a --- config/zen4/make_defs.mk | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/config/zen4/make_defs.mk b/config/zen4/make_defs.mk index 44e96bb0c7..1895ec95e4 100644 --- a/config/zen4/make_defs.mk +++ b/config/zen4/make_defs.mk @@ -4,7 +4,11 @@ # An object-based framework for developing high-performance BLAS-like # libraries. # +<<<<<<< HEAD # Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. +======= +# Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. +>>>>>>> 06113811... Added support for zen4 architecture # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -49,7 +53,19 @@ THIS_CONFIG := zen4 # general-purpose/configuration-agnostic flags in common.mk. You # may specify additional flags here as needed. +<<<<<<< HEAD CPPROCFLAGS := +======= +# Since we removed BLIS_CONFIG_EPYC from header file, we need to +# add it here at two places, +# CPPROCFLAGS = This will enable it for framework code +# This flag is used when configure is invoked with specific architecture +# CKOPTFLAGS = This will enable it for architecture specific kernels +# This flag is used for kernels assocaited with this architecture +# irrespective of the configuration it is built for. + +CPPROCFLAGS := -DBLIS_CONFIG_EPYC +>>>>>>> 06113811... Added support for zen4 architecture CMISCFLAGS := CPICFLAGS := CWARNFLAGS := @@ -123,6 +139,13 @@ endif # gcc CROPTFLAGS := $(CKOPTFLAGS) CRVECFLAGS := $(CKVECFLAGS) +<<<<<<< HEAD +======= +# Add this after updating variables for reference kernels +# we don't want this defined for them +CKOPTFLAGS += -DBLIS_CONFIG_EPYC + +>>>>>>> 06113811... Added support for zen4 architecture # Store all of the variables here to new variables containing the # configuration name. $(eval $(call store-make-defs,$(THIS_CONFIG))) From 43c16d8e085344a641c907407e74e8c7efbf0938 Mon Sep 17 00:00:00 2001 From: Chandrashekara K R Date: Wed, 22 Dec 2021 14:47:15 +0530 Subject: [PATCH 122/243] AOCL-Windows: Updating the blis windows build system. 1. Removed the libomp.lib hardcoded from cmake scripts and made it user configurable. By default libomp.lib is used as an omp library. 2. Added the STATIC_LIBRARY_OPTIONS property in set_target_properties cmake command to link omp library to build static-mt blis library. 3. Updated the blis_ref_kernel_mirror.py to give support for zen4 architecture. AMD-Internal: CPUPL-1630 Change-Id: I54b04cde2fa6a1ddc4b4303f1da808c1efe0484a --- CMakeLists.txt | 39 ++++++---------- test/CMakeLists.txt | 98 ++++++++++++++++++++-------------------- testsuite/CMakeLists.txt | 6 +-- 3 files changed, 65 insertions(+), 78 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index bcb67f2ccf..3affe7b40c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,4 +1,4 @@ -##Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.## +##Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved.## cmake_minimum_required(VERSION 3.0.0) @@ -10,8 +10,7 @@ set(CMAKE_RUNTIME_OUTPUT_DIRECTORY "${CMAKE_SOURCE_DIR}/bin") SET(AOCL_BLIS_FAMILY "zen" CACHE STRING "AOCL BLIS family name") -SET(OpenMP_libomp_LIBRARY "C:/Program Files/LLVM/lib/libomp.lib" CACHE STRING "openmp library -path") +SET(OMP_LIB "C:\\Program Files\\LLVM\\lib\\libomp.lib" CACHE STRING "openmp library path") set(TARGET_ARCH ${AOCL_BLIS_FAMILY}) set(AOCL_BLIS_ZEN TRUE) set (PYTHON_EXE "python") @@ -100,8 +99,9 @@ option(BLIS_ENABLE_ILP64 "ENABLE BLIS ILP64" OFF) option(ENABLE_INT_TYPE_SIZE " Internal BLIS integers ,used in native BLIS interfaces based on architecture dependent " ON) option(ENABLE_BLASTEST "Enable the blastest" OFF) option(ENABLE_TESTCPP_TESTING "Enabling testcpp" OFF) -option (ENABLE_NO_UNDERSCORE_API "export APIs without underscore" OFF) +option (ENABLE_NO_UNDERSCORE_API "export APIs without underscore" ON) option (ENABLE_UPPERCASE_API "export APIs with uppercase" OFF) +option (ENABLE_API_WRAPPER "Enable wrapper code" OFF) option (ENABLE_COMPLEX_RETURN_INTEL "Enable complex_return_intel" OFF) option (ENABLE_TRSM_PREINVERSION "Enable TRSM preinversion" ON) option (ENABLE_AOCL_DYNAMIC "Enable Dynamic Multi-threading" OFF) @@ -131,6 +131,10 @@ if(ENABLE_UPPERCASE_API) add_definitions(-DBLIS_ENABLE_UPPERCASE_API) endif() +if(ENABLE_API_WRAPPER) + add_definitions(-DBLIS_ENABLE_API_WRAPPER) +endif() + if(ENABLE_AOCL_DYNAMIC) set(AOCL_DYNAMIC TRUE) endif() @@ -265,9 +269,7 @@ if(ENABLE_MULTITHREADING) find_package(OpenMP) if (OPENMP_FOUND) set(BLIS_ENABLE_OPENMP TRUE) - set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") - set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${OpenMP_EXE_LINKER_FLAGS}") + add_compile_options(-Xclang -fopenmp) else() message (FATAL_ERROR "Openmp Not Found") endif() @@ -534,34 +536,19 @@ execute_process( OUTPUT_VARIABLE CMD_OUTPUT) message( STATUS "Generating monolithic header file :" ${CMD_OUTPUT}) -# Logic to generate the cblas.h in include folder. -set(CBLAS_H "cblas.h") -# Arguements for python script -set(C_COMMENT "-c") -set(VERBOSE "-v1") -set(INPUT "${CMAKE_SOURCE_DIR}/frame/compat/cblas/src/${CBLAS_H}") -set(OUTPUT "${CMAKE_SOURCE_DIR}/include/${TARGET_ARCH}/${CBLAS_H}") -set(TEMP_DIR "${INCLUDE}") -set(DIR_H_PATH "${HEADER_PATH}") - -# Run python script to generate monolithic header at configuration time -execute_process( - COMMAND ${PYTHON_EXE} ${FLATTEN_PY} "${C_COMMENT}" "${VERBOSE}" "${INPUT}" "${OUTPUT}" "${TEMP_DIR}" "${DIR_H_PATH}" - RESULT_VARIABLE CMD_RESULT - OUTPUT_VARIABLE CMD_OUTPUT) -message( STATUS "Generating monolithic cblas header file :" ${CMD_OUTPUT}) - # setting the blis version string file (STRINGS "version" BLIS_VERSION) set(BLIS_VERSION_STRING ${BLIS_VERSION}) add_definitions(-DBLIS_VERSION_STRING="AOCL BLIS ${BLIS_VERSION_STRING}") +message( STATUS "OPENMP Library:" ${OMP_LIB}) + if(BUILD_SHARED_LIBS) add_library("${PROJECT_NAME}" SHARED ${CMAKE_SOURCE_DIR}/bli_config.h ${CMAKE_SOURCE_DIR}/include/${TARGET_ARCH}/blis.h ${headers}) if(ENABLE_OPENMP) - target_link_libraries("${PROJECT_NAME}" PRIVATE OpenMP::OpenMP_CXX) + target_link_libraries("${PROJECT_NAME}" PUBLIC "${OMP_LIB}") endif() target_compile_definitions("${PROJECT_NAME}" PUBLIC -DBLIS_IS_BUILDING_LIBRARY) set_target_properties("${PROJECT_NAME}" PROPERTIES LINKER_LANGUAGE C OUTPUT_NAME "${LIB_NAME}") @@ -571,7 +558,7 @@ if(NOT BUILD_SHARED_LIBS) ${CMAKE_SOURCE_DIR}/include/${TARGET_ARCH}/blis.h ${headers}) if(ENABLE_OPENMP) - set_target_properties("${PROJECT_NAME}" PROPERTIES LINKER_LANGUAGE C OUTPUT_NAME "${LIB_NAME}" STATIC_LIBRARY_OPTIONS "${OpenMP_libomp_LIBRARY}") + set_target_properties("${PROJECT_NAME}" PROPERTIES LINKER_LANGUAGE C OUTPUT_NAME "${LIB_NAME}" STATIC_LIBRARY_OPTIONS "${OMP_LIB}") else() set_target_properties("${PROJECT_NAME}" PROPERTIES LINKER_LANGUAGE C OUTPUT_NAME "${LIB_NAME}") endif() diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index d116e942d0..fe8f7bac98 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -1,172 +1,172 @@ -##Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.## +##Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved.## add_definitions(-DBLAS="AOCL") add_executable(TestAminv test_aminv.c) target_link_libraries(TestAminv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestAminv OpenMP::OpenMP_CXX) +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestAminv "${OMP_LIB}") endif() target_link_libraries(TestAminv optimized "${LIB_NAME}.lib") add_executable(TestAxpyv test_axpyv.c) target_link_libraries(TestAxpyv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestAxpyv OpenMP::OpenMP_CXX) +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestAxpyv "${OMP_LIB}") endif() target_link_libraries(TestAxpyv optimized "${LIB_NAME}.lib") add_executable(TestAxpbyv test_axpbyv.c) target_link_libraries(TestAxpbyv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestAxpbyv OpenMP::OpenMP_CXX) +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestAxpbyv "${OMP_LIB}") endif() target_link_libraries(TestAxpbyv optimized "${LIB_NAME}.lib") add_executable(TestCopyv test_copyv.c) target_link_libraries(TestCopyv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestCopyv OpenMP::OpenMP_CXX) +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestCopyv "${OMP_LIB}") endif() target_link_libraries(TestCopyv optimized "${LIB_NAME}.lib") add_executable(TestCabs1 test_cabs1.c) target_link_libraries(TestCabs1 debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestCabs1 OpenMP::OpenMP_CXX) +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestCabs1 "${OMP_LIB}") endif() target_link_libraries(TestCabs1 optimized "${LIB_NAME}.lib") add_executable(TestDotv test_dotv.c) target_link_libraries(TestDotv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestDotv OpenMP::OpenMP_CXX) +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestDotv "${OMP_LIB}") endif() target_link_libraries(TestDotv optimized "${LIB_NAME}.lib") add_executable(TestGemm test_gemm.c) target_link_libraries(TestGemm debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestGemm OpenMP::OpenMP_CXX) +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestGemm "${OMP_LIB}") endif() target_link_libraries(TestGemm optimized "${LIB_NAME}.lib") add_executable(TestGemmBatch test_gemm_batch.c) target_link_libraries(TestGemmBatch debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestGemmBatch OpenMP::OpenMP_CXX) +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestGemmBatch "${OMP_LIB}") endif() target_link_libraries(TestGemmBatch optimized "${LIB_NAME}.lib") add_executable(TestGemm3m test_gemm3m.c) target_link_libraries(TestGemm3m debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestGemm3m OpenMP::OpenMP_CXX) +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestGemm3m "${OMP_LIB}") endif() target_link_libraries(TestGemm3m optimized "${LIB_NAME}.lib") add_executable(TestGemmt test_gemmt.c) target_link_libraries(TestGemmt debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestGemmt OpenMP::OpenMP_CXX) +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestGemmt "${OMP_LIB}") endif() target_link_libraries(TestGemmt optimized "${LIB_NAME}.lib") add_executable(TestGemv test_gemv.c) target_link_libraries(TestGemv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestGemv OpenMP::OpenMP_CXX) +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestGemv "${OMP_LIB}") endif() target_link_libraries(TestGemv optimized "${LIB_NAME}.lib") add_executable(TestGer test_ger.c) target_link_libraries(TestGer debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestGer OpenMP::OpenMP_CXX) +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestGer "${OMP_LIB}") endif() target_link_libraries(TestGer optimized "${LIB_NAME}.lib") add_executable(TestHemm test_hemm.c) target_link_libraries(TestHemm debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestHemm OpenMP::OpenMP_CXX) +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestHemm "${OMP_LIB}") endif() target_link_libraries(TestHemm optimized "${LIB_NAME}.lib") add_executable(TestHemv test_hemv.c) target_link_libraries(TestHemv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestHemv OpenMP::OpenMP_CXX) +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestHemv "${OMP_LIB}") endif() target_link_libraries(TestHemv optimized "${LIB_NAME}.lib") add_executable(TestHer test_her.c) target_link_libraries(TestHer debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestHer OpenMP::OpenMP_CXX) +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestHer "${OMP_LIB}") endif() target_link_libraries(TestHer optimized "${LIB_NAME}.lib") add_executable(TestHer2 test_her2.c) target_link_libraries(TestHer2 debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestHer2 OpenMP::OpenMP_CXX) +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestHer2 "${OMP_LIB}") endif() target_link_libraries(TestHer2 optimized "${LIB_NAME}.lib") add_executable(TestHer2k test_her2k.c) target_link_libraries(TestHer2k debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestHer2k OpenMP::OpenMP_CXX) +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestHer2k "${OMP_LIB}") endif() target_link_libraries(TestHer2k optimized "${LIB_NAME}.lib") add_executable(TestHerk test_herk.c) target_link_libraries(TestHerk debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestHerk OpenMP::OpenMP_CXX) +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestHerk "${OMP_LIB}") endif() target_link_libraries(TestHerk optimized "${LIB_NAME}.lib") add_executable(TestScalv test_scalv.c) target_link_libraries(TestScalv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestScalv OpenMP::OpenMP_CXX) +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestScalv "${OMP_LIB}") endif() target_link_libraries(TestScalv optimized "${LIB_NAME}.lib") add_executable(TestSwapv test_swapv.c) target_link_libraries(TestSwapv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestSwapv OpenMP::OpenMP_CXX) +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestSwapv "${OMP_LIB}") endif() target_link_libraries(TestSwapv optimized "${LIB_NAME}.lib") add_executable(TestTrmm test_trmm.c) target_link_libraries(TestTrmm debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestTrmm OpenMP::OpenMP_CXX) +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestTrmm "${OMP_LIB}") endif() target_link_libraries(TestTrmm optimized "${LIB_NAME}.lib") add_executable(TestTrmv test_trmv.c) target_link_libraries(TestTrmv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestTrmv OpenMP::OpenMP_CXX) +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestTrmv "${OMP_LIB}") endif() target_link_libraries(TestTrmv optimized "${LIB_NAME}.lib") add_executable(TestTrsm test_trsm.c) target_link_libraries(TestTrsm debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestTrsm OpenMP::OpenMP_CXX) +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestTrsm "${OMP_LIB}") endif() target_link_libraries(TestTrsm optimized "${LIB_NAME}.lib") add_executable(TestTrsv test_trsv.c) target_link_libraries(TestTrsv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestTrsv OpenMP::OpenMP_CXX) +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestTrsv "${OMP_LIB}") endif() target_link_libraries(TestTrsv optimized "${LIB_NAME}.lib") diff --git a/testsuite/CMakeLists.txt b/testsuite/CMakeLists.txt index 85866926dd..613f9e3861 100644 --- a/testsuite/CMakeLists.txt +++ b/testsuite/CMakeLists.txt @@ -1,4 +1,4 @@ -##Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.## +##Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved.## include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src) @@ -7,8 +7,8 @@ add_executable(test_libblis "") add_subdirectory(src) target_link_libraries(test_libblis debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(test_libblis OpenMP::OpenMP_CXX) +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(test_libblis "${OMP_LIB}") endif() target_link_libraries(test_libblis optimized "${LIB_NAME}.lib") From 718c6bc024d3a6c8f4c6987e098aba41e914db8c Mon Sep 17 00:00:00 2001 From: Harsh Dave Date: Thu, 23 Dec 2021 04:44:24 -0600 Subject: [PATCH 123/243] Optimized daxpy2v implementation - Optimized axpy2v implementation for double datatype by handling rows in mulitple of 4 and store the final computed result at the end of computation, preventing unnecessary stores for improving the performance. - Optimal and reuse of vector registers for faster computation. AMD-Internal: [CPUPL-1973] Change-Id: I7b8ef94d0f67c1c666fdce26e9b2b7291365d2e9 --- kernels/zen/1f/CMakeLists.txt | 3 +++ kernels/zen/1f/bli_axpy2v_zen_int.c | 5 ++++- kernels/zen/bli_kernels_zen.h | 16 ++++++++++++++++ 3 files changed, 23 insertions(+), 1 deletion(-) diff --git a/kernels/zen/1f/CMakeLists.txt b/kernels/zen/1f/CMakeLists.txt index 3a77f69ef1..b020d8c92d 100644 --- a/kernels/zen/1f/CMakeLists.txt +++ b/kernels/zen/1f/CMakeLists.txt @@ -8,5 +8,8 @@ target_sources("${PROJECT_NAME}" ${CMAKE_CURRENT_SOURCE_DIR}/bli_axpyf_zen_int_4.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_axpyf_zen_int_6.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_axpy2v_zen_int.c +<<<<<<< HEAD ${CMAKE_CURRENT_SOURCE_DIR}/bli_dotxaxpyf_zen_int_8.c +======= +>>>>>>> 8b5b2707... Optimized daxpy2v implementation ) diff --git a/kernels/zen/1f/bli_axpy2v_zen_int.c b/kernels/zen/1f/bli_axpy2v_zen_int.c index cba0141376..26d307eda1 100644 --- a/kernels/zen/1f/bli_axpy2v_zen_int.c +++ b/kernels/zen/1f/bli_axpy2v_zen_int.c @@ -186,6 +186,7 @@ void bli_daxpy2v_zen_int ); } } +<<<<<<< HEAD /** * zaxpy2v kernel performs axpy2v operation. @@ -718,4 +719,6 @@ void bli_zaxpy2v_zen_int } } AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) -} \ No newline at end of file +} +======= +>>>>>>> 8b5b2707... Optimized daxpy2v implementation diff --git a/kernels/zen/bli_kernels_zen.h b/kernels/zen/bli_kernels_zen.h index 5444c90ea8..f7083c9150 100644 --- a/kernels/zen/bli_kernels_zen.h +++ b/kernels/zen/bli_kernels_zen.h @@ -32,6 +32,19 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ +<<<<<<< HEAD +======= +// hemv helper function +void bli_pre_hemv_8x8(double *a, double *x, + double *y, double *alpha, + dim_t cs_a, dim_t rs_a); + +void bli_post_hemv_8x8(double *a, double *x, + double *y, double *alpha, + dim_t cs_a, dim_t rs_a); + + +>>>>>>> 8b5b2707... Optimized daxpy2v implementation // -- level-1m -- PACKM_KER_PROT(double, d, packm_8xk_gen_zen) PACKM_KER_PROT(double, d, packm_6xk_gen_zen) @@ -129,7 +142,10 @@ AXPYF_KER_PROT( dcomplex, z, axpyf_zen_int_5 ) AXPYF_KER_PROT( dcomplex, z, axpyf_zen_int_4 ) // axpy2v (intrinsics) AXPY2V_KER_PROT(double, d, axpy2v_zen_int ) +<<<<<<< HEAD AXPY2V_KER_PROT(dcomplex, z, axpy2v_zen_int ) +======= +>>>>>>> 8b5b2707... Optimized daxpy2v implementation // dotxf (intrinsics) DOTXF_KER_PROT( float, s, dotxf_zen_int_8 ) From f48ced0811c66030a52326e4c28abb0f11f43c2e Mon Sep 17 00:00:00 2001 From: Harsh Dave Date: Fri, 17 Dec 2021 02:34:52 -0600 Subject: [PATCH 124/243] Optimized dher2 implementation - Impplemented her2 framework calls for transposed and non transposed kernel variants. - dher2 kernel operate over 4 columns at a time. It computes 4x4 triangular part of matrix first and remainder part is computed in chunk of 4x4 tile upto m rows. - remainder cases(m < 4) are handled serially. AMD-Internal: [CPUPL-1968] Change-Id: I12ae97b2ad673a7fd9b733c607f27b1089142313 --- frame/2/her2/bli_her2_unf_var1.c | 212 +++++++++++++++++++++++++++++++ frame/2/her2/bli_her2_unf_var4.c | 187 +++++++++++++++++++++++++++ kernels/zen/bli_kernels_zen.h | 16 --- 3 files changed, 399 insertions(+), 16 deletions(-) diff --git a/frame/2/her2/bli_her2_unf_var1.c b/frame/2/her2/bli_her2_unf_var1.c index a0aec48f71..299e3d161d 100644 --- a/frame/2/her2/bli_her2_unf_var1.c +++ b/frame/2/her2/bli_her2_unf_var1.c @@ -158,5 +158,217 @@ void PASTEMAC(ch,varname) \ } \ } + +#ifdef BLIS_CONFIG_EPYC + +/** + * Following is function declaration + * that computes her2 for transposed case. + * It handles triangular part of matrix and + * remaining computation in optimal way to + * gain performance improvement. + * a is triangular matrix, x and y are vectors + */ +void bli_dher2_trans_zen_int_4 + ( + double *a, + double *x, + double *y, + double *alpha, + dim_t m, + dim_t lda + ); + +void bli_dher2_unf_var1 + ( + uplo_t uplo, + conj_t conjx, + conj_t conjy, + conj_t conjh, + dim_t m, + double* alpha, + double* x, inc_t incx, + double* y, inc_t incy, + double* c, inc_t rs_c, inc_t cs_c, + cntx_t* cntx + ) +{ + const num_t dt = PASTEMAC(d,type); + + double* x0; + double* chi1; + double* y0; + double* psi1; + double* c10t; + double* gamma11; + double alpha0; + double alpha1; + double alpha0_chi1; + double alpha1_psi1; + double alpha0_chi1_psi1; + double conjx0_chi1; + double conjy1_psi1; + double conjy0_psi1; + dim_t i; + dim_t n_behind; + inc_t rs_ct, cs_ct; + conj_t conj0, conj1; + + /* The algorithm will be expressed in terms of the lower triangular + * case;the upper triangular case is supported by swapping the row + * and column strides of A and toggling some conj parameters. + */ + if ( bli_is_lower( uplo ) ) + { + rs_ct = rs_c; + cs_ct = cs_c; + + PASTEMAC(d,copys)( *alpha, alpha0 ); + PASTEMAC(d,copycjs)( conjh, *alpha, alpha1 ); + } + else /* if ( bli_is_upper( uplo ) ) */ + { + rs_ct = cs_c; + cs_ct = rs_c; + + /* Toggle conjugation of conjx/conjy, but only if we are being + * invoked as her2; for syr2, conjx/conjy are unchanged. + */ + conjx = bli_apply_conj( conjh, conjx ); + conjy = bli_apply_conj( conjh, conjy ); + + PASTEMAC(d,copycjs)( conjh, *alpha, alpha0 ); + PASTEMAC(d,copys)( *alpha, alpha1 ); + } + + /* Apply conjh (which carries the conjugation component of the + * Hermitian transpose, if applicable) to conjx and/or conjy as + * needed to arrive at the effective conjugation for the vector + * subproblems. + */ + conj0 = bli_apply_conj( conjh, conjy ); + conj1 = bli_apply_conj( conjh, conjx ); + + PASTECH(d,axpy2v_ker_ft) kfp_2v; + + /* Query the context for the kernel function pointer. */ + kfp_2v = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPY2V_KER, cntx ); + + if( (incx == 1) && (incy == 1) && (rs_ct == 1)) + { + for ( i = 0; i < m; ) + { + n_behind = i; + x0 = x + (0 )*incx; + chi1 = x + (i )*incx; + y0 = y + (0 )*incy; + psi1 = y + (i )*incy; + c10t = c + (i )*rs_ct + (0 )*cs_ct; + gamma11 = c + (i )*rs_ct + (i )*cs_ct; + + if((n_behind >= 3)) + { + bli_dher2_trans_zen_int_4(c10t, x0, y0, &alpha0, n_behind + 1, cs_ct); + i+=4; + } + else + { + /* Apply conjx and/or conjy to chi1 and/or psi1. */ + PASTEMAC(d,copycjs)( conjx, *chi1, conjx0_chi1 ); + PASTEMAC(d,copycjs)( conjy, *psi1, conjy1_psi1 ); + PASTEMAC(d,copycjs)( conj0, *psi1, conjy0_psi1 ); + + /* Compute scalars for vector subproblems. */ + PASTEMAC(d,scal2s)( alpha0, conjx0_chi1, alpha0_chi1 ); + PASTEMAC(d,scal2s)( alpha1, conjy1_psi1, alpha1_psi1 ); + + /* Compute alpha * chi1 * conj(psi1) after both chi1 + * and psi1 have already been conjugated, if needed, + * by conjx and conjy. + */ + PASTEMAC(d,scal2s)( alpha0_chi1, conjy0_psi1, + alpha0_chi1_psi1 ); + + /* c10t = c10t + alpha * chi1 * y0'; */ + /* c10t = c10t + conj(alpha) * psi1 * x0'; */ + kfp_2v + ( + conj0, + conj1, + n_behind, + &alpha0_chi1, + &alpha1_psi1, + y0, incy, + x0, incx, + c10t, cs_ct, + cntx + ); + + /* gamma11 = gamma11 + alpha * chi1 * conj(psi1) + + conj(alpha) * psi1 * conj(chi1); */ + PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); + PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); + + i+=1; + } + } + } + else + { + for ( i = 0; i < m; ++i ) + { + n_behind = i; + x0 = x + (0 )*incx; + chi1 = x + (i )*incx; + y0 = y + (0 )*incy; + psi1 = y + (i )*incy; + c10t = c + (i )*rs_ct + (0 )*cs_ct; + gamma11 = c + (i )*rs_ct + (i )*cs_ct; + + /* Apply conjx and/or conjy to chi1 and/or psi1. */ + PASTEMAC(d,copycjs)( conjx, *chi1, conjx0_chi1 ); + PASTEMAC(d,copycjs)( conjy, *psi1, conjy1_psi1 ); + PASTEMAC(d,copycjs)( conj0, *psi1, conjy0_psi1 ); + + /* Compute scalars for vector subproblems. */ + PASTEMAC(d,scal2s)( alpha0, conjx0_chi1, alpha0_chi1 ); + PASTEMAC(d,scal2s)( alpha1, conjy1_psi1, alpha1_psi1 ); + + /* Compute alpha * chi1 * conj(psi1) after both chi1 + * and psi1 have already been conjugated, if needed, + * by conjx and conjy. + */ + PASTEMAC(d,scal2s)( alpha0_chi1, conjy0_psi1, + alpha0_chi1_psi1 ); + + /* c10t = c10t + alpha * chi1 * y0'; */ + /* c10t = c10t + conj(alpha) * psi1 * x0'; */ + kfp_2v + ( + conj0, + conj1, + n_behind, + &alpha0_chi1, + &alpha1_psi1, + y0, incy, + x0, incx, + c10t, cs_ct, + cntx + ); + + /* gamma11 = gamma11 + alpha * chi1 * conj(psi1) + + conj(alpha) * psi1 * conj(chi1); */ + PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); + PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); + + } + } +} + +GENTFUNC(float, s, her2_unf_var1) +GENTFUNC(scomplex, c, her2_unf_var1) +GENTFUNC(dcomplex, z,her2_unf_var1) +#else INSERT_GENTFUNC_BASIC0( her2_unf_var1 ) +#endif diff --git a/frame/2/her2/bli_her2_unf_var4.c b/frame/2/her2/bli_her2_unf_var4.c index 3dea31d53e..e39c7224c4 100644 --- a/frame/2/her2/bli_her2_unf_var4.c +++ b/frame/2/her2/bli_her2_unf_var4.c @@ -166,5 +166,192 @@ void PASTEMAC(ch,varname) \ } \ } +#ifdef BLIS_CONFIG_EPYC +/** + * Following is function declaration + * that computes her2 for transposed case. + * It handles triangular part of matrix and + * remaining computation in optimal way to + * gain performance improvement. + * a is triangular matrix, x and y are vectors + */ +void bli_dher2_zen_int_4 + ( + double *a, + double *x, + double *y, + double *alpha, + dim_t m, + dim_t lda + ); + +void bli_dher2_unf_var4 + ( + uplo_t uplo, + conj_t conjx, + conj_t conjy, + conj_t conjh, + dim_t m, + double* alpha, + double* x, inc_t incx, + double* y, inc_t incy, + double* c, inc_t rs_c, inc_t cs_c, + cntx_t* cntx + ) +{ + + double* chi1; + double* x2; + double* psi1; + double* y2; + double* gamma11; + double* c21; + double alpha0; + double alpha0_psi1; + double alpha1_chi1; + double alpha0_chi1_psi1; + dim_t i; + dim_t n_ahead; + inc_t rs_ct, cs_ct; + + const num_t dt = PASTEMAC(d,type); + + /* The algorithm will be expressed in terms of the lower triangular + * case; the upper triangular case is supported by swapping the row + * and column strides of A and toggling some conj parameters. + */ + if ( bli_is_lower( uplo ) ) + { + rs_ct = rs_c; + cs_ct = cs_c; + + PASTEMAC(d,copys)( *alpha, alpha0 ); + } + else /* if ( bli_is_upper( uplo ) ) */ + { + rs_ct = cs_c; + cs_ct = rs_c; + + /* Toggle conjugation of conjx/conjy, but only if we are being + * invoked as her2; for syr2, conjx/conjy are unchanged. + */ + + PASTEMAC(d,copys)( *alpha, alpha0 ); + } + /* Apply conjh (which carries the conjugation component of the + * Hermitian transpose, if applicable) to conjx and/or conjy as + * needed to arrive at the effective conjugation for the vector + * subproblems. + */ + + PASTECH(d,axpy2v_ker_ft) kfp_2v; + + /* Query the context for the kernel function pointer. */ + kfp_2v = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPY2V_KER, cntx ); + + if((incx == 1) && (incy == 1) && (rs_ct == 1)) + { + for ( i = 0; i < m; ) + { + n_ahead = m - i - 1; + chi1 = x + (i ) * incx; + x2 = x + (i+1) * incx; + psi1 = y + (i ) * incy; + y2 = y + (i+1) * incy; + gamma11 = c + (i ) + (i )*cs_ct; + c21 = c + (i+1) + (i )*cs_ct; + + if((n_ahead >= 3)) + { + bli_dher2_zen_int_4(gamma11, chi1, psi1, &alpha0, n_ahead + 1, cs_ct); + i+= 4; + } + else + { + /* Compute scalars for vector subproblems. */ + PASTEMAC(d,scal2s)( alpha0, *psi1, alpha0_psi1 ); + PASTEMAC(d,scal2s)( alpha0, *chi1, alpha1_chi1 ); + + /* Compute alpha * chi1 * conj(psi1) after both chi1 + * and psi1 have + already been conjugated, if needed, by conjx and + conjy. */ + PASTEMAC(d,scal2s)( alpha0_psi1, *chi1, + alpha0_chi1_psi1 ); + + /* c21 = c21 + alpha * x2 * conj(psi1); */ + /* c21 = c21 + conj(alpha) * y2 * conj(chi1); */ + + kfp_2v + ( + conjx, + conjy, + n_ahead, + &alpha0_psi1, + &alpha1_chi1, + x2, incx, + y2, incy, + c21, rs_ct, + cntx + ); + + + PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); + PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); + i+=1; + } + } + } + else + { + for ( i = 0; i < m; ++i) + { + n_ahead = m - i - 1; + chi1 = x + (i ) * incx; + x2 = x + (i+1) * incx; + psi1 = y + (i ) * incy; + y2 = y + (i+1) * incy; + gamma11 = c + (i ) + (i )*cs_ct; + c21 = c + (i+1) + (i )*cs_ct; + + /* Compute scalars for vector subproblems. */ + PASTEMAC(d,scal2s)( alpha0, *psi1, alpha0_psi1 ); + PASTEMAC(d,scal2s)( alpha0, *chi1, alpha1_chi1 ); + + /* Compute alpha * chi1 * conj(psi1) after both chi1 + * and psi1 have + already been conjugated, if needed, by conjx and + conjy. */ + PASTEMAC(d,scal2s)( alpha0_psi1, *chi1, + alpha0_chi1_psi1 ); + + /* c21 = c21 + alpha * x2 * conj(psi1); */ + /* c21 = c21 + conj(alpha) * y2 * conj(chi1); */ + + kfp_2v + ( + conjx, + conjy, + n_ahead, + &alpha0_psi1, + &alpha1_chi1, + x2, incx, + y2, incy, + c21, rs_ct, + cntx + ); + + + PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); + PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); + } + } +} + +GENTFUNC(float, s, her2_unf_var4) +GENTFUNC(scomplex, c, her2_unf_var4) +GENTFUNC(dcomplex, z,her2_unf_var4) +#else INSERT_GENTFUNC_BASIC0( her2_unf_var4 ) +#endif diff --git a/kernels/zen/bli_kernels_zen.h b/kernels/zen/bli_kernels_zen.h index f7083c9150..5444c90ea8 100644 --- a/kernels/zen/bli_kernels_zen.h +++ b/kernels/zen/bli_kernels_zen.h @@ -32,19 +32,6 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ -<<<<<<< HEAD -======= -// hemv helper function -void bli_pre_hemv_8x8(double *a, double *x, - double *y, double *alpha, - dim_t cs_a, dim_t rs_a); - -void bli_post_hemv_8x8(double *a, double *x, - double *y, double *alpha, - dim_t cs_a, dim_t rs_a); - - ->>>>>>> 8b5b2707... Optimized daxpy2v implementation // -- level-1m -- PACKM_KER_PROT(double, d, packm_8xk_gen_zen) PACKM_KER_PROT(double, d, packm_6xk_gen_zen) @@ -142,10 +129,7 @@ AXPYF_KER_PROT( dcomplex, z, axpyf_zen_int_5 ) AXPYF_KER_PROT( dcomplex, z, axpyf_zen_int_4 ) // axpy2v (intrinsics) AXPY2V_KER_PROT(double, d, axpy2v_zen_int ) -<<<<<<< HEAD AXPY2V_KER_PROT(dcomplex, z, axpy2v_zen_int ) -======= ->>>>>>> 8b5b2707... Optimized daxpy2v implementation // dotxf (intrinsics) DOTXF_KER_PROT( float, s, dotxf_zen_int_8 ) From 6e2f536590401a45b73d08aa1750203c9937e5a2 Mon Sep 17 00:00:00 2001 From: Dipal M Zambare Date: Mon, 20 Dec 2021 09:43:13 +0530 Subject: [PATCH 125/243] Removed Arch specific code from BLIS framework. - Removed BLIS_CONFIG_EPYC macro - The code dependent on this macro is handled in one of the three ways -- It is updated to work across platforms. -- Added in architecture/feature specific runtime checks. -- Duplicated in AMD specific files. Build system is updated to pick AMD specific files when library is built for any of the zen architecture AMD-Internal: [CPUPL-1960] Change-Id: I6f9f8018e41fa48eb43ae4245c9c2c361857f43b --- CMakeLists.txt | 39 +++-- config/zen4/make_defs.mk | 23 --- frame/2/her2/bli_her2_unf_var1.c | 212 ---------------------------- frame/2/her2/bli_her2_unf_var4.c | 187 ------------------------ kernels/zen/1f/CMakeLists.txt | 3 - kernels/zen/1f/bli_axpy2v_zen_int.c | 5 +- kernels/zen/bli_kernels_zen.h | 10 -- 7 files changed, 27 insertions(+), 452 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 3affe7b40c..bcb67f2ccf 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,4 +1,4 @@ -##Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved.## +##Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.## cmake_minimum_required(VERSION 3.0.0) @@ -10,7 +10,8 @@ set(CMAKE_RUNTIME_OUTPUT_DIRECTORY "${CMAKE_SOURCE_DIR}/bin") SET(AOCL_BLIS_FAMILY "zen" CACHE STRING "AOCL BLIS family name") -SET(OMP_LIB "C:\\Program Files\\LLVM\\lib\\libomp.lib" CACHE STRING "openmp library path") +SET(OpenMP_libomp_LIBRARY "C:/Program Files/LLVM/lib/libomp.lib" CACHE STRING "openmp library +path") set(TARGET_ARCH ${AOCL_BLIS_FAMILY}) set(AOCL_BLIS_ZEN TRUE) set (PYTHON_EXE "python") @@ -99,9 +100,8 @@ option(BLIS_ENABLE_ILP64 "ENABLE BLIS ILP64" OFF) option(ENABLE_INT_TYPE_SIZE " Internal BLIS integers ,used in native BLIS interfaces based on architecture dependent " ON) option(ENABLE_BLASTEST "Enable the blastest" OFF) option(ENABLE_TESTCPP_TESTING "Enabling testcpp" OFF) -option (ENABLE_NO_UNDERSCORE_API "export APIs without underscore" ON) +option (ENABLE_NO_UNDERSCORE_API "export APIs without underscore" OFF) option (ENABLE_UPPERCASE_API "export APIs with uppercase" OFF) -option (ENABLE_API_WRAPPER "Enable wrapper code" OFF) option (ENABLE_COMPLEX_RETURN_INTEL "Enable complex_return_intel" OFF) option (ENABLE_TRSM_PREINVERSION "Enable TRSM preinversion" ON) option (ENABLE_AOCL_DYNAMIC "Enable Dynamic Multi-threading" OFF) @@ -131,10 +131,6 @@ if(ENABLE_UPPERCASE_API) add_definitions(-DBLIS_ENABLE_UPPERCASE_API) endif() -if(ENABLE_API_WRAPPER) - add_definitions(-DBLIS_ENABLE_API_WRAPPER) -endif() - if(ENABLE_AOCL_DYNAMIC) set(AOCL_DYNAMIC TRUE) endif() @@ -269,7 +265,9 @@ if(ENABLE_MULTITHREADING) find_package(OpenMP) if (OPENMP_FOUND) set(BLIS_ENABLE_OPENMP TRUE) - add_compile_options(-Xclang -fopenmp) + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") + set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${OpenMP_EXE_LINKER_FLAGS}") else() message (FATAL_ERROR "Openmp Not Found") endif() @@ -536,19 +534,34 @@ execute_process( OUTPUT_VARIABLE CMD_OUTPUT) message( STATUS "Generating monolithic header file :" ${CMD_OUTPUT}) +# Logic to generate the cblas.h in include folder. +set(CBLAS_H "cblas.h") +# Arguements for python script +set(C_COMMENT "-c") +set(VERBOSE "-v1") +set(INPUT "${CMAKE_SOURCE_DIR}/frame/compat/cblas/src/${CBLAS_H}") +set(OUTPUT "${CMAKE_SOURCE_DIR}/include/${TARGET_ARCH}/${CBLAS_H}") +set(TEMP_DIR "${INCLUDE}") +set(DIR_H_PATH "${HEADER_PATH}") + +# Run python script to generate monolithic header at configuration time +execute_process( + COMMAND ${PYTHON_EXE} ${FLATTEN_PY} "${C_COMMENT}" "${VERBOSE}" "${INPUT}" "${OUTPUT}" "${TEMP_DIR}" "${DIR_H_PATH}" + RESULT_VARIABLE CMD_RESULT + OUTPUT_VARIABLE CMD_OUTPUT) +message( STATUS "Generating monolithic cblas header file :" ${CMD_OUTPUT}) + # setting the blis version string file (STRINGS "version" BLIS_VERSION) set(BLIS_VERSION_STRING ${BLIS_VERSION}) add_definitions(-DBLIS_VERSION_STRING="AOCL BLIS ${BLIS_VERSION_STRING}") -message( STATUS "OPENMP Library:" ${OMP_LIB}) - if(BUILD_SHARED_LIBS) add_library("${PROJECT_NAME}" SHARED ${CMAKE_SOURCE_DIR}/bli_config.h ${CMAKE_SOURCE_DIR}/include/${TARGET_ARCH}/blis.h ${headers}) if(ENABLE_OPENMP) - target_link_libraries("${PROJECT_NAME}" PUBLIC "${OMP_LIB}") + target_link_libraries("${PROJECT_NAME}" PRIVATE OpenMP::OpenMP_CXX) endif() target_compile_definitions("${PROJECT_NAME}" PUBLIC -DBLIS_IS_BUILDING_LIBRARY) set_target_properties("${PROJECT_NAME}" PROPERTIES LINKER_LANGUAGE C OUTPUT_NAME "${LIB_NAME}") @@ -558,7 +571,7 @@ if(NOT BUILD_SHARED_LIBS) ${CMAKE_SOURCE_DIR}/include/${TARGET_ARCH}/blis.h ${headers}) if(ENABLE_OPENMP) - set_target_properties("${PROJECT_NAME}" PROPERTIES LINKER_LANGUAGE C OUTPUT_NAME "${LIB_NAME}" STATIC_LIBRARY_OPTIONS "${OMP_LIB}") + set_target_properties("${PROJECT_NAME}" PROPERTIES LINKER_LANGUAGE C OUTPUT_NAME "${LIB_NAME}" STATIC_LIBRARY_OPTIONS "${OpenMP_libomp_LIBRARY}") else() set_target_properties("${PROJECT_NAME}" PROPERTIES LINKER_LANGUAGE C OUTPUT_NAME "${LIB_NAME}") endif() diff --git a/config/zen4/make_defs.mk b/config/zen4/make_defs.mk index 1895ec95e4..44e96bb0c7 100644 --- a/config/zen4/make_defs.mk +++ b/config/zen4/make_defs.mk @@ -4,11 +4,7 @@ # An object-based framework for developing high-performance BLAS-like # libraries. # -<<<<<<< HEAD # Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. -======= -# Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. ->>>>>>> 06113811... Added support for zen4 architecture # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -53,19 +49,7 @@ THIS_CONFIG := zen4 # general-purpose/configuration-agnostic flags in common.mk. You # may specify additional flags here as needed. -<<<<<<< HEAD CPPROCFLAGS := -======= -# Since we removed BLIS_CONFIG_EPYC from header file, we need to -# add it here at two places, -# CPPROCFLAGS = This will enable it for framework code -# This flag is used when configure is invoked with specific architecture -# CKOPTFLAGS = This will enable it for architecture specific kernels -# This flag is used for kernels assocaited with this architecture -# irrespective of the configuration it is built for. - -CPPROCFLAGS := -DBLIS_CONFIG_EPYC ->>>>>>> 06113811... Added support for zen4 architecture CMISCFLAGS := CPICFLAGS := CWARNFLAGS := @@ -139,13 +123,6 @@ endif # gcc CROPTFLAGS := $(CKOPTFLAGS) CRVECFLAGS := $(CKVECFLAGS) -<<<<<<< HEAD -======= -# Add this after updating variables for reference kernels -# we don't want this defined for them -CKOPTFLAGS += -DBLIS_CONFIG_EPYC - ->>>>>>> 06113811... Added support for zen4 architecture # Store all of the variables here to new variables containing the # configuration name. $(eval $(call store-make-defs,$(THIS_CONFIG))) diff --git a/frame/2/her2/bli_her2_unf_var1.c b/frame/2/her2/bli_her2_unf_var1.c index 299e3d161d..a0aec48f71 100644 --- a/frame/2/her2/bli_her2_unf_var1.c +++ b/frame/2/her2/bli_her2_unf_var1.c @@ -158,217 +158,5 @@ void PASTEMAC(ch,varname) \ } \ } - -#ifdef BLIS_CONFIG_EPYC - -/** - * Following is function declaration - * that computes her2 for transposed case. - * It handles triangular part of matrix and - * remaining computation in optimal way to - * gain performance improvement. - * a is triangular matrix, x and y are vectors - */ -void bli_dher2_trans_zen_int_4 - ( - double *a, - double *x, - double *y, - double *alpha, - dim_t m, - dim_t lda - ); - -void bli_dher2_unf_var1 - ( - uplo_t uplo, - conj_t conjx, - conj_t conjy, - conj_t conjh, - dim_t m, - double* alpha, - double* x, inc_t incx, - double* y, inc_t incy, - double* c, inc_t rs_c, inc_t cs_c, - cntx_t* cntx - ) -{ - const num_t dt = PASTEMAC(d,type); - - double* x0; - double* chi1; - double* y0; - double* psi1; - double* c10t; - double* gamma11; - double alpha0; - double alpha1; - double alpha0_chi1; - double alpha1_psi1; - double alpha0_chi1_psi1; - double conjx0_chi1; - double conjy1_psi1; - double conjy0_psi1; - dim_t i; - dim_t n_behind; - inc_t rs_ct, cs_ct; - conj_t conj0, conj1; - - /* The algorithm will be expressed in terms of the lower triangular - * case;the upper triangular case is supported by swapping the row - * and column strides of A and toggling some conj parameters. - */ - if ( bli_is_lower( uplo ) ) - { - rs_ct = rs_c; - cs_ct = cs_c; - - PASTEMAC(d,copys)( *alpha, alpha0 ); - PASTEMAC(d,copycjs)( conjh, *alpha, alpha1 ); - } - else /* if ( bli_is_upper( uplo ) ) */ - { - rs_ct = cs_c; - cs_ct = rs_c; - - /* Toggle conjugation of conjx/conjy, but only if we are being - * invoked as her2; for syr2, conjx/conjy are unchanged. - */ - conjx = bli_apply_conj( conjh, conjx ); - conjy = bli_apply_conj( conjh, conjy ); - - PASTEMAC(d,copycjs)( conjh, *alpha, alpha0 ); - PASTEMAC(d,copys)( *alpha, alpha1 ); - } - - /* Apply conjh (which carries the conjugation component of the - * Hermitian transpose, if applicable) to conjx and/or conjy as - * needed to arrive at the effective conjugation for the vector - * subproblems. - */ - conj0 = bli_apply_conj( conjh, conjy ); - conj1 = bli_apply_conj( conjh, conjx ); - - PASTECH(d,axpy2v_ker_ft) kfp_2v; - - /* Query the context for the kernel function pointer. */ - kfp_2v = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPY2V_KER, cntx ); - - if( (incx == 1) && (incy == 1) && (rs_ct == 1)) - { - for ( i = 0; i < m; ) - { - n_behind = i; - x0 = x + (0 )*incx; - chi1 = x + (i )*incx; - y0 = y + (0 )*incy; - psi1 = y + (i )*incy; - c10t = c + (i )*rs_ct + (0 )*cs_ct; - gamma11 = c + (i )*rs_ct + (i )*cs_ct; - - if((n_behind >= 3)) - { - bli_dher2_trans_zen_int_4(c10t, x0, y0, &alpha0, n_behind + 1, cs_ct); - i+=4; - } - else - { - /* Apply conjx and/or conjy to chi1 and/or psi1. */ - PASTEMAC(d,copycjs)( conjx, *chi1, conjx0_chi1 ); - PASTEMAC(d,copycjs)( conjy, *psi1, conjy1_psi1 ); - PASTEMAC(d,copycjs)( conj0, *psi1, conjy0_psi1 ); - - /* Compute scalars for vector subproblems. */ - PASTEMAC(d,scal2s)( alpha0, conjx0_chi1, alpha0_chi1 ); - PASTEMAC(d,scal2s)( alpha1, conjy1_psi1, alpha1_psi1 ); - - /* Compute alpha * chi1 * conj(psi1) after both chi1 - * and psi1 have already been conjugated, if needed, - * by conjx and conjy. - */ - PASTEMAC(d,scal2s)( alpha0_chi1, conjy0_psi1, - alpha0_chi1_psi1 ); - - /* c10t = c10t + alpha * chi1 * y0'; */ - /* c10t = c10t + conj(alpha) * psi1 * x0'; */ - kfp_2v - ( - conj0, - conj1, - n_behind, - &alpha0_chi1, - &alpha1_psi1, - y0, incy, - x0, incx, - c10t, cs_ct, - cntx - ); - - /* gamma11 = gamma11 + alpha * chi1 * conj(psi1) - + conj(alpha) * psi1 * conj(chi1); */ - PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); - PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); - - i+=1; - } - } - } - else - { - for ( i = 0; i < m; ++i ) - { - n_behind = i; - x0 = x + (0 )*incx; - chi1 = x + (i )*incx; - y0 = y + (0 )*incy; - psi1 = y + (i )*incy; - c10t = c + (i )*rs_ct + (0 )*cs_ct; - gamma11 = c + (i )*rs_ct + (i )*cs_ct; - - /* Apply conjx and/or conjy to chi1 and/or psi1. */ - PASTEMAC(d,copycjs)( conjx, *chi1, conjx0_chi1 ); - PASTEMAC(d,copycjs)( conjy, *psi1, conjy1_psi1 ); - PASTEMAC(d,copycjs)( conj0, *psi1, conjy0_psi1 ); - - /* Compute scalars for vector subproblems. */ - PASTEMAC(d,scal2s)( alpha0, conjx0_chi1, alpha0_chi1 ); - PASTEMAC(d,scal2s)( alpha1, conjy1_psi1, alpha1_psi1 ); - - /* Compute alpha * chi1 * conj(psi1) after both chi1 - * and psi1 have already been conjugated, if needed, - * by conjx and conjy. - */ - PASTEMAC(d,scal2s)( alpha0_chi1, conjy0_psi1, - alpha0_chi1_psi1 ); - - /* c10t = c10t + alpha * chi1 * y0'; */ - /* c10t = c10t + conj(alpha) * psi1 * x0'; */ - kfp_2v - ( - conj0, - conj1, - n_behind, - &alpha0_chi1, - &alpha1_psi1, - y0, incy, - x0, incx, - c10t, cs_ct, - cntx - ); - - /* gamma11 = gamma11 + alpha * chi1 * conj(psi1) - + conj(alpha) * psi1 * conj(chi1); */ - PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); - PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); - - } - } -} - -GENTFUNC(float, s, her2_unf_var1) -GENTFUNC(scomplex, c, her2_unf_var1) -GENTFUNC(dcomplex, z,her2_unf_var1) -#else INSERT_GENTFUNC_BASIC0( her2_unf_var1 ) -#endif diff --git a/frame/2/her2/bli_her2_unf_var4.c b/frame/2/her2/bli_her2_unf_var4.c index e39c7224c4..3dea31d53e 100644 --- a/frame/2/her2/bli_her2_unf_var4.c +++ b/frame/2/her2/bli_her2_unf_var4.c @@ -166,192 +166,5 @@ void PASTEMAC(ch,varname) \ } \ } -#ifdef BLIS_CONFIG_EPYC -/** - * Following is function declaration - * that computes her2 for transposed case. - * It handles triangular part of matrix and - * remaining computation in optimal way to - * gain performance improvement. - * a is triangular matrix, x and y are vectors - */ -void bli_dher2_zen_int_4 - ( - double *a, - double *x, - double *y, - double *alpha, - dim_t m, - dim_t lda - ); - -void bli_dher2_unf_var4 - ( - uplo_t uplo, - conj_t conjx, - conj_t conjy, - conj_t conjh, - dim_t m, - double* alpha, - double* x, inc_t incx, - double* y, inc_t incy, - double* c, inc_t rs_c, inc_t cs_c, - cntx_t* cntx - ) -{ - - double* chi1; - double* x2; - double* psi1; - double* y2; - double* gamma11; - double* c21; - double alpha0; - double alpha0_psi1; - double alpha1_chi1; - double alpha0_chi1_psi1; - dim_t i; - dim_t n_ahead; - inc_t rs_ct, cs_ct; - - const num_t dt = PASTEMAC(d,type); - - /* The algorithm will be expressed in terms of the lower triangular - * case; the upper triangular case is supported by swapping the row - * and column strides of A and toggling some conj parameters. - */ - if ( bli_is_lower( uplo ) ) - { - rs_ct = rs_c; - cs_ct = cs_c; - - PASTEMAC(d,copys)( *alpha, alpha0 ); - } - else /* if ( bli_is_upper( uplo ) ) */ - { - rs_ct = cs_c; - cs_ct = rs_c; - - /* Toggle conjugation of conjx/conjy, but only if we are being - * invoked as her2; for syr2, conjx/conjy are unchanged. - */ - - PASTEMAC(d,copys)( *alpha, alpha0 ); - } - /* Apply conjh (which carries the conjugation component of the - * Hermitian transpose, if applicable) to conjx and/or conjy as - * needed to arrive at the effective conjugation for the vector - * subproblems. - */ - - PASTECH(d,axpy2v_ker_ft) kfp_2v; - - /* Query the context for the kernel function pointer. */ - kfp_2v = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPY2V_KER, cntx ); - - if((incx == 1) && (incy == 1) && (rs_ct == 1)) - { - for ( i = 0; i < m; ) - { - n_ahead = m - i - 1; - chi1 = x + (i ) * incx; - x2 = x + (i+1) * incx; - psi1 = y + (i ) * incy; - y2 = y + (i+1) * incy; - gamma11 = c + (i ) + (i )*cs_ct; - c21 = c + (i+1) + (i )*cs_ct; - - if((n_ahead >= 3)) - { - bli_dher2_zen_int_4(gamma11, chi1, psi1, &alpha0, n_ahead + 1, cs_ct); - i+= 4; - } - else - { - /* Compute scalars for vector subproblems. */ - PASTEMAC(d,scal2s)( alpha0, *psi1, alpha0_psi1 ); - PASTEMAC(d,scal2s)( alpha0, *chi1, alpha1_chi1 ); - - /* Compute alpha * chi1 * conj(psi1) after both chi1 - * and psi1 have - already been conjugated, if needed, by conjx and - conjy. */ - PASTEMAC(d,scal2s)( alpha0_psi1, *chi1, - alpha0_chi1_psi1 ); - - /* c21 = c21 + alpha * x2 * conj(psi1); */ - /* c21 = c21 + conj(alpha) * y2 * conj(chi1); */ - - kfp_2v - ( - conjx, - conjy, - n_ahead, - &alpha0_psi1, - &alpha1_chi1, - x2, incx, - y2, incy, - c21, rs_ct, - cntx - ); - - - PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); - PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); - i+=1; - } - } - } - else - { - for ( i = 0; i < m; ++i) - { - n_ahead = m - i - 1; - chi1 = x + (i ) * incx; - x2 = x + (i+1) * incx; - psi1 = y + (i ) * incy; - y2 = y + (i+1) * incy; - gamma11 = c + (i ) + (i )*cs_ct; - c21 = c + (i+1) + (i )*cs_ct; - - /* Compute scalars for vector subproblems. */ - PASTEMAC(d,scal2s)( alpha0, *psi1, alpha0_psi1 ); - PASTEMAC(d,scal2s)( alpha0, *chi1, alpha1_chi1 ); - - /* Compute alpha * chi1 * conj(psi1) after both chi1 - * and psi1 have - already been conjugated, if needed, by conjx and - conjy. */ - PASTEMAC(d,scal2s)( alpha0_psi1, *chi1, - alpha0_chi1_psi1 ); - - /* c21 = c21 + alpha * x2 * conj(psi1); */ - /* c21 = c21 + conj(alpha) * y2 * conj(chi1); */ - - kfp_2v - ( - conjx, - conjy, - n_ahead, - &alpha0_psi1, - &alpha1_chi1, - x2, incx, - y2, incy, - c21, rs_ct, - cntx - ); - - - PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); - PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); - } - } -} - -GENTFUNC(float, s, her2_unf_var4) -GENTFUNC(scomplex, c, her2_unf_var4) -GENTFUNC(dcomplex, z,her2_unf_var4) -#else INSERT_GENTFUNC_BASIC0( her2_unf_var4 ) -#endif diff --git a/kernels/zen/1f/CMakeLists.txt b/kernels/zen/1f/CMakeLists.txt index b020d8c92d..3a77f69ef1 100644 --- a/kernels/zen/1f/CMakeLists.txt +++ b/kernels/zen/1f/CMakeLists.txt @@ -8,8 +8,5 @@ target_sources("${PROJECT_NAME}" ${CMAKE_CURRENT_SOURCE_DIR}/bli_axpyf_zen_int_4.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_axpyf_zen_int_6.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_axpy2v_zen_int.c -<<<<<<< HEAD ${CMAKE_CURRENT_SOURCE_DIR}/bli_dotxaxpyf_zen_int_8.c -======= ->>>>>>> 8b5b2707... Optimized daxpy2v implementation ) diff --git a/kernels/zen/1f/bli_axpy2v_zen_int.c b/kernels/zen/1f/bli_axpy2v_zen_int.c index 26d307eda1..cba0141376 100644 --- a/kernels/zen/1f/bli_axpy2v_zen_int.c +++ b/kernels/zen/1f/bli_axpy2v_zen_int.c @@ -186,7 +186,6 @@ void bli_daxpy2v_zen_int ); } } -<<<<<<< HEAD /** * zaxpy2v kernel performs axpy2v operation. @@ -719,6 +718,4 @@ void bli_zaxpy2v_zen_int } } AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) -} -======= ->>>>>>> 8b5b2707... Optimized daxpy2v implementation +} \ No newline at end of file diff --git a/kernels/zen/bli_kernels_zen.h b/kernels/zen/bli_kernels_zen.h index 5444c90ea8..e29ed61b2b 100644 --- a/kernels/zen/bli_kernels_zen.h +++ b/kernels/zen/bli_kernels_zen.h @@ -57,16 +57,6 @@ AXPBYV_KER_PROT( dcomplex, z, axpbyv_zen_int ) AXPBYV_KER_PROT( float, s, axpbyv_zen_int10 ) AXPBYV_KER_PROT( double, d, axpbyv_zen_int10 ) -// axpbyv (intrinsics) -AXPBYV_KER_PROT( float, s, axpbyv_zen_int ) -AXPBYV_KER_PROT( double, d, axpbyv_zen_int ) -AXPBYV_KER_PROT( scomplex, c, axpbyv_zen_int ) -AXPBYV_KER_PROT( dcomplex, z, axpbyv_zen_int ) - -// axpbyv (intrinsics, unrolled x10) -AXPBYV_KER_PROT( float, s, axpbyv_zen_int10 ) -AXPBYV_KER_PROT( double, d, axpbyv_zen_int10 ) - // axpyv (intrinsics) AXPYV_KER_PROT( float, s, axpyv_zen_int ) AXPYV_KER_PROT( double, d, axpyv_zen_int ) From e5d5a43eab8e2e3ebd375d582c1b51a9e12b93fe Mon Sep 17 00:00:00 2001 From: Arnav Sharma Date: Thu, 19 May 2022 13:24:34 +0530 Subject: [PATCH 126/243] Optimized ZHER Implementation - Implemented optimized her framework calls for double precision complex numbers. - The zher kernel operates over 4 columns at a time. Initially, it computes the diagonal elements of the matrix, then the 4x4 triangular part is computed and finally the remaining part is computed as 4x4 tiles of the matrix upto m rows. AMD-Internal: [CPUPL-2151] Change-Id: I27430ee33ffb901b3ef4bdd97b034e3f748e9cca --- frame/2/bli_l2_ker_prot.h | 15 +- frame/2/her/CMakeLists.txt | 22 +- frame/2/her/bli_her_unb_var1_amd.c | 280 +++++++ frame/2/her/bli_her_unb_var2_amd.c | 280 +++++++ kernels/zen/2/CMakeLists.txt | 13 +- kernels/zen/2/bli_her_zen_int_amd.c | 1046 +++++++++++++++++++++++++++ kernels/zen/bli_kernels_zen.h | 4 + 7 files changed, 1656 insertions(+), 4 deletions(-) create mode 100644 frame/2/her/bli_her_unb_var1_amd.c create mode 100644 frame/2/her/bli_her_unb_var2_amd.c create mode 100644 kernels/zen/2/bli_her_zen_int_amd.c diff --git a/frame/2/bli_l2_ker_prot.h b/frame/2/bli_l2_ker_prot.h index 82febd761f..5182b5d670 100644 --- a/frame/2/bli_l2_ker_prot.h +++ b/frame/2/bli_l2_ker_prot.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020-21, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020-22, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -54,3 +54,16 @@ void PASTEMAC(ch,opname) \ cntx_t* restrict cntx \ ); +#define HER_KER_PROT( ctype, ch, opname ) \ +\ +void PASTEMAC(ch,opname) \ + ( \ + uplo_t uplo, \ + conj_t conjx, \ + conj_t conjh, \ + dim_t m, \ + ctype* restrict alpha, \ + ctype* restrict x, inc_t incx, \ + ctype* restrict c, inc_t rs_c, inc_t cs_c, \ + cntx_t* restrict cntx \ + ); \ No newline at end of file diff --git a/frame/2/her/CMakeLists.txt b/frame/2/her/CMakeLists.txt index 37b06d2a7f..b98422feef 100644 --- a/frame/2/her/CMakeLists.txt +++ b/frame/2/her/CMakeLists.txt @@ -2,8 +2,26 @@ target_sources("${PROJECT_NAME}" PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/bli_her_unb_var1.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_her_unb_var2.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_her_var_oapi.c ) +# Select AMD specific sources for AMD configurations. +if(${TARGET_ARCH} STREQUAL zen OR +${TARGET_ARCH} STREQUAL zen2 OR +${TARGET_ARCH} STREQUAL zen3 OR +${TARGET_ARCH} STREQUAL zen4 OR +${TARGET_ARCH} STREQUAL amdzen) + target_sources("${PROJECT_NAME}" + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/bli_her_unb_var1_amd.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_her_unb_var2_amd.c + ) +else() + target_sources("${PROJECT_NAME}" + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/bli_her_unb_var1.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_her_unb_var2.c + ) +endif() + + add_subdirectory(ind) \ No newline at end of file diff --git a/frame/2/her/bli_her_unb_var1_amd.c b/frame/2/her/bli_her_unb_var1_amd.c new file mode 100644 index 0000000000..13334c5418 --- /dev/null +++ b/frame/2/her/bli_her_unb_var1_amd.c @@ -0,0 +1,280 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTEMAC(ch,varname) \ + ( \ + uplo_t uplo, \ + conj_t conjx, \ + conj_t conjh, \ + dim_t m, \ + ctype* alpha, /* complex alpha allows her variants to also perform syr. */ \ + ctype* x, inc_t incx, \ + ctype* c, inc_t rs_c, inc_t cs_c, \ + cntx_t* cntx \ + ) \ +{ \ + const num_t dt = PASTEMAC(ch,type); \ + ctype* x0; \ + ctype* chi1; \ + ctype* c10t; \ + ctype* gamma11; \ + ctype alpha_local; \ + ctype alpha_chi1; \ + ctype alpha_chi1_chi1; \ + ctype conjx0_chi1; \ + ctype conjx1_chi1; \ + dim_t i; \ + dim_t n_behind; \ + inc_t rs_ct, cs_ct; \ + conj_t conj0, conj1; \ +\ + /* Eliminate unused variable warnings. */ \ + ( void )conj0; \ +\ + /* Make a local copy of alpha and zero out the imaginary component if + we are being invoked as her, since her requires alpha to be real. */ \ + PASTEMAC(ch,copys)( *alpha, alpha_local ); \ + if ( bli_is_conj( conjh ) ) \ + { \ + PASTEMAC(ch,seti0s)( alpha_local ); \ + } \ +\ + /* The algorithm will be expressed in terms of the lower triangular case; + the upper triangular case is supported by swapping the row and column + strides of A and toggling some conj parameters. */ \ + if ( bli_is_lower( uplo ) ) \ + { \ + rs_ct = rs_c; \ + cs_ct = cs_c; \ + } \ + else /* if ( bli_is_upper( uplo ) ) */ \ + { \ + rs_ct = cs_c; \ + cs_ct = rs_c; \ +\ + /* Toggle conjugation of conjx, but only if we are being invoked + as her; for syr, conjx is unchanged. */ \ + conjx = bli_apply_conj( conjh, conjx ); \ + } \ +\ + /* Apply conjh (which carries the conjugation component of the Hermitian + transpose, if applicable) to conjx as needed to arrive at the effective + conjugation for the scalar and vector subproblems. */ \ + conj0 = conjx; \ + conj1 = bli_apply_conj( conjh, conjx ); \ +\ + PASTECH(ch,axpyv_ker_ft) kfp_av; \ +\ + /* Query the context for the kernel function pointer. */ \ + kfp_av = bli_cntx_get_l1v_ker_dt( dt, BLIS_AXPYV_KER, cntx ); \ +\ + for ( i = 0; i < m; ++i ) \ + { \ + n_behind = i; \ + x0 = x + (0 )*incx; \ + chi1 = x + (i )*incx; \ + c10t = c + (i )*rs_ct + (0 )*cs_ct; \ + gamma11 = c + (i )*rs_ct + (i )*cs_ct; \ +\ + /* Apply conjx to chi1. */ \ + PASTEMAC(ch,copycjs)( conj0, *chi1, conjx0_chi1 ); \ + PASTEMAC(ch,copycjs)( conj1, *chi1, conjx1_chi1 ); \ +\ + /* Compute scalar for vector subproblem. */ \ + PASTEMAC(ch,scal2s)( alpha_local, conjx0_chi1, alpha_chi1 ); \ +\ + /* Compute alpha * chi1 * conj(chi1) after chi1 has already been + conjugated, if needed, by conjx. */ \ + PASTEMAC(ch,scal2s)( alpha_chi1, conjx1_chi1, alpha_chi1_chi1 ); \ +\ + /* c10t = c10t + alpha * chi1 * x0'; */ \ + kfp_av \ + ( \ + conj1, \ + n_behind, \ + &alpha_chi1, \ + x0, incx, \ + c10t, cs_ct, \ + cntx \ + ); \ +\ + /* gamma11 = gamma11 + alpha * chi1 * conj(chi1); */ \ + PASTEMAC(ch,adds)( alpha_chi1_chi1, *gamma11 ); \ +\ + /* For her2, explicitly set the imaginary component of gamma11 to + zero. */ \ + if ( bli_is_conj( conjh ) ) \ + PASTEMAC(ch,seti0s)( *gamma11 ); \ + } \ +} + +INSERT_GENTFUNC_BASIC0_SD( her_unb_var1 ) +GENTFUNC( scomplex, c, her_unb_var1 ) + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTEMAC(ch,varname) \ + ( \ + uplo_t uplo, \ + conj_t conjx, \ + conj_t conjh, \ + dim_t m, \ + ctype* alpha, /* complex alpha allows her variants to also perform syr. */ \ + ctype* x, inc_t incx, \ + ctype* c, inc_t rs_c, inc_t cs_c, \ + cntx_t* cntx \ + ) \ +{ \ + const num_t dt = PASTEMAC(ch,type); \ + /* Redirect to intrinsic implementation of HER for dcomplex */ \ + if ( bli_cpuid_is_avx_supported() == TRUE && bli_is_conj(conjh) && incx == 1 ) \ + { \ + bli_zher_zen_int_var1 \ + ( \ + uplo, \ + conjx, \ + conjh, \ + m, \ + alpha, \ + x, \ + incx, \ + c, \ + rs_c, \ + cs_c, \ + cntx \ + ); \ + } \ + else \ + { \ + ctype* x0; \ + ctype* chi1; \ + ctype* c10t; \ + ctype* gamma11; \ + ctype alpha_local; \ + ctype alpha_chi1; \ + ctype alpha_chi1_chi1; \ + ctype conjx0_chi1; \ + ctype conjx1_chi1; \ + dim_t i; \ + dim_t n_behind; \ + inc_t rs_ct, cs_ct; \ + conj_t conj0, conj1; \ +\ + /* Eliminate unused variable warnings. */ \ + ( void )conj0; \ +\ + /* Make a local copy of alpha and zero out the imaginary component if + we are being invoked as her, since her requires alpha to be real. */ \ + PASTEMAC(ch,copys)( *alpha, alpha_local ); \ + if ( bli_is_conj( conjh ) ) \ + { \ + PASTEMAC(ch,seti0s)( alpha_local ); \ + } \ +\ + /* The algorithm will be expressed in terms of the lower triangular case; + the upper triangular case is supported by swapping the row and column + strides of A and toggling some conj parameters. */ \ + if ( bli_is_lower( uplo ) ) \ + { \ + rs_ct = rs_c; \ + cs_ct = cs_c; \ + } \ + else /* if ( bli_is_upper( uplo ) ) */ \ + { \ + rs_ct = cs_c; \ + cs_ct = rs_c; \ +\ + /* Toggle conjugation of conjx, but only if we are being invoked + as her; for syr, conjx is unchanged. */ \ + conjx = bli_apply_conj( conjh, conjx ); \ + } \ +\ + /* Apply conjh (which carries the conjugation component of the Hermitian + transpose, if applicable) to conjx as needed to arrive at the effective + conjugation for the scalar and vector subproblems. */ \ + conj0 = conjx; \ + conj1 = bli_apply_conj( conjh, conjx ); \ +\ + PASTECH(ch,axpyv_ker_ft) kfp_av; \ +\ + /* Query the context for the kernel function pointer. */ \ + kfp_av = bli_cntx_get_l1v_ker_dt( dt, BLIS_AXPYV_KER, cntx ); \ +\ + for ( i = 0; i < m; ++i ) \ + { \ + n_behind = i; \ + x0 = x + (0 )*incx; \ + chi1 = x + (i )*incx; \ + c10t = c + (i )*rs_ct + (0 )*cs_ct; \ + gamma11 = c + (i )*rs_ct + (i )*cs_ct; \ +\ + /* Apply conjx to chi1. */ \ + PASTEMAC(ch,copycjs)( conj0, *chi1, conjx0_chi1 ); \ + PASTEMAC(ch,copycjs)( conj1, *chi1, conjx1_chi1 ); \ +\ + /* Compute scalar for vector subproblem. */ \ + PASTEMAC(ch,scal2s)( alpha_local, conjx0_chi1, alpha_chi1 ); \ +\ + /* Compute alpha * chi1 * conj(chi1) after chi1 has already been + conjugated, if needed, by conjx. */ \ + PASTEMAC(ch,scal2s)( alpha_chi1, conjx1_chi1, alpha_chi1_chi1 ); \ +\ + /* c10t = c10t + alpha * chi1 * x0'; */ \ + kfp_av \ + ( \ + conj1, \ + n_behind, \ + &alpha_chi1, \ + x0, incx, \ + c10t, cs_ct, \ + cntx \ + ); \ +\ + /* gamma11 = gamma11 + alpha * chi1 * conj(chi1); */ \ + PASTEMAC(ch,adds)( alpha_chi1_chi1, *gamma11 ); \ +\ + /* For her2, explicitly set the imaginary component of gamma11 to + zero. */ \ + if ( bli_is_conj( conjh ) ) \ + PASTEMAC(ch,seti0s)( *gamma11 ); \ + } \ + } \ +} +GENTFUNC( dcomplex, z, her_unb_var1 ) \ No newline at end of file diff --git a/frame/2/her/bli_her_unb_var2_amd.c b/frame/2/her/bli_her_unb_var2_amd.c new file mode 100644 index 0000000000..6fb4a5d295 --- /dev/null +++ b/frame/2/her/bli_her_unb_var2_amd.c @@ -0,0 +1,280 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTEMAC(ch,varname) \ + ( \ + uplo_t uplo, \ + conj_t conjx, \ + conj_t conjh, \ + dim_t m, \ + ctype* alpha, /* complex alpha allows her variants to also perform syr. */ \ + ctype* x, inc_t incx, \ + ctype* c, inc_t rs_c, inc_t cs_c, \ + cntx_t* cntx \ + ) \ +{ \ + const num_t dt = PASTEMAC(ch,type); \ + ctype* chi1; \ + ctype* x2; \ + ctype* gamma11; \ + ctype* c21; \ + ctype alpha_local; \ + ctype alpha_chi1; \ + ctype alpha_chi1_chi1; \ + ctype conjx0_chi1; \ + ctype conjx1_chi1; \ + dim_t i; \ + dim_t n_ahead; \ + inc_t rs_ct, cs_ct; \ + conj_t conj0, conj1; \ +\ + /* Eliminate unused variable warnings. */ \ + ( void )conj0; \ +\ + /* Make a local copy of alpha and zero out the imaginary component if + we are being invoked as her, since her requires alpha to be real. */ \ + PASTEMAC(ch,copys)( *alpha, alpha_local ); \ + if ( bli_is_conj( conjh ) ) \ + { \ + PASTEMAC(ch,seti0s)( alpha_local ); \ + } \ +\ + /* The algorithm will be expressed in terms of the lower triangular case; + the upper triangular case is supported by swapping the row and column + strides of A and toggling some conj parameters. */ \ + if ( bli_is_lower( uplo ) ) \ + { \ + rs_ct = rs_c; \ + cs_ct = cs_c; \ + } \ + else /* if ( bli_is_upper( uplo ) ) */ \ + { \ + rs_ct = cs_c; \ + cs_ct = rs_c; \ +\ + /* Toggle conjugation of conjx, but only if we are being invoked + as her; for syr, conjx is unchanged. */ \ + conjx = bli_apply_conj( conjh, conjx ); \ + } \ +\ + /* Apply conjh (which carries the conjugation component of the Hermitian + transpose, if applicable) to conjx as needed to arrive at the effective + conjugation for the scalar and vector subproblems. */ \ + conj0 = bli_apply_conj( conjh, conjx ); \ + conj1 = conjx; \ +\ + PASTECH(ch,axpyv_ker_ft) kfp_av; \ +\ + /* Query the context for the kernel function pointer. */ \ + kfp_av = bli_cntx_get_l1v_ker_dt( dt, BLIS_AXPYV_KER, cntx ); \ +\ + for ( i = 0; i < m; ++i ) \ + { \ + n_ahead = m - i - 1; \ + chi1 = x + (i )*incx; \ + x2 = x + (i+1)*incx; \ + gamma11 = c + (i )*rs_ct + (i )*cs_ct; \ + c21 = c + (i+1)*rs_ct + (i )*cs_ct; \ +\ + /* Apply conjx to chi1. */ \ + PASTEMAC(ch,copycjs)( conj0, *chi1, conjx0_chi1 ); \ + PASTEMAC(ch,copycjs)( conj1, *chi1, conjx1_chi1 ); \ +\ + /* Compute scalar for vector subproblem. */ \ + PASTEMAC(ch,scal2s)( alpha_local, conjx0_chi1, alpha_chi1 ); \ +\ + /* Compute alpha * chi1 * conj(chi1) after chi1 has already been + conjugated, if needed, by conjx. */ \ + PASTEMAC(ch,scal2s)( alpha_chi1, conjx1_chi1, alpha_chi1_chi1 ); \ +\ + /* c21 = c21 + alpha * x2 * conj(chi1); */ \ + kfp_av \ + ( \ + conj1, \ + n_ahead, \ + &alpha_chi1, \ + x2, incx, \ + c21, rs_ct, \ + cntx \ + ); \ +\ + /* gamma11 = gamma11 + alpha * chi1 * conj(chi1); */ \ + PASTEMAC(ch,adds)( alpha_chi1_chi1, *gamma11 ); \ +\ + /* For her, explicitly set the imaginary component of gamma11 to + zero. */ \ + if ( bli_is_conj( conjh ) ) \ + PASTEMAC(ch,seti0s)( *gamma11 ); \ + } \ +} + +INSERT_GENTFUNC_BASIC0_SD( her_unb_var2 ) +GENTFUNC( scomplex, c, her_unb_var2 ) + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTEMAC(ch,varname) \ + ( \ + uplo_t uplo, \ + conj_t conjx, \ + conj_t conjh, \ + dim_t m, \ + ctype* alpha, /* complex alpha allows her variants to also perform syr. */ \ + ctype* x, inc_t incx, \ + ctype* c, inc_t rs_c, inc_t cs_c, \ + cntx_t* cntx \ + ) \ +{ \ + const num_t dt = PASTEMAC(ch,type); \ + /* Redirect to intrinsic implementation of HER for unit increment */ \ + if ( bli_cpuid_is_avx_supported() == TRUE && bli_is_conj(conjh) && incx == 1 ) \ + { \ + bli_zher_zen_int_var2 \ + ( \ + uplo, \ + conjx, \ + conjh, \ + m, \ + alpha, \ + x, \ + incx, \ + c, \ + rs_c, \ + cs_c, \ + cntx \ + ); \ + } \ + else \ + { \ + ctype* chi1; \ + ctype* x2; \ + ctype* gamma11; \ + ctype* c21; \ + ctype alpha_local; \ + ctype alpha_chi1; \ + ctype alpha_chi1_chi1; \ + ctype conjx0_chi1; \ + ctype conjx1_chi1; \ + dim_t i; \ + dim_t n_ahead; \ + inc_t rs_ct, cs_ct; \ + conj_t conj0, conj1; \ +\ + /* Eliminate unused variable warnings. */ \ + ( void )conj0; \ +\ + /* Make a local copy of alpha and zero out the imaginary component if + we are being invoked as her, since her requires alpha to be real. */ \ + PASTEMAC(ch,copys)( *alpha, alpha_local ); \ + if ( bli_is_conj( conjh ) ) \ + { \ + PASTEMAC(ch,seti0s)( alpha_local ); \ + } \ +\ + /* The algorithm will be expressed in terms of the lower triangular case; + the upper triangular case is supported by swapping the row and column + strides of A and toggling some conj parameters. */ \ + if ( bli_is_lower( uplo ) ) \ + { \ + rs_ct = rs_c; \ + cs_ct = cs_c; \ + } \ + else /* if ( bli_is_upper( uplo ) ) */ \ + { \ + rs_ct = cs_c; \ + cs_ct = rs_c; \ +\ + /* Toggle conjugation of conjx, but only if we are being invoked + as her; for syr, conjx is unchanged. */ \ + conjx = bli_apply_conj( conjh, conjx ); \ + } \ +\ + /* Apply conjh (which carries the conjugation component of the Hermitian + transpose, if applicable) to conjx as needed to arrive at the effective + conjugation for the scalar and vector subproblems. */ \ + conj0 = bli_apply_conj( conjh, conjx ); \ + conj1 = conjx; \ +\ + PASTECH(ch,axpyv_ker_ft) kfp_av; \ +\ + /* Query the context for the kernel function pointer. */ \ + kfp_av = bli_cntx_get_l1v_ker_dt( dt, BLIS_AXPYV_KER, cntx ); \ +\ + for ( i = 0; i < m; ++i ) \ + { \ + n_ahead = m - i - 1; \ + chi1 = x + (i )*incx; \ + x2 = x + (i+1)*incx; \ + gamma11 = c + (i )*rs_ct + (i )*cs_ct; \ + c21 = c + (i+1)*rs_ct + (i )*cs_ct; \ +\ + /* Apply conjx to chi1. */ \ + PASTEMAC(ch,copycjs)( conj0, *chi1, conjx0_chi1 ); \ + PASTEMAC(ch,copycjs)( conj1, *chi1, conjx1_chi1 ); \ +\ + /* Compute scalar for vector subproblem. */ \ + PASTEMAC(ch,scal2s)( alpha_local, conjx0_chi1, alpha_chi1 ); \ +\ + /* Compute alpha * chi1 * conj(chi1) after chi1 has already been + conjugated, if needed, by conjx. */ \ + PASTEMAC(ch,scal2s)( alpha_chi1, conjx1_chi1, alpha_chi1_chi1 ); \ +\ + /* c21 = c21 + alpha * x2 * conj(chi1); */ \ + kfp_av \ + ( \ + conj1, \ + n_ahead, \ + &alpha_chi1, \ + x2, incx, \ + c21, rs_ct, \ + cntx \ + ); \ +\ + /* gamma11 = gamma11 + alpha * chi1 * conj(chi1); */ \ + PASTEMAC(ch,adds)( alpha_chi1_chi1, *gamma11 ); \ +\ + /* For her, explicitly set the imaginary component of gamma11 to + zero. */ \ + if ( bli_is_conj( conjh ) ) \ + PASTEMAC(ch,seti0s)( *gamma11 ); \ + } \ + } \ +} +GENTFUNC( dcomplex, z, her_unb_var2 ) \ No newline at end of file diff --git a/kernels/zen/2/CMakeLists.txt b/kernels/zen/2/CMakeLists.txt index 07a9266b0e..c1fd431f47 100644 --- a/kernels/zen/2/CMakeLists.txt +++ b/kernels/zen/2/CMakeLists.txt @@ -7,5 +7,16 @@ target_sources("${PROJECT_NAME}" ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemv_zen_int_4.c ) +# Select AMD specific sources for AMD configurations. +if(${TARGET_ARCH} STREQUAL zen OR +${TARGET_ARCH} STREQUAL zen2 OR +${TARGET_ARCH} STREQUAL zen3 OR +${TARGET_ARCH} STREQUAL zen4 OR +${TARGET_ARCH} STREQUAL amdzen) + target_sources("${PROJECT_NAME}" + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/bli_her_zen_int_amd.c + ) +endif() - + add_subdirectory(ind) \ No newline at end of file diff --git a/kernels/zen/2/bli_her_zen_int_amd.c b/kernels/zen/2/bli_her_zen_int_amd.c new file mode 100644 index 0000000000..393797f8cb --- /dev/null +++ b/kernels/zen/2/bli_her_zen_int_amd.c @@ -0,0 +1,1046 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "immintrin.h" +#include "blis.h" + +/** + * Optimized implementation of ZHER for lower triangular row stored & + * upper triangular column stored matrix. + * This kernel performs: + * A := A + conj?(alpha) * conj?(x) * conj?(x)^H + * where, + * A is an m x m hermitian matrix stored in upper/lower triangular + * x is a vector of length m + * alpha is a scalar + */ +void bli_zher_zen_int_var1 +( + uplo_t uplo, + conj_t conjx, + conj_t conjh, + dim_t m, + dcomplex* restrict alpha, + dcomplex* restrict x, inc_t incx, + dcomplex* restrict c, inc_t rs_c, inc_t cs_c, + cntx_t* restrict cntx +) +{ + double xcR, xcI; + double xhermcR, xhermcI; + double alphaR; + + dcomplex* xc; + dcomplex* xhermc; + dcomplex* cc; + + __m256d alphaRv; + __m256d ymm0, ymm1, ymm4, ymm5; + __m256d ymm6, ymm7, ymm8, ymm9, ymm10, ymm11; + __m256d ymm0_shuf, ymm1_shuf; + __m256d conj_mulv; + + dim_t conj_multiplier; + + inc_t rs_ct, cs_ct; + dim_t i = 0; + dim_t j = 0; + + alphaR = alpha->real; + + // The algorithm is expressed in terms of lower triangular case; + // the upper triangular case is supported by swapping the row and column + // strides of A & toggling the conj parameter. + if ( bli_is_lower( uplo ) ) + { + rs_ct = rs_c; + cs_ct = cs_c; + } + else /* if ( bli_is_upper( uplo ) ) */ + { + rs_ct = cs_c; + cs_ct = rs_c; + conjx = bli_apply_conj( conjh, conjx ); + } + + // Enabling conj_multiplier for scalar multiplication based on conjx + if ( !bli_is_conj(conjx) ) conj_multiplier = 1; + else conj_multiplier = -1; + + // Broadcasting real values of alpha based on conjx + // alphaRv = aR aR aR aR + if ( bli_is_conj( conjx ) ) alphaRv = _mm256_broadcast_sd( &alphaR ); + else alphaRv = _mm256_set_pd( -alphaR, alphaR, -alphaR, alphaR ); + + conj_mulv = _mm256_set_pd( conj_multiplier, -1 * conj_multiplier, conj_multiplier, -1 * conj_multiplier ); + + /********* DIAGONAL ELEMENTS *********/ + // Solving for the diagonal elements using a scalar loop + for ( i = 0; i < m; i++ ) + { + xc = x + i*incx; + xcR = xc->real; + xcI = xc->imag; + cc = c + (i)*rs_ct + (i)*cs_ct; + cc->real += alphaR * ((xcR * xcR) + (xcI * xcI)); + cc->imag = 0; + } + + // Vectorized loop + for ( i = 0; ( i + 3 ) < m; i += 4 ) + { + // Loading elements of x to ymm0-1 for computing xherm vector + // ymm0 = x0R x0I x1R x1I + // ymm1 = x2R x2I x3R x3I + ymm0 = _mm256_loadu_pd( (double*)(x + i*incx) ); + ymm1 = _mm256_loadu_pd( (double*)(x + (i + 2)*incx) ); + + // Scaling xherm vector with alpha + // alphaRv = aR aR aR aR + // ymm0 = x0R -x0I x1R -x1I + // ymm1 = x2R -x2I x3R -x3I + // ymm0 * alphaRv = aR.x0R -aR.x0I aR.x1R -aR.x1I + // ymm1 * alphaRv = aR.x2R -aR.x2I aR.x3R -aR.x3I + ymm0 = _mm256_mul_pd( ymm0, alphaRv ); + ymm1 = _mm256_mul_pd( ymm1, alphaRv ); + + // Shuffling xherm vector for multiplication with x vector + // ymm0_shuf = -x0I x0R -x1I x1R + // ymm1_shuf = -x2I x2R -x3I x3R + ymm0_shuf = _mm256_permute_pd( ymm0, 5 ); + ymm1_shuf = _mm256_permute_pd( ymm1, 5 ); + + /********* TRIANGULAR BLOCK *********/ + // Solving the corner elements of the triangular block + // using scalar multiplication + xc = x + (i + 1)*incx; + xcR = xc->real; + xcI = conj_multiplier * xc->imag; + + xhermc = x + (i)*incx; + xhermcR = xhermc->real; + xhermcI = -1 * conj_multiplier * xhermc->imag; + + cc = c + (i + 1)*rs_ct + (i + 0)*cs_ct; + cc->real += alphaR * ( (xcR * xhermcR) - (xcI * xhermcI) ); + cc->imag += alphaR * ( (xcR * xhermcI) + (xcI * xhermcR) ); + + xc = x + (i + 3)*incx; + xcR = xc->real; + xcI = conj_multiplier * xc->imag; + + xhermc = x + (i + 2)*incx; + xhermcR = xhermc->real; + xhermcI = -1 * conj_multiplier * xhermc->imag; + + cc = c + (i + 3)*rs_ct + (i + 2)*cs_ct; + cc->real += alphaR * ( (xcR * xhermcR) - (xcI * xhermcI) ); + cc->imag += alphaR * ( (xcR * xhermcI) + (xcI * xhermcR) ); + + // Solving the 2x2 square tile inside the triangular block + // using intrinsics + // Broadcasting elements from x to ymm4-5 + // ymm4 = x2R x2I x2R x2I + // ymm5 = x3R x3I x3R x3I + ymm4 = _mm256_broadcast_pd( (__m128d const*)( x + (i + 2)*incx ) ); + ymm5 = _mm256_broadcast_pd( (__m128d const*)( x + (i + 3)*incx ) ); + + // Loading a tile from matrix + // ymm10 = c20R c20I c21R c21I + // ymm11 = c30R c30I c31R c31I + ymm10 = _mm256_loadu_pd( (double*)( c + (i + 2)*rs_ct + (i)*cs_ct ) ); + ymm11 = _mm256_loadu_pd( (double*)( c + (i + 3)*rs_ct + (i)*cs_ct ) ); + + // Separating the real & imaginary parts of x into ymm4-7 + // ymm6 -> imag of ymm4 + // ymm4 -> real of ymm4 + ymm6 = _mm256_permute_pd( ymm4, 15 ); + ymm4 = _mm256_permute_pd( ymm4, 0 ); + ymm7 = _mm256_permute_pd( ymm5, 15 ); + ymm5 = _mm256_permute_pd( ymm5, 0 ); + + // Applying conjugate to elements of x vector + ymm6 = _mm256_mul_pd( ymm6, conj_mulv ); + ymm7 = _mm256_mul_pd( ymm7, conj_mulv ); + + // Multiplying x vector with x hermitian vector + // and adding the result to the corresponding tile + ymm8 = _mm256_mul_pd( ymm4, ymm0 ); + ymm8 = _mm256_fmadd_pd( ymm6, ymm0_shuf, ymm8 ); + ymm10 = _mm256_add_pd( ymm10, ymm8 ); + + ymm9 = _mm256_mul_pd( ymm5, ymm0 ); + ymm9 = _mm256_fmadd_pd( ymm7, ymm0_shuf, ymm9 ); + ymm11 = _mm256_add_pd( ymm11, ymm9 ); + + // Storing back the results to the matrix + _mm256_storeu_pd( (double*)( c + (i + 2)*rs_ct + (i)*cs_ct ), ymm10 ); + _mm256_storeu_pd( (double*)( c + (i + 3)*rs_ct + (i)*cs_ct ), ymm11 ); + + /********* SQUARE BLOCK *********/ + // Solving a 4x4 square block of matrix using intrinsics + for ( j = (i + 4); (j + 3) < m; j += 4) + { + // Broadcasting elements from x to ymm4-5 + ymm4 = _mm256_broadcast_pd( (__m128d const*)( x + (j )*incx ) ); + ymm5 = _mm256_broadcast_pd( (__m128d const*)( x + (j + 1)*incx ) ); + + // Loading a tile from matrix + ymm10 = _mm256_loadu_pd( (double*)( c + j*rs_ct + (i )*cs_ct ) ); + ymm11 = _mm256_loadu_pd( (double*)( c + j*rs_ct + (i + 2)*cs_ct ) ); + + // Separating the real & imaginary parts of x into ymm4-7 + // ymm6 -> imag of ymm4 + // ymm4 -> real of ymm4 + ymm6 = _mm256_permute_pd( ymm4, 15 ); + ymm4 = _mm256_permute_pd( ymm4, 0 ); + ymm7 = _mm256_permute_pd( ymm5, 15 ); + ymm5 = _mm256_permute_pd( ymm5, 0 ); + + // Applying conjugate to elements of x vector + ymm6 = _mm256_mul_pd( ymm6, conj_mulv ); + ymm7 = _mm256_mul_pd( ymm7, conj_mulv ); + + // Multiplying x vector with x hermitian vector + // and adding the result to the corresponding tile + ymm8 = _mm256_mul_pd( ymm4, ymm0 ); + ymm8 = _mm256_fmadd_pd( ymm6, ymm0_shuf, ymm8 ); + ymm10 = _mm256_add_pd( ymm10, ymm8 ); + + ymm9 = _mm256_mul_pd( ymm4, ymm1 ); + ymm9 = _mm256_fmadd_pd( ymm6, ymm1_shuf, ymm9 ); + ymm11 = _mm256_add_pd( ymm11, ymm9 ); + + // Storing back the results to the matrix + _mm256_storeu_pd + ( + (double*)( c + (j)*rs_ct + (i)*cs_ct ), + ymm10 + ); + _mm256_storeu_pd + ( + (double*)( c + (j)*rs_ct + (i + 2)*cs_ct ), + ymm11 + ); + + // Loading a tile from matrix + ymm10 = _mm256_loadu_pd + ( + (double*)( c + (j + 1)*rs_ct + (i)*cs_ct ) + ); + ymm11 = _mm256_loadu_pd + ( + (double*)( c + (j + 1)*rs_ct + (i + 2)*cs_ct ) + ); + + // Multiplying x vector with x hermitian vector + // and adding the result to the corresponding tile + ymm8 = _mm256_mul_pd( ymm5, ymm0 ); + ymm8 = _mm256_fmadd_pd( ymm7, ymm0_shuf, ymm8 ); + ymm10 = _mm256_add_pd( ymm10, ymm8 ); + + ymm9 = _mm256_mul_pd( ymm5, ymm1 ); + ymm9 = _mm256_fmadd_pd( ymm7, ymm1_shuf, ymm9 ); + ymm11 = _mm256_add_pd( ymm11, ymm9 ); + + // Storing back the results to the matrix + _mm256_storeu_pd + ( + (double*)( c + (j + 1)*rs_ct + (i)*cs_ct ), + ymm10 + ); + _mm256_storeu_pd + ( + (double*)( c + (j + 1)*rs_ct + (i + 2)*cs_ct ), + ymm11 + ); + + // Broadcasting elements from x to ymm4-5 + ymm4 = _mm256_broadcast_pd( (__m128d const*)( x + (j + 2)*incx ) ); + ymm5 = _mm256_broadcast_pd( (__m128d const*)( x + (j + 3)*incx ) ); + + // Loading a tile from matrix + ymm10 = _mm256_loadu_pd + ( + (double*)( c + (j + 2)*rs_ct + (i)*cs_ct ) + ); + ymm11 = _mm256_loadu_pd + ( + (double*)( c + (j + 2)*rs_ct + (i + 2)*cs_ct ) + ); + + // Separating the real & imaginary parts of x into ymm4-7 + // ymm6 -> imag of ymm4 + // ymm4 -> real of ymm4 + ymm6 = _mm256_permute_pd( ymm4, 15 ); + ymm4 = _mm256_permute_pd( ymm4, 0 ); + ymm7 = _mm256_permute_pd( ymm5, 15 ); + ymm5 = _mm256_permute_pd( ymm5, 0 ); + + // Applying conjugate to elements of x vector + ymm6 = _mm256_mul_pd( ymm6, conj_mulv ); + ymm7 = _mm256_mul_pd( ymm7, conj_mulv ); + + // Multiplying x vector with x hermitian vector + // and adding the result to the corresponding tile + ymm8 = _mm256_mul_pd( ymm4, ymm0 ); + ymm8 = _mm256_fmadd_pd( ymm6, ymm0_shuf, ymm8 ); + ymm10 = _mm256_add_pd( ymm10, ymm8 ); + + ymm9 = _mm256_mul_pd( ymm4, ymm1 ); + ymm9 = _mm256_fmadd_pd( ymm6, ymm1_shuf, ymm9 ); + ymm11 = _mm256_add_pd( ymm11, ymm9 ); + + // Storing back the results to the matrix + _mm256_storeu_pd + ( + (double*)( c + (j + 2)*rs_ct + (i)*cs_ct ), + ymm10 + ); + _mm256_storeu_pd + ( + (double*)( c + (j + 2)*rs_ct + (i + 2)*cs_ct ), + ymm11 + ); + + // Loading a tile from matrix + ymm10 = _mm256_loadu_pd + ( + (double*)( c + (j + 3)*rs_ct + (i)*cs_ct ) + ); + ymm11 = _mm256_loadu_pd + ( + (double*)( c + (j + 3)*rs_ct + (i + 2)*cs_ct ) + ); + + // Multiplying x vector with x hermitian vector + // and adding the result to the corresponding tile + ymm8 = _mm256_mul_pd( ymm5, ymm0 ); + ymm8 = _mm256_fmadd_pd( ymm7, ymm0_shuf, ymm8 ); + ymm10 = _mm256_add_pd( ymm10, ymm8 ); + + ymm9 = _mm256_mul_pd( ymm5, ymm1 ); + ymm9 = _mm256_fmadd_pd( ymm7, ymm1_shuf, ymm9 ); + ymm11 = _mm256_add_pd( ymm11, ymm9 ); + + // Storing back the results to the matrix + _mm256_storeu_pd + ( + (double*)( c + (j + 3)*rs_ct + (i)*cs_ct ), + ymm10 + ); + _mm256_storeu_pd + ( + (double*)( c + (j + 3)*rs_ct + (i + 2)*cs_ct ), + ymm11 + ); + } + + // Solving a 2x2 square block of matrix using intrinsics + for ( ; (j + 1) < m; j += 2) + { + // Broadcasting elements from x to ymm4-5 + ymm4 = _mm256_broadcast_pd( (__m128d const*)( x + (j)*incx ) ); + ymm5 = _mm256_broadcast_pd( (__m128d const*)( x + (j + 1)*incx ) ); + + // Loading a tile from matrix + ymm10 = _mm256_loadu_pd( (double*)( c + j*rs_ct + (i)*cs_ct ) ); + ymm11 = _mm256_loadu_pd( (double*)( c + j*rs_ct + (i + 2)*cs_ct ) ); + + // Separating the real & imaginary parts of x into ymm4-7 + // ymm6 -> imag of ymm4 + // ymm4 -> real of ymm4 + ymm6 = _mm256_permute_pd( ymm4, 15 ); + ymm4 = _mm256_permute_pd( ymm4, 0 ); + ymm7 = _mm256_permute_pd( ymm5, 15 ); + ymm5 = _mm256_permute_pd( ymm5, 0 ); + + // Applying conjugate to elements of x vector + ymm6 = _mm256_mul_pd( ymm6, conj_mulv ); + ymm7 = _mm256_mul_pd( ymm7, conj_mulv ); + + // Multiplying x vector with x hermitian vector + // and adding the result to the corresponding tile + ymm8 = _mm256_mul_pd( ymm4, ymm0 ); + ymm8 = _mm256_fmadd_pd( ymm6, ymm0_shuf, ymm8 ); + ymm10 = _mm256_add_pd( ymm10, ymm8 ); + + ymm9 = _mm256_mul_pd( ymm4, ymm1 ); + ymm9 = _mm256_fmadd_pd( ymm6, ymm1_shuf, ymm9 ); + ymm11 = _mm256_add_pd( ymm11, ymm9 ); + + // Storing back the results to the matrix + _mm256_storeu_pd + ( + (double*)( c + (j)*rs_ct + (i)*cs_ct ), + ymm10 + ); + _mm256_storeu_pd + ( + (double*)( c + (j)*rs_ct + (i + 2)*cs_ct ), + ymm11 + ); + + // Loading a tile from matrix + ymm10 = _mm256_loadu_pd + ( + (double*)( c + (j + 1)*rs_ct + (i)*cs_ct ) + ); + ymm11 = _mm256_loadu_pd + ( + (double*)( c + (j + 1)*rs_ct + (i + 2)*cs_ct ) + ); + + // Multiplying x vector with x hermitian vector + // and adding the result to the corresponding tile + ymm8 = _mm256_mul_pd( ymm5, ymm0 ); + ymm8 = _mm256_fmadd_pd( ymm7, ymm0_shuf, ymm8 ); + ymm10 = _mm256_add_pd( ymm10, ymm8 ); + + ymm9 = _mm256_mul_pd( ymm5, ymm1 ); + ymm9 = _mm256_fmadd_pd( ymm7, ymm1_shuf, ymm9 ); + ymm11 = _mm256_add_pd( ymm11, ymm9 ); + + // Storing back the results to the matrix + _mm256_storeu_pd + ( + (double*)( c + (j + 1)*rs_ct + (i)*cs_ct ), + ymm10 + ); + _mm256_storeu_pd + ( + (double*)( c + (j + 1)*rs_ct + (i + 2)*cs_ct ), + ymm11 + ); + } + + for ( ; j < m; j++ ) + { + // Broadcasting elements from x to ymm4-5 + ymm4 = _mm256_broadcast_pd( (__m128d const*)( x + (j)*incx ) ); + + // Loading a tile from matrix + ymm10 = _mm256_loadu_pd( (double*)( c + j*rs_ct + (i)*cs_ct ) ); + ymm11 = _mm256_loadu_pd( (double*)( c + j*rs_ct + (i + 2)*cs_ct ) ); + + // Separating the real & imaginary parts of x into ymm4-7 + // ymm6 -> imag of ymm4 + // ymm4 -> real of ymm4 + ymm6 = _mm256_permute_pd( ymm4, 15 ); + ymm4 = _mm256_permute_pd( ymm4, 0 ); + ymm7 = _mm256_permute_pd( ymm5, 15 ); + ymm5 = _mm256_permute_pd( ymm5, 0 ); + + // Applying conjugate to elements of x vector + ymm6 = _mm256_mul_pd( ymm6, conj_mulv ); + ymm7 = _mm256_mul_pd( ymm7, conj_mulv ); + + // Multiplying x vector with x hermitian vector + // and adding the result to the corresponding tile + ymm8 = _mm256_mul_pd( ymm4, ymm0 ); + ymm8 = _mm256_fmadd_pd( ymm6, ymm0_shuf, ymm8 ); + ymm10 = _mm256_add_pd( ymm10, ymm8 ); + + ymm9 = _mm256_mul_pd( ymm4, ymm1 ); + ymm9 = _mm256_fmadd_pd( ymm6, ymm1_shuf, ymm9 ); + ymm11 = _mm256_add_pd( ymm11, ymm9 ); + + // Storing back the results to the matrix + _mm256_storeu_pd + ( + (double*)( c + (j)*rs_ct + (i)*cs_ct ), + ymm10 + ); + _mm256_storeu_pd + ( + (double*)( c + (j)*rs_ct + (i + 2)*cs_ct ), + ymm11 + ); + } + } + + // Solving the remaining blocks of matrix + for ( ; ( i + 1 ) < m; i += 2 ) + { + // Solving the corner elements + xc = x + (i + 1)*incx; + xcR = xc->real; + xcI = conj_multiplier * xc->imag; + + xhermc = x + i*incx; + xhermcR = xhermc->real; + xhermcI = -1 * conj_multiplier * xhermc->imag; + + cc = c + (i + 1)*rs_ct + i*cs_ct; + cc->real += alphaR * ( (xcR * xhermcR) - (xcI * xhermcI) ); + cc->imag += alphaR * ( (xcR * xhermcI) + (xcI * xhermcR) ); + + // Loading elements of x to ymm0 for computing xherm vector + ymm0 = _mm256_loadu_pd( (double*)( x + i*incx ) ); + + // Scaling xherm vector with alpha + ymm0 = _mm256_mul_pd( ymm0, alphaRv ); + + // Shuffling xherm vector for multiplication with x vector + ymm0_shuf = _mm256_permute_pd( ymm0, 5 ); + + /********* SQUARE BLOCK *********/ + // Solving a 2x2 square block of matrix using intrinsics + for ( j = ( i + 2 ); j < m; j++ ) + { + // Broadcasting elements from x to ymm4 + ymm4 = _mm256_broadcast_pd( (__m128d const*)( x + (j)*incx ) ); + + // Loading a tile from matrix + ymm10 = _mm256_loadu_pd( (double*)( c + (j)*rs_ct + (i)*cs_ct ) ); + + // Separating the real & imaginary parts of x into ymm4-7 + // ymm6 -> imag of ymm4 + // ymm4 -> real of ymm4 + ymm6 = _mm256_permute_pd( ymm4, 15 ); + ymm4 = _mm256_permute_pd( ymm4, 0 ); + + // Applying conjugate to elements of x vector + ymm6 = _mm256_mul_pd( ymm6, conj_mulv ); + + // Multiplying x vector with x hermitian vector + // and adding the result to the corresponding tile + ymm8 = _mm256_mul_pd( ymm4, ymm0 ); + ymm8 = _mm256_fmadd_pd( ymm6, ymm0_shuf, ymm8 ); + ymm10 = _mm256_add_pd( ymm10, ymm8 ); + + // Storing back the results to the matrix + _mm256_storeu_pd( (double*)( c + (j)*rs_ct + (i)*cs_ct ), ymm10 ); + } + } +} + +/** + * Optimized implementation of ZHER for lower triangular column stored & + * upper triangular row stored matrix. + * This kernel performs: + * A := A + conj?(alpha) * conj?(x) * conj?(x)^H + * where, + * A is an m x m hermitian matrix stored in upper/lower triangular + * x is a vector of length m + * alpha is a scalar + */ +void bli_zher_zen_int_var2 +( + uplo_t uplo, + conj_t conjx, + conj_t conjh, + dim_t m, + dcomplex* alpha, + dcomplex* x, inc_t incx, + dcomplex* c, inc_t rs_c, inc_t cs_c, + cntx_t* cntx +) +{ + double xcR, xcI; + double xhermcR, xhermcI; + double alphaR; + + dcomplex* xc; + dcomplex* xhermc; + dcomplex* cc; + + __m256d alphaRv; + __m256d ymm0, ymm1, ymm2, ymm3, ymm4, ymm5; + __m256d ymm6, ymm7, ymm8, ymm9, ymm10, ymm11; + __m256d ymm0_shuf, ymm1_shuf, ymm2_shuf, ymm3_shuf; + + dim_t conj_multiplier; + + inc_t rs_ct, cs_ct; + dim_t i = 0; + dim_t j = 0; + + alphaR = alpha->real; + + // The algorithm is expressed in terms of lower triangular case; + // the upper triangular case is supported by swapping the row and column + // strides of A & toggling the conj parameter. + if ( bli_is_lower( uplo ) ) + { + rs_ct = rs_c; + cs_ct = cs_c; + } + else /* if ( bli_is_upper( uplo ) ) */ + { + rs_ct = cs_c; + cs_ct = rs_c; + conjx = bli_apply_conj( conjh, conjx ); + } + + // Enabling conj_multiplier for scalar multiplication based on conjx + if ( !bli_is_conj(conjx) ) conj_multiplier = 1; + else conj_multiplier = -1; + + // Broadcasting real values of alpha based on conjx + // alphaRv = aR aR aR aR + if ( bli_is_conj( conjx ) ) alphaRv = _mm256_broadcast_sd( &alphaR ); + else alphaRv = _mm256_set_pd( -alphaR, alphaR, -alphaR, alphaR ); + + __m256d conj_mulv = _mm256_set_pd( conj_multiplier, -1 * conj_multiplier, conj_multiplier, -1 * conj_multiplier ); + + /********* DIAGONAL ELEMENTS *********/ + // Solving for the diagonal elements using a scalar loop + for ( i = 0; i < m; i++ ) + { + xc = x + i*incx; + xcR = xc->real; + xcI = xc->imag; + cc = c + (i)*rs_ct + (i)*cs_ct; + cc->real += alphaR * ((xcR * xcR) + (xcI * xcI)); + cc->imag = 0; + } + + // Vectorized loop + for ( i = 0; ( i + 3 ) < m; i += 4 ) + { + // Broadcasting elements of x to ymm0-1 for computing xherm vector + // ymm0 = x0R x0I x1R x1I + ymm0 = _mm256_broadcast_pd( (__m128d const*)( x + i*incx ) ); + ymm1 = _mm256_broadcast_pd( (__m128d const*)( x + (i + 1)*incx ) ); + ymm2 = _mm256_broadcast_pd( (__m128d const*)( x + (i + 2)*incx ) ); + ymm3 = _mm256_broadcast_pd( (__m128d const*)( x + (i + 3)*incx ) ); + + // Scaling xherm vector with alpha + // alphaRv = aR aR aR aR + // ymm0 = x0R -x0I x1R -x1I + // ymm0 * alphaRv = aR.x0R -aR.x0I aR.x1R -aR.x1I + ymm0 = _mm256_mul_pd( ymm0, alphaRv ); + ymm1 = _mm256_mul_pd( ymm1, alphaRv ); + ymm2 = _mm256_mul_pd( ymm2, alphaRv ); + ymm3 = _mm256_mul_pd( ymm3, alphaRv ); + + // Shuffling xherm vector for multiplication with x vector + // ymm0_shuf = -x0I x0R -x1I x1R + ymm0_shuf = _mm256_permute_pd( ymm0, 5 ); + ymm1_shuf = _mm256_permute_pd( ymm1, 5 ); + ymm2_shuf = _mm256_permute_pd( ymm2, 5 ); + ymm3_shuf = _mm256_permute_pd( ymm3, 5 ); + + /********* TRIANGULAR BLOCK *********/ + // Solving the corner elements of the triangular block + // using scalar multiplication + xc = x + (i + 1)*incx; + xcR = xc->real; + xcI = conj_multiplier * xc->imag; + + xhermc = x + (i)*incx; + xhermcR = xhermc->real; + xhermcI = -1 * conj_multiplier * xhermc->imag; + + cc = c + (i + 1)*rs_ct + (i + 0)*cs_ct; + cc->real += alphaR * ( (xcR * xhermcR) - (xcI * xhermcI) ); + cc->imag += alphaR * ( (xcR * xhermcI) + (xcI * xhermcR) ); + + xc = x + (i + 3)*incx; + xcR = xc->real; + xcI = conj_multiplier * xc->imag; + + xhermc = x + (i + 2)*incx; + xhermcR = xhermc->real; + xhermcI = -1 * conj_multiplier * xhermc->imag; + + cc = c + (i + 3)*rs_ct + (i + 2)*cs_ct; + cc->real += alphaR * ( (xcR * xhermcR) - (xcI * xhermcI) ); + cc->imag += alphaR * ( (xcR * xhermcI) + (xcI * xhermcR) ); + + // Solving the 2x2 square tile inside the triangular block + // using intrinsics + // Loading elements from x to ymm4 + // ymm4 = x2R x2I x2R x2I + ymm4 = _mm256_loadu_pd( (double*)( x + (i + 2)*incx ) ); + + // Loading a tile from matrix + // ymm10 = c20R c20I c21R c21I + // ymm11 = c30R c30I c31R c31I + ymm10 = _mm256_loadu_pd + ( + (double*)( c + (i + 2)*rs_ct + (i)*cs_ct ) + ); + ymm11 = _mm256_loadu_pd + ( + (double*)( c + (i + 2)*rs_ct + (i + 1)*cs_ct ) + ); + + // Separating the real & imaginary parts of x into ymm4-7 + // ymm6 -> imag of ymm4 + // ymm4 -> real of ymm4 + ymm6 = _mm256_permute_pd( ymm4, 15 ); + ymm4 = _mm256_permute_pd( ymm4, 0 ); + + // Applying conjugate to elements of x vector + ymm6 = _mm256_mul_pd( ymm6, conj_mulv ); + + // Multiplying x vector with x hermitian vector + // and adding the result to the corresponding tile + ymm8 = _mm256_mul_pd( ymm4, ymm0 ); + ymm8 = _mm256_fmadd_pd( ymm6, ymm0_shuf, ymm8 ); + ymm10 = _mm256_add_pd( ymm10, ymm8 ); + + ymm9 = _mm256_mul_pd( ymm4, ymm1 ); + ymm9 = _mm256_fmadd_pd( ymm6, ymm1_shuf, ymm9 ); + ymm11 = _mm256_add_pd( ymm11, ymm9 ); + + // Storing back the results to the matrix + _mm256_storeu_pd + ( + (double*)( c + (i + 2)*rs_ct + (i)*cs_ct ), + ymm10 + ); + _mm256_storeu_pd + ( + (double*)( c + (i + 2)*rs_ct + (i + 1)*cs_ct ), + ymm11 + ); + + /********* SQUARE BLOCK *********/ + // Solving a 4x4 square block of matrix using intrinsics + for ( j = (i + 4); (j + 3) < m; j += 4) + { + // Loading elements from x to ymm4-5 + ymm4 = _mm256_loadu_pd( (double*)( x + j*incx ) ); + ymm5 = _mm256_loadu_pd( (double*)( x + (j + 2)*incx ) ); + + // Separating the real & imaginary parts of x into ymm4-7 + // ymm6 -> imag of ymm4 + // ymm4 -> real of ymm4 + ymm6 = _mm256_permute_pd( ymm4, 15 ); + ymm4 = _mm256_permute_pd( ymm4, 0 ); + ymm7 = _mm256_permute_pd( ymm5, 15 ); + ymm5 = _mm256_permute_pd( ymm5, 0 ); + + // Applying conjugate to elements of x vector + ymm6 = _mm256_mul_pd( ymm6, conj_mulv ); + ymm7 = _mm256_mul_pd( ymm7, conj_mulv ); + + // Loading a tile from matrix + ymm10 = _mm256_loadu_pd + ( + (double*)( c + (j)*rs_ct + (i)*cs_ct ) + ); + ymm11 = _mm256_loadu_pd + ( + (double*)( c + (j + 2)*rs_ct + (i)*cs_ct ) + ); + + // Multiplying x vector with x hermitian vector + // and adding the result to the corresponding tile + ymm8 = _mm256_mul_pd( ymm4, ymm0 ); + ymm9 = _mm256_mul_pd( ymm5, ymm0 ); + ymm8 = _mm256_fmadd_pd( ymm6, ymm0_shuf, ymm8 ); + ymm9 = _mm256_fmadd_pd( ymm7, ymm0_shuf, ymm9 ); + ymm10 = _mm256_add_pd( ymm10, ymm8 ); + ymm11 = _mm256_add_pd( ymm11, ymm9 ); + + // Storing back the results to the matrix + _mm256_storeu_pd + ( + (double*)( c + (j)*rs_ct + (i)*cs_ct ), + ymm10 + ); + _mm256_storeu_pd + ( + (double*)( c + (j + 2)*rs_ct + (i)*cs_ct ), + ymm11 + ); + + // Loading a tile from matrix + ymm10 = _mm256_loadu_pd + ( + (double*)( c + (j)*rs_ct + (i + 1)*cs_ct ) + ); + ymm11 = _mm256_loadu_pd + ( + (double*)( c + (j + 2)*rs_ct + (i + 1)*cs_ct ) + ); + + // Multiplying x vector with x hermitian vector + // and adding the result to the corresponding tile + ymm8 = _mm256_mul_pd( ymm4, ymm1 ); + ymm9 = _mm256_mul_pd( ymm5, ymm1 ); + ymm8 = _mm256_fmadd_pd( ymm6, ymm1_shuf, ymm8 ); + ymm9 = _mm256_fmadd_pd( ymm7, ymm1_shuf, ymm9 ); + ymm10 = _mm256_add_pd( ymm10, ymm8 ); + ymm11 = _mm256_add_pd( ymm11, ymm9 ); + + // Storing back the results to the matrix + _mm256_storeu_pd + ( + (double*)( c + (j)*rs_ct + (i + 1)*cs_ct ), + ymm10 + ); + _mm256_storeu_pd + ( + (double*)( c + (j + 2)*rs_ct + (i + 1)*cs_ct ), + ymm11 + ); + + // Loading a tile from matrix + ymm10 = _mm256_loadu_pd + ( + (double*)( c + (j)*rs_ct + (i + 2)*cs_ct ) + ); + ymm11 = _mm256_loadu_pd + ( + (double*)( c + (j + 2)*rs_ct + (i + 2)*cs_ct ) + ); + + // Multiplying x vector with x hermitian vector + // and adding the result to the corresponding tile + ymm8 = _mm256_mul_pd( ymm4, ymm2 ); + ymm9 = _mm256_mul_pd( ymm5, ymm2 ); + ymm8 = _mm256_fmadd_pd( ymm6, ymm2_shuf, ymm8 ); + ymm9 = _mm256_fmadd_pd( ymm7, ymm2_shuf, ymm9 ); + ymm10 = _mm256_add_pd( ymm10, ymm8 ); + ymm11 = _mm256_add_pd( ymm11, ymm9 ); + + // Storing back the results to the matrix + _mm256_storeu_pd + ( + (double*)( c + (j)*rs_ct + (i + 2)*cs_ct ), + ymm10 + ); + _mm256_storeu_pd + ( + (double*)( c + (j + 2)*rs_ct + (i + 2)*cs_ct ), + ymm11 + ); + + // Loading a tile from matrix + ymm10 = _mm256_loadu_pd + ( + (double*)( c + (j)*rs_ct + (i + 3)*cs_ct ) + ); + ymm11 = _mm256_loadu_pd + ( + (double*)( c + (j + 2)*rs_ct + (i + 3)*cs_ct ) + ); + + // Multiplying x vector with x hermitian vector + // and adding the result to the corresponding tile + ymm8 = _mm256_mul_pd( ymm4, ymm3 ); + ymm9 = _mm256_mul_pd( ymm5, ymm3 ); + ymm8 = _mm256_fmadd_pd( ymm6, ymm3_shuf, ymm8 ); + ymm9 = _mm256_fmadd_pd( ymm7, ymm3_shuf, ymm9 ); + ymm10 = _mm256_add_pd( ymm10, ymm8 ); + ymm11 = _mm256_add_pd( ymm11, ymm9 ); + + // Storing back the results to the matrix + _mm256_storeu_pd + ( + (double*)( c + (j)*rs_ct + (i + 3)*cs_ct ), + ymm10 + ); + _mm256_storeu_pd + ( + (double*)( c + (j + 2)*rs_ct + (i + 3)*cs_ct ), + ymm11 + ); + } + + // Solving a 2x2 square block of matrix using intrinsics + for ( ; (j + 1) < m; j += 2) + { + // Loading elements from x to ymm4 + ymm4 = _mm256_loadu_pd( (double*)( x + j*incx ) ); + + // Separating the real & imaginary parts of x into ymm4-7 + // ymm6 -> imag of ymm4 + // ymm4 -> real of ymm4 + ymm6 = _mm256_permute_pd( ymm4, 15 ); + ymm4 = _mm256_permute_pd( ymm4, 0 ); + + // Applying conjugate to elements of x vector + ymm6 = _mm256_mul_pd( ymm6, conj_mulv ); + + // Loading a tile from matrix + ymm10 = _mm256_loadu_pd( (double*)( c + (j)*rs_ct + (i)*cs_ct ) ); + + // Multiplying x vector with x hermitian vector + // and adding the result to the corresponding tile + ymm8 = _mm256_mul_pd( ymm4, ymm0 ); + ymm8 = _mm256_fmadd_pd( ymm6, ymm0_shuf, ymm8 ); + ymm10 = _mm256_add_pd( ymm10, ymm8 ); + + // Storing back the results to the matrix + _mm256_storeu_pd( (double*)( c + (j)*rs_ct + (i)*cs_ct ), ymm10 ); + + // Loading a tile from matrix + ymm10 = _mm256_loadu_pd( (double*)( c + j*rs_ct + (i + 1)*cs_ct ) ); + + // Multiplying x vector with x hermitian vector + // and adding the result to the corresponding tile + ymm8 = _mm256_mul_pd( ymm4, ymm1 ); + ymm8 = _mm256_fmadd_pd( ymm6, ymm1_shuf, ymm8 ); + ymm10 = _mm256_add_pd( ymm10, ymm8 ); + + // Storing back the results to the matrix + _mm256_storeu_pd( (double*)( c + j*rs_ct + (i + 1)*cs_ct ), ymm10 ); + + // Loading a tile from matrix + ymm10 = _mm256_loadu_pd( (double*)( c + j*rs_ct + (i + 2)*cs_ct ) ); + + // Multiplying x vector with x hermitian vector + // and adding the result to the corresponding tile + ymm8 = _mm256_mul_pd( ymm4, ymm2 ); + ymm8 = _mm256_fmadd_pd( ymm6, ymm2_shuf, ymm8 ); + ymm10 = _mm256_add_pd( ymm10, ymm8 ); + + // Storing back the results to the matrix + _mm256_storeu_pd( (double*)( c + j*rs_ct + (i + 2)*cs_ct ), ymm10 ); + + // Loading a tile from matrix + ymm10 = _mm256_loadu_pd( (double*)( c + j*rs_ct + (i + 3)*cs_ct ) ); + + // Multiplying x vector with x hermitian vector + // and adding the result to the corresponding tile + ymm8 = _mm256_mul_pd( ymm4, ymm3 ); + ymm8 = _mm256_fmadd_pd( ymm6, ymm3_shuf, ymm8 ); + ymm10 = _mm256_add_pd( ymm10, ymm8 ); + + // Storing back the results to the matrix + _mm256_storeu_pd( (double*)( c + j*rs_ct + (i + 3)*cs_ct ), ymm10 ); + } + + // Calculating for the remaining elements using scalar code + for ( ; j < m; j++ ) + { + xc = x + j*incx; + xcR = xc->real; + xcI = conj_multiplier * xc->imag; + + xhermc = x + i*incx; + xhermcR = xhermc->real; + xhermcI = -1 * conj_multiplier * xhermc->imag; + + // c + ((alpha * x) * xherm) + cc = c + (j)*rs_ct + (i)*cs_ct; + cc->real += (alphaR * ((xcR * xhermcR) - (xcI * xhermcI))); + cc->imag += (alphaR * ((xcR * xhermcI) + (xcI * xhermcR))); + + xc = x + j*incx; + xcR = xc->real; + xcI = conj_multiplier * xc->imag; + + xhermc = x + (i + 1)*incx; + xhermcR = xhermc->real; + xhermcI = -1 * conj_multiplier * xhermc->imag; + + // c + ((alpha * x) * xherm) + cc = c + (j)*rs_ct + (i + 1)*cs_ct; + cc->real += (alphaR * ((xcR * xhermcR) - (xcI * xhermcI))); + cc->imag += (alphaR * ((xcR * xhermcI) + (xcI * xhermcR))); + + xc = x + j*incx; + xcR = xc->real; + xcI = conj_multiplier * xc->imag; + + xhermc = x + (i + 2)*incx; + xhermcR = xhermc->real; + xhermcI = -1 * conj_multiplier * xhermc->imag; + + // c + ((alpha * x) * xherm) + cc = c + (j)*rs_ct + (i + 2)*cs_ct; + cc->real += (alphaR * ((xcR * xhermcR) - (xcI * xhermcI))); + cc->imag += (alphaR * ((xcR * xhermcI) + (xcI * xhermcR))); + + xc = x + j*incx; + xcR = xc->real; + xcI = conj_multiplier * xc->imag; + + xhermc = x + (i + 3)*incx; + xhermcR = xhermc->real; + xhermcI = -1 * conj_multiplier * xhermc->imag; + + // c + ((alpha * x) * xherm) + cc = c + (j)*rs_ct + (i + 3)*cs_ct; + cc->real += (alphaR * ((xcR * xhermcR) - (xcI * xhermcI))); + cc->imag += (alphaR * ((xcR * xhermcI) + (xcI * xhermcR))); + } + } + + for ( ; ( i + 1 ) < m; i += 2 ) + { + /********* TRIANGULAR BLOCK *********/ + // Solving the corner elements of the triangular block + // using scalar multiplication + xc = x + (i + 1)*incx; + xcR = xc->real; + xcI = conj_multiplier * xc->imag; + + xhermc = x + i*incx; + xhermcR = xhermc->real; + xhermcI = -1 * conj_multiplier * xhermc->imag; + + cc = c + (i + 1)*rs_ct + i*cs_ct; + cc->real += alphaR * ( (xcR * xhermcR) - (xcI * xhermcI) ); + cc->imag += alphaR * ( (xcR * xhermcI) + (xcI * xhermcR) ); + + // Solving the remaining elements in square block + // using scalar code + for ( j = (i + 2); j < m; j++ ) + { + xc = x + j*incx; + xcR = xc->real; + xcI = conj_multiplier * xc->imag; + + xhermc = x + i*incx; + xhermcR = xhermc->real; + xhermcI = -1 * conj_multiplier * xhermc->imag; + + // c + ((alpha * x) * xherm) + cc = c + (j)*rs_ct + (i)*cs_ct; + cc->real += (alphaR * ((xcR * xhermcR) - (xcI * xhermcI))); + cc->imag += (alphaR * ((xcR * xhermcI) + (xcI * xhermcR))); + + xc = x + j*incx; + xcR = xc->real; + xcI = conj_multiplier * xc->imag; + + xhermc = x + (i + 1)*incx; + xhermcR = xhermc->real; + xhermcI = -1 * conj_multiplier * xhermc->imag; + + // c + ((alpha * x) * xherm) + cc = c + (j)*rs_ct + (i + 1)*cs_ct; + cc->real += (alphaR * ((xcR * xhermcR) - (xcI * xhermcI))); + cc->imag += (alphaR * ((xcR * xhermcI) + (xcI * xhermcR))); + } + } +} \ No newline at end of file diff --git a/kernels/zen/bli_kernels_zen.h b/kernels/zen/bli_kernels_zen.h index e29ed61b2b..bd2704bbc4 100644 --- a/kernels/zen/bli_kernels_zen.h +++ b/kernels/zen/bli_kernels_zen.h @@ -140,6 +140,10 @@ GEMV_KER_PROT( double, d, gemv_zen_ref_c ) GEMV_KER_PROT( scomplex, c, gemv_zen_int_4x4 ) GEMV_KER_PROT( dcomplex, z, gemv_zen_int_4x4 ) +// her (intrinsics) +HER_KER_PROT( dcomplex, z, her_zen_int_var1 ) +HER_KER_PROT( dcomplex, z, her_zen_int_var2 ) + // -- level-3 sup -------------------------------------------------------------- // semmsup_rv From 66b2231b6524b1a91492fb176dcf4b95fa94c2eb Mon Sep 17 00:00:00 2001 From: Arnav Sharma Date: Wed, 1 Jun 2022 12:21:55 +0530 Subject: [PATCH 127/243] Fixed CMake files for HER - Removed subdirectory addition Change-Id: I419085db0b9034777409207a7d79b7ffa91eb8f1 --- frame/2/her/CMakeLists.txt | 4 +--- kernels/zen/2/CMakeLists.txt | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/frame/2/her/CMakeLists.txt b/frame/2/her/CMakeLists.txt index b98422feef..0bcab8ff61 100644 --- a/frame/2/her/CMakeLists.txt +++ b/frame/2/her/CMakeLists.txt @@ -22,6 +22,4 @@ else() ${CMAKE_CURRENT_SOURCE_DIR}/bli_her_unb_var1.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_her_unb_var2.c ) -endif() - - add_subdirectory(ind) \ No newline at end of file +endif() \ No newline at end of file diff --git a/kernels/zen/2/CMakeLists.txt b/kernels/zen/2/CMakeLists.txt index c1fd431f47..85ad4bfd5a 100644 --- a/kernels/zen/2/CMakeLists.txt +++ b/kernels/zen/2/CMakeLists.txt @@ -17,6 +17,4 @@ ${TARGET_ARCH} STREQUAL amdzen) PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/bli_her_zen_int_amd.c ) -endif() - - add_subdirectory(ind) \ No newline at end of file +endif() \ No newline at end of file From e61ec820f9a9d2192d96ee87c179b4fd5dfab255 Mon Sep 17 00:00:00 2001 From: "Dipal M. Zambare" Date: Thu, 2 Jun 2022 14:14:45 +0530 Subject: [PATCH 128/243] Fixed windows build issue for BLIS 4.0 - Removed extra AVX512 typedef from bli_amaxv_zen_int.c. AMD-Internal: [CPUPL-2154] Change-Id: Ieaa827c7d81b8d101f3a827827b99433570219f2 --- kernels/zen/1/bli_amaxv_zen_int.c | 9 --------- 1 file changed, 9 deletions(-) diff --git a/kernels/zen/1/bli_amaxv_zen_int.c b/kernels/zen/1/bli_amaxv_zen_int.c index 31358421e3..8487bdce4b 100644 --- a/kernels/zen/1/bli_amaxv_zen_int.c +++ b/kernels/zen/1/bli_amaxv_zen_int.c @@ -56,15 +56,6 @@ typedef union } v16sf_t; #endif -/* Union data structure to access AVX registers - One 512-bit AVX register holds 8 DP elements. */ -typedef union -{ - __m512d v; - double d[8] __attribute__((aligned(64))); -} v8df_t; - - /* Union data structure to access AVX registers One 256-bit AVX register holds 8 SP elements. */ typedef union From 8cc15107ede971e432718f58ebc8d323a9a6b4e3 Mon Sep 17 00:00:00 2001 From: Dipal M Zambare Date: Wed, 18 May 2022 11:01:41 +0530 Subject: [PATCH 129/243] Enabled AVX-512 kernels for Zen4 config - Enabled AVX-512 skylake kernels in zen4 configuration. AVX-512 kernels are added for GEMM float and double types. - Enabled reference kernel for TRSM native path AMD-Internal: [CPUPL-2108] Change-Id: I66f3468346085c17183cbcbf4f2c8cfe07579b6f --- config/skx/bli_cntx_init_skx.c | 4 ++-- config/skx/bli_family_skx.h | 2 -- config/zen4/bli_cntx_init_zen4.c | 26 ++++++++++++++++---------- config/zen4/bli_family_zen4.h | 10 ++++++++-- config/zen4/make_defs.mk | 25 ++++++++++++++++++------- config_registry | 2 +- frame/include/bli_arch_config.h | 5 ++++- 7 files changed, 49 insertions(+), 25 deletions(-) diff --git a/config/skx/bli_cntx_init_skx.c b/config/skx/bli_cntx_init_skx.c index c14311bf21..f18503a7a7 100644 --- a/config/skx/bli_cntx_init_skx.c +++ b/config/skx/bli_cntx_init_skx.c @@ -73,8 +73,8 @@ void bli_cntx_init_skx( cntx_t* cntx ) 10, #if 1 // amaxv - BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int_avx512, - BLIS_AMAXV_KER, BLIS_DOUBLE, bli_damaxv_zen_int_avx512, + BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int, + BLIS_AMAXV_KER, BLIS_DOUBLE, bli_damaxv_zen_int, #endif // axpyv #if 0 diff --git a/config/skx/bli_family_skx.h b/config/skx/bli_family_skx.h index cbba06358e..ac9478f8ba 100644 --- a/config/skx/bli_family_skx.h +++ b/config/skx/bli_family_skx.h @@ -50,8 +50,6 @@ #define BLIS_SIMD_SIZE 64 #define BLIS_SIMD_NUM_REGISTERS 32 -#define AVX512 - //#include //#define BLIS_MALLOC_POOL malloc diff --git a/config/zen4/bli_cntx_init_zen4.c b/config/zen4/bli_cntx_init_zen4.c index c340fa9087..e25ceabc8f 100644 --- a/config/zen4/bli_cntx_init_zen4.c +++ b/config/zen4/bli_cntx_init_zen4.c @@ -47,18 +47,20 @@ void bli_cntx_init_zen4( cntx_t* cntx ) // their storage preferences. bli_cntx_set_l3_nat_ukrs ( - 8, + 4, // gemm - BLIS_GEMM_UKR, BLIS_FLOAT, bli_sgemm_haswell_asm_6x16, TRUE, - BLIS_GEMM_UKR, BLIS_DOUBLE, bli_dgemm_haswell_asm_6x8, TRUE, + BLIS_GEMM_UKR, BLIS_FLOAT , bli_sgemm_skx_asm_32x12_l2, FALSE, + BLIS_GEMM_UKR, BLIS_DOUBLE, bli_dgemm_skx_asm_16x14, FALSE, BLIS_GEMM_UKR, BLIS_SCOMPLEX, bli_cgemm_haswell_asm_3x8, TRUE, BLIS_GEMM_UKR, BLIS_DCOMPLEX, bli_zgemm_haswell_asm_3x4, TRUE, +#if 0 // GENOA TODO: TRSM AVX-512 implementation // gemmtrsm_l BLIS_GEMMTRSM_L_UKR, BLIS_FLOAT, bli_sgemmtrsm_l_haswell_asm_6x16, TRUE, BLIS_GEMMTRSM_L_UKR, BLIS_DOUBLE, bli_dgemmtrsm_l_haswell_asm_6x8, TRUE, // gemmtrsm_u BLIS_GEMMTRSM_U_UKR, BLIS_FLOAT, bli_sgemmtrsm_u_haswell_asm_6x16, TRUE, BLIS_GEMMTRSM_U_UKR, BLIS_DOUBLE, bli_dgemmtrsm_u_haswell_asm_6x8, TRUE, +#endif cntx ); @@ -160,14 +162,16 @@ void bli_cntx_init_zen4( cntx_t* cntx ) // // These are reference block sizes and may be overridden based on // number of threads used at runtime. - // s d c z - bli_blksz_init_easy( &blkszs[ BLIS_MR ], 6, 6, 3, 3 ); - bli_blksz_init_easy( &blkszs[ BLIS_NR ], 16, 8, 8, 4 ); - bli_blksz_init_easy( &blkszs[ BLIS_MC ], 144, 72, 144, 18 ); - bli_blksz_init_easy( &blkszs[ BLIS_KC ], 256, 256, 256, 566 ); - bli_blksz_init_easy( &blkszs[ BLIS_NC ], 4080, 4080, 4080, 256 ); - bli_blksz_init_easy( &blkszs[ BLIS_AF ], 5, 5, -1, -1 ); + // s d c z + bli_blksz_init_easy( &blkszs[ BLIS_MR ], 32, 16, 3, 3 ); + bli_blksz_init_easy( &blkszs[ BLIS_NR ], 12, 14, 8, 4 ); + bli_blksz_init_easy( &blkszs[ BLIS_MC ], 480, 240, 144, 18 ); + bli_blksz_init ( &blkszs[ BLIS_KC ], 384, 256, 256, 566, + 480, 320, 256, 566 ); + bli_blksz_init_easy( &blkszs[ BLIS_NC ], 3072, 3752, 4080, 256 ); + + bli_blksz_init_easy( &blkszs[ BLIS_AF ], 8, 8, -1, -1 ); bli_blksz_init_easy( &blkszs[ BLIS_DF ], 8, 8, -1, -1 ); // Update the context with the current architecture's register and cache @@ -188,6 +192,7 @@ void bli_cntx_init_zen4( cntx_t* cntx ) ); // ------------------------------------------------------------------------- +#if 0 // GENOA TODO: TRSM AVX-512 implementation //Initialize TRSM blocksize objects with architecture-specific values. //Using different cache block sizes for TRSM instead of common level-3 block sizes. //Tuning is done for double-precision only. @@ -208,6 +213,7 @@ void bli_cntx_init_zen4( cntx_t* cntx ) BLIS_MR, &blkszs[ BLIS_MR ], cntx ); +#endif // Initialize sup thresholds with architecture-appropriate values. s d c z bli_blksz_init_easy( &thresh[ BLIS_MT ], 512, 256, 380, 110 ); diff --git a/config/zen4/bli_family_zen4.h b/config/zen4/bli_family_zen4.h index 9c70fcef83..71929cdac4 100644 --- a/config/zen4/bli_family_zen4.h +++ b/config/zen4/bli_family_zen4.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2021-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -39,7 +39,6 @@ // Setting these macros to 1 will force JR and IR inner loops // to be not paralleized. // - #define BLIS_THREAD_MAX_IR 1 #define BLIS_THREAD_MAX_JR 1 @@ -56,4 +55,11 @@ //#define BLIS_ENABLE_FAST_MATH +// -- SIMD config -------------------------------------------------------- + +#define BLIS_SIMD_ALIGN_SIZE 64 + +#define BLIS_SIMD_SIZE 64 +#define BLIS_SIMD_NUM_REGISTERS 32 + #endif diff --git a/config/zen4/make_defs.mk b/config/zen4/make_defs.mk index 44e96bb0c7..85a8a39f62 100644 --- a/config/zen4/make_defs.mk +++ b/config/zen4/make_defs.mk @@ -32,7 +32,7 @@ # # -# FLAGS that are specific to the 'zen3' architecture are added here. +# FLAGS that are specific to the 'zen4' architecture are added here. # FLAGS that are common for all the AMD architectures are present in # config/zen/amd_config.mk. @@ -73,15 +73,17 @@ GCC_VERSION := $(strip $(shell $(CC) -dumpversion | cut -d. -f1)) # gcc or clang version must be atleast 4.0 # gcc 9.0 or later: ifeq ($(shell test $(GCC_VERSION) -ge 11; echo $$?),0) -CKVECFLAGS += -march=znver3 +CKVECFLAGS += -march=znver3 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mfpmath=sse +CRVECFLAGS += -march=znver3 else ifeq ($(shell test $(GCC_VERSION) -ge 9; echo $$?),0) -CKVECFLAGS += -march=znver2 +CKVECFLAGS += -march=znver2 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mfpmath=sse +CRVECFLAGS += -march=znver2 else # If gcc is older than 9.1.0 but at least 6.1.0, then we can use -march=znver1 # as the fallback option. -CRVECFLAGS += -march=znver1 -mno-avx256-split-unaligned-store CKVECFLAGS += -march=znver1 -mno-avx256-split-unaligned-store +CRVECFLAGS += -march=znver1 -mno-avx256-split-unaligned-store endif # GCC 9 endif # GCC 11 else @@ -99,11 +101,13 @@ ifeq ($(CC_VENDOR),clang) # for version 3x we will enable znver3 ifeq ($(strip $(shell $(CC) -v |&head -1 |grep -c 'AOCC_3')),1) -CKVECFLAGS += -march=znver3 +CKVECFLAGS += -march=znver3 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mfpmath=sse +CRVECFLAGS += -march=znver3 else # for version 2x we will enable znver2 ifeq ($(strip $(shell $(CC) -v |&head -1 |grep -c 'AOCC.LLVM.2\|AOCC_2')),1) -CKVECFLAGS += -march=znver2 +CKVECFLAGS += -march=znver2 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mfpmath=sse +CRVECFLAGS += -march=znver2 else #if compiling with clang VENDOR_STRING := $(strip $(shell ${CC_VENDOR} --version | egrep -o '[0-9]+\.[0-9]+\.?[0-9]*')) @@ -111,8 +115,10 @@ CC_MAJOR := $(shell (echo ${VENDOR_STRING} | cut -d. -f1)) #clang 9.0 or later: ifeq ($(shell test $(CC_MAJOR) -ge 9; echo $$?),0) CKVECFLAGS += -march=znver2 +CRVECFLAGS += -march=znver2 else CKVECFLAGS += -march=znver1 +CRVECFLAGS += -march=znver1 endif # ge 9 endif # aocc 2 endif # aocc 3 @@ -121,7 +127,12 @@ endif # gcc # Flags specific to reference kernels. CROPTFLAGS := $(CKOPTFLAGS) -CRVECFLAGS := $(CKVECFLAGS) + +# Flags specific to reference kernels. +# Note: We use AVX2 for reference kernels because, as Jeff Hammond says, +# reference kernel code "is not going to achieve high enough SIMD utilization +# to overcome the AVX-512 frequency drop". (Issue #187) +CRVECFLAGS += -mno-avx512f -mno-avx512vl -mno-avx512bw -mno-avx512dq -mno-avx512cd -funsafe-math-optimizations -ffp-contract=fast # Store all of the variables here to new variables containing the # configuration name. diff --git a/config_registry b/config_registry index 822b133f5c..4e6716dfa1 100644 --- a/config_registry +++ b/config_registry @@ -26,7 +26,7 @@ sandybridge: sandybridge penryn: penryn # AMD architectures. -zen4: zen4/zen4/zen3/zen2/zen/haswell +zen4: zen4/zen4/skx/zen3/zen2/zen/haswell zen3: zen3/zen3/zen2/zen/haswell zen2: zen2/zen2/zen/haswell zen: zen/zen/haswell diff --git a/frame/include/bli_arch_config.h b/frame/include/bli_arch_config.h index 3e2e0b022b..6343c6ba89 100644 --- a/frame/include/bli_arch_config.h +++ b/frame/include/bli_arch_config.h @@ -6,7 +6,7 @@ Copyright (C) 2014, The University of Texas at Austin Copyright (C) 2016, Hewlett Packard Enterprise Development LP - Copyright (C) 2019 - 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -171,6 +171,9 @@ CNTX_INIT_PROTS( generic ) // -- AMD64 architectures -- +#ifdef BLIS_FAMILY_ZEN4 +#include "bli_family_zen4.h" +#endif #ifdef BLIS_FAMILY_ZEN3 #include "bli_family_zen3.h" #endif From c87b9aab75dadb260a84ab8443ba3353d1008fb2 Mon Sep 17 00:00:00 2001 From: Dipal M Zambare Date: Mon, 6 Jun 2022 17:13:18 +0530 Subject: [PATCH 130/243] Added support for AVX512 for Windows and AMAVX - Completed zen4 configuration support on windows - Enabled AVX512 kernels for AMAXV - Added zen4 configuration in amdzen for windows - Moved all zen4 kernels inside kernels/zen4 folder AMD-Internal: [CPUPL-2108] Change-Id: I9d2336998bbcdb8e2c4ca474977b5939bfa578ba --- CMakeLists.txt | 13 +- config/zen4/bli_cntx_init_zen4.c | 4 +- frame/2/gemv/CMakeLists.txt | 1 + frame/2/hemv/CMakeLists.txt | 3 +- frame/2/her2/CMakeLists.txt | 3 +- frame/2/trsv/CMakeLists.txt | 3 +- frame/3/CMakeLists.txt | 3 +- frame/3/gemm/CMakeLists.txt | 3 +- frame/3/trmm/CMakeLists.txt | 1 + frame/compat/CMakeLists.txt | 3 +- frame/compat/bla_amax_amd.c | 24 +- frame/include/bli_arch_config.h | 4 +- kernels/CMakeLists.txt | 8 +- kernels/skx/3/CMakeLists.txt | 7 + kernels/skx/CMakeLists.txt | 4 + kernels/zen/1/bli_amaxv_zen_int.c | 1005 +-------------------- kernels/zen/bli_kernels_zen.h | 2 - kernels/zen4/1/CMakeLists.txt | 6 + kernels/zen4/1/bli_amaxv_zen_int_avx512.c | 975 ++++++++++++++++++++ kernels/zen4/CMakeLists.txt | 5 + kernels/zen4/bli_kernels_zen4.h | 39 + 21 files changed, 1089 insertions(+), 1027 deletions(-) create mode 100644 kernels/skx/3/CMakeLists.txt create mode 100644 kernels/skx/CMakeLists.txt create mode 100644 kernels/zen4/1/CMakeLists.txt create mode 100644 kernels/zen4/1/bli_amaxv_zen_int_avx512.c create mode 100644 kernels/zen4/CMakeLists.txt create mode 100644 kernels/zen4/bli_kernels_zen4.h diff --git a/CMakeLists.txt b/CMakeLists.txt index bcb67f2ccf..ec11b44a6f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -17,7 +17,7 @@ set(AOCL_BLIS_ZEN TRUE) set (PYTHON_EXE "python") if ("${AOCL_BLIS_FAMILY}" STREQUAL "") - message(FATAL_ERROR "Machine configuration missing! Select one of zen, zen2, zen3 or amdzen") + message(FATAL_ERROR "Machine configuration missing! Select one of zen, zen2, zen3, zen4 or amdzen") endif () if (${AOCL_BLIS_FAMILY} STREQUAL "auto") @@ -53,6 +53,7 @@ elseif (${AOCL_BLIS_FAMILY} STREQUAL "zen3") elseif (${AOCL_BLIS_FAMILY} STREQUAL "zen4") add_definitions(-DBLIS_FAMILY_ZEN4) add_definitions(-DBLIS_CONFIG_ZEN4) + add_definitions(-DBLIS_KERNELS_SKX) add_definitions(-DBLIS_KERNELS_ZEN4) add_definitions(-DBLIS_KERNELS_ZEN3) add_definitions(-DBLIS_KERNELS_ZEN2) @@ -66,6 +67,7 @@ elseif (${AOCL_BLIS_FAMILY} STREQUAL "amdzen") add_definitions(-DBLIS_CONFIG_ZEN2) add_definitions(-DBLIS_CONFIG_ZEN) add_definitions(-DBLIS_CONFIG_GENERIC) + add_definitions(-DBLIS_KERNELS_SKX) add_definitions(-DBLIS_KERNELS_ZEN4) add_definitions(-DBLIS_KERNELS_ZEN3) add_definitions(-DBLIS_KERNELS_ZEN2) @@ -314,6 +316,11 @@ elseif(${ENABLE_SIMD_FLAGS} MATCHES "SSE2") add_definitions(/arch:SSE2) endif() +if(${TARGET_ARCH} STREQUAL zen4 OR + ${TARGET_ARCH} STREQUAL amdzen) + add_definitions(/arch:AVX512) +endif() + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /W0 ") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /Oi") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /MP") @@ -422,7 +429,9 @@ include_directories(${CMAKE_SOURCE_DIR}/kernels/zen/2) include_directories(${CMAKE_SOURCE_DIR}/kernels/zen/3) include_directories(${CMAKE_SOURCE_DIR}/kernels/zen/3/sup) include_directories(${CMAKE_SOURCE_DIR}/kernels/zen2) - +include_directories(${CMAKE_SOURCE_DIR}/kernels/zen4) +include_directories(${CMAKE_SOURCE_DIR}/kernels/skx) +include_directories(${CMAKE_SOURCE_DIR}/kernels/skx/3) file(GLOB headers ${CMAKE_SOURCE_DIR}/*.h) # Monolithic Header generation diff --git a/config/zen4/bli_cntx_init_zen4.c b/config/zen4/bli_cntx_init_zen4.c index e25ceabc8f..0f589f301c 100644 --- a/config/zen4/bli_cntx_init_zen4.c +++ b/config/zen4/bli_cntx_init_zen4.c @@ -115,8 +115,8 @@ void bli_cntx_init_zen4( cntx_t* cntx ) 24, // amaxv - BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int, - BLIS_AMAXV_KER, BLIS_DOUBLE, bli_damaxv_zen_int, + BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int_avx512, + BLIS_AMAXV_KER, BLIS_DOUBLE, bli_damaxv_zen_int_avx512, // axpbyv BLIS_AXPBYV_KER, BLIS_FLOAT, bli_saxpbyv_zen_int10, diff --git a/frame/2/gemv/CMakeLists.txt b/frame/2/gemv/CMakeLists.txt index 2f75a00f63..633ec9431a 100644 --- a/frame/2/gemv/CMakeLists.txt +++ b/frame/2/gemv/CMakeLists.txt @@ -12,6 +12,7 @@ target_sources("${PROJECT_NAME}" if(${TARGET_ARCH} STREQUAL zen OR ${TARGET_ARCH} STREQUAL zen2 OR ${TARGET_ARCH} STREQUAL zen3 OR + ${TARGET_ARCH} STREQUAL zen4 OR ${TARGET_ARCH} STREQUAL amdzen) target_sources("${PROJECT_NAME}" PRIVATE diff --git a/frame/2/hemv/CMakeLists.txt b/frame/2/hemv/CMakeLists.txt index 34820c3762..10e324b52d 100644 --- a/frame/2/hemv/CMakeLists.txt +++ b/frame/2/hemv/CMakeLists.txt @@ -14,7 +14,8 @@ target_sources("${PROJECT_NAME}" # Select AMD specific sources for AMD configurations. if(${TARGET_ARCH} STREQUAL zen OR ${TARGET_ARCH} STREQUAL zen2 OR - ${TARGET_ARCH} STREQUAL zen3 OR + ${TARGET_ARCH} STREQUAL zen3 OR + ${TARGET_ARCH} STREQUAL zen4 OR ${TARGET_ARCH} STREQUAL amdzen) target_sources("${PROJECT_NAME}" PRIVATE diff --git a/frame/2/her2/CMakeLists.txt b/frame/2/her2/CMakeLists.txt index 83629df8f5..cfdeb2480d 100644 --- a/frame/2/her2/CMakeLists.txt +++ b/frame/2/her2/CMakeLists.txt @@ -12,7 +12,8 @@ target_sources("${PROJECT_NAME}" # Select AMD specific sources for AMD configurations. if(${TARGET_ARCH} STREQUAL zen OR ${TARGET_ARCH} STREQUAL zen2 OR - ${TARGET_ARCH} STREQUAL zen3 OR + ${TARGET_ARCH} STREQUAL zen3 OR + ${TARGET_ARCH} STREQUAL zen4 OR ${TARGET_ARCH} STREQUAL amdzen) target_sources("${PROJECT_NAME}" PRIVATE diff --git a/frame/2/trsv/CMakeLists.txt b/frame/2/trsv/CMakeLists.txt index b07389340e..f1aacc745c 100644 --- a/frame/2/trsv/CMakeLists.txt +++ b/frame/2/trsv/CMakeLists.txt @@ -10,7 +10,8 @@ target_sources("${PROJECT_NAME}" # Select AMD specific sources for AMD configurations. if(${TARGET_ARCH} STREQUAL zen OR ${TARGET_ARCH} STREQUAL zen2 OR - ${TARGET_ARCH} STREQUAL zen3 OR + ${TARGET_ARCH} STREQUAL zen3 OR + ${TARGET_ARCH} STREQUAL zen4 OR ${TARGET_ARCH} STREQUAL amdzen) target_sources("${PROJECT_NAME}" PRIVATE diff --git a/frame/3/CMakeLists.txt b/frame/3/CMakeLists.txt index e9d7da7b8e..734622344a 100644 --- a/frame/3/CMakeLists.txt +++ b/frame/3/CMakeLists.txt @@ -30,7 +30,8 @@ target_sources("${PROJECT_NAME}" # Select AMD specific sources for AMD configurations. if(${TARGET_ARCH} STREQUAL zen OR ${TARGET_ARCH} STREQUAL zen2 OR - ${TARGET_ARCH} STREQUAL zen3 OR + ${TARGET_ARCH} STREQUAL zen3 OR + ${TARGET_ARCH} STREQUAL zen4 OR ${TARGET_ARCH} STREQUAL amdzen) target_sources("${PROJECT_NAME}" PRIVATE diff --git a/frame/3/gemm/CMakeLists.txt b/frame/3/gemm/CMakeLists.txt index 825dd745ca..8969680031 100644 --- a/frame/3/gemm/CMakeLists.txt +++ b/frame/3/gemm/CMakeLists.txt @@ -18,7 +18,8 @@ target_sources("${PROJECT_NAME}" # Select AMD specific sources for AMD configurations. if(${TARGET_ARCH} STREQUAL zen OR ${TARGET_ARCH} STREQUAL zen2 OR -${TARGET_ARCH} STREQUAL zen3 OR +${TARGET_ARCH} STREQUAL zen3 OR +${TARGET_ARCH} STREQUAL zen4 OR ${TARGET_ARCH} STREQUAL amdzen) target_sources("${PROJECT_NAME}" PRIVATE diff --git a/frame/3/trmm/CMakeLists.txt b/frame/3/trmm/CMakeLists.txt index a3845f3858..49106e4b10 100644 --- a/frame/3/trmm/CMakeLists.txt +++ b/frame/3/trmm/CMakeLists.txt @@ -12,6 +12,7 @@ target_sources("${PROJECT_NAME}" if(${TARGET_ARCH} STREQUAL zen OR ${TARGET_ARCH} STREQUAL zen2 OR ${TARGET_ARCH} STREQUAL zen3 OR +${TARGET_ARCH} STREQUAL zen4 OR ${TARGET_ARCH} STREQUAL amdzen) target_sources("${PROJECT_NAME}" PRIVATE diff --git a/frame/compat/CMakeLists.txt b/frame/compat/CMakeLists.txt index 48b66acbcb..bfe8e10508 100644 --- a/frame/compat/CMakeLists.txt +++ b/frame/compat/CMakeLists.txt @@ -35,7 +35,8 @@ ${CMAKE_CURRENT_SOURCE_DIR}/bla_omatadd.c # Select AMD specific sources for AMD configurations. if(${TARGET_ARCH} STREQUAL zen OR ${TARGET_ARCH} STREQUAL zen2 OR -${TARGET_ARCH} STREQUAL zen3 OR +${TARGET_ARCH} STREQUAL zen3 OR +${TARGET_ARCH} STREQUAL zen4 OR ${TARGET_ARCH} STREQUAL amdzen) target_sources("${PROJECT_NAME}" PRIVATE diff --git a/frame/compat/bla_amax_amd.c b/frame/compat/bla_amax_amd.c index 7f1a771f7c..2f7c2d2491 100644 --- a/frame/compat/bla_amax_amd.c +++ b/frame/compat/bla_amax_amd.c @@ -162,13 +162,15 @@ f77_int isamax_ // Non-AVX platforms will use the kernels derived from the context. if (bli_cpuid_is_avx_supported() == TRUE) { + cntx_t* cntx = bli_gks_query_cntx(); + samaxv_ker_ft f = bli_cntx_get_l1v_ker_dt(BLIS_FLOAT, BLIS_AMAXV_KER, cntx ); /* Call BLIS kernel */ - bli_samaxv_zen_int + f ( - n0, - x0, incx0, - &bli_index, - NULL + n0, + x0, incx0, + &bli_index, + NULL ); } else @@ -258,13 +260,15 @@ f77_int idamax_ // Non-AVX platforms will use the kernels derived from the context. if (bli_cpuid_is_avx_supported() == TRUE) { + cntx_t* cntx = bli_gks_query_cntx(); + damaxv_ker_ft f = bli_cntx_get_l1v_ker_dt(BLIS_DOUBLE, BLIS_AMAXV_KER, cntx ); /* Call BLIS kernel */ - bli_damaxv_zen_int + f ( - n0, - x0, incx0, - &bli_index, - NULL + n0, + x0, incx0, + &bli_index, + NULL ); } else diff --git a/frame/include/bli_arch_config.h b/frame/include/bli_arch_config.h index 6343c6ba89..787e3879b8 100644 --- a/frame/include/bli_arch_config.h +++ b/frame/include/bli_arch_config.h @@ -264,7 +264,9 @@ CNTX_INIT_PROTS( generic ) #endif // -- AMD64 architectures -- - +#ifdef BLIS_KERNELS_ZEN4 +#include "bli_kernels_zen4.h" +#endif #ifdef BLIS_KERNELS_ZEN2 #include "bli_kernels_zen2.h" #endif diff --git a/kernels/CMakeLists.txt b/kernels/CMakeLists.txt index 5cf469ef18..bee82f8685 100644 --- a/kernels/CMakeLists.txt +++ b/kernels/CMakeLists.txt @@ -1,4 +1,10 @@ -##Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.## +##Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.## add_subdirectory(haswell) add_subdirectory(zen) + +if(${TARGET_ARCH} STREQUAL zen4 OR + ${TARGET_ARCH} STREQUAL amdzen) + add_subdirectory(skx) + add_subdirectory(zen4) +endif() \ No newline at end of file diff --git a/kernels/skx/3/CMakeLists.txt b/kernels/skx/3/CMakeLists.txt new file mode 100644 index 0000000000..30857ba975 --- /dev/null +++ b/kernels/skx/3/CMakeLists.txt @@ -0,0 +1,7 @@ +##Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.## + +target_sources("${PROJECT_NAME}" + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/bli_dgemm_skx_asm_16x14.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_sgemm_skx_asm_32x12_l2.c + ) diff --git a/kernels/skx/CMakeLists.txt b/kernels/skx/CMakeLists.txt new file mode 100644 index 0000000000..bc8f1eaab3 --- /dev/null +++ b/kernels/skx/CMakeLists.txt @@ -0,0 +1,4 @@ +##Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.## + +add_subdirectory(3) + diff --git a/kernels/zen/1/bli_amaxv_zen_int.c b/kernels/zen/1/bli_amaxv_zen_int.c index 8487bdce4b..7f799fa628 100644 --- a/kernels/zen/1/bli_amaxv_zen_int.c +++ b/kernels/zen/1/bli_amaxv_zen_int.c @@ -4,8 +4,8 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2016 - 2021, Advanced Micro Devices, Inc. - Copyright (C) 2018, The University of Texas at Austin + Copyright (C) 2016 - 2022, Advanced Micro Devices, Inc. + Copyright (C) 2018, The University of Texas at Austin Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -36,26 +36,6 @@ #include "immintrin.h" #include "blis.h" -// Disable for all context without AVX512 support -// Please define it in bli_family_xxx.h in config directory if there is AVX512 support -#ifdef AVX512 -/* Union data structure to access AVX registers - One 512-bit AVX register holds 8 DP elements. */ -typedef union -{ - __m512d v; - double d[8] __attribute__((aligned(64))); -} v8df_t; - -/* Union data structure to access AVX registers - One 512-bit AVX register holds 16 SP elements. */ -typedef union -{ - __m512 v; - float f[16] __attribute__((aligned(64))); -} v16sf_t; -#endif - /* Union data structure to access AVX registers One 256-bit AVX register holds 8 SP elements. */ typedef union @@ -84,42 +64,6 @@ typedef union double d[2]; }v2dd_t; -// Disable for all context without AVX512 support -// Please define it in bli_family_xxx.h in config directory if there is AVX512 support -#ifdef AVX512 -/* Convert the nan to -ve numbers decrementing with - the times the function is called to ensure that - bigger numbers are assigned for nan which showed - up first.*/ -__m512 remove_NAN_512_s(__m512 vec) -{ - // Sign extraction mask - __m512 sign_mask; - // Temporary place to store vector's sign extracted 16xdouble word - __m512 vec_mask; - // k register to store the mask to do blend operation to remove NAN - __mmask16 vec_mask16; - // Static to preserve accross the function calls - static int iter = -1; - iter -= 1; - - // Extracting sign from the vec into int_mask_vec - // Sign is -0.f in IEEE754 is just signbit set, all others 0 - sign_mask = _mm512_set1_ps(-0.f); - // And with -0.f will keep just signbits, all others will be 0 - vec_mask = _mm512_mul_ps(vec, sign_mask); - // Typecast mask into int type no clock cycle is taken just to - // convince compiler. - __m512i int_mask_vec = _mm512_castps_si512(vec_mask); - // Extract the signbits and put it in a 16bit mask register - vec_mask16 = _mm512_movepi32_mask(int_mask_vec); - - // Swap NAN with -ve number - vec = _mm512_mask_blend_ps(vec_mask16, _mm512_set1_ps(iter), vec); - return vec; -} -#endif - // return a mask which indicates either: // - v1 > v2 // - v1 is NaN and v2 is not @@ -320,511 +264,7 @@ void bli_samaxv_zen_int AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3) } -// Disable for all context without AVX512 support -// Please define it in bli_family_xxx.h in config directory if there is AVX512 support -#ifdef AVX512 -void bli_samaxv_zen_int_avx512( - dim_t n, - float *restrict x, inc_t incx, - dim_t *restrict i_max, - cntx_t *restrict cntx) -{ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_3) - // *minus_one = -1 - float *minus_one = PASTEMAC(s, m1); // bli_sm1() - // *zero_i = 0 - dim_t *zero_i = PASTEMAC(i, 0); // bli_i0() - - float fndMaxVal; // Max value will be stored in this - dim_t fndInd; // Max value's index will be stored in this - // Iterator for loops to keep continuity throughout the loops - dim_t i; - - /* If the vector length is zero, return early. This directly emulates - the behavior of netlib BLAS's i?amax() routines. */ - if (bli_zero_dim1(n)) - { - /* Set i_max to zero if dimension is 0, no need to compute */ - // Copy zero_i, that is 0 to i_max (i_max = 0) - PASTEMAC(i, copys) // bli_icopys - (*zero_i, *i_max); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3) - return; - } - - /* Initialize the index of the maximum absolute value to zero. */ - // Copy zero_i, that is 0 to fndInd (fndInd = 0) - PASTEMAC(i, copys) // bli_icopys - (*zero_i, fndInd); - - /* Initialize the maximum absolute value search candidate with - -1, which is guaranteed to be less than all values we will - compute. */ - // Copy minus_one to fndMaxVal real and imaginary. - PASTEMAC(s, copys) // bli_scopys - (*minus_one, fndMaxVal); - - // For non-unit strides, or very small vector lengths, compute with - // scalar code. - // n is less than the single vector length or non unit stride. - if (incx != 1 || n < 16) - { - for (i = 0; i < n; ++i) - { - // Call math.h fabsf to take absolute value of *(x +(i)*incx) - float absval = fabsf(*(x + (i)*incx)); - if (fndMaxVal < absval || (isnan(absval) && !isnan(fndMaxVal))) - { - // If max value is found, set the value and index - fndMaxVal = absval; - fndInd = i; - } - } - } - else - { - dim_t num_iter, num_remain; - dim_t num_vector_elements = 16; - /* Total Registers used is - * xmm0-xmm4 - * ymm5-ymm9 - * zmm10-zmm26 - * There are 6 free registers to use - */ - // zmm register 15x - v16sf_t x_vec_1, x_vec_2, x_vec_3, max_vec_1, max_vec_2, - max_vec_3, maxInd_vec_1, maxInd_vec_2, - maxInd_vec_3, index_vec_1, ind_vec_2, - ind_vec_3, inc_vec, mask, - abs_mask; - // ymm register 5x - v8sf_t max_vec_lo, max_vec_hi, - maxInd_vec_lo, maxInd_vec_hi, - mask_vec_lo; - // xmm register 5x - v4sf_t max_vec_lo_lo, max_vec_lo_hi, - maxInd_vec_lo_lo, maxInd_vec_lo_hi, - mask_vec_lo_lo; - // zmm register 1x - __m512i intMask; - // k register 3x - __mmask16 mask_vec_1, mask_vec_2, - mask_vec_3; - - // Number of iterations for main loop. - num_iter = n / num_vector_elements; - // Number of iterations remaining for residual non vector loop - num_remain = n % num_vector_elements; - // A number with signbit one and others 0 IEEE-754 - abs_mask.v = _mm512_set1_ps(-0.f); - // index_vector after loading max_vector with initial values. - index_vec_1.v = _mm512_setr_ps(16, 17, 18, 19, 20, 21, - 22, 23, 24, 25, 26, 27, - 28, 29, 30, 31); - // Broadcast 16. This is to increment the vector easily - inc_vec.v = _mm512_set1_ps(16); - // Load 16 float values from memory - max_vec_1.v = _mm512_loadu_ps(x); - // max_vector = abs(max_vector) - max_vec_1.v = _mm512_andnot_ps(abs_mask.v, max_vec_1.v); - // Remove nan and replace with -ve values - max_vec_1.v = remove_NAN_512_s(max_vec_1.v); - - // Increment x vector as we have loaded 16 values - x += num_vector_elements; - // indexes for values present in max vector. - maxInd_vec_1.v = _mm512_setr_ps(0, 1, 2, 3, 4, 5, 6, 7, 8, - 9, 10, 11, 12, 13, 14, 15); - - int i = 1; - for (; (i + 4) < num_iter; i += 5) - { - /* - Unrolled to process 5 at a time. It basically works - by taking a master max_vec_1 and a maxInd_vec_1 - holding indexes. Elements are taken from the RAM on a batch - of 5 (1 master max_vec_1 already exists to compare so - 6 elements). Now each 2 of them is compared with each other - and an intermediate result is obtained. This intermediate - result is again with each other and combined until we reach - one vector in max_vector and maxIndex_vector. - */ - - // Load the vector and subs NAN - // Load Value x values - x_vec_1.v = _mm512_loadu_ps(x); - // x_vec_1 = abs(x_vec_1) - x_vec_1.v = _mm512_andnot_ps(abs_mask.v, x_vec_1.v); - // Increment x vector as we have loaded 16 values - x += num_vector_elements; - // Remove nan and replace with -ve values - x_vec_1.v = remove_NAN_512_s(x_vec_1.v); - - // Mask Generation of 1st(can be previous max) and 2nd element - // mask = max_vector - x_vec_1 - mask.v = _mm512_sub_ps(max_vec_1.v, x_vec_1.v); - // Type cast mask from IEEE754 (float) to integer type - // This operation will not need a new register, its just to convince - // the compiler. But its accounted as seperate register in the - // above calculations - intMask = _mm512_castps_si512(mask.v); - // Extract the signbit and build the mask. - mask_vec_1 = _mm512_movepi32_mask(intMask); - - // Load 2 elements to 2nd max and x vector, set indexes - // Load Value x values - max_vec_2.v = _mm512_loadu_ps(x); - // max_vec_2 = abs(max_vec_2) - max_vec_2.v = _mm512_andnot_ps(abs_mask.v, max_vec_2.v); - // Remove nan and replace with -ve values - max_vec_2.v = remove_NAN_512_s(max_vec_2.v); - // Increment x vector as we have loaded 16 values - x += num_vector_elements; - // Increment the index vector to point to next indexes. - maxInd_vec_2.v = _mm512_add_ps(index_vec_1.v, inc_vec.v); - - // Load Value x values - x_vec_2.v = _mm512_loadu_ps(x); - // x_vec_2 = abs(x_vec_2) - x_vec_2.v = _mm512_andnot_ps(abs_mask.v, x_vec_2.v); - // Remove nan and replace with -ve values - x_vec_2.v = remove_NAN_512_s(x_vec_2.v); - // Increment x vector as we have loaded 16 values - x += num_vector_elements; - // Increment the index vector to point to next indexes. - ind_vec_2.v = _mm512_add_ps(maxInd_vec_2.v, inc_vec.v); - - // Mask generation for last loaded 2 elements into x and max vectors. - // mask = max_vec_2 - x_vec_2 - mask.v = _mm512_sub_ps(max_vec_2.v, x_vec_2.v); - // Type cast mask from IEEE754 (float) to integer type - // This operation will not need a new register, its just to convince - // the compiler. But its accounted as seperate register in the - // above calculations - intMask = _mm512_castps_si512(mask.v); - // Extract the signbit and build the mask. - mask_vec_2 = _mm512_movepi32_mask(intMask); - - // Load 2 more elements to 3rd max and x vector, set indexes - // Load Value x values - max_vec_3.v = _mm512_loadu_ps(x); - // max_vec_3 = abs(max_vec_3) - max_vec_3.v = _mm512_andnot_ps(abs_mask.v, max_vec_3.v); - // Remove nan and replace with -ve values - max_vec_3.v = remove_NAN_512_s(max_vec_3.v); - // Increment x vector as we have loaded 16 values - x += num_vector_elements; - // Increment the index vector to point to next indexes. - maxInd_vec_3.v = _mm512_add_ps(ind_vec_2.v, inc_vec.v); - // Load Value x values - x_vec_3.v = _mm512_loadu_ps(x); - // x_vec_3 = abs(x_vec_3) - x_vec_3.v = _mm512_andnot_ps(abs_mask.v, x_vec_3.v); - // Remove nan and replace with -ve values - x_vec_3.v = remove_NAN_512_s(x_vec_3.v); - // Increment x vector as we have loaded 16 values - x += num_vector_elements; - // Increment the index vector to point to next indexes. - ind_vec_3.v = _mm512_add_ps(maxInd_vec_3.v, inc_vec.v); - - // Mask generation for last 2 elements loaded into x and max vectors. - // mask = max_vec_3 - x_vec_3 - mask.v = _mm512_sub_ps(max_vec_3.v, x_vec_3.v); - // Type cast mask from IEEE754 (float) to integer type - // This operation will not need a new register, its just to convince - // the compiler. But its accounted as seperate register in the - // above calculations - intMask = _mm512_castps_si512(mask.v); - // Extract the signbit and build the mask. - mask_vec_3 = _mm512_movepi32_mask(intMask); - - // Blend max vector and index vector (3 pairs of elements needs to be blended). - /* Take values from max_vector if corresponding bit in mask_vector is 0 - * otherwise take value from x_vector, this is accumulated maximum value - * from max_vector and x_vector to mask_vector */ - max_vec_1.v = _mm512_mask_blend_ps(mask_vec_1, - max_vec_1.v, - x_vec_1.v); - /* Take values from max_vector if corresponding bit in mask_vector is 0 - * otherwise take value from x_vector, this is accumulated maximum value - * from max_vector and x_vector to mask_vector */ - max_vec_2.v = _mm512_mask_blend_ps(mask_vec_2, - max_vec_2.v, - x_vec_2.v); - /* Take values from max_vector if corresponding bit in mask_vector is 0 - * otherwise take value from x_vector, this is accumulated maximum value - * from max_vector and x_vector to mask_vector */ - max_vec_3.v = _mm512_mask_blend_ps(mask_vec_3, - max_vec_3.v, - x_vec_3.v); - /* Take values from maxIndex_vector if corresponding bit in mask_vector - * is 0 otherwise take value from index_vec_1, this is accumulated - * maximum value index from maxIndex_vector and index_vec_1 - * to maxIndex_vector */ - maxInd_vec_1.v = _mm512_mask_blend_ps(mask_vec_1, - maxInd_vec_1.v, - index_vec_1.v); - /* Take values from maxIndex_vector if corresponding bit in mask_vector - * is 0 otherwise take value from index_vec_1, this is accumulated - * maximum value index from maxIndex_vector and index_vec_1 - * to maxIndex_vector */ - maxInd_vec_2.v = _mm512_mask_blend_ps(mask_vec_2, - maxInd_vec_2.v, - ind_vec_2.v); - /* Take values from maxIndex_vector if corresponding bit in mask_vector - * is 0 otherwise take value from index_vec_1, this is accumulated - * maximum value index from maxIndex_vector and index_vec_1 - * to maxIndex_vector */ - maxInd_vec_3.v = _mm512_mask_blend_ps(mask_vec_3, - maxInd_vec_3.v, - ind_vec_3.v); - - // Mask generation for blending max_vec_2 and max_vec_3 to max_vec_2. - // mask = max_vec_2 - max_vec_3 - mask.v = _mm512_sub_ps(max_vec_2.v, max_vec_3.v); - // Type cast mask from IEEE754 (float) to integer type - // This operation will not need a new register, its just to convince - // the compiler. But its accounted as seperate register in the - // above calculations - intMask = _mm512_castps_si512(mask.v); - // Extract the signbit and build the mask. - mask_vec_2 = _mm512_movepi32_mask(intMask); - - // Blend to obtain 1 vector each of max values and index. - /* Take values from max_vec_2 if corresponding bit in mask_vec_2 - * is 0 otherwise take value from max_vec_3, this is accumulated - * maximum value from max_vec_2 and max_vec_3 to mask_vec_2 */ - max_vec_2.v = _mm512_mask_blend_ps(mask_vec_2, - max_vec_2.v, - max_vec_3.v); - /* Take values from maxInd_vec_2 if corresponding bit in mask_vector - * is 0 otherwise take value from maxInd_vec_3, this is accumulated - * maximum value index from maxInd_vec_2 and maxInd_vec_3 - * to maxInd_vec_2 */ - maxInd_vec_2.v = _mm512_mask_blend_ps(mask_vec_2, - maxInd_vec_2.v, - maxInd_vec_3.v); - - // Mask generation for blending max_vec_1 and max_vec_2 into max_vec_1. - // mask = max_vec_1 - max_vec_2 - mask.v = _mm512_sub_ps(max_vec_1.v, max_vec_2.v); - // Type cast mask from IEEE754 (float) to integer type - // This operation will not need a new register, its just to convince - // the compiler. But its accounted as seperate register in the - // above calculations - intMask = _mm512_castps_si512(mask.v); - // Extract the signbit and build the mask. - mask_vec_1 = _mm512_movepi32_mask(intMask); - - // Final blend to the master max_vec_1 and maxInd_vec_1 - /* Take values from max_vec_1 if corresponding bit in mask_vec_1 - * is 0 otherwise take value from max_vec_2, this is accumulated - * maximum value from max_vec_1 and max_vec_2 to mask_vec_1 */ - max_vec_1.v = _mm512_mask_blend_ps(mask_vec_1, max_vec_1.v, max_vec_2.v); - /* Take values from maxInd_vec_1 if corresponding bit in mask_vector - * is 0 otherwise take value from maxInd_vec_2, this is accumulated - * maximum value index from maxInd_vec_1 and maxInd_vec_2 - * to maxInd_vec_1 */ - maxInd_vec_1.v = _mm512_mask_blend_ps(mask_vec_1, - maxInd_vec_1.v, - maxInd_vec_2.v); - - // Increment the index vector to point to next indexes. - index_vec_1.v = _mm512_add_ps(ind_vec_3.v, inc_vec.v); - } - - for (; i < num_iter; i++) - { - /* - Take vector one by one, above code makes max_vec_1 - contain the first 16 elements, now with the max vector - as first 16 elements (abs), we need to load next 16 elements - into x_vec_1 (abs). Now with those we can safely removeNan - which will put -ve values as NAN. - - These -ve values of NAN decreases by 1 in each iteration, - this helps us find the first NAN value. - */ - // Load Value x values - x_vec_1.v = _mm512_loadu_ps(x); - // x_vec_1 = abs(x_vec_1) - x_vec_1.v = _mm512_andnot_ps(abs_mask.v, x_vec_1.v); - // Remove nan and replace with -ve values - x_vec_1.v = remove_NAN_512_s(x_vec_1.v); - - // Mask Generation - // mask = max_vec_1 - x_vec_1 - mask.v = _mm512_sub_ps(max_vec_1.v, x_vec_1.v); - // Extract the signbit and build the mask. - mask_vec_1 = _mm512_movepi32_mask(_mm512_castps_si512(mask.v)); - /* Take values from max_vec_1 if corresponding bit in - * mask_vec_1 is 0 otherwise take value from x_vec_1, - * this is accumulated maximum value from max_vec_1 and - * x_vec_1 to mask_vec_1 */ - max_vec_1.v = _mm512_mask_blend_ps(mask_vec_1, - max_vec_1.v, - x_vec_1.v); - /* Take values from maxInd_vec_1 if corresponding bit in - * mask_vector is 0 otherwise take value from index_vec_1, - * this is accumulated maximum value index from maxInd_vec_1 - * and index_vec_1 to maxInd_vec_1 */ - maxInd_vec_1.v = _mm512_mask_blend_ps(mask_vec_1, - maxInd_vec_1.v, - index_vec_1.v); - - // Increment the index vector to point to next indexes. - index_vec_1.v = _mm512_add_ps(index_vec_1.v, inc_vec.v); - - // Increment x vector as we have loaded 16 values - x += num_vector_elements; - } - - num_remain = (n - ((i)*16)); - - /* - Now take the max vector and produce the max value from - the max vector by slicing and comparing with itself, - until we are left with just one index position and max value. - */ - // Split max to hi and lo - max_vec_hi.v = _mm512_extractf32x8_ps(max_vec_1.v, 1); - max_vec_lo.v = _mm512_extractf32x8_ps(max_vec_1.v, 0); - - // Split maxIndex to hi and lo - maxInd_vec_hi.v = _mm512_extractf32x8_ps(maxInd_vec_1.v, 1); - maxInd_vec_lo.v = _mm512_extractf32x8_ps(maxInd_vec_1.v, 0); - - // Compare max_vec_hi > max_vec_1 - // mask_vec_lo = max_vec_lo - max_vec_hi - mask_vec_lo.v = _mm256_sub_ps(max_vec_lo.v, max_vec_hi.v); - - /* Take values from max_vec_lo if corresponding bit in mask_vec_lo - * is 0 otherwise take value from max_vec_hi, this is accumulated - * maximum value from max_vec_lo and max_vec_hi to max_vec_lo */ - max_vec_lo.v = _mm256_blendv_ps(max_vec_lo.v, - max_vec_hi.v, - mask_vec_lo.v); - /* Take values from maxInd_vec_lo if corresponding bit - * in mask_vec_lo is 0 otherwise take value from maxInd_vec_hi, - * this is accumulated maximum value from maxInd_vec_lo and - * maxInd_vec_hi to maxInd_vec_lo */ - maxInd_vec_lo.v = _mm256_blendv_ps(maxInd_vec_lo.v, - maxInd_vec_hi.v, - mask_vec_lo.v); - - // Split max_lo to hi and lo - max_vec_lo_hi.v = _mm256_extractf128_ps(max_vec_lo.v, 1); - max_vec_lo_lo.v = _mm256_extractf128_ps(max_vec_lo.v, 0); - - // Split maxIndex_lo to hi and lo - maxInd_vec_lo_hi.v = _mm256_extractf128_ps(maxInd_vec_lo.v, 1); - maxInd_vec_lo_lo.v = _mm256_extractf128_ps(maxInd_vec_lo.v, 0); - - // mask_vec_lo_lo = max_vec_lo_lo - max_vec_lo_hi - mask_vec_lo_lo.v = _mm_sub_ps(max_vec_lo_lo.v, max_vec_lo_hi.v); - /* Take values from max_vec_lo_lo if corresponding bit in - * mask_vec_lo_lo is 0 otherwise take value from max_vec_lo_hi, - * this is accumulated maximum value from max_vec_lo_lo and - * max_vec_lo_hi to max_vec_lo_lo */ - max_vec_lo_lo.v = _mm_blendv_ps(max_vec_lo_lo.v, - max_vec_lo_hi.v, - mask_vec_lo_lo.v); - /* Take values from maxInd_vec_lo if corresponding bit - * in mask_vec_lo_lo is 0 otherwise take value from maxInd_vec_hi, - * this is accumulated maximum value from maxInd_vec_lo and - * maxInd_vec_hi to maxInd_vec_lo */ - maxInd_vec_lo_lo.v = _mm_blendv_ps(maxInd_vec_lo_lo.v, - maxInd_vec_lo_hi.v, - mask_vec_lo_lo.v); - - // Take 64 high bits of max_lo_lo and put it to 64 low bits, rest 1st value - /* Example max_vec_lo_lo is {a, b, x, y} - * After max_vec_lo_hi.v = _mm_permute_ps(max_vec_lo_lo.v, 14); - * max_vec_lo_hi is {x, y, a, a} (essentially folding the vector) - */ - max_vec_lo_hi.v = _mm_permute_ps(max_vec_lo_lo.v, 14); - // Fold the vector same as max_vector - maxInd_vec_lo_hi.v = _mm_permute_ps(maxInd_vec_lo_lo.v, 14); - - // mask_vec_lo_lo = max_vec_lo_lo - max_vec_lo_hi - mask_vec_lo_lo.v = _mm_sub_ps(max_vec_lo_lo.v, max_vec_lo_hi.v); - /* Take values from max_vec_lo_lo if corresponding bit in - * mask_vec_lo_lo is 0 otherwise take value from max_vec_lo_hi, - * this is accumulated maximum value from max_vec_lo_lo and - * max_vec_lo_hi to max_vec_lo_lo */ - max_vec_lo_lo.v = _mm_blendv_ps(max_vec_lo_lo.v, - max_vec_lo_hi.v, - mask_vec_lo_lo.v); - /* Take values from maxInd_vec_lo if corresponding bit - * in mask_vec_lo_lo is 0 otherwise take value from maxInd_vec_hi, - * this is accumulated maximum value from maxInd_vec_lo and - * maxInd_vec_hi to maxInd_vec_lo */ - maxInd_vec_lo_lo.v = _mm_blendv_ps(maxInd_vec_lo_lo.v, - maxInd_vec_lo_hi.v, - mask_vec_lo_lo.v); - - // Take max_vec_lo_lo.f[1] and put it to max_vec_lo_hi.f[0] - /* Example max_vec_lo_lo is {a, b, x, y} - * After max_vec_lo_hi.v = _mm_permute_ps(max_vec_lo_lo.v, 1); - * max_vec_lo_hi is {b, a, a, a} (essentially folding the vector) - */ - max_vec_lo_hi.v = _mm_permute_ps(max_vec_lo_lo.v, 1); - // Do the same operation. - maxInd_vec_lo_hi.v = _mm_permute_ps(maxInd_vec_lo_lo.v, 1); - - // mask_vec_lo_lo = max_vec_lo_lo - max_vec_lo_hi - mask_vec_lo_lo.v = _mm_sub_ps(max_vec_lo_lo.v, max_vec_lo_hi.v); - /* Take values from max_vec_lo_lo if corresponding bit in - * mask_vec_lo_lo is 0 otherwise take value from max_vec_lo_hi, - * this is accumulated maximum value from max_vec_lo_lo and - * max_vec_lo_hi to max_vec_lo_lo */ - max_vec_lo_lo.v = _mm_blendv_ps(max_vec_lo_lo.v, - max_vec_lo_hi.v, - mask_vec_lo_lo.v); - /* Take values from maxInd_vec_lo if corresponding bit - * in mask_vec_lo_lo is 0 otherwise take value from maxInd_vec_hi, - * this is accumulated maximum value from maxInd_vec_lo and - * maxInd_vec_hi to maxInd_vec_lo */ - maxInd_vec_lo_lo.v = _mm_blendv_ps(maxInd_vec_lo_lo.v, - maxInd_vec_lo_hi.v, - mask_vec_lo_lo.v); - /* We have kept on folding and comparing until we got one single index - * and max value so that is the final answer so set it as the final - * answer.*/ - fndInd = maxInd_vec_lo_lo.f[0]; - fndMaxVal = max_vec_lo_lo.f[0]; - // Found value is < 0 means it was the max NAN which was accumulated. - if (fndMaxVal < 0) - { - // So just set it as NAN - fndMaxVal = NAN; - } - // Finish off the remaining values using normal instructions - for (int i = n - num_remain; i < n; i++) - { - float absval = fabsf(*(x)); - if (fndMaxVal < absval || (isnan(absval) && !isnan(fndMaxVal))) - { - fndMaxVal = absval; - fndInd = i; - } - x += 1; - } - } - - // Issue vzeroupper instruction to clear upper lanes of ymm registers. - // This avoids a performance penalty caused by false dependencies when - // transitioning from from AVX to SSE instructions (which may occur - // later, especially if BLIS is compiled with -mfpmath=sse). - _mm256_zeroupper(); - - /* Store final index to output variable. */ - *i_max = fndInd; - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3) -} -#endif // ----------------------------------------------------------------------------- - void bli_damaxv_zen_int ( dim_t n, @@ -981,444 +421,3 @@ void bli_damaxv_zen_int *i_max = i_max_l; AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3) } - -// ----------------------------------------------------------------------------- - -#if 0 -#undef GENTFUNCR -#define GENTFUNCR( ctype, ctype_r, ch, chr, varname ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - dim_t n, \ - ctype* x, inc_t incx, \ - dim_t* i_max, \ - cntx_t* cntx \ - ) \ -{ \ - ctype_r* minus_one = PASTEMAC(chr,m1); \ - dim_t* zero_i = PASTEMAC(i,0); \ -\ - ctype_r chi1_r; \ - ctype_r chi1_i; \ - ctype_r abs_chi1; \ - ctype_r abs_chi1_max; \ - dim_t i; \ -\ - /* Initialize the index of the maximum absolute value to zero. */ \ - PASTEMAC(i,copys)( zero_i, *i_max ); \ -\ - /* If the vector length is zero, return early. This directly emulates - the behavior of netlib BLAS's i?amax() routines. */ \ - if ( bli_zero_dim1( n ) ) return; \ -\ - /* Initialize the maximum absolute value search candidate with - -1, which is guaranteed to be less than all values we will - compute. */ \ - PASTEMAC(chr,copys)( *minus_one, abs_chi1_max ); \ -\ - if ( incx == 1 ) \ - { \ - for ( i = 0; i < n; ++i ) \ - { \ - /* Get the real and imaginary components of chi1. */ \ - PASTEMAC2(ch,chr,gets)( x[i], chi1_r, chi1_i ); \ -\ - /* Replace chi1_r and chi1_i with their absolute values. */ \ - PASTEMAC(chr,abval2s)( chi1_r, chi1_r ); \ - PASTEMAC(chr,abval2s)( chi1_i, chi1_i ); \ -\ - /* Add the real and imaginary absolute values together. */ \ - PASTEMAC(chr,set0s)( abs_chi1 ); \ - PASTEMAC(chr,adds)( chi1_r, abs_chi1 ); \ - PASTEMAC(chr,adds)( chi1_i, abs_chi1 ); \ -\ - /* If the absolute value of the current element exceeds that of - the previous largest, save it and its index. If NaN is - encountered, then treat it the same as if it were a valid - value that was smaller than any previously seen. This - behavior mimics that of LAPACK's ?lange(). */ \ - if ( abs_chi1_max < abs_chi1 || bli_isnan( abs_chi1 ) ) \ - { \ - abs_chi1_max = abs_chi1; \ - *i_max = i; \ - } \ - } \ - } \ - else \ - { \ - for ( i = 0; i < n; ++i ) \ - { \ - ctype* chi1 = x + (i )*incx; \ -\ - /* Get the real and imaginary components of chi1. */ \ - PASTEMAC2(ch,chr,gets)( *chi1, chi1_r, chi1_i ); \ -\ - /* Replace chi1_r and chi1_i with their absolute values. */ \ - PASTEMAC(chr,abval2s)( chi1_r, chi1_r ); \ - PASTEMAC(chr,abval2s)( chi1_i, chi1_i ); \ -\ - /* Add the real and imaginary absolute values together. */ \ - PASTEMAC(chr,set0s)( abs_chi1 ); \ - PASTEMAC(chr,adds)( chi1_r, abs_chi1 ); \ - PASTEMAC(chr,adds)( chi1_i, abs_chi1 ); \ -\ - /* If the absolute value of the current element exceeds that of - the previous largest, save it and its index. If NaN is - encountered, then treat it the same as if it were a valid - value that was smaller than any previously seen. This - behavior mimics that of LAPACK's ?lange(). */ \ - if ( abs_chi1_max < abs_chi1 || bli_isnan( abs_chi1 ) ) \ - { \ - abs_chi1_max = abs_chi1; \ - *i_max = i; \ - } \ - } \ - } \ -} -GENTFUNCR( scomplex, float, c, s, amaxv_zen_int ) -GENTFUNCR( dcomplex, double, z, d, amaxv_zen_int ) -#endif - -// Disable for all context without AVX512 support -// Please define it in bli_family_xxx.h in config directory if there is AVX512 support -#ifdef AVX512 -/* Converts all the NAN to a negative number less than previously encountered NANs*/ -__m512d remove_NAN_512d(__m512d vec) -{ - - static int iter; - static __m512d sign_mask; - - __m512d vec_mask; - __m512i int_mask_vec; - __mmask8 vec_mask8; - - iter = iter - 1; - - sign_mask = _mm512_set1_pd(-0.f); - - //numbers other than NAN will become 0 - vec_mask = _mm512_mul_pd(vec, sign_mask); - - //producing an 8-bit mask - int_mask_vec = _mm512_castpd_si512(vec_mask); - vec_mask8 = _mm512_movepi64_mask(int_mask_vec); - - //replacing all the NAN with negative numbers - vec = _mm512_mask_blend_pd(vec_mask8, _mm512_set1_pd(-1 + iter), vec); - - return vec; -} - -#endif -//---------------------------------------------------------------------------------------------------- - -// Disable for all context without AVX512 support -// Please define it in bli_family_xxx.h in config directory if there is AVX512 support -#ifdef AVX512 -void bli_damaxv_zen_int_avx512( - dim_t n, - double *restrict x, inc_t incx, - dim_t *restrict i_max, - cntx_t *restrict cntx) -{ - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3) - double *minus_one = PASTEMAC(d, m1); - dim_t *zero_i = PASTEMAC(i, 0); - - double chi1_r; - //double chi1_i; - double abs_chi1; - double abs_chi1_max; - dim_t i_max_l; - dim_t i; - - /* If the vector length is zero, return early. This directly emulates - the behavior of netlib BLAS's i?amax() routines. */ - if (bli_zero_dim1(n)) - { - PASTEMAC(i, copys) - (*zero_i, *i_max); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3) - return; - } - - /* Initialize the index of the maximum absolute value to zero. */ - PASTEMAC(i, copys) - (*zero_i, i_max_l); - - /* Initialize the maximum absolute value search candidate with - -1, which is guaranteed to be less than all values we will - compute. */ - PASTEMAC(d, copys) - (*minus_one, abs_chi1_max); - - // For non-unit strides, or very small vector lengths, compute with - // scalar code. - if (incx != 1 || n < 8) - { - for (i = 0; i < n; ++i) - { - double *chi1 = x + (i)*incx; - - /* Get the real and imaginary components of chi1. */ - chi1_r = *chi1; - - /* Replace chi1_r and chi1_i with their absolute values. */ - chi1_r = fabs(chi1_r); - - /* Add the real and imaginary absolute values together. */ - abs_chi1 = chi1_r; - - /* If the absolute value of the current element exceeds that of - the previous largest, save it and its index. If NaN is - encountered, then treat it the same as if it were a valid - value that was smaller than any previously seen. This - behavior mimics that of LAPACK's i?amax(). */ - if (abs_chi1_max < abs_chi1 || (isnan(abs_chi1) && !isnan(abs_chi1_max))) - { - abs_chi1_max = abs_chi1; - i_max_l = i; - } - } - } - else - { - - dim_t iterations, n_left, vector_length = 8, unrollCount = 0; - - //mask bits - __mmask8 mask_got_01, mask_got_23; - - //YMM0 - YMM6 registers - v4df_t max_hi, max_lo, max_ind_hi, max_ind_lo, - mask_final, inter_result, inter_ind; - - //XMM0 to XMM4 registers - v2dd_t max_vec_hi, max_vec_lo, max_ind_hi_128, - max_ind_lo_128, mask_vec_lo; - - //ZMM0 to ZMM13 registers - v8df_t zmm0, zmm1, zmm2, zmm3, zmm4_Ind, - zmm5_Ind, zmm6_Ind, zmm7_Ind, max_01, - max_23, final_max, max_array, max_ind, inc_vec; - - //ZMM14 to ZMM16 registers - __m512d mask_01, mask_23, sign_mask; - - //Intermediate int mask values - __m512i int_mask_01, int_mask_23; - - // Initialize sign mask - sign_mask = _mm512_set1_pd(-0.f); - - //Initializing the indexes of the base case of max vector - zmm4_Ind.v = _mm512_set_pd(7, 6, 5, 4, 3, 2, 1, 0); - inc_vec.v = _mm512_set1_pd(8); //Vector for incrementing - - // Initializing the max array as vec [ 0 : 512 ] - max_array.v = _mm512_loadu_pd(x); - - // Taking the absolute value and removing the NAN - max_array.v = _mm512_andnot_pd(sign_mask, max_array.v); - max_array.v = remove_NAN_512d(max_array.v); - - // Initializing the maximumum index - max_ind.v = _mm512_set_pd(7, 6, 5, 4, 3, 2, 1, 0); - x += vector_length; - - //Incrementing to make the vector - //to point to the next 8 elements - zmm4_Ind.v = _mm512_add_pd(zmm4_Ind.v, inc_vec.v); - - /* Loop unrolled by a factor of 4 - At the end of the loop max_array holds the largest element - in each corresponding vector index */ - for (unrollCount = 8; (unrollCount + 31) < n; unrollCount += 32) - { - // Taking 32 elements - // Taking only the absolute values of the registers - // Removing the NAN values and replacing it - // with negative numbers - zmm0.v = _mm512_loadu_pd(x); - zmm0.v = _mm512_andnot_pd(sign_mask, zmm0.v); - zmm0.v = remove_NAN_512d(zmm0.v); - x += vector_length; - - zmm1.v = _mm512_loadu_pd(x); - zmm5_Ind.v = _mm512_add_pd(zmm4_Ind.v, inc_vec.v); - zmm1.v = _mm512_andnot_pd(sign_mask, zmm1.v); - zmm1.v = remove_NAN_512d(zmm1.v); - x += vector_length; - - zmm2.v = _mm512_loadu_pd(x); - zmm6_Ind.v = _mm512_add_pd(zmm5_Ind.v, inc_vec.v); - zmm2.v = _mm512_andnot_pd(sign_mask, zmm2.v); - zmm2.v = remove_NAN_512d(zmm2.v); - x += vector_length; - - zmm3.v = _mm512_loadu_pd(x); - zmm7_Ind.v = _mm512_add_pd(zmm6_Ind.v, inc_vec.v); - zmm3.v = _mm512_andnot_pd(sign_mask, zmm3.v); - zmm3.v = remove_NAN_512d(zmm3.v); - x += vector_length; - - /*Using sub function to generating the mask - as a 512d type*/ - mask_01 = _mm512_sub_pd(zmm0.v, zmm1.v); - mask_23 = _mm512_sub_pd(zmm2.v, zmm3.v); - - //Converting the 512d mask to a 512i mask - int_mask_01 = _mm512_castpd_si512(mask_01); - int_mask_23 = _mm512_castpd_si512(mask_23); - - /*Converting the 512i mask - to mmask type to use the mask bits*/ - mask_got_01 = _mm512_movepi64_mask(int_mask_01); - mask_got_23 = _mm512_movepi64_mask(int_mask_23); - - //Storing the largest elements in index % 8 position for - //vector 1 and 2, and the index of the corresponding element - max_01.v = _mm512_mask_blend_pd(mask_got_01, zmm0.v, zmm1.v); - zmm5_Ind.v = _mm512_mask_blend_pd(mask_got_01, zmm4_Ind.v, zmm5_Ind.v); - - //Storing the largest elements in index % 8 position for - //vector 3 and 4, and the index of the corresponding element - max_23.v = _mm512_mask_blend_pd(mask_got_23, zmm2.v, zmm3.v); - zmm6_Ind.v = _mm512_mask_blend_pd(mask_got_23, zmm6_Ind.v, zmm7_Ind.v); - - //Generating mask for the intermediate max vector - mask_01 = _mm512_sub_pd(max_01.v, max_23.v); - int_mask_01 = _mm512_castpd_si512(mask_01); - mask_got_01 = _mm512_movepi64_mask(int_mask_01); - - /*Storing the largest elements in index % 8 position for - the intermediate max vectors, - and the index of the corresponding element*/ - final_max.v = _mm512_mask_blend_pd(mask_got_01, max_01.v, max_23.v); - zmm5_Ind.v = _mm512_mask_blend_pd(mask_got_01, zmm5_Ind.v, zmm6_Ind.v); - - //Generating the mask for final max vector and base max vector - mask_01 = _mm512_sub_pd(max_array.v, final_max.v); - int_mask_01 = _mm512_castpd_si512(mask_01); - mask_got_01 = _mm512_movepi64_mask(int_mask_01); - - // Result is the maximum of all index % 8 locations - max_array.v = _mm512_mask_blend_pd(mask_got_01, max_array.v, final_max.v); - max_ind.v = _mm512_mask_blend_pd(mask_got_01, max_ind.v, zmm5_Ind.v); - - // Incrementing the index to point to the next 8 locations - zmm4_Ind.v = _mm512_add_pd(zmm7_Ind.v, inc_vec.v); - } - - // Calculating the number of iterations left - iterations = (n - unrollCount) / vector_length; - n_left = (n - unrollCount) % vector_length; - - /* At the end of the loop max_array holds the largest element - in each corresponding vector index */ - for (int i = 1; i < iterations; ++i) - { - // Taking 32 elements - // Taking only the absolute values of the registers - // Removing the NAN values and replacing it - // with negative numbers - zmm0.v = _mm512_loadu_pd(x); - zmm0.v = _mm512_abs_pd(zmm0.v); - zmm0.v = remove_NAN_512d(zmm0.v); - - //Generating mask for the intermediate max vector - mask_01 = _mm512_sub_pd(max_array.v, zmm0.v); - int_mask_01 = _mm512_castpd_si512(mask_01); - mask_got_01 = _mm512_movepi64_mask(int_mask_01); - - // Result is the maximum of all index % 8 locations - max_array.v = _mm512_mask_blend_pd(mask_got_01, max_array.v, zmm0.v); - - //Storing the index of the corresponding max array elemets - max_ind.v = _mm512_mask_blend_pd(mask_got_01, max_ind.v, zmm4_Ind.v); - - //Incrementing the vector the point to the next location - //Incrementing the vector indexes - x += vector_length; - zmm4_Ind.v = _mm512_add_pd(zmm4_Ind.v, inc_vec.v); - } - - //Breaking max array into vectors of length 4 - //Taking upper and lower halves - max_hi.v = _mm512_extractf64x4_pd(max_array.v, 1); - max_ind_hi.v = _mm512_extractf64x4_pd(max_ind.v, 1); - max_lo.v = _mm512_extractf64x4_pd(max_array.v, 0); - max_ind_lo.v = _mm512_extractf64x4_pd(max_ind.v, 0); - - //Generating the mask for blending - mask_final.v = _mm256_sub_pd(max_hi.v, max_lo.v); - - // Storing the max of max array index % 4 - inter_result.v = _mm256_blendv_pd(max_hi.v, max_lo.v, mask_final.v); - inter_ind.v = _mm256_blendv_pd(max_ind_hi.v, max_ind_lo.v, mask_final.v); - - //Breaking max array into vectors of length 2 - max_vec_lo.v = _mm256_extractf128_pd(inter_result.v, 0); - max_vec_hi.v = _mm256_extractf128_pd(inter_result.v, 1); - max_ind_hi_128.v = _mm256_extractf128_pd(inter_ind.v, 1); - max_ind_lo_128.v = _mm256_extractf128_pd(inter_ind.v, 0); - - //Generating the mask for blending - mask_vec_lo.v = _mm_sub_pd(max_vec_lo.v, max_vec_hi.v); - - // Storing the max of max array index % 2 - max_vec_lo.v = _mm_blendv_pd(max_vec_lo.v, max_vec_hi.v, mask_vec_lo.v); - max_ind_lo_128.v = _mm_blendv_pd(max_ind_lo_128.v, max_ind_hi_128.v, mask_vec_lo.v); - - max_vec_hi.v = _mm_permute_pd(max_vec_lo.v, 1); - max_ind_hi_128.v = _mm_permute_pd(max_ind_lo_128.v, 1); - - //Performing work of CMP128 i.e generating mask - mask_vec_lo.v = _mm_sub_pd(max_vec_lo.v, max_vec_hi.v); - - //Finding the maximum element - max_vec_lo.v = _mm_blendv_pd(max_vec_lo.v, max_vec_hi.v, mask_vec_lo.v); - max_ind_lo_128.v = _mm_blendv_pd(max_ind_lo_128.v, max_ind_hi_128.v, mask_vec_lo.v); - - abs_chi1_max = max_vec_lo.d[0]; - - //If the largest number is negative it is NAN - if (abs_chi1_max < 0) - abs_chi1_max = NAN; - - i_max_l = max_ind_lo_128.d[0]; - - for (i = n - n_left; i < n; i++) - { - double *chi1 = x; - - /* Get the real and imaginary components of chi1. */ - chi1_r = *chi1; - - /* Replace chi1_r and chi1_i with their absolute values. */ - abs_chi1 = fabs(chi1_r); - - /* If the absolute value of the current element exceeds that of - the previous largest, save it and its index. If NaN is - encountered, return the index of the first NaN. This - behavior mimics that of LAPACK's i?amax(). */ - if (abs_chi1_max < abs_chi1 || (isnan(abs_chi1) && !isnan(abs_chi1_max))) - { - abs_chi1_max = abs_chi1; - i_max_l = i; - } - - x += 1; - } - } - - // Return value - *i_max = i_max_l; - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3) -} -#endif - -// --------------------------------------------------------------------------------- diff --git a/kernels/zen/bli_kernels_zen.h b/kernels/zen/bli_kernels_zen.h index bd2704bbc4..53db6e4d22 100644 --- a/kernels/zen/bli_kernels_zen.h +++ b/kernels/zen/bli_kernels_zen.h @@ -43,9 +43,7 @@ PACKM_KER_PROT(double, d, packm_6xk_nn_zen) // amaxv (intrinsics) AMAXV_KER_PROT( float, s, amaxv_zen_int ) -AMAXV_KER_PROT( float, s, amaxv_zen_int_avx512 ) AMAXV_KER_PROT( double, d, amaxv_zen_int ) -AMAXV_KER_PROT( double, d, amaxv_zen_int_avx512 ) // axpbyv (intrinsics) AXPBYV_KER_PROT( float, s, axpbyv_zen_int ) diff --git a/kernels/zen4/1/CMakeLists.txt b/kernels/zen4/1/CMakeLists.txt new file mode 100644 index 0000000000..7bd499efb6 --- /dev/null +++ b/kernels/zen4/1/CMakeLists.txt @@ -0,0 +1,6 @@ +##Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.## + +target_sources("${PROJECT_NAME}" + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/bli_amaxv_zen_int_avx512.c + ) diff --git a/kernels/zen4/1/bli_amaxv_zen_int_avx512.c b/kernels/zen4/1/bli_amaxv_zen_int_avx512.c new file mode 100644 index 0000000000..0e0186c403 --- /dev/null +++ b/kernels/zen4/1/bli_amaxv_zen_int_avx512.c @@ -0,0 +1,975 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "immintrin.h" +#include "blis.h" +typedef union +{ + __m512d v; + double d[8] __attribute__((aligned(64))); +} v8df_t; + +/* Union data structure to access AVX registers + One 512-bit AVX register holds 16 SP elements. */ +typedef union +{ + __m512 v; + float f[16] __attribute__((aligned(64))); +} v16sf_t; + +/* Union data structure to access AVX registers + One 256-bit AVX register holds 8 SP elements. */ +typedef union +{ + __m256 v; + float f[8] __attribute__((aligned(64))); +} v8sf_t; + +typedef union +{ + __m128 v; + float f[4]; +} v4sf_t; + +/* Union data structure to access AVX registers + One 256-bit AVX register holds 4 DP elements. */ +typedef union +{ + __m256d v; + double d[4] __attribute__((aligned(64))); +}v4df_t; + +typedef union +{ + __m128d v; + double d[2]; +}v2dd_t; + +/* Convert the nan to -ve numbers decrementing with + the times the function is called to ensure that + bigger numbers are assigned for nan which showed + up first.*/ +__m512 remove_NAN_512_s(__m512 vec) +{ + // Sign extraction mask + __m512 sign_mask; + // Temporary place to store vector's sign extracted 16xdouble word + __m512 vec_mask; + // k register to store the mask to do blend operation to remove NAN + __mmask16 vec_mask16; + // Static to preserve accross the function calls + static int iter = -1; + iter -= 1; + + // Extracting sign from the vec into int_mask_vec + // Sign is -0.f in IEEE754 is just signbit set, all others 0 + sign_mask = _mm512_set1_ps(-0.f); + // And with -0.f will keep just signbits, all others will be 0 + vec_mask = _mm512_mul_ps(vec, sign_mask); + // Typecast mask into int type no clock cycle is taken just to + // convince compiler. + __m512i int_mask_vec = _mm512_castps_si512(vec_mask); + // Extract the signbits and put it in a 16bit mask register + vec_mask16 = _mm512_movepi32_mask(int_mask_vec); + + // Swap NAN with -ve number + vec = _mm512_mask_blend_ps(vec_mask16, _mm512_set1_ps(iter), vec); + return vec; +} + +// return a mask which indicates either: +// - v1 > v2 +// - v1 is NaN and v2 is not +// assumes that idx(v1) > idx(v2) +// all "OQ" comparisons false if either operand NaN +#define CMP256( dt, v1, v2 ) \ + _mm256_or_p##dt( _mm256_cmp_p##dt( v1, v2, _CMP_GT_OQ ), /* v1 > v2 || */ \ + _mm256_andnot_p##dt( _mm256_cmp_p##dt( v2, v2, _CMP_UNORD_Q ), /* ( !isnan(v2) && */ \ + _mm256_cmp_p##dt( v1, v1, _CMP_UNORD_Q ) /* isnan(v1) ) */ \ + ) \ + ); + +// return a mask which indicates either: +// - v1 > v2 +// - v1 is NaN and v2 is not +// - v1 == v2 (maybe == NaN) and i1 < i2 +// all "OQ" comparisons false if either operand NaN +#define CMP128( dt, v1, v2, i1, i2 ) \ + _mm_or_p##dt( _mm_or_p##dt( _mm_cmp_p##dt( v1, v2, _CMP_GT_OQ ), /* ( v1 > v2 || */ \ + _mm_andnot_p##dt( _mm_cmp_p##dt( v2, v2, _CMP_UNORD_Q ), /* ( !isnan(v2) && */ \ + _mm_cmp_p##dt( v1, v1, _CMP_UNORD_Q ) /* isnan(v1) ) ) || */ \ + ) \ + ), \ + _mm_and_p##dt( _mm_or_p##dt( _mm_cmp_p##dt( v1, v2, _CMP_EQ_OQ ), /* ( ( v1 == v2 || */ \ + _mm_and_p##dt( _mm_cmp_p##dt( v1, v1, _CMP_UNORD_Q ), /* ( isnan(v1) && */ \ + _mm_cmp_p##dt( v2, v2, _CMP_UNORD_Q ) /* isnan(v2) ) ) && */ \ + ) \ + ), \ + _mm_cmp_p##dt( i1, i2, _CMP_LT_OQ ) /* i1 < i2 ) */ \ + ) \ + ); + +// ---------------------------------------------------------------------------- +void bli_samaxv_zen_int_avx512( + dim_t n, + float *restrict x, inc_t incx, + dim_t *restrict i_max, + cntx_t *restrict cntx) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_3) + // *minus_one = -1 + float *minus_one = PASTEMAC(s, m1); // bli_sm1() + // *zero_i = 0 + dim_t *zero_i = PASTEMAC(i, 0); // bli_i0() + + float fndMaxVal; // Max value will be stored in this + dim_t fndInd; // Max value's index will be stored in this + // Iterator for loops to keep continuity throughout the loops + dim_t i; + + /* If the vector length is zero, return early. This directly emulates + the behavior of netlib BLAS's i?amax() routines. */ + if (bli_zero_dim1(n)) + { + /* Set i_max to zero if dimension is 0, no need to compute */ + // Copy zero_i, that is 0 to i_max (i_max = 0) + PASTEMAC(i, copys) // bli_icopys + (*zero_i, *i_max); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3) + return; + } + + /* Initialize the index of the maximum absolute value to zero. */ + // Copy zero_i, that is 0 to fndInd (fndInd = 0) + PASTEMAC(i, copys) // bli_icopys + (*zero_i, fndInd); + + /* Initialize the maximum absolute value search candidate with + -1, which is guaranteed to be less than all values we will + compute. */ + // Copy minus_one to fndMaxVal real and imaginary. + PASTEMAC(s, copys) // bli_scopys + (*minus_one, fndMaxVal); + + // For non-unit strides, or very small vector lengths, compute with + // scalar code. + // n is less than the single vector length or non unit stride. + if (incx != 1 || n < 16) + { + for (i = 0; i < n; ++i) + { + // Call math.h fabsf to take absolute value of *(x +(i)*incx) + float absval = fabsf(*(x + (i)*incx)); + if (fndMaxVal < absval || (isnan(absval) && !isnan(fndMaxVal))) + { + // If max value is found, set the value and index + fndMaxVal = absval; + fndInd = i; + } + } + } + else + { + dim_t num_iter, num_remain; + dim_t num_vector_elements = 16; + /* Total Registers used is + * xmm0-xmm4 + * ymm5-ymm9 + * zmm10-zmm26 + * There are 6 free registers to use + */ + // zmm register 15x + v16sf_t x_vec_1, x_vec_2, x_vec_3, max_vec_1, max_vec_2, + max_vec_3, maxInd_vec_1, maxInd_vec_2, + maxInd_vec_3, index_vec_1, ind_vec_2, + ind_vec_3, inc_vec, mask, + abs_mask; + // ymm register 5x + v8sf_t max_vec_lo, max_vec_hi, + maxInd_vec_lo, maxInd_vec_hi, + mask_vec_lo; + // xmm register 5x + v4sf_t max_vec_lo_lo, max_vec_lo_hi, + maxInd_vec_lo_lo, maxInd_vec_lo_hi, + mask_vec_lo_lo; + // zmm register 1x + __m512i intMask; + // k register 3x + __mmask16 mask_vec_1, mask_vec_2, + mask_vec_3; + + // Number of iterations for main loop. + num_iter = n / num_vector_elements; + // Number of iterations remaining for residual non vector loop + num_remain = n % num_vector_elements; + // A number with signbit one and others 0 IEEE-754 + abs_mask.v = _mm512_set1_ps(-0.f); + // index_vector after loading max_vector with initial values. + index_vec_1.v = _mm512_setr_ps(16, 17, 18, 19, 20, 21, + 22, 23, 24, 25, 26, 27, + 28, 29, 30, 31); + // Broadcast 16. This is to increment the vector easily + inc_vec.v = _mm512_set1_ps(16); + // Load 16 float values from memory + max_vec_1.v = _mm512_loadu_ps(x); + // max_vector = abs(max_vector) + max_vec_1.v = _mm512_andnot_ps(abs_mask.v, max_vec_1.v); + // Remove nan and replace with -ve values + max_vec_1.v = remove_NAN_512_s(max_vec_1.v); + + // Increment x vector as we have loaded 16 values + x += num_vector_elements; + // indexes for values present in max vector. + maxInd_vec_1.v = _mm512_setr_ps(0, 1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15); + + int i = 1; + for (; (i + 4) < num_iter; i += 5) + { + /* + Unrolled to process 5 at a time. It basically works + by taking a master max_vec_1 and a maxInd_vec_1 + holding indexes. Elements are taken from the RAM on a batch + of 5 (1 master max_vec_1 already exists to compare so + 6 elements). Now each 2 of them is compared with each other + and an intermediate result is obtained. This intermediate + result is again with each other and combined until we reach + one vector in max_vector and maxIndex_vector. + */ + + // Load the vector and subs NAN + // Load Value x values + x_vec_1.v = _mm512_loadu_ps(x); + // x_vec_1 = abs(x_vec_1) + x_vec_1.v = _mm512_andnot_ps(abs_mask.v, x_vec_1.v); + // Increment x vector as we have loaded 16 values + x += num_vector_elements; + // Remove nan and replace with -ve values + x_vec_1.v = remove_NAN_512_s(x_vec_1.v); + + // Mask Generation of 1st(can be previous max) and 2nd element + // mask = max_vector - x_vec_1 + mask.v = _mm512_sub_ps(max_vec_1.v, x_vec_1.v); + // Type cast mask from IEEE754 (float) to integer type + // This operation will not need a new register, its just to convince + // the compiler. But its accounted as seperate register in the + // above calculations + intMask = _mm512_castps_si512(mask.v); + // Extract the signbit and build the mask. + mask_vec_1 = _mm512_movepi32_mask(intMask); + + // Load 2 elements to 2nd max and x vector, set indexes + // Load Value x values + max_vec_2.v = _mm512_loadu_ps(x); + // max_vec_2 = abs(max_vec_2) + max_vec_2.v = _mm512_andnot_ps(abs_mask.v, max_vec_2.v); + // Remove nan and replace with -ve values + max_vec_2.v = remove_NAN_512_s(max_vec_2.v); + // Increment x vector as we have loaded 16 values + x += num_vector_elements; + // Increment the index vector to point to next indexes. + maxInd_vec_2.v = _mm512_add_ps(index_vec_1.v, inc_vec.v); + + // Load Value x values + x_vec_2.v = _mm512_loadu_ps(x); + // x_vec_2 = abs(x_vec_2) + x_vec_2.v = _mm512_andnot_ps(abs_mask.v, x_vec_2.v); + // Remove nan and replace with -ve values + x_vec_2.v = remove_NAN_512_s(x_vec_2.v); + // Increment x vector as we have loaded 16 values + x += num_vector_elements; + // Increment the index vector to point to next indexes. + ind_vec_2.v = _mm512_add_ps(maxInd_vec_2.v, inc_vec.v); + + // Mask generation for last loaded 2 elements into x and max vectors. + // mask = max_vec_2 - x_vec_2 + mask.v = _mm512_sub_ps(max_vec_2.v, x_vec_2.v); + // Type cast mask from IEEE754 (float) to integer type + // This operation will not need a new register, its just to convince + // the compiler. But its accounted as seperate register in the + // above calculations + intMask = _mm512_castps_si512(mask.v); + // Extract the signbit and build the mask. + mask_vec_2 = _mm512_movepi32_mask(intMask); + + // Load 2 more elements to 3rd max and x vector, set indexes + // Load Value x values + max_vec_3.v = _mm512_loadu_ps(x); + // max_vec_3 = abs(max_vec_3) + max_vec_3.v = _mm512_andnot_ps(abs_mask.v, max_vec_3.v); + // Remove nan and replace with -ve values + max_vec_3.v = remove_NAN_512_s(max_vec_3.v); + // Increment x vector as we have loaded 16 values + x += num_vector_elements; + // Increment the index vector to point to next indexes. + maxInd_vec_3.v = _mm512_add_ps(ind_vec_2.v, inc_vec.v); + // Load Value x values + x_vec_3.v = _mm512_loadu_ps(x); + // x_vec_3 = abs(x_vec_3) + x_vec_3.v = _mm512_andnot_ps(abs_mask.v, x_vec_3.v); + // Remove nan and replace with -ve values + x_vec_3.v = remove_NAN_512_s(x_vec_3.v); + // Increment x vector as we have loaded 16 values + x += num_vector_elements; + // Increment the index vector to point to next indexes. + ind_vec_3.v = _mm512_add_ps(maxInd_vec_3.v, inc_vec.v); + + // Mask generation for last 2 elements loaded into x and max vectors. + // mask = max_vec_3 - x_vec_3 + mask.v = _mm512_sub_ps(max_vec_3.v, x_vec_3.v); + // Type cast mask from IEEE754 (float) to integer type + // This operation will not need a new register, its just to convince + // the compiler. But its accounted as seperate register in the + // above calculations + intMask = _mm512_castps_si512(mask.v); + // Extract the signbit and build the mask. + mask_vec_3 = _mm512_movepi32_mask(intMask); + + // Blend max vector and index vector (3 pairs of elements needs to be blended). + /* Take values from max_vector if corresponding bit in mask_vector is 0 + * otherwise take value from x_vector, this is accumulated maximum value + * from max_vector and x_vector to mask_vector */ + max_vec_1.v = _mm512_mask_blend_ps(mask_vec_1, + max_vec_1.v, + x_vec_1.v); + /* Take values from max_vector if corresponding bit in mask_vector is 0 + * otherwise take value from x_vector, this is accumulated maximum value + * from max_vector and x_vector to mask_vector */ + max_vec_2.v = _mm512_mask_blend_ps(mask_vec_2, + max_vec_2.v, + x_vec_2.v); + /* Take values from max_vector if corresponding bit in mask_vector is 0 + * otherwise take value from x_vector, this is accumulated maximum value + * from max_vector and x_vector to mask_vector */ + max_vec_3.v = _mm512_mask_blend_ps(mask_vec_3, + max_vec_3.v, + x_vec_3.v); + /* Take values from maxIndex_vector if corresponding bit in mask_vector + * is 0 otherwise take value from index_vec_1, this is accumulated + * maximum value index from maxIndex_vector and index_vec_1 + * to maxIndex_vector */ + maxInd_vec_1.v = _mm512_mask_blend_ps(mask_vec_1, + maxInd_vec_1.v, + index_vec_1.v); + /* Take values from maxIndex_vector if corresponding bit in mask_vector + * is 0 otherwise take value from index_vec_1, this is accumulated + * maximum value index from maxIndex_vector and index_vec_1 + * to maxIndex_vector */ + maxInd_vec_2.v = _mm512_mask_blend_ps(mask_vec_2, + maxInd_vec_2.v, + ind_vec_2.v); + /* Take values from maxIndex_vector if corresponding bit in mask_vector + * is 0 otherwise take value from index_vec_1, this is accumulated + * maximum value index from maxIndex_vector and index_vec_1 + * to maxIndex_vector */ + maxInd_vec_3.v = _mm512_mask_blend_ps(mask_vec_3, + maxInd_vec_3.v, + ind_vec_3.v); + + // Mask generation for blending max_vec_2 and max_vec_3 to max_vec_2. + // mask = max_vec_2 - max_vec_3 + mask.v = _mm512_sub_ps(max_vec_2.v, max_vec_3.v); + // Type cast mask from IEEE754 (float) to integer type + // This operation will not need a new register, its just to convince + // the compiler. But its accounted as seperate register in the + // above calculations + intMask = _mm512_castps_si512(mask.v); + // Extract the signbit and build the mask. + mask_vec_2 = _mm512_movepi32_mask(intMask); + + // Blend to obtain 1 vector each of max values and index. + /* Take values from max_vec_2 if corresponding bit in mask_vec_2 + * is 0 otherwise take value from max_vec_3, this is accumulated + * maximum value from max_vec_2 and max_vec_3 to mask_vec_2 */ + max_vec_2.v = _mm512_mask_blend_ps(mask_vec_2, + max_vec_2.v, + max_vec_3.v); + /* Take values from maxInd_vec_2 if corresponding bit in mask_vector + * is 0 otherwise take value from maxInd_vec_3, this is accumulated + * maximum value index from maxInd_vec_2 and maxInd_vec_3 + * to maxInd_vec_2 */ + maxInd_vec_2.v = _mm512_mask_blend_ps(mask_vec_2, + maxInd_vec_2.v, + maxInd_vec_3.v); + + // Mask generation for blending max_vec_1 and max_vec_2 into max_vec_1. + // mask = max_vec_1 - max_vec_2 + mask.v = _mm512_sub_ps(max_vec_1.v, max_vec_2.v); + // Type cast mask from IEEE754 (float) to integer type + // This operation will not need a new register, its just to convince + // the compiler. But its accounted as seperate register in the + // above calculations + intMask = _mm512_castps_si512(mask.v); + // Extract the signbit and build the mask. + mask_vec_1 = _mm512_movepi32_mask(intMask); + + // Final blend to the master max_vec_1 and maxInd_vec_1 + /* Take values from max_vec_1 if corresponding bit in mask_vec_1 + * is 0 otherwise take value from max_vec_2, this is accumulated + * maximum value from max_vec_1 and max_vec_2 to mask_vec_1 */ + max_vec_1.v = _mm512_mask_blend_ps(mask_vec_1, max_vec_1.v, max_vec_2.v); + /* Take values from maxInd_vec_1 if corresponding bit in mask_vector + * is 0 otherwise take value from maxInd_vec_2, this is accumulated + * maximum value index from maxInd_vec_1 and maxInd_vec_2 + * to maxInd_vec_1 */ + maxInd_vec_1.v = _mm512_mask_blend_ps(mask_vec_1, + maxInd_vec_1.v, + maxInd_vec_2.v); + + // Increment the index vector to point to next indexes. + index_vec_1.v = _mm512_add_ps(ind_vec_3.v, inc_vec.v); + } + + for (; i < num_iter; i++) + { + /* + Take vector one by one, above code makes max_vec_1 + contain the first 16 elements, now with the max vector + as first 16 elements (abs), we need to load next 16 elements + into x_vec_1 (abs). Now with those we can safely removeNan + which will put -ve values as NAN. + + These -ve values of NAN decreases by 1 in each iteration, + this helps us find the first NAN value. + */ + // Load Value x values + x_vec_1.v = _mm512_loadu_ps(x); + // x_vec_1 = abs(x_vec_1) + x_vec_1.v = _mm512_andnot_ps(abs_mask.v, x_vec_1.v); + // Remove nan and replace with -ve values + x_vec_1.v = remove_NAN_512_s(x_vec_1.v); + + // Mask Generation + // mask = max_vec_1 - x_vec_1 + mask.v = _mm512_sub_ps(max_vec_1.v, x_vec_1.v); + // Extract the signbit and build the mask. + mask_vec_1 = _mm512_movepi32_mask(_mm512_castps_si512(mask.v)); + /* Take values from max_vec_1 if corresponding bit in + * mask_vec_1 is 0 otherwise take value from x_vec_1, + * this is accumulated maximum value from max_vec_1 and + * x_vec_1 to mask_vec_1 */ + max_vec_1.v = _mm512_mask_blend_ps(mask_vec_1, + max_vec_1.v, + x_vec_1.v); + /* Take values from maxInd_vec_1 if corresponding bit in + * mask_vector is 0 otherwise take value from index_vec_1, + * this is accumulated maximum value index from maxInd_vec_1 + * and index_vec_1 to maxInd_vec_1 */ + maxInd_vec_1.v = _mm512_mask_blend_ps(mask_vec_1, + maxInd_vec_1.v, + index_vec_1.v); + + // Increment the index vector to point to next indexes. + index_vec_1.v = _mm512_add_ps(index_vec_1.v, inc_vec.v); + + // Increment x vector as we have loaded 16 values + x += num_vector_elements; + } + + num_remain = (n - ((i)*16)); + + /* + Now take the max vector and produce the max value from + the max vector by slicing and comparing with itself, + until we are left with just one index position and max value. + */ + // Split max to hi and lo + max_vec_hi.v = _mm512_extractf32x8_ps(max_vec_1.v, 1); + max_vec_lo.v = _mm512_extractf32x8_ps(max_vec_1.v, 0); + + // Split maxIndex to hi and lo + maxInd_vec_hi.v = _mm512_extractf32x8_ps(maxInd_vec_1.v, 1); + maxInd_vec_lo.v = _mm512_extractf32x8_ps(maxInd_vec_1.v, 0); + + // Compare max_vec_hi > max_vec_1 + // mask_vec_lo = max_vec_lo - max_vec_hi + mask_vec_lo.v = _mm256_sub_ps(max_vec_lo.v, max_vec_hi.v); + + /* Take values from max_vec_lo if corresponding bit in mask_vec_lo + * is 0 otherwise take value from max_vec_hi, this is accumulated + * maximum value from max_vec_lo and max_vec_hi to max_vec_lo */ + max_vec_lo.v = _mm256_blendv_ps(max_vec_lo.v, + max_vec_hi.v, + mask_vec_lo.v); + /* Take values from maxInd_vec_lo if corresponding bit + * in mask_vec_lo is 0 otherwise take value from maxInd_vec_hi, + * this is accumulated maximum value from maxInd_vec_lo and + * maxInd_vec_hi to maxInd_vec_lo */ + maxInd_vec_lo.v = _mm256_blendv_ps(maxInd_vec_lo.v, + maxInd_vec_hi.v, + mask_vec_lo.v); + + // Split max_lo to hi and lo + max_vec_lo_hi.v = _mm256_extractf128_ps(max_vec_lo.v, 1); + max_vec_lo_lo.v = _mm256_extractf128_ps(max_vec_lo.v, 0); + + // Split maxIndex_lo to hi and lo + maxInd_vec_lo_hi.v = _mm256_extractf128_ps(maxInd_vec_lo.v, 1); + maxInd_vec_lo_lo.v = _mm256_extractf128_ps(maxInd_vec_lo.v, 0); + + // mask_vec_lo_lo = max_vec_lo_lo - max_vec_lo_hi + mask_vec_lo_lo.v = _mm_sub_ps(max_vec_lo_lo.v, max_vec_lo_hi.v); + /* Take values from max_vec_lo_lo if corresponding bit in + * mask_vec_lo_lo is 0 otherwise take value from max_vec_lo_hi, + * this is accumulated maximum value from max_vec_lo_lo and + * max_vec_lo_hi to max_vec_lo_lo */ + max_vec_lo_lo.v = _mm_blendv_ps(max_vec_lo_lo.v, + max_vec_lo_hi.v, + mask_vec_lo_lo.v); + /* Take values from maxInd_vec_lo if corresponding bit + * in mask_vec_lo_lo is 0 otherwise take value from maxInd_vec_hi, + * this is accumulated maximum value from maxInd_vec_lo and + * maxInd_vec_hi to maxInd_vec_lo */ + maxInd_vec_lo_lo.v = _mm_blendv_ps(maxInd_vec_lo_lo.v, + maxInd_vec_lo_hi.v, + mask_vec_lo_lo.v); + + // Take 64 high bits of max_lo_lo and put it to 64 low bits, rest 1st value + /* Example max_vec_lo_lo is {a, b, x, y} + * After max_vec_lo_hi.v = _mm_permute_ps(max_vec_lo_lo.v, 14); + * max_vec_lo_hi is {x, y, a, a} (essentially folding the vector) + */ + max_vec_lo_hi.v = _mm_permute_ps(max_vec_lo_lo.v, 14); + // Fold the vector same as max_vector + maxInd_vec_lo_hi.v = _mm_permute_ps(maxInd_vec_lo_lo.v, 14); + + // mask_vec_lo_lo = max_vec_lo_lo - max_vec_lo_hi + mask_vec_lo_lo.v = _mm_sub_ps(max_vec_lo_lo.v, max_vec_lo_hi.v); + /* Take values from max_vec_lo_lo if corresponding bit in + * mask_vec_lo_lo is 0 otherwise take value from max_vec_lo_hi, + * this is accumulated maximum value from max_vec_lo_lo and + * max_vec_lo_hi to max_vec_lo_lo */ + max_vec_lo_lo.v = _mm_blendv_ps(max_vec_lo_lo.v, + max_vec_lo_hi.v, + mask_vec_lo_lo.v); + /* Take values from maxInd_vec_lo if corresponding bit + * in mask_vec_lo_lo is 0 otherwise take value from maxInd_vec_hi, + * this is accumulated maximum value from maxInd_vec_lo and + * maxInd_vec_hi to maxInd_vec_lo */ + maxInd_vec_lo_lo.v = _mm_blendv_ps(maxInd_vec_lo_lo.v, + maxInd_vec_lo_hi.v, + mask_vec_lo_lo.v); + + // Take max_vec_lo_lo.f[1] and put it to max_vec_lo_hi.f[0] + /* Example max_vec_lo_lo is {a, b, x, y} + * After max_vec_lo_hi.v = _mm_permute_ps(max_vec_lo_lo.v, 1); + * max_vec_lo_hi is {b, a, a, a} (essentially folding the vector) + */ + max_vec_lo_hi.v = _mm_permute_ps(max_vec_lo_lo.v, 1); + // Do the same operation. + maxInd_vec_lo_hi.v = _mm_permute_ps(maxInd_vec_lo_lo.v, 1); + + // mask_vec_lo_lo = max_vec_lo_lo - max_vec_lo_hi + mask_vec_lo_lo.v = _mm_sub_ps(max_vec_lo_lo.v, max_vec_lo_hi.v); + /* Take values from max_vec_lo_lo if corresponding bit in + * mask_vec_lo_lo is 0 otherwise take value from max_vec_lo_hi, + * this is accumulated maximum value from max_vec_lo_lo and + * max_vec_lo_hi to max_vec_lo_lo */ + max_vec_lo_lo.v = _mm_blendv_ps(max_vec_lo_lo.v, + max_vec_lo_hi.v, + mask_vec_lo_lo.v); + /* Take values from maxInd_vec_lo if corresponding bit + * in mask_vec_lo_lo is 0 otherwise take value from maxInd_vec_hi, + * this is accumulated maximum value from maxInd_vec_lo and + * maxInd_vec_hi to maxInd_vec_lo */ + maxInd_vec_lo_lo.v = _mm_blendv_ps(maxInd_vec_lo_lo.v, + maxInd_vec_lo_hi.v, + mask_vec_lo_lo.v); + /* We have kept on folding and comparing until we got one single index + * and max value so that is the final answer so set it as the final + * answer.*/ + fndInd = maxInd_vec_lo_lo.f[0]; + fndMaxVal = max_vec_lo_lo.f[0]; + // Found value is < 0 means it was the max NAN which was accumulated. + if (fndMaxVal < 0) + { + // So just set it as NAN + fndMaxVal = NAN; + } + // Finish off the remaining values using normal instructions + for (int i = n - num_remain; i < n; i++) + { + float absval = fabsf(*(x)); + if (fndMaxVal < absval || (isnan(absval) && !isnan(fndMaxVal))) + { + fndMaxVal = absval; + fndInd = i; + } + x += 1; + } + } + + // Issue vzeroupper instruction to clear upper lanes of ymm registers. + // This avoids a performance penalty caused by false dependencies when + // transitioning from from AVX to SSE instructions (which may occur + // later, especially if BLIS is compiled with -mfpmath=sse). + _mm256_zeroupper(); + + /* Store final index to output variable. */ + *i_max = fndInd; + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3) +} + +// ----------------------------------------------------------------------------- +/* Converts all the NAN to a negative number less than previously encountered NANs*/ +__m512d remove_NAN_512d(__m512d vec) +{ + + static int iter; + static __m512d sign_mask; + + __m512d vec_mask; + __m512i int_mask_vec; + __mmask8 vec_mask8; + + iter = iter - 1; + + sign_mask = _mm512_set1_pd(-0.f); + + //numbers other than NAN will become 0 + vec_mask = _mm512_mul_pd(vec, sign_mask); + + //producing an 8-bit mask + int_mask_vec = _mm512_castpd_si512(vec_mask); + vec_mask8 = _mm512_movepi64_mask(int_mask_vec); + + //replacing all the NAN with negative numbers + vec = _mm512_mask_blend_pd(vec_mask8, _mm512_set1_pd(-1 + iter), vec); + + return vec; +} + +//---------------------------------------------------------------------------------------------------- +void bli_damaxv_zen_int_avx512( + dim_t n, + double *restrict x, inc_t incx, + dim_t *restrict i_max, + cntx_t *restrict cntx) +{ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3) + double *minus_one = PASTEMAC(d, m1); + dim_t *zero_i = PASTEMAC(i, 0); + + double chi1_r; + //double chi1_i; + double abs_chi1; + double abs_chi1_max; + dim_t i_max_l; + dim_t i; + + /* If the vector length is zero, return early. This directly emulates + the behavior of netlib BLAS's i?amax() routines. */ + if (bli_zero_dim1(n)) + { + PASTEMAC(i, copys) + (*zero_i, *i_max); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3) + return; + } + + /* Initialize the index of the maximum absolute value to zero. */ + PASTEMAC(i, copys) + (*zero_i, i_max_l); + + /* Initialize the maximum absolute value search candidate with + -1, which is guaranteed to be less than all values we will + compute. */ + PASTEMAC(d, copys) + (*minus_one, abs_chi1_max); + + // For non-unit strides, or very small vector lengths, compute with + // scalar code. + if (incx != 1 || n < 8) + { + for (i = 0; i < n; ++i) + { + double *chi1 = x + (i)*incx; + + /* Get the real and imaginary components of chi1. */ + chi1_r = *chi1; + + /* Replace chi1_r and chi1_i with their absolute values. */ + chi1_r = fabs(chi1_r); + + /* Add the real and imaginary absolute values together. */ + abs_chi1 = chi1_r; + + /* If the absolute value of the current element exceeds that of + the previous largest, save it and its index. If NaN is + encountered, then treat it the same as if it were a valid + value that was smaller than any previously seen. This + behavior mimics that of LAPACK's i?amax(). */ + if (abs_chi1_max < abs_chi1 || (isnan(abs_chi1) && !isnan(abs_chi1_max))) + { + abs_chi1_max = abs_chi1; + i_max_l = i; + } + } + } + else + { + + dim_t iterations, n_left, vector_length = 8, unrollCount = 0; + + //mask bits + __mmask8 mask_got_01, mask_got_23; + + //YMM0 - YMM6 registers + v4df_t max_hi, max_lo, max_ind_hi, max_ind_lo, + mask_final, inter_result, inter_ind; + + //XMM0 to XMM4 registers + v2dd_t max_vec_hi, max_vec_lo, max_ind_hi_128, + max_ind_lo_128, mask_vec_lo; + + //ZMM0 to ZMM13 registers + v8df_t zmm0, zmm1, zmm2, zmm3, zmm4_Ind, + zmm5_Ind, zmm6_Ind, zmm7_Ind, max_01, + max_23, final_max, max_array, max_ind, inc_vec; + + //ZMM14 to ZMM16 registers + __m512d mask_01, mask_23, sign_mask; + + //Intermediate int mask values + __m512i int_mask_01, int_mask_23; + + // Initialize sign mask + sign_mask = _mm512_set1_pd(-0.f); + + //Initializing the indexes of the base case of max vector + zmm4_Ind.v = _mm512_set_pd(7, 6, 5, 4, 3, 2, 1, 0); + inc_vec.v = _mm512_set1_pd(8); //Vector for incrementing + + // Initializing the max array as vec [ 0 : 512 ] + max_array.v = _mm512_loadu_pd(x); + + // Taking the absolute value and removing the NAN + max_array.v = _mm512_andnot_pd(sign_mask, max_array.v); + max_array.v = remove_NAN_512d(max_array.v); + + // Initializing the maximumum index + max_ind.v = _mm512_set_pd(7, 6, 5, 4, 3, 2, 1, 0); + x += vector_length; + + //Incrementing to make the vector + //to point to the next 8 elements + zmm4_Ind.v = _mm512_add_pd(zmm4_Ind.v, inc_vec.v); + + /* Loop unrolled by a factor of 4 + At the end of the loop max_array holds the largest element + in each corresponding vector index */ + for (unrollCount = 8; (unrollCount + 31) < n; unrollCount += 32) + { + // Taking 32 elements + // Taking only the absolute values of the registers + // Removing the NAN values and replacing it + // with negative numbers + zmm0.v = _mm512_loadu_pd(x); + zmm0.v = _mm512_andnot_pd(sign_mask, zmm0.v); + zmm0.v = remove_NAN_512d(zmm0.v); + x += vector_length; + + zmm1.v = _mm512_loadu_pd(x); + zmm5_Ind.v = _mm512_add_pd(zmm4_Ind.v, inc_vec.v); + zmm1.v = _mm512_andnot_pd(sign_mask, zmm1.v); + zmm1.v = remove_NAN_512d(zmm1.v); + x += vector_length; + + zmm2.v = _mm512_loadu_pd(x); + zmm6_Ind.v = _mm512_add_pd(zmm5_Ind.v, inc_vec.v); + zmm2.v = _mm512_andnot_pd(sign_mask, zmm2.v); + zmm2.v = remove_NAN_512d(zmm2.v); + x += vector_length; + + zmm3.v = _mm512_loadu_pd(x); + zmm7_Ind.v = _mm512_add_pd(zmm6_Ind.v, inc_vec.v); + zmm3.v = _mm512_andnot_pd(sign_mask, zmm3.v); + zmm3.v = remove_NAN_512d(zmm3.v); + x += vector_length; + + /*Using sub function to generating the mask + as a 512d type*/ + mask_01 = _mm512_sub_pd(zmm0.v, zmm1.v); + mask_23 = _mm512_sub_pd(zmm2.v, zmm3.v); + + //Converting the 512d mask to a 512i mask + int_mask_01 = _mm512_castpd_si512(mask_01); + int_mask_23 = _mm512_castpd_si512(mask_23); + + /*Converting the 512i mask + to mmask type to use the mask bits*/ + mask_got_01 = _mm512_movepi64_mask(int_mask_01); + mask_got_23 = _mm512_movepi64_mask(int_mask_23); + + //Storing the largest elements in index % 8 position for + //vector 1 and 2, and the index of the corresponding element + max_01.v = _mm512_mask_blend_pd(mask_got_01, zmm0.v, zmm1.v); + zmm5_Ind.v = _mm512_mask_blend_pd(mask_got_01, zmm4_Ind.v, zmm5_Ind.v); + + //Storing the largest elements in index % 8 position for + //vector 3 and 4, and the index of the corresponding element + max_23.v = _mm512_mask_blend_pd(mask_got_23, zmm2.v, zmm3.v); + zmm6_Ind.v = _mm512_mask_blend_pd(mask_got_23, zmm6_Ind.v, zmm7_Ind.v); + + //Generating mask for the intermediate max vector + mask_01 = _mm512_sub_pd(max_01.v, max_23.v); + int_mask_01 = _mm512_castpd_si512(mask_01); + mask_got_01 = _mm512_movepi64_mask(int_mask_01); + + /*Storing the largest elements in index % 8 position for + the intermediate max vectors, + and the index of the corresponding element*/ + final_max.v = _mm512_mask_blend_pd(mask_got_01, max_01.v, max_23.v); + zmm5_Ind.v = _mm512_mask_blend_pd(mask_got_01, zmm5_Ind.v, zmm6_Ind.v); + + //Generating the mask for final max vector and base max vector + mask_01 = _mm512_sub_pd(max_array.v, final_max.v); + int_mask_01 = _mm512_castpd_si512(mask_01); + mask_got_01 = _mm512_movepi64_mask(int_mask_01); + + // Result is the maximum of all index % 8 locations + max_array.v = _mm512_mask_blend_pd(mask_got_01, max_array.v, final_max.v); + max_ind.v = _mm512_mask_blend_pd(mask_got_01, max_ind.v, zmm5_Ind.v); + + // Incrementing the index to point to the next 8 locations + zmm4_Ind.v = _mm512_add_pd(zmm7_Ind.v, inc_vec.v); + } + + // Calculating the number of iterations left + iterations = (n - unrollCount) / vector_length; + n_left = (n - unrollCount) % vector_length; + + /* At the end of the loop max_array holds the largest element + in each corresponding vector index */ + for (int i = 1; i < iterations; ++i) + { + // Taking 32 elements + // Taking only the absolute values of the registers + // Removing the NAN values and replacing it + // with negative numbers + zmm0.v = _mm512_loadu_pd(x); + zmm0.v = _mm512_abs_pd(zmm0.v); + zmm0.v = remove_NAN_512d(zmm0.v); + + //Generating mask for the intermediate max vector + mask_01 = _mm512_sub_pd(max_array.v, zmm0.v); + int_mask_01 = _mm512_castpd_si512(mask_01); + mask_got_01 = _mm512_movepi64_mask(int_mask_01); + + // Result is the maximum of all index % 8 locations + max_array.v = _mm512_mask_blend_pd(mask_got_01, max_array.v, zmm0.v); + + //Storing the index of the corresponding max array elemets + max_ind.v = _mm512_mask_blend_pd(mask_got_01, max_ind.v, zmm4_Ind.v); + + //Incrementing the vector the point to the next location + //Incrementing the vector indexes + x += vector_length; + zmm4_Ind.v = _mm512_add_pd(zmm4_Ind.v, inc_vec.v); + } + + //Breaking max array into vectors of length 4 + //Taking upper and lower halves + max_hi.v = _mm512_extractf64x4_pd(max_array.v, 1); + max_ind_hi.v = _mm512_extractf64x4_pd(max_ind.v, 1); + max_lo.v = _mm512_extractf64x4_pd(max_array.v, 0); + max_ind_lo.v = _mm512_extractf64x4_pd(max_ind.v, 0); + + //Generating the mask for blending + mask_final.v = _mm256_sub_pd(max_hi.v, max_lo.v); + + // Storing the max of max array index % 4 + inter_result.v = _mm256_blendv_pd(max_hi.v, max_lo.v, mask_final.v); + inter_ind.v = _mm256_blendv_pd(max_ind_hi.v, max_ind_lo.v, mask_final.v); + + //Breaking max array into vectors of length 2 + max_vec_lo.v = _mm256_extractf128_pd(inter_result.v, 0); + max_vec_hi.v = _mm256_extractf128_pd(inter_result.v, 1); + max_ind_hi_128.v = _mm256_extractf128_pd(inter_ind.v, 1); + max_ind_lo_128.v = _mm256_extractf128_pd(inter_ind.v, 0); + + //Generating the mask for blending + mask_vec_lo.v = _mm_sub_pd(max_vec_lo.v, max_vec_hi.v); + + // Storing the max of max array index % 2 + max_vec_lo.v = _mm_blendv_pd(max_vec_lo.v, max_vec_hi.v, mask_vec_lo.v); + max_ind_lo_128.v = _mm_blendv_pd(max_ind_lo_128.v, max_ind_hi_128.v, mask_vec_lo.v); + + max_vec_hi.v = _mm_permute_pd(max_vec_lo.v, 1); + max_ind_hi_128.v = _mm_permute_pd(max_ind_lo_128.v, 1); + + //Performing work of CMP128 i.e generating mask + mask_vec_lo.v = _mm_sub_pd(max_vec_lo.v, max_vec_hi.v); + + //Finding the maximum element + max_vec_lo.v = _mm_blendv_pd(max_vec_lo.v, max_vec_hi.v, mask_vec_lo.v); + max_ind_lo_128.v = _mm_blendv_pd(max_ind_lo_128.v, max_ind_hi_128.v, mask_vec_lo.v); + + abs_chi1_max = max_vec_lo.d[0]; + + //If the largest number is negative it is NAN + if (abs_chi1_max < 0) + abs_chi1_max = NAN; + + i_max_l = max_ind_lo_128.d[0]; + + for (i = n - n_left; i < n; i++) + { + double *chi1 = x; + + /* Get the real and imaginary components of chi1. */ + chi1_r = *chi1; + + /* Replace chi1_r and chi1_i with their absolute values. */ + abs_chi1 = fabs(chi1_r); + + /* If the absolute value of the current element exceeds that of + the previous largest, save it and its index. If NaN is + encountered, return the index of the first NaN. This + behavior mimics that of LAPACK's i?amax(). */ + if (abs_chi1_max < abs_chi1 || (isnan(abs_chi1) && !isnan(abs_chi1_max))) + { + abs_chi1_max = abs_chi1; + i_max_l = i; + } + + x += 1; + } + } + + // Return value + *i_max = i_max_l; + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3) +} diff --git a/kernels/zen4/CMakeLists.txt b/kernels/zen4/CMakeLists.txt new file mode 100644 index 0000000000..827d91bbe2 --- /dev/null +++ b/kernels/zen4/CMakeLists.txt @@ -0,0 +1,5 @@ +##Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.## + +add_subdirectory(1) + + diff --git a/kernels/zen4/bli_kernels_zen4.h b/kernels/zen4/bli_kernels_zen4.h new file mode 100644 index 0000000000..476eeaaeed --- /dev/null +++ b/kernels/zen4/bli_kernels_zen4.h @@ -0,0 +1,39 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +// -- level-1v -- + +// amaxv (intrinsics) +AMAXV_KER_PROT( float, s, amaxv_zen_int_avx512 ) +AMAXV_KER_PROT( double, d, amaxv_zen_int_avx512 ) From c03699b97a26ec11eaf2984df22e583c2d6c797c Mon Sep 17 00:00:00 2001 From: Chandrashekara K R Date: Thu, 9 Jun 2022 14:32:54 +0530 Subject: [PATCH 131/243] AOCL-Windows: Updating windows build system to add /arch:AVX512 compiler flag only to zen4 specific source files. Change-Id: Ia4fa65a831a00ce37f97075db6812be048bfe0bc --- CMakeLists.txt | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index ec11b44a6f..2558494d90 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -318,7 +318,9 @@ endif() if(${TARGET_ARCH} STREQUAL zen4 OR ${TARGET_ARCH} STREQUAL amdzen) - add_definitions(/arch:AVX512) + set_source_files_properties(${CMAKE_CURRENT_SOURCE_DIR}/kernels/zen4/1/bli_amaxv_zen_int_avx512.c PROPERTIES COMPILE_FLAGS /arch:AVX512) + set_source_files_properties(${CMAKE_CURRENT_SOURCE_DIR}/kernels/skx/3/bli_dgemm_skx_asm_16x14.c PROPERTIES COMPILE_FLAGS /arch:AVX512) + set_source_files_properties(${CMAKE_CURRENT_SOURCE_DIR}/kernels/skx/3/bli_sgemm_skx_asm_32x12_l2.c PROPERTIES COMPILE_FLAGS /arch:AVX512) endif() set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /W0 ") From 6c112632a76ff18616c2ddbdb08d3c932a656011 Mon Sep 17 00:00:00 2001 From: mkadavil Date: Thu, 28 Apr 2022 19:17:57 +0530 Subject: [PATCH 132/243] Low precision gemm integrated as aocl_gemm addon. - Multi-Threaded int8 GEMM (Input - uint8_t, int8_t, Output - int32_t). AVX512_vnni based micro-kernel for int8 gemm. Paralellization supported along m and n dimensions. - Multi-Threaded B matrix reorder support for sgemm. Reordering B matrix is packing entire B matrix upfront before sgemm. It allows sgemm to take advantage of packed B matrix without incurring packing costs during runtime. - Makefile updates to addon make rules to compile avx512 code for selected files in addon folder. - CPU features query enhancements to check for AVX512_VNNI flag. - Bench for int8 gemm and sgemm with B matrix reorder. Supports performance mode for benchmarking and accuracy mode for testing code correctness. AMD-Internal: [CPUPL-2102] Change-Id: I8fb25f5c2fbd97d756f95b623332cb29e3b8d182 --- Makefile | 6 +- addon/aocl_gemm/aocl_gemm.h | 43 + addon/aocl_gemm/aocl_gemm_f32f32f32of32.c | 170 + addon/aocl_gemm/aocl_gemm_f32f32f32of32.h | 60 + .../aocl_gemm/aocl_gemm_f32f32f32of32_utils.c | 263 ++ .../aocl_gemm/aocl_gemm_f32f32f32of32_utils.h | 59 + addon/aocl_gemm/aocl_gemm_u8s8s32os32.c | 164 + addon/aocl_gemm/aocl_gemm_u8s8s32os32.h | 59 + addon/aocl_gemm/aocl_gemm_u8s8s32os32_utils.c | 139 + addon/aocl_gemm/aocl_gemm_u8s8s32os32_utils.h | 55 + .../frame/f32f32f32/lpgemm_f32f32f32.c | 319 ++ .../frame/f32f32f32/lpgemm_f32f32f32.h | 61 + addon/aocl_gemm/frame/lpgemm_config.c | 88 + addon/aocl_gemm/frame/lpgemm_config.h | 54 + addon/aocl_gemm/frame/lpgemm_types.h | 116 + .../threading/lpgemm_thread_decor_openmp.c | 447 +++ .../threading/lpgemm_thread_decor_openmp.h | 68 + .../frame/threading/lpgemm_thrinfo_utils.h | 78 + .../aocl_gemm/frame/u8s8s32/lpgemm_reorder.c | 214 ++ .../aocl_gemm/frame/u8s8s32/lpgemm_reorder.h | 52 + .../aocl_gemm/frame/u8s8s32/lpgemm_u8s8s32.c | 340 ++ .../aocl_gemm/frame/u8s8s32/lpgemm_u8s8s32.h | 62 + addon/aocl_gemm/frame/u8s8s32/lpgemm_utils.c | 156 + addon/aocl_gemm/frame/u8s8s32/lpgemm_utils.h | 225 ++ .../kernels/u8s8s32/lpgemm_6x64rowmajor.h | 58 + .../u8s8s32/lpgemm_6x64rowmajor_amd512vnni.c | 693 ++++ .../kernels/u8s8s32/lpgemm_m_fringe.h | 118 + .../u8s8s32/lpgemm_m_fringe_amd512vnni.c | 1354 +++++++ .../kernels/u8s8s32/lpgemm_mn_fringe.h | 363 ++ .../u8s8s32/lpgemm_mn_fringe_amd512vnni.c | 3200 +++++++++++++++++ .../kernels/u8s8s32/lpgemm_n_fringe.h | 111 + .../u8s8s32/lpgemm_n_fringe_amd512vnni.c | 1561 ++++++++ .../aocl_gemm/kernels/u8s8s32/lpgemm_packa.h | 55 + .../kernels/u8s8s32/lpgemm_packa_amd512vnni.c | 518 +++ .../aocl_gemm/kernels/u8s8s32/lpgemm_packb.h | 65 + .../kernels/u8s8s32/lpgemm_packb_amd512vnni.c | 792 ++++ bench/bench_aocl_gemm/bench_input.txt | 377 ++ bench/bench_aocl_gemm/bench_lpgemm.c | 475 +++ bench/bench_aocl_gemm/data_gen_lpgemm.py | 45 + frame/base/bli_cpuid.c | 62 +- frame/base/bli_cpuid.h | 6 +- 41 files changed, 13146 insertions(+), 5 deletions(-) create mode 100644 addon/aocl_gemm/aocl_gemm.h create mode 100644 addon/aocl_gemm/aocl_gemm_f32f32f32of32.c create mode 100644 addon/aocl_gemm/aocl_gemm_f32f32f32of32.h create mode 100644 addon/aocl_gemm/aocl_gemm_f32f32f32of32_utils.c create mode 100644 addon/aocl_gemm/aocl_gemm_f32f32f32of32_utils.h create mode 100644 addon/aocl_gemm/aocl_gemm_u8s8s32os32.c create mode 100644 addon/aocl_gemm/aocl_gemm_u8s8s32os32.h create mode 100644 addon/aocl_gemm/aocl_gemm_u8s8s32os32_utils.c create mode 100644 addon/aocl_gemm/aocl_gemm_u8s8s32os32_utils.h create mode 100644 addon/aocl_gemm/frame/f32f32f32/lpgemm_f32f32f32.c create mode 100644 addon/aocl_gemm/frame/f32f32f32/lpgemm_f32f32f32.h create mode 100644 addon/aocl_gemm/frame/lpgemm_config.c create mode 100644 addon/aocl_gemm/frame/lpgemm_config.h create mode 100644 addon/aocl_gemm/frame/lpgemm_types.h create mode 100644 addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.c create mode 100644 addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.h create mode 100644 addon/aocl_gemm/frame/threading/lpgemm_thrinfo_utils.h create mode 100644 addon/aocl_gemm/frame/u8s8s32/lpgemm_reorder.c create mode 100644 addon/aocl_gemm/frame/u8s8s32/lpgemm_reorder.h create mode 100644 addon/aocl_gemm/frame/u8s8s32/lpgemm_u8s8s32.c create mode 100644 addon/aocl_gemm/frame/u8s8s32/lpgemm_u8s8s32.h create mode 100644 addon/aocl_gemm/frame/u8s8s32/lpgemm_utils.c create mode 100644 addon/aocl_gemm/frame/u8s8s32/lpgemm_utils.h create mode 100644 addon/aocl_gemm/kernels/u8s8s32/lpgemm_6x64rowmajor.h create mode 100644 addon/aocl_gemm/kernels/u8s8s32/lpgemm_6x64rowmajor_amd512vnni.c create mode 100644 addon/aocl_gemm/kernels/u8s8s32/lpgemm_m_fringe.h create mode 100644 addon/aocl_gemm/kernels/u8s8s32/lpgemm_m_fringe_amd512vnni.c create mode 100644 addon/aocl_gemm/kernels/u8s8s32/lpgemm_mn_fringe.h create mode 100644 addon/aocl_gemm/kernels/u8s8s32/lpgemm_mn_fringe_amd512vnni.c create mode 100644 addon/aocl_gemm/kernels/u8s8s32/lpgemm_n_fringe.h create mode 100644 addon/aocl_gemm/kernels/u8s8s32/lpgemm_n_fringe_amd512vnni.c create mode 100644 addon/aocl_gemm/kernels/u8s8s32/lpgemm_packa.h create mode 100644 addon/aocl_gemm/kernels/u8s8s32/lpgemm_packa_amd512vnni.c create mode 100644 addon/aocl_gemm/kernels/u8s8s32/lpgemm_packb.h create mode 100644 addon/aocl_gemm/kernels/u8s8s32/lpgemm_packb_amd512vnni.c create mode 100644 bench/bench_aocl_gemm/bench_input.txt create mode 100644 bench/bench_aocl_gemm/bench_lpgemm.c create mode 100644 bench/bench_aocl_gemm/data_gen_lpgemm.py diff --git a/Makefile b/Makefile index 820954e3e5..f42914024e 100644 --- a/Makefile +++ b/Makefile @@ -598,10 +598,12 @@ endef define make-c99-addon-rule $(BASE_OBJ_ADDON_PATH)/%.o: $(ADDON_PATH)/%.$(2) $(BLIS_H_FLAT) $(ADDON_H99_FILES) $(MAKE_DEFS_MK_PATHS) ifeq ($(ENABLE_VERBOSE),yes) - $(CC) $(call get-addon-c99flags-for,$(1)) -c $$< -o $$@ + $$(if $$(findstring _amd512vnni,$$<),$$(eval LPGEMM_MARCH_VAR=icelake-server),$$(eval LPGEMM_MARCH_VAR=znver3)) + $(CC) -march=$$(LPGEMM_MARCH_VAR) $(call get-addon-c99flags-for,$(1)) -c $$< -o $$@ else @echo "Compiling $$@" $(call get-addon-c99text-for,$(1)) - @$(CC) $(call get-addon-c99flags-for,$(1)) -c $$< -o $$@ + $$(if $$(findstring _amd512vnni,$$<),$$(eval LPGEMM_MARCH_VAR=icelake-server),$$(eval LPGEMM_MARCH_VAR=znver3)) + @$(CC) -march=$$(LPGEMM_MARCH_VAR) $(call get-addon-c99flags-for,$(1)) -c $$< -o $$@ endif endef diff --git a/addon/aocl_gemm/aocl_gemm.h b/addon/aocl_gemm/aocl_gemm.h new file mode 100644 index 0000000000..446bdc11b7 --- /dev/null +++ b/addon/aocl_gemm/aocl_gemm.h @@ -0,0 +1,43 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_ADDON_LPGEMM +#define BLIS_ADDON_LPGEMM + +#include "aocl_gemm_u8s8s32os32.h" +#include "aocl_gemm_f32f32f32of32.h" +#include "aocl_gemm_u8s8s32os32_utils.h" +#include "aocl_gemm_f32f32f32of32_utils.h" + +#endif //BLIS_ADDON_LPGEMM diff --git a/addon/aocl_gemm/aocl_gemm_f32f32f32of32.c b/addon/aocl_gemm/aocl_gemm_f32f32f32of32.c new file mode 100644 index 0000000000..f3055ad3f9 --- /dev/null +++ b/addon/aocl_gemm/aocl_gemm_f32f32f32of32.c @@ -0,0 +1,170 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "aocl_gemm_f32f32f32of32.h" +#include "lpgemm_types.h" +#include "lpgemm_thread_decor_openmp.h" +#include "lpgemm_utils.h" +#include "lpgemm_f32f32f32.h" + +void aocl_gemm_f32f32f32of32 + ( + const char transa, + const char transb, + const dim_t m, + const dim_t n, + const dim_t k, + const float alpha, + const float* a, + const dim_t lda, + const char mem_format_a, + const float* b, + const dim_t ldb, + const char mem_format_b, + const float beta, + float* c, + const dim_t ldc + ) +{ + trans_t blis_transa; + trans_t blis_transb; + + // Check if avx ISA is supported, lpgemm fp32 matmul only works with it. + if ( bli_cpuid_is_avx_supported() == FALSE ) + { + printf(" AVX2 ISA not supported by processor, cannot perform lpgemm.\n"); + return; // Error. + } + + /* Initialize BLIS. */ + bli_init_auto(); + + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); + AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(s), transa, transb, m, n, k,\ + (void*)&alpha, lda, ldb, (void*)&beta, ldc); + + // Null check for pointers. + if ( ( a == NULL ) || ( b == NULL ) || ( c == NULL ) ) + { + AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_1, \ + "Invalid pointers provided for input parameters."); + return; // Error. + } + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + bli_param_map_netlib_to_blis_trans( transa, &blis_transa ); + bli_param_map_netlib_to_blis_trans( transb, &blis_transb ); + + /* Perform BLAS parameter checking. */ + // Transpose not supported. + if ( ( blis_transa != BLIS_NO_TRANSPOSE ) || + ( blis_transb != BLIS_NO_TRANSPOSE ) ) + { + AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_1, \ + "Input matrix transpose not supported."); + return; // Error. + } + + // Row major input expected with leading dimensions equal to row stride. + if ( ( lda != k ) || ( ldb != n ) || ( ldc != n ) ) + { + AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_1, \ + "Column major and general stride not supported."); + return; // Error. + } + + // Check if dimensions are valid. + if ( ( m <= 0) || ( n <= 0 ) || ( k <= 0 ) || + ( lda <= 0 ) || ( ldb <= 0 ) || ( ldc <= 0 ) ) + { + AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_1, \ + "Invalid matrix dimensions."); + return; // Error. + } + + const inc_t rs_a = lda; + const inc_t cs_a = 1; + const inc_t rs_b = ldb; + const inc_t cs_b = 1; + const inc_t rs_c = ldc; + + AOCL_MEMORY_TAG mtag_a; + AOCL_MEMORY_TAG mtag_b; + + bli_param_map_char_to_lpmtag( mem_format_a, &mtag_a ); + bli_param_map_char_to_lpmtag( mem_format_b, &mtag_b ); + + // Only unreordered A supported now. + if ( mtag_a != UNPACKED ) + { + AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_1, \ + "A matrix packing/reordering not supported."); + return; // Error. + } + + // Initialize a local runtime with global settings if necessary. Note + // that in the case that a runtime is passed in, we make a local copy. + rntm_t rntm_g; + bli_rntm_init_from_global( &rntm_g ); + bli_membrk_rntm_set_membrk( &rntm_g ); + +#ifdef BLIS_ENABLE_OPENMP + lpgemm_f32f32f32of32_openmp_thread_decorator + ( + m, n, k, + a, rs_a, cs_a, mtag_a, + b, rs_b, cs_b, mtag_b, + c, rs_c, + alpha, beta, + &rntm_g + ); +#else + // Setting pack A by default for non open mp case. + bli_rntm_set_pack_a( 1, &rntm_g ); + + lpgemm_rowvar_f32f32f32of32 + ( + m, n, k, + a, rs_a, cs_a, mtag_a, + b, rs_b, cs_b, mtag_b, + c, rs_c, + alpha, beta, + &rntm_g, + NULL + ); +#endif + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); +} diff --git a/addon/aocl_gemm/aocl_gemm_f32f32f32of32.h b/addon/aocl_gemm/aocl_gemm_f32f32f32of32.h new file mode 100644 index 0000000000..8ce4e001f9 --- /dev/null +++ b/addon/aocl_gemm/aocl_gemm_f32f32f32of32.h @@ -0,0 +1,60 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef AOCL_GEMM_F32F32F32OF32_H +#define AOCL_GEMM_F32F32F32OF32_H + +// Only supports matrices in row major format. This api can perform gemm with +// both normal as well as reordered B matrix as opposesd to sgemm (only +// supports former). This api can be considered analogous to packed sgemm api. +BLIS_EXPORT_ADDON void aocl_gemm_f32f32f32of32 + ( + const char transa, + const char transb, + const dim_t m, + const dim_t n, + const dim_t k, + const float alpha, + const float* a, + const dim_t lda, + const char mem_format_a, + const float* b, + const dim_t ldb, + const char mem_format_b, + const float beta, + float* c, + const dim_t ldc + ); + +#endif //AOCL_GEMM_F32F32F32OF32_H diff --git a/addon/aocl_gemm/aocl_gemm_f32f32f32of32_utils.c b/addon/aocl_gemm/aocl_gemm_f32f32f32of32_utils.c new file mode 100644 index 0000000000..84a611f605 --- /dev/null +++ b/addon/aocl_gemm/aocl_gemm_f32f32f32of32_utils.c @@ -0,0 +1,263 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "aocl_gemm_f32f32f32of32_utils.h" +#include "lpgemm_utils.h" + +siz_t aocl_get_reorder_buf_size_f32f32f32of32 + ( + const char mat_type, + const dim_t k, + const dim_t n + ) +{ + if ( ( k <= 0 ) || ( n <= 0 ) ) + { + return 0; // Error. + } + + // Check if avx ISA is supported, lpgemm fp32 matmul only works with it. + if ( bli_cpuid_is_avx_supported() == FALSE ) + { + printf(" AVX2 ISA not supported by processor, cannot perform lpgemm.\n"); + return 0; // Error. + } + + /* Initialize BLIS. */ + bli_init_auto(); + + // Query the global cntx. + cntx_t* cntx = bli_gks_query_cntx(); + + num_t dt = BLIS_FLOAT; + + AOCL_MATRIX_TYPE input_mat_type; + bli_param_map_char_to_lpmat_type( mat_type, &input_mat_type ); + + if ( input_mat_type == A_MATRIX ) + { + return 0; // A reorder not supported. + } + + const dim_t NR = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_NR, cntx ); + + // Extra space since packing does width in multiples of NR. + const dim_t n_reorder = ( ( n + NR - 1 ) / NR ) * NR; + + siz_t size_req = sizeof( float ) * k * n_reorder; + + return size_req; +} + +// Pack B into row stored column panels. +void aocl_reorder_f32f32f32of32 + ( + const char mat_type, + const float* input_buf_addr_b, + float* reorder_buf_addr_b, + const dim_t k, + const dim_t n, + const dim_t ldb + ) +{ + if ( ( input_buf_addr_b == NULL ) || ( reorder_buf_addr_b == NULL ) || + ( k <= 0 ) || ( n <= 0 ) || ( ldb < n ) ) + { + return; // Error. + } + + // Check if avx ISA is supported, lpgemm fp32 matmul only works with it. + if ( bli_cpuid_is_avx_supported() == FALSE ) + { + printf(" AVX2 ISA not supported by processor, cannot perform lpgemm.\n"); + return; // Error. + } + + /* Initialize BLIS. */ + bli_init_auto(); + + // Query the global cntx. + cntx_t* cntx = bli_gks_query_cntx(); + + num_t dt = BLIS_FLOAT; + + AOCL_MATRIX_TYPE input_mat_type; + bli_param_map_char_to_lpmat_type( mat_type, &input_mat_type ); + + if ( input_mat_type == A_MATRIX ) + { + return; // A reorder not supported. + } + + const dim_t NC = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_NC, cntx ); + const dim_t KC = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_KC, cntx ); + const dim_t NR = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_NR, cntx ); + + // Only supports row major packing now. + inc_t rs_b = ldb; + inc_t cs_b = 1; + + inc_t rs_p = NR; + + float one_local = *PASTEMAC(s,1); + float* restrict kappa_cast = &one_local; + + // Set the schema to "row stored column panels" to indicate packing to + // conventional column-stored row panels. + pack_t schema = BLIS_PACKED_COL_PANELS; + trans_t transc = BLIS_NO_TRANSPOSE; + conj_t conjc = bli_extract_conj( transc ); + + // Initialize a local runtime with global settings if necessary. Note + // that in the case that a runtime is passed in, we make a local copy. + rntm_t rntm_g; + bli_rntm_init_from_global( &rntm_g ); + + dim_t n_threads = bli_rntm_num_threads( &rntm_g ); + n_threads = ( n_threads > 0 ) ? n_threads : 1; + +#ifdef BLIS_ENABLE_OPENMP + _Pragma( "omp parallel num_threads(n_threads)" ) + { + // Initialise a local thrinfo obj for work split across threads. + thrinfo_t thread_jc; + bli_thrinfo_set_n_way( n_threads, &thread_jc ); + bli_thrinfo_set_work_id( omp_get_thread_num(), &thread_jc ); +#else + { + // Initialise a local thrinfo obj for work split across threads. + thrinfo_t thread_jc; + bli_thrinfo_set_n_way( 1, &thread_jc ); + bli_thrinfo_set_work_id( 0, &thread_jc ); +#endif + // Compute the JC loop thread range for the current thread. Per thread + // gets multiple of NR columns. + dim_t jc_start, jc_end; + bli_thread_range_sub( &thread_jc, n, NR, FALSE, &jc_start, &jc_end ); + + for ( dim_t jc = jc_start; jc < jc_end; jc += NC ) + { + dim_t nc0 = bli_min( ( jc_end - jc ), NC ); + + dim_t jc_cur_loop = jc; + dim_t jc_cur_loop_rem = 0; + dim_t n_sub_updated; + + get_B_panel_reordered_start_offset_width + ( + jc, n, NC, NR, + &jc_cur_loop, &jc_cur_loop_rem, + &nc0, &n_sub_updated + ); + + // Compute the total number of iterations we'll need. + dim_t n_iter = ( nc0 + NR - 1 ) / NR; + + for ( dim_t pc = 0; pc < k; pc += KC ) + { + dim_t kc0 = bli_min( ( k - pc ), KC ); + inc_t ps_p = kc0 * NR; + + const float* b_temp = input_buf_addr_b + ( jc * cs_b ) + ( pc * rs_b ); + + // The offsets are calculated in such a way that it resembles + // the reorder buffer traversal in single threaded reordering. + // The panel boundaries (KCxNC) remain as it is accessed in + // single thread, and as a consequence a thread with jc_start + // inside the panel cannot consider NC range for reorder. It + // has to work with NC' < NC, and the offset is calulated using + // prev NC panels spanning k dim + cur NC panel spaning pc loop + // cur iteration + (NC - NC') spanning current kc0 (<= KC). + // + //Eg: Consider the following reordered buffer diagram: + // t1 t2 + // | | + // | |..NC..| + // | | | + // |.NC. |.NC. |NC'|NC" + // pc=0-+-----+-----+---+--+ + // KC| | | | | + // | 1 | 3 | 5 | + // pc=KC-+-----+-----+---st-+ + // KC| | | | | + // | 2 | 4 | 6 | 7| + // pc=k=2KC-+-----+-----+---+--+ + // |jc=0 |jc=NC|jc=2NC| + // + // The numbers 1,2..6,7 denotes the order in which reordered + // KCxNC blocks are stored in memory, ie: block 1 followed by 2 + // followed by 3, etc. Given two threads t1 and t2, and t2 needs + // to acces point st in the reorder buffer to write the data: + // The offset calulation logic will be: + // jc_cur_loop = 2NC, jc_cur_loop_rem = NC', pc = KC, + // n_sub_updated = NC, k = 2KC, kc0_updated = KC + // + // st = ( jc_cur_loop * k ) + // + ( n_sub_updated * pc ) + // + ( NC' * kc0_updated) + float* p_temp = reorder_buf_addr_b + ( jc_cur_loop * k ) + + ( n_sub_updated * pc ) + ( jc_cur_loop_rem * kc0 ); + + dim_t jr, it; + // Iterate over every logical micropanel in the source matrix. + for ( jr = 0, it = 0; it < n_iter; jr += NR, it += 1 ) + { + dim_t panel_dim_i = bli_min( NR, nc0 - jr ); + + const float* b_use = b_temp + ( jr * cs_b ); + float* p_use = p_temp; + + PASTEMAC(s,packm_cxk) + ( + conjc, + schema, + panel_dim_i, + NR, + kc0, + kc0, + kappa_cast, + ( float* )b_use, cs_b, rs_b, + p_use, rs_p, + cntx + ); + + p_temp += ps_p; + } + } + + adjust_B_panel_reordered_jc( &jc, jc_cur_loop ); + } + } +} diff --git a/addon/aocl_gemm/aocl_gemm_f32f32f32of32_utils.h b/addon/aocl_gemm/aocl_gemm_f32f32f32of32_utils.h new file mode 100644 index 0000000000..819e087691 --- /dev/null +++ b/addon/aocl_gemm/aocl_gemm_f32f32f32of32_utils.h @@ -0,0 +1,59 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef AOCL_GEMM_F32F32F32OF32_UTILS_H +#define AOCL_GEMM_F32F32F32OF32_UTILS_H + +// Returns the size of buffer in bytes required for the reordered matrix. +BLIS_EXPORT_ADDON siz_t aocl_get_reorder_buf_size_f32f32f32of32 + ( + const char mat_type, + const dim_t k, + const dim_t n + ); + +// Performs reordering of input matrix. Reordering is the process of packing +// the entire matrix upfront, so that the benefits of packed matrix is obtained +// without incurring the packing costs during matmul computation. +BLIS_EXPORT_ADDON void aocl_reorder_f32f32f32of32 + ( + const char mat_type, + const float* input_buf_addr_b, + float* reorder_buf_addr_b, + const dim_t k, + const dim_t n, + const dim_t ldb + ); + +#endif //AOCL_GEMM_F32F32F32OF32_UTILS_H diff --git a/addon/aocl_gemm/aocl_gemm_u8s8s32os32.c b/addon/aocl_gemm/aocl_gemm_u8s8s32os32.c new file mode 100644 index 0000000000..e1131a3992 --- /dev/null +++ b/addon/aocl_gemm/aocl_gemm_u8s8s32os32.c @@ -0,0 +1,164 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "aocl_gemm_u8s8s32os32.h" +#include "lpgemm_types.h" +#include "lpgemm_thread_decor_openmp.h" +#include "lpgemm_u8s8s32.h" +#include "lpgemm_config.h" +#include "lpgemm_utils.h" + +void aocl_gemm_u8s8s32os32 + ( + const char transa, + const char transb, + const dim_t m, + const dim_t n, + const dim_t k, + const int32_t alpha, + const uint8_t* a, + const dim_t lda, + const char mem_format_a, + const int8_t* b, + const dim_t ldb, + const char mem_format_b, + const int32_t beta, + int32_t* c, + const dim_t ldc + ) +{ + trans_t blis_transa; + trans_t blis_transb; + + // Check if avx512_vnni ISA is supported, lpgemm matmul only works with it. + if ( bli_cpuid_is_avx512vnni_supported() == FALSE ) + { + printf(" AVX512_VNNI ISA not supported by processor, cannot perform lpgemm.\n"); + return; // Error. + } + + /* Initialize BLIS. */ + bli_init_auto(); + + // Set MC, NC, KC, NR, MR. + aocl_lpgemm_init_global_cntx(); + + // Null check for pointers. + if ( ( a == NULL ) || ( b == NULL ) || ( c == NULL ) ) + { + return; // Error. + } + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + bli_param_map_netlib_to_blis_trans( transa, &blis_transa ); + bli_param_map_netlib_to_blis_trans( transb, &blis_transb ); + + /* Perform BLAS parameter checking. */ + // Transpose not supported. + if ( ( blis_transa != BLIS_NO_TRANSPOSE ) || + ( blis_transb != BLIS_NO_TRANSPOSE ) ) + { + return; // Error. + } + + // Row major input expected with leading dimensions equal to row stride. + if ( ( lda != k ) || ( ldb != n ) || ( ldc != n ) ) + { + return; // Error. + } + + // Check if dimensions are valid. + if ( ( m <= 0) || ( n <= 0 ) || ( k <= 0 ) || + ( lda <= 0 ) || ( ldb <= 0 ) || ( ldc <= 0 ) ) + { + return; // Error. + } + + const inc_t rs_a = lda; + const inc_t cs_a = 1; + const inc_t rs_b = ldb; + const inc_t cs_b = 1; + const inc_t rs_c = ldc; + + AOCL_MEMORY_TAG mtag_a; + AOCL_MEMORY_TAG mtag_b; + + bli_param_map_char_to_lpmtag( mem_format_a, &mtag_a ); + bli_param_map_char_to_lpmtag( mem_format_b, &mtag_b ); + + // B matrix needs to be packed in a certain format in order to be loaded + // and used in VNNI instrution. As such the mtag_b always needs to be either + // packed or reordered. B matrix as it is (unpacked) cannot be used, and + // the mtag_b is set to packed to enable runtime packing. + if ( mtag_b == UNPACKED ) + { + mtag_b = PACK; + } + + // Only unpacked A supported now. + if ( mtag_a != UNPACKED ) + { + return; // Error. + } + + // Initialize a local runtime with global settings if necessary. Note + // that in the case that a runtime is passed in, we make a local copy. + rntm_t rntm_g; + bli_rntm_init_from_global( &rntm_g ); + bli_membrk_rntm_set_membrk( &rntm_g ); + +#ifdef BLIS_ENABLE_OPENMP + lpgemm_u8s8s32o32_openmp_thread_decorator + ( + m, n, k, + a, rs_a, cs_a, mtag_a, + b, rs_b, cs_b, mtag_b, + c, rs_c, + alpha, beta, + &rntm_g + ); +#else + lpgemm_rowvar_u8s8s32o32 + ( + m, n, k, + a, rs_a, cs_a, mtag_a, + b, rs_b, cs_b, mtag_b, + c, rs_c, + alpha, beta, + &rntm_g, + NULL + ); +#endif +} diff --git a/addon/aocl_gemm/aocl_gemm_u8s8s32os32.h b/addon/aocl_gemm/aocl_gemm_u8s8s32os32.h new file mode 100644 index 0000000000..90be30a24b --- /dev/null +++ b/addon/aocl_gemm/aocl_gemm_u8s8s32os32.h @@ -0,0 +1,59 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef AOCL_GEMM_U8S8S32OS32_H +#define AOCL_GEMM_U8S8S32OS32_H + +// Only supports matrices in row major format. Currenlty only mem_format_b +// is configurable to reorder. +BLIS_EXPORT_ADDON void aocl_gemm_u8s8s32os32 + ( + const char transa, + const char transb, + const dim_t m, + const dim_t n, + const dim_t k, + const int32_t alpha, + const uint8_t* a, + const dim_t lda, + const char mem_format_a, + const int8_t* b, + const dim_t ldb, + const char mem_format_b, + const int32_t beta, + int32_t* c, + const dim_t ldc + ); + +#endif //AOCL_GEMM_U8S8S32OS32_H diff --git a/addon/aocl_gemm/aocl_gemm_u8s8s32os32_utils.c b/addon/aocl_gemm/aocl_gemm_u8s8s32os32_utils.c new file mode 100644 index 0000000000..31a56ef577 --- /dev/null +++ b/addon/aocl_gemm/aocl_gemm_u8s8s32os32_utils.c @@ -0,0 +1,139 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "aocl_gemm_u8s8s32os32_utils.h" +#include "lpgemm_types.h" +#include "lpgemm_config.h" +#include "lpgemm_utils.h" +#include "lpgemm_reorder.h" + +siz_t aocl_get_reorder_buf_size_u8s8s32os32 + ( + const char mat_type, + const dim_t k, + const dim_t n + ) +{ + if ( ( k <= 0 ) || ( n <= 0 ) ) + { + return 0; // Error. + } + + // Check if avx512_vnni ISA is supported, lpgemm matmul only works with it. + if ( bli_cpuid_is_avx512vnni_supported() == FALSE ) + { + printf(" AVX512_VNNI ISA not supported by processor, cannot perform lpgemm.\n"); + return 0; // Error. + } + + /* Initialize BLIS. */ + bli_init_auto(); + + // Set MC, NC, KC, NR, MR. + aocl_lpgemm_init_global_cntx(); + + AOCL_MATRIX_TYPE input_mat_type; + bli_param_map_char_to_lpmat_type( mat_type, &input_mat_type ); + + if ( input_mat_type == A_MATRIX ) + { + return 0; // A reorder not supported. + } + + // Extra space since packing does width in multiples of 16. The vnni + // instruction can be used as long as atleast one zmm register can be fully + // loaded; and since k_dim needs to be atleast 4, having n_dim atleast 16 + // should give 4x16=64 elements, enough for 1 zmm register.The padding is + // not rounded to NR (=64), since that would result in memory wastage. + dim_t n_reorder = make_multiple_of_n( n, 16 ); + + // Extra space since packing does length in multiples of 4. + dim_t k_reorder = make_multiple_of_n( k, 4 ); + + siz_t size_req = sizeof( int8_t ) * k_reorder * n_reorder; + + return size_req; +} + +void aocl_reorder_u8s8s32os32 + ( + const char mat_type, + const int8_t* input_buf_addr, + int8_t* reorder_buf_addr, + const dim_t k, + const dim_t n, + const dim_t ldb + ) +{ + if ( ( input_buf_addr == NULL ) || ( reorder_buf_addr == NULL ) || + ( k <= 0 ) || ( n <= 0 ) || ( ldb < n ) ) + { + return; // Error. + } + + // Check if avx512_vnni ISA is supported, lpgemm matmul only works with it. + if ( bli_cpuid_is_avx512vnni_supported() == FALSE ) + { + printf(" AVX512_VNNI ISA not supported by processor, cannot perform lpgemm.\n"); + return; // Error. + } + + /* Initialize BLIS. */ + bli_init_auto(); + + // Set MC, NC, KC, NR, MR. + aocl_lpgemm_init_global_cntx(); + + AOCL_MATRIX_TYPE input_mat_type; + bli_param_map_char_to_lpmat_type( mat_type, &input_mat_type ); + + if ( input_mat_type == A_MATRIX ) + { + return; // A reorder not supported. + } + + // Create dummy b_reorder obj. + lpgemm_obj_t b_reorder; + b_reorder.storage.aligned_buffer = reorder_buf_addr; + + // Create dummy original b obj; + lpgemm_obj_t b; + b.storage.aligned_buffer = ( void* )input_buf_addr; + b.rs = ldb; + b.width = n; + b.length = k; + + reorderb_nr64_u8s8s32o32( &b, &b_reorder ); +} diff --git a/addon/aocl_gemm/aocl_gemm_u8s8s32os32_utils.h b/addon/aocl_gemm/aocl_gemm_u8s8s32os32_utils.h new file mode 100644 index 0000000000..d23660a1ec --- /dev/null +++ b/addon/aocl_gemm/aocl_gemm_u8s8s32os32_utils.h @@ -0,0 +1,55 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef AOCL_GEMM_U8S8S32OS32_UTILS_H +#define AOCL_GEMM_U8S8S32OS32_UTILS_H + +BLIS_EXPORT_ADDON siz_t aocl_get_reorder_buf_size_u8s8s32os32 + ( + const char mat_type, + const dim_t k, + const dim_t n + ); + +BLIS_EXPORT_ADDON void aocl_reorder_u8s8s32os32 + ( + const char mat_type, + const int8_t* input_buf_addr, + int8_t* reorder_buf_addr, + const dim_t k, + const dim_t n, + const dim_t ldb + ); + +#endif //AOCL_GEMM_U8S8S32OS32_UTILS_H diff --git a/addon/aocl_gemm/frame/f32f32f32/lpgemm_f32f32f32.c b/addon/aocl_gemm/frame/f32f32f32/lpgemm_f32f32f32.c new file mode 100644 index 0000000000..f9bb67803d --- /dev/null +++ b/addon/aocl_gemm/frame/f32f32f32/lpgemm_f32f32f32.c @@ -0,0 +1,319 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "lpgemm_f32f32f32.h" +#include "lpgemm_types.h" +#include "lpgemm_utils.h" +#include "lpgemm_thrinfo_utils.h" + +void lpgemm_pack_a_f32f32f32of32 + ( + const float* input_buf_addr_a, + float* reorder_buf_addr_a, + const dim_t m, + const dim_t k, + const dim_t rs_a, + const dim_t cs_a, + const dim_t ps_p, + const dim_t MR, + cntx_t* cntx + ); + +void lpgemm_rowvar_f32f32f32of32 + ( + const dim_t m, + const dim_t n, + const dim_t k, + const float* a, + const dim_t rs_a, + const dim_t cs_a, + const AOCL_MEMORY_TAG mtag_a, + const float* b, + const dim_t rs_b, + const dim_t cs_b, + const AOCL_MEMORY_TAG mtag_b, + float* c, + const dim_t rs_c, + float alpha, + float beta, + rntm_t* rntm, + lpgemm_thrinfo_t* thread + ) +{ + // Query the global cntx. + cntx_t* cntx = bli_gks_query_cntx(); + + num_t dt = BLIS_FLOAT; + + // Query the context for various blocksizes. + const dim_t NR = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_NR, cntx ); + const dim_t MR = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_MR, cntx ); + const dim_t NC = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_NC, cntx ); + const dim_t MC = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_MC, cntx ); + const dim_t KC = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_KC, cntx ); + + // Strides are updated based on matrix packing/reordering. + const float* a_use = NULL; + dim_t rs_a_use = rs_a; + dim_t cs_a_use = cs_a; + + const float* b_use = NULL; + dim_t rs_b_use = rs_b; + dim_t cs_b_use = cs_b; + + float* c_use_jc = NULL; + float* c_use_ic = NULL; + + // Only supporting row major with unit column strided C for now. + const dim_t cs_c = 1; + + /* Compute partitioning step values for each matrix of each loop. */ + inc_t ps_a_use; + inc_t ps_b_use; + auxinfo_t aux; + + // Check if packing of A is required. + bool should_pack_A = bli_rntm_pack_a( rntm ); + + // Pack buffer for A. + float* pack_a_buffer_f32f32f32of32; + mem_t mem_a = BLIS_MEM_INITIALIZER; + siz_t mem_a_size_req = 0; + + float one_local = *PASTEMAC(s,1); + + trans_t transc = BLIS_NO_TRANSPOSE; + conj_t conjc = bli_extract_conj( transc ); + + // Generate thrinfo objects for jc and ic loops from lpgemm_thrinfo_t. + thrinfo_t thread_jc; + thrinfo_t thread_ic; + + lpgemm_gen_thrinfo( thread, &thread_jc, &thread_ic ); + + // Compute the JC loop thread range for the current thread. + dim_t jc_start, jc_end; + bli_thread_range_sub( &thread_jc, n, NR, FALSE, &jc_start, &jc_end ); + + for ( dim_t jc = jc_start; jc < jc_end; jc += NC ) + { + dim_t nc0 = bli_min( ( jc_end - jc ), NC ); + c_use_jc = c + jc; + + dim_t jc_cur_loop = jc; + dim_t jc_cur_loop_rem = 0; + dim_t n_sub_updated; + + if ( mtag_b == REORDERED ) + { + get_B_panel_reordered_start_offset_width + ( + jc, n, NC, NR, + &jc_cur_loop, &jc_cur_loop_rem, + &nc0, &n_sub_updated + ); + } + + for ( dim_t pc = 0; pc < k; pc += KC ) + { + float beta0 = ( pc == 0 ) ? beta : one_local; + dim_t kc0 = bli_min( ( k - pc ), KC ); + + if ( mtag_b == REORDERED ) + { + // In multi-threaded scenarios, an extra offset into a given + // packed B panel is required, since the jc loop split can + // result in per thread start offset inside the panel, instead + // of panel boundaries. + b_use = b + ( jc_cur_loop * k ) + + ( n_sub_updated * pc ) + ( jc_cur_loop_rem * kc0 ); + + rs_b_use = NR; + cs_b_use = 1; + ps_b_use = kc0; + } + else + { + b_use = b + ( pc * rs_b ) + ( jc * cs_b ); + ps_b_use = 1; + } + + dim_t ic_start, ic_end; + bli_thread_range_sub( &thread_ic, m, MR, FALSE, &ic_start, &ic_end ); + + for ( dim_t ic = ic_start; ic < ic_end; ic += MC ) + { + dim_t mc0 = bli_min( ( ic_end - ic ), MC ); + c_use_ic = c_use_jc + ( rs_c * ic ); + + if ( mtag_a == REORDERED ) + { + // Extra space since packing does width in multiples of MR. + const dim_t m_updated = ( ( m + MR - 1 ) / MR ) * MR; + a_use = a + ( pc * m_updated ) + ( kc0 * ic ); + + rs_a_use = 1; + cs_a_use = MR; + ps_a_use = MR * kc0; + } + else if ( should_pack_A == TRUE ) + { + // Extra space since packing does width in multiples of MR. + const dim_t mc0_updated = ( ( mc0 + MR - 1 ) / MR ) * MR; + mem_a_size_req = sizeof( float ) * mc0_updated * kc0; + + lpgemm_alloc_mem_panel + ( + mem_a_size_req, BLIS_BUFFER_FOR_A_BLOCK, + &mem_a, rntm + ); + pack_a_buffer_f32f32f32of32 = ( float* )bli_mem_buffer( &mem_a ); + + rs_a_use = 1; + cs_a_use = MR; + ps_a_use = MR * kc0; + + lpgemm_pack_a_f32f32f32of32 + ( + ( a + ( rs_a * ic ) + pc ), + pack_a_buffer_f32f32f32of32, + mc0, kc0, + rs_a, cs_a, ps_a_use, MR, + cntx + ); + + a_use = pack_a_buffer_f32f32f32of32; + } + else + { + a_use = a + ( rs_a * ic ) + pc; + ps_a_use = MR * rs_a; + } + + // Embed the panel stride of A within the auxinfo_t object. The + // millikernel will query and use this to iterate through + // micropanels of A (if needed). + bli_auxinfo_set_ps_a( ps_a_use, &aux ); + + for ( dim_t jr = 0; jr < nc0; jr += NR ) + { + dim_t nr0 = bli_min( ( nc0 - jr ), NR ); + + // Reordered/unpacked B, reordered/unpacked A. + bli_sgemmsup_rv_zen_asm_6x16m + ( + conjc, + conjc, + mc0, nr0, kc0, + &alpha, + ( float* )a_use, rs_a_use, cs_a_use, + ( float* )( b_use + ( jr * ps_b_use ) ), rs_b_use, cs_b_use, + &beta0, + ( c_use_ic + jr ), rs_c, cs_c, + &aux, cntx + ); + } + } + } + if ( mtag_b == REORDERED ) + { + adjust_B_panel_reordered_jc( &jc, jc_cur_loop ); + } + } + + // Release pack buffers. + if ( should_pack_A == TRUE ) + { + if ( bli_mem_is_alloc( &mem_a ) ) + { + bli_membrk_release( rntm, &mem_a ); + } + } +} + +void lpgemm_pack_a_f32f32f32of32 + ( + const float* input_buf_addr_a, + float* reorder_buf_addr_a, + const dim_t m, + const dim_t k, + const dim_t rs_a, + const dim_t cs_a, + const dim_t ps_p, + const dim_t MR, + cntx_t* cntx + ) +{ + float one_local = *PASTEMAC(s,1); + float* restrict kappa_cast = &one_local; + + // Set the schema to "column stored row panels" to indicate packing to conventional + // column-stored row panels. + pack_t schema = BLIS_PACKED_ROW_PANELS; + trans_t transc = BLIS_NO_TRANSPOSE; + conj_t conjc = bli_extract_conj( transc ); + + // Compute the total number of iterations we'll need. + dim_t m_iter = ( m + MR - 1 ) / MR; + + inc_t cs_p = MR; + + float* p_temp = reorder_buf_addr_a; + dim_t ir, it; + // Iterate over every logical micropanel in the source matrix. + for ( ir = 0, it = 0; it < m_iter; ir += MR, it += 1 ) + { + dim_t panel_dim_i = bli_min( MR, m - ir ); + + const float* a_use = input_buf_addr_a + ( ir * rs_a ); + float* p_use = p_temp; + + PASTEMAC(s,packm_cxk) + ( + conjc, + schema, + panel_dim_i, + MR, + k, + k, + kappa_cast, + ( float* )a_use, rs_a, cs_a, + p_use, cs_p, + cntx + ); + + p_temp += ps_p; + } +} diff --git a/addon/aocl_gemm/frame/f32f32f32/lpgemm_f32f32f32.h b/addon/aocl_gemm/frame/f32f32f32/lpgemm_f32f32f32.h new file mode 100644 index 0000000000..03c1146bf0 --- /dev/null +++ b/addon/aocl_gemm/frame/f32f32f32/lpgemm_f32f32f32.h @@ -0,0 +1,61 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef LPGEMM_F32F32F32_H +#define LPGEMM_F32F32F32_H + +#include "lpgemm_types.h" + +void lpgemm_rowvar_f32f32f32of32 + ( + const dim_t m, + const dim_t n, + const dim_t k, + const float* a, + const dim_t rs_a, + const dim_t cs_a, + const AOCL_MEMORY_TAG mtag_a, + const float* b, + const dim_t rs_b, + const dim_t cs_b, + const AOCL_MEMORY_TAG mtag_b, + float* c, + const dim_t rs_c, + float alpha, + float beta, + rntm_t* rntm, + lpgemm_thrinfo_t* thread + ); + +#endif //LPGEMM_F32F32F32_H diff --git a/addon/aocl_gemm/frame/lpgemm_config.c b/addon/aocl_gemm/frame/lpgemm_config.c new file mode 100644 index 0000000000..a16147e4c1 --- /dev/null +++ b/addon/aocl_gemm/frame/lpgemm_config.c @@ -0,0 +1,88 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "lpgemm_config.h" + +lpgemm_cntx_t global_cntx_t_list[3]; //Only one op type supported now. + +BLIS_INLINE void lpgemm_set_block_sizes_global_cntx + ( + AOCL_OPERATION_TYPE op_type, + dim_t MC, + dim_t NC, + dim_t KC, + dim_t NR, + dim_t MR + ) +{ + global_cntx_t_list[op_type].blksz.MC = MC; + global_cntx_t_list[op_type].blksz.NC = NC; + global_cntx_t_list[op_type].blksz.KC = KC; + global_cntx_t_list[op_type].blksz.NR = NR; + global_cntx_t_list[op_type].blksz.MR = MR; +} + +// Sets default block sizes for lpgemm. Currently only u8s8s32 supported. +// Thread safety is not considered now since the block sizes are not expected +// to be configurable from application. +void aocl_lpgemm_init_global_cntx() +{ + lpgemm_set_block_sizes_global_cntx( U8S8S32OS32, 144, 1024, 2048, 64, 6 ); +} + +dim_t lpgemm_get_block_size_MC_global_cntx( AOCL_OPERATION_TYPE op_type ) +{ + return global_cntx_t_list[op_type].blksz.MC; +} + +dim_t lpgemm_get_block_size_NC_global_cntx( AOCL_OPERATION_TYPE op_type ) +{ + return global_cntx_t_list[op_type].blksz.NC; +} + +dim_t lpgemm_get_block_size_KC_global_cntx( AOCL_OPERATION_TYPE op_type ) +{ + return global_cntx_t_list[op_type].blksz.KC; +} + +dim_t lpgemm_get_block_size_NR_global_cntx( AOCL_OPERATION_TYPE op_type ) +{ + return global_cntx_t_list[op_type].blksz.NR; +} + +dim_t lpgemm_get_block_size_MR_global_cntx( AOCL_OPERATION_TYPE op_type ) +{ + return global_cntx_t_list[op_type].blksz.MR; +} diff --git a/addon/aocl_gemm/frame/lpgemm_config.h b/addon/aocl_gemm/frame/lpgemm_config.h new file mode 100644 index 0000000000..8e25986f6d --- /dev/null +++ b/addon/aocl_gemm/frame/lpgemm_config.h @@ -0,0 +1,54 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef LPGEMM_CONFIG_H +#define LPGEMM_CONFIG_H + +#include "lpgemm_types.h" + +extern lpgemm_cntx_t lpgemm_global_cntx_t_list[3]; // equals to number of ops in enum AOCL_OPERATION_TYPE. + +void aocl_lpgemm_init_global_cntx(); + +dim_t lpgemm_get_block_size_MC_global_cntx( AOCL_OPERATION_TYPE op_type ); + +dim_t lpgemm_get_block_size_NC_global_cntx( AOCL_OPERATION_TYPE op_type ); + +dim_t lpgemm_get_block_size_KC_global_cntx( AOCL_OPERATION_TYPE op_type ); + +dim_t lpgemm_get_block_size_NR_global_cntx( AOCL_OPERATION_TYPE op_type ); + +dim_t lpgemm_get_block_size_MR_global_cntx( AOCL_OPERATION_TYPE op_type ); + +#endif //LPGEMM_CONFIG_H diff --git a/addon/aocl_gemm/frame/lpgemm_types.h b/addon/aocl_gemm/frame/lpgemm_types.h new file mode 100644 index 0000000000..2d9cca79a4 --- /dev/null +++ b/addon/aocl_gemm/frame/lpgemm_types.h @@ -0,0 +1,116 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef LPGEMM_TYPES_H +#define LPGEMM_TYPES_H + +typedef enum +{ + INT8 = 0, + INT16 = 1, + INT32 = 2 +} AOCL_ARRAY_TYPE; + +// Enum name template:A_mat_type ## B_mat_type ## Accumulate_type ## C_mat_type. +typedef enum +{ + U8S8S16OS16 = 0, // uint8_t - A, int8_t - B, int16_t - C + U8S8S32OS32 = 1, // uint8_t - A, int8_t - B, int32_t - C + F16F16F16OF16 = 2 // float16 - A, float16 - B, float16 - C +} AOCL_OPERATION_TYPE; + +typedef enum +{ + UNPACKED = 0, + PACK = 1, + REORDERED = 2, +} AOCL_MEMORY_TAG; + +typedef enum +{ + ROW_MAJOR = 0, + COLUMN_MAJOR = 1, +} AOCL_STOR_TAG; + +typedef enum +{ + A_MATRIX = 0, + B_MATRIX = 1, +} AOCL_MATRIX_TYPE; + +typedef struct +{ + void* aligned_buffer; + void* origin_buffer; +} lpgemm_mem_t; + +typedef struct +{ + dim_t length; + dim_t width; + + dim_t elem_size; + + dim_t rs; + dim_t cs; + + AOCL_MEMORY_TAG mtag; + + lpgemm_mem_t storage; +} lpgemm_obj_t; + +typedef struct +{ + dim_t MC; + dim_t NC; + dim_t KC; + dim_t NR; + dim_t MR; +} lpgemm_block_size_t; + +typedef struct +{ + lpgemm_block_size_t blksz; +} lpgemm_cntx_t; + +typedef struct +{ + dim_t n_threads; + dim_t tid; + dim_t ic_ways; + dim_t jc_ways; + thrcomm_t* comm; +} lpgemm_thrinfo_t; + +#endif //LPGEMM_TYPES_H diff --git a/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.c b/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.c new file mode 100644 index 0000000000..59641f5dae --- /dev/null +++ b/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.c @@ -0,0 +1,447 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "lpgemm_config.h" +#include "lpgemm_thread_decor_openmp.h" +#include "lpgemm_types.h" +#include "lpgemm_u8s8s32.h" +#include "lpgemm_f32f32f32.h" + +#ifdef BLIS_ENABLE_OPENMP + +#define BLIS_LPGEMM_NUM_STATIC_COMMS 96 + +BLIS_INLINE dim_t next_factor + ( + const dim_t nt, + const dim_t part_nt + ) +{ + if ( part_nt == nt ) + { + return part_nt; + } + + dim_t nt_temp = part_nt + 1; + while ( ( nt_temp <= nt ) && ( ( nt % nt_temp ) != 0 ) ) + { + nt_temp++; + } + return nt_temp; +} + +BLIS_INLINE dim_t prev_factor + ( + const dim_t nt, + const dim_t part_nt + ) +{ + if ( part_nt == 1 ) + { + return part_nt; + } + + dim_t nt_temp = part_nt - 1; + while ( ( nt_temp >= 1 ) && ( ( nt % nt_temp ) != 0 ) ) + { + nt_temp--; + } + return nt_temp; +} + +BLIS_INLINE void lpgemm_pnl_wrk_heur_adjust_ic_jc_ways + ( + dim_t MR, + dim_t NR, + dim_t m, + dim_t n, + dim_t* n_threads, + dim_t* ic_ways, + dim_t* jc_ways + ) +{ + // This function currently only increments ic and subsequently decrements + // jc. Cannot proceed if all threads are allocated to ic. + // The factorization adjustment here is based on improving the B NR panel + // distribution among the jc threads. + dim_t mu = ( m + MR - 1 ) / MR; + dim_t nu = ( n + NR - 1 ) / NR; + + // The next 3 ic factors will be considered to see if it results in better + // NR panel distribution and subsequently reduce the per thread panel work. + dim_t nu_mod_jc_ways = nu % ( *jc_ways ); + if ( ( nu_mod_jc_ways != 0 ) && ( ( *ic_ways ) < ( *n_threads ) ) ) + { + dim_t mu_ic_cur = ( mu + ( *ic_ways ) - 1 ) / ( *ic_ways ); + dim_t nu_jc_cur = ( nu + ( *jc_ways ) - 1 ) / ( *jc_ways ); + dim_t panel_work_cur = mu_ic_cur + nu_jc_cur; + + const dim_t next_ic = next_factor( ( *n_threads ), ( *ic_ways ) ); + const dim_t prev_jc = prev_factor( ( *n_threads ), ( *jc_ways ) ); + dim_t mu_ic_next = ( mu + next_ic - 1 ) / next_ic; + dim_t nu_jc_prev = ( nu + prev_jc - 1 ) / prev_jc; + dim_t panel_work_next = mu_ic_next + nu_jc_prev; + + if ( panel_work_next < panel_work_cur ) + { + panel_work_cur = panel_work_next; + ( *ic_ways ) = next_ic; + ( *jc_ways ) = prev_jc; + } + + nu_mod_jc_ways = nu % ( *jc_ways ); + if ( ( nu_mod_jc_ways != 0 ) && ( next_ic < ( *n_threads ) ) ) + { + const dim_t next_next_ic = next_factor( ( *n_threads ), next_ic ); + const dim_t prev_prev_jc = prev_factor( ( *n_threads ), prev_jc ); + dim_t mu_ic_next_next = ( mu + next_next_ic - 1 ) / next_next_ic; + dim_t nu_jc_prev_prev = ( nu + prev_prev_jc - 1 ) / prev_prev_jc; + dim_t panel_work_next_next = mu_ic_next_next + nu_jc_prev_prev; + + if ( panel_work_next_next < panel_work_cur ) + { + panel_work_cur = panel_work_next_next; + ( *ic_ways ) = next_next_ic; + ( *jc_ways ) = prev_prev_jc; + } + + nu_mod_jc_ways = nu % ( *jc_ways ); + if ( ( nu_mod_jc_ways != 0 ) && ( next_next_ic < ( *n_threads ) ) ) + { + const dim_t next_next_next_ic = + next_factor + ( + ( *n_threads ), next_next_ic + ); + const dim_t prev_prev_prev_jc = + prev_factor + ( + ( *n_threads ), prev_prev_jc + ); + dim_t mu_ic_next_next_next = + ( mu + next_next_next_ic - 1 ) / next_next_next_ic; + dim_t nu_jc_prev_prev_prev = + ( nu + prev_prev_prev_jc - 1 ) / prev_prev_prev_jc; + dim_t panel_work_next_next_next = + mu_ic_next_next_next + nu_jc_prev_prev_prev; + + if ( panel_work_next_next_next < panel_work_cur ) + { + ( *ic_ways ) = next_next_next_ic; + ( *jc_ways ) = prev_prev_prev_jc; + } + } + } + } +} + +BLIS_INLINE void lpgemm_adjust_ic_jc_ways + ( + dim_t m, + dim_t n, + dim_t* n_threads, + dim_t* ic_ways, + dim_t* jc_ways + ) +{ + const dim_t m_ic = m / ( *ic_ways ); + const dim_t n_jc = n / ( *jc_ways ); + const int64_t cur_work_per_thread = m_ic + n_jc; + + const dim_t next_ic = next_factor( ( *n_threads ), ( *ic_ways ) ); + const dim_t prev_ic = prev_factor( ( *n_threads ), ( *ic_ways ) ); + const dim_t next_jc = next_factor( ( *n_threads ), ( *jc_ways ) ); + const dim_t prev_jc = prev_factor( ( *n_threads ), ( *jc_ways ) ); + + const dim_t m_next_ic = m / next_ic; + const dim_t m_prev_ic = m / prev_ic; + const dim_t n_next_jc = n / next_jc; + const dim_t n_prev_jc = n / prev_jc; + + const int64_t next_jc_work_per_thread = n_next_jc + m_prev_ic; + const int64_t next_ic_work_per_thread = m_next_ic + n_prev_jc; + + bool can_increase_ic = FALSE; + bool can_increase_jc = FALSE; + + if ( next_ic_work_per_thread <= cur_work_per_thread ) + { + can_increase_ic = TRUE; + } + else if ( next_jc_work_per_thread < cur_work_per_thread ) + { + can_increase_jc = TRUE; + } + + if ( can_increase_ic ) + { + ( *ic_ways ) = next_ic; + ( *jc_ways ) = prev_jc; + } + else if ( can_increase_jc ) + { + // Giving priority to ic and m dimensions, if m >= n, jc must be < ic. + if ( ( ( m >= n ) && ( prev_ic >= next_jc ) ) || + ( ( m < n ) && ( prev_ic <= next_jc ) ) ) + { + ( *ic_ways ) = prev_ic; + ( *jc_ways ) = next_jc; + } + } +} + +BLIS_INLINE void lpgemm_u8s8s32o32_get_threading + ( + dim_t* n_threads, + dim_t* ic_ways, + dim_t* jc_ways, + dim_t m, + dim_t n, + dim_t k, + rntm_t* rntm_g + ) +{ + *n_threads = bli_rntm_num_threads( rntm_g ); + *jc_ways = bli_rntm_jc_ways( rntm_g ); + *ic_ways = bli_rntm_ic_ways( rntm_g ); + + if ( ( ( *ic_ways ) > 0 ) || ( ( *jc_ways ) > 0 ) ) + { + // If BLIS_IC_NT or JC_NT are set. + // Default cases. + *ic_ways = ( ( *ic_ways ) > 0 ) ? ( *ic_ways ) : 1; + *jc_ways = ( ( *jc_ways ) > 0 ) ? ( *jc_ways ) : 1; + + *n_threads = ( *jc_ways ) * ( *ic_ways ); + } + else if ( ( *n_threads ) > 1 ) + { + + dim_t NR = lpgemm_get_block_size_NR_global_cntx( U8S8S32OS32 ); + dim_t MR = lpgemm_get_block_size_MR_global_cntx( U8S8S32OS32 ); + + if ( n <= NR ) + { + // If n is less than micro panel dimension, allocating all threads + // to ic resulted in gains. + ( *ic_ways ) = ( *n_threads ); + ( *jc_ways ) = 1; + } + else + { + // If BLIS_NUM_THREADS are set, generate jc,ic from the same. + bli_thread_partition_2x2( ( *n_threads ), m, n, ic_ways, jc_ways ); + + lpgemm_adjust_ic_jc_ways( m, n, n_threads, ic_ways, jc_ways ); + + lpgemm_pnl_wrk_heur_adjust_ic_jc_ways + ( + MR, NR, m, n, + n_threads, ic_ways, jc_ways + ); + } + } + else + { + // Setting all the values to 1 in case n_threads <= 1. This ensures + // the threading parameters are valid. + *n_threads = 1; + *jc_ways = 1; + *ic_ways = 1; + } +} + +// Some aspects of sgemm smart threading incorporated here. Eventually this +// will be redirected to the sgemm smart threading API. +BLIS_INLINE void lpgemm_f32f32f32of32_get_threading + ( + dim_t* n_threads, + dim_t* ic_ways, + dim_t* jc_ways, + dim_t m, + dim_t n, + dim_t k, + rntm_t* rntm_g + ) +{ + // Query the global cntx. + cntx_t* cntx = bli_gks_query_cntx(); + + num_t dt = BLIS_FLOAT; + + // Query the context for SUP limits. + const dim_t MT = bli_cntx_get_l3_sup_thresh_dt( dt, BLIS_MT, cntx ); + const dim_t NT = bli_cntx_get_l3_sup_thresh_dt( dt, BLIS_NT, cntx ); + const dim_t KT = bli_cntx_get_l3_sup_thresh_dt( dt, BLIS_KT, cntx ); + + const dim_t MT_2 = MT / 2; + + *n_threads = bli_rntm_num_threads( rntm_g ); + *jc_ways = bli_rntm_jc_ways( rntm_g ); + *ic_ways = bli_rntm_ic_ways( rntm_g ); + + if ( ( ( *ic_ways ) > 0 ) || ( ( *jc_ways ) > 0 ) ) + { + // If BLIS_IC_NT or JC_NT are set. + // Default cases. + *ic_ways = ( ( *ic_ways ) > 0 ) ? ( *ic_ways ) : 1; + *jc_ways = ( ( *jc_ways ) > 0 ) ? ( *jc_ways ) : 1; + + *n_threads = ( *jc_ways ) * ( *ic_ways ); + } + else if ( ( *n_threads ) > 1 ) + { + // If BLIS_NUM_THREADS are set, generate jc,ic from the same. + bli_thread_partition_2x2( ( *n_threads ), m, n, ic_ways, jc_ways ); + + lpgemm_adjust_ic_jc_ways( m, n, n_threads, ic_ways, jc_ways ); + } + else + { + // Setting all the values to 1 in case n_threads <= 1. This ensures + // the threading parameters are valid. + *n_threads = 1; + *jc_ways = 1; + *ic_ways = 1; + } + + // Native -> SUP path. + const dim_t m_ic = m / ( *ic_ways ); + const dim_t n_jc = n / ( *jc_ways ); + const dim_t page_size = bli_info_get_page_size(); + const dim_t page_size_b_floatx2 = + 2 * ( page_size / sizeof( float ) ); + + if ( ( m >= MT ) && ( n >= NT ) && ( k >= KT ) ) + { + if ( ( k > page_size_b_floatx2 ) || + ( ( k <= page_size_b_floatx2 ) && + ( m_ic > MT_2 ) && ( n_jc >= NT ) ) ) + { + bli_rntm_set_pack_a( 1, rntm_g ); + } + } +} + +#define GEN_LPGEMM_OPENMP_DECORATOR(A_type,B_type,C_type,LPGEMM_SFX) \ +void lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \ + ( \ + const dim_t m, \ + const dim_t n, \ + const dim_t k, \ + const A_type* a, \ + const dim_t rs_a, \ + const dim_t cs_a, \ + const AOCL_MEMORY_TAG mtag_a, \ + const B_type* b, \ + const dim_t rs_b, \ + const dim_t cs_b, \ + const AOCL_MEMORY_TAG mtag_b, \ + C_type* c, \ + const dim_t rs_c, \ + C_type alpha, \ + C_type beta, \ + rntm_t* rntm_g \ + ) \ +{ \ + dim_t n_threads; \ + \ + /* Factorization of threads along m and n dimension respectively.*/ \ + dim_t ic_ways; \ + dim_t jc_ways; \ + \ + lpgemm_ ## LPGEMM_SFX ## _get_threading \ + ( \ + &n_threads, \ + &ic_ways, &jc_ways, \ + m, n, k, rntm_g \ + ); \ + \ + /* Set the packing block allocator field of the rntm. This will be + * inherited by all of the child threads when they make local copies of + * the rntm below.*/ \ + bli_membrk_rntm_set_membrk( rntm_g ); \ + \ + thrcomm_t static_lpgemm_comms[BLIS_LPGEMM_NUM_STATIC_COMMS]; \ + thrcomm_t* cur_lpgemm_comms = static_lpgemm_comms; \ + \ + if ( jc_ways > BLIS_LPGEMM_NUM_STATIC_COMMS ) \ + { \ + cur_lpgemm_comms = bli_malloc_intl( jc_ways * sizeof( thrcomm_t ) ); \ + } \ + for ( dim_t i = 0; i < jc_ways; ++i ) \ + { \ + bli_thrcomm_init( ic_ways, &cur_lpgemm_comms[i] ); \ + } \ + \ + _Pragma( "omp parallel num_threads(n_threads)" ) \ + { \ + /* Create a thread-local copy of the master thread's rntm_t. This is + * necessary since we want each thread to be able to track its own + * small block pool_t as it executes down the function stack.*/ \ + rntm_t rntm_l = *rntm_g; \ + \ + /* lpgemm_thrinfo_t object will be used to generate thrinfo_t objects + * for use in blis mt framework inside the respective mat mul driver + * functions.*/ \ + lpgemm_thrinfo_t thread; \ + thread.n_threads = n_threads; \ + thread.tid = omp_get_thread_num(); \ + thread.ic_ways = ic_ways; \ + thread.jc_ways = jc_ways; \ + thread.comm = cur_lpgemm_comms; \ + \ + lpgemm_rowvar_ ## LPGEMM_SFX \ + ( \ + m, n, k, \ + a, rs_a, cs_a, mtag_a, \ + b, rs_b, cs_b, mtag_b, \ + c, rs_c, \ + alpha, \ + beta, \ + &rntm_l, \ + &thread \ + ); \ + } \ + if ( jc_ways > BLIS_LPGEMM_NUM_STATIC_COMMS ) \ + { \ + bli_free_intl( cur_lpgemm_comms ); \ + } \ +} \ + +GEN_LPGEMM_OPENMP_DECORATOR(uint8_t,int8_t,int32_t,u8s8s32o32) +GEN_LPGEMM_OPENMP_DECORATOR(float,float,float,f32f32f32of32) + +#endif diff --git a/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.h b/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.h new file mode 100644 index 0000000000..dba1d71ab6 --- /dev/null +++ b/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.h @@ -0,0 +1,68 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef LPGEMM_THREAD_DECOR_OPENMP_H +#define LPGEMM_THREAD_DECOR_OPENMP_H + +#ifdef BLIS_ENABLE_OPENMP + +#include "lpgemm_types.h" + +#define GEN_LPGEMM_OPENMP_DECORATOR_FN(A_type,B_type,C_type,LPGEMM_SFX) \ +void lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \ + ( \ + const dim_t m, \ + const dim_t n, \ + const dim_t k, \ + const A_type* a, \ + const dim_t rs_a, \ + const dim_t cs_a, \ + const AOCL_MEMORY_TAG mtag_a, \ + const B_type* b, \ + const dim_t rs_b, \ + const dim_t cs_b, \ + const AOCL_MEMORY_TAG mtag_b, \ + C_type* c, \ + const dim_t rs_c, \ + C_type alpha, \ + C_type beta, \ + rntm_t* rntm_g \ + ); \ + +GEN_LPGEMM_OPENMP_DECORATOR_FN(uint8_t,int8_t,int32_t,u8s8s32o32) +GEN_LPGEMM_OPENMP_DECORATOR_FN(float,float,float,f32f32f32of32) + +#endif + +#endif //LPGEMM_THREAD_DECOR_OPENMP_H diff --git a/addon/aocl_gemm/frame/threading/lpgemm_thrinfo_utils.h b/addon/aocl_gemm/frame/threading/lpgemm_thrinfo_utils.h new file mode 100644 index 0000000000..2ac9b505a6 --- /dev/null +++ b/addon/aocl_gemm/frame/threading/lpgemm_thrinfo_utils.h @@ -0,0 +1,78 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef LPGEMM_THRINFO_UTILS_H +#define LPGEMM_THRINFO_UTILS_H + +// Parallelization only supported along jc and ic loops. Thus not reusing the +// existing thrinfo tree logic, since a light-weight work id generation will +// suffice. However the logic used for thread meta data generation, specific +// to jc and ic loops is borrowed. +BLIS_INLINE void lpgemm_gen_thrinfo + ( + lpgemm_thrinfo_t* thread, + thrinfo_t* thread_jc, + thrinfo_t* thread_ic + ) +{ + if ( thread == NULL ) + { + // Set n_ways=1 to ensure ST behaviour when thread is not initialized. + // This is the case when BLIS_ENABLE_OPENMP is not defined. + bli_thrinfo_set_ocomm_id( 0, thread_jc ); + bli_thrinfo_set_n_way( 1, thread_jc ); + bli_thrinfo_set_work_id( 0, thread_jc ); + + bli_thrinfo_set_ocomm_id( 0, thread_ic ); + bli_thrinfo_set_n_way( 1, thread_ic ); + bli_thrinfo_set_work_id( 0, thread_ic ); + } + else + { + // Replicate the logic in bli_l3_sup_thrinfo_create_root for jc thrinfo. + bli_thrinfo_set_ocomm_id( thread->tid, thread_jc ); + bli_thrinfo_set_n_way( thread->jc_ways, thread_jc ); + dim_t jc_work_id = thread->tid / thread->ic_ways; + bli_thrinfo_set_work_id( jc_work_id, thread_jc ); + + // Replicate the sub node creation logic in bli_thrinfo_sup_create_for_cntl + // for ic thrinfo. + dim_t ic_comm_id = thread->tid % thread->ic_ways; + bli_thrinfo_set_ocomm_id( ic_comm_id, thread_ic ); + bli_thrinfo_set_n_way( thread->ic_ways, thread_ic ); + bli_thrinfo_set_work_id( ic_comm_id, thread_ic ); + } +} + +#endif //LPGEMM_THRINFO_UTILS_H diff --git a/addon/aocl_gemm/frame/u8s8s32/lpgemm_reorder.c b/addon/aocl_gemm/frame/u8s8s32/lpgemm_reorder.c new file mode 100644 index 0000000000..8d61c757e8 --- /dev/null +++ b/addon/aocl_gemm/frame/u8s8s32/lpgemm_reorder.c @@ -0,0 +1,214 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "lpgemm_utils.h" +#include "lpgemm_reorder.h" +#include "lpgemm_packa.h" +#include "lpgemm_packb.h" +#include "lpgemm_config.h" + +void reorderb_nr64_u8s8s32o32 + ( + lpgemm_obj_t* b, + lpgemm_obj_t* b_reorder + ) +{ + dim_t NC = lpgemm_get_block_size_NC_global_cntx( U8S8S32OS32 ); + dim_t NR = lpgemm_get_block_size_NR_global_cntx( U8S8S32OS32 ); + dim_t KC = lpgemm_get_block_size_KC_global_cntx( U8S8S32OS32 ); + + dim_t rs_b = b->rs; + dim_t rs_b_reorder; + dim_t cs_b_reorder; + + dim_t n = b->width; + dim_t k = b->length; + + // k needs to be a multiple of 4 so that it can be used with vpdpbusd + // instruction. Padding is added in cases this condition is not + // satisfied, and therefore the k offset used for packed/reordered + // buffer needs to be updated. + dim_t k_updated = make_multiple_of_n( k, 4 ); + + // Initialize a local runtime with global settings if necessary. Note + // that in the case that a runtime is passed in, we make a local copy. + rntm_t rntm_g; + bli_rntm_init_from_global( &rntm_g ); + + dim_t n_threads = bli_rntm_num_threads( &rntm_g ); + n_threads = ( n_threads > 0 ) ? n_threads : 1; + +#ifdef BLIS_ENABLE_OPENMP + _Pragma( "omp parallel num_threads(n_threads)" ) + { + // Initialise a local thrinfo obj for work split across threads. + thrinfo_t thread_jc; + bli_thrinfo_set_n_way( n_threads, &thread_jc ); + bli_thrinfo_set_work_id( omp_get_thread_num(), &thread_jc ); +#else + { + // Initialise a local thrinfo obj for work split across threads. + thrinfo_t thread_jc; + bli_thrinfo_set_n_way( 1, &thread_jc ); + bli_thrinfo_set_work_id( 0, &thread_jc ); +#endif + // Compute the JC loop thread range for the current thread. + dim_t jc_start, jc_end; + bli_thread_range_sub( &thread_jc, n, NR, FALSE, &jc_start, &jc_end ); + + for ( dim_t jc = jc_start; jc < jc_end; jc += NC ) + { + dim_t nc0 = bli_min( ( jc_end - jc ), NC ); + + dim_t jc_cur_loop = jc; + dim_t jc_cur_loop_rem = 0; + dim_t n_sub_updated; + + get_B_panel_reordered_start_offset_width + ( + jc, n, NC, get_packb_u8s8s32o32_min_NR(), + &jc_cur_loop, &jc_cur_loop_rem, + &nc0, &n_sub_updated + ); + + for ( dim_t pc = 0; pc < k; pc += KC ) + { + dim_t kc0 = bli_min( ( k - pc ), KC ); + + // kc0 needs to be a multiple of 4 so that it can be used with + // vpdpbusd instruction. Padding is added in cases this + // condition is not satisfied, and therefore the kc0 offsets + // used for packed/reordered buffers needs to be updated. + dim_t kc0_updated = make_multiple_of_n( kc0, 4 ); + + // The offsets are calculated in such a way that it resembles + // the reorder buffer traversal in single threaded reordering. + // The panel boundaries (KCxNC) remain as it is accessed in + // single thread, and as a consequence a thread with jc_start + // inside the panel cannot consider NC range for reorder. It + // has to work with NC' < NC, and the offset is calulated using + // prev NC panels spanning k dim + cur NC panel spaning pc loop + // cur iteration + (NC - NC') spanning current kc0 (<= KC). + // + //Eg: Consider the following reordered buffer diagram: + // t1 t2 + // | | + // | |..NC..| + // | | | + // |.NC. |.NC. |NC'|NC" + // pc=0-+-----+-----+---+--+ + // KC| | | | | + // | 1 | 3 | 5 | + // pc=KC-+-----+-----+---st-+ + // KC| | | | | + // | 2 | 4 | 6 | 7| + // pc=k=2KC-+-----+-----+---+--+ + // |jc=0 |jc=NC|jc=2NC| + // + // The numbers 1,2..6,7 denotes the order in which reordered + // KCxNC blocks are stored in memory, ie: block 1 followed by 2 + // followed by 3, etc. Given two threads t1 and t2, and t2 needs + // to acces point st in the reorder buffer to write the data: + // The offset calulation logic will be: + // jc_cur_loop = 2NC, jc_cur_loop_rem = NC', pc = KC, + // n_sub_updated = NC, k = 2KC, kc0_updated = KC + // + // st = ( jc_cur_loop * k ) + // + ( n_sub_updated * pc ) + // + ( NC' * kc0_updated) + packb_nr64_u8s8s32o32 + ( + ( ( ( int8_t* )b_reorder->storage.aligned_buffer ) + + ( jc_cur_loop * k_updated ) + ( n_sub_updated * pc ) + + ( jc_cur_loop_rem * kc0_updated ) ), + ( ( ( int8_t* )b->storage.aligned_buffer ) + + ( rs_b * pc ) + jc ), + rs_b, nc0, kc0, &rs_b_reorder, &cs_b_reorder + ); + } + + adjust_B_panel_reordered_jc( &jc, jc_cur_loop ); + } + } + + b_reorder->rs = rs_b_reorder; + b_reorder->cs = cs_b_reorder; + b_reorder->mtag = REORDERED; +} + +void reordera_mr6_u8s8s32o32 + ( + lpgemm_obj_t* a, + lpgemm_obj_t* a_reorder + ) +{ + dim_t MC = lpgemm_get_block_size_MC_global_cntx( U8S8S32OS32 ); + dim_t KC = lpgemm_get_block_size_KC_global_cntx( U8S8S32OS32 ); + + dim_t rs_a = a->rs; + dim_t rs_a_reorder; + dim_t cs_a_reorder; + + dim_t k = a->width; + dim_t m = a->length; + + for ( dim_t pc = 0; pc < k; pc += KC ) + { + dim_t kc0 = bli_min( ( k - pc ), KC ); + + // kc0 needs to be a multiple of 4 so that it can be used with + // vpdpbusd instruction. Padding is added in cases this + // condition is not satisfied, and therefore the kc0 offsets + // used for packed/reordered buffers needs to be updated. + dim_t kc0_updated = make_multiple_of_n( kc0, 4 ); + + for ( dim_t ic = 0; ic < m; ic += MC ) + { + dim_t mc0 = bli_min( ( m - ic ), MC ); + + packa_k64_u8s8s32o32 + ( + ( ( ( uint8_t* )a_reorder->storage.aligned_buffer ) + ( pc * m ) + + ( ic * kc0_updated ) ), + ( ( ( uint8_t* )a->storage.aligned_buffer ) + ( rs_a * ic ) + pc ), + rs_a, mc0, kc0, &rs_a_reorder, &cs_a_reorder + ); + } + } + + a_reorder->rs = rs_a_reorder; + a_reorder->cs = cs_a_reorder; + a_reorder->mtag = REORDERED; +} diff --git a/addon/aocl_gemm/frame/u8s8s32/lpgemm_reorder.h b/addon/aocl_gemm/frame/u8s8s32/lpgemm_reorder.h new file mode 100644 index 0000000000..eb8dad9cfc --- /dev/null +++ b/addon/aocl_gemm/frame/u8s8s32/lpgemm_reorder.h @@ -0,0 +1,52 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef LPGEMM_REORDER_H +#define LPGEMM_REORDER_H + +#include "lpgemm_types.h" + +void reorderb_nr64_u8s8s32o32 + ( + lpgemm_obj_t* b, + lpgemm_obj_t* b_reorder + ); + +void reordera_mr6_u8s8s32o32 + ( + lpgemm_obj_t* a, + lpgemm_obj_t* a_reorder + ); + +#endif //LPGEMM_REORDER_H diff --git a/addon/aocl_gemm/frame/u8s8s32/lpgemm_u8s8s32.c b/addon/aocl_gemm/frame/u8s8s32/lpgemm_u8s8s32.c new file mode 100644 index 0000000000..5ee25a92a1 --- /dev/null +++ b/addon/aocl_gemm/frame/u8s8s32/lpgemm_u8s8s32.c @@ -0,0 +1,340 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "lpgemm_u8s8s32.h" +#include "lpgemm_packa.h" +#include "lpgemm_packb.h" +#include "lpgemm_6x64rowmajor.h" +#include "lpgemm_utils.h" +#include "lpgemm_thrinfo_utils.h" +#include "lpgemm_config.h" + +// B should always be packed. +void lpgemm_rowvar_u8s8s32o32 + ( + const dim_t m, + const dim_t n, + const dim_t k, + const uint8_t* a, + const dim_t rs_a, + const dim_t cs_a, + const AOCL_MEMORY_TAG mtag_a, + const int8_t* b, + const dim_t rs_b, + const dim_t cs_b, + const AOCL_MEMORY_TAG mtag_b, + int32_t* c, + const dim_t rs_c, + int32_t alpha, + int32_t beta, + rntm_t* rntm, + lpgemm_thrinfo_t* thread + ) +{ + dim_t NC = lpgemm_get_block_size_NC_global_cntx( U8S8S32OS32 ); + dim_t KC = lpgemm_get_block_size_KC_global_cntx( U8S8S32OS32 ); + dim_t MC = lpgemm_get_block_size_MC_global_cntx( U8S8S32OS32 ); + dim_t NR = lpgemm_get_block_size_NR_global_cntx( U8S8S32OS32 ); + dim_t MR = lpgemm_get_block_size_MR_global_cntx( U8S8S32OS32 ); + + if ( mtag_b == UNPACKED ) + { + //Error: can only work with packed B now. + return; + } + + // Strides are updated based on matrix packing/reordering. + const uint8_t* a_use = NULL; + dim_t rs_a_use = rs_a; + dim_t cs_a_use = cs_a; + dim_t a_block_stride = 0; + + const int8_t* b_use = NULL; + dim_t rs_b_use = rs_b; + dim_t cs_b_use = cs_b; + + int32_t* c_use_jc = NULL; + int32_t* c_use_ic = NULL; + + // Pack buffer for A. + uint8_t* pack_a_buffer_u8s8s32o32; + mem_t mem_a = BLIS_MEM_INITIALIZER; + siz_t mem_a_size_req = 0; + + // Pack buffer for B. + int8_t* pack_b_buffer_u8s8s32o32; + mem_t mem_b = BLIS_MEM_INITIALIZER; + siz_t mem_b_size_req = 0; + dim_t packb_min_NR = get_packb_u8s8s32o32_min_NR(); + + // kc needs to be a multiple of 4 so that it can be used with vpdpbusd + // instruction. Padding is added in cases this condition is not + // satisfied, and therefore the k offset used for packed/reordered + // buffer needs to be updated. + dim_t k_updated = make_multiple_of_n( k, 4 ); + + // Generate thrinfo objects for jc and ic loops from lpgemm_thrinfo_t. + thrinfo_t thread_jc; + thrinfo_t thread_ic; + + lpgemm_gen_thrinfo( thread, &thread_jc, &thread_ic ); + + // Compute the JC loop thread range for the current thread. + dim_t jc_start, jc_end; + bli_thread_range_sub( &thread_jc, n, NR, FALSE, &jc_start, &jc_end ); + + for ( dim_t jc = jc_start; jc < jc_end; jc += NC ) + { + dim_t nc0 = bli_min( ( jc_end - jc ), NC ); + c_use_jc = c + jc; + + dim_t jc_cur_loop = jc; + dim_t jc_cur_loop_rem = 0; + dim_t n_sub_updated; + + if ( mtag_b == REORDERED ) + { + get_B_panel_reordered_start_offset_width + ( + jc, n, NC, packb_min_NR, + &jc_cur_loop, &jc_cur_loop_rem, + &nc0, &n_sub_updated + ); + } + + for ( dim_t pc = 0; pc < k; pc += KC ) + { + int32_t beta0 = ( pc == 0 ) ? beta : 1; + dim_t kc0 = bli_min( ( k - pc ), KC ); + + // kc0 needs to be a multiple of 4 so that it can be + // used with vpdpbusd instruction. Padding is added in + // cases this condition is not satisfied, and therefore + // the kc0 offsets used for packed/reordered buffers + // needs to be updated. + dim_t kc0_updated = make_multiple_of_n( kc0, 4 ); + + if ( mtag_b == PACK ) + { + // Pack B chunks are based on jc work id. + dim_t jc_work_id = bli_thread_work_id( &thread_jc ); + + // Using child thrinfo (thread_ic) tid to decide chief thread + // per B matrix chunk (jc work id group) + if ( bli_thread_am_ochief( &thread_ic ) ) + { + // nc0 needs to be a multiple of 16 since this gives maximum + // vectorization. Packing B always results in buffers with width + // which is a multiple of 16. Subsequently the nc0 offsets used + // for packed/reordered buffers needs to be updated. + dim_t nc0_updated = make_multiple_of_n( nc0, packb_min_NR ); + mem_b_size_req = sizeof( int8_t ) * nc0_updated * kc0_updated; + + lpgemm_alloc_mem_panel + ( + mem_b_size_req, BLIS_BUFFER_FOR_B_PANEL, + &mem_b, rntm + ); + + thread->comm[jc_work_id].sent_object = + bli_mem_buffer( &mem_b ); + } + + // All threads in work group should wait till chief thread has + // finished allocating the packing buffers. + bli_thrcomm_barrier + ( + bli_thread_ocomm_id( &thread_ic ), + &thread->comm[jc_work_id] + ); + + pack_b_buffer_u8s8s32o32 = + ( int8_t* ) thread->comm[jc_work_id].sent_object; + + // Compute the B panel per thread loop range for parallel + // packing using ic_ways number of threads. Since atmost only + // ic_ways threads can be used, the thread_ic attributes are + // used to split the loop range. + dim_t jc_packb_start, jc_packb_end; + bli_thread_range_sub + ( + &thread_ic, nc0, NR, FALSE, + &jc_packb_start, &jc_packb_end + ); + + // Ensure thread ranges are valid, especially cases where no: + // of threads available for parallelization are greater than + // no: of B panel NR chunks. + if ( ( jc_packb_end > jc_packb_start ) && + ( jc_packb_start < ( jc + nc0 ) ) ) + { + packb_nr64_u8s8s32o32 + ( + pack_b_buffer_u8s8s32o32 + ( jc_packb_start * kc0_updated ), + ( b + ( rs_b * pc ) + ( cs_b * jc ) + + ( cs_b * jc_packb_start ) ), rs_b, + ( jc_packb_end - jc_packb_start ), kc0, + &rs_b_use, &cs_b_use + ); + } + else + { + get_packb_nr64_u8s8s32o32_strides( &rs_b_use, &cs_b_use ); + } + + // All threads in work group should wait till B matrix packing + // is completed by the participating threads. + bli_thrcomm_barrier + ( + bli_thread_ocomm_id( &thread_ic ), + &thread->comm[jc_work_id] + ); + b_use = pack_b_buffer_u8s8s32o32; + } + else if ( mtag_b == REORDERED ) + { + // In multi-threaded scenarios, an extra offset into a given + // packed B panel is required, since the jc loop split can + // result in per thread start offset inside the panel, instead + // of panel boundaries. + b_use = b + ( jc_cur_loop * k_updated ) + + ( n_sub_updated * pc ) + + ( jc_cur_loop_rem * kc0_updated ); + + get_packb_nr64_u8s8s32o32_strides( &rs_b_use, &cs_b_use ); + } + else + { + //Unpacked B not supported. + return; + } + + dim_t ic_start, ic_end; + bli_thread_range_sub( &thread_ic, m, MR, FALSE, &ic_start, &ic_end ); + + for ( dim_t ic = ic_start; ic < ic_end; ic += MC ) + { + dim_t mc0 = bli_min( ( ic_end - ic ), MC ); + c_use_ic = c_use_jc + ( rs_c * ic ); + + // Matrix A packed and reordered code path is not triggerred + // currently since we do not support it yet. + if ( mtag_a == PACK ) + { + mem_a_size_req = sizeof( uint8_t ) * mc0 * kc0_updated; + + lpgemm_alloc_mem_panel + ( + mem_a_size_req, BLIS_BUFFER_FOR_A_BLOCK, + &mem_a, rntm + ); + pack_a_buffer_u8s8s32o32 = ( uint8_t* )bli_mem_buffer( &mem_a ); + + packa_k64_u8s8s32o32 + ( + pack_a_buffer_u8s8s32o32, + ( a + ( rs_a * ic ) + pc ), rs_a, + mc0, kc0, + &rs_a_use, &cs_a_use + ); + a_use = pack_a_buffer_u8s8s32o32; + a_block_stride = kc0_updated; + } + else if ( mtag_a == REORDERED ) + { + get_packa_k64_u8s8s32o32_strides( &rs_a_use, &cs_a_use ); + a_use = a + ( pc * m ) + ( kc0_updated * ic ); + a_block_stride = kc0_updated; + } + else + { + a_use = a + ( rs_a * ic ) + ( cs_a * pc ); + + // Int8 kernel reads 4 elements, totalling 4 bytes in a + // single broadcast for use in vnni instruction. + // Non vnni based kernel requires update to this code. + cs_a_use = 4; + a_block_stride = rs_a; + } + + for ( dim_t jr = 0; jr < nc0; jr += NR ) + { + dim_t nr0 = bli_min( ( nc0 - jr ), NR ); + + // Reorder/Packed B, Reorder/Packed/Unpacked A call. + lpgemm_rowvar_u8s8s32o32_6x64 + ( + mc0, nr0, kc0, + a_use, rs_a_use, cs_a_use, a_block_stride, + ( b_use + ( jr * kc0_updated ) ), rs_b_use, cs_b_use, + ( c_use_ic + jr ), rs_c, 1, + alpha, beta0 + ); + } + } + } + if ( mtag_b == REORDERED ) + { + adjust_B_panel_reordered_jc( &jc, jc_cur_loop ); + } + } + + // Release pack buffers. + if ( mtag_b == PACK ) + { + // All threads in work group should wait till B matrix usage is + // completed by the participating threads. + bli_thrcomm_barrier + ( + bli_thread_ocomm_id( &thread_jc ), + &thread->comm[bli_thread_work_id( &thread_jc)] + ); + + if ( bli_thread_am_ochief( &thread_ic ) ) + { + if ( bli_mem_is_alloc( &mem_b ) ) + { + bli_membrk_release( rntm, &mem_b ); + } + } + } + if ( mtag_a == PACK ) + { + if ( bli_mem_is_alloc( &mem_a ) ) + { + bli_membrk_release( rntm, &mem_a ); + } + } +} diff --git a/addon/aocl_gemm/frame/u8s8s32/lpgemm_u8s8s32.h b/addon/aocl_gemm/frame/u8s8s32/lpgemm_u8s8s32.h new file mode 100644 index 0000000000..2da9cc2de7 --- /dev/null +++ b/addon/aocl_gemm/frame/u8s8s32/lpgemm_u8s8s32.h @@ -0,0 +1,62 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef LPGEMM_U8S8S32_H +#define LPGEMM_U8S8S32_H + +#include "lpgemm_types.h" + +// B should always be packed. +void lpgemm_rowvar_u8s8s32o32 + ( + const dim_t m, + const dim_t n, + const dim_t k, + const uint8_t* a, + const dim_t rs_a, + const dim_t cs_a, + const AOCL_MEMORY_TAG mtag_a, + const int8_t* b, + const dim_t rs_b, + const dim_t cs_b, + const AOCL_MEMORY_TAG mtag_b, + int32_t* c, + const dim_t rs_c, + int32_t alpha, + int32_t beta, + rntm_t* rntm, + lpgemm_thrinfo_t* thread + ); + +#endif //LPGEMM_U8S8S32_H diff --git a/addon/aocl_gemm/frame/u8s8s32/lpgemm_utils.c b/addon/aocl_gemm/frame/u8s8s32/lpgemm_utils.c new file mode 100644 index 0000000000..aa6669469d --- /dev/null +++ b/addon/aocl_gemm/frame/u8s8s32/lpgemm_utils.c @@ -0,0 +1,156 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "lpgemm_utils.h" + +dim_t get_64byte_aligned_memory + ( + void** original_memory, + void** aligned_memory, + int64_t allocate_size + ) +{ + // Get 64 byte aligned memory. + int8_t* t1_original = ( int8_t* ) malloc( allocate_size + 64 ); + if ( t1_original == NULL ) + { + //Error in malloc. + *original_memory = NULL; + *aligned_memory = NULL; + return -1; + } + + int8_t* ta_original = t1_original + 64; + ta_original = ta_original - ( ( int64_t )( ta_original ) % 64 ); + + *original_memory = t1_original; + *aligned_memory = ta_original; + return 0; +} + +static lpgemm_obj_t* alloc_lpgemm_obj_t_u8s8s32 + ( + dim_t length, + dim_t width, + dim_t stride, + dim_t elem_size, + AOCL_STOR_TAG stor_scheme, + AOCL_MEMORY_TAG mtag + ) +{ + lpgemm_obj_t* obj = ( lpgemm_obj_t* ) malloc( sizeof( lpgemm_obj_t ) ); + + if ( obj == NULL ) + { + return NULL; //failure + } + + // Allocate aligned buffers. + get_64byte_aligned_memory( &obj->storage.origin_buffer, + &obj->storage.aligned_buffer, + ( elem_size * length * width ) ); + + if ( obj->storage.origin_buffer == NULL ) + { + // Buffer allocation failed. + free( obj ); + return NULL; + } + + obj->length = length; + obj->width = width; + obj->elem_size = elem_size; + + if ( stor_scheme == ROW_MAJOR ) + { + obj->rs = stride; + obj->cs = 4; // 4 elements read at a time. + } + else if ( stor_scheme == COLUMN_MAJOR ) + { + obj->cs = stride; + obj->rs = 1; + } + obj->mtag = mtag; + + return obj; +} + +lpgemm_obj_t* alloc_unpack_tag_lpgemm_obj_t_u8s8s32 + ( + dim_t length, + dim_t width, + dim_t stride, + dim_t elem_size, + AOCL_STOR_TAG stor_scheme + ) +{ + return alloc_lpgemm_obj_t_u8s8s32( length, width, stride, elem_size, stor_scheme, UNPACKED ); +} + +lpgemm_obj_t* alloc_pack_tag_lpgemm_obj_t_u8s8s32 + ( + dim_t length, + dim_t width, + dim_t stride, + dim_t elem_size, + AOCL_STOR_TAG stor_scheme + ) +{ + return alloc_lpgemm_obj_t_u8s8s32( length, width, stride, elem_size, stor_scheme, PACK ); +} + +lpgemm_obj_t* alloc_reorder_tag_lpgemm_obj_t_u8s8s32 + ( + dim_t length, + dim_t width, + dim_t stride, + dim_t elem_size, + AOCL_STOR_TAG stor_scheme + ) +{ + // Extra space since packing does width in multiples of 16. + dim_t width_reorder = make_multiple_of_n( width, 16 ); + // Extra space since packing does length in multiples of 4. + dim_t length_reorder = make_multiple_of_n( length, 4 ); + + return alloc_lpgemm_obj_t_u8s8s32( length_reorder, width_reorder, stride, elem_size, stor_scheme, REORDERED ); +} + +void dealloc_lpgemm_obj_t_u8s8s32( lpgemm_obj_t* obj ) +{ + free( obj->storage.origin_buffer ); + free( obj ); +} diff --git a/addon/aocl_gemm/frame/u8s8s32/lpgemm_utils.h b/addon/aocl_gemm/frame/u8s8s32/lpgemm_utils.h new file mode 100644 index 0000000000..743af4f3ec --- /dev/null +++ b/addon/aocl_gemm/frame/u8s8s32/lpgemm_utils.h @@ -0,0 +1,225 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef LPGEMM_UTILS_H +#define LPGEMM_UTILS_H + +#include "lpgemm_types.h" + +// Users of this API needs to free the allocated memory on their own. +dim_t get_64byte_aligned_memory + ( + void** original_memory, + void** aligned_memory, + int64_t allocate_size + ); + +lpgemm_obj_t* alloc_unpack_tag_lpgemm_obj_t_u8s8s32 + ( + dim_t length, + dim_t width, + dim_t stride, + dim_t elem_size, + AOCL_STOR_TAG stor_scheme + ); + +lpgemm_obj_t* alloc_pack_tag_lpgemm_obj_t_u8s8s32 + ( + dim_t length, + dim_t width, + dim_t stride, + dim_t elem_size, + AOCL_STOR_TAG stor_scheme + ); + +lpgemm_obj_t* alloc_reorder_tag_lpgemm_obj_t_u8s8s32 + ( + dim_t length, + dim_t width, + dim_t stride, + dim_t elem_size, + AOCL_STOR_TAG stor_scheme + ); + +void dealloc_lpgemm_obj_t_u8s8s32( lpgemm_obj_t* obj ); + +BLIS_INLINE void bli_param_map_char_to_lpmtag + ( + char mtag, + AOCL_MEMORY_TAG* lp_mtag + ) +{ + if ( mtag == 'n' || mtag == 'N' ) *lp_mtag = UNPACKED; + else if ( mtag == 'p' || mtag == 'P' ) *lp_mtag = PACK; + else if ( mtag == 'r' || mtag == 'R' ) *lp_mtag = REORDERED; + else + { + *lp_mtag = UNPACKED; + } +} + +BLIS_INLINE void bli_param_map_char_to_lpmat_type + ( + const char mtag, + AOCL_MATRIX_TYPE* lp_mat_type + ) +{ + if ( mtag == 'a' || mtag == 'A' ) *lp_mat_type = A_MATRIX; + else if ( mtag == 'b' || mtag == 'B' ) *lp_mat_type = B_MATRIX; + else + { + *lp_mat_type = B_MATRIX; + } +} + +BLIS_INLINE dim_t make_multiple_of_n( dim_t k, dim_t n ) +{ + if ( n <= 0 ) + { + return 0; + } + + return ( ( ( k + n - 1 ) / n ) * n ); +} + +BLIS_INLINE void lpgemm_alloc_mem_panel + ( + dim_t size_req, + packbuf_t buf_type, + mem_t* mem, + rntm_t* rntm_l + ) +{ + if ( bli_mem_is_unalloc( mem ) ) + { + bli_membrk_acquire_m + ( + rntm_l, + size_req, + buf_type, + mem + ); + } + else + { + siz_t mem_size = bli_mem_size( mem ); + if ( mem_size < size_req ) + { + bli_membrk_release( rntm_l, mem ); + bli_membrk_acquire_m + ( + rntm_l, + size_req, + buf_type, + mem + ); + } + } +} + +BLIS_INLINE dim_t get_Bpanel_width_for_kdim_traversal + ( + dim_t jc, + dim_t n, + dim_t NC, + dim_t NR + ) +{ + dim_t n_mod_NR = n % NR; + dim_t n_sub_updated = NC; + + if ( ( n % NC ) != 0 ) + { + // Only applicable to final NC part of jc loop where jc + remaining + // elements is less than NC; or when n < NC in which case panel width + // is atmost n. + dim_t n_last_loop = ( n / NC ) * NC; + if ( jc >= n_last_loop ) + { + n_sub_updated = n - n_last_loop; + if ( n_mod_NR != 0 ) + { + n_sub_updated += ( NR - n_mod_NR ); + } + } + } + + return n_sub_updated; +} + +BLIS_INLINE void get_B_panel_reordered_start_offset_width + ( + dim_t jc, + dim_t n, + dim_t NC, + dim_t NR, + dim_t* panel_start, + dim_t* panel_offset, + dim_t* panel_width, + dim_t* panel_width_kdim_trav + ) +{ + // Since n dimension is split across threads in units of NR blocks, + // it could happen that B matrix chunk for a thread may be part of + // two separate NCxKC panels. In this case nc0 is updated such that + // the jr loop only accesses the remaining portion of current NCxKC + // panel, with the next jc iteration taking care of the other panel. + // This ensures that jr loop does not cross panel boundaries. + ( *panel_start ) = ( jc / NC ) * NC; + ( *panel_offset ) = jc - ( *panel_start ); + + // Check if jc + current_panel_width (nc0) crosses panel boundaries. + if ( ( jc + ( *panel_width ) ) > ( ( *panel_start ) + NC ) ) + { + ( *panel_width ) = NC - ( *panel_offset ); + } + + ( *panel_width_kdim_trav ) = get_Bpanel_width_for_kdim_traversal + ( + jc, n, NC, NR + ); +} + +BLIS_INLINE void adjust_B_panel_reordered_jc( dim_t* jc, dim_t panel_start ) +{ + // Since n dimension is split across threads in units of NR blocks, + // it could happen that B matrix chunk for a thread may be part of + // two separate NCxKC panels. In this case jc is reset to immediate + // previous panel offset so that in the next iteration, the + // following panel belonging to the B chunk is accessed. This + // ensures that jr loop does not cross panel boundaries. + ( *jc ) = panel_start; +} + +#endif //LPGEMM_UTILS_H diff --git a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_6x64rowmajor.h b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_6x64rowmajor.h new file mode 100644 index 0000000000..9a5b3644d1 --- /dev/null +++ b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_6x64rowmajor.h @@ -0,0 +1,58 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_GEMM_INT8_MNROW +#define BLIS_GEMM_INT8_MNROW + +// 6x64 int8o32 kernel +void lpgemm_rowvar_u8s8s32o32_6x64 + ( + const dim_t m0, + const dim_t n0, + const dim_t k0, + const uint8_t* a, + const dim_t rs_a, + const dim_t cs_a, + const dim_t ps_a, + const int8_t* b, + const dim_t rs_b, + const dim_t cs_b, + int32_t* c, + const dim_t rs_c, + const dim_t cs_c, + const int32_t alpha, + const int32_t beta + ); + +#endif //BLIS_GEMM_INT8_MNROW diff --git a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_6x64rowmajor_amd512vnni.c b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_6x64rowmajor_amd512vnni.c new file mode 100644 index 0000000000..679fc916c2 --- /dev/null +++ b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_6x64rowmajor_amd512vnni.c @@ -0,0 +1,693 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include + +#include "blis.h" +#include "lpgemm_6x64rowmajor.h" +#include "lpgemm_n_fringe.h" +#include "lpgemm_m_fringe.h" + +// 6x64 int8o32 kernel +void lpgemm_rowvar_u8s8s32o32_6x64 + ( + const dim_t m0, + const dim_t n0, + const dim_t k0, + const uint8_t* a, + const dim_t rs_a, + const dim_t cs_a, + const dim_t ps_a, + const int8_t* b, + const dim_t rs_b, + const dim_t cs_b, + int32_t* c, + const dim_t rs_c, + const dim_t cs_c, + const int32_t alpha, + const int32_t beta + ) +{ + dim_t MR = 6; + dim_t NR = 64; + + dim_t m_full_pieces = m0 / MR; + dim_t m_full_pieces_loop_limit = m_full_pieces * MR; + dim_t m_partial_pieces = m0 % MR; + + dim_t k_full_pieces = k0 / 4; + dim_t k_partial_pieces = k0 % 4; + + uint32_t a_kfringe_buf = 0; + + if ( n0 < NR ) + { + dim_t n0_rem = n0 % 16; + + // Split into multiple smaller fringe kernels, so as to maximize + // vectorization. Any n0 < NR(64) can be expressed as n0 = 48 + n` + // or n0 = 32 + n` or n0 = 16 + n`, where n` < 16. + dim_t n0_48 = n0 / 48; + dim_t n0_32 = n0 / 32; + dim_t n0_16 = n0 / 16; + + // KC when not multiple of 4 will have padding to make it multiple of + // 4 in packed buffer. Also the k0 cannot be passed as the updated + // value since A matrix is not packed and requires original k0. + dim_t k0_updated = k0; + if ( k_partial_pieces > 0 ) + { + k0_updated += ( 4 - k_partial_pieces ); + } + + if ( n0_48 == 1 ) + { + lpgemm_rowvar_u8s8s32o32_6x48 + ( + m0, k0, + a, rs_a, cs_a, ps_a, + b, ( ( rs_b / 4 ) * 3 ), cs_b, + c, rs_c, + alpha, beta + ); + + b = b + ( 48 * k0_updated ); // k0x48 packed contiguosly. + c = c + 48; + } + else if ( n0_32 == 1 ) + { + lpgemm_rowvar_u8s8s32o32_6x32 + ( + m0, k0, + a, rs_a, cs_a, ps_a, + b, ( ( rs_b / 4 ) * 2 ), cs_b, + c, rs_c, + alpha, beta + ); + + b = b + ( 32 * k0_updated ); // k0x32 packed contiguosly. + c = c + 32; + } + else if ( n0_16 == 1 ) + { + lpgemm_rowvar_u8s8s32o32_6x16 + ( + m0, k0, + a, rs_a, cs_a, ps_a, + b, ( ( rs_b / 4 ) * 1 ), cs_b, + c, rs_c, + alpha, beta + ); + + b = b + ( 16 * k0_updated ); // k0x16 packed contiguosly. + c = c + 16; + } + + if ( n0_rem > 0 ) + { + lpgemm_rowvar_u8s8s32o32_6xlt16 + ( + m0, k0, + a, rs_a, cs_a, ps_a, + b, ( ( rs_b / 4 ) * 1 ), cs_b, + c, rs_c, + alpha, beta, n0_rem + ); + + // No leftover fringe after this point. + } + + return; + } + + // B matrix storage. + __m512i b0; + __m512i b1; + __m512i b2; + __m512i b3; + + // A matrix storage. + __m512i a_int32_0; + __m512i a_int32_1; + + for ( dim_t ir = 0; ir < m_full_pieces_loop_limit; ir += MR ) + { + // Registers to use for accumulating C. + __m512i c_int32_0p0 = _mm512_setzero_epi32(); + __m512i c_int32_0p1 = _mm512_setzero_epi32(); + __m512i c_int32_0p2 = _mm512_setzero_epi32(); + __m512i c_int32_0p3 = _mm512_setzero_epi32(); + + __m512i c_int32_1p0 = _mm512_setzero_epi32(); + __m512i c_int32_1p1 = _mm512_setzero_epi32(); + __m512i c_int32_1p2 = _mm512_setzero_epi32(); + __m512i c_int32_1p3 = _mm512_setzero_epi32(); + + __m512i c_int32_2p0 = _mm512_setzero_epi32(); + __m512i c_int32_2p1 = _mm512_setzero_epi32(); + __m512i c_int32_2p2 = _mm512_setzero_epi32(); + __m512i c_int32_2p3 = _mm512_setzero_epi32(); + + __m512i c_int32_3p0 = _mm512_setzero_epi32(); + __m512i c_int32_3p1 = _mm512_setzero_epi32(); + __m512i c_int32_3p2 = _mm512_setzero_epi32(); + __m512i c_int32_3p3 = _mm512_setzero_epi32(); + + __m512i c_int32_4p0 = _mm512_setzero_epi32(); + __m512i c_int32_4p1 = _mm512_setzero_epi32(); + __m512i c_int32_4p2 = _mm512_setzero_epi32(); + __m512i c_int32_4p3 = _mm512_setzero_epi32(); + + __m512i c_int32_5p0 = _mm512_setzero_epi32(); + __m512i c_int32_5p1 = _mm512_setzero_epi32(); + __m512i c_int32_5p2 = _mm512_setzero_epi32(); + __m512i c_int32_5p3 = _mm512_setzero_epi32(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + // The instructions are arranged in a mixed way to reduce data + // chain dependencies. + + // Load 4 rows with 64 elements each from B to 4 ZMM registers. It + // is to be noted that the B matrix is packed for use in vnni + // instructions and each load to ZMM register will have 4 elements + // along k direction and 16 elements across n directions, so 4x16 + // elements to a ZMM register. + b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + b2 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 2 ) ); + b3 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 3 ) ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-63] = a[0,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + + // Broadcast a[1,kr:kr+4]. + a_int32_1 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); + c_int32_0p2 = _mm512_dpbusd_epi32( c_int32_0p2, a_int32_0, b2 ); + c_int32_0p3 = _mm512_dpbusd_epi32( c_int32_0p3, a_int32_0, b3 ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-63] = a[1,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_1, b0 ); + + // Broadcast a[2,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_1, b1 ); + c_int32_1p2 = _mm512_dpbusd_epi32( c_int32_1p2, a_int32_1, b2 ); + c_int32_1p3 = _mm512_dpbusd_epi32( c_int32_1p3, a_int32_1, b3 ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-63] = a[2,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + + // Broadcast a[3,kr:kr+4]. + a_int32_1 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + + c_int32_2p1 = _mm512_dpbusd_epi32( c_int32_2p1, a_int32_0, b1 ); + c_int32_2p2 = _mm512_dpbusd_epi32( c_int32_2p2, a_int32_0, b2 ); + c_int32_2p3 = _mm512_dpbusd_epi32( c_int32_2p3, a_int32_0, b3 ); + + // Perform column direction mat-mul with k = 4. + // c[3,0-63] = a[3,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_1, b0 ); + + // Broadcast a[4,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); + + c_int32_3p1 = _mm512_dpbusd_epi32( c_int32_3p1, a_int32_1, b1 ); + c_int32_3p2 = _mm512_dpbusd_epi32( c_int32_3p2, a_int32_1, b2 ); + c_int32_3p3 = _mm512_dpbusd_epi32( c_int32_3p3, a_int32_1, b3 ); + + // Perform column direction mat-mul with k = 4. + // c[4,0-63] = a[4,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_4p0 = _mm512_dpbusd_epi32( c_int32_4p0, a_int32_0, b0 ); + + // Broadcast a[5,kr:kr+4]. + a_int32_1 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 5 ) + ( cs_a * kr ) ) ); + + c_int32_4p1 = _mm512_dpbusd_epi32( c_int32_4p1, a_int32_0, b1 ); + c_int32_4p2 = _mm512_dpbusd_epi32( c_int32_4p2, a_int32_0, b2 ); + c_int32_4p3 = _mm512_dpbusd_epi32( c_int32_4p3, a_int32_0, b3 ); + + // Perform column direction mat-mul with k = 4. + // c[5,0-63] = a[5,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_5p0 = _mm512_dpbusd_epi32( c_int32_5p0, a_int32_1, b0 ); + c_int32_5p1 = _mm512_dpbusd_epi32( c_int32_5p1, a_int32_1, b1 ); + c_int32_5p2 = _mm512_dpbusd_epi32( c_int32_5p2, a_int32_1, b2 ); + c_int32_5p3 = _mm512_dpbusd_epi32( c_int32_5p3, a_int32_1, b3 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + b2 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); + b3 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 3 ) ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-63] = a[0,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + + // Broadcast a[1,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_1 = _mm512_set1_epi32( a_kfringe_buf ); + + c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); + c_int32_0p2 = _mm512_dpbusd_epi32( c_int32_0p2, a_int32_0, b2 ); + c_int32_0p3 = _mm512_dpbusd_epi32( c_int32_0p3, a_int32_0, b3 ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-63] = a[1,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_1, b0 ); + + // Broadcast a[2,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_1, b1 ); + c_int32_1p2 = _mm512_dpbusd_epi32( c_int32_1p2, a_int32_1, b2 ); + c_int32_1p3 = _mm512_dpbusd_epi32( c_int32_1p3, a_int32_1, b3 ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-63] = a[2,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + + // Broadcast a[3,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_1 = _mm512_set1_epi32( a_kfringe_buf ); + + c_int32_2p1 = _mm512_dpbusd_epi32( c_int32_2p1, a_int32_0, b1 ); + c_int32_2p2 = _mm512_dpbusd_epi32( c_int32_2p2, a_int32_0, b2 ); + c_int32_2p3 = _mm512_dpbusd_epi32( c_int32_2p3, a_int32_0, b3 ); + + // Perform column direction mat-mul with k = 4. + // c[3,0-63] = a[3,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_1, b0 ); + + // Broadcast a[4,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 4 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + c_int32_3p1 = _mm512_dpbusd_epi32( c_int32_3p1, a_int32_1, b1 ); + c_int32_3p2 = _mm512_dpbusd_epi32( c_int32_3p2, a_int32_1, b2 ); + c_int32_3p3 = _mm512_dpbusd_epi32( c_int32_3p3, a_int32_1, b3 ); + + // Perform column direction mat-mul with k = 4. + // c[4,0-63] = a[4,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_4p0 = _mm512_dpbusd_epi32( c_int32_4p0, a_int32_0, b0 ); + + // Broadcast a[5,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 5 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_1 = _mm512_set1_epi32( a_kfringe_buf ); + + c_int32_4p1 = _mm512_dpbusd_epi32( c_int32_4p1, a_int32_0, b1 ); + c_int32_4p2 = _mm512_dpbusd_epi32( c_int32_4p2, a_int32_0, b2 ); + c_int32_4p3 = _mm512_dpbusd_epi32( c_int32_4p3, a_int32_0, b3 ); + + // Perform column direction mat-mul with k = 4. + // c[5,0-63] = a[5,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_5p0 = _mm512_dpbusd_epi32( c_int32_5p0, a_int32_1, b0 ); + c_int32_5p1 = _mm512_dpbusd_epi32( c_int32_5p1, a_int32_1, b1 ); + c_int32_5p2 = _mm512_dpbusd_epi32( c_int32_5p2, a_int32_1, b2 ); + c_int32_5p3 = _mm512_dpbusd_epi32( c_int32_5p3, a_int32_1, b3 ); + } + + // Load alpha and beta + __m512i selector1 = _mm512_set1_epi32( alpha ); + __m512i selector2 = _mm512_set1_epi32( beta ); + + // Scale by alpha + c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); + c_int32_0p1 = _mm512_mullo_epi32( selector1, c_int32_0p1 ); + c_int32_0p2 = _mm512_mullo_epi32( selector1, c_int32_0p2 ); + c_int32_0p3 = _mm512_mullo_epi32( selector1, c_int32_0p3 ); + + c_int32_1p0 = _mm512_mullo_epi32( selector1, c_int32_1p0 ); + c_int32_1p1 = _mm512_mullo_epi32( selector1, c_int32_1p1 ); + c_int32_1p2 = _mm512_mullo_epi32( selector1, c_int32_1p2 ); + c_int32_1p3 = _mm512_mullo_epi32( selector1, c_int32_1p3 ); + + c_int32_2p0 = _mm512_mullo_epi32( selector1, c_int32_2p0 ); + c_int32_2p1 = _mm512_mullo_epi32( selector1, c_int32_2p1 ); + c_int32_2p2 = _mm512_mullo_epi32( selector1, c_int32_2p2 ); + c_int32_2p3 = _mm512_mullo_epi32( selector1, c_int32_2p3 ); + + c_int32_3p0 = _mm512_mullo_epi32( selector1, c_int32_3p0 ); + c_int32_3p1 = _mm512_mullo_epi32( selector1, c_int32_3p1 ); + c_int32_3p2 = _mm512_mullo_epi32( selector1, c_int32_3p2 ); + c_int32_3p3 = _mm512_mullo_epi32( selector1, c_int32_3p3 ); + + c_int32_4p0 = _mm512_mullo_epi32( selector1, c_int32_4p0 ); + c_int32_4p1 = _mm512_mullo_epi32( selector1, c_int32_4p1 ); + c_int32_4p2 = _mm512_mullo_epi32( selector1, c_int32_4p2 ); + c_int32_4p3 = _mm512_mullo_epi32( selector1, c_int32_4p3 ); + + c_int32_5p0 = _mm512_mullo_epi32( selector1, c_int32_5p0 ); + c_int32_5p1 = _mm512_mullo_epi32( selector1, c_int32_5p1 ); + c_int32_5p2 = _mm512_mullo_epi32( selector1, c_int32_5p2 ); + c_int32_5p3 = _mm512_mullo_epi32( selector1, c_int32_5p3 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 0 ) ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 0 ) ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p1 = _mm512_add_epi32( selector1, c_int32_0p1 ); + + // c[0,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 0 ) ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p2 = _mm512_add_epi32( selector1, c_int32_0p2 ); + + // c[0,48-63] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 0 ) ) + ( 3*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p3 = _mm512_add_epi32( selector1, c_int32_0p3 ); + + // c[1,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 1 ) ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[1,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 1 ) ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p1 = _mm512_add_epi32( selector1, c_int32_1p1 ); + + // c[1,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 1 ) ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p2 = _mm512_add_epi32( selector1, c_int32_1p2 ); + + // c[1,48-63] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 1 ) ) + ( 3*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p3 = _mm512_add_epi32( selector1, c_int32_1p3 ); + + // c[2,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 2 ) ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + // c[2,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 2 ) ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p1 = _mm512_add_epi32( selector1, c_int32_2p1 ); + + // c[2,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 2 ) ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p2 = _mm512_add_epi32( selector1, c_int32_2p2 ); + + // c[2,48-63] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 2 ) ) + ( 3*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p3 = _mm512_add_epi32( selector1, c_int32_2p3 ); + + // c[3,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 3 ) ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); + + // c[3,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 3 ) ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_3p1 = _mm512_add_epi32( selector1, c_int32_3p1 ); + + // c[3,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 3 ) ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_3p2 = _mm512_add_epi32( selector1, c_int32_3p2 ); + + // c[3,48-63] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 3 ) ) + ( 3*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_3p3 = _mm512_add_epi32( selector1, c_int32_3p3 ); + + // c[4,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 4 ) ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_4p0 = _mm512_add_epi32( selector1, c_int32_4p0 ); + + // c[4,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 4 ) ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_4p1 = _mm512_add_epi32( selector1, c_int32_4p1 ); + + // c[4,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 4 ) ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_4p2 = _mm512_add_epi32( selector1, c_int32_4p2 ); + + // c[4,48-63] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 4 ) ) + ( 3*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_4p3 = _mm512_add_epi32( selector1, c_int32_4p3 ); + + // c[5,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 5 ) ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_5p0 = _mm512_add_epi32( selector1, c_int32_5p0 ); + + // c[5,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 5 ) ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_5p1 = _mm512_add_epi32( selector1, c_int32_5p1 ); + + // c[5,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 5 ) ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_5p2 = _mm512_add_epi32( selector1, c_int32_5p2 ); + + // c[5,48-63] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 5 ) ) + ( 3*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_5p3 = _mm512_add_epi32( selector1, c_int32_5p3 ); + } + + // Store the results. + // c[0,0-15] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 0 ) ) + ( 0*16 ), c_int32_0p0 ); + + // c[0, 16-31] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 0 ) ) + ( 1*16 ), c_int32_0p1 ); + + // c[0,32-47] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 0 ) ) + ( 2*16 ), c_int32_0p2 ); + + // c[0,48-63] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 0 ) ) + ( 3*16 ), c_int32_0p3 ); + + // c[1,0-15] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 1 ) ) + ( 0*16 ), c_int32_1p0 ); + + // c[1,16-31] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 1 ) ) + ( 1*16 ), c_int32_1p1 ); + + // c[1,32-47] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 1 ) ) + ( 2*16 ), c_int32_1p2 ); + + // c[1,48-63] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 1 ) ) + ( 3*16 ), c_int32_1p3 ); + + // c[2,0-15] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 2 ) ) + ( 0*16 ), c_int32_2p0 ); + + // c[2,16-31] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 2 ) ) + ( 1*16 ), c_int32_2p1 ); + + // c[2,32-47] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 2 ) ) + ( 2*16 ), c_int32_2p2 ); + + // c[2,48-63] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 2 ) ) + ( 3*16 ), c_int32_2p3 ); + + // c[3,0-15] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 3 ) ) + ( 0*16 ), c_int32_3p0 ); + + // c[3,16-31] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 3 ) ) + ( 1*16 ), c_int32_3p1 ); + + // c[3,32-47] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 3 ) ) + ( 2*16 ), c_int32_3p2 ); + + // c[3,48-63] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 3 ) ) + ( 3*16 ), c_int32_3p3 ); + + // c[4,0-15] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 4 ) ) + ( 0*16 ), c_int32_4p0 ); + + // c[4,16-31] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 4 ) ) + ( 1*16 ), c_int32_4p1 ); + + // c[4,32-47] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 4 ) ) + ( 2*16 ), c_int32_4p2 ); + + // c[4,48-63] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 4 ) ) + ( 3*16 ), c_int32_4p3 ); + + // c[5,0-15] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 5 ) ) + ( 0*16 ), c_int32_5p0 ); + + // c[5,16-31] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 5 ) ) + ( 1*16 ), c_int32_5p1 ); + + // c[5,32-47] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 5 ) ) + ( 2*16 ), c_int32_5p2 ); + + // c[5,48-63] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 5 ) ) + ( 3*16 ), c_int32_5p3 ); + + a = a + ( MR * ps_a ); + } + + if ( m_partial_pieces > 0 ) + { + if ( m_partial_pieces == 5 ) + { + // In cases where A matrix is packed cs_a is set to 24, since the + // next column in a given row is accessed after 4*6 elements, where + // 6 is MR and 4 elements are broadcasted each time from A (vnni). + // In fringe case, where m < MR, the next column will be after m'*4 + // elements, and subsequently following adjustment of cs_a is + // required before calling m fringe kernels. + dim_t cs_a_use = ( cs_a == 4 ) ? 4 : ( ( cs_a / 6 ) * 5 ); + lpgemm_rowvar_u8s8s32o32_5x64 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta + ); + } + else if ( m_partial_pieces == 4 ) + { + dim_t cs_a_use = ( cs_a == 4 ) ? 4 : ( ( cs_a / 6 ) * 4 ); + lpgemm_rowvar_u8s8s32o32_4x64 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta + ); + } + else if ( m_partial_pieces == 3 ) + { + dim_t cs_a_use = ( cs_a == 4 ) ? 4 : ( ( cs_a / 6 ) * 3 ); + lpgemm_rowvar_u8s8s32o32_3x64 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta + ); + } + else if ( m_partial_pieces == 2 ) + { + dim_t cs_a_use = ( cs_a == 4 ) ? 4 : ( ( cs_a / 6 ) * 2 ); + lpgemm_rowvar_u8s8s32o32_2x64 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta + ); + } + else if ( m_partial_pieces == 1 ) + { + dim_t cs_a_use = ( cs_a == 4 ) ? 4 : ( ( cs_a / 6 ) * 1 ); + lpgemm_rowvar_u8s8s32o32_1x64 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta + ); + } + } +} diff --git a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_m_fringe.h b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_m_fringe.h new file mode 100644 index 0000000000..b0acdbcd64 --- /dev/null +++ b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_m_fringe.h @@ -0,0 +1,118 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_GEMM_INT8_MFRINGE +#define BLIS_GEMM_INT8_MFRINGE + +// 5x64 int8o32 kernel +void lpgemm_rowvar_u8s8s32o32_5x64 + ( + const dim_t k0, + const uint8_t* a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t* b, + const dim_t rs_b, + const dim_t cs_b, + int32_t* c, + const dim_t rs_c, + const int32_t alpha, + const int32_t beta + ); + +// 4x64 int8o32 kernel +void lpgemm_rowvar_u8s8s32o32_4x64 + ( + const dim_t k0, + const uint8_t* a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t* b, + const dim_t rs_b, + const dim_t cs_b, + int32_t* c, + const dim_t rs_c, + const int32_t alpha, + const int32_t beta + ); + +// 3x64 int8o32 kernel +void lpgemm_rowvar_u8s8s32o32_3x64 + ( + const dim_t k0, + const uint8_t* a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t* b, + const dim_t rs_b, + const dim_t cs_b, + int32_t* c, + const dim_t rs_c, + const int32_t alpha, + const int32_t beta + ); + +// 2x64 int8o32 kernel +void lpgemm_rowvar_u8s8s32o32_2x64 + ( + const dim_t k0, + const uint8_t* a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t* b, + const dim_t rs_b, + const dim_t cs_b, + int32_t* c, + const dim_t rs_c, + const int32_t alpha, + const int32_t beta + ); + +// 1x64 int8o32 kernel +void lpgemm_rowvar_u8s8s32o32_1x64 + ( + const dim_t k0, + const uint8_t* a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t* b, + const dim_t rs_b, + const dim_t cs_b, + int32_t* c, + const dim_t rs_c, + const int32_t alpha, + const int32_t beta + ); + +#endif //BLIS_GEMM_INT8_MFRINGE diff --git a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_m_fringe_amd512vnni.c b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_m_fringe_amd512vnni.c new file mode 100644 index 0000000000..e02c2cce89 --- /dev/null +++ b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_m_fringe_amd512vnni.c @@ -0,0 +1,1354 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include + +#include "blis.h" +#include "lpgemm_m_fringe.h" + +// 5x64 int8o32 kernel +void lpgemm_rowvar_u8s8s32o32_5x64 + ( + const dim_t k0, + const uint8_t* a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t* b, + const dim_t rs_b, + const dim_t cs_b, + int32_t* c, + const dim_t rs_c, + const int32_t alpha, + const int32_t beta + ) +{ + dim_t k_full_pieces = k0 / 4; + dim_t k_partial_pieces = k0 % 4; + + uint32_t a_kfringe_buf = 0; + + // B matrix storage. + __m512i b0; + __m512i b1; + __m512i b2; + __m512i b3; + + // A matrix storage. + __m512i a_int32_0; + __m512i a_int32_1; + + // Registers to use for accumulating C. + __m512i c_int32_0p0 = _mm512_setzero_epi32(); + __m512i c_int32_0p1 = _mm512_setzero_epi32(); + __m512i c_int32_0p2 = _mm512_setzero_epi32(); + __m512i c_int32_0p3 = _mm512_setzero_epi32(); + + __m512i c_int32_1p0 = _mm512_setzero_epi32(); + __m512i c_int32_1p1 = _mm512_setzero_epi32(); + __m512i c_int32_1p2 = _mm512_setzero_epi32(); + __m512i c_int32_1p3 = _mm512_setzero_epi32(); + + __m512i c_int32_2p0 = _mm512_setzero_epi32(); + __m512i c_int32_2p1 = _mm512_setzero_epi32(); + __m512i c_int32_2p2 = _mm512_setzero_epi32(); + __m512i c_int32_2p3 = _mm512_setzero_epi32(); + + __m512i c_int32_3p0 = _mm512_setzero_epi32(); + __m512i c_int32_3p1 = _mm512_setzero_epi32(); + __m512i c_int32_3p2 = _mm512_setzero_epi32(); + __m512i c_int32_3p3 = _mm512_setzero_epi32(); + + __m512i c_int32_4p0 = _mm512_setzero_epi32(); + __m512i c_int32_4p1 = _mm512_setzero_epi32(); + __m512i c_int32_4p2 = _mm512_setzero_epi32(); + __m512i c_int32_4p3 = _mm512_setzero_epi32(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + b2 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 2 ) ); + b3 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 3 ) ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-63] = a[0,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + + // Broadcast a[1,kr:kr+4]. + a_int32_1 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); + c_int32_0p2 = _mm512_dpbusd_epi32( c_int32_0p2, a_int32_0, b2 ); + c_int32_0p3 = _mm512_dpbusd_epi32( c_int32_0p3, a_int32_0, b3 ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-63] = a[1,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_1, b0 ); + + // Broadcast a[2,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_1, b1 ); + c_int32_1p2 = _mm512_dpbusd_epi32( c_int32_1p2, a_int32_1, b2 ); + c_int32_1p3 = _mm512_dpbusd_epi32( c_int32_1p3, a_int32_1, b3 ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-63] = a[2,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + + // Broadcast a[3,kr:kr+4]. + a_int32_1 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + + c_int32_2p1 = _mm512_dpbusd_epi32( c_int32_2p1, a_int32_0, b1 ); + c_int32_2p2 = _mm512_dpbusd_epi32( c_int32_2p2, a_int32_0, b2 ); + c_int32_2p3 = _mm512_dpbusd_epi32( c_int32_2p3, a_int32_0, b3 ); + + // Perform column direction mat-mul with k = 4. + // c[3,0-63] = a[3,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_1, b0 ); + + // Broadcast a[4,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); + + c_int32_3p1 = _mm512_dpbusd_epi32( c_int32_3p1, a_int32_1, b1 ); + c_int32_3p2 = _mm512_dpbusd_epi32( c_int32_3p2, a_int32_1, b2 ); + c_int32_3p3 = _mm512_dpbusd_epi32( c_int32_3p3, a_int32_1, b3 ); + + // Perform column direction mat-mul with k = 4. + // c[4,0-63] = a[4,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_4p0 = _mm512_dpbusd_epi32( c_int32_4p0, a_int32_0, b0 ); + c_int32_4p1 = _mm512_dpbusd_epi32( c_int32_4p1, a_int32_0, b1 ); + c_int32_4p2 = _mm512_dpbusd_epi32( c_int32_4p2, a_int32_0, b2 ); + c_int32_4p3 = _mm512_dpbusd_epi32( c_int32_4p3, a_int32_0, b3 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + b2 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); + b3 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 3 ) ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-63] = a[0,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + + // Broadcast a[1,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_1 = _mm512_set1_epi32( a_kfringe_buf ); + + c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); + c_int32_0p2 = _mm512_dpbusd_epi32( c_int32_0p2, a_int32_0, b2 ); + c_int32_0p3 = _mm512_dpbusd_epi32( c_int32_0p3, a_int32_0, b3 ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-63] = a[1,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_1, b0 ); + + // Broadcast a[2,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_1, b1 ); + c_int32_1p2 = _mm512_dpbusd_epi32( c_int32_1p2, a_int32_1, b2 ); + c_int32_1p3 = _mm512_dpbusd_epi32( c_int32_1p3, a_int32_1, b3 ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-63] = a[2,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + + // Broadcast a[3,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_1 = _mm512_set1_epi32( a_kfringe_buf ); + + c_int32_2p1 = _mm512_dpbusd_epi32( c_int32_2p1, a_int32_0, b1 ); + c_int32_2p2 = _mm512_dpbusd_epi32( c_int32_2p2, a_int32_0, b2 ); + c_int32_2p3 = _mm512_dpbusd_epi32( c_int32_2p3, a_int32_0, b3 ); + + // Perform column direction mat-mul with k = 4. + // c[3,0-63] = a[3,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_1, b0 ); + + // Broadcast a[4,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 4 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + c_int32_3p1 = _mm512_dpbusd_epi32( c_int32_3p1, a_int32_1, b1 ); + c_int32_3p2 = _mm512_dpbusd_epi32( c_int32_3p2, a_int32_1, b2 ); + c_int32_3p3 = _mm512_dpbusd_epi32( c_int32_3p3, a_int32_1, b3 ); + + // Perform column direction mat-mul with k = 4. + // c[4,0-63] = a[4,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_4p0 = _mm512_dpbusd_epi32( c_int32_4p0, a_int32_0, b0 ); + c_int32_4p1 = _mm512_dpbusd_epi32( c_int32_4p1, a_int32_0, b1 ); + c_int32_4p2 = _mm512_dpbusd_epi32( c_int32_4p2, a_int32_0, b2 ); + c_int32_4p3 = _mm512_dpbusd_epi32( c_int32_4p3, a_int32_0, b3 ); + } + + // Load alpha and beta + __m512i selector1 = _mm512_set1_epi32( alpha ); + __m512i selector2 = _mm512_set1_epi32( beta ); + + // Scale by alpha + c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); + c_int32_0p1 = _mm512_mullo_epi32( selector1, c_int32_0p1 ); + c_int32_0p2 = _mm512_mullo_epi32( selector1, c_int32_0p2 ); + c_int32_0p3 = _mm512_mullo_epi32( selector1, c_int32_0p3 ); + + c_int32_1p0 = _mm512_mullo_epi32( selector1, c_int32_1p0 ); + c_int32_1p1 = _mm512_mullo_epi32( selector1, c_int32_1p1 ); + c_int32_1p2 = _mm512_mullo_epi32( selector1, c_int32_1p2 ); + c_int32_1p3 = _mm512_mullo_epi32( selector1, c_int32_1p3 ); + + c_int32_2p0 = _mm512_mullo_epi32( selector1, c_int32_2p0 ); + c_int32_2p1 = _mm512_mullo_epi32( selector1, c_int32_2p1 ); + c_int32_2p2 = _mm512_mullo_epi32( selector1, c_int32_2p2 ); + c_int32_2p3 = _mm512_mullo_epi32( selector1, c_int32_2p3 ); + + c_int32_3p0 = _mm512_mullo_epi32( selector1, c_int32_3p0 ); + c_int32_3p1 = _mm512_mullo_epi32( selector1, c_int32_3p1 ); + c_int32_3p2 = _mm512_mullo_epi32( selector1, c_int32_3p2 ); + c_int32_3p3 = _mm512_mullo_epi32( selector1, c_int32_3p3 ); + + c_int32_4p0 = _mm512_mullo_epi32( selector1, c_int32_4p0 ); + c_int32_4p1 = _mm512_mullo_epi32( selector1, c_int32_4p1 ); + c_int32_4p2 = _mm512_mullo_epi32( selector1, c_int32_4p2 ); + c_int32_4p3 = _mm512_mullo_epi32( selector1, c_int32_4p3 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p1 = _mm512_add_epi32( selector1, c_int32_0p1 ); + + // c[0,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p2 = _mm512_add_epi32( selector1, c_int32_0p2 ); + + // c[0,48-63] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 3*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p3 = _mm512_add_epi32( selector1, c_int32_0p3 ); + + // c[1,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[1,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p1 = _mm512_add_epi32( selector1, c_int32_1p1 ); + + // c[1,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p2 = _mm512_add_epi32( selector1, c_int32_1p2 ); + + // c[1,48-63] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 3*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p3 = _mm512_add_epi32( selector1, c_int32_1p3 ); + + // c[2,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + // c[2,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p1 = _mm512_add_epi32( selector1, c_int32_2p1 ); + + // c[2,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p2 = _mm512_add_epi32( selector1, c_int32_2p2 ); + + // c[2,48-63] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 3*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p3 = _mm512_add_epi32( selector1, c_int32_2p3 ); + + // c[3,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 3 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); + + // c[3,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 3 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_3p1 = _mm512_add_epi32( selector1, c_int32_3p1 ); + + // c[3,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 3 ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_3p2 = _mm512_add_epi32( selector1, c_int32_3p2 ); + + // c[3,48-63] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 3 ) + ( 3*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_3p3 = _mm512_add_epi32( selector1, c_int32_3p3 ); + + // c[4,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 4 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_4p0 = _mm512_add_epi32( selector1, c_int32_4p0 ); + + // c[4,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 4 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_4p1 = _mm512_add_epi32( selector1, c_int32_4p1 ); + + // c[4,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 4 ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_4p2 = _mm512_add_epi32( selector1, c_int32_4p2 ); + + // c[4,48-63] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 4 ) + ( 3*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_4p3 = _mm512_add_epi32( selector1, c_int32_4p3 ); + } + + // Store the results. + // c[0,0-15] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 0*16 ), c_int32_0p0 ); + + // c[0, 16-31] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 1*16 ), c_int32_0p1 ); + + // c[0,32-47] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 2*16 ), c_int32_0p2 ); + + // c[0,48-63] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 3*16 ), c_int32_0p3 ); + + // c[1,0-15] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 0*16 ), c_int32_1p0 ); + + // c[1,16-31] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 1*16 ), c_int32_1p1 ); + + // c[1,32-47] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 2*16 ), c_int32_1p2 ); + + // c[1,48-63] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 3*16 ), c_int32_1p3 ); + + // c[2,0-15] + _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 0*16 ), c_int32_2p0 ); + + // c[2,16-31] + _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 1*16 ), c_int32_2p1 ); + + // c[2,32-47] + _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 2*16 ), c_int32_2p2 ); + + // c[2,48-63] + _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 3*16 ), c_int32_2p3 ); + + // c[3,0-15] + _mm512_storeu_epi32( c + ( rs_c * 3 ) + ( 0*16 ), c_int32_3p0 ); + + // c[3,16-31] + _mm512_storeu_epi32( c + ( rs_c * 3 ) + ( 1*16 ), c_int32_3p1 ); + + // c[3,32-47] + _mm512_storeu_epi32( c + ( rs_c * 3 ) + ( 2*16 ), c_int32_3p2 ); + + // c[3,48-63] + _mm512_storeu_epi32( c + ( rs_c * 3 ) + ( 3*16 ), c_int32_3p3 ); + + // c[4,0-15] + _mm512_storeu_epi32( c + ( rs_c * 4 ) + ( 0*16 ), c_int32_4p0 ); + + // c[4,16-31] + _mm512_storeu_epi32( c + ( rs_c * 4 ) + ( 1*16 ), c_int32_4p1 ); + + // c[4,32-47] + _mm512_storeu_epi32( c + ( rs_c * 4 ) + ( 2*16 ), c_int32_4p2 ); + + // c[4,48-63] + _mm512_storeu_epi32( c + ( rs_c * 4 ) + ( 3*16 ), c_int32_4p3 ); +} + +// 4x64 int8o32 kernel +void lpgemm_rowvar_u8s8s32o32_4x64 + ( + const dim_t k0, + const uint8_t* a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t* b, + const dim_t rs_b, + const dim_t cs_b, + int32_t* c, + const dim_t rs_c, + const int32_t alpha, + const int32_t beta + ) +{ + dim_t k_full_pieces = k0 / 4; + dim_t k_partial_pieces = k0 % 4; + + uint32_t a_kfringe_buf = 0; + + // B matrix storage. + __m512i b0; + __m512i b1; + __m512i b2; + __m512i b3; + + // A matrix storage. + __m512i a_int32_0; + __m512i a_int32_1; + + // Registers to use for accumulating C. + __m512i c_int32_0p0 = _mm512_setzero_epi32(); + __m512i c_int32_0p1 = _mm512_setzero_epi32(); + __m512i c_int32_0p2 = _mm512_setzero_epi32(); + __m512i c_int32_0p3 = _mm512_setzero_epi32(); + + __m512i c_int32_1p0 = _mm512_setzero_epi32(); + __m512i c_int32_1p1 = _mm512_setzero_epi32(); + __m512i c_int32_1p2 = _mm512_setzero_epi32(); + __m512i c_int32_1p3 = _mm512_setzero_epi32(); + + __m512i c_int32_2p0 = _mm512_setzero_epi32(); + __m512i c_int32_2p1 = _mm512_setzero_epi32(); + __m512i c_int32_2p2 = _mm512_setzero_epi32(); + __m512i c_int32_2p3 = _mm512_setzero_epi32(); + + __m512i c_int32_3p0 = _mm512_setzero_epi32(); + __m512i c_int32_3p1 = _mm512_setzero_epi32(); + __m512i c_int32_3p2 = _mm512_setzero_epi32(); + __m512i c_int32_3p3 = _mm512_setzero_epi32(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + b2 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 2 ) ); + b3 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 3 ) ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-63] = a[0,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + + // Broadcast a[1,kr:kr+4]. + a_int32_1 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); + c_int32_0p2 = _mm512_dpbusd_epi32( c_int32_0p2, a_int32_0, b2 ); + c_int32_0p3 = _mm512_dpbusd_epi32( c_int32_0p3, a_int32_0, b3 ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-63] = a[1,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_1, b0 ); + + // Broadcast a[2,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_1, b1 ); + c_int32_1p2 = _mm512_dpbusd_epi32( c_int32_1p2, a_int32_1, b2 ); + c_int32_1p3 = _mm512_dpbusd_epi32( c_int32_1p3, a_int32_1, b3 ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-63] = a[2,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + + // Broadcast a[3,kr:kr+4]. + a_int32_1 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + + c_int32_2p1 = _mm512_dpbusd_epi32( c_int32_2p1, a_int32_0, b1 ); + c_int32_2p2 = _mm512_dpbusd_epi32( c_int32_2p2, a_int32_0, b2 ); + c_int32_2p3 = _mm512_dpbusd_epi32( c_int32_2p3, a_int32_0, b3 ); + + // Perform column direction mat-mul with k = 4. + // c[3,0-63] = a[3,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_1, b0 ); + c_int32_3p1 = _mm512_dpbusd_epi32( c_int32_3p1, a_int32_1, b1 ); + c_int32_3p2 = _mm512_dpbusd_epi32( c_int32_3p2, a_int32_1, b2 ); + c_int32_3p3 = _mm512_dpbusd_epi32( c_int32_3p3, a_int32_1, b3 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + b2 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); + b3 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 3 ) ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-63] = a[0,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + + // Broadcast a[1,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_1 = _mm512_set1_epi32( a_kfringe_buf ); + + c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); + c_int32_0p2 = _mm512_dpbusd_epi32( c_int32_0p2, a_int32_0, b2 ); + c_int32_0p3 = _mm512_dpbusd_epi32( c_int32_0p3, a_int32_0, b3 ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-63] = a[1,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_1, b0 ); + + // Broadcast a[2,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_1, b1 ); + c_int32_1p2 = _mm512_dpbusd_epi32( c_int32_1p2, a_int32_1, b2 ); + c_int32_1p3 = _mm512_dpbusd_epi32( c_int32_1p3, a_int32_1, b3 ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-63] = a[2,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + + // Broadcast a[3,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_1 = _mm512_set1_epi32( a_kfringe_buf ); + + c_int32_2p1 = _mm512_dpbusd_epi32( c_int32_2p1, a_int32_0, b1 ); + c_int32_2p2 = _mm512_dpbusd_epi32( c_int32_2p2, a_int32_0, b2 ); + c_int32_2p3 = _mm512_dpbusd_epi32( c_int32_2p3, a_int32_0, b3 ); + + // Perform column direction mat-mul with k = 4. + // c[3,0-63] = a[3,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_1, b0 ); + c_int32_3p1 = _mm512_dpbusd_epi32( c_int32_3p1, a_int32_1, b1 ); + c_int32_3p2 = _mm512_dpbusd_epi32( c_int32_3p2, a_int32_1, b2 ); + c_int32_3p3 = _mm512_dpbusd_epi32( c_int32_3p3, a_int32_1, b3 ); + } + + // Load alpha and beta + __m512i selector1 = _mm512_set1_epi32( alpha ); + __m512i selector2 = _mm512_set1_epi32( beta ); + + // Scale by alpha + c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); + c_int32_0p1 = _mm512_mullo_epi32( selector1, c_int32_0p1 ); + c_int32_0p2 = _mm512_mullo_epi32( selector1, c_int32_0p2 ); + c_int32_0p3 = _mm512_mullo_epi32( selector1, c_int32_0p3 ); + + c_int32_1p0 = _mm512_mullo_epi32( selector1, c_int32_1p0 ); + c_int32_1p1 = _mm512_mullo_epi32( selector1, c_int32_1p1 ); + c_int32_1p2 = _mm512_mullo_epi32( selector1, c_int32_1p2 ); + c_int32_1p3 = _mm512_mullo_epi32( selector1, c_int32_1p3 ); + + c_int32_2p0 = _mm512_mullo_epi32( selector1, c_int32_2p0 ); + c_int32_2p1 = _mm512_mullo_epi32( selector1, c_int32_2p1 ); + c_int32_2p2 = _mm512_mullo_epi32( selector1, c_int32_2p2 ); + c_int32_2p3 = _mm512_mullo_epi32( selector1, c_int32_2p3 ); + + c_int32_3p0 = _mm512_mullo_epi32( selector1, c_int32_3p0 ); + c_int32_3p1 = _mm512_mullo_epi32( selector1, c_int32_3p1 ); + c_int32_3p2 = _mm512_mullo_epi32( selector1, c_int32_3p2 ); + c_int32_3p3 = _mm512_mullo_epi32( selector1, c_int32_3p3 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p1 = _mm512_add_epi32( selector1, c_int32_0p1 ); + + // c[0,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p2 = _mm512_add_epi32( selector1, c_int32_0p2 ); + + // c[0,48-63] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 3*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p3 = _mm512_add_epi32( selector1, c_int32_0p3 ); + + // c[1,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[1,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p1 = _mm512_add_epi32( selector1, c_int32_1p1 ); + + // c[1,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p2 = _mm512_add_epi32( selector1, c_int32_1p2 ); + + // c[1,48-63] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 3*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p3 = _mm512_add_epi32( selector1, c_int32_1p3 ); + + // c[2,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + // c[2,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p1 = _mm512_add_epi32( selector1, c_int32_2p1 ); + + // c[2,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p2 = _mm512_add_epi32( selector1, c_int32_2p2 ); + + // c[2,48-63] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 3*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p3 = _mm512_add_epi32( selector1, c_int32_2p3 ); + + // c[3,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 3 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); + + // c[3,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 3 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_3p1 = _mm512_add_epi32( selector1, c_int32_3p1 ); + + // c[3,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 3 ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_3p2 = _mm512_add_epi32( selector1, c_int32_3p2 ); + + // c[3,48-63] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 3 ) + ( 3*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_3p3 = _mm512_add_epi32( selector1, c_int32_3p3 ); + } + + // Store the results. + // c[0,0-15] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 0*16 ), c_int32_0p0 ); + + // c[0, 16-31] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 1*16 ), c_int32_0p1 ); + + // c[0,32-47] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 2*16 ), c_int32_0p2 ); + + // c[0,48-63] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 3*16 ), c_int32_0p3 ); + + // c[1,0-15] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 0*16 ), c_int32_1p0 ); + + // c[1,16-31] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 1*16 ), c_int32_1p1 ); + + // c[1,32-47] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 2*16 ), c_int32_1p2 ); + + // c[1,48-63] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 3*16 ), c_int32_1p3 ); + + // c[2,0-15] + _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 0*16 ), c_int32_2p0 ); + + // c[2,16-31] + _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 1*16 ), c_int32_2p1 ); + + // c[2,32-47] + _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 2*16 ), c_int32_2p2 ); + + // c[2,48-63] + _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 3*16 ), c_int32_2p3 ); + + // c[3,0-15] + _mm512_storeu_epi32( c + ( rs_c * 3 ) + ( 0*16 ), c_int32_3p0 ); + + // c[3,16-31] + _mm512_storeu_epi32( c + ( rs_c * 3 ) + ( 1*16 ), c_int32_3p1 ); + + // c[3,32-47] + _mm512_storeu_epi32( c + ( rs_c * 3 ) + ( 2*16 ), c_int32_3p2 ); + + // c[3,48-63] + _mm512_storeu_epi32( c + ( rs_c * 3 ) + ( 3*16 ), c_int32_3p3 ); +} + +// 3x64 int8o32 kernel +void lpgemm_rowvar_u8s8s32o32_3x64 + ( + const dim_t k0, + const uint8_t* a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t* b, + const dim_t rs_b, + const dim_t cs_b, + int32_t* c, + const dim_t rs_c, + const int32_t alpha, + const int32_t beta + ) +{ + dim_t k_full_pieces = k0 / 4; + dim_t k_partial_pieces = k0 % 4; + + uint32_t a_kfringe_buf = 0; + + // Registers to use for accumulating C. + __m512i c_int32_0p0 = _mm512_setzero_epi32(); + __m512i c_int32_0p1 = _mm512_setzero_epi32(); + __m512i c_int32_0p2 = _mm512_setzero_epi32(); + __m512i c_int32_0p3 = _mm512_setzero_epi32(); + + __m512i c_int32_1p0 = _mm512_setzero_epi32(); + __m512i c_int32_1p1 = _mm512_setzero_epi32(); + __m512i c_int32_1p2 = _mm512_setzero_epi32(); + __m512i c_int32_1p3 = _mm512_setzero_epi32(); + + __m512i c_int32_2p0 = _mm512_setzero_epi32(); + __m512i c_int32_2p1 = _mm512_setzero_epi32(); + __m512i c_int32_2p2 = _mm512_setzero_epi32(); + __m512i c_int32_2p3 = _mm512_setzero_epi32(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+4]. + __m512i a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + __m512i b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + __m512i b2 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 2 ) ); + __m512i b3 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 3 ) ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-63] = a[0,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + + // Broadcast a[1,kr:kr+4]. + __m512i a_int32_1 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); + c_int32_0p2 = _mm512_dpbusd_epi32( c_int32_0p2, a_int32_0, b2 ); + c_int32_0p3 = _mm512_dpbusd_epi32( c_int32_0p3, a_int32_0, b3 ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-63] = a[1,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_1, b0 ); + + // Broadcast a[2,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_1, b1 ); + c_int32_1p2 = _mm512_dpbusd_epi32( c_int32_1p2, a_int32_1, b2 ); + c_int32_1p3 = _mm512_dpbusd_epi32( c_int32_1p3, a_int32_1, b3 ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-63] = a[2,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + c_int32_2p1 = _mm512_dpbusd_epi32( c_int32_2p1, a_int32_0, b1 ); + c_int32_2p2 = _mm512_dpbusd_epi32( c_int32_2p2, a_int32_0, b2 ); + c_int32_2p3 = _mm512_dpbusd_epi32( c_int32_2p3, a_int32_0, b3 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + __m512i a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + __m512i b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + __m512i b2 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); + __m512i b3 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 3 ) ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-63] = a[0,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + + // Broadcast a[1,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + __m512i a_int32_1 = _mm512_set1_epi32( a_kfringe_buf ); + + c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); + c_int32_0p2 = _mm512_dpbusd_epi32( c_int32_0p2, a_int32_0, b2 ); + c_int32_0p3 = _mm512_dpbusd_epi32( c_int32_0p3, a_int32_0, b3 ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-63] = a[1,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_1, b0 ); + + // Broadcast a[2,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_1, b1 ); + c_int32_1p2 = _mm512_dpbusd_epi32( c_int32_1p2, a_int32_1, b2 ); + c_int32_1p3 = _mm512_dpbusd_epi32( c_int32_1p3, a_int32_1, b3 ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-63] = a[2,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + c_int32_2p1 = _mm512_dpbusd_epi32( c_int32_2p1, a_int32_0, b1 ); + c_int32_2p2 = _mm512_dpbusd_epi32( c_int32_2p2, a_int32_0, b2 ); + c_int32_2p3 = _mm512_dpbusd_epi32( c_int32_2p3, a_int32_0, b3 ); + } + + // Load alpha and beta + __m512i selector1 = _mm512_set1_epi32( alpha ); + __m512i selector2 = _mm512_set1_epi32( beta ); + + // Scale by alpha + c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); + c_int32_0p1 = _mm512_mullo_epi32( selector1, c_int32_0p1 ); + c_int32_0p2 = _mm512_mullo_epi32( selector1, c_int32_0p2 ); + c_int32_0p3 = _mm512_mullo_epi32( selector1, c_int32_0p3 ); + + c_int32_1p0 = _mm512_mullo_epi32( selector1, c_int32_1p0 ); + c_int32_1p1 = _mm512_mullo_epi32( selector1, c_int32_1p1 ); + c_int32_1p2 = _mm512_mullo_epi32( selector1, c_int32_1p2 ); + c_int32_1p3 = _mm512_mullo_epi32( selector1, c_int32_1p3 ); + + c_int32_2p0 = _mm512_mullo_epi32( selector1, c_int32_2p0 ); + c_int32_2p1 = _mm512_mullo_epi32( selector1, c_int32_2p1 ); + c_int32_2p2 = _mm512_mullo_epi32( selector1, c_int32_2p2 ); + c_int32_2p3 = _mm512_mullo_epi32( selector1, c_int32_2p3 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p1 = _mm512_add_epi32( selector1, c_int32_0p1 ); + + // c[0,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p2 = _mm512_add_epi32( selector1, c_int32_0p2 ); + + // c[0,48-63] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 3*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p3 = _mm512_add_epi32( selector1, c_int32_0p3 ); + + // c[1,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[1,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p1 = _mm512_add_epi32( selector1, c_int32_1p1 ); + + // c[1,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p2 = _mm512_add_epi32( selector1, c_int32_1p2 ); + + // c[1,48-63] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 3*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p3 = _mm512_add_epi32( selector1, c_int32_1p3 ); + + // c[2,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + // c[2,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p1 = _mm512_add_epi32( selector1, c_int32_2p1 ); + + // c[2,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p2 = _mm512_add_epi32( selector1, c_int32_2p2 ); + + // c[2,48-63] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 3*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p3 = _mm512_add_epi32( selector1, c_int32_2p3 ); + } + + // Store the results. + // c[0,0-15] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 0*16 ), c_int32_0p0 ); + + // c[0, 16-31] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 1*16 ), c_int32_0p1 ); + + // c[0,32-47] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 2*16 ), c_int32_0p2 ); + + // c[0,48-63] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 3*16 ), c_int32_0p3 ); + + // c[1,0-15] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 0*16 ), c_int32_1p0 ); + + // c[1,16-31] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 1*16 ), c_int32_1p1 ); + + // c[1,32-47] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 2*16 ), c_int32_1p2 ); + + // c[1,48-63] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 3*16 ), c_int32_1p3 ); + + // c[2,0-15] + _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 0*16 ), c_int32_2p0 ); + + // c[2,16-31] + _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 1*16 ), c_int32_2p1 ); + + // c[2,32-47] + _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 2*16 ), c_int32_2p2 ); + + // c[2,48-63] + _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 3*16 ), c_int32_2p3 ); +} + +// 2x64 int8o32 kernel +void lpgemm_rowvar_u8s8s32o32_2x64 + ( + const dim_t k0, + const uint8_t* a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t* b, + const dim_t rs_b, + const dim_t cs_b, + int32_t* c, + const dim_t rs_c, + const int32_t alpha, + const int32_t beta + ) +{ + dim_t k_full_pieces = k0 / 4; + dim_t k_partial_pieces = k0 % 4; + + uint32_t a_kfringe_buf = 0; + + // Registers to use for accumulating C. + __m512i c_int32_0p0 = _mm512_setzero_epi32(); + __m512i c_int32_0p1 = _mm512_setzero_epi32(); + __m512i c_int32_0p2 = _mm512_setzero_epi32(); + __m512i c_int32_0p3 = _mm512_setzero_epi32(); + + __m512i c_int32_1p0 = _mm512_setzero_epi32(); + __m512i c_int32_1p1 = _mm512_setzero_epi32(); + __m512i c_int32_1p2 = _mm512_setzero_epi32(); + __m512i c_int32_1p3 = _mm512_setzero_epi32(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+4]. + __m512i a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + __m512i b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + __m512i b2 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 2 ) ); + __m512i b3 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 3 ) ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-63] = a[0,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + + // Broadcast a[1,kr:kr+4]. + __m512i a_int32_1 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); + c_int32_0p2 = _mm512_dpbusd_epi32( c_int32_0p2, a_int32_0, b2 ); + c_int32_0p3 = _mm512_dpbusd_epi32( c_int32_0p3, a_int32_0, b3 ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-63] = a[1,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_1, b0 ); + c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_1, b1 ); + c_int32_1p2 = _mm512_dpbusd_epi32( c_int32_1p2, a_int32_1, b2 ); + c_int32_1p3 = _mm512_dpbusd_epi32( c_int32_1p3, a_int32_1, b3 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + __m512i a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + __m512i b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + __m512i b2 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); + __m512i b3 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 3 ) ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-63] = a[0,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + + // Broadcast a[1,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + __m512i a_int32_1 = _mm512_set1_epi32( a_kfringe_buf ); + + c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); + c_int32_0p2 = _mm512_dpbusd_epi32( c_int32_0p2, a_int32_0, b2 ); + c_int32_0p3 = _mm512_dpbusd_epi32( c_int32_0p3, a_int32_0, b3 ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-63] = a[1,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_1, b0 ); + c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_1, b1 ); + c_int32_1p2 = _mm512_dpbusd_epi32( c_int32_1p2, a_int32_1, b2 ); + c_int32_1p3 = _mm512_dpbusd_epi32( c_int32_1p3, a_int32_1, b3 ); + } + + // Load alpha and beta + __m512i selector1 = _mm512_set1_epi32( alpha ); + __m512i selector2 = _mm512_set1_epi32( beta ); + + // Scale by alpha + c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); + c_int32_0p1 = _mm512_mullo_epi32( selector1, c_int32_0p1 ); + c_int32_0p2 = _mm512_mullo_epi32( selector1, c_int32_0p2 ); + c_int32_0p3 = _mm512_mullo_epi32( selector1, c_int32_0p3 ); + + c_int32_1p0 = _mm512_mullo_epi32( selector1, c_int32_1p0 ); + c_int32_1p1 = _mm512_mullo_epi32( selector1, c_int32_1p1 ); + c_int32_1p2 = _mm512_mullo_epi32( selector1, c_int32_1p2 ); + c_int32_1p3 = _mm512_mullo_epi32( selector1, c_int32_1p3 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p1 = _mm512_add_epi32( selector1, c_int32_0p1 ); + + // c[0,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p2 = _mm512_add_epi32( selector1, c_int32_0p2 ); + + // c[0,48-63] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 3*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p3 = _mm512_add_epi32( selector1, c_int32_0p3 ); + + // c[1,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[1,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p1 = _mm512_add_epi32( selector1, c_int32_1p1 ); + + // c[1,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p2 = _mm512_add_epi32( selector1, c_int32_1p2 ); + + // c[1,48-63] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 3*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p3 = _mm512_add_epi32( selector1, c_int32_1p3 ); + } + + // Store the results. + // c[0,0-15] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 0*16 ), c_int32_0p0 ); + + // c[0, 16-31] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 1*16 ), c_int32_0p1 ); + + // c[0,32-47] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 2*16 ), c_int32_0p2 ); + + // c[0,48-63] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 3*16 ), c_int32_0p3 ); + + // c[1,0-15] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 0*16 ), c_int32_1p0 ); + + // c[1,16-31] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 1*16 ), c_int32_1p1 ); + + // c[1,32-47] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 2*16 ), c_int32_1p2 ); + + // c[1,48-63] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 3*16 ), c_int32_1p3 ); +} + +// 1x64 int8o32 kernel +void lpgemm_rowvar_u8s8s32o32_1x64 + ( + const dim_t k0, + const uint8_t* a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t* b, + const dim_t rs_b, + const dim_t cs_b, + int32_t* c, + const dim_t rs_c, + const int32_t alpha, + const int32_t beta + ) +{ + dim_t k_full_pieces = k0 / 4; + dim_t k_partial_pieces = k0 % 4; + + uint32_t a_kfringe_buf = 0; + + // Registers to use for accumulating C. + __m512i c_int32_0p0 = _mm512_setzero_epi32(); + __m512i c_int32_0p1 = _mm512_setzero_epi32(); + __m512i c_int32_0p2 = _mm512_setzero_epi32(); + __m512i c_int32_0p3 = _mm512_setzero_epi32(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr] + __m512i a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + __m512i b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + __m512i b2 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 2 ) ); + __m512i b3 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 3 ) ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-63] = a[0,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); + c_int32_0p2 = _mm512_dpbusd_epi32( c_int32_0p2, a_int32_0, b2 ); + c_int32_0p3 = _mm512_dpbusd_epi32( c_int32_0p3, a_int32_0, b3 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + __m512i a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + __m512i b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + __m512i b2 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); + __m512i b3 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 3 ) ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-63] = a[0,kr:kr+4]*b[kr:kr+4,0-63] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); + c_int32_0p2 = _mm512_dpbusd_epi32( c_int32_0p2, a_int32_0, b2 ); + c_int32_0p3 = _mm512_dpbusd_epi32( c_int32_0p3, a_int32_0, b3 ); + } + + // Load alpha and beta + __m512i selector1 = _mm512_set1_epi32( alpha ); + __m512i selector2 = _mm512_set1_epi32( beta ); + + // Scale by alpha + c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); + c_int32_0p1 = _mm512_mullo_epi32( selector1, c_int32_0p1 ); + c_int32_0p2 = _mm512_mullo_epi32( selector1, c_int32_0p2 ); + c_int32_0p3 = _mm512_mullo_epi32( selector1, c_int32_0p3 ); + + // Scale C by beta. + if ( beta != 0) + { + // c[0,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p1 = _mm512_add_epi32( selector1, c_int32_0p1 ); + + // c[0,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p2 = _mm512_add_epi32( selector1, c_int32_0p2 ); + + // c[0,48-63] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 3*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p3 = _mm512_add_epi32( selector1, c_int32_0p3 ); + } + + // Store the accumulated results. + // c[0,0-15] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 0*16 ), c_int32_0p0 ); + + // c[0, 16-31] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 1*16 ), c_int32_0p1 ); + + // c[0,32-47] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 2*16 ), c_int32_0p2 ); + + // c[0,48-63] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 3*16 ), c_int32_0p3 ); +} diff --git a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_mn_fringe.h b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_mn_fringe.h new file mode 100644 index 0000000000..008f254b11 --- /dev/null +++ b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_mn_fringe.h @@ -0,0 +1,363 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_GEMM_INT8_MNFRINGE +#define BLIS_GEMM_INT8_MNFRINGE + +// 5xlt16 int8o32 fringe kernel +void lpgemm_rowvar_u8s8s32o32_5xlt16 + ( + const dim_t k0, + const uint8_t* a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t* b, + const dim_t rs_b, + const dim_t cs_b, + int32_t* c, + const dim_t rs_c, + const int32_t alpha, + const int32_t beta, + const dim_t n0_rem + ); + +// 4xlt16 int8o32 fringe kernel +void lpgemm_rowvar_u8s8s32o32_4xlt16 + ( + const dim_t k0, + const uint8_t* a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t* b, + const dim_t rs_b, + const dim_t cs_b, + int32_t* c, + const dim_t rs_c, + const int32_t alpha, + const int32_t beta, + const dim_t n0_rem + ); + +// 3xlt16 int8o32 fringe kernel +void lpgemm_rowvar_u8s8s32o32_3xlt16 + ( + const dim_t k0, + const uint8_t* a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t* b, + const dim_t rs_b, + const dim_t cs_b, + int32_t* c, + const dim_t rs_c, + const int32_t alpha, + const int32_t beta, + const dim_t n0_rem + ); + +// 2xlt16 int8o32 fringe kernel +void lpgemm_rowvar_u8s8s32o32_2xlt16 + ( + const dim_t k0, + const uint8_t* a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t* b, + const dim_t rs_b, + const dim_t cs_b, + int32_t* c, + const dim_t rs_c, + const int32_t alpha, + const int32_t beta, + const dim_t n0_rem + ); + +// 1xlt16 int8o32 fringe kernel +void lpgemm_rowvar_u8s8s32o32_1xlt16 + ( + const dim_t k0, + const uint8_t* a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t* b, + const dim_t rs_b, + const dim_t cs_b, + int32_t* c, + const dim_t rs_c, + const int32_t alpha, + const int32_t beta, + const dim_t n0_rem + ); + +// 5x16 int8o32 kernel +void lpgemm_rowvar_u8s8s32o32_5x16 + ( + const dim_t k0, + const uint8_t* a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t* b, + const dim_t rs_b, + const dim_t cs_b, + int32_t* c, + const dim_t rs_c, + const int32_t alpha, + const int32_t beta + ); + +// 4x16 int8o32 kernel +void lpgemm_rowvar_u8s8s32o32_4x16 + ( + const dim_t k0, + const uint8_t* a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t* b, + const dim_t rs_b, + const dim_t cs_b, + int32_t* c, + const dim_t rs_c, + const int32_t alpha, + const int32_t beta + ); + +// 3x16 int8o32 kernel +void lpgemm_rowvar_u8s8s32o32_3x16 + ( + const dim_t k0, + const uint8_t* a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t* b, + const dim_t rs_b, + const dim_t cs_b, + int32_t* c, + const dim_t rs_c, + const int32_t alpha, + const int32_t beta + ); + +// 2x16 int8o32 kernel +void lpgemm_rowvar_u8s8s32o32_2x16 + ( + const dim_t k0, + const uint8_t* a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t* b, + const dim_t rs_b, + const dim_t cs_b, + int32_t* c, + const dim_t rs_c, + const int32_t alpha, + const int32_t beta + ); + +// 1x16 int8o32 kernel +void lpgemm_rowvar_u8s8s32o32_1x16 + ( + const dim_t k0, + const uint8_t* a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t* b, + const dim_t rs_b, + const dim_t cs_b, + int32_t* c, + const dim_t rs_c, + const int32_t alpha, + const int32_t beta + ); + +// 5x32 int8o32 kernel +void lpgemm_rowvar_u8s8s32o32_5x32 + ( + const dim_t k0, + const uint8_t* a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t* b, + const dim_t rs_b, + const dim_t cs_b, + int32_t* c, + const dim_t rs_c, + const int32_t alpha, + const int32_t beta + ); + +// 4x32 int8o32 kernel +void lpgemm_rowvar_u8s8s32o32_4x32 + ( + const dim_t k0, + const uint8_t* a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t* b, + const dim_t rs_b, + const dim_t cs_b, + int32_t* c, + const dim_t rs_c, + const int32_t alpha, + const int32_t beta + ); + +// 3x32 int8o32 kernel +void lpgemm_rowvar_u8s8s32o32_3x32 + ( + const dim_t k0, + const uint8_t* a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t* b, + const dim_t rs_b, + const dim_t cs_b, + int32_t* c, + const dim_t rs_c, + const int32_t alpha, + const int32_t beta + ); + +// 2x32 int8o32 kernel +void lpgemm_rowvar_u8s8s32o32_2x32 + ( + const dim_t k0, + const uint8_t* a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t* b, + const dim_t rs_b, + const dim_t cs_b, + int32_t* c, + const dim_t rs_c, + const int32_t alpha, + const int32_t beta + ); + +// 1x32 int8o32 kernel +void lpgemm_rowvar_u8s8s32o32_1x32 + ( + const dim_t k0, + const uint8_t* a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t* b, + const dim_t rs_b, + const dim_t cs_b, + int32_t* c, + const dim_t rs_c, + const int32_t alpha, + const int32_t beta + ); + +// 5x48 int8o32 kernel +void lpgemm_rowvar_u8s8s32o32_5x48 + ( + const dim_t k0, + const uint8_t* a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t* b, + const dim_t rs_b, + const dim_t cs_b, + int32_t* c, + const dim_t rs_c, + const int32_t alpha, + const int32_t beta + ); + +// 4x48 int8o32 kernel +void lpgemm_rowvar_u8s8s32o32_4x48 + ( + const dim_t k0, + const uint8_t* a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t* b, + const dim_t rs_b, + const dim_t cs_b, + int32_t* c, + const dim_t rs_c, + const int32_t alpha, + const int32_t beta + ); + +// 3x48 int8o32 kernel +void lpgemm_rowvar_u8s8s32o32_3x48 + ( + const dim_t k0, + const uint8_t* a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t* b, + const dim_t rs_b, + const dim_t cs_b, + int32_t* c, + const dim_t rs_c, + const int32_t alpha, + const int32_t beta + ); + +// 2x48 int8o32 kernel +void lpgemm_rowvar_u8s8s32o32_2x48 + ( + const dim_t k0, + const uint8_t* a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t* b, + const dim_t rs_b, + const dim_t cs_b, + int32_t* c, + const dim_t rs_c, + const int32_t alpha, + const int32_t beta + ); + +// 1x48 int8o32 kernel +void lpgemm_rowvar_u8s8s32o32_1x48 + ( + const dim_t k0, + const uint8_t* a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t* b, + const dim_t rs_b, + const dim_t cs_b, + int32_t* c, + const dim_t rs_c, + const int32_t alpha, + const int32_t beta + ); + +#endif //BLIS_GEMM_INT8_MNFRINGE diff --git a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_mn_fringe_amd512vnni.c b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_mn_fringe_amd512vnni.c new file mode 100644 index 0000000000..427205292b --- /dev/null +++ b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_mn_fringe_amd512vnni.c @@ -0,0 +1,3200 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include + +#include "blis.h" +#include "lpgemm_mn_fringe.h" + +// 5xlt16 int8o32 fringe kernel +void lpgemm_rowvar_u8s8s32o32_5xlt16 + ( + const dim_t k0, + const uint8_t* a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t* b, + const dim_t rs_b, + const dim_t cs_b, + int32_t* c, + const dim_t rs_c, + const int32_t alpha, + const int32_t beta, + const dim_t n0_rem + ) +{ + dim_t k_full_pieces = k0 / 4; + dim_t k_partial_pieces = k0 % 4; + + uint32_t a_kfringe_buf = 0; + + // For corner cases. + int32_t buf0[16]; + int32_t buf1[16]; + int32_t buf2[16]; + int32_t buf3[16]; + int32_t buf4[16]; + + { + // Registers to use for accumulating C. + __m512i c_int32_0p0 = _mm512_setzero_epi32(); + + __m512i c_int32_1p0 = _mm512_setzero_epi32(); + + __m512i c_int32_2p0 = _mm512_setzero_epi32(); + + __m512i c_int32_3p0 = _mm512_setzero_epi32(); + + __m512i c_int32_4p0 = _mm512_setzero_epi32(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+4]. + __m512i a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + + // Broadcast a[1,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-15] = a[1,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + + // Broadcast a[2,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-15] = a[2,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + + // Broadcast a[3,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[3,0-15] = a[3,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_0, b0 ); + + // Broadcast a[4,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[4,0-15] = a[4,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_4p0 = _mm512_dpbusd_epi32( c_int32_4p0, a_int32_0, b0 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + __m512i a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + + // Broadcast a[1,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-15] = a[1,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + + // Broadcast a[2,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-15] = a[2,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + + // Broadcast a[3,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[3,0-15] = a[3,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_0, b0 ); + + // Broadcast a[4,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 4 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[4,0-15] = a[4,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_4p0 = _mm512_dpbusd_epi32( c_int32_4p0, a_int32_0, b0 ); + } + + // Load alpha and beta + __m512i selector1 = _mm512_set1_epi32( alpha ); + __m512i selector2 = _mm512_set1_epi32( beta ); + + // Scale by alpha + c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); + + c_int32_1p0 = _mm512_mullo_epi32( selector1, c_int32_1p0 ); + + c_int32_2p0 = _mm512_mullo_epi32( selector1, c_int32_2p0 ); + + c_int32_3p0 = _mm512_mullo_epi32( selector1, c_int32_3p0 ); + + c_int32_4p0 = _mm512_mullo_epi32( selector1, c_int32_4p0 ); + + // Scale C by beta. + if ( beta != 0 ) + { + memcpy( buf0, ( c + ( rs_c * 0 ) ), ( n0_rem * sizeof( int32_t ) ) ); + memcpy( buf1, ( c + ( rs_c * 1 ) ), ( n0_rem * sizeof( int32_t ) ) ); + memcpy( buf2, ( c + ( rs_c * 2 ) ), ( n0_rem * sizeof( int32_t ) ) ); + memcpy( buf3, ( c + ( rs_c * 3 ) ), ( n0_rem * sizeof( int32_t ) ) ); + memcpy( buf4, ( c + ( rs_c * 4 ) ), ( n0_rem * sizeof( int32_t ) ) ); + + // c[0,0-15] + selector1 = _mm512_loadu_epi32( buf0 ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[1,0-15] + selector1 = _mm512_loadu_epi32( buf1 ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[2,0-15] + selector1 = _mm512_loadu_epi32( buf2 ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + // c[3,0-15] + selector1 = _mm512_loadu_epi32( buf3 ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); + + // c[4,0-15] + selector1 = _mm512_loadu_epi32( buf4 ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_4p0 = _mm512_add_epi32( selector1, c_int32_4p0 ); + } + + // Store the results. + // c[0,0-15] + _mm512_storeu_epi32( buf0, c_int32_0p0 ); + + // c[1,0-15] + _mm512_storeu_epi32( buf1, c_int32_1p0 ); + + // c[2,0-15] + _mm512_storeu_epi32( buf2, c_int32_2p0 ); + + // c[3,0-15] + _mm512_storeu_epi32( buf3, c_int32_3p0 ); + + // c[4,0-15] + _mm512_storeu_epi32( buf4, c_int32_4p0 ); + + // Memcpy partial parts. + // c[0,0-15] + memcpy( c + ( rs_c * 0 ) + ( 0*16 ), buf0, ( n0_rem * sizeof( int32_t ) ) ); + + // c[1,0-15] + memcpy( c + ( rs_c * 1 ) + ( 0*16 ), buf1, ( n0_rem * sizeof( int32_t ) ) ); + + // c[2,0-15] + memcpy( c + ( rs_c * 2 ) + ( 0*16 ), buf2, ( n0_rem * sizeof( int32_t ) ) ); + + // c[3,0-15] + memcpy( c + ( rs_c * 3 ) + ( 0*16 ), buf3, ( n0_rem * sizeof( int32_t ) ) ); + + // c[4,0-15] + memcpy( c + ( rs_c * 4 ) + ( 0*16 ), buf4, ( n0_rem * sizeof( int32_t ) ) ); + } +} + +// 4xlt16 int8o32 fringe kernel +void lpgemm_rowvar_u8s8s32o32_4xlt16 + ( + const dim_t k0, + const uint8_t* a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t* b, + const dim_t rs_b, + const dim_t cs_b, + int32_t* c, + const dim_t rs_c, + const int32_t alpha, + const int32_t beta, + const dim_t n0_rem + ) +{ + dim_t k_full_pieces = k0 / 4; + dim_t k_partial_pieces = k0 % 4; + + uint32_t a_kfringe_buf = 0; + + // For corner cases. + int32_t buf0[16]; + int32_t buf1[16]; + int32_t buf2[16]; + int32_t buf3[16]; + + { + // Registers to use for accumulating C. + __m512i c_int32_0p0 = _mm512_setzero_epi32(); + + __m512i c_int32_1p0 = _mm512_setzero_epi32(); + + __m512i c_int32_2p0 = _mm512_setzero_epi32(); + + __m512i c_int32_3p0 = _mm512_setzero_epi32(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+4]. + __m512i a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + + // Broadcast a[1,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-15] = a[1,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + + // Broadcast a[2,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-15] = a[2,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + + // Broadcast a[3,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[3,0-15] = a[3,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_0, b0 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + __m512i a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + + // Broadcast a[1,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-15] = a[1,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + + // Broadcast a[2,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-15] = a[2,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + + // Broadcast a[3,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[3,0-15] = a[3,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_0, b0 ); + } + + // Load alpha and beta + __m512i selector1 = _mm512_set1_epi32( alpha ); + __m512i selector2 = _mm512_set1_epi32( beta ); + + // Scale by alpha + c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); + + c_int32_1p0 = _mm512_mullo_epi32( selector1, c_int32_1p0 ); + + c_int32_2p0 = _mm512_mullo_epi32( selector1, c_int32_2p0 ); + + c_int32_3p0 = _mm512_mullo_epi32( selector1, c_int32_3p0 ); + + + // Scale C by beta. + if ( beta != 0 ) + { + memcpy( buf0, ( c + ( rs_c * 0 ) ), ( n0_rem * sizeof( int32_t ) ) ); + memcpy( buf1, ( c + ( rs_c * 1 ) ), ( n0_rem * sizeof( int32_t ) ) ); + memcpy( buf2, ( c + ( rs_c * 2 ) ), ( n0_rem * sizeof( int32_t ) ) ); + memcpy( buf3, ( c + ( rs_c * 3 ) ), ( n0_rem * sizeof( int32_t ) ) ); + + // c[0,0-15] + selector1 = _mm512_loadu_epi32( buf0 ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[1,0-15] + selector1 = _mm512_loadu_epi32( buf1 ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[2,0-15] + selector1 = _mm512_loadu_epi32( buf2 ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + // c[3,0-15] + selector1 = _mm512_loadu_epi32( buf3 ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); + } + + // Store the results. + // c[0,0-15] + _mm512_storeu_epi32( buf0, c_int32_0p0 ); + + // c[1,0-15] + _mm512_storeu_epi32( buf1, c_int32_1p0 ); + + // c[2,0-15] + _mm512_storeu_epi32( buf2, c_int32_2p0 ); + + // c[3,0-15] + _mm512_storeu_epi32( buf3, c_int32_3p0 ); + + // Memcpy partial parts. + // c[0,0-15] + memcpy( c + ( rs_c * 0 ) + ( 0*16 ), buf0, ( n0_rem * sizeof( int32_t ) ) ); + + // c[1,0-15] + memcpy( c + ( rs_c * 1 ) + ( 0*16 ), buf1, ( n0_rem * sizeof( int32_t ) ) ); + + // c[2,0-15] + memcpy( c + ( rs_c * 2 ) + ( 0*16 ), buf2, ( n0_rem * sizeof( int32_t ) ) ); + + // c[3,0-15] + memcpy( c + ( rs_c * 3 ) + ( 0*16 ), buf3, ( n0_rem * sizeof( int32_t ) ) ); + } +} + +// 3xlt16 int8o32 fringe kernel +void lpgemm_rowvar_u8s8s32o32_3xlt16 + ( + const dim_t k0, + const uint8_t* a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t* b, + const dim_t rs_b, + const dim_t cs_b, + int32_t* c, + const dim_t rs_c, + const int32_t alpha, + const int32_t beta, + const dim_t n0_rem + ) +{ + dim_t k_full_pieces = k0 / 4; + dim_t k_partial_pieces = k0 % 4; + + uint32_t a_kfringe_buf = 0; + + // For corner cases. + int32_t buf0[16]; + int32_t buf1[16]; + int32_t buf2[16]; + + { + // Registers to use for accumulating C. + __m512i c_int32_0p0 = _mm512_setzero_epi32(); + + __m512i c_int32_1p0 = _mm512_setzero_epi32(); + + __m512i c_int32_2p0 = _mm512_setzero_epi32(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+4]. + __m512i a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + + // Broadcast a[1,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-15] = a[1,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + + // Broadcast a[2,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-15] = a[2,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + __m512i a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + + // Broadcast a[1,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-15] = a[1,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + + // Broadcast a[2,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-15] = a[2,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + } + + // Load alpha and beta + __m512i selector1 = _mm512_set1_epi32( alpha ); + __m512i selector2 = _mm512_set1_epi32( beta ); + + // Scale by alpha + c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); + + c_int32_1p0 = _mm512_mullo_epi32( selector1, c_int32_1p0 ); + + c_int32_2p0 = _mm512_mullo_epi32( selector1, c_int32_2p0 ); + + // Scale C by beta. + if ( beta != 0 ) + { + memcpy( buf0, ( c + ( rs_c * 0 ) ), ( n0_rem * sizeof( int32_t ) ) ); + memcpy( buf1, ( c + ( rs_c * 1 ) ), ( n0_rem * sizeof( int32_t ) ) ); + memcpy( buf2, ( c + ( rs_c * 2 ) ), ( n0_rem * sizeof( int32_t ) ) ); + + // c[0,0-15] + selector1 = _mm512_loadu_epi32( buf0 ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[1,0-15] + selector1 = _mm512_loadu_epi32( buf1 ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[2,0-15] + selector1 = _mm512_loadu_epi32( buf2 ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + } + + // Store the results. + // c[0,0-15] + _mm512_storeu_epi32( buf0, c_int32_0p0 ); + + // c[1,0-15] + _mm512_storeu_epi32( buf1, c_int32_1p0 ); + + // c[2,0-15] + _mm512_storeu_epi32( buf2, c_int32_2p0 ); + + // Memcpy partial parts. + // c[0,0-15] + memcpy( c + ( rs_c * 0 ) + ( 0*16 ), buf0, ( n0_rem * sizeof( int32_t ) ) ); + + // c[1,0-15] + memcpy( c + ( rs_c * 1 ) + ( 0*16 ), buf1, ( n0_rem * sizeof( int32_t ) ) ); + + // c[2,0-15] + memcpy( c + ( rs_c * 2 ) + ( 0*16 ), buf2, ( n0_rem * sizeof( int32_t ) ) ); + } +} + +// 2xlt16 int8o32 fringe kernel +void lpgemm_rowvar_u8s8s32o32_2xlt16 + ( + const dim_t k0, + const uint8_t* a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t* b, + const dim_t rs_b, + const dim_t cs_b, + int32_t* c, + const dim_t rs_c, + const int32_t alpha, + const int32_t beta, + const dim_t n0_rem + ) +{ + dim_t k_full_pieces = k0 / 4; + dim_t k_partial_pieces = k0 % 4; + + uint32_t a_kfringe_buf = 0; + + // For corner cases. + int32_t buf0[16]; + int32_t buf1[16]; + + { + // Registers to use for accumulating C. + __m512i c_int32_0p0 = _mm512_setzero_epi32(); + + __m512i c_int32_1p0 = _mm512_setzero_epi32(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+4]. + __m512i a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + + // Broadcast a[1,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-15] = a[1,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + __m512i a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + + // Broadcast a[1,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-15] = a[1,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + } + + // Load alpha and beta + __m512i selector1 = _mm512_set1_epi32( alpha ); + __m512i selector2 = _mm512_set1_epi32( beta ); + + // Scale by alpha + c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); + + c_int32_1p0 = _mm512_mullo_epi32( selector1, c_int32_1p0 ); + + // Scale C by beta. + if ( beta != 0 ) + { + memcpy( buf0, ( c + ( rs_c * 0 ) ), ( n0_rem * sizeof( int32_t ) ) ); + memcpy( buf1, ( c + ( rs_c * 1 ) ), ( n0_rem * sizeof( int32_t ) ) ); + + // c[0,0-15] + selector1 = _mm512_loadu_epi32( buf0 ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[1,0-15] + selector1 = _mm512_loadu_epi32( buf1 ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + } + + // Store the results. + // c[0,0-15] + _mm512_storeu_epi32( buf0, c_int32_0p0 ); + + // c[1,0-15] + _mm512_storeu_epi32( buf1, c_int32_1p0 ); + + // Memcpy partial parts. + // c[0,0-15] + memcpy( c + ( rs_c * 0 ) + ( 0*16 ), buf0, ( n0_rem * sizeof( int32_t ) ) ); + + // c[1,0-15] + memcpy( c + ( rs_c * 1 ) + ( 0*16 ), buf1, ( n0_rem * sizeof( int32_t ) ) ); + } +} + +// 1xlt16 int8o32 fringe kernel +void lpgemm_rowvar_u8s8s32o32_1xlt16 + ( + const dim_t k0, + const uint8_t* a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t* b, + const dim_t rs_b, + const dim_t cs_b, + int32_t* c, + const dim_t rs_c, + const int32_t alpha, + const int32_t beta, + const dim_t n0_rem + ) +{ + dim_t k_full_pieces = k0 / 4; + dim_t k_partial_pieces = k0 % 4; + + uint32_t a_kfringe_buf = 0; + + // For corner cases. + int32_t buf0[16]; + + { + // Registers to use for accumulating C. + __m512i c_int32_0p0 = _mm512_setzero_epi32(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+4]. + __m512i a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + __m512i a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + } + + // Load alpha and beta + __m512i selector1 = _mm512_set1_epi32( alpha ); + __m512i selector2 = _mm512_set1_epi32( beta ); + + // Scale by alpha + c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); + + // Scale C by beta. + if ( beta != 0 ) + { + memcpy( buf0, ( c + ( rs_c * 0 ) ), ( n0_rem * sizeof( int32_t ) ) ); + + // c[0,0-15] + selector1 = _mm512_loadu_epi32( buf0 ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + } + + // Store the results. + // c[0,0-15] + _mm512_storeu_epi32( buf0, c_int32_0p0 ); + + // Memcpy partial parts. + // c[0,0-15] + memcpy( c + ( rs_c * 0 ) + ( 0*16 ), buf0, ( n0_rem * sizeof( int32_t ) ) ); + } +} + +// 5x16 int8o32 kernel +void lpgemm_rowvar_u8s8s32o32_5x16 + ( + const dim_t k0, + const uint8_t* a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t* b, + const dim_t rs_b, + const dim_t cs_b, + int32_t* c, + const dim_t rs_c, + const int32_t alpha, + const int32_t beta + ) +{ + dim_t k_full_pieces = k0 / 4; + dim_t k_partial_pieces = k0 % 4; + + uint32_t a_kfringe_buf = 0; + + // Registers to use for accumulating C. + __m512i c_int32_0p0 = _mm512_setzero_epi32(); + + __m512i c_int32_1p0 = _mm512_setzero_epi32(); + + __m512i c_int32_2p0 = _mm512_setzero_epi32(); + + __m512i c_int32_3p0 = _mm512_setzero_epi32(); + + __m512i c_int32_4p0 = _mm512_setzero_epi32(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+4]. + __m512i a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + + // Broadcast a[1,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-15] = a[1,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + + // Broadcast a[2,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-15] = a[2,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + + // Broadcast a[3,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[3,0-15] = a[3,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_0, b0 ); + + // Broadcast a[4,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[4,0-15] = a[4,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_4p0 = _mm512_dpbusd_epi32( c_int32_4p0, a_int32_0, b0 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + __m512i a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + + // Broadcast a[1,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-15] = a[1,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + + // Broadcast a[2,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-15] = a[2,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + + // Broadcast a[3,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[3,0-15] = a[3,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_0, b0 ); + + // Broadcast a[4,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 4 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[4,0-15] = a[4,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_4p0 = _mm512_dpbusd_epi32( c_int32_4p0, a_int32_0, b0 ); + } + + // Load alpha and beta + __m512i selector1 = _mm512_set1_epi32( alpha ); + __m512i selector2 = _mm512_set1_epi32( beta ); + + // Scale by alpha + c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); + + c_int32_1p0 = _mm512_mullo_epi32( selector1, c_int32_1p0 ); + + c_int32_2p0 = _mm512_mullo_epi32( selector1, c_int32_2p0 ); + + c_int32_3p0 = _mm512_mullo_epi32( selector1, c_int32_3p0 ); + + c_int32_4p0 = _mm512_mullo_epi32( selector1, c_int32_4p0 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[1,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[2,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + // c[3,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 3 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); + + // c[4,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 4 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_4p0 = _mm512_add_epi32( selector1, c_int32_4p0 ); + } + + // Store the results. + // c[0,0-15] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 0*16 ), c_int32_0p0 ); + + // c[1,0-15] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 0*16 ), c_int32_1p0 ); + + // c[2,0-15] + _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 0*16 ), c_int32_2p0 ); + + // c[3,0-15] + _mm512_storeu_epi32( c + ( rs_c * 3 ) + ( 0*16 ), c_int32_3p0 ); + + // c[4,0-15] + _mm512_storeu_epi32( c + ( rs_c * 4 ) + ( 0*16 ), c_int32_4p0 ); +} + +// 4x16 int8o32 kernel +void lpgemm_rowvar_u8s8s32o32_4x16 + ( + const dim_t k0, + const uint8_t* a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t* b, + const dim_t rs_b, + const dim_t cs_b, + int32_t* c, + const dim_t rs_c, + const int32_t alpha, + const int32_t beta + ) +{ + dim_t k_full_pieces = k0 / 4; + dim_t k_partial_pieces = k0 % 4; + + uint32_t a_kfringe_buf = 0; + + // Registers to use for accumulating C. + __m512i c_int32_0p0 = _mm512_setzero_epi32(); + + __m512i c_int32_1p0 = _mm512_setzero_epi32(); + + __m512i c_int32_2p0 = _mm512_setzero_epi32(); + + __m512i c_int32_3p0 = _mm512_setzero_epi32(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+4]. + __m512i a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + + // Broadcast a[1,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-15] = a[1,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + + // Broadcast a[2,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-15] = a[2,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + + // Broadcast a[3,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[3,0-15] = a[3,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_0, b0 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + __m512i a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + + // Broadcast a[1,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-15] = a[1,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + + // Broadcast a[2,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-15] = a[2,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + + // Broadcast a[3,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[3,0-15] = a[3,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_0, b0 ); + } + + // Load alpha and beta + __m512i selector1 = _mm512_set1_epi32( alpha ); + __m512i selector2 = _mm512_set1_epi32( beta ); + + // Scale by alpha + c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); + + c_int32_1p0 = _mm512_mullo_epi32( selector1, c_int32_1p0 ); + + c_int32_2p0 = _mm512_mullo_epi32( selector1, c_int32_2p0 ); + + c_int32_3p0 = _mm512_mullo_epi32( selector1, c_int32_3p0 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[1,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[2,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + // c[3,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 3 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); + } + + // Store the results. + // c[0,0-15] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 0*16 ), c_int32_0p0 ); + + // c[1,0-15] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 0*16 ), c_int32_1p0 ); + + // c[2,0-15] + _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 0*16 ), c_int32_2p0 ); + + // c[3,0-15] + _mm512_storeu_epi32( c + ( rs_c * 3 ) + ( 0*16 ), c_int32_3p0 ); +} + +// 3x16 int8o32 kernel +void lpgemm_rowvar_u8s8s32o32_3x16 + ( + const dim_t k0, + const uint8_t* a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t* b, + const dim_t rs_b, + const dim_t cs_b, + int32_t* c, + const dim_t rs_c, + const int32_t alpha, + const int32_t beta + ) +{ + dim_t k_full_pieces = k0 / 4; + dim_t k_partial_pieces = k0 % 4; + + uint32_t a_kfringe_buf = 0; + + // Registers to use for accumulating C. + __m512i c_int32_0p0 = _mm512_setzero_epi32(); + + __m512i c_int32_1p0 = _mm512_setzero_epi32(); + + __m512i c_int32_2p0 = _mm512_setzero_epi32(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+4]. + __m512i a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + + // Broadcast a[1,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-15] = a[1,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + + // Broadcast a[2,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-15] = a[2,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + __m512i a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + + // Broadcast a[1,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-15] = a[1,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + + // Broadcast a[2,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-15] = a[2,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + } + + // Load alpha and beta + __m512i selector1 = _mm512_set1_epi32( alpha ); + __m512i selector2 = _mm512_set1_epi32( beta ); + + // Scale by alpha + c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); + + c_int32_1p0 = _mm512_mullo_epi32( selector1, c_int32_1p0 ); + + c_int32_2p0 = _mm512_mullo_epi32( selector1, c_int32_2p0 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[1,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[2,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + } + + // Store the results. + // c[0,0-15] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 0*16 ), c_int32_0p0 ); + + // c[1,0-15] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 0*16 ), c_int32_1p0 ); + + // c[2,0-15] + _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 0*16 ), c_int32_2p0 ); +} + +// 2x16 int8o32 kernel +void lpgemm_rowvar_u8s8s32o32_2x16 + ( + const dim_t k0, + const uint8_t* a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t* b, + const dim_t rs_b, + const dim_t cs_b, + int32_t* c, + const dim_t rs_c, + const int32_t alpha, + const int32_t beta + ) +{ + dim_t k_full_pieces = k0 / 4; + dim_t k_partial_pieces = k0 % 4; + + uint32_t a_kfringe_buf = 0; + + // Registers to use for accumulating C. + __m512i c_int32_0p0 = _mm512_setzero_epi32(); + + __m512i c_int32_1p0 = _mm512_setzero_epi32(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+4]. + __m512i a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + + // Broadcast a[1,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-15] = a[1,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + __m512i a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + + // Broadcast a[1,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-15] = a[1,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + } + + // Load alpha and beta + __m512i selector1 = _mm512_set1_epi32( alpha ); + __m512i selector2 = _mm512_set1_epi32( beta ); + + // Scale by alpha + c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); + + c_int32_1p0 = _mm512_mullo_epi32( selector1, c_int32_1p0 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[1,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + } + + // Store the results. + // c[0,0-15] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 0*16 ), c_int32_0p0 ); + + // c[1,0-15] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 0*16 ), c_int32_1p0 ); +} + +// 1x16 int8o32 kernel +void lpgemm_rowvar_u8s8s32o32_1x16 + ( + const dim_t k0, + const uint8_t* a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t* b, + const dim_t rs_b, + const dim_t cs_b, + int32_t* c, + const dim_t rs_c, + const int32_t alpha, + const int32_t beta + ) +{ + dim_t k_full_pieces = k0 / 4; + dim_t k_partial_pieces = k0 % 4; + + uint32_t a_kfringe_buf = 0; + + // Registers to use for accumulating C. + __m512i c_int32_0p0 = _mm512_setzero_epi32(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+4]. + __m512i a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + __m512i a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + } + + // Load alpha and beta + __m512i selector1 = _mm512_set1_epi32( alpha ); + __m512i selector2 = _mm512_set1_epi32( beta ); + + // Scale by alpha + c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + } + + // Store the results. + // c[0,0-15] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 0*16 ), c_int32_0p0 ); +} + +// 5x32 int8o32 kernel +void lpgemm_rowvar_u8s8s32o32_5x32 + ( + const dim_t k0, + const uint8_t* a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t* b, + const dim_t rs_b, + const dim_t cs_b, + int32_t* c, + const dim_t rs_c, + const int32_t alpha, + const int32_t beta + ) +{ + dim_t k_full_pieces = k0 / 4; + dim_t k_partial_pieces = k0 % 4; + + uint32_t a_kfringe_buf = 0; + + // Registers to use for accumulating C. + __m512i c_int32_0p0 = _mm512_setzero_epi32(); + __m512i c_int32_0p1 = _mm512_setzero_epi32(); + + __m512i c_int32_1p0 = _mm512_setzero_epi32(); + __m512i c_int32_1p1 = _mm512_setzero_epi32(); + + __m512i c_int32_2p0 = _mm512_setzero_epi32(); + __m512i c_int32_2p1 = _mm512_setzero_epi32(); + + __m512i c_int32_3p0 = _mm512_setzero_epi32(); + __m512i c_int32_3p1 = _mm512_setzero_epi32(); + + __m512i c_int32_4p0 = _mm512_setzero_epi32(); + __m512i c_int32_4p1 = _mm512_setzero_epi32(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + __m512i b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + + // Broadcast a[0,kr:kr+4]. + __m512i a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-31] = a[0,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); + + // Broadcast a[1,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-31] = a[1,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_0, b1 ); + + // Broadcast a[2,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-31] = a[2,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + c_int32_2p1 = _mm512_dpbusd_epi32( c_int32_2p1, a_int32_0, b1 ); + + // Broadcast a[3,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[3,0-31] = a[3,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_0, b0 ); + c_int32_3p1 = _mm512_dpbusd_epi32( c_int32_3p1, a_int32_0, b1 ); + + // Broadcast a[4,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[4,0-31] = a[4,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_4p0 = _mm512_dpbusd_epi32( c_int32_4p0, a_int32_0, b0 ); + c_int32_4p1 = _mm512_dpbusd_epi32( c_int32_4p1, a_int32_0, b1 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + __m512i b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + + // Broadcast a[0,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + __m512i a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-31] = a[0,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); + + // Broadcast a[1,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-31] = a[1,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_0, b1 ); + + // Broadcast a[2,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-31] = a[2,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + c_int32_2p1 = _mm512_dpbusd_epi32( c_int32_2p1, a_int32_0, b1 ); + + // Broadcast a[3,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[3,0-31] = a[3,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_0, b0 ); + c_int32_3p1 = _mm512_dpbusd_epi32( c_int32_3p1, a_int32_0, b1 ); + + // Broadcast a[4,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 4 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[4,0-31] = a[4,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_4p0 = _mm512_dpbusd_epi32( c_int32_4p0, a_int32_0, b0 ); + c_int32_4p1 = _mm512_dpbusd_epi32( c_int32_4p1, a_int32_0, b1 ); + } + + // Load alpha and beta + __m512i selector1 = _mm512_set1_epi32( alpha ); + __m512i selector2 = _mm512_set1_epi32( beta ); + + // Scale by alpha + c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); + c_int32_0p1 = _mm512_mullo_epi32( selector1, c_int32_0p1 ); + + c_int32_1p0 = _mm512_mullo_epi32( selector1, c_int32_1p0 ); + c_int32_1p1 = _mm512_mullo_epi32( selector1, c_int32_1p1 ); + + c_int32_2p0 = _mm512_mullo_epi32( selector1, c_int32_2p0 ); + c_int32_2p1 = _mm512_mullo_epi32( selector1, c_int32_2p1 ); + + c_int32_3p0 = _mm512_mullo_epi32( selector1, c_int32_3p0 ); + c_int32_3p1 = _mm512_mullo_epi32( selector1, c_int32_3p1 ); + + c_int32_4p0 = _mm512_mullo_epi32( selector1, c_int32_4p0 ); + c_int32_4p1 = _mm512_mullo_epi32( selector1, c_int32_4p1 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p1 = _mm512_add_epi32( selector1, c_int32_0p1 ); + + // c[1,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[1,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p1 = _mm512_add_epi32( selector1, c_int32_1p1 ); + + // c[2,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + // c[2,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p1 = _mm512_add_epi32( selector1, c_int32_2p1 ); + + // c[3,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 3 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); + + // c[3,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 3 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_3p1 = _mm512_add_epi32( selector1, c_int32_3p1 ); + + // c[4,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 4 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_4p0 = _mm512_add_epi32( selector1, c_int32_4p0 ); + + // c[4,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 4 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_4p1 = _mm512_add_epi32( selector1, c_int32_4p1 ); + } + + // Store the results. + // c[0,0-15] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 0*16 ), c_int32_0p0 ); + + // c[0, 16-31] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 1*16 ), c_int32_0p1 ); + + // c[1,0-15] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 0*16 ), c_int32_1p0 ); + + // c[1,16-31] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 1*16 ), c_int32_1p1 ); + + // c[2,0-15] + _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 0*16 ), c_int32_2p0 ); + + // c[2,16-31] + _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 1*16 ), c_int32_2p1 ); + + // c[3,0-15] + _mm512_storeu_epi32( c + ( rs_c * 3 ) + ( 0*16 ), c_int32_3p0 ); + + // c[3,16-31] + _mm512_storeu_epi32( c + ( rs_c * 3 ) + ( 1*16 ), c_int32_3p1 ); + + // c[4,0-15] + _mm512_storeu_epi32( c + ( rs_c * 4 ) + ( 0*16 ), c_int32_4p0 ); + + // c[4,16-31] + _mm512_storeu_epi32( c + ( rs_c * 4 ) + ( 1*16 ), c_int32_4p1 ); +} + +// 4x32 int8o32 kernel +void lpgemm_rowvar_u8s8s32o32_4x32 + ( + const dim_t k0, + const uint8_t* a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t* b, + const dim_t rs_b, + const dim_t cs_b, + int32_t* c, + const dim_t rs_c, + const int32_t alpha, + const int32_t beta + ) +{ + dim_t k_full_pieces = k0 / 4; + dim_t k_partial_pieces = k0 % 4; + + uint32_t a_kfringe_buf = 0; + + // Registers to use for accumulating C. + __m512i c_int32_0p0 = _mm512_setzero_epi32(); + __m512i c_int32_0p1 = _mm512_setzero_epi32(); + + __m512i c_int32_1p0 = _mm512_setzero_epi32(); + __m512i c_int32_1p1 = _mm512_setzero_epi32(); + + __m512i c_int32_2p0 = _mm512_setzero_epi32(); + __m512i c_int32_2p1 = _mm512_setzero_epi32(); + + __m512i c_int32_3p0 = _mm512_setzero_epi32(); + __m512i c_int32_3p1 = _mm512_setzero_epi32(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + __m512i b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + + // Broadcast a[0,kr:kr+4]. + __m512i a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-31] = a[0,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); + + // Broadcast a[1,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-31] = a[1,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_0, b1 ); + + // Broadcast a[2,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-31] = a[2,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + c_int32_2p1 = _mm512_dpbusd_epi32( c_int32_2p1, a_int32_0, b1 ); + + // Broadcast a[3,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[3,0-31] = a[3,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_0, b0 ); + c_int32_3p1 = _mm512_dpbusd_epi32( c_int32_3p1, a_int32_0, b1 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + __m512i b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + + // Broadcast a[0,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + __m512i a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-31] = a[0,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); + + // Broadcast a[1,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-31] = a[1,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_0, b1 ); + + // Broadcast a[2,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-31] = a[2,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + c_int32_2p1 = _mm512_dpbusd_epi32( c_int32_2p1, a_int32_0, b1 ); + + // Broadcast a[3,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[3,0-31] = a[3,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_0, b0 ); + c_int32_3p1 = _mm512_dpbusd_epi32( c_int32_3p1, a_int32_0, b1 ); + } + + // Load alpha and beta + __m512i selector1 = _mm512_set1_epi32( alpha ); + __m512i selector2 = _mm512_set1_epi32( beta ); + + // Scale by alpha + c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); + c_int32_0p1 = _mm512_mullo_epi32( selector1, c_int32_0p1 ); + + c_int32_1p0 = _mm512_mullo_epi32( selector1, c_int32_1p0 ); + c_int32_1p1 = _mm512_mullo_epi32( selector1, c_int32_1p1 ); + + c_int32_2p0 = _mm512_mullo_epi32( selector1, c_int32_2p0 ); + c_int32_2p1 = _mm512_mullo_epi32( selector1, c_int32_2p1 ); + + c_int32_3p0 = _mm512_mullo_epi32( selector1, c_int32_3p0 ); + c_int32_3p1 = _mm512_mullo_epi32( selector1, c_int32_3p1 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p1 = _mm512_add_epi32( selector1, c_int32_0p1 ); + + // c[1,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[1,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p1 = _mm512_add_epi32( selector1, c_int32_1p1 ); + + // c[2,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + // c[2,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p1 = _mm512_add_epi32( selector1, c_int32_2p1 ); + + // c[3,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 3 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); + + // c[3,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 3 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_3p1 = _mm512_add_epi32( selector1, c_int32_3p1 ); + } + + // Store the results. + // c[0,0-15] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 0*16 ), c_int32_0p0 ); + + // c[0, 16-31] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 1*16 ), c_int32_0p1 ); + + // c[1,0-15] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 0*16 ), c_int32_1p0 ); + + // c[1,16-31] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 1*16 ), c_int32_1p1 ); + + // c[2,0-15] + _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 0*16 ), c_int32_2p0 ); + + // c[2,16-31] + _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 1*16 ), c_int32_2p1 ); + + // c[3,0-15] + _mm512_storeu_epi32( c + ( rs_c * 3 ) + ( 0*16 ), c_int32_3p0 ); + + // c[3,16-31] + _mm512_storeu_epi32( c + ( rs_c * 3 ) + ( 1*16 ), c_int32_3p1 ); +} + +// 3x32 int8o32 kernel +void lpgemm_rowvar_u8s8s32o32_3x32 + ( + const dim_t k0, + const uint8_t* a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t* b, + const dim_t rs_b, + const dim_t cs_b, + int32_t* c, + const dim_t rs_c, + const int32_t alpha, + const int32_t beta + ) +{ + dim_t k_full_pieces = k0 / 4; + dim_t k_partial_pieces = k0 % 4; + + uint32_t a_kfringe_buf = 0; + + // Registers to use for accumulating C. + __m512i c_int32_0p0 = _mm512_setzero_epi32(); + __m512i c_int32_0p1 = _mm512_setzero_epi32(); + + __m512i c_int32_1p0 = _mm512_setzero_epi32(); + __m512i c_int32_1p1 = _mm512_setzero_epi32(); + + __m512i c_int32_2p0 = _mm512_setzero_epi32(); + __m512i c_int32_2p1 = _mm512_setzero_epi32(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + __m512i b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + + // Broadcast a[0,kr:kr+4]. + __m512i a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-31] = a[0,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); + + // Broadcast a[1,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-31] = a[1,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_0, b1 ); + + // Broadcast a[2,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-31] = a[2,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + c_int32_2p1 = _mm512_dpbusd_epi32( c_int32_2p1, a_int32_0, b1 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + __m512i b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + + // Broadcast a[0,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + __m512i a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-31] = a[0,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); + + // Broadcast a[1,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-31] = a[1,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_0, b1 ); + + // Broadcast a[2,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-31] = a[2,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + c_int32_2p1 = _mm512_dpbusd_epi32( c_int32_2p1, a_int32_0, b1 ); + } + + // Load alpha and beta + __m512i selector1 = _mm512_set1_epi32( alpha ); + __m512i selector2 = _mm512_set1_epi32( beta ); + + // Scale by alpha + c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); + c_int32_0p1 = _mm512_mullo_epi32( selector1, c_int32_0p1 ); + + c_int32_1p0 = _mm512_mullo_epi32( selector1, c_int32_1p0 ); + c_int32_1p1 = _mm512_mullo_epi32( selector1, c_int32_1p1 ); + + c_int32_2p0 = _mm512_mullo_epi32( selector1, c_int32_2p0 ); + c_int32_2p1 = _mm512_mullo_epi32( selector1, c_int32_2p1 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p1 = _mm512_add_epi32( selector1, c_int32_0p1 ); + + // c[1,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[1,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p1 = _mm512_add_epi32( selector1, c_int32_1p1 ); + + // c[2,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + // c[2,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p1 = _mm512_add_epi32( selector1, c_int32_2p1 ); + } + + // Store the results. + // c[0,0-15] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 0*16 ), c_int32_0p0 ); + + // c[0, 16-31] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 1*16 ), c_int32_0p1 ); + + // c[1,0-15] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 0*16 ), c_int32_1p0 ); + + // c[1,16-31] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 1*16 ), c_int32_1p1 ); + + // c[2,0-15] + _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 0*16 ), c_int32_2p0 ); + + // c[2,16-31] + _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 1*16 ), c_int32_2p1 ); +} + +// 2x32 int8o32 kernel +void lpgemm_rowvar_u8s8s32o32_2x32 + ( + const dim_t k0, + const uint8_t* a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t* b, + const dim_t rs_b, + const dim_t cs_b, + int32_t* c, + const dim_t rs_c, + const int32_t alpha, + const int32_t beta + ) +{ + dim_t k_full_pieces = k0 / 4; + dim_t k_partial_pieces = k0 % 4; + + uint32_t a_kfringe_buf = 0; + + // Registers to use for accumulating C. + __m512i c_int32_0p0 = _mm512_setzero_epi32(); + __m512i c_int32_0p1 = _mm512_setzero_epi32(); + + __m512i c_int32_1p0 = _mm512_setzero_epi32(); + __m512i c_int32_1p1 = _mm512_setzero_epi32(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + __m512i b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + + // Broadcast a[0,kr:kr+4]. + __m512i a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-31] = a[0,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); + + // Broadcast a[1,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-31] = a[1,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_0, b1 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + __m512i b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + + // Broadcast a[0,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + __m512i a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-31] = a[0,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); + + // Broadcast a[1,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-31] = a[1,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_0, b1 ); + } + + // Load alpha and beta + __m512i selector1 = _mm512_set1_epi32( alpha ); + __m512i selector2 = _mm512_set1_epi32( beta ); + + // Scale by alpha + c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); + c_int32_0p1 = _mm512_mullo_epi32( selector1, c_int32_0p1 ); + + c_int32_1p0 = _mm512_mullo_epi32( selector1, c_int32_1p0 ); + c_int32_1p1 = _mm512_mullo_epi32( selector1, c_int32_1p1 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p1 = _mm512_add_epi32( selector1, c_int32_0p1 ); + + // c[1,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[1,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p1 = _mm512_add_epi32( selector1, c_int32_1p1 ); + } + + // Store the results. + // c[0,0-15] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 0*16 ), c_int32_0p0 ); + + // c[0, 16-31] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 1*16 ), c_int32_0p1 ); + + // c[1,0-15] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 0*16 ), c_int32_1p0 ); + + // c[1,16-31] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 1*16 ), c_int32_1p1 ); +} + +// 1x32 int8o32 kernel +void lpgemm_rowvar_u8s8s32o32_1x32 + ( + const dim_t k0, + const uint8_t* a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t* b, + const dim_t rs_b, + const dim_t cs_b, + int32_t* c, + const dim_t rs_c, + const int32_t alpha, + const int32_t beta + ) +{ + dim_t k_full_pieces = k0 / 4; + dim_t k_partial_pieces = k0 % 4; + + uint32_t a_kfringe_buf = 0; + + // Registers to use for accumulating C. + __m512i c_int32_0p0 = _mm512_setzero_epi32(); + __m512i c_int32_0p1 = _mm512_setzero_epi32(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + __m512i b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + + // Broadcast a[0,kr:kr+4]. + __m512i a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-31] = a[0,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + __m512i b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + + // Broadcast a[0,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + __m512i a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-31] = a[0,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); + } + + // Load alpha and beta + __m512i selector1 = _mm512_set1_epi32( alpha ); + __m512i selector2 = _mm512_set1_epi32( beta ); + + // Scale by alpha + c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); + c_int32_0p1 = _mm512_mullo_epi32( selector1, c_int32_0p1 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p1 = _mm512_add_epi32( selector1, c_int32_0p1 ); + } + + // Store the results. + // c[0,0-15] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 0*16 ), c_int32_0p0 ); + + // c[0, 16-31] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 1*16 ), c_int32_0p1 ); +} + +// 5x48 int8o32 kernel +void lpgemm_rowvar_u8s8s32o32_5x48 + ( + const dim_t k0, + const uint8_t* a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t* b, + const dim_t rs_b, + const dim_t cs_b, + int32_t* c, + const dim_t rs_c, + const int32_t alpha, + const int32_t beta + ) +{ + dim_t k_full_pieces = k0 / 4; + dim_t k_partial_pieces = k0 % 4; + + uint32_t a_kfringe_buf = 0; + + // Registers to use for accumulating C. + __m512i c_int32_0p0 = _mm512_setzero_epi32(); + __m512i c_int32_0p1 = _mm512_setzero_epi32(); + __m512i c_int32_0p2 = _mm512_setzero_epi32(); + + __m512i c_int32_1p0 = _mm512_setzero_epi32(); + __m512i c_int32_1p1 = _mm512_setzero_epi32(); + __m512i c_int32_1p2 = _mm512_setzero_epi32(); + + __m512i c_int32_2p0 = _mm512_setzero_epi32(); + __m512i c_int32_2p1 = _mm512_setzero_epi32(); + __m512i c_int32_2p2 = _mm512_setzero_epi32(); + + __m512i c_int32_3p0 = _mm512_setzero_epi32(); + __m512i c_int32_3p1 = _mm512_setzero_epi32(); + __m512i c_int32_3p2 = _mm512_setzero_epi32(); + + __m512i c_int32_4p0 = _mm512_setzero_epi32(); + __m512i c_int32_4p1 = _mm512_setzero_epi32(); + __m512i c_int32_4p2 = _mm512_setzero_epi32(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + __m512i b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + __m512i b2 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 2 ) ); + + // Broadcast a[0,kr:kr+4]. + __m512i a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-47] = a[0,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); + c_int32_0p2 = _mm512_dpbusd_epi32( c_int32_0p2, a_int32_0, b2 ); + + // Broadcast a[1,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-47] = a[1,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_0, b1 ); + c_int32_1p2 = _mm512_dpbusd_epi32( c_int32_1p2, a_int32_0, b2 ); + + // Broadcast a[2,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-47] = a[2,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + c_int32_2p1 = _mm512_dpbusd_epi32( c_int32_2p1, a_int32_0, b1 ); + c_int32_2p2 = _mm512_dpbusd_epi32( c_int32_2p2, a_int32_0, b2 ); + + // Broadcast a[3,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[3,0-47] = a[3,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_0, b0 ); + c_int32_3p1 = _mm512_dpbusd_epi32( c_int32_3p1, a_int32_0, b1 ); + c_int32_3p2 = _mm512_dpbusd_epi32( c_int32_3p2, a_int32_0, b2 ); + + // Broadcast a[4,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[4,0-47] = a[4,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_4p0 = _mm512_dpbusd_epi32( c_int32_4p0, a_int32_0, b0 ); + c_int32_4p1 = _mm512_dpbusd_epi32( c_int32_4p1, a_int32_0, b1 ); + c_int32_4p2 = _mm512_dpbusd_epi32( c_int32_4p2, a_int32_0, b2 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + __m512i b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + __m512i b2 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); + + // Broadcast a[0,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + __m512i a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-47] = a[0,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); + c_int32_0p2 = _mm512_dpbusd_epi32( c_int32_0p2, a_int32_0, b2 ); + + // Broadcast a[1,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-47] = a[1,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_0, b1 ); + c_int32_1p2 = _mm512_dpbusd_epi32( c_int32_1p2, a_int32_0, b2 ); + + // Broadcast a[2,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-47] = a[2,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + c_int32_2p1 = _mm512_dpbusd_epi32( c_int32_2p1, a_int32_0, b1 ); + c_int32_2p2 = _mm512_dpbusd_epi32( c_int32_2p2, a_int32_0, b2 ); + + // Broadcast a[3,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[3,0-47] = a[3,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_0, b0 ); + c_int32_3p1 = _mm512_dpbusd_epi32( c_int32_3p1, a_int32_0, b1 ); + c_int32_3p2 = _mm512_dpbusd_epi32( c_int32_3p2, a_int32_0, b2 ); + + // Broadcast a[4,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 4 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[4,0-47] = a[4,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_4p0 = _mm512_dpbusd_epi32( c_int32_4p0, a_int32_0, b0 ); + c_int32_4p1 = _mm512_dpbusd_epi32( c_int32_4p1, a_int32_0, b1 ); + c_int32_4p2 = _mm512_dpbusd_epi32( c_int32_4p2, a_int32_0, b2 ); + } + + // Load alpha and beta + __m512i selector1 = _mm512_set1_epi32( alpha ); + __m512i selector2 = _mm512_set1_epi32( beta ); + + // Scale by alpha + c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); + c_int32_0p1 = _mm512_mullo_epi32( selector1, c_int32_0p1 ); + c_int32_0p2 = _mm512_mullo_epi32( selector1, c_int32_0p2 ); + + c_int32_1p0 = _mm512_mullo_epi32( selector1, c_int32_1p0 ); + c_int32_1p1 = _mm512_mullo_epi32( selector1, c_int32_1p1 ); + c_int32_1p2 = _mm512_mullo_epi32( selector1, c_int32_1p2 ); + + c_int32_2p0 = _mm512_mullo_epi32( selector1, c_int32_2p0 ); + c_int32_2p1 = _mm512_mullo_epi32( selector1, c_int32_2p1 ); + c_int32_2p2 = _mm512_mullo_epi32( selector1, c_int32_2p2 ); + + c_int32_3p0 = _mm512_mullo_epi32( selector1, c_int32_3p0 ); + c_int32_3p1 = _mm512_mullo_epi32( selector1, c_int32_3p1 ); + c_int32_3p2 = _mm512_mullo_epi32( selector1, c_int32_3p2 ); + + c_int32_4p0 = _mm512_mullo_epi32( selector1, c_int32_4p0 ); + c_int32_4p1 = _mm512_mullo_epi32( selector1, c_int32_4p1 ); + c_int32_4p2 = _mm512_mullo_epi32( selector1, c_int32_4p2 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p1 = _mm512_add_epi32( selector1, c_int32_0p1 ); + + // c[0,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p2 = _mm512_add_epi32( selector1, c_int32_0p2 ); + + // c[1,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[1,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p1 = _mm512_add_epi32( selector1, c_int32_1p1 ); + + // c[1,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p2 = _mm512_add_epi32( selector1, c_int32_1p2 ); + + // c[2,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + // c[2,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p1 = _mm512_add_epi32( selector1, c_int32_2p1 ); + + // c[2,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p2 = _mm512_add_epi32( selector1, c_int32_2p2 ); + + // c[3,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 3 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); + + // c[3,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 3 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_3p1 = _mm512_add_epi32( selector1, c_int32_3p1 ); + + // c[3,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 3 ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_3p2 = _mm512_add_epi32( selector1, c_int32_3p2 ); + + // c[4,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 4 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_4p0 = _mm512_add_epi32( selector1, c_int32_4p0 ); + + // c[4,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 4 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_4p1 = _mm512_add_epi32( selector1, c_int32_4p1 ); + + // c[4,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 4 ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_4p2 = _mm512_add_epi32( selector1, c_int32_4p2 ); + } + + // Store the results. + // c[0,0-15] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 0*16 ), c_int32_0p0 ); + + // c[0, 16-31] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 1*16 ), c_int32_0p1 ); + + // c[0,32-47] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 2*16 ), c_int32_0p2 ); + + // c[1,0-15] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 0*16 ), c_int32_1p0 ); + + // c[1,16-31] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 1*16 ), c_int32_1p1 ); + + // c[1,32-47] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 2*16 ), c_int32_1p2 ); + + // c[2,0-15] + _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 0*16 ), c_int32_2p0 ); + + // c[2,16-31] + _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 1*16 ), c_int32_2p1 ); + + // c[2,32-47] + _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 2*16 ), c_int32_2p2 ); + + // c[3,0-15] + _mm512_storeu_epi32( c + ( rs_c * 3 ) + ( 0*16 ), c_int32_3p0 ); + + // c[3,16-31] + _mm512_storeu_epi32( c + ( rs_c * 3 ) + ( 1*16 ), c_int32_3p1 ); + + // c[3,32-47] + _mm512_storeu_epi32( c + ( rs_c * 3 ) + ( 2*16 ), c_int32_3p2 ); + + // c[4,0-15] + _mm512_storeu_epi32( c + ( rs_c * 4 ) + ( 0*16 ), c_int32_4p0 ); + + // c[4,16-31] + _mm512_storeu_epi32( c + ( rs_c * 4 ) + ( 1*16 ), c_int32_4p1 ); + + // c[4,32-47] + _mm512_storeu_epi32( c + ( rs_c * 4 ) + ( 2*16 ), c_int32_4p2 ); +} + +// 4x48 int8o32 kernel +void lpgemm_rowvar_u8s8s32o32_4x48 + ( + const dim_t k0, + const uint8_t* a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t* b, + const dim_t rs_b, + const dim_t cs_b, + int32_t* c, + const dim_t rs_c, + const int32_t alpha, + const int32_t beta + ) +{ + dim_t k_full_pieces = k0 / 4; + dim_t k_partial_pieces = k0 % 4; + + uint32_t a_kfringe_buf = 0; + + // Registers to use for accumulating C. + __m512i c_int32_0p0 = _mm512_setzero_epi32(); + __m512i c_int32_0p1 = _mm512_setzero_epi32(); + __m512i c_int32_0p2 = _mm512_setzero_epi32(); + + __m512i c_int32_1p0 = _mm512_setzero_epi32(); + __m512i c_int32_1p1 = _mm512_setzero_epi32(); + __m512i c_int32_1p2 = _mm512_setzero_epi32(); + + __m512i c_int32_2p0 = _mm512_setzero_epi32(); + __m512i c_int32_2p1 = _mm512_setzero_epi32(); + __m512i c_int32_2p2 = _mm512_setzero_epi32(); + + __m512i c_int32_3p0 = _mm512_setzero_epi32(); + __m512i c_int32_3p1 = _mm512_setzero_epi32(); + __m512i c_int32_3p2 = _mm512_setzero_epi32(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + __m512i b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + __m512i b2 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 2 ) ); + + // Broadcast a[0,kr:kr+4]. + __m512i a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-47] = a[0,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); + c_int32_0p2 = _mm512_dpbusd_epi32( c_int32_0p2, a_int32_0, b2 ); + + // Broadcast a[1,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-47] = a[1,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_0, b1 ); + c_int32_1p2 = _mm512_dpbusd_epi32( c_int32_1p2, a_int32_0, b2 ); + + // Broadcast a[2,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-47] = a[2,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + c_int32_2p1 = _mm512_dpbusd_epi32( c_int32_2p1, a_int32_0, b1 ); + c_int32_2p2 = _mm512_dpbusd_epi32( c_int32_2p2, a_int32_0, b2 ); + + // Broadcast a[3,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[3,0-47] = a[3,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_0, b0 ); + c_int32_3p1 = _mm512_dpbusd_epi32( c_int32_3p1, a_int32_0, b1 ); + c_int32_3p2 = _mm512_dpbusd_epi32( c_int32_3p2, a_int32_0, b2 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + __m512i b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + __m512i b2 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); + + // Broadcast a[0,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + __m512i a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-47] = a[0,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); + c_int32_0p2 = _mm512_dpbusd_epi32( c_int32_0p2, a_int32_0, b2 ); + + // Broadcast a[1,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-47] = a[1,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_0, b1 ); + c_int32_1p2 = _mm512_dpbusd_epi32( c_int32_1p2, a_int32_0, b2 ); + + // Broadcast a[2,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-47] = a[2,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + c_int32_2p1 = _mm512_dpbusd_epi32( c_int32_2p1, a_int32_0, b1 ); + c_int32_2p2 = _mm512_dpbusd_epi32( c_int32_2p2, a_int32_0, b2 ); + + // Broadcast a[3,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[3,0-47] = a[3,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_0, b0 ); + c_int32_3p1 = _mm512_dpbusd_epi32( c_int32_3p1, a_int32_0, b1 ); + c_int32_3p2 = _mm512_dpbusd_epi32( c_int32_3p2, a_int32_0, b2 ); + } + + // Load alpha and beta + __m512i selector1 = _mm512_set1_epi32( alpha ); + __m512i selector2 = _mm512_set1_epi32( beta ); + + // Scale by alpha + c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); + c_int32_0p1 = _mm512_mullo_epi32( selector1, c_int32_0p1 ); + c_int32_0p2 = _mm512_mullo_epi32( selector1, c_int32_0p2 ); + + c_int32_1p0 = _mm512_mullo_epi32( selector1, c_int32_1p0 ); + c_int32_1p1 = _mm512_mullo_epi32( selector1, c_int32_1p1 ); + c_int32_1p2 = _mm512_mullo_epi32( selector1, c_int32_1p2 ); + + c_int32_2p0 = _mm512_mullo_epi32( selector1, c_int32_2p0 ); + c_int32_2p1 = _mm512_mullo_epi32( selector1, c_int32_2p1 ); + c_int32_2p2 = _mm512_mullo_epi32( selector1, c_int32_2p2 ); + + c_int32_3p0 = _mm512_mullo_epi32( selector1, c_int32_3p0 ); + c_int32_3p1 = _mm512_mullo_epi32( selector1, c_int32_3p1 ); + c_int32_3p2 = _mm512_mullo_epi32( selector1, c_int32_3p2 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p1 = _mm512_add_epi32( selector1, c_int32_0p1 ); + + // c[0,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p2 = _mm512_add_epi32( selector1, c_int32_0p2 ); + + // c[1,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[1,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p1 = _mm512_add_epi32( selector1, c_int32_1p1 ); + + // c[1,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p2 = _mm512_add_epi32( selector1, c_int32_1p2 ); + + // c[2,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + // c[2,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p1 = _mm512_add_epi32( selector1, c_int32_2p1 ); + + // c[2,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p2 = _mm512_add_epi32( selector1, c_int32_2p2 ); + + // c[3,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 3 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); + + // c[3,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 3 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_3p1 = _mm512_add_epi32( selector1, c_int32_3p1 ); + + // c[3,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 3 ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_3p2 = _mm512_add_epi32( selector1, c_int32_3p2 ); + } + + // Store the results. + // c[0,0-15] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 0*16 ), c_int32_0p0 ); + + // c[0, 16-31] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 1*16 ), c_int32_0p1 ); + + // c[0,32-47] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 2*16 ), c_int32_0p2 ); + + // c[1,0-15] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 0*16 ), c_int32_1p0 ); + + // c[1,16-31] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 1*16 ), c_int32_1p1 ); + + // c[1,32-47] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 2*16 ), c_int32_1p2 ); + + // c[2,0-15] + _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 0*16 ), c_int32_2p0 ); + + // c[2,16-31] + _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 1*16 ), c_int32_2p1 ); + + // c[2,32-47] + _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 2*16 ), c_int32_2p2 ); + + // c[3,0-15] + _mm512_storeu_epi32( c + ( rs_c * 3 ) + ( 0*16 ), c_int32_3p0 ); + + // c[3,16-31] + _mm512_storeu_epi32( c + ( rs_c * 3 ) + ( 1*16 ), c_int32_3p1 ); + + // c[3,32-47] + _mm512_storeu_epi32( c + ( rs_c * 3 ) + ( 2*16 ), c_int32_3p2 ); +} + +// 3x48 int8o32 kernel +void lpgemm_rowvar_u8s8s32o32_3x48 + ( + const dim_t k0, + const uint8_t* a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t* b, + const dim_t rs_b, + const dim_t cs_b, + int32_t* c, + const dim_t rs_c, + const int32_t alpha, + const int32_t beta + ) +{ + dim_t k_full_pieces = k0 / 4; + dim_t k_partial_pieces = k0 % 4; + + uint32_t a_kfringe_buf = 0; + + // Registers to use for accumulating C. + __m512i c_int32_0p0 = _mm512_setzero_epi32(); + __m512i c_int32_0p1 = _mm512_setzero_epi32(); + __m512i c_int32_0p2 = _mm512_setzero_epi32(); + + __m512i c_int32_1p0 = _mm512_setzero_epi32(); + __m512i c_int32_1p1 = _mm512_setzero_epi32(); + __m512i c_int32_1p2 = _mm512_setzero_epi32(); + + __m512i c_int32_2p0 = _mm512_setzero_epi32(); + __m512i c_int32_2p1 = _mm512_setzero_epi32(); + __m512i c_int32_2p2 = _mm512_setzero_epi32(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + __m512i b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + __m512i b2 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 2 ) ); + + // Broadcast a[0,kr:kr+4]. + __m512i a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-47] = a[0,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); + c_int32_0p2 = _mm512_dpbusd_epi32( c_int32_0p2, a_int32_0, b2 ); + + // Broadcast a[1,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-47] = a[1,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_0, b1 ); + c_int32_1p2 = _mm512_dpbusd_epi32( c_int32_1p2, a_int32_0, b2 ); + + // Broadcast a[2,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-47] = a[2,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + c_int32_2p1 = _mm512_dpbusd_epi32( c_int32_2p1, a_int32_0, b1 ); + c_int32_2p2 = _mm512_dpbusd_epi32( c_int32_2p2, a_int32_0, b2 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + __m512i b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + __m512i b2 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); + + // Broadcast a[0,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + __m512i a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-47] = a[0,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); + c_int32_0p2 = _mm512_dpbusd_epi32( c_int32_0p2, a_int32_0, b2 ); + + // Broadcast a[1,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-47] = a[1,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_0, b1 ); + c_int32_1p2 = _mm512_dpbusd_epi32( c_int32_1p2, a_int32_0, b2 ); + + // Broadcast a[2,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-47] = a[2,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + c_int32_2p1 = _mm512_dpbusd_epi32( c_int32_2p1, a_int32_0, b1 ); + c_int32_2p2 = _mm512_dpbusd_epi32( c_int32_2p2, a_int32_0, b2 ); + } + + // Load alpha and beta + __m512i selector1 = _mm512_set1_epi32( alpha ); + __m512i selector2 = _mm512_set1_epi32( beta ); + + // Scale by alpha + c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); + c_int32_0p1 = _mm512_mullo_epi32( selector1, c_int32_0p1 ); + c_int32_0p2 = _mm512_mullo_epi32( selector1, c_int32_0p2 ); + + c_int32_1p0 = _mm512_mullo_epi32( selector1, c_int32_1p0 ); + c_int32_1p1 = _mm512_mullo_epi32( selector1, c_int32_1p1 ); + c_int32_1p2 = _mm512_mullo_epi32( selector1, c_int32_1p2 ); + + c_int32_2p0 = _mm512_mullo_epi32( selector1, c_int32_2p0 ); + c_int32_2p1 = _mm512_mullo_epi32( selector1, c_int32_2p1 ); + c_int32_2p2 = _mm512_mullo_epi32( selector1, c_int32_2p2 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p1 = _mm512_add_epi32( selector1, c_int32_0p1 ); + + // c[0,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p2 = _mm512_add_epi32( selector1, c_int32_0p2 ); + + // c[1,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[1,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p1 = _mm512_add_epi32( selector1, c_int32_1p1 ); + + // c[1,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p2 = _mm512_add_epi32( selector1, c_int32_1p2 ); + + // c[2,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + // c[2,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p1 = _mm512_add_epi32( selector1, c_int32_2p1 ); + + // c[2,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 2 ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p2 = _mm512_add_epi32( selector1, c_int32_2p2 ); + } + + // Store the results. + // c[0,0-15] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 0*16 ), c_int32_0p0 ); + + // c[0, 16-31] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 1*16 ), c_int32_0p1 ); + + // c[0,32-47] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 2*16 ), c_int32_0p2 ); + + // c[1,0-15] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 0*16 ), c_int32_1p0 ); + + // c[1,16-31] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 1*16 ), c_int32_1p1 ); + + // c[1,32-47] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 2*16 ), c_int32_1p2 ); + + // c[2,0-15] + _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 0*16 ), c_int32_2p0 ); + + // c[2,16-31] + _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 1*16 ), c_int32_2p1 ); + + // c[2,32-47] + _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 2*16 ), c_int32_2p2 ); +} + +// 2x48 int8o32 kernel +void lpgemm_rowvar_u8s8s32o32_2x48 + ( + const dim_t k0, + const uint8_t* a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t* b, + const dim_t rs_b, + const dim_t cs_b, + int32_t* c, + const dim_t rs_c, + const int32_t alpha, + const int32_t beta + ) +{ + dim_t k_full_pieces = k0 / 4; + dim_t k_partial_pieces = k0 % 4; + + uint32_t a_kfringe_buf = 0; + + // Registers to use for accumulating C. + __m512i c_int32_0p0 = _mm512_setzero_epi32(); + __m512i c_int32_0p1 = _mm512_setzero_epi32(); + __m512i c_int32_0p2 = _mm512_setzero_epi32(); + + __m512i c_int32_1p0 = _mm512_setzero_epi32(); + __m512i c_int32_1p1 = _mm512_setzero_epi32(); + __m512i c_int32_1p2 = _mm512_setzero_epi32(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + __m512i b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + __m512i b2 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 2 ) ); + + // Broadcast a[0,kr:kr+4]. + __m512i a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-47] = a[0,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); + c_int32_0p2 = _mm512_dpbusd_epi32( c_int32_0p2, a_int32_0, b2 ); + + // Broadcast a[1,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-47] = a[1,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_0, b1 ); + c_int32_1p2 = _mm512_dpbusd_epi32( c_int32_1p2, a_int32_0, b2 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + __m512i b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + __m512i b2 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); + + // Broadcast a[0,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + __m512i a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-47] = a[0,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); + c_int32_0p2 = _mm512_dpbusd_epi32( c_int32_0p2, a_int32_0, b2 ); + + // Broadcast a[1,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-47] = a[1,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_0, b1 ); + c_int32_1p2 = _mm512_dpbusd_epi32( c_int32_1p2, a_int32_0, b2 ); + } + + // Load alpha and beta + __m512i selector1 = _mm512_set1_epi32( alpha ); + __m512i selector2 = _mm512_set1_epi32( beta ); + + // Scale by alpha + c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); + c_int32_0p1 = _mm512_mullo_epi32( selector1, c_int32_0p1 ); + c_int32_0p2 = _mm512_mullo_epi32( selector1, c_int32_0p2 ); + + c_int32_1p0 = _mm512_mullo_epi32( selector1, c_int32_1p0 ); + c_int32_1p1 = _mm512_mullo_epi32( selector1, c_int32_1p1 ); + c_int32_1p2 = _mm512_mullo_epi32( selector1, c_int32_1p2 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p1 = _mm512_add_epi32( selector1, c_int32_0p1 ); + + // c[0,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p2 = _mm512_add_epi32( selector1, c_int32_0p2 ); + + // c[1,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[1,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p1 = _mm512_add_epi32( selector1, c_int32_1p1 ); + + // c[1,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 1 ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p2 = _mm512_add_epi32( selector1, c_int32_1p2 ); + } + + // Store the results. + // c[0,0-15] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 0*16 ), c_int32_0p0 ); + + // c[0, 16-31] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 1*16 ), c_int32_0p1 ); + + // c[0,32-47] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 2*16 ), c_int32_0p2 ); + + // c[1,0-15] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 0*16 ), c_int32_1p0 ); + + // c[1,16-31] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 1*16 ), c_int32_1p1 ); + + // c[1,32-47] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 2*16 ), c_int32_1p2 ); +} + +// 1x48 int8o32 kernel +void lpgemm_rowvar_u8s8s32o32_1x48 + ( + const dim_t k0, + const uint8_t* a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t* b, + const dim_t rs_b, + const dim_t cs_b, + int32_t* c, + const dim_t rs_c, + const int32_t alpha, + const int32_t beta + ) +{ + dim_t k_full_pieces = k0 / 4; + dim_t k_partial_pieces = k0 % 4; + + uint32_t a_kfringe_buf = 0; + + // Registers to use for accumulating C. + __m512i c_int32_0p0 = _mm512_setzero_epi32(); + __m512i c_int32_0p1 = _mm512_setzero_epi32(); + __m512i c_int32_0p2 = _mm512_setzero_epi32(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + __m512i b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + __m512i b2 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 2 ) ); + + // Broadcast a[0,kr:kr+4]. + __m512i a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-47] = a[0,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); + c_int32_0p2 = _mm512_dpbusd_epi32( c_int32_0p2, a_int32_0, b2 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + __m512i b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + __m512i b2 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); + + // Broadcast a[0,kr:kr+4]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); + __m512i a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-47] = a[0,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); + c_int32_0p2 = _mm512_dpbusd_epi32( c_int32_0p2, a_int32_0, b2 ); + } + + // Load alpha and beta + __m512i selector1 = _mm512_set1_epi32( alpha ); + __m512i selector2 = _mm512_set1_epi32( beta ); + + // Scale by alpha + c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); + c_int32_0p1 = _mm512_mullo_epi32( selector1, c_int32_0p1 ); + c_int32_0p2 = _mm512_mullo_epi32( selector1, c_int32_0p2 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p1 = _mm512_add_epi32( selector1, c_int32_0p1 ); + + // c[0,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p2 = _mm512_add_epi32( selector1, c_int32_0p2 ); + } + + // Store the results. + // c[0,0-15] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 0*16 ), c_int32_0p0 ); + + // c[0, 16-31] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 1*16 ), c_int32_0p1 ); + + // c[0,32-47] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 2*16 ), c_int32_0p2 ); +} diff --git a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_n_fringe.h b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_n_fringe.h new file mode 100644 index 0000000000..f2aee7c832 --- /dev/null +++ b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_n_fringe.h @@ -0,0 +1,111 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_GEMM_INT8_NFRINGE +#define BLIS_GEMM_INT8_NFRINGE + +// 6xlt16 int8o32 fringe kernel +void lpgemm_rowvar_u8s8s32o32_6xlt16 + ( + const dim_t m0, + const dim_t k0, + const uint8_t* a, + const dim_t rs_a, + const dim_t cs_a, + const dim_t ps_a, + const int8_t* b, + const dim_t rs_b, + const dim_t cs_b, + int32_t* c, + const dim_t rs_c, + const int32_t alpha, + const int32_t beta, + const dim_t n0_rem + ); + +// 6x16 int8o32 fringe kernel +void lpgemm_rowvar_u8s8s32o32_6x16 + ( + const dim_t m0, + const dim_t k0, + const uint8_t* a, + const dim_t rs_a, + const dim_t cs_a, + const dim_t ps_a, + const int8_t* b, + const dim_t rs_b, + const dim_t cs_b, + int32_t* c, + const dim_t rs_c, + const int32_t alpha, + const int32_t beta + ); + +// 6x32 int8o32 fringe kernel +void lpgemm_rowvar_u8s8s32o32_6x32 + ( + const dim_t m0, + const dim_t k0, + const uint8_t* a, + const dim_t rs_a, + const dim_t cs_a, + const dim_t ps_a, + const int8_t* b, + const dim_t rs_b, + const dim_t cs_b, + int32_t* c, + const dim_t rs_c, + const int32_t alpha, + const int32_t beta + ); + +// 6x48 int8o32 fringe kernel +void lpgemm_rowvar_u8s8s32o32_6x48 + ( + const dim_t m0, + const dim_t k0, + const uint8_t* a, + const dim_t rs_a, + const dim_t cs_a, + const dim_t ps_a, + const int8_t* b, + const dim_t rs_b, + const dim_t cs_b, + int32_t* c, + const dim_t rs_c, + const int32_t alpha, + const int32_t beta + ); + +#endif //BLIS_GEMM_INT8_NFRINGE diff --git a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_n_fringe_amd512vnni.c b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_n_fringe_amd512vnni.c new file mode 100644 index 0000000000..1ab6182f17 --- /dev/null +++ b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_n_fringe_amd512vnni.c @@ -0,0 +1,1561 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include + +#include "blis.h" +#include "lpgemm_n_fringe.h" +#include "lpgemm_mn_fringe.h" + +// 6xlt16 int8o32 fringe kernel +void lpgemm_rowvar_u8s8s32o32_6xlt16 + ( + const dim_t m0, + const dim_t k0, + const uint8_t* a, + const dim_t rs_a, + const dim_t cs_a, + const dim_t ps_a, + const int8_t* b, + const dim_t rs_b, + const dim_t cs_b, + int32_t* c, + const dim_t rs_c, + const int32_t alpha, + const int32_t beta, + const dim_t n0_rem + ) +{ + dim_t MR = 6; + dim_t m_full_pieces = m0 / MR; + dim_t m_full_pieces_loop_limit = m_full_pieces * MR; + dim_t m_partial_pieces = m0 % MR; + + dim_t k_full_pieces = k0 / 4; + dim_t k_partial_pieces = k0 % 4; + + uint32_t a_kfringe_buf = 0; + + // For corner cases. + int32_t buf0[16]; + int32_t buf1[16]; + int32_t buf2[16]; + int32_t buf3[16]; + int32_t buf4[16]; + int32_t buf5[16]; + + for ( dim_t ir = 0; ir < m_full_pieces_loop_limit; ir += MR ) + { + // Registers to use for accumulating C. + __m512i c_int32_0p0 = _mm512_setzero_epi32(); + + __m512i c_int32_1p0 = _mm512_setzero_epi32(); + + __m512i c_int32_2p0 = _mm512_setzero_epi32(); + + __m512i c_int32_3p0 = _mm512_setzero_epi32(); + + __m512i c_int32_4p0 = _mm512_setzero_epi32(); + + __m512i c_int32_5p0 = _mm512_setzero_epi32(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + // Load 4 rows with 16 extended elements each from B to 1 ZMM + // registers. It is to be noted that the B matrix is packed for use + // in vnni instructions and each load to ZMM register will have 4 + // elements along k direction and 16 elements across n directions, + // so 4x16 elements to a ZMM register. + __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+4]. + __m512i a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + + // Broadcast a[1,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-15] = a[1,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + + // Broadcast a[2,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-15] = a[2,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + + // Broadcast a[3,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[3,0-15] = a[3,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_0, b0 ); + + // Broadcast a[4,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[4,0-15] = a[4,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_4p0 = _mm512_dpbusd_epi32( c_int32_4p0, a_int32_0, b0 ); + + // Broadcast a[5,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 5 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[5,0-15] = a[5,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_5p0 = _mm512_dpbusd_epi32( c_int32_5p0, a_int32_0, b0 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + __m512i a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + + // Broadcast a[1,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-15] = a[1,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + + // Broadcast a[2,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-15] = a[2,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + + // Broadcast a[3,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[3,0-15] = a[3,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_0, b0 ); + + // Broadcast a[4,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 4 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[4,0-15] = a[4,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_4p0 = _mm512_dpbusd_epi32( c_int32_4p0, a_int32_0, b0 ); + + // Broadcast a[5,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 5 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[5,0-15] = a[5,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_5p0 = _mm512_dpbusd_epi32( c_int32_5p0, a_int32_0, b0 ); + } + + // Load alpha and beta + __m512i selector1 = _mm512_set1_epi32( alpha ); + __m512i selector2 = _mm512_set1_epi32( beta ); + + // Scale by alpha + c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); + + c_int32_1p0 = _mm512_mullo_epi32( selector1, c_int32_1p0 ); + + c_int32_2p0 = _mm512_mullo_epi32( selector1, c_int32_2p0 ); + + c_int32_3p0 = _mm512_mullo_epi32( selector1, c_int32_3p0 ); + + c_int32_4p0 = _mm512_mullo_epi32( selector1, c_int32_4p0 ); + + c_int32_5p0 = _mm512_mullo_epi32( selector1, c_int32_5p0 ); + + // Scale C by beta. + if ( beta != 0 ) + { + memcpy( buf0, ( c + ( rs_c * ( ir + 0 ) ) ), ( n0_rem * sizeof( int32_t ) ) ); + memcpy( buf1, ( c + ( rs_c * ( ir + 1 ) ) ), ( n0_rem * sizeof( int32_t ) ) ); + memcpy( buf2, ( c + ( rs_c * ( ir + 2 ) ) ), ( n0_rem * sizeof( int32_t ) ) ); + memcpy( buf3, ( c + ( rs_c * ( ir + 3 ) ) ), ( n0_rem * sizeof( int32_t ) ) ); + memcpy( buf4, ( c + ( rs_c * ( ir + 4 ) ) ), ( n0_rem * sizeof( int32_t ) ) ); + memcpy( buf5, ( c + ( rs_c * ( ir + 5 ) ) ), ( n0_rem * sizeof( int32_t ) ) ); + + // c[0,0-15] + selector1 = _mm512_loadu_epi32( buf0 ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[1,0-15] + selector1 = _mm512_loadu_epi32( buf1 ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[2,0-15] + selector1 = _mm512_loadu_epi32( buf2 ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + // c[3,0-15] + selector1 = _mm512_loadu_epi32( buf3 ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); + + // c[4,0-15] + selector1 = _mm512_loadu_epi32( buf4 ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_4p0 = _mm512_add_epi32( selector1, c_int32_4p0 ); + + // c[5,0-15] + selector1 = _mm512_loadu_epi32( buf5 ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_5p0 = _mm512_add_epi32( selector1, c_int32_5p0 ); + } + + // Store the results. + // c[0,0-15] + _mm512_storeu_epi32( buf0, c_int32_0p0 ); + + // c[1,0-15] + _mm512_storeu_epi32( buf1, c_int32_1p0 ); + + // c[2,0-15] + _mm512_storeu_epi32( buf2, c_int32_2p0 ); + + // c[3,0-15] + _mm512_storeu_epi32( buf3, c_int32_3p0 ); + + // c[4,0-15] + _mm512_storeu_epi32( buf4, c_int32_4p0 ); + + // c[5,0-15] + _mm512_storeu_epi32( buf5, c_int32_5p0 ); + + // Memcpy partial parts. + // c[0,0-15] + memcpy( c + ( rs_c * ( ir + 0 ) ) + ( 0*16 ), buf0, ( n0_rem * sizeof( int32_t ) ) ); + + // c[1,0-15] + memcpy( c + ( rs_c * ( ir + 1 ) ) + ( 0*16 ), buf1, ( n0_rem * sizeof( int32_t ) ) ); + + // c[2,0-15] + memcpy( c + ( rs_c * ( ir + 2 ) ) + ( 0*16 ), buf2, ( n0_rem * sizeof( int32_t ) ) ); + + // c[3,0-15] + memcpy( c + ( rs_c * ( ir + 3 ) ) + ( 0*16 ), buf3, ( n0_rem * sizeof( int32_t ) ) ); + + // c[4,0-15] + memcpy( c + ( rs_c * ( ir + 4 ) ) + ( 0*16 ), buf4, ( n0_rem * sizeof( int32_t ) ) ); + + // c[5,0-15] + memcpy( c + ( rs_c * ( ir + 5 ) ) + ( 0*16 ), buf5, ( n0_rem * sizeof( int32_t ) ) ); + + a = a + ( MR * ps_a ); + } + + if ( m_partial_pieces > 0 ) + { + if ( m_partial_pieces == 5 ) + { + dim_t cs_a_use = ( cs_a == 4 ) ? 4 : ( ( cs_a / 6 ) * 5 ); + lpgemm_rowvar_u8s8s32o32_5xlt16 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, n0_rem + ); + } + else if ( m_partial_pieces == 4 ) + { + dim_t cs_a_use = ( cs_a == 4 ) ? 4 : ( ( cs_a / 6 ) * 4 ); + lpgemm_rowvar_u8s8s32o32_4xlt16 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, n0_rem + ); + } + else if ( m_partial_pieces == 3 ) + { + dim_t cs_a_use = ( cs_a == 4 ) ? 4 : ( ( cs_a / 6 ) * 3 ); + lpgemm_rowvar_u8s8s32o32_3xlt16 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, n0_rem + ); + } + else if ( m_partial_pieces == 2 ) + { + dim_t cs_a_use = ( cs_a == 4 ) ? 4 : ( ( cs_a / 6 ) * 2 ); + lpgemm_rowvar_u8s8s32o32_2xlt16 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, n0_rem + ); + } + else if ( m_partial_pieces == 1 ) + { + dim_t cs_a_use = ( cs_a == 4 ) ? 4 : ( ( cs_a / 6 ) * 1 ); + lpgemm_rowvar_u8s8s32o32_1xlt16 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, n0_rem + ); + } + } +} + +// 6x16 int8o32 fringe kernel +void lpgemm_rowvar_u8s8s32o32_6x16 + ( + const dim_t m0, + const dim_t k0, + const uint8_t* a, + const dim_t rs_a, + const dim_t cs_a, + const dim_t ps_a, + const int8_t* b, + const dim_t rs_b, + const dim_t cs_b, + int32_t* c, + const dim_t rs_c, + const int32_t alpha, + const int32_t beta + ) +{ + dim_t MR = 6; + dim_t m_full_pieces = m0 / MR; + dim_t m_full_pieces_loop_limit = m_full_pieces * MR; + dim_t m_partial_pieces = m0 % MR; + + dim_t k_full_pieces = k0 / 4; + dim_t k_partial_pieces = k0 % 4; + + uint32_t a_kfringe_buf = 0; + + for ( dim_t ir = 0; ir < m_full_pieces_loop_limit; ir += MR ) + { + // Registers to use for accumulating C. + __m512i c_int32_0p0 = _mm512_setzero_epi32(); + + __m512i c_int32_1p0 = _mm512_setzero_epi32(); + + __m512i c_int32_2p0 = _mm512_setzero_epi32(); + + __m512i c_int32_3p0 = _mm512_setzero_epi32(); + + __m512i c_int32_4p0 = _mm512_setzero_epi32(); + + __m512i c_int32_5p0 = _mm512_setzero_epi32(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + // Load 4 rows with 16 elements each from B to 1 ZMM registers. It + // is to be noted that the B matrix is packed for use in vnni + // instructions and each load to ZMM register will have 4 elements + // along k direction and 16 elements across n directions, so 4x16 + // elements to a ZMM register. + __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+4]. + __m512i a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + + // Broadcast a[1,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-15] = a[1,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + + // Broadcast a[2,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-15] = a[2,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + + // Broadcast a[3,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[3,0-15] = a[3,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_0, b0 ); + + // Broadcast a[4,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[4,0-15] = a[4,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_4p0 = _mm512_dpbusd_epi32( c_int32_4p0, a_int32_0, b0 ); + + // Broadcast a[5,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 5 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[5,0-15] = a[5,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_5p0 = _mm512_dpbusd_epi32( c_int32_5p0, a_int32_0, b0 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + __m512i a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + + // Broadcast a[1,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-15] = a[1,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + + // Broadcast a[2,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-15] = a[2,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + + // Broadcast a[3,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[3,0-15] = a[3,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_0, b0 ); + + // Broadcast a[4,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 4 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[4,0-15] = a[4,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_4p0 = _mm512_dpbusd_epi32( c_int32_4p0, a_int32_0, b0 ); + + // Broadcast a[5,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 5 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[5,0-15] = a[5,kr:kr+4]*b[kr:kr+4,0-15] + c_int32_5p0 = _mm512_dpbusd_epi32( c_int32_5p0, a_int32_0, b0 ); + } + + // Load alpha and beta + __m512i selector1 = _mm512_set1_epi32( alpha ); + __m512i selector2 = _mm512_set1_epi32( beta ); + + // Scale by alpha + c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); + + c_int32_1p0 = _mm512_mullo_epi32( selector1, c_int32_1p0 ); + + c_int32_2p0 = _mm512_mullo_epi32( selector1, c_int32_2p0 ); + + c_int32_3p0 = _mm512_mullo_epi32( selector1, c_int32_3p0 ); + + c_int32_4p0 = _mm512_mullo_epi32( selector1, c_int32_4p0 ); + + c_int32_5p0 = _mm512_mullo_epi32( selector1, c_int32_5p0 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 0 ) ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[1,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 1 ) ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[2,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 2 ) ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + // c[3,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 3 ) ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); + + // c[4,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 4 ) ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_4p0 = _mm512_add_epi32( selector1, c_int32_4p0 ); + + // c[5,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 5 ) ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_5p0 = _mm512_add_epi32( selector1, c_int32_5p0 ); + } + + // Store the results. + // c[0,0-15] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 0 ) ) + ( 0*16 ), c_int32_0p0 ); + + // c[1,0-15] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 1 ) ) + ( 0*16 ), c_int32_1p0 ); + + // c[2,0-15] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 2 ) ) + ( 0*16 ), c_int32_2p0 ); + + // c[3,0-15] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 3 ) ) + ( 0*16 ), c_int32_3p0 ); + + // c[4,0-15] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 4 ) ) + ( 0*16 ), c_int32_4p0 ); + + // c[5,0-15] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 5 ) ) + ( 0*16 ), c_int32_5p0 ); + + a = a + ( MR * ps_a ); + } + + if ( m_partial_pieces > 0 ) + { + if ( m_partial_pieces == 5 ) + { + dim_t cs_a_use = ( cs_a == 4 ) ? 4 : ( ( cs_a / 6 ) * 5 ); + lpgemm_rowvar_u8s8s32o32_5x16 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta + ); + } + else if ( m_partial_pieces == 4 ) + { + dim_t cs_a_use = ( cs_a == 4 ) ? 4 : ( ( cs_a / 6 ) * 4 ); + lpgemm_rowvar_u8s8s32o32_4x16 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta + ); + } + else if ( m_partial_pieces == 3 ) + { + dim_t cs_a_use = ( cs_a == 4 ) ? 4 : ( ( cs_a / 6 ) * 3 ); + lpgemm_rowvar_u8s8s32o32_3x16 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta + ); + } + else if ( m_partial_pieces == 2 ) + { + dim_t cs_a_use = ( cs_a == 4 ) ? 4 : ( ( cs_a / 6 ) * 2 ); + lpgemm_rowvar_u8s8s32o32_2x16 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta + ); + } + else if ( m_partial_pieces == 1 ) + { + dim_t cs_a_use = ( cs_a == 4 ) ? 4 : ( ( cs_a / 6 ) * 1 ); + lpgemm_rowvar_u8s8s32o32_1x16 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta + ); + } + } +} + +// 6x32 int8o32 fringe kernel +void lpgemm_rowvar_u8s8s32o32_6x32 + ( + const dim_t m0, + const dim_t k0, + const uint8_t* a, + const dim_t rs_a, + const dim_t cs_a, + const dim_t ps_a, + const int8_t* b, + const dim_t rs_b, + const dim_t cs_b, + int32_t* c, + const dim_t rs_c, + const int32_t alpha, + const int32_t beta + ) +{ + dim_t MR = 6; + dim_t m_full_pieces = m0 / MR; + dim_t m_full_pieces_loop_limit = m_full_pieces * MR; + dim_t m_partial_pieces = m0 % MR; + + dim_t k_full_pieces = k0 / 4; + dim_t k_partial_pieces = k0 % 4; + + uint32_t a_kfringe_buf = 0; + + for ( dim_t ir = 0; ir < m_full_pieces_loop_limit; ir += MR ) + { + // Registers to use for accumulating C. + __m512i c_int32_0p0 = _mm512_setzero_epi32(); + __m512i c_int32_0p1 = _mm512_setzero_epi32(); + + __m512i c_int32_1p0 = _mm512_setzero_epi32(); + __m512i c_int32_1p1 = _mm512_setzero_epi32(); + + __m512i c_int32_2p0 = _mm512_setzero_epi32(); + __m512i c_int32_2p1 = _mm512_setzero_epi32(); + + __m512i c_int32_3p0 = _mm512_setzero_epi32(); + __m512i c_int32_3p1 = _mm512_setzero_epi32(); + + __m512i c_int32_4p0 = _mm512_setzero_epi32(); + __m512i c_int32_4p1 = _mm512_setzero_epi32(); + + __m512i c_int32_5p0 = _mm512_setzero_epi32(); + __m512i c_int32_5p1 = _mm512_setzero_epi32(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + // Load 4 rows with 32 elements each from B to 2 ZMM registers. It + // is to be noted that the B matrix is packed for use in vnni + // instructions and each load to ZMM register will have 4 elements + // along k direction and 16 elements across n directions, so 4x16 + // elements to a ZMM register. + __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + __m512i b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + + // Broadcast a[0,kr:kr+4]. + __m512i a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-31] = a[0,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); + + // Broadcast a[1,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-31] = a[1,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_0, b1 ); + + // Broadcast a[2,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-31] = a[2,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + c_int32_2p1 = _mm512_dpbusd_epi32( c_int32_2p1, a_int32_0, b1 ); + + // Broadcast a[3,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[3,0-31] = a[3,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_0, b0 ); + c_int32_3p1 = _mm512_dpbusd_epi32( c_int32_3p1, a_int32_0, b1 ); + + // Broadcast a[4,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[4,0-31] = a[4,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_4p0 = _mm512_dpbusd_epi32( c_int32_4p0, a_int32_0, b0 ); + c_int32_4p1 = _mm512_dpbusd_epi32( c_int32_4p1, a_int32_0, b1 ); + + // Broadcast a[5,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 5 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[5,0-31] = a[5,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_5p0 = _mm512_dpbusd_epi32( c_int32_5p0, a_int32_0, b0 ); + c_int32_5p1 = _mm512_dpbusd_epi32( c_int32_5p1, a_int32_0, b1 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + __m512i b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + + // Broadcast a[0,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + __m512i a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-31] = a[0,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); + + // Broadcast a[1,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-31] = a[1,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_0, b1 ); + + // Broadcast a[2,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-31] = a[2,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + c_int32_2p1 = _mm512_dpbusd_epi32( c_int32_2p1, a_int32_0, b1 ); + + // Broadcast a[3,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[3,0-31] = a[3,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_0, b0 ); + c_int32_3p1 = _mm512_dpbusd_epi32( c_int32_3p1, a_int32_0, b1 ); + + // Broadcast a[4,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 4 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[4,0-31] = a[4,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_4p0 = _mm512_dpbusd_epi32( c_int32_4p0, a_int32_0, b0 ); + c_int32_4p1 = _mm512_dpbusd_epi32( c_int32_4p1, a_int32_0, b1 ); + + // Broadcast a[5,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 5 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[5,0-31] = a[5,kr:kr+4]*b[kr:kr+4,0-31] + c_int32_5p0 = _mm512_dpbusd_epi32( c_int32_5p0, a_int32_0, b0 ); + c_int32_5p1 = _mm512_dpbusd_epi32( c_int32_5p1, a_int32_0, b1 ); + } + + // Load alpha and beta + __m512i selector1 = _mm512_set1_epi32( alpha ); + __m512i selector2 = _mm512_set1_epi32( beta ); + + // Scale by alpha + c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); + c_int32_0p1 = _mm512_mullo_epi32( selector1, c_int32_0p1 ); + + c_int32_1p0 = _mm512_mullo_epi32( selector1, c_int32_1p0 ); + c_int32_1p1 = _mm512_mullo_epi32( selector1, c_int32_1p1 ); + + c_int32_2p0 = _mm512_mullo_epi32( selector1, c_int32_2p0 ); + c_int32_2p1 = _mm512_mullo_epi32( selector1, c_int32_2p1 ); + + c_int32_3p0 = _mm512_mullo_epi32( selector1, c_int32_3p0 ); + c_int32_3p1 = _mm512_mullo_epi32( selector1, c_int32_3p1 ); + + c_int32_4p0 = _mm512_mullo_epi32( selector1, c_int32_4p0 ); + c_int32_4p1 = _mm512_mullo_epi32( selector1, c_int32_4p1 ); + + c_int32_5p0 = _mm512_mullo_epi32( selector1, c_int32_5p0 ); + c_int32_5p1 = _mm512_mullo_epi32( selector1, c_int32_5p1 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 0 ) ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 0 ) ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p1 = _mm512_add_epi32( selector1, c_int32_0p1 ); + + // c[1,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 1 ) ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[1,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 1 ) ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p1 = _mm512_add_epi32( selector1, c_int32_1p1 ); + + // c[2,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 2 ) ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + // c[2,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 2 ) ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p1 = _mm512_add_epi32( selector1, c_int32_2p1 ); + + // c[3,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 3 ) ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); + + // c[3,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 3 ) ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_3p1 = _mm512_add_epi32( selector1, c_int32_3p1 ); + + // c[4,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 4 ) ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_4p0 = _mm512_add_epi32( selector1, c_int32_4p0 ); + + // c[4,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 4 ) ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_4p1 = _mm512_add_epi32( selector1, c_int32_4p1 ); + + // c[5,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 5 ) ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_5p0 = _mm512_add_epi32( selector1, c_int32_5p0 ); + + // c[5,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 5 ) ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_5p1 = _mm512_add_epi32( selector1, c_int32_5p1 ); + } + + // Store the results. + // c[0,0-15] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 0 ) ) + ( 0*16 ), c_int32_0p0 ); + + // c[0, 16-31] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 0 ) ) + ( 1*16 ), c_int32_0p1 ); + + // c[1,0-15] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 1 ) ) + ( 0*16 ), c_int32_1p0 ); + + // c[1,16-31] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 1 ) ) + ( 1*16 ), c_int32_1p1 ); + + // c[2,0-15] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 2 ) ) + ( 0*16 ), c_int32_2p0 ); + + // c[2,16-31] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 2 ) ) + ( 1*16 ), c_int32_2p1 ); + + // c[3,0-15] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 3 ) ) + ( 0*16 ), c_int32_3p0 ); + + // c[3,16-31] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 3 ) ) + ( 1*16 ), c_int32_3p1 ); + + // c[4,0-15] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 4 ) ) + ( 0*16 ), c_int32_4p0 ); + + // c[4,16-31] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 4 ) ) + ( 1*16 ), c_int32_4p1 ); + + // c[5,0-15] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 5 ) ) + ( 0*16 ), c_int32_5p0 ); + + // c[5,16-31] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 5 ) ) + ( 1*16 ), c_int32_5p1 ); + + a = a + ( MR * ps_a ); + } + + if ( m_partial_pieces > 0 ) + { + if ( m_partial_pieces == 5 ) + { + dim_t cs_a_use = ( cs_a == 4 ) ? 4 : ( ( cs_a / 6 ) * 5 ); + lpgemm_rowvar_u8s8s32o32_5x32 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta + ); + } + else if ( m_partial_pieces == 4 ) + { + dim_t cs_a_use = ( cs_a == 4 ) ? 4 : ( ( cs_a / 6 ) * 4 ); + lpgemm_rowvar_u8s8s32o32_4x32 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta + ); + } + else if ( m_partial_pieces == 3 ) + { + dim_t cs_a_use = ( cs_a == 4 ) ? 4 : ( ( cs_a / 6 ) * 3 ); + lpgemm_rowvar_u8s8s32o32_3x32 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta + ); + } + else if ( m_partial_pieces == 2 ) + { + dim_t cs_a_use = ( cs_a == 4 ) ? 4 : ( ( cs_a / 6 ) * 2 ); + lpgemm_rowvar_u8s8s32o32_2x32 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta + ); + } + else if ( m_partial_pieces == 1 ) + { + dim_t cs_a_use = ( cs_a == 4 ) ? 4 : ( ( cs_a / 6 ) * 1 ); + lpgemm_rowvar_u8s8s32o32_1x32 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta + ); + } + } +} + +// 6x48 int8o32 fringe kernel +void lpgemm_rowvar_u8s8s32o32_6x48 + ( + const dim_t m0, + const dim_t k0, + const uint8_t* a, + const dim_t rs_a, + const dim_t cs_a, + const dim_t ps_a, + const int8_t* b, + const dim_t rs_b, + const dim_t cs_b, + int32_t* c, + const dim_t rs_c, + const int32_t alpha, + const int32_t beta + ) +{ + dim_t MR = 6; + dim_t m_full_pieces = m0 / MR; + dim_t m_full_pieces_loop_limit = m_full_pieces * MR; + dim_t m_partial_pieces = m0 % MR; + + dim_t k_full_pieces = k0 / 4; + dim_t k_partial_pieces = k0 % 4; + + uint32_t a_kfringe_buf = 0; + + for ( dim_t ir = 0; ir < m_full_pieces_loop_limit; ir += MR ) + { + // Registers to use for accumulating C. + __m512i c_int32_0p0 = _mm512_setzero_epi32(); + __m512i c_int32_0p1 = _mm512_setzero_epi32(); + __m512i c_int32_0p2 = _mm512_setzero_epi32(); + + __m512i c_int32_1p0 = _mm512_setzero_epi32(); + __m512i c_int32_1p1 = _mm512_setzero_epi32(); + __m512i c_int32_1p2 = _mm512_setzero_epi32(); + + __m512i c_int32_2p0 = _mm512_setzero_epi32(); + __m512i c_int32_2p1 = _mm512_setzero_epi32(); + __m512i c_int32_2p2 = _mm512_setzero_epi32(); + + __m512i c_int32_3p0 = _mm512_setzero_epi32(); + __m512i c_int32_3p1 = _mm512_setzero_epi32(); + __m512i c_int32_3p2 = _mm512_setzero_epi32(); + + __m512i c_int32_4p0 = _mm512_setzero_epi32(); + __m512i c_int32_4p1 = _mm512_setzero_epi32(); + __m512i c_int32_4p2 = _mm512_setzero_epi32(); + + __m512i c_int32_5p0 = _mm512_setzero_epi32(); + __m512i c_int32_5p1 = _mm512_setzero_epi32(); + __m512i c_int32_5p2 = _mm512_setzero_epi32(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + // Load 4 rows with 48 elements each from B to 3 ZMM registers. It + // is to be noted that the B matrix is packed for use in vnni + // instructions and each load to ZMM register will have 4 elements + // along k direction and 16 elements across n directions, so 4x16 + // elements to a ZMM register. + __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + __m512i b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + __m512i b2 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 2 ) ); + + // Broadcast a[0,kr:kr+4]. + __m512i a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-47] = a[0,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); + c_int32_0p2 = _mm512_dpbusd_epi32( c_int32_0p2, a_int32_0, b2 ); + + // Broadcast a[1,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-47] = a[1,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_0, b1 ); + c_int32_1p2 = _mm512_dpbusd_epi32( c_int32_1p2, a_int32_0, b2 ); + + // Broadcast a[2,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-47] = a[2,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + c_int32_2p1 = _mm512_dpbusd_epi32( c_int32_2p1, a_int32_0, b1 ); + c_int32_2p2 = _mm512_dpbusd_epi32( c_int32_2p2, a_int32_0, b2 ); + + // Broadcast a[3,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[3,0-47] = a[3,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_0, b0 ); + c_int32_3p1 = _mm512_dpbusd_epi32( c_int32_3p1, a_int32_0, b1 ); + c_int32_3p2 = _mm512_dpbusd_epi32( c_int32_3p2, a_int32_0, b2 ); + + // Broadcast a[4,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[4,0-47] = a[4,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_4p0 = _mm512_dpbusd_epi32( c_int32_4p0, a_int32_0, b0 ); + c_int32_4p1 = _mm512_dpbusd_epi32( c_int32_4p1, a_int32_0, b1 ); + c_int32_4p2 = _mm512_dpbusd_epi32( c_int32_4p2, a_int32_0, b2 ); + + // Broadcast a[5,kr:kr+4]. + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 5 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 4. + // c[5,0-47] = a[5,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_5p0 = _mm512_dpbusd_epi32( c_int32_5p0, a_int32_0, b0 ); + c_int32_5p1 = _mm512_dpbusd_epi32( c_int32_5p1, a_int32_0, b1 ); + c_int32_5p2 = _mm512_dpbusd_epi32( c_int32_5p2, a_int32_0, b2 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + __m512i b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + __m512i b2 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); + + // Broadcast a[0,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + __m512i a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[0,0-47] = a[0,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); + c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); + c_int32_0p2 = _mm512_dpbusd_epi32( c_int32_0p2, a_int32_0, b2 ); + + // Broadcast a[1,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[1,0-47] = a[1,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_1p0 = _mm512_dpbusd_epi32( c_int32_1p0, a_int32_0, b0 ); + c_int32_1p1 = _mm512_dpbusd_epi32( c_int32_1p1, a_int32_0, b1 ); + c_int32_1p2 = _mm512_dpbusd_epi32( c_int32_1p2, a_int32_0, b2 ); + + // Broadcast a[2,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[2,0-47] = a[2,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_2p0 = _mm512_dpbusd_epi32( c_int32_2p0, a_int32_0, b0 ); + c_int32_2p1 = _mm512_dpbusd_epi32( c_int32_2p1, a_int32_0, b1 ); + c_int32_2p2 = _mm512_dpbusd_epi32( c_int32_2p2, a_int32_0, b2 ); + + // Broadcast a[3,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[3,0-47] = a[3,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_3p0 = _mm512_dpbusd_epi32( c_int32_3p0, a_int32_0, b0 ); + c_int32_3p1 = _mm512_dpbusd_epi32( c_int32_3p1, a_int32_0, b1 ); + c_int32_3p2 = _mm512_dpbusd_epi32( c_int32_3p2, a_int32_0, b2 ); + + // Broadcast a[4,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 4 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[4,0-47] = a[4,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_4p0 = _mm512_dpbusd_epi32( c_int32_4p0, a_int32_0, b0 ); + c_int32_4p1 = _mm512_dpbusd_epi32( c_int32_4p1, a_int32_0, b1 ); + c_int32_4p2 = _mm512_dpbusd_epi32( c_int32_4p2, a_int32_0, b2 ); + + // Broadcast a[5,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 5 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( uint8_t ) ) + ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 4. + // c[5,0-47] = a[5,kr:kr+4]*b[kr:kr+4,0-47] + c_int32_5p0 = _mm512_dpbusd_epi32( c_int32_5p0, a_int32_0, b0 ); + c_int32_5p1 = _mm512_dpbusd_epi32( c_int32_5p1, a_int32_0, b1 ); + c_int32_5p2 = _mm512_dpbusd_epi32( c_int32_5p2, a_int32_0, b2 ); + } + + // Load alpha and beta + __m512i selector1 = _mm512_set1_epi32( alpha ); + __m512i selector2 = _mm512_set1_epi32( beta ); + + // Scale by alpha + c_int32_0p0 = _mm512_mullo_epi32( selector1, c_int32_0p0 ); + c_int32_0p1 = _mm512_mullo_epi32( selector1, c_int32_0p1 ); + c_int32_0p2 = _mm512_mullo_epi32( selector1, c_int32_0p2 ); + + c_int32_1p0 = _mm512_mullo_epi32( selector1, c_int32_1p0 ); + c_int32_1p1 = _mm512_mullo_epi32( selector1, c_int32_1p1 ); + c_int32_1p2 = _mm512_mullo_epi32( selector1, c_int32_1p2 ); + + c_int32_2p0 = _mm512_mullo_epi32( selector1, c_int32_2p0 ); + c_int32_2p1 = _mm512_mullo_epi32( selector1, c_int32_2p1 ); + c_int32_2p2 = _mm512_mullo_epi32( selector1, c_int32_2p2 ); + + c_int32_3p0 = _mm512_mullo_epi32( selector1, c_int32_3p0 ); + c_int32_3p1 = _mm512_mullo_epi32( selector1, c_int32_3p1 ); + c_int32_3p2 = _mm512_mullo_epi32( selector1, c_int32_3p2 ); + + c_int32_4p0 = _mm512_mullo_epi32( selector1, c_int32_4p0 ); + c_int32_4p1 = _mm512_mullo_epi32( selector1, c_int32_4p1 ); + c_int32_4p2 = _mm512_mullo_epi32( selector1, c_int32_4p2 ); + + c_int32_5p0 = _mm512_mullo_epi32( selector1, c_int32_5p0 ); + c_int32_5p1 = _mm512_mullo_epi32( selector1, c_int32_5p1 ); + c_int32_5p2 = _mm512_mullo_epi32( selector1, c_int32_5p2 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 0 ) ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 0 ) ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p1 = _mm512_add_epi32( selector1, c_int32_0p1 ); + + // c[0,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 0 ) ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_0p2 = _mm512_add_epi32( selector1, c_int32_0p2 ); + + // c[1,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 1 ) ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[1,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 1 ) ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p1 = _mm512_add_epi32( selector1, c_int32_1p1 ); + + // c[1,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 1 ) ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_1p2 = _mm512_add_epi32( selector1, c_int32_1p2 ); + + // c[2,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 2 ) ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + // c[2,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 2 ) ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p1 = _mm512_add_epi32( selector1, c_int32_2p1 ); + + // c[2,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 2 ) ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_2p2 = _mm512_add_epi32( selector1, c_int32_2p2 ); + + // c[3,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 3 ) ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); + + // c[3,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 3 ) ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_3p1 = _mm512_add_epi32( selector1, c_int32_3p1 ); + + // c[3,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 3 ) ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_3p2 = _mm512_add_epi32( selector1, c_int32_3p2 ); + + // c[4,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 4 ) ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_4p0 = _mm512_add_epi32( selector1, c_int32_4p0 ); + + // c[4,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 4 ) ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_4p1 = _mm512_add_epi32( selector1, c_int32_4p1 ); + + // c[4,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 4 ) ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_4p2 = _mm512_add_epi32( selector1, c_int32_4p2 ); + + // c[5,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 5 ) ) + ( 0*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_5p0 = _mm512_add_epi32( selector1, c_int32_5p0 ); + + // c[5,16-31] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 5 ) ) + ( 1*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_5p1 = _mm512_add_epi32( selector1, c_int32_5p1 ); + + // c[5,32-47] + selector1 = _mm512_loadu_epi32( c + ( rs_c * ( ir + 5 ) ) + ( 2*16 ) ); + selector1 = _mm512_mullo_epi32( selector2, selector1 ); + c_int32_5p2 = _mm512_add_epi32( selector1, c_int32_5p2 ); + } + + // Store the results. + // c[0,0-15] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 0 ) ) + ( 0*16 ), c_int32_0p0 ); + + // c[0, 16-31] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 0 ) ) + ( 1*16 ), c_int32_0p1 ); + + // c[0,32-47] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 0 ) ) + ( 2*16 ), c_int32_0p2 ); + + // c[1,0-15] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 1 ) ) + ( 0*16 ), c_int32_1p0 ); + + // c[1,16-31] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 1 ) ) + ( 1*16 ), c_int32_1p1 ); + + // c[1,32-47] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 1 ) ) + ( 2*16 ), c_int32_1p2 ); + + // c[2,0-15] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 2 ) ) + ( 0*16 ), c_int32_2p0 ); + + // c[2,16-31] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 2 ) ) + ( 1*16 ), c_int32_2p1 ); + + // c[2,32-47] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 2 ) ) + ( 2*16 ), c_int32_2p2 ); + + // c[3,0-15] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 3 ) ) + ( 0*16 ), c_int32_3p0 ); + + // c[3,16-31] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 3 ) ) + ( 1*16 ), c_int32_3p1 ); + + // c[3,32-47] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 3 ) ) + ( 2*16 ), c_int32_3p2 ); + + // c[4,0-15] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 4 ) ) + ( 0*16 ), c_int32_4p0 ); + + // c[4,16-31] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 4 ) ) + ( 1*16 ), c_int32_4p1 ); + + // c[4,32-47] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 4 ) ) + ( 2*16 ), c_int32_4p2 ); + + // c[5,0-15] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 5 ) ) + ( 0*16 ), c_int32_5p0 ); + + // c[5,16-31] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 5 ) ) + ( 1*16 ), c_int32_5p1 ); + + // c[5,32-47] + _mm512_storeu_epi32( c + ( rs_c * ( ir + 5 ) ) + ( 2*16 ), c_int32_5p2 ); + + a = a + ( MR * ps_a ); + } + + if ( m_partial_pieces > 0 ) + { + if ( m_partial_pieces == 5 ) + { + dim_t cs_a_use = ( cs_a == 4 ) ? 4 : ( ( cs_a / 6 ) * 5 ); + lpgemm_rowvar_u8s8s32o32_5x48 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta + ); + } + else if ( m_partial_pieces == 4 ) + { + dim_t cs_a_use = ( cs_a == 4 ) ? 4 : ( ( cs_a / 6 ) * 4 ); + lpgemm_rowvar_u8s8s32o32_4x48 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta + ); + } + else if ( m_partial_pieces == 3 ) + { + dim_t cs_a_use = ( cs_a == 4 ) ? 4 : ( ( cs_a / 6 ) * 3 ); + lpgemm_rowvar_u8s8s32o32_3x48 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta + ); + } + else if ( m_partial_pieces == 2 ) + { + dim_t cs_a_use = ( cs_a == 4 ) ? 4 : ( ( cs_a / 6 ) * 2 ); + lpgemm_rowvar_u8s8s32o32_2x48 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta + ); + } + else if ( m_partial_pieces == 1 ) + { + dim_t cs_a_use = ( cs_a == 4 ) ? 4 : ( ( cs_a / 6 ) * 1 ); + lpgemm_rowvar_u8s8s32o32_1x48 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta + ); + } + } +} diff --git a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_packa.h b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_packa.h new file mode 100644 index 0000000000..b983b0c617 --- /dev/null +++ b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_packa.h @@ -0,0 +1,55 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_GEMM_INT8_PACKA +#define BLIS_GEMM_INT8_PACKA + +void get_packa_k64_u8s8s32o32_strides + ( + dim_t* rs_a, + dim_t* cs_a + ); + +void packa_k64_u8s8s32o32 + ( + uint8_t* pack_a_buffer_u8s8s32o32, + const uint8_t* a, + const dim_t lda, + const dim_t MC, + const dim_t KC, + dim_t* rs_a, + dim_t* cs_a + ); + +#endif //BLIS_GEMM_INT8_PACKA diff --git a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_packa_amd512vnni.c b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_packa_amd512vnni.c new file mode 100644 index 0000000000..03bc7db03f --- /dev/null +++ b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_packa_amd512vnni.c @@ -0,0 +1,518 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include + +#include "blis.h" +#include "lpgemm_packa.h" + +#define MR 6 +#define NR 64 + +void packa_m5_k64_u8s8s32o32 + ( + uint8_t* pack_a_buffer_u8s8s32o32, + const uint8_t* a, + const dim_t lda, + const dim_t KC + ); + +void packa_m4_k64_u8s8s32o32 + ( + uint8_t* pack_a_buffer_u8s8s32o32, + const uint8_t* a, + const dim_t lda, + const dim_t KC + ); + +void packa_m3_k64_u8s8s32o32 + ( + uint8_t* pack_a_buffer_u8s8s32o32, + const uint8_t* a, + const dim_t lda, + const dim_t KC + ); + +void packa_m2_k64_u8s8s32o32 + ( + uint8_t* pack_a_buffer_u8s8s32o32, + const uint8_t* a, + const dim_t lda, + const dim_t KC + ); + +void packa_m1_k64_u8s8s32o32 + ( + uint8_t* pack_a_buffer_u8s8s32o32, + const uint8_t* a, + const dim_t lda, + const dim_t KC + ); + +void get_packa_k64_u8s8s32o32_strides + ( + dim_t* rs_a, + dim_t* cs_a + ) +{ + *rs_a = 4; + *cs_a = 24; +} + +// TODO: k fringe till k=4, k%4=0 and padding to make k'%4 = 0 if k%4 != 0 originally. +void packa_k64_u8s8s32o32 + ( + uint8_t* pack_a_buffer_u8s8s32o32, + const uint8_t* a, + const dim_t lda, + const dim_t MC, + const dim_t KC, + dim_t* rs_a, + dim_t* cs_a + ) +{ + // Used for permuting the mm512i elements for use in vpdpbusd instruction. + // These are indexes of the format a0-a1-b0-b1-a2-a3-b2-b3 and a0-a1-a2-a3-b0-b1-b2-b3. + // Adding 4 int32 wise gives format a4-a5-b4-b5-a6-a7-b6-b7 and a4-a5-a6-a7-b4-b5-b6-b7. + __m512i selector1 = _mm512_setr_epi64( 0x0, 0x1, 0x8, 0x9, 0x2, 0x3, 0xA, 0xB ); + __m512i selector1_1 = _mm512_setr_epi64( 0x4, 0x5, 0xC, 0xD, 0x6, 0x7, 0xE, 0xF ); + __m512i selector2 = _mm512_setr_epi64( 0x0, 0x1, 0x2, 0x3, 0x8, 0x9, 0xA, 0xB ); + __m512i selector2_1 = _mm512_setr_epi64( 0x4, 0x5, 0x6, 0x7, 0xC, 0xD, 0xE, 0xF ); + + // First half. + __m512i selector3 = _mm512_setr_epi64( 0x0, 0x1, 0x8, 0x2, 0x3, 0x9, 0x4, 0x5 ); // 64 elems + __m512i selector4 = _mm512_setr_epi64( 0x8, 0x6, 0x7, 0x9, 0x0, 0x0, 0x0, 0x0 ); // 32 elems + __m512i selector5 = _mm512_setr_epi64( 0x0, 0x1, 0xA, 0x2, 0x3, 0xB, 0x4, 0x5 ); // 64 elems + __m512i selector6 = _mm512_setr_epi64( 0xA, 0x6, 0x7, 0xB, 0x0, 0x0, 0x0, 0x0 ); // 32 elems + + // Second half. + __m512i selector7 = _mm512_setr_epi64( 0x0, 0x1, 0xC, 0x2, 0x3, 0xD, 0x4, 0x5 ); // 64 elems + __m512i selector8 = _mm512_setr_epi64( 0xC, 0x6, 0x7, 0xD, 0x0, 0x0, 0x0, 0x0 ); // 32 elems + __m512i selector9 = _mm512_setr_epi64( 0x0, 0x1, 0xE, 0x2, 0x3, 0xF, 0x4, 0x5 ); // 64 elems + __m512i selector10 = _mm512_setr_epi64( 0xE, 0x6, 0x7, 0xF, 0x0, 0x0, 0x0, 0x0 ); // 32 elems + + dim_t m_full_pieces = MC / MR; + dim_t m_full_pieces_loop_limit = m_full_pieces * MR; + dim_t m_partial_pieces = MC % MR; + + __m512i a0; + __m512i b0; + __m512i c0; + __m512i d0; + __m512i e0; + __m512i f0; + __m512i a01; + __m512i c01; + __m512i e01; + __m256i last_piece; + + for ( dim_t ic = 0; ic < m_full_pieces_loop_limit; ic += MR ) + { + for ( dim_t kr = 0; kr < KC; kr += NR ) + { + // Rearrange for vpdpbusd, read 6 rows from A with 64 elements in each row. + a0 = _mm512_loadu_epi8( a + ( lda * ( ic + 0 ) ) + kr ); + b0 = _mm512_loadu_epi8( a + ( lda * ( ic + 1 ) ) + kr ); + c0 = _mm512_loadu_epi8( a + ( lda * ( ic + 2 ) ) + kr ); + d0 = _mm512_loadu_epi8( a + ( lda * ( ic + 3 ) ) + kr ); + e0 = _mm512_loadu_epi8( a + ( lda * ( ic + 4 ) ) + kr ); + f0 = _mm512_loadu_epi8( a + ( lda * ( ic + 5 ) ) + kr ); + + a01 = _mm512_unpacklo_epi32( a0, b0 ); + a0 = _mm512_unpackhi_epi32( a0, b0 ); + + c01 = _mm512_unpacklo_epi32( c0, d0 ); + c0 = _mm512_unpackhi_epi32( c0, d0 ); + + e01 = _mm512_unpacklo_epi32( e0, f0 ); // Elem 4 + e0 = _mm512_unpackhi_epi32( e0, f0 ); // Elem 5 + + b0 = _mm512_unpacklo_epi64( a01, c01 ); + a01 = _mm512_unpackhi_epi64( a01, c01 ); + + d0 = _mm512_unpacklo_epi64( a0, c0 ); + c01 = _mm512_unpackhi_epi64( a0, c0 ); + + a0 = _mm512_permutex2var_epi64( b0, selector1, a01 ); + c0 = _mm512_permutex2var_epi64( d0, selector1, c01 ); + b0 = _mm512_permutex2var_epi64( b0, selector1_1, a01 ); + d0 = _mm512_permutex2var_epi64( d0, selector1_1, c01 ); + + a01 = _mm512_permutex2var_epi64( a0, selector2, c0 ); // a[0] + c01 = _mm512_permutex2var_epi64( b0, selector2, d0 ); // a[2] + a0 = _mm512_permutex2var_epi64( a0, selector2_1, c0 ); // a[1] + c0 = _mm512_permutex2var_epi64( b0, selector2_1, d0 ); // a[3] + + // First half + b0 = _mm512_permutex2var_epi64( a01, selector3, e01 ); // 1st 64 + a01 = _mm512_permutex2var_epi64( a01, selector4, e0 ); // 1st 32 + d0 = _mm512_permutex2var_epi64( a0, selector5, e01 ); // 2nd 64 + a0 = _mm512_permutex2var_epi64( a0, selector6, e0 ); // 2nd 32 + + _mm512_storeu_epi64( pack_a_buffer_u8s8s32o32 + ( ( ic * KC ) + ( ( kr * MR ) + ( 0 ) ) ), b0 ); + _mm512_storeu_epi64( pack_a_buffer_u8s8s32o32 + ( ( ic * KC ) + ( ( kr * MR ) + ( 64 ) ) ) , a01 ); + _mm512_storeu_epi64( pack_a_buffer_u8s8s32o32 + ( ( ic * KC ) + ( ( kr * MR ) + ( 96 ) ) ), d0 ); + // Last piece + last_piece = _mm512_castsi512_si256( a0 ); + _mm256_storeu_epi64( pack_a_buffer_u8s8s32o32 + ( ( ic * KC ) + ( ( kr * MR ) + ( 160 ) ) ), last_piece ); + + // Second half + b0 = _mm512_permutex2var_epi64( c01, selector7, e01 ); // 3rd 64 + c01 = _mm512_permutex2var_epi64( c01, selector8, e0 ); // 3rd 32 + d0 = _mm512_permutex2var_epi64( c0, selector9, e01 ); // 4th 64 + c0 = _mm512_permutex2var_epi64( c0, selector10, e0 ); // 4th 32 + + _mm512_storeu_epi64( pack_a_buffer_u8s8s32o32 + ( ( ic * KC ) + ( ( kr * MR ) + ( 192 ) ) ), b0 ); + _mm512_storeu_epi64( pack_a_buffer_u8s8s32o32 + ( ( ic * KC ) + ( ( kr * MR ) + ( 256 ) ) ) , c01 ); + _mm512_storeu_epi64( pack_a_buffer_u8s8s32o32 + ( ( ic * KC ) + ( ( kr * MR ) + ( 288 ) ) ), d0 ); + // Last piece + last_piece = _mm512_castsi512_si256( c0 ); + _mm256_storeu_epi64( pack_a_buffer_u8s8s32o32 + ( ( ic * KC ) + ( ( kr * MR ) + ( 352 ) ) ), last_piece ); + } + //TODO: Handle kc < 64 case, 48,32,16 + } + + if ( m_partial_pieces > 0 ) + { + if ( m_partial_pieces == 5 ) + { + packa_m5_k64_u8s8s32o32 + ( + pack_a_buffer_u8s8s32o32 + ( m_full_pieces_loop_limit * KC ), + a + ( lda * m_full_pieces_loop_limit ), lda, KC + ); + } + else if ( m_partial_pieces == 4 ) + { + packa_m4_k64_u8s8s32o32 + ( + pack_a_buffer_u8s8s32o32 + ( m_full_pieces_loop_limit * KC ), + a + ( lda * m_full_pieces_loop_limit ), lda, KC + ); + } + else if ( m_partial_pieces == 3 ) + { + packa_m3_k64_u8s8s32o32 + ( + pack_a_buffer_u8s8s32o32 + ( m_full_pieces_loop_limit * KC ), + a + ( lda * m_full_pieces_loop_limit ), lda, KC + ); + } + else if ( m_partial_pieces == 2 ) + { + packa_m2_k64_u8s8s32o32 + ( + pack_a_buffer_u8s8s32o32 + ( m_full_pieces_loop_limit * KC ), + a + ( lda * m_full_pieces_loop_limit ), lda, KC + ); + } + else if ( m_partial_pieces == 1 ) + { + packa_m1_k64_u8s8s32o32 + ( + pack_a_buffer_u8s8s32o32 + ( m_full_pieces_loop_limit * KC ), + a + ( lda * m_full_pieces_loop_limit ), lda, KC + ); + } + } + *rs_a = 4; + *cs_a = 24; +} + +void packa_m5_k64_u8s8s32o32 + ( + uint8_t* pack_a_buffer_u8s8s32o32, + const uint8_t* a, + const dim_t lda, + const dim_t KC + ) +{ + // Used for permuting the mm512i elements for use in vpdpbusd instruction. + // These are indexes of the format a0-a1-b0-b1-a2-a3-b2-b3 and a0-a1-a2-a3-b0-b1-b2-b3. + // Adding 4 int32 wise gives format a4-a5-b4-b5-a6-a7-b6-b7 and a4-a5-a6-a7-b4-b5-b6-b7. + __m512i selector1 = _mm512_setr_epi64( 0x0, 0x1, 0x8, 0x9, 0x2, 0x3, 0xA, 0xB ); + __m512i selector1_1 = _mm512_setr_epi64( 0x4, 0x5, 0xC, 0xD, 0x6, 0x7, 0xE, 0xF ); + __m512i selector2 = _mm512_setr_epi64( 0x0, 0x1, 0x2, 0x3, 0x8, 0x9, 0xA, 0xB ); + __m512i selector2_1 = _mm512_setr_epi64( 0x4, 0x5, 0x6, 0x7, 0xC, 0xD, 0xE, 0xF ); + + // First half. + __m512i selector3 = _mm512_setr_epi32( 0x0, 0x1, 0x2, 0x3, 0x10, 0x4, 0x5, 0x6, 0x7, 0x11, 0x8, 0x9, 0xA, 0xB, 0x12, 0xC); + __m512i selector4 = _mm512_setr_epi32( 0xD, 0xE, 0xF, 0x13, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0); + __m512i selector5 = _mm512_setr_epi32( 0x0, 0x1, 0x2, 0x3, 0x14, 0x4, 0x5, 0x6, 0x7, 0x15, 0x8, 0x9, 0xA, 0xB, 0x16, 0xC); + __m512i selector6 = _mm512_setr_epi32( 0xD, 0xE, 0xF, 0x17, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0); + + // Second half. + __m512i selector7 = _mm512_setr_epi32( 0x0, 0x1, 0x2, 0x3, 0x18, 0x4, 0x5, 0x6, 0x7, 0x19, 0x8, 0x9, 0xA, 0xB, 0x1A, 0xC); + __m512i selector8 = _mm512_setr_epi32( 0xD, 0xE, 0xF, 0x1B, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0); + __m512i selector9 = _mm512_setr_epi32( 0x0, 0x1, 0x2, 0x3, 0x1C, 0x4, 0x5, 0x6, 0x7, 0x1D, 0x8, 0x9, 0xA, 0xB, 0x1E, 0xC); + __m512i selector10 = _mm512_setr_epi32( 0xD, 0xE, 0xF, 0x1F, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0); + + __m512i a0; + __m512i b0; + __m512i c0; + __m512i d0; + __m512i e0; + __m512i a01; + __m512i c01; + __m128i last_piece; + + for ( dim_t kr = 0; kr < KC; kr += NR ) + { + // Rearrange for vpdpbusd, read 5 rows from A with 64 elements in each row. + a0 = _mm512_loadu_epi8( a + ( lda * 0 ) + kr ); + b0 = _mm512_loadu_epi8( a + ( lda * 1 ) + kr ); + c0 = _mm512_loadu_epi8( a + ( lda * 2 ) + kr ); + d0 = _mm512_loadu_epi8( a + ( lda * 3 ) + kr ); + e0 = _mm512_loadu_epi8( a + ( lda * 4 ) + kr ); + + a01 = _mm512_unpacklo_epi32( a0, b0 ); + a0 = _mm512_unpackhi_epi32( a0, b0 ); + + c01 = _mm512_unpacklo_epi32( c0, d0 ); + c0 = _mm512_unpackhi_epi32( c0, d0 ); + + b0 = _mm512_unpacklo_epi64( a01, c01 ); + a01 = _mm512_unpackhi_epi64( a01, c01 ); + + d0 = _mm512_unpacklo_epi64( a0, c0 ); + c01 = _mm512_unpackhi_epi64( a0, c0 ); + + a0 = _mm512_permutex2var_epi64( b0, selector1, a01 ); + c0 = _mm512_permutex2var_epi64( d0, selector1, c01 ); + b0 = _mm512_permutex2var_epi64( b0, selector1_1, a01 ); + d0 = _mm512_permutex2var_epi64( d0, selector1_1, c01 ); + + a01 = _mm512_permutex2var_epi64( a0, selector2, c0 ); // a[0] + c01 = _mm512_permutex2var_epi64( b0, selector2, d0 ); // a[2] + a0 = _mm512_permutex2var_epi64( a0, selector2_1, c0 ); // a[1] + c0 = _mm512_permutex2var_epi64( b0, selector2_1, d0 ); // a[3] + + // First half + b0 = _mm512_permutex2var_epi32( a01, selector3, e0 ); + a01 = _mm512_permutex2var_epi32( a01, selector4, e0 ); + d0 = _mm512_permutex2var_epi32( a0, selector5, e0 ); + a0 = _mm512_permutex2var_epi32( a0, selector6, e0 ); + + _mm512_storeu_epi64( pack_a_buffer_u8s8s32o32 + ( ( kr * 5 ) + ( 0 ) ), b0 ); + _mm512_storeu_epi64( pack_a_buffer_u8s8s32o32 + ( ( kr * 5 ) + ( 64 ) ) , a01 ); + _mm512_storeu_epi64( pack_a_buffer_u8s8s32o32 + ( ( kr * 5 ) + ( 80 ) ), d0 ); + // Last piece + last_piece = _mm512_castsi512_si128( a0 ); + _mm_storeu_epi64( pack_a_buffer_u8s8s32o32 + ( ( kr * 5 ) + ( 144 ) ), last_piece ); + + // Second half + b0 = _mm512_permutex2var_epi32( c01, selector7, e0 ); + c01 = _mm512_permutex2var_epi32( c01, selector8, e0 ); + d0 = _mm512_permutex2var_epi32( c0, selector9, e0 ); + c0 = _mm512_permutex2var_epi32( c0, selector10, e0 ); + + _mm512_storeu_epi64( pack_a_buffer_u8s8s32o32 + ( ( kr * 5 ) + ( 160 ) ), b0 ); + _mm512_storeu_epi64( pack_a_buffer_u8s8s32o32 + ( ( kr * 5 ) + ( 224 ) ) , c01 ); + _mm512_storeu_epi64( pack_a_buffer_u8s8s32o32 + ( ( kr * 5 ) + ( 240 ) ), d0 ); + // Last piece + last_piece = _mm512_castsi512_si128( c0 ); + _mm_storeu_epi64( pack_a_buffer_u8s8s32o32 + ( ( kr * 5 ) + ( 304 ) ), last_piece ); + } +} + +void packa_m4_k64_u8s8s32o32 + ( + uint8_t* pack_a_buffer_u8s8s32o32, + const uint8_t* a, + const dim_t lda, + const dim_t KC + ) +{ + // Used for permuting the mm512i elements for use in vpdpbusd instruction. + // These are indexes of the format a0-a1-b0-b1-a2-a3-b2-b3 and a0-a1-a2-a3-b0-b1-b2-b3. + // Adding 4 int32 wise gives format a4-a5-b4-b5-a6-a7-b6-b7 and a4-a5-a6-a7-b4-b5-b6-b7. + __m512i selector1 = _mm512_setr_epi64( 0x0, 0x1, 0x8, 0x9, 0x2, 0x3, 0xA, 0xB ); + __m512i selector1_1 = _mm512_setr_epi64( 0x4, 0x5, 0xC, 0xD, 0x6, 0x7, 0xE, 0xF ); + __m512i selector2 = _mm512_setr_epi64( 0x0, 0x1, 0x2, 0x3, 0x8, 0x9, 0xA, 0xB ); + __m512i selector2_1 = _mm512_setr_epi64( 0x4, 0x5, 0x6, 0x7, 0xC, 0xD, 0xE, 0xF ); + + __m512i a0; + __m512i b0; + __m512i c0; + __m512i d0; + __m512i a01; + __m512i c01; + + for ( dim_t kr = 0; kr < KC; kr += NR ) + { + // Rearrange for vpdpbusd, read 4 rows from A with 64 elements in each row. + a0 = _mm512_loadu_epi8( a + ( lda * 0 ) + kr ); + b0 = _mm512_loadu_epi8( a + ( lda * 1 ) + kr ); + c0 = _mm512_loadu_epi8( a + ( lda * 2 ) + kr ); + d0 = _mm512_loadu_epi8( a + ( lda * 3 ) + kr ); + + a01 = _mm512_unpacklo_epi32( a0, b0 ); + a0 = _mm512_unpackhi_epi32( a0, b0 ); + + c01 = _mm512_unpacklo_epi32( c0, d0 ); + c0 = _mm512_unpackhi_epi32( c0, d0 ); + + b0 = _mm512_unpacklo_epi64( a01, c01 ); + a01 = _mm512_unpackhi_epi64( a01, c01 ); + + d0 = _mm512_unpacklo_epi64( a0, c0 ); + c01 = _mm512_unpackhi_epi64( a0, c0 ); + + a0 = _mm512_permutex2var_epi64( b0, selector1, a01 ); + c0 = _mm512_permutex2var_epi64( d0, selector1, c01 ); + b0 = _mm512_permutex2var_epi64( b0, selector1_1, a01 ); + d0 = _mm512_permutex2var_epi64( d0, selector1_1, c01 ); + + a01 = _mm512_permutex2var_epi64( a0, selector2, c0 ); // a[0] + c01 = _mm512_permutex2var_epi64( b0, selector2, d0 ); // a[2] + a0 = _mm512_permutex2var_epi64( a0, selector2_1, c0 ); // a[1] + c0 = _mm512_permutex2var_epi64( b0, selector2_1, d0 ); // a[3] + + _mm512_storeu_epi64( pack_a_buffer_u8s8s32o32 + ( ( kr * 4 ) + ( 0 ) ), a01 ); + _mm512_storeu_epi64( pack_a_buffer_u8s8s32o32 + ( ( kr * 4 ) + ( 64 ) ) , a0 ); + _mm512_storeu_epi64( pack_a_buffer_u8s8s32o32 + ( ( kr * 4 ) + ( 128 ) ), c01 ); + _mm512_storeu_epi64( pack_a_buffer_u8s8s32o32 + ( ( kr * 4 ) + ( 192 ) ), c0 ); + } +} + +void packa_m3_k64_u8s8s32o32 + ( + uint8_t* pack_a_buffer_u8s8s32o32, + const uint8_t* a, + const dim_t lda, + const dim_t KC + ) +{ + // Used for permuting the mm512i elements for use in vpdpbusd instruction. + // These are indexes of the format a0-a1-b0-b1-a2-a3-b2-b3 and a0-a1-a2-a3-b0-b1-b2-b3. + // Adding 4 int32 wise gives format a4-a5-b4-b5-a6-a7-b6-b7 and a4-a5-a6-a7-b4-b5-b6-b7. + __m512i selector1 = _mm512_setr_epi64( 0x0, 0x1, 0x8, 0x9, 0x2, 0x3, 0xA, 0xB ); + __m512i selector1_1 = _mm512_setr_epi64( 0x4, 0x5, 0xC, 0xD, 0x6, 0x7, 0xE, 0xF ); + + // First half + __m512i selector3 = _mm512_setr_epi32( 0x0, 0x1, 0x10, 0x2, 0x3, 0x11, 0x4, 0x5, 0x12, 0x6, 0x7, 0x13, 0x8, 0x9, 0x14, 0xA ); + __m512i selector4 = _mm512_setr_epi32( 0xB, 0x15, 0xC, 0xD, 0x16, 0xE, 0xF, 0x17, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0 ); + + // Second half + __m512i selector5 = _mm512_setr_epi32( 0x0, 0x1, 0x18, 0x2, 0x3, 0x19, 0x4, 0x5, 0x1A, 0x6, 0x7, 0x1B, 0x8, 0x9, 0x1C, 0xA ); + __m512i selector6 = _mm512_setr_epi32( 0xB, 0x1D, 0xC, 0xD, 0x1E, 0xE, 0xF, 0x1F, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0 ); + + __m512i a0; + __m512i b0; + __m512i c0; + __m512i a01; + __m256i last_piece; + + for ( dim_t kr = 0; kr < KC; kr += NR ) + { + // Rearrange for vpdpbusd, read 3 rows from A with 64 elements in each row. + a0 = _mm512_loadu_epi8( a + ( lda * 0 ) + kr ); + b0 = _mm512_loadu_epi8( a + ( lda * 1 ) + kr ); + c0 = _mm512_loadu_epi8( a + ( lda * 2 ) + kr ); + + a01 = _mm512_unpacklo_epi32( a0, b0 ); + a0 = _mm512_unpackhi_epi32( a0, b0 ); + + b0 = _mm512_permutex2var_epi64( a01, selector1, a0 ); // a[0] + a01 = _mm512_permutex2var_epi64( a01, selector1_1, a0 ); // a[1] + + a0 = _mm512_permutex2var_epi32( b0, selector3, c0 ); + b0 = _mm512_permutex2var_epi32( b0, selector4, c0 ); + + _mm512_storeu_epi64( pack_a_buffer_u8s8s32o32 + ( ( kr * 3 ) + ( 0 ) ), a0 ); + _mm512_storeu_epi64( pack_a_buffer_u8s8s32o32 + ( ( kr * 3 ) + ( 64 ) ) , b0 ); + + a0 = _mm512_permutex2var_epi32( a01, selector5, c0 ); + b0 = _mm512_permutex2var_epi32( a01, selector6, c0 ); + + _mm512_storeu_epi64( pack_a_buffer_u8s8s32o32 + ( ( kr * 3 ) + ( 96 ) ), a0 ); + // Last piece + last_piece = _mm512_castsi512_si256( b0 ); + _mm256_storeu_epi64( pack_a_buffer_u8s8s32o32 + ( ( kr * 3 ) + ( 160 ) ), last_piece ); + } +} + +void packa_m2_k64_u8s8s32o32 + ( + uint8_t* pack_a_buffer_u8s8s32o32, + const uint8_t* a, + const dim_t lda, + const dim_t KC + ) +{ + // Used for permuting the mm512i elements for use in vpdpbusd instruction. + // These are indexes of the format a0-a1-b0-b1-a2-a3-b2-b3 and a0-a1-a2-a3-b0-b1-b2-b3. + // Adding 4 int32 wise gives format a4-a5-b4-b5-a6-a7-b6-b7 and a4-a5-a6-a7-b4-b5-b6-b7. + __m512i selector1 = _mm512_setr_epi64( 0x0, 0x1, 0x8, 0x9, 0x2, 0x3, 0xA, 0xB ); + __m512i selector1_1 = _mm512_setr_epi64( 0x4, 0x5, 0xC, 0xD, 0x6, 0x7, 0xE, 0xF ); + + __m512i a0; + __m512i b0; + __m512i a01; + + for ( dim_t kr = 0; kr < KC; kr += NR ) + { + // Rearrange for vpdpbusd, read 2 rows from A with 64 elements in each row. + a0 = _mm512_loadu_epi8( a + ( lda * 0 ) + kr ); + b0 = _mm512_loadu_epi8( a + ( lda * 1 ) + kr ); + + a01 = _mm512_unpacklo_epi32( a0, b0 ); + a0 = _mm512_unpackhi_epi32( a0, b0 ); + + b0 = _mm512_permutex2var_epi64( a01, selector1, a0 ); // a[0] + a01 = _mm512_permutex2var_epi64( a01, selector1_1, a0 ); // a[1] + + _mm512_storeu_epi64( pack_a_buffer_u8s8s32o32 + ( ( kr * 2 ) + ( 0 ) ), b0 ); + _mm512_storeu_epi64( pack_a_buffer_u8s8s32o32 + ( ( kr * 2 ) + ( 64 ) ) , a01 ); + } +} + +void packa_m1_k64_u8s8s32o32 + ( + uint8_t* pack_a_buffer_u8s8s32o32, + const uint8_t* a, + const dim_t lda, + const dim_t KC + ) +{ + __m512i a0; + + for ( dim_t kr = 0; kr < KC; kr += NR ) + { + // Rearrange for vpdpbusd, read 1 row from A with 64 elements in each row. + a0 = _mm512_loadu_epi8( a + ( lda * 0 ) + kr ); + + _mm512_storeu_epi64( pack_a_buffer_u8s8s32o32 + ( ( kr * 1 ) + ( 0 ) ), a0 ); + } +} diff --git a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_packb.h b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_packb.h new file mode 100644 index 0000000000..3f310c0a48 --- /dev/null +++ b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_packb.h @@ -0,0 +1,65 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_GEMM_INT8_PACKB +#define BLIS_GEMM_INT8_PACKB + +BLIS_INLINE dim_t get_packb_u8s8s32o32_min_NR() +{ + // This is the minimum NR' required for use in u8s8s32 kernels. The idea + // here is that since k needs to be a multiple of 4 (VNNI instr), NR'=16 + // results in total of 4 * NR' = 64 bytes to be loaded, which fits in 1 ZMM + // register. Thus the smallest n fringe kernel dimension has n=16, and thus + // any rounding for buffer sizes should be to 16. + return 16; +} + +void get_packb_nr64_u8s8s32o32_strides + ( + dim_t* rs_b, + dim_t* cs_b + ); + +void packb_nr64_u8s8s32o32 + ( + int8_t* pack_b_buffer_u8s8s32o32, + const int8_t* b, + const dim_t ldb, + const dim_t NC, + const dim_t KC, + dim_t* rs_b, + dim_t* cs_b + ); + +#endif //BLIS_GEMM_INT8_PACKB diff --git a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_packb_amd512vnni.c b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_packb_amd512vnni.c new file mode 100644 index 0000000000..06a46afb44 --- /dev/null +++ b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_packb_amd512vnni.c @@ -0,0 +1,792 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include + +#include "blis.h" +#include "lpgemm_packb.h" + +#define NR 64 + +void packb_nrlt16_u8s8s32o32 + ( + int8_t* pack_b_buffer_u8s8s32o32, + const int8_t* b, + const dim_t ldb, + const dim_t KC, + const dim_t n0_partial_rem + ); + +void packb_nr16_u8s8s32o32 + ( + int8_t* pack_b_buffer_u8s8s32o32, + const int8_t* b, + const dim_t ldb, + const dim_t KC + ); + +void packb_nr32_u8s8s32o32 + ( + int8_t* pack_b_buffer_u8s8s32o32, + const int8_t* b, + const dim_t ldb, + const dim_t KC + ); + +void packb_nr48_u8s8s32o32 + ( + int8_t* pack_b_buffer_u8s8s32o32, + const int8_t* b, + const dim_t ldb, + const dim_t KC + ); + +void get_packb_nr64_u8s8s32o32_strides + ( + dim_t* rs_b, + dim_t* cs_b + ) +{ + *rs_b = NR * 4; + *cs_b = NR; +} + +void packb_nr64_u8s8s32o32 + ( + int8_t* pack_b_buffer_u8s8s32o32, + const int8_t* b, + const dim_t ldb, + const dim_t NC, + const dim_t KC, + dim_t* rs_b, + dim_t* cs_b + ) +{ + // Used for permuting the mm512i elements for use in vpdpbusd instruction. + // These are indexes of the format a0-a1-b0-b1-a2-a3-b2-b3 and a0-a1-a2-a3-b0-b1-b2-b3. + // Adding int32 wise all4 gives format a4-a5-b4-b5-a6-a7-b6-b7 and a4-a5-a6-a7-b4-b5-b6-b7. + __m512i selector1 = _mm512_setr_epi64( 0x0, 0x1, 0x8, 0x9, 0x2, 0x3, 0xA, 0xB ); + __m512i selector1_1 = _mm512_setr_epi64( 0x4, 0x5, 0xC, 0xD, 0x6, 0x7, 0xE, 0xF ); + + __m512i selector2 = _mm512_setr_epi64( 0x0, 0x1, 0x2, 0x3, 0x8, 0x9, 0xA, 0xB ); + __m512i selector2_1 = _mm512_setr_epi64( 0x4, 0x5, 0x6, 0x7, 0xC, 0xD, 0xE, 0xF ); + + dim_t n_full_pieces = NC / NR; + dim_t n_full_pieces_loop_limit = n_full_pieces * NR; + dim_t n_partial_pieces = NC % NR; + + dim_t k_full_pieces_blks = KC / 4; + dim_t k_full_pieces = k_full_pieces_blks * 4; + dim_t k_partial_pieces = KC % 4; + + // KC when not multiple of 4 will have padding to make it multiple of 4 in packed buffer. + dim_t KC_updated = KC; + if ( k_partial_pieces > 0 ) + { + KC_updated += ( 4 - k_partial_pieces ); + } + + __m512i a0; + __m512i b0; + __m512i c0; + __m512i d0; + __m512i a01; + __m512i c01; + + for ( dim_t jc = 0; jc < n_full_pieces_loop_limit; jc += NR ) + { + for ( dim_t kr = 0; kr < k_full_pieces; kr += 4 ) + { + // Rearrange for vpdpbusd, read 4 rows from B with 64 elements in each row. + a0 = _mm512_loadu_epi8( b + ( ldb * ( kr + 0 ) ) + jc ); + b0 = _mm512_loadu_epi8( b + ( ldb * ( kr + 1 ) ) + jc ); + c0 = _mm512_loadu_epi8( b + ( ldb * ( kr + 2 ) ) + jc ); + d0 = _mm512_loadu_epi8( b + ( ldb * ( kr + 3 ) ) + jc ); + + a01 = _mm512_unpacklo_epi8( a0, b0 ); + a0 = _mm512_unpackhi_epi8( a0, b0 ); + + c01 = _mm512_unpacklo_epi8( c0, d0 ); + c0 = _mm512_unpackhi_epi8( c0, d0 ); + + b0 = _mm512_unpacklo_epi16( a01, c01 ); + a01 = _mm512_unpackhi_epi16( a01, c01 ); + + d0 = _mm512_unpacklo_epi16( a0, c0 ); + c01 = _mm512_unpackhi_epi16( a0, c0 ); + + a0 = _mm512_permutex2var_epi64( b0, selector1, a01 ); + c0 = _mm512_permutex2var_epi64( d0, selector1, c01 ); + b0 = _mm512_permutex2var_epi64( b0, selector1_1, a01 ); + d0 = _mm512_permutex2var_epi64( d0, selector1_1, c01 ); + + a01 = _mm512_permutex2var_epi64( a0, selector2, c0 ); // b[0] + c01 = _mm512_permutex2var_epi64( b0, selector2, d0 ); // b[2] + a0 = _mm512_permutex2var_epi64( a0, selector2_1, c0 ); // b[1] + c0 = _mm512_permutex2var_epi64( b0, selector2_1, d0 ); // b[3] + + _mm512_storeu_epi64( pack_b_buffer_u8s8s32o32 + ( ( jc * KC_updated ) + ( ( kr + 0 ) * NR ) ), a01 ); + _mm512_storeu_epi64( pack_b_buffer_u8s8s32o32 + ( ( jc * KC_updated ) + ( ( kr + 1 ) * NR ) ) , a0 ); + _mm512_storeu_epi64( pack_b_buffer_u8s8s32o32 + ( ( jc * KC_updated ) + ( ( kr + 2 ) * NR ) ), c01 ); + _mm512_storeu_epi64( pack_b_buffer_u8s8s32o32 + ( ( jc * KC_updated ) + ( ( kr + 3 ) * NR ) ), c0 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + if ( k_partial_pieces == 3 ) + { + a0 = _mm512_loadu_epi8( b + ( ldb * ( k_full_pieces + 0 ) ) + jc ); + b0 = _mm512_loadu_epi8( b + ( ldb * ( k_full_pieces + 1 ) ) + jc ); + c0 = _mm512_loadu_epi8( b + ( ldb * ( k_full_pieces + 2 ) ) + jc ); + d0 = _mm512_setzero_si512(); + + } + else if( k_partial_pieces == 2 ) + { + a0 = _mm512_loadu_epi8( b + ( ldb * ( k_full_pieces + 0 ) ) + jc ); + b0 = _mm512_loadu_epi8( b + ( ldb * ( k_full_pieces + 1 ) ) + jc ); + c0 = _mm512_setzero_si512(); + d0 = _mm512_setzero_si512(); + } + else //k_partial_pieces == 1 + { + a0 = _mm512_loadu_epi8( b + ( ldb * ( k_full_pieces + 0 ) ) + jc ); + b0 = _mm512_setzero_si512(); + c0 = _mm512_setzero_si512(); + d0 = _mm512_setzero_si512(); + } + + a01 = _mm512_unpacklo_epi8( a0, b0 ); + a0 = _mm512_unpackhi_epi8( a0, b0 ); + + c01 = _mm512_unpacklo_epi8( c0, d0 ); + c0 = _mm512_unpackhi_epi8( c0, d0 ); + + b0 = _mm512_unpacklo_epi16( a01, c01 ); + a01 = _mm512_unpackhi_epi16( a01, c01 ); + + d0 = _mm512_unpacklo_epi16( a0, c0 ); + c01 = _mm512_unpackhi_epi16( a0, c0 ); + + a0 = _mm512_permutex2var_epi64( b0, selector1, a01 ); + c0 = _mm512_permutex2var_epi64( d0, selector1, c01 ); + b0 = _mm512_permutex2var_epi64( b0, selector1_1, a01 ); + d0 = _mm512_permutex2var_epi64( d0, selector1_1, c01 ); + + a01 = _mm512_permutex2var_epi64( a0, selector2, c0 ); // b[0] + c01 = _mm512_permutex2var_epi64( b0, selector2, d0 ); // b[2] + a0 = _mm512_permutex2var_epi64( a0, selector2_1, c0 ); // b[1] + c0 = _mm512_permutex2var_epi64( b0, selector2_1, d0 ); // b[3] + + _mm512_storeu_epi64( pack_b_buffer_u8s8s32o32 + ( ( jc * KC_updated ) + ( ( k_full_pieces + 0 ) * NR ) ), a01 ); + _mm512_storeu_epi64( pack_b_buffer_u8s8s32o32 + ( ( jc * KC_updated ) + ( ( k_full_pieces + 1 ) * NR ) ) , a0 ); + _mm512_storeu_epi64( pack_b_buffer_u8s8s32o32 + ( ( jc * KC_updated ) + ( ( k_full_pieces + 2 ) * NR ) ), c01 ); + _mm512_storeu_epi64( pack_b_buffer_u8s8s32o32 + ( ( jc * KC_updated ) + ( ( k_full_pieces + 3 ) * NR ) ), c0 ); + } + } + + // Contiguous packing of fringe panel (n` < NR). + if ( n_partial_pieces > 0 ) + { + dim_t n0_partial_rem = n_partial_pieces % 16; + dim_t n0_partial_pack = 0; + + // Split into multiple smaller fringe kernels, so as to maximize + // vectorization after packing. Any n0 < NR(64) can be expressed + // as n0 = 48 + n` / n0 = 32 + n` / n0 = 16 + n`, where n` < 16. + dim_t n0_48 = n_partial_pieces / 48; + dim_t n0_32 = n_partial_pieces / 32; + dim_t n0_16 = n_partial_pieces / 16; + + if ( n0_48 == 1 ) + { + packb_nr48_u8s8s32o32 + ( + ( pack_b_buffer_u8s8s32o32 + ( n_full_pieces_loop_limit * KC_updated ) ), + ( b + n_full_pieces_loop_limit ), ldb, KC + ); + + n0_partial_pack = 48; + } + else if ( n0_32 == 1 ) + { + packb_nr32_u8s8s32o32 + ( + ( pack_b_buffer_u8s8s32o32 + ( n_full_pieces_loop_limit * KC_updated ) ), + ( b + n_full_pieces_loop_limit ), ldb, KC + ); + + n0_partial_pack = 32; + } + else if ( n0_16 == 1 ) + { + packb_nr16_u8s8s32o32 + ( + ( pack_b_buffer_u8s8s32o32 + ( n_full_pieces_loop_limit * KC_updated ) ), + ( b + n_full_pieces_loop_limit ), ldb, KC + ); + + n0_partial_pack = 16; + } + + if ( n0_partial_rem > 0 ) + { + packb_nrlt16_u8s8s32o32 + ( + ( pack_b_buffer_u8s8s32o32 + ( n_full_pieces_loop_limit * KC_updated ) + + ( n0_partial_pack * KC_updated ) ), + ( b + n_full_pieces_loop_limit + n0_partial_pack ), ldb, KC, + n0_partial_rem + ); + } + } + *rs_b = NR * 4; + *cs_b = NR; +} + +void packb_nr48_u8s8s32o32 + ( + int8_t* pack_b_buffer_u8s8s32o32, + const int8_t* b, + const dim_t ldb, + const dim_t KC + ) +{ + dim_t kr_new = 0; + + dim_t k_full_pieces_blks = KC / 4; + dim_t k_full_pieces = k_full_pieces_blks * 4; + dim_t k_partial_pieces = KC % 4; + + __m256i a0_32; + __m256i b0_32; + __m256i c0_32; + __m256i d0_32; + __m256i a01_32; + __m256i c01_32; + __m512i a0_zmm; + __m512i b0_zmm; + __m128i a0_16; + __m128i b0_16; + __m128i c0_16; + __m128i d0_16; + __m128i a01_16; + __m128i c01_16; + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 4 ) + { + // Rearrange for vpdpbusd, read 4 rows from B with 32 elements in each row. + a0_32 = _mm256_loadu_epi8( b + ( ldb * ( kr + 0 ) ) ); + b0_32 = _mm256_loadu_epi8( b + ( ldb * ( kr + 1 ) ) ); + c0_32 = _mm256_loadu_epi8( b + ( ldb * ( kr + 2 ) ) ); + d0_32 = _mm256_loadu_epi8( b + ( ldb * ( kr + 3 ) ) ); + + a01_32 = _mm256_unpacklo_epi8( a0_32, b0_32 ); + a0_32 = _mm256_unpackhi_epi8( a0_32, b0_32 ); + + c01_32 = _mm256_unpacklo_epi8( c0_32, d0_32 ); + c0_32 = _mm256_unpackhi_epi8( c0_32, d0_32 ); + + b0_32 = _mm256_unpacklo_epi16( a01_32, c01_32 ); + a01_32 = _mm256_unpackhi_epi16( a01_32, c01_32 ); + + d0_32 = _mm256_unpacklo_epi16( a0_32, c0_32 ); + c01_32 = _mm256_unpackhi_epi16( a0_32, c0_32 ); + + a0_32 = _mm256_shuffle_i32x4( b0_32, a01_32, 0x0 ); // 0 elem + c0_32 = _mm256_shuffle_i32x4( b0_32, a01_32, 0x3 ); // 2 elem + b0_32 = _mm256_shuffle_i32x4( d0_32, c01_32, 0x0 ); // 1 elem + d0_32 = _mm256_shuffle_i32x4( d0_32, c01_32, 0x3 ); // 3 elem + + a0_zmm = _mm512_castsi256_si512( a0_32 ); + a0_zmm = _mm512_inserti32x8( a0_zmm, b0_32, 0x1 ); + b0_zmm = _mm512_castsi256_si512( c0_32 ); + b0_zmm = _mm512_inserti32x8( b0_zmm, d0_32, 0x1 ); + + // First 4x32 elements. + _mm512_storeu_epi64( pack_b_buffer_u8s8s32o32 + ( ( kr_new + 0 ) * NR ), a0_zmm ); + _mm512_storeu_epi64( pack_b_buffer_u8s8s32o32 + ( ( kr_new + 1 ) * NR ), b0_zmm ); + + // Rearrange for vpdpbusd, read 4 rows from B with next 16 elements in each row. + a0_16 = _mm_loadu_epi8( b + ( ldb * ( kr + 0 ) ) + ( 32 ) ); + b0_16 = _mm_loadu_epi8( b + ( ldb * ( kr + 1 ) ) + ( 32 ) ); + c0_16 = _mm_loadu_epi8( b + ( ldb * ( kr + 2 ) ) + ( 32 ) ); + d0_16 = _mm_loadu_epi8( b + ( ldb * ( kr + 3 ) ) + ( 32 ) ); + + a01_16 = _mm_unpacklo_epi8( a0_16, b0_16 ); + a0_16 = _mm_unpackhi_epi8( a0_16, b0_16 ); + + c01_16 = _mm_unpacklo_epi8( c0_16, d0_16 ); + c0_16 = _mm_unpackhi_epi8( c0_16, d0_16 ); + + b0_16 = _mm_unpacklo_epi16( a01_16, c01_16 ); // 0 elem + a01_16 = _mm_unpackhi_epi16( a01_16, c01_16 ); // 1 elem + d0_16 = _mm_unpacklo_epi16( a0_16, c0_16 ); // 2 elem + c01_16 = _mm_unpackhi_epi16( a0_16, c0_16 ); // 3 elem + + a0_zmm = _mm512_castsi128_si512( b0_16 ); + a0_zmm = _mm512_inserti32x4( a0_zmm, a01_16, 0x1 ); + a0_zmm = _mm512_inserti32x4( a0_zmm, d0_16, 0x2 ); + a0_zmm = _mm512_inserti32x4( a0_zmm, c01_16, 0x3 ); + + // Last 4x16 elements. + _mm512_storeu_epi64( pack_b_buffer_u8s8s32o32 + ( ( kr_new + 2 ) * NR ), a0_zmm ); + + // The 4th 16byte chunk will be ignored, since its not part of the original data, + // but is here due to the packing in 4 16byte chunks format. + kr_new += 3; + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + if ( k_partial_pieces == 3 ) + { + a0_32 = _mm256_loadu_epi8( b + ( ldb * ( k_full_pieces + 0 ) ) ); + b0_32 = _mm256_loadu_epi8( b + ( ldb * ( k_full_pieces + 1 ) ) ); + c0_32 = _mm256_loadu_epi8( b + ( ldb * ( k_full_pieces + 2 ) ) ); + d0_32 = _mm256_setzero_si256(); + + a0_16 = _mm_loadu_epi8( b + ( ldb * ( k_full_pieces + 0 ) ) + ( 32 ) ); + b0_16 = _mm_loadu_epi8( b + ( ldb * ( k_full_pieces + 1 ) ) + ( 32 ) ); + c0_16 = _mm_loadu_epi8( b + ( ldb * ( k_full_pieces + 2 ) ) + ( 32 ) ); + d0_16 = _mm_setzero_si128(); + + } + else if( k_partial_pieces == 2 ) + { + a0_32 = _mm256_loadu_epi8( b + ( ldb * ( k_full_pieces + 0 ) ) ); + b0_32 = _mm256_loadu_epi8( b + ( ldb * ( k_full_pieces + 1 ) ) ); + c0_32 = _mm256_setzero_si256(); + d0_32 = _mm256_setzero_si256(); + + a0_16 = _mm_loadu_epi8( b + ( ldb * ( k_full_pieces + 0 ) ) + ( 32 ) ); + b0_16 = _mm_loadu_epi8( b + ( ldb * ( k_full_pieces + 1 ) ) + ( 32 ) ); + c0_16 = _mm_setzero_si128(); + d0_16 = _mm_setzero_si128(); + } + else //k_partial_pieces == 1 + { + a0_32 = _mm256_loadu_epi8( b + ( ldb * ( k_full_pieces + 0 ) ) ); + b0_32 = _mm256_setzero_si256(); + c0_32 = _mm256_setzero_si256(); + d0_32 = _mm256_setzero_si256(); + + a0_16 = _mm_loadu_epi8( b + ( ldb * ( k_full_pieces + 0 ) ) + ( 32 ) ); + b0_16 = _mm_setzero_si128(); + c0_16 = _mm_setzero_si128(); + d0_16 = _mm_setzero_si128(); + } + + a01_32 = _mm256_unpacklo_epi8( a0_32, b0_32 ); + a0_32 = _mm256_unpackhi_epi8( a0_32, b0_32 ); + + c01_32 = _mm256_unpacklo_epi8( c0_32, d0_32 ); + c0_32 = _mm256_unpackhi_epi8( c0_32, d0_32 ); + + b0_32 = _mm256_unpacklo_epi16( a01_32, c01_32 ); + a01_32 = _mm256_unpackhi_epi16( a01_32, c01_32 ); + + d0_32 = _mm256_unpacklo_epi16( a0_32, c0_32 ); + c01_32 = _mm256_unpackhi_epi16( a0_32, c0_32 ); + + a0_32 = _mm256_shuffle_i32x4( b0_32, a01_32, 0x0 ); // 0 elem + c0_32 = _mm256_shuffle_i32x4( b0_32, a01_32, 0x3 ); // 2 elem + b0_32 = _mm256_shuffle_i32x4( d0_32, c01_32, 0x0 ); // 1 elem + d0_32 = _mm256_shuffle_i32x4( d0_32, c01_32, 0x3 ); // 3 elem + + a0_zmm = _mm512_castsi256_si512( a0_32 ); + a0_zmm = _mm512_inserti32x8( a0_zmm, b0_32, 0x1 ); + b0_zmm = _mm512_castsi256_si512( c0_32 ); + b0_zmm = _mm512_inserti32x8( b0_zmm, d0_32, 0x1 ); + + // First 4x32 elements. + _mm512_storeu_epi64( pack_b_buffer_u8s8s32o32 + ( ( kr_new + 0 ) * NR ), a0_zmm ); + _mm512_storeu_epi64( pack_b_buffer_u8s8s32o32 + ( ( kr_new + 1 ) * NR ), b0_zmm ); + + a01_16 = _mm_unpacklo_epi8( a0_16, b0_16 ); + a0_16 = _mm_unpackhi_epi8( a0_16, b0_16 ); + + c01_16 = _mm_unpacklo_epi8( c0_16, d0_16 ); + c0_16 = _mm_unpackhi_epi8( c0_16, d0_16 ); + + b0_16 = _mm_unpacklo_epi16( a01_16, c01_16 ); // 0 elem + a01_16 = _mm_unpackhi_epi16( a01_16, c01_16 ); // 1 elem + d0_16 = _mm_unpacklo_epi16( a0_16, c0_16 ); // 2 elem + c01_16 = _mm_unpackhi_epi16( a0_16, c0_16 ); // 3 elem + + a0_zmm = _mm512_castsi128_si512( b0_16 ); + a0_zmm = _mm512_inserti32x4( a0_zmm, a01_16, 0x1 ); + a0_zmm = _mm512_inserti32x4( a0_zmm, d0_16, 0x2 ); + a0_zmm = _mm512_inserti32x4( a0_zmm, c01_16, 0x3 ); + + // Last 4x16 elements. + _mm512_storeu_epi64( pack_b_buffer_u8s8s32o32 + ( ( kr_new + 2 ) * NR ), a0_zmm ); + } +} + +void packb_nr32_u8s8s32o32 + ( + int8_t* pack_b_buffer_u8s8s32o32, + const int8_t* b, + const dim_t ldb, + const dim_t KC + ) +{ + dim_t kr_new = 0; + + dim_t k_full_pieces_blks = KC / 4; + dim_t k_full_pieces = k_full_pieces_blks * 4; + dim_t k_partial_pieces = KC % 4; + + __m256i a0_32; + __m256i b0_32; + __m256i c0_32; + __m256i d0_32; + __m256i a01_32; + __m256i c01_32; + __m512i a0_zmm; + __m512i b0_zmm; + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 4 ) + { + // Rearrange for vpdpbusd, read 4 rows from B with 32 elements in each row. + a0_32 = _mm256_loadu_epi8( b + ( ldb * ( kr + 0 ) ) ); + b0_32 = _mm256_loadu_epi8( b + ( ldb * ( kr + 1 ) ) ); + c0_32 = _mm256_loadu_epi8( b + ( ldb * ( kr + 2 ) ) ); + d0_32 = _mm256_loadu_epi8( b + ( ldb * ( kr + 3 ) ) ); + + a01_32 = _mm256_unpacklo_epi8( a0_32, b0_32 ); + a0_32 = _mm256_unpackhi_epi8( a0_32, b0_32 ); + + c01_32 = _mm256_unpacklo_epi8( c0_32, d0_32 ); + c0_32 = _mm256_unpackhi_epi8( c0_32, d0_32 ); + + b0_32 = _mm256_unpacklo_epi16( a01_32, c01_32 ); + a01_32 = _mm256_unpackhi_epi16( a01_32, c01_32 ); + + d0_32 = _mm256_unpacklo_epi16( a0_32, c0_32 ); + c01_32 = _mm256_unpackhi_epi16( a0_32, c0_32 ); + + a0_32 = _mm256_shuffle_i32x4( b0_32, a01_32, 0x0 ); // 0 elem + c0_32 = _mm256_shuffle_i32x4( b0_32, a01_32, 0x3 ); // 2 elem + b0_32 = _mm256_shuffle_i32x4( d0_32, c01_32, 0x0 ); // 1 elem + d0_32 = _mm256_shuffle_i32x4( d0_32, c01_32, 0x3 ); // 3 elem + + a0_zmm = _mm512_castsi256_si512( a0_32 ); + a0_zmm = _mm512_inserti32x8( a0_zmm, b0_32, 0x1 ); + b0_zmm = _mm512_castsi256_si512( c0_32 ); + b0_zmm = _mm512_inserti32x8( b0_zmm, d0_32, 0x1 ); + + // First 4x32 elements. + _mm512_storeu_epi64( pack_b_buffer_u8s8s32o32 + ( ( kr_new + 0 ) * NR ), a0_zmm ); + _mm512_storeu_epi64( pack_b_buffer_u8s8s32o32 + ( ( kr_new + 1 ) * NR ), b0_zmm ); + + // The 3rd and 4th 16byte chunk will be ignored, since its not part of the original data, + // but is here due to the packing in 4 16byte chunks format. + kr_new += 2; + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + if ( k_partial_pieces == 3 ) + { + a0_32 = _mm256_loadu_epi8( b + ( ldb * ( k_full_pieces + 0 ) ) ); + b0_32 = _mm256_loadu_epi8( b + ( ldb * ( k_full_pieces + 1 ) ) ); + c0_32 = _mm256_loadu_epi8( b + ( ldb * ( k_full_pieces + 2 ) ) ); + d0_32 = _mm256_setzero_si256(); + + } + else if( k_partial_pieces == 2 ) + { + a0_32 = _mm256_loadu_epi8( b + ( ldb * ( k_full_pieces + 0 ) ) ); + b0_32 = _mm256_loadu_epi8( b + ( ldb * ( k_full_pieces + 1 ) ) ); + c0_32 = _mm256_setzero_si256(); + d0_32 = _mm256_setzero_si256(); + } + else //k_partial_pieces == 1 + { + a0_32 = _mm256_loadu_epi8( b + ( ldb * ( k_full_pieces + 0 ) ) ); + b0_32 = _mm256_setzero_si256(); + c0_32 = _mm256_setzero_si256(); + d0_32 = _mm256_setzero_si256(); + } + + a01_32 = _mm256_unpacklo_epi8( a0_32, b0_32 ); + a0_32 = _mm256_unpackhi_epi8( a0_32, b0_32 ); + + c01_32 = _mm256_unpacklo_epi8( c0_32, d0_32 ); + c0_32 = _mm256_unpackhi_epi8( c0_32, d0_32 ); + + b0_32 = _mm256_unpacklo_epi16( a01_32, c01_32 ); + a01_32 = _mm256_unpackhi_epi16( a01_32, c01_32 ); + + d0_32 = _mm256_unpacklo_epi16( a0_32, c0_32 ); + c01_32 = _mm256_unpackhi_epi16( a0_32, c0_32 ); + + a0_32 = _mm256_shuffle_i32x4( b0_32, a01_32, 0x0 ); // 0 elem + c0_32 = _mm256_shuffle_i32x4( b0_32, a01_32, 0x3 ); // 2 elem + b0_32 = _mm256_shuffle_i32x4( d0_32, c01_32, 0x0 ); // 1 elem + d0_32 = _mm256_shuffle_i32x4( d0_32, c01_32, 0x3 ); // 3 elem + + a0_zmm = _mm512_castsi256_si512( a0_32 ); + a0_zmm = _mm512_inserti32x8( a0_zmm, b0_32, 0x1 ); + b0_zmm = _mm512_castsi256_si512( c0_32 ); + b0_zmm = _mm512_inserti32x8( b0_zmm, d0_32, 0x1 ); + + // First 4x32 elements. + _mm512_storeu_epi64( pack_b_buffer_u8s8s32o32 + ( ( kr_new + 0 ) * NR ), a0_zmm ); + _mm512_storeu_epi64( pack_b_buffer_u8s8s32o32 + ( ( kr_new + 1 ) * NR ), b0_zmm ); + } +} + +void packb_nr16_u8s8s32o32 + ( + int8_t* pack_b_buffer_u8s8s32o32, + const int8_t* b, + const dim_t ldb, + const dim_t KC + ) +{ + dim_t kr_new = 0; + + dim_t k_full_pieces_blks = KC / 4; + dim_t k_full_pieces = k_full_pieces_blks * 4; + dim_t k_partial_pieces = KC % 4; + + __m128i a0_16; + __m128i b0_16; + __m128i c0_16; + __m128i d0_16; + __m128i a01_16; + __m128i c01_16; + __m512i a0_zmm; + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 4 ) + { + // Rearrange for vpdpbusd, read 4 rows from B with next 16 elements in each row. + a0_16 = _mm_loadu_epi8( b + ( ldb * ( kr + 0 ) ) ); + b0_16 = _mm_loadu_epi8( b + ( ldb * ( kr + 1 ) ) ); + c0_16 = _mm_loadu_epi8( b + ( ldb * ( kr + 2 ) ) ); + d0_16 = _mm_loadu_epi8( b + ( ldb * ( kr + 3 ) ) ); + + a01_16 = _mm_unpacklo_epi8( a0_16, b0_16 ); + a0_16 = _mm_unpackhi_epi8( a0_16, b0_16 ); + + c01_16 = _mm_unpacklo_epi8( c0_16, d0_16 ); + c0_16 = _mm_unpackhi_epi8( c0_16, d0_16 ); + + b0_16 = _mm_unpacklo_epi16( a01_16, c01_16 ); // 0 elem + a01_16 = _mm_unpackhi_epi16( a01_16, c01_16 ); // 1 elem + d0_16 = _mm_unpacklo_epi16( a0_16, c0_16 ); // 2 elem + c01_16 = _mm_unpackhi_epi16( a0_16, c0_16 ); // 3 elem + + a0_zmm = _mm512_castsi128_si512( b0_16 ); + a0_zmm = _mm512_inserti32x4( a0_zmm, a01_16, 0x1 ); + a0_zmm = _mm512_inserti32x4( a0_zmm, d0_16, 0x2 ); + a0_zmm = _mm512_inserti32x4( a0_zmm, c01_16, 0x3 ); + + // Last 4x16 elements. + _mm512_storeu_epi64( pack_b_buffer_u8s8s32o32 + ( ( kr_new + 0 ) * NR ), a0_zmm ); + + // The 2nd, 3rd, and 4th 16byte chunk will be ignored, since its not part of the original data, + // but is here due to the packing in 4 16byte chunks format. + kr_new += 1; + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + if ( k_partial_pieces == 3 ) + { + a0_16 = _mm_loadu_epi8( b + ( ldb * ( k_full_pieces + 0 ) ) ); + b0_16 = _mm_loadu_epi8( b + ( ldb * ( k_full_pieces + 1 ) ) ); + c0_16 = _mm_loadu_epi8( b + ( ldb * ( k_full_pieces + 2 ) ) ); + d0_16 = _mm_setzero_si128(); + + } + else if( k_partial_pieces == 2 ) + { + a0_16 = _mm_loadu_epi8( b + ( ldb * ( k_full_pieces + 0 ) ) ); + b0_16 = _mm_loadu_epi8( b + ( ldb * ( k_full_pieces + 1 ) ) ); + c0_16 = _mm_setzero_si128(); + d0_16 = _mm_setzero_si128(); + } + else //k_partial_pieces == 1 + { + a0_16 = _mm_loadu_epi8( b + ( ldb * ( k_full_pieces + 0 ) ) ); + b0_16 = _mm_setzero_si128(); + c0_16 = _mm_setzero_si128(); + d0_16 = _mm_setzero_si128(); + } + + a01_16 = _mm_unpacklo_epi8( a0_16, b0_16 ); + a0_16 = _mm_unpackhi_epi8( a0_16, b0_16 ); + + c01_16 = _mm_unpacklo_epi8( c0_16, d0_16 ); + c0_16 = _mm_unpackhi_epi8( c0_16, d0_16 ); + + b0_16 = _mm_unpacklo_epi16( a01_16, c01_16 ); // 0 elem + a01_16 = _mm_unpackhi_epi16( a01_16, c01_16 ); // 1 elem + d0_16 = _mm_unpacklo_epi16( a0_16, c0_16 ); // 2 elem + c01_16 = _mm_unpackhi_epi16( a0_16, c0_16 ); // 3 elem + + __m512i a0_zmm = _mm512_castsi128_si512( b0_16 ); + a0_zmm = _mm512_inserti32x4( a0_zmm, a01_16, 0x1 ); + a0_zmm = _mm512_inserti32x4( a0_zmm, d0_16, 0x2 ); + a0_zmm = _mm512_inserti32x4( a0_zmm, c01_16, 0x3 ); + + // Last 4x16 elements. + _mm512_storeu_epi64( pack_b_buffer_u8s8s32o32 + ( ( kr_new + 0 ) * NR ), a0_zmm ); + } +} + +void packb_nrlt16_u8s8s32o32 + ( + int8_t* pack_b_buffer_u8s8s32o32, + const int8_t* b, + const dim_t ldb, + const dim_t KC, + const dim_t n0_partial_rem + ) +{ + int8_t buf0[16]; + int8_t buf1[16]; + int8_t buf2[16]; + int8_t buf3[16]; + + dim_t kr_new = 0; + + dim_t k_full_pieces_blks = KC / 4; + dim_t k_full_pieces = k_full_pieces_blks * 4; + dim_t k_partial_pieces = KC % 4; + + __m128i a0_16; + __m128i b0_16; + __m128i c0_16; + __m128i d0_16; + __m128i a01_16; + __m128i c01_16; + __m512i a0_zmm; + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 4 ) + { + memcpy( buf0, ( b + ( ldb * ( kr + 0 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) ); + memcpy( buf1, ( b + ( ldb * ( kr + 1 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) ); + memcpy( buf2, ( b + ( ldb * ( kr + 2 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) ); + memcpy( buf3, ( b + ( ldb * ( kr + 3 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) ); + + // Rearrange for vpdpbusd, read 4 rows from B with next 16 elements in each row. + a0_16 = _mm_loadu_epi8( buf0 ); + b0_16 = _mm_loadu_epi8( buf1 ); + c0_16 = _mm_loadu_epi8( buf2 ); + d0_16 = _mm_loadu_epi8( buf3 ); + + a01_16 = _mm_unpacklo_epi8( a0_16, b0_16 ); + a0_16 = _mm_unpackhi_epi8( a0_16, b0_16 ); + + c01_16 = _mm_unpacklo_epi8( c0_16, d0_16 ); + c0_16 = _mm_unpackhi_epi8( c0_16, d0_16 ); + + b0_16 = _mm_unpacklo_epi16( a01_16, c01_16 ); // 0 elem + a01_16 = _mm_unpackhi_epi16( a01_16, c01_16 ); // 1 elem + d0_16 = _mm_unpacklo_epi16( a0_16, c0_16 ); // 2 elem + c01_16 = _mm_unpackhi_epi16( a0_16, c0_16 ); // 3 elem + + a0_zmm = _mm512_castsi128_si512( b0_16 ); + a0_zmm = _mm512_inserti32x4( a0_zmm, a01_16, 0x1 ); + a0_zmm = _mm512_inserti32x4( a0_zmm, d0_16, 0x2 ); + a0_zmm = _mm512_inserti32x4( a0_zmm, c01_16, 0x3 ); + + // Last 4x16 elements. + _mm512_storeu_epi64( pack_b_buffer_u8s8s32o32 + ( ( kr_new + 0 ) * NR ), a0_zmm ); + + // The 2nd, 3rd, and 4th 16byte chunk will be ignored, since its not part of the original data, + // but is here due to the packing in 4 16byte chunks format. + kr_new += 1; + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + if ( k_partial_pieces == 3 ) + { + memcpy( buf0, ( b + ( ldb * ( k_full_pieces + 0 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) ); + memcpy( buf1, ( b + ( ldb * ( k_full_pieces + 1 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) ); + memcpy( buf2, ( b + ( ldb * ( k_full_pieces + 2 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) ); + + a0_16 = _mm_loadu_epi8( buf0 ); + b0_16 = _mm_loadu_epi8( buf1 ); + c0_16 = _mm_loadu_epi8( buf2 ); + d0_16 = _mm_setzero_si128(); + + } + else if( k_partial_pieces == 2 ) + { + memcpy( buf0, ( b + ( ldb * ( k_full_pieces + 0 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) ); + memcpy( buf1, ( b + ( ldb * ( k_full_pieces + 1 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) ); + + a0_16 = _mm_loadu_epi8( buf0 ); + b0_16 = _mm_loadu_epi8( buf1 ); + c0_16 = _mm_setzero_si128(); + d0_16 = _mm_setzero_si128(); + } + else //k_partial_pieces == 1 + { + memcpy( buf0, ( b + ( ldb * ( k_full_pieces + 0 ) ) ), ( n0_partial_rem * sizeof( int8_t ) ) ); + + a0_16 = _mm_loadu_epi8( buf0 ); + b0_16 = _mm_setzero_si128(); + c0_16 = _mm_setzero_si128(); + d0_16 = _mm_setzero_si128(); + } + + a01_16 = _mm_unpacklo_epi8( a0_16, b0_16 ); + a0_16 = _mm_unpackhi_epi8( a0_16, b0_16 ); + + c01_16 = _mm_unpacklo_epi8( c0_16, d0_16 ); + c0_16 = _mm_unpackhi_epi8( c0_16, d0_16 ); + + b0_16 = _mm_unpacklo_epi16( a01_16, c01_16 ); // 0 elem + a01_16 = _mm_unpackhi_epi16( a01_16, c01_16 ); // 1 elem + d0_16 = _mm_unpacklo_epi16( a0_16, c0_16 ); // 2 elem + c01_16 = _mm_unpackhi_epi16( a0_16, c0_16 ); // 3 elem + + __m512i a0_zmm = _mm512_castsi128_si512( b0_16 ); + a0_zmm = _mm512_inserti32x4( a0_zmm, a01_16, 0x1 ); + a0_zmm = _mm512_inserti32x4( a0_zmm, d0_16, 0x2 ); + a0_zmm = _mm512_inserti32x4( a0_zmm, c01_16, 0x3 ); + + // Last 4x16 elements. + _mm512_storeu_epi64( pack_b_buffer_u8s8s32o32 + ( ( kr_new + 0 ) * NR ), a0_zmm ); + } +} diff --git a/bench/bench_aocl_gemm/bench_input.txt b/bench/bench_aocl_gemm/bench_input.txt new file mode 100644 index 0000000000..de76f75168 --- /dev/null +++ b/bench/bench_aocl_gemm/bench_input.txt @@ -0,0 +1,377 @@ +i p 480 20 2050 2050 20 20 +i p 481 20 2050 2050 20 20 +i p 482 20 2050 2050 20 20 +i p 483 20 2050 2050 20 20 +i R 484 20 2050 2050 20 20 +i R 485 20 2050 2050 20 20 +i R 480 39 2050 2050 39 39 +i R 481 39 2050 2050 39 39 +i R 482 39 2050 2050 39 39 +i R 483 39 2050 2050 39 39 +i R 484 39 2050 2050 39 39 +i p 485 39 2050 2050 39 39 +i p 480 50 2050 2050 50 50 +i p 481 50 2050 2050 50 50 +i p 482 50 2050 2050 50 50 +i p 483 50 2050 2050 50 50 +i p 484 50 2050 2050 50 50 +i p 485 50 2050 2050 50 50 +i R 480 1108 2050 2050 1108 1108 +i R 481 1108 2050 2050 1108 1108 +i R 482 1108 2050 2050 1108 1108 +i R 483 1108 2050 2050 1108 1108 +i R 484 1108 2050 2050 1108 1108 +i R 485 1108 2050 2050 1108 1108 +i R 480 1127 2050 2050 1127 1127 +i R 481 1127 2050 2050 1127 1127 +i R 482 1127 2050 2050 1127 1127 +i R 483 1127 2050 2050 1127 1127 +i p 484 1127 2050 2050 1127 1127 +i p 485 1127 2050 2050 1127 1127 +i p 480 1138 2050 2050 1138 1138 +i p 481 1138 2050 2050 1138 1138 +i p 482 1138 2050 2050 1138 1138 +i p 483 1138 2050 2050 1138 1138 +i p 484 1138 2050 2050 1138 1138 +i p 485 1138 2050 2050 1138 1138 +i p 1 1 3 3 1 1 +i p 1 9 3 3 9 9 +i p 1 2048 3 3 2048 2048 +i p 1 2048 5192 5192 2048 2048 +i p 9 1 3 3 1 1 +i p 576 1 3500 3500 1 1 +i p 1 1 1 1 1 1 +i p 102 1088 1024 1024 1088 1088 +i p 102 2048 1024 1024 2048 2048 +i p 485 656 1024 1024 656 656 +i p 483 656 1024 1024 656 656 +i p 81 128 3 3 128 128 +i p 1022 512 515 515 512 512 +i p 74 512 515 515 512 512 +i p 253 2048 515 515 2048 2048 +i p 8192 1040 515 515 1040 1040 +i p 10 1029 515 515 1029 1029 +i p 24 1040 2050 2050 1040 1040 +i p 1024 1029 2050 2050 1029 1029 +i p 480 660 2050 2050 660 660 +i p 481 660 2050 2050 660 660 +i p 482 660 2050 2050 660 660 +i p 483 660 2050 2050 660 660 +i p 484 660 2050 2050 660 660 +i p 485 660 2050 2050 660 660 +i p 480 679 2050 2050 679 679 +i p 481 679 2050 2050 679 679 +i p 482 679 2050 2050 679 679 +i p 483 679 2050 2050 679 679 +i p 484 679 2050 2050 679 679 +i p 485 679 2050 2050 679 679 +i p 480 690 2050 2050 690 690 +i p 481 690 2050 2050 690 690 +i p 482 690 2050 2050 690 690 +i p 483 690 2050 2050 690 690 +i p 484 690 2050 2050 690 690 +i p 485 690 2050 2050 690 690 +i p 480 660 2048 2048 660 660 +i p 481 660 2048 2048 660 660 +i p 482 660 2048 2048 660 660 +i p 483 660 2048 2048 660 660 +i p 484 660 2048 2048 660 660 +i p 485 660 2048 2048 660 660 +i p 480 679 2048 2048 679 679 +i p 481 679 2048 2048 679 679 +i p 482 679 2048 2048 679 679 +i p 483 679 2048 2048 679 679 +i p 484 679 2048 2048 679 679 +i p 485 679 2048 2048 679 679 +i p 480 690 2048 2048 690 690 +i p 481 690 2048 2048 690 690 +i p 482 690 2048 2048 690 690 +i p 483 690 2048 2048 690 690 +i p 484 690 2048 2048 690 690 +i p 485 690 2048 2048 690 690 +i p 480 656 1024 1024 656 656 +i p 480 128 3 3 128 128 +i p 1024 512 515 515 512 512 +i p 1024 2048 1024 1024 2048 2048 +i p 1024 2048 515 515 2048 2048 +i p 1024 1040 515 515 1040 1040 +i p 5 1029 515 515 1029 1029 +i p 1024 1029 515 515 1029 1029 +i p 1024 1040 2050 2050 1040 1040 +i p 1029 1029 2050 2050 1029 1029 +i R 480 646 2050 2050 646 646 +i R 481 646 2050 2050 646 646 +i R 482 646 2050 2050 646 646 +i R 483 646 2050 2050 646 646 +i R 484 646 2050 2050 646 646 +i R 485 646 2050 2050 646 646 +i R 481 656 2050 2050 656 656 +i R 482 656 2050 2050 656 656 +i R 483 656 2050 2050 656 656 +i R 484 656 2050 2050 656 656 +i p 485 656 2050 2050 656 656 +i p 480 672 2050 2050 672 672 +i p 481 672 2050 2050 672 672 +i p 482 672 2050 2050 672 672 +i p 483 672 2050 2050 672 672 +i p 484 672 2050 2050 672 672 +i p 485 672 2050 2050 672 672 +i p 480 688 2050 2050 688 688 +i p 481 688 2050 2050 688 688 +i r 482 688 2050 2050 688 688 +i r 483 688 2050 2050 688 688 +i r 484 688 2050 2050 688 688 +i r 485 688 2050 2050 688 688 +i r 1024 512 64 64 512 512 +i r 16 256 512 512 256 256 +i r 480 640 512 512 640 640 +i r 64 768 512 512 768 768 +i r 128 128 128 128 128 128 +i r 1024 64 512 512 64 64 +i r 1024 256 32 32 256 256 +i r 1024 512 64 64 512 512 +i r 480 640 512 512 640 640 +i p 1024 32 256 256 32 32 +i P 1024 64 512 512 64 64 +i P 64 800 320 320 800 800 +i P 64 768 512 512 768 768 +i P 16 256 512 512 256 256 +i P 128 128 128 128 128 128 +i P 256 512 256 256 512 512 +i P 1024 1024 1024 1024 1024 1024 +i P 480 640 1024 1024 640 640 +i P 480 640 256 256 640 640 +i P 8 64 32 32 64 64 +i P 9 64 32 32 64 64 +i P 10 128 64 64 128 128 +i P 8 8 8 8 8 8 +i P 12 12 12 12 12 12 +i P 25 25 25 25 25 25 +i P 25 25 20 20 25 25 +f p 480 20 2050 2050 20 20 +f p 481 20 2050 2050 20 20 +f p 482 20 2050 2050 20 20 +f p 483 20 2050 2050 20 20 +f R 484 20 2050 2050 20 20 +f R 485 20 2050 2050 20 20 +f R 480 39 2050 2050 39 39 +f R 481 39 2050 2050 39 39 +f R 482 39 2050 2050 39 39 +f R 483 39 2050 2050 39 39 +f R 484 39 2050 2050 39 39 +f p 485 39 2050 2050 39 39 +f p 480 50 2050 2050 50 50 +f p 481 50 2050 2050 50 50 +f p 482 50 2050 2050 50 50 +f p 483 50 2050 2050 50 50 +f p 484 50 2050 2050 50 50 +f p 485 50 2050 2050 50 50 +f R 480 1108 2050 2050 1108 1108 +f R 481 1108 2050 2050 1108 1108 +f R 482 1108 2050 2050 1108 1108 +f R 483 1108 2050 2050 1108 1108 +f R 484 1108 2050 2050 1108 1108 +f R 485 1108 2050 2050 1108 1108 +f R 480 1127 2050 2050 1127 1127 +f R 481 1127 2050 2050 1127 1127 +f R 482 1127 2050 2050 1127 1127 +f R 483 1127 2050 2050 1127 1127 +f p 484 1127 2050 2050 1127 1127 +f p 485 1127 2050 2050 1127 1127 +f p 480 1138 2050 2050 1138 1138 +f p 481 1138 2050 2050 1138 1138 +f p 482 1138 2050 2050 1138 1138 +f p 483 1138 2050 2050 1138 1138 +f p 484 1138 2050 2050 1138 1138 +f p 485 1138 2050 2050 1138 1138 +f p 1 1 3 3 1 1 +f p 1 9 3 3 9 9 +f p 1 2048 3 3 2048 2048 +f p 1 2048 5192 5192 2048 2048 +f p 9 1 3 3 1 1 +f p 576 1 3500 3500 1 1 +f p 1 1 1 1 1 1 +f p 102 1088 1024 1024 1088 1088 +f p 102 2048 1024 1024 2048 2048 +f p 485 656 1024 1024 656 656 +f p 483 656 1024 1024 656 656 +f p 81 128 3 3 128 128 +f p 1022 512 515 515 512 512 +f p 74 512 515 515 512 512 +f p 253 2048 515 515 2048 2048 +f p 8192 1040 515 515 1040 1040 +f p 10 1029 515 515 1029 1029 +f p 24 1040 2050 2050 1040 1040 +f p 1024 1029 2050 2050 1029 1029 +f p 480 660 2050 2050 660 660 +f p 481 660 2050 2050 660 660 +f p 482 660 2050 2050 660 660 +f p 483 660 2050 2050 660 660 +f p 484 660 2050 2050 660 660 +f p 485 660 2050 2050 660 660 +f p 480 679 2050 2050 679 679 +f p 481 679 2050 2050 679 679 +f p 482 679 2050 2050 679 679 +f p 483 679 2050 2050 679 679 +f p 484 679 2050 2050 679 679 +f p 485 679 2050 2050 679 679 +f p 480 690 2050 2050 690 690 +f p 481 690 2050 2050 690 690 +f p 482 690 2050 2050 690 690 +f p 483 690 2050 2050 690 690 +f p 484 690 2050 2050 690 690 +f p 485 690 2050 2050 690 690 +f p 480 660 2048 2048 660 660 +f p 481 660 2048 2048 660 660 +f p 482 660 2048 2048 660 660 +f p 483 660 2048 2048 660 660 +f p 484 660 2048 2048 660 660 +f p 485 660 2048 2048 660 660 +f p 480 679 2048 2048 679 679 +f p 481 679 2048 2048 679 679 +f p 482 679 2048 2048 679 679 +f p 483 679 2048 2048 679 679 +f p 484 679 2048 2048 679 679 +f p 485 679 2048 2048 679 679 +f p 480 690 2048 2048 690 690 +f p 481 690 2048 2048 690 690 +f p 482 690 2048 2048 690 690 +f p 483 690 2048 2048 690 690 +f p 484 690 2048 2048 690 690 +f p 485 690 2048 2048 690 690 +f p 480 656 1024 1024 656 656 +f p 480 128 3 3 128 128 +f p 1024 512 515 515 512 512 +f p 1024 2048 1024 1024 2048 2048 +f p 1024 2048 515 515 2048 2048 +f p 1024 1040 515 515 1040 1040 +f p 5 1029 515 515 1029 1029 +f p 1024 1029 515 515 1029 1029 +f p 1024 1040 2050 2050 1040 1040 +f p 1029 1029 2050 2050 1029 1029 +f R 480 646 2050 2050 646 646 +f R 481 646 2050 2050 646 646 +f R 482 646 2050 2050 646 646 +f R 483 646 2050 2050 646 646 +f R 484 646 2050 2050 646 646 +f R 485 646 2050 2050 646 646 +f R 481 656 2050 2050 656 656 +f R 482 656 2050 2050 656 656 +f R 483 656 2050 2050 656 656 +f R 484 656 2050 2050 656 656 +f p 485 656 2050 2050 656 656 +f p 480 672 2050 2050 672 672 +f p 481 672 2050 2050 672 672 +f p 482 672 2050 2050 672 672 +f p 483 672 2050 2050 672 672 +f p 484 672 2050 2050 672 672 +f p 485 672 2050 2050 672 672 +f p 480 688 2050 2050 688 688 +f p 481 688 2050 2050 688 688 +f r 482 688 2050 2050 688 688 +f r 483 688 2050 2050 688 688 +f r 484 688 2050 2050 688 688 +f r 485 688 2050 2050 688 688 +f r 1024 512 64 64 512 512 +f r 16 256 512 512 256 256 +f r 480 640 512 512 640 640 +f r 64 768 512 512 768 768 +f r 128 128 128 128 128 128 +f r 1024 64 512 512 64 64 +f r 1024 256 32 32 256 256 +f r 1024 512 64 64 512 512 +f r 480 640 512 512 640 640 +f p 1024 32 256 256 32 32 +f P 1024 64 512 512 64 64 +f P 64 800 320 320 800 800 +f P 64 768 512 512 768 768 +f P 16 256 512 512 256 256 +f P 128 128 128 128 128 128 +f P 256 512 256 256 512 512 +f P 1024 1024 1024 1024 1024 1024 +f P 480 640 1024 1024 640 640 +f P 480 640 256 256 640 640 +f P 8 64 32 32 64 64 +f P 9 64 32 32 64 64 +f P 10 128 64 64 128 128 +f P 8 8 8 8 8 8 +f P 12 12 12 12 12 12 +f P 25 25 25 25 25 25 +f P 25 25 20 20 25 25 +i r 4096 256 5 5 256 256 +i r 3000 256 128 128 256 256 +i r 4096 1024 512 512 1024 1024 +i r 144 256 5 5 256 256 +i r 144 256 128 128 256 256 +i r 144 1024 512 512 1024 1024 +i r 480 688 256 256 688 688 +i r 480 640 512 512 640 640 +i r 480 640 1024 1024 640 640 +i r 64 800 320 320 800 800 +i r 64 768 512 512 768 768 +i r 16 256 512 512 256 256 +i r 128 128 128 128 128 128 +i r 256 512 256 256 512 512 +i r 1024 1024 1024 1024 1024 1024 +i r 1024 32 256 256 32 32 +i r 1024 64 512 512 64 64 +i r 1024 256 32 32 256 256 +i r 1024 512 64 64 512 512 +i r 512 32 256 256 32 32 +i r 512 768 512 512 768 768 +i r 512 256 32 32 256 256 +i r 512 512 64 64 512 512 +i r 512 256 768 768 256 256 +i r 768 768 1024 1024 768 768 +i r 768 768 768 768 768 768 +i r 2048 2048 2048 2048 2048 2048 +i r 4096 4096 4096 4096 4096 4096 +f r 4096 256 5 5 256 256 +f r 3000 256 128 128 256 256 +f r 4096 1024 512 512 1024 1024 +f r 144 256 5 5 256 256 +f r 144 256 128 128 256 256 +f r 144 1024 512 512 1024 1024 +f r 480 688 256 256 688 688 +f r 480 640 512 512 640 640 +f r 480 640 1024 1024 640 640 +f r 64 800 320 320 800 800 +f r 64 768 512 512 768 768 +f r 16 256 512 512 256 256 +f r 128 128 128 128 128 128 +f r 256 512 256 256 512 512 +f r 1024 1024 1024 1024 1024 1024 +f r 1024 32 256 256 32 32 +f r 1024 64 512 512 64 64 +f r 1024 256 32 32 256 256 +f r 1024 512 64 64 512 512 +f r 512 32 256 256 32 32 +f r 512 768 512 512 768 768 +f r 512 256 32 32 256 256 +f r 512 512 64 64 512 512 +f r 512 256 768 768 256 256 +f r 768 768 1024 1024 768 768 +f r 768 768 768 768 768 768 +f r 2048 2048 2048 2048 2048 2048 +f r 4096 4096 4096 4096 4096 4096 +f r 2048 1024 1024 1024 1024 1024 +f r 2048 4096 1024 1024 4096 4096 +f r 2048 1024 4096 4096 1024 1024 +f r 2048 1024 2 2 1024 1024 +f r 128 1024 1024 1024 1024 1024 +f r 1536 768 768 768 768 768 +f r 1536 3072 768 768 3072 3072 +f r 1536 768 3072 3072 768 768 +f r 1536 768 2 2 768 768 +f r 128 768 768 768 768 768 +f r 1024 8 13 13 8 8 +f r 1024 4 8 8 4 4 +f r 1024 128 355 355 128 128 +f r 1024 64 128 128 64 64 +f r 1024 1 64 64 1 1 +f r 480 1 256 256 1 1 +f r 480 256 512 512 256 256 +f r 480 1024 845 845 1024 1024 +f r 480 512 1024 1024 512 512 +f r 10 17191 128 128 17191 17191 +f r 10 512 256 256 512 512 diff --git a/bench/bench_aocl_gemm/bench_lpgemm.c b/bench/bench_aocl_gemm/bench_lpgemm.c new file mode 100644 index 0000000000..91dd10b966 --- /dev/null +++ b/bench/bench_aocl_gemm/bench_lpgemm.c @@ -0,0 +1,475 @@ +#include +#include +#include +#include +#include +#include +#include + +#include "blis/blis.h" + +// Mode can be one of the follwoing: +// 1. p - performance, used for benchmarks. +// 2. a - accuracy, used to test accuracy/correctness. +// Default value is p, can be modified by passing command line arg. +char bench_mode = 'p'; + +int32_t global_n_repeat = 0; + +#define _XSTR(str) #str +#define XSTR(str) _XSTR(str) + +#define GEN_FUNC_NAME(prototype,ctype) prototype ## ctype + +#define GEN_FILL_ARRAY_FUNC(ctype) \ +void fill_array_ ## ctype ( void* arr, dim_t size ) \ +{ \ + ctype* temp_arr = ( ctype* ) arr; \ + for ( dim_t i = 0; i < size; ++i ) \ + { \ + temp_arr[i] = ( ctype )( i % 100 ); \ + } \ +} \ + +GEN_FILL_ARRAY_FUNC(uint8_t) +GEN_FILL_ARRAY_FUNC(int8_t) +GEN_FILL_ARRAY_FUNC(float) + +#define GEN_BLIS_MAT_MUL_FUNC(A_type,B_type,C_type,BLAS_SFX) \ +void mat_mul_ ## BLAS_SFX \ + ( \ + char op_t, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + C_type alpha, \ + A_type* a, \ + dim_t lda, \ + B_type* b, \ + dim_t ldb, \ + C_type beta, \ + C_type* c, \ + dim_t ldc \ + ) \ +{ \ + char transa = 'n'; \ + char transb = 'n'; \ + char reordera = 'n'; \ + char reorderb = 'n'; \ + \ + if ( ( op_t == 'p' ) || ( op_t == 'P' ) ) \ + { \ + /* No reordering of B.*/ \ + reordera = 'n'; \ + reorderb = 'n'; \ + } \ + else if ( ( op_t == 'r' ) || ( op_t == 'R' ) ) \ + { \ + /* Reordered B.*/ \ + reordera = 'n'; \ + reorderb = 'r'; \ + } \ + \ + aocl_gemm_ ## BLAS_SFX( transa, transb, m, n, k, \ + alpha, \ + a, lda, reordera, \ + b, ldb, reorderb, \ + beta, \ + c, ldc ); \ +} \ + +GEN_BLIS_MAT_MUL_FUNC(uint8_t,int8_t,int32_t,u8s8s32os32) +GEN_BLIS_MAT_MUL_FUNC(float,float,float,f32f32f32of32) + +double get_gflops + ( + dim_t m, + dim_t n, + dim_t k, + double runtime + ) +{ + return ( ( 2.0 * m * n * k ) / ( runtime * 1.0e9 ) ); +} + +void print_result + ( + const char* msg, + int32_t n_repeats, + dim_t m, + dim_t n, + dim_t k, + dim_t lda, + dim_t ldb, + dim_t ldc, + double runtime + ) +{ + double gflops = get_gflops( m, n, k, runtime ); + printf("%s m: %ld, n: %ld, k: %ld, lda: %ld, ldb: %ld, ldc: %ld," \ + " GFlops: %f, n_repeats: %d\n", + msg, m, n, k, lda, ldb, ldc, gflops, n_repeats); +} + +#define GEN_MAT_MUL_BENCH_DRV_FUNC(A_type,B_type,C_type,BLAS_SFX) \ +void mat_mul_bench_driver_ ## BLAS_SFX \ + ( \ + char op_t, \ + int32_t n_repeats, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + C_type alpha, \ + A_type* a, \ + dim_t lda, \ + B_type* b, \ + dim_t ldb, \ + C_type beta, \ + C_type* c, \ + dim_t ldc \ + ) \ +{ \ + double gflops; \ + double min_time_diff = DBL_MAX; \ + for ( int32_t nr = 0; nr < n_repeats; ++nr ) \ + { \ + if ( bench_mode == 'a' ) \ + { \ + memset( ( void* ) c, 0, sizeof( float ) * m * n ); \ + } \ + \ + struct timespec tstart={0,0}, tend={0,0}; \ + clock_gettime(CLOCK_MONOTONIC, &tstart); \ + \ + GEN_FUNC_NAME(mat_mul_,BLAS_SFX) \ + ( \ + op_t, m, n, k, \ + alpha, \ + a, lda, \ + b, ldb, \ + beta, \ + c, ldc \ + ); \ + \ + clock_gettime(CLOCK_MONOTONIC, &tend); \ + \ + double diff = \ + ( ( double ) tend.tv_sec + ( 1.0e-9 * tend.tv_nsec ) ) - \ + ( ( double ) tstart.tv_sec + ( 1.0e-9 * tstart.tv_nsec ) ); \ + min_time_diff = ( diff < min_time_diff ) ? diff : min_time_diff; \ + } \ + \ + print_result( XSTR(BLAS_SFX), n_repeats, m, n, k, lda, ldb, ldc, min_time_diff); \ +} \ + +GEN_MAT_MUL_BENCH_DRV_FUNC(uint8_t,int8_t,int32_t,u8s8s32os32) +GEN_MAT_MUL_BENCH_DRV_FUNC(float,float,float,f32f32f32of32) + +#define GEN_MAT_MUL_ACC_CHK_DRV_FUNC(A_type,B_type,C_type,BLAS_SFX) \ +void mat_mul_accuracy_check_driver_ ## BLAS_SFX \ + ( \ + FILE* fout, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + C_type alpha, \ + A_type* a, \ + dim_t lda, \ + B_type* b, \ + dim_t ldb, \ + C_type beta, \ + C_type* c, \ + dim_t ldc, \ + C_type* c_ref, \ + dim_t ldc_ref \ + ) \ +{ \ + for ( dim_t i = 0; i < m; ++i ) \ + { \ + for ( dim_t j = 0; j < n; ++j ) \ + { \ + C_type temp_accum = 0; \ + \ + for ( dim_t p = 0; p < k; ++p) \ + { \ + temp_accum += ( *( a + ( i * lda ) + p ) * *( b + ( p * ldb ) + j ) ); \ + } \ + \ + temp_accum = ( beta * ( * (c_ref + ( ldc_ref * i ) + j ) ) ) \ + + ( alpha * temp_accum ); \ + if ( *( c + ( ldc * i ) + j ) != temp_accum ) \ + { \ + if ( fout ) \ + { \ + fprintf( fout, "%s Failure input m: %ld, n: %ld, k: %ld," \ + " lda: %ld, ldb: %ld, ldc: %ld\n", \ + XSTR(BLAS_SFX), m, n, k, lda, ldb, ldc ); \ + fflush( fout ); \ + } \ + printf("failure, m: %ld, n: %ld, k: %ld\n", i, j, k); \ + goto cleanup_acc; \ + } \ + } \ + } \ +cleanup_acc: \ + return; \ +} \ + +GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t,int8_t,int32_t,u8s8s32os32) +GEN_MAT_MUL_ACC_CHK_DRV_FUNC(float,float,float,f32f32f32of32) + +#define GEN_MAT_MUL_BENCH_MAIN_FUNC(A_type,B_type,C_type,BLAS_SFX) \ +void mat_mul_bench_main_ ## BLAS_SFX \ + ( \ + FILE* fin, \ + FILE* fout, \ + char op_t, \ + int32_t m, \ + int32_t n, \ + int32_t k, \ + int32_t stride_a, \ + int32_t stride_b, \ + int32_t stride_c \ + ) \ +{ \ + if ( ( op_t != 'p' ) && ( op_t != 'P' ) && ( op_t != 'r' ) && ( op_t != 'R' ) ) \ + { \ + printf("The op_t ( 2nd arg in input.txt) is not valid\n"); \ + return; \ + } \ + \ + int32_t n_repeats = bli_max( 30, bli_min(( 3e10 / ( ( int64_t )m * n * k )), 100 )); \ + if ( global_n_repeat > 0 ) \ + { \ + n_repeats = global_n_repeat; \ + } \ + \ + /* Get 64 byte aligned memory.*/ \ + A_type* a = ( A_type* ) bli_malloc_user( sizeof( A_type ) * m * k ); \ + \ + B_type* b = ( B_type* ) bli_malloc_user( sizeof( B_type ) * n * k ); \ + \ + C_type* c = ( C_type* ) bli_malloc_user( sizeof( C_type ) * m * n ); \ + memset( ( void* ) c, 0, sizeof( C_type ) * m * n ); \ + \ + C_type* c_ref = ( C_type* ) bli_malloc_user( sizeof( C_type ) * m * n ); \ + memset( ( void* ) c_ref, 0, sizeof( C_type ) * m * n ); \ + \ + C_type alpha; \ + C_type beta; \ + if ( bench_mode == 'p' ) \ + { \ + alpha = 1; \ + beta = 0; \ + } \ + else if ( bench_mode == 'a' ) \ + { \ + alpha = 2; \ + beta = 9; \ + } \ + \ + GEN_FUNC_NAME(fill_array_,A_type)( a, ( m * k ) ); \ + GEN_FUNC_NAME(fill_array_,B_type)( b, ( k * n ) ); \ + \ + if ( ( op_t == 'p' ) || ( op_t == 'P' ) ) \ + { \ + /* No reordering of B.*/ \ + GEN_FUNC_NAME(mat_mul_bench_driver_,BLAS_SFX) \ + ( \ + op_t, n_repeats, m, n, k, \ + alpha, \ + a, stride_a, \ + b, stride_b, \ + beta, \ + c, stride_c \ + ); \ + } \ + else if ( ( op_t == 'r' ) || ( op_t == 'R' ) ) \ + { \ + /* Reorder B.*/ \ + siz_t b_reorder_buf_siz_req = \ + GEN_FUNC_NAME(aocl_get_reorder_buf_size_,BLAS_SFX)( 'B', k, n ); \ + \ + B_type* b_reorder = ( B_type* ) bli_malloc_user( b_reorder_buf_siz_req ); \ + GEN_FUNC_NAME(aocl_reorder_,BLAS_SFX)( 'B', b, b_reorder, k, n, stride_b ); \ + \ + GEN_FUNC_NAME(mat_mul_bench_driver_,BLAS_SFX) \ + ( \ + op_t, n_repeats, m, n, k, \ + alpha, \ + a, stride_a, \ + b_reorder, stride_b, \ + beta, \ + c, stride_c \ + ); \ + \ + bli_free_user( b_reorder ); \ + } \ + \ + if ( bench_mode == 'a' ) \ + { \ + printf(" Running accuracy check.\n"); \ + GEN_FUNC_NAME(mat_mul_accuracy_check_driver_,BLAS_SFX) \ + ( \ + fout, m, n, k, \ + alpha, \ + a, stride_a, \ + b, stride_b, \ + beta, \ + c, stride_c, \ + c_ref, stride_c \ + ); \ + } \ + \ + if ( a != NULL ) \ + { \ + bli_free_user( a ); \ + } \ + if ( b != NULL ) \ + { \ + bli_free_user( b ); \ + } \ + if ( c != NULL ) \ + { \ + bli_free_user( c ); \ + } \ + if ( c_ref != NULL ) \ + { \ + bli_free_user( c_ref ); \ + } \ +} \ + +GEN_MAT_MUL_BENCH_MAIN_FUNC(uint8_t,int8_t,int32_t,u8s8s32os32) +GEN_MAT_MUL_BENCH_MAIN_FUNC(float,float,float,f32f32f32of32) + +int main( int argc, char** argv ) +{ + FILE* fin = NULL; + if ( argc < 5 ) + { + printf( "Usage: ./mat_mul -i input.txt -m mode < -n 1000 >\nMode is either a or p." \ + " a is used for accuracy test, whereas p is used for" \ + " performance benchmarking.\nn_repeats can be set" \ + " optionally using -n argument.\n" ); + exit( 1 ); + } + + char* file_name = NULL; + + // Parse CLI arguments. + opterr = 0; + int opt_val; + while ( ( opt_val = getopt( argc, argv, "i:m:n:" ) ) != -1 ) + { + switch ( opt_val ) + { + case 'i': + file_name = optarg; + break; + case 'm': + bench_mode = ( ( ( *optarg ) == 'a' ) || ( ( *optarg ) == 'p' ) ) ? ( *optarg ) : 'p'; + break; + case 'n': + global_n_repeat = ( atoi( optarg ) > 0 ) ? atoi( optarg ) : 0; + break; + default: + break; + } + } + + if ( bench_mode == 'p' ) + { + printf( "Running bench in performance benchmarking mode.\n" ); + } + else if ( bench_mode == 'a' ) + { + printf( "Running bench in accuracy/correctness testing mode.\n" ); + } + + if ( file_name == NULL ) + { + printf( " File name provided is invalid.\n" ); + exit( 1 ); + } + + fin = fopen( file_name, "r" ); + if (fin == NULL) + { + printf( "Error opening the file %s\n", argv[1] ); + exit( 1 ); + } + + FILE* fout = NULL; + + fout = fopen( "lpgemm_accuracy_test_failures.txt", "w" ); + + char op_type_char; + char op_t; + int32_t m, n, k; + int32_t stride_a, stride_b, stride_c; + + const dim_t len_list_omp_cores_for_testing = 6; + const dim_t list_omp_cores_for_testing[6] = { 100, 80, 64, 24, 8, 1 }; + + dim_t core_index = 0; + bool can_run = TRUE; + while ( ( can_run == TRUE ) && ( fseek( fin, 0L, SEEK_SET ) == 0 ) ) + { + if ( bench_mode == 'p' ) + { + can_run = FALSE; + } + else if ( bench_mode == 'a' ) + { + // For accuracy testing, we test accuracy using multiple different + // number of cores. This helps uncover any bugs related to over + // subscription or varying thread factorizations. + // Set current number of cores. + omp_set_num_threads( list_omp_cores_for_testing[core_index] ); + printf( "Accuracy test using %ld threads.\n", + list_omp_cores_for_testing[core_index] ); + + core_index++; + if ( core_index < len_list_omp_cores_for_testing ) + { + can_run = TRUE; + } + else + { + can_run = FALSE; + } + } + + while ( fscanf( fin, "%c %c %d %d %d %d %d %d\n", + &op_type_char, &op_t, &m, &n, &k, + &stride_a, &stride_b, &stride_c ) == 8 ) + { + if ( ( op_type_char == 'i' ) || ( op_type_char == 'I' ) ) + { + GEN_FUNC_NAME(mat_mul_bench_main_,u8s8s32os32) + ( + fin, fout, op_t, + m, n, k, stride_a, stride_b, stride_c + ); + } + else if ( ( op_type_char == 'f' ) || ( op_type_char == 'F' ) ) + { + GEN_FUNC_NAME(mat_mul_bench_main_,f32f32f32of32) + ( + fin, fout, op_t, + m, n, k, stride_a, stride_b, stride_c + ); + } + } + } + + if ( fin ) + { + fclose( fin ); + } + if ( fout ) + { + fclose( fout ); + } + return 0; +} diff --git a/bench/bench_aocl_gemm/data_gen_lpgemm.py b/bench/bench_aocl_gemm/data_gen_lpgemm.py new file mode 100644 index 0000000000..f9c09689f9 --- /dev/null +++ b/bench/bench_aocl_gemm/data_gen_lpgemm.py @@ -0,0 +1,45 @@ +# Initializing global mnk_array.This array will be used to store all mnk values +mnk_array = [] + +max_elem = 2500; +out_file_name = "accuracy_test_data_lpgemm.txt" +# Important mnk generator function.This will generate all possible combinations +# of m,n,k values using formula m(t+1)=ROUND(m(t)*Base,0)+offset +def mnk_generator(): + k_1 = 1 + incr_k = 20 + while (k_1 <= max_elem): + n_1 = 1 + incr_n = 20 + while (n_1 <= max_elem): + m_1 = 1 + incr_m = 20 + while (m_1 <= max_elem): + mnk_array.append([m_1, n_1, k_1]) + if (m_1 == 1): + m_1 = m_1 + 9 + else: + m_1 = m_1 + incr_m + if (n_1 == 1): + n_1 = n_1 + 9 + else: + n_1 = n_1 + incr_n + if (k_1 == 1): + k_1 = k_1 + 9 + else: + k_1 = k_1 + incr_k + +def data_gen(): + mnk_generator() + + fout = open(out_file_name, "w") + + for ele in mnk_array: + fout.write("i r " + str(ele[0]) + " " + str(ele[1]) + " " + str(ele[2]) + " " +\ + str(ele[2]) + " " + str(ele[1]) + " " + str(ele[1]) + "\n") + + fout.truncate(fout.tell() - 1) + fout.close() + +##__main__ +data_gen() diff --git a/frame/base/bli_cpuid.c b/frame/base/bli_cpuid.c index 605d4c8089..552ab6e7aa 100644 --- a/frame/base/bli_cpuid.c +++ b/frame/base/bli_cpuid.c @@ -287,7 +287,8 @@ bool bli_cpuid_is_zen4 FEATURE_AVX512DQ | FEATURE_AVX512CD | FEATURE_AVX512BW | - FEATURE_AVX512VL ; + FEATURE_AVX512VL | + FEATURE_AVX512VNNI; if ( !bli_cpuid_has_features( features, expected ) ) return FALSE; @@ -558,6 +559,62 @@ bool bli_cpuid_is_avx_supported( void ) return is_avx_supported; } + +// Check (at runtime) if AVX512_VNNI is supported on the current platform, this +// is to ensure that AVX512_VNNI kernels are not used on legacy platforms which +// results in crash. + +// The support for AVX512_VNNI is checked only once (when this API is called +// first time). On subsequent calls the cached value is returned. +static bool is_avx512vnni_supported = FALSE; + +// Determine if the CPU has support for AVX512_VNNI. +void bli_cpuid_check_avx512vnni_support( void ) +{ + uint32_t family, model, features; + + // Call the CPUID instruction and parse its results into a family id, + // model id, and a feature bit field. + bli_cpuid_query( &family, &model, &features ); + + // Check for expected CPU features. + const uint32_t expected = FEATURE_AVX | + FEATURE_FMA3 | + FEATURE_AVX2 | + FEATURE_AVX512F | + FEATURE_AVX512DQ | + FEATURE_AVX512BW | + FEATURE_AVX512VL | + FEATURE_AVX512VNNI; + + if ( !bli_cpuid_has_features( features, expected ) ) + { + is_avx512vnni_supported = FALSE; + } + else + { + is_avx512vnni_supported = TRUE; + } +} + +static bli_pthread_once_t once_check_avx512vnni_support = BLIS_PTHREAD_ONCE_INIT; + +// Ensure that actual support determination happens only once +void bli_cpuid_check_avx512vnni_support_once( void ) +{ +#ifndef BLIS_CONFIGURETIME_CPUID + bli_pthread_once( &once_check_avx512vnni_support, bli_cpuid_check_avx512vnni_support ); +#endif +} + +// API to check if AVX512_VNNI is supported or not on the current platform. +bool bli_cpuid_is_avx512vnni_supported( void ) +{ + bli_cpuid_check_avx512vnni_support_once(); + + return is_avx512vnni_supported; +} + #elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) arch_t bli_cpuid_query_id( void ) @@ -758,6 +815,7 @@ enum FEATURE_MASK_AVX512CD = (1u<<28), // cpuid[eax=7,ecx=0] :ebx[28] FEATURE_MASK_AVX512BW = (1u<<30), // cpuid[eax=7,ecx=0] :ebx[30] FEATURE_MASK_AVX512VL = (1u<<31), // cpuid[eax=7,ecx=0] :ebx[31] + FEATURE_MASK_AVX512VNNI = (1u<<11), // cpuid[eax=7,ecx=0] :ecx[11] FEATURE_MASK_XGETBV = (1u<<26)| (1u<<27), // cpuid[eax=1] :ecx[27:26] XGETBV_MASK_XMM = 0x02u, // xcr0[1] @@ -824,6 +882,8 @@ uint32_t bli_cpuid_query if ( bli_cpuid_has_features( ebx, FEATURE_MASK_AVX512CD ) ) *features |= FEATURE_AVX512CD; if ( bli_cpuid_has_features( ebx, FEATURE_MASK_AVX512BW ) ) *features |= FEATURE_AVX512BW; if ( bli_cpuid_has_features( ebx, FEATURE_MASK_AVX512VL ) ) *features |= FEATURE_AVX512VL; + + if ( bli_cpuid_has_features( ecx, FEATURE_MASK_AVX512VNNI ) ) *features |= FEATURE_AVX512VNNI; } // Check extended processor info / features bits for AMD-specific features. diff --git a/frame/base/bli_cpuid.h b/frame/base/bli_cpuid.h index cb4c45ab5d..439cef3e41 100644 --- a/frame/base/bli_cpuid.h +++ b/frame/base/bli_cpuid.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018-2021, Advanced Micro Devices, Inc. + Copyright (C) 2018-2022, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -134,6 +134,7 @@ BLIS_INLINE bool bli_cpuid_has_features( uint32_t have, uint32_t want ) void get_cpu_name( char *cpu_name ); int vpu_count( void ); bool bli_cpuid_is_avx_supported(void); +bool bli_cpuid_is_avx512vnni_supported(void); enum { @@ -157,7 +158,8 @@ enum FEATURE_AVX512ER = 0x0800, FEATURE_AVX512CD = 0x1000, FEATURE_AVX512BW = 0x2000, - FEATURE_AVX512VL = 0x4000 + FEATURE_AVX512VL = 0x4000, + FEATURE_AVX512VNNI = 0x8000 }; From e073e8b6697320ab93de5c78078232dbb6239dd6 Mon Sep 17 00:00:00 2001 From: mkadavil Date: Fri, 10 Jun 2022 18:19:08 +0530 Subject: [PATCH 133/243] DAMAXV AXX512 micro kernel bug fix. -DAMAXV AVX512 is giving wrong results when max element is present at index in [n-u, n), where u < 32. This is a fallout of using wrong start offset for the non-loop unrolled code. -Functions for replacing NaN with negative numbers is replaced with MACRO to avoid function call overhead and to remove static variables used for stateful replacement numbers for NaN. AMD-Internal: [CPUPL-2190] Change-Id: Ie1435c38b264a271f869782793d0b52bbe6e1b2a --- kernels/zen4/1/bli_amaxv_zen_int_avx512.c | 139 +++++++++++----------- 1 file changed, 67 insertions(+), 72 deletions(-) diff --git a/kernels/zen4/1/bli_amaxv_zen_int_avx512.c b/kernels/zen4/1/bli_amaxv_zen_int_avx512.c index 0e0186c403..9e32f955a8 100644 --- a/kernels/zen4/1/bli_amaxv_zen_int_avx512.c +++ b/kernels/zen4/1/bli_amaxv_zen_int_avx512.c @@ -80,33 +80,24 @@ typedef union the times the function is called to ensure that bigger numbers are assigned for nan which showed up first.*/ -__m512 remove_NAN_512_s(__m512 vec) -{ - // Sign extraction mask - __m512 sign_mask; - // Temporary place to store vector's sign extracted 16xdouble word - __m512 vec_mask; - // k register to store the mask to do blend operation to remove NAN - __mmask16 vec_mask16; - // Static to preserve accross the function calls - static int iter = -1; - iter -= 1; - - // Extracting sign from the vec into int_mask_vec - // Sign is -0.f in IEEE754 is just signbit set, all others 0 - sign_mask = _mm512_set1_ps(-0.f); - // And with -0.f will keep just signbits, all others will be 0 - vec_mask = _mm512_mul_ps(vec, sign_mask); - // Typecast mask into int type no clock cycle is taken just to - // convince compiler. - __m512i int_mask_vec = _mm512_castps_si512(vec_mask); - // Extract the signbits and put it in a 16bit mask register - vec_mask16 = _mm512_movepi32_mask(int_mask_vec); - - // Swap NAN with -ve number - vec = _mm512_mask_blend_ps(vec_mask16, _mm512_set1_ps(iter), vec); - return vec; -} +#define REMOVE_NAN_512S(reg_512) \ + { \ + /*Sign is -0.f in IEEE754 is just signbit set, all others 0*/ \ + __m512 sign_mask = _mm512_set1_ps( -0.0f ); \ + \ + /* Numbers other than NAN will become 0. */ \ + __m512 vec_mask = _mm512_mul_ps( reg_512, sign_mask ); \ + \ + /* Typecast mask into int type no clock cycle is taken just to + * convince compiler. */ \ + __m512i int_mask_vec = _mm512_castps_si512( vec_mask ); \ + /* Extract the signbits and put it in a 16bit mask register. */ \ + __mmask16 vec_mask16 = _mm512_movepi32_mask( int_mask_vec ); \ + \ + /* Swap NAN with -ve number. */ \ + reg_512 = _mm512_mask_blend_ps( vec_mask16, _mm512_set1_ps( nan_repl ), reg_512 ); \ + nan_repl = nan_repl - 1; \ + } // return a mask which indicates either: // - v1 > v2 @@ -151,10 +142,14 @@ void bli_samaxv_zen_int_avx512( // *minus_one = -1 float *minus_one = PASTEMAC(s, m1); // bli_sm1() // *zero_i = 0 - dim_t *zero_i = PASTEMAC(i, 0); // bli_i0() + dim_t *zero_i = PASTEMAC(i, 0); // bli_i0() + + // Used to replace NAN in registers. This value is decremented each time + // remove NAN is applied so as to keep the NAN value replacements unique. + float nan_repl = -1.0; float fndMaxVal; // Max value will be stored in this - dim_t fndInd; // Max value's index will be stored in this + dim_t fndInd; // Max value's index will be stored in this // Iterator for loops to keep continuity throughout the loops dim_t i; @@ -246,7 +241,7 @@ void bli_samaxv_zen_int_avx512( // max_vector = abs(max_vector) max_vec_1.v = _mm512_andnot_ps(abs_mask.v, max_vec_1.v); // Remove nan and replace with -ve values - max_vec_1.v = remove_NAN_512_s(max_vec_1.v); + REMOVE_NAN_512S(max_vec_1.v); // Increment x vector as we have loaded 16 values x += num_vector_elements; @@ -254,7 +249,7 @@ void bli_samaxv_zen_int_avx512( maxInd_vec_1.v = _mm512_setr_ps(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); - int i = 1; + dim_t i = 1; for (; (i + 4) < num_iter; i += 5) { /* @@ -276,7 +271,7 @@ void bli_samaxv_zen_int_avx512( // Increment x vector as we have loaded 16 values x += num_vector_elements; // Remove nan and replace with -ve values - x_vec_1.v = remove_NAN_512_s(x_vec_1.v); + REMOVE_NAN_512S(x_vec_1.v); // Mask Generation of 1st(can be previous max) and 2nd element // mask = max_vector - x_vec_1 @@ -295,7 +290,7 @@ void bli_samaxv_zen_int_avx512( // max_vec_2 = abs(max_vec_2) max_vec_2.v = _mm512_andnot_ps(abs_mask.v, max_vec_2.v); // Remove nan and replace with -ve values - max_vec_2.v = remove_NAN_512_s(max_vec_2.v); + REMOVE_NAN_512S(max_vec_2.v); // Increment x vector as we have loaded 16 values x += num_vector_elements; // Increment the index vector to point to next indexes. @@ -306,7 +301,7 @@ void bli_samaxv_zen_int_avx512( // x_vec_2 = abs(x_vec_2) x_vec_2.v = _mm512_andnot_ps(abs_mask.v, x_vec_2.v); // Remove nan and replace with -ve values - x_vec_2.v = remove_NAN_512_s(x_vec_2.v); + REMOVE_NAN_512S(x_vec_2.v); // Increment x vector as we have loaded 16 values x += num_vector_elements; // Increment the index vector to point to next indexes. @@ -329,7 +324,7 @@ void bli_samaxv_zen_int_avx512( // max_vec_3 = abs(max_vec_3) max_vec_3.v = _mm512_andnot_ps(abs_mask.v, max_vec_3.v); // Remove nan and replace with -ve values - max_vec_3.v = remove_NAN_512_s(max_vec_3.v); + REMOVE_NAN_512S(max_vec_3.v); // Increment x vector as we have loaded 16 values x += num_vector_elements; // Increment the index vector to point to next indexes. @@ -339,7 +334,7 @@ void bli_samaxv_zen_int_avx512( // x_vec_3 = abs(x_vec_3) x_vec_3.v = _mm512_andnot_ps(abs_mask.v, x_vec_3.v); // Remove nan and replace with -ve values - x_vec_3.v = remove_NAN_512_s(x_vec_3.v); + REMOVE_NAN_512S(x_vec_3.v); // Increment x vector as we have loaded 16 values x += num_vector_elements; // Increment the index vector to point to next indexes. @@ -468,7 +463,7 @@ void bli_samaxv_zen_int_avx512( // x_vec_1 = abs(x_vec_1) x_vec_1.v = _mm512_andnot_ps(abs_mask.v, x_vec_1.v); // Remove nan and replace with -ve values - x_vec_1.v = remove_NAN_512_s(x_vec_1.v); + REMOVE_NAN_512S(x_vec_1.v); // Mask Generation // mask = max_vec_1 - x_vec_1 @@ -618,7 +613,7 @@ void bli_samaxv_zen_int_avx512( fndMaxVal = NAN; } // Finish off the remaining values using normal instructions - for (int i = n - num_remain; i < n; i++) + for (dim_t i = n - num_remain; i < n; i++) { float absval = fabsf(*(x)); if (fndMaxVal < absval || (isnan(absval) && !isnan(fndMaxVal))) @@ -643,32 +638,21 @@ void bli_samaxv_zen_int_avx512( // ----------------------------------------------------------------------------- /* Converts all the NAN to a negative number less than previously encountered NANs*/ -__m512d remove_NAN_512d(__m512d vec) -{ - - static int iter; - static __m512d sign_mask; - - __m512d vec_mask; - __m512i int_mask_vec; - __mmask8 vec_mask8; - - iter = iter - 1; - - sign_mask = _mm512_set1_pd(-0.f); - - //numbers other than NAN will become 0 - vec_mask = _mm512_mul_pd(vec, sign_mask); - - //producing an 8-bit mask - int_mask_vec = _mm512_castpd_si512(vec_mask); - vec_mask8 = _mm512_movepi64_mask(int_mask_vec); - - //replacing all the NAN with negative numbers - vec = _mm512_mask_blend_pd(vec_mask8, _mm512_set1_pd(-1 + iter), vec); - - return vec; -} +#define REMOVE_NAN_512D(reg_512) \ + { \ + __m512d sign_mask = _mm512_set1_pd( -0.0f ); \ + \ + /* Numbers other than NAN will become 0. */ \ + __m512d vec_mask = _mm512_mul_pd( reg_512, sign_mask ); \ + \ + /* Producing an 8-bit mask. */ \ + __m512i int_mask_vec = _mm512_castpd_si512( vec_mask ); \ + __mmask8 vec_mask8 = _mm512_movepi64_mask( int_mask_vec ); \ + \ + /* Replacing all the NAN with negative numbers. */ \ + reg_512 = _mm512_mask_blend_pd( vec_mask8, _mm512_set1_pd( nan_repl ), reg_512 ); \ + nan_repl = nan_repl - 1; \ + } //---------------------------------------------------------------------------------------------------- void bli_damaxv_zen_int_avx512( @@ -679,6 +663,11 @@ void bli_damaxv_zen_int_avx512( { AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3) double *minus_one = PASTEMAC(d, m1); + + // Used to replace NAN in registers. This value is decremented each time + // remove NAN is applied so as to keep the NAN value replacements unique. + double nan_repl = -1.0; + dim_t *zero_i = PASTEMAC(i, 0); double chi1_r; @@ -776,7 +765,7 @@ void bli_damaxv_zen_int_avx512( // Taking the absolute value and removing the NAN max_array.v = _mm512_andnot_pd(sign_mask, max_array.v); - max_array.v = remove_NAN_512d(max_array.v); + REMOVE_NAN_512D(max_array.v); // Initializing the maximumum index max_ind.v = _mm512_set_pd(7, 6, 5, 4, 3, 2, 1, 0); @@ -786,7 +775,7 @@ void bli_damaxv_zen_int_avx512( //to point to the next 8 elements zmm4_Ind.v = _mm512_add_pd(zmm4_Ind.v, inc_vec.v); - /* Loop unrolled by a factor of 4 + /* Loop unrolled by a factor of 4 At the end of the loop max_array holds the largest element in each corresponding vector index */ for (unrollCount = 8; (unrollCount + 31) < n; unrollCount += 32) @@ -797,25 +786,25 @@ void bli_damaxv_zen_int_avx512( // with negative numbers zmm0.v = _mm512_loadu_pd(x); zmm0.v = _mm512_andnot_pd(sign_mask, zmm0.v); - zmm0.v = remove_NAN_512d(zmm0.v); + REMOVE_NAN_512D(zmm0.v); x += vector_length; zmm1.v = _mm512_loadu_pd(x); zmm5_Ind.v = _mm512_add_pd(zmm4_Ind.v, inc_vec.v); zmm1.v = _mm512_andnot_pd(sign_mask, zmm1.v); - zmm1.v = remove_NAN_512d(zmm1.v); + REMOVE_NAN_512D(zmm1.v); x += vector_length; zmm2.v = _mm512_loadu_pd(x); zmm6_Ind.v = _mm512_add_pd(zmm5_Ind.v, inc_vec.v); zmm2.v = _mm512_andnot_pd(sign_mask, zmm2.v); - zmm2.v = remove_NAN_512d(zmm2.v); + REMOVE_NAN_512D(zmm2.v); x += vector_length; zmm3.v = _mm512_loadu_pd(x); zmm7_Ind.v = _mm512_add_pd(zmm6_Ind.v, inc_vec.v); zmm3.v = _mm512_andnot_pd(sign_mask, zmm3.v); - zmm3.v = remove_NAN_512d(zmm3.v); + REMOVE_NAN_512D(zmm3.v); x += vector_length; /*Using sub function to generating the mask @@ -872,7 +861,7 @@ void bli_damaxv_zen_int_avx512( /* At the end of the loop max_array holds the largest element in each corresponding vector index */ - for (int i = 1; i < iterations; ++i) + for (dim_t i = 0; i < iterations; ++i) { // Taking 32 elements // Taking only the absolute values of the registers @@ -880,7 +869,7 @@ void bli_damaxv_zen_int_avx512( // with negative numbers zmm0.v = _mm512_loadu_pd(x); zmm0.v = _mm512_abs_pd(zmm0.v); - zmm0.v = remove_NAN_512d(zmm0.v); + REMOVE_NAN_512D(zmm0.v); //Generating mask for the intermediate max vector mask_01 = _mm512_sub_pd(max_array.v, zmm0.v); @@ -968,6 +957,12 @@ void bli_damaxv_zen_int_avx512( } } + // Issue vzeroupper instruction to clear upper lanes of ymm registers. + // This avoids a performance penalty caused by false dependencies when + // transitioning from from AVX to SSE instructions (which may occur + // later, especially if BLIS is compiled with -mfpmath=sse). + _mm256_zeroupper(); + // Return value *i_max = i_max_l; From 13c71ca976c2832f1eeeb23fbbd2ec4114fc2ff9 Mon Sep 17 00:00:00 2001 From: satish kumar nuggu Date: Mon, 13 Jun 2022 09:52:45 +0530 Subject: [PATCH 134/243] BugFix of AOCL_DYNAMIC in TRSM multithreaded small. - Added initialization of rntm object before aocl_dynamic. - Bugfixes in dtrsm right-side kernels, avoided accessing extra memory while using store for corner cases. AMD-Internal: [CPUPL-2193] [CPUPL-2194] Change-Id: I1c9d10edda93621626957d4de2f53d249ad531ba --- kernels/zen/3/bli_trsm_small.c | 271 +++++++++------------------------ 1 file changed, 76 insertions(+), 195 deletions(-) diff --git a/kernels/zen/3/bli_trsm_small.c b/kernels/zen/3/bli_trsm_small.c index d7192a062b..bb8a2e9cc5 100644 --- a/kernels/zen/3/bli_trsm_small.c +++ b/kernels/zen/3/bli_trsm_small.c @@ -3908,7 +3908,6 @@ err_t bli_trsm_small_mt cntl_t* cntl ) { - rntm_t rntm; gint_t m = bli_obj_length( b ); // number of rows of matrix b gint_t n = bli_obj_width( b ); // number of columns of Matrix b dim_t d_mr = 8,d_nr = 6; @@ -3928,6 +3927,9 @@ err_t bli_trsm_small_mt } } + rntm_t rntm; + bli_rntm_init_from_global( &rntm ); + #ifdef AOCL_DYNAMIC // If dynamic-threading is enabled, calculate optimum number // of threads. @@ -3938,8 +3940,6 @@ err_t bli_trsm_small_mt } #endif - bli_rntm_init_from_global( &rntm ); - // Query the total number of threads from the rntm_t object. dim_t n_threads = bli_rntm_num_threads( &rntm ); @@ -6727,25 +6727,19 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_pd(ymm0, ymm11, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_pd(ymm0, ymm13, 0x07); + _mm_storeu_pd((double *)b11, _mm256_extractf128_pd(ymm3,0)); + _mm_storeu_pd((double *)(b11 + cs_b), _mm256_extractf128_pd(ymm5,0)); + _mm_storeu_pd((double *)(b11 + cs_b*2), _mm256_extractf128_pd(ymm7,0)); + _mm_storeu_pd((double *)(b11 + cs_b*3), _mm256_extractf128_pd(ymm9,0)); + _mm_storeu_pd((double *)(b11 + cs_b*4), _mm256_extractf128_pd(ymm11,0)); + _mm_storeu_pd((double *)(b11 + cs_b*5), _mm256_extractf128_pd(ymm13,0)); - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); - _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); - _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); + _mm_storel_pd((double *)b11 + 2, _mm256_extractf128_pd(ymm3,1)); + _mm_storel_pd((double *)(b11 + cs_b + 2), _mm256_extractf128_pd(ymm5,1)); + _mm_storel_pd((double *)(b11 + cs_b*2 + 2), _mm256_extractf128_pd(ymm7,1)); + _mm_storel_pd((double *)(b11 + cs_b*3 + 2), _mm256_extractf128_pd(ymm9,1)); + _mm_storel_pd((double *)(b11 + cs_b*4 + 2), _mm256_extractf128_pd(ymm11,1)); + _mm_storel_pd((double *)(b11 + cs_b*5 + 2), _mm256_extractf128_pd(ymm13,1)); m_remainder -= 3; i += 3; @@ -6857,25 +6851,12 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_pd(ymm0, ymm11, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_pd(ymm0, ymm13, 0x03); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); - _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); - _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); + _mm_storeu_pd((double *)b11, _mm256_extractf128_pd(ymm3,0)); + _mm_storeu_pd((double *)(b11 + cs_b), _mm256_extractf128_pd(ymm5,0)); + _mm_storeu_pd((double *)(b11 + cs_b*2), _mm256_extractf128_pd(ymm7,0)); + _mm_storeu_pd((double *)(b11 + cs_b*3), _mm256_extractf128_pd(ymm9,0)); + _mm_storeu_pd((double *)(b11 + cs_b*4), _mm256_extractf128_pd(ymm11,0)); + _mm_storeu_pd((double *)(b11 + cs_b*5), _mm256_extractf128_pd(ymm13,0)); m_remainder -= 2; i += 2; @@ -6987,25 +6968,12 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_pd(ymm0, ymm11, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_pd(ymm0, ymm13, 0x01); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); - _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); - _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); + _mm_storel_pd((double *)b11, _mm256_extractf128_pd(ymm3,0)); + _mm_storel_pd((double *)(b11 + cs_b), _mm256_extractf128_pd(ymm5,0)); + _mm_storel_pd((double *)(b11 + cs_b*2), _mm256_extractf128_pd(ymm7,0)); + _mm_storel_pd((double *)(b11 + cs_b*3), _mm256_extractf128_pd(ymm9,0)); + _mm_storel_pd((double *)(b11 + cs_b*4), _mm256_extractf128_pd(ymm11,0)); + _mm_storel_pd((double *)(b11 + cs_b*5), _mm256_extractf128_pd(ymm13,0)); m_remainder -= 1; i += 1; @@ -7397,23 +7365,15 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x07); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*3 + 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x07); + _mm_storeu_pd((double *)b11, _mm256_extractf128_pd(ymm3,0)); + _mm_storeu_pd((double *)(b11 + cs_b), _mm256_extractf128_pd(ymm5,0)); + _mm_storeu_pd((double *)(b11 + cs_b*2), _mm256_extractf128_pd(ymm7,0)); + _mm_storeu_pd((double *)(b11 + cs_b*3), _mm256_extractf128_pd(ymm9,0)); - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - xmm5 = _mm256_extractf128_pd(ymm9, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 3),xmm5); - _mm_storel_pd((b11 + cs_b * 3 + 2), _mm256_extractf128_pd(ymm9, 1)); + _mm_storel_pd((double *)b11 + 2, _mm256_extractf128_pd(ymm3,1)); + _mm_storel_pd((double *)(b11 + cs_b + 2), _mm256_extractf128_pd(ymm5,1)); + _mm_storel_pd((double *)(b11 + cs_b*2 + 2), _mm256_extractf128_pd(ymm7,1)); + _mm_storel_pd((double *)(b11 + cs_b*3 + 2), _mm256_extractf128_pd(ymm9,1)); m_remainder -= 3; i += 3; @@ -7494,21 +7454,10 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x03); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x03); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - xmm5 = _mm256_extractf128_pd(ymm9, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 3),xmm5); + _mm_storeu_pd((double *)b11, _mm256_extractf128_pd(ymm3,0)); + _mm_storeu_pd((double *)(b11 + cs_b), _mm256_extractf128_pd(ymm5,0)); + _mm_storeu_pd((double *)(b11 + cs_b*2), _mm256_extractf128_pd(ymm7,0)); + _mm_storeu_pd((double *)(b11 + cs_b*3), _mm256_extractf128_pd(ymm9,0)); m_remainder -= 2; i += 2; @@ -7588,15 +7537,6 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x01); - _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm3, 0)); _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm5, 0)); _mm_storel_pd((b11 + cs_b * 2), _mm256_extractf128_pd(ymm7, 0)); @@ -9165,25 +9105,19 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_pd(ymm0, ymm11, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_pd(ymm0, ymm13, 0x07); + _mm_storeu_pd((double *)b11, _mm256_extractf128_pd(ymm3,0)); + _mm_storeu_pd((double *)(b11 + cs_b), _mm256_extractf128_pd(ymm5,0)); + _mm_storeu_pd((double *)(b11 + cs_b*2), _mm256_extractf128_pd(ymm7,0)); + _mm_storeu_pd((double *)(b11 + cs_b*3), _mm256_extractf128_pd(ymm9,0)); + _mm_storeu_pd((double *)(b11 + cs_b*4), _mm256_extractf128_pd(ymm11,0)); + _mm_storeu_pd((double *)(b11 + cs_b*5), _mm256_extractf128_pd(ymm13,0)); - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); - _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); - _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); + _mm_storel_pd((double *)b11 + 2, _mm256_extractf128_pd(ymm3,1)); + _mm_storel_pd((double *)(b11 + cs_b + 2), _mm256_extractf128_pd(ymm5,1)); + _mm_storel_pd((double *)(b11 + cs_b*2 + 2), _mm256_extractf128_pd(ymm7,1)); + _mm_storel_pd((double *)(b11 + cs_b*3 + 2), _mm256_extractf128_pd(ymm9,1)); + _mm_storel_pd((double *)(b11 + cs_b*4 + 2), _mm256_extractf128_pd(ymm11,1)); + _mm_storel_pd((double *)(b11 + cs_b*5 + 2), _mm256_extractf128_pd(ymm13,1)); m_remainder -=3; } @@ -9286,25 +9220,12 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_pd(ymm0, ymm11, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_pd(ymm0, ymm13, 0x03); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); - _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); - _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); + _mm_storeu_pd((double *)b11, _mm256_extractf128_pd(ymm3,0)); + _mm_storeu_pd((double *)(b11 + cs_b), _mm256_extractf128_pd(ymm5,0)); + _mm_storeu_pd((double *)(b11 + cs_b*2), _mm256_extractf128_pd(ymm7,0)); + _mm_storeu_pd((double *)(b11 + cs_b*3), _mm256_extractf128_pd(ymm9,0)); + _mm_storeu_pd((double *)(b11 + cs_b*4), _mm256_extractf128_pd(ymm11,0)); + _mm_storeu_pd((double *)(b11 + cs_b*5), _mm256_extractf128_pd(ymm13,0)); m_remainder -=2; } @@ -9407,25 +9328,12 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_pd(ymm0, ymm11, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_pd(ymm0, ymm13, 0x01); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); - _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); - _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); + _mm_storel_pd((double *)b11, _mm256_extractf128_pd(ymm3,0)); + _mm_storel_pd((double *)(b11 + cs_b), _mm256_extractf128_pd(ymm5,0)); + _mm_storel_pd((double *)(b11 + cs_b*2), _mm256_extractf128_pd(ymm7,0)); + _mm_storel_pd((double *)(b11 + cs_b*3), _mm256_extractf128_pd(ymm9,0)); + _mm_storel_pd((double *)(b11 + cs_b*4), _mm256_extractf128_pd(ymm11,0)); + _mm_storel_pd((double *)(b11 + cs_b*5), _mm256_extractf128_pd(ymm13,0)); m_remainder -=1; } @@ -9806,23 +9714,15 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x07); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*3 + 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x07); + _mm_storeu_pd((double *)b11, _mm256_extractf128_pd(ymm3,0)); + _mm_storeu_pd((double *)(b11 + cs_b), _mm256_extractf128_pd(ymm5,0)); + _mm_storeu_pd((double *)(b11 + cs_b*2), _mm256_extractf128_pd(ymm7,0)); + _mm_storeu_pd((double *)(b11 + cs_b*3), _mm256_extractf128_pd(ymm9,0)); - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - xmm5 = _mm256_extractf128_pd(ymm9, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 3),xmm5); - _mm_storel_pd((b11 + cs_b * 3 + 2), _mm256_extractf128_pd(ymm9, 1)); + _mm_storel_pd((double *)b11 + 2, _mm256_extractf128_pd(ymm3,1)); + _mm_storel_pd((double *)(b11 + cs_b + 2), _mm256_extractf128_pd(ymm5,1)); + _mm_storel_pd((double *)(b11 + cs_b*2 + 2), _mm256_extractf128_pd(ymm7,1)); + _mm_storel_pd((double *)(b11 + cs_b*3 + 2), _mm256_extractf128_pd(ymm9,1)); m_remainder -=3; } @@ -9898,21 +9798,10 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x03); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x03); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - xmm5 = _mm256_extractf128_pd(ymm9, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 3),xmm5); + _mm_storeu_pd((double *)b11, _mm256_extractf128_pd(ymm3,0)); + _mm_storeu_pd((double *)(b11 + cs_b), _mm256_extractf128_pd(ymm5,0)); + _mm_storeu_pd((double *)(b11 + cs_b*2), _mm256_extractf128_pd(ymm7,0)); + _mm_storeu_pd((double *)(b11 + cs_b*3), _mm256_extractf128_pd(ymm9,0)); m_remainder -=2; } @@ -9985,15 +9874,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm0 = _mm256_broadcast_sd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x01); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x01); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x01); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x01); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm3, 0)); _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm5, 0)); From aaf840d86eafb32c6674a91ab73964a5277426b4 Mon Sep 17 00:00:00 2001 From: satish kumar nuggu Date: Wed, 18 May 2022 16:53:24 +0530 Subject: [PATCH 135/243] Disabled zgemm SUP path - Need to identify new Thresholds for zgemm SUP path to avoid performance regression. AMD-Internal: [CPUPL-2148] Change-Id: I0baa2b415dc5e296780566ba7450249445b93d43 --- frame/compat/bla_gemm_amd.c | 9 --------- 1 file changed, 9 deletions(-) diff --git a/frame/compat/bla_gemm_amd.c b/frame/compat/bla_gemm_amd.c index 681869c9b8..99d7371778 100644 --- a/frame/compat/bla_gemm_amd.c +++ b/frame/compat/bla_gemm_amd.c @@ -753,15 +753,6 @@ void zgemm_ } } #endif - - err_t status = bli_gemmsup(&alphao, &ao, &bo, &betao, &co, NULL, NULL); - if(status==BLIS_SUCCESS) - { - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) - return; - } - // fall back on native path when zgemm is not handled in sup path. bli_gemmnat(&alphao, &ao, &bo, &betao, &co, NULL, NULL); AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); From 00c5048f65122e399e56b1eeb031c6bc0b72995e Mon Sep 17 00:00:00 2001 From: Dipal M Zambare Date: Tue, 14 Jun 2022 08:30:51 +0530 Subject: [PATCH 136/243] Fixed high impact static analysis issues Initialized ymm and xmm registers to zero to address un-inilizaed variable errors reported in static analsys. AMD-Internal: [CPUPL-2078] Change-Id: Icfcc008a0f244278efd8145d7feef764ed5fcc04 --- kernels/zen/3/bli_dgemm_ref_k1.c | 6 +++++- kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_c3x8n.c | 8 +++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/kernels/zen/3/bli_dgemm_ref_k1.c b/kernels/zen/3/bli_dgemm_ref_k1.c index 659975cdb7..03a2b789bb 100644 --- a/kernels/zen/3/bli_dgemm_ref_k1.c +++ b/kernels/zen/3/bli_dgemm_ref_k1.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -394,6 +394,7 @@ void bli_dgemm_ref_k1_nn if(m_rem == 1) { + ymm0 = _mm256_setzero_pd(); ymm3 = _mm256_setzero_pd(); ymm5 = _mm256_setzero_pd(); ymm7 = _mm256_setzero_pd(); @@ -690,6 +691,7 @@ void bli_dgemm_ref_k1_nn if(m_rem == 1) { + ymm0 = _mm256_setzero_pd(); ymm3 = _mm256_setzero_pd(); ymm5 = _mm256_setzero_pd(); ymm7 = _mm256_setzero_pd(); @@ -897,6 +899,7 @@ void bli_dgemm_ref_k1_nn if(m_rem == 1) { + ymm0 = _mm256_setzero_pd(); ymm3 = _mm256_setzero_pd(); ymm5 = _mm256_setzero_pd(); ymm15 = _mm256_setzero_pd(); @@ -1052,6 +1055,7 @@ void bli_dgemm_ref_k1_nn if(m_rem == 1) { + ymm0 = _mm256_setzero_pd(); ymm3 = _mm256_setzero_pd(); ymm15 = _mm256_setzero_pd(); diff --git a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_c3x8n.c b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_c3x8n.c index a21c9b5ed1..77f0348561 100644 --- a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_c3x8n.c +++ b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_c3x8n.c @@ -6,7 +6,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020-2021, Advanced Micro Devices, Inc. + Copyright (C) 2020-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -138,6 +138,8 @@ void bli_cgemmsup_rv_zen_asm_3x8n for (n_iter = 0; n_iter < n0 / 8; n_iter++) { // clear scratch registers. + xmm0 = _mm_setzero_ps(); + xmm3 = _mm_setzero_ps(); ymm4 = _mm256_setzero_ps(); ymm5 = _mm256_setzero_ps(); ymm6 = _mm256_setzero_ps(); @@ -572,6 +574,8 @@ void bli_cgemmsup_rv_zen_asm_2x8n for (n_iter = 0; n_iter < n0 / 8; n_iter++) { // clear scratch registers. + xmm0 = _mm_setzero_ps(); + xmm3 = _mm_setzero_ps(); ymm4 = _mm256_setzero_ps(); ymm5 = _mm256_setzero_ps(); ymm6 = _mm256_setzero_ps(); @@ -919,6 +923,8 @@ void bli_cgemmsup_rv_zen_asm_1x8n for (n_iter = 0; n_iter < n0 / 8; n_iter++) { // clear scratch registers. + xmm0 = _mm_setzero_ps(); + xmm3 = _mm_setzero_ps(); ymm4 = _mm256_setzero_ps(); ymm5 = _mm256_setzero_ps(); ymm6 = _mm256_setzero_ps(); From 47c344bc2e93d5a3d1127a96aac4edf19b7137b0 Mon Sep 17 00:00:00 2001 From: Kiran Varaganti Date: Mon, 27 Jun 2022 05:34:48 +0000 Subject: [PATCH 137/243] DGEMM Benchmark Optimizations Updated with optimal cache-blocking sizes for MC, KC and NC for AVX512 dgemm kernel Change-Id: I56b3df238b6d85a6f6861448c0c6f907c972146a --- config/zen4/bli_cntx_init_zen4.c | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/config/zen4/bli_cntx_init_zen4.c b/config/zen4/bli_cntx_init_zen4.c index 0f589f301c..b3a6d1030c 100644 --- a/config/zen4/bli_cntx_init_zen4.c +++ b/config/zen4/bli_cntx_init_zen4.c @@ -167,9 +167,9 @@ void bli_cntx_init_zen4( cntx_t* cntx ) bli_blksz_init_easy( &blkszs[ BLIS_MR ], 32, 16, 3, 3 ); bli_blksz_init_easy( &blkszs[ BLIS_NR ], 12, 14, 8, 4 ); bli_blksz_init_easy( &blkszs[ BLIS_MC ], 480, 240, 144, 18 ); - bli_blksz_init ( &blkszs[ BLIS_KC ], 384, 256, 256, 566, + bli_blksz_init ( &blkszs[ BLIS_KC ], 384, 512, 256, 566, 480, 320, 256, 566 ); - bli_blksz_init_easy( &blkszs[ BLIS_NC ], 3072, 3752, 4080, 256 ); + bli_blksz_init_easy( &blkszs[ BLIS_NC ], 3072, 4004, 4080, 256 ); bli_blksz_init_easy( &blkszs[ BLIS_AF ], 8, 8, -1, -1 ); bli_blksz_init_easy( &blkszs[ BLIS_DF ], 8, 8, -1, -1 ); From ff2ee0ae3fce91d0e057b7233346b2b5f4bada4e Mon Sep 17 00:00:00 2001 From: Chandrashekara K R Date: Wed, 22 Jun 2022 17:21:58 +0530 Subject: [PATCH 138/243] AOCL-WINDOWS: Added the windows build system to build bench folder on windows. 1. Added the checks in .c files of the bench folder to read the input parameters from the given input files on windows using fscanf. Change-Id: Ie0497696304d318f345a646ab0ce3ba84debd4e2 --- CMakeLists.txt | 1 + bench/CMakeLists.txt | 97 ++++++++++++++++++++++++++++++++++++++++++++ bench/Makefile | 10 ++--- bench/bench_amaxv.c | 5 +-- bench/bench_axpbyv.c | 2 +- bench/bench_copyv.c | 4 +- bench/bench_dotv.c | 5 +-- bench/bench_gemm.c | 5 +-- bench/bench_gemmt.c | 4 +- bench/bench_gemv.c | 5 +-- bench/bench_ger.c | 5 +-- bench/bench_scalv.c | 5 +-- bench/bench_swapv.c | 4 +- bench/bench_syrk.c | 4 +- bench/bench_trsm.c | 4 +- bench/bench_trsv.c | 5 +-- 16 files changed, 128 insertions(+), 37 deletions(-) create mode 100644 bench/CMakeLists.txt diff --git a/CMakeLists.txt b/CMakeLists.txt index 2558494d90..dce532b07c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -598,6 +598,7 @@ add_subdirectory(frame) add_subdirectory(aocl_dtl) add_subdirectory(test) add_subdirectory(testsuite) +add_subdirectory(bench) if(ENABLE_TESTCPP_TESTING) add_subdirectory(vendor/testcpp) endif() diff --git a/bench/CMakeLists.txt b/bench/CMakeLists.txt new file mode 100644 index 0000000000..00d01fdd21 --- /dev/null +++ b/bench/CMakeLists.txt @@ -0,0 +1,97 @@ +##Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.## + +add_definitions(-DBLAS="AOCL") +add_definitions(-DN_REPEAT=1000) +add_definitions(-DINT_FS="%lld") +add_definitions(-DUINT_FS="%llu") + +add_executable(BenchAmaxv bench_amaxv.c) +target_link_libraries(BenchAmaxv debug "${LIB_NAME}.lib") +if(ENABLE_OPENMP) + target_link_libraries(BenchAmaxv OpenMP::OpenMP_CXX) +endif() +target_link_libraries(BenchAmaxv optimized "${LIB_NAME}.lib") + +add_executable(BenchAxpbyv bench_axpbyv.c) +target_link_libraries(BenchAxpbyv debug "${LIB_NAME}.lib") +if(ENABLE_OPENMP) + target_link_libraries(BenchAxpbyv OpenMP::OpenMP_CXX) +endif() +target_link_libraries(BenchAxpbyv optimized "${LIB_NAME}.lib") + +add_executable(BenchCopyv bench_copyv.c) +target_link_libraries(BenchCopyv debug "${LIB_NAME}.lib") +if(ENABLE_OPENMP) + target_link_libraries(BenchCopyv OpenMP::OpenMP_CXX) +endif() +target_link_libraries(BenchCopyv optimized "${LIB_NAME}.lib") + +add_executable(BenchDotv bench_dotv.c) +target_link_libraries(BenchDotv debug "${LIB_NAME}.lib") +if(ENABLE_OPENMP) + target_link_libraries(BenchDotv OpenMP::OpenMP_CXX) +endif() +target_link_libraries(BenchDotv optimized "${LIB_NAME}.lib") + +add_executable(BenchGemm bench_gemm.c) +target_link_libraries(BenchGemm debug "${LIB_NAME}.lib") +if(ENABLE_OPENMP) + target_link_libraries(BenchGemm OpenMP::OpenMP_CXX) +endif() +target_link_libraries(BenchGemm optimized "${LIB_NAME}.lib") + +add_executable(BenchGemmt bench_gemmt.c) +target_link_libraries(BenchGemmt debug "${LIB_NAME}.lib") +if(ENABLE_OPENMP) + target_link_libraries(BenchGemmt OpenMP::OpenMP_CXX) +endif() +target_link_libraries(BenchGemmt optimized "${LIB_NAME}.lib") + +add_executable(BenchGemv bench_gemv.c) +target_link_libraries(BenchGemv debug "${LIB_NAME}.lib") +if(ENABLE_OPENMP) + target_link_libraries(BenchGemv OpenMP::OpenMP_CXX) +endif() +target_link_libraries(BenchGemv optimized "${LIB_NAME}.lib") + +add_executable(BenchGer bench_ger.c) +target_link_libraries(BenchGer debug "${LIB_NAME}.lib") +if(ENABLE_OPENMP) + target_link_libraries(BenchGer OpenMP::OpenMP_CXX) +endif() +target_link_libraries(BenchGer optimized "${LIB_NAME}.lib") + +add_executable(BenchScalv bench_scalv.c) +target_link_libraries(BenchScalv debug "${LIB_NAME}.lib") +if(ENABLE_OPENMP) + target_link_libraries(BenchScalv OpenMP::OpenMP_CXX) +endif() +target_link_libraries(BenchScalv optimized "${LIB_NAME}.lib") + +add_executable(BenchSwapv bench_swapv.c) +target_link_libraries(BenchSwapv debug "${LIB_NAME}.lib") +if(ENABLE_OPENMP) + target_link_libraries(BenchSwapv OpenMP::OpenMP_CXX) +endif() +target_link_libraries(BenchSwapv optimized "${LIB_NAME}.lib") + +add_executable(BenchSyrk bench_syrk.c) +target_link_libraries(BenchSyrk debug "${LIB_NAME}.lib") +if(ENABLE_OPENMP) + target_link_libraries(BenchSyrk OpenMP::OpenMP_CXX) +endif() +target_link_libraries(BenchSyrk optimized "${LIB_NAME}.lib") + +add_executable(BenchTrsm bench_trsm.c) +target_link_libraries(BenchTrsm debug "${LIB_NAME}.lib") +if(ENABLE_OPENMP) + target_link_libraries(BenchTrsm OpenMP::OpenMP_CXX) +endif() +target_link_libraries(BenchTrsm optimized "${LIB_NAME}.lib") + +add_executable(BenchTrsv bench_trsv.c) +target_link_libraries(BenchTrsv debug "${LIB_NAME}.lib") +if(ENABLE_OPENMP) + target_link_libraries(BenchTrsv OpenMP::OpenMP_CXX) +endif() +target_link_libraries(BenchTrsv optimized "${LIB_NAME}.lib") diff --git a/bench/Makefile b/bench/Makefile index d47485b2fc..93cca3298a 100755 --- a/bench/Makefile +++ b/bench/Makefile @@ -6,7 +6,7 @@ # libraries. # # Copyright (C) 2014, The University of Texas at Austin -# Copyright (C) 2017 - 2021, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2017 - 2022, Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -246,17 +246,17 @@ $(TEST_OBJ_PATH)/%.o: $(TEST_SRC_PATH)/%.c $(CC) $(CFLAGS) -c $< -o $@ bench_%_openblas.o: bench_%.c - $(CC) $(CFLAGS) -DBLAS=\"openblas\" $(NRTS) -c $< -o $@ + $(CC) $(CFLAGS) -DBLAS=\"openblas\" $(NRTS) -DINT_FS=\"%ld\" -DUINT_FS=\"%lu\" -c $< -o $@ bench_%_atlas.o: bench_%.c - $(CC) $(CFLAGS) -DBLAS=\"atlas\" $(NRTS) -c $< -o $@ + $(CC) $(CFLAGS) -DBLAS=\"atlas\" $(NRTS) -DINT_FS=\"%ld\" -DUINT_FS=\"%lu\" -c $< -o $@ bench_%_mkl.o: bench_%.c - $(CC) $(CFLAGS) -DBLAS=\"mkl\" $(NRTS) -c $< -o $@ + $(CC) $(CFLAGS) -DBLAS=\"mkl\" $(NRTS) -DINT_FS=\"%ld\" -DUINT_FS=\"%lu\" -c $< -o $@ bench_%_blis.o: bench_%.c - $(CC) $(CFLAGS) -DBLAS=\"aocl\" $(NRTS) -c $< -o $@ + $(CC) $(CFLAGS) -DBLAS=\"aocl\" $(NRTS) -DINT_FS=\"%ld\" -DUINT_FS=\"%lu\" -c $< -o $@ # -- Executable file rules -- diff --git a/bench/bench_amaxv.c b/bench/bench_amaxv.c index 739bd0f979..2a0e578975 100644 --- a/bench/bench_amaxv.c +++ b/bench/bench_amaxv.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2021-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -101,8 +101,7 @@ int main( int argc, char** argv ) char tmp[256]; // to store function name, line no present in logs. // {S,D,C,Z} {n incx} - - while (fscanf(fin, "%s %c %ld %ld \n", + while (fscanf(fin, "%s %c " INT_FS INT_FS " \n", tmp, &dt_ch, &n, &incx) == 4) { diff --git a/bench/bench_axpbyv.c b/bench/bench_axpbyv.c index 36a203f696..c962079dd6 100644 --- a/bench/bench_axpbyv.c +++ b/bench/bench_axpbyv.c @@ -97,7 +97,7 @@ int main( int argc, char** argv ) // {function name} {S, D, C, Z} {n} // {alpha_r} {alpha_i} {incx} {beta_r} {beta_i} {incy} - while ( fscanf( fin, "%s %c %ld %lf %lf %ld %lf %lf %ld\n", + while ( fscanf( fin, "%s %c " INT_FS " %lf %lf " INT_FS " %lf %lf " INT_FS "\n", tmp, &dt_ch, &n, &alpha_r, &alpha_i, &incx, &beta_r, &beta_i, &incy ) == 9 ) { diff --git a/bench/bench_copyv.c b/bench/bench_copyv.c index c46ffc6093..7be38907ed 100644 --- a/bench/bench_copyv.c +++ b/bench/bench_copyv.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2021-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -101,7 +101,7 @@ int main( int argc, char** argv ) inc_t incx, incy; // {S,D,C,Z} {n incx incy} - while (fscanf(fin, "%s %c %ld %ld %ld\n", + while (fscanf(fin, "%s %c " INT_FS INT_FS INT_FS "\n", tmp, &dt_ch, &n, &incx, &incy) == 5) { diff --git a/bench/bench_dotv.c b/bench/bench_dotv.c index 80dcf8e99d..0d39594f72 100644 --- a/bench/bench_dotv.c +++ b/bench/bench_dotv.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2021-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -104,8 +104,7 @@ int main( int argc, char** argv ) // {S,D,C,Z} {n incx incy} - - while (fscanf(fin, "%s %c %ld %ld %ld\n", + while (fscanf(fin, "%s %c " INT_FS INT_FS INT_FS "\n", tmp, &dt_ch, &n, &incx, &incy) == 5) { diff --git a/bench/bench_gemm.c b/bench/bench_gemm.c index 8258b61d18..908ce0fca5 100755 --- a/bench/bench_gemm.c +++ b/bench/bench_gemm.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020-2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -129,8 +129,7 @@ int main( int argc, char** argv ) // beta_real, beta_imag, ldc, // // number of threads, execution time, gflops ---> ignored by bench - - while (fscanf(fin, "%s %c %c %c %ld %ld %ld %lf %lf %ld %ld %lf %lf %ld[^\n]", + while (fscanf(fin, "%s %c %c %c " INT_FS INT_FS INT_FS " %lf %lf " INT_FS INT_FS " %lf %lf " INT_FS"[^\n]", api_name, &dt_ch, &transA_c, &transB_c, &m, &n, &k, &alpha_r, &alpha_i, &lda, &ldb, &beta_r, &beta_i, &ldc) == 14) { diff --git a/bench/bench_gemmt.c b/bench/bench_gemmt.c index 621c9288c7..ad24593747 100644 --- a/bench/bench_gemmt.c +++ b/bench/bench_gemmt.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2020-21, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020-22, Advanced Micro Devices, Inc. All rights reserved. modification, are permitted provided that the following conditions are met: @@ -122,7 +122,7 @@ int main( int argc, char** argv ) stor_scheme = 'C'; // since logs are collected at BLAS APIs // {S,D,C,Z} {triangC : l or u} {n k lda ldb ldc transa transb alpha_real alpha_imaginary beta_real, beta_imaginary} - while (fscanf(fin,"%s %c %c %ld %ld %lu %lu %lu %c %c %lf %lf %lf %lf\n",\ + while (fscanf(fin,"%s %c %c " INT_FS INT_FS UINT_FS UINT_FS UINT_FS " %c %c %lf %lf %lf %lf\n",\ tmp, &dt_ch, &uplo_c, &n, &k,\ &lda, &ldb, &ldc, &transA_c, &transB_c, \ &alpha_r, &alpha_i, &beta_r, &beta_i) == 14) diff --git a/bench/bench_gemv.c b/bench/bench_gemv.c index acc4598000..9f06bf8efb 100755 --- a/bench/bench_gemv.c +++ b/bench/bench_gemv.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2021-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -112,8 +112,7 @@ int main( int argc, char** argv ) // {S,D,C,Z} {transa m n alpha lda, incx, beta, incy} - - while (fscanf(fin, "%s %c %c %ld %ld %lf %lf %ld %ld %lf %lf %ld\n", + while (fscanf(fin, "%s %c %c " INT_FS INT_FS " %lf %lf " INT_FS INT_FS " %lf %lf " INT_FS "\n", tmp, &dt_ch, &transA, &m, &n, &alpha_r, &alpha_i, &lda,\ &incx, &beta_r, &beta_i, &incy) == 12) { diff --git a/bench/bench_ger.c b/bench/bench_ger.c index fb50c94265..2c8981a682 100644 --- a/bench/bench_ger.c +++ b/bench/bench_ger.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2021-22, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2021-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -116,8 +116,7 @@ int main( int argc, char** argv ) #endif // {S,D,C,Z} {transa m n alpha incx incy lda} - - while (fscanf(fin, "%s %c %ld %ld %lf %lf %ld %ld %ld\n", + while (fscanf(fin, "%s %c " INT_FS INT_FS " %lf %lf " INT_FS INT_FS INT_FS "\n", tmp, &dt_ch, &m, &n, &alpha_r, &alpha_i, &incx, &incy, &lda) == 9) { diff --git a/bench/bench_scalv.c b/bench/bench_scalv.c index 404d5078f5..b8cd6241c1 100644 --- a/bench/bench_scalv.c +++ b/bench/bench_scalv.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2021-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -105,8 +105,7 @@ int main( int argc, char** argv ) // {S,D,C,Z} {alpha n incx} - - while (fscanf(fin, "%s %c %lf %lf %ld %ld\n", + while (fscanf(fin, "%s %c %lf %lf " INT_FS INT_FS "\n", tmp, &dt_ch, &alpha_r, &alpha_i, &n, &incx) == 6) { diff --git a/bench/bench_swapv.c b/bench/bench_swapv.c index 16aafdaaed..34af6b7975 100644 --- a/bench/bench_swapv.c +++ b/bench/bench_swapv.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2021-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -103,7 +103,7 @@ int main( int argc, char** argv ) char tmp[256]; // to store function name, line no present in logs. // {S,D,C,Z} {n incx incy} - while (fscanf(fin, "%s %c %ld %ld %ld\n", + while (fscanf(fin, "%s %c " INT_FS INT_FS INT_FS "\n", tmp, &dt_ch, &n, &incx, &incy) == 5) { diff --git a/bench/bench_syrk.c b/bench/bench_syrk.c index 017b010dfc..b65db83aa5 100644 --- a/bench/bench_syrk.c +++ b/bench/bench_syrk.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2021-2022, Advanced Micro Devices, Inc. All rights reserved. modification, are permitted provided that the following conditions are met: @@ -120,7 +120,7 @@ int main( int argc, char** argv ) stor_scheme = 'C'; // since logs are collected at BLAS APIs // {S,D,C,Z}{ uploc, transa, n, k, alpha_real, alpha_imag, lda, beta_real, beta_imag, ldc} - while (fscanf(fin, "%s %c %c %c %ld %ld %lf %lf %lu %lf %lf %lu\n",\ + while (fscanf(fin, "%s %c %c %c " INT_FS INT_FS " %lf %lf " UINT_FS " %lf %lf " UINT_FS "\n",\ tmp, &dt_ch, &uplo_c, &transA_c, &n, &k, &alpha_r,\ &alpha_i, &lda, &beta_r, &beta_i, &ldc) == 12) { diff --git a/bench/bench_trsm.c b/bench/bench_trsm.c index a7d62ebecc..b2b7f1af18 100644 --- a/bench/bench_trsm.c +++ b/bench/bench_trsm.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020-2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -101,7 +101,7 @@ int main( int argc, char** argv ) f77_char dt_type_arg, side_arg, uploa_arg, transa_arg, diaga_arg; f77_char logline[255]; // input order: {S,D,C,Z} {side, uplo, transa, diag, m, n, lda, ldb, alphaR, alphaI} - while(fscanf(fin, "%s %c %c %c %c %c %ld %ld %ld %ld %lf %lf\n", + while(fscanf(fin, "%s %c %c %c %c %c " INT_FS INT_FS INT_FS INT_FS " %lf %lf\n", logline, &dt_type_arg, &side_arg, &uploa_arg, &transa_arg, &diaga_arg, &m, &n, &lda, &ldb, &alphaR, &alphaI) == 12) { diff --git a/bench/bench_trsv.c b/bench/bench_trsv.c index ca18d3fdc9..ddf3ea187a 100644 --- a/bench/bench_trsv.c +++ b/bench/bench_trsv.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2021-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -121,8 +121,7 @@ int main( int argc, char** argv ) fprintf(fout, "Dt uploa\t transa\t diaga\t m\t lda\t incx\t gflops\n"); // {S,D,C,Z} {uploa transa diaga m lda, incx} - - while (fscanf(fin, "%s %c %c %c %c %ld %ld %ld\n", + while (fscanf(fin, "%s %c %c %c %c " INT_FS INT_FS INT_FS "\n", tmp, &dt_ch, &uploa_c, &transA, &diaga_c, &m, &lda, &incx) == 8) { From 77e8492cbd12dcaa5a7ca0959ce667508d8d8dff Mon Sep 17 00:00:00 2001 From: Harish Date: Mon, 27 Jun 2022 18:39:01 +0530 Subject: [PATCH 139/243] Added znver4 flag for config builds with AOCC 4.0 compiler version Change-Id: I45f1031ed4c5ea2e3f594713f3821d6bbbecd4df --- config/zen4/make_defs.mk | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/config/zen4/make_defs.mk b/config/zen4/make_defs.mk index 85a8a39f62..f35d4515e0 100644 --- a/config/zen4/make_defs.mk +++ b/config/zen4/make_defs.mk @@ -97,8 +97,13 @@ ifeq ($(CC_VENDOR),clang) # AMD clang version 11.0.0 (CLANG: AOCC_2.3.0-Build#85 2020_11_10) (based on LLVM Mirror.Version.11.0.0) # AMD clang version 12.0.0 (CLANG: AOCC_3.0.0-Build#2 2020_11_05) (based on LLVM Mirror.Version.12.0.0) -# For our prupose we just want to know if it version 2x or 3x +# For our prupose we just want to know if it version 2x or 3x or 4x +# for version 4x we will enable znver4 +ifeq ($(strip $(shell $(CC) -v |&head -1 |grep -c 'AOCC_4')),1) +CKVECFLAGS += -march=znver4 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mfpmath=sse +CRVECFLAGS += -march=znver4 +else # for version 3x we will enable znver3 ifeq ($(strip $(shell $(CC) -v |&head -1 |grep -c 'AOCC_3')),1) CKVECFLAGS += -march=znver3 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mfpmath=sse From 2e4ed37e97cdb42399adc2b7fa55225ff5a16173 Mon Sep 17 00:00:00 2001 From: Harish Date: Tue, 28 Jun 2022 13:00:04 +0530 Subject: [PATCH 140/243] Added missing endif for the AOCC 4.0 verion check Change-Id: I1c77ae795c398aec685152b491b838a75e7ce318 --- config/zen4/make_defs.mk | 2 ++ 1 file changed, 2 insertions(+) diff --git a/config/zen4/make_defs.mk b/config/zen4/make_defs.mk index f35d4515e0..c6a3c545f5 100644 --- a/config/zen4/make_defs.mk +++ b/config/zen4/make_defs.mk @@ -96,6 +96,7 @@ ifeq ($(CC_VENDOR),clang) # AMD clang version 10.0.0 (CLANG: AOCC_2.2.0-Build#93 2020_06_25) (based on LLVM Mirror.Version.10.0.0) # AMD clang version 11.0.0 (CLANG: AOCC_2.3.0-Build#85 2020_11_10) (based on LLVM Mirror.Version.11.0.0) # AMD clang version 12.0.0 (CLANG: AOCC_3.0.0-Build#2 2020_11_05) (based on LLVM Mirror.Version.12.0.0) +# AMD clang version 14.0.0 (CLANG: AOCC_4.0.0-Build#98 2022_06_15) (based on LLVM Mirror.Version.14.0.0) # For our prupose we just want to know if it version 2x or 3x or 4x @@ -127,6 +128,7 @@ CRVECFLAGS += -march=znver1 endif # ge 9 endif # aocc 2 endif # aocc 3 +endif # aocc 4 endif # clang endif # gcc From d4bb906094a424050387d1f6a537311cdcb297ac Mon Sep 17 00:00:00 2001 From: Harihara Sudhan S Date: Mon, 27 Jun 2022 12:10:46 +0530 Subject: [PATCH 141/243] Exception handling in GEMV smart-threading - Added condition to check if n or m is 0 in smart threading logic AMD-Internal: [CPUPL2219] Change-Id: Idd58cd13a11aa5bdb4117b4c9262f38ef3c1afc4 --- frame/2/gemv/bli_gemv_unf_var1_amd.c | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/frame/2/gemv/bli_gemv_unf_var1_amd.c b/frame/2/gemv/bli_gemv_unf_var1_amd.c index 447f8dbc43..3347a133aa 100644 --- a/frame/2/gemv/bli_gemv_unf_var1_amd.c +++ b/frame/2/gemv/bli_gemv_unf_var1_amd.c @@ -343,6 +343,14 @@ void bli_sgemv_var1_smart_threading // Calculate the amount data processed per iteration dim_t n_per_loop = n / fuse; double data_per_iter = n_per_loop* m; + + // Exception handling when m-dimenstion or n-dimension is zero + if (bli_zero_dim2(m,n)) + { + *nt = 1; + return; + } + double m_n_ratio = m/n; // When the input value is less than the fuse factor From 2ba2fb2b63409afa58ece47f51e0bb3dec1c22ec Mon Sep 17 00:00:00 2001 From: Dipal M Zambare Date: Wed, 18 May 2022 11:01:41 +0530 Subject: [PATCH 142/243] Add AVX2 path for TRSM+GEMM combination. - Enabled AVX2 TRSM + GEMM kernel path, when GEMM is called from TRSM context it will invoke AVX2 GEMM kernels instead of the default AVX-512 GEMM kernels. - The default context has the block sizes for AVX512 GEMM kernels, however, TRSM uses AVX2 GEMM kernels and they need different block sizes. - Added new API bli_zen4_override_trsm_blkszs(). It overrides default block sizes in context with block sizes needed for AVX2 GEMM kernels. - Added new API bli_zen4_restore_default_blkszs(). It restores The block sizes to there default values (as needed by default AVX512 GEMM kernels). - Updated bli_trsm_front() to override the block sizes in the context needed by TRSM + AVX2 GEMM kernels and restore them to the default values at the end of this function. It is done in bli_trsm_front() so that we override the context before creating different threads. AMD-Internal: [CPUPL-2225] Change-Id: Ie92d0fc40f94a32dfb865fe3771dc14ed7884c55 --- config/amdzen/bli_family_amdzen.h | 26 ++++++- config/zen4/bli_cntx_init_zen4.c | 116 ++++++++++++++++++++++++---- config/zen4/bli_family_zen4.h | 24 ++++++ config/zen4/make_defs.mk | 8 +- frame/3/trsm/bli_trsm_front.c | 35 ++++++++- frame/3/trsm/bli_trsm_ll_ker_var2.c | 21 ++++- frame/3/trsm/bli_trsm_lu_ker_var2.c | 21 ++++- frame/3/trsm/bli_trsm_rl_ker_var2.c | 21 ++++- frame/3/trsm/bli_trsm_ru_ker_var2.c | 21 ++++- frame/base/bli_blksz.c | 8 +- frame/include/bli_type_defs.h | 5 +- testsuite/src/test_gemmtrsm_ukr.c | 39 +++++++++- 12 files changed, 311 insertions(+), 34 deletions(-) diff --git a/config/amdzen/bli_family_amdzen.h b/config/amdzen/bli_family_amdzen.h index c73409673d..1a8c1234a8 100644 --- a/config/amdzen/bli_family_amdzen.h +++ b/config/amdzen/bli_family_amdzen.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2021-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -61,5 +61,29 @@ //#define BLIS_ENABLE_FAST_MATH +/* + * Override the block sizes in the context to the block sizes used + * by AVX2 GEMM+TRSM kernels, this is needed in Zen4 context as default + * GEMM kernels are AVX512 based and uses different block sizes. + * + * This function should be called in TRSM path before performing + * any packing operations. + * + * Also the context must be restored to default values by calling + * bli_zen4_restore_default_blkszs() before exiting TRSM Path + */ +BLIS_EXPORT_BLIS void bli_zen4_override_trsm_blkszs (cntx_t* cntx); + +/* + * Restore the block sizes to default values needed for zen4 context. + * + * This function should be called to restore the block sizes to there + * default values if they where overriden by calling + * bli_zen4_override_trsm_blkszs() to enable AVX2 GEMM kernels in the + * TRSM path. + * + */ +BLIS_EXPORT_BLIS void bli_zen4_restore_default_blkszs (cntx_t* cntx); + #endif diff --git a/config/zen4/bli_cntx_init_zen4.c b/config/zen4/bli_cntx_init_zen4.c index b3a6d1030c..fc900a1e98 100644 --- a/config/zen4/bli_cntx_init_zen4.c +++ b/config/zen4/bli_cntx_init_zen4.c @@ -34,6 +34,24 @@ #include "blis.h" +/* + * List of default block sizes for zen4. + * Converted it to macro as this list is used at multiple places in this file. + */ + +#define BLI_CNTX_DEFAULT_BLKSZ_LIST(blkszs) \ + /* s d c z */ \ + bli_blksz_init_easy( &blkszs[ BLIS_MR ], 32, 16, 3, 3 ); \ + bli_blksz_init_easy( &blkszs[ BLIS_NR ], 12, 14, 8, 4 ); \ + bli_blksz_init_easy( &blkszs[ BLIS_MC ], 480, 240, 144, 18 ); \ + bli_blksz_init ( &blkszs[ BLIS_KC ], 384, 512, 256, 566, \ + 480, 320, 256, 566 ); \ + bli_blksz_init_easy( &blkszs[ BLIS_NC ], 3072, 4004, 4080, 256 ); \ + \ + bli_blksz_init_easy( &blkszs[ BLIS_AF ], 8, 8, -1, -1 ); \ + bli_blksz_init_easy( &blkszs[ BLIS_DF ], 8, 8, -1, -1 ); \ + + void bli_cntx_init_zen4( cntx_t* cntx ) { blksz_t blkszs[ BLIS_NUM_BLKSZS ]; @@ -47,20 +65,23 @@ void bli_cntx_init_zen4( cntx_t* cntx ) // their storage preferences. bli_cntx_set_l3_nat_ukrs ( - 4, + 10, // gemm BLIS_GEMM_UKR, BLIS_FLOAT , bli_sgemm_skx_asm_32x12_l2, FALSE, BLIS_GEMM_UKR, BLIS_DOUBLE, bli_dgemm_skx_asm_16x14, FALSE, BLIS_GEMM_UKR, BLIS_SCOMPLEX, bli_cgemm_haswell_asm_3x8, TRUE, BLIS_GEMM_UKR, BLIS_DCOMPLEX, bli_zgemm_haswell_asm_3x4, TRUE, -#if 0 // GENOA TODO: TRSM AVX-512 implementation + + BLIS_GEMM_AVX2_UKR, BLIS_FLOAT, bli_sgemm_haswell_asm_6x16, TRUE, + BLIS_GEMM_AVX2_UKR, BLIS_DOUBLE, bli_dgemm_haswell_asm_6x8, TRUE, + // gemmtrsm_l BLIS_GEMMTRSM_L_UKR, BLIS_FLOAT, bli_sgemmtrsm_l_haswell_asm_6x16, TRUE, BLIS_GEMMTRSM_L_UKR, BLIS_DOUBLE, bli_dgemmtrsm_l_haswell_asm_6x8, TRUE, // gemmtrsm_u BLIS_GEMMTRSM_U_UKR, BLIS_FLOAT, bli_sgemmtrsm_u_haswell_asm_6x16, TRUE, BLIS_GEMMTRSM_U_UKR, BLIS_DOUBLE, bli_dgemmtrsm_u_haswell_asm_6x8, TRUE, -#endif + cntx ); @@ -115,7 +136,7 @@ void bli_cntx_init_zen4( cntx_t* cntx ) 24, // amaxv - BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int_avx512, + BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int_avx512, BLIS_AMAXV_KER, BLIS_DOUBLE, bli_damaxv_zen_int_avx512, // axpbyv @@ -162,17 +183,8 @@ void bli_cntx_init_zen4( cntx_t* cntx ) // // These are reference block sizes and may be overridden based on // number of threads used at runtime. - - // s d c z - bli_blksz_init_easy( &blkszs[ BLIS_MR ], 32, 16, 3, 3 ); - bli_blksz_init_easy( &blkszs[ BLIS_NR ], 12, 14, 8, 4 ); - bli_blksz_init_easy( &blkszs[ BLIS_MC ], 480, 240, 144, 18 ); - bli_blksz_init ( &blkszs[ BLIS_KC ], 384, 512, 256, 566, - 480, 320, 256, 566 ); - bli_blksz_init_easy( &blkszs[ BLIS_NC ], 3072, 4004, 4080, 256 ); - - bli_blksz_init_easy( &blkszs[ BLIS_AF ], 8, 8, -1, -1 ); - bli_blksz_init_easy( &blkszs[ BLIS_DF ], 8, 8, -1, -1 ); + + BLI_CNTX_DEFAULT_BLKSZ_LIST(blkszs); // Update the context with the current architecture's register and cache // blocksizes (and multiples) for native execution. @@ -192,11 +204,14 @@ void bli_cntx_init_zen4( cntx_t* cntx ) ); // ------------------------------------------------------------------------- -#if 0 // GENOA TODO: TRSM AVX-512 implementation +#if 0 // Replaced with runtime blocksize override + //Initialize TRSM blocksize objects with architecture-specific values. //Using different cache block sizes for TRSM instead of common level-3 block sizes. //Tuning is done for double-precision only. // s d c z + bli_blksz_init_easy( &blkszs[ BLIS_MR ], 6, 6, 3, 3 ); + bli_blksz_init_easy( &blkszs[ BLIS_NR ], 16, 8, 8, 4 ); bli_blksz_init_easy( &blkszs[ BLIS_MC ], 144, 72, 144, 72 ); bli_blksz_init_easy( &blkszs[ BLIS_KC ], 256, 492, 256, 256 ); bli_blksz_init_easy( &blkszs[ BLIS_NC ], 4080, 1600, 4080, 4080 ); @@ -298,3 +313,72 @@ void bli_cntx_init_zen4( cntx_t* cntx ) cntx ); } + +/* + * Override the block sizes in the context to the block sizes used + * by AVX2 GEMM+TRSM kernels, this is needed in Zen4 context as default + * GEMM kernels are AVX512 based and uses different block sizes. + * + * This function should be called in TRSM path before performing + * any packing operations. + * + * Also the context must be restored to default values by calling + * bli_zen4_restore_default_blkszs() before exiting TRSM Path + */ +void bli_zen4_override_trsm_blkszs (cntx_t* cntx) +{ + blksz_t blkszs[ BLIS_NUM_BLKSZS ]; + bli_blksz_init_easy( &blkszs[ BLIS_MR ], 6, 6, 3, 3 ); + bli_blksz_init_easy( &blkszs[ BLIS_NR ], 16, 8, 8, 4 ); + bli_blksz_init_easy( &blkszs[ BLIS_MC ], 144, 72, 144, 72 ); + bli_blksz_init_easy( &blkszs[ BLIS_KC ], 256, 492, 256, 256 ); + bli_blksz_init_easy( &blkszs[ BLIS_NC ], 4080, 1600, 4080, 4080 ); + + + // Update the context with the current architecture's register and cache + // blocksizes (and multiples) for native execution. + bli_cntx_set_blkszs + ( + BLIS_NAT, 5, + // level-3 + BLIS_NC, &blkszs[ BLIS_NC ], BLIS_NR, + BLIS_KC, &blkszs[ BLIS_KC ], BLIS_KR, + BLIS_MC, &blkszs[ BLIS_MC ], BLIS_MR, + BLIS_NR, &blkszs[ BLIS_NR ], BLIS_NR, + BLIS_MR, &blkszs[ BLIS_MR ], BLIS_MR, + cntx + ); +} + +/* + * Restore the block sizes to default values needed for zen4 context. + * + * This function should be called to restore the block sizes to there + * default values if they where overriden by calling + * bli_zen4_override_trsm_blkszs() to enable AVX2 GEMM kernels in the + * TRSM path. + * + */ +void bli_zen4_restore_default_blkszs (cntx_t* cntx) +{ + blksz_t blkszs[ BLIS_NUM_BLKSZS ]; + + BLI_CNTX_DEFAULT_BLKSZ_LIST(blkszs); + + // Update the context with the current architecture's register and cache + // blocksizes (and multiples) for native execution. + bli_cntx_set_blkszs + ( + BLIS_NAT, 7, + // level-3 + BLIS_NC, &blkszs[ BLIS_NC ], BLIS_NR, + BLIS_KC, &blkszs[ BLIS_KC ], BLIS_KR, + BLIS_MC, &blkszs[ BLIS_MC ], BLIS_MR, + BLIS_NR, &blkszs[ BLIS_NR ], BLIS_NR, + BLIS_MR, &blkszs[ BLIS_MR ], BLIS_MR, + // level-1f + BLIS_AF, &blkszs[ BLIS_AF ], BLIS_AF, + BLIS_DF, &blkszs[ BLIS_DF ], BLIS_DF, + cntx + ); +} \ No newline at end of file diff --git a/config/zen4/bli_family_zen4.h b/config/zen4/bli_family_zen4.h index 71929cdac4..fad5f16986 100644 --- a/config/zen4/bli_family_zen4.h +++ b/config/zen4/bli_family_zen4.h @@ -62,4 +62,28 @@ #define BLIS_SIMD_SIZE 64 #define BLIS_SIMD_NUM_REGISTERS 32 +/* + * Override the block sizes in the context to the block sizes used + * by AVX2 GEMM+TRSM kernels, this is needed in Zen4 context as default + * GEMM kernels are AVX512 based and uses different block sizes. + * + * This function should be called in TRSM path before performing + * any packing operations. + * + * Also the context must be restored to default values by calling + * bli_zen4_restore_default_blkszs() before exiting TRSM Path + */ +BLIS_EXPORT_BLIS void bli_zen4_override_trsm_blkszs (cntx_t* cntx); + +/* + * Restore the block sizes to default values needed for zen4 context. + * + * This function should be called to restore the block sizes to there + * default values if they where overriden by calling + * bli_zen4_override_trsm_blkszs() to enable AVX2 GEMM kernels in the + * TRSM path. + * + */ +BLIS_EXPORT_BLIS void bli_zen4_restore_default_blkszs (cntx_t* cntx); + #endif diff --git a/config/zen4/make_defs.mk b/config/zen4/make_defs.mk index c6a3c545f5..75bec7018e 100644 --- a/config/zen4/make_defs.mk +++ b/config/zen4/make_defs.mk @@ -73,11 +73,11 @@ GCC_VERSION := $(strip $(shell $(CC) -dumpversion | cut -d. -f1)) # gcc or clang version must be atleast 4.0 # gcc 9.0 or later: ifeq ($(shell test $(GCC_VERSION) -ge 11; echo $$?),0) -CKVECFLAGS += -march=znver3 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mfpmath=sse +CKVECFLAGS += -march=znver3 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -mfpmath=sse CRVECFLAGS += -march=znver3 else ifeq ($(shell test $(GCC_VERSION) -ge 9; echo $$?),0) -CKVECFLAGS += -march=znver2 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mfpmath=sse +CKVECFLAGS += -march=znver2 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -mfpmath=sse CRVECFLAGS += -march=znver2 else # If gcc is older than 9.1.0 but at least 6.1.0, then we can use -march=znver1 @@ -107,12 +107,12 @@ CRVECFLAGS += -march=znver4 else # for version 3x we will enable znver3 ifeq ($(strip $(shell $(CC) -v |&head -1 |grep -c 'AOCC_3')),1) -CKVECFLAGS += -march=znver3 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mfpmath=sse +CKVECFLAGS += -march=znver3 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -mfpmath=sse CRVECFLAGS += -march=znver3 else # for version 2x we will enable znver2 ifeq ($(strip $(shell $(CC) -v |&head -1 |grep -c 'AOCC.LLVM.2\|AOCC_2')),1) -CKVECFLAGS += -march=znver2 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mfpmath=sse +CKVECFLAGS += -march=znver2 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -mfpmath=sse CRVECFLAGS += -march=znver2 else #if compiling with clang diff --git a/frame/3/trsm/bli_trsm_front.c b/frame/3/trsm/bli_trsm_front.c index f964faf0dd..9eddd5c42a 100644 --- a/frame/3/trsm/bli_trsm_front.c +++ b/frame/3/trsm/bli_trsm_front.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2020, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -36,6 +36,7 @@ #include "blis.h" //#define PRINT_SMALL_TRSM_INFO + void bli_trsm_front ( side_t side, @@ -151,6 +152,24 @@ void bli_trsm_front // in bli_packm_init(). if ( bli_cntx_method( cntx ) == BLIS_NAT ) { +#if defined(BLIS_FAMILY_AMDZEN) || defined(BLIS_FAMILY_ZEN4) + /* Zen4 TRSM Fixme: + * + * On Zen4 we want to use AVX-512 kernels for GEMM and AVX2 kernels + * for TRSM (Till we implemente TRSM AVX-512 kernels) + * + * The AVX2 kernels use different block sizes then AVX512 kernels + * Here we override the default block sizes in the context with AVX2 + * specific block size used in GEMMTRSM kernerls. + * + * We need to revisit this when TRSM AVX-512 kernels are implemented. + */ + if ( (bli_arch_query_id() == BLIS_ARCH_ZEN4) && + (bli_obj_dt(a) == BLIS_FLOAT || bli_obj_dt(a) == BLIS_DOUBLE) ) + { + bli_zen4_override_trsm_blkszs(cntx); + } +#endif bli_obj_set_pack_schema( BLIS_PACKED_ROW_PANELS, &a_local ); bli_obj_set_pack_schema( BLIS_PACKED_COL_PANELS, &b_local ); } @@ -177,6 +196,20 @@ void bli_trsm_front rntm, cntl ); + +#if defined(BLIS_FAMILY_AMDZEN) || defined(BLIS_FAMILY_ZEN4) + /* Zen4 TRSM Fixme: + * + * We have overrding the block sizes at the start of this function + * Since the context is created only once we need to ensure that the + * default block sizes are restored for the subsequent operations. + */ + if ( (bli_arch_query_id() == BLIS_ARCH_ZEN4) && + (bli_obj_dt(a) == BLIS_FLOAT || bli_obj_dt(a) == BLIS_DOUBLE) ) + { + bli_zen4_restore_default_blkszs(cntx); + } +#endif AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); } diff --git a/frame/3/trsm/bli_trsm_ll_ker_var2.c b/frame/3/trsm/bli_trsm_ll_ker_var2.c index 5426348c83..fe39e6f478 100644 --- a/frame/3/trsm/bli_trsm_ll_ker_var2.c +++ b/frame/3/trsm/bli_trsm_ll_ker_var2.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2020, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -174,6 +174,25 @@ void PASTEMAC(ch,varname) \ gemmtrsm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMMTRSM_L_UKR, cntx ); \ PASTECH(ch,gemm_ukr_ft) \ gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \ +\ + /* Zen4 TRSM Fixme: + * + * On Zen4 we want to use AVX-512 kernels for GEMM and AVX2 kernels + * for TRSM (Till we implemente TRSM AVX-512 kernels) + * + * The AVX2 kernels for TRSM are enabled in the context, but they + * are compatible with only AVX2 version of GEMM kernels. + * + * Here we force the GEMM kernels to the AVX2 varients for float and double. + * For scomplex and dcomplex reference path is retained as is. + * + * We need to revisit this when TRSM AVX-512 kernels are implemented. + */ \ + if ((bli_arch_query_id() == BLIS_ARCH_ZEN4) && \ + (dt == BLIS_FLOAT || dt == BLIS_DOUBLE)) \ + { \ + gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_AVX2_UKR, cntx ); \ + } \ \ /* Temporary C buffer for edge cases. Note that the strides of this temporary buffer are set so that they match the storage of the diff --git a/frame/3/trsm/bli_trsm_lu_ker_var2.c b/frame/3/trsm/bli_trsm_lu_ker_var2.c index 0d4e2e0ba6..e55b75dff4 100644 --- a/frame/3/trsm/bli_trsm_lu_ker_var2.c +++ b/frame/3/trsm/bli_trsm_lu_ker_var2.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2020, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -174,6 +174,25 @@ void PASTEMAC(ch,varname) \ gemmtrsm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMMTRSM_U_UKR, cntx ); \ PASTECH(ch,gemm_ukr_ft) \ gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \ +\ + /* Zen4 TRSM Fixme: + * + * On Zen4 we want to use AVX-512 kernels for GEMM and AVX2 kernels + * for TRSM (Till we implemente TRSM AVX-512 kernels) + * + * The AVX2 kernels for TRSM are enabled in the context, but they + * are compatible with only AVX2 version of GEMM kernels. + * + * Here we force the GEMM kernels to the AVX2 varients for float and double. + * For scomplex and dcomplex reference path is retained as is. + * + * We need to revisit this when TRSM AVX-512 kernels are implemented. + */ \ + if ((bli_arch_query_id() == BLIS_ARCH_ZEN4) && \ + (dt == BLIS_FLOAT || dt == BLIS_DOUBLE)) \ + { \ + gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_AVX2_UKR, cntx ); \ + } \ \ /* Temporary C buffer for edge cases. Note that the strides of this temporary buffer are set so that they match the storage of the diff --git a/frame/3/trsm/bli_trsm_rl_ker_var2.c b/frame/3/trsm/bli_trsm_rl_ker_var2.c index 396fb4af12..5e070a760b 100644 --- a/frame/3/trsm/bli_trsm_rl_ker_var2.c +++ b/frame/3/trsm/bli_trsm_rl_ker_var2.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2020, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -180,6 +180,25 @@ void PASTEMAC(ch,varname) \ gemmtrsm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMMTRSM_U_UKR, cntx ); \ PASTECH(ch,gemm_ukr_ft) \ gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \ +\ + /* Zen4 TRSM Fixme: + * + * On Zen4 we want to use AVX-512 kernels for GEMM and AVX2 kernels + * for TRSM (Till we implemente TRSM AVX-512 kernels) + * + * The AVX2 kernels for TRSM are enabled in the context, but they + * are compatible with only AVX2 version of GEMM kernels. + * + * Here we force the GEMM kernels to the AVX2 varients for float and double. + * For scomplex and dcomplex reference path is retained as is. + * + * We need to revisit this when TRSM AVX-512 kernels are implemented. + */ \ + if ((bli_arch_query_id() == BLIS_ARCH_ZEN4) && \ + (dt == BLIS_FLOAT || dt == BLIS_DOUBLE)) \ + { \ + gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_AVX2_UKR, cntx ); \ + } \ \ /* Temporary C buffer for edge cases. Note that the strides of this temporary buffer are set so that they match the storage of the diff --git a/frame/3/trsm/bli_trsm_ru_ker_var2.c b/frame/3/trsm/bli_trsm_ru_ker_var2.c index 8b73b702f0..b592c24276 100644 --- a/frame/3/trsm/bli_trsm_ru_ker_var2.c +++ b/frame/3/trsm/bli_trsm_ru_ker_var2.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2020, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -179,6 +179,25 @@ void PASTEMAC(ch,varname) \ gemmtrsm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMMTRSM_L_UKR, cntx ); \ PASTECH(ch,gemm_ukr_ft) \ gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_UKR, cntx ); \ +\ + /* Zen4 TRSM Fixme: + * + * On Zen4 we want to use AVX-512 kernels for GEMM and AVX2 kernels + * for TRSM (Till we implemente TRSM AVX-512 kernels) + * + * The AVX2 kernels for TRSM are enabled in the context, but they + * are compatible with only AVX2 version of GEMM kernels. + * + * Here we force the GEMM kernels to the AVX2 varients for float and double. + * For scomplex and dcomplex reference path is retained as is. + * + * We need to revisit this when TRSM AVX-512 kernels are implemented. + */ \ + if ((bli_arch_query_id() == BLIS_ARCH_ZEN4) && \ + (dt == BLIS_FLOAT || dt == BLIS_DOUBLE)) \ + { \ + gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_AVX2_UKR, cntx ); \ + } \ \ /* Temporary C buffer for edge cases. Note that the strides of this temporary buffer are set so that they match the storage of the diff --git a/frame/base/bli_blksz.c b/frame/base/bli_blksz.c index f3891dbbba..a4d937bc2f 100644 --- a/frame/base/bli_blksz.c +++ b/frame/base/bli_blksz.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -272,7 +272,7 @@ dim_t bli_determine_blocksize_f b_alg = bli_blksz_get_def( dt, bsize ); b_max = bli_blksz_get_max( dt, bsize ); - // If b_use != 0, this means that trsm blocksizes are set + // If b_alg != 0, this means that trsm blocksizes are set // and we continue with trsm-specific blocksizes. // Else, we query L3 blocksizes and use them for TRSM execution. if( b_alg > 0 ) return bli_determine_blocksize_f_sub( i, dim, b_alg, b_max); @@ -313,10 +313,10 @@ dim_t bli_determine_blocksize_b b_alg = bli_blksz_get_def( dt, bsize ); b_max = bli_blksz_get_max( dt, bsize ); - // If b_use != 0, this means that trsm blocksizes are set + // If b_alg != 0, this means that trsm blocksizes are set // and we continue with trsm-specific blocksizes. // Else, we query L3 blocksizes and use them for TRSM execution. - if( b_alg > 0 ) bli_determine_blocksize_b_sub( i, dim, b_alg, b_max ); + if( b_alg > 0 ) return bli_determine_blocksize_b_sub( i, dim, b_alg, b_max ); } diff --git a/frame/include/bli_type_defs.h b/frame/include/bli_type_defs.h index 584c221ba0..4e28d8b461 100644 --- a/frame/include/bli_type_defs.h +++ b/frame/include/bli_type_defs.h @@ -802,10 +802,11 @@ typedef enum BLIS_GEMMTRSM_L_UKR, BLIS_GEMMTRSM_U_UKR, BLIS_TRSM_L_UKR, - BLIS_TRSM_U_UKR + BLIS_TRSM_U_UKR, + BLIS_GEMM_AVX2_UKR } l3ukr_t; -#define BLIS_NUM_LEVEL3_UKRS 5 +#define BLIS_NUM_LEVEL3_UKRS 6 typedef enum diff --git a/testsuite/src/test_gemmtrsm_ukr.c b/testsuite/src/test_gemmtrsm_ukr.c index b3916db6a1..a0cec45b92 100644 --- a/testsuite/src/test_gemmtrsm_ukr.c +++ b/testsuite/src/test_gemmtrsm_ukr.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -209,13 +209,32 @@ void libblis_test_gemmtrsm_ukr_experiment // Query a context. cntx = bli_gks_query_cntx(); +#if defined(BLIS_FAMILY_AMDZEN) || defined(BLIS_FAMILY_ZEN4) + /* Zen4 TRSM Fixme: + * + * TRSM and GEMM used different values of MR and NR, we need to ensure that + * Values used for packing are as per the MR and NR values expected by the kernels + * For now this issue exists only for zen4 hence override the values here if + * the family is BLIS_TRSM and architecture is zen4 + * + * We need to override the values here as well as the packing and compute + * kernels are invoked directly from here (instead of BLIS/BLAS call.) + * + * We need to revisit this when TRSM AVX-512 kernels are implemented. + */ + if (bli_arch_query_id() == BLIS_ARCH_ZEN4) + { + bli_zen4_override_trsm_blkszs(cntx); + } +#endif + // Use the datatype of the first char in the datatype combination string. bli_param_map_char_to_blis_dt( dc_str[0], &datatype ); // Map the dimension specifier to actual dimensions. k = libblis_test_get_dim_from_prob_size( op->dim_spec[0], p_cur ); - // Fix m and n to MR and NR, respectively. + m = bli_cntx_get_blksz_def_dt( datatype, BLIS_MR, cntx ); n = bli_cntx_get_blksz_def_dt( datatype, BLIS_NR, cntx ); @@ -224,6 +243,7 @@ void libblis_test_gemmtrsm_ukr_experiment ldap = bli_cntx_get_blksz_max_dt( datatype, BLIS_MR, cntx ); ldbp = bli_cntx_get_blksz_max_dt( datatype, BLIS_NR, cntx ); + // Store the register blocksizes so that the driver can retrieve the // values later when printing results. op->dim_aux[0] = m; @@ -433,6 +453,7 @@ bli_printm( "ap", &ap, "%5.2f", "" ); bli_cntl_free( cntl_b, &BLIS_PACKM_SINGLE_THREADED ); #endif + // Free the packed objects. bli_obj_free( &ap ); bli_obj_free( &bp ); @@ -442,6 +463,20 @@ bli_printm( "ap", &ap, "%5.2f", "" ); bli_obj_free( &b ); bli_obj_free( &c11 ); bli_obj_free( &c11_save ); + +#if defined(BLIS_FAMILY_AMDZEN) || defined(BLIS_FAMILY_ZEN4) + /* Zen4 TRSM Fixme: + * + * We have overrding the block sizes at the start of this function + * Since the context is created only once we need to ensure that the + * default block sizes are restored for the subsequent operations. + */ + if (bli_arch_query_id() == BLIS_ARCH_ZEN4) + { + bli_zen4_restore_default_blkszs(cntx); + } +#endif + } From 25cf7517abf2095c625a908c1f3fa53eb41cfbf6 Mon Sep 17 00:00:00 2001 From: Arnav Sharma Date: Thu, 23 Jun 2022 12:43:51 +0530 Subject: [PATCH 143/243] AOCL Dynamic Optimization for DGEMMT - Optimized thread allocation for cases with n <= 220 for DGEMMT. AMD-Internal: [CPUPL-2215] Change-Id: Id01edf268a90fd96a41ef947db54f6afc490548f --- frame/base/bli_rntm.c | 54 +++++++++++++++++++++++++++++++++++++------ 1 file changed, 47 insertions(+), 7 deletions(-) diff --git a/frame/base/bli_rntm.c b/frame/base/bli_rntm.c index fbf5654b7a..0db51870eb 100644 --- a/frame/base/bli_rntm.c +++ b/frame/base/bli_rntm.c @@ -679,17 +679,57 @@ void bli_nthreads_optimum( { dim_t n = bli_obj_length(c); dim_t k = bli_obj_width_after_trans(a); - dim_t product = (n*k)>>4; /* product is derived based on n and k */ - //Limit the number thread for smaller sizes: - if(product <= 346) + if ( n < 32 ) { - n_threads_ideal = 1; + if ( k < 128 ) + { + n_threads_ideal = 1; + } + else if ( k == 128 ) + { + n_threads_ideal = 4; + } } - /* finer threshold needs to set for max_thread cap of 2,3,4,5,6..32 */ - else + else if ( n <= 40 ) { - n_threads_ideal = n_threads; + if ( k < 32 ) + { + n_threads_ideal = 2; + } + else if ( k < 128 ) + { + n_threads_ideal = 4; + } + else if ( k <= 256 ) + { + n_threads_ideal = 8; + } + } + else if ( n < 115 ) + { + if ( k < 128 ) + { + n_threads_ideal = 6; + } + else if ( k <= 216 ) + { + n_threads_ideal = 8; + } + } + else if ( n <= 160 ) + { + if ( k <= 132 ) + { + n_threads_ideal = 8; + } + } + else if ( n <= 220 ) + { + if ( k < 128 ) + { + n_threads_ideal = 8; + } } } else if( family == BLIS_TRMM && bli_obj_is_double(c)) From 2ad25a7180f8b34af0a14760ee04a8414e564ab2 Mon Sep 17 00:00:00 2001 From: Vignesh Balasubramanian Date: Mon, 27 Jun 2022 18:24:19 +0530 Subject: [PATCH 144/243] ZGEMM kernel performance improvement for k=1 sizes: The current implementation for handling zgemm exploits SIMD parallelism along the k dimension. This would give great performance in cases of k being large. But for input sizes with k=1, it is better to exploit SIMD parallelism along the m and n dimensions, thereby giving better performance. This commit does the same through loop reordering, by loading column vectors from A. AMD-Internal: [CPUPL-2236] Change-Id: Ibfa29f271395497b6e2d0127c319ecb4b883d19f --- frame/compat/bla_gemm_amd.c | 15 + kernels/zen/3/CMakeLists.txt | 3 +- kernels/zen/3/bli_zgemm_ref_k1.c | 1790 ++++++++++++++++++++++++++++++ kernels/zen/bli_kernels_zen.h | 12 + 4 files changed, 1819 insertions(+), 1 deletion(-) create mode 100644 kernels/zen/3/bli_zgemm_ref_k1.c diff --git a/frame/compat/bla_gemm_amd.c b/frame/compat/bla_gemm_amd.c index 99d7371778..eed041e4cf 100644 --- a/frame/compat/bla_gemm_amd.c +++ b/frame/compat/bla_gemm_amd.c @@ -712,6 +712,21 @@ void zgemm_ //dim_t nt = bli_thread_get_num_threads(); // get number of threads bool nt = bli_thread_get_is_parallel(); // Check if parallel zgemm is invoked. + if((nt==0) && (k0 == 1) && bli_is_notrans(blis_transa) && bli_is_notrans(blis_transb)) + { + bli_zgemm_ref_k1_nn( m0, n0, k0, + alpha, + a, *lda, + b, *ldb, + beta, + c, *ldc); + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + /* Finalize BLIS */ + bli_finalize_auto(); + + return; + } #ifdef BLIS_ENABLE_SMALL_MATRIX diff --git a/kernels/zen/3/CMakeLists.txt b/kernels/zen/3/CMakeLists.txt index 80f78b471b..a52740ecaf 100644 --- a/kernels/zen/3/CMakeLists.txt +++ b/kernels/zen/3/CMakeLists.txt @@ -1,4 +1,4 @@ -##Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.## +##Copyright (C) 2020-2022, Advanced Micro Devices, Inc. All rights reserved.## target_sources("${PROJECT_NAME}" PRIVATE @@ -7,6 +7,7 @@ target_sources("${PROJECT_NAME}" ${CMAKE_CURRENT_SOURCE_DIR}/bli_dgemm_ref_k1.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemm_sqp_kernels.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemm_sqp.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_zgemm_ref_k1.c ) add_subdirectory(sup) diff --git a/kernels/zen/3/bli_zgemm_ref_k1.c b/kernels/zen/3/bli_zgemm_ref_k1.c new file mode 100644 index 0000000000..143bf8c84d --- /dev/null +++ b/kernels/zen/3/bli_zgemm_ref_k1.c @@ -0,0 +1,1790 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include +#include "blis.h" + +#include "immintrin.h" + +#define Z_MR 4 +#define Z_NR 5 + +// Macros for the main loop for M +#define SCALE_ALPHA_REAL_M_LOOP(r0,r1,r2,valr) \ + if(valr != 0.0) \ + { \ + r2 = _mm256_broadcast_sd((double const *)(&valr)); \ + r0 = _mm256_mul_pd(r0,r2); \ + r1 = _mm256_mul_pd(r1,r2); \ + } \ + +#define SCALE_ALPHA_IMAG_M_LOOP(r0,r1,r2,r3,rc1,rc2,vali) \ + if(vali != 0.0) \ + { \ + r3 = _mm256_permute4x64_pd(rc1,0b10110001); \ + r2 = _mm256_set_pd(1.0,-1.0,1.0,-1.0); \ + r3 = _mm256_mul_pd(r3, r2); \ + r2 = _mm256_broadcast_sd((double const *)(&vali)); \ + r0 = _mm256_fmadd_pd(r3,r2,r0); \ + r3 = _mm256_permute4x64_pd(rc2,0b10110001); \ + r2 = _mm256_set_pd(1.0,-1.0,1.0,-1.0); \ + r3 = _mm256_mul_pd(r3, r2); \ + r2 = _mm256_broadcast_sd((double const *)(&vali)); \ + r1 = _mm256_fmadd_pd(r3,r2,r1); \ + } \ + +#define NEG_PERM_M_LOOP(r0,r1,r2) \ + r0 = _mm256_permute4x64_pd(r0,0b10110001); \ + r1 = _mm256_permute4x64_pd(r1,0b10110001); \ + r2 = _mm256_set_pd(1.0,-1.0,1.0,-1.0); \ + r0 = _mm256_mul_pd(r2, r0); \ + r1 = _mm256_mul_pd(r2, r1); \ + +#define FMA_M_LOOP(rin_0,rin_1,rout_0,rout_1,rbc,loc) \ + rbc = _mm256_broadcast_sd(loc); \ + rout_0 = _mm256_fmadd_pd(rbc, rin_0, rout_0); \ + rout_1 = _mm256_fmadd_pd(rbc, rin_1, rout_1); \ + +#define SCALE_BETA_REAL_M_LOOP(rin_0,rin_1,rout_0,rout_1,rbc) \ + rout_0 = _mm256_fmadd_pd(rbc, rin_0, rout_0); \ + rout_1 = _mm256_fmadd_pd(rbc, rin_1, rout_1); \ + +#define SCALE_BETA_IMAG_M_LOOP(rin_0,rin_1,rout_0,rout_1,rbc,rn) \ + NEG_PERM_M_LOOP(rin_0,rin_1,rn); \ + rout_0 = _mm256_fmadd_pd(rbc, rin_0, rout_0); \ + rout_1 = _mm256_fmadd_pd(rbc, rin_1, rout_1); \ + + +// Macros for fringe cases with M +#define SCALE_ALPHA_REAL_M_FRINGE(r0,r2,val) \ + if(val != 0.0) \ + { \ + r2 = _mm256_broadcast_sd((double const *)(&val)); \ + r0 = _mm256_mul_pd(r0,r2); \ + } \ + +#define SCALE_ALPHA_IMAG_M_FRINGE(r0,r2,r3,r4,val) \ + if(val != 0.0) \ + { \ + r3 = _mm256_permute4x64_pd(r4,0b10110001); \ + r2 = _mm256_set_pd(1.0,-1.0,1.0,-1.0); \ + r3 = _mm256_mul_pd(r3, r2); \ + r2 = _mm256_broadcast_sd((double const *)(&val)); \ + r0 = _mm256_fmadd_pd(r3,r2,r0); \ + } \ + +#define NEG_PERM_M_FRINGE(r0,r2) \ + r0 = _mm256_permute4x64_pd(r0,0b10110001); \ + r2 = _mm256_set_pd(1.0,-1.0,1.0,-1.0); \ + r0 = _mm256_mul_pd(r2, r0); \ + +#define FMA_M_FRINGE(r_in,r_out,r_bc,loc) \ + r_bc = _mm256_broadcast_sd(loc); \ + r_out = _mm256_fmadd_pd(r_bc, r_in, r_out); \ + +#define SCALE_BETA_REAL_M_FRINGE(rin_0,rout_0,rbc) \ + rout_0 = _mm256_fmadd_pd(rbc, rin_0, rout_0); \ + +#define SCALE_BETA_IMAG_M_FRINGE(rin_0,rout_0,rbc,rn) \ + NEG_PERM_M_FRINGE(rin_0,rn); \ + rout_0 = _mm256_fmadd_pd(rbc, rin_0, rout_0); \ + + +void bli_zgemm_ref_k1_nn +( + dim_t m, + dim_t n, + dim_t k, + dcomplex* alpha, + dcomplex* a, const inc_t lda, + dcomplex* b, const inc_t ldb, + dcomplex* beta, + dcomplex* c, const inc_t ldc + ) +{ + + double alpha_valr, beta_valr; + double alpha_vali, beta_vali; + + alpha_valr = alpha->real; + beta_valr = beta->real; + alpha_vali = alpha->imag; + beta_vali = beta->imag; + + if((m == 0) || (n == 0) || (((alpha_valr == 0.0 && alpha_vali == 0.0) || (k == 0)) + && (beta_valr == 1.0 && beta_vali == 0.0))) + { + return; + } + dim_t m_remainder = (m % Z_MR); + dim_t n_remainder = (n % Z_NR); + + //scratch registers + __m256d ymm0, ymm1, ymm2, ymm3; + __m256d ymm4, ymm5, ymm6, ymm7; + __m256d ymm8, ymm9, ymm10, ymm11; + __m256d ymm12, ymm13, ymm14, ymm15; + __m128d xmm5; + + /* Form C = alpha*A*B + beta*c */ + for(dim_t j = 0;j < (n-Z_NR+1);j=j+Z_NR) + { + dcomplex* temp_b = b + j*ldb; + dcomplex* temp_a = a; + dcomplex* temp_c = c + j*ldc; + + for(dim_t i = 0;i < (m-Z_MR+1);i=i+Z_MR) + { + ymm3 = _mm256_setzero_pd(); + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + ymm12 = _mm256_setzero_pd(); + + if(alpha_valr != 0.0 || alpha_vali != 0.0) + { + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_valr, alpha_vali + where alpha_valr and/or alpha_vali is not zero. + b. This loop operates with 4x5 block size + along n dimension for every Z_NR columns of temp_b where + computing all Z_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + //R(a[0][0]) I(a[0][0]) R(a[1][0]) I(a[1][0]) + ymm0 = _mm256_loadu_pd((double const *)(temp_a)); + //R(a[2][0]) I(a[2][0]) R(a[3][0]) I(a[3][0]) + ymm1 = _mm256_loadu_pd((double const *)(temp_a + 2)); + + ymm13 = ymm0; + ymm14 = ymm1; + + _mm_prefetch((char*)(temp_a) + 64, _MM_HINT_T0); + + SCALE_ALPHA_REAL_M_LOOP(ymm0,ymm1,ymm15,alpha_valr); + SCALE_ALPHA_IMAG_M_LOOP(ymm0,ymm1,ymm15,ymm2,ymm13,ymm14,alpha_vali); + + /* + The result after scaling with alpha_valr and/or alpha_vali is as follows: + For ymm0 : + R(a[0][0]) = alpha_valr*R(a[0][0])-alpha_vali*I(a[0][0]) + I(a[0][0]) = alpha_valr*I(a[0][0])+alpha_vali*R[0][0] + R(a[1][0]) = alpha_valr*R(a[1][0])-alpha_vali*I(a[1][0]) + I(a[1][0]) = alpha_valr*I(a[1][0])+alpha_vali*(R[1][0]) + + For ymm1 : + R(a[2][0]) = alpha_valr*R(a[2][0])-alpha_vali*I(a[2][0]) + I(a[2][0]) = alpha_valr*I(a[2][0])+alpha_vali*R[2][0] + R(a[3][0]) = alpha_valr*R(a[3][0])-alpha_vali*I(a[3][0]) + I(a[3][0]) = alpha_valr*I(a[3][0])+alpha_vali*(R[3][0]) + */ + + //Calculating using real part of complex number in B matrix + //ymm3+=R(b[0][0])*R(a[0][0]) R(b[0][0])*I(a[0][0]) + // R(b[0][0])*R(a[1][0]) R(b[0][0])*I(a[1][0]) + //ymm4+=R(b[0][0])*R(a[2][0]) R(b[0][0])*I(a[2][0]) + // R(b[0][0])*R(a[3][0]) R(b[0][0])*I(a[3][0]) + FMA_M_LOOP(ymm0,ymm1,ymm3,ymm4,ymm2,(double const *)(temp_b)); + //ymm5+=R(b[0][1])*R(a[0][0]) R(b[0][1])*I(a[0][0]) + // R(b[0][1])*R(a[1][0]) R(b[0][1])*I(a[1][0]) + //ymm6+=R(b[0][1])*R(a[0][0]) R(b[0][1])*I(a[0][0]) + // R(b[0][1])*R(a[1][0]) R(b[0][1])*I(a[1][0]) + FMA_M_LOOP(ymm0,ymm1,ymm5,ymm6,ymm2,(double const *)(temp_b+ldb)); + //ymm7+=R(b[0][2])*R(a[0][0]) R(b[0][2])*I(a[0][0]) + // R(b[0][2])*R(a[1][0]) R(b[0][2])*I(a[1][0]) + //ymm8+=R(b[0][2])*R(a[0][0]) R(b[0][2])*I(a[0][0]) + // R(b[0][2])*R(a[1][0]) R(b[0][2])*I(a[1][0]) + FMA_M_LOOP(ymm0,ymm1,ymm7,ymm8,ymm2,(double const *)(temp_b+ldb*2)); + //ymm9+=R(b[0][3])*R(a[0][0]) R(b[0][3])*I(a[0][0]) + // R(b[0][3])*R(a[1][0]) R(b[0][3])*I(a[1][0]) + //ymm10+=R(b[0][3])*R(a[0][0]) R(b[0][3])*I(a[0][0]) + // R(b[0][3])*R(a[1][0]) R(b[0][3])*I(a[1][0]) + FMA_M_LOOP(ymm0,ymm1,ymm9,ymm10,ymm2,(double const *)(temp_b+ldb*3)); + //ymm11+=R(b[0][4])*R(a[0][0]) R(b[0][4])*I(a[0][0]) + // R(b[0][4])*R(a[1][0]) R(b[0][4])*I(a[1][0]) + //ymm12+=R(b[0][4])*R(a[0][0]) R(b[0][4])*I(a[0][0]) + // R(b[0][4])*R(a[1][0]) R(b[0][4])*I(a[1][0]) + FMA_M_LOOP(ymm0,ymm1,ymm11,ymm12,ymm2,(double const *)(temp_b+ldb*4)); + + //Calculating using imaginary part of complex numbers in B matrix + //Shuffling ymm0 and ymm1 in accordance to the requirement + NEG_PERM_M_LOOP(ymm0,ymm1,ymm2); + //ymm3+=I(b[0][0])*R(a[0][0]) I(b[0][0])*I(a[0][0]) + // I(b[0][0])*R(a[1][0]) I(b[0][0])*I(a[1][0]) + //ymm4+=R(b[0][0])*R(a[2][0]) I(b[0][0])*I(a[2][0]) + // I(b[0][0])*R(a[3][0]) I(b[0][0])*I(a[3][0]) + FMA_M_LOOP(ymm0,ymm1,ymm3,ymm4,ymm2,(double const *)(temp_b)+1); + //ymm5+=I(b[0][1])*R(a[0][0]) I(b[0][1])*I(a[0][0]) + // I(b[0][1])*R(a[1][0]) I(b[0][1])*I(a[1][0]) + //ymm6+=R(b[0][1])*R(a[0][0]) I(b[0][1])*I(a[0][0]) + // I(b[0][1])*R(a[1][0]) I(b[0][1])*I(a[1][0]) + FMA_M_LOOP(ymm0,ymm1,ymm5,ymm6,ymm2,(double const *)(temp_b+ldb)+1); + //ymm7+=I(b[0][2])*R(a[0][0]) I(b[0][2])*I(a[0][0]) + // I(b[0][2])*R(a[1][0]) I(b[0][2])*I(a[1][0]) + //ymm8+=I(b[0][2])*R(a[0][0]) I(b[0][2])*I(a[0][0]) + // I(b[0][2])*R(a[1][0]) I(b[0][2])*I(a[1][0]) + FMA_M_LOOP(ymm0,ymm1,ymm7,ymm8,ymm2,(double const *)(temp_b+ldb*2)+1); + //ymm9+=I(b[0][3])*R(a[0][0]) I(b[0][3])*I(a[0][0]) + // I(b[0][3])*R(a[1][0]) I(b[0][3])*I(a[1][0]) + //ymm10+=I(b[0][3])*R(a[0][0]) I(b[0][3])*I(a[0][0]) + // I(b[0][3])*R(a[1][0]) I(b[0][3])*I(a[1][0]) + FMA_M_LOOP(ymm0,ymm1,ymm9,ymm10,ymm2,(double const *)(temp_b+ldb*3)+1); + //ymm11+=I(b[0][4])*R(a[0][0]) I(b[0][4])*I(a[0][0]) + // I(b[0][4])*R(a[1][0]) I(b[0][4])*I(a[1][0]) + //ymm12+=I(b[0][4])*R(a[0][0]) I(b[0][4])*I(a[0][0]) + // I(b[0][4])*R(a[1][0]) I(b[0][4])*I(a[1][0]) + FMA_M_LOOP(ymm0,ymm1,ymm11,ymm12,ymm2,(double const *)(temp_b+ldb*4)+1); + } + if(beta_valr != 0.0) + { + /* + a. Perform beta*C using temp_c, beta_valr, + where beta_valr is not zero. + b. This loop operates with 4x5 block size + along n dimension for every Z_NR columns of temp_c where + computing all Z_MR rows of temp_c. + c. Accumulated alpha*A*B into registers will be added to beta*C + d. Same approach is used in remaining fringe cases. + */ + ymm15 = _mm256_broadcast_sd((double const *)(&beta_valr)); + + //R(c[0][0]) I(c[0][0]) R(c[1][0]) I(c[1][0]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c)); + //R(c[2][0]) I(c[2][0]) R(c[3][0]) I(c[3][0]) + ymm1 = _mm256_loadu_pd((double const *)(temp_c + 2)); + //ymm3+=beta_valr*R(c[0][0]) beta_valr*I(c[0][0]) + // beta_valr*R(c[1][0]) beta_valr*I(c[1][0]) + //ymm4+=beta_valr*R(c[2][0]) beta_valr*I(c[2][0]) + // beta_valr*R(c[3][0]) beta_valr*I(c[3][0]) + SCALE_BETA_REAL_M_LOOP(ymm0,ymm1,ymm3,ymm4,ymm15); + + //R(c[0][1]) I(c[0][1]) R(c[1][1]) I(c[1][1]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc)); + //R(c[2][1]) I(c[2][1]) R(c[3][1]) I(c[3][1]) + ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc + 2)); + //ymm5+=beta_valr*R(c[0][1]) beta_valr*I(c[0][1]) + // beta_valr*R(c[1][1]) beta_valr*I(c[1][1]) + //ymm6+=beta_valr*R(c[2][1]) beta_valr*I(c[2][1]) + // beta_valr*R(c[3][1]) beta_valr*I(c[3][1]) + SCALE_BETA_REAL_M_LOOP(ymm0,ymm1,ymm5,ymm6,ymm15); + + //R(c[0][2]) I(c[0][2]) R(c[1][2]) I(c[1][2]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*2)); + //R(c[2][2]) I(c[2][2]) R(c[3][2]) I(c[3][2]) + ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc*2 + 2)); + //ymm7+=beta_valr*R(c[0][2]) beta_valr*I(c[0][2]) + // beta_valr*R(c[1][2]) beta_valr*I(c[1][2]) + //ymm8+=beta_valr*R(c[2][2]) beta_valr*I(c[2][2]) + //beta_valr*R(c[3][2]) beta_valr*I(c[3][2]) + SCALE_BETA_REAL_M_LOOP(ymm0,ymm1,ymm7,ymm8,ymm15); + + //R(c[0][3]) I(c[0][3]) R(c[1][3]) I(c[1][3]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*3)); + //R(c[2][3]) I(c[2][3]) R(c[3][3]) I(c[3][3]) + ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc*3 + 2)); + //ymm9+=beta_valr*R(c[0][3]) beta_valr*I(c[0][3]) + // beta_valr*R(c[1][3]) beta_valr*I(c[1][3]) + //ymm10+=beta_valr*R(c[2][3]) beta_valr*I(c[2][3]) + // beta_valr*R(c[3][3]) beta_valr*I(c[3][3]) + SCALE_BETA_REAL_M_LOOP(ymm0,ymm1,ymm9,ymm10,ymm15); + + //R(c[0][4]) I(c[0][4]) R(c[1][4]) I(c[1][4]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*4)); + //R(c[2][4]) I(c[2][4]) R(c[3][4]) I(c[3][4]) + ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc*4 + 2)); + //ymm11+=beta_valr*R(c[0][4]) beta_valr*I(c[0][4]) + // beta_valr*R(c[1][4]) beta_valr*I(c[1][4]) + //ymm12+=beta_valr*R(c[2][4]) beta_valr*I(c[2][4]) + // beta_valr*R(c[3][4]) beta_valr*I(c[3][4]) + SCALE_BETA_REAL_M_LOOP(ymm0,ymm1,ymm11,ymm12,ymm15); + + } + if(beta_vali != 0.0) + { + /* + a. Perform beta*C using temp_c, beta_vali, + where beta_vali is not zero. + b. This loop operates with 4x5 block size + along n dimension for every Z_NR columns of temp_c where + computing all Z_MR rows of temp_c. + c. Accumulated alpha*A*B into registers will be added to beta*C + d. Same approach is used in remaining fringe cases. + */ + + ymm15 = _mm256_broadcast_sd((double const *)(&beta_vali)); + + //R(c[0][0]) I(c[0][0]) R(c[1][0]) I(c[1][0]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c)); + //R(c[2][0]) I(c[2][0]) R(c[3][0]) I(c[3][0]) + ymm1 = _mm256_loadu_pd((double const *)(temp_c + 2)); + //ymm3+=beta_vali*(-I(c[0][0])) beta_vali*R(c[0][0]) + // beta_vali*(-I(c[1][0])) beta_vali*R(c[1][0]) + //ymm4+=beta_vali*(-I(c[2][0])) beta_vali*R(c[2][0]) + // beta_vali*(-I(c[3][0])) beta_vali*R(c[3][0]) + SCALE_BETA_IMAG_M_LOOP(ymm0,ymm1,ymm3,ymm4,ymm15,ymm2); + + //R(c[0][1]) I(c[0][1]) R(c[1][1]) I(c[1][1]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc)); + //R(c[2][1]) I(c[2][1]) R(c[3][1]) I(c[3][1]) + ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc + 2)); + //ymm5+=beta_vali*(-I(c[0][1])) beta_vali*R(c[0][1]) + // beta_vali*(-I(c[1][1])) beta_vali*R(c[1][1]) + //ymm6+=beta_vali*(-I(c[2][1])) beta_vali*R(c[2][1]) + // beta_vali*(-I(c[3][1])) beta_vali*R(c[3][1]) + SCALE_BETA_IMAG_M_LOOP(ymm0,ymm1,ymm5,ymm6,ymm15,ymm2); + + //R(c[0][2]) I(c[0][2]) R(c[1][2]) I(c[1][2]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*2)); + //R(c[2][2]) I(c[2][2]) R(c[3][2]) I(c[3][2]) + ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc*2 + 2)); + //ymm7+=beta_vali*(-I(c[0][2])) beta_vali*R(c[0][2]) + // beta_vali*(-I(c[1][2])) beta_vali*R(c[1][2]) + //ymm8+=beta_vali*(-I(c[2][2])) beta_vali*R(c[2][2]) + // beta_vali*(-I(c[3][2])) beta_vali*R(c[3][2]) + SCALE_BETA_IMAG_M_LOOP(ymm0,ymm1,ymm7,ymm8,ymm15,ymm2); + + //R(c[0][3]) I(c[0][3]) R(c[1][3]) I(c[1][3]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*3)); + //R(c[2][3]) I(c[2][3]) R(c[3][3]) I(c[3][3]) + ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc*3 + 2)); + //ymm9+=beta_vali*(-I(c[0][3])) beta_vali*R(c[0][3]) + // beta_vali*(-I(c[1][3])) beta_vali*R(c[1][3]) + //ymm10+=beta_vali*(-I(c[2][3])) beta_vali*R(c[2][3]) + // beta_vali*(-I(c[3][3])) beta_vali*R(c[3][3]) + SCALE_BETA_IMAG_M_LOOP(ymm0,ymm1,ymm9,ymm10,ymm15,ymm2); + + //R(c[0][4]) I(c[0][4]) R(c[1][4]) I(c[1][4]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*4)); + //R(c[2][4]) I(c[2][4]) R(c[3][4]) I(c[3][4]) + ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc*4 + 2)); + //ymm11+=beta_vali*(-I(c[0][4])) beta_vali*R(c[0][4]) + // beta_vali*(-I(c[1][4])) beta_vali*R(c[1][4]) + //ymm12+=beta_vali*(-I(c[2][4])) beta_vali*R(c[2][4]) + // beta_vali*(-I(c[3][4])) beta_vali*R(c[3][4]) + SCALE_BETA_IMAG_M_LOOP(ymm0,ymm1,ymm11,ymm12,ymm15,ymm2); + } + /* + The scaling has been done sequentially as follows: + - If alpha_valr is not 0, it is used for scaling A + - If alpha_vali is not 0, it is used for scaling A using permutation + and selective negation, after loading + - If beta_valr is not 0, is is used for scaling C + - If beta_vali is not 0, it is used for scaling C using permutation + and selective negation, after loading + + The results are accumalated in accordance to the non zero scalar values, + and similar approach is followed in fringe cases + */ + + _mm256_storeu_pd((double *)(temp_c), ymm3); + _mm256_storeu_pd((double *)(temp_c + 2), ymm4); + + _mm256_storeu_pd((double *)(temp_c + ldc), ymm5); + _mm256_storeu_pd((double *)(temp_c + ldc + 2), ymm6); + + _mm256_storeu_pd((double *)(temp_c + ldc*2), ymm7); + _mm256_storeu_pd((double *)(temp_c + ldc*2 + 2), ymm8); + + _mm256_storeu_pd((double *)(temp_c + ldc*3), ymm9); + _mm256_storeu_pd((double *)(temp_c + ldc*3 + 2), ymm10); + + _mm256_storeu_pd((double *)(temp_c + ldc*4), ymm11); + _mm256_storeu_pd((double *)(temp_c + ldc*4 + 2), ymm12); + + temp_c+=Z_MR; + temp_a+=Z_MR; + } + + dim_t m_rem=m_remainder; + if(m_rem>=2) + { + ymm3 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + + if(alpha_valr != 0.0 || alpha_vali != 0.0) + { + + //R(a[0][0]) I(a[0][0]) R(a[1][0]) I(a[1][0]) + ymm0 = _mm256_loadu_pd((double const *)(temp_a)); + ymm13 = ymm0; + SCALE_ALPHA_REAL_M_FRINGE(ymm0,ymm15,alpha_valr); + SCALE_ALPHA_IMAG_M_FRINGE(ymm0,ymm15,ymm2,ymm13,alpha_vali); + /* + The result after scaling with alpha_valr and/or alpha_vali is as follows: + For ymm0 : + R(a[0][0]) = alpha_valr*R(a[0][0])-alpha_vali*I(a[0][0]) + I(a[0][0]) = alpha_valr*I(a[0][0])+alpha_vali*R[0][0] + R(a[1][0]) = alpha_valr*R(a[1][0])-alpha_vali*I(a[1][0]) + I(a[1][0]) = alpha_valr*I(a[1][0])+alpha_vali*(R[1][0]) + */ + + //Calculating using real part of complex number in B matrix + //ymm3+=R(b[0][0])*R(a[0][0]) R(b[0][0])*I(a[0][0]) + // R(b[0][0])*R(a[1][0]) R(b[0][0])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)); + //ymm5+=R(b[0][1])*R(a[0][0]) R(b[0][1])*I(a[0][0]) + // R(b[0][1])*R(a[1][0]) R(b[0][1])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm5,ymm2,(double const *)(temp_b+ldb)); + //ymm7+=R(b[0][2])*R(a[0][0]) R(b[0][2])*I(a[0][0]) + // R(b[0][2])*R(a[1][0]) R(b[0][2])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm7,ymm2,(double const *)(temp_b+ldb*2)); + //ymm9+=R(b[0][3])*R(a[0][0]) R(b[0][3])*I(a[0][0]) + // R(b[0][3])*R(a[1][0]) R(b[0][3])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm9,ymm2,(double const *)(temp_b+ldb*3)); + //ymm11+=R(b[0][4])*R(a[0][0]) R(b[0][4])*I(a[0][0]) + // R(b[0][4])*R(a[1][0]) R(b[0][4])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm11,ymm2,(double const *)(temp_b+ldb*4)); + + //Calculating using imaginary part of complex numbers in B matrix + //Shuffling ymm0 in accordance to the requirement + NEG_PERM_M_FRINGE(ymm0,ymm2); + + // ymm3+=I(b[0][0])*R(a[0][0]) I(b[0][0])*I(a[0][0]) + // I(b[0][0])*R(a[1][0]) I(b[0][0])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)+1); + //ymm5+=I(b[0][1])*R(a[0][0]) I(b[0][1])*I(a[0][0]) + // I(b[0][1])*R(a[1][0]) I(b[0][1])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm5,ymm2,(double const *)(temp_b+ldb)+1); + //ymm7+=I(b[0][2])*R(a[0][0]) I(b[0][2])*I(a[0][0]) + // I(b[0][2])*R(a[1][0]) I(b[0][2])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm7,ymm2,(double const *)(temp_b+ldb*2)+1); + //ymm9+=I(b[0][3])*R(a[0][0]) I(b[0][3])*I(a[0][0]) + // I(b[0][3])*R(a[1][0]) I(b[0][3])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm9,ymm2,(double const *)(temp_b+ldb*3)+1); + //ymm11+=I(b[0][4])*R(a[0][0]) I(b[0][4])*I(a[0][0]) + // I(b[0][4])*R(a[1][0]) I(b[0][4])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm11,ymm2,(double const *)(temp_b+ldb*4)+1); + + } + + if(beta_valr != 0.0) + { + + ymm15 = _mm256_broadcast_sd((double const *)(&beta_valr)); + + //R(c[0][0]) I(c[0][0]) R(c[1][0]) I(c[1][0]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c)); + //ymm3+=beta_valr*R(c[0][0]) beta_valr*I(c[0][0]) + // beta_valr*R(c[1][0]) beta_valr*I(c[1][0]) + SCALE_BETA_REAL_M_FRINGE(ymm0,ymm3,ymm15); + + //R(c[0][1]) I(c[0][1]) R(c[1][1]) I(c[1][1]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc)); + //ymm5+=beta_valr*R(c[0][1]) beta_valr*I(c[0][1]) + // beta_valr*R(c[1][1]) beta_valr*I(c[1][1]) + SCALE_BETA_REAL_M_FRINGE(ymm0,ymm5,ymm15); + + //R(c[0][2]) I(c[0][2]) R(c[1][2]) I(c[1][2]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*2)); + //ymm7+=beta_valr*R(c[0][2]) beta_valr*I(c[0][2]) + // beta_valr*R(c[1][2]) beta_valr*I(c[1][2]) + SCALE_BETA_REAL_M_FRINGE(ymm0,ymm7,ymm15); + + //R(c[0][3]) I(c[0][3]) R(c[1][3]) I(c[1][3]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*3)); + //ymm9+=beta_valr*R(c[0][3]) beta_valr*I(c[0][3]) + // beta_valr*R(c[1][3]) beta_valr*I(c[1][3]) + SCALE_BETA_REAL_M_FRINGE(ymm0,ymm9,ymm15); + + //R(c[0][4]) I(c[0][4]) R(c[1][4]) I(c[1][4]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*4)); + //ymm11+=beta_valr*R(c[0][4]) beta_valr*I(c[0][4]) + // beta_valr*R(c[1][4]) beta_valr*I(c[1][4]) + SCALE_BETA_REAL_M_FRINGE(ymm0,ymm11,ymm15); + + } + + if(beta_vali != 0.0) + { + + ymm15 = _mm256_broadcast_sd((double const *)(&beta_vali)); + + //R(c[0][0]) I(c[0][0]) R(c[1][0]) I(c[1][0]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c)); + //ymm3+=beta_vali*(-I(c[0][0])) beta_vali*R(c[0][0]) + // beta_vali*(-I(c[1][0])) beta_vali*R(c[1][0]) + SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm3,ymm15,ymm2); + + //R(c[0][1]) I(c[0][1]) R(c[1][1]) I(c[1][1]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc)); + //ymm5+=beta_vali*(-I(c[0][1])) beta_vali*R(c[0][1]) + // beta_vali*(-I(c[1][1])) beta_vali*R(c[1][1]) + SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm5,ymm15,ymm2); + + //R(c[0][2]) I(c[0][2]) R(c[1][2]) I(c[1][2]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*2)); + //ymm7+=beta_vali*(-I(c[0][2])) beta_vali*R(c[0][2]) + // beta_vali*(-I(c[1][2])) beta_vali*R(c[1][2]) + SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm7,ymm15,ymm2); + + //R(c[0][3]) I(c[0][3]) R(c[1][3]) I(c[1][3]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*3)); + //ymm9+=beta_vali*(-I(c[0][3])) beta_vali*R(c[0][3]) + // beta_vali*(-I(c[1][3])) beta_vali*R(c[1][3]) + SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm9,ymm15,ymm2); + + //R(c[0][4]) I(c[0][4]) R(c[1][4]) I(c[1][4]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*4)); + //ymm11+=beta_vali*(-I(c[0][4])) beta_vali*R(c[0][4]) + // beta_vali*(-I(c[1][4])) beta_vali*R(c[1][4]) + SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm11,ymm15,ymm2); + } + + /* + The scaling has been done sequentially as follows: + - If alpha_valr is not 0, it is used for scaling A + - If alpha_vali is not 0, it is used for scaling A using permutation + and selective negation, after loading + - If beta_valr is not 0, is is used for scaling C + - If beta_vali is not 0, it is used for scaling C using permutation + and selective negation, after loading + + The results are accumalated in accordance to the non zero scalar values, + and similar approach is followed in fringe cases + */ + + _mm256_storeu_pd((double *)(temp_c), ymm3); + _mm256_storeu_pd((double *)(temp_c + ldc), ymm5); + _mm256_storeu_pd((double *)(temp_c + ldc*2), ymm7); + _mm256_storeu_pd((double *)(temp_c + ldc*3), ymm9); + _mm256_storeu_pd((double *)(temp_c + ldc*4), ymm11); + + temp_c+=2; + temp_a+=2; + + m_rem -= 2; + } + + if(m_rem==1) + { + + xmm5 = _mm_setzero_pd(); + ymm3 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm11 = _mm256_setzero_pd(); + + if(alpha_valr != 0.0 || alpha_vali != 0.0) + { + xmm5 = _mm_loadu_pd((double const*)(temp_a));//R(a[0][0]) I(a[0][0]) + ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(a[0][0]) I(a[0][0]) + ymm13 = ymm0; + + SCALE_ALPHA_REAL_M_FRINGE(ymm0,ymm15,alpha_valr); + SCALE_ALPHA_IMAG_M_FRINGE(ymm0,ymm15,ymm2,ymm13,alpha_vali); + + //Calculating using real part of complex number in B matrix + //ymm3+=R(b[0][0])*R(a[0][0]) R(b[0][0])*I(a[0][0]) + // R(b[0][0])*R(a[1][0]) R(b[0][0])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)); + //ymm5+=R(b[0][1])*R(a[0][0]) R(b[0][1])*I(a[0][0]) + // R(b[0][1])*R(a[1][0]) R(b[0][1])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm5,ymm2,(double const *)(temp_b+ldb)); + //ymm7+=R(b[0][2])*R(a[0][0]) R(b[0][2])*I(a[0][0]) + // R(b[0][2])*R(a[1][0]) R(b[0][2])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm7,ymm2,(double const *)(temp_b+ldb*2)); + //ymm9+=R(b[0][3])*R(a[0][0]) R(b[0][3])*I(a[0][0]) + // R(b[0][3])*R(a[1][0]) R(b[0][3])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm9,ymm2,(double const *)(temp_b+ldb*3)); + //ymm11+=R(b[0][4])*R(a[0][0]) R(b[0][4])*I(a[0][0]) + // R(b[0][4])*R(a[1][0]) R(b[0][4])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm11,ymm2,(double const *)(temp_b+ldb*4)); + + //Calculating using imaginary part of complex numbers in B matrix + //Shuffling ymm0 in accordance to the requirement + NEG_PERM_M_FRINGE(ymm0,ymm2); + + // ymm3+=I(b[0][0])*R(a[0][0]) I(b[0][0])*I(a[0][0]) + // I(b[0][0])*R(a[1][0]) I(b[0][0])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)+1); + //ymm5+=I(b[0][1])*R(a[0][0]) I(b[0][1])*I(a[0][0]) + // I(b[0][1])*R(a[1][0]) I(b[0][1])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm5,ymm2,(double const *)(temp_b+ldb)+1); + //ymm7+=I(b[0][2])*R(a[0][0]) I(b[0][2])*I(a[0][0]) + // I(b[0][2])*R(a[1][0]) I(b[0][2])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm7,ymm2,(double const *)(temp_b+ldb*2)+1); + //ymm9+=I(b[0][3])*R(a[0][0]) I(b[0][3])*I(a[0][0]) + // I(b[0][3])*R(a[1][0]) I(b[0][3])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm9,ymm2,(double const *)(temp_b+ldb*3)+1); + //ymm11+=I(b[0][4])*R(a[0][0]) I(b[0][4])*I(a[0][0]) + // I(b[0][4])*R(a[1][0]) I(b[0][4])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm11,ymm2,(double const *)(temp_b+ldb*4)+1); + + } + if(beta_valr != 0.0) + { + ymm15 = _mm256_broadcast_sd((double const *)(&beta_valr)); + + xmm5 = _mm_loadu_pd((double const*)(temp_c));//R(c[0][0]) I(c[0][0]) + ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][0]) I(c[0][0]) + //ymm3+=beta_valr*R(c[0][0]) beta_valr*I(c[0][0]) + SCALE_BETA_REAL_M_FRINGE(ymm0,ymm3,ymm15); + + xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc));//R(c[0][1]) I(c[0][1]) + ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][1]) I(c[0][1]) + //ymm5+=beta_valr*R(c[0][1]) beta_valr*I(c[0][1]) + SCALE_BETA_REAL_M_FRINGE(ymm0,ymm5,ymm15); + + xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc * 2));//R(c[0][2]) I(c[0][2]) + ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][2]) I(c[0][2]) + //ymm7+=beta_valr*R(c[0][2]) beta_valr*I(c[0][2]) + SCALE_BETA_REAL_M_FRINGE(ymm0,ymm7,ymm15); + + xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc * 3));//R(c[0][3]) I(c[0][3]) + ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][3]) I(c[0][3]) + //ymm9+=beta_valr*R(c[0][3]) beta_valr*I(c[0][3]) + SCALE_BETA_REAL_M_FRINGE(ymm0,ymm9,ymm15); + + xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc * 4));//R(c[0][4]) I(c[0][4]) + ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][4]) I(c[0][4]) + //ymm11+=beta_valr*R(c[0][4]) beta_valr*I(c[0][4]) + SCALE_BETA_REAL_M_FRINGE(ymm0,ymm11,ymm15); + } + if(beta_vali != 0.0) + { + ymm15 = _mm256_broadcast_sd((double const *)(&beta_vali)); + + xmm5 = _mm_loadu_pd((double const*)(temp_c));//R(c[0][0]) I(c[0][0]) + ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][0]) I(c[0][0]) + //ymm3+=beta_vali*(-I(c[0][0])) beta_vali*R(c[0][0]) + SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm3,ymm15,ymm2); + + xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc));//R(c[0][1]) I(c[0][1]) + ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][1]) I(c[0][1]) + //ymm5+=beta_vali*(-I(c[0][1])) beta_vali*R(c[0][1]) + SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm5,ymm15,ymm2); + + xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc * 2));//R(c[0][2]) I(c[0][2]) + ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][2]) I(c[0][2]) + //ymm7+=beta_vali*(-I(c[0][2])) beta_vali*R(c[0][2]) + SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm7,ymm15,ymm2); + + xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc * 3));//R(c[0][3]) I(c[0][3]) + ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][3]) I(c[0][3]) + //ymm9+=beta_vali*(-I(c[0][3])) beta_vali*R(c[0][3]) + SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm9,ymm15,ymm2); + + xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc * 4));//R(c[0][4]) I(c[0][4]) + ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][4]) I(c[0][4]) + //ymm11+=beta_vali*(-I(c[0][4])) beta_vali*R(c[0][4]) + SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm11,ymm15,ymm2); + + } + + xmm5 = _mm256_extractf128_pd(ymm3, 0); + _mm_storeu_pd((double *)(temp_c), xmm5); + + xmm5 = _mm256_extractf128_pd(ymm5, 0); + _mm_storeu_pd((double *)(temp_c + ldc), xmm5); + + xmm5 = _mm256_extractf128_pd(ymm7, 0); + _mm_storeu_pd((double *)(temp_c + ldc*2), xmm5); + + xmm5 = _mm256_extractf128_pd(ymm9, 0); + _mm_storeu_pd((double *)(temp_c + ldc*3), xmm5); + + xmm5 = _mm256_extractf128_pd(ymm11, 0); + _mm_storeu_pd((double *)(temp_c + ldc*4), xmm5); + + } + + } + if(n_remainder==4) + { + dcomplex* temp_b = b + (n - n_remainder)*ldb; + dcomplex* temp_a = a; + dcomplex* temp_c = c + (n - n_remainder)*ldc; + for(dim_t i = 0;i < (m-Z_MR+1);i=i+Z_MR) + { + ymm3 = _mm256_setzero_pd(); + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + + if(alpha_valr != 0.0 || alpha_vali != 0.0) + { + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_valr, alpha_vali + where alpha_valr and/or alpha_vali is not zero. + b. This loop operates with 4x5 block size + along n dimension for every Z_NR columns of temp_b where + computing all Z_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + + //R(a[0][0]) I(a[0][0]) R(a[1][0]) I(a[1][0]) + ymm0 = _mm256_loadu_pd((double const *)(temp_a)); + //R(a[2][0]) I(a[2][0]) R(a[3][0]) I(a[3][0]) + ymm1 = _mm256_loadu_pd((double const *)(temp_a + 2)); + + ymm13 = ymm0; + ymm14 = ymm1; + + _mm_prefetch((char*)(temp_a) + 64, _MM_HINT_T0); + + SCALE_ALPHA_REAL_M_LOOP(ymm0,ymm1,ymm15,alpha_valr); + SCALE_ALPHA_IMAG_M_LOOP(ymm0,ymm1,ymm15,ymm2,ymm13,ymm14,alpha_vali); + + /* + The result after scaling with alpha_valr and/or alpha_vali is as follows: + For ymm0 : + R(a[0][0]) = alpha_valr*R(a[0][0])-alpha_vali*I(a[0][0]) + I(a[0][0]) = alpha_valr*I(a[0][0])+alpha_vali*R[0][0] + R(a[1][0]) = alpha_valr*R(a[1][0])-alpha_vali*I(a[1][0]) + I(a[1][0]) = alpha_valr*I(a[1][0])+alpha_vali*(R[1][0]) + + For ymm1 : + R(a[2][0]) = alpha_valr*R(a[2][0])-alpha_vali*I(a[2][0]) + I(a[2][0]) = alpha_valr*I(a[2][0])+alpha_vali*R[2][0] + R(a[3][0]) = alpha_valr*R(a[3][0])-alpha_vali*I(a[3][0]) + I(a[3][0]) = alpha_valr*I(a[3][0])+alpha_vali*(R[3][0]) + */ + + //Calculating using real part of complex number in B matrix + FMA_M_LOOP(ymm0,ymm1,ymm3,ymm4,ymm2,(double const *)(temp_b)); + FMA_M_LOOP(ymm0,ymm1,ymm5,ymm6,ymm2,(double const *)(temp_b+ldb)); + FMA_M_LOOP(ymm0,ymm1,ymm7,ymm8,ymm2,(double const *)(temp_b+ldb*2)); + FMA_M_LOOP(ymm0,ymm1,ymm9,ymm10,ymm2,(double const *)(temp_b+ldb*3)); + + //Calculating using imaginary part of complex numbers in B matrix + //Shuffling ymm0 and ymm1 in accordance to the requirement + NEG_PERM_M_LOOP(ymm0,ymm1,ymm2); + FMA_M_LOOP(ymm0,ymm1,ymm3,ymm4,ymm2,(double const *)(temp_b)+1); + FMA_M_LOOP(ymm0,ymm1,ymm5,ymm6,ymm2,(double const *)(temp_b+ldb)+1); + FMA_M_LOOP(ymm0,ymm1,ymm7,ymm8,ymm2,(double const *)(temp_b+ldb*2)+1); + FMA_M_LOOP(ymm0,ymm1,ymm9,ymm10,ymm2,(double const *)(temp_b+ldb*3)+1); + } + if(beta_valr != 0.0) + { + /* + a. Perform beta*C using temp_c, beta_valr, + where beta_valr is not zero. + b. This loop operates with 4x5 block size + along n dimension for every Z_NR columns of temp_c where + computing all Z_MR rows of temp_c. + c. Accumulated alpha*A*B into registers will be added to beta*C + d. Same approach is used in remaining fringe cases. + */ + ymm15 = _mm256_broadcast_sd((double const *)(&beta_valr)); + + //R(c[0][0]) I(c[0][0]) R(c[1][0]) I(c[1][0]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c)); + //R(c[2][0]) I(c[2][0]) R(c[3][0]) I(c[3][0]) + ymm1 = _mm256_loadu_pd((double const *)(temp_c + 2)); + //ymm3+=beta_valr*R(c[0][0]) beta_valr*I(c[0][0]) + // beta_valr*R(c[1][0]) beta_valr*I(c[1][0]) + //ymm4+=beta_valr*R(c[2][0]) beta_valr*I(c[2][0]) + // beta_valr*R(c[3][0]) beta_valr*I(c[3][0]) + SCALE_BETA_REAL_M_LOOP(ymm0,ymm1,ymm3,ymm4,ymm15); + + //R(c[0][1]) I(c[0][1]) R(c[1][1]) I(c[1][1]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc)); + //R(c[2][1]) I(c[2][1]) R(c[3][1]) I(c[3][1]) + ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc + 2)); + //ymm5+=beta_valr*R(c[0][1]) beta_valr*I(c[0][1]) + // beta_valr*R(c[1][1]) beta_valr*I(c[1][1]) + //ymm6+=beta_valr*R(c[2][1]) beta_valr*I(c[2][1]) + // beta_valr*R(c[3][1]) beta_valr*I(c[3][1]) + SCALE_BETA_REAL_M_LOOP(ymm0,ymm1,ymm5,ymm6,ymm15); + + //R(c[0][2]) I(c[0][2]) R(c[1][2]) I(c[1][2]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*2)); + //R(c[2][2]) I(c[2][2]) R(c[3][2]) I(c[3][2]) + ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc*2 + 2)); + //ymm7+=beta_valr*R(c[0][2]) beta_valr*I(c[0][2]) + // beta_valr*R(c[1][2]) beta_valr*I(c[1][2]) + //ymm8+=beta_valr*R(c[2][2]) beta_valr*I(c[2][2]) + // beta_valr*R(c[3][2]) beta_valr*I(c[3][2]) + SCALE_BETA_REAL_M_LOOP(ymm0,ymm1,ymm7,ymm8,ymm15); + + //R(c[0][3]) I(c[0][3]) R(c[1][3]) I(c[1][3]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*3)); + //R(c[2][3]) I(c[2][3]) R(c[3][3]) I(c[3][3]) + ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc*3 + 2)); + //ymm9+=beta_valr*R(c[0][3]) beta_valr*I(c[0][3]) + // beta_valr*R(c[1][3]) beta_valr*I(c[1][3]) + //ymm10+=beta_valr*R(c[2][3]) beta_valr*I(c[2][3]) + // beta_valr*R(c[3][3]) beta_valr*I(c[3][3]) + SCALE_BETA_REAL_M_LOOP(ymm0,ymm1,ymm9,ymm10,ymm15); + + } + if(beta_vali != 0.0) + { + /* + a. Perform beta*C using temp_c, beta_vali, + where beta_vali is not zero. + b. This loop operates with 4x5 block size + along n dimension for every Z_NR columns of temp_c where + computing all Z_MR rows of temp_c. + c. Accumulated alpha*A*B into registers will be added to beta*C + d. Same approach is used in remaining fringe cases. + */ + + ymm15 = _mm256_broadcast_sd((double const *)(&beta_vali)); + + //R(c[0][0]) I(c[0][0]) R(c[1][0]) I(c[1][0]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c)); + //R(c[2][0]) I(c[2][0]) R(c[3][0]) I(c[3][0]) + ymm1 = _mm256_loadu_pd((double const *)(temp_c + 2)); + //ymm3+=beta_vali*(-I(c[0][0])) beta_vali*R(c[0][0]) + // beta_vali*(-I(c[1][0])) beta_vali*R(c[1][0]) + //ymm4+=beta_vali*(-I(c[2][0])) beta_vali*R(c[2][0]) + // beta_vali*(-I(c[3][0])) beta_vali*R(c[3][0]) + SCALE_BETA_IMAG_M_LOOP(ymm0,ymm1,ymm3,ymm4,ymm15,ymm2); + + //R(c[0][1]) I(c[0][1]) R(c[1][1]) I(c[1][1]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc)); + //R(c[2][1]) I(c[2][1]) R(c[3][1]) I(c[3][1]) + ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc + 2)); + //ymm5+=beta_vali*(-I(c[0][1])) beta_vali*R(c[0][1]) + // beta_vali*(-I(c[1][1])) beta_vali*R(c[1][1]) + //ymm6+=beta_vali*(-I(c[2][1])) beta_vali*R(c[2][1]) + // beta_vali*(-I(c[3][1])) beta_vali*R(c[3][1]) + SCALE_BETA_IMAG_M_LOOP(ymm0,ymm1,ymm5,ymm6,ymm15,ymm2); + + //R(c[0][2]) I(c[0][2]) R(c[1][2]) I(c[1][2]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*2)); + //R(c[2][2]) I(c[2][2]) R(c[3][2]) I(c[3][2]) + ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc*2 + 2)); + //ymm7+=beta_vali*(-I(c[0][2])) beta_vali*R(c[0][2]) + // beta_vali*(-I(c[1][2])) beta_vali*R(c[1][2]) + //ymm8+=beta_vali*(-I(c[2][2])) beta_vali*R(c[2][2]) + // beta_vali*(-I(c[3][2])) beta_vali*R(c[3][2]) + SCALE_BETA_IMAG_M_LOOP(ymm0,ymm1,ymm7,ymm8,ymm15,ymm2); + + //R(c[0][3]) I(c[0][3]) R(c[1][3]) I(c[1][3]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*3)); + //R(c[2][3]) I(c[2][3]) R(c[3][3]) I(c[3][3]) + ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc*3 + 2)); + //ymm9+=beta_vali*(-I(c[0][3])) beta_vali*R(c[0][3]) + // beta_vali*(-I(c[1][3])) beta_vali*R(c[1][3]) + //ymm10+=beta_vali*(-I(c[2][3])) beta_vali*R(c[2][3]) + // beta_vali*(-I(c[3][3])) beta_vali*R(c[3][3]) + SCALE_BETA_IMAG_M_LOOP(ymm0,ymm1,ymm9,ymm10,ymm15,ymm2); + } + /* + The scaling has been done sequentially as follows: + - If alpha_valr is not 0, it is used for scaling A + - If alpha_vali is not 0, it is used for scaling A using permutation + and selective negation, after loading + - If beta_valr is not 0, is is used for scaling C + - If beta_vali is not 0, it is used for scaling C using permutation + and selective negation, after loading + + The results are accumalated in accordance to the non zero scalar values, + and similar approach is followed in fringe cases + */ + + _mm256_storeu_pd((double *)(temp_c), ymm3); + _mm256_storeu_pd((double *)(temp_c + 2), ymm4); + + _mm256_storeu_pd((double *)(temp_c + ldc), ymm5); + _mm256_storeu_pd((double *)(temp_c + ldc + 2), ymm6); + + _mm256_storeu_pd((double *)(temp_c + ldc*2), ymm7); + _mm256_storeu_pd((double *)(temp_c + ldc*2 + 2), ymm8); + + _mm256_storeu_pd((double *)(temp_c + ldc*3), ymm9); + _mm256_storeu_pd((double *)(temp_c + ldc*3 + 2), ymm10); + + temp_c+=Z_MR; + temp_a+=Z_MR; + } + + dim_t m_rem=m_remainder; + if(m_rem>=2) + { + ymm3 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + + if(alpha_valr != 0.0 || alpha_vali != 0.0) + { + + //R(a[0][0]) I(a[0][0]) R(a[1][0]) I(a[1][0]) + ymm0 = _mm256_loadu_pd((double const *)(temp_a)); + ymm13 = ymm0; + SCALE_ALPHA_REAL_M_FRINGE(ymm0,ymm15,alpha_valr); + SCALE_ALPHA_IMAG_M_FRINGE(ymm0,ymm15,ymm2,ymm13,alpha_vali); + /* + The result after scaling with alpha_valr and/or alpha_vali is as follows: + For ymm0 : + R(a[0][0]) = alpha_valr*R(a[0][0])-alpha_vali*I(a[0][0]) + I(a[0][0]) = alpha_valr*I(a[0][0])+alpha_vali*R[0][0] + R(a[1][0]) = alpha_valr*R(a[1][0])-alpha_vali*I(a[1][0]) + I(a[1][0]) = alpha_valr*I(a[1][0])+alpha_vali*(R[1][0]) + */ + + //Calculating using real part of complex number in B matrix + //ymm3+=R(b[0][0])*R(a[0][0]) R(b[0][0])*I(a[0][0]) + // R(b[0][0])*R(a[1][0]) R(b[0][0])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)); + //ymm5+=R(b[0][1])*R(a[0][0]) R(b[0][1])*I(a[0][0]) + // R(b[0][1])*R(a[1][0]) R(b[0][1])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm5,ymm2,(double const *)(temp_b+ldb)); + //ymm7+=R(b[0][2])*R(a[0][0]) R(b[0][2])*I(a[0][0]) + // R(b[0][2])*R(a[1][0]) R(b[0][2])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm7,ymm2,(double const *)(temp_b+ldb*2)); + //ymm9+=R(b[0][3])*R(a[0][0]) R(b[0][3])*I(a[0][0]) + // R(b[0][3])*R(a[1][0]) R(b[0][3])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm9,ymm2,(double const *)(temp_b+ldb*3)); + + //Calculating using imaginary part of complex numbers in B matrix + //Shuffling ymm0 in accordance to the requirement + NEG_PERM_M_FRINGE(ymm0,ymm2); + + // ymm3+=I(b[0][0])*R(a[0][0]) I(b[0][0])*I(a[0][0]) + // I(b[0][0])*R(a[1][0]) I(b[0][0])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)+1); + //ymm5+=I(b[0][1])*R(a[0][0]) I(b[0][1])*I(a[0][0]) + // I(b[0][1])*R(a[1][0]) I(b[0][1])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm5,ymm2,(double const *)(temp_b+ldb)+1); + //ymm7+=I(b[0][2])*R(a[0][0]) I(b[0][2])*I(a[0][0]) + // I(b[0][2])*R(a[1][0]) I(b[0][2])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm7,ymm2,(double const *)(temp_b+ldb*2)+1); + //ymm9+=I(b[0][3])*R(a[0][0]) I(b[0][3])*I(a[0][0]) + // I(b[0][3])*R(a[1][0]) I(b[0][3])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm9,ymm2,(double const *)(temp_b+ldb*3)+1); + + } + + if(beta_valr != 0.0) + { + + ymm15 = _mm256_broadcast_sd((double const *)(&beta_valr)); + + //R(c[0][0]) I(c[0][0]) R(c[1][0]) I(c[1][0]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c)); + //ymm3+=beta_valr*R(c[0][0]) beta_valr*I(c[0][0]) + // beta_valr*R(c[1][0]) beta_valr*I(c[1][0]) + SCALE_BETA_REAL_M_FRINGE(ymm0,ymm3,ymm15); + + //R(c[0][1]) I(c[0][1]) R(c[1][1]) I(c[1][1]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc)); + //ymm5+=beta_valr*R(c[0][1]) beta_valr*I(c[0][1]) + // beta_valr*R(c[1][1]) beta_valr*I(c[1][1]) + SCALE_BETA_REAL_M_FRINGE(ymm0,ymm5,ymm15); + + //R(c[0][2]) I(c[0][2]) R(c[1][2]) I(c[1][2]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*2)); + //ymm7+=beta_valr*R(c[0][2]) beta_valr*I(c[0][2]) + // beta_valr*R(c[1][2]) beta_valr*I(c[1][2]) + SCALE_BETA_REAL_M_FRINGE(ymm0,ymm7,ymm15); + + //R(c[0][3]) I(c[0][3]) R(c[1][3]) I(c[1][3]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*3)); + //ymm9+=beta_valr*R(c[0][3]) beta_valr*I(c[0][3]) + // beta_valr*R(c[1][3]) beta_valr*I(c[1][3]) + SCALE_BETA_REAL_M_FRINGE(ymm0,ymm9,ymm15); + + } + + if(beta_vali != 0.0) + { + + ymm15 = _mm256_broadcast_sd((double const *)(&beta_vali)); + + //R(c[0][0]) I(c[0][0]) R(c[1][0]) I(c[1][0]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c)); + //ymm3+=beta_vali*(-I(c[0][0])) beta_vali*R(c[0][0]) + // beta_vali*(-I(c[1][0])) beta_vali*R(c[1][0]) + SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm3,ymm15,ymm2); + + //R(c[0][1]) I(c[0][1]) R(c[1][1]) I(c[1][1]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc)); + //ymm5+=beta_vali*(-I(c[0][1])) beta_vali*R(c[0][1]) + // beta_vali*(-I(c[1][1])) beta_vali*R(c[1][1]) + SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm5,ymm15,ymm2); + + //R(c[0][2]) I(c[0][2]) R(c[1][2]) I(c[1][2]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*2)); + //ymm7+=beta_vali*(-I(c[0][2])) beta_vali*R(c[0][2]) + // beta_vali*(-I(c[1][2])) beta_vali*R(c[1][2]) + SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm7,ymm15,ymm2); + + //R(c[0][3]) I(c[0][3]) R(c[1][3]) I(c[1][3]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*3)); + //ymm9+=beta_vali*(-I(c[0][3])) beta_vali*R(c[0][3]) + // beta_vali*(-I(c[1][3])) beta_vali*R(c[1][3]) + SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm9,ymm15,ymm2); + } + + /* + The scaling has been done sequentially as follows: + - If alpha_valr is not 0, it is used for scaling A + - If alpha_vali is not 0, it is used for scaling A using permutation + and selective negation, after loading + - If beta_valr is not 0, is is used for scaling C + - If beta_vali is not 0, it is used for scaling C using permutation + and selective negation, after loading + + The results are accumalated in accordance to the non zero scalar values, + and similar approach is followed in fringe cases + */ + + _mm256_storeu_pd((double *)(temp_c), ymm3); + _mm256_storeu_pd((double *)(temp_c + ldc), ymm5); + _mm256_storeu_pd((double *)(temp_c + ldc*2), ymm7); + _mm256_storeu_pd((double *)(temp_c + ldc*3), ymm9); + + temp_c+=2; + temp_a+=2; + + m_rem -= 2; + } + + if(m_rem==1) + { + + xmm5 = _mm_setzero_pd(); + ymm3 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + + if(alpha_valr != 0.0 || alpha_vali != 0.0) + { + xmm5 = _mm_loadu_pd((double const*)(temp_a));//R(a[0][0]) I(a[0][0]) + ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(a[0][0]) I(a[0][0]) + ymm13 = ymm0; + + SCALE_ALPHA_REAL_M_FRINGE(ymm0,ymm15,alpha_valr); + SCALE_ALPHA_IMAG_M_FRINGE(ymm0,ymm15,ymm2,ymm13,alpha_vali); + + //Calculating using real part of complex number in B matrix + //ymm3+=R(b[0][0])*R(a[0][0]) R(b[0][0])*I(a[0][0]) + // R(b[0][0])*R(a[1][0]) R(b[0][0])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)); + //ymm5+=R(b[0][1])*R(a[0][0]) R(b[0][1])*I(a[0][0]) + // R(b[0][1])*R(a[1][0]) R(b[0][1])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm5,ymm2,(double const *)(temp_b+ldb)); + //ymm7+=R(b[0][2])*R(a[0][0]) R(b[0][2])*I(a[0][0]) + // R(b[0][2])*R(a[1][0]) R(b[0][2])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm7,ymm2,(double const *)(temp_b+ldb*2)); + //ymm9+=R(b[0][3])*R(a[0][0]) R(b[0][3])*I(a[0][0]) + // R(b[0][3])*R(a[1][0]) R(b[0][3])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm9,ymm2,(double const *)(temp_b+ldb*3)); + + //Calculating using imaginary part of complex numbers in B matrix + //Shuffling ymm0 in accordance to the requirement + NEG_PERM_M_FRINGE(ymm0,ymm2); + + // ymm3+=I(b[0][0])*R(a[0][0]) I(b[0][0])*I(a[0][0]) + // I(b[0][0])*R(a[1][0]) I(b[0][0])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)+1); + //ymm5+=I(b[0][1])*R(a[0][0]) I(b[0][1])*I(a[0][0]) + // I(b[0][1])*R(a[1][0]) I(b[0][1])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm5,ymm2,(double const *)(temp_b+ldb)+1); + //ymm7+=I(b[0][2])*R(a[0][0]) I(b[0][2])*I(a[0][0]) + // I(b[0][2])*R(a[1][0]) I(b[0][2])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm7,ymm2,(double const *)(temp_b+ldb*2)+1); + //ymm9+=I(b[0][3])*R(a[0][0]) I(b[0][3])*I(a[0][0]) + // I(b[0][3])*R(a[1][0]) I(b[0][3])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm9,ymm2,(double const *)(temp_b+ldb*3)+1); + + } + if(beta_valr != 0.0) + { + ymm15 = _mm256_broadcast_sd((double const *)(&beta_valr)); + + xmm5 = _mm_loadu_pd((double const*)(temp_c));//R(c[0][0]) I(c[0][0]) + ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][0]) I(c[0][0]) + //ymm3+=beta_valr*R(c[0][0]) beta_valr*I(c[0][0]) + SCALE_BETA_REAL_M_FRINGE(ymm0,ymm3,ymm15); + + xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc));//R(c[0][1]) I(c[0][1]) + ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][1]) I(c[0][1]) + //ymm5+=beta_valr*R(c[0][1]) beta_valr*I(c[0][1]) + SCALE_BETA_REAL_M_FRINGE(ymm0,ymm5,ymm15); + + xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc * 2));//R(c[0][2]) I(c[0][2]) + ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][2]) I(c[0][2]) + //ymm7+=beta_valr*R(c[0][2]) beta_valr*I(c[0][2]) + SCALE_BETA_REAL_M_FRINGE(ymm0,ymm7,ymm15); + + xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc * 3));//R(c[0][3]) I(c[0][3]) + ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][3]) I(c[0][3]) + //ymm9+=beta_valr*R(c[0][3]) beta_valr*I(c[0][3]) + SCALE_BETA_REAL_M_FRINGE(ymm0,ymm9,ymm15); + } + if(beta_vali != 0.0) + { + ymm15 = _mm256_broadcast_sd((double const *)(&beta_vali)); + + xmm5 = _mm_loadu_pd((double const*)(temp_c));//R(c[0][0]) I(c[0][0]) + ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][0]) I(c[0][0]) + //ymm3+=beta_vali*(-I(c[0][0])) beta_vali*R(c[0][0]) + SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm3,ymm15,ymm2); + + xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc));//R(c[0][1]) I(c[0][1]) + ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][1]) I(c[0][1]) + //ymm5+=beta_vali*(-I(c[0][1])) beta_vali*R(c[0][1]) + SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm5,ymm15,ymm2); + + xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc * 2));//R(c[0][2]) I(c[0][2]) + ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][2]) I(c[0][2]) + //ymm7+=beta_vali*(-I(c[0][2])) beta_vali*R(c[0][2]) + SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm7,ymm15,ymm2); + + xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc * 3));//R(c[0][3]) I(c[0][3]) + ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][3]) I(c[0][3]) + //ymm9+=beta_vali*(-I(c[0][3])) beta_vali*R(c[0][3]) + SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm9,ymm15,ymm2); + + } + + xmm5 = _mm256_extractf128_pd(ymm3, 0); + _mm_storeu_pd((double *)(temp_c), xmm5); + + xmm5 = _mm256_extractf128_pd(ymm5, 0); + _mm_storeu_pd((double *)(temp_c + ldc), xmm5); + + xmm5 = _mm256_extractf128_pd(ymm7, 0); + _mm_storeu_pd((double *)(temp_c + ldc*2), xmm5); + + xmm5 = _mm256_extractf128_pd(ymm9, 0); + _mm_storeu_pd((double *)(temp_c + ldc*3), xmm5); + + } + n_remainder -= 4; + + } + if(n_remainder>=2) + { + dcomplex* temp_b = b + (n - n_remainder)*ldb; + dcomplex* temp_a = a; + dcomplex* temp_c = c + (n - n_remainder)*ldc; + for(dim_t i = 0;i < (m-Z_MR+1);i=i+Z_MR) + { + ymm3 = _mm256_setzero_pd(); + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + + if(alpha_valr != 0.0 || alpha_vali != 0.0) + { + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_valr, alpha_vali + where alpha_valr and/or alpha_vali is not zero. + b. This loop operates with 4x5 block size + along n dimension for every Z_NR columns of temp_b where + computing all Z_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + + //R(a[0][0]) I(a[0][0]) R(a[1][0]) I(a[1][0]) + ymm0 = _mm256_loadu_pd((double const *)(temp_a)); + //R(a[2][0]) I(a[2][0]) R(a[3][0]) I(a[3][0]) + ymm1 = _mm256_loadu_pd((double const *)(temp_a + 2)); + + ymm13 = ymm0; + ymm14 = ymm1; + + _mm_prefetch((char*)(temp_a) + 64, _MM_HINT_T0); + + SCALE_ALPHA_REAL_M_LOOP(ymm0,ymm1,ymm15,alpha_valr); + SCALE_ALPHA_IMAG_M_LOOP(ymm0,ymm1,ymm15,ymm2,ymm13,ymm14,alpha_vali); + + /* + The result after scaling with alpha_valr and/or alpha_vali is as follows: + For ymm0 : + R(a[0][0]) = alpha_valr*R(a[0][0])-alpha_vali*I(a[0][0]) + I(a[0][0]) = alpha_valr*I(a[0][0])+alpha_vali*R[0][0] + R(a[1][0]) = alpha_valr*R(a[1][0])-alpha_vali*I(a[1][0]) + I(a[1][0]) = alpha_valr*I(a[1][0])+alpha_vali*(R[1][0]) + + For ymm1 : + R(a[2][0]) = alpha_valr*R(a[2][0])-alpha_vali*I(a[2][0]) + I(a[2][0]) = alpha_valr*I(a[2][0])+alpha_vali*R[2][0] + R(a[3][0]) = alpha_valr*R(a[3][0])-alpha_vali*I(a[3][0]) + I(a[3][0]) = alpha_valr*I(a[3][0])+alpha_vali*(R[3][0]) + */ + + //Calculating using real part of complex number in B matrix + FMA_M_LOOP(ymm0,ymm1,ymm3,ymm4,ymm2,(double const *)(temp_b)); + FMA_M_LOOP(ymm0,ymm1,ymm5,ymm6,ymm2,(double const *)(temp_b+ldb)); + + //Calculating using imaginary part of complex numbers in B matrix + //Shuffling ymm0 and ymm1 in accordance to the requirement + NEG_PERM_M_LOOP(ymm0,ymm1,ymm2); + FMA_M_LOOP(ymm0,ymm1,ymm3,ymm4,ymm2,(double const *)(temp_b)+1); + FMA_M_LOOP(ymm0,ymm1,ymm5,ymm6,ymm2,(double const *)(temp_b+ldb)+1); + } + if(beta_valr != 0.0) + { + /* + a. Perform beta*C using temp_c, beta_valr, + where beta_valr is not zero. + b. This loop operates with 4x5 block size + along n dimension for every Z_NR columns of temp_c where + computing all Z_MR rows of temp_c. + c. Accumulated alpha*A*B into registers will be added to beta*C + d. Same approach is used in remaining fringe cases. + */ + ymm15 = _mm256_broadcast_sd((double const *)(&beta_valr)); + + //R(c[0][0]) I(c[0][0]) R(c[1][0]) I(c[1][0]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c)); + //R(c[2][0]) I(c[2][0]) R(c[3][0]) I(c[3][0]) + ymm1 = _mm256_loadu_pd((double const *)(temp_c + 2)); + //ymm3+=beta_valr*R(c[0][0]) beta_valr*I(c[0][0]) + // beta_valr*R(c[1][0]) beta_valr*I(c[1][0]) + //ymm4+=beta_valr*R(c[2][0]) beta_valr*I(c[2][0]) + // beta_valr*R(c[3][0]) beta_valr*I(c[3][0]) + SCALE_BETA_REAL_M_LOOP(ymm0,ymm1,ymm3,ymm4,ymm15); + + //R(c[0][1]) I(c[0][1]) R(c[1][1]) I(c[1][1]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc)); + //R(c[2][1]) I(c[2][1]) R(c[3][1]) I(c[3][1]) + ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc + 2)); + //ymm5+=beta_valr*R(c[0][1]) beta_valr*I(c[0][1]) + // beta_valr*R(c[1][1]) beta_valr*I(c[1][1]) + //ymm6+=beta_valr*R(c[2][1]) beta_valr*I(c[2][1]) + // beta_valr*R(c[3][1]) beta_valr*I(c[3][1]) + SCALE_BETA_REAL_M_LOOP(ymm0,ymm1,ymm5,ymm6,ymm15); + + } + if(beta_vali != 0.0) + { + /* + a. Perform beta*C using temp_c, beta_vali, + where beta_vali is not zero. + b. This loop operates with 4x5 block size + along n dimension for every Z_NR columns of temp_c where + computing all Z_MR rows of temp_c. + c. Accumulated alpha*A*B into registers will be added to beta*C + d. Same approach is used in remaining fringe cases. + */ + + ymm15 = _mm256_broadcast_sd((double const *)(&beta_vali)); + + //R(c[0][0]) I(c[0][0]) R(c[1][0]) I(c[1][0]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c)); + //R(c[2][0]) I(c[2][0]) R(c[3][0]) I(c[3][0]) + ymm1 = _mm256_loadu_pd((double const *)(temp_c + 2)); + //ymm3+=beta_vali*(-I(c[0][0])) beta_vali*R(c[0][0]) + // beta_vali*(-I(c[1][0])) beta_vali*R(c[1][0]) + //ymm4+=beta_vali*(-I(c[2][0])) beta_vali*R(c[2][0]) + // beta_vali*(-I(c[3][0])) beta_vali*R(c[3][0]) + SCALE_BETA_IMAG_M_LOOP(ymm0,ymm1,ymm3,ymm4,ymm15,ymm2); + + //R(c[0][1]) I(c[0][1]) R(c[1][1]) I(c[1][1]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc)); + //R(c[2][1]) I(c[2][1]) R(c[3][1]) I(c[3][1]) + ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc + 2)); + //ymm5+=beta_vali*(-I(c[0][1])) beta_vali*R(c[0][1]) + // beta_vali*(-I(c[1][1])) beta_vali*R(c[1][1]) + //ymm6+=beta_vali*(-I(c[2][1])) beta_vali*R(c[2][1]) + // beta_vali*(-I(c[3][1])) beta_vali*R(c[3][1]) + SCALE_BETA_IMAG_M_LOOP(ymm0,ymm1,ymm5,ymm6,ymm15,ymm2); + } + /* + The scaling has been done sequentially as follows: + - If alpha_valr is not 0, it is used for scaling A + - If alpha_vali is not 0, it is used for scaling A using permutation + and selective negation, after loading + - If beta_valr is not 0, is is used for scaling C + - If beta_vali is not 0, it is used for scaling C using permutation + and selective negation, after loading + + The results are accumalated in accordance to the non zero scalar values, + and similar approach is followed in fringe cases + */ + + _mm256_storeu_pd((double *)(temp_c), ymm3); + _mm256_storeu_pd((double *)(temp_c + 2), ymm4); + + _mm256_storeu_pd((double *)(temp_c + ldc), ymm5); + _mm256_storeu_pd((double *)(temp_c + ldc + 2), ymm6); + + temp_c+=Z_MR; + temp_a+=Z_MR; + } + + dim_t m_rem=m_remainder; + if(m_rem>=2) + { + ymm3 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + + if(alpha_valr != 0.0 || alpha_vali != 0.0) + { + + //R(a[0][0]) I(a[0][0]) R(a[1][0]) I(a[1][0]) + ymm0 = _mm256_loadu_pd((double const *)(temp_a)); + ymm13 = ymm0; + SCALE_ALPHA_REAL_M_FRINGE(ymm0,ymm15,alpha_valr); + SCALE_ALPHA_IMAG_M_FRINGE(ymm0,ymm15,ymm2,ymm13,alpha_vali); + /* + The result after scaling with alpha_valr and/or alpha_vali is as follows: + For ymm0 : + R(a[0][0]) = alpha_valr*R(a[0][0])-alpha_vali*I(a[0][0]) + I(a[0][0]) = alpha_valr*I(a[0][0])+alpha_vali*R[0][0] + R(a[1][0]) = alpha_valr*R(a[1][0])-alpha_vali*I(a[1][0]) + I(a[1][0]) = alpha_valr*I(a[1][0])+alpha_vali*(R[1][0]) + */ + + //Calculating using real part of complex number in B matrix + //ymm3+=R(b[0][0])*R(a[0][0]) R(b[0][0])*I(a[0][0]) + // R(b[0][0])*R(a[1][0]) R(b[0][0])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)); + //ymm5+=R(b[0][1])*R(a[0][0]) R(b[0][1])*I(a[0][0]) + // R(b[0][1])*R(a[1][0]) R(b[0][1])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm5,ymm2,(double const *)(temp_b+ldb)); + + //Calculating using imaginary part of complex numbers in B matrix + //Shuffling ymm0 in accordance to the requirement + NEG_PERM_M_FRINGE(ymm0,ymm2); + + // ymm3+=I(b[0][0])*R(a[0][0]) I(b[0][0])*I(a[0][0]) + // I(b[0][0])*R(a[1][0]) I(b[0][0])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)+1); + //ymm5+=I(b[0][1])*R(a[0][0]) I(b[0][1])*I(a[0][0]) + // I(b[0][1])*R(a[1][0]) I(b[0][1])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm5,ymm2,(double const *)(temp_b+ldb)+1); + + } + + if(beta_valr != 0.0) + { + + ymm15 = _mm256_broadcast_sd((double const *)(&beta_valr)); + + //R(c[0][0]) I(c[0][0]) R(c[1][0]) I(c[1][0]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c)); + //ymm3+=beta_valr*R(c[0][0]) beta_valr*I(c[0][0]) + // beta_valr*R(c[1][0]) beta_valr*I(c[1][0]) + SCALE_BETA_REAL_M_FRINGE(ymm0,ymm3,ymm15); + + //R(c[0][1]) I(c[0][1]) R(c[1][1]) I(c[1][1]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc)); + //ymm5+=beta_valr*R(c[0][1]) beta_valr*I(c[0][1]) + // beta_valr*R(c[1][1]) beta_valr*I(c[1][1]) + SCALE_BETA_REAL_M_FRINGE(ymm0,ymm5,ymm15); + + } + + if(beta_vali != 0.0) + { + + ymm15 = _mm256_broadcast_sd((double const *)(&beta_vali)); + + //R(c[0][0]) I(c[0][0]) R(c[1][0]) I(c[1][0]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c)); + //ymm3+=beta_vali*(-I(c[0][0])) beta_vali*R(c[0][0]) + // beta_vali*(-I(c[1][0])) beta_vali*R(c[1][0]) + SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm3,ymm15,ymm2); + + //R(c[0][1]) I(c[0][1]) R(c[1][1]) I(c[1][1]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc)); + //ymm5+=beta_vali*(-I(c[0][1])) beta_vali*R(c[0][1]) + // beta_vali*(-I(c[1][1])) beta_vali*R(c[1][1]) + SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm5,ymm15,ymm2); + } + + /* + The scaling has been done sequentially as follows: + - If alpha_valr is not 0, it is used for scaling A + - If alpha_vali is not 0, it is used for scaling A using permutation + and selective negation, after loading + - If beta_valr is not 0, is is used for scaling C + - If beta_vali is not 0, it is used for scaling C using permutation + and selective negation, after loading + + The results are accumalated in accordance to the non zero scalar values, + and similar approach is followed in fringe cases + */ + + _mm256_storeu_pd((double *)(temp_c), ymm3); + _mm256_storeu_pd((double *)(temp_c + ldc), ymm5); + + temp_c+=2; + temp_a+=2; + + m_rem -= 2; + } + + if(m_rem==1) + { + + xmm5 = _mm_setzero_pd(); + ymm3 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + + if(alpha_valr != 0.0 || alpha_vali != 0.0) + { + xmm5 = _mm_loadu_pd((double const*)(temp_a));//R(a[0][0]) I(a[0][0]) + ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(a[0][0]) I(a[0][0]) + ymm13 = ymm0; + + SCALE_ALPHA_REAL_M_FRINGE(ymm0,ymm15,alpha_valr); + SCALE_ALPHA_IMAG_M_FRINGE(ymm0,ymm15,ymm2,ymm13,alpha_vali); + + //Calculating using real part of complex number in B matrix + //ymm3+=R(b[0][0])*R(a[0][0]) R(b[0][0])*I(a[0][0]) + // R(b[0][0])*R(a[1][0]) R(b[0][0])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)); + //ymm5+=R(b[0][1])*R(a[0][0]) R(b[0][1])*I(a[0][0]) + // R(b[0][1])*R(a[1][0]) R(b[0][1])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm5,ymm2,(double const *)(temp_b+ldb)); + + //Calculating using imaginary part of complex numbers in B matrix + //Shuffling ymm0 in accordance to the requirement + NEG_PERM_M_FRINGE(ymm0,ymm2); + + // ymm3+=I(b[0][0])*R(a[0][0]) I(b[0][0])*I(a[0][0]) + // I(b[0][0])*R(a[1][0]) I(b[0][0])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)+1); + //ymm5+=I(b[0][1])*R(a[0][0]) I(b[0][1])*I(a[0][0]) + // I(b[0][1])*R(a[1][0]) I(b[0][1])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm5,ymm2,(double const *)(temp_b+ldb)+1); + + } + if(beta_valr != 0.0) + { + ymm15 = _mm256_broadcast_sd((double const *)(&beta_valr)); + + xmm5 = _mm_loadu_pd((double const*)(temp_c));//R(c[0][0]) I(c[0][0]) + ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][0]) I(c[0][0]) + //ymm3+=beta_valr*R(c[0][0]) beta_valr*I(c[0][0]) + SCALE_BETA_REAL_M_FRINGE(ymm0,ymm3,ymm15); + + xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc));//R(c[0][1]) I(c[0][1]) + ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][1]) I(c[0][1]) + //ymm5+=beta_valr*R(c[0][1]) beta_valr*I(c[0][1]) + SCALE_BETA_REAL_M_FRINGE(ymm0,ymm5,ymm15); + } + if(beta_vali != 0.0) + { + ymm15 = _mm256_broadcast_sd((double const *)(&beta_vali)); + + xmm5 = _mm_loadu_pd((double const*)(temp_c));//R(c[0][0]) I(c[0][0]) + ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][0]) I(c[0][0]) + //ymm3+=beta_vali*(-I(c[0][0])) beta_vali*R(c[0][0]) + SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm3,ymm15,ymm2); + + xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc));//R(c[0][1]) I(c[0][1]) + ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][1]) I(c[0][1]) + //ymm5+=beta_vali*(-I(c[0][1])) beta_vali*R(c[0][1]) + SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm5,ymm15,ymm2); + + } + + xmm5 = _mm256_extractf128_pd(ymm3, 0); + _mm_storeu_pd((double *)(temp_c), xmm5); + + xmm5 = _mm256_extractf128_pd(ymm5, 0); + _mm_storeu_pd((double *)(temp_c + ldc), xmm5); + + } + n_remainder -= 2; + } + if(n_remainder==1) + { + dcomplex* temp_b = b + (n - n_remainder)*ldb; + dcomplex* temp_a = a; + dcomplex* temp_c = c + (n - n_remainder)*ldc; + + for(dim_t i = 0;i < (m-Z_MR+1);i=i+Z_MR) + { + ymm3 = _mm256_setzero_pd(); + ymm4 = _mm256_setzero_pd(); + + if(alpha_valr != 0.0 || alpha_vali != 0.0) + { + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_valr, aplha_vali + where alpha_valr and/or alpha_vali is not zero. + b. This loop operates with 4x5 block size + along n dimension for every Z_NR columns of temp_b where + computing all Z_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + + //R(a[0][0]) I(a[0][0]) R(a[1][0]) I(a[1][0]) + ymm0 = _mm256_loadu_pd((double const *)(temp_a)); + //R(a[2][0]) I(a[2][0]) R(a[3][0]) I(a[3][0]) + ymm1 = _mm256_loadu_pd((double const *)(temp_a + 2)); + + ymm13 = ymm0; + ymm14 = ymm1; + _mm_prefetch((char*)(temp_a) + 64, _MM_HINT_T0); + + SCALE_ALPHA_REAL_M_LOOP(ymm0,ymm1,ymm15,alpha_valr); + SCALE_ALPHA_IMAG_M_LOOP(ymm0,ymm1,ymm15,ymm2,ymm13,ymm14,alpha_vali); + + /* + The result after scaling with alpha_valr and/or alpha_vali is as follows: + For ymm0 : + R(a[0][0]) = alpha_valr*R(a[0][0])-alpha_vali*I(a[0][0]) + I(a[0][0]) = alpha_valr*I(a[0][0])+alpha_vali*R[0][0] + R(a[1][0]) = alpha_valr*R(a[1][0])-alpha_vali*I(a[1][0]) + I(a[1][0]) = alpha_valr*I(a[1][0])+alpha_vali*(R[1][0]) + + For ymm1 : + R(a[2][0]) = alpha_valr*R(a[2][0])-alpha_vali*I(a[2][0]) + I(a[2][0]) = alpha_valr*I(a[2][0])+alpha_vali*R[2][0] + R(a[3][0]) = alpha_valr*R(a[3][0])-alpha_vali*I(a[3][0]) + I(a[3][0]) = alpha_valr*I(a[3][0])+alpha_vali*(R[3][0]) + */ + + //Calculating using real part of complex number in B matrix + FMA_M_LOOP(ymm0,ymm1,ymm3,ymm4,ymm2,(double const *)(temp_b)); + + //Calculating using imaginary part of complex numbers in B matrix + //Shuffling ymm0 and ymm1 in accordance to the requirement + NEG_PERM_M_LOOP(ymm0,ymm1,ymm2); + FMA_M_LOOP(ymm0,ymm1,ymm3,ymm4,ymm2,(double const *)(temp_b)+1); + + } + if(beta_valr != 0.0) + { + /* + a. Perform beta*C using temp_c, beta_valr, + where beta_valr is not zero. + b. This loop operates with 4x5 block size + along n dimension for every Z_NR columns of temp_c where + computing all Z_MR rows of temp_c. + c. Accumulated alpha*A*B into registers will be added to beta*C + d. Same approach is used in remaining fringe cases. + */ + ymm15 = _mm256_broadcast_sd((double const *)(&beta_valr)); + + //R(c[0][0]) I(c[0][0]) R(c[1][0]) I(c[1][0]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c)); + //R(c[2][0]) I(c[2][0]) R(c[3][0]) I(c[3][0]) + ymm1 = _mm256_loadu_pd((double const *)(temp_c + 2)); + //ymm3+=beta_valr*R(c[0][0]) beta_valr*I(c[0][0]) + // beta_valr*R(c[1][0]) beta_valr*I(c[1][0]) + //ymm4+=beta_valr*R(c[2][0]) beta_valr*I(c[2][0]) + // beta_valr*R(c[3][0]) beta_valr*I(c[3][0]) + SCALE_BETA_REAL_M_LOOP(ymm0,ymm1,ymm3,ymm4,ymm15); + + } + if(beta_vali != 0.0) + { + /* + a. Perform beta*C using temp_c, beta_vali, + where beta_vali is not zero. + b. This loop operates with 4x5 block size + along n dimension for every Z_NR columns of temp_c where + computing all Z_MR rows of temp_c. + c. Accumulated alpha*A*B into registers will be added to beta*C + d. Same approach is used in remaining fringe cases. + */ + + ymm15 = _mm256_broadcast_sd((double const *)(&beta_vali)); + + ymm0 = _mm256_loadu_pd((double const *)(temp_c)); + ymm1 = _mm256_loadu_pd((double const *)(temp_c + 2)); + //ymm3+=beta_vali*(-I(c[0][0])) beta_vali*R(c[0][0]) + // beta_vali*(-I(c[1][0])) beta_vali*R(c[1][0]) + //ymm4+=beta_vali*(-I(c[2][0])) beta_vali*R(c[2][0]) + // beta_vali*(-I(c[3][0])) beta_vali*R(c[3][0]) + SCALE_BETA_IMAG_M_LOOP(ymm0,ymm1,ymm3,ymm4,ymm15,ymm2); + } + /* + The scaling has been done sequentially as follows: + - If alpha_valr is not 0, it is used for scaling A + - If alpha_vali is not 0, it is used for scaling A using permutation + and selective negation, after loading + - If beta_valr is not 0, is is used for scaling C + - If beta_vali is not 0, it is used for scaling C using permutation + and selective negation, after loading + + The results are accumalated in accordance to the non zero scalar values, + and similar approach is followed in fringe cases + */ + + //R(c[0][0]) I(c[0][0]) R(c[1][0]) I(c[1][0]) + _mm256_storeu_pd((double *)(temp_c), ymm3); + //R(c[2][0]) I(c[2][0]) R(c[3][0]) I(c[3][0]) + _mm256_storeu_pd((double *)(temp_c + 2), ymm4); + + temp_c+=Z_MR; + temp_a+=Z_MR; + } + + dim_t m_rem=m_remainder; + if(m_rem>=2) + { + ymm3 = _mm256_setzero_pd(); + + if(alpha_valr != 0.0 || alpha_vali != 0.0) + { + + //R(a[0][0]) I(a[0][0]) R(a[1][0]) I(a[1][0]) + ymm0 = _mm256_loadu_pd((double const *)(temp_a)); + ymm13 = ymm0; + + SCALE_ALPHA_REAL_M_FRINGE(ymm0,ymm15,alpha_valr); + SCALE_ALPHA_IMAG_M_FRINGE(ymm0,ymm15,ymm2,ymm13,alpha_vali); + + /* + The result after scaling with alpha_valr and/or alpha_vali is as follows: + For ymm0 : + R(a[0][0]) = alpha_valr*R(a[0][0])-alpha_vali*I(a[0][0]) + I(a[0][0]) = alpha_valr*I(a[0][0])+alpha_vali*R[0][0] + R(a[1][0]) = alpha_valr*R(a[1][0])-alpha_vali*I(a[1][0]) + I(a[1][0]) = alpha_valr*I(a[1][0])+alpha_vali*(R[1][0]) + */ + + //Calculating using real part of complex number in B matrix + //ymm3+=R(b[0][0])*R(a[0][0]) R(b[0][0])*I(a[0][0]) + // R(b[0][0])*R(a[1][0]) R(b[0][0])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)); + + //Calculating using imaginary part of complex numbers in B matrix + //Shuffling ymm0 in accordance to the requirement + NEG_PERM_M_FRINGE(ymm0,ymm2); + + // ymm3+=I(b[0][0])*R(a[0][0]) I(b[0][0])*I(a[0][0]) + // I(b[0][0])*R(a[1][0]) I(b[0][0])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)+1); + } + + if(beta_valr != 0.0) + { + + ymm15 = _mm256_broadcast_sd((double const *)(&beta_valr)); + + ymm0 = _mm256_loadu_pd((double const *)(temp_c)); + //ymm3+=beta_valr*R(c[0][0]) beta_valr*I(c[0][0]) + // beta_valr*R(c[1][0]) beta_valr*I(c[1][0]) + SCALE_BETA_REAL_M_FRINGE(ymm0,ymm3,ymm15); + } + + if(beta_vali != 0.0) + { + + ymm15 = _mm256_broadcast_sd((double const *)(&beta_vali)); + + ymm0 = _mm256_loadu_pd((double const *)(temp_c)); + //ymm3+=beta_vali*(-I(c[0][0])) beta_vali*R(c[0][0]) + // beta_vali*(-I(c[1][0])) beta_vali*R(c[1][0]) + SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm3,ymm15,ymm2); + } + + /* + The scaling has been done sequentially as follows: + - If alpha_valr is not 0, it is used for scaling A + - If alpha_vali is not 0, it is used for scaling A using permutation + and selective negation, after loading + - If beta_valr is not 0, is is used for scaling C + - If beta_vali is not 0, it is used for scaling C using permutation + and selective negation, after loading + + The results are accumalated in accordance to the non zero scalar values, + and similar approach is followed in fringe cases + */ + + _mm256_storeu_pd((double *)(temp_c), ymm3); + + temp_c+=2; + temp_a+=2; + + m_rem -= 2; + } + + if(m_rem==1) + { + + xmm5 = _mm_setzero_pd(); + ymm3 = _mm256_setzero_pd(); + + if(alpha_valr != 0.0 || alpha_vali != 0.0) + { + xmm5 = _mm_loadu_pd((double const*)(temp_a)); + ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0); + ymm13 = ymm0; + + SCALE_ALPHA_REAL_M_FRINGE(ymm0,ymm15,alpha_valr); + SCALE_ALPHA_IMAG_M_FRINGE(ymm0,ymm15,ymm2,ymm13,alpha_vali); + + //Calculating using real part of complex number in B matrix + //ymm3+=R(b[0][0])*R(a[0][0]) R(b[0][0])*I(a[0][0]) + // R(b[0][0])*R(a[1][0]) R(b[0][0])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)); + + //Calculating using imaginary part of complex numbers in B matrix + //Shuffling ymm0 in accordance to the requirement + NEG_PERM_M_FRINGE(ymm0,ymm2); + + // ymm3+=I(b[0][0])*R(a[0][0]) I(b[0][0])*I(a[0][0]) + // I(b[0][0])*R(a[1][0]) I(b[0][0])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)+1); + } + if(beta_valr != 0.0) + { + ymm15 = _mm256_broadcast_sd((double const *)(&beta_valr)); + + xmm5 = _mm_loadu_pd((double const*)(temp_c)); + ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0); + //ymm3+=beta_valr*R(c[0][0]) beta_valr*I(c[0][0]) + SCALE_BETA_REAL_M_FRINGE(ymm0,ymm3,ymm15); + } + if(beta_vali != 0.0) + { + ymm15 = _mm256_broadcast_sd((double const *)(&beta_vali)); + + xmm5 = _mm_loadu_pd((double const*)(temp_c)); + ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0); + //ymm3+=beta_vali*(-I(c[0][0])) beta_vali*R(c[0][0]) + SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm3,ymm15,ymm2); + + } + + xmm5 = _mm256_extractf128_pd(ymm3, 0); + _mm_storeu_pd((double *)(temp_c), xmm5); + + } + + } + +} diff --git a/kernels/zen/bli_kernels_zen.h b/kernels/zen/bli_kernels_zen.h index 53db6e4d22..1d18d711e1 100644 --- a/kernels/zen/bli_kernels_zen.h +++ b/kernels/zen/bli_kernels_zen.h @@ -313,6 +313,18 @@ void bli_dgemm_ref_k1_nn double* c, const inc_t ldc ); +void bli_zgemm_ref_k1_nn + ( + dim_t m, + dim_t n, + dim_t k, + dcomplex* alpha, + dcomplex* a, const inc_t lda, + dcomplex* b, const inc_t ldb, + dcomplex* beta, + dcomplex* c, const inc_t ldc + ); + err_t bli_trsm_small ( side_t side, From 4f96bb712e8f9f14e097a353f53c860e59cc60bc Mon Sep 17 00:00:00 2001 From: Arnav Sharma Date: Wed, 13 Jul 2022 11:42:35 +0530 Subject: [PATCH 145/243] AOCL Dynamic Optimization for DGEMMT - Fine-tuned the thread allocation logic for parallelizing DGEMMT for the cases where n <= 220. This results in performance improvement in multi-threaded DGEMMT for small values of n. AMD-Internal: [CPUPL-2215] Change-Id: I2654bc64d2dc43c2db911e0c9175755be3aa8ba5 --- frame/base/bli_rntm.c | 34 ++++++++++++++++++++++++++++++++-- 1 file changed, 32 insertions(+), 2 deletions(-) diff --git a/frame/base/bli_rntm.c b/frame/base/bli_rntm.c index 0db51870eb..5efba4f2f0 100644 --- a/frame/base/bli_rntm.c +++ b/frame/base/bli_rntm.c @@ -680,16 +680,35 @@ void bli_nthreads_optimum( dim_t n = bli_obj_length(c); dim_t k = bli_obj_width_after_trans(a); - if ( n < 32 ) + if ( n < 8 ) + { + if ( k <= 512) + { + n_threads_ideal = 1; + } + else if ( k <= 1024 ) + { + n_threads_ideal = 4; + } + } + else if ( n < 32 ) { if ( k < 128 ) { n_threads_ideal = 1; } - else if ( k == 128 ) + else if ( k <= 512 ) { n_threads_ideal = 4; } + else if ( k <= 1024 ) + { + n_threads_ideal = 6; + } + else if ( k <= 1600 ) + { + n_threads_ideal = 10; + } } else if ( n <= 40 ) { @@ -724,6 +743,17 @@ void bli_nthreads_optimum( n_threads_ideal = 8; } } + else if ( n < 176 ) + { + if ( k < 128 ) + { + n_threads_ideal = 8; + } + else if ( k <= 512 ) + { + n_threads_ideal = 14; + } + } else if ( n <= 220 ) { if ( k < 128 ) From 86134c72786cb3e2c138c6d9a0a63bfcca1cc3d3 Mon Sep 17 00:00:00 2001 From: Kiran Varaganti Date: Mon, 18 Jul 2022 06:41:37 +0000 Subject: [PATCH 146/243] Replaced vzeroall Replaced vzeroall instruction with vxorpd and vmovapd for dgemm kernels -both AVX2 and AVX512. vzeroall is expensive instruction and replaced it with faster version of zeroing all registers. vzeroupper() instruction is also added at the end of AVX2 kernels to avoid any AVX2/SSE transition penalities. Kindly note only the main kernels are modified. Change-Id: Ieb9bc629db01f0f94dd0e8e55550940d3d7eb2a4 --- kernels/haswell/3/bli_gemm_haswell_asm_d6x8.c | 18 +++- .../3/sup/bli_gemmsup_rv_haswell_asm_d6x8m.c | 56 ++++++------ kernels/skx/3/bli_dgemm_skx_asm_16x14.c | 55 ++++++------ .../zen/3/sup/bli_gemmsup_rv_zen_asm_s6x16.c | 89 ++++++++++--------- 4 files changed, 120 insertions(+), 98 deletions(-) diff --git a/kernels/haswell/3/bli_gemm_haswell_asm_d6x8.c b/kernels/haswell/3/bli_gemm_haswell_asm_d6x8.c index b4ac979e1a..59e239fe14 100644 --- a/kernels/haswell/3/bli_gemm_haswell_asm_d6x8.c +++ b/kernels/haswell/3/bli_gemm_haswell_asm_d6x8.c @@ -950,7 +950,21 @@ void bli_dgemm_haswell_asm_6x8 begin_asm() - vzeroall() // zero all xmm/ymm registers. + //vzeroall() // zero all xmm/ymm registers. + + vxorpd( ymm4, ymm4, ymm4) // vzeroall is expensive + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + mov(var(a), rax) // load address of a. @@ -1610,7 +1624,7 @@ void bli_dgemm_haswell_asm_6x8 label(.DDONE) - + vzeroupper() end_asm( diff --git a/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8m.c b/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8m.c index c7a95d65f1..56227ec4dc 100644 --- a/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8m.c +++ b/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8m.c @@ -274,17 +274,17 @@ void bli_dgemmsup_rv_haswell_asm_6x8m // a latency of 1 cycle, while vzeroall // has a latency of 12 cycles. vxorpd(ymm4, ymm4, ymm4) - vxorpd(ymm5, ymm5, ymm5) - vxorpd(ymm6, ymm6, ymm6) - vxorpd(ymm7, ymm7, ymm7) - vxorpd(ymm8, ymm8, ymm8) - vxorpd(ymm9, ymm9, ymm9) - vxorpd(ymm10, ymm10, ymm10) - vxorpd(ymm11, ymm11, ymm11) - vxorpd(ymm12, ymm12, ymm12) - vxorpd(ymm13, ymm13, ymm13) - vxorpd(ymm14, ymm14, ymm14) - vxorpd(ymm15, ymm15, ymm15) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) #endif mov(var(b), rbx) // load address of b. @@ -1077,18 +1077,18 @@ void bli_dgemmsup_rv_haswell_asm_6x6m // a latency of 1 cycle, while vzeroall // has a latency of 12 cycles. vxorpd(ymm1, ymm1, ymm1) // zero ymm1 since we only use the lower - vxorpd(ymm4, ymm4, ymm4) // half (xmm1), and nans/infs may slow us - vxorpd(ymm5, ymm5, ymm5) // down. - vxorpd(ymm6, ymm6, ymm6) - vxorpd(ymm7, ymm7, ymm7) - vxorpd(ymm8, ymm8, ymm8) - vxorpd(ymm9, ymm9, ymm9) - vxorpd(ymm10, ymm10, ymm10) - vxorpd(ymm11, ymm11, ymm11) - vxorpd(ymm12, ymm12, ymm12) - vxorpd(ymm13, ymm13, ymm13) - vxorpd(ymm14, ymm14, ymm14) - vxorpd(ymm15, ymm15, ymm15) + vmovapd(ymm1, ymm4) // half (xmm1), and nans/infs may slow us + vmovapd(ymm1, ymm5) // down. + vmovapd(ymm1, ymm6) + vmovapd(ymm1, ymm7) + vmovapd(ymm1, ymm8) + vmovapd(ymm1, ymm9) + vmovapd(ymm1, ymm10) + vmovapd(ymm1, ymm11) + vmovapd(ymm1, ymm12) + vmovapd(ymm1, ymm13) + vmovapd(ymm1, ymm14) + vmovapd(ymm1, ymm15) #endif mov(var(b), rbx) // load address of b. @@ -1858,11 +1858,11 @@ void bli_dgemmsup_rv_haswell_asm_6x4m // a latency of 1 cycle, while vzeroall // has a latency of 12 cycles. vxorpd(ymm4, ymm4, ymm4) - vxorpd(ymm6, ymm6, ymm6) - vxorpd(ymm8, ymm8, ymm8) - vxorpd(ymm10, ymm10, ymm10) - vxorpd(ymm12, ymm12, ymm12) - vxorpd(ymm14, ymm14, ymm14) + vmovapd(ymm4, ymm6) + vmovapd(ymm4, ymm8) + vmovapd(ymm4, ymm10) + vmovapd(ymm4, ymm12) + vmovapd(ymm4, ymm14) #endif mov(var(b), rbx) // load address of b. diff --git a/kernels/skx/3/bli_dgemm_skx_asm_16x14.c b/kernels/skx/3/bli_dgemm_skx_asm_16x14.c index 136f315323..877e0c9191 100644 --- a/kernels/skx/3/bli_dgemm_skx_asm_16x14.c +++ b/kernels/skx/3/bli_dgemm_skx_asm_16x14.c @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2022, Advanced Micro Devices, Inc.All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -175,33 +176,33 @@ void bli_dgemm_skx_asm_16x14( BEGIN_ASM() VXORPD(YMM( 4), YMM( 4), YMM( 4)) //clear out registers - VXORPD(YMM( 5), YMM( 5), YMM( 5)) - VXORPD(YMM( 6), YMM( 6), YMM( 6)) - VXORPD(YMM( 7), YMM( 7), YMM( 7)) - VXORPD(YMM( 8), YMM( 8), YMM( 8)) - VXORPD(YMM( 9), YMM( 9), YMM( 9)) - VXORPD(YMM(10), YMM(10), YMM(10)) - VXORPD(YMM(11), YMM(11), YMM(11)) - VXORPD(YMM(12), YMM(12), YMM(12)) - VXORPD(YMM(13), YMM(13), YMM(13)) - VXORPD(YMM(14), YMM(14), YMM(14)) - VXORPD(YMM(15), YMM(15), YMM(15)) - VXORPD(YMM(16), YMM(16), YMM(16)) - VXORPD(YMM(17), YMM(17), YMM(17)) - VXORPD(YMM(18), YMM(18), YMM(18)) - VXORPD(YMM(19), YMM(19), YMM(19)) - VXORPD(YMM(20), YMM(20), YMM(20)) - VXORPD(YMM(21), YMM(21), YMM(21)) - VXORPD(YMM(22), YMM(22), YMM(22)) - VXORPD(YMM(23), YMM(23), YMM(23)) - VXORPD(YMM(24), YMM(24), YMM(24)) - VXORPD(YMM(25), YMM(25), YMM(25)) - VXORPD(YMM(26), YMM(26), YMM(26)) - VXORPD(YMM(27), YMM(27), YMM(27)) - VXORPD(YMM(28), YMM(28), YMM(28)) - VXORPD(YMM(29), YMM(29), YMM(29)) - VXORPD(YMM(30), YMM(30), YMM(30)) - VXORPD(YMM(31), YMM(31), YMM(31)) + VMOVAPD(YMM(4), YMM( 5)) + VMOVAPD(YMM(4), YMM( 6)) + VMOVAPD(YMM(4), YMM( 7)) + VMOVAPD(YMM(4), YMM( 8)) + VMOVAPD(YMM(4), YMM( 9)) + VMOVAPD(YMM(4), YMM(10)) + VMOVAPD(YMM(4), YMM(11)) + VMOVAPD(YMM(4), YMM(12)) + VMOVAPD(YMM(4), YMM(13)) + VMOVAPD(YMM(4), YMM(14)) + VMOVAPD(YMM(4), YMM(15)) + VMOVAPD(YMM(4), YMM(16)) + VMOVAPD(YMM(4), YMM(17)) + VMOVAPD(YMM(4), YMM(18)) + VMOVAPD(YMM(4), YMM(19)) + VMOVAPD(YMM(4), YMM(20)) + VMOVAPD(YMM(4), YMM(21)) + VMOVAPD(YMM(4), YMM(22)) + VMOVAPD(YMM(4), YMM(23)) + VMOVAPD(YMM(4), YMM(24)) + VMOVAPD(YMM(4), YMM(25)) + VMOVAPD(YMM(4), YMM(26)) + VMOVAPD(YMM(4), YMM(27)) + VMOVAPD(YMM(4), YMM(28)) + VMOVAPD(YMM(4), YMM(29)) + VMOVAPD(YMM(4), YMM(30)) + VMOVAPD(YMM(4), YMM(31)) MOV(RSI, VAR(k)) //loop index MOV(RAX, VAR(a)) //load address of a diff --git a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_s6x16.c b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_s6x16.c index 347384aa65..752a0a01c5 100644 --- a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_s6x16.c +++ b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_s6x16.c @@ -113,17 +113,19 @@ void bli_sgemmsup_rv_zen_asm_5x16 begin_asm() vxorps(ymm4, ymm4, ymm4) - vxorps(ymm5, ymm5, ymm5) - vxorps(ymm6, ymm6, ymm6) - vxorps(ymm7, ymm7, ymm7) - vxorps(ymm8, ymm8, ymm8) - vxorps(ymm9, ymm9, ymm9) - vxorps(ymm10, ymm10, ymm10) - vxorps(ymm11, ymm11, ymm11) - vxorps(ymm12, ymm12, ymm12) - vxorps(ymm13, ymm13, ymm13) - vxorps(ymm14, ymm14, ymm14) - vxorps(ymm15, ymm15, ymm15) + vmovaps(ymm4, ymm5) + vmovaps(ymm4, ymm6) + vmovaps(ymm4, ymm7) + vmovaps(ymm4, ymm8) + vmovaps(ymm4, ymm9) + vmovaps(ymm4, ymm10) + vmovaps(ymm4, ymm11) + vmovaps(ymm4, ymm12) + vmovaps(ymm4, ymm13) + vmovaps(ymm4, ymm14) + vmovaps(ymm4, ymm15) + + mov(var(a), rax) // load address of a. mov(var(rs_a), r8) // load rs_a mov(var(cs_a), r9) // load cs_a @@ -694,6 +696,7 @@ void bli_sgemmsup_rv_zen_asm_5x16 vmovss(xmm14, mem(rdx, rax, 1)) label(.SDONE) + vzeroupper() end_asm( : // output operands (none) @@ -758,19 +761,20 @@ void bli_sgemmsup_rv_zen_asm_4x16 // ------------------------------------------------------------------------- begin_asm() - - vxorps(ymm4, ymm4, ymm4) - vxorps(ymm5, ymm5, ymm5) - vxorps(ymm6, ymm6, ymm6) - vxorps(ymm7, ymm7, ymm7) - vxorps(ymm8, ymm8, ymm8) - vxorps(ymm9, ymm9, ymm9) - vxorps(ymm10, ymm10, ymm10) - vxorps(ymm11, ymm11, ymm11) - vxorps(ymm12, ymm12, ymm12) - vxorps(ymm13, ymm13, ymm13) - vxorps(ymm14, ymm14, ymm14) - vxorps(ymm15, ymm15, ymm15) + + vxorps(ymm4, ymm4, ymm4) + vmovaps(ymm4, ymm5) + vmovaps(ymm4, ymm6) + vmovaps(ymm4, ymm7) + vmovaps(ymm4, ymm8) + vmovaps(ymm4, ymm9) + vmovaps(ymm4, ymm10) + vmovaps(ymm4, ymm11) + vmovaps(ymm4, ymm12) + vmovaps(ymm4, ymm13) + vmovaps(ymm4, ymm14) + vmovaps(ymm4, ymm15) + mov(var(a), rax) // load address of a. mov(var(rs_a), r8) // load rs_a mov(var(cs_a), r9) // load cs_a @@ -822,14 +826,14 @@ void bli_sgemmsup_rv_zen_asm_4x16 prefetch(0, mem(rdx, rsi, 2, 3*8)) // prefetch c + 7*cs_c label(.SPOSTPFETCH) // done prefetching c - + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.SCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - + label(.SLOOPKITER) // MAIN LOOP - + // ---------------------------------- iteration 0 vmovups(mem(rbx, 0*32), ymm0) vmovups(mem(rbx, 1*32), ymm1) @@ -1188,7 +1192,8 @@ void bli_sgemmsup_rv_zen_asm_4x16 vmovups(xmm2, mem(rcx, rsi, 4)) // store ( gamma07..gamma37 ) label(.SDONE) - + vzeroupper() + end_asm( : // output operands (none) : // input operands @@ -1252,19 +1257,20 @@ void bli_sgemmsup_rv_zen_asm_3x16 // ------------------------------------------------------------------------- begin_asm() - - vxorps(ymm4, ymm4, ymm4) - vxorps(ymm5, ymm5, ymm5) - vxorps(ymm6, ymm6, ymm6) - vxorps(ymm7, ymm7, ymm7) - vxorps(ymm8, ymm8, ymm8) - vxorps(ymm9, ymm9, ymm9) - vxorps(ymm10, ymm10, ymm10) - vxorps(ymm11, ymm11, ymm11) - vxorps(ymm12, ymm12, ymm12) - vxorps(ymm13, ymm13, ymm13) - vxorps(ymm14, ymm14, ymm14) - vxorps(ymm15, ymm15, ymm15) + + vxorps(ymm4, ymm4, ymm4) + vmovaps(ymm4, ymm5) + vmovaps(ymm4, ymm6) + vmovaps(ymm4, ymm7) + vmovaps(ymm4, ymm8) + vmovaps(ymm4, ymm9) + vmovaps(ymm4, ymm10) + vmovaps(ymm4, ymm11) + vmovaps(ymm4, ymm12) + vmovaps(ymm4, ymm13) + vmovaps(ymm4, ymm14) + vmovaps(ymm4, ymm15) + mov(var(a), rax) // load address of a. mov(var(rs_a), r8) // load rs_a mov(var(cs_a), r9) // load cs_a @@ -1746,6 +1752,7 @@ void bli_sgemmsup_rv_zen_asm_3x16 vmovss(xmm14, mem(rdx, rax, 1)) label(.SDONE) + vzeroupper() end_asm( : // output operands (none) From f63e699c087f1e53d9629d518ba4e43a823e310e Mon Sep 17 00:00:00 2001 From: mkadavil Date: Wed, 20 Jul 2022 19:07:46 +0530 Subject: [PATCH 147/243] Fix for segmentation fault in low precision gemm. - Low precision gemm sets thread meta data (lpgemm_thrinfo_t) to NULL when compiled without open mp threading support. Subsequently the code is executed as if it is single-threaded. However, when B matrix needs to be packed, communicators are required (irrespective of single or multi-threaded), and the code accesses lpgemm_thrinfo_t for the same without NULL check. This results in seg fault. For the fix, a non-open mp thread decorator layer is added, which creates a placeholder lpgemm_thrinfo_t object with a communicator before invoking the 5 loop algorithm. This object will be used for packing. - Makefile for compilation of aocl_gemm bench. AMD-Internal: [CPUPL-2304] Change-Id: Id505235c8421792240b84f93942ca62dac78f3dc --- addon/aocl_gemm/aocl_gemm_f32f32f32of32.c | 5 +- addon/aocl_gemm/aocl_gemm_u8s8s32os32.c | 5 +- .../threading/lpgemm_thread_decor_openmp.c | 65 +++++++++ .../threading/lpgemm_thread_decor_openmp.h | 30 +++- bench/bench_aocl_gemm/Makefile | 132 ++++++++++++++++++ bench/bench_aocl_gemm/bench_lpgemm.c | 5 +- 6 files changed, 232 insertions(+), 10 deletions(-) create mode 100755 bench/bench_aocl_gemm/Makefile diff --git a/addon/aocl_gemm/aocl_gemm_f32f32f32of32.c b/addon/aocl_gemm/aocl_gemm_f32f32f32of32.c index f3055ad3f9..973d33c548 100644 --- a/addon/aocl_gemm/aocl_gemm_f32f32f32of32.c +++ b/addon/aocl_gemm/aocl_gemm_f32f32f32of32.c @@ -154,15 +154,14 @@ void aocl_gemm_f32f32f32of32 // Setting pack A by default for non open mp case. bli_rntm_set_pack_a( 1, &rntm_g ); - lpgemm_rowvar_f32f32f32of32 + lpgemm_f32f32f32of32_thread_decorator ( m, n, k, a, rs_a, cs_a, mtag_a, b, rs_b, cs_b, mtag_b, c, rs_c, alpha, beta, - &rntm_g, - NULL + &rntm_g ); #endif diff --git a/addon/aocl_gemm/aocl_gemm_u8s8s32os32.c b/addon/aocl_gemm/aocl_gemm_u8s8s32os32.c index e1131a3992..860b0e420c 100644 --- a/addon/aocl_gemm/aocl_gemm_u8s8s32os32.c +++ b/addon/aocl_gemm/aocl_gemm_u8s8s32os32.c @@ -150,15 +150,14 @@ void aocl_gemm_u8s8s32os32 &rntm_g ); #else - lpgemm_rowvar_u8s8s32o32 + lpgemm_u8s8s32o32_thread_decorator ( m, n, k, a, rs_a, cs_a, mtag_a, b, rs_b, cs_b, mtag_b, c, rs_c, alpha, beta, - &rntm_g, - NULL + &rntm_g ); #endif } diff --git a/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.c b/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.c index 59641f5dae..5659c8286f 100644 --- a/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.c +++ b/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.c @@ -444,4 +444,69 @@ void lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \ GEN_LPGEMM_OPENMP_DECORATOR(uint8_t,int8_t,int32_t,u8s8s32o32) GEN_LPGEMM_OPENMP_DECORATOR(float,float,float,f32f32f32of32) +#else + +#define GEN_LPGEMM_DECORATOR(A_type,B_type,C_type,LPGEMM_SFX) \ +void lpgemm_ ## LPGEMM_SFX ## _thread_decorator \ + ( \ + const dim_t m, \ + const dim_t n, \ + const dim_t k, \ + const A_type* a, \ + const dim_t rs_a, \ + const dim_t cs_a, \ + const AOCL_MEMORY_TAG mtag_a, \ + const B_type* b, \ + const dim_t rs_b, \ + const dim_t cs_b, \ + const AOCL_MEMORY_TAG mtag_b, \ + C_type* c, \ + const dim_t rs_c, \ + C_type alpha, \ + C_type beta, \ + rntm_t* rntm_g \ + ) \ +{ \ + dim_t n_threads = 1; \ + \ + /* Factorization of threads along m and n dimension respectively.*/ \ + dim_t ic_ways = 1; \ + dim_t jc_ways = 1; \ + \ + /* Set the packing block allocator field of the rntm. This will be + * inherited by all of the child threads when they make local copies of + * the rntm below.*/ \ + bli_membrk_rntm_set_membrk( rntm_g ); \ + \ + thrcomm_t static_lpgemm_comm; \ + thrcomm_t* cur_lpgemm_comm = &static_lpgemm_comm; \ + \ + bli_thrcomm_init( ic_ways, cur_lpgemm_comm ); \ + \ + /* lpgemm_thrinfo_t object will be used to generate thrinfo_t objects + * for use in blis mt framework inside the respective mat mul driver + * functions.*/ \ + lpgemm_thrinfo_t thread; \ + thread.n_threads = n_threads; \ + thread.tid = 0; \ + thread.ic_ways = ic_ways; \ + thread.jc_ways = jc_ways; \ + thread.comm = cur_lpgemm_comm; \ + \ + lpgemm_rowvar_ ## LPGEMM_SFX \ + ( \ + m, n, k, \ + a, rs_a, cs_a, mtag_a, \ + b, rs_b, cs_b, mtag_b, \ + c, rs_c, \ + alpha, \ + beta, \ + rntm_g, \ + &thread \ + ); \ +} \ + +GEN_LPGEMM_DECORATOR(uint8_t,int8_t,int32_t,u8s8s32o32) +GEN_LPGEMM_DECORATOR(float,float,float,f32f32f32of32) + #endif diff --git a/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.h b/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.h index dba1d71ab6..dd38a02ebd 100644 --- a/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.h +++ b/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.h @@ -35,10 +35,10 @@ #ifndef LPGEMM_THREAD_DECOR_OPENMP_H #define LPGEMM_THREAD_DECOR_OPENMP_H -#ifdef BLIS_ENABLE_OPENMP - #include "lpgemm_types.h" +#ifdef BLIS_ENABLE_OPENMP + #define GEN_LPGEMM_OPENMP_DECORATOR_FN(A_type,B_type,C_type,LPGEMM_SFX) \ void lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \ ( \ @@ -63,6 +63,32 @@ void lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \ GEN_LPGEMM_OPENMP_DECORATOR_FN(uint8_t,int8_t,int32_t,u8s8s32o32) GEN_LPGEMM_OPENMP_DECORATOR_FN(float,float,float,f32f32f32of32) +#else + +#define GEN_LPGEMM_DECORATOR_FN(A_type,B_type,C_type,LPGEMM_SFX) \ +void lpgemm_ ## LPGEMM_SFX ## _thread_decorator \ + ( \ + const dim_t m, \ + const dim_t n, \ + const dim_t k, \ + const A_type* a, \ + const dim_t rs_a, \ + const dim_t cs_a, \ + const AOCL_MEMORY_TAG mtag_a, \ + const B_type* b, \ + const dim_t rs_b, \ + const dim_t cs_b, \ + const AOCL_MEMORY_TAG mtag_b, \ + C_type* c, \ + const dim_t rs_c, \ + C_type alpha, \ + C_type beta, \ + rntm_t* rntm_g \ + ); \ + +GEN_LPGEMM_DECORATOR_FN(uint8_t,int8_t,int32_t,u8s8s32o32) +GEN_LPGEMM_DECORATOR_FN(float,float,float,f32f32f32of32) + #endif #endif //LPGEMM_THREAD_DECOR_OPENMP_H diff --git a/bench/bench_aocl_gemm/Makefile b/bench/bench_aocl_gemm/Makefile new file mode 100755 index 0000000000..6344fe9396 --- /dev/null +++ b/bench/bench_aocl_gemm/Makefile @@ -0,0 +1,132 @@ +# +# +# BLIS +# An object-based framework for developing high-performance BLAS-like +# libraries. +# +# Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# - Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# - Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# - Neither the name(s) of the copyright holder(s) nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# + +# Makefile for lpgemm bench. +# + +# +# --- Makefile PHONY target definitions ---------------------------------------- +# + +.PHONY: all \ + blis \ + check-env check-env-mk check-lib \ + clean cleanx + +# Comments: +# - DIST_PATH is assumed to not exist if BLIS_INSTALL_PATH is given. +# - We must use recursively expanded assignment for LIB_PATH and INC_PATH in +# the second case because CONFIG_NAME is not yet set. +ifneq ($(strip $(BLIS_INSTALL_PATH)),) +LIB_PATH := $(BLIS_INSTALL_PATH)/lib +INC_PATH := $(BLIS_INSTALL_PATH)/include/blis +SHARE_PATH := $(BLIS_INSTALL_PATH)/share/blis +else +DIST_PATH := ../.. +LIB_PATH = ../../lib/$(CONFIG_NAME) +INC_PATH = ../../include/$(CONFIG_NAME) +SHARE_PATH := ../.. +endif + + + +# +# --- Include common makefile definitions -------------------------------------- +# + +# Include the common makefile fragment. +-include $(SHARE_PATH)/common.mk + +# +# --- General build definitions ------------------------------------------------ +# + +TEST_SRC_PATH := . +TEST_OBJ_PATH := . + +# Gather all local object files. +TEST_OBJS := $(patsubst $(TEST_SRC_PATH)/%.c, \ + $(TEST_OBJ_PATH)/%.o, \ + $(wildcard $(TEST_SRC_PATH)/*.c)) + + + +# Override the value of CINCFLAGS so that the value of CFLAGS returned by +# get-user-cflags-for() is not cluttered up with include paths needed only +# while building BLIS. +CINCFLAGS := -I$(INC_PATH) -I$(CBLAS_HEADER_PATH) + +# Use the CFLAGS for the configuration family. +CFLAGS := $(call get-user-cflags-for,$(CONFIG_NAME)) + +# Add local header paths to CFLAGS +CFLAGS += -I$(TEST_SRC_PATH) + +# Locate the libblis library to which we will link. +#LIBBLIS_LINK := $(LIB_PATH)/$(LIBBLIS_L) + +# +# --- Targets/rules ------------------------------------------------------------ +# + +# Complete list of possible targets when defining 'all': +# +# blis openblas atlas mkl mac essl +# +all: blis + +blis: \ + bench_lpgemm_blis.x + + +# --Object file rules -- + +$(TEST_OBJ_PATH)/%.o: $(TEST_SRC_PATH)/%.c + $(CC) $(CFLAGS) -c $< -o $@ + +bench_%_blis.o: bench_%.c + $(CC) $(CFLAGS) -DBLAS=\"aocl\" $(NRTS) -DINT_FS=\"%ld\" -DUINT_FS=\"%lu\" -c $< -o $@ + + +# -- Executable file rules -- + +bench_%_blis.x: bench_%_blis.o $(LIBBLIS_LINK) + $(LINKER) $< $(LIBBLIS_LINK) $(LDFLAGS) -o $@ + + +# -- Clean rules -- + +clean: cleanx + +cleanx: + - $(RM_F) *.o *.x diff --git a/bench/bench_aocl_gemm/bench_lpgemm.c b/bench/bench_aocl_gemm/bench_lpgemm.c index 91dd10b966..0af03cf2f8 100644 --- a/bench/bench_aocl_gemm/bench_lpgemm.c +++ b/bench/bench_aocl_gemm/bench_lpgemm.c @@ -6,7 +6,7 @@ #include #include -#include "blis/blis.h" +#include "blis.h" // Mode can be one of the follwoing: // 1. p - performance, used for benchmarks. @@ -129,7 +129,6 @@ void mat_mul_bench_driver_ ## BLAS_SFX \ dim_t ldc \ ) \ { \ - double gflops; \ double min_time_diff = DBL_MAX; \ for ( int32_t nr = 0; nr < n_repeats; ++nr ) \ { \ @@ -425,7 +424,9 @@ int main( int argc, char** argv ) // number of cores. This helps uncover any bugs related to over // subscription or varying thread factorizations. // Set current number of cores. +#ifdef BLIS_ENABLE_OPENMP omp_set_num_threads( list_omp_cores_for_testing[core_index] ); +#endif printf( "Accuracy test using %ld threads.\n", list_omp_cores_for_testing[core_index] ); From 5d617429f48b5028bec3772373de993eebafd1d2 Mon Sep 17 00:00:00 2001 From: Dipal M Zambare Date: Fri, 22 Jul 2022 10:27:53 +0530 Subject: [PATCH 148/243] Enabled znver4 support for GCC version >= 12 - Updated zen4 configuration to add -march=znver4 flag in the compiler options if the gcc version is above or equal to 12 AMD-Internal: [CPUPL-1937] Change-Id: Ic11470b92f71e49ee193a3a5406cf6045d66bd2f --- config/zen4/make_defs.mk | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/config/zen4/make_defs.mk b/config/zen4/make_defs.mk index 75bec7018e..d5d13167aa 100644 --- a/config/zen4/make_defs.mk +++ b/config/zen4/make_defs.mk @@ -71,11 +71,17 @@ CKOPTFLAGS := $(COPTFLAGS) -fomit-frame-pointer ifeq ($(CC_VENDOR),gcc) GCC_VERSION := $(strip $(shell $(CC) -dumpversion | cut -d. -f1)) # gcc or clang version must be atleast 4.0 -# gcc 9.0 or later: +# gcc 12.0 or later: +ifeq ($(shell test $(GCC_VERSION) -ge 12; echo $$?),0) +CKVECFLAGS += -march=znver4 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -mfpmath=sse +CRVECFLAGS += -march=znver4 +else +# gcc 11.0 or later: ifeq ($(shell test $(GCC_VERSION) -ge 11; echo $$?),0) CKVECFLAGS += -march=znver3 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -mfpmath=sse CRVECFLAGS += -march=znver3 else +# gcc 9.0 or later: ifeq ($(shell test $(GCC_VERSION) -ge 9; echo $$?),0) CKVECFLAGS += -march=znver2 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -mfpmath=sse CRVECFLAGS += -march=znver2 @@ -86,6 +92,7 @@ CKVECFLAGS += -march=znver1 -mno-avx256-split-unaligned-store CRVECFLAGS += -march=znver1 -mno-avx256-split-unaligned-store endif # GCC 9 endif # GCC 11 +endif # GCC 12 else ifeq ($(CC_VENDOR),clang) From eff436c6536d715170ca20e3beb3834515cfaa55 Mon Sep 17 00:00:00 2001 From: Kiran Varaganti Date: Fri, 22 Jul 2022 06:21:13 +0000 Subject: [PATCH 149/243] Bug Fix to replace vzeroall Fixed syntax in AVX512 dgemm native kernel. zen4 configuration follows Intel ASM syntax whereas other AMD configs follow AT&T ASM syntax. Bug was introduced due to following AT&T syntax in AVX512 dgemm kernel. In this commit we changed the syntax to Intel ASM format. src and dst operands are interchanged. Change-Id: Ie61dc7c5e8309b79437d471331318f3104bcd447 --- kernels/skx/3/bli_dgemm_skx_asm_16x14.c | 54 ++++++++++++------------- 1 file changed, 27 insertions(+), 27 deletions(-) diff --git a/kernels/skx/3/bli_dgemm_skx_asm_16x14.c b/kernels/skx/3/bli_dgemm_skx_asm_16x14.c index 877e0c9191..c0ada1eb66 100644 --- a/kernels/skx/3/bli_dgemm_skx_asm_16x14.c +++ b/kernels/skx/3/bli_dgemm_skx_asm_16x14.c @@ -176,33 +176,33 @@ void bli_dgemm_skx_asm_16x14( BEGIN_ASM() VXORPD(YMM( 4), YMM( 4), YMM( 4)) //clear out registers - VMOVAPD(YMM(4), YMM( 5)) - VMOVAPD(YMM(4), YMM( 6)) - VMOVAPD(YMM(4), YMM( 7)) - VMOVAPD(YMM(4), YMM( 8)) - VMOVAPD(YMM(4), YMM( 9)) - VMOVAPD(YMM(4), YMM(10)) - VMOVAPD(YMM(4), YMM(11)) - VMOVAPD(YMM(4), YMM(12)) - VMOVAPD(YMM(4), YMM(13)) - VMOVAPD(YMM(4), YMM(14)) - VMOVAPD(YMM(4), YMM(15)) - VMOVAPD(YMM(4), YMM(16)) - VMOVAPD(YMM(4), YMM(17)) - VMOVAPD(YMM(4), YMM(18)) - VMOVAPD(YMM(4), YMM(19)) - VMOVAPD(YMM(4), YMM(20)) - VMOVAPD(YMM(4), YMM(21)) - VMOVAPD(YMM(4), YMM(22)) - VMOVAPD(YMM(4), YMM(23)) - VMOVAPD(YMM(4), YMM(24)) - VMOVAPD(YMM(4), YMM(25)) - VMOVAPD(YMM(4), YMM(26)) - VMOVAPD(YMM(4), YMM(27)) - VMOVAPD(YMM(4), YMM(28)) - VMOVAPD(YMM(4), YMM(29)) - VMOVAPD(YMM(4), YMM(30)) - VMOVAPD(YMM(4), YMM(31)) + VMOVAPD(YMM(5) , YMM(4)) + VMOVAPD(YMM(6) , YMM(4)) + VMOVAPD(YMM(7) , YMM(4)) + VMOVAPD(YMM(8) , YMM(4)) + VMOVAPD(YMM(9) , YMM(4)) + VMOVAPD(YMM(10), YMM(4)) + VMOVAPD(YMM(11), YMM(4)) + VMOVAPD(YMM(12), YMM(4)) + VMOVAPD(YMM(13), YMM(4)) + VMOVAPD(YMM(14), YMM(4)) + VMOVAPD(YMM(15), YMM(4)) + VMOVAPD(YMM(16), YMM(4)) + VMOVAPD(YMM(17), YMM(4)) + VMOVAPD(YMM(18), YMM(4)) + VMOVAPD(YMM(19), YMM(4)) + VMOVAPD(YMM(20), YMM(4)) + VMOVAPD(YMM(21), YMM(4)) + VMOVAPD(YMM(22), YMM(4)) + VMOVAPD(YMM(23), YMM(4)) + VMOVAPD(YMM(24), YMM(4)) + VMOVAPD(YMM(25), YMM(4)) + VMOVAPD(YMM(26), YMM(4)) + VMOVAPD(YMM(27), YMM(4)) + VMOVAPD(YMM(28), YMM(4)) + VMOVAPD(YMM(29), YMM(4)) + VMOVAPD(YMM(30), YMM(4)) + VMOVAPD(YMM(31), YMM(4)) MOV(RSI, VAR(k)) //loop index MOV(RAX, VAR(a)) //load address of a From 6054b888fb5abb19166eef8981b3747cab81ddee Mon Sep 17 00:00:00 2001 From: Kiran Varaganti Date: Mon, 25 Jul 2022 15:31:23 +0000 Subject: [PATCH 150/243] Fixed Bug in bench_trsm.c When bli_trsm() API is called, we make sure the "side" argument is "side_t" and not f77_char and argument is passed by value and not by its address. Change-Id: I5a616eb054c034be2d67640b8ab3b9615706a8c9 --- bench/bench_trsm.c | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bench/bench_trsm.c b/bench/bench_trsm.c index b2b7f1af18..7014bd4753 100644 --- a/bench/bench_trsm.c +++ b/bench/bench_trsm.c @@ -62,7 +62,7 @@ int main( int argc, char** argv ) dim_t p_inc = 0; // to keep track of number of inputs num_t dt = BLIS_DOUBLE; dim_t r, n_repeats; - f77_char side; + side_t side; uplo_t uploa; trans_t transa; diag_t diaga; @@ -191,7 +191,7 @@ int main( int argc, char** argv ) #endif dtime = bli_clock(); #ifdef BLIS - bli_trsm( &side, + bli_trsm( side, &alpha, &a, &b ); From 808d79a6108cb6878646e6cac6ed96c0e35c0116 Mon Sep 17 00:00:00 2001 From: Vignesh Balasubramanian Date: Thu, 21 Jul 2022 21:49:46 +0530 Subject: [PATCH 151/243] Implemented efficient ZGEMM algorithm when k=1 Problem statement : To improve the performance of the zgemm kernel for dealing with input sizes with k=1 by fine tuning its previous implementation. In the previous implementation, usage of SIMD parallelism along m and n dimensions instead of the k dimension proved to provide a better performance to the zgemm kernel. This code was subjected to further improvements along the following lines: - Cases to deal with alpha=0 and beta!=0 (i.e. just scaling of C) were handled at the beginning separately, using the bli_zscalm api. - Register blocking was further improved, resulting in the kernel size to increase from 4x5 to 4x6. - Prefetching was added to the code, by empirically finding out a suitable value to be added to the pointer. Overall, it provided a mild improvement to the performance. - Conditional statements were removed from the kernel loop, and a logic was deduced to allow such removal without affecting the output. The performance improvement of this single threaded implementation also proved to compete with that of the default implementation for multiple threads, as long as m and n are under 128. An improvement to this patch would be to find out a suitable feature which would establish a relationship between the number of threads and the input size constraints, thereby providing a unique size constraint for different number of threads. AMD-Internal: [CPUPL-2236] Change-Id: I3d401c8fd78bec80ce62eef390fa85e6287df847 --- frame/compat/bla_gemm_amd.c | 17 +- kernels/zen/3/bli_zgemm_ref_k1.c | 1920 +++++++++++++++--------------- 2 files changed, 990 insertions(+), 947 deletions(-) diff --git a/frame/compat/bla_gemm_amd.c b/frame/compat/bla_gemm_amd.c index eed041e4cf..8a8d576c46 100644 --- a/frame/compat/bla_gemm_amd.c +++ b/frame/compat/bla_gemm_amd.c @@ -712,13 +712,20 @@ void zgemm_ //dim_t nt = bli_thread_get_num_threads(); // get number of threads bool nt = bli_thread_get_is_parallel(); // Check if parallel zgemm is invoked. - if((nt==0) && (k0 == 1) && bli_is_notrans(blis_transa) && bli_is_notrans(blis_transb)) + /* + Invoking the API for input sizes with k=1. + - For single thread, the API has no constraints before invoking. + - For multiple threads, the constraint is that m and n should individually be less than 128. + */ + if((k0==1) && ((nt==0) || ((nt==1) && (m0 < 128) && (n0 < 128))) + && bli_is_notrans(blis_transa) + && bli_is_notrans(blis_transb)) { bli_zgemm_ref_k1_nn( m0, n0, k0, - alpha, - a, *lda, - b, *ldb, - beta, + (dcomplex*)alpha, + (dcomplex*)a, *lda, + (dcomplex*)b, *ldb, + (dcomplex*)beta, c, *ldc); AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); diff --git a/kernels/zen/3/bli_zgemm_ref_k1.c b/kernels/zen/3/bli_zgemm_ref_k1.c index 143bf8c84d..47de706238 100644 --- a/kernels/zen/3/bli_zgemm_ref_k1.c +++ b/kernels/zen/3/bli_zgemm_ref_k1.c @@ -35,35 +35,28 @@ #include #include #include "blis.h" - #include "immintrin.h" #define Z_MR 4 -#define Z_NR 5 +#define Z_NR 6 // Macros for the main loop for M -#define SCALE_ALPHA_REAL_M_LOOP(r0,r1,r2,valr) \ - if(valr != 0.0) \ - { \ - r2 = _mm256_broadcast_sd((double const *)(&valr)); \ - r0 = _mm256_mul_pd(r0,r2); \ - r1 = _mm256_mul_pd(r1,r2); \ - } \ - -#define SCALE_ALPHA_IMAG_M_LOOP(r0,r1,r2,r3,rc1,rc2,vali) \ - if(vali != 0.0) \ - { \ - r3 = _mm256_permute4x64_pd(rc1,0b10110001); \ - r2 = _mm256_set_pd(1.0,-1.0,1.0,-1.0); \ - r3 = _mm256_mul_pd(r3, r2); \ - r2 = _mm256_broadcast_sd((double const *)(&vali)); \ - r0 = _mm256_fmadd_pd(r3,r2,r0); \ - r3 = _mm256_permute4x64_pd(rc2,0b10110001); \ - r2 = _mm256_set_pd(1.0,-1.0,1.0,-1.0); \ - r3 = _mm256_mul_pd(r3, r2); \ - r2 = _mm256_broadcast_sd((double const *)(&vali)); \ - r1 = _mm256_fmadd_pd(r3,r2,r1); \ - } \ +#define SCALE_ALPHA_REAL_M_LOOP(rin_0,rin_1,r_bcast,real_val) \ + r_bcast = _mm256_broadcast_sd((double const *)(&real_val)); \ + rin_0 = _mm256_mul_pd(rin_0,r_bcast); \ + rin_1 = _mm256_mul_pd(rin_1,r_bcast); \ + +#define SCALE_ALPHA_IMAG_M_LOOP(rout_0,rout_1,rin_0,rin_1,r_bcast,r_perm,imag_val) \ + r_perm = _mm256_permute4x64_pd(rin_0,0b10110001); \ + r_bcast = _mm256_set_pd(1.0,-1.0,1.0,-1.0); \ + r_perm = _mm256_mul_pd(r_bcast, r_perm); \ + r_bcast = _mm256_broadcast_sd((double const *)(&imag_val)); \ + rout_0 = _mm256_fmadd_pd(r_perm,r_bcast,rout_0); \ + r_perm = _mm256_permute4x64_pd(rin_1,0b10110001); \ + r_bcast = _mm256_set_pd(1.0,-1.0,1.0,-1.0); \ + r_perm = _mm256_mul_pd(r_bcast, r_perm); \ + r_bcast = _mm256_broadcast_sd((double const *)(&imag_val)); \ + rout_1 = _mm256_fmadd_pd(r_perm,r_bcast,rout_1); \ #define NEG_PERM_M_LOOP(r0,r1,r2) \ r0 = _mm256_permute4x64_pd(r0,0b10110001); \ @@ -86,24 +79,17 @@ rout_0 = _mm256_fmadd_pd(rbc, rin_0, rout_0); \ rout_1 = _mm256_fmadd_pd(rbc, rin_1, rout_1); \ - // Macros for fringe cases with M -#define SCALE_ALPHA_REAL_M_FRINGE(r0,r2,val) \ - if(val != 0.0) \ - { \ - r2 = _mm256_broadcast_sd((double const *)(&val)); \ - r0 = _mm256_mul_pd(r0,r2); \ - } \ - -#define SCALE_ALPHA_IMAG_M_FRINGE(r0,r2,r3,r4,val) \ - if(val != 0.0) \ - { \ - r3 = _mm256_permute4x64_pd(r4,0b10110001); \ - r2 = _mm256_set_pd(1.0,-1.0,1.0,-1.0); \ - r3 = _mm256_mul_pd(r3, r2); \ - r2 = _mm256_broadcast_sd((double const *)(&val)); \ - r0 = _mm256_fmadd_pd(r3,r2,r0); \ - } \ +#define SCALE_ALPHA_REAL_M_FRINGE(rin_0,r_bcast,real_val) \ + r_bcast = _mm256_broadcast_sd((double const *)(&real_val)); \ + rin_0 = _mm256_mul_pd(rin_0,r_bcast); \ + +#define SCALE_ALPHA_IMAG_M_FRINGE(rout_0,rin_0,r_bcast,r_perm,imag_val) \ + r_perm = _mm256_permute4x64_pd(rin_0,0b10110001); \ + r_bcast = _mm256_set_pd(1.0,-1.0,1.0,-1.0); \ + r_perm = _mm256_mul_pd(r_bcast, r_perm); \ + r_bcast = _mm256_broadcast_sd((double const *)(&imag_val)); \ + rout_0 = _mm256_fmadd_pd(r_perm,r_bcast,rout_0); \ #define NEG_PERM_M_FRINGE(r0,r2) \ r0 = _mm256_permute4x64_pd(r0,0b10110001); \ @@ -121,7 +107,6 @@ NEG_PERM_M_FRINGE(rin_0,rn); \ rout_0 = _mm256_fmadd_pd(rbc, rin_0, rout_0); \ - void bli_zgemm_ref_k1_nn ( dim_t m, @@ -135,19 +120,31 @@ void bli_zgemm_ref_k1_nn ) { - double alpha_valr, beta_valr; - double alpha_vali, beta_vali; - - alpha_valr = alpha->real; - beta_valr = beta->real; - alpha_vali = alpha->imag; - beta_vali = beta->imag; + double alpha_real, beta_real; + double alpha_imag, beta_imag; + + alpha_real = alpha->real; + beta_real = beta->real; + alpha_imag = alpha->imag; + beta_imag = beta->imag; + + /* If m or n is zero, return immediately. */ + if ( bli_zero_dim2( m, n ) ) return; + /* If alpha alone is zero, scale by beta and return. */ + if (bli_zeq0(*(alpha))) + { + bli_zscalm( + BLIS_NO_CONJUGATE, + 0, + BLIS_NONUNIT_DIAG, + BLIS_DENSE, + m, n, + beta, + c, 1, ldc + ); + return; + } - if((m == 0) || (n == 0) || (((alpha_valr == 0.0 && alpha_vali == 0.0) || (k == 0)) - && (beta_valr == 1.0 && beta_vali == 0.0))) - { - return; - } dim_t m_remainder = (m % Z_MR); dim_t n_remainder = (n % Z_NR); @@ -159,12 +156,14 @@ void bli_zgemm_ref_k1_nn __m128d xmm5; /* Form C = alpha*A*B + beta*c */ + // Main loop along N dimension for(dim_t j = 0;j < (n-Z_NR+1);j=j+Z_NR) { dcomplex* temp_b = b + j*ldb; dcomplex* temp_a = a; dcomplex* temp_c = c + j*ldc; + //Main loop along M dimension for(dim_t i = 0;i < (m-Z_MR+1);i=i+Z_MR) { ymm3 = _mm256_setzero_pd(); @@ -178,235 +177,265 @@ void bli_zgemm_ref_k1_nn ymm11 = _mm256_setzero_pd(); ymm12 = _mm256_setzero_pd(); - if(alpha_valr != 0.0 || alpha_vali != 0.0) - { - /* - a. Perform alpha*A*B using temp_a, temp_b and alpha_valr, alpha_vali - where alpha_valr and/or alpha_vali is not zero. - b. This loop operates with 4x5 block size - along n dimension for every Z_NR columns of temp_b where - computing all Z_MR rows of temp_a. - c. Same approach is used in remaining fringe cases. - */ - //R(a[0][0]) I(a[0][0]) R(a[1][0]) I(a[1][0]) - ymm0 = _mm256_loadu_pd((double const *)(temp_a)); - //R(a[2][0]) I(a[2][0]) R(a[3][0]) I(a[3][0]) - ymm1 = _mm256_loadu_pd((double const *)(temp_a + 2)); - - ymm13 = ymm0; - ymm14 = ymm1; - - _mm_prefetch((char*)(temp_a) + 64, _MM_HINT_T0); - - SCALE_ALPHA_REAL_M_LOOP(ymm0,ymm1,ymm15,alpha_valr); - SCALE_ALPHA_IMAG_M_LOOP(ymm0,ymm1,ymm15,ymm2,ymm13,ymm14,alpha_vali); - - /* - The result after scaling with alpha_valr and/or alpha_vali is as follows: - For ymm0 : - R(a[0][0]) = alpha_valr*R(a[0][0])-alpha_vali*I(a[0][0]) - I(a[0][0]) = alpha_valr*I(a[0][0])+alpha_vali*R[0][0] - R(a[1][0]) = alpha_valr*R(a[1][0])-alpha_vali*I(a[1][0]) - I(a[1][0]) = alpha_valr*I(a[1][0])+alpha_vali*(R[1][0]) - - For ymm1 : - R(a[2][0]) = alpha_valr*R(a[2][0])-alpha_vali*I(a[2][0]) - I(a[2][0]) = alpha_valr*I(a[2][0])+alpha_vali*R[2][0] - R(a[3][0]) = alpha_valr*R(a[3][0])-alpha_vali*I(a[3][0]) - I(a[3][0]) = alpha_valr*I(a[3][0])+alpha_vali*(R[3][0]) - */ - - //Calculating using real part of complex number in B matrix - //ymm3+=R(b[0][0])*R(a[0][0]) R(b[0][0])*I(a[0][0]) - // R(b[0][0])*R(a[1][0]) R(b[0][0])*I(a[1][0]) - //ymm4+=R(b[0][0])*R(a[2][0]) R(b[0][0])*I(a[2][0]) - // R(b[0][0])*R(a[3][0]) R(b[0][0])*I(a[3][0]) - FMA_M_LOOP(ymm0,ymm1,ymm3,ymm4,ymm2,(double const *)(temp_b)); - //ymm5+=R(b[0][1])*R(a[0][0]) R(b[0][1])*I(a[0][0]) - // R(b[0][1])*R(a[1][0]) R(b[0][1])*I(a[1][0]) - //ymm6+=R(b[0][1])*R(a[0][0]) R(b[0][1])*I(a[0][0]) - // R(b[0][1])*R(a[1][0]) R(b[0][1])*I(a[1][0]) - FMA_M_LOOP(ymm0,ymm1,ymm5,ymm6,ymm2,(double const *)(temp_b+ldb)); - //ymm7+=R(b[0][2])*R(a[0][0]) R(b[0][2])*I(a[0][0]) - // R(b[0][2])*R(a[1][0]) R(b[0][2])*I(a[1][0]) - //ymm8+=R(b[0][2])*R(a[0][0]) R(b[0][2])*I(a[0][0]) - // R(b[0][2])*R(a[1][0]) R(b[0][2])*I(a[1][0]) - FMA_M_LOOP(ymm0,ymm1,ymm7,ymm8,ymm2,(double const *)(temp_b+ldb*2)); - //ymm9+=R(b[0][3])*R(a[0][0]) R(b[0][3])*I(a[0][0]) - // R(b[0][3])*R(a[1][0]) R(b[0][3])*I(a[1][0]) - //ymm10+=R(b[0][3])*R(a[0][0]) R(b[0][3])*I(a[0][0]) - // R(b[0][3])*R(a[1][0]) R(b[0][3])*I(a[1][0]) - FMA_M_LOOP(ymm0,ymm1,ymm9,ymm10,ymm2,(double const *)(temp_b+ldb*3)); - //ymm11+=R(b[0][4])*R(a[0][0]) R(b[0][4])*I(a[0][0]) - // R(b[0][4])*R(a[1][0]) R(b[0][4])*I(a[1][0]) - //ymm12+=R(b[0][4])*R(a[0][0]) R(b[0][4])*I(a[0][0]) - // R(b[0][4])*R(a[1][0]) R(b[0][4])*I(a[1][0]) - FMA_M_LOOP(ymm0,ymm1,ymm11,ymm12,ymm2,(double const *)(temp_b+ldb*4)); - - //Calculating using imaginary part of complex numbers in B matrix - //Shuffling ymm0 and ymm1 in accordance to the requirement - NEG_PERM_M_LOOP(ymm0,ymm1,ymm2); - //ymm3+=I(b[0][0])*R(a[0][0]) I(b[0][0])*I(a[0][0]) - // I(b[0][0])*R(a[1][0]) I(b[0][0])*I(a[1][0]) - //ymm4+=R(b[0][0])*R(a[2][0]) I(b[0][0])*I(a[2][0]) - // I(b[0][0])*R(a[3][0]) I(b[0][0])*I(a[3][0]) - FMA_M_LOOP(ymm0,ymm1,ymm3,ymm4,ymm2,(double const *)(temp_b)+1); - //ymm5+=I(b[0][1])*R(a[0][0]) I(b[0][1])*I(a[0][0]) - // I(b[0][1])*R(a[1][0]) I(b[0][1])*I(a[1][0]) - //ymm6+=R(b[0][1])*R(a[0][0]) I(b[0][1])*I(a[0][0]) - // I(b[0][1])*R(a[1][0]) I(b[0][1])*I(a[1][0]) - FMA_M_LOOP(ymm0,ymm1,ymm5,ymm6,ymm2,(double const *)(temp_b+ldb)+1); - //ymm7+=I(b[0][2])*R(a[0][0]) I(b[0][2])*I(a[0][0]) - // I(b[0][2])*R(a[1][0]) I(b[0][2])*I(a[1][0]) - //ymm8+=I(b[0][2])*R(a[0][0]) I(b[0][2])*I(a[0][0]) - // I(b[0][2])*R(a[1][0]) I(b[0][2])*I(a[1][0]) - FMA_M_LOOP(ymm0,ymm1,ymm7,ymm8,ymm2,(double const *)(temp_b+ldb*2)+1); - //ymm9+=I(b[0][3])*R(a[0][0]) I(b[0][3])*I(a[0][0]) - // I(b[0][3])*R(a[1][0]) I(b[0][3])*I(a[1][0]) - //ymm10+=I(b[0][3])*R(a[0][0]) I(b[0][3])*I(a[0][0]) - // I(b[0][3])*R(a[1][0]) I(b[0][3])*I(a[1][0]) - FMA_M_LOOP(ymm0,ymm1,ymm9,ymm10,ymm2,(double const *)(temp_b+ldb*3)+1); - //ymm11+=I(b[0][4])*R(a[0][0]) I(b[0][4])*I(a[0][0]) - // I(b[0][4])*R(a[1][0]) I(b[0][4])*I(a[1][0]) - //ymm12+=I(b[0][4])*R(a[0][0]) I(b[0][4])*I(a[0][0]) - // I(b[0][4])*R(a[1][0]) I(b[0][4])*I(a[1][0]) - FMA_M_LOOP(ymm0,ymm1,ymm11,ymm12,ymm2,(double const *)(temp_b+ldb*4)+1); - } - if(beta_valr != 0.0) + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_real, alpha_imag + where alpha_real and/or alpha_imag is not zero. + b. This loop operates with 4x6 block size + along n dimension for every Z_NR columns of temp_b where + computing all Z_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + //R(a[0][0]) I(a[0][0]) R(a[1][0]) I(a[1][0]) + ymm0 = _mm256_loadu_pd((double const *)(temp_a)); + //R(a[2][0]) I(a[2][0]) R(a[3][0]) I(a[3][0]) + ymm1 = _mm256_loadu_pd((double const *)(temp_a + 2)); + + ymm13 = ymm0; + ymm14 = ymm1; + _mm_prefetch((char*)(temp_a + 32), _MM_HINT_T0); + + SCALE_ALPHA_REAL_M_LOOP(ymm0,ymm1,ymm15,alpha_real); + SCALE_ALPHA_IMAG_M_LOOP(ymm0,ymm1,ymm13,ymm14,ymm15,ymm2,alpha_imag); + + ymm13 = _mm256_setzero_pd(); + ymm14 = _mm256_setzero_pd(); + + /* + The result after scaling with alpha_real and/or alpha_imag is as follows: + For ymm0 : + R(a[0][0]) = alpha_real*R(a[0][0])-alpha_imag*I(a[0][0]) + I(a[0][0]) = alpha_real*I(a[0][0])+alpha_imag*R[0][0] + R(a[1][0]) = alpha_real*R(a[1][0])-alpha_imag*I(a[1][0]) + I(a[1][0]) = alpha_real*I(a[1][0])+alpha_imag*(R[1][0]) + + For ymm1 : + R(a[2][0]) = alpha_real*R(a[2][0])-alpha_imag*I(a[2][0]) + I(a[2][0]) = alpha_real*I(a[2][0])+alpha_imag*R[2][0] + R(a[3][0]) = alpha_real*R(a[3][0])-alpha_imag*I(a[3][0]) + I(a[3][0]) = alpha_real*I(a[3][0])+alpha_imag*(R[3][0]) + */ + + //Calculating using real part of complex number in B matrix + //ymm3+=R(b[0][0])*R(a[0][0]) R(b[0][0])*I(a[0][0]) + // R(b[0][0])*R(a[1][0]) R(b[0][0])*I(a[1][0]) + //ymm4+=R(b[0][0])*R(a[2][0]) R(b[0][0])*I(a[2][0]) + // R(b[0][0])*R(a[3][0]) R(b[0][0])*I(a[3][0]) + FMA_M_LOOP(ymm0,ymm1,ymm3,ymm4,ymm2,(double const *)(temp_b)); + //ymm5+=R(b[0][1])*R(a[0][0]) R(b[0][1])*I(a[0][0]) + // R(b[0][1])*R(a[1][0]) R(b[0][1])*I(a[1][0]) + //ymm6+=R(b[0][1])*R(a[0][0]) R(b[0][1])*I(a[0][0]) + // R(b[0][1])*R(a[1][0]) R(b[0][1])*I(a[1][0]) + FMA_M_LOOP(ymm0,ymm1,ymm5,ymm6,ymm2,(double const *)(temp_b+ldb)); + //ymm7+=R(b[0][2])*R(a[0][0]) R(b[0][2])*I(a[0][0]) + // R(b[0][2])*R(a[1][0]) R(b[0][2])*I(a[1][0]) + //ymm8+=R(b[0][2])*R(a[0][0]) R(b[0][2])*I(a[0][0]) + // R(b[0][2])*R(a[1][0]) R(b[0][2])*I(a[1][0]) + FMA_M_LOOP(ymm0,ymm1,ymm7,ymm8,ymm2,(double const *)(temp_b+ldb*2)); + //ymm9+=R(b[0][3])*R(a[0][0]) R(b[0][3])*I(a[0][0]) + // R(b[0][3])*R(a[1][0]) R(b[0][3])*I(a[1][0]) + //ymm10+=R(b[0][3])*R(a[0][0]) R(b[0][3])*I(a[0][0]) + // R(b[0][3])*R(a[1][0]) R(b[0][3])*I(a[1][0]) + FMA_M_LOOP(ymm0,ymm1,ymm9,ymm10,ymm2,(double const *)(temp_b+ldb*3)); + //ymm11+=R(b[0][4])*R(a[0][0]) R(b[0][4])*I(a[0][0]) + // R(b[0][4])*R(a[1][0]) R(b[0][4])*I(a[1][0]) + //ymm12+=R(b[0][4])*R(a[0][0]) R(b[0][4])*I(a[0][0]) + // R(b[0][4])*R(a[1][0]) R(b[0][4])*I(a[1][0]) + FMA_M_LOOP(ymm0,ymm1,ymm11,ymm12,ymm2,(double const *)(temp_b+ldb*4)); + //ymm11+=R(b[0][5])*R(a[0][0]) R(b[0][5])*I(a[0][0]) + // R(b[0][5])*R(a[1][0]) R(b[0][5])*I(a[1][0]) + //ymm12+=R(b[0][5])*R(a[0][0]) R(b[0][5])*I(a[0][0]) + // R(b[0][5])*R(a[1][0]) R(b[0][5])*I(a[1][0]) + FMA_M_LOOP(ymm0,ymm1,ymm13,ymm14,ymm2,(double const *)(temp_b+ldb*5)); + + //Calculating using imaginary part of complex numbers in B matrix + //Shuffling ymm0 and ymm1 in accordance to the requirement + NEG_PERM_M_LOOP(ymm0,ymm1,ymm2); + //ymm3+=I(b[0][0])*R(a[0][0]) I(b[0][0])*I(a[0][0]) + // I(b[0][0])*R(a[1][0]) I(b[0][0])*I(a[1][0]) + //ymm4+=R(b[0][0])*R(a[2][0]) I(b[0][0])*I(a[2][0]) + // I(b[0][0])*R(a[3][0]) I(b[0][0])*I(a[3][0]) + FMA_M_LOOP(ymm0,ymm1,ymm3,ymm4,ymm2,(double const *)(temp_b)+1); + //ymm5+=I(b[0][1])*R(a[0][0]) I(b[0][1])*I(a[0][0]) + // I(b[0][1])*R(a[1][0]) I(b[0][1])*I(a[1][0]) + //ymm6+=R(b[0][1])*R(a[0][0]) I(b[0][1])*I(a[0][0]) + // I(b[0][1])*R(a[1][0]) I(b[0][1])*I(a[1][0]) + FMA_M_LOOP(ymm0,ymm1,ymm5,ymm6,ymm2,(double const *)(temp_b+ldb)+1); + //ymm7+=I(b[0][2])*R(a[0][0]) I(b[0][2])*I(a[0][0]) + // I(b[0][2])*R(a[1][0]) I(b[0][2])*I(a[1][0]) + //ymm8+=I(b[0][2])*R(a[0][0]) I(b[0][2])*I(a[0][0]) + // I(b[0][2])*R(a[1][0]) I(b[0][2])*I(a[1][0]) + FMA_M_LOOP(ymm0,ymm1,ymm7,ymm8,ymm2,(double const *)(temp_b+ldb*2)+1); + //ymm9+=I(b[0][3])*R(a[0][0]) I(b[0][3])*I(a[0][0]) + // I(b[0][3])*R(a[1][0]) I(b[0][3])*I(a[1][0]) + //ymm10+=I(b[0][3])*R(a[0][0]) I(b[0][3])*I(a[0][0]) + // I(b[0][3])*R(a[1][0]) I(b[0][3])*I(a[1][0]) + FMA_M_LOOP(ymm0,ymm1,ymm9,ymm10,ymm2,(double const *)(temp_b+ldb*3)+1); + //ymm11+=I(b[0][4])*R(a[0][0]) I(b[0][4])*I(a[0][0]) + // I(b[0][4])*R(a[1][0]) I(b[0][4])*I(a[1][0]) + //ymm12+=I(b[0][4])*R(a[0][0]) I(b[0][4])*I(a[0][0]) + // I(b[0][4])*R(a[1][0]) I(b[0][4])*I(a[1][0]) + FMA_M_LOOP(ymm0,ymm1,ymm11,ymm12,ymm2,(double const *)(temp_b+ldb*4)+1); + //ymm13+=I(b[0][5])*R(a[0][0]) I(b[0][5])*I(a[0][0]) + // I(b[0][5])*R(a[1][0]) I(b[0][5])*I(a[1][0]) + //ymm14+=I(b[0][5])*R(a[0][0]) I(b[0][5])*I(a[0][0]) + // I(b[0][5])*R(a[1][0]) I(b[0][5])*I(a[1][0]) + FMA_M_LOOP(ymm0,ymm1,ymm13,ymm14,ymm2,(double const *)(temp_b+ldb*5)+1); + + /* + a. Perform beta*C using temp_c, beta_real, + where beta_real is not zero. + b. This loop operates with 4x6 block size + along n dimension for every Z_NR columns of temp_c where + computing all Z_MR rows of temp_c. + c. Accumulated alpha*A*B into registers will be added to beta*C + d. Same approach is used in remaining fringe cases. + */ + if(beta_real != 0.0) { - /* - a. Perform beta*C using temp_c, beta_valr, - where beta_valr is not zero. - b. This loop operates with 4x5 block size - along n dimension for every Z_NR columns of temp_c where - computing all Z_MR rows of temp_c. - c. Accumulated alpha*A*B into registers will be added to beta*C - d. Same approach is used in remaining fringe cases. - */ - ymm15 = _mm256_broadcast_sd((double const *)(&beta_valr)); + ymm15 = _mm256_broadcast_sd((double const *)(&beta_real)); //R(c[0][0]) I(c[0][0]) R(c[1][0]) I(c[1][0]) ymm0 = _mm256_loadu_pd((double const *)(temp_c)); //R(c[2][0]) I(c[2][0]) R(c[3][0]) I(c[3][0]) ymm1 = _mm256_loadu_pd((double const *)(temp_c + 2)); - //ymm3+=beta_valr*R(c[0][0]) beta_valr*I(c[0][0]) - // beta_valr*R(c[1][0]) beta_valr*I(c[1][0]) - //ymm4+=beta_valr*R(c[2][0]) beta_valr*I(c[2][0]) - // beta_valr*R(c[3][0]) beta_valr*I(c[3][0]) + //ymm3+=beta_real*R(c[0][0]) beta_real*I(c[0][0]) + // beta_real*R(c[1][0]) beta_real*I(c[1][0]) + //ymm4+=beta_real*R(c[2][0]) beta_real*I(c[2][0]) + // beta_real*R(c[3][0]) beta_real*I(c[3][0]) SCALE_BETA_REAL_M_LOOP(ymm0,ymm1,ymm3,ymm4,ymm15); //R(c[0][1]) I(c[0][1]) R(c[1][1]) I(c[1][1]) ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc)); //R(c[2][1]) I(c[2][1]) R(c[3][1]) I(c[3][1]) ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc + 2)); - //ymm5+=beta_valr*R(c[0][1]) beta_valr*I(c[0][1]) - // beta_valr*R(c[1][1]) beta_valr*I(c[1][1]) - //ymm6+=beta_valr*R(c[2][1]) beta_valr*I(c[2][1]) - // beta_valr*R(c[3][1]) beta_valr*I(c[3][1]) + //ymm5+=beta_real*R(c[0][1]) beta_real*I(c[0][1]) + // beta_real*R(c[1][1]) beta_real*I(c[1][1]) + //ymm6+=beta_real*R(c[2][1]) beta_real*I(c[2][1]) + // beta_real*R(c[3][1]) beta_real*I(c[3][1]) SCALE_BETA_REAL_M_LOOP(ymm0,ymm1,ymm5,ymm6,ymm15); //R(c[0][2]) I(c[0][2]) R(c[1][2]) I(c[1][2]) ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*2)); //R(c[2][2]) I(c[2][2]) R(c[3][2]) I(c[3][2]) ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc*2 + 2)); - //ymm7+=beta_valr*R(c[0][2]) beta_valr*I(c[0][2]) - // beta_valr*R(c[1][2]) beta_valr*I(c[1][2]) - //ymm8+=beta_valr*R(c[2][2]) beta_valr*I(c[2][2]) - //beta_valr*R(c[3][2]) beta_valr*I(c[3][2]) + //ymm7+=beta_real*R(c[0][2]) beta_real*I(c[0][2]) + // beta_real*R(c[1][2]) beta_real*I(c[1][2]) + //ymm8+=beta_real*R(c[2][2]) beta_real*I(c[2][2]) + //beta_real*R(c[3][2]) beta_real*I(c[3][2]) SCALE_BETA_REAL_M_LOOP(ymm0,ymm1,ymm7,ymm8,ymm15); //R(c[0][3]) I(c[0][3]) R(c[1][3]) I(c[1][3]) ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*3)); //R(c[2][3]) I(c[2][3]) R(c[3][3]) I(c[3][3]) ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc*3 + 2)); - //ymm9+=beta_valr*R(c[0][3]) beta_valr*I(c[0][3]) - // beta_valr*R(c[1][3]) beta_valr*I(c[1][3]) - //ymm10+=beta_valr*R(c[2][3]) beta_valr*I(c[2][3]) - // beta_valr*R(c[3][3]) beta_valr*I(c[3][3]) + //ymm9+=beta_real*R(c[0][3]) beta_real*I(c[0][3]) + // beta_real*R(c[1][3]) beta_real*I(c[1][3]) + //ymm10+=beta_real*R(c[2][3]) beta_real*I(c[2][3]) + // beta_real*R(c[3][3]) beta_real*I(c[3][3]) SCALE_BETA_REAL_M_LOOP(ymm0,ymm1,ymm9,ymm10,ymm15); //R(c[0][4]) I(c[0][4]) R(c[1][4]) I(c[1][4]) ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*4)); //R(c[2][4]) I(c[2][4]) R(c[3][4]) I(c[3][4]) ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc*4 + 2)); - //ymm11+=beta_valr*R(c[0][4]) beta_valr*I(c[0][4]) - // beta_valr*R(c[1][4]) beta_valr*I(c[1][4]) - //ymm12+=beta_valr*R(c[2][4]) beta_valr*I(c[2][4]) - // beta_valr*R(c[3][4]) beta_valr*I(c[3][4]) + //ymm11+=beta_real*R(c[0][4]) beta_real*I(c[0][4]) + // beta_real*R(c[1][4]) beta_real*I(c[1][4]) + //ymm12+=beta_real*R(c[2][4]) beta_real*I(c[2][4]) + // beta_real*R(c[3][4]) beta_real*I(c[3][4]) SCALE_BETA_REAL_M_LOOP(ymm0,ymm1,ymm11,ymm12,ymm15); + //R(c[0][5]) I(c[0][5]) R(c[1][5]) I(c[1][5]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*5)); + //R(c[2][5]) I(c[2][5]) R(c[3][5]) I(c[3][5]) + ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc*5 + 2)); + //ymm13+=beta_real*R(c[0][5]) beta_real*I(c[0][5]) + // beta_real*R(c[1][5]) beta_real*I(c[1][5]) + //ymm14+=beta_real*R(c[2][5]) beta_real*I(c[2][5]) + // beta_real*R(c[3][5]) beta_real*I(c[3][5]) + SCALE_BETA_REAL_M_LOOP(ymm0,ymm1,ymm13,ymm14,ymm15); } - if(beta_vali != 0.0) + + /* + a. Perform beta*C using temp_c, beta_imag, + where beta_imag is not zero. + b. This loop operates with 4x6 block size + along n dimension for every Z_NR columns of temp_c where + computing all Z_MR rows of temp_c. + c. Accumulated alpha*A*B into registers will be added to beta*C + d. Same approach is used in remaining fringe cases. + */ + + if(beta_imag != 0.0) { - /* - a. Perform beta*C using temp_c, beta_vali, - where beta_vali is not zero. - b. This loop operates with 4x5 block size - along n dimension for every Z_NR columns of temp_c where - computing all Z_MR rows of temp_c. - c. Accumulated alpha*A*B into registers will be added to beta*C - d. Same approach is used in remaining fringe cases. - */ - - ymm15 = _mm256_broadcast_sd((double const *)(&beta_vali)); + ymm15 = _mm256_broadcast_sd((double const *)(&beta_imag)); //R(c[0][0]) I(c[0][0]) R(c[1][0]) I(c[1][0]) ymm0 = _mm256_loadu_pd((double const *)(temp_c)); //R(c[2][0]) I(c[2][0]) R(c[3][0]) I(c[3][0]) ymm1 = _mm256_loadu_pd((double const *)(temp_c + 2)); - //ymm3+=beta_vali*(-I(c[0][0])) beta_vali*R(c[0][0]) - // beta_vali*(-I(c[1][0])) beta_vali*R(c[1][0]) - //ymm4+=beta_vali*(-I(c[2][0])) beta_vali*R(c[2][0]) - // beta_vali*(-I(c[3][0])) beta_vali*R(c[3][0]) + //ymm3+=beta_imag*(-I(c[0][0])) beta_imag*R(c[0][0]) + // beta_imag*(-I(c[1][0])) beta_imag*R(c[1][0]) + //ymm4+=beta_imag*(-I(c[2][0])) beta_imag*R(c[2][0]) + // beta_imag*(-I(c[3][0])) beta_imag*R(c[3][0]) SCALE_BETA_IMAG_M_LOOP(ymm0,ymm1,ymm3,ymm4,ymm15,ymm2); //R(c[0][1]) I(c[0][1]) R(c[1][1]) I(c[1][1]) ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc)); //R(c[2][1]) I(c[2][1]) R(c[3][1]) I(c[3][1]) ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc + 2)); - //ymm5+=beta_vali*(-I(c[0][1])) beta_vali*R(c[0][1]) - // beta_vali*(-I(c[1][1])) beta_vali*R(c[1][1]) - //ymm6+=beta_vali*(-I(c[2][1])) beta_vali*R(c[2][1]) - // beta_vali*(-I(c[3][1])) beta_vali*R(c[3][1]) + //ymm5+=beta_imag*(-I(c[0][1])) beta_imag*R(c[0][1]) + // beta_imag*(-I(c[1][1])) beta_imag*R(c[1][1]) + //ymm6+=beta_imag*(-I(c[2][1])) beta_imag*R(c[2][1]) + // beta_imag*(-I(c[3][1])) beta_imag*R(c[3][1]) SCALE_BETA_IMAG_M_LOOP(ymm0,ymm1,ymm5,ymm6,ymm15,ymm2); //R(c[0][2]) I(c[0][2]) R(c[1][2]) I(c[1][2]) ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*2)); //R(c[2][2]) I(c[2][2]) R(c[3][2]) I(c[3][2]) ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc*2 + 2)); - //ymm7+=beta_vali*(-I(c[0][2])) beta_vali*R(c[0][2]) - // beta_vali*(-I(c[1][2])) beta_vali*R(c[1][2]) - //ymm8+=beta_vali*(-I(c[2][2])) beta_vali*R(c[2][2]) - // beta_vali*(-I(c[3][2])) beta_vali*R(c[3][2]) + //ymm7+=beta_imag*(-I(c[0][2])) beta_imag*R(c[0][2]) + // beta_imag*(-I(c[1][2])) beta_imag*R(c[1][2]) + //ymm8+=beta_imag*(-I(c[2][2])) beta_imag*R(c[2][2]) + // beta_imag*(-I(c[3][2])) beta_imag*R(c[3][2]) SCALE_BETA_IMAG_M_LOOP(ymm0,ymm1,ymm7,ymm8,ymm15,ymm2); //R(c[0][3]) I(c[0][3]) R(c[1][3]) I(c[1][3]) ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*3)); //R(c[2][3]) I(c[2][3]) R(c[3][3]) I(c[3][3]) ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc*3 + 2)); - //ymm9+=beta_vali*(-I(c[0][3])) beta_vali*R(c[0][3]) - // beta_vali*(-I(c[1][3])) beta_vali*R(c[1][3]) - //ymm10+=beta_vali*(-I(c[2][3])) beta_vali*R(c[2][3]) - // beta_vali*(-I(c[3][3])) beta_vali*R(c[3][3]) + //ymm9+=beta_imag*(-I(c[0][3])) beta_imag*R(c[0][3]) + // beta_imag*(-I(c[1][3])) beta_imag*R(c[1][3]) + //ymm10+=beta_imag*(-I(c[2][3])) beta_imag*R(c[2][3]) + // beta_imag*(-I(c[3][3])) beta_imag*R(c[3][3]) SCALE_BETA_IMAG_M_LOOP(ymm0,ymm1,ymm9,ymm10,ymm15,ymm2); //R(c[0][4]) I(c[0][4]) R(c[1][4]) I(c[1][4]) ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*4)); //R(c[2][4]) I(c[2][4]) R(c[3][4]) I(c[3][4]) ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc*4 + 2)); - //ymm11+=beta_vali*(-I(c[0][4])) beta_vali*R(c[0][4]) - // beta_vali*(-I(c[1][4])) beta_vali*R(c[1][4]) - //ymm12+=beta_vali*(-I(c[2][4])) beta_vali*R(c[2][4]) - // beta_vali*(-I(c[3][4])) beta_vali*R(c[3][4]) + //ymm11+=beta_imag*(-I(c[0][4])) beta_imag*R(c[0][4]) + // beta_imag*(-I(c[1][4])) beta_imag*R(c[1][4]) + //ymm12+=beta_imag*(-I(c[2][4])) beta_imag*R(c[2][4]) + // beta_imag*(-I(c[3][4])) beta_imag*R(c[3][4]) SCALE_BETA_IMAG_M_LOOP(ymm0,ymm1,ymm11,ymm12,ymm15,ymm2); + + //R(c[0][5]) I(c[0][5]) R(c[1][5]) I(c[1][5]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*5)); + //R(c[2][5]) I(c[2][5]) R(c[3][5]) I(c[3][5]) + ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc*5 + 2)); + //ymm13+=beta_imag*(-I(c[0][5])) beta_imag*R(c[0][5]) + // beta_imag*(-I(c[1][5])) beta_imag*R(c[1][5]) + //ymm14+=beta_imag*(-I(c[2][5])) beta_imag*R(c[2][5]) + // beta_imag*(-I(c[3][5])) beta_imag*R(c[3][5]) + SCALE_BETA_IMAG_M_LOOP(ymm0,ymm1,ymm13,ymm14,ymm15,ymm2); } /* The scaling has been done sequentially as follows: - - If alpha_valr is not 0, it is used for scaling A - - If alpha_vali is not 0, it is used for scaling A using permutation + - If alpha_real is not 0, it is used for scaling A + - If alpha_imag is not 0, it is used for scaling A using permutation and selective negation, after loading - - If beta_valr is not 0, is is used for scaling C - - If beta_vali is not 0, it is used for scaling C using permutation + - If beta_real is not 0, is is used for scaling C + - If beta_imag is not 0, it is used for scaling C using permutation and selective negation, after loading The results are accumalated in accordance to the non zero scalar values, @@ -428,10 +457,14 @@ void bli_zgemm_ref_k1_nn _mm256_storeu_pd((double *)(temp_c + ldc*4), ymm11); _mm256_storeu_pd((double *)(temp_c + ldc*4 + 2), ymm12); + _mm256_storeu_pd((double *)(temp_c + ldc*5), ymm13); + _mm256_storeu_pd((double *)(temp_c + ldc*5 + 2), ymm14); + temp_c+=Z_MR; temp_a+=Z_MR; } + // Fringe cases for M dim_t m_rem=m_remainder; if(m_rem>=2) { @@ -441,146 +474,161 @@ void bli_zgemm_ref_k1_nn ymm9 = _mm256_setzero_pd(); ymm11 = _mm256_setzero_pd(); - if(alpha_valr != 0.0 || alpha_vali != 0.0) - { + //R(a[0][0]) I(a[0][0]) R(a[1][0]) I(a[1][0]) + ymm0 = _mm256_loadu_pd((double const *)(temp_a)); - //R(a[0][0]) I(a[0][0]) R(a[1][0]) I(a[1][0]) - ymm0 = _mm256_loadu_pd((double const *)(temp_a)); - ymm13 = ymm0; - SCALE_ALPHA_REAL_M_FRINGE(ymm0,ymm15,alpha_valr); - SCALE_ALPHA_IMAG_M_FRINGE(ymm0,ymm15,ymm2,ymm13,alpha_vali); - /* - The result after scaling with alpha_valr and/or alpha_vali is as follows: - For ymm0 : - R(a[0][0]) = alpha_valr*R(a[0][0])-alpha_vali*I(a[0][0]) - I(a[0][0]) = alpha_valr*I(a[0][0])+alpha_vali*R[0][0] - R(a[1][0]) = alpha_valr*R(a[1][0])-alpha_vali*I(a[1][0]) - I(a[1][0]) = alpha_valr*I(a[1][0])+alpha_vali*(R[1][0]) - */ - - //Calculating using real part of complex number in B matrix - //ymm3+=R(b[0][0])*R(a[0][0]) R(b[0][0])*I(a[0][0]) - // R(b[0][0])*R(a[1][0]) R(b[0][0])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)); - //ymm5+=R(b[0][1])*R(a[0][0]) R(b[0][1])*I(a[0][0]) - // R(b[0][1])*R(a[1][0]) R(b[0][1])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm5,ymm2,(double const *)(temp_b+ldb)); - //ymm7+=R(b[0][2])*R(a[0][0]) R(b[0][2])*I(a[0][0]) - // R(b[0][2])*R(a[1][0]) R(b[0][2])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm7,ymm2,(double const *)(temp_b+ldb*2)); - //ymm9+=R(b[0][3])*R(a[0][0]) R(b[0][3])*I(a[0][0]) - // R(b[0][3])*R(a[1][0]) R(b[0][3])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm9,ymm2,(double const *)(temp_b+ldb*3)); - //ymm11+=R(b[0][4])*R(a[0][0]) R(b[0][4])*I(a[0][0]) - // R(b[0][4])*R(a[1][0]) R(b[0][4])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm11,ymm2,(double const *)(temp_b+ldb*4)); - - //Calculating using imaginary part of complex numbers in B matrix - //Shuffling ymm0 in accordance to the requirement - NEG_PERM_M_FRINGE(ymm0,ymm2); - - // ymm3+=I(b[0][0])*R(a[0][0]) I(b[0][0])*I(a[0][0]) - // I(b[0][0])*R(a[1][0]) I(b[0][0])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)+1); - //ymm5+=I(b[0][1])*R(a[0][0]) I(b[0][1])*I(a[0][0]) - // I(b[0][1])*R(a[1][0]) I(b[0][1])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm5,ymm2,(double const *)(temp_b+ldb)+1); - //ymm7+=I(b[0][2])*R(a[0][0]) I(b[0][2])*I(a[0][0]) - // I(b[0][2])*R(a[1][0]) I(b[0][2])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm7,ymm2,(double const *)(temp_b+ldb*2)+1); - //ymm9+=I(b[0][3])*R(a[0][0]) I(b[0][3])*I(a[0][0]) - // I(b[0][3])*R(a[1][0]) I(b[0][3])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm9,ymm2,(double const *)(temp_b+ldb*3)+1); - //ymm11+=I(b[0][4])*R(a[0][0]) I(b[0][4])*I(a[0][0]) - // I(b[0][4])*R(a[1][0]) I(b[0][4])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm11,ymm2,(double const *)(temp_b+ldb*4)+1); + ymm13 = ymm0; + SCALE_ALPHA_REAL_M_FRINGE(ymm0,ymm15,alpha_real); + SCALE_ALPHA_IMAG_M_FRINGE(ymm0,ymm13,ymm15,ymm2,alpha_imag); - } + ymm13 = _mm256_setzero_pd(); - if(beta_valr != 0.0) - { + /* + The result after scaling with alpha_real and/or alpha_imag is as follows: + For ymm0 : + R(a[0][0]) = alpha_real*R(a[0][0])-alpha_imag*I(a[0][0]) + I(a[0][0]) = alpha_real*I(a[0][0])+alpha_imag*R[0][0] + R(a[1][0]) = alpha_real*R(a[1][0])-alpha_imag*I(a[1][0]) + I(a[1][0]) = alpha_real*I(a[1][0])+alpha_imag*(R[1][0]) + */ - ymm15 = _mm256_broadcast_sd((double const *)(&beta_valr)); + //Calculating using real part of complex number in B matrix + //ymm3+=R(b[0][0])*R(a[0][0]) R(b[0][0])*I(a[0][0]) + // R(b[0][0])*R(a[1][0]) R(b[0][0])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)); + //ymm5+=R(b[0][1])*R(a[0][0]) R(b[0][1])*I(a[0][0]) + // R(b[0][1])*R(a[1][0]) R(b[0][1])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm5,ymm2,(double const *)(temp_b+ldb)); + //ymm7+=R(b[0][2])*R(a[0][0]) R(b[0][2])*I(a[0][0]) + // R(b[0][2])*R(a[1][0]) R(b[0][2])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm7,ymm2,(double const *)(temp_b+ldb*2)); + //ymm9+=R(b[0][3])*R(a[0][0]) R(b[0][3])*I(a[0][0]) + // R(b[0][3])*R(a[1][0]) R(b[0][3])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm9,ymm2,(double const *)(temp_b+ldb*3)); + //ymm11+=R(b[0][4])*R(a[0][0]) R(b[0][4])*I(a[0][0]) + // R(b[0][4])*R(a[1][0]) R(b[0][4])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm11,ymm2,(double const *)(temp_b+ldb*4)); + //ymm13+=R(b[0][5])*R(a[0][0]) R(b[0][5])*I(a[0][0]) + // R(b[0][5])*R(a[1][0]) R(b[0][5])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm13,ymm2,(double const *)(temp_b+ldb*5)); + + //Calculating using imaginary part of complex numbers in B matrix + //Shuffling ymm0 in accordance to the requirement + NEG_PERM_M_FRINGE(ymm0,ymm2); + + // ymm3+=I(b[0][0])*R(a[0][0]) I(b[0][0])*I(a[0][0]) + // I(b[0][0])*R(a[1][0]) I(b[0][0])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)+1); + //ymm5+=I(b[0][1])*R(a[0][0]) I(b[0][1])*I(a[0][0]) + // I(b[0][1])*R(a[1][0]) I(b[0][1])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm5,ymm2,(double const *)(temp_b+ldb)+1); + //ymm7+=I(b[0][2])*R(a[0][0]) I(b[0][2])*I(a[0][0]) + // I(b[0][2])*R(a[1][0]) I(b[0][2])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm7,ymm2,(double const *)(temp_b+ldb*2)+1); + //ymm9+=I(b[0][3])*R(a[0][0]) I(b[0][3])*I(a[0][0]) + // I(b[0][3])*R(a[1][0]) I(b[0][3])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm9,ymm2,(double const *)(temp_b+ldb*3)+1); + //ymm11+=I(b[0][4])*R(a[0][0]) I(b[0][4])*I(a[0][0]) + // I(b[0][4])*R(a[1][0]) I(b[0][4])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm11,ymm2,(double const *)(temp_b+ldb*4)+1); + //ymm13+=I(b[0][5])*R(a[0][0]) I(b[0][5])*I(a[0][0]) + // I(b[0][5])*R(a[1][0]) I(b[0][5])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm13,ymm2,(double const *)(temp_b+ldb*5)+1); + + + if(beta_real != 0.0) + { + ymm15 = _mm256_broadcast_sd((double const *)(&beta_real)); //R(c[0][0]) I(c[0][0]) R(c[1][0]) I(c[1][0]) ymm0 = _mm256_loadu_pd((double const *)(temp_c)); - //ymm3+=beta_valr*R(c[0][0]) beta_valr*I(c[0][0]) - // beta_valr*R(c[1][0]) beta_valr*I(c[1][0]) + //ymm3+=beta_real*R(c[0][0]) beta_real*I(c[0][0]) + // beta_real*R(c[1][0]) beta_real*I(c[1][0]) SCALE_BETA_REAL_M_FRINGE(ymm0,ymm3,ymm15); //R(c[0][1]) I(c[0][1]) R(c[1][1]) I(c[1][1]) ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc)); - //ymm5+=beta_valr*R(c[0][1]) beta_valr*I(c[0][1]) - // beta_valr*R(c[1][1]) beta_valr*I(c[1][1]) + //ymm5+=beta_real*R(c[0][1]) beta_real*I(c[0][1]) + // beta_real*R(c[1][1]) beta_real*I(c[1][1]) SCALE_BETA_REAL_M_FRINGE(ymm0,ymm5,ymm15); //R(c[0][2]) I(c[0][2]) R(c[1][2]) I(c[1][2]) ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*2)); - //ymm7+=beta_valr*R(c[0][2]) beta_valr*I(c[0][2]) - // beta_valr*R(c[1][2]) beta_valr*I(c[1][2]) + //ymm7+=beta_real*R(c[0][2]) beta_real*I(c[0][2]) + // beta_real*R(c[1][2]) beta_real*I(c[1][2]) SCALE_BETA_REAL_M_FRINGE(ymm0,ymm7,ymm15); //R(c[0][3]) I(c[0][3]) R(c[1][3]) I(c[1][3]) ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*3)); - //ymm9+=beta_valr*R(c[0][3]) beta_valr*I(c[0][3]) - // beta_valr*R(c[1][3]) beta_valr*I(c[1][3]) + //ymm9+=beta_real*R(c[0][3]) beta_real*I(c[0][3]) + // beta_real*R(c[1][3]) beta_real*I(c[1][3]) SCALE_BETA_REAL_M_FRINGE(ymm0,ymm9,ymm15); //R(c[0][4]) I(c[0][4]) R(c[1][4]) I(c[1][4]) ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*4)); - //ymm11+=beta_valr*R(c[0][4]) beta_valr*I(c[0][4]) - // beta_valr*R(c[1][4]) beta_valr*I(c[1][4]) + //ymm11+=beta_real*R(c[0][4]) beta_real*I(c[0][4]) + // beta_real*R(c[1][4]) beta_real*I(c[1][4]) SCALE_BETA_REAL_M_FRINGE(ymm0,ymm11,ymm15); + //R(c[0][5]) I(c[0][5]) R(c[1][5]) I(c[1][5]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*5)); + //ymm13+=beta_real*R(c[0][5]) beta_real*I(c[0][5]) + // beta_real*R(c[1][5]) beta_real*I(c[1][5]) + SCALE_BETA_REAL_M_FRINGE(ymm0,ymm13,ymm15); } - if(beta_vali != 0.0) - { - ymm15 = _mm256_broadcast_sd((double const *)(&beta_vali)); + if(beta_imag != 0.0) + { + ymm15 = _mm256_broadcast_sd((double const *)(&beta_imag)); //R(c[0][0]) I(c[0][0]) R(c[1][0]) I(c[1][0]) ymm0 = _mm256_loadu_pd((double const *)(temp_c)); - //ymm3+=beta_vali*(-I(c[0][0])) beta_vali*R(c[0][0]) - // beta_vali*(-I(c[1][0])) beta_vali*R(c[1][0]) + //ymm3+=beta_imag*(-I(c[0][0])) beta_imag*R(c[0][0]) + // beta_imag*(-I(c[1][0])) beta_imag*R(c[1][0]) SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm3,ymm15,ymm2); //R(c[0][1]) I(c[0][1]) R(c[1][1]) I(c[1][1]) ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc)); - //ymm5+=beta_vali*(-I(c[0][1])) beta_vali*R(c[0][1]) - // beta_vali*(-I(c[1][1])) beta_vali*R(c[1][1]) + //ymm5+=beta_imag*(-I(c[0][1])) beta_imag*R(c[0][1]) + // beta_imag*(-I(c[1][1])) beta_imag*R(c[1][1]) SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm5,ymm15,ymm2); //R(c[0][2]) I(c[0][2]) R(c[1][2]) I(c[1][2]) ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*2)); - //ymm7+=beta_vali*(-I(c[0][2])) beta_vali*R(c[0][2]) - // beta_vali*(-I(c[1][2])) beta_vali*R(c[1][2]) + //ymm7+=beta_imag*(-I(c[0][2])) beta_imag*R(c[0][2]) + // beta_imag*(-I(c[1][2])) beta_imag*R(c[1][2]) SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm7,ymm15,ymm2); //R(c[0][3]) I(c[0][3]) R(c[1][3]) I(c[1][3]) ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*3)); - //ymm9+=beta_vali*(-I(c[0][3])) beta_vali*R(c[0][3]) - // beta_vali*(-I(c[1][3])) beta_vali*R(c[1][3]) + //ymm9+=beta_imag*(-I(c[0][3])) beta_imag*R(c[0][3]) + // beta_imag*(-I(c[1][3])) beta_imag*R(c[1][3]) SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm9,ymm15,ymm2); //R(c[0][4]) I(c[0][4]) R(c[1][4]) I(c[1][4]) ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*4)); - //ymm11+=beta_vali*(-I(c[0][4])) beta_vali*R(c[0][4]) - // beta_vali*(-I(c[1][4])) beta_vali*R(c[1][4]) + //ymm11+=beta_imag*(-I(c[0][4])) beta_imag*R(c[0][4]) + // beta_imag*(-I(c[1][4])) beta_imag*R(c[1][4]) SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm11,ymm15,ymm2); + + //R(c[0][5]) I(c[0][5]) R(c[1][5]) I(c[1][5]) + ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*5)); + //ymm13+=beta_imag*(-I(c[0][5])) beta_imag*R(c[0][5]) + // beta_imag*(-I(c[1][5])) beta_imag*R(c[1][5]) + SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm13,ymm15,ymm2); } /* The scaling has been done sequentially as follows: - - If alpha_valr is not 0, it is used for scaling A - - If alpha_vali is not 0, it is used for scaling A using permutation + - If alpha_real is not 0, it is used for scaling A + - If alpha_imag is not 0, it is used for scaling A using permutation and selective negation, after loading - - If beta_valr is not 0, is is used for scaling C - - If beta_vali is not 0, it is used for scaling C using permutation + - If beta_real is not 0, is is used for scaling C + - If beta_imag is not 0, it is used for scaling C using permutation and selective negation, after loading - The results are accumalated in accordance to the non zero scalar values, - and similar approach is followed in fringe cases + The results are accumalated in accordance to the non zero scalar values. */ _mm256_storeu_pd((double *)(temp_c), ymm3); @@ -588,6 +636,7 @@ void bli_zgemm_ref_k1_nn _mm256_storeu_pd((double *)(temp_c + ldc*2), ymm7); _mm256_storeu_pd((double *)(temp_c + ldc*3), ymm9); _mm256_storeu_pd((double *)(temp_c + ldc*4), ymm11); + _mm256_storeu_pd((double *)(temp_c + ldc*5), ymm13); temp_c+=2; temp_a+=2; @@ -605,111 +654,126 @@ void bli_zgemm_ref_k1_nn ymm9 = _mm256_setzero_pd(); ymm11 = _mm256_setzero_pd(); - if(alpha_valr != 0.0 || alpha_vali != 0.0) + xmm5 = _mm_loadu_pd((double const*)(temp_a));//R(a[0][0]) I(a[0][0]) + ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(a[0][0]) I(a[0][0]) + + ymm13 = ymm0; + SCALE_ALPHA_REAL_M_FRINGE(ymm0,ymm15,alpha_real); + SCALE_ALPHA_IMAG_M_FRINGE(ymm0,ymm13,ymm15,ymm2,alpha_imag); + + ymm13 = _mm256_setzero_pd(); + + //Calculating using real part of complex number in B matrix + //ymm3+=R(b[0][0])*R(a[0][0]) R(b[0][0])*I(a[0][0]) + // R(b[0][0])*R(a[1][0]) R(b[0][0])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)); + //ymm5+=R(b[0][1])*R(a[0][0]) R(b[0][1])*I(a[0][0]) + // R(b[0][1])*R(a[1][0]) R(b[0][1])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm5,ymm2,(double const *)(temp_b+ldb)); + //ymm7+=R(b[0][2])*R(a[0][0]) R(b[0][2])*I(a[0][0]) + // R(b[0][2])*R(a[1][0]) R(b[0][2])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm7,ymm2,(double const *)(temp_b+ldb*2)); + //ymm9+=R(b[0][3])*R(a[0][0]) R(b[0][3])*I(a[0][0]) + // R(b[0][3])*R(a[1][0]) R(b[0][3])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm9,ymm2,(double const *)(temp_b+ldb*3)); + //ymm11+=R(b[0][4])*R(a[0][0]) R(b[0][4])*I(a[0][0]) + // R(b[0][4])*R(a[1][0]) R(b[0][4])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm11,ymm2,(double const *)(temp_b+ldb*4)); + //ymm13+=R(b[0][5])*R(a[0][0]) R(b[0][5])*I(a[0][0]) + // R(b[0][5])*R(a[1][0]) R(b[0][5])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm13,ymm2,(double const *)(temp_b+ldb*5)); + + //Calculating using imaginary part of complex numbers in B matrix + //Shuffling ymm0 in accordance to the requirement + NEG_PERM_M_FRINGE(ymm0,ymm2); + + // ymm3+=I(b[0][0])*R(a[0][0]) I(b[0][0])*I(a[0][0]) + // I(b[0][0])*R(a[1][0]) I(b[0][0])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)+1); + //ymm5+=I(b[0][1])*R(a[0][0]) I(b[0][1])*I(a[0][0]) + // I(b[0][1])*R(a[1][0]) I(b[0][1])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm5,ymm2,(double const *)(temp_b+ldb)+1); + //ymm7+=I(b[0][2])*R(a[0][0]) I(b[0][2])*I(a[0][0]) + // I(b[0][2])*R(a[1][0]) I(b[0][2])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm7,ymm2,(double const *)(temp_b+ldb*2)+1); + //ymm9+=I(b[0][3])*R(a[0][0]) I(b[0][3])*I(a[0][0]) + // I(b[0][3])*R(a[1][0]) I(b[0][3])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm9,ymm2,(double const *)(temp_b+ldb*3)+1); + //ymm11+=I(b[0][4])*R(a[0][0]) I(b[0][4])*I(a[0][0]) + // I(b[0][4])*R(a[1][0]) I(b[0][4])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm11,ymm2,(double const *)(temp_b+ldb*4)+1); + //ymm13+=I(b[0][5])*R(a[0][0]) I(b[0][5])*I(a[0][0]) + // I(b[0][5])*R(a[1][0]) I(b[0][5])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm13,ymm2,(double const *)(temp_b+ldb*5)+1); + + if(beta_real != 0.0) { - xmm5 = _mm_loadu_pd((double const*)(temp_a));//R(a[0][0]) I(a[0][0]) - ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(a[0][0]) I(a[0][0]) - ymm13 = ymm0; - - SCALE_ALPHA_REAL_M_FRINGE(ymm0,ymm15,alpha_valr); - SCALE_ALPHA_IMAG_M_FRINGE(ymm0,ymm15,ymm2,ymm13,alpha_vali); - - //Calculating using real part of complex number in B matrix - //ymm3+=R(b[0][0])*R(a[0][0]) R(b[0][0])*I(a[0][0]) - // R(b[0][0])*R(a[1][0]) R(b[0][0])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)); - //ymm5+=R(b[0][1])*R(a[0][0]) R(b[0][1])*I(a[0][0]) - // R(b[0][1])*R(a[1][0]) R(b[0][1])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm5,ymm2,(double const *)(temp_b+ldb)); - //ymm7+=R(b[0][2])*R(a[0][0]) R(b[0][2])*I(a[0][0]) - // R(b[0][2])*R(a[1][0]) R(b[0][2])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm7,ymm2,(double const *)(temp_b+ldb*2)); - //ymm9+=R(b[0][3])*R(a[0][0]) R(b[0][3])*I(a[0][0]) - // R(b[0][3])*R(a[1][0]) R(b[0][3])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm9,ymm2,(double const *)(temp_b+ldb*3)); - //ymm11+=R(b[0][4])*R(a[0][0]) R(b[0][4])*I(a[0][0]) - // R(b[0][4])*R(a[1][0]) R(b[0][4])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm11,ymm2,(double const *)(temp_b+ldb*4)); - - //Calculating using imaginary part of complex numbers in B matrix - //Shuffling ymm0 in accordance to the requirement - NEG_PERM_M_FRINGE(ymm0,ymm2); - - // ymm3+=I(b[0][0])*R(a[0][0]) I(b[0][0])*I(a[0][0]) - // I(b[0][0])*R(a[1][0]) I(b[0][0])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)+1); - //ymm5+=I(b[0][1])*R(a[0][0]) I(b[0][1])*I(a[0][0]) - // I(b[0][1])*R(a[1][0]) I(b[0][1])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm5,ymm2,(double const *)(temp_b+ldb)+1); - //ymm7+=I(b[0][2])*R(a[0][0]) I(b[0][2])*I(a[0][0]) - // I(b[0][2])*R(a[1][0]) I(b[0][2])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm7,ymm2,(double const *)(temp_b+ldb*2)+1); - //ymm9+=I(b[0][3])*R(a[0][0]) I(b[0][3])*I(a[0][0]) - // I(b[0][3])*R(a[1][0]) I(b[0][3])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm9,ymm2,(double const *)(temp_b+ldb*3)+1); - //ymm11+=I(b[0][4])*R(a[0][0]) I(b[0][4])*I(a[0][0]) - // I(b[0][4])*R(a[1][0]) I(b[0][4])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm11,ymm2,(double const *)(temp_b+ldb*4)+1); - - } - if(beta_valr != 0.0) - { - ymm15 = _mm256_broadcast_sd((double const *)(&beta_valr)); + ymm15 = _mm256_broadcast_sd((double const *)(&beta_real)); xmm5 = _mm_loadu_pd((double const*)(temp_c));//R(c[0][0]) I(c[0][0]) ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][0]) I(c[0][0]) - //ymm3+=beta_valr*R(c[0][0]) beta_valr*I(c[0][0]) + //ymm3+=beta_real*R(c[0][0]) beta_real*I(c[0][0]) SCALE_BETA_REAL_M_FRINGE(ymm0,ymm3,ymm15); xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc));//R(c[0][1]) I(c[0][1]) ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][1]) I(c[0][1]) - //ymm5+=beta_valr*R(c[0][1]) beta_valr*I(c[0][1]) + //ymm5+=beta_real*R(c[0][1]) beta_real*I(c[0][1]) SCALE_BETA_REAL_M_FRINGE(ymm0,ymm5,ymm15); xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc * 2));//R(c[0][2]) I(c[0][2]) ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][2]) I(c[0][2]) - //ymm7+=beta_valr*R(c[0][2]) beta_valr*I(c[0][2]) + //ymm7+=beta_real*R(c[0][2]) beta_real*I(c[0][2]) SCALE_BETA_REAL_M_FRINGE(ymm0,ymm7,ymm15); xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc * 3));//R(c[0][3]) I(c[0][3]) ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][3]) I(c[0][3]) - //ymm9+=beta_valr*R(c[0][3]) beta_valr*I(c[0][3]) + //ymm9+=beta_real*R(c[0][3]) beta_real*I(c[0][3]) SCALE_BETA_REAL_M_FRINGE(ymm0,ymm9,ymm15); xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc * 4));//R(c[0][4]) I(c[0][4]) ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][4]) I(c[0][4]) - //ymm11+=beta_valr*R(c[0][4]) beta_valr*I(c[0][4]) + //ymm11+=beta_real*R(c[0][4]) beta_real*I(c[0][4]) SCALE_BETA_REAL_M_FRINGE(ymm0,ymm11,ymm15); + + xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc * 5));//R(c[0][5]) I(c[0][5]) + ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][5]) I(c[0][5]) + //ymm13+=beta_real*R(c[0][5]) beta_real*I(c[0][5]) + SCALE_BETA_REAL_M_FRINGE(ymm0,ymm13,ymm15); } - if(beta_vali != 0.0) + + if(beta_imag != 0.0) { - ymm15 = _mm256_broadcast_sd((double const *)(&beta_vali)); + ymm15 = _mm256_broadcast_sd((double const *)(&beta_imag)); xmm5 = _mm_loadu_pd((double const*)(temp_c));//R(c[0][0]) I(c[0][0]) ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][0]) I(c[0][0]) - //ymm3+=beta_vali*(-I(c[0][0])) beta_vali*R(c[0][0]) + //ymm3+=beta_imag*(-I(c[0][0])) beta_imag*R(c[0][0]) SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm3,ymm15,ymm2); xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc));//R(c[0][1]) I(c[0][1]) ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][1]) I(c[0][1]) - //ymm5+=beta_vali*(-I(c[0][1])) beta_vali*R(c[0][1]) + //ymm5+=beta_imag*(-I(c[0][1])) beta_imag*R(c[0][1]) SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm5,ymm15,ymm2); xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc * 2));//R(c[0][2]) I(c[0][2]) ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][2]) I(c[0][2]) - //ymm7+=beta_vali*(-I(c[0][2])) beta_vali*R(c[0][2]) + //ymm7+=beta_imag*(-I(c[0][2])) beta_imag*R(c[0][2]) SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm7,ymm15,ymm2); xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc * 3));//R(c[0][3]) I(c[0][3]) ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][3]) I(c[0][3]) - //ymm9+=beta_vali*(-I(c[0][3])) beta_vali*R(c[0][3]) + //ymm9+=beta_imag*(-I(c[0][3])) beta_imag*R(c[0][3]) SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm9,ymm15,ymm2); xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc * 4));//R(c[0][4]) I(c[0][4]) ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][4]) I(c[0][4]) - //ymm11+=beta_vali*(-I(c[0][4])) beta_vali*R(c[0][4]) + //ymm11+=beta_imag*(-I(c[0][4])) beta_imag*R(c[0][4]) SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm11,ymm15,ymm2); + xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc * 5));//R(c[0][5]) I(c[0][5]) + ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][5]) I(c[0][5]) + //ymm13+=beta_imag*(-I(c[0][5])) beta_imag*R(c[0][5]) + SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm13,ymm15,ymm2); } xmm5 = _mm256_extractf128_pd(ymm3, 0); @@ -727,14 +791,21 @@ void bli_zgemm_ref_k1_nn xmm5 = _mm256_extractf128_pd(ymm11, 0); _mm_storeu_pd((double *)(temp_c + ldc*4), xmm5); + xmm5 = _mm256_extractf128_pd(ymm13, 0); + _mm_storeu_pd((double *)(temp_c + ldc*5), xmm5); + } } - if(n_remainder==4) + + //Fringe case for N + if(n_remainder>=4) { dcomplex* temp_b = b + (n - n_remainder)*ldb; dcomplex* temp_a = a; dcomplex* temp_c = c + (n - n_remainder)*ldc; + + //Main loop for M for(dim_t i = 0;i < (m-Z_MR+1);i=i+Z_MR) { ymm3 = _mm256_setzero_pd(); @@ -746,174 +817,168 @@ void bli_zgemm_ref_k1_nn ymm9 = _mm256_setzero_pd(); ymm10 = _mm256_setzero_pd(); - if(alpha_valr != 0.0 || alpha_vali != 0.0) - { - /* - a. Perform alpha*A*B using temp_a, temp_b and alpha_valr, alpha_vali - where alpha_valr and/or alpha_vali is not zero. - b. This loop operates with 4x5 block size - along n dimension for every Z_NR columns of temp_b where - computing all Z_MR rows of temp_a. - c. Same approach is used in remaining fringe cases. - */ - - //R(a[0][0]) I(a[0][0]) R(a[1][0]) I(a[1][0]) - ymm0 = _mm256_loadu_pd((double const *)(temp_a)); - //R(a[2][0]) I(a[2][0]) R(a[3][0]) I(a[3][0]) - ymm1 = _mm256_loadu_pd((double const *)(temp_a + 2)); - - ymm13 = ymm0; - ymm14 = ymm1; - - _mm_prefetch((char*)(temp_a) + 64, _MM_HINT_T0); - - SCALE_ALPHA_REAL_M_LOOP(ymm0,ymm1,ymm15,alpha_valr); - SCALE_ALPHA_IMAG_M_LOOP(ymm0,ymm1,ymm15,ymm2,ymm13,ymm14,alpha_vali); - - /* - The result after scaling with alpha_valr and/or alpha_vali is as follows: - For ymm0 : - R(a[0][0]) = alpha_valr*R(a[0][0])-alpha_vali*I(a[0][0]) - I(a[0][0]) = alpha_valr*I(a[0][0])+alpha_vali*R[0][0] - R(a[1][0]) = alpha_valr*R(a[1][0])-alpha_vali*I(a[1][0]) - I(a[1][0]) = alpha_valr*I(a[1][0])+alpha_vali*(R[1][0]) - - For ymm1 : - R(a[2][0]) = alpha_valr*R(a[2][0])-alpha_vali*I(a[2][0]) - I(a[2][0]) = alpha_valr*I(a[2][0])+alpha_vali*R[2][0] - R(a[3][0]) = alpha_valr*R(a[3][0])-alpha_vali*I(a[3][0]) - I(a[3][0]) = alpha_valr*I(a[3][0])+alpha_vali*(R[3][0]) - */ - - //Calculating using real part of complex number in B matrix - FMA_M_LOOP(ymm0,ymm1,ymm3,ymm4,ymm2,(double const *)(temp_b)); - FMA_M_LOOP(ymm0,ymm1,ymm5,ymm6,ymm2,(double const *)(temp_b+ldb)); - FMA_M_LOOP(ymm0,ymm1,ymm7,ymm8,ymm2,(double const *)(temp_b+ldb*2)); - FMA_M_LOOP(ymm0,ymm1,ymm9,ymm10,ymm2,(double const *)(temp_b+ldb*3)); - - //Calculating using imaginary part of complex numbers in B matrix - //Shuffling ymm0 and ymm1 in accordance to the requirement - NEG_PERM_M_LOOP(ymm0,ymm1,ymm2); - FMA_M_LOOP(ymm0,ymm1,ymm3,ymm4,ymm2,(double const *)(temp_b)+1); - FMA_M_LOOP(ymm0,ymm1,ymm5,ymm6,ymm2,(double const *)(temp_b+ldb)+1); - FMA_M_LOOP(ymm0,ymm1,ymm7,ymm8,ymm2,(double const *)(temp_b+ldb*2)+1); - FMA_M_LOOP(ymm0,ymm1,ymm9,ymm10,ymm2,(double const *)(temp_b+ldb*3)+1); - } - if(beta_valr != 0.0) + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_real, alpha_imag + where alpha_real and/or alpha_imag is not zero. + b. This loop operates with 4x6 block size + along n dimension for every Z_NR columns of temp_b where + computing all Z_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + + //R(a[0][0]) I(a[0][0]) R(a[1][0]) I(a[1][0]) + ymm0 = _mm256_loadu_pd((double const *)(temp_a)); + //R(a[2][0]) I(a[2][0]) R(a[3][0]) I(a[3][0]) + ymm1 = _mm256_loadu_pd((double const *)(temp_a + 2)); + + ymm13 = ymm0; + ymm14 = ymm1; + SCALE_ALPHA_REAL_M_LOOP(ymm0,ymm1,ymm15,alpha_real); + SCALE_ALPHA_IMAG_M_LOOP(ymm0,ymm1,ymm13,ymm14,ymm15,ymm2,alpha_imag); + + /* + The result after scaling with alpha_real and/or alpha_imag is as follows: + For ymm0 : + R(a[0][0]) = alpha_real*R(a[0][0])-alpha_imag*I(a[0][0]) + I(a[0][0]) = alpha_real*I(a[0][0])+alpha_imag*R[0][0] + R(a[1][0]) = alpha_real*R(a[1][0])-alpha_imag*I(a[1][0]) + I(a[1][0]) = alpha_real*I(a[1][0])+alpha_imag*(R[1][0]) + + For ymm1 : + R(a[2][0]) = alpha_real*R(a[2][0])-alpha_imag*I(a[2][0]) + I(a[2][0]) = alpha_real*I(a[2][0])+alpha_imag*R[2][0] + R(a[3][0]) = alpha_real*R(a[3][0])-alpha_imag*I(a[3][0]) + I(a[3][0]) = alpha_real*I(a[3][0])+alpha_imag*(R[3][0]) + */ + + //Calculating using real part of complex number in B matrix + FMA_M_LOOP(ymm0,ymm1,ymm3,ymm4,ymm2,(double const *)(temp_b)); + FMA_M_LOOP(ymm0,ymm1,ymm5,ymm6,ymm2,(double const *)(temp_b+ldb)); + FMA_M_LOOP(ymm0,ymm1,ymm7,ymm8,ymm2,(double const *)(temp_b+ldb*2)); + FMA_M_LOOP(ymm0,ymm1,ymm9,ymm10,ymm2,(double const *)(temp_b+ldb*3)); + + //Calculating using imaginary part of complex numbers in B matrix + //Shuffling ymm0 and ymm1 in accordance to the requirement + NEG_PERM_M_LOOP(ymm0,ymm1,ymm2); + FMA_M_LOOP(ymm0,ymm1,ymm3,ymm4,ymm2,(double const *)(temp_b)+1); + FMA_M_LOOP(ymm0,ymm1,ymm5,ymm6,ymm2,(double const *)(temp_b+ldb)+1); + FMA_M_LOOP(ymm0,ymm1,ymm7,ymm8,ymm2,(double const *)(temp_b+ldb*2)+1); + FMA_M_LOOP(ymm0,ymm1,ymm9,ymm10,ymm2,(double const *)(temp_b+ldb*3)+1); + + /* + a. Perform beta*C using temp_c, beta_real, + where beta_real is not zero. + b. This loop operates with 4x6 block size + along n dimension for every Z_NR columns of temp_c where + computing all Z_MR rows of temp_c. + c. Accumulated alpha*A*B into registers will be added to beta*C + d. Same approach is used in remaining fringe cases. + */ + if(beta_real != 0.0) { - /* - a. Perform beta*C using temp_c, beta_valr, - where beta_valr is not zero. - b. This loop operates with 4x5 block size - along n dimension for every Z_NR columns of temp_c where - computing all Z_MR rows of temp_c. - c. Accumulated alpha*A*B into registers will be added to beta*C - d. Same approach is used in remaining fringe cases. - */ - ymm15 = _mm256_broadcast_sd((double const *)(&beta_valr)); + ymm15 = _mm256_broadcast_sd((double const *)(&beta_real)); //R(c[0][0]) I(c[0][0]) R(c[1][0]) I(c[1][0]) ymm0 = _mm256_loadu_pd((double const *)(temp_c)); //R(c[2][0]) I(c[2][0]) R(c[3][0]) I(c[3][0]) ymm1 = _mm256_loadu_pd((double const *)(temp_c + 2)); - //ymm3+=beta_valr*R(c[0][0]) beta_valr*I(c[0][0]) - // beta_valr*R(c[1][0]) beta_valr*I(c[1][0]) - //ymm4+=beta_valr*R(c[2][0]) beta_valr*I(c[2][0]) - // beta_valr*R(c[3][0]) beta_valr*I(c[3][0]) + //ymm3+=beta_real*R(c[0][0]) beta_real*I(c[0][0]) + // beta_real*R(c[1][0]) beta_real*I(c[1][0]) + //ymm4+=beta_real*R(c[2][0]) beta_real*I(c[2][0]) + // beta_real*R(c[3][0]) beta_real*I(c[3][0]) SCALE_BETA_REAL_M_LOOP(ymm0,ymm1,ymm3,ymm4,ymm15); //R(c[0][1]) I(c[0][1]) R(c[1][1]) I(c[1][1]) ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc)); //R(c[2][1]) I(c[2][1]) R(c[3][1]) I(c[3][1]) ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc + 2)); - //ymm5+=beta_valr*R(c[0][1]) beta_valr*I(c[0][1]) - // beta_valr*R(c[1][1]) beta_valr*I(c[1][1]) - //ymm6+=beta_valr*R(c[2][1]) beta_valr*I(c[2][1]) - // beta_valr*R(c[3][1]) beta_valr*I(c[3][1]) + //ymm5+=beta_real*R(c[0][1]) beta_real*I(c[0][1]) + // beta_real*R(c[1][1]) beta_real*I(c[1][1]) + //ymm6+=beta_real*R(c[2][1]) beta_real*I(c[2][1]) + // beta_real*R(c[3][1]) beta_real*I(c[3][1]) SCALE_BETA_REAL_M_LOOP(ymm0,ymm1,ymm5,ymm6,ymm15); //R(c[0][2]) I(c[0][2]) R(c[1][2]) I(c[1][2]) ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*2)); //R(c[2][2]) I(c[2][2]) R(c[3][2]) I(c[3][2]) ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc*2 + 2)); - //ymm7+=beta_valr*R(c[0][2]) beta_valr*I(c[0][2]) - // beta_valr*R(c[1][2]) beta_valr*I(c[1][2]) - //ymm8+=beta_valr*R(c[2][2]) beta_valr*I(c[2][2]) - // beta_valr*R(c[3][2]) beta_valr*I(c[3][2]) + //ymm7+=beta_real*R(c[0][2]) beta_real*I(c[0][2]) + // beta_real*R(c[1][2]) beta_real*I(c[1][2]) + //ymm8+=beta_real*R(c[2][2]) beta_real*I(c[2][2]) + // beta_real*R(c[3][2]) beta_real*I(c[3][2]) SCALE_BETA_REAL_M_LOOP(ymm0,ymm1,ymm7,ymm8,ymm15); //R(c[0][3]) I(c[0][3]) R(c[1][3]) I(c[1][3]) ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*3)); //R(c[2][3]) I(c[2][3]) R(c[3][3]) I(c[3][3]) ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc*3 + 2)); - //ymm9+=beta_valr*R(c[0][3]) beta_valr*I(c[0][3]) - // beta_valr*R(c[1][3]) beta_valr*I(c[1][3]) - //ymm10+=beta_valr*R(c[2][3]) beta_valr*I(c[2][3]) - // beta_valr*R(c[3][3]) beta_valr*I(c[3][3]) + //ymm9+=beta_real*R(c[0][3]) beta_real*I(c[0][3]) + // beta_real*R(c[1][3]) beta_real*I(c[1][3]) + //ymm10+=beta_real*R(c[2][3]) beta_real*I(c[2][3]) + // beta_real*R(c[3][3]) beta_real*I(c[3][3]) SCALE_BETA_REAL_M_LOOP(ymm0,ymm1,ymm9,ymm10,ymm15); - } - if(beta_vali != 0.0) + /* + a. Perform beta*C using temp_c, beta_imag, + where beta_imag is not zero. + b. This loop operates with 4x6 block size + along n dimension for every Z_NR columns of temp_c where + computing all Z_MR rows of temp_c. + c. Accumulated alpha*A*B into registers will be added to beta*C + d. Same approach is used in remaining fringe cases. + */ + + if(beta_imag != 0.0) { - /* - a. Perform beta*C using temp_c, beta_vali, - where beta_vali is not zero. - b. This loop operates with 4x5 block size - along n dimension for every Z_NR columns of temp_c where - computing all Z_MR rows of temp_c. - c. Accumulated alpha*A*B into registers will be added to beta*C - d. Same approach is used in remaining fringe cases. - */ - - ymm15 = _mm256_broadcast_sd((double const *)(&beta_vali)); + ymm15 = _mm256_broadcast_sd((double const *)(&beta_imag)); //R(c[0][0]) I(c[0][0]) R(c[1][0]) I(c[1][0]) ymm0 = _mm256_loadu_pd((double const *)(temp_c)); //R(c[2][0]) I(c[2][0]) R(c[3][0]) I(c[3][0]) ymm1 = _mm256_loadu_pd((double const *)(temp_c + 2)); - //ymm3+=beta_vali*(-I(c[0][0])) beta_vali*R(c[0][0]) - // beta_vali*(-I(c[1][0])) beta_vali*R(c[1][0]) - //ymm4+=beta_vali*(-I(c[2][0])) beta_vali*R(c[2][0]) - // beta_vali*(-I(c[3][0])) beta_vali*R(c[3][0]) + //ymm3+=beta_imag*(-I(c[0][0])) beta_imag*R(c[0][0]) + // beta_imag*(-I(c[1][0])) beta_imag*R(c[1][0]) + //ymm4+=beta_imag*(-I(c[2][0])) beta_imag*R(c[2][0]) + // beta_imag*(-I(c[3][0])) beta_imag*R(c[3][0]) SCALE_BETA_IMAG_M_LOOP(ymm0,ymm1,ymm3,ymm4,ymm15,ymm2); //R(c[0][1]) I(c[0][1]) R(c[1][1]) I(c[1][1]) ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc)); //R(c[2][1]) I(c[2][1]) R(c[3][1]) I(c[3][1]) ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc + 2)); - //ymm5+=beta_vali*(-I(c[0][1])) beta_vali*R(c[0][1]) - // beta_vali*(-I(c[1][1])) beta_vali*R(c[1][1]) - //ymm6+=beta_vali*(-I(c[2][1])) beta_vali*R(c[2][1]) - // beta_vali*(-I(c[3][1])) beta_vali*R(c[3][1]) + //ymm5+=beta_imag*(-I(c[0][1])) beta_imag*R(c[0][1]) + // beta_imag*(-I(c[1][1])) beta_imag*R(c[1][1]) + //ymm6+=beta_imag*(-I(c[2][1])) beta_imag*R(c[2][1]) + // beta_imag*(-I(c[3][1])) beta_imag*R(c[3][1]) SCALE_BETA_IMAG_M_LOOP(ymm0,ymm1,ymm5,ymm6,ymm15,ymm2); //R(c[0][2]) I(c[0][2]) R(c[1][2]) I(c[1][2]) ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*2)); //R(c[2][2]) I(c[2][2]) R(c[3][2]) I(c[3][2]) ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc*2 + 2)); - //ymm7+=beta_vali*(-I(c[0][2])) beta_vali*R(c[0][2]) - // beta_vali*(-I(c[1][2])) beta_vali*R(c[1][2]) - //ymm8+=beta_vali*(-I(c[2][2])) beta_vali*R(c[2][2]) - // beta_vali*(-I(c[3][2])) beta_vali*R(c[3][2]) + //ymm7+=beta_imag*(-I(c[0][2])) beta_imag*R(c[0][2]) + // beta_imag*(-I(c[1][2])) beta_imag*R(c[1][2]) + //ymm8+=beta_imag*(-I(c[2][2])) beta_imag*R(c[2][2]) + // beta_imag*(-I(c[3][2])) beta_imag*R(c[3][2]) SCALE_BETA_IMAG_M_LOOP(ymm0,ymm1,ymm7,ymm8,ymm15,ymm2); //R(c[0][3]) I(c[0][3]) R(c[1][3]) I(c[1][3]) ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*3)); //R(c[2][3]) I(c[2][3]) R(c[3][3]) I(c[3][3]) ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc*3 + 2)); - //ymm9+=beta_vali*(-I(c[0][3])) beta_vali*R(c[0][3]) - // beta_vali*(-I(c[1][3])) beta_vali*R(c[1][3]) - //ymm10+=beta_vali*(-I(c[2][3])) beta_vali*R(c[2][3]) - // beta_vali*(-I(c[3][3])) beta_vali*R(c[3][3]) + //ymm9+=beta_imag*(-I(c[0][3])) beta_imag*R(c[0][3]) + // beta_imag*(-I(c[1][3])) beta_imag*R(c[1][3]) + //ymm10+=beta_imag*(-I(c[2][3])) beta_imag*R(c[2][3]) + // beta_imag*(-I(c[3][3])) beta_imag*R(c[3][3]) SCALE_BETA_IMAG_M_LOOP(ymm0,ymm1,ymm9,ymm10,ymm15,ymm2); } /* The scaling has been done sequentially as follows: - - If alpha_valr is not 0, it is used for scaling A - - If alpha_vali is not 0, it is used for scaling A using permutation + - If alpha_real is not 0, it is used for scaling A + - If alpha_imag is not 0, it is used for scaling A using permutation and selective negation, after loading - - If beta_valr is not 0, is is used for scaling C - - If beta_vali is not 0, it is used for scaling C using permutation + - If beta_real is not 0, is is used for scaling C + - If beta_imag is not 0, it is used for scaling C using permutation and selective negation, after loading The results are accumalated in accordance to the non zero scalar values, @@ -936,6 +1001,7 @@ void bli_zgemm_ref_k1_nn temp_a+=Z_MR; } + // Fringe cases for M dim_t m_rem=m_remainder; if(m_rem>=2) { @@ -944,124 +1010,119 @@ void bli_zgemm_ref_k1_nn ymm7 = _mm256_setzero_pd(); ymm9 = _mm256_setzero_pd(); - if(alpha_valr != 0.0 || alpha_vali != 0.0) - { - //R(a[0][0]) I(a[0][0]) R(a[1][0]) I(a[1][0]) - ymm0 = _mm256_loadu_pd((double const *)(temp_a)); - ymm13 = ymm0; - SCALE_ALPHA_REAL_M_FRINGE(ymm0,ymm15,alpha_valr); - SCALE_ALPHA_IMAG_M_FRINGE(ymm0,ymm15,ymm2,ymm13,alpha_vali); - /* - The result after scaling with alpha_valr and/or alpha_vali is as follows: - For ymm0 : - R(a[0][0]) = alpha_valr*R(a[0][0])-alpha_vali*I(a[0][0]) - I(a[0][0]) = alpha_valr*I(a[0][0])+alpha_vali*R[0][0] - R(a[1][0]) = alpha_valr*R(a[1][0])-alpha_vali*I(a[1][0]) - I(a[1][0]) = alpha_valr*I(a[1][0])+alpha_vali*(R[1][0]) - */ - - //Calculating using real part of complex number in B matrix - //ymm3+=R(b[0][0])*R(a[0][0]) R(b[0][0])*I(a[0][0]) - // R(b[0][0])*R(a[1][0]) R(b[0][0])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)); - //ymm5+=R(b[0][1])*R(a[0][0]) R(b[0][1])*I(a[0][0]) - // R(b[0][1])*R(a[1][0]) R(b[0][1])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm5,ymm2,(double const *)(temp_b+ldb)); - //ymm7+=R(b[0][2])*R(a[0][0]) R(b[0][2])*I(a[0][0]) - // R(b[0][2])*R(a[1][0]) R(b[0][2])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm7,ymm2,(double const *)(temp_b+ldb*2)); - //ymm9+=R(b[0][3])*R(a[0][0]) R(b[0][3])*I(a[0][0]) - // R(b[0][3])*R(a[1][0]) R(b[0][3])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm9,ymm2,(double const *)(temp_b+ldb*3)); - - //Calculating using imaginary part of complex numbers in B matrix - //Shuffling ymm0 in accordance to the requirement - NEG_PERM_M_FRINGE(ymm0,ymm2); - - // ymm3+=I(b[0][0])*R(a[0][0]) I(b[0][0])*I(a[0][0]) - // I(b[0][0])*R(a[1][0]) I(b[0][0])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)+1); - //ymm5+=I(b[0][1])*R(a[0][0]) I(b[0][1])*I(a[0][0]) - // I(b[0][1])*R(a[1][0]) I(b[0][1])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm5,ymm2,(double const *)(temp_b+ldb)+1); - //ymm7+=I(b[0][2])*R(a[0][0]) I(b[0][2])*I(a[0][0]) - // I(b[0][2])*R(a[1][0]) I(b[0][2])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm7,ymm2,(double const *)(temp_b+ldb*2)+1); - //ymm9+=I(b[0][3])*R(a[0][0]) I(b[0][3])*I(a[0][0]) - // I(b[0][3])*R(a[1][0]) I(b[0][3])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm9,ymm2,(double const *)(temp_b+ldb*3)+1); + //R(a[0][0]) I(a[0][0]) R(a[1][0]) I(a[1][0]) + ymm0 = _mm256_loadu_pd((double const *)(temp_a)); - } + ymm13 = ymm0; + SCALE_ALPHA_REAL_M_FRINGE(ymm0,ymm15,alpha_real); + SCALE_ALPHA_IMAG_M_FRINGE(ymm0,ymm13,ymm15,ymm2,alpha_imag); + /* + The result after scaling with alpha_real and/or alpha_imag is as follows: + For ymm0 : + R(a[0][0]) = alpha_real*R(a[0][0])-alpha_imag*I(a[0][0]) + I(a[0][0]) = alpha_real*I(a[0][0])+alpha_imag*R[0][0] + R(a[1][0]) = alpha_real*R(a[1][0])-alpha_imag*I(a[1][0]) + I(a[1][0]) = alpha_real*I(a[1][0])+alpha_imag*(R[1][0]) + */ - if(beta_valr != 0.0) + //Calculating using real part of complex number in B matrix + //ymm3+=R(b[0][0])*R(a[0][0]) R(b[0][0])*I(a[0][0]) + // R(b[0][0])*R(a[1][0]) R(b[0][0])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)); + //ymm5+=R(b[0][1])*R(a[0][0]) R(b[0][1])*I(a[0][0]) + // R(b[0][1])*R(a[1][0]) R(b[0][1])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm5,ymm2,(double const *)(temp_b+ldb)); + //ymm7+=R(b[0][2])*R(a[0][0]) R(b[0][2])*I(a[0][0]) + // R(b[0][2])*R(a[1][0]) R(b[0][2])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm7,ymm2,(double const *)(temp_b+ldb*2)); + //ymm9+=R(b[0][3])*R(a[0][0]) R(b[0][3])*I(a[0][0]) + // R(b[0][3])*R(a[1][0]) R(b[0][3])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm9,ymm2,(double const *)(temp_b+ldb*3)); + + //Calculating using imaginary part of complex numbers in B matrix + //Shuffling ymm0 in accordance to the requirement + NEG_PERM_M_FRINGE(ymm0,ymm2); + + // ymm3+=I(b[0][0])*R(a[0][0]) I(b[0][0])*I(a[0][0]) + // I(b[0][0])*R(a[1][0]) I(b[0][0])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)+1); + //ymm5+=I(b[0][1])*R(a[0][0]) I(b[0][1])*I(a[0][0]) + // I(b[0][1])*R(a[1][0]) I(b[0][1])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm5,ymm2,(double const *)(temp_b+ldb)+1); + //ymm7+=I(b[0][2])*R(a[0][0]) I(b[0][2])*I(a[0][0]) + // I(b[0][2])*R(a[1][0]) I(b[0][2])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm7,ymm2,(double const *)(temp_b+ldb*2)+1); + //ymm9+=I(b[0][3])*R(a[0][0]) I(b[0][3])*I(a[0][0]) + // I(b[0][3])*R(a[1][0]) I(b[0][3])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm9,ymm2,(double const *)(temp_b+ldb*3)+1); + + + if(beta_real != 0.0) { - - ymm15 = _mm256_broadcast_sd((double const *)(&beta_valr)); + ymm15 = _mm256_broadcast_sd((double const *)(&beta_real)); //R(c[0][0]) I(c[0][0]) R(c[1][0]) I(c[1][0]) ymm0 = _mm256_loadu_pd((double const *)(temp_c)); - //ymm3+=beta_valr*R(c[0][0]) beta_valr*I(c[0][0]) - // beta_valr*R(c[1][0]) beta_valr*I(c[1][0]) + //ymm3+=beta_real*R(c[0][0]) beta_real*I(c[0][0]) + // beta_real*R(c[1][0]) beta_real*I(c[1][0]) SCALE_BETA_REAL_M_FRINGE(ymm0,ymm3,ymm15); //R(c[0][1]) I(c[0][1]) R(c[1][1]) I(c[1][1]) ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc)); - //ymm5+=beta_valr*R(c[0][1]) beta_valr*I(c[0][1]) - // beta_valr*R(c[1][1]) beta_valr*I(c[1][1]) + //ymm5+=beta_real*R(c[0][1]) beta_real*I(c[0][1]) + // beta_real*R(c[1][1]) beta_real*I(c[1][1]) SCALE_BETA_REAL_M_FRINGE(ymm0,ymm5,ymm15); //R(c[0][2]) I(c[0][2]) R(c[1][2]) I(c[1][2]) ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*2)); - //ymm7+=beta_valr*R(c[0][2]) beta_valr*I(c[0][2]) - // beta_valr*R(c[1][2]) beta_valr*I(c[1][2]) + //ymm7+=beta_real*R(c[0][2]) beta_real*I(c[0][2]) + // beta_real*R(c[1][2]) beta_real*I(c[1][2]) SCALE_BETA_REAL_M_FRINGE(ymm0,ymm7,ymm15); //R(c[0][3]) I(c[0][3]) R(c[1][3]) I(c[1][3]) ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*3)); - //ymm9+=beta_valr*R(c[0][3]) beta_valr*I(c[0][3]) - // beta_valr*R(c[1][3]) beta_valr*I(c[1][3]) + //ymm9+=beta_real*R(c[0][3]) beta_real*I(c[0][3]) + // beta_real*R(c[1][3]) beta_real*I(c[1][3]) SCALE_BETA_REAL_M_FRINGE(ymm0,ymm9,ymm15); - } - if(beta_vali != 0.0) + if(beta_imag != 0.0) { - - ymm15 = _mm256_broadcast_sd((double const *)(&beta_vali)); + ymm15 = _mm256_broadcast_sd((double const *)(&beta_imag)); //R(c[0][0]) I(c[0][0]) R(c[1][0]) I(c[1][0]) ymm0 = _mm256_loadu_pd((double const *)(temp_c)); - //ymm3+=beta_vali*(-I(c[0][0])) beta_vali*R(c[0][0]) - // beta_vali*(-I(c[1][0])) beta_vali*R(c[1][0]) + //ymm3+=beta_imag*(-I(c[0][0])) beta_imag*R(c[0][0]) + // beta_imag*(-I(c[1][0])) beta_imag*R(c[1][0]) SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm3,ymm15,ymm2); //R(c[0][1]) I(c[0][1]) R(c[1][1]) I(c[1][1]) ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc)); - //ymm5+=beta_vali*(-I(c[0][1])) beta_vali*R(c[0][1]) - // beta_vali*(-I(c[1][1])) beta_vali*R(c[1][1]) + //ymm5+=beta_imag*(-I(c[0][1])) beta_imag*R(c[0][1]) + // beta_imag*(-I(c[1][1])) beta_imag*R(c[1][1]) SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm5,ymm15,ymm2); //R(c[0][2]) I(c[0][2]) R(c[1][2]) I(c[1][2]) ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*2)); - //ymm7+=beta_vali*(-I(c[0][2])) beta_vali*R(c[0][2]) - // beta_vali*(-I(c[1][2])) beta_vali*R(c[1][2]) + //ymm7+=beta_imag*(-I(c[0][2])) beta_imag*R(c[0][2]) + // beta_imag*(-I(c[1][2])) beta_imag*R(c[1][2]) SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm7,ymm15,ymm2); //R(c[0][3]) I(c[0][3]) R(c[1][3]) I(c[1][3]) ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc*3)); - //ymm9+=beta_vali*(-I(c[0][3])) beta_vali*R(c[0][3]) - // beta_vali*(-I(c[1][3])) beta_vali*R(c[1][3]) + //ymm9+=beta_imag*(-I(c[0][3])) beta_imag*R(c[0][3]) + // beta_imag*(-I(c[1][3])) beta_imag*R(c[1][3]) SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm9,ymm15,ymm2); } /* The scaling has been done sequentially as follows: - - If alpha_valr is not 0, it is used for scaling A - - If alpha_vali is not 0, it is used for scaling A using permutation + - If alpha_real is not 0, it is used for scaling A + - If alpha_imag is not 0, it is used for scaling A using permutation and selective negation, after loading - - If beta_valr is not 0, is is used for scaling C - - If beta_vali is not 0, it is used for scaling C using permutation + - If beta_real is not 0, is is used for scaling C + - If beta_imag is not 0, it is used for scaling C using permutation and selective negation, after loading The results are accumalated in accordance to the non zero scalar values, @@ -1088,95 +1149,92 @@ void bli_zgemm_ref_k1_nn ymm7 = _mm256_setzero_pd(); ymm9 = _mm256_setzero_pd(); - if(alpha_valr != 0.0 || alpha_vali != 0.0) + xmm5 = _mm_loadu_pd((double const*)(temp_a));//R(a[0][0]) I(a[0][0]) + ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(a[0][0]) I(a[0][0]) + + ymm13 = ymm0; + SCALE_ALPHA_REAL_M_FRINGE(ymm0,ymm15,alpha_real); + SCALE_ALPHA_IMAG_M_FRINGE(ymm0,ymm13,ymm15,ymm2,alpha_imag); + + //Calculating using real part of complex number in B matrix + //ymm3+=R(b[0][0])*R(a[0][0]) R(b[0][0])*I(a[0][0]) + // R(b[0][0])*R(a[1][0]) R(b[0][0])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)); + //ymm5+=R(b[0][1])*R(a[0][0]) R(b[0][1])*I(a[0][0]) + // R(b[0][1])*R(a[1][0]) R(b[0][1])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm5,ymm2,(double const *)(temp_b+ldb)); + //ymm7+=R(b[0][2])*R(a[0][0]) R(b[0][2])*I(a[0][0]) + // R(b[0][2])*R(a[1][0]) R(b[0][2])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm7,ymm2,(double const *)(temp_b+ldb*2)); + //ymm9+=R(b[0][3])*R(a[0][0]) R(b[0][3])*I(a[0][0]) + // R(b[0][3])*R(a[1][0]) R(b[0][3])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm9,ymm2,(double const *)(temp_b+ldb*3)); + + //Calculating using imaginary part of complex numbers in B matrix + //Shuffling ymm0 in accordance to the requirement + NEG_PERM_M_FRINGE(ymm0,ymm2); + + // ymm3+=I(b[0][0])*R(a[0][0]) I(b[0][0])*I(a[0][0]) + // I(b[0][0])*R(a[1][0]) I(b[0][0])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)+1); + //ymm5+=I(b[0][1])*R(a[0][0]) I(b[0][1])*I(a[0][0]) + // I(b[0][1])*R(a[1][0]) I(b[0][1])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm5,ymm2,(double const *)(temp_b+ldb)+1); + //ymm7+=I(b[0][2])*R(a[0][0]) I(b[0][2])*I(a[0][0]) + // I(b[0][2])*R(a[1][0]) I(b[0][2])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm7,ymm2,(double const *)(temp_b+ldb*2)+1); + //ymm9+=I(b[0][3])*R(a[0][0]) I(b[0][3])*I(a[0][0]) + // I(b[0][3])*R(a[1][0]) I(b[0][3])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm9,ymm2,(double const *)(temp_b+ldb*3)+1); + + if(beta_real != 0.0) { - xmm5 = _mm_loadu_pd((double const*)(temp_a));//R(a[0][0]) I(a[0][0]) - ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(a[0][0]) I(a[0][0]) - ymm13 = ymm0; - - SCALE_ALPHA_REAL_M_FRINGE(ymm0,ymm15,alpha_valr); - SCALE_ALPHA_IMAG_M_FRINGE(ymm0,ymm15,ymm2,ymm13,alpha_vali); - - //Calculating using real part of complex number in B matrix - //ymm3+=R(b[0][0])*R(a[0][0]) R(b[0][0])*I(a[0][0]) - // R(b[0][0])*R(a[1][0]) R(b[0][0])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)); - //ymm5+=R(b[0][1])*R(a[0][0]) R(b[0][1])*I(a[0][0]) - // R(b[0][1])*R(a[1][0]) R(b[0][1])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm5,ymm2,(double const *)(temp_b+ldb)); - //ymm7+=R(b[0][2])*R(a[0][0]) R(b[0][2])*I(a[0][0]) - // R(b[0][2])*R(a[1][0]) R(b[0][2])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm7,ymm2,(double const *)(temp_b+ldb*2)); - //ymm9+=R(b[0][3])*R(a[0][0]) R(b[0][3])*I(a[0][0]) - // R(b[0][3])*R(a[1][0]) R(b[0][3])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm9,ymm2,(double const *)(temp_b+ldb*3)); - - //Calculating using imaginary part of complex numbers in B matrix - //Shuffling ymm0 in accordance to the requirement - NEG_PERM_M_FRINGE(ymm0,ymm2); - - // ymm3+=I(b[0][0])*R(a[0][0]) I(b[0][0])*I(a[0][0]) - // I(b[0][0])*R(a[1][0]) I(b[0][0])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)+1); - //ymm5+=I(b[0][1])*R(a[0][0]) I(b[0][1])*I(a[0][0]) - // I(b[0][1])*R(a[1][0]) I(b[0][1])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm5,ymm2,(double const *)(temp_b+ldb)+1); - //ymm7+=I(b[0][2])*R(a[0][0]) I(b[0][2])*I(a[0][0]) - // I(b[0][2])*R(a[1][0]) I(b[0][2])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm7,ymm2,(double const *)(temp_b+ldb*2)+1); - //ymm9+=I(b[0][3])*R(a[0][0]) I(b[0][3])*I(a[0][0]) - // I(b[0][3])*R(a[1][0]) I(b[0][3])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm9,ymm2,(double const *)(temp_b+ldb*3)+1); - - } - if(beta_valr != 0.0) - { - ymm15 = _mm256_broadcast_sd((double const *)(&beta_valr)); + ymm15 = _mm256_broadcast_sd((double const *)(&beta_real)); xmm5 = _mm_loadu_pd((double const*)(temp_c));//R(c[0][0]) I(c[0][0]) ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][0]) I(c[0][0]) - //ymm3+=beta_valr*R(c[0][0]) beta_valr*I(c[0][0]) + //ymm3+=beta_real*R(c[0][0]) beta_real*I(c[0][0]) SCALE_BETA_REAL_M_FRINGE(ymm0,ymm3,ymm15); xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc));//R(c[0][1]) I(c[0][1]) ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][1]) I(c[0][1]) - //ymm5+=beta_valr*R(c[0][1]) beta_valr*I(c[0][1]) + //ymm5+=beta_real*R(c[0][1]) beta_real*I(c[0][1]) SCALE_BETA_REAL_M_FRINGE(ymm0,ymm5,ymm15); xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc * 2));//R(c[0][2]) I(c[0][2]) ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][2]) I(c[0][2]) - //ymm7+=beta_valr*R(c[0][2]) beta_valr*I(c[0][2]) + //ymm7+=beta_real*R(c[0][2]) beta_real*I(c[0][2]) SCALE_BETA_REAL_M_FRINGE(ymm0,ymm7,ymm15); xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc * 3));//R(c[0][3]) I(c[0][3]) ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][3]) I(c[0][3]) - //ymm9+=beta_valr*R(c[0][3]) beta_valr*I(c[0][3]) + //ymm9+=beta_real*R(c[0][3]) beta_real*I(c[0][3]) SCALE_BETA_REAL_M_FRINGE(ymm0,ymm9,ymm15); } - if(beta_vali != 0.0) + + if(beta_imag != 0.0) { - ymm15 = _mm256_broadcast_sd((double const *)(&beta_vali)); + ymm15 = _mm256_broadcast_sd((double const *)(&beta_imag)); xmm5 = _mm_loadu_pd((double const*)(temp_c));//R(c[0][0]) I(c[0][0]) ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][0]) I(c[0][0]) - //ymm3+=beta_vali*(-I(c[0][0])) beta_vali*R(c[0][0]) + //ymm3+=beta_imag*(-I(c[0][0])) beta_imag*R(c[0][0]) SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm3,ymm15,ymm2); xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc));//R(c[0][1]) I(c[0][1]) ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][1]) I(c[0][1]) - //ymm5+=beta_vali*(-I(c[0][1])) beta_vali*R(c[0][1]) + //ymm5+=beta_imag*(-I(c[0][1])) beta_imag*R(c[0][1]) SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm5,ymm15,ymm2); xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc * 2));//R(c[0][2]) I(c[0][2]) ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][2]) I(c[0][2]) - //ymm7+=beta_vali*(-I(c[0][2])) beta_vali*R(c[0][2]) + //ymm7+=beta_imag*(-I(c[0][2])) beta_imag*R(c[0][2]) SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm7,ymm15,ymm2); xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc * 3));//R(c[0][3]) I(c[0][3]) ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][3]) I(c[0][3]) - //ymm9+=beta_vali*(-I(c[0][3])) beta_vali*R(c[0][3]) + //ymm9+=beta_imag*(-I(c[0][3])) beta_imag*R(c[0][3]) SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm9,ymm15,ymm2); - } xmm5 = _mm256_extractf128_pd(ymm3, 0); @@ -1207,130 +1265,125 @@ void bli_zgemm_ref_k1_nn ymm5 = _mm256_setzero_pd(); ymm6 = _mm256_setzero_pd(); - if(alpha_valr != 0.0 || alpha_vali != 0.0) - { - /* - a. Perform alpha*A*B using temp_a, temp_b and alpha_valr, alpha_vali - where alpha_valr and/or alpha_vali is not zero. - b. This loop operates with 4x5 block size - along n dimension for every Z_NR columns of temp_b where - computing all Z_MR rows of temp_a. - c. Same approach is used in remaining fringe cases. - */ - - //R(a[0][0]) I(a[0][0]) R(a[1][0]) I(a[1][0]) - ymm0 = _mm256_loadu_pd((double const *)(temp_a)); - //R(a[2][0]) I(a[2][0]) R(a[3][0]) I(a[3][0]) - ymm1 = _mm256_loadu_pd((double const *)(temp_a + 2)); - - ymm13 = ymm0; - ymm14 = ymm1; - - _mm_prefetch((char*)(temp_a) + 64, _MM_HINT_T0); - - SCALE_ALPHA_REAL_M_LOOP(ymm0,ymm1,ymm15,alpha_valr); - SCALE_ALPHA_IMAG_M_LOOP(ymm0,ymm1,ymm15,ymm2,ymm13,ymm14,alpha_vali); - - /* - The result after scaling with alpha_valr and/or alpha_vali is as follows: - For ymm0 : - R(a[0][0]) = alpha_valr*R(a[0][0])-alpha_vali*I(a[0][0]) - I(a[0][0]) = alpha_valr*I(a[0][0])+alpha_vali*R[0][0] - R(a[1][0]) = alpha_valr*R(a[1][0])-alpha_vali*I(a[1][0]) - I(a[1][0]) = alpha_valr*I(a[1][0])+alpha_vali*(R[1][0]) - - For ymm1 : - R(a[2][0]) = alpha_valr*R(a[2][0])-alpha_vali*I(a[2][0]) - I(a[2][0]) = alpha_valr*I(a[2][0])+alpha_vali*R[2][0] - R(a[3][0]) = alpha_valr*R(a[3][0])-alpha_vali*I(a[3][0]) - I(a[3][0]) = alpha_valr*I(a[3][0])+alpha_vali*(R[3][0]) - */ - - //Calculating using real part of complex number in B matrix - FMA_M_LOOP(ymm0,ymm1,ymm3,ymm4,ymm2,(double const *)(temp_b)); - FMA_M_LOOP(ymm0,ymm1,ymm5,ymm6,ymm2,(double const *)(temp_b+ldb)); - - //Calculating using imaginary part of complex numbers in B matrix - //Shuffling ymm0 and ymm1 in accordance to the requirement - NEG_PERM_M_LOOP(ymm0,ymm1,ymm2); - FMA_M_LOOP(ymm0,ymm1,ymm3,ymm4,ymm2,(double const *)(temp_b)+1); - FMA_M_LOOP(ymm0,ymm1,ymm5,ymm6,ymm2,(double const *)(temp_b+ldb)+1); - } - if(beta_valr != 0.0) + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_real, alpha_imag + where alpha_real and/or alpha_imag is not zero. + b. This loop operates with 4x6 block size + along n dimension for every Z_NR columns of temp_b where + computing all Z_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + + //R(a[0][0]) I(a[0][0]) R(a[1][0]) I(a[1][0]) + ymm0 = _mm256_loadu_pd((double const *)(temp_a)); + //R(a[2][0]) I(a[2][0]) R(a[3][0]) I(a[3][0]) + ymm1 = _mm256_loadu_pd((double const *)(temp_a + 2)); + + ymm13 = ymm0; + ymm14 = ymm1; + SCALE_ALPHA_REAL_M_LOOP(ymm0,ymm1,ymm15,alpha_real); + SCALE_ALPHA_IMAG_M_LOOP(ymm0,ymm1,ymm13,ymm14,ymm15,ymm2,alpha_imag); + + /* + The result after scaling with alpha_real and/or alpha_imag is as follows: + For ymm0 : + R(a[0][0]) = alpha_real*R(a[0][0])-alpha_imag*I(a[0][0]) + I(a[0][0]) = alpha_real*I(a[0][0])+alpha_imag*R[0][0] + R(a[1][0]) = alpha_real*R(a[1][0])-alpha_imag*I(a[1][0]) + I(a[1][0]) = alpha_real*I(a[1][0])+alpha_imag*(R[1][0]) + + For ymm1 : + R(a[2][0]) = alpha_real*R(a[2][0])-alpha_imag*I(a[2][0]) + I(a[2][0]) = alpha_real*I(a[2][0])+alpha_imag*R[2][0] + R(a[3][0]) = alpha_real*R(a[3][0])-alpha_imag*I(a[3][0]) + I(a[3][0]) = alpha_real*I(a[3][0])+alpha_imag*(R[3][0]) + */ + + //Calculating using real part of complex number in B matrix + FMA_M_LOOP(ymm0,ymm1,ymm3,ymm4,ymm2,(double const *)(temp_b)); + FMA_M_LOOP(ymm0,ymm1,ymm5,ymm6,ymm2,(double const *)(temp_b+ldb)); + + //Calculating using imaginary part of complex numbers in B matrix + //Shuffling ymm0 and ymm1 in accordance to the requirement + NEG_PERM_M_LOOP(ymm0,ymm1,ymm2); + FMA_M_LOOP(ymm0,ymm1,ymm3,ymm4,ymm2,(double const *)(temp_b)+1); + FMA_M_LOOP(ymm0,ymm1,ymm5,ymm6,ymm2,(double const *)(temp_b+ldb)+1); + + /* + a. Perform beta*C using temp_c, beta_real, + where beta_real is not zero. + b. This loop operates with 4x6 block size + along n dimension for every Z_NR columns of temp_c where + computing all Z_MR rows of temp_c. + c. Accumulated alpha*A*B into registers will be added to beta*C + d. Same approach is used in remaining fringe cases. + */ + if(beta_real != 0.0) { - /* - a. Perform beta*C using temp_c, beta_valr, - where beta_valr is not zero. - b. This loop operates with 4x5 block size - along n dimension for every Z_NR columns of temp_c where - computing all Z_MR rows of temp_c. - c. Accumulated alpha*A*B into registers will be added to beta*C - d. Same approach is used in remaining fringe cases. - */ - ymm15 = _mm256_broadcast_sd((double const *)(&beta_valr)); + ymm15 = _mm256_broadcast_sd((double const *)(&beta_real)); //R(c[0][0]) I(c[0][0]) R(c[1][0]) I(c[1][0]) ymm0 = _mm256_loadu_pd((double const *)(temp_c)); //R(c[2][0]) I(c[2][0]) R(c[3][0]) I(c[3][0]) ymm1 = _mm256_loadu_pd((double const *)(temp_c + 2)); - //ymm3+=beta_valr*R(c[0][0]) beta_valr*I(c[0][0]) - // beta_valr*R(c[1][0]) beta_valr*I(c[1][0]) - //ymm4+=beta_valr*R(c[2][0]) beta_valr*I(c[2][0]) - // beta_valr*R(c[3][0]) beta_valr*I(c[3][0]) + //ymm3+=beta_real*R(c[0][0]) beta_real*I(c[0][0]) + // beta_real*R(c[1][0]) beta_real*I(c[1][0]) + //ymm4+=beta_real*R(c[2][0]) beta_real*I(c[2][0]) + // beta_real*R(c[3][0]) beta_real*I(c[3][0]) SCALE_BETA_REAL_M_LOOP(ymm0,ymm1,ymm3,ymm4,ymm15); //R(c[0][1]) I(c[0][1]) R(c[1][1]) I(c[1][1]) ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc)); //R(c[2][1]) I(c[2][1]) R(c[3][1]) I(c[3][1]) ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc + 2)); - //ymm5+=beta_valr*R(c[0][1]) beta_valr*I(c[0][1]) - // beta_valr*R(c[1][1]) beta_valr*I(c[1][1]) - //ymm6+=beta_valr*R(c[2][1]) beta_valr*I(c[2][1]) - // beta_valr*R(c[3][1]) beta_valr*I(c[3][1]) + //ymm5+=beta_real*R(c[0][1]) beta_real*I(c[0][1]) + // beta_real*R(c[1][1]) beta_real*I(c[1][1]) + //ymm6+=beta_real*R(c[2][1]) beta_real*I(c[2][1]) + // beta_real*R(c[3][1]) beta_real*I(c[3][1]) SCALE_BETA_REAL_M_LOOP(ymm0,ymm1,ymm5,ymm6,ymm15); - } - if(beta_vali != 0.0) + + /* + a. Perform beta*C using temp_c, beta_imag, + where beta_imag is not zero. + b. This loop operates with 4x6 block size + along n dimension for every Z_NR columns of temp_c where + computing all Z_MR rows of temp_c. + c. Accumulated alpha*A*B into registers will be added to beta*C + d. Same approach is used in remaining fringe cases. + */ + + if(beta_imag != 0.0) { - /* - a. Perform beta*C using temp_c, beta_vali, - where beta_vali is not zero. - b. This loop operates with 4x5 block size - along n dimension for every Z_NR columns of temp_c where - computing all Z_MR rows of temp_c. - c. Accumulated alpha*A*B into registers will be added to beta*C - d. Same approach is used in remaining fringe cases. - */ - - ymm15 = _mm256_broadcast_sd((double const *)(&beta_vali)); + ymm15 = _mm256_broadcast_sd((double const *)(&beta_imag)); //R(c[0][0]) I(c[0][0]) R(c[1][0]) I(c[1][0]) ymm0 = _mm256_loadu_pd((double const *)(temp_c)); //R(c[2][0]) I(c[2][0]) R(c[3][0]) I(c[3][0]) ymm1 = _mm256_loadu_pd((double const *)(temp_c + 2)); - //ymm3+=beta_vali*(-I(c[0][0])) beta_vali*R(c[0][0]) - // beta_vali*(-I(c[1][0])) beta_vali*R(c[1][0]) - //ymm4+=beta_vali*(-I(c[2][0])) beta_vali*R(c[2][0]) - // beta_vali*(-I(c[3][0])) beta_vali*R(c[3][0]) + //ymm3+=beta_imag*(-I(c[0][0])) beta_imag*R(c[0][0]) + // beta_imag*(-I(c[1][0])) beta_imag*R(c[1][0]) + //ymm4+=beta_imag*(-I(c[2][0])) beta_imag*R(c[2][0]) + // beta_imag*(-I(c[3][0])) beta_imag*R(c[3][0]) SCALE_BETA_IMAG_M_LOOP(ymm0,ymm1,ymm3,ymm4,ymm15,ymm2); //R(c[0][1]) I(c[0][1]) R(c[1][1]) I(c[1][1]) ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc)); //R(c[2][1]) I(c[2][1]) R(c[3][1]) I(c[3][1]) ymm1 = _mm256_loadu_pd((double const *)(temp_c + ldc + 2)); - //ymm5+=beta_vali*(-I(c[0][1])) beta_vali*R(c[0][1]) - // beta_vali*(-I(c[1][1])) beta_vali*R(c[1][1]) - //ymm6+=beta_vali*(-I(c[2][1])) beta_vali*R(c[2][1]) - // beta_vali*(-I(c[3][1])) beta_vali*R(c[3][1]) + //ymm5+=beta_imag*(-I(c[0][1])) beta_imag*R(c[0][1]) + // beta_imag*(-I(c[1][1])) beta_imag*R(c[1][1]) + //ymm6+=beta_imag*(-I(c[2][1])) beta_imag*R(c[2][1]) + // beta_imag*(-I(c[3][1])) beta_imag*R(c[3][1]) SCALE_BETA_IMAG_M_LOOP(ymm0,ymm1,ymm5,ymm6,ymm15,ymm2); } /* The scaling has been done sequentially as follows: - - If alpha_valr is not 0, it is used for scaling A - - If alpha_vali is not 0, it is used for scaling A using permutation + - If alpha_real is not 0, it is used for scaling A + - If alpha_imag is not 0, it is used for scaling A using permutation and selective negation, after loading - - If beta_valr is not 0, is is used for scaling C - - If beta_vali is not 0, it is used for scaling C using permutation + - If beta_real is not 0, is is used for scaling C + - If beta_imag is not 0, it is used for scaling C using permutation and selective negation, after loading The results are accumalated in accordance to the non zero scalar values, @@ -1353,88 +1406,83 @@ void bli_zgemm_ref_k1_nn ymm3 = _mm256_setzero_pd(); ymm5 = _mm256_setzero_pd(); - if(alpha_valr != 0.0 || alpha_vali != 0.0) - { - //R(a[0][0]) I(a[0][0]) R(a[1][0]) I(a[1][0]) - ymm0 = _mm256_loadu_pd((double const *)(temp_a)); - ymm13 = ymm0; - SCALE_ALPHA_REAL_M_FRINGE(ymm0,ymm15,alpha_valr); - SCALE_ALPHA_IMAG_M_FRINGE(ymm0,ymm15,ymm2,ymm13,alpha_vali); - /* - The result after scaling with alpha_valr and/or alpha_vali is as follows: - For ymm0 : - R(a[0][0]) = alpha_valr*R(a[0][0])-alpha_vali*I(a[0][0]) - I(a[0][0]) = alpha_valr*I(a[0][0])+alpha_vali*R[0][0] - R(a[1][0]) = alpha_valr*R(a[1][0])-alpha_vali*I(a[1][0]) - I(a[1][0]) = alpha_valr*I(a[1][0])+alpha_vali*(R[1][0]) - */ - - //Calculating using real part of complex number in B matrix - //ymm3+=R(b[0][0])*R(a[0][0]) R(b[0][0])*I(a[0][0]) - // R(b[0][0])*R(a[1][0]) R(b[0][0])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)); - //ymm5+=R(b[0][1])*R(a[0][0]) R(b[0][1])*I(a[0][0]) - // R(b[0][1])*R(a[1][0]) R(b[0][1])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm5,ymm2,(double const *)(temp_b+ldb)); - - //Calculating using imaginary part of complex numbers in B matrix - //Shuffling ymm0 in accordance to the requirement - NEG_PERM_M_FRINGE(ymm0,ymm2); - - // ymm3+=I(b[0][0])*R(a[0][0]) I(b[0][0])*I(a[0][0]) - // I(b[0][0])*R(a[1][0]) I(b[0][0])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)+1); - //ymm5+=I(b[0][1])*R(a[0][0]) I(b[0][1])*I(a[0][0]) - // I(b[0][1])*R(a[1][0]) I(b[0][1])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm5,ymm2,(double const *)(temp_b+ldb)+1); + //R(a[0][0]) I(a[0][0]) R(a[1][0]) I(a[1][0]) + ymm0 = _mm256_loadu_pd((double const *)(temp_a)); - } + ymm13 = ymm0; + SCALE_ALPHA_REAL_M_FRINGE(ymm0,ymm15,alpha_real); + SCALE_ALPHA_IMAG_M_FRINGE(ymm0,ymm13,ymm15,ymm2,alpha_imag); + /* + The result after scaling with alpha_real and/or alpha_imag is as follows: + For ymm0 : + R(a[0][0]) = alpha_real*R(a[0][0])-alpha_imag*I(a[0][0]) + I(a[0][0]) = alpha_real*I(a[0][0])+alpha_imag*R[0][0] + R(a[1][0]) = alpha_real*R(a[1][0])-alpha_imag*I(a[1][0]) + I(a[1][0]) = alpha_real*I(a[1][0])+alpha_imag*(R[1][0]) + */ + + //Calculating using real part of complex number in B matrix + //ymm3+=R(b[0][0])*R(a[0][0]) R(b[0][0])*I(a[0][0]) + // R(b[0][0])*R(a[1][0]) R(b[0][0])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)); + //ymm5+=R(b[0][1])*R(a[0][0]) R(b[0][1])*I(a[0][0]) + // R(b[0][1])*R(a[1][0]) R(b[0][1])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm5,ymm2,(double const *)(temp_b+ldb)); + + //Calculating using imaginary part of complex numbers in B matrix + //Shuffling ymm0 in accordance to the requirement + NEG_PERM_M_FRINGE(ymm0,ymm2); + + // ymm3+=I(b[0][0])*R(a[0][0]) I(b[0][0])*I(a[0][0]) + // I(b[0][0])*R(a[1][0]) I(b[0][0])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)+1); + //ymm5+=I(b[0][1])*R(a[0][0]) I(b[0][1])*I(a[0][0]) + // I(b[0][1])*R(a[1][0]) I(b[0][1])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm5,ymm2,(double const *)(temp_b+ldb)+1); - if(beta_valr != 0.0) - { - ymm15 = _mm256_broadcast_sd((double const *)(&beta_valr)); + if(beta_real != 0.0) + { + ymm15 = _mm256_broadcast_sd((double const *)(&beta_real)); //R(c[0][0]) I(c[0][0]) R(c[1][0]) I(c[1][0]) ymm0 = _mm256_loadu_pd((double const *)(temp_c)); - //ymm3+=beta_valr*R(c[0][0]) beta_valr*I(c[0][0]) - // beta_valr*R(c[1][0]) beta_valr*I(c[1][0]) + //ymm3+=beta_real*R(c[0][0]) beta_real*I(c[0][0]) + // beta_real*R(c[1][0]) beta_real*I(c[1][0]) SCALE_BETA_REAL_M_FRINGE(ymm0,ymm3,ymm15); //R(c[0][1]) I(c[0][1]) R(c[1][1]) I(c[1][1]) ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc)); - //ymm5+=beta_valr*R(c[0][1]) beta_valr*I(c[0][1]) - // beta_valr*R(c[1][1]) beta_valr*I(c[1][1]) + //ymm5+=beta_real*R(c[0][1]) beta_real*I(c[0][1]) + // beta_real*R(c[1][1]) beta_real*I(c[1][1]) SCALE_BETA_REAL_M_FRINGE(ymm0,ymm5,ymm15); - } - if(beta_vali != 0.0) + if(beta_imag != 0.0) { - - ymm15 = _mm256_broadcast_sd((double const *)(&beta_vali)); + ymm15 = _mm256_broadcast_sd((double const *)(&beta_imag)); //R(c[0][0]) I(c[0][0]) R(c[1][0]) I(c[1][0]) ymm0 = _mm256_loadu_pd((double const *)(temp_c)); - //ymm3+=beta_vali*(-I(c[0][0])) beta_vali*R(c[0][0]) - // beta_vali*(-I(c[1][0])) beta_vali*R(c[1][0]) + //ymm3+=beta_imag*(-I(c[0][0])) beta_imag*R(c[0][0]) + // beta_imag*(-I(c[1][0])) beta_imag*R(c[1][0]) SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm3,ymm15,ymm2); //R(c[0][1]) I(c[0][1]) R(c[1][1]) I(c[1][1]) ymm0 = _mm256_loadu_pd((double const *)(temp_c + ldc)); - //ymm5+=beta_vali*(-I(c[0][1])) beta_vali*R(c[0][1]) - // beta_vali*(-I(c[1][1])) beta_vali*R(c[1][1]) + //ymm5+=beta_imag*(-I(c[0][1])) beta_imag*R(c[0][1]) + // beta_imag*(-I(c[1][1])) beta_imag*R(c[1][1]) SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm5,ymm15,ymm2); } /* The scaling has been done sequentially as follows: - - If alpha_valr is not 0, it is used for scaling A - - If alpha_vali is not 0, it is used for scaling A using permutation + - If alpha_real is not 0, it is used for scaling A + - If alpha_imag is not 0, it is used for scaling A using permutation and selective negation, after loading - - If beta_valr is not 0, is is used for scaling C - - If beta_vali is not 0, it is used for scaling C using permutation + - If beta_real is not 0, is is used for scaling C + - If beta_imag is not 0, it is used for scaling C using permutation and selective negation, after loading The results are accumalated in accordance to the non zero scalar values, @@ -1457,63 +1505,60 @@ void bli_zgemm_ref_k1_nn ymm3 = _mm256_setzero_pd(); ymm5 = _mm256_setzero_pd(); - if(alpha_valr != 0.0 || alpha_vali != 0.0) + xmm5 = _mm_loadu_pd((double const*)(temp_a));//R(a[0][0]) I(a[0][0]) + ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(a[0][0]) I(a[0][0]) + + ymm13 = ymm0; + SCALE_ALPHA_REAL_M_FRINGE(ymm0,ymm15,alpha_real); + SCALE_ALPHA_IMAG_M_FRINGE(ymm0,ymm13,ymm15,ymm2,alpha_imag); + + //Calculating using real part of complex number in B matrix + //ymm3+=R(b[0][0])*R(a[0][0]) R(b[0][0])*I(a[0][0]) + // R(b[0][0])*R(a[1][0]) R(b[0][0])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)); + //ymm5+=R(b[0][1])*R(a[0][0]) R(b[0][1])*I(a[0][0]) + // R(b[0][1])*R(a[1][0]) R(b[0][1])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm5,ymm2,(double const *)(temp_b+ldb)); + + //Calculating using imaginary part of complex numbers in B matrix + //Shuffling ymm0 in accordance to the requirement + NEG_PERM_M_FRINGE(ymm0,ymm2); + + // ymm3+=I(b[0][0])*R(a[0][0]) I(b[0][0])*I(a[0][0]) + // I(b[0][0])*R(a[1][0]) I(b[0][0])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)+1); + //ymm5+=I(b[0][1])*R(a[0][0]) I(b[0][1])*I(a[0][0]) + // I(b[0][1])*R(a[1][0]) I(b[0][1])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm5,ymm2,(double const *)(temp_b+ldb)+1); + + if(beta_real != 0.0) { - xmm5 = _mm_loadu_pd((double const*)(temp_a));//R(a[0][0]) I(a[0][0]) - ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(a[0][0]) I(a[0][0]) - ymm13 = ymm0; - - SCALE_ALPHA_REAL_M_FRINGE(ymm0,ymm15,alpha_valr); - SCALE_ALPHA_IMAG_M_FRINGE(ymm0,ymm15,ymm2,ymm13,alpha_vali); - - //Calculating using real part of complex number in B matrix - //ymm3+=R(b[0][0])*R(a[0][0]) R(b[0][0])*I(a[0][0]) - // R(b[0][0])*R(a[1][0]) R(b[0][0])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)); - //ymm5+=R(b[0][1])*R(a[0][0]) R(b[0][1])*I(a[0][0]) - // R(b[0][1])*R(a[1][0]) R(b[0][1])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm5,ymm2,(double const *)(temp_b+ldb)); - - //Calculating using imaginary part of complex numbers in B matrix - //Shuffling ymm0 in accordance to the requirement - NEG_PERM_M_FRINGE(ymm0,ymm2); - - // ymm3+=I(b[0][0])*R(a[0][0]) I(b[0][0])*I(a[0][0]) - // I(b[0][0])*R(a[1][0]) I(b[0][0])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)+1); - //ymm5+=I(b[0][1])*R(a[0][0]) I(b[0][1])*I(a[0][0]) - // I(b[0][1])*R(a[1][0]) I(b[0][1])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm5,ymm2,(double const *)(temp_b+ldb)+1); - - } - if(beta_valr != 0.0) - { - ymm15 = _mm256_broadcast_sd((double const *)(&beta_valr)); + ymm15 = _mm256_broadcast_sd((double const *)(&beta_real)); xmm5 = _mm_loadu_pd((double const*)(temp_c));//R(c[0][0]) I(c[0][0]) ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][0]) I(c[0][0]) - //ymm3+=beta_valr*R(c[0][0]) beta_valr*I(c[0][0]) + //ymm3+=beta_real*R(c[0][0]) beta_real*I(c[0][0]) SCALE_BETA_REAL_M_FRINGE(ymm0,ymm3,ymm15); xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc));//R(c[0][1]) I(c[0][1]) ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][1]) I(c[0][1]) - //ymm5+=beta_valr*R(c[0][1]) beta_valr*I(c[0][1]) + //ymm5+=beta_real*R(c[0][1]) beta_real*I(c[0][1]) SCALE_BETA_REAL_M_FRINGE(ymm0,ymm5,ymm15); } - if(beta_vali != 0.0) + + if(beta_imag != 0.0) { - ymm15 = _mm256_broadcast_sd((double const *)(&beta_vali)); + ymm15 = _mm256_broadcast_sd((double const *)(&beta_imag)); xmm5 = _mm_loadu_pd((double const*)(temp_c));//R(c[0][0]) I(c[0][0]) ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][0]) I(c[0][0]) - //ymm3+=beta_vali*(-I(c[0][0])) beta_vali*R(c[0][0]) + //ymm3+=beta_imag*(-I(c[0][0])) beta_imag*R(c[0][0]) SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm3,ymm15,ymm2); xmm5 = _mm_loadu_pd((double const*)(temp_c + ldc));//R(c[0][1]) I(c[0][1]) ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0);//R(c[0][1]) I(c[0][1]) - //ymm5+=beta_vali*(-I(c[0][1])) beta_vali*R(c[0][1]) + //ymm5+=beta_imag*(-I(c[0][1])) beta_imag*R(c[0][1]) SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm5,ymm15,ymm2); - } xmm5 = _mm256_extractf128_pd(ymm3, 0); @@ -1531,111 +1576,108 @@ void bli_zgemm_ref_k1_nn dcomplex* temp_a = a; dcomplex* temp_c = c + (n - n_remainder)*ldc; + // Main loop for M for(dim_t i = 0;i < (m-Z_MR+1);i=i+Z_MR) { ymm3 = _mm256_setzero_pd(); ymm4 = _mm256_setzero_pd(); - if(alpha_valr != 0.0 || alpha_vali != 0.0) - { - /* - a. Perform alpha*A*B using temp_a, temp_b and alpha_valr, aplha_vali - where alpha_valr and/or alpha_vali is not zero. - b. This loop operates with 4x5 block size - along n dimension for every Z_NR columns of temp_b where - computing all Z_MR rows of temp_a. - c. Same approach is used in remaining fringe cases. - */ - - //R(a[0][0]) I(a[0][0]) R(a[1][0]) I(a[1][0]) - ymm0 = _mm256_loadu_pd((double const *)(temp_a)); - //R(a[2][0]) I(a[2][0]) R(a[3][0]) I(a[3][0]) - ymm1 = _mm256_loadu_pd((double const *)(temp_a + 2)); - - ymm13 = ymm0; - ymm14 = ymm1; - _mm_prefetch((char*)(temp_a) + 64, _MM_HINT_T0); - - SCALE_ALPHA_REAL_M_LOOP(ymm0,ymm1,ymm15,alpha_valr); - SCALE_ALPHA_IMAG_M_LOOP(ymm0,ymm1,ymm15,ymm2,ymm13,ymm14,alpha_vali); - - /* - The result after scaling with alpha_valr and/or alpha_vali is as follows: - For ymm0 : - R(a[0][0]) = alpha_valr*R(a[0][0])-alpha_vali*I(a[0][0]) - I(a[0][0]) = alpha_valr*I(a[0][0])+alpha_vali*R[0][0] - R(a[1][0]) = alpha_valr*R(a[1][0])-alpha_vali*I(a[1][0]) - I(a[1][0]) = alpha_valr*I(a[1][0])+alpha_vali*(R[1][0]) - - For ymm1 : - R(a[2][0]) = alpha_valr*R(a[2][0])-alpha_vali*I(a[2][0]) - I(a[2][0]) = alpha_valr*I(a[2][0])+alpha_vali*R[2][0] - R(a[3][0]) = alpha_valr*R(a[3][0])-alpha_vali*I(a[3][0]) - I(a[3][0]) = alpha_valr*I(a[3][0])+alpha_vali*(R[3][0]) - */ - - //Calculating using real part of complex number in B matrix - FMA_M_LOOP(ymm0,ymm1,ymm3,ymm4,ymm2,(double const *)(temp_b)); - - //Calculating using imaginary part of complex numbers in B matrix - //Shuffling ymm0 and ymm1 in accordance to the requirement - NEG_PERM_M_LOOP(ymm0,ymm1,ymm2); - FMA_M_LOOP(ymm0,ymm1,ymm3,ymm4,ymm2,(double const *)(temp_b)+1); - } - if(beta_valr != 0.0) + /* + a. Perform alpha*A*B using temp_a, temp_b and alpha_real, aplha_vali + where alpha_real and/or alpha_imag is not zero. + b. This loop operates with 4x6 block size + along n dimension for every Z_NR columns of temp_b where + computing all Z_MR rows of temp_a. + c. Same approach is used in remaining fringe cases. + */ + + //R(a[0][0]) I(a[0][0]) R(a[1][0]) I(a[1][0]) + ymm0 = _mm256_loadu_pd((double const *)(temp_a)); + //R(a[2][0]) I(a[2][0]) R(a[3][0]) I(a[3][0]) + ymm1 = _mm256_loadu_pd((double const *)(temp_a + 2)); + + ymm13 = ymm0; + ymm14 = ymm1; + SCALE_ALPHA_REAL_M_LOOP(ymm0,ymm1,ymm15,alpha_real); + SCALE_ALPHA_IMAG_M_LOOP(ymm0,ymm1,ymm13,ymm14,ymm15,ymm2,alpha_imag); + + /* + The result after scaling with alpha_real and/or alpha_imag is as follows: + For ymm0 : + R(a[0][0]) = alpha_real*R(a[0][0])-alpha_imag*I(a[0][0]) + I(a[0][0]) = alpha_real*I(a[0][0])+alpha_imag*R[0][0] + R(a[1][0]) = alpha_real*R(a[1][0])-alpha_imag*I(a[1][0]) + I(a[1][0]) = alpha_real*I(a[1][0])+alpha_imag*(R[1][0]) + + For ymm1 : + R(a[2][0]) = alpha_real*R(a[2][0])-alpha_imag*I(a[2][0]) + I(a[2][0]) = alpha_real*I(a[2][0])+alpha_imag*R[2][0] + R(a[3][0]) = alpha_real*R(a[3][0])-alpha_imag*I(a[3][0]) + I(a[3][0]) = alpha_real*I(a[3][0])+alpha_imag*(R[3][0]) + */ + + //Calculating using real part of complex number in B matrix + FMA_M_LOOP(ymm0,ymm1,ymm3,ymm4,ymm2,(double const *)(temp_b)); + + //Calculating using imaginary part of complex numbers in B matrix + //Shuffling ymm0 and ymm1 in accordance to the requirement + NEG_PERM_M_LOOP(ymm0,ymm1,ymm2); + FMA_M_LOOP(ymm0,ymm1,ymm3,ymm4,ymm2,(double const *)(temp_b)+1); + + /* + a. Perform beta*C using temp_c, beta_real, + where beta_real is not zero. + b. This loop operates with 4x6 block size + along n dimension for every Z_NR columns of temp_c where + computing all Z_MR rows of temp_c. + c. Accumulated alpha*A*B into registers will be added to beta*C + d. Same approach is used in remaining fringe cases. + */ + if(beta_real != 0.0) { - /* - a. Perform beta*C using temp_c, beta_valr, - where beta_valr is not zero. - b. This loop operates with 4x5 block size - along n dimension for every Z_NR columns of temp_c where - computing all Z_MR rows of temp_c. - c. Accumulated alpha*A*B into registers will be added to beta*C - d. Same approach is used in remaining fringe cases. - */ - ymm15 = _mm256_broadcast_sd((double const *)(&beta_valr)); + ymm15 = _mm256_broadcast_sd((double const *)(&beta_real)); //R(c[0][0]) I(c[0][0]) R(c[1][0]) I(c[1][0]) ymm0 = _mm256_loadu_pd((double const *)(temp_c)); //R(c[2][0]) I(c[2][0]) R(c[3][0]) I(c[3][0]) ymm1 = _mm256_loadu_pd((double const *)(temp_c + 2)); - //ymm3+=beta_valr*R(c[0][0]) beta_valr*I(c[0][0]) - // beta_valr*R(c[1][0]) beta_valr*I(c[1][0]) - //ymm4+=beta_valr*R(c[2][0]) beta_valr*I(c[2][0]) - // beta_valr*R(c[3][0]) beta_valr*I(c[3][0]) + //ymm3+=beta_real*R(c[0][0]) beta_real*I(c[0][0]) + // beta_real*R(c[1][0]) beta_real*I(c[1][0]) + //ymm4+=beta_real*R(c[2][0]) beta_real*I(c[2][0]) + // beta_real*R(c[3][0]) beta_real*I(c[3][0]) SCALE_BETA_REAL_M_LOOP(ymm0,ymm1,ymm3,ymm4,ymm15); - } - if(beta_vali != 0.0) + + /* + a. Perform beta*C using temp_c, beta_imag, + where beta_imag is not zero. + b. This loop operates with 4x6 block size + along n dimension for every Z_NR columns of temp_c where + computing all Z_MR rows of temp_c. + c. Accumulated alpha*A*B into registers will be added to beta*C + d. Same approach is used in remaining fringe cases. + */ + + if(beta_imag != 0.0) { - /* - a. Perform beta*C using temp_c, beta_vali, - where beta_vali is not zero. - b. This loop operates with 4x5 block size - along n dimension for every Z_NR columns of temp_c where - computing all Z_MR rows of temp_c. - c. Accumulated alpha*A*B into registers will be added to beta*C - d. Same approach is used in remaining fringe cases. - */ - - ymm15 = _mm256_broadcast_sd((double const *)(&beta_vali)); + ymm15 = _mm256_broadcast_sd((double const *)(&beta_imag)); ymm0 = _mm256_loadu_pd((double const *)(temp_c)); ymm1 = _mm256_loadu_pd((double const *)(temp_c + 2)); - //ymm3+=beta_vali*(-I(c[0][0])) beta_vali*R(c[0][0]) - // beta_vali*(-I(c[1][0])) beta_vali*R(c[1][0]) - //ymm4+=beta_vali*(-I(c[2][0])) beta_vali*R(c[2][0]) - // beta_vali*(-I(c[3][0])) beta_vali*R(c[3][0]) + //ymm3+=beta_imag*(-I(c[0][0])) beta_imag*R(c[0][0]) + // beta_imag*(-I(c[1][0])) beta_imag*R(c[1][0]) + //ymm4+=beta_imag*(-I(c[2][0])) beta_imag*R(c[2][0]) + // beta_imag*(-I(c[3][0])) beta_imag*R(c[3][0]) SCALE_BETA_IMAG_M_LOOP(ymm0,ymm1,ymm3,ymm4,ymm15,ymm2); } /* The scaling has been done sequentially as follows: - - If alpha_valr is not 0, it is used for scaling A - - If alpha_vali is not 0, it is used for scaling A using permutation + - If alpha_real is not 0, it is used for scaling A + - If alpha_imag is not 0, it is used for scaling A using permutation and selective negation, after loading - - If beta_valr is not 0, is is used for scaling C - - If beta_vali is not 0, it is used for scaling C using permutation + - If beta_real is not 0, is is used for scaling C + - If beta_imag is not 0, it is used for scaling C using permutation and selective negation, after loading The results are accumalated in accordance to the non zero scalar values, @@ -1651,73 +1693,69 @@ void bli_zgemm_ref_k1_nn temp_a+=Z_MR; } + // Fringe cases for M dim_t m_rem=m_remainder; if(m_rem>=2) { ymm3 = _mm256_setzero_pd(); - if(alpha_valr != 0.0 || alpha_vali != 0.0) - { - //R(a[0][0]) I(a[0][0]) R(a[1][0]) I(a[1][0]) - ymm0 = _mm256_loadu_pd((double const *)(temp_a)); - ymm13 = ymm0; - - SCALE_ALPHA_REAL_M_FRINGE(ymm0,ymm15,alpha_valr); - SCALE_ALPHA_IMAG_M_FRINGE(ymm0,ymm15,ymm2,ymm13,alpha_vali); - - /* - The result after scaling with alpha_valr and/or alpha_vali is as follows: - For ymm0 : - R(a[0][0]) = alpha_valr*R(a[0][0])-alpha_vali*I(a[0][0]) - I(a[0][0]) = alpha_valr*I(a[0][0])+alpha_vali*R[0][0] - R(a[1][0]) = alpha_valr*R(a[1][0])-alpha_vali*I(a[1][0]) - I(a[1][0]) = alpha_valr*I(a[1][0])+alpha_vali*(R[1][0]) - */ - - //Calculating using real part of complex number in B matrix - //ymm3+=R(b[0][0])*R(a[0][0]) R(b[0][0])*I(a[0][0]) - // R(b[0][0])*R(a[1][0]) R(b[0][0])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)); - - //Calculating using imaginary part of complex numbers in B matrix - //Shuffling ymm0 in accordance to the requirement - NEG_PERM_M_FRINGE(ymm0,ymm2); - - // ymm3+=I(b[0][0])*R(a[0][0]) I(b[0][0])*I(a[0][0]) - // I(b[0][0])*R(a[1][0]) I(b[0][0])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)+1); - } + //R(a[0][0]) I(a[0][0]) R(a[1][0]) I(a[1][0]) + ymm0 = _mm256_loadu_pd((double const *)(temp_a)); + + ymm13 = ymm0; + SCALE_ALPHA_REAL_M_FRINGE(ymm0,ymm15,alpha_real); + SCALE_ALPHA_IMAG_M_FRINGE(ymm0,ymm13,ymm15,ymm2,alpha_imag); + + /* + The result after scaling with alpha_real and/or alpha_imag is as follows: + For ymm0 : + R(a[0][0]) = alpha_real*R(a[0][0])-alpha_imag*I(a[0][0]) + I(a[0][0]) = alpha_real*I(a[0][0])+alpha_imag*R[0][0] + R(a[1][0]) = alpha_real*R(a[1][0])-alpha_imag*I(a[1][0]) + I(a[1][0]) = alpha_real*I(a[1][0])+alpha_imag*(R[1][0]) + */ + + //Calculating using real part of complex number in B matrix + //ymm3+=R(b[0][0])*R(a[0][0]) R(b[0][0])*I(a[0][0]) + // R(b[0][0])*R(a[1][0]) R(b[0][0])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)); + + //Calculating using imaginary part of complex numbers in B matrix + //Shuffling ymm0 in accordance to the requirement + NEG_PERM_M_FRINGE(ymm0,ymm2); + + // ymm3+=I(b[0][0])*R(a[0][0]) I(b[0][0])*I(a[0][0]) + // I(b[0][0])*R(a[1][0]) I(b[0][0])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)+1); - if(beta_valr != 0.0) - { - ymm15 = _mm256_broadcast_sd((double const *)(&beta_valr)); + if(beta_real != 0.0) + { + ymm15 = _mm256_broadcast_sd((double const *)(&beta_real)); ymm0 = _mm256_loadu_pd((double const *)(temp_c)); - //ymm3+=beta_valr*R(c[0][0]) beta_valr*I(c[0][0]) - // beta_valr*R(c[1][0]) beta_valr*I(c[1][0]) + //ymm3+=beta_real*R(c[0][0]) beta_real*I(c[0][0]) + // beta_real*R(c[1][0]) beta_real*I(c[1][0]) SCALE_BETA_REAL_M_FRINGE(ymm0,ymm3,ymm15); } - if(beta_vali != 0.0) + if(beta_imag != 0.0) { - - ymm15 = _mm256_broadcast_sd((double const *)(&beta_vali)); + ymm15 = _mm256_broadcast_sd((double const *)(&beta_imag)); ymm0 = _mm256_loadu_pd((double const *)(temp_c)); - //ymm3+=beta_vali*(-I(c[0][0])) beta_vali*R(c[0][0]) - // beta_vali*(-I(c[1][0])) beta_vali*R(c[1][0]) + //ymm3+=beta_imag*(-I(c[0][0])) beta_imag*R(c[0][0]) + // beta_imag*(-I(c[1][0])) beta_imag*R(c[1][0]) SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm3,ymm15,ymm2); } - /* The scaling has been done sequentially as follows: - - If alpha_valr is not 0, it is used for scaling A - - If alpha_vali is not 0, it is used for scaling A using permutation + - If alpha_real is not 0, it is used for scaling A + - If alpha_imag is not 0, it is used for scaling A using permutation and selective negation, after loading - - If beta_valr is not 0, is is used for scaling C - - If beta_vali is not 0, it is used for scaling C using permutation + - If beta_real is not 0, is is used for scaling C + - If beta_imag is not 0, it is used for scaling C using permutation and selective negation, after loading The results are accumalated in accordance to the non zero scalar values, @@ -1738,46 +1776,44 @@ void bli_zgemm_ref_k1_nn xmm5 = _mm_setzero_pd(); ymm3 = _mm256_setzero_pd(); - if(alpha_valr != 0.0 || alpha_vali != 0.0) - { - xmm5 = _mm_loadu_pd((double const*)(temp_a)); - ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0); - ymm13 = ymm0; + xmm5 = _mm_loadu_pd((double const*)(temp_a)); + ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0); - SCALE_ALPHA_REAL_M_FRINGE(ymm0,ymm15,alpha_valr); - SCALE_ALPHA_IMAG_M_FRINGE(ymm0,ymm15,ymm2,ymm13,alpha_vali); + ymm13 = ymm0; + SCALE_ALPHA_REAL_M_FRINGE(ymm0,ymm15,alpha_real); + SCALE_ALPHA_IMAG_M_FRINGE(ymm0,ymm13,ymm15,ymm2,alpha_imag); - //Calculating using real part of complex number in B matrix - //ymm3+=R(b[0][0])*R(a[0][0]) R(b[0][0])*I(a[0][0]) - // R(b[0][0])*R(a[1][0]) R(b[0][0])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)); + //Calculating using real part of complex number in B matrix + //ymm3+=R(b[0][0])*R(a[0][0]) R(b[0][0])*I(a[0][0]) + // R(b[0][0])*R(a[1][0]) R(b[0][0])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)); - //Calculating using imaginary part of complex numbers in B matrix - //Shuffling ymm0 in accordance to the requirement - NEG_PERM_M_FRINGE(ymm0,ymm2); + //Calculating using imaginary part of complex numbers in B matrix + //Shuffling ymm0 in accordance to the requirement + NEG_PERM_M_FRINGE(ymm0,ymm2); - // ymm3+=I(b[0][0])*R(a[0][0]) I(b[0][0])*I(a[0][0]) - // I(b[0][0])*R(a[1][0]) I(b[0][0])*I(a[1][0]) - FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)+1); - } - if(beta_valr != 0.0) + // ymm3+=I(b[0][0])*R(a[0][0]) I(b[0][0])*I(a[0][0]) + // I(b[0][0])*R(a[1][0]) I(b[0][0])*I(a[1][0]) + FMA_M_FRINGE(ymm0,ymm3,ymm2,(double const *)(temp_b)+1); + + if(beta_real != 0.0) { - ymm15 = _mm256_broadcast_sd((double const *)(&beta_valr)); + ymm15 = _mm256_broadcast_sd((double const *)(&beta_real)); xmm5 = _mm_loadu_pd((double const*)(temp_c)); ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0); - //ymm3+=beta_valr*R(c[0][0]) beta_valr*I(c[0][0]) + //ymm3+=beta_real*R(c[0][0]) beta_real*I(c[0][0]) SCALE_BETA_REAL_M_FRINGE(ymm0,ymm3,ymm15); } - if(beta_vali != 0.0) + + if(beta_imag != 0.0) { - ymm15 = _mm256_broadcast_sd((double const *)(&beta_vali)); + ymm15 = _mm256_broadcast_sd((double const *)(&beta_imag)); xmm5 = _mm_loadu_pd((double const*)(temp_c)); ymm0 = _mm256_insertf128_pd(ymm0,xmm5,0); - //ymm3+=beta_vali*(-I(c[0][0])) beta_vali*R(c[0][0]) + //ymm3+=beta_imag*(-I(c[0][0])) beta_imag*R(c[0][0]) SCALE_BETA_IMAG_M_FRINGE(ymm0,ymm3,ymm15,ymm2); - } xmm5 = _mm256_extractf128_pd(ymm3, 0); @@ -1787,4 +1823,4 @@ void bli_zgemm_ref_k1_nn } -} +} \ No newline at end of file From 6c1acc74c833dc3b0f713b3ca5cd8c1e63b16689 Mon Sep 17 00:00:00 2001 From: Mangala V Date: Wed, 27 Jul 2022 12:16:43 +0530 Subject: [PATCH 152/243] ZGEMM optimizations -- Conditionally packing of B matrix is enabled in zgemmsup path which is performing better when B matrix is large -- Incorporated decision logic to choose between zgemm_small vs zgemm sup based on matrix dimensions "m, n and k". -- Calling of ZGEMV when matrix dimension m or n = 1. Very good performance improvement is observed. Change-Id: I7c64020f4f78a6a51617b184cc88076213b5527d --- frame/3/bli_l3_sup_int_amd.c | 14 ++++++ frame/compat/bla_gemm_amd.c | 88 ++++++++++++++++++++++++++++++++++-- 2 files changed, 99 insertions(+), 3 deletions(-) diff --git a/frame/3/bli_l3_sup_int_amd.c b/frame/3/bli_l3_sup_int_amd.c index e00cc54ad0..bbf5637555 100644 --- a/frame/3/bli_l3_sup_int_amd.c +++ b/frame/3/bli_l3_sup_int_amd.c @@ -119,6 +119,13 @@ err_t bli_gemmsup_int bli_rntm_set_pack_b( 1, rntm ); } + /*Enable packing of B matrix for complex data type*/ + if (bli_is_dcomplex(dt) && (n_threads == 1)) + { + if ((m > 55) && (k > 55) && (n > 55)) + bli_rntm_set_pack_b(1, rntm); + } + bli_gemmsup_ref_var2m( BLIS_NO_TRANSPOSE, alpha, a, b, beta, c, stor_id, cntx, rntm, thread ); @@ -152,6 +159,13 @@ err_t bli_gemmsup_int bli_rntm_set_pack_a( 1, rntm ); } + /*Enable packing of A matrix for complex data type*/ + if (bli_is_dcomplex(dt) && (n_threads == 1)) + { + if ((m > 55) && (k > 55) && (n > 55)) + bli_rntm_set_pack_a(1, rntm); + } + bli_gemmsup_ref_var2m( BLIS_TRANSPOSE, alpha, a, b, beta, c, stor_id, cntx, rntm, thread ); diff --git a/frame/compat/bla_gemm_amd.c b/frame/compat/bla_gemm_amd.c index 8a8d576c46..d46a69f0f8 100644 --- a/frame/compat/bla_gemm_amd.c +++ b/frame/compat/bla_gemm_amd.c @@ -712,6 +712,7 @@ void zgemm_ //dim_t nt = bli_thread_get_num_threads(); // get number of threads bool nt = bli_thread_get_is_parallel(); // Check if parallel zgemm is invoked. + /* Invoking the API for input sizes with k=1. - For single thread, the API has no constraints before invoking. @@ -735,11 +736,83 @@ void zgemm_ return; } + /* Call Gemv when m/n=1 */ + if (n0 == 1) + { + if (bli_is_notrans(blis_transa)) + { + bli_zgemv_unf_var2( + BLIS_NO_TRANSPOSE, + bli_extract_conj(blis_transb), + m0, k0, + (dcomplex *)alpha, + (dcomplex *)a, rs_a, cs_a, + (dcomplex *)b, bli_is_notrans(blis_transb) ? rs_b : cs_b, + (dcomplex *)beta, + c, rs_c, + ((void *)0)); + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + + return; + } +#if 0 +/*** Code is disabled as bli_zgemv_unf_var1 not optimised *** + Calling below unoptimised variant causes regression ***/ + else + { + bli_zgemv_unf_var1( + blis_transa, + bli_extract_conj(blis_transb), + k0, m0, + (dcomplex *)alpha, + (dcomplex *)a, rs_a, cs_a, + (dcomplex *)b, bli_is_notrans(blis_transb) ? rs_b : cs_b, + (dcomplex *)beta, + c, rs_c, + ((void *)0)); + } +#endif + } + else if (m0 == 1) + { + if (bli_is_trans(blis_transb)) + { + bli_zgemv_unf_var2( + blis_transb, + bli_extract_conj(blis_transa), + k0, n0, + (dcomplex *)alpha, + (dcomplex *)b, cs_b, rs_b, + (dcomplex *)a, bli_is_notrans(blis_transa) ? cs_a : rs_a, + (dcomplex *)beta, + c, cs_c, + ((void *)0)); + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + return; + } +#if 0 +/*** Code is disabled as bli_zgemv_unf_var1 not optimised *** + Calling below unoptimised variant causes regression ***/ + + else + { + bli_zgemv_unf_var1( + blis_transb, + bli_extract_conj(blis_transa), + n0, k0, + (dcomplex *)alpha, + (dcomplex *)b, cs_b, rs_b, + (dcomplex *)a, bli_is_notrans(blis_transa) ? cs_a : rs_a, + (dcomplex *)beta, + c, cs_c, + ((void *)0)); + } +#endif + } #ifdef BLIS_ENABLE_SMALL_MATRIX - if( ( (nt == 0) && (m0 <= 512 ) && ( n0 <= 512 ) && ( k0 <= 512 ) ) || - ( (nt == 1) && ((( m0 <= 32)||(n0 <= 32)||(k0 <=32)) && ((m0+n0+k0)<=100)) ) - ) + if (((nt == 0) && (m0 <= 40) && (n0 <= 40) && (k0 <= 512)) || + ((nt == 1) && (((m0 <= 32) || (n0 <= 32) || (k0 <= 32)) && ((m0 + n0 + k0) <= 100)))) { err_t status = BLIS_NOT_YET_IMPLEMENTED; if (bli_is_notrans(blis_transa)) @@ -775,6 +848,15 @@ void zgemm_ } } #endif + + err_t status = bli_gemmsup(&alphao, &ao, &bo, &betao, &co, NULL, NULL); + if (status == BLIS_SUCCESS) + { + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) + return; + } + // fall back on native path when zgemm is not handled in sup path. bli_gemmnat(&alphao, &ao, &bo, &betao, &co, NULL, NULL); AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); From fde812015fc04c213eeb00a8fb8d1472eff32bea Mon Sep 17 00:00:00 2001 From: Chandrashekara K R Date: Fri, 29 Jul 2022 15:55:10 +0530 Subject: [PATCH 153/243] Updated blis library version from 4.0 to 3.2.1 AMD-Internal: [CPUPL-2322] Change-Id: I3a6a61543dd2754e2590d7f5f22442c9fdeaee95 --- frame/base/bli_info.c | 2 +- so_version | 2 +- version | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/frame/base/bli_info.c b/frame/base/bli_info.c index a3e9cb2ec5..73ac5b3f57 100644 --- a/frame/base/bli_info.c +++ b/frame/base/bli_info.c @@ -40,7 +40,7 @@ // This string gets defined via -D on the command line when BLIS is compiled. // This string is (or rather, should be) only used here. -static char* bli_version_str = "4.0"; //BLIS_VERSION_STRING; +static char* bli_version_str = "3.2.1"; //BLIS_VERSION_STRING; static char* bli_int_type_size_str = STRINGIFY_INT( BLIS_INT_TYPE_SIZE ); char* bli_info_get_version_str( void ) { return bli_version_str; } diff --git a/so_version b/so_version index 77605e74c7..63b08bc9c2 100644 --- a/so_version +++ b/so_version @@ -1,2 +1,2 @@ 3 -1.2 +2.1 diff --git a/version b/version index ef538c2810..e4604e3afd 100644 --- a/version +++ b/version @@ -1 +1 @@ -3.1.2 +3.2.1 From 1d31386c0233355cf27ea9d131967fd1c7cdc12c Mon Sep 17 00:00:00 2001 From: Nallani Bhaskar Date: Fri, 29 Jul 2022 07:31:58 +0530 Subject: [PATCH 154/243] Fixed few out of bound memory reads in sgemmsup kernels Details: Fixed memory access bugs in the bli_sgemmsup_rd_zen_asm_s1x16() kernel. The bugs were caused by loading four single-precision elements of C, via instructions such as: vfmadd231ps(mem(rcx, 0*32), ymm3, ymm4) or vfmadd231ps(mem(rcx, 0*32), xmm3, xmm4) in situations where only two elements are guaranteed to exist. (These bugs may not have manifested in earlier tests due to the leading dimension alignment that BLIS employs by default.) The issue was fixed by replacing lines like the one above with: vmovsd(mem(rcx), xmm0) vfmadd231ps(xmm0, xmm3, xmm4) Thus, we use vmovsd to explicitly load only two elements of C into registers, and then operate on those values using register addressing. AMD_CPUPLID: CPUPL-2279 Change-Id: Ic39290d651f5218b2e548351a87ac5e4b5b79c68 --- .../zen/3/sup/bli_gemmsup_rd_zen_asm_s6x16.c | 5 +++-- .../zen/3/sup/bli_gemmsup_rv_zen_asm_s6x16.c | 20 ++++++++++++------- .../zen/3/sup/bli_gemmsup_rv_zen_asm_s6x16m.c | 20 ++++++++++++------- 3 files changed, 29 insertions(+), 16 deletions(-) diff --git a/kernels/zen/3/sup/bli_gemmsup_rd_zen_asm_s6x16.c b/kernels/zen/3/sup/bli_gemmsup_rd_zen_asm_s6x16.c index 96bc927499..c309c8c0cd 100644 --- a/kernels/zen/3/sup/bli_gemmsup_rd_zen_asm_s6x16.c +++ b/kernels/zen/3/sup/bli_gemmsup_rd_zen_asm_s6x16.c @@ -3,7 +3,7 @@ An object-based framework for developing high-performance BLAS-like libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020, Advanced Micro Devices, Inc. + Copyright (C) 2020 - 2022 , Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -516,7 +516,8 @@ void bli_sgemmsup_rd_zen_asm_1x16 je(.SBETAZERO) // if ZF = 1, jump to beta == 0 case label(.SROWSTORED) - vfmadd231ps(mem(rcx), ymm3, ymm4) + vmovups(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm4) vmovups(xmm4, mem(rcx)) jmp(.SDONE) // jump to end. diff --git a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_s6x16.c b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_s6x16.c index 752a0a01c5..507ff5a717 100644 --- a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_s6x16.c +++ b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_s6x16.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020, Advanced Micro Devices, Inc. + Copyright (C) 2020-2022, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -8048,15 +8048,18 @@ void bli_sgemmsup_rv_zen_asm_3x2 label(.SROWSTORED) - vfmadd231ps(mem(rcx), xmm3, xmm4) + vmovsd(mem(rcx), xmm0)////a0a1 + vfmadd231ps(xmm0, xmm3, xmm4) vmovsd(xmm4, mem(rcx)) add(rdi, rcx) - vfmadd231ps(mem(rcx), xmm3, xmm6) + vmovsd(mem(rcx), xmm0)////a0a1 + vfmadd231ps(xmm0, xmm3, xmm6) vmovsd(xmm6, mem(rcx)) add(rdi, rcx) - vfmadd231ps(mem(rcx), xmm3, xmm8) + vmovsd(mem(rcx), xmm0)////a0a1 + vfmadd231ps(xmm0, xmm3, xmm8) vmovsd(xmm8, mem(rcx)) jmp(.SDONE) // jump to end. @@ -8329,11 +8332,13 @@ void bli_sgemmsup_rv_zen_asm_2x2 label(.SROWSTORED) - vfmadd231ps(mem(rcx), xmm3, xmm4) + vmovsd(mem(rcx), xmm0)////a0a1 + vfmadd231ps(xmm0, xmm3, xmm4) vmovsd(xmm4, mem(rcx)) add(rdi, rcx) - vfmadd231ps(mem(rcx), xmm3, xmm6) + vmovsd(mem(rcx), xmm0)////a0a1 + vfmadd231ps(xmm0, xmm3, xmm6) vmovsd(xmm6, mem(rcx)) jmp(.SDONE) // jump to end. @@ -8577,7 +8582,8 @@ void bli_sgemmsup_rv_zen_asm_1x2 label(.SROWSTORED) - vfmadd231ps(mem(rcx), xmm3, xmm4) + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm4) vmovsd(xmm4, mem(rcx)) jmp(.SDONE) // jump to end. diff --git a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_s6x16m.c b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_s6x16m.c index 41dbbd699e..e6ecd47f47 100644 --- a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_s6x16m.c +++ b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_s6x16m.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020, Advanced Micro Devices, Inc. + Copyright (C) 2020-2022, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -2231,22 +2231,28 @@ void bli_sgemmsup_rv_zen_asm_6x2m label(.SROWSTORED) - vfmadd231ps(mem(rcx), xmm3, xmm4) + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm4) vmovlpd(xmm4, mem(rcx)) add(rdi, rcx) - vfmadd231ps(mem(rcx), xmm3, xmm6) + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm6) vmovlpd(xmm6, mem(rcx)) add(rdi, rcx) - vfmadd231ps(mem(rcx), xmm3, xmm8) + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm8) vmovlpd(xmm8, mem(rcx)) add(rdi, rcx) - vfmadd231ps(mem(rcx), xmm3, xmm10) + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm10) vmovlpd(xmm10, mem(rcx)) add(rdi, rcx) - vfmadd231ps(mem(rcx), xmm3, xmm12) + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm12) vmovlpd(xmm12, mem(rcx)) add(rdi, rcx) - vfmadd231ps(mem(rcx), xmm3, xmm14) + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm14) vmovlpd(xmm14, mem(rcx)) jmp(.SDONE) // jump to end. From 4b1663213cafbfc7f975926c8fce9df8d61a5a59 Mon Sep 17 00:00:00 2001 From: "Field G. Van Zee" Date: Thu, 14 Jul 2022 17:55:34 -0500 Subject: [PATCH 155/243] Fixed out-of-bounds read in haswell gemmsup kernels. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Details: - Fixed memory access bugs in the bli_sgemmsup_rv_haswell_asm_Mx2() kernels, where M = {1,2,3,4,5,6}. The bugs were caused by loading four single-precision elements of C, via instructions such as: vfmadd231ps(mem(rcx, 0*32), xmm3, xmm4) in situations where only two elements are guaranteed to exist. (These bugs may not have manifested in earlier tests due to the leading dimension alignment that BLIS employs by default.) The issue was fixed by replacing lines like the one above with: vmovsd(mem(rcx), xmm0) vfmadd231ps(xmm0, xmm3, xmm4) Thus, we use vmovsd to explicitly load only two elements of C into registers, and then operate on those values using register addressing. Thanks to Daniël de Kok for reporting these bugs in #635, and to Bhaskar Nallani for proposing the fix). - CREDITS file update. Change-Id: Ib525c36bcbf20b2bbbe380da3d74d142b338fe9b --- CREDITS | 1 + .../s6x16/bli_gemmsup_rv_haswell_asm_sMx2.c | 141 ++++++++++-------- 2 files changed, 79 insertions(+), 63 deletions(-) diff --git a/CREDITS b/CREDITS index c6d5d7151a..d68bcca014 100644 --- a/CREDITS +++ b/CREDITS @@ -23,6 +23,7 @@ but many others have contributed code and feedback, including Dilyn Corner @dilyn-corner Mat Cross @matcross (NAG) @decandia50 + Daniël de Kok @danieldk (Explosion) Kay Dewhurst @jkd2016 (Max Planck Institute, Halle, Germany) Jeff Diamond (Oracle) Johannes Dieterich @iotamudelta diff --git a/kernels/haswell/3/sup/s6x16/bli_gemmsup_rv_haswell_asm_sMx2.c b/kernels/haswell/3/sup/s6x16/bli_gemmsup_rv_haswell_asm_sMx2.c index 6090f8b0b9..3cbb69a50f 100644 --- a/kernels/haswell/3/sup/s6x16/bli_gemmsup_rv_haswell_asm_sMx2.c +++ b/kernels/haswell/3/sup/s6x16/bli_gemmsup_rv_haswell_asm_sMx2.c @@ -387,34 +387,39 @@ void bli_sgemmsup_rv_haswell_asm_6x2 label(.SROWSTORED) - - - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm4) + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm4) vmovsd(xmm4, mem(rcx, 0*32)) add(rdi, rcx) - - - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm6) + + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm6) vmovsd(xmm6, mem(rcx, 0*32)) add(rdi, rcx) - - - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm8) + + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm8) vmovsd(xmm8, mem(rcx, 0*32)) add(rdi, rcx) - - - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm10) + + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm10) vmovsd(xmm10, mem(rcx, 0*32)) add(rdi, rcx) - - - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm12) + + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm12) vmovsd(xmm12, mem(rcx, 0*32)) add(rdi, rcx) - - - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm14) + + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm14) vmovsd(xmm14, mem(rcx, 0*32)) //add(rdi, rcx) @@ -846,29 +851,33 @@ void bli_sgemmsup_rv_haswell_asm_5x2 label(.SROWSTORED) - - - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm4) + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm4) vmovsd(xmm4, mem(rcx, 0*32)) add(rdi, rcx) - - - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm6) + + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm6) vmovsd(xmm6, mem(rcx, 0*32)) add(rdi, rcx) - - - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm8) + + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm8) vmovsd(xmm8, mem(rcx, 0*32)) add(rdi, rcx) - - - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm10) + + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm10) vmovsd(xmm10, mem(rcx, 0*32)) add(rdi, rcx) - - - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm12) + + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm12) vmovsd(xmm12, mem(rcx, 0*32)) //add(rdi, rcx) @@ -1286,24 +1295,27 @@ void bli_sgemmsup_rv_haswell_asm_4x2 label(.SROWSTORED) - - - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm4) + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm4) vmovsd(xmm4, mem(rcx, 0*32)) add(rdi, rcx) - - - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm6) + + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm6) vmovsd(xmm6, mem(rcx, 0*32)) add(rdi, rcx) - - - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm8) + + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm8) vmovsd(xmm8, mem(rcx, 0*32)) add(rdi, rcx) - - - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm10) + + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm10) vmovsd(xmm10, mem(rcx, 0*32)) //add(rdi, rcx) @@ -1681,19 +1693,21 @@ void bli_sgemmsup_rv_haswell_asm_3x2 label(.SROWSTORED) - - - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm4) + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm4) vmovsd(xmm4, mem(rcx, 0*32)) add(rdi, rcx) - - - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm6) + + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm6) vmovsd(xmm6, mem(rcx, 0*32)) add(rdi, rcx) - - - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm8) + + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm8) vmovsd(xmm8, mem(rcx, 0*32)) //add(rdi, rcx) @@ -2064,14 +2078,15 @@ void bli_sgemmsup_rv_haswell_asm_2x2 label(.SROWSTORED) - - - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm4) + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm4) vmovsd(xmm4, mem(rcx, 0*32)) add(rdi, rcx) - - - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm6) + + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm6) vmovsd(xmm6, mem(rcx, 0*32)) //add(rdi, rcx) @@ -2402,9 +2417,9 @@ void bli_sgemmsup_rv_haswell_asm_1x2 label(.SROWSTORED) - - - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm4) + + vmovsd(mem(rcx), xmm0) + vfmadd231ps(xmm0, xmm3, xmm4) vmovsd(xmm4, mem(rcx, 0*32)) //add(rdi, rcx) From faff30b46a6c3573a33912e73f555a1f625e3394 Mon Sep 17 00:00:00 2001 From: "Field G. Van Zee" Date: Tue, 26 Jul 2022 17:29:32 -0500 Subject: [PATCH 156/243] Fixed out-of-bounds bug in sup s6x16m haswell kernel. Details: - Fixed another out-of-bounds read access bug in the haswell sup assembly kernels. This bug is similar to the one fixed in 17b0caa and affects bli_sgemmsup_rv_haswell_asm_6x2m(). Thanks to Madeesh Kannan for reporting this bug (and a suitable fix) in #635. - CREDITS file update. Change-Id: I10ccf4d4f471d93e8c8cc4df422c686438fb04e9 --- CREDITS | 1 + .../3/sup/bli_gemmsup_rv_haswell_asm_s6x16m.c | 41 +++++++++++-------- 2 files changed, 24 insertions(+), 18 deletions(-) diff --git a/CREDITS b/CREDITS index d68bcca014..fd0bcb5b32 100644 --- a/CREDITS +++ b/CREDITS @@ -46,6 +46,7 @@ but many others have contributed code and feedback, including Matthew Honnibal @honnibal Stefan Husmann @stefanhusmann Francisco Igual @figual (Universidad Complutense de Madrid) + Madeesh Kannan @shadeMe Tony Kelman @tkelman Lee Killough @leekillough (Cray) Mike Kistler @mkistler (IBM, Austin Research Laboratory) diff --git a/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_s6x16m.c b/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_s6x16m.c index 426e5157e1..877e636b80 100644 --- a/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_s6x16m.c +++ b/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_s6x16m.c @@ -4475,34 +4475,39 @@ void bli_sgemmsup_rv_haswell_asm_6x2m label(.SROWSTORED) - - - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm4) + + vmovsd(mem(rcx, 0*32), xmm0) + vfmadd231ps(xmm0, xmm3, xmm4) vmovsd(xmm4, mem(rcx, 0*32)) add(rdi, rcx) - - - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm6) + + + vmovsd(mem(rcx, 0*32), xmm0) + vfmadd231ps(xmm0, xmm3, xmm6) vmovsd(xmm6, mem(rcx, 0*32)) add(rdi, rcx) - - - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm8) + + + vmovsd(mem(rcx, 0*32), xmm0) + vfmadd231ps(xmm0, xmm3, xmm8) vmovsd(xmm8, mem(rcx, 0*32)) add(rdi, rcx) - - - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm10) + + + vmovsd(mem(rcx, 0*32), xmm0) + vfmadd231ps(xmm0, xmm3, xmm10) vmovsd(xmm10, mem(rcx, 0*32)) add(rdi, rcx) - - - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm12) + + + vmovsd(mem(rcx, 0*32), xmm0) + vfmadd231ps(xmm0, xmm3, xmm12) vmovsd(xmm12, mem(rcx, 0*32)) add(rdi, rcx) - - - vfmadd231ps(mem(rcx, 0*32), xmm3, xmm14) + + + vmovsd(mem(rcx, 0*32), xmm0) + vfmadd231ps(xmm0, xmm3, xmm14) vmovsd(xmm14, mem(rcx, 0*32)) //add(rdi, rcx) From ea163fc23be7647573380db075a94e70979c78d8 Mon Sep 17 00:00:00 2001 From: Devin Matthews Date: Thu, 16 Sep 2021 10:59:37 -0500 Subject: [PATCH 157/243] Fix problem where uninitialized registers are included in vhaddpd in the Mx1 gemmsup kernels for haswell. The fix is to use the same (valid) source register twice in the horizontal addition. Change-Id: I96ed39e289aaeeb44be9117074b32bd8d4c19de6 --- .../d6x8/bli_gemmsup_rd_haswell_asm_dMx1.c | 624 +++++++++--------- .../d6x8/bli_gemmsup_rd_haswell_asm_dMx4.c | 14 - 2 files changed, 312 insertions(+), 326 deletions(-) diff --git a/kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx1.c b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx1.c index 6e3c1a0e85..457ef9f22d 100644 --- a/kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx1.c +++ b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx1.c @@ -99,9 +99,9 @@ void bli_dgemmsup_rd_haswell_asm_6x1 // ------------------------------------------------------------------------- begin_asm() - + //vzeroall() // zero all xmm/ymm registers. - + mov(var(a), rax) // load address of a. mov(var(rs_a), r8) // load rs_a //mov(var(cs_a), r9) // load cs_a @@ -119,7 +119,7 @@ void bli_dgemmsup_rd_haswell_asm_6x1 //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b //lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a - + mov(var(c), rcx) // load address of c mov(var(rs_c), rdi) // load rs_c @@ -163,19 +163,19 @@ void bli_dgemmsup_rd_haswell_asm_6x1 prefetch(0, mem(r10, rdi, 1, 1*8)) // prefetch c + 4*rs_c prefetch(0, mem(r10, rdi, 2, 1*8)) // prefetch c + 5*rs_c #endif - - - + + + mov(var(k_iter16), rsi) // i = k_iter16; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKITER4) // if i == 0, jump to code that // contains the k_iter4 loop. - - + + label(.DLOOPKITER16) // MAIN LOOP - - + + // ---------------------------------- iteration 0 #if 0 @@ -206,7 +206,7 @@ void bli_dgemmsup_rd_haswell_asm_6x1 add(imm(4*8), rax) // a += 4*cs_a = 4*8; vfmadd231pd(ymm0, ymm3, ymm14) - + // ---------------------------------- iteration 1 vmovupd(mem(rbx ), ymm0) @@ -233,7 +233,7 @@ void bli_dgemmsup_rd_haswell_asm_6x1 // ---------------------------------- iteration 2 - + #if 0 prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a @@ -287,27 +287,27 @@ void bli_dgemmsup_rd_haswell_asm_6x1 add(imm(4*8), rax) // a += 4*cs_a = 4*8; vfmadd231pd(ymm0, ymm3, ymm14) - + dec(rsi) // i -= 1; jne(.DLOOPKITER16) // iterate again if i != 0. - - - - - - + + + + + + label(.DCONSIDKITER4) - + mov(var(k_iter4), rsi) // i = k_iter4; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKLEFT1) // if i == 0, jump to code that // considers k_left1 loop. // else, we prepare to enter k_iter4 loop. - - + + label(.DLOOPKITER4) // EDGE LOOP (ymm) - + #if 0 prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a @@ -336,21 +336,21 @@ void bli_dgemmsup_rd_haswell_asm_6x1 add(imm(4*8), rax) // a += 4*cs_a = 4*8; vfmadd231pd(ymm0, ymm3, ymm14) - + dec(rsi) // i -= 1; jne(.DLOOPKITER4) // iterate again if i != 0. - - - + + + label(.DCONSIDKLEFT1) - + mov(var(k_left1), rsi) // i = k_left1; test(rsi, rsi) // check i via logical AND. je(.DPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left1 loop. - - + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) @@ -358,7 +358,7 @@ void bli_dgemmsup_rd_haswell_asm_6x1 // using the xmm registers would zero out the // high bits of the destination registers, // which would destory intermediate results. - + vmovsd(mem(rbx ), xmm0) add(imm(1*8), rbx) // b += 1*rs_b = 1*8; @@ -381,12 +381,12 @@ void bli_dgemmsup_rd_haswell_asm_6x1 add(imm(1*8), rax) // a += 1*cs_a = 1*8; vfmadd231pd(ymm0, ymm3, ymm14) - + dec(rsi) // i -= 1; jne(.DLOOPKLEFT1) // iterate again if i != 0. - - - + + + @@ -399,28 +399,28 @@ void bli_dgemmsup_rd_haswell_asm_6x1 // ymm10 // ymm12 // ymm14 - - vhaddpd( ymm5, ymm4, ymm0 ) + + vhaddpd( ymm4, ymm4, ymm0 ) vextractf128(imm(1), ymm0, xmm1 ) vaddpd( xmm0, xmm1, xmm4 ) - vhaddpd( ymm7, ymm6, ymm0 ) + vhaddpd( ymm6, ymm6, ymm0 ) vextractf128(imm(1), ymm0, xmm1 ) vaddpd( xmm0, xmm1, xmm6 ) - vhaddpd( ymm9, ymm8, ymm0 ) + vhaddpd( ymm8, ymm8, ymm0 ) vextractf128(imm(1), ymm0, xmm1 ) vaddpd( xmm0, xmm1, xmm8 ) - vhaddpd( ymm11, ymm10, ymm0 ) + vhaddpd( ymm10, ymm10, ymm0 ) vextractf128(imm(1), ymm0, xmm1 ) vaddpd( xmm0, xmm1, xmm10 ) - vhaddpd( ymm13, ymm12, ymm0 ) + vhaddpd( ymm12, ymm12, ymm0 ) vextractf128(imm(1), ymm0, xmm1 ) vaddpd( xmm0, xmm1, xmm12 ) - vhaddpd( ymm15, ymm14, ymm0 ) + vhaddpd( ymm14, ymm14, ymm0 ) vextractf128(imm(1), ymm0, xmm1 ) vaddpd( xmm0, xmm1, xmm14 ) @@ -435,114 +435,114 @@ void bli_dgemmsup_rd_haswell_asm_6x1 //mov(var(rs_c), rdi) // load rs_c //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(double) - + mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate - + vmulpd(xmm0, xmm4, xmm4) // scale by alpha vmulpd(xmm0, xmm6, xmm6) vmulpd(xmm0, xmm8, xmm8) vmulpd(xmm0, xmm10, xmm10) vmulpd(xmm0, xmm12, xmm12) vmulpd(xmm0, xmm14, xmm14) - - - - - - + + + + + + //mov(var(cs_c), rsi) // load cs_c //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) - - - + + + // now avoid loading C if beta == 0 - + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomisd(xmm0, xmm3) // set ZF if beta == 0. je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case - - + + label(.DROWSTORED) - - vmovsd(mem(rcx), xmm0) + + vmovsd(mem(rcx), xmm0) vfmadd231pd(xmm0, xmm3, xmm4) vmovsd(xmm4, mem(rcx)) add(rdi, rcx) - - vmovsd(mem(rcx), xmm0) + + vmovsd(mem(rcx), xmm0) vfmadd231pd(xmm0, xmm3, xmm6) vmovsd(xmm6, mem(rcx)) add(rdi, rcx) - - vmovsd(mem(rcx), xmm0) + + vmovsd(mem(rcx), xmm0) vfmadd231pd(xmm0, xmm3, xmm8) vmovsd(xmm8, mem(rcx)) add(rdi, rcx) - - vmovsd(mem(rcx), xmm0) + + vmovsd(mem(rcx), xmm0) vfmadd231pd(xmm0, xmm3, xmm10) vmovsd(xmm10, mem(rcx)) add(rdi, rcx) - - vmovsd(mem(rcx), xmm0) + + vmovsd(mem(rcx), xmm0) vfmadd231pd(xmm0, xmm3, xmm12) vmovsd(xmm12, mem(rcx)) add(rdi, rcx) - - vmovsd(mem(rcx), xmm0) + + vmovsd(mem(rcx), xmm0) vfmadd231pd(xmm0, xmm3, xmm14) vmovsd(xmm14, mem(rcx)) //add(rdi, rcx) - - - + + + jmp(.DDONE) // jump to end. - - - - + + + + label(.DBETAZERO) - - + + label(.DROWSTORBZ) - - + + vmovsd(xmm4, mem(rcx)) add(rdi, rcx) - + vmovsd(xmm6, mem(rcx)) add(rdi, rcx) - + vmovsd(xmm8, mem(rcx)) add(rdi, rcx) - + vmovsd(xmm10, mem(rcx)) add(rdi, rcx) - + vmovsd(xmm12, mem(rcx)) add(rdi, rcx) - + vmovsd(xmm14, mem(rcx)) //add(rdi, rcx) - - - - + + + + label(.DDONE) - + label(.DRETURN) - + end_asm( : // output operands (none) @@ -613,9 +613,9 @@ void bli_dgemmsup_rd_haswell_asm_3x1 // ------------------------------------------------------------------------- begin_asm() - + //vzeroall() // zero all xmm/ymm registers. - + mov(var(a), rax) // load address of a. mov(var(rs_a), r8) // load rs_a //mov(var(cs_a), r9) // load cs_a @@ -633,7 +633,7 @@ void bli_dgemmsup_rd_haswell_asm_3x1 //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b //lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a - + mov(var(c), rcx) // load address of c mov(var(rs_c), rdi) // load rs_c @@ -671,19 +671,19 @@ void bli_dgemmsup_rd_haswell_asm_3x1 prefetch(0, mem(rcx, rdi, 1, 1*8)) // prefetch c + 1*rs_c prefetch(0, mem(rcx, rdi, 2, 1*8)) // prefetch c + 2*rs_c #endif - - - + + + mov(var(k_iter16), rsi) // i = k_iter16; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKITER4) // if i == 0, jump to code that // contains the k_iter4 loop. - - + + label(.DLOOPKITER16) // MAIN LOOP - - + + // ---------------------------------- iteration 0 #if 0 @@ -705,7 +705,7 @@ void bli_dgemmsup_rd_haswell_asm_3x1 add(imm(4*8), rax) // a += 4*cs_a = 4*8; vfmadd231pd(ymm0, ymm3, ymm8) - + // ---------------------------------- iteration 1 vmovupd(mem(rbx ), ymm0) @@ -723,7 +723,7 @@ void bli_dgemmsup_rd_haswell_asm_3x1 // ---------------------------------- iteration 2 - + #if 0 prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a @@ -759,27 +759,27 @@ void bli_dgemmsup_rd_haswell_asm_3x1 add(imm(4*8), rax) // a += 4*cs_a = 4*8; vfmadd231pd(ymm0, ymm3, ymm8) - + dec(rsi) // i -= 1; jne(.DLOOPKITER16) // iterate again if i != 0. - - - - - - + + + + + + label(.DCONSIDKITER4) - + mov(var(k_iter4), rsi) // i = k_iter4; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKLEFT1) // if i == 0, jump to code that // considers k_left1 loop. // else, we prepare to enter k_iter4 loop. - - + + label(.DLOOPKITER4) // EDGE LOOP (ymm) - + #if 0 prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a @@ -799,21 +799,21 @@ void bli_dgemmsup_rd_haswell_asm_3x1 add(imm(4*8), rax) // a += 4*cs_a = 4*8; vfmadd231pd(ymm0, ymm3, ymm8) - + dec(rsi) // i -= 1; jne(.DLOOPKITER4) // iterate again if i != 0. - - - + + + label(.DCONSIDKLEFT1) - + mov(var(k_left1), rsi) // i = k_left1; test(rsi, rsi) // check i via logical AND. je(.DPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left1 loop. - - + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) @@ -821,7 +821,7 @@ void bli_dgemmsup_rd_haswell_asm_3x1 // using the xmm registers would zero out the // high bits of the destination registers, // which would destory intermediate results. - + vmovsd(mem(rbx ), xmm0) add(imm(1*8), rbx) // b += 1*rs_b = 1*8; @@ -835,12 +835,12 @@ void bli_dgemmsup_rd_haswell_asm_3x1 add(imm(1*8), rax) // a += 1*cs_a = 1*8; vfmadd231pd(ymm0, ymm3, ymm8) - + dec(rsi) // i -= 1; jne(.DLOOPKLEFT1) // iterate again if i != 0. - - - + + + @@ -850,16 +850,16 @@ void bli_dgemmsup_rd_haswell_asm_3x1 // ymm4 // ymm6 // ymm8 - - vhaddpd( ymm5, ymm4, ymm0 ) + + vhaddpd( ymm4, ymm4, ymm0 ) vextractf128(imm(1), ymm0, xmm1 ) vaddpd( xmm0, xmm1, xmm4 ) - vhaddpd( ymm7, ymm6, ymm0 ) + vhaddpd( ymm6, ymm6, ymm0 ) vextractf128(imm(1), ymm0, xmm1 ) vaddpd( xmm0, xmm1, xmm6 ) - vhaddpd( ymm9, ymm8, ymm0 ) + vhaddpd( ymm8, ymm8, ymm0 ) vextractf128(imm(1), ymm0, xmm1 ) vaddpd( xmm0, xmm1, xmm8 ) @@ -871,87 +871,87 @@ void bli_dgemmsup_rd_haswell_asm_3x1 //mov(var(rs_c), rdi) // load rs_c //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(double) - + mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate - + vmulpd(xmm0, xmm4, xmm4) // scale by alpha vmulpd(xmm0, xmm6, xmm6) vmulpd(xmm0, xmm8, xmm8) - - - - - - + + + + + + //mov(var(cs_c), rsi) // load cs_c //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) - - - + + + // now avoid loading C if beta == 0 - + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomisd(xmm0, xmm3) // set ZF if beta == 0. je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case - - + + label(.DROWSTORED) - - vmovsd(mem(rcx), xmm0) + + vmovsd(mem(rcx), xmm0) vfmadd231pd(xmm0, xmm3, xmm4) vmovsd(xmm4, mem(rcx)) add(rdi, rcx) - - vmovsd(mem(rcx), xmm0) + + vmovsd(mem(rcx), xmm0) vfmadd231pd(xmm0, xmm3, xmm6) vmovsd(xmm6, mem(rcx)) add(rdi, rcx) - - vmovsd(mem(rcx), xmm0) + + vmovsd(mem(rcx), xmm0) vfmadd231pd(xmm0, xmm3, xmm8) vmovsd(xmm8, mem(rcx)) //add(rdi, rcx) - - - + + + jmp(.DDONE) // jump to end. - - - - + + + + label(.DBETAZERO) - - + + label(.DROWSTORBZ) - - + + vmovsd(xmm4, mem(rcx)) add(rdi, rcx) - + vmovsd(xmm6, mem(rcx)) add(rdi, rcx) - + vmovsd(xmm8, mem(rcx)) //add(rdi, rcx) - - - - + + + + label(.DDONE) - + label(.DRETURN) - + end_asm( : // output operands (none) @@ -1022,9 +1022,9 @@ void bli_dgemmsup_rd_haswell_asm_2x1 // ------------------------------------------------------------------------- begin_asm() - + //vzeroall() // zero all xmm/ymm registers. - + mov(var(a), rax) // load address of a. mov(var(rs_a), r8) // load rs_a //mov(var(cs_a), r9) // load cs_a @@ -1042,7 +1042,7 @@ void bli_dgemmsup_rd_haswell_asm_2x1 //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b //lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a - + mov(var(c), rcx) // load address of c mov(var(rs_c), rdi) // load rs_c @@ -1078,19 +1078,19 @@ void bli_dgemmsup_rd_haswell_asm_2x1 prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c prefetch(0, mem(rcx, rdi, 1, 1*8)) // prefetch c + 1*rs_c #endif - - - + + + mov(var(k_iter16), rsi) // i = k_iter16; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKITER4) // if i == 0, jump to code that // contains the k_iter4 loop. - - + + label(.DLOOPKITER16) // MAIN LOOP - - + + // ---------------------------------- iteration 0 #if 0 @@ -1109,7 +1109,7 @@ void bli_dgemmsup_rd_haswell_asm_2x1 add(imm(4*8), rax) // a += 4*cs_a = 4*8; vfmadd231pd(ymm0, ymm3, ymm6) - + // ---------------------------------- iteration 1 vmovupd(mem(rbx ), ymm0) @@ -1124,7 +1124,7 @@ void bli_dgemmsup_rd_haswell_asm_2x1 // ---------------------------------- iteration 2 - + #if 0 prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a @@ -1154,27 +1154,27 @@ void bli_dgemmsup_rd_haswell_asm_2x1 add(imm(4*8), rax) // a += 4*cs_a = 4*8; vfmadd231pd(ymm0, ymm3, ymm6) - + dec(rsi) // i -= 1; jne(.DLOOPKITER16) // iterate again if i != 0. - - - - - - + + + + + + label(.DCONSIDKITER4) - + mov(var(k_iter4), rsi) // i = k_iter4; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKLEFT1) // if i == 0, jump to code that // considers k_left1 loop. // else, we prepare to enter k_iter4 loop. - - + + label(.DLOOPKITER4) // EDGE LOOP (ymm) - + #if 0 prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a @@ -1191,21 +1191,21 @@ void bli_dgemmsup_rd_haswell_asm_2x1 add(imm(4*8), rax) // a += 4*cs_a = 4*8; vfmadd231pd(ymm0, ymm3, ymm6) - + dec(rsi) // i -= 1; jne(.DLOOPKITER4) // iterate again if i != 0. - - - + + + label(.DCONSIDKLEFT1) - + mov(var(k_left1), rsi) // i = k_left1; test(rsi, rsi) // check i via logical AND. je(.DPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left1 loop. - - + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) @@ -1213,7 +1213,7 @@ void bli_dgemmsup_rd_haswell_asm_2x1 // using the xmm registers would zero out the // high bits of the destination registers, // which would destory intermediate results. - + vmovsd(mem(rbx ), xmm0) add(imm(1*8), rbx) // b += 1*rs_b = 1*8; @@ -1224,12 +1224,12 @@ void bli_dgemmsup_rd_haswell_asm_2x1 add(imm(1*8), rax) // a += 1*cs_a = 1*8; vfmadd231pd(ymm0, ymm3, ymm6) - + dec(rsi) // i -= 1; jne(.DLOOPKLEFT1) // iterate again if i != 0. - - - + + + @@ -1238,12 +1238,12 @@ void bli_dgemmsup_rd_haswell_asm_2x1 // ymm4 // ymm6 - - vhaddpd( ymm5, ymm4, ymm0 ) + + vhaddpd( ymm4, ymm4, ymm0 ) vextractf128(imm(1), ymm0, xmm1 ) vaddpd( xmm0, xmm1, xmm4 ) - vhaddpd( ymm7, ymm6, ymm0 ) + vhaddpd( ymm6, ymm6, ymm0 ) vextractf128(imm(1), ymm0, xmm1 ) vaddpd( xmm0, xmm1, xmm6 ) @@ -1254,78 +1254,78 @@ void bli_dgemmsup_rd_haswell_asm_2x1 //mov(var(rs_c), rdi) // load rs_c //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(double) - + mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate - + vmulpd(xmm0, xmm4, xmm4) // scale by alpha vmulpd(xmm0, xmm6, xmm6) - - - - - - + + + + + + //mov(var(cs_c), rsi) // load cs_c //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) - - - + + + // now avoid loading C if beta == 0 - + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomisd(xmm0, xmm3) // set ZF if beta == 0. je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case - - + + label(.DROWSTORED) - - vmovsd(mem(rcx), xmm0) + + vmovsd(mem(rcx), xmm0) vfmadd231pd(xmm0, xmm3, xmm4) vmovsd(xmm4, mem(rcx)) add(rdi, rcx) - - vmovsd(mem(rcx), xmm0) + + vmovsd(mem(rcx), xmm0) vfmadd231pd(xmm0, xmm3, xmm6) vmovsd(xmm6, mem(rcx)) //add(rdi, rcx) - - - + + + jmp(.DDONE) // jump to end. - - - - + + + + label(.DBETAZERO) - - + + label(.DROWSTORBZ) - - + + vmovsd(xmm4, mem(rcx)) add(rdi, rcx) - + vmovsd(xmm6, mem(rcx)) //add(rdi, rcx) - - - - + + + + label(.DDONE) - + label(.DRETURN) - + end_asm( : // output operands (none) @@ -1396,9 +1396,9 @@ void bli_dgemmsup_rd_haswell_asm_1x1 // ------------------------------------------------------------------------- begin_asm() - + //vzeroall() // zero all xmm/ymm registers. - + mov(var(a), rax) // load address of a. mov(var(rs_a), r8) // load rs_a //mov(var(cs_a), r9) // load cs_a @@ -1416,7 +1416,7 @@ void bli_dgemmsup_rd_haswell_asm_1x1 //lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b //lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a - + mov(var(c), rcx) // load address of c mov(var(rs_c), rdi) // load rs_c @@ -1450,19 +1450,19 @@ void bli_dgemmsup_rd_haswell_asm_1x1 //lea(mem(r10, rdi, 1), r10) // rdx = c + 3*rs_c; prefetch(0, mem(rcx, 1*8)) // prefetch c + 0*rs_c #endif - - - + + + mov(var(k_iter16), rsi) // i = k_iter16; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKITER4) // if i == 0, jump to code that // contains the k_iter4 loop. - - + + label(.DLOOPKITER16) // MAIN LOOP - - + + // ---------------------------------- iteration 0 #if 0 @@ -1478,7 +1478,7 @@ void bli_dgemmsup_rd_haswell_asm_1x1 add(imm(4*8), rax) // a += 4*cs_a = 4*8; vfmadd231pd(ymm0, ymm3, ymm4) - + // ---------------------------------- iteration 1 vmovupd(mem(rbx ), ymm0) @@ -1490,7 +1490,7 @@ void bli_dgemmsup_rd_haswell_asm_1x1 // ---------------------------------- iteration 2 - + #if 0 prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a @@ -1514,27 +1514,27 @@ void bli_dgemmsup_rd_haswell_asm_1x1 add(imm(4*8), rax) // a += 4*cs_a = 4*8; vfmadd231pd(ymm0, ymm3, ymm4) - + dec(rsi) // i -= 1; jne(.DLOOPKITER16) // iterate again if i != 0. - - - - - - + + + + + + label(.DCONSIDKITER4) - + mov(var(k_iter4), rsi) // i = k_iter4; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKLEFT1) // if i == 0, jump to code that // considers k_left1 loop. // else, we prepare to enter k_iter4 loop. - - + + label(.DLOOPKITER4) // EDGE LOOP (ymm) - + #if 0 prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a @@ -1548,21 +1548,21 @@ void bli_dgemmsup_rd_haswell_asm_1x1 add(imm(4*8), rax) // a += 4*cs_a = 4*8; vfmadd231pd(ymm0, ymm3, ymm4) - + dec(rsi) // i -= 1; jne(.DLOOPKITER4) // iterate again if i != 0. - - - + + + label(.DCONSIDKLEFT1) - + mov(var(k_left1), rsi) // i = k_left1; test(rsi, rsi) // check i via logical AND. je(.DPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left1 loop. - - + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) @@ -1570,7 +1570,7 @@ void bli_dgemmsup_rd_haswell_asm_1x1 // using the xmm registers would zero out the // high bits of the destination registers, // which would destory intermediate results. - + vmovsd(mem(rbx ), xmm0) add(imm(1*8), rbx) // b += 1*rs_b = 1*8; @@ -1578,12 +1578,12 @@ void bli_dgemmsup_rd_haswell_asm_1x1 add(imm(1*8), rax) // a += 1*cs_a = 1*8; vfmadd231pd(ymm0, ymm3, ymm4) - + dec(rsi) // i -= 1; jne(.DLOOPKLEFT1) // iterate again if i != 0. - - - + + + @@ -1591,8 +1591,8 @@ void bli_dgemmsup_rd_haswell_asm_1x1 label(.DPOSTACCUM) // ymm4 - - vhaddpd( ymm5, ymm4, ymm0 ) + + vhaddpd( ymm4, ymm4, ymm0 ) vextractf128(imm(1), ymm0, xmm1 ) vaddpd( xmm0, xmm1, xmm4 ) @@ -1602,69 +1602,69 @@ void bli_dgemmsup_rd_haswell_asm_1x1 //mov(var(rs_c), rdi) // load rs_c //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(double) - + mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate - + vmulpd(xmm0, xmm4, xmm4) // scale by alpha - - - - - - + + + + + + //mov(var(cs_c), rsi) // load cs_c //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) - - - + + + // now avoid loading C if beta == 0 - + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomisd(xmm0, xmm3) // set ZF if beta == 0. je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case - - + + label(.DROWSTORED) - - vmovsd(mem(rcx), xmm0) + + vmovsd(mem(rcx), xmm0) vfmadd231pd(xmm0, xmm3, xmm4) vmovsd(xmm4, mem(rcx)) //add(rdi, rcx) - - - + + + jmp(.DDONE) // jump to end. - - - - + + + + label(.DBETAZERO) - - + + label(.DROWSTORBZ) - - + + vmovsd(xmm4, mem(rcx)) //add(rdi, rcx) - - - - + + + + label(.DDONE) - + label(.DRETURN) - + end_asm( : // output operands (none) diff --git a/kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx4.c b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx4.c index 4c6094b1cd..4ac275f5de 100644 --- a/kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx4.c +++ b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx4.c @@ -1339,20 +1339,6 @@ void bli_dgemmsup_rd_haswell_asm_1x4 vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) - vhaddpd( ymm8, ymm5, ymm0 ) - vextractf128(imm(1), ymm0, xmm1 ) - vaddpd( xmm0, xmm1, xmm0 ) - - vhaddpd( ymm14, ymm11, ymm2 ) - vextractf128(imm(1), ymm2, xmm1 ) - vaddpd( xmm2, xmm1, xmm2 ) - - vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) - - // xmm4[0:3] = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) - - - //mov(var(rs_c), rdi) // load rs_c //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) From 9495401b7338be76c9206c57712ef23347a1fc2b Mon Sep 17 00:00:00 2001 From: Devin Matthews Date: Thu, 16 Sep 2021 10:16:17 -0500 Subject: [PATCH 158/243] Fix more copy-paste errors in the haswell gemmsup code. Fixes #486. Change-Id: I568386b5d67a698ea9c0b6b17f133df86c2894bd --- .../d6x8/bli_gemmsup_rd_haswell_asm_dMx4.c | 465 +++++++++--------- 1 file changed, 239 insertions(+), 226 deletions(-) diff --git a/kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx4.c b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx4.c index 4ac275f5de..41f73bc9ec 100644 --- a/kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx4.c +++ b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx4.c @@ -101,7 +101,7 @@ void bli_dgemmsup_rd_haswell_asm_6x4 begin_asm() //vzeroall() // zero all xmm/ymm registers. - + mov(var(a), r14) // load address of a. mov(var(rs_a), r8) // load rs_a //mov(var(cs_a), r9) // load cs_a @@ -119,7 +119,7 @@ void bli_dgemmsup_rd_haswell_asm_6x4 lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a - + mov(var(c), r12) // load address of c mov(var(rs_c), rdi) // load rs_c @@ -172,19 +172,19 @@ void bli_dgemmsup_rd_haswell_asm_6x4 prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c #endif lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a - - - + + + mov(var(k_iter16), rsi) // i = k_iter16; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKITER4) // if i == 0, jump to code that // contains the k_iter4 loop. - - + + label(.DLOOPKITER16) // MAIN LOOP - - + + // ---------------------------------- iteration 0 #if 0 @@ -219,7 +219,7 @@ void bli_dgemmsup_rd_haswell_asm_6x4 vfmadd231pd(ymm1, ymm3, ymm14) vfmadd231pd(ymm2, ymm3, ymm15) - + // ---------------------------------- iteration 1 vmovupd(mem(rax ), ymm0) @@ -250,7 +250,7 @@ void bli_dgemmsup_rd_haswell_asm_6x4 // ---------------------------------- iteration 2 - + #if 0 prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a @@ -312,27 +312,27 @@ void bli_dgemmsup_rd_haswell_asm_6x4 vfmadd231pd(ymm1, ymm3, ymm14) vfmadd231pd(ymm2, ymm3, ymm15) - + dec(rsi) // i -= 1; jne(.DLOOPKITER16) // iterate again if i != 0. - - - - - - + + + + + + label(.DCONSIDKITER4) - + mov(var(k_iter4), rsi) // i = k_iter4; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKLEFT1) // if i == 0, jump to code that // considers k_left1 loop. // else, we prepare to enter k_iter4 loop. - - + + label(.DLOOPKITER4) // EDGE LOOP (ymm) - + #if 0 prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a @@ -343,7 +343,7 @@ void bli_dgemmsup_rd_haswell_asm_6x4 vmovupd(mem(rax, r8, 1), ymm1) vmovupd(mem(rax, r8, 2), ymm2) add(imm(4*8), rax) // a += 4*cs_b = 4*8; - + vmovupd(mem(rbx ), ymm3) vfmadd231pd(ymm0, ymm3, ymm4) vfmadd231pd(ymm1, ymm3, ymm5) @@ -365,21 +365,21 @@ void bli_dgemmsup_rd_haswell_asm_6x4 vfmadd231pd(ymm1, ymm3, ymm14) vfmadd231pd(ymm2, ymm3, ymm15) - + dec(rsi) // i -= 1; jne(.DLOOPKITER4) // iterate again if i != 0. - - - + + + label(.DCONSIDKLEFT1) - + mov(var(k_left1), rsi) // i = k_left1; test(rsi, rsi) // check i via logical AND. je(.DPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left1 loop. - - + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) @@ -387,12 +387,12 @@ void bli_dgemmsup_rd_haswell_asm_6x4 // using the xmm registers would zero out the // high bits of the destination registers, // which would destory intermediate results. - + vmovsd(mem(rax ), xmm0) vmovsd(mem(rax, r8, 1), xmm1) vmovsd(mem(rax, r8, 2), xmm2) add(imm(1*8), rax) // a += 1*cs_a = 1*8; - + vmovsd(mem(rbx ), xmm3) vfmadd231pd(ymm0, ymm3, ymm4) vfmadd231pd(ymm1, ymm3, ymm5) @@ -414,12 +414,12 @@ void bli_dgemmsup_rd_haswell_asm_6x4 vfmadd231pd(ymm1, ymm3, ymm14) vfmadd231pd(ymm2, ymm3, ymm15) - + dec(rsi) // i -= 1; jne(.DLOOPKLEFT1) // iterate again if i != 0. - - - + + + @@ -427,11 +427,11 @@ void bli_dgemmsup_rd_haswell_asm_6x4 label(.DPOSTACCUM) - - // ymm4 ymm7 ymm10 ymm13 + + // ymm4 ymm7 ymm10 ymm13 // ymm5 ymm8 ymm11 ymm14 // ymm6 ymm9 ymm12 ymm15 - + vhaddpd( ymm7, ymm4, ymm0 ) vextractf128(imm(1), ymm0, xmm1 ) vaddpd( xmm0, xmm1, xmm0 ) @@ -469,7 +469,7 @@ void bli_dgemmsup_rd_haswell_asm_6x4 // xmm6[0:3] = sum(ymm6) sum(ymm9) sum(ymm12) sum(ymm15) - + //mov(var(rs_c), rdi) // load rs_c //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) @@ -477,73 +477,73 @@ void bli_dgemmsup_rd_haswell_asm_6x4 mov(var(beta), rbx) // load address of beta vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate - + vmulpd(ymm0, ymm4, ymm4) // scale by alpha vmulpd(ymm0, ymm5, ymm5) vmulpd(ymm0, ymm6, ymm6) - - - - - - + + + + + + //mov(var(cs_c), rsi) // load cs_c //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) - - - + + + // now avoid loading C if beta == 0 - + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomisd(xmm0, xmm3) // set ZF if beta == 0. je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case - - + + label(.DROWSTORED) - - + + vfmadd231pd(mem(rcx), ymm3, ymm4) vmovupd(ymm4, mem(rcx)) add(rdi, rcx) - + vfmadd231pd(mem(rcx), ymm3, ymm5) vmovupd(ymm5, mem(rcx)) add(rdi, rcx) - + vfmadd231pd(mem(rcx), ymm3, ymm6) vmovupd(ymm6, mem(rcx)) //add(rdi, rcx) - - - + + + jmp(.DDONE) // jump to end. - - - - + + + + label(.DBETAZERO) - - + + label(.DROWSTORBZ) - - + + vmovupd(ymm4, mem(rcx)) add(rdi, rcx) - + vmovupd(ymm5, mem(rcx)) add(rdi, rcx) - + vmovupd(ymm6, mem(rcx)) //add(rdi, rcx) - - - - + + + + label(.DDONE) - - + + lea(mem(r12, rdi, 2), r12) // @@ -560,7 +560,7 @@ void bli_dgemmsup_rd_haswell_asm_6x4 label(.DRETURN) - + end_asm( : // output operands (none) @@ -629,7 +629,7 @@ void bli_dgemmsup_rd_haswell_asm_2x4 // ------------------------------------------------------------------------- begin_asm() - + //vzeroall() // zero all xmm/ymm registers. mov(var(a), rax) // load address of a. @@ -649,7 +649,7 @@ void bli_dgemmsup_rd_haswell_asm_2x4 lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b //lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a - + mov(var(c), rcx) // load address of c mov(var(rs_c), rdi) // load rs_c @@ -682,7 +682,7 @@ void bli_dgemmsup_rd_haswell_asm_2x4 //lea(mem(r14), rax) // rax = a; //lea(mem(rdx), rbx) // rbx = b; - + #if 1 //mov(var(rs_c), rdi) // load rs_c //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) @@ -690,18 +690,18 @@ void bli_dgemmsup_rd_haswell_asm_2x4 prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c #endif - - - + + + mov(var(k_iter16), rsi) // i = k_iter16; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKITER4) // if i == 0, jump to code that // contains the k_iter4 loop. - - + + label(.DLOOPKITER16) // MAIN LOOP - - + + // ---------------------------------- iteration 0 #if 0 @@ -730,7 +730,7 @@ void bli_dgemmsup_rd_haswell_asm_2x4 vfmadd231pd(ymm0, ymm3, ymm13) vfmadd231pd(ymm1, ymm3, ymm14) - + // ---------------------------------- iteration 1 vmovupd(mem(rax ), ymm0) @@ -756,7 +756,7 @@ void bli_dgemmsup_rd_haswell_asm_2x4 // ---------------------------------- iteration 2 - + #if 0 prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a @@ -807,27 +807,27 @@ void bli_dgemmsup_rd_haswell_asm_2x4 vfmadd231pd(ymm0, ymm3, ymm13) vfmadd231pd(ymm1, ymm3, ymm14) - + dec(rsi) // i -= 1; jne(.DLOOPKITER16) // iterate again if i != 0. - - - - - - + + + + + + label(.DCONSIDKITER4) - + mov(var(k_iter4), rsi) // i = k_iter4; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKLEFT1) // if i == 0, jump to code that // considers k_left1 loop. // else, we prepare to enter k_iter4 loop. - - + + label(.DLOOPKITER4) // EDGE LOOP (ymm) - + #if 0 prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a @@ -836,7 +836,7 @@ void bli_dgemmsup_rd_haswell_asm_2x4 vmovupd(mem(rax ), ymm0) vmovupd(mem(rax, r8, 1), ymm1) add(imm(4*8), rax) // a += 4*cs_b = 4*8; - + vmovupd(mem(rbx ), ymm3) vfmadd231pd(ymm0, ymm3, ymm4) vfmadd231pd(ymm1, ymm3, ymm5) @@ -854,21 +854,21 @@ void bli_dgemmsup_rd_haswell_asm_2x4 vfmadd231pd(ymm0, ymm3, ymm13) vfmadd231pd(ymm1, ymm3, ymm14) - + dec(rsi) // i -= 1; jne(.DLOOPKITER4) // iterate again if i != 0. - - - + + + label(.DCONSIDKLEFT1) - + mov(var(k_left1), rsi) // i = k_left1; test(rsi, rsi) // check i via logical AND. je(.DPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left1 loop. - - + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) @@ -876,11 +876,11 @@ void bli_dgemmsup_rd_haswell_asm_2x4 // using the xmm registers would zero out the // high bits of the destination registers, // which would destory intermediate results. - + vmovsd(mem(rax ), xmm0) vmovsd(mem(rax, r8, 1), xmm1) add(imm(1*8), rax) // a += 1*cs_a = 1*8; - + vmovsd(mem(rbx ), xmm3) vfmadd231pd(ymm0, ymm3, ymm4) vfmadd231pd(ymm1, ymm3, ymm5) @@ -898,12 +898,12 @@ void bli_dgemmsup_rd_haswell_asm_2x4 vfmadd231pd(ymm0, ymm3, ymm13) vfmadd231pd(ymm1, ymm3, ymm14) - + dec(rsi) // i -= 1; jne(.DLOOPKLEFT1) // iterate again if i != 0. - - - + + + @@ -911,10 +911,10 @@ void bli_dgemmsup_rd_haswell_asm_2x4 label(.DPOSTACCUM) - - // ymm4 ymm7 ymm10 ymm13 + + // ymm4 ymm7 ymm10 ymm13 // ymm5 ymm8 ymm11 ymm14 - + vhaddpd( ymm7, ymm4, ymm0 ) vextractf128(imm(1), ymm0, xmm1 ) vaddpd( xmm0, xmm1, xmm0 ) @@ -943,75 +943,75 @@ void bli_dgemmsup_rd_haswell_asm_2x4 //mov(var(rs_c), rdi) // load rs_c //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) - + mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate - + vmulpd(ymm0, ymm4, ymm4) // scale by alpha vmulpd(ymm0, ymm5, ymm5) - - - - - - + + + + + + //mov(var(cs_c), rsi) // load cs_c //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) - - - + + + // now avoid loading C if beta == 0 - + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomisd(xmm0, xmm3) // set ZF if beta == 0. je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case - - + + label(.DROWSTORED) - - + + vfmadd231pd(mem(rcx), ymm3, ymm4) vmovupd(ymm4, mem(rcx)) add(rdi, rcx) - + vfmadd231pd(mem(rcx), ymm3, ymm5) vmovupd(ymm5, mem(rcx)) //add(rdi, rcx) - - - + + + jmp(.DDONE) // jump to end. - - - - + + + + label(.DBETAZERO) - - + + label(.DROWSTORBZ) - - + + vmovupd(ymm4, mem(rcx)) add(rdi, rcx) - + vmovupd(ymm5, mem(rcx)) //add(rdi, rcx) - - - - + + + + label(.DDONE) label(.DRETURN) - - + + end_asm( : // output operands (none) @@ -1079,7 +1079,7 @@ void bli_dgemmsup_rd_haswell_asm_1x4 // ------------------------------------------------------------------------- begin_asm() - + //vzeroall() // zero all xmm/ymm registers. mov(var(a), rax) // load address of a. @@ -1099,7 +1099,7 @@ void bli_dgemmsup_rd_haswell_asm_1x4 lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b //lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a - + mov(var(c), rcx) // load address of c mov(var(rs_c), rdi) // load rs_c @@ -1128,26 +1128,26 @@ void bli_dgemmsup_rd_haswell_asm_1x4 //lea(mem(r14), rax) // rax = a; //lea(mem(rdx), rbx) // rbx = b; - + #if 1 //mov(var(rs_c), rdi) // load rs_c //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c - prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + //prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c #endif - - - + + + mov(var(k_iter16), rsi) // i = k_iter16; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKITER4) // if i == 0, jump to code that // contains the k_iter4 loop. - - + + label(.DLOOPKITER16) // MAIN LOOP - - + + // ---------------------------------- iteration 0 #if 0 @@ -1170,7 +1170,7 @@ void bli_dgemmsup_rd_haswell_asm_1x4 add(imm(4*8), rbx) // b += 4*rs_b = 4*8; vfmadd231pd(ymm0, ymm3, ymm13) - + // ---------------------------------- iteration 1 vmovupd(mem(rax ), ymm0) @@ -1191,7 +1191,7 @@ void bli_dgemmsup_rd_haswell_asm_1x4 // ---------------------------------- iteration 2 - + #if 0 prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a #endif @@ -1231,27 +1231,27 @@ void bli_dgemmsup_rd_haswell_asm_1x4 add(imm(4*8), rbx) // b += 4*rs_b = 4*8; vfmadd231pd(ymm0, ymm3, ymm13) - + dec(rsi) // i -= 1; jne(.DLOOPKITER16) // iterate again if i != 0. - - - - - - + + + + + + label(.DCONSIDKITER4) - + mov(var(k_iter4), rsi) // i = k_iter4; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKLEFT1) // if i == 0, jump to code that // considers k_left1 loop. // else, we prepare to enter k_iter4 loop. - - + + label(.DLOOPKITER4) // EDGE LOOP (ymm) - + #if 0 prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*cs_a prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*cs_a @@ -1259,7 +1259,7 @@ void bli_dgemmsup_rd_haswell_asm_1x4 vmovupd(mem(rax ), ymm0) add(imm(4*8), rax) // a += 4*cs_b = 4*8; - + vmovupd(mem(rbx ), ymm3) vfmadd231pd(ymm0, ymm3, ymm4) @@ -1273,21 +1273,21 @@ void bli_dgemmsup_rd_haswell_asm_1x4 add(imm(4*8), rbx) // b += 4*rs_b = 4*8; vfmadd231pd(ymm0, ymm3, ymm13) - + dec(rsi) // i -= 1; jne(.DLOOPKITER4) // iterate again if i != 0. - - - + + + label(.DCONSIDKLEFT1) - + mov(var(k_left1), rsi) // i = k_left1; test(rsi, rsi) // check i via logical AND. je(.DPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left1 loop. - - + + label(.DLOOPKLEFT1) // EDGE LOOP (scalar) @@ -1295,10 +1295,10 @@ void bli_dgemmsup_rd_haswell_asm_1x4 // using the xmm registers would zero out the // high bits of the destination registers, // which would destory intermediate results. - + vmovsd(mem(rax ), xmm0) add(imm(1*8), rax) // a += 1*cs_a = 1*8; - + vmovsd(mem(rbx ), xmm3) vfmadd231pd(ymm0, ymm3, ymm4) @@ -1312,12 +1312,12 @@ void bli_dgemmsup_rd_haswell_asm_1x4 add(imm(1*8), rbx) // b += 1*rs_b = 1*8; vfmadd231pd(ymm0, ymm3, ymm13) - + dec(rsi) // i -= 1; jne(.DLOOPKLEFT1) // iterate again if i != 0. - - - + + + @@ -1325,9 +1325,9 @@ void bli_dgemmsup_rd_haswell_asm_1x4 label(.DPOSTACCUM) - - // ymm4 ymm7 ymm10 ymm13 - + + // ymm4 ymm7 ymm10 ymm13 + vhaddpd( ymm7, ymm4, ymm0 ) vextractf128(imm(1), ymm0, xmm1 ) vaddpd( xmm0, xmm1, xmm0 ) @@ -1339,69 +1339,82 @@ void bli_dgemmsup_rd_haswell_asm_1x4 vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + //vhaddpd( ymm8, ymm5, ymm0 ) + //vextractf128(imm(1), ymm0, xmm1 ) + //vaddpd( xmm0, xmm1, xmm0 ) + + //vhaddpd( ymm14, ymm11, ymm2 ) + //vextractf128(imm(1), ymm2, xmm1 ) + //vaddpd( xmm2, xmm1, xmm2 ) + + //vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + + // xmm4[0:3] = sum(ymm4) sum(ymm7) sum(ymm10) sum(ymm13) + + //mov(var(rs_c), rdi) // load rs_c //lea(mem(, rdi, 4), rdi) // rs_c *= sizeof(float) - + mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate - + vmulpd(ymm0, ymm4, ymm4) // scale by alpha - - - - - - + + + + + + //mov(var(cs_c), rsi) // load cs_c //lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) - - - + + + // now avoid loading C if beta == 0 - + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomisd(xmm0, xmm3) // set ZF if beta == 0. je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case - - + + label(.DROWSTORED) - - + + vfmadd231pd(mem(rcx), ymm3, ymm4) vmovupd(ymm4, mem(rcx)) //add(rdi, rcx) - - - + + + jmp(.DDONE) // jump to end. - - - - + + + + label(.DBETAZERO) - - + + label(.DROWSTORBZ) - - + + vmovupd(ymm4, mem(rcx)) //add(rdi, rcx) - - - - + + + + label(.DDONE) label(.DRETURN) - - + + end_asm( : // output operands (none) From 3d655a951b575b172bfa4c9dc0336d10a0ca5bed Mon Sep 17 00:00:00 2001 From: Devin Matthews Date: Tue, 5 Oct 2021 15:20:27 -0500 Subject: [PATCH 159/243] Fix data race in testsuite. Change-Id: I7704037bad0f7485e7b352de68c2c4535d364226 --- testsuite/src/test_libblis.c | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/testsuite/src/test_libblis.c b/testsuite/src/test_libblis.c index 6bf58831c1..925d2ece1b 100644 --- a/testsuite/src/test_libblis.c +++ b/testsuite/src/test_libblis.c @@ -2386,7 +2386,11 @@ void libblis_test_op_driver // Mark this operation as done. - op->test_done = TRUE; + if ( tdata->id == 0 ) + op->test_done = TRUE; + + // Wait here so that all threads know we are done + bli_pthread_barrier_wait( tdata->barrier ); } From 3c01fcb9fc6c5ea555865421154f1b30e2a647ae Mon Sep 17 00:00:00 2001 From: Minh Quan Ho <1337056+hominhquan@users.noreply.github.com> Date: Tue, 12 Oct 2021 19:53:04 +0200 Subject: [PATCH 160/243] Fix insufficient pool-growing logic in bli_pool.c. (#559) Details: - The current mechanism for growing a pool_t doubles the length of the block_ptrs array every time the array length needs to be increased due to new blocks being added. However, that logic did not take in account the new total number of blocks, and the fact that the caller may be requesting more blocks that would fit even after doubling the current length of block_ptrs. The code comments now contain two illustrating examples that show why, even after doubling, we must always have at least enough room to fit all of the old blocks plus the newly requested blocks. - This commit also happens to fix a memory corruption issue that stems from growing any pool_t that is initialized with a block_ptrs length of 0. (Previously, the memory pool for packed buffers of C was initialized with a block_ptrs length of 0, but because it is unused this bug did not manifest by default.) - Co-authored-by: Minh Quan Ho Change-Id: Ie4963c56e03cbc197d26e29f2def6494f0a6046d --- frame/base/bli_pool.c | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/frame/base/bli_pool.c b/frame/base/bli_pool.c index 7e561983c6..de48035e7d 100644 --- a/frame/base/bli_pool.c +++ b/frame/base/bli_pool.c @@ -373,7 +373,15 @@ void bli_pool_grow { // To prevent this from happening often, we double the current // length of the block_ptrs array. - const siz_t block_ptrs_len_new = 2 * block_ptrs_len_cur; + // Sanity: make sure that the block_ptrs_len_new will be at least + // num_blocks_new, in case doubling the block_ptrs_len_cur is not enough. + // Example 1: + // - block_ptrs_len_cur == num_blocks_cur == 0 and num_blocks_add = 1 + // - So doubling: 2 * block_ptrs_len_cur = 0, whereas 1 is expected + // Example 2: + // - block_ptrs_len_cur == num_blocks_cur == 10 and num_blocks_add = 30 + // - So doubling: 2 * block_ptrs_len_cur = 20, whereas 40 is expected + const siz_t block_ptrs_len_new = bli_max( (2 * block_ptrs_len_cur), num_blocks_new ); #ifdef BLIS_ENABLE_MEM_TRACING printf( "bli_pool_grow(): growing block_ptrs_len (%d -> %d): ", From 6d4d6a7514f3cf848f87cc79a40a5e73add0a52b Mon Sep 17 00:00:00 2001 From: Minh Quan Ho <1337056+hominhquan@users.noreply.github.com> Date: Wed, 13 Oct 2021 20:28:02 +0200 Subject: [PATCH 161/243] Alloc at least 1 elem in pool_t block_ptrs. (#560) Details: - Previously, the block_ptrs field of the pool_t was allowed to be initialized as any unsigned integer, including 0. However, a length of 0 could be problematic given that malloc(0) is undefined and therefore variable across implementations. As a safety measure, we check for block_ptrs array lengths of 0 and, in that case, increase them to 1. - Co-authored-by: Minh Quan Ho Change-Id: I1e885d887aaba5e73df091ef52e6c327fd6418de --- frame/base/bli_pool.c | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/frame/base/bli_pool.c b/frame/base/bli_pool.c index de48035e7d..106fbfa10b 100644 --- a/frame/base/bli_pool.c +++ b/frame/base/bli_pool.c @@ -52,6 +52,11 @@ void bli_pool_init // Make sure that block_ptrs_len is at least num_blocks. block_ptrs_len = bli_max( block_ptrs_len, num_blocks ); + // Handle the case where block_ptrs_len is zero, we explicitly set it to 1, + // to avoid any malloc() with zero size, whose behavior is not fixed, and + // also to prevent from falling into any further memory corruption bug. + block_ptrs_len = ( block_ptrs_len == 0 ) ? 1 : block_ptrs_len; + #ifdef BLIS_ENABLE_MEM_TRACING printf( "bli_pool_init(): allocating block_ptrs (length %d): ", ( int )block_ptrs_len ); From 2a81437bd8a052af7c3289e497fa763c3eaaca4d Mon Sep 17 00:00:00 2001 From: "Field G. Van Zee" Date: Mon, 31 May 2021 16:50:18 -0500 Subject: [PATCH 162/243] Fixed bugs in cpackm kernels, gemmlike code. Details: - Fixed intermittent bugs in bli_packm_haswell_asm_c3xk.c and bli_packm_haswell_asm_c8xk.c whereby the imaginary component of the kappa scalar was incorrectly loaded at an offset of 8 bytes (instead of 4 bytes) from the real component. This was almost certainly a copy- paste bug carried over from the corresonding zpackm kernels. Thanks to Devin Matthews for bringing this to my attention. - Added missing code to gemmlike sandbox files bls_gemm_bp_var1.c and bls_gemm_bp_var2.c that initializes the elements of the temporary microtile to zero. (This bug was never observed in output but rather noticed analytically. It probably would have also manifested as intermittent failures, this time involving edge cases.) - Minor commented-out/disabled changes to testsuite/src/test_gemm.c relating to debugging. Change-Id: I899e20df203806717fb5270b5f3dd0bf1f685011 --- kernels/haswell/1m/bli_packm_haswell_asm_c3xk.c | 2 +- kernels/haswell/1m/bli_packm_haswell_asm_c8xk.c | 2 +- sandbox/gemmlike/bls_gemm_bp_var1.c | 3 +++ sandbox/gemmlike/bls_gemm_bp_var2.c | 6 ++++++ sandbox/gemmlike/bls_l3_packm_var.c | 14 +++++++------- testsuite/src/test_gemm.c | 14 +++++++------- 6 files changed, 25 insertions(+), 16 deletions(-) diff --git a/kernels/haswell/1m/bli_packm_haswell_asm_c3xk.c b/kernels/haswell/1m/bli_packm_haswell_asm_c3xk.c index b99b6eef26..6b7321859f 100644 --- a/kernels/haswell/1m/bli_packm_haswell_asm_c3xk.c +++ b/kernels/haswell/1m/bli_packm_haswell_asm_c3xk.c @@ -125,7 +125,7 @@ void bli_cpackm_haswell_asm_3xk mov(var(kappa), rcx) // load address of kappa vbroadcastss(mem(rcx, 0), ymm10) // load kappa_r and duplicate - vbroadcastss(mem(rcx, 8), ymm11) // load kappa_i and duplicate + vbroadcastss(mem(rcx, 4), ymm11) // load kappa_i and duplicate // now branch on kappa == 1.0 diff --git a/kernels/haswell/1m/bli_packm_haswell_asm_c8xk.c b/kernels/haswell/1m/bli_packm_haswell_asm_c8xk.c index 4cad0c90c3..6184447d58 100644 --- a/kernels/haswell/1m/bli_packm_haswell_asm_c8xk.c +++ b/kernels/haswell/1m/bli_packm_haswell_asm_c8xk.c @@ -125,7 +125,7 @@ void bli_cpackm_haswell_asm_8xk mov(var(kappa), rcx) // load address of kappa vbroadcastss(mem(rcx, 0), ymm10) // load kappa_r and duplicate - vbroadcastss(mem(rcx, 8), ymm11) // load kappa_i and duplicate + vbroadcastss(mem(rcx, 4), ymm11) // load kappa_i and duplicate // now branch on kappa == 1.0 diff --git a/sandbox/gemmlike/bls_gemm_bp_var1.c b/sandbox/gemmlike/bls_gemm_bp_var1.c index ae695ce34f..330a94801b 100644 --- a/sandbox/gemmlike/bls_gemm_bp_var1.c +++ b/sandbox/gemmlike/bls_gemm_bp_var1.c @@ -230,6 +230,9 @@ void PASTECH2(bls_,ch,varname) \ thrinfo_t* restrict thread_pa = NULL; \ thrinfo_t* restrict thread_jr = NULL; \ thrinfo_t* restrict thread_ir = NULL; \ +\ + /* Clear the temporary C buffer in case it has any infs or NaNs. */ \ + PASTEMAC(ch,set0s_mxn)( MR, NR, ct, rs_ct, cs_ct ); \ \ /* Identify the current thrinfo_t node and then grow the tree. */ \ thread_jc = thread; \ diff --git a/sandbox/gemmlike/bls_gemm_bp_var2.c b/sandbox/gemmlike/bls_gemm_bp_var2.c index 957cd57944..22df767aea 100644 --- a/sandbox/gemmlike/bls_gemm_bp_var2.c +++ b/sandbox/gemmlike/bls_gemm_bp_var2.c @@ -538,6 +538,12 @@ void PASTECH2(bls_,ch,varname) \ const inc_t cs_ct = ( col_pref ? MR : 1 ); \ \ ctype zero = *PASTEMAC(ch,0); \ +\ + /* Clear the temporary C buffer in case it has any infs or NaNs. + NOTE: This initialization should really be done statically since + var2 executes this microkernel wrapper many times, and the overhead + of touching the temporary microtile adds up. */ \ + PASTEMAC(ch,set0s_mxn)( MR, NR, ct, rs_ct, cs_ct ); \ \ /* Handle interior and edge cases separately. */ \ if ( mr_cur == MR && nr_cur == NR ) \ diff --git a/sandbox/gemmlike/bls_l3_packm_var.c b/sandbox/gemmlike/bls_l3_packm_var.c index 8a4c1d0206..3265ef834d 100644 --- a/sandbox/gemmlike/bls_l3_packm_var.c +++ b/sandbox/gemmlike/bls_l3_packm_var.c @@ -176,17 +176,17 @@ void PASTECH2(bls_,ch,varname) \ cntx \ ); \ } \ -\ - p_begin += ps_p; \ \ /* -if ( row_stored ) \ -PASTEMAC(ch,fprintm)( stdout, "packm_sup_var1: b packed", panel_len_max, panel_dim_max, \ - p_use, rs_p, cs_p, "%5.2f", "" ); \ if ( !row_stored ) \ -PASTEMAC(ch,fprintm)( stdout, "packm_sup_var1: a packed", panel_dim_max, panel_len_max, \ - p_use, rs_p, cs_p, "%5.2f", "" ); \ +PASTEMAC(ch,fprintm)( stdout, "packm_var1: a packed", panel_dim_max, panel_len_max, \ + p_use, rs_p, cs_p, "%5.2f", "" ); \ +else \ +PASTEMAC(ch,fprintm)( stdout, "packm_var1: b packed", panel_len_max, panel_dim_max, \ + p_use, rs_p, cs_p, "%5.2f", "" ); \ */ \ +\ + p_begin += ps_p; \ } \ } diff --git a/testsuite/src/test_gemm.c b/testsuite/src/test_gemm.c index fc25e74095..1182f07e27 100644 --- a/testsuite/src/test_gemm.c +++ b/testsuite/src/test_gemm.c @@ -267,18 +267,17 @@ void libblis_test_gemm_experiment } #endif + #if 0 + //bli_setm( &BLIS_ONE, &a ); + bli_setsc( 1.0, 0.0, &alpha ); + bli_setsc( 1.0, 0.0, &beta ); + #endif + // Randomize A, B, and C, and save C. libblis_test_mobj_randomize( params, TRUE, &a ); libblis_test_mobj_randomize( params, TRUE, &b ); libblis_test_mobj_randomize( params, TRUE, &c ); bli_copym( &c, &c_save ); -//bli_setm( &BLIS_ONE, &a ); -//bli_setsc( 1.0, 0.0, &alpha ); -//bli_setsc( 0.0, 0.0, &beta ); - -//bli_setm( &BLIS_ONE, &a ); -//bli_setsc( 1.0, 0.0, &alpha ); -//bli_setsc( 0.0, 0.0, &beta ); // Apply the parameters. bli_obj_set_conjtrans( transa, &a ); @@ -482,6 +481,7 @@ if ( bli_obj_length( c ) == 12 && bli_obj_stor3_from_strides( c, a, b ) == BLIS_RRR ) bli_printm( "c after", c, "%6.3f", "" ); #endif +//bli_printm( "c after", c, "%5.2f", "" ); break; default: From 76fbf1233d68d00a54ebf62691d6457cef623c8b Mon Sep 17 00:00:00 2001 From: Devin Matthews Date: Fri, 9 Jul 2021 14:59:48 -0500 Subject: [PATCH 163/243] Add vzeroupper to Haswell microkernels. (#524) Details: - Added vzeroupper instruction to the end of all 'gemm' and 'gemmtrsm' microkernels so as to avoid a performance penalty when mixing AVX and SSE instructions. These vzeroupper instructions were once part of the haswell kernels, but were inadvertently removed during a source code shuffle some time ago when we were managing duplicate 'haswell' and 'zen' kernel sets. Thanks to Devin Matthews for tracking this down and re-inserting the missing instructions. Change-Id: I418fea9fed27ba3ad7d395cf96d1be507955d8e9 --- kernels/haswell/3/bli_gemm_haswell_asm_d6x8.c | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/kernels/haswell/3/bli_gemm_haswell_asm_d6x8.c b/kernels/haswell/3/bli_gemm_haswell_asm_d6x8.c index 59e239fe14..e6d47268f7 100644 --- a/kernels/haswell/3/bli_gemm_haswell_asm_d6x8.c +++ b/kernels/haswell/3/bli_gemm_haswell_asm_d6x8.c @@ -870,7 +870,7 @@ void bli_sgemm_haswell_asm_6x16 label(.SDONE) - + vzeroupper() end_asm( : // output operands (none) @@ -1624,6 +1624,7 @@ void bli_dgemm_haswell_asm_6x8 label(.DDONE) + vzeroupper() @@ -2158,7 +2159,7 @@ void bli_cgemm_haswell_asm_3x8 label(.CDONE) - + vzeroupper() end_asm( : // output operands (none) @@ -2758,7 +2759,7 @@ void bli_zgemm_haswell_asm_3x4 label(.ZDONE) - + vzeroupper() end_asm( : // output operands (none) From e5d4fc2a70d66a49a053b01c01a8d4638a4f783a Mon Sep 17 00:00:00 2001 From: Harihara Sudhan S Date: Thu, 21 Jul 2022 14:28:38 +0530 Subject: [PATCH 164/243] Added low precision GEMM (u8s8s16os16) Feature Addition : Added low precision GEMM to addon. The kernel takes unsigned int8 and signed int8 as inputs and performs GEMM operation. The intermediate accumulation and output are in signed int16. - The compute kernel will perform computation only if B matrix reordered to suit the usage of AVX2 instruction vpmaddubsw. - Kernel for packing the B matrix is provided. - LPGEMM bench code was modified to test the performance and accuracy of the new variant. AMD-Internal: [CPUPL-2171] Change-Id: Id9a6d90b79f4bf82fb2e2f3093974dbf37275f9b --- addon/aocl_gemm/aocl_gemm.h | 4 +- addon/aocl_gemm/aocl_gemm_u8s8s16os16.c | 136 +++ addon/aocl_gemm/aocl_gemm_u8s8s16os16.h | 57 ++ addon/aocl_gemm/aocl_gemm_u8s8s16os16_utils.c | 121 +++ addon/aocl_gemm/aocl_gemm_u8s8s16os16_utils.h | 51 ++ .../frame/u8s8s16/lpgemm_reorder_s16.c | 94 ++ .../frame/u8s8s16/lpgemm_reorder_s16.h | 45 + .../aocl_gemm/frame/u8s8s16/lpgemm_u8s8s16.c | 108 +++ .../aocl_gemm/frame/u8s8s16/lpgemm_u8s8s16.h | 58 ++ .../kernels/u8s8s16/lpgemm_6x32rowmajor.h | 58 ++ .../u8s8s16/lpgemm_6x32rowmajor_amd256.c | 477 ++++++++++ .../kernels/u8s8s16/lpgemm_m_fringe_amd256.c | 537 +++++++++++ .../kernels/u8s8s16/lpgemm_m_fringe_s16.h | 80 ++ .../kernels/u8s8s16/lpgemm_mn_fringe_amd256.c | 854 ++++++++++++++++++ .../kernels/u8s8s16/lpgemm_mn_fringe_s16.h | 119 +++ .../kernels/u8s8s16/lpgemm_n_fringe_amd256.c | 683 ++++++++++++++ .../kernels/u8s8s16/lpgemm_n_fringe_s16.h | 71 ++ .../kernels/u8s8s16/lpgemm_packb_amd256.c | 252 ++++++ .../kernels/u8s8s16/lpgemm_packb_s16.h | 47 + bench/bench_aocl_gemm/bench_lpgemm.c | 18 +- 20 files changed, 3866 insertions(+), 4 deletions(-) create mode 100644 addon/aocl_gemm/aocl_gemm_u8s8s16os16.c create mode 100644 addon/aocl_gemm/aocl_gemm_u8s8s16os16.h create mode 100644 addon/aocl_gemm/aocl_gemm_u8s8s16os16_utils.c create mode 100644 addon/aocl_gemm/aocl_gemm_u8s8s16os16_utils.h create mode 100644 addon/aocl_gemm/frame/u8s8s16/lpgemm_reorder_s16.c create mode 100644 addon/aocl_gemm/frame/u8s8s16/lpgemm_reorder_s16.h create mode 100644 addon/aocl_gemm/frame/u8s8s16/lpgemm_u8s8s16.c create mode 100644 addon/aocl_gemm/frame/u8s8s16/lpgemm_u8s8s16.h create mode 100644 addon/aocl_gemm/kernels/u8s8s16/lpgemm_6x32rowmajor.h create mode 100644 addon/aocl_gemm/kernels/u8s8s16/lpgemm_6x32rowmajor_amd256.c create mode 100644 addon/aocl_gemm/kernels/u8s8s16/lpgemm_m_fringe_amd256.c create mode 100644 addon/aocl_gemm/kernels/u8s8s16/lpgemm_m_fringe_s16.h create mode 100644 addon/aocl_gemm/kernels/u8s8s16/lpgemm_mn_fringe_amd256.c create mode 100644 addon/aocl_gemm/kernels/u8s8s16/lpgemm_mn_fringe_s16.h create mode 100644 addon/aocl_gemm/kernels/u8s8s16/lpgemm_n_fringe_amd256.c create mode 100644 addon/aocl_gemm/kernels/u8s8s16/lpgemm_n_fringe_s16.h create mode 100644 addon/aocl_gemm/kernels/u8s8s16/lpgemm_packb_amd256.c create mode 100644 addon/aocl_gemm/kernels/u8s8s16/lpgemm_packb_s16.h diff --git a/addon/aocl_gemm/aocl_gemm.h b/addon/aocl_gemm/aocl_gemm.h index 446bdc11b7..9316bb7bdc 100644 --- a/addon/aocl_gemm/aocl_gemm.h +++ b/addon/aocl_gemm/aocl_gemm.h @@ -35,9 +35,11 @@ #ifndef BLIS_ADDON_LPGEMM #define BLIS_ADDON_LPGEMM +#include "aocl_gemm_u8s8s16os16.h" #include "aocl_gemm_u8s8s32os32.h" #include "aocl_gemm_f32f32f32of32.h" +#include "aocl_gemm_u8s8s16os16_utils.h" #include "aocl_gemm_u8s8s32os32_utils.h" #include "aocl_gemm_f32f32f32of32_utils.h" -#endif //BLIS_ADDON_LPGEMM +#endif // BLIS_ADDON_LPGEMM diff --git a/addon/aocl_gemm/aocl_gemm_u8s8s16os16.c b/addon/aocl_gemm/aocl_gemm_u8s8s16os16.c new file mode 100644 index 0000000000..8ff3c20247 --- /dev/null +++ b/addon/aocl_gemm/aocl_gemm_u8s8s16os16.c @@ -0,0 +1,136 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "aocl_gemm_u8s8s16os16.h" +#include "lpgemm_types.h" +#include "lpgemm_u8s8s16.h" +#include "lpgemm_config.h" +#include "lpgemm_utils.h" + +void aocl_gemm_u8s8s16os16( + const char transa, + const char transb, + const dim_t m, + const dim_t n, + const dim_t k, + const int16_t alpha, + const uint8_t *a, + const dim_t lda, + const char mem_format_a, + const int8_t *b, + const dim_t ldb, + const char mem_format_b, + const int16_t beta, + int16_t *c, + const dim_t ldc) +{ + trans_t blis_transa; + trans_t blis_transb; + + /* Initialize BLIS. */ + bli_init_auto(); + + // Set MC, NC, KC, NR, MR. + aocl_lpgemm_init_global_cntx(); + + // Null check for pointers. + if ((a == NULL) || (b == NULL) || (c == NULL)) + { + return; // Error. + } + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + bli_param_map_netlib_to_blis_trans(transa, &blis_transa); + bli_param_map_netlib_to_blis_trans(transb, &blis_transb); + + /* Perform BLAS parameter checking. */ + // Transpose not supported. + if ((blis_transa != BLIS_NO_TRANSPOSE) || (blis_transb != BLIS_NO_TRANSPOSE)) + { + return; // Error. + } + + // Row major input expected with leading dimensions equal to row stride. + if ((lda != k) || (ldb != n) || (ldc != n)) + { + return; // Error. + } + + // Check if dimensions are valid. + if ((m <= 0) || (n <= 0) || (k <= 0) || (lda <= 0) || (ldb <= 0) || (ldc <= 0)) + { + return; // Error. + } + + const inc_t rs_a = lda; + const inc_t cs_a = 1; + const inc_t rs_b = ldb; + const inc_t cs_b = 1; + const inc_t rs_c = ldc; + + AOCL_MEMORY_TAG mtag_a; + AOCL_MEMORY_TAG mtag_b; + + bli_param_map_char_to_lpmtag(mem_format_a, &mtag_a); + bli_param_map_char_to_lpmtag(mem_format_b, &mtag_b); + + // B matrix needs to be packed in a certain format in order to be loaded + // and used in VNNI instrution. As such the mtag_b always needs to be either + // packed or reordered. B matrix as it is (unpacked) cannot be used, and + // the mtag_b is set to packed to enable runtime packing. + if (mtag_b == UNPACKED) + { + return; // Error. + } + + // Only unpacked A supported now. + if (mtag_a != UNPACKED) + { + return; // Error. + } + + // Initialize a local runtime with global settings if necessary. Note + // that in the case that a runtime is passed in, we make a local copy. + rntm_t rntm_g; + bli_rntm_init_from_global(&rntm_g); + bli_membrk_rntm_set_membrk(&rntm_g); + + lpgemm_rowvar_u8s8s16o16( + m, n, k, + a, rs_a, cs_a, + b, rs_b, cs_b, + c, rs_c, + alpha, beta); +} diff --git a/addon/aocl_gemm/aocl_gemm_u8s8s16os16.h b/addon/aocl_gemm/aocl_gemm_u8s8s16os16.h new file mode 100644 index 0000000000..4e56c705bd --- /dev/null +++ b/addon/aocl_gemm/aocl_gemm_u8s8s16os16.h @@ -0,0 +1,57 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef AOCL_GEMM_U8S8S16OS16_H +#define AOCL_GEMM_U8S8S16OS16_H + +// Only supports matrices in row major format +// Limitations: Supports mem_format_b = 'Reorder' +BLIS_EXPORT_ADDON void aocl_gemm_u8s8s16os16( + const char transa, + const char transb, + const dim_t m, + const dim_t n, + const dim_t k, + const int16_t alpha, + const uint8_t *a, + const dim_t lda, + const char mem_format_a, + const int8_t *b, + const dim_t ldb, + const char mem_format_b, + const int16_t beta, + int16_t *c, + const dim_t ldc); + +#endif // AOCL_GEMM_U8S8S16OS16_H diff --git a/addon/aocl_gemm/aocl_gemm_u8s8s16os16_utils.c b/addon/aocl_gemm/aocl_gemm_u8s8s16os16_utils.c new file mode 100644 index 0000000000..993dbb2dfb --- /dev/null +++ b/addon/aocl_gemm/aocl_gemm_u8s8s16os16_utils.c @@ -0,0 +1,121 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "aocl_gemm_u8s8s16os16_utils.h" +#include "lpgemm_types.h" +#include "lpgemm_config.h" +#include "lpgemm_utils.h" +#include "lpgemm_reorder_s16.h" + +siz_t aocl_get_reorder_buf_size_u8s8s16os16( + const char mat_type, + const dim_t k, + const dim_t n) +{ + if ((k <= 0) || (n <= 0)) + { + return 0; // Error. + } + + /* Initialize BLIS. */ + bli_init_auto(); + + // Set MC, NC, KC, NR, MR. + aocl_lpgemm_init_global_cntx(); + + AOCL_MATRIX_TYPE input_mat_type; + bli_param_map_char_to_lpmat_type(mat_type, &input_mat_type); + + if (input_mat_type == A_MATRIX) + { + return 0; // A reorder not supported. + } + + // Extra space since packing does width in multiples of 16. The vpmaddubsw + // instruction can be used as long as atleast one ymm register can be fully + // loaded; and since k_dim needs to be at least 2, having n_dim atleast 16 + // should give 2x16=32 elements, enough for 1 ymm register.The padding is + // not rounded to NR (=16), since that would result in memory wastage. + dim_t n_reorder = make_multiple_of_n(n, 16); + + // Extra space since packing does length in multiples of 2. + dim_t k_reorder = make_multiple_of_n(k, 2); + + siz_t size_req = sizeof(int8_t) * k_reorder * n_reorder; + + return size_req; +} + +void aocl_reorder_u8s8s16os16( + const char mat_type, + const int8_t *input_buf_addr, + int8_t *reorder_buf_addr, + const dim_t k, + const dim_t n, + const dim_t ldb) +{ + if ((input_buf_addr == NULL) || (reorder_buf_addr == NULL) || + (k <= 0) || (n <= 0) || (ldb < n)) + { + return; // Error. + } + + /* Initialize BLIS. */ + bli_init_auto(); + + // Set MC, NC, KC, NR, MR. + aocl_lpgemm_init_global_cntx(); + + AOCL_MATRIX_TYPE input_mat_type; + bli_param_map_char_to_lpmat_type(mat_type, &input_mat_type); + + if (input_mat_type == A_MATRIX) + { + return; // A reorder not supported. + } + + // Create dummy b_reorder obj. + lpgemm_obj_t b_reorder; + b_reorder.storage.aligned_buffer = reorder_buf_addr; + + // Create dummy original b obj; + lpgemm_obj_t b; + b.storage.aligned_buffer = (void *)input_buf_addr; + b.rs = ldb; + b.width = n; + b.length = k; + + aocl_reorderb_nr32_u8s8s16o16(&b, &b_reorder); +} diff --git a/addon/aocl_gemm/aocl_gemm_u8s8s16os16_utils.h b/addon/aocl_gemm/aocl_gemm_u8s8s16os16_utils.h new file mode 100644 index 0000000000..5f76da8b38 --- /dev/null +++ b/addon/aocl_gemm/aocl_gemm_u8s8s16os16_utils.h @@ -0,0 +1,51 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef AOCL_GEMM_U8S8S16OS16_UTILS_H +#define AOCL_GEMM_U8S8S16OS16_UTILS_H + +BLIS_EXPORT_ADDON siz_t aocl_get_reorder_buf_size_u8s8s16os16( + const char mat_type, + const dim_t k, + const dim_t n); + +BLIS_EXPORT_ADDON void aocl_reorder_u8s8s16os16( + const char mat_type, + const int8_t *input_buf_addr, + int8_t *reorder_buf_addr, + const dim_t k, + const dim_t n, + const dim_t ldb); + +#endif // AOCL_GEMM_U8S8S16OS16_UTILS_H \ No newline at end of file diff --git a/addon/aocl_gemm/frame/u8s8s16/lpgemm_reorder_s16.c b/addon/aocl_gemm/frame/u8s8s16/lpgemm_reorder_s16.c new file mode 100644 index 0000000000..e900559943 --- /dev/null +++ b/addon/aocl_gemm/frame/u8s8s16/lpgemm_reorder_s16.c @@ -0,0 +1,94 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ +#include "blis.h" +#include "lpgemm_utils.h" +#include "lpgemm_reorder_s16.h" +#include "lpgemm_packb_s16.h" +#include "lpgemm_config.h" + +void aocl_reorderb_nr32_u8s8s16o16 + ( + lpgemm_obj_t *b, + lpgemm_obj_t *b_reorder + ) +{ + // To Do: Constant declaration's to be moved to config + const dim_t NC = 1024; + const dim_t KC = 1024; + + // Extracting the matrix properties from the lpgemm object + dim_t rs_b = b->rs; + dim_t n = b->width; + dim_t k = b->length; + + dim_t rs_b_reorder; + dim_t cs_b_reorder; + + dim_t k_updated = k; + + // Making multiple of 2 to suit k in vpmaddubsw + k_updated += (k_updated & 0x1); + + for (dim_t jc = 0; jc < n; jc += NC) + { + dim_t nc0 = ((jc + NC) <= n) ? NC : (n % NC); + + // nc0 needs to be a multiple of 16 since this gives maximum + // vectorization. Packing B always results in buffers with width + // which is a multiple of 16. Subsequently the nc0 offsets used + // for packed/reordered buffers needs to be updated. + dim_t nc0_mod16 = nc0 % 16; + dim_t nc0_updated = nc0; + if (nc0_mod16 != 0) + { + nc0_updated += (16 - nc0_mod16); + } + + for (dim_t pc = 0; pc < k; pc += KC) + { + dim_t kc0 = ((pc + KC) <= k) ? KC : (k % KC); + + // B should always be packed. + packb_nr32_u8s8s16o16( + (((int8_t *)b_reorder->storage.aligned_buffer) + (jc * k_updated) + (nc0_updated * pc)), + (((int8_t *)b->storage.aligned_buffer) + (rs_b * pc) + jc), + rs_b, nc0, kc0, &rs_b_reorder, &cs_b_reorder); + } + } + + // Changing the packed matrix properties in the packed matrix object + b_reorder->rs = rs_b_reorder; + b_reorder->cs = cs_b_reorder; + b_reorder->mtag = REORDERED; +} diff --git a/addon/aocl_gemm/frame/u8s8s16/lpgemm_reorder_s16.h b/addon/aocl_gemm/frame/u8s8s16/lpgemm_reorder_s16.h new file mode 100644 index 0000000000..1b107d634b --- /dev/null +++ b/addon/aocl_gemm/frame/u8s8s16/lpgemm_reorder_s16.h @@ -0,0 +1,45 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ +#ifndef LPGEMM_REORDER_S16_H +#define LPGEMM_REORDER_S16_H + +#include "lpgemm_types.h" + +void aocl_reorderb_nr32_u8s8s16o16 + ( + lpgemm_obj_t *b, + lpgemm_obj_t *b_reorder + ); + +#endif // LPGEMM_REORDER_H \ No newline at end of file diff --git a/addon/aocl_gemm/frame/u8s8s16/lpgemm_u8s8s16.c b/addon/aocl_gemm/frame/u8s8s16/lpgemm_u8s8s16.c new file mode 100644 index 0000000000..d8e725a755 --- /dev/null +++ b/addon/aocl_gemm/frame/u8s8s16/lpgemm_u8s8s16.c @@ -0,0 +1,108 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "lpgemm_u8s8s16.h" +#include "lpgemm_packb.h" +#include "lpgemm_6x32rowmajor.h" +#include "lpgemm_utils.h" +#include "lpgemm_config.h" + +void lpgemm_rowvar_u8s8s16o16 + ( + const dim_t m, + const dim_t n, + const dim_t k, + const uint8_t *a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t *b, + const dim_t rs_b, + const dim_t cs_b, + int16_t *c, + const dim_t rs_c, + int16_t alpha, + int16_t beta + ) +{ + // To Do: Constant declaration's to be moved to config files + dim_t NC = 1024; + dim_t KC = 1024; + dim_t MC = 144; + dim_t NR = 32; + + const int8_t *b_use; + const uint8_t *a_use; + + for (dim_t jc = 0; jc < n; jc += NC) + { + dim_t nc0 = ((jc + NC) <= n) ? NC : (n % NC); + + for (dim_t pc = 0; pc < k; pc += KC) + { + int32_t beta0 = (pc == 0) ? beta : 1; + dim_t kc0 = ((pc + KC) <= k) ? KC : (k % KC); + + int kc0_updated = kc0; + + // Making multiple of 2 to suit k in vpmaddubsw + kc0_updated += (kc0_updated & 0x1); + + // B part getting processed + b_use = b + (jc * k) + (pc * nc0); + + for (dim_t ic = 0; ic < m; ic += MC) + { + dim_t mc0 = ((ic + MC) <= m) ? MC : (m % MC); + + a_use = a + (rs_a * ic) + (cs_a * pc); + + dim_t a_block_stride = rs_a; + + for (dim_t jr = 0; jr < nc0; jr += NR) + { + dim_t nr0 = ((jr + NR) <= nc0) ? NR : (nc0 % NR); + + // Calls for reorder B + lpgemm_rowvar_u8s8s16o16_6x32( + mc0, nr0, kc0, + a_use, rs_a, cs_a, a_block_stride, + (b_use + (jr * kc0_updated)), rs_b, cs_b, + (c + (rs_c * ic) + jc + jr), rs_c, 1, + alpha, beta0); + } + } + } + } +} diff --git a/addon/aocl_gemm/frame/u8s8s16/lpgemm_u8s8s16.h b/addon/aocl_gemm/frame/u8s8s16/lpgemm_u8s8s16.h new file mode 100644 index 0000000000..4802c90c29 --- /dev/null +++ b/addon/aocl_gemm/frame/u8s8s16/lpgemm_u8s8s16.h @@ -0,0 +1,58 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef LPGEMM_U8S8S16_H +#define LPGEMM_U8S8S16_H + +#include "blis.h" +#include "lpgemm_types.h" + +void lpgemm_rowvar_u8s8s16o16 + ( + const dim_t m, + const dim_t n, + const dim_t k, + const uint8_t *a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t *b, + const dim_t rs_b, + const dim_t cs_b, + int16_t *c, + const dim_t rs_c, + int16_t alpha, + int16_t beta + ); + +#endif // LPGEMM_U8S8S16_H \ No newline at end of file diff --git a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_6x32rowmajor.h b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_6x32rowmajor.h new file mode 100644 index 0000000000..dca3170d3f --- /dev/null +++ b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_6x32rowmajor.h @@ -0,0 +1,58 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_GEMM_INT16_MNROW +#define BLIS_GEMM_INT16_MNROW + +// 6x32 int8o16 kernel +void lpgemm_rowvar_u8s8s16o16_6x32 + ( + const dim_t m0, + const dim_t n0, + const dim_t k0, + const uint8_t *a, + const dim_t rs_a, + const dim_t cs_a, + const dim_t ps_a, + const int8_t *b, + const dim_t rs_b, + const dim_t cs_b, + int16_t *c, + const dim_t rs_c, + const dim_t cs_c, + const int16_t alpha, + const int16_t beta + ); + +#endif // BLIS_GEMM_INT16_MNROW \ No newline at end of file diff --git a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_6x32rowmajor_amd256.c b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_6x32rowmajor_amd256.c new file mode 100644 index 0000000000..3c2c8e90cb --- /dev/null +++ b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_6x32rowmajor_amd256.c @@ -0,0 +1,477 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ +#include +#include "blis.h" +#include "lpgemm_6x32rowmajor.h" +#include "lpgemm_m_fringe_s16.h" +#include "lpgemm_n_fringe_s16.h" + +// 6x32 int8o16 kernel +void lpgemm_rowvar_u8s8s16o16_6x32 + ( + const dim_t m0, + const dim_t n0, + const dim_t k0, + const uint8_t *a, + const dim_t rs_a, + const dim_t cs_a, + const dim_t ps_a, + const int8_t *b, + const dim_t rs_b, + const dim_t cs_b, + int16_t *c, + const dim_t rs_c, + const dim_t cs_c, + const int16_t alpha, + const int16_t beta + ) +{ + dim_t MR = 6; + dim_t NR = 32; + + dim_t m_full_pieces = m0 / MR; + dim_t m_full_pieces_loop_limit = m_full_pieces * MR; + dim_t m_partial_pieces = m0 % MR; + + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + // When n fringe cases are encountered + if (n0 < NR) + { + // Split into multiple smaller fringe kernels, so as to maximize + // vectorization after packing. Any n0 < NR(32) can be expressed + // as n0 = 16 + n`. + dim_t n0_rem = n0 % 16; + dim_t n0_16 = n0 / 16; + dim_t k0_updated = k0; + + // Making multiple of 2 to suit k in vpmaddubsw + k0_updated += (k0_updated & 0x1); + + if (n0_16 == 1) + { + lpgemm_rowvar_u8s8s16o16_6x16( + m0, k0_updated, + a, rs_a, cs_a, ps_a, + b, ((rs_b / 2) * 1), cs_b, + c, rs_c, + alpha, beta); + + b = b + (16 * k0_updated); + c = c + 16; + } + + if (n0_rem > 0) + { + lpgemm_rowvar_u8s8s16o16_6xlt16( + m0, k0_updated, + a, rs_a, cs_a, ps_a, + b, ((rs_b / 2) * 1), cs_b, + c, rs_c, + alpha, beta, n0_rem); + } + + // If fringe cases are encountered, return early + return; + } + + // B matrix storage. + __m256i b0, b1; + + // A matrix storage. + __m256i a_int32_0, a_int32_1; + + // Intermediate vectors + __m256i inter_vec[4]; + + for (dim_t ir = 0; ir < m_full_pieces_loop_limit; ir += MR) + { + // Registers to use for accumulating C. + __m256i c_int16_0p0 = _mm256_setzero_si256(); + __m256i c_int16_0p1 = _mm256_setzero_si256(); + + __m256i c_int16_1p0 = _mm256_setzero_si256(); + __m256i c_int16_1p1 = _mm256_setzero_si256(); + + __m256i c_int16_2p0 = _mm256_setzero_si256(); + __m256i c_int16_2p1 = _mm256_setzero_si256(); + + __m256i c_int16_3p0 = _mm256_setzero_si256(); + __m256i c_int16_3p1 = _mm256_setzero_si256(); + + __m256i c_int16_4p0 = _mm256_setzero_si256(); + __m256i c_int16_4p1 = _mm256_setzero_si256(); + + __m256i c_int16_5p0 = _mm256_setzero_si256(); + __m256i c_int16_5p1 = _mm256_setzero_si256(); + + for (dim_t kr = 0; kr < k_full_pieces; kr += 1) + { + dim_t offset = kr * 2; + + // Broadcast a[0,kr:kr+2]. + a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 0) + (cs_a * offset))); + + b0 = _mm256_loadu_si256((__m256i const *)(b + (64 * kr) + (NR * 0))); + b1 = _mm256_loadu_si256((__m256i const *)(b + (64 * kr) + (NR * 1))); + + // Broadcast a[1,kr:kr+2]. + a_int32_1 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 1) + (cs_a * offset))); + + // Seperate register for intermediate op + inter_vec[0] = _mm256_maddubs_epi16(a_int32_0, b0); + inter_vec[1] = _mm256_maddubs_epi16(a_int32_0, b1); + + // Perform column direction mat-mul with k = 2. + // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_0p0 = _mm256_add_epi16(inter_vec[0], c_int16_0p0); + c_int16_0p1 = _mm256_add_epi16(inter_vec[1], c_int16_0p1); + + // Broadcast a[2,kr:kr+2]. + a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 2) + (cs_a * offset))); + + // Seperate register for intermediate op + inter_vec[2] = _mm256_maddubs_epi16(a_int32_1, b0); + inter_vec[3] = _mm256_maddubs_epi16(a_int32_1, b1); + + // Perform column direction mat-mul with k = 2. + // c[1,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_1p0 = _mm256_add_epi16(inter_vec[2], c_int16_1p0); + c_int16_1p1 = _mm256_add_epi16(inter_vec[3], c_int16_1p1); + + // Broadcast a[3,kr:kr+2]. + a_int32_1 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 3) + (cs_a * offset))); + + // Seperate register for intermediate op + inter_vec[0] = _mm256_maddubs_epi16(a_int32_0, b0); + inter_vec[1] = _mm256_maddubs_epi16(a_int32_0, b1); + + // Perform column direction mat-mul with k = 2. + // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_2p0 = _mm256_add_epi16(inter_vec[0], c_int16_2p0); + c_int16_2p1 = _mm256_add_epi16(inter_vec[1], c_int16_2p1); + + // Broadcast a[4,kr:kr+2]. + a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 4) + (cs_a * offset))); + + // Seperate register for intermediate op + inter_vec[2] = _mm256_maddubs_epi16(a_int32_1, b0); + inter_vec[3] = _mm256_maddubs_epi16(a_int32_1, b1); + + // Perform column direction mat-mul with k = 2. + // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_3p0 = _mm256_add_epi16(inter_vec[2], c_int16_3p0); + c_int16_3p1 = _mm256_add_epi16(inter_vec[3], c_int16_3p1); + + // Broadcast a[5,kr:kr+2]. + a_int32_1 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 5) + (cs_a * offset))); + + // Seperate register for intermediate op + inter_vec[0] = _mm256_maddubs_epi16(a_int32_0, b0); + inter_vec[1] = _mm256_maddubs_epi16(a_int32_0, b1); + + // Perform column direction mat-mul with k = 2. + // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+4,0-31] + c_int16_4p0 = _mm256_add_epi16(inter_vec[0], c_int16_4p0); + c_int16_4p1 = _mm256_add_epi16(inter_vec[1], c_int16_4p1); + + // Seperate register for intermediate op + inter_vec[2] = _mm256_maddubs_epi16(a_int32_1, b0); + inter_vec[3] = _mm256_maddubs_epi16(a_int32_1, b1); + + // Perform column direction mat-mul with k = 2. + // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+4,0-31] + c_int16_5p0 = _mm256_add_epi16(inter_vec[2], c_int16_5p0); + c_int16_5p1 = _mm256_add_epi16(inter_vec[3], c_int16_5p1); + } + + // Handle k remainder. + if (k_partial_pieces > 0) + { + uint8_t a_element[6]; + + b0 = _mm256_loadu_si256((__m256i const *)(b + (64 * k_full_pieces) + (NR * 0))); + b1 = _mm256_loadu_si256((__m256i const *)(b + (64 * k_full_pieces) + (NR * 1))); + + a_element[0] = *(a + (rs_a * 0) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_element[0]); + + // Seperate register for intermediate op + inter_vec[0] = _mm256_maddubs_epi16(a_int32_0, b0); + inter_vec[1] = _mm256_maddubs_epi16(a_int32_0, b1); + + // Perform column direction mat-mul with k = 2. + // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_0p0 = _mm256_add_epi16(inter_vec[0], c_int16_0p0); + c_int16_0p1 = _mm256_add_epi16(inter_vec[1], c_int16_0p1); + + a_element[1] = *(a + (rs_a * 1) + (cs_a * (k_full_pieces * 2))); + a_int32_1 = _mm256_set1_epi8(a_element[1]); + + // Seperate register for intermediate op + inter_vec[2] = _mm256_maddubs_epi16(a_int32_1, b0); + inter_vec[3] = _mm256_maddubs_epi16(a_int32_1, b1); + + // Perform column direction mat-mul with k = 2. + // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+4,0-31] + c_int16_1p0 = _mm256_add_epi16(inter_vec[2], c_int16_1p0); + c_int16_1p1 = _mm256_add_epi16(inter_vec[3], c_int16_1p1); + + a_element[2] = *(a + (rs_a * 2) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_element[2]); + + // Seperate register for intermediate op + inter_vec[0] = _mm256_maddubs_epi16(a_int32_0, b0); + inter_vec[1] = _mm256_maddubs_epi16(a_int32_0, b1); + + // Perform column direction mat-mul with k = 2. + // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_2p0 = _mm256_add_epi16(inter_vec[0], c_int16_2p0); + c_int16_2p1 = _mm256_add_epi16(inter_vec[1], c_int16_2p1); + + a_element[3] = *(a + (rs_a * 3) + (cs_a * (k_full_pieces * 2))); + a_int32_1 = _mm256_set1_epi8(a_element[3]); + + // Seperate register for intermediate op + inter_vec[2] = _mm256_maddubs_epi16(a_int32_1, b0); + inter_vec[3] = _mm256_maddubs_epi16(a_int32_1, b1); + + // Perform column direction mat-mul with k = 2. + // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_3p0 = _mm256_add_epi16(inter_vec[2], c_int16_3p0); + c_int16_3p1 = _mm256_add_epi16(inter_vec[3], c_int16_3p1); + + a_element[4] = *(a + (rs_a * 4) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_element[4]); + + // Seperate register for intermediate op + inter_vec[0] = _mm256_maddubs_epi16(a_int32_0, b0); + inter_vec[1] = _mm256_maddubs_epi16(a_int32_0, b1); + + // Perform column direction mat-mul with k = 2. + // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_4p0 = _mm256_add_epi16(inter_vec[0], c_int16_4p0); + c_int16_4p1 = _mm256_add_epi16(inter_vec[1], c_int16_4p1); + + a_element[5] = *(a + (rs_a * 5) + (cs_a * (k_full_pieces * 2))); + a_int32_1 = _mm256_set1_epi8(a_element[5]); + + // Seperate register for intermediate op + inter_vec[2] = _mm256_maddubs_epi16(a_int32_1, b0); + inter_vec[3] = _mm256_maddubs_epi16(a_int32_1, b1); + + // Perform column direction mat-mul with k = 2. + // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_5p0 = _mm256_add_epi16(inter_vec[2], c_int16_5p0); + c_int16_5p1 = _mm256_add_epi16(inter_vec[3], c_int16_5p1); + } + + // Load alpha and beta + __m256i selector1 = _mm256_set1_epi16(alpha); + __m256i selector2 = _mm256_set1_epi16(beta); + + // Scale by alpha + c_int16_0p0 = _mm256_mullo_epi16(selector1, c_int16_0p0); + c_int16_0p1 = _mm256_mullo_epi16(selector1, c_int16_0p1); + + c_int16_1p0 = _mm256_mullo_epi16(selector1, c_int16_1p0); + c_int16_1p1 = _mm256_mullo_epi16(selector1, c_int16_1p1); + + c_int16_2p0 = _mm256_mullo_epi16(selector1, c_int16_2p0); + c_int16_2p1 = _mm256_mullo_epi16(selector1, c_int16_2p1); + + c_int16_3p0 = _mm256_mullo_epi16(selector1, c_int16_3p0); + c_int16_3p1 = _mm256_mullo_epi16(selector1, c_int16_3p1); + + c_int16_4p0 = _mm256_mullo_epi16(selector1, c_int16_4p0); + c_int16_4p1 = _mm256_mullo_epi16(selector1, c_int16_4p1); + + c_int16_5p0 = _mm256_mullo_epi16(selector1, c_int16_5p0); + c_int16_5p1 = _mm256_mullo_epi16(selector1, c_int16_5p1); + + // Scale C by beta. + if (beta != 0) + { + // c[0,0-15] + selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * (ir + 0)) + (0 * 16))); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_0p0 = _mm256_add_epi16(selector1, c_int16_0p0); + + // c[0, 16-31] + selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * (ir + 0)) + (1 * 16))); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_0p1 = _mm256_add_epi16(selector1, c_int16_0p1); + + // c[1,0-15] + selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * (ir + 1)) + (0 * 16))); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_1p0 = _mm256_add_epi16(selector1, c_int16_1p0); + + // c[1,16-31] + selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * (ir + 1)) + (1 * 16))); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_1p1 = _mm256_add_epi16(selector1, c_int16_1p1); + + // c[2,0-15] + selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * (ir + 2)) + (0 * 16))); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_2p0 = _mm256_add_epi16(selector1, c_int16_2p0); + + // c[2,16-31] + selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * (ir + 2)) + (1 * 16))); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_2p1 = _mm256_add_epi16(selector1, c_int16_2p1); + + // c[3,0-15] + selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * (ir + 3)) + (0 * 16))); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_3p0 = _mm256_add_epi16(selector1, c_int16_3p0); + + // c[3,16-31] + selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * (ir + 3)) + (1 * 16))); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_3p1 = _mm256_add_epi16(selector1, c_int16_3p1); + + // c[4,0-15] + selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * (ir + 4)) + (0 * 16))); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_4p0 = _mm256_add_epi16(selector1, c_int16_4p0); + + // c[4,16-31] + selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * (ir + 4)) + (1 * 16))); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_4p1 = _mm256_add_epi16(selector1, c_int16_4p1); + + // c[5,0-15] + selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * (ir + 5)) + (0 * 16))); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_5p0 = _mm256_add_epi16(selector1, c_int16_5p0); + + // c[5,16-31] + selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * (ir + 5)) + (1 * 16))); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_5p1 = _mm256_add_epi16(selector1, c_int16_5p1); + } + + // Store the results. + // c[0,0-15] + _mm256_storeu_si256((__m256i *)(c + (rs_c * (ir + 0)) + (0 * 16)), c_int16_0p0); + + // c[0, 16-31] + _mm256_storeu_si256((__m256i *)(c + (rs_c * (ir + 0)) + (1 * 16)), c_int16_0p1); + + // c[1,0-15] + _mm256_storeu_si256((__m256i *)(c + (rs_c * (ir + 1)) + (0 * 16)), c_int16_1p0); + + // c[1,16-31] + _mm256_storeu_si256((__m256i *)(c + (rs_c * (ir + 1)) + (1 * 16)), c_int16_1p1); + + // c[2,0-15] + _mm256_storeu_si256((__m256i *)(c + (rs_c * (ir + 2)) + (0 * 16)), c_int16_2p0); + + // c[2,16-31] + _mm256_storeu_si256((__m256i *)(c + (rs_c * (ir + 2)) + (1 * 16)), c_int16_2p1); + + // c[3,0-15] + _mm256_storeu_si256((__m256i *)(c + (rs_c * (ir + 3)) + (0 * 16)), c_int16_3p0); + + // c[3,16-31] + _mm256_storeu_si256((__m256i *)(c + (rs_c * (ir + 3)) + (1 * 16)), c_int16_3p1); + + // c[4,0-15] + _mm256_storeu_si256((__m256i *)(c + (rs_c * (ir + 4)) + (0 * 16)), c_int16_4p0); + + // c[4,16-31] + _mm256_storeu_si256((__m256i *)(c + (rs_c * (ir + 4)) + (1 * 16)), c_int16_4p1); + + // c[5,0-15] + _mm256_storeu_si256((__m256i *)(c + (rs_c * (ir + 5)) + (0 * 16)), c_int16_5p0); + + // c[5,16-31] + _mm256_storeu_si256((__m256i *)(c + (rs_c * (ir + 5)) + (1 * 16)), c_int16_5p1); + + a = a + (MR * ps_a); + } + + if (m_partial_pieces > 0) + { + // Split into multiple smaller fringe kernels, so as to maximize + // vectorization after packing. Any m0 < MR(6) can be expressed + // as a combination of numbers from the set {4, 2, 1}. + dim_t m_partial4 = m_partial_pieces / 4; + m_partial_pieces = m_partial_pieces % 4; + + dim_t m_partial2 = m_partial_pieces / 2; + dim_t m_partial = m_partial_pieces % 2; + + if (m_partial4 == 1) + { + lpgemm_rowvar_u8s8s16o16_4x32( + k0, + a, rs_a, cs_a, + b, rs_b, cs_b, + (c + (rs_c * m_full_pieces_loop_limit)), rs_c, + alpha, beta); + + // a pointer increment + a = a + (4 * ps_a); + m_full_pieces_loop_limit += 4; + } + + if (m_partial2 == 1) + { + lpgemm_rowvar_u8s8s16o16_2x32( + k0, + a, rs_a, cs_a, + b, rs_b, cs_b, + (c + (rs_c * m_full_pieces_loop_limit)), rs_c, + alpha, beta); + + // a pointer increment + a = a + (2 * ps_a); + m_full_pieces_loop_limit += 2; + } + + if (m_partial == 1) + { + lpgemm_rowvar_u8s8s16o16_1x32( + k0, + a, rs_a, cs_a, + b, rs_b, cs_b, + (c + (rs_c * m_full_pieces_loop_limit)), rs_c, + alpha, beta); + } + } +} \ No newline at end of file diff --git a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_m_fringe_amd256.c b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_m_fringe_amd256.c new file mode 100644 index 0000000000..115e3c2a4e --- /dev/null +++ b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_m_fringe_amd256.c @@ -0,0 +1,537 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include + +#include "blis.h" +#include "lpgemm_m_fringe_s16.h" + +// 4x32 int8o16 kernel +void lpgemm_rowvar_u8s8s16o16_4x32 + ( + const dim_t k0, + const uint8_t *a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t *b, + const dim_t rs_b, + const dim_t cs_b, + int16_t *c, + const dim_t rs_c, + const int16_t alpha, + const int16_t beta + ) +{ + dim_t NR = 32; + + // The division is done by considering the vpmaddubsw instruction + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + // B matrix storage. + __m256i b0; + __m256i b1; + + // A matrix storage. + __m256i a_int32_0; + __m256i a_int32_1; + __m256i inter_vec[4]; + + // Registers to use for accumulating C. + __m256i c_int16_0p0 = _mm256_setzero_si256(); + __m256i c_int16_0p1 = _mm256_setzero_si256(); + + __m256i c_int16_1p0 = _mm256_setzero_si256(); + __m256i c_int16_1p1 = _mm256_setzero_si256(); + + __m256i c_int16_2p0 = _mm256_setzero_si256(); + __m256i c_int16_2p1 = _mm256_setzero_si256(); + + __m256i c_int16_3p0 = _mm256_setzero_si256(); + __m256i c_int16_3p1 = _mm256_setzero_si256(); + + for (dim_t kr = 0; kr < k_full_pieces; kr += 1) + { + dim_t offset = kr * 2; + + // Broadcast a[0,kr:kr+2]. + a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 0) + (cs_a * offset))); + + b0 = _mm256_loadu_si256((__m256i const *)(b + (64 * kr) + (NR * 0))); + b1 = _mm256_loadu_si256((__m256i const *)(b + (64 * kr) + (NR * 1))); + + // Broadcast a[1,kr:kr+2]. + a_int32_1 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 1) + (cs_a * offset))); + + // Seperate register for intermediate op + inter_vec[0] = _mm256_maddubs_epi16(a_int32_0, b0); + inter_vec[1] = _mm256_maddubs_epi16(a_int32_0, b1); + + // Perform column direction mat-mul with k = 2. + // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_0p0 = _mm256_add_epi16(inter_vec[0], c_int16_0p0); + c_int16_0p1 = _mm256_add_epi16(inter_vec[1], c_int16_0p1); + + // Broadcast a[2,kr:kr+2]. + a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 2) + (cs_a * offset))); + + // Seperate register for intermediate op + inter_vec[2] = _mm256_maddubs_epi16(a_int32_1, b0); + inter_vec[3] = _mm256_maddubs_epi16(a_int32_1, b1); + + // Perform column direction mat-mul with k = 2. + // c[1,0-31] = a[1,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_1p0 = _mm256_add_epi16(inter_vec[2], c_int16_1p0); + c_int16_1p1 = _mm256_add_epi16(inter_vec[3], c_int16_1p1); + + // Broadcast a[3,kr:kr+2]. + a_int32_1 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 3) + (cs_a * offset))); + + // Seperate register for intermediate op + inter_vec[0] = _mm256_maddubs_epi16(a_int32_0, b0); + inter_vec[1] = _mm256_maddubs_epi16(a_int32_0, b1); + + // Perform column direction mat-mul with k = 2. + // c[2,0-31] = a[2,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_2p0 = _mm256_add_epi16(inter_vec[0], c_int16_2p0); + c_int16_2p1 = _mm256_add_epi16(inter_vec[1], c_int16_2p1); + + // Seperate register for intermediate op + inter_vec[2] = _mm256_maddubs_epi16(a_int32_1, b0); + inter_vec[3] = _mm256_maddubs_epi16(a_int32_1, b1); + + // Perform column direction mat-mul with k = 2. + // c[3,0-31] = a[3,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_3p0 = _mm256_add_epi16(inter_vec[2], c_int16_3p0); + c_int16_3p1 = _mm256_add_epi16(inter_vec[3], c_int16_3p1); + } + + // Handle k remainder. + if (k_partial_pieces > 0) + { + uint8_t a_element[4]; + + b0 = _mm256_loadu_si256((__m256i const *)(b + (64 * k_full_pieces) + (NR * 0))); + b1 = _mm256_loadu_si256((__m256i const *)(b + (64 * k_full_pieces) + (NR * 1))); + + a_element[0] = *(a + (rs_a * 0) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_element[0]); + + // Seperate register for intermediate op + inter_vec[0] = _mm256_maddubs_epi16(a_int32_0, b0); + inter_vec[1] = _mm256_maddubs_epi16(a_int32_0, b1); + + // Perform column direction mat-mul with k = 2. + // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_0p0 = _mm256_add_epi16(inter_vec[0], c_int16_0p0); + c_int16_0p1 = _mm256_add_epi16(inter_vec[1], c_int16_0p1); + + a_element[1] = *(a + (rs_a * 1) + (cs_a * (k_full_pieces * 2))); + a_int32_1 = _mm256_set1_epi8(a_element[1]); + + // Seperate register for intermediate op + inter_vec[2] = _mm256_maddubs_epi16(a_int32_1, b0); + inter_vec[3] = _mm256_maddubs_epi16(a_int32_1, b1); + + // Perform column direction mat-mul with k = 2. + // c[1,0-31] = a[1,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_1p0 = _mm256_add_epi16(inter_vec[2], c_int16_1p0); + c_int16_1p1 = _mm256_add_epi16(inter_vec[3], c_int16_1p1); + + a_element[2] = *(a + (rs_a * 2) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_element[2]); + + // Seperate register for intermediate op + inter_vec[0] = _mm256_maddubs_epi16(a_int32_0, b0); + inter_vec[1] = _mm256_maddubs_epi16(a_int32_0, b1); + + // Perform column direction mat-mul with k = 4. + // c[2,0-63] = a[2,kr:kr+4]*b[kr:kr+4,0-63] + c_int16_2p0 = _mm256_add_epi16(inter_vec[0], c_int16_2p0); + c_int16_2p1 = _mm256_add_epi16(inter_vec[1], c_int16_2p1); + + a_element[3] = *(a + (rs_a * 3) + (cs_a * (k_full_pieces * 2))); + a_int32_1 = _mm256_set1_epi8(a_element[3]); + + // Seperate register for intermediate op + inter_vec[2] = _mm256_maddubs_epi16(a_int32_1, b0); + inter_vec[3] = _mm256_maddubs_epi16(a_int32_1, b1); + + // Perform column direction mat-mul with k = 2. + // c[3,0-31] = a[3,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_3p0 = _mm256_add_epi16(inter_vec[2], c_int16_3p0); + c_int16_3p1 = _mm256_add_epi16(inter_vec[3], c_int16_3p1); + } + + // Load alpha and beta + __m256i selector1 = _mm256_set1_epi16(alpha); + __m256i selector2 = _mm256_set1_epi16(beta); + + // Scale by alpha + c_int16_0p0 = _mm256_mullo_epi16(selector1, c_int16_0p0); + c_int16_0p1 = _mm256_mullo_epi16(selector1, c_int16_0p1); + + c_int16_1p0 = _mm256_mullo_epi16(selector1, c_int16_1p0); + c_int16_1p1 = _mm256_mullo_epi16(selector1, c_int16_1p1); + + c_int16_2p0 = _mm256_mullo_epi16(selector1, c_int16_2p0); + c_int16_2p1 = _mm256_mullo_epi16(selector1, c_int16_2p1); + + c_int16_3p0 = _mm256_mullo_epi16(selector1, c_int16_3p0); + c_int16_3p1 = _mm256_mullo_epi16(selector1, c_int16_3p1); + + // Scale C by beta. + if (beta != 0) + { + // c[0,0-15] + selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * 0) + (0 * 16))); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_0p0 = _mm256_add_epi16(selector1, c_int16_0p0); + + // c[0, 16-31] + selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * 0) + (1 * 16))); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_0p1 = _mm256_add_epi16(selector1, c_int16_0p1); + + // c[1,0-15] + selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * 1) + (0 * 16))); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_1p0 = _mm256_add_epi16(selector1, c_int16_1p0); + + // c[1,16-31] + selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * 1) + (1 * 16))); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_1p1 = _mm256_add_epi16(selector1, c_int16_1p1); + + // c[2,0-15] + selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * 2) + (0 * 16))); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_2p0 = _mm256_add_epi16(selector1, c_int16_2p0); + + // c[2,16-31] + selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * 2) + (1 * 16))); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_2p1 = _mm256_add_epi16(selector1, c_int16_2p1); + + // c[3,0-15] + selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * 3) + (0 * 16))); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_3p0 = _mm256_add_epi16(selector1, c_int16_3p0); + + // c[3,16-31] + selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * 3) + (1 * 16))); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_3p1 = _mm256_add_epi16(selector1, c_int16_3p1); + } + + // Store the results. + // c[0,0-15] + _mm256_storeu_si256((__m256i *)(c + (rs_c * 0) + (0 * 16)), c_int16_0p0); + + // c[0, 16-31] + _mm256_storeu_si256((__m256i *)(c + (rs_c * 0) + (1 * 16)), c_int16_0p1); + + // c[1,0-15] + _mm256_storeu_si256((__m256i *)(c + (rs_c * 1) + (0 * 16)), c_int16_1p0); + + // c[1,16-31] + _mm256_storeu_si256((__m256i *)(c + (rs_c * 1) + (1 * 16)), c_int16_1p1); + + // c[2,0-15] + _mm256_storeu_si256((__m256i *)(c + (rs_c * 2) + (0 * 16)), c_int16_2p0); + + // c[2,16-31] + _mm256_storeu_si256((__m256i *)(c + (rs_c * 2) + (1 * 16)), c_int16_2p1); + + // c[3,0-15] + _mm256_storeu_si256((__m256i *)(c + (rs_c * 3) + (0 * 16)), c_int16_3p0); + + // c[3,16-31] + _mm256_storeu_si256((__m256i *)(c + (rs_c * 3) + (1 * 16)), c_int16_3p1); +} + +// 2x32 int8o16 kernel +void lpgemm_rowvar_u8s8s16o16_2x32 + ( + const dim_t k0, + const uint8_t *a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t *b, + const dim_t rs_b, + const dim_t cs_b, + int16_t *c, + const dim_t rs_c, + const int16_t alpha, + const int16_t beta + ) +{ + dim_t NR = 32; + + // The division is done by considering the vpmaddubsw instruction + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + // B matrix storage. + __m256i b0; + __m256i b1; + + // A matrix storage. + __m256i a_int32_0; + __m256i a_int32_1; + __m256i inter_vec[4]; + + // Registers to use for accumulating C. + __m256i c_int16_0p0 = _mm256_setzero_si256(); + __m256i c_int16_0p1 = _mm256_setzero_si256(); + + __m256i c_int16_1p0 = _mm256_setzero_si256(); + __m256i c_int16_1p1 = _mm256_setzero_si256(); + + for (dim_t kr = 0; kr < k_full_pieces; kr += 1) + { + dim_t offset = kr * 2; + + // Broadcast a[0,kr:kr+2]. + a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 0) + (cs_a * offset))); + + b0 = _mm256_loadu_si256((__m256i const *)(b + (64 * kr) + (NR * 0))); + b1 = _mm256_loadu_si256((__m256i const *)(b + (64 * kr) + (NR * 1))); + + // Broadcast a[1,kr:kr+2]. + a_int32_1 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 1) + (cs_a * offset))); + + // Seperate register for intermediate op + inter_vec[0] = _mm256_maddubs_epi16(a_int32_0, b0); + inter_vec[1] = _mm256_maddubs_epi16(a_int32_0, b1); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_0p0 = _mm256_add_epi16(inter_vec[0], c_int16_0p0); + c_int16_0p1 = _mm256_add_epi16(inter_vec[1], c_int16_0p1); + + // Seperate register for intermediate op + inter_vec[2] = _mm256_maddubs_epi16(a_int32_1, b0); + inter_vec[3] = _mm256_maddubs_epi16(a_int32_1, b1); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_1p0 = _mm256_add_epi16(inter_vec[2], c_int16_1p0); + c_int16_1p1 = _mm256_add_epi16(inter_vec[3], c_int16_1p1); + } + // Handle k remainder. + if (k_partial_pieces > 0) + { + uint8_t a_element[2]; + + b0 = _mm256_loadu_si256((__m256i const *)(b + (64 * k_full_pieces) + (NR * 0))); + b1 = _mm256_loadu_si256((__m256i const *)(b + (64 * k_full_pieces) + (NR * 1))); + + a_element[0] = *(a + (rs_a * 0) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_element[0]); + + // Seperate register for intermediate op + inter_vec[0] = _mm256_maddubs_epi16(a_int32_0, b0); + inter_vec[1] = _mm256_maddubs_epi16(a_int32_0, b1); + + // Perform column direction mat-mul with k = 2. + // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_0p0 = _mm256_add_epi16(inter_vec[0], c_int16_0p0); + c_int16_0p1 = _mm256_add_epi16(inter_vec[1], c_int16_0p1); + + a_element[1] = *(a + (rs_a * 1) + (cs_a * (k_full_pieces * 2))); + a_int32_1 = _mm256_set1_epi8(a_element[1]); + + // Seperate register for intermediate op + inter_vec[2] = _mm256_maddubs_epi16(a_int32_1, b0); + inter_vec[3] = _mm256_maddubs_epi16(a_int32_1, b1); + + // Perform column direction mat-mul with k = 2. + // c[1,0-31] = a[1,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_1p0 = _mm256_add_epi16(inter_vec[2], c_int16_1p0); + c_int16_1p1 = _mm256_add_epi16(inter_vec[3], c_int16_1p1); + } + + // Load alpha and beta + __m256i selector1 = _mm256_set1_epi16(alpha); + __m256i selector2 = _mm256_set1_epi16(beta); + + // Scale by alpha + c_int16_0p0 = _mm256_mullo_epi16(selector1, c_int16_0p0); + c_int16_0p1 = _mm256_mullo_epi16(selector1, c_int16_0p1); + + c_int16_1p0 = _mm256_mullo_epi16(selector1, c_int16_1p0); + c_int16_1p1 = _mm256_mullo_epi16(selector1, c_int16_1p1); + + // Scale C by beta. + if (beta != 0) + { + // c[0,0-15] + selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * 0) + (0 * 16))); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_0p0 = _mm256_add_epi16(selector1, c_int16_0p0); + + // c[0, 16-31] + selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * 0) + (1 * 16))); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_0p1 = _mm256_add_epi16(selector1, c_int16_0p1); + + // c[1,0-15] + selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * 1) + (0 * 16))); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_1p0 = _mm256_add_epi16(selector1, c_int16_1p0); + + // c[1,16-31] + selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * 1) + (1 * 16))); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_1p1 = _mm256_add_epi16(selector1, c_int16_1p1); + } + + // Store the results. + // c[0,0-15] + _mm256_storeu_si256((__m256i *)(c + (rs_c * 0) + (0 * 16)), c_int16_0p0); + + // c[0, 16-31] + _mm256_storeu_si256((__m256i *)(c + (rs_c * 0) + (1 * 16)), c_int16_0p1); + + // c[1,0-15] + _mm256_storeu_si256((__m256i *)(c + (rs_c * 1) + (0 * 16)), c_int16_1p0); + + // c[1,16-31] + _mm256_storeu_si256((__m256i *)(c + (rs_c * 1) + (1 * 16)), c_int16_1p1); +} + +// 1x32 int8o16 kernel +void lpgemm_rowvar_u8s8s16o16_1x32 + ( + const dim_t k0, + const uint8_t *a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t *b, + const dim_t rs_b, + const dim_t cs_b, + int16_t *c, + const dim_t rs_c, + const int16_t alpha, + const int16_t beta + ) +{ + dim_t NR = 32; + + // The division is done by considering the vpmaddubsw instruction + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + // B matrix storage. + __m256i b0; + __m256i b1; + + // A matrix storage. + __m256i a_int32_0; + __m256i inter_vec[2]; + + // Registers to use for accumulating C. + __m256i c_int16_0p0 = _mm256_setzero_si256(); + __m256i c_int16_0p1 = _mm256_setzero_si256(); + + for (dim_t kr = 0; kr < k_full_pieces; kr += 1) + { + dim_t offset = kr * 2; + + // Broadcast a[0,kr:kr+2]. + a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 0) + (cs_a * offset))); + + b0 = _mm256_loadu_si256((__m256i const *)(b + (64 * kr) + (NR * 0))); + b1 = _mm256_loadu_si256((__m256i const *)(b + (64 * kr) + (NR * 1))); + + // Seperate register for intermediate op + inter_vec[0] = _mm256_maddubs_epi16(a_int32_0, b0); + inter_vec[1] = _mm256_maddubs_epi16(a_int32_0, b1); + + // Perform column direction mat-mul with k = 2. + // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_0p0 = _mm256_add_epi16(inter_vec[0], c_int16_0p0); + c_int16_0p1 = _mm256_add_epi16(inter_vec[1], c_int16_0p1); + } + // Handle k remainder. + if (k_partial_pieces > 0) + { + uint8_t a_element[1]; + + b0 = _mm256_loadu_si256((__m256i const *)(b + (64 * k_full_pieces) + (NR * 0))); + b1 = _mm256_loadu_si256((__m256i const *)(b + (64 * k_full_pieces) + (NR * 1))); + + a_element[0] = *(a + (rs_a * 0) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_element[0]); + + // Seperate register for intermediate op + inter_vec[0] = _mm256_maddubs_epi16(a_int32_0, b0); + inter_vec[1] = _mm256_maddubs_epi16(a_int32_0, b1); + + // Perform column direction mat-mul with k = 2. + // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_0p0 = _mm256_add_epi16(inter_vec[0], c_int16_0p0); + c_int16_0p1 = _mm256_add_epi16(inter_vec[1], c_int16_0p1); + } + + // Load alpha and beta + __m256i selector1 = _mm256_set1_epi16(alpha); + __m256i selector2 = _mm256_set1_epi16(beta); + + // Scale by alpha + c_int16_0p0 = _mm256_mullo_epi16(selector1, c_int16_0p0); + c_int16_0p1 = _mm256_mullo_epi16(selector1, c_int16_0p1); + + // Scale C by beta. + if (beta != 0) + { + // c[0,0-15] + selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * 0) + (0 * 16))); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_0p0 = _mm256_add_epi16(selector1, c_int16_0p0); + + // c[0, 16-31] + selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * 0) + (1 * 16))); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_0p1 = _mm256_add_epi16(selector1, c_int16_0p1); + } + + // Store the results. + // c[0,0-15] + _mm256_storeu_si256((__m256i *)(c + (rs_c * 0) + (0 * 16)), c_int16_0p0); + + // c[0, 16-31] + _mm256_storeu_si256((__m256i *)(c + (rs_c * 0) + (1 * 16)), c_int16_0p1); +} diff --git a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_m_fringe_s16.h b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_m_fringe_s16.h new file mode 100644 index 0000000000..f69ea2630e --- /dev/null +++ b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_m_fringe_s16.h @@ -0,0 +1,80 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_GEMM_INT16_MFRINGE +#define BLIS_GEMM_INT16_MFRINGE + +// 4x32 int8o16 kernel +void lpgemm_rowvar_u8s8s16o16_4x32( + const dim_t k0, + const uint8_t *a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t *b, + const dim_t rs_b, + const dim_t cs_b, + int16_t *c, + const dim_t rs_c, + const int16_t alpha, + const int16_t beta); + +// 2x32 int8o16 kernel +void lpgemm_rowvar_u8s8s16o16_2x32( + const dim_t k0, + const uint8_t *a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t *b, + const dim_t rs_b, + const dim_t cs_b, + int16_t *c, + const dim_t rs_c, + const int16_t alpha, + const int16_t beta); + +// 1x32 int8o16 kernel +void lpgemm_rowvar_u8s8s16o16_1x32( + const dim_t k0, + const uint8_t *a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t *b, + const dim_t rs_b, + const dim_t cs_b, + int16_t *c, + const dim_t rs_c, + const int16_t alpha, + const int16_t beta); + +#endif // BLIS_GEMM_INT16_MFRINGE \ No newline at end of file diff --git a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_mn_fringe_amd256.c b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_mn_fringe_amd256.c new file mode 100644 index 0000000000..dba8792cc1 --- /dev/null +++ b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_mn_fringe_amd256.c @@ -0,0 +1,854 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include + +#include "blis.h" +#include "lpgemm_mn_fringe_s16.h" + +// 4x32 int8o16 kernel +void lpgemm_rowvar_u8s8s16o16_4x16 + ( + const dim_t k0, + const uint8_t *a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t *b, + const dim_t rs_b, + const dim_t cs_b, + int16_t *c, + const dim_t rs_c, + const int16_t alpha, + const int16_t beta + ) +{ + dim_t NR = 16; + + // The division is done by considering the vpmaddubsw instruction + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + // B matrix storage. + __m256i b0; + + // A matrix storage. + __m256i a_int32_0; + __m256i inter_vec; + + // Registers to use for accumulating C. + __m256i c_int16_0p0 = _mm256_setzero_si256(); + __m256i c_int16_1p0 = _mm256_setzero_si256(); + __m256i c_int16_2p0 = _mm256_setzero_si256(); + __m256i c_int16_3p0 = _mm256_setzero_si256(); + + for (dim_t kr = 0; kr < k_full_pieces; kr += 1) + { + dim_t offset = kr * 2; + + b0 = _mm256_loadu_si256((__m256i const *)(b + (32 * kr) + (NR * 0))); + + // Broadcast a[0,kr:kr+2]. + a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 0) + (cs_a * offset))); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_0p0 = _mm256_add_epi16(inter_vec, c_int16_0p0); + + // Broadcast a[1,kr:kr+2]. + a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 1) + (cs_a * offset))); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_1p0 = _mm256_add_epi16(inter_vec, c_int16_1p0); + + // Broadcast a[2,kr:kr+2]. + a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 2) + (cs_a * offset))); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_2p0 = _mm256_add_epi16(inter_vec, c_int16_2p0); + + // Broadcast a[3,kr:kr+2]. + a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 3) + (cs_a * offset))); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[3,0-31] = a[3,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_3p0 = _mm256_add_epi16(inter_vec, c_int16_3p0); + } + + // Handle k remainder. + if (k_partial_pieces > 0) + { + uint8_t a_element[4]; + + b0 = _mm256_loadu_si256((__m256i const *)(b + (32 * k_full_pieces) + (NR * 0))); + + a_element[0] = *(a + (rs_a * 0) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_element[0]); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_0p0 = _mm256_add_epi16(inter_vec, c_int16_0p0); + + a_element[1] = *(a + (rs_a * 1) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_element[1]); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_1p0 = _mm256_add_epi16(inter_vec, c_int16_1p0); + + a_element[2] = *(a + (rs_a * 2) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_element[2]); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_2p0 = _mm256_add_epi16(inter_vec, c_int16_2p0); + + a_element[3] = *(a + (rs_a * 3) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_element[3]); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_3p0 = _mm256_add_epi16(inter_vec, c_int16_3p0); + } + + // Load alpha and beta + __m256i selector1 = _mm256_set1_epi16(alpha); + __m256i selector2 = _mm256_set1_epi16(beta); + + // Scale by alpha + c_int16_0p0 = _mm256_mullo_epi16(selector1, c_int16_0p0); + + c_int16_1p0 = _mm256_mullo_epi16(selector1, c_int16_1p0); + + c_int16_2p0 = _mm256_mullo_epi16(selector1, c_int16_2p0); + + c_int16_3p0 = _mm256_mullo_epi16(selector1, c_int16_3p0); + + // Scale C by beta. + if (beta != 0) + { + // c[0,0-15] + selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * 0) + (0 * 16))); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_0p0 = _mm256_add_epi16(selector1, c_int16_0p0); + + // c[1,0-15] + selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * 1) + (0 * 16))); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_1p0 = _mm256_add_epi16(selector1, c_int16_1p0); + + // c[2,0-15] + selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * 2) + (0 * 16))); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_2p0 = _mm256_add_epi16(selector1, c_int16_2p0); + + // c[3,0-15] + selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * 3) + (0 * 16))); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_3p0 = _mm256_add_epi16(selector1, c_int16_3p0); + } + + // Store the results. + // c[0,0-15] + _mm256_storeu_si256((__m256i *)(c + (rs_c * 0) + (0 * 16)), c_int16_0p0); + + // c[1,0-15] + _mm256_storeu_si256((__m256i *)(c + (rs_c * 1) + (0 * 16)), c_int16_1p0); + + // c[2,0-15] + _mm256_storeu_si256((__m256i *)(c + (rs_c * 2) + (0 * 16)), c_int16_2p0); + + // c[3,0-15] + _mm256_storeu_si256((__m256i *)(c + (rs_c * 3) + (0 * 16)), c_int16_3p0); +} + +// 4x16 int8o16 kernel +void lpgemm_rowvar_u8s8s16o16_4xlt16 + ( + const dim_t k0, + const uint8_t *a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t *b, + const dim_t rs_b, + const dim_t cs_b, + int16_t *c, + const dim_t rs_c, + const int16_t alpha, + const int16_t beta, + dim_t n0_rem + ) +{ + dim_t NR = 16; + + // The division is done by considering the vpmaddubsw instruction + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + // B matrix storage. + __m256i b0; + + // A matrix storage. + __m256i a_int32_0; + __m256i inter_vec; + + int16_t buf0[16]; + int16_t buf1[16]; + int16_t buf2[16]; + int16_t buf3[16]; + + // Registers to use for accumulating C. + __m256i c_int16_0p0 = _mm256_setzero_si256(); + + __m256i c_int16_1p0 = _mm256_setzero_si256(); + + __m256i c_int16_2p0 = _mm256_setzero_si256(); + + __m256i c_int16_3p0 = _mm256_setzero_si256(); + + for (dim_t kr = 0; kr < k_full_pieces; kr += 1) + { + dim_t offset = kr * 2; + + b0 = _mm256_loadu_si256((__m256i const *)(b + (32 * kr) + (NR * 0))); + + // Broadcast a[0,kr:kr+2]. + a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 0) + (cs_a * offset))); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_0p0 = _mm256_add_epi16(inter_vec, c_int16_0p0); + + // Broadcast a[1,kr:kr+2]. + a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 1) + (cs_a * offset))); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_1p0 = _mm256_add_epi16(inter_vec, c_int16_1p0); + + // Broadcast a[2,kr:kr+2]. + a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 2) + (cs_a * offset))); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_2p0 = _mm256_add_epi16(inter_vec, c_int16_2p0); + + // Broadcast a[3,kr:kr+2]. + a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 3) + (cs_a * offset))); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[3,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_3p0 = _mm256_add_epi16(inter_vec, c_int16_3p0); + } + + // Handle k remainder. + if (k_partial_pieces > 0) + { + uint8_t a_element[4]; + + b0 = _mm256_loadu_si256((__m256i const *)(b + (32 * k_full_pieces) + (NR * 0))); + + a_element[0] = *(a + (rs_a * 0) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_element[0]); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_0p0 = _mm256_add_epi16(inter_vec, c_int16_0p0); + + a_element[1] = *(a + (rs_a * 1) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_element[1]); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_1p0 = _mm256_add_epi16(inter_vec, c_int16_1p0); + + a_element[2] = *(a + (rs_a * 2) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_element[2]); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[2,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_2p0 = _mm256_add_epi16(inter_vec, c_int16_2p0); + + a_element[3] = *(a + (rs_a * 3) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_element[3]); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_3p0 = _mm256_add_epi16(inter_vec, c_int16_3p0); + } + + // Load alpha and beta + __m256i selector1 = _mm256_set1_epi16(alpha); + __m256i selector2 = _mm256_set1_epi16(beta); + + // Scale by alpha + c_int16_0p0 = _mm256_mullo_epi16(selector1, c_int16_0p0); + + c_int16_1p0 = _mm256_mullo_epi16(selector1, c_int16_1p0); + + c_int16_2p0 = _mm256_mullo_epi16(selector1, c_int16_2p0); + + c_int16_3p0 = _mm256_mullo_epi16(selector1, c_int16_3p0); + + // Scale C by beta. + if (beta != 0) + { + memcpy(buf0, (c + (rs_c * 0)), (n0_rem * sizeof(int16_t))); + memcpy(buf1, (c + (rs_c * 1)), (n0_rem * sizeof(int16_t))); + memcpy(buf2, (c + (rs_c * 2)), (n0_rem * sizeof(int16_t))); + memcpy(buf3, (c + (rs_c * 3)), (n0_rem * sizeof(int16_t))); + + // c[0,0-15] + selector1 = _mm256_loadu_si256((__m256i const *)buf0); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_0p0 = _mm256_add_epi16(selector1, c_int16_0p0); + + // c[1,0-15] + selector1 = _mm256_loadu_si256((__m256i const *)buf1); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_1p0 = _mm256_add_epi16(selector1, c_int16_1p0); + + // c[2,0-15] + selector1 = _mm256_loadu_si256((__m256i const *)buf2); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_2p0 = _mm256_add_epi16(selector1, c_int16_2p0); + + // c[3,0-15] + selector1 = _mm256_loadu_si256((__m256i const *)buf3); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_3p0 = _mm256_add_epi16(selector1, c_int16_3p0); + } + + // c[0,0-15] + _mm256_storeu_si256((__m256i_u *)buf0, c_int16_0p0); + + // c[1,0-15] + _mm256_storeu_si256((__m256i_u *)buf1, c_int16_1p0); + + // c[2,0-15] + _mm256_storeu_si256((__m256i_u *)buf2, c_int16_2p0); + + // c[3,0-15] + _mm256_storeu_si256((__m256i_u *)buf3, c_int16_3p0); + + memcpy(c + (rs_c * 0) + (0 * 16), buf0, (n0_rem * sizeof(int16_t))); + + // c[1,0-15] + memcpy(c + (rs_c * +1) + (0 * 16), buf1, (n0_rem * sizeof(int16_t))); + + // c[2,0-15] + memcpy(c + (rs_c * +2) + (0 * 16), buf2, (n0_rem * sizeof(int16_t))); + + // c[3,0-15] + memcpy(c + (rs_c * +3) + (0 * 16), buf3, (n0_rem * sizeof(int16_t))); +} + +// 2x16 int8o16 kernel +void lpgemm_rowvar_u8s8s16o16_2x16 + ( + const dim_t k0, + const uint8_t *a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t *b, + const dim_t rs_b, + const dim_t cs_b, + int16_t *c, + const dim_t rs_c, + const int16_t alpha, + const int16_t beta + ) +{ + dim_t NR = 16; + + // The division is done by considering the vpmaddubsw instruction + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + // B matrix storage. + __m256i b0; + + // A matrix storage. + __m256i a_int32_0; + __m256i inter_vec; + + // Registers to use for accumulating C. + __m256i c_int16_0p0 = _mm256_setzero_si256(); + + __m256i c_int16_1p0 = _mm256_setzero_si256(); + + for (dim_t kr = 0; kr < k_full_pieces; kr += 1) + { + dim_t offset = kr * 2; + + b0 = _mm256_loadu_si256((__m256i const *)(b + (32 * kr) + (NR * 0))); + + // Broadcast a[0,kr:kr+2]. + a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 0) + (cs_a * offset))); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_0p0 = _mm256_add_epi16(inter_vec, c_int16_0p0); + + // Broadcast a[1,kr:kr+2]. + a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 1) + (cs_a * offset))); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+4]*b[kr:kr+2,0-31] + c_int16_1p0 = _mm256_add_epi16(inter_vec, c_int16_1p0); + } + // Handle k remainder. + if (k_partial_pieces > 0) + { + uint8_t a_element[2]; + + b0 = _mm256_loadu_si256((__m256i const *)(b + (32 * k_full_pieces) + (NR * 0))); + + a_element[0] = *(a + (rs_a * 0) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_element[0]); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_0p0 = _mm256_add_epi16(inter_vec, c_int16_0p0); + + a_element[1] = *(a + (rs_a * 1) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_element[1]); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_1p0 = _mm256_add_epi16(inter_vec, c_int16_1p0); + } + + // Load alpha and beta + __m256i selector1 = _mm256_set1_epi16(alpha); + __m256i selector2 = _mm256_set1_epi16(beta); + + // Scale by alpha + c_int16_0p0 = _mm256_mullo_epi16(selector1, c_int16_0p0); + + c_int16_1p0 = _mm256_mullo_epi16(selector1, c_int16_1p0); + + // Scale C by beta. + if (beta != 0) + { + // c[0,0-15] + selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * 0) + (0 * 16))); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_0p0 = _mm256_add_epi16(selector1, c_int16_0p0); + + // c[1,0-15] + selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * 1) + (0 * 16))); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_1p0 = _mm256_add_epi16(selector1, c_int16_1p0); + } + + // Store the results. + // c[0,0-15] + _mm256_storeu_si256((__m256i *)(c + (rs_c * 0) + (0 * 16)), c_int16_0p0); + + // c[1,0-15] + _mm256_storeu_si256((__m256i *)(c + (rs_c * 1) + (0 * 16)), c_int16_1p0); +} + +// 2xlt16 int8o16 kernel +void lpgemm_rowvar_u8s8s16o16_2xlt16 + ( + const dim_t k0, + const uint8_t *a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t *b, + const dim_t rs_b, + const dim_t cs_b, + int16_t *c, + const dim_t rs_c, + const int16_t alpha, + const int16_t beta, + dim_t n0_rem + ) +{ + dim_t NR = 16; + + // The division is done by considering the vpmaddubsw instruction + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + // B matrix storage. + __m256i b0; + + // A matrix storage. + __m256i a_int32_0; + __m256i inter_vec; + + int16_t buf0[16]; + int16_t buf1[16]; + + // Registers to use for accumulating C. + __m256i c_int16_0p0 = _mm256_setzero_si256(); + + __m256i c_int16_1p0 = _mm256_setzero_si256(); + + for (dim_t kr = 0; kr < k_full_pieces; kr += 1) + { + dim_t offset = kr * 2; + + b0 = _mm256_loadu_si256((__m256i const *)(b + (32 * kr) + (NR * 0))); + + // Broadcast a[0,kr:kr+2]. + a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 0) + (cs_a * offset))); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_0p0 = _mm256_add_epi16(inter_vec, c_int16_0p0); + + // Broadcast a[1,kr:kr+2]. + a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 1) + (cs_a * offset))); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 4. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_1p0 = _mm256_add_epi16(inter_vec, c_int16_1p0); + } + // Handle k remainder. + if (k_partial_pieces > 0) + { + uint8_t a_element[4]; + + b0 = _mm256_loadu_si256((__m256i const *)(b + (32 * k_full_pieces) + (NR * 0))); + + a_element[0] = *(a + (rs_a * 0) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_element[0]); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_0p0 = _mm256_add_epi16(inter_vec, c_int16_0p0); + + a_element[1] = *(a + (rs_a * 1) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_element[1]); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_1p0 = _mm256_add_epi16(inter_vec, c_int16_1p0); + } + + // Load alpha and beta + __m256i selector1 = _mm256_set1_epi16(alpha); + __m256i selector2 = _mm256_set1_epi16(beta); + + // Scale by alpha + c_int16_0p0 = _mm256_mullo_epi16(selector1, c_int16_0p0); + + c_int16_1p0 = _mm256_mullo_epi16(selector1, c_int16_1p0); + + // Scale C by beta. + if (beta != 0) + { + memcpy(buf0, (c + (rs_c * 0)), (n0_rem * sizeof(int16_t))); + memcpy(buf1, (c + (rs_c * 1)), (n0_rem * sizeof(int16_t))); + + // c[0,0-15] + selector1 = _mm256_loadu_si256((__m256i const *)buf0); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_0p0 = _mm256_add_epi16(selector1, c_int16_0p0); + + // c[1,0-15] + selector1 = _mm256_loadu_si256((__m256i const *)buf1); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_1p0 = _mm256_add_epi16(selector1, c_int16_1p0); + } + + // c[0,0-15] + _mm256_storeu_si256((__m256i_u *)buf0, c_int16_0p0); + + // c[1,0-15] + _mm256_storeu_si256((__m256i_u *)buf1, c_int16_1p0); + + // c[0,0-15] + memcpy(c + (rs_c * 0) + (0 * 16), buf0, (n0_rem * sizeof(int16_t))); + + // c[1,0-15] + memcpy(c + (rs_c * +1) + (0 * 16), buf1, (n0_rem * sizeof(int16_t))); +} + +// 1x16 int8o16 kernel +void lpgemm_rowvar_u8s8s16o16_1x16 + ( + const dim_t k0, + const uint8_t *a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t *b, + const dim_t rs_b, + const dim_t cs_b, + int16_t *c, + const dim_t rs_c, + const int16_t alpha, + const int16_t beta + ) +{ + int NR = 16; + + // The division is done by considering the vpmaddubsw instruction + int k_full_pieces = k0 / 2; + int k_partial_pieces = k0 % 2; + + // B matrix storage. + __m256i b0; + + // A matrix storage. + __m256i a_int32_0; + __m256i inter_vec; + + // Registers to use for accumulating C. + __m256i c_int16_0p0 = _mm256_setzero_si256(); + + for (int kr = 0; kr < k_full_pieces; kr += 1) + { + int offset = kr * 2; + + // Broadcast a[0,kr:kr+2]. + a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 0) + (cs_a * offset))); + + b0 = _mm256_loadu_si256((__m256i const *)(b + (32 * kr) + (NR * 0))); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_0p0 = _mm256_add_epi16(inter_vec, c_int16_0p0); + } + // Handle k remainder. + if (k_partial_pieces > 0) + { + uint8_t a_element[1]; + + b0 = _mm256_loadu_si256((__m256i const *)(b + (64 * k_full_pieces) + (NR * 0))); + + a_element[0] = *(a + (rs_a * 0) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_element[0]); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_0p0 = _mm256_add_epi16(inter_vec, c_int16_0p0); + } + + // Load alpha and beta + __m256i selector1 = _mm256_set1_epi16(alpha); + __m256i selector2 = _mm256_set1_epi16(beta); + + // Scale by alpha + c_int16_0p0 = _mm256_mullo_epi16(selector1, c_int16_0p0); + + // Scale C by beta. + if (beta != 0) + { + // c[0,0-15] + selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * 0) + (0 * 16))); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_0p0 = _mm256_add_epi16(selector1, c_int16_0p0); + } + + // Store the results. + // c[0,0-15] + _mm256_storeu_si256((__m256i *)(c + (rs_c * 0) + (0 * 16)), c_int16_0p0); +} + +// 1xlt16 int8o16 kernel +void lpgemm_rowvar_u8s8s16o16_1xlt16 + ( + const int k0, + const uint8_t *a, + const int rs_a, + const int cs_a, + const int8_t *b, + const int rs_b, + const int cs_b, + int16_t *c, + const int rs_c, + const int16_t alpha, + const int16_t beta, + dim_t n0_rem + ) +{ + int NR = 16; + + // The division is done by considering the vpmaddubsw instruction + int k_full_pieces = k0 / 2; + int k_partial_pieces = k0 % 2; + + // B matrix storage. + __m256i b0; + + // A matrix storage. + __m256i a_int32_0; + __m256i inter_vec; + + int16_t buf0[16]; + + // Registers to use for accumulating C. + __m256i c_int16_0p0 = _mm256_setzero_si256(); + + for (int kr = 0; kr < k_full_pieces; kr += 1) + { + int offset = kr * 2; + + b0 = _mm256_loadu_si256((__m256i const *)(b + (32 * kr) + (NR * 0))); + + // Broadcast a[0,kr:kr+2]. + a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 0) + (cs_a * offset))); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_0p0 = _mm256_add_epi16(inter_vec, c_int16_0p0); + } + // Handle k remainder. + if (k_partial_pieces > 0) + { + uint8_t a_element[4]; + + b0 = _mm256_loadu_si256((__m256i const *)(b + (32 * k_full_pieces) + (NR * 0))); + + a_element[0] = *(a + (rs_a * 0) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_element[0]); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_0p0 = _mm256_add_epi16(inter_vec, c_int16_0p0); + } + + // Load alpha and beta + __m256i selector1 = _mm256_set1_epi16(alpha); + __m256i selector2 = _mm256_set1_epi16(beta); + + // Scale by alpha + c_int16_0p0 = _mm256_mullo_epi16(selector1, c_int16_0p0); + + // Scale C by beta. + if (beta != 0) + { + memcpy(buf0, (c + (rs_c * 0)), (n0_rem * sizeof(int16_t))); + + // c[0,0-15] + selector1 = _mm256_loadu_si256((__m256i const *)buf0); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_0p0 = _mm256_add_epi16(selector1, c_int16_0p0); + } + + // c[0,0-15] + _mm256_storeu_si256((__m256i_u *)buf0, c_int16_0p0); + + memcpy(c + (rs_c * 0) + (0 * 16), buf0, (n0_rem * sizeof(int16_t))); +} diff --git a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_mn_fringe_s16.h b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_mn_fringe_s16.h new file mode 100644 index 0000000000..6b34fae976 --- /dev/null +++ b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_mn_fringe_s16.h @@ -0,0 +1,119 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_GEMM_INT16_MNFRINGE +#define BLIS_GEMM_INT16_MNFRINGE + +void lpgemm_rowvar_u8s8s16o16_4x16( + const dim_t k0, + const uint8_t *a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t *b, + const dim_t rs_b, + const dim_t cs_b, + int16_t *c, + const dim_t rs_c, + const int16_t alpha, + const int16_t beta); + +void lpgemm_rowvar_u8s8s16o16_4xlt16( + const dim_t k0, + const uint8_t *a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t *b, + const dim_t rs_b, + const dim_t cs_b, + int16_t *c, + const dim_t rs_c, + const int16_t alpha, + const int16_t beta, + dim_t n0_rem); + +void lpgemm_rowvar_u8s8s16o16_2x16( + const dim_t k0, + const uint8_t *a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t *b, + const dim_t rs_b, + const dim_t cs_b, + int16_t *c, + const dim_t rs_c, + const int16_t alpha, + const int16_t beta); + +void lpgemm_rowvar_u8s8s16o16_2xlt16( + const dim_t k0, + const uint8_t *a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t *b, + const dim_t rs_b, + const dim_t cs_b, + int16_t *c, + const dim_t rs_c, + const int16_t alpha, + const int16_t beta, + dim_t n0_rem); + +void lpgemm_rowvar_u8s8s16o16_1x16( + const dim_t k0, + const uint8_t *a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t *b, + const dim_t rs_b, + const dim_t cs_b, + int16_t *c, + const dim_t rs_c, + const int16_t alpha, + const int16_t beta); + +void lpgemm_rowvar_u8s8s16o16_1xlt16( + const int k0, + const uint8_t *a, + const int rs_a, + const int cs_a, + const int8_t *b, + const int rs_b, + const int cs_b, + int16_t *c, + const int rs_c, + const int16_t alpha, + const int16_t beta, + dim_t n0_rem); + +#endif // BLIS_GEMM_INT16_MNFRINGE \ No newline at end of file diff --git a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_n_fringe_amd256.c b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_n_fringe_amd256.c new file mode 100644 index 0000000000..d5cec8f0fc --- /dev/null +++ b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_n_fringe_amd256.c @@ -0,0 +1,683 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include + +#include "blis.h" +#include "lpgemm_n_fringe_s16.h" +#include "lpgemm_mn_fringe_s16.h" + +// 6x16 int8o16 kernel +void lpgemm_rowvar_u8s8s16o16_6x16 + ( + const dim_t m0, + const dim_t k0, + const uint8_t *a, + const dim_t rs_a, + const dim_t cs_a, + const dim_t ps_a, + const int8_t *b, + const dim_t rs_b, + const dim_t cs_b, + int16_t *c, + const dim_t rs_c, + const int16_t alpha, + const int16_t beta + ) +{ + dim_t MR = 6; + dim_t NR = 16; + + dim_t m_full_pieces = m0 / MR; + dim_t m_full_pieces_loop_limit = m_full_pieces * MR; + dim_t m_partial_pieces = m0 % MR; + + // The division is done by considering the vpmaddubsw instruction + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + // B matrix storage. + __m256i b0; + + // A matrix storage. + __m256i a_int32_0; + __m256i inter_vec; + + for (dim_t ir = 0; ir < m_full_pieces_loop_limit; ir += MR) + { + // Registers to use for accumulating C. + __m256i c_int16_0p0 = _mm256_setzero_si256(); + + __m256i c_int16_1p0 = _mm256_setzero_si256(); + + __m256i c_int16_2p0 = _mm256_setzero_si256(); + + __m256i c_int16_3p0 = _mm256_setzero_si256(); + + __m256i c_int16_4p0 = _mm256_setzero_si256(); + + __m256i c_int16_5p0 = _mm256_setzero_si256(); + + for (dim_t kr = 0; kr < k_full_pieces; kr += 1) + { + int offset = kr * 2; + + b0 = _mm256_loadu_si256((__m256i const *)(b + (32 * kr) + (NR * 0))); + + // Broadcast a[0,kr:kr+2]. + a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 0) + (cs_a * offset))); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_0p0 = _mm256_add_epi16(inter_vec, c_int16_0p0); + + // Broadcast a[1,kr:kr+2]. + a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 1) + (cs_a * offset))); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_1p0 = _mm256_add_epi16(inter_vec, c_int16_1p0); + + // Broadcast a[2,kr:kr+2]. + a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 2) + (cs_a * offset))); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[2,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_2p0 = _mm256_add_epi16(inter_vec, c_int16_2p0); + + // Broadcast a[3,kr:kr+2]. + a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 3) + (cs_a * offset))); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[3,0-15] = a[3,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_3p0 = _mm256_add_epi16(inter_vec, c_int16_3p0); + + // Broadcast a[4,kr:kr+2]. + a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 4) + (cs_a * offset))); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[4,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_4p0 = _mm256_add_epi16(inter_vec, c_int16_4p0); + + // Broadcast a[5,kr:kr+2]. + a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 5) + (cs_a * offset))); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[5,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_5p0 = _mm256_add_epi16(inter_vec, c_int16_5p0); + } + + // Handle k remainder. + if (k_partial_pieces > 0) + { + uint8_t a_element[6]; + + b0 = _mm256_loadu_si256((__m256i const *)(b + (32 * k_full_pieces) + (NR * 0))); + + a_element[0] = *(a + (rs_a * 0) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_element[0]); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_0p0 = _mm256_add_epi16(inter_vec, c_int16_0p0); + + a_element[1] = *(a + (rs_a * 1) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_element[1]); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_1p0 = _mm256_add_epi16(inter_vec, c_int16_1p0); + + a_element[2] = *(a + (rs_a * 2) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_element[2]); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_2p0 = _mm256_add_epi16(inter_vec, c_int16_2p0); + + a_element[3] = *(a + (rs_a * 3) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_element[3]); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[3,0-15] = a[3,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_3p0 = _mm256_add_epi16(inter_vec, c_int16_3p0); + + a_element[4] = *(a + (rs_a * 4) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_element[4]); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[4,0-15] = a[4,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_4p0 = _mm256_add_epi16(inter_vec, c_int16_4p0); + + a_element[5] = *(a + (rs_a * 5) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_element[5]); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[5,0-15] = a[5,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_5p0 = _mm256_add_epi16(inter_vec, c_int16_5p0); + } + + // Load alpha and beta + __m256i selector1 = _mm256_set1_epi16(alpha); + __m256i selector2 = _mm256_set1_epi16(beta); + + // Scale by alpha + c_int16_0p0 = _mm256_mullo_epi16(selector1, c_int16_0p0); + + c_int16_1p0 = _mm256_mullo_epi16(selector1, c_int16_1p0); + + c_int16_2p0 = _mm256_mullo_epi16(selector1, c_int16_2p0); + + c_int16_3p0 = _mm256_mullo_epi16(selector1, c_int16_3p0); + + c_int16_4p0 = _mm256_mullo_epi16(selector1, c_int16_4p0); + + c_int16_5p0 = _mm256_mullo_epi16(selector1, c_int16_5p0); + + // Scale C by beta. + if (beta != 0) + { + // c[0,0-15] + selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * (ir + 0)) + (0 * 16))); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_0p0 = _mm256_add_epi16(selector1, c_int16_0p0); + + // c[1,0-15] + selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * (ir + 1)) + (0 * 16))); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_1p0 = _mm256_add_epi16(selector1, c_int16_1p0); + + // c[2,0-15] + selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * (ir + 2)) + (0 * 16))); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_2p0 = _mm256_add_epi16(selector1, c_int16_2p0); + + // c[3,0-15] + selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * (ir + 3)) + (0 * 16))); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_3p0 = _mm256_add_epi16(selector1, c_int16_3p0); + + // c[4,0-15] + selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * (ir + 4)) + (0 * 16))); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_4p0 = _mm256_add_epi16(selector1, c_int16_4p0); + + // c[5,0-15] + selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * (ir + 5)) + (0 * 16))); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_5p0 = _mm256_add_epi16(selector1, c_int16_5p0); + } + + // Store the results. + // c[0,0-15] + _mm256_storeu_si256((__m256i *)(c + (rs_c * (ir + 0)) + (0 * 16)), c_int16_0p0); + + // c[1,0-15] + _mm256_storeu_si256((__m256i *)(c + (rs_c * (ir + 1)) + (0 * 16)), c_int16_1p0); + + // c[2,0-15] + _mm256_storeu_si256((__m256i *)(c + (rs_c * (ir + 2)) + (0 * 16)), c_int16_2p0); + + // c[3,0-15] + _mm256_storeu_si256((__m256i *)(c + (rs_c * (ir + 3)) + (0 * 16)), c_int16_3p0); + + // c[4,0-15] + _mm256_storeu_si256((__m256i *)(c + (rs_c * (ir + 4)) + (0 * 16)), c_int16_4p0); + + // c[5,0-15] + _mm256_storeu_si256((__m256i *)(c + (rs_c * (ir + 5)) + (0 * 16)), c_int16_5p0); + + a = a + (MR * ps_a); + } + + if (m_partial_pieces > 0) + { + dim_t m_partial4 = m_partial_pieces / 4; + m_partial_pieces = m_partial_pieces % 4; + + dim_t m_partial2 = m_partial_pieces / 2; + dim_t m_partial = m_partial_pieces % 2; + + if (m_partial4 == 1) + { + lpgemm_rowvar_u8s8s16o16_4x16( + k0, + a, rs_a, cs_a, + b, rs_b, cs_b, + (c + (rs_c * m_full_pieces_loop_limit)), rs_c, + alpha, beta); + + // a pointer increment + a = a + (4 * ps_a); + m_full_pieces_loop_limit += 4; + } + + if (m_partial2 == 1) + { + lpgemm_rowvar_u8s8s16o16_2x16( + k0, + a, rs_a, cs_a, + b, rs_b, cs_b, + (c + (rs_c * m_full_pieces_loop_limit)), rs_c, + alpha, beta); + + // a pointer increment + a = a + (2 * ps_a); + m_full_pieces_loop_limit += 2; + } + + if (m_partial == 1) + { + lpgemm_rowvar_u8s8s16o16_1x16( + k0, + a, rs_a, cs_a, + b, rs_b, cs_b, + (c + (rs_c * m_full_pieces_loop_limit)), rs_c, + alpha, beta); + } + } +} + +// 6xlt16 int8o16 kernel +void lpgemm_rowvar_u8s8s16o16_6xlt16 + ( + const dim_t m0, + const dim_t k0, + const uint8_t *a, + const dim_t rs_a, + const dim_t cs_a, + const dim_t ps_a, + const int8_t *b, + const dim_t rs_b, + const dim_t cs_b, + int16_t *c, + const dim_t rs_c, + const int16_t alpha, + const int16_t beta, + const dim_t n0_rem + ) +{ + dim_t MR = 6; + + dim_t m_full_pieces = m0 / MR; + dim_t m_full_pieces_loop_limit = m_full_pieces * MR; + dim_t m_partial_pieces = m0 % MR; + + // The division is done by considering the vpmaddubsw instruction + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int16_t buf0[16]; + int16_t buf1[16]; + int16_t buf2[16]; + int16_t buf3[16]; + int16_t buf4[16]; + int16_t buf5[16]; + + // B matrix storage. + __m256i b0; + + // A matrix storage. + __m256i a_int32_0; + __m256i inter_vec; + + for (dim_t ir = 0; ir < m_full_pieces_loop_limit; ir += MR) + { + // Registers to use for accumulating C. + __m256i c_int16_0p0 = _mm256_setzero_si256(); + + __m256i c_int16_1p0 = _mm256_setzero_si256(); + + __m256i c_int16_2p0 = _mm256_setzero_si256(); + + __m256i c_int16_3p0 = _mm256_setzero_si256(); + + __m256i c_int16_4p0 = _mm256_setzero_si256(); + + __m256i c_int16_5p0 = _mm256_setzero_si256(); + + for (dim_t kr = 0; kr < k_full_pieces; kr += 1) + { + dim_t offset = kr * 2; + + b0 = _mm256_loadu_si256((__m256i const *)(b + (32 * kr) + (cs_b * 0))); + + // Broadcast a[0,kr:kr+2]. + a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 0) + (cs_a * offset))); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_0p0 = _mm256_add_epi16(inter_vec, c_int16_0p0); + + // Broadcast a[1,kr:kr+2]. + a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 1) + (cs_a * offset))); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_1p0 = _mm256_add_epi16(inter_vec, c_int16_1p0); + + // Broadcast a[2,kr:kr+2]. + a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 2) + (cs_a * offset))); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_2p0 = _mm256_add_epi16(inter_vec, c_int16_2p0); + + // Broadcast a[3,kr:kr+2]. + a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 3) + (cs_a * offset))); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[3,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_3p0 = _mm256_add_epi16(inter_vec, c_int16_3p0); + + // Broadcast a[4,kr:kr+2]. + a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 4) + (cs_a * offset))); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[4,0-15] = a[4,kr:kr+2]*b[kr:kr+4,0-31] + c_int16_4p0 = _mm256_add_epi16(inter_vec, c_int16_4p0); + + // Broadcast a[5,kr:kr+4]. + a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 5) + (cs_a * offset))); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[5,0-15] = a[5,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_5p0 = _mm256_add_epi16(inter_vec, c_int16_5p0); + } + + // Handle k remainder. + if (k_partial_pieces > 0) + { + uint8_t a_element[6]; + + b0 = _mm256_loadu_si256((__m256i const *)(b + (32 * k_full_pieces) + (cs_b * 0))); + + a_element[0] = *(a + (rs_a * 0) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_element[0]); + + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_0p0 = _mm256_add_epi16(inter_vec, c_int16_0p0); + + a_element[1] = *(a + (rs_a * 1) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_element[1]); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_1p0 = _mm256_add_epi16(inter_vec, c_int16_1p0); + + a_element[2] = *(a + (rs_a * 2) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_element[2]); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_2p0 = _mm256_add_epi16(inter_vec, c_int16_2p0); + + a_element[3] = *(a + (rs_a * 3) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_element[3]); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[3,0-15] = a[3,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_3p0 = _mm256_add_epi16(inter_vec, c_int16_3p0); + + a_element[4] = *(a + (rs_a * 4) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_element[4]); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[4,0-15] = a[4,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_4p0 = _mm256_add_epi16(inter_vec, c_int16_4p0); + + a_element[5] = *(a + (rs_a * 5) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_element[5]); + + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); + + // Perform column direction mat-mul with k = 2. + // c[5,0-15] = a[5,kr:kr+2]*b[kr:kr+2,0-31] + c_int16_5p0 = _mm256_add_epi16(inter_vec, c_int16_5p0); + } + + // Load alpha and beta + __m256i selector1 = _mm256_set1_epi16(alpha); + __m256i selector2 = _mm256_set1_epi16(beta); + + // Scale by alpha + c_int16_0p0 = _mm256_mullo_epi16(selector1, c_int16_0p0); + + c_int16_1p0 = _mm256_mullo_epi16(selector1, c_int16_1p0); + + c_int16_2p0 = _mm256_mullo_epi16(selector1, c_int16_2p0); + + c_int16_3p0 = _mm256_mullo_epi16(selector1, c_int16_3p0); + + c_int16_4p0 = _mm256_mullo_epi16(selector1, c_int16_4p0); + + c_int16_5p0 = _mm256_mullo_epi16(selector1, c_int16_5p0); + + // Scale C by beta. + if (beta != 0) + { + memcpy(buf0, (c + (rs_c * (ir + 0))), (n0_rem * sizeof(int16_t))); + memcpy(buf1, (c + (rs_c * (ir + 1))), (n0_rem * sizeof(int16_t))); + memcpy(buf2, (c + (rs_c * (ir + 2))), (n0_rem * sizeof(int16_t))); + memcpy(buf3, (c + (rs_c * (ir + 3))), (n0_rem * sizeof(int16_t))); + memcpy(buf4, (c + (rs_c * (ir + 4))), (n0_rem * sizeof(int16_t))); + memcpy(buf5, (c + (rs_c * (ir + 5))), (n0_rem * sizeof(int16_t))); + + // c[0,0-15] + selector1 = _mm256_loadu_si256((__m256i const *)buf0); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_0p0 = _mm256_add_epi16(selector1, c_int16_0p0); + + // c[1,0-15] + selector1 = _mm256_loadu_si256((__m256i const *)buf1); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_1p0 = _mm256_add_epi16(selector1, c_int16_1p0); + + // c[2,0-15] + selector1 = _mm256_loadu_si256((__m256i const *)buf2); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_2p0 = _mm256_add_epi16(selector1, c_int16_2p0); + + // c[3,0-15] + selector1 = _mm256_loadu_si256((__m256i const *)buf3); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_3p0 = _mm256_add_epi16(selector1, c_int16_3p0); + + // c[4,0-15] + selector1 = _mm256_loadu_si256((__m256i const *)buf4); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_4p0 = _mm256_add_epi16(selector1, c_int16_4p0); + + // c[5,0-15] + selector1 = _mm256_loadu_si256((__m256i const *)buf5); + selector1 = _mm256_mullo_epi16(selector2, selector1); + c_int16_5p0 = _mm256_add_epi16(selector1, c_int16_5p0); + } + + // Store the results. + // c[0,0-15] + _mm256_storeu_si256((__m256i_u *)buf0, c_int16_0p0); + + // c[1,0-15] + _mm256_storeu_si256((__m256i_u *)buf1, c_int16_1p0); + + // c[2,0-15] + _mm256_storeu_si256((__m256i_u *)buf2, c_int16_2p0); + + // c[3,0-15] + _mm256_storeu_si256((__m256i_u *)buf3, c_int16_3p0); + + // c[4,0-15] + _mm256_storeu_si256((__m256i *)buf4, c_int16_4p0); + + // c[5,0-15] + _mm256_storeu_si256((__m256i *)buf5, c_int16_5p0); + + memcpy(c + (rs_c * (ir + 0)) + (0 * 16), buf0, (n0_rem * sizeof(int16_t))); + + // c[1,0-15] + memcpy(c + (rs_c * (ir + 1)) + (0 * 16), buf1, (n0_rem * sizeof(int16_t))); + + // c[2,0-15] + memcpy(c + (rs_c * (ir + 2)) + (0 * 16), buf2, (n0_rem * sizeof(int16_t))); + + // c[3,0-15] + memcpy(c + (rs_c * (ir + 3)) + (0 * 16), buf3, (n0_rem * sizeof(int16_t))); + + // c[4,0-15] + memcpy(c + (rs_c * (ir + 4)) + (0 * 16), buf4, (n0_rem * sizeof(int16_t))); + + // c[5,0-15] + memcpy(c + (rs_c * (ir + 5)) + (0 * 16), buf5, (n0_rem * sizeof(int16_t))); + + a = a + (MR * ps_a); + } + + if (m_partial_pieces > 0) + { + dim_t m_partial4 = m_partial_pieces / 4; + m_partial_pieces = m_partial_pieces % 4; + + dim_t m_partial2 = m_partial_pieces / 2; + dim_t m_partial = m_partial_pieces % 2; + + if (m_partial4 == 1) + { + lpgemm_rowvar_u8s8s16o16_4xlt16( + k0, + a, rs_a, cs_a, + b, rs_b, cs_b, + (c + (rs_c * m_full_pieces_loop_limit)), rs_c, + alpha, beta, n0_rem); + + // a pointer increment + a = a + (4 * ps_a); + m_full_pieces_loop_limit += 4; + } + + if (m_partial2 == 1) + { + lpgemm_rowvar_u8s8s16o16_2xlt16( + k0, + a, rs_a, cs_a, + b, rs_b, cs_b, + (c + (rs_c * m_full_pieces_loop_limit)), rs_c, + alpha, beta, n0_rem); + + // a pointer increment + a = a + (2 * ps_a); + m_full_pieces_loop_limit += 2; + } + + if (m_partial == 1) + { + lpgemm_rowvar_u8s8s16o16_1xlt16( + k0, + a, rs_a, cs_a, + b, rs_b, cs_b, + (c + (rs_c * m_full_pieces_loop_limit)), rs_c, + alpha, beta, n0_rem); + } + } +} diff --git a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_n_fringe_s16.h b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_n_fringe_s16.h new file mode 100644 index 0000000000..7987aa04aa --- /dev/null +++ b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_n_fringe_s16.h @@ -0,0 +1,71 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_GEMM_INT16_NFRINGE +#define BLIS_GEMM_INT16_NFRINGE + +// 6x16 int8o16 kernel +void lpgemm_rowvar_u8s8s16o16_6x16( + const dim_t m0, + const dim_t k0, + const uint8_t *a, + const dim_t rs_a, + const dim_t cs_a, + const dim_t ps_a, + const int8_t *b, + const dim_t rs_b, + const dim_t cs_b, + int16_t *c, + const dim_t rs_c, + const int16_t alpha, + const int16_t beta); + +// 6xlt16 int8o16 kernel +void lpgemm_rowvar_u8s8s16o16_6xlt16( + const dim_t m0, + const dim_t k0, + const uint8_t *a, + const dim_t rs_a, + const dim_t cs_a, + const dim_t ps_a, + const int8_t *b, + const dim_t rs_b, + const dim_t cs_b, + int16_t *c, + const dim_t rs_c, + const int16_t alpha, + const int16_t beta, + const dim_t n0_rem); + +#endif // BLIS_GEMM_INT16_NFRINGE \ No newline at end of file diff --git a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_packb_amd256.c b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_packb_amd256.c new file mode 100644 index 0000000000..1d0e1471b1 --- /dev/null +++ b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_packb_amd256.c @@ -0,0 +1,252 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include + +#include "blis.h" +#include "lpgemm_packb_s16.h" + +void packb_nrlt16_u8s8s16o16( + int8_t *pack_b_buffer_u8s8s16o16, + const int8_t *b, + const dim_t ldb, + const dim_t rows, + dim_t n0_partial_rem) +{ + dim_t k_full_pieces_blks = rows / 2; + dim_t k_full_pieces = k_full_pieces_blks * 2; + dim_t k_partial_pieces = rows % 2; + dim_t NR = 16; + dim_t kr_new = 0; + + int8_t buf0[16], buf1[16]; + + __m128i b_vec[2], inter_vec[2]; + + for (dim_t kr = 0; kr < k_full_pieces; kr += 2) + { + memcpy(buf0, (b + (ldb * (kr + 0))), (n0_partial_rem * sizeof(int8_t))); + memcpy(buf1, (b + (ldb * (kr + 1))), (n0_partial_rem * sizeof(int8_t))); + + // Read b[0,0], b[0,1], b[0,2]......., b[0,15] + b_vec[0] = _mm_loadu_si128((__m128i *)buf0); + // Read b[1,0], b[1,1], b[1,2]......., b[1,15] + b_vec[1] = _mm_loadu_si128((__m128i *)buf1); + + // Reorder B matrix inputs to suit vpmaddubsw instructions + inter_vec[0] = _mm_unpacklo_epi8(b_vec[0], b_vec[1]); + inter_vec[1] = _mm_unpackhi_epi8(b_vec[0], b_vec[1]); + + // Store b[0,0], b[1,0], b[0,1]......., b[0,7], b[1,7] + _mm_storeu_si128((__m128i *)(pack_b_buffer_u8s8s16o16 + (kr_new * NR)), inter_vec[0]); + // Store b[0,8], b[1,8], b[0,9]......., b[0,15], b[1,15] + _mm_storeu_si128((__m128i *)(pack_b_buffer_u8s8s16o16 + ((kr_new + 1) * NR)), inter_vec[1]); + + // Increment to ignore the padded bits + kr_new += 2; + } + + // Handle k partial cases + if (k_partial_pieces > 0) + { + memcpy(buf0, (b + (ldb * (k_full_pieces + 0))), (n0_partial_rem * sizeof(int8_t))); + + // Read b[0,0], b[0,1], b[0,2]......., b[0,15] + b_vec[0] = _mm_loadu_si128((__m128i *)buf0); + b_vec[1] = _mm_setzero_si128(); // Initialize with zero for padding + + // Reorder B matrix inputs to suit vpmaddubsw instructions + inter_vec[0] = _mm_unpacklo_epi8(b_vec[0], b_vec[1]); + inter_vec[1] = _mm_unpackhi_epi8(b_vec[0], b_vec[1]); + + // Store b[0,0], 0, b[0,1]......., b[0,7], 0 + _mm_storeu_si128((__m128i *)(pack_b_buffer_u8s8s16o16 + ((kr_new + 0) * NR)), inter_vec[0]); + + // Store b[0,8], 0, b[0,9]......., b[0,15], 0 + _mm_storeu_si128((__m128i *)(pack_b_buffer_u8s8s16o16 + ((kr_new + 1) * NR)), inter_vec[1]); + } +} + +void packb_nr16_u8s8s16o16( + int8_t *pack_b_buffer_u8s8s16o16, + const int8_t *b, + const dim_t ldb, + const dim_t rows) +{ + dim_t k_full_pieces_blks = rows / 2; + dim_t k_full_pieces = k_full_pieces_blks * 2; + dim_t k_partial_pieces = rows % 2; + dim_t NR = 16; + dim_t kr_new = 0; + + __m128i b_vec[2], inter_vec[2]; + + for (dim_t kr = 0; kr < k_full_pieces; kr += 2) + { + // Read b[0,0], b[0,1], b[0,2]......., b[0,15] + b_vec[0] = _mm_loadu_si128((__m128i const *)(b + (ldb * (kr + 0)))); + + // Read b[1,0], b[1,1], b[1,2]......., b[1,15] + b_vec[1] = _mm_loadu_si128((__m128i const *)(b + (ldb * (kr + 1)))); + + // Reorder B matrix inputs to suit vpmaddubsw instructions + inter_vec[0] = _mm_unpacklo_epi8(b_vec[0], b_vec[1]); + inter_vec[1] = _mm_unpackhi_epi8(b_vec[0], b_vec[1]); + + // Store b[0,0], b[1,0], b[0,1]......., b[0,7], b[1,7] + _mm_storeu_si128((__m128i *)(pack_b_buffer_u8s8s16o16 + ((kr_new + 0) * NR)), inter_vec[0]); + + // Store b[0,8], b[1,8], b[0,9]......., b[0,15], b[1,15] + _mm_storeu_si128((__m128i *)(pack_b_buffer_u8s8s16o16 + ((kr_new + 1) * NR)), inter_vec[1]); + + // Increment to ignore the padded bits + kr_new += 2; + } + + if (k_partial_pieces > 0) + { + // Read b[0,0], b[0,1], b[0,2]......., b[0,15] + b_vec[0] = _mm_loadu_si128((__m128i const *)(b + (ldb * (k_full_pieces + 0)))); + b_vec[1] = _mm_setzero_si128(); // Initialize with zero for padding + + // Reorder B matrix inputs to suit vpmaddubsw instructions + inter_vec[0] = _mm_unpacklo_epi8(b_vec[0], b_vec[1]); + inter_vec[1] = _mm_unpackhi_epi8(b_vec[0], b_vec[1]); + + // Store b[0,0], 0, b[0,1]......., b[0,7], 0 + _mm_storeu_si128((__m128i *)(pack_b_buffer_u8s8s16o16 + ((kr_new + 0) * NR)), inter_vec[0]); + // Store b[0,8], 0, b[0,9]......., b[0,15], 0 + _mm_storeu_si128((__m128i *)(pack_b_buffer_u8s8s16o16 + ((kr_new + 1) * NR)), inter_vec[1]); + } +} + +void packb_nr32_u8s8s16o16( + int8_t *pack_b_buffer_u8s8s16o16, + const int8_t *b, + const dim_t ldb, + const dim_t cols, + const dim_t rows, + dim_t *rs_b, + dim_t *cs_b) +{ + dim_t NR = 32; + + dim_t n_full_pieces = cols / NR; + dim_t n_full_pieces_loop_limit = n_full_pieces * NR; + dim_t n_partial_pieces = cols % NR; + dim_t k_full_pieces_blks = rows / 2; + dim_t k_full_pieces = k_full_pieces_blks * 2; + dim_t k_partial_pieces = rows % 2; + + dim_t KC_updated = rows; + + // Making multiple of 2 to suit k in vpmaddubsw + KC_updated += (KC_updated & 0x1); + + __m256i b_vec[2], inter_vec[2]; + + for (dim_t jc = 0; jc < n_full_pieces_loop_limit; jc += NR) + { + for (dim_t kr = 0; kr < k_full_pieces; kr += 2) + { + // Read b[0,0], b[0,1], b[0,2]......., b[0,31] + b_vec[0] = _mm256_loadu_si256((__m256i const *)(b + (ldb * (kr + 0)) + jc)); + + // Read b[1,0], b[1,1], b[1,2]......., b[1,31] + b_vec[1] = _mm256_loadu_si256((__m256i const *)(b + (ldb * (kr + 1)) + jc)); + + // Reorder B matrix inputs to suit vpmaddubsw instructions + inter_vec[0] = _mm256_unpacklo_epi8(b_vec[0], b_vec[1]); + inter_vec[1] = _mm256_unpackhi_epi8(b_vec[0], b_vec[1]); + + b_vec[0] = _mm256_permute2f128_si256(inter_vec[0], inter_vec[1], 0x20); + b_vec[1] = _mm256_permute2f128_si256(inter_vec[0], inter_vec[1], 0x31); + + // Store B[0,0], B[1,0], B[0,1], B[1,1], ......, B[0,15], B[1,15] + _mm256_storeu_si256((__m256i *)(pack_b_buffer_u8s8s16o16 + ((jc * KC_updated) + (kr * NR))), b_vec[0]); + // Store B[0,16], B[1,16], B[0,17], B[1,17], ......, B[0,31], B[1,31] + _mm256_storeu_si256((__m256i *)(pack_b_buffer_u8s8s16o16 + ((jc * KC_updated) + ((kr + 1) * NR))), b_vec[1]); + } + + if (k_partial_pieces > 0) + { + // Read b[0,0], b[0,1], b[0,2]......., b[0,31] + b_vec[0] = _mm256_loadu_si256((__m256i const *)(b + (ldb * (k_full_pieces + 0)) + jc)); + b_vec[1] = _mm256_setzero_si256(); // Initialize with zero for padding + + // Reorder B matrix inputs to suit vpmaddubsw instructions + inter_vec[0] = _mm256_unpacklo_epi8(b_vec[0], b_vec[1]); + inter_vec[1] = _mm256_unpackhi_epi8(b_vec[0], b_vec[1]); + + b_vec[0] = _mm256_permute2f128_si256(inter_vec[0], inter_vec[1], 0x20); + b_vec[1] = _mm256_permute2f128_si256(inter_vec[0], inter_vec[1], 0x31); + + // Store B[0,0], B[1,0], B[0,1], B[1,1], ......, B[0,15], B[1,15] + _mm256_storeu_si256((__m256i *)(pack_b_buffer_u8s8s16o16 + ((jc * KC_updated) + (k_full_pieces * NR))), b_vec[0]); + // Store B[0,16], B[1,16], B[0,17], B[1,17], ......, B[0,31], B[1,31] + _mm256_storeu_si256((__m256i *)(pack_b_buffer_u8s8s16o16 + ((jc * KC_updated) + ((k_full_pieces + 1) * NR))), b_vec[1]); + } + } + + // B matrix packing when n < NR + if (n_partial_pieces > 0) + { + // Split into multiple smaller fringe kernels, so as to maximize + // vectorization after packing. Any n0 < NR(32) can be expressed + // as n0 = 16 + n`. + dim_t n0_16 = n_partial_pieces / 16; + dim_t n0_partial_rem = n_partial_pieces % 16; + + dim_t n0_partial_pack = 0; + + if (n0_16 == 1) + { + packb_nr16_u8s8s16o16( + (pack_b_buffer_u8s8s16o16 + (n_full_pieces_loop_limit * KC_updated)), + (b + n_full_pieces_loop_limit), ldb, rows); + + n0_partial_pack = 16; + } + + if (n0_partial_rem > 0) + { + packb_nrlt16_u8s8s16o16( + (pack_b_buffer_u8s8s16o16 + (n_full_pieces_loop_limit * KC_updated) + (n0_partial_pack * KC_updated)), + (b + n_full_pieces_loop_limit + n0_partial_pack), ldb, rows, n0_partial_rem); + } + } + + *rs_b = NR * 2; + *cs_b = NR; +} \ No newline at end of file diff --git a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_packb_s16.h b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_packb_s16.h new file mode 100644 index 0000000000..31d8465dac --- /dev/null +++ b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_packb_s16.h @@ -0,0 +1,47 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_GEMM_INT16_PACKB +#define BLIS_GEMM_INT16_PACKB + +void packb_nr32_u8s8s16o16( + int8_t *pack_b_buffer_u8s8s16o16, + const int8_t *b, + const dim_t ldb, + const dim_t cols, + const dim_t rows, + dim_t *rs_b, + dim_t *cs_b); + +#endif // BLIS_GEMM_INT16_PACKB \ No newline at end of file diff --git a/bench/bench_aocl_gemm/bench_lpgemm.c b/bench/bench_aocl_gemm/bench_lpgemm.c index 0af03cf2f8..37d6254e7b 100644 --- a/bench/bench_aocl_gemm/bench_lpgemm.c +++ b/bench/bench_aocl_gemm/bench_lpgemm.c @@ -78,6 +78,7 @@ void mat_mul_ ## BLAS_SFX \ c, ldc ); \ } \ +GEN_BLIS_MAT_MUL_FUNC(uint8_t, int8_t, int16_t, u8s8s16os16) GEN_BLIS_MAT_MUL_FUNC(uint8_t,int8_t,int32_t,u8s8s32os32) GEN_BLIS_MAT_MUL_FUNC(float,float,float,f32f32f32of32) @@ -107,7 +108,7 @@ void print_result { double gflops = get_gflops( m, n, k, runtime ); printf("%s m: %ld, n: %ld, k: %ld, lda: %ld, ldb: %ld, ldc: %ld," \ - " GFlops: %f, n_repeats: %d\n", + " Gops: %f, n_repeats: %d\n", msg, m, n, k, lda, ldb, ldc, gflops, n_repeats); } @@ -134,7 +135,7 @@ void mat_mul_bench_driver_ ## BLAS_SFX \ { \ if ( bench_mode == 'a' ) \ { \ - memset( ( void* ) c, 0, sizeof( float ) * m * n ); \ + memset( ( void* ) c, 0, sizeof( C_type ) * m * n ); \ } \ \ struct timespec tstart={0,0}, tend={0,0}; \ @@ -161,6 +162,7 @@ void mat_mul_bench_driver_ ## BLAS_SFX \ print_result( XSTR(BLAS_SFX), n_repeats, m, n, k, lda, ldb, ldc, min_time_diff); \ } \ +GEN_MAT_MUL_BENCH_DRV_FUNC(uint8_t, int8_t, int16_t, u8s8s16os16) GEN_MAT_MUL_BENCH_DRV_FUNC(uint8_t,int8_t,int32_t,u8s8s32os32) GEN_MAT_MUL_BENCH_DRV_FUNC(float,float,float,f32f32f32of32) @@ -214,6 +216,7 @@ cleanup_acc: \ return; \ } \ +GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t, int8_t, int16_t, u8s8s16os16) GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t,int8_t,int32_t,u8s8s32os32) GEN_MAT_MUL_ACC_CHK_DRV_FUNC(float,float,float,f32f32f32of32) @@ -338,6 +341,7 @@ void mat_mul_bench_main_ ## BLAS_SFX \ } \ } \ +GEN_MAT_MUL_BENCH_MAIN_FUNC(uint8_t, int8_t, int16_t, u8s8s16os16) GEN_MAT_MUL_BENCH_MAIN_FUNC(uint8_t,int8_t,int32_t,u8s8s32os32) GEN_MAT_MUL_BENCH_MAIN_FUNC(float,float,float,f32f32f32of32) @@ -461,6 +465,14 @@ int main( int argc, char** argv ) m, n, k, stride_a, stride_b, stride_c ); } + else if ((op_type_char == 's') || (op_type_char == 'S')) + { + GEN_FUNC_NAME(mat_mul_bench_main_, u8s8s16os16) + ( + fin, fout, op_t, + m, n, k, stride_a, stride_b, stride_c + ); + } } } @@ -473,4 +485,4 @@ int main( int argc, char** argv ) fclose( fout ); } return 0; -} +} \ No newline at end of file From 828d3cd3d3801436067248084800c772c073b4db Mon Sep 17 00:00:00 2001 From: mkadavil Date: Mon, 25 Jul 2022 12:57:25 +0530 Subject: [PATCH 165/243] Post operations support for low precision gemm. - Low precision gemm is often used in ML/DNN workloads and is used in conjunction with pre and post operations. Performing gemm and ops together at the micro kernel level results in better overall performance due to cache/register reuse of output matrix. The provision for defining the post-operations and invoking the micro-kernel with it from the framework is added as part of this change. This includes adding new data structures/functions to define the post-ops to be applied and an extensible template using which new post-ops can easily be integrated. As for the post-operations, RELU and Bias Add for u8s8s32 is implemented in this first cut. - aocl_gemm bench modifications to test/benchmark RELU and Bias Add. AMD-Internal: [CPUPL-2316] Change-Id: Iad5fe9e54965bb52d5381ae459a69800946c7d18 --- addon/aocl_gemm/aocl_gemm.h | 1 + addon/aocl_gemm/aocl_gemm_f32f32f32of32.c | 14 +- addon/aocl_gemm/aocl_gemm_f32f32f32of32.h | 5 +- addon/aocl_gemm/aocl_gemm_post_ops.h | 96 + addon/aocl_gemm/aocl_gemm_u8s8s16os16.c | 35 +- addon/aocl_gemm/aocl_gemm_u8s8s16os16.h | 35 +- addon/aocl_gemm/aocl_gemm_u8s8s32os32.c | 14 +- addon/aocl_gemm/aocl_gemm_u8s8s32os32.h | 5 +- .../frame/f32f32f32/lpgemm_f32f32f32.c | 3 +- .../frame/f32f32f32/lpgemm_f32f32f32.h | 4 +- addon/aocl_gemm/frame/lpgemm_post_ops.c | 148 ++ addon/aocl_gemm/frame/lpgemm_post_ops.h | 88 + .../threading/lpgemm_thread_decor_openmp.c | 12 +- .../threading/lpgemm_thread_decor_openmp.h | 7 +- .../aocl_gemm/frame/u8s8s32/lpgemm_u8s8s32.c | 11 +- .../aocl_gemm/frame/u8s8s32/lpgemm_u8s8s32.h | 4 +- .../kernels/u8s8s32/lpgemm_6x64rowmajor.h | 8 +- .../u8s8s32/lpgemm_6x64rowmajor_amd512vnni.c | 243 ++- .../kernels/u8s8s32/lpgemm_m_fringe.h | 32 +- .../u8s8s32/lpgemm_m_fringe_amd512vnni.c | 671 ++++++- .../kernels/u8s8s32/lpgemm_mn_fringe.h | 122 +- .../u8s8s32/lpgemm_mn_fringe_amd512vnni.c | 1564 +++++++++++++++-- .../kernels/u8s8s32/lpgemm_n_fringe.h | 26 +- .../u8s8s32/lpgemm_n_fringe_amd512vnni.c | 565 +++++- bench/bench_aocl_gemm/bench_lpgemm.c | 277 ++- 25 files changed, 3679 insertions(+), 311 deletions(-) create mode 100644 addon/aocl_gemm/aocl_gemm_post_ops.h create mode 100644 addon/aocl_gemm/frame/lpgemm_post_ops.c create mode 100644 addon/aocl_gemm/frame/lpgemm_post_ops.h diff --git a/addon/aocl_gemm/aocl_gemm.h b/addon/aocl_gemm/aocl_gemm.h index 9316bb7bdc..f9e37e76cb 100644 --- a/addon/aocl_gemm/aocl_gemm.h +++ b/addon/aocl_gemm/aocl_gemm.h @@ -35,6 +35,7 @@ #ifndef BLIS_ADDON_LPGEMM #define BLIS_ADDON_LPGEMM +#include "aocl_gemm_post_ops.h" #include "aocl_gemm_u8s8s16os16.h" #include "aocl_gemm_u8s8s32os32.h" #include "aocl_gemm_f32f32f32of32.h" diff --git a/addon/aocl_gemm/aocl_gemm_f32f32f32of32.c b/addon/aocl_gemm/aocl_gemm_f32f32f32of32.c index 973d33c548..bc9ed29da1 100644 --- a/addon/aocl_gemm/aocl_gemm_f32f32f32of32.c +++ b/addon/aocl_gemm/aocl_gemm_f32f32f32of32.c @@ -35,6 +35,7 @@ #include "blis.h" #include "aocl_gemm_f32f32f32of32.h" #include "lpgemm_types.h" +#include "lpgemm_post_ops.h" #include "lpgemm_thread_decor_openmp.h" #include "lpgemm_utils.h" #include "lpgemm_f32f32f32.h" @@ -55,7 +56,8 @@ void aocl_gemm_f32f32f32of32 const char mem_format_b, const float beta, float* c, - const dim_t ldc + const dim_t ldc, + aocl_post_op* post_op_unparsed ) { trans_t blis_transa; @@ -134,6 +136,10 @@ void aocl_gemm_f32f32f32of32 return; // Error. } + // Convert post op struct to post op linked list format. + lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS]; + lpgemm_translate_to_post_ops_list( post_op_unparsed, post_op_list ); + // Initialize a local runtime with global settings if necessary. Note // that in the case that a runtime is passed in, we make a local copy. rntm_t rntm_g; @@ -148,7 +154,8 @@ void aocl_gemm_f32f32f32of32 b, rs_b, cs_b, mtag_b, c, rs_c, alpha, beta, - &rntm_g + &rntm_g, + post_op_list ); #else // Setting pack A by default for non open mp case. @@ -161,7 +168,8 @@ void aocl_gemm_f32f32f32of32 b, rs_b, cs_b, mtag_b, c, rs_c, alpha, beta, - &rntm_g + &rntm_g, + post_op_list ); #endif diff --git a/addon/aocl_gemm/aocl_gemm_f32f32f32of32.h b/addon/aocl_gemm/aocl_gemm_f32f32f32of32.h index 8ce4e001f9..3e450414ea 100644 --- a/addon/aocl_gemm/aocl_gemm_f32f32f32of32.h +++ b/addon/aocl_gemm/aocl_gemm_f32f32f32of32.h @@ -35,6 +35,8 @@ #ifndef AOCL_GEMM_F32F32F32OF32_H #define AOCL_GEMM_F32F32F32OF32_H +#include "aocl_gemm_post_ops.h" + // Only supports matrices in row major format. This api can perform gemm with // both normal as well as reordered B matrix as opposesd to sgemm (only // supports former). This api can be considered analogous to packed sgemm api. @@ -54,7 +56,8 @@ BLIS_EXPORT_ADDON void aocl_gemm_f32f32f32of32 const char mem_format_b, const float beta, float* c, - const dim_t ldc + const dim_t ldc, + aocl_post_op* post_op_unparsed ); #endif //AOCL_GEMM_F32F32F32OF32_H diff --git a/addon/aocl_gemm/aocl_gemm_post_ops.h b/addon/aocl_gemm/aocl_gemm_post_ops.h new file mode 100644 index 0000000000..ce69ea0b33 --- /dev/null +++ b/addon/aocl_gemm/aocl_gemm_post_ops.h @@ -0,0 +1,96 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef AOCL_GEMM_POST_OPS_H +#define AOCL_GEMM_POST_OPS_H + +#define AOCL_MAX_POST_OPS 5 + +typedef enum +{ + LINEAR = 0, + RELU = 1, + GELU = 2, + CLIP = 3, +} AOCL_ELT_ALGO_TYPE; + +typedef enum +{ + SUM = 1, + ELTWISE = 2, + BIAS = 3, +} AOCL_POST_OP_TYPE; + +typedef struct +{ + void* alpha; + void* beta; + AOCL_ELT_ALGO_TYPE algo_type; +} aocl_eltwise_algo; + +typedef struct +{ + bool is_power_of_2; + void* scale_factor; + void* buff; + void* zero_point; +} aocl_post_op_sum; + +typedef struct +{ + bool is_power_of_2; + void* scale_factor; + aocl_eltwise_algo algo; +} aocl_post_op_eltwise; + +typedef struct +{ + void* bias; +} aocl_post_op_bias; + +typedef struct +{ + aocl_post_op_sum sum; + aocl_post_op_eltwise eltwise; + aocl_post_op_bias bias; + + // eg: seq_length = 2 + dim_t seq_length; + + // eg: seq_vector[0] = BIAS, seq_vector[1] = ELTWISE means bias followed + // by eltwise(relu, if AOCL_ELT_ALGO_TYPE = 1). + AOCL_POST_OP_TYPE* seq_vector; +} aocl_post_op; + +#endif //AOCL_GEMM_POST_OPS_H diff --git a/addon/aocl_gemm/aocl_gemm_u8s8s16os16.c b/addon/aocl_gemm/aocl_gemm_u8s8s16os16.c index 8ff3c20247..ef0b4e227e 100644 --- a/addon/aocl_gemm/aocl_gemm_u8s8s16os16.c +++ b/addon/aocl_gemm/aocl_gemm_u8s8s16os16.c @@ -39,22 +39,25 @@ #include "lpgemm_config.h" #include "lpgemm_utils.h" -void aocl_gemm_u8s8s16os16( - const char transa, - const char transb, - const dim_t m, - const dim_t n, - const dim_t k, - const int16_t alpha, - const uint8_t *a, - const dim_t lda, - const char mem_format_a, - const int8_t *b, - const dim_t ldb, - const char mem_format_b, - const int16_t beta, - int16_t *c, - const dim_t ldc) +void aocl_gemm_u8s8s16os16 + ( + const char transa, + const char transb, + const dim_t m, + const dim_t n, + const dim_t k, + const int16_t alpha, + const uint8_t* a, + const dim_t lda, + const char mem_format_a, + const int8_t* b, + const dim_t ldb, + const char mem_format_b, + const int16_t beta, + int16_t* c, + const dim_t ldc, + aocl_post_op* post_op_unparsed + ) { trans_t blis_transa; trans_t blis_transb; diff --git a/addon/aocl_gemm/aocl_gemm_u8s8s16os16.h b/addon/aocl_gemm/aocl_gemm_u8s8s16os16.h index 4e56c705bd..926948aac5 100644 --- a/addon/aocl_gemm/aocl_gemm_u8s8s16os16.h +++ b/addon/aocl_gemm/aocl_gemm_u8s8s16os16.h @@ -37,21 +37,24 @@ // Only supports matrices in row major format // Limitations: Supports mem_format_b = 'Reorder' -BLIS_EXPORT_ADDON void aocl_gemm_u8s8s16os16( - const char transa, - const char transb, - const dim_t m, - const dim_t n, - const dim_t k, - const int16_t alpha, - const uint8_t *a, - const dim_t lda, - const char mem_format_a, - const int8_t *b, - const dim_t ldb, - const char mem_format_b, - const int16_t beta, - int16_t *c, - const dim_t ldc); +BLIS_EXPORT_ADDON void aocl_gemm_u8s8s16os16 + ( + const char transa, + const char transb, + const dim_t m, + const dim_t n, + const dim_t k, + const int16_t alpha, + const uint8_t* a, + const dim_t lda, + const char mem_format_a, + const int8_t* b, + const dim_t ldb, + const char mem_format_b, + const int16_t beta, + int16_t* c, + const dim_t ldc, + aocl_post_op* post_op_unparsed + ); #endif // AOCL_GEMM_U8S8S16OS16_H diff --git a/addon/aocl_gemm/aocl_gemm_u8s8s32os32.c b/addon/aocl_gemm/aocl_gemm_u8s8s32os32.c index 860b0e420c..4c92d3c74c 100644 --- a/addon/aocl_gemm/aocl_gemm_u8s8s32os32.c +++ b/addon/aocl_gemm/aocl_gemm_u8s8s32os32.c @@ -35,6 +35,7 @@ #include "blis.h" #include "aocl_gemm_u8s8s32os32.h" #include "lpgemm_types.h" +#include "lpgemm_post_ops.h" #include "lpgemm_thread_decor_openmp.h" #include "lpgemm_u8s8s32.h" #include "lpgemm_config.h" @@ -56,7 +57,8 @@ void aocl_gemm_u8s8s32os32 const char mem_format_b, const int32_t beta, int32_t* c, - const dim_t ldc + const dim_t ldc, + aocl_post_op* post_op_unparsed ) { trans_t blis_transa; @@ -133,6 +135,10 @@ void aocl_gemm_u8s8s32os32 return; // Error. } + // Convert post op struct to post op linked list format. + lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS]; + lpgemm_translate_to_post_ops_list( post_op_unparsed, post_op_list ); + // Initialize a local runtime with global settings if necessary. Note // that in the case that a runtime is passed in, we make a local copy. rntm_t rntm_g; @@ -147,7 +153,8 @@ void aocl_gemm_u8s8s32os32 b, rs_b, cs_b, mtag_b, c, rs_c, alpha, beta, - &rntm_g + &rntm_g, + post_op_list ); #else lpgemm_u8s8s32o32_thread_decorator @@ -157,7 +164,8 @@ void aocl_gemm_u8s8s32os32 b, rs_b, cs_b, mtag_b, c, rs_c, alpha, beta, - &rntm_g + &rntm_g, + post_op_list ); #endif } diff --git a/addon/aocl_gemm/aocl_gemm_u8s8s32os32.h b/addon/aocl_gemm/aocl_gemm_u8s8s32os32.h index 90be30a24b..0993d3562a 100644 --- a/addon/aocl_gemm/aocl_gemm_u8s8s32os32.h +++ b/addon/aocl_gemm/aocl_gemm_u8s8s32os32.h @@ -35,6 +35,8 @@ #ifndef AOCL_GEMM_U8S8S32OS32_H #define AOCL_GEMM_U8S8S32OS32_H +#include "aocl_gemm_post_ops.h" + // Only supports matrices in row major format. Currenlty only mem_format_b // is configurable to reorder. BLIS_EXPORT_ADDON void aocl_gemm_u8s8s32os32 @@ -53,7 +55,8 @@ BLIS_EXPORT_ADDON void aocl_gemm_u8s8s32os32 const char mem_format_b, const int32_t beta, int32_t* c, - const dim_t ldc + const dim_t ldc, + aocl_post_op* post_op_unparsed ); #endif //AOCL_GEMM_U8S8S32OS32_H diff --git a/addon/aocl_gemm/frame/f32f32f32/lpgemm_f32f32f32.c b/addon/aocl_gemm/frame/f32f32f32/lpgemm_f32f32f32.c index f9bb67803d..3744b73947 100644 --- a/addon/aocl_gemm/frame/f32f32f32/lpgemm_f32f32f32.c +++ b/addon/aocl_gemm/frame/f32f32f32/lpgemm_f32f32f32.c @@ -69,7 +69,8 @@ void lpgemm_rowvar_f32f32f32of32 float alpha, float beta, rntm_t* rntm, - lpgemm_thrinfo_t* thread + lpgemm_thrinfo_t* thread, + lpgemm_post_op* post_op_list ) { // Query the global cntx. diff --git a/addon/aocl_gemm/frame/f32f32f32/lpgemm_f32f32f32.h b/addon/aocl_gemm/frame/f32f32f32/lpgemm_f32f32f32.h index 03c1146bf0..f58754acb1 100644 --- a/addon/aocl_gemm/frame/f32f32f32/lpgemm_f32f32f32.h +++ b/addon/aocl_gemm/frame/f32f32f32/lpgemm_f32f32f32.h @@ -36,6 +36,7 @@ #define LPGEMM_F32F32F32_H #include "lpgemm_types.h" +#include "lpgemm_post_ops.h" void lpgemm_rowvar_f32f32f32of32 ( @@ -55,7 +56,8 @@ void lpgemm_rowvar_f32f32f32of32 float alpha, float beta, rntm_t* rntm, - lpgemm_thrinfo_t* thread + lpgemm_thrinfo_t* thread, + lpgemm_post_op* post_op_list ); #endif //LPGEMM_F32F32F32_H diff --git a/addon/aocl_gemm/frame/lpgemm_post_ops.c b/addon/aocl_gemm/frame/lpgemm_post_ops.c new file mode 100644 index 0000000000..700679772a --- /dev/null +++ b/addon/aocl_gemm/frame/lpgemm_post_ops.c @@ -0,0 +1,148 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "lpgemm_post_ops.h" + +BLIS_INLINE void lpgemm_set_node_params + ( + lpgemm_post_op* post_op_node, + LPGEMM_POST_OP_CODE op_code, + void* op1, + void* op2, + void* op3, + void* scale_factor, + bool is_power_of_2 + ) +{ + post_op_node->op_code = op_code; + post_op_node->op_args1 = op1; + post_op_node->op_args2 = op2; + post_op_node->op_args3 = op3; + post_op_node->scale_factor = scale_factor; + post_op_node->is_power_of_2 = is_power_of_2; + post_op_node->next = NULL; +} + +void lpgemm_translate_to_post_ops_list + ( + aocl_post_op* post_op_unparsed, + lpgemm_post_op* post_op_list + ) +{ + if ( post_op_unparsed == NULL ) + { + lpgemm_set_node_params + ( + post_op_list, POST_OPS_DISABLE, + NULL, NULL, NULL, NULL, FALSE + ); + return; + } + + if ( ( post_op_unparsed->seq_length > AOCL_MAX_POST_OPS ) ) + { + lpgemm_set_node_params + ( + post_op_list, POST_OPS_DISABLE, + NULL, NULL, NULL, NULL, FALSE + ); + return; //Error, seq length exceeds max post ops permitted. + } + + for ( dim_t i = 0; i < post_op_unparsed->seq_length; ++i ) + { + // Dispatcher code + switch ( *( post_op_unparsed->seq_vector + i ) ) + { + case SUM: + lpgemm_set_node_params + ( + ( post_op_list + i ), POST_OPS_SUM, + post_op_unparsed->sum.buff, + post_op_unparsed->sum.zero_point, + NULL, + post_op_unparsed->sum.scale_factor, + post_op_unparsed->sum.is_power_of_2 + ); + break; + case ELTWISE: + { + LPGEMM_POST_OP_CODE tmp_code = POST_OPS_DISABLE; + // Eltwise algo dispatcher. + switch ( post_op_unparsed->eltwise.algo.algo_type ) + { + case LINEAR: + tmp_code = POST_OPS_LINEAR; + break; + case RELU: + tmp_code = POST_OPS_RELU; + break; + case GELU: + tmp_code = POST_OPS_GELU; + break; + case CLIP: + tmp_code = POST_OPS_CLIP; + break; + } + lpgemm_set_node_params + ( + ( post_op_list + i ), tmp_code, + NULL, + post_op_unparsed->eltwise.algo.alpha, + post_op_unparsed->eltwise.algo.beta, + post_op_unparsed->eltwise.scale_factor, + post_op_unparsed->eltwise.is_power_of_2 + ); + } + break; + case BIAS: + lpgemm_set_node_params + ( + ( post_op_list + i ), POST_OPS_BIAS, + post_op_unparsed->bias.bias, + NULL, NULL, NULL, FALSE + ); + break; + default: + break; + } + + // Simulating linked link using an array. + if ( i < ( post_op_unparsed->seq_length - 1 ) ) + { + ( post_op_list + i )->next = ( post_op_list + i + 1); + } + } +} diff --git a/addon/aocl_gemm/frame/lpgemm_post_ops.h b/addon/aocl_gemm/frame/lpgemm_post_ops.h new file mode 100644 index 0000000000..325cda5ff7 --- /dev/null +++ b/addon/aocl_gemm/frame/lpgemm_post_ops.h @@ -0,0 +1,88 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef LPGEMM_POST_OPS_H +#define LPGEMM_POST_OPS_H + +typedef enum +{ + POST_OPS_DISABLE = 0, + POST_OPS_BIAS = 1, + POST_OPS_RELU = 2, + POST_OPS_SUM = 3, + POST_OPS_LINEAR = 4, + POST_OPS_GELU = 5, + POST_OPS_CLIP = 6, +} LPGEMM_POST_OP_CODE; + +// Used as an internal structure. +typedef struct lpgemm_post_op_t +{ + LPGEMM_POST_OP_CODE op_code; + void* op_args1; + void* op_args2; // alpha, zero_point + void* op_args3; // beta + void* scale_factor; + bool is_power_of_2; + struct lpgemm_post_op_t* next; +} lpgemm_post_op; + +void lpgemm_translate_to_post_ops_list + ( + aocl_post_op* post_op_unparsed, + lpgemm_post_op* post_op_list + ); + +#define POST_OP_LABEL_LASTK_SAFE_JUMP \ + if ( ( is_last_k == TRUE ) && ( post_ops_list_temp != NULL ) ) \ + { \ + goto *post_ops_labels[post_ops_list_temp->op_code]; \ + } \ + else \ + { \ + goto *post_ops_labels[0]; \ + } + +#define POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR \ + post_ops_list_temp = post_ops_list_temp->next; \ + if ( post_ops_list_temp != NULL ) \ + { \ + goto *post_ops_labels[post_ops_list_temp->op_code]; \ + } \ + else \ + { \ + goto *post_ops_labels[0]; \ + } + +#endif //LPGEMM_POST_OPS_H diff --git a/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.c b/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.c index 5659c8286f..d35e04009f 100644 --- a/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.c +++ b/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.c @@ -373,7 +373,8 @@ void lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \ const dim_t rs_c, \ C_type alpha, \ C_type beta, \ - rntm_t* rntm_g \ + rntm_t* rntm_g, \ + lpgemm_post_op* post_op_list \ ) \ { \ dim_t n_threads; \ @@ -432,7 +433,8 @@ void lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \ alpha, \ beta, \ &rntm_l, \ - &thread \ + &thread, \ + post_op_list \ ); \ } \ if ( jc_ways > BLIS_LPGEMM_NUM_STATIC_COMMS ) \ @@ -464,7 +466,8 @@ void lpgemm_ ## LPGEMM_SFX ## _thread_decorator \ const dim_t rs_c, \ C_type alpha, \ C_type beta, \ - rntm_t* rntm_g \ + rntm_t* rntm_g, \ + lpgemm_post_op* post_op_list \ ) \ { \ dim_t n_threads = 1; \ @@ -502,7 +505,8 @@ void lpgemm_ ## LPGEMM_SFX ## _thread_decorator \ alpha, \ beta, \ rntm_g, \ - &thread \ + &thread, \ + post_op_list \ ); \ } \ diff --git a/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.h b/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.h index dd38a02ebd..51a5941481 100644 --- a/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.h +++ b/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.h @@ -36,6 +36,7 @@ #define LPGEMM_THREAD_DECOR_OPENMP_H #include "lpgemm_types.h" +#include "lpgemm_post_ops.h" #ifdef BLIS_ENABLE_OPENMP @@ -57,7 +58,8 @@ void lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \ const dim_t rs_c, \ C_type alpha, \ C_type beta, \ - rntm_t* rntm_g \ + rntm_t* rntm_g, \ + lpgemm_post_op* post_op_list \ ); \ GEN_LPGEMM_OPENMP_DECORATOR_FN(uint8_t,int8_t,int32_t,u8s8s32o32) @@ -83,7 +85,8 @@ void lpgemm_ ## LPGEMM_SFX ## _thread_decorator \ const dim_t rs_c, \ C_type alpha, \ C_type beta, \ - rntm_t* rntm_g \ + rntm_t* rntm_g, \ + lpgemm_post_op* post_op_list \ ); \ GEN_LPGEMM_DECORATOR_FN(uint8_t,int8_t,int32_t,u8s8s32o32) diff --git a/addon/aocl_gemm/frame/u8s8s32/lpgemm_u8s8s32.c b/addon/aocl_gemm/frame/u8s8s32/lpgemm_u8s8s32.c index 5ee25a92a1..50dc5ad97e 100644 --- a/addon/aocl_gemm/frame/u8s8s32/lpgemm_u8s8s32.c +++ b/addon/aocl_gemm/frame/u8s8s32/lpgemm_u8s8s32.c @@ -60,7 +60,8 @@ void lpgemm_rowvar_u8s8s32o32 int32_t alpha, int32_t beta, rntm_t* rntm, - lpgemm_thrinfo_t* thread + lpgemm_thrinfo_t* thread, + lpgemm_post_op* post_op_list ) { dim_t NC = lpgemm_get_block_size_NC_global_cntx( U8S8S32OS32 ); @@ -105,6 +106,9 @@ void lpgemm_rowvar_u8s8s32o32 // buffer needs to be updated. dim_t k_updated = make_multiple_of_n( k, 4 ); + // Is required to decide whether to apply post ops or not. + bool is_last_k = FALSE; + // Generate thrinfo objects for jc and ic loops from lpgemm_thrinfo_t. thrinfo_t thread_jc; thrinfo_t thread_ic; @@ -146,6 +150,8 @@ void lpgemm_rowvar_u8s8s32o32 // needs to be updated. dim_t kc0_updated = make_multiple_of_n( kc0, 4 ); + is_last_k = ( ( pc + KC ) >= k ) ? ( TRUE ) : ( FALSE ); + if ( mtag_b == PACK ) { // Pack B chunks are based on jc work id. @@ -300,7 +306,8 @@ void lpgemm_rowvar_u8s8s32o32 a_use, rs_a_use, cs_a_use, a_block_stride, ( b_use + ( jr * kc0_updated ) ), rs_b_use, cs_b_use, ( c_use_ic + jr ), rs_c, 1, - alpha, beta0 + alpha, beta0, + is_last_k, ic, ( jc + jr ), post_op_list ); } } diff --git a/addon/aocl_gemm/frame/u8s8s32/lpgemm_u8s8s32.h b/addon/aocl_gemm/frame/u8s8s32/lpgemm_u8s8s32.h index 2da9cc2de7..8f846abfdd 100644 --- a/addon/aocl_gemm/frame/u8s8s32/lpgemm_u8s8s32.h +++ b/addon/aocl_gemm/frame/u8s8s32/lpgemm_u8s8s32.h @@ -36,6 +36,7 @@ #define LPGEMM_U8S8S32_H #include "lpgemm_types.h" +#include "lpgemm_post_ops.h" // B should always be packed. void lpgemm_rowvar_u8s8s32o32 @@ -56,7 +57,8 @@ void lpgemm_rowvar_u8s8s32o32 int32_t alpha, int32_t beta, rntm_t* rntm, - lpgemm_thrinfo_t* thread + lpgemm_thrinfo_t* thread, + lpgemm_post_op* post_op_list ); #endif //LPGEMM_U8S8S32_H diff --git a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_6x64rowmajor.h b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_6x64rowmajor.h index 9a5b3644d1..8373ba1b72 100644 --- a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_6x64rowmajor.h +++ b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_6x64rowmajor.h @@ -35,6 +35,8 @@ #ifndef BLIS_GEMM_INT8_MNROW #define BLIS_GEMM_INT8_MNROW +#include "lpgemm_post_ops.h" + // 6x64 int8o32 kernel void lpgemm_rowvar_u8s8s32o32_6x64 ( @@ -52,7 +54,11 @@ void lpgemm_rowvar_u8s8s32o32_6x64 const dim_t rs_c, const dim_t cs_c, const int32_t alpha, - const int32_t beta + const int32_t beta, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op* post_ops_list ); #endif //BLIS_GEMM_INT8_MNROW diff --git a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_6x64rowmajor_amd512vnni.c b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_6x64rowmajor_amd512vnni.c index 679fc916c2..e611556171 100644 --- a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_6x64rowmajor_amd512vnni.c +++ b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_6x64rowmajor_amd512vnni.c @@ -56,9 +56,20 @@ void lpgemm_rowvar_u8s8s32o32_6x64 const dim_t rs_c, const dim_t cs_c, const int32_t alpha, - const int32_t beta + const int32_t beta, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op* post_ops_list ) { + static void* post_ops_labels[] = + { + &&POST_OPS_6x64_DISABLE, + &&POST_OPS_BIAS_6x64, + &&POST_OPS_RELU_6x64 + }; + dim_t MR = 6; dim_t NR = 64; @@ -99,11 +110,15 @@ void lpgemm_rowvar_u8s8s32o32_6x64 a, rs_a, cs_a, ps_a, b, ( ( rs_b / 4 ) * 3 ), cs_b, c, rs_c, - alpha, beta + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list ); b = b + ( 48 * k0_updated ); // k0x48 packed contiguosly. c = c + 48; + post_op_c_j += 48; } else if ( n0_32 == 1 ) { @@ -113,11 +128,15 @@ void lpgemm_rowvar_u8s8s32o32_6x64 a, rs_a, cs_a, ps_a, b, ( ( rs_b / 4 ) * 2 ), cs_b, c, rs_c, - alpha, beta + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list ); b = b + ( 32 * k0_updated ); // k0x32 packed contiguosly. c = c + 32; + post_op_c_j += 32; } else if ( n0_16 == 1 ) { @@ -127,11 +146,15 @@ void lpgemm_rowvar_u8s8s32o32_6x64 a, rs_a, cs_a, ps_a, b, ( ( rs_b / 4 ) * 1 ), cs_b, c, rs_c, - alpha, beta + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list ); b = b + ( 16 * k0_updated ); // k0x16 packed contiguosly. c = c + 16; + post_op_c_j += 16; } if ( n0_rem > 0 ) @@ -142,15 +165,18 @@ void lpgemm_rowvar_u8s8s32o32_6x64 a, rs_a, cs_a, ps_a, b, ( ( rs_b / 4 ) * 1 ), cs_b, c, rs_c, - alpha, beta, n0_rem + alpha, beta, n0_rem, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list ); // No leftover fringe after this point. } - + return; } - + // B matrix storage. __m512i b0; __m512i b1; @@ -160,7 +186,7 @@ void lpgemm_rowvar_u8s8s32o32_6x64 // A matrix storage. __m512i a_int32_0; __m512i a_int32_1; - + for ( dim_t ir = 0; ir < m_full_pieces_loop_limit; ir += MR ) { // Registers to use for accumulating C. @@ -544,7 +570,180 @@ void lpgemm_rowvar_u8s8s32o32_6x64 selector1 = _mm512_mullo_epi32( selector2, selector1 ); c_int32_5p3 = _mm512_add_epi32( selector1, c_int32_5p3 ); } - + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_6x64: + { + selector1 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 ) ); + a_int32_0 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 2 * 16 ) ); + a_int32_1 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 3 * 16 ) ); + + // c[0,0-15] + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + c_int32_0p1 = _mm512_add_epi32( selector2, c_int32_0p1 ); + + // c[0,32-47] + c_int32_0p2 = _mm512_add_epi32( a_int32_0, c_int32_0p2 ); + + // c[0,48-63] + c_int32_0p3 = _mm512_add_epi32( a_int32_1, c_int32_0p3 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[1, 16-31] + c_int32_1p1 = _mm512_add_epi32( selector2, c_int32_1p1 ); + + // c[1,32-47] + c_int32_1p2 = _mm512_add_epi32( a_int32_0, c_int32_1p2 ); + + // c[1,48-63] + c_int32_1p3 = _mm512_add_epi32( a_int32_1, c_int32_1p3 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + // c[2, 16-31] + c_int32_2p1 = _mm512_add_epi32( selector2, c_int32_2p1 ); + + // c[2,32-47] + c_int32_2p2 = _mm512_add_epi32( a_int32_0, c_int32_2p2 ); + + // c[2,48-63] + c_int32_2p3 = _mm512_add_epi32( a_int32_1, c_int32_2p3 ); + + // c[3,0-15] + c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); + + // c[3, 16-31] + c_int32_3p1 = _mm512_add_epi32( selector2, c_int32_3p1 ); + + // c[3,32-47] + c_int32_3p2 = _mm512_add_epi32( a_int32_0, c_int32_3p2 ); + + // c[3,48-63] + c_int32_3p3 = _mm512_add_epi32( a_int32_1, c_int32_3p3 ); + + // c[4,0-15] + c_int32_4p0 = _mm512_add_epi32( selector1, c_int32_4p0 ); + + // c[4, 16-31] + c_int32_4p1 = _mm512_add_epi32( selector2, c_int32_4p1 ); + + // c[4,32-47] + c_int32_4p2 = _mm512_add_epi32( a_int32_0, c_int32_4p2 ); + + // c[4,48-63] + c_int32_4p3 = _mm512_add_epi32( a_int32_1, c_int32_4p3 ); + + // c[5,0-15] + c_int32_5p0 = _mm512_add_epi32( selector1, c_int32_5p0 ); + + // c[5, 16-31] + c_int32_5p1 = _mm512_add_epi32( selector2, c_int32_5p1 ); + + // c[5,32-47] + c_int32_5p2 = _mm512_add_epi32( a_int32_0, c_int32_5p2 ); + + // c[5,48-63] + c_int32_5p3 = _mm512_add_epi32( a_int32_1, c_int32_5p3 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_6x64: + { + selector1 = _mm512_setzero_epi32(); + + // c[0,0-15] + c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + c_int32_0p1 = _mm512_max_epi32( selector1, c_int32_0p1 ); + + // c[0,32-47] + c_int32_0p2 = _mm512_max_epi32( selector1, c_int32_0p2 ); + + // c[0,48-63] + c_int32_0p3 = _mm512_max_epi32( selector1, c_int32_0p3 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); + + // c[1,16-31] + c_int32_1p1 = _mm512_max_epi32( selector1, c_int32_1p1 ); + + // c[1,32-47] + c_int32_1p2 = _mm512_max_epi32( selector1, c_int32_1p2 ); + + // c[1,48-63] + c_int32_1p3 = _mm512_max_epi32( selector1, c_int32_1p3 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_max_epi32( selector1, c_int32_2p0 ); + + // c[2,16-31] + c_int32_2p1 = _mm512_max_epi32( selector1, c_int32_2p1 ); + + // c[2,32-47] + c_int32_2p2 = _mm512_max_epi32( selector1, c_int32_2p2 ); + + // c[2,48-63] + c_int32_2p3 = _mm512_max_epi32( selector1, c_int32_2p3 ); + + // c[3,0-15] + c_int32_3p0 = _mm512_max_epi32( selector1, c_int32_3p0 ); + + // c[3,16-31] + c_int32_3p1 = _mm512_max_epi32( selector1, c_int32_3p1 ); + + // c[3,32-47] + c_int32_3p2 = _mm512_max_epi32( selector1, c_int32_3p2 ); + + // c[3,48-63] + c_int32_3p3 = _mm512_max_epi32( selector1, c_int32_3p3 ); + + // c[4,0-15] + c_int32_4p0 = _mm512_max_epi32( selector1, c_int32_4p0 ); + + // c[4,16-31] + c_int32_4p1 = _mm512_max_epi32( selector1, c_int32_4p1 ); + + // c[4,32-47] + c_int32_4p2 = _mm512_max_epi32( selector1, c_int32_4p2 ); + + // c[4,48-63] + c_int32_4p3 = _mm512_max_epi32( selector1, c_int32_4p3 ); + + // c[5,0-15] + c_int32_5p0 = _mm512_max_epi32( selector1, c_int32_5p0 ); + + // c[5,16-31] + c_int32_5p1 = _mm512_max_epi32( selector1, c_int32_5p1 ); + + // c[5,32-47] + c_int32_5p2 = _mm512_max_epi32( selector1, c_int32_5p2 ); + + // c[5,48-63] + c_int32_5p3 = _mm512_max_epi32( selector1, c_int32_5p3 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_6x64_DISABLE: + ; + // Store the results. // c[0,0-15] _mm512_storeu_epi32( c + ( rs_c * ( ir + 0 ) ) + ( 0*16 ), c_int32_0p0 ); @@ -619,6 +818,7 @@ void lpgemm_rowvar_u8s8s32o32_6x64 _mm512_storeu_epi32( c + ( rs_c * ( ir + 5 ) ) + ( 3*16 ), c_int32_5p3 ); a = a + ( MR * ps_a ); + post_op_c_i += MR; } if ( m_partial_pieces > 0 ) @@ -638,7 +838,10 @@ void lpgemm_rowvar_u8s8s32o32_6x64 a, rs_a, cs_a_use, b, rs_b, cs_b, ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, - alpha, beta + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list ); } else if ( m_partial_pieces == 4 ) @@ -650,7 +853,10 @@ void lpgemm_rowvar_u8s8s32o32_6x64 a, rs_a, cs_a_use, b, rs_b, cs_b, ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, - alpha, beta + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list ); } else if ( m_partial_pieces == 3 ) @@ -662,7 +868,10 @@ void lpgemm_rowvar_u8s8s32o32_6x64 a, rs_a, cs_a_use, b, rs_b, cs_b, ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, - alpha, beta + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list ); } else if ( m_partial_pieces == 2 ) @@ -674,7 +883,10 @@ void lpgemm_rowvar_u8s8s32o32_6x64 a, rs_a, cs_a_use, b, rs_b, cs_b, ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, - alpha, beta + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list ); } else if ( m_partial_pieces == 1 ) @@ -686,7 +898,10 @@ void lpgemm_rowvar_u8s8s32o32_6x64 a, rs_a, cs_a_use, b, rs_b, cs_b, ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, - alpha, beta + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list ); } } diff --git a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_m_fringe.h b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_m_fringe.h index b0acdbcd64..e4cc3f763e 100644 --- a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_m_fringe.h +++ b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_m_fringe.h @@ -35,6 +35,8 @@ #ifndef BLIS_GEMM_INT8_MFRINGE #define BLIS_GEMM_INT8_MFRINGE +#include "lpgemm_post_ops.h" + // 5x64 int8o32 kernel void lpgemm_rowvar_u8s8s32o32_5x64 ( @@ -48,7 +50,11 @@ void lpgemm_rowvar_u8s8s32o32_5x64 int32_t* c, const dim_t rs_c, const int32_t alpha, - const int32_t beta + const int32_t beta, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op* post_ops_list ); // 4x64 int8o32 kernel @@ -64,7 +70,11 @@ void lpgemm_rowvar_u8s8s32o32_4x64 int32_t* c, const dim_t rs_c, const int32_t alpha, - const int32_t beta + const int32_t beta, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op* post_ops_list ); // 3x64 int8o32 kernel @@ -80,7 +90,11 @@ void lpgemm_rowvar_u8s8s32o32_3x64 int32_t* c, const dim_t rs_c, const int32_t alpha, - const int32_t beta + const int32_t beta, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op* post_ops_list ); // 2x64 int8o32 kernel @@ -96,7 +110,11 @@ void lpgemm_rowvar_u8s8s32o32_2x64 int32_t* c, const dim_t rs_c, const int32_t alpha, - const int32_t beta + const int32_t beta, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op* post_ops_list ); // 1x64 int8o32 kernel @@ -112,7 +130,11 @@ void lpgemm_rowvar_u8s8s32o32_1x64 int32_t* c, const dim_t rs_c, const int32_t alpha, - const int32_t beta + const int32_t beta, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op* post_ops_list ); #endif //BLIS_GEMM_INT8_MFRINGE diff --git a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_m_fringe_amd512vnni.c b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_m_fringe_amd512vnni.c index e02c2cce89..6c8e79fa31 100644 --- a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_m_fringe_amd512vnni.c +++ b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_m_fringe_amd512vnni.c @@ -51,9 +51,19 @@ void lpgemm_rowvar_u8s8s32o32_5x64 int32_t* c, const dim_t rs_c, const int32_t alpha, - const int32_t beta + const int32_t beta, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op* post_ops_list ) { + static void* post_ops_labels[] = + { + &&POST_OPS_5x64_DISABLE, + &&POST_OPS_BIAS_5x64, + &&POST_OPS_RELU_5x64 + }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -384,7 +394,156 @@ void lpgemm_rowvar_u8s8s32o32_5x64 selector1 = _mm512_mullo_epi32( selector2, selector1 ); c_int32_4p3 = _mm512_add_epi32( selector1, c_int32_4p3 ); } - + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_5x64: + { + selector1 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j ); + selector2 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 ) ); + a_int32_0 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 2 * 16 ) ); + a_int32_1 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 3 * 16 ) ); + + // c[0,0-15] + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + c_int32_0p1 = _mm512_add_epi32( selector2, c_int32_0p1 ); + + // c[0,32-47] + c_int32_0p2 = _mm512_add_epi32( a_int32_0, c_int32_0p2 ); + + // c[0,48-63] + c_int32_0p3 = _mm512_add_epi32( a_int32_1, c_int32_0p3 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[1, 16-31] + c_int32_1p1 = _mm512_add_epi32( selector2, c_int32_1p1 ); + + // c[1,32-47] + c_int32_1p2 = _mm512_add_epi32( a_int32_0, c_int32_1p2 ); + + // c[1,48-63] + c_int32_1p3 = _mm512_add_epi32( a_int32_1, c_int32_1p3 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + // c[2, 16-31] + c_int32_2p1 = _mm512_add_epi32( selector2, c_int32_2p1 ); + + // c[2,32-47] + c_int32_2p2 = _mm512_add_epi32( a_int32_0, c_int32_2p2 ); + + // c[2,48-63] + c_int32_2p3 = _mm512_add_epi32( a_int32_1, c_int32_2p3 ); + + // c[3,0-15] + c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); + + // c[3, 16-31] + c_int32_3p1 = _mm512_add_epi32( selector2, c_int32_3p1 ); + + // c[3,32-47] + c_int32_3p2 = _mm512_add_epi32( a_int32_0, c_int32_3p2 ); + + // c[3,48-63] + c_int32_3p3 = _mm512_add_epi32( a_int32_1, c_int32_3p3 ); + + // c[4,0-15] + c_int32_4p0 = _mm512_add_epi32( selector1, c_int32_4p0 ); + + // c[4, 16-31] + c_int32_4p1 = _mm512_add_epi32( selector2, c_int32_4p1 ); + + // c[4,32-47] + c_int32_4p2 = _mm512_add_epi32( a_int32_0, c_int32_4p2 ); + + // c[4,48-63] + c_int32_4p3 = _mm512_add_epi32( a_int32_1, c_int32_4p3 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_5x64: + { + selector1 = _mm512_setzero_epi32(); + + // c[0,0-15] + c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + c_int32_0p1 = _mm512_max_epi32( selector1, c_int32_0p1 ); + + // c[0,32-47] + c_int32_0p2 = _mm512_max_epi32( selector1, c_int32_0p2 ); + + // c[0,48-63] + c_int32_0p3 = _mm512_max_epi32( selector1, c_int32_0p3 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); + + // c[1,16-31] + c_int32_1p1 = _mm512_max_epi32( selector1, c_int32_1p1 ); + + // c[1,32-47] + c_int32_1p2 = _mm512_max_epi32( selector1, c_int32_1p2 ); + + // c[1,48-63] + c_int32_1p3 = _mm512_max_epi32( selector1, c_int32_1p3 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_max_epi32( selector1, c_int32_2p0 ); + + // c[2,16-31] + c_int32_2p1 = _mm512_max_epi32( selector1, c_int32_2p1 ); + + // c[2,32-47] + c_int32_2p2 = _mm512_max_epi32( selector1, c_int32_2p2 ); + + // c[2,48-63] + c_int32_2p3 = _mm512_max_epi32( selector1, c_int32_2p3 ); + + // c[3,0-15] + c_int32_3p0 = _mm512_max_epi32( selector1, c_int32_3p0 ); + + // c[3,16-31] + c_int32_3p1 = _mm512_max_epi32( selector1, c_int32_3p1 ); + + // c[3,32-47] + c_int32_3p2 = _mm512_max_epi32( selector1, c_int32_3p2 ); + + // c[3,48-63] + c_int32_3p3 = _mm512_max_epi32( selector1, c_int32_3p3 ); + + // c[4,0-15] + c_int32_4p0 = _mm512_max_epi32( selector1, c_int32_4p0 ); + + // c[4,16-31] + c_int32_4p1 = _mm512_max_epi32( selector1, c_int32_4p1 ); + + // c[4,32-47] + c_int32_4p2 = _mm512_max_epi32( selector1, c_int32_4p2 ); + + // c[4,48-63] + c_int32_4p3 = _mm512_max_epi32( selector1, c_int32_4p3 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_5x64_DISABLE: + ; + // Store the results. // c[0,0-15] _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 0*16 ), c_int32_0p0 ); @@ -460,9 +619,19 @@ void lpgemm_rowvar_u8s8s32o32_4x64 int32_t* c, const dim_t rs_c, const int32_t alpha, - const int32_t beta + const int32_t beta, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op* post_ops_list ) { + static void* post_ops_labels[] = + { + &&POST_OPS_4x64_DISABLE, + &&POST_OPS_BIAS_4x64, + &&POST_OPS_RELU_4x64 + }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -735,7 +904,132 @@ void lpgemm_rowvar_u8s8s32o32_4x64 selector1 = _mm512_mullo_epi32( selector2, selector1 ); c_int32_3p3 = _mm512_add_epi32( selector1, c_int32_3p3 ); } - + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_4x64: + { + selector1 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j ); + selector2 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 ) ); + a_int32_0 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 2 * 16 ) ); + a_int32_1 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 3 * 16 ) ); + + // c[0,0-15] + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + c_int32_0p1 = _mm512_add_epi32( selector2, c_int32_0p1 ); + + // c[0,32-47] + c_int32_0p2 = _mm512_add_epi32( a_int32_0, c_int32_0p2 ); + + // c[0,48-63] + c_int32_0p3 = _mm512_add_epi32( a_int32_1, c_int32_0p3 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[1, 16-31] + c_int32_1p1 = _mm512_add_epi32( selector2, c_int32_1p1 ); + + // c[1,32-47] + c_int32_1p2 = _mm512_add_epi32( a_int32_0, c_int32_1p2 ); + + // c[1,48-63] + c_int32_1p3 = _mm512_add_epi32( a_int32_1, c_int32_1p3 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + // c[2, 16-31] + c_int32_2p1 = _mm512_add_epi32( selector2, c_int32_2p1 ); + + // c[2,32-47] + c_int32_2p2 = _mm512_add_epi32( a_int32_0, c_int32_2p2 ); + + // c[2,48-63] + c_int32_2p3 = _mm512_add_epi32( a_int32_1, c_int32_2p3 ); + + // c[3,0-15] + c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); + + // c[3, 16-31] + c_int32_3p1 = _mm512_add_epi32( selector2, c_int32_3p1 ); + + // c[3,32-47] + c_int32_3p2 = _mm512_add_epi32( a_int32_0, c_int32_3p2 ); + + // c[3,48-63] + c_int32_3p3 = _mm512_add_epi32( a_int32_1, c_int32_3p3 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_4x64: + { + selector1 = _mm512_setzero_epi32(); + + // c[0,0-15] + c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + c_int32_0p1 = _mm512_max_epi32( selector1, c_int32_0p1 ); + + // c[0,32-47] + c_int32_0p2 = _mm512_max_epi32( selector1, c_int32_0p2 ); + + // c[0,48-63] + c_int32_0p3 = _mm512_max_epi32( selector1, c_int32_0p3 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); + + // c[1,16-31] + c_int32_1p1 = _mm512_max_epi32( selector1, c_int32_1p1 ); + + // c[1,32-47] + c_int32_1p2 = _mm512_max_epi32( selector1, c_int32_1p2 ); + + // c[1,48-63] + c_int32_1p3 = _mm512_max_epi32( selector1, c_int32_1p3 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_max_epi32( selector1, c_int32_2p0 ); + + // c[2,16-31] + c_int32_2p1 = _mm512_max_epi32( selector1, c_int32_2p1 ); + + // c[2,32-47] + c_int32_2p2 = _mm512_max_epi32( selector1, c_int32_2p2 ); + + // c[2,48-63] + c_int32_2p3 = _mm512_max_epi32( selector1, c_int32_2p3 ); + + // c[3,0-15] + c_int32_3p0 = _mm512_max_epi32( selector1, c_int32_3p0 ); + + // c[3,16-31] + c_int32_3p1 = _mm512_max_epi32( selector1, c_int32_3p1 ); + + // c[3,32-47] + c_int32_3p2 = _mm512_max_epi32( selector1, c_int32_3p2 ); + + // c[3,48-63] + c_int32_3p3 = _mm512_max_epi32( selector1, c_int32_3p3 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_4x64_DISABLE: + ; + // Store the results. // c[0,0-15] _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 0*16 ), c_int32_0p0 ); @@ -799,14 +1093,34 @@ void lpgemm_rowvar_u8s8s32o32_3x64 int32_t* c, const dim_t rs_c, const int32_t alpha, - const int32_t beta + const int32_t beta, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op* post_ops_list ) { + static void* post_ops_labels[] = + { + &&POST_OPS_3x64_DISABLE, + &&POST_OPS_BIAS_3x64, + &&POST_OPS_RELU_3x64 + }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; uint32_t a_kfringe_buf = 0; + // B matrix storage. + __m512i b0; + __m512i b1; + __m512i b2; + __m512i b3; + + // A matrix storage. + __m512i a_int32_0; + __m512i a_int32_1; + // Registers to use for accumulating C. __m512i c_int32_0p0 = _mm512_setzero_epi32(); __m512i c_int32_0p1 = _mm512_setzero_epi32(); @@ -825,21 +1139,21 @@ void lpgemm_rowvar_u8s8s32o32_3x64 for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) { - __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); // Broadcast a[0,kr:kr+4]. - __m512i a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); - __m512i b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); - __m512i b2 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 2 ) ); - __m512i b3 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 3 ) ); + b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + b2 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 2 ) ); + b3 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 3 ) ); // Perform column direction mat-mul with k = 4. // c[0,0-63] = a[0,kr:kr+4]*b[kr:kr+4,0-63] c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); // Broadcast a[1,kr:kr+4]. - __m512i a_int32_1 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + a_int32_1 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); c_int32_0p2 = _mm512_dpbusd_epi32( c_int32_0p2, a_int32_0, b2 ); @@ -866,7 +1180,7 @@ void lpgemm_rowvar_u8s8s32o32_3x64 // Handle k remainder. if ( k_partial_pieces > 0 ) { - __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); // Broadcast a[0,kr:kr+4]. memcpy @@ -875,11 +1189,11 @@ void lpgemm_rowvar_u8s8s32o32_3x64 ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - __m512i a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - __m512i b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); - __m512i b2 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); - __m512i b3 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 3 ) ); + b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + b2 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); + b3 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 3 ) ); // Perform column direction mat-mul with k = 4. // c[0,0-63] = a[0,kr:kr+4]*b[kr:kr+4,0-63] @@ -892,7 +1206,7 @@ void lpgemm_rowvar_u8s8s32o32_3x64 ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - __m512i a_int32_1 = _mm512_set1_epi32( a_kfringe_buf ); + a_int32_1 = _mm512_set1_epi32( a_kfringe_buf ); c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); c_int32_0p2 = _mm512_dpbusd_epi32( c_int32_0p2, a_int32_0, b2 ); @@ -1006,7 +1320,108 @@ void lpgemm_rowvar_u8s8s32o32_3x64 selector1 = _mm512_mullo_epi32( selector2, selector1 ); c_int32_2p3 = _mm512_add_epi32( selector1, c_int32_2p3 ); } - + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_3x64: + { + selector1 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j ); + selector2 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 ) ); + a_int32_0 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 2 * 16 ) ); + a_int32_1 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 3 * 16 ) ); + + // c[0,0-15] + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + c_int32_0p1 = _mm512_add_epi32( selector2, c_int32_0p1 ); + + // c[0,32-47] + c_int32_0p2 = _mm512_add_epi32( a_int32_0, c_int32_0p2 ); + + // c[0,48-63] + c_int32_0p3 = _mm512_add_epi32( a_int32_1, c_int32_0p3 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[1, 16-31] + c_int32_1p1 = _mm512_add_epi32( selector2, c_int32_1p1 ); + + // c[1,32-47] + c_int32_1p2 = _mm512_add_epi32( a_int32_0, c_int32_1p2 ); + + // c[1,48-63] + c_int32_1p3 = _mm512_add_epi32( a_int32_1, c_int32_1p3 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + // c[2, 16-31] + c_int32_2p1 = _mm512_add_epi32( selector2, c_int32_2p1 ); + + // c[2,32-47] + c_int32_2p2 = _mm512_add_epi32( a_int32_0, c_int32_2p2 ); + + // c[2,48-63] + c_int32_2p3 = _mm512_add_epi32( a_int32_1, c_int32_2p3 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_3x64: + { + selector1 = _mm512_setzero_epi32(); + + // c[0,0-15] + c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + c_int32_0p1 = _mm512_max_epi32( selector1, c_int32_0p1 ); + + // c[0,32-47] + c_int32_0p2 = _mm512_max_epi32( selector1, c_int32_0p2 ); + + // c[0,48-63] + c_int32_0p3 = _mm512_max_epi32( selector1, c_int32_0p3 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); + + // c[1,16-31] + c_int32_1p1 = _mm512_max_epi32( selector1, c_int32_1p1 ); + + // c[1,32-47] + c_int32_1p2 = _mm512_max_epi32( selector1, c_int32_1p2 ); + + // c[1,48-63] + c_int32_1p3 = _mm512_max_epi32( selector1, c_int32_1p3 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_max_epi32( selector1, c_int32_2p0 ); + + // c[2,16-31] + c_int32_2p1 = _mm512_max_epi32( selector1, c_int32_2p1 ); + + // c[2,32-47] + c_int32_2p2 = _mm512_max_epi32( selector1, c_int32_2p2 ); + + // c[2,48-63] + c_int32_2p3 = _mm512_max_epi32( selector1, c_int32_2p3 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_3x64_DISABLE: + ; + // Store the results. // c[0,0-15] _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 0*16 ), c_int32_0p0 ); @@ -1058,14 +1473,34 @@ void lpgemm_rowvar_u8s8s32o32_2x64 int32_t* c, const dim_t rs_c, const int32_t alpha, - const int32_t beta + const int32_t beta, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op* post_ops_list ) { + static void* post_ops_labels[] = + { + &&POST_OPS_2x64_DISABLE, + &&POST_OPS_BIAS_2x64, + &&POST_OPS_RELU_2x64 + }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; uint32_t a_kfringe_buf = 0; + // B matrix storage. + __m512i b0; + __m512i b1; + __m512i b2; + __m512i b3; + + // A matrix storage. + __m512i a_int32_0; + __m512i a_int32_1; + // Registers to use for accumulating C. __m512i c_int32_0p0 = _mm512_setzero_epi32(); __m512i c_int32_0p1 = _mm512_setzero_epi32(); @@ -1079,21 +1514,21 @@ void lpgemm_rowvar_u8s8s32o32_2x64 for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) { - __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); // Broadcast a[0,kr:kr+4]. - __m512i a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); - __m512i b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); - __m512i b2 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 2 ) ); - __m512i b3 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 3 ) ); + b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + b2 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 2 ) ); + b3 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 3 ) ); // Perform column direction mat-mul with k = 4. // c[0,0-63] = a[0,kr:kr+4]*b[kr:kr+4,0-63] c_int32_0p0 = _mm512_dpbusd_epi32( c_int32_0p0, a_int32_0, b0 ); // Broadcast a[1,kr:kr+4]. - __m512i a_int32_1 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + a_int32_1 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); c_int32_0p2 = _mm512_dpbusd_epi32( c_int32_0p2, a_int32_0, b2 ); @@ -1109,7 +1544,7 @@ void lpgemm_rowvar_u8s8s32o32_2x64 // Handle k remainder. if ( k_partial_pieces > 0 ) { - __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); // Broadcast a[0,kr:kr+4]. memcpy @@ -1118,11 +1553,11 @@ void lpgemm_rowvar_u8s8s32o32_2x64 ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - __m512i a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - __m512i b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); - __m512i b2 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); - __m512i b3 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 3 ) ); + b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + b2 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); + b3 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 3 ) ); // Perform column direction mat-mul with k = 4. // c[0,0-63] = a[0,kr:kr+4]*b[kr:kr+4,0-63] @@ -1135,7 +1570,7 @@ void lpgemm_rowvar_u8s8s32o32_2x64 ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - __m512i a_int32_1 = _mm512_set1_epi32( a_kfringe_buf ); + a_int32_1 = _mm512_set1_epi32( a_kfringe_buf ); c_int32_0p1 = _mm512_dpbusd_epi32( c_int32_0p1, a_int32_0, b1 ); c_int32_0p2 = _mm512_dpbusd_epi32( c_int32_0p2, a_int32_0, b2 ); @@ -1207,7 +1642,84 @@ void lpgemm_rowvar_u8s8s32o32_2x64 selector1 = _mm512_mullo_epi32( selector2, selector1 ); c_int32_1p3 = _mm512_add_epi32( selector1, c_int32_1p3 ); } - + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_2x64: + { + selector1 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j ); + selector2 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 ) ); + a_int32_0 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 2 * 16 ) ); + a_int32_1 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 3 * 16 ) ); + + // c[0,0-15] + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + c_int32_0p1 = _mm512_add_epi32( selector2, c_int32_0p1 ); + + // c[0,32-47] + c_int32_0p2 = _mm512_add_epi32( a_int32_0, c_int32_0p2 ); + + // c[0,48-63] + c_int32_0p3 = _mm512_add_epi32( a_int32_1, c_int32_0p3 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[1, 16-31] + c_int32_1p1 = _mm512_add_epi32( selector2, c_int32_1p1 ); + + // c[1,32-47] + c_int32_1p2 = _mm512_add_epi32( a_int32_0, c_int32_1p2 ); + + // c[1,48-63] + c_int32_1p3 = _mm512_add_epi32( a_int32_1, c_int32_1p3 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_2x64: + { + selector1 = _mm512_setzero_epi32(); + + // c[0,0-15] + c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + c_int32_0p1 = _mm512_max_epi32( selector1, c_int32_0p1 ); + + // c[0,32-47] + c_int32_0p2 = _mm512_max_epi32( selector1, c_int32_0p2 ); + + // c[0,48-63] + c_int32_0p3 = _mm512_max_epi32( selector1, c_int32_0p3 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); + + // c[1,16-31] + c_int32_1p1 = _mm512_max_epi32( selector1, c_int32_1p1 ); + + // c[1,32-47] + c_int32_1p2 = _mm512_max_epi32( selector1, c_int32_1p2 ); + + // c[1,48-63] + c_int32_1p3 = _mm512_max_epi32( selector1, c_int32_1p3 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_2x64_DISABLE: + ; + // Store the results. // c[0,0-15] _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 0*16 ), c_int32_0p0 ); @@ -1247,14 +1759,34 @@ void lpgemm_rowvar_u8s8s32o32_1x64 int32_t* c, const dim_t rs_c, const int32_t alpha, - const int32_t beta + const int32_t beta, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op* post_ops_list ) { + static void* post_ops_labels[] = + { + &&POST_OPS_1x64_DISABLE, + &&POST_OPS_BIAS_1x64, + &&POST_OPS_RELU_1x64 + }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; uint32_t a_kfringe_buf = 0; + // B matrix storage. + __m512i b0; + __m512i b1; + __m512i b2; + __m512i b3; + + // A matrix storage. + __m512i a_int32_0; + __m512i a_int32_1; + // Registers to use for accumulating C. __m512i c_int32_0p0 = _mm512_setzero_epi32(); __m512i c_int32_0p1 = _mm512_setzero_epi32(); @@ -1263,14 +1795,14 @@ void lpgemm_rowvar_u8s8s32o32_1x64 for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) { - __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); // Broadcast a[0,kr] - __m512i a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); - __m512i b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); - __m512i b2 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 2 ) ); - __m512i b3 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 3 ) ); + b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + b2 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 2 ) ); + b3 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 3 ) ); // Perform column direction mat-mul with k = 4. // c[0,0-63] = a[0,kr:kr+4]*b[kr:kr+4,0-63] @@ -1282,7 +1814,7 @@ void lpgemm_rowvar_u8s8s32o32_1x64 // Handle k remainder. if ( k_partial_pieces > 0 ) { - __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); // Broadcast a[0,kr:kr+4]. memcpy @@ -1291,11 +1823,11 @@ void lpgemm_rowvar_u8s8s32o32_1x64 ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - __m512i a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); - __m512i b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); - __m512i b2 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); - __m512i b3 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 3 ) ); + b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + b2 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); + b3 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 3 ) ); // Perform column direction mat-mul with k = 4. // c[0,0-63] = a[0,kr:kr+4]*b[kr:kr+4,0-63] @@ -1339,6 +1871,59 @@ void lpgemm_rowvar_u8s8s32o32_1x64 c_int32_0p3 = _mm512_add_epi32( selector1, c_int32_0p3 ); } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_1x64: + { + selector1 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j ); + selector2 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 ) ); + a_int32_0 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 2 * 16 ) ); + a_int32_1 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 3 * 16 ) ); + + // c[0,0-15] + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + c_int32_0p1 = _mm512_add_epi32( selector2, c_int32_0p1 ); + + // c[0,32-47] + c_int32_0p2 = _mm512_add_epi32( a_int32_0, c_int32_0p2 ); + + // c[0,48-63] + c_int32_0p3 = _mm512_add_epi32( a_int32_1, c_int32_0p3 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_1x64: + { + selector1 = _mm512_setzero_epi32(); + + // c[0,0-15] + c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + c_int32_0p1 = _mm512_max_epi32( selector1, c_int32_0p1 ); + + // c[0,32-47] + c_int32_0p2 = _mm512_max_epi32( selector1, c_int32_0p2 ); + + // c[0,48-63] + c_int32_0p3 = _mm512_max_epi32( selector1, c_int32_0p3 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_1x64_DISABLE: + ; + // Store the accumulated results. // c[0,0-15] _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 0*16 ), c_int32_0p0 ); diff --git a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_mn_fringe.h b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_mn_fringe.h index 008f254b11..e49f543d98 100644 --- a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_mn_fringe.h +++ b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_mn_fringe.h @@ -35,6 +35,8 @@ #ifndef BLIS_GEMM_INT8_MNFRINGE #define BLIS_GEMM_INT8_MNFRINGE +#include "lpgemm_post_ops.h" + // 5xlt16 int8o32 fringe kernel void lpgemm_rowvar_u8s8s32o32_5xlt16 ( @@ -49,7 +51,11 @@ void lpgemm_rowvar_u8s8s32o32_5xlt16 const dim_t rs_c, const int32_t alpha, const int32_t beta, - const dim_t n0_rem + const dim_t n0_rem, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op* post_ops_list ); // 4xlt16 int8o32 fringe kernel @@ -66,7 +72,11 @@ void lpgemm_rowvar_u8s8s32o32_4xlt16 const dim_t rs_c, const int32_t alpha, const int32_t beta, - const dim_t n0_rem + const dim_t n0_rem, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op* post_ops_list ); // 3xlt16 int8o32 fringe kernel @@ -83,7 +93,11 @@ void lpgemm_rowvar_u8s8s32o32_3xlt16 const dim_t rs_c, const int32_t alpha, const int32_t beta, - const dim_t n0_rem + const dim_t n0_rem, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op* post_ops_list ); // 2xlt16 int8o32 fringe kernel @@ -100,7 +114,11 @@ void lpgemm_rowvar_u8s8s32o32_2xlt16 const dim_t rs_c, const int32_t alpha, const int32_t beta, - const dim_t n0_rem + const dim_t n0_rem, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op* post_ops_list ); // 1xlt16 int8o32 fringe kernel @@ -117,7 +135,11 @@ void lpgemm_rowvar_u8s8s32o32_1xlt16 const dim_t rs_c, const int32_t alpha, const int32_t beta, - const dim_t n0_rem + const dim_t n0_rem, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op* post_ops_list ); // 5x16 int8o32 kernel @@ -133,7 +155,11 @@ void lpgemm_rowvar_u8s8s32o32_5x16 int32_t* c, const dim_t rs_c, const int32_t alpha, - const int32_t beta + const int32_t beta, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op* post_ops_list ); // 4x16 int8o32 kernel @@ -149,7 +175,11 @@ void lpgemm_rowvar_u8s8s32o32_4x16 int32_t* c, const dim_t rs_c, const int32_t alpha, - const int32_t beta + const int32_t beta, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op* post_ops_list ); // 3x16 int8o32 kernel @@ -165,7 +195,11 @@ void lpgemm_rowvar_u8s8s32o32_3x16 int32_t* c, const dim_t rs_c, const int32_t alpha, - const int32_t beta + const int32_t beta, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op* post_ops_list ); // 2x16 int8o32 kernel @@ -181,7 +215,11 @@ void lpgemm_rowvar_u8s8s32o32_2x16 int32_t* c, const dim_t rs_c, const int32_t alpha, - const int32_t beta + const int32_t beta, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op* post_ops_list ); // 1x16 int8o32 kernel @@ -197,7 +235,11 @@ void lpgemm_rowvar_u8s8s32o32_1x16 int32_t* c, const dim_t rs_c, const int32_t alpha, - const int32_t beta + const int32_t beta, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op* post_ops_list ); // 5x32 int8o32 kernel @@ -213,7 +255,11 @@ void lpgemm_rowvar_u8s8s32o32_5x32 int32_t* c, const dim_t rs_c, const int32_t alpha, - const int32_t beta + const int32_t beta, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op* post_ops_list ); // 4x32 int8o32 kernel @@ -229,7 +275,11 @@ void lpgemm_rowvar_u8s8s32o32_4x32 int32_t* c, const dim_t rs_c, const int32_t alpha, - const int32_t beta + const int32_t beta, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op* post_ops_list ); // 3x32 int8o32 kernel @@ -245,7 +295,11 @@ void lpgemm_rowvar_u8s8s32o32_3x32 int32_t* c, const dim_t rs_c, const int32_t alpha, - const int32_t beta + const int32_t beta, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op* post_ops_list ); // 2x32 int8o32 kernel @@ -261,7 +315,11 @@ void lpgemm_rowvar_u8s8s32o32_2x32 int32_t* c, const dim_t rs_c, const int32_t alpha, - const int32_t beta + const int32_t beta, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op* post_ops_list ); // 1x32 int8o32 kernel @@ -277,7 +335,11 @@ void lpgemm_rowvar_u8s8s32o32_1x32 int32_t* c, const dim_t rs_c, const int32_t alpha, - const int32_t beta + const int32_t beta, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op* post_ops_list ); // 5x48 int8o32 kernel @@ -293,7 +355,11 @@ void lpgemm_rowvar_u8s8s32o32_5x48 int32_t* c, const dim_t rs_c, const int32_t alpha, - const int32_t beta + const int32_t beta, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op* post_ops_list ); // 4x48 int8o32 kernel @@ -309,7 +375,11 @@ void lpgemm_rowvar_u8s8s32o32_4x48 int32_t* c, const dim_t rs_c, const int32_t alpha, - const int32_t beta + const int32_t beta, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op* post_ops_list ); // 3x48 int8o32 kernel @@ -325,7 +395,11 @@ void lpgemm_rowvar_u8s8s32o32_3x48 int32_t* c, const dim_t rs_c, const int32_t alpha, - const int32_t beta + const int32_t beta, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op* post_ops_list ); // 2x48 int8o32 kernel @@ -341,7 +415,11 @@ void lpgemm_rowvar_u8s8s32o32_2x48 int32_t* c, const dim_t rs_c, const int32_t alpha, - const int32_t beta + const int32_t beta, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op* post_ops_list ); // 1x48 int8o32 kernel @@ -357,7 +435,11 @@ void lpgemm_rowvar_u8s8s32o32_1x48 int32_t* c, const dim_t rs_c, const int32_t alpha, - const int32_t beta + const int32_t beta, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op* post_ops_list ); #endif //BLIS_GEMM_INT8_MNFRINGE diff --git a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_mn_fringe_amd512vnni.c b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_mn_fringe_amd512vnni.c index 427205292b..07e61f4ce3 100644 --- a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_mn_fringe_amd512vnni.c +++ b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_mn_fringe_amd512vnni.c @@ -52,9 +52,19 @@ void lpgemm_rowvar_u8s8s32o32_5xlt16 const dim_t rs_c, const int32_t alpha, const int32_t beta, - const dim_t n0_rem + const dim_t n0_rem, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op* post_ops_list ) { + static void* post_ops_labels[] = + { + &&POST_OPS_5xLT16_DISABLE, + &&POST_OPS_BIAS_5xLT16, + &&POST_OPS_RELU_5xLT16 + }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -213,6 +223,56 @@ void lpgemm_rowvar_u8s8s32o32_5xlt16 selector1 = _mm512_mullo_epi32( selector2, selector1 ); c_int32_4p0 = _mm512_add_epi32( selector1, c_int32_4p0 ); } + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_5xLT16: + { + selector1 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j ); + + // c[0,0-15] + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + // c[3,0-15] + c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); + + // c[4,0-15] + c_int32_4p0 = _mm512_add_epi32( selector1, c_int32_4p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_5xLT16: + { + selector1 = _mm512_setzero_epi32(); + + // c[0,0-15] + c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_max_epi32( selector1, c_int32_2p0 ); + + // c[3,0-15] + c_int32_3p0 = _mm512_max_epi32( selector1, c_int32_3p0 ); + + // c[4,0-15] + c_int32_4p0 = _mm512_max_epi32( selector1, c_int32_4p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_5xLT16_DISABLE: + ; // Store the results. // c[0,0-15] @@ -262,9 +322,19 @@ void lpgemm_rowvar_u8s8s32o32_4xlt16 const dim_t rs_c, const int32_t alpha, const int32_t beta, - const dim_t n0_rem + const dim_t n0_rem, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op* post_ops_list ) { + static void* post_ops_labels[] = + { + &&POST_OPS_4xLT16_DISABLE, + &&POST_OPS_BIAS_4xLT16, + &&POST_OPS_RELU_4xLT16 + }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -398,6 +468,50 @@ void lpgemm_rowvar_u8s8s32o32_4xlt16 selector1 = _mm512_mullo_epi32( selector2, selector1 ); c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); } + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_4xLT16: + { + selector1 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j ); + + // c[0,0-15] + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + // c[3,0-15] + c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_4xLT16: + { + selector1 = _mm512_setzero_epi32(); + + // c[0,0-15] + c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_max_epi32( selector1, c_int32_2p0 ); + + // c[3,0-15] + c_int32_3p0 = _mm512_max_epi32( selector1, c_int32_3p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_4xLT16_DISABLE: + ; // Store the results. // c[0,0-15] @@ -441,9 +555,19 @@ void lpgemm_rowvar_u8s8s32o32_3xlt16 const dim_t rs_c, const int32_t alpha, const int32_t beta, - const dim_t n0_rem + const dim_t n0_rem, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op* post_ops_list ) { + static void* post_ops_labels[] = + { + &&POST_OPS_3xLT16_DISABLE, + &&POST_OPS_BIAS_3xLT16, + &&POST_OPS_RELU_3xLT16 + }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -550,6 +674,44 @@ void lpgemm_rowvar_u8s8s32o32_3xlt16 selector1 = _mm512_mullo_epi32( selector2, selector1 ); c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); } + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_3xLT16: + { + selector1 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j ); + + // c[0,0-15] + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_3xLT16: + { + selector1 = _mm512_setzero_epi32(); + + // c[0,0-15] + c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_max_epi32( selector1, c_int32_2p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_3xLT16_DISABLE: + ; // Store the results. // c[0,0-15] @@ -587,9 +749,19 @@ void lpgemm_rowvar_u8s8s32o32_2xlt16 const dim_t rs_c, const int32_t alpha, const int32_t beta, - const dim_t n0_rem + const dim_t n0_rem, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op* post_ops_list ) { + static void* post_ops_labels[] = + { + &&POST_OPS_2xLT16_DISABLE, + &&POST_OPS_BIAS_2xLT16, + &&POST_OPS_RELU_2xLT16 + }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -670,6 +842,38 @@ void lpgemm_rowvar_u8s8s32o32_2xlt16 selector1 = _mm512_mullo_epi32( selector2, selector1 ); c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); } + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_2xLT16: + { + selector1 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j ); + + // c[0,0-15] + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_2xLT16: + { + selector1 = _mm512_setzero_epi32(); + + // c[0,0-15] + c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_2xLT16_DISABLE: + ; // Store the results. // c[0,0-15] @@ -701,9 +905,19 @@ void lpgemm_rowvar_u8s8s32o32_1xlt16 const dim_t rs_c, const int32_t alpha, const int32_t beta, - const dim_t n0_rem + const dim_t n0_rem, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op* post_ops_list ) { + static void* post_ops_labels[] = + { + &&POST_OPS_1xLT16_DISABLE, + &&POST_OPS_BIAS_1xLT16, + &&POST_OPS_RELU_1xLT16 + }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -758,6 +972,32 @@ void lpgemm_rowvar_u8s8s32o32_1xlt16 selector1 = _mm512_mullo_epi32( selector2, selector1 ); c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); } + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_1xLT16: + { + selector1 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j ); + + // c[0,0-15] + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_1xLT16: + { + selector1 = _mm512_setzero_epi32(); + + // c[0,0-15] + c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_1xLT16_DISABLE: + ; // Store the results. // c[0,0-15] @@ -782,9 +1022,19 @@ void lpgemm_rowvar_u8s8s32o32_5x16 int32_t* c, const dim_t rs_c, const int32_t alpha, - const int32_t beta + const int32_t beta, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op* post_ops_list ) { + static void* post_ops_labels[] = + { + &&POST_OPS_5x16_DISABLE, + &&POST_OPS_BIAS_5x16, + &&POST_OPS_RELU_5x16 + }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -929,7 +1179,57 @@ void lpgemm_rowvar_u8s8s32o32_5x16 selector1 = _mm512_mullo_epi32( selector2, selector1 ); c_int32_4p0 = _mm512_add_epi32( selector1, c_int32_4p0 ); } - + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_5x16: + { + selector1 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j ); + + // c[0,0-15] + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + // c[3,0-15] + c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); + + // c[4,0-15] + c_int32_4p0 = _mm512_add_epi32( selector1, c_int32_4p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_5x16: + { + selector1 = _mm512_setzero_epi32(); + + // c[0,0-15] + c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_max_epi32( selector1, c_int32_2p0 ); + + // c[3,0-15] + c_int32_3p0 = _mm512_max_epi32( selector1, c_int32_3p0 ); + + // c[4,0-15] + c_int32_4p0 = _mm512_max_epi32( selector1, c_int32_4p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_5x16_DISABLE: + ; + // Store the results. // c[0,0-15] _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 0*16 ), c_int32_0p0 ); @@ -960,9 +1260,19 @@ void lpgemm_rowvar_u8s8s32o32_4x16 int32_t* c, const dim_t rs_c, const int32_t alpha, - const int32_t beta + const int32_t beta, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op* post_ops_list ) { + static void* post_ops_labels[] = + { + &&POST_OPS_4x16_DISABLE, + &&POST_OPS_BIAS_4x16, + &&POST_OPS_RELU_4x16 + }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -1083,6 +1393,50 @@ void lpgemm_rowvar_u8s8s32o32_4x16 selector1 = _mm512_mullo_epi32( selector2, selector1 ); c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); } + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_4x16: + { + selector1 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j ); + + // c[0,0-15] + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + // c[3,0-15] + c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_4x16: + { + selector1 = _mm512_setzero_epi32(); + + // c[0,0-15] + c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_max_epi32( selector1, c_int32_2p0 ); + + // c[3,0-15] + c_int32_3p0 = _mm512_max_epi32( selector1, c_int32_3p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_4x16_DISABLE: + ; // Store the results. // c[0,0-15] @@ -1111,9 +1465,19 @@ void lpgemm_rowvar_u8s8s32o32_3x16 int32_t* c, const dim_t rs_c, const int32_t alpha, - const int32_t beta + const int32_t beta, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op* post_ops_list ) { + static void* post_ops_labels[] = + { + &&POST_OPS_3x16_DISABLE, + &&POST_OPS_BIAS_3x16, + &&POST_OPS_RELU_3x16 + }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -1210,6 +1574,44 @@ void lpgemm_rowvar_u8s8s32o32_3x16 selector1 = _mm512_mullo_epi32( selector2, selector1 ); c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); } + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_3x16: + { + selector1 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j ); + + // c[0,0-15] + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_3x16: + { + selector1 = _mm512_setzero_epi32(); + + // c[0,0-15] + c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_max_epi32( selector1, c_int32_2p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_3x16_DISABLE: + ; // Store the results. // c[0,0-15] @@ -1235,9 +1637,19 @@ void lpgemm_rowvar_u8s8s32o32_2x16 int32_t* c, const dim_t rs_c, const int32_t alpha, - const int32_t beta + const int32_t beta, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op* post_ops_list ) { + static void* post_ops_labels[] = + { + &&POST_OPS_2x16_DISABLE, + &&POST_OPS_BIAS_2x16, + &&POST_OPS_RELU_2x16 + }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -1310,6 +1722,38 @@ void lpgemm_rowvar_u8s8s32o32_2x16 selector1 = _mm512_mullo_epi32( selector2, selector1 ); c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); } + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_2x16: + { + selector1 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j ); + + // c[0,0-15] + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_2x16: + { + selector1 = _mm512_setzero_epi32(); + + // c[0,0-15] + c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_2x16_DISABLE: + ; // Store the results. // c[0,0-15] @@ -1332,9 +1776,19 @@ void lpgemm_rowvar_u8s8s32o32_1x16 int32_t* c, const dim_t rs_c, const int32_t alpha, - const int32_t beta + const int32_t beta, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op* post_ops_list ) { + static void* post_ops_labels[] = + { + &&POST_OPS_1x16_DISABLE, + &&POST_OPS_BIAS_1x16, + &&POST_OPS_RELU_1x16 + }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -1383,6 +1837,32 @@ void lpgemm_rowvar_u8s8s32o32_1x16 selector1 = _mm512_mullo_epi32( selector2, selector1 ); c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); } + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_1x16: + { + selector1 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j ); + + // c[0,0-15] + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_1x16: + { + selector1 = _mm512_setzero_epi32(); + + // c[0,0-15] + c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_1x16_DISABLE: + ; // Store the results. // c[0,0-15] @@ -1402,14 +1882,31 @@ void lpgemm_rowvar_u8s8s32o32_5x32 int32_t* c, const dim_t rs_c, const int32_t alpha, - const int32_t beta + const int32_t beta, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op* post_ops_list ) { + static void* post_ops_labels[] = + { + &&POST_OPS_5x32_DISABLE, + &&POST_OPS_BIAS_5x32, + &&POST_OPS_RELU_5x32 + }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; uint32_t a_kfringe_buf = 0; + // B matrix storage. + __m512i b0; + __m512i b1; + + // A matrix storage. + __m512i a_int32_0; + // Registers to use for accumulating C. __m512i c_int32_0p0 = _mm512_setzero_epi32(); __m512i c_int32_0p1 = _mm512_setzero_epi32(); @@ -1428,11 +1925,11 @@ void lpgemm_rowvar_u8s8s32o32_5x32 for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) { - __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); - __m512i b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); // Broadcast a[0,kr:kr+4]. - __m512i a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 4. // c[0,0-31] = a[0,kr:kr+4]*b[kr:kr+4,0-31] @@ -1474,12 +1971,12 @@ void lpgemm_rowvar_u8s8s32o32_5x32 // Handle k remainder. if ( k_partial_pieces > 0 ) { - __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); - __m512i b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); // Broadcast a[0,kr:kr+4]. memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - __m512i a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 4. // c[0,0-31] = a[0,kr:kr+4]*b[kr:kr+4,0-31] @@ -1596,27 +2093,110 @@ void lpgemm_rowvar_u8s8s32o32_5x32 selector1 = _mm512_mullo_epi32( selector2, selector1 ); c_int32_4p1 = _mm512_add_epi32( selector1, c_int32_4p1 ); } - - // Store the results. - // c[0,0-15] - _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 0*16 ), c_int32_0p0 ); - // c[0, 16-31] - _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 1*16 ), c_int32_0p1 ); + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_5x32: + { + selector1 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 ) ); - // c[1,0-15] - _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 0*16 ), c_int32_1p0 ); + // c[0,0-15] + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); - // c[1,16-31] - _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 1*16 ), c_int32_1p1 ); + // c[0, 16-31] + c_int32_0p1 = _mm512_add_epi32( selector2, c_int32_0p1 ); - // c[2,0-15] - _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 0*16 ), c_int32_2p0 ); + // c[1,0-15] + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); - // c[2,16-31] - _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 1*16 ), c_int32_2p1 ); + // c[1, 16-31] + c_int32_1p1 = _mm512_add_epi32( selector2, c_int32_1p1 ); - // c[3,0-15] + // c[2,0-15] + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + // c[2, 16-31] + c_int32_2p1 = _mm512_add_epi32( selector2, c_int32_2p1 ); + + // c[3,0-15] + c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); + + // c[3, 16-31] + c_int32_3p1 = _mm512_add_epi32( selector2, c_int32_3p1 ); + + // c[4,0-15] + c_int32_4p0 = _mm512_add_epi32( selector1, c_int32_4p0 ); + + // c[4, 16-31] + c_int32_4p1 = _mm512_add_epi32( selector2, c_int32_4p1 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_5x32: + { + selector1 = _mm512_setzero_epi32(); + + // c[0,0-15] + c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + c_int32_0p1 = _mm512_max_epi32( selector1, c_int32_0p1 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); + + // c[1,16-31] + c_int32_1p1 = _mm512_max_epi32( selector1, c_int32_1p1 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_max_epi32( selector1, c_int32_2p0 ); + + // c[2,16-31] + c_int32_2p1 = _mm512_max_epi32( selector1, c_int32_2p1 ); + + // c[3,0-15] + c_int32_3p0 = _mm512_max_epi32( selector1, c_int32_3p0 ); + + // c[3,16-31] + c_int32_3p1 = _mm512_max_epi32( selector1, c_int32_3p1 ); + + // c[4,0-15] + c_int32_4p0 = _mm512_max_epi32( selector1, c_int32_4p0 ); + + // c[4,16-31] + c_int32_4p1 = _mm512_max_epi32( selector1, c_int32_4p1 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_5x32_DISABLE: + ; + + // Store the results. + // c[0,0-15] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 0*16 ), c_int32_0p0 ); + + // c[0, 16-31] + _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 1*16 ), c_int32_0p1 ); + + // c[1,0-15] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 0*16 ), c_int32_1p0 ); + + // c[1,16-31] + _mm512_storeu_epi32( c + ( rs_c * 1 ) + ( 1*16 ), c_int32_1p1 ); + + // c[2,0-15] + _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 0*16 ), c_int32_2p0 ); + + // c[2,16-31] + _mm512_storeu_epi32( c + ( rs_c * 2 ) + ( 1*16 ), c_int32_2p1 ); + + // c[3,0-15] _mm512_storeu_epi32( c + ( rs_c * 3 ) + ( 0*16 ), c_int32_3p0 ); // c[3,16-31] @@ -1642,14 +2222,31 @@ void lpgemm_rowvar_u8s8s32o32_4x32 int32_t* c, const dim_t rs_c, const int32_t alpha, - const int32_t beta + const int32_t beta, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op* post_ops_list ) { + static void* post_ops_labels[] = + { + &&POST_OPS_4x32_DISABLE, + &&POST_OPS_BIAS_4x32, + &&POST_OPS_RELU_4x32 + }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; uint32_t a_kfringe_buf = 0; + // B matrix storage. + __m512i b0; + __m512i b1; + + // A matrix storage. + __m512i a_int32_0; + // Registers to use for accumulating C. __m512i c_int32_0p0 = _mm512_setzero_epi32(); __m512i c_int32_0p1 = _mm512_setzero_epi32(); @@ -1665,11 +2262,11 @@ void lpgemm_rowvar_u8s8s32o32_4x32 for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) { - __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); - __m512i b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); // Broadcast a[0,kr:kr+4]. - __m512i a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 4. // c[0,0-31] = a[0,kr:kr+4]*b[kr:kr+4,0-31] @@ -1703,12 +2300,12 @@ void lpgemm_rowvar_u8s8s32o32_4x32 // Handle k remainder. if ( k_partial_pieces > 0 ) { - __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); - __m512i b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); // Broadcast a[0,kr:kr+4]. memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - __m512i a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 4. // c[0,0-31] = a[0,kr:kr+4]*b[kr:kr+4,0-31] @@ -1803,6 +2400,77 @@ void lpgemm_rowvar_u8s8s32o32_4x32 selector1 = _mm512_mullo_epi32( selector2, selector1 ); c_int32_3p1 = _mm512_add_epi32( selector1, c_int32_3p1 ); } + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_4x32: + { + selector1 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 ) ); + + // c[0,0-15] + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + c_int32_0p1 = _mm512_add_epi32( selector2, c_int32_0p1 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[1, 16-31] + c_int32_1p1 = _mm512_add_epi32( selector2, c_int32_1p1 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + // c[2, 16-31] + c_int32_2p1 = _mm512_add_epi32( selector2, c_int32_2p1 ); + + // c[3,0-15] + c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); + + // c[3, 16-31] + c_int32_3p1 = _mm512_add_epi32( selector2, c_int32_3p1 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_4x32: + { + selector1 = _mm512_setzero_epi32(); + + // c[0,0-15] + c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + c_int32_0p1 = _mm512_max_epi32( selector1, c_int32_0p1 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); + + // c[1,16-31] + c_int32_1p1 = _mm512_max_epi32( selector1, c_int32_1p1 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_max_epi32( selector1, c_int32_2p0 ); + + // c[2,16-31] + c_int32_2p1 = _mm512_max_epi32( selector1, c_int32_2p1 ); + + // c[3,0-15] + c_int32_3p0 = _mm512_max_epi32( selector1, c_int32_3p0 ); + + // c[3,16-31] + c_int32_3p1 = _mm512_max_epi32( selector1, c_int32_3p1 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_4x32_DISABLE: + ; // Store the results. // c[0,0-15] @@ -1843,14 +2511,31 @@ void lpgemm_rowvar_u8s8s32o32_3x32 int32_t* c, const dim_t rs_c, const int32_t alpha, - const int32_t beta + const int32_t beta, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op* post_ops_list ) { + static void* post_ops_labels[] = + { + &&POST_OPS_3x32_DISABLE, + &&POST_OPS_BIAS_3x32, + &&POST_OPS_RELU_3x32 + }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; uint32_t a_kfringe_buf = 0; + // B matrix storage. + __m512i b0; + __m512i b1; + + // A matrix storage. + __m512i a_int32_0; + // Registers to use for accumulating C. __m512i c_int32_0p0 = _mm512_setzero_epi32(); __m512i c_int32_0p1 = _mm512_setzero_epi32(); @@ -1863,11 +2548,11 @@ void lpgemm_rowvar_u8s8s32o32_3x32 for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) { - __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); - __m512i b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); // Broadcast a[0,kr:kr+4]. - __m512i a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 4. // c[0,0-31] = a[0,kr:kr+4]*b[kr:kr+4,0-31] @@ -1893,12 +2578,12 @@ void lpgemm_rowvar_u8s8s32o32_3x32 // Handle k remainder. if ( k_partial_pieces > 0 ) { - __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); - __m512i b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); // Broadcast a[0,kr:kr+4]. memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - __m512i a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 4. // c[0,0-31] = a[0,kr:kr+4]*b[kr:kr+4,0-31] @@ -1971,6 +2656,65 @@ void lpgemm_rowvar_u8s8s32o32_3x32 selector1 = _mm512_mullo_epi32( selector2, selector1 ); c_int32_2p1 = _mm512_add_epi32( selector1, c_int32_2p1 ); } + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_3x32: + { + selector1 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 ) ); + + // c[0,0-15] + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + c_int32_0p1 = _mm512_add_epi32( selector2, c_int32_0p1 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[1, 16-31] + c_int32_1p1 = _mm512_add_epi32( selector2, c_int32_1p1 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + // c[2, 16-31] + c_int32_2p1 = _mm512_add_epi32( selector2, c_int32_2p1 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_3x32: + { + selector1 = _mm512_setzero_epi32(); + + // c[0,0-15] + c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + c_int32_0p1 = _mm512_max_epi32( selector1, c_int32_0p1 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); + + // c[1,16-31] + c_int32_1p1 = _mm512_max_epi32( selector1, c_int32_1p1 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_max_epi32( selector1, c_int32_2p0 ); + + // c[2,16-31] + c_int32_2p1 = _mm512_max_epi32( selector1, c_int32_2p1 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_3x32_DISABLE: + ; // Store the results. // c[0,0-15] @@ -2005,14 +2749,31 @@ void lpgemm_rowvar_u8s8s32o32_2x32 int32_t* c, const dim_t rs_c, const int32_t alpha, - const int32_t beta + const int32_t beta, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op* post_ops_list ) { + static void* post_ops_labels[] = + { + &&POST_OPS_2x32_DISABLE, + &&POST_OPS_BIAS_2x32, + &&POST_OPS_RELU_2x32 + }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; uint32_t a_kfringe_buf = 0; + // B matrix storage. + __m512i b0; + __m512i b1; + + // A matrix storage. + __m512i a_int32_0; + // Registers to use for accumulating C. __m512i c_int32_0p0 = _mm512_setzero_epi32(); __m512i c_int32_0p1 = _mm512_setzero_epi32(); @@ -2022,11 +2783,11 @@ void lpgemm_rowvar_u8s8s32o32_2x32 for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) { - __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); - __m512i b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); // Broadcast a[0,kr:kr+4]. - __m512i a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 4. // c[0,0-31] = a[0,kr:kr+4]*b[kr:kr+4,0-31] @@ -2044,12 +2805,12 @@ void lpgemm_rowvar_u8s8s32o32_2x32 // Handle k remainder. if ( k_partial_pieces > 0 ) { - __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); - __m512i b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); // Broadcast a[0,kr:kr+4]. memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - __m512i a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 4. // c[0,0-31] = a[0,kr:kr+4]*b[kr:kr+4,0-31] @@ -2100,6 +2861,53 @@ void lpgemm_rowvar_u8s8s32o32_2x32 selector1 = _mm512_mullo_epi32( selector2, selector1 ); c_int32_1p1 = _mm512_add_epi32( selector1, c_int32_1p1 ); } + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_2x32: + { + selector1 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 ) ); + + // c[0,0-15] + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + c_int32_0p1 = _mm512_add_epi32( selector2, c_int32_0p1 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[1, 16-31] + c_int32_1p1 = _mm512_add_epi32( selector2, c_int32_1p1 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_2x32: + { + selector1 = _mm512_setzero_epi32(); + + // c[0,0-15] + c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + c_int32_0p1 = _mm512_max_epi32( selector1, c_int32_0p1 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); + + // c[1,16-31] + c_int32_1p1 = _mm512_max_epi32( selector1, c_int32_1p1 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_2x32_DISABLE: + ; // Store the results. // c[0,0-15] @@ -2128,25 +2936,42 @@ void lpgemm_rowvar_u8s8s32o32_1x32 int32_t* c, const dim_t rs_c, const int32_t alpha, - const int32_t beta + const int32_t beta, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op* post_ops_list ) { + static void* post_ops_labels[] = + { + &&POST_OPS_1x32_DISABLE, + &&POST_OPS_BIAS_1x32, + &&POST_OPS_RELU_1x32 + }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; uint32_t a_kfringe_buf = 0; + // B matrix storage. + __m512i b0; + __m512i b1; + + // A matrix storage. + __m512i a_int32_0; + // Registers to use for accumulating C. __m512i c_int32_0p0 = _mm512_setzero_epi32(); __m512i c_int32_0p1 = _mm512_setzero_epi32(); for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) { - __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); - __m512i b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); // Broadcast a[0,kr:kr+4]. - __m512i a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 4. // c[0,0-31] = a[0,kr:kr+4]*b[kr:kr+4,0-31] @@ -2156,12 +2981,12 @@ void lpgemm_rowvar_u8s8s32o32_1x32 // Handle k remainder. if ( k_partial_pieces > 0 ) { - __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); - __m512i b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); // Broadcast a[0,kr:kr+4]. memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - __m512i a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 4. // c[0,0-31] = a[0,kr:kr+4]*b[kr:kr+4,0-31] @@ -2190,6 +3015,41 @@ void lpgemm_rowvar_u8s8s32o32_1x32 selector1 = _mm512_mullo_epi32( selector2, selector1 ); c_int32_0p1 = _mm512_add_epi32( selector1, c_int32_0p1 ); } + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_1x32: + { + selector1 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 ) ); + + // c[0,0-15] + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + c_int32_0p1 = _mm512_add_epi32( selector2, c_int32_0p1 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_1x32: + { + selector1 = _mm512_setzero_epi32(); + + // c[0,0-15] + c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + c_int32_0p1 = _mm512_max_epi32( selector1, c_int32_0p1 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_1x32_DISABLE: + ; // Store the results. // c[0,0-15] @@ -2212,14 +3072,32 @@ void lpgemm_rowvar_u8s8s32o32_5x48 int32_t* c, const dim_t rs_c, const int32_t alpha, - const int32_t beta + const int32_t beta, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op* post_ops_list ) { + static void* post_ops_labels[] = + { + &&POST_OPS_5x48_DISABLE, + &&POST_OPS_BIAS_5x48, + &&POST_OPS_RELU_5x48 + }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; uint32_t a_kfringe_buf = 0; + // B matrix storage. + __m512i b0; + __m512i b1; + __m512i b2; + + // A matrix storage. + __m512i a_int32_0; + // Registers to use for accumulating C. __m512i c_int32_0p0 = _mm512_setzero_epi32(); __m512i c_int32_0p1 = _mm512_setzero_epi32(); @@ -2243,12 +3121,12 @@ void lpgemm_rowvar_u8s8s32o32_5x48 for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) { - __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); - __m512i b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); - __m512i b2 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 2 ) ); + b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + b2 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 2 ) ); // Broadcast a[0,kr:kr+4]. - __m512i a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 4. // c[0,0-47] = a[0,kr:kr+4]*b[kr:kr+4,0-47] @@ -2295,13 +3173,13 @@ void lpgemm_rowvar_u8s8s32o32_5x48 // Handle k remainder. if ( k_partial_pieces > 0 ) { - __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); - __m512i b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); - __m512i b2 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); + b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + b2 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); // Broadcast a[0,kr:kr+4]. memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - __m512i a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 4. // c[0,0-47] = a[0,kr:kr+4]*b[kr:kr+4,0-47] @@ -2453,7 +3331,123 @@ void lpgemm_rowvar_u8s8s32o32_5x48 selector1 = _mm512_mullo_epi32( selector2, selector1 ); c_int32_4p2 = _mm512_add_epi32( selector1, c_int32_4p2 ); } - + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_5x48: + { + selector1 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 ) ); + a_int32_0 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 2 * 16 ) ); + + // c[0,0-15] + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + c_int32_0p1 = _mm512_add_epi32( selector2, c_int32_0p1 ); + + // c[0,32-47] + c_int32_0p2 = _mm512_add_epi32( a_int32_0, c_int32_0p2 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[1, 16-31] + c_int32_1p1 = _mm512_add_epi32( selector2, c_int32_1p1 ); + + // c[1,32-47] + c_int32_1p2 = _mm512_add_epi32( a_int32_0, c_int32_1p2 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + // c[2, 16-31] + c_int32_2p1 = _mm512_add_epi32( selector2, c_int32_2p1 ); + + // c[2,32-47] + c_int32_2p2 = _mm512_add_epi32( a_int32_0, c_int32_2p2 ); + + // c[3,0-15] + c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); + + // c[3, 16-31] + c_int32_3p1 = _mm512_add_epi32( selector2, c_int32_3p1 ); + + // c[3,32-47] + c_int32_3p2 = _mm512_add_epi32( a_int32_0, c_int32_3p2 ); + + // c[4,0-15] + c_int32_4p0 = _mm512_add_epi32( selector1, c_int32_4p0 ); + + // c[4, 16-31] + c_int32_4p1 = _mm512_add_epi32( selector2, c_int32_4p1 ); + + // c[4,32-47] + c_int32_4p2 = _mm512_add_epi32( a_int32_0, c_int32_4p2 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_5x48: + { + selector1 = _mm512_setzero_epi32(); + + // c[0,0-15] + c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + c_int32_0p1 = _mm512_max_epi32( selector1, c_int32_0p1 ); + + // c[0,32-47] + c_int32_0p2 = _mm512_max_epi32( selector1, c_int32_0p2 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); + + // c[1,16-31] + c_int32_1p1 = _mm512_max_epi32( selector1, c_int32_1p1 ); + + // c[1,32-47] + c_int32_1p2 = _mm512_max_epi32( selector1, c_int32_1p2 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_max_epi32( selector1, c_int32_2p0 ); + + // c[2,16-31] + c_int32_2p1 = _mm512_max_epi32( selector1, c_int32_2p1 ); + + // c[2,32-47] + c_int32_2p2 = _mm512_max_epi32( selector1, c_int32_2p2 ); + + // c[3,0-15] + c_int32_3p0 = _mm512_max_epi32( selector1, c_int32_3p0 ); + + // c[3,16-31] + c_int32_3p1 = _mm512_max_epi32( selector1, c_int32_3p1 ); + + // c[3,32-47] + c_int32_3p2 = _mm512_max_epi32( selector1, c_int32_3p2 ); + + // c[4,0-15] + c_int32_4p0 = _mm512_max_epi32( selector1, c_int32_4p0 ); + + // c[4,16-31] + c_int32_4p1 = _mm512_max_epi32( selector1, c_int32_4p1 ); + + // c[4,32-47] + c_int32_4p2 = _mm512_max_epi32( selector1, c_int32_4p2 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_5x48_DISABLE: + ; + // Store the results. // c[0,0-15] _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 0*16 ), c_int32_0p0 ); @@ -2514,14 +3508,32 @@ void lpgemm_rowvar_u8s8s32o32_4x48 int32_t* c, const dim_t rs_c, const int32_t alpha, - const int32_t beta + const int32_t beta, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op* post_ops_list ) { + static void* post_ops_labels[] = + { + &&POST_OPS_4x48_DISABLE, + &&POST_OPS_BIAS_4x48, + &&POST_OPS_RELU_4x48 + }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; uint32_t a_kfringe_buf = 0; + // B matrix storage. + __m512i b0; + __m512i b1; + __m512i b2; + + // A matrix storage. + __m512i a_int32_0; + // Registers to use for accumulating C. __m512i c_int32_0p0 = _mm512_setzero_epi32(); __m512i c_int32_0p1 = _mm512_setzero_epi32(); @@ -2541,12 +3553,12 @@ void lpgemm_rowvar_u8s8s32o32_4x48 for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) { - __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); - __m512i b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); - __m512i b2 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 2 ) ); + b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + b2 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 2 ) ); // Broadcast a[0,kr:kr+4]. - __m512i a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 4. // c[0,0-47] = a[0,kr:kr+4]*b[kr:kr+4,0-47] @@ -2584,13 +3596,13 @@ void lpgemm_rowvar_u8s8s32o32_4x48 // Handle k remainder. if ( k_partial_pieces > 0 ) { - __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); - __m512i b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); - __m512i b2 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); + b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + b2 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); // Broadcast a[0,kr:kr+4]. memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - __m512i a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 4. // c[0,0-47] = a[0,kr:kr+4]*b[kr:kr+4,0-47] @@ -2713,6 +3725,104 @@ void lpgemm_rowvar_u8s8s32o32_4x48 selector1 = _mm512_mullo_epi32( selector2, selector1 ); c_int32_3p2 = _mm512_add_epi32( selector1, c_int32_3p2 ); } + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_4x48: + { + selector1 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 ) ); + a_int32_0 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 2 * 16 ) ); + + // c[0,0-15] + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + c_int32_0p1 = _mm512_add_epi32( selector2, c_int32_0p1 ); + + // c[0,32-47] + c_int32_0p2 = _mm512_add_epi32( a_int32_0, c_int32_0p2 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[1, 16-31] + c_int32_1p1 = _mm512_add_epi32( selector2, c_int32_1p1 ); + + // c[1,32-47] + c_int32_1p2 = _mm512_add_epi32( a_int32_0, c_int32_1p2 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + // c[2, 16-31] + c_int32_2p1 = _mm512_add_epi32( selector2, c_int32_2p1 ); + + // c[2,32-47] + c_int32_2p2 = _mm512_add_epi32( a_int32_0, c_int32_2p2 ); + + // c[3,0-15] + c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); + + // c[3, 16-31] + c_int32_3p1 = _mm512_add_epi32( selector2, c_int32_3p1 ); + + // c[3,32-47] + c_int32_3p2 = _mm512_add_epi32( a_int32_0, c_int32_3p2 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_4x48: + { + selector1 = _mm512_setzero_epi32(); + + // c[0,0-15] + c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + c_int32_0p1 = _mm512_max_epi32( selector1, c_int32_0p1 ); + + // c[0,32-47] + c_int32_0p2 = _mm512_max_epi32( selector1, c_int32_0p2 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); + + // c[1,16-31] + c_int32_1p1 = _mm512_max_epi32( selector1, c_int32_1p1 ); + + // c[1,32-47] + c_int32_1p2 = _mm512_max_epi32( selector1, c_int32_1p2 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_max_epi32( selector1, c_int32_2p0 ); + + // c[2,16-31] + c_int32_2p1 = _mm512_max_epi32( selector1, c_int32_2p1 ); + + // c[2,32-47] + c_int32_2p2 = _mm512_max_epi32( selector1, c_int32_2p2 ); + + // c[3,0-15] + c_int32_3p0 = _mm512_max_epi32( selector1, c_int32_3p0 ); + + // c[3,16-31] + c_int32_3p1 = _mm512_max_epi32( selector1, c_int32_3p1 ); + + // c[3,32-47] + c_int32_3p2 = _mm512_max_epi32( selector1, c_int32_3p2 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_4x48_DISABLE: + ; // Store the results. // c[0,0-15] @@ -2765,14 +3875,32 @@ void lpgemm_rowvar_u8s8s32o32_3x48 int32_t* c, const dim_t rs_c, const int32_t alpha, - const int32_t beta + const int32_t beta, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op* post_ops_list ) { + static void* post_ops_labels[] = + { + &&POST_OPS_3x48_DISABLE, + &&POST_OPS_BIAS_3x48, + &&POST_OPS_RELU_3x48 + }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; uint32_t a_kfringe_buf = 0; + // B matrix storage. + __m512i b0; + __m512i b1; + __m512i b2; + + // A matrix storage. + __m512i a_int32_0; + // Registers to use for accumulating C. __m512i c_int32_0p0 = _mm512_setzero_epi32(); __m512i c_int32_0p1 = _mm512_setzero_epi32(); @@ -2788,12 +3916,12 @@ void lpgemm_rowvar_u8s8s32o32_3x48 for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) { - __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); - __m512i b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); - __m512i b2 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 2 ) ); + b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + b2 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 2 ) ); // Broadcast a[0,kr:kr+4]. - __m512i a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 4. // c[0,0-47] = a[0,kr:kr+4]*b[kr:kr+4,0-47] @@ -2822,13 +3950,13 @@ void lpgemm_rowvar_u8s8s32o32_3x48 // Handle k remainder. if ( k_partial_pieces > 0 ) { - __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); - __m512i b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); - __m512i b2 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); + b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + b2 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); // Broadcast a[0,kr:kr+4]. memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - __m512i a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 4. // c[0,0-47] = a[0,kr:kr+4]*b[kr:kr+4,0-47] @@ -2922,6 +4050,86 @@ void lpgemm_rowvar_u8s8s32o32_3x48 selector1 = _mm512_mullo_epi32( selector2, selector1 ); c_int32_2p2 = _mm512_add_epi32( selector1, c_int32_2p2 ); } + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_3x48: + { + selector1 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 ) ); + a_int32_0 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 2 * 16 ) ); + + // c[0,0-15] + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + c_int32_0p1 = _mm512_add_epi32( selector2, c_int32_0p1 ); + + // c[0,32-47] + c_int32_0p2 = _mm512_add_epi32( a_int32_0, c_int32_0p2 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[1, 16-31] + c_int32_1p1 = _mm512_add_epi32( selector2, c_int32_1p1 ); + + // c[1,32-47] + c_int32_1p2 = _mm512_add_epi32( a_int32_0, c_int32_1p2 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + // c[2, 16-31] + c_int32_2p1 = _mm512_add_epi32( selector2, c_int32_2p1 ); + + // c[2,32-47] + c_int32_2p2 = _mm512_add_epi32( a_int32_0, c_int32_2p2 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_3x48: + { + selector1 = _mm512_setzero_epi32(); + + // c[0,0-15] + c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + c_int32_0p1 = _mm512_max_epi32( selector1, c_int32_0p1 ); + + // c[0,32-47] + c_int32_0p2 = _mm512_max_epi32( selector1, c_int32_0p2 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); + + // c[1,16-31] + c_int32_1p1 = _mm512_max_epi32( selector1, c_int32_1p1 ); + + // c[1,32-47] + c_int32_1p2 = _mm512_max_epi32( selector1, c_int32_1p2 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_max_epi32( selector1, c_int32_2p0 ); + + // c[2,16-31] + c_int32_2p1 = _mm512_max_epi32( selector1, c_int32_2p1 ); + + // c[2,32-47] + c_int32_2p2 = _mm512_max_epi32( selector1, c_int32_2p2 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_3x48_DISABLE: + ; // Store the results. // c[0,0-15] @@ -2965,14 +4173,32 @@ void lpgemm_rowvar_u8s8s32o32_2x48 int32_t* c, const dim_t rs_c, const int32_t alpha, - const int32_t beta + const int32_t beta, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op* post_ops_list ) { + static void* post_ops_labels[] = + { + &&POST_OPS_2x48_DISABLE, + &&POST_OPS_BIAS_2x48, + &&POST_OPS_RELU_2x48 + }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; uint32_t a_kfringe_buf = 0; + // B matrix storage. + __m512i b0; + __m512i b1; + __m512i b2; + + // A matrix storage. + __m512i a_int32_0; + // Registers to use for accumulating C. __m512i c_int32_0p0 = _mm512_setzero_epi32(); __m512i c_int32_0p1 = _mm512_setzero_epi32(); @@ -2984,12 +4210,12 @@ void lpgemm_rowvar_u8s8s32o32_2x48 for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) { - __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); - __m512i b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); - __m512i b2 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 2 ) ); + b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + b2 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 2 ) ); // Broadcast a[0,kr:kr+4]. - __m512i a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 4. // c[0,0-47] = a[0,kr:kr+4]*b[kr:kr+4,0-47] @@ -3009,13 +4235,13 @@ void lpgemm_rowvar_u8s8s32o32_2x48 // Handle k remainder. if ( k_partial_pieces > 0 ) { - __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); - __m512i b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); - __m512i b2 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); + b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + b2 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); // Broadcast a[0,kr:kr+4]. memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - __m512i a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 4. // c[0,0-47] = a[0,kr:kr+4]*b[kr:kr+4,0-47] @@ -3080,6 +4306,68 @@ void lpgemm_rowvar_u8s8s32o32_2x48 selector1 = _mm512_mullo_epi32( selector2, selector1 ); c_int32_1p2 = _mm512_add_epi32( selector1, c_int32_1p2 ); } + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_2x48: + { + selector1 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 ) ); + a_int32_0 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 2 * 16 ) ); + + // c[0,0-15] + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + c_int32_0p1 = _mm512_add_epi32( selector2, c_int32_0p1 ); + + // c[0,32-47] + c_int32_0p2 = _mm512_add_epi32( a_int32_0, c_int32_0p2 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[1, 16-31] + c_int32_1p1 = _mm512_add_epi32( selector2, c_int32_1p1 ); + + // c[1,32-47] + c_int32_1p2 = _mm512_add_epi32( a_int32_0, c_int32_1p2 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_2x48: + { + selector1 = _mm512_setzero_epi32(); + + // c[0,0-15] + c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + c_int32_0p1 = _mm512_max_epi32( selector1, c_int32_0p1 ); + + // c[0,32-47] + c_int32_0p2 = _mm512_max_epi32( selector1, c_int32_0p2 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); + + // c[1,16-31] + c_int32_1p1 = _mm512_max_epi32( selector1, c_int32_1p1 ); + + // c[1,32-47] + c_int32_1p2 = _mm512_max_epi32( selector1, c_int32_1p2 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_2x48_DISABLE: + ; // Store the results. // c[0,0-15] @@ -3114,14 +4402,32 @@ void lpgemm_rowvar_u8s8s32o32_1x48 int32_t* c, const dim_t rs_c, const int32_t alpha, - const int32_t beta + const int32_t beta, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op* post_ops_list ) { + static void* post_ops_labels[] = + { + &&POST_OPS_1x48_DISABLE, + &&POST_OPS_BIAS_1x48, + &&POST_OPS_RELU_1x48 + }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; uint32_t a_kfringe_buf = 0; + // B matrix storage. + __m512i b0; + __m512i b1; + __m512i b2; + + // A matrix storage. + __m512i a_int32_0; + // Registers to use for accumulating C. __m512i c_int32_0p0 = _mm512_setzero_epi32(); __m512i c_int32_0p1 = _mm512_setzero_epi32(); @@ -3129,12 +4435,12 @@ void lpgemm_rowvar_u8s8s32o32_1x48 for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) { - __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); - __m512i b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); - __m512i b2 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 2 ) ); + b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + b2 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 2 ) ); // Broadcast a[0,kr:kr+4]. - __m512i a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 4. // c[0,0-47] = a[0,kr:kr+4]*b[kr:kr+4,0-47] @@ -3145,13 +4451,13 @@ void lpgemm_rowvar_u8s8s32o32_1x48 // Handle k remainder. if ( k_partial_pieces > 0 ) { - __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); - __m512i b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); - __m512i b2 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); + b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + b2 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); // Broadcast a[0,kr:kr+4]. memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - __m512i a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 4. // c[0,0-47] = a[0,kr:kr+4]*b[kr:kr+4,0-47] @@ -3187,6 +4493,50 @@ void lpgemm_rowvar_u8s8s32o32_1x48 selector1 = _mm512_mullo_epi32( selector2, selector1 ); c_int32_0p2 = _mm512_add_epi32( selector1, c_int32_0p2 ); } + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_1x48: + { + selector1 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 ) ); + a_int32_0 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 2 * 16 ) ); + + // c[0,0-15] + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + c_int32_0p1 = _mm512_add_epi32( selector2, c_int32_0p1 ); + + // c[0,32-47] + c_int32_0p2 = _mm512_add_epi32( a_int32_0, c_int32_0p2 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_1x48: + { + selector1 = _mm512_setzero_epi32(); + + // c[0,0-15] + c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + c_int32_0p1 = _mm512_max_epi32( selector1, c_int32_0p1 ); + + // c[0,32-47] + c_int32_0p2 = _mm512_max_epi32( selector1, c_int32_0p2 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_1x48_DISABLE: + ; // Store the results. // c[0,0-15] diff --git a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_n_fringe.h b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_n_fringe.h index f2aee7c832..84e9ee5b7b 100644 --- a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_n_fringe.h +++ b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_n_fringe.h @@ -35,6 +35,8 @@ #ifndef BLIS_GEMM_INT8_NFRINGE #define BLIS_GEMM_INT8_NFRINGE +#include "lpgemm_post_ops.h" + // 6xlt16 int8o32 fringe kernel void lpgemm_rowvar_u8s8s32o32_6xlt16 ( @@ -51,7 +53,11 @@ void lpgemm_rowvar_u8s8s32o32_6xlt16 const dim_t rs_c, const int32_t alpha, const int32_t beta, - const dim_t n0_rem + const dim_t n0_rem, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op* post_ops_list ); // 6x16 int8o32 fringe kernel @@ -69,7 +75,11 @@ void lpgemm_rowvar_u8s8s32o32_6x16 int32_t* c, const dim_t rs_c, const int32_t alpha, - const int32_t beta + const int32_t beta, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op* post_ops_list ); // 6x32 int8o32 fringe kernel @@ -87,7 +97,11 @@ void lpgemm_rowvar_u8s8s32o32_6x32 int32_t* c, const dim_t rs_c, const int32_t alpha, - const int32_t beta + const int32_t beta, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op* post_ops_list ); // 6x48 int8o32 fringe kernel @@ -105,7 +119,11 @@ void lpgemm_rowvar_u8s8s32o32_6x48 int32_t* c, const dim_t rs_c, const int32_t alpha, - const int32_t beta + const int32_t beta, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op* post_ops_list ); #endif //BLIS_GEMM_INT8_NFRINGE diff --git a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_n_fringe_amd512vnni.c b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_n_fringe_amd512vnni.c index 1ab6182f17..70369f73ed 100644 --- a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_n_fringe_amd512vnni.c +++ b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_n_fringe_amd512vnni.c @@ -55,9 +55,19 @@ void lpgemm_rowvar_u8s8s32o32_6xlt16 const dim_t rs_c, const int32_t alpha, const int32_t beta, - const dim_t n0_rem + const dim_t n0_rem, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op* post_ops_list ) { + static void* post_ops_labels[] = + { + &&POST_OPS_6xLT16_DISABLE, + &&POST_OPS_BIAS_6xLT16, + &&POST_OPS_RELU_6xLT16 + }; dim_t MR = 6; dim_t m_full_pieces = m0 / MR; dim_t m_full_pieces_loop_limit = m_full_pieces * MR; @@ -68,6 +78,12 @@ void lpgemm_rowvar_u8s8s32o32_6xlt16 uint32_t a_kfringe_buf = 0; + // B matrix storage. + __m512i b0; + + // A matrix storage. + __m512i a_int32_0; + // For corner cases. int32_t buf0[16]; int32_t buf1[16]; @@ -98,10 +114,10 @@ void lpgemm_rowvar_u8s8s32o32_6xlt16 // in vnni instructions and each load to ZMM register will have 4 // elements along k direction and 16 elements across n directions, // so 4x16 elements to a ZMM register. - __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); // Broadcast a[0,kr:kr+4]. - __m512i a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 4. // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] @@ -145,7 +161,7 @@ void lpgemm_rowvar_u8s8s32o32_6xlt16 // Handle k remainder. if ( k_partial_pieces > 0 ) { - __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); // Broadcast a[0,kr:kr+4]. memcpy @@ -154,7 +170,7 @@ void lpgemm_rowvar_u8s8s32o32_6xlt16 ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - __m512i a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 4. // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] @@ -283,6 +299,62 @@ void lpgemm_rowvar_u8s8s32o32_6xlt16 selector1 = _mm512_mullo_epi32( selector2, selector1 ); c_int32_5p0 = _mm512_add_epi32( selector1, c_int32_5p0 ); } + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_6xLT16: + { + selector1 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j ); + + // c[0,0-15] + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + // c[3,0-15] + c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); + + // c[4,0-15] + c_int32_4p0 = _mm512_add_epi32( selector1, c_int32_4p0 ); + + // c[5,0-15] + c_int32_5p0 = _mm512_add_epi32( selector1, c_int32_5p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_6xLT16: + { + selector1 = _mm512_setzero_epi32(); + + // c[0,0-15] + c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_max_epi32( selector1, c_int32_2p0 ); + + // c[3,0-15] + c_int32_3p0 = _mm512_max_epi32( selector1, c_int32_3p0 ); + + // c[4,0-15] + c_int32_4p0 = _mm512_max_epi32( selector1, c_int32_4p0 ); + + // c[5,0-15] + c_int32_5p0 = _mm512_max_epi32( selector1, c_int32_5p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_6xLT16_DISABLE: + ; // Store the results. // c[0,0-15] @@ -323,6 +395,7 @@ void lpgemm_rowvar_u8s8s32o32_6xlt16 memcpy( c + ( rs_c * ( ir + 5 ) ) + ( 0*16 ), buf5, ( n0_rem * sizeof( int32_t ) ) ); a = a + ( MR * ps_a ); + post_op_c_i += MR; } if ( m_partial_pieces > 0 ) @@ -336,7 +409,10 @@ void lpgemm_rowvar_u8s8s32o32_6xlt16 a, rs_a, cs_a_use, b, rs_b, cs_b, ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, - alpha, beta, n0_rem + alpha, beta, n0_rem, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list ); } else if ( m_partial_pieces == 4 ) @@ -348,7 +424,10 @@ void lpgemm_rowvar_u8s8s32o32_6xlt16 a, rs_a, cs_a_use, b, rs_b, cs_b, ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, - alpha, beta, n0_rem + alpha, beta, n0_rem, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list ); } else if ( m_partial_pieces == 3 ) @@ -360,7 +439,10 @@ void lpgemm_rowvar_u8s8s32o32_6xlt16 a, rs_a, cs_a_use, b, rs_b, cs_b, ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, - alpha, beta, n0_rem + alpha, beta, n0_rem, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list ); } else if ( m_partial_pieces == 2 ) @@ -372,7 +454,10 @@ void lpgemm_rowvar_u8s8s32o32_6xlt16 a, rs_a, cs_a_use, b, rs_b, cs_b, ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, - alpha, beta, n0_rem + alpha, beta, n0_rem, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list ); } else if ( m_partial_pieces == 1 ) @@ -384,7 +469,10 @@ void lpgemm_rowvar_u8s8s32o32_6xlt16 a, rs_a, cs_a_use, b, rs_b, cs_b, ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, - alpha, beta, n0_rem + alpha, beta, n0_rem, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list ); } } @@ -405,9 +493,19 @@ void lpgemm_rowvar_u8s8s32o32_6x16 int32_t* c, const dim_t rs_c, const int32_t alpha, - const int32_t beta + const int32_t beta, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op* post_ops_list ) { + static void* post_ops_labels[] = + { + &&POST_OPS_6x16_DISABLE, + &&POST_OPS_BIAS_6x16, + &&POST_OPS_RELU_6x16 + }; dim_t MR = 6; dim_t m_full_pieces = m0 / MR; dim_t m_full_pieces_loop_limit = m_full_pieces * MR; @@ -418,6 +516,12 @@ void lpgemm_rowvar_u8s8s32o32_6x16 uint32_t a_kfringe_buf = 0; + // B matrix storage. + __m512i b0; + + // A matrix storage. + __m512i a_int32_0; + for ( dim_t ir = 0; ir < m_full_pieces_loop_limit; ir += MR ) { // Registers to use for accumulating C. @@ -440,10 +544,10 @@ void lpgemm_rowvar_u8s8s32o32_6x16 // instructions and each load to ZMM register will have 4 elements // along k direction and 16 elements across n directions, so 4x16 // elements to a ZMM register. - __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); // Broadcast a[0,kr:kr+4]. - __m512i a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 4. // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] @@ -487,7 +591,7 @@ void lpgemm_rowvar_u8s8s32o32_6x16 // Handle k remainder. if ( k_partial_pieces > 0 ) { - __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); // Broadcast a[0,kr:kr+4]. memcpy @@ -496,7 +600,7 @@ void lpgemm_rowvar_u8s8s32o32_6x16 ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - __m512i a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 4. // c[0,0-15] = a[0,kr:kr+4]*b[kr:kr+4,0-15] @@ -618,6 +722,62 @@ void lpgemm_rowvar_u8s8s32o32_6x16 selector1 = _mm512_mullo_epi32( selector2, selector1 ); c_int32_5p0 = _mm512_add_epi32( selector1, c_int32_5p0 ); } + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_6x16: + { + selector1 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j ); + + // c[0,0-15] + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + // c[3,0-15] + c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); + + // c[4,0-15] + c_int32_4p0 = _mm512_add_epi32( selector1, c_int32_4p0 ); + + // c[5,0-15] + c_int32_5p0 = _mm512_add_epi32( selector1, c_int32_5p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_6x16: + { + selector1 = _mm512_setzero_epi32(); + + // c[0,0-15] + c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_max_epi32( selector1, c_int32_2p0 ); + + // c[3,0-15] + c_int32_3p0 = _mm512_max_epi32( selector1, c_int32_3p0 ); + + // c[4,0-15] + c_int32_4p0 = _mm512_max_epi32( selector1, c_int32_4p0 ); + + // c[5,0-15] + c_int32_5p0 = _mm512_max_epi32( selector1, c_int32_5p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_6x16_DISABLE: + ; // Store the results. // c[0,0-15] @@ -639,6 +799,7 @@ void lpgemm_rowvar_u8s8s32o32_6x16 _mm512_storeu_epi32( c + ( rs_c * ( ir + 5 ) ) + ( 0*16 ), c_int32_5p0 ); a = a + ( MR * ps_a ); + post_op_c_i += MR; } if ( m_partial_pieces > 0 ) @@ -652,7 +813,10 @@ void lpgemm_rowvar_u8s8s32o32_6x16 a, rs_a, cs_a_use, b, rs_b, cs_b, ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, - alpha, beta + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list ); } else if ( m_partial_pieces == 4 ) @@ -664,7 +828,10 @@ void lpgemm_rowvar_u8s8s32o32_6x16 a, rs_a, cs_a_use, b, rs_b, cs_b, ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, - alpha, beta + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list ); } else if ( m_partial_pieces == 3 ) @@ -676,7 +843,10 @@ void lpgemm_rowvar_u8s8s32o32_6x16 a, rs_a, cs_a_use, b, rs_b, cs_b, ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, - alpha, beta + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list ); } else if ( m_partial_pieces == 2 ) @@ -688,7 +858,10 @@ void lpgemm_rowvar_u8s8s32o32_6x16 a, rs_a, cs_a_use, b, rs_b, cs_b, ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, - alpha, beta + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list ); } else if ( m_partial_pieces == 1 ) @@ -700,7 +873,10 @@ void lpgemm_rowvar_u8s8s32o32_6x16 a, rs_a, cs_a_use, b, rs_b, cs_b, ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, - alpha, beta + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list ); } } @@ -721,9 +897,19 @@ void lpgemm_rowvar_u8s8s32o32_6x32 int32_t* c, const dim_t rs_c, const int32_t alpha, - const int32_t beta + const int32_t beta, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op* post_ops_list ) { + static void* post_ops_labels[] = + { + &&POST_OPS_6x32_DISABLE, + &&POST_OPS_BIAS_6x32, + &&POST_OPS_RELU_6x32 + }; dim_t MR = 6; dim_t m_full_pieces = m0 / MR; dim_t m_full_pieces_loop_limit = m_full_pieces * MR; @@ -734,6 +920,13 @@ void lpgemm_rowvar_u8s8s32o32_6x32 uint32_t a_kfringe_buf = 0; + // B matrix storage. + __m512i b0; + __m512i b1; + + // A matrix storage. + __m512i a_int32_0; + for ( dim_t ir = 0; ir < m_full_pieces_loop_limit; ir += MR ) { // Registers to use for accumulating C. @@ -762,11 +955,11 @@ void lpgemm_rowvar_u8s8s32o32_6x32 // instructions and each load to ZMM register will have 4 elements // along k direction and 16 elements across n directions, so 4x16 // elements to a ZMM register. - __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); - __m512i b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); // Broadcast a[0,kr:kr+4]. - __m512i a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 4. // c[0,0-31] = a[0,kr:kr+4]*b[kr:kr+4,0-31] @@ -816,8 +1009,8 @@ void lpgemm_rowvar_u8s8s32o32_6x32 // Handle k remainder. if ( k_partial_pieces > 0 ) { - __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); - __m512i b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); // Broadcast a[0,kr:kr+4]. memcpy @@ -826,7 +1019,7 @@ void lpgemm_rowvar_u8s8s32o32_6x32 ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - __m512i a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 4. // c[0,0-31] = a[0,kr:kr+4]*b[kr:kr+4,0-31] @@ -990,6 +1183,101 @@ void lpgemm_rowvar_u8s8s32o32_6x32 selector1 = _mm512_mullo_epi32( selector2, selector1 ); c_int32_5p1 = _mm512_add_epi32( selector1, c_int32_5p1 ); } + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_6x32: + { + selector1 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 ) ); + + // c[0,0-15] + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + c_int32_0p1 = _mm512_add_epi32( selector2, c_int32_0p1 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[1, 16-31] + c_int32_1p1 = _mm512_add_epi32( selector2, c_int32_1p1 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + // c[2, 16-31] + c_int32_2p1 = _mm512_add_epi32( selector2, c_int32_2p1 ); + + // c[3,0-15] + c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); + + // c[3, 16-31] + c_int32_3p1 = _mm512_add_epi32( selector2, c_int32_3p1 ); + + // c[4,0-15] + c_int32_4p0 = _mm512_add_epi32( selector1, c_int32_4p0 ); + + // c[4, 16-31] + c_int32_4p1 = _mm512_add_epi32( selector2, c_int32_4p1 ); + + // c[5,0-15] + c_int32_5p0 = _mm512_add_epi32( selector1, c_int32_5p0 ); + + // c[5, 16-31] + c_int32_5p1 = _mm512_add_epi32( selector2, c_int32_5p1 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_6x32: + { + selector1 = _mm512_setzero_epi32(); + + // c[0,0-15] + c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + c_int32_0p1 = _mm512_max_epi32( selector1, c_int32_0p1 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); + + // c[1,16-31] + c_int32_1p1 = _mm512_max_epi32( selector1, c_int32_1p1 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_max_epi32( selector1, c_int32_2p0 ); + + // c[2,16-31] + c_int32_2p1 = _mm512_max_epi32( selector1, c_int32_2p1 ); + + // c[3,0-15] + c_int32_3p0 = _mm512_max_epi32( selector1, c_int32_3p0 ); + + // c[3,16-31] + c_int32_3p1 = _mm512_max_epi32( selector1, c_int32_3p1 ); + + // c[4,0-15] + c_int32_4p0 = _mm512_max_epi32( selector1, c_int32_4p0 ); + + // c[4,16-31] + c_int32_4p1 = _mm512_max_epi32( selector1, c_int32_4p1 ); + + // c[5,0-15] + c_int32_5p0 = _mm512_max_epi32( selector1, c_int32_5p0 ); + + // c[5,16-31] + c_int32_5p1 = _mm512_max_epi32( selector1, c_int32_5p1 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_6x32_DISABLE: + ; // Store the results. // c[0,0-15] @@ -1029,6 +1317,7 @@ void lpgemm_rowvar_u8s8s32o32_6x32 _mm512_storeu_epi32( c + ( rs_c * ( ir + 5 ) ) + ( 1*16 ), c_int32_5p1 ); a = a + ( MR * ps_a ); + post_op_c_i += MR; } if ( m_partial_pieces > 0 ) @@ -1042,7 +1331,10 @@ void lpgemm_rowvar_u8s8s32o32_6x32 a, rs_a, cs_a_use, b, rs_b, cs_b, ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, - alpha, beta + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list ); } else if ( m_partial_pieces == 4 ) @@ -1054,7 +1346,10 @@ void lpgemm_rowvar_u8s8s32o32_6x32 a, rs_a, cs_a_use, b, rs_b, cs_b, ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, - alpha, beta + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list ); } else if ( m_partial_pieces == 3 ) @@ -1066,7 +1361,10 @@ void lpgemm_rowvar_u8s8s32o32_6x32 a, rs_a, cs_a_use, b, rs_b, cs_b, ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, - alpha, beta + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list ); } else if ( m_partial_pieces == 2 ) @@ -1078,7 +1376,10 @@ void lpgemm_rowvar_u8s8s32o32_6x32 a, rs_a, cs_a_use, b, rs_b, cs_b, ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, - alpha, beta + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list ); } else if ( m_partial_pieces == 1 ) @@ -1090,7 +1391,10 @@ void lpgemm_rowvar_u8s8s32o32_6x32 a, rs_a, cs_a_use, b, rs_b, cs_b, ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, - alpha, beta + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list ); } } @@ -1111,9 +1415,19 @@ void lpgemm_rowvar_u8s8s32o32_6x48 int32_t* c, const dim_t rs_c, const int32_t alpha, - const int32_t beta + const int32_t beta, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op* post_ops_list ) { + static void* post_ops_labels[] = + { + &&POST_OPS_6x48_DISABLE, + &&POST_OPS_BIAS_6x48, + &&POST_OPS_RELU_6x48 + }; dim_t MR = 6; dim_t m_full_pieces = m0 / MR; dim_t m_full_pieces_loop_limit = m_full_pieces * MR; @@ -1124,6 +1438,14 @@ void lpgemm_rowvar_u8s8s32o32_6x48 uint32_t a_kfringe_buf = 0; + // B matrix storage. + __m512i b0; + __m512i b1; + __m512i b2; + + // A matrix storage. + __m512i a_int32_0; + for ( dim_t ir = 0; ir < m_full_pieces_loop_limit; ir += MR ) { // Registers to use for accumulating C. @@ -1158,12 +1480,12 @@ void lpgemm_rowvar_u8s8s32o32_6x48 // instructions and each load to ZMM register will have 4 elements // along k direction and 16 elements across n directions, so 4x16 // elements to a ZMM register. - __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); - __m512i b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); - __m512i b2 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 2 ) ); + b0 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b1 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + b2 = _mm512_loadu_epi8( b + ( rs_b * kr ) + ( cs_b * 2 ) ); // Broadcast a[0,kr:kr+4]. - __m512i a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + a_int32_0 = _mm512_set1_epi32( *( uint32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 4. // c[0,0-47] = a[0,kr:kr+4]*b[kr:kr+4,0-47] @@ -1219,9 +1541,9 @@ void lpgemm_rowvar_u8s8s32o32_6x48 // Handle k remainder. if ( k_partial_pieces > 0 ) { - __m512i b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); - __m512i b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); - __m512i b2 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); + b0 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b1 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + b2 = _mm512_loadu_epi8( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); // Broadcast a[0,kr:kr+4]. memcpy @@ -1230,7 +1552,7 @@ void lpgemm_rowvar_u8s8s32o32_6x48 ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( uint8_t ) ) ); - __m512i a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_int32_0 = _mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 4. // c[0,0-47] = a[0,kr:kr+4]*b[kr:kr+4,0-47] @@ -1436,6 +1758,141 @@ void lpgemm_rowvar_u8s8s32o32_6x48 selector1 = _mm512_mullo_epi32( selector2, selector1 ); c_int32_5p2 = _mm512_add_epi32( selector1, c_int32_5p2 ); } + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_6x48: + { + selector1 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 ) ); + a_int32_0 = + _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 2 * 16 ) ); + + // c[0,0-15] + c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + c_int32_0p1 = _mm512_add_epi32( selector2, c_int32_0p1 ); + + // c[0,32-47] + c_int32_0p2 = _mm512_add_epi32( a_int32_0, c_int32_0p2 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_add_epi32( selector1, c_int32_1p0 ); + + // c[1, 16-31] + c_int32_1p1 = _mm512_add_epi32( selector2, c_int32_1p1 ); + + // c[1,32-47] + c_int32_1p2 = _mm512_add_epi32( a_int32_0, c_int32_1p2 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_add_epi32( selector1, c_int32_2p0 ); + + // c[2, 16-31] + c_int32_2p1 = _mm512_add_epi32( selector2, c_int32_2p1 ); + + // c[2,32-47] + c_int32_2p2 = _mm512_add_epi32( a_int32_0, c_int32_2p2 ); + + // c[3,0-15] + c_int32_3p0 = _mm512_add_epi32( selector1, c_int32_3p0 ); + + // c[3, 16-31] + c_int32_3p1 = _mm512_add_epi32( selector2, c_int32_3p1 ); + + // c[3,32-47] + c_int32_3p2 = _mm512_add_epi32( a_int32_0, c_int32_3p2 ); + + // c[4,0-15] + c_int32_4p0 = _mm512_add_epi32( selector1, c_int32_4p0 ); + + // c[4, 16-31] + c_int32_4p1 = _mm512_add_epi32( selector2, c_int32_4p1 ); + + // c[4,32-47] + c_int32_4p2 = _mm512_add_epi32( a_int32_0, c_int32_4p2 ); + + // c[5,0-15] + c_int32_5p0 = _mm512_add_epi32( selector1, c_int32_5p0 ); + + // c[5, 16-31] + c_int32_5p1 = _mm512_add_epi32( selector2, c_int32_5p1 ); + + // c[5,32-47] + c_int32_5p2 = _mm512_add_epi32( a_int32_0, c_int32_5p2 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_6x48: + { + //printf("relu\n"); + selector1 = _mm512_setzero_epi32(); + + // c[0,0-15] + c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); + + // c[0, 16-31] + c_int32_0p1 = _mm512_max_epi32( selector1, c_int32_0p1 ); + + // c[0,32-47] + c_int32_0p2 = _mm512_max_epi32( selector1, c_int32_0p2 ); + + // c[1,0-15] + c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); + + // c[1,16-31] + c_int32_1p1 = _mm512_max_epi32( selector1, c_int32_1p1 ); + + // c[1,32-47] + c_int32_1p2 = _mm512_max_epi32( selector1, c_int32_1p2 ); + + // c[2,0-15] + c_int32_2p0 = _mm512_max_epi32( selector1, c_int32_2p0 ); + + // c[2,16-31] + c_int32_2p1 = _mm512_max_epi32( selector1, c_int32_2p1 ); + + // c[2,32-47] + c_int32_2p2 = _mm512_max_epi32( selector1, c_int32_2p2 ); + + // c[3,0-15] + c_int32_3p0 = _mm512_max_epi32( selector1, c_int32_3p0 ); + + // c[3,16-31] + c_int32_3p1 = _mm512_max_epi32( selector1, c_int32_3p1 ); + + // c[3,32-47] + c_int32_3p2 = _mm512_max_epi32( selector1, c_int32_3p2 ); + + // c[4,0-15] + c_int32_4p0 = _mm512_max_epi32( selector1, c_int32_4p0 ); + + // c[4,16-31] + c_int32_4p1 = _mm512_max_epi32( selector1, c_int32_4p1 ); + + // c[4,32-47] + c_int32_4p2 = _mm512_max_epi32( selector1, c_int32_4p2 ); + + // c[5,0-15] + c_int32_5p0 = _mm512_max_epi32( selector1, c_int32_5p0 ); + + // c[5,16-31] + c_int32_5p1 = _mm512_max_epi32( selector1, c_int32_5p1 ); + + // c[5,32-47] + c_int32_5p2 = _mm512_max_epi32( selector1, c_int32_5p2 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_6x48_DISABLE: + ; // Store the results. // c[0,0-15] @@ -1493,6 +1950,7 @@ void lpgemm_rowvar_u8s8s32o32_6x48 _mm512_storeu_epi32( c + ( rs_c * ( ir + 5 ) ) + ( 2*16 ), c_int32_5p2 ); a = a + ( MR * ps_a ); + post_op_c_i += MR; } if ( m_partial_pieces > 0 ) @@ -1506,7 +1964,10 @@ void lpgemm_rowvar_u8s8s32o32_6x48 a, rs_a, cs_a_use, b, rs_b, cs_b, ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, - alpha, beta + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list ); } else if ( m_partial_pieces == 4 ) @@ -1518,7 +1979,10 @@ void lpgemm_rowvar_u8s8s32o32_6x48 a, rs_a, cs_a_use, b, rs_b, cs_b, ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, - alpha, beta + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list ); } else if ( m_partial_pieces == 3 ) @@ -1530,7 +1994,10 @@ void lpgemm_rowvar_u8s8s32o32_6x48 a, rs_a, cs_a_use, b, rs_b, cs_b, ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, - alpha, beta + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list ); } else if ( m_partial_pieces == 2 ) @@ -1542,7 +2009,10 @@ void lpgemm_rowvar_u8s8s32o32_6x48 a, rs_a, cs_a_use, b, rs_b, cs_b, ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, - alpha, beta + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list ); } else if ( m_partial_pieces == 1 ) @@ -1554,7 +2024,10 @@ void lpgemm_rowvar_u8s8s32o32_6x48 a, rs_a, cs_a_use, b, rs_b, cs_b, ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, - alpha, beta + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list ); } } diff --git a/bench/bench_aocl_gemm/bench_lpgemm.c b/bench/bench_aocl_gemm/bench_lpgemm.c index 37d6254e7b..90a76e6868 100644 --- a/bench/bench_aocl_gemm/bench_lpgemm.c +++ b/bench/bench_aocl_gemm/bench_lpgemm.c @@ -34,6 +34,21 @@ void fill_array_ ## ctype ( void* arr, dim_t size ) \ GEN_FILL_ARRAY_FUNC(uint8_t) GEN_FILL_ARRAY_FUNC(int8_t) GEN_FILL_ARRAY_FUNC(float) +GEN_FILL_ARRAY_FUNC(int32_t) + +#define GEN_FILL_ARRAY_POST_OPS_FUNC(ctype) \ +void fill_array_post_ops_ ## ctype ( void* arr, dim_t size ) \ +{ \ + ctype* temp_arr = ( ctype* ) arr; \ + for ( dim_t i = 0; i < size; ++i ) \ + { \ + temp_arr[i] = ( ctype )( i % 20 ); \ + } \ +} \ + +GEN_FILL_ARRAY_POST_OPS_FUNC(int16_t) +GEN_FILL_ARRAY_POST_OPS_FUNC(int32_t) +GEN_FILL_ARRAY_POST_OPS_FUNC(float) #define GEN_BLIS_MAT_MUL_FUNC(A_type,B_type,C_type,BLAS_SFX) \ void mat_mul_ ## BLAS_SFX \ @@ -49,7 +64,8 @@ void mat_mul_ ## BLAS_SFX \ dim_t ldb, \ C_type beta, \ C_type* c, \ - dim_t ldc \ + dim_t ldc, \ + aocl_post_op* post_op\ ) \ { \ char transa = 'n'; \ @@ -75,7 +91,62 @@ void mat_mul_ ## BLAS_SFX \ a, lda, reordera, \ b, ldb, reorderb, \ beta, \ - c, ldc ); \ + c, ldc, post_op ); \ + \ + /*dim_t MR = 6; \ + dim_t NR = 16; \ + \ + __m512i selector1; \ + __m512i all_zero = _mm512_setzero_epi32(); \ + __m512i c0; \ + __m512i c1; \ + __m512i c2; \ + __m512i c3; \ + __m512i c4; \ + __m512i c5; \ + \ + for ( dim_t i = 0; i < m; i += MR ) \ + { \ + if ( ( i + MR ) > m ) \ + { \ + break; \ + } \ + for ( dim_t j = 0; j < n; j += NR ) \ + { \ + if ( ( j + NR ) > n ) \ + { \ + break; \ + } \ + selector1 = _mm512_loadu_epi32( (int32_t*)post_op->bias.bias + j ); \ + c0 = _mm512_loadu_epi32( c + ( ( i + 0 ) * ldc ) + j ); \ + c1 = _mm512_loadu_epi32( c + ( ( i + 1 ) * ldc ) + j ); \ + c2 = _mm512_loadu_epi32( c + ( ( i + 2 ) * ldc ) + j ); \ + c3 = _mm512_loadu_epi32( c + ( ( i + 3 ) * ldc ) + j ); \ + c4 = _mm512_loadu_epi32( c + ( ( i + 4 ) * ldc ) + j ); \ + c5 = _mm512_loadu_epi32( c + ( ( i + 5 ) * ldc ) + j ); \ + \ + c0 = _mm512_add_epi32( selector1, c0 ); \ + c1 = _mm512_add_epi32( selector1, c1 ); \ + c2 = _mm512_add_epi32( selector1, c2 ); \ + c3 = _mm512_add_epi32( selector1, c3 ); \ + c4 = _mm512_add_epi32( selector1, c4 ); \ + c5 = _mm512_add_epi32( selector1, c5 ); \ + \ + c0 = _mm512_max_epi32( all_zero, c0 ); \ + c1 = _mm512_max_epi32( all_zero, c1 ); \ + c2 = _mm512_max_epi32( all_zero, c2 ); \ + c3 = _mm512_max_epi32( all_zero, c3 ); \ + c4 = _mm512_max_epi32( all_zero, c4 ); \ + c5 = _mm512_max_epi32( all_zero, c5 ); \ + \ + _mm512_storeu_epi32( c + ( ( i + 0 ) * ldc ) + j, c0 ); \ + _mm512_storeu_epi32( c + ( ( i + 1 ) * ldc ) + j, c1 ); \ + _mm512_storeu_epi32( c + ( ( i + 2 ) * ldc ) + j, c2 ); \ + _mm512_storeu_epi32( c + ( ( i + 3 ) * ldc ) + j, c3 ); \ + _mm512_storeu_epi32( c + ( ( i + 4 ) * ldc ) + j, c4 ); \ + _mm512_storeu_epi32( c + ( ( i + 5 ) * ldc ) + j, c5 ); \ + } \ + } */\ } \ GEN_BLIS_MAT_MUL_FUNC(uint8_t, int8_t, int16_t, u8s8s16os16) @@ -127,7 +198,8 @@ void mat_mul_bench_driver_ ## BLAS_SFX \ dim_t ldb, \ C_type beta, \ C_type* c, \ - dim_t ldc \ + dim_t ldc, \ + aocl_post_op* post_op\ ) \ { \ double min_time_diff = DBL_MAX; \ @@ -148,7 +220,8 @@ void mat_mul_bench_driver_ ## BLAS_SFX \ a, lda, \ b, ldb, \ beta, \ - c, ldc \ + c, ldc, \ + post_op \ ); \ \ clock_gettime(CLOCK_MONOTONIC, &tend); \ @@ -182,7 +255,8 @@ void mat_mul_accuracy_check_driver_ ## BLAS_SFX \ C_type* c, \ dim_t ldc, \ C_type* c_ref, \ - dim_t ldc_ref \ + dim_t ldc_ref, \ + aocl_post_op* post_op\ ) \ { \ for ( dim_t i = 0; i < m; ++i ) \ @@ -198,6 +272,34 @@ void mat_mul_accuracy_check_driver_ ## BLAS_SFX \ \ temp_accum = ( beta * ( * (c_ref + ( ldc_ref * i ) + j ) ) ) \ + ( alpha * temp_accum ); \ + \ + if ( post_op != NULL ) \ + { \ + /* Apply bias followed by relu. */ \ + if ( post_op->seq_vector[0] == BIAS ) \ + { \ + if ( post_op->seq_length >= 1 ) \ + { \ + temp_accum += ( *( ( int32_t* )post_op->bias.bias + j ) ); \ + } \ + if ( post_op->seq_length > 1 ) \ + { \ + temp_accum = ( temp_accum > 0 ) ? temp_accum : 0 ; \ + } \ + } \ + else if ( post_op->seq_vector[0] == ELTWISE ) \ + { \ + if ( post_op->seq_length >= 1 ) \ + { \ + temp_accum = ( temp_accum > 0 ) ? temp_accum : 0 ; \ + } \ + if ( post_op->seq_length > 1 ) \ + { \ + temp_accum += ( *( ( int32_t* )post_op->bias.bias + j ) ); \ + } \ + } \ + } \ + \ if ( *( c + ( ldc * i ) + j ) != temp_accum ) \ { \ if ( fout ) \ @@ -219,6 +321,100 @@ cleanup_acc: \ GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t, int8_t, int16_t, u8s8s16os16) GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t,int8_t,int32_t,u8s8s32os32) GEN_MAT_MUL_ACC_CHK_DRV_FUNC(float,float,float,f32f32f32of32) + +/* Only supports bias followed by RELU and vice versa for now.*/ \ +#define GEN_MAT_MUL_POST_OPS_CREATOR(C_type,BLAS_SFX) \ +aocl_post_op* lpgemm_create_post_ops_struct_ ## BLAS_SFX \ + ( \ + dim_t m, \ + dim_t n, \ + char* post_ops_str \ + ) \ +{ \ + aocl_post_op* post_ops = NULL; \ + post_ops = ( aocl_post_op* ) malloc( sizeof( aocl_post_op ) ); \ + \ + if ( post_ops == NULL ) \ + { \ + return NULL; \ + } \ + \ + /* Only supporting 2 post ops at max for now.*/ \ + dim_t max_post_ops_seq_length = 2; \ + post_ops->seq_vector = ( AOCL_POST_OP_TYPE* ) \ + malloc \ + ( \ + max_post_ops_seq_length * \ + sizeof( AOCL_POST_OP_TYPE ) \ + ); \ + \ + if ( post_ops->seq_vector == NULL ) \ + { \ + free( post_ops ); \ + return NULL; \ + } \ + \ + /* Parse post ops list.*/ \ + char* ops_tok = strtok(post_ops_str, ", " ); \ + dim_t cur_op_index = 0; \ + while ( ops_tok ) \ + { \ + if ( strcmp( ops_tok, "bias") == 0 ) \ + { \ + post_ops->seq_vector[cur_op_index] = BIAS; \ + } \ + else if ( strcmp( ops_tok, "relu") == 0 ) \ + { \ + post_ops->seq_vector[cur_op_index] = ELTWISE; \ + } \ + ops_tok = strtok( NULL, ", " ); \ + cur_op_index++; \ + } \ + post_ops->seq_length = cur_op_index; \ + \ + /* Allocate bias buffer, return early if alloc fails.*/ \ + post_ops->bias.bias = malloc( n * sizeof( C_type ) ); \ + if ( post_ops->bias.bias == NULL ) \ + { \ + free( post_ops->seq_vector ); \ + free( post_ops ); \ + return NULL; \ + } \ + \ + GEN_FUNC_NAME(fill_array_post_ops_,C_type)( post_ops->bias.bias, n ); \ + \ + post_ops->eltwise.is_power_of_2 = FALSE; \ + post_ops->eltwise.scale_factor = NULL; \ + post_ops->eltwise.algo.alpha = NULL; \ + post_ops->eltwise.algo.beta = NULL; \ + post_ops->eltwise.algo.algo_type = RELU; \ + \ + return post_ops; \ +} \ + +GEN_MAT_MUL_POST_OPS_CREATOR(int16_t,u8s8s16os16) +GEN_MAT_MUL_POST_OPS_CREATOR(int32_t,u8s8s32os32) +GEN_MAT_MUL_POST_OPS_CREATOR(float,f32f32f32of32) + +void lpgemm_destroy_post_ops_struct( aocl_post_op* post_ops ) +{ + if ( post_ops == NULL ) + { + return; + } + + if ( post_ops->bias.bias != NULL ) + { + free( post_ops->bias.bias ); + } + + if( post_ops->seq_vector != NULL ) + { + free( post_ops->seq_vector ); + } + + free( post_ops ); +} #define GEN_MAT_MUL_BENCH_MAIN_FUNC(A_type,B_type,C_type,BLAS_SFX) \ void mat_mul_bench_main_ ## BLAS_SFX \ @@ -231,7 +427,8 @@ void mat_mul_bench_main_ ## BLAS_SFX \ int32_t k, \ int32_t stride_a, \ int32_t stride_b, \ - int32_t stride_c \ + int32_t stride_c, \ + char* post_ops_str \ ) \ { \ if ( ( op_t != 'p' ) && ( op_t != 'P' ) && ( op_t != 'r' ) && ( op_t != 'R' ) ) \ @@ -272,6 +469,17 @@ void mat_mul_bench_main_ ## BLAS_SFX \ \ GEN_FUNC_NAME(fill_array_,A_type)( a, ( m * k ) ); \ GEN_FUNC_NAME(fill_array_,B_type)( b, ( k * n ) ); \ + \ + aocl_post_op* post_op = NULL; \ + if ( post_ops_str != NULL ) \ + { \ + post_op = GEN_FUNC_NAME(lpgemm_create_post_ops_struct_,BLAS_SFX)( m, n, post_ops_str ); \ + if ( post_op == NULL ) \ + { \ + printf(" post op struct allocation failure, returning.\n"); \ + return; \ + } \ + } \ \ if ( ( op_t == 'p' ) || ( op_t == 'P' ) ) \ { \ @@ -283,7 +491,8 @@ void mat_mul_bench_main_ ## BLAS_SFX \ a, stride_a, \ b, stride_b, \ beta, \ - c, stride_c \ + c, stride_c, \ + post_op \ ); \ } \ else if ( ( op_t == 'r' ) || ( op_t == 'R' ) ) \ @@ -302,7 +511,8 @@ void mat_mul_bench_main_ ## BLAS_SFX \ a, stride_a, \ b_reorder, stride_b, \ beta, \ - c, stride_c \ + c, stride_c, \ + post_op \ ); \ \ bli_free_user( b_reorder ); \ @@ -319,9 +529,12 @@ void mat_mul_bench_main_ ## BLAS_SFX \ b, stride_b, \ beta, \ c, stride_c, \ - c_ref, stride_c \ + c_ref, stride_c, \ + post_op \ ); \ } \ + \ + lpgemm_destroy_post_ops_struct( post_op ); \ \ if ( a != NULL ) \ { \ @@ -350,19 +563,24 @@ int main( int argc, char** argv ) FILE* fin = NULL; if ( argc < 5 ) { - printf( "Usage: ./mat_mul -i input.txt -m mode < -n 1000 >\nMode is either a or p." \ - " a is used for accuracy test, whereas p is used for" \ - " performance benchmarking.\nn_repeats can be set" \ - " optionally using -n argument.\n" ); + printf( "Usage: ./mat_mul -i input.txt -m mode < -n 1000 -o op1,op2.. >" \ + "\nMode is either a or p. a is used for accuracy test, " \ + " whereas p is used for performance benchmarking." \ + "\nn_repeats can be set optionally using -n arg." \ + "\nPost ops can be executed optionaly by providing a " \ + "coma separated list of ops after -o arg.\n Currently " \ + "bias,relu and relu,bias is supported.\n" ); exit( 1 ); } char* file_name = NULL; + char* post_ops_str = NULL; + char* post_ops_str_dest = NULL; //Strtok is used to parse, need to maintain a copy. // Parse CLI arguments. opterr = 0; int opt_val; - while ( ( opt_val = getopt( argc, argv, "i:m:n:" ) ) != -1 ) + while ( ( opt_val = getopt( argc, argv, "i:m:n:o:" ) ) != -1 ) { switch ( opt_val ) { @@ -375,11 +593,19 @@ int main( int argc, char** argv ) case 'n': global_n_repeat = ( atoi( optarg ) > 0 ) ? atoi( optarg ) : 0; break; + case 'o': + post_ops_str = optarg; + break; default: break; } } + if ( post_ops_str != NULL ) + { + post_ops_str_dest = strdup( post_ops_str ); + } + if ( bench_mode == 'p' ) { printf( "Running bench in performance benchmarking mode.\n" ); @@ -411,8 +637,8 @@ int main( int argc, char** argv ) int32_t m, n, k; int32_t stride_a, stride_b, stride_c; - const dim_t len_list_omp_cores_for_testing = 6; - const dim_t list_omp_cores_for_testing[6] = { 100, 80, 64, 24, 8, 1 }; + const dim_t len_list_omp_cores_for_testing = 2; + const dim_t list_omp_cores_for_testing[2] = { 80, 1 }; dim_t core_index = 0; bool can_run = TRUE; @@ -454,7 +680,8 @@ int main( int argc, char** argv ) GEN_FUNC_NAME(mat_mul_bench_main_,u8s8s32os32) ( fin, fout, op_t, - m, n, k, stride_a, stride_b, stride_c + m, n, k, stride_a, stride_b, stride_c, + post_ops_str_dest ); } else if ( ( op_type_char == 'f' ) || ( op_type_char == 'F' ) ) @@ -462,7 +689,8 @@ int main( int argc, char** argv ) GEN_FUNC_NAME(mat_mul_bench_main_,f32f32f32of32) ( fin, fout, op_t, - m, n, k, stride_a, stride_b, stride_c + m, n, k, stride_a, stride_b, stride_c, + NULL ); } else if ((op_type_char == 's') || (op_type_char == 'S')) @@ -470,12 +698,21 @@ int main( int argc, char** argv ) GEN_FUNC_NAME(mat_mul_bench_main_, u8s8s16os16) ( fin, fout, op_t, - m, n, k, stride_a, stride_b, stride_c + m, n, k, stride_a, stride_b, stride_c, + NULL ); } + if ( post_ops_str != NULL ) + { + strcpy( post_ops_str_dest, post_ops_str ); + } } } + if ( post_ops_str_dest != NULL ) + { + free( post_ops_str_dest ); + } if ( fin ) { fclose( fin ); @@ -485,4 +722,4 @@ int main( int argc, char** argv ) fclose( fout ); } return 0; -} \ No newline at end of file +} From 60de0a18568565783cdfb707f8b3858c51681f40 Mon Sep 17 00:00:00 2001 From: Harihara Sudhan S Date: Wed, 3 Aug 2022 14:27:03 +0530 Subject: [PATCH 166/243] Multithreading and support for unpacked B matrix in u8s8s16os16 Fucntionality - When the B matrix is not reordered before the u8s8s16os16 compute kernel call packing of B matrix is done as part of the five loop algorithm. The state of B matrix (packed or unpacked) is given as an user input. - Packing of B matrix is done as part of the five loop compute. - Temprorary buffer for pack B is allocated in the five loop algorithm - Multithreading for computation kernel - Configuration constants for u8s8s16os16 are part of the lpgemm config AMD-Internal: [CPUPL-2171] Change-Id: I22b4f0ec7fc29a2add4be0cff7d75f92dd3e60b8 --- addon/aocl_gemm/aocl_gemm_u8s8s16os16.c | 43 ++- addon/aocl_gemm/aocl_gemm_u8s8s16os16.h | 2 + addon/aocl_gemm/aocl_gemm_u8s8s16os16_utils.c | 40 ++- addon/aocl_gemm/aocl_gemm_u8s8s16os16_utils.h | 28 +- addon/aocl_gemm/frame/lpgemm_config.c | 1 + .../threading/lpgemm_thread_decor_openmp.c | 56 +++ .../threading/lpgemm_thread_decor_openmp.h | 2 + .../frame/u8s8s16/lpgemm_reorder_s16.c | 5 +- .../aocl_gemm/frame/u8s8s16/lpgemm_u8s8s16.c | 323 ++++++++++++++---- .../aocl_gemm/frame/u8s8s16/lpgemm_u8s8s16.h | 36 +- .../kernels/u8s8s16/lpgemm_6x32rowmajor.h | 34 +- .../u8s8s16/lpgemm_6x32rowmajor_amd256.c | 58 ++-- .../kernels/u8s8s16/lpgemm_m_fringe_amd256.c | 112 +++--- .../kernels/u8s8s16/lpgemm_mn_fringe_amd256.c | 68 ++-- .../kernels/u8s8s16/lpgemm_n_fringe_amd256.c | 116 +++---- .../kernels/u8s8s16/lpgemm_packb_amd256.c | 23 +- .../kernels/u8s8s16/lpgemm_packb_s16.h | 24 +- 17 files changed, 637 insertions(+), 334 deletions(-) diff --git a/addon/aocl_gemm/aocl_gemm_u8s8s16os16.c b/addon/aocl_gemm/aocl_gemm_u8s8s16os16.c index ef0b4e227e..59613cfddd 100644 --- a/addon/aocl_gemm/aocl_gemm_u8s8s16os16.c +++ b/addon/aocl_gemm/aocl_gemm_u8s8s16os16.c @@ -38,6 +38,7 @@ #include "lpgemm_u8s8s16.h" #include "lpgemm_config.h" #include "lpgemm_utils.h" +#include "lpgemm_thread_decor_openmp.h" void aocl_gemm_u8s8s16os16 ( @@ -62,6 +63,13 @@ void aocl_gemm_u8s8s16os16 trans_t blis_transa; trans_t blis_transb; + // Check if avx ISA is supported, lpgemm u8s8s16os16 matmul only works with it. + if ( bli_cpuid_is_avx_supported() == FALSE ) + { + printf(" AVX2 ISA not supported by processor, cannot perform lpgemm.\n"); + return; // Error. + } + /* Initialize BLIS. */ bli_init_auto(); @@ -115,7 +123,7 @@ void aocl_gemm_u8s8s16os16 // the mtag_b is set to packed to enable runtime packing. if (mtag_b == UNPACKED) { - return; // Error. + mtag_b = PACK; } // Only unpacked A supported now. @@ -124,16 +132,37 @@ void aocl_gemm_u8s8s16os16 return; // Error. } + // Convert post op struct to post op linked list format. + lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS]; + lpgemm_translate_to_post_ops_list( post_op_unparsed, post_op_list ); + // Initialize a local runtime with global settings if necessary. Note // that in the case that a runtime is passed in, we make a local copy. rntm_t rntm_g; bli_rntm_init_from_global(&rntm_g); bli_membrk_rntm_set_membrk(&rntm_g); - lpgemm_rowvar_u8s8s16o16( - m, n, k, - a, rs_a, cs_a, - b, rs_b, cs_b, - c, rs_c, - alpha, beta); +#ifdef BLIS_ENABLE_OPENMP + lpgemm_u8s8s16o16_openmp_thread_decorator + ( + m, n, k, + a, rs_a, cs_a, mtag_a, + b, rs_b, cs_b, mtag_b, + c, rs_c, + alpha, beta, + &rntm_g, + post_op_list + ); +#else + lpgemm_u8s8s16o16_thread_decorator + ( + m, n, k, + a, rs_a, cs_a, mtag_a, + b, rs_b, cs_b, mtag_b, + c, rs_c, + alpha, beta, + &rntm_g, + post_op_list + ); +#endif } diff --git a/addon/aocl_gemm/aocl_gemm_u8s8s16os16.h b/addon/aocl_gemm/aocl_gemm_u8s8s16os16.h index 926948aac5..920e8806b0 100644 --- a/addon/aocl_gemm/aocl_gemm_u8s8s16os16.h +++ b/addon/aocl_gemm/aocl_gemm_u8s8s16os16.h @@ -35,6 +35,8 @@ #ifndef AOCL_GEMM_U8S8S16OS16_H #define AOCL_GEMM_U8S8S16OS16_H +#include "aocl_gemm_post_ops.h" + // Only supports matrices in row major format // Limitations: Supports mem_format_b = 'Reorder' BLIS_EXPORT_ADDON void aocl_gemm_u8s8s16os16 diff --git a/addon/aocl_gemm/aocl_gemm_u8s8s16os16_utils.c b/addon/aocl_gemm/aocl_gemm_u8s8s16os16_utils.c index 993dbb2dfb..cbbae09e1a 100644 --- a/addon/aocl_gemm/aocl_gemm_u8s8s16os16_utils.c +++ b/addon/aocl_gemm/aocl_gemm_u8s8s16os16_utils.c @@ -39,16 +39,25 @@ #include "lpgemm_utils.h" #include "lpgemm_reorder_s16.h" -siz_t aocl_get_reorder_buf_size_u8s8s16os16( - const char mat_type, - const dim_t k, - const dim_t n) +siz_t aocl_get_reorder_buf_size_u8s8s16os16 + ( + const char mat_type, + const dim_t k, + const dim_t n + ) { if ((k <= 0) || (n <= 0)) { return 0; // Error. } + // Check if avx ISA is supported, lpgemm u8s8s16os16 matmul only works with it. + if ( bli_cpuid_is_avx_supported() == FALSE ) + { + printf(" AVX2 ISA not supported by processor, cannot perform lpgemm.\n"); + return 0; // Error. + } + /* Initialize BLIS. */ bli_init_auto(); @@ -78,13 +87,15 @@ siz_t aocl_get_reorder_buf_size_u8s8s16os16( return size_req; } -void aocl_reorder_u8s8s16os16( - const char mat_type, - const int8_t *input_buf_addr, - int8_t *reorder_buf_addr, - const dim_t k, - const dim_t n, - const dim_t ldb) +void aocl_reorder_u8s8s16os16 + ( + const char mat_type, + const int8_t *input_buf_addr, + int8_t *reorder_buf_addr, + const dim_t k, + const dim_t n, + const dim_t ldb + ) { if ((input_buf_addr == NULL) || (reorder_buf_addr == NULL) || (k <= 0) || (n <= 0) || (ldb < n)) @@ -92,6 +103,13 @@ void aocl_reorder_u8s8s16os16( return; // Error. } + // Check if avx ISA is supported, lpgemm u8s8s16os16 matmul only works with it. + if ( bli_cpuid_is_avx_supported() == FALSE ) + { + printf(" AVX2 ISA not supported by processor, cannot perform lpgemm.\n"); + return; // Error. + } + /* Initialize BLIS. */ bli_init_auto(); diff --git a/addon/aocl_gemm/aocl_gemm_u8s8s16os16_utils.h b/addon/aocl_gemm/aocl_gemm_u8s8s16os16_utils.h index 5f76da8b38..21f7276be9 100644 --- a/addon/aocl_gemm/aocl_gemm_u8s8s16os16_utils.h +++ b/addon/aocl_gemm/aocl_gemm_u8s8s16os16_utils.h @@ -35,17 +35,21 @@ #ifndef AOCL_GEMM_U8S8S16OS16_UTILS_H #define AOCL_GEMM_U8S8S16OS16_UTILS_H -BLIS_EXPORT_ADDON siz_t aocl_get_reorder_buf_size_u8s8s16os16( - const char mat_type, - const dim_t k, - const dim_t n); - -BLIS_EXPORT_ADDON void aocl_reorder_u8s8s16os16( - const char mat_type, - const int8_t *input_buf_addr, - int8_t *reorder_buf_addr, - const dim_t k, - const dim_t n, - const dim_t ldb); +BLIS_EXPORT_ADDON siz_t aocl_get_reorder_buf_size_u8s8s16os16 + ( + const char mat_type, + const dim_t k, + const dim_t n + ); + +BLIS_EXPORT_ADDON void aocl_reorder_u8s8s16os16 + ( + const char mat_type, + const int8_t *input_buf_addr, + int8_t *reorder_buf_addr, + const dim_t k, + const dim_t n, + const dim_t ldb + ); #endif // AOCL_GEMM_U8S8S16OS16_UTILS_H \ No newline at end of file diff --git a/addon/aocl_gemm/frame/lpgemm_config.c b/addon/aocl_gemm/frame/lpgemm_config.c index a16147e4c1..a23fd409f1 100644 --- a/addon/aocl_gemm/frame/lpgemm_config.c +++ b/addon/aocl_gemm/frame/lpgemm_config.c @@ -60,6 +60,7 @@ BLIS_INLINE void lpgemm_set_block_sizes_global_cntx void aocl_lpgemm_init_global_cntx() { lpgemm_set_block_sizes_global_cntx( U8S8S32OS32, 144, 1024, 2048, 64, 6 ); + lpgemm_set_block_sizes_global_cntx( U8S8S16OS16, 144, 1024, 1024, 32, 6 ); } dim_t lpgemm_get_block_size_MC_global_cntx( AOCL_OPERATION_TYPE op_type ) diff --git a/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.c b/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.c index d35e04009f..afba56d6b1 100644 --- a/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.c +++ b/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.c @@ -36,6 +36,7 @@ #include "lpgemm_config.h" #include "lpgemm_thread_decor_openmp.h" #include "lpgemm_types.h" +#include "lpgemm_u8s8s16.h" #include "lpgemm_u8s8s32.h" #include "lpgemm_f32f32f32.h" @@ -222,6 +223,59 @@ BLIS_INLINE void lpgemm_adjust_ic_jc_ways } } +BLIS_INLINE void lpgemm_u8s8s16o16_get_threading + ( + dim_t* n_threads, + dim_t* ic_ways, + dim_t* jc_ways, + dim_t m, + dim_t n, + dim_t k, + rntm_t* rntm_g + ) +{ + *n_threads = bli_rntm_num_threads( rntm_g ); + *jc_ways = bli_rntm_jc_ways( rntm_g ); + *ic_ways = bli_rntm_ic_ways( rntm_g ); + + if ( ( ( *ic_ways ) > 0 ) || ( ( *jc_ways ) > 0 ) ) + { + // If BLIS_IC_NT or JC_NT are set. + // Default cases. + *ic_ways = ( ( *ic_ways ) > 0 ) ? ( *ic_ways ) : 1; + *jc_ways = ( ( *jc_ways ) > 0 ) ? ( *jc_ways ) : 1; + + *n_threads = ( *jc_ways ) * ( *ic_ways ); + } + else if ( ( *n_threads ) > 1 ) + { + + dim_t NR = 32; + //dim_t MR = 6; + + if ( n <= NR ) + { + // If n is less than micro panel dimension, allocating all threads + // to ic resulted in gains. + ( *ic_ways ) = ( *n_threads ); + ( *jc_ways ) = 1; + } + else + { + // If BLIS_NUM_THREADS are set, generate jc,ic from the same. + bli_thread_partition_2x2( ( *n_threads ), m, n, ic_ways, jc_ways ); + } + } + else + { + // Setting all the values to 1 in case n_threads <= 1. This ensures + // the threading parameters are valid. + *n_threads = 1; + *jc_ways = 1; + *ic_ways = 1; + } +} + BLIS_INLINE void lpgemm_u8s8s32o32_get_threading ( dim_t* n_threads, @@ -443,6 +497,7 @@ void lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \ } \ } \ +GEN_LPGEMM_OPENMP_DECORATOR(uint8_t,int8_t,int16_t,u8s8s16o16) GEN_LPGEMM_OPENMP_DECORATOR(uint8_t,int8_t,int32_t,u8s8s32o32) GEN_LPGEMM_OPENMP_DECORATOR(float,float,float,f32f32f32of32) @@ -510,6 +565,7 @@ void lpgemm_ ## LPGEMM_SFX ## _thread_decorator \ ); \ } \ +GEN_LPGEMM_DECORATOR(uint8_t,int8_t,int32_t,u8s8s16o16) GEN_LPGEMM_DECORATOR(uint8_t,int8_t,int32_t,u8s8s32o32) GEN_LPGEMM_DECORATOR(float,float,float,f32f32f32of32) diff --git a/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.h b/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.h index 51a5941481..2b420b2f9c 100644 --- a/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.h +++ b/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.h @@ -62,6 +62,7 @@ void lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \ lpgemm_post_op* post_op_list \ ); \ +GEN_LPGEMM_OPENMP_DECORATOR_FN(uint8_t,int8_t,int16_t,u8s8s16o16) GEN_LPGEMM_OPENMP_DECORATOR_FN(uint8_t,int8_t,int32_t,u8s8s32o32) GEN_LPGEMM_OPENMP_DECORATOR_FN(float,float,float,f32f32f32of32) @@ -89,6 +90,7 @@ void lpgemm_ ## LPGEMM_SFX ## _thread_decorator \ lpgemm_post_op* post_op_list \ ); \ +GEN_LPGEMM_DECORATOR_FN(uint8_t,int8_t,int32_t,u8s8s16o16) GEN_LPGEMM_DECORATOR_FN(uint8_t,int8_t,int32_t,u8s8s32o32) GEN_LPGEMM_DECORATOR_FN(float,float,float,f32f32f32of32) diff --git a/addon/aocl_gemm/frame/u8s8s16/lpgemm_reorder_s16.c b/addon/aocl_gemm/frame/u8s8s16/lpgemm_reorder_s16.c index e900559943..f471244423 100644 --- a/addon/aocl_gemm/frame/u8s8s16/lpgemm_reorder_s16.c +++ b/addon/aocl_gemm/frame/u8s8s16/lpgemm_reorder_s16.c @@ -43,9 +43,8 @@ void aocl_reorderb_nr32_u8s8s16o16 lpgemm_obj_t *b_reorder ) { - // To Do: Constant declaration's to be moved to config - const dim_t NC = 1024; - const dim_t KC = 1024; + const dim_t NC = lpgemm_get_block_size_NC_global_cntx(U8S8S16OS16); + const dim_t KC = lpgemm_get_block_size_KC_global_cntx(U8S8S16OS16); // Extracting the matrix properties from the lpgemm object dim_t rs_b = b->rs; diff --git a/addon/aocl_gemm/frame/u8s8s16/lpgemm_u8s8s16.c b/addon/aocl_gemm/frame/u8s8s16/lpgemm_u8s8s16.c index d8e725a755..2926661eac 100644 --- a/addon/aocl_gemm/frame/u8s8s16/lpgemm_u8s8s16.c +++ b/addon/aocl_gemm/frame/u8s8s16/lpgemm_u8s8s16.c @@ -9,14 +9,14 @@ Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -34,75 +34,252 @@ #include "blis.h" #include "lpgemm_u8s8s16.h" -#include "lpgemm_packb.h" +#include "lpgemm_packb_s16.h" #include "lpgemm_6x32rowmajor.h" #include "lpgemm_utils.h" #include "lpgemm_config.h" +#include "lpgemm_thrinfo_utils.h" void lpgemm_rowvar_u8s8s16o16 - ( - const dim_t m, - const dim_t n, - const dim_t k, - const uint8_t *a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t *b, - const dim_t rs_b, - const dim_t cs_b, - int16_t *c, - const dim_t rs_c, - int16_t alpha, - int16_t beta - ) + ( + const dim_t m, + const dim_t n, + const dim_t k, + const uint8_t *a, + const dim_t rs_a, + const dim_t cs_a, + const AOCL_MEMORY_TAG mtag_a, + const int8_t *b, + const dim_t rs_b, + const dim_t cs_b, + const AOCL_MEMORY_TAG mtag_b, + int16_t *c, + const dim_t rs_c, + int16_t alpha, + int16_t beta, + rntm_t *rntm, + lpgemm_thrinfo_t *thread, + lpgemm_post_op* post_op_list + ) { - // To Do: Constant declaration's to be moved to config files - dim_t NC = 1024; - dim_t KC = 1024; - dim_t MC = 144; - dim_t NR = 32; - - const int8_t *b_use; - const uint8_t *a_use; - - for (dim_t jc = 0; jc < n; jc += NC) - { - dim_t nc0 = ((jc + NC) <= n) ? NC : (n % NC); - - for (dim_t pc = 0; pc < k; pc += KC) - { - int32_t beta0 = (pc == 0) ? beta : 1; - dim_t kc0 = ((pc + KC) <= k) ? KC : (k % KC); - - int kc0_updated = kc0; - - // Making multiple of 2 to suit k in vpmaddubsw - kc0_updated += (kc0_updated & 0x1); - - // B part getting processed - b_use = b + (jc * k) + (pc * nc0); - - for (dim_t ic = 0; ic < m; ic += MC) - { - dim_t mc0 = ((ic + MC) <= m) ? MC : (m % MC); - - a_use = a + (rs_a * ic) + (cs_a * pc); - - dim_t a_block_stride = rs_a; - - for (dim_t jr = 0; jr < nc0; jr += NR) - { - dim_t nr0 = ((jr + NR) <= nc0) ? NR : (nc0 % NR); - - // Calls for reorder B - lpgemm_rowvar_u8s8s16o16_6x32( - mc0, nr0, kc0, - a_use, rs_a, cs_a, a_block_stride, - (b_use + (jr * kc0_updated)), rs_b, cs_b, - (c + (rs_c * ic) + jc + jr), rs_c, 1, - alpha, beta0); - } - } - } - } + const dim_t NC = lpgemm_get_block_size_NC_global_cntx( U8S8S16OS16 ); + const dim_t KC = lpgemm_get_block_size_KC_global_cntx( U8S8S16OS16 ); + const dim_t MC = lpgemm_get_block_size_MC_global_cntx( U8S8S16OS16 ); + const dim_t NR = lpgemm_get_block_size_NR_global_cntx( U8S8S16OS16 ); + const dim_t MR = lpgemm_get_block_size_MR_global_cntx( U8S8S16OS16 ); + + if (mtag_b == UNPACKED) + { + // Error: can only work with packed B now. + return; + } + + const int8_t *b_use; + const uint8_t *a_use; + dim_t rs_a_use = rs_a; + dim_t cs_a_use = cs_a; + + dim_t rs_b_use = rs_b; + dim_t cs_b_use = cs_b; + + int16_t *c_use_jc = NULL; + int16_t *c_use_ic = NULL; + + // Pack buffer for B. + int8_t *pack_b_buffer_u8s8s16o16; + mem_t mem_b = BLIS_MEM_INITIALIZER; + dim_t packb_min_NR = 16; + siz_t mem_b_size_req = 0; + + // Making multiple of 2 to suit k in vpmaddubsw + dim_t k_updated = make_multiple_of_n( k, 2 ); + + // Generate thrinfo objects for jc and ic loops from lpgemm_thrinfo_t. + thrinfo_t thread_jc; + thrinfo_t thread_ic; + + lpgemm_gen_thrinfo(thread, &thread_jc, &thread_ic); + + // Compute the JC loop thread range for the current thread. + dim_t jc_start, jc_end; + bli_thread_range_sub(&thread_jc, n, NR, FALSE, &jc_start, &jc_end); + + for (dim_t jc = jc_start; jc < jc_end; jc += NC) + { + dim_t nc0 = bli_min((jc_end - jc), NC); + c_use_jc = c + jc; + + dim_t jc_cur_loop = jc; + dim_t jc_cur_loop_rem = 0; + dim_t n_sub_updated = 0; + + if (mtag_b == REORDERED) + { + get_B_panel_reordered_start_offset_width + ( + jc, n, NC, packb_min_NR, + &jc_cur_loop, &jc_cur_loop_rem, + &nc0, &n_sub_updated + ); + } + + for (dim_t pc = 0; pc < k; pc += KC) + { + int16_t beta0 = (pc == 0) ? beta : 1; + dim_t kc0 = bli_min((k - pc), KC); + + // kc0 needs to be a multiple of 2 so that it can be + // used with vpmaddubsw instruction. Padding is added in + // cases this condition is not satisfied, and therefore + // the kc0 offsets used for packed/reordered buffers + // needs to be updated. + dim_t kc0_updated = make_multiple_of_n(kc0, 2); + + if (mtag_b == PACK) + { + // Pack B chunks are based on jc work id. + dim_t jc_work_id = bli_thread_work_id(&thread_jc); + + // Using child thrinfo (thread_ic) tid to decide chief thread + // per B matrix chunk (jc work id group) + if (bli_thread_am_ochief(&thread_ic)) + { + // nc0 needs to be a multiple of 16 since this gives maximum + // vectorization. Packing B always results in buffers with width + // which is a multiple of 16. Subsequently the nc0 offsets used + // for packed/reordered buffers needs to be updated. + dim_t nc0_updated = make_multiple_of_n(nc0, packb_min_NR); + mem_b_size_req = sizeof(int8_t) * nc0_updated * kc0_updated; + + lpgemm_alloc_mem_panel( + mem_b_size_req, BLIS_BUFFER_FOR_B_PANEL, + &mem_b, rntm); + + thread->comm[jc_work_id].sent_object = + bli_mem_buffer(&mem_b); + } + + // All threads in work group should wait till chief thread has + // finished allocating the packing buffers. + bli_thrcomm_barrier( + bli_thread_ocomm_id(&thread_ic), + &thread->comm[jc_work_id]); + + pack_b_buffer_u8s8s16o16 = + (int8_t *)thread->comm[jc_work_id].sent_object; + + // Compute the B panel per thread loop range for parallel + // packing using ic_ways number of threads. Since atmost only + // ic_ways threads can be used, the thread_ic attributes are + // used to split the loop range. + dim_t jc_packb_start, jc_packb_end; + bli_thread_range_sub + ( + &thread_ic, nc0, NR, FALSE, + &jc_packb_start, &jc_packb_end + ); + + // Ensure thread ranges are valid, especially cases where no: + // of threads available for parallelization are greater than + // no: of B panel NR chunks. + if ((jc_packb_end > jc_packb_start) && + (jc_packb_start < (jc + nc0))) + { + packb_nr32_u8s8s16o16 + ( + pack_b_buffer_u8s8s16o16 + + (jc_packb_start * kc0_updated), + (b + (rs_b * pc) + (cs_b * jc) + + (cs_b * jc_packb_start)), + rs_b, + (jc_packb_end - jc_packb_start), kc0, + &rs_b_use, &cs_b_use + ); + } + else + { + get_packb_nr32_u8s8s16o16_strides(&rs_b_use, &cs_b_use); + } + + // All threads in work group should wait till B matrix packing + // is completed by the participating threads. + bli_thrcomm_barrier + ( + bli_thread_ocomm_id(&thread_ic), + &thread->comm[jc_work_id] + ); + + b_use = pack_b_buffer_u8s8s16o16; + } + else if (mtag_b == REORDERED) + { + // In multi-threaded scenarios, an extra offset into a given + // packed B panel is required, since the jc loop split can + // result in per thread start offset inside the panel, instead + // of panel boundaries. + b_use = b + (jc_cur_loop * k_updated) + + (n_sub_updated * pc) + + (jc_cur_loop_rem * kc0_updated); + + get_packb_nr32_u8s8s16o16_strides(&rs_b_use, &cs_b_use); + } + else + { + // Unpacked B not supported. + return; + } + + dim_t ic_start, ic_end; + bli_thread_range_sub(&thread_ic, m, MR, FALSE, &ic_start, &ic_end); + + for (dim_t ic = ic_start; ic < ic_end; ic += MC) + { + dim_t mc0 = bli_min((ic_end - ic), MC); + c_use_ic = c_use_jc + (rs_c * ic); + + a_use = a + (rs_a * ic) + (cs_a * pc); + cs_a_use = 1; + + dim_t a_block_stride = rs_a; + + for (dim_t jr = 0; jr < nc0; jr += NR) + { + dim_t nr0 = bli_min((nc0 - jr), NR); + + // Calls for reorder B + lpgemm_rowvar_u8s8s16o16_6x32 + ( + mc0, nr0, kc0, + a_use, rs_a_use, cs_a_use, a_block_stride, + (b_use + (jr * kc0_updated)), rs_b_use, cs_b_use, + (c_use_ic + +jr), rs_c, 1, + alpha, beta0 + ); + } + } + } + + if (mtag_b == REORDERED) + { + adjust_B_panel_reordered_jc(&jc, jc_cur_loop); + } + } + + // Release pack buffers. + if (mtag_b == PACK) + { + // All threads in work group should wait till B matrix usage is + // completed by the participating threads. + bli_thrcomm_barrier( + bli_thread_ocomm_id(&thread_jc), + &thread->comm[bli_thread_work_id(&thread_jc)]); + + if (bli_thread_am_ochief(&thread_ic)) + { + if (bli_mem_is_alloc(&mem_b)) + { + bli_membrk_release(rntm, &mem_b); + } + } + } } diff --git a/addon/aocl_gemm/frame/u8s8s16/lpgemm_u8s8s16.h b/addon/aocl_gemm/frame/u8s8s16/lpgemm_u8s8s16.h index 4802c90c29..ff1fcba8d8 100644 --- a/addon/aocl_gemm/frame/u8s8s16/lpgemm_u8s8s16.h +++ b/addon/aocl_gemm/frame/u8s8s16/lpgemm_u8s8s16.h @@ -37,22 +37,28 @@ #include "blis.h" #include "lpgemm_types.h" +#include "lpgemm_post_ops.h" void lpgemm_rowvar_u8s8s16o16 - ( - const dim_t m, - const dim_t n, - const dim_t k, - const uint8_t *a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t *b, - const dim_t rs_b, - const dim_t cs_b, - int16_t *c, - const dim_t rs_c, - int16_t alpha, - int16_t beta - ); + ( + const dim_t m, + const dim_t n, + const dim_t k, + const uint8_t *a, + const dim_t rs_a, + const dim_t cs_a, + const AOCL_MEMORY_TAG mtag_a, + const int8_t *b, + const dim_t rs_b, + const dim_t cs_b, + const AOCL_MEMORY_TAG mtag_b, + int16_t *c, + const dim_t rs_c, + int16_t alpha, + int16_t beta, + rntm_t *rntm, + lpgemm_thrinfo_t *thread, + lpgemm_post_op* post_op_list + ); #endif // LPGEMM_U8S8S16_H \ No newline at end of file diff --git a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_6x32rowmajor.h b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_6x32rowmajor.h index dca3170d3f..163241367b 100644 --- a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_6x32rowmajor.h +++ b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_6x32rowmajor.h @@ -37,22 +37,22 @@ // 6x32 int8o16 kernel void lpgemm_rowvar_u8s8s16o16_6x32 - ( - const dim_t m0, - const dim_t n0, - const dim_t k0, - const uint8_t *a, - const dim_t rs_a, - const dim_t cs_a, - const dim_t ps_a, - const int8_t *b, - const dim_t rs_b, - const dim_t cs_b, - int16_t *c, - const dim_t rs_c, - const dim_t cs_c, - const int16_t alpha, - const int16_t beta - ); + ( + const dim_t m0, + const dim_t n0, + const dim_t k0, + const uint8_t *a, + const dim_t rs_a, + const dim_t cs_a, + const dim_t ps_a, + const int8_t *b, + const dim_t rs_b, + const dim_t cs_b, + int16_t *c, + const dim_t rs_c, + const dim_t cs_c, + const int16_t alpha, + const int16_t beta + ); #endif // BLIS_GEMM_INT16_MNROW \ No newline at end of file diff --git a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_6x32rowmajor_amd256.c b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_6x32rowmajor_amd256.c index 3c2c8e90cb..1c83dc9ddc 100644 --- a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_6x32rowmajor_amd256.c +++ b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_6x32rowmajor_amd256.c @@ -40,21 +40,21 @@ // 6x32 int8o16 kernel void lpgemm_rowvar_u8s8s16o16_6x32 ( - const dim_t m0, - const dim_t n0, - const dim_t k0, - const uint8_t *a, - const dim_t rs_a, - const dim_t cs_a, - const dim_t ps_a, - const int8_t *b, - const dim_t rs_b, - const dim_t cs_b, - int16_t *c, - const dim_t rs_c, - const dim_t cs_c, - const int16_t alpha, - const int16_t beta + const dim_t m0, + const dim_t n0, + const dim_t k0, + const uint8_t *a, + const dim_t rs_a, + const dim_t cs_a, + const dim_t ps_a, + const int8_t *b, + const dim_t rs_b, + const dim_t cs_b, + int16_t *c, + const dim_t rs_c, + const dim_t cs_c, + const int16_t alpha, + const int16_t beta ) { dim_t MR = 6; @@ -78,7 +78,7 @@ void lpgemm_rowvar_u8s8s16o16_6x32 dim_t k0_updated = k0; // Making multiple of 2 to suit k in vpmaddubsw - k0_updated += (k0_updated & 0x1); + k0_updated += (k0_updated & 0x1); if (n0_16 == 1) { @@ -220,13 +220,13 @@ void lpgemm_rowvar_u8s8s16o16_6x32 // Handle k remainder. if (k_partial_pieces > 0) { - uint8_t a_element[6]; + uint8_t a_kfringe; b0 = _mm256_loadu_si256((__m256i const *)(b + (64 * k_full_pieces) + (NR * 0))); b1 = _mm256_loadu_si256((__m256i const *)(b + (64 * k_full_pieces) + (NR * 1))); - a_element[0] = *(a + (rs_a * 0) + (cs_a * (k_full_pieces * 2))); - a_int32_0 = _mm256_set1_epi8(a_element[0]); + a_kfringe = *(a + (rs_a * 0) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_kfringe); // Seperate register for intermediate op inter_vec[0] = _mm256_maddubs_epi16(a_int32_0, b0); @@ -237,8 +237,8 @@ void lpgemm_rowvar_u8s8s16o16_6x32 c_int16_0p0 = _mm256_add_epi16(inter_vec[0], c_int16_0p0); c_int16_0p1 = _mm256_add_epi16(inter_vec[1], c_int16_0p1); - a_element[1] = *(a + (rs_a * 1) + (cs_a * (k_full_pieces * 2))); - a_int32_1 = _mm256_set1_epi8(a_element[1]); + a_kfringe = *(a + (rs_a * 1) + (cs_a * (k_full_pieces * 2))); + a_int32_1 = _mm256_set1_epi8(a_kfringe); // Seperate register for intermediate op inter_vec[2] = _mm256_maddubs_epi16(a_int32_1, b0); @@ -249,8 +249,8 @@ void lpgemm_rowvar_u8s8s16o16_6x32 c_int16_1p0 = _mm256_add_epi16(inter_vec[2], c_int16_1p0); c_int16_1p1 = _mm256_add_epi16(inter_vec[3], c_int16_1p1); - a_element[2] = *(a + (rs_a * 2) + (cs_a * (k_full_pieces * 2))); - a_int32_0 = _mm256_set1_epi8(a_element[2]); + a_kfringe = *(a + (rs_a * 2) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_kfringe); // Seperate register for intermediate op inter_vec[0] = _mm256_maddubs_epi16(a_int32_0, b0); @@ -261,8 +261,8 @@ void lpgemm_rowvar_u8s8s16o16_6x32 c_int16_2p0 = _mm256_add_epi16(inter_vec[0], c_int16_2p0); c_int16_2p1 = _mm256_add_epi16(inter_vec[1], c_int16_2p1); - a_element[3] = *(a + (rs_a * 3) + (cs_a * (k_full_pieces * 2))); - a_int32_1 = _mm256_set1_epi8(a_element[3]); + a_kfringe = *(a + (rs_a * 3) + (cs_a * (k_full_pieces * 2))); + a_int32_1 = _mm256_set1_epi8(a_kfringe); // Seperate register for intermediate op inter_vec[2] = _mm256_maddubs_epi16(a_int32_1, b0); @@ -273,8 +273,8 @@ void lpgemm_rowvar_u8s8s16o16_6x32 c_int16_3p0 = _mm256_add_epi16(inter_vec[2], c_int16_3p0); c_int16_3p1 = _mm256_add_epi16(inter_vec[3], c_int16_3p1); - a_element[4] = *(a + (rs_a * 4) + (cs_a * (k_full_pieces * 2))); - a_int32_0 = _mm256_set1_epi8(a_element[4]); + a_kfringe = *(a + (rs_a * 4) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_kfringe); // Seperate register for intermediate op inter_vec[0] = _mm256_maddubs_epi16(a_int32_0, b0); @@ -285,8 +285,8 @@ void lpgemm_rowvar_u8s8s16o16_6x32 c_int16_4p0 = _mm256_add_epi16(inter_vec[0], c_int16_4p0); c_int16_4p1 = _mm256_add_epi16(inter_vec[1], c_int16_4p1); - a_element[5] = *(a + (rs_a * 5) + (cs_a * (k_full_pieces * 2))); - a_int32_1 = _mm256_set1_epi8(a_element[5]); + a_kfringe = *(a + (rs_a * 5) + (cs_a * (k_full_pieces * 2))); + a_int32_1 = _mm256_set1_epi8(a_kfringe); // Seperate register for intermediate op inter_vec[2] = _mm256_maddubs_epi16(a_int32_1, b0); diff --git a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_m_fringe_amd256.c b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_m_fringe_amd256.c index 115e3c2a4e..56561ff972 100644 --- a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_m_fringe_amd256.c +++ b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_m_fringe_amd256.c @@ -38,20 +38,18 @@ #include "lpgemm_m_fringe_s16.h" // 4x32 int8o16 kernel -void lpgemm_rowvar_u8s8s16o16_4x32 - ( - const dim_t k0, - const uint8_t *a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t *b, - const dim_t rs_b, - const dim_t cs_b, - int16_t *c, - const dim_t rs_c, - const int16_t alpha, - const int16_t beta - ) +void lpgemm_rowvar_u8s8s16o16_4x32( + const dim_t k0, + const uint8_t *a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t *b, + const dim_t rs_b, + const dim_t cs_b, + int16_t *c, + const dim_t rs_c, + const int16_t alpha, + const int16_t beta) { dim_t NR = 32; @@ -140,13 +138,13 @@ void lpgemm_rowvar_u8s8s16o16_4x32 // Handle k remainder. if (k_partial_pieces > 0) { - uint8_t a_element[4]; + uint8_t a_kfringe; b0 = _mm256_loadu_si256((__m256i const *)(b + (64 * k_full_pieces) + (NR * 0))); b1 = _mm256_loadu_si256((__m256i const *)(b + (64 * k_full_pieces) + (NR * 1))); - a_element[0] = *(a + (rs_a * 0) + (cs_a * (k_full_pieces * 2))); - a_int32_0 = _mm256_set1_epi8(a_element[0]); + a_kfringe = *(a + (rs_a * 0) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_kfringe); // Seperate register for intermediate op inter_vec[0] = _mm256_maddubs_epi16(a_int32_0, b0); @@ -157,8 +155,8 @@ void lpgemm_rowvar_u8s8s16o16_4x32 c_int16_0p0 = _mm256_add_epi16(inter_vec[0], c_int16_0p0); c_int16_0p1 = _mm256_add_epi16(inter_vec[1], c_int16_0p1); - a_element[1] = *(a + (rs_a * 1) + (cs_a * (k_full_pieces * 2))); - a_int32_1 = _mm256_set1_epi8(a_element[1]); + a_kfringe = *(a + (rs_a * 1) + (cs_a * (k_full_pieces * 2))); + a_int32_1 = _mm256_set1_epi8(a_kfringe); // Seperate register for intermediate op inter_vec[2] = _mm256_maddubs_epi16(a_int32_1, b0); @@ -169,8 +167,8 @@ void lpgemm_rowvar_u8s8s16o16_4x32 c_int16_1p0 = _mm256_add_epi16(inter_vec[2], c_int16_1p0); c_int16_1p1 = _mm256_add_epi16(inter_vec[3], c_int16_1p1); - a_element[2] = *(a + (rs_a * 2) + (cs_a * (k_full_pieces * 2))); - a_int32_0 = _mm256_set1_epi8(a_element[2]); + a_kfringe = *(a + (rs_a * 2) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_kfringe); // Seperate register for intermediate op inter_vec[0] = _mm256_maddubs_epi16(a_int32_0, b0); @@ -181,8 +179,8 @@ void lpgemm_rowvar_u8s8s16o16_4x32 c_int16_2p0 = _mm256_add_epi16(inter_vec[0], c_int16_2p0); c_int16_2p1 = _mm256_add_epi16(inter_vec[1], c_int16_2p1); - a_element[3] = *(a + (rs_a * 3) + (cs_a * (k_full_pieces * 2))); - a_int32_1 = _mm256_set1_epi8(a_element[3]); + a_kfringe = *(a + (rs_a * 3) + (cs_a * (k_full_pieces * 2))); + a_int32_1 = _mm256_set1_epi8(a_kfringe); // Seperate register for intermediate op inter_vec[2] = _mm256_maddubs_epi16(a_int32_1, b0); @@ -282,20 +280,18 @@ void lpgemm_rowvar_u8s8s16o16_4x32 } // 2x32 int8o16 kernel -void lpgemm_rowvar_u8s8s16o16_2x32 - ( - const dim_t k0, - const uint8_t *a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t *b, - const dim_t rs_b, - const dim_t cs_b, - int16_t *c, - const dim_t rs_c, - const int16_t alpha, - const int16_t beta - ) +void lpgemm_rowvar_u8s8s16o16_2x32( + const dim_t k0, + const uint8_t *a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t *b, + const dim_t rs_b, + const dim_t cs_b, + int16_t *c, + const dim_t rs_c, + const int16_t alpha, + const int16_t beta) { dim_t NR = 32; @@ -353,13 +349,13 @@ void lpgemm_rowvar_u8s8s16o16_2x32 // Handle k remainder. if (k_partial_pieces > 0) { - uint8_t a_element[2]; + uint8_t a_kfringe; b0 = _mm256_loadu_si256((__m256i const *)(b + (64 * k_full_pieces) + (NR * 0))); b1 = _mm256_loadu_si256((__m256i const *)(b + (64 * k_full_pieces) + (NR * 1))); - a_element[0] = *(a + (rs_a * 0) + (cs_a * (k_full_pieces * 2))); - a_int32_0 = _mm256_set1_epi8(a_element[0]); + a_kfringe = *(a + (rs_a * 0) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_kfringe); // Seperate register for intermediate op inter_vec[0] = _mm256_maddubs_epi16(a_int32_0, b0); @@ -370,8 +366,8 @@ void lpgemm_rowvar_u8s8s16o16_2x32 c_int16_0p0 = _mm256_add_epi16(inter_vec[0], c_int16_0p0); c_int16_0p1 = _mm256_add_epi16(inter_vec[1], c_int16_0p1); - a_element[1] = *(a + (rs_a * 1) + (cs_a * (k_full_pieces * 2))); - a_int32_1 = _mm256_set1_epi8(a_element[1]); + a_kfringe = *(a + (rs_a * 1) + (cs_a * (k_full_pieces * 2))); + a_int32_1 = _mm256_set1_epi8(a_kfringe); // Seperate register for intermediate op inter_vec[2] = _mm256_maddubs_epi16(a_int32_1, b0); @@ -433,20 +429,18 @@ void lpgemm_rowvar_u8s8s16o16_2x32 } // 1x32 int8o16 kernel -void lpgemm_rowvar_u8s8s16o16_1x32 - ( - const dim_t k0, - const uint8_t *a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t *b, - const dim_t rs_b, - const dim_t cs_b, - int16_t *c, - const dim_t rs_c, - const int16_t alpha, - const int16_t beta - ) +void lpgemm_rowvar_u8s8s16o16_1x32( + const dim_t k0, + const uint8_t *a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t *b, + const dim_t rs_b, + const dim_t cs_b, + int16_t *c, + const dim_t rs_c, + const int16_t alpha, + const int16_t beta) { dim_t NR = 32; @@ -488,13 +482,13 @@ void lpgemm_rowvar_u8s8s16o16_1x32 // Handle k remainder. if (k_partial_pieces > 0) { - uint8_t a_element[1]; + uint8_t a_kfringe; b0 = _mm256_loadu_si256((__m256i const *)(b + (64 * k_full_pieces) + (NR * 0))); b1 = _mm256_loadu_si256((__m256i const *)(b + (64 * k_full_pieces) + (NR * 1))); - a_element[0] = *(a + (rs_a * 0) + (cs_a * (k_full_pieces * 2))); - a_int32_0 = _mm256_set1_epi8(a_element[0]); + a_kfringe = *(a + (rs_a * 0) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_kfringe); // Seperate register for intermediate op inter_vec[0] = _mm256_maddubs_epi16(a_int32_0, b0); diff --git a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_mn_fringe_amd256.c b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_mn_fringe_amd256.c index dba8792cc1..a494da8160 100644 --- a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_mn_fringe_amd256.c +++ b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_mn_fringe_amd256.c @@ -122,12 +122,12 @@ void lpgemm_rowvar_u8s8s16o16_4x16 // Handle k remainder. if (k_partial_pieces > 0) { - uint8_t a_element[4]; + uint8_t a_kfringe; b0 = _mm256_loadu_si256((__m256i const *)(b + (32 * k_full_pieces) + (NR * 0))); - a_element[0] = *(a + (rs_a * 0) + (cs_a * (k_full_pieces * 2))); - a_int32_0 = _mm256_set1_epi8(a_element[0]); + a_kfringe = *(a + (rs_a * 0) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_kfringe); // Seperate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); @@ -136,8 +136,8 @@ void lpgemm_rowvar_u8s8s16o16_4x16 // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-31] c_int16_0p0 = _mm256_add_epi16(inter_vec, c_int16_0p0); - a_element[1] = *(a + (rs_a * 1) + (cs_a * (k_full_pieces * 2))); - a_int32_0 = _mm256_set1_epi8(a_element[1]); + a_kfringe = *(a + (rs_a * 1) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_kfringe); // Seperate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); @@ -146,8 +146,8 @@ void lpgemm_rowvar_u8s8s16o16_4x16 // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-31] c_int16_1p0 = _mm256_add_epi16(inter_vec, c_int16_1p0); - a_element[2] = *(a + (rs_a * 2) + (cs_a * (k_full_pieces * 2))); - a_int32_0 = _mm256_set1_epi8(a_element[2]); + a_kfringe = *(a + (rs_a * 2) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_kfringe); // Seperate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); @@ -156,8 +156,8 @@ void lpgemm_rowvar_u8s8s16o16_4x16 // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-31] c_int16_2p0 = _mm256_add_epi16(inter_vec, c_int16_2p0); - a_element[3] = *(a + (rs_a * 3) + (cs_a * (k_full_pieces * 2))); - a_int32_0 = _mm256_set1_epi8(a_element[3]); + a_kfringe = *(a + (rs_a * 3) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_kfringe); // Seperate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); @@ -312,12 +312,12 @@ void lpgemm_rowvar_u8s8s16o16_4xlt16 // Handle k remainder. if (k_partial_pieces > 0) { - uint8_t a_element[4]; + uint8_t a_kfringe; b0 = _mm256_loadu_si256((__m256i const *)(b + (32 * k_full_pieces) + (NR * 0))); - a_element[0] = *(a + (rs_a * 0) + (cs_a * (k_full_pieces * 2))); - a_int32_0 = _mm256_set1_epi8(a_element[0]); + a_kfringe = *(a + (rs_a * 0) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_kfringe); // Seperate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); @@ -326,8 +326,8 @@ void lpgemm_rowvar_u8s8s16o16_4xlt16 // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-31] c_int16_0p0 = _mm256_add_epi16(inter_vec, c_int16_0p0); - a_element[1] = *(a + (rs_a * 1) + (cs_a * (k_full_pieces * 2))); - a_int32_0 = _mm256_set1_epi8(a_element[1]); + a_kfringe = *(a + (rs_a * 1) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_kfringe); // Seperate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); @@ -336,8 +336,8 @@ void lpgemm_rowvar_u8s8s16o16_4xlt16 // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-31] c_int16_1p0 = _mm256_add_epi16(inter_vec, c_int16_1p0); - a_element[2] = *(a + (rs_a * 2) + (cs_a * (k_full_pieces * 2))); - a_int32_0 = _mm256_set1_epi8(a_element[2]); + a_kfringe = *(a + (rs_a * 2) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_kfringe); // Seperate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); @@ -346,8 +346,8 @@ void lpgemm_rowvar_u8s8s16o16_4xlt16 // c[2,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-31] c_int16_2p0 = _mm256_add_epi16(inter_vec, c_int16_2p0); - a_element[3] = *(a + (rs_a * 3) + (cs_a * (k_full_pieces * 2))); - a_int32_0 = _mm256_set1_epi8(a_element[3]); + a_kfringe = *(a + (rs_a * 3) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_kfringe); // Seperate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); @@ -486,12 +486,12 @@ void lpgemm_rowvar_u8s8s16o16_2x16 // Handle k remainder. if (k_partial_pieces > 0) { - uint8_t a_element[2]; + uint8_t a_kfringe; b0 = _mm256_loadu_si256((__m256i const *)(b + (32 * k_full_pieces) + (NR * 0))); - a_element[0] = *(a + (rs_a * 0) + (cs_a * (k_full_pieces * 2))); - a_int32_0 = _mm256_set1_epi8(a_element[0]); + a_kfringe = *(a + (rs_a * 0) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_kfringe); // Seperate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); @@ -500,8 +500,8 @@ void lpgemm_rowvar_u8s8s16o16_2x16 // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-31] c_int16_0p0 = _mm256_add_epi16(inter_vec, c_int16_0p0); - a_element[1] = *(a + (rs_a * 1) + (cs_a * (k_full_pieces * 2))); - a_int32_0 = _mm256_set1_epi8(a_element[1]); + a_kfringe = *(a + (rs_a * 1) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_kfringe); // Seperate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); @@ -609,12 +609,12 @@ void lpgemm_rowvar_u8s8s16o16_2xlt16 // Handle k remainder. if (k_partial_pieces > 0) { - uint8_t a_element[4]; + uint8_t a_kfringe; b0 = _mm256_loadu_si256((__m256i const *)(b + (32 * k_full_pieces) + (NR * 0))); - a_element[0] = *(a + (rs_a * 0) + (cs_a * (k_full_pieces * 2))); - a_int32_0 = _mm256_set1_epi8(a_element[0]); + a_kfringe = *(a + (rs_a * 0) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_kfringe); // Seperate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); @@ -623,8 +623,8 @@ void lpgemm_rowvar_u8s8s16o16_2xlt16 // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-31] c_int16_0p0 = _mm256_add_epi16(inter_vec, c_int16_0p0); - a_element[1] = *(a + (rs_a * 1) + (cs_a * (k_full_pieces * 2))); - a_int32_0 = _mm256_set1_epi8(a_element[1]); + a_kfringe = *(a + (rs_a * 1) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_kfringe); // Seperate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); @@ -724,12 +724,12 @@ void lpgemm_rowvar_u8s8s16o16_1x16 // Handle k remainder. if (k_partial_pieces > 0) { - uint8_t a_element[1]; + uint8_t a_kfringe; b0 = _mm256_loadu_si256((__m256i const *)(b + (64 * k_full_pieces) + (NR * 0))); - a_element[0] = *(a + (rs_a * 0) + (cs_a * (k_full_pieces * 2))); - a_int32_0 = _mm256_set1_epi8(a_element[0]); + a_kfringe = *(a + (rs_a * 0) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_kfringe); // Seperate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); @@ -814,12 +814,12 @@ void lpgemm_rowvar_u8s8s16o16_1xlt16 // Handle k remainder. if (k_partial_pieces > 0) { - uint8_t a_element[4]; + uint8_t a_kfringe; b0 = _mm256_loadu_si256((__m256i const *)(b + (32 * k_full_pieces) + (NR * 0))); - a_element[0] = *(a + (rs_a * 0) + (cs_a * (k_full_pieces * 2))); - a_int32_0 = _mm256_set1_epi8(a_element[0]); + a_kfringe = *(a + (rs_a * 0) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_kfringe); // Seperate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); diff --git a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_n_fringe_amd256.c b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_n_fringe_amd256.c index d5cec8f0fc..8fb2cdec1f 100644 --- a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_n_fringe_amd256.c +++ b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_n_fringe_amd256.c @@ -39,22 +39,20 @@ #include "lpgemm_mn_fringe_s16.h" // 6x16 int8o16 kernel -void lpgemm_rowvar_u8s8s16o16_6x16 - ( - const dim_t m0, - const dim_t k0, - const uint8_t *a, - const dim_t rs_a, - const dim_t cs_a, - const dim_t ps_a, - const int8_t *b, - const dim_t rs_b, - const dim_t cs_b, - int16_t *c, - const dim_t rs_c, - const int16_t alpha, - const int16_t beta - ) +void lpgemm_rowvar_u8s8s16o16_6x16( + const dim_t m0, + const dim_t k0, + const uint8_t *a, + const dim_t rs_a, + const dim_t cs_a, + const dim_t ps_a, + const int8_t *b, + const dim_t rs_b, + const dim_t cs_b, + int16_t *c, + const dim_t rs_c, + const int16_t alpha, + const int16_t beta) { dim_t MR = 6; dim_t NR = 16; @@ -159,12 +157,12 @@ void lpgemm_rowvar_u8s8s16o16_6x16 // Handle k remainder. if (k_partial_pieces > 0) { - uint8_t a_element[6]; + uint8_t a_kfringe; b0 = _mm256_loadu_si256((__m256i const *)(b + (32 * k_full_pieces) + (NR * 0))); - a_element[0] = *(a + (rs_a * 0) + (cs_a * (k_full_pieces * 2))); - a_int32_0 = _mm256_set1_epi8(a_element[0]); + a_kfringe = *(a + (rs_a * 0) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_kfringe); // Seperate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); @@ -173,8 +171,8 @@ void lpgemm_rowvar_u8s8s16o16_6x16 // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-31] c_int16_0p0 = _mm256_add_epi16(inter_vec, c_int16_0p0); - a_element[1] = *(a + (rs_a * 1) + (cs_a * (k_full_pieces * 2))); - a_int32_0 = _mm256_set1_epi8(a_element[1]); + a_kfringe = *(a + (rs_a * 1) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_kfringe); // Seperate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); @@ -183,8 +181,8 @@ void lpgemm_rowvar_u8s8s16o16_6x16 // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-31] c_int16_1p0 = _mm256_add_epi16(inter_vec, c_int16_1p0); - a_element[2] = *(a + (rs_a * 2) + (cs_a * (k_full_pieces * 2))); - a_int32_0 = _mm256_set1_epi8(a_element[2]); + a_kfringe = *(a + (rs_a * 2) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_kfringe); // Seperate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); @@ -193,8 +191,8 @@ void lpgemm_rowvar_u8s8s16o16_6x16 // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-31] c_int16_2p0 = _mm256_add_epi16(inter_vec, c_int16_2p0); - a_element[3] = *(a + (rs_a * 3) + (cs_a * (k_full_pieces * 2))); - a_int32_0 = _mm256_set1_epi8(a_element[3]); + a_kfringe = *(a + (rs_a * 3) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_kfringe); // Seperate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); @@ -203,8 +201,8 @@ void lpgemm_rowvar_u8s8s16o16_6x16 // c[3,0-15] = a[3,kr:kr+2]*b[kr:kr+2,0-31] c_int16_3p0 = _mm256_add_epi16(inter_vec, c_int16_3p0); - a_element[4] = *(a + (rs_a * 4) + (cs_a * (k_full_pieces * 2))); - a_int32_0 = _mm256_set1_epi8(a_element[4]); + a_kfringe = *(a + (rs_a * 4) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_kfringe); // Seperate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); @@ -213,8 +211,8 @@ void lpgemm_rowvar_u8s8s16o16_6x16 // c[4,0-15] = a[4,kr:kr+2]*b[kr:kr+2,0-31] c_int16_4p0 = _mm256_add_epi16(inter_vec, c_int16_4p0); - a_element[5] = *(a + (rs_a * 5) + (cs_a * (k_full_pieces * 2))); - a_int32_0 = _mm256_set1_epi8(a_element[5]); + a_kfringe = *(a + (rs_a * 5) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_kfringe); // Seperate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); @@ -346,23 +344,21 @@ void lpgemm_rowvar_u8s8s16o16_6x16 } // 6xlt16 int8o16 kernel -void lpgemm_rowvar_u8s8s16o16_6xlt16 - ( - const dim_t m0, - const dim_t k0, - const uint8_t *a, - const dim_t rs_a, - const dim_t cs_a, - const dim_t ps_a, - const int8_t *b, - const dim_t rs_b, - const dim_t cs_b, - int16_t *c, - const dim_t rs_c, - const int16_t alpha, - const int16_t beta, - const dim_t n0_rem - ) +void lpgemm_rowvar_u8s8s16o16_6xlt16( + const dim_t m0, + const dim_t k0, + const uint8_t *a, + const dim_t rs_a, + const dim_t cs_a, + const dim_t ps_a, + const int8_t *b, + const dim_t rs_b, + const dim_t cs_b, + int16_t *c, + const dim_t rs_c, + const int16_t alpha, + const int16_t beta, + const dim_t n0_rem) { dim_t MR = 6; @@ -473,12 +469,12 @@ void lpgemm_rowvar_u8s8s16o16_6xlt16 // Handle k remainder. if (k_partial_pieces > 0) { - uint8_t a_element[6]; + uint8_t a_kfringe; b0 = _mm256_loadu_si256((__m256i const *)(b + (32 * k_full_pieces) + (cs_b * 0))); - a_element[0] = *(a + (rs_a * 0) + (cs_a * (k_full_pieces * 2))); - a_int32_0 = _mm256_set1_epi8(a_element[0]); + a_kfringe = *(a + (rs_a * 0) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_kfringe); inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); @@ -486,8 +482,8 @@ void lpgemm_rowvar_u8s8s16o16_6xlt16 // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-31] c_int16_0p0 = _mm256_add_epi16(inter_vec, c_int16_0p0); - a_element[1] = *(a + (rs_a * 1) + (cs_a * (k_full_pieces * 2))); - a_int32_0 = _mm256_set1_epi8(a_element[1]); + a_kfringe = *(a + (rs_a * 1) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_kfringe); // Seperate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); @@ -496,8 +492,8 @@ void lpgemm_rowvar_u8s8s16o16_6xlt16 // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-31] c_int16_1p0 = _mm256_add_epi16(inter_vec, c_int16_1p0); - a_element[2] = *(a + (rs_a * 2) + (cs_a * (k_full_pieces * 2))); - a_int32_0 = _mm256_set1_epi8(a_element[2]); + a_kfringe = *(a + (rs_a * 2) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_kfringe); // Seperate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); @@ -506,8 +502,8 @@ void lpgemm_rowvar_u8s8s16o16_6xlt16 // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-31] c_int16_2p0 = _mm256_add_epi16(inter_vec, c_int16_2p0); - a_element[3] = *(a + (rs_a * 3) + (cs_a * (k_full_pieces * 2))); - a_int32_0 = _mm256_set1_epi8(a_element[3]); + a_kfringe = *(a + (rs_a * 3) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_kfringe); // Seperate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); @@ -516,8 +512,8 @@ void lpgemm_rowvar_u8s8s16o16_6xlt16 // c[3,0-15] = a[3,kr:kr+2]*b[kr:kr+2,0-31] c_int16_3p0 = _mm256_add_epi16(inter_vec, c_int16_3p0); - a_element[4] = *(a + (rs_a * 4) + (cs_a * (k_full_pieces * 2))); - a_int32_0 = _mm256_set1_epi8(a_element[4]); + a_kfringe = *(a + (rs_a * 4) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_kfringe); // Seperate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); @@ -526,8 +522,8 @@ void lpgemm_rowvar_u8s8s16o16_6xlt16 // c[4,0-15] = a[4,kr:kr+2]*b[kr:kr+2,0-31] c_int16_4p0 = _mm256_add_epi16(inter_vec, c_int16_4p0); - a_element[5] = *(a + (rs_a * 5) + (cs_a * (k_full_pieces * 2))); - a_int32_0 = _mm256_set1_epi8(a_element[5]); + a_kfringe = *(a + (rs_a * 5) + (cs_a * (k_full_pieces * 2))); + a_int32_0 = _mm256_set1_epi8(a_kfringe); // Seperate register for intermediate op inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); @@ -594,7 +590,7 @@ void lpgemm_rowvar_u8s8s16o16_6xlt16 selector1 = _mm256_mullo_epi16(selector2, selector1); c_int16_5p0 = _mm256_add_epi16(selector1, c_int16_5p0); } - + // Store the results. // c[0,0-15] _mm256_storeu_si256((__m256i_u *)buf0, c_int16_0p0); diff --git a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_packb_amd256.c b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_packb_amd256.c index 1d0e1471b1..1f6ec5a787 100644 --- a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_packb_amd256.c +++ b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_packb_amd256.c @@ -37,6 +37,14 @@ #include "blis.h" #include "lpgemm_packb_s16.h" +void get_packb_nr32_u8s8s16o16_strides( + dim_t *rs_b, + dim_t *cs_b) +{ + *rs_b = 32 * 2; + *cs_b = 32; +} + void packb_nrlt16_u8s8s16o16( int8_t *pack_b_buffer_u8s8s16o16, const int8_t *b, @@ -172,7 +180,7 @@ void packb_nr32_u8s8s16o16( dim_t KC_updated = rows; // Making multiple of 2 to suit k in vpmaddubsw - KC_updated += (KC_updated & 0x1); + KC_updated += (KC_updated & 0x1); __m256i b_vec[2], inter_vec[2]; @@ -182,10 +190,10 @@ void packb_nr32_u8s8s16o16( { // Read b[0,0], b[0,1], b[0,2]......., b[0,31] b_vec[0] = _mm256_loadu_si256((__m256i const *)(b + (ldb * (kr + 0)) + jc)); - + // Read b[1,0], b[1,1], b[1,2]......., b[1,31] b_vec[1] = _mm256_loadu_si256((__m256i const *)(b + (ldb * (kr + 1)) + jc)); - + // Reorder B matrix inputs to suit vpmaddubsw instructions inter_vec[0] = _mm256_unpacklo_epi8(b_vec[0], b_vec[1]); inter_vec[1] = _mm256_unpackhi_epi8(b_vec[0], b_vec[1]); @@ -233,7 +241,8 @@ void packb_nr32_u8s8s16o16( if (n0_16 == 1) { packb_nr16_u8s8s16o16( - (pack_b_buffer_u8s8s16o16 + (n_full_pieces_loop_limit * KC_updated)), + (pack_b_buffer_u8s8s16o16 + + (n_full_pieces_loop_limit * KC_updated)), (b + n_full_pieces_loop_limit), ldb, rows); n0_partial_pack = 16; @@ -242,8 +251,10 @@ void packb_nr32_u8s8s16o16( if (n0_partial_rem > 0) { packb_nrlt16_u8s8s16o16( - (pack_b_buffer_u8s8s16o16 + (n_full_pieces_loop_limit * KC_updated) + (n0_partial_pack * KC_updated)), - (b + n_full_pieces_loop_limit + n0_partial_pack), ldb, rows, n0_partial_rem); + (pack_b_buffer_u8s8s16o16 + (n_full_pieces_loop_limit * KC_updated) + + (n0_partial_pack * KC_updated)), + (b + n_full_pieces_loop_limit + n0_partial_pack), + ldb, rows, n0_partial_rem); } } diff --git a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_packb_s16.h b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_packb_s16.h index 31d8465dac..b8d73c862c 100644 --- a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_packb_s16.h +++ b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_packb_s16.h @@ -35,13 +35,21 @@ #ifndef BLIS_GEMM_INT16_PACKB #define BLIS_GEMM_INT16_PACKB -void packb_nr32_u8s8s16o16( - int8_t *pack_b_buffer_u8s8s16o16, - const int8_t *b, - const dim_t ldb, - const dim_t cols, - const dim_t rows, - dim_t *rs_b, - dim_t *cs_b); +void get_packb_nr32_u8s8s16o16_strides + ( + dim_t* rs_b, + dim_t* cs_b + ); + +void packb_nr32_u8s8s16o16 + ( + int8_t *pack_b_buffer_u8s8s16o16, + const int8_t *b, + const dim_t ldb, + const dim_t cols, + const dim_t rows, + dim_t *rs_b, + dim_t *cs_b + ); #endif // BLIS_GEMM_INT16_PACKB \ No newline at end of file From d1eaf65a26fefea4fa327ad83ee5a10ec48eaf7a Mon Sep 17 00:00:00 2001 From: Harihara Sudhan S Date: Mon, 8 Aug 2022 15:18:41 +0530 Subject: [PATCH 167/243] Post-Ops for u8s8s16os16 Functionality - Post-ops is an operation performed on every element of the output matrix after GEMM operation is completed. - Post-ops relu and bias added to all the compute kernels of u8s8s16os16 - Post-ops are done on the value loaded into the register to avoid reloading of C matrix elements - Minor bug fixes in openmp thread decorator of lpgemm - Added test cases to lpgemm bench input file AMD-Internal: [CPUPL-2171] Change-Id: If49f763fdfac19749f6665c172348691165d8631 --- addon/aocl_gemm/aocl_gemm_u8s8s16os16.c | 1 + .../threading/lpgemm_thread_decor_openmp.c | 5 +- .../aocl_gemm/frame/u8s8s16/lpgemm_u8s8s16.c | 10 +- .../kernels/u8s8s16/lpgemm_6x32rowmajor.h | 36 ++- .../u8s8s16/lpgemm_6x32rowmajor_amd256.c | 188 ++++++++++-- .../kernels/u8s8s16/lpgemm_m_fringe_amd256.c | 290 +++++++++++++++--- .../kernels/u8s8s16/lpgemm_m_fringe_s16.h | 92 +++--- .../kernels/u8s8s16/lpgemm_mn_fringe_amd256.c | 290 +++++++++++++++++- .../kernels/u8s8s16/lpgemm_mn_fringe_s16.h | 188 +++++++----- .../kernels/u8s8s16/lpgemm_n_fringe_amd256.c | 246 ++++++++++++--- .../kernels/u8s8s16/lpgemm_n_fringe_s16.h | 71 +++-- .../kernels/u8s8s16/lpgemm_packb_amd256.c | 27 +- bench/bench_aocl_gemm/bench_input.txt | 150 +++++++++ 13 files changed, 1289 insertions(+), 305 deletions(-) diff --git a/addon/aocl_gemm/aocl_gemm_u8s8s16os16.c b/addon/aocl_gemm/aocl_gemm_u8s8s16os16.c index 59613cfddd..96232a8947 100644 --- a/addon/aocl_gemm/aocl_gemm_u8s8s16os16.c +++ b/addon/aocl_gemm/aocl_gemm_u8s8s16os16.c @@ -39,6 +39,7 @@ #include "lpgemm_config.h" #include "lpgemm_utils.h" #include "lpgemm_thread_decor_openmp.h" +#include "lpgemm_post_ops.h" void aocl_gemm_u8s8s16os16 ( diff --git a/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.c b/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.c index afba56d6b1..1521fe7c5d 100644 --- a/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.c +++ b/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.c @@ -250,8 +250,7 @@ BLIS_INLINE void lpgemm_u8s8s16o16_get_threading else if ( ( *n_threads ) > 1 ) { - dim_t NR = 32; - //dim_t MR = 6; + dim_t NR = lpgemm_get_block_size_NR_global_cntx( U8S8S16OS16 ); if ( n <= NR ) { @@ -565,7 +564,7 @@ void lpgemm_ ## LPGEMM_SFX ## _thread_decorator \ ); \ } \ -GEN_LPGEMM_DECORATOR(uint8_t,int8_t,int32_t,u8s8s16o16) +GEN_LPGEMM_DECORATOR(uint8_t,int8_t,int16_t,u8s8s16o16) GEN_LPGEMM_DECORATOR(uint8_t,int8_t,int32_t,u8s8s32o32) GEN_LPGEMM_DECORATOR(float,float,float,f32f32f32of32) diff --git a/addon/aocl_gemm/frame/u8s8s16/lpgemm_u8s8s16.c b/addon/aocl_gemm/frame/u8s8s16/lpgemm_u8s8s16.c index 2926661eac..98fa45e966 100644 --- a/addon/aocl_gemm/frame/u8s8s16/lpgemm_u8s8s16.c +++ b/addon/aocl_gemm/frame/u8s8s16/lpgemm_u8s8s16.c @@ -94,6 +94,9 @@ void lpgemm_rowvar_u8s8s16o16 // Making multiple of 2 to suit k in vpmaddubsw dim_t k_updated = make_multiple_of_n( k, 2 ); + // Is required to decide whether to apply post ops or not. + bool is_last_k = FALSE; + // Generate thrinfo objects for jc and ic loops from lpgemm_thrinfo_t. thrinfo_t thread_jc; thrinfo_t thread_ic; @@ -128,6 +131,8 @@ void lpgemm_rowvar_u8s8s16o16 int16_t beta0 = (pc == 0) ? beta : 1; dim_t kc0 = bli_min((k - pc), KC); + is_last_k = ( ( pc + KC ) >= k ) ? ( TRUE ) : ( FALSE ); + // kc0 needs to be a multiple of 2 so that it can be // used with vpmaddubsw instruction. Padding is added in // cases this condition is not satisfied, and therefore @@ -252,8 +257,9 @@ void lpgemm_rowvar_u8s8s16o16 mc0, nr0, kc0, a_use, rs_a_use, cs_a_use, a_block_stride, (b_use + (jr * kc0_updated)), rs_b_use, cs_b_use, - (c_use_ic + +jr), rs_c, 1, - alpha, beta0 + (c_use_ic + jr), rs_c, 1, + alpha, beta0, + is_last_k, ic, ( jc + jr ), post_op_list ); } } diff --git a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_6x32rowmajor.h b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_6x32rowmajor.h index 163241367b..2dd00f4494 100644 --- a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_6x32rowmajor.h +++ b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_6x32rowmajor.h @@ -35,24 +35,30 @@ #ifndef BLIS_GEMM_INT16_MNROW #define BLIS_GEMM_INT16_MNROW +#include "lpgemm_post_ops.h" + // 6x32 int8o16 kernel void lpgemm_rowvar_u8s8s16o16_6x32 ( - const dim_t m0, - const dim_t n0, - const dim_t k0, - const uint8_t *a, - const dim_t rs_a, - const dim_t cs_a, - const dim_t ps_a, - const int8_t *b, - const dim_t rs_b, - const dim_t cs_b, - int16_t *c, - const dim_t rs_c, - const dim_t cs_c, - const int16_t alpha, - const int16_t beta + const dim_t m0, + const dim_t n0, + const dim_t k0, + const uint8_t *a, + const dim_t rs_a, + const dim_t cs_a, + const dim_t ps_a, + const int8_t *b, + const dim_t rs_b, + const dim_t cs_b, + int16_t *c, + const dim_t rs_c, + const dim_t cs_c, + const int16_t alpha, + const int16_t beta, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op *post_ops_list ); #endif // BLIS_GEMM_INT16_MNROW \ No newline at end of file diff --git a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_6x32rowmajor_amd256.c b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_6x32rowmajor_amd256.c index 1c83dc9ddc..f2cbd7affe 100644 --- a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_6x32rowmajor_amd256.c +++ b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_6x32rowmajor_amd256.c @@ -40,23 +40,33 @@ // 6x32 int8o16 kernel void lpgemm_rowvar_u8s8s16o16_6x32 ( - const dim_t m0, - const dim_t n0, - const dim_t k0, - const uint8_t *a, - const dim_t rs_a, - const dim_t cs_a, - const dim_t ps_a, - const int8_t *b, - const dim_t rs_b, - const dim_t cs_b, - int16_t *c, - const dim_t rs_c, - const dim_t cs_c, - const int16_t alpha, - const int16_t beta + const dim_t m0, + const dim_t n0, + const dim_t k0, + const uint8_t *a, + const dim_t rs_a, + const dim_t cs_a, + const dim_t ps_a, + const int8_t *b, + const dim_t rs_b, + const dim_t cs_b, + int16_t *c, + const dim_t rs_c, + const dim_t cs_c, + const int16_t alpha, + const int16_t beta, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op *post_ops_list ) { + static void *post_ops_labels[] = + { + &&POST_OPS_6x32_DISABLE, + &&POST_OPS_BIAS_6x32, + &&POST_OPS_RELU_6x32}; + dim_t MR = 6; dim_t NR = 32; @@ -87,7 +97,10 @@ void lpgemm_rowvar_u8s8s16o16_6x32 a, rs_a, cs_a, ps_a, b, ((rs_b / 2) * 1), cs_b, c, rs_c, - alpha, beta); + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list); b = b + (16 * k0_updated); c = c + 16; @@ -100,7 +113,10 @@ void lpgemm_rowvar_u8s8s16o16_6x32 a, rs_a, cs_a, ps_a, b, ((rs_b / 2) * 1), cs_b, c, rs_c, - alpha, beta, n0_rem); + alpha, beta, n0_rem, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list); } // If fringe cases are encountered, return early @@ -385,44 +401,140 @@ void lpgemm_rowvar_u8s8s16o16_6x32 c_int16_5p1 = _mm256_add_epi16(selector1, c_int16_5p1); } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_6x32: + { + selector1 = + _mm256_loadu_si256( (__m256i const *)((int16_t *)post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 )) ); + selector2 = + _mm256_loadu_si256( (__m256i const *)((int16_t *)post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 )) ); + + // c[0,0-15] + c_int16_0p0 = _mm256_add_epi16( selector1, c_int16_0p0 ); + + // c[0, 16-31] + c_int16_0p1 = _mm256_add_epi16( selector2, c_int16_0p1 ); + + // c[1,0-15] + c_int16_1p0 = _mm256_add_epi16( selector1, c_int16_1p0 ); + + // c[1, 16-31] + c_int16_1p1 = _mm256_add_epi16( selector2, c_int16_1p1 ); + + // c[2,0-15] + c_int16_2p0 = _mm256_add_epi16( selector1, c_int16_2p0 ); + + // c[2, 16-31] + c_int16_2p1 = _mm256_add_epi16( selector2, c_int16_2p1 ); + + // c[3,0-15] + c_int16_3p0 = _mm256_add_epi16( selector1, c_int16_3p0 ); + + // c[3, 16-31] + c_int16_3p1 = _mm256_add_epi16( selector2, c_int16_3p1 ); + + // c[4,0-15] + c_int16_4p0 = _mm256_add_epi16( selector1, c_int16_4p0 ); + + // c[4, 16-31] + c_int16_4p1 = _mm256_add_epi16( selector2, c_int16_4p1 ); + + // c[5,0-15] + c_int16_5p0 = _mm256_add_epi16( selector1, c_int16_5p0 ); + + // c[5, 16-31] + c_int16_5p1 = _mm256_add_epi16( selector2, c_int16_5p1 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_6x32: + { + selector1 = _mm256_setzero_si256 (); + + // c[0,0-15] + c_int16_0p0 = _mm256_max_epi16( selector1, c_int16_0p0 ); + + // c[0, 16-31] + c_int16_0p1 = _mm256_max_epi16( selector1, c_int16_0p1 ); + + // c[1,0-15] + c_int16_1p0 = _mm256_max_epi16( selector1, c_int16_1p0 ); + + // c[1,16-31] + c_int16_1p1 = _mm256_max_epi16( selector1, c_int16_1p1 ); + + // c[2,0-15] + c_int16_2p0 = _mm256_max_epi16( selector1, c_int16_2p0 ); + + // c[2,16-31] + c_int16_2p1 = _mm256_max_epi16( selector1, c_int16_2p1 ); + + // c[3,0-15] + c_int16_3p0 = _mm256_max_epi16( selector1, c_int16_3p0 ); + + // c[3,16-31] + c_int16_3p1 = _mm256_max_epi16( selector1, c_int16_3p1 ); + + // c[4,0-15] + c_int16_4p0 = _mm256_max_epi16( selector1, c_int16_4p0 ); + + // c[4,16-31] + c_int16_4p1 = _mm256_max_epi16( selector1, c_int16_4p1 ); + + // c[5,0-15] + c_int16_5p0 = _mm256_max_epi16( selector1, c_int16_5p0 ); + + // c[5,16-31] + c_int16_5p1 = _mm256_max_epi16( selector1, c_int16_5p1 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_6x32_DISABLE: + ; + // Store the results. // c[0,0-15] - _mm256_storeu_si256((__m256i *)(c + (rs_c * (ir + 0)) + (0 * 16)), c_int16_0p0); + _mm256_storeu_si256( (__m256i *)(c + ( rs_c * ( ir + 0 ) ) + ( 0*16 )), c_int16_0p0 ); // c[0, 16-31] - _mm256_storeu_si256((__m256i *)(c + (rs_c * (ir + 0)) + (1 * 16)), c_int16_0p1); + _mm256_storeu_si256( (__m256i *)(c + ( rs_c * ( ir + 0 ) ) + ( 1*16 )), c_int16_0p1 ); // c[1,0-15] - _mm256_storeu_si256((__m256i *)(c + (rs_c * (ir + 1)) + (0 * 16)), c_int16_1p0); + _mm256_storeu_si256( (__m256i *)(c + ( rs_c * ( ir + 1 ) ) + ( 0*16 )), c_int16_1p0 ); // c[1,16-31] - _mm256_storeu_si256((__m256i *)(c + (rs_c * (ir + 1)) + (1 * 16)), c_int16_1p1); + _mm256_storeu_si256( (__m256i *)(c + ( rs_c * ( ir + 1 ) ) + ( 1*16 )), c_int16_1p1 ); // c[2,0-15] - _mm256_storeu_si256((__m256i *)(c + (rs_c * (ir + 2)) + (0 * 16)), c_int16_2p0); + _mm256_storeu_si256( (__m256i *)(c + ( rs_c * ( ir + 2 ) ) + ( 0*16 )), c_int16_2p0 ); // c[2,16-31] - _mm256_storeu_si256((__m256i *)(c + (rs_c * (ir + 2)) + (1 * 16)), c_int16_2p1); + _mm256_storeu_si256( (__m256i *)(c + ( rs_c * ( ir + 2 ) ) + ( 1*16 )), c_int16_2p1 ); // c[3,0-15] - _mm256_storeu_si256((__m256i *)(c + (rs_c * (ir + 3)) + (0 * 16)), c_int16_3p0); + _mm256_storeu_si256( (__m256i *)(c + ( rs_c * ( ir + 3 ) ) + ( 0*16 )), c_int16_3p0 ); // c[3,16-31] - _mm256_storeu_si256((__m256i *)(c + (rs_c * (ir + 3)) + (1 * 16)), c_int16_3p1); + _mm256_storeu_si256( (__m256i *)(c + ( rs_c * ( ir + 3 ) ) + ( 1*16 )), c_int16_3p1 ); // c[4,0-15] - _mm256_storeu_si256((__m256i *)(c + (rs_c * (ir + 4)) + (0 * 16)), c_int16_4p0); + _mm256_storeu_si256( (__m256i *)(c + ( rs_c * ( ir + 4 ) ) + ( 0*16 )), c_int16_4p0 ); // c[4,16-31] - _mm256_storeu_si256((__m256i *)(c + (rs_c * (ir + 4)) + (1 * 16)), c_int16_4p1); + _mm256_storeu_si256( (__m256i *)(c + ( rs_c * ( ir + 4 ) ) + ( 1*16 )), c_int16_4p1 ); // c[5,0-15] - _mm256_storeu_si256((__m256i *)(c + (rs_c * (ir + 5)) + (0 * 16)), c_int16_5p0); + _mm256_storeu_si256( (__m256i *)(c + ( rs_c * ( ir + 5 ) ) + ( 0*16 )), c_int16_5p0 ); // c[5,16-31] - _mm256_storeu_si256((__m256i *)(c + (rs_c * (ir + 5)) + (1 * 16)), c_int16_5p1); - - a = a + (MR * ps_a); + _mm256_storeu_si256( (__m256i *)(c + ( rs_c * ( ir + 5 ) ) + ( 1*16 )), c_int16_5p1 ); + + a = a + ( MR * ps_a ); + post_op_c_i += MR; } if (m_partial_pieces > 0) @@ -443,7 +555,10 @@ void lpgemm_rowvar_u8s8s16o16_6x32 a, rs_a, cs_a, b, rs_b, cs_b, (c + (rs_c * m_full_pieces_loop_limit)), rs_c, - alpha, beta); + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list); // a pointer increment a = a + (4 * ps_a); @@ -457,7 +572,10 @@ void lpgemm_rowvar_u8s8s16o16_6x32 a, rs_a, cs_a, b, rs_b, cs_b, (c + (rs_c * m_full_pieces_loop_limit)), rs_c, - alpha, beta); + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list); // a pointer increment a = a + (2 * ps_a); @@ -471,7 +589,9 @@ void lpgemm_rowvar_u8s8s16o16_6x32 a, rs_a, cs_a, b, rs_b, cs_b, (c + (rs_c * m_full_pieces_loop_limit)), rs_c, - alpha, beta); + alpha, beta,is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list); } } } \ No newline at end of file diff --git a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_m_fringe_amd256.c b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_m_fringe_amd256.c index 56561ff972..84a472b45a 100644 --- a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_m_fringe_amd256.c +++ b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_m_fringe_amd256.c @@ -38,21 +38,33 @@ #include "lpgemm_m_fringe_s16.h" // 4x32 int8o16 kernel -void lpgemm_rowvar_u8s8s16o16_4x32( - const dim_t k0, - const uint8_t *a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t *b, - const dim_t rs_b, - const dim_t cs_b, - int16_t *c, - const dim_t rs_c, - const int16_t alpha, - const int16_t beta) +void lpgemm_rowvar_u8s8s16o16_4x32 + ( + const dim_t k0, + const uint8_t *a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t *b, + const dim_t rs_b, + const dim_t cs_b, + int16_t *c, + const dim_t rs_c, + const int16_t alpha, + const int16_t beta, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op *post_ops_list + ) { dim_t NR = 32; + static void *post_ops_labels[] = + { + &&POST_OPS_4x32_DISABLE, + &&POST_OPS_BIAS_4x32, + &&POST_OPS_RELU_4x32}; + // The division is done by considering the vpmaddubsw instruction dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -252,49 +264,133 @@ void lpgemm_rowvar_u8s8s16o16_4x32( selector1 = _mm256_mullo_epi16(selector2, selector1); c_int16_3p1 = _mm256_add_epi16(selector1, c_int16_3p1); } + + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_4x32: + { + selector1 = + _mm256_loadu_si256( (__m256i const *)((int16_t *)post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 )) ); + selector2 = + _mm256_loadu_si256( (__m256i const *)((int16_t *)post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 )) ); + + // c[0,0-15] + c_int16_0p0 = _mm256_add_epi16( selector1, c_int16_0p0 ); + + // c[0, 16-31] + c_int16_0p1 = _mm256_add_epi16( selector2, c_int16_0p1 ); + + // c[1,0-15] + c_int16_1p0 = _mm256_add_epi16( selector1, c_int16_1p0 ); + + // c[1, 16-31] + c_int16_1p1 = _mm256_add_epi16( selector2, c_int16_1p1 ); + + // c[2,0-15] + c_int16_2p0 = _mm256_add_epi16( selector1, c_int16_2p0 ); + + // c[2, 16-31] + c_int16_2p1 = _mm256_add_epi16( selector2, c_int16_2p1 ); + + // c[3,0-15] + c_int16_3p0 = _mm256_add_epi16( selector1, c_int16_3p0 ); + + // c[3, 16-31] + c_int16_3p1 = _mm256_add_epi16( selector2, c_int16_3p1 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_4x32: + { + selector1 = _mm256_setzero_si256 (); + + // c[0,0-15] + c_int16_0p0 = _mm256_max_epi16( selector1, c_int16_0p0 ); + + // c[0, 16-31] + c_int16_0p1 = _mm256_max_epi16( selector1, c_int16_0p1 ); + + // c[1,0-15] + c_int16_1p0 = _mm256_max_epi16( selector1, c_int16_1p0 ); + + // c[1,16-31] + c_int16_1p1 = _mm256_max_epi16( selector1, c_int16_1p1 ); + + // c[2,0-15] + c_int16_2p0 = _mm256_max_epi16( selector1, c_int16_2p0 ); + + // c[2,16-31] + c_int16_2p1 = _mm256_max_epi16( selector1, c_int16_2p1 ); + + // c[3,0-15] + c_int16_3p0 = _mm256_max_epi16( selector1, c_int16_3p0 ); + + // c[3,16-31] + c_int16_3p1 = _mm256_max_epi16( selector1, c_int16_3p1 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_4x32_DISABLE: + ; // Store the results. // c[0,0-15] - _mm256_storeu_si256((__m256i *)(c + (rs_c * 0) + (0 * 16)), c_int16_0p0); + _mm256_storeu_si256( (__m256i *)(c + ( rs_c * 0 ) + ( 0*16 )), c_int16_0p0 ); // c[0, 16-31] - _mm256_storeu_si256((__m256i *)(c + (rs_c * 0) + (1 * 16)), c_int16_0p1); + _mm256_storeu_si256( (__m256i *)(c + ( rs_c * 0 ) + ( 1*16 )), c_int16_0p1 ); // c[1,0-15] - _mm256_storeu_si256((__m256i *)(c + (rs_c * 1) + (0 * 16)), c_int16_1p0); + _mm256_storeu_si256( (__m256i *)(c + ( rs_c * 1 ) + ( 0*16 )), c_int16_1p0 ); // c[1,16-31] - _mm256_storeu_si256((__m256i *)(c + (rs_c * 1) + (1 * 16)), c_int16_1p1); + _mm256_storeu_si256( (__m256i *)(c + ( rs_c * 1 ) + ( 1*16 )), c_int16_1p1 ); // c[2,0-15] - _mm256_storeu_si256((__m256i *)(c + (rs_c * 2) + (0 * 16)), c_int16_2p0); + _mm256_storeu_si256( (__m256i *)(c + ( rs_c * 2 ) + ( 0*16 )), c_int16_2p0 ); // c[2,16-31] - _mm256_storeu_si256((__m256i *)(c + (rs_c * 2) + (1 * 16)), c_int16_2p1); + _mm256_storeu_si256( (__m256i *)(c + ( rs_c * 2 ) + ( 1*16 )), c_int16_2p1 ); // c[3,0-15] - _mm256_storeu_si256((__m256i *)(c + (rs_c * 3) + (0 * 16)), c_int16_3p0); + _mm256_storeu_si256( (__m256i *)(c + ( rs_c * 3 ) + ( 0*16 )), c_int16_3p0 ); // c[3,16-31] - _mm256_storeu_si256((__m256i *)(c + (rs_c * 3) + (1 * 16)), c_int16_3p1); + _mm256_storeu_si256( (__m256i *)(c + ( rs_c * 3 ) + ( 1*16 )), c_int16_3p1 ); } + // 2x32 int8o16 kernel -void lpgemm_rowvar_u8s8s16o16_2x32( - const dim_t k0, - const uint8_t *a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t *b, - const dim_t rs_b, - const dim_t cs_b, - int16_t *c, - const dim_t rs_c, - const int16_t alpha, - const int16_t beta) +void lpgemm_rowvar_u8s8s16o16_2x32 + ( + const dim_t k0, + const uint8_t *a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t *b, + const dim_t rs_b, + const dim_t cs_b, + int16_t *c, + const dim_t rs_c, + const int16_t alpha, + const int16_t beta, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op *post_ops_list + ) { dim_t NR = 32; + static void *post_ops_labels[] = + { + &&POST_OPS_2x32_DISABLE, + &&POST_OPS_BIAS_2x32, + &&POST_OPS_RELU_2x32}; + // The division is done by considering the vpmaddubsw instruction dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -414,36 +510,95 @@ void lpgemm_rowvar_u8s8s16o16_2x32( c_int16_1p1 = _mm256_add_epi16(selector1, c_int16_1p1); } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_2x32: + { + selector1 = + _mm256_loadu_si256( (__m256i const *)((int16_t *)post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 )) ); + selector2 = + _mm256_loadu_si256( (__m256i const *)((int16_t *)post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 )) ); + + // c[0,0-15] + c_int16_0p0 = _mm256_add_epi16( selector1, c_int16_0p0 ); + + // c[0, 16-31] + c_int16_0p1 = _mm256_add_epi16( selector2, c_int16_0p1 ); + + // c[1,0-15] + c_int16_1p0 = _mm256_add_epi16( selector1, c_int16_1p0 ); + + // c[1, 16-31] + c_int16_1p1 = _mm256_add_epi16( selector2, c_int16_1p1 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_2x32: + { + selector1 = _mm256_setzero_si256 (); + + // c[0,0-15] + c_int16_0p0 = _mm256_max_epi16( selector1, c_int16_0p0 ); + + // c[0, 16-31] + c_int16_0p1 = _mm256_max_epi16( selector1, c_int16_0p1 ); + + // c[1,0-15] + c_int16_1p0 = _mm256_max_epi16( selector1, c_int16_1p0 ); + + // c[1,16-31] + c_int16_1p1 = _mm256_max_epi16( selector1, c_int16_1p1 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_2x32_DISABLE: + ; + // Store the results. // c[0,0-15] - _mm256_storeu_si256((__m256i *)(c + (rs_c * 0) + (0 * 16)), c_int16_0p0); + _mm256_storeu_si256( (__m256i *)(c + ( rs_c * 0 ) + ( 0*16 )), c_int16_0p0 ); // c[0, 16-31] - _mm256_storeu_si256((__m256i *)(c + (rs_c * 0) + (1 * 16)), c_int16_0p1); + _mm256_storeu_si256( (__m256i *)(c + ( rs_c * 0 ) + ( 1*16 )), c_int16_0p1 ); // c[1,0-15] - _mm256_storeu_si256((__m256i *)(c + (rs_c * 1) + (0 * 16)), c_int16_1p0); + _mm256_storeu_si256( (__m256i *)(c + ( rs_c * 1 ) + ( 0*16 )), c_int16_1p0 ); // c[1,16-31] - _mm256_storeu_si256((__m256i *)(c + (rs_c * 1) + (1 * 16)), c_int16_1p1); + _mm256_storeu_si256( (__m256i *)(c + ( rs_c * 1 ) + ( 1*16 )), c_int16_1p1 ); } // 1x32 int8o16 kernel -void lpgemm_rowvar_u8s8s16o16_1x32( - const dim_t k0, - const uint8_t *a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t *b, - const dim_t rs_b, - const dim_t cs_b, - int16_t *c, - const dim_t rs_c, - const int16_t alpha, - const int16_t beta) +void lpgemm_rowvar_u8s8s16o16_1x32 + ( + const dim_t k0, + const uint8_t *a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t *b, + const dim_t rs_b, + const dim_t cs_b, + int16_t *c, + const dim_t rs_c, + const int16_t alpha, + const int16_t beta, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op *post_ops_list + ) { dim_t NR = 32; + static void *post_ops_labels[] = + { + &&POST_OPS_1x32_DISABLE, + &&POST_OPS_BIAS_1x32, + &&POST_OPS_RELU_1x32}; + // The division is done by considering the vpmaddubsw instruction dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -522,10 +677,45 @@ void lpgemm_rowvar_u8s8s16o16_1x32( c_int16_0p1 = _mm256_add_epi16(selector1, c_int16_0p1); } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_1x32: + { + selector1 = + _mm256_loadu_si256( (__m256i const *)((int16_t *)post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 )) ); + selector2 = + _mm256_loadu_si256( (__m256i const *)((int16_t *)post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 )) ); + + // c[0,0-15] + c_int16_0p0 = _mm256_add_epi16( selector1, c_int16_0p0 ); + + // c[0, 16-31] + c_int16_0p1 = _mm256_add_epi16( selector2, c_int16_0p1 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_1x32: + { + selector1 = _mm256_setzero_si256 (); + + // c[0,0-15] + c_int16_0p0 = _mm256_max_epi16( selector1, c_int16_0p0 ); + + // c[0, 16-31] + c_int16_0p1 = _mm256_max_epi16( selector1, c_int16_0p1 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_1x32_DISABLE: + ; + // Store the results. // c[0,0-15] - _mm256_storeu_si256((__m256i *)(c + (rs_c * 0) + (0 * 16)), c_int16_0p0); + _mm256_storeu_si256( (__m256i *)(c + ( rs_c * 0 ) + ( 0*16 )), c_int16_0p0 ); // c[0, 16-31] - _mm256_storeu_si256((__m256i *)(c + (rs_c * 0) + (1 * 16)), c_int16_0p1); + _mm256_storeu_si256( (__m256i *)(c + ( rs_c * 0 ) + ( 1*16 )), c_int16_0p1 ); } diff --git a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_m_fringe_s16.h b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_m_fringe_s16.h index f69ea2630e..da1930cbab 100644 --- a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_m_fringe_s16.h +++ b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_m_fringe_s16.h @@ -35,46 +35,66 @@ #ifndef BLIS_GEMM_INT16_MFRINGE #define BLIS_GEMM_INT16_MFRINGE +#include "lpgemm_post_ops.h" + // 4x32 int8o16 kernel -void lpgemm_rowvar_u8s8s16o16_4x32( - const dim_t k0, - const uint8_t *a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t *b, - const dim_t rs_b, - const dim_t cs_b, - int16_t *c, - const dim_t rs_c, - const int16_t alpha, - const int16_t beta); +void lpgemm_rowvar_u8s8s16o16_4x32 + ( + const dim_t k0, + const uint8_t *a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t *b, + const dim_t rs_b, + const dim_t cs_b, + int16_t *c, + const dim_t rs_c, + const int16_t alpha, + const int16_t beta, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op *post_ops_list + ); // 2x32 int8o16 kernel -void lpgemm_rowvar_u8s8s16o16_2x32( - const dim_t k0, - const uint8_t *a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t *b, - const dim_t rs_b, - const dim_t cs_b, - int16_t *c, - const dim_t rs_c, - const int16_t alpha, - const int16_t beta); +void lpgemm_rowvar_u8s8s16o16_2x32 + ( + const dim_t k0, + const uint8_t *a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t *b, + const dim_t rs_b, + const dim_t cs_b, + int16_t *c, + const dim_t rs_c, + const int16_t alpha, + const int16_t beta, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op *post_ops_list + ); // 1x32 int8o16 kernel -void lpgemm_rowvar_u8s8s16o16_1x32( - const dim_t k0, - const uint8_t *a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t *b, - const dim_t rs_b, - const dim_t cs_b, - int16_t *c, - const dim_t rs_c, - const int16_t alpha, - const int16_t beta); +void lpgemm_rowvar_u8s8s16o16_1x32 + ( + const dim_t k0, + const uint8_t *a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t *b, + const dim_t rs_b, + const dim_t cs_b, + int16_t *c, + const dim_t rs_c, + const int16_t alpha, + const int16_t beta, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op *post_ops_list + ); #endif // BLIS_GEMM_INT16_MFRINGE \ No newline at end of file diff --git a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_mn_fringe_amd256.c b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_mn_fringe_amd256.c index a494da8160..a8e8547cf4 100644 --- a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_mn_fringe_amd256.c +++ b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_mn_fringe_amd256.c @@ -50,11 +50,21 @@ void lpgemm_rowvar_u8s8s16o16_4x16 int16_t *c, const dim_t rs_c, const int16_t alpha, - const int16_t beta + const int16_t beta, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op *post_ops_list ) { dim_t NR = 16; + static void *post_ops_labels[] = + { + &&POST_OPS_4x16_DISABLE, + &&POST_OPS_BIAS_4x16, + &&POST_OPS_RELU_4x16}; + // The division is done by considering the vpmaddubsw instruction dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -204,6 +214,50 @@ void lpgemm_rowvar_u8s8s16o16_4x16 c_int16_3p0 = _mm256_add_epi16(selector1, c_int16_3p0); } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_4x16: + { + selector1 = + _mm256_loadu_si256( (__m256i const *)((int16_t *)post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 )) ); + + // c[0,0-15] + c_int16_0p0 = _mm256_add_epi16( selector1, c_int16_0p0 ); + + // c[1,0-15] + c_int16_1p0 = _mm256_add_epi16( selector1, c_int16_1p0 ); + + // c[2,0-15] + c_int16_2p0 = _mm256_add_epi16( selector1, c_int16_2p0 ); + + // c[3,0-15] + c_int16_3p0 = _mm256_add_epi16( selector1, c_int16_3p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_4x16: + { + selector1 = _mm256_setzero_si256 (); + + // c[0,0-15] + c_int16_0p0 = _mm256_max_epi16( selector1, c_int16_0p0 ); + + // c[1,0-15] + c_int16_1p0 = _mm256_max_epi16( selector1, c_int16_1p0 ); + + // c[2,0-15] + c_int16_2p0 = _mm256_max_epi16( selector1, c_int16_2p0 ); + + // c[3,0-15] + c_int16_3p0 = _mm256_max_epi16( selector1, c_int16_3p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_4x16_DISABLE: + ; + // Store the results. // c[0,0-15] _mm256_storeu_si256((__m256i *)(c + (rs_c * 0) + (0 * 16)), c_int16_0p0); @@ -232,11 +286,21 @@ void lpgemm_rowvar_u8s8s16o16_4xlt16 const dim_t rs_c, const int16_t alpha, const int16_t beta, - dim_t n0_rem + dim_t n0_rem, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op *post_ops_list ) { dim_t NR = 16; + static void *post_ops_labels[] = + { + &&POST_OPS_4xlt16_DISABLE, + &&POST_OPS_BIAS_4xlt16, + &&POST_OPS_RELU_4xlt16}; + // The division is done by considering the vpmaddubsw instruction dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -399,6 +463,54 @@ void lpgemm_rowvar_u8s8s16o16_4xlt16 c_int16_3p0 = _mm256_add_epi16(selector1, c_int16_3p0); } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_4xlt16: + { + int16_t buf4[16]; + + memcpy(buf4, (int16_t *)(post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 )), (n0_rem * sizeof(int16_t))); + + selector1 = + _mm256_loadu_si256( (__m256i const *) buf4 ); + + // c[0,0-15] + c_int16_0p0 = _mm256_add_epi16( selector1, c_int16_0p0 ); + + // c[1,0-15] + c_int16_1p0 = _mm256_add_epi16( selector1, c_int16_1p0 ); + + // c[2,0-15] + c_int16_2p0 = _mm256_add_epi16( selector1, c_int16_2p0 ); + + // c[3,0-15] + c_int16_3p0 = _mm256_add_epi16( selector1, c_int16_3p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_4xlt16: + { + selector1 = _mm256_setzero_si256 (); + + // c[0,0-15] + c_int16_0p0 = _mm256_max_epi16( selector1, c_int16_0p0 ); + + // c[1,0-15] + c_int16_1p0 = _mm256_max_epi16( selector1, c_int16_1p0 ); + + // c[2,0-15] + c_int16_2p0 = _mm256_max_epi16( selector1, c_int16_2p0 ); + + // c[3,0-15] + c_int16_3p0 = _mm256_max_epi16( selector1, c_int16_3p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_4xlt16_DISABLE: + ; + // c[0,0-15] _mm256_storeu_si256((__m256i_u *)buf0, c_int16_0p0); @@ -436,11 +548,21 @@ void lpgemm_rowvar_u8s8s16o16_2x16 int16_t *c, const dim_t rs_c, const int16_t alpha, - const int16_t beta + const int16_t beta, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op *post_ops_list ) { dim_t NR = 16; + static void *post_ops_labels[] = + { + &&POST_OPS_2x16_DISABLE, + &&POST_OPS_BIAS_2x16, + &&POST_OPS_RELU_2x16}; + // The division is done by considering the vpmaddubsw instruction dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -534,6 +656,38 @@ void lpgemm_rowvar_u8s8s16o16_2x16 c_int16_1p0 = _mm256_add_epi16(selector1, c_int16_1p0); } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_2x16: + { + selector1 = + _mm256_loadu_si256( (__m256i const *)((int16_t *)post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 )) ); + + // c[0,0-15] + c_int16_0p0 = _mm256_add_epi16( selector1, c_int16_0p0 ); + + // c[1,0-15] + c_int16_1p0 = _mm256_add_epi16( selector1, c_int16_1p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_2x16: + { + selector1 = _mm256_setzero_si256 (); + + // c[0,0-15] + c_int16_0p0 = _mm256_max_epi16( selector1, c_int16_0p0 ); + + // c[1,0-15] + c_int16_1p0 = _mm256_max_epi16( selector1, c_int16_1p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_2x16_DISABLE: + ; + // Store the results. // c[0,0-15] _mm256_storeu_si256((__m256i *)(c + (rs_c * 0) + (0 * 16)), c_int16_0p0); @@ -556,11 +710,21 @@ void lpgemm_rowvar_u8s8s16o16_2xlt16 const dim_t rs_c, const int16_t alpha, const int16_t beta, - dim_t n0_rem + dim_t n0_rem, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op *post_ops_list ) { dim_t NR = 16; + static void *post_ops_labels[] = + { + &&POST_OPS_2xlt16_DISABLE, + &&POST_OPS_BIAS_2xlt16, + &&POST_OPS_RELU_2xlt16}; + // The division is done by considering the vpmaddubsw instruction dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -660,6 +824,42 @@ void lpgemm_rowvar_u8s8s16o16_2xlt16 c_int16_1p0 = _mm256_add_epi16(selector1, c_int16_1p0); } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_2xlt16: + { + int16_t buf4[16]; + + memcpy(buf4, (int16_t *)(post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 )), (n0_rem * sizeof(int16_t))); + + selector1 = + _mm256_loadu_si256( (__m256i const *) buf4); + + // c[0,0-15] + c_int16_0p0 = _mm256_add_epi16( selector1, c_int16_0p0 ); + + // c[1,0-15] + c_int16_1p0 = _mm256_add_epi16( selector1, c_int16_1p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_2xlt16: + { + selector1 = _mm256_setzero_si256 (); + + // c[0,0-15] + c_int16_0p0 = _mm256_max_epi16( selector1, c_int16_0p0 ); + + // c[1,0-15] + c_int16_1p0 = _mm256_max_epi16( selector1, c_int16_1p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_2xlt16_DISABLE: + ; + // c[0,0-15] _mm256_storeu_si256((__m256i_u *)buf0, c_int16_0p0); @@ -686,11 +886,21 @@ void lpgemm_rowvar_u8s8s16o16_1x16 int16_t *c, const dim_t rs_c, const int16_t alpha, - const int16_t beta + const int16_t beta, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op *post_ops_list ) { int NR = 16; + static void *post_ops_labels[] = + { + &&POST_OPS_1x16_DISABLE, + &&POST_OPS_BIAS_1x16, + &&POST_OPS_RELU_1x16}; + // The division is done by considering the vpmaddubsw instruction int k_full_pieces = k0 / 2; int k_partial_pieces = k0 % 2; @@ -755,9 +965,35 @@ void lpgemm_rowvar_u8s8s16o16_1x16 c_int16_0p0 = _mm256_add_epi16(selector1, c_int16_0p0); } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_1x16: + { + selector1 = + _mm256_loadu_si256( (__m256i const *)((int16_t *)post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 )) ); + + // c[0,0-15] + c_int16_0p0 = _mm256_add_epi16( selector1, c_int16_0p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_1x16: + { + selector1 = _mm256_setzero_si256 (); + + // c[0,0-15] + c_int16_0p0 = _mm256_max_epi16( selector1, c_int16_0p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_1x16_DISABLE: + ; + // Store the results. // c[0,0-15] - _mm256_storeu_si256((__m256i *)(c + (rs_c * 0) + (0 * 16)), c_int16_0p0); + _mm256_storeu_si256( (__m256i *)(c + ( rs_c * 0 ) + ( 0*16 )), c_int16_0p0 ); } // 1xlt16 int8o16 kernel @@ -774,11 +1010,21 @@ void lpgemm_rowvar_u8s8s16o16_1xlt16 const int rs_c, const int16_t alpha, const int16_t beta, - dim_t n0_rem + dim_t n0_rem, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op *post_ops_list ) { int NR = 16; + static void *post_ops_labels[] = + { + &&POST_OPS_1xlt16_DISABLE, + &&POST_OPS_BIAS_1xlt16, + &&POST_OPS_RELU_1xlt16}; + // The division is done by considering the vpmaddubsw instruction int k_full_pieces = k0 / 2; int k_partial_pieces = k0 % 2; @@ -847,6 +1093,36 @@ void lpgemm_rowvar_u8s8s16o16_1xlt16 c_int16_0p0 = _mm256_add_epi16(selector1, c_int16_0p0); } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_1xlt16: + { + int16_t buf4[16]; + + memcpy(buf4, (int16_t *)(post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 )), (n0_rem * sizeof(int16_t))); + + selector1 = + _mm256_loadu_si256( (__m256i const *)buf4 ); + + // c[0,0-15] + c_int16_0p0 = _mm256_add_epi16( selector1, c_int16_0p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_1xlt16: + { + selector1 = _mm256_setzero_si256 (); + + // c[0,0-15] + c_int16_0p0 = _mm256_max_epi16( selector1, c_int16_0p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_1xlt16_DISABLE: + ; + // c[0,0-15] _mm256_storeu_si256((__m256i_u *)buf0, c_int16_0p0); diff --git a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_mn_fringe_s16.h b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_mn_fringe_s16.h index 6b34fae976..bfffe5c336 100644 --- a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_mn_fringe_s16.h +++ b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_mn_fringe_s16.h @@ -35,85 +35,123 @@ #ifndef BLIS_GEMM_INT16_MNFRINGE #define BLIS_GEMM_INT16_MNFRINGE -void lpgemm_rowvar_u8s8s16o16_4x16( - const dim_t k0, - const uint8_t *a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t *b, - const dim_t rs_b, - const dim_t cs_b, - int16_t *c, - const dim_t rs_c, - const int16_t alpha, - const int16_t beta); +#include "lpgemm_post_ops.h" -void lpgemm_rowvar_u8s8s16o16_4xlt16( - const dim_t k0, - const uint8_t *a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t *b, - const dim_t rs_b, - const dim_t cs_b, - int16_t *c, - const dim_t rs_c, - const int16_t alpha, - const int16_t beta, - dim_t n0_rem); +void lpgemm_rowvar_u8s8s16o16_4x16 + ( + const dim_t k0, + const uint8_t *a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t *b, + const dim_t rs_b, + const dim_t cs_b, + int16_t *c, + const dim_t rs_c, + const int16_t alpha, + const int16_t beta, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op *post_ops_list + ); -void lpgemm_rowvar_u8s8s16o16_2x16( - const dim_t k0, - const uint8_t *a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t *b, - const dim_t rs_b, - const dim_t cs_b, - int16_t *c, - const dim_t rs_c, - const int16_t alpha, - const int16_t beta); +void lpgemm_rowvar_u8s8s16o16_4xlt16 + ( + const dim_t k0, + const uint8_t *a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t *b, + const dim_t rs_b, + const dim_t cs_b, + int16_t *c, + const dim_t rs_c, + const int16_t alpha, + const int16_t beta, + dim_t n0_rem, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op *post_ops_list + ); -void lpgemm_rowvar_u8s8s16o16_2xlt16( - const dim_t k0, - const uint8_t *a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t *b, - const dim_t rs_b, - const dim_t cs_b, - int16_t *c, - const dim_t rs_c, - const int16_t alpha, - const int16_t beta, - dim_t n0_rem); +void lpgemm_rowvar_u8s8s16o16_2x16 + ( + const dim_t k0, + const uint8_t *a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t *b, + const dim_t rs_b, + const dim_t cs_b, + int16_t *c, + const dim_t rs_c, + const int16_t alpha, + const int16_t beta, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op *post_ops_list + ); -void lpgemm_rowvar_u8s8s16o16_1x16( - const dim_t k0, - const uint8_t *a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t *b, - const dim_t rs_b, - const dim_t cs_b, - int16_t *c, - const dim_t rs_c, - const int16_t alpha, - const int16_t beta); +void lpgemm_rowvar_u8s8s16o16_2xlt16 + ( + const dim_t k0, + const uint8_t *a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t *b, + const dim_t rs_b, + const dim_t cs_b, + int16_t *c, + const dim_t rs_c, + const int16_t alpha, + const int16_t beta, + dim_t n0_rem, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op *post_ops_list + ); -void lpgemm_rowvar_u8s8s16o16_1xlt16( - const int k0, - const uint8_t *a, - const int rs_a, - const int cs_a, - const int8_t *b, - const int rs_b, - const int cs_b, - int16_t *c, - const int rs_c, - const int16_t alpha, - const int16_t beta, - dim_t n0_rem); +void lpgemm_rowvar_u8s8s16o16_1x16 + ( + const dim_t k0, + const uint8_t *a, + const dim_t rs_a, + const dim_t cs_a, + const int8_t *b, + const dim_t rs_b, + const dim_t cs_b, + int16_t *c, + const dim_t rs_c, + const int16_t alpha, + const int16_t beta, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op *post_ops_list + ); + +void lpgemm_rowvar_u8s8s16o16_1xlt16 + ( + const int k0, + const uint8_t *a, + const int rs_a, + const int cs_a, + const int8_t *b, + const int rs_b, + const int cs_b, + int16_t *c, + const int rs_c, + const int16_t alpha, + const int16_t beta, + dim_t n0_rem, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op *post_ops_list + ); #endif // BLIS_GEMM_INT16_MNFRINGE \ No newline at end of file diff --git a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_n_fringe_amd256.c b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_n_fringe_amd256.c index 8fb2cdec1f..53ab412ded 100644 --- a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_n_fringe_amd256.c +++ b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_n_fringe_amd256.c @@ -39,24 +39,36 @@ #include "lpgemm_mn_fringe_s16.h" // 6x16 int8o16 kernel -void lpgemm_rowvar_u8s8s16o16_6x16( - const dim_t m0, - const dim_t k0, - const uint8_t *a, - const dim_t rs_a, - const dim_t cs_a, - const dim_t ps_a, - const int8_t *b, - const dim_t rs_b, - const dim_t cs_b, - int16_t *c, - const dim_t rs_c, - const int16_t alpha, - const int16_t beta) +void lpgemm_rowvar_u8s8s16o16_6x16 + ( + const dim_t m0, + const dim_t k0, + const uint8_t *a, + const dim_t rs_a, + const dim_t cs_a, + const dim_t ps_a, + const int8_t *b, + const dim_t rs_b, + const dim_t cs_b, + int16_t *c, + const dim_t rs_c, + const int16_t alpha, + const int16_t beta, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op *post_ops_list + ) { dim_t MR = 6; dim_t NR = 16; + static void *post_ops_labels[] = + { + &&POST_OPS_6x16_DISABLE, + &&POST_OPS_BIAS_6x16, + &&POST_OPS_RELU_6x16}; + dim_t m_full_pieces = m0 / MR; dim_t m_full_pieces_loop_limit = m_full_pieces * MR; dim_t m_partial_pieces = m0 % MR; @@ -273,26 +285,83 @@ void lpgemm_rowvar_u8s8s16o16_6x16( c_int16_5p0 = _mm256_add_epi16(selector1, c_int16_5p0); } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_6x16: + { + selector1 = + _mm256_loadu_si256( (__m256i const *)((int16_t *)post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 )) ); + + // c[0,0-15] + c_int16_0p0 = _mm256_add_epi16( selector1, c_int16_0p0 ); + + // c[1,0-15] + c_int16_1p0 = _mm256_add_epi16( selector1, c_int16_1p0 ); + + // c[2,0-15] + c_int16_2p0 = _mm256_add_epi16( selector1, c_int16_2p0 ); + + // c[3,0-15] + c_int16_3p0 = _mm256_add_epi16( selector1, c_int16_3p0 ); + + // c[4,0-15] + c_int16_4p0 = _mm256_add_epi16( selector1, c_int16_4p0 ); + + // c[5,0-15] + c_int16_5p0 = _mm256_add_epi16( selector1, c_int16_5p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_6x16: + { + selector1 = _mm256_setzero_si256 (); + + // c[0,0-15] + c_int16_0p0 = _mm256_max_epi16( selector1, c_int16_0p0 ); + + // c[1,0-15] + c_int16_1p0 = _mm256_max_epi16( selector1, c_int16_1p0 ); + + // c[2,0-15] + c_int16_2p0 = _mm256_max_epi16( selector1, c_int16_2p0 ); + + // c[3,0-15] + c_int16_3p0 = _mm256_max_epi16( selector1, c_int16_3p0 ); + + // c[4,0-15] + c_int16_4p0 = _mm256_max_epi16( selector1, c_int16_4p0 ); + + // c[5,0-15] + c_int16_5p0 = _mm256_max_epi16( selector1, c_int16_5p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_6x16_DISABLE: + ; + // Store the results. // c[0,0-15] - _mm256_storeu_si256((__m256i *)(c + (rs_c * (ir + 0)) + (0 * 16)), c_int16_0p0); + _mm256_storeu_si256( (__m256i *)(c + ( rs_c * 0 ) + ( 0*16 )), c_int16_0p0 ); // c[1,0-15] - _mm256_storeu_si256((__m256i *)(c + (rs_c * (ir + 1)) + (0 * 16)), c_int16_1p0); + _mm256_storeu_si256( (__m256i *)(c + ( rs_c * 1 ) + ( 0*16 )), c_int16_1p0 ); // c[2,0-15] - _mm256_storeu_si256((__m256i *)(c + (rs_c * (ir + 2)) + (0 * 16)), c_int16_2p0); + _mm256_storeu_si256( (__m256i *)(c + ( rs_c * 2 ) + ( 0*16 )), c_int16_2p0 ); // c[3,0-15] - _mm256_storeu_si256((__m256i *)(c + (rs_c * (ir + 3)) + (0 * 16)), c_int16_3p0); + _mm256_storeu_si256( (__m256i *)(c + ( rs_c * 3 ) + ( 0*16 )), c_int16_3p0 ); // c[4,0-15] - _mm256_storeu_si256((__m256i *)(c + (rs_c * (ir + 4)) + (0 * 16)), c_int16_4p0); + _mm256_storeu_si256( (__m256i *)(c + ( rs_c * 4 ) + ( 0*16 )), c_int16_4p0 ); // c[5,0-15] - _mm256_storeu_si256((__m256i *)(c + (rs_c * (ir + 5)) + (0 * 16)), c_int16_5p0); - - a = a + (MR * ps_a); + _mm256_storeu_si256( (__m256i *)(c + ( rs_c * 5 ) + ( 0*16 )), c_int16_5p0 ); + + a = a + ( MR * ps_a ); + post_op_c_i += MR; } if (m_partial_pieces > 0) @@ -310,7 +379,10 @@ void lpgemm_rowvar_u8s8s16o16_6x16( a, rs_a, cs_a, b, rs_b, cs_b, (c + (rs_c * m_full_pieces_loop_limit)), rs_c, - alpha, beta); + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list); // a pointer increment a = a + (4 * ps_a); @@ -324,7 +396,10 @@ void lpgemm_rowvar_u8s8s16o16_6x16( a, rs_a, cs_a, b, rs_b, cs_b, (c + (rs_c * m_full_pieces_loop_limit)), rs_c, - alpha, beta); + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list); // a pointer increment a = a + (2 * ps_a); @@ -338,30 +413,45 @@ void lpgemm_rowvar_u8s8s16o16_6x16( a, rs_a, cs_a, b, rs_b, cs_b, (c + (rs_c * m_full_pieces_loop_limit)), rs_c, - alpha, beta); + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list); } } } // 6xlt16 int8o16 kernel -void lpgemm_rowvar_u8s8s16o16_6xlt16( - const dim_t m0, - const dim_t k0, - const uint8_t *a, - const dim_t rs_a, - const dim_t cs_a, - const dim_t ps_a, - const int8_t *b, - const dim_t rs_b, - const dim_t cs_b, - int16_t *c, - const dim_t rs_c, - const int16_t alpha, - const int16_t beta, - const dim_t n0_rem) +void lpgemm_rowvar_u8s8s16o16_6xlt16 + ( + const dim_t m0, + const dim_t k0, + const uint8_t *a, + const dim_t rs_a, + const dim_t cs_a, + const dim_t ps_a, + const int8_t *b, + const dim_t rs_b, + const dim_t cs_b, + int16_t *c, + const dim_t rs_c, + const int16_t alpha, + const int16_t beta, + const dim_t n0_rem, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op *post_ops_list + ) { dim_t MR = 6; + static void *post_ops_labels[] = + { + &&POST_OPS_6xlt16_DISABLE, + &&POST_OPS_BIAS_6xlt16, + &&POST_OPS_RELU_6xlt16}; + dim_t m_full_pieces = m0 / MR; dim_t m_full_pieces_loop_limit = m_full_pieces * MR; dim_t m_partial_pieces = m0 % MR; @@ -591,6 +681,66 @@ void lpgemm_rowvar_u8s8s16o16_6xlt16( c_int16_5p0 = _mm256_add_epi16(selector1, c_int16_5p0); } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_6xlt16: + { + int16_t buf6[16]; + + memcpy(buf0, (int16_t *)(post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 )), (n0_rem * sizeof(int16_t))); + + selector1 = + _mm256_loadu_si256( (__m256i const *)buf6 ); + + // c[0,0-15] + c_int16_0p0 = _mm256_add_epi16( selector1, c_int16_0p0 ); + + // c[1,0-15] + c_int16_1p0 = _mm256_add_epi16( selector1, c_int16_1p0 ); + + // c[2,0-15] + c_int16_2p0 = _mm256_add_epi16( selector1, c_int16_2p0 ); + + // c[3,0-15] + c_int16_3p0 = _mm256_add_epi16( selector1, c_int16_3p0 ); + + // c[4,0-15] + c_int16_4p0 = _mm256_add_epi16( selector1, c_int16_4p0 ); + + // c[5,0-15] + c_int16_5p0 = _mm256_add_epi16( selector1, c_int16_5p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_6xlt16: + { + selector1 = _mm256_setzero_si256 (); + + // c[0,0-15] + c_int16_0p0 = _mm256_max_epi16( selector1, c_int16_0p0 ); + + // c[1,0-15] + c_int16_1p0 = _mm256_max_epi16( selector1, c_int16_1p0 ); + + // c[2,0-15] + c_int16_2p0 = _mm256_max_epi16( selector1, c_int16_2p0 ); + + // c[3,0-15] + c_int16_3p0 = _mm256_max_epi16( selector1, c_int16_3p0 ); + + // c[4,0-15] + c_int16_4p0 = _mm256_max_epi16( selector1, c_int16_4p0 ); + + // c[5,0-15] + c_int16_5p0 = _mm256_max_epi16( selector1, c_int16_5p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_6xlt16_DISABLE: + ; + // Store the results. // c[0,0-15] _mm256_storeu_si256((__m256i_u *)buf0, c_int16_0p0); @@ -628,6 +778,7 @@ void lpgemm_rowvar_u8s8s16o16_6xlt16( memcpy(c + (rs_c * (ir + 5)) + (0 * 16), buf5, (n0_rem * sizeof(int16_t))); a = a + (MR * ps_a); + post_op_c_i += MR; } if (m_partial_pieces > 0) @@ -645,7 +796,10 @@ void lpgemm_rowvar_u8s8s16o16_6xlt16( a, rs_a, cs_a, b, rs_b, cs_b, (c + (rs_c * m_full_pieces_loop_limit)), rs_c, - alpha, beta, n0_rem); + alpha, beta, n0_rem, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list); // a pointer increment a = a + (4 * ps_a); @@ -659,7 +813,10 @@ void lpgemm_rowvar_u8s8s16o16_6xlt16( a, rs_a, cs_a, b, rs_b, cs_b, (c + (rs_c * m_full_pieces_loop_limit)), rs_c, - alpha, beta, n0_rem); + alpha, beta, n0_rem, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list); // a pointer increment a = a + (2 * ps_a); @@ -673,7 +830,10 @@ void lpgemm_rowvar_u8s8s16o16_6xlt16( a, rs_a, cs_a, b, rs_b, cs_b, (c + (rs_c * m_full_pieces_loop_limit)), rs_c, - alpha, beta, n0_rem); + alpha, beta, n0_rem, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list); } } } diff --git a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_n_fringe_s16.h b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_n_fringe_s16.h index 7987aa04aa..9ca5578d87 100644 --- a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_n_fringe_s16.h +++ b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_n_fringe_s16.h @@ -34,38 +34,51 @@ #ifndef BLIS_GEMM_INT16_NFRINGE #define BLIS_GEMM_INT16_NFRINGE +#include "lpgemm_post_ops.h" // 6x16 int8o16 kernel -void lpgemm_rowvar_u8s8s16o16_6x16( - const dim_t m0, - const dim_t k0, - const uint8_t *a, - const dim_t rs_a, - const dim_t cs_a, - const dim_t ps_a, - const int8_t *b, - const dim_t rs_b, - const dim_t cs_b, - int16_t *c, - const dim_t rs_c, - const int16_t alpha, - const int16_t beta); +void lpgemm_rowvar_u8s8s16o16_6x16 + ( + const dim_t m0, + const dim_t k0, + const uint8_t *a, + const dim_t rs_a, + const dim_t cs_a, + const dim_t ps_a, + const int8_t *b, + const dim_t rs_b, + const dim_t cs_b, + int16_t *c, + const dim_t rs_c, + const int16_t alpha, + const int16_t beta, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op *post_ops_list + ); // 6xlt16 int8o16 kernel -void lpgemm_rowvar_u8s8s16o16_6xlt16( - const dim_t m0, - const dim_t k0, - const uint8_t *a, - const dim_t rs_a, - const dim_t cs_a, - const dim_t ps_a, - const int8_t *b, - const dim_t rs_b, - const dim_t cs_b, - int16_t *c, - const dim_t rs_c, - const int16_t alpha, - const int16_t beta, - const dim_t n0_rem); +void lpgemm_rowvar_u8s8s16o16_6xlt16 + ( + const dim_t m0, + const dim_t k0, + const uint8_t *a, + const dim_t rs_a, + const dim_t cs_a, + const dim_t ps_a, + const int8_t *b, + const dim_t rs_b, + const dim_t cs_b, + int16_t *c, + const dim_t rs_c, + const int16_t alpha, + const int16_t beta, + const dim_t n0_rem, + bool is_last_k, + dim_t post_op_c_i, + dim_t post_op_c_j, + lpgemm_post_op *post_ops_list + ); #endif // BLIS_GEMM_INT16_NFRINGE \ No newline at end of file diff --git a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_packb_amd256.c b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_packb_amd256.c index 1f6ec5a787..ac9cb469e3 100644 --- a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_packb_amd256.c +++ b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_packb_amd256.c @@ -36,21 +36,26 @@ #include "blis.h" #include "lpgemm_packb_s16.h" +#include "lpgemm_config.h" -void get_packb_nr32_u8s8s16o16_strides( - dim_t *rs_b, - dim_t *cs_b) +void get_packb_nr32_u8s8s16o16_strides + ( + dim_t *rs_b, + dim_t *cs_b + ) { - *rs_b = 32 * 2; - *cs_b = 32; + *rs_b = lpgemm_get_block_size_NR_global_cntx( U8S8S16OS16 ) * 2; + *cs_b = lpgemm_get_block_size_NR_global_cntx( U8S8S16OS16 ); } -void packb_nrlt16_u8s8s16o16( - int8_t *pack_b_buffer_u8s8s16o16, - const int8_t *b, - const dim_t ldb, - const dim_t rows, - dim_t n0_partial_rem) +void packb_nrlt16_u8s8s16o16 + ( + int8_t *pack_b_buffer_u8s8s16o16, + const int8_t *b, + const dim_t ldb, + const dim_t rows, + dim_t n0_partial_rem + ) { dim_t k_full_pieces_blks = rows / 2; dim_t k_full_pieces = k_full_pieces_blks * 2; diff --git a/bench/bench_aocl_gemm/bench_input.txt b/bench/bench_aocl_gemm/bench_input.txt index de76f75168..617dced6df 100644 --- a/bench/bench_aocl_gemm/bench_input.txt +++ b/bench/bench_aocl_gemm/bench_input.txt @@ -1,3 +1,153 @@ +s r 480 20 2050 2050 20 20 +s r 481 20 2050 2050 20 20 +s r 482 20 2050 2050 20 20 +s p 483 20 2050 2050 20 20 +s R 484 20 2050 2050 20 20 +s R 485 20 2050 2050 20 20 +s R 480 39 2050 2050 39 39 +s R 481 39 2050 2050 39 39 +s R 482 39 2050 2050 39 39 +s R 483 39 2050 2050 39 39 +s R 484 39 2050 2050 39 39 +s p 485 39 2050 2050 39 39 +s p 480 50 2050 2050 50 50 +s p 481 50 2050 2050 50 50 +s p 482 50 2050 2050 50 50 +s p 483 50 2050 2050 50 50 +s p 484 50 2050 2050 50 50 +s p 485 50 2050 2050 50 50 +s R 480 1108 2050 2050 1108 1108 +s R 481 1108 2050 2050 1108 1108 +s R 482 1108 2050 2050 1108 1108 +s R 483 1108 2050 2050 1108 1108 +s R 484 1108 2050 2050 1108 1108 +s R 485 1108 2050 2050 1108 1108 +s R 480 1127 2050 2050 1127 1127 +s R 481 1127 2050 2050 1127 1127 +s R 482 1127 2050 2050 1127 1127 +s R 483 1127 2050 2050 1127 1127 +s p 484 1127 2050 2050 1127 1127 +s p 485 1127 2050 2050 1127 1127 +s p 480 1138 2050 2050 1138 1138 +s p 481 1138 2050 2050 1138 1138 +s p 482 1138 2050 2050 1138 1138 +s p 483 1138 2050 2050 1138 1138 +s p 484 1138 2050 2050 1138 1138 +s p 485 1138 2050 2050 1138 1138 +s p 1 1 3 3 1 1 +s p 1 9 3 3 9 9 +s p 1 2048 3 3 2048 2048 +s p 1 2048 5192 5192 2048 2048 +s p 9 1 3 3 1 1 +s p 576 1 3500 3500 1 1 +s p 1 1 1 1 1 1 +s p 102 1088 1024 1024 1088 1088 +s p 102 2048 1024 1024 2048 2048 +s p 485 656 1024 1024 656 656 +s p 483 656 1024 1024 656 656 +s p 81 128 3 3 128 128 +s p 1022 512 515 515 512 512 +s p 74 512 515 515 512 512 +s p 253 2048 515 515 2048 2048 +s p 8192 1040 515 515 1040 1040 +s p 10 1029 515 515 1029 1029 +s p 24 1040 2050 2050 1040 1040 +s p 1024 1029 2050 2050 1029 1029 +s p 480 660 2050 2050 660 660 +s p 481 660 2050 2050 660 660 +s p 482 660 2050 2050 660 660 +s p 483 660 2050 2050 660 660 +s p 484 660 2050 2050 660 660 +s p 485 660 2050 2050 660 660 +s p 480 679 2050 2050 679 679 +s p 481 679 2050 2050 679 679 +s p 482 679 2050 2050 679 679 +s p 483 679 2050 2050 679 679 +s p 484 679 2050 2050 679 679 +s p 485 679 2050 2050 679 679 +s p 480 690 2050 2050 690 690 +s p 481 690 2050 2050 690 690 +s p 482 690 2050 2050 690 690 +s p 483 690 2050 2050 690 690 +s p 484 690 2050 2050 690 690 +s p 485 690 2050 2050 690 690 +s p 480 660 2048 2048 660 660 +s p 481 660 2048 2048 660 660 +s p 482 660 2048 2048 660 660 +s p 483 660 2048 2048 660 660 +s p 484 660 2048 2048 660 660 +s p 485 660 2048 2048 660 660 +s p 480 679 2048 2048 679 679 +s p 481 679 2048 2048 679 679 +s p 482 679 2048 2048 679 679 +s p 483 679 2048 2048 679 679 +s p 484 679 2048 2048 679 679 +s p 485 679 2048 2048 679 679 +s p 480 690 2048 2048 690 690 +s p 481 690 2048 2048 690 690 +s p 482 690 2048 2048 690 690 +s p 483 690 2048 2048 690 690 +s p 484 690 2048 2048 690 690 +s p 485 690 2048 2048 690 690 +s p 480 656 1024 1024 656 656 +s p 480 128 3 3 128 128 +s p 1024 512 515 515 512 512 +s p 1024 2048 1024 1024 2048 2048 +s p 1024 2048 515 515 2048 2048 +s p 1024 1040 515 515 1040 1040 +s p 5 1029 515 515 1029 1029 +s p 1024 1029 515 515 1029 1029 +s p 1024 1040 2050 2050 1040 1040 +s p 1029 1029 2050 2050 1029 1029 +s R 480 646 2050 2050 646 646 +s R 481 646 2050 2050 646 646 +s R 482 646 2050 2050 646 646 +s R 483 646 2050 2050 646 646 +s R 484 646 2050 2050 646 646 +s R 485 646 2050 2050 646 646 +s R 481 656 2050 2050 656 656 +s R 482 656 2050 2050 656 656 +s R 483 656 2050 2050 656 656 +s R 484 656 2050 2050 656 656 +s p 485 656 2050 2050 656 656 +s p 480 672 2050 2050 672 672 +s p 481 672 2050 2050 672 672 +s p 482 672 2050 2050 672 672 +s p 483 672 2050 2050 672 672 +s p 484 672 2050 2050 672 672 +s p 485 672 2050 2050 672 672 +s p 480 688 2050 2050 688 688 +s p 481 688 2050 2050 688 688 +s r 482 688 2050 2050 688 688 +s r 483 688 2050 2050 688 688 +s r 484 688 2050 2050 688 688 +s r 485 688 2050 2050 688 688 +s r 1024 512 64 64 512 512 +s r 16 256 512 512 256 256 +s r 480 640 512 512 640 640 +s r 64 768 512 512 768 768 +s r 128 128 128 128 128 128 +s r 1024 64 512 512 64 64 +s r 1024 256 32 32 256 256 +s r 1024 512 64 64 512 512 +s r 480 640 512 512 640 640 +s p 1024 32 256 256 32 32 +s P 1024 64 512 512 64 64 +s P 64 800 320 320 800 800 +s P 64 768 512 512 768 768 +s P 16 256 512 512 256 256 +s P 128 128 128 128 128 128 +s P 256 512 256 256 512 512 +s P 1024 1024 1024 1024 1024 1024 +s P 480 640 1024 1024 640 640 +s P 480 640 256 256 640 640 +s P 8 64 32 32 64 64 +s P 9 64 32 32 64 64 +s P 10 128 64 64 128 128 +s P 8 8 8 8 8 8 +s P 12 12 12 12 12 12 +s P 25 25 25 25 25 25 +s P 25 25 20 20 25 25 i p 480 20 2050 2050 20 20 i p 481 20 2050 2050 20 20 i p 482 20 2050 2050 20 20 From 737e08cd7a6046475b3db80ba41f98fa0d322bfd Mon Sep 17 00:00:00 2001 From: Edward Smyth Date: Mon, 1 Aug 2022 11:59:18 -0400 Subject: [PATCH 168/243] BLIS: Improve architecture selection at runtime Enable meaningful names as options for BLIS_ARCH_TYPE environment variable. For example, BLIS_ARCH_TYPE=zen4 or BLIS_ARCH_TYPE='ZEN4' or BLIS_ARCH_TYPE=6 will select the same code path (in this release). The meaningful names are not case sensitive. This implements change 1 in the Jira ticket below. Following review comments: 1. Use names from arch_t enum in function bli_env_get_var_arch_type() rather than directly using numbers. 2. AMD copyrights updated. AMD-Internal: [CPUPL-2235] Change-Id: I8cfd43d34765d5e8c7e35680d18825d9934753ad --- frame/base/bli_arch.c | 6 +- frame/base/bli_env.c | 151 +++++++++++++++++++++++++++++++++- frame/base/bli_env.h | 4 +- frame/include/bli_type_defs.h | 2 + 4 files changed, 159 insertions(+), 4 deletions(-) diff --git a/frame/base/bli_arch.c b/frame/base/bli_arch.c index 2696236717..aa6940ba49 100644 --- a/frame/base/bli_arch.c +++ b/frame/base/bli_arch.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018-2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -83,7 +83,7 @@ void bli_arch_set_id( void ) // Check the environment variable BLIS_ARCH_TYPE to see if the user // requested that we use a specific subconfiguration. - dim_t req_id = bli_env_get_var( "BLIS_ARCH_TYPE", -1 ); + dim_t req_id = bli_env_get_var_arch_type( "BLIS_ARCH_TYPE", -1 ); #ifndef BLIS_CONFIGURETIME_CPUID if ( req_id != -1 ) @@ -230,6 +230,8 @@ void bli_arch_set_id( void ) // enumeration that is typedef'ed in bli_type_defs.h. That is, the // index order of each string should correspond to the implied/assigned // enum value given to the corresponding BLIS_ARCH_ value. +// This must also be kept up-to-date with the bli_env_get_var_arch_type() +// function in bli_env.c static char* config_name[ BLIS_NUM_ARCHS ] = { "skx", diff --git a/frame/base/bli_env.c b/frame/base/bli_env.c index 23b8e059e1..2cb9efd87a 100644 --- a/frame/base/bli_env.c +++ b/frame/base/bli_env.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -69,6 +69,155 @@ gint_t bli_env_get_var( const char* env, gint_t fallback ) return r_val; } +gint_t bli_env_get_var_arch_type( const char* env, gint_t fallback ) +{ + gint_t r_val; + char* str; + int i, size; + + // Query the environment variable and store the result in str. + str = getenv( env ); + + // Set the return value based on the string obtained from getenv(). + if ( str != NULL ) + { + // If there was no error, convert the string to an integer and + // prepare to return that integer. + r_val = ( gint_t )strtol( str, NULL, 10 ); + + if (r_val == 0) + { + // Could be deliberately 0 (currently meaning "skx") or + // a non-numeric value. We still allow direct specification + // of integer value to select code path. Non-zero integer + // values bypass this code block and are handled as before. + // Here we look for known meaningful names, and return 0 + // if we cannot find a match. + // This code MUST be kept in synch with arch_t enumeration + // in bli_type_defs.h and array config_name in bli_arch.c + + // convert string to lowercase + size = strlen(str); + for (i=0;i<=size;i++) + { + str[i] = tolower(str[i]); + } + + // Intel + if (strcmp(str, "skx") == 0) + { + r_val = BLIS_ARCH_SKX; + } + else if (strcmp(str, "knl") == 0) + { + r_val = BLIS_ARCH_KNL; + } + else if (strcmp(str, "knc") == 0) + { + r_val = BLIS_ARCH_KNC; + } + else if (strcmp(str, "haswell") == 0) + { + r_val = BLIS_ARCH_HASWELL; + } + else if (strcmp(str, "sandybridge") == 0) + { + r_val = BLIS_ARCH_SANDYBRIDGE; + } + else if (strcmp(str, "penryn") == 0) + { + r_val = BLIS_ARCH_PENRYN; + } + // AMD + else if (strcmp(str, "zen4") == 0) + { + r_val = BLIS_ARCH_ZEN4; + } + else if (strcmp(str, "zen3") == 0) + { + r_val = BLIS_ARCH_ZEN3; + } + else if (strcmp(str, "zen2") == 0) + { + r_val = BLIS_ARCH_ZEN2; + } + else if ((strcmp(str, "zen") == 0) || (strcmp(str, "zen1") == 0)) + { + r_val = BLIS_ARCH_ZEN; + } + else if (strcmp(str, "excavator") == 0) + { + r_val = BLIS_ARCH_EXCAVATOR; + } + else if (strcmp(str, "steamroller") == 0) + { + r_val = BLIS_ARCH_STEAMROLLER; + } + else if (strcmp(str, "piledriver") == 0) + { + r_val = BLIS_ARCH_PILEDRIVER; + } + else if (strcmp(str, "bulldozer") == 0) + { + r_val = BLIS_ARCH_BULLDOZER; + } + // ARM + else if (strcmp(str, "thunderx2") == 0) + { + r_val = BLIS_ARCH_THUNDERX2; + } + else if (strcmp(str, "cortexa57") == 0) + { + r_val = BLIS_ARCH_CORTEXA57; + } + else if (strcmp(str, "cortexa53") == 0) + { + r_val = BLIS_ARCH_CORTEXA53; + } + else if (strcmp(str, "cortexa15") == 0) + { + r_val = BLIS_ARCH_CORTEXA15; + } + else if (strcmp(str, "cortexa9") == 0) + { + r_val = BLIS_ARCH_CORTEXA9; + } + // IBM POWER + else if (strcmp(str, "power10") == 0) + { + r_val = BLIS_ARCH_POWER10; + } + else if (strcmp(str, "power9") == 0) + { + r_val = BLIS_ARCH_POWER9; + } + else if (strcmp(str, "power7") == 0) + { + r_val = BLIS_ARCH_POWER7; + } + else if (strcmp(str, "bgq") == 0) + { + r_val = BLIS_ARCH_BGQ; + } + // Generic + else if (strcmp(str, "generic") == 0) + { + r_val = BLIS_ARCH_GENERIC; + } + + // No else case means we return r_val=0, i.e. this behaves + // the same as generic bli_env_get_var(). + } + } + else + { + // If there was an error, use the "fallback" as the return value. + r_val = fallback; + } + + return r_val; +} + #if 0 #ifdef _MSC_VER #define strerror_r(errno,buf,len) strerror_s(buf,len,errno) diff --git a/frame/base/bli_env.h b/frame/base/bli_env.h index de86fadff0..eaa778cd20 100644 --- a/frame/base/bli_env.h +++ b/frame/base/bli_env.h @@ -6,7 +6,7 @@ Copyright (C) 2014, The University of Texas at Austin Copyright (C) 2016, Hewlett Packard Enterprise Development LP - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -40,5 +40,7 @@ gint_t bli_env_get_var( const char* env, gint_t fallback ); //void bli_env_set_var( const char* env, dim_t value ); +gint_t bli_env_get_var_arch_type( const char* env, gint_t fallback ); + #endif diff --git a/frame/include/bli_type_defs.h b/frame/include/bli_type_defs.h index 4e28d8b461..47377aa250 100644 --- a/frame/include/bli_type_defs.h +++ b/frame/include/bli_type_defs.h @@ -990,6 +990,8 @@ typedef enum // string array in bli_arch.c. Whenever values are added/inserted // OR if values are rearranged, be sure to update the string array // in bli_arch.c. +// This must also be kept up-to-date with the bli_env_get_var_arch_type() +// function in bli_env.c typedef enum { From 8504ef013da47cdcc325f817a8fc874806fb2648 Mon Sep 17 00:00:00 2001 From: "Mangala.V" Date: Fri, 22 Jul 2022 14:52:24 +0530 Subject: [PATCH 169/243] Optimisation of DTRSM and ZTRSM 1. Extract instruction replaced with cast when accessing first 128bit, as cast inst needs no cycle but extract takes few cycles 2. Added prefetch of A buffer when computing gemm operation 3. Added prefetch of C11 buffer before TRSM operation, with offset of 7 to cs_c With above changes performance improvements observed in case of Single thread Change-Id: Id377c490ddac8b06384acfa9a6d89dbe11bbc7be --- kernels/zen/3/bli_trsm_small.c | 549 +++++++++++++++------------------ 1 file changed, 252 insertions(+), 297 deletions(-) diff --git a/kernels/zen/3/bli_trsm_small.c b/kernels/zen/3/bli_trsm_small.c index bb8a2e9cc5..5b6df35d77 100644 --- a/kernels/zen/3/bli_trsm_small.c +++ b/kernels/zen/3/bli_trsm_small.c @@ -668,9 +668,11 @@ BLIS_INLINE err_t dtrsm_XAltB_ref ymm15 = _mm256_setzero_pd(); /*GEMM block used in trsm small right cases*/ +/* B = 8x6, A = 6x6 */ #define BLIS_DTRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) \ for(k = 0; k < k_iter; k++) \ {\ + _mm_prefetch((char*)( a01 + 8), _MM_HINT_T0); \ /*load 8x1 block of B10*/ \ ymm0 = _mm256_loadu_pd((double const *)b10); \ ymm1 = _mm256_loadu_pd((double const *)(b10 + 4)); \ @@ -1278,7 +1280,7 @@ BLIS_INLINE err_t dtrsm_XAltB_ref \ _mm256_storeu_pd((double *)(b11), ymm0); /*store(B11[0-3][0])*/\ _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); /*store(B11[0-3][1])*/\ - xmm5 = _mm256_extractf128_pd(ymm2, 0);\ + xmm5 = _mm256_castpd256_pd128(ymm2);\ _mm_storeu_pd((double *)(b11 + cs_b * 2), xmm5);\ _mm_storel_pd((b11 + cs_b * 2 + 2), _mm256_extractf128_pd(ymm2, 1)); @@ -1297,7 +1299,7 @@ BLIS_INLINE err_t dtrsm_XAltB_ref ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08);\ \ _mm256_storeu_pd((double *)(b11), ymm0); /*store(B11[0-3][0])*/\ - xmm5 = _mm256_extractf128_pd(ymm1, 0);\ + xmm5 = _mm256_castpd256_pd128(ymm1);\ _mm_storeu_pd((double *)(b11 + cs_b * 1), xmm5);\ _mm_storel_pd((b11 + cs_b * 1 + 2), _mm256_extractf128_pd(ymm1, 1)); @@ -1310,7 +1312,7 @@ BLIS_INLINE err_t dtrsm_XAltB_ref ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8);\ ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08);\ \ - xmm5 = _mm256_extractf128_pd(ymm0, 0);\ + xmm5 = _mm256_castpd256_pd128(ymm0);\ _mm_storeu_pd((double *)(b11), xmm5);\ _mm_storel_pd((b11 + 2), _mm256_extractf128_pd(ymm0, 1)); @@ -1333,7 +1335,7 @@ BLIS_INLINE err_t dtrsm_XAltB_ref \ _mm256_storeu_pd((double *)(b11), ymm0); /*store(B11[0-3][0])*/\ _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); /*store(B11[0-3][1])*/\ - xmm5 = _mm256_extractf128_pd(ymm2, 0);\ + xmm5 = _mm256_castpd256_pd128(ymm2);\ _mm_storeu_pd((double *)(b11 + cs_b * 2), xmm5); #define BLIS_PRE_DTRSM_SMALL_2M_2N(AlphaVal,b11,cs_b)\ @@ -1350,7 +1352,7 @@ BLIS_INLINE err_t dtrsm_XAltB_ref ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C);\ \ _mm256_storeu_pd((double *)(b11), ymm0); /*store(B11[0-3][0])*/\ - xmm5 = _mm256_extractf128_pd(ymm1, 0);\ + xmm5 = _mm256_castpd256_pd128(ymm1);\ _mm_storeu_pd((double *)(b11 + cs_b * 1), xmm5); #define BLIS_PRE_DTRSM_SMALL_2M_1N(AlphaVal,b11,cs_b)\ @@ -1362,7 +1364,7 @@ BLIS_INLINE err_t dtrsm_XAltB_ref ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8);\ ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C);\ \ - xmm5 = _mm256_extractf128_pd(ymm0, 0);\ + xmm5 = _mm256_castpd256_pd128(ymm0);\ _mm_storeu_pd((double *)(b11 + cs_b * 0), xmm5); #define BLIS_PRE_DTRSM_SMALL_1M_3N(AlphaVal,b11,cs_b)\ @@ -1380,9 +1382,9 @@ BLIS_INLINE err_t dtrsm_XAltB_ref ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E);\ ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E);\ \ - _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm0, 0));\ - _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm1, 0));\ - _mm_storel_pd((b11 + cs_b * 2), _mm256_extractf128_pd(ymm2, 0)); + _mm_storel_pd((b11 + cs_b * 0), _mm256_castpd256_pd128(ymm0));\ + _mm_storel_pd((b11 + cs_b * 1), _mm256_castpd256_pd128(ymm1));\ + _mm_storel_pd((b11 + cs_b * 2), _mm256_castpd256_pd128(ymm2)); #define BLIS_PRE_DTRSM_SMALL_1M_2N(AlphaVal,b11,cs_b)\ ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); /*register to hold alpha*/\ @@ -1396,8 +1398,8 @@ BLIS_INLINE err_t dtrsm_XAltB_ref ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E);\ ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E);\ \ - _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm0, 0));\ - _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm1, 0)); + _mm_storel_pd((b11 + cs_b * 0), _mm256_castpd256_pd128(ymm0));\ + _mm_storel_pd((b11 + cs_b * 1), _mm256_castpd256_pd128(ymm1)); #define BLIS_PRE_DTRSM_SMALL_1M_1N(AlphaVal,b11,cs_b)\ ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); /*register to hold alpha*/\ @@ -1407,7 +1409,7 @@ BLIS_INLINE err_t dtrsm_XAltB_ref \ ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E);\ \ - _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm0, 0)); + _mm_storel_pd((b11 + cs_b * 0), _mm256_castpd256_pd128(ymm0)); /* pre & post TRSM for Right remainder cases*/ #define BLIS_PRE_DTRSM_SMALL_3N_3M(AlphaVal,b11,cs_b)\ @@ -1436,7 +1438,7 @@ BLIS_INLINE err_t dtrsm_XAltB_ref \ _mm256_storeu_pd((double *)b11, ymm3);\ _mm256_storeu_pd((double *)(b11 + cs_b), ymm5);\ - xmm5 = _mm256_extractf128_pd(ymm7, 0);\ + xmm5 = _mm256_castpd256_pd128(ymm7);\ _mm_storeu_pd((double *)(b11 + cs_b * 2),xmm5);\ _mm_storel_pd((b11 + cs_b * 2 + 2), _mm256_extractf128_pd(ymm7, 1)); @@ -1464,7 +1466,7 @@ BLIS_INLINE err_t dtrsm_XAltB_ref \ _mm256_storeu_pd((double *)b11, ymm3);\ _mm256_storeu_pd((double *)(b11 + cs_b), ymm5);\ - xmm5 = _mm256_extractf128_pd(ymm7, 0);\ + xmm5 = _mm256_castpd256_pd128(ymm7);\ _mm_storeu_pd((double *)(b11 + cs_b * 2),xmm5); #define BLIS_PRE_DTRSM_SMALL_3N_1M(AlphaVal,b11,cs_b)\ @@ -1487,9 +1489,9 @@ BLIS_INLINE err_t dtrsm_XAltB_ref ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*2));\ ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x01);\ \ - _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm3, 0));\ - _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm5, 0));\ - _mm_storel_pd((b11 + cs_b * 2), _mm256_extractf128_pd(ymm7, 0)); + _mm_storel_pd((b11 + cs_b * 0), _mm256_castpd256_pd128(ymm3));\ + _mm_storel_pd((b11 + cs_b * 1), _mm256_castpd256_pd128(ymm5));\ + _mm_storel_pd((b11 + cs_b * 2), _mm256_castpd256_pd128(ymm7)); #define BLIS_PRE_DTRSM_SMALL_2N_3M(AlphaVal,b11,cs_b)\ ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); /*register to hold alpha*/\ @@ -1511,7 +1513,7 @@ BLIS_INLINE err_t dtrsm_XAltB_ref ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x07);\ \ _mm256_storeu_pd((double *)b11, ymm3);\ - xmm5 = _mm256_extractf128_pd(ymm5, 0);\ + xmm5 = _mm256_castpd256_pd128(ymm5);\ _mm_storeu_pd((double *)(b11 + cs_b*1), xmm5);\ _mm_storel_pd((b11 + cs_b * 1 + 2), _mm256_extractf128_pd(ymm5, 1)); @@ -1533,7 +1535,7 @@ BLIS_INLINE err_t dtrsm_XAltB_ref ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x03);\ \ _mm256_storeu_pd((double *)b11, ymm3);\ - xmm5 = _mm256_extractf128_pd(ymm5, 0);\ + xmm5 = _mm256_castpd256_pd128(ymm5);\ _mm_storeu_pd((double *)(b11 + cs_b*1), xmm5); #define BLIS_PRE_DTRSM_SMALL_2N_1M(AlphaVal,b11,cs_b)\ @@ -1551,8 +1553,8 @@ BLIS_INLINE err_t dtrsm_XAltB_ref ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b));\ ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x01);\ \ - _mm_storel_pd(b11 , _mm256_extractf128_pd(ymm3, 0));\ - _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm5, 0)); + _mm_storel_pd(b11 , _mm256_castpd256_pd128(ymm3));\ + _mm_storel_pd((b11 + cs_b * 1), _mm256_castpd256_pd128(ymm5)); #define BLIS_PRE_DTRSM_SMALL_1N_3M(AlphaVal,b11,cs_b)\ ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); /*register to hold alpha*/\ @@ -1563,7 +1565,7 @@ BLIS_INLINE err_t dtrsm_XAltB_ref ymm3 = _mm256_fmsub_pd(ymm6, ymm15, ymm3); #define BLIS_POST_DTRSM_SMALL_1N_3M(b11,cs_b)\ - xmm5 = _mm256_extractf128_pd(ymm3, 0);\ + xmm5 = _mm256_castpd256_pd128(ymm3);\ _mm_storeu_pd((double *)(b11), xmm5);\ _mm_storel_pd((b11 + 2), _mm256_extractf128_pd(ymm3, 1)); @@ -1578,7 +1580,7 @@ BLIS_INLINE err_t dtrsm_XAltB_ref ymm0 = _mm256_loadu_pd((double const *)b11);\ ymm3 = _mm256_blend_pd(ymm6, ymm3, 0x03);\ \ - xmm5 = _mm256_extractf128_pd(ymm3, 0);\ + xmm5 = _mm256_castpd256_pd128(ymm3);\ _mm_storeu_pd((double *)(b11), xmm5); #define BLIS_PRE_DTRSM_SMALL_1N_1M(AlphaVal,b11,cs_b)\ @@ -1590,7 +1592,7 @@ BLIS_INLINE err_t dtrsm_XAltB_ref #define BLIS_POST_DTRSM_SMALL_1N_1M(b11,cs_b)\ ymm3 = _mm256_blend_pd(ymm6, ymm3, 0x01);\ \ - _mm_storel_pd(b11, _mm256_extractf128_pd(ymm3, 0)); + _mm_storel_pd(b11, _mm256_castpd256_pd128(ymm3)); /* multiply with Alpha pre TRSM for 6*8 kernel*/ #define BLIS_PRE_DTRSM_SMALL_6x8(AlphaVal,b11,cs_b)\ @@ -3439,7 +3441,6 @@ BLIS_INLINE void bli_dtrsm_small_pack __m256d ymm8, ymm9, ymm10, ymm11; __m256d ymm12, ymm13; __m128d xmm0,xmm1,xmm2,xmm3; - double zero = 0.0; if(side=='L'||side=='l') { @@ -3595,12 +3596,10 @@ BLIS_INLINE void bli_dtrsm_small_pack ymm4 = _mm256_unpacklo_pd(ymm10, ymm11); ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_broadcast_sd((double const *)&zero); ymm0 = _mm256_unpackhi_pd(ymm10, ymm11); ymm1 = _mm256_unpackhi_pd(ymm12, ymm13); ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_broadcast_sd((double const *)&zero); _mm256_storeu_pd((double *)(pbuff + p_lda * 4), ymm6); _mm256_storeu_pd((double *)(pbuff + p_lda * 5), ymm7); @@ -3611,32 +3610,19 @@ BLIS_INLINE void bli_dtrsm_small_pack ymm11 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 5 + 4)); ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_broadcast_sd((double const *)&zero); - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_broadcast_sd((double const *)&zero); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - _mm_storeu_pd((double *)(pbuff + 4), _mm256_extractf128_pd(ymm6,0)); - _mm_storeu_pd((double *)(pbuff + 4 + p_lda), _mm256_extractf128_pd(ymm7,0)); - _mm_storeu_pd((double *)(pbuff + 4 + p_lda*2), _mm256_extractf128_pd(ymm8,0)); - _mm_storeu_pd((double *)(pbuff + 4 + p_lda*3), _mm256_extractf128_pd(ymm9,0)); + _mm_storeu_pd((double *)(pbuff + 4), _mm256_castpd256_pd128(ymm4)); + _mm_storeu_pd((double *)(pbuff + 4 + p_lda), _mm256_castpd256_pd128(ymm0)); + _mm_storeu_pd((double *)(pbuff + 4 + p_lda*2), _mm256_extractf128_pd(ymm4,1)); + _mm_storeu_pd((double *)(pbuff + 4 + p_lda*3), _mm256_extractf128_pd(ymm0,1)); ymm4 = _mm256_unpacklo_pd(ymm10, ymm11); - ymm5 = _mm256_broadcast_sd((double const *)&zero); - - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_broadcast_sd((double const *)&zero); ymm0 = _mm256_unpackhi_pd(ymm10, ymm11); - ymm1 = _mm256_broadcast_sd((double const *)&zero); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_broadcast_sd((double const *)&zero); - _mm_storeu_pd((double *)(pbuff + p_lda * 4 + 4), _mm256_extractf128_pd(ymm6,0)); - _mm_storeu_pd((double *)(pbuff + p_lda * 5 + 4), _mm256_extractf128_pd(ymm7,0)); + _mm_storeu_pd((double *)(pbuff + p_lda * 4 + 4), _mm256_castpd256_pd128(ymm4)); + _mm_storeu_pd((double *)(pbuff + p_lda * 5 + 4), _mm256_castpd256_pd128(ymm0)); inbuf += mr*cs_a; pbuff += mr; } @@ -3740,7 +3726,7 @@ BLIS_INLINE void dtrsm_small_pack_diag_element if(is_eight){ _mm256_store_pd((double *)(d11_pack + 4), ymm5); }else{ - _mm_storeu_pd((double *)(d11_pack + 4), _mm256_extractf128_pd(ymm5,0)); + _mm_storeu_pd((double *)(d11_pack + 4), _mm256_castpd256_pd128(ymm5)); } } @@ -4291,7 +4277,7 @@ BLIS_INLINE err_t ztrsm_AuXB_ref /*get the dcomplex mul answer into register*/\ ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ ymm8 = _mm256_sub_pd(ymm15,ymm8);\ - xmm5 = _mm256_extractf128_pd(ymm8, 0);\ + xmm5 = _mm256_castpd256_pd128(ymm8);\ /*store dcomplex elements*/\ _mm_storeu_pd((double *)(b11 + cs_b * 0), xmm5);\ } @@ -4329,9 +4315,9 @@ BLIS_INLINE err_t ztrsm_AuXB_ref ymm14 = _mm256_mul_pd(ymm1, ymm14);\ ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ ymm9 = _mm256_sub_pd(ymm15,ymm9);\ - xmm4 = _mm256_extractf128_pd(ymm8, 0);\ + xmm4 = _mm256_castpd256_pd128(ymm8);\ _mm_storeu_pd((double *)(b11 + cs_b * 0), xmm4);\ - xmm5 = _mm256_extractf128_pd(ymm9, 0);\ + xmm5 = _mm256_castpd256_pd128(ymm9);\ _mm_storeu_pd((double *)(b11 + cs_b * 1), xmm5);\ } @@ -4854,7 +4840,7 @@ BLIS_INLINE err_t ztrsm_AuXB_ref ymm12 = _mm256_sub_pd(ymm15,ymm12);\ \ _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm8);\ - xmm5 = _mm256_extractf128_pd(ymm12, 0);\ + xmm5 = _mm256_castpd256_pd128(ymm12);\ _mm_storeu_pd((double *)(b11 + cs_b * 0 + 2), xmm5);\ } @@ -4910,9 +4896,9 @@ BLIS_INLINE err_t ztrsm_AuXB_ref \ _mm256_storeu_pd((double *)(b11), ymm8);\ _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm9);\ - xmm4 = _mm256_extractf128_pd(ymm12, 0);\ + xmm4 = _mm256_castpd256_pd128(ymm12);\ _mm_storeu_pd((double *)(b11 + cs_b * 0 + 2), xmm4);\ - xmm5 = _mm256_extractf128_pd(ymm13, 0);\ + xmm5 = _mm256_castpd256_pd128(ymm13);\ _mm_storeu_pd((double *)(b11 + cs_b * 1 + 2), xmm5);\ } @@ -6091,9 +6077,9 @@ BLIS_INLINE void bli_ztrsm_small_pack ymm7 = _mm256_permute2f128_pd(ymm0,ymm5,0x31); ymm8 = _mm256_permute2f128_pd(ymm1,ymm5,0x20); - _mm_storeu_pd((double *)(pbuff + 2), _mm256_extractf128_pd(ymm6,0)); - _mm_storeu_pd((double *)(pbuff + p_lda + 2), _mm256_extractf128_pd(ymm7,0)); - _mm_storeu_pd((double *)(pbuff + p_lda * 2 + 2), _mm256_extractf128_pd(ymm8,0)); + _mm_storeu_pd((double *)(pbuff + 2), _mm256_castpd256_pd128(ymm6)); + _mm_storeu_pd((double *)(pbuff + p_lda + 2), _mm256_castpd256_pd128(ymm7)); + _mm_storeu_pd((double *)(pbuff + p_lda * 2 + 2), _mm256_castpd256_pd128(ymm8)); inbuf += mr*cs_a; pbuff += mr; @@ -6227,7 +6213,6 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB dim_t k_iter; //determines the number of GEMM operations to be done double ones = 1.0; - double zero = 0.0; bool is_unitdiag = bli_obj_has_unit_diag(a); double AlphaVal = *(double *)AlphaObj->buffer; //value of Alpha @@ -6363,6 +6348,13 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB b. Towards the end TRSM output will be stored back into b11 */ + _mm_prefetch((char*)(b11 + 0 + 7), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + cs_b + 7), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + 2 * cs_b + 7), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + 3 * cs_b + 7), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + 4 * cs_b + 7), _MM_HINT_T0); + _mm_prefetch((char*)(b11 + 5 * cs_b + 7), _MM_HINT_T0); + //extract a00 ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); @@ -6727,12 +6719,12 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - _mm_storeu_pd((double *)b11, _mm256_extractf128_pd(ymm3,0)); - _mm_storeu_pd((double *)(b11 + cs_b), _mm256_extractf128_pd(ymm5,0)); - _mm_storeu_pd((double *)(b11 + cs_b*2), _mm256_extractf128_pd(ymm7,0)); - _mm_storeu_pd((double *)(b11 + cs_b*3), _mm256_extractf128_pd(ymm9,0)); - _mm_storeu_pd((double *)(b11 + cs_b*4), _mm256_extractf128_pd(ymm11,0)); - _mm_storeu_pd((double *)(b11 + cs_b*5), _mm256_extractf128_pd(ymm13,0)); + _mm_storeu_pd((double *)b11, _mm256_castpd256_pd128(ymm3)); + _mm_storeu_pd((double *)(b11 + cs_b), _mm256_castpd256_pd128(ymm5)); + _mm_storeu_pd((double *)(b11 + cs_b*2), _mm256_castpd256_pd128(ymm7)); + _mm_storeu_pd((double *)(b11 + cs_b*3), _mm256_castpd256_pd128(ymm9)); + _mm_storeu_pd((double *)(b11 + cs_b*4), _mm256_castpd256_pd128(ymm11)); + _mm_storeu_pd((double *)(b11 + cs_b*5), _mm256_castpd256_pd128(ymm13)); _mm_storel_pd((double *)b11 + 2, _mm256_extractf128_pd(ymm3,1)); _mm_storel_pd((double *)(b11 + cs_b + 2), _mm256_extractf128_pd(ymm5,1)); @@ -6851,12 +6843,12 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - _mm_storeu_pd((double *)b11, _mm256_extractf128_pd(ymm3,0)); - _mm_storeu_pd((double *)(b11 + cs_b), _mm256_extractf128_pd(ymm5,0)); - _mm_storeu_pd((double *)(b11 + cs_b*2), _mm256_extractf128_pd(ymm7,0)); - _mm_storeu_pd((double *)(b11 + cs_b*3), _mm256_extractf128_pd(ymm9,0)); - _mm_storeu_pd((double *)(b11 + cs_b*4), _mm256_extractf128_pd(ymm11,0)); - _mm_storeu_pd((double *)(b11 + cs_b*5), _mm256_extractf128_pd(ymm13,0)); + _mm_storeu_pd((double *)b11, _mm256_castpd256_pd128(ymm3)); + _mm_storeu_pd((double *)(b11 + cs_b), _mm256_castpd256_pd128(ymm5)); + _mm_storeu_pd((double *)(b11 + cs_b*2), _mm256_castpd256_pd128(ymm7)); + _mm_storeu_pd((double *)(b11 + cs_b*3), _mm256_castpd256_pd128(ymm9)); + _mm_storeu_pd((double *)(b11 + cs_b*4), _mm256_castpd256_pd128(ymm11)); + _mm_storeu_pd((double *)(b11 + cs_b*5), _mm256_castpd256_pd128(ymm13)); m_remainder -= 2; i += 2; @@ -6968,12 +6960,12 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - _mm_storel_pd((double *)b11, _mm256_extractf128_pd(ymm3,0)); - _mm_storel_pd((double *)(b11 + cs_b), _mm256_extractf128_pd(ymm5,0)); - _mm_storel_pd((double *)(b11 + cs_b*2), _mm256_extractf128_pd(ymm7,0)); - _mm_storel_pd((double *)(b11 + cs_b*3), _mm256_extractf128_pd(ymm9,0)); - _mm_storel_pd((double *)(b11 + cs_b*4), _mm256_extractf128_pd(ymm11,0)); - _mm_storel_pd((double *)(b11 + cs_b*5), _mm256_extractf128_pd(ymm13,0)); + _mm_storel_pd((double *)b11, _mm256_castpd256_pd128(ymm3)); + _mm_storel_pd((double *)(b11 + cs_b), _mm256_castpd256_pd128(ymm5)); + _mm_storel_pd((double *)(b11 + cs_b*2), _mm256_castpd256_pd128(ymm7)); + _mm_storel_pd((double *)(b11 + cs_b*3), _mm256_castpd256_pd128(ymm9)); + _mm_storel_pd((double *)(b11 + cs_b*4), _mm256_castpd256_pd128(ymm11)); + _mm_storel_pd((double *)(b11 + cs_b*5), _mm256_castpd256_pd128(ymm13)); m_remainder -= 1; i += 1; @@ -7028,21 +7020,12 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a * 5)); ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_broadcast_sd((double const *)&zero); - - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_broadcast_sd((double const *)&zero); - - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_extractf128_pd(ymm6,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_extractf128_pd(ymm7,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_extractf128_pd(ymm8,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_extractf128_pd(ymm9,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_castpd256_pd128(ymm4)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_castpd256_pd128(ymm0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_extractf128_pd(ymm4,1)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_extractf128_pd(ymm0,1)); a01 += d_nr*cs_a; ptr_a10_dup += d_nr; @@ -7365,10 +7348,10 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - _mm_storeu_pd((double *)b11, _mm256_extractf128_pd(ymm3,0)); - _mm_storeu_pd((double *)(b11 + cs_b), _mm256_extractf128_pd(ymm5,0)); - _mm_storeu_pd((double *)(b11 + cs_b*2), _mm256_extractf128_pd(ymm7,0)); - _mm_storeu_pd((double *)(b11 + cs_b*3), _mm256_extractf128_pd(ymm9,0)); + _mm_storeu_pd((double *)b11, _mm256_castpd256_pd128(ymm3)); + _mm_storeu_pd((double *)(b11 + cs_b), _mm256_castpd256_pd128(ymm5)); + _mm_storeu_pd((double *)(b11 + cs_b*2), _mm256_castpd256_pd128(ymm7)); + _mm_storeu_pd((double *)(b11 + cs_b*3), _mm256_castpd256_pd128(ymm9)); _mm_storel_pd((double *)b11 + 2, _mm256_extractf128_pd(ymm3,1)); _mm_storel_pd((double *)(b11 + cs_b + 2), _mm256_extractf128_pd(ymm5,1)); @@ -7454,10 +7437,10 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - _mm_storeu_pd((double *)b11, _mm256_extractf128_pd(ymm3,0)); - _mm_storeu_pd((double *)(b11 + cs_b), _mm256_extractf128_pd(ymm5,0)); - _mm_storeu_pd((double *)(b11 + cs_b*2), _mm256_extractf128_pd(ymm7,0)); - _mm_storeu_pd((double *)(b11 + cs_b*3), _mm256_extractf128_pd(ymm9,0)); + _mm_storeu_pd((double *)b11, _mm256_castpd256_pd128(ymm3)); + _mm_storeu_pd((double *)(b11 + cs_b), _mm256_castpd256_pd128(ymm5)); + _mm_storeu_pd((double *)(b11 + cs_b*2), _mm256_castpd256_pd128(ymm7)); + _mm_storeu_pd((double *)(b11 + cs_b*3), _mm256_castpd256_pd128(ymm9)); m_remainder -= 2; i += 2; @@ -7537,10 +7520,10 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm3, 0)); - _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm5, 0)); - _mm_storel_pd((b11 + cs_b * 2), _mm256_extractf128_pd(ymm7, 0)); - _mm_storel_pd((b11 + cs_b * 3), _mm256_extractf128_pd(ymm9, 0)); + _mm_storel_pd((b11 + cs_b * 0), _mm256_castpd256_pd128(ymm3)); + _mm_storel_pd((b11 + cs_b * 1), _mm256_castpd256_pd128(ymm5)); + _mm_storel_pd((b11 + cs_b * 2), _mm256_castpd256_pd128(ymm7)); + _mm_storel_pd((b11 + cs_b * 3), _mm256_castpd256_pd128(ymm9)); m_remainder -= 1; i += 1; @@ -7589,21 +7572,12 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a * 5)); ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_broadcast_sd((double const *)&zero); - - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_broadcast_sd((double const *)&zero); - - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_extractf128_pd(ymm6,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_extractf128_pd(ymm7,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_extractf128_pd(ymm8,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_extractf128_pd(ymm9,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_castpd256_pd128(ymm4)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_castpd256_pd128(ymm0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_extractf128_pd(ymm4,1)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_extractf128_pd(ymm0,1)); a01 += d_nr*cs_a; ptr_a10_dup += d_nr; @@ -8010,21 +7984,12 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a * 5)); ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_broadcast_sd((double const *)&zero); - - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_broadcast_sd((double const *)&zero); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - - _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_extractf128_pd(ymm6,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_extractf128_pd(ymm7,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_extractf128_pd(ymm8,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_extractf128_pd(ymm9,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_castpd256_pd128(ymm4)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_castpd256_pd128(ymm0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_extractf128_pd(ymm4,1)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_extractf128_pd(ymm0,1)); a01 += d_nr*cs_a; ptr_a10_dup += d_nr; @@ -8339,21 +8304,13 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a * 5)); ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_broadcast_sd((double const *)&zero); - - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_broadcast_sd((double const *)&zero); - - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_extractf128_pd(ymm6,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_extractf128_pd(ymm7,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_extractf128_pd(ymm8,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_extractf128_pd(ymm9,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_castpd256_pd128(ymm4)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_castpd256_pd128(ymm0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_extractf128_pd(ymm4,1)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_extractf128_pd(ymm0,1)); a01 += d_nr*cs_a; ptr_a10_dup += d_nr; @@ -9105,12 +9062,12 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - _mm_storeu_pd((double *)b11, _mm256_extractf128_pd(ymm3,0)); - _mm_storeu_pd((double *)(b11 + cs_b), _mm256_extractf128_pd(ymm5,0)); - _mm_storeu_pd((double *)(b11 + cs_b*2), _mm256_extractf128_pd(ymm7,0)); - _mm_storeu_pd((double *)(b11 + cs_b*3), _mm256_extractf128_pd(ymm9,0)); - _mm_storeu_pd((double *)(b11 + cs_b*4), _mm256_extractf128_pd(ymm11,0)); - _mm_storeu_pd((double *)(b11 + cs_b*5), _mm256_extractf128_pd(ymm13,0)); + _mm_storeu_pd((double *)b11, _mm256_castpd256_pd128(ymm3)); + _mm_storeu_pd((double *)(b11 + cs_b), _mm256_castpd256_pd128(ymm5)); + _mm_storeu_pd((double *)(b11 + cs_b*2), _mm256_castpd256_pd128(ymm7)); + _mm_storeu_pd((double *)(b11 + cs_b*3), _mm256_castpd256_pd128(ymm9)); + _mm_storeu_pd((double *)(b11 + cs_b*4), _mm256_castpd256_pd128(ymm11)); + _mm_storeu_pd((double *)(b11 + cs_b*5), _mm256_castpd256_pd128(ymm13)); _mm_storel_pd((double *)b11 + 2, _mm256_extractf128_pd(ymm3,1)); _mm_storel_pd((double *)(b11 + cs_b + 2), _mm256_extractf128_pd(ymm5,1)); @@ -9220,12 +9177,12 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - _mm_storeu_pd((double *)b11, _mm256_extractf128_pd(ymm3,0)); - _mm_storeu_pd((double *)(b11 + cs_b), _mm256_extractf128_pd(ymm5,0)); - _mm_storeu_pd((double *)(b11 + cs_b*2), _mm256_extractf128_pd(ymm7,0)); - _mm_storeu_pd((double *)(b11 + cs_b*3), _mm256_extractf128_pd(ymm9,0)); - _mm_storeu_pd((double *)(b11 + cs_b*4), _mm256_extractf128_pd(ymm11,0)); - _mm_storeu_pd((double *)(b11 + cs_b*5), _mm256_extractf128_pd(ymm13,0)); + _mm_storeu_pd((double *)b11, _mm256_castpd256_pd128(ymm3)); + _mm_storeu_pd((double *)(b11 + cs_b), _mm256_castpd256_pd128(ymm5)); + _mm_storeu_pd((double *)(b11 + cs_b*2), _mm256_castpd256_pd128(ymm7)); + _mm_storeu_pd((double *)(b11 + cs_b*3), _mm256_castpd256_pd128(ymm9)); + _mm_storeu_pd((double *)(b11 + cs_b*4), _mm256_castpd256_pd128(ymm11)); + _mm_storeu_pd((double *)(b11 + cs_b*5), _mm256_castpd256_pd128(ymm13)); m_remainder -=2; } @@ -9328,12 +9285,12 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - _mm_storel_pd((double *)b11, _mm256_extractf128_pd(ymm3,0)); - _mm_storel_pd((double *)(b11 + cs_b), _mm256_extractf128_pd(ymm5,0)); - _mm_storel_pd((double *)(b11 + cs_b*2), _mm256_extractf128_pd(ymm7,0)); - _mm_storel_pd((double *)(b11 + cs_b*3), _mm256_extractf128_pd(ymm9,0)); - _mm_storel_pd((double *)(b11 + cs_b*4), _mm256_extractf128_pd(ymm11,0)); - _mm_storel_pd((double *)(b11 + cs_b*5), _mm256_extractf128_pd(ymm13,0)); + _mm_storel_pd((double *)b11, _mm256_castpd256_pd128(ymm3)); + _mm_storel_pd((double *)(b11 + cs_b), _mm256_castpd256_pd128(ymm5)); + _mm_storel_pd((double *)(b11 + cs_b*2), _mm256_castpd256_pd128(ymm7)); + _mm_storel_pd((double *)(b11 + cs_b*3), _mm256_castpd256_pd128(ymm9)); + _mm_storel_pd((double *)(b11 + cs_b*4), _mm256_castpd256_pd128(ymm11)); + _mm_storel_pd((double *)(b11 + cs_b*5), _mm256_castpd256_pd128(ymm13)); m_remainder -=1; } @@ -9399,10 +9356,10 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_extractf128_pd(ymm6,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_extractf128_pd(ymm7,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_extractf128_pd(ymm8,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_extractf128_pd(ymm9,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_castpd256_pd128(ymm6)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_castpd256_pd128(ymm7)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_castpd256_pd128(ymm8)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_castpd256_pd128(ymm9)); a01 += d_nr*cs_a; ptr_a10_dup += d_nr; @@ -9714,10 +9671,10 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - _mm_storeu_pd((double *)b11, _mm256_extractf128_pd(ymm3,0)); - _mm_storeu_pd((double *)(b11 + cs_b), _mm256_extractf128_pd(ymm5,0)); - _mm_storeu_pd((double *)(b11 + cs_b*2), _mm256_extractf128_pd(ymm7,0)); - _mm_storeu_pd((double *)(b11 + cs_b*3), _mm256_extractf128_pd(ymm9,0)); + _mm_storeu_pd((double *)b11, _mm256_castpd256_pd128(ymm3)); + _mm_storeu_pd((double *)(b11 + cs_b), _mm256_castpd256_pd128(ymm5)); + _mm_storeu_pd((double *)(b11 + cs_b*2), _mm256_castpd256_pd128(ymm7)); + _mm_storeu_pd((double *)(b11 + cs_b*3), _mm256_castpd256_pd128(ymm9)); _mm_storel_pd((double *)b11 + 2, _mm256_extractf128_pd(ymm3,1)); _mm_storel_pd((double *)(b11 + cs_b + 2), _mm256_extractf128_pd(ymm5,1)); @@ -9798,10 +9755,10 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - _mm_storeu_pd((double *)b11, _mm256_extractf128_pd(ymm3,0)); - _mm_storeu_pd((double *)(b11 + cs_b), _mm256_extractf128_pd(ymm5,0)); - _mm_storeu_pd((double *)(b11 + cs_b*2), _mm256_extractf128_pd(ymm7,0)); - _mm_storeu_pd((double *)(b11 + cs_b*3), _mm256_extractf128_pd(ymm9,0)); + _mm_storeu_pd((double *)b11, _mm256_castpd256_pd128(ymm3)); + _mm_storeu_pd((double *)(b11 + cs_b), _mm256_castpd256_pd128(ymm5)); + _mm_storeu_pd((double *)(b11 + cs_b*2), _mm256_castpd256_pd128(ymm7)); + _mm_storeu_pd((double *)(b11 + cs_b*3), _mm256_castpd256_pd128(ymm9)); m_remainder -=2; } @@ -9874,12 +9831,12 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm3, 0)); - _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm5, 0)); - _mm_storel_pd((b11 + cs_b * 2), _mm256_extractf128_pd(ymm7, 0)); - _mm_storel_pd((b11 + cs_b * 3), _mm256_extractf128_pd(ymm9, 0)); + _mm_storel_pd((b11 + cs_b * 0), _mm256_castpd256_pd128(ymm3)); + _mm_storel_pd((b11 + cs_b * 1), _mm256_castpd256_pd128(ymm5)); + _mm_storel_pd((b11 + cs_b * 2), _mm256_castpd256_pd128(ymm7)); + _mm_storel_pd((b11 + cs_b * 3), _mm256_castpd256_pd128(ymm9)); m_remainder -=1; } @@ -9938,10 +9895,10 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_extractf128_pd(ymm6,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_extractf128_pd(ymm7,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_extractf128_pd(ymm8,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_extractf128_pd(ymm9,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_castpd256_pd128(ymm6)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_castpd256_pd128(ymm7)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_castpd256_pd128(ymm8)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_castpd256_pd128(ymm9)); a01 += d_nr*cs_a; ptr_a10_dup += d_nr; @@ -10347,10 +10304,10 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_extractf128_pd(ymm6,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_extractf128_pd(ymm7,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_extractf128_pd(ymm8,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_extractf128_pd(ymm9,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_castpd256_pd128(ymm6)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_castpd256_pd128(ymm7)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_castpd256_pd128(ymm8)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_castpd256_pd128(ymm9)); a01 += d_nr*cs_a; ptr_a10_dup += d_nr; @@ -10673,10 +10630,10 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_extractf128_pd(ymm6,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_extractf128_pd(ymm7,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_extractf128_pd(ymm8,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_extractf128_pd(ymm9,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_castpd256_pd128(ymm6)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_castpd256_pd128(ymm7)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_castpd256_pd128(ymm8)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_castpd256_pd128(ymm9)); a01 += d_nr*cs_a; ptr_a10_dup += d_nr; @@ -12372,7 +12329,7 @@ BLIS_INLINE err_t bli_dtrsm_small_AltXB_AuXB _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - xmm5 = _mm256_extractf128_pd(ymm3, 0); + xmm5 = _mm256_castpd256_pd128(ymm3); _mm_storeu_pd((double *)(b11 + cs_b * 3),xmm5); _mm_storel_pd((b11 + cs_b * 3 + 2), _mm256_extractf128_pd(ymm3, 1)); @@ -12566,7 +12523,7 @@ BLIS_INLINE err_t bli_dtrsm_small_AltXB_AuXB _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - xmm5 = _mm256_extractf128_pd(ymm3, 0); + xmm5 = _mm256_castpd256_pd128(ymm3); _mm_storeu_pd((double *)(b11 + cs_b * 3), xmm5); if(transa) @@ -14431,7 +14388,7 @@ BLIS_INLINE err_t bli_dtrsm_small_AutXB_AlXB _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - xmm5 = _mm256_extractf128_pd(ymm3, 0); + xmm5 = _mm256_castpd256_pd128(ymm3); _mm_storeu_pd((double *)(b11 + cs_b * 3),xmm5); _mm_storel_pd((b11 + cs_b * 3 + 2), _mm256_extractf128_pd(ymm3, 1)); @@ -14628,7 +14585,7 @@ BLIS_INLINE err_t bli_dtrsm_small_AutXB_AlXB _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - xmm5 = _mm256_extractf128_pd(ymm3, 0); + xmm5 = _mm256_castpd256_pd128(ymm3); _mm_storeu_pd((double *)(b11 + cs_b * 3), xmm5); if(transa) @@ -14819,10 +14776,10 @@ BLIS_INLINE err_t bli_dtrsm_small_AutXB_AlXB ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0E); - _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm0, 0)); - _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm1, 0)); - _mm_storel_pd((b11 + cs_b * 2), _mm256_extractf128_pd(ymm2, 0)); - _mm_storel_pd((b11 + cs_b * 3), _mm256_extractf128_pd(ymm3, 0)); + _mm_storel_pd((b11 + cs_b * 0), _mm256_castpd256_pd128(ymm0)); + _mm_storel_pd((b11 + cs_b * 1), _mm256_castpd256_pd128(ymm1)); + _mm_storel_pd((b11 + cs_b * 2), _mm256_castpd256_pd128(ymm2)); + _mm_storel_pd((b11 + cs_b * 3), _mm256_castpd256_pd128(ymm3)); if(transa) dtrsm_AutXB_ref(a11, b11, m_rem, 4, cs_a, cs_b, is_unitdiag); @@ -32335,11 +32292,11 @@ BLIS_INLINE err_t bli_ztrsm_small_AutXB_AlXB ymm13 = _mm256_sub_pd(ymm15,ymm13); _mm_storeu_pd((double *)(b11 + 2), - _mm256_extractf128_pd(ymm11,0)); + _mm256_castpd256_pd128(ymm11)); _mm_storeu_pd((double *)(b11 + cs_b * 1 + 2), - _mm256_extractf128_pd(ymm12,0)); + _mm256_castpd256_pd128(ymm12)); _mm_storeu_pd((double *)(b11 + cs_b * 2 + 2), - _mm256_extractf128_pd(ymm13,0)); + _mm256_castpd256_pd128(ymm13)); if(transa) ztrsm_AutXB_ref(a11, b11, m_rem, 3, @@ -32541,35 +32498,33 @@ BLIS_INLINE err_t bli_ztrsm_small_AutXB_AlXB { dim_t p_lda = 2; // packed leading dimension if(transa) - { - dim_t x = 0; - for(x = 0; (x + 1) < i; x += p_lda) - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - _mm_storeu_pd((double *)(ptr_a10_dup), - _mm256_extractf128_pd(ymm0, 0)); - _mm_storeu_pd((double *)(ptr_a10_dup + - p_lda), _mm256_extractf128_pd(ymm0, 1)); - a10 += p_lda; - ptr_a10_dup += p_lda * p_lda; - } - for(; x < i; x += 1) - { - xmm4 = _mm_loadu_pd((double const *)(a10)); - _mm_storeu_pd((double *)(ptr_a10_dup), xmm4); - a10 += 1; - ptr_a10_dup += 1; - } + { + dim_t x = 0; + for(x = 0; (x + 1) < i; x += p_lda) + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + _mm_storeu_pd((double *)(ptr_a10_dup), + _mm256_castpd256_pd128(ymm0)); + _mm_storeu_pd((double *)(ptr_a10_dup + + p_lda), _mm256_extractf128_pd(ymm0, 1)); + a10 += p_lda; + ptr_a10_dup += p_lda * p_lda; + } + for(; x < i; x += 1) + { + xmm4 = _mm_loadu_pd((double const *)(a10)); + _mm_storeu_pd((double *)(ptr_a10_dup), xmm4); + a10 += 1; + ptr_a10_dup += 1; + } - } + } else { for(dim_t x=0;x 0; j -= d_nr) { @@ -33835,11 +33790,11 @@ BLIS_INLINE err_t bli_ztrsm_small_AltXB_AuXB ymm10 = _mm256_sub_pd(ymm15,ymm10); _mm_storeu_pd((double *)(b11), - _mm256_extractf128_pd(ymm8,0)); + _mm256_castpd256_pd128(ymm8)); _mm_storeu_pd((double *)(b11 + cs_b * 1), - _mm256_extractf128_pd(ymm9,0) ); + _mm256_castpd256_pd128(ymm9) ); _mm_storeu_pd((double *)(b11 + cs_b * 2), - _mm256_extractf128_pd(ymm10,0)); + _mm256_castpd256_pd128(ymm10)); if(transa) ztrsm_AltXB_ref(a11, b11, m_remainder, 3, @@ -34405,15 +34360,15 @@ BLIS_INLINE err_t bli_ztrsm_small_XAutB_XAlB #endif _mm256_storeu_pd((double *)b11, ymm3); _mm_storeu_pd((double *)(b11 + 2), - _mm256_extractf128_pd(ymm4,0)); + _mm256_castpd256_pd128(ymm4)); _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); _mm_storeu_pd((double *)(b11 + cs_b + 2), - _mm256_extractf128_pd(ymm6,0)); + _mm256_castpd256_pd128(ymm6)); _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); _mm_storeu_pd((double *)(b11 + cs_b*2 + 2), - _mm256_extractf128_pd(ymm8,0)); + _mm256_castpd256_pd128(ymm8)); m_remainder -=3; } else if(2 == m_remainder) @@ -34710,11 +34665,11 @@ BLIS_INLINE err_t bli_ztrsm_small_XAutB_XAlB BLIS_ZTRSM_MUL(ymm3) #endif _mm_storeu_pd((double *)b11, - _mm256_extractf128_pd(ymm3,0)); + _mm256_castpd256_pd128(ymm3)); _mm_storeu_pd((double *)(b11 + cs_b), - _mm256_extractf128_pd(ymm5,0)); + _mm256_castpd256_pd128(ymm5)); _mm_storeu_pd((double *)(b11 + cs_b*2), - _mm256_extractf128_pd(ymm7,0)); + _mm256_castpd256_pd128(ymm7)); m_remainder -=1; } } @@ -34757,11 +34712,11 @@ BLIS_INLINE err_t bli_ztrsm_small_XAutB_XAlB ymm5 = _mm256_permute2f128_pd(ymm1,ymm5,0x20); _mm_storeu_pd((double *)(ptr_a10_dup + 2), - _mm256_extractf128_pd(ymm3,0)); + _mm256_castpd256_pd128(ymm3)); _mm_storeu_pd((double *)(ptr_a10_dup + p_lda + 2), - _mm256_extractf128_pd(ymm4,0)); + _mm256_castpd256_pd128(ymm4)); _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 2 + 2), - _mm256_extractf128_pd(ymm5, 0)); + _mm256_castpd256_pd128(ymm5)); a01 += d_nr*cs_a; ptr_a10_dup += d_nr; } @@ -34977,11 +34932,11 @@ BLIS_INLINE err_t bli_ztrsm_small_XAutB_XAlB #endif _mm256_storeu_pd((double *)b11, ymm3); _mm_storeu_pd((double *)(b11 + 2), - _mm256_extractf128_pd(ymm4,0)); + _mm256_castpd256_pd128(ymm4)); _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); _mm_storeu_pd((double *)(b11 + cs_b + 2), - _mm256_extractf128_pd(ymm6,0)); + _mm256_castpd256_pd128(ymm6)); m_remainder -=3; } if(2 == m_remainder) @@ -35123,9 +35078,9 @@ BLIS_INLINE err_t bli_ztrsm_small_XAutB_XAlB BLIS_ZTRSM_MUL(ymm3) #endif _mm_storeu_pd((double *)b11, - _mm256_extractf128_pd(ymm3,0)); + _mm256_castpd256_pd128(ymm3)); _mm_storeu_pd((double *)(b11 + cs_b), - _mm256_extractf128_pd(ymm5,0)); + _mm256_castpd256_pd128(ymm5)); m_remainder -=1; } n_remainder -= 2; @@ -35167,12 +35122,12 @@ BLIS_INLINE err_t bli_ztrsm_small_XAutB_XAlB ymm5 = _mm256_permute2f128_pd(ymm1,ymm5,0x20); _mm_storeu_pd((double *)(ptr_a10_dup + 2), - _mm256_extractf128_pd(ymm3,0)); + _mm256_castpd256_pd128(ymm3)); _mm_storeu_pd((double *)(ptr_a10_dup + p_lda + 2), - _mm256_extractf128_pd(ymm4,0)); + _mm256_castpd256_pd128(ymm4)); _mm_storeu_pd((double *) (ptr_a10_dup + p_lda * 2 + 2), - _mm256_extractf128_pd(ymm5, 0)); + _mm256_castpd256_pd128(ymm5)); a01 += d_nr*cs_a; ptr_a10_dup += d_nr; } @@ -35283,7 +35238,7 @@ BLIS_INLINE err_t bli_ztrsm_small_XAutB_XAlB _mm256_storeu_pd((double *)b11, ymm3); _mm_storeu_pd((double *)(b11 + 2), - _mm256_extractf128_pd(ymm4,0)); + _mm256_castpd256_pd128(ymm4)); m_remainder -=3; } @@ -35351,7 +35306,7 @@ BLIS_INLINE err_t bli_ztrsm_small_XAutB_XAlB BLIS_ZTRSM_MUL(ymm3) #endif _mm_storeu_pd((double *)b11, - _mm256_extractf128_pd(ymm3,0)); + _mm256_castpd256_pd128(ymm3)); m_remainder -=1; } n_remainder -= 1; @@ -35850,15 +35805,15 @@ BLIS_INLINE err_t bli_ztrsm_small_XAltB_XAuB _mm256_storeu_pd((double *)b11, ymm3); _mm_storeu_pd((double *)(b11 + 2), - _mm256_extractf128_pd(ymm4,0)); + _mm256_castpd256_pd128(ymm4)); _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); _mm_storeu_pd((double *)(b11 + cs_b + 2), - _mm256_extractf128_pd(ymm6,0)); + _mm256_castpd256_pd128(ymm6)); _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); _mm_storeu_pd((double *)(b11 + cs_b*2 + 2), - _mm256_extractf128_pd(ymm8,0)); + _mm256_castpd256_pd128(ymm8)); m_remainder -= 3; i += 3; @@ -36132,11 +36087,11 @@ BLIS_INLINE err_t bli_ztrsm_small_XAltB_XAuB _mm_storeu_pd((double *)b11, - _mm256_extractf128_pd(ymm3,0)); + _mm256_castpd256_pd128(ymm3)); _mm_storeu_pd((double *)(b11 + cs_b), - _mm256_extractf128_pd(ymm5,0)); + _mm256_castpd256_pd128(ymm5)); _mm_storeu_pd((double *)(b11 + cs_b*2), - _mm256_extractf128_pd(ymm7,0)); + _mm256_castpd256_pd128(ymm7)); m_remainder -= 1; i += 1; @@ -36184,11 +36139,11 @@ BLIS_INLINE err_t bli_ztrsm_small_XAltB_XAuB ymm5 = _mm256_permute2f128_pd(ymm1,ymm5,0x20); _mm_storeu_pd((double *)(ptr_a10_dup + 2), - _mm256_extractf128_pd(ymm3,0)); + _mm256_castpd256_pd128(ymm3)); _mm_storeu_pd((double *)(ptr_a10_dup + p_lda + 2), - _mm256_extractf128_pd(ymm4,0)); + _mm256_castpd256_pd128(ymm4)); _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 2 + 2), - _mm256_extractf128_pd(ymm5, 0)); + _mm256_castpd256_pd128(ymm5)); a01 += d_nr*cs_a; ptr_a10_dup += d_nr; } @@ -36405,11 +36360,11 @@ BLIS_INLINE err_t bli_ztrsm_small_XAltB_XAuB _mm256_storeu_pd((double *)b11, ymm3); _mm_storeu_pd((double *)(b11 + 2), - _mm256_extractf128_pd(ymm4,0)); + _mm256_castpd256_pd128(ymm4)); _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); _mm_storeu_pd((double *)(b11 + cs_b + 2), - _mm256_extractf128_pd(ymm6,0)); + _mm256_castpd256_pd128(ymm6)); m_remainder -= 3; i += 3; } @@ -36548,9 +36503,9 @@ BLIS_INLINE err_t bli_ztrsm_small_XAltB_XAuB BLIS_ZTRSM_MUL(ymm5) #endif _mm_storeu_pd((double *)b11, - _mm256_extractf128_pd(ymm3,0)); + _mm256_castpd256_pd128(ymm3)); _mm_storeu_pd((double *)(b11 + cs_b), - _mm256_extractf128_pd(ymm5,0)); + _mm256_castpd256_pd128(ymm5)); m_remainder -= 1; i += 1; } @@ -36595,11 +36550,11 @@ BLIS_INLINE err_t bli_ztrsm_small_XAltB_XAuB ymm5 = _mm256_permute2f128_pd(ymm1,ymm5,0x20); _mm_storeu_pd((double *)(ptr_a10_dup + 2), - _mm256_extractf128_pd(ymm3,0)); + _mm256_castpd256_pd128(ymm3)); _mm_storeu_pd((double *)(ptr_a10_dup + p_lda + 2), - _mm256_extractf128_pd(ymm4,0)); + _mm256_castpd256_pd128(ymm4)); _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 2 + 2), - _mm256_extractf128_pd(ymm5, 0)); + _mm256_castpd256_pd128(ymm5)); a01 += d_nr*cs_a; ptr_a10_dup += d_nr; } @@ -36710,7 +36665,7 @@ BLIS_INLINE err_t bli_ztrsm_small_XAltB_XAuB _mm256_storeu_pd((double *)b11, ymm3); _mm_storeu_pd((double *)(b11 + 2), - _mm256_extractf128_pd(ymm4,0)); + _mm256_castpd256_pd128(ymm4)); m_remainder -= 3; i += 3; } @@ -36786,7 +36741,7 @@ BLIS_INLINE err_t bli_ztrsm_small_XAltB_XAuB BLIS_ZTRSM_MUL(ymm3) #endif _mm_storeu_pd((double *)b11, - _mm256_extractf128_pd(ymm3,0)); + _mm256_castpd256_pd128(ymm3)); m_remainder -= 1; i += 1; } From 4bca7f6f4afc4af7157c50a27c6b35f6d7f830e4 Mon Sep 17 00:00:00 2001 From: Shubham Sharma Date: Wed, 27 Jul 2022 05:47:45 -0500 Subject: [PATCH 170/243] DGEMMT optimizations Details: 1. For lower and upper, non-transpose variants of gemmt, new kernels are developed and optimized to compute only the required outputs in the diagonal blocks. 2. In the previous implementation, all the 48 outputs of the given 6x8 block of C matrix are computed and stored into a temporary buffer. Later,the required elements are copied into the final C output buffer. 3. Changes are made to compute only the required outputs of the 6x8 block of C matrix and directly stored in the final C output buffer. 4. With this optimization, we are avoiding copy operation and also reducing the number of computations. 5. Kernels specific to compute Lower and Upper Variant diagonal outputs have been added. 6. SUP Framework changes to integrate the new kernels have been added. 7. These kernels are part of the SUP framework. AMD-Internal: [CPUPL-2341] Change-Id: I0ec8f24a0fb19d9b1ef7254732b8e09f06e1486a --- frame/3/gemmt/bli_gemmt_sup_var1n2m.c | 603 +- .../3/sup/bli_gemmsup_rv_haswell_asm_d6x8m.c | 7224 +++++++++++++++-- 2 files changed, 7199 insertions(+), 628 deletions(-) diff --git a/frame/3/gemmt/bli_gemmt_sup_var1n2m.c b/frame/3/gemmt/bli_gemmt_sup_var1n2m.c index 382ca6f67d..2af7e9f45f 100644 --- a/frame/3/gemmt/bli_gemmt_sup_var1n2m.c +++ b/frame/3/gemmt/bli_gemmt_sup_var1n2m.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2020 - 21, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 22, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -55,6 +55,288 @@ typedef void (*FUNCPTR_T) rntm_t* restrict rntm, thrinfo_t* restrict thread ); + + +// Declaration of gemmt specific kernels function pointer +// This is aligned to bli_dgemmsup_rv_haswell_asm_6x8m function protype. +typedef void (*gemmt_ker_ft) + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + void* restrict alpha, + void* restrict a, inc_t rs_a0, inc_t cs_a0, + void* restrict b, inc_t rs_b0, inc_t cs_b0, + void* restrict beta, + void* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ); + +// Gemmt Upper variant kernel for m_offset=0 and n_offset=0 in 24x24 block +BLIS_INLINE void bli_dgemmsup_rv_haswell_asm_6x8m_0x0_U + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + void* restrict alpha, + void* restrict a, inc_t rs_a0, inc_t cs_a0, + void* restrict b, inc_t rs_b0, inc_t cs_b0, + void* restrict beta, + void* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ); + +// Gemmt Upper variant kernel for m_offset=6 and n_offset=8 in 24x24 block +BLIS_INLINE void bli_dgemmsup_rv_haswell_asm_6x8m_6x8_U + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + void* restrict alpha, + void* restrict a, inc_t rs_a0, inc_t cs_a0, + void* restrict b, inc_t rs_b0, inc_t cs_b0, + void* restrict beta, + void* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ); + +// Gemmt Upper variant kernel for m_offset=12 and n_offset=16 in 24x24 block +BLIS_INLINE void bli_dgemmsup_rv_haswell_asm_6x8m_12x16_U + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + void* restrict alpha, + void* restrict a, inc_t rs_a0, inc_t cs_a0, + void* restrict b, inc_t rs_b0, inc_t cs_b0, + void* restrict beta, + void* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ); + +// Gemmt Upper variant combined kernel for m_offset=12, n_offset=16 and m_offset=18, n_offset=16 in 24x24 block +BLIS_INLINE void bli_dgemmsup_rv_haswell_asm_6x8m_0x0_combined_U + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + void* restrict alpha, + void* restrict a, inc_t rs_a0, inc_t cs_a0, + void* restrict b, inc_t rs_b0, inc_t cs_b0, + void* restrict beta, + void* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ); + +// Gemmt Upper variant kernel for m_offset=6 and n_offset=0 in 24x24 block +BLIS_INLINE void bli_dgemmsup_rv_haswell_asm_6x8m_6x0_U + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + void* restrict alpha, + void* restrict a, inc_t rs_a0, inc_t cs_a0, + void* restrict b, inc_t rs_b0, inc_t cs_b0, + void* restrict beta, + void* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ); + +// Gemmt Upper variant kernel for m_offset=12 and n_offset=8 in 24x24 block +BLIS_INLINE void bli_dgemmsup_rv_haswell_asm_6x8m_12x8_U + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + void* restrict alpha, + void* restrict a, inc_t rs_a0, inc_t cs_a0, + void* restrict b, inc_t rs_b0, inc_t cs_b0, + void* restrict beta, + void* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ); + +// Gemmt Upper variant kernel for m_offset=18 and n_offset=16 in 24x24 block +BLIS_INLINE void bli_dgemmsup_rv_haswell_asm_6x8m_18x16_U + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + void* restrict alpha, + void* restrict a, inc_t rs_a0, inc_t cs_a0, + void* restrict b, inc_t rs_b0, inc_t cs_b0, + void* restrict beta, + void* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ); + +// Gemmt Lower variant kernel for m_offset=0 and n_offset=0 in 24x24 block +BLIS_INLINE void bli_dgemmsup_rv_haswell_asm_6x8m_0x0_L + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + void* restrict alpha, + void* restrict a, inc_t rs_a0, inc_t cs_a0, + void* restrict b, inc_t rs_b0, inc_t cs_b0, + void* restrict beta, + void* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ); + +// Gemmt Lower variant kernel for m_offset=6 and n_offset=8 in 24x24 block +BLIS_INLINE void bli_dgemmsup_rv_haswell_asm_6x8m_6x8_L + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + void* restrict alpha, + void* restrict a, inc_t rs_a0, inc_t cs_a0, + void* restrict b, inc_t rs_b0, inc_t cs_b0, + void* restrict beta, + void* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ); + +// Gemmt Lower variant kernel for m_offset=12 and n_offset=16 in 24x24 block +BLIS_INLINE void bli_dgemmsup_rv_haswell_asm_6x8m_12x16_L + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + void* restrict alpha, + void* restrict a, inc_t rs_a0, inc_t cs_a0, + void* restrict b, inc_t rs_b0, inc_t cs_b0, + void* restrict beta, + void* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ); + +// Gemmt Lower variant combined kernel for m_offset=0, n_offset=0 and m_offset=6, n_offset=0 in 24x24 block +BLIS_INLINE void bli_dgemmsup_rv_haswell_asm_6x8m_16x12_combined_L + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + void* restrict alpha, + void* restrict a, inc_t rs_a0, inc_t cs_a0, + void* restrict b, inc_t rs_b0, inc_t cs_b0, + void* restrict beta, + void* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ); + +// Gemmt Lower variant kernel for m_offset=6 and n_offset=0 in 24x24 block +BLIS_INLINE void bli_dgemmsup_rv_haswell_asm_6x8m_6x0_L + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + void* restrict alpha, + void* restrict a, inc_t rs_a0, inc_t cs_a0, + void* restrict b, inc_t rs_b0, inc_t cs_b0, + void* restrict beta, + void* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ); + +// Gemmt Lower variant kernel for m_offset=12 and n_offset=8 in 24x24 block +BLIS_INLINE void bli_dgemmsup_rv_haswell_asm_6x8m_12x8_L + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + void* restrict alpha, + void* restrict a, inc_t rs_a0, inc_t cs_a0, + void* restrict b, inc_t rs_b0, inc_t cs_b0, + void* restrict beta, + void* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ); + +// Gemmt Lower variant kernel for m_offset=18 and n_offset=16 in 24x24 block +BLIS_INLINE void bli_dgemmsup_rv_haswell_asm_6x8m_18x16_L + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + void* restrict alpha, + void* restrict a, inc_t rs_a0, inc_t cs_a0, + void* restrict b, inc_t rs_b0, inc_t cs_b0, + void* restrict beta, + void* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ); + +//Look-up table for Gemmt Upper Variant Kernels +gemmt_ker_ft ker_fpus[7] = +{ + bli_dgemmsup_rv_haswell_asm_6x8m_0x0_U, + bli_dgemmsup_rv_haswell_asm_6x8m_6x0_U, + bli_dgemmsup_rv_haswell_asm_6x8m_6x8_U, + bli_dgemmsup_rv_haswell_asm_6x8m_12x8_U, + bli_dgemmsup_rv_haswell_asm_6x8m_12x16_U, + bli_dgemmsup_rv_haswell_asm_6x8m_18x16_U, + bli_dgemmsup_rv_haswell_asm_6x8m_0x0_combined_U +}; + +//Look-up table for Gemmt Lower Variant Kernels +gemmt_ker_ft ker_fpls[7] = +{ + bli_dgemmsup_rv_haswell_asm_6x8m_0x0_L, + bli_dgemmsup_rv_haswell_asm_6x8m_6x0_L, + bli_dgemmsup_rv_haswell_asm_6x8m_6x8_L, + bli_dgemmsup_rv_haswell_asm_6x8m_12x8_L, + bli_dgemmsup_rv_haswell_asm_6x8m_12x16_L, + bli_dgemmsup_rv_haswell_asm_6x8m_18x16_L, + bli_dgemmsup_rv_haswell_asm_6x8m_16x12_combined_L +}; + // // -- var1n -------------------------------------------------------------------- // @@ -1501,7 +1783,7 @@ void PASTEMACT(ch,opname,uplo,varname) \ \ /* storage-scheme of ct should be same as that of C. Since update routines only support row-major order, - col_pref flag is used to induce transpose to matrices before + col_pref flag is used to induce transpose to matrices before passing to update routine whenever C is col-stored */ \ const bool col_pref = (rs_c == 1)? 1 : 0; \ \ @@ -1833,40 +2115,138 @@ void PASTEMACT(ch,opname,uplo,varname) \ { \ const dim_t mr_cur = (i+MR-1) < mc_cur ? MR : mc_cur - i; \ \ - /* Invoke the gemmsup millikernel. */ \ - gemmsup_ker \ - ( \ - conja, \ - conjb, \ - mr_cur, \ - nr_cur, \ - kc_cur, \ - alpha_cast, \ - a_ir, rs_a_use, cs_a_use, \ - b_jr, rs_b_use, cs_b_use, \ - zero, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ - /* Scale the bottom edge of C and add the result from above. */ \ - /* If c and ct are col-major, induce transpose and call update for upper-triangle of C */ \ - if( col_pref ) \ - { \ - PASTEMAC(ch,update_upper_triang)( n_off_cblock, m_off_cblock, \ - nr_cur, mr_cur, \ - ct, cs_ct, rs_ct, \ + /* Prerequisites : MR = 6, NR = 8. + An optimization: allow the last jr iteration to contain up to NRE + In DGEMMT API implementation, kernel operates on 6x8 block. MR and + NR are set as 6 and 8 respectively. 24 being the LCM of 6 and 8, + the diagonal pattern repeats for every 24x24 block. + This pattern is exploited to achieve the optimization in diagonal + blocks by computing only the required elements. In the previous + implementation, all the 48 outputs of the given 6x8 block are + computed and stored into a temporary buffer. Later, the required + elements are copied into the final C output buffer. + With this optimization, we are avoiding copy operation and also + reducing the number of computations. + Variables m_off_24 and n_off_24 respectively store the m and n + offsets from the starting point of the corresponding 24x24 block. + Variables m_idx and n_idx store indices of the current 6x8 block + along m and n dimensions, in 24x24 block. m_idx is computed as + (m_off_24 / MR) while n_idx is computed as (n_off_24 / NR). + Range of m_idx is 0 <= m_idx <= 3 and the range of n_idx is + 0 <= n_idx <= 2. Based on these indices, for the given 6x8 block, + logic is implemented to identify the relevant kernel from the + look-up table. + During instances, where m is not a multiple of 6 or n is not a + multiple of 8, it goes to the default gemm kernel. MR and NR must be + 6 and 8 for these kernels to achieve the expected functionality.*/ \ +\ + dim_t m_off_24 = m_off_cblock % 24; \ + dim_t n_off_24 = n_off_cblock % 24; \ + dim_t m_idx = (dim_t)(m_off_24 / MR); \ + dim_t n_idx = (dim_t)(n_off_24 / NR); \ +\ + /* Optimized kernels are not implemented for the case where B is + stored as column major */ \ + bool storage_supported = (dt == BLIS_DOUBLE) && ( (stor_id == BLIS_RRR) || (stor_id == BLIS_RCR) || (stor_id == BLIS_CRR) ); \ +\ + /* Check if m, n indices are multiple of MR and NR respectively + and current block is a complete 6x8 block */ \ + bool idx_supported = ((m_off_24 % MR) == 0) && ((n_off_24 % NR) == 0) && (mr_cur == MR) && (nr_cur == NR); \ +\ + /* m_idx and n_idx would be equal only if the current block is + a diagonal block */\ + if( (storage_supported) && (m_idx == n_idx) && (idx_supported) ) { \ + dim_t ker_idx; \ + ker_idx = m_idx<<1; \ +\ + /* If there is another 6x8 diagonal block pending for computation + after the current 6x8 diagonal block, then the two blocks can + be computed together(12x8). This combined kernel is implemented + only for the case where n_idx = 2 i.e., n_off_24 = 16. To call + this, it has to be ensured that at least 12 rows are pending in + C for computation. (m_off + 2 * MR <=m). Usage of this combined + kernel saves the entire time to execute one kernel*/ \ + if( (n_idx == 2) && (m_off_cblock + MR + MR <= m) )\ + ker_idx = 6; /* use combined kernel, index of combined kernel + in lookup table is 6 */\ + gemmt_ker_ft ker_fp = ker_fpls[ker_idx]; \ + ker_fp \ + ( \ + conja, \ + conjb, \ + mr_cur, \ + nr_cur, \ + kc_cur, \ + alpha_cast, \ + a_ir, rs_a_use, cs_a_use, \ + b_jr, rs_b_use, cs_b_use, \ beta_use, \ - c_ir, cs_c, rs_c ); \ + c_ir, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ } \ - else \ - { \ - PASTEMAC(ch,update_lower_triang)( m_off_cblock, n_off_cblock, \ - mr_cur, nr_cur, \ - ct, rs_ct, cs_ct, \ - beta_use, \ - c_ir, rs_c, cs_c ); \ + /* 6x8 block where m_idx == n_idx+1 also has some parts of the diagonal */\ + else if( (storage_supported) && (m_idx == n_idx+1) && (idx_supported) ) { \ + dim_t ker_idx = (n_idx << 1) + 1; \ + gemmt_ker_ft ker_fp = ker_fpls[ker_idx]; \ + /* If current block was already computed in the combined kernel it + can be skipped combined kernel is only implemented for n_idx=2, + i == m_zero is only true for the first iteration therefore if + i == m_zero then the current 6x8 block was not computed in + combined kernel*/ \ + if( (n_idx != 2) || (i == m_zero) ) { \ + ker_fp \ + ( \ + conja, \ + conjb, \ + mr_cur, \ + nr_cur, \ + kc_cur, \ + alpha_cast, \ + a_ir, rs_a_use, cs_a_use, \ + b_jr, rs_b_use, cs_b_use, \ + beta_use, \ + c_ir, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ + } \ } \ + /* Call the regular kernel for non applicable cases */ \ + else { \ + gemmsup_ker \ + ( \ + conja, \ + conjb, \ + mr_cur, \ + nr_cur, \ + kc_cur, \ + alpha_cast, \ + a_ir, rs_a_use, cs_a_use, \ + b_jr, rs_b_use, cs_b_use, \ + zero, \ + ct, rs_ct, cs_ct, \ + &aux, \ + cntx \ + ); \ + if( col_pref ) \ + { \ + PASTEMAC(ch,update_upper_triang)( n_off_cblock, m_off_cblock, \ + nr_cur, mr_cur, \ + ct, cs_ct, rs_ct, \ + beta_use, \ + c_ir, cs_c, rs_c ); \ + } \ + else \ + { \ + PASTEMAC(ch,update_lower_triang)( m_off_cblock, n_off_cblock, \ + mr_cur, nr_cur, \ + ct, rs_ct, cs_ct, \ + beta_use, \ + c_ir, rs_c, cs_c ); \ + }\ + }\ \ a_ir += ps_a_use; \ c_ir += irstep_c; \ @@ -2410,39 +2790,136 @@ void PASTEMACT(ch,opname,uplo,varname) \ { \ const dim_t mr_cur = (i+MR-1) < mc_cur ? MR : mc_cur - i; \ \ - /* Invoke the gemmsup millikernel. */ \ - gemmsup_ker \ - ( \ - conja, \ - conjb, \ - mr_cur, \ - nr_cur, \ - kc_cur, \ - alpha_cast, \ - a_ir, rs_a_use, cs_a_use, \ - b_jr, rs_b_use, cs_b_use, \ - zero, \ - ct, rs_ct, cs_ct, \ - &aux, \ - cntx \ - ); \ -\ - if( col_pref ) \ - { \ - PASTEMAC(ch,update_lower_triang)( n_off_cblock, m_off_cblock, \ - nr_cur, mr_cur, \ - ct, cs_ct, rs_ct, \ - beta_use, \ - c_ir, cs_c, rs_c ); \ + /* Prerequisites : MR = 6, NR = 8. + An optimization: allow the last jr iteration to contain up to NRE + In DGEMMT API implementation, kernel operates on 6x8 block. MR and + NR are set as 6 and 8 respectively. 24 being the LCM of 6 and 8, + the diagonal pattern repeats for every 24x24 block. + This pattern is exploited to achieve the optimization in diagonal + blocks by computing only the required elements. In the previous + implementation, all the 48 outputs of the given 6x8 block are + computed and stored into a temporary buffer. Later, the required + elements are copied into the final C output buffer. + With this optimization, we are avoiding copy operation and also + reducing the number of computations. + Variables m_off_24 and n_off_24 respectively store the m and n + offsets from the starting point of the corresponding 24x24 block. + Variables m_idx and n_idx store indices of the current 6x8 block + along m and n dimensions, in 24x24 block. m_idx is computed as + (m_off_24 / MR) while n_idx is computed as (n_off_24 / NR). + Range of m_idx is 0 <= m_idx <= 3 and the range of n_idx is + 0 <= n_idx <= 2. Based on these indices, for the given 6x8 block, + logic is implemented to identify the relevant kernel from the + look-up table. + During instances, where m is not a multiple of 6 or n is not a + multiple of 8, it goes to the default gemm kernel. MR and NR must be + 6 and 8 for these kernels to achieve the expected functionality.*/ \ + dim_t m_off_24 = m_off_cblock % 24; \ + dim_t n_off_24 = n_off_cblock % 24; \ + dim_t m_idx = (dim_t)(m_off_24 / MR); \ + dim_t n_idx = (dim_t)(n_off_24 / NR); \ +\ + /* Optimized kernels are not implemented for the case where B is + stored as column major */ \ + bool storage_supported = (dt == BLIS_DOUBLE) && ( (stor_id == BLIS_RRR) || (stor_id == BLIS_RCR) || (stor_id == BLIS_CRR) ); \ +\ + /* Check if m, n indices are multiple of MR and NR respectively + and current block is a complete 6x8 block */ \ + bool idx_supported = ((m_off_24 % MR) == 0) && ((n_off_24 % NR) == 0) && (mr_cur==MR) && (nr_cur==NR); \ +\ + /* m_idx and n_idx would be equal only if the current block is + a diagonal block */\ + if( (storage_supported) && (m_idx == n_idx) && (idx_supported) ) { \ + m_idx = m_idx<<1; \ + /* If there is another 6x8 diagonal block pending for computation + after the current 6x8 diagonal block, then the two blocks can + be computed together(12x8). This combined kernel is implemented + only for the case where n_idx = 0 i.e., n_off_24 = 0. To call + this, it has to be ensured that at least 12 rows are pending in + C for computation (i+ MR + MR <= mc_cur). Usage of this combined + kernel saves the entire time to execute one kernel*/ \ + if( (n_idx == 0) && (i+ MR + MR <= mc_cur) ) \ + m_idx = 6; /* use combined kernel, index of combined kernel + in lookup table is 6 */\ + gemmt_ker_ft ker_fp = ker_fpus[m_idx]; \ + ker_fp \ + ( \ + conja, \ + conjb, \ + mr_cur, \ + nr_cur, \ + kc_cur, \ + alpha_cast, \ + a_ir, rs_a_use, cs_a_use, \ + b_jr, rs_b_use, cs_b_use, \ + beta_use, \ + c_ir, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ } \ - else \ - { \ - PASTEMAC(ch,update_upper_triang)( m_off_cblock, n_off_cblock, \ - mr_cur, nr_cur, \ - ct, rs_ct, cs_ct, \ + /* 6x8 block where m_idx == n_idx+1 also has some parts of the diagonal */\ + else if( (storage_supported) && (m_idx == n_idx+1) && (idx_supported) ) { \ + gemmt_ker_ft ker_fp = ker_fpus[(n_idx << 1) + 1]; \ + /* If current block was already computed in the combined kernel it + can be skipped combined kernel is only implemented for n_idx=0, + i == m_rect is only true for the first iteration therefore if + i == m_rect then the current 6x8 block was not computed in + combined kernel*/ \ + if( (n_idx != 0) || (i == m_rect) ) { \ + ker_fp \ + ( \ + conja, \ + conjb, \ + mr_cur, \ + nr_cur, \ + kc_cur, \ + alpha_cast, \ + a_ir, rs_a_use, cs_a_use, \ + b_jr, rs_b_use, cs_b_use, \ beta_use, \ - c_ir, rs_c, cs_c ); \ + c_ir, rs_c, cs_c, \ + &aux, \ + cntx \ + ); \ + } \ } \ + /* call the regular kernel for non applicable cases */ \ + else { \ + gemmsup_ker \ + ( \ + conja, \ + conjb, \ + mr_cur, \ + nr_cur, \ + kc_cur, \ + alpha_cast, \ + a_ir, rs_a_use, cs_a_use, \ + b_jr, rs_b_use, cs_b_use, \ + zero, \ + ct, rs_ct, cs_ct, \ + &aux, \ + cntx \ + ); \ + \ + if( col_pref ) \ + { \ + PASTEMAC(ch,update_lower_triang)( n_off_cblock, m_off_cblock, \ + nr_cur, mr_cur, \ + ct, cs_ct, rs_ct, \ + beta_use, \ + c_ir, cs_c, rs_c ); \ + } \ + else \ + { \ + PASTEMAC(ch,update_upper_triang)( m_off_cblock, n_off_cblock, \ + mr_cur, nr_cur, \ + ct, rs_ct, cs_ct, \ + beta_use, \ + c_ir, rs_c, cs_c ); \ + } \ + } \ +\ a_ir += ps_a_use; \ c_ir += irstep_c; \ m_off_cblock += mr_cur; \ diff --git a/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8m.c b/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8m.c index 56227ec4dc..21c394f558 100644 --- a/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8m.c +++ b/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8m.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2019 - 22, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -40,20 +40,20 @@ /* rrr: - -------- ------ -------- - -------- ------ -------- - -------- += ------ ... -------- - -------- ------ -------- - -------- ------ : - -------- ------ : + -------- ------ -------- + -------- ------ -------- + -------- += ------ ... -------- + -------- ------ -------- + -------- ------ : + -------- ------ : rcr: - -------- | | | | -------- - -------- | | | | -------- - -------- += | | | | ... -------- - -------- | | | | -------- - -------- | | | | : - -------- | | | | : + -------- | | | | -------- + -------- | | | | -------- + -------- += | | | | ... -------- + -------- | | | | -------- + -------- | | | | : + -------- | | | | : Assumptions: - B is row-stored; @@ -69,12 +69,12 @@ cost of the in-register transpose). crr: - | | | | | | | | ------ -------- - | | | | | | | | ------ -------- - | | | | | | | | += ------ ... -------- - | | | | | | | | ------ -------- - | | | | | | | | ------ : - | | | | | | | | ------ : + | | | | | | | | ------ -------- + | | | | | | | | ------ -------- + | | | | | | | | += ------ ... -------- + | | | | | | | | ------ -------- + | | | | | | | | ------ : + | | | | | | | | ------ : */ // Prototype reference microkernels. @@ -226,15 +226,15 @@ void bli_dgemmsup_rv_haswell_asm_6x8m // ------------------------------------------------------------------------- begin_asm() - + //vzeroall() // zero all xmm/ymm registers. - + mov(var(a), r14) // load address of a. mov(var(rs_a), r8) // load rs_a mov(var(cs_a), r9) // load cs_a lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) - + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a @@ -337,19 +337,19 @@ void bli_dgemmsup_rv_haswell_asm_6x8m lea(mem(rdx, r8, 2), rdx) // from next upanel of a. lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; #endif - - - - + + + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.DLOOPKITER) // MAIN LOOP - - + + // ---------------------------------- iteration 0 #if 0 @@ -357,7 +357,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8m #else prefetch(0, mem(rdx, 5*8)) #endif - + vmovupd(mem(rbx, 0*32), ymm0) vmovupd(mem(rbx, 1*32), ymm1) add(r10, rbx) // b += rs_b; @@ -368,14 +368,14 @@ void bli_dgemmsup_rv_haswell_asm_6x8m vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, r8, 4), ymm2) vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; @@ -384,7 +384,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8m vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + // ---------------------------------- iteration 1 #if 0 @@ -403,14 +403,14 @@ void bli_dgemmsup_rv_haswell_asm_6x8m vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, r8, 4), ymm2) vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; @@ -418,8 +418,8 @@ void bli_dgemmsup_rv_haswell_asm_6x8m vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - - + + // ---------------------------------- iteration 2 #if 0 @@ -427,7 +427,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8m #else prefetch(0, mem(rdx, r9, 2, 5*8)) #endif - + vmovupd(mem(rbx, 0*32), ymm0) vmovupd(mem(rbx, 1*32), ymm1) add(r10, rbx) // b += rs_b; @@ -438,14 +438,14 @@ void bli_dgemmsup_rv_haswell_asm_6x8m vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, r8, 4), ymm2) vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; @@ -453,7 +453,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8m vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + // ---------------------------------- iteration 3 @@ -474,14 +474,14 @@ void bli_dgemmsup_rv_haswell_asm_6x8m vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, r8, 4), ymm2) vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; @@ -489,50 +489,50 @@ void bli_dgemmsup_rv_haswell_asm_6x8m vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - - - + + + dec(rsi) // i -= 1; jne(.DLOOPKITER) // iterate again if i != 0. - - - - - - + + + + + + label(.DCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.DPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.DLOOPKLEFT) // EDGE LOOP #if 1 prefetch(0, mem(rdx, 5*8)) add(r9, rdx) #endif - + vmovupd(mem(rbx, 0*32), ymm0) vmovupd(mem(rbx, 1*32), ymm1) add(r10, rbx) // b += rs_b; - + vbroadcastsd(mem(rax ), ymm2) vbroadcastsd(mem(rax, r8, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, r8, 4), ymm2) vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; @@ -540,23 +540,23 @@ void bli_dgemmsup_rv_haswell_asm_6x8m vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - - + + dec(rsi) // i -= 1; jne(.DLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.DPOSTACCUM) - - + + mov(r12, rcx) // reset rcx to current utile of c. mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate - + vmulpd(ymm0, ymm4, ymm4) // scale by alpha vmulpd(ymm0, ymm5, ymm5) vmulpd(ymm0, ymm6, ymm6) @@ -569,24 +569,24 @@ void bli_dgemmsup_rv_haswell_asm_6x8m vmulpd(ymm0, ymm13, ymm13) vmulpd(ymm0, ymm14, ymm14) vmulpd(ymm0, ymm15, ymm15) - - - - - - + + + + + + mov(var(cs_c), rsi) // load cs_c lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) - + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; - - - + + + // now avoid loading C if beta == 0 - + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomisd(xmm0, xmm3) // set ZF if beta == 0. je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case @@ -595,60 +595,60 @@ void bli_dgemmsup_rv_haswell_asm_6x8m cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. jz(.DCOLSTORED) // jump to column storage case - - + + label(.DROWSTORED) - - + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) vmovupd(ymm4, mem(rcx, 0*32)) vfmadd231pd(mem(rcx, 1*32), ymm3, ymm5) vmovupd(ymm5, mem(rcx, 1*32)) add(rdi, rcx) - - + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) vmovupd(ymm6, mem(rcx, 0*32)) vfmadd231pd(mem(rcx, 1*32), ymm3, ymm7) vmovupd(ymm7, mem(rcx, 1*32)) add(rdi, rcx) - - + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) vmovupd(ymm8, mem(rcx, 0*32)) vfmadd231pd(mem(rcx, 1*32), ymm3, ymm9) vmovupd(ymm9, mem(rcx, 1*32)) add(rdi, rcx) - - + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm10) vmovupd(ymm10, mem(rcx, 0*32)) vfmadd231pd(mem(rcx, 1*32), ymm3, ymm11) vmovupd(ymm11, mem(rcx, 1*32)) add(rdi, rcx) - - + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm12) vmovupd(ymm12, mem(rcx, 0*32)) vfmadd231pd(mem(rcx, 1*32), ymm3, ymm13) vmovupd(ymm13, mem(rcx, 1*32)) add(rdi, rcx) - - + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm14) vmovupd(ymm14, mem(rcx, 0*32)) vfmadd231pd(mem(rcx, 1*32), ymm3, ymm15) vmovupd(ymm15, mem(rcx, 1*32)) //add(rdi, rcx) - - + + jmp(.DDONE) // jump to end. @@ -735,51 +735,51 @@ void bli_dgemmsup_rv_haswell_asm_6x8m jmp(.DDONE) // jump to end. - - - - + + + + label(.DBETAZERO) - + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. jz(.DCOLSTORBZ) // jump to column storage case - + label(.DROWSTORBZ) - - + + vmovupd(ymm4, mem(rcx, 0*32)) vmovupd(ymm5, mem(rcx, 1*32)) add(rdi, rcx) - + vmovupd(ymm6, mem(rcx, 0*32)) vmovupd(ymm7, mem(rcx, 1*32)) add(rdi, rcx) - - + + vmovupd(ymm8, mem(rcx, 0*32)) vmovupd(ymm9, mem(rcx, 1*32)) add(rdi, rcx) - - + + vmovupd(ymm10, mem(rcx, 0*32)) vmovupd(ymm11, mem(rcx, 1*32)) add(rdi, rcx) - - + + vmovupd(ymm12, mem(rcx, 0*32)) vmovupd(ymm13, mem(rcx, 1*32)) add(rdi, rcx) - - + + vmovupd(ymm14, mem(rcx, 0*32)) vmovupd(ymm15, mem(rcx, 1*32)) //add(rdi, rcx) - - + + jmp(.DDONE) // jump to end. @@ -844,9 +844,9 @@ void bli_dgemmsup_rv_haswell_asm_6x8m //lea(mem(rdx, rsi, 4), rdx) - - - + + + label(.DDONE) @@ -867,8 +867,8 @@ void bli_dgemmsup_rv_haswell_asm_6x8m label(.DRETURN) - - + + end_asm( : // output operands (none) @@ -985,7 +985,143 @@ void bli_dgemmsup_rv_haswell_asm_6x8m AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); } -void bli_dgemmsup_rv_haswell_asm_6x6m +/* +24x24 block + + 1 1 1 1 1 1 1 1 1 1 2 2 2 2 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 + |- - - - - - - -|- - - - - - - -| - - - - - - - -| +0 | | | | +1 | m_off_24 = 0 | | | +2 | n_off_24 = 0 | | | +3 | m_idx = 0 | | | +4 | n_idx = 0 | | | +5 |- - - - - - - -|- - - - - - - -|- - - - - - - - | +6 | | | | +7 | m_off_24 = 6 | m_off_24 = 6 | | +8 | n_off_24 = 0 | n_off_24 = 8 | | +9 | m_idx = 1 | m_idx = 1 | | +10 | n_idx = 0 | n_idx = 1 | | +11 |- - - - - - - -|- - - - - - - -|- - - - - - - - | +12 | | | | +13 | | m_off_24 = 12 | m_off_24 = 12 | +14 | | n_off_24 = 8 | n_off_24 = 16 | +15 | | m_idx = 2 | m_idx = 2 | +16 | | n_idx = 1 | n_idx = 2 | +17 |- - - - - - - -|- - - - - - - -|- - - - - - - - | +18 | | | | +19 | | | m_off_24 = 18 | +20 | | | n_off_24 = 16 | +21 | | | m_idx = 3 | +22 | | | n_idx = 2 | +23 |- - - - - - - -|- - - - - - - -|- - - - - - - - | +*/ + +#define PREFETCH_C() \ +\ + cmp(imm(8), rdi) \ + jz(.DCOLPFETCH) \ + label(.DROWPFETCH) \ + \ + lea(mem(r12, rdi, 2), rdx) \ + lea(mem(rdx, rdi, 1), rdx) \ + prefetch(0, mem(r12, 7*8)) \ + prefetch(0, mem(r12, rdi, 1, 7*8)) \ + prefetch(0, mem(r12, rdi, 2, 7*8)) \ + prefetch(0, mem(rdx, 7*8)) \ + prefetch(0, mem(rdx, rdi, 1, 7*8)) \ + prefetch(0, mem(rdx, rdi, 2, 7*8)) \ + \ + jmp(.DPOSTPFETCH) \ + label(.DCOLPFETCH) \ + \ + mov(var(cs_c), rsi) \ + lea(mem(, rsi, 8), rsi) \ + lea(mem(r12, rsi, 2), rdx) \ + lea(mem(rdx, rsi, 1), rdx) \ + prefetch(0, mem(r12, 5*8)) \ + prefetch(0, mem(r12, rsi, 1, 5*8)) \ + prefetch(0, mem(r12, rsi, 2, 5*8)) \ + prefetch(0, mem(rdx, 5*8)) \ + prefetch(0, mem(rdx, rsi, 1, 5*8)) \ + prefetch(0, mem(rdx, rsi, 2, 5*8)) \ + lea(mem(rdx, rsi, 2), rdx) \ + prefetch(0, mem(rdx, rsi, 1, 5*8)) \ + prefetch(0, mem(rdx, rsi, 2, 5*8)) \ + +#define SUBITER4x4(a, b, r1, r2, r3, r4) \ +\ + vmovupd(mem(b, 0*32), ymm0) \ + \ + vbroadcastsd(mem(a ), ymm2) \ + vbroadcastsd(mem(a, r8, 1), ymm3) \ + vfmadd231pd(ymm0, ymm2, r1) \ + vfmadd231pd(ymm0, ymm3, r2) \ + \ + vbroadcastsd(mem(a, r8, 2), ymm2) \ + vbroadcastsd(mem(a, r13, 1), ymm3) \ + vfmadd231pd(ymm0, ymm2, r3) \ + vfmadd231pd(ymm0, ymm3, r4) \ + +#define SUBITER2x4(a, b, r1, r2) \ +\ + vmovupd(mem(b, 0*32), ymm0) \ + \ + vbroadcastsd(mem(a ), ymm2) \ + vbroadcastsd(mem(a, r8, 1), ymm3) \ + vfmadd231pd(ymm0, ymm2, r1) \ + vfmadd231pd(ymm0, ymm3, r2) \ + +#define SUBITER2x2(a, b, r1, r2) \ +\ + vmovupd(mem(b, 0*32), xmm0) \ + \ + vbroadcastsd(mem(a ), ymm2) \ + vbroadcastsd(mem(a, r8, 1), ymm3) \ + vfmadd231pd(xmm0, xmm2, r1) \ + vfmadd231pd(xmm0, xmm3, r2) \ + +#define SUBITER6x4(a, b, r1, r2, r3, r4, r5, r6) \ +\ + vmovupd(mem(b, 0*32), ymm0) \ + \ + vbroadcastsd(mem(a ), ymm2) \ + vbroadcastsd(mem(a, r8, 1), ymm3) \ + vfmadd231pd(ymm0, ymm2, r1) \ + vfmadd231pd(ymm0, ymm3, r2) \ + \ + vbroadcastsd(mem(a, r8, 2), ymm2) \ + vbroadcastsd(mem(a, r13, 1), ymm3) \ + vfmadd231pd(ymm0, ymm2, r3) \ + vfmadd231pd(ymm0, ymm3, r4) \ + \ + vbroadcastsd(mem(a, r8, 4), ymm2) \ + vbroadcastsd(mem(a, r15, 1), ymm3) \ + vfmadd231pd(ymm0, ymm2, r5) \ + vfmadd231pd(ymm0, ymm3, r6) \ +/* + +Following kernel computes the 6x8 block for the Lower vairant(L) of gemmt where +m_offset in 24x24 block is 0 and n_offset is 0(0x0) +(0x0)_L + +the region marked with 'x' is computed by following kernel +the region marked with '-' is not computed + + <-- n_off_24 -- > + 0 1 2 3 4 5 6 7 + +↑ 0 x - - - - - - - +| 1 x x - - - - - - +m 2 x x x - - - - - +off 3 x x x x - - - - +24 4 x x x x x - - - +| 5 x x x x x x - - +↓ + + +*/ +void bli_dgemmsup_rv_haswell_asm_6x8m_0x0_L ( conj_t conja, conj_t conjb, @@ -1002,17 +1138,12 @@ void bli_dgemmsup_rv_haswell_asm_6x6m ) { AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); - //void* a_next = bli_auxinfo_next_a( data ); - //void* b_next = bli_auxinfo_next_b( data ); // Typecast local copies of integers in case dim_t and inc_t are a // different size than is expected by load instructions. uint64_t k_iter = k0 / 4; uint64_t k_left = k0 % 4; - uint64_t m_iter = m0 / 6; - uint64_t m_left = m0 % 6; - uint64_t rs_a = rs_a0; uint64_t cs_a = cs_a0; uint64_t rs_b = rs_b0; @@ -1021,31 +1152,24 @@ void bli_dgemmsup_rv_haswell_asm_6x6m uint64_t cs_c = cs_c0; // Query the panel stride of A and convert it to units of bytes. - uint64_t ps_a = bli_auxinfo_ps_a( data ); - uint64_t ps_a8 = ps_a * sizeof( double ); + uint64_t ps_a8 = bli_auxinfo_ps_a( data ) * sizeof( double ); - if ( m_iter == 0 ) goto consider_edge_cases; // ------------------------------------------------------------------------- begin_asm() - - //vzeroall() // zero all xmm/ymm registers. - + mov(var(a), r14) // load address of a. mov(var(rs_a), r8) // load rs_a mov(var(cs_a), r9) // load cs_a lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) - + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a - //mov(var(b), rbx) // load address of b. mov(var(rs_b), r10) // load rs_b - //mov(var(cs_b), r11) // load cs_b lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) - //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) // NOTE: We cannot pre-load elements of a or b // because it could eventually, in the last @@ -1057,181 +1181,6151 @@ void bli_dgemmsup_rv_haswell_asm_6x6m mov(var(rs_c), rdi) // load rs_c lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) - - // During preamble and loops: - // r12 = rcx = c - // r14 = rax = a - // read rbx from var(b) near beginning of loop - // r11 = m dim index ii - - mov(var(m_iter), r11) // ii = m_iter; - - label(.DLOOP6X8I) // LOOP OVER ii = [ m_iter ... 1 0 ] - - - -#if 0 - vzeroall() // zero all xmm/ymm registers. -#else + //for triangular kernels we can skip 1st loop around micro kernel // skylake can execute 3 vxorpd ipc with // a latency of 1 cycle, while vzeroall // has a latency of 12 cycles. - vxorpd(ymm1, ymm1, ymm1) // zero ymm1 since we only use the lower - vmovapd(ymm1, ymm4) // half (xmm1), and nans/infs may slow us - vmovapd(ymm1, ymm5) // down. - vmovapd(ymm1, ymm6) - vmovapd(ymm1, ymm7) - vmovapd(ymm1, ymm8) - vmovapd(ymm1, ymm9) - vmovapd(ymm1, ymm10) - vmovapd(ymm1, ymm11) - vmovapd(ymm1, ymm12) - vmovapd(ymm1, ymm13) - vmovapd(ymm1, ymm14) - vmovapd(ymm1, ymm15) -#endif + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) mov(var(b), rbx) // load address of b. - //mov(r12, rcx) // reset rcx to current utile of c. mov(r14, rax) // reset rax to current upanel of a. - - - cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. - jz(.DCOLPFETCH) // jump to column storage case - label(.DROWPFETCH) // row-stored prefetching on c - - lea(mem(r12, rdi, 2), rdx) // - lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; - prefetch(0, mem(r12, 5*8)) // prefetch c + 0*rs_c - prefetch(0, mem(r12, rdi, 1, 5*8)) // prefetch c + 1*rs_c - prefetch(0, mem(r12, rdi, 2, 5*8)) // prefetch c + 2*rs_c - prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*rs_c - prefetch(0, mem(rdx, rdi, 1, 5*8)) // prefetch c + 4*rs_c - prefetch(0, mem(rdx, rdi, 2, 5*8)) // prefetch c + 5*rs_c - - jmp(.DPOSTPFETCH) // jump to end of prefetching c - label(.DCOLPFETCH) // column-stored prefetching c - - mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) - lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) - lea(mem(r12, rsi, 2), rdx) // - lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; - prefetch(0, mem(r12, 5*8)) // prefetch c + 0*cs_c - prefetch(0, mem(r12, rsi, 1, 5*8)) // prefetch c + 1*cs_c - prefetch(0, mem(r12, rsi, 2, 5*8)) // prefetch c + 2*cs_c - prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*cs_c - prefetch(0, mem(rdx, rsi, 1, 5*8)) // prefetch c + 4*cs_c - prefetch(0, mem(rdx, rsi, 2, 5*8)) // prefetch c + 5*cs_c + PREFETCH_C() + lea(mem(rdx, rsi, 2), rdx) // rdx = c + 5*cs_c; label(.DPOSTPFETCH) // done prefetching c -#if 1 mov(var(ps_a8), rdx) // load ps_a8 lea(mem(rax, rdx, 1), rdx) // rdx = a + ps_a8 lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; // use rcx, rdx for prefetching lines // from next upanel of a. -#else - lea(mem(rax, r8, 4), rdx) // use rdx for prefetching lines - lea(mem(rdx, r8, 2), rdx) // from next upanel of a. - lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; -#endif - - - - + + + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + // skip computation of ymm5, ymm7, ymm9, ymm11 and compute only half of ymm4, ymm6, ymm13, ymm15 label(.DLOOPKITER) // MAIN LOOP - - + + // ---------------------------------- iteration 0 -#if 0 prefetch(0, mem(rdx, 5*8)) -#else - prefetch(0, mem(rdx, 5*8)) -#endif - - vmovupd(mem(rbx, 0*32), ymm0) - vmovupd(mem(rbx, 1*32), xmm1) - add(r10, rbx) // b += rs_b; - vbroadcastsd(mem(rax ), ymm2) - vbroadcastsd(mem(rax, r8, 1), ymm3) - vfmadd231pd(ymm0, ymm2, ymm4) - vfmadd231pd(ymm1, ymm2, ymm5) - vfmadd231pd(ymm0, ymm3, ymm6) - vfmadd231pd(ymm1, ymm3, ymm7) - - vbroadcastsd(mem(rax, r8, 2), ymm2) - vbroadcastsd(mem(rax, r13, 1), ymm3) - vfmadd231pd(ymm0, ymm2, ymm8) - vfmadd231pd(ymm1, ymm2, ymm9) - vfmadd231pd(ymm0, ymm3, ymm10) - vfmadd231pd(ymm1, ymm3, ymm11) - - vbroadcastsd(mem(rax, r8, 4), ymm2) - vbroadcastsd(mem(rax, r15, 1), ymm3) + SUBITER6x4(rax, rbx, ymm4, ymm6, ymm8, ymm10, ymm12, ymm14) + lea(mem(rax, r8, 4), rbp) + lea(mem(rbx, 1*32), rcx) + SUBITER2x2(rbp, rcx, xmm13, xmm15) + add(r10, rbx) // b += rs_b; add(r9, rax) // a += cs_a; - vfmadd231pd(ymm0, ymm2, ymm12) - vfmadd231pd(ymm1, ymm2, ymm13) - vfmadd231pd(ymm0, ymm3, ymm14) - vfmadd231pd(ymm1, ymm3, ymm15) - + // ---------------------------------- iteration 1 -#if 0 - prefetch(0, mem(rdx, 5*8)) -#else prefetch(0, mem(rdx, r9, 1, 5*8)) -#endif - vmovupd(mem(rbx, 0*32), ymm0) - vmovupd(mem(rbx, 1*32), xmm1) + SUBITER6x4(rax, rbx, ymm4, ymm6, ymm8, ymm10, ymm12, ymm14) + lea(mem(rax, r8, 4), rbp) + lea(mem(rbx, 1*32), rcx) + SUBITER2x2(rbp, rcx, xmm13, xmm15) add(r10, rbx) // b += rs_b; - - vbroadcastsd(mem(rax ), ymm2) - vbroadcastsd(mem(rax, r8, 1), ymm3) - vfmadd231pd(ymm0, ymm2, ymm4) - vfmadd231pd(ymm1, ymm2, ymm5) - vfmadd231pd(ymm0, ymm3, ymm6) - vfmadd231pd(ymm1, ymm3, ymm7) - - vbroadcastsd(mem(rax, r8, 2), ymm2) - vbroadcastsd(mem(rax, r13, 1), ymm3) - vfmadd231pd(ymm0, ymm2, ymm8) - vfmadd231pd(ymm1, ymm2, ymm9) - vfmadd231pd(ymm0, ymm3, ymm10) - vfmadd231pd(ymm1, ymm3, ymm11) - - vbroadcastsd(mem(rax, r8, 4), ymm2) - vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; - vfmadd231pd(ymm0, ymm2, ymm12) - vfmadd231pd(ymm1, ymm2, ymm13) - vfmadd231pd(ymm0, ymm3, ymm14) - vfmadd231pd(ymm1, ymm3, ymm15) - - + + // ---------------------------------- iteration 2 -#if 0 - prefetch(0, mem(rdx, 5*8)) -#else prefetch(0, mem(rdx, r9, 2, 5*8)) -#endif - - vmovupd(mem(rbx, 0*32), ymm0) - vmovupd(mem(rbx, 1*32), xmm1) - add(r10, rbx) // b += rs_b; + + SUBITER6x4(rax, rbx, ymm4, ymm6, ymm8, ymm10, ymm12, ymm14) + lea(mem(rax, r8, 4), rbp) + lea(mem(rbx, 1*32), rcx) + SUBITER2x2(rbp, rcx, xmm13, xmm15) + add(r10, rbx) // b += rs_b; + add(r9, rax) // a += cs_a; + + + // ---------------------------------- iteration 3 + + prefetch(0, mem(rdx, rcx, 1, 5*8)) + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; + + SUBITER6x4(rax, rbx, ymm4, ymm6, ymm8, ymm10, ymm12, ymm14) + lea(mem(rax, r8, 4), rbp) + lea(mem(rbx, 1*32), rcx) + SUBITER2x2(rbp, rcx, xmm13, xmm15) + add(r10, rbx) // b += rs_b; + add(r9, rax) // a += cs_a; + + + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) + + SUBITER6x4(rax, rbx, ymm4, ymm6, ymm8, ymm10, ymm12, ymm14) + lea(mem(rax, r8, 4), rbp) + lea(mem(rbx, 1*32), rcx) + SUBITER2x2(rbp, rcx, xmm13, xmm15) + add(r10, rbx) // b += rs_b; + add(r9, rax) // a += cs_a; + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(ymm0, ymm12, ymm12) + vmulpd(ymm0, ymm13, ymm13) + vmulpd(ymm0, ymm14, ymm14) + vmulpd(ymm0, ymm15, ymm15) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm4) + vmovlpd(xmm4, mem(rcx, 0*32)) // write back only lower half of xmm (8 bytes) + + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm6) + vmovupd(xmm6, mem(rcx, 0*32)) // write only lower half of ymm6 to c + + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) + vmovupd(xmm8, mem(rcx, 0*32)) // write lower half of ymm (16 bytes) + vextractf128(imm(1), ymm8, xmm1) // move upper half of ymm to xmm + vmovlpd(xmm1, mem(rcx, 2*8)) // write only lower half of xmm (8 bytes) to rcx + 16 + + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm10) + vmovupd(ymm10, mem(rcx, 0*32)) + + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm12) + vmovupd(ymm12, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), xmm3, xmm13) + vmovlpd(xmm13, mem(rcx, 1*32)) // write back only xmm13[0] + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm14) + vmovupd(ymm14, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), xmm3, xmm15) + vmovupd(xmm15, mem(rcx, 1*32)) // write xmm to c (16 bytes) + //add(rdi, rcx) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) + vmovupd(ymm4, mem(rcx )) + vextractf128(imm(1), ymm6, xmm1) // move upper half of ymm to xmm1 (ymm6[2], ymm6[3]) + vmovhpd(xmm6, mem(rcx, rsi, 1, 1*8)) // write upper half of xmm6(ymm6[1]) to c + rsi + 8 + vmovupd(xmm1, mem(rcx, rsi, 1, 2*8)) // write xmm1 (ymm6[2], ymm6[3]) to c + rsi + 16 + vextractf128(imm(1), ymm8, xmm1) // move upper half of ymm8 to xmm1 + vmovupd(xmm1, mem(rcx, rsi, 2, 2*8)) // write upper half of ymm8 to c + rsi*2 + 16 + vextractf128(imm(1), ymm10, xmm1) // move uppper half of ymm10 to xmm1 + vmovhpd(xmm1, mem(rcx, rax, 1, 3*8)) // move ymm8[3] to c + rsi*3 + 3*8 + + lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx ), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + + vfmadd231pd(mem(rdx ), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vmovupd(xmm0, mem(rdx )) // move the first half of ymm13 to c + vmovhpd(xmm1, mem(rdx, rsi, 1, 1*8)) // move the last 8 bits of ymm13 + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovlpd(xmm4, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(xmm6, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(xmm8, mem(rcx, 0*32)) + vextractf128(imm(1), ymm8, xmm1) + vmovlpd(xmm1, mem(rcx, 2*8)) + add(rdi, rcx) + + + vmovupd(ymm10, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(ymm12, mem(rcx, 0*32)) + vmovlpd(xmm13, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm14, mem(rcx, 0*32)) + vmovupd(xmm15, mem(rcx, 1*32)) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vmovupd(ymm4, mem(rcx )) + vextractf128(imm(1), ymm6, xmm1) // move upper half of ymm to xmm1 (ymm6[2], ymm6[3]) + vmovhpd(xmm6, mem(rcx, rsi, 1, 1*8)) // write upper half of xmm6(ymm6[1]) to c + rsi + 8 + vmovupd(xmm1, mem(rcx, rsi, 1, 2*8)) // write xmm1 (ymm6[2], ymm6[3]) to c + rsi + 16 + vextractf128(imm(1), ymm8, xmm1) // move upper half of ymm8 to xmm1 + vmovupd(xmm1, mem(rcx, rsi, 2, 2*8)) // write upper half of ymm8 to c + rsi*2 + 16 + vextractf128(imm(1), ymm10, xmm1) // move uppper half of ymm10 to xmm1 + vmovhpd(xmm1, mem(rcx, rax, 1, 3*8)) // move ymm8[3] to c + rsi*3 + 3*8 + + lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + + vmovupd(xmm0, mem(rdx )) // move the first half of ymm13 to c + vmovhpd(xmm1, mem(rdx, rsi, 1, 1*8)) // move the last 8 bits of ymm13 + + + label(.DDONE) + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a8] "m" (ps_a8), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); +} + +/* + +Following kernel computes the 6x8 block for the Lower vairant(L) of gemmt where +m_offset in 24x24 block is 6 and n_offset is 8(6x8) +(6x8)_L + +the region marked with 'x' is computed by following kernel +the region marked with '-' is not computed + + <-- n_off_24 -- > + 8 9 10 11 12 13 14 15 + +↑ 6 - - - - - - - - +| 7 - - - - - - - - +m 8 x - - - - - - - +off 9 x x - - - - - - +24 10 x x x - - - - - +| 11 x x x x - - - - +↓ + + +*/ +void bli_dgemmsup_rv_haswell_asm_6x8m_6x8_L + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + + // ------------------------------------------------------------------------- + + begin_asm() + mov(var(a), r14) + mov(var(c), r12) + + mov(var(rs_a), r8) + mov(var(cs_a), r9) + mov(var(rs_b), r10) + mov(var(rs_c), rdi) + lea(mem(, r8, 8), r8) + lea(mem(, r9, 8), r9) + lea(mem(, r10, 8), r10) + lea(mem(, rdi, 8), rdi) + + lea(mem(r8, r8, 2), r13) //3*r8 + lea(mem(r8, r8, 4), r15) //5*r8 + + vxorpd(ymm8, ymm8, ymm8) + vmovapd( ymm8, ymm10) + vmovapd( ymm8, ymm12) + vmovapd( ymm8, ymm14) + mov(var(b), rbx) // load address of b. + mov(r14, rax) + + cmp(imm(8), rdi) + jz(.DCOLPFETCH) + + label(.DROWPFETCH) + lea(mem(r12, rdi, 2), rdx) + lea(mem(rdx, rdi, 1), rdx) + prefetch(0, mem(r12, rdi, 2, 1*8)) + prefetch(0, mem(rdx, 2*8)) + prefetch(0, mem(rdx, rdi, 1, 3*8)) + prefetch(0, mem(rdx, rdi, 2, 4*8)) + jmp(.DPOSTPFETCH) + + label(.DCOLPFETCH) + mov(var(cs_c), rsi) + lea(mem(, rsi, 8), rsi) + lea(mem(r12, rsi, 2), rdx) + lea(mem(rdx, rsi, 1), rdx) + prefetch(0, mem(r12, 5*8)) + prefetch(0, mem(r12, rsi, 1, 5*8)) + prefetch(0, mem(r12, rsi, 2, 5*8)) + prefetch(0, mem(rdx, 5*8)) + + label(.DPOSTPFETCH) + lea(mem(rax, r8, 2), rax) + mov(var(k_iter), rsi) + test(rsi, rsi) + je(.DCONSIDKLEFT) + + // computer xmm8, xmm10, ymm12, ymm14 only + label(.DLOOPKITER) + //0 + SUBITER4x4(rax, rbx, ymm8, ymm10, ymm12, ymm14) + add(r10, rbx) // b += rs_b; + add(r9, rax) // a += cs_a; + //1 + SUBITER4x4(rax, rbx, ymm8, ymm10, ymm12, ymm14) + add(r10, rbx) // b += rs_b; + add(r9, rax) // a += cs_a; + //2 + SUBITER4x4(rax, rbx, ymm8, ymm10, ymm12, ymm14) + add(r10, rbx) // b += rs_b; + add(r9, rax) // a += cs_a; + //3 + SUBITER4x4(rax, rbx, ymm8, ymm10, ymm12, ymm14) + add(r10, rbx) // b += rs_b; + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + label(.DCONSIDKLEFT) + mov(var(k_left), rsi) + test(rsi, rsi) + je(.DPOSTACCUM) + + label(.DLOOPKLEFT) + SUBITER4x4(rax, rbx, ymm8, ymm10, ymm12, ymm14) + add(r10, rbx) // b += rs_b; + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + label(.DPOSTACCUM) + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(ymm0, ymm12, ymm12) + vmulpd(ymm0, ymm14, ymm14) + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + label(.DROWSTORED) + lea(mem(rcx , rdi, 2), rcx) + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm8) + vmovlpd(xmm8, mem(rcx, 0*32)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm10) + vmovupd(xmm10, mem(rcx, 0*32)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm12) + vmovupd(xmm12, mem(rcx, 0*32)) + vextractf128(imm(1), ymm12, xmm1) + vmovlpd(xmm1, mem(rcx, 2*8)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm14) + vmovupd(ymm14, mem(rcx, 0*32)) + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + + vextractf128(imm(1), ymm4, xmm1) + vmovupd(xmm1, mem(rcx, 2*8 )) // write upper half of ymm4 to c + vextractf128(imm(1), ymm6, xmm1) + vmovhpd(xmm1, mem(rcx, rsi, 1, 3*8)) // write last element of ymm6 + + lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx ), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovhpd(xmm4, mem(rdx, rax, 1, 1*8)) // write only last 8 bytes of second half of ymm14 + + lea(mem(rdx, rsi, 4), rdx) + + jmp(.DDONE) // jump to end. + + label(.DBETAZERO) + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + label(.DROWSTORBZ) + lea(mem(rdi, rcx, 2), rdi) + + vmovlpd(xmm8, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(xmm10, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(xmm12, mem(rcx, 0*32)) + vextractf128(imm(1), ymm12, xmm1) + vmovlpd(xmm1, mem(rcx, 2*8)) + add(rdi, rcx) + + + vmovupd(ymm14, mem(rcx, 0*32)) + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + + vbroadcastsd(mem(rbx), ymm3) + + vextractf128(imm(1), ymm4, xmm1) + vmovupd(xmm1, mem(rcx, 2*8 )) // write upper half of ymm4 to c + vextractf128(imm(1), ymm6, xmm1) + vmovhpd(xmm1, mem(rcx, rsi, 1, 3*8)) // write last element of ymm6 + + lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovhpd(xmm4, mem(rdx, rax, 1, 1*8)) // write only last 8 bytes of second half of ymm14 + + + label(.DDONE) + + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); +} + +/* + +Following kernel computes the 6x8 block for the Lower vairant(L) of gemmt where +m_offset in 24x24 block is 12 and n_offset is 16(12x16) +(12x16)_L + + +the region marked with 'x' is computed by following kernel +the region marked with '-' is not computed + + <-- n_off_24 -- > + 16 17 18 19 20 21 22 23 + +↑ 12 - - - - - - - - +| 13 - - - - - - - - +m 14 - - - - - - - - +off 15 - - - - - - - - +24 16 x - - - - - - - +| 17 x x - - - - - - +↓ + + +*/ +void bli_dgemmsup_rv_haswell_asm_6x8m_12x16_L + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + begin_asm() + mov(var(a), r14) + mov(var(b), rbx) + mov(var(c), r12) + mov(r14, rax) + + mov(var(rs_a), r8) + mov(var(cs_a), r9) + lea(mem(, r8, 8), r8) + lea(mem(, r9, 8), r9) + + mov(var(rs_b), r10) + lea(mem(, r10, 8), r10) + + mov(var(rs_c), rdi) + lea(mem(, rdi, 8), rdi) + + lea(mem(r8, r8, 4), r15) + + vxorpd(ymm12, ymm12, ymm12) + vmovapd(ymm12, ymm14) + + cmp(imm(8), rdi) + jz(.DCOLPFETCH) + + label(.DROWPFETCH) + lea(mem(r12, rdi, 2), rdx) + lea(mem(rdx, rdi, 1), rdx) + prefetch(0, mem(rdx, rdi, 1, 1*8)) + prefetch(0, mem(rdx, rdi, 2, 2*8)) + jmp(.DPOSTPFETCH) + + label(.DCOLPFETCH) + mov(var(cs_c), rsi) + lea(mem(, rsi, 8), rsi) + prefetch(0, mem(r12, 5*8)) + prefetch(0, mem(r12, rsi, 1, 5*8)) + + label(.DPOSTPFETCH) + lea(mem(rax, r8, 4), rax) + mov(var(k_iter), rsi) + test(rsi, rsi) + je(.DCONSILEFT) + + //compute xmm12 and xmm 14 + label(.DMAIN) + //0 + SUBITER2x2(rax, rbx, xmm12, xmm14) + add(r10, rbx) + add(r9, rax) + //1 + SUBITER2x2(rax, rbx, xmm12, xmm14) + add(r10, rbx) + add(r9, rax) + //2 + SUBITER2x2(rax, rbx, xmm12, xmm14) + add(r10, rbx) + add(r9, rax) + //3 + SUBITER2x2(rax, rbx, xmm12, xmm14) + add(r10, rbx) + add(r9, rax) + + dec(rsi) + jne(.DMAIN) + + label(.DCONSILEFT) + mov(var(k_left), rsi) + test(rsi, rsi) + je(.DPOSTACC) + + label(.DLEFT) + SUBITER2x2(rax, rbx, xmm12, xmm14) + add(r10, rbx) + add(r9, rax) + dec(rsi) + jne(.DLEFT) + + label(.DPOSTACC) + mov(r12, rcx) + mov(var(alpha), rax) + mov(var(beta), rbx) + vbroadcastsd(mem(rax), ymm0) + vbroadcastsd(mem(rbx), ymm3) + vmulpd(ymm0, ymm12, ymm12) + vmulpd(ymm0, ymm14, ymm14) + + mov(var(cs_c), rsi) + lea(mem(, rsi, 8), rsi) + vxorpd(ymm0, ymm0, ymm0) + + cmp(imm(8), rdi) //rs_c == 0? + je(.DCOLSTOR) + + label(.DROWSTOR) + lea(mem(rcx, rdi, 4), rcx) //rcx += 4 * rdi + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm12) + vmovlpd(xmm12, mem(rcx)) + add(rdi, rcx) + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm14) + vmovlpd(xmm14, mem(rcx)) + vmovhpd(xmm14, mem(rcx, rsi, 1)) + jmp(.DRETURN) + + label(.DCOLSTOR) + + lea(mem(rcx, rdi, 4), rdx) + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vinsertf128(imm(0x1), xmm2, ymm0, ymm12) + vinsertf128(imm(0x1), xmm3, ymm1, ymm14) + + vfmadd231pd(mem(rdx), xmm3, xmm12) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm14) + vmovupd(xmm12, mem(rdx )) + vmovhpd(xmm14, mem(rdx, rsi, 1, 1*8)) + + label(.DRETURN) + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +/* + +Following kernel computes the 6x8 block for the Lower vairant(L) of gemmt where +m_offset in 24x24 block is 12, n_offset is 16(12x16) and m_offset is 18, n_offset is 16 (18x16) +(16x12)+(16x18)_L + +the region marked with 'x' is computed by following kernel +the region marked with '-' is not computed + + <-- n_off_24 -- > + 16 17 18 19 20 21 22 23 + +↑ 12 - - - - - - - - +| 13 - - - - - - - - +m 14 - - - - - - - - +off 15 - - - - - - - - +24 16 x - - - - - - - +| 17 x x - - - - - - +↓ +↑ 18 x x x - - - - - +| 19 x x x x - - - - +m 20 x x x x x - - - +off 21 x x x x x x - - +24 22 x x x x x x x - +| 23 x x x x x x x x +↓ + + +*/ +void bli_dgemmsup_rv_haswell_asm_6x8m_16x12_combined_L + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) + { + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + uint64_t ps_a8 = bli_auxinfo_ps_a( data ) * sizeof( double ); + double* a_next = a + rs_a * 6; + begin_asm() + mov(var(a), r14) + mov(var(b), rbx) + mov(var(c), r12) + mov(var(a_next), r11) + mov(r14, rax) + + mov(var(rs_a), r8) + mov(var(cs_a), r9) + lea(mem(, r8, 8), r8) + lea(mem(, r9, 8), r9) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // 5 + mov(var(rs_b), r10) + lea(mem(, r10, 8), r10) + + mov(var(rs_c), rdi) + lea(mem(, rdi, 8), rdi) + + lea(mem(r8, r8, 4), r15) + + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + cmp(imm(8), rdi) + jz(.DCOLPFETCH) + + label(.DROWPFETCH) + lea(mem(r12, rdi, 2), rdx) + lea(mem(rdx, rdi, 1), rdx) + + prefetch(0, mem(rdx, rdi, 1, 1*8)) // c + 4 * rs_c + prefetch(0, mem(rdx, rdi, 2, 2*8)) + lea(mem(rdx, rdi, 2), rdx) + lea(mem(rdx, rdi, 1), rdx) // c + 6 *rsc + prefetch(0, mem(rdx, 7*8)) + prefetch(0, mem(rdx, rdi, 1, 7*8)) + prefetch(0, mem(rdx, rdi, 2, 7*8)) + lea(mem(rdx, rdi, 2), rdx) + lea(mem(rdx, rdi, 1), rdx) + prefetch(0, mem(rdx, 7*8)) + prefetch(0, mem(rdx, rdi, 1, 7*8)) + prefetch(0, mem(rdx, rdi, 2, 7*8)) + + jmp(.DPOSTPFETCH) + + label(.DCOLPFETCH) + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(r12, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(r12, 11*8)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 11*8)) // prefetch c + 1*cs_c + prefetch(0, mem(r12, rsi, 2, 11*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 11*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 11*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 11*8)) // prefetch c + 5*cs_c + lea(mem(rdx, rsi, 2), rdx) // rdx = c + 5*cs_c; + prefetch(0, mem(rdx, rsi, 1, 11*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rsi, 2, 11*8)) // prefetch c + 7*cs_c + + label(.DPOSTPFETCH) + mov(var(ps_a8), rdx) + lea(mem(rax, rdx, 1), rdx) //rdx = a + ps_a8 //for prefetch + mov(var(ps_a8), rbp) + lea(mem(r11, rbp, 1), rbp) //rdx = a + ps_a8 //for prefetch + mov(var(k_iter), rsi) + test(rsi, rsi) + je(.DCONSILEFT) + + // ymm5 and ymm7 contains the data for 16x12 block, other registers contains data for 16x18 block + label(.DMAIN) + //0 + prefetch(0, mem(rdx, 5*8)) + prefetch(0, mem(rbp, 5*8)) + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm5) + vfmadd231pd(xmm0, xmm3, xmm7) + vbroadcastsd(mem(r11 ), ymm2) + vbroadcastsd(mem(r11, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(r11, r8, 2), ymm2) + vbroadcastsd(mem(r11, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(r11, r8, 4), ymm2) + vbroadcastsd(mem(r11, r15, 1), ymm3) + add(r9, r11) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + add(r10, rbx) + add(r9, rax) + //1 + prefetch(0, mem(rdx, r9, 1, 5*8)) + prefetch(0, mem(rbp, r9, 1, 5*8)) + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm5) + vfmadd231pd(xmm0, xmm3, xmm7) + vbroadcastsd(mem(r11 ), ymm2) + vbroadcastsd(mem(r11, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(r11, r8, 2), ymm2) + vbroadcastsd(mem(r11, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(r11, r8, 4), ymm2) + vbroadcastsd(mem(r11, r15, 1), ymm3) + add(r9, r11) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + add(r10, rbx) + add(r9, rax) + //2 + prefetch(0, mem(rdx, r9, 2, 5*8)) + prefetch(0, mem(rbp, r9, 2, 5*8)) + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm5) + vfmadd231pd(xmm0, xmm3, xmm7) + vbroadcastsd(mem(r11 ), ymm2) + vbroadcastsd(mem(r11, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(r11, r8, 2), ymm2) + vbroadcastsd(mem(r11, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(r11, r8, 4), ymm2) + vbroadcastsd(mem(r11, r15, 1), ymm3) + add(r9, r11) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + add(r10, rbx) + add(r9, rax) + //3 + prefetch(0, mem(rdx, rcx, 1, 5*8)) + prefetch(0, mem(rbp, rcx, 1, 5*8)) + lea(mem(rdx, r9, 4), rdx) + lea(mem(rbp, r9, 4), rbp) + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm5) + vfmadd231pd(xmm0, xmm3, xmm7) + vbroadcastsd(mem(r11 ), ymm2) + vbroadcastsd(mem(r11, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(r11, r8, 2), ymm2) + vbroadcastsd(mem(r11, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(r11, r8, 4), ymm2) + vbroadcastsd(mem(r11, r15, 1), ymm3) + add(r9, r11) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + add(r10, rbx) + add(r9, rax) + + dec(rsi) + jne(.DMAIN) + + label(.DCONSILEFT) + mov(var(k_left), rsi) + test(rsi, rsi) + je(.DPOSTACC) + + label(.DLEFT) + prefetch(0, mem(rdx, 5*8)) + prefetch(0, mem(rbp, 5*8)) + add(r9, rbp) + add(r9, rdx) + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm5) + vfmadd231pd(xmm0, xmm3, xmm7) + vbroadcastsd(mem(r11 ), ymm2) + vbroadcastsd(mem(r11, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(r11, r8, 2), ymm2) + vbroadcastsd(mem(r11, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(r11, r8, 4), ymm2) + vbroadcastsd(mem(r11, r15, 1), ymm3) + add(r9, r11) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + add(r10, rbx) + add(r9, rax) + + dec(rsi) + jne(.DLEFT) + + label(.DPOSTACC) + mov(r12, rcx) + mov(var(alpha), rax) + mov(var(beta), rbx) + vbroadcastsd(mem(rax), ymm0) + vbroadcastsd(mem(rbx), ymm3) + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm7, ymm7) + vmulpd(ymm0, ymm4, ymm4) + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm9, ymm9) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(ymm0, ymm11, ymm11) + vmulpd(ymm0, ymm12, ymm12) + vmulpd(ymm0, ymm13, ymm13) + vmulpd(ymm0, ymm14, ymm14) + vmulpd(ymm0, ymm15, ymm15) + + mov(var(cs_c), rsi) + lea(mem(, rsi, 8), rsi) + vxorpd(ymm0, ymm0, ymm0) + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + cmp(imm(8), rdi) //rs_c == 8? + je(.DCOLSTOR) + + label(.DROWSTOR) + lea(mem(rcx, rdi, 4), rcx) //rcx += 4 * rdi + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm5) + vmovlpd(xmm5, mem(rcx)) + add(rdi, rcx) + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm7) + vmovlpd(xmm7, mem(rcx)) + vmovhpd(xmm7, mem(rcx, rsi, 1)) + + //for lower 6x8 + lea(mem(rcx, rdi, 1), rcx) //rcx += 1 * rdi + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vmovupd(xmm4, mem(rcx, 0*32)) + vextractf128(imm(1), ymm4, xmm1) + vmovlpd(xmm1, mem(rcx, 2*8)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) + vmovupd(ymm6, mem(rcx, 0*32)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) + vmovupd(ymm8, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm9) + vmovlpd(xmm9, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm10) + vmovupd(ymm10, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm11) + vmovupd(xmm11, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm12) + vmovupd(ymm12, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm13) + vmovupd(xmm13, mem(rcx, 1*32)) + vextractf128(imm(1), ymm13, xmm1) + vmovlpd(xmm1, mem(rcx, 1*32+2*8)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm14) + vmovupd(ymm14, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm15) + vmovupd(ymm15, mem(rcx, 1*32)) + + jmp(.DRETURN) + + label(.DCOLSTOR) + vbroadcastsd(mem(rbx), ymm3) + + lea(mem(rcx, rdi, 4), rdx) //rdx = rcx + 4* rs_c + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + + vfmadd231pd(mem(rdx), xmm3, xmm5) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm7) + vmovupd(xmm5, mem(rdx )) + vmovhpd(xmm7, mem(rdx, rsi, 1, 1*8)) + + lea(mem(rcx, rdi, 4), rcx) + lea(mem(rcx, rdi, 2), rcx) + lea(mem(rcx, rdi, 4), rdx) + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vextractf128(imm(1), ymm10, xmm1) + vmovhpd(xmm10, mem(rcx, rax, 1, 1*8)) + vmovupd(xmm1, mem(rcx, rax, 1, 2*8)) + + + lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx ), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + // begin I/O on columns 4-7 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm5) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm7) + vextractf128(imm(1), ymm5, xmm1) + vmovupd(xmm1, mem(rcx, 2*8 )) + vextractf128(imm(1), ymm7, xmm1) + vmovhpd(xmm1, mem(rcx, rsi, 1, 3*8)) + + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx ), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovhpd(xmm4, mem(rdx, rax, 1, 1*8)) + + + + label(.DRETURN) + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [a_next] "m" (a_next), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a8] "m" (ps_a8), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + } +/* + +Following kernel computes the 6x8 block for the Lower vairant(L) of gemmt where +m_offset in 24x24 block is 6 and n_offset is 0(6x0) +(6x0)_L + + +the region marked with 'x' is computed by following kernel +the region marked with '-' is not computed + + <-- n_off_24 -- > + 0 1 2 3 4 5 6 7 + +↑ 6 x x x x x x x - +| 7 x x x x x x x x +m 8 x x x x x x x x +off 9 x x x x x x x x +24 10 x x x x x x x x +| 11 x x x x x x x x +↓ + + +*/ +void bli_dgemmsup_rv_haswell_asm_6x8m_6x0_L + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // Query the panel stride of A and convert it to units of bytes. + uint64_t ps_a8 = bli_auxinfo_ps_a( data ) * sizeof( double ); + + + // ------------------------------------------------------------------------- + + begin_asm() + + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(rs_b), r10) // load rs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + mov(var(b), rbx) // load address of b. + mov(r14, rax) // reset rax to current upanel of a. + + + + PREFETCH_C() + + label(.DPOSTPFETCH) // done prefetching c + + + mov(var(ps_a8), rdx) // load ps_a8 + lea(mem(rax, rdx, 1), rdx) // rdx = a + ps_a8 + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rdx, 5*8)) + SUBITER6x4(rax, rbx, ymm4, ymm6, ymm8, ymm10, ymm12, ymm14) + lea(mem(rbx, 1*32), rbp) + SUBITER6x4(rax, rbp, ymm5, ymm7, ymm9, ymm11, ymm13, ymm15) + add(r10, rbx) // b += rs_b; + add(r9, rax) // a += cs_a; + // ---------------------------------- iteration 1 + + prefetch(0, mem(rdx, r9, 1, 5*8)) + SUBITER6x4(rax, rbx, ymm4, ymm6, ymm8, ymm10, ymm12, ymm14) + lea(mem(rbx, 1*32), rbp) + SUBITER6x4(rax, rbp, ymm5, ymm7, ymm9, ymm11, ymm13, ymm15) + add(r10, rbx) // b += rs_b; + add(r9, rax) // a += cs_a; + // ---------------------------------- iteration 2 + + prefetch(0, mem(rdx, r9, 2, 5*8)) + SUBITER6x4(rax, rbx, ymm4, ymm6, ymm8, ymm10, ymm12, ymm14) + lea(mem(rbx, 1*32), rbp) + SUBITER6x4(rax, rbp, ymm5, ymm7, ymm9, ymm11, ymm13, ymm15) + add(r10, rbx) // b += rs_b; + add(r9, rax) // a += cs_a; + // ---------------------------------- iteration 3 + + prefetch(0, mem(rdx, rcx, 1, 5*8)) + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; + + SUBITER6x4(rax, rbx, ymm4, ymm6, ymm8, ymm10, ymm12, ymm14) + lea(mem(rbx, 1*32), rbp) + SUBITER6x4(rax, rbp, ymm5, ymm7, ymm9, ymm11, ymm13, ymm15) + add(r10, rbx) // b += rs_b; + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) + + SUBITER6x4(rax, rbx, ymm4, ymm6, ymm8, ymm10, ymm12, ymm14) + lea(mem(rbx, 1*32), rbp) + SUBITER6x4(rax, rbp, ymm5, ymm7, ymm9, ymm11, ymm13, ymm15) + add(r10, rbx) // b += rs_b; + add(r9, rax) // a += cs_a; + + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm7, ymm7) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm9, ymm9) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(ymm0, ymm11, ymm11) + vmulpd(ymm0, ymm12, ymm12) + vmulpd(ymm0, ymm13, ymm13) + vmulpd(ymm0, ymm14, ymm14) + vmulpd(ymm0, ymm15, ymm15) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vmovupd(ymm4, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm5) + vmovupd(xmm5, mem(rcx, 1*32)) + vextractf128(imm(1), ymm5, xmm1) + vmovlpd(xmm1, mem(rcx, 1*32+2*8)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) + vmovupd(ymm6, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm7) + vmovupd(ymm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) + vmovupd(ymm8, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm9) + vmovupd(ymm9, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm10) + vmovupd(ymm10, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm11) + vmovupd(ymm11, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm12) + vmovupd(ymm12, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm13) + vmovupd(ymm13, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm14) + vmovupd(ymm14, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm15) + vmovupd(ymm15, mem(rcx, 1*32)) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx ), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + // begin I/O on columns 4-7 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm5) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm7) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm9) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm11) + vmovupd(ymm5, mem(rcx )) + vmovupd(ymm7, mem(rcx, rsi, 1)) + vmovupd(ymm9, mem(rcx, rsi, 2)) + vmovhpd(xmm11, mem(rcx, rax, 1, 1*8)) + vextractf128(imm(1), ymm11, xmm1) + vmovupd(xmm1, mem(rcx, rax, 1, 2*8)) + + + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx ), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx, 0*32)) + vmovupd(xmm5, mem(rcx, 1*32)) + vextractf128(imm(1), ymm5, xmm1) + vmovlpd(xmm1, mem(rcx, 1*32+2*8)) + add(rdi, rcx) + + + vmovupd(ymm6, mem(rcx, 0*32)) + vmovupd(ymm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm8, mem(rcx, 0*32)) + vmovupd(ymm9, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm10, mem(rcx, 0*32)) + vmovupd(ymm11, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm12, mem(rcx, 0*32)) + vmovupd(ymm13, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm14, mem(rcx, 0*32)) + vmovupd(ymm15, mem(rcx, 1*32)) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + // begin I/O on columns 4-7 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vbroadcastsd(mem(rbx), ymm3) + + vmovupd(ymm5, mem(rcx )) + vmovupd(ymm7, mem(rcx, rsi, 1)) + vmovupd(ymm9, mem(rcx, rsi, 2)) + vmovhpd(xmm11, mem(rcx, rax, 1, 1*8)) + vextractf128(imm(1), ymm11, xmm1) + vmovupd(xmm1, mem(rcx, rax, 1, 2*8)) + + + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + label(.DDONE) + + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a8] "m" (ps_a8), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + + // Handle edge cases in the m dimension, if they exist. + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); +} + +/* + +Following kernel computes the 6x8 block for the Lower vairant(L) of gemmt where +m_offset in 24x24 block is 12 and n_offset is 8(12x8) +(12x8)_L + +the region marked with 'x' is computed by following kernel +the region marked with '-' is not computed + + <-- n_off_24 -- > + 8 9 10 11 12 13 14 15 + +↑ 12 x x x x x - - - +| 13 x x x x x x - - +m 14 x x x x x x x - +off 15 x x x x x x x x +24 16 x x x x x x x x +| 17 x x x x x x x x +↓ + + +*/ +void bli_dgemmsup_rv_haswell_asm_6x8m_12x8_L + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // Query the panel stride of A and convert it to units of bytes. + uint64_t ps_a8 = bli_auxinfo_ps_a( data ) * sizeof( double ); + + + // ------------------------------------------------------------------------- + + begin_asm() + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(rs_b), r10) // load rs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + mov(var(b), rbx) // load address of b. + mov(r14, rax) // reset rax to current upanel of a. + + + + PREFETCH_C() + + label(.DPOSTPFETCH) // done prefetching c + + + mov(var(ps_a8), rdx) // load ps_a8 + lea(mem(rax, rdx, 1), rdx) // rdx = a + ps_a8 + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rdx, 5*8)) + + SUBITER6x4(rax, rbx, ymm4, ymm6, ymm8, ymm10, ymm12, ymm14) + lea(mem(rbx, 1*32), rbp) + SUBITER6x4(rax, rbp, ymm5, ymm7, ymm9, ymm11, ymm13, ymm15) + add(r10, rbx) // b += rs_b; + add(r9, rax) // a += cs_a; + // ---------------------------------- iteration 1 + + prefetch(0, mem(rdx, r9, 1, 5*8)) + + SUBITER6x4(rax, rbx, ymm4, ymm6, ymm8, ymm10, ymm12, ymm14) + lea(mem(rbx, 1*32), rbp) + SUBITER6x4(rax, rbp, ymm5, ymm7, ymm9, ymm11, ymm13, ymm15) + add(r10, rbx) // b += rs_b; + add(r9, rax) // a += cs_a; + // ---------------------------------- iteration 2 + + prefetch(0, mem(rdx, r9, 2, 5*8)) + + SUBITER6x4(rax, rbx, ymm4, ymm6, ymm8, ymm10, ymm12, ymm14) + lea(mem(rbx, 1*32), rbp) + SUBITER6x4(rax, rbp, ymm5, ymm7, ymm9, ymm11, ymm13, ymm15) + add(r10, rbx) // b += rs_b; + add(r9, rax) // a += cs_a; + // ---------------------------------- iteration 3 + + prefetch(0, mem(rdx, rcx, 1, 5*8)) + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; + + SUBITER6x4(rax, rbx, ymm4, ymm6, ymm8, ymm10, ymm12, ymm14) + lea(mem(rbx, 1*32), rbp) + SUBITER6x4(rax, rbp, ymm5, ymm7, ymm9, ymm11, ymm13, ymm15) + add(r10, rbx) // b += rs_b; + add(r9, rax) // a += cs_a; + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) + + SUBITER6x4(rax, rbx, ymm4, ymm6, ymm8, ymm10, ymm12, ymm14) + lea(mem(rbx, 1*32), rbp) + SUBITER6x4(rax, rbp, ymm5, ymm7, ymm9, ymm11, ymm13, ymm15) + add(r10, rbx) // b += rs_b; + add(r9, rax) // a += cs_a; + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm7, ymm7) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm9, ymm9) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(ymm0, ymm11, ymm11) + vmulpd(ymm0, ymm12, ymm12) + vmulpd(ymm0, ymm13, ymm13) + vmulpd(ymm0, ymm14, ymm14) + vmulpd(ymm0, ymm15, ymm15) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vmovupd(ymm4, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm5) + vmovlpd(xmm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) + vmovupd(ymm6, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm7) + vmovupd(xmm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) + vmovupd(ymm8, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm9) + vmovupd(xmm9, mem(rcx, 1*32)) + vextractf128(imm(1), ymm9, xmm1) + vmovlpd(xmm1, mem(rcx, 1*32+2*8)) + + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm10) + vmovupd(ymm10, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm11) + vmovupd(ymm11, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm12) + vmovupd(ymm12, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm13) + vmovupd(ymm13, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm14) + vmovupd(ymm14, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm15) + vmovupd(ymm15, mem(rcx, 1*32)) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx ), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + // begin I/O on columns 4-7 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm5) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm7) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm9) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm11) + vmovupd(ymm5, mem(rcx )) + vextractf128(imm(1), ymm7, xmm1) + vmovhpd(xmm7, mem(rcx, rsi, 1, 1*8)) + vmovupd(xmm1, mem(rcx, rsi, 1, 2*8)) + vextractf128(imm(1), ymm9, xmm1) + vmovupd(xmm1, mem(rcx, rsi, 2, 2*8)) + vextractf128(imm(1), ymm11, xmm1) + vmovhpd(xmm1, mem(rcx, rax, 1, 3*8)) + + + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx ), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(ymm4, mem(rcx, 0*32)) + vmovlpd(xmm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm6, mem(rcx, 0*32)) + vmovupd(xmm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm8, mem(rcx, 0*32)) + vmovupd(xmm9, mem(rcx, 1*32)) + vextractf128(imm(1), ymm9, xmm1) + vmovlpd(xmm1, mem(rcx, 1*32+2*8)) + add(rdi, rcx) + + + vmovupd(ymm10, mem(rcx, 0*32)) + vmovupd(ymm11, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm12, mem(rcx, 0*32)) + vmovupd(ymm13, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm14, mem(rcx, 0*32)) + vmovupd(ymm15, mem(rcx, 1*32)) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + // begin I/O on columns 4-7 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vmovupd(ymm5, mem(rcx )) + vextractf128(imm(1), ymm7, xmm1) + vmovhpd(xmm7, mem(rcx, rsi, 1, 1*8)) + vmovupd(xmm1, mem(rcx, rsi, 1, 2*8)) + vextractf128(imm(1), ymm9, xmm1) + vmovupd(xmm1, mem(rcx, rsi, 2, 2*8)) + vextractf128(imm(1), ymm11, xmm1) + vmovhpd(xmm1, mem(rcx, rax, 1, 3*8)) + + + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + label(.DDONE) + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a8] "m" (ps_a8), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + + // Handle edge cases in the m dimension, if they exist. + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); +} +/* + +Following kernel computes the 6x8 block for the Lower vairant(L) of gemmt where +m_offset in 24x24 block is 18 and n_offset is 16(18x16) +(18x16)_L + + +the region marked with 'x' is computed by following kernel +the region marked with '-' is not computed + + <-- n_off_24 -- > + 16 17 18 19 20 21 22 23 + +↑ 18 x x x - - - - - +| 19 x x x x - - - - +m 20 x x x x x - - - +off 21 x x x x x x - - +24 22 x x x x x x x - +| 23 x x x x x x x x +↓ + + +*/ +void bli_dgemmsup_rv_haswell_asm_6x8m_18x16_L + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // Query the panel stride of A and convert it to units of bytes. + uint64_t ps_a8 = bli_auxinfo_ps_a( data ) * sizeof( double ); + + + // ------------------------------------------------------------------------- + + begin_asm() + + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + mov(var(rs_b), r10) // load rs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + mov(var(b), rbx) // load address of b. + mov(r14, rax) // reset rax to current upanel of a. + + + + PREFETCH_C() + + label(.DPOSTPFETCH) // done prefetching c + + + mov(var(ps_a8), rdx) // load ps_a8 + lea(mem(rax, rdx, 1), rdx) // rdx = a + ps_a8 + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; + // use rcx, rdx for prefetching lines + // from next upanel of a. + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rdx, 5*8)) + SUBITER6x4(rax, rbx, ymm4, ymm6, ymm8, ymm10, ymm12, ymm14) + lea(mem(rbx, 1*32), rbp) + lea(mem(rax, r8, 2), rcx) + SUBITER4x4(rcx, rbp, ymm9, ymm11, ymm13, ymm15) + add(r10, rbx) // b += rs_b; + add(r9, rax) // a += cs_a; + // ---------------------------------- iteration 1 + + prefetch(0, mem(rdx, r9, 1, 5*8)) + SUBITER6x4(rax, rbx, ymm4, ymm6, ymm8, ymm10, ymm12, ymm14) + lea(mem(rax, r8, 2), rcx) + lea(mem(rbx, 1*32), rbp) + SUBITER4x4(rcx, rbp, ymm9, ymm11, ymm13, ymm15) + add(r10, rbx) // b += rs_b; + add(r9, rax) // a += cs_a; + // ---------------------------------- iteration 2 + + prefetch(0, mem(rdx, r9, 2, 5*8)) + SUBITER6x4(rax, rbx, ymm4, ymm6, ymm8, ymm10, ymm12, ymm14) + lea(mem(rax, r8, 2), rcx) + lea(mem(rbx, 1*32), rbp) + SUBITER4x4(rcx, rbp, ymm9, ymm11, ymm13, ymm15) + add(r10, rbx) // b += rs_b; + add(r9, rax) // a += cs_a; + // ---------------------------------- iteration 3 + + prefetch(0, mem(rdx, rcx, 1, 5*8)) + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; + SUBITER6x4(rax, rbx, ymm4, ymm6, ymm8, ymm10, ymm12, ymm14) + lea(mem(rax, r8, 2), rcx) + lea(mem(rbx, 1*32), rbp) + SUBITER4x4(rcx, rbp, ymm9, ymm11, ymm13, ymm15) + add(r10, rbx) // b += rs_b; + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKITER) // iterate again if i != 0. + + + + + + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) // i = k_left; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left loop. + + + label(.DLOOPKLEFT) // EDGE LOOP + + prefetch(0, mem(rdx, 5*8)) + SUBITER6x4(rax, rbx, ymm4, ymm6, ymm8, ymm10, ymm12, ymm14) + lea(mem(rax, r8, 2), rcx) + lea(mem(rbx, 1*32), rbp) + SUBITER4x4(rcx, rbp, ymm9, ymm11, ymm13, ymm15) + add(r10, rbx) // b += rs_b; + add(r9, rax) // a += cs_a; + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT) // iterate again if i != 0. + + + + label(.DPOSTACCUM) + + + + mov(r12, rcx) // reset rcx to current utile of c. + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm9, ymm9) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(ymm0, ymm11, ymm11) + vmulpd(ymm0, ymm12, ymm12) + vmulpd(ymm0, ymm13, ymm13) + vmulpd(ymm0, ymm14, ymm14) + vmulpd(ymm0, ymm15, ymm15) + + + + + + + mov(var(cs_c), rsi) // load cs_c + lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) + + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; + + lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORED) // jump to column storage case + + + + label(.DROWSTORED) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vmovupd(xmm4, mem(rcx, 0*32)) + vextractf128(imm(1), ymm4, xmm1) + vmovlpd(xmm1, mem(rcx, 2*8)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) + vmovupd(ymm6, mem(rcx, 0*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) + vmovupd(ymm8, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm9) + vmovlpd(xmm9, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm10) + vmovupd(ymm10, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm11) + vmovupd(xmm11, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm12) + vmovupd(ymm12, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm13) + vmovupd(xmm13, mem(rcx, 1*32)) + vextractf128(imm(1), ymm13, xmm1) + vmovlpd(xmm1, mem(rcx, 1*32+2*8)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm14) + vmovupd(ymm14, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm15) + vmovupd(ymm15, mem(rcx, 1*32)) + + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORED) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vextractf128(imm(1), ymm10, xmm1) + vmovhpd(xmm10, mem(rcx, rax, 1, 1*8)) + vmovupd(xmm1, mem(rcx, rax, 1, 2*8)) + + lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx ), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + // begin I/O on columns 4-7 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm5) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm7) + vextractf128(imm(1), ymm5, xmm1) + vmovupd(xmm1, mem(rcx, 2*8 )) + vextractf128(imm(1), ymm7, xmm1) + vmovhpd(xmm1, mem(rcx, rsi, 1, 3*8)) + + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx ), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovhpd(xmm4, mem(rdx, rax, 1, 1*8)) + + + + jmp(.DDONE) // jump to end. + + + + + label(.DBETAZERO) + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLSTORBZ) // jump to column storage case + + + + label(.DROWSTORBZ) + + + vmovupd(xmm4, mem(rcx, 0*32)) + vextractf128(imm(1), ymm4, xmm1) + vmovlpd(xmm1, mem(rcx, 2*8)) + add(rdi, rcx) + + + vmovupd(ymm6, mem(rcx, 0*32)) + add(rdi, rcx) + + + vmovupd(ymm8, mem(rcx, 0*32)) + vmovlpd(xmm9, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm10, mem(rcx, 0*32)) + vmovupd(xmm11, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm12, mem(rcx, 0*32)) + vmovupd(xmm13, mem(rcx, 1*32)) + vextractf128(imm(1), ymm13, xmm1) + vmovlpd(xmm1, mem(rcx, 1*32+2*8)) + add(rdi, rcx) + + + vmovupd(ymm14, mem(rcx, 0*32)) + vmovupd(ymm15, mem(rcx, 1*32)) + //add(rdi, rcx) + + jmp(.DDONE) // jump to end. + + + + label(.DCOLSTORBZ) + + // begin I/O on columns 0-3 + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vextractf128(imm(1), ymm10, xmm1) + vmovhpd(xmm10, mem(rcx, rax, 1, 1*8)) + vmovupd(xmm1, mem(rcx, rax, 1, 2*8)) + + lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + // begin I/O on columns 4-7 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + + vbroadcastsd(mem(rbx), ymm3) + + vextractf128(imm(1), ymm5, xmm1) + vmovupd(xmm1, mem(rcx, 2*8 )) + vextractf128(imm(1), ymm7, xmm1) + vmovhpd(xmm1, mem(rcx, rsi, 1, 3*8)) + + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovhpd(xmm4, mem(rdx, rax, 1, 1*8)) + + label(.DDONE) + label(.DRETURN) + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a8] "m" (ps_a8), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + + // Handle edge cases in the m dimension, if they exist. + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); +} +/* + +Following kernel computes the 6x8 block for the Upper vairant(U) of gemmt where +m_offset in 24x24 block is 0 and n_offset is 0(0x0) +(0x0)_U + +the region marked with 'x' is computed by following kernel +the region marked with '-' is not computed + + <-- n_off_24 -- > + 0 1 2 3 4 5 6 7 + +↑ 0 x x x x x x x x +| 1 - x x x x x x x +m 2 - - x x x x x x +off 3 - - - x x x x x +24 4 - - - - x x x x +| 5 - - - - - x x x +↓ + + +*/ +void bli_dgemmsup_rv_haswell_asm_6x8m_0x0_U + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + uint64_t ps_a8 = bli_auxinfo_ps_a( data ) * sizeof( double ); + + begin_asm() + + + mov(var(a), r14) + mov(var(rs_a), r8) + mov(var(cs_a), r9) + lea(mem(, r8, 8), r8) + lea(mem(, r9, 8), r9) + + lea(mem(r8, r8, 2), r13) + lea(mem(r8, r8, 4), r15) + + mov(var(rs_b), r10) + lea(mem(, r10, 8), r10) + + mov(var(c), r12) + mov(var(rs_c), rdi) + lea(mem(, rdi, 8), rdi) + + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm15) + + mov(var(b), rbx) + mov(r14, rax) + + + + PREFETCH_C() + + label(.DPOSTPFETCH) + + mov(var(ps_a8), rdx) + lea(mem(rax, rdx, 1), rdx) + lea(mem(r9, r9, 2), rcx) + + mov(var(k_iter), rsi) + test(rsi, rsi) + je(.DCONSIDKLEFT) + + //ymm12, ymm14 can be skipped + label(.DLOOPKITER) // MAIN LOOP + //0 + prefetch(0, mem(rdx, 5*8)) + SUBITER4x4(rax, rbx, ymm4, ymm6, ymm8, ymm10) + lea(mem(rbx, 1*32), rbp) + SUBITER6x4(rax, rbp, ymm5, ymm7, ymm9, ymm11, ymm13, ymm15) + add(r10, rbx) // b += rs_b; + add(r9, rax) // a += cs_a; + + //1 + prefetch(0, mem(rdx, r9, 1, 5*8)) + + SUBITER4x4(rax, rbx, ymm4, ymm6, ymm8, ymm10) + lea(mem(rbx, 1*32), rbp) + SUBITER6x4(rax, rbp, ymm5, ymm7, ymm9, ymm11, ymm13, ymm15) + add(r10, rbx) // b += rs_b; + add(r9, rax) // a += cs_a; + //2 + + prefetch(0, mem(rdx, r9, 2, 5*8)) + + SUBITER4x4(rax, rbx, ymm4, ymm6, ymm8, ymm10) + lea(mem(rbx, 1*32), rbp) + SUBITER6x4(rax, rbp, ymm5, ymm7, ymm9, ymm11, ymm13, ymm15) + add(r10, rbx) // b += rs_b; + add(r9, rax) // a += cs_a; + //3 + prefetch(0, mem(rdx, rcx, 1, 5*8)) + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; + + SUBITER4x4(rax, rbx, ymm4, ymm6, ymm8, ymm10) + lea(mem(rbx, 1*32), rbp) + SUBITER6x4(rax, rbp, ymm5, ymm7, ymm9, ymm11, ymm13, ymm15) + add(r10, rbx) // b += rs_b; + add(r9, rax) // a += cs_a; + dec(rsi) + jne(.DLOOPKITER) + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) + test(rsi, rsi) + je(.DPOSTACCUM) + + label(.DLOOPKLEFT) // EDGE LOOP + + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) + + SUBITER4x4(rax, rbx, ymm4, ymm6, ymm8, ymm10) + lea(mem(rbx, 1*32), rbp) + SUBITER6x4(rax, rbp, ymm5, ymm7, ymm9, ymm11, ymm13, ymm15) + add(r10, rbx) // b += rs_b; + add(r9, rax) // a += cs_a; + dec(rsi) + jne(.DLOOPKLEFT) + + label(.DPOSTACCUM) + + + + mov(r12, rcx) + mov(var(alpha), rax) + mov(var(beta), rbx) + vbroadcastsd(mem(rax), ymm0) + vbroadcastsd(mem(rbx), ymm3) + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm7, ymm7) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm9, ymm9) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(ymm0, ymm11, ymm11) + vmulpd(ymm0, ymm12, ymm12) + vmulpd(ymm0, ymm13, ymm13) + vmulpd(ymm0, ymm14, ymm14) + vmulpd(ymm0, ymm15, ymm15) + + mov(var(cs_c), rsi) + lea(mem(, rsi, 8), rsi) + lea(mem(rcx, rdi, 4), rdx) // c + 4*rs_c; + lea(mem(rsi, rsi, 2), rax) // 3*cs_c; + + + vxorpd(ymm0, ymm0, ymm0) + vucomisd(xmm0, xmm3) + je(.DBETAZERO) + + cmp(imm(8), rdi) + jz(.DCOLSTORED) + + label(.DROWSTORED) + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vmovupd(ymm4, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm5) + vmovupd(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) + vmovhpd(xmm6, mem(rcx, 0*32+1*8)) + vextractf128(imm(0x1), ymm6, xmm6) + vmovupd(xmm6, mem(rcx, 0*32+2*8)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm7) + vmovupd(ymm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) + vextractf128(imm(0x1), ymm8, xmm8) + vmovupd(xmm8, mem(rcx, 0*32+2*8)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm9) + vmovupd(ymm9, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm10) + vextractf128(imm(0x1), ymm10, xmm10) + vmovhpd(xmm10, mem(rcx, 0*32+3*8)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm11) + vmovupd(ymm11, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm13) + vmovupd(ymm13, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm15) + vmovhpd(xmm15, mem(rcx, 1*32+1*8)) + vextractf128(imm(0x1), ymm15, xmm15) + vmovupd(xmm15, mem(rcx, 1*32+2*8)) + + + jmp(.DDONE) + + + label(.DCOLSTORED) + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) + vmovlpd(xmm4, mem(rcx )) + vmovupd(xmm6, mem(rcx, rsi, 1)) + vmovupd(xmm8, mem(rcx, rsi, 2)) + vextractf128(imm(0x1), ymm8, xmm8) + vmovlpd(xmm8, mem(rcx, rsi, 2, 1*16)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + + lea(mem(rdx, rsi, 4), rdx) + + // begin I/O on columns 4-7 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm5) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm7) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm9) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm11) + vmovupd(ymm5, mem(rcx )) + vmovupd(ymm7, mem(rcx, rsi, 1)) + vmovupd(ymm9, mem(rcx, rsi, 2)) + vmovupd(ymm11, mem(rcx, rax, 1)) + + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx ), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovlpd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + jmp(.DDONE) // jump to end. + + label(.DBETAZERO) // if beta zero + + + cmp(imm(8), rdi) + jz(.DCOLSTORBZ) + + label(.DROWSTORBZ) + + vmovupd(ymm4, mem(rcx, 0*32)) + vmovupd(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + vmovhpd(xmm6, mem(rcx, 0*32+1*8)) + vextractf128(imm(0x1), ymm6, xmm6) + vmovupd(xmm6, mem(rcx, 0*32+2*8)) + vmovupd(ymm7, mem(rcx, 1*32)) + add(rdi, rcx) + + vextractf128(imm(0x1), ymm8, xmm8) + vmovupd(xmm8, mem(rcx, 0*32+2*8)) + vmovupd(ymm9, mem(rcx, 1*32)) + add(rdi, rcx) + + vextractf128(imm(0x1), ymm10, xmm10) + vmovhpd(xmm10, mem(rcx, 0*32+3*8)) + vmovupd(ymm11, mem(rcx, 1*32)) + add(rdi, rcx) + + vmovupd(ymm13, mem(rcx, 1*32)) + add(rdi, rcx) + + vmovhpd(xmm15, mem(rcx, 1*32+1*8)) + vextractf128(imm(0x1), ymm15, xmm15) + vmovupd(xmm15, mem(rcx, 1*32+2*8)) + + jmp(.DDONE) + + label(.DCOLSTORBZ) + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vmovlpd(xmm4, mem(rcx )) + vmovupd(xmm6, mem(rcx, rsi, 1)) + vmovupd(xmm8, mem(rcx, rsi, 2)) + vextractf128(imm(0x1), ymm8, xmm8) + vmovlpd(xmm8, mem(rcx, rsi, 2, 1*16)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + + lea(mem(rdx, rsi, 4), rdx) + + // begin I/O on columns 4-7 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vmovupd(ymm5, mem(rcx )) + vmovupd(ymm7, mem(rcx, rsi, 1)) + vmovupd(ymm9, mem(rcx, rsi, 2)) + vmovupd(ymm11, mem(rcx, rax, 1)) + + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovlpd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + + label(.DDONE) + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a8] "m" (ps_a8), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +/* + +Following kernel computes the 6x8 block for the Upper vairant(U) of gemmt where +m_offset in 24x24 block is 6 and n_offset is 8(6x8) +(6x8)_U + +the region marked with 'x' is computed by following kernel +the region marked with '-' is not computed + + <-- n_off_24 -- > + 8 9 10 11 12 13 14 15 + +↑ 6 x x x x x x x x +| 7 x x x x x x x x +m 8 x x x x x x x x +off 9 - x x x x x x x +24 10 - - x x x x x x +| 11 - - - x x x x x +↓ + + +*/ +void bli_dgemmsup_rv_haswell_asm_6x8m_6x8_U + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + uint64_t ps_a8 = bli_auxinfo_ps_a( data ) * sizeof( double ); + + begin_asm() + + + mov(var(a), r14) + mov(var(rs_a), r8) + mov(var(cs_a), r9) + lea(mem(, r8, 8), r8) + lea(mem(, r9, 8), r9) + + lea(mem(r8, r8, 2), r13) + lea(mem(r8, r8, 4), r15) + + mov(var(rs_b), r10) + lea(mem(, r10, 8), r10) + + mov(var(c), r12) + mov(var(rs_c), rdi) + lea(mem(, rdi, 8), rdi) + + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + mov(var(b), rbx) + mov(r14, rax) + + + PREFETCH_C() + + label(.DPOSTPFETCH) + + mov(var(ps_a8), rdx) + lea(mem(rax, rdx, 1), rdx) + lea(mem(r9, r9, 2), rcx) + + mov(var(k_iter), rsi) + test(rsi, rsi) + je(.DCONSIDKLEFT) + + label(.DLOOPKITER) // MAIN LOOP + //0 + prefetch(0, mem(rdx, 5*8)) + + SUBITER6x4(rax, rbx, ymm4, ymm6, ymm8, ymm10, ymm12, ymm14) + lea(mem(rbx, 1*32), rbp) + SUBITER6x4(rax, rbp, ymm5, ymm7, ymm9, ymm11, ymm13, ymm15) + add(r10, rbx) // b += rs_b; + add(r9, rax) // a += cs_a; + //1 + prefetch(0, mem(rdx, r9, 1, 5*8)) + + SUBITER6x4(rax, rbx, ymm4, ymm6, ymm8, ymm10, ymm12, ymm14) + lea(mem(rbx, 1*32), rbp) + SUBITER6x4(rax, rbp, ymm5, ymm7, ymm9, ymm11, ymm13, ymm15) + add(r10, rbx) // b += rs_b; + add(r9, rax) // a += cs_a; + //2 + + prefetch(0, mem(rdx, r9, 2, 5*8)) + + SUBITER6x4(rax, rbx, ymm4, ymm6, ymm8, ymm10, ymm12, ymm14) + lea(mem(rbx, 1*32), rbp) + SUBITER6x4(rax, rbp, ymm5, ymm7, ymm9, ymm11, ymm13, ymm15) + add(r10, rbx) // b += rs_b; + add(r9, rax) // a += cs_a; + //3 + prefetch(0, mem(rdx, rcx, 1, 5*8)) + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; + + SUBITER6x4(rax, rbx, ymm4, ymm6, ymm8, ymm10, ymm12, ymm14) + lea(mem(rbx, 1*32), rbp) + SUBITER6x4(rax, rbp, ymm5, ymm7, ymm9, ymm11, ymm13, ymm15) + add(r10, rbx) // b += rs_b; + add(r9, rax) // a += cs_a; + + dec(rsi) + jne(.DLOOPKITER) + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) + test(rsi, rsi) + je(.DPOSTACCUM) + + label(.DLOOPKLEFT) // EDGE LOOP + + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) + + SUBITER6x4(rax, rbx, ymm4, ymm6, ymm8, ymm10, ymm12, ymm14) + lea(mem(rbx, 1*32), rbp) + SUBITER6x4(rax, rbp, ymm5, ymm7, ymm9, ymm11, ymm13, ymm15) + add(r10, rbx) // b += rs_b; + add(r9, rax) // a += cs_a; + + dec(rsi) + jne(.DLOOPKLEFT) + + label(.DPOSTACCUM) + + + + mov(r12, rcx) + mov(var(alpha), rax) + mov(var(beta), rbx) + vbroadcastsd(mem(rax), ymm0) + vbroadcastsd(mem(rbx), ymm3) + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm7, ymm7) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm9, ymm9) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(ymm0, ymm11, ymm11) + vmulpd(ymm0, ymm12, ymm12) + vmulpd(ymm0, ymm13, ymm13) + vmulpd(ymm0, ymm14, ymm14) + vmulpd(ymm0, ymm15, ymm15) + + mov(var(cs_c), rsi) + lea(mem(, rsi, 8), rsi) + lea(mem(rcx, rdi, 4), rdx) // c + 4*rs_c; + lea(mem(rsi, rsi, 2), rax) // 3*cs_c; + + + vxorpd(ymm0, ymm0, ymm0) + vucomisd(xmm0, xmm3) + je(.DBETAZERO) + + cmp(imm(8), rdi) + jz(.DCOLSTORED) + + label(.DROWSTORED) + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vmovupd(ymm4, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm5) + vmovupd(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) + vmovupd(ymm6, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm7) + vmovupd(ymm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) + vmovupd(ymm8, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm9) + vmovupd(ymm9, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm10) + vmovhpd(xmm10, mem(rcx, 0*32+1*8)) + vextractf128(imm(0x1), ymm10, xmm10) + vmovupd(xmm10, mem(rcx, 0*32+2*8)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm11) + vmovupd(ymm11, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm12) + vextractf128(imm(0x1), ymm12, xmm12) + vmovupd(xmm12, mem(rcx, 0*32+2*8)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm13) + vmovupd(ymm13, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm14) + vextractf128(imm(0x1), ymm14, xmm14) + vmovhpd(xmm14, mem(rcx, 0*32+3*8)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm15) + vmovupd(ymm15, mem(rcx, 1*32)) + + jmp(.DDONE) + + + label(.DCOLSTORED) + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) + vmovupd(xmm4, mem(rcx )) + vextractf128(imm(0x1), ymm4, xmm4) + vmovlpd(xmm4, mem(rcx, 2*8 )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx ), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovlpd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + // begin I/O on columns 4-7 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm5) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm7) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm9) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm11) + vmovupd(ymm5, mem(rcx )) + vmovupd(ymm7, mem(rcx, rsi, 1)) + vmovupd(ymm9, mem(rcx, rsi, 2)) + vmovupd(ymm11, mem(rcx, rax, 1)) + + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx ), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + jmp(.DDONE) // jump to end. + + label(.DBETAZERO) // if beta zero + + + cmp(imm(8), rdi) + jz(.DCOLSTORBZ) + + label(.DROWSTORBZ) + + vmovupd(ymm4, mem(rcx, 0*32)) + vmovupd(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx, 0*32)) + vmovupd(ymm7, mem(rcx, 1*32)) + add(rdi, rcx) + + vmovupd(ymm8, mem(rcx, 0*32)) + vmovupd(ymm9, mem(rcx, 1*32)) + add(rdi, rcx) + + vmovhpd(xmm10, mem(rcx, 0*32+1*8)) + vextractf128(imm(0x1), ymm10, xmm10) + vmovupd(xmm10, mem(rcx, 0*32+2*8)) + vmovupd(ymm11, mem(rcx, 1*32)) + add(rdi, rcx) + + + vextractf128(imm(0x1), ymm12, xmm12) + vmovupd(xmm12, mem(rcx, 0*32+2*8)) + vmovupd(ymm13, mem(rcx, 1*32)) + add(rdi, rcx) + + vextractf128(imm(0x1), ymm14, xmm14) + vmovhpd(xmm14, mem(rcx, 0*32+3*8)) + vmovupd(ymm15, mem(rcx, 1*32)) + + jmp(.DDONE) + + label(.DCOLSTORBZ) + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vmovupd(xmm4, mem(rcx )) + vextractf128(imm(0x1), ymm4, xmm4) + vmovlpd(xmm4, mem(rcx, 2*8 )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx ), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovlpd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + // begin I/O on columns 4-7 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vmovupd(ymm5, mem(rcx )) + vmovupd(ymm7, mem(rcx, rsi, 1)) + vmovupd(ymm9, mem(rcx, rsi, 2)) + vmovupd(ymm11, mem(rcx, rax, 1)) + + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + + label(.DDONE) + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a8] "m" (ps_a8), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +/* + +Following kernel computes the 6x8 block for the Upper vairant(U) of gemmt where +m_offset in 24x24 block is 12 and n_offset is 16(12x16) +(12x16)_U + + +the region marked with 'x' is computed by following kernel +the region marked with '-' is not computed + + <-- n_off_24 -- > + 16 17 18 19 20 21 22 23 + +↑ 12 x x x x x x x x +| 13 x x x x x x x x +m 14 x x x x x x x x +off 15 x x x x x x x x +24 16 x x x x x x x x +| 17 - x x x x x x x +↓ + + +*/ +void bli_dgemmsup_rv_haswell_asm_6x8m_12x16_U + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + uint64_t ps_a8 = bli_auxinfo_ps_a( data ) * sizeof( double ); + + begin_asm() + + + mov(var(a), r14) + mov(var(rs_a), r8) + mov(var(cs_a), r9) + lea(mem(, r8, 8), r8) + lea(mem(, r9, 8), r9) + + lea(mem(r8, r8, 2), r13) + lea(mem(r8, r8, 4), r15) + + mov(var(rs_b), r10) + lea(mem(, r10, 8), r10) + + mov(var(c), r12) + mov(var(rs_c), rdi) + lea(mem(, rdi, 8), rdi) + + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + mov(var(b), rbx) + mov(r14, rax) + + + PREFETCH_C() + + label(.DPOSTPFETCH) + + mov(var(ps_a8), rdx) + lea(mem(rax, rdx, 1), rdx) + lea(mem(r9, r9, 2), rcx) + + mov(var(k_iter), rsi) + test(rsi, rsi) + je(.DCONSIDKLEFT) + + label(.DLOOPKITER) // MAIN LOOP + //0 + prefetch(0, mem(rdx, 5*8)) + + SUBITER6x4(rax, rbx, ymm4, ymm6, ymm8, ymm10, ymm12, ymm14) + lea(mem(rbx, 1*32), rbp) + SUBITER6x4(rax, rbp, ymm5, ymm7, ymm9, ymm11, ymm13, ymm15) + add(r10, rbx) // b += rs_b; + add(r9, rax) // a += cs_a + //1 + prefetch(0, mem(rdx, r9, 1, 5*8)) + SUBITER6x4(rax, rbx, ymm4, ymm6, ymm8, ymm10, ymm12, ymm14) + lea(mem(rbx, 1*32), rbp) + SUBITER6x4(rax, rbp, ymm5, ymm7, ymm9, ymm11, ymm13, ymm15) + add(r10, rbx) // b += rs_b; + add(r9, rax) // a += cs_a; + + //2 + + prefetch(0, mem(rdx, r9, 2, 5*8)) + SUBITER6x4(rax, rbx, ymm4, ymm6, ymm8, ymm10, ymm12, ymm14) + lea(mem(rbx, 1*32), rbp) + SUBITER6x4(rax, rbp, ymm5, ymm7, ymm9, ymm11, ymm13, ymm15) + add(r10, rbx) // b += rs_b; + add(r9, rax) // a += cs_a; + //3 + prefetch(0, mem(rdx, rcx, 1, 5*8)) + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; + + SUBITER6x4(rax, rbx, ymm4, ymm6, ymm8, ymm10, ymm12, ymm14) + lea(mem(rbx, 1*32), rbp) + SUBITER6x4(rax, rbp, ymm5, ymm7, ymm9, ymm11, ymm13, ymm15) + add(r10, rbx) // b += rs_b; + add(r9, rax) // a += cs_a; + dec(rsi) + jne(.DLOOPKITER) + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) + test(rsi, rsi) + je(.DPOSTACCUM) + + label(.DLOOPKLEFT) // EDGE LOOP + + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) + SUBITER6x4(rax, rbx, ymm4, ymm6, ymm8, ymm10, ymm12, ymm14) + lea(mem(rbx, 1*32), rbp) + SUBITER6x4(rax, rbp, ymm5, ymm7, ymm9, ymm11, ymm13, ymm15) + add(r10, rbx) // b += rs_b; + add(r9, rax) // a += cs_a; + + dec(rsi) + jne(.DLOOPKLEFT) + + label(.DPOSTACCUM) + + + + mov(r12, rcx) + mov(var(alpha), rax) + mov(var(beta), rbx) + vbroadcastsd(mem(rax), ymm0) + vbroadcastsd(mem(rbx), ymm3) + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm7, ymm7) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm9, ymm9) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(ymm0, ymm11, ymm11) + vmulpd(ymm0, ymm12, ymm12) + vmulpd(ymm0, ymm13, ymm13) + vmulpd(ymm0, ymm14, ymm14) + vmulpd(ymm0, ymm15, ymm15) + + mov(var(cs_c), rsi) + lea(mem(, rsi, 8), rsi) + lea(mem(rcx, rdi, 4), rdx) // c + 4*rs_c; + lea(mem(rsi, rsi, 2), rax) // 3*cs_c; + + + vxorpd(ymm0, ymm0, ymm0) + vucomisd(xmm0, xmm3) + je(.DBETAZERO) + + cmp(imm(8), rdi) + jz(.DCOLSTORED) + + label(.DROWSTORED) + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vmovupd(ymm4, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm5) + vmovupd(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) + vmovupd(ymm6, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm7) + vmovupd(ymm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) + vmovupd(ymm8, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm9) + vmovupd(ymm9, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm10) + vmovupd(ymm10, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm11) + vmovupd(ymm11, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm12) + vmovupd(ymm12, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm13) + vmovupd(ymm13, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm14) + vmovhpd(xmm14, mem(rcx, 0*32+1*8)) + vextractf128(imm(0x1), ymm14, xmm14) + vmovupd(xmm14, mem(rcx, 0*32+2*8)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm15) + vmovupd(ymm15, mem(rcx, 1*32)) + + jmp(.DDONE) + + + label(.DCOLSTORED) + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx ), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovlpd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + // begin I/O on columns 4-7 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm5) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm7) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm9) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm11) + vmovupd(ymm5, mem(rcx )) + vmovupd(ymm7, mem(rcx, rsi, 1)) + vmovupd(ymm9, mem(rcx, rsi, 2)) + vmovupd(ymm11, mem(rcx, rax, 1)) + + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx ), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + jmp(.DDONE) // jump to end. + + label(.DBETAZERO) + + + cmp(imm(8), rdi) + jz(.DCOLSTORBZ) + + label(.DROWSTORBZ) + + vmovupd(ymm4, mem(rcx, 0*32)) + vmovupd(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm6, mem(rcx, 0*32)) + vmovupd(ymm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm8, mem(rcx, 0*32)) + vmovupd(ymm9, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm10, mem(rcx, 0*32)) + vmovupd(ymm11, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm12, mem(rcx, 0*32)) + vmovupd(ymm13, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovhpd(xmm14, mem(rcx, 0*32+1*8)) + vextractf128(imm(0x1), ymm14, xmm14) + vmovupd(xmm14, mem(rcx, 0*32+2*8)) + vmovupd(ymm15, mem(rcx, 1*32)) + + jmp(.DDONE) + + label(.DCOLSTORBZ) + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovlpd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + // begin I/O on columns 4-7 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vmovupd(ymm5, mem(rcx )) + vmovupd(ymm7, mem(rcx, rsi, 1)) + vmovupd(ymm9, mem(rcx, rsi, 2)) + vmovupd(ymm11, mem(rcx, rax, 1)) + + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + + label(.DDONE) + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a8] "m" (ps_a8), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +/* + +Following kernel computes the 6x8 block for the Upper vairant(U) of gemmt where +m_offset in 24x24 block is 6 and n_offset is 0(6x0) +(6x0)_U + + +the region marked with 'x' is computed by following kernel +the region marked with '-' is not computed + + <-- n_off_24 -- > + 0 1 2 3 4 5 6 7 + +↑ 6 - - - - - - x x +| 7 - - - - - - - x +m 8 - - - - - - - - +off 9 - - - - - - - - +24 10 - - - - - - - - +| 11 - - - - - - - - +↓ + + +*/ +void bli_dgemmsup_rv_haswell_asm_6x8m_6x0_U + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + uint64_t ps_a8 = bli_auxinfo_ps_a( data ) * sizeof( double ); + + begin_asm() + mov(var(a), r14) + mov(var(b), rbx) + mov(var(c), r12) + mov(r14, rax) + + mov(var(rs_a), r8) + mov(var(cs_a), r9) + lea(mem(, r8, 8), r8) + lea(mem(, r9, 8), r9) + + mov(var(rs_b), r10) + lea(mem(, r10, 8), r10) + + mov(var(rs_c), rdi) + lea(mem(, rdi, 8), rdi) + + lea(mem(r8, r8, 4), r15) + + vxorpd(ymm5, ymm5, ymm5) + vxorpd(ymm7, ymm7, ymm7) + + cmp(imm(8), rdi) + jz(.DCOLPFETCH) + + label(.DROWPFETCH) + lea(mem(r12, rdi, 2), rdx) + lea(mem(rdx, rdi, 1), rdx) + prefetch(0, mem(rdx, rdi, 1, 1*8)) + prefetch(0, mem(rdx, rdi, 2, 2*8)) + jmp(.DPOSTPFETCH) + + label(.DCOLPFETCH) + mov(var(cs_c), rsi) + lea(mem(, rsi, 8), rsi) + prefetch(0, mem(r12, 5*8)) + prefetch(0, mem(r12, rsi, 1, 5*8)) + + label(.DPOSTPFETCH) + mov(var(k_iter), rsi) + test(rsi, rsi) + lea(mem(rbx, 1*16), rbx) + je(.DCONSILEFT) + + //compute xmm5 and xmm7 only + label(.DMAIN) + //0 + lea(mem(rbx, 1*32), rbp) + SUBITER2x2(rax, rbp, xmm5, xmm7) + add(r9, rax) + add(r10, rbx) + //1 + lea(mem(rbx, 1*32), rbp) + SUBITER2x2(rax, rbp, xmm5, xmm7) + add(r9, rax) + add(r10, rbx) + //2 + lea(mem(rbx, 1*32), rbp) + SUBITER2x2(rax, rbp, xmm5, xmm7) + add(r9, rax) + add(r10, rbx) + //3 + lea(mem(rbx, 1*32), rbp) + SUBITER2x2(rax, rbp, xmm5, xmm7) + add(r9, rax) + add(r10, rbx) + + dec(rsi) + jne(.DMAIN) + + label(.DCONSILEFT) + mov(var(k_left), rsi) + test(rsi, rsi) + je(.DPOSTACC) + + label(.DLEFT) + lea(mem(rbx, 1*32), rbp) + SUBITER2x2(rax, rbp, xmm5, xmm7) + add(r9, rax) + add(r10, rbx) + dec(rsi) + jne(.DLEFT) + + label(.DPOSTACC) + mov(r12, rcx) + mov(var(alpha), rax) + mov(var(beta), rbx) + vbroadcastsd(mem(rax), ymm0) + vbroadcastsd(mem(rbx), ymm3) + lea(mem(rsi, rsi, 2), rax) + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm7, ymm7) + + mov(var(cs_c), rsi) + lea(mem(, rsi, 8), rsi) + vxorpd(ymm0, ymm0, ymm0) + + cmp(imm(8), rdi) + je(.DCOLSTOR) + + label(.DROWSTOR) + lea(mem(rcx, 1*32), rcx) + lea(mem(rcx, 1*16), rcx) + + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm5) + vmovlpd(xmm5, mem(rcx)) + vmovhpd(xmm5, mem(rcx, rsi, 1)) + add(rdi, rcx) + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm7) + vmovhpd(xmm7, mem(rcx, rsi, 1)) + + jmp(.DRETURN) + + label(.DCOLSTOR) + + vbroadcastsd(mem(rbx), ymm3) + lea(mem(rcx, rsi, 4), rcx) + lea(mem(rcx, rsi, 2), rcx) + vunpcklpd(xmm7, xmm5, xmm0) + vunpckhpd(xmm7, xmm5, xmm1) + vfmadd231pd(mem(rcx ), xmm3, xmm0) + vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm1) + vmovlpd(xmm0, mem(rcx )) + vmovupd(xmm1, mem(rcx, rsi, 1)) + + label(.DRETURN) + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a8] "m" (ps_a8), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +/* + +Following kernel computes the 6x8 block for the Upper vairant(U) of gemmt where +m_offset in 24x24 block is 12 and n_offset is 8(12x8) +(12x8)_U + +the region marked with 'x' is computed by following kernel +the region marked with '-' is not computed + + <-- n_off_24 -- > + 8 9 10 11 12 13 14 15 + +↑ 12 - - - - x x x x +| 13 - - - - - x x x +m 14 - - - - - - x x +off 15 - - - - - - - x +24 16 - - - - - - - - +| 17 - - - - - - - - +↓ + + +*/ +void bli_dgemmsup_rv_haswell_asm_6x8m_12x8_U + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + uint64_t ps_a8 = bli_auxinfo_ps_a( data ) * sizeof( double ); + + begin_asm() + + + mov(var(a), r14) + mov(var(rs_a), r8) + mov(var(cs_a), r9) + lea(mem(, r8, 8), r8) + lea(mem(, r9, 8), r9) + + lea(mem(r8, r8, 2), r13) + lea(mem(r8, r8, 4), r15) + + mov(var(rs_b), r10) + lea(mem(, r10, 8), r10) + + mov(var(c), r12) + mov(var(rs_c), rdi) + lea(mem(, rdi, 8), rdi) + + vxorpd(ymm5, ymm5, ymm5) + vmovapd( ymm5, ymm7) + vmovapd( ymm5, ymm9) + vmovapd( ymm5, ymm11) + + mov(var(b), rbx) + mov(r14, rax) + + cmp(imm(8), rdi) + jz(.DCOLPFETCH) + label(.DROWPFETCH) + + lea(mem(r12, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(r12, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(r12, rdi, 2, 7*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c + + jmp(.DPOSTPFETCH) + label(.DCOLPFETCH) + + mov(var(cs_c), rsi) + lea(mem(, rsi, 8), rsi) + lea(mem(r12, rsi, 2), rdx) + lea(mem(rdx, rsi, 1), rdx) + prefetch(0, mem(rdx, rsi, 1, 5*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*8)) // prefetch c + 5*cs_c + lea(mem(rdx, rsi, 2), rdx) // rdx = c + 5*cs_c; + prefetch(0, mem(rdx, rsi, 1, 5*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*8)) // prefetch c + 7*cs_c + + label(.DPOSTPFETCH) + + mov(var(ps_a8), rdx) + lea(mem(rax, rdx, 1), rdx) + lea(mem(r9, r9, 2), rcx) + + mov(var(k_iter), rsi) + test(rsi, rsi) + je(.DCONSIDKLEFT) + + //compute ymm5, 7, 9, 11 only + label(.DLOOPKITER) // MAIN LOOP + //0 + prefetch(0, mem(rdx, 5*8)) + + lea(mem(rbx, 1*32), rbp) + SUBITER4x4(rax, rbp, ymm5, ymm7, ymm9, ymm11) + add(r9, rax) // a += cs_a; + add(r10, rbx) + //1 + prefetch(0, mem(rdx, r9, 1, 5*8)) + + lea(mem(rbx, 1*32), rbp) + SUBITER4x4(rax, rbp, ymm5, ymm7, ymm9, ymm11) + add(r9, rax) // a += cs_a; + add(r10, rbx) + //2 + + prefetch(0, mem(rdx, r9, 2, 5*8)) + + lea(mem(rbx, 1*32), rbp) + SUBITER4x4(rax, rbp, ymm5, ymm7, ymm9, ymm11) + add(r9, rax) // a += cs_a; + add(r10, rbx) + //3 + prefetch(0, mem(rdx, rcx, 1, 5*8)) + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; + + lea(mem(rbx, 1*32), rbp) + SUBITER4x4(rax, rbp, ymm5, ymm7, ymm9, ymm11) + add(r9, rax) // a += cs_a; + add(r10, rbx) + dec(rsi) + jne(.DLOOPKITER) + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) + test(rsi, rsi) + je(.DPOSTACCUM) + + label(.DLOOPKLEFT) // EDGE LOOP + + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) + + lea(mem(rbx, 1*32), rbp) + SUBITER4x4(rax, rbp, ymm5, ymm7, ymm9, ymm11) + add(r9, rax) // a += cs_a; + add(r10, rbx) + dec(rsi) + jne(.DLOOPKLEFT) + + label(.DPOSTACCUM) + + + + mov(r12, rcx) + mov(var(alpha), rax) + mov(var(beta), rbx) + vbroadcastsd(mem(rax), ymm0) + vbroadcastsd(mem(rbx), ymm3) + + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm7, ymm7) + vmulpd(ymm0, ymm9, ymm9) + vmulpd(ymm0, ymm11, ymm11) + + mov(var(cs_c), rsi) + lea(mem(, rsi, 8), rsi) + lea(mem(rcx, rdi, 4), rdx) // c + 4*rs_c; + lea(mem(rsi, rsi, 2), rax) // 3*cs_c; + + + vxorpd(ymm0, ymm0, ymm0) + vucomisd(xmm0, xmm3) + je(.DBETAZERO) + + cmp(imm(8), rdi) + jz(.DCOLSTORED) + + label(.DROWSTORED) + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm5) + vmovupd(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm7) + vmovhpd(xmm7, mem(rcx, 1*32+1*8)) + vextractf128(imm(0x1), ymm7, xmm7) + vmovupd(xmm7, mem(rcx, 1*32+2*8)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm9) + vextractf128(imm(0x1), ymm9, xmm9) + vmovupd(xmm9, mem(rcx, 1*32+2*8)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm11) + vextractf128(imm(0x1), ymm11, xmm11) + vmovhpd(xmm11, mem(rcx, 1*32+3*8)) + + + jmp(.DDONE) + + + label(.DCOLSTORED) + + lea(mem(rdx, rsi, 4), rdx) + lea(mem(rcx, rsi, 4), rcx) + + // begin I/O on columns 4-7 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm5) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm7) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm9) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm11) + vmovlpd(xmm5, mem(rcx )) + vmovupd(xmm7, mem(rcx, rsi, 1)) + vmovupd(xmm9, mem(rcx, rsi, 2)) + vextractf128(imm(0x1), ymm9, xmm9) + vmovlpd(xmm9, mem(rcx, rsi, 2, 1*16)) + vmovupd(ymm11, mem(rcx, rax, 1)) + + jmp(.DDONE) // jump to end. + + label(.DBETAZERO) + + + cmp(imm(8), rdi) + jz(.DCOLSTORBZ) + + label(.DROWSTORBZ) + + vmovupd(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + vmovhpd(xmm7, mem(rcx, 1*32+1*8)) + vextractf128(imm(0x1), ymm7, xmm7) + vmovupd(xmm7, mem(rcx, 1*32+2*8)) + add(rdi, rcx) + + vextractf128(imm(0x1), ymm9, xmm9) + vmovupd(xmm9, mem(rcx, 1*32+2*8)) + add(rdi, rcx) + + vextractf128(imm(0x1), ymm11, xmm11) + vmovhpd(xmm11, mem(rcx, 1*32+3*8)) + + jmp(.DDONE) + + label(.DCOLSTORBZ) + + lea(mem(rdx, rsi, 4), rdx) + lea(mem(rcx, rsi, 4), rcx) + + // begin I/O on columns 4-7 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vmovlpd(xmm5, mem(rcx )) + vmovupd(xmm7, mem(rcx, rsi, 1)) + vmovupd(xmm9, mem(rcx, rsi, 2)) + vextractf128(imm(0x1), ymm9, xmm9) + vmovupd(ymm9, mem(rcx, rsi, 2, 1*16)) + + + label(.DDONE) + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a8] "m" (ps_a8), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +/* + +Following kernel computes the 6x8 block for the Upper vairant(U) of gemmt where +m_offset in 24x24 block is 18 and n_offset is 16(18x16) +(18x16)_U + + +the region marked with 'x' is computed by following kernel +the region marked with '-' is not computed + + <-- n_off_24 -- > + 16 17 18 19 20 21 22 23 + +↑ 18 - - x x x x x x +| 19 - - - x x x x x +m 20 - - - - x x x x +off 21 - - - - - x x x +24 22 - - - - - - x x +| 23 - - - - - - - x +↓ + + +*/ +void bli_dgemmsup_rv_haswell_asm_6x8m_18x16_U + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + uint64_t ps_a8 = bli_auxinfo_ps_a( data ) * sizeof( double ); + + begin_asm() + + + mov(var(a), r14) + mov(var(rs_a), r8) + mov(var(cs_a), r9) + lea(mem(, r8, 8), r8) + lea(mem(, r9, 8), r9) + + lea(mem(r8, r8, 2), r13) + lea(mem(r8, r8, 4), r15) + + mov(var(rs_b), r10) + lea(mem(, r10, 8), r10) + + mov(var(c), r12) + mov(var(rs_c), rdi) + lea(mem(, rdi, 8), rdi) + + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm15) + + mov(var(b), rbx) + mov(r14, rax) + + + + PREFETCH_C() + + label(.DPOSTPFETCH) + + mov(var(ps_a8), rdx) + lea(mem(rax, rdx, 1), rdx) + lea(mem(r9, r9, 2), rcx) + + mov(var(k_iter), rsi) + test(rsi, rsi) + je(.DCONSIDKLEFT) + + //skip ymm8, 10, 12, 14 + label(.DLOOPKITER) // MAIN LOOP + //0 + prefetch(0, mem(rdx, 5*8)) + SUBITER2x4(rax, rbx, ymm4, ymm6) + lea(mem(rbx, 1*32), rbp) + SUBITER6x4(rax, rbp, ymm5, ymm7, ymm9, ymm11, ymm13, ymm15) + add(r10, rbx) // b += rs_b; + add(r9, rax) // a += cs_a; + //1 + prefetch(0, mem(rdx, r9, 1, 5*8)) + SUBITER2x4(rax, rbx, ymm4, ymm6) + lea(mem(rbx, 1*32), rbp) + SUBITER6x4(rax, rbp, ymm5, ymm7, ymm9, ymm11, ymm13, ymm15) + add(r10, rbx) // b += rs_b; + add(r9, rax) // a += cs_a; + //2 + prefetch(0, mem(rdx, r9, 2, 5*8)) + SUBITER2x4(rax, rbx, ymm4, ymm6) + lea(mem(rbx, 1*32), rbp) + SUBITER6x4(rax, rbp, ymm5, ymm7, ymm9, ymm11, ymm13, ymm15) + add(r10, rbx) // b += rs_b; + add(r9, rax) // a += cs_a; + //3 + prefetch(0, mem(rdx, rcx, 1, 5*8)) + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; + SUBITER2x4(rax, rbx, ymm4, ymm6) + lea(mem(rbx, 1*32), rbp) + SUBITER6x4(rax, rbp, ymm5, ymm7, ymm9, ymm11, ymm13, ymm15) + add(r10, rbx) // b += rs_b; + add(r9, rax) // a += cs_a; + dec(rsi) + jne(.DLOOPKITER) + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) + test(rsi, rsi) + je(.DPOSTACCUM) + + label(.DLOOPKLEFT) // EDGE LOOP + + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) + SUBITER2x4(rax, rbx, ymm4, ymm6) + lea(mem(rbx, 1*32), rbp) + SUBITER6x4(rax, rbp, ymm5, ymm7, ymm9, ymm11, ymm13, ymm15) + add(r10, rbx) // b += rs_b; + add(r9, rax) // a += cs_a; + + dec(rsi) + jne(.DLOOPKLEFT) + + label(.DPOSTACCUM) + + + + mov(r12, rcx) + mov(var(alpha), rax) + mov(var(beta), rbx) + vbroadcastsd(mem(rax), ymm0) + vbroadcastsd(mem(rbx), ymm3) + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm7, ymm7) + vmulpd(ymm0, ymm9, ymm9) + vmulpd(ymm0, ymm11, ymm11) + vmulpd(ymm0, ymm13, ymm13) + vmulpd(ymm0, ymm15, ymm15) + + mov(var(cs_c), rsi) + lea(mem(, rsi, 8), rsi) + lea(mem(rcx, rdi, 4), rdx) // c + 4*rs_c; + lea(mem(rsi, rsi, 2), rax) // 3*cs_c; + + + vxorpd(ymm0, ymm0, ymm0) + vucomisd(xmm0, xmm3) + je(.DBETAZERO) + + cmp(imm(8), rdi) + jz(.DCOLSTORED) + + label(.DROWSTORED) + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vextractf128(imm(0x1), ymm4, xmm4) + vmovupd(xmm4, mem(rcx, 0*32+2*8)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm5) + vmovupd(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) + vextractf128(imm(0x1), ymm6, xmm6) + vmovhpd(xmm6, mem(rcx, 0*32+3*8)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm7) + vmovupd(ymm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm9) + vmovupd(ymm9, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm11) + vmovhpd(xmm11, mem(rcx, 1*32+1*8)) + vextractf128(imm(0x1), ymm11, xmm11) + vmovupd(xmm11, mem(rcx, 1*32+2*8)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm13) + vextractf128(imm(0x1), ymm13, xmm13) + vmovupd(xmm13, mem(rcx, 1*32+2*8)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm15) + vextractf128(imm(0x1), ymm15, xmm15) + vmovhpd(xmm15, mem(rcx, 1*32+3*8)) + //add(rdi, rcx) + + + jmp(.DDONE) + + + label(.DCOLSTORED) + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) + vmovlpd(xmm8, mem(rcx, rsi, 2)) + vmovupd(xmm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + lea(mem(rdx, rsi, 4), rdx) + + // begin I/O on columns 4-7 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm5) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm7) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm9) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm11) + vmovupd(xmm5, mem(rcx )) + vextractf128(imm(0x1), ymm5, xmm5) + vmovlpd(xmm5, mem(rcx, 2*8 )) + vmovupd(ymm7, mem(rcx, rsi, 1)) + vmovupd(ymm9, mem(rcx, rsi, 2)) + vmovupd(ymm11, mem(rcx, rax, 1)) + + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx ), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovlpd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + jmp(.DDONE) // jump to end. + + label(.DBETAZERO) + + + cmp(imm(8), rdi) + jz(.DCOLSTORBZ) + + label(.DROWSTORBZ) + + vextractf128(imm(0x1), ymm4, xmm4) + vmovupd(xmm4, mem(rcx, 0*32+2*8)) + vmovupd(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + vextractf128(imm(0x1), ymm6, xmm6) + vmovhpd(xmm6, mem(rcx, 0*32+3*8)) + vmovupd(ymm7, mem(rcx, 1*32)) + add(rdi, rcx) + + vmovupd(ymm9, mem(rcx, 1*32)) + add(rdi, rcx) + + vmovhpd(xmm11, mem(rcx, 1*32+1*8)) + vextractf128(imm(0x1), ymm11, xmm11) + vmovupd(xmm11, mem(rcx, 1*32+2*8)) + add(rdi, rcx) + + vextractf128(imm(0x1), ymm13, xmm13) + vmovupd(xmm13, mem(rcx, 1*32+2*8)) + add(rdi, rcx) + + vextractf128(imm(0x1), ymm15, xmm15) + vmovhpd(xmm15, mem(rcx, 1*32+3*8)) + + jmp(.DDONE) + + label(.DCOLSTORBZ) + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vmovlpd(xmm8, mem(rcx, rsi, 2)) + vmovupd(xmm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + + lea(mem(rdx, rsi, 4), rdx) + + // begin I/O on columns 4-7 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vmovupd(xmm5, mem(rcx )) + vextractf128(imm(0x1), ymm5, xmm5) + vmovlpd(xmm5, mem(rcx )) + vmovupd(ymm7, mem(rcx, rsi, 1)) + vmovupd(ymm9, mem(rcx, rsi, 2)) + vmovupd(ymm11, mem(rcx, rax, 1)) + + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovlpd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + + label(.DDONE) + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a8] "m" (ps_a8), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +/* + +Following kernel computes the 6x8 block for the Upper vairant(U) of gemmt where +m_offset in 24x24 block is 0, n_offset is 0(0x0) and m_offset is 6, n_offset is 0 (6x0) +(0x0)+(6x0)_L + +the region marked with 'x' is computed by following kernel +the region marked with '-' is not computed + + <-- n_off_24 -- > + 0 1 2 3 4 5 6 7 + +↑ 0 x x x x x x x x +| 1 - x x x x x x x +m 2 - - x x x x x x +off 3 - - - x x x x x +24 4 - - - - x x x x +| 5 - - - - - x x x +↓ +↑ 6 - - - - - - x x +| 7 - - - - - - - x +m 8 - - - - - - - - +off 9 - - - - - - - - +24 10 - - - - - - - - +| 11 - - - - - - - - +↓ + + +*/ +void bli_dgemmsup_rv_haswell_asm_6x8m_0x0_combined_U + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + uint64_t ps_a8 = bli_auxinfo_ps_a( data ) * sizeof( double ); + + begin_asm() + + + mov(var(a), r14) + mov(var(rs_a), r8) + mov(var(cs_a), r9) + lea(mem(, r8, 8), r8) + lea(mem(, r9, 8), r9) + + lea(mem(r8, r8, 2), r13) + lea(mem(r8, r8, 4), r15) + + mov(var(rs_b), r10) + lea(mem(, r10, 8), r10) + + mov(var(c), r12) + mov(var(rs_c), rdi) + lea(mem(, rdi, 8), rdi) + + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + mov(var(b), rbx) + mov(r14, rax) + + + + cmp(imm(8), rdi) + jz(.DCOLPFETCH) + label(.DROWPFETCH) + + lea(mem(r12, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(r12, 7*8)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, rdi, 1, 7*8)) // prefetch c + 1*rs_c + prefetch(0, mem(r12, rdi, 2, 7*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 7*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 7*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 7*8)) // prefetch c + 5*rs_c + + jmp(.DPOSTPFETCH) + label(.DCOLPFETCH) + + mov(var(cs_c), rsi) + lea(mem(, rsi, 8), rsi) + lea(mem(r12, rsi, 2), rdx) + lea(mem(rdx, rsi, 1), rdx) + prefetch(0, mem(r12, 5*8)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 5*8)) // prefetch c + 1*cs_c + prefetch(0, mem(r12, rsi, 2, 5*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 5*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*8)) // prefetch c + 5*cs_c + lea(mem(rdx, rsi, 2), rdx) // rdx = c + 5*cs_c; + prefetch(0, mem(rdx, rsi, 1, 5*8)) // prefetch c + 6*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*8)) // prefetch c + 7*cs_c + + label(.DPOSTPFETCH) + + mov(var(ps_a8), rdx) + lea(mem(rax, rdx, 1), rdx) + lea(mem(r9, r9, 2), rcx) + + mov(var(k_iter), rsi) + test(rsi, rsi) + je(.DCONSIDKLEFT) + + //ymm12 and ymm14 are used for 0x6 block + label(.DLOOPKITER) // MAIN LOOP + //0 + prefetch(0, mem(rdx, 5*8)) + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm1, ymm3, ymm15) + + vmovupd(mem(rbx, 1*64), ymm0) + add(r10, rbx) // b += rs_b; + lea(mem(rax, r13, 2), rbp) + vbroadcastsd(mem(rbp ), ymm2) + vbroadcastsd(mem(rbp, r8, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm12) + vfmadd231pd(ymm1, ymm3, ymm14) + + add(r9, rax) // a += cs_a; + + + + //1 + prefetch(0, mem(rdx, r9, 1, 5*8)) + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm1, ymm3, ymm15) + + vmovupd(mem(rbx, 1*64), ymm0) + add(r10, rbx) // b += rs_b; + lea(mem(rax, r13, 2), rbp) + vbroadcastsd(mem(rbp ), ymm2) + vbroadcastsd(mem(rbp, r8, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm12) + vfmadd231pd(ymm1, ymm3, ymm14) + + add(r9, rax) // a += cs_a; + + + //2 + + prefetch(0, mem(rdx, r9, 2, 5*8)) + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm1, ymm3, ymm15) + + vmovupd(mem(rbx, 1*64), ymm0) + add(r10, rbx) // b += rs_b; + lea(mem(rax, r13, 2), rbp) + vbroadcastsd(mem(rbp ), ymm2) + vbroadcastsd(mem(rbp, r8, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm12) + vfmadd231pd(ymm1, ymm3, ymm14) + add(r9, rax) // a += cs_a; + + + //3 + prefetch(0, mem(rdx, rcx, 1, 5*8)) + lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm1, ymm3, ymm15) + + vmovupd(mem(rbx, 1*64), ymm0) + add(r10, rbx) // b += rs_b; + lea(mem(rax, r13, 2), rbp) + vbroadcastsd(mem(rbp ), ymm2) + vbroadcastsd(mem(rbp, r8, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm12) + vfmadd231pd(ymm1, ymm3, ymm14) + add(r9, rax) // a += cs_a; + + + dec(rsi) + jne(.DLOOPKITER) + + label(.DCONSIDKLEFT) + + mov(var(k_left), rsi) + test(rsi, rsi) + je(.DPOSTACCUM) + + label(.DLOOPKLEFT) // EDGE LOOP + + prefetch(0, mem(rdx, 5*8)) + add(r9, rdx) + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm1, ymm3, ymm15) + + vmovupd(mem(rbx, 1*64), ymm0) + add(r10, rbx) // b += rs_b; + lea(mem(rax, r13, 2), rbp) + vbroadcastsd(mem(rbp ), ymm2) + vbroadcastsd(mem(rbp, r8, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm12) + vfmadd231pd(ymm1, ymm3, ymm14) + add(r9, rax) // a += cs_a; + + + dec(rsi) + jne(.DLOOPKLEFT) + + label(.DPOSTACCUM) + + + + mov(r12, rcx) + mov(var(alpha), rax) + mov(var(beta), rbx) + vbroadcastsd(mem(rax), ymm0) + vbroadcastsd(mem(rbx), ymm3) + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm7, ymm7) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm9, ymm9) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(ymm0, ymm11, ymm11) + vmulpd(ymm0, ymm12, ymm12) + vmulpd(ymm0, ymm13, ymm13) + vmulpd(ymm0, ymm14, ymm14) + vmulpd(ymm0, ymm15, ymm15) + + mov(var(cs_c), rsi) + lea(mem(, rsi, 8), rsi) + lea(mem(rcx, rdi, 4), rdx) // c + 4*rs_c; + lea(mem(rsi, rsi, 2), rax) // 3*cs_c; + + + vxorpd(ymm0, ymm0, ymm0) + vucomisd(xmm0, xmm3) + je(.DBETAZERO) + + cmp(imm(8), rdi) + jz(.DCOLSTORED) + + label(.DROWSTORED) + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) + vmovupd(ymm4, mem(rcx, 0*32)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm5) + vmovupd(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) + vmovhpd(xmm6, mem(rcx, 0*32+1*8)) + vextractf128(imm(0x1), ymm6, xmm6) + vmovupd(xmm6, mem(rcx, 0*32+2*8)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm7) + vmovupd(ymm7, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) + vextractf128(imm(0x1), ymm8, xmm8) + vmovupd(xmm8, mem(rcx, 0*32+2*8)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm9) + vmovupd(ymm9, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm10) + vextractf128(imm(0x1), ymm10, xmm10) + vmovhpd(xmm10, mem(rcx, 0*32+3*8)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm11) + vmovupd(ymm11, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, rdi, 2, 1*32), ymm3, ymm12) + vextractf128(imm(0x1), ymm12, xmm12) + vmovupd(xmm12, mem(rcx, rdi, 2, 1*32+2*8)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm13) + vmovupd(ymm13, mem(rcx, 1*32)) + add(rdi, rcx) + + + vfmadd231pd(mem(rcx, rdi, 2, 1*32), ymm3, ymm14) + vextractf128(imm(0x1), ymm14, xmm14) + vmovhpd(xmm14, mem(rcx, rdi, 2, 1*32+3*8)) + + vfmadd231pd(mem(rcx, 1*32), ymm3, ymm15) + vmovhpd(xmm15, mem(rcx, 1*32+1*8)) + vextractf128(imm(0x1), ymm15, xmm15) + vmovupd(xmm15, mem(rcx, 1*32+2*8)) + //add(rdi, rcx) + + + jmp(.DDONE) + + + label(.DCOLSTORED) + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm4) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm10) + vmovlpd(xmm4, mem(rcx )) + vmovupd(xmm6, mem(rcx, rsi, 1)) + vmovupd(xmm8, mem(rcx, rsi, 2)) + vextractf128(imm(0x1), ymm8, xmm8) + vmovlpd(xmm8, mem(rcx, rsi, 2, 1*16)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + lea(mem(rcx, 6*8), rbp) + lea(mem(rbp, rsi, 2), rbp) + vfmadd231pd(mem(rbp ), xmm3, xmm2) + vfmadd231pd(mem(rbp, rsi, 1), xmm3, xmm4) + vmovlpd(xmm2, mem(rbp)) + vmovupd(xmm4, mem(rbp, rsi, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + // begin I/O on columns 4-7 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vbroadcastsd(mem(rbx), ymm3) + + vfmadd231pd(mem(rcx ), ymm3, ymm5) + vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm7) + vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm9) + vfmadd231pd(mem(rcx, rax, 1), ymm3, ymm11) + vmovupd(ymm5, mem(rcx )) + vmovupd(ymm7, mem(rcx, rsi, 1)) + vmovupd(ymm9, mem(rcx, rsi, 2)) + vmovupd(ymm11, mem(rcx, rax, 1)) + + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vfmadd231pd(mem(rdx ), xmm3, xmm0) + vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) + vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) + vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) + vmovlpd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + jmp(.DDONE) // jump to end. + + label(.DBETAZERO) + + + cmp(imm(8), rdi) + jz(.DCOLSTORBZ) + + label(.DROWSTORBZ) + + vmovupd(ymm4, mem(rcx, 0*32)) + vmovupd(ymm5, mem(rcx, 1*32)) + add(rdi, rcx) + + vmovhpd(xmm6, mem(rcx, 0*32+1*8)) + vextractf128(imm(0x1), ymm6, xmm6) + vmovupd(xmm6, mem(rcx, 0*32+2*8)) + vmovupd(ymm7, mem(rcx, 1*32)) + add(rdi, rcx) + + vextractf128(imm(0x1), ymm8, xmm8) + vmovupd(xmm8, mem(rcx, 0*32+2*8)) + vmovupd(ymm9, mem(rcx, 1*32)) + add(rdi, rcx) + + vextractf128(imm(0x1), ymm10, xmm10) + vmovhpd(xmm10, mem(rcx, 0*32+3*8)) + vmovupd(ymm11, mem(rcx, 1*32)) + add(rdi, rcx) + + + vextractf128(imm(0x1), ymm12, xmm12) + vmovupd(xmm12, mem(rcx, rdi, 2, 1*32+2*8)) + vmovupd(ymm13, mem(rcx, 1*32)) + add(rdi, rcx) + + + vextractf128(imm(0x1), ymm14, xmm14) + vmovhpd(xmm14, mem(rcx, rdi, 2, 1*32+3*8)) + vmovhpd(xmm15, mem(rcx, 1*32+1*8)) + vextractf128(imm(0x1), ymm15, xmm15) + vmovupd(xmm15, mem(rcx, 1*32+2*8)) + + jmp(.DDONE) + + label(.DCOLSTORBZ) + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vmovlpd(xmm4, mem(rcx )) + vmovupd(xmm6, mem(rcx, rsi, 1)) + vmovupd(xmm8, mem(rcx, rsi, 2)) + vextractf128(imm(0x1), ymm8, xmm8) + vmovlpd(xmm8, mem(rcx, rsi, 2, 1*16)) + vmovupd(ymm10, mem(rcx, rax, 1)) + + lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + lea(mem(rcx, rsi, 4), rbp) + lea(mem(rbp, rsi, 2), rbp) + lea(mem(rbp, 1*32+1*16), rbp) + vmovlpd(xmm2, mem(rbp)) + vmovupd(xmm4, mem(rbp, rsi, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + // begin I/O on columns 4-7 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + vperm2f128(imm(0x31), ymm2, ymm0, ymm9) + vperm2f128(imm(0x31), ymm3, ymm1, ymm11) + + vmovupd(ymm5, mem(rcx )) + vmovupd(ymm7, mem(rcx, rsi, 1)) + vmovupd(ymm9, mem(rcx, rsi, 2)) + vmovupd(ymm11, mem(rcx, rax, 1)) + + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovlpd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + + label(.DDONE) + + end_asm( + : // output operands (none) + : // input operands + [k_iter] "m" (k_iter), + [k_left] "m" (k_left), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [ps_a8] "m" (ps_a8), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) +} + +void bli_dgemmsup_rv_haswell_asm_6x6m + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + //void* a_next = bli_auxinfo_next_a( data ); + //void* b_next = bli_auxinfo_next_b( data ); + + // Typecast local copies of integers in case dim_t and inc_t are a + // different size than is expected by load instructions. + uint64_t k_iter = k0 / 4; + uint64_t k_left = k0 % 4; + + uint64_t m_iter = m0 / 6; + uint64_t m_left = m0 % 6; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + // Query the panel stride of A and convert it to units of bytes. + uint64_t ps_a = bli_auxinfo_ps_a( data ); + uint64_t ps_a8 = ps_a * sizeof( double ); + + if ( m_iter == 0 ) goto consider_edge_cases; + + // ------------------------------------------------------------------------- + + begin_asm() + + //vzeroall() // zero all xmm/ymm registers. + + mov(var(a), r14) // load address of a. + mov(var(rs_a), r8) // load rs_a + mov(var(cs_a), r9) // load cs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + lea(mem(, r9, 8), r9) // cs_a *= sizeof(double) + + lea(mem(r8, r8, 2), r13) // r13 = 3*rs_a + lea(mem(r8, r8, 4), r15) // r15 = 5*rs_a + + //mov(var(b), rbx) // load address of b. + mov(var(rs_b), r10) // load rs_b + //mov(var(cs_b), r11) // load cs_b + lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) + //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + // NOTE: We cannot pre-load elements of a or b + // because it could eventually, in the last + // unrolled iter or the cleanup loop, result + // in reading beyond the bounds allocated mem + // (the likely result: a segmentation fault). + + mov(var(c), r12) // load address of c + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + + // During preamble and loops: + // r12 = rcx = c + // r14 = rax = a + // read rbx from var(b) near beginning of loop + // r11 = m dim index ii + + mov(var(m_iter), r11) // ii = m_iter; + + label(.DLOOP6X8I) // LOOP OVER ii = [ m_iter ... 1 0 ] + + + +#if 0 + vzeroall() // zero all xmm/ymm registers. +#else + // skylake can execute 3 vxorpd ipc with + // a latency of 1 cycle, while vzeroall + // has a latency of 12 cycles. + vxorpd(ymm1, ymm1, ymm1) // zero ymm1 since we only use the lower + vxorpd(ymm4, ymm4, ymm4) // half (xmm1), and nans/infs may slow us + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) +#endif + + mov(var(b), rbx) // load address of b. + //mov(r12, rcx) // reset rcx to current utile of c. + mov(r14, rax) // reset rax to current upanel of a. + + + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. + jz(.DCOLPFETCH) // jump to column storage case + label(.DROWPFETCH) // row-stored prefetching on c + + lea(mem(r12, rdi, 2), rdx) // + lea(mem(rdx, rdi, 1), rdx) // rdx = c + 3*rs_c; + prefetch(0, mem(r12, 5*8)) // prefetch c + 0*rs_c + prefetch(0, mem(r12, rdi, 1, 5*8)) // prefetch c + 1*rs_c + prefetch(0, mem(r12, rdi, 2, 5*8)) // prefetch c + 2*rs_c + prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*rs_c + prefetch(0, mem(rdx, rdi, 1, 5*8)) // prefetch c + 4*rs_c + prefetch(0, mem(rdx, rdi, 2, 5*8)) // prefetch c + 5*rs_c + + jmp(.DPOSTPFETCH) // jump to end of prefetching c + label(.DCOLPFETCH) // column-stored prefetching c + + mov(var(cs_c), rsi) // load cs_c to rsi (temporarily) + lea(mem(, rsi, 8), rsi) // cs_c *= sizeof(double) + lea(mem(r12, rsi, 2), rdx) // + lea(mem(rdx, rsi, 1), rdx) // rdx = c + 3*cs_c; + prefetch(0, mem(r12, 5*8)) // prefetch c + 0*cs_c + prefetch(0, mem(r12, rsi, 1, 5*8)) // prefetch c + 1*cs_c + prefetch(0, mem(r12, rsi, 2, 5*8)) // prefetch c + 2*cs_c + prefetch(0, mem(rdx, 5*8)) // prefetch c + 3*cs_c + prefetch(0, mem(rdx, rsi, 1, 5*8)) // prefetch c + 4*cs_c + prefetch(0, mem(rdx, rsi, 2, 5*8)) // prefetch c + 5*cs_c + + label(.DPOSTPFETCH) // done prefetching c + + +#if 1 + mov(var(ps_a8), rdx) // load ps_a8 + lea(mem(rax, rdx, 1), rdx) // rdx = a + ps_a8 + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; + // use rcx, rdx for prefetching lines + // from next upanel of a. +#else + lea(mem(rax, r8, 4), rdx) // use rdx for prefetching lines + lea(mem(rdx, r8, 2), rdx) // from next upanel of a. + lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; +#endif + + + + + mov(var(k_iter), rsi) // i = k_iter; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT) // if i == 0, jump to code that + // contains the k_left loop. + + + label(.DLOOPKITER) // MAIN LOOP + + + // ---------------------------------- iteration 0 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 1 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, r9, 1, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + + + // ---------------------------------- iteration 2 + +#if 0 + prefetch(0, mem(rdx, 5*8)) +#else + prefetch(0, mem(rdx, r9, 2, 5*8)) +#endif + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), xmm1) + add(r10, rbx) // b += rs_b; vbroadcastsd(mem(rax ), ymm2) vbroadcastsd(mem(rax, r8, 1), ymm3) @@ -1239,14 +7333,14 @@ void bli_dgemmsup_rv_haswell_asm_6x6m vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, r8, 4), ymm2) vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; @@ -1254,7 +7348,7 @@ void bli_dgemmsup_rv_haswell_asm_6x6m vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + // ---------------------------------- iteration 3 @@ -1275,14 +7369,14 @@ void bli_dgemmsup_rv_haswell_asm_6x6m vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, r8, 4), ymm2) vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; @@ -1290,27 +7384,27 @@ void bli_dgemmsup_rv_haswell_asm_6x6m vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - - - + + + dec(rsi) // i -= 1; jne(.DLOOPKITER) // iterate again if i != 0. - - - - - - + + + + + + label(.DCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.DPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.DLOOPKLEFT) // EDGE LOOP - + #if 1 prefetch(0, mem(rdx, 5*8)) add(r9, rdx) @@ -1319,21 +7413,21 @@ void bli_dgemmsup_rv_haswell_asm_6x6m vmovupd(mem(rbx, 0*32), ymm0) vmovupd(mem(rbx, 1*32), xmm1) add(r10, rbx) // b += rs_b; - + vbroadcastsd(mem(rax ), ymm2) vbroadcastsd(mem(rax, r8, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, r8, 4), ymm2) vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; @@ -1341,53 +7435,53 @@ void bli_dgemmsup_rv_haswell_asm_6x6m vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - - + + dec(rsi) // i -= 1; jne(.DLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.DPOSTACCUM) - - + + mov(r12, rcx) // reset rcx to current utile of c. mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate - + vmulpd(ymm0, ymm4, ymm4) // scale by alpha - vmulpd(xmm0, xmm5, xmm5) + vmulpd(ymm0, ymm5, ymm5) vmulpd(ymm0, ymm6, ymm6) - vmulpd(xmm0, xmm7, xmm7) + vmulpd(ymm0, ymm7, ymm7) vmulpd(ymm0, ymm8, ymm8) - vmulpd(xmm0, xmm9, xmm9) + vmulpd(ymm0, ymm9, ymm9) vmulpd(ymm0, ymm10, ymm10) - vmulpd(xmm0, xmm11, xmm11) + vmulpd(ymm0, ymm11, ymm11) vmulpd(ymm0, ymm12, ymm12) - vmulpd(xmm0, xmm13, xmm13) + vmulpd(ymm0, ymm13, ymm13) vmulpd(ymm0, ymm14, ymm14) - vmulpd(xmm0, xmm15, xmm15) - - - - - - + vmulpd(ymm0, ymm15, ymm15) + + + + + + mov(var(cs_c), rsi) // load cs_c lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) - + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; - - - + + + // now avoid loading C if beta == 0 - + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomisd(xmm0, xmm3) // set ZF if beta == 0. je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case @@ -1396,60 +7490,60 @@ void bli_dgemmsup_rv_haswell_asm_6x6m cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. jz(.DCOLSTORED) // jump to column storage case - - + + label(.DROWSTORED) - - + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) vmovupd(ymm4, mem(rcx, 0*32)) vfmadd231pd(mem(rcx, 1*32), xmm3, xmm5) vmovupd(xmm5, mem(rcx, 1*32)) add(rdi, rcx) - - + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) vmovupd(ymm6, mem(rcx, 0*32)) vfmadd231pd(mem(rcx, 1*32), xmm3, xmm7) vmovupd(xmm7, mem(rcx, 1*32)) add(rdi, rcx) - - + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) vmovupd(ymm8, mem(rcx, 0*32)) vfmadd231pd(mem(rcx, 1*32), xmm3, xmm9) vmovupd(xmm9, mem(rcx, 1*32)) add(rdi, rcx) - - + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm10) vmovupd(ymm10, mem(rcx, 0*32)) vfmadd231pd(mem(rcx, 1*32), xmm3, xmm11) vmovupd(xmm11, mem(rcx, 1*32)) add(rdi, rcx) - - + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm12) vmovupd(ymm12, mem(rcx, 0*32)) vfmadd231pd(mem(rcx, 1*32), xmm3, xmm13) vmovupd(xmm13, mem(rcx, 1*32)) add(rdi, rcx) - - + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm14) vmovupd(ymm14, mem(rcx, 0*32)) vfmadd231pd(mem(rcx, 1*32), xmm3, xmm15) vmovupd(xmm15, mem(rcx, 1*32)) //add(rdi, rcx) - - + + jmp(.DDONE) // jump to end. @@ -1524,51 +7618,51 @@ void bli_dgemmsup_rv_haswell_asm_6x6m jmp(.DDONE) // jump to end. - - - - + + + + label(.DBETAZERO) - + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. jz(.DCOLSTORBZ) // jump to column storage case - + label(.DROWSTORBZ) - - + + vmovupd(ymm4, mem(rcx, 0*32)) vmovupd(xmm5, mem(rcx, 1*32)) add(rdi, rcx) - + vmovupd(ymm6, mem(rcx, 0*32)) vmovupd(xmm7, mem(rcx, 1*32)) add(rdi, rcx) - - + + vmovupd(ymm8, mem(rcx, 0*32)) vmovupd(xmm9, mem(rcx, 1*32)) add(rdi, rcx) - - + + vmovupd(ymm10, mem(rcx, 0*32)) vmovupd(xmm11, mem(rcx, 1*32)) add(rdi, rcx) - - + + vmovupd(ymm12, mem(rcx, 0*32)) vmovupd(xmm13, mem(rcx, 1*32)) add(rdi, rcx) - - + + vmovupd(ymm14, mem(rcx, 0*32)) vmovupd(xmm15, mem(rcx, 1*32)) //add(rdi, rcx) - - + + jmp(.DDONE) // jump to end. @@ -1625,9 +7719,9 @@ void bli_dgemmsup_rv_haswell_asm_6x6m //lea(mem(rdx, rsi, 4), rdx) - - - + + + label(.DDONE) @@ -1648,8 +7742,8 @@ void bli_dgemmsup_rv_haswell_asm_6x6m label(.DRETURN) - - + + end_asm( : // output operands (none) @@ -1810,9 +7904,9 @@ void bli_dgemmsup_rv_haswell_asm_6x4m // ------------------------------------------------------------------------- begin_asm() - + //vzeroall() // zero all xmm/ymm registers. - + mov(var(a), r14) // load address of a. mov(var(rs_a), r8) // load rs_a mov(var(cs_a), r9) // load cs_a @@ -1827,7 +7921,7 @@ void bli_dgemmsup_rv_haswell_asm_6x4m //mov(var(cs_b), r11) // load cs_b lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) - + // NOTE: We cannot pre-load elements of a or b // because it could eventually, in the last // unrolled iter or the cleanup loop, result @@ -1858,11 +7952,11 @@ void bli_dgemmsup_rv_haswell_asm_6x4m // a latency of 1 cycle, while vzeroall // has a latency of 12 cycles. vxorpd(ymm4, ymm4, ymm4) - vmovapd(ymm4, ymm6) - vmovapd(ymm4, ymm8) - vmovapd(ymm4, ymm10) - vmovapd(ymm4, ymm12) - vmovapd(ymm4, ymm14) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm14) #endif mov(var(b), rbx) // load address of b. @@ -1912,17 +8006,17 @@ void bli_dgemmsup_rv_haswell_asm_6x4m #endif - - + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.DLOOPKITER) // MAIN LOOP - - + + // ---------------------------------- iteration 0 #if 0 @@ -1930,7 +8024,7 @@ void bli_dgemmsup_rv_haswell_asm_6x4m #else prefetch(0, mem(rdx, 5*8)) #endif - + vmovupd(mem(rbx, 0*32), ymm0) add(r10, rbx) // b += rs_b; @@ -1938,19 +8032,19 @@ void bli_dgemmsup_rv_haswell_asm_6x4m vbroadcastsd(mem(rax, r8, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm0, ymm3, ymm6) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm0, ymm3, ymm10) - + vbroadcastsd(mem(rax, r8, 4), ymm2) vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm0, ymm3, ymm14) - + // ---------------------------------- iteration 1 #if 0 @@ -1966,18 +8060,18 @@ void bli_dgemmsup_rv_haswell_asm_6x4m vbroadcastsd(mem(rax, r8, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm0, ymm3, ymm6) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm0, ymm3, ymm10) - + vbroadcastsd(mem(rax, r8, 4), ymm2) vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm0, ymm3, ymm14) - + // ---------------------------------- iteration 2 @@ -1986,7 +8080,7 @@ void bli_dgemmsup_rv_haswell_asm_6x4m #else prefetch(0, mem(rdx, r9, 2, 5*8)) #endif - + vmovupd(mem(rbx, 0*32), ymm0) add(r10, rbx) // b += rs_b; @@ -1994,18 +8088,18 @@ void bli_dgemmsup_rv_haswell_asm_6x4m vbroadcastsd(mem(rax, r8, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm0, ymm3, ymm6) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm0, ymm3, ymm10) - + vbroadcastsd(mem(rax, r8, 4), ymm2) vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm0, ymm3, ymm14) - + // ---------------------------------- iteration 3 @@ -2023,38 +8117,38 @@ void bli_dgemmsup_rv_haswell_asm_6x4m vbroadcastsd(mem(rax, r8, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm0, ymm3, ymm6) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm0, ymm3, ymm10) - + vbroadcastsd(mem(rax, r8, 4), ymm2) vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm0, ymm3, ymm14) - - - + + + dec(rsi) // i -= 1; jne(.DLOOPKITER) // iterate again if i != 0. - - - - - - + + + + + + label(.DCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.DPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.DLOOPKLEFT) // EDGE LOOP - + #if 1 prefetch(0, mem(rdx, 5*8)) add(r9, rdx) @@ -2067,58 +8161,58 @@ void bli_dgemmsup_rv_haswell_asm_6x4m vbroadcastsd(mem(rax, r8, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm0, ymm3, ymm6) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm0, ymm3, ymm10) - + vbroadcastsd(mem(rax, r8, 4), ymm2) vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm0, ymm3, ymm14) - - + + dec(rsi) // i -= 1; jne(.DLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.DPOSTACCUM) - + mov(r12, rcx) // reset rcx to current utile of c. mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate - + vmulpd(ymm0, ymm4, ymm4) // scale by alpha vmulpd(ymm0, ymm6, ymm6) vmulpd(ymm0, ymm8, ymm8) vmulpd(ymm0, ymm10, ymm10) vmulpd(ymm0, ymm12, ymm12) vmulpd(ymm0, ymm14, ymm14) - - - - - - + + + + + + mov(var(cs_c), rsi) // load cs_c lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) - + //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; - - - + + + // now avoid loading C if beta == 0 - + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomisd(xmm0, xmm3) // set ZF if beta == 0. je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case @@ -2127,42 +8221,42 @@ void bli_dgemmsup_rv_haswell_asm_6x4m cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. jz(.DCOLSTORED) // jump to column storage case - - + + label(.DROWSTORED) - - + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm4) vmovupd(ymm4, mem(rcx, 0*32)) add(rdi, rcx) - - + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm6) vmovupd(ymm6, mem(rcx, 0*32)) add(rdi, rcx) - - + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm8) vmovupd(ymm8, mem(rcx, 0*32)) add(rdi, rcx) - - + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm10) vmovupd(ymm10, mem(rcx, 0*32)) add(rdi, rcx) - - + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm12) vmovupd(ymm12, mem(rcx, 0*32)) add(rdi, rcx) - - + + vfmadd231pd(mem(rcx, 0*32), ymm3, ymm14) vmovupd(ymm14, mem(rcx, 0*32)) //add(rdi, rcx) - - + + jmp(.DDONE) // jump to end. @@ -2210,45 +8304,45 @@ void bli_dgemmsup_rv_haswell_asm_6x4m jmp(.DDONE) // jump to end. - - - - + + + + label(.DBETAZERO) cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. jz(.DCOLSTORBZ) // jump to column storage case - - + + label(.DROWSTORBZ) - - + + vmovupd(ymm4, mem(rcx, 0*32)) add(rdi, rcx) - + vmovupd(ymm6, mem(rcx, 0*32)) add(rdi, rcx) - - + + vmovupd(ymm8, mem(rcx, 0*32)) add(rdi, rcx) - + vmovupd(ymm10, mem(rcx, 0*32)) add(rdi, rcx) - - + + vmovupd(ymm12, mem(rcx, 0*32)) add(rdi, rcx) - + vmovupd(ymm14, mem(rcx, 0*32)) //add(rdi, rcx) - + jmp(.DDONE) // jump to end. @@ -2283,15 +8377,15 @@ void bli_dgemmsup_rv_haswell_asm_6x4m vmovupd(xmm4, mem(rdx, rax, 1)) //lea(mem(rdx, rsi, 4), rdx) - - - - + + + + label(.DDONE) - + lea(mem(r12, rdi, 4), r12) // lea(mem(r12, rdi, 2), r12) // c_ii = r12 += 6*rs_c @@ -2307,8 +8401,8 @@ void bli_dgemmsup_rv_haswell_asm_6x4m label(.DRETURN) - - + + end_asm( : // output operands (none) @@ -2469,9 +8563,9 @@ void bli_dgemmsup_rv_haswell_asm_6x2m // ------------------------------------------------------------------------- begin_asm() - + //vzeroall() // zero all xmm/ymm registers. - + mov(var(a), r14) // load address of a. mov(var(rs_a), r8) // load rs_a mov(var(cs_a), r9) // load cs_a @@ -2486,7 +8580,7 @@ void bli_dgemmsup_rv_haswell_asm_6x2m //mov(var(cs_b), r11) // load cs_b lea(mem(, r10, 8), r10) // rs_b *= sizeof(double) //lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) - + // NOTE: We cannot pre-load elements of a or b // because it could eventually, in the last // unrolled iter or the cleanup loop, result @@ -2517,11 +8611,11 @@ void bli_dgemmsup_rv_haswell_asm_6x2m // a latency of 1 cycle, while vzeroall // has a latency of 12 cycles. vxorpd(xmm4, xmm4, xmm4) - vxorpd(xmm6, xmm6, xmm6) - vxorpd(xmm8, xmm8, xmm8) - vxorpd(xmm10, xmm10, xmm10) - vxorpd(xmm12, xmm12, xmm12) - vxorpd(xmm14, xmm14, xmm14) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm14) #endif mov(var(b), rbx) // load address of b. @@ -2565,19 +8659,19 @@ void bli_dgemmsup_rv_haswell_asm_6x2m lea(mem(rdx, r8, 2), rdx) // from next upanel of a. lea(mem(r9, r9, 2), rcx) // rcx = 3*cs_a; #endif - - - - + + + + mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.DCONSIDKLEFT) // if i == 0, jump to code that // contains the k_left loop. - - + + label(.DLOOPKITER) // MAIN LOOP - - + + // ---------------------------------- iteration 0 #if 0 @@ -2585,7 +8679,7 @@ void bli_dgemmsup_rv_haswell_asm_6x2m #else prefetch(0, mem(rdx, 5*8)) #endif - + vmovupd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; @@ -2593,19 +8687,19 @@ void bli_dgemmsup_rv_haswell_asm_6x2m vbroadcastsd(mem(rax, r8, 1), ymm3) vfmadd231pd(xmm0, xmm2, xmm4) vfmadd231pd(xmm0, xmm3, xmm6) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(xmm0, xmm2, xmm8) vfmadd231pd(xmm0, xmm3, xmm10) - + vbroadcastsd(mem(rax, r8, 4), ymm2) vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; vfmadd231pd(xmm0, xmm2, xmm12) vfmadd231pd(xmm0, xmm3, xmm14) - + // ---------------------------------- iteration 1 #if 0 @@ -2621,18 +8715,18 @@ void bli_dgemmsup_rv_haswell_asm_6x2m vbroadcastsd(mem(rax, r8, 1), ymm3) vfmadd231pd(xmm0, xmm2, xmm4) vfmadd231pd(xmm0, xmm3, xmm6) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(xmm0, xmm2, xmm8) vfmadd231pd(xmm0, xmm3, xmm10) - + vbroadcastsd(mem(rax, r8, 4), ymm2) vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; vfmadd231pd(xmm0, xmm2, xmm12) vfmadd231pd(xmm0, xmm3, xmm14) - + // ---------------------------------- iteration 2 @@ -2641,7 +8735,7 @@ void bli_dgemmsup_rv_haswell_asm_6x2m #else prefetch(0, mem(rdx, r9, 2, 5*8)) #endif - + vmovupd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; @@ -2649,18 +8743,18 @@ void bli_dgemmsup_rv_haswell_asm_6x2m vbroadcastsd(mem(rax, r8, 1), ymm3) vfmadd231pd(xmm0, xmm2, xmm4) vfmadd231pd(xmm0, xmm3, xmm6) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(xmm0, xmm2, xmm8) vfmadd231pd(xmm0, xmm3, xmm10) - + vbroadcastsd(mem(rax, r8, 4), ymm2) vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; vfmadd231pd(xmm0, xmm2, xmm12) vfmadd231pd(xmm0, xmm3, xmm14) - + // ---------------------------------- iteration 3 @@ -2678,43 +8772,43 @@ void bli_dgemmsup_rv_haswell_asm_6x2m vbroadcastsd(mem(rax, r8, 1), ymm3) vfmadd231pd(xmm0, xmm2, xmm4) vfmadd231pd(xmm0, xmm3, xmm6) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(xmm0, xmm2, xmm8) vfmadd231pd(xmm0, xmm3, xmm10) - + vbroadcastsd(mem(rax, r8, 4), ymm2) vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; vfmadd231pd(xmm0, xmm2, xmm12) vfmadd231pd(xmm0, xmm3, xmm14) - - - + + + dec(rsi) // i -= 1; jne(.DLOOPKITER) // iterate again if i != 0. - - - - - - + + + + + + label(.DCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.DPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.DLOOPKLEFT) // EDGE LOOP #if 1 prefetch(0, mem(rdx, 5*8)) add(r9, rdx) #endif - + vmovupd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; @@ -2722,58 +8816,57 @@ void bli_dgemmsup_rv_haswell_asm_6x2m vbroadcastsd(mem(rax, r8, 1), ymm3) vfmadd231pd(xmm0, xmm2, xmm4) vfmadd231pd(xmm0, xmm3, xmm6) - + vbroadcastsd(mem(rax, r8, 2), ymm2) vbroadcastsd(mem(rax, r13, 1), ymm3) vfmadd231pd(xmm0, xmm2, xmm8) vfmadd231pd(xmm0, xmm3, xmm10) - + vbroadcastsd(mem(rax, r8, 4), ymm2) vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; vfmadd231pd(xmm0, xmm2, xmm12) vfmadd231pd(xmm0, xmm3, xmm14) - - + + dec(rsi) // i -= 1; jne(.DLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.DPOSTACCUM) - + mov(r12, rcx) // reset rcx to current utile of c. mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate - - vmulpd(xmm0, xmm4, xmm4) // scale by alpha - vmulpd(xmm0, xmm6, xmm6) - vmulpd(xmm0, xmm8, xmm8) - vmulpd(xmm0, xmm10, xmm10) - vmulpd(xmm0, xmm12, xmm12) - vmulpd(xmm0, xmm14, xmm14) - - - - - - + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm6, ymm6) + vmulpd(ymm0, ymm8, ymm8) + vmulpd(ymm0, ymm10, ymm10) + vmulpd(ymm0, ymm12, ymm12) + vmulpd(ymm0, ymm14, ymm14) + + + + + + mov(var(cs_c), rsi) // load cs_c lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) - - //lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; + lea(mem(rcx, rdi, 4), rdx) // load address of c + 4*rs_c; //lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; - - - + + + // now avoid loading C if beta == 0 - + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomisd(xmm0, xmm3) // set ZF if beta == 0. je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case @@ -2782,42 +8875,42 @@ void bli_dgemmsup_rv_haswell_asm_6x2m cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. jz(.DCOLSTORED) // jump to column storage case - - + + label(.DROWSTORED) - - + + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm4) vmovupd(xmm4, mem(rcx, 0*32)) add(rdi, rcx) - - + + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm6) vmovupd(xmm6, mem(rcx, 0*32)) add(rdi, rcx) - - + + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm8) vmovupd(xmm8, mem(rcx, 0*32)) add(rdi, rcx) - - + + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm10) vmovupd(xmm10, mem(rcx, 0*32)) add(rdi, rcx) - - + + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm12) vmovupd(xmm12, mem(rcx, 0*32)) add(rdi, rcx) - - + + vfmadd231pd(mem(rcx, 0*32), xmm3, xmm14) vmovupd(xmm14, mem(rcx, 0*32)) //add(rdi, rcx) - - + + jmp(.DDONE) // jump to end. @@ -2853,40 +8946,40 @@ void bli_dgemmsup_rv_haswell_asm_6x2m jmp(.DDONE) // jump to end. - - - - + + + + label(.DBETAZERO) cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. jz(.DCOLSTORBZ) // jump to column storage case - - + + label(.DROWSTORBZ) - - + + vmovupd(xmm4, mem(rcx, 0*32)) add(rdi, rcx) - + vmovupd(xmm6, mem(rcx, 0*32)) add(rdi, rcx) - - + + vmovupd(xmm8, mem(rcx, 0*32)) add(rdi, rcx) - + vmovupd(xmm10, mem(rcx, 0*32)) add(rdi, rcx) - - + + vmovupd(xmm12, mem(rcx, 0*32)) add(rdi, rcx) - + vmovupd(xmm14, mem(rcx, 0*32)) //add(rdi, rcx) @@ -2897,7 +8990,7 @@ void bli_dgemmsup_rv_haswell_asm_6x2m label(.DCOLSTORBZ) - + // begin I/O on columns 0-3 vunpcklpd(xmm6, xmm4, xmm0) vunpckhpd(xmm6, xmm4, xmm1) @@ -2918,10 +9011,10 @@ void bli_dgemmsup_rv_haswell_asm_6x2m vmovupd(xmm1, mem(rdx, rsi, 1)) //lea(mem(rdx, rsi, 4), rdx) - - - - + + + + label(.DDONE) @@ -2943,7 +9036,7 @@ void bli_dgemmsup_rv_haswell_asm_6x2m label(.DRETURN) - + end_asm( : // output operands (none) @@ -3060,3 +9153,4 @@ void bli_dgemmsup_rv_haswell_asm_6x2m AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); } + From c85bbfdb502e073bae6de62fab5afbf12540122b Mon Sep 17 00:00:00 2001 From: "Dipal M. Zambare" Date: Fri, 12 Aug 2022 05:33:05 +0000 Subject: [PATCH 171/243] Updated BLIS version string format MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Updated version string to match the recommended format “AOCL-BLIS 3.2.1 Build 20220727”. - Fixed issues with include paths which was preventing compile time version sting definition passing via build commands. - Removed version string determination based on git tag using ‘git describe’, version string will always be taken from the version file. AMD-Internal: [CPUPL-2324] Change-Id: Idc7edf1211f66d348ec3b5b43f2507c2b810f088 --- CMakeLists.txt | 3 ++- common.mk | 3 ++- configure | 58 ++----------------------------------------- frame/base/bli_info.c | 4 +-- 4 files changed, 8 insertions(+), 60 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index dce532b07c..a46f2d664e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -565,7 +565,8 @@ message( STATUS "Generating monolithic cblas header file :" ${CMD_OUTPUT}) # setting the blis version string file (STRINGS "version" BLIS_VERSION) set(BLIS_VERSION_STRING ${BLIS_VERSION}) -add_definitions(-DBLIS_VERSION_STRING="AOCL BLIS ${BLIS_VERSION_STRING}") +string(TIMESTAMP BUILD_DATE "%Y%m%d") +add_definitions(-DBLIS_VERSION_STRING="AOCL-BLIS ${BLIS_VERSION_STRING} Build ${BUILD_DATE}") if(BUILD_SHARED_LIBS) add_library("${PROJECT_NAME}" SHARED ${CMAKE_SOURCE_DIR}/bli_config.h diff --git a/common.mk b/common.mk index 02f34360af..0fdf659a9f 100644 --- a/common.mk +++ b/common.mk @@ -5,7 +5,7 @@ # libraries. # # Copyright (C) 2014, The University of Texas at Austin -# Copyright (C) 2020-2021, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2020-2022, Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -1084,6 +1084,7 @@ BLIS_H_FLAT := $(BASE_INC_PATH)/$(BLIS_H) # header files. CBLAS_H := cblas.h CBLAS_H_SRC_PATH := $(filter %/$(CBLAS_H), $(FRAME_H99_FILES)) +CBLAS_H_DIRPATH := $(dir $(CBLAS_H_SRC_PATH)) # Construct the path to what will be the intermediate flattened/monolithic # cblas.h file. diff --git a/configure b/configure index b35348abdc..e18f98f6ae 100755 --- a/configure +++ b/configure @@ -1819,67 +1819,13 @@ try_assemble() set_default_version() { - local gitdir version_file gd_stderr git_describe_str git_error new_version_str - - gitdir='.git' - # The path to the version file. version_file=$1 echo "${script_name}: determining default version string." - # Check if the .git dir exists; if it does not, we do nothing. - if [ -d "${dist_path}/${gitdir}" ]; then - - echo "${script_name}: found '${gitdir}' directory; assuming git clone." - - echo "${script_name}: executing: git describe --tags." - - gd_stderr="git_describe_stderr.txt" - - # Query git for the version string, which is simply the current tag, - # followed by a number signifying how many commits have transpired - # since the tag, followed by a 'g' and a shortened hash tab. Capture - # stderr to a file. - git_describe_str=$(git -C ${dist_path} describe --tags 2> ${gd_stderr}) - - # Pull in whatever error message was generated, if any, and delete - # the file. - git_error=$(cat ${gd_stderr}) - - # Remove the stderr file. - rm -f ${gd_stderr} - - # If git returned an error, don't do anything. - if [ -n "${git_error}" ]; then - - echo "${script_name}: git returned an error: '${git_error}'." - echo "${script_name}: using string from unmodified version file." - - # Use what's in the version file as-is. - version="AOCL BLIS $(cat "${version_file}")" - else - - echo "${script_name}: got back ${git_describe_str}." - - # Strip off the commit hash label. - new_version_str=$(echo ${git_describe_str} | cut -d- -f2) - - echo "${script_name}: truncating to ${new_version_str}." - - # Write the new version string to the version file. - #echo "${new_version_str}" > ${version_file} - - # Set the version variable. - version="AOCL BLIS ${new_version_str}" - fi - else - - echo "${script_name}: could not find '${gitdir}' directory; using unmodified version file." - - # Use what's in the version file as-is. - version="AOCL BLIS $(cat "${version_file}")" - fi + # Use what's in the version file as-is. + version="AOCL-BLIS $(cat "${version_file}") Build $(date +%Y%m%d)" } diff --git a/frame/base/bli_info.c b/frame/base/bli_info.c index 73ac5b3f57..cc350ab606 100644 --- a/frame/base/bli_info.c +++ b/frame/base/bli_info.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -40,7 +40,7 @@ // This string gets defined via -D on the command line when BLIS is compiled. // This string is (or rather, should be) only used here. -static char* bli_version_str = "3.2.1"; //BLIS_VERSION_STRING; +static char* bli_version_str = BLIS_VERSION_STRING; static char* bli_int_type_size_str = STRINGIFY_INT( BLIS_INT_TYPE_SIZE ); char* bli_info_get_version_str( void ) { return bli_version_str; } From a226e54421073f95648229588113ef283ee8cc48 Mon Sep 17 00:00:00 2001 From: Arnav Sharma Date: Thu, 11 Aug 2022 08:45:49 +0000 Subject: [PATCH 172/243] AVX512 based SGEMM Optimizations - Updated with optimal cache-blocking sizes for MC, KC and NC for AVX512 Native SGEMM kernel. AMD-Internal: [CPUPL-2385] Change-Id: I1feae5ac79e960c6b26df24756d460243820b797 --- config/zen4/bli_cntx_init_zen4.c | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/config/zen4/bli_cntx_init_zen4.c b/config/zen4/bli_cntx_init_zen4.c index fc900a1e98..5f0728f18a 100644 --- a/config/zen4/bli_cntx_init_zen4.c +++ b/config/zen4/bli_cntx_init_zen4.c @@ -43,10 +43,10 @@ /* s d c z */ \ bli_blksz_init_easy( &blkszs[ BLIS_MR ], 32, 16, 3, 3 ); \ bli_blksz_init_easy( &blkszs[ BLIS_NR ], 12, 14, 8, 4 ); \ - bli_blksz_init_easy( &blkszs[ BLIS_MC ], 480, 240, 144, 18 ); \ - bli_blksz_init ( &blkszs[ BLIS_KC ], 384, 512, 256, 566, \ + bli_blksz_init_easy( &blkszs[ BLIS_MC ], 512, 240, 144, 18 ); \ + bli_blksz_init ( &blkszs[ BLIS_KC ], 480, 512, 256, 566, \ 480, 320, 256, 566 ); \ - bli_blksz_init_easy( &blkszs[ BLIS_NC ], 3072, 4004, 4080, 256 ); \ + bli_blksz_init_easy( &blkszs[ BLIS_NC ], 6144, 4004, 4080, 256 ); \ \ bli_blksz_init_easy( &blkszs[ BLIS_AF ], 8, 8, -1, -1 ); \ bli_blksz_init_easy( &blkszs[ BLIS_DF ], 8, 8, -1, -1 ); \ From f5ef30a44aff2e6423a400aa277c0d9f399a8174 Mon Sep 17 00:00:00 2001 From: Shubham Sharma Date: Fri, 12 Aug 2022 21:00:52 +0530 Subject: [PATCH 173/243] Fix in DGEMMT SUP kernel Details: 1. Due to error in C output buffer address computation in kernel bli_dgemmsup_rv_haswell_asm_6x8m_6x8_L, invalid memory is being accessed. This is causing seg fault in libflame netlib testing. 2. Validated the fix with libflame netlib testing. AMD-Internal: [CPUPL-2341] Change-Id: I9ca0cf09cf2d177ade73f840054b5028eae3a0ed --- kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8m.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8m.c b/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8m.c index 21c394f558..41da73f361 100644 --- a/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8m.c +++ b/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8m.c @@ -1809,7 +1809,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_6x8_L jz(.DCOLSTORBZ) // jump to column storage case label(.DROWSTORBZ) - lea(mem(rdi, rcx, 2), rdi) + lea(mem(rcx , rdi, 2), rcx) vmovlpd(xmm8, mem(rcx, 0*32)) add(rdi, rcx) From 6fbdfc3cf2d0859289d7e319cf40769e0336955c Mon Sep 17 00:00:00 2001 From: mkadavil Date: Thu, 11 Aug 2022 18:49:03 +0530 Subject: [PATCH 174/243] Low precision gemm refactoring and bug fixes. -The micro-kernel function signatures follow a common pattern. These functions can be represented as an instantiation of a MACRO as is done in BLIS, and thus the number of micro-kernel header files can be brought down. A new single header file containing all the MACRO definitions with the instantiation is added, and the existing unnecessary header files are removed. -The bias addition in micro-kernel for n remaining < 16 reads the bias array assuming it contains 16 elements. This can result in seg-faults, since out of bound memory is accessed. It is fixed by copying required elements to an intermediate buffer and using that buffer for loading. -Input matrix storage type parameter is added to lpgemm APIs. It can be either row or column major, denoted by r and c respectively. Currently only row major input matrices are supported. -Bug fix in s16 fringe micro-kernel to use correct offset while storing output. AMD-Internal: [CPUPL-2386] Change-Id: Idfa23e69d54ad7e06a67b1e36a5b5558fbff03a3 --- addon/aocl_gemm/aocl_gemm.h | 7 +- addon/aocl_gemm/aocl_gemm_f32f32f32of32.c | 30 +- addon/aocl_gemm/aocl_gemm_f32f32f32of32.h | 63 --- .../aocl_gemm/aocl_gemm_f32f32f32of32_utils.c | 25 +- ...f32_utils.h => aocl_gemm_interface_apis.h} | 79 +++- addon/aocl_gemm/aocl_gemm_u8s8s16os16.c | 31 +- addon/aocl_gemm/aocl_gemm_u8s8s16os16.h | 62 --- addon/aocl_gemm/aocl_gemm_u8s8s16os16_utils.c | 19 +- addon/aocl_gemm/aocl_gemm_u8s8s16os16_utils.h | 55 --- addon/aocl_gemm/aocl_gemm_u8s8s32os32.c | 28 +- addon/aocl_gemm/aocl_gemm_u8s8s32os32.h | 62 --- addon/aocl_gemm/aocl_gemm_u8s8s32os32_utils.c | 19 +- addon/aocl_gemm/aocl_gemm_u8s8s32os32_utils.h | 55 --- .../frame/f32f32f32/lpgemm_f32f32f32.c | 24 +- ...f32f32.h => lpgemm_5loop_interface_apis.h} | 54 ++- .../threading/lpgemm_thread_decor_openmp.c | 4 +- .../aocl_gemm/frame/u8s8s16/lpgemm_u8s8s16.c | 27 +- .../aocl_gemm/frame/u8s8s16/lpgemm_u8s8s16.h | 64 --- .../aocl_gemm/frame/u8s8s32/lpgemm_u8s8s32.c | 26 +- .../aocl_gemm/frame/u8s8s32/lpgemm_u8s8s32.h | 64 --- addon/aocl_gemm/kernels/lpgemm_kernels.h | 223 +++++++++ .../kernels/u8s8s16/lpgemm_6x32rowmajor.h | 64 --- .../u8s8s16/lpgemm_6x32rowmajor_amd256.c | 29 +- .../kernels/u8s8s16/lpgemm_m_fringe_amd256.c | 59 +-- .../kernels/u8s8s16/lpgemm_m_fringe_s16.h | 100 ---- .../kernels/u8s8s16/lpgemm_mn_fringe_amd256.c | 119 +---- .../kernels/u8s8s16/lpgemm_mn_fringe_s16.h | 157 ------ .../kernels/u8s8s16/lpgemm_n_fringe_amd256.c | 58 +-- .../kernels/u8s8s16/lpgemm_n_fringe_s16.h | 84 ---- .../kernels/u8s8s32/lpgemm_6x64rowmajor.h | 64 --- .../u8s8s32/lpgemm_6x64rowmajor_amd512vnni.c | 27 +- .../kernels/u8s8s32/lpgemm_m_fringe.h | 140 ------ .../u8s8s32/lpgemm_m_fringe_amd512vnni.c | 97 +--- .../kernels/u8s8s32/lpgemm_mn_fringe.h | 445 ------------------ .../u8s8s32/lpgemm_mn_fringe_amd512vnni.c | 417 ++-------------- .../kernels/u8s8s32/lpgemm_n_fringe.h | 129 ----- .../u8s8s32/lpgemm_n_fringe_amd512vnni.c | 94 +--- bench/bench_aocl_gemm/bench_lpgemm.c | 3 +- 38 files changed, 436 insertions(+), 2671 deletions(-) delete mode 100644 addon/aocl_gemm/aocl_gemm_f32f32f32of32.h rename addon/aocl_gemm/{aocl_gemm_f32f32f32of32_utils.h => aocl_gemm_interface_apis.h} (51%) delete mode 100644 addon/aocl_gemm/aocl_gemm_u8s8s16os16.h delete mode 100644 addon/aocl_gemm/aocl_gemm_u8s8s16os16_utils.h delete mode 100644 addon/aocl_gemm/aocl_gemm_u8s8s32os32.h delete mode 100644 addon/aocl_gemm/aocl_gemm_u8s8s32os32_utils.h rename addon/aocl_gemm/frame/{f32f32f32/lpgemm_f32f32f32.h => lpgemm_5loop_interface_apis.h} (63%) delete mode 100644 addon/aocl_gemm/frame/u8s8s16/lpgemm_u8s8s16.h delete mode 100644 addon/aocl_gemm/frame/u8s8s32/lpgemm_u8s8s32.h create mode 100644 addon/aocl_gemm/kernels/lpgemm_kernels.h delete mode 100644 addon/aocl_gemm/kernels/u8s8s16/lpgemm_6x32rowmajor.h delete mode 100644 addon/aocl_gemm/kernels/u8s8s16/lpgemm_m_fringe_s16.h delete mode 100644 addon/aocl_gemm/kernels/u8s8s16/lpgemm_mn_fringe_s16.h delete mode 100644 addon/aocl_gemm/kernels/u8s8s16/lpgemm_n_fringe_s16.h delete mode 100644 addon/aocl_gemm/kernels/u8s8s32/lpgemm_6x64rowmajor.h delete mode 100644 addon/aocl_gemm/kernels/u8s8s32/lpgemm_m_fringe.h delete mode 100644 addon/aocl_gemm/kernels/u8s8s32/lpgemm_mn_fringe.h delete mode 100644 addon/aocl_gemm/kernels/u8s8s32/lpgemm_n_fringe.h diff --git a/addon/aocl_gemm/aocl_gemm.h b/addon/aocl_gemm/aocl_gemm.h index f9e37e76cb..4e971d932a 100644 --- a/addon/aocl_gemm/aocl_gemm.h +++ b/addon/aocl_gemm/aocl_gemm.h @@ -36,11 +36,6 @@ #define BLIS_ADDON_LPGEMM #include "aocl_gemm_post_ops.h" -#include "aocl_gemm_u8s8s16os16.h" -#include "aocl_gemm_u8s8s32os32.h" -#include "aocl_gemm_f32f32f32of32.h" -#include "aocl_gemm_u8s8s16os16_utils.h" -#include "aocl_gemm_u8s8s32os32_utils.h" -#include "aocl_gemm_f32f32f32of32_utils.h" +#include "aocl_gemm_interface_apis.h" #endif // BLIS_ADDON_LPGEMM diff --git a/addon/aocl_gemm/aocl_gemm_f32f32f32of32.c b/addon/aocl_gemm/aocl_gemm_f32f32f32of32.c index bc9ed29da1..179882d412 100644 --- a/addon/aocl_gemm/aocl_gemm_f32f32f32of32.c +++ b/addon/aocl_gemm/aocl_gemm_f32f32f32of32.c @@ -33,32 +33,14 @@ */ #include "blis.h" -#include "aocl_gemm_f32f32f32of32.h" +#include "aocl_gemm_interface_apis.h" #include "lpgemm_types.h" #include "lpgemm_post_ops.h" #include "lpgemm_thread_decor_openmp.h" #include "lpgemm_utils.h" -#include "lpgemm_f32f32f32.h" - -void aocl_gemm_f32f32f32of32 - ( - const char transa, - const char transb, - const dim_t m, - const dim_t n, - const dim_t k, - const float alpha, - const float* a, - const dim_t lda, - const char mem_format_a, - const float* b, - const dim_t ldb, - const char mem_format_b, - const float beta, - float* c, - const dim_t ldc, - aocl_post_op* post_op_unparsed - ) +#include "lpgemm_5loop_interface_apis.h" + +AOCL_GEMM_MATMUL(float,float,float,f32f32f32of32) { trans_t blis_transa; trans_t blis_transb; @@ -98,6 +80,10 @@ void aocl_gemm_f32f32f32of32 "Input matrix transpose not supported."); return; // Error. } + if ( ( order != 'r' ) && ( order != 'R' ) ) + { + return; // Only row major supported. + } // Row major input expected with leading dimensions equal to row stride. if ( ( lda != k ) || ( ldb != n ) || ( ldc != n ) ) diff --git a/addon/aocl_gemm/aocl_gemm_f32f32f32of32.h b/addon/aocl_gemm/aocl_gemm_f32f32f32of32.h deleted file mode 100644 index 3e450414ea..0000000000 --- a/addon/aocl_gemm/aocl_gemm_f32f32f32of32.h +++ /dev/null @@ -1,63 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#ifndef AOCL_GEMM_F32F32F32OF32_H -#define AOCL_GEMM_F32F32F32OF32_H - -#include "aocl_gemm_post_ops.h" - -// Only supports matrices in row major format. This api can perform gemm with -// both normal as well as reordered B matrix as opposesd to sgemm (only -// supports former). This api can be considered analogous to packed sgemm api. -BLIS_EXPORT_ADDON void aocl_gemm_f32f32f32of32 - ( - const char transa, - const char transb, - const dim_t m, - const dim_t n, - const dim_t k, - const float alpha, - const float* a, - const dim_t lda, - const char mem_format_a, - const float* b, - const dim_t ldb, - const char mem_format_b, - const float beta, - float* c, - const dim_t ldc, - aocl_post_op* post_op_unparsed - ); - -#endif //AOCL_GEMM_F32F32F32OF32_H diff --git a/addon/aocl_gemm/aocl_gemm_f32f32f32of32_utils.c b/addon/aocl_gemm/aocl_gemm_f32f32f32of32_utils.c index 84a611f605..948c1383de 100644 --- a/addon/aocl_gemm/aocl_gemm_f32f32f32of32_utils.c +++ b/addon/aocl_gemm/aocl_gemm_f32f32f32of32_utils.c @@ -33,15 +33,10 @@ */ #include "blis.h" -#include "aocl_gemm_f32f32f32of32_utils.h" +#include "aocl_gemm_interface_apis.h" #include "lpgemm_utils.h" -siz_t aocl_get_reorder_buf_size_f32f32f32of32 - ( - const char mat_type, - const dim_t k, - const dim_t n - ) +AOCL_GEMM_GET_REORDER_BUF_SIZE(f32f32f32of32) { if ( ( k <= 0 ) || ( n <= 0 ) ) { @@ -82,17 +77,9 @@ siz_t aocl_get_reorder_buf_size_f32f32f32of32 } // Pack B into row stored column panels. -void aocl_reorder_f32f32f32of32 - ( - const char mat_type, - const float* input_buf_addr_b, - float* reorder_buf_addr_b, - const dim_t k, - const dim_t n, - const dim_t ldb - ) +AOCL_GEMM_REORDER(float,f32f32f32of32) { - if ( ( input_buf_addr_b == NULL ) || ( reorder_buf_addr_b == NULL ) || + if ( ( input_buf_addr == NULL ) || ( reorder_buf_addr == NULL ) || ( k <= 0 ) || ( n <= 0 ) || ( ldb < n ) ) { return; // Error. @@ -190,7 +177,7 @@ void aocl_reorder_f32f32f32of32 dim_t kc0 = bli_min( ( k - pc ), KC ); inc_t ps_p = kc0 * NR; - const float* b_temp = input_buf_addr_b + ( jc * cs_b ) + ( pc * rs_b ); + const float* b_temp = input_buf_addr + ( jc * cs_b ) + ( pc * rs_b ); // The offsets are calculated in such a way that it resembles // the reorder buffer traversal in single threaded reordering. @@ -227,7 +214,7 @@ void aocl_reorder_f32f32f32of32 // st = ( jc_cur_loop * k ) // + ( n_sub_updated * pc ) // + ( NC' * kc0_updated) - float* p_temp = reorder_buf_addr_b + ( jc_cur_loop * k ) + + float* p_temp = reorder_buf_addr + ( jc_cur_loop * k ) + ( n_sub_updated * pc ) + ( jc_cur_loop_rem * kc0 ); dim_t jr, it; diff --git a/addon/aocl_gemm/aocl_gemm_f32f32f32of32_utils.h b/addon/aocl_gemm/aocl_gemm_interface_apis.h similarity index 51% rename from addon/aocl_gemm/aocl_gemm_f32f32f32of32_utils.h rename to addon/aocl_gemm/aocl_gemm_interface_apis.h index 819e087691..0c656554b1 100644 --- a/addon/aocl_gemm/aocl_gemm_f32f32f32of32_utils.h +++ b/addon/aocl_gemm/aocl_gemm_interface_apis.h @@ -32,28 +32,69 @@ */ -#ifndef AOCL_GEMM_F32F32F32OF32_UTILS_H -#define AOCL_GEMM_F32F32F32OF32_UTILS_H +#ifndef AOCL_GEMM_INTERFACE_H +#define AOCL_GEMM_INTERFACE_H + +#include "aocl_gemm_post_ops.h" // Returns the size of buffer in bytes required for the reordered matrix. -BLIS_EXPORT_ADDON siz_t aocl_get_reorder_buf_size_f32f32f32of32 - ( - const char mat_type, - const dim_t k, - const dim_t n - ); +#define AOCL_GEMM_GET_REORDER_BUF_SIZE(LP_SFX) \ +BLIS_EXPORT_ADDON siz_t aocl_get_reorder_buf_size_ ## LP_SFX \ + ( \ + const char mat_type, \ + const dim_t k, \ + const dim_t n \ + ) \ + +AOCL_GEMM_GET_REORDER_BUF_SIZE(f32f32f32of32); +AOCL_GEMM_GET_REORDER_BUF_SIZE(u8s8s32os32); +AOCL_GEMM_GET_REORDER_BUF_SIZE(u8s8s16os16); // Performs reordering of input matrix. Reordering is the process of packing // the entire matrix upfront, so that the benefits of packed matrix is obtained // without incurring the packing costs during matmul computation. -BLIS_EXPORT_ADDON void aocl_reorder_f32f32f32of32 - ( - const char mat_type, - const float* input_buf_addr_b, - float* reorder_buf_addr_b, - const dim_t k, - const dim_t n, - const dim_t ldb - ); - -#endif //AOCL_GEMM_F32F32F32OF32_UTILS_H +#define AOCL_GEMM_REORDER(B_type,LP_SFX) \ +BLIS_EXPORT_ADDON void aocl_reorder_ ## LP_SFX \ + ( \ + const char mat_type, \ + const B_type* input_buf_addr, \ + B_type* reorder_buf_addr, \ + const dim_t k, \ + const dim_t n, \ + const dim_t ldb \ + ) \ + +AOCL_GEMM_REORDER(float,f32f32f32of32); +AOCL_GEMM_REORDER(int8_t,u8s8s32os32); +AOCL_GEMM_REORDER(int8_t,u8s8s16os16); + +// Only supports matrices in row major format. This api can perform gemm with +// both normal as well as reordered B matrix as opposesd to sgemm (only +// supports former). This api can be considered analogous to packed sgemm api. +#define AOCL_GEMM_MATMUL(A_type,B_type,C_type,LP_SFX) \ +BLIS_EXPORT_ADDON void aocl_gemm_ ## LP_SFX \ + ( \ + const char order, \ + const char transa, \ + const char transb, \ + const dim_t m, \ + const dim_t n, \ + const dim_t k, \ + const C_type alpha, \ + const A_type* a, \ + const dim_t lda, \ + const char mem_format_a, \ + const B_type* b, \ + const dim_t ldb, \ + const char mem_format_b, \ + const C_type beta, \ + C_type* c, \ + const dim_t ldc, \ + aocl_post_op* post_op_unparsed \ + ) \ + +AOCL_GEMM_MATMUL(float,float,float,f32f32f32of32); +AOCL_GEMM_MATMUL(uint8_t,int8_t,int32_t,u8s8s32os32); +AOCL_GEMM_MATMUL(uint8_t,int8_t,int16_t,u8s8s16os16); + +#endif // AOCL_GEMM_INTERFACE_H diff --git a/addon/aocl_gemm/aocl_gemm_u8s8s16os16.c b/addon/aocl_gemm/aocl_gemm_u8s8s16os16.c index 96232a8947..4d0e8565a9 100644 --- a/addon/aocl_gemm/aocl_gemm_u8s8s16os16.c +++ b/addon/aocl_gemm/aocl_gemm_u8s8s16os16.c @@ -33,33 +33,15 @@ */ #include "blis.h" -#include "aocl_gemm_u8s8s16os16.h" +#include "aocl_gemm_interface_apis.h" #include "lpgemm_types.h" -#include "lpgemm_u8s8s16.h" +#include "lpgemm_5loop_interface_apis.h" #include "lpgemm_config.h" #include "lpgemm_utils.h" #include "lpgemm_thread_decor_openmp.h" #include "lpgemm_post_ops.h" -void aocl_gemm_u8s8s16os16 - ( - const char transa, - const char transb, - const dim_t m, - const dim_t n, - const dim_t k, - const int16_t alpha, - const uint8_t* a, - const dim_t lda, - const char mem_format_a, - const int8_t* b, - const dim_t ldb, - const char mem_format_b, - const int16_t beta, - int16_t* c, - const dim_t ldc, - aocl_post_op* post_op_unparsed - ) +AOCL_GEMM_MATMUL(uint8_t,int8_t,int16_t,u8s8s16os16) { trans_t blis_transa; trans_t blis_transb; @@ -89,10 +71,15 @@ void aocl_gemm_u8s8s16os16 /* Perform BLAS parameter checking. */ // Transpose not supported. - if ((blis_transa != BLIS_NO_TRANSPOSE) || (blis_transb != BLIS_NO_TRANSPOSE)) + if ( ( blis_transa != BLIS_NO_TRANSPOSE ) || + ( blis_transb != BLIS_NO_TRANSPOSE ) ) { return; // Error. } + if ( ( order != 'r' ) && ( order != 'R' ) ) + { + return; // Only row major supported. + } // Row major input expected with leading dimensions equal to row stride. if ((lda != k) || (ldb != n) || (ldc != n)) diff --git a/addon/aocl_gemm/aocl_gemm_u8s8s16os16.h b/addon/aocl_gemm/aocl_gemm_u8s8s16os16.h deleted file mode 100644 index 920e8806b0..0000000000 --- a/addon/aocl_gemm/aocl_gemm_u8s8s16os16.h +++ /dev/null @@ -1,62 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#ifndef AOCL_GEMM_U8S8S16OS16_H -#define AOCL_GEMM_U8S8S16OS16_H - -#include "aocl_gemm_post_ops.h" - -// Only supports matrices in row major format -// Limitations: Supports mem_format_b = 'Reorder' -BLIS_EXPORT_ADDON void aocl_gemm_u8s8s16os16 - ( - const char transa, - const char transb, - const dim_t m, - const dim_t n, - const dim_t k, - const int16_t alpha, - const uint8_t* a, - const dim_t lda, - const char mem_format_a, - const int8_t* b, - const dim_t ldb, - const char mem_format_b, - const int16_t beta, - int16_t* c, - const dim_t ldc, - aocl_post_op* post_op_unparsed - ); - -#endif // AOCL_GEMM_U8S8S16OS16_H diff --git a/addon/aocl_gemm/aocl_gemm_u8s8s16os16_utils.c b/addon/aocl_gemm/aocl_gemm_u8s8s16os16_utils.c index cbbae09e1a..5cadd206d5 100644 --- a/addon/aocl_gemm/aocl_gemm_u8s8s16os16_utils.c +++ b/addon/aocl_gemm/aocl_gemm_u8s8s16os16_utils.c @@ -33,18 +33,13 @@ */ #include "blis.h" -#include "aocl_gemm_u8s8s16os16_utils.h" +#include "aocl_gemm_interface_apis.h" #include "lpgemm_types.h" #include "lpgemm_config.h" #include "lpgemm_utils.h" #include "lpgemm_reorder_s16.h" -siz_t aocl_get_reorder_buf_size_u8s8s16os16 - ( - const char mat_type, - const dim_t k, - const dim_t n - ) +AOCL_GEMM_GET_REORDER_BUF_SIZE(u8s8s16os16) { if ((k <= 0) || (n <= 0)) { @@ -87,15 +82,7 @@ siz_t aocl_get_reorder_buf_size_u8s8s16os16 return size_req; } -void aocl_reorder_u8s8s16os16 - ( - const char mat_type, - const int8_t *input_buf_addr, - int8_t *reorder_buf_addr, - const dim_t k, - const dim_t n, - const dim_t ldb - ) +AOCL_GEMM_REORDER(int8_t,u8s8s16os16) { if ((input_buf_addr == NULL) || (reorder_buf_addr == NULL) || (k <= 0) || (n <= 0) || (ldb < n)) diff --git a/addon/aocl_gemm/aocl_gemm_u8s8s16os16_utils.h b/addon/aocl_gemm/aocl_gemm_u8s8s16os16_utils.h deleted file mode 100644 index 21f7276be9..0000000000 --- a/addon/aocl_gemm/aocl_gemm_u8s8s16os16_utils.h +++ /dev/null @@ -1,55 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#ifndef AOCL_GEMM_U8S8S16OS16_UTILS_H -#define AOCL_GEMM_U8S8S16OS16_UTILS_H - -BLIS_EXPORT_ADDON siz_t aocl_get_reorder_buf_size_u8s8s16os16 - ( - const char mat_type, - const dim_t k, - const dim_t n - ); - -BLIS_EXPORT_ADDON void aocl_reorder_u8s8s16os16 - ( - const char mat_type, - const int8_t *input_buf_addr, - int8_t *reorder_buf_addr, - const dim_t k, - const dim_t n, - const dim_t ldb - ); - -#endif // AOCL_GEMM_U8S8S16OS16_UTILS_H \ No newline at end of file diff --git a/addon/aocl_gemm/aocl_gemm_u8s8s32os32.c b/addon/aocl_gemm/aocl_gemm_u8s8s32os32.c index 4c92d3c74c..6263d2ce27 100644 --- a/addon/aocl_gemm/aocl_gemm_u8s8s32os32.c +++ b/addon/aocl_gemm/aocl_gemm_u8s8s32os32.c @@ -33,33 +33,15 @@ */ #include "blis.h" -#include "aocl_gemm_u8s8s32os32.h" +#include "aocl_gemm_interface_apis.h" #include "lpgemm_types.h" #include "lpgemm_post_ops.h" #include "lpgemm_thread_decor_openmp.h" -#include "lpgemm_u8s8s32.h" +#include "lpgemm_5loop_interface_apis.h" #include "lpgemm_config.h" #include "lpgemm_utils.h" -void aocl_gemm_u8s8s32os32 - ( - const char transa, - const char transb, - const dim_t m, - const dim_t n, - const dim_t k, - const int32_t alpha, - const uint8_t* a, - const dim_t lda, - const char mem_format_a, - const int8_t* b, - const dim_t ldb, - const char mem_format_b, - const int32_t beta, - int32_t* c, - const dim_t ldc, - aocl_post_op* post_op_unparsed - ) +AOCL_GEMM_MATMUL(uint8_t,int8_t,int32_t,u8s8s32os32) { trans_t blis_transa; trans_t blis_transb; @@ -94,6 +76,10 @@ void aocl_gemm_u8s8s32os32 { return; // Error. } + if ( ( order != 'r' ) && ( order != 'R' ) ) + { + return; // Only row major supported. + } // Row major input expected with leading dimensions equal to row stride. if ( ( lda != k ) || ( ldb != n ) || ( ldc != n ) ) diff --git a/addon/aocl_gemm/aocl_gemm_u8s8s32os32.h b/addon/aocl_gemm/aocl_gemm_u8s8s32os32.h deleted file mode 100644 index 0993d3562a..0000000000 --- a/addon/aocl_gemm/aocl_gemm_u8s8s32os32.h +++ /dev/null @@ -1,62 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#ifndef AOCL_GEMM_U8S8S32OS32_H -#define AOCL_GEMM_U8S8S32OS32_H - -#include "aocl_gemm_post_ops.h" - -// Only supports matrices in row major format. Currenlty only mem_format_b -// is configurable to reorder. -BLIS_EXPORT_ADDON void aocl_gemm_u8s8s32os32 - ( - const char transa, - const char transb, - const dim_t m, - const dim_t n, - const dim_t k, - const int32_t alpha, - const uint8_t* a, - const dim_t lda, - const char mem_format_a, - const int8_t* b, - const dim_t ldb, - const char mem_format_b, - const int32_t beta, - int32_t* c, - const dim_t ldc, - aocl_post_op* post_op_unparsed - ); - -#endif //AOCL_GEMM_U8S8S32OS32_H diff --git a/addon/aocl_gemm/aocl_gemm_u8s8s32os32_utils.c b/addon/aocl_gemm/aocl_gemm_u8s8s32os32_utils.c index 31a56ef577..11f9f6937a 100644 --- a/addon/aocl_gemm/aocl_gemm_u8s8s32os32_utils.c +++ b/addon/aocl_gemm/aocl_gemm_u8s8s32os32_utils.c @@ -33,18 +33,13 @@ */ #include "blis.h" -#include "aocl_gemm_u8s8s32os32_utils.h" +#include "aocl_gemm_interface_apis.h" #include "lpgemm_types.h" #include "lpgemm_config.h" #include "lpgemm_utils.h" #include "lpgemm_reorder.h" -siz_t aocl_get_reorder_buf_size_u8s8s32os32 - ( - const char mat_type, - const dim_t k, - const dim_t n - ) +AOCL_GEMM_GET_REORDER_BUF_SIZE(u8s8s32os32) { if ( ( k <= 0 ) || ( n <= 0 ) ) { @@ -87,15 +82,7 @@ siz_t aocl_get_reorder_buf_size_u8s8s32os32 return size_req; } -void aocl_reorder_u8s8s32os32 - ( - const char mat_type, - const int8_t* input_buf_addr, - int8_t* reorder_buf_addr, - const dim_t k, - const dim_t n, - const dim_t ldb - ) +AOCL_GEMM_REORDER(int8_t,u8s8s32os32) { if ( ( input_buf_addr == NULL ) || ( reorder_buf_addr == NULL ) || ( k <= 0 ) || ( n <= 0 ) || ( ldb < n ) ) diff --git a/addon/aocl_gemm/aocl_gemm_u8s8s32os32_utils.h b/addon/aocl_gemm/aocl_gemm_u8s8s32os32_utils.h deleted file mode 100644 index d23660a1ec..0000000000 --- a/addon/aocl_gemm/aocl_gemm_u8s8s32os32_utils.h +++ /dev/null @@ -1,55 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#ifndef AOCL_GEMM_U8S8S32OS32_UTILS_H -#define AOCL_GEMM_U8S8S32OS32_UTILS_H - -BLIS_EXPORT_ADDON siz_t aocl_get_reorder_buf_size_u8s8s32os32 - ( - const char mat_type, - const dim_t k, - const dim_t n - ); - -BLIS_EXPORT_ADDON void aocl_reorder_u8s8s32os32 - ( - const char mat_type, - const int8_t* input_buf_addr, - int8_t* reorder_buf_addr, - const dim_t k, - const dim_t n, - const dim_t ldb - ); - -#endif //AOCL_GEMM_U8S8S32OS32_UTILS_H diff --git a/addon/aocl_gemm/frame/f32f32f32/lpgemm_f32f32f32.c b/addon/aocl_gemm/frame/f32f32f32/lpgemm_f32f32f32.c index 3744b73947..41a5151259 100644 --- a/addon/aocl_gemm/frame/f32f32f32/lpgemm_f32f32f32.c +++ b/addon/aocl_gemm/frame/f32f32f32/lpgemm_f32f32f32.c @@ -33,7 +33,7 @@ */ #include "blis.h" -#include "lpgemm_f32f32f32.h" +#include "lpgemm_5loop_interface_apis.h" #include "lpgemm_types.h" #include "lpgemm_utils.h" #include "lpgemm_thrinfo_utils.h" @@ -51,27 +51,7 @@ void lpgemm_pack_a_f32f32f32of32 cntx_t* cntx ); -void lpgemm_rowvar_f32f32f32of32 - ( - const dim_t m, - const dim_t n, - const dim_t k, - const float* a, - const dim_t rs_a, - const dim_t cs_a, - const AOCL_MEMORY_TAG mtag_a, - const float* b, - const dim_t rs_b, - const dim_t cs_b, - const AOCL_MEMORY_TAG mtag_b, - float* c, - const dim_t rs_c, - float alpha, - float beta, - rntm_t* rntm, - lpgemm_thrinfo_t* thread, - lpgemm_post_op* post_op_list - ) +LPGEMM_5LOOP(float,float,float,f32f32f32of32) { // Query the global cntx. cntx_t* cntx = bli_gks_query_cntx(); diff --git a/addon/aocl_gemm/frame/f32f32f32/lpgemm_f32f32f32.h b/addon/aocl_gemm/frame/lpgemm_5loop_interface_apis.h similarity index 63% rename from addon/aocl_gemm/frame/f32f32f32/lpgemm_f32f32f32.h rename to addon/aocl_gemm/frame/lpgemm_5loop_interface_apis.h index f58754acb1..65a8715d29 100644 --- a/addon/aocl_gemm/frame/f32f32f32/lpgemm_f32f32f32.h +++ b/addon/aocl_gemm/frame/lpgemm_5loop_interface_apis.h @@ -32,32 +32,36 @@ */ -#ifndef LPGEMM_F32F32F32_H -#define LPGEMM_F32F32F32_H +#ifndef LPGEMM_5LOOP_INTF_H +#define LPGEMM_5LOOP_INTF_H #include "lpgemm_types.h" #include "lpgemm_post_ops.h" -void lpgemm_rowvar_f32f32f32of32 - ( - const dim_t m, - const dim_t n, - const dim_t k, - const float* a, - const dim_t rs_a, - const dim_t cs_a, - const AOCL_MEMORY_TAG mtag_a, - const float* b, - const dim_t rs_b, - const dim_t cs_b, - const AOCL_MEMORY_TAG mtag_b, - float* c, - const dim_t rs_c, - float alpha, - float beta, - rntm_t* rntm, - lpgemm_thrinfo_t* thread, - lpgemm_post_op* post_op_list - ); - -#endif //LPGEMM_F32F32F32_H +#define LPGEMM_5LOOP(A_type,B_type,C_type,LP_SFX) \ +void lpgemm_rowvar_ ## LP_SFX \ + ( \ + const dim_t m, \ + const dim_t n, \ + const dim_t k, \ + const A_type* a, \ + const dim_t rs_a, \ + const dim_t cs_a, \ + const AOCL_MEMORY_TAG mtag_a, \ + const B_type* b, \ + const dim_t rs_b, \ + const dim_t cs_b, \ + const AOCL_MEMORY_TAG mtag_b, \ + C_type* c, \ + const dim_t rs_c, \ + C_type alpha, \ + C_type beta, \ + rntm_t* rntm, \ + lpgemm_thrinfo_t* thread, \ + lpgemm_post_op* post_op_list \ + ) \ + +LPGEMM_5LOOP(uint8_t,int8_t,int32_t,u8s8s32o32); +LPGEMM_5LOOP(uint8_t,int8_t,int16_t,u8s8s16o16); +LPGEMM_5LOOP(float,float,float,f32f32f32of32); +#endif // LPGEMM_5LOOP_INTF_H diff --git a/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.c b/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.c index 1521fe7c5d..544f5b6c70 100644 --- a/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.c +++ b/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.c @@ -36,9 +36,7 @@ #include "lpgemm_config.h" #include "lpgemm_thread_decor_openmp.h" #include "lpgemm_types.h" -#include "lpgemm_u8s8s16.h" -#include "lpgemm_u8s8s32.h" -#include "lpgemm_f32f32f32.h" +#include "lpgemm_5loop_interface_apis.h" #ifdef BLIS_ENABLE_OPENMP diff --git a/addon/aocl_gemm/frame/u8s8s16/lpgemm_u8s8s16.c b/addon/aocl_gemm/frame/u8s8s16/lpgemm_u8s8s16.c index 98fa45e966..76e5640f73 100644 --- a/addon/aocl_gemm/frame/u8s8s16/lpgemm_u8s8s16.c +++ b/addon/aocl_gemm/frame/u8s8s16/lpgemm_u8s8s16.c @@ -33,34 +33,15 @@ */ #include "blis.h" -#include "lpgemm_u8s8s16.h" +#include "lpgemm_5loop_interface_apis.h" #include "lpgemm_packb_s16.h" -#include "lpgemm_6x32rowmajor.h" +#include "lpgemm_kernels.h" #include "lpgemm_utils.h" #include "lpgemm_config.h" #include "lpgemm_thrinfo_utils.h" -void lpgemm_rowvar_u8s8s16o16 - ( - const dim_t m, - const dim_t n, - const dim_t k, - const uint8_t *a, - const dim_t rs_a, - const dim_t cs_a, - const AOCL_MEMORY_TAG mtag_a, - const int8_t *b, - const dim_t rs_b, - const dim_t cs_b, - const AOCL_MEMORY_TAG mtag_b, - int16_t *c, - const dim_t rs_c, - int16_t alpha, - int16_t beta, - rntm_t *rntm, - lpgemm_thrinfo_t *thread, - lpgemm_post_op* post_op_list - ) +// B should always be packed. +LPGEMM_5LOOP(uint8_t,int8_t,int16_t,u8s8s16o16) { const dim_t NC = lpgemm_get_block_size_NC_global_cntx( U8S8S16OS16 ); const dim_t KC = lpgemm_get_block_size_KC_global_cntx( U8S8S16OS16 ); diff --git a/addon/aocl_gemm/frame/u8s8s16/lpgemm_u8s8s16.h b/addon/aocl_gemm/frame/u8s8s16/lpgemm_u8s8s16.h deleted file mode 100644 index ff1fcba8d8..0000000000 --- a/addon/aocl_gemm/frame/u8s8s16/lpgemm_u8s8s16.h +++ /dev/null @@ -1,64 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#ifndef LPGEMM_U8S8S16_H -#define LPGEMM_U8S8S16_H - -#include "blis.h" -#include "lpgemm_types.h" -#include "lpgemm_post_ops.h" - -void lpgemm_rowvar_u8s8s16o16 - ( - const dim_t m, - const dim_t n, - const dim_t k, - const uint8_t *a, - const dim_t rs_a, - const dim_t cs_a, - const AOCL_MEMORY_TAG mtag_a, - const int8_t *b, - const dim_t rs_b, - const dim_t cs_b, - const AOCL_MEMORY_TAG mtag_b, - int16_t *c, - const dim_t rs_c, - int16_t alpha, - int16_t beta, - rntm_t *rntm, - lpgemm_thrinfo_t *thread, - lpgemm_post_op* post_op_list - ); - -#endif // LPGEMM_U8S8S16_H \ No newline at end of file diff --git a/addon/aocl_gemm/frame/u8s8s32/lpgemm_u8s8s32.c b/addon/aocl_gemm/frame/u8s8s32/lpgemm_u8s8s32.c index 50dc5ad97e..2c5ff4999c 100644 --- a/addon/aocl_gemm/frame/u8s8s32/lpgemm_u8s8s32.c +++ b/addon/aocl_gemm/frame/u8s8s32/lpgemm_u8s8s32.c @@ -33,36 +33,16 @@ */ #include "blis.h" -#include "lpgemm_u8s8s32.h" +#include "lpgemm_5loop_interface_apis.h" #include "lpgemm_packa.h" #include "lpgemm_packb.h" -#include "lpgemm_6x64rowmajor.h" +#include "lpgemm_kernels.h" #include "lpgemm_utils.h" #include "lpgemm_thrinfo_utils.h" #include "lpgemm_config.h" // B should always be packed. -void lpgemm_rowvar_u8s8s32o32 - ( - const dim_t m, - const dim_t n, - const dim_t k, - const uint8_t* a, - const dim_t rs_a, - const dim_t cs_a, - const AOCL_MEMORY_TAG mtag_a, - const int8_t* b, - const dim_t rs_b, - const dim_t cs_b, - const AOCL_MEMORY_TAG mtag_b, - int32_t* c, - const dim_t rs_c, - int32_t alpha, - int32_t beta, - rntm_t* rntm, - lpgemm_thrinfo_t* thread, - lpgemm_post_op* post_op_list - ) +LPGEMM_5LOOP(uint8_t,int8_t,int32_t,u8s8s32o32) { dim_t NC = lpgemm_get_block_size_NC_global_cntx( U8S8S32OS32 ); dim_t KC = lpgemm_get_block_size_KC_global_cntx( U8S8S32OS32 ); diff --git a/addon/aocl_gemm/frame/u8s8s32/lpgemm_u8s8s32.h b/addon/aocl_gemm/frame/u8s8s32/lpgemm_u8s8s32.h deleted file mode 100644 index 8f846abfdd..0000000000 --- a/addon/aocl_gemm/frame/u8s8s32/lpgemm_u8s8s32.h +++ /dev/null @@ -1,64 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#ifndef LPGEMM_U8S8S32_H -#define LPGEMM_U8S8S32_H - -#include "lpgemm_types.h" -#include "lpgemm_post_ops.h" - -// B should always be packed. -void lpgemm_rowvar_u8s8s32o32 - ( - const dim_t m, - const dim_t n, - const dim_t k, - const uint8_t* a, - const dim_t rs_a, - const dim_t cs_a, - const AOCL_MEMORY_TAG mtag_a, - const int8_t* b, - const dim_t rs_b, - const dim_t cs_b, - const AOCL_MEMORY_TAG mtag_b, - int32_t* c, - const dim_t rs_c, - int32_t alpha, - int32_t beta, - rntm_t* rntm, - lpgemm_thrinfo_t* thread, - lpgemm_post_op* post_op_list - ); - -#endif //LPGEMM_U8S8S32_H diff --git a/addon/aocl_gemm/kernels/lpgemm_kernels.h b/addon/aocl_gemm/kernels/lpgemm_kernels.h new file mode 100644 index 0000000000..f2e66d3277 --- /dev/null +++ b/addon/aocl_gemm/kernels/lpgemm_kernels.h @@ -0,0 +1,223 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_LPGEMM_KERN_H +#define BLIS_LPGEMM_KERN_H + +#include "lpgemm_post_ops.h" + +#define LPGEMM_MAIN_KERN(A_type,B_type,C_type,LP_SFX) \ +void lpgemm_rowvar_ ## LP_SFX \ + ( \ + const dim_t m0, \ + const dim_t n0, \ + const dim_t k0, \ + const A_type* a, \ + const dim_t rs_a, \ + const dim_t cs_a, \ + const dim_t ps_a, \ + const B_type* b, \ + const dim_t rs_b, \ + const dim_t cs_b, \ + C_type* c, \ + const dim_t rs_c, \ + const dim_t cs_c, \ + const C_type alpha, \ + const C_type beta, \ + bool is_last_k, \ + dim_t post_op_c_i, \ + dim_t post_op_c_j, \ + lpgemm_post_op* post_ops_list \ + ) \ + +LPGEMM_MAIN_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x64); +LPGEMM_MAIN_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x32); + +#define LPGEMM_M_FRINGE_KERN(A_type,B_type,C_type,LP_SFX) \ +void lpgemm_rowvar_ ## LP_SFX \ + ( \ + const dim_t k0, \ + const A_type* a, \ + const dim_t rs_a, \ + const dim_t cs_a, \ + const B_type* b, \ + const dim_t rs_b, \ + const dim_t cs_b, \ + C_type* c, \ + const dim_t rs_c, \ + const C_type alpha, \ + const C_type beta, \ + bool is_last_k, \ + dim_t post_op_c_i, \ + dim_t post_op_c_j, \ + lpgemm_post_op* post_ops_list \ + ) \ + +LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5x64); +LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_4x64); +LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3x64); +LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2x64); +LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1x64); + +LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4x32); +LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2x32); +LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_1x32); + +#define LPGEMM_N_FRINGE_KERN(A_type,B_type,C_type,LP_SFX) \ +void lpgemm_rowvar_ ## LP_SFX \ + ( \ + const dim_t m0, \ + const dim_t k0, \ + const A_type* a, \ + const dim_t rs_a, \ + const dim_t cs_a, \ + const dim_t ps_a, \ + const B_type* b, \ + const dim_t rs_b, \ + const dim_t cs_b, \ + C_type* c, \ + const dim_t rs_c, \ + const C_type alpha, \ + const C_type beta, \ + bool is_last_k, \ + dim_t post_op_c_i, \ + dim_t post_op_c_j, \ + lpgemm_post_op* post_ops_list \ + ) \ + +LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x16); +LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x32); +LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x48); + +LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x16); + +#define LPGEMM_N_LT_NR0_FRINGE_KERN(A_type,B_type,C_type,LP_SFX) \ +void lpgemm_rowvar_ ## LP_SFX \ + ( \ + const dim_t m0, \ + const dim_t k0, \ + const A_type* a, \ + const dim_t rs_a, \ + const dim_t cs_a, \ + const dim_t ps_a, \ + const B_type* b, \ + const dim_t rs_b, \ + const dim_t cs_b, \ + C_type* c, \ + const dim_t rs_c, \ + const C_type alpha, \ + const C_type beta, \ + const dim_t n0_rem, \ + bool is_last_k, \ + dim_t post_op_c_i, \ + dim_t post_op_c_j, \ + lpgemm_post_op* post_ops_list \ + ) \ + +LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6xlt16); + +LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6xlt16); + +#define LPGEMM_MN_FRINGE_KERN(A_type,B_type,C_type,LP_SFX) \ +void lpgemm_rowvar_ ## LP_SFX \ + ( \ + const dim_t k0, \ + const A_type* a, \ + const dim_t rs_a, \ + const dim_t cs_a, \ + const B_type* b, \ + const dim_t rs_b, \ + const dim_t cs_b, \ + C_type* c, \ + const dim_t rs_c, \ + const C_type alpha, \ + const C_type beta, \ + bool is_last_k, \ + dim_t post_op_c_i, \ + dim_t post_op_c_j, \ + lpgemm_post_op* post_ops_list \ + ) \ + +LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5x16); +LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_4x16); +LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3x16); +LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2x16); +LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1x16); +LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5x32); +LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_4x32); +LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3x32); +LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2x32); +LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1x32); +LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5x48); +LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_4x48); +LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3x48); +LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2x48); +LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1x48); + +LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4x16); +LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2x16); +LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_1x16); + +#define LPGEMM_MN_LT_NR0_FRINGE_KERN(A_type,B_type,C_type,LP_SFX) \ +void lpgemm_rowvar_ ## LP_SFX \ + ( \ + const dim_t k0, \ + const A_type* a, \ + const dim_t rs_a, \ + const dim_t cs_a, \ + const B_type* b, \ + const dim_t rs_b, \ + const dim_t cs_b, \ + C_type* c, \ + const dim_t rs_c, \ + const C_type alpha, \ + const C_type beta, \ + const dim_t n0_rem, \ + bool is_last_k, \ + dim_t post_op_c_i, \ + dim_t post_op_c_j, \ + lpgemm_post_op* post_ops_list \ + ) \ + +LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5xlt16); +LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_4xlt16); +LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3xlt16); +LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2xlt16); +LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1xlt16); + +LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4xlt16); +LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2xlt16); +LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_1xlt16); + +#endif //BLIS_LPGEMM_KERN_H diff --git a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_6x32rowmajor.h b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_6x32rowmajor.h deleted file mode 100644 index 2dd00f4494..0000000000 --- a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_6x32rowmajor.h +++ /dev/null @@ -1,64 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#ifndef BLIS_GEMM_INT16_MNROW -#define BLIS_GEMM_INT16_MNROW - -#include "lpgemm_post_ops.h" - -// 6x32 int8o16 kernel -void lpgemm_rowvar_u8s8s16o16_6x32 - ( - const dim_t m0, - const dim_t n0, - const dim_t k0, - const uint8_t *a, - const dim_t rs_a, - const dim_t cs_a, - const dim_t ps_a, - const int8_t *b, - const dim_t rs_b, - const dim_t cs_b, - int16_t *c, - const dim_t rs_c, - const dim_t cs_c, - const int16_t alpha, - const int16_t beta, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op *post_ops_list - ); - -#endif // BLIS_GEMM_INT16_MNROW \ No newline at end of file diff --git a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_6x32rowmajor_amd256.c b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_6x32rowmajor_amd256.c index f2cbd7affe..64a2293041 100644 --- a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_6x32rowmajor_amd256.c +++ b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_6x32rowmajor_amd256.c @@ -33,33 +33,10 @@ */ #include #include "blis.h" -#include "lpgemm_6x32rowmajor.h" -#include "lpgemm_m_fringe_s16.h" -#include "lpgemm_n_fringe_s16.h" +#include "lpgemm_kernels.h" // 6x32 int8o16 kernel -void lpgemm_rowvar_u8s8s16o16_6x32 - ( - const dim_t m0, - const dim_t n0, - const dim_t k0, - const uint8_t *a, - const dim_t rs_a, - const dim_t cs_a, - const dim_t ps_a, - const int8_t *b, - const dim_t rs_b, - const dim_t cs_b, - int16_t *c, - const dim_t rs_c, - const dim_t cs_c, - const int16_t alpha, - const int16_t beta, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op *post_ops_list - ) +LPGEMM_MAIN_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x32) { static void *post_ops_labels[] = { @@ -594,4 +571,4 @@ void lpgemm_rowvar_u8s8s16o16_6x32 post_ops_list); } } -} \ No newline at end of file +} diff --git a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_m_fringe_amd256.c b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_m_fringe_amd256.c index 84a472b45a..c4071f0428 100644 --- a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_m_fringe_amd256.c +++ b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_m_fringe_amd256.c @@ -35,27 +35,10 @@ #include #include "blis.h" -#include "lpgemm_m_fringe_s16.h" +#include "lpgemm_kernels.h" // 4x32 int8o16 kernel -void lpgemm_rowvar_u8s8s16o16_4x32 - ( - const dim_t k0, - const uint8_t *a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t *b, - const dim_t rs_b, - const dim_t cs_b, - int16_t *c, - const dim_t rs_c, - const int16_t alpha, - const int16_t beta, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op *post_ops_list - ) +LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4x32) { dim_t NR = 32; @@ -364,24 +347,7 @@ void lpgemm_rowvar_u8s8s16o16_4x32 // 2x32 int8o16 kernel -void lpgemm_rowvar_u8s8s16o16_2x32 - ( - const dim_t k0, - const uint8_t *a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t *b, - const dim_t rs_b, - const dim_t cs_b, - int16_t *c, - const dim_t rs_c, - const int16_t alpha, - const int16_t beta, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op *post_ops_list - ) +LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2x32) { dim_t NR = 32; @@ -572,24 +538,7 @@ void lpgemm_rowvar_u8s8s16o16_2x32 } // 1x32 int8o16 kernel -void lpgemm_rowvar_u8s8s16o16_1x32 - ( - const dim_t k0, - const uint8_t *a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t *b, - const dim_t rs_b, - const dim_t cs_b, - int16_t *c, - const dim_t rs_c, - const int16_t alpha, - const int16_t beta, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op *post_ops_list - ) +LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_1x32) { dim_t NR = 32; diff --git a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_m_fringe_s16.h b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_m_fringe_s16.h deleted file mode 100644 index da1930cbab..0000000000 --- a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_m_fringe_s16.h +++ /dev/null @@ -1,100 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#ifndef BLIS_GEMM_INT16_MFRINGE -#define BLIS_GEMM_INT16_MFRINGE - -#include "lpgemm_post_ops.h" - -// 4x32 int8o16 kernel -void lpgemm_rowvar_u8s8s16o16_4x32 - ( - const dim_t k0, - const uint8_t *a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t *b, - const dim_t rs_b, - const dim_t cs_b, - int16_t *c, - const dim_t rs_c, - const int16_t alpha, - const int16_t beta, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op *post_ops_list - ); - -// 2x32 int8o16 kernel -void lpgemm_rowvar_u8s8s16o16_2x32 - ( - const dim_t k0, - const uint8_t *a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t *b, - const dim_t rs_b, - const dim_t cs_b, - int16_t *c, - const dim_t rs_c, - const int16_t alpha, - const int16_t beta, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op *post_ops_list - ); - -// 1x32 int8o16 kernel -void lpgemm_rowvar_u8s8s16o16_1x32 - ( - const dim_t k0, - const uint8_t *a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t *b, - const dim_t rs_b, - const dim_t cs_b, - int16_t *c, - const dim_t rs_c, - const int16_t alpha, - const int16_t beta, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op *post_ops_list - ); - -#endif // BLIS_GEMM_INT16_MFRINGE \ No newline at end of file diff --git a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_mn_fringe_amd256.c b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_mn_fringe_amd256.c index a8e8547cf4..a3c2ed4ead 100644 --- a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_mn_fringe_amd256.c +++ b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_mn_fringe_amd256.c @@ -35,27 +35,10 @@ #include #include "blis.h" -#include "lpgemm_mn_fringe_s16.h" +#include "lpgemm_kernels.h" // 4x32 int8o16 kernel -void lpgemm_rowvar_u8s8s16o16_4x16 - ( - const dim_t k0, - const uint8_t *a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t *b, - const dim_t rs_b, - const dim_t cs_b, - int16_t *c, - const dim_t rs_c, - const int16_t alpha, - const int16_t beta, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op *post_ops_list - ) +LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4x16) { dim_t NR = 16; @@ -273,25 +256,7 @@ void lpgemm_rowvar_u8s8s16o16_4x16 } // 4x16 int8o16 kernel -void lpgemm_rowvar_u8s8s16o16_4xlt16 - ( - const dim_t k0, - const uint8_t *a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t *b, - const dim_t rs_b, - const dim_t cs_b, - int16_t *c, - const dim_t rs_c, - const int16_t alpha, - const int16_t beta, - dim_t n0_rem, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op *post_ops_list - ) +LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4xlt16) { dim_t NR = 16; @@ -536,24 +501,7 @@ void lpgemm_rowvar_u8s8s16o16_4xlt16 } // 2x16 int8o16 kernel -void lpgemm_rowvar_u8s8s16o16_2x16 - ( - const dim_t k0, - const uint8_t *a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t *b, - const dim_t rs_b, - const dim_t cs_b, - int16_t *c, - const dim_t rs_c, - const int16_t alpha, - const int16_t beta, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op *post_ops_list - ) +LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2x16) { dim_t NR = 16; @@ -697,25 +645,7 @@ void lpgemm_rowvar_u8s8s16o16_2x16 } // 2xlt16 int8o16 kernel -void lpgemm_rowvar_u8s8s16o16_2xlt16 - ( - const dim_t k0, - const uint8_t *a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t *b, - const dim_t rs_b, - const dim_t cs_b, - int16_t *c, - const dim_t rs_c, - const int16_t alpha, - const int16_t beta, - dim_t n0_rem, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op *post_ops_list - ) +LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2xlt16) { dim_t NR = 16; @@ -874,24 +804,7 @@ void lpgemm_rowvar_u8s8s16o16_2xlt16 } // 1x16 int8o16 kernel -void lpgemm_rowvar_u8s8s16o16_1x16 - ( - const dim_t k0, - const uint8_t *a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t *b, - const dim_t rs_b, - const dim_t cs_b, - int16_t *c, - const dim_t rs_c, - const int16_t alpha, - const int16_t beta, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op *post_ops_list - ) +LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_1x16) { int NR = 16; @@ -997,25 +910,7 @@ void lpgemm_rowvar_u8s8s16o16_1x16 } // 1xlt16 int8o16 kernel -void lpgemm_rowvar_u8s8s16o16_1xlt16 - ( - const int k0, - const uint8_t *a, - const int rs_a, - const int cs_a, - const int8_t *b, - const int rs_b, - const int cs_b, - int16_t *c, - const int rs_c, - const int16_t alpha, - const int16_t beta, - dim_t n0_rem, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op *post_ops_list - ) +LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_1xlt16) { int NR = 16; diff --git a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_mn_fringe_s16.h b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_mn_fringe_s16.h deleted file mode 100644 index bfffe5c336..0000000000 --- a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_mn_fringe_s16.h +++ /dev/null @@ -1,157 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#ifndef BLIS_GEMM_INT16_MNFRINGE -#define BLIS_GEMM_INT16_MNFRINGE - -#include "lpgemm_post_ops.h" - -void lpgemm_rowvar_u8s8s16o16_4x16 - ( - const dim_t k0, - const uint8_t *a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t *b, - const dim_t rs_b, - const dim_t cs_b, - int16_t *c, - const dim_t rs_c, - const int16_t alpha, - const int16_t beta, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op *post_ops_list - ); - -void lpgemm_rowvar_u8s8s16o16_4xlt16 - ( - const dim_t k0, - const uint8_t *a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t *b, - const dim_t rs_b, - const dim_t cs_b, - int16_t *c, - const dim_t rs_c, - const int16_t alpha, - const int16_t beta, - dim_t n0_rem, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op *post_ops_list - ); - -void lpgemm_rowvar_u8s8s16o16_2x16 - ( - const dim_t k0, - const uint8_t *a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t *b, - const dim_t rs_b, - const dim_t cs_b, - int16_t *c, - const dim_t rs_c, - const int16_t alpha, - const int16_t beta, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op *post_ops_list - ); - -void lpgemm_rowvar_u8s8s16o16_2xlt16 - ( - const dim_t k0, - const uint8_t *a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t *b, - const dim_t rs_b, - const dim_t cs_b, - int16_t *c, - const dim_t rs_c, - const int16_t alpha, - const int16_t beta, - dim_t n0_rem, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op *post_ops_list - ); - -void lpgemm_rowvar_u8s8s16o16_1x16 - ( - const dim_t k0, - const uint8_t *a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t *b, - const dim_t rs_b, - const dim_t cs_b, - int16_t *c, - const dim_t rs_c, - const int16_t alpha, - const int16_t beta, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op *post_ops_list - ); - -void lpgemm_rowvar_u8s8s16o16_1xlt16 - ( - const int k0, - const uint8_t *a, - const int rs_a, - const int cs_a, - const int8_t *b, - const int rs_b, - const int cs_b, - int16_t *c, - const int rs_c, - const int16_t alpha, - const int16_t beta, - dim_t n0_rem, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op *post_ops_list - ); - -#endif // BLIS_GEMM_INT16_MNFRINGE \ No newline at end of file diff --git a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_n_fringe_amd256.c b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_n_fringe_amd256.c index 53ab412ded..7a34a636ac 100644 --- a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_n_fringe_amd256.c +++ b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_n_fringe_amd256.c @@ -35,30 +35,10 @@ #include #include "blis.h" -#include "lpgemm_n_fringe_s16.h" -#include "lpgemm_mn_fringe_s16.h" +#include "lpgemm_kernels.h" // 6x16 int8o16 kernel -void lpgemm_rowvar_u8s8s16o16_6x16 - ( - const dim_t m0, - const dim_t k0, - const uint8_t *a, - const dim_t rs_a, - const dim_t cs_a, - const dim_t ps_a, - const int8_t *b, - const dim_t rs_b, - const dim_t cs_b, - int16_t *c, - const dim_t rs_c, - const int16_t alpha, - const int16_t beta, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op *post_ops_list - ) +LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x16) { dim_t MR = 6; dim_t NR = 16; @@ -343,22 +323,22 @@ void lpgemm_rowvar_u8s8s16o16_6x16 // Store the results. // c[0,0-15] - _mm256_storeu_si256( (__m256i *)(c + ( rs_c * 0 ) + ( 0*16 )), c_int16_0p0 ); + _mm256_storeu_si256( (__m256i *)(c + ( rs_c * ( ir + 0 ) ) + ( 0 * 16 ) ), c_int16_0p0 ); // c[1,0-15] - _mm256_storeu_si256( (__m256i *)(c + ( rs_c * 1 ) + ( 0*16 )), c_int16_1p0 ); + _mm256_storeu_si256( (__m256i *)(c + ( rs_c * ( ir + 1 ) ) + ( 0 * 16 ) ), c_int16_1p0 ); // c[2,0-15] - _mm256_storeu_si256( (__m256i *)(c + ( rs_c * 2 ) + ( 0*16 )), c_int16_2p0 ); + _mm256_storeu_si256( (__m256i *)(c + ( rs_c * ( ir + 2 ) ) + ( 0 * 16 ) ), c_int16_2p0 ); // c[3,0-15] - _mm256_storeu_si256( (__m256i *)(c + ( rs_c * 3 ) + ( 0*16 )), c_int16_3p0 ); + _mm256_storeu_si256( (__m256i *)(c + ( rs_c * ( ir + 3 ) ) + ( 0 * 16 ) ), c_int16_3p0 ); // c[4,0-15] - _mm256_storeu_si256( (__m256i *)(c + ( rs_c * 4 ) + ( 0*16 )), c_int16_4p0 ); + _mm256_storeu_si256( (__m256i *)(c + ( rs_c * ( ir + 4 ) ) + ( 0 * 16 ) ), c_int16_4p0 ); // c[5,0-15] - _mm256_storeu_si256( (__m256i *)(c + ( rs_c * 5 ) + ( 0*16 )), c_int16_5p0 ); + _mm256_storeu_si256( (__m256i *)(c + ( rs_c * ( ir + 5 ) ) + ( 0 * 16 ) ), c_int16_5p0 ); a = a + ( MR * ps_a ); post_op_c_i += MR; @@ -422,27 +402,7 @@ void lpgemm_rowvar_u8s8s16o16_6x16 } // 6xlt16 int8o16 kernel -void lpgemm_rowvar_u8s8s16o16_6xlt16 - ( - const dim_t m0, - const dim_t k0, - const uint8_t *a, - const dim_t rs_a, - const dim_t cs_a, - const dim_t ps_a, - const int8_t *b, - const dim_t rs_b, - const dim_t cs_b, - int16_t *c, - const dim_t rs_c, - const int16_t alpha, - const int16_t beta, - const dim_t n0_rem, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op *post_ops_list - ) +LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6xlt16) { dim_t MR = 6; diff --git a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_n_fringe_s16.h b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_n_fringe_s16.h deleted file mode 100644 index 9ca5578d87..0000000000 --- a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_n_fringe_s16.h +++ /dev/null @@ -1,84 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#ifndef BLIS_GEMM_INT16_NFRINGE -#define BLIS_GEMM_INT16_NFRINGE -#include "lpgemm_post_ops.h" - -// 6x16 int8o16 kernel -void lpgemm_rowvar_u8s8s16o16_6x16 - ( - const dim_t m0, - const dim_t k0, - const uint8_t *a, - const dim_t rs_a, - const dim_t cs_a, - const dim_t ps_a, - const int8_t *b, - const dim_t rs_b, - const dim_t cs_b, - int16_t *c, - const dim_t rs_c, - const int16_t alpha, - const int16_t beta, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op *post_ops_list - ); - -// 6xlt16 int8o16 kernel -void lpgemm_rowvar_u8s8s16o16_6xlt16 - ( - const dim_t m0, - const dim_t k0, - const uint8_t *a, - const dim_t rs_a, - const dim_t cs_a, - const dim_t ps_a, - const int8_t *b, - const dim_t rs_b, - const dim_t cs_b, - int16_t *c, - const dim_t rs_c, - const int16_t alpha, - const int16_t beta, - const dim_t n0_rem, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op *post_ops_list - ); - -#endif // BLIS_GEMM_INT16_NFRINGE \ No newline at end of file diff --git a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_6x64rowmajor.h b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_6x64rowmajor.h deleted file mode 100644 index 8373ba1b72..0000000000 --- a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_6x64rowmajor.h +++ /dev/null @@ -1,64 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#ifndef BLIS_GEMM_INT8_MNROW -#define BLIS_GEMM_INT8_MNROW - -#include "lpgemm_post_ops.h" - -// 6x64 int8o32 kernel -void lpgemm_rowvar_u8s8s32o32_6x64 - ( - const dim_t m0, - const dim_t n0, - const dim_t k0, - const uint8_t* a, - const dim_t rs_a, - const dim_t cs_a, - const dim_t ps_a, - const int8_t* b, - const dim_t rs_b, - const dim_t cs_b, - int32_t* c, - const dim_t rs_c, - const dim_t cs_c, - const int32_t alpha, - const int32_t beta, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op* post_ops_list - ); - -#endif //BLIS_GEMM_INT8_MNROW diff --git a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_6x64rowmajor_amd512vnni.c b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_6x64rowmajor_amd512vnni.c index e611556171..109b4784fa 100644 --- a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_6x64rowmajor_amd512vnni.c +++ b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_6x64rowmajor_amd512vnni.c @@ -35,33 +35,10 @@ #include #include "blis.h" -#include "lpgemm_6x64rowmajor.h" -#include "lpgemm_n_fringe.h" -#include "lpgemm_m_fringe.h" +#include "lpgemm_kernels.h" // 6x64 int8o32 kernel -void lpgemm_rowvar_u8s8s32o32_6x64 - ( - const dim_t m0, - const dim_t n0, - const dim_t k0, - const uint8_t* a, - const dim_t rs_a, - const dim_t cs_a, - const dim_t ps_a, - const int8_t* b, - const dim_t rs_b, - const dim_t cs_b, - int32_t* c, - const dim_t rs_c, - const dim_t cs_c, - const int32_t alpha, - const int32_t beta, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op* post_ops_list - ) +LPGEMM_MAIN_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x64) { static void* post_ops_labels[] = { diff --git a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_m_fringe.h b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_m_fringe.h deleted file mode 100644 index e4cc3f763e..0000000000 --- a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_m_fringe.h +++ /dev/null @@ -1,140 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#ifndef BLIS_GEMM_INT8_MFRINGE -#define BLIS_GEMM_INT8_MFRINGE - -#include "lpgemm_post_ops.h" - -// 5x64 int8o32 kernel -void lpgemm_rowvar_u8s8s32o32_5x64 - ( - const dim_t k0, - const uint8_t* a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t* b, - const dim_t rs_b, - const dim_t cs_b, - int32_t* c, - const dim_t rs_c, - const int32_t alpha, - const int32_t beta, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op* post_ops_list - ); - -// 4x64 int8o32 kernel -void lpgemm_rowvar_u8s8s32o32_4x64 - ( - const dim_t k0, - const uint8_t* a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t* b, - const dim_t rs_b, - const dim_t cs_b, - int32_t* c, - const dim_t rs_c, - const int32_t alpha, - const int32_t beta, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op* post_ops_list - ); - -// 3x64 int8o32 kernel -void lpgemm_rowvar_u8s8s32o32_3x64 - ( - const dim_t k0, - const uint8_t* a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t* b, - const dim_t rs_b, - const dim_t cs_b, - int32_t* c, - const dim_t rs_c, - const int32_t alpha, - const int32_t beta, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op* post_ops_list - ); - -// 2x64 int8o32 kernel -void lpgemm_rowvar_u8s8s32o32_2x64 - ( - const dim_t k0, - const uint8_t* a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t* b, - const dim_t rs_b, - const dim_t cs_b, - int32_t* c, - const dim_t rs_c, - const int32_t alpha, - const int32_t beta, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op* post_ops_list - ); - -// 1x64 int8o32 kernel -void lpgemm_rowvar_u8s8s32o32_1x64 - ( - const dim_t k0, - const uint8_t* a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t* b, - const dim_t rs_b, - const dim_t cs_b, - int32_t* c, - const dim_t rs_c, - const int32_t alpha, - const int32_t beta, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op* post_ops_list - ); - -#endif //BLIS_GEMM_INT8_MFRINGE diff --git a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_m_fringe_amd512vnni.c b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_m_fringe_amd512vnni.c index 6c8e79fa31..934260e7ed 100644 --- a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_m_fringe_amd512vnni.c +++ b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_m_fringe_amd512vnni.c @@ -36,27 +36,10 @@ #include #include "blis.h" -#include "lpgemm_m_fringe.h" +#include "lpgemm_kernels.h" // 5x64 int8o32 kernel -void lpgemm_rowvar_u8s8s32o32_5x64 - ( - const dim_t k0, - const uint8_t* a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t* b, - const dim_t rs_b, - const dim_t cs_b, - int32_t* c, - const dim_t rs_c, - const int32_t alpha, - const int32_t beta, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op* post_ops_list - ) +LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5x64) { static void* post_ops_labels[] = { @@ -607,24 +590,7 @@ void lpgemm_rowvar_u8s8s32o32_5x64 } // 4x64 int8o32 kernel -void lpgemm_rowvar_u8s8s32o32_4x64 - ( - const dim_t k0, - const uint8_t* a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t* b, - const dim_t rs_b, - const dim_t cs_b, - int32_t* c, - const dim_t rs_c, - const int32_t alpha, - const int32_t beta, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op* post_ops_list - ) +LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_4x64) { static void* post_ops_labels[] = { @@ -1081,24 +1047,7 @@ void lpgemm_rowvar_u8s8s32o32_4x64 } // 3x64 int8o32 kernel -void lpgemm_rowvar_u8s8s32o32_3x64 - ( - const dim_t k0, - const uint8_t* a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t* b, - const dim_t rs_b, - const dim_t cs_b, - int32_t* c, - const dim_t rs_c, - const int32_t alpha, - const int32_t beta, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op* post_ops_list - ) +LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3x64) { static void* post_ops_labels[] = { @@ -1461,24 +1410,7 @@ void lpgemm_rowvar_u8s8s32o32_3x64 } // 2x64 int8o32 kernel -void lpgemm_rowvar_u8s8s32o32_2x64 - ( - const dim_t k0, - const uint8_t* a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t* b, - const dim_t rs_b, - const dim_t cs_b, - int32_t* c, - const dim_t rs_c, - const int32_t alpha, - const int32_t beta, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op* post_ops_list - ) +LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2x64) { static void* post_ops_labels[] = { @@ -1747,24 +1679,7 @@ void lpgemm_rowvar_u8s8s32o32_2x64 } // 1x64 int8o32 kernel -void lpgemm_rowvar_u8s8s32o32_1x64 - ( - const dim_t k0, - const uint8_t* a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t* b, - const dim_t rs_b, - const dim_t cs_b, - int32_t* c, - const dim_t rs_c, - const int32_t alpha, - const int32_t beta, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op* post_ops_list - ) +LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1x64) { static void* post_ops_labels[] = { diff --git a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_mn_fringe.h b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_mn_fringe.h deleted file mode 100644 index e49f543d98..0000000000 --- a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_mn_fringe.h +++ /dev/null @@ -1,445 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#ifndef BLIS_GEMM_INT8_MNFRINGE -#define BLIS_GEMM_INT8_MNFRINGE - -#include "lpgemm_post_ops.h" - -// 5xlt16 int8o32 fringe kernel -void lpgemm_rowvar_u8s8s32o32_5xlt16 - ( - const dim_t k0, - const uint8_t* a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t* b, - const dim_t rs_b, - const dim_t cs_b, - int32_t* c, - const dim_t rs_c, - const int32_t alpha, - const int32_t beta, - const dim_t n0_rem, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op* post_ops_list - ); - -// 4xlt16 int8o32 fringe kernel -void lpgemm_rowvar_u8s8s32o32_4xlt16 - ( - const dim_t k0, - const uint8_t* a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t* b, - const dim_t rs_b, - const dim_t cs_b, - int32_t* c, - const dim_t rs_c, - const int32_t alpha, - const int32_t beta, - const dim_t n0_rem, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op* post_ops_list - ); - -// 3xlt16 int8o32 fringe kernel -void lpgemm_rowvar_u8s8s32o32_3xlt16 - ( - const dim_t k0, - const uint8_t* a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t* b, - const dim_t rs_b, - const dim_t cs_b, - int32_t* c, - const dim_t rs_c, - const int32_t alpha, - const int32_t beta, - const dim_t n0_rem, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op* post_ops_list - ); - -// 2xlt16 int8o32 fringe kernel -void lpgemm_rowvar_u8s8s32o32_2xlt16 - ( - const dim_t k0, - const uint8_t* a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t* b, - const dim_t rs_b, - const dim_t cs_b, - int32_t* c, - const dim_t rs_c, - const int32_t alpha, - const int32_t beta, - const dim_t n0_rem, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op* post_ops_list - ); - -// 1xlt16 int8o32 fringe kernel -void lpgemm_rowvar_u8s8s32o32_1xlt16 - ( - const dim_t k0, - const uint8_t* a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t* b, - const dim_t rs_b, - const dim_t cs_b, - int32_t* c, - const dim_t rs_c, - const int32_t alpha, - const int32_t beta, - const dim_t n0_rem, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op* post_ops_list - ); - -// 5x16 int8o32 kernel -void lpgemm_rowvar_u8s8s32o32_5x16 - ( - const dim_t k0, - const uint8_t* a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t* b, - const dim_t rs_b, - const dim_t cs_b, - int32_t* c, - const dim_t rs_c, - const int32_t alpha, - const int32_t beta, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op* post_ops_list - ); - -// 4x16 int8o32 kernel -void lpgemm_rowvar_u8s8s32o32_4x16 - ( - const dim_t k0, - const uint8_t* a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t* b, - const dim_t rs_b, - const dim_t cs_b, - int32_t* c, - const dim_t rs_c, - const int32_t alpha, - const int32_t beta, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op* post_ops_list - ); - -// 3x16 int8o32 kernel -void lpgemm_rowvar_u8s8s32o32_3x16 - ( - const dim_t k0, - const uint8_t* a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t* b, - const dim_t rs_b, - const dim_t cs_b, - int32_t* c, - const dim_t rs_c, - const int32_t alpha, - const int32_t beta, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op* post_ops_list - ); - -// 2x16 int8o32 kernel -void lpgemm_rowvar_u8s8s32o32_2x16 - ( - const dim_t k0, - const uint8_t* a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t* b, - const dim_t rs_b, - const dim_t cs_b, - int32_t* c, - const dim_t rs_c, - const int32_t alpha, - const int32_t beta, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op* post_ops_list - ); - -// 1x16 int8o32 kernel -void lpgemm_rowvar_u8s8s32o32_1x16 - ( - const dim_t k0, - const uint8_t* a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t* b, - const dim_t rs_b, - const dim_t cs_b, - int32_t* c, - const dim_t rs_c, - const int32_t alpha, - const int32_t beta, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op* post_ops_list - ); - -// 5x32 int8o32 kernel -void lpgemm_rowvar_u8s8s32o32_5x32 - ( - const dim_t k0, - const uint8_t* a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t* b, - const dim_t rs_b, - const dim_t cs_b, - int32_t* c, - const dim_t rs_c, - const int32_t alpha, - const int32_t beta, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op* post_ops_list - ); - -// 4x32 int8o32 kernel -void lpgemm_rowvar_u8s8s32o32_4x32 - ( - const dim_t k0, - const uint8_t* a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t* b, - const dim_t rs_b, - const dim_t cs_b, - int32_t* c, - const dim_t rs_c, - const int32_t alpha, - const int32_t beta, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op* post_ops_list - ); - -// 3x32 int8o32 kernel -void lpgemm_rowvar_u8s8s32o32_3x32 - ( - const dim_t k0, - const uint8_t* a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t* b, - const dim_t rs_b, - const dim_t cs_b, - int32_t* c, - const dim_t rs_c, - const int32_t alpha, - const int32_t beta, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op* post_ops_list - ); - -// 2x32 int8o32 kernel -void lpgemm_rowvar_u8s8s32o32_2x32 - ( - const dim_t k0, - const uint8_t* a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t* b, - const dim_t rs_b, - const dim_t cs_b, - int32_t* c, - const dim_t rs_c, - const int32_t alpha, - const int32_t beta, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op* post_ops_list - ); - -// 1x32 int8o32 kernel -void lpgemm_rowvar_u8s8s32o32_1x32 - ( - const dim_t k0, - const uint8_t* a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t* b, - const dim_t rs_b, - const dim_t cs_b, - int32_t* c, - const dim_t rs_c, - const int32_t alpha, - const int32_t beta, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op* post_ops_list - ); - -// 5x48 int8o32 kernel -void lpgemm_rowvar_u8s8s32o32_5x48 - ( - const dim_t k0, - const uint8_t* a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t* b, - const dim_t rs_b, - const dim_t cs_b, - int32_t* c, - const dim_t rs_c, - const int32_t alpha, - const int32_t beta, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op* post_ops_list - ); - -// 4x48 int8o32 kernel -void lpgemm_rowvar_u8s8s32o32_4x48 - ( - const dim_t k0, - const uint8_t* a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t* b, - const dim_t rs_b, - const dim_t cs_b, - int32_t* c, - const dim_t rs_c, - const int32_t alpha, - const int32_t beta, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op* post_ops_list - ); - -// 3x48 int8o32 kernel -void lpgemm_rowvar_u8s8s32o32_3x48 - ( - const dim_t k0, - const uint8_t* a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t* b, - const dim_t rs_b, - const dim_t cs_b, - int32_t* c, - const dim_t rs_c, - const int32_t alpha, - const int32_t beta, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op* post_ops_list - ); - -// 2x48 int8o32 kernel -void lpgemm_rowvar_u8s8s32o32_2x48 - ( - const dim_t k0, - const uint8_t* a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t* b, - const dim_t rs_b, - const dim_t cs_b, - int32_t* c, - const dim_t rs_c, - const int32_t alpha, - const int32_t beta, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op* post_ops_list - ); - -// 1x48 int8o32 kernel -void lpgemm_rowvar_u8s8s32o32_1x48 - ( - const dim_t k0, - const uint8_t* a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t* b, - const dim_t rs_b, - const dim_t cs_b, - int32_t* c, - const dim_t rs_c, - const int32_t alpha, - const int32_t beta, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op* post_ops_list - ); - -#endif //BLIS_GEMM_INT8_MNFRINGE diff --git a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_mn_fringe_amd512vnni.c b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_mn_fringe_amd512vnni.c index 07e61f4ce3..c8642475f5 100644 --- a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_mn_fringe_amd512vnni.c +++ b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_mn_fringe_amd512vnni.c @@ -36,28 +36,10 @@ #include #include "blis.h" -#include "lpgemm_mn_fringe.h" +#include "lpgemm_kernels.h" // 5xlt16 int8o32 fringe kernel -void lpgemm_rowvar_u8s8s32o32_5xlt16 - ( - const dim_t k0, - const uint8_t* a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t* b, - const dim_t rs_b, - const dim_t cs_b, - int32_t* c, - const dim_t rs_c, - const int32_t alpha, - const int32_t beta, - const dim_t n0_rem, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op* post_ops_list - ) +LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5xlt16) { static void* post_ops_labels[] = { @@ -229,9 +211,9 @@ void lpgemm_rowvar_u8s8s32o32_5xlt16 POST_OP_LABEL_LASTK_SAFE_JUMP POST_OPS_BIAS_5xLT16: { - selector1 = - _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + - post_op_c_j ); + memcpy( buf0, ( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j ), ( n0_rem * sizeof( int32_t ) ) ); + selector1 = _mm512_loadu_epi32( buf0 ); // c[0,0-15] c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); @@ -309,25 +291,7 @@ void lpgemm_rowvar_u8s8s32o32_5xlt16 } // 4xlt16 int8o32 fringe kernel -void lpgemm_rowvar_u8s8s32o32_4xlt16 - ( - const dim_t k0, - const uint8_t* a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t* b, - const dim_t rs_b, - const dim_t cs_b, - int32_t* c, - const dim_t rs_c, - const int32_t alpha, - const int32_t beta, - const dim_t n0_rem, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op* post_ops_list - ) +LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_4xlt16) { static void* post_ops_labels[] = { @@ -474,9 +438,9 @@ void lpgemm_rowvar_u8s8s32o32_4xlt16 POST_OP_LABEL_LASTK_SAFE_JUMP POST_OPS_BIAS_4xLT16: { - selector1 = - _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + - post_op_c_j ); + memcpy( buf0, ( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j ), ( n0_rem * sizeof( int32_t ) ) ); + selector1 = _mm512_loadu_epi32( buf0 ); // c[0,0-15] c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); @@ -542,25 +506,7 @@ void lpgemm_rowvar_u8s8s32o32_4xlt16 } // 3xlt16 int8o32 fringe kernel -void lpgemm_rowvar_u8s8s32o32_3xlt16 - ( - const dim_t k0, - const uint8_t* a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t* b, - const dim_t rs_b, - const dim_t cs_b, - int32_t* c, - const dim_t rs_c, - const int32_t alpha, - const int32_t beta, - const dim_t n0_rem, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op* post_ops_list - ) +LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3xlt16) { static void* post_ops_labels[] = { @@ -680,9 +626,9 @@ void lpgemm_rowvar_u8s8s32o32_3xlt16 POST_OP_LABEL_LASTK_SAFE_JUMP POST_OPS_BIAS_3xLT16: { - selector1 = - _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + - post_op_c_j ); + memcpy( buf0, ( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j ), ( n0_rem * sizeof( int32_t ) ) ); + selector1 = _mm512_loadu_epi32( buf0 ); // c[0,0-15] c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); @@ -736,25 +682,7 @@ void lpgemm_rowvar_u8s8s32o32_3xlt16 } // 2xlt16 int8o32 fringe kernel -void lpgemm_rowvar_u8s8s32o32_2xlt16 - ( - const dim_t k0, - const uint8_t* a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t* b, - const dim_t rs_b, - const dim_t cs_b, - int32_t* c, - const dim_t rs_c, - const int32_t alpha, - const int32_t beta, - const dim_t n0_rem, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op* post_ops_list - ) +LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2xlt16) { static void* post_ops_labels[] = { @@ -848,9 +776,9 @@ void lpgemm_rowvar_u8s8s32o32_2xlt16 POST_OP_LABEL_LASTK_SAFE_JUMP POST_OPS_BIAS_2xLT16: { - selector1 = - _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + - post_op_c_j ); + memcpy( buf0, ( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j ), ( n0_rem * sizeof( int32_t ) ) ); + selector1 = _mm512_loadu_epi32( buf0 ); // c[0,0-15] c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); @@ -892,25 +820,7 @@ void lpgemm_rowvar_u8s8s32o32_2xlt16 } // 1xlt16 int8o32 fringe kernel -void lpgemm_rowvar_u8s8s32o32_1xlt16 - ( - const dim_t k0, - const uint8_t* a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t* b, - const dim_t rs_b, - const dim_t cs_b, - int32_t* c, - const dim_t rs_c, - const int32_t alpha, - const int32_t beta, - const dim_t n0_rem, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op* post_ops_list - ) +LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1xlt16) { static void* post_ops_labels[] = { @@ -978,9 +888,9 @@ void lpgemm_rowvar_u8s8s32o32_1xlt16 POST_OP_LABEL_LASTK_SAFE_JUMP POST_OPS_BIAS_1xLT16: { - selector1 = - _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + - post_op_c_j ); + memcpy( buf0, ( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j ), ( n0_rem * sizeof( int32_t ) ) ); + selector1 = _mm512_loadu_epi32( buf0 ); // c[0,0-15] c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); @@ -1010,24 +920,7 @@ void lpgemm_rowvar_u8s8s32o32_1xlt16 } // 5x16 int8o32 kernel -void lpgemm_rowvar_u8s8s32o32_5x16 - ( - const dim_t k0, - const uint8_t* a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t* b, - const dim_t rs_b, - const dim_t cs_b, - int32_t* c, - const dim_t rs_c, - const int32_t alpha, - const int32_t beta, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op* post_ops_list - ) +LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5x16) { static void* post_ops_labels[] = { @@ -1248,24 +1141,7 @@ void lpgemm_rowvar_u8s8s32o32_5x16 } // 4x16 int8o32 kernel -void lpgemm_rowvar_u8s8s32o32_4x16 - ( - const dim_t k0, - const uint8_t* a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t* b, - const dim_t rs_b, - const dim_t cs_b, - int32_t* c, - const dim_t rs_c, - const int32_t alpha, - const int32_t beta, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op* post_ops_list - ) +LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_4x16) { static void* post_ops_labels[] = { @@ -1453,24 +1329,7 @@ void lpgemm_rowvar_u8s8s32o32_4x16 } // 3x16 int8o32 kernel -void lpgemm_rowvar_u8s8s32o32_3x16 - ( - const dim_t k0, - const uint8_t* a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t* b, - const dim_t rs_b, - const dim_t cs_b, - int32_t* c, - const dim_t rs_c, - const int32_t alpha, - const int32_t beta, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op* post_ops_list - ) +LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3x16) { static void* post_ops_labels[] = { @@ -1625,24 +1484,7 @@ void lpgemm_rowvar_u8s8s32o32_3x16 } // 2x16 int8o32 kernel -void lpgemm_rowvar_u8s8s32o32_2x16 - ( - const dim_t k0, - const uint8_t* a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t* b, - const dim_t rs_b, - const dim_t cs_b, - int32_t* c, - const dim_t rs_c, - const int32_t alpha, - const int32_t beta, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op* post_ops_list - ) +LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2x16) { static void* post_ops_labels[] = { @@ -1764,24 +1606,7 @@ void lpgemm_rowvar_u8s8s32o32_2x16 } // 1x16 int8o32 kernel -void lpgemm_rowvar_u8s8s32o32_1x16 - ( - const dim_t k0, - const uint8_t* a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t* b, - const dim_t rs_b, - const dim_t cs_b, - int32_t* c, - const dim_t rs_c, - const int32_t alpha, - const int32_t beta, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op* post_ops_list - ) +LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1x16) { static void* post_ops_labels[] = { @@ -1870,24 +1695,7 @@ void lpgemm_rowvar_u8s8s32o32_1x16 } // 5x32 int8o32 kernel -void lpgemm_rowvar_u8s8s32o32_5x32 - ( - const dim_t k0, - const uint8_t* a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t* b, - const dim_t rs_b, - const dim_t cs_b, - int32_t* c, - const dim_t rs_c, - const int32_t alpha, - const int32_t beta, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op* post_ops_list - ) +LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5x32) { static void* post_ops_labels[] = { @@ -2210,24 +2018,7 @@ void lpgemm_rowvar_u8s8s32o32_5x32 } // 4x32 int8o32 kernel -void lpgemm_rowvar_u8s8s32o32_4x32 - ( - const dim_t k0, - const uint8_t* a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t* b, - const dim_t rs_b, - const dim_t cs_b, - int32_t* c, - const dim_t rs_c, - const int32_t alpha, - const int32_t beta, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op* post_ops_list - ) +LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_4x32) { static void* post_ops_labels[] = { @@ -2499,24 +2290,7 @@ void lpgemm_rowvar_u8s8s32o32_4x32 } // 3x32 int8o32 kernel -void lpgemm_rowvar_u8s8s32o32_3x32 - ( - const dim_t k0, - const uint8_t* a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t* b, - const dim_t rs_b, - const dim_t cs_b, - int32_t* c, - const dim_t rs_c, - const int32_t alpha, - const int32_t beta, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op* post_ops_list - ) +LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3x32) { static void* post_ops_labels[] = { @@ -2737,24 +2511,7 @@ void lpgemm_rowvar_u8s8s32o32_3x32 } // 2x32 int8o32 kernel -void lpgemm_rowvar_u8s8s32o32_2x32 - ( - const dim_t k0, - const uint8_t* a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t* b, - const dim_t rs_b, - const dim_t cs_b, - int32_t* c, - const dim_t rs_c, - const int32_t alpha, - const int32_t beta, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op* post_ops_list - ) +LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2x32) { static void* post_ops_labels[] = { @@ -2924,24 +2681,7 @@ void lpgemm_rowvar_u8s8s32o32_2x32 } // 1x32 int8o32 kernel -void lpgemm_rowvar_u8s8s32o32_1x32 - ( - const dim_t k0, - const uint8_t* a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t* b, - const dim_t rs_b, - const dim_t cs_b, - int32_t* c, - const dim_t rs_c, - const int32_t alpha, - const int32_t beta, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op* post_ops_list - ) +LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1x32) { static void* post_ops_labels[] = { @@ -3060,24 +2800,7 @@ void lpgemm_rowvar_u8s8s32o32_1x32 } // 5x48 int8o32 kernel -void lpgemm_rowvar_u8s8s32o32_5x48 - ( - const dim_t k0, - const uint8_t* a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t* b, - const dim_t rs_b, - const dim_t cs_b, - int32_t* c, - const dim_t rs_c, - const int32_t alpha, - const int32_t beta, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op* post_ops_list - ) +LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5x48) { static void* post_ops_labels[] = { @@ -3496,24 +3219,7 @@ void lpgemm_rowvar_u8s8s32o32_5x48 } // 4x48 int8o32 kernel -void lpgemm_rowvar_u8s8s32o32_4x48 - ( - const dim_t k0, - const uint8_t* a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t* b, - const dim_t rs_b, - const dim_t cs_b, - int32_t* c, - const dim_t rs_c, - const int32_t alpha, - const int32_t beta, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op* post_ops_list - ) +LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_4x48) { static void* post_ops_labels[] = { @@ -3863,24 +3569,7 @@ void lpgemm_rowvar_u8s8s32o32_4x48 } // 3x48 int8o32 kernel -void lpgemm_rowvar_u8s8s32o32_3x48 - ( - const dim_t k0, - const uint8_t* a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t* b, - const dim_t rs_b, - const dim_t cs_b, - int32_t* c, - const dim_t rs_c, - const int32_t alpha, - const int32_t beta, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op* post_ops_list - ) +LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3x48) { static void* post_ops_labels[] = { @@ -4161,24 +3850,7 @@ void lpgemm_rowvar_u8s8s32o32_3x48 } // 2x48 int8o32 kernel -void lpgemm_rowvar_u8s8s32o32_2x48 - ( - const dim_t k0, - const uint8_t* a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t* b, - const dim_t rs_b, - const dim_t cs_b, - int32_t* c, - const dim_t rs_c, - const int32_t alpha, - const int32_t beta, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op* post_ops_list - ) +LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2x48) { static void* post_ops_labels[] = { @@ -4390,24 +4062,7 @@ void lpgemm_rowvar_u8s8s32o32_2x48 } // 1x48 int8o32 kernel -void lpgemm_rowvar_u8s8s32o32_1x48 - ( - const dim_t k0, - const uint8_t* a, - const dim_t rs_a, - const dim_t cs_a, - const int8_t* b, - const dim_t rs_b, - const dim_t cs_b, - int32_t* c, - const dim_t rs_c, - const int32_t alpha, - const int32_t beta, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op* post_ops_list - ) +LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1x48) { static void* post_ops_labels[] = { diff --git a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_n_fringe.h b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_n_fringe.h deleted file mode 100644 index 84e9ee5b7b..0000000000 --- a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_n_fringe.h +++ /dev/null @@ -1,129 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#ifndef BLIS_GEMM_INT8_NFRINGE -#define BLIS_GEMM_INT8_NFRINGE - -#include "lpgemm_post_ops.h" - -// 6xlt16 int8o32 fringe kernel -void lpgemm_rowvar_u8s8s32o32_6xlt16 - ( - const dim_t m0, - const dim_t k0, - const uint8_t* a, - const dim_t rs_a, - const dim_t cs_a, - const dim_t ps_a, - const int8_t* b, - const dim_t rs_b, - const dim_t cs_b, - int32_t* c, - const dim_t rs_c, - const int32_t alpha, - const int32_t beta, - const dim_t n0_rem, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op* post_ops_list - ); - -// 6x16 int8o32 fringe kernel -void lpgemm_rowvar_u8s8s32o32_6x16 - ( - const dim_t m0, - const dim_t k0, - const uint8_t* a, - const dim_t rs_a, - const dim_t cs_a, - const dim_t ps_a, - const int8_t* b, - const dim_t rs_b, - const dim_t cs_b, - int32_t* c, - const dim_t rs_c, - const int32_t alpha, - const int32_t beta, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op* post_ops_list - ); - -// 6x32 int8o32 fringe kernel -void lpgemm_rowvar_u8s8s32o32_6x32 - ( - const dim_t m0, - const dim_t k0, - const uint8_t* a, - const dim_t rs_a, - const dim_t cs_a, - const dim_t ps_a, - const int8_t* b, - const dim_t rs_b, - const dim_t cs_b, - int32_t* c, - const dim_t rs_c, - const int32_t alpha, - const int32_t beta, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op* post_ops_list - ); - -// 6x48 int8o32 fringe kernel -void lpgemm_rowvar_u8s8s32o32_6x48 - ( - const dim_t m0, - const dim_t k0, - const uint8_t* a, - const dim_t rs_a, - const dim_t cs_a, - const dim_t ps_a, - const int8_t* b, - const dim_t rs_b, - const dim_t cs_b, - int32_t* c, - const dim_t rs_c, - const int32_t alpha, - const int32_t beta, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op* post_ops_list - ); - -#endif //BLIS_GEMM_INT8_NFRINGE diff --git a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_n_fringe_amd512vnni.c b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_n_fringe_amd512vnni.c index 70369f73ed..090a120232 100644 --- a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_n_fringe_amd512vnni.c +++ b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_n_fringe_amd512vnni.c @@ -36,31 +36,10 @@ #include #include "blis.h" -#include "lpgemm_n_fringe.h" -#include "lpgemm_mn_fringe.h" +#include "lpgemm_kernels.h" // 6xlt16 int8o32 fringe kernel -void lpgemm_rowvar_u8s8s32o32_6xlt16 - ( - const dim_t m0, - const dim_t k0, - const uint8_t* a, - const dim_t rs_a, - const dim_t cs_a, - const dim_t ps_a, - const int8_t* b, - const dim_t rs_b, - const dim_t cs_b, - int32_t* c, - const dim_t rs_c, - const int32_t alpha, - const int32_t beta, - const dim_t n0_rem, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op* post_ops_list - ) +LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6xlt16) { static void* post_ops_labels[] = { @@ -305,9 +284,9 @@ void lpgemm_rowvar_u8s8s32o32_6xlt16 POST_OP_LABEL_LASTK_SAFE_JUMP POST_OPS_BIAS_6xLT16: { - selector1 = - _mm512_loadu_epi32( ( int32_t* )post_ops_list_temp->op_args1 + - post_op_c_j ); + memcpy( buf0, ( ( int32_t* )post_ops_list_temp->op_args1 + + post_op_c_j ), ( n0_rem * sizeof( int32_t ) ) ); + selector1 = _mm512_loadu_epi32( buf0 ); // c[0,0-15] c_int32_0p0 = _mm512_add_epi32( selector1, c_int32_0p0 ); @@ -479,26 +458,7 @@ void lpgemm_rowvar_u8s8s32o32_6xlt16 } // 6x16 int8o32 fringe kernel -void lpgemm_rowvar_u8s8s32o32_6x16 - ( - const dim_t m0, - const dim_t k0, - const uint8_t* a, - const dim_t rs_a, - const dim_t cs_a, - const dim_t ps_a, - const int8_t* b, - const dim_t rs_b, - const dim_t cs_b, - int32_t* c, - const dim_t rs_c, - const int32_t alpha, - const int32_t beta, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op* post_ops_list - ) +LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x16) { static void* post_ops_labels[] = { @@ -883,26 +843,7 @@ void lpgemm_rowvar_u8s8s32o32_6x16 } // 6x32 int8o32 fringe kernel -void lpgemm_rowvar_u8s8s32o32_6x32 - ( - const dim_t m0, - const dim_t k0, - const uint8_t* a, - const dim_t rs_a, - const dim_t cs_a, - const dim_t ps_a, - const int8_t* b, - const dim_t rs_b, - const dim_t cs_b, - int32_t* c, - const dim_t rs_c, - const int32_t alpha, - const int32_t beta, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op* post_ops_list - ) +LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x32) { static void* post_ops_labels[] = { @@ -1401,26 +1342,7 @@ void lpgemm_rowvar_u8s8s32o32_6x32 } // 6x48 int8o32 fringe kernel -void lpgemm_rowvar_u8s8s32o32_6x48 - ( - const dim_t m0, - const dim_t k0, - const uint8_t* a, - const dim_t rs_a, - const dim_t cs_a, - const dim_t ps_a, - const int8_t* b, - const dim_t rs_b, - const dim_t cs_b, - int32_t* c, - const dim_t rs_c, - const int32_t alpha, - const int32_t beta, - bool is_last_k, - dim_t post_op_c_i, - dim_t post_op_c_j, - lpgemm_post_op* post_ops_list - ) +LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x48) { static void* post_ops_labels[] = { diff --git a/bench/bench_aocl_gemm/bench_lpgemm.c b/bench/bench_aocl_gemm/bench_lpgemm.c index 90a76e6868..1bf9587635 100644 --- a/bench/bench_aocl_gemm/bench_lpgemm.c +++ b/bench/bench_aocl_gemm/bench_lpgemm.c @@ -68,6 +68,7 @@ void mat_mul_ ## BLAS_SFX \ aocl_post_op* post_op\ ) \ { \ + char storage = 'r'; \ char transa = 'n'; \ char transb = 'n'; \ char reordera = 'n'; \ @@ -86,7 +87,7 @@ void mat_mul_ ## BLAS_SFX \ reorderb = 'r'; \ } \ \ - aocl_gemm_ ## BLAS_SFX( transa, transb, m, n, k, \ + aocl_gemm_ ## BLAS_SFX( storage, transa, transb, m, n, k, \ alpha, \ a, lda, reordera, \ b, ldb, reorderb, \ From 88e44c64e330d77af8f33d541537e23857ea8fb6 Mon Sep 17 00:00:00 2001 From: satish kumar nuggu Date: Thu, 11 Aug 2022 14:44:16 +0530 Subject: [PATCH 175/243] Fixed Memory Leaks in TRSM 1. Fixed the memory leaks in corner cases which caused due to extra loads in all datatypes(s,d,c,z). 2. In remainder cases instead of loading required number of elements, loaded extra elements which lead to memory leaks. Fixed memory leaks by restricting number of loads to required number of elements. AMD-Internal: [CPUPL-2280] Change-Id: Ia49a02565e01d5ed05e98090b7773a444587cd8a --- kernels/zen/3/bli_trsm_small.c | 1259 +++++++++++++++++++++++--------- 1 file changed, 925 insertions(+), 334 deletions(-) diff --git a/kernels/zen/3/bli_trsm_small.c b/kernels/zen/3/bli_trsm_small.c index 5b6df35d77..168fe48d7d 100644 --- a/kernels/zen/3/bli_trsm_small.c +++ b/kernels/zen/3/bli_trsm_small.c @@ -1682,6 +1682,99 @@ BLIS_INLINE err_t dtrsm_XAltB_ref ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5));\ ymm13 = _mm256_fmsub_pd(ymm0, ymm15, ymm13); +#define BLIS_PRE_DTRSM_SMALL_6x3(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); /*register to hold alpha*/\ +\ + ymm0 = _mm256_broadcast_sd ((double const *)(b11 + 2));\ + xmm5 = _mm_loadu_pd((double const *)(b11));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ +\ + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3);\ +\ + ymm0 = _mm256_broadcast_sd ((double const *)(b11 + 2 + cs_b));\ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ +\ + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5);\ +\ + ymm0 = _mm256_broadcast_sd ((double const *)(b11 + 2 + cs_b*2));\ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b*2));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ +\ + ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7);\ +\ + ymm0 = _mm256_broadcast_sd ((double const *)(b11 + 2 + cs_b*3));\ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b*3));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ +\ + ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9);\ +\ + ymm0 = _mm256_broadcast_sd ((double const *)(b11 + 2 + cs_b*4));\ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b*4));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ +\ + ymm11 = _mm256_fmsub_pd(ymm0, ymm15, ymm11);\ +\ + ymm0 = _mm256_broadcast_sd ((double const *)(b11 + 2 + cs_b*5));\ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b*5));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ +\ + ymm13 = _mm256_fmsub_pd(ymm0, ymm15, ymm13); + +#define BLIS_PRE_DTRSM_SMALL_6x2(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); /*register to hold alpha*/\ +\ + xmm5 = _mm_loadu_pd((double const *)(b11));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ +\ + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3);\ +\ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ +\ + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5);\ +\ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b*2));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ +\ + ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7);\ +\ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b*3));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ +\ + ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9);\ +\ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b*4));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ +\ + ymm11 = _mm256_fmsub_pd(ymm0, ymm15, ymm11);\ +\ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b*5));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ +\ + ymm13 = _mm256_fmsub_pd(ymm0, ymm15, ymm13); + +#define BLIS_PRE_DTRSM_SMALL_6x1(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); /*register to hold alpha*/\ +\ + ymm0 = _mm256_broadcast_sd ((double const *)(b11));\ + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3);\ +\ + ymm0 = _mm256_broadcast_sd ((double const *)(b11 + cs_b));\ + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5);\ +\ + ymm0 = _mm256_broadcast_sd ((double const *)(b11 + cs_b*2));\ + ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7);\ +\ + ymm0 = _mm256_broadcast_sd ((double const *)(b11 + cs_b*3));\ + ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9);\ +\ + ymm0 = _mm256_broadcast_sd ((double const *)(b11 + cs_b*4));\ + ymm11 = _mm256_fmsub_pd(ymm0, ymm15, ymm11);\ +\ + ymm0 = _mm256_broadcast_sd ((double const *)(b11 + cs_b*5));\ + ymm13 = _mm256_fmsub_pd(ymm0, ymm15, ymm13); + #ifdef BLIS_DISABLE_TRSM_PREINVERSION #define STRSM_SMALL_DIV_OR_SCALE _mm256_div_ps #endif @@ -1936,6 +2029,22 @@ BLIS_INLINE err_t dtrsm_XAltB_ref b10 += cs_b;\ } +#define BLIS_STRSM_SMALL_GEMM_1nx5m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 5x1 block of B10*/\ + ymm0 = _mm256_broadcast_ss((float const *)(b10 + 4));\ + xmm5 = _mm_loadu_ps((float const *)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + /*GEMM block used in strsm small left cases*/ #define BLIS_STRSM_SMALL_GEMM_16mx6n(a10,b01,cs_b,p_lda,k_iter) \ float *b01_prefetch = b01 + 8; \ @@ -3228,6 +3337,280 @@ BLIS_INLINE err_t dtrsm_XAltB_ref ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5));\ ymm13 = _mm256_fmsub_ps(ymm0, ymm15, ymm13); +#define BLIS_PRE_STRSM_SMALL_6x7(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); /*register to hold alpha*/\ +\ + xmm5 = _mm_loadu_ps((float const *)(b11));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + xmm5 = _mm_broadcast_ss((float *)(b11 + 6));\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + 4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ + \ + ymm3 = _mm256_fmsub_ps(ymm0, ymm15, ymm3);\ +\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + xmm5 = _mm_broadcast_ss((float *)(b11 + cs_b + 6));\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + cs_b + 4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ + \ + ymm5 = _mm256_fmsub_ps(ymm0, ymm15, ymm5);\ +\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b*2));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + xmm5 = _mm_broadcast_ss((float *)(b11 + cs_b*2 + 6));\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + cs_b*2 + 4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ + \ + ymm7 = _mm256_fmsub_ps(ymm0, ymm15, ymm7);\ +\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b*3));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + xmm5 = _mm_broadcast_ss((float *)(b11 + cs_b*3 + 6));\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + cs_b*3 + 4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ + \ + ymm9 = _mm256_fmsub_ps(ymm0, ymm15, ymm9);\ +\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b*4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + xmm5 = _mm_broadcast_ss((float *)(b11 + cs_b*4 + 6));\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + cs_b*4 + 4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ + \ + ymm11 = _mm256_fmsub_ps(ymm0, ymm15, ymm11);\ +\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b*5));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + xmm5 = _mm_broadcast_ss((float *)(b11 + cs_b*5 + 6));\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + cs_b*5 + 4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ + \ + ymm13 = _mm256_fmsub_ps(ymm0, ymm15, ymm13); + +#define BLIS_PRE_STRSM_SMALL_6x6(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); /*register to hold alpha*/\ +\ + xmm5 = _mm_loadu_ps((float const *)(b11));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + 4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ +\ + ymm3 = _mm256_fmsub_ps(ymm0, ymm15, ymm3);\ +\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + 4 + cs_b));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ + ymm5 = _mm256_fmsub_ps(ymm0, ymm15, ymm5);\ +\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b*2));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + 4 + cs_b*2));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ + ymm7 = _mm256_fmsub_ps(ymm0, ymm15, ymm7);\ +\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b*3));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + 4 + cs_b*3));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ + ymm9 = _mm256_fmsub_ps(ymm0, ymm15, ymm9);\ +\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b*4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + 4 + cs_b*4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ + ymm11 = _mm256_fmsub_ps(ymm0, ymm15, ymm11);\ +\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b*5));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + 4 + cs_b*5));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ + ymm13 = _mm256_fmsub_ps(ymm0, ymm15, ymm13); + +#define BLIS_PRE_STRSM_SMALL_6x5(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); /*register to hold alpha*/\ +\ + ymm0 = _mm256_broadcast_ss((float const *)(b11 + 4));\ + xmm5 = _mm_loadu_ps((float const *)(b11));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm3 = _mm256_fmsub_ps(ymm0, ymm15, ymm3);\ +\ + ymm0 = _mm256_broadcast_ss((float const *)(b11 + 4 + cs_b));\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm5 = _mm256_fmsub_ps(ymm0, ymm15, ymm5);\ +\ + ymm0 = _mm256_broadcast_ss((float const *)(b11 + 4 + cs_b*2));\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b*2));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm7 = _mm256_fmsub_ps(ymm0, ymm15, ymm7);\ +\ + ymm0 = _mm256_broadcast_ss((float const *)(b11 + 4 + cs_b*3));\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b*3));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm9 = _mm256_fmsub_ps(ymm0, ymm15, ymm9);\ +\ + ymm0 = _mm256_broadcast_ss((float const *)(b11 + 4 + cs_b*4));\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b*4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm11 = _mm256_fmsub_ps(ymm0, ymm15, ymm11);\ +\ + ymm0 = _mm256_broadcast_ss((float const *)(b11 + 4 + cs_b*5));\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b*5));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm13 = _mm256_fmsub_ps(ymm0, ymm15, ymm13); + +#define BLIS_PRE_STRSM_SMALL_6x4(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); /*register to hold alpha*/\ +\ + xmm5 = _mm_loadu_ps((float const *)(b11));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm3 = _mm256_fmsub_ps(ymm0, ymm15, ymm3);\ +\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm5 = _mm256_fmsub_ps(ymm0, ymm15, ymm5);\ +\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b*2));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm7 = _mm256_fmsub_ps(ymm0, ymm15, ymm7);\ +\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b*3));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm9 = _mm256_fmsub_ps(ymm0, ymm15, ymm9);\ +\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b*4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm11 = _mm256_fmsub_ps(ymm0, ymm15, ymm11);\ +\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b*5));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm13 = _mm256_fmsub_ps(ymm0, ymm15, ymm13); + +#define BLIS_PRE_STRSM_SMALL_6x3(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); /*register to hold alpha*/\ +\ + xmm5 = _mm_broadcast_ss((float *)(b11 + 2));\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm3 = _mm256_fmsub_ps(ymm0, ymm15, ymm3);\ +\ + xmm5 = _mm_broadcast_ss((float *)(b11 + 2 + cs_b));\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + cs_b));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm5 = _mm256_fmsub_ps(ymm0, ymm15, ymm5);\ +\ + xmm5 = _mm_broadcast_ss((float *)(b11 + 2 + cs_b*2));\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + cs_b*2));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm7 = _mm256_fmsub_ps(ymm0, ymm15, ymm7);\ +\ + xmm5 = _mm_broadcast_ss((float *)(b11 + 2 + cs_b*3));\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + cs_b*3));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm9 = _mm256_fmsub_ps(ymm0, ymm15, ymm9);\ +\ + xmm5 = _mm_broadcast_ss((float *)(b11 + 2 + cs_b*4));\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + cs_b*4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm11 = _mm256_fmsub_ps(ymm0, ymm15, ymm11);\ +\ + xmm5 = _mm_broadcast_ss((float *)(b11 + 2 + cs_b*5));\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + cs_b*5));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm13 = _mm256_fmsub_ps(ymm0, ymm15, ymm13); + +#define BLIS_PRE_STRSM_SMALL_6x2(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); /*register to hold alpha*/\ +\ + xmm5 = _mm_broadcast_ss((float *)&zero);\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm3 = _mm256_fmsub_ps(ymm0, ymm15, ymm3);\ +\ + xmm5 = _mm_broadcast_ss((float *)&zero);\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + cs_b));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm5 = _mm256_fmsub_ps(ymm0, ymm15, ymm5);\ +\ + xmm5 = _mm_broadcast_ss((float *)&zero);\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + cs_b*2));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm7 = _mm256_fmsub_ps(ymm0, ymm15, ymm7);\ +\ + xmm5 = _mm_broadcast_ss((float *)&zero);\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + cs_b*3));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm9 = _mm256_fmsub_ps(ymm0, ymm15, ymm9);\ +\ + xmm5 = _mm_broadcast_ss((float *)&zero);\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + cs_b*4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm11 = _mm256_fmsub_ps(ymm0, ymm15, ymm11);\ +\ + xmm5 = _mm_broadcast_ss((float *)&zero);\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + cs_b*5));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm13 = _mm256_fmsub_ps(ymm0, ymm15, ymm13); + +#define BLIS_PRE_STRSM_SMALL_6x1(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); /*register to hold alpha*/\ +\ + ymm0 = _mm256_broadcast_ss((float const *)b11);\ + ymm3 = _mm256_fmsub_ps(ymm0, ymm15, ymm3);\ +\ + ymm0 = _mm256_broadcast_ss((float const *)(b11 + cs_b));\ + ymm5 = _mm256_fmsub_ps(ymm0, ymm15, ymm5);\ +\ + ymm0 = _mm256_broadcast_ss((float const *)(b11 + cs_b*2));\ + ymm7 = _mm256_fmsub_ps(ymm0, ymm15, ymm7);\ +\ + ymm0 = _mm256_broadcast_ss((float const *)(b11 + cs_b*3));\ + ymm9 = _mm256_fmsub_ps(ymm0, ymm15, ymm9);\ +\ + ymm0 = _mm256_broadcast_ss((float const *)(b11 + cs_b*4));\ + ymm11 = _mm256_fmsub_ps(ymm0, ymm15, ymm11);\ +\ + ymm0 = _mm256_broadcast_ss((float const *)(b11 + cs_b*5));\ + ymm13 = _mm256_fmsub_ps(ymm0, ymm15, ymm13); + /* Load b11 of size 6x8 and multiply with alpha Add the GEMM output and perform inregister transose of b11 @@ -6628,7 +7011,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB BLIS_DTRSM_SMALL_GEMM_6nx4m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_DTRSM_SMALL_6x4(AlphaVal,b11,cs_b) + BLIS_PRE_DTRSM_SMALL_6x3(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -6752,7 +7135,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB BLIS_DTRSM_SMALL_GEMM_6nx4m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_DTRSM_SMALL_6x4(AlphaVal,b11,cs_b) + BLIS_PRE_DTRSM_SMALL_6x2(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -6869,7 +7252,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB BLIS_DTRSM_SMALL_GEMM_6nx4m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_DTRSM_SMALL_6x4(AlphaVal,b11,cs_b) + BLIS_PRE_DTRSM_SMALL_6x1(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -8276,7 +8659,8 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB if(transa) { - for(dim_t x =0;x < p_lda;x+=d_nr) + dim_t x = 0; + for(x = 0;(x + d_nr - 1) < p_lda;x+=d_nr) { ymm0 = _mm256_loadu_pd((double const *)(a01)); ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); @@ -8315,6 +8699,34 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB a01 += d_nr*cs_a; ptr_a10_dup += d_nr; } + dim_t remainder_loop_count = p_lda - x; + if(remainder_loop_count >= 4) + { + ymm0 = _mm256_loadu_pd((double const *)(a01)); + ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); + ymm2 = _mm256_loadu_pd((double const *)(a01 + cs_a * 2)); + ymm3 = _mm256_loadu_pd((double const *)(a01 + cs_a * 3)); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); + a01 += 4*cs_a; + ptr_a10_dup += 4; + remainder_loop_count = remainder_loop_count - 4; + } } else { @@ -8979,7 +9391,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB BLIS_DTRSM_SMALL_GEMM_6nx4m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_DTRSM_SMALL_6x4(AlphaVal,b11,cs_b) + BLIS_PRE_DTRSM_SMALL_6x3(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -9094,7 +9506,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB BLIS_DTRSM_SMALL_GEMM_6nx4m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_DTRSM_SMALL_6x4(AlphaVal,b11,cs_b) + BLIS_PRE_DTRSM_SMALL_6x2(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -9202,7 +9614,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB BLIS_DTRSM_SMALL_GEMM_6nx4m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_DTRSM_SMALL_6x4(AlphaVal,b11,cs_b) + BLIS_PRE_DTRSM_SMALL_6x1(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -10591,7 +11003,8 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB if(transa) { - for(dim_t x =0;x < p_lda;x+=d_nr) + dim_t x =0; + for(x =0;(x+d_nr-1) < p_lda;x+=d_nr) { ymm0 = _mm256_loadu_pd((double const *)(a01)); ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); @@ -10638,6 +11051,34 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB a01 += d_nr*cs_a; ptr_a10_dup += d_nr; } + dim_t remainder_loop_count = p_lda - x; + if(remainder_loop_count >= 4) + { + ymm0 = _mm256_loadu_pd((double const *)(a01)); + ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); + ymm2 = _mm256_loadu_pd((double const *)(a01 + cs_a * 2)); + ymm3 = _mm256_loadu_pd((double const *)(a01 + cs_a * 3)); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); + a01 += 4*cs_a; + ptr_a10_dup += 4; + remainder_loop_count = remainder_loop_count - 4; + } } else { @@ -11696,7 +12137,7 @@ BLIS_INLINE err_t bli_dtrsm_small_AltXB_AuXB dim_t p_lda = 4; // packed leading dimension if(transa) { - for(dim_t x =0;x < m-i+4;x+=p_lda) + for(dim_t x =0;x < m-i-4;x+=p_lda) { ymm0 = _mm256_loadu_pd((double const *)(a10)); ymm1 = _mm256_loadu_pd((double const *)(a10 + cs_a)); @@ -14530,7 +14971,9 @@ BLIS_INLINE err_t bli_dtrsm_small_AutXB_AlXB _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 5)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); @@ -14538,7 +14981,8 @@ BLIS_INLINE err_t bli_dtrsm_small_AutXB_AlXB ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store(B11[0-3][3]) + xmm5 = _mm256_castpd256_pd128(ymm1); + _mm_storeu_pd((double *)(b11 + cs_b * 5), xmm5); if(transa) dtrsm_AutXB_ref(a11, b11, m_rem, 6, cs_a, cs_b, is_unitdiag); @@ -14585,7 +15029,7 @@ BLIS_INLINE err_t bli_dtrsm_small_AutXB_AlXB _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - xmm5 = _mm256_castpd256_pd128(ymm3); + xmm5 = _mm256_castpd256_pd128(ymm3); _mm_storeu_pd((double *)(b11 + cs_b * 3), xmm5); if(transa) @@ -15820,7 +16264,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_STRSM_SMALL_6x8(AlphaVal,b11,cs_b) + BLIS_PRE_STRSM_SMALL_6x7(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -15903,25 +16347,29 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm0 = _mm256_loadu_ps((float const *)b11); - ymm3 = _mm256_blend_ps(ymm0, ymm3, 0x7F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_ps(ymm0, ymm5, 0x7F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_ps(ymm0, ymm7, 0x7F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_ps(ymm0, ymm9, 0x7F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_ps(ymm0, ymm11, 0x7F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_ps(ymm0, ymm13, 0x7F); - - _mm256_storeu_ps((float *)b11, ymm3); - _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); - _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); - _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); - _mm256_storeu_ps((float *)(b11 + cs_b*4), ymm11); - _mm256_storeu_ps((float *)(b11 + cs_b*5), ymm13); + _mm_storeu_ps((float *)(b11),_mm256_extractf128_ps(ymm3, 0)); + _mm_storel_pi((__m64 *)(b11 + 4),_mm256_extractf128_ps(ymm3, 1)); + _mm_store_ss((float *)(b11 + 6),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm3,ymm3), 1)); + + _mm_storeu_ps((float *)(b11 + cs_b),_mm256_extractf128_ps(ymm5, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b),_mm256_extractf128_ps(ymm5, 1)); + _mm_store_ss((float *)(b11 + 6 + cs_b),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm5,ymm5), 1)); + + _mm_storeu_ps((float *)(b11 + cs_b*2),_mm256_extractf128_ps(ymm7, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*2),_mm256_extractf128_ps(ymm7, 1)); + _mm_store_ss((float *)(b11 + 6 + cs_b*2),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm7,ymm7), 1)); + + _mm_storeu_ps((float *)(b11 + cs_b*3),_mm256_extractf128_ps(ymm9, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*3),_mm256_extractf128_ps(ymm9, 1)); + _mm_store_ss((float *)(b11 + 6 + cs_b*3),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm9,ymm9), 1)); + + _mm_storeu_ps((float *)(b11 + cs_b*4),_mm256_extractf128_ps(ymm11, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*4),_mm256_extractf128_ps(ymm11, 1)); + _mm_store_ss((float *)(b11 + 6 + cs_b*4),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm11,ymm11), 1)); + + _mm_storeu_ps((float *)(b11 + cs_b*5),_mm256_extractf128_ps(ymm13, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*5),_mm256_extractf128_ps(ymm13, 1)); + _mm_store_ss((float *)(b11 + 6 + cs_b*5),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm13,ymm13), 1)); m_remainder -=7; } @@ -15941,7 +16389,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_STRSM_SMALL_6x8(AlphaVal,b11,cs_b) + BLIS_PRE_STRSM_SMALL_6x6(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -16024,25 +16472,18 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm0 = _mm256_loadu_ps((float const *)b11); - ymm3 = _mm256_blend_ps(ymm0, ymm3, 0x3F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_ps(ymm0, ymm5, 0x3F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_ps(ymm0, ymm7, 0x3F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_ps(ymm0, ymm9, 0x3F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_ps(ymm0, ymm11, 0x3F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_ps(ymm0, ymm13, 0x3F); - - _mm256_storeu_ps((float *)b11, ymm3); - _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); - _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); - _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); - _mm256_storeu_ps((float *)(b11 + cs_b*4), ymm11); - _mm256_storeu_ps((float *)(b11 + cs_b*5), ymm13); + _mm_storeu_ps((float *)(b11),_mm256_extractf128_ps(ymm3, 0)); + _mm_storel_pi((__m64 *)(b11 + 4),_mm256_extractf128_ps(ymm3, 1)); + _mm_storeu_ps((float *)(b11 + cs_b),_mm256_extractf128_ps(ymm5, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b),_mm256_extractf128_ps(ymm5, 1)); + _mm_storeu_ps((float *)(b11 + cs_b*2),_mm256_extractf128_ps(ymm7, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*2),_mm256_extractf128_ps(ymm7, 1)); + _mm_storeu_ps((float *)(b11 + cs_b*3),_mm256_extractf128_ps(ymm9, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*3),_mm256_extractf128_ps(ymm9, 1)); + _mm_storeu_ps((float *)(b11 + cs_b*4),_mm256_extractf128_ps(ymm11, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*4),_mm256_extractf128_ps(ymm11, 1)); + _mm_storeu_ps((float *)(b11 + cs_b*5),_mm256_extractf128_ps(ymm13, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*5),_mm256_extractf128_ps(ymm13, 1)); m_remainder -=6; } @@ -16062,7 +16503,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_STRSM_SMALL_6x8(AlphaVal,b11,cs_b) + BLIS_PRE_STRSM_SMALL_6x5(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -16145,25 +16586,18 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm0 = _mm256_loadu_ps((float const *)b11); - ymm3 = _mm256_blend_ps(ymm0, ymm3, 0x1F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_ps(ymm0, ymm5, 0x1F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_ps(ymm0, ymm7, 0x1F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_ps(ymm0, ymm9, 0x1F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_ps(ymm0, ymm11, 0x1F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_ps(ymm0, ymm13, 0x1F); - - _mm256_storeu_ps((float *)b11, ymm3); - _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); - _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); - _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); - _mm256_storeu_ps((float *)(b11 + cs_b*4), ymm11); - _mm256_storeu_ps((float *)(b11 + cs_b*5), ymm13); + _mm_storeu_ps((float *)(b11),_mm256_extractf128_ps(ymm3, 0)); + _mm_store_ss((float *)(b11 + 4),_mm256_extractf128_ps(ymm3, 1)); + _mm_storeu_ps((float *)(b11 + cs_b),_mm256_extractf128_ps(ymm5, 0)); + _mm_store_ss((float *)(b11 + 4 + cs_b),_mm256_extractf128_ps(ymm5, 1)); + _mm_storeu_ps((float *)(b11 + cs_b*2),_mm256_extractf128_ps(ymm7, 0)); + _mm_store_ss((float *)(b11 + 4 + cs_b*2),_mm256_extractf128_ps(ymm7, 1)); + _mm_storeu_ps((float *)(b11 + cs_b*3),_mm256_extractf128_ps(ymm9, 0)); + _mm_store_ss((float *)(b11 + 4 + cs_b*3),_mm256_extractf128_ps(ymm9, 1)); + _mm_storeu_ps((float *)(b11 + cs_b*4),_mm256_extractf128_ps(ymm11, 0)); + _mm_store_ss((float *)(b11 + 4 + cs_b*4),_mm256_extractf128_ps(ymm11, 1)); + _mm_storeu_ps((float *)(b11 + cs_b*5),_mm256_extractf128_ps(ymm13, 0)); + _mm_store_ss((float *)(b11 + 4 + cs_b*5),_mm256_extractf128_ps(ymm13, 1)); m_remainder -=5; } @@ -16183,7 +16617,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_STRSM_SMALL_6x8(AlphaVal,b11,cs_b) + BLIS_PRE_STRSM_SMALL_6x4(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -16266,25 +16700,12 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm0 = _mm256_loadu_ps((float const *)b11); - ymm3 = _mm256_blend_ps(ymm0, ymm3, 0x0F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_ps(ymm0, ymm5, 0x0F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_ps(ymm0, ymm7, 0x0F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_ps(ymm0, ymm9, 0x0F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_ps(ymm0, ymm11, 0x0F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_ps(ymm0, ymm13, 0x0F); - - _mm256_storeu_ps((float *)b11, ymm3); - _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); - _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); - _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); - _mm256_storeu_ps((float *)(b11 + cs_b*4), ymm11); - _mm256_storeu_ps((float *)(b11 + cs_b*5), ymm13); + _mm_storeu_ps((float *)(b11),_mm256_extractf128_ps(ymm3, 0)); + _mm_storeu_ps((float *)(b11 + cs_b),_mm256_extractf128_ps(ymm5, 0)); + _mm_storeu_ps((float *)(b11 + cs_b*2),_mm256_extractf128_ps(ymm7, 0)); + _mm_storeu_ps((float *)(b11 + cs_b*3),_mm256_extractf128_ps(ymm9, 0)); + _mm_storeu_ps((float *)(b11 + cs_b*4),_mm256_extractf128_ps(ymm11, 0)); + _mm_storeu_ps((float *)(b11 + cs_b*5),_mm256_extractf128_ps(ymm13, 0)); m_remainder -=4; } @@ -16304,7 +16725,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_STRSM_SMALL_6x8(AlphaVal,b11,cs_b) + BLIS_PRE_STRSM_SMALL_6x3(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -16387,25 +16808,29 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm0 = _mm256_loadu_ps((float const *)b11); - ymm3 = _mm256_blend_ps(ymm0, ymm3, 0x07); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_ps(ymm0, ymm5, 0x07); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_ps(ymm0, ymm7, 0x07); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_ps(ymm0, ymm9, 0x07); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_ps(ymm0, ymm11, 0x07); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_ps(ymm0, ymm13, 0x07); - - _mm256_storeu_ps((float *)b11, ymm3); - _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); - _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); - _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); - _mm256_storeu_ps((float *)(b11 + cs_b*4), ymm11); - _mm256_storeu_ps((float *)(b11 + cs_b*5), ymm13); + xmm5 = _mm256_extractf128_ps(ymm3, 0); + _mm_storel_pi((__m64 *)(b11),xmm5); + _mm_store_ss((float *)(b11+2),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm3,ymm3), 0)); + + xmm5 = _mm256_extractf128_ps(ymm5, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b),xmm5); + _mm_store_ss((float *)(b11+ 2 + cs_b),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm5,ymm5), 0)); + + xmm5 = _mm256_extractf128_ps(ymm7, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b*2),xmm5); + _mm_store_ss((float *)(b11 + 2 + cs_b*2),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm7,ymm7), 0)); + + xmm5 = _mm256_extractf128_ps(ymm9, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b*3),xmm5); + _mm_store_ss((float *)(b11 + 2 + cs_b*3),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm9,ymm9), 0)); + + xmm5 = _mm256_extractf128_ps(ymm11, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b*4),xmm5); + _mm_store_ss((float *)(b11 + 2 + cs_b*4),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm11,ymm11), 0)); + + xmm5 = _mm256_extractf128_ps(ymm13, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b*5),xmm5); + _mm_store_ss((float *)(b11 + 2 + cs_b*5),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm13,ymm13), 0)); m_remainder -=3; } @@ -16425,7 +16850,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_STRSM_SMALL_6x8(AlphaVal,b11,cs_b) + BLIS_PRE_STRSM_SMALL_6x2(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -16508,25 +16933,23 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm0 = _mm256_loadu_ps((float const *)b11); - ymm3 = _mm256_blend_ps(ymm0, ymm3, 0x03); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_ps(ymm0, ymm5, 0x03); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_ps(ymm0, ymm7, 0x03); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_ps(ymm0, ymm9, 0x03); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_ps(ymm0, ymm11, 0x03); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_ps(ymm0, ymm13, 0x03); - - _mm256_storeu_ps((float *)b11, ymm3); - _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); - _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); - _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); - _mm256_storeu_ps((float *)(b11 + cs_b*4), ymm11); - _mm256_storeu_ps((float *)(b11 + cs_b*5), ymm13); + xmm5 = _mm256_extractf128_ps(ymm3, 0); + _mm_storel_pi((__m64 *)(b11),xmm5); + + xmm5 = _mm256_extractf128_ps(ymm5, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b),xmm5); + + xmm5 = _mm256_extractf128_ps(ymm7, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b*2),xmm5); + + xmm5 = _mm256_extractf128_ps(ymm9, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b*3),xmm5); + + xmm5 = _mm256_extractf128_ps(ymm11, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b*4),xmm5); + + xmm5 = _mm256_extractf128_ps(ymm13, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b*5),xmm5); m_remainder -=2; } @@ -16546,7 +16969,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_STRSM_SMALL_6x8(AlphaVal,b11,cs_b) + BLIS_PRE_STRSM_SMALL_6x1(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -16629,25 +17052,12 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm0 = _mm256_loadu_ps((float const *)b11); - ymm3 = _mm256_blend_ps(ymm0, ymm3, 0x01); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_ps(ymm0, ymm5, 0x01); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_ps(ymm0, ymm7, 0x01); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_ps(ymm0, ymm9, 0x01); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_ps(ymm0, ymm11, 0x01); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_ps(ymm0, ymm13, 0x01); - - _mm256_storeu_ps((float *)b11, ymm3); - _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); - _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); - _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); - _mm256_storeu_ps((float *)(b11 + cs_b*4), ymm11); - _mm256_storeu_ps((float *)(b11 + cs_b*5), ymm13); + _mm_store_ss((b11 + cs_b * 0), _mm256_extractf128_ps(ymm3, 0)); + _mm_store_ss((b11 + cs_b * 1), _mm256_extractf128_ps(ymm5, 0)); + _mm_store_ss((b11 + cs_b * 2), _mm256_extractf128_ps(ymm7, 0)); + _mm_store_ss((b11 + cs_b * 3), _mm256_extractf128_ps(ymm9, 0)); + _mm_store_ss((b11 + cs_b * 4), _mm256_extractf128_ps(ymm11, 0)); + _mm_store_ss((b11 + cs_b * 5), _mm256_extractf128_ps(ymm13, 0)); m_remainder -=1; } @@ -18690,7 +19100,8 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB __m128 xmm0, xmm1, xmm2, xmm3; __m128 xmm4, xmm5, xmm6, xmm7; __m128 xmm8, xmm9; - for(dim_t x =0;x < p_lda;x+=d_nr) + dim_t x = 0; + for(x =0;(x+d_nr-1) < p_lda;x+=d_nr) { xmm0 = _mm_loadu_ps((float const *)(a01)); xmm1 = _mm_loadu_ps((float const *)(a01 + cs_a)); @@ -18733,6 +19144,33 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB a01 += d_nr*cs_a; ptr_a10_dup += d_nr; } + dim_t remainder_count = p_lda - x; + if(remainder_count >= 4) + { + xmm0 = _mm_loadu_ps((float const *)(a01)); + xmm1 = _mm_loadu_ps((float const *)(a01 + cs_a)); + xmm2 = _mm_loadu_ps((float const *)(a01 + cs_a * 2)); + xmm3 = _mm_loadu_ps((float const *)(a01 + cs_a * 3)); + + xmm4 = _mm_unpacklo_ps(xmm0, xmm1); + xmm5 = _mm_unpacklo_ps(xmm2, xmm3); + xmm6 = _mm_shuffle_ps(xmm4,xmm5,0x44); + xmm7 = _mm_shuffle_ps(xmm4,xmm5,0xEE); + + xmm0 = _mm_unpackhi_ps(xmm0, xmm1); + xmm1 = _mm_unpackhi_ps(xmm2, xmm3); + xmm8 = _mm_shuffle_ps(xmm0,xmm1,0x44); + xmm9 = _mm_shuffle_ps(xmm0,xmm1,0xEE); + + _mm_storeu_ps((float *)(ptr_a10_dup), xmm6); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda), xmm7); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda*2), xmm8); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda*3), xmm9); + + a01 += 4*cs_a; + ptr_a10_dup += 4; + remainder_count = remainder_count - 4; + } } else { @@ -18909,7 +19347,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB ymm3 = _mm256_setzero_ps(); ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_1nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_1nx5m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_1N_5M(AlphaVal,b11,cs_b) @@ -19510,7 +19948,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 8x6 and multiply with alpha - BLIS_PRE_STRSM_SMALL_6x8(AlphaVal,b11,cs_b) + BLIS_PRE_STRSM_SMALL_6x7(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -19601,25 +20039,29 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB ymm13 = STRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - ymm0 = _mm256_loadu_ps((float const *)b11); - ymm3 = _mm256_blend_ps(ymm0, ymm3, 0x7F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_ps(ymm0, ymm5, 0x7F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_ps(ymm0, ymm7, 0x7F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_ps(ymm0, ymm9, 0x7F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_ps(ymm0, ymm11, 0x7F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_ps(ymm0, ymm13, 0x7F); + _mm_storeu_ps((float *)(b11),_mm256_extractf128_ps(ymm3, 0)); + _mm_storel_pi((__m64 *)(b11 + 4),_mm256_extractf128_ps(ymm3, 1)); + _mm_store_ss((float *)(b11 + 6),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm3,ymm3), 1)); - _mm256_storeu_ps((float *)b11, ymm3); - _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); - _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); - _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); - _mm256_storeu_ps((float *)(b11 + cs_b*4), ymm11); - _mm256_storeu_ps((float *)(b11 + cs_b*5), ymm13); + _mm_storeu_ps((float *)(b11 + cs_b),_mm256_extractf128_ps(ymm5, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b),_mm256_extractf128_ps(ymm5, 1)); + _mm_store_ss((float *)(b11 + 6 + cs_b),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm5,ymm5), 1)); + + _mm_storeu_ps((float *)(b11 + cs_b*2),_mm256_extractf128_ps(ymm7, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*2),_mm256_extractf128_ps(ymm7, 1)); + _mm_store_ss((float *)(b11 + 6 + cs_b*2),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm7,ymm7), 1)); + + _mm_storeu_ps((float *)(b11 + cs_b*3),_mm256_extractf128_ps(ymm9, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*3),_mm256_extractf128_ps(ymm9, 1)); + _mm_store_ss((float *)(b11 + 6 + cs_b*3),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm9,ymm9), 1)); + + _mm_storeu_ps((float *)(b11 + cs_b*4),_mm256_extractf128_ps(ymm11, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*4),_mm256_extractf128_ps(ymm11, 1)); + _mm_store_ss((float *)(b11 + 6 + cs_b*4),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm11,ymm11), 1)); + + _mm_storeu_ps((float *)(b11 + cs_b*5),_mm256_extractf128_ps(ymm13, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*5),_mm256_extractf128_ps(ymm13, 1)); + _mm_store_ss((float *)(b11 + 6 + cs_b*5),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm13,ymm13), 1)); m_remainder -= 7; i += 7; @@ -19640,7 +20082,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 8x6 and multiply with alpha - BLIS_PRE_STRSM_SMALL_6x8(AlphaVal,b11,cs_b) + BLIS_PRE_STRSM_SMALL_6x6(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -19731,25 +20173,18 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB ymm13 = STRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - ymm0 = _mm256_loadu_ps((float const *)b11); - ymm3 = _mm256_blend_ps(ymm0, ymm3, 0x3F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_ps(ymm0, ymm5, 0x3F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_ps(ymm0, ymm7, 0x3F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_ps(ymm0, ymm9, 0x3F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_ps(ymm0, ymm11, 0x3F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_ps(ymm0, ymm13, 0x3F); - - _mm256_storeu_ps((float *)b11, ymm3); - _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); - _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); - _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); - _mm256_storeu_ps((float *)(b11 + cs_b*4), ymm11); - _mm256_storeu_ps((float *)(b11 + cs_b*5), ymm13); + _mm_storeu_ps((float *)(b11),_mm256_extractf128_ps(ymm3, 0)); + _mm_storel_pi((__m64 *)(b11 + 4),_mm256_extractf128_ps(ymm3, 1)); + _mm_storeu_ps((float *)(b11 + cs_b),_mm256_extractf128_ps(ymm5, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b),_mm256_extractf128_ps(ymm5, 1)); + _mm_storeu_ps((float *)(b11 + cs_b*2),_mm256_extractf128_ps(ymm7, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*2),_mm256_extractf128_ps(ymm7, 1)); + _mm_storeu_ps((float *)(b11 + cs_b*3),_mm256_extractf128_ps(ymm9, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*3),_mm256_extractf128_ps(ymm9, 1)); + _mm_storeu_ps((float *)(b11 + cs_b*4),_mm256_extractf128_ps(ymm11, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*4),_mm256_extractf128_ps(ymm11, 1)); + _mm_storeu_ps((float *)(b11 + cs_b*5),_mm256_extractf128_ps(ymm13, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*5),_mm256_extractf128_ps(ymm13, 1)); m_remainder -= 6; i += 6; @@ -19770,7 +20205,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 8x6 and multiply with alpha - BLIS_PRE_STRSM_SMALL_6x8(AlphaVal,b11,cs_b) + BLIS_PRE_STRSM_SMALL_6x5(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -19861,25 +20296,18 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB ymm13 = STRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - ymm0 = _mm256_loadu_ps((float const *)b11); - ymm3 = _mm256_blend_ps(ymm0, ymm3, 0x1F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_ps(ymm0, ymm5, 0x1F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_ps(ymm0, ymm7, 0x1F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_ps(ymm0, ymm9, 0x1F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_ps(ymm0, ymm11, 0x1F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_ps(ymm0, ymm13, 0x1F); - - _mm256_storeu_ps((float *)b11, ymm3); - _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); - _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); - _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); - _mm256_storeu_ps((float *)(b11 + cs_b*4), ymm11); - _mm256_storeu_ps((float *)(b11 + cs_b*5), ymm13); + _mm_storeu_ps((float *)(b11),_mm256_extractf128_ps(ymm3, 0)); + _mm_store_ss((float *)(b11 + 4),_mm256_extractf128_ps(ymm3, 1)); + _mm_storeu_ps((float *)(b11 + cs_b),_mm256_extractf128_ps(ymm5, 0)); + _mm_store_ss((float *)(b11 + 4 + cs_b),_mm256_extractf128_ps(ymm5, 1)); + _mm_storeu_ps((float *)(b11 + cs_b*2),_mm256_extractf128_ps(ymm7, 0)); + _mm_store_ss((float *)(b11 + 4 + cs_b*2),_mm256_extractf128_ps(ymm7, 1)); + _mm_storeu_ps((float *)(b11 + cs_b*3),_mm256_extractf128_ps(ymm9, 0)); + _mm_store_ss((float *)(b11 + 4 + cs_b*3),_mm256_extractf128_ps(ymm9, 1)); + _mm_storeu_ps((float *)(b11 + cs_b*4),_mm256_extractf128_ps(ymm11, 0)); + _mm_store_ss((float *)(b11 + 4 + cs_b*4),_mm256_extractf128_ps(ymm11, 1)); + _mm_storeu_ps((float *)(b11 + cs_b*5),_mm256_extractf128_ps(ymm13, 0)); + _mm_store_ss((float *)(b11 + 4 + cs_b*5),_mm256_extractf128_ps(ymm13, 1)); m_remainder -= 5; i += 5; @@ -19900,7 +20328,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 8x6 and multiply with alpha - BLIS_PRE_STRSM_SMALL_6x8(AlphaVal,b11,cs_b) + BLIS_PRE_STRSM_SMALL_6x4(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -19991,25 +20419,12 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB ymm13 = STRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - ymm0 = _mm256_loadu_ps((float const *)b11); - ymm3 = _mm256_blend_ps(ymm0, ymm3, 0x0F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_ps(ymm0, ymm5, 0x0F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_ps(ymm0, ymm7, 0x0F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_ps(ymm0, ymm9, 0x0F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_ps(ymm0, ymm11, 0x0F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_ps(ymm0, ymm13, 0x0F); - - _mm256_storeu_ps((float *)b11, ymm3); - _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); - _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); - _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); - _mm256_storeu_ps((float *)(b11 + cs_b*4), ymm11); - _mm256_storeu_ps((float *)(b11 + cs_b*5), ymm13); + _mm_storeu_ps((float *)(b11),_mm256_extractf128_ps(ymm3, 0)); + _mm_storeu_ps((float *)(b11 + cs_b),_mm256_extractf128_ps(ymm5, 0)); + _mm_storeu_ps((float *)(b11 + cs_b*2),_mm256_extractf128_ps(ymm7, 0)); + _mm_storeu_ps((float *)(b11 + cs_b*3),_mm256_extractf128_ps(ymm9, 0)); + _mm_storeu_ps((float *)(b11 + cs_b*4),_mm256_extractf128_ps(ymm11, 0)); + _mm_storeu_ps((float *)(b11 + cs_b*5),_mm256_extractf128_ps(ymm13, 0)); m_remainder -= 4; i += 4; @@ -20030,7 +20445,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 8x6 and multiply with alpha - BLIS_PRE_STRSM_SMALL_6x8(AlphaVal,b11,cs_b) + BLIS_PRE_STRSM_SMALL_6x3(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -20121,25 +20536,29 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB ymm13 = STRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - ymm0 = _mm256_loadu_ps((float const *)b11); - ymm3 = _mm256_blend_ps(ymm0, ymm3, 0x07); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_ps(ymm0, ymm5, 0x07); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_ps(ymm0, ymm7, 0x07); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_ps(ymm0, ymm9, 0x07); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_ps(ymm0, ymm11, 0x07); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_ps(ymm0, ymm13, 0x07); + xmm5 = _mm256_extractf128_ps(ymm3, 0); + _mm_storel_pi((__m64 *)(b11),xmm5); + _mm_store_ss((float *)(b11+2),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm3,ymm3), 0)); - _mm256_storeu_ps((float *)b11, ymm3); - _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); - _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); - _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); - _mm256_storeu_ps((float *)(b11 + cs_b*4), ymm11); - _mm256_storeu_ps((float *)(b11 + cs_b*5), ymm13); + xmm5 = _mm256_extractf128_ps(ymm5, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b),xmm5); + _mm_store_ss((float *)(b11+ 2 + cs_b),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm5,ymm5), 0)); + + xmm5 = _mm256_extractf128_ps(ymm7, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b*2),xmm5); + _mm_store_ss((float *)(b11 + 2 + cs_b*2),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm7,ymm7), 0)); + + xmm5 = _mm256_extractf128_ps(ymm9, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b*3),xmm5); + _mm_store_ss((float *)(b11 + 2 + cs_b*3),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm9,ymm9), 0)); + + xmm5 = _mm256_extractf128_ps(ymm11, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b*4),xmm5); + _mm_store_ss((float *)(b11 + 2 + cs_b*4),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm11,ymm11), 0)); + + xmm5 = _mm256_extractf128_ps(ymm13, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b*5),xmm5); + _mm_store_ss((float *)(b11 + 2 + cs_b*5),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm13,ymm13), 0)); m_remainder -= 3; i += 3; @@ -20160,7 +20579,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_STRSM_SMALL_6x8(AlphaVal,b11,cs_b) + BLIS_PRE_STRSM_SMALL_6x2(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -20251,25 +20670,23 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB ymm13 = STRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - ymm0 = _mm256_loadu_ps((float const *)b11); - ymm3 = _mm256_blend_ps(ymm0, ymm3, 0x03); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_ps(ymm0, ymm5, 0x03); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_ps(ymm0, ymm7, 0x03); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_ps(ymm0, ymm9, 0x03); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_ps(ymm0, ymm11, 0x03); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_ps(ymm0, ymm13, 0x03); + xmm5 = _mm256_extractf128_ps(ymm3, 0); + _mm_storel_pi((__m64 *)(b11),xmm5); - _mm256_storeu_ps((float *)b11, ymm3); - _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); - _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); - _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); - _mm256_storeu_ps((float *)(b11 + cs_b*4), ymm11); - _mm256_storeu_ps((float *)(b11 + cs_b*5), ymm13); + xmm5 = _mm256_extractf128_ps(ymm5, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b),xmm5); + + xmm5 = _mm256_extractf128_ps(ymm7, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b*2),xmm5); + + xmm5 = _mm256_extractf128_ps(ymm9, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b*3),xmm5); + + xmm5 = _mm256_extractf128_ps(ymm11, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b*4),xmm5); + + xmm5 = _mm256_extractf128_ps(ymm13, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b*5),xmm5); m_remainder -= 2; i += 2; @@ -20290,7 +20707,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_STRSM_SMALL_6x8(AlphaVal,b11,cs_b) + BLIS_PRE_STRSM_SMALL_6x1(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -20381,25 +20798,12 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB ymm13 = STRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - ymm0 = _mm256_loadu_ps((float const *)b11); - ymm3 = _mm256_blend_ps(ymm0, ymm3, 0x01); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_ps(ymm0, ymm5, 0x01); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_ps(ymm0, ymm7, 0x01); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_ps(ymm0, ymm9, 0x01); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_ps(ymm0, ymm11, 0x01); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_ps(ymm0, ymm13, 0x01); - - _mm256_storeu_ps((float *)b11, ymm3); - _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); - _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); - _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); - _mm256_storeu_ps((float *)(b11 + cs_b*4), ymm11); - _mm256_storeu_ps((float *)(b11 + cs_b*5), ymm13); + _mm_store_ss((b11 + cs_b * 0), _mm256_extractf128_ps(ymm3, 0)); + _mm_store_ss((b11 + cs_b * 1), _mm256_extractf128_ps(ymm5, 0)); + _mm_store_ss((b11 + cs_b * 2), _mm256_extractf128_ps(ymm7, 0)); + _mm_store_ss((b11 + cs_b * 3), _mm256_extractf128_ps(ymm9, 0)); + _mm_store_ss((b11 + cs_b * 4), _mm256_extractf128_ps(ymm11, 0)); + _mm_store_ss((b11 + cs_b * 5), _mm256_extractf128_ps(ymm13, 0)); m_remainder -= 1; i += 1; @@ -22523,7 +22927,8 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB __m128 xmm4, xmm5, xmm6, xmm7; __m128 xmm8, xmm9; - for(dim_t x =0;x < p_lda;x+=d_nr) + dim_t x = 0; + for(x =0;(x+d_nr-1) < p_lda;x+=d_nr) { xmm0 = _mm_loadu_ps((float const *)(a01)); xmm1 = _mm_loadu_ps((float const *)(a01 + cs_a)); @@ -22566,6 +22971,32 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB a01 += d_nr*cs_a; ptr_a10_dup += d_nr; } + dim_t remainder_count = p_lda - x; + if(remainder_count >= 4) + { + xmm0 = _mm_loadu_ps((float const *)(a01)); + xmm1 = _mm_loadu_ps((float const *)(a01 + cs_a)); + xmm2 = _mm_loadu_ps((float const *)(a01 + cs_a * 2)); + xmm3 = _mm_loadu_ps((float const *)(a01 + cs_a * 3)); + + xmm4 = _mm_unpacklo_ps(xmm0, xmm1); + xmm5 = _mm_unpacklo_ps(xmm2, xmm3); + xmm6 = _mm_shuffle_ps(xmm4,xmm5,0x44); + xmm7 = _mm_shuffle_ps(xmm4,xmm5,0xEE); + + xmm0 = _mm_unpackhi_ps(xmm0, xmm1); + xmm1 = _mm_unpackhi_ps(xmm2, xmm3); + xmm8 = _mm_shuffle_ps(xmm0,xmm1,0x44); + xmm9 = _mm_shuffle_ps(xmm0,xmm1,0xEE); + + _mm_storeu_ps((float *)(ptr_a10_dup), xmm6); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda), xmm7); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda*2), xmm8); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda*3), xmm9); + + a01 += 4*cs_a; + ptr_a10_dup += 4; + } } else { @@ -29414,7 +29845,7 @@ BLIS_INLINE err_t bli_strsm_small_AltXB_AuXB dim_t p_lda = 8; // packed leading dimension if(transa) { - for(dim_t x =0;x < m-i+8;x+=p_lda) + for(dim_t x =0;x < m-i-8;x+=p_lda) { ymm0 = _mm256_loadu_ps((float const *)(a10)); ymm1 = _mm256_loadu_ps((float const *)(a10 + cs_a)); @@ -30332,7 +30763,7 @@ BLIS_INLINE err_t bli_strsm_small_AltXB_AuXB __m128 xmm6,xmm7,xmm8,xmm9; if(transa) { - for(dim_t x =0;x < m-i+4;x+=p_lda) + for(dim_t x =0;x < m-i-4;x+=p_lda) { xmm0 = _mm_loadu_ps((float const *)(a10)); xmm1 = _mm_loadu_ps((float const *)(a10 + cs_a)); @@ -36293,10 +36724,10 @@ BLIS_INLINE err_t bli_ztrsm_small_XAltB_XAuB BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_ZTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_ZTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_ZTRSM_SMALL_3x4(AlphaVal,b11,cs_b) + BLIS_PRE_ZTRSM_SMALL_2x4(AlphaVal,b11,cs_b) ///implement TRSM/// ////extract a00 @@ -37265,7 +37696,8 @@ BLIS_INLINE void ctrsm_small_pack_diag_element ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal));\ ymm16 = _mm256_permute_ps(ymm16, 0x44);\ \ - ymm0 = _mm256_loadu_ps((float const *)(b11));\ + xmm0 = _mm_loadu_ps((float const *)(b11));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0);\ /*in register transpose * ymm0,ymm1,ymm2 holds * two dcomplex elements of b11 cols*/\ @@ -37367,8 +37799,10 @@ BLIS_INLINE void ctrsm_small_pack_diag_element ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal));\ ymm16 = _mm256_permute_ps(ymm16, 0x44);\ \ - ymm0 = _mm256_loadu_ps((float const *)(b11));\ - ymm1 = _mm256_loadu_ps((float const *)(b11 + cs_b *1));\ + xmm0 = _mm_loadu_ps((float const *)(b11));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0);\ + xmm1 = _mm_loadu_ps((float const *)(b11 + cs_b * 1));\ + ymm1 = _mm256_insertf128_ps(ymm1, xmm1, 0);\ /*in register transpose * ymm0,ymm1,ymm2 holds * two dcomplex elements of b11 cols*/\ @@ -37513,6 +37947,132 @@ BLIS_INLINE void ctrsm_small_pack_diag_element }\ } +/** + * Multiplies Alpha with one scomplex + * element of three column. + */ +#define BLIS_PRE_CTRSM_SMALL_3x1(AlphaVal, b11,cs_b){\ + ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal));\ + ymm16 = _mm256_permute_ps(ymm16, 0x44);\ + \ + xmm1 = _mm_loadl_pi(xmm1, (__m64 const *)(b11));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 0);\ + \ + ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11);\ + ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0);\ + ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5);\ + ymm19 = _mm256_mul_ps(ymm19, ymm17);\ + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19);\ + ymm8 = _mm256_sub_ps(ymm19, ymm8);\ + \ + xmm1 = _mm_loadl_pi(xmm1, (__m64 const *)(b11 + cs_b));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 0);\ + \ + ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11);\ + ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0);\ + ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5);\ + ymm19 = _mm256_mul_ps(ymm19, ymm17);\ + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19);\ + ymm10 = _mm256_sub_ps(ymm19, ymm10);\ + \ + xmm1 = _mm_loadl_pi(xmm1, (__m64 const *)(b11 + cs_b*2));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 0);\ + \ + ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11);\ + ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0);\ + ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5);\ + ymm19 = _mm256_mul_ps(ymm19, ymm17);\ + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19);\ + ymm12 = _mm256_sub_ps(ymm19, ymm12);\ + \ +} + +/** + * Multiplies Alpha with two scomplex + * element of three column. + */ +#define BLIS_PRE_CTRSM_SMALL_3x2(AlphaVal, b11,cs_b){\ + ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal));\ + ymm16 = _mm256_permute_ps(ymm16, 0x44);\ + \ + xmm0 = _mm_loadu_ps((float const *)(b11));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0);\ + \ + ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11);\ + ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0);\ + ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5);\ + ymm19 = _mm256_mul_ps(ymm19, ymm17);\ + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19);\ + ymm8 = _mm256_sub_ps(ymm19, ymm8);\ + \ + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0);\ + \ + ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11);\ + ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0);\ + ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5);\ + ymm19 = _mm256_mul_ps(ymm19, ymm17);\ + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19);\ + ymm10 = _mm256_sub_ps(ymm19, ymm10);\ + \ + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b*2));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0);\ + \ + ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11);\ + ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0);\ + ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5);\ + ymm19 = _mm256_mul_ps(ymm19, ymm17);\ + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19);\ + ymm12 = _mm256_sub_ps(ymm19, ymm12);\ + \ +} + +/** + * Multiplies Alpha with three scomplex + * element of three column. + */ +#define BLIS_PRE_CTRSM_SMALL_3x3(AlphaVal, b11,cs_b){\ + ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal));\ + ymm16 = _mm256_permute_ps(ymm16, 0x44);\ + \ + xmm0 = _mm_loadu_ps((float const *)(b11));\ + xmm1 = _mm_loadl_pi(xmm1, (__m64 const *)(b11 + 2));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0);\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 1);\ + \ + ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11);\ + ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0);\ + ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5);\ + ymm19 = _mm256_mul_ps(ymm19, ymm17);\ + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19);\ + ymm8 = _mm256_sub_ps(ymm19, ymm8);\ + \ + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b));\ + xmm1 = _mm_loadl_pi(xmm1, (__m64 const *)(b11 + cs_b + 2));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0);\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 1);\ + \ + ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11);\ + ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0);\ + ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5);\ + ymm19 = _mm256_mul_ps(ymm19, ymm17);\ + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19);\ + ymm10 = _mm256_sub_ps(ymm19, ymm10);\ + \ + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b*2));\ + xmm1 = _mm_loadl_pi(xmm1, (__m64 const *)(b11 + cs_b*2 + 2));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0);\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 1);\ + \ + ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11);\ + ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0);\ + ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5);\ + ymm19 = _mm256_mul_ps(ymm19, ymm17);\ + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19);\ + ymm12 = _mm256_sub_ps(ymm19, ymm12);\ + \ +} + /** * Multiplies Alpha with four scomplex * element of three column. @@ -40496,8 +41056,7 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB (float const *)(a10 + cs_a)); ymm2 = _mm256_loadu_ps( (float const *)(a10 + cs_a * 2)); - ymm3 = _mm256_loadu_ps( - (float const *)(a10 + cs_a * 3)); + ymm3 = _mm256_broadcast_ss((float const *)&ones); ymm4 = _mm256_shuffle_ps(ymm0, ymm1, 0x44); ymm5 = _mm256_shuffle_ps(ymm2, ymm3, 0x44); @@ -40709,10 +41268,12 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); ymm16 = _mm256_permute_ps(ymm16, 0x44); - ymm0 = _mm256_loadu_ps((float const *)(b11)); - ymm1 = _mm256_loadu_ps((float const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_ps((float const *)(b11 + cs_b *2)); - + xmm0 = _mm_loadu_ps((float const *)(b11)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b)); + ymm1 = _mm256_insertf128_ps(ymm1, xmm0, 0); + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b*2)); + ymm2 = _mm256_insertf128_ps(ymm2, xmm0, 0); ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); @@ -42321,7 +42882,7 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB dim_t p_lda = 4; if(transa) { - for(dim_t x =0;x < m-i+4;x+=p_lda) + for(dim_t x =0;x < m-i-4;x+=p_lda) { ymm0 = _mm256_loadu_ps((float const *)(a10)); ymm1 = _mm256_loadu_ps((float const *)(a10 + cs_a)); @@ -42360,11 +42921,11 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB { if(transa) { - ctrsm_small_pack_diag_element(is_unitdiag,a11,cs_a,d11_pack,m_rem); + ctrsm_small_pack_diag_element(is_unitdiag,a11,cs_a,d11_pack,4); } else { - ctrsm_small_pack_diag_element(is_unitdiag,a11,rs_a,d11_pack,m_rem); + ctrsm_small_pack_diag_element(is_unitdiag,a11,rs_a,d11_pack,4); } } @@ -43556,7 +44117,7 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB BLIS_CTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_CTRSM_SMALL_3x4(AlphaVal,b11,cs_b) + BLIS_PRE_CTRSM_SMALL_3x3(AlphaVal,b11,cs_b) ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 2)); @@ -43664,7 +44225,7 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB BLIS_CTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_CTRSM_SMALL_3x4(AlphaVal,b11,cs_b) + BLIS_PRE_CTRSM_SMALL_3x2(AlphaVal,b11,cs_b) ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 2)); @@ -43763,7 +44324,7 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB BLIS_CTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_CTRSM_SMALL_3x4(AlphaVal,b11,cs_b) + BLIS_PRE_CTRSM_SMALL_3x1(AlphaVal,b11,cs_b) ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 2)); @@ -44111,7 +44672,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); ymm16 = _mm256_permute_ps(ymm16, 0x44); - ymm0 = _mm256_loadu_ps((float const *)(b11)); + xmm0 = _mm_loadu_ps((float const *)(b11)); + xmm1 = _mm_loadl_pi(xmm1, (__m64 const *)(b11 + 2)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 1); ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); @@ -44120,7 +44684,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19); ymm8 = _mm256_sub_ps(ymm19, ymm8); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b * 1)); + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b)); + xmm1 = _mm_loadl_pi(xmm1, (__m64 const *)(b11 + cs_b + 2)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 1); ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5); @@ -44184,12 +44751,13 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB k_iter = (n-n_rem); BLIS_SET_S_YMM_REG_ZEROS - ///GEMM implementation starts/// + ///GEMM implementation starts/// BLIS_CTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); ymm16 = _mm256_permute_ps(ymm16, 0x44); - ymm0 = _mm256_loadu_ps((float const *)(b11)); + xmm0 = _mm_loadu_ps((float const *)(b11)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); @@ -44198,7 +44766,8 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19); ymm8 = _mm256_sub_ps(ymm19, ymm8); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b * 1)); + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5); @@ -44261,7 +44830,8 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); ymm16 = _mm256_permute_ps(ymm16, 0x44); - ymm0 = _mm256_loadu_ps((float const *)(b11)); + xmm1 = _mm_loadl_pi(xmm1, (__m64 const *)(b11)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 0); ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); @@ -44270,7 +44840,8 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19); ymm8 = _mm256_sub_ps(ymm19, ymm8); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b * 1)); + xmm1 = _mm_loadl_pi(xmm1, (__m64 const *)(b11 + cs_b)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 0); ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5); @@ -44486,12 +45057,15 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB k_iter = (n-n_rem); BLIS_SET_S_YMM_REG_ZEROS - ///GEMM implementation starts/// + ///GEMM implementation starts/// BLIS_CTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); ymm16 = _mm256_permute_ps(ymm16, 0x44); - ymm0 = _mm256_loadu_ps((float const *)(b11)); + xmm0 = _mm_loadu_ps((float const *)(b11)); + xmm1 = _mm_loadl_pi(xmm1, (__m64 const *)(b11 + 2)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 1); ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); @@ -44530,7 +45104,8 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); ymm16 = _mm256_permute_ps(ymm16, 0x44); - ymm0 = _mm256_loadu_ps((float const *)(b11)); + xmm0 = _mm_loadu_ps((float const *)(b11)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); @@ -44567,7 +45142,8 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); ymm16 = _mm256_permute_ps(ymm16, 0x44); - ymm0 = _mm256_loadu_ps((float const *)(b11)); + xmm1 = _mm_loadl_pi(xmm1, (__m64 const *)(b11)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 0); ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); @@ -44994,7 +45570,7 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB BLIS_CTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_CTRSM_SMALL_3x4(AlphaVal,b11,cs_b) + BLIS_PRE_CTRSM_SMALL_3x3(AlphaVal,b11,cs_b) ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); @@ -45116,7 +45692,7 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB BLIS_CTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_CTRSM_SMALL_3x4(AlphaVal,b11,cs_b) + BLIS_PRE_CTRSM_SMALL_3x2(AlphaVal,b11,cs_b) ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); @@ -45232,7 +45808,7 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB BLIS_CTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_CTRSM_SMALL_3x4(AlphaVal,b11,cs_b) + BLIS_PRE_CTRSM_SMALL_3x1(AlphaVal,b11,cs_b) ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); @@ -45598,7 +46174,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); ymm16 = _mm256_permute_ps(ymm16, 0x44); - ymm0 = _mm256_loadu_ps((float const *)(b11)); + xmm0 = _mm_loadu_ps((float const *)(b11)); + xmm1 = _mm_loadl_pi(xmm1, (__m64 const *)(b11 + 2)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 1); ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); @@ -45607,7 +46186,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19); ymm8 = _mm256_sub_ps(ymm19, ymm8); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b * 1)); + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b)); + xmm1 = _mm_loadl_pi(xmm1, (__m64 const *)(b11 + cs_b + 2)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 1); ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5); @@ -45678,7 +46260,8 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); ymm16 = _mm256_permute_ps(ymm16, 0x44); - ymm0 = _mm256_loadu_ps((float const *)(b11)); + xmm0 = _mm_loadu_ps((float const *)(b11)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); @@ -45687,7 +46270,8 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19); ymm8 = _mm256_sub_ps(ymm19, ymm8); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b * 1)); + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5); @@ -45753,7 +46337,8 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); ymm16 = _mm256_permute_ps(ymm16, 0x44); - ymm0 = _mm256_loadu_ps((float const *)(b11)); + xmm1 = _mm_loadl_pi(xmm1, (__m64 const *)(b11)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 0); ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); @@ -45762,7 +46347,8 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19); ymm8 = _mm256_sub_ps(ymm19, ymm8); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b * 1)); + xmm1 = _mm_loadl_pi(xmm1, (__m64 const *)(b11 + cs_b)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 0); ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5); @@ -45984,7 +46570,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); ymm16 = _mm256_permute_ps(ymm16, 0x44); - ymm0 = _mm256_loadu_ps((float const *)(b11)); + xmm0 = _mm_loadu_ps((float const *)(b11)); + xmm1 = _mm_loadl_pi(xmm1, (__m64 const *)(b11 + 2)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 1); ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); @@ -46026,7 +46615,8 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); ymm16 = _mm256_permute_ps(ymm16, 0x44); - ymm0 = _mm256_loadu_ps((float const *)(b11)); + xmm0 = _mm_loadu_ps((float const *)(b11)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); @@ -46066,7 +46656,8 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); ymm16 = _mm256_permute_ps(ymm16, 0x44); - ymm0 = _mm256_loadu_ps((float const *)(b11)); + xmm1 = _mm_loadl_pi(xmm1, (__m64 const *)(b11)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 0); ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); From e2e1dadee11e3ca1c2b9821192bfd23dcab6578e Mon Sep 17 00:00:00 2001 From: Harsh Dave Date: Wed, 10 Aug 2022 03:59:00 -0500 Subject: [PATCH 176/243] DGEMM Improvements - We prefetch next panel while packing 8xk panel. - Modified prefetch offsets for dgemm native and dgemm_small kernel. AMD-Internal: [CPUPL-2366] Change-Id: Ife609e789c8b87169c73bb0a30d6f1af20fb30ed --- frame/compat/bla_gemm_amd.c | 7 +- .../haswell/1m/bli_packm_haswell_asm_d8xk.c | 14 +- kernels/haswell/3/bli_gemm_haswell_asm_d6x8.c | 435 +++++++++--------- .../3/sup/bli_gemmsup_rv_haswell_asm_d6x8m.c | 3 +- kernels/zen/3/bli_gemm_small.c | 36 +- 5 files changed, 260 insertions(+), 235 deletions(-) diff --git a/frame/compat/bla_gemm_amd.c b/frame/compat/bla_gemm_amd.c index d46a69f0f8..942d94f34a 100644 --- a/frame/compat/bla_gemm_amd.c +++ b/frame/compat/bla_gemm_amd.c @@ -556,9 +556,10 @@ void dgemm_ #ifdef BLIS_ENABLE_SMALL_MATRIX - //if( ((m0 + n0 -k0) < 2000) && ((m0 + k0-n0) < 2000) && ((n0 + k0-m0) < 2000) && (n0 > 2)) - if( ( ( (m0 + n0 -k0) < 2000) && ((m0 + k0-n0) < 2000) && ((n0 + k0-m0) < 2000) ) || - ((n0 <= 10) && (k0 <=10)) ) + if(((m0 == n0) && (m0 < 400) && (k0 < 1000)) || + ( (m0 != n0) && (( ((m0 + n0 -k0) < 1500) && + ((m0 + k0-n0) < 1500) && ((n0 + k0-m0) < 1500) ) || + ((n0 <= 100) && (k0 <=100))))) { err_t status = BLIS_FAILURE; if (bli_is_notrans(blis_transa)) diff --git a/kernels/haswell/1m/bli_packm_haswell_asm_d8xk.c b/kernels/haswell/1m/bli_packm_haswell_asm_d8xk.c index 9deb564ce4..3b03d38fb7 100644 --- a/kernels/haswell/1m/bli_packm_haswell_asm_d8xk.c +++ b/kernels/haswell/1m/bli_packm_haswell_asm_d8xk.c @@ -101,6 +101,8 @@ void bli_dpackm_haswell_asm_8xk // assembly region, this constraint should be lifted. const bool unitk = bli_deq1( *kappa ); + double* restrict a_next = a + cdim0; + // ------------------------------------------------------------------------- @@ -267,7 +269,7 @@ void bli_dpackm_haswell_asm_8xk label(.DCOLUNIT) lea(mem(r10, r10, 2), r13) // r13 = 3*lda - + mov(var(a_next), rcx) mov(var(k_iter), rsi) // i = k_iter; test(rsi, rsi) // check i via logical AND. je(.DCONKLEFTCOLU) // if i == 0, jump to code that @@ -278,22 +280,27 @@ void bli_dpackm_haswell_asm_8xk vmovupd(mem(rax, 0), ymm0) vmovupd(mem(rax, 32), ymm1) + prefetch(0, mem(rcx,7*8)) vmovupd(ymm0, mem(rbx, 0*64+ 0)) vmovupd(ymm1, mem(rbx, 0*64+32)) vmovupd(mem(rax, r10, 1, 0), ymm2) vmovupd(mem(rax, r10, 1, 32), ymm3) + prefetch(0, mem(rcx, r10, 1,7*8)) vmovupd(ymm2, mem(rbx, 1*64+ 0)) vmovupd(ymm3, mem(rbx, 1*64+32)) vmovupd(mem(rax, r10, 2, 0), ymm4) vmovupd(mem(rax, r10, 2, 32), ymm5) + prefetch(0, mem(rcx, r10, 2,7*8)) vmovupd(ymm4, mem(rbx, 2*64+ 0)) vmovupd(ymm5, mem(rbx, 2*64+32)) vmovupd(mem(rax, r13, 1, 0), ymm6) vmovupd(mem(rax, r13, 1, 32), ymm7) + prefetch(0, mem(rcx, r13, 1,7*8)) add(r14, rax) // a += 4*lda; + add(r14, rcx) vmovupd(ymm6, mem(rbx, 3*64+ 0)) vmovupd(ymm7, mem(rbx, 3*64+32)) add(imm(4*8*8), rbx) // p += 4*ldp = 4*8; @@ -315,7 +322,9 @@ void bli_dpackm_haswell_asm_8xk vmovupd(mem(rax, 0), ymm0) vmovupd(mem(rax, 32), ymm1) + prefetch(0, mem(rcx,7*8)) add(r10, rax) // a += lda; + add(r10, rcx) vmovupd(ymm0, mem(rbx, 0*64+ 0)) vmovupd(ymm1, mem(rbx, 0*64+32)) add(imm(8*8), rbx) // p += ldp = 8; @@ -343,7 +352,8 @@ void bli_dpackm_haswell_asm_8xk [p] "m" (p), [ldp] "m" (ldp), [kappa] "m" (kappa), - [one] "m" (one) + [one] "m" (one), + [a_next] "m" (a_next) : // register clobber list "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", /*"r9",*/ "r10", /*"r11",*/ "r12", "r13", "r14", "r15", diff --git a/kernels/haswell/3/bli_gemm_haswell_asm_d6x8.c b/kernels/haswell/3/bli_gemm_haswell_asm_d6x8.c index e6d47268f7..5187d0bcb0 100644 --- a/kernels/haswell/3/bli_gemm_haswell_asm_d6x8.c +++ b/kernels/haswell/3/bli_gemm_haswell_asm_d6x8.c @@ -102,7 +102,19 @@ void bli_sgemm_haswell_asm_6x16 begin_asm() - vzeroall() // zero all xmm/ymm registers. + //vzeroall() // zero all xmm/ymm registers. + vxorps( ymm4, ymm4, ymm4) + vmovaps( ymm4, ymm5) + vmovaps( ymm4, ymm6) + vmovaps( ymm4, ymm7) + vmovaps( ymm4, ymm8) + vmovaps( ymm4, ymm9) + vmovaps( ymm4, ymm10) + vmovaps( ymm4, ymm11) + vmovaps( ymm4, ymm12) + vmovaps( ymm4, ymm13) + vmovaps( ymm4, ymm14) + vmovaps( ymm4, ymm15) mov(var(a), rax) // load address of a. @@ -141,7 +153,7 @@ void bli_sgemm_haswell_asm_6x16 // iteration 0 prefetch(0, mem(rax, 64*4)) - + vbroadcastss(mem(rax, 0*4), ymm2) vbroadcastss(mem(rax, 1*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm4) @@ -167,6 +179,8 @@ void bli_sgemm_haswell_asm_6x16 vmovaps(mem(rbx, -1*32), ymm1) // iteration 1 + prefetch(0, mem(rax, 72*4)) + vbroadcastss(mem(rax, 6*4), ymm2) vbroadcastss(mem(rax, 7*4), ymm3) vfmadd231ps(ymm0, ymm2, ymm4) @@ -192,7 +206,7 @@ void bli_sgemm_haswell_asm_6x16 vmovaps(mem(rbx, 1*32), ymm1) // iteration 2 - prefetch(0, mem(rax, 76*4)) + prefetch(0, mem(rax, 80*4)) vbroadcastss(mem(rax, 12*4), ymm2) vbroadcastss(mem(rax, 13*4), ymm3) @@ -1010,76 +1024,78 @@ void bli_dgemm_haswell_asm_6x8 vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, 2*8), ymm2) vbroadcastsd(mem(rax, 3*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, 4*8), ymm2) vbroadcastsd(mem(rax, 5*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + vmovapd(mem(rbx, -2*32), ymm0) vmovapd(mem(rbx, -1*32), ymm1) - + // iteration 1 + prefetch(0, mem(rax, 72*8)) + vbroadcastsd(mem(rax, 6*8), ymm2) vbroadcastsd(mem(rax, 7*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, 8*8), ymm2) vbroadcastsd(mem(rax, 9*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, 10*8), ymm2) vbroadcastsd(mem(rax, 11*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + vmovapd(mem(rbx, 0*32), ymm0) vmovapd(mem(rbx, 1*32), ymm1) - + // iteration 2 - prefetch(0, mem(rax, 76*8)) - + prefetch(0, mem(rax, 80*8)) + vbroadcastsd(mem(rax, 12*8), ymm2) vbroadcastsd(mem(rax, 13*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, 14*8), ymm2) vbroadcastsd(mem(rax, 15*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, 16*8), ymm2) vbroadcastsd(mem(rax, 17*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + vmovapd(mem(rbx, 2*32), ymm0) vmovapd(mem(rbx, 3*32), ymm1) - + // iteration 3 vbroadcastsd(mem(rax, 18*8), ymm2) vbroadcastsd(mem(rax, 19*8), ymm3) @@ -1087,91 +1103,91 @@ void bli_dgemm_haswell_asm_6x8 vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, 20*8), ymm2) vbroadcastsd(mem(rax, 21*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, 22*8), ymm2) vbroadcastsd(mem(rax, 23*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + add(imm(4*6*8), rax) // a += 4*6 (unroll x mr) add(imm(4*8*8), rbx) // b += 4*8 (unroll x nr) - + vmovapd(mem(rbx, -4*32), ymm0) vmovapd(mem(rbx, -3*32), ymm1) - - + + dec(rsi) // i -= 1; jne(.DLOOPKITER) // iterate again if i != 0. - - - - - - + + + + + + label(.DCONSIDKLEFT) - + mov(var(k_left), rsi) // i = k_left; test(rsi, rsi) // check i via logical AND. je(.DPOSTACCUM) // if i == 0, we're done; jump to end. // else, we prepare to enter k_left loop. - - + + label(.DLOOPKLEFT) // EDGE LOOP - + prefetch(0, mem(rax, 64*8)) - + vbroadcastsd(mem(rax, 0*8), ymm2) vbroadcastsd(mem(rax, 1*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm4) vfmadd231pd(ymm1, ymm2, ymm5) vfmadd231pd(ymm0, ymm3, ymm6) vfmadd231pd(ymm1, ymm3, ymm7) - + vbroadcastsd(mem(rax, 2*8), ymm2) vbroadcastsd(mem(rax, 3*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm8) vfmadd231pd(ymm1, ymm2, ymm9) vfmadd231pd(ymm0, ymm3, ymm10) vfmadd231pd(ymm1, ymm3, ymm11) - + vbroadcastsd(mem(rax, 4*8), ymm2) vbroadcastsd(mem(rax, 5*8), ymm3) vfmadd231pd(ymm0, ymm2, ymm12) vfmadd231pd(ymm1, ymm2, ymm13) vfmadd231pd(ymm0, ymm3, ymm14) vfmadd231pd(ymm1, ymm3, ymm15) - + add(imm(1*6*8), rax) // a += 1*6 (unroll x mr) add(imm(1*8*8), rbx) // b += 1*8 (unroll x nr) - + vmovapd(mem(rbx, -4*32), ymm0) vmovapd(mem(rbx, -3*32), ymm1) - - + + dec(rsi) // i -= 1; jne(.DLOOPKLEFT) // iterate again if i != 0. - - - + + + label(.DPOSTACCUM) - - - - + + + + mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate - + vmulpd(ymm0, ymm4, ymm4) // scale by alpha vmulpd(ymm0, ymm5, ymm5) vmulpd(ymm0, ymm6, ymm6) @@ -1184,179 +1200,179 @@ void bli_dgemm_haswell_asm_6x8 vmulpd(ymm0, ymm13, ymm13) vmulpd(ymm0, ymm14, ymm14) vmulpd(ymm0, ymm15, ymm15) - - - - - - + + + + + + mov(var(cs_c), rsi) // load cs_c lea(mem(, rsi, 8), rsi) // rsi = cs_c * sizeof(double) - + lea(mem(rcx, rsi, 4), rdx) // load address of c + 4*cs_c; lea(mem(rcx, rdi, 4), r14) // load address of c + 4*rs_c; - + lea(mem(rsi, rsi, 2), r13) // r13 = 3*cs_c; //lea(mem(rsi, rsi, 4), r15) // r15 = 5*cs_c; //lea(mem(r13, rsi, 4), r10) // r10 = 7*cs_c; - - + + // now avoid loading C if beta == 0 - + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. vucomisd(xmm0, xmm3) // set ZF if beta == 0. je(.DBETAZERO) // if ZF = 1, jump to beta == 0 case - - + + cmp(imm(8), rsi) // set ZF if (8*cs_c) == 8. jz(.DROWSTORED) // jump to row storage case - - + + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. jz(.DCOLSTORED) // jump to column storage case - - - + + + label(.DGENSTORED) - - + + DGEMM_INPUT_GS_BETA_NZ vfmadd213pd(ymm4, ymm3, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + DGEMM_INPUT_GS_BETA_NZ vfmadd213pd(ymm6, ymm3, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + DGEMM_INPUT_GS_BETA_NZ vfmadd213pd(ymm8, ymm3, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + DGEMM_INPUT_GS_BETA_NZ vfmadd213pd(ymm10, ymm3, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + DGEMM_INPUT_GS_BETA_NZ vfmadd213pd(ymm12, ymm3, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + DGEMM_INPUT_GS_BETA_NZ vfmadd213pd(ymm14, ymm3, ymm0) DGEMM_OUTPUT_GS_BETA_NZ - - + + mov(rdx, rcx) // rcx = c + 4*cs_c - - + + DGEMM_INPUT_GS_BETA_NZ vfmadd213pd(ymm5, ymm3, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + DGEMM_INPUT_GS_BETA_NZ vfmadd213pd(ymm7, ymm3, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + DGEMM_INPUT_GS_BETA_NZ vfmadd213pd(ymm9, ymm3, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + DGEMM_INPUT_GS_BETA_NZ vfmadd213pd(ymm11, ymm3, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + DGEMM_INPUT_GS_BETA_NZ vfmadd213pd(ymm13, ymm3, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + DGEMM_INPUT_GS_BETA_NZ vfmadd213pd(ymm15, ymm3, ymm0) DGEMM_OUTPUT_GS_BETA_NZ - - - + + + jmp(.DDONE) // jump to end. - - - + + + label(.DROWSTORED) - - + + vfmadd231pd(mem(rcx), ymm3, ymm4) vmovupd(ymm4, mem(rcx)) add(rdi, rcx) vfmadd231pd(mem(rdx), ymm3, ymm5) vmovupd(ymm5, mem(rdx)) add(rdi, rdx) - - + + vfmadd231pd(mem(rcx), ymm3, ymm6) vmovupd(ymm6, mem(rcx)) add(rdi, rcx) vfmadd231pd(mem(rdx), ymm3, ymm7) vmovupd(ymm7, mem(rdx)) add(rdi, rdx) - - + + vfmadd231pd(mem(rcx), ymm3, ymm8) vmovupd(ymm8, mem(rcx)) add(rdi, rcx) vfmadd231pd(mem(rdx), ymm3, ymm9) vmovupd(ymm9, mem(rdx)) add(rdi, rdx) - - + + vfmadd231pd(mem(rcx), ymm3, ymm10) vmovupd(ymm10, mem(rcx)) add(rdi, rcx) vfmadd231pd(mem(rdx), ymm3, ymm11) vmovupd(ymm11, mem(rdx)) add(rdi, rdx) - - + + vfmadd231pd(mem(rcx), ymm3, ymm12) vmovupd(ymm12, mem(rcx)) add(rdi, rcx) vfmadd231pd(mem(rdx), ymm3, ymm13) vmovupd(ymm13, mem(rdx)) add(rdi, rdx) - - + + vfmadd231pd(mem(rcx), ymm3, ymm14) vmovupd(ymm14, mem(rcx)) //add(rdi, rcx) vfmadd231pd(mem(rdx), ymm3, ymm15) vmovupd(ymm15, mem(rdx)) //add(rdi, rdx) - - - + + + jmp(.DDONE) // jump to end. - - - + + + label(.DCOLSTORED) - - + + vunpcklpd(ymm6, ymm4, ymm0) vunpckhpd(ymm6, ymm4, ymm1) vunpcklpd(ymm10, ymm8, ymm2) @@ -1365,9 +1381,9 @@ void bli_dgemm_haswell_asm_6x8 vinsertf128(imm(0x1), xmm3, ymm1, ymm6) vperm2f128(imm(0x31), ymm2, ymm0, ymm8) vperm2f128(imm(0x31), ymm3, ymm1, ymm10) - + vbroadcastsd(mem(rbx), ymm3) - + vfmadd231pd(mem(rcx), ymm3, ymm4) vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm6) vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm8) @@ -1376,14 +1392,14 @@ void bli_dgemm_haswell_asm_6x8 vmovupd(ymm6, mem(rcx, rsi, 1)) vmovupd(ymm8, mem(rcx, rsi, 2)) vmovupd(ymm10, mem(rcx, r13, 1)) - + lea(mem(rcx, rsi, 4), rcx) - + vunpcklpd(ymm14, ymm12, ymm0) vunpckhpd(ymm14, ymm12, ymm1) vextractf128(imm(0x1), ymm0, xmm2) vextractf128(imm(0x1), ymm1, xmm4) - + vfmadd231pd(mem(r14), xmm3, xmm0) vfmadd231pd(mem(r14, rsi, 1), xmm3, xmm1) vfmadd231pd(mem(r14, rsi, 2), xmm3, xmm2) @@ -1392,10 +1408,10 @@ void bli_dgemm_haswell_asm_6x8 vmovupd(xmm1, mem(r14, rsi, 1)) vmovupd(xmm2, mem(r14, rsi, 2)) vmovupd(xmm4, mem(r14, r13, 1)) - + lea(mem(r14, rsi, 4), r14) - - + + vunpcklpd(ymm7, ymm5, ymm0) vunpckhpd(ymm7, ymm5, ymm1) vunpcklpd(ymm11, ymm9, ymm2) @@ -1404,9 +1420,9 @@ void bli_dgemm_haswell_asm_6x8 vinsertf128(imm(0x1), xmm3, ymm1, ymm7) vperm2f128(imm(0x31), ymm2, ymm0, ymm9) vperm2f128(imm(0x31), ymm3, ymm1, ymm11) - + vbroadcastsd(mem(rbx), ymm3) - + vfmadd231pd(mem(rcx), ymm3, ymm5) vfmadd231pd(mem(rcx, rsi, 1), ymm3, ymm7) vfmadd231pd(mem(rcx, rsi, 2), ymm3, ymm9) @@ -1415,14 +1431,14 @@ void bli_dgemm_haswell_asm_6x8 vmovupd(ymm7, mem(rcx, rsi, 1)) vmovupd(ymm9, mem(rcx, rsi, 2)) vmovupd(ymm11, mem(rcx, r13, 1)) - + //lea(mem(rcx, rsi, 4), rcx) - + vunpcklpd(ymm15, ymm13, ymm0) vunpckhpd(ymm15, ymm13, ymm1) vextractf128(imm(0x1), ymm0, xmm2) vextractf128(imm(0x1), ymm1, xmm4) - + vfmadd231pd(mem(r14), xmm3, xmm0) vfmadd231pd(mem(r14, rsi, 1), xmm3, xmm1) vfmadd231pd(mem(r14, rsi, 2), xmm3, xmm2) @@ -1431,139 +1447,139 @@ void bli_dgemm_haswell_asm_6x8 vmovupd(xmm1, mem(r14, rsi, 1)) vmovupd(xmm2, mem(r14, rsi, 2)) vmovupd(xmm4, mem(r14, r13, 1)) - + //lea(mem(r14, rsi, 4), r14) - - - + + + jmp(.DDONE) // jump to end. - - - + + + label(.DBETAZERO) - + cmp(imm(8), rsi) // set ZF if (8*cs_c) == 8. jz(.DROWSTORBZ) // jump to row storage case - + cmp(imm(8), rdi) // set ZF if (8*rs_c) == 8. jz(.DCOLSTORBZ) // jump to column storage case - - - + + + label(.DGENSTORBZ) - - + + vmovapd(ymm4, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + vmovapd(ymm6, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + vmovapd(ymm8, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + vmovapd(ymm10, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + vmovapd(ymm12, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + vmovapd(ymm14, ymm0) DGEMM_OUTPUT_GS_BETA_NZ - - + + mov(rdx, rcx) // rcx = c + 4*cs_c - - + + vmovapd(ymm5, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + vmovapd(ymm7, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + vmovapd(ymm9, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + vmovapd(ymm11, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + vmovapd(ymm13, ymm0) DGEMM_OUTPUT_GS_BETA_NZ add(rdi, rcx) // c += rs_c; - - + + vmovapd(ymm15, ymm0) DGEMM_OUTPUT_GS_BETA_NZ - - - + + + jmp(.DDONE) // jump to end. - - - + + + label(.DROWSTORBZ) - - + + vmovupd(ymm4, mem(rcx)) add(rdi, rcx) vmovupd(ymm5, mem(rdx)) add(rdi, rdx) - + vmovupd(ymm6, mem(rcx)) add(rdi, rcx) vmovupd(ymm7, mem(rdx)) add(rdi, rdx) - - + + vmovupd(ymm8, mem(rcx)) add(rdi, rcx) vmovupd(ymm9, mem(rdx)) add(rdi, rdx) - - + + vmovupd(ymm10, mem(rcx)) add(rdi, rcx) vmovupd(ymm11, mem(rdx)) add(rdi, rdx) - - + + vmovupd(ymm12, mem(rcx)) add(rdi, rcx) vmovupd(ymm13, mem(rdx)) add(rdi, rdx) - - + + vmovupd(ymm14, mem(rcx)) //add(rdi, rcx) vmovupd(ymm15, mem(rdx)) //add(rdi, rdx) - - + + jmp(.DDONE) // jump to end. - - - + + + label(.DCOLSTORBZ) - - + + vunpcklpd(ymm6, ymm4, ymm0) vunpckhpd(ymm6, ymm4, ymm1) vunpcklpd(ymm10, ymm8, ymm2) @@ -1572,27 +1588,27 @@ void bli_dgemm_haswell_asm_6x8 vinsertf128(imm(0x1), xmm3, ymm1, ymm6) vperm2f128(imm(0x31), ymm2, ymm0, ymm8) vperm2f128(imm(0x31), ymm3, ymm1, ymm10) - + vmovupd(ymm4, mem(rcx)) vmovupd(ymm6, mem(rcx, rsi, 1)) vmovupd(ymm8, mem(rcx, rsi, 2)) vmovupd(ymm10, mem(rcx, r13, 1)) - + lea(mem(rcx, rsi, 4), rcx) - + vunpcklpd(ymm14, ymm12, ymm0) vunpckhpd(ymm14, ymm12, ymm1) vextractf128(imm(0x1), ymm0, xmm2) vextractf128(imm(0x1), ymm1, xmm4) - + vmovupd(xmm0, mem(r14)) vmovupd(xmm1, mem(r14, rsi, 1)) vmovupd(xmm2, mem(r14, rsi, 2)) vmovupd(xmm4, mem(r14, r13, 1)) - + lea(mem(r14, rsi, 4), r14) - - + + vunpcklpd(ymm7, ymm5, ymm0) vunpckhpd(ymm7, ymm5, ymm1) vunpcklpd(ymm11, ymm9, ymm2) @@ -1601,32 +1617,31 @@ void bli_dgemm_haswell_asm_6x8 vinsertf128(imm(0x1), xmm3, ymm1, ymm7) vperm2f128(imm(0x31), ymm2, ymm0, ymm9) vperm2f128(imm(0x31), ymm3, ymm1, ymm11) - + vmovupd(ymm5, mem(rcx)) vmovupd(ymm7, mem(rcx, rsi, 1)) vmovupd(ymm9, mem(rcx, rsi, 2)) vmovupd(ymm11, mem(rcx, r13, 1)) - + //lea(mem(rcx, rsi, 4), rcx) - + vunpcklpd(ymm15, ymm13, ymm0) vunpckhpd(ymm15, ymm13, ymm1) vextractf128(imm(0x1), ymm0, xmm2) vextractf128(imm(0x1), ymm1, xmm4) - + vmovupd(xmm0, mem(r14)) vmovupd(xmm1, mem(r14, rsi, 1)) vmovupd(xmm2, mem(r14, rsi, 2)) vmovupd(xmm4, mem(r14, r13, 1)) - + //lea(mem(r14, rsi, 4), r14) - - - - label(.DDONE) + + + label(.DDONE) vzeroupper() - + end_asm( : // output operands (none) diff --git a/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8m.c b/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8m.c index 41da73f361..107917d078 100644 --- a/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8m.c +++ b/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8m.c @@ -867,8 +867,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8m label(.DRETURN) - - + vzeroupper() end_asm( : // output operands (none) diff --git a/kernels/zen/3/bli_gemm_small.c b/kernels/zen/3/bli_gemm_small.c index 0cf5c8c5ce..e232e28e51 100644 --- a/kernels/zen/3/bli_gemm_small.c +++ b/kernels/zen/3/bli_gemm_small.c @@ -1951,12 +1951,12 @@ static err_t bli_sgemm_small tA_packed = D_A_pack; #ifdef BLIS_ENABLE_PREFETCH - _mm_prefetch((char*)(tC + 0), _MM_HINT_T0); - _mm_prefetch((char*)(tC + 8), _MM_HINT_T0); - _mm_prefetch((char*)(tC + ldc), _MM_HINT_T0); - _mm_prefetch((char*)(tC + ldc + 8), _MM_HINT_T0); - _mm_prefetch((char*)(tC + 2 * ldc), _MM_HINT_T0); - _mm_prefetch((char*)(tC + 2 * ldc + 8), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 7), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 15), _MM_HINT_T0); + _mm_prefetch((char*)(tC + ldc + 7), _MM_HINT_T0); + _mm_prefetch((char*)(tC + ldc + 15), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 2 * ldc + 7), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 2 * ldc + 15), _MM_HINT_T0); #endif // clear scratch registers. ymm4 = _mm256_setzero_pd(); @@ -2111,12 +2111,12 @@ static err_t bli_sgemm_small tA = tA_packed + row_idx_packed; #ifdef BLIS_ENABLE_PREFETCH - _mm_prefetch((char*)(tC + 0), _MM_HINT_T0); - _mm_prefetch((char*)(tC + 8), _MM_HINT_T0); - _mm_prefetch((char*)(tC + ldc), _MM_HINT_T0); - _mm_prefetch((char*)(tC + ldc + 8), _MM_HINT_T0); - _mm_prefetch((char*)(tC + 2 * ldc), _MM_HINT_T0); - _mm_prefetch((char*)(tC + 2 * ldc + 8), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 7), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 15), _MM_HINT_T0); + _mm_prefetch((char*)(tC + ldc + 7), _MM_HINT_T0); + _mm_prefetch((char*)(tC + ldc + 15), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 2 * ldc + 7), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 2 * ldc + 15), _MM_HINT_T0); #endif // clear scratch registers. ymm4 = _mm256_setzero_pd(); @@ -4513,12 +4513,12 @@ err_t bli_dgemm_small_At tA = tA_packed + row_idx_packed; #ifdef BLIS_ENABLE_PREFETCH - _mm_prefetch((char*)(tC + 0), _MM_HINT_T0); - _mm_prefetch((char*)(tC + 8), _MM_HINT_T0); - _mm_prefetch((char*)(tC + ldc), _MM_HINT_T0); - _mm_prefetch((char*)(tC + ldc + 8), _MM_HINT_T0); - _mm_prefetch((char*)(tC + 2 * ldc), _MM_HINT_T0); - _mm_prefetch((char*)(tC + 2 * ldc + 8), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 7), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 15), _MM_HINT_T0); + _mm_prefetch((char*)(tC + ldc + 7), _MM_HINT_T0); + _mm_prefetch((char*)(tC + ldc + 15), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 2 * ldc + 7), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 2 * ldc + 15), _MM_HINT_T0); #endif // clear scratch registers. ymm4 = _mm256_setzero_pd(); From 39196d163e134b61a031c5aa58ac430e9176588d Mon Sep 17 00:00:00 2001 From: Nallani Bhaskar Date: Mon, 8 Aug 2022 17:20:57 +0530 Subject: [PATCH 177/243] Enable packing of A & B dynamically in dgemmsup. Details: - When work distributed for each thread is larger than caches, it is advisable to perform packing of B for sup dgemm. - Work distribution per thread is calculated based on the values of jc_nt and ic_nt. - For RRC and CRC cases we want to avoid rd kernels which are not efficient in performance compared to rv kernels. Therefore we perform packing of A as well so that rv kernels are invoked for these cases. - These changes result in improved DGEMM performance. - Dynamic packing is done using the API "bli_rntm_set_pack_b( 1, rntm )" Change-Id: I8344520b4a2591e57518bb54183a15957f60f94b --- frame/3/bli_l3_sup_int_amd.c | 64 +++++++++++++++++++++++++++++------- 1 file changed, 52 insertions(+), 12 deletions(-) diff --git a/frame/3/bli_l3_sup_int_amd.c b/frame/3/bli_l3_sup_int_amd.c index bbf5637555..1bc4fa8caa 100644 --- a/frame/3/bli_l3_sup_int_amd.c +++ b/frame/3/bli_l3_sup_int_amd.c @@ -113,22 +113,42 @@ err_t bli_gemmsup_int bli_l3_sup_thrinfo_update_root( rntm, thread ); } - /*Enable packing for B matrix for higher sizes*/ + //Enable packing for B matrix for higher sizes if(bli_is_float(dt) && (n_threads==1)) { if((m > 240) && (k > 240) && (n > 240)) - bli_rntm_set_pack_b( 1, rntm ); + bli_rntm_set_pack_b( 1, rntm );//packb } - /*Enable packing of B matrix for complex data type*/ + //Enable packing of B matrix for complex data type if (bli_is_dcomplex(dt) && (n_threads == 1)) { if ((m > 55) && (k > 55) && (n > 55)) - bli_rntm_set_pack_b(1, rntm); + bli_rntm_set_pack_b(1, rntm);//packb } - bli_gemmsup_ref_var2m( BLIS_NO_TRANSPOSE, - alpha, a, b, beta, c, - stor_id, cntx, rntm, thread ); + //Enable packing of B matrix for double data type when dims at per + //thread level are above caches and enable packing of A when transA + //(RRC or CRC storage ids) to avoid rd kernels + if(bli_is_double(dt)) + { + dim_t m_pt = (m/bli_rntm_ways_for( BLIS_MC, rntm )); + dim_t n_pt = (n/bli_rntm_ways_for( BLIS_NC, rntm )); + + if(k > 120) + { + if(((m_pt > 320) && (n_pt > 120)) || ((m_pt > 120) && (n_pt > 320))) + { + bli_rntm_set_pack_b(1, rntm);//packb + + if(stor_id==BLIS_RRC || stor_id==BLIS_CRC) + bli_rntm_set_pack_a(1, rntm);//packa + } + } + } + + bli_gemmsup_ref_var2m(BLIS_NO_TRANSPOSE, + alpha, a, b, beta, c, + stor_id, cntx, rntm, thread ); } else { @@ -156,19 +176,39 @@ err_t bli_gemmsup_int * becomes pack B inside var2m because this is transpose case*/ if(bli_is_float(dt) && (n_threads==1)) { if((m > 240) && (k > 240) && (n > 240)) - bli_rntm_set_pack_a( 1, rntm ); + bli_rntm_set_pack_a( 1, rntm );//packb } /*Enable packing of A matrix for complex data type*/ if (bli_is_dcomplex(dt) && (n_threads == 1)) { if ((m > 55) && (k > 55) && (n > 55)) - bli_rntm_set_pack_a(1, rntm); + bli_rntm_set_pack_a(1, rntm);//packb } - bli_gemmsup_ref_var2m( BLIS_TRANSPOSE, - alpha, a, b, beta, c, - stor_id, cntx, rntm, thread ); + //Enable packing of B matrix for double data type when dims at per + //thread level are above caches and enable packing of A when transA + //(RRC or CRC storage ids) to avoid rd kernels + if(bli_is_double(dt)) + { + dim_t m_pt = (m/bli_rntm_ways_for( BLIS_NC, rntm )); + dim_t n_pt = (n/bli_rntm_ways_for( BLIS_MC, rntm )); + + if(k > 120) + { + if(((m_pt > 320) && (n_pt > 120)) || ((m_pt > 120) && (n_pt > 320))) + { + bli_rntm_set_pack_a(1, rntm);//packb + + if(stor_id==BLIS_RRC || stor_id==BLIS_CRC) + bli_rntm_set_pack_b(1, rntm);//packa + } + } + } + + bli_gemmsup_ref_var2m(BLIS_TRANSPOSE, + alpha, a, b, beta, c, + stor_id, cntx, rntm, thread ); } AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4); From 46e7727ea8d112c928b50dd3802f9ae3d2e43e6d Mon Sep 17 00:00:00 2001 From: Harsh Dave Date: Wed, 17 Aug 2022 02:24:46 -0500 Subject: [PATCH 178/243] DGEMM Improvements - Incase of DGEMM when m, n and leading dimensions are large packing of A and B matrixes are required for optimal performance. - Modified decision logic to choose between sup vs native, now apart from matrix dimensions, we also incorporate matrix leading dimensions into this decision. AMD-Internal: [CPUPL-2366] Change-Id: I255db5f7049d783e22d7c912edf8bbf023e32ed8 --- frame/base/bli_cntx.h | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/frame/base/bli_cntx.h b/frame/base/bli_cntx.h index d868167234..3715d70c9f 100644 --- a/frame/base/bli_cntx.h +++ b/frame/base/bli_cntx.h @@ -621,6 +621,30 @@ BLIS_INLINE bool bli_cntx_l3_sup_thresh_is_met( obj_t* a, obj_t* b, obj_t* c, cn } + + if(dt == BLIS_DOUBLE) + { + /** + * In case of both matrices having large strides, + * are to be handled in native path, since native + * path does packing of both matrices by default. + * It helps avoiding huge memory jumps while accessing + * matrices during GEMM computation. + */ + dim_t k = bli_obj_width( a ); + inc_t rs_a = bli_obj_row_stride( a ); + inc_t cs_a = bli_obj_col_stride( a ); + inc_t rs_b = bli_obj_row_stride( b ); + inc_t cs_b = bli_obj_col_stride( b ); + inc_t stride_a = rs_a > cs_a ? rs_a : cs_a; + inc_t stride_b = rs_b > cs_b ? rs_b : cs_b; + if( (m > 5000 && n > 700 && k > 120) && (stride_a > 5000 && stride_b > 5000) ) + { + return FALSE; + } + } + + if ( m < bli_cntx_get_l3_sup_thresh_dt( dt, BLIS_MT, cntx ) ) return TRUE; if ( n < bli_cntx_get_l3_sup_thresh_dt( dt, BLIS_NT, cntx ) ) return TRUE; if ( k < bli_cntx_get_l3_sup_thresh_dt( dt, BLIS_KT, cntx ) ) return TRUE; From 171fb7358dce3f973c5c76a22eb6a65c60979b82 Mon Sep 17 00:00:00 2001 From: mkadavil Date: Wed, 17 Aug 2022 16:14:48 +0530 Subject: [PATCH 179/243] SGEMM Optimization -sup GEMM - 2 variants var2m (block-panel) and var1n (panel-block). We added decision logic to choose between var1n and var2m for single thread SGEMM.var1n is favorable option when "n" is very large compared to "m". -Also fixed a bug related to fetching "MR" "NR" values in bli_gemmsup_int(). We replaced "bli_cntx_get_blksz_def_dt()(used for Native)" with "bli_cntx_get_l3_sup_blksz_def_dt()". AMD-Internal: [CPUPL-2406] Change-Id: If36529015b1c5f8f87eb40c05ebcf433c471d4d5 --- frame/3/bli_l3_sup_int_amd.c | 52 +++++++++++++++++++++++++----------- 1 file changed, 36 insertions(+), 16 deletions(-) diff --git a/frame/3/bli_l3_sup_int_amd.c b/frame/3/bli_l3_sup_int_amd.c index 1bc4fa8caa..b226b135d0 100644 --- a/frame/3/bli_l3_sup_int_amd.c +++ b/frame/3/bli_l3_sup_int_amd.c @@ -52,21 +52,15 @@ err_t bli_gemmsup_int const dim_t m = bli_obj_length( c ); const dim_t n = bli_obj_width( c ); const dim_t k = bli_obj_width( a ); - const dim_t MR = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx ); - const dim_t NR = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx ); + const dim_t MR = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_MR, cntx ); + const dim_t NR = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_NR, cntx ); + const dim_t KC = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_KC, cntx ); const bool auto_factor = bli_rntm_auto_factor( rntm ); const dim_t n_threads = bli_rntm_num_threads( rntm ); - + bool use_pb = FALSE; dim_t jc_new; dim_t ic_new; - - //bli_gemmsup_ref_var2 - //bli_gemmsup_ref_var1 - #if 0 - bli_gemmsup_ref_var1n - #else - #endif const stor3_t stor_id = bli_obj_stor3_from_strides( c, a, b ); const bool is_rrr_rrc_rcr_crr = ( stor_id == BLIS_RRR || stor_id == BLIS_RRC || @@ -96,6 +90,9 @@ err_t bli_gemmsup_int const dim_t mu = m / MR; const dim_t nu = n / NR; + // Heuristic to decide whether to use 1n variant or not for sgemm. + use_pb = ( ( nu >= ( 4 * mu ) ) && ( k >= KC ) ) ? TRUE : FALSE; + // If the parallel thread factorization was automatic, we update it // with a new factorization based on the matrix dimensions in units // of micropanels. However in case smart threading is enabled, @@ -146,9 +143,21 @@ err_t bli_gemmsup_int } } - bli_gemmsup_ref_var2m(BLIS_NO_TRANSPOSE, - alpha, a, b, beta, c, - stor_id, cntx, rntm, thread ); + // Using the 1n kernel (B broadcast) gave better performance for sgemm + // in single-thread scenario, given the number of n panels are + // sufficiently larger than m panels. + if ( bli_is_float( dt ) && ( n_threads == 1 ) && ( use_pb == TRUE ) ) + { + bli_gemmsup_ref_var1n( BLIS_NO_TRANSPOSE, + alpha, a, b, beta, c, + stor_id, cntx, rntm, thread ); + } + else + { + bli_gemmsup_ref_var2m( BLIS_NO_TRANSPOSE, + alpha, a, b, beta, c, + stor_id, cntx, rntm, thread ); + } } else { @@ -159,6 +168,8 @@ err_t bli_gemmsup_int const dim_t mu = n / MR; // the n becomes m after a transposition const dim_t nu = m / NR; // the m becomes n after a transposition + use_pb = ( ( nu >= ( 4 * mu ) ) && ( k >= KC ) ) ? TRUE : FALSE; + if ( auto_factor ) { // In the block-panel algorithm, the m dimension is parallelized @@ -206,9 +217,18 @@ err_t bli_gemmsup_int } } - bli_gemmsup_ref_var2m(BLIS_TRANSPOSE, - alpha, a, b, beta, c, - stor_id, cntx, rntm, thread ); + if ( bli_is_float( dt ) && ( n_threads == 1 ) && ( use_pb == TRUE ) ) + { + bli_gemmsup_ref_var1n( BLIS_TRANSPOSE, + alpha, a, b, beta, c, + stor_id, cntx, rntm, thread ); + } + else + { + bli_gemmsup_ref_var2m( BLIS_TRANSPOSE, + alpha, a, b, beta, c, + stor_id, cntx, rntm, thread ); + } } AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4); From 7f322da01db3dc98c2245888a8f536eb0d4de7e3 Mon Sep 17 00:00:00 2001 From: Edward Smyth Date: Tue, 9 Aug 2022 06:06:07 -0400 Subject: [PATCH 180/243] BLIS: BLAS3 quick return functionality Implement netlib BLAS style quick return functionality for when no work is required. Similar functionality was already in HERK and HER2K routines. AMD copyrights updated. AMD-Internal: [CPUPL-2373] Change-Id: I0ebe9d76465b0e48b2ff5c2f1cc2a75763fe187c --- frame/compat/bla_gemm.c | 35 ++++++++++++++++++++++- frame/compat/bla_gemm3m.c | 22 ++++++++++++++- frame/compat/bla_gemm_amd.c | 56 ++++++++++++++++++++++++++++++++++++- frame/compat/bla_gemmt.c | 22 ++++++++++++++- frame/compat/bla_hemm.c | 22 ++++++++++++++- frame/compat/bla_symm.c | 22 ++++++++++++++- frame/compat/bla_syr2k.c | 22 ++++++++++++++- frame/compat/bla_syrk.c | 22 ++++++++++++++- frame/compat/bla_trmm.c | 20 ++++++++++++- frame/compat/bla_trsm.c | 18 ++++++++++++ frame/compat/bla_trsm_amd.c | 54 +++++++++++++++++++++++++++++++++++ 11 files changed, 306 insertions(+), 9 deletions(-) diff --git a/frame/compat/bla_gemm.c b/frame/compat/bla_gemm.c index 406ff69d53..d3601952ba 100644 --- a/frame/compat/bla_gemm.c +++ b/frame/compat/bla_gemm.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019 - 22, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -86,6 +86,17 @@ void PASTEF77(ch,blasname) \ ldb, \ ldc \ ); \ +\ + /* Quick return if possible. */ \ + if ( *m == 0 || *n == 0 || (( PASTEMAC(ch,eq0)( *alpha ) || *k == 0) \ + && PASTEMAC(ch,eq1)( *beta ) )) \ + { \ + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ + return; \ + } \ \ /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ @@ -175,6 +186,17 @@ void PASTEF77(ch,blasname) \ ldb, \ ldc \ ); \ +\ + /* Quick return if possible. */ \ + if ( *m == 0 || *n == 0 || (( PASTEMAC(ch,eq0)( *alpha ) || *k == 0) \ + && PASTEMAC(ch,eq1)( *beta ) )) \ + { \ + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ + return; \ + } \ \ /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ @@ -344,6 +366,17 @@ void dzgemm_ ldc ); + /* Quick return if possible. */ + if ( *m == 0 || *n == 0 || (( PASTEMAC(ch,eq0)( *alpha ) || *k == 0) + && PASTEMAC(ch,eq1)( *beta ) )) + { + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + /* Finalize BLIS. */ + bli_finalize_auto(); + return; \ + } + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); bli_param_map_netlib_to_blis_trans( *transb, &blis_transb ); diff --git a/frame/compat/bla_gemm3m.c b/frame/compat/bla_gemm3m.c index e51cc314de..665c8643dd 100644 --- a/frame/compat/bla_gemm3m.c +++ b/frame/compat/bla_gemm3m.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -83,6 +83,16 @@ void PASTEF77(ch,blasname) \ ldb, \ ldc \ ); \ +\ + /* Quick return if possible. */ \ + if ( *m == 0 || *n == 0 || (( PASTEMAC(ch,eq0)( *alpha ) || *k == 0) \ + && PASTEMAC(ch,eq1)( *beta ) )) \ + { \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ + return; \ + } \ \ /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ @@ -164,6 +174,16 @@ void PASTEF77(ch,blasname) \ ldb, \ ldc \ ); \ +\ + /* Quick return if possible. */ \ + if ( *m == 0 || *n == 0 || (( PASTEMAC(ch,eq0)( *alpha ) || *k == 0) \ + && PASTEMAC(ch,eq1)( *beta ) )) \ + { \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ + return; \ + } \ \ /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ diff --git a/frame/compat/bla_gemm_amd.c b/frame/compat/bla_gemm_amd.c index 942d94f34a..adc83f073d 100644 --- a/frame/compat/bla_gemm_amd.c +++ b/frame/compat/bla_gemm_amd.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019 - 22, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -86,6 +86,17 @@ void PASTEF77(ch,blasname) \ ldb, \ ldc \ ); \ +\ + /* Quick return if possible. */ \ + if ( *m == 0 || *n == 0 || (( PASTEMAC(ch,eq0)( *alpha ) || *k == 0) \ + && PASTEMAC(ch,eq1)( *beta ) )) \ + { \ + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ + return; \ + } \ \ /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ @@ -175,6 +186,17 @@ void PASTEF77(ch,blasname) \ ldb, \ ldc \ ); \ +\ + /* Quick return if possible. */ \ + if ( *m == 0 || *n == 0 || (( PASTEMAC(ch,eq0)( *alpha ) || *k == 0) \ + && PASTEMAC(ch,eq1)( *beta ) )) \ + { \ + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ + return; \ + } \ \ /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ @@ -343,6 +365,16 @@ void dgemm_ ldc ); + /* Quick return if possible. */ + if ( *m == 0 || *n == 0 || ((*alpha == 0.0 || *k == 0) && *beta == 1.0)) + { + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + /* Finalize BLIS. */ + bli_finalize_auto(); + return; + } + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ bli_param_map_netlib_to_blis_trans(*transa, &blis_transa); bli_param_map_netlib_to_blis_trans(*transb, &blis_transb); @@ -666,6 +698,17 @@ void zgemm_ ldc ); + /* Quick return if possible. */ + if ( *m == 0 || *n == 0 || (( PASTEMAC(z,eq0)( *alpha ) || *k == 0) + && PASTEMAC(z,eq1)( *beta ) )) + { + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + /* Finalize BLIS. */ + bli_finalize_auto(); + return; + } + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); bli_param_map_netlib_to_blis_trans( *transb, &blis_transb ); @@ -918,6 +961,17 @@ void dzgemm_ ldc ); + /* Quick return if possible. */ + if ( *m == 0 || *n == 0 || (( PASTEMAC(z,eq0)( *alpha ) || *k == 0) + && PASTEMAC(z,eq1)( *beta ) )) + { + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + /* Finalize BLIS. */ + bli_finalize_auto(); + return; + } + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); bli_param_map_netlib_to_blis_trans( *transb, &blis_transb ); diff --git a/frame/compat/bla_gemmt.c b/frame/compat/bla_gemmt.c index e51b943667..7abad40acf 100644 --- a/frame/compat/bla_gemmt.c +++ b/frame/compat/bla_gemmt.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2020, Advanced Micro Devices, Inc. + Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -84,6 +84,16 @@ void PASTEF77(ch,blasname) \ ldb, \ ldc \ ); \ +\ + /* Quick return if possible. */ \ + if ( *n == 0 || (( PASTEMAC(ch,eq0)( *alpha ) || *k == 0) \ + && PASTEMAC(ch,eq1)( *beta ) )) \ + { \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ + return; \ + } \ \ /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ @@ -170,6 +180,16 @@ void PASTEF77(ch,blasname) \ ldb, \ ldc \ ); \ +\ + /* Quick return if possible. */ \ + if ( *n == 0 || (( PASTEMAC(ch,eq0)( *alpha ) || *k == 0) \ + && PASTEMAC(ch,eq1)( *beta ) )) \ + { \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ + return; \ + } \ \ /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ diff --git a/frame/compat/bla_hemm.c b/frame/compat/bla_hemm.c index fcd7858731..0e003012d2 100644 --- a/frame/compat/bla_hemm.c +++ b/frame/compat/bla_hemm.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019 - 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -84,6 +84,16 @@ void PASTEF77(ch,blasname) \ ldb, \ ldc \ ); \ +\ + /* Quick return if possible. */ \ + if ( *m == 0 || *n == 0 || ( PASTEMAC(ch,eq0)( *alpha ) \ + && PASTEMAC(ch,eq1)( *beta ) )) \ + { \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ + return; \ + } \ \ /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ bli_param_map_netlib_to_blis_side( *side, &blis_side ); \ @@ -165,6 +175,16 @@ void PASTEF77(ch,blasname) \ ldb, \ ldc \ ); \ +\ + /* Quick return if possible. */ \ + if ( *m == 0 || *n == 0 || ( PASTEMAC(ch,eq0)( *alpha ) \ + && PASTEMAC(ch,eq1)( *beta ) )) \ + { \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ + return; \ + } \ \ /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ bli_param_map_netlib_to_blis_side( *side, &blis_side ); \ diff --git a/frame/compat/bla_symm.c b/frame/compat/bla_symm.c index 078cbf743c..85aebb435f 100755 --- a/frame/compat/bla_symm.c +++ b/frame/compat/bla_symm.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019 - 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -83,6 +83,16 @@ void PASTEF77(ch,blasname) \ ldb, \ ldc \ ); \ +\ + /* Quick return if possible. */ \ + if ( *m == 0 || *n == 0 || ( PASTEMAC(ch,eq0)( *alpha ) \ + && PASTEMAC(ch,eq1)( *beta ) )) \ + { \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ + return; \ + } \ \ /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ bli_param_map_netlib_to_blis_side( *side, &blis_side ); \ @@ -163,6 +173,16 @@ void PASTEF77(ch,blasname) \ ldb, \ ldc \ ); \ +\ + /* Quick return if possible. */ \ + if ( *m == 0 || *n == 0 || ( PASTEMAC(ch,eq0)( *alpha ) \ + && PASTEMAC(ch,eq1)( *beta ) )) \ + { \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ + return; \ + } \ \ /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ bli_param_map_netlib_to_blis_side( *side, &blis_side ); \ diff --git a/frame/compat/bla_syr2k.c b/frame/compat/bla_syr2k.c index b2280423a7..6a4f31b969 100644 --- a/frame/compat/bla_syr2k.c +++ b/frame/compat/bla_syr2k.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin. - Copyright (C) 2019 - 2021, Advanced Micro Devices, Inc.All Rights Reserved. + Copyright (C) 2019 - 2022, Advanced Micro Devices, Inc.All Rights Reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -83,6 +83,16 @@ void PASTEF77(ch,blasname) \ ldb, \ ldc \ ); \ +\ + /* Quick return if possible. */ \ + if ( *m == 0 || (( PASTEMAC(ch,eq0)( *alpha ) || *k == 0) \ + && PASTEMAC(ch,eq1)( *beta ) )) \ + { \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ + return; \ + } \ \ /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ bli_param_map_netlib_to_blis_uplo( *uploc, &blis_uploc ); \ @@ -172,6 +182,16 @@ void PASTEF77(ch,blasname) \ ldb, \ ldc \ ); \ +\ + /* Quick return if possible. */ \ + if ( *m == 0 || (( PASTEMAC(ch,eq0)( *alpha ) || *k == 0) \ + && PASTEMAC(ch,eq1)( *beta ) )) \ + { \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ + return; \ + } \ \ /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ bli_param_map_netlib_to_blis_uplo( *uploc, &blis_uploc ); \ diff --git a/frame/compat/bla_syrk.c b/frame/compat/bla_syrk.c index 547fceaa79..376b23aec9 100644 --- a/frame/compat/bla_syrk.c +++ b/frame/compat/bla_syrk.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin. - Copyright (C) 2019 - 2021, Advanced Micro Devices, Inc.All Rights Reserved. + Copyright (C) 2019 - 2022, Advanced Micro Devices, Inc.All Rights Reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -81,6 +81,16 @@ void PASTEF77(ch,blasname) \ lda, \ ldc \ ); \ +\ + /* Quick return if possible. */ \ + if ( *m == 0 || (( PASTEMAC(ch,eq0)( *alpha ) || *k == 0) \ + && PASTEMAC(ch,eq1)( *beta ) )) \ + { \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ + return; \ + } \ \ /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ bli_param_map_netlib_to_blis_uplo( *uploc, &blis_uploc ); \ @@ -164,6 +174,16 @@ void PASTEF77(ch,blasname) \ lda, \ ldc \ ); \ +\ + /* Quick return if possible. */ \ + if ( *m == 0 || (( PASTEMAC(ch,eq0)( *alpha ) || *k == 0) \ + && PASTEMAC(ch,eq1)( *beta ) )) \ + { \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ + return; \ + } \ \ /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ bli_param_map_netlib_to_blis_uplo( *uploc, &blis_uploc ); \ diff --git a/frame/compat/bla_trmm.c b/frame/compat/bla_trmm.c index ee87b96c04..c319b3ab51 100644 --- a/frame/compat/bla_trmm.c +++ b/frame/compat/bla_trmm.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin. - Copyright (C) 2019 - 2021, Advanced Micro Devices, Inc.All Rights Reserved. + Copyright (C) 2019 - 2022, Advanced Micro Devices, Inc.All Rights Reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -86,6 +86,15 @@ void PASTEF77(ch,blasname) \ lda, \ ldb \ ); \ +\ + /* Quick return if possible. */ \ + if ( *m == 0 || *n == 0 ) \ + { \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ + return; \ + } \ \ /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ bli_param_map_netlib_to_blis_side( *side, &blis_side ); \ @@ -168,6 +177,15 @@ void PASTEF77(ch,blasname) \ lda, \ ldb \ ); \ +\ + /* Quick return if possible. */ \ + if ( *m == 0 || *n == 0 ) \ + { \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ + return; \ + } \ \ /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ bli_param_map_netlib_to_blis_side( *side, &blis_side ); \ diff --git a/frame/compat/bla_trsm.c b/frame/compat/bla_trsm.c index fea7ba6f17..e99805d8dd 100644 --- a/frame/compat/bla_trsm.c +++ b/frame/compat/bla_trsm.c @@ -85,6 +85,15 @@ void PASTEF77(ch,blasname) \ lda, \ ldb \ ); \ +\ + /* Quick return if possible. */ \ + if ( *m == 0 || *n == 0 ) \ + { \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ + return; \ + } \ \ /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ bli_param_map_netlib_to_blis_side( *side, &blis_side ); \ @@ -169,6 +178,15 @@ void PASTEF77(ch,blasname) \ lda, \ ldb \ ); \ +\ + /* Quick return if possible. */ \ + if ( *m == 0 || *n == 0 ) \ + { \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ + return; \ + } \ \ /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ bli_param_map_netlib_to_blis_side( *side, &blis_side ); \ diff --git a/frame/compat/bla_trsm_amd.c b/frame/compat/bla_trsm_amd.c index f479b5eac0..e1c997717d 100644 --- a/frame/compat/bla_trsm_amd.c +++ b/frame/compat/bla_trsm_amd.c @@ -85,6 +85,15 @@ void PASTEF77(ch,blasname) \ lda, \ ldb \ ); \ +\ + /* Quick return if possible. */ \ + if ( *m == 0 || *n == 0 ) \ + { \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ + return; \ + } \ \ /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ bli_param_map_netlib_to_blis_side( *side, &blis_side ); \ @@ -169,6 +178,15 @@ void PASTEF77(ch,blasname) \ lda, \ ldb \ ); \ +\ + /* Quick return if possible. */ \ + if ( *m == 0 || *n == 0 ) \ + { \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ + return; \ + } \ \ /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ bli_param_map_netlib_to_blis_side( *side, &blis_side ); \ @@ -424,6 +442,15 @@ void strsm_ ldb ); + /* Quick return if possible. */ + if ( *m == 0 || *n == 0 ) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + /* Finalize BLIS. */ + bli_finalize_auto(); + return; + } + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ bli_param_map_netlib_to_blis_side( *side, &blis_side ); bli_param_map_netlib_to_blis_uplo( *uploa, &blis_uploa ); @@ -686,6 +713,15 @@ void dtrsm_ ldb ); + /* Quick return if possible. */ + if ( *m == 0 || *n == 0 ) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + /* Finalize BLIS. */ + bli_finalize_auto(); + return; + } + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ bli_param_map_netlib_to_blis_side( *side, &blis_side ); bli_param_map_netlib_to_blis_uplo( *uploa, &blis_uploa ); @@ -982,6 +1018,15 @@ void ztrsm_ ldb ); + /* Quick return if possible. */ + if ( *m == 0 || *n == 0 ) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + /* Finalize BLIS. */ + bli_finalize_auto(); + return; + } + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ bli_param_map_netlib_to_blis_side( *side, &blis_side ); bli_param_map_netlib_to_blis_uplo( *uploa, &blis_uploa ); @@ -1308,6 +1353,15 @@ void ctrsm_ ldb ); + /* Quick return if possible. */ + if ( *m == 0 || *n == 0 ) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + /* Finalize BLIS. */ + bli_finalize_auto(); + return; + } + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ bli_param_map_netlib_to_blis_side( *side, &blis_side ); bli_param_map_netlib_to_blis_uplo( *uploa, &blis_uploa ); From 8adef27aca822f1867f3aeaa2bcb6cc0a11f27db Mon Sep 17 00:00:00 2001 From: Shubham Sharma Date: Tue, 16 Aug 2022 12:27:46 +0530 Subject: [PATCH 181/243] Optimization of DGEMMT SUP kernel for beta zero cases. Details: 1. In kernels for non-transpose variants, changes are made to optimize the cases of beta zero. 2. Validated the changes with BLIS Testsuite, GTestSuite(Functionality, Valgrind, Integer Tests) and Netlib Tests. 3. Fixed warnings during the build process. AMD-Internal: [CPUPL-2341] Change-Id: I8bb53ad619eb2413c999fe18eafd67c75fe1f83a --- frame/3/gemmt/bli_gemmt_sup_var1n2m.c | 288 ++----------- .../3/sup/bli_gemmsup_rv_haswell_asm_d6x8m.c | 395 +++++++++++++----- kernels/haswell/bli_kernels_haswell.h | 16 + 3 files changed, 331 insertions(+), 368 deletions(-) diff --git a/frame/3/gemmt/bli_gemmt_sup_var1n2m.c b/frame/3/gemmt/bli_gemmt_sup_var1n2m.c index 2af7e9f45f..1023821b86 100644 --- a/frame/3/gemmt/bli_gemmt_sup_var1n2m.c +++ b/frame/3/gemmt/bli_gemmt_sup_var1n2m.c @@ -66,249 +66,11 @@ typedef void (*gemmt_ker_ft) dim_t m0, dim_t n0, dim_t k0, - void* restrict alpha, - void* restrict a, inc_t rs_a0, inc_t cs_a0, - void* restrict b, inc_t rs_b0, inc_t cs_b0, - void* restrict beta, - void* restrict c, inc_t rs_c0, inc_t cs_c0, - auxinfo_t* restrict data, - cntx_t* restrict cntx - ); - -// Gemmt Upper variant kernel for m_offset=0 and n_offset=0 in 24x24 block -BLIS_INLINE void bli_dgemmsup_rv_haswell_asm_6x8m_0x0_U - ( - conj_t conja, - conj_t conjb, - dim_t m0, - dim_t n0, - dim_t k0, - void* restrict alpha, - void* restrict a, inc_t rs_a0, inc_t cs_a0, - void* restrict b, inc_t rs_b0, inc_t cs_b0, - void* restrict beta, - void* restrict c, inc_t rs_c0, inc_t cs_c0, - auxinfo_t* restrict data, - cntx_t* restrict cntx - ); - -// Gemmt Upper variant kernel for m_offset=6 and n_offset=8 in 24x24 block -BLIS_INLINE void bli_dgemmsup_rv_haswell_asm_6x8m_6x8_U - ( - conj_t conja, - conj_t conjb, - dim_t m0, - dim_t n0, - dim_t k0, - void* restrict alpha, - void* restrict a, inc_t rs_a0, inc_t cs_a0, - void* restrict b, inc_t rs_b0, inc_t cs_b0, - void* restrict beta, - void* restrict c, inc_t rs_c0, inc_t cs_c0, - auxinfo_t* restrict data, - cntx_t* restrict cntx - ); - -// Gemmt Upper variant kernel for m_offset=12 and n_offset=16 in 24x24 block -BLIS_INLINE void bli_dgemmsup_rv_haswell_asm_6x8m_12x16_U - ( - conj_t conja, - conj_t conjb, - dim_t m0, - dim_t n0, - dim_t k0, - void* restrict alpha, - void* restrict a, inc_t rs_a0, inc_t cs_a0, - void* restrict b, inc_t rs_b0, inc_t cs_b0, - void* restrict beta, - void* restrict c, inc_t rs_c0, inc_t cs_c0, - auxinfo_t* restrict data, - cntx_t* restrict cntx - ); - -// Gemmt Upper variant combined kernel for m_offset=12, n_offset=16 and m_offset=18, n_offset=16 in 24x24 block -BLIS_INLINE void bli_dgemmsup_rv_haswell_asm_6x8m_0x0_combined_U - ( - conj_t conja, - conj_t conjb, - dim_t m0, - dim_t n0, - dim_t k0, - void* restrict alpha, - void* restrict a, inc_t rs_a0, inc_t cs_a0, - void* restrict b, inc_t rs_b0, inc_t cs_b0, - void* restrict beta, - void* restrict c, inc_t rs_c0, inc_t cs_c0, - auxinfo_t* restrict data, - cntx_t* restrict cntx - ); - -// Gemmt Upper variant kernel for m_offset=6 and n_offset=0 in 24x24 block -BLIS_INLINE void bli_dgemmsup_rv_haswell_asm_6x8m_6x0_U - ( - conj_t conja, - conj_t conjb, - dim_t m0, - dim_t n0, - dim_t k0, - void* restrict alpha, - void* restrict a, inc_t rs_a0, inc_t cs_a0, - void* restrict b, inc_t rs_b0, inc_t cs_b0, - void* restrict beta, - void* restrict c, inc_t rs_c0, inc_t cs_c0, - auxinfo_t* restrict data, - cntx_t* restrict cntx - ); - -// Gemmt Upper variant kernel for m_offset=12 and n_offset=8 in 24x24 block -BLIS_INLINE void bli_dgemmsup_rv_haswell_asm_6x8m_12x8_U - ( - conj_t conja, - conj_t conjb, - dim_t m0, - dim_t n0, - dim_t k0, - void* restrict alpha, - void* restrict a, inc_t rs_a0, inc_t cs_a0, - void* restrict b, inc_t rs_b0, inc_t cs_b0, - void* restrict beta, - void* restrict c, inc_t rs_c0, inc_t cs_c0, - auxinfo_t* restrict data, - cntx_t* restrict cntx - ); - -// Gemmt Upper variant kernel for m_offset=18 and n_offset=16 in 24x24 block -BLIS_INLINE void bli_dgemmsup_rv_haswell_asm_6x8m_18x16_U - ( - conj_t conja, - conj_t conjb, - dim_t m0, - dim_t n0, - dim_t k0, - void* restrict alpha, - void* restrict a, inc_t rs_a0, inc_t cs_a0, - void* restrict b, inc_t rs_b0, inc_t cs_b0, - void* restrict beta, - void* restrict c, inc_t rs_c0, inc_t cs_c0, - auxinfo_t* restrict data, - cntx_t* restrict cntx - ); - -// Gemmt Lower variant kernel for m_offset=0 and n_offset=0 in 24x24 block -BLIS_INLINE void bli_dgemmsup_rv_haswell_asm_6x8m_0x0_L - ( - conj_t conja, - conj_t conjb, - dim_t m0, - dim_t n0, - dim_t k0, - void* restrict alpha, - void* restrict a, inc_t rs_a0, inc_t cs_a0, - void* restrict b, inc_t rs_b0, inc_t cs_b0, - void* restrict beta, - void* restrict c, inc_t rs_c0, inc_t cs_c0, - auxinfo_t* restrict data, - cntx_t* restrict cntx - ); - -// Gemmt Lower variant kernel for m_offset=6 and n_offset=8 in 24x24 block -BLIS_INLINE void bli_dgemmsup_rv_haswell_asm_6x8m_6x8_L - ( - conj_t conja, - conj_t conjb, - dim_t m0, - dim_t n0, - dim_t k0, - void* restrict alpha, - void* restrict a, inc_t rs_a0, inc_t cs_a0, - void* restrict b, inc_t rs_b0, inc_t cs_b0, - void* restrict beta, - void* restrict c, inc_t rs_c0, inc_t cs_c0, - auxinfo_t* restrict data, - cntx_t* restrict cntx - ); - -// Gemmt Lower variant kernel for m_offset=12 and n_offset=16 in 24x24 block -BLIS_INLINE void bli_dgemmsup_rv_haswell_asm_6x8m_12x16_L - ( - conj_t conja, - conj_t conjb, - dim_t m0, - dim_t n0, - dim_t k0, - void* restrict alpha, - void* restrict a, inc_t rs_a0, inc_t cs_a0, - void* restrict b, inc_t rs_b0, inc_t cs_b0, - void* restrict beta, - void* restrict c, inc_t rs_c0, inc_t cs_c0, - auxinfo_t* restrict data, - cntx_t* restrict cntx - ); - -// Gemmt Lower variant combined kernel for m_offset=0, n_offset=0 and m_offset=6, n_offset=0 in 24x24 block -BLIS_INLINE void bli_dgemmsup_rv_haswell_asm_6x8m_16x12_combined_L - ( - conj_t conja, - conj_t conjb, - dim_t m0, - dim_t n0, - dim_t k0, - void* restrict alpha, - void* restrict a, inc_t rs_a0, inc_t cs_a0, - void* restrict b, inc_t rs_b0, inc_t cs_b0, - void* restrict beta, - void* restrict c, inc_t rs_c0, inc_t cs_c0, - auxinfo_t* restrict data, - cntx_t* restrict cntx - ); - -// Gemmt Lower variant kernel for m_offset=6 and n_offset=0 in 24x24 block -BLIS_INLINE void bli_dgemmsup_rv_haswell_asm_6x8m_6x0_L - ( - conj_t conja, - conj_t conjb, - dim_t m0, - dim_t n0, - dim_t k0, - void* restrict alpha, - void* restrict a, inc_t rs_a0, inc_t cs_a0, - void* restrict b, inc_t rs_b0, inc_t cs_b0, - void* restrict beta, - void* restrict c, inc_t rs_c0, inc_t cs_c0, - auxinfo_t* restrict data, - cntx_t* restrict cntx - ); - -// Gemmt Lower variant kernel for m_offset=12 and n_offset=8 in 24x24 block -BLIS_INLINE void bli_dgemmsup_rv_haswell_asm_6x8m_12x8_L - ( - conj_t conja, - conj_t conjb, - dim_t m0, - dim_t n0, - dim_t k0, - void* restrict alpha, - void* restrict a, inc_t rs_a0, inc_t cs_a0, - void* restrict b, inc_t rs_b0, inc_t cs_b0, - void* restrict beta, - void* restrict c, inc_t rs_c0, inc_t cs_c0, - auxinfo_t* restrict data, - cntx_t* restrict cntx - ); - -// Gemmt Lower variant kernel for m_offset=18 and n_offset=16 in 24x24 block -BLIS_INLINE void bli_dgemmsup_rv_haswell_asm_6x8m_18x16_L - ( - conj_t conja, - conj_t conjb, - dim_t m0, - dim_t n0, - dim_t k0, - void* restrict alpha, - void* restrict a, inc_t rs_a0, inc_t cs_a0, - void* restrict b, inc_t rs_b0, inc_t cs_b0, - void* restrict beta, - void* restrict c, inc_t rs_c0, inc_t cs_c0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, auxinfo_t* restrict data, cntx_t* restrict cntx ); @@ -2177,11 +1939,11 @@ void PASTEMACT(ch,opname,uplo,varname) \ mr_cur, \ nr_cur, \ kc_cur, \ - alpha_cast, \ - a_ir, rs_a_use, cs_a_use, \ - b_jr, rs_b_use, cs_b_use, \ - beta_use, \ - c_ir, rs_c, cs_c, \ + (double*) alpha_cast, \ + (double*) a_ir, rs_a_use, cs_a_use, \ + (double*) b_jr, rs_b_use, cs_b_use, \ + (double*) beta_use, \ + (double*) c_ir, rs_c, cs_c, \ &aux, \ cntx \ ); \ @@ -2203,11 +1965,11 @@ void PASTEMACT(ch,opname,uplo,varname) \ mr_cur, \ nr_cur, \ kc_cur, \ - alpha_cast, \ - a_ir, rs_a_use, cs_a_use, \ - b_jr, rs_b_use, cs_b_use, \ - beta_use, \ - c_ir, rs_c, cs_c, \ + (double*) alpha_cast, \ + (double*) a_ir, rs_a_use, cs_a_use, \ + (double*) b_jr, rs_b_use, cs_b_use, \ + (double*) beta_use, \ + (double*) c_ir, rs_c, cs_c, \ &aux, \ cntx \ ); \ @@ -2849,11 +2611,11 @@ void PASTEMACT(ch,opname,uplo,varname) \ mr_cur, \ nr_cur, \ kc_cur, \ - alpha_cast, \ - a_ir, rs_a_use, cs_a_use, \ - b_jr, rs_b_use, cs_b_use, \ - beta_use, \ - c_ir, rs_c, cs_c, \ + (double*) alpha_cast, \ + (double*) a_ir, rs_a_use, cs_a_use, \ + (double*) b_jr, rs_b_use, cs_b_use, \ + (double*) beta_use, \ + (double*) c_ir, rs_c, cs_c, \ &aux, \ cntx \ ); \ @@ -2874,11 +2636,11 @@ void PASTEMACT(ch,opname,uplo,varname) \ mr_cur, \ nr_cur, \ kc_cur, \ - alpha_cast, \ - a_ir, rs_a_use, cs_a_use, \ - b_jr, rs_b_use, cs_b_use, \ - beta_use, \ - c_ir, rs_c, cs_c, \ + (double*) alpha_cast, \ + (double*) a_ir, rs_a_use, cs_a_use, \ + (double*) b_jr, rs_b_use, cs_b_use, \ + (double*) beta_use, \ + (double*) c_ir, rs_c, cs_c, \ &aux, \ cntx \ ); \ diff --git a/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8m.c b/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8m.c index 107917d078..eb734fe0d7 100644 --- a/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8m.c +++ b/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8m.c @@ -1127,11 +1127,11 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_0x0_L dim_t m0, dim_t n0, dim_t k0, - double* restrict alpha, - double* restrict a, inc_t rs_a0, inc_t cs_a0, - double* restrict b, inc_t rs_b0, inc_t cs_b0, - double* restrict beta, - double* restrict c, inc_t rs_c0, inc_t cs_c0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, auxinfo_t* restrict data, cntx_t* restrict cntx ) @@ -1541,7 +1541,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_0x0_L label(.DDONE) - + vzeroupper() end_asm( : // output operands (none) @@ -1604,11 +1604,11 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_6x8_L dim_t m0, dim_t n0, dim_t k0, - double* restrict alpha, - double* restrict a, inc_t rs_a0, inc_t cs_a0, - double* restrict b, inc_t rs_b0, inc_t cs_b0, - double* restrict beta, - double* restrict c, inc_t rs_c0, inc_t cs_c0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, auxinfo_t* restrict data, cntx_t* restrict cntx ) @@ -1861,7 +1861,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_6x8_L label(.DDONE) - + vzeroupper() end_asm( @@ -1925,11 +1925,11 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_12x16_L dim_t m0, dim_t n0, dim_t k0, - double* restrict alpha, - double* restrict a, inc_t rs_a0, inc_t cs_a0, - double* restrict b, inc_t rs_b0, inc_t cs_b0, - double* restrict beta, - double* restrict c, inc_t rs_c0, inc_t cs_c0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, auxinfo_t* restrict data, cntx_t* restrict cntx ) @@ -2032,19 +2032,21 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_12x16_L mov(var(cs_c), rsi) lea(mem(, rsi, 8), rsi) vxorpd(ymm0, ymm0, ymm0) + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) - cmp(imm(8), rdi) //rs_c == 0? + cmp(imm(8), rdi) //rs_c == 0? je(.DCOLSTOR) label(.DROWSTOR) - lea(mem(rcx, rdi, 4), rcx) //rcx += 4 * rdi + lea(mem(rcx, rdi, 4), rcx) //rcx += 4 * rdi vfmadd231pd(mem(rcx, 0*32), xmm3, xmm12) vmovlpd(xmm12, mem(rcx)) add(rdi, rcx) vfmadd231pd(mem(rcx, 0*32), xmm3, xmm14) vmovlpd(xmm14, mem(rcx)) vmovhpd(xmm14, mem(rcx, rsi, 1)) - jmp(.DRETURN) + jmp(.DDONE) label(.DCOLSTOR) @@ -2058,8 +2060,35 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_12x16_L vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm14) vmovupd(xmm12, mem(rdx )) vmovhpd(xmm14, mem(rdx, rsi, 1, 1*8)) + jmp(.DDONE) + + label(.DBETAZERO) + cmp(imm(8), rdi) //rs_c == 0? + je(.DCOLSTORBZ) + + label(.DROWSTORBZ) + lea(mem(rcx, rdi, 4), rcx) //rcx += 4 * rdi + vmovlpd(xmm12, mem(rcx)) + add(rdi, rcx) + vmovlpd(xmm14, mem(rcx)) + vmovhpd(xmm14, mem(rcx, rsi, 1)) + jmp(.DDONE) + + label(.DCOLSTORBZ) + + lea(mem(rcx, rdi, 4), rdx) + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vinsertf128(imm(0x1), xmm2, ymm0, ymm12) + vinsertf128(imm(0x1), xmm3, ymm1, ymm14) + + vmovupd(xmm12, mem(rdx )) + vmovhpd(xmm14, mem(rdx, rsi, 1, 1*8)) + jmp(.DDONE) + + label(.DDONE) + vzeroupper() - label(.DRETURN) end_asm( : // output operands (none) : // input operands @@ -2125,11 +2154,11 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_16x12_combined_L dim_t m0, dim_t n0, dim_t k0, - double* restrict alpha, - double* restrict a, inc_t rs_a0, inc_t cs_a0, - double* restrict b, inc_t rs_b0, inc_t cs_b0, - double* restrict beta, - double* restrict c, inc_t rs_c0, inc_t cs_c0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, auxinfo_t* restrict data, cntx_t* restrict cntx ) @@ -2143,7 +2172,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_16x12_combined_L uint64_t rs_c = rs_c0; uint64_t cs_c = cs_c0; uint64_t ps_a8 = bli_auxinfo_ps_a( data ) * sizeof( double ); - double* a_next = a + rs_a * 6; + double* a_next = ( (double*)a ) + rs_a * 6; begin_asm() mov(var(a), r14) mov(var(b), rbx) @@ -2428,6 +2457,9 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_16x12_combined_L lea(mem(, rsi, 8), rsi) vxorpd(ymm0, ymm0, ymm0) lea(mem(rsi, rsi, 2), rax) // rax = 3*cs_c; + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO) + cmp(imm(8), rdi) //rs_c == 8? je(.DCOLSTOR) @@ -2484,7 +2516,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_16x12_combined_L vfmadd231pd(mem(rcx, 1*32), ymm3, ymm15) vmovupd(ymm15, mem(rcx, 1*32)) - jmp(.DRETURN) + jmp(.DDONE) label(.DCOLSTOR) vbroadcastsd(mem(rbx), ymm3) @@ -2575,10 +2607,131 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_16x12_combined_L vmovupd(xmm1, mem(rdx, rsi, 1)) vmovupd(xmm2, mem(rdx, rsi, 2)) vmovhpd(xmm4, mem(rdx, rax, 1, 1*8)) + jmp(.DDONE) + + label(.DBETAZERO) + cmp(imm(8), rdi) + je(.DCOLSTORBZ) + + label(.DROWSTORBZ) + lea(mem(rcx, rdi, 4), rcx) //rcx += 4 * rdi + vmovlpd(xmm5, mem(rcx)) + add(rdi, rcx) + vmovlpd(xmm7, mem(rcx)) + vmovhpd(xmm7, mem(rcx, rsi, 1)) + + //For lower 6x8 block + lea(mem(rcx, rdi, 1), rcx) //rcx += 1 * rdi + vmovupd(xmm4, mem(rcx, 0*32)) + vextractf128(imm(1), ymm4, xmm1) + vmovlpd(xmm1, mem(rcx, 2*8)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx, 0*32)) + add(rdi, rcx) + + vmovupd(ymm8, mem(rcx, 0*32)) + + vmovlpd(xmm9, mem(rcx, 1*32)) + add(rdi, rcx) + + vmovupd(ymm10, mem(rcx, 0*32)) + + vmovupd(xmm11, mem(rcx, 1*32)) + add(rdi, rcx) + + + vmovupd(ymm12, mem(rcx, 0*32)) + + vmovupd(xmm13, mem(rcx, 1*32)) + vextractf128(imm(1), ymm13, xmm1) + vmovlpd(xmm1, mem(rcx, 1*32+2*8)) + add(rdi, rcx) + + vmovupd(ymm14, mem(rcx, 0*32)) + + vmovupd(ymm15, mem(rcx, 1*32)) + + jmp(.DDONE) + + label(.DCOLSTORBZ) + vbroadcastsd(mem(rbx), ymm3) + + lea(mem(rcx, rdi, 4), rdx) //rdx = rcx + 4* rs_c + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + + vmovupd(xmm5, mem(rdx )) + vmovhpd(xmm7, mem(rdx, rsi, 1, 1*8)) + + lea(mem(rcx, rdi, 4), rcx) + lea(mem(rcx, rdi, 2), rcx) + lea(mem(rcx, rdi, 4), rdx) + + vunpcklpd(ymm6, ymm4, ymm0) + vunpckhpd(ymm6, ymm4, ymm1) + vunpcklpd(ymm10, ymm8, ymm2) + vunpckhpd(ymm10, ymm8, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm4) + vinsertf128(imm(0x1), xmm3, ymm1, ymm6) + vperm2f128(imm(0x31), ymm2, ymm0, ymm8) + vperm2f128(imm(0x31), ymm3, ymm1, ymm10) + + vbroadcastsd(mem(rbx), ymm3) + + vmovupd(ymm4, mem(rcx )) + vmovupd(ymm6, mem(rcx, rsi, 1)) + vmovupd(ymm8, mem(rcx, rsi, 2)) + vextractf128(imm(1), ymm10, xmm1) + vmovhpd(xmm10, mem(rcx, rax, 1, 1*8)) + vmovupd(xmm1, mem(rcx, rax, 1, 2*8)) + + + lea(mem(rcx, rsi, 4), rcx) + + vunpcklpd(ymm14, ymm12, ymm0) + vunpckhpd(ymm14, ymm12, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovupd(xmm4, mem(rdx, rax, 1)) + + lea(mem(rdx, rsi, 4), rdx) + + // begin I/O on columns 4-7 + vunpcklpd(ymm7, ymm5, ymm0) + vunpckhpd(ymm7, ymm5, ymm1) + vunpcklpd(ymm11, ymm9, ymm2) + vunpckhpd(ymm11, ymm9, ymm3) + vinsertf128(imm(0x1), xmm2, ymm0, ymm5) + vinsertf128(imm(0x1), xmm3, ymm1, ymm7) + + vbroadcastsd(mem(rbx), ymm3) + + vextractf128(imm(1), ymm5, xmm1) + vmovupd(xmm1, mem(rcx, 2*8 )) + vextractf128(imm(1), ymm7, xmm1) + vmovhpd(xmm1, mem(rcx, rsi, 1, 3*8)) + vunpcklpd(ymm15, ymm13, ymm0) + vunpckhpd(ymm15, ymm13, ymm1) + vextractf128(imm(0x1), ymm0, xmm2) + vextractf128(imm(0x1), ymm1, xmm4) + vmovupd(xmm0, mem(rdx )) + vmovupd(xmm1, mem(rdx, rsi, 1)) + vmovupd(xmm2, mem(rdx, rsi, 2)) + vmovhpd(xmm4, mem(rdx, rax, 1, 1*8)) + jmp(.DDONE) + + label(.DDONE) + vzeroupper() - label(.DRETURN) end_asm( : // output operands (none) : // input operands @@ -2639,11 +2792,11 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_6x0_L dim_t m0, dim_t n0, dim_t k0, - double* restrict alpha, - double* restrict a, inc_t rs_a0, inc_t cs_a0, - double* restrict b, inc_t rs_b0, inc_t cs_b0, - double* restrict beta, - double* restrict c, inc_t rs_c0, inc_t cs_c0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, auxinfo_t* restrict data, cntx_t* restrict cntx ) @@ -3097,7 +3250,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_6x0_L vmovupd(xmm4, mem(rdx, rax, 1)) label(.DDONE) - + vzeroupper() end_asm( : // output operands (none) @@ -3162,11 +3315,11 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_12x8_L dim_t m0, dim_t n0, dim_t k0, - double* restrict alpha, - double* restrict a, inc_t rs_a0, inc_t cs_a0, - double* restrict b, inc_t rs_b0, inc_t cs_b0, - double* restrict beta, - double* restrict c, inc_t rs_c0, inc_t cs_c0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, auxinfo_t* restrict data, cntx_t* restrict cntx ) @@ -3624,6 +3777,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_12x8_L vmovupd(xmm4, mem(rdx, rax, 1)) label(.DDONE) + vzeroupper() end_asm( : // output operands (none) @@ -3688,11 +3842,11 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_18x16_L dim_t m0, dim_t n0, dim_t k0, - double* restrict alpha, - double* restrict a, inc_t rs_a0, inc_t cs_a0, - double* restrict b, inc_t rs_b0, inc_t cs_b0, - double* restrict beta, - double* restrict c, inc_t rs_c0, inc_t cs_c0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, auxinfo_t* restrict data, cntx_t* restrict cntx ) @@ -4140,6 +4294,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_18x16_L label(.DDONE) label(.DRETURN) + vzeroupper() end_asm( : // output operands (none) @@ -4203,11 +4358,11 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_0x0_U dim_t m0, dim_t n0, dim_t k0, - double* restrict alpha, - double* restrict a, inc_t rs_a0, inc_t cs_a0, - double* restrict b, inc_t rs_b0, inc_t cs_b0, - double* restrict beta, - double* restrict c, inc_t rs_c0, inc_t cs_c0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, auxinfo_t* restrict data, cntx_t* restrict cntx ) @@ -4572,6 +4727,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_0x0_U label(.DDONE) + vzeroupper() end_asm( : // output operands (none) @@ -4632,11 +4788,11 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_6x8_U dim_t m0, dim_t n0, dim_t k0, - double* restrict alpha, - double* restrict a, inc_t rs_a0, inc_t cs_a0, - double* restrict b, inc_t rs_b0, inc_t cs_b0, - double* restrict beta, - double* restrict c, inc_t rs_c0, inc_t cs_c0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, auxinfo_t* restrict data, cntx_t* restrict cntx ) @@ -4993,10 +5149,6 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_6x8_U vextractf128(imm(0x1), ymm0, xmm2) vextractf128(imm(0x1), ymm1, xmm4) - vfmadd231pd(mem(rdx ), xmm3, xmm0) - vfmadd231pd(mem(rdx, rsi, 1), xmm3, xmm1) - vfmadd231pd(mem(rdx, rsi, 2), xmm3, xmm2) - vfmadd231pd(mem(rdx, rax, 1), xmm3, xmm4) vmovlpd(xmm2, mem(rdx, rsi, 2)) vmovupd(xmm4, mem(rdx, rax, 1)) @@ -5029,6 +5181,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_6x8_U label(.DDONE) + vzeroupper() end_asm( : // output operands (none) @@ -5090,11 +5243,11 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_12x16_U dim_t m0, dim_t n0, dim_t k0, - double* restrict alpha, - double* restrict a, inc_t rs_a0, inc_t cs_a0, - double* restrict b, inc_t rs_b0, inc_t cs_b0, - double* restrict beta, - double* restrict c, inc_t rs_c0, inc_t cs_c0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, auxinfo_t* restrict data, cntx_t* restrict cntx ) @@ -5480,6 +5633,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_12x16_U label(.DDONE) + vzeroupper() end_asm( : // output operands (none) @@ -5541,11 +5695,11 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_6x0_U dim_t m0, dim_t n0, dim_t k0, - double* restrict alpha, - double* restrict a, inc_t rs_a0, inc_t cs_a0, - double* restrict b, inc_t rs_b0, inc_t cs_b0, - double* restrict beta, - double* restrict c, inc_t rs_c0, inc_t cs_c0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, auxinfo_t* restrict data, cntx_t* restrict cntx ) @@ -5656,36 +5810,67 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_6x0_U mov(var(cs_c), rsi) lea(mem(, rsi, 8), rsi) vxorpd(ymm0, ymm0, ymm0) + vucomisd(xmm0, xmm3) + je(.DBETAZERO) cmp(imm(8), rdi) je(.DCOLSTOR) label(.DROWSTOR) - lea(mem(rcx, 1*32), rcx) - lea(mem(rcx, 1*16), rcx) + lea(mem(rcx, 1*32), rcx) + lea(mem(rcx, 1*16), rcx) vfmadd231pd(mem(rcx, 0*32), xmm3, xmm5) - vmovlpd(xmm5, mem(rcx)) + vmovlpd(xmm5, mem(rcx)) vmovhpd(xmm5, mem(rcx, rsi, 1)) add(rdi, rcx) vfmadd231pd(mem(rcx, 0*32), xmm3, xmm7) vmovhpd(xmm7, mem(rcx, rsi, 1)) - jmp(.DRETURN) + jmp(.DDONE) label(.DCOLSTOR) - vbroadcastsd(mem(rbx), ymm3) + vbroadcastsd(mem(rbx), ymm3) lea(mem(rcx, rsi, 4), rcx) lea(mem(rcx, rsi, 2), rcx) vunpcklpd(xmm7, xmm5, xmm0) vunpckhpd(xmm7, xmm5, xmm1) - vfmadd231pd(mem(rcx ), xmm3, xmm0) - vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm1) - vmovlpd(xmm0, mem(rcx )) - vmovupd(xmm1, mem(rcx, rsi, 1)) + vfmadd231pd(mem(rcx ), xmm3, xmm0) + vfmadd231pd(mem(rcx, rsi, 1), xmm3, xmm1) + vmovlpd(xmm0, mem(rcx )) + vmovupd(xmm1, mem(rcx, rsi, 1)) + jmp(.DDONE) + + label(.DBETAZERO) + cmp(imm(8), rdi) + je(.DCOLSTORBZ) + + label(.DROWSTORBZ) + lea(mem(rcx, 1*32), rcx) + lea(mem(rcx, 1*16), rcx) + + vmovlpd(xmm5, mem(rcx)) + vmovhpd(xmm5, mem(rcx, rsi, 1)) + add(rdi, rcx) + vmovhpd(xmm7, mem(rcx, rsi, 1)) + + jmp(.DDONE) + + label(.DCOLSTORBZ) + + lea(mem(rcx, rsi, 4), rcx) + lea(mem(rcx, rsi, 2), rcx) + vunpcklpd(xmm7, xmm5, xmm0) + vunpckhpd(xmm7, xmm5, xmm1) + vmovlpd(xmm0, mem(rcx )) + vmovupd(xmm1, mem(rcx, rsi, 1)) + jmp(.DDONE) + + + label(.DDONE) + vzeroupper() - label(.DRETURN) end_asm( : // output operands (none) : // input operands @@ -5745,11 +5930,11 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_12x8_U dim_t m0, dim_t n0, dim_t k0, - double* restrict alpha, - double* restrict a, inc_t rs_a0, inc_t cs_a0, - double* restrict b, inc_t rs_b0, inc_t cs_b0, - double* restrict beta, - double* restrict c, inc_t rs_c0, inc_t cs_c0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, auxinfo_t* restrict data, cntx_t* restrict cntx ) @@ -6008,10 +6193,11 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_12x8_U vmovupd(xmm7, mem(rcx, rsi, 1)) vmovupd(xmm9, mem(rcx, rsi, 2)) vextractf128(imm(0x1), ymm9, xmm9) - vmovupd(ymm9, mem(rcx, rsi, 2, 1*16)) - + vmovlpd(xmm9, mem(rcx, rsi, 2, 2*8)) + vmovupd(ymm11, mem(rcx, rax, 1)) label(.DDONE) + vzeroupper() end_asm( : // output operands (none) @@ -6073,11 +6259,11 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_18x16_U dim_t m0, dim_t n0, dim_t k0, - double* restrict alpha, - double* restrict a, inc_t rs_a0, inc_t cs_a0, - double* restrict b, inc_t rs_b0, inc_t cs_b0, - double* restrict beta, - double* restrict c, inc_t rs_c0, inc_t cs_c0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, auxinfo_t* restrict data, cntx_t* restrict cntx ) @@ -6394,7 +6580,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_18x16_U vmovupd(xmm5, mem(rcx )) vextractf128(imm(0x1), ymm5, xmm5) - vmovlpd(xmm5, mem(rcx )) + vmovlpd(xmm5, mem(rcx, 2*8 )) vmovupd(ymm7, mem(rcx, rsi, 1)) vmovupd(ymm9, mem(rcx, rsi, 2)) vmovupd(ymm11, mem(rcx, rax, 1)) @@ -6409,6 +6595,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_18x16_U label(.DDONE) + vzeroupper() end_asm( : // output operands (none) @@ -6476,11 +6663,11 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_0x0_combined_U dim_t m0, dim_t n0, dim_t k0, - double* restrict alpha, - double* restrict a, inc_t rs_a0, inc_t cs_a0, - double* restrict b, inc_t rs_b0, inc_t cs_b0, - double* restrict beta, - double* restrict c, inc_t rs_c0, inc_t cs_c0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, auxinfo_t* restrict data, cntx_t* restrict cntx ) @@ -7012,9 +7199,9 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_0x0_combined_U vunpckhpd(ymm14, ymm12, ymm1) vextractf128(imm(0x1), ymm0, xmm2) vextractf128(imm(0x1), ymm1, xmm4) - lea(mem(rcx, rsi, 4), rbp) + lea(mem(rcx, rdi, 4), rbp) + lea(mem(rbp, rdi, 2), rbp) lea(mem(rbp, rsi, 2), rbp) - lea(mem(rbp, 1*32+1*16), rbp) vmovlpd(xmm2, mem(rbp)) vmovupd(xmm4, mem(rbp, rsi, 1)) @@ -7047,6 +7234,7 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_0x0_combined_U label(.DDONE) + vzeroupper() end_asm( : // output operands (none) @@ -7741,8 +7929,7 @@ void bli_dgemmsup_rv_haswell_asm_6x6m label(.DRETURN) - - + vzeroupper() end_asm( : // output operands (none) @@ -8400,8 +8587,7 @@ void bli_dgemmsup_rv_haswell_asm_6x4m label(.DRETURN) - - + vzeroupper() end_asm( : // output operands (none) @@ -9034,8 +9220,7 @@ void bli_dgemmsup_rv_haswell_asm_6x2m label(.DRETURN) - - + vzeroupper() end_asm( : // output operands (none) diff --git a/kernels/haswell/bli_kernels_haswell.h b/kernels/haswell/bli_kernels_haswell.h index 1c35122a4e..5b4c8a05bc 100644 --- a/kernels/haswell/bli_kernels_haswell.h +++ b/kernels/haswell/bli_kernels_haswell.h @@ -278,6 +278,22 @@ GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_1x1 ) // gemmsup_rd (mkernel in m dim) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8m_0x0_U ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8m_6x0_U ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8m_6x8_U ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8m_12x8_U ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8m_12x16_U ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8m_18x16_U ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8m_0x0_combined_U ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8m_0x0_L ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8m_6x0_L ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8m_6x8_L ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8m_12x8_L ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8m_12x16_L ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8m_18x16_L ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8m_16x12_combined_L ) +GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8m_0x0_combined_U ) + GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_6x8m ) GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_6x4m ) GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_6x2m ) From 32c9239c7fdf7c1b6a3f7bd01dce687c2d9fc47a Mon Sep 17 00:00:00 2001 From: Shubham Sharma Date: Thu, 18 Aug 2022 08:22:45 -0400 Subject: [PATCH 182/243] Optimization of DGEMMT SUP kernels Details: 1. Optimized the kernels by replacing the macros with the actual computation of required output elements. AMD-Internal: [CPUPL-2341] Change-Id: Ieefb80ac9b2dc2955b683710e259cf45d581e1b5 --- .../3/sup/bli_gemmsup_rv_haswell_asm_d6x8m.c | 1279 ++++++++++++++--- 1 file changed, 1051 insertions(+), 228 deletions(-) diff --git a/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8m.c b/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8m.c index eb734fe0d7..8ac3612bdf 100644 --- a/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8m.c +++ b/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_d6x8m.c @@ -1048,56 +1048,6 @@ void bli_dgemmsup_rv_haswell_asm_6x8m prefetch(0, mem(rdx, rsi, 1, 5*8)) \ prefetch(0, mem(rdx, rsi, 2, 5*8)) \ -#define SUBITER4x4(a, b, r1, r2, r3, r4) \ -\ - vmovupd(mem(b, 0*32), ymm0) \ - \ - vbroadcastsd(mem(a ), ymm2) \ - vbroadcastsd(mem(a, r8, 1), ymm3) \ - vfmadd231pd(ymm0, ymm2, r1) \ - vfmadd231pd(ymm0, ymm3, r2) \ - \ - vbroadcastsd(mem(a, r8, 2), ymm2) \ - vbroadcastsd(mem(a, r13, 1), ymm3) \ - vfmadd231pd(ymm0, ymm2, r3) \ - vfmadd231pd(ymm0, ymm3, r4) \ - -#define SUBITER2x4(a, b, r1, r2) \ -\ - vmovupd(mem(b, 0*32), ymm0) \ - \ - vbroadcastsd(mem(a ), ymm2) \ - vbroadcastsd(mem(a, r8, 1), ymm3) \ - vfmadd231pd(ymm0, ymm2, r1) \ - vfmadd231pd(ymm0, ymm3, r2) \ - -#define SUBITER2x2(a, b, r1, r2) \ -\ - vmovupd(mem(b, 0*32), xmm0) \ - \ - vbroadcastsd(mem(a ), ymm2) \ - vbroadcastsd(mem(a, r8, 1), ymm3) \ - vfmadd231pd(xmm0, xmm2, r1) \ - vfmadd231pd(xmm0, xmm3, r2) \ - -#define SUBITER6x4(a, b, r1, r2, r3, r4, r5, r6) \ -\ - vmovupd(mem(b, 0*32), ymm0) \ - \ - vbroadcastsd(mem(a ), ymm2) \ - vbroadcastsd(mem(a, r8, 1), ymm3) \ - vfmadd231pd(ymm0, ymm2, r1) \ - vfmadd231pd(ymm0, ymm3, r2) \ - \ - vbroadcastsd(mem(a, r8, 2), ymm2) \ - vbroadcastsd(mem(a, r13, 1), ymm3) \ - vfmadd231pd(ymm0, ymm2, r3) \ - vfmadd231pd(ymm0, ymm3, r4) \ - \ - vbroadcastsd(mem(a, r8, 4), ymm2) \ - vbroadcastsd(mem(a, r15, 1), ymm3) \ - vfmadd231pd(ymm0, ymm2, r5) \ - vfmadd231pd(ymm0, ymm3, r6) \ /* Following kernel computes the 6x8 block for the Lower vairant(L) of gemmt where @@ -1224,36 +1174,81 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_0x0_L prefetch(0, mem(rdx, 5*8)) - SUBITER6x4(rax, rbx, ymm4, ymm6, ymm8, ymm10, ymm12, ymm14) - lea(mem(rax, r8, 4), rbp) - lea(mem(rbx, 1*32), rcx) - SUBITER2x2(rbp, rcx, xmm13, xmm15) + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(xmm1, xmm2, xmm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(xmm1, xmm3, xmm15) // ---------------------------------- iteration 1 prefetch(0, mem(rdx, r9, 1, 5*8)) - SUBITER6x4(rax, rbx, ymm4, ymm6, ymm8, ymm10, ymm12, ymm14) - lea(mem(rax, r8, 4), rbp) - lea(mem(rbx, 1*32), rcx) - SUBITER2x2(rbp, rcx, xmm13, xmm15) + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(xmm1, xmm2, xmm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(xmm1, xmm3, xmm15) // ---------------------------------- iteration 2 prefetch(0, mem(rdx, r9, 2, 5*8)) - SUBITER6x4(rax, rbx, ymm4, ymm6, ymm8, ymm10, ymm12, ymm14) - lea(mem(rax, r8, 4), rbp) - lea(mem(rbx, 1*32), rcx) - SUBITER2x2(rbp, rcx, xmm13, xmm15) + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(xmm1, xmm2, xmm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(xmm1, xmm3, xmm15) // ---------------------------------- iteration 3 @@ -1261,12 +1256,27 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_0x0_L prefetch(0, mem(rdx, rcx, 1, 5*8)) lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; - SUBITER6x4(rax, rbx, ymm4, ymm6, ymm8, ymm10, ymm12, ymm14) - lea(mem(rax, r8, 4), rbp) - lea(mem(rbx, 1*32), rcx) - SUBITER2x2(rbp, rcx, xmm13, xmm15) + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(xmm1, xmm2, xmm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(xmm1, xmm3, xmm15) @@ -1287,13 +1297,27 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_0x0_L prefetch(0, mem(rdx, 5*8)) add(r9, rdx) - SUBITER6x4(rax, rbx, ymm4, ymm6, ymm8, ymm10, ymm12, ymm14) - lea(mem(rax, r8, 4), rbp) - lea(mem(rbx, 1*32), rcx) - SUBITER2x2(rbp, rcx, xmm13, xmm15) + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) add(r10, rbx) // b += rs_b; - add(r9, rax) // a += cs_a; + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm0, xmm2, xmm4) + vfmadd231pd(xmm0, xmm3, xmm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm0, ymm3, ymm10) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(xmm1, xmm2, xmm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(xmm1, xmm3, xmm15) dec(rsi) // i -= 1; jne(.DLOOPKLEFT) // iterate again if i != 0. @@ -1672,7 +1696,6 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_6x8_L prefetch(0, mem(rdx, 5*8)) label(.DPOSTPFETCH) - lea(mem(rax, r8, 2), rax) mov(var(k_iter), rsi) test(rsi, rsi) je(.DCONSIDKLEFT) @@ -1680,19 +1703,51 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_6x8_L // computer xmm8, xmm10, ymm12, ymm14 only label(.DLOOPKITER) //0 - SUBITER4x4(rax, rbx, ymm8, ymm10, ymm12, ymm14) + vmovupd(mem(rbx, 0*32), ymm0) + vbroadcastsd(mem(rax, r8, 2), ymm1) + vbroadcastsd(mem(rax, r13, 1), ymm2) + vbroadcastsd(mem(rax, r8, 4), ymm3) + vbroadcastsd(mem(rax, r15, 1), ymm4) + vfmadd231pd(xmm0, xmm1, xmm8) + vfmadd231pd(xmm0, xmm2, xmm10) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm0, ymm4, ymm14) add(r10, rbx) // b += rs_b; add(r9, rax) // a += cs_a; //1 - SUBITER4x4(rax, rbx, ymm8, ymm10, ymm12, ymm14) + vmovupd(mem(rbx, 0*32), ymm0) + vbroadcastsd(mem(rax, r8, 2), ymm1) + vbroadcastsd(mem(rax, r13, 1), ymm2) + vbroadcastsd(mem(rax, r8, 4), ymm3) + vbroadcastsd(mem(rax, r15, 1), ymm4) + vfmadd231pd(xmm0, xmm1, xmm8) + vfmadd231pd(xmm0, xmm2, xmm10) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm0, ymm4, ymm14) add(r10, rbx) // b += rs_b; add(r9, rax) // a += cs_a; //2 - SUBITER4x4(rax, rbx, ymm8, ymm10, ymm12, ymm14) + vmovupd(mem(rbx, 0*32), ymm0) + vbroadcastsd(mem(rax, r8, 2), ymm1) + vbroadcastsd(mem(rax, r13, 1), ymm2) + vbroadcastsd(mem(rax, r8, 4), ymm3) + vbroadcastsd(mem(rax, r15, 1), ymm4) + vfmadd231pd(xmm0, xmm1, xmm8) + vfmadd231pd(xmm0, xmm2, xmm10) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm0, ymm4, ymm14) add(r10, rbx) // b += rs_b; add(r9, rax) // a += cs_a; //3 - SUBITER4x4(rax, rbx, ymm8, ymm10, ymm12, ymm14) + vmovupd(mem(rbx, 0*32), ymm0) + vbroadcastsd(mem(rax, r8, 2), ymm1) + vbroadcastsd(mem(rax, r13, 1), ymm2) + vbroadcastsd(mem(rax, r8, 4), ymm3) + vbroadcastsd(mem(rax, r15, 1), ymm4) + vfmadd231pd(xmm0, xmm1, xmm8) + vfmadd231pd(xmm0, xmm2, xmm10) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm0, ymm4, ymm14) add(r10, rbx) // b += rs_b; add(r9, rax) // a += cs_a; @@ -1706,7 +1761,15 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_6x8_L je(.DPOSTACCUM) label(.DLOOPKLEFT) - SUBITER4x4(rax, rbx, ymm8, ymm10, ymm12, ymm14) + vmovupd(mem(rbx, 0*32), ymm0) + vbroadcastsd(mem(rax, r8, 2), ymm1) + vbroadcastsd(mem(rax, r13, 1), ymm2) + vbroadcastsd(mem(rax, r8, 4), ymm3) + vbroadcastsd(mem(rax, r15, 1), ymm4) + vfmadd231pd(xmm0, xmm1, xmm8) + vfmadd231pd(xmm0, xmm2, xmm10) + vfmadd231pd(ymm0, ymm3, ymm12) + vfmadd231pd(ymm0, ymm4, ymm14) add(r10, rbx) // b += rs_b; add(r9, rax) // a += cs_a; @@ -1989,19 +2052,35 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_12x16_L //compute xmm12 and xmm 14 label(.DMAIN) //0 - SUBITER2x2(rax, rbx, xmm12, xmm14) + vmovupd(mem(rbx, 0*32), xmm0) + vbroadcastsd(mem(rax, r8, 4), ymm3) + vbroadcastsd(mem(rax, r15, 1), ymm4) + vfmadd231pd(xmm0, xmm3, xmm12) + vfmadd231pd(xmm0, xmm4, xmm14) add(r10, rbx) add(r9, rax) //1 - SUBITER2x2(rax, rbx, xmm12, xmm14) + vmovupd(mem(rbx, 0*32), xmm0) + vbroadcastsd(mem(rax, r8, 4), ymm3) + vbroadcastsd(mem(rax, r15, 1), ymm4) + vfmadd231pd(xmm0, xmm3, xmm12) + vfmadd231pd(xmm0, xmm4, xmm14) add(r10, rbx) add(r9, rax) //2 - SUBITER2x2(rax, rbx, xmm12, xmm14) + vmovupd(mem(rbx, 0*32), xmm0) + vbroadcastsd(mem(rax, r8, 4), ymm3) + vbroadcastsd(mem(rax, r15, 1), ymm4) + vfmadd231pd(xmm0, xmm3, xmm12) + vfmadd231pd(xmm0, xmm4, xmm14) add(r10, rbx) add(r9, rax) //3 - SUBITER2x2(rax, rbx, xmm12, xmm14) + vmovupd(mem(rbx, 0*32), xmm0) + vbroadcastsd(mem(rax, r8, 4), ymm3) + vbroadcastsd(mem(rax, r15, 1), ymm4) + vfmadd231pd(xmm0, xmm3, xmm12) + vfmadd231pd(xmm0, xmm4, xmm14) add(r10, rbx) add(r9, rax) @@ -2014,7 +2093,11 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_12x16_L je(.DPOSTACC) label(.DLEFT) - SUBITER2x2(rax, rbx, xmm12, xmm14) + vmovupd(mem(rbx, 0*32), xmm0) + vbroadcastsd(mem(rax, r8, 4), ymm3) + vbroadcastsd(mem(rax, r15, 1), ymm4) + vfmadd231pd(xmm0, xmm3, xmm12) + vfmadd231pd(xmm0, xmm4, xmm14) add(r10, rbx) add(r9, rax) dec(rsi) @@ -2886,37 +2969,117 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_6x0_L // ---------------------------------- iteration 0 prefetch(0, mem(rdx, 5*8)) - SUBITER6x4(rax, rbx, ymm4, ymm6, ymm8, ymm10, ymm12, ymm14) - lea(mem(rbx, 1*32), rbp) - SUBITER6x4(rax, rbp, ymm5, ymm7, ymm9, ymm11, ymm13, ymm15) + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) // ---------------------------------- iteration 1 prefetch(0, mem(rdx, r9, 1, 5*8)) - SUBITER6x4(rax, rbx, ymm4, ymm6, ymm8, ymm10, ymm12, ymm14) - lea(mem(rbx, 1*32), rbp) - SUBITER6x4(rax, rbp, ymm5, ymm7, ymm9, ymm11, ymm13, ymm15) + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) // ---------------------------------- iteration 2 prefetch(0, mem(rdx, r9, 2, 5*8)) - SUBITER6x4(rax, rbx, ymm4, ymm6, ymm8, ymm10, ymm12, ymm14) - lea(mem(rbx, 1*32), rbp) - SUBITER6x4(rax, rbp, ymm5, ymm7, ymm9, ymm11, ymm13, ymm15) + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) // ---------------------------------- iteration 3 prefetch(0, mem(rdx, rcx, 1, 5*8)) lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; - SUBITER6x4(rax, rbx, ymm4, ymm6, ymm8, ymm10, ymm12, ymm14) - lea(mem(rbx, 1*32), rbp) - SUBITER6x4(rax, rbp, ymm5, ymm7, ymm9, ymm11, ymm13, ymm15) + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) dec(rsi) // i -= 1; jne(.DLOOPKITER) // iterate again if i != 0. @@ -2939,11 +3102,31 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_6x0_L prefetch(0, mem(rdx, 5*8)) add(r9, rdx) - SUBITER6x4(rax, rbx, ymm4, ymm6, ymm8, ymm10, ymm12, ymm14) - lea(mem(rbx, 1*32), rbp) - SUBITER6x4(rax, rbp, ymm5, ymm7, ymm9, ymm11, ymm13, ymm15) + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) dec(rsi) // i -= 1; @@ -3410,39 +3593,120 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_12x8_L prefetch(0, mem(rdx, 5*8)) - SUBITER6x4(rax, rbx, ymm4, ymm6, ymm8, ymm10, ymm12, ymm14) - lea(mem(rbx, 1*32), rbp) - SUBITER6x4(rax, rbp, ymm5, ymm7, ymm9, ymm11, ymm13, ymm15) + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) // ---------------------------------- iteration 1 prefetch(0, mem(rdx, r9, 1, 5*8)) - SUBITER6x4(rax, rbx, ymm4, ymm6, ymm8, ymm10, ymm12, ymm14) - lea(mem(rbx, 1*32), rbp) - SUBITER6x4(rax, rbp, ymm5, ymm7, ymm9, ymm11, ymm13, ymm15) + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) // ---------------------------------- iteration 2 prefetch(0, mem(rdx, r9, 2, 5*8)) - SUBITER6x4(rax, rbx, ymm4, ymm6, ymm8, ymm10, ymm12, ymm14) - lea(mem(rbx, 1*32), rbp) - SUBITER6x4(rax, rbp, ymm5, ymm7, ymm9, ymm11, ymm13, ymm15) + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) // ---------------------------------- iteration 3 prefetch(0, mem(rdx, rcx, 1, 5*8)) lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; - SUBITER6x4(rax, rbx, ymm4, ymm6, ymm8, ymm10, ymm12, ymm14) - lea(mem(rbx, 1*32), rbp) - SUBITER6x4(rax, rbp, ymm5, ymm7, ymm9, ymm11, ymm13, ymm15) + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + dec(rsi) // i -= 1; jne(.DLOOPKITER) // iterate again if i != 0. @@ -3464,11 +3728,32 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_12x8_L prefetch(0, mem(rdx, 5*8)) add(r9, rdx) - SUBITER6x4(rax, rbx, ymm4, ymm6, ymm8, ymm10, ymm12, ymm14) - lea(mem(rbx, 1*32), rbp) - SUBITER6x4(rax, rbp, ymm5, ymm7, ymm9, ymm11, ymm13, ymm15) + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + dec(rsi) // i -= 1; jne(.DLOOPKLEFT) // iterate again if i != 0. @@ -3935,40 +4220,109 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_18x16_L // ---------------------------------- iteration 0 prefetch(0, mem(rdx, 5*8)) - SUBITER6x4(rax, rbx, ymm4, ymm6, ymm8, ymm10, ymm12, ymm14) - lea(mem(rbx, 1*32), rbp) - lea(mem(rax, r8, 2), rcx) - SUBITER4x4(rcx, rbp, ymm9, ymm11, ymm13, ymm15) + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) // ---------------------------------- iteration 1 prefetch(0, mem(rdx, r9, 1, 5*8)) - SUBITER6x4(rax, rbx, ymm4, ymm6, ymm8, ymm10, ymm12, ymm14) - lea(mem(rax, r8, 2), rcx) - lea(mem(rbx, 1*32), rbp) - SUBITER4x4(rcx, rbp, ymm9, ymm11, ymm13, ymm15) + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) // ---------------------------------- iteration 2 prefetch(0, mem(rdx, r9, 2, 5*8)) - SUBITER6x4(rax, rbx, ymm4, ymm6, ymm8, ymm10, ymm12, ymm14) - lea(mem(rax, r8, 2), rcx) - lea(mem(rbx, 1*32), rbp) - SUBITER4x4(rcx, rbp, ymm9, ymm11, ymm13, ymm15) + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) // ---------------------------------- iteration 3 prefetch(0, mem(rdx, rcx, 1, 5*8)) lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; - SUBITER6x4(rax, rbx, ymm4, ymm6, ymm8, ymm10, ymm12, ymm14) - lea(mem(rax, r8, 2), rcx) - lea(mem(rbx, 1*32), rbp) - SUBITER4x4(rcx, rbp, ymm9, ymm11, ymm13, ymm15) + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) dec(rsi) // i -= 1; jne(.DLOOPKITER) // iterate again if i != 0. @@ -3989,12 +4343,31 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_18x16_L label(.DLOOPKLEFT) // EDGE LOOP prefetch(0, mem(rdx, 5*8)) - SUBITER6x4(rax, rbx, ymm4, ymm6, ymm8, ymm10, ymm12, ymm14) - lea(mem(rax, r8, 2), rcx) - lea(mem(rbx, 1*32), rbp) - SUBITER4x4(rcx, rbp, ymm9, ymm11, ymm13, ymm15) + add(r9, rdx) + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm0, ymm3, ymm6) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) dec(rsi) // i -= 1; jne(.DLOOPKLEFT) // iterate again if i != 0. @@ -4293,7 +4666,6 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_18x16_L vmovhpd(xmm4, mem(rdx, rax, 1, 1*8)) label(.DDONE) - label(.DRETURN) vzeroupper() end_asm( @@ -4432,42 +4804,116 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_0x0_U label(.DLOOPKITER) // MAIN LOOP //0 prefetch(0, mem(rdx, 5*8)) - SUBITER4x4(rax, rbx, ymm4, ymm6, ymm8, ymm10) - lea(mem(rbx, 1*32), rbp) - SUBITER6x4(rax, rbp, ymm5, ymm7, ymm9, ymm11, ymm13, ymm15) + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm1, ymm3, ymm15) //1 prefetch(0, mem(rdx, r9, 1, 5*8)) - SUBITER4x4(rax, rbx, ymm4, ymm6, ymm8, ymm10) - lea(mem(rbx, 1*32), rbp) - SUBITER6x4(rax, rbp, ymm5, ymm7, ymm9, ymm11, ymm13, ymm15) + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm1, ymm3, ymm15) //2 prefetch(0, mem(rdx, r9, 2, 5*8)) - SUBITER4x4(rax, rbx, ymm4, ymm6, ymm8, ymm10) - lea(mem(rbx, 1*32), rbp) - SUBITER6x4(rax, rbp, ymm5, ymm7, ymm9, ymm11, ymm13, ymm15) + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm1, ymm3, ymm15) //3 prefetch(0, mem(rdx, rcx, 1, 5*8)) lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; - SUBITER4x4(rax, rbx, ymm4, ymm6, ymm8, ymm10) - lea(mem(rbx, 1*32), rbp) - SUBITER6x4(rax, rbp, ymm5, ymm7, ymm9, ymm11, ymm13, ymm15) + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) add(r10, rbx) // b += rs_b; - add(r9, rax) // a += cs_a; - dec(rsi) - jne(.DLOOPKITER) - label(.DCONSIDKLEFT) + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm1, ymm3, ymm15) + + dec(rsi) + jne(.DLOOPKITER) + + label(.DCONSIDKLEFT) mov(var(k_left), rsi) test(rsi, rsi) @@ -4478,11 +4924,30 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_0x0_U prefetch(0, mem(rdx, 5*8)) add(r9, rdx) - SUBITER4x4(rax, rbx, ymm4, ymm6, ymm8, ymm10) - lea(mem(rbx, 1*32), rbp) - SUBITER6x4(rax, rbp, ymm5, ymm7, ymm9, ymm11, ymm13, ymm15) + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm1, ymm3, ymm15) + dec(rsi) jne(.DLOOPKLEFT) @@ -4863,37 +5328,120 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_6x8_U //0 prefetch(0, mem(rdx, 5*8)) - SUBITER6x4(rax, rbx, ymm4, ymm6, ymm8, ymm10, ymm12, ymm14) - lea(mem(rbx, 1*32), rbp) - SUBITER6x4(rax, rbp, ymm5, ymm7, ymm9, ymm11, ymm13, ymm15) + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + //1 prefetch(0, mem(rdx, r9, 1, 5*8)) - SUBITER6x4(rax, rbx, ymm4, ymm6, ymm8, ymm10, ymm12, ymm14) - lea(mem(rbx, 1*32), rbp) - SUBITER6x4(rax, rbp, ymm5, ymm7, ymm9, ymm11, ymm13, ymm15) + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + //2 prefetch(0, mem(rdx, r9, 2, 5*8)) - SUBITER6x4(rax, rbx, ymm4, ymm6, ymm8, ymm10, ymm12, ymm14) - lea(mem(rbx, 1*32), rbp) - SUBITER6x4(rax, rbp, ymm5, ymm7, ymm9, ymm11, ymm13, ymm15) + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + //3 prefetch(0, mem(rdx, rcx, 1, 5*8)) lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; - SUBITER6x4(rax, rbx, ymm4, ymm6, ymm8, ymm10, ymm12, ymm14) - lea(mem(rbx, 1*32), rbp) - SUBITER6x4(rax, rbp, ymm5, ymm7, ymm9, ymm11, ymm13, ymm15) + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) dec(rsi) jne(.DLOOPKITER) @@ -4909,11 +5457,31 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_6x8_U prefetch(0, mem(rdx, 5*8)) add(r9, rdx) - SUBITER6x4(rax, rbx, ymm4, ymm6, ymm8, ymm10, ymm12, ymm14) - lea(mem(rbx, 1*32), rbp) - SUBITER6x4(rax, rbp, ymm5, ymm7, ymm9, ymm11, ymm13, ymm15) + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) dec(rsi) jne(.DLOOPKLEFT) @@ -5318,36 +5886,119 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_12x16_U //0 prefetch(0, mem(rdx, 5*8)) - SUBITER6x4(rax, rbx, ymm4, ymm6, ymm8, ymm10, ymm12, ymm14) - lea(mem(rbx, 1*32), rbp) - SUBITER6x4(rax, rbp, ymm5, ymm7, ymm9, ymm11, ymm13, ymm15) + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) add(r10, rbx) // b += rs_b; - add(r9, rax) // a += cs_a + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) + add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) //1 prefetch(0, mem(rdx, r9, 1, 5*8)) - SUBITER6x4(rax, rbx, ymm4, ymm6, ymm8, ymm10, ymm12, ymm14) - lea(mem(rbx, 1*32), rbp) - SUBITER6x4(rax, rbp, ymm5, ymm7, ymm9, ymm11, ymm13, ymm15) + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) //2 prefetch(0, mem(rdx, r9, 2, 5*8)) - SUBITER6x4(rax, rbx, ymm4, ymm6, ymm8, ymm10, ymm12, ymm14) - lea(mem(rbx, 1*32), rbp) - SUBITER6x4(rax, rbp, ymm5, ymm7, ymm9, ymm11, ymm13, ymm15) + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) //3 prefetch(0, mem(rdx, rcx, 1, 5*8)) lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; - SUBITER6x4(rax, rbx, ymm4, ymm6, ymm8, ymm10, ymm12, ymm14) - lea(mem(rbx, 1*32), rbp) - SUBITER6x4(rax, rbp, ymm5, ymm7, ymm9, ymm11, ymm13, ymm15) + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) + dec(rsi) jne(.DLOOPKITER) @@ -5361,11 +6012,32 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_12x16_U prefetch(0, mem(rdx, 5*8)) add(r9, rdx) - SUBITER6x4(rax, rbx, ymm4, ymm6, ymm8, ymm10, ymm12, ymm14) - lea(mem(rbx, 1*32), rbp) - SUBITER6x4(rax, rbp, ymm5, ymm7, ymm9, ymm11, ymm13, ymm15) + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm8) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; + vfmadd231pd(ymm0, ymm2, ymm12) + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm0, ymm3, ymm14) + vfmadd231pd(ymm1, ymm3, ymm15) dec(rsi) jne(.DLOOPKLEFT) @@ -5761,25 +6433,37 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_6x0_U //compute xmm5 and xmm7 only label(.DMAIN) //0 - lea(mem(rbx, 1*32), rbp) - SUBITER2x2(rax, rbp, xmm5, xmm7) - add(r9, rax) + vmovupd(mem(rbx, 1*32), xmm1) + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm1, xmm2, xmm5) + vfmadd231pd(xmm1, xmm3, xmm7) add(r10, rbx) - //1 - lea(mem(rbx, 1*32), rbp) - SUBITER2x2(rax, rbp, xmm5, xmm7) add(r9, rax) + //1 + vmovupd(mem(rbx, 1*32), xmm1) + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm1, xmm2, xmm5) + vfmadd231pd(xmm1, xmm3, xmm7) add(r10, rbx) - //2 - lea(mem(rbx, 1*32), rbp) - SUBITER2x2(rax, rbp, xmm5, xmm7) add(r9, rax) + //2 + vmovupd(mem(rbx, 1*32), xmm1) + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm1, xmm2, xmm5) + vfmadd231pd(xmm1, xmm3, xmm7) add(r10, rbx) - //3 - lea(mem(rbx, 1*32), rbp) - SUBITER2x2(rax, rbp, xmm5, xmm7) add(r9, rax) + //3 + vmovupd(mem(rbx, 1*32), xmm1) + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm1, xmm2, xmm5) + vfmadd231pd(xmm1, xmm3, xmm7) add(r10, rbx) + add(r9, rax) dec(rsi) jne(.DMAIN) @@ -5790,10 +6474,13 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_6x0_U je(.DPOSTACC) label(.DLEFT) - lea(mem(rbx, 1*32), rbp) - SUBITER2x2(rax, rbp, xmm5, xmm7) - add(r9, rax) + vmovupd(mem(rbx, 1*32), xmm1) + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(xmm1, xmm2, xmm5) + vfmadd231pd(xmm1, xmm3, xmm7) add(r10, rbx) + add(r9, rax) dec(rsi) jne(.DLEFT) @@ -6019,33 +6706,70 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_12x8_U //0 prefetch(0, mem(rdx, 5*8)) - lea(mem(rbx, 1*32), rbp) - SUBITER4x4(rax, rbp, ymm5, ymm7, ymm9, ymm11) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) add(r9, rax) // a += cs_a; - add(r10, rbx) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm1, ymm3, ymm11) //1 prefetch(0, mem(rdx, r9, 1, 5*8)) - lea(mem(rbx, 1*32), rbp) - SUBITER4x4(rax, rbp, ymm5, ymm7, ymm9, ymm11) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) add(r9, rax) // a += cs_a; - add(r10, rbx) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm1, ymm3, ymm11) //2 prefetch(0, mem(rdx, r9, 2, 5*8)) - lea(mem(rbx, 1*32), rbp) - SUBITER4x4(rax, rbp, ymm5, ymm7, ymm9, ymm11) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) add(r9, rax) // a += cs_a; - add(r10, rbx) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm1, ymm3, ymm11) //3 prefetch(0, mem(rdx, rcx, 1, 5*8)) lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; - lea(mem(rbx, 1*32), rbp) - SUBITER4x4(rax, rbp, ymm5, ymm7, ymm9, ymm11) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) add(r9, rax) // a += cs_a; - add(r10, rbx) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm1, ymm3, ymm11) + dec(rsi) jne(.DLOOPKITER) @@ -6060,10 +6784,20 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_12x8_U prefetch(0, mem(rdx, 5*8)) add(r9, rdx) - lea(mem(rbx, 1*32), rbp) - SUBITER4x4(rax, rbp, ymm5, ymm7, ymm9, ymm11) + vmovupd(mem(rbx, 1*32), ymm1) + add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) add(r9, rax) // a += cs_a; - add(r10, rbx) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm1, ymm3, ymm11) + dec(rsi) jne(.DLOOPKLEFT) @@ -6331,33 +7065,105 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_18x16_U label(.DLOOPKITER) // MAIN LOOP //0 prefetch(0, mem(rdx, 5*8)) - SUBITER2x4(rax, rbx, ymm4, ymm6) - lea(mem(rbx, 1*32), rbp) - SUBITER6x4(rax, rbp, ymm5, ymm7, ymm9, ymm11, ymm13, ymm15) + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm1, ymm3, ymm15) + //1 prefetch(0, mem(rdx, r9, 1, 5*8)) - SUBITER2x4(rax, rbx, ymm4, ymm6) - lea(mem(rbx, 1*32), rbp) - SUBITER6x4(rax, rbp, ymm5, ymm7, ymm9, ymm11, ymm13, ymm15) + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm1, ymm3, ymm15) + //2 prefetch(0, mem(rdx, r9, 2, 5*8)) - SUBITER2x4(rax, rbx, ymm4, ymm6) - lea(mem(rbx, 1*32), rbp) - SUBITER6x4(rax, rbp, ymm5, ymm7, ymm9, ymm11, ymm13, ymm15) + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm1, ymm3, ymm15) + //3 prefetch(0, mem(rdx, rcx, 1, 5*8)) lea(mem(rdx, r9, 4), rdx) // a_prefetch += 4*cs_a; - SUBITER2x4(rax, rbx, ymm4, ymm6) - lea(mem(rbx, 1*32), rbp) - SUBITER6x4(rax, rbp, ymm5, ymm7, ymm9, ymm11, ymm13, ymm15) + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm1, ymm3, ymm15) + dec(rsi) jne(.DLOOPKITER) @@ -6371,11 +7177,28 @@ void bli_dgemmsup_rv_haswell_asm_6x8m_18x16_U prefetch(0, mem(rdx, 5*8)) add(r9, rdx) - SUBITER2x4(rax, rbx, ymm4, ymm6) - lea(mem(rbx, 1*32), rbp) - SUBITER6x4(rax, rbp, ymm5, ymm7, ymm9, ymm11, ymm13, ymm15) + + vmovupd(mem(rbx, 0*32), ymm0) + vmovupd(mem(rbx, 1*32), ymm1) add(r10, rbx) // b += rs_b; + + vbroadcastsd(mem(rax ), ymm2) + vbroadcastsd(mem(rax, r8, 1), ymm3) + vfmadd231pd(ymm0, ymm2, ymm4) + vfmadd231pd(ymm1, ymm2, ymm5) + vfmadd231pd(ymm0, ymm3, ymm6) + vfmadd231pd(ymm1, ymm3, ymm7) + + vbroadcastsd(mem(rax, r8, 2), ymm2) + vbroadcastsd(mem(rax, r13, 1), ymm3) + vfmadd231pd(ymm1, ymm2, ymm9) + vfmadd231pd(ymm1, ymm3, ymm11) + + vbroadcastsd(mem(rax, r8, 4), ymm2) + vbroadcastsd(mem(rax, r15, 1), ymm3) add(r9, rax) // a += cs_a; + vfmadd231pd(ymm1, ymm2, ymm13) + vfmadd231pd(ymm1, ymm3, ymm15) dec(rsi) jne(.DLOOPKLEFT) From cf31fcd02077840dccbe2ec707a688b0004af80e Mon Sep 17 00:00:00 2001 From: Vignesh Balasubramanian Date: Wed, 17 Aug 2022 17:17:43 +0530 Subject: [PATCH 183/243] Fine tuned threshold and aocl dynamic for zgemm for skinny matrices. -Updated optimal threads in zgemm sup path for skinny matrices. -Fine tuned the threshold values for small and sup paths to improve overall zgemm. -Zgemm small is selected for inputs with transb as N. -Redirection of input among small, sup and native path was fine tuned. AMD-Internal : [CPUPL-1900] Change-Id: Ide37c8255def770b4b74bc6e7c6edb5ee15d3b1f --- config/zen/bli_cntx_init_zen.c | 4 ++-- frame/base/bli_rntm.c | 19 +++++++++++++++++-- frame/compat/bla_gemm_amd.c | 6 ++++-- 3 files changed, 23 insertions(+), 6 deletions(-) diff --git a/config/zen/bli_cntx_init_zen.c b/config/zen/bli_cntx_init_zen.c index 3fea3ea8f9..f527fe58d1 100644 --- a/config/zen/bli_cntx_init_zen.c +++ b/config/zen/bli_cntx_init_zen.c @@ -231,9 +231,9 @@ void bli_cntx_init_zen( cntx_t* cntx ) // Initialize sup thresholds with architecture-appropriate values. // s d c z - bli_blksz_init_easy( &thresh[ BLIS_MT ], 512, 256, 380, 110 ); + bli_blksz_init_easy( &thresh[ BLIS_MT ], 512, 256, 380, 128 ); bli_blksz_init_easy( &thresh[ BLIS_NT ], 512, 256, 256, 128 ); - bli_blksz_init_easy( &thresh[ BLIS_KT ], 440, 220, 220, 110 ); + bli_blksz_init_easy( &thresh[ BLIS_KT ], 440, 220, 220, 128 ); // Initialize the context with the sup thresholds. bli_cntx_set_l3_sup_thresh diff --git a/frame/base/bli_rntm.c b/frame/base/bli_rntm.c index 5efba4f2f0..d6712b060a 100644 --- a/frame/base/bli_rntm.c +++ b/frame/base/bli_rntm.c @@ -624,14 +624,29 @@ void bli_nthreads_optimum( dim_t n = bli_obj_width(c); dim_t k = bli_obj_width_after_trans(a); - if((m<=128 || n<=128 || k<=128) && ((m+n+k) <= 400) ) + if((m<=128 || n<=128 || k<=128) && ((m+n+k) <= 400)) { n_threads_ideal = 8; } - else if((m<=256 || n<=256 || k<=256) && ((m+n+k) <= 800) ) + else if((m<=256 || n<=256 || k<=256) && ((m+n+k) <= 800)) { n_threads_ideal = 16; } + if((m<=48) || (n<=48) || (k<=48)) + { + if((m+n+k) <= 840) + { + n_threads_ideal = 8; + } + else if((m+n+k) <= 1240) + { + n_threads_ideal = 16; + } + else if((m+n+k) <= 1540) + { + n_threads_ideal = 32; + } + } } else if( family == BLIS_SYRK && bli_obj_is_double(c)) { diff --git a/frame/compat/bla_gemm_amd.c b/frame/compat/bla_gemm_amd.c index adc83f073d..2a9dcb99db 100644 --- a/frame/compat/bla_gemm_amd.c +++ b/frame/compat/bla_gemm_amd.c @@ -762,7 +762,7 @@ void zgemm_ - For single thread, the API has no constraints before invoking. - For multiple threads, the constraint is that m and n should individually be less than 128. */ - if((k0==1) && ((nt==0) || ((nt==1) && (m0 < 128) && (n0 < 128))) + if((k0 == 1) && ((nt == 0) || ((nt == 1) && (m0 < 128) && (n0 < 128))) && bli_is_notrans(blis_transa) && bli_is_notrans(blis_transb)) { @@ -853,9 +853,11 @@ void zgemm_ } #endif } + #ifdef BLIS_ENABLE_SMALL_MATRIX - if (((nt == 0) && (m0 <= 40) && (n0 <= 40) && (k0 <= 512)) || + if (((nt == 0) && (((m0 <= 40) && (n0 <= 40)) || + (m0 <= 128) && (n0 <= 128) && bli_is_notrans(blis_transb)) && (k0 <= 512)) || ((nt == 1) && (((m0 <= 32) || (n0 <= 32) || (k0 <= 32)) && ((m0 + n0 + k0) <= 100)))) { err_t status = BLIS_NOT_YET_IMPLEMENTED; From 22af681a11549e99dee86cc78b298c4c72fd991b Mon Sep 17 00:00:00 2001 From: Sireesha Sanga Date: Thu, 18 Aug 2022 22:02:19 +0530 Subject: [PATCH 184/243] Runtime Thread Control Feature Update Details: 1. Runtime Thread Control Feature is enhanced to create a provision for the application to allocate a different number of threads to BLIS from the number of threads application is using for itself. 2. In the previous implementation, if application sets BLIS_NUM_THREADS with a valid value, BLIS internally calls omp_set_num_threads() API with same value. Due to this, application could not differentiate between the number of threads used in BLIS library and the application. 3. With the current solution, if Application wants to allocate different number of threads for BLIS API and application, Application can choose either BLIS_NUM_THREADS environment variable or bli_thread_set_num_threads(nt) API for BLIS, and OpenMP APIs or environment variables for itself, respectively. 4. If BLIS_NUM_THREADS is set with a valid value, same value will be used in the subsequent parallel regions unless bli_thread_set_num_threads() API is used by the Application to modify the desired number of threads during BLIS API execution. 5. Once BLIS_NUM_THREADS environment variable or bli_thread_set_num_threads(nt) API is used by the application, BLIS module would always give precedence to these values. BLIS API would not consider the values set using OpenMP API omp_set_num_threads(nt) API or OMP_NUM_THREADS environment variable. 6. If BLIS_NUM_THREADS is not set, then if Application is multithreaded and issued omp_set_num_threads(nt) with desired number of threads, omp_get_max_threads() API will fetch the number of threads set earlier. 7. If BLIS_NUM_THREADS is not set, omp_set_num_threads(nt) is not called by the application, but only OMP_NUM_THREADS is set, omp_get_max_threads() API will fetch the value of OMP_NUM_THREADS. 8. If both environment variables are not set, or if they are set with invalid values, and omp_set_num_threads(nt) is not issued by application, omp_get_max_threads() API will return the number of the cores in the current context. 9. BLIS will initialize rntm->num_threads with the same value. However if omp_set_nested is false - BLIS APIs called from parallel threads will run in sequential. But if nested parallelism is enabled Then each application will launch MT BLIS. 10. Order of precedence used for number of threads: 0. value set using bli_thread_set_num_threads(nt) by the application 1. valid value set for BLIS_NUM_THREADS environment variable 2. omp_set_num_threads(nt) issued by the application 3. valid value set for OMP_NUM_THREADS environment variable 4. Number of cores 11. If nt is not a valid value for omp_set_num_threads(nt) API, number of threads would be set to 1. omp_get_max_threads() API will return 1. 12. OMP_NUM_THREADS env. variable is applicable only when OpenMP is enabled. AMD-Internal: [CPUPL-2342] Change-Id: I2041ac1d824f0b57a23a2a69abd6017c800f21b6 --- frame/base/bli_rntm.c | 16 ++++++--- frame/base/bli_rntm.h | 10 ++++++ frame/include/bli_type_defs.h | 2 ++ frame/thread/bli_thread.c | 63 +++++++++++++++++++++++------------ 4 files changed, 66 insertions(+), 25 deletions(-) diff --git a/frame/base/bli_rntm.c b/frame/base/bli_rntm.c index d6712b060a..c6d2cf5b4a 100644 --- a/frame/base/bli_rntm.c +++ b/frame/base/bli_rntm.c @@ -59,12 +59,20 @@ void bli_rntm_init_from_global( rntm_t* rntm ) // Acquire the mutex protecting global_rntm. bli_pthread_mutex_lock( &global_rntm_mutex ); - // Update the latest value of number of threads into global rntm structure, - // before copying into local rntm structure. This updated value will be - // used in the subsequent parallel regions. + // If BLIS_NUM_THREADS environment variable is not set or + // if bli_thread_set_num_threads() API is not used by the + // application, blis_mt flag will be false. + // Then we derive number of threads using OpenMP API + // omp_get_max_threads(), and update into global rntm structure, + // before copying into local rntm structure. + + // This updated value will be used in the subsequent parallel regions. + if(!(global_rntm.blis_mt)) + { #ifdef BLIS_ENABLE_OPENMP - global_rntm.num_threads = n_threads; + global_rntm.num_threads = n_threads; #endif + } *rntm = global_rntm; diff --git a/frame/base/bli_rntm.h b/frame/base/bli_rntm.h index e28463c5ab..c45184c57d 100644 --- a/frame/base/bli_rntm.h +++ b/frame/base/bli_rntm.h @@ -66,6 +66,11 @@ BLIS_INLINE bool bli_rntm_auto_factor( rntm_t* rntm ) return rntm->auto_factor; } +BLIS_INLINE bool bli_rntm_blis_mt( rntm_t* rntm ) +{ + return rntm->blis_mt; +} + BLIS_INLINE dim_t bli_rntm_num_threads( rntm_t* rntm ) { return rntm->num_threads; @@ -154,6 +159,11 @@ BLIS_INLINE void bli_rntm_set_auto_factor_only( bool auto_factor, rntm_t* rntm ) rntm->auto_factor = auto_factor; } +BLIS_INLINE void bli_rntm_set_blis_mt_only( bool blis_mt, rntm_t* rntm ) +{ + rntm->blis_mt = blis_mt; +} + BLIS_INLINE void bli_rntm_set_num_threads_only( dim_t nt, rntm_t* rntm ) { rntm->num_threads = nt; diff --git a/frame/include/bli_type_defs.h b/frame/include/bli_type_defs.h index 47377aa250..2ad2126352 100644 --- a/frame/include/bli_type_defs.h +++ b/frame/include/bli_type_defs.h @@ -1478,6 +1478,8 @@ typedef struct rntm_s bool pack_a; // enable/disable packing of left-hand matrix A. bool pack_b; // enable/disable packing of right-hand matrix B. bool l3_sup; // enable/disable small matrix handling in level-3 ops. + // blis_mt, flag to figure out whether number of + bool blis_mt;// threads is set using BLIS APIS or OpenMP APIs. // "Internal" fields: these should not be exposed to the end-user. diff --git a/frame/thread/bli_thread.c b/frame/thread/bli_thread.c index 097d136e7e..f721bae7e6 100644 --- a/frame/thread/bli_thread.c +++ b/frame/thread/bli_thread.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 21, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 22, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -1614,12 +1614,11 @@ void bli_thread_set_num_threads( dim_t n_threads ) bli_rntm_set_num_threads_only( n_threads, &global_rntm ); -#ifdef BLIS_ENABLE_OPENMP - // In the function bli_rntm_init_from_global() we extract n_threads - // using the API omp_get_max_threads(). Following step ensures that - // omp_get_max_threads returns the same value as set here. - omp_set_num_threads( n_threads ); -#endif + // BLIS_NUM_THREADS env variable or BLIS API to set the + // number of threads is used. Setting the blis_mt flag to TRUE + // so that OMP API or OMP env variables will not be of effect + // going forward. + bli_rntm_set_blis_mt_only(TRUE, &global_rntm); // Release the mutex protecting global_rntm. bli_pthread_mutex_unlock( &global_rntm_mutex ); @@ -1642,30 +1641,41 @@ void bli_thread_init_rntm_from_env #ifdef BLIS_ENABLE_MULTITHREADING - // Try to read BLIS_NUM_THREADS first. - nt = bli_env_get_var( "BLIS_NUM_THREADS", -1 ); - - -#ifdef BLIS_ENABLE_OPENMP - // Scenarios: - // 1. If BLIS_NUM_THREADS is set with valid value, set the nt using omp_set_num_threads(nt) - // so that this value can be fetched inside BLIS API as well. - // 2. If BLIS_NUM_THREADS is not set, then if Application is multithreaded and issued + // 1. If BLIS_NUM_THREADS is set with a valid value, same value + // will be used in the subsequent parallel regions unless + // bli_thread_set_num_threads() API is used by the Application + // to modify the desired number of threads during BLIS API execution. + // + // 2. Once BLIS_NUM_THREADS environment variable or bli_thread_set_num_threads(nt) + // API is used by the application, BLIS module would always give precedence to + // these values. BLIS API would not consider the values set using OpenMP API + // omp_set_num_threads(nt) API or OMP_NUM_THREADS environment variable. + // + // 3. If Application wants to allocate separate number of threads for BLIS API execution + // and application, Application can choose either BLIS_NUM_THREADS environement variable + // or bli_thread_set_num_threads(nt) API, to set the desired number of threads + // in BLIS API Execution. Application can use OpenMP APIs or environment variables for + // itself. + // + // 4. If BLIS_NUM_THREADS is not set, then if Application is multithreaded and issued // omp_set_num_threads(nt) with desired number of threads, // omp_get_max_threads() API will fetch the number of threads set earlier. - // 3. If BLIS_NUM_THREADS is not set, omp_set_num_threads(nt) is not called by the application, + // + // 5. If BLIS_NUM_THREADS is not set, omp_set_num_threads(nt) is not called by the application, // but only OMP_NUM_THREADS is set, // omp_get_max_threads() API will fetch the value of OMP_NUM_THREADS. - // 4. If both environment variables are not set, or if they are set with invalid values, and + // + // 6. If both environment variables are not set, or if they are set with invalid values, and // omp_set_num_threads(nt) is not issued by application, // omp_get_max_threads() API will return the number of the cores in the current context. // - // BLIS will rntm->num_threads will also get initialized with the same value. + // BLIS will initialize rntm->num_threads with the same value. // However if omp_set_nested is false - BLIS APIs called from parallel threads will run in sequential. // But if nested parallelism is enabled - Then each application will launch MT BLIS. // // Order of precedence used for number of threads: + // 0. valid value set using bli_thread_set_num_threads(nt) by the application // 1. valid value set for BLIS_NUM_THREADS environment variable // 2. omp_set_num_threads(nt) issued by the application // 3. valid value set for OMP_NUM_THREADS environment variable @@ -1676,16 +1686,27 @@ void bli_thread_init_rntm_from_env // // OMP_NUM_THREADS environment variable is applicable only when OpenMP is enabled. + + // Try to read BLIS_NUM_THREADS first. + nt = bli_env_get_var( "BLIS_NUM_THREADS", -1 ); + + // If BLIS_NUM_THREADS is set with a valid value, set the blis_mt flag in global runtime + // structure. Later during API execution, this flag will be checked for TRUE or FALSE. + // If the flag is FALSE, only then the value set by the application using OpenMP API, + // would be fetched and used subsequently. if(nt > 0) { - omp_set_num_threads(nt); + bli_rntm_set_blis_mt_only(TRUE, rntm); } else { + bli_rntm_set_blis_mt_only(FALSE, rntm); + +#ifdef BLIS_ENABLE_OPENMP nt = omp_get_max_threads(); +#endif } -#endif // Read the environment variables for the number of threads (ways // of parallelism) for each individual loop. jc = bli_env_get_var( "BLIS_JC_NT", -1 ); From 6861fcae918a83d28b53e1e2d87cf859b68f5a5d Mon Sep 17 00:00:00 2001 From: Edward Smyth Date: Wed, 10 Aug 2022 12:20:45 -0400 Subject: [PATCH 185/243] BLIS: Improve architecture selection at runtime Make BLIS_ARCH_TYPE=0 be an error, so that incorrect meaningful names will get an error rather than "skx" code path. BLIS_ARCH_TYPE=1 is now "generic", so that it should be constant as new code paths are added. Thus all other code path enum values have increased by 2. Also added new options to BLIS configure program to allow: 1. BLIS_ARCH_TYPE functionality to be disabled, e.g.: ./configure --disable-blis-arch-type amdzen 2. Renaming the environment variable tested from "BLIS_ARCH_TYPE" to a specified value, e.g.: ./configure --rename-blis-arch-type=MY_NAME_FOR_ARCH_TYPE amdzen On Windows, these can be enabled with e.g.: cmake ... -DDISABLE_BLIS_ARCH_TYPE=ON or cmake ... -DRENAME_BLIS_ARCH_TYPE=MY_NAME_FOR_ARCH_TYPE This implements changes 2 and 3 in the Jira ticket below. AMD-Internal: [CPUPL-2235] Change-Id: Ie42906bd909f9d83f00a90c5bef9c5bf3ef5adb4 --- CMakeLists.txt | 20 +++++++++++++++++- build/bli_config.h.in | 8 ++++++- build/bli_win_config.h.in | 4 ++++ configure | 40 +++++++++++++++++++++++++++++++++++ frame/base/bli_arch.c | 22 ++++++++++++++++--- frame/base/bli_env.c | 20 ++++++++++-------- frame/include/bli_type_defs.h | 14 +++++++++--- 7 files changed, 111 insertions(+), 17 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index a46f2d664e..5b01a8f4e4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -107,6 +107,8 @@ option (ENABLE_UPPERCASE_API "export APIs with uppercase" OFF) option (ENABLE_COMPLEX_RETURN_INTEL "Enable complex_return_intel" OFF) option (ENABLE_TRSM_PREINVERSION "Enable TRSM preinversion" ON) option (ENABLE_AOCL_DYNAMIC "Enable Dynamic Multi-threading" OFF) +option(DISABLE_BLIS_ARCH_TYPE "Disable BLIS_ARCH_TYPE functionality" OFF) +option(RENAME_BLIS_ARCH_TYPE "Rename BLIS_ARCH_TYPE env var renamed to supplied value" BLIS_ARCH_TYPE) if (${AOCL_BLIS_FAMILY} STREQUAL "amdzen") set(REF_KERNEL_MIRRORING_PY "${CMAKE_SOURCE_DIR}/build/blis_ref_kernel_mirror.py") @@ -282,6 +284,21 @@ else() endif() endif() +if(DISABLE_BLIS_ARCH_TYPE) + set(BLIS_DISABLE_BLIS_ARCH_TYPE TRUE) +else() + set(BLIS_DISABLE_BLIS_ARCH_TYPE FALSE) +endif() + +if(RENAME_BLIS_ARCH_TYPE) + set(__blis_arch_type_name TRUE) + set(rename_blis_arch_type "${RENAME_BLIS_ARCH_TYPE}") +else() + set(__blis_arch_type_name TRUE) + set(rename_blis_arch_type "BLIS_ARCH_TYPE") +endif() + + #print configurations message("---cmake configurations---") message(CMAKE_C_COMPILER_ID : ${CMAKE_C_COMPILER_ID}) @@ -303,7 +320,8 @@ message(BLIS_ENABLE_MEMKIND : ${BLIS_ENABLE_MEMKIND}) message(BLIS_ENABLE_PRAGMA_OMP_SIMD : ${BLIS_ENABLE_PRAGMA_OMP_SIMD}) message(BLIS_ENABLE_SANDBOX : ${BLIS_ENABLE_SANDBOX}) message(BLIS_ENABLE_SHARED : ${BLIS_ENABLE_SHARED}) - +message(DISABLE_BLIS_ARCH_TYPE : ${DISABLE_BLIS_ARCH_TYPE}) +message(RENAME_BLIS_ARCH_TYPE : ${RENAME_BLIS_ARCH_TYPE}) SET(ENABLE_SIMD_FLAGS "AVX2" CACHE STRING "Set compiler SIMD flags") SET_PROPERTY(CACHE ENABLE_SIMD_FLAGS PROPERTY STRINGS none SSE2 AVX AVX2) diff --git a/build/bli_config.h.in b/build/bli_config.h.in index 73f51baed2..6c17fc5e74 100644 --- a/build/bli_config.h.in +++ b/build/bli_config.h.in @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2021, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -194,4 +194,10 @@ #define BLIS_DISABLE_COMPLEX_RETURN_INTEL #endif +#if @disable_blis_arch_type@ +#define DISABLE_BLIS_ARCH_TYPE +#endif + +#define __blis_arch_type_name "@rename_blis_arch_type@" + #endif diff --git a/build/bli_win_config.h.in b/build/bli_win_config.h.in index 6c61b2b1a4..3903763873 100644 --- a/build/bli_win_config.h.in +++ b/build/bli_win_config.h.in @@ -47,4 +47,8 @@ #cmakedefine BLIS_ENABLE_COMPLEX_RETURN_INTEL +#cmakedefine DISABLE_BLIS_ARCH_TYPE + +#cmakedefine __blis_arch_type_name "@rename_blis_arch_type@" + #endif diff --git a/configure b/configure index e18f98f6ae..73dc8cc358 100755 --- a/configure +++ b/configure @@ -353,6 +353,19 @@ print_usage() echo " Num_threads is derived from either environment variable" echo " OMP_NUM_THREADS or BLIS_NUM_THREADS' or bli_set_num_threads() API." echo " " + echo " --enable-blis-arch-type, --disable-blis-arch-type" + echo " " + echo " Disable (Enabled by default) support for BLIS_ARCH_TYPE" + echo " environment variable, which allows user to select" + echo " architecture-specific code path at runtime." + echo " If disabled, in builds with multiple code paths, BLIS" + echo " will still select path automatically." + echo " " + echo " --rename-blis-arch-type=STRING" + echo " " + echo " Change environment variable used to select architecture-specific" + echo " code path from BLIS_ARCH_TYPE to STRING" + echo " " echo " -q, --quiet Suppress informational output. By default, configure" echo " is verbose. (NOTE: -q is not yet implemented)" echo " " @@ -1145,8 +1158,11 @@ auto_detect() # NOTE: -D_GNU_SOURCE is needed to enable POSIX extensions to # pthreads (i.e., barriers). + double_quote_open=\"\\\" + double_quote_close=\\\"\" cmd="${cc} ${config_defines} \ -DBLIS_CONFIGURETIME_CPUID \ + -D__blis_arch_type_name=${double_quote_open}${rename_blis_arch_type}${double_quote_close} \ ${c_hdr_paths} \ -std=c99 -D_GNU_SOURCE \ ${cflags} \ @@ -2025,6 +2041,8 @@ main() enable_aocl_dynamic='yes' force_version='no' complex_return='default' + disable_blis_arch_type='no' + rename_blis_arch_type='BLIS_ARCH_TYPE' # The addon flag and names. addon_flag='' @@ -2254,6 +2272,15 @@ main() complex-return=*) complex_return=${OPTARG#*=} ;; + enable-blis-arch-type) + disable_blis_arch_type='no' + ;; + disable-blis-arch-type) + disable_blis_arch_type='yes' + ;; + rename-blis-arch-type=*) + rename_blis_arch_type=${OPTARG#*=} + ;; *) print_usage ;; @@ -3229,6 +3256,17 @@ main() exit 1 fi + if [ "x${disable_blis_arch_type}" = "xyes" ]; then + echo "${script_name}: user selection of code path using BLIS_ARCH_TYPE env var is disabled." + disable_blis_arch_type_01='1' + else + disable_blis_arch_type_01='0' + fi + + # Check if the user requested a custom env var name to replace BLIS_ARCH_TYPE. + if [ "x${rename_blis_arch_type}" != "xBLIS_ARCH_TYPE" ]; then + echo "${script_name}: configuring with BLIS_ARCH_TYPE env var renamed to '${rename_blis_arch_type}'." + fi echo "${script_name}: configuring complex return type as \"${complex_return}\"." @@ -3442,6 +3480,8 @@ main() | sed -e "s/@enable_sandbox@/${enable_sandbox_01}/g" \ | sed -e "s/@enable_shared@/${enable_shared_01}/g" \ | sed -e "s/@complex_return_intel@/${complex_return_intel01}/g" \ + | sed -e "s/@disable_blis_arch_type@/${disable_blis_arch_type_01}/g" \ + | sed -e "s/@rename_blis_arch_type@/${rename_blis_arch_type}/g" \ > "${bli_config_h_out_path}" # -- Instantiate bli_addon.h file from template ---------------------------- diff --git a/frame/base/bli_arch.c b/frame/base/bli_arch.c index aa6940ba49..fecc353161 100644 --- a/frame/base/bli_arch.c +++ b/frame/base/bli_arch.c @@ -81,9 +81,19 @@ void bli_arch_set_id( void ) bool do_logging = bli_env_get_var( "BLIS_ARCH_DEBUG", 0 ); bli_arch_set_logging( do_logging ); - // Check the environment variable BLIS_ARCH_TYPE to see if the user - // requested that we use a specific subconfiguration. - dim_t req_id = bli_env_get_var_arch_type( "BLIS_ARCH_TYPE", -1 ); + // DISABLE_BLIS_ARCH_TYPE and BLIS_CONFIGURETIME_CPUID seem similar but + // have different use cases: + // * BLIS_CONFIGURETIME_CPUID is used by the "configure auto" option to + // select a single code path, and affects other parts of the code. + // * DISABLE_BLIS_ARCH_TYPE disables user selection of code path here in + // builds with multiple code paths. + +#ifndef DISABLE_BLIS_ARCH_TYPE + // Check the environment variable (that "__blis_arch_type_name" is + // defined to be) to see if the user requested that we use a specific + // subconfiguration. "__blis_arch_type_name" will be defined by the + // configure command in bli_config.h, with the default name of BLIS_ARCH_TYPE + dim_t req_id = bli_env_get_var_arch_type( __blis_arch_type_name, -1 ); #ifndef BLIS_CONFIGURETIME_CPUID if ( req_id != -1 ) @@ -118,6 +128,8 @@ void bli_arch_set_id( void ) id = req_id; } else +#endif + #endif { // BLIS_ARCH_TYPE was unset. Proceed with normal subconfiguration @@ -234,6 +246,10 @@ void bli_arch_set_id( void ) // function in bli_env.c static char* config_name[ BLIS_NUM_ARCHS ] = { + "error", + + "generic", + "skx", "knl", "knc", diff --git a/frame/base/bli_env.c b/frame/base/bli_env.c index 2cb9efd87a..7fabc2b955 100644 --- a/frame/base/bli_env.c +++ b/frame/base/bli_env.c @@ -87,14 +87,15 @@ gint_t bli_env_get_var_arch_type( const char* env, gint_t fallback ) if (r_val == 0) { - // Could be deliberately 0 (currently meaning "skx") or - // a non-numeric value. We still allow direct specification - // of integer value to select code path. Non-zero integer - // values bypass this code block and are handled as before. - // Here we look for known meaningful names, and return 0 - // if we cannot find a match. - // This code MUST be kept in synch with arch_t enumeration - // in bli_type_defs.h and array config_name in bli_arch.c + // Could be deliberately 0 (now meaning an ERROR) + // or a non-numeric value. We still allow direct + // specification of integer value to select code + // path. Non-zero integer values bypass this code + // block and are handled as before. Here we look + // for known meaningful names, and return 0 if + // we cannot find a match. This code MUST be kept + // in synch with arch_t enumeration in + // bli_type_defs.h and array config_name in bli_arch.c // convert string to lowercase size = strlen(str); @@ -141,7 +142,8 @@ gint_t bli_env_get_var_arch_type( const char* env, gint_t fallback ) { r_val = BLIS_ARCH_ZEN2; } - else if ((strcmp(str, "zen") == 0) || (strcmp(str, "zen1") == 0)) + else if ((strcmp(str, "zen") == 0) || + (strcmp(str, "zen1") == 0)) { r_val = BLIS_ARCH_ZEN; } diff --git a/frame/include/bli_type_defs.h b/frame/include/bli_type_defs.h index 2ad2126352..89f9aada33 100644 --- a/frame/include/bli_type_defs.h +++ b/frame/include/bli_type_defs.h @@ -998,6 +998,13 @@ typedef enum // NOTE: The C language standard guarantees that the first enum value // starts at 0. + // Initial value, will be selected for an unrecognized (non-integer) + // value of BLIS_ARCH_TYPE + BLIS_ARCH_ERROR, + + // Generic architecture/configuration + BLIS_ARCH_GENERIC, + // Intel BLIS_ARCH_SKX, BLIS_ARCH_KNL, @@ -1029,12 +1036,13 @@ typedef enum BLIS_ARCH_POWER7, BLIS_ARCH_BGQ, - // Generic architecture/configuration - BLIS_ARCH_GENERIC + // Dummy value, always the last one. + // In config_name in bli_arch.c this is also set to "generic" + BLIS_ARCH_GENERIC_LAST } arch_t; -#define BLIS_NUM_ARCHS (BLIS_ARCH_GENERIC + 1) +#define BLIS_NUM_ARCHS (BLIS_ARCH_GENERIC_LAST + 1) // From 035ed98b51bb8e4eb50eb3b75d96d61d4646b271 Mon Sep 17 00:00:00 2001 From: Arnav Sharma Date: Fri, 19 Aug 2022 12:05:53 +0530 Subject: [PATCH 186/243] Temporarily disabling optimized ZHER - Disabling optimized ZHER pending verification with netlib BLAS test. AMD-Internal: [CPUPL-2416] Change-Id: I74c4d16e1c99ddeb1df91130a8e14feafd0952d0 --- frame/2/her/bli_her_unb_var1_amd.c | 7 +++++-- frame/2/her/bli_her_unb_var2_amd.c | 7 +++++-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/frame/2/her/bli_her_unb_var1_amd.c b/frame/2/her/bli_her_unb_var1_amd.c index 13334c5418..297c9200ee 100644 --- a/frame/2/her/bli_her_unb_var1_amd.c +++ b/frame/2/her/bli_her_unb_var1_amd.c @@ -163,8 +163,11 @@ void PASTEMAC(ch,varname) \ ) \ { \ const num_t dt = PASTEMAC(ch,type); \ + /* ToDo: + Enable intrinsic implementation after verifying + with netlib BLAS tests. */ \ /* Redirect to intrinsic implementation of HER for dcomplex */ \ - if ( bli_cpuid_is_avx_supported() == TRUE && bli_is_conj(conjh) && incx == 1 ) \ + /* if ( bli_cpuid_is_avx_supported() == TRUE && bli_is_conj(conjh) && incx == 1 ) \ { \ bli_zher_zen_int_var1 \ ( \ @@ -181,7 +184,7 @@ void PASTEMAC(ch,varname) \ cntx \ ); \ } \ - else \ + else \ */ \ { \ ctype* x0; \ ctype* chi1; \ diff --git a/frame/2/her/bli_her_unb_var2_amd.c b/frame/2/her/bli_her_unb_var2_amd.c index 6fb4a5d295..c101200d2d 100644 --- a/frame/2/her/bli_her_unb_var2_amd.c +++ b/frame/2/her/bli_her_unb_var2_amd.c @@ -163,8 +163,11 @@ void PASTEMAC(ch,varname) \ ) \ { \ const num_t dt = PASTEMAC(ch,type); \ + /* ToDo: + Enable intrinsic implementation after verifying + with netlib BLAS tests. */ \ /* Redirect to intrinsic implementation of HER for unit increment */ \ - if ( bli_cpuid_is_avx_supported() == TRUE && bli_is_conj(conjh) && incx == 1 ) \ + /* if ( bli_cpuid_is_avx_supported() == TRUE && bli_is_conj(conjh) && incx == 1 ) \ { \ bli_zher_zen_int_var2 \ ( \ @@ -181,7 +184,7 @@ void PASTEMAC(ch,varname) \ cntx \ ); \ } \ - else \ + else \ */ \ { \ ctype* chi1; \ ctype* x2; \ From b8b339416a16fb6c9055ad260a16d2fbed2b88b9 Mon Sep 17 00:00:00 2001 From: Shubham Sharma Date: Fri, 19 Aug 2022 08:38:26 -0400 Subject: [PATCH 187/243] DGEMMT optimizations Details: 1. For lower and upper, "B" column major storage variants of gemmt, new kernels are developed and optimized to compute only the required outputs in the diagonal blocks. 2. In the previous implementation, all the 48 outputs of the given 6x8 block of C matrix are computed and stored into a temporary buffer. Later,the required elements are copied into the final C output buffer. 3. Changes are made to compute only the required outputs of the 6x8 block of C matrix and directly stored in the final C output buffer. 4. With this optimization, we are avoiding copy operation and also reducing the number of computations. 5. Customized bli_dgemmsup_rd_haswell_asm_6x8m Kernels specific to compute Lower and Upper Variant diagonal outputs have been added. 6. SUP Framework changes to integrate the new kernels have been added. 7. These kernels are part of the SUP framework. AMD-Internal: [CPUPL-2341] Change-Id: I9748b2b52557718e7497ecf046530d3031636a63 --- frame/3/gemmt/bli_gemmt_sup_var1n2m.c | 85 +- .../3/sup/bli_gemmsup_rd_haswell_asm_d6x8m.c | 12188 +++++++++++++++- kernels/haswell/bli_kernels_haswell.h | 18 +- 3 files changed, 12219 insertions(+), 72 deletions(-) diff --git a/frame/3/gemmt/bli_gemmt_sup_var1n2m.c b/frame/3/gemmt/bli_gemmt_sup_var1n2m.c index 1023821b86..a026ed8d39 100644 --- a/frame/3/gemmt/bli_gemmt_sup_var1n2m.c +++ b/frame/3/gemmt/bli_gemmt_sup_var1n2m.c @@ -76,19 +76,25 @@ typedef void (*gemmt_ker_ft) ); //Look-up table for Gemmt Upper Variant Kernels -gemmt_ker_ft ker_fpus[7] = -{ - bli_dgemmsup_rv_haswell_asm_6x8m_0x0_U, - bli_dgemmsup_rv_haswell_asm_6x8m_6x0_U, - bli_dgemmsup_rv_haswell_asm_6x8m_6x8_U, - bli_dgemmsup_rv_haswell_asm_6x8m_12x8_U, - bli_dgemmsup_rv_haswell_asm_6x8m_12x16_U, - bli_dgemmsup_rv_haswell_asm_6x8m_18x16_U, - bli_dgemmsup_rv_haswell_asm_6x8m_0x0_combined_U -}; +gemmt_ker_ft ker_fpus[14] = + { + bli_dgemmsup_rv_haswell_asm_6x8m_0x0_U, + bli_dgemmsup_rv_haswell_asm_6x8m_6x0_U, + bli_dgemmsup_rv_haswell_asm_6x8m_6x8_U, + bli_dgemmsup_rv_haswell_asm_6x8m_12x8_U, + bli_dgemmsup_rv_haswell_asm_6x8m_12x16_U, + bli_dgemmsup_rv_haswell_asm_6x8m_18x16_U, + bli_dgemmsup_rv_haswell_asm_6x8m_0x0_combined_U, + bli_dgemmsup_rd_haswell_asm_6x8m_0x0_U, + bli_dgemmsup_rd_haswell_asm_6x8m_6x0_U, + bli_dgemmsup_rd_haswell_asm_6x8m_6x8_U, + bli_dgemmsup_rd_haswell_asm_6x8m_12x8_U, + bli_dgemmsup_rd_haswell_asm_6x8m_12x16_U, + bli_dgemmsup_rd_haswell_asm_6x8m_18x16_U, + bli_dgemmsup_rd_haswell_asm_6x8m_0x0_combined_U}; //Look-up table for Gemmt Lower Variant Kernels -gemmt_ker_ft ker_fpls[7] = +gemmt_ker_ft ker_fpls[14] = { bli_dgemmsup_rv_haswell_asm_6x8m_0x0_L, bli_dgemmsup_rv_haswell_asm_6x8m_6x0_L, @@ -96,7 +102,14 @@ gemmt_ker_ft ker_fpls[7] = bli_dgemmsup_rv_haswell_asm_6x8m_12x8_L, bli_dgemmsup_rv_haswell_asm_6x8m_12x16_L, bli_dgemmsup_rv_haswell_asm_6x8m_18x16_L, - bli_dgemmsup_rv_haswell_asm_6x8m_16x12_combined_L + bli_dgemmsup_rv_haswell_asm_6x8m_16x12_combined_L, + bli_dgemmsup_rd_haswell_asm_6x8m_0x0_L, + bli_dgemmsup_rd_haswell_asm_6x8m_6x0_L, + bli_dgemmsup_rd_haswell_asm_6x8m_6x8_L, + bli_dgemmsup_rd_haswell_asm_6x8m_12x8_L, + bli_dgemmsup_rd_haswell_asm_6x8m_12x16_L, + bli_dgemmsup_rd_haswell_asm_6x8m_18x16_L, + bli_dgemmsup_rd_haswell_asm_6x8m_16x12_combined_L }; // @@ -1906,10 +1919,6 @@ void PASTEMACT(ch,opname,uplo,varname) \ dim_t n_off_24 = n_off_cblock % 24; \ dim_t m_idx = (dim_t)(m_off_24 / MR); \ dim_t n_idx = (dim_t)(n_off_24 / NR); \ -\ - /* Optimized kernels are not implemented for the case where B is - stored as column major */ \ - bool storage_supported = (dt == BLIS_DOUBLE) && ( (stor_id == BLIS_RRR) || (stor_id == BLIS_RCR) || (stor_id == BLIS_CRR) ); \ \ /* Check if m, n indices are multiple of MR and NR respectively and current block is a complete 6x8 block */ \ @@ -1917,7 +1926,8 @@ void PASTEMACT(ch,opname,uplo,varname) \ \ /* m_idx and n_idx would be equal only if the current block is a diagonal block */\ - if( (storage_supported) && (m_idx == n_idx) && (idx_supported) ) { \ + if( (dt == BLIS_DOUBLE) && (m_idx == n_idx) && (idx_supported) ) { \ + /* index of kernel in lookup table is 2*m_idx) */ \ dim_t ker_idx; \ ker_idx = m_idx<<1; \ \ @@ -1928,9 +1938,14 @@ void PASTEMACT(ch,opname,uplo,varname) \ this, it has to be ensured that at least 12 rows are pending in C for computation. (m_off + 2 * MR <=m). Usage of this combined kernel saves the entire time to execute one kernel*/ \ - if( (n_idx == 2) && (m_off_cblock + MR + MR <= m) )\ + if( (n_idx == 2) && (m_off_cblock + MR + MR <= m) ) {\ ker_idx = 6; /* use combined kernel, index of combined kernel in lookup table is 6 */\ + } \ + /* use rd kernel if B is column major storage */ \ + if( stor_id == BLIS_RRC ) { \ + ker_idx += 7; /* index of rd kernel*/ \ + } \ gemmt_ker_ft ker_fp = ker_fpls[ker_idx]; \ ker_fp \ ( \ @@ -1949,15 +1964,17 @@ void PASTEMACT(ch,opname,uplo,varname) \ ); \ } \ /* 6x8 block where m_idx == n_idx+1 also has some parts of the diagonal */\ - else if( (storage_supported) && (m_idx == n_idx+1) && (idx_supported) ) { \ - dim_t ker_idx = (n_idx << 1) + 1; \ - gemmt_ker_ft ker_fp = ker_fpls[ker_idx]; \ + else if( (dt == BLIS_DOUBLE) && (m_idx == n_idx+1) && (idx_supported) ) { \ /* If current block was already computed in the combined kernel it can be skipped combined kernel is only implemented for n_idx=2, i == m_zero is only true for the first iteration therefore if i == m_zero then the current 6x8 block was not computed in combined kernel*/ \ if( (n_idx != 2) || (i == m_zero) ) { \ + dim_t ker_idx = (n_idx << 1) + 1; \ + /* use rd kernel if B is column major storage */ \ + if( stor_id == BLIS_RRC ) { ker_idx += 7; } \ + gemmt_ker_ft ker_fp = ker_fpls[ker_idx]; \ ker_fp \ ( \ conja, \ @@ -2580,10 +2597,6 @@ void PASTEMACT(ch,opname,uplo,varname) \ dim_t n_off_24 = n_off_cblock % 24; \ dim_t m_idx = (dim_t)(m_off_24 / MR); \ dim_t n_idx = (dim_t)(n_off_24 / NR); \ -\ - /* Optimized kernels are not implemented for the case where B is - stored as column major */ \ - bool storage_supported = (dt == BLIS_DOUBLE) && ( (stor_id == BLIS_RRR) || (stor_id == BLIS_RCR) || (stor_id == BLIS_CRR) ); \ \ /* Check if m, n indices are multiple of MR and NR respectively and current block is a complete 6x8 block */ \ @@ -2591,8 +2604,8 @@ void PASTEMACT(ch,opname,uplo,varname) \ \ /* m_idx and n_idx would be equal only if the current block is a diagonal block */\ - if( (storage_supported) && (m_idx == n_idx) && (idx_supported) ) { \ - m_idx = m_idx<<1; \ + if( (dt == BLIS_DOUBLE) && (m_idx == n_idx) && idx_supported ) { \ + dim_t ker_idx = m_idx<<1; \ /* If there is another 6x8 diagonal block pending for computation after the current 6x8 diagonal block, then the two blocks can be computed together(12x8). This combined kernel is implemented @@ -2600,10 +2613,15 @@ void PASTEMACT(ch,opname,uplo,varname) \ this, it has to be ensured that at least 12 rows are pending in C for computation (i+ MR + MR <= mc_cur). Usage of this combined kernel saves the entire time to execute one kernel*/ \ - if( (n_idx == 0) && (i+ MR + MR <= mc_cur) ) \ - m_idx = 6; /* use combined kernel, index of combined kernel + if( (n_idx == 0) && (i+ MR + MR <= mc_cur) ) { \ + ker_idx = 6; /* use combined kernel, index of combined kernel in lookup table is 6 */\ - gemmt_ker_ft ker_fp = ker_fpus[m_idx]; \ + } \ + /* if B is column storage we use rd kernel*/ \ + if( stor_id == BLIS_RRC ) { \ + ker_idx += 7; /* index of rd kernel*/\ + } \ + gemmt_ker_ft ker_fp = ker_fpus[ker_idx]; \ ker_fp \ ( \ conja, \ @@ -2621,14 +2639,17 @@ void PASTEMACT(ch,opname,uplo,varname) \ ); \ } \ /* 6x8 block where m_idx == n_idx+1 also has some parts of the diagonal */\ - else if( (storage_supported) && (m_idx == n_idx+1) && (idx_supported) ) { \ - gemmt_ker_ft ker_fp = ker_fpus[(n_idx << 1) + 1]; \ + else if( (dt == BLIS_DOUBLE) && (m_idx == n_idx+1) && (idx_supported) ) { \ /* If current block was already computed in the combined kernel it can be skipped combined kernel is only implemented for n_idx=0, i == m_rect is only true for the first iteration therefore if i == m_rect then the current 6x8 block was not computed in combined kernel*/ \ if( (n_idx != 0) || (i == m_rect) ) { \ + dim_t ker_idx = (n_idx << 1) + 1 ; \ + /* use rd kernel if B is column major storage */ \ + if( stor_id == BLIS_RRC ) { ker_idx += 7; } \ + gemmt_ker_ft ker_fp = ker_fpus[ker_idx]; \ ker_fp \ ( \ conja, \ diff --git a/kernels/haswell/3/sup/bli_gemmsup_rd_haswell_asm_d6x8m.c b/kernels/haswell/3/sup/bli_gemmsup_rd_haswell_asm_d6x8m.c index f6edad70bf..2f25755ef4 100644 --- a/kernels/haswell/3/sup/bli_gemmsup_rd_haswell_asm_d6x8m.c +++ b/kernels/haswell/3/sup/bli_gemmsup_rd_haswell_asm_d6x8m.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2022, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -654,7 +654,12153 @@ void bli_dgemmsup_rd_haswell_asm_6x8m label(.DRETURN) + vzeroupper() + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + consider_edge_cases: + + // Handle edge cases in the m dimension, if they exist. + if ( m_left ) + { + const dim_t nr_cur = 8; + const dim_t i_edge = m0 - ( dim_t )m_left; + + double* restrict cij = c + i_edge*rs_c; + double* restrict bj = b; + double* restrict ai = a + i_edge*rs_a; + + if ( 2 == m_left ) + { + const dim_t mr_cur = 2; + + bli_dgemmsup_rd_haswell_asm_2x8 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + //cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; + } + if ( 1 == m_left ) + { + const dim_t mr_cur = 1; + + bli_dgemmsup_rd_haswell_asm_1x8 + ( + conja, conjb, mr_cur, nr_cur, k0, + alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, + beta, cij, rs_c0, cs_c0, data, cntx + ); + } + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); +} + +/* +24x24 block + + 1 1 1 1 1 1 1 1 1 1 2 2 2 2 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 + |- - - - - - - -|- - - - - - - -| - - - - - - - -| +0 | | | | +1 | m_off_24 = 0 | | | +2 | n_off_24 = 0 | | | +3 | m_idx = 0 | | | +4 | n_idx = 0 | | | +5 |- - - - - - - -|- - - - - - - -|- - - - - - - - | +6 | | | | +7 | m_off_24 = 6 | m_off_24 = 6 | | +8 | n_off_24 = 0 | n_off_24 = 8 | | +9 | m_idx = 1 | m_idx = 1 | | +10 | n_idx = 0 | n_idx = 1 | | +11 |- - - - - - - -|- - - - - - - -|- - - - - - - - | +12 | | | | +13 | | m_off_24 = 12 | m_off_24 = 12 | +14 | | n_off_24 = 8 | n_off_24 = 16 | +15 | | m_idx = 2 | m_idx = 2 | +16 | | n_idx = 1 | n_idx = 2 | +17 |- - - - - - - -|- - - - - - - -|- - - - - - - - | +18 | | | | +19 | | | m_off_24 = 18 | +20 | | | n_off_24 = 16 | +21 | | | m_idx = 3 | +22 | | | n_idx = 2 | +23 |- - - - - - - -|- - - - - - - -|- - - - - - - - | +*/ + + +#define SUBITER_K4_3x4(a, b) \ +\ + vmovupd(mem(a ), ymm0) \ + vmovupd(mem(a, r8, 1), ymm1) \ + vmovupd(mem(a, r8, 2), ymm2) \ + add(imm(4*8), a) \ +\ + vmovupd(mem(b ), ymm3) \ + vfmadd231pd(ymm0, ymm3, ymm4) \ + vfmadd231pd(ymm1, ymm3, ymm5) \ + vfmadd231pd(ymm2, ymm3, ymm6) \ +\ + vmovupd(mem(b, r11, 1), ymm3) \ + vfmadd231pd(ymm0, ymm3, ymm7) \ + vfmadd231pd(ymm1, ymm3, ymm8) \ + vfmadd231pd(ymm2, ymm3, ymm9) \ +\ + vmovupd(mem(b, r11, 2), ymm3) \ + vfmadd231pd(ymm0, ymm3, ymm10) \ + vfmadd231pd(ymm1, ymm3, ymm11) \ + vfmadd231pd(ymm2, ymm3, ymm12) \ +\ + vmovupd(mem(b, r13, 1), ymm3) \ + add(imm(4*8), b) \ + vfmadd231pd(ymm0, ymm3, ymm13) \ + vfmadd231pd(ymm1, ymm3, ymm14) \ + vfmadd231pd(ymm2, ymm3, ymm15) \ + +#define SUBITER_K1_3x4(a, b) \ +\ + vmovsd(mem(a ), xmm0) \ + vmovsd(mem(a, r8, 1), xmm1) \ + vmovsd(mem(a, r8, 2), xmm2) \ + add(imm(1*8), a) \ +\ + vmovsd(mem(b ), xmm3) \ + vfmadd231pd(ymm0, ymm3, ymm4) \ + vfmadd231pd(ymm1, ymm3, ymm5) \ + vfmadd231pd(ymm2, ymm3, ymm6) \ +\ + vmovsd(mem(b, r11, 1), xmm3) \ + vfmadd231pd(ymm0, ymm3, ymm7) \ + vfmadd231pd(ymm1, ymm3, ymm8) \ + vfmadd231pd(ymm2, ymm3, ymm9) \ +\ + vmovsd(mem(b, r11, 2), xmm3) \ + vfmadd231pd(ymm0, ymm3, ymm10) \ + vfmadd231pd(ymm1, ymm3, ymm11) \ + vfmadd231pd(ymm2, ymm3, ymm12) \ +\ + vmovsd(mem(b, r13, 1), xmm3) \ + add(imm(1*8), b) \ + vfmadd231pd(ymm0, ymm3, ymm13) \ + vfmadd231pd(ymm1, ymm3, ymm14) \ + vfmadd231pd(ymm2, ymm3, ymm15) \ + +#define SUBITER_K4_2x4(a, b) \ +\ + vmovupd(mem(a ), ymm0) \ + vmovupd(mem(a, r8, 1), ymm1) \ + add(imm(4*8), a) \ +\ + vmovupd(mem(b ), ymm3) \ + vfmadd231pd(ymm0, ymm3, ymm4) \ + vfmadd231pd(ymm1, ymm3, ymm5) \ +\ + vmovupd(mem(b, r11, 1), ymm3) \ + vfmadd231pd(ymm0, ymm3, ymm7) \ + vfmadd231pd(ymm1, ymm3, ymm8) \ +\ + vmovupd(mem(b, r11, 2), ymm3) \ + vfmadd231pd(ymm0, ymm3, ymm10) \ + vfmadd231pd(ymm1, ymm3, ymm11) \ +\ + vmovupd(mem(b, r13, 1), ymm3) \ + add(imm(4*8), b) \ + vfmadd231pd(ymm0, ymm3, ymm13) \ + vfmadd231pd(ymm1, ymm3, ymm14) \ + +#define SUBITER_K1_2x4(a, b) \ +\ + vmovsd(mem(a ), xmm0) \ + vmovsd(mem(a, r8, 1), xmm1) \ + add(imm(1*8), a) \ +\ + vmovsd(mem(b ), xmm3) \ + vfmadd231pd(ymm0, ymm3, ymm4) \ + vfmadd231pd(ymm1, ymm3, ymm5) \ +\ + vmovsd(mem(b, r11, 1), xmm3) \ + vfmadd231pd(ymm0, ymm3, ymm7) \ + vfmadd231pd(ymm1, ymm3, ymm8) \ +\ + vmovsd(mem(b, r11, 2), xmm3) \ + vfmadd231pd(ymm0, ymm3, ymm10) \ + vfmadd231pd(ymm1, ymm3, ymm11) \ +\ + vmovsd(mem(b, r13, 1), xmm3) \ + add(imm(1*8), b) \ + vfmadd231pd(ymm0, ymm3, ymm13) \ + vfmadd231pd(ymm1, ymm3, ymm14) \ + +#define SUBITER_K4_1x4(a, b) \ +\ + vmovupd(mem(a ), ymm0) \ + add(imm(4*8), a) \ + vmovupd(mem(b ), ymm3) \ + vfmadd231pd(ymm0, ymm3, ymm4) \ + vmovupd(mem(b, r11, 1), ymm3) \ + vfmadd231pd(ymm0, ymm3, ymm7) \ + vmovupd(mem(b, r11, 2), ymm3) \ + vfmadd231pd(ymm0, ymm3, ymm10) \ + vmovupd(mem(b, r13, 1), ymm3) \ + add(imm(4*8), b) \ + vfmadd231pd(ymm0, ymm3, ymm13) \ + +#define SUBITER_K1_1x4(a, b) \ +\ + vmovsd(mem(a ), xmm0) \ + add(imm(1*8), a) \ + vmovsd(mem(b ), xmm3) \ + vfmadd231pd(ymm0, ymm3, ymm4) \ + vmovsd(mem(b, r11, 1), xmm3) \ + vfmadd231pd(ymm0, ymm3, ymm7) \ + vmovsd(mem(b, r11, 2), xmm3) \ + vfmadd231pd(ymm0, ymm3, ymm10) \ + vmovsd(mem(b, r13, 1), xmm3) \ + add(imm(1*8), b) \ + vfmadd231pd(ymm0, ymm3, ymm13) \ + +/* + +Following kernel computes the 6x8 block for the Upper vairant(U) of gemmt where +m_offset in 24x24 block is 0 and n_offset is 0(0x0) +(0x0)_U + +the region marked with 'x' is computed by following kernel +the region marked with '-' is not computed + + <-- n_off_24 -- > + 0 1 2 3 4 5 6 7 + +↑ 0 x x x x x x x x +| 1 - x x x x x x x +m 2 - - x x x x x x +off 3 - - - x x x x x +24 4 - - - - x x x x +| 5 - - - - - x x x +↓ + + +*/ +void bli_dgemmsup_rd_haswell_asm_6x8m_0x0_U + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t m_iter = m0 / 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + begin_asm() + + mov(var(rs_a), r8) // load rs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + mov(var(cs_b), r11) // load cs_b + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(imm(0), r15) // jj = 0; + +// ----------------------- Block 1 + + mov(var(a), r14) // load address of a + mov(var(b), rdx) // load address of b + mov(var(c), r12) // load address of c + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8 + lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rdx) // rbx = b + 4*jj*cs_b; + + mov(var(m_iter), r9) // ii = m_iter; + + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK1) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK1) // MAIN LOOP + // ---------------------------------- iteration 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + // ---------------------------------- iteration 1 + SUBITER_K4_3x4(rax, rbx) + // ---------------------------------- iteration 2 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + // ---------------------------------- iteration 3 + SUBITER_K4_3x4(rax, rbx) + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK1) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK1) + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK1) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK1) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK1) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + label(.DLOOPKLEFT1_BLOCK1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + SUBITER_K1_3x4(rax, rbx) + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK1) // iterate again if i != 0. + label(.DPOSTACCUM_BLOCK1) + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + // now avoid loading C if beta == 0 + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK1) // if ZF = 1, jump to beta == 0 case + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vextractf128(imm(1), ymm5, xmm1 ) + vmovhpd(xmm5, mem(rcx, 1*8)) + vmovupd(xmm1, mem(rcx ,2*8)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vextractf128(imm(1), ymm6, xmm1 ) + vmovupd(xmm1, mem(rcx, 2*8)) + jmp(.DDONE_BLOCK1) // jump to end. + label(.DBETAZERO_BLOCK1) + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vextractf128(imm(1), ymm5, xmm1 ) + vmovhpd(xmm5, mem(rcx, 1*8)) + vmovupd(xmm1, mem(rcx ,2*8)) + add(rdi, rcx) + vextractf128(imm(1), ymm6, xmm1 ) + vmovupd(xmm1, mem(rcx, 2*8)) + label(.DDONE_BLOCK1) + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + dec(r9) // ii -= 1; + +// ----------------------- Block 2 + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK2) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK2) // MAIN LOOP + + // ---------------------------------- iteration 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + SUBITER_K4_1x4(rax, rbx) + // ---------------------------------- iteration 1 + SUBITER_K4_1x4(rax, rbx) + // ---------------------------------- iteration 2 + prefetch(0, mem(rax, r10, 1, 0*8)) + SUBITER_K4_1x4(rax, rbx) + // ---------------------------------- iteration 3 + SUBITER_K4_1x4(rax, rbx) + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK2) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK2) + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK2) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK2) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + SUBITER_K4_1x4(rax, rbx) + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK2) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK2) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK2) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + label(.DLOOPKLEFT1_BLOCK2) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + SUBITER_K1_1x4(rax, rbx) + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK2) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK2) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK2) // if ZF = 1, jump to beta == 0 case + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vextractf128(imm(1), ymm4, xmm1 ) + vmovhpd(xmm1, mem(rcx ,3*8)) + jmp(.DDONE_BLOCK2) // jump to end. + label(.DBETAZERO_BLOCK2) + + vextractf128(imm(1), ymm4, xmm1 ) + vmovhpd(xmm1, mem(rcx ,3*8)) + label(.DDONE_BLOCK2) + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + dec(r9) // ii -= 1; + add(imm(4), r15) // jj += 4; + +// ----------------------- Block 3 + mov(var(a), r14) // load address of a + mov(var(b), rdx) // load address of b + mov(var(c), r12) // load address of c + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8 + lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rdx) // rbx = b + 4*jj*cs_b; + + mov(var(m_iter), r9) // ii = m_iter; + + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK3) // if i == 0, jump to code that + // contains the k_iter4 loop. + label(.DLOOPKITER16_BLOCK3) // MAIN LOOP + // ---------------------------------- iteration 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + // ---------------------------------- iteration 1 + SUBITER_K4_3x4(rax, rbx) + // ---------------------------------- iteration 2 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + // ---------------------------------- iteration 3 + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK3) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK3) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK3) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK3) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK3) // iterate again if i != 0. + label(.DCONSIDKLEFT1_BLOCK3) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK3) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK3) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + SUBITER_K1_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK3) // iterate again if i != 0. + label(.DPOSTACCUM_BLOCK3) + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK3) // if ZF = 1, jump to beta == 0 case + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + + jmp(.DDONE_BLOCK3) // jump to end. + + label(.DBETAZERO_BLOCK3) + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + + label(.DDONE_BLOCK3) + + lea(mem(r12, rdi, 2), r12) + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + +// ----------------------- Block 4 + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK4) // MAIN LOOP + // ---------------------------------- iteration 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + // ---------------------------------- iteration 1 + SUBITER_K4_3x4(rax, rbx) + // ---------------------------------- iteration 2 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + // ---------------------------------- iteration 3 + SUBITER_K4_3x4(rax, rbx) + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK4) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK4) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK4) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK4) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK4) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK4) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK4) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + SUBITER_K1_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK4) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK4) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK4) // if ZF = 1, jump to beta == 0 case + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vextractf128(imm(1), ymm6, xmm1 ) + vmovhpd(xmm6, mem(rcx, 1*8)) + vmovupd(xmm1, mem(rcx, 2*8)) + + + jmp(.DDONE_BLOCK4) // jump to end. + + label(.DBETAZERO_BLOCK4) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vextractf128(imm(1), ymm6, xmm1 ) + vmovhpd(xmm6, mem(rcx, 1*8)) + vmovupd(xmm1, mem(rcx, 2*8)) + + + label(.DDONE_BLOCK4) + + vzeroupper() + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); +} + +/* + +Following kernel computes the 6x8 block for the Upper vairant(U) of gemmt where +m_offset in 24x24 block is 6 and n_offset is 0(6x0) +(6x0)_U + + +the region marked with 'x' is computed by following kernel +the region marked with '-' is not computed + + <-- n_off_24 -- > + 0 1 2 3 4 5 6 7 + +↑ 6 - - - - - - x x +| 7 - - - - - - - x +m 8 - - - - - - - - +off 9 - - - - - - - - +24 10 - - - - - - - - +| 11 - - - - - - - - +↓ + + +*/ +void bli_dgemmsup_rd_haswell_asm_6x8m_6x0_U + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t m_iter = m0 / 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + begin_asm() + + mov(var(rs_a), r8) // load rs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + mov(var(cs_b), r11) // load cs_b + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(imm(4), r15) // jj = 4; +// ----------------------- Block 3 + mov(var(a), r14) // load address of a + mov(var(b), rdx) // load address of b + mov(var(c), r12) // load address of c + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8 + lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rdx) // rbx = b + 4*jj*cs_b; + + mov(var(m_iter), r9) // ii = m_iter; + +// ----------------------- Block 3 + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK3) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK3) // MAIN LOOP + // ---------------------------------- iteration 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_2x4(rax, rbx) + // ---------------------------------- iteration 1 + SUBITER_K4_2x4(rax, rbx) + // ---------------------------------- iteration 2 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_2x4(rax, rbx) + // ---------------------------------- iteration 3 + SUBITER_K4_2x4(rax, rbx) + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK3) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK3) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK3) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK3) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_2x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK3) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK3) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK3) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK3) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + SUBITER_K1_2x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK3) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK3) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK3) // if ZF = 1, jump to beta == 0 case + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vextractf128(imm(1), ymm4, xmm1 ) + vmovupd(xmm1, mem(rcx, 2*8)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vextractf128(imm(1), ymm5, xmm1 ) + vmovhpd(xmm1, mem(rcx, 3*8)) + add(rdi, rcx) + + // vfmadd231pd(mem(rcx), ymm3, ymm6) + // vmovupd(ymm6, mem(rcx)) + + + jmp(.DDONE_BLOCK3) // jump to end. + + label(.DBETAZERO_BLOCK3) + + + vextractf128(imm(1), ymm4, xmm1 ) + vmovupd(xmm1, mem(rcx, 2*8)) + add(rdi, rcx) + + vextractf128(imm(1), ymm5, xmm1 ) + vmovhpd(xmm1, mem(rcx, 3*8)) + add(rdi, rcx) + + + label(.DDONE_BLOCK3) + + vzeroupper() + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); +} + +/* + +Following kernel computes the 6x8 block for the Upper vairant(U) of gemmt where +m_offset in 24x24 block is 0, n_offset is 0(0x0) and m_offset is 6, n_offset is 0 (6x0) +(0x0)+(6x0)_L + +the region marked with 'x' is computed by following kernel +the region marked with '-' is not computed + + <-- n_off_24 -- > + 0 1 2 3 4 5 6 7 + +↑ 0 x x x x x x x x +| 1 - x x x x x x x +m 2 - - x x x x x x +off 3 - - - x x x x x +24 4 - - - - x x x x +| 5 - - - - - x x x +↓ +↑ 6 - - - - - - x x +| 7 - - - - - - - x +m 8 - - - - - - - - +off 9 - - - - - - - - +24 10 - - - - - - - - +| 11 - - - - - - - - +↓ + + +*/ +void bli_dgemmsup_rd_haswell_asm_6x8m_0x0_combined_U + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t m_iter = m0 / 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + begin_asm() + + mov(var(rs_a), r8) // load rs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + mov(var(cs_b), r11) // load cs_b + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(imm(0), r15) // jj = 0; + +// ----------------------- Block 1 + + mov(var(a), r14) // load address of a + mov(var(b), rdx) // load address of b + mov(var(c), r12) // load address of c + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8 + lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rdx) // rbx = b + 4*jj*cs_b; + + mov(var(m_iter), r9) // ii = m_iter; + + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK1) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK1) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + // ---------------------------------- iteration 1 + SUBITER_K4_3x4(rax, rbx) + // ---------------------------------- iteration 2 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + // ---------------------------------- iteration 3 + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK1) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK1) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK1) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK1) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK1) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + SUBITER_K1_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK1) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK1) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK1) // if ZF = 1, jump to beta == 0 case + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vextractf128(imm(1), ymm5, xmm1 ) + vmovhpd(xmm5, mem(rcx, 1*8)) + vmovupd(xmm1, mem(rcx ,2*8)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vextractf128(imm(1), ymm6, xmm1 ) + vmovupd(xmm1, mem(rcx, 2*8)) + + + jmp(.DDONE_BLOCK1) // jump to end. + + label(.DBETAZERO_BLOCK1) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vextractf128(imm(1), ymm5, xmm1 ) + vmovhpd(xmm5, mem(rcx, 1*8)) + vmovupd(xmm1, mem(rcx ,2*8)) + + add(rdi, rcx) + + vextractf128(imm(1), ymm6, xmm1 ) + vmovupd(xmm1, mem(rcx, 2*8)) + + + + label(.DDONE_BLOCK1) + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + +// ----------------------- Block 2 + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK2) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK2) // MAIN LOOP + + // ---------------------------------- iteration 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + SUBITER_K4_1x4(rax, rbx) + // ---------------------------------- iteration 1 + SUBITER_K4_1x4(rax, rbx) + // ---------------------------------- iteration 2 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + SUBITER_K4_1x4(rax, rbx) + // ---------------------------------- iteration 3 + SUBITER_K4_1x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK2) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK2) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK2) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK2) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + SUBITER_K4_1x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK2) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK2) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK2) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK2) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + SUBITER_K1_1x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK2) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK2) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK2) // if ZF = 1, jump to beta == 0 case + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vextractf128(imm(1), ymm4, xmm1 ) + vmovhpd(xmm1, mem(rcx ,3*8)) + + + jmp(.DDONE_BLOCK2) // jump to end. + + label(.DBETAZERO_BLOCK2) + + + vextractf128(imm(1), ymm4, xmm1 ) + vmovhpd(xmm1, mem(rcx ,3*8)) + + + label(.DDONE_BLOCK2) + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + + add(imm(4), r15) // jj += 4; + +// ----------------------- Block 3 + mov(var(a), r14) // load address of a + mov(var(b), rdx) // load address of b + mov(var(c), r12) // load address of c + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8 + lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rdx) // rbx = b + 4*jj*cs_b; + + mov(var(m_iter), r9) // ii = m_iter; + + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK3) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK3) // MAIN LOOP + + // ---------------------------------- iteration 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + // ---------------------------------- iteration 1 + SUBITER_K4_3x4(rax, rbx) + // ---------------------------------- iteration 2 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + // ---------------------------------- iteration 3 + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK3) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK3) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK3) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK3) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK3) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK3) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK3) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK3) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + SUBITER_K1_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK3) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK3) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK3) // if ZF = 1, jump to beta == 0 case + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + + + jmp(.DDONE_BLOCK3) // jump to end. + + label(.DBETAZERO_BLOCK3) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + + + label(.DDONE_BLOCK3) + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + +// ----------------------- Block 4 + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK4) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 2 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK4) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK4) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK4) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK4) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK4) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK4) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK4) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK4) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK4) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK4) // if ZF = 1, jump to beta == 0 case + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vextractf128(imm(1), ymm6, xmm1 ) + vmovhpd(xmm6, mem(rcx, 1*8)) + vmovupd(xmm1, mem(rcx, 2*8)) + + + jmp(.DDONE_BLOCK4) // jump to end. + + label(.DBETAZERO_BLOCK4) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vextractf128(imm(1), ymm6, xmm1 ) + vmovhpd(xmm6, mem(rcx, 1*8)) + vmovupd(xmm1, mem(rcx, 2*8)) + + + label(.DDONE_BLOCK4) + + vzeroupper() + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + a += 6 * rs_a; + c += 6 * rs_c; + + begin_asm() + + mov(var(rs_a), r8) // load rs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + mov(var(cs_b), r11) // load cs_b + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(imm(0), r15) // jj = 0; + +// ----------------------- Block 1 + mov(var(a), r14) // load address of a + mov(var(b), rdx) // load address of b + mov(var(c), r12) // load address of c + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8 + lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rdx) // rbx = b + 4*jj*cs_b; + + mov(var(m_iter), r9) // ii = m_iter; + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + +// ----------------------- Block 2 + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + + add(imm(4), r15) // jj += 4; + +// ----------------------- Block 3 + mov(var(a), r14) // load address of a + mov(var(b), rdx) // load address of b + mov(var(c), r12) // load address of c + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8 + lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rdx) // rbx = b + 4*jj*cs_b; + + mov(var(m_iter), r9) // ii = m_iter; + + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK3) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK3) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + // ---------------------------------- iteration 2 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK3) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK3) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK3) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK3) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK3) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK3) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK3) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK3) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK3) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK3) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK3) // if ZF = 1, jump to beta == 0 case + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vextractf128(imm(1), ymm4, xmm1 ) + vmovupd(xmm1, mem(rcx, 2*8)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vextractf128(imm(1), ymm5, xmm1 ) + vmovhpd(xmm1, mem(rcx, 3*8)) + add(rdi, rcx) + + // vfmadd231pd(mem(rcx), ymm3, ymm6) + // vmovupd(ymm6, mem(rcx)) + + + jmp(.DDONE_BLOCK3) // jump to end. + + label(.DBETAZERO_BLOCK3) + + + vextractf128(imm(1), ymm4, xmm1 ) + vmovupd(xmm1, mem(rcx, 2*8)) + add(rdi, rcx) + + vextractf128(imm(1), ymm5, xmm1 ) + vmovhpd(xmm1, mem(rcx, 3*8)) + add(rdi, rcx) + + + label(.DDONE_BLOCK3) + + vzeroupper() + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); +} + +/* + +Following kernel computes the 6x8 block for the Upper vairant(U) of gemmt where +m_offset in 24x24 block is 6 and n_offset is 8(6x8) +(6x8)_U + +the region marked with 'x' is computed by following kernel +the region marked with '-' is not computed + + <-- n_off_24 -- > + 8 9 10 11 12 13 14 15 + +↑ 6 x x x x x x x x +| 7 x x x x x x x x +m 8 x x x x x x x x +off 9 - x x x x x x x +24 10 - - x x x x x x +| 11 - - - x x x x x +↓ + + +*/ +void bli_dgemmsup_rd_haswell_asm_6x8m_6x8_U + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t m_iter = m0 / 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + begin_asm() + + mov(var(rs_a), r8) // load rs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + mov(var(cs_b), r11) // load cs_b + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(imm(0), r15) // jj = 0; + +// ----------------------- Block 1 + + mov(var(a), r14) // load address of a + mov(var(b), rdx) // load address of b + mov(var(c), r12) // load address of c + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8 + lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rdx) // rbx = b + 4*jj*cs_b; + + mov(var(m_iter), r9) // ii = m_iter; + + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK1) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK1) // MAIN LOOP + + // ---------------------------------- iteration 0 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + // ---------------------------------- iteration 1 + SUBITER_K4_3x4(rax, rbx) + // ---------------------------------- iteration 2 + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + // ---------------------------------- iteration 3 + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK1) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK1) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK1) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK1) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK1) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + SUBITER_K1_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK1) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK1) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK1) // if ZF = 1, jump to beta == 0 case + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + + + jmp(.DDONE_BLOCK1) // jump to end. + + label(.DBETAZERO_BLOCK1) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + + + label(.DDONE_BLOCK1) + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + +// ----------------------- Block 2 + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK2) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK2) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 1 + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 2 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 3 + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK2) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK2) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK2) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK2) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK2) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK2) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK2) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK2) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + SUBITER_K1_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK2) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK2) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK2) // if ZF = 1, jump to beta == 0 case + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vextractf128(imm(1), ymm4, xmm1 ) + vmovhpd(xmm4, mem(rcx, 1*8)) + vmovupd(xmm1, mem(rcx, 2*8)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vextractf128(imm(1), ymm5, xmm1 ) + vmovupd(xmm1, mem(rcx, 2*8)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vextractf128(imm(1), ymm6, xmm1 ) + vmovhpd(xmm1, mem(rcx, 3*8)) + + + jmp(.DDONE_BLOCK2) // jump to end. + + label(.DBETAZERO_BLOCK2) + + + vextractf128(imm(1), ymm4, xmm1 ) + vmovhpd(xmm4, mem(rcx, 1*8)) + vmovupd(xmm1, mem(rcx, 2*8)) + add(rdi, rcx) + + vextractf128(imm(1), ymm5, xmm1 ) + vmovupd(xmm1, mem(rcx, 2*8)) + add(rdi, rcx) + + vextractf128(imm(1), ymm6, xmm1 ) + vmovhpd(xmm1, mem(rcx, 3*8)) + + + label(.DDONE_BLOCK2) + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + + add(imm(4), r15) // jj += 4; + +// ----------------------- Block 3 + mov(var(a), r14) // load address of a + mov(var(b), rdx) // load address of b + mov(var(c), r12) // load address of c + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8 + lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rdx) // rbx = b + 4*jj*cs_b; + + mov(var(m_iter), r9) // ii = m_iter; + + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK3) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK3) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 1 + + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 2 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 3 + + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK3) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK3) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK3) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK3) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK3) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK3) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK3) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK3) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + SUBITER_K1_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK3) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK3) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK3) // if ZF = 1, jump to beta == 0 case + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + + + jmp(.DDONE_BLOCK3) // jump to end. + + label(.DBETAZERO_BLOCK3) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + + + label(.DDONE_BLOCK3) + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + +// ----------------------- Block 4 + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK4) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 1 + + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 2 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 3 + + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK4) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK4) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK4) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK4) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK4) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK4) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK4) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + SUBITER_K1_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK4) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK4) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK4) // if ZF = 1, jump to beta == 0 case + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + + + jmp(.DDONE_BLOCK4) // jump to end. + + label(.DBETAZERO_BLOCK4) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + + + label(.DDONE_BLOCK4) + + vzeroupper() + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); +} + +/* + +Following kernel computes the 6x8 block for the Upper vairant(U) of gemmt where +m_offset in 24x24 block is 12 and n_offset is 8(12x8) +(12x8)_U + +the region marked with 'x' is computed by following kernel +the region marked with '-' is not computed + + <-- n_off_24 -- > + 8 9 10 11 12 13 14 15 + +↑ 12 - - - - x x x x +| 13 - - - - - x x x +m 14 - - - - - - x x +off 15 - - - - - - - x +24 16 - - - - - - - - +| 17 - - - - - - - - +↓ + + +*/ +void bli_dgemmsup_rd_haswell_asm_6x8m_12x8_U + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t m_iter = m0 / 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + begin_asm() + + mov(var(rs_a), r8) // load rs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + mov(var(cs_b), r11) // load cs_b + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(imm(4), r15) // jj = 0; + +// ----------------------- Block 3 + + mov(var(a), r14) // load address of a + mov(var(b), rdx) // load address of b + mov(var(c), r12) // load address of c + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8 + lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rdx) // rbx = b + 4*jj*cs_b; + + mov(var(m_iter), r9) // ii = m_iter; + + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK3) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK3) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 1 + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 2 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 3 + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK3) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK3) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK3) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK3) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK3) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK3) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK3) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK3) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + SUBITER_K1_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK3) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK3) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK3) // if ZF = 1, jump to beta == 0 case + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vextractf128(imm(1), ymm5, xmm1 ) + vmovhpd(xmm5, mem(rcx, 1*8)) + vmovupd(xmm1, mem(rcx ,2*8)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vextractf128(imm(1), ymm6, xmm1 ) + vmovupd(xmm1, mem(rcx, 2*8)) + + + jmp(.DDONE_BLOCK3) // jump to end. + + label(.DBETAZERO_BLOCK3) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vextractf128(imm(1), ymm5, xmm1 ) + vmovhpd(xmm5, mem(rcx, 1*8)) + vmovupd(xmm1, mem(rcx ,2*8)) + + add(rdi, rcx) + + vextractf128(imm(1), ymm6, xmm1 ) + vmovupd(xmm1, mem(rcx, 2*8)) + + + + label(.DDONE_BLOCK3) + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + +// ----------------------- Block 4 + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK4) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_1x4(rax, rbx) + + // ---------------------------------- iteration 1 + SUBITER_K4_1x4(rax, rbx) + + // ---------------------------------- iteration 2 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_1x4(rax, rbx) + + // ---------------------------------- iteration 3 + SUBITER_K4_1x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK4) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK4) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK4) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_1x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK4) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK4) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK4) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK4) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + SUBITER_K1_1x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK4) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK4) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK4) // if ZF = 1, jump to beta == 0 case + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vextractf128(imm(1), ymm4, xmm1 ) + vmovhpd(xmm1, mem(rcx ,3*8)) + + + jmp(.DDONE_BLOCK4) // jump to end. + + label(.DBETAZERO_BLOCK4) + + + vextractf128(imm(1), ymm4, xmm1 ) + vmovhpd(xmm1, mem(rcx ,3*8)) + + + label(.DDONE_BLOCK4) + + vzeroupper() + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); +} + +/* + +Following kernel computes the 6x8 block for the Upper vairant(U) of gemmt where +m_offset in 24x24 block is 12 and n_offset is 16(12x16) +(12x16)_U + + +the region marked with 'x' is computed by following kernel +the region marked with '-' is not computed + + <-- n_off_24 -- > + 16 17 18 19 20 21 22 23 + +↑ 12 x x x x x x x x +| 13 x x x x x x x x +m 14 x x x x x x x x +off 15 x x x x x x x x +24 16 x x x x x x x x +| 17 - x x x x x x x +↓ + + +*/ +void bli_dgemmsup_rd_haswell_asm_6x8m_12x16_U + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t m_iter = m0 / 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + begin_asm() + + mov(var(rs_a), r8) // load rs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + mov(var(cs_b), r11) // load cs_b + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(imm(0), r15) // jj = 0; + +// ----------------------- Block 1 + + mov(var(a), r14) // load address of a + mov(var(b), rdx) // load address of b + mov(var(c), r12) // load address of c + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8 + lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rdx) // rbx = b + 4*jj*cs_b; + + mov(var(m_iter), r9) // ii = m_iter; + + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK1) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK1) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 1 + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 2 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 3 + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK1) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK1) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK1) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK1) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK1) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + SUBITER_K1_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK1) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK1) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK1) // if ZF = 1, jump to beta == 0 case + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + + + jmp(.DDONE_BLOCK1) // jump to end. + + label(.DBETAZERO_BLOCK1) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + + + label(.DDONE_BLOCK1) + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + +// ----------------------- Block 2 + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK2) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK2) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 1 + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 2 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 3 + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK2) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK2) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK2) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK2) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK2) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK2) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK2) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK2) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + SUBITER_K1_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK2) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK2) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK2) // if ZF = 1, jump to beta == 0 case + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vextractf128(imm(1), ymm6, xmm1 ) + vmovhpd(xmm6, mem(rcx, 1*8)) + vmovupd(xmm1, mem(rcx, 2*8)) + + + jmp(.DDONE_BLOCK2) // jump to end. + + label(.DBETAZERO_BLOCK2) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vextractf128(imm(1), ymm6, xmm1 ) + vmovhpd(xmm6, mem(rcx, 1*8)) + vmovupd(xmm1, mem(rcx, 2*8)) + + + label(.DDONE_BLOCK2) + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + + add(imm(4), r15) // jj += 4; + +// ----------------------- Block 3 + mov(var(a), r14) // load address of a + mov(var(b), rdx) // load address of b + mov(var(c), r12) // load address of c + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8 + lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rdx) // rbx = b + 4*jj*cs_b; + + mov(var(m_iter), r9) // ii = m_iter; + + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK3) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK3) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 1 + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 2 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 3 + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK3) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK3) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK3) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK3) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK3) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK3) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK3) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK3) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + SUBITER_K1_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK3) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK3) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK3) // if ZF = 1, jump to beta == 0 case + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + + + jmp(.DDONE_BLOCK3) // jump to end. + + label(.DBETAZERO_BLOCK3) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + + + label(.DDONE_BLOCK3) + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + +// ----------------------- Block 4 + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK4) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 1 + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 2 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 3 + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK4) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK4) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK4) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK4) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK4) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK4) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK4) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + SUBITER_K1_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK4) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK4) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK4) // if ZF = 1, jump to beta == 0 case + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + + + jmp(.DDONE_BLOCK4) // jump to end. + + label(.DBETAZERO_BLOCK4) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + + + label(.DDONE_BLOCK4) + + vzeroupper() + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); +} + +/* + +Following kernel computes the 6x8 block for the Upper vairant(U) of gemmt where +m_offset in 24x24 block is 12 and n_offset is 8(12x8) +(12x8)_U + +the region marked with 'x' is computed by following kernel +the region marked with '-' is not computed + + <-- n_off_24 -- > + 8 9 10 11 12 13 14 15 + +↑ 12 - - - - x x x x +| 13 - - - - - x x x +m 14 - - - - - - x x +off 15 - - - - - - - x +24 16 - - - - - - - - +| 17 - - - - - - - - +↓ + + +*/ +void bli_dgemmsup_rd_haswell_asm_6x8m_18x16_U + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t m_iter = m0 / 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + begin_asm() + + mov(var(rs_a), r8) // load rs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + mov(var(cs_b), r11) // load cs_b + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(imm(0), r15) // jj = 0; + +// ----------------------- Block 1 + mov(var(a), r14) // load address of a + mov(var(b), rdx) // load address of b + mov(var(c), r12) // load address of c + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8 + lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rdx) // rbx = b + 4*jj*cs_b; + + mov(var(m_iter), r9) // ii = m_iter; + + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK1) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK1) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_2x4(rax, rbx) + + // ---------------------------------- iteration 1 + SUBITER_K4_2x4(rax, rbx) + + // ---------------------------------- iteration 2 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_2x4(rax, rbx) + + // ---------------------------------- iteration 3 + SUBITER_K4_2x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK1) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK1) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK1) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_2x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK1) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK1) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + SUBITER_K1_2x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK1) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK1) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK1) // if ZF = 1, jump to beta == 0 case + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vextractf128(imm(1), ymm4, xmm1 ) + vmovupd(xmm1, mem(rcx, 2*8)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vextractf128(imm(1), ymm5, xmm1 ) + vmovhpd(xmm1, mem(rcx, 3*8)) + + // vfmadd231pd(mem(rcx), ymm3, ymm6) + // vmovupd(ymm6, mem(rcx)) + + + jmp(.DDONE_BLOCK1) // jump to end. + + label(.DBETAZERO_BLOCK1) + + + vextractf128(imm(1), ymm4, xmm1 ) + vmovupd(xmm1, mem(rcx, 2*8)) + add(rdi, rcx) + + vextractf128(imm(1), ymm5, xmm1 ) + vmovhpd(xmm1, mem(rcx, 3*8)) + + + label(.DDONE_BLOCK1) +// ----------------------- Block 2 + add(imm(4), r15) // jj += 4; + +// ----------------------- Block 3 + mov(var(a), r14) // load address of a + mov(var(b), rdx) // load address of b + mov(var(c), r12) // load address of c + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8 + lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rdx) // rbx = b + 4*jj*cs_b; + + mov(var(m_iter), r9) // ii = m_iter; + + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK3) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK3) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 1 + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 2 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 3 + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK3) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK3) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK3) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK3) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK3) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK3) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK3) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK3) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + SUBITER_K1_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK3) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK3) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK3) // if ZF = 1, jump to beta == 0 case + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + + + jmp(.DDONE_BLOCK3) // jump to end. + + label(.DBETAZERO_BLOCK3) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + + + label(.DDONE_BLOCK3) + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + +// ----------------------- Block 4 + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK4) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 1 + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 2 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 3 + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK4) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK4) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK4) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK4) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK4) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK4) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK4) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + SUBITER_K1_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK4) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK4) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK4) // if ZF = 1, jump to beta == 0 case + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vextractf128(imm(1), ymm4, xmm1 ) + vmovhpd(xmm4, mem(rcx, 1*8)) + vmovupd(xmm1, mem(rcx, 2*8)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vextractf128(imm(1), ymm5, xmm1 ) + vmovupd(xmm1, mem(rcx, 2*8)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vextractf128(imm(1), ymm6, xmm1 ) + vmovhpd(xmm1, mem(rcx, 3*8)) + + + jmp(.DDONE_BLOCK4) // jump to end. + + label(.DBETAZERO_BLOCK4) + + + vextractf128(imm(1), ymm4, xmm1 ) + vmovhpd(xmm4, mem(rcx, 1*8)) + vmovupd(xmm1, mem(rcx, 2*8)) + add(rdi, rcx) + + vextractf128(imm(1), ymm5, xmm1 ) + vmovupd(xmm1, mem(rcx, 2*8)) + add(rdi, rcx) + + vextractf128(imm(1), ymm6, xmm1 ) + vmovhpd(xmm1, mem(rcx, 3*8)) + + + label(.DDONE_BLOCK4) + + vzeroupper() + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); +} + +/* + +Following kernel computes the 6x8 block for the Lower vairant(L) of gemmt where +m_offset in 24x24 block is 0 and n_offset is 0(0x0) +(0x0)_L + +the region marked with 'x' is computed by following kernel +the region marked with '-' is not computed + + <-- n_off_24 -- > + 0 1 2 3 4 5 6 7 + +↑ 0 x - - - - - - - +| 1 x x - - - - - - +m 2 x x x - - - - - +off 3 x x x x - - - - +24 4 x x x x x - - - +| 5 x x x x x x - - +↓ + + +*/ +void bli_dgemmsup_rd_haswell_asm_6x8m_0x0_L + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t m_iter = m0 / 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + begin_asm() + + mov(var(rs_a), r8) // load rs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + mov(var(cs_b), r11) // load cs_b + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(imm(0), r15) // jj = 0; + +// ----------------------- Block 1 + + mov(var(a), r14) // load address of a + mov(var(b), rdx) // load address of b + mov(var(c), r12) // load address of c + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8 + lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rdx) // rbx = b + 4*jj*cs_b; + + mov(var(m_iter), r9) // ii = m_iter; + + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK1) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK1) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 1 + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 2 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 3 + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK1) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK1) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK1) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK1) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK1) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + SUBITER_K1_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK1) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK1) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK1) // if ZF = 1, jump to beta == 0 case + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovlpd(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(xmm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(xmm6, mem(rcx)) + vextractf128(imm(1), ymm6, xmm1 ) + vmovlpd(xmm1, mem(rcx, 2*8)) + + + jmp(.DDONE_BLOCK1) // jump to end. + + label(.DBETAZERO_BLOCK1) + + + vmovlpd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm6, mem(rcx)) + vextractf128(imm(1), ymm6, xmm1 ) + vmovlpd(xmm1, mem(rcx, 2*8)) + + + label(.DDONE_BLOCK1) + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + +// ----------------------- Block 2 + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK2) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK2) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 1 + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 2 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 3 + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK2) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK2) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK2) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK2) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK2) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK2) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK2) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK2) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + SUBITER_K1_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK2) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK2) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK2) // if ZF = 1, jump to beta == 0 case + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + + + jmp(.DDONE_BLOCK2) // jump to end. + + label(.DBETAZERO_BLOCK2) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + + + label(.DDONE_BLOCK2) + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + + add(imm(4), r15) // jj += 4; + +// ----------------------- Block 3 + mov(var(a), r14) // load address of a + mov(var(b), rdx) // load address of b + mov(var(c), r12) // load address of c + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8 + lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rdx) // rbx = b + 4*jj*cs_b; + + mov(var(m_iter), r9) // ii = m_iter; + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + +// ----------------------- Block 4 + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK4) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 1 + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 2 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 3 + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK4) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK4) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK4) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK4) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK4) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK4) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK4) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK4) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK4) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK4) // if ZF = 1, jump to beta == 0 case + + + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovlpd(xmm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(xmm6, mem(rcx)) + + + jmp(.DDONE_BLOCK4) // jump to end. + + label(.DBETAZERO_BLOCK4) + + + add(rdi, rcx) + + vmovlpd(xmm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm6, mem(rcx)) + + + label(.DDONE_BLOCK4) + + vzeroupper() + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); +} + +/* + +Following kernel computes the 6x8 block for the Lower vairant(L) of gemmt where +m_offset in 24x24 block is 6 and n_offset is 0(6x0) +(6x0)_L + + +the region marked with 'x' is computed by following kernel +the region marked with '-' is not computed + + <-- n_off_24 -- > + 0 1 2 3 4 5 6 7 + +↑ 6 x x x x x x x - +| 7 x x x x x x x x +m 8 x x x x x x x x +off 9 x x x x x x x x +24 10 x x x x x x x x +| 11 x x x x x x x x +↓ + + +*/ +void bli_dgemmsup_rd_haswell_asm_6x8m_6x0_L + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t m_iter = m0 / 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + begin_asm() + + mov(var(rs_a), r8) // load rs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + mov(var(cs_b), r11) // load cs_b + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(imm(0), r15) // jj = 0; + +// ----------------------- Block 1 + + mov(var(a), r14) // load address of a + mov(var(b), rdx) // load address of b + mov(var(c), r12) // load address of c + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8 + lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rdx) // rbx = b + 4*jj*cs_b; + + mov(var(m_iter), r9) // ii = m_iter; + + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK1) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK1) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 1 + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 2 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 3 + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK1) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK1) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK1) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK1) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK1) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + SUBITER_K1_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK1) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK1) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK1) // if ZF = 1, jump to beta == 0 case + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + + + jmp(.DDONE_BLOCK1) // jump to end. + + label(.DBETAZERO_BLOCK1) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + + + label(.DDONE_BLOCK1) + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + +// ----------------------- Block 2 + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK2) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK2) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 1 + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 2 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 3 + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK2) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK2) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK2) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK2) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK2) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK2) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK2) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK2) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + SUBITER_K1_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK2) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK2) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK2) // if ZF = 1, jump to beta == 0 case + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + + + jmp(.DDONE_BLOCK2) // jump to end. + + label(.DBETAZERO_BLOCK2) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + + + label(.DDONE_BLOCK2) + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + + add(imm(4), r15) // jj += 4; + +// ----------------------- Block 3 + mov(var(a), r14) // load address of a + mov(var(b), rdx) // load address of b + mov(var(c), r12) // load address of c + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8 + lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rdx) // rbx = b + 4*jj*cs_b; + + mov(var(m_iter), r9) // ii = m_iter; + + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK3) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK3) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 1 + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 2 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 3 + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK3) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK3) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK3) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK3) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK3) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK3) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK3) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK3) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + SUBITER_K1_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK3) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK3) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK3) // if ZF = 1, jump to beta == 0 case + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vextractf128(imm(1), ymm4, xmm1 ) + vmovupd(xmm4, mem(rcx)) + vmovlpd(xmm1, mem(rcx, 2*8)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + + + jmp(.DDONE_BLOCK3) // jump to end. + + label(.DBETAZERO_BLOCK3) + + + vextractf128(imm(1), ymm4, xmm1 ) + vmovupd(xmm4, mem(rcx)) + vmovlpd(xmm1, mem(rcx, 2*8)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + + + label(.DDONE_BLOCK3) + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + +// ----------------------- Block 4 + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK4) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 1 + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 2 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 3 + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK4) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK4) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK4) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK4) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK4) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK4) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK4) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + SUBITER_K1_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK4) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK4) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK4) // if ZF = 1, jump to beta == 0 case + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + + + jmp(.DDONE_BLOCK4) // jump to end. + + label(.DBETAZERO_BLOCK4) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + + + label(.DDONE_BLOCK4) + + vzeroupper() + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); +} + +/* + +Following kernel computes the 6x8 block for the Lower vairant(L) of gemmt where +m_offset in 24x24 block is 6 and n_offset is 8(6x8) +(6x8)_L + +the region marked with 'x' is computed by following kernel +the region marked with '-' is not computed + + <-- n_off_24 -- > + 8 9 10 11 12 13 14 15 + +↑ 6 - - - - - - - - +| 7 - - - - - - - - +m 8 x - - - - - - - +off 9 x x - - - - - - +24 10 x x x - - - - - +| 11 x x x x - - - - +↓ + + +*/ +void bli_dgemmsup_rd_haswell_asm_6x8m_6x8_L + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t m_iter = m0 / 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + begin_asm() + + mov(var(rs_a), r8) // load rs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + mov(var(cs_b), r11) // load cs_b + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(imm(0), r15) // jj = 0; + +// ----------------------- Block 1 + + mov(var(a), r14) // load address of a + mov(var(b), rdx) // load address of b + mov(var(c), r12) // load address of c + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8 + lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rdx) // rbx = b + 4*jj*cs_b; + + mov(var(m_iter), r9) // ii = m_iter; + + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK1) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK1) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 1 + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 2 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 3 + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK1) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK1) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK1) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK1) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK1) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK1) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK1) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm6, ymm6) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK1) // if ZF = 1, jump to beta == 0 case + + + add(rdi, rcx) + add(rdi, rcx) + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovlpd(xmm6, mem(rcx)) + + + jmp(.DDONE_BLOCK1) // jump to end. + + label(.DBETAZERO_BLOCK1) + + add(rdi, rcx) + add(rdi, rcx) + vmovlpd(xmm6, mem(rcx)) + + + + label(.DDONE_BLOCK1) + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + +// ----------------------- Block 2 + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK2) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK2) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 1 + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 2 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 3 + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK2) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK2) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK2) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK2) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK2) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK2) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK2) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK2) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + SUBITER_K1_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK2) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK2) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK2) // if ZF = 1, jump to beta == 0 case + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vextractf128(imm(1), ymm5, xmm1 ) + vmovupd(xmm5, mem(rcx)) + vmovlpd(xmm1, mem(rcx, 2*8)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + + + jmp(.DDONE_BLOCK2) // jump to end. + + label(.DBETAZERO_BLOCK2) + + + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vextractf128(imm(1), ymm5, xmm1 ) + vmovupd(xmm5, mem(rcx)) + vmovlpd(xmm1, mem(rcx, 2*8)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + + + label(.DDONE_BLOCK2) + + vzeroupper() + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); +} + +/* + +Following kernel computes the 6x8 block for the Lower vairant(L) of gemmt where +m_offset in 24x24 block is 12 and n_offset is 8(12x8) +(12x8)_L + +the region marked with 'x' is computed by following kernel +the region marked with '-' is not computed + + <-- n_off_24 -- > + 8 9 10 11 12 13 14 15 + +↑ 12 x x x x x - - - +| 13 x x x x x x - - +m 14 x x x x x x x - +off 15 x x x x x x x x +24 16 x x x x x x x x +| 17 x x x x x x x x +↓ + + +*/ +void bli_dgemmsup_rd_haswell_asm_6x8m_12x8_L + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t m_iter = m0 / 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + begin_asm() + + mov(var(rs_a), r8) // load rs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + mov(var(cs_b), r11) // load cs_b + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(imm(0), r15) // jj = 0; + +// ----------------------- Block 1 + + mov(var(a), r14) // load address of a + mov(var(b), rdx) // load address of b + mov(var(c), r12) // load address of c + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8 + lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rdx) // rbx = b + 4*jj*cs_b; + + mov(var(m_iter), r9) // ii = m_iter; + + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK1) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK1) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 1 + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 2 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 3 + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK1) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK1) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK1) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK1) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK1) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + SUBITER_K1_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK1) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK1) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK1) // if ZF = 1, jump to beta == 0 case + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + + + jmp(.DDONE_BLOCK1) // jump to end. + + label(.DBETAZERO_BLOCK1) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + + label(.DDONE_BLOCK1) + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + +// ----------------------- Block 2 + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK2) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK2) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 1 + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 2 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 3 + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK2) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK2) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK2) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK2) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK2) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK2) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK2) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK2) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + SUBITER_K1_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK2) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK2) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK2) // if ZF = 1, jump to beta == 0 case + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + + jmp(.DDONE_BLOCK2) // jump to end. + + label(.DBETAZERO_BLOCK2) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + + label(.DDONE_BLOCK2) + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + + add(imm(4), r15) // jj += 4; + +// ----------------------- Block 3 + mov(var(a), r14) // load address of a + mov(var(b), rdx) // load address of b + mov(var(c), r12) // load address of c + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8 + lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rdx) // rbx = b + 4*jj*cs_b; + + mov(var(m_iter), r9) // ii = m_iter; + + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK3) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK3) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 1 + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 2 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + // ---------------------------------- iteration 3 + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK3) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK3) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK3) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK3) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK3) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK3) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK3) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK3) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + SUBITER_K1_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK3) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK3) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK3) // if ZF = 1, jump to beta == 0 case + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovlpd(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(xmm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vextractf128(imm(1), ymm6, xmm1 ) + vmovupd(xmm6, mem(rcx)) + vmovlpd(xmm1, mem(rcx, 2*8)) + + jmp(.DDONE_BLOCK3) // jump to end. + + label(.DBETAZERO_BLOCK3) + + + vmovlpd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm5, mem(rcx)) + add(rdi, rcx) + + vextractf128(imm(1), ymm6, xmm1 ) + vmovupd(xmm6, mem(rcx)) + vmovlpd(xmm1, mem(rcx, 2*8)) + + label(.DDONE_BLOCK3) + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + +// ----------------------- Block 4 + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK4) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + SUBITER_K4_3x4(rax, rbx) + // ---------------------------------- iteration 1 + SUBITER_K4_3x4(rax, rbx) + // ---------------------------------- iteration 2 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + SUBITER_K4_3x4(rax, rbx) + // ---------------------------------- iteration 3 + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK4) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK4) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK4) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + SUBITER_K4_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK4) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK4) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK4) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK4) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + SUBITER_K1_3x4(rax, rbx) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK4) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK4) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK4) // if ZF = 1, jump to beta == 0 case + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + + + jmp(.DDONE_BLOCK4) // jump to end. + + label(.DBETAZERO_BLOCK4) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + + + label(.DDONE_BLOCK4) + + vzeroupper() + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); +} + +/* + +Following kernel computes the 6x8 block for the Lower vairant(L) of gemmt where +m_offset in 24x24 block is 12 and n_offset is 16(12x16) +(12x16)_L + + +the region marked with 'x' is computed by following kernel +the region marked with '-' is not computed + + <-- n_off_24 -- > + 16 17 18 19 20 21 22 23 + +↑ 12 - - - - - - - - +| 13 - - - - - - - - +m 14 - - - - - - - - +off 15 - - - - - - - - +24 16 x - - - - - - - +| 17 x x - - - - - - +↓ + + +*/ +void bli_dgemmsup_rd_haswell_asm_6x8m_12x16_L + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t m_iter = m0 / 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + begin_asm() + + mov(var(rs_a), r8) // load rs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + mov(var(cs_b), r11) // load cs_b + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(imm(0), r15) // jj = 0; + +// ----------------------- Block 1 + + mov(var(a), r14) // load address of a + mov(var(b), rdx) // load address of b + mov(var(c), r12) // load address of c + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8 + lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rdx) // rbx = b + 4*jj*cs_b; + + mov(var(m_iter), r9) // ii = m_iter; + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + +// ----------------------- Block 2 + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK2) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK2) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 2 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 3 + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK2) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK2) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK2) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK2) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK2) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK2) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK2) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK2) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK2) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK2) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK2) // if ZF = 1, jump to beta == 0 case + + + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovlpd(xmm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(xmm6, mem(rcx)) + + + jmp(.DDONE_BLOCK2) // jump to end. + + label(.DBETAZERO_BLOCK2) + + + add(rdi, rcx) + + vmovlpd(xmm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm6, mem(rcx)) + + label(.DDONE_BLOCK2) + + vzeroupper() + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); +} + +/* + +Following kernel computes the 6x8 block for the Lower vairant(L) of gemmt where +m_offset in 24x24 block is 12, n_offset is 16(12x16) and m_offset is 18, n_offset is 16 (18x16) +(16x12)+(16x18)_L + +the region marked with 'x' is computed by following kernel +the region marked with '-' is not computed + + <-- n_off_24 -- > + 16 17 18 19 20 21 22 23 + +↑ 12 - - - - - - - - +| 13 - - - - - - - - +m 14 - - - - - - - - +off 15 - - - - - - - - +24 16 x - - - - - - - +| 17 x x - - - - - - +↓ +↑ 18 x x x - - - - - +| 19 x x x x - - - - +m 20 x x x x x - - - +off 21 x x x x x x - - +24 22 x x x x x x x - +| 23 x x x x x x x x +↓ + + +*/ +void bli_dgemmsup_rd_haswell_asm_6x8m_16x12_combined_L + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t m_iter = m0 / 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + begin_asm() + + mov(var(rs_a), r8) // load rs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + mov(var(cs_b), r11) // load cs_b + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(imm(0), r15) // jj = 0; + +// ----------------------- Block 1 + + mov(var(a), r14) // load address of a + mov(var(b), rdx) // load address of b + mov(var(c), r12) // load address of c + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8 + lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rdx) // rbx = b + 4*jj*cs_b; + + mov(var(m_iter), r9) // ii = m_iter; + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + +// ----------------------- Block 2 + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK2) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK2) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 2 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK2) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK2) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK2) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK2) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK2) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK2) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK2) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK2) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK2) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK2) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK2) // if ZF = 1, jump to beta == 0 case + + + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovlpd(xmm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(xmm6, mem(rcx)) + + jmp(.DDONE_BLOCK2) // jump to end. + + label(.DBETAZERO_BLOCK2) + + + add(rdi, rcx) + + vmovlpd(xmm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm6, mem(rcx)) + + label(.DDONE_BLOCK2) + + vzeroupper() + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + a += 6 * rs_a; + c += 6 * rs_c; + begin_asm() + + mov(var(rs_a), r8) // load rs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + mov(var(cs_b), r11) // load cs_b + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(imm(0), r15) // jj = 0; + +// ----------------------- Block 1 + + mov(var(a), r14) // load address of a + mov(var(b), rdx) // load address of b + mov(var(c), r12) // load address of c + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8 + lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rdx) // rbx = b + 4*jj*cs_b; + + mov(var(m_iter), r9) // ii = m_iter; + + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK1) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK1) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 2 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK1) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK1) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK1) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK1) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK1) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK1) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK1) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK1) // if ZF = 1, jump to beta == 0 case + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vextractf128(imm(1), ymm4, xmm1 ) + vmovupd(xmm4, mem(rcx)) + vmovlpd(xmm1, mem(rcx, 2*8)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + + jmp(.DDONE_BLOCK1) // jump to end. + + label(.DBETAZERO_BLOCK1) + + + vextractf128(imm(1), ymm4, xmm1 ) + vmovupd(xmm4, mem(rcx)) + vmovlpd(xmm1, mem(rcx, 2*8)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + + label(.DDONE_BLOCK1) + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + +// ----------------------- Block 2 + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK2) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK2) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 2 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK2) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK2) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK2) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK2) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK2) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK2) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK2) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK2) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK2) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK2) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK2) // if ZF = 1, jump to beta == 0 case + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + + jmp(.DDONE_BLOCK2) // jump to end. + + label(.DBETAZERO_BLOCK2) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + + label(.DDONE_BLOCK2) + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + + add(imm(4), r15) // jj += 4; + +// ----------------------- Block 3 + mov(var(a), r14) // load address of a + mov(var(b), rdx) // load address of b + mov(var(c), r12) // load address of c + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8 + lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rdx) // rbx = b + 4*jj*cs_b; + + mov(var(m_iter), r9) // ii = m_iter; + + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK3) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK3) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 2 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK3) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK3) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK3) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK3) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK3) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK3) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK3) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK3) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK3) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK3) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm6, ymm6) + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK3) // if ZF = 1, jump to beta == 0 case + + + add(rdi, rcx) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovlpd(xmm6, mem(rcx)) + + jmp(.DDONE_BLOCK3) // jump to end. + + label(.DBETAZERO_BLOCK3) + + + add(rdi, rcx) + add(rdi, rcx) + vmovlpd(xmm6, mem(rcx)) + + label(.DDONE_BLOCK3) + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + +// ----------------------- Block 4 + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK4) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 2 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK4) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK4) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK4) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK4) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK4) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK4) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK4) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK4) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK4) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK4) // if ZF = 1, jump to beta == 0 case + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(xmm5, mem(rcx)) + vextractf128(imm(1), ymm5, xmm1 ) + vmovlpd(xmm1, mem(rcx, 2*8)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + + jmp(.DDONE_BLOCK4) // jump to end. + + label(.DBETAZERO_BLOCK4) + + + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm5, mem(rcx)) + vextractf128(imm(1), ymm5, xmm1 ) + vmovlpd(xmm1, mem(rcx, 2*8)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + + label(.DDONE_BLOCK4) + + vzeroupper() + + end_asm( + : // output operands (none) + : // input operands + [m_iter] "m" (m_iter), + [k_iter16] "m" (k_iter16), + [k_iter4] "m" (k_iter4), + [k_left1] "m" (k_left1), + [a] "m" (a), + [rs_a] "m" (rs_a), + [cs_a] "m" (cs_a), + [b] "m" (b), + [rs_b] "m" (rs_b), + [cs_b] "m" (cs_b), + [alpha] "m" (alpha), + [beta] "m" (beta), + [c] "m" (c), + [rs_c] "m" (rs_c), + [cs_c] "m" (cs_c)/*, + [a_next] "m" (a_next), + [b_next] "m" (b_next)*/ + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", + "xmm0", "xmm1", "xmm2", "xmm3", + "xmm4", "xmm5", "xmm6", "xmm7", + "xmm8", "xmm9", "xmm10", "xmm11", + "xmm12", "xmm13", "xmm14", "xmm15", + "memory" + ) + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); +} + +/* + +Following kernel computes the 6x8 block for the Lower vairant(L) of gemmt where +m_offset in 24x24 block is 18 and n_offset is 16(18x16) +(18x16)_L + + +the region marked with 'x' is computed by following kernel +the region marked with '-' is not computed + + <-- n_off_24 -- > + 16 17 18 19 20 21 22 23 + +↑ 18 x x x - - - - - +| 19 x x x x - - - - +m 20 x x x x x - - - +off 21 x x x x x x - - +24 22 x x x x x x x - +| 23 x x x x x x x x +↓ + + +*/ +void bli_dgemmsup_rd_haswell_asm_6x8m_18x16_L + ( + conj_t conja, + conj_t conjb, + dim_t m0, + dim_t n0, + dim_t k0, + double* restrict alpha, + double* restrict a, inc_t rs_a0, inc_t cs_a0, + double* restrict b, inc_t rs_b0, inc_t cs_b0, + double* restrict beta, + double* restrict c, inc_t rs_c0, inc_t cs_c0, + auxinfo_t* restrict data, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + + uint64_t k_iter16 = k0 / 16; + uint64_t k_left16 = k0 % 16; + uint64_t k_iter4 = k_left16 / 4; + uint64_t k_left1 = k_left16 % 4; + + uint64_t m_iter = m0 / 3; + + uint64_t rs_a = rs_a0; + uint64_t cs_a = cs_a0; + uint64_t rs_b = rs_b0; + uint64_t cs_b = cs_b0; + uint64_t rs_c = rs_c0; + uint64_t cs_c = cs_c0; + + begin_asm() + + mov(var(rs_a), r8) // load rs_a + lea(mem(, r8, 8), r8) // rs_a *= sizeof(double) + mov(var(cs_b), r11) // load cs_b + lea(mem(, r11, 8), r11) // cs_b *= sizeof(double) + + lea(mem(r11, r11, 2), r13) // r13 = 3*cs_b + lea(mem(r8, r8, 2), r10) // r10 = 3*rs_a + + mov(var(rs_c), rdi) // load rs_c + lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + + mov(imm(0), r15) // jj = 0; + +// ----------------------- Block 1 + + mov(var(a), r14) // load address of a + mov(var(b), rdx) // load address of b + mov(var(c), r12) // load address of c + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8 + lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rdx) // rbx = b + 4*jj*cs_b; + + mov(var(m_iter), r9) // ii = m_iter; + + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK1) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK1) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 2 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK1) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK1) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK1) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK1) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK1) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK1) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK1) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK1) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK1) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK1) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK1) // if ZF = 1, jump to beta == 0 case + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vextractf128(imm(1), ymm4, xmm1 ) + vmovupd(xmm4, mem(rcx)) + vmovlpd(xmm1, mem(rcx, 2*8)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + + jmp(.DDONE_BLOCK1) // jump to end. + + label(.DBETAZERO_BLOCK1) + + + vextractf128(imm(1), ymm4, xmm1 ) + vmovupd(xmm4, mem(rcx)) + vmovlpd(xmm1, mem(rcx, 2*8)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + + label(.DDONE_BLOCK1) + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + +// ----------------------- Block 2 + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK2) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK2) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 2 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK2) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK2) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK2) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK2) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK2) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK2) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK2) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK2) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK2) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK2) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK2) // if ZF = 1, jump to beta == 0 case + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + + jmp(.DDONE_BLOCK2) // jump to end. + + label(.DBETAZERO_BLOCK2) + + + vmovupd(ymm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm5, mem(rcx)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + + label(.DDONE_BLOCK2) + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + + add(imm(4), r15) // jj += 4; + +// ----------------------- Block 3 + mov(var(a), r14) // load address of a + mov(var(b), rdx) // load address of b + mov(var(c), r12) // load address of c + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(imm(1*8), rsi) // rsi *= cs_c*sizeof(double) = 1*8 + lea(mem(r12, rsi, 1), r12) // r12 = c + 4*jj*cs_c; + + lea(mem( , r15, 1), rsi) // rsi = r15 = 4*jj; + imul(r11, rsi) // rsi *= cs_b; + lea(mem(rdx, rsi, 1), rdx) // rbx = b + 4*jj*cs_b; + + mov(var(m_iter), r9) // ii = m_iter; + + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK3) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK3) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 2 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK3) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK3) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK3) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK3) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK3) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK3) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK3) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK3) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK3) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK3) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm6, ymm6) + + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK3) // if ZF = 1, jump to beta == 0 case + + + add(rdi, rcx) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovlpd(xmm6, mem(rcx)) + + jmp(.DDONE_BLOCK3) // jump to end. + + label(.DBETAZERO_BLOCK3) + + + add(rdi, rcx) + add(rdi, rcx) + vmovlpd(xmm6, mem(rcx)) + + label(.DDONE_BLOCK3) + + lea(mem(r12, rdi, 2), r12) // + lea(mem(r12, rdi, 1), r12) // c_ii = r12 += 3*rs_c + + lea(mem(r14, r8, 2), r14) // + lea(mem(r14, r8, 1), r14) // a_ii = r14 += 3*rs_a + + dec(r9) // ii -= 1; + +// ----------------------- Block 4 + vxorpd(ymm4, ymm4, ymm4) + vmovapd( ymm4, ymm5) + vmovapd( ymm4, ymm6) + vmovapd( ymm4, ymm7) + vmovapd( ymm4, ymm8) + vmovapd( ymm4, ymm9) + vmovapd( ymm4, ymm10) + vmovapd( ymm4, ymm11) + vmovapd( ymm4, ymm12) + vmovapd( ymm4, ymm13) + vmovapd( ymm4, ymm14) + vmovapd( ymm4, ymm15) + + lea(mem(r12), rcx) // rcx = c_iijj; + lea(mem(r14), rax) // rax = a_ii; + lea(mem(rdx), rbx) // rbx = b_jj; + + prefetch(0, mem(rcx, 3*8)) // prefetch c + 0*rs_c + prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c + prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c + lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + + mov(var(k_iter16), rsi) // i = k_iter16; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKITER4_BLOCK4) // if i == 0, jump to code that + // contains the k_iter4 loop. + + label(.DLOOPKITER16_BLOCK4) // MAIN LOOP + + // ---------------------------------- iteration 0 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 1 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 2 + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + // ---------------------------------- iteration 3 + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER16_BLOCK4) // iterate again if i != 0. + + label(.DCONSIDKITER4_BLOCK4) + + mov(var(k_iter4), rsi) // i = k_iter4; + test(rsi, rsi) // check i via logical AND. + je(.DCONSIDKLEFT1_BLOCK4) // if i == 0, jump to code that + // considers k_left1 loop. + // else, we prepare to enter k_iter4 loop. + + label(.DLOOPKITER4_BLOCK4) // EDGE LOOP (ymm) + + prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a + prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a + prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + + vmovupd(mem(rax ), ymm0) + vmovupd(mem(rax, r8, 1), ymm1) + vmovupd(mem(rax, r8, 2), ymm2) + add(imm(4*8), rax) // a += 4*cs_a = 4*8; + + vmovupd(mem(rbx ), ymm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovupd(mem(rbx, r11, 1), ymm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovupd(mem(rbx, r11, 2), ymm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovupd(mem(rbx, r13, 1), ymm3) + add(imm(4*8), rbx) // b += 4*rs_b = 4*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKITER4_BLOCK4) // iterate again if i != 0. + + label(.DCONSIDKLEFT1_BLOCK4) + + mov(var(k_left1), rsi) // i = k_left1; + test(rsi, rsi) // check i via logical AND. + je(.DPOSTACCUM_BLOCK4) // if i == 0, we're done; jump to end. + // else, we prepare to enter k_left1 loop. + + label(.DLOOPKLEFT1_BLOCK4) // EDGE LOOP (scalar) + // NOTE: We must use ymm registers here bc + // using the xmm registers would zero out the + // high bits of the destination registers, + // which would destory intermediate results. + + vmovsd(mem(rax ), xmm0) + vmovsd(mem(rax, r8, 1), xmm1) + vmovsd(mem(rax, r8, 2), xmm2) + add(imm(1*8), rax) // a += 1*cs_a = 1*8; + + vmovsd(mem(rbx ), xmm3) + vfmadd231pd(ymm0, ymm3, ymm4) + vfmadd231pd(ymm1, ymm3, ymm5) + vfmadd231pd(ymm2, ymm3, ymm6) + + vmovsd(mem(rbx, r11, 1), xmm3) + vfmadd231pd(ymm0, ymm3, ymm7) + vfmadd231pd(ymm1, ymm3, ymm8) + vfmadd231pd(ymm2, ymm3, ymm9) + + vmovsd(mem(rbx, r11, 2), xmm3) + vfmadd231pd(ymm0, ymm3, ymm10) + vfmadd231pd(ymm1, ymm3, ymm11) + vfmadd231pd(ymm2, ymm3, ymm12) + + vmovsd(mem(rbx, r13, 1), xmm3) + add(imm(1*8), rbx) // b += 1*rs_b = 1*8; + vfmadd231pd(ymm0, ymm3, ymm13) + vfmadd231pd(ymm1, ymm3, ymm14) + vfmadd231pd(ymm2, ymm3, ymm15) + + dec(rsi) // i -= 1; + jne(.DLOOPKLEFT1_BLOCK4) // iterate again if i != 0. + + label(.DPOSTACCUM_BLOCK4) + + // ymm4 ymm7 ymm10 ymm13 + // ymm5 ymm8 ymm11 ymm14 + // ymm6 ymm9 ymm12 ymm15 + + vhaddpd( ymm7, ymm4, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm13, ymm10, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm4 ) + // xmm4[0] = sum(ymm4); xmm4[1] = sum(ymm7) + // xmm4[2] = sum(ymm10); xmm4[3] = sum(ymm13) + + vhaddpd( ymm8, ymm5, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm14, ymm11, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm5 ) + // xmm5[0] = sum(ymm5); xmm5[1] = sum(ymm8) + // xmm5[2] = sum(ymm11); xmm5[3] = sum(ymm14) + + vhaddpd( ymm9, ymm6, ymm0 ) + vextractf128(imm(1), ymm0, xmm1 ) + vaddpd( xmm0, xmm1, xmm0 ) + + vhaddpd( ymm15, ymm12, ymm2 ) + vextractf128(imm(1), ymm2, xmm1 ) + vaddpd( xmm2, xmm1, xmm2 ) + + vperm2f128(imm(0x20), ymm2, ymm0, ymm6 ) + // xmm6[0] = sum(ymm6); xmm6[1] = sum(ymm9) + // xmm6[2] = sum(ymm12); xmm6[3] = sum(ymm15) + + mov(var(alpha), rax) // load address of alpha + mov(var(beta), rbx) // load address of beta + vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate + vbroadcastsd(mem(rbx), ymm3) // load beta and duplicate + + vmulpd(ymm0, ymm4, ymm4) // scale by alpha + vmulpd(ymm0, ymm5, ymm5) + vmulpd(ymm0, ymm6, ymm6) + + // now avoid loading C if beta == 0 + + vxorpd(ymm0, ymm0, ymm0) // set ymm0 to zero. + vucomisd(xmm0, xmm3) // set ZF if beta == 0. + je(.DBETAZERO_BLOCK4) // if ZF = 1, jump to beta == 0 case + + + vfmadd231pd(mem(rcx), ymm3, ymm4) + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm5) + vmovupd(xmm5, mem(rcx)) + vextractf128(imm(1), ymm5, xmm1 ) + vmovlpd(xmm1, mem(rcx, 2*8)) + add(rdi, rcx) + + vfmadd231pd(mem(rcx), ymm3, ymm6) + vmovupd(ymm6, mem(rcx)) + + jmp(.DDONE_BLOCK4) // jump to end. + + label(.DBETAZERO_BLOCK4) + + + vmovupd(xmm4, mem(rcx)) + add(rdi, rcx) + + vmovupd(xmm5, mem(rcx)) + vextractf128(imm(1), ymm5, xmm1 ) + vmovlpd(xmm1, mem(rcx, 2*8)) + add(rdi, rcx) + + vmovupd(ymm6, mem(rcx)) + + + label(.DDONE_BLOCK4) + + vzeroupper() end_asm( : // output operands (none) @@ -686,42 +12832,6 @@ void bli_dgemmsup_rd_haswell_asm_6x8m "memory" ) - consider_edge_cases: - - // Handle edge cases in the m dimension, if they exist. - if ( m_left ) - { - const dim_t nr_cur = 8; - const dim_t i_edge = m0 - ( dim_t )m_left; - - double* restrict cij = c + i_edge*rs_c; - double* restrict bj = b; - double* restrict ai = a + i_edge*rs_a; - - if ( 2 == m_left ) - { - const dim_t mr_cur = 2; - - bli_dgemmsup_rd_haswell_asm_2x8 - ( - conja, conjb, mr_cur, nr_cur, k0, - alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, - beta, cij, rs_c0, cs_c0, data, cntx - ); - //cij += mr_cur*rs_c0; ai += mr_cur*rs_a0; m_left -= mr_cur; - } - if ( 1 == m_left ) - { - const dim_t mr_cur = 1; - - bli_dgemmsup_rd_haswell_asm_1x8 - ( - conja, conjb, mr_cur, nr_cur, k0, - alpha, ai, rs_a0, cs_a0, bj, rs_b0, cs_b0, - beta, cij, rs_c0, cs_c0, data, cntx - ); - } - } AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); } @@ -1228,7 +13338,7 @@ void bli_dgemmsup_rd_haswell_asm_6x4m label(.DRETURN) - + vzeroupper() end_asm( : // output operands (none) @@ -1844,7 +13954,7 @@ void bli_dgemmsup_rd_haswell_asm_6x2m label(.DRETURN) - + vzeroupper() end_asm( : // output operands (none) diff --git a/kernels/haswell/bli_kernels_haswell.h b/kernels/haswell/bli_kernels_haswell.h index 5b4c8a05bc..d841d715f3 100644 --- a/kernels/haswell/bli_kernels_haswell.h +++ b/kernels/haswell/bli_kernels_haswell.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2022, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -294,6 +294,22 @@ GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8m_18x16_L ) GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8m_16x12_combined_L ) GEMMSUP_KER_PROT( double, d, gemmsup_rv_haswell_asm_6x8m_0x0_combined_U ) +GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_6x8m_0x0_U ) +GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_6x8m_6x0_U ) +GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_6x8m_6x8_U ) +GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_6x8m_12x8_U ) +GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_6x8m_12x16_U ) +GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_6x8m_18x16_U ) +GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_6x8m_0x0_combined_U ) +GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_6x8m_0x0_L ) +GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_6x8m_6x0_L ) +GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_6x8m_6x8_L ) +GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_6x8m_12x8_L ) +GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_6x8m_12x16_L ) +GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_6x8m_18x16_L ) +GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_6x8m_16x12_combined_L ) +GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_6x8m_0x0_combined_U ) + GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_6x8m ) GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_6x4m ) GEMMSUP_KER_PROT( double, d, gemmsup_rd_haswell_asm_6x2m ) From 0b81f530746f856f7cc8f7e5be9684b016d712db Mon Sep 17 00:00:00 2001 From: satish kumar nuggu Date: Fri, 19 Aug 2022 08:47:02 +0530 Subject: [PATCH 188/243] Fixed bug in DZGEMM 1. In zen4 dgemm and sgemm native kernels are column-prefer kernels, cgemm and zgemm native kernels are row-prefer kernels. zen3 and older arch (uses row-prefer kernels for all datatypes) hence induced-transpose carried out based on kernel preference check. Added a condition check, output matrix storage format need to be checked along with kernel preference to avoid induced-transpose for zen4. 2. Added functions bli_cntx_l3_vir_ukr_dislikes_storage_of_md, bli_cntx_l3_vir_ukr_prefers_storage_of_md for checking output matrix storage format and micro kernel preference of mixed datatypes. AMD-Internal: [CPUPL-2347] Change-Id: Ib77676f4e2152f7876ad7dc91de716547f5ba3a5 --- frame/3/gemm/bli_gemm_md.c | 8 ++------ frame/base/bli_cntx.h | 21 +++++++++++++++++++++ 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/frame/3/gemm/bli_gemm_md.c b/frame/3/gemm/bli_gemm_md.c index 0f82b15f3e..c9450a26c2 100644 --- a/frame/3/gemm/bli_gemm_md.c +++ b/frame/3/gemm/bli_gemm_md.c @@ -172,14 +172,12 @@ mddm_t bli_gemm_md_ccr // that computation datatype to query the corresponding ukernel output // preference. const num_t dt = BLIS_REAL | bli_obj_comp_prec( c ); - const bool row_pref - = bli_cntx_l3_nat_ukr_prefers_rows_dt( dt, BLIS_GEMM_UKR, *cntx ); // We can only perform this case of mixed-domain gemm, C += A*B where // B is real, if the microkernel prefers column output. If it prefers // row output, we must induce a transposition and perform C += A*B // where A (formerly B) is real. - if ( row_pref ) + if ( bli_cntx_l3_vir_ukr_dislikes_storage_of_md( c, dt, BLIS_GEMM_UKR, *cntx ) ) { bli_obj_swap( a, b ); @@ -273,14 +271,12 @@ mddm_t bli_gemm_md_crc // that computation datatype to query the corresponding ukernel output // preference. const num_t dt = BLIS_REAL | bli_obj_comp_prec( c ); - const bool col_pref - = bli_cntx_l3_nat_ukr_prefers_cols_dt( dt, BLIS_GEMM_UKR, *cntx ); // We can only perform this case of mixed-domain gemm, C += A*B where // A is real, if the microkernel prefers row output. If it prefers // column output, we must induce a transposition and perform C += A*B // where B (formerly A) is real. - if ( col_pref ) + if ( bli_cntx_l3_vir_ukr_dislikes_storage_of_md( c, dt, BLIS_GEMM_UKR, *cntx ) ) { bli_obj_swap( a, b ); diff --git a/frame/base/bli_cntx.h b/frame/base/bli_cntx.h index 3715d70c9f..fcef4738fd 100644 --- a/frame/base/bli_cntx.h +++ b/frame/base/bli_cntx.h @@ -601,6 +601,27 @@ BLIS_INLINE bool bli_cntx_l3_vir_ukr_dislikes_storage_of( obj_t* obj, l3ukr_t uk !bli_cntx_l3_vir_ukr_prefers_storage_of( obj, ukr_id, cntx ); } +BLIS_INLINE bool bli_cntx_l3_vir_ukr_prefers_storage_of_md( obj_t* obj, num_t dt, l3ukr_t ukr_id, cntx_t* cntx ) +{ + // we use the computation datatype, which may differ from the + // storage datatype of C + const bool ukr_prefers_rows + = bli_cntx_l3_vir_ukr_prefers_rows_dt( dt, ukr_id, cntx ); + const bool ukr_prefers_cols + = bli_cntx_l3_vir_ukr_prefers_cols_dt( dt, ukr_id, cntx ); + bool r_val = FALSE; + + if ( bli_obj_is_row_stored( obj ) && ukr_prefers_rows ) r_val = TRUE; + else if ( bli_obj_is_col_stored( obj ) && ukr_prefers_cols ) r_val = TRUE; + return r_val; +} + +BLIS_INLINE bool bli_cntx_l3_vir_ukr_dislikes_storage_of_md( obj_t* obj, num_t dt, l3ukr_t ukr_id, cntx_t* cntx ) +{ + return ( bool ) + !bli_cntx_l3_vir_ukr_prefers_storage_of_md( obj, dt, ukr_id, cntx ); +} + // ----------------------------------------------------------------------------- BLIS_INLINE bool bli_cntx_l3_sup_thresh_is_met( obj_t* a, obj_t* b, obj_t* c, cntx_t* cntx ) { From 219c41ded91f76f7f5903a70988094b6a6d765f6 Mon Sep 17 00:00:00 2001 From: satish kumar nuggu Date: Sat, 20 Aug 2022 13:04:33 +0530 Subject: [PATCH 189/243] ZTRSM Improvements Details: 1. Optimized ztrsm for small sizes upto 500 in multi thread scenarios. 2. Enabled multithreading execution for bli_trsm_small implementation for double complex data type. 3. Added decision logic to choose between native vs multi-threaded small path for sizes upto 500 and threads upto 8. AMD-Internal: [CPUPL-2340] Change-Id: I4df9d7e6ee152baa9cf33e58d36e1c17f75a00c1 --- frame/compat/bla_trsm_amd.c | 23 +++++++++++++++++++++ kernels/zen/3/bli_trsm_small.c | 37 +++++++++++++++++++++++++++++++--- kernels/zen/bli_kernels_zen.h | 11 ++++++++++ 3 files changed, 68 insertions(+), 3 deletions(-) diff --git a/frame/compat/bla_trsm_amd.c b/frame/compat/bla_trsm_amd.c index e1c997717d..8ca7434bd8 100644 --- a/frame/compat/bla_trsm_amd.c +++ b/frame/compat/bla_trsm_amd.c @@ -1291,6 +1291,29 @@ void ztrsm_ return; } } +#ifdef BLIS_ENABLE_OPENMP + + // bli_trsm_small_mt supports till n_threads equal to 8 + if( bli_cntx_trsm_small_thresh_is_met_zen(&ao, m0, n0) == true ) + { + err_t status; + status = bli_trsm_small_mt( + blis_side, + &alphao, + &ao, + &bo, + NULL, + NULL); + + if ( status == BLIS_SUCCESS ) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + /* Finalize BLIS. */ + bli_finalize_auto(); + return; + } + } +#endif// BLIS_ENABLE_OPENMP } // bli_cpuid_is_avx_supported} #endif diff --git a/kernels/zen/3/bli_trsm_small.c b/kernels/zen/3/bli_trsm_small.c index 168fe48d7d..0b47ddd862 100644 --- a/kernels/zen/3/bli_trsm_small.c +++ b/kernels/zen/3/bli_trsm_small.c @@ -4201,14 +4201,16 @@ err_t bli_trsm_small case BLIS_FLOAT: case BLIS_SCOMPLEX: { - if(m > 1000 || n > 1000) { + bool nt = bli_thread_get_is_parallel(); + if((nt == 0) && (m > 1000 || n > 1000)) { return BLIS_NOT_YET_IMPLEMENTED; } break; } case BLIS_DCOMPLEX: { - if(m > 500 || n > 500) { + bool nt = bli_thread_get_is_parallel(); + if((nt == 0) && (m > 500 || n > 500)) { return BLIS_NOT_YET_IMPLEMENTED; } break; @@ -4289,6 +4291,11 @@ err_t bli_trsm_small_mt d_mr = 8,d_nr = 6; break; } + case BLIS_DCOMPLEX: + { + d_mr = 4,d_nr = 3; + break; + } default: { return BLIS_NOT_YET_IMPLEMENTED; @@ -4303,7 +4310,7 @@ err_t bli_trsm_small_mt // If dynamic-threading is enabled, calculate optimum number // of threads. // rntm will be updated with optimum number of threads. - if( bli_obj_is_double(b)) + if( bli_obj_is_double(b) ) { bli_nthreads_optimum(a, b, b, BLIS_TRSM, &rntm); } @@ -46694,4 +46701,28 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB return BLIS_SUCCESS; } +/* + * Check if the TRSM small path should be taken for this + * input and threads combination + */ +bool bli_cntx_trsm_small_thresh_is_met_zen(obj_t* a,dim_t m, dim_t n) +{ + rntm_t rntm; + bli_rntm_init_from_global(&rntm); + dim_t n_threads = bli_rntm_num_threads(&rntm); + + if(bli_obj_is_dcomplex(a)) + { + if ((n_threads > 1) && (n_threads <= 8) && (m <= 500) && (n <= 500)) + { + return true; + } + else + { + return false; + } + } + return false; +} + #endif //BLIS_ENABLE_SMALL_MATRIX_TRSM diff --git a/kernels/zen/bli_kernels_zen.h b/kernels/zen/bli_kernels_zen.h index 1d18d711e1..4cec80773f 100644 --- a/kernels/zen/bli_kernels_zen.h +++ b/kernels/zen/bli_kernels_zen.h @@ -380,6 +380,17 @@ bool bli_cntx_syrksup_thresh_is_met_zen cntx_t* cntx ); +/* + * Check if the TRSM small path should be taken for this + * input and threads combination + */ +bool bli_cntx_trsm_small_thresh_is_met_zen + ( + obj_t* a, + dim_t m, + dim_t n + ); + #ifdef BLIS_ENABLE_FAST_MATH void bli_dnorm2fv_unb_var1 ( From 4e3e00fb7ec46cba5760f58869cb403bc6d62c42 Mon Sep 17 00:00:00 2001 From: eashdash Date: Wed, 17 Aug 2022 08:25:30 +0000 Subject: [PATCH 190/243] Added low precision GEMM - bf16bf16f32of32 Feature Addition: Added a new variant of low precision GEMM to addon - BFloat16. The kernel takes bf16 type inputs and perform BF16 GEMM operations. The intermediate accumulation and output are in float. 1. Compute kernels will perform computations only if B matrix is reordered in accordance with the usage of AVX-512 BF16 instruction - dpbf16_ps 2. Kernel for packing B matrix is provided Change-Id: If5d08213068869eff060c9998596d2d2703a6793 --- Makefile | 4 +- addon/aocl_gemm/aocl_bf16_type.h | 36 + addon/aocl_gemm/aocl_gemm_bf16.c | 157 + addon/aocl_gemm/aocl_gemm_bf16_utils.c | 126 + addon/aocl_gemm/aocl_gemm_interface_apis.h | 4 + .../aocl_gemm/frame/bf16bf16f32/lpgemm_bf16.c | 137 + .../frame/bf16bf16f32/lpgemm_reorder_bf16.c | 93 + .../frame/bf16bf16f32/lpgemm_reorder_bf16.h | 46 + .../frame/lpgemm_5loop_interface_apis.h | 2 + addon/aocl_gemm/frame/lpgemm_config.c | 7 +- addon/aocl_gemm/frame/lpgemm_config.h | 3 +- addon/aocl_gemm/frame/lpgemm_types.h | 3 +- .../threading/lpgemm_thread_decor_openmp.c | 54 + .../threading/lpgemm_thread_decor_openmp.h | 3 + .../lpgemm_6x64rowmajor_bf16_amd512vnni.c | 694 ++++ .../lpgemm_m_fringe_bf16_amd512vnni.c | 1313 +++++++ .../lpgemm_mn_fringe_bf16_amd512vnni.c | 3069 +++++++++++++++++ .../lpgemm_n_fringe_bf16_amd512vnni.c | 1590 +++++++++ .../kernels/bf16bf16f32/lpgemm_packb_bf16.h | 67 + .../lpgemm_packb_bf16_amd512vnni.c | 504 +++ addon/aocl_gemm/kernels/lpgemm_kernels.h | 36 + bench/bench_aocl_gemm/Makefile | 2 +- bench/bench_aocl_gemm/bench_lpgemm.c | 726 ---- 23 files changed, 7942 insertions(+), 734 deletions(-) create mode 100644 addon/aocl_gemm/aocl_bf16_type.h create mode 100644 addon/aocl_gemm/aocl_gemm_bf16.c create mode 100644 addon/aocl_gemm/aocl_gemm_bf16_utils.c create mode 100644 addon/aocl_gemm/frame/bf16bf16f32/lpgemm_bf16.c create mode 100644 addon/aocl_gemm/frame/bf16bf16f32/lpgemm_reorder_bf16.c create mode 100644 addon/aocl_gemm/frame/bf16bf16f32/lpgemm_reorder_bf16.h create mode 100644 addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_6x64rowmajor_bf16_amd512vnni.c create mode 100644 addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_m_fringe_bf16_amd512vnni.c create mode 100644 addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_mn_fringe_bf16_amd512vnni.c create mode 100644 addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_n_fringe_bf16_amd512vnni.c create mode 100644 addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_packb_bf16.h create mode 100644 addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_packb_bf16_amd512vnni.c delete mode 100644 bench/bench_aocl_gemm/bench_lpgemm.c diff --git a/Makefile b/Makefile index f42914024e..7a730d9c25 100644 --- a/Makefile +++ b/Makefile @@ -599,11 +599,11 @@ define make-c99-addon-rule $(BASE_OBJ_ADDON_PATH)/%.o: $(ADDON_PATH)/%.$(2) $(BLIS_H_FLAT) $(ADDON_H99_FILES) $(MAKE_DEFS_MK_PATHS) ifeq ($(ENABLE_VERBOSE),yes) $$(if $$(findstring _amd512vnni,$$<),$$(eval LPGEMM_MARCH_VAR=icelake-server),$$(eval LPGEMM_MARCH_VAR=znver3)) - $(CC) -march=$$(LPGEMM_MARCH_VAR) $(call get-addon-c99flags-for,$(1)) -c $$< -o $$@ + $(CC) -march=$$(LPGEMM_MARCH_VAR) -mavx512bf16 $(call get-addon-c99flags-for,$(1)) -c $$< -o $$@ else @echo "Compiling $$@" $(call get-addon-c99text-for,$(1)) $$(if $$(findstring _amd512vnni,$$<),$$(eval LPGEMM_MARCH_VAR=icelake-server),$$(eval LPGEMM_MARCH_VAR=znver3)) - @$(CC) -march=$$(LPGEMM_MARCH_VAR) $(call get-addon-c99flags-for,$(1)) -c $$< -o $$@ + @$(CC) -march=$$(LPGEMM_MARCH_VAR) -mavx512bf16 $(call get-addon-c99flags-for,$(1)) -c $$< -o $$@ endif endef diff --git a/addon/aocl_gemm/aocl_bf16_type.h b/addon/aocl_gemm/aocl_bf16_type.h new file mode 100644 index 0000000000..f8b2fd431a --- /dev/null +++ b/addon/aocl_gemm/aocl_bf16_type.h @@ -0,0 +1,36 @@ + +/* + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +#ifndef AOCL_GEMM_HALF_PRECISION_TYPE_H +#define AOCL_GEMM_HALF_PRECISION_TYPE_H + +typedef int16_t bfloat16; + +#endif // AOCL_GEMM_HALF_PRECISION_TYPE_H + diff --git a/addon/aocl_gemm/aocl_gemm_bf16.c b/addon/aocl_gemm/aocl_gemm_bf16.c new file mode 100644 index 0000000000..b61a2efb43 --- /dev/null +++ b/addon/aocl_gemm/aocl_gemm_bf16.c @@ -0,0 +1,157 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "aocl_gemm_interface_apis.h" +#include "lpgemm_types.h" +#include "lpgemm_post_ops.h" +#include "lpgemm_thread_decor_openmp.h" +#include "lpgemm_5loop_interface_apis.h" +#include "lpgemm_config.h" +#include "lpgemm_utils.h" + +AOCL_GEMM_MATMUL(bfloat16,bfloat16,float,bf16bf16f32of32) +{ + trans_t blis_transa; + trans_t blis_transb; + + // Check if avx512_vnni ISA is supported, lpgemm matmul only works with it. + if ( bli_cpuid_is_avx512vnni_supported() == FALSE ) + { + printf(" AVX512_BF16 ISA not supported by processor, cannot perform lpgemm.\n"); + return; // Error. + } + + /* Initialize BLIS. */ + bli_init_auto(); + + // Set MC, NC, KC, NR, MR. + aocl_lpgemm_init_global_cntx(); + + // Null check for pointers. + if ( ( a == NULL ) || ( b == NULL ) || ( c == NULL ) ) + { + return; // Error. + } + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + bli_param_map_netlib_to_blis_trans( transa, &blis_transa ); + bli_param_map_netlib_to_blis_trans( transb, &blis_transb ); + + /* Perform BLAS parameter checking. */ + // Transpose not supported. + if ( ( blis_transa != BLIS_NO_TRANSPOSE ) || + ( blis_transb != BLIS_NO_TRANSPOSE ) ) + { + return; // Error. + } + if ( ( order != 'r' ) && ( order != 'R' ) ) + { + return; // Only row major supported. + } + + // Row major input expected with leading dimensions equal to row stride. + if ( ( lda != k ) || ( ldb != n ) || ( ldc != n ) ) + { + return; // Error. + } + + // Check if dimensions are valid. + if ( ( m <= 0) || ( n <= 0 ) || ( k <= 0 ) || + ( lda <= 0 ) || ( ldb <= 0 ) || ( ldc <= 0 ) ) + { + return; // Error. + } + + const inc_t rs_a = lda; + const inc_t cs_a = 1; + const inc_t rs_b = ldb; + const inc_t cs_b = 1; + const inc_t rs_c = ldc; + + AOCL_MEMORY_TAG mtag_a; + AOCL_MEMORY_TAG mtag_b; + + bli_param_map_char_to_lpmtag( mem_format_a, &mtag_a ); + bli_param_map_char_to_lpmtag( mem_format_b, &mtag_b ); + + // B matrix needs to be packed in a certain format in order to be loaded + // and used in bf16 instrution. As such the mtag_b always needs to be either + // packed or reordered. B matrix as it is (unpacked) cannot be used, and + // the mtag_b is set to packed to enable runtime packing. + if ( mtag_b == UNPACKED ) + { + mtag_b = PACK; + } + + // Only unpacked A supported now. + if ( mtag_a != UNPACKED ) + { + return; // Error. + } + + // Convert post op struct to post op linked list format. + lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS]; + lpgemm_translate_to_post_ops_list( post_op_unparsed, post_op_list ); + + // Initialize a local runtime with global settings if necessary. Note + // that in the case that a runtime is passed in, we make a local copy. + rntm_t rntm_g; + bli_rntm_init_from_global( &rntm_g ); + bli_membrk_rntm_set_membrk( &rntm_g ); + +#ifdef BLIS_ENABLE_OPENMP + lpgemm_bf16bf16f32of32_openmp_thread_decorator + ( + m, n, k, + a, rs_a, cs_a, mtag_a, + b, rs_b, cs_b, mtag_b, + c, rs_c, + alpha, beta, + &rntm_g, + post_op_list + ); +#else + lpgemm_bf16bf16f32of32_thread_decorator + ( + m, n, k, + a, rs_a, cs_a, mtag_a, + b, rs_b, cs_b, mtag_b, + c, rs_c, + alpha, beta, + &rntm_g, + post_op_list + ); +#endif +} diff --git a/addon/aocl_gemm/aocl_gemm_bf16_utils.c b/addon/aocl_gemm/aocl_gemm_bf16_utils.c new file mode 100644 index 0000000000..b65d04bd43 --- /dev/null +++ b/addon/aocl_gemm/aocl_gemm_bf16_utils.c @@ -0,0 +1,126 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "aocl_gemm_interface_apis.h" +#include "lpgemm_types.h" +#include "lpgemm_config.h" +#include "lpgemm_utils.h" +#include "lpgemm_reorder_bf16.h" + +AOCL_GEMM_GET_REORDER_BUF_SIZE(bf16bf16f32of32) +{ + if ( ( k <= 0 ) || ( n <= 0 ) ) + { + return 0; // Error. + } + + // Check if avx512_bf16 ISA is supported, lpgemm matmul only works with it. + if ( bli_cpuid_is_avx512vnni_supported() == FALSE ) + { + printf(" AVX512_BF16 ISA not supported by processor, cannot perform lpgemm.\n"); + return 0; // Error. + } + + /* Initialize BLIS. */ + bli_init_auto(); + + // Set MC, NC, KC, NR, MR. + aocl_lpgemm_init_global_cntx(); + + AOCL_MATRIX_TYPE input_mat_type; + bli_param_map_char_to_lpmat_type( mat_type, &input_mat_type ); + + if ( input_mat_type == A_MATRIX ) + { + return 0; // A reorder not supported. + } + + // Extra space since packing does width in multiples of 16. The bf16 + // instruction can be used as long as atleast one zmm register can be fully + // loaded; and since k_dim needs to be atleast 2, having n_dim atleast 16 + // should give 2x16=32 elements, enough for 1 zmm register.The padding is + // not rounded to NR (=64), since that would result in memory wastage. + dim_t n_reorder = make_multiple_of_n( n, 16 ); + + // Extra space since packing does length in multiples of 2. + dim_t k_reorder = make_multiple_of_n( k, 2 ); + + siz_t size_req = sizeof( int16_t ) * k_reorder * n_reorder; + + return size_req; +} + +AOCL_GEMM_REORDER(bfloat16, bf16bf16f32of32) +{ + if ( ( input_buf_addr == NULL ) || ( reorder_buf_addr == NULL ) || + ( k <= 0 ) || ( n <= 0 ) || ( ldb < n ) ) + { + return; // Error. + } + + // Check if avx512_bf16 ISA is supported, lpgemm matmul only works with it. + if ( bli_cpuid_is_avx512vnni_supported() == FALSE ) + { + printf(" AVX512_BF16 ISA not supported by processor, cannot perform lpgemm.\n"); + return; // Error. + } + + /* Initialize BLIS. */ + bli_init_auto(); + + // Set MC, NC, KC, NR, MR. + aocl_lpgemm_init_global_cntx(); + + AOCL_MATRIX_TYPE input_mat_type; + bli_param_map_char_to_lpmat_type( mat_type, &input_mat_type ); + + if ( input_mat_type == A_MATRIX ) + { + return; // A reorder not supported. + } + + // Create dummy b_reorder obj. + lpgemm_obj_t b_reorder; + b_reorder.storage.aligned_buffer = reorder_buf_addr; + + // Create dummy original b obj; + lpgemm_obj_t b; + b.storage.aligned_buffer = ( void* )input_buf_addr; + b.rs = ldb; + b.width = n; + b.length = k; + + reorderb_nr64_bf16bf16f32of32( &b, &b_reorder ); +} diff --git a/addon/aocl_gemm/aocl_gemm_interface_apis.h b/addon/aocl_gemm/aocl_gemm_interface_apis.h index 0c656554b1..b38c2c1599 100644 --- a/addon/aocl_gemm/aocl_gemm_interface_apis.h +++ b/addon/aocl_gemm/aocl_gemm_interface_apis.h @@ -36,6 +36,7 @@ #define AOCL_GEMM_INTERFACE_H #include "aocl_gemm_post_ops.h" +#include "aocl_bf16_type.h" // Returns the size of buffer in bytes required for the reordered matrix. #define AOCL_GEMM_GET_REORDER_BUF_SIZE(LP_SFX) \ @@ -49,6 +50,7 @@ BLIS_EXPORT_ADDON siz_t aocl_get_reorder_buf_size_ ## LP_SFX \ AOCL_GEMM_GET_REORDER_BUF_SIZE(f32f32f32of32); AOCL_GEMM_GET_REORDER_BUF_SIZE(u8s8s32os32); AOCL_GEMM_GET_REORDER_BUF_SIZE(u8s8s16os16); +AOCL_GEMM_GET_REORDER_BUF_SIZE(bf16bf16f32of32); // Performs reordering of input matrix. Reordering is the process of packing // the entire matrix upfront, so that the benefits of packed matrix is obtained @@ -67,6 +69,7 @@ BLIS_EXPORT_ADDON void aocl_reorder_ ## LP_SFX \ AOCL_GEMM_REORDER(float,f32f32f32of32); AOCL_GEMM_REORDER(int8_t,u8s8s32os32); AOCL_GEMM_REORDER(int8_t,u8s8s16os16); +AOCL_GEMM_REORDER(bfloat16,bf16bf16f32of32); // Only supports matrices in row major format. This api can perform gemm with // both normal as well as reordered B matrix as opposesd to sgemm (only @@ -96,5 +99,6 @@ BLIS_EXPORT_ADDON void aocl_gemm_ ## LP_SFX \ AOCL_GEMM_MATMUL(float,float,float,f32f32f32of32); AOCL_GEMM_MATMUL(uint8_t,int8_t,int32_t,u8s8s32os32); AOCL_GEMM_MATMUL(uint8_t,int8_t,int16_t,u8s8s16os16); +AOCL_GEMM_MATMUL(bfloat16,bfloat16,float,bf16bf16f32of32); #endif // AOCL_GEMM_INTERFACE_H diff --git a/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_bf16.c b/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_bf16.c new file mode 100644 index 0000000000..8988537184 --- /dev/null +++ b/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_bf16.c @@ -0,0 +1,137 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "lpgemm_5loop_interface_apis.h" +#include "lpgemm_packb_bf16.h" +#include "lpgemm_kernels.h" +#include "lpgemm_utils.h" +#include "lpgemm_thrinfo_utils.h" +#include "lpgemm_config.h" + +// B should always be packed. +LPGEMM_5LOOP(bfloat16,bfloat16,float,bf16bf16f32of32) +{ + dim_t NC = lpgemm_get_block_size_NC_global_cntx( BF16BF16F32OF32 ); + dim_t KC = lpgemm_get_block_size_KC_global_cntx( BF16BF16F32OF32 ); + dim_t MC = lpgemm_get_block_size_MC_global_cntx( BF16BF16F32OF32 ); + dim_t NR = lpgemm_get_block_size_NR_global_cntx( BF16BF16F32OF32 ); + dim_t MR = lpgemm_get_block_size_MR_global_cntx( BF16BF16F32OF32 ); + + const int16_t* a_use = NULL; + dim_t rs_a_use = rs_a; + dim_t cs_a_use = cs_a; + dim_t a_block_stride = 0; + + const int16_t* b_use = NULL; + dim_t rs_b_use = rs_b; + dim_t cs_b_use = cs_b; + + float* c_use_jc = NULL; + float* c_use_ic = NULL; + + // kc needs to be a multiple of 2 so that it can be used with dpbf16_ps + // instruction. Padding is added in cases this condition is not + // satisfied, and therefore the k offset used for packed/reordered + // buffer needs to be updated. + dim_t k_updated = k; + k_updated += (k_updated & 0x1); + + // Is required to decide whether to apply post ops or not. + bool is_last_k = FALSE; + + for ( dim_t jc = 0; jc < n; jc += NC ) + { + dim_t nc0 = ( ( jc + NC ) <= n ) ? NC : ( n % NC ); + + dim_t nc0_mod16 = nc0 % 16; + dim_t nc0_updated = nc0; + if ( nc0_mod16 != 0 ) + { + nc0_updated += ( 16 - nc0_mod16 ); + } + + for ( dim_t pc = 0; pc < k; pc += KC ) + { + float beta0 = ( pc == 0 ) ? beta : 1; + dim_t kc0 = ( ( pc + KC ) <= k ) ? KC : ( k % KC ); + + // kc0 needs to be a multiple of 2 so that it can be + // used with dpbf16_ps instruction. Padding is added in + // cases this condition is not satisfied, and therefore + // the kc0 offsets used for packed/reordered buffers + // needs to be updated. + dim_t kc0_updated = kc0; + kc0_updated += (kc0_updated & 0x1); + + is_last_k = ( ( pc + KC ) >= k ) ? ( TRUE ) : ( FALSE ); + + // B part getting processed + if ( mtag_b == REORDERED ) + { + b_use = b + ( jc * k_updated ) + ( pc * nc0_updated ); + get_packb_nr64_bf16bf16f32of32_strides( &rs_b_use, &cs_b_use ); + } + + for ( dim_t ic = 0; ic < m; ic += MC ) + { + dim_t mc0 = ( ( ic + MC ) <= m ) ? MC : ( m % MC ); + + a_use = a + ( rs_a * ic ) + ( cs_a * pc ); + + // bf16 kernel reads 2 elements, totalling 4 bytes in a + // single broadcast for use in bf16 instruction. + // Non bf16 based kernel requires update to this code. + cs_a_use = 2; + a_block_stride = rs_a; + + for ( dim_t jr = 0; jr < nc0; jr += NR ) + { + dim_t nr0 = ( ( jr + NR ) <= nc0 ) ? NR : ( nc0 % NR ); + + // Reorder/Packed B, Reorder/Packed/Unpacked A call. + lpgemm_rowvar_bf16bf16f32of32_6x64 + ( + mc0, nr0, kc0, + a_use, rs_a, cs_a_use, a_block_stride, + ( b_use + ( jr * kc0_updated ) ), rs_b_use, cs_b_use, + ( c + ( rs_c * ic ) + jc + jr ), rs_c, 1, + alpha, beta0, + is_last_k, ic, ( jc + jr ), post_op_list + ); + } + } + } + } +} diff --git a/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_reorder_bf16.c b/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_reorder_bf16.c new file mode 100644 index 0000000000..5b3461d73d --- /dev/null +++ b/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_reorder_bf16.c @@ -0,0 +1,93 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "lpgemm_utils.h" +#include "lpgemm_reorder_bf16.h" +#include "lpgemm_packb_bf16.h" +#include "lpgemm_config.h" +#include "aocl_bf16_type.h" + +void reorderb_nr64_bf16bf16f32of32 + ( + lpgemm_obj_t *b, + lpgemm_obj_t *b_reorder + ) +{ + dim_t NC = lpgemm_get_block_size_NC_global_cntx( BF16BF16F32OF32 ); + dim_t KC = lpgemm_get_block_size_KC_global_cntx( BF16BF16F32OF32 ); + + // Extracting the matrix properties from the lpgemm object + dim_t rs_b = b->rs; + dim_t n = b->width; + dim_t k = b->length; + + dim_t rs_b_reorder; + dim_t cs_b_reorder; + + // k needs to be a multiple of 2 so that it can be used with vpdpbusd + // instruction. Padding is added in cases this condition is not + // satisfied, and therefore the k offset used for packed/reordered + // buffer needs to be updated. + dim_t k_updated = k; + k_updated += (k_updated & 0x1); + + for ( dim_t jc = 0; jc < n; jc += NC ) + { + dim_t nc0 = ( ( jc + NC ) <= n ) ? NC : ( n % NC ); + + dim_t nc0_mod16 = nc0 % 16; + dim_t nc0_updated = nc0; + if ( nc0_mod16 != 0 ) + { + nc0_updated += ( 16 - nc0_mod16 ); + } + for ( dim_t pc = 0; pc < k; pc += KC ) + { + dim_t kc0 = ( ( pc + KC ) <= k ) ? KC : ( k % KC ); + // B should always be packed. + packb_nr64_bf16bf16f32of32 + ( + ( ( ( bfloat16* )b_reorder->storage.aligned_buffer ) + ( jc * k_updated ) + + ( nc0_updated * pc ) ), + ( ( ( bfloat16* )b->storage.aligned_buffer ) + ( rs_b * pc ) + jc ), + rs_b, nc0, kc0, &rs_b_reorder, &cs_b_reorder + ); + } + } + + b_reorder->rs = rs_b_reorder; + b_reorder->cs = cs_b_reorder; + b_reorder->mtag = REORDERED; +} diff --git a/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_reorder_bf16.h b/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_reorder_bf16.h new file mode 100644 index 0000000000..c1b83c1b75 --- /dev/null +++ b/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_reorder_bf16.h @@ -0,0 +1,46 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef LPGEMM_REORDER_BF16_H +#define LPGEMM_REORDER_BF16_H + +#include "lpgemm_types.h" + +void reorderb_nr64_bf16bf16f32of32 + ( + lpgemm_obj_t *b, + lpgemm_obj_t *b_reorder + ); + +#endif // LPGEMM_REORDER_H diff --git a/addon/aocl_gemm/frame/lpgemm_5loop_interface_apis.h b/addon/aocl_gemm/frame/lpgemm_5loop_interface_apis.h index 65a8715d29..5a5d2eaff8 100644 --- a/addon/aocl_gemm/frame/lpgemm_5loop_interface_apis.h +++ b/addon/aocl_gemm/frame/lpgemm_5loop_interface_apis.h @@ -37,6 +37,7 @@ #include "lpgemm_types.h" #include "lpgemm_post_ops.h" +#include "aocl_bf16_type.h" #define LPGEMM_5LOOP(A_type,B_type,C_type,LP_SFX) \ void lpgemm_rowvar_ ## LP_SFX \ @@ -64,4 +65,5 @@ void lpgemm_rowvar_ ## LP_SFX \ LPGEMM_5LOOP(uint8_t,int8_t,int32_t,u8s8s32o32); LPGEMM_5LOOP(uint8_t,int8_t,int16_t,u8s8s16o16); LPGEMM_5LOOP(float,float,float,f32f32f32of32); +LPGEMM_5LOOP(bfloat16,bfloat16,float,bf16bf16f32of32); #endif // LPGEMM_5LOOP_INTF_H diff --git a/addon/aocl_gemm/frame/lpgemm_config.c b/addon/aocl_gemm/frame/lpgemm_config.c index a23fd409f1..901ec087d2 100644 --- a/addon/aocl_gemm/frame/lpgemm_config.c +++ b/addon/aocl_gemm/frame/lpgemm_config.c @@ -35,7 +35,7 @@ #include "blis.h" #include "lpgemm_config.h" -lpgemm_cntx_t global_cntx_t_list[3]; //Only one op type supported now. +lpgemm_cntx_t global_cntx_t_list[4]; //Only one op type supported now. BLIS_INLINE void lpgemm_set_block_sizes_global_cntx ( @@ -59,8 +59,9 @@ BLIS_INLINE void lpgemm_set_block_sizes_global_cntx // to be configurable from application. void aocl_lpgemm_init_global_cntx() { - lpgemm_set_block_sizes_global_cntx( U8S8S32OS32, 144, 1024, 2048, 64, 6 ); - lpgemm_set_block_sizes_global_cntx( U8S8S16OS16, 144, 1024, 1024, 32, 6 ); + lpgemm_set_block_sizes_global_cntx( U8S8S32OS32, 144, 1024, 2048, 64, 6 ); + lpgemm_set_block_sizes_global_cntx( U8S8S16OS16, 144, 1024, 1024, 32, 6 ); + lpgemm_set_block_sizes_global_cntx( BF16BF16F32OF32, 144, 1024, 2048, 64, 6 ); } dim_t lpgemm_get_block_size_MC_global_cntx( AOCL_OPERATION_TYPE op_type ) diff --git a/addon/aocl_gemm/frame/lpgemm_config.h b/addon/aocl_gemm/frame/lpgemm_config.h index 8e25986f6d..7e7f3bb2ad 100644 --- a/addon/aocl_gemm/frame/lpgemm_config.h +++ b/addon/aocl_gemm/frame/lpgemm_config.h @@ -37,7 +37,8 @@ #include "lpgemm_types.h" -extern lpgemm_cntx_t lpgemm_global_cntx_t_list[3]; // equals to number of ops in enum AOCL_OPERATION_TYPE. +// equals to number of ops in enum AOCL_OPERATION_TYPE. +extern lpgemm_cntx_t lpgemm_global_cntx_t_list[4]; void aocl_lpgemm_init_global_cntx(); diff --git a/addon/aocl_gemm/frame/lpgemm_types.h b/addon/aocl_gemm/frame/lpgemm_types.h index 2d9cca79a4..aebd485d0d 100644 --- a/addon/aocl_gemm/frame/lpgemm_types.h +++ b/addon/aocl_gemm/frame/lpgemm_types.h @@ -47,7 +47,8 @@ typedef enum { U8S8S16OS16 = 0, // uint8_t - A, int8_t - B, int16_t - C U8S8S32OS32 = 1, // uint8_t - A, int8_t - B, int32_t - C - F16F16F16OF16 = 2 // float16 - A, float16 - B, float16 - C + F16F16F16OF16 = 2, // float16 - A, float16 - B, float16 - C + BF16BF16F32OF32 = 3 // bf16 - A, bf16 - B, float - C } AOCL_OPERATION_TYPE; typedef enum diff --git a/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.c b/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.c index 544f5b6c70..cf2c9231c3 100644 --- a/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.c +++ b/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.c @@ -334,6 +334,58 @@ BLIS_INLINE void lpgemm_u8s8s32o32_get_threading } } +BLIS_INLINE void lpgemm_bf16bf16f32of32_get_threading + ( + dim_t* n_threads, + dim_t* ic_ways, + dim_t* jc_ways, + dim_t m, + dim_t n, + dim_t k, + rntm_t* rntm_g + ) +{ + *n_threads = bli_rntm_num_threads( rntm_g ); + *jc_ways = bli_rntm_jc_ways( rntm_g ); + *ic_ways = bli_rntm_ic_ways( rntm_g ); + + if ( ( ( *ic_ways ) > 0 ) || ( ( *jc_ways ) > 0 ) ) + { + // If BLIS_IC_NT or JC_NT are set. + // Default cases. + *ic_ways = ( ( *ic_ways ) > 0 ) ? ( *ic_ways ) : 1; + *jc_ways = ( ( *jc_ways ) > 0 ) ? ( *jc_ways ) : 1; + + *n_threads = ( *jc_ways ) * ( *ic_ways ); + } + else if ( ( *n_threads ) > 1 ) + { + + dim_t NR = lpgemm_get_block_size_NR_global_cntx( BF16BF16F32OF32 ); + + if ( n <= NR ) + { + // If n is less than micro panel dimension, allocating all threads + // to ic resulted in gains. + ( *ic_ways ) = ( *n_threads ); + ( *jc_ways ) = 1; + } + else + { + // If BLIS_NUM_THREADS are set, generate jc,ic from the same. + bli_thread_partition_2x2( ( *n_threads ), m, n, ic_ways, jc_ways ); + } + } + else + { + // Setting all the values to 1 in case n_threads <= 1. This ensures + // the threading parameters are valid. + *n_threads = 1; + *jc_ways = 1; + *ic_ways = 1; + } +} + // Some aspects of sgemm smart threading incorporated here. Eventually this // will be redirected to the sgemm smart threading API. BLIS_INLINE void lpgemm_f32f32f32of32_get_threading @@ -496,6 +548,7 @@ void lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \ GEN_LPGEMM_OPENMP_DECORATOR(uint8_t,int8_t,int16_t,u8s8s16o16) GEN_LPGEMM_OPENMP_DECORATOR(uint8_t,int8_t,int32_t,u8s8s32o32) +GEN_LPGEMM_OPENMP_DECORATOR(bfloat16,bfloat16,float,bf16bf16f32of32) GEN_LPGEMM_OPENMP_DECORATOR(float,float,float,f32f32f32of32) #else @@ -564,6 +617,7 @@ void lpgemm_ ## LPGEMM_SFX ## _thread_decorator \ GEN_LPGEMM_DECORATOR(uint8_t,int8_t,int16_t,u8s8s16o16) GEN_LPGEMM_DECORATOR(uint8_t,int8_t,int32_t,u8s8s32o32) +GEN_LPGEMM_DECORATOR(bfloat16,bfloat16,float,bf16bf16f32of32) GEN_LPGEMM_DECORATOR(float,float,float,f32f32f32of32) #endif diff --git a/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.h b/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.h index 2b420b2f9c..82702a4cf6 100644 --- a/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.h +++ b/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.h @@ -37,6 +37,7 @@ #include "lpgemm_types.h" #include "lpgemm_post_ops.h" +#include "aocl_bf16_type.h" #ifdef BLIS_ENABLE_OPENMP @@ -64,6 +65,7 @@ void lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \ GEN_LPGEMM_OPENMP_DECORATOR_FN(uint8_t,int8_t,int16_t,u8s8s16o16) GEN_LPGEMM_OPENMP_DECORATOR_FN(uint8_t,int8_t,int32_t,u8s8s32o32) +GEN_LPGEMM_OPENMP_DECORATOR_FN(bfloat16,bfloat16,float,bf16bf16f32of32) GEN_LPGEMM_OPENMP_DECORATOR_FN(float,float,float,f32f32f32of32) #else @@ -92,6 +94,7 @@ void lpgemm_ ## LPGEMM_SFX ## _thread_decorator \ GEN_LPGEMM_DECORATOR_FN(uint8_t,int8_t,int32_t,u8s8s16o16) GEN_LPGEMM_DECORATOR_FN(uint8_t,int8_t,int32_t,u8s8s32o32) +GEN_LPGEMM_DECORATOR_FN(bfloat16,bfloat16,float,bf16bf16f32of32) GEN_LPGEMM_DECORATOR_FN(float,float,float,f32f32f32of32) #endif diff --git a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_6x64rowmajor_bf16_amd512vnni.c b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_6x64rowmajor_bf16_amd512vnni.c new file mode 100644 index 0000000000..e98c6e9872 --- /dev/null +++ b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_6x64rowmajor_bf16_amd512vnni.c @@ -0,0 +1,694 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS dim_tERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include + +#include "blis.h" +#include "lpgemm_kernels.h" + +// 6x64 bf16 kernel +LPGEMM_MAIN_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x64) +{ + dim_t MR = 6; + dim_t NR = 64; + + dim_t m_full_pieces = m0 / MR; + dim_t m_full_pieces_loop_limit = m_full_pieces * MR; + dim_t m_partial_pieces = m0 % MR; + + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int32_t a_kfringe_buf = 0; + + if ( n0 < NR ) + { + dim_t n0_rem = n0 % 16; + + // Split dim_to multiple smaller fringe kernels, so as to maximize + // vectorization. Any n0 < NR(64) can be expressed as n0 = 48 + n` + // or n0 = 32 + n` or n0 = 16 + n`, where n` < 16. + dim_t n0_48 = n0 / 48; + dim_t n0_32 = n0 / 32; + dim_t n0_16 = n0 / 16; + + // KC when not multiple of 2 will have padding to make it multiple of + // 2 in packed buffer. Also the k0 cannot be passed as the updated + // value since A matrix is not packed and requires original k0. + dim_t k0_updated = k0; + k0_updated += (k0_updated & 0x1); + + if ( n0_48 == 1 ) + { + lpgemm_rowvar_bf16bf16f32of32_6x48 + ( + m0, k0, + a, rs_a, cs_a, ps_a, + b, ( ( rs_b / 4 ) * 3 ), cs_b, + c, rs_c, + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list + ); + + b = b + ( 48 * k0_updated ); // k0x48 packed contiguosly. + c = c + 48; + } + + else if ( n0_32 == 1 ) + { + lpgemm_rowvar_bf16bf16f32of32_6x32 + ( + m0, k0, + a, rs_a, cs_a, ps_a, + b, ( ( rs_b / 4 ) * 2 ), cs_b, + c, rs_c, + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list + ); + + b = b + ( 32 * k0_updated ); // k0x32 packed contiguosly. + c = c + 32; + } + + else if ( n0_16 == 1 ) + { + lpgemm_rowvar_bf16bf16f32of32_6x16 + ( + m0, k0, + a, rs_a, cs_a, ps_a, + b, ( ( rs_b / 4 ) * 1 ), cs_b, + c, rs_c, + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list + ); + + b = b + ( 16 * k0_updated ); // k0x16 packed contiguosly. + c = c + 16; + } + + if ( n0_rem > 0 ) + { + lpgemm_rowvar_bf16bf16f32of32_6xlt16 + ( + m0, k0, + a, rs_a, cs_a, ps_a, + b, ( ( rs_b / 4 ) * 1 ), cs_b, + c, rs_c, + alpha, beta, n0_rem, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list + ); + + // No leftover fringe after this podint. + } + return; + } + + // B matrix storage bfloat type + __m512bh b0; + __m512bh b1; + __m512bh b2; + __m512bh b3; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + __m512bh a_bf16_1; + + for ( dim_t ir = 0; ir < m_full_pieces_loop_limit; ir += MR ) + { + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + __m512 c_float_0p1 = _mm512_setzero_ps(); + __m512 c_float_0p2 = _mm512_setzero_ps(); + __m512 c_float_0p3 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + __m512 c_float_1p1 = _mm512_setzero_ps(); + __m512 c_float_1p2 = _mm512_setzero_ps(); + __m512 c_float_1p3 = _mm512_setzero_ps(); + + __m512 c_float_2p0 = _mm512_setzero_ps(); + __m512 c_float_2p1 = _mm512_setzero_ps(); + __m512 c_float_2p2 = _mm512_setzero_ps(); + __m512 c_float_2p3 = _mm512_setzero_ps(); + + __m512 c_float_3p0 = _mm512_setzero_ps(); + __m512 c_float_3p1 = _mm512_setzero_ps(); + __m512 c_float_3p2 = _mm512_setzero_ps(); + __m512 c_float_3p3 = _mm512_setzero_ps(); + + __m512 c_float_4p0 = _mm512_setzero_ps(); + __m512 c_float_4p1 = _mm512_setzero_ps(); + __m512 c_float_4p2 = _mm512_setzero_ps(); + __m512 c_float_4p3 = _mm512_setzero_ps(); + + __m512 c_float_5p0 = _mm512_setzero_ps(); + __m512 c_float_5p1 = _mm512_setzero_ps(); + __m512 c_float_5p2 = _mm512_setzero_ps(); + __m512 c_float_5p3 = _mm512_setzero_ps(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + // The instructions are arranged in a mixed way to reduce data + // chain dependencies. + + b0 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+2] + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )(a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + b1 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + b2 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 2 ) ); + b3 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 3 ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-63] = a[0,kr:kr+2]*b[kr:kr+2,0-63] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + a_bf16_1 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); + c_float_0p3 = _mm512_dpbf16_ps( c_float_0p3, a_bf16_0, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-63] = a[1,kr:kr+2]*b[kr:kr+2,0-63] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_1, b0 ); + + // Broadcast a[2,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_1, b1 ); + c_float_1p2 = _mm512_dpbf16_ps( c_float_1p2, a_bf16_1, b2 ); + c_float_1p3 = _mm512_dpbf16_ps( c_float_1p3, a_bf16_1, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-63] = a[2,kr:kr+2]*b[kr:kr+2,0-63] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + + // Broadcast a[3,kr:kr+2]. + a_bf16_1 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + + c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); + c_float_2p2 = _mm512_dpbf16_ps( c_float_2p2, a_bf16_0, b2 ); + c_float_2p3 = _mm512_dpbf16_ps( c_float_2p3, a_bf16_0, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-63] = a[3,kr:kr+2]*b[kr:kr+2,0-63] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_1, b0 ); + + // Broadcast a[4,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); + + c_float_3p1 = _mm512_dpbf16_ps( c_float_3p1, a_bf16_1, b1 ); + c_float_3p2 = _mm512_dpbf16_ps( c_float_3p2, a_bf16_1, b2 ); + c_float_3p3 = _mm512_dpbf16_ps( c_float_3p3, a_bf16_1, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[4,0-63] = a[4,kr:kr+2]*b[kr:kr+2,0-63] + c_float_4p0 = _mm512_dpbf16_ps( c_float_4p0, a_bf16_0, b0 ); + + // Broadcast a[5,kr:kr+2]. + a_bf16_1 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 5 ) + ( cs_a * kr ) ) ); + + c_float_4p1 = _mm512_dpbf16_ps( c_float_4p1, a_bf16_0, b1 ); + c_float_4p2 = _mm512_dpbf16_ps( c_float_4p2, a_bf16_0, b2 ); + c_float_4p3 = _mm512_dpbf16_ps( c_float_4p3, a_bf16_0, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[5,0-63] = a[5,kr:kr+2]*b[kr:kr+2,0-63] + c_float_5p0 = _mm512_dpbf16_ps( c_float_5p0, a_bf16_1, b0 ); + c_float_5p1 = _mm512_dpbf16_ps( c_float_5p1, a_bf16_1, b1 ); + c_float_5p2 = _mm512_dpbf16_ps( c_float_5p2, a_bf16_1, b2 ); + c_float_5p3 = _mm512_dpbf16_ps( c_float_5p3, a_bf16_1, b3 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + b1 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + b2 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); + b3 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 3 ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-63] = a[0,kr:kr+2]*b[kr:kr+2,0-63] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_1 = _mm512_set1_epi32( a_kfringe_buf ); + + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); + c_float_0p3 = _mm512_dpbf16_ps( c_float_0p3, a_bf16_0, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-63] = a[1,kr:kr+2]*b[kr:kr+2,0-63] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_1, b0 ); + + // Broadcast a[2,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_1, b1 ); + c_float_1p2 = _mm512_dpbf16_ps( c_float_1p2, a_bf16_1, b2 ); + c_float_1p3 = _mm512_dpbf16_ps( c_float_1p3, a_bf16_1, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-63] = a[2,kr:kr+2]*b[kr:kr+2,0-63] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + + // Broadcast a[3,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_1 = _mm512_set1_epi32( a_kfringe_buf ); + + c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); + c_float_2p2 = _mm512_dpbf16_ps( c_float_2p2, a_bf16_0, b2 ); + c_float_2p3 = _mm512_dpbf16_ps( c_float_2p3, a_bf16_0, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-63] = a[3,kr:kr+2]*b[kr:kr+2,0-63] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_1, b0 ); + + // Broadcast a[4,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 4 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + c_float_3p1 = _mm512_dpbf16_ps( c_float_3p1, a_bf16_1, b1 ); + c_float_3p2 = _mm512_dpbf16_ps( c_float_3p2, a_bf16_1, b2 ); + c_float_3p3 = _mm512_dpbf16_ps( c_float_3p3, a_bf16_1, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[4,0-63] = a[4,kr:kr+2]*b[kr:kr+2,0-63] + c_float_4p0 = _mm512_dpbf16_ps( c_float_4p0, a_bf16_0, b0 ); + + // Broadcast a[5,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 5 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_1 = _mm512_set1_epi32( a_kfringe_buf ); + + c_float_4p1 = _mm512_dpbf16_ps( c_float_4p1, a_bf16_0, b1 ); + c_float_4p2 = _mm512_dpbf16_ps( c_float_4p2, a_bf16_0, b2 ); + c_float_4p3 = _mm512_dpbf16_ps( c_float_4p3, a_bf16_0, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[5,0-63] = a[5,kr:kr+2]*b[kr:kr+2,0-63] + c_float_5p0 = _mm512_dpbf16_ps( c_float_5p0, a_bf16_1, b0 ); + c_float_5p1 = _mm512_dpbf16_ps( c_float_5p1, a_bf16_1, b1 ); + c_float_5p2 = _mm512_dpbf16_ps( c_float_5p2, a_bf16_1, b2 ); + c_float_5p3 = _mm512_dpbf16_ps( c_float_5p3, a_bf16_1, b3 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps ( alpha ); + __m512 selector2 = _mm512_set1_ps ( beta ); + + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + c_float_0p1 = _mm512_mul_ps( selector1, c_float_0p1 ); + c_float_0p2 = _mm512_mul_ps( selector1, c_float_0p2 ); + c_float_0p3 = _mm512_mul_ps( selector1, c_float_0p3 ); + + c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); + c_float_1p1 = _mm512_mul_ps( selector1, c_float_1p1 ); + c_float_1p2 = _mm512_mul_ps( selector1, c_float_1p2 ); + c_float_1p3 = _mm512_mul_ps( selector1, c_float_1p3 ); + + c_float_2p0 = _mm512_mul_ps( selector1, c_float_2p0 ); + c_float_2p1 = _mm512_mul_ps( selector1, c_float_2p1 ); + c_float_2p2 = _mm512_mul_ps( selector1, c_float_2p2 ); + c_float_2p3 = _mm512_mul_ps( selector1, c_float_2p3 ); + + c_float_3p0 = _mm512_mul_ps( selector1, c_float_3p0 ); + c_float_3p1 = _mm512_mul_ps( selector1, c_float_3p1 ); + c_float_3p2 = _mm512_mul_ps( selector1, c_float_3p2 ); + c_float_3p3 = _mm512_mul_ps( selector1, c_float_3p3 ); + + c_float_4p0 = _mm512_mul_ps( selector1, c_float_4p0 ); + c_float_4p1 = _mm512_mul_ps( selector1, c_float_4p1 ); + c_float_4p2 = _mm512_mul_ps( selector1, c_float_4p2 ); + c_float_4p3 = _mm512_mul_ps( selector1, c_float_4p3 ); + + c_float_5p0 = _mm512_mul_ps( selector1, c_float_5p0 ); + c_float_5p1 = _mm512_mul_ps( selector1, c_float_5p1 ); + c_float_5p2 = _mm512_mul_ps( selector1, c_float_5p2 ); + c_float_5p3 = _mm512_mul_ps( selector1, c_float_5p3 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 0 ) ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 0 ) ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 0 ) ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p2 = _mm512_add_ps( selector1, c_float_0p2 ); + + // c[0,48-63] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 0 ) ) + ( 3*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p3 = _mm512_add_ps( selector1, c_float_0p3 ); + + // c[1,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 1 ) ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 1 ) ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p1 = _mm512_add_ps( selector1, c_float_1p1 ); + + // c[1,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 1 ) ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p2 = _mm512_add_ps( selector1, c_float_1p2 ); + + // c[1,48-63] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 1 ) ) + ( 3*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p3 = _mm512_add_ps( selector1, c_float_1p3 ); + + // c[2,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 2 ) ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[2,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 2 ) ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p1 = _mm512_add_ps( selector1, c_float_2p1 ); + + // c[2,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 2 ) ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p2 = _mm512_add_ps( selector1, c_float_2p2 ); + + // c[2,48-63] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 2 ) ) + ( 3*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p3 = _mm512_add_ps( selector1, c_float_2p3 ); + + // c[3,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 3 ) ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + + // c[3,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 3 ) ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_3p1 = _mm512_add_ps( selector1, c_float_3p1 ); + + // c[3,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 3 ) ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_3p2 = _mm512_add_ps( selector1, c_float_3p2 ); + + // c[3,48-63] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 3 ) ) + ( 3*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_3p3 = _mm512_add_ps( selector1, c_float_3p3 ); + + // c[4,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 4 ) ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); + + // c[4,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 4 ) ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_4p1 = _mm512_add_ps( selector1, c_float_4p1 ); + + // c[4,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 4 ) ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_4p2 = _mm512_add_ps( selector1, c_float_4p2 ); + + // c[4,48-63] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 4 ) ) + ( 3*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_4p3 = _mm512_add_ps( selector1, c_float_4p3 ); + + // c[5,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 5 ) ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_5p0 = _mm512_add_ps( selector1, c_float_5p0 ); + + // c[5,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 5 ) ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_5p1 = _mm512_add_ps( selector1, c_float_5p1 ); + + // c[5,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 5 ) ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_5p2 = _mm512_add_ps( selector1, c_float_5p2 ); + + // c[5,48-63] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 5 ) ) + ( 3*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_5p3 = _mm512_add_ps( selector1, c_float_5p3 ); + } + + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( c + ( rs_c * ( ir + 0 ) ) + ( 0*16 ), c_float_0p0 ); + + // c[0, 16-31] + _mm512_storeu_ps( c + ( rs_c * ( ir + 0 ) ) + ( 1*16 ), c_float_0p1 ); + + // c[0,32-47] + _mm512_storeu_ps( c + ( rs_c * ( ir + 0 ) ) + ( 2*16 ), c_float_0p2 ); + + // c[0,48-63] + _mm512_storeu_ps( c + ( rs_c * ( ir + 0 ) ) + ( 3*16 ), c_float_0p3 ); + + // c[1,0-15] + _mm512_storeu_ps( c + ( rs_c * ( ir + 1 ) ) + ( 0*16 ), c_float_1p0 ); + + // c[1,16-31] + _mm512_storeu_ps( c + ( rs_c * ( ir + 1 ) ) + ( 1*16 ), c_float_1p1 ); + + // c[1,32-47] + _mm512_storeu_ps( c + ( rs_c * ( ir + 1 ) ) + ( 2*16 ), c_float_1p2 ); + + // c[1,48-63] + _mm512_storeu_ps( c + ( rs_c * ( ir + 1 ) ) + ( 3*16 ), c_float_1p3 ); + + // c[2,0-15] + _mm512_storeu_ps( c + ( rs_c * ( ir + 2 ) ) + ( 0*16 ), c_float_2p0 ); + + // c[2,16-31] + _mm512_storeu_ps( c + ( rs_c * ( ir + 2 ) ) + ( 1*16 ), c_float_2p1 ); + + // c[2,32-47] + _mm512_storeu_ps( c + ( rs_c * ( ir + 2 ) ) + ( 2*16 ), c_float_2p2 ); + + // c[2,48-63] + _mm512_storeu_ps( c + ( rs_c * ( ir + 2 ) ) + ( 3*16 ), c_float_2p3 ); + + // c[3,0-15] + _mm512_storeu_ps( c + ( rs_c * ( ir + 3 ) ) + ( 0*16 ), c_float_3p0 ); + + // c[3,16-31] + _mm512_storeu_ps( c + ( rs_c * ( ir + 3 ) ) + ( 1*16 ), c_float_3p1 ); + + // c[3,32-47] + _mm512_storeu_ps( c + ( rs_c * ( ir + 3 ) ) + ( 2*16 ), c_float_3p2 ); + + // c[3,48-63] + _mm512_storeu_ps( c + ( rs_c * ( ir + 3 ) ) + ( 3*16 ), c_float_3p3 ); + + // c[4,0-15] + _mm512_storeu_ps( c + ( rs_c * ( ir + 4 ) ) + ( 0*16 ), c_float_4p0 ); + + // c[4,16-31] + _mm512_storeu_ps( c + ( rs_c * ( ir + 4 ) ) + ( 1*16 ), c_float_4p1 ); + + // c[4,32-47] + _mm512_storeu_ps( c + ( rs_c * ( ir + 4 ) ) + ( 2*16 ), c_float_4p2 ); + + // c[4,48-63] + _mm512_storeu_ps( c + ( rs_c * ( ir + 4 ) ) + ( 3*16 ), c_float_4p3 ); + + // c[5,0-15] + _mm512_storeu_ps( c + ( rs_c * ( ir + 5 ) ) + ( 0*16 ), c_float_5p0 ); + + // c[5,16-31] + _mm512_storeu_ps( c + ( rs_c * ( ir + 5 ) ) + ( 1*16 ), c_float_5p1 ); + + // c[5,32-47] + _mm512_storeu_ps( c + ( rs_c * ( ir + 5 ) ) + ( 2*16 ), c_float_5p2 ); + + // c[5,48-63] + _mm512_storeu_ps( c + ( rs_c * ( ir + 5 ) ) + ( 3*16 ), c_float_5p3 ); + + a = a + ( MR * ps_a ); + } + + if ( m_partial_pieces > 0 ) + { + if ( m_partial_pieces == 5 ) + { + // In cases where A matrix is packed cs_a is set to 12, since the + // next column in a given row is accessed after 2*6 elements, where + // 6 is MR and 2 elements are broadcasted each time from A (bf16). + // In fringe case, where m < MR, the next column will be after m'*2 + // elements, and subsequently following adjustment of cs_a is + // required before calling m fringe kernels. + dim_t cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 5 ); + lpgemm_rowvar_bf16bf16f32of32_5x64 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list + ); + } + else if ( m_partial_pieces == 4 ) + { + dim_t cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 4 ); + lpgemm_rowvar_bf16bf16f32of32_4x64 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list + ); + } + else if ( m_partial_pieces == 3 ) + { + dim_t cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 3 ); + lpgemm_rowvar_bf16bf16f32of32_3x64 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list + ); + } + else if ( m_partial_pieces == 2 ) + { + dim_t cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 2 ); + lpgemm_rowvar_bf16bf16f32of32_2x64 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list + ); + } + else if ( m_partial_pieces == 1 ) + { + dim_t cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 1 ); + lpgemm_rowvar_bf16bf16f32of32_1x64 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list + ); + } + } +} diff --git a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_m_fringe_bf16_amd512vnni.c b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_m_fringe_bf16_amd512vnni.c new file mode 100644 index 0000000000..9d1e742bea --- /dev/null +++ b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_m_fringe_bf16_amd512vnni.c @@ -0,0 +1,1313 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include + +#include "blis.h" +#include "lpgemm_kernels.h" + +// 5x64 bf16 kernel +LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x64) +{ + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int32_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + __m512bh b1; + __m512bh b2; + __m512bh b3; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + __m512bh a_bf16_1; + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + __m512 c_float_0p1 = _mm512_setzero_ps(); + __m512 c_float_0p2 = _mm512_setzero_ps(); + __m512 c_float_0p3 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + __m512 c_float_1p1 = _mm512_setzero_ps(); + __m512 c_float_1p2 = _mm512_setzero_ps(); + __m512 c_float_1p3 = _mm512_setzero_ps(); + + __m512 c_float_2p0 = _mm512_setzero_ps(); + __m512 c_float_2p1 = _mm512_setzero_ps(); + __m512 c_float_2p2 = _mm512_setzero_ps(); + __m512 c_float_2p3 = _mm512_setzero_ps(); + + __m512 c_float_3p0 = _mm512_setzero_ps(); + __m512 c_float_3p1 = _mm512_setzero_ps(); + __m512 c_float_3p2 = _mm512_setzero_ps(); + __m512 c_float_3p3 = _mm512_setzero_ps(); + + __m512 c_float_4p0 = _mm512_setzero_ps(); + __m512 c_float_4p1 = _mm512_setzero_ps(); + __m512 c_float_4p2 = _mm512_setzero_ps(); + __m512 c_float_4p3 = _mm512_setzero_ps(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + b1 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + b2 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 2 ) ); + b3 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 3 ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-63] = a[0,kr:kr+2]*b[kr:kr+2,0-63] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + a_bf16_1 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); + c_float_0p3 = _mm512_dpbf16_ps( c_float_0p3, a_bf16_0, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-63] = a[1,kr:kr+2]*b[kr:kr+2,0-63] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_1, b0 ); + + // Broadcast a[2,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_1, b1 ); + c_float_1p2 = _mm512_dpbf16_ps( c_float_1p2, a_bf16_1, b2 ); + c_float_1p3 = _mm512_dpbf16_ps( c_float_1p3, a_bf16_1, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-63] = a[2,kr:kr+2]*b[kr:kr+2,0-63] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + + // Broadcast a[3,kr:kr+2]. + a_bf16_1 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + + c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); + c_float_2p2 = _mm512_dpbf16_ps( c_float_2p2, a_bf16_0, b2 ); + c_float_2p3 = _mm512_dpbf16_ps( c_float_2p3, a_bf16_0, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-63] = a[3,kr:kr+2]*b[kr:kr+2,0-63] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_1, b0 ); + + // Broadcast a[4,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); + + c_float_3p1 = _mm512_dpbf16_ps( c_float_3p1, a_bf16_1, b1 ); + c_float_3p2 = _mm512_dpbf16_ps( c_float_3p2, a_bf16_1, b2 ); + c_float_3p3 = _mm512_dpbf16_ps( c_float_3p3, a_bf16_1, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[4,0-63] = a[4,kr:kr+2]*b[kr:kr+2,0-63] + c_float_4p0 = _mm512_dpbf16_ps( c_float_4p0, a_bf16_0, b0 ); + c_float_4p1 = _mm512_dpbf16_ps( c_float_4p1, a_bf16_0, b1 ); + c_float_4p2 = _mm512_dpbf16_ps( c_float_4p2, a_bf16_0, b2 ); + c_float_4p3 = _mm512_dpbf16_ps( c_float_4p3, a_bf16_0, b3 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + b1 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + b2 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); + b3 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 3 ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-63] = a[0,kr:kr+2]*b[kr:kr+2,0-63] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_1 = _mm512_set1_epi32( a_kfringe_buf ); + + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); + c_float_0p3 = _mm512_dpbf16_ps( c_float_0p3, a_bf16_0, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-63] = a[1,kr:kr+2]*b[kr:kr+2,0-63] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_1, b0 ); + + // Broadcast a[2,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_1, b1 ); + c_float_1p2 = _mm512_dpbf16_ps( c_float_1p2, a_bf16_1, b2 ); + c_float_1p3 = _mm512_dpbf16_ps( c_float_1p3, a_bf16_1, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-63] = a[2,kr:kr+2]*b[kr:kr+2,0-63] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + + // Broadcast a[3,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_1 = _mm512_set1_epi32( a_kfringe_buf ); + + c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); + c_float_2p2 = _mm512_dpbf16_ps( c_float_2p2, a_bf16_0, b2 ); + c_float_2p3 = _mm512_dpbf16_ps( c_float_2p3, a_bf16_0, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-63] = a[3,kr:kr+2]*b[kr:kr+2,0-63] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_1, b0 ); + + // Broadcast a[4,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 4 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + c_float_3p1 = _mm512_dpbf16_ps( c_float_3p1, a_bf16_1, b1 ); + c_float_3p2 = _mm512_dpbf16_ps( c_float_3p2, a_bf16_1, b2 ); + c_float_3p3 = _mm512_dpbf16_ps( c_float_3p3, a_bf16_1, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[4,0-63] = a[4,kr:kr+2]*b[kr:kr+2,0-63] + c_float_4p0 = _mm512_dpbf16_ps( c_float_4p0, a_bf16_0, b0 ); + c_float_4p1 = _mm512_dpbf16_ps( c_float_4p1, a_bf16_0, b1 ); + c_float_4p2 = _mm512_dpbf16_ps( c_float_4p2, a_bf16_0, b2 ); + c_float_4p3 = _mm512_dpbf16_ps( c_float_4p3, a_bf16_0, b3 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + c_float_0p1 = _mm512_mul_ps( selector1, c_float_0p1 ); + c_float_0p2 = _mm512_mul_ps( selector1, c_float_0p2 ); + c_float_0p3 = _mm512_mul_ps( selector1, c_float_0p3 ); + + c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); + c_float_1p1 = _mm512_mul_ps( selector1, c_float_1p1 ); + c_float_1p2 = _mm512_mul_ps( selector1, c_float_1p2 ); + c_float_1p3 = _mm512_mul_ps( selector1, c_float_1p3 ); + + c_float_2p0 = _mm512_mul_ps( selector1, c_float_2p0 ); + c_float_2p1 = _mm512_mul_ps( selector1, c_float_2p1 ); + c_float_2p2 = _mm512_mul_ps( selector1, c_float_2p2 ); + c_float_2p3 = _mm512_mul_ps( selector1, c_float_2p3 ); + + c_float_3p0 = _mm512_mul_ps( selector1, c_float_3p0 ); + c_float_3p1 = _mm512_mul_ps( selector1, c_float_3p1 ); + c_float_3p2 = _mm512_mul_ps( selector1, c_float_3p2 ); + c_float_3p3 = _mm512_mul_ps( selector1, c_float_3p3 ); + + c_float_4p0 = _mm512_mul_ps( selector1, c_float_4p0 ); + c_float_4p1 = _mm512_mul_ps( selector1, c_float_4p1 ); + c_float_4p2 = _mm512_mul_ps( selector1, c_float_4p2 ); + c_float_4p3 = _mm512_mul_ps( selector1, c_float_4p3 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p2 = _mm512_add_ps( selector1, c_float_0p2 ); + + // c[0,48-63] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 3*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p3 = _mm512_add_ps( selector1, c_float_0p3 ); + + // c[1,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p1 = _mm512_add_ps( selector1, c_float_1p1 ); + + // c[1,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p2 = _mm512_add_ps( selector1, c_float_1p2 ); + + // c[1,48-63] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 3*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p3 = _mm512_add_ps( selector1, c_float_1p3 ); + + // c[2,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 2 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[2,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 2 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p1 = _mm512_add_ps( selector1, c_float_2p1 ); + + // c[2,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * 2 ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p2 = _mm512_add_ps( selector1, c_float_2p2 ); + + // c[2,48-63] + selector1 = _mm512_loadu_ps( c + ( rs_c * 2 ) + ( 3*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p3 = _mm512_add_ps( selector1, c_float_2p3 ); + + // c[3,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 3 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + + // c[3,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 3 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_3p1 = _mm512_add_ps( selector1, c_float_3p1 ); + + // c[3,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * 3 ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_3p2 = _mm512_add_ps( selector1, c_float_3p2 ); + + // c[3,48-63] + selector1 = _mm512_loadu_ps( c + ( rs_c * 3 ) + ( 3*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_3p3 = _mm512_add_ps( selector1, c_float_3p3 ); + + // c[4,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 4 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); + + // c[4,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 4 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_4p1 = _mm512_add_ps( selector1, c_float_4p1 ); + + // c[4,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * 4 ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_4p2 = _mm512_add_ps( selector1, c_float_4p2 ); + + // c[4,48-63] + selector1 = _mm512_loadu_ps( c + ( rs_c * 4 ) + ( 3*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_4p3 = _mm512_add_ps( selector1, c_float_4p3 ); + } + + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 0*16 ), c_float_0p0 ); + + // c[0, 16-31] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 1*16 ), c_float_0p1 ); + + // c[0,32-47] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 2*16 ), c_float_0p2 ); + + // c[0,48-63] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 3*16 ), c_float_0p3 ); + + // c[1,0-15] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 0*16 ), c_float_1p0 ); + + // c[1,16-31] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 1*16 ), c_float_1p1 ); + + // c[1,32-47] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 2*16 ), c_float_1p2 ); + + // c[1,48-63] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 3*16 ), c_float_1p3 ); + + // c[2,0-15] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 0*16 ), c_float_2p0 ); + + // c[2,16-31] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 1*16 ), c_float_2p1 ); + + // c[2,32-47] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 2*16 ), c_float_2p2 ); + + // c[2,48-63] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 3*16 ), c_float_2p3 ); + + // c[3,0-15] + _mm512_storeu_ps( c + ( rs_c * 3 ) + ( 0*16 ), c_float_3p0 ); + + // c[3,16-31] + _mm512_storeu_ps( c + ( rs_c * 3 ) + ( 1*16 ), c_float_3p1 ); + + // c[3,32-47] + _mm512_storeu_ps( c + ( rs_c * 3 ) + ( 2*16 ), c_float_3p2 ); + + // c[3,48-63] + _mm512_storeu_ps( c + ( rs_c * 3 ) + ( 3*16 ), c_float_3p3 ); + + // c[4,0-15] + _mm512_storeu_ps( c + ( rs_c * 4 ) + ( 0*16 ), c_float_4p0 ); + + // c[4,16-31] + _mm512_storeu_ps( c + ( rs_c * 4 ) + ( 1*16 ), c_float_4p1 ); + + // c[4,32-47] + _mm512_storeu_ps( c + ( rs_c * 4 ) + ( 2*16 ), c_float_4p2 ); + + // c[4,48-63] + _mm512_storeu_ps( c + ( rs_c * 4 ) + ( 3*16 ), c_float_4p3 ); +} + +// 4x64 bf16 kernel +LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x64) +{ + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int32_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + __m512bh b1; + __m512bh b2; + __m512bh b3; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + __m512bh a_bf16_1; + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + __m512 c_float_0p1 = _mm512_setzero_ps(); + __m512 c_float_0p2 = _mm512_setzero_ps(); + __m512 c_float_0p3 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + __m512 c_float_1p1 = _mm512_setzero_ps(); + __m512 c_float_1p2 = _mm512_setzero_ps(); + __m512 c_float_1p3 = _mm512_setzero_ps(); + + __m512 c_float_2p0 = _mm512_setzero_ps(); + __m512 c_float_2p1 = _mm512_setzero_ps(); + __m512 c_float_2p2 = _mm512_setzero_ps(); + __m512 c_float_2p3 = _mm512_setzero_ps(); + + __m512 c_float_3p0 = _mm512_setzero_ps(); + __m512 c_float_3p1 = _mm512_setzero_ps(); + __m512 c_float_3p2 = _mm512_setzero_ps(); + __m512 c_float_3p3 = _mm512_setzero_ps(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + b1 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + b2 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 2 ) ); + b3 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 3 ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-63] = a[0,kr:kr+4]*b[kr:kr+4,0-63] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + a_bf16_1 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); + c_float_0p3 = _mm512_dpbf16_ps( c_float_0p3, a_bf16_0, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-63] = a[1,kr:kr+2]*b[kr:kr+2,0-63] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_1, b0 ); + + // Broadcast a[2,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_1, b1 ); + c_float_1p2 = _mm512_dpbf16_ps( c_float_1p2, a_bf16_1, b2 ); + c_float_1p3 = _mm512_dpbf16_ps( c_float_1p3, a_bf16_1, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-63] = a[2,kr:kr+2]*b[kr:kr+2,0-63] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + + // Broadcast a[3,kr:kr+2]. + a_bf16_1 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + + c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); + c_float_2p2 = _mm512_dpbf16_ps( c_float_2p2, a_bf16_0, b2 ); + c_float_2p3 = _mm512_dpbf16_ps( c_float_2p3, a_bf16_0, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-63] = a[3,kr:kr+2]*b[kr:kr+2,0-63] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_1, b0 ); + c_float_3p1 = _mm512_dpbf16_ps( c_float_3p1, a_bf16_1, b1 ); + c_float_3p2 = _mm512_dpbf16_ps( c_float_3p2, a_bf16_1, b2 ); + c_float_3p3 = _mm512_dpbf16_ps( c_float_3p3, a_bf16_1, b3 ); + } + + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16) ) + ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + b1 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + b2 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); + b3 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 3 ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-63] = a[0,kr:kr+2]*b[kr:kr+2,0-63] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_1 = _mm512_set1_epi32( a_kfringe_buf ); + + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); + c_float_0p3 = _mm512_dpbf16_ps( c_float_0p3, a_bf16_0, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-63] = a[1,kr:kr+2]*b[kr:kr+2,0-63] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_1, b0 ); + + // Broadcast a[2,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_1, b1 ); + c_float_1p2 = _mm512_dpbf16_ps( c_float_1p2, a_bf16_1, b2 ); + c_float_1p3 = _mm512_dpbf16_ps( c_float_1p3, a_bf16_1, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-63] = a[2,kr:kr+2]*b[kr:kr+2,0-63] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + + // Broadcast a[3,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_1 = _mm512_set1_epi32( a_kfringe_buf ); + + c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); + c_float_2p2 = _mm512_dpbf16_ps( c_float_2p2, a_bf16_0, b2 ); + c_float_2p3 = _mm512_dpbf16_ps( c_float_2p3, a_bf16_0, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-63] = a[3,kr:kr+2]*b[kr:kr+2,0-63] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_1, b0 ); + c_float_3p1 = _mm512_dpbf16_ps( c_float_3p1, a_bf16_1, b1 ); + c_float_3p2 = _mm512_dpbf16_ps( c_float_3p2, a_bf16_1, b2 ); + c_float_3p3 = _mm512_dpbf16_ps( c_float_3p3, a_bf16_1, b3 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + c_float_0p1 = _mm512_mul_ps( selector1, c_float_0p1 ); + c_float_0p2 = _mm512_mul_ps( selector1, c_float_0p2 ); + c_float_0p3 = _mm512_mul_ps( selector1, c_float_0p3 ); + + c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); + c_float_1p1 = _mm512_mul_ps( selector1, c_float_1p1 ); + c_float_1p2 = _mm512_mul_ps( selector1, c_float_1p2 ); + c_float_1p3 = _mm512_mul_ps( selector1, c_float_1p3 ); + + c_float_2p0 = _mm512_mul_ps( selector1, c_float_2p0 ); + c_float_2p1 = _mm512_mul_ps( selector1, c_float_2p1 ); + c_float_2p2 = _mm512_mul_ps( selector1, c_float_2p2 ); + c_float_2p3 = _mm512_mul_ps( selector1, c_float_2p3 ); + + c_float_3p0 = _mm512_mul_ps( selector1, c_float_3p0 ); + c_float_3p1 = _mm512_mul_ps( selector1, c_float_3p1 ); + c_float_3p2 = _mm512_mul_ps( selector1, c_float_3p2 ); + c_float_3p3 = _mm512_mul_ps( selector1, c_float_3p3 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p2 = _mm512_add_ps( selector1, c_float_0p2 ); + + // c[0,48-63] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 3*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p3 = _mm512_add_ps( selector1, c_float_0p3 ); + + // c[1,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p1 = _mm512_add_ps( selector1, c_float_1p1 ); + + // c[1,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p2 = _mm512_add_ps( selector1, c_float_1p2 ); + + // c[1,48-63] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 3*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p3 = _mm512_add_ps( selector1, c_float_1p3 ); + + // c[2,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 2 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[2,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 2 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p1 = _mm512_add_ps( selector1, c_float_2p1 ); + + // c[2,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * 2 ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p2 = _mm512_add_ps( selector1, c_float_2p2 ); + + // c[2,48-63] + selector1 = _mm512_loadu_ps( c + ( rs_c * 2 ) + ( 3*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p3 = _mm512_add_ps( selector1, c_float_2p3 ); + + // c[3,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 3 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + + // c[3,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 3 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_3p1 = _mm512_add_ps( selector1, c_float_3p1 ); + + // c[3,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * 3 ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_3p2 = _mm512_add_ps( selector1, c_float_3p2 ); + + // c[3,48-63] + selector1 = _mm512_loadu_ps( c + ( rs_c * 3 ) + ( 3*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_3p3 = _mm512_add_ps( selector1, c_float_3p3 ); + } + + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 0*16 ), c_float_0p0 ); + + // c[0, 16-31] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 1*16 ), c_float_0p1 ); + + // c[0,32-47] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 2*16 ), c_float_0p2 ); + + // c[0,48-63] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 3*16 ), c_float_0p3 ); + + // c[1,0-15] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 0*16 ), c_float_1p0 ); + + // c[1,16-31] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 1*16 ), c_float_1p1 ); + + // c[1,32-47] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 2*16 ), c_float_1p2 ); + + // c[1,48-63] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 3*16 ), c_float_1p3 ); + + // c[2,0-15] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 0*16 ), c_float_2p0 ); + + // c[2,16-31] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 1*16 ), c_float_2p1 ); + + // c[2,32-47] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 2*16 ), c_float_2p2 ); + + // c[2,48-63] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 3*16 ), c_float_2p3 ); + + // c[3,0-15] + _mm512_storeu_ps( c + ( rs_c * 3 ) + ( 0*16 ), c_float_3p0 ); + + // c[3,16-31] + _mm512_storeu_ps( c + ( rs_c * 3 ) + ( 1*16 ), c_float_3p1 ); + + // c[3,32-47] + _mm512_storeu_ps( c + ( rs_c * 3 ) + ( 2*16 ), c_float_3p2 ); + + // c[3,48-63] + _mm512_storeu_ps( c + ( rs_c * 3 ) + ( 3*16 ), c_float_3p3 ); +} + +// 3x64 bf16 kernel +LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x64) +{ + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int32_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + __m512bh b1; + __m512bh b2; + __m512bh b3; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + __m512bh a_bf16_1; + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + __m512 c_float_0p1 = _mm512_setzero_ps(); + __m512 c_float_0p2 = _mm512_setzero_ps(); + __m512 c_float_0p3 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + __m512 c_float_1p1 = _mm512_setzero_ps(); + __m512 c_float_1p2 = _mm512_setzero_ps(); + __m512 c_float_1p3 = _mm512_setzero_ps(); + + __m512 c_float_2p0 = _mm512_setzero_ps(); + __m512 c_float_2p1 = _mm512_setzero_ps(); + __m512 c_float_2p2 = _mm512_setzero_ps(); + __m512 c_float_2p3 = _mm512_setzero_ps(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + b1 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + b2 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 2 ) ); + b3 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 3 ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-63] = a[0,kr:kr+2]*b[kr:kr+2,0-63] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + a_bf16_1 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); + c_float_0p3 = _mm512_dpbf16_ps( c_float_0p3, a_bf16_0, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-63] = a[1,kr:kr+2]*b[kr:kr+2,0-63] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_1, b0 ); + + // Broadcast a[2,kr:kr+4]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_1, b1 ); + c_float_1p2 = _mm512_dpbf16_ps( c_float_1p2, a_bf16_1, b2 ); + c_float_1p3 = _mm512_dpbf16_ps( c_float_1p3, a_bf16_1, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-63] = a[2,kr:kr+2]*b[kr:kr+2,0-63] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); + c_float_2p2 = _mm512_dpbf16_ps( c_float_2p2, a_bf16_0, b2 ); + c_float_2p3 = _mm512_dpbf16_ps( c_float_2p3, a_bf16_0, b3 ); + } + + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + b1 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + b2 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); + b3 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 3 ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-63] = a[0,kr:kr+2]*b[kr:kr+2,0-63] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + __m512i a_bf16_1 = _mm512_set1_epi32( a_kfringe_buf ); + + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); + c_float_0p3 = _mm512_dpbf16_ps( c_float_0p3, a_bf16_0, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-63] = a[1,kr:kr+2]*b[kr:kr+2,0-63] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_1, b0 ); + + // Broadcast a[2,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_1, b1 ); + c_float_1p2 = _mm512_dpbf16_ps( c_float_1p2, a_bf16_1, b2 ); + c_float_1p3 = _mm512_dpbf16_ps( c_float_1p3, a_bf16_1, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-63] = a[2,kr:kr+2]*b[kr:kr+2,0-63] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); + c_float_2p2 = _mm512_dpbf16_ps( c_float_2p2, a_bf16_0, b2 ); + c_float_2p3 = _mm512_dpbf16_ps( c_float_2p3, a_bf16_0, b3 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + c_float_0p1 = _mm512_mul_ps( selector1, c_float_0p1 ); + c_float_0p2 = _mm512_mul_ps( selector1, c_float_0p2 ); + c_float_0p3 = _mm512_mul_ps( selector1, c_float_0p3 ); + + c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); + c_float_1p1 = _mm512_mul_ps( selector1, c_float_1p1 ); + c_float_1p2 = _mm512_mul_ps( selector1, c_float_1p2 ); + c_float_1p3 = _mm512_mul_ps( selector1, c_float_1p3 ); + + c_float_2p0 = _mm512_mul_ps( selector1, c_float_2p0 ); + c_float_2p1 = _mm512_mul_ps( selector1, c_float_2p1 ); + c_float_2p2 = _mm512_mul_ps( selector1, c_float_2p2 ); + c_float_2p3 = _mm512_mul_ps( selector1, c_float_2p3 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p2 = _mm512_add_ps( selector1, c_float_0p2 ); + + // c[0,48-63] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 3*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p3 = _mm512_add_ps( selector1, c_float_0p3 ); + + // c[1,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p1 = _mm512_add_ps( selector1, c_float_1p1 ); + + // c[1,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p2 = _mm512_add_ps( selector1, c_float_1p2 ); + + // c[1,48-63] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 3*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p3 = _mm512_add_ps( selector1, c_float_1p3 ); + + // c[2,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 2 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[2,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 2 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p1 = _mm512_add_ps( selector1, c_float_2p1 ); + + // c[2,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * 2 ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p2 = _mm512_add_ps( selector1, c_float_2p2 ); + + // c[2,48-63] + selector1 = _mm512_loadu_ps( c + ( rs_c * 2 ) + ( 3*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p3 = _mm512_add_ps( selector1, c_float_2p3 ); + } + + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 0*16 ), c_float_0p0 ); + + // c[0, 16-31] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 1*16 ), c_float_0p1 ); + + // c[0,32-47] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 2*16 ), c_float_0p2 ); + + // c[0,48-63] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 3*16 ), c_float_0p3 ); + + // c[1,0-15] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 0*16 ), c_float_1p0 ); + + // c[1,16-31] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 1*16 ), c_float_1p1 ); + + // c[1,32-47] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 2*16 ), c_float_1p2 ); + + // c[1,48-63] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 3*16 ), c_float_1p3 ); + + // c[2,0-15] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 0*16 ), c_float_2p0 ); + + // c[2,16-31] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 1*16 ), c_float_2p1 ); + + // c[2,32-47] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 2*16 ), c_float_2p2 ); + + // c[2,48-63] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 3*16 ), c_float_2p3 ); +} + +// 2x64 bf16 kernel +LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2x64) +{ + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int32_t a_kfringe_buf = 0; + // B matrix storage bfloat type + __m512bh b0; + __m512bh b1; + __m512bh b2; + __m512bh b3; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + __m512bh a_bf16_1; + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + __m512 c_float_0p1 = _mm512_setzero_ps(); + __m512 c_float_0p2 = _mm512_setzero_ps(); + __m512 c_float_0p3 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + __m512 c_float_1p1 = _mm512_setzero_ps(); + __m512 c_float_1p2 = _mm512_setzero_ps(); + __m512 c_float_1p3 = _mm512_setzero_ps(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + b1 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + b2 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 2 ) ); + b3 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 3 ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-63] = a[0,kr:kr+2]*b[kr:kr+2,0-63] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + a_bf16_1 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); + c_float_0p3 = _mm512_dpbf16_ps( c_float_0p3, a_bf16_0, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-63] = a[1,kr:kr+2]*b[kr:kr+2,0-63] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_1, b0 ); + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_1, b1 ); + c_float_1p2 = _mm512_dpbf16_ps( c_float_1p2, a_bf16_1, b2 ); + c_float_1p3 = _mm512_dpbf16_ps( c_float_1p3, a_bf16_1, b3 ); + } + + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + __m512i b0 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + __m512i a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + __m512i b1 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + __m512i b2 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); + __m512i b3 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 3 ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-63] = a[0,kr:kr+2]*b[kr:kr+2,0-63] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + __m512i a_bf16_1 = _mm512_set1_epi32( a_kfringe_buf ); + + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); + c_float_0p3 = _mm512_dpbf16_ps( c_float_0p3, a_bf16_0, b3 ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-63] = a[1,kr:kr+2]*b[kr:kr+2,0-63] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_1, b0 ); + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_1, b1 ); + c_float_1p2 = _mm512_dpbf16_ps( c_float_1p2, a_bf16_1, b2 ); + c_float_1p3 = _mm512_dpbf16_ps( c_float_1p3, a_bf16_1, b3 ); + } + + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + c_float_0p1 = _mm512_mul_ps( selector1, c_float_0p1 ); + c_float_0p2 = _mm512_mul_ps( selector1, c_float_0p2 ); + c_float_0p3 = _mm512_mul_ps( selector1, c_float_0p3 ); + + c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); + c_float_1p1 = _mm512_mul_ps( selector1, c_float_1p1 ); + c_float_1p2 = _mm512_mul_ps( selector1, c_float_1p2 ); + c_float_1p3 = _mm512_mul_ps( selector1, c_float_1p3 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p2 = _mm512_add_ps( selector1, c_float_0p2 ); + + // c[0,48-63] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 3*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p3 = _mm512_add_ps( selector1, c_float_0p3 ); + + // c[1,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p1 = _mm512_add_ps( selector1, c_float_1p1 ); + + // c[1,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p2 = _mm512_add_ps( selector1, c_float_1p2 ); + + // c[1,48-63] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 3*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p3 = _mm512_add_ps( selector1, c_float_1p3 ); + } + + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 0*16 ), c_float_0p0 ); + + // c[0, 16-31] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 1*16 ), c_float_0p1 ); + + // c[0,32-47] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 2*16 ), c_float_0p2 ); + + // c[0,48-63] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 3*16 ), c_float_0p3 ); + + // c[1,0-15] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 0*16 ), c_float_1p0 ); + + // c[1,16-31] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 1*16 ), c_float_1p1 ); + + // c[1,32-47] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 2*16 ), c_float_1p2 ); + + // c[1,48-63] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 3*16 ), c_float_1p3 ); +} + +// 1x64 bf16 kernel +LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1x64) +{ + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int32_t a_kfringe_buf = 0; + + // Registers to use for accumulating C. + __m512 c_float32_0p0 = _mm512_setzero_ps(); + __m512 c_float32_0p1 = _mm512_setzero_ps(); + __m512 c_float32_0p2 = _mm512_setzero_ps(); + __m512 c_float32_0p3 = _mm512_setzero_ps(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + __m512bh b0 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr] + __m512bh a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + __m512bh b1 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + __m512bh b2 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 2 ) ); + __m512bh b3 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 3 ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-63] = a[0,kr:kr+2]*b[kr:kr+2,0-63] + c_float32_0p0 = _mm512_dpbf16_ps( c_float32_0p0, a_bf16_0, b0 ); + c_float32_0p1 = _mm512_dpbf16_ps( c_float32_0p1, a_bf16_0, b1 ); + c_float32_0p2 = _mm512_dpbf16_ps( c_float32_0p2, a_bf16_0, b2 ); + c_float32_0p3 = _mm512_dpbf16_ps( c_float32_0p3, a_bf16_0, b3 ); + } + + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + __m512i b0 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + __m512i a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + __m512i b1 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + __m512i b2 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); + __m512i b3 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 3 ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-63] = a[0,kr:kr+2]*b[kr:kr+2,0-63] + c_float32_0p0 = _mm512_dpbf16_ps( c_float32_0p0, a_bf16_0, b0 ); + c_float32_0p1 = _mm512_dpbf16_ps( c_float32_0p1, a_bf16_0, b1 ); + c_float32_0p2 = _mm512_dpbf16_ps( c_float32_0p2, a_bf16_0, b2 ); + c_float32_0p3 = _mm512_dpbf16_ps( c_float32_0p3, a_bf16_0, b3 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + // Scale by alpha + c_float32_0p0 = _mm512_mul_ps( selector1, c_float32_0p0 ); + c_float32_0p1 = _mm512_mul_ps( selector1, c_float32_0p1 ); + c_float32_0p2 = _mm512_mul_ps( selector1, c_float32_0p2 ); + c_float32_0p3 = _mm512_mul_ps( selector1, c_float32_0p3 ); + + // Scale C by beta. + if ( beta != 0) + { + // c[0,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float32_0p0 = _mm512_add_ps( selector1, c_float32_0p0 ); + + // c[0, 16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float32_0p1 = _mm512_add_ps( selector1, c_float32_0p1 ); + + // c[0,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float32_0p2 = _mm512_add_ps( selector1, c_float32_0p2 ); + + // c[0,48-63] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 3*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float32_0p3 = _mm512_add_ps( selector1, c_float32_0p3 ); + } + + // Store the accumulated results. + // c[0,0-15] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 0*16 ), c_float32_0p0 ); + + // c[0, 16-31] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 1*16 ), c_float32_0p1 ); + + // c[0,32-47] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 2*16 ), c_float32_0p2 ); + + // c[0,48-63] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 3*16 ), c_float32_0p3 ); +} diff --git a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_mn_fringe_bf16_amd512vnni.c b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_mn_fringe_bf16_amd512vnni.c new file mode 100644 index 0000000000..978b7944c1 --- /dev/null +++ b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_mn_fringe_bf16_amd512vnni.c @@ -0,0 +1,3069 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include + +#include "blis.h" +#include "lpgemm_kernels.h" + +// 5xlt16 bf16 fringe kernel +LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5xlt16) +{ + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int32_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + + // For corner cases. + float buf0[16]; + float buf1[16]; + float buf2[16]; + float buf3[16]; + float buf4[16]; + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + + __m512 c_float_2p0 = _mm512_setzero_ps(); + + __m512 c_float_3p0 = _mm512_setzero_ps(); + + __m512 c_float_4p0 = _mm512_setzero_ps(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + + // Broadcast a[2,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-15] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + + // Broadcast a[3,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-15] = a[3,kr:kr+2]*b[kr:kr+2,0-15] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); + + // Broadcast a[4,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[4,0-15] = a[4,kr:kr+2]*b[kr:kr+2,0-15] + c_float_4p0 = _mm512_dpbf16_ps( c_float_4p0, a_bf16_0, b0 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + + // Broadcast a[2,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-15] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + + // Broadcast a[3,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-15] = a[3,kr:kr+2]*b[kr:kr+2,0-15] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); + + // Broadcast a[4,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 4 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[4,0-15] = a[4,kr:kr+2]*b[kr:kr+2,0-15] + c_float_4p0 = _mm512_dpbf16_ps( c_float_4p0, a_bf16_0, b0 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + + c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); + + c_float_2p0 = _mm512_mul_ps( selector1, c_float_2p0 ); + + c_float_3p0 = _mm512_mul_ps( selector1, c_float_3p0 ); + + c_float_4p0 = _mm512_mul_ps( selector1, c_float_4p0 ); + + // Scale C by beta. + if ( beta != 0 ) + { + memcpy( buf0, ( c + ( rs_c * 0 ) ), ( n0_rem * sizeof( float ) ) ); + memcpy( buf1, ( c + ( rs_c * 1 ) ), ( n0_rem * sizeof( float ) ) ); + memcpy( buf2, ( c + ( rs_c * 2 ) ), ( n0_rem * sizeof( float ) ) ); + memcpy( buf3, ( c + ( rs_c * 3 ) ), ( n0_rem * sizeof( float ) ) ); + memcpy( buf4, ( c + ( rs_c * 4 ) ), ( n0_rem * sizeof( float ) ) ); + + // c[0,0-15] + selector1 = _mm512_loadu_ps( buf0 ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + selector1 = _mm512_loadu_ps( buf1 ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[2,0-15] + selector1 = _mm512_loadu_ps( buf2 ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[3,0-15] + selector1 = _mm512_loadu_ps( buf3 ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + + // c[4,0-15] + selector1 = _mm512_loadu_ps( buf4 ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); + } + + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( buf0, c_float_0p0 ); + + // c[1,0-15] + _mm512_storeu_ps( buf1, c_float_1p0 ); + + // c[2,0-15] + _mm512_storeu_ps( buf2, c_float_2p0 ); + + // c[3,0-15] + _mm512_storeu_ps( buf3, c_float_3p0 ); + + // c[4,0-15] + _mm512_storeu_ps( buf4, c_float_4p0 ); + + // Memcpy partial parts. + // c[0,0-15] + memcpy( c + ( rs_c * 0 ) + ( 0*16 ), buf0, ( n0_rem * sizeof( float ) ) ); + + // c[1,0-15] + memcpy( c + ( rs_c * 1 ) + ( 0*16 ), buf1, ( n0_rem * sizeof( float ) ) ); + + // c[2,0-15] + memcpy( c + ( rs_c * 2 ) + ( 0*16 ), buf2, ( n0_rem * sizeof( float ) ) ); + + // c[3,0-15] + memcpy( c + ( rs_c * 3 ) + ( 0*16 ), buf3, ( n0_rem * sizeof( float ) ) ); + + // c[4,0-15] + memcpy( c + ( rs_c * 4 ) + ( 0*16 ), buf4, ( n0_rem * sizeof( float ) ) ); + +} + +// 4xlt16 bf16 fringe kernel +LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4xlt16) +{ + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int32_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + + // For corner cases. + float buf0[16]; + float buf1[16]; + float buf2[16]; + float buf3[16]; + + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + + __m512 c_float_2p0 = _mm512_setzero_ps(); + + __m512 c_float_3p0 = _mm512_setzero_ps(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + + // Broadcast a[2,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-15] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + + // Broadcast a[3,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-15] = a[3,kr:kr+2]*b[kr:kr+2,0-15] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); + } + // Handle k remainder. + + if ( k_partial_pieces > 0 ) + { + b0 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + + // Broadcast a[2,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-15] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + + // Broadcast a[3,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-15] = a[3,kr:kr+2]*b[kr:kr+2,0-15] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + + c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); + + c_float_2p0 = _mm512_mul_ps( selector1, c_float_2p0 ); + + c_float_3p0 = _mm512_mul_ps( selector1, c_float_3p0 ); + + // Scale C by beta. + if ( beta != 0 ) + { + memcpy( buf0, ( c + ( rs_c * 0 ) ), ( n0_rem * sizeof( float ) ) ); + memcpy( buf1, ( c + ( rs_c * 1 ) ), ( n0_rem * sizeof( float ) ) ); + memcpy( buf2, ( c + ( rs_c * 2 ) ), ( n0_rem * sizeof( float ) ) ); + memcpy( buf3, ( c + ( rs_c * 3 ) ), ( n0_rem * sizeof( float ) ) ); + + // c[0,0-15] + selector1 = _mm512_loadu_ps( buf0 ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + selector1 = _mm512_loadu_ps( buf1 ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[2,0-15] + selector1 = _mm512_loadu_ps( buf2 ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[3,0-15] + selector1 = _mm512_loadu_ps( buf3 ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + } + + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( buf0, c_float_0p0 ); + + // c[1,0-15] + _mm512_storeu_ps( buf1, c_float_1p0 ); + + // c[2,0-15] + _mm512_storeu_ps( buf2, c_float_2p0 ); + + // c[3,0-15] + _mm512_storeu_ps( buf3, c_float_3p0 ); + + // Memcpy partial parts. + // c[0,0-15] + memcpy( c + ( rs_c * 0 ) + ( 0*16 ), buf0, ( n0_rem * sizeof( float ) ) ); + + // c[1,0-15] + memcpy( c + ( rs_c * 1 ) + ( 0*16 ), buf1, ( n0_rem * sizeof( float ) ) ); + + // c[2,0-15] + memcpy( c + ( rs_c * 2 ) + ( 0*16 ), buf2, ( n0_rem * sizeof( float ) ) ); + + // c[3,0-15] + memcpy( c + ( rs_c * 3 ) + ( 0*16 ), buf3, ( n0_rem * sizeof( float ) ) ); + +} + +// 3xlt16 bf16 fringe kernel +LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3xlt16) +{ + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int32_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + + // For corner cases. + float buf0[16]; + float buf1[16]; + float buf2[16]; + + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + + __m512 c_float_2p0 = _mm512_setzero_ps(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + + // Broadcast a[2,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-15] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + + // Broadcast a[2,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-15] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + + c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); + + c_float_2p0 = _mm512_mul_ps( selector1, c_float_2p0 ); + + // Scale C by beta. + if ( beta != 0 ) + { + memcpy( buf0, ( c + ( rs_c * 0 ) ), ( n0_rem * sizeof( float ) ) ); + memcpy( buf1, ( c + ( rs_c * 1 ) ), ( n0_rem * sizeof( float ) ) ); + memcpy( buf2, ( c + ( rs_c * 2 ) ), ( n0_rem * sizeof( float) ) ); + + // c[0,0-15] + selector1 = _mm512_loadu_ps( buf0 ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + selector1 = _mm512_loadu_ps( buf1 ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[2,0-15] + selector1 = _mm512_loadu_ps( buf2 ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + } + + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( buf0, c_float_0p0 ); + + // c[1,0-15] + _mm512_storeu_ps( buf1, c_float_1p0 ); + + // c[2,0-15] + _mm512_storeu_ps( buf2, c_float_2p0 ); + + // Memcpy partial parts. + // c[0,0-15] + memcpy( c + ( rs_c * 0 ) + ( 0*16 ), buf0, ( n0_rem * sizeof( float ) ) ); + + // c[1,0-15] + memcpy( c + ( rs_c * 1 ) + ( 0*16 ), buf1, ( n0_rem * sizeof( float ) ) ); + + // c[2,0-15] + memcpy( c + ( rs_c * 2 ) + ( 0*16 ), buf2, ( n0_rem * sizeof( float ) ) ); + +} + +// 2xlt16 bf16 fringe kernel +LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2xlt16) +{ + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int32_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + + // For corner cases. + float buf0[16]; + float buf1[16]; + + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + + c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); + + // Scale C by beta. + if ( beta != 0 ) + { + memcpy( buf0, ( c + ( rs_c * 0 ) ), ( n0_rem * sizeof( float ) ) ); + memcpy( buf1, ( c + ( rs_c * 1 ) ), ( n0_rem * sizeof( float) ) ); + + // c[0,0-15] + selector1 = _mm512_loadu_ps( buf0 ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + selector1 = _mm512_loadu_ps( buf1 ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + } + + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( buf0, c_float_0p0 ); + + // c[1,0-15] + _mm512_storeu_ps( buf1, c_float_1p0 ); + + // Memcpy partial parts. + // c[0,0-15] + memcpy( c + ( rs_c * 0 ) + ( 0*16 ), buf0, ( n0_rem * sizeof( float ) ) ); + + // c[1,0-15] + memcpy( c + ( rs_c * 1 ) + ( 0*16 ), buf1, ( n0_rem * sizeof( float ) ) ); + +} + +// 1xlt16 bf16 fringe kernel +LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1xlt16) +{ + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int32_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + + // For corner cases. + float buf0[16]; + + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + + // Scale C by beta. + if ( beta != 0 ) + { + memcpy( buf0, ( c + ( rs_c * 0 ) ), ( n0_rem * sizeof( float ) ) ); + + // c[0,0-15] + selector1 = _mm512_loadu_ps( buf0 ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + } + + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( buf0, c_float_0p0 ); + + // Memcpy partial parts. + // c[0,0-15] + memcpy( c + ( rs_c * 0 ) + ( 0*16 ), buf0, ( n0_rem * sizeof( float ) ) ); + +} + +// 5x16 bf16 kernel +LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x16) +{ + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int32_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + + __m512 c_float_2p0 = _mm512_setzero_ps(); + + __m512 c_float_3p0 = _mm512_setzero_ps(); + + __m512 c_float_4p0 = _mm512_setzero_ps(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + + // Broadcast a[2,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-15] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + + // Broadcast a[3,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-15] = a[3,kr:kr+2]*b[kr:kr+2,0-15] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); + + // Broadcast a[4,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[4,0-15] = a[4,kr:kr+2]*b[kr:kr+2,0-15] + c_float_4p0 = _mm512_dpbf16_ps( c_float_4p0, a_bf16_0, b0 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + + // Broadcast a[2,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-15] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + + // Broadcast a[3,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-15] = a[3,kr:kr+2]*b[kr:kr+2,0-15] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); + + // Broadcast a[4,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 4 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[4,0-15] = a[4,kr:kr+2]*b[kr:kr+2,0-15] + c_float_4p0 = _mm512_dpbf16_ps( c_float_4p0, a_bf16_0, b0 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + + c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); + + c_float_2p0 = _mm512_mul_ps( selector1, c_float_2p0 ); + + c_float_3p0 = _mm512_mul_ps( selector1, c_float_3p0 ); + + c_float_4p0 = _mm512_mul_ps( selector1, c_float_4p0 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[2,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 2 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[3,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 3 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + + // c[4,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 4 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); + } + + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 0*16 ), c_float_0p0 ); + + // c[1,0-15] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 0*16 ), c_float_1p0 ); + + // c[2,0-15] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 0*16 ), c_float_2p0 ); + + // c[3,0-15] + _mm512_storeu_ps( c + ( rs_c * 3 ) + ( 0*16 ), c_float_3p0 ); + + // c[4,0-15] + _mm512_storeu_ps( c + ( rs_c * 4 ) + ( 0*16 ), c_float_4p0 ); +} + +// 4x16 bf16 kernel +LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x16) +{ + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int32_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + + __m512 c_float_2p0 = _mm512_setzero_ps(); + + __m512 c_float_3p0 = _mm512_setzero_ps(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + + // Broadcast a[2,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-15] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + + // Broadcast a[3,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-15] = a[3,kr:kr+2]*b[kr:kr+2,0-15] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + + // Broadcast a[2,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-15] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + + // Broadcast a[3,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-15] = a[3,kr:kr+2]*b[kr:kr+2,0-15] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + + c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); + + c_float_2p0 = _mm512_mul_ps( selector1, c_float_2p0 ); + + c_float_3p0 = _mm512_mul_ps( selector1, c_float_3p0 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[2,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 2 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[3,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 3 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + } + + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 0*16 ), c_float_0p0 ); + + // c[1,0-15] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 0*16 ), c_float_1p0 ); + + // c[2,0-15] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 0*16 ), c_float_2p0 ); + + // c[3,0-15] + _mm512_storeu_ps( c + ( rs_c * 3 ) + ( 0*16 ), c_float_3p0 ); +} + +// 3x16 bf16 kernel +LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x16) +{ + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int32_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + + __m512 c_float_2p0 = _mm512_setzero_ps(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + + // Broadcast a[2,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-15] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + + // Broadcast a[2,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-15] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + + c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); + + c_float_2p0 = _mm512_mul_ps( selector1, c_float_2p0 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[2,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 2 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + } + + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 0*16 ), c_float_0p0 ); + + // c[1,0-15] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 0*16 ), c_float_1p0 ); + + // c[2,0-15] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 0*16 ), c_float_2p0 ); +} + +// 2x16 bf16 kernel +LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2x16) +{ + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int32_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + + c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + } + + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 0*16 ), c_float_0p0 ); + + // c[1,0-15] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 0*16 ), c_float_1p0 ); +} + +// 1x16 bf16 kernel +LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1x16) +{ + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int32_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + } + + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 0*16 ), c_float_0p0 ); +} + +// 5x32 bf16 kernel +LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x32) +{ + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int32_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + __m512bh b1; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + __m512 c_float_0p1 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + __m512 c_float_1p1 = _mm512_setzero_ps(); + + __m512 c_float_2p0 = _mm512_setzero_ps(); + __m512 c_float_2p1 = _mm512_setzero_ps(); + + __m512 c_float_3p0 = _mm512_setzero_ps(); + __m512 c_float_3p1 = _mm512_setzero_ps(); + + __m512 c_float_4p0 = _mm512_setzero_ps(); + __m512 c_float_4p1 = _mm512_setzero_ps(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b1 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + + // Broadcast a[1,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-31] = a[1,kr:kr+2]*b[kr:kr+2,0-31] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_0, b1 ); + + // Broadcast a[2,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-31] = a[2,kr:kr+2]*b[kr:kr+2,0-31] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); + + // Broadcast a[3,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-31] = a[3,kr:kr+2]*b[kr:kr+2,0-31] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); + c_float_3p1 = _mm512_dpbf16_ps( c_float_3p1, a_bf16_0, b1 ); + + // Broadcast a[4,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[4,0-31] = a[4,kr:kr+2]*b[kr:kr+2,0-31] + c_float_4p0 = _mm512_dpbf16_ps( c_float_4p0, a_bf16_0, b0 ); + c_float_4p1 = _mm512_dpbf16_ps( c_float_4p1, a_bf16_0, b1 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b1 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + + // Broadcast a[0,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + + // Broadcast a[1,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-31] = a[1,kr:kr+2]*b[kr:kr+2,0-31] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_0, b1 ); + + // Broadcast a[2,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-31] = a[2,kr:kr+2]*b[kr:kr+2,0-31] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); + + // Broadcast a[3,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-31] = a[3,kr:kr+2]*b[kr:kr+2,0-31] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); + c_float_3p1 = _mm512_dpbf16_ps( c_float_3p1, a_bf16_0, b1 ); + + // Broadcast a[4,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 4 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[4,0-31] = a[4,kr:kr+2]*b[kr:kr+2,0-31] + c_float_4p0 = _mm512_dpbf16_ps( c_float_4p0, a_bf16_0, b0 ); + c_float_4p1 = _mm512_dpbf16_ps( c_float_4p1, a_bf16_0, b1 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + c_float_0p1 = _mm512_mul_ps( selector1, c_float_0p1 ); + + c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); + c_float_1p1 = _mm512_mul_ps( selector1, c_float_1p1 ); + + c_float_2p0 = _mm512_mul_ps( selector1, c_float_2p0 ); + c_float_2p1 = _mm512_mul_ps( selector1, c_float_2p1 ); + + c_float_3p0 = _mm512_mul_ps( selector1, c_float_3p0 ); + c_float_3p1 = _mm512_mul_ps( selector1, c_float_3p1 ); + + c_float_4p0 = _mm512_mul_ps( selector1, c_float_4p0 ); + c_float_4p1 = _mm512_mul_ps( selector1, c_float_4p1 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[1,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p1 = _mm512_add_ps( selector1, c_float_1p1 ); + + // c[2,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 2 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[2,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 2 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p1 = _mm512_add_ps( selector1, c_float_2p1 ); + + // c[3,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 3 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + + // c[3,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 3 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_3p1 = _mm512_add_ps( selector1, c_float_3p1 ); + + // c[4,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 4 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); + + // c[4,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 4 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_4p1 = _mm512_add_ps( selector1, c_float_4p1 ); + } + + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 0*16 ), c_float_0p0 ); + + // c[0, 16-31] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 1*16 ), c_float_0p1 ); + + // c[1,0-15] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 0*16 ), c_float_1p0 ); + + // c[1,16-31] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 1*16 ), c_float_1p1 ); + + // c[2,0-15] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 0*16 ), c_float_2p0 ); + + // c[2,16-31] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 1*16 ), c_float_2p1 ); + + // c[3,0-15] + _mm512_storeu_ps( c + ( rs_c * 3 ) + ( 0*16 ), c_float_3p0 ); + + // c[3,16-31] + _mm512_storeu_ps( c + ( rs_c * 3 ) + ( 1*16 ), c_float_3p1 ); + + // c[4,0-15] + _mm512_storeu_ps( c + ( rs_c * 4 ) + ( 0*16 ), c_float_4p0 ); + + // c[4,16-31] + _mm512_storeu_ps( c + ( rs_c * 4 ) + ( 1*16 ), c_float_4p1 ); +} + +// 4x32 bf16 kernel +LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x32) +{ + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int32_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + __m512bh b1; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + __m512 c_float_0p1 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + __m512 c_float_1p1 = _mm512_setzero_ps(); + + __m512 c_float_2p0 = _mm512_setzero_ps(); + __m512 c_float_2p1 = _mm512_setzero_ps(); + + __m512 c_float_3p0 = _mm512_setzero_ps(); + __m512 c_float_3p1 = _mm512_setzero_ps(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b1 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + + // Broadcast a[1,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-31] = a[1,kr:kr+2]*b[kr:kr+2,0-31] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_0, b1 ); + + // Broadcast a[2,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-31] = a[2,kr:kr+2]*b[kr:kr+2,0-31] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); + + // Broadcast a[3,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-31] = a[3,kr:kr+2]*b[kr:kr+2,0-31] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); + c_float_3p1 = _mm512_dpbf16_ps( c_float_3p1, a_bf16_0, b1 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b1 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + + // Broadcast a[0,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + + // Broadcast a[1,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-31] = a[1,kr:kr+2]*b[kr:kr+2,0-31] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_0, b1 ); + + // Broadcast a[2,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-31] = a[2,kr:kr+2]*b[kr:kr+2,0-31] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); + + // Broadcast a[3,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-31] = a[3,kr:kr+2]*b[kr:kr+2,0-31] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); + c_float_3p1 = _mm512_dpbf16_ps( c_float_3p1, a_bf16_0, b1 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + c_float_0p1 = _mm512_mul_ps( selector1, c_float_0p1 ); + + c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); + c_float_1p1 = _mm512_mul_ps( selector1, c_float_1p1 ); + + c_float_2p0 = _mm512_mul_ps( selector1, c_float_2p0 ); + c_float_2p1 = _mm512_mul_ps( selector1, c_float_2p1 ); + + c_float_3p0 = _mm512_mul_ps( selector1, c_float_3p0 ); + c_float_3p1 = _mm512_mul_ps( selector1, c_float_3p1 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[1,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p1 = _mm512_add_ps( selector1, c_float_1p1 ); + + // c[2,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 2 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[2,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 2 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p1 = _mm512_add_ps( selector1, c_float_2p1 ); + + // c[3,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 3 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + + // c[3,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 3 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_3p1 = _mm512_add_ps( selector1, c_float_3p1 ); + } + + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 0*16 ), c_float_0p0 ); + + // c[0, 16-31] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 1*16 ), c_float_0p1 ); + + // c[1,0-15] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 0*16 ), c_float_1p0 ); + + // c[1,16-31] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 1*16 ), c_float_1p1 ); + + // c[2,0-15] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 0*16 ), c_float_2p0 ); + + // c[2,16-31] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 1*16 ), c_float_2p1 ); + + // c[3,0-15] + _mm512_storeu_ps( c + ( rs_c * 3 ) + ( 0*16 ), c_float_3p0 ); + + // c[3,16-31] + _mm512_storeu_ps( c + ( rs_c * 3 ) + ( 1*16 ), c_float_3p1 ); +} + +// 3x32 bf16 kernel +LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x32) +{ + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int32_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + __m512bh b1; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + __m512 c_float_0p1 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + __m512 c_float_1p1 = _mm512_setzero_ps(); + + __m512 c_float_2p0 = _mm512_setzero_ps(); + __m512 c_float_2p1 = _mm512_setzero_ps(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b1 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + + // Broadcast a[1,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-31] = a[1,kr:kr+2]*b[kr:kr+2,0-31] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_0, b1 ); + + // Broadcast a[2,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-31] = a[2,kr:kr+2]*b[kr:kr+2,0-31] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b1 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + + // Broadcast a[0,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + + // Broadcast a[1,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-31] = a[1,kr:kr+2]*b[kr:kr+2,0-31] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_0, b1 ); + + // Broadcast a[2,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-31] = a[2,kr:kr+2]*b[kr:kr+2,0-31] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + c_float_0p1 = _mm512_mul_ps( selector1, c_float_0p1 ); + + c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); + c_float_1p1 = _mm512_mul_ps( selector1, c_float_1p1 ); + + c_float_2p0 = _mm512_mul_ps( selector1, c_float_2p0 ); + c_float_2p1 = _mm512_mul_ps( selector1, c_float_2p1 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[1,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p1 = _mm512_add_ps( selector1, c_float_1p1 ); + + // c[2,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 2 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[2,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 2 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p1 = _mm512_add_ps( selector1, c_float_2p1 ); + } + + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 0*16 ), c_float_0p0 ); + + // c[0, 16-31] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 1*16 ), c_float_0p1 ); + + // c[1,0-15] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 0*16 ), c_float_1p0 ); + + // c[1,16-31] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 1*16 ), c_float_1p1 ); + + // c[2,0-15] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 0*16 ), c_float_2p0 ); + + // c[2,16-31] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 1*16 ), c_float_2p1 ); +} + +// 2x32 bf16 kernel +LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2x32) +{ + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int32_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + __m512bh b1; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + __m512 c_float_0p1 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + __m512 c_float_1p1 = _mm512_setzero_ps(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b1 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + + // Broadcast a[1,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-31] = a[1,kr:kr+2]*b[kr:kr+2,0-31] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_0, b1 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b1 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + + // Broadcast a[0,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + + // Broadcast a[1,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-31] = a[1,kr:kr+2]*b[kr:kr+2,0-31] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_0, b1 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + c_float_0p1 = _mm512_mul_ps( selector1, c_float_0p1 ); + + c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); + c_float_1p1 = _mm512_mul_ps( selector1, c_float_1p1 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[1,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p1 = _mm512_add_ps( selector1, c_float_1p1 ); + } + + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 0*16 ), c_float_0p0 ); + + // c[0, 16-31] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 1*16 ), c_float_0p1 ); + + // c[1,0-15] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 0*16 ), c_float_1p0 ); + + // c[1,16-31] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 1*16 ), c_float_1p1 ); +} + +// 1x32 bf16 kernel +LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1x32) +{ + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int32_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + __m512bh b1; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + __m512 c_float_0p1 = _mm512_setzero_ps(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b1 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b1 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + + // Broadcast a[0,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + c_float_0p1 = _mm512_mul_ps( selector1, c_float_0p1 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + } + + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 0*16 ), c_float_0p0 ); + + // c[0, 16-31] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 1*16 ), c_float_0p1 ); +} + +// 5x48 bf16 kernel +LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x48) +{ + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int32_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + __m512bh b1; + __m512bh b2; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + __m512 c_float_0p1 = _mm512_setzero_ps(); + __m512 c_float_0p2 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + __m512 c_float_1p1 = _mm512_setzero_ps(); + __m512 c_float_1p2 = _mm512_setzero_ps(); + + __m512 c_float_2p0 = _mm512_setzero_ps(); + __m512 c_float_2p1 = _mm512_setzero_ps(); + __m512 c_float_2p2 = _mm512_setzero_ps(); + + __m512 c_float_3p0 = _mm512_setzero_ps(); + __m512 c_float_3p1 = _mm512_setzero_ps(); + __m512 c_float_3p2 = _mm512_setzero_ps(); + + __m512 c_float_4p0 = _mm512_setzero_ps(); + __m512 c_float_4p1 = _mm512_setzero_ps(); + __m512 c_float_4p2 = _mm512_setzero_ps(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b1 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + b2 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 2 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-47] = a[0,kr:kr+2]*b[kr:kr+2,0-47] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); + + // Broadcast a[1,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-47] = a[1,kr:kr+2]*b[kr:kr+2,0-47] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_0, b1 ); + c_float_1p2 = _mm512_dpbf16_ps( c_float_1p2, a_bf16_0, b2 ); + + // Broadcast a[2,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-47] = a[2,kr:kr+2]*b[kr:kr+2,0-47] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); + c_float_2p2 = _mm512_dpbf16_ps( c_float_2p2, a_bf16_0, b2 ); + + // Broadcast a[3,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-47] = a[3,kr:kr+2]*b[kr:kr+2,0-47] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); + c_float_3p1 = _mm512_dpbf16_ps( c_float_3p1, a_bf16_0, b1 ); + c_float_3p2 = _mm512_dpbf16_ps( c_float_3p2, a_bf16_0, b2 ); + + // Broadcast a[4,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[4,0-47] = a[4,kr:kr+2]*b[kr:kr+2,0-47] + c_float_4p0 = _mm512_dpbf16_ps( c_float_4p0, a_bf16_0, b0 ); + c_float_4p1 = _mm512_dpbf16_ps( c_float_4p1, a_bf16_0, b1 ); + c_float_4p2 = _mm512_dpbf16_ps( c_float_4p2, a_bf16_0, b2 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b1 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + b2 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); + + // Broadcast a[0,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-47] = a[0,kr:kr+2]*b[kr:kr+2,0-47] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); + + // Broadcast a[1,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-47] = a[1,kr:kr+2]*b[kr:kr+2,0-47] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_0, b1 ); + c_float_1p2 = _mm512_dpbf16_ps( c_float_1p2, a_bf16_0, b2 ); + + // Broadcast a[2,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-47] = a[2,kr:kr+2]*b[kr:kr+2,0-47] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); + c_float_2p2 = _mm512_dpbf16_ps( c_float_2p2, a_bf16_0, b2 ); + + // Broadcast a[3,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-47] = a[3,kr:kr+2]*b[kr:kr+2,0-47] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); + c_float_3p1 = _mm512_dpbf16_ps( c_float_3p1, a_bf16_0, b1 ); + c_float_3p2 = _mm512_dpbf16_ps( c_float_3p2, a_bf16_0, b2 ); + + // Broadcast a[4,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 4 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[4,0-47] = a[4,kr:kr+2]*b[kr:kr+2,0-47] + c_float_4p0 = _mm512_dpbf16_ps( c_float_4p0, a_bf16_0, b0 ); + c_float_4p1 = _mm512_dpbf16_ps( c_float_4p1, a_bf16_0, b1 ); + c_float_4p2 = _mm512_dpbf16_ps( c_float_4p2, a_bf16_0, b2 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + c_float_0p1 = _mm512_mul_ps( selector1, c_float_0p1 ); + c_float_0p2 = _mm512_mul_ps( selector1, c_float_0p2 ); + + c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); + c_float_1p1 = _mm512_mul_ps( selector1, c_float_1p1 ); + c_float_1p2 = _mm512_mul_ps( selector1, c_float_1p2 ); + + c_float_2p0 = _mm512_mul_ps( selector1, c_float_2p0 ); + c_float_2p1 = _mm512_mul_ps( selector1, c_float_2p1 ); + c_float_2p2 = _mm512_mul_ps( selector1, c_float_2p2 ); + + c_float_3p0 = _mm512_mul_ps( selector1, c_float_3p0 ); + c_float_3p1 = _mm512_mul_ps( selector1, c_float_3p1 ); + c_float_3p2 = _mm512_mul_ps( selector1, c_float_3p2 ); + + c_float_4p0 = _mm512_mul_ps( selector1, c_float_4p0 ); + c_float_4p1 = _mm512_mul_ps( selector1, c_float_4p1 ); + c_float_4p2 = _mm512_mul_ps( selector1, c_float_4p2 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p2 = _mm512_add_ps( selector1, c_float_0p2 ); + + // c[1,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p1 = _mm512_add_ps( selector1, c_float_1p1 ); + + // c[1,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p2 = _mm512_add_ps( selector1, c_float_1p2 ); + + // c[2,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 2 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[2,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 2 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p1 = _mm512_add_ps( selector1, c_float_2p1 ); + + // c[2,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * 2 ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p2 = _mm512_add_ps( selector1, c_float_2p2 ); + + // c[3,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 3 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + + // c[3,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 3 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_3p1 = _mm512_add_ps( selector1, c_float_3p1 ); + + // c[3,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * 3 ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_3p2 = _mm512_add_ps( selector1, c_float_3p2 ); + + // c[4,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 4 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); + + // c[4,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 4 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_4p1 = _mm512_add_ps( selector1, c_float_4p1 ); + + // c[4,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * 4 ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_4p2 = _mm512_add_ps( selector1, c_float_4p2 ); + } + + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 0*16 ), c_float_0p0 ); + + // c[0, 16-31] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 1*16 ), c_float_0p1 ); + + // c[0,32-47] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 2*16 ), c_float_0p2 ); + + // c[1,0-15] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 0*16 ), c_float_1p0 ); + + // c[1,16-31] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 1*16 ), c_float_1p1 ); + + // c[1,32-47] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 2*16 ), c_float_1p2 ); + + // c[2,0-15] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 0*16 ), c_float_2p0 ); + + // c[2,16-31] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 1*16 ), c_float_2p1 ); + + // c[2,32-47] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 2*16 ), c_float_2p2 ); + + // c[3,0-15] + _mm512_storeu_ps( c + ( rs_c * 3 ) + ( 0*16 ), c_float_3p0 ); + + // c[3,16-31] + _mm512_storeu_ps( c + ( rs_c * 3 ) + ( 1*16 ), c_float_3p1 ); + + // c[3,32-47] + _mm512_storeu_ps( c + ( rs_c * 3 ) + ( 2*16 ), c_float_3p2 ); + + // c[4,0-15] + _mm512_storeu_ps( c + ( rs_c * 4 ) + ( 0*16 ), c_float_4p0 ); + + // c[4,16-31] + _mm512_storeu_ps( c + ( rs_c * 4 ) + ( 1*16 ), c_float_4p1 ); + + // c[4,32-47] + _mm512_storeu_ps( c + ( rs_c * 4 ) + ( 2*16 ), c_float_4p2 ); +} + +// 4x48 bf16 kernel +LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x48) +{ + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int32_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + __m512bh b1; + __m512bh b2; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + __m512 c_float_0p1 = _mm512_setzero_ps(); + __m512 c_float_0p2 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + __m512 c_float_1p1 = _mm512_setzero_ps(); + __m512 c_float_1p2 = _mm512_setzero_ps(); + + __m512 c_float_2p0 = _mm512_setzero_ps(); + __m512 c_float_2p1 = _mm512_setzero_ps(); + __m512 c_float_2p2 = _mm512_setzero_ps(); + + __m512 c_float_3p0 = _mm512_setzero_ps(); + __m512 c_float_3p1 = _mm512_setzero_ps(); + __m512 c_float_3p2 = _mm512_setzero_ps(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b1 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + b2 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 2 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-47] = a[0,kr:kr+2]*b[kr:kr+2,0-47] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); + + // Broadcast a[1,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-47] = a[1,kr:kr+2]*b[kr:kr+2,0-47] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_0, b1 ); + c_float_1p2 = _mm512_dpbf16_ps( c_float_1p2, a_bf16_0, b2 ); + + // Broadcast a[2,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-47] = a[2,kr:kr+2]*b[kr:kr+2,0-47] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); + c_float_2p2 = _mm512_dpbf16_ps( c_float_2p2, a_bf16_0, b2 ); + + // Broadcast a[3,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-47] = a[3,kr:kr+2]*b[kr:kr+2,0-47] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); + c_float_3p1 = _mm512_dpbf16_ps( c_float_3p1, a_bf16_0, b1 ); + c_float_3p2 = _mm512_dpbf16_ps( c_float_3p2, a_bf16_0, b2 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b1 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + b2 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); + + // Broadcast a[0,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-47] = a[0,kr:kr+2]*b[kr:kr+2,0-47] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); + + // Broadcast a[1,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-47] = a[1,kr:kr+2]*b[kr:kr+2,0-47] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_0, b1 ); + c_float_1p2 = _mm512_dpbf16_ps( c_float_1p2, a_bf16_0, b2 ); + + // Broadcast a[2,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-47] = a[2,kr:kr+2]*b[kr:kr+2,0-47] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); + c_float_2p2 = _mm512_dpbf16_ps( c_float_2p2, a_bf16_0, b2 ); + + // Broadcast a[3,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-47] = a[3,kr:kr+2]*b[kr:kr+2,0-47] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); + c_float_3p1 = _mm512_dpbf16_ps( c_float_3p1, a_bf16_0, b1 ); + c_float_3p2 = _mm512_dpbf16_ps( c_float_3p2, a_bf16_0, b2 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + c_float_0p1 = _mm512_mul_ps( selector1, c_float_0p1 ); + c_float_0p2 = _mm512_mul_ps( selector1, c_float_0p2 ); + + c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); + c_float_1p1 = _mm512_mul_ps( selector1, c_float_1p1 ); + c_float_1p2 = _mm512_mul_ps( selector1, c_float_1p2 ); + + c_float_2p0 = _mm512_mul_ps( selector1, c_float_2p0 ); + c_float_2p1 = _mm512_mul_ps( selector1, c_float_2p1 ); + c_float_2p2 = _mm512_mul_ps( selector1, c_float_2p2 ); + + c_float_3p0 = _mm512_mul_ps( selector1, c_float_3p0 ); + c_float_3p1 = _mm512_mul_ps( selector1, c_float_3p1 ); + c_float_3p2 = _mm512_mul_ps( selector1, c_float_3p2 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p2 = _mm512_add_ps( selector1, c_float_0p2 ); + + // c[1,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p1 = _mm512_add_ps( selector1, c_float_1p1 ); + + // c[1,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p2 = _mm512_add_ps( selector1, c_float_1p2 ); + + // c[2,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 2 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[2,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 2 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p1 = _mm512_add_ps( selector1, c_float_2p1 ); + + // c[2,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * 2 ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p2 = _mm512_add_ps( selector1, c_float_2p2 ); + + // c[3,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 3 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + + // c[3,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 3 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_3p1 = _mm512_add_ps( selector1, c_float_3p1 ); + + // c[3,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * 3 ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_3p2 = _mm512_add_ps( selector1, c_float_3p2 ); + } + + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 0*16 ), c_float_0p0 ); + + // c[0, 16-31] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 1*16 ), c_float_0p1 ); + + // c[0,32-47] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 2*16 ), c_float_0p2 ); + + // c[1,0-15] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 0*16 ), c_float_1p0 ); + + // c[1,16-31] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 1*16 ), c_float_1p1 ); + + // c[1,32-47] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 2*16 ), c_float_1p2 ); + + // c[2,0-15] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 0*16 ), c_float_2p0 ); + + // c[2,16-31] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 1*16 ), c_float_2p1 ); + + // c[2,32-47] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 2*16 ), c_float_2p2 ); + + // c[3,0-15] + _mm512_storeu_ps( c + ( rs_c * 3 ) + ( 0*16 ), c_float_3p0 ); + + // c[3,16-31] + _mm512_storeu_ps( c + ( rs_c * 3 ) + ( 1*16 ), c_float_3p1 ); + + // c[3,32-47] + _mm512_storeu_ps( c + ( rs_c * 3 ) + ( 2*16 ), c_float_3p2 ); +} + +// 3x48 bf16 kernel +LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x48) +{ + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int32_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + __m512bh b1; + __m512bh b2; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + + // Registers to use for accumulating C. + __m512 float_0p0 = _mm512_setzero_ps(); + __m512 float_0p1 = _mm512_setzero_ps(); + __m512 float_0p2 = _mm512_setzero_ps(); + + __m512 float_1p0 = _mm512_setzero_ps(); + __m512 float_1p1 = _mm512_setzero_ps(); + __m512 float_1p2 = _mm512_setzero_ps(); + + __m512 float_2p0 = _mm512_setzero_ps(); + __m512 float_2p1 = _mm512_setzero_ps(); + __m512 float_2p2 = _mm512_setzero_ps(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b1 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + b2 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 2 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-47] = a[0,kr:kr+2]*b[kr:kr+2,0-47] + float_0p0 = _mm512_dpbf16_ps( float_0p0, a_bf16_0, b0 ); + float_0p1 = _mm512_dpbf16_ps( float_0p1, a_bf16_0, b1 ); + float_0p2 = _mm512_dpbf16_ps( float_0p2, a_bf16_0, b2 ); + + // Broadcast a[1,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-47] = a[1,kr:kr+2]*b[kr:kr+2,0-47] + float_1p0 = _mm512_dpbf16_ps( float_1p0, a_bf16_0, b0 ); + float_1p1 = _mm512_dpbf16_ps( float_1p1, a_bf16_0, b1 ); + float_1p2 = _mm512_dpbf16_ps( float_1p2, a_bf16_0, b2 ); + + // Broadcast a[2,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-47] = a[2,kr:kr+2]*b[kr:kr+2,0-47] + float_2p0 = _mm512_dpbf16_ps( float_2p0, a_bf16_0, b0 ); + float_2p1 = _mm512_dpbf16_ps( float_2p1, a_bf16_0, b1 ); + float_2p2 = _mm512_dpbf16_ps( float_2p2, a_bf16_0, b2 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b1 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + b2 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); + + // Broadcast a[0,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-47] = a[0,kr:kr+2]*b[kr:kr+2,0-47] + float_0p0 = _mm512_dpbf16_ps( float_0p0, a_bf16_0, b0 ); + float_0p1 = _mm512_dpbf16_ps( float_0p1, a_bf16_0, b1 ); + float_0p2 = _mm512_dpbf16_ps( float_0p2, a_bf16_0, b2 ); + + // Broadcast a[1,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-47] = a[1,kr:kr+2]*b[kr:kr+2,0-47] + float_1p0 = _mm512_dpbf16_ps( float_1p0, a_bf16_0, b0 ); + float_1p1 = _mm512_dpbf16_ps( float_1p1, a_bf16_0, b1 ); + float_1p2 = _mm512_dpbf16_ps( float_1p2, a_bf16_0, b2 ); + + // Broadcast a[2,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-47] = a[2,kr:kr+2]*b[kr:kr+2,0-47] + float_2p0 = _mm512_dpbf16_ps( float_2p0, a_bf16_0, b0 ); + float_2p1 = _mm512_dpbf16_ps( float_2p1, a_bf16_0, b1 ); + float_2p2 = _mm512_dpbf16_ps( float_2p2, a_bf16_0, b2 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + // Scale by alpha + float_0p0 = _mm512_mul_ps( selector1, float_0p0 ); + float_0p1 = _mm512_mul_ps( selector1, float_0p1 ); + float_0p2 = _mm512_mul_ps( selector1, float_0p2 ); + + float_1p0 = _mm512_mul_ps( selector1, float_1p0 ); + float_1p1 = _mm512_mul_ps( selector1, float_1p1 ); + float_1p2 = _mm512_mul_ps( selector1, float_1p2 ); + + float_2p0 = _mm512_mul_ps( selector1, float_2p0 ); + float_2p1 = _mm512_mul_ps( selector1, float_2p1 ); + float_2p2 = _mm512_mul_ps( selector1, float_2p2 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + float_0p0 = _mm512_add_ps( selector1, float_0p0 ); + + // c[0, 16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + float_0p1 = _mm512_add_ps( selector1, float_0p1 ); + + // c[0,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + float_0p2 = _mm512_add_ps( selector1, float_0p2 ); + + // c[1,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + float_1p0 = _mm512_add_ps( selector1, float_1p0 ); + + // c[1,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + float_1p1 = _mm512_add_ps( selector1, float_1p1 ); + + // c[1,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + float_1p2 = _mm512_add_ps( selector1, float_1p2 ); + + // c[2,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 2 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + float_2p0 = _mm512_add_ps( selector1, float_2p0 ); + + // c[2,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 2 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + float_2p1 = _mm512_add_ps( selector1, float_2p1 ); + + // c[2,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * 2 ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + float_2p2 = _mm512_add_ps( selector1, float_2p2 ); + } + + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 0*16 ), float_0p0 ); + + // c[0, 16-31] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 1*16 ), float_0p1 ); + + // c[0,32-47] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 2*16 ), float_0p2 ); + + // c[1,0-15] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 0*16 ), float_1p0 ); + + // c[1,16-31] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 1*16 ), float_1p1 ); + + // c[1,32-47] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 2*16 ), float_1p2 ); + + // c[2,0-15] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 0*16 ), float_2p0 ); + + // c[2,16-31] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 1*16 ), float_2p1 ); + + // c[2,32-47] + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 2*16 ), float_2p2 ); +} + +// 2x48 bf16 kernel +LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2x48) +{ + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int32_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + __m512bh b1; + __m512bh b2; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + __m512 c_float_0p1 = _mm512_setzero_ps(); + __m512 c_float_0p2 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + __m512 c_float_1p1 = _mm512_setzero_ps(); + __m512 c_float_1p2 = _mm512_setzero_ps(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b1 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + b2 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 2 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-47] = a[0,kr:kr+2]*b[kr:kr+2,0-47] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); + + // Broadcast a[1,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-47] = a[1,kr:kr+2]*b[kr:kr+2,0-47] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_0, b1 ); + c_float_1p2 = _mm512_dpbf16_ps( c_float_1p2, a_bf16_0, b2 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b1 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + b2 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); + + // Broadcast a[0,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-47] = a[0,kr:kr+2]*b[kr:kr+2,0-47] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); + + // Broadcast a[1,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-47] = a[1,kr:kr+2]*b[kr:kr+2,0-47] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_0, b1 ); + c_float_1p2 = _mm512_dpbf16_ps( c_float_1p2, a_bf16_0, b2 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + c_float_0p1 = _mm512_mul_ps( selector1, c_float_0p1 ); + c_float_0p2 = _mm512_mul_ps( selector1, c_float_0p2 ); + + c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); + c_float_1p1 = _mm512_mul_ps( selector1, c_float_1p1 ); + c_float_1p2 = _mm512_mul_ps( selector1, c_float_1p2 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p2 = _mm512_add_ps( selector1, c_float_0p2 ); + + // c[1,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p1 = _mm512_add_ps( selector1, c_float_1p1 ); + + // c[1,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p2 = _mm512_add_ps( selector1, c_float_1p2 ); + } + + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 0*16 ), c_float_0p0 ); + + // c[0, 16-31] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 1*16 ), c_float_0p1 ); + + // c[0,32-47] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 2*16 ), c_float_0p2 ); + + // c[1,0-15] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 0*16 ), c_float_1p0 ); + + // c[1,16-31] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 1*16 ), c_float_1p1 ); + + // c[1,32-47] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 2*16 ), c_float_1p2 ); +} + +// 1x48 bf16 kernel +LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1x48) +{ + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int32_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + __m512bh b1; + __m512bh b2; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + __m512 c_float_0p1 = _mm512_setzero_ps(); + __m512 c_float_0p2 = _mm512_setzero_ps(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + b0 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b1 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + b2 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 2 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-47] = a[0,kr:kr+2]*b[kr:kr+2,0-47] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b1 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + b2 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); + + // Broadcast a[0,kr:kr+2]. + memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-47] = a[0,kr:kr+2]*b[kr:kr+2,0-47] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + c_float_0p1 = _mm512_mul_ps( selector1, c_float_0p1 ); + c_float_0p2 = _mm512_mul_ps( selector1, c_float_0p2 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p2 = _mm512_add_ps( selector1, c_float_0p2 ); + } + + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 0*16 ), c_float_0p0 ); + + // c[0, 16-31] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 1*16 ), c_float_0p1 ); + + // c[0,32-47] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 2*16 ), c_float_0p2 ); +} diff --git a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_n_fringe_bf16_amd512vnni.c b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_n_fringe_bf16_amd512vnni.c new file mode 100644 index 0000000000..bdfb8de8e4 --- /dev/null +++ b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_n_fringe_bf16_amd512vnni.c @@ -0,0 +1,1590 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ +#include +#include + +#include "blis.h" +#include "lpgemm_kernels.h" + +// 6xlt16 bf16 fringe kernel +LPGEMM_N_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6xlt16) +{ + dim_t MR = 6; + dim_t m_full_pieces = m0 / MR; + dim_t m_full_pieces_loop_limit = m_full_pieces * MR; + dim_t m_partial_pieces = m0 % MR; + + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int32_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + + // For corner cases. + float buf0[16]; + float buf1[16]; + float buf2[16]; + float buf3[16]; + float buf4[16]; + float buf5[16]; + + for ( dim_t ir = 0; ir < m_full_pieces_loop_limit; ir += MR ) + { + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + + __m512 c_float_2p0 = _mm512_setzero_ps(); + + __m512 c_float_3p0 = _mm512_setzero_ps(); + + __m512 c_float_4p0 = _mm512_setzero_ps(); + + __m512 c_float_5p0 = _mm512_setzero_ps(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + // Load 2 rows with 16 extended elements each from B to 1 ZMM + // registers. It is to be noted that the B matrix is packed for use + // in bf16 instructions and each load to ZMM register will have 2 + // elements along k direction and 16 elements across n directions, + // so 2x16 elements to a ZMM register. + b0 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + + // Broadcast a[2,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-15] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + + // Broadcast a[3,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-15] = a[3,kr:kr+2]*b[kr:kr+2,0-15] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); + + // Broadcast a[4,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[4,0-15] = a[4,kr:kr+2]*b[kr:kr+2,0-15] + c_float_4p0 = _mm512_dpbf16_ps( c_float_4p0, a_bf16_0, b0 ); + + // Broadcast a[5,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 5 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[5,0-15] = a[5,kr:kr+2]*b[kr:kr+2,0-15] + c_float_5p0 = _mm512_dpbf16_ps( c_float_5p0, a_bf16_0, b0 ); + } + + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + + // Broadcast a[2,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-15] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + + // Broadcast a[3,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-15] = a[3,kr:kr+2]*b[kr:kr+2,0-15] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); + + // Broadcast a[4,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 4 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[4,0-15] = a[4,kr:kr+2]*b[kr:kr+2,0-15] + c_float_4p0 = _mm512_dpbf16_ps( c_float_4p0, a_bf16_0, b0 ); + + // Broadcast a[5,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 5 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[5,0-15] = a[5,kr:kr+2]*b[kr:kr+2,0-15] + c_float_5p0 = _mm512_dpbf16_ps( c_float_5p0, a_bf16_0, b0 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + + c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); + + c_float_2p0 = _mm512_mul_ps( selector1, c_float_2p0 ); + + c_float_3p0 = _mm512_mul_ps( selector1, c_float_3p0 ); + + c_float_4p0 = _mm512_mul_ps( selector1, c_float_4p0 ); + + c_float_5p0 = _mm512_mul_ps( selector1, c_float_5p0 ); + + // Scale C by beta. + if ( beta != 0 ) + { + memcpy( buf0, ( c + ( rs_c * ( ir + 0 ) ) ), ( n0_rem * sizeof( float ) ) ); + memcpy( buf1, ( c + ( rs_c * ( ir + 1 ) ) ), ( n0_rem * sizeof( float ) ) ); + memcpy( buf2, ( c + ( rs_c * ( ir + 2 ) ) ), ( n0_rem * sizeof( float ) ) ); + memcpy( buf3, ( c + ( rs_c * ( ir + 3 ) ) ), ( n0_rem * sizeof( float) ) ); + memcpy( buf4, ( c + ( rs_c * ( ir + 4 ) ) ), ( n0_rem * sizeof( float ) ) ); + memcpy( buf5, ( c + ( rs_c * ( ir + 5 ) ) ), ( n0_rem * sizeof( float ) ) ); + + // c[0,0-15] + selector1 = _mm512_loadu_ps( buf0 ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + selector1 = _mm512_loadu_ps( buf1 ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[2,0-15] + selector1 = _mm512_loadu_ps( buf2 ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[3,0-15] + selector1 = _mm512_loadu_ps( buf3 ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + + // c[4,0-15] + selector1 = _mm512_loadu_ps( buf4 ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); + + // c[5,0-15] + selector1 = _mm512_loadu_ps( buf5 ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_5p0 = _mm512_add_ps( selector1, c_float_5p0 ); + } + + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( buf0, c_float_0p0 ); + + // c[1,0-15] + _mm512_storeu_ps( buf1, c_float_1p0 ); + + // c[2,0-15] + _mm512_storeu_ps( buf2, c_float_2p0 ); + + // c[3,0-15] + _mm512_storeu_ps( buf3, c_float_3p0 ); + + // c[4,0-15] + _mm512_storeu_ps( buf4, c_float_4p0 ); + + // c[5,0-15] + _mm512_storeu_ps( buf5, c_float_5p0 ); + + // Memcpy partial parts. + // c[0,0-15] + memcpy( c + ( rs_c * ( ir + 0 ) ) + ( 0*16 ), buf0, ( n0_rem * sizeof( float ) ) ); + + // c[1,0-15] + memcpy( c + ( rs_c * ( ir + 1 ) ) + ( 0*16 ), buf1, ( n0_rem * sizeof( float ) ) ); + + // c[2,0-15] + memcpy( c + ( rs_c * ( ir + 2 ) ) + ( 0*16 ), buf2, ( n0_rem * sizeof( float ) ) ); + + // c[3,0-15] + memcpy( c + ( rs_c * ( ir + 3 ) ) + ( 0*16 ), buf3, ( n0_rem * sizeof( float ) ) ); + + // c[4,0-15] + memcpy( c + ( rs_c * ( ir + 4 ) ) + ( 0*16 ), buf4, ( n0_rem * sizeof( float ) ) ); + + // c[5,0-15] + memcpy( c + ( rs_c * ( ir + 5 ) ) + ( 0*16 ), buf5, ( n0_rem * sizeof( float ) ) ); + + a = a + ( MR * ps_a ); + } + + if ( m_partial_pieces > 0 ) + { + if ( m_partial_pieces == 5 ) + { + int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 5 ); + lpgemm_rowvar_bf16bf16f32of32_5xlt16 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, n0_rem, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list + ); + } + else if ( m_partial_pieces == 4 ) + { + int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 4 ); + lpgemm_rowvar_bf16bf16f32of32_4xlt16 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, n0_rem, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list + ); + } + else if ( m_partial_pieces == 3 ) + { + int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 3 ); + lpgemm_rowvar_bf16bf16f32of32_3xlt16 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, n0_rem, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list + ); + } + else if ( m_partial_pieces == 2 ) + { + int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 2 ); + lpgemm_rowvar_bf16bf16f32of32_2xlt16 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, n0_rem, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list + ); + } + else if ( m_partial_pieces == 1 ) + { + int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 1 ); + lpgemm_rowvar_bf16bf16f32of32_1xlt16 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, n0_rem, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list + ); + } + } +} + +// 6x16 bf16 fringe kernel +LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x16) +{ + dim_t MR = 6; + dim_t m_full_pieces = m0 / MR; + dim_t m_full_pieces_loop_limit = m_full_pieces * MR; + dim_t m_partial_pieces = m0 % MR; + + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int32_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + + for ( dim_t ir = 0; ir < m_full_pieces_loop_limit; ir += MR ) + { + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + + __m512 c_float_2p0 = _mm512_setzero_ps(); + + __m512 c_float_3p0 = _mm512_setzero_ps(); + + __m512 c_float_4p0 = _mm512_setzero_ps(); + + __m512 c_float_5p0 = _mm512_setzero_ps(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + // Load 2 rows with 16 elements each from B to 1 ZMM registers. It + // is to be noted that the B matrix is packed for use in bf16 + // instructions and each load to ZMM register will have 2 elements + // along k direction and 16 elements across n directions, so 2x16 + // elements to a ZMM register. + b0 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + + // Broadcast a[2,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-15] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + + // Broadcast a[3,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-15] = a[3,kr:kr+2]*b[kr:kr+2,0-15] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); + + // Broadcast a[4,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[4,0-15] = a[4,kr:kr+2]*b[kr:kr+2,0-15] + c_float_4p0 = _mm512_dpbf16_ps( c_float_4p0, a_bf16_0, b0 ); + + // Broadcast a[5,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 5 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[5,0-15] = a[5,kr:kr+2]*b[kr:kr+2,0-15] + c_float_5p0 = _mm512_dpbf16_ps( c_float_5p0, a_bf16_0, b0 ); + } + // Handle k remainder. + + if ( k_partial_pieces > 0 ) + { + b0 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + + // Broadcast a[0,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + + // Broadcast a[1,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + + // Broadcast a[2,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-15] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + + // Broadcast a[3,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-15] = a[3,kr:kr+2]*b[kr:kr+2,0-15] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); + + // Broadcast a[4,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 4 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[4,0-15] = a[4,kr:kr+2]*b[kr:kr+2,0-15] + c_float_4p0 = _mm512_dpbf16_ps( c_float_4p0, a_bf16_0, b0 ); + + // Broadcast a[5,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 5 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[5,0-15] = a[5,kr:kr+2]*b[kr:kr+2,0-15] + c_float_5p0 = _mm512_dpbf16_ps( c_float_5p0, a_bf16_0, b0 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + + c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); + + c_float_2p0 = _mm512_mul_ps( selector1, c_float_2p0 ); + + c_float_3p0 = _mm512_mul_ps( selector1, c_float_3p0 ); + + c_float_4p0 = _mm512_mul_ps( selector1, c_float_4p0 ); + + c_float_5p0 = _mm512_mul_ps( selector1, c_float_5p0 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 0 ) ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 1 ) ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[2,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 2 ) ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[3,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 3 ) ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + + // c[4,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 4 ) ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); + + // c[5,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 5 ) ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_5p0 = _mm512_add_ps( selector1, c_float_5p0 ); + } + + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( c + ( rs_c * ( ir + 0 ) ) + ( 0*16 ), c_float_0p0 ); + + // c[1,0-15] + _mm512_storeu_ps( c + ( rs_c * ( ir + 1 ) ) + ( 0*16 ), c_float_1p0 ); + + // c[2,0-15] + _mm512_storeu_ps( c + ( rs_c * ( ir + 2 ) ) + ( 0*16 ), c_float_2p0 ); + + // c[3,0-15] + _mm512_storeu_ps( c + ( rs_c * ( ir + 3 ) ) + ( 0*16 ), c_float_3p0 ); + + // c[4,0-15] + _mm512_storeu_ps( c + ( rs_c * ( ir + 4 ) ) + ( 0*16 ), c_float_4p0 ); + + // c[5,0-15] + _mm512_storeu_ps( c + ( rs_c * ( ir + 5 ) ) + ( 0*16 ), c_float_5p0 ); + + a = a + ( MR * ps_a ); + } + + if ( m_partial_pieces > 0 ) + { + if ( m_partial_pieces == 5 ) + { + int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 5 ); + lpgemm_rowvar_bf16bf16f32of32_5x16 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list + ); + } + else if ( m_partial_pieces == 4 ) + { + int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 4 ); + lpgemm_rowvar_bf16bf16f32of32_4x16 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list + ); + } + else if ( m_partial_pieces == 3 ) + { + int cs_a_use = ( cs_a == 2) ? 2 : ( ( cs_a / 6 ) * 3 ); + lpgemm_rowvar_bf16bf16f32of32_3x16 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list + ); + } + else if ( m_partial_pieces == 2 ) + { + int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 2 ); + lpgemm_rowvar_bf16bf16f32of32_2x16 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list + ); + } + else if ( m_partial_pieces == 1 ) + { + int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 1 ); + lpgemm_rowvar_bf16bf16f32of32_1x16 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list + ); + } + } +} + +// 6x32 bf16 fringe kernel +LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x32) +{ + dim_t MR = 6; + dim_t m_full_pieces = m0 / MR; + dim_t m_full_pieces_loop_limit = m_full_pieces * MR; + dim_t m_partial_pieces = m0 % MR; + + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int32_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + __m512bh b1; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + + for ( dim_t ir = 0; ir < m_full_pieces_loop_limit; ir += MR ) + { + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + __m512 c_float_0p1 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + __m512 c_float_1p1 = _mm512_setzero_ps(); + + __m512 c_float_2p0 = _mm512_setzero_ps(); + __m512 c_float_2p1 = _mm512_setzero_ps(); + + __m512 c_float_3p0 = _mm512_setzero_ps(); + __m512 c_float_3p1 = _mm512_setzero_ps(); + + __m512 c_float_4p0 = _mm512_setzero_ps(); + __m512 c_float_4p1 = _mm512_setzero_ps(); + + __m512 c_float_5p0 = _mm512_setzero_ps(); + __m512 c_float_5p1 = _mm512_setzero_ps(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + // Load 2 rows with 32 elements each from B to 2 ZMM registers. It + // is to be noted that the B matrix is packed for use in bf16 + // instructions and each load to ZMM register will have 2 elements + // along k direction and 32 elements across n directions, so 2x16 + // elements to a ZMM register. + b0 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b1 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + + // Broadcast a[1,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-31] = a[1,kr:kr+2]*b[kr:kr+2,0-31] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_0, b1 ); + + // Broadcast a[2,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-31] = a[2,kr:kr+2]*b[kr:kr+2,0-31] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); + + // Broadcast a[3,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-31] = a[3,kr:kr+2]*b[kr:kr+2,0-31] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); + c_float_3p1 = _mm512_dpbf16_ps( c_float_3p1, a_bf16_0, b1 ); + + // Broadcast a[4,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[4,0-31] = a[4,kr:kr+2]*b[kr:kr+2,0-31] + c_float_4p0 = _mm512_dpbf16_ps( c_float_4p0, a_bf16_0, b0 ); + c_float_4p1 = _mm512_dpbf16_ps( c_float_4p1, a_bf16_0, b1 ); + + // Broadcast a[5,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 5 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[5,0-31] = a[5,kr:kr+2]*b[kr:kr+2,0-31] + c_float_5p0 = _mm512_dpbf16_ps( c_float_5p0, a_bf16_0, b0 ); + c_float_5p1 = _mm512_dpbf16_ps( c_float_5p1, a_bf16_0, b1 ); + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b1 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + + // Broadcast a[0,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + + // Broadcast a[1,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-31] = a[1,kr:kr+2]*b[kr:kr+2,0-31] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_0, b1 ); + + // Broadcast a[2,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-31] = a[2,kr:kr+2]*b[kr:kr+2,0-31] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); + + // Broadcast a[3,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-31] = a[3,kr:kr+2]*b[kr:kr+2,0-31] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); + c_float_3p1 = _mm512_dpbf16_ps( c_float_3p1, a_bf16_0, b1 ); + + // Broadcast a[4,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 4 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[4,0-31] = a[4,kr:kr+2]*b[kr:kr+2,0-31] + c_float_4p0 = _mm512_dpbf16_ps( c_float_4p0, a_bf16_0, b0 ); + c_float_4p1 = _mm512_dpbf16_ps( c_float_4p1, a_bf16_0, b1 ); + + // Broadcast a[5,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 5 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[5,0-31] = a[5,kr:kr+2]*b[kr:kr+2,0-31] + c_float_5p0 = _mm512_dpbf16_ps( c_float_5p0, a_bf16_0, b0 ); + c_float_5p1 = _mm512_dpbf16_ps( c_float_5p1, a_bf16_0, b1 ); + } + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + c_float_0p1 = _mm512_mul_ps( selector1, c_float_0p1 ); + + c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); + c_float_1p1 = _mm512_mul_ps( selector1, c_float_1p1 ); + + c_float_2p0 = _mm512_mul_ps( selector1, c_float_2p0 ); + c_float_2p1 = _mm512_mul_ps( selector1, c_float_2p1 ); + + c_float_3p0 = _mm512_mul_ps( selector1, c_float_3p0 ); + c_float_3p1 = _mm512_mul_ps( selector1, c_float_3p1 ); + + c_float_4p0 = _mm512_mul_ps( selector1, c_float_4p0 ); + c_float_4p1 = _mm512_mul_ps( selector1, c_float_4p1 ); + + c_float_5p0 = _mm512_mul_ps( selector1, c_float_5p0 ); + c_float_5p1 = _mm512_mul_ps( selector1, c_float_5p1 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 0 ) ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 0 ) ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[1,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 1 ) ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 1 ) ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p1 = _mm512_add_ps( selector1, c_float_1p1 ); + + // c[2,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 2 ) ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[2,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 2 ) ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p1 = _mm512_add_ps( selector1, c_float_2p1 ); + + // c[3,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 3 ) ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + + // c[3,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 3 ) ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_3p1 = _mm512_add_ps( selector1, c_float_3p1 ); + + // c[4,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 4 ) ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); + + // c[4,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 4 ) ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_4p1 = _mm512_add_ps( selector1, c_float_4p1 ); + + // c[5,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 5 ) ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_5p0 = _mm512_add_ps( selector1, c_float_5p0 ); + + // c[5,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 5 ) ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_5p1 = _mm512_add_ps( selector1, c_float_5p1 ); + } + + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( c + ( rs_c * ( ir + 0 ) ) + ( 0*16 ), c_float_0p0 ); + + // c[0, 16-31] + _mm512_storeu_ps( c + ( rs_c * ( ir + 0 ) ) + ( 1*16 ), c_float_0p1 ); + + // c[1,0-15] + _mm512_storeu_ps( c + ( rs_c * ( ir + 1 ) ) + ( 0*16 ), c_float_1p0 ); + + // c[1,16-31] + _mm512_storeu_ps( c + ( rs_c * ( ir + 1 ) ) + ( 1*16 ), c_float_1p1 ); + + // c[2,0-15] + _mm512_storeu_ps( c + ( rs_c * ( ir + 2 ) ) + ( 0*16 ), c_float_2p0 ); + + // c[2,16-31] + _mm512_storeu_ps( c + ( rs_c * ( ir + 2 ) ) + ( 1*16 ), c_float_2p1 ); + + // c[3,0-15] + _mm512_storeu_ps( c + ( rs_c * ( ir + 3 ) ) + ( 0*16 ), c_float_3p0 ); + + // c[3,16-31] + _mm512_storeu_ps( c + ( rs_c * ( ir + 3 ) ) + ( 1*16 ), c_float_3p1 ); + + // c[4,0-15] + _mm512_storeu_ps( c + ( rs_c * ( ir + 4 ) ) + ( 0*16 ), c_float_4p0 ); + + // c[4,16-31] + _mm512_storeu_ps( c + ( rs_c * ( ir + 4 ) ) + ( 1*16 ), c_float_4p1 ); + + // c[5,0-15] + _mm512_storeu_ps( c + ( rs_c * ( ir + 5 ) ) + ( 0*16 ), c_float_5p0 ); + + // c[5,16-31] + _mm512_storeu_ps( c + ( rs_c * ( ir + 5 ) ) + ( 1*16 ), c_float_5p1 ); + + a = a + ( MR * ps_a ); + } + + if ( m_partial_pieces > 0 ) + { + if ( m_partial_pieces == 5 ) + { + int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 5 ); + lpgemm_rowvar_bf16bf16f32of32_5x32 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list + ); + } + else if ( m_partial_pieces == 4 ) + { + int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 4 ); + lpgemm_rowvar_bf16bf16f32of32_4x32 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list + ); + } + else if ( m_partial_pieces == 3 ) + { + int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 3 ); + lpgemm_rowvar_bf16bf16f32of32_3x32 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list + ); + } + else if ( m_partial_pieces == 2 ) + { + int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 2 ); + lpgemm_rowvar_bf16bf16f32of32_2x32 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list + ); + } + else if ( m_partial_pieces == 1 ) + { + int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 1 ); + lpgemm_rowvar_bf16bf16f32of32_1x32 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list + ); + } + } +} + +// 6x48 bf16 fringe kernel +LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x48) +{ + dim_t MR = 6; + dim_t m_full_pieces = m0 / MR; + dim_t m_full_pieces_loop_limit = m_full_pieces * MR; + dim_t m_partial_pieces = m0 % MR; + + dim_t k_full_pieces = k0 / 2; + dim_t k_partial_pieces = k0 % 2; + + int32_t a_kfringe_buf = 0; + + // B matrix storage bfloat type + __m512bh b0; + __m512bh b1; + __m512bh b2; + + // A matrix storage bfloat type + __m512bh a_bf16_0; + + for ( dim_t ir = 0; ir < m_full_pieces_loop_limit; ir += MR ) + { + // Registers to use for accumulating C. + __m512 c_float_0p0 = _mm512_setzero_ps(); + __m512 c_float_0p1 = _mm512_setzero_ps(); + __m512 c_float_0p2 = _mm512_setzero_ps(); + + __m512 c_float_1p0 = _mm512_setzero_ps(); + __m512 c_float_1p1 = _mm512_setzero_ps(); + __m512 c_float_1p2 = _mm512_setzero_ps(); + + __m512 c_float_2p0 = _mm512_setzero_ps(); + __m512 c_float_2p1 = _mm512_setzero_ps(); + __m512 c_float_2p2 = _mm512_setzero_ps(); + + __m512 c_float_3p0 = _mm512_setzero_ps(); + __m512 c_float_3p1 = _mm512_setzero_ps(); + __m512 c_float_3p2 = _mm512_setzero_ps(); + + __m512 c_float_4p0 = _mm512_setzero_ps(); + __m512 c_float_4p1 = _mm512_setzero_ps(); + __m512 c_float_4p2 = _mm512_setzero_ps(); + + __m512 c_float_5p0 = _mm512_setzero_ps(); + __m512 c_float_5p1 = _mm512_setzero_ps(); + __m512 c_float_5p2 = _mm512_setzero_ps(); + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) + { + + // Load 2 rows with 48 elements each from B to 3 ZMM registers. It + // is to be noted that the B matrix is packed for use in bf16 + // instructions and each load to ZMM register will have 2 elements + // along k direction and 16 elements across n directions, so 2x16 + // elements to a ZMM register. + b0 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b1 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + b2 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 2 ) ); + + // Broadcast a[0,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-47] = a[0,kr:kr+2]*b[kr:kr+2,0-47] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); + + // Broadcast a[1,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-47] = a[1,kr:kr+2]*b[kr:kr+2,0-47] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_0, b1 ); + c_float_1p2 = _mm512_dpbf16_ps( c_float_1p2, a_bf16_0, b2 ); + + // Broadcast a[2,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-47] = a[2,kr:kr+2]*b[kr:kr+2,0-47] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); + c_float_2p2 = _mm512_dpbf16_ps( c_float_2p2, a_bf16_0, b2 ); + + // Broadcast a[3,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-47] = a[3,kr:kr+2]*b[kr:kr+2,0-47] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); + c_float_3p1 = _mm512_dpbf16_ps( c_float_3p1, a_bf16_0, b1 ); + c_float_3p2 = _mm512_dpbf16_ps( c_float_3p2, a_bf16_0, b2 ); + + // Broadcast a[4,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[4,0-47] = a[4,kr:kr+2]*b[kr:kr+2,0-47] + c_float_4p0 = _mm512_dpbf16_ps( c_float_4p0, a_bf16_0, b0 ); + c_float_4p1 = _mm512_dpbf16_ps( c_float_4p1, a_bf16_0, b1 ); + c_float_4p2 = _mm512_dpbf16_ps( c_float_4p2, a_bf16_0, b2 ); + + // Broadcast a[5,kr:kr+2]. + a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 5 ) + ( cs_a * kr ) ) ); + + // Perform column direction mat-mul with k = 2. + // c[5,0-47] = a[5,kr:kr+2]*b[kr:kr+2,0-47] + c_float_5p0 = _mm512_dpbf16_ps( c_float_5p0, a_bf16_0, b0 ); + c_float_5p1 = _mm512_dpbf16_ps( c_float_5p1, a_bf16_0, b1 ); + c_float_5p2 = _mm512_dpbf16_ps( c_float_5p2, a_bf16_0, b2 ); + + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + b0 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b1 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + b2 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); + + // Broadcast a[0,kr:kr+4]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[0,0-47] = a[0,kr:kr+2]*b[kr:kr+2,0-47] + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); + + // Broadcast a[1,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[1,0-47] = a[1,kr:kr+2]*b[kr:kr+2,0-47] + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_0, b1 ); + c_float_1p2 = _mm512_dpbf16_ps( c_float_1p2, a_bf16_0, b2 ); + + // Broadcast a[2,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[2,0-47] = a[2,kr:kr+2]*b[kr:kr+2,0-47] + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); + c_float_2p2 = _mm512_dpbf16_ps( c_float_2p2, a_bf16_0, b2 ); + + // Broadcast a[3,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[3,0-47] = a[3,kr:kr+2]*b[kr:kr+2,0-47] + c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); + c_float_3p1 = _mm512_dpbf16_ps( c_float_3p1, a_bf16_0, b1 ); + c_float_3p2 = _mm512_dpbf16_ps( c_float_3p2, a_bf16_0, b2 ); + + // Broadcast a[4,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 4 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[4,0-47] = a[4,kr:kr+2]*b[kr:kr+2,0-47] + c_float_4p0 = _mm512_dpbf16_ps( c_float_4p0, a_bf16_0, b0 ); + c_float_4p1 = _mm512_dpbf16_ps( c_float_4p1, a_bf16_0, b1 ); + c_float_4p2 = _mm512_dpbf16_ps( c_float_4p2, a_bf16_0, b2 ); + + // Broadcast a[5,kr:kr+2]. + memcpy + ( + &a_kfringe_buf, + ( a + ( rs_a * 5 ) + ( cs_a * k_full_pieces ) ), + ( k_partial_pieces * sizeof( bfloat16 ) ) + ); + a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + + // Perform column direction mat-mul with k = 2. + // c[5,0-47] = a[5,kr:kr+2]*b[kr:kr+2,0-47] + c_float_5p0 = _mm512_dpbf16_ps( c_float_5p0, a_bf16_0, b0 ); + c_float_5p1 = _mm512_dpbf16_ps( c_float_5p1, a_bf16_0, b1 ); + c_float_5p2 = _mm512_dpbf16_ps( c_float_5p2, a_bf16_0, b2 ); + } + + // Load alpha and beta + __m512 selector1 = _mm512_set1_ps( alpha ); + __m512 selector2 = _mm512_set1_ps( beta ); + + // Scale by alpha + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + c_float_0p1 = _mm512_mul_ps( selector1, c_float_0p1 ); + c_float_0p2 = _mm512_mul_ps( selector1, c_float_0p2 ); + + c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); + c_float_1p1 = _mm512_mul_ps( selector1, c_float_1p1 ); + c_float_1p2 = _mm512_mul_ps( selector1, c_float_1p2 ); + + c_float_2p0 = _mm512_mul_ps( selector1, c_float_2p0 ); + c_float_2p1 = _mm512_mul_ps( selector1, c_float_2p1 ); + c_float_2p2 = _mm512_mul_ps( selector1, c_float_2p2 ); + + c_float_3p0 = _mm512_mul_ps( selector1, c_float_3p0 ); + c_float_3p1 = _mm512_mul_ps( selector1, c_float_3p1 ); + c_float_3p2 = _mm512_mul_ps( selector1, c_float_3p2 ); + + c_float_4p0 = _mm512_mul_ps( selector1, c_float_4p0 ); + c_float_4p1 = _mm512_mul_ps( selector1, c_float_4p1 ); + c_float_4p2 = _mm512_mul_ps( selector1, c_float_4p2 ); + + c_float_5p0 = _mm512_mul_ps( selector1, c_float_5p0 ); + c_float_5p1 = _mm512_mul_ps( selector1, c_float_5p1 ); + c_float_5p2 = _mm512_mul_ps( selector1, c_float_5p2 ); + + // Scale C by beta. + if ( beta != 0 ) + { + // c[0,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 0 ) ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 0 ) ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 0 ) ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_0p2 = _mm512_add_ps( selector1, c_float_0p2 ); + + // c[1,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 1 ) ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 1 ) ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p1 = _mm512_add_ps( selector1, c_float_1p1 ); + + // c[1,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 1 ) ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_1p2 = _mm512_add_ps( selector1, c_float_1p2 ); + + // c[2,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 2 ) ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[2,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 2 ) ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p1 = _mm512_add_ps( selector1, c_float_2p1 ); + + // c[2,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 2 ) ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_2p2 = _mm512_add_ps( selector1, c_float_2p2 ); + + // c[3,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 3 ) ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + + // c[3,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 3 ) ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_3p1 = _mm512_add_ps( selector1, c_float_3p1 ); + + // c[3,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 3 ) ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_3p2 = _mm512_add_ps( selector1, c_float_3p2 ); + + // c[4,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 4 ) ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); + + // c[4,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 4 ) ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_4p1 = _mm512_add_ps( selector1, c_float_4p1 ); + + // c[4,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 4 ) ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_4p2 = _mm512_add_ps( selector1, c_float_4p2 ); + + // c[5,0-15] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 5 ) ) + ( 0*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_5p0 = _mm512_add_ps( selector1, c_float_5p0 ); + + // c[5,16-31] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 5 ) ) + ( 1*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_5p1 = _mm512_add_ps( selector1, c_float_5p1 ); + + // c[5,32-47] + selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 5 ) ) + ( 2*16 ) ); + selector1 = _mm512_mul_ps( selector2, selector1 ); + c_float_5p2 = _mm512_add_ps( selector1, c_float_5p2 ); + + } + + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( c + ( rs_c * ( ir + 0 ) ) + ( 0*16 ), c_float_0p0 ); + + // c[0, 16-31] + _mm512_storeu_ps( c + ( rs_c * ( ir + 0 ) ) + ( 1*16 ), c_float_0p1 ); + + // c[0,32-47] + _mm512_storeu_ps( c + ( rs_c * ( ir + 0 ) ) + ( 2*16 ), c_float_0p2 ); + + // c[1,0-15] + _mm512_storeu_ps( c + ( rs_c * ( ir + 1 ) ) + ( 0*16 ), c_float_1p0 ); + + // c[1,16-31] + _mm512_storeu_ps( c + ( rs_c * ( ir + 1 ) ) + ( 1*16 ), c_float_1p1 ); + + // c[1,32-47] + _mm512_storeu_ps( c + ( rs_c * ( ir + 1 ) ) + ( 2*16 ), c_float_1p2 ); + + // c[2,0-15] + _mm512_storeu_ps( c + ( rs_c * ( ir + 2 ) ) + ( 0*16 ), c_float_2p0 ); + + // c[2,16-31] + _mm512_storeu_ps( c + ( rs_c * ( ir + 2 ) ) + ( 1*16 ), c_float_2p1 ); + + // c[2,32-47] + _mm512_storeu_ps( c + ( rs_c * ( ir + 2 ) ) + ( 2*16 ), c_float_2p2 ); + + // c[3,0-15] + _mm512_storeu_ps( c + ( rs_c * ( ir + 3 ) ) + ( 0*16 ), c_float_3p0 ); + + // c[3,16-31] + _mm512_storeu_ps( c + ( rs_c * ( ir + 3 ) ) + ( 1*16 ), c_float_3p1 ); + + // c[3,32-47] + _mm512_storeu_ps( c + ( rs_c * ( ir + 3 ) ) + ( 2*16 ), c_float_3p2 ); + + // c[4,0-15] + _mm512_storeu_ps( c + ( rs_c * ( ir + 4 ) ) + ( 0*16 ), c_float_4p0 ); + + // c[4,16-31] + _mm512_storeu_ps( c + ( rs_c * ( ir + 4 ) ) + ( 1*16 ), c_float_4p1 ); + + // c[4,32-47] + _mm512_storeu_ps( c + ( rs_c * ( ir + 4 ) ) + ( 2*16 ), c_float_4p2 ); + + // c[5,0-15] + _mm512_storeu_ps( c + ( rs_c * ( ir + 5 ) ) + ( 0*16 ), c_float_5p0 ); + + // c[5,16-31] + _mm512_storeu_ps( c + ( rs_c * ( ir + 5 ) ) + ( 1*16 ), c_float_5p1 ); + + // c[5,32-47] + _mm512_storeu_ps( c + ( rs_c * ( ir + 5 ) ) + ( 2*16 ), c_float_5p2 ); + + a = a + ( MR * ps_a ); + + } + + if ( m_partial_pieces > 0 ) + { + if ( m_partial_pieces == 5 ) + { + int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 5 ); + lpgemm_rowvar_bf16bf16f32of32_5x48 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list + ); + } + else if ( m_partial_pieces == 4 ) + { + int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 4 ); + lpgemm_rowvar_bf16bf16f32of32_4x48 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list + ); + } + else if ( m_partial_pieces == 3 ) + { + int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 3 ); + lpgemm_rowvar_bf16bf16f32of32_3x48 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list + ); + } + else if ( m_partial_pieces == 2 ) + { + int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 2 ); + lpgemm_rowvar_bf16bf16f32of32_2x48 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list + ); + } + else if ( m_partial_pieces == 1 ) + { + int cs_a_use = ( cs_a == 2 ) ? 2 : ( ( cs_a / 6 ) * 1 ); + lpgemm_rowvar_bf16bf16f32of32_1x48 + ( + k0, + a, rs_a, cs_a_use, + b, rs_b, cs_b, + ( c + ( rs_c * m_full_pieces_loop_limit ) ), rs_c, + alpha, beta, + is_last_k, + post_op_c_i, post_op_c_j, + post_ops_list + ); + } + } +} diff --git a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_packb_bf16.h b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_packb_bf16.h new file mode 100644 index 0000000000..07b22a5b25 --- /dev/null +++ b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_packb_bf16.h @@ -0,0 +1,67 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLIS_GEMM_BF16_PACKB +#define BLIS_GEMM_BF16_PACKB + +#include "lpgemm_kernels.h" + +BLIS_INLINE dim_t get_packb_bf16bf16f32of32_min_NR() +{ + // This is the minimum NR' required for use in bf16bf16f32 kernels. The idea + // here is that since k needs to be a multiple of 2 (BF16 instr), NR'=16 + // results in total of 2 * NR' = 64 bytes to be loaded, which fits in 1 ZMM + // register. Thus the smallest n fringe kernel dimension has n=16, and thus + // any rounding for buffer sizes should be to 16. + return 16; +} + +void get_packb_nr64_bf16bf16f32of32_strides + ( + dim_t* rs_b, + dim_t* cs_b + ); + +void packb_nr64_bf16bf16f32of32 + ( + bfloat16* pack_b_buffer_bf16bf16f32of32, + const bfloat16* b, + const dim_t ldb, + const dim_t NC, + const dim_t KC, + dim_t* rs_b, + dim_t* cs_b + ); + +#endif //BLIS_GEMM_BF16_PACKB diff --git a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_packb_bf16_amd512vnni.c b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_packb_bf16_amd512vnni.c new file mode 100644 index 0000000000..19725b2768 --- /dev/null +++ b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_packb_bf16_amd512vnni.c @@ -0,0 +1,504 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include +#include + +#include "blis.h" +#include "lpgemm_config.h" +#include "aocl_bf16_type.h" + +void packb_nrlt16_bf16bf16f32of32 + ( + bfloat16* pack_b_buffer_bf16bf16f32of32, + const bfloat16* b, + const dim_t ldb, + const dim_t KC, + const dim_t n0_partial_rem + ); + +void packb_nr16_bf16bf16f32of32 + ( + bfloat16* pack_b_buffer_bf16bf16f32of32, + const bfloat16* b, + const dim_t ldb, + const dim_t KC + ); + +void packb_nr32_bf16bf16f32of32 + ( + bfloat16* pack_b_buffer_bf16bf16f32of32, + const bfloat16* b, + const dim_t ldb, + const dim_t KC + ); + +void packb_nr48_bf16bf16f32of32 + ( + bfloat16* pack_b_buffer_bf16bf16f32of32, + const bfloat16* b, + const dim_t ldb, + const dim_t KC + ); + +void get_packb_nr64_bf16bf16f32of32_strides + ( + dim_t* rs_b, + dim_t* cs_b + ) +{ + *rs_b = lpgemm_get_block_size_NR_global_cntx( BF16BF16F32OF32 ) * 2; + *cs_b = lpgemm_get_block_size_NR_global_cntx( BF16BF16F32OF32 ) / 2; +} + +void packb_nr64_bf16bf16f32of32 + ( + bfloat16* pack_b_buffer_bf16bf16f32of32, + const bfloat16* b, + const dim_t ldb, + const dim_t NC, + const dim_t KC, + dim_t* rs_b, + dim_t* cs_b + ) +{ + dim_t NR = 64; + + // Used for permuting the mm512i elements for use in dpbf16_ps instruction. + __m512i selector1 = _mm512_setr_epi64(0x0, 0x1, 0x8, 0x9, 0x2, 0x3, 0xA, 0xB); + __m512i selector1_1 = _mm512_setr_epi64( 0x4, 0x5, 0xC, 0xD, 0x6, 0x7, 0xE, 0xF ); + + __m512i a0; + __m512i b0; + __m512i c0; + __m512i d0; + __m512i a01; + __m512i c01; + + dim_t n_full_pieces = NC / NR; + dim_t n_full_pieces_loop_limit = n_full_pieces * NR; + dim_t n_partial_pieces = NC % NR; + + dim_t k_full_pieces_blks = KC / 2; + dim_t k_full_pieces = k_full_pieces_blks * 2; + dim_t k_partial_pieces = KC % 2; + + // KC when not multiple of 2 will have padding to make it multiple of 2 in packed buffer. + dim_t KC_updated = KC; + if ( k_partial_pieces > 0 ) + { + KC_updated += ( 2 - k_partial_pieces ); + } + + for ( dim_t jc = 0; jc < n_full_pieces_loop_limit; jc += NR ) + { + for ( dim_t kr = 0; kr < k_full_pieces; kr += 2 ) + { + // Rearrange for dpbf16_ps, read 2 rows from B with 64 elements in each row. + a0 = _mm512_loadu_epi16( b + ( ldb * ( kr + 0 ) ) + jc ); + b0 = _mm512_loadu_epi16( b + ( ldb * ( kr + 0 ) ) + jc + 32 ); + c0 = _mm512_loadu_epi16( b + ( ldb * ( kr + 1 ) ) + jc ); + d0 = _mm512_loadu_epi16( b + ( ldb * ( kr + 1 ) ) + jc + 32 ); + + a01 = _mm512_unpacklo_epi16( a0, c0 ); + a0 = _mm512_unpackhi_epi16( a0, c0 ); + + c01 = _mm512_unpacklo_epi16( b0, d0 ); + c0 = _mm512_unpackhi_epi16( b0, d0 ); + + b0 = _mm512_permutex2var_epi64( a01, selector1, a0 ); + d0 = _mm512_permutex2var_epi64( c01, selector1, c0 ); + a0 = _mm512_permutex2var_epi64( a01, selector1_1, a0 ); + c0 = _mm512_permutex2var_epi64( c01, selector1_1, c0 ); + + //store to pack_b buffer + _mm512_storeu_epi64( pack_b_buffer_bf16bf16f32of32 + ( jc * KC_updated ) + ( ( kr + 0 ) * NR ), b0 ); + _mm512_storeu_epi64( pack_b_buffer_bf16bf16f32of32 + ( jc * KC_updated ) + ( ( kr + 0 ) * NR ) + 32, a0 ); + _mm512_storeu_epi64( pack_b_buffer_bf16bf16f32of32 + ( jc * KC_updated ) + ( ( kr + 1 ) * NR ), d0 ); + _mm512_storeu_epi64( pack_b_buffer_bf16bf16f32of32 + ( jc * KC_updated ) + ( ( kr + 1 ) * NR ) + 32, c0 ); + } + // Handle k remainder. + if( k_partial_pieces > 0) + { + a0 = _mm512_loadu_epi16( b + ( ldb * ( k_full_pieces + 0 ) ) + jc ); + b0 = _mm512_loadu_epi16( b + ( ldb * ( k_full_pieces + 0 ) ) + jc + 32 ); + c0 = _mm512_setzero_si512(); + d0 = _mm512_setzero_si512(); + + a01 = _mm512_unpacklo_epi16( a0, c0 ); + a0 = _mm512_unpackhi_epi16( a0, c0 ); + + c01 = _mm512_unpacklo_epi16( b0, d0 ); + c0 = _mm512_unpackhi_epi16( b0, d0 ); + + b0 = _mm512_permutex2var_epi64( a01, selector1, a0 ); + d0 = _mm512_permutex2var_epi64( c01, selector1, c0 ); + a0 = _mm512_permutex2var_epi64( a01, selector1_1, a0 ); + c0 = _mm512_permutex2var_epi64( c01, selector1_1, c0 ); + + //store to pack_b buffer + _mm512_storeu_epi64( pack_b_buffer_bf16bf16f32of32 + ( jc * KC_updated ) + ( ( k_full_pieces + 0 ) * NR ), b0 ); + _mm512_storeu_epi64( pack_b_buffer_bf16bf16f32of32 + ( jc * KC_updated ) + ( ( k_full_pieces + 0 ) * NR ) + 32, a0 ); + _mm512_storeu_epi64( pack_b_buffer_bf16bf16f32of32 + ( jc * KC_updated ) + ( ( k_full_pieces + 1 ) * NR ), d0 ); + _mm512_storeu_epi64( pack_b_buffer_bf16bf16f32of32 + ( jc * KC_updated ) + ( ( k_full_pieces + 1 ) * NR ) + 32, c0 ); + } + } + + if(n_partial_pieces > 0) + { + dim_t n0_partial_rem = n_partial_pieces % 16; + dim_t n0_partial_pack = 0; + + // Split into multiple smaller fringe kernels, so as to maximize + // vectorization after packing. Any n0 < NR(64) can be expressed + // as n0 = 48 + n` / n0 = 32 + n` / n0 = 16 + n`, where n` < 16. + dim_t n0_48 = n_partial_pieces / 48; + dim_t n0_32 = n_partial_pieces / 32; + dim_t n0_16 = n_partial_pieces / 16; + + if ( n0_48 == 1 ) + { + packb_nr48_bf16bf16f32of32 + ( + ( pack_b_buffer_bf16bf16f32of32 + ( n_full_pieces_loop_limit * KC_updated ) ), + ( b + n_full_pieces_loop_limit ), ldb, KC + ); + + n0_partial_pack = 48; + } + else if ( n0_32 == 1 ) + { + packb_nr32_bf16bf16f32of32 + ( + ( pack_b_buffer_bf16bf16f32of32 + ( n_full_pieces_loop_limit * KC_updated ) ), + ( b + n_full_pieces_loop_limit ), ldb, KC + ); + + n0_partial_pack = 32; + } + else if ( n0_16 == 1 ) + { + packb_nr16_bf16bf16f32of32 + ( + ( pack_b_buffer_bf16bf16f32of32 + ( n_full_pieces_loop_limit * KC_updated ) ), + ( b + n_full_pieces_loop_limit ), ldb, KC + ); + + n0_partial_pack = 16; + } + + if ( n0_partial_rem > 0 ) + { + packb_nrlt16_bf16bf16f32of32 + ( + ( pack_b_buffer_bf16bf16f32of32 + ( n_full_pieces_loop_limit * KC_updated ) + + ( n0_partial_pack * KC_updated ) ), + ( b + n_full_pieces_loop_limit + n0_partial_pack ), ldb, KC, + n0_partial_rem + ); + } + } + *rs_b = NR * 2; + *cs_b = NR / 2; +} + +void packb_nr48_bf16bf16f32of32 +( + bfloat16* pack_b_buffer_bf16bf16f32of32, + const bfloat16* b, + const dim_t ldb, + const dim_t KC + ) +{ + dim_t NR1 = 32; + dim_t NR2 = 16; + + // Used for permuting the mm512i elements for use in dpbf16_ps instruction. + __m512i selector1 = _mm512_setr_epi64(0x0, 0x1, 0x8, 0x9, 0x2, 0x3, 0xA, 0xB); + __m512i selector1_1 = _mm512_setr_epi64( 0x4, 0x5, 0xC, 0xD, 0x6, 0x7, 0xE, 0xF ); + + __m512i a0x; + __m512i b0x; + __m512i c0x; + __m512i a01x; + + __m256i a0; + __m256i b0; + __m256i c0; + __m256i a01; + + dim_t k_full_pieces_blks = KC / 2; + dim_t k_full_pieces = k_full_pieces_blks * 2; + dim_t k_partial_pieces = KC % 2; + + dim_t kr_new = 0; + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 2 ) + { + // Rearrange for dpbf16_ps, read 2 rows from B with 32 elements in each row. + a0x = _mm512_loadu_epi16( b + ( ldb * ( kr + 0 ) ) ); + c0x = _mm512_loadu_epi16( b + ( ldb * ( kr + 1 ) ) ); + + a01x = _mm512_unpacklo_epi16( a0x, c0x ); + a0x = _mm512_unpackhi_epi16( a0x, c0x ); + + b0x = _mm512_permutex2var_epi64( a01x, selector1, a0x ); + a0x = _mm512_permutex2var_epi64( a01x, selector1_1, a0x ); + + //First 2x32 elements + _mm512_storeu_epi64( pack_b_buffer_bf16bf16f32of32 + ( ( kr_new + 0 ) * NR1 ), b0x ); + _mm512_storeu_epi64( pack_b_buffer_bf16bf16f32of32 + ( ( kr_new + 1 ) * NR1 ), a0x ); + + // Rearrange for dpbf16_ps, read 2 rows from B with next 16 elements in each row. + a0 = _mm256_loadu_epi16( b + ( ldb * ( kr + 0 ) ) + NR1 ); + c0 = _mm256_loadu_epi16( b + ( ldb * ( kr + 1 ) ) + NR1 ); + + a01 = _mm256_unpacklo_epi16( a0, c0 ); + a0 = _mm256_unpackhi_epi16( a0, c0 ); + + b0 = _mm256_permute2f128_si256(a01, a0, 0x20); + a0 = _mm256_permute2f128_si256(a01, a0, 0x31); + + //Last 2x16 elements + _mm256_storeu_epi64( pack_b_buffer_bf16bf16f32of32 + ( ( kr_new + 2 ) * NR1 ), b0 ); + _mm256_storeu_epi64( pack_b_buffer_bf16bf16f32of32 + ( ( kr_new + 2 ) * NR1 ) + NR2, a0 ); + + kr_new += 3; + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + a0x = _mm512_loadu_epi16( b + ( ldb * ( k_full_pieces + 0 ) ) ); + c0x = _mm512_setzero_si512(); + + a01x = _mm512_unpacklo_epi16( a0x, c0x ); + a0x = _mm512_unpackhi_epi16( a0x, c0x ); + + b0x = _mm512_permutex2var_epi64( a01x, selector1, a0x ); + a0x = _mm512_permutex2var_epi64( a01x, selector1_1, a0x ); + + //First 2x32 elements + _mm512_storeu_epi64( pack_b_buffer_bf16bf16f32of32 + ( ( kr_new + 0 ) * NR1 ), b0x ); + _mm512_storeu_epi64( pack_b_buffer_bf16bf16f32of32 + ( ( kr_new + 1 ) * NR1 ), a0x ); + + a0 = _mm256_loadu_epi16( b + ( ldb * ( k_full_pieces + 0 ) ) + NR1 ); + c0 = _mm256_setzero_si256(); + + a01 = _mm256_unpacklo_epi16( a0, c0 ); + a0 = _mm256_unpackhi_epi16( a0, c0 ); + + b0 = _mm256_permute2f128_si256(a01, a0, 0x20); + a0 = _mm256_permute2f128_si256(a01, a0, 0x31); + + //Last 2x16 elements + _mm256_storeu_epi64( pack_b_buffer_bf16bf16f32of32 + ( ( kr_new + 2 ) * NR1 ), b0 ); + _mm256_storeu_epi64( pack_b_buffer_bf16bf16f32of32 + ( ( kr_new + 2 ) * NR1 ) + NR2, a0 ); + } +} + +void packb_nr32_bf16bf16f32of32 +( + bfloat16* pack_b_buffer_bf16bf16f32of32, + const bfloat16* b, + const dim_t ldb, + const dim_t KC + ) +{ + dim_t NR = 32; + + // Used for permuting the mm512i elements for use in dpbf16_ps instruction. + __m512i selector1 = _mm512_setr_epi64(0x0, 0x1, 0x8, 0x9, 0x2, 0x3, 0xA, 0xB); + __m512i selector1_1 = _mm512_setr_epi64( 0x4, 0x5, 0xC, 0xD, 0x6, 0x7, 0xE, 0xF ); + + __m512i a0; + __m512i b0; + __m512i c0; + __m512i a01; + + dim_t k_full_pieces_blks = KC / 2; + dim_t k_full_pieces = k_full_pieces_blks * 2; + dim_t k_partial_pieces = KC % 2; + + dim_t kr_new = 0; + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 2 ) + { + // Rearrange for dpbf16_ps, read 2 rows from B with 32 elements in each row. + a0 = _mm512_loadu_epi16( b + ( ldb * ( kr + 0 ) ) ); + c0 = _mm512_loadu_epi16( b + ( ldb * ( kr + 1 ) ) ); + + a01 = _mm512_unpacklo_epi16( a0, c0 ); + a0 = _mm512_unpackhi_epi16( a0, c0 ); + + b0 = _mm512_permutex2var_epi64( a01, selector1, a0 ); + a0 = _mm512_permutex2var_epi64( a01, selector1_1, a0 ); + + _mm512_storeu_epi64( pack_b_buffer_bf16bf16f32of32 + ( ( kr_new ) * NR ), b0 ); + _mm512_storeu_epi64( pack_b_buffer_bf16bf16f32of32 + ( ( kr_new + 1 ) * NR ), a0 ); + + kr_new += 2; + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + a0 = _mm512_loadu_epi16( b + ( ldb * ( k_full_pieces + 0 ) ) ); + c0 = _mm512_setzero_si512(); + + a01 = _mm512_unpacklo_epi16( a0, c0 ); + a0 = _mm512_unpackhi_epi16( a0, c0 ); + + b0 = _mm512_permutex2var_epi64( a01, selector1, a0 ); + a0 = _mm512_permutex2var_epi64( a01, selector1_1, a0 ); + + _mm512_storeu_epi64( pack_b_buffer_bf16bf16f32of32 + ( ( kr_new ) * NR ), b0 ); + _mm512_storeu_epi64( pack_b_buffer_bf16bf16f32of32 + ( ( kr_new + 1 ) * NR ), a0 ); + } +} + +void packb_nr16_bf16bf16f32of32 +( + bfloat16* pack_b_buffer_bf16bf16f32of32, + const bfloat16* b, + const dim_t ldb, + const dim_t KC + ) +{ + dim_t NR = 16; + + __m256i a0; + __m256i b0; + __m256i c0; + __m256i a01; + + dim_t k_full_pieces_blks = KC / 2; + dim_t k_full_pieces = k_full_pieces_blks * 2; + dim_t k_partial_pieces = KC % 2; + + dim_t kr_new = 0; + + for ( dim_t kr = 0; kr < k_full_pieces; kr += 2 ) + { + // Rearrange for dpbf16_ps, read 2 rows from B with 16 elements in each row. + a0 = _mm256_loadu_epi16( b + ( ldb * ( kr + 0 ) ) ); + c0 = _mm256_loadu_epi16( b + ( ldb * ( kr + 1 ) ) ); + + a01 = _mm256_unpacklo_epi16( a0, c0 ); + a0 = _mm256_unpackhi_epi16( a0, c0 ); + + b0 = _mm256_permute2f128_si256(a01, a0, 0x20); + a0 = _mm256_permute2f128_si256(a01, a0, 0x31); + + _mm256_storeu_epi64( pack_b_buffer_bf16bf16f32of32 + ( ( kr_new + 0 ) * NR ), b0 ); + _mm256_storeu_epi64( pack_b_buffer_bf16bf16f32of32 + ( ( kr_new + 1 ) * NR ), a0 ); + + kr_new += 2; + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + a0 = _mm256_loadu_epi16( b + ( ldb * ( k_full_pieces + 0 ) ) ); + c0 = _mm256_setzero_si256(); + + a01 = _mm256_unpacklo_epi16( a0, c0 ); + a0 = _mm256_unpackhi_epi16( a0, c0 ); + + b0 = _mm256_permute2f128_si256(a01, a0, 0x20); + a0 = _mm256_permute2f128_si256(a01, a0, 0x31); + + _mm256_storeu_epi64( pack_b_buffer_bf16bf16f32of32 + ( ( kr_new + 0 ) * NR ), b0 ); + _mm256_storeu_epi64( pack_b_buffer_bf16bf16f32of32 + ( ( kr_new + 1 ) * NR ), a0 ); + } +} + +void packb_nrlt16_bf16bf16f32of32 +( + bfloat16* pack_b_buffer_bf16bf16f32of32, + const bfloat16* b, + const dim_t ldb, + const dim_t KC, + const dim_t n0_partial_rem + ) +{ + dim_t NR = 16; + + __m256i a0; + __m256i b0; + __m256i c0; + __m256i a01; + + dim_t k_full_pieces_blks = KC / 2; + dim_t k_full_pieces = k_full_pieces_blks * 2; + dim_t k_partial_pieces = KC % 2; + + dim_t kr_new = 0; + + bfloat16 buf0[16]; + bfloat16 buf1[16]; + + for ( int kr = 0; kr < k_full_pieces; kr += 2 ) + { + memcpy( buf0, ( b + ( ldb * ( kr + 0 ) ) ), ( n0_partial_rem * sizeof( bfloat16 ) ) ); + memcpy( buf1, ( b + ( ldb * ( kr + 1 ) ) ), ( n0_partial_rem * sizeof( bfloat16 ) ) ); + // Rearrange for dpbf16_ps, read 2 rows from B with next 16 elements in each row. + a0 = _mm256_loadu_epi16( buf0 ); + c0 = _mm256_loadu_epi16( buf1 ); + + a01 = _mm256_unpacklo_epi16( a0, c0 ); + a0 = _mm256_unpackhi_epi16( a0, c0 ); + + b0 = _mm256_permute2f128_si256(a01, a0, 0x20); + a0 = _mm256_permute2f128_si256(a01, a0, 0x31); + + _mm256_storeu_epi64( pack_b_buffer_bf16bf16f32of32 + ( ( kr_new + 0 ) * NR ), b0 ); + _mm256_storeu_epi64( pack_b_buffer_bf16bf16f32of32 + ( ( kr_new + 1 ) * NR ), a0 ); + + kr_new += 2; + } + // Handle k remainder. + if ( k_partial_pieces > 0 ) + { + memcpy( buf0, ( b + ( ldb * ( k_full_pieces + 0 ) ) ), ( n0_partial_rem * sizeof( bfloat16 ) ) ); + a0 = _mm256_loadu_epi16( buf0 ); + c0 = _mm256_setzero_si256(); + + a01 = _mm256_unpacklo_epi16( a0, c0 ); + a0 = _mm256_unpackhi_epi16( a0, c0 ); + + b0 = _mm256_permute2f128_si256(a01, a0, 0x20); + a0 = _mm256_permute2f128_si256(a01, a0, 0x31); + + _mm256_storeu_epi64( pack_b_buffer_bf16bf16f32of32 + ( ( kr_new + 0 ) * NR ), b0 ); + _mm256_storeu_epi64( pack_b_buffer_bf16bf16f32of32 + ( ( kr_new + 1 ) * NR ), a0 ); + } +} diff --git a/addon/aocl_gemm/kernels/lpgemm_kernels.h b/addon/aocl_gemm/kernels/lpgemm_kernels.h index f2e66d3277..3e79a0ce58 100644 --- a/addon/aocl_gemm/kernels/lpgemm_kernels.h +++ b/addon/aocl_gemm/kernels/lpgemm_kernels.h @@ -36,6 +36,7 @@ #define BLIS_LPGEMM_KERN_H #include "lpgemm_post_ops.h" +#include "aocl_bf16_type.h" #define LPGEMM_MAIN_KERN(A_type,B_type,C_type,LP_SFX) \ void lpgemm_rowvar_ ## LP_SFX \ @@ -63,6 +64,7 @@ void lpgemm_rowvar_ ## LP_SFX \ LPGEMM_MAIN_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x64); LPGEMM_MAIN_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x32); +LPGEMM_MAIN_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_6x64); #define LPGEMM_M_FRINGE_KERN(A_type,B_type,C_type,LP_SFX) \ void lpgemm_rowvar_ ## LP_SFX \ @@ -94,6 +96,12 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4x32); LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2x32); LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_1x32); +LPGEMM_M_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_5x64); +LPGEMM_M_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_4x64); +LPGEMM_M_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_3x64); +LPGEMM_M_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_2x64); +LPGEMM_M_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_1x64); + #define LPGEMM_N_FRINGE_KERN(A_type,B_type,C_type,LP_SFX) \ void lpgemm_rowvar_ ## LP_SFX \ ( \ @@ -122,6 +130,10 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x48); LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x16); +LPGEMM_N_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_6x16); +LPGEMM_N_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_6x32); +LPGEMM_N_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_6x48); + #define LPGEMM_N_LT_NR0_FRINGE_KERN(A_type,B_type,C_type,LP_SFX) \ void lpgemm_rowvar_ ## LP_SFX \ ( \ @@ -149,6 +161,8 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6xlt16); LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6xlt16); +LPGEMM_N_LT_NR0_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_6xlt16); + #define LPGEMM_MN_FRINGE_KERN(A_type,B_type,C_type,LP_SFX) \ void lpgemm_rowvar_ ## LP_SFX \ ( \ @@ -189,6 +203,22 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4x16); LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2x16); LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_1x16); +LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_5x16); +LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_4x16); +LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_3x16); +LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_2x16); +LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_1x16); +LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_5x32); +LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_4x32); +LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_3x32); +LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_2x32); +LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_1x32); +LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_5x48); +LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_4x48); +LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_3x48); +LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_2x48); +LPGEMM_MN_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_1x48); + #define LPGEMM_MN_LT_NR0_FRINGE_KERN(A_type,B_type,C_type,LP_SFX) \ void lpgemm_rowvar_ ## LP_SFX \ ( \ @@ -220,4 +250,10 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4xlt16); LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2xlt16); LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_1xlt16); +LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_5xlt16); +LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_4xlt16); +LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_3xlt16); +LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_2xlt16); +LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16,bfloat16,float,bf16bf16f32of32_1xlt16); + #endif //BLIS_LPGEMM_KERN_H diff --git a/bench/bench_aocl_gemm/Makefile b/bench/bench_aocl_gemm/Makefile index 6344fe9396..a760d84688 100755 --- a/bench/bench_aocl_gemm/Makefile +++ b/bench/bench_aocl_gemm/Makefile @@ -87,7 +87,7 @@ TEST_OBJS := $(patsubst $(TEST_SRC_PATH)/%.c, \ CINCFLAGS := -I$(INC_PATH) -I$(CBLAS_HEADER_PATH) # Use the CFLAGS for the configuration family. -CFLAGS := $(call get-user-cflags-for,$(CONFIG_NAME)) +CFLAGS := $(call get-user-cflags-for,$(CONFIG_NAME)) -march=icelake-server -mavx512bf16 # Add local header paths to CFLAGS CFLAGS += -I$(TEST_SRC_PATH) diff --git a/bench/bench_aocl_gemm/bench_lpgemm.c b/bench/bench_aocl_gemm/bench_lpgemm.c deleted file mode 100644 index 1bf9587635..0000000000 --- a/bench/bench_aocl_gemm/bench_lpgemm.c +++ /dev/null @@ -1,726 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include - -#include "blis.h" - -// Mode can be one of the follwoing: -// 1. p - performance, used for benchmarks. -// 2. a - accuracy, used to test accuracy/correctness. -// Default value is p, can be modified by passing command line arg. -char bench_mode = 'p'; - -int32_t global_n_repeat = 0; - -#define _XSTR(str) #str -#define XSTR(str) _XSTR(str) - -#define GEN_FUNC_NAME(prototype,ctype) prototype ## ctype - -#define GEN_FILL_ARRAY_FUNC(ctype) \ -void fill_array_ ## ctype ( void* arr, dim_t size ) \ -{ \ - ctype* temp_arr = ( ctype* ) arr; \ - for ( dim_t i = 0; i < size; ++i ) \ - { \ - temp_arr[i] = ( ctype )( i % 100 ); \ - } \ -} \ - -GEN_FILL_ARRAY_FUNC(uint8_t) -GEN_FILL_ARRAY_FUNC(int8_t) -GEN_FILL_ARRAY_FUNC(float) -GEN_FILL_ARRAY_FUNC(int32_t) - -#define GEN_FILL_ARRAY_POST_OPS_FUNC(ctype) \ -void fill_array_post_ops_ ## ctype ( void* arr, dim_t size ) \ -{ \ - ctype* temp_arr = ( ctype* ) arr; \ - for ( dim_t i = 0; i < size; ++i ) \ - { \ - temp_arr[i] = ( ctype )( i % 20 ); \ - } \ -} \ - -GEN_FILL_ARRAY_POST_OPS_FUNC(int16_t) -GEN_FILL_ARRAY_POST_OPS_FUNC(int32_t) -GEN_FILL_ARRAY_POST_OPS_FUNC(float) - -#define GEN_BLIS_MAT_MUL_FUNC(A_type,B_type,C_type,BLAS_SFX) \ -void mat_mul_ ## BLAS_SFX \ - ( \ - char op_t, \ - dim_t m, \ - dim_t n, \ - dim_t k, \ - C_type alpha, \ - A_type* a, \ - dim_t lda, \ - B_type* b, \ - dim_t ldb, \ - C_type beta, \ - C_type* c, \ - dim_t ldc, \ - aocl_post_op* post_op\ - ) \ -{ \ - char storage = 'r'; \ - char transa = 'n'; \ - char transb = 'n'; \ - char reordera = 'n'; \ - char reorderb = 'n'; \ - \ - if ( ( op_t == 'p' ) || ( op_t == 'P' ) ) \ - { \ - /* No reordering of B.*/ \ - reordera = 'n'; \ - reorderb = 'n'; \ - } \ - else if ( ( op_t == 'r' ) || ( op_t == 'R' ) ) \ - { \ - /* Reordered B.*/ \ - reordera = 'n'; \ - reorderb = 'r'; \ - } \ - \ - aocl_gemm_ ## BLAS_SFX( storage, transa, transb, m, n, k, \ - alpha, \ - a, lda, reordera, \ - b, ldb, reorderb, \ - beta, \ - c, ldc, post_op ); \ - \ - /*dim_t MR = 6; \ - dim_t NR = 16; \ - \ - __m512i selector1; \ - __m512i all_zero = _mm512_setzero_epi32(); \ - __m512i c0; \ - __m512i c1; \ - __m512i c2; \ - __m512i c3; \ - __m512i c4; \ - __m512i c5; \ - \ - for ( dim_t i = 0; i < m; i += MR ) \ - { \ - if ( ( i + MR ) > m ) \ - { \ - break; \ - } \ - for ( dim_t j = 0; j < n; j += NR ) \ - { \ - if ( ( j + NR ) > n ) \ - { \ - break; \ - } \ - selector1 = _mm512_loadu_epi32( (int32_t*)post_op->bias.bias + j ); \ - c0 = _mm512_loadu_epi32( c + ( ( i + 0 ) * ldc ) + j ); \ - c1 = _mm512_loadu_epi32( c + ( ( i + 1 ) * ldc ) + j ); \ - c2 = _mm512_loadu_epi32( c + ( ( i + 2 ) * ldc ) + j ); \ - c3 = _mm512_loadu_epi32( c + ( ( i + 3 ) * ldc ) + j ); \ - c4 = _mm512_loadu_epi32( c + ( ( i + 4 ) * ldc ) + j ); \ - c5 = _mm512_loadu_epi32( c + ( ( i + 5 ) * ldc ) + j ); \ - \ - c0 = _mm512_add_epi32( selector1, c0 ); \ - c1 = _mm512_add_epi32( selector1, c1 ); \ - c2 = _mm512_add_epi32( selector1, c2 ); \ - c3 = _mm512_add_epi32( selector1, c3 ); \ - c4 = _mm512_add_epi32( selector1, c4 ); \ - c5 = _mm512_add_epi32( selector1, c5 ); \ - \ - c0 = _mm512_max_epi32( all_zero, c0 ); \ - c1 = _mm512_max_epi32( all_zero, c1 ); \ - c2 = _mm512_max_epi32( all_zero, c2 ); \ - c3 = _mm512_max_epi32( all_zero, c3 ); \ - c4 = _mm512_max_epi32( all_zero, c4 ); \ - c5 = _mm512_max_epi32( all_zero, c5 ); \ - \ - _mm512_storeu_epi32( c + ( ( i + 0 ) * ldc ) + j, c0 ); \ - _mm512_storeu_epi32( c + ( ( i + 1 ) * ldc ) + j, c1 ); \ - _mm512_storeu_epi32( c + ( ( i + 2 ) * ldc ) + j, c2 ); \ - _mm512_storeu_epi32( c + ( ( i + 3 ) * ldc ) + j, c3 ); \ - _mm512_storeu_epi32( c + ( ( i + 4 ) * ldc ) + j, c4 ); \ - _mm512_storeu_epi32( c + ( ( i + 5 ) * ldc ) + j, c5 ); \ - } \ - } */\ -} \ - -GEN_BLIS_MAT_MUL_FUNC(uint8_t, int8_t, int16_t, u8s8s16os16) -GEN_BLIS_MAT_MUL_FUNC(uint8_t,int8_t,int32_t,u8s8s32os32) -GEN_BLIS_MAT_MUL_FUNC(float,float,float,f32f32f32of32) - -double get_gflops - ( - dim_t m, - dim_t n, - dim_t k, - double runtime - ) -{ - return ( ( 2.0 * m * n * k ) / ( runtime * 1.0e9 ) ); -} - -void print_result - ( - const char* msg, - int32_t n_repeats, - dim_t m, - dim_t n, - dim_t k, - dim_t lda, - dim_t ldb, - dim_t ldc, - double runtime - ) -{ - double gflops = get_gflops( m, n, k, runtime ); - printf("%s m: %ld, n: %ld, k: %ld, lda: %ld, ldb: %ld, ldc: %ld," \ - " Gops: %f, n_repeats: %d\n", - msg, m, n, k, lda, ldb, ldc, gflops, n_repeats); -} - -#define GEN_MAT_MUL_BENCH_DRV_FUNC(A_type,B_type,C_type,BLAS_SFX) \ -void mat_mul_bench_driver_ ## BLAS_SFX \ - ( \ - char op_t, \ - int32_t n_repeats, \ - dim_t m, \ - dim_t n, \ - dim_t k, \ - C_type alpha, \ - A_type* a, \ - dim_t lda, \ - B_type* b, \ - dim_t ldb, \ - C_type beta, \ - C_type* c, \ - dim_t ldc, \ - aocl_post_op* post_op\ - ) \ -{ \ - double min_time_diff = DBL_MAX; \ - for ( int32_t nr = 0; nr < n_repeats; ++nr ) \ - { \ - if ( bench_mode == 'a' ) \ - { \ - memset( ( void* ) c, 0, sizeof( C_type ) * m * n ); \ - } \ - \ - struct timespec tstart={0,0}, tend={0,0}; \ - clock_gettime(CLOCK_MONOTONIC, &tstart); \ - \ - GEN_FUNC_NAME(mat_mul_,BLAS_SFX) \ - ( \ - op_t, m, n, k, \ - alpha, \ - a, lda, \ - b, ldb, \ - beta, \ - c, ldc, \ - post_op \ - ); \ - \ - clock_gettime(CLOCK_MONOTONIC, &tend); \ - \ - double diff = \ - ( ( double ) tend.tv_sec + ( 1.0e-9 * tend.tv_nsec ) ) - \ - ( ( double ) tstart.tv_sec + ( 1.0e-9 * tstart.tv_nsec ) ); \ - min_time_diff = ( diff < min_time_diff ) ? diff : min_time_diff; \ - } \ - \ - print_result( XSTR(BLAS_SFX), n_repeats, m, n, k, lda, ldb, ldc, min_time_diff); \ -} \ - -GEN_MAT_MUL_BENCH_DRV_FUNC(uint8_t, int8_t, int16_t, u8s8s16os16) -GEN_MAT_MUL_BENCH_DRV_FUNC(uint8_t,int8_t,int32_t,u8s8s32os32) -GEN_MAT_MUL_BENCH_DRV_FUNC(float,float,float,f32f32f32of32) - -#define GEN_MAT_MUL_ACC_CHK_DRV_FUNC(A_type,B_type,C_type,BLAS_SFX) \ -void mat_mul_accuracy_check_driver_ ## BLAS_SFX \ - ( \ - FILE* fout, \ - dim_t m, \ - dim_t n, \ - dim_t k, \ - C_type alpha, \ - A_type* a, \ - dim_t lda, \ - B_type* b, \ - dim_t ldb, \ - C_type beta, \ - C_type* c, \ - dim_t ldc, \ - C_type* c_ref, \ - dim_t ldc_ref, \ - aocl_post_op* post_op\ - ) \ -{ \ - for ( dim_t i = 0; i < m; ++i ) \ - { \ - for ( dim_t j = 0; j < n; ++j ) \ - { \ - C_type temp_accum = 0; \ - \ - for ( dim_t p = 0; p < k; ++p) \ - { \ - temp_accum += ( *( a + ( i * lda ) + p ) * *( b + ( p * ldb ) + j ) ); \ - } \ - \ - temp_accum = ( beta * ( * (c_ref + ( ldc_ref * i ) + j ) ) ) \ - + ( alpha * temp_accum ); \ - \ - if ( post_op != NULL ) \ - { \ - /* Apply bias followed by relu. */ \ - if ( post_op->seq_vector[0] == BIAS ) \ - { \ - if ( post_op->seq_length >= 1 ) \ - { \ - temp_accum += ( *( ( int32_t* )post_op->bias.bias + j ) ); \ - } \ - if ( post_op->seq_length > 1 ) \ - { \ - temp_accum = ( temp_accum > 0 ) ? temp_accum : 0 ; \ - } \ - } \ - else if ( post_op->seq_vector[0] == ELTWISE ) \ - { \ - if ( post_op->seq_length >= 1 ) \ - { \ - temp_accum = ( temp_accum > 0 ) ? temp_accum : 0 ; \ - } \ - if ( post_op->seq_length > 1 ) \ - { \ - temp_accum += ( *( ( int32_t* )post_op->bias.bias + j ) ); \ - } \ - } \ - } \ - \ - if ( *( c + ( ldc * i ) + j ) != temp_accum ) \ - { \ - if ( fout ) \ - { \ - fprintf( fout, "%s Failure input m: %ld, n: %ld, k: %ld," \ - " lda: %ld, ldb: %ld, ldc: %ld\n", \ - XSTR(BLAS_SFX), m, n, k, lda, ldb, ldc ); \ - fflush( fout ); \ - } \ - printf("failure, m: %ld, n: %ld, k: %ld\n", i, j, k); \ - goto cleanup_acc; \ - } \ - } \ - } \ -cleanup_acc: \ - return; \ -} \ - -GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t, int8_t, int16_t, u8s8s16os16) -GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t,int8_t,int32_t,u8s8s32os32) -GEN_MAT_MUL_ACC_CHK_DRV_FUNC(float,float,float,f32f32f32of32) - -/* Only supports bias followed by RELU and vice versa for now.*/ \ -#define GEN_MAT_MUL_POST_OPS_CREATOR(C_type,BLAS_SFX) \ -aocl_post_op* lpgemm_create_post_ops_struct_ ## BLAS_SFX \ - ( \ - dim_t m, \ - dim_t n, \ - char* post_ops_str \ - ) \ -{ \ - aocl_post_op* post_ops = NULL; \ - post_ops = ( aocl_post_op* ) malloc( sizeof( aocl_post_op ) ); \ - \ - if ( post_ops == NULL ) \ - { \ - return NULL; \ - } \ - \ - /* Only supporting 2 post ops at max for now.*/ \ - dim_t max_post_ops_seq_length = 2; \ - post_ops->seq_vector = ( AOCL_POST_OP_TYPE* ) \ - malloc \ - ( \ - max_post_ops_seq_length * \ - sizeof( AOCL_POST_OP_TYPE ) \ - ); \ - \ - if ( post_ops->seq_vector == NULL ) \ - { \ - free( post_ops ); \ - return NULL; \ - } \ - \ - /* Parse post ops list.*/ \ - char* ops_tok = strtok(post_ops_str, ", " ); \ - dim_t cur_op_index = 0; \ - while ( ops_tok ) \ - { \ - if ( strcmp( ops_tok, "bias") == 0 ) \ - { \ - post_ops->seq_vector[cur_op_index] = BIAS; \ - } \ - else if ( strcmp( ops_tok, "relu") == 0 ) \ - { \ - post_ops->seq_vector[cur_op_index] = ELTWISE; \ - } \ - ops_tok = strtok( NULL, ", " ); \ - cur_op_index++; \ - } \ - post_ops->seq_length = cur_op_index; \ - \ - /* Allocate bias buffer, return early if alloc fails.*/ \ - post_ops->bias.bias = malloc( n * sizeof( C_type ) ); \ - if ( post_ops->bias.bias == NULL ) \ - { \ - free( post_ops->seq_vector ); \ - free( post_ops ); \ - return NULL; \ - } \ - \ - GEN_FUNC_NAME(fill_array_post_ops_,C_type)( post_ops->bias.bias, n ); \ - \ - post_ops->eltwise.is_power_of_2 = FALSE; \ - post_ops->eltwise.scale_factor = NULL; \ - post_ops->eltwise.algo.alpha = NULL; \ - post_ops->eltwise.algo.beta = NULL; \ - post_ops->eltwise.algo.algo_type = RELU; \ - \ - return post_ops; \ -} \ - -GEN_MAT_MUL_POST_OPS_CREATOR(int16_t,u8s8s16os16) -GEN_MAT_MUL_POST_OPS_CREATOR(int32_t,u8s8s32os32) -GEN_MAT_MUL_POST_OPS_CREATOR(float,f32f32f32of32) - -void lpgemm_destroy_post_ops_struct( aocl_post_op* post_ops ) -{ - if ( post_ops == NULL ) - { - return; - } - - if ( post_ops->bias.bias != NULL ) - { - free( post_ops->bias.bias ); - } - - if( post_ops->seq_vector != NULL ) - { - free( post_ops->seq_vector ); - } - - free( post_ops ); -} - -#define GEN_MAT_MUL_BENCH_MAIN_FUNC(A_type,B_type,C_type,BLAS_SFX) \ -void mat_mul_bench_main_ ## BLAS_SFX \ - ( \ - FILE* fin, \ - FILE* fout, \ - char op_t, \ - int32_t m, \ - int32_t n, \ - int32_t k, \ - int32_t stride_a, \ - int32_t stride_b, \ - int32_t stride_c, \ - char* post_ops_str \ - ) \ -{ \ - if ( ( op_t != 'p' ) && ( op_t != 'P' ) && ( op_t != 'r' ) && ( op_t != 'R' ) ) \ - { \ - printf("The op_t ( 2nd arg in input.txt) is not valid\n"); \ - return; \ - } \ - \ - int32_t n_repeats = bli_max( 30, bli_min(( 3e10 / ( ( int64_t )m * n * k )), 100 )); \ - if ( global_n_repeat > 0 ) \ - { \ - n_repeats = global_n_repeat; \ - } \ - \ - /* Get 64 byte aligned memory.*/ \ - A_type* a = ( A_type* ) bli_malloc_user( sizeof( A_type ) * m * k ); \ - \ - B_type* b = ( B_type* ) bli_malloc_user( sizeof( B_type ) * n * k ); \ - \ - C_type* c = ( C_type* ) bli_malloc_user( sizeof( C_type ) * m * n ); \ - memset( ( void* ) c, 0, sizeof( C_type ) * m * n ); \ - \ - C_type* c_ref = ( C_type* ) bli_malloc_user( sizeof( C_type ) * m * n ); \ - memset( ( void* ) c_ref, 0, sizeof( C_type ) * m * n ); \ - \ - C_type alpha; \ - C_type beta; \ - if ( bench_mode == 'p' ) \ - { \ - alpha = 1; \ - beta = 0; \ - } \ - else if ( bench_mode == 'a' ) \ - { \ - alpha = 2; \ - beta = 9; \ - } \ - \ - GEN_FUNC_NAME(fill_array_,A_type)( a, ( m * k ) ); \ - GEN_FUNC_NAME(fill_array_,B_type)( b, ( k * n ) ); \ - \ - aocl_post_op* post_op = NULL; \ - if ( post_ops_str != NULL ) \ - { \ - post_op = GEN_FUNC_NAME(lpgemm_create_post_ops_struct_,BLAS_SFX)( m, n, post_ops_str ); \ - if ( post_op == NULL ) \ - { \ - printf(" post op struct allocation failure, returning.\n"); \ - return; \ - } \ - } \ - \ - if ( ( op_t == 'p' ) || ( op_t == 'P' ) ) \ - { \ - /* No reordering of B.*/ \ - GEN_FUNC_NAME(mat_mul_bench_driver_,BLAS_SFX) \ - ( \ - op_t, n_repeats, m, n, k, \ - alpha, \ - a, stride_a, \ - b, stride_b, \ - beta, \ - c, stride_c, \ - post_op \ - ); \ - } \ - else if ( ( op_t == 'r' ) || ( op_t == 'R' ) ) \ - { \ - /* Reorder B.*/ \ - siz_t b_reorder_buf_siz_req = \ - GEN_FUNC_NAME(aocl_get_reorder_buf_size_,BLAS_SFX)( 'B', k, n ); \ - \ - B_type* b_reorder = ( B_type* ) bli_malloc_user( b_reorder_buf_siz_req ); \ - GEN_FUNC_NAME(aocl_reorder_,BLAS_SFX)( 'B', b, b_reorder, k, n, stride_b ); \ - \ - GEN_FUNC_NAME(mat_mul_bench_driver_,BLAS_SFX) \ - ( \ - op_t, n_repeats, m, n, k, \ - alpha, \ - a, stride_a, \ - b_reorder, stride_b, \ - beta, \ - c, stride_c, \ - post_op \ - ); \ - \ - bli_free_user( b_reorder ); \ - } \ - \ - if ( bench_mode == 'a' ) \ - { \ - printf(" Running accuracy check.\n"); \ - GEN_FUNC_NAME(mat_mul_accuracy_check_driver_,BLAS_SFX) \ - ( \ - fout, m, n, k, \ - alpha, \ - a, stride_a, \ - b, stride_b, \ - beta, \ - c, stride_c, \ - c_ref, stride_c, \ - post_op \ - ); \ - } \ - \ - lpgemm_destroy_post_ops_struct( post_op ); \ - \ - if ( a != NULL ) \ - { \ - bli_free_user( a ); \ - } \ - if ( b != NULL ) \ - { \ - bli_free_user( b ); \ - } \ - if ( c != NULL ) \ - { \ - bli_free_user( c ); \ - } \ - if ( c_ref != NULL ) \ - { \ - bli_free_user( c_ref ); \ - } \ -} \ - -GEN_MAT_MUL_BENCH_MAIN_FUNC(uint8_t, int8_t, int16_t, u8s8s16os16) -GEN_MAT_MUL_BENCH_MAIN_FUNC(uint8_t,int8_t,int32_t,u8s8s32os32) -GEN_MAT_MUL_BENCH_MAIN_FUNC(float,float,float,f32f32f32of32) - -int main( int argc, char** argv ) -{ - FILE* fin = NULL; - if ( argc < 5 ) - { - printf( "Usage: ./mat_mul -i input.txt -m mode < -n 1000 -o op1,op2.. >" \ - "\nMode is either a or p. a is used for accuracy test, " \ - " whereas p is used for performance benchmarking." \ - "\nn_repeats can be set optionally using -n arg." \ - "\nPost ops can be executed optionaly by providing a " \ - "coma separated list of ops after -o arg.\n Currently " \ - "bias,relu and relu,bias is supported.\n" ); - exit( 1 ); - } - - char* file_name = NULL; - char* post_ops_str = NULL; - char* post_ops_str_dest = NULL; //Strtok is used to parse, need to maintain a copy. - - // Parse CLI arguments. - opterr = 0; - int opt_val; - while ( ( opt_val = getopt( argc, argv, "i:m:n:o:" ) ) != -1 ) - { - switch ( opt_val ) - { - case 'i': - file_name = optarg; - break; - case 'm': - bench_mode = ( ( ( *optarg ) == 'a' ) || ( ( *optarg ) == 'p' ) ) ? ( *optarg ) : 'p'; - break; - case 'n': - global_n_repeat = ( atoi( optarg ) > 0 ) ? atoi( optarg ) : 0; - break; - case 'o': - post_ops_str = optarg; - break; - default: - break; - } - } - - if ( post_ops_str != NULL ) - { - post_ops_str_dest = strdup( post_ops_str ); - } - - if ( bench_mode == 'p' ) - { - printf( "Running bench in performance benchmarking mode.\n" ); - } - else if ( bench_mode == 'a' ) - { - printf( "Running bench in accuracy/correctness testing mode.\n" ); - } - - if ( file_name == NULL ) - { - printf( " File name provided is invalid.\n" ); - exit( 1 ); - } - - fin = fopen( file_name, "r" ); - if (fin == NULL) - { - printf( "Error opening the file %s\n", argv[1] ); - exit( 1 ); - } - - FILE* fout = NULL; - - fout = fopen( "lpgemm_accuracy_test_failures.txt", "w" ); - - char op_type_char; - char op_t; - int32_t m, n, k; - int32_t stride_a, stride_b, stride_c; - - const dim_t len_list_omp_cores_for_testing = 2; - const dim_t list_omp_cores_for_testing[2] = { 80, 1 }; - - dim_t core_index = 0; - bool can_run = TRUE; - while ( ( can_run == TRUE ) && ( fseek( fin, 0L, SEEK_SET ) == 0 ) ) - { - if ( bench_mode == 'p' ) - { - can_run = FALSE; - } - else if ( bench_mode == 'a' ) - { - // For accuracy testing, we test accuracy using multiple different - // number of cores. This helps uncover any bugs related to over - // subscription or varying thread factorizations. - // Set current number of cores. -#ifdef BLIS_ENABLE_OPENMP - omp_set_num_threads( list_omp_cores_for_testing[core_index] ); -#endif - printf( "Accuracy test using %ld threads.\n", - list_omp_cores_for_testing[core_index] ); - - core_index++; - if ( core_index < len_list_omp_cores_for_testing ) - { - can_run = TRUE; - } - else - { - can_run = FALSE; - } - } - - while ( fscanf( fin, "%c %c %d %d %d %d %d %d\n", - &op_type_char, &op_t, &m, &n, &k, - &stride_a, &stride_b, &stride_c ) == 8 ) - { - if ( ( op_type_char == 'i' ) || ( op_type_char == 'I' ) ) - { - GEN_FUNC_NAME(mat_mul_bench_main_,u8s8s32os32) - ( - fin, fout, op_t, - m, n, k, stride_a, stride_b, stride_c, - post_ops_str_dest - ); - } - else if ( ( op_type_char == 'f' ) || ( op_type_char == 'F' ) ) - { - GEN_FUNC_NAME(mat_mul_bench_main_,f32f32f32of32) - ( - fin, fout, op_t, - m, n, k, stride_a, stride_b, stride_c, - NULL - ); - } - else if ((op_type_char == 's') || (op_type_char == 'S')) - { - GEN_FUNC_NAME(mat_mul_bench_main_, u8s8s16os16) - ( - fin, fout, op_t, - m, n, k, stride_a, stride_b, stride_c, - NULL - ); - } - if ( post_ops_str != NULL ) - { - strcpy( post_ops_str_dest, post_ops_str ); - } - } - } - - if ( post_ops_str_dest != NULL ) - { - free( post_ops_str_dest ); - } - if ( fin ) - { - fclose( fin ); - } - if ( fout ) - { - fclose( fout ); - } - return 0; -} From 584069bf74959c76164f765396c1b501deb9d0c7 Mon Sep 17 00:00:00 2001 From: mkadavil Date: Wed, 24 Aug 2022 19:28:59 +0530 Subject: [PATCH 191/243] Parametric ReLU post-ops support for u8s8s32 and u8s8s16 GEMM. -Parametric ReLU is the generalization of leaky ReLU in which the leakage coefficient is tunable. The support for the same is added following the register-level fusion technique. -Low precision bench enhancement to check accuracy/performance of low precision gemm with PReLU. -Bug fixes in low precision gemm kernels. AMD-Internal: [CPUPL-2442] Change-Id: I81336405b185a994297d122b2d868b758ae6dad5 --- Makefile | 8 +- addon/aocl_gemm/aocl_gemm_post_ops.h | 6 +- .../aocl_gemm/frame/bf16bf16f32/lpgemm_bf16.c | 2 - addon/aocl_gemm/frame/lpgemm_post_ops.c | 10 +- addon/aocl_gemm/frame/lpgemm_post_ops.h | 6 +- .../u8s8s16/lpgemm_6x32rowmajor_amd256.c | 53 +- .../kernels/u8s8s16/lpgemm_m_fringe_amd256.c | 76 +- .../kernels/u8s8s16/lpgemm_mn_fringe_amd256.c | 137 +++- .../kernels/u8s8s16/lpgemm_n_fringe_amd256.c | 76 +- .../kernels/u8s8s16/lpgemm_s16_kern_macros.h | 55 ++ .../u8s8s32/lpgemm_6x64rowmajor_amd512vnni.c | 86 +- .../u8s8s32/lpgemm_m_fringe_amd512vnni.c | 246 +++++- .../u8s8s32/lpgemm_mn_fringe_amd512vnni.c | 576 ++++++++++++- .../u8s8s32/lpgemm_n_fringe_amd512vnni.c | 179 +++- .../kernels/u8s8s32/lpgemm_s32_kern_macros.h | 45 ++ bench/bench_aocl_gemm/Makefile | 2 +- bench/bench_aocl_gemm/bench_lpgemm.c | 761 ++++++++++++++++++ 17 files changed, 2232 insertions(+), 92 deletions(-) create mode 100644 addon/aocl_gemm/kernels/u8s8s16/lpgemm_s16_kern_macros.h create mode 100644 addon/aocl_gemm/kernels/u8s8s32/lpgemm_s32_kern_macros.h create mode 100644 bench/bench_aocl_gemm/bench_lpgemm.c diff --git a/Makefile b/Makefile index 7a730d9c25..e066ea36aa 100644 --- a/Makefile +++ b/Makefile @@ -598,12 +598,12 @@ endef define make-c99-addon-rule $(BASE_OBJ_ADDON_PATH)/%.o: $(ADDON_PATH)/%.$(2) $(BLIS_H_FLAT) $(ADDON_H99_FILES) $(MAKE_DEFS_MK_PATHS) ifeq ($(ENABLE_VERBOSE),yes) - $$(if $$(findstring _amd512vnni,$$<),$$(eval LPGEMM_MARCH_VAR=icelake-server),$$(eval LPGEMM_MARCH_VAR=znver3)) - $(CC) -march=$$(LPGEMM_MARCH_VAR) -mavx512bf16 $(call get-addon-c99flags-for,$(1)) -c $$< -o $$@ + $$(if $$(findstring _amd512vnni,$$<),$$(eval LPGEMM_MARCH_VAR=icelake-server -mavx512bf16),$$(eval LPGEMM_MARCH_VAR=znver3)) + $(CC) -march=$$(LPGEMM_MARCH_VAR) $(call get-addon-c99flags-for,$(1)) -c $$< -o $$@ else @echo "Compiling $$@" $(call get-addon-c99text-for,$(1)) - $$(if $$(findstring _amd512vnni,$$<),$$(eval LPGEMM_MARCH_VAR=icelake-server),$$(eval LPGEMM_MARCH_VAR=znver3)) - @$(CC) -march=$$(LPGEMM_MARCH_VAR) -mavx512bf16 $(call get-addon-c99flags-for,$(1)) -c $$< -o $$@ + $$(if $$(findstring _amd512vnni,$$<),$$(eval LPGEMM_MARCH_VAR=icelake-server -mavx512bf16),$$(eval LPGEMM_MARCH_VAR=znver3)) + @$(CC) -march=$$(LPGEMM_MARCH_VAR) $(call get-addon-c99flags-for,$(1)) -c $$< -o $$@ endif endef diff --git a/addon/aocl_gemm/aocl_gemm_post_ops.h b/addon/aocl_gemm/aocl_gemm_post_ops.h index ce69ea0b33..4a739892a4 100644 --- a/addon/aocl_gemm/aocl_gemm_post_ops.h +++ b/addon/aocl_gemm/aocl_gemm_post_ops.h @@ -39,10 +39,8 @@ typedef enum { - LINEAR = 0, - RELU = 1, - GELU = 2, - CLIP = 3, + RELU = 0, + PRELU = 1, } AOCL_ELT_ALGO_TYPE; typedef enum diff --git a/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_bf16.c b/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_bf16.c index 8988537184..3a1f473d58 100644 --- a/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_bf16.c +++ b/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_bf16.c @@ -47,10 +47,8 @@ LPGEMM_5LOOP(bfloat16,bfloat16,float,bf16bf16f32of32) dim_t KC = lpgemm_get_block_size_KC_global_cntx( BF16BF16F32OF32 ); dim_t MC = lpgemm_get_block_size_MC_global_cntx( BF16BF16F32OF32 ); dim_t NR = lpgemm_get_block_size_NR_global_cntx( BF16BF16F32OF32 ); - dim_t MR = lpgemm_get_block_size_MR_global_cntx( BF16BF16F32OF32 ); const int16_t* a_use = NULL; - dim_t rs_a_use = rs_a; dim_t cs_a_use = cs_a; dim_t a_block_stride = 0; diff --git a/addon/aocl_gemm/frame/lpgemm_post_ops.c b/addon/aocl_gemm/frame/lpgemm_post_ops.c index 700679772a..45b479c19e 100644 --- a/addon/aocl_gemm/frame/lpgemm_post_ops.c +++ b/addon/aocl_gemm/frame/lpgemm_post_ops.c @@ -103,17 +103,13 @@ void lpgemm_translate_to_post_ops_list // Eltwise algo dispatcher. switch ( post_op_unparsed->eltwise.algo.algo_type ) { - case LINEAR: - tmp_code = POST_OPS_LINEAR; - break; case RELU: tmp_code = POST_OPS_RELU; break; - case GELU: - tmp_code = POST_OPS_GELU; + case PRELU: + tmp_code = POST_OPS_RELU_SCALE; break; - case CLIP: - tmp_code = POST_OPS_CLIP; + default: break; } lpgemm_set_node_params diff --git a/addon/aocl_gemm/frame/lpgemm_post_ops.h b/addon/aocl_gemm/frame/lpgemm_post_ops.h index 325cda5ff7..7dea44b0c7 100644 --- a/addon/aocl_gemm/frame/lpgemm_post_ops.h +++ b/addon/aocl_gemm/frame/lpgemm_post_ops.h @@ -40,10 +40,8 @@ typedef enum POST_OPS_DISABLE = 0, POST_OPS_BIAS = 1, POST_OPS_RELU = 2, - POST_OPS_SUM = 3, - POST_OPS_LINEAR = 4, - POST_OPS_GELU = 5, - POST_OPS_CLIP = 6, + POST_OPS_RELU_SCALE = 3, + POST_OPS_SUM = 4, } LPGEMM_POST_OP_CODE; // Used as an internal structure. diff --git a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_6x32rowmajor_amd256.c b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_6x32rowmajor_amd256.c index 64a2293041..d90f195550 100644 --- a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_6x32rowmajor_amd256.c +++ b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_6x32rowmajor_amd256.c @@ -34,6 +34,7 @@ #include #include "blis.h" #include "lpgemm_kernels.h" +#include "lpgemm_s16_kern_macros.h" // 6x32 int8o16 kernel LPGEMM_MAIN_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x32) @@ -42,7 +43,9 @@ LPGEMM_MAIN_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x32) { &&POST_OPS_6x32_DISABLE, &&POST_OPS_BIAS_6x32, - &&POST_OPS_RELU_6x32}; + &&POST_OPS_RELU_6x32, + &&POST_OPS_RELU_SCALE_6x32 + }; dim_t MR = 6; dim_t NR = 32; @@ -70,7 +73,7 @@ LPGEMM_MAIN_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x32) if (n0_16 == 1) { lpgemm_rowvar_u8s8s16o16_6x16( - m0, k0_updated, + m0, k0, a, rs_a, cs_a, ps_a, b, ((rs_b / 2) * 1), cs_b, c, rs_c, @@ -81,12 +84,13 @@ LPGEMM_MAIN_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x32) b = b + (16 * k0_updated); c = c + 16; + post_op_c_j += 16; } if (n0_rem > 0) { lpgemm_rowvar_u8s8s16o16_6xlt16( - m0, k0_updated, + m0, k0, a, rs_a, cs_a, ps_a, b, ((rs_b / 2) * 1), cs_b, c, rs_c, @@ -468,6 +472,49 @@ LPGEMM_MAIN_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x32) // c[5,16-31] c_int16_5p1 = _mm256_max_epi16( selector1, c_int16_5p1 ); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_6x32: + { + selector2 = + _mm256_set1_epi16( *( ( int16_t* )post_ops_list_temp->op_args2 ) ); + + // c[0,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_0p0) + + // c[0,16-31] + RELU_SCALE_OP_S16_AVX2(c_int16_0p1) + + // c[1,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_1p0) + + // c[1,16-31] + RELU_SCALE_OP_S16_AVX2(c_int16_1p1) + + // c[2,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_2p0) + + // c[2,16-31] + RELU_SCALE_OP_S16_AVX2(c_int16_2p1) + + // c[3,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_3p0) + + // c[3,16-31] + RELU_SCALE_OP_S16_AVX2(c_int16_3p1) + + // c[4,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_4p0) + + // c[4,16-31] + RELU_SCALE_OP_S16_AVX2(c_int16_4p1) + + // c[5,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_5p0) + + // c[5,16-31] + RELU_SCALE_OP_S16_AVX2(c_int16_5p1) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_6x32_DISABLE: diff --git a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_m_fringe_amd256.c b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_m_fringe_amd256.c index c4071f0428..753a083d93 100644 --- a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_m_fringe_amd256.c +++ b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_m_fringe_amd256.c @@ -36,6 +36,7 @@ #include "blis.h" #include "lpgemm_kernels.h" +#include "lpgemm_s16_kern_macros.h" // 4x32 int8o16 kernel LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4x32) @@ -46,7 +47,9 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4x32) { &&POST_OPS_4x32_DISABLE, &&POST_OPS_BIAS_4x32, - &&POST_OPS_RELU_4x32}; + &&POST_OPS_RELU_4x32, + &&POST_OPS_RELU_SCALE_4x32 + }; // The division is done by considering the vpmaddubsw instruction dim_t k_full_pieces = k0 / 2; @@ -314,6 +317,37 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4x32) // c[3,16-31] c_int16_3p1 = _mm256_max_epi16( selector1, c_int16_3p1 ); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_4x32: + { + selector2 = + _mm256_set1_epi16( *( ( int16_t* )post_ops_list_temp->op_args2 ) ); + + // c[0,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_0p0) + + // c[0,16-31] + RELU_SCALE_OP_S16_AVX2(c_int16_0p1) + + // c[1,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_1p0) + + // c[1,16-31] + RELU_SCALE_OP_S16_AVX2(c_int16_1p1) + + // c[2,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_2p0) + + // c[2,16-31] + RELU_SCALE_OP_S16_AVX2(c_int16_2p1) + + // c[3,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_3p0) + + // c[3,16-31] + RELU_SCALE_OP_S16_AVX2(c_int16_3p1) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_4x32_DISABLE: @@ -355,7 +389,9 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2x32) { &&POST_OPS_2x32_DISABLE, &&POST_OPS_BIAS_2x32, - &&POST_OPS_RELU_2x32}; + &&POST_OPS_RELU_2x32, + &&POST_OPS_RELU_SCALE_2x32 + }; // The division is done by considering the vpmaddubsw instruction dim_t k_full_pieces = k0 / 2; @@ -518,6 +554,25 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2x32) // c[1,16-31] c_int16_1p1 = _mm256_max_epi16( selector1, c_int16_1p1 ); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_2x32: + { + selector2 = + _mm256_set1_epi16( *( ( int16_t* )post_ops_list_temp->op_args2 ) ); + + // c[0,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_0p0) + + // c[0,16-31] + RELU_SCALE_OP_S16_AVX2(c_int16_0p1) + + // c[1,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_1p0) + + // c[1,16-31] + RELU_SCALE_OP_S16_AVX2(c_int16_1p1) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_2x32_DISABLE: @@ -546,7 +601,9 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_1x32) { &&POST_OPS_1x32_DISABLE, &&POST_OPS_BIAS_1x32, - &&POST_OPS_RELU_1x32}; + &&POST_OPS_RELU_1x32, + &&POST_OPS_RELU_SCALE_1x32 + }; // The division is done by considering the vpmaddubsw instruction dim_t k_full_pieces = k0 / 2; @@ -656,6 +713,19 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_1x32) // c[0, 16-31] c_int16_0p1 = _mm256_max_epi16( selector1, c_int16_0p1 ); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_1x32: + { + selector2 = + _mm256_set1_epi16( *( ( int16_t* )post_ops_list_temp->op_args2 ) ); + + // c[0,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_0p0) + + // c[0,16-31] + RELU_SCALE_OP_S16_AVX2(c_int16_0p1) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_1x32_DISABLE: diff --git a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_mn_fringe_amd256.c b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_mn_fringe_amd256.c index a3c2ed4ead..77ccd2d27d 100644 --- a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_mn_fringe_amd256.c +++ b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_mn_fringe_amd256.c @@ -36,6 +36,7 @@ #include "blis.h" #include "lpgemm_kernels.h" +#include "lpgemm_s16_kern_macros.h" // 4x32 int8o16 kernel LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4x16) @@ -46,7 +47,9 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4x16) { &&POST_OPS_4x16_DISABLE, &&POST_OPS_BIAS_4x16, - &&POST_OPS_RELU_4x16}; + &&POST_OPS_RELU_4x16, + &&POST_OPS_RELU_SCALE_4x16 + }; // The division is done by considering the vpmaddubsw instruction dim_t k_full_pieces = k0 / 2; @@ -236,6 +239,25 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4x16) // c[3,0-15] c_int16_3p0 = _mm256_max_epi16( selector1, c_int16_3p0 ); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_4x16: + { + selector2 = + _mm256_set1_epi16( *( ( int16_t* )post_ops_list_temp->op_args2 ) ); + + // c[0,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_0p0) + + // c[1,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_1p0) + + // c[2,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_2p0) + + // c[3,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_3p0) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_4x16_DISABLE: @@ -264,7 +286,9 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4xlt16) { &&POST_OPS_4xlt16_DISABLE, &&POST_OPS_BIAS_4xlt16, - &&POST_OPS_RELU_4xlt16}; + &&POST_OPS_RELU_4xlt16, + &&POST_OPS_RELU_SCALE_4xlt16 + }; // The division is done by considering the vpmaddubsw instruction dim_t k_full_pieces = k0 / 2; @@ -433,13 +457,11 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4xlt16) POST_OP_LABEL_LASTK_SAFE_JUMP POST_OPS_BIAS_4xlt16: { - int16_t buf4[16]; - - memcpy(buf4, (int16_t *)(post_ops_list_temp->op_args1 - + post_op_c_j + ( 0 * 16 )), (n0_rem * sizeof(int16_t))); + memcpy( buf0, ( ( int16_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 ) ), ( n0_rem * sizeof( int16_t ) ) ); selector1 = - _mm256_loadu_si256( (__m256i const *) buf4 ); + _mm256_loadu_si256( (__m256i const *) buf0 ); // c[0,0-15] c_int16_0p0 = _mm256_add_epi16( selector1, c_int16_0p0 ); @@ -471,6 +493,25 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4xlt16) // c[3,0-15] c_int16_3p0 = _mm256_max_epi16( selector1, c_int16_3p0 ); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_4xlt16: + { + selector2 = + _mm256_set1_epi16( *( ( int16_t* )post_ops_list_temp->op_args2 ) ); + + // c[0,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_0p0) + + // c[1,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_1p0) + + // c[2,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_2p0) + + // c[3,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_3p0) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_4xlt16_DISABLE: @@ -509,7 +550,9 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2x16) { &&POST_OPS_2x16_DISABLE, &&POST_OPS_BIAS_2x16, - &&POST_OPS_RELU_2x16}; + &&POST_OPS_RELU_2x16, + &&POST_OPS_RELU_SCALE_2x16 + }; // The division is done by considering the vpmaddubsw instruction dim_t k_full_pieces = k0 / 2; @@ -631,6 +674,19 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2x16) // c[1,0-15] c_int16_1p0 = _mm256_max_epi16( selector1, c_int16_1p0 ); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_2x16: + { + selector2 = + _mm256_set1_epi16( *( ( int16_t* )post_ops_list_temp->op_args2 ) ); + + // c[0,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_0p0) + + // c[1,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_1p0) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_2x16_DISABLE: @@ -653,7 +709,9 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2xlt16) { &&POST_OPS_2xlt16_DISABLE, &&POST_OPS_BIAS_2xlt16, - &&POST_OPS_RELU_2xlt16}; + &&POST_OPS_RELU_2xlt16, + &&POST_OPS_RELU_SCALE_2xlt16 + }; // The division is done by considering the vpmaddubsw instruction dim_t k_full_pieces = k0 / 2; @@ -759,13 +817,11 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2xlt16) POST_OP_LABEL_LASTK_SAFE_JUMP POST_OPS_BIAS_2xlt16: { - int16_t buf4[16]; - - memcpy(buf4, (int16_t *)(post_ops_list_temp->op_args1 + - post_op_c_j + ( 0 * 16 )), (n0_rem * sizeof(int16_t))); + memcpy( buf0, ( ( int16_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 ) ), ( n0_rem * sizeof( int16_t ) ) ); selector1 = - _mm256_loadu_si256( (__m256i const *) buf4); + _mm256_loadu_si256( (__m256i const *) buf0); // c[0,0-15] c_int16_0p0 = _mm256_add_epi16( selector1, c_int16_0p0 ); @@ -785,6 +841,19 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2xlt16) // c[1,0-15] c_int16_1p0 = _mm256_max_epi16( selector1, c_int16_1p0 ); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_2xlt16: + { + selector2 = + _mm256_set1_epi16( *( ( int16_t* )post_ops_list_temp->op_args2 ) ); + + // c[0,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_0p0) + + // c[1,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_1p0) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_2xlt16_DISABLE: @@ -800,7 +869,7 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2xlt16) memcpy(c + (rs_c * 0) + (0 * 16), buf0, (n0_rem * sizeof(int16_t))); // c[1,0-15] - memcpy(c + (rs_c * +1) + (0 * 16), buf1, (n0_rem * sizeof(int16_t))); + memcpy(c + (rs_c * 1) + (0 * 16), buf1, (n0_rem * sizeof(int16_t))); } // 1x16 int8o16 kernel @@ -812,7 +881,9 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_1x16) { &&POST_OPS_1x16_DISABLE, &&POST_OPS_BIAS_1x16, - &&POST_OPS_RELU_1x16}; + &&POST_OPS_RELU_1x16, + &&POST_OPS_RELU_SCALE_1x16 + }; // The division is done by considering the vpmaddubsw instruction int k_full_pieces = k0 / 2; @@ -849,7 +920,7 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_1x16) { uint8_t a_kfringe; - b0 = _mm256_loadu_si256((__m256i const *)(b + (64 * k_full_pieces) + (NR * 0))); + b0 = _mm256_loadu_si256((__m256i const *)(b + (32 * k_full_pieces) + (NR * 0))); a_kfringe = *(a + (rs_a * 0) + (cs_a * (k_full_pieces * 2))); a_int32_0 = _mm256_set1_epi8(a_kfringe); @@ -899,6 +970,16 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_1x16) // c[0,0-15] c_int16_0p0 = _mm256_max_epi16( selector1, c_int16_0p0 ); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_1x16: + { + selector2 = + _mm256_set1_epi16( *( ( int16_t* )post_ops_list_temp->op_args2 ) ); + + // c[0,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_0p0) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_1x16_DISABLE: @@ -918,7 +999,9 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_1xlt16) { &&POST_OPS_1xlt16_DISABLE, &&POST_OPS_BIAS_1xlt16, - &&POST_OPS_RELU_1xlt16}; + &&POST_OPS_RELU_1xlt16, + &&POST_OPS_RELU_SCALE_1xlt16 + }; // The division is done by considering the vpmaddubsw instruction int k_full_pieces = k0 / 2; @@ -993,13 +1076,11 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_1xlt16) POST_OP_LABEL_LASTK_SAFE_JUMP POST_OPS_BIAS_1xlt16: { - int16_t buf4[16]; - - memcpy(buf4, (int16_t *)(post_ops_list_temp->op_args1 - + post_op_c_j + ( 0 * 16 )), (n0_rem * sizeof(int16_t))); + memcpy( buf0, ( ( int16_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 ) ), ( n0_rem * sizeof( int16_t ) ) ); selector1 = - _mm256_loadu_si256( (__m256i const *)buf4 ); + _mm256_loadu_si256( (__m256i const *)buf0 ); // c[0,0-15] c_int16_0p0 = _mm256_add_epi16( selector1, c_int16_0p0 ); @@ -1013,6 +1094,16 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_1xlt16) // c[0,0-15] c_int16_0p0 = _mm256_max_epi16( selector1, c_int16_0p0 ); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_1xlt16: + { + selector2 = + _mm256_set1_epi16( *( ( int16_t* )post_ops_list_temp->op_args2 ) ); + + // c[0,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_0p0) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_1xlt16_DISABLE: diff --git a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_n_fringe_amd256.c b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_n_fringe_amd256.c index 7a34a636ac..631e01f27f 100644 --- a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_n_fringe_amd256.c +++ b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_n_fringe_amd256.c @@ -36,6 +36,7 @@ #include "blis.h" #include "lpgemm_kernels.h" +#include "lpgemm_s16_kern_macros.h" // 6x16 int8o16 kernel LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x16) @@ -47,7 +48,9 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x16) { &&POST_OPS_6x16_DISABLE, &&POST_OPS_BIAS_6x16, - &&POST_OPS_RELU_6x16}; + &&POST_OPS_RELU_6x16, + &&POST_OPS_RELU_SCALE_6x16 + }; dim_t m_full_pieces = m0 / MR; dim_t m_full_pieces_loop_limit = m_full_pieces * MR; @@ -316,6 +319,31 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x16) // c[5,0-15] c_int16_5p0 = _mm256_max_epi16( selector1, c_int16_5p0 ); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_6x16: + { + selector2 = + _mm256_set1_epi16( *( ( int16_t* )post_ops_list_temp->op_args2 ) ); + + // c[0,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_0p0) + + // c[1,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_1p0) + + // c[2,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_2p0) + + // c[3,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_3p0) + + // c[4,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_4p0) + + // c[5,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_5p0) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_6x16_DISABLE: @@ -410,7 +438,9 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6xlt16) { &&POST_OPS_6xlt16_DISABLE, &&POST_OPS_BIAS_6xlt16, - &&POST_OPS_RELU_6xlt16}; + &&POST_OPS_RELU_6xlt16, + &&POST_OPS_RELU_SCALE_6xlt16 + }; dim_t m_full_pieces = m0 / MR; dim_t m_full_pieces_loop_limit = m_full_pieces * MR; @@ -646,13 +676,12 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6xlt16) POST_OP_LABEL_LASTK_SAFE_JUMP POST_OPS_BIAS_6xlt16: { - int16_t buf6[16]; - - memcpy(buf0, (int16_t *)(post_ops_list_temp->op_args1 + - post_op_c_j + ( 0 * 16 )), (n0_rem * sizeof(int16_t))); + memcpy( buf0, ( ( int16_t* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 ) ), + ( n0_rem * sizeof( int16_t ) ) ); selector1 = - _mm256_loadu_si256( (__m256i const *)buf6 ); + _mm256_loadu_si256( ( __m256i const* )buf0 ); // c[0,0-15] c_int16_0p0 = _mm256_add_epi16( selector1, c_int16_0p0 ); @@ -696,6 +725,31 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6xlt16) // c[5,0-15] c_int16_5p0 = _mm256_max_epi16( selector1, c_int16_5p0 ); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_6xlt16: + { + selector2 = + _mm256_set1_epi16( *( ( int16_t* )post_ops_list_temp->op_args2 ) ); + + // c[0,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_0p0) + + // c[1,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_1p0) + + // c[2,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_2p0) + + // c[3,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_3p0) + + // c[4,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_4p0) + + // c[5,0-15] + RELU_SCALE_OP_S16_AVX2(c_int16_5p0) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_6xlt16_DISABLE: @@ -703,16 +757,16 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6xlt16) // Store the results. // c[0,0-15] - _mm256_storeu_si256((__m256i_u *)buf0, c_int16_0p0); + _mm256_storeu_si256((__m256i *)buf0, c_int16_0p0); // c[1,0-15] - _mm256_storeu_si256((__m256i_u *)buf1, c_int16_1p0); + _mm256_storeu_si256((__m256i *)buf1, c_int16_1p0); // c[2,0-15] - _mm256_storeu_si256((__m256i_u *)buf2, c_int16_2p0); + _mm256_storeu_si256((__m256i *)buf2, c_int16_2p0); // c[3,0-15] - _mm256_storeu_si256((__m256i_u *)buf3, c_int16_3p0); + _mm256_storeu_si256((__m256i *)buf3, c_int16_3p0); // c[4,0-15] _mm256_storeu_si256((__m256i *)buf4, c_int16_4p0); diff --git a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_s16_kern_macros.h b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_s16_kern_macros.h new file mode 100644 index 0000000000..07131d1099 --- /dev/null +++ b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_s16_kern_macros.h @@ -0,0 +1,55 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef LPGEMM_S16_KERN_MACROS_H +#define LPGEMM_S16_KERN_MACROS_H + +#define RELU_SCALE_OP_S16_AVX2(reg) \ + selector1 = _mm256_setzero_si256();\ + selector1 = _mm256_cmpgt_epi16 ( selector1, reg ); \ + \ + /* Only < 0 elements in b0. */ \ + b0 = _mm256_and_si256 ( selector1, reg ); \ +\ + /* Only >= 0 elements in c_int16_0p0. */ \ + reg = _mm256_andnot_si256( selector1, reg ); \ + \ + /* Only scaling for < 0 elements. */ \ + b0 = _mm256_mullo_epi16( b0, selector2 ); \ + \ + /* Combine the scaled < 0 and >= 0 elements. */ \ + reg = _mm256_or_si256( b0, reg ); \ + \ + +#endif //LPGEMM_S16_KERN_MACROS_H diff --git a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_6x64rowmajor_amd512vnni.c b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_6x64rowmajor_amd512vnni.c index 109b4784fa..64c61a9c57 100644 --- a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_6x64rowmajor_amd512vnni.c +++ b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_6x64rowmajor_amd512vnni.c @@ -36,6 +36,7 @@ #include "blis.h" #include "lpgemm_kernels.h" +#include "lpgemm_s32_kern_macros.h" // 6x64 int8o32 kernel LPGEMM_MAIN_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x64) @@ -44,7 +45,8 @@ LPGEMM_MAIN_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x64) { &&POST_OPS_6x64_DISABLE, &&POST_OPS_BIAS_6x64, - &&POST_OPS_RELU_6x64 + &&POST_OPS_RELU_6x64, + &&POST_OPS_RELU_SCALE_6x64 }; dim_t MR = 6; @@ -716,6 +718,88 @@ LPGEMM_MAIN_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x64) // c[5,48-63] c_int32_5p3 = _mm512_max_epi32( selector1, c_int32_5p3 ); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_6x64: + { + selector1 = _mm512_setzero_epi32(); + selector2 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_0p1) + + // c[0, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_0p2) + + // c[0, 48-63] + RELU_SCALE_OP_S32_AVX512(c_int32_0p3) + + // c[1, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_1p1) + + // c[1, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_1p2) + + // c[1, 48-63] + RELU_SCALE_OP_S32_AVX512(c_int32_1p3) + + // c[2, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_2p0) + + // c[2, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_2p1) + + // c[2, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_2p2) + + // c[2, 48-63] + RELU_SCALE_OP_S32_AVX512(c_int32_2p3) + + // c[3, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_3p0) + + // c[3, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_3p1) + + // c[3, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_3p2) + + // c[3, 48-63] + RELU_SCALE_OP_S32_AVX512(c_int32_3p3) + + // c[4, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_4p0) + + // c[4, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_4p1) + + // c[4, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_4p2) + + // c[4, 48-63] + RELU_SCALE_OP_S32_AVX512(c_int32_4p3) + + // c[5, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_5p0) + + // c[5, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_5p1) + + // c[5, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_5p2) + + // c[5, 48-63] + RELU_SCALE_OP_S32_AVX512(c_int32_5p3) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_6x64_DISABLE: diff --git a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_m_fringe_amd512vnni.c b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_m_fringe_amd512vnni.c index 934260e7ed..896aa23d60 100644 --- a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_m_fringe_amd512vnni.c +++ b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_m_fringe_amd512vnni.c @@ -37,6 +37,7 @@ #include "blis.h" #include "lpgemm_kernels.h" +#include "lpgemm_s32_kern_macros.h" // 5x64 int8o32 kernel LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5x64) @@ -45,7 +46,8 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5x64) { &&POST_OPS_5x64_DISABLE, &&POST_OPS_BIAS_5x64, - &&POST_OPS_RELU_5x64 + &&POST_OPS_RELU_5x64, + &&POST_OPS_RELU_SCALE_5x64 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -522,6 +524,76 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5x64) // c[4,48-63] c_int32_4p3 = _mm512_max_epi32( selector1, c_int32_4p3 ); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_5x64: + { + selector1 = _mm512_setzero_epi32(); + selector2 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_0p1) + + // c[0, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_0p2) + + // c[0, 48-63] + RELU_SCALE_OP_S32_AVX512(c_int32_0p3) + + // c[1, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_1p1) + + // c[1, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_1p2) + + // c[1, 48-63] + RELU_SCALE_OP_S32_AVX512(c_int32_1p3) + + // c[2, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_2p0) + + // c[2, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_2p1) + + // c[2, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_2p2) + + // c[2, 48-63] + RELU_SCALE_OP_S32_AVX512(c_int32_2p3) + + // c[3, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_3p0) + + // c[3, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_3p1) + + // c[3, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_3p2) + + // c[3, 48-63] + RELU_SCALE_OP_S32_AVX512(c_int32_3p3) + + // c[4, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_4p0) + + // c[4, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_4p1) + + // c[4, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_4p2) + + // c[4, 48-63] + RELU_SCALE_OP_S32_AVX512(c_int32_4p3) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_5x64_DISABLE: @@ -596,7 +668,8 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_4x64) { &&POST_OPS_4x64_DISABLE, &&POST_OPS_BIAS_4x64, - &&POST_OPS_RELU_4x64 + &&POST_OPS_RELU_4x64, + &&POST_OPS_RELU_SCALE_4x64 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -991,6 +1064,64 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_4x64) // c[3,48-63] c_int32_3p3 = _mm512_max_epi32( selector1, c_int32_3p3 ); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_4x64: + { + selector1 = _mm512_setzero_epi32(); + selector2 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_0p1) + + // c[0, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_0p2) + + // c[0, 48-63] + RELU_SCALE_OP_S32_AVX512(c_int32_0p3) + + // c[1, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_1p1) + + // c[1, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_1p2) + + // c[1, 48-63] + RELU_SCALE_OP_S32_AVX512(c_int32_1p3) + + // c[2, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_2p0) + + // c[2, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_2p1) + + // c[2, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_2p2) + + // c[2, 48-63] + RELU_SCALE_OP_S32_AVX512(c_int32_2p3) + + // c[3, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_3p0) + + // c[3, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_3p1) + + // c[3, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_3p2) + + // c[3, 48-63] + RELU_SCALE_OP_S32_AVX512(c_int32_3p3) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_4x64_DISABLE: @@ -1053,7 +1184,8 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3x64) { &&POST_OPS_3x64_DISABLE, &&POST_OPS_BIAS_3x64, - &&POST_OPS_RELU_3x64 + &&POST_OPS_RELU_3x64, + &&POST_OPS_RELU_SCALE_3x64 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -1366,6 +1498,52 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3x64) // c[2,48-63] c_int32_2p3 = _mm512_max_epi32( selector1, c_int32_2p3 ); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_3x64: + { + selector1 = _mm512_setzero_epi32(); + selector2 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_0p1) + + // c[0, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_0p2) + + // c[0, 48-63] + RELU_SCALE_OP_S32_AVX512(c_int32_0p3) + + // c[1, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_1p1) + + // c[1, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_1p2) + + // c[1, 48-63] + RELU_SCALE_OP_S32_AVX512(c_int32_1p3) + + // c[2, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_2p0) + + // c[2, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_2p1) + + // c[2, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_2p2) + + // c[2, 48-63] + RELU_SCALE_OP_S32_AVX512(c_int32_2p3) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_3x64_DISABLE: @@ -1416,7 +1594,8 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2x64) { &&POST_OPS_2x64_DISABLE, &&POST_OPS_BIAS_2x64, - &&POST_OPS_RELU_2x64 + &&POST_OPS_RELU_2x64, + &&POST_OPS_RELU_SCALE_2x64 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -1647,6 +1826,40 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2x64) // c[1,48-63] c_int32_1p3 = _mm512_max_epi32( selector1, c_int32_1p3 ); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_2x64: + { + selector1 = _mm512_setzero_epi32(); + selector2 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_0p1) + + // c[0, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_0p2) + + // c[0, 48-63] + RELU_SCALE_OP_S32_AVX512(c_int32_0p3) + + // c[1, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_1p1) + + // c[1, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_1p2) + + // c[1, 48-63] + RELU_SCALE_OP_S32_AVX512(c_int32_1p3) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_2x64_DISABLE: @@ -1685,7 +1898,8 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1x64) { &&POST_OPS_1x64_DISABLE, &&POST_OPS_BIAS_1x64, - &&POST_OPS_RELU_1x64 + &&POST_OPS_RELU_1x64, + &&POST_OPS_RELU_SCALE_1x64 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -1834,6 +2048,28 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1x64) // c[0,48-63] c_int32_0p3 = _mm512_max_epi32( selector1, c_int32_0p3 ); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_1x64: + { + selector1 = _mm512_setzero_epi32(); + selector2 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_0p1) + + // c[0, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_0p2) + + // c[0, 48-63] + RELU_SCALE_OP_S32_AVX512(c_int32_0p3) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_1x64_DISABLE: diff --git a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_mn_fringe_amd512vnni.c b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_mn_fringe_amd512vnni.c index c8642475f5..07c2d8fef9 100644 --- a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_mn_fringe_amd512vnni.c +++ b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_mn_fringe_amd512vnni.c @@ -37,6 +37,7 @@ #include "blis.h" #include "lpgemm_kernels.h" +#include "lpgemm_s32_kern_macros.h" // 5xlt16 int8o32 fringe kernel LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5xlt16) @@ -45,7 +46,8 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5xlt16) { &&POST_OPS_5xLT16_DISABLE, &&POST_OPS_BIAS_5xLT16, - &&POST_OPS_RELU_5xLT16 + &&POST_OPS_RELU_5xLT16, + &&POST_OPS_RELU_SCALE_5xLT16 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -251,6 +253,31 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5xlt16) // c[4,0-15] c_int32_4p0 = _mm512_max_epi32( selector1, c_int32_4p0 ); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_5xLT16: + { + selector1 = _mm512_setzero_epi32(); + selector2 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_0p0) + + // c[1, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_1p0) + + // c[2, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_2p0) + + // c[3, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_3p0) + + // c[4, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_4p0) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_5xLT16_DISABLE: @@ -297,7 +324,8 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_4xlt16) { &&POST_OPS_4xLT16_DISABLE, &&POST_OPS_BIAS_4xLT16, - &&POST_OPS_RELU_4xLT16 + &&POST_OPS_RELU_4xLT16, + &&POST_OPS_RELU_SCALE_4xLT16 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -472,6 +500,28 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_4xlt16) // c[3,0-15] c_int32_3p0 = _mm512_max_epi32( selector1, c_int32_3p0 ); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_4xLT16: + { + selector1 = _mm512_setzero_epi32(); + selector2 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_0p0) + + // c[1, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_1p0) + + // c[2, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_2p0) + + // c[3, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_3p0) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_4xLT16_DISABLE: @@ -512,7 +562,8 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3xlt16) { &&POST_OPS_3xLT16_DISABLE, &&POST_OPS_BIAS_3xLT16, - &&POST_OPS_RELU_3xLT16 + &&POST_OPS_RELU_3xLT16, + &&POST_OPS_RELU_SCALE_3xLT16 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -654,6 +705,25 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3xlt16) // c[2,0-15] c_int32_2p0 = _mm512_max_epi32( selector1, c_int32_2p0 ); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_3xLT16: + { + selector1 = _mm512_setzero_epi32(); + selector2 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_0p0) + + // c[1, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_1p0) + + // c[2, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_2p0) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_3xLT16_DISABLE: @@ -688,7 +758,8 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2xlt16) { &&POST_OPS_2xLT16_DISABLE, &&POST_OPS_BIAS_2xLT16, - &&POST_OPS_RELU_2xLT16 + &&POST_OPS_RELU_2xLT16, + &&POST_OPS_RELU_SCALE_2xLT16 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -798,6 +869,22 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2xlt16) // c[1,0-15] c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_2xLT16: + { + selector1 = _mm512_setzero_epi32(); + selector2 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_0p0) + + // c[1, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_1p0) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_2xLT16_DISABLE: @@ -826,7 +913,8 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1xlt16) { &&POST_OPS_1xLT16_DISABLE, &&POST_OPS_BIAS_1xLT16, - &&POST_OPS_RELU_1xLT16 + &&POST_OPS_RELU_1xLT16, + &&POST_OPS_RELU_SCALE_1xLT16 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -904,6 +992,19 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1xlt16) // c[0,0-15] c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_1xLT16: + { + selector1 = _mm512_setzero_epi32(); + selector2 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_0p0) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_1xLT16_DISABLE: @@ -926,7 +1027,8 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5x16) { &&POST_OPS_5x16_DISABLE, &&POST_OPS_BIAS_5x16, - &&POST_OPS_RELU_5x16 + &&POST_OPS_RELU_5x16, + &&POST_OPS_RELU_SCALE_5x16 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -1118,6 +1220,31 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5x16) // c[4,0-15] c_int32_4p0 = _mm512_max_epi32( selector1, c_int32_4p0 ); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_5x16: + { + selector1 = _mm512_setzero_epi32(); + selector2 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_0p0) + + // c[1, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_1p0) + + // c[2, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_2p0) + + // c[3, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_3p0) + + // c[4, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_4p0) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_5x16_DISABLE: @@ -1147,7 +1274,8 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_4x16) { &&POST_OPS_4x16_DISABLE, &&POST_OPS_BIAS_4x16, - &&POST_OPS_RELU_4x16 + &&POST_OPS_RELU_4x16, + &&POST_OPS_RELU_SCALE_4x16 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -1309,6 +1437,28 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_4x16) // c[3,0-15] c_int32_3p0 = _mm512_max_epi32( selector1, c_int32_3p0 ); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_4x16: + { + selector1 = _mm512_setzero_epi32(); + selector2 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_0p0) + + // c[1, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_1p0) + + // c[2, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_2p0) + + // c[3, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_3p0) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_4x16_DISABLE: @@ -1335,7 +1485,8 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3x16) { &&POST_OPS_3x16_DISABLE, &&POST_OPS_BIAS_3x16, - &&POST_OPS_RELU_3x16 + &&POST_OPS_RELU_3x16, + &&POST_OPS_RELU_SCALE_3x16 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -1467,6 +1618,25 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3x16) // c[2,0-15] c_int32_2p0 = _mm512_max_epi32( selector1, c_int32_2p0 ); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_3x16: + { + selector1 = _mm512_setzero_epi32(); + selector2 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_0p0) + + // c[1, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_1p0) + + // c[2, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_2p0) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_3x16_DISABLE: @@ -1490,7 +1660,8 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2x16) { &&POST_OPS_2x16_DISABLE, &&POST_OPS_BIAS_2x16, - &&POST_OPS_RELU_2x16 + &&POST_OPS_RELU_2x16, + &&POST_OPS_RELU_SCALE_2x16 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -1592,6 +1763,22 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2x16) // c[1,0-15] c_int32_1p0 = _mm512_max_epi32( selector1, c_int32_1p0 ); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_2x16: + { + selector1 = _mm512_setzero_epi32(); + selector2 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_0p0) + + // c[1, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_1p0) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_2x16_DISABLE: @@ -1612,7 +1799,8 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1x16) { &&POST_OPS_1x16_DISABLE, &&POST_OPS_BIAS_1x16, - &&POST_OPS_RELU_1x16 + &&POST_OPS_RELU_1x16, + &&POST_OPS_RELU_SCALE_1x16 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -1684,6 +1872,19 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1x16) // c[0,0-15] c_int32_0p0 = _mm512_max_epi32( selector1, c_int32_0p0 ); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_1x16: + { + selector1 = _mm512_setzero_epi32(); + selector2 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_0p0) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_1x16_DISABLE: @@ -1701,7 +1902,8 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5x32) { &&POST_OPS_5x32_DISABLE, &&POST_OPS_BIAS_5x32, - &&POST_OPS_RELU_5x32 + &&POST_OPS_RELU_5x32, + &&POST_OPS_RELU_SCALE_5x32 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -1980,6 +2182,46 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5x32) // c[4,16-31] c_int32_4p1 = _mm512_max_epi32( selector1, c_int32_4p1 ); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_5x32: + { + selector1 = _mm512_setzero_epi32(); + selector2 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_0p1) + + // c[1, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_1p1) + + // c[2, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_2p0) + + // c[2, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_2p1) + + // c[3, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_3p0) + + // c[3, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_3p1) + + // c[4, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_4p0) + + // c[4, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_4p1) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_5x32_DISABLE: @@ -2024,7 +2266,8 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_4x32) { &&POST_OPS_4x32_DISABLE, &&POST_OPS_BIAS_4x32, - &&POST_OPS_RELU_4x32 + &&POST_OPS_RELU_4x32, + &&POST_OPS_RELU_SCALE_4x32 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -2258,6 +2501,40 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_4x32) // c[3,16-31] c_int32_3p1 = _mm512_max_epi32( selector1, c_int32_3p1 ); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_4x32: + { + selector1 = _mm512_setzero_epi32(); + selector2 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_0p1) + + // c[1, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_1p1) + + // c[2, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_2p0) + + // c[2, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_2p1) + + // c[3, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_3p0) + + // c[3, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_3p1) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_4x32_DISABLE: @@ -2296,7 +2573,8 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3x32) { &&POST_OPS_3x32_DISABLE, &&POST_OPS_BIAS_3x32, - &&POST_OPS_RELU_3x32 + &&POST_OPS_RELU_3x32, + &&POST_OPS_RELU_SCALE_3x32 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -2485,6 +2763,34 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3x32) // c[2,16-31] c_int32_2p1 = _mm512_max_epi32( selector1, c_int32_2p1 ); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_3x32: + { + selector1 = _mm512_setzero_epi32(); + selector2 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_0p1) + + // c[1, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_1p1) + + // c[2, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_2p0) + + // c[2, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_2p1) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_3x32_DISABLE: @@ -2517,7 +2823,8 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2x32) { &&POST_OPS_2x32_DISABLE, &&POST_OPS_BIAS_2x32, - &&POST_OPS_RELU_2x32 + &&POST_OPS_RELU_2x32, + &&POST_OPS_RELU_SCALE_2x32 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -2661,6 +2968,28 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2x32) // c[1,16-31] c_int32_1p1 = _mm512_max_epi32( selector1, c_int32_1p1 ); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_2x32: + { + selector1 = _mm512_setzero_epi32(); + selector2 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_0p1) + + // c[1, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_1p1) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_2x32_DISABLE: @@ -2687,7 +3016,8 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1x32) { &&POST_OPS_1x32_DISABLE, &&POST_OPS_BIAS_1x32, - &&POST_OPS_RELU_1x32 + &&POST_OPS_RELU_1x32, + &&POST_OPS_RELU_SCALE_1x32 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -2786,6 +3116,22 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1x32) // c[0, 16-31] c_int32_0p1 = _mm512_max_epi32( selector1, c_int32_0p1 ); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_1x32: + { + selector1 = _mm512_setzero_epi32(); + selector2 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_0p1) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_1x32_DISABLE: @@ -2806,7 +3152,8 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5x48) { &&POST_OPS_5x48_DISABLE, &&POST_OPS_BIAS_5x48, - &&POST_OPS_RELU_5x48 + &&POST_OPS_RELU_5x48, + &&POST_OPS_RELU_SCALE_5x48 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -3166,6 +3513,61 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5x48) // c[4,32-47] c_int32_4p2 = _mm512_max_epi32( selector1, c_int32_4p2 ); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_5x48: + { + selector1 = _mm512_setzero_epi32(); + selector2 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_0p1) + + // c[0, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_0p2) + + // c[1, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_1p1) + + // c[1, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_1p2) + + // c[2, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_2p0) + + // c[2, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_2p1) + + // c[2, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_2p2) + + // c[3, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_3p0) + + // c[3, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_3p1) + + // c[3, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_3p2) + + // c[4, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_4p0) + + // c[4, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_4p1) + + // c[4, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_4p2) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_5x48_DISABLE: @@ -3225,7 +3627,8 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_4x48) { &&POST_OPS_4x48_DISABLE, &&POST_OPS_BIAS_4x48, - &&POST_OPS_RELU_4x48 + &&POST_OPS_RELU_4x48, + &&POST_OPS_RELU_SCALE_4x48 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -3525,6 +3928,52 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_4x48) // c[3,32-47] c_int32_3p2 = _mm512_max_epi32( selector1, c_int32_3p2 ); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_4x48: + { + selector1 = _mm512_setzero_epi32(); + selector2 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_0p1) + + // c[0, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_0p2) + + // c[1, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_1p1) + + // c[1, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_1p2) + + // c[2, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_2p0) + + // c[2, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_2p1) + + // c[2, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_2p2) + + // c[3, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_3p0) + + // c[3, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_3p1) + + // c[3, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_3p2) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_4x48_DISABLE: @@ -3575,7 +4024,8 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3x48) { &&POST_OPS_3x48_DISABLE, &&POST_OPS_BIAS_3x48, - &&POST_OPS_RELU_3x48 + &&POST_OPS_RELU_3x48, + &&POST_OPS_RELU_SCALE_3x48 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -3815,6 +4265,43 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3x48) // c[2,32-47] c_int32_2p2 = _mm512_max_epi32( selector1, c_int32_2p2 ); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_3x48: + { + selector1 = _mm512_setzero_epi32(); + selector2 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_0p1) + + // c[0, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_0p2) + + // c[1, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_1p1) + + // c[1, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_1p2) + + // c[2, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_2p0) + + // c[2, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_2p1) + + // c[2, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_2p2) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_3x48_DISABLE: @@ -3856,7 +4343,8 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2x48) { &&POST_OPS_2x48_DISABLE, &&POST_OPS_BIAS_2x48, - &&POST_OPS_RELU_2x48 + &&POST_OPS_RELU_2x48, + &&POST_OPS_RELU_SCALE_2x48 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -4036,6 +4524,34 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2x48) // c[1,32-47] c_int32_1p2 = _mm512_max_epi32( selector1, c_int32_1p2 ); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_2x48: + { + selector1 = _mm512_setzero_epi32(); + selector2 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_0p1) + + // c[0, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_0p2) + + // c[1, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_1p1) + + // c[1, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_1p2) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_2x48_DISABLE: @@ -4068,7 +4584,8 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1x48) { &&POST_OPS_1x48_DISABLE, &&POST_OPS_BIAS_1x48, - &&POST_OPS_RELU_1x48 + &&POST_OPS_RELU_1x48, + &&POST_OPS_RELU_SCALE_1x48 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -4188,6 +4705,25 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1x48) // c[0,32-47] c_int32_0p2 = _mm512_max_epi32( selector1, c_int32_0p2 ); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_1x48: + { + selector1 = _mm512_setzero_epi32(); + selector2 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_0p1) + + // c[0, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_0p2) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_1x48_DISABLE: diff --git a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_n_fringe_amd512vnni.c b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_n_fringe_amd512vnni.c index 090a120232..b710f2eb49 100644 --- a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_n_fringe_amd512vnni.c +++ b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_n_fringe_amd512vnni.c @@ -37,6 +37,7 @@ #include "blis.h" #include "lpgemm_kernels.h" +#include "lpgemm_s32_kern_macros.h" // 6xlt16 int8o32 fringe kernel LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6xlt16) @@ -45,7 +46,8 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6xlt16) { &&POST_OPS_6xLT16_DISABLE, &&POST_OPS_BIAS_6xLT16, - &&POST_OPS_RELU_6xLT16 + &&POST_OPS_RELU_6xLT16, + &&POST_OPS_RELU_SCALE_6xLT16 }; dim_t MR = 6; dim_t m_full_pieces = m0 / MR; @@ -330,6 +332,34 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6xlt16) // c[5,0-15] c_int32_5p0 = _mm512_max_epi32( selector1, c_int32_5p0 ); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_6xLT16: + { + selector1 = _mm512_setzero_epi32(); + selector2 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_0p0) + + // c[1, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_1p0) + + // c[2, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_2p0) + + // c[3, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_3p0) + + // c[4, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_4p0) + + // c[5, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_5p0) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_6xLT16_DISABLE: @@ -464,7 +494,8 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x16) { &&POST_OPS_6x16_DISABLE, &&POST_OPS_BIAS_6x16, - &&POST_OPS_RELU_6x16 + &&POST_OPS_RELU_6x16, + &&POST_OPS_RELU_SCALE_6x16 }; dim_t MR = 6; dim_t m_full_pieces = m0 / MR; @@ -734,6 +765,34 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x16) // c[5,0-15] c_int32_5p0 = _mm512_max_epi32( selector1, c_int32_5p0 ); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_6x16: + { + selector1 = _mm512_setzero_epi32(); + selector2 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_0p0) + + // c[1, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_1p0) + + // c[2, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_2p0) + + // c[3, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_3p0) + + // c[4, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_4p0) + + // c[5, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_5p0) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_6x16_DISABLE: @@ -849,7 +908,8 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x32) { &&POST_OPS_6x32_DISABLE, &&POST_OPS_BIAS_6x32, - &&POST_OPS_RELU_6x32 + &&POST_OPS_RELU_6x32, + &&POST_OPS_RELU_SCALE_6x32 }; dim_t MR = 6; dim_t m_full_pieces = m0 / MR; @@ -1215,6 +1275,52 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x32) // c[5,16-31] c_int32_5p1 = _mm512_max_epi32( selector1, c_int32_5p1 ); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_6x32: + { + selector1 = _mm512_setzero_epi32(); + selector2 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_0p1) + + // c[1, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_1p1) + + // c[2, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_2p0) + + // c[2, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_2p1) + + // c[3, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_3p0) + + // c[3, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_3p1) + + // c[4, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_4p0) + + // c[4, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_4p1) + + // c[5, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_5p0) + + // c[5, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_5p1) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_6x32_DISABLE: @@ -1348,7 +1454,8 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x48) { &&POST_OPS_6x48_DISABLE, &&POST_OPS_BIAS_6x48, - &&POST_OPS_RELU_6x48 + &&POST_OPS_RELU_6x48, + &&POST_OPS_RELU_SCALE_6x48 }; dim_t MR = 6; dim_t m_full_pieces = m0 / MR; @@ -1811,6 +1918,70 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x48) // c[5,32-47] c_int32_5p2 = _mm512_max_epi32( selector1, c_int32_5p2 ); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_6x48: + { + selector1 = _mm512_setzero_epi32(); + selector2 = + _mm512_set1_epi32( *( ( int32_t* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_0p1) + + // c[0, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_0p2) + + // c[1, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_1p1) + + // c[1, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_1p2) + + // c[2, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_2p0) + + // c[2, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_2p1) + + // c[2, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_2p2) + + // c[3, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_3p0) + + // c[3, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_3p1) + + // c[3, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_3p2) + + // c[4, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_4p0) + + // c[4, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_4p1) + + // c[4, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_4p2) + + // c[5, 0-15] + RELU_SCALE_OP_S32_AVX512(c_int32_5p0) + + // c[5, 16-31] + RELU_SCALE_OP_S32_AVX512(c_int32_5p1) + + // c[5, 32-47] + RELU_SCALE_OP_S32_AVX512(c_int32_5p2) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_6x48_DISABLE: diff --git a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_s32_kern_macros.h b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_s32_kern_macros.h new file mode 100644 index 0000000000..bb82d04d34 --- /dev/null +++ b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_s32_kern_macros.h @@ -0,0 +1,45 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef LPGEMM_S32_KERN_MACROS_H +#define LPGEMM_S32_KERN_MACROS_H + +#define RELU_SCALE_OP_S32_AVX512(reg) \ + /* Generate indenx of elements <= 0.*/ \ + relu_cmp_mask = _mm512_cmple_epi32_mask( reg, selector1 ); \ + \ + /* Apply scaling on for <= 0 elements.*/ \ + reg = _mm512_mask_mullo_epi32( reg, relu_cmp_mask, reg, selector2 ); \ + +#endif // LPGEMM_S32_KERN_MACROS_H diff --git a/bench/bench_aocl_gemm/Makefile b/bench/bench_aocl_gemm/Makefile index a760d84688..91b3a7b587 100755 --- a/bench/bench_aocl_gemm/Makefile +++ b/bench/bench_aocl_gemm/Makefile @@ -87,7 +87,7 @@ TEST_OBJS := $(patsubst $(TEST_SRC_PATH)/%.c, \ CINCFLAGS := -I$(INC_PATH) -I$(CBLAS_HEADER_PATH) # Use the CFLAGS for the configuration family. -CFLAGS := $(call get-user-cflags-for,$(CONFIG_NAME)) -march=icelake-server -mavx512bf16 +CFLAGS := $(call get-user-cflags-for,$(CONFIG_NAME)) # Add local header paths to CFLAGS CFLAGS += -I$(TEST_SRC_PATH) diff --git a/bench/bench_aocl_gemm/bench_lpgemm.c b/bench/bench_aocl_gemm/bench_lpgemm.c new file mode 100644 index 0000000000..67a9277a5a --- /dev/null +++ b/bench/bench_aocl_gemm/bench_lpgemm.c @@ -0,0 +1,761 @@ +#include +#include +#include +#include +#include +#include +#include + +#include "blis.h" + +// Mode can be one of the follwoing: +// 1. p - performance, used for benchmarks. +// 2. a - accuracy, used to test accuracy/correctness. +// Default value is p, can be modified by passing command line arg. +char bench_mode = 'p'; + +int32_t global_n_repeat = 0; + +#define _XSTR(str) #str +#define XSTR(str) _XSTR(str) + +#define GEN_FUNC_NAME(prototype,ctype) prototype ## ctype + +#define GEN_FILL_ARRAY_FUNC(ctype) \ +void fill_array_ ## ctype ( void* arr, dim_t size ) \ +{ \ + ctype* temp_arr = ( ctype* ) arr; \ + for ( dim_t i = 0; i < size; ++i ) \ + { \ + temp_arr[i] = ( ctype )( i % 100 ); \ + } \ +} \ + +GEN_FILL_ARRAY_FUNC(uint8_t) +GEN_FILL_ARRAY_FUNC(int8_t) +GEN_FILL_ARRAY_FUNC(float) +GEN_FILL_ARRAY_FUNC(int32_t) + +#define GEN_FILL_ARRAY_POST_OPS_FUNC(ctype) \ +void fill_array_post_ops_ ## ctype ( void* arr, dim_t size ) \ +{ \ + ctype* temp_arr = ( ctype* ) arr; \ + for ( dim_t i = 0; i < size; ++i ) \ + { \ + temp_arr[i] = ( ctype )( i % 20 ); \ + } \ +} \ + +GEN_FILL_ARRAY_POST_OPS_FUNC(int16_t) +GEN_FILL_ARRAY_POST_OPS_FUNC(int32_t) +GEN_FILL_ARRAY_POST_OPS_FUNC(float) + +#define GEN_BLIS_MAT_MUL_FUNC(A_type,B_type,C_type,BLAS_SFX) \ +void mat_mul_ ## BLAS_SFX \ + ( \ + char op_t, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + C_type alpha, \ + A_type* a, \ + dim_t lda, \ + B_type* b, \ + dim_t ldb, \ + C_type beta, \ + C_type* c, \ + dim_t ldc, \ + aocl_post_op* post_op\ + ) \ +{ \ + char storage = 'r'; \ + char transa = 'n'; \ + char transb = 'n'; \ + char reordera = 'n'; \ + char reorderb = 'n'; \ + \ + if ( ( op_t == 'p' ) || ( op_t == 'P' ) ) \ + { \ + /* No reordering of B.*/ \ + reordera = 'n'; \ + reorderb = 'n'; \ + } \ + else if ( ( op_t == 'r' ) || ( op_t == 'R' ) ) \ + { \ + /* Reordered B.*/ \ + reordera = 'n'; \ + reorderb = 'r'; \ + } \ + \ + aocl_gemm_ ## BLAS_SFX( storage, transa, transb, m, n, k, \ + alpha, \ + a, lda, reordera, \ + b, ldb, reorderb, \ + beta, \ + c, ldc, post_op ); \ + \ + /*dim_t MR = 6; \ + dim_t NR = 16; \ + \ + __m512i selector1; \ + __m512i all_zero = _mm512_setzero_epi32(); \ + __m512i c0; \ + __m512i c1; \ + __m512i c2; \ + __m512i c3; \ + __m512i c4; \ + __m512i c5; \ + \ + for ( dim_t i = 0; i < m; i += MR ) \ + { \ + if ( ( i + MR ) > m ) \ + { \ + break; \ + } \ + for ( dim_t j = 0; j < n; j += NR ) \ + { \ + if ( ( j + NR ) > n ) \ + { \ + break; \ + } \ + selector1 = _mm512_loadu_epi32( (int32_t*)post_op->bias.bias + j ); \ + c0 = _mm512_loadu_epi32( c + ( ( i + 0 ) * ldc ) + j ); \ + c1 = _mm512_loadu_epi32( c + ( ( i + 1 ) * ldc ) + j ); \ + c2 = _mm512_loadu_epi32( c + ( ( i + 2 ) * ldc ) + j ); \ + c3 = _mm512_loadu_epi32( c + ( ( i + 3 ) * ldc ) + j ); \ + c4 = _mm512_loadu_epi32( c + ( ( i + 4 ) * ldc ) + j ); \ + c5 = _mm512_loadu_epi32( c + ( ( i + 5 ) * ldc ) + j ); \ + \ + c0 = _mm512_add_epi32( selector1, c0 ); \ + c1 = _mm512_add_epi32( selector1, c1 ); \ + c2 = _mm512_add_epi32( selector1, c2 ); \ + c3 = _mm512_add_epi32( selector1, c3 ); \ + c4 = _mm512_add_epi32( selector1, c4 ); \ + c5 = _mm512_add_epi32( selector1, c5 ); \ + \ + c0 = _mm512_max_epi32( all_zero, c0 ); \ + c1 = _mm512_max_epi32( all_zero, c1 ); \ + c2 = _mm512_max_epi32( all_zero, c2 ); \ + c3 = _mm512_max_epi32( all_zero, c3 ); \ + c4 = _mm512_max_epi32( all_zero, c4 ); \ + c5 = _mm512_max_epi32( all_zero, c5 ); \ + \ + _mm512_storeu_epi32( c + ( ( i + 0 ) * ldc ) + j, c0 ); \ + _mm512_storeu_epi32( c + ( ( i + 1 ) * ldc ) + j, c1 ); \ + _mm512_storeu_epi32( c + ( ( i + 2 ) * ldc ) + j, c2 ); \ + _mm512_storeu_epi32( c + ( ( i + 3 ) * ldc ) + j, c3 ); \ + _mm512_storeu_epi32( c + ( ( i + 4 ) * ldc ) + j, c4 ); \ + _mm512_storeu_epi32( c + ( ( i + 5 ) * ldc ) + j, c5 ); \ + } \ + } */\ +} \ + +GEN_BLIS_MAT_MUL_FUNC(uint8_t, int8_t, int16_t, u8s8s16os16) +GEN_BLIS_MAT_MUL_FUNC(uint8_t,int8_t,int32_t,u8s8s32os32) +GEN_BLIS_MAT_MUL_FUNC(float,float,float,f32f32f32of32) + +double get_gflops + ( + dim_t m, + dim_t n, + dim_t k, + double runtime + ) +{ + return ( ( 2.0 * m * n * k ) / ( runtime * 1.0e9 ) ); +} + +void print_result + ( + const char* msg, + int32_t n_repeats, + dim_t m, + dim_t n, + dim_t k, + dim_t lda, + dim_t ldb, + dim_t ldc, + double runtime + ) +{ + double gflops = get_gflops( m, n, k, runtime ); + printf("%s m: %ld, n: %ld, k: %ld, lda: %ld, ldb: %ld, ldc: %ld," \ + " Gops: %f, n_repeats: %d\n", + msg, m, n, k, lda, ldb, ldc, gflops, n_repeats); +} + +#define GEN_MAT_MUL_BENCH_DRV_FUNC(A_type,B_type,C_type,BLAS_SFX) \ +void mat_mul_bench_driver_ ## BLAS_SFX \ + ( \ + char op_t, \ + int32_t n_repeats, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + C_type alpha, \ + A_type* a, \ + dim_t lda, \ + B_type* b, \ + dim_t ldb, \ + C_type beta, \ + C_type* c, \ + dim_t ldc, \ + aocl_post_op* post_op\ + ) \ +{ \ + double min_time_diff = DBL_MAX; \ + for ( int32_t nr = 0; nr < n_repeats; ++nr ) \ + { \ + if ( bench_mode == 'a' ) \ + { \ + memset( ( void* ) c, 0, sizeof( C_type ) * m * n ); \ + } \ + \ + struct timespec tstart={0,0}, tend={0,0}; \ + clock_gettime(CLOCK_MONOTONIC, &tstart); \ + \ + GEN_FUNC_NAME(mat_mul_,BLAS_SFX) \ + ( \ + op_t, m, n, k, \ + alpha, \ + a, lda, \ + b, ldb, \ + beta, \ + c, ldc, \ + post_op \ + ); \ + \ + clock_gettime(CLOCK_MONOTONIC, &tend); \ + \ + double diff = \ + ( ( double ) tend.tv_sec + ( 1.0e-9 * tend.tv_nsec ) ) - \ + ( ( double ) tstart.tv_sec + ( 1.0e-9 * tstart.tv_nsec ) ); \ + min_time_diff = ( diff < min_time_diff ) ? diff : min_time_diff; \ + } \ + \ + print_result( XSTR(BLAS_SFX), n_repeats, m, n, k, lda, ldb, ldc, min_time_diff); \ +} \ + +GEN_MAT_MUL_BENCH_DRV_FUNC(uint8_t, int8_t, int16_t, u8s8s16os16) +GEN_MAT_MUL_BENCH_DRV_FUNC(uint8_t,int8_t,int32_t,u8s8s32os32) +GEN_MAT_MUL_BENCH_DRV_FUNC(float,float,float,f32f32f32of32) + +#define GEN_MAT_MUL_ACC_CHK_DRV_FUNC(A_type,B_type,C_type,BLAS_SFX) \ +void mat_mul_accuracy_check_driver_ ## BLAS_SFX \ + ( \ + FILE* fout, \ + dim_t m, \ + dim_t n, \ + dim_t k, \ + C_type alpha, \ + A_type* a, \ + dim_t lda, \ + B_type* b, \ + dim_t ldb, \ + C_type beta, \ + C_type* c, \ + dim_t ldc, \ + C_type* c_ref, \ + dim_t ldc_ref, \ + aocl_post_op* post_op\ + ) \ +{ \ + for ( dim_t i = 0; i < m; ++i ) \ + { \ + for ( dim_t j = 0; j < n; ++j ) \ + { \ + C_type temp_accum = 0; \ + \ + for ( dim_t p = 0; p < k; ++p) \ + { \ + temp_accum += ( *( a + ( i * lda ) + p ) * *( b + ( p * ldb ) + j ) ); \ + } \ + \ + temp_accum = ( beta * ( * (c_ref + ( ldc_ref * i ) + j ) ) ) \ + + ( alpha * temp_accum ); \ + \ + if ( post_op != NULL ) \ + { \ + /* Apply bias followed by relu. */ \ + if ( post_op->seq_vector[0] == BIAS ) \ + { \ + if ( post_op->seq_length >= 1 ) \ + { \ + temp_accum += ( *( ( C_type* )post_op->bias.bias + j ) ); \ + } \ + if ( post_op->seq_length > 1 ) \ + { \ + if ( post_op->eltwise.algo.alpha != NULL ) /* PReLU*/ \ + { \ + temp_accum = ( temp_accum > 0 ) ? \ + temp_accum : \ + ( temp_accum * *( ( C_type* ) post_op->eltwise.algo.alpha ) ); \ + } \ + else \ + { \ + temp_accum = ( temp_accum > 0 ) ? temp_accum : 0 ; \ + } \ + } \ + } \ + else if ( post_op->seq_vector[0] == ELTWISE ) \ + { \ + if ( post_op->seq_length >= 1 ) \ + { \ + if ( post_op->eltwise.algo.alpha != NULL ) /* PReLU*/ \ + { \ + temp_accum = ( temp_accum > 0 ) ? \ + temp_accum : \ + ( temp_accum * *( ( C_type* ) post_op->eltwise.algo.alpha ) ); \ + } \ + else \ + { \ + temp_accum = ( temp_accum > 0 ) ? temp_accum : 0 ; \ + } \ + } \ + if ( post_op->seq_length > 1 ) \ + { \ + temp_accum += ( *( ( C_type* )post_op->bias.bias + j ) ); \ + } \ + } \ + } \ + \ + if ( *( c + ( ldc * i ) + j ) != temp_accum ) \ + { \ + if ( fout ) \ + { \ + fprintf( fout, "%s Failure input m: %ld, n: %ld, k: %ld," \ + " lda: %ld, ldb: %ld, ldc: %ld\n", \ + XSTR(BLAS_SFX), m, n, k, lda, ldb, ldc ); \ + fflush( fout ); \ + } \ + printf("failure, m: %ld, n: %ld, k: %ld\n", i, j, k); \ + goto cleanup_acc; \ + } \ + } \ + } \ +cleanup_acc: \ + return; \ +} \ + +GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t, int8_t, int16_t, u8s8s16os16) +GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t,int8_t,int32_t,u8s8s32os32) +GEN_MAT_MUL_ACC_CHK_DRV_FUNC(float,float,float,f32f32f32of32) + +/* Only supports bias followed by RELU and vice versa for now.*/ \ +#define GEN_MAT_MUL_POST_OPS_CREATOR(C_type,BLAS_SFX) \ +aocl_post_op* lpgemm_create_post_ops_struct_ ## BLAS_SFX \ + ( \ + dim_t m, \ + dim_t n, \ + char* post_ops_str \ + ) \ +{ \ + aocl_post_op* post_ops = NULL; \ + post_ops = ( aocl_post_op* ) malloc( sizeof( aocl_post_op ) ); \ + \ + if ( post_ops == NULL ) \ + { \ + return NULL; \ + } \ + \ + /* Only supporting 2 post ops at max for now.*/ \ + dim_t max_post_ops_seq_length = 2; \ + post_ops->seq_vector = ( AOCL_POST_OP_TYPE* ) \ + malloc \ + ( \ + max_post_ops_seq_length * \ + sizeof( AOCL_POST_OP_TYPE ) \ + ); \ + \ + if ( post_ops->seq_vector == NULL ) \ + { \ + free( post_ops ); \ + return NULL; \ + } \ + \ + /* Parse post ops list.*/ \ + char* ops_tok = strtok(post_ops_str, ", " ); \ + bool is_param_relu = FALSE; \ + dim_t cur_op_index = 0; \ + while ( ops_tok ) \ + { \ + if ( strcmp( ops_tok, "bias") == 0 ) \ + { \ + post_ops->seq_vector[cur_op_index] = BIAS; \ + } \ + else if ( strcmp( ops_tok, "relu") == 0 ) \ + { \ + post_ops->seq_vector[cur_op_index] = ELTWISE; \ + } \ + else if ( strcmp( ops_tok, "prelu") == 0 ) \ + { \ + post_ops->seq_vector[cur_op_index] = ELTWISE; \ + is_param_relu = TRUE; \ + } \ + ops_tok = strtok( NULL, ", " ); \ + cur_op_index++; \ + } \ + post_ops->seq_length = cur_op_index; \ + \ + /* Allocate bias buffer, return early if alloc fails.*/ \ + post_ops->bias.bias = malloc( n * sizeof( C_type ) ); \ + if ( post_ops->bias.bias == NULL ) \ + { \ + free( post_ops->seq_vector ); \ + free( post_ops ); \ + return NULL; \ + } \ + \ + GEN_FUNC_NAME(fill_array_post_ops_,C_type)( post_ops->bias.bias, n ); \ + \ + post_ops->eltwise.is_power_of_2 = FALSE; \ + post_ops->eltwise.scale_factor = NULL; \ + post_ops->eltwise.algo.alpha = NULL; \ + post_ops->eltwise.algo.algo_type = RELU; \ + if ( is_param_relu == TRUE ) \ + { \ + post_ops->eltwise.algo.alpha = malloc( sizeof( C_type ) ); \ + *( ( C_type* ) post_ops->eltwise.algo.alpha ) = ( C_type )6; \ + post_ops->eltwise.algo.algo_type = PRELU; \ + } \ + post_ops->eltwise.algo.beta = NULL; \ + \ + return post_ops; \ +} \ + +GEN_MAT_MUL_POST_OPS_CREATOR(int16_t,u8s8s16os16) +GEN_MAT_MUL_POST_OPS_CREATOR(int32_t,u8s8s32os32) +GEN_MAT_MUL_POST_OPS_CREATOR(float,f32f32f32of32) + +void lpgemm_destroy_post_ops_struct( aocl_post_op* post_ops ) +{ + if ( post_ops == NULL ) + { + return; + } + + if ( post_ops->eltwise.algo.alpha != NULL ) + { + free( post_ops->eltwise.algo.alpha ); + } + if ( post_ops->bias.bias != NULL ) + { + free( post_ops->bias.bias ); + } + + if( post_ops->seq_vector != NULL ) + { + free( post_ops->seq_vector ); + } + + free( post_ops ); +} + +#define GEN_MAT_MUL_BENCH_MAIN_FUNC(A_type,B_type,C_type,BLAS_SFX) \ +void mat_mul_bench_main_ ## BLAS_SFX \ + ( \ + FILE* fin, \ + FILE* fout, \ + char op_t, \ + int32_t m, \ + int32_t n, \ + int32_t k, \ + int32_t stride_a, \ + int32_t stride_b, \ + int32_t stride_c, \ + char* post_ops_str \ + ) \ +{ \ + if ( ( op_t != 'p' ) && ( op_t != 'P' ) && ( op_t != 'r' ) && ( op_t != 'R' ) ) \ + { \ + printf("The op_t ( 2nd arg in input.txt) is not valid\n"); \ + return; \ + } \ + \ + int32_t n_repeats = bli_max( 30, bli_min(( 3e10 / ( ( int64_t )m * n * k )), 100 )); \ + if ( global_n_repeat > 0 ) \ + { \ + n_repeats = global_n_repeat; \ + } \ + \ + /* Get 64 byte aligned memory.*/ \ + A_type* a = ( A_type* ) bli_malloc_user( sizeof( A_type ) * m * k ); \ + \ + B_type* b = ( B_type* ) bli_malloc_user( sizeof( B_type ) * n * k ); \ + \ + C_type* c = ( C_type* ) bli_malloc_user( sizeof( C_type ) * m * n ); \ + memset( ( void* ) c, 0, sizeof( C_type ) * m * n ); \ + \ + C_type* c_ref = ( C_type* ) bli_malloc_user( sizeof( C_type ) * m * n ); \ + memset( ( void* ) c_ref, 0, sizeof( C_type ) * m * n ); \ + \ + C_type alpha; \ + C_type beta; \ + if ( bench_mode == 'p' ) \ + { \ + alpha = 1; \ + beta = 0; \ + } \ + else if ( bench_mode == 'a' ) \ + { \ + alpha = 2; \ + beta = 9; \ + } \ + \ + GEN_FUNC_NAME(fill_array_,A_type)( a, ( m * k ) ); \ + GEN_FUNC_NAME(fill_array_,B_type)( b, ( k * n ) ); \ + \ + aocl_post_op* post_op = NULL; \ + if ( post_ops_str != NULL ) \ + { \ + post_op = GEN_FUNC_NAME(lpgemm_create_post_ops_struct_,BLAS_SFX)( m, n, post_ops_str ); \ + if ( post_op == NULL ) \ + { \ + printf(" post op struct allocation failure, returning.\n"); \ + return; \ + } \ + } \ + \ + if ( ( op_t == 'p' ) || ( op_t == 'P' ) ) \ + { \ + /* No reordering of B.*/ \ + GEN_FUNC_NAME(mat_mul_bench_driver_,BLAS_SFX) \ + ( \ + op_t, n_repeats, m, n, k, \ + alpha, \ + a, stride_a, \ + b, stride_b, \ + beta, \ + c, stride_c, \ + post_op \ + ); \ + } \ + else if ( ( op_t == 'r' ) || ( op_t == 'R' ) ) \ + { \ + /* Reorder B.*/ \ + siz_t b_reorder_buf_siz_req = \ + GEN_FUNC_NAME(aocl_get_reorder_buf_size_,BLAS_SFX)( 'B', k, n ); \ + \ + B_type* b_reorder = ( B_type* ) bli_malloc_user( b_reorder_buf_siz_req ); \ + GEN_FUNC_NAME(aocl_reorder_,BLAS_SFX)( 'B', b, b_reorder, k, n, stride_b ); \ + \ + GEN_FUNC_NAME(mat_mul_bench_driver_,BLAS_SFX) \ + ( \ + op_t, n_repeats, m, n, k, \ + alpha, \ + a, stride_a, \ + b_reorder, stride_b, \ + beta, \ + c, stride_c, \ + post_op \ + ); \ + \ + bli_free_user( b_reorder ); \ + } \ + \ + if ( bench_mode == 'a' ) \ + { \ + printf(" Running accuracy check.\n"); \ + GEN_FUNC_NAME(mat_mul_accuracy_check_driver_,BLAS_SFX) \ + ( \ + fout, m, n, k, \ + alpha, \ + a, stride_a, \ + b, stride_b, \ + beta, \ + c, stride_c, \ + c_ref, stride_c, \ + post_op \ + ); \ + } \ + \ + lpgemm_destroy_post_ops_struct( post_op ); \ + \ + if ( a != NULL ) \ + { \ + bli_free_user( a ); \ + } \ + if ( b != NULL ) \ + { \ + bli_free_user( b ); \ + } \ + if ( c != NULL ) \ + { \ + bli_free_user( c ); \ + } \ + if ( c_ref != NULL ) \ + { \ + bli_free_user( c_ref ); \ + } \ +} \ + +GEN_MAT_MUL_BENCH_MAIN_FUNC(uint8_t, int8_t, int16_t, u8s8s16os16) +GEN_MAT_MUL_BENCH_MAIN_FUNC(uint8_t,int8_t,int32_t,u8s8s32os32) +GEN_MAT_MUL_BENCH_MAIN_FUNC(float,float,float,f32f32f32of32) + +int main( int argc, char** argv ) +{ + FILE* fin = NULL; + if ( argc < 5 ) + { + printf( "Usage: ./mat_mul -i input.txt -m mode < -n 1000 -o op1,op2.. >" \ + "\nMode is either a or p. a is used for accuracy test, " \ + "whereas p is used for performance benchmarking." \ + "\nn_repeats can be set optionally using -n arg." \ + "\nPost ops can be executed optionaly by providing a " \ + "coma separated list of ops after -o arg.\nCurrently " \ + "bias and relu/prelu is supported and can be specified " \ + "as a single post op or combination of the same. eg: -o bias,relu ; -o prelu.\n" ); + exit( 1 ); + } + + char* file_name = NULL; + char* post_ops_str = NULL; + char* post_ops_str_dest = NULL; //Strtok is used to parse, need to maintain a copy. + + // Parse CLI arguments. + opterr = 0; + int opt_val; + while ( ( opt_val = getopt( argc, argv, "i:m:n:o:" ) ) != -1 ) + { + switch ( opt_val ) + { + case 'i': + file_name = optarg; + break; + case 'm': + bench_mode = ( ( ( *optarg ) == 'a' ) || ( ( *optarg ) == 'p' ) ) ? ( *optarg ) : 'p'; + break; + case 'n': + global_n_repeat = ( atoi( optarg ) > 0 ) ? atoi( optarg ) : 0; + break; + case 'o': + post_ops_str = optarg; + break; + default: + break; + } + } + + if ( post_ops_str != NULL ) + { + post_ops_str_dest = strdup( post_ops_str ); + } + + if ( bench_mode == 'p' ) + { + printf( "Running bench in performance benchmarking mode.\n" ); + } + else if ( bench_mode == 'a' ) + { + printf( "Running bench in accuracy/correctness testing mode.\n" ); + } + + if ( file_name == NULL ) + { + printf( " File name provided is invalid.\n" ); + exit( 1 ); + } + + fin = fopen( file_name, "r" ); + if (fin == NULL) + { + printf( "Error opening the file %s\n", argv[1] ); + exit( 1 ); + } + + FILE* fout = NULL; + + fout = fopen( "lpgemm_accuracy_test_failures.txt", "w" ); + + char op_type_char; + char op_t; + int32_t m, n, k; + int32_t stride_a, stride_b, stride_c; + + const dim_t len_list_omp_cores_for_testing = 2; + const dim_t list_omp_cores_for_testing[2] = { 80, 1 }; + + dim_t core_index = 0; + bool can_run = TRUE; + while ( ( can_run == TRUE ) && ( fseek( fin, 0L, SEEK_SET ) == 0 ) ) + { + if ( bench_mode == 'p' ) + { + can_run = FALSE; + } + else if ( bench_mode == 'a' ) + { + // For accuracy testing, we test accuracy using multiple different + // number of cores. This helps uncover any bugs related to over + // subscription or varying thread factorizations. + // Set current number of cores. +#ifdef BLIS_ENABLE_OPENMP + omp_set_num_threads( list_omp_cores_for_testing[core_index] ); +#endif + printf( "Accuracy test using %ld threads.\n", + list_omp_cores_for_testing[core_index] ); + + core_index++; + if ( core_index < len_list_omp_cores_for_testing ) + { + can_run = TRUE; + } + else + { + can_run = FALSE; + } + } + + while ( fscanf( fin, "%c %c %d %d %d %d %d %d\n", + &op_type_char, &op_t, &m, &n, &k, + &stride_a, &stride_b, &stride_c ) == 8 ) + { + if ( ( op_type_char == 'i' ) || ( op_type_char == 'I' ) ) + { + GEN_FUNC_NAME(mat_mul_bench_main_,u8s8s32os32) + ( + fin, fout, op_t, + m, n, k, stride_a, stride_b, stride_c, + post_ops_str_dest + ); + } + else if ( ( op_type_char == 'f' ) || ( op_type_char == 'F' ) ) + { + GEN_FUNC_NAME(mat_mul_bench_main_,f32f32f32of32) + ( + fin, fout, op_t, + m, n, k, stride_a, stride_b, stride_c, + NULL + ); + } + else if ((op_type_char == 's') || (op_type_char == 'S')) + { + GEN_FUNC_NAME(mat_mul_bench_main_, u8s8s16os16) + ( + fin, fout, op_t, + m, n, k, stride_a, stride_b, stride_c, + post_ops_str_dest + ); + } + if ( post_ops_str != NULL ) + { + strcpy( post_ops_str_dest, post_ops_str ); + } + } + } + + if ( post_ops_str_dest != NULL ) + { + free( post_ops_str_dest ); + } + if ( fin ) + { + fclose( fin ); + } + if ( fout ) + { + fclose( fout ); + } + return 0; +} From 5ca632e0f007a2fa978d5f945af3e5223c43685a Mon Sep 17 00:00:00 2001 From: Harihara Sudhan S Date: Tue, 23 Aug 2022 13:33:28 +0530 Subject: [PATCH 192/243] Added API to check for BF16 ISA support - Checking for AVX512 bfloat 16 instructions support in architecture using the CPUID AMD-Internal: [CPUPL-2446] Change-Id: I088a8aa46b037af837b2e58a96b59eae70c1dbf0 --- frame/base/bli_cpuid.c | 63 ++++++++++++++++++++++++++++++++++++++++++ frame/base/bli_cpuid.h | 36 ++++++++++++------------ 2 files changed, 81 insertions(+), 18 deletions(-) diff --git a/frame/base/bli_cpuid.c b/frame/base/bli_cpuid.c index 552ab6e7aa..4dba53080a 100644 --- a/frame/base/bli_cpuid.c +++ b/frame/base/bli_cpuid.c @@ -597,7 +597,43 @@ void bli_cpuid_check_avx512vnni_support( void ) } } +// The support for AVX512_BF16 is checked only once (when this API is called +// first time). On subsequent calls the cached value is returned. +static bool is_avx512bf16_supported = FALSE; + +// Determine if the CPU has support for AVX512_BF16. +void bli_cpuid_check_avx512_bf16_support( void ) +{ + uint32_t family, model, features; + + // Call the CPUID instruction and parse its results into a family id, + // model id, and a feature bit field. + bli_cpuid_query( &family, &model, &features ); + + // Check for expected CPU features. + const uint32_t expected = FEATURE_AVX | + FEATURE_FMA3 | + FEATURE_AVX2 | + FEATURE_AVX512F | + FEATURE_AVX512DQ | + FEATURE_AVX512BW | + FEATURE_AVX512VL | + FEATURE_AVX512VNNI | + FEATURE_AVX512BF16 + ; + + if ( !bli_cpuid_has_features( features, expected ) ) + { + is_avx512bf16_supported = FALSE; + } + else + { + is_avx512bf16_supported = TRUE; + } +} + static bli_pthread_once_t once_check_avx512vnni_support = BLIS_PTHREAD_ONCE_INIT; +static bli_pthread_once_t once_check_avx512_bf16_support = BLIS_PTHREAD_ONCE_INIT; // Ensure that actual support determination happens only once void bli_cpuid_check_avx512vnni_support_once( void ) @@ -607,6 +643,14 @@ void bli_cpuid_check_avx512vnni_support_once( void ) #endif } +// Ensure that actual support determination happens only once to avoid performance hit +void bli_cpuid_check_avx512_bf16_support_once( void ) +{ +#ifndef BLIS_CONFIGURETIME_CPUID + bli_pthread_once( &once_check_avx512_bf16_support, bli_cpuid_check_avx512_bf16_support ); +#endif +} + // API to check if AVX512_VNNI is supported or not on the current platform. bool bli_cpuid_is_avx512vnni_supported( void ) { @@ -615,6 +659,14 @@ bool bli_cpuid_is_avx512vnni_supported( void ) return is_avx512vnni_supported; } +// API to check if AVX512_bf16 is supported or not on the current platform. +bool bli_cpuid_is_avx512_bf16_supported( void ) +{ + bli_cpuid_check_avx512_bf16_support_once(); + + return is_avx512bf16_supported; +} + #elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) arch_t bli_cpuid_query_id( void ) @@ -816,6 +868,7 @@ enum FEATURE_MASK_AVX512BW = (1u<<30), // cpuid[eax=7,ecx=0] :ebx[30] FEATURE_MASK_AVX512VL = (1u<<31), // cpuid[eax=7,ecx=0] :ebx[31] FEATURE_MASK_AVX512VNNI = (1u<<11), // cpuid[eax=7,ecx=0] :ecx[11] + FEATURE_MASK_AVX512BF16 = (1u<< 5), // cpuid[eax=7,ecx=1] :eax[5] FEATURE_MASK_XGETBV = (1u<<26)| (1u<<27), // cpuid[eax=1] :ecx[27:26] XGETBV_MASK_XMM = 0x02u, // xcr0[1] @@ -884,6 +937,16 @@ uint32_t bli_cpuid_query if ( bli_cpuid_has_features( ebx, FEATURE_MASK_AVX512VL ) ) *features |= FEATURE_AVX512VL; if ( bli_cpuid_has_features( ecx, FEATURE_MASK_AVX512VNNI ) ) *features |= FEATURE_AVX512VNNI; + + // This is actually a macro that modifies the last four operands, + // hence why they are not passed by address. + // This returns extended feature flags in EAX. + // The availability of AVX512_BF16 can be found using the + // 5th feature bit of the returned value + __cpuid_count( 7, 1, eax, ebx, ecx, edx ); + + if ( bli_cpuid_has_features( eax, FEATURE_MASK_AVX512BF16 ) ) *features |= FEATURE_AVX512BF16; + } // Check extended processor info / features bits for AMD-specific features. diff --git a/frame/base/bli_cpuid.h b/frame/base/bli_cpuid.h index 439cef3e41..805f31bf2e 100644 --- a/frame/base/bli_cpuid.h +++ b/frame/base/bli_cpuid.h @@ -135,6 +135,7 @@ void get_cpu_name( char *cpu_name ); int vpu_count( void ); bool bli_cpuid_is_avx_supported(void); bool bli_cpuid_is_avx512vnni_supported(void); +bool bli_cpuid_is_avx512_bf16_supported(void); enum { @@ -144,26 +145,25 @@ enum }; enum { - FEATURE_SSE3 = 0x0001, - FEATURE_SSSE3 = 0x0002, - FEATURE_SSE41 = 0x0004, - FEATURE_SSE42 = 0x0008, - FEATURE_AVX = 0x0010, - FEATURE_AVX2 = 0x0020, - FEATURE_FMA3 = 0x0040, - FEATURE_FMA4 = 0x0080, - FEATURE_AVX512F = 0x0100, - FEATURE_AVX512DQ = 0x0200, - FEATURE_AVX512PF = 0x0400, - FEATURE_AVX512ER = 0x0800, - FEATURE_AVX512CD = 0x1000, - FEATURE_AVX512BW = 0x2000, - FEATURE_AVX512VL = 0x4000, - FEATURE_AVX512VNNI = 0x8000 + FEATURE_SSE3 = 0x0001, + FEATURE_SSSE3 = 0x0002, + FEATURE_SSE41 = 0x0004, + FEATURE_SSE42 = 0x0008, + FEATURE_AVX = 0x0010, + FEATURE_AVX2 = 0x0020, + FEATURE_FMA3 = 0x0040, + FEATURE_FMA4 = 0x0080, + FEATURE_AVX512F = 0x0100, + FEATURE_AVX512DQ = 0x0200, + FEATURE_AVX512PF = 0x0400, + FEATURE_AVX512ER = 0x0800, + FEATURE_AVX512CD = 0x1000, + FEATURE_AVX512BW = 0x2000, + FEATURE_AVX512VL = 0x4000, + FEATURE_AVX512VNNI = 0x8000, + FEATURE_AVX512BF16 = 0x10000 }; - - #elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) char* find_string_in( char* target, char* buffer, size_t buf_len, char* filepath ); From 2114a43df805138e96a5605f1ed60914486bfb41 Mon Sep 17 00:00:00 2001 From: satish kumar nuggu Date: Tue, 23 Aug 2022 16:28:52 +0530 Subject: [PATCH 193/243] Fixes to avoid Out of Bound Memory Access in TRSM small algorithm Details: 1. Fixed the issues corresponding to Out of bound memory access during load and store. 2. In Intrinsic code: i. AVX2 Registers can hold 4 double elements. ii. In case of remainder when number of elements is lessthan vectorised register. Though the required number of elements are lessthan 4, we are reading and writing in chunks of 4 elements due to vectorization. This might cause out of bound memory access. 3. Redesigned code to restrict out of bound access by loading and storing the exact number of elements required. AMD-Internal: [SWLCSG-1470] Change-Id: I786f8023cf5a5f3e5343bea413c59bd0e764df9b --- kernels/zen/3/bli_trsm_small.c | 5083 +++++++++++++++++++++++++------- 1 file changed, 3987 insertions(+), 1096 deletions(-) diff --git a/kernels/zen/3/bli_trsm_small.c b/kernels/zen/3/bli_trsm_small.c index 0b47ddd862..ffb00b83c1 100644 --- a/kernels/zen/3/bli_trsm_small.c +++ b/kernels/zen/3/bli_trsm_small.c @@ -711,7 +711,7 @@ BLIS_INLINE err_t dtrsm_XAltB_ref #define BLIS_DTRSM_SMALL_GEMM_6nx4m(a01,b10,cs_b,p_lda,k_iter) \ for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ {\ - /*load 8x1 block of B10*/\ + /*load 4x1 block of B10*/\ ymm0 = _mm256_loadu_pd((double const *)b10); /*B10[0][0] B10[1][0] B10[2][0] B10[3][0]*/\ \ /*broadcast 1st row of A01*/\ @@ -737,6 +737,96 @@ BLIS_INLINE err_t dtrsm_XAltB_ref b10 += cs_b;\ } +#define BLIS_DTRSM_SMALL_GEMM_6nx3m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 3x1 block of B10*/\ + xmm5 = _mm_loadu_pd((double const*)(b10));\ + ymm0 = _mm256_broadcast_sd((double const *)(b10 + 2));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ + ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); /*A01[0][3]*/\ + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 4)); /*A01[0][4]*/\ + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 5)); /*A01[0][5]*/\ + ymm13 = _mm256_fmadd_pd(ymm2, ymm0, ymm13);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_DTRSM_SMALL_GEMM_6nx2m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 2x1 block of B10*/\ + xmm5 = _mm_loadu_pd((double const*)(b10));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ + ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); /*A01[0][3]*/\ + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 4)); /*A01[0][4]*/\ + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 5)); /*A01[0][5]*/\ + ymm13 = _mm256_fmadd_pd(ymm2, ymm0, ymm13);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_DTRSM_SMALL_GEMM_6nx1m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 1x1 block of B10*/\ + ymm0 = _mm256_broadcast_sd((double const *)b10);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ + ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); /*A01[0][3]*/\ + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 4)); /*A01[0][4]*/\ + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 5)); /*A01[0][5]*/\ + ymm13 = _mm256_fmadd_pd(ymm2, ymm0, ymm13);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + #define BLIS_DTRSM_SMALL_GEMM_4nx8m(a01,b10,cs_b,p_lda,k_iter) \ for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ {\ @@ -765,6 +855,101 @@ BLIS_INLINE err_t dtrsm_XAltB_ref b10 += cs_b;\ } +#define BLIS_DTRSM_SMALL_GEMM_4nx4m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 4x1 block of B10*/\ + ymm0 = _mm256_loadu_pd((double const *)b10);/*B10[0][0] B10[1][0] B10[2][0] B10[3][0]*/\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ + ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); /*A01[0][3]*/\ + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_DTRSM_SMALL_GEMM_4nx3m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 3x1 block of B10*/\ + xmm5 = _mm_loadu_pd((double const*)(b10));\ + ymm0 = _mm256_broadcast_sd((double const *)(b10 + 2));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ + ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); /*A01[0][3]*/\ + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_DTRSM_SMALL_GEMM_4nx2m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 2x1 block of B10*/\ + xmm5 = _mm_loadu_pd((double const*)(b10));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ + ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); /*A01[0][3]*/\ + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_DTRSM_SMALL_GEMM_4nx1m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 1x1 block of B10*/\ + ymm0 = _mm256_broadcast_sd((double const *)b10);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ + ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); /*A01[0][3]*/\ + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + #define BLIS_DTRSM_SMALL_GEMM_3nx8m(a01,b10,cs_b,p_lda,k_iter)\ for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ {\ @@ -789,47 +974,54 @@ BLIS_INLINE err_t dtrsm_XAltB_ref b10 += cs_b;\ } -#define BLIS_DTRSM_SMALL_GEMM_2nx8m(a01,b10,cs_b,p_lda,k_iter)\ +#define BLIS_DTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) \ for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ {\ - /*load 8x1 block of B10*/\ + /*load 4x1 block of B10*/\ ymm0 = _mm256_loadu_pd((double const *)b10);/*B10[0][0] B10[1][0] B10[2][0] B10[3][0]*/\ - ymm1 = _mm256_loadu_pd((double const *)(b10 + 4));/*B10[4][0] B10[5][0] B10[6][0] B10[7][0]*/\ \ /*broadcast 1st row of A01*/\ ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3);\ - ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);\ \ ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5);\ - ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ + ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7);\ \ a01 += 1; /*move to next row*/\ b10 += cs_b;\ } -#define BLIS_DTRSM_SMALL_GEMM_1nx8m(a01,b10,cs_b,p_lda,k_iter)\ +#define BLIS_DTRSM_SMALL_GEMM_3nx3m(a01,b10,cs_b,p_lda,k_iter) \ for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ {\ - /*load 8x1 block of B10*/\ - ymm0 = _mm256_loadu_pd((double const *)b10);/*B10[0][0] B10[1][0] B10[2][0] B10[3][0]*/\ - ymm1 = _mm256_loadu_pd((double const *)(b10 + 4));/*B10[4][0] B10[5][0] B10[6][0] B10[7][0]*/\ + /*load 3x1 block of B10*/\ + xmm5 = _mm_loadu_pd((double const*)(b10));\ + ymm0 = _mm256_broadcast_sd((double const *)(b10 + 2));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ \ /*broadcast 1st row of A01*/\ ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3);\ - ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ + ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7);\ \ a01 += 1; /*move to next row*/\ b10 += cs_b;\ } -#define BLIS_DTRSM_SMALL_GEMM_4nx4m(a01,b10,cs_b,p_lda,k_iter) \ +#define BLIS_DTRSM_SMALL_GEMM_3nx2m(a01,b10,cs_b,p_lda,k_iter) \ for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ {\ - /*load 8x1 block of B10*/\ - ymm0 = _mm256_loadu_pd((double const *)b10);/*B10[0][0] B10[1][0] B10[2][0] B10[3][0]*/\ + /*load 2x1 block of B10*/\ + xmm5 = _mm_loadu_pd((double const*)(b10));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ \ /*broadcast 1st row of A01*/\ ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ @@ -840,19 +1032,16 @@ BLIS_INLINE err_t dtrsm_XAltB_ref \ ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7);\ -\ - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); /*A01[0][3]*/\ - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9);\ \ a01 += 1; /*move to next row*/\ b10 += cs_b;\ } -#define BLIS_DTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) \ +#define BLIS_DTRSM_SMALL_GEMM_3nx1m(a01,b10,cs_b,p_lda,k_iter) \ for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ {\ - /*load 8x1 block of B10*/\ - ymm0 = _mm256_loadu_pd((double const *)b10);/*B10[0][0] B10[1][0] B10[2][0] B10[3][0]*/\ + /*load 1x1 block of B10*/\ + ymm0 = _mm256_broadcast_sd((double const *)b10);\ \ /*broadcast 1st row of A01*/\ ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ @@ -868,28 +1057,163 @@ BLIS_INLINE err_t dtrsm_XAltB_ref b10 += cs_b;\ } -#define BLIS_DTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) \ +#define BLIS_DTRSM_SMALL_GEMM_2nx8m(a01,b10,cs_b,p_lda,k_iter)\ for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ {\ /*load 8x1 block of B10*/\ ymm0 = _mm256_loadu_pd((double const *)b10);/*B10[0][0] B10[1][0] B10[2][0] B10[3][0]*/\ + ymm1 = _mm256_loadu_pd((double const *)(b10 + 4));/*B10[4][0] B10[5][0] B10[6][0] B10[7][0]*/\ \ /*broadcast 1st row of A01*/\ ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3);\ + ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);\ \ ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5);\ + ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6);\ \ a01 += 1; /*move to next row*/\ b10 += cs_b;\ } -#define BLIS_DTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) \ +#define BLIS_DTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 4x1 block of B10*/\ + ymm0 = _mm256_loadu_pd((double const *)b10);/*B10[0][0] B10[1][0] B10[2][0] B10[3][0]*/\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_DTRSM_SMALL_GEMM_2nx3m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 3x1 block of B10*/\ + xmm5 = _mm_loadu_pd((double const*)(b10));\ + ymm0 = _mm256_broadcast_sd((double const *)(b10 + 2));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_DTRSM_SMALL_GEMM_2nx2m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 2x1 block of B10*/\ + xmm5 = _mm_loadu_pd((double const*)(b10));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_DTRSM_SMALL_GEMM_2nx1m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 1x1 block of B10*/\ + ymm0 = _mm256_broadcast_sd((double const *)b10);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_DTRSM_SMALL_GEMM_1nx8m(a01,b10,cs_b,p_lda,k_iter)\ for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ {\ /*load 8x1 block of B10*/\ ymm0 = _mm256_loadu_pd((double const *)b10);/*B10[0][0] B10[1][0] B10[2][0] B10[3][0]*/\ + ymm1 = _mm256_loadu_pd((double const *)(b10 + 4));/*B10[4][0] B10[5][0] B10[6][0] B10[7][0]*/\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3);\ + ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_DTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 4x1 block of B10*/\ + ymm0 = _mm256_loadu_pd((double const *)b10);/*B10[0][0] B10[1][0] B10[2][0] B10[3][0]*/\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_DTRSM_SMALL_GEMM_1nx3m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 3x1 block of B10*/\ + xmm5 = _mm_loadu_pd((double const*)(b10));\ + ymm0 = _mm256_broadcast_sd((double const *)(b10+ 2));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_DTRSM_SMALL_GEMM_1nx2m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 2x1 block of B10*/\ + xmm5 = _mm_loadu_pd((double const*)(b10));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_DTRSM_SMALL_GEMM_1nx1m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 1x1 block of B10*/\ + ymm0 = _mm256_broadcast_sd((double const *)b10);\ \ /*broadcast 1st row of A01*/\ ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ @@ -1264,8 +1588,12 @@ BLIS_INLINE err_t dtrsm_XAltB_ref #define BLIS_PRE_DTRSM_SMALL_3M_3N(AlphaVal,b11,cs_b)\ ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); /*register to hold alpha*/\ \ - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0));\ - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1));\ + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 0));\ + ymm0 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 0 + 2));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1));\ + ymm1 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 1 + 2));\ + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0);\ xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2));\ ymm2 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 2 + 2));\ ymm2 = _mm256_insertf128_pd(ymm2, xmm5, 0);\ @@ -1274,20 +1602,22 @@ BLIS_INLINE err_t dtrsm_XAltB_ref ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9);\ ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10);\ \ - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08);\ - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08);\ - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08);\ -\ - _mm256_storeu_pd((double *)(b11), ymm0); /*store(B11[0-3][0])*/\ - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); /*store(B11[0-3][1])*/\ - xmm5 = _mm256_castpd256_pd128(ymm2);\ + xmm5 = _mm256_castpd256_pd128(ymm8);\ + _mm_storeu_pd((double *)(b11 + cs_b * 0), xmm5);\ + _mm_storel_pd((b11 + cs_b * 0 + 2), _mm256_extractf128_pd(ymm8, 1));\ + xmm5 = _mm256_castpd256_pd128(ymm9);\ + _mm_storeu_pd((double *)(b11 + cs_b * 1), xmm5);\ + _mm_storel_pd((b11 + cs_b * 1 + 2), _mm256_extractf128_pd(ymm9, 1));\ + xmm5 = _mm256_castpd256_pd128(ymm10);\ _mm_storeu_pd((double *)(b11 + cs_b * 2), xmm5);\ - _mm_storel_pd((b11 + cs_b * 2 + 2), _mm256_extractf128_pd(ymm2, 1)); + _mm_storel_pd((b11 + cs_b * 2 + 2), _mm256_extractf128_pd(ymm10, 1)); #define BLIS_PRE_DTRSM_SMALL_3M_2N(AlphaVal,b11,cs_b)\ ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); /*register to hold alpha*/\ \ - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0));\ + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 0));\ + ymm0 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 0 + 2));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1));\ ymm1 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 1 + 2));\ ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0);\ @@ -1295,13 +1625,12 @@ BLIS_INLINE err_t dtrsm_XAltB_ref ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8);\ ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9);\ \ - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08);\ - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08);\ -\ - _mm256_storeu_pd((double *)(b11), ymm0); /*store(B11[0-3][0])*/\ - xmm5 = _mm256_castpd256_pd128(ymm1);\ + xmm5 = _mm256_castpd256_pd128(ymm8);\ + _mm_storeu_pd((double *)(b11 + cs_b * 0), xmm5);\ + _mm_storel_pd((b11 + cs_b * 0 + 2), _mm256_extractf128_pd(ymm8, 1));\ + xmm5 = _mm256_castpd256_pd128(ymm9);\ _mm_storeu_pd((double *)(b11 + cs_b * 1), xmm5);\ - _mm_storel_pd((b11 + cs_b * 1 + 2), _mm256_extractf128_pd(ymm1, 1)); + _mm_storel_pd((b11 + cs_b * 1 + 2), _mm256_extractf128_pd(ymm9, 1)); #define BLIS_PRE_DTRSM_SMALL_3M_1N(AlphaVal,b11,cs_b)\ ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); /*register to hold alpha*/\ @@ -1310,18 +1639,19 @@ BLIS_INLINE err_t dtrsm_XAltB_ref ymm0 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 0 + 2));\ ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8);\ - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08);\ \ - xmm5 = _mm256_castpd256_pd128(ymm0);\ + xmm5 = _mm256_castpd256_pd128(ymm8);\ _mm_storeu_pd((double *)(b11), xmm5);\ - _mm_storel_pd((b11 + 2), _mm256_extractf128_pd(ymm0, 1)); + _mm_storel_pd((b11 + 2), _mm256_extractf128_pd(ymm8, 1)); #define BLIS_PRE_DTRSM_SMALL_2M_3N(AlphaVal,b11,cs_b)\ ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); /*register to hold alpha*/\ \ - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0));\ - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1));\ + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 0));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1));\ + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0);\ xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2));\ ymm2 = _mm256_insertf128_pd(ymm2, xmm5, 0);\ \ @@ -1329,30 +1659,27 @@ BLIS_INLINE err_t dtrsm_XAltB_ref ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9);\ ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10);\ \ - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C);\ - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C);\ - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0C);\ -\ - _mm256_storeu_pd((double *)(b11), ymm0); /*store(B11[0-3][0])*/\ - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); /*store(B11[0-3][1])*/\ - xmm5 = _mm256_castpd256_pd128(ymm2);\ + xmm5 = _mm256_castpd256_pd128(ymm8);\ + _mm_storeu_pd((double *)(b11 + cs_b * 0), xmm5);\ + xmm5 = _mm256_castpd256_pd128(ymm9);\ + _mm_storeu_pd((double *)(b11 + cs_b * 1), xmm5);\ + xmm5 = _mm256_castpd256_pd128(ymm10);\ _mm_storeu_pd((double *)(b11 + cs_b * 2), xmm5); #define BLIS_PRE_DTRSM_SMALL_2M_2N(AlphaVal,b11,cs_b)\ ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); /*register to hold alpha*/\ \ - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0));\ + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 0));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1));\ ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0);\ \ ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8);\ ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9);\ \ - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C);\ - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C);\ -\ - _mm256_storeu_pd((double *)(b11), ymm0); /*store(B11[0-3][0])*/\ - xmm5 = _mm256_castpd256_pd128(ymm1);\ + xmm5 = _mm256_castpd256_pd128(ymm8);\ + _mm_storeu_pd((double *)(b11 + cs_b * 0), xmm5);\ + xmm5 = _mm256_castpd256_pd128(ymm9);\ _mm_storeu_pd((double *)(b11 + cs_b * 1), xmm5); #define BLIS_PRE_DTRSM_SMALL_2M_1N(AlphaVal,b11,cs_b)\ @@ -1362,9 +1689,8 @@ BLIS_INLINE err_t dtrsm_XAltB_ref ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ \ ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8);\ - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C);\ \ - xmm5 = _mm256_castpd256_pd128(ymm0);\ + xmm5 = _mm256_castpd256_pd128(ymm8);\ _mm_storeu_pd((double *)(b11 + cs_b * 0), xmm5); #define BLIS_PRE_DTRSM_SMALL_1M_3N(AlphaVal,b11,cs_b)\ @@ -1378,13 +1704,9 @@ BLIS_INLINE err_t dtrsm_XAltB_ref ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9);\ ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10);\ \ - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E);\ - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E);\ - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E);\ -\ - _mm_storel_pd((b11 + cs_b * 0), _mm256_castpd256_pd128(ymm0));\ - _mm_storel_pd((b11 + cs_b * 1), _mm256_castpd256_pd128(ymm1));\ - _mm_storel_pd((b11 + cs_b * 2), _mm256_castpd256_pd128(ymm2)); + _mm_storel_pd((double *)(b11), _mm256_extractf128_pd(ymm8,0));\ + _mm_storel_pd((double *)(b11 + cs_b * 1), _mm256_extractf128_pd(ymm9,0));\ + _mm_storel_pd((double *)(b11 + cs_b * 2), _mm256_extractf128_pd(ymm10,0)); #define BLIS_PRE_DTRSM_SMALL_1M_2N(AlphaVal,b11,cs_b)\ ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); /*register to hold alpha*/\ @@ -1395,11 +1717,8 @@ BLIS_INLINE err_t dtrsm_XAltB_ref ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8);\ ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9);\ \ - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E);\ - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E);\ -\ - _mm_storel_pd((b11 + cs_b * 0), _mm256_castpd256_pd128(ymm0));\ - _mm_storel_pd((b11 + cs_b * 1), _mm256_castpd256_pd128(ymm1)); + _mm_storel_pd((double *)(b11), _mm256_extractf128_pd(ymm8,0));\ + _mm_storel_pd((double *)(b11 + cs_b * 1), _mm256_extractf128_pd(ymm9,0)); #define BLIS_PRE_DTRSM_SMALL_1M_1N(AlphaVal,b11,cs_b)\ ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); /*register to hold alpha*/\ @@ -1407,18 +1726,20 @@ BLIS_INLINE err_t dtrsm_XAltB_ref ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b *0));\ ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8);\ \ - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E);\ -\ - _mm_storel_pd((b11 + cs_b * 0), _mm256_castpd256_pd128(ymm0)); + _mm_storel_pd((double *)(b11), _mm256_extractf128_pd(ymm8,0)); /* pre & post TRSM for Right remainder cases*/ #define BLIS_PRE_DTRSM_SMALL_3N_3M(AlphaVal,b11,cs_b)\ ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); /*register to hold alpha*/\ \ - ymm0 = _mm256_loadu_pd((double const *)b11);\ + xmm5 = _mm_loadu_pd((double const*)(b11));\ + ymm0 = _mm256_broadcast_sd((double const *)(b11 + 2));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3);\ \ - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b));\ + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b));\ + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b + 2));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5);\ \ xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2));\ @@ -1427,17 +1748,13 @@ BLIS_INLINE err_t dtrsm_XAltB_ref ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); #define BLIS_POST_DTRSM_SMALL_3N_3M(b11,cs_b)\ - ymm0 = _mm256_loadu_pd((double const *)b11);\ - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x07);\ - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b));\ - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x07);\ - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2));\ - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*2 + 2));\ - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x07);\ \ - _mm256_storeu_pd((double *)b11, ymm3);\ - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5);\ + xmm5 = _mm256_castpd256_pd128(ymm3);\ + _mm_storeu_pd((double *)(b11),xmm5);\ + _mm_storel_pd((b11 + 2), _mm256_extractf128_pd(ymm3, 1));\ + xmm5 = _mm256_castpd256_pd128(ymm5);\ + _mm_storeu_pd((double *)(b11 + cs_b),xmm5);\ + _mm_storel_pd((b11 + cs_b + 2), _mm256_extractf128_pd(ymm5, 1));\ xmm5 = _mm256_castpd256_pd128(ymm7);\ _mm_storeu_pd((double *)(b11 + cs_b * 2),xmm5);\ _mm_storel_pd((b11 + cs_b * 2 + 2), _mm256_extractf128_pd(ymm7, 1)); @@ -1445,10 +1762,12 @@ BLIS_INLINE err_t dtrsm_XAltB_ref #define BLIS_PRE_DTRSM_SMALL_3N_2M(AlphaVal,b11,cs_b)\ ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); /*register to hold alpha*/\ \ - ymm0 = _mm256_loadu_pd((double const *)b11);\ + xmm5 = _mm_loadu_pd((double const*)(b11));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3);\ \ - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b));\ + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5);\ \ xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2));\ @@ -1456,16 +1775,11 @@ BLIS_INLINE err_t dtrsm_XAltB_ref ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); #define BLIS_POST_DTRSM_SMALL_3N_2M(b11,cs_b)\ - ymm0 = _mm256_loadu_pd((double const *)b11);\ - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x03);\ - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b));\ - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x03);\ - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2));\ - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x03);\ \ - _mm256_storeu_pd((double *)b11, ymm3);\ - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5);\ + xmm5 = _mm256_castpd256_pd128(ymm3);\ + _mm_storeu_pd((double *)(b11),xmm5);\ + xmm5 = _mm256_castpd256_pd128(ymm5);\ + _mm_storeu_pd((double *)(b11 + cs_b),xmm5);\ xmm5 = _mm256_castpd256_pd128(ymm7);\ _mm_storeu_pd((double *)(b11 + cs_b * 2),xmm5); @@ -1482,12 +1796,6 @@ BLIS_INLINE err_t dtrsm_XAltB_ref ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); #define BLIS_POST_DTRSM_SMALL_3N_1M(b11,cs_b)\ - ymm0 = _mm256_broadcast_sd((double const *)b11);\ - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x01);\ - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b));\ - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x01);\ - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*2));\ - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x01);\ \ _mm_storel_pd((b11 + cs_b * 0), _mm256_castpd256_pd128(ymm3));\ _mm_storel_pd((b11 + cs_b * 1), _mm256_castpd256_pd128(ymm5));\ @@ -1496,7 +1804,9 @@ BLIS_INLINE err_t dtrsm_XAltB_ref #define BLIS_PRE_DTRSM_SMALL_2N_3M(AlphaVal,b11,cs_b)\ ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); /*register to hold alpha*/\ \ - ymm0 = _mm256_loadu_pd((double const *)b11);\ + xmm5 = _mm_loadu_pd((double const*)(b11));\ + ymm0 = _mm256_broadcast_sd((double const *)(b11 + 2));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3);\ \ xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1));\ @@ -1505,14 +1815,10 @@ BLIS_INLINE err_t dtrsm_XAltB_ref ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); #define BLIS_POST_DTRSM_SMALL_2N_3M(b11,cs_b)\ - ymm0 = _mm256_loadu_pd((double const *)b11);\ - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x07);\ - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1));\ - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*1 + 2));\ - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x07);\ \ - _mm256_storeu_pd((double *)b11, ymm3);\ + xmm5 = _mm256_castpd256_pd128(ymm3);\ + _mm_storeu_pd((double *)(b11), xmm5);\ + _mm_storel_pd((b11 + 2), _mm256_extractf128_pd(ymm3, 1));\ xmm5 = _mm256_castpd256_pd128(ymm5);\ _mm_storeu_pd((double *)(b11 + cs_b*1), xmm5);\ _mm_storel_pd((b11 + cs_b * 1 + 2), _mm256_extractf128_pd(ymm5, 1)); @@ -1520,7 +1826,8 @@ BLIS_INLINE err_t dtrsm_XAltB_ref #define BLIS_PRE_DTRSM_SMALL_2N_2M(AlphaVal,b11,cs_b)\ ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); /*register to hold alpha*/\ \ - ymm0 = _mm256_loadu_pd((double const *)b11);\ + xmm5 = _mm_loadu_pd((double const*)(b11));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3);\ \ xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1));\ @@ -1528,13 +1835,9 @@ BLIS_INLINE err_t dtrsm_XAltB_ref ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); #define BLIS_POST_DTRSM_SMALL_2N_2M(b11,cs_b)\ - ymm0 = _mm256_loadu_pd((double const *)b11);\ - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x03);\ - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1));\ - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x03);\ \ - _mm256_storeu_pd((double *)b11, ymm3);\ + xmm5 = _mm256_castpd256_pd128(ymm3);\ + _mm_storeu_pd((double *)(b11), xmm5);\ xmm5 = _mm256_castpd256_pd128(ymm5);\ _mm_storeu_pd((double *)(b11 + cs_b*1), xmm5); @@ -1548,10 +1851,6 @@ BLIS_INLINE err_t dtrsm_XAltB_ref ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); #define BLIS_POST_DTRSM_SMALL_2N_1M(b11,cs_b)\ - ymm0 = _mm256_broadcast_sd((double const *)b11);\ - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x01);\ - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b));\ - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x01);\ \ _mm_storel_pd(b11 , _mm256_castpd256_pd128(ymm3));\ _mm_storel_pd((b11 + cs_b * 1), _mm256_castpd256_pd128(ymm5)); @@ -1561,8 +1860,8 @@ BLIS_INLINE err_t dtrsm_XAltB_ref \ xmm5 = _mm_loadu_pd((double const*)(b11));\ ymm0 = _mm256_broadcast_sd((double const *)(b11+ 2));\ - ymm6 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ - ymm3 = _mm256_fmsub_pd(ymm6, ymm15, ymm3); + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); #define BLIS_POST_DTRSM_SMALL_1N_3M(b11,cs_b)\ xmm5 = _mm256_castpd256_pd128(ymm3);\ @@ -1573,12 +1872,10 @@ BLIS_INLINE err_t dtrsm_XAltB_ref ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); /*register to hold alpha*/\ \ xmm5 = _mm_loadu_pd((double const*)(b11));\ - ymm6 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ - ymm3 = _mm256_fmsub_pd(ymm6, ymm15, ymm3); + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); #define BLIS_POST_DTRSM_SMALL_1N_2M(b11,cs_b)\ - ymm0 = _mm256_loadu_pd((double const *)b11);\ - ymm3 = _mm256_blend_pd(ymm6, ymm3, 0x03);\ \ xmm5 = _mm256_castpd256_pd128(ymm3);\ _mm_storeu_pd((double *)(b11), xmm5); @@ -1586,11 +1883,10 @@ BLIS_INLINE err_t dtrsm_XAltB_ref #define BLIS_PRE_DTRSM_SMALL_1N_1M(AlphaVal,b11,cs_b)\ ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); /*register to hold alpha*/\ \ - ymm6 = _mm256_broadcast_sd((double const *)b11);\ - ymm3 = _mm256_fmsub_pd(ymm6, ymm15, ymm3); + ymm0 = _mm256_broadcast_sd((double const *)b11);\ + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); #define BLIS_POST_DTRSM_SMALL_1N_1M(b11,cs_b)\ - ymm3 = _mm256_blend_pd(ymm6, ymm3, 0x01);\ \ _mm_storel_pd(b11, _mm256_castpd256_pd128(ymm3)); @@ -1867,6 +2163,223 @@ BLIS_INLINE err_t dtrsm_XAltB_ref b10 += cs_b;\ } +#define BLIS_STRSM_SMALL_GEMM_6nx7m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 7x1 block of B10*/\ + xmm5 = _mm_loadu_ps((float const*)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + xmm5 = _mm_broadcast_ss((float const*)(b10 + 6));\ + xmm5 = _mm_loadl_pi(xmm5,(__m64*)(b10 + 4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_ps(ymm2, ymm0, ymm5);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ + ymm7 = _mm256_fmadd_ps(ymm2, ymm0, ymm7);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 3)); /*A01[0][3]*/\ + ymm9 = _mm256_fmadd_ps(ymm2, ymm0, ymm9);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 4)); /*A01[0][4]*/\ + ymm11 = _mm256_fmadd_ps(ymm2, ymm0, ymm11);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 5)); /*A01[0][5]*/\ + ymm13 = _mm256_fmadd_ps(ymm2, ymm0, ymm13);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_STRSM_SMALL_GEMM_6nx6m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 6x1 block of B10*/\ + xmm5 = _mm_loadu_ps((float const*)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + xmm5 = _mm_loadl_pi(xmm5,(__m64*)(b10 + 4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_ps(ymm2, ymm0, ymm5);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ + ymm7 = _mm256_fmadd_ps(ymm2, ymm0, ymm7);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 3)); /*A01[0][3]*/\ + ymm9 = _mm256_fmadd_ps(ymm2, ymm0, ymm9);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 4)); /*A01[0][4]*/\ + ymm11 = _mm256_fmadd_ps(ymm2, ymm0, ymm11);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 5)); /*A01[0][5]*/\ + ymm13 = _mm256_fmadd_ps(ymm2, ymm0, ymm13);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_STRSM_SMALL_GEMM_6nx5m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 5x1 block of B10*/\ + ymm0 = _mm256_broadcast_ss((float const *)(b10 + 4));\ + xmm5 = _mm_loadu_ps((float const*)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_ps(ymm2, ymm0, ymm5);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ + ymm7 = _mm256_fmadd_ps(ymm2, ymm0, ymm7);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 3)); /*A01[0][3]*/\ + ymm9 = _mm256_fmadd_ps(ymm2, ymm0, ymm9);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 4)); /*A01[0][4]*/\ + ymm11 = _mm256_fmadd_ps(ymm2, ymm0, ymm11);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 5)); /*A01[0][5]*/\ + ymm13 = _mm256_fmadd_ps(ymm2, ymm0, ymm13);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_STRSM_SMALL_GEMM_6nx4m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 4x1 block of B10*/\ + xmm5 = _mm_loadu_ps((float const*)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_ps(ymm2, ymm0, ymm5);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ + ymm7 = _mm256_fmadd_ps(ymm2, ymm0, ymm7);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 3)); /*A01[0][3]*/\ + ymm9 = _mm256_fmadd_ps(ymm2, ymm0, ymm9);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 4)); /*A01[0][4]*/\ + ymm11 = _mm256_fmadd_ps(ymm2, ymm0, ymm11);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 5)); /*A01[0][5]*/\ + ymm13 = _mm256_fmadd_ps(ymm2, ymm0, ymm13);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_STRSM_SMALL_GEMM_6nx3m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 3x1 block of B10*/\ + __m128 xmm6 = _mm_broadcast_ss((float const *)(b10+ 2));\ + xmm5 = _mm_loadl_pi(xmm6,(__m64*)(b10)); \ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_ps(ymm2, ymm0, ymm5);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ + ymm7 = _mm256_fmadd_ps(ymm2, ymm0, ymm7);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 3)); /*A01[0][3]*/\ + ymm9 = _mm256_fmadd_ps(ymm2, ymm0, ymm9);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 4)); /*A01[0][4]*/\ + ymm11 = _mm256_fmadd_ps(ymm2, ymm0, ymm11);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 5)); /*A01[0][5]*/\ + ymm13 = _mm256_fmadd_ps(ymm2, ymm0, ymm13);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_STRSM_SMALL_GEMM_6nx2m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 2x1 block of B10*/\ + xmm5 = _mm_setzero_ps();\ + xmm5 = _mm_loadl_pi(xmm5,(__m64*)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_ps(ymm2, ymm0, ymm5);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ + ymm7 = _mm256_fmadd_ps(ymm2, ymm0, ymm7);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 3)); /*A01[0][3]*/\ + ymm9 = _mm256_fmadd_ps(ymm2, ymm0, ymm9);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 4)); /*A01[0][4]*/\ + ymm11 = _mm256_fmadd_ps(ymm2, ymm0, ymm11);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 5)); /*A01[0][5]*/\ + ymm13 = _mm256_fmadd_ps(ymm2, ymm0, ymm13);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_STRSM_SMALL_GEMM_6nx1m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 1x1 block of B10*/\ + ymm0 = _mm256_broadcast_ss((float const *)b10);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_ps(ymm2, ymm0, ymm5);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ + ymm7 = _mm256_fmadd_ps(ymm2, ymm0, ymm7);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 3)); /*A01[0][3]*/\ + ymm9 = _mm256_fmadd_ps(ymm2, ymm0, ymm9);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 4)); /*A01[0][4]*/\ + ymm11 = _mm256_fmadd_ps(ymm2, ymm0, ymm11);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 5)); /*A01[0][5]*/\ + ymm13 = _mm256_fmadd_ps(ymm2, ymm0, ymm13);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + #define BLIS_STRSM_SMALL_GEMM_4nx16m(a01,b10,cs_b,p_lda,k_iter) \ for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ {\ @@ -1978,6 +2491,181 @@ BLIS_INLINE err_t dtrsm_XAltB_ref b10 += cs_b;\ } +#define BLIS_STRSM_SMALL_GEMM_4nx7m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 7x1 block of B10*/\ + xmm5 = _mm_loadu_ps((float const*)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + xmm5 = _mm_broadcast_ss((float const*)(b10 + 6));\ + xmm5 = _mm_loadl_pi(xmm5,(__m64*)(b10 + 4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_ps(ymm2, ymm0, ymm5);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ + ymm7 = _mm256_fmadd_ps(ymm2, ymm0, ymm7);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 3)); /*A01[0][3]*/\ + ymm9 = _mm256_fmadd_ps(ymm2, ymm0, ymm9);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_STRSM_SMALL_GEMM_4nx6m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 6x1 block of B10*/\ + xmm5 = _mm_loadu_ps((float const*)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + xmm5 = _mm_loadl_pi(xmm5,(__m64*)(b10 + 4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_ps(ymm2, ymm0, ymm5);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ + ymm7 = _mm256_fmadd_ps(ymm2, ymm0, ymm7);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 3)); /*A01[0][3]*/\ + ymm9 = _mm256_fmadd_ps(ymm2, ymm0, ymm9);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_STRSM_SMALL_GEMM_4nx5m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 5x1 block of B10*/\ + ymm0 = _mm256_broadcast_ss((float const *)(b10 + 4));\ + xmm5 = _mm_loadu_ps((float const*)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_ps(ymm2, ymm0, ymm5);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ + ymm7 = _mm256_fmadd_ps(ymm2, ymm0, ymm7);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 3)); /*A01[0][3]*/\ + ymm9 = _mm256_fmadd_ps(ymm2, ymm0, ymm9);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_STRSM_SMALL_GEMM_4nx4m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 4x1 block of B10*/\ + xmm5 = _mm_loadu_ps((float const*)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_ps(ymm2, ymm0, ymm5);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ + ymm7 = _mm256_fmadd_ps(ymm2, ymm0, ymm7);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 3)); /*A01[0][3]*/\ + ymm9 = _mm256_fmadd_ps(ymm2, ymm0, ymm9);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_STRSM_SMALL_GEMM_4nx3m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 3x1 block of B10*/\ + __m128 xmm6 = _mm_broadcast_ss((float const *)(b10+ 2));\ + xmm5 = _mm_loadl_pi(xmm6,(__m64*)(b10)); \ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_ps(ymm2, ymm0, ymm5);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ + ymm7 = _mm256_fmadd_ps(ymm2, ymm0, ymm7);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 3)); /*A01[0][3]*/\ + ymm9 = _mm256_fmadd_ps(ymm2, ymm0, ymm9);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_STRSM_SMALL_GEMM_4nx2m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 2x1 block of B10*/\ + xmm5 = _mm_setzero_ps();\ + xmm5 = _mm_loadl_pi(xmm5,(__m64*)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_ps(ymm2, ymm0, ymm5);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ + ymm7 = _mm256_fmadd_ps(ymm2, ymm0, ymm7);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 3)); /*A01[0][3]*/\ + ymm9 = _mm256_fmadd_ps(ymm2, ymm0, ymm9);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_STRSM_SMALL_GEMM_4nx1m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 1x1 block of B10*/\ + ymm0 = _mm256_broadcast_ss((float const *)b10);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_ps(ymm2, ymm0, ymm5);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ + ymm7 = _mm256_fmadd_ps(ymm2, ymm0, ymm7);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 3)); /*A01[0][3]*/\ + ymm9 = _mm256_fmadd_ps(ymm2, ymm0, ymm9);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + #define BLIS_STRSM_SMALL_GEMM_3nx8m(a01,b10,cs_b,p_lda,k_iter) \ for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ {\ @@ -1998,6 +2686,160 @@ BLIS_INLINE err_t dtrsm_XAltB_ref b10 += cs_b;\ } +#define BLIS_STRSM_SMALL_GEMM_3nx7m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 7x1 block of B10*/\ + xmm5 = _mm_loadu_ps((float const*)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + xmm5 = _mm_broadcast_ss((float const*)(b10 + 6));\ + xmm5 = _mm_loadl_pi(xmm5,(__m64*)(b10 + 4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_ps(ymm2, ymm0, ymm5);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ + ymm7 = _mm256_fmadd_ps(ymm2, ymm0, ymm7);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_STRSM_SMALL_GEMM_3nx6m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 6x1 block of B10*/\ + xmm5 = _mm_loadu_ps((float const*)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + xmm5 = _mm_loadl_pi(xmm5,(__m64*)(b10 + 4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_ps(ymm2, ymm0, ymm5);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ + ymm7 = _mm256_fmadd_ps(ymm2, ymm0, ymm7);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_STRSM_SMALL_GEMM_3nx5m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 5x1 block of B10*/\ + ymm0 = _mm256_broadcast_ss((float const *)(b10 + 4));\ + xmm5 = _mm_loadu_ps((float const*)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_ps(ymm2, ymm0, ymm5);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ + ymm7 = _mm256_fmadd_ps(ymm2, ymm0, ymm7);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_STRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 4x1 block of B10*/\ + xmm5 = _mm_loadu_ps((float const*)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_ps(ymm2, ymm0, ymm5);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ + ymm7 = _mm256_fmadd_ps(ymm2, ymm0, ymm7);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_STRSM_SMALL_GEMM_3nx3m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 3x1 block of B10*/\ + __m128 xmm6 = _mm_broadcast_ss((float const *)(b10+ 2));\ + xmm5 = _mm_loadl_pi(xmm6,(__m64*)(b10)); \ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_ps(ymm2, ymm0, ymm5);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ + ymm7 = _mm256_fmadd_ps(ymm2, ymm0, ymm7);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_STRSM_SMALL_GEMM_3nx2m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 2x1 block of B10*/\ + xmm5 = _mm_setzero_ps();\ + xmm5 = _mm_loadl_pi(xmm5,(__m64*)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_ps(ymm2, ymm0, ymm5);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ + ymm7 = _mm256_fmadd_ps(ymm2, ymm0, ymm7);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_STRSM_SMALL_GEMM_3nx1m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 1x1 block of B10*/\ + ymm0 = _mm256_broadcast_ss((float const *)b10);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_ps(ymm2, ymm0, ymm5);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ + ymm7 = _mm256_fmadd_ps(ymm2, ymm0, ymm7);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + #define BLIS_STRSM_SMALL_GEMM_2nx8m(a01,b10,cs_b,p_lda,k_iter) \ for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ {\ @@ -2015,6 +2857,139 @@ BLIS_INLINE err_t dtrsm_XAltB_ref b10 += cs_b;\ } +#define BLIS_STRSM_SMALL_GEMM_2nx7m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 7x1 block of B10*/\ + xmm5 = _mm_loadu_ps((float const*)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + xmm5 = _mm_broadcast_ss((float const*)(b10 + 6));\ + xmm5 = _mm_loadl_pi(xmm5,(__m64*)(b10 + 4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_ps(ymm2, ymm0, ymm5);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_STRSM_SMALL_GEMM_2nx6m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 6x1 block of B10*/\ + xmm5 = _mm_loadu_ps((float const*)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + xmm5 = _mm_loadl_pi(xmm5,(__m64*)(b10 + 4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_ps(ymm2, ymm0, ymm5);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_STRSM_SMALL_GEMM_2nx5m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 5x1 block of B10*/\ + ymm0 = _mm256_broadcast_ss((float const *)(b10 + 4));\ + xmm5 = _mm_loadu_ps((float const*)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_ps(ymm2, ymm0, ymm5);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_STRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 4x1 block of B10*/\ + xmm5 = _mm_loadu_ps((float const*)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_ps(ymm2, ymm0, ymm5);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_STRSM_SMALL_GEMM_2nx3m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 3x1 block of B10*/\ + __m128 xmm6 = _mm_broadcast_ss((float const *)(b10+ 2));\ + xmm5 = _mm_loadl_pi(xmm6,(__m64*)(b10)); \ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_ps(ymm2, ymm0, ymm5);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_STRSM_SMALL_GEMM_2nx2m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 2x1 block of B10*/\ + xmm5 = _mm_setzero_ps();\ + xmm5 = _mm_loadl_pi(xmm5,(__m64*)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_ps(ymm2, ymm0, ymm5);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_STRSM_SMALL_GEMM_2nx1m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 1x1 block of B10*/\ + ymm0 = _mm256_broadcast_ss((float const *)b10);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_ps(ymm2, ymm0, ymm5);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + #define BLIS_STRSM_SMALL_GEMM_1nx8m(a01,b10,cs_b,p_lda,k_iter) \ for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ {\ @@ -2029,12 +3004,94 @@ BLIS_INLINE err_t dtrsm_XAltB_ref b10 += cs_b;\ } +#define BLIS_STRSM_SMALL_GEMM_1nx7m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 7x1 block of B10*/\ + xmm5 = _mm_loadu_ps((float const*)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + xmm5 = _mm_broadcast_ss((float const*)(b10 + 6));\ + xmm5 = _mm_loadl_pi(xmm5,(__m64*)(b10 + 4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_STRSM_SMALL_GEMM_1nx6m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 6x1 block of B10*/\ + xmm5 = _mm_loadu_ps((float const*)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + xmm5 = _mm_loadl_pi(xmm5,(__m64*)(b10 + 4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + #define BLIS_STRSM_SMALL_GEMM_1nx5m(a01,b10,cs_b,p_lda,k_iter) \ for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ {\ /*load 5x1 block of B10*/\ ymm0 = _mm256_broadcast_ss((float const *)(b10 + 4));\ - xmm5 = _mm_loadu_ps((float const *)(b10));\ + xmm5 = _mm_loadu_ps((float const*)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_STRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 4x1 block of B10*/\ + xmm5 = _mm_loadu_ps((float const*)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_STRSM_SMALL_GEMM_1nx3m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 3x1 block of B10*/\ + __m128 xmm6 = _mm_broadcast_ss((float const *)(b10+ 2));\ + xmm5 = _mm_loadl_pi(xmm6,(__m64*)(b10)); \ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_STRSM_SMALL_GEMM_1nx2m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 2x1 block of B10*/\ + xmm5 = _mm_setzero_ps();\ + xmm5 = _mm_loadl_pi(xmm5,(__m64*)(b10));\ ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ \ /*broadcast 1st row of A01*/\ @@ -2045,6 +3102,20 @@ BLIS_INLINE err_t dtrsm_XAltB_ref b10 += cs_b;\ } +#define BLIS_STRSM_SMALL_GEMM_1nx1m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 1x1 block of B10*/\ + ymm0 = _mm256_broadcast_ss((float const *)b10);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + /*GEMM block used in strsm small left cases*/ #define BLIS_STRSM_SMALL_GEMM_16mx6n(a10,b01,cs_b,p_lda,k_iter) \ float *b01_prefetch = b01 + 8; \ @@ -5004,7 +6075,286 @@ BLIS_INLINE err_t ztrsm_AuXB_ref ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ {\ - ymm0 = _mm256_loadu_pd((double const *)(a10));\ + ymm0 = _mm256_loadu_pd((double const *)(a10));\ + ymm0 = _mm256_mul_pd(ymm0, ymm18);\ + \ + ymm1 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0));\ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0 + 1));\ + \ + ymm8 = _mm256_fmadd_pd(ymm0, ymm1, ymm8);\ + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4);\ + \ + ymm1 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1));\ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1 + 1));\ + \ + ymm9 = _mm256_fmadd_pd(ymm0, ymm1, ymm9);\ + ymm5 = _mm256_fmadd_pd(ymm0, ymm2, ymm5);\ + \ + ymm1 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 2));\ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 2 + 1));\ + \ + ymm10 = _mm256_fmadd_pd(ymm0, ymm1, ymm10);\ + ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6);\ + \ + tptr += 2; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_loadu_pd((double const *)(a10));\ + \ + ymm1 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0));\ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0 + 1));\ + \ + ymm8 = _mm256_fmadd_pd(ymm0, ymm1, ymm8);\ + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4);\ + \ + ymm1 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1));\ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1 + 1));\ + \ + ymm9 = _mm256_fmadd_pd(ymm0, ymm1, ymm9);\ + ymm5 = _mm256_fmadd_pd(ymm0, ymm2, ymm5);\ + \ + ymm1 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 2));\ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 2 + 1));\ + \ + ymm10 = _mm256_fmadd_pd(ymm0, ymm1, ymm10);\ + ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6);\ + \ + tptr += 2; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + }\ + }\ + ymm4 = _mm256_permute_pd(ymm4, 0x5);\ + ymm5 = _mm256_permute_pd(ymm5, 0x5);\ + ymm6 = _mm256_permute_pd(ymm6, 0x5);\ + ymm8 = _mm256_addsub_pd(ymm8, ymm4);\ + ymm9 = _mm256_addsub_pd(ymm9, ymm5);\ + ymm10 = _mm256_addsub_pd(ymm10, ymm6);\ +} + +#define BLIS_ZTRSM_SMALL_GEMM_2mx2n(a10,b01,cs_b,p_lda,k_iter){\ + double *tptr = (double *)b01;\ + if(conjtransa) {\ + ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_loadu_pd((double const *)(a10));\ + ymm0 = _mm256_mul_pd(ymm0, ymm18);\ + \ + ymm1 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0));\ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0 + 1));\ + \ + ymm8 = _mm256_fmadd_pd(ymm0, ymm1, ymm8);\ + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4);\ + \ + ymm1 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1));\ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1 + 1));\ + \ + ymm9 = _mm256_fmadd_pd(ymm0, ymm1, ymm9);\ + ymm5 = _mm256_fmadd_pd(ymm0, ymm2, ymm5);\ + \ + tptr += 2; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_loadu_pd((double const *)(a10));\ + \ + ymm1 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0));\ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0 + 1));\ + \ + ymm8 = _mm256_fmadd_pd(ymm0, ymm1, ymm8);\ + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4);\ + \ + ymm1 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1));\ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1 + 1));\ + \ + ymm9 = _mm256_fmadd_pd(ymm0, ymm1, ymm9);\ + ymm5 = _mm256_fmadd_pd(ymm0, ymm2, ymm5);\ + \ + tptr += 2; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + }\ + }\ + ymm4 = _mm256_permute_pd(ymm4, 0x5);\ + ymm5 = _mm256_permute_pd(ymm5, 0x5);\ + ymm8 = _mm256_addsub_pd(ymm8, ymm4);\ + ymm9 = _mm256_addsub_pd(ymm9, ymm5);\ +} + +#define BLIS_ZTRSM_SMALL_GEMM_2mx1n(a10,b01,cs_b,p_lda,k_iter){\ + double *tptr = (double *)b01;\ + if(conjtransa) {\ + ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_loadu_pd((double const *)(a10));\ + ymm0 = _mm256_mul_pd(ymm0, ymm18);\ + \ + ymm1 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0));\ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0 + 1));\ + \ + ymm8 = _mm256_fmadd_pd(ymm0, ymm1, ymm8);\ + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4);\ + \ + tptr += 2; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_loadu_pd((double const *)(a10));\ + \ + ymm1 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0));\ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0 + 1));\ + \ + ymm8 = _mm256_fmadd_pd(ymm0, ymm1, ymm8);\ + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4);\ + \ + tptr += 2; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + }\ + }\ + ymm4 = _mm256_permute_pd(ymm4, 0x5);\ + ymm8 = _mm256_addsub_pd(ymm8, ymm4);\ +} + +#define BLIS_ZTRSM_SMALL_GEMM_1mx3n(a10,b01,cs_b,p_lda,k_iter){\ + double *tptr = (double *)b01;\ + if(conjtransa) {\ + ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + xmm4 = _mm_loadu_pd((double const *)(a10));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm4, 0); \ + ymm0 = _mm256_mul_pd(ymm0, ymm18);\ + \ + ymm1 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0));\ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0 + 1));\ + \ + ymm8 = _mm256_fmadd_pd(ymm0, ymm1, ymm8);\ + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4);\ + \ + ymm1 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1));\ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1 + 1));\ + \ + ymm9 = _mm256_fmadd_pd(ymm0, ymm1, ymm9);\ + ymm5 = _mm256_fmadd_pd(ymm0, ymm2, ymm5);\ + \ + ymm1 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 2));\ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 2 + 1));\ + \ + ymm10 = _mm256_fmadd_pd(ymm0, ymm1, ymm10);\ + ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6);\ + \ + tptr += 2; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + xmm4 = _mm_loadu_pd((double const *)(a10));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm4, 0); \ + \ + ymm1 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0));\ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0 + 1));\ + \ + ymm8 = _mm256_fmadd_pd(ymm0, ymm1, ymm8);\ + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4);\ + \ + ymm1 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1));\ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1 + 1));\ + \ + ymm9 = _mm256_fmadd_pd(ymm0, ymm1, ymm9);\ + ymm5 = _mm256_fmadd_pd(ymm0, ymm2, ymm5);\ + \ + ymm1 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 2));\ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 2 + 1));\ + \ + ymm10 = _mm256_fmadd_pd(ymm0, ymm1, ymm10);\ + ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6);\ + \ + tptr += 2; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + }\ + }\ + ymm4 = _mm256_permute_pd(ymm4, 0x5);\ + ymm5 = _mm256_permute_pd(ymm5, 0x5);\ + ymm6 = _mm256_permute_pd(ymm6, 0x5);\ + ymm8 = _mm256_addsub_pd(ymm8, ymm4);\ + ymm9 = _mm256_addsub_pd(ymm9, ymm5);\ + ymm10 = _mm256_addsub_pd(ymm10, ymm6);\ +} + +#define BLIS_ZTRSM_SMALL_GEMM_1mx2n(a10,b01,cs_b,p_lda,k_iter){\ + double *tptr = (double *)b01;\ + if(conjtransa) {\ + ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + xmm4 = _mm_loadu_pd((double const *)(a10));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm4, 0); \ + ymm0 = _mm256_mul_pd(ymm0, ymm18);\ + \ + ymm1 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0));\ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0 + 1));\ + \ + ymm8 = _mm256_fmadd_pd(ymm0, ymm1, ymm8);\ + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4);\ + \ + ymm1 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1));\ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1 + 1));\ + \ + ymm9 = _mm256_fmadd_pd(ymm0, ymm1, ymm9);\ + ymm5 = _mm256_fmadd_pd(ymm0, ymm2, ymm5);\ + \ + tptr += 2; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + xmm4 = _mm_loadu_pd((double const *)(a10));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm4, 0); \ + \ + ymm1 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0));\ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0 + 1));\ + \ + ymm8 = _mm256_fmadd_pd(ymm0, ymm1, ymm8);\ + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4);\ + \ + ymm1 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1));\ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1 + 1));\ + \ + ymm9 = _mm256_fmadd_pd(ymm0, ymm1, ymm9);\ + ymm5 = _mm256_fmadd_pd(ymm0, ymm2, ymm5);\ + \ + tptr += 2; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + }\ + }\ + ymm4 = _mm256_permute_pd(ymm4, 0x5);\ + ymm5 = _mm256_permute_pd(ymm5, 0x5);\ + ymm8 = _mm256_addsub_pd(ymm8, ymm4);\ + ymm9 = _mm256_addsub_pd(ymm9, ymm5);\ +} + +#define BLIS_ZTRSM_SMALL_GEMM_1mx1n(a10,b01,cs_b,p_lda,k_iter){\ + double *tptr = (double *)b01;\ + if(conjtransa) {\ + ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + xmm4 = _mm_loadu_pd((double const *)(a10));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm4, 0); \ ymm0 = _mm256_mul_pd(ymm0, ymm18);\ \ ymm1 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0));\ @@ -5013,18 +6363,6 @@ BLIS_INLINE err_t ztrsm_AuXB_ref ymm8 = _mm256_fmadd_pd(ymm0, ymm1, ymm8);\ ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4);\ \ - ymm1 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1));\ - ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1 + 1));\ - \ - ymm9 = _mm256_fmadd_pd(ymm0, ymm1, ymm9);\ - ymm5 = _mm256_fmadd_pd(ymm0, ymm2, ymm5);\ - \ - ymm1 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 2));\ - ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 2 + 1));\ - \ - ymm10 = _mm256_fmadd_pd(ymm0, ymm1, ymm10);\ - ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6);\ - \ tptr += 2; /*move to next row of B*/\ a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ }\ @@ -5032,7 +6370,8 @@ BLIS_INLINE err_t ztrsm_AuXB_ref else {\ for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ {\ - ymm0 = _mm256_loadu_pd((double const *)(a10));\ + xmm4 = _mm_loadu_pd((double const *)(a10));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm4, 0); \ \ ymm1 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0));\ ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0 + 1));\ @@ -5040,28 +6379,12 @@ BLIS_INLINE err_t ztrsm_AuXB_ref ymm8 = _mm256_fmadd_pd(ymm0, ymm1, ymm8);\ ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4);\ \ - ymm1 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1));\ - ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1 + 1));\ - \ - ymm9 = _mm256_fmadd_pd(ymm0, ymm1, ymm9);\ - ymm5 = _mm256_fmadd_pd(ymm0, ymm2, ymm5);\ - \ - ymm1 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 2));\ - ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 2 + 1));\ - \ - ymm10 = _mm256_fmadd_pd(ymm0, ymm1, ymm10);\ - ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6);\ - \ tptr += 2; /*move to next row of B*/\ a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ }\ }\ - ymm4 = _mm256_permute_pd(ymm4, 0x5);\ - ymm5 = _mm256_permute_pd(ymm5, 0x5);\ - ymm6 = _mm256_permute_pd(ymm6, 0x5);\ + ymm4 = _mm256_permute_pd(ymm4, 0x5);\ ymm8 = _mm256_addsub_pd(ymm8, ymm4);\ - ymm9 = _mm256_addsub_pd(ymm9, ymm5);\ - ymm10 = _mm256_addsub_pd(ymm10, ymm6);\ } /** @@ -5493,6 +6816,136 @@ BLIS_INLINE err_t ztrsm_AuXB_ref ymm6 = _mm256_addsub_pd(ymm6, ymm11);\ } +#define BLIS_ZTRSM_SMALL_GEMM_2nx3m(a01,b10,cs_b,p_lda,k_iter) {\ + double *tptr = (double *)a01;\ + if(conjtransa) {\ + ymm18 = _mm256_set_pd(-1.0, -1.0, -1.0, -1.0);\ + for(k = 0; k< k_iter; k++) \ + { \ + ymm0 = _mm256_loadu_pd((double const *)(b10)); \ + xmm5 = _mm_loadu_pd((double const *)(b10 + 2));\ + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); \ + \ + _mm_prefetch((char*)( b10 + 4*cs_b), _MM_HINT_T0); \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0)); \ + ymm12 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0 + 1)); \ + ymm12 = _mm256_mul_pd(ymm12, ymm18);\ + \ + ymm3 = _mm256_fmadd_pd(ymm0, ymm2, ymm3);\ + ymm4 = _mm256_fmadd_pd(ymm1, ymm2, ymm4);\ + ymm8 = _mm256_fmadd_pd(ymm0, ymm12, ymm8);\ + ymm9 = _mm256_fmadd_pd(ymm1, ymm12, ymm9);\ + \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1)); \ + ymm12 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1 + 1)); \ + ymm12 = _mm256_mul_pd(ymm12, ymm18);\ + \ + ymm5 = _mm256_fmadd_pd(ymm0, ymm2, ymm5);\ + ymm6 = _mm256_fmadd_pd(ymm1, ymm2, ymm6);\ + ymm10 = _mm256_fmadd_pd(ymm0, ymm12, ymm10);\ + ymm11 = _mm256_fmadd_pd(ymm1, ymm12, ymm11);\ + \ + tptr += 2; \ + b10 += cs_b; \ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) \ + { \ + ymm0 = _mm256_loadu_pd((double const *)(b10)); \ + xmm5 = _mm_loadu_pd((double const *)(b10 + 2));\ + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); \ + \ + _mm_prefetch((char*)( b10 + 4*cs_b), _MM_HINT_T0); \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0)); \ + ymm12 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0 + 1)); \ + \ + ymm3 = _mm256_fmadd_pd(ymm0, ymm2, ymm3);\ + ymm4 = _mm256_fmadd_pd(ymm1, ymm2, ymm4);\ + ymm8 = _mm256_fmadd_pd(ymm0, ymm12, ymm8);\ + ymm9 = _mm256_fmadd_pd(ymm1, ymm12, ymm9);\ + \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1)); \ + ymm12 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1 + 1)); \ + \ + ymm5 = _mm256_fmadd_pd(ymm0, ymm2, ymm5);\ + ymm6 = _mm256_fmadd_pd(ymm1, ymm2, ymm6);\ + ymm10 = _mm256_fmadd_pd(ymm0, ymm12, ymm10);\ + ymm11 = _mm256_fmadd_pd(ymm1, ymm12, ymm11);\ + \ + tptr += 2; \ + b10 += cs_b; \ + }\ + }\ + ymm8 = _mm256_permute_pd(ymm8, 0x5);\ + ymm9 = _mm256_permute_pd(ymm9, 0x5);\ + ymm10 = _mm256_permute_pd(ymm10, 0x5);\ + ymm11 = _mm256_permute_pd(ymm11, 0x5);\ + ymm3 = _mm256_addsub_pd(ymm3, ymm8);\ + ymm4 = _mm256_addsub_pd(ymm4, ymm9);\ + ymm5 = _mm256_addsub_pd(ymm5, ymm10);\ + ymm6 = _mm256_addsub_pd(ymm6, ymm11);\ +} + +#define BLIS_ZTRSM_SMALL_GEMM_2nx1m(a01,b10,cs_b,p_lda,k_iter){\ + double *tptr = (double *)a01;\ + if(conjtransa) {\ + ymm18 = _mm256_set_pd(-1.0, -1.0, -1.0, -1.0);\ + for(k = 0; k< k_iter; k++) \ + { \ + xmm5 = _mm_loadu_pd((double const *)(b10));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); \ + \ + _mm_prefetch((char*)( b10 + 2*cs_b), _MM_HINT_T0); \ + ymm1 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0)); \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0 + 1)); \ + ymm2 = _mm256_mul_pd(ymm2, ymm18);\ + \ + ymm3 = _mm256_fmadd_pd(ymm0, ymm1, ymm3);\ + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4);\ + \ + \ + ymm1 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1)); \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1 + 1)); \ + ymm2 = _mm256_mul_pd(ymm2, ymm18);\ + \ + ymm5 = _mm256_fmadd_pd(ymm0, ymm1, ymm5);\ + ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6);\ + \ + tptr += 2; \ + b10 += cs_b; \ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) \ + { \ + xmm5 = _mm_loadu_pd((double const *)(b10));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); \ + \ + _mm_prefetch((char*)( b10 + 2*cs_b), _MM_HINT_T0); \ + ymm1 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0)); \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0 + 1)); \ + \ + ymm3 = _mm256_fmadd_pd(ymm0, ymm1, ymm3);\ + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4);\ + \ + \ + ymm1 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1)); \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1 + 1)); \ + \ + ymm5 = _mm256_fmadd_pd(ymm0, ymm1, ymm5);\ + ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6);\ + \ + tptr += 2; \ + b10 += cs_b; \ + }\ + }\ + ymm4 = _mm256_permute_pd(ymm4, 0x5);\ + ymm6 = _mm256_permute_pd(ymm6, 0x5);\ + ymm3 = _mm256_addsub_pd(ymm3, ymm4);\ + ymm5 = _mm256_addsub_pd(ymm5, ymm6);\ +} + /** * Performs GEMM operation * ymm0 holds 2 elements of a column. @@ -5625,59 +7078,62 @@ BLIS_INLINE err_t ztrsm_AuXB_ref * 3 elements of a columns get held by ymm0(2 element) * and xmm5 (1 element). */ -#define BLIS_ZTRSM_SMALL_GEMM_1nx3m(a01,b10,cs_b,p_lda,k_iter) {\ + #define BLIS_ZTRSM_SMALL_GEMM_1nx3m(a01,b10,cs_b,p_lda,k_iter) {\ double *tptr = (double *)a01;\ if(conjtransa) {\ ymm18 = _mm256_set_pd(-1.0, -1.0, -1.0, -1.0);\ - for(k = 0; k< k_iter; k++) \ + for(k = 0; k < k_iter; k++)\ {\ - ymm0 = _mm256_loadu_pd((double const *)(b10)); \ - /*ymm1 = _mm256_loadu_pd((double const *)(b10 + 2));*/\ + ymm0 = _mm256_loadu_pd((double const *)b10);\ xmm5 = _mm_loadu_pd((double const *)(b10 + 2));\ - ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0);\ + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); \ \ _mm_prefetch((char*)( b10 + 4*cs_b), _MM_HINT_T0); \ - ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0)); \ - ymm5 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0 + 1)); \ - ymm5 = _mm256_mul_pd(ymm5, ymm18);\ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0));\ + ymm7 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0 + 1));\ + ymm7 = _mm256_mul_pd(ymm7, ymm18);\ + /*dcomplex multiplication and substraction*/\ \ ymm3 = _mm256_fmadd_pd(ymm0, ymm2, ymm3);\ - ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6);\ - ymm4 = _mm256_fmadd_pd(ymm1, ymm5, ymm4);\ - ymm7 = _mm256_fmadd_pd(ymm1, ymm5, ymm7);\ + ymm4 = _mm256_fmadd_pd(ymm1, ymm2, ymm4);\ + ymm5 = _mm256_fmadd_pd(ymm0, ymm7, ymm5);\ + ymm6 = _mm256_fmadd_pd(ymm1, ymm7, ymm6);\ + /*dcomplex multiplication and substraction*/\ \ tptr += 2;\ b10 += cs_b;\ }\ }\ else {\ - for(k = 0; k< k_iter; k++) \ + for(k = 0; k < k_iter; k++)\ {\ - ymm0 = _mm256_loadu_pd((double const *)(b10)); \ - /*ymm1 = _mm256_loadu_pd((double const *)(b10 + 2));*/\ + ymm0 = _mm256_loadu_pd((double const *)b10);\ xmm5 = _mm_loadu_pd((double const *)(b10 + 2));\ - ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0);\ + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); \ \ _mm_prefetch((char*)( b10 + 4*cs_b), _MM_HINT_T0); \ - ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0)); \ - ymm5 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0 + 1)); \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0));\ + ymm7 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0 + 1));\ + /*dcomplex multiplication and substraction*/\ \ ymm3 = _mm256_fmadd_pd(ymm0, ymm2, ymm3);\ - ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6);\ - ymm4 = _mm256_fmadd_pd(ymm1, ymm5, ymm4);\ - ymm7 = _mm256_fmadd_pd(ymm1, ymm5, ymm7);\ + ymm4 = _mm256_fmadd_pd(ymm1, ymm2, ymm4);\ + ymm5 = _mm256_fmadd_pd(ymm0, ymm7, ymm5);\ + ymm6 = _mm256_fmadd_pd(ymm1, ymm7, ymm6);\ + /*ymm3 = _mm256_add_pd(ymm15, ymm3);*/\ + /*dcomplex multiplication and substraction*/\ \ tptr += 2;\ b10 += cs_b;\ }\ }\ + ymm5 = _mm256_permute_pd(ymm5, 0x5);\ ymm6 = _mm256_permute_pd(ymm6, 0x5);\ - ymm7 = _mm256_permute_pd(ymm7, 0x5);\ - ymm3 = _mm256_addsub_pd(ymm3, ymm6);\ - ymm4 = _mm256_addsub_pd(ymm5, ymm7);\ +\ + ymm3 = _mm256_addsub_pd(ymm3, ymm5);\ + ymm4 = _mm256_addsub_pd(ymm4, ymm6);\ } - /** * Performs GEMM operation. * 1 elements of a column are kept in ymm0. @@ -5721,7 +7177,7 @@ BLIS_INLINE err_t ztrsm_AuXB_ref }\ }\ ymm4 = _mm256_permute_pd(ymm4, 0x5);\ - ymm3 = _mm256_addsub_pd(ymm3, ymm4);\ + ymm3 = _mm256_addsub_pd(ymm3, ymm4);\ } @@ -5766,7 +7222,7 @@ BLIS_INLINE err_t ztrsm_AuXB_ref }\ }\ ymm4 = _mm256_permute_pd(ymm4, 0x5);\ - ymm3 = _mm256_addsub_pd(ymm3, ymm4);\ + ymm3 = _mm256_addsub_pd(ymm3, ymm4);\ } /** @@ -5864,6 +7320,184 @@ BLIS_INLINE err_t ztrsm_AuXB_ref ymm8 = _mm256_addsub_pd(ymm8, ymm15);\ } +#define BLIS_ZTRSM_SMALL_GEMM_3nx3m(a01,b10,cs_b,p_lda,k_iter) {\ + double *tptr = (double *)a01;\ + if(conjtransa) {\ + ymm18 = _mm256_set_pd(-1.0, -1.0, -1.0, -1.0);\ + for(k = 0; k< k_iter; k++) \ + { \ + ymm0 = _mm256_loadu_pd((double const *)(b10)); \ + xmm5 = _mm_loadu_pd((double const *)(b10 + 2));\ + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); \ + \ + _mm_prefetch((char*)( b10 + 4*cs_b), _MM_HINT_T0); \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0));\ + ymm9 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0 + 1));\ + ymm9 = _mm256_mul_pd(ymm9, ymm18);\ + \ + ymm3 = _mm256_fmadd_pd(ymm0, ymm2, ymm3);\ + ymm4 = _mm256_fmadd_pd(ymm1, ymm2, ymm4);\ + ymm10 = _mm256_fmadd_pd(ymm0, ymm9, ymm10);\ + ymm11 = _mm256_fmadd_pd(ymm1, ymm9, ymm11);\ + \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1)); \ + ymm9 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1 + 1)); \ + ymm9 = _mm256_mul_pd(ymm9, ymm18);\ + \ + ymm5 = _mm256_fmadd_pd(ymm0, ymm2, ymm5);\ + ymm6 = _mm256_fmadd_pd(ymm1, ymm2, ymm6);\ + ymm12 = _mm256_fmadd_pd(ymm0, ymm9, ymm12);\ + ymm13 = _mm256_fmadd_pd(ymm1, ymm9, ymm13);\ + \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 2)); \ + ymm9 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 2 + 1)); \ + ymm9 = _mm256_mul_pd(ymm9, ymm18);\ + \ + ymm7 = _mm256_fmadd_pd(ymm0, ymm2, ymm7);\ + ymm8 = _mm256_fmadd_pd(ymm1, ymm2, ymm8);\ + ymm14 = _mm256_fmadd_pd(ymm0, ymm9, ymm14);\ + ymm15 = _mm256_fmadd_pd(ymm1, ymm9, ymm15);\ + \ + tptr += 2; \ + b10 += cs_b; \ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) \ + { \ + ymm0 = _mm256_loadu_pd((double const *)(b10)); \ + xmm5 = _mm_loadu_pd((double const *)(b10 + 2));\ + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); \ + \ + _mm_prefetch((char*)( b10 + 4*cs_b), _MM_HINT_T0); \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0));\ + ymm9 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0 + 1));\ + \ + ymm3 = _mm256_fmadd_pd(ymm0, ymm2, ymm3);\ + ymm4 = _mm256_fmadd_pd(ymm1, ymm2, ymm4);\ + ymm10 = _mm256_fmadd_pd(ymm0, ymm9, ymm10);\ + ymm11 = _mm256_fmadd_pd(ymm1, ymm9, ymm11);\ + \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1)); \ + ymm9 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1 + 1)); \ + \ + ymm5 = _mm256_fmadd_pd(ymm0, ymm2, ymm5);\ + ymm6 = _mm256_fmadd_pd(ymm1, ymm2, ymm6);\ + ymm12 = _mm256_fmadd_pd(ymm0, ymm9, ymm12);\ + ymm13 = _mm256_fmadd_pd(ymm1, ymm9, ymm13);\ + \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 2)); \ + ymm9 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 2 + 1)); \ + \ + ymm7 = _mm256_fmadd_pd(ymm0, ymm2, ymm7);\ + ymm8 = _mm256_fmadd_pd(ymm1, ymm2, ymm8);\ + ymm14 = _mm256_fmadd_pd(ymm0, ymm9, ymm14);\ + ymm15 = _mm256_fmadd_pd(ymm1, ymm9, ymm15);\ + \ + tptr += 2; \ + b10 += cs_b; \ + }\ + }\ + ymm10 = _mm256_permute_pd(ymm10, 0x5);\ + ymm11 = _mm256_permute_pd(ymm11, 0x5);\ + ymm12 = _mm256_permute_pd(ymm12, 0x5);\ + ymm13 = _mm256_permute_pd(ymm13, 0x5);\ + ymm14 = _mm256_permute_pd(ymm14, 0x5);\ + ymm15 = _mm256_permute_pd(ymm15, 0x5);\ +\ + ymm3 = _mm256_addsub_pd(ymm3, ymm10);\ + ymm4 = _mm256_addsub_pd(ymm4, ymm11);\ + ymm5 = _mm256_addsub_pd(ymm5, ymm12);\ + ymm6 = _mm256_addsub_pd(ymm6, ymm13);\ + ymm7 = _mm256_addsub_pd(ymm7, ymm14);\ + ymm8 = _mm256_addsub_pd(ymm8, ymm15);\ +} + +#define BLIS_ZTRSM_SMALL_GEMM_3nx1m(a01,b10,cs_b,p_lda,k_iter) {\ + double *tptr = (double *)a01;\ + if(conjtransa) {\ + ymm18 = _mm256_set_pd(-1.0, -1.0, -1.0, -1.0);\ + for(k = 0; k< k_iter; k++) \ + {\ + xmm5 = _mm_loadu_pd((double const *)(b10));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); \ + \ + _mm_prefetch((char*)( b10 + 2*cs_b), _MM_HINT_T0); \ + ymm4 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0)); \ + ymm6 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0 + 1)); \ + ymm6 = _mm256_mul_pd(ymm6, ymm18);\ + /*dcomplex multiplication and substraction*/\ + \ + ymm3 = _mm256_fmadd_pd(ymm0, ymm4, ymm3);\ + ymm8 = _mm256_fmadd_pd(ymm0, ymm6, ymm8);\ + \ + ymm4 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1)); \ + ymm6 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1 + 1)); \ + ymm6 = _mm256_mul_pd(ymm6, ymm18);\ + \ + /*dcomplex multiplication and substraction*/\ + \ + ymm5 = _mm256_fmadd_pd(ymm0, ymm4, ymm5);\ + ymm9 = _mm256_fmadd_pd(ymm0, ymm6, ymm9);\ + \ + ymm4 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 2)); \ + ymm6 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 2 + 1)); \ + ymm6 = _mm256_mul_pd(ymm6, ymm18);\ + \ + /*dcomplex multiplication and substraction*/\ + \ + ymm7 = _mm256_fmadd_pd(ymm0, ymm4, ymm7);\ + ymm10 = _mm256_fmadd_pd(ymm0, ymm6, ymm10);\ + \ + tptr += 2; \ + b10 += cs_b; \ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) \ + {\ + xmm5 = _mm_loadu_pd((double const *)(b10));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); \ + \ + _mm_prefetch((char*)( b10 + 2*cs_b), _MM_HINT_T0); \ + ymm4 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0)); \ + ymm6 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 0 + 1)); \ + /*dcomplex multiplication and substraction*/\ + \ + ymm3 = _mm256_fmadd_pd(ymm0, ymm4, ymm3);\ + ymm8 = _mm256_fmadd_pd(ymm0, ymm6, ymm8);\ + /*ymm3 = _mm256_add_pd(ymm15, ymm3);*/\ + \ + ymm4 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1)); \ + ymm6 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 1 + 1)); \ + \ + /*dcomplex multiplication and substraction*/\ + \ + ymm5 = _mm256_fmadd_pd(ymm0, ymm4, ymm5);\ + ymm9 = _mm256_fmadd_pd(ymm0, ymm6, ymm9);\ + /*ymm5 = _mm256_add_pd(ymm15, ymm5);*/\ + \ + ymm4 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 2)); \ + ymm6 = _mm256_broadcast_sd((double const *)(tptr + p_lda * 2 * 2 + 1)); \ + \ + /*dcomplex multiplication and substraction*/\ + \ + ymm7 = _mm256_fmadd_pd(ymm0, ymm4, ymm7);\ + ymm10 = _mm256_fmadd_pd(ymm0, ymm6, ymm10);\ + /*ymm7 = _mm256_add_pd(ymm15, ymm7);*/\ + \ + tptr += 2; \ + b10 += cs_b; \ + }\ + }\ + ymm8 = _mm256_permute_pd(ymm8, 0x5);\ + ymm9 = _mm256_permute_pd(ymm9, 0x5);\ + ymm10 = _mm256_permute_pd(ymm10, 0x5);\ + ymm3 = _mm256_addsub_pd(ymm3, ymm8);\ + ymm5 = _mm256_addsub_pd(ymm5, ymm9);\ + ymm7 = _mm256_addsub_pd(ymm7, ymm10);\ +} + /** * Multiplies Alpha with 4 element of 2 columns. * ymm0 and ymm1 holds 4 elements of a column. @@ -5907,6 +7541,72 @@ BLIS_INLINE err_t ztrsm_AuXB_ref ymm6 = _mm256_sub_pd(ymm15,ymm6);\ } +#define BLIS_PRE_ZTRSM_SMALL_2x1(AlphaVal,b11,cs_b){\ + ymm16 = _mm256_broadcast_pd(( __m128d const*)(&AlphaVal));\ + \ + xmm5 = _mm_loadu_pd((double const *)(b11));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); \ + ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ + \ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm0, ymm16);\ + ymm14 = _mm256_mul_pd(ymm0, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm3 = _mm256_sub_pd(ymm15,ymm3);\ + \ + xmm5 = _mm_loadu_pd((double const *)(b11 +cs_b));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); \ +\ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm0, ymm16);\ + ymm14 = _mm256_mul_pd(ymm0, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm5 = _mm256_sub_pd(ymm15,ymm5);\ +} + +#define BLIS_PRE_ZTRSM_SMALL_2x3(AlphaVal,b11,cs_b) {\ + ymm16 = _mm256_broadcast_pd(( __m128d const*)(&AlphaVal));\ + \ + ymm0 = _mm256_loadu_pd((double const *)(b11));\ + xmm5 = _mm_loadu_pd((double const *)(b11 + 2));\ + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); \ + ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ + \ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm0, ymm16);\ + ymm14 = _mm256_mul_pd(ymm0, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm3 = _mm256_sub_pd(ymm15,ymm3);\ + \ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm1, ymm16);\ + ymm14 = _mm256_mul_pd(ymm1, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm4 = _mm256_sub_pd(ymm15,ymm4);\ + \ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *1));\ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b *1 + 2));\ + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); \ +\ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm0, ymm16);\ + ymm14 = _mm256_mul_pd(ymm0, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm5 = _mm256_sub_pd(ymm15,ymm5);\ + \ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm1, ymm16);\ + ymm14 = _mm256_mul_pd(ymm1, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm6 = _mm256_sub_pd(ymm15,ymm6);\ +} + /** * Multiplies Alpha with 4 element of 3 columns. * ymm0 and ymm1 holds 4 elements of a column. @@ -5968,6 +7668,102 @@ BLIS_INLINE err_t ztrsm_AuXB_ref \ } +#define BLIS_PRE_ZTRSM_SMALL_3x3(AlphaVal,b11,cs_b) {\ + ymm16 = _mm256_broadcast_pd(( __m128d const*)(&AlphaVal));\ + \ + ymm0 = _mm256_loadu_pd((double const *)(b11));\ + xmm5 = _mm_loadu_pd((double const *)(b11 + 2));\ + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); \ + ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ + \ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm0, ymm16);\ + ymm14 = _mm256_mul_pd(ymm0, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm3 = _mm256_sub_pd(ymm15,ymm3);\ + \ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm1, ymm16);\ + ymm14 = _mm256_mul_pd(ymm1, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm4 = _mm256_sub_pd(ymm15,ymm4);\ + \ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *1));\ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b *1 + 2));\ + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); \ +\ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm0, ymm16);\ + ymm14 = _mm256_mul_pd(ymm0, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm5 = _mm256_sub_pd(ymm15,ymm5);\ + \ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm1, ymm16);\ + ymm14 = _mm256_mul_pd(ymm1, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm6 = _mm256_sub_pd(ymm15,ymm6);\ + \ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *2));\ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b *2 + 2));\ + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); \ + \ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm0, ymm16);\ + ymm14 = _mm256_mul_pd(ymm0, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm7 = _mm256_sub_pd(ymm15,ymm7);\ + \ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm1, ymm16);\ + ymm14 = _mm256_mul_pd(ymm1, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm8 = _mm256_sub_pd(ymm15,ymm8);\ + \ +} + +#define BLIS_PRE_ZTRSM_SMALL_3x1(AlphaVal,b11,cs_b) {\ + ymm16 = _mm256_broadcast_pd(( __m128d const*)(&AlphaVal));\ + \ + xmm5 = _mm_loadu_pd((double const *)(b11));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); \ + ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ + \ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm0, ymm16);\ + ymm14 = _mm256_mul_pd(ymm0, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm3 = _mm256_sub_pd(ymm15,ymm3);\ + \ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b*1));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); \ +\ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm0, ymm16);\ + ymm14 = _mm256_mul_pd(ymm0, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm5 = _mm256_sub_pd(ymm15,ymm5);\ + \ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b*2));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); \ + \ + ymm14 = _mm256_permute_pd(ymm16, 0x5);\ + ymm14 = _mm256_mul_pd(ymm14, ymm18);\ + ymm17 = _mm256_mul_pd(ymm0, ymm16);\ + ymm14 = _mm256_mul_pd(ymm0, ymm14);\ + ymm15 = _mm256_hsub_pd(ymm17, ymm14);\ + ymm7 = _mm256_sub_pd(ymm15,ymm7);\ + \ +} + /* * Pack a block of 4xk or 3xk from input buffer into packed buffer * directly or after transpose based on input params @@ -6158,7 +7954,7 @@ BLIS_INLINE err_t ztrsm_AuXB_ref \ ymm8 = _mm256_addsub_pd(ymm8, ymm4);\ ymm12 = _mm256_addsub_pd(ymm12, ymm5);\ - ymm9 = _mm256_addsub_pd(ymm9, ymm6);\ + ymm9 = _mm256_addsub_pd(ymm9, ymm6);\ ymm13 = _mm256_addsub_pd(ymm13, ymm7);\ } @@ -7015,7 +8811,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_6nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_DTRSM_SMALL_GEMM_6nx3m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha BLIS_PRE_DTRSM_SMALL_6x3(AlphaVal,b11,cs_b) @@ -7139,7 +8935,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_6nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_DTRSM_SMALL_GEMM_6nx2m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha BLIS_PRE_DTRSM_SMALL_6x2(AlphaVal,b11,cs_b) @@ -7256,7 +9052,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_6nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_DTRSM_SMALL_GEMM_6nx1m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha BLIS_PRE_DTRSM_SMALL_6x1(AlphaVal,b11,cs_b) @@ -7674,7 +9470,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_4nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_DTRSM_SMALL_GEMM_4nx3m(a01,b10,cs_b,p_lda,k_iter) ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha @@ -7764,7 +9560,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_4nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_DTRSM_SMALL_GEMM_4nx2m(a01,b10,cs_b,p_lda,k_iter) ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha @@ -7848,7 +9644,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_4nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_DTRSM_SMALL_GEMM_4nx1m(a01,b10,cs_b,p_lda,k_iter) ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha @@ -8191,7 +9987,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_DTRSM_SMALL_GEMM_3nx3m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_DTRSM_SMALL_3N_3M(AlphaVal,b11,cs_b) @@ -8242,7 +10038,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_DTRSM_SMALL_GEMM_3nx2m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_DTRSM_SMALL_3N_2M(AlphaVal,b11,cs_b) @@ -8293,7 +10089,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_DTRSM_SMALL_GEMM_3nx1m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_DTRSM_SMALL_3N_1M(AlphaVal,b11,cs_b) @@ -8554,7 +10350,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB ymm5 = _mm256_setzero_pd(); ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_DTRSM_SMALL_GEMM_2nx3m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_DTRSM_SMALL_2N_3M(AlphaVal,b11,cs_b) @@ -8591,7 +10387,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB ymm5 = _mm256_setzero_pd(); ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_DTRSM_SMALL_GEMM_2nx2m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_DTRSM_SMALL_2N_2M(AlphaVal,b11,cs_b) @@ -8627,7 +10423,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB ymm5 = _mm256_setzero_pd(); ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_DTRSM_SMALL_GEMM_2nx1m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_DTRSM_SMALL_2N_1M(AlphaVal,b11,cs_b) @@ -8855,7 +10651,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB ymm3 = _mm256_setzero_pd(); ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_DTRSM_SMALL_GEMM_1nx3m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_DTRSM_SMALL_1N_3M(AlphaVal,b11,cs_b) @@ -8882,7 +10678,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB ymm3 = _mm256_setzero_pd(); ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_DTRSM_SMALL_GEMM_1nx2m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_DTRSM_SMALL_1N_2M(AlphaVal,b11,cs_b) @@ -8909,7 +10705,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB ymm3 = _mm256_setzero_pd(); ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_DTRSM_SMALL_GEMM_1nx1m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_DTRSM_SMALL_1N_1M(AlphaVal,b11,cs_b) @@ -9395,7 +11191,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_6nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_DTRSM_SMALL_GEMM_6nx3m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha BLIS_PRE_DTRSM_SMALL_6x3(AlphaVal,b11,cs_b) @@ -9510,7 +11306,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_6nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_DTRSM_SMALL_GEMM_6nx2m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha BLIS_PRE_DTRSM_SMALL_6x2(AlphaVal,b11,cs_b) @@ -9618,7 +11414,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_6nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_DTRSM_SMALL_GEMM_6nx1m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha BLIS_PRE_DTRSM_SMALL_6x1(AlphaVal,b11,cs_b) @@ -10030,7 +11826,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_4nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_DTRSM_SMALL_GEMM_4nx3m(a01,b10,cs_b,p_lda,k_iter) ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha @@ -10115,7 +11911,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_4nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_DTRSM_SMALL_GEMM_4nx2m(a01,b10,cs_b,p_lda,k_iter) ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha @@ -10194,7 +11990,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_4nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_DTRSM_SMALL_GEMM_4nx1m(a01,b10,cs_b,p_lda,k_iter) ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha @@ -10539,7 +12335,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_DTRSM_SMALL_GEMM_3nx3m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_DTRSM_SMALL_3N_3M(AlphaVal,b11,cs_b) @@ -10586,7 +12382,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_DTRSM_SMALL_GEMM_3nx2m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_DTRSM_SMALL_3N_2M(AlphaVal,b11,cs_b) @@ -10634,7 +12430,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_DTRSM_SMALL_GEMM_3nx1m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_DTRSM_SMALL_3N_1M(AlphaVal,b11,cs_b) @@ -10902,7 +12698,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_DTRSM_SMALL_GEMM_2nx3m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_DTRSM_SMALL_2N_3M(AlphaVal,b11,cs_b) @@ -10938,7 +12734,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_DTRSM_SMALL_GEMM_2nx2m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_DTRSM_SMALL_2N_2M(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -10973,7 +12769,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_DTRSM_SMALL_GEMM_2nx1m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_DTRSM_SMALL_2N_1M(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -11206,7 +13002,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB ymm3 = _mm256_setzero_pd(); ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_DTRSM_SMALL_GEMM_1nx3m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_DTRSM_SMALL_1N_3M(AlphaVal,b11,cs_b) @@ -11234,7 +13030,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB ymm3 = _mm256_setzero_pd(); ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_DTRSM_SMALL_GEMM_1nx2m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_DTRSM_SMALL_1N_2M(AlphaVal,b11,cs_b) @@ -11259,7 +13055,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB ymm3 = _mm256_setzero_pd(); ///GEMM implementation starts/// - BLIS_DTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_DTRSM_SMALL_GEMM_1nx1m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_DTRSM_SMALL_1N_1M(AlphaVal,b11,cs_b) @@ -12701,35 +14497,54 @@ BLIS_INLINE err_t bli_dtrsm_small_AltXB_AuXB ///GEMM code ends/// ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to store alpha value - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 0)); + ymm0 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 0 + 2)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1)); + ymm1 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 1 + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2)); + ymm2 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 2 + 2)); + ymm2 = _mm256_insertf128_pd(ymm2, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); + ymm3 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 3 + 2)); + ymm3 = _mm256_insertf128_pd(ymm3, xmm5, 0); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x08); + _mm_storeu_pd((double *)(b11), _mm256_castpd256_pd128(ymm8)); + _mm_storeu_pd((double *)(b11 + cs_b * 1), _mm256_castpd256_pd128(ymm9)); + _mm_storeu_pd((double *)(b11 + cs_b * 2), _mm256_castpd256_pd128(ymm10)); + _mm_storeu_pd((double *)(b11 + cs_b * 3), _mm256_castpd256_pd128(ymm11)); - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) + _mm_storel_pd((double *)(b11 + 2), _mm256_extractf128_pd(ymm8,1)); + _mm_storel_pd((double *)(b11 + cs_b * 1 + 2), _mm256_extractf128_pd(ymm9,1)); + _mm_storel_pd((double *)(b11 + cs_b * 2 + 2), _mm256_extractf128_pd(ymm10,1)); + _mm_storel_pd((double *)(b11 + cs_b * 3 + 2), _mm256_extractf128_pd(ymm11,1)); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 4)); + ymm0 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 4 + 2)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 5)); + ymm1 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 5 + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); + ymm4 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); + ymm5 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store(B11[0-3][3]) + _mm_storeu_pd((double *)(b11 + cs_b * 4), _mm256_castpd256_pd128(ymm4)); + _mm_storeu_pd((double *)(b11 + cs_b * 5), _mm256_castpd256_pd128(ymm5)); + + _mm_storel_pd((double *)(b11 + cs_b * 4 + 2), _mm256_extractf128_pd(ymm4,1)); + _mm_storel_pd((double *)(b11 + cs_b * 5 + 2), _mm256_extractf128_pd(ymm5,1)); if(transa) dtrsm_AltXB_ref(a11, b11, m_remainder, 6, cs_a, cs_b, is_unitdiag); @@ -12757,11 +14572,20 @@ BLIS_INLINE err_t bli_dtrsm_small_AltXB_AuXB ///implement TRSM/// - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 0)); + ymm0 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 0 + 2)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1)); + ymm1 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 1 + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2)); + ymm2 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 2 + 2)); + ymm2 = _mm256_insertf128_pd(ymm2, xmm5, 0); + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); - ymm3 = _mm256_broadcast_sd((double const *)(b11 + cs_b*3 + 2)); + ymm3 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 3 + 2)); ymm3 = _mm256_insertf128_pd(ymm3, xmm5, 0); ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); @@ -12769,17 +14593,15 @@ BLIS_INLINE err_t bli_dtrsm_small_AltXB_AuXB ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x08); + _mm_storeu_pd((double *)(b11), _mm256_castpd256_pd128(ymm8)); + _mm_storeu_pd((double *)(b11 + cs_b * 1), _mm256_castpd256_pd128(ymm9)); + _mm_storeu_pd((double *)(b11 + cs_b * 2), _mm256_castpd256_pd128(ymm10)); + _mm_storeu_pd((double *)(b11 + cs_b * 3), _mm256_castpd256_pd128(ymm11)); - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - xmm5 = _mm256_castpd256_pd128(ymm3); - _mm_storeu_pd((double *)(b11 + cs_b * 3),xmm5); - _mm_storel_pd((b11 + cs_b * 3 + 2), _mm256_extractf128_pd(ymm3, 1)); + _mm_storel_pd((double *)(b11 + 2), _mm256_extractf128_pd(ymm8,1)); + _mm_storel_pd((double *)(b11 + cs_b * 1 + 2), _mm256_extractf128_pd(ymm9,1)); + _mm_storel_pd((double *)(b11 + cs_b * 2 + 2), _mm256_extractf128_pd(ymm10,1)); + _mm_storel_pd((double *)(b11 + cs_b * 3 + 2), _mm256_extractf128_pd(ymm11,1)); if(transa) dtrsm_AltXB_ref(a11, b11, m_remainder, 4, cs_a, cs_b, is_unitdiag); @@ -12897,35 +14719,40 @@ BLIS_INLINE err_t bli_dtrsm_small_AltXB_AuXB ///GEMM code ends/// ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to store alpha value - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 0)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2)); + ymm2 = _mm256_insertf128_pd(ymm2, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); + ymm3 = _mm256_insertf128_pd(ymm3, xmm5, 0); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0C); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0C); + _mm_storeu_pd((double *)(b11), _mm256_castpd256_pd128(ymm8)); + _mm_storeu_pd((double *)(b11 + cs_b * 1), _mm256_castpd256_pd128(ymm9)); + _mm_storeu_pd((double *)(b11 + cs_b * 2), _mm256_castpd256_pd128(ymm10)); + _mm_storeu_pd((double *)(b11 + cs_b * 3), _mm256_castpd256_pd128(ymm11)); - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 4)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 5)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); + ymm4 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); + ymm5 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store(B11[0-3][3]) + _mm_storeu_pd((double *)(b11 + cs_b * 4), _mm256_castpd256_pd128(ymm4)); + _mm_storeu_pd((double *)(b11 + cs_b * 5), _mm256_castpd256_pd128(ymm5)); if(transa) dtrsm_AltXB_ref(a11, b11, m_remainder, 6, cs_a, cs_b, is_unitdiag); @@ -12952,9 +14779,15 @@ BLIS_INLINE err_t bli_dtrsm_small_AltXB_AuXB ///implement TRSM/// - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 0)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2)); + ymm2 = _mm256_insertf128_pd(ymm2, xmm5, 0); + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); ymm3 = _mm256_insertf128_pd(ymm3, xmm5, 0); @@ -12963,16 +14796,10 @@ BLIS_INLINE err_t bli_dtrsm_small_AltXB_AuXB ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0C); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0C); - - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - xmm5 = _mm256_castpd256_pd128(ymm3); - _mm_storeu_pd((double *)(b11 + cs_b * 3), xmm5); + _mm_storeu_pd((double *)(b11), _mm256_castpd256_pd128(ymm8)); + _mm_storeu_pd((double *)(b11 + cs_b * 1), _mm256_castpd256_pd128(ymm9)); + _mm_storeu_pd((double *)(b11 + cs_b * 2), _mm256_castpd256_pd128(ymm10)); + _mm_storeu_pd((double *)(b11 + cs_b * 3), _mm256_castpd256_pd128(ymm11)); if(transa) dtrsm_AltXB_ref(a11, b11, m_remainder, 4, cs_a, cs_b, is_unitdiag); @@ -13089,35 +14916,30 @@ BLIS_INLINE err_t bli_dtrsm_small_AltXB_AuXB ///GEMM code ends/// ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to store alpha value - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_broadcast_sd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_broadcast_sd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_broadcast_sd((double const *)(b11 + cs_b *3)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0E); + _mm_storel_pd((double *)(b11), _mm256_extractf128_pd(ymm8,0)); + _mm_storel_pd((double *)(b11 + cs_b * 1), _mm256_extractf128_pd(ymm9,0)); + _mm_storel_pd((double *)(b11 + cs_b * 2), _mm256_extractf128_pd(ymm10,0)); + _mm_storel_pd((double *)(b11 + cs_b * 3), _mm256_extractf128_pd(ymm11,0)); - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b *4)); + ymm1 = _mm256_broadcast_sd((double const *)(b11 + cs_b *5)); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); + ymm4 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); + ymm5 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); - - _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store(B11[0-3][3]) + _mm_storel_pd((double *)(b11 + cs_b * 4), _mm256_extractf128_pd(ymm4,0)); + _mm_storel_pd((double *)(b11 + cs_b * 5), _mm256_extractf128_pd(ymm5,0)); if(transa) dtrsm_AltXB_ref(a11, b11, m_remainder, 6, cs_a, cs_b, is_unitdiag); @@ -13144,24 +14966,20 @@ BLIS_INLINE err_t bli_dtrsm_small_AltXB_AuXB ///implement TRSM/// - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_broadcast_sd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_broadcast_sd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_broadcast_sd((double const *)(b11 + cs_b *3)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0E); - - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) + _mm_storel_pd((double *)(b11), _mm256_extractf128_pd(ymm8,0)); + _mm_storel_pd((double *)(b11 + cs_b * 1), _mm256_extractf128_pd(ymm9,0)); + _mm_storel_pd((double *)(b11 + cs_b * 2), _mm256_extractf128_pd(ymm10,0)); + _mm_storel_pd((double *)(b11 + cs_b * 3), _mm256_extractf128_pd(ymm11,0)); if(transa) dtrsm_AltXB_ref(a11, b11, m_remainder, 4, cs_a, cs_b, is_unitdiag); @@ -13228,6 +15046,7 @@ BLIS_INLINE err_t bli_dtrsm_small_AltXB_AuXB return BLIS_SUCCESS; } + /* TRSM for the Left Upper case AX = alpha * B, Double precision * A is Left side, upper-triangular, transpose, non-unit/unit diagonal * dimensions A: mxm X: mxn B: mxn @@ -14761,238 +16580,56 @@ BLIS_INLINE err_t bli_dtrsm_small_AutXB_AlXB ///GEMM code ends/// ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to store alpha value - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x08); - - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); - - _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store(B11[0-3][3]) - - if(transa) - dtrsm_AutXB_ref(a11, b11, m_rem, 6, cs_a, cs_b,is_unitdiag); - else - dtrsm_AlXB_ref(a11, b11, m_rem, 6, rs_a, cs_b, is_unitdiag); - } - - dim_t n_rem = n-j; - if((n_rem >= 4)) - { - a10 = D_A_pack; //pointer to block of A to be used for GEMM - a11 = L + (i*rs_a) + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM - b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM - - k_iter = i; //number of times GEMM to be performed(in blocks of 4x4) - - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS - - ///GEMM code begins/// - BLIS_DTRSM_SMALL_GEMM_4mx4n(a10,b01,cs_b,p_lda,k_iter) + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 0)); + ymm0 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 0 + 2)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1)); + ymm1 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 1 + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); - ///implement TRSM/// + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2)); + ymm2 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 2 + 2)); + ymm2 = _mm256_insertf128_pd(ymm2, xmm5, 0); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); - ymm3 = _mm256_broadcast_sd((double const *)(b11 + cs_b*3 + 2)); + ymm3 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 3 + 2)); ymm3 = _mm256_insertf128_pd(ymm3, xmm5, 0); - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x08); - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - xmm5 = _mm256_castpd256_pd128(ymm3); - _mm_storeu_pd((double *)(b11 + cs_b * 3),xmm5); - _mm_storel_pd((b11 + cs_b * 3 + 2), _mm256_extractf128_pd(ymm3, 1)); - - if(transa) - dtrsm_AutXB_ref(a11, b11, m_rem, 4, cs_a, cs_b,is_unitdiag); - else - dtrsm_AlXB_ref(a11, b11, m_rem, 4, rs_a, cs_b, is_unitdiag); - n_rem -= 4; - j +=4; - } - - if(n_rem) - { - a10 = D_A_pack; //pointer to block of A to be used for GEMM - a11 = L + (i*rs_a) + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM - b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM - - k_iter = i; //number of times GEMM to be performed(in blocks of 4x4) - - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS - - if(3 == n_rem) - { - ///GEMM code begins/// - BLIS_DTRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) - - BLIS_PRE_DTRSM_SMALL_3M_3N(AlphaVal,b11,cs_b) - - if(transa) - dtrsm_AutXB_ref(a11, b11, m_rem, 3, cs_a, cs_b,is_unitdiag); - else - dtrsm_AlXB_ref(a11, b11, m_rem, 3, rs_a, cs_b, is_unitdiag); - } - else if(2 == n_rem) - { - ///GEMM code begins/// - BLIS_DTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b,p_lda,k_iter) - - BLIS_PRE_DTRSM_SMALL_3M_2N(AlphaVal,b11,cs_b) - - if(transa) - dtrsm_AutXB_ref(a11, b11, m_rem, 2, cs_a, cs_b,is_unitdiag); - else - dtrsm_AlXB_ref(a11, b11, m_rem, 2, rs_a, cs_b, is_unitdiag); - } - else if(1 == n_rem) - { - ///GEMM code begins/// - BLIS_DTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b,p_lda,k_iter) - - BLIS_PRE_DTRSM_SMALL_3M_1N(AlphaVal,b11,cs_b) - - if(transa) - dtrsm_AutXB_ref(a11, b11, m_rem, 1, cs_a, cs_b, is_unitdiag); - else - dtrsm_AlXB_ref(a11, b11, m_rem, 1, rs_a, cs_b, is_unitdiag); - } - } - } - else if(2 == m_rem) // Repetative A blocks will be 2*2 - { - dim_t p_lda = 4; // packed leading dimension - if(transa) - { - for(dim_t x=0;x= 4)) + { + a10 = D_A_pack; //pointer to block of A to be used for GEMM + a11 = L + (i*rs_a) + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM + b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM + + k_iter = i; //number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx4n(a10,b01,cs_b,p_lda,k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + + ///implement TRSM/// + + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 0)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2)); + ymm2 = _mm256_insertf128_pd(ymm2, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); + ymm3 = _mm256_insertf128_pd(ymm3, xmm5, 0); + + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + _mm_storeu_pd((double *)(b11), _mm256_castpd256_pd128(ymm8)); + _mm_storeu_pd((double *)(b11 + cs_b * 1), _mm256_castpd256_pd128(ymm9)); + _mm_storeu_pd((double *)(b11 + cs_b * 2), _mm256_castpd256_pd128(ymm10)); + _mm_storeu_pd((double *)(b11 + cs_b * 3), _mm256_castpd256_pd128(ymm11)); if(transa) dtrsm_AutXB_ref(a11, b11, m_rem, 4, cs_a, cs_b, is_unitdiag); @@ -15158,35 +17004,29 @@ BLIS_INLINE err_t bli_dtrsm_small_AutXB_AlXB ///GEMM code ends/// ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to store alpha value - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_broadcast_sd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_broadcast_sd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_broadcast_sd((double const *)(b11 + cs_b *3)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0E); - - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) + _mm_storel_pd((double *)(b11), _mm256_extractf128_pd(ymm8,0)); + _mm_storel_pd((double *)(b11 + cs_b * 1), _mm256_extractf128_pd(ymm9,0)); + _mm_storel_pd((double *)(b11 + cs_b * 2), _mm256_extractf128_pd(ymm10,0)); + _mm_storel_pd((double *)(b11 + cs_b * 3), _mm256_extractf128_pd(ymm11,0)); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b *4)); + ymm1 = _mm256_broadcast_sd((double const *)(b11 + cs_b *5)); - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); + ymm4 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); + ymm5 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store(B11[0-3][3]) + _mm_storel_pd((double *)(b11 + cs_b * 4), _mm256_extractf128_pd(ymm4,0)); + _mm_storel_pd((double *)(b11 + cs_b * 5), _mm256_extractf128_pd(ymm5,0)); if(transa) dtrsm_AutXB_ref(a11, b11, m_rem, 6, cs_a, cs_b, is_unitdiag); @@ -15222,15 +17062,10 @@ BLIS_INLINE err_t bli_dtrsm_small_AutXB_AlXB ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0E); - - _mm_storel_pd((b11 + cs_b * 0), _mm256_castpd256_pd128(ymm0)); - _mm_storel_pd((b11 + cs_b * 1), _mm256_castpd256_pd128(ymm1)); - _mm_storel_pd((b11 + cs_b * 2), _mm256_castpd256_pd128(ymm2)); - _mm_storel_pd((b11 + cs_b * 3), _mm256_castpd256_pd128(ymm3)); + _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm8,0)); + _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm9,0)); + _mm_storel_pd((b11 + cs_b * 2), _mm256_extractf128_pd(ymm10,0)); + _mm_storel_pd((b11 + cs_b * 3), _mm256_extractf128_pd(ymm11,0)); if(transa) dtrsm_AutXB_ref(a11, b11, m_rem, 4, cs_a, cs_b, is_unitdiag); @@ -16268,7 +18103,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_6nx7m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha BLIS_PRE_STRSM_SMALL_6x7(AlphaVal,b11,cs_b) @@ -16393,7 +18228,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_6nx6m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha BLIS_PRE_STRSM_SMALL_6x6(AlphaVal,b11,cs_b) @@ -16507,7 +18342,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_6nx5m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha BLIS_PRE_STRSM_SMALL_6x5(AlphaVal,b11,cs_b) @@ -16621,7 +18456,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_6nx4m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha BLIS_PRE_STRSM_SMALL_6x4(AlphaVal,b11,cs_b) @@ -16729,7 +18564,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_6nx3m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha BLIS_PRE_STRSM_SMALL_6x3(AlphaVal,b11,cs_b) @@ -16854,7 +18689,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_6nx2m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha BLIS_PRE_STRSM_SMALL_6x2(AlphaVal,b11,cs_b) @@ -16973,7 +18808,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_6nx1m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha BLIS_PRE_STRSM_SMALL_6x1(AlphaVal,b11,cs_b) @@ -17383,7 +19218,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_4nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_4nx7m(a01,b10,cs_b,p_lda,k_iter) ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); //register to hold alpha @@ -17497,7 +19332,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_4nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_4nx6m(a01,b10,cs_b,p_lda,k_iter) ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); //register to hold alpha @@ -17600,7 +19435,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_4nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_4nx5m(a01,b10,cs_b,p_lda,k_iter) ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); //register to hold alpha @@ -17691,7 +19526,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_4nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_4nx4m(a01,b10,cs_b,p_lda,k_iter) ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); //register to hold alpha @@ -17774,7 +19609,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_4nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_4nx3m(a01,b10,cs_b,p_lda,k_iter) ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); //register to hold alpha @@ -17872,7 +19707,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_4nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_4nx2m(a01,b10,cs_b,p_lda,k_iter) ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); //register to hold alpha @@ -17963,7 +19798,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_4nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_4nx1m(a01,b10,cs_b,p_lda,k_iter) ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); //register to hold alpha @@ -18306,7 +20141,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_3nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_3nx7m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_3N_7M(AlphaVal,b11,cs_b) @@ -18353,7 +20188,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_3nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_3nx6m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_3N_6M(AlphaVal,b11,cs_b) @@ -18400,7 +20235,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_3nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_3nx5m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_3N_5M(AlphaVal,b11,cs_b) @@ -18448,7 +20283,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_3nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_3N_4M(AlphaVal,b11,cs_b) @@ -18495,7 +20330,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_3nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_3nx3m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_3N_3M(AlphaVal,b11,cs_b) @@ -18542,7 +20377,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_3nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_3nx2m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_3N_2M(AlphaVal,b11,cs_b) @@ -18590,7 +20425,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_3nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_3nx1m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_3N_1M(AlphaVal,b11,cs_b) @@ -18852,7 +20687,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_2nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_2nx7m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_2N_7M(AlphaVal,b11,cs_b) @@ -18888,7 +20723,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_2nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_2nx6m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_2N_6M(AlphaVal,b11,cs_b) @@ -18924,7 +20759,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_2nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_2nx5m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_2N_5M(AlphaVal,b11,cs_b) @@ -18960,7 +20795,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_2nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_2N_4M(AlphaVal,b11,cs_b) @@ -18996,7 +20831,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_2nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_2nx3m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_2N_3M(AlphaVal,b11,cs_b) @@ -19032,7 +20867,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_2nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_2nx2m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_2N_2M(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -19067,7 +20902,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_2nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_2nx1m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_2N_1M(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -19302,7 +21137,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB ymm3 = _mm256_setzero_ps(); ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_1nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_1nx7m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_1N_7M(AlphaVal,b11,cs_b) @@ -19328,7 +21163,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB ymm3 = _mm256_setzero_ps(); ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_1nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_1nx6m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_1N_6M(AlphaVal,b11,cs_b) @@ -19379,7 +21214,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB ymm3 = _mm256_setzero_ps(); ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_1nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_1N_4M(AlphaVal,b11,cs_b) @@ -19404,7 +21239,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB ymm3 = _mm256_setzero_ps(); ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_1nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_1nx3m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_1N_3M(AlphaVal,b11,cs_b) @@ -19429,7 +21264,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB ymm3 = _mm256_setzero_ps(); ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_1nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_1nx2m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_1N_2M(AlphaVal,b11,cs_b) @@ -19454,7 +21289,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB ymm3 = _mm256_setzero_ps(); ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_1nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_1nx1m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_1N_1M(AlphaVal,b11,cs_b) @@ -19952,7 +21787,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_6nx7m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 8x6 and multiply with alpha BLIS_PRE_STRSM_SMALL_6x7(AlphaVal,b11,cs_b) @@ -20086,7 +21921,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_6nx6m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 8x6 and multiply with alpha BLIS_PRE_STRSM_SMALL_6x6(AlphaVal,b11,cs_b) @@ -20209,7 +22044,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_6nx5m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 8x6 and multiply with alpha BLIS_PRE_STRSM_SMALL_6x5(AlphaVal,b11,cs_b) @@ -20332,7 +22167,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_6nx4m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 8x6 and multiply with alpha BLIS_PRE_STRSM_SMALL_6x4(AlphaVal,b11,cs_b) @@ -20449,7 +22284,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_6nx3m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 8x6 and multiply with alpha BLIS_PRE_STRSM_SMALL_6x3(AlphaVal,b11,cs_b) @@ -20583,7 +22418,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_6nx2m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha BLIS_PRE_STRSM_SMALL_6x2(AlphaVal,b11,cs_b) @@ -20711,7 +22546,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_6nx1m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha BLIS_PRE_STRSM_SMALL_6x1(AlphaVal,b11,cs_b) @@ -21137,7 +22972,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_4nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_4nx7m(a01,b10,cs_b,p_lda,k_iter) ymm15 = _mm256_broadcast_ss((float const *)(&AlphaVal)); //register to hold alpha @@ -21256,7 +23091,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_4nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_4nx6m(a01,b10,cs_b,p_lda,k_iter) ymm15 = _mm256_broadcast_ss((float const *)(&AlphaVal)); //register to hold alpha @@ -21364,7 +23199,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_4nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_4nx5m(a01,b10,cs_b,p_lda,k_iter) ymm15 = _mm256_broadcast_ss((float const *)(&AlphaVal)); //register to hold alpha @@ -21460,7 +23295,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_4nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_4nx4m(a01,b10,cs_b,p_lda,k_iter) ymm15 = _mm256_broadcast_ss((float const *)(&AlphaVal)); //register to hold alpha @@ -21548,7 +23383,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_4nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_4nx3m(a01,b10,cs_b,p_lda,k_iter) ymm15 = _mm256_broadcast_ss((float const *)(&AlphaVal)); //register to hold alpha @@ -21651,7 +23486,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_4nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_4nx2m(a01,b10,cs_b,p_lda,k_iter) ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); //register to hold alpha @@ -21747,7 +23582,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_4nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_4nx1m(a01,b10,cs_b,p_lda,k_iter) ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); //register to hold alpha @@ -22098,7 +23933,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_3nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_3nx7m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_3N_7M(AlphaVal,b11,cs_b) @@ -22149,7 +23984,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_3nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_3nx6m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_3N_6M(AlphaVal,b11,cs_b) @@ -22200,7 +24035,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_3nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_3nx5m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_3N_5M(AlphaVal,b11,cs_b) @@ -22251,7 +24086,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_3nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_3N_4M(AlphaVal,b11,cs_b) @@ -22302,7 +24137,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_3nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_3nx3m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_3N_3M(AlphaVal,b11,cs_b) @@ -22353,7 +24188,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_3nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_3nx2m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_3N_2M(AlphaVal,b11,cs_b) @@ -22404,7 +24239,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_3nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_3nx1m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_3N_1M(AlphaVal,b11,cs_b) @@ -22669,7 +24504,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB ymm5 = _mm256_setzero_ps(); ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_2nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_2nx7m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_2N_7M(AlphaVal,b11,cs_b) @@ -22706,7 +24541,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB ymm5 = _mm256_setzero_ps(); ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_2nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_2nx6m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_2N_6M(AlphaVal,b11,cs_b) @@ -22743,7 +24578,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB ymm5 = _mm256_setzero_ps(); ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_2nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_2nx5m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_2N_5M(AlphaVal,b11,cs_b) @@ -22780,7 +24615,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB ymm5 = _mm256_setzero_ps(); ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_2nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_2N_4M(AlphaVal,b11,cs_b) @@ -22817,7 +24652,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB ymm5 = _mm256_setzero_ps(); ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_2nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_2nx3m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_2N_3M(AlphaVal,b11,cs_b) @@ -22854,7 +24689,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB ymm5 = _mm256_setzero_ps(); ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_2nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_2nx2m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_2N_2M(AlphaVal,b11,cs_b) @@ -22890,7 +24725,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB ymm5 = _mm256_setzero_ps(); ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_2nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_2nx1m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_2N_1M(AlphaVal,b11,cs_b) @@ -23130,7 +24965,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB ymm3 = _mm256_setzero_ps(); ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_1nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_1nx7m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_1N_7M(AlphaVal,b11,cs_b) @@ -23156,7 +24991,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB ymm3 = _mm256_setzero_ps(); ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_1nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_1nx6m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_1N_6M(AlphaVal,b11,cs_b) @@ -23182,7 +25017,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB ymm3 = _mm256_setzero_ps(); ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_1nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_1nx5m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_1N_5M(AlphaVal,b11,cs_b) @@ -23208,7 +25043,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB ymm3 = _mm256_setzero_ps(); ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_1nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_1N_4M(AlphaVal,b11,cs_b) @@ -23234,7 +25069,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB ymm3 = _mm256_setzero_ps(); ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_1nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_1nx3m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_1N_3M(AlphaVal,b11,cs_b) @@ -23261,7 +25096,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB ymm3 = _mm256_setzero_ps(); ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_1nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_1nx2m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_1N_2M(AlphaVal,b11,cs_b) @@ -23288,7 +25123,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB ymm3 = _mm256_setzero_ps(); ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_1nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_1nx1m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_1N_1M(AlphaVal,b11,cs_b) @@ -32700,12 +34535,15 @@ BLIS_INLINE err_t bli_ztrsm_small_AutXB_AlXB _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm9); _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm10); - ymm0 = _mm256_loadu_pd((double const *) - (b11 + cs_b *0 + 2)); - ymm1 = _mm256_loadu_pd((double const *) - (b11 + cs_b *1 + 2)); - ymm2 = _mm256_loadu_pd((double const *) - (b11 + cs_b *2 + 2)); + xmm4 = _mm_loadu_pd((double const *) + (b11 + cs_b *0 + 2)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm4, 0); + xmm4 = _mm_loadu_pd((double const *) + (b11 + cs_b *1 + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm4, 0); + xmm4 = _mm_loadu_pd((double const *) + (b11 + cs_b *2 + 2)); + ymm2 = _mm256_insertf128_pd(ymm2, xmm4, 0); ymm14 = _mm256_permute_pd(ymm16, 0x5); ymm14 = _mm256_mul_pd(ymm14, ymm18); @@ -32896,7 +34734,7 @@ BLIS_INLINE err_t bli_ztrsm_small_AutXB_AlXB if(2 == n_rem) { ///GEMM code begins/// - BLIS_ZTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b, + BLIS_ZTRSM_SMALL_GEMM_2mx2n(a10,b01,cs_b, p_lda,k_iter) BLIS_PRE_ZTRSM_SMALL_2M_2N(AlphaVal,b11,cs_b) @@ -32913,7 +34751,7 @@ BLIS_INLINE err_t bli_ztrsm_small_AutXB_AlXB else if(1 == n_rem) { ///GEMM code begins/// - BLIS_ZTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b, + BLIS_ZTRSM_SMALL_GEMM_2mx1n(a10,b01,cs_b, p_lda,k_iter) BLIS_PRE_ZTRSM_SMALL_2M_1N(AlphaVal,b11,cs_b) @@ -32976,14 +34814,17 @@ BLIS_INLINE err_t bli_ztrsm_small_AutXB_AlXB BLIS_SET_YMM_REG_ZEROS ///GEMM code begins/// - BLIS_ZTRSM_SMALL_GEMM_2mx3n(a10,b01,cs_b,p_lda,k_iter) + BLIS_ZTRSM_SMALL_GEMM_1mx3n(a10,b01,cs_b,p_lda,k_iter) ///GEMM code ends/// ymm16 = _mm256_broadcast_pd((__m128d const *) (&AlphaVal)); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + xmm4 = _mm_loadu_pd((double const *)(b11 + cs_b *0)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm4, 0); + xmm4 = _mm_loadu_pd((double const *)(b11 + cs_b *1)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm4, 0); + xmm4 = _mm_loadu_pd((double const *)(b11 + cs_b *2)); + ymm2 = _mm256_insertf128_pd(ymm2, xmm4, 0); ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); ymm14 = _mm256_permute_pd(ymm16, 0x5); @@ -33040,7 +34881,7 @@ BLIS_INLINE err_t bli_ztrsm_small_AutXB_AlXB if(2 == n_rem) { ///GEMM code begins/// - BLIS_ZTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b, + BLIS_ZTRSM_SMALL_GEMM_1mx2n(a10,b01,cs_b, p_lda,k_iter) BLIS_PRE_ZTRSM_SMALL_1M_2N(AlphaVal,b11,cs_b) @@ -33057,7 +34898,7 @@ BLIS_INLINE err_t bli_ztrsm_small_AutXB_AlXB else if(1 == n_rem) { ///GEMM code begins/// - BLIS_ZTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b, + BLIS_ZTRSM_SMALL_GEMM_1mx1n(a10,b01,cs_b, p_lda,k_iter) BLIS_PRE_ZTRSM_SMALL_1M_1N(AlphaVal,b11,cs_b) @@ -33883,7 +35724,7 @@ BLIS_INLINE err_t bli_ztrsm_small_AltXB_AuXB BLIS_SET_YMM_REG_ZEROS ///GEMM code begins/// - BLIS_ZTRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) + BLIS_ZTRSM_SMALL_GEMM_3mx3n(a10,b01,cs_b,p_lda,k_iter) ///GEMM code ends/// ymm16 = _mm256_broadcast_pd((__m128d const *) (&AlphaVal)); @@ -33922,12 +35763,15 @@ BLIS_INLINE err_t bli_ztrsm_small_AltXB_AuXB _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm9); _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm10); - ymm0 = _mm256_loadu_pd((double const *) + xmm4 = _mm_loadu_pd((double const *) (b11 + cs_b *0 + 2)); - ymm1 = _mm256_loadu_pd((double const *) + ymm0 = _mm256_insertf128_pd(ymm0, xmm4, 0); + xmm4 = _mm_loadu_pd((double const *) (b11 + cs_b *1 + 2)); - ymm2 = _mm256_loadu_pd((double const *) + ymm1 = _mm256_insertf128_pd(ymm1, xmm4, 0); + xmm4 = _mm_loadu_pd((double const *) (b11 + cs_b *2 + 2)); + ymm2 = _mm256_insertf128_pd(ymm2, xmm4, 0); ymm14 = _mm256_permute_pd(ymm16, 0x5); ymm14 = _mm256_mul_pd(ymm14, ymm18); @@ -33979,7 +35823,7 @@ BLIS_INLINE err_t bli_ztrsm_small_AltXB_AuXB if(2 == n_remainder) { ///GEMM code begins/// - BLIS_ZTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b, + BLIS_ZTRSM_SMALL_GEMM_3mx2n(a10,b01,cs_b, p_lda,k_iter) BLIS_PRE_ZTRSM_SMALL_3M_2N(AlphaVal,b11,cs_b) @@ -33996,7 +35840,7 @@ BLIS_INLINE err_t bli_ztrsm_small_AltXB_AuXB else if(1 == n_remainder) { ///GEMM code begins/// - BLIS_ZTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b, + BLIS_ZTRSM_SMALL_GEMM_3mx1n(a10,b01,cs_b, p_lda,k_iter) BLIS_PRE_ZTRSM_SMALL_3M_1N(AlphaVal,b11,cs_b) @@ -34116,7 +35960,7 @@ BLIS_INLINE err_t bli_ztrsm_small_AltXB_AuXB if(2 == n_remainder) { ///GEMM code begins/// - BLIS_ZTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b, + BLIS_ZTRSM_SMALL_GEMM_2mx2n(a10,b01,cs_b, p_lda,k_iter) BLIS_PRE_ZTRSM_SMALL_2M_2N(AlphaVal,b11,cs_b) @@ -34133,7 +35977,7 @@ BLIS_INLINE err_t bli_ztrsm_small_AltXB_AuXB else if(1 == n_remainder) { ///GEMM code begins/// - BLIS_ZTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b, + BLIS_ZTRSM_SMALL_GEMM_2mx1n(a10,b01,cs_b, p_lda,k_iter) BLIS_PRE_ZTRSM_SMALL_2M_1N(AlphaVal,b11,cs_b) @@ -34194,14 +36038,17 @@ BLIS_INLINE err_t bli_ztrsm_small_AltXB_AuXB BLIS_SET_YMM_REG_ZEROS ///GEMM code begins/// - BLIS_ZTRSM_SMALL_GEMM_2mx3n(a10,b01,cs_b,p_lda,k_iter) + BLIS_ZTRSM_SMALL_GEMM_1mx3n(a10,b01,cs_b,p_lda,k_iter) ///GEMM code ends/// ymm16 = _mm256_broadcast_pd((__m128d const *) (&AlphaVal)); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + xmm4 = _mm_loadu_pd((double const *)(b11 + cs_b *0)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm4, 0); + xmm4 = _mm_loadu_pd((double const *)(b11 + cs_b *1)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm4, 0); + xmm4 = _mm_loadu_pd((double const *)(b11 + cs_b *2)); + ymm2 = _mm256_insertf128_pd(ymm2, xmm4, 0); ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); ymm14 = _mm256_permute_pd(ymm16, 0x5); @@ -34257,7 +36104,7 @@ BLIS_INLINE err_t bli_ztrsm_small_AltXB_AuXB { ///GEMM code begins/// - BLIS_ZTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b, + BLIS_ZTRSM_SMALL_GEMM_1mx2n(a10,b01,cs_b, p_lda,k_iter) BLIS_PRE_ZTRSM_SMALL_1M_2N(AlphaVal,b11,cs_b) @@ -34274,7 +36121,7 @@ BLIS_INLINE err_t bli_ztrsm_small_AltXB_AuXB else if(1 == n_remainder) { ///GEMM code begins/// - BLIS_ZTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b, + BLIS_ZTRSM_SMALL_GEMM_1mx1n(a10,b01,cs_b, p_lda,k_iter) BLIS_PRE_ZTRSM_SMALL_1M_1N(AlphaVal,b11,cs_b) @@ -34632,7 +36479,7 @@ BLIS_INLINE err_t bli_ztrsm_small_XAutB_XAlB where k_iter are zero */ - BLIS_ZTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_ZTRSM_SMALL_GEMM_3nx3m(a01,b10,cs_b,p_lda,k_iter) /* Load b11 multiply with alpha @@ -34640,7 +36487,7 @@ BLIS_INLINE err_t bli_ztrsm_small_XAutB_XAlB and peform TRSM operation. */ - BLIS_PRE_ZTRSM_SMALL_3x4(AlphaVal,b11,cs_b) + BLIS_PRE_ZTRSM_SMALL_3x3(AlphaVal,b11,cs_b) ///implement TRSM/// /* Compute 3x3 TRSM block by using GEMM block output in @@ -34973,7 +36820,7 @@ BLIS_INLINE err_t bli_ztrsm_small_XAutB_XAlB where k_iter are zero */ - BLIS_ZTRSM_SMALL_GEMM_3nx2m(a01,b10,cs_b,p_lda,k_iter) + BLIS_ZTRSM_SMALL_GEMM_3nx1m(a01,b10,cs_b,p_lda,k_iter) /* Load b11 and multiply with alpha @@ -34981,7 +36828,7 @@ BLIS_INLINE err_t bli_ztrsm_small_XAutB_XAlB and peform TRSM operation. */ - BLIS_PRE_ZTRSM_SMALL_3x2(AlphaVal,b11,cs_b) + BLIS_PRE_ZTRSM_SMALL_3x1(AlphaVal,b11,cs_b) ///implement TRSM/// /* Compute 3x3 TRSM block by using GEMM block output @@ -35306,10 +37153,12 @@ BLIS_INLINE err_t bli_ztrsm_small_XAutB_XAlB For first itteration there will be no GEMM operation where k_iter are zero */ - BLIS_ZTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) + //BLIS_ZTRSM_SMALL_GEMM_3nx3m(a01,b10,cs_b,p_lda,k_iter) + BLIS_ZTRSM_SMALL_GEMM_2nx3m(a01,b10,cs_b,p_lda,k_iter) // Load b11 and multiply with alpha - BLIS_PRE_ZTRSM_SMALL_3x4(AlphaVal,b11,cs_b) + //BLIS_PRE_ZTRSM_SMALL_3x3(AlphaVal,b11,cs_b) + BLIS_PRE_ZTRSM_SMALL_2x3(AlphaVal,b11,cs_b) ////extract a00 ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); @@ -35393,10 +37242,10 @@ BLIS_INLINE err_t bli_ztrsm_small_XAutB_XAlB For first itteration there will be no GEMM operation where k_iter are zero */ - BLIS_ZTRSM_SMALL_GEMM_3nx2m(a01,b10,cs_b,p_lda,k_iter) + BLIS_ZTRSM_SMALL_GEMM_2nx2m(a01,b10,cs_b,p_lda,k_iter) // Load b11 and multiply with alpha - BLIS_PRE_ZTRSM_SMALL_3x2(AlphaVal,b11,cs_b) + BLIS_PRE_ZTRSM_SMALL_2x2(AlphaVal,b11,cs_b) ////extract a00 ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); @@ -35464,10 +37313,10 @@ BLIS_INLINE err_t bli_ztrsm_small_XAutB_XAlB For first itteration there will be no GEMM operation where k_iter are zero */ - BLIS_ZTRSM_SMALL_GEMM_3nx2m(a01,b10,cs_b,p_lda,k_iter) + BLIS_ZTRSM_SMALL_GEMM_2nx1m(a01,b10,cs_b,p_lda,k_iter) // Load b11 and multiply with alpha - BLIS_PRE_ZTRSM_SMALL_3x2(AlphaVal,b11,cs_b) + BLIS_PRE_ZTRSM_SMALL_2x1(AlphaVal,b11,cs_b) ////extract a00 ymm0 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); ymm15 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); @@ -35657,7 +37506,7 @@ BLIS_INLINE err_t bli_ztrsm_small_XAutB_XAlB BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_ZTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_ZTRSM_SMALL_GEMM_1nx3m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_ZTRSM_SMALL_1x3(b11,cs_b,AlphaVal) ///implement TRSM/// @@ -36088,10 +37937,10 @@ BLIS_INLINE err_t bli_ztrsm_small_XAltB_XAuB BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_ZTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_ZTRSM_SMALL_GEMM_3nx3m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_ZTRSM_SMALL_3x4(AlphaVal,b11,cs_b) + BLIS_PRE_ZTRSM_SMALL_3x3(AlphaVal,b11,cs_b) ///implement TRSM/// ////extract a00 @@ -36405,10 +38254,10 @@ BLIS_INLINE err_t bli_ztrsm_small_XAltB_XAuB BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_ZTRSM_SMALL_GEMM_3nx2m(a01,b10,cs_b,p_lda,k_iter) + BLIS_ZTRSM_SMALL_GEMM_3nx1m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 2x3 and multiply with alpha - BLIS_PRE_ZTRSM_SMALL_3x2(AlphaVal,b11,cs_b) + BLIS_PRE_ZTRSM_SMALL_3x1(AlphaVal,b11,cs_b) ///implement TRSM/// ////extract a00 @@ -36731,10 +38580,10 @@ BLIS_INLINE err_t bli_ztrsm_small_XAltB_XAuB BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_ZTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_ZTRSM_SMALL_GEMM_2nx3m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_ZTRSM_SMALL_2x4(AlphaVal,b11,cs_b) + BLIS_PRE_ZTRSM_SMALL_2x3(AlphaVal,b11,cs_b) ///implement TRSM/// ////extract a00 @@ -36888,10 +38737,10 @@ BLIS_INLINE err_t bli_ztrsm_small_XAltB_XAuB BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_ZTRSM_SMALL_GEMM_2nx2m(a01,b10,cs_b,p_lda,k_iter) + BLIS_ZTRSM_SMALL_GEMM_2nx1m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_ZTRSM_SMALL_2x2(AlphaVal,b11,cs_b) + BLIS_PRE_ZTRSM_SMALL_2x1(AlphaVal,b11,cs_b) ///implement TRSM/// ////extract a00 @@ -37084,7 +38933,7 @@ BLIS_INLINE err_t bli_ztrsm_small_XAltB_XAuB BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_ZTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_ZTRSM_SMALL_GEMM_1nx3m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_ZTRSM_SMALL_1x3(b11,cs_b,AlphaVal) ///implement TRSM/// @@ -37673,7 +39522,7 @@ BLIS_INLINE void ctrsm_small_pack_diag_element ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal));\ ymm16 = _mm256_permute_ps(ymm16, 0x44);\ \ - xmm0 = _mm_loadu_ps((float const *)(b11));\ + xmm0 = _mm_loadl_pi(xmm0,(__m64 const *)(b11));\ ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0);\ /*in register transpose * ymm0,ymm1,ymm2 holds @@ -37733,7 +39582,7 @@ BLIS_INLINE void ctrsm_small_pack_diag_element ymm16 = _mm256_permute_ps(ymm16, 0x44);\ \ xmm0 = _mm_loadu_ps((float const *)(b11));\ - xmm1 = _mm_loadu_ps((float const *)(b11 + 2));\ + xmm1 = _mm_loadl_pi(xmm1,(__m64 const *)(b11 + 2));\ ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0);\ ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 1);\ /*in register transpose @@ -37766,9 +39615,9 @@ BLIS_INLINE void ctrsm_small_pack_diag_element ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal));\ ymm16 = _mm256_permute_ps(ymm16, 0x44);\ \ - xmm0 = _mm_loadu_ps((float const *)(b11));\ + xmm0 = _mm_loadl_pi(xmm0,(__m64 const *)(b11));\ ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0);\ - xmm1 = _mm_loadu_ps((float const *)(b11 + cs_b * 1));\ + xmm1 = _mm_loadl_pi(xmm1,(__m64 const *)(b11 + cs_b * 1));\ ymm1 = _mm256_insertf128_ps(ymm1, xmm1, 0);\ /*in register transpose * ymm0,ymm1,ymm2 holds @@ -38236,6 +40085,186 @@ BLIS_INLINE void ctrsm_small_pack_diag_element ymm10 = _mm256_addsub_ps(ymm10, ymm6);\ } +#define BLIS_CTRSM_SMALL_GEMM_2nx3m(a01,b10,cs_b,p_lda,k_iter){\ + float *tptr = (float *)a01;\ + if(conjtransa) {\ + ymm18 = _mm256_setr_ps(-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0);\ + for(k = 0; k< k_iter; k++) \ + { \ + xmm5 = _mm_loadu_ps((float const *)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b10 + 2));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + p_lda * 0 + 0);\ + ymm3 = _mm256_broadcast_ss(tptr + p_lda * 0 + 1);\ + ymm3 = _mm256_mul_ps(ymm3, ymm18);\ + \ + _mm_prefetch((char*)( b10 + 2*cs_b), _MM_HINT_T0); \ + ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + p_lda * 1 * 2+ 0);\ + ymm3 = _mm256_broadcast_ss(tptr + p_lda * 1 * 2 + 1);\ + \ + ymm3 = _mm256_mul_ps(ymm3, ymm18);\ + ymm10 = _mm256_fmadd_ps(ymm0, ymm2, ymm10);\ + ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);\ + \ + tptr += 2;\ + b10 += cs_b;\ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) \ + { \ + xmm5 = _mm_loadu_ps((float const *)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b10 + 2));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + p_lda * 0 + 0);\ + ymm3 = _mm256_broadcast_ss(tptr + p_lda * 0 + 1);\ + \ + _mm_prefetch((char*)( b10 + 2*cs_b), _MM_HINT_T0); \ + ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + p_lda * 1 * 2+ 0);\ + ymm3 = _mm256_broadcast_ss(tptr + p_lda * 1 * 2 + 1);\ + \ + ymm10 = _mm256_fmadd_ps(ymm0, ymm2, ymm10);\ + ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);\ + \ + tptr += 2;\ + b10 += cs_b;\ + }\ + }\ + ymm4 = _mm256_permute_ps(ymm4, 0xb1);\ + ymm6 = _mm256_permute_ps(ymm6, 0xb1);\ + \ + ymm8 = _mm256_addsub_ps(ymm8, ymm4);\ + ymm10 = _mm256_addsub_ps(ymm10, ymm6);\ +} + +#define BLIS_CTRSM_SMALL_GEMM_2nx2m(a01,b10,cs_b,p_lda,k_iter){\ + float *tptr = (float *)a01;\ + if(conjtransa) {\ + ymm18 = _mm256_setr_ps(-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0);\ + for(k = 0; k< k_iter; k++) \ + { \ + xmm5 = _mm_loadu_ps((float const *)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + p_lda * 0 + 0);\ + ymm3 = _mm256_broadcast_ss(tptr + p_lda * 0 + 1);\ + ymm3 = _mm256_mul_ps(ymm3, ymm18);\ + \ + _mm_prefetch((char*)( b10 + 2*cs_b), _MM_HINT_T0); \ + ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + p_lda * 1 * 2+ 0);\ + ymm3 = _mm256_broadcast_ss(tptr + p_lda * 1 * 2 + 1);\ + \ + ymm3 = _mm256_mul_ps(ymm3, ymm18);\ + ymm10 = _mm256_fmadd_ps(ymm0, ymm2, ymm10);\ + ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);\ + \ + tptr += 2;\ + b10 += cs_b;\ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) \ + { \ + xmm5 = _mm_loadu_ps((float const *)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + p_lda * 0 + 0);\ + ymm3 = _mm256_broadcast_ss(tptr + p_lda * 0 + 1);\ + \ + _mm_prefetch((char*)( b10 + 2*cs_b), _MM_HINT_T0); \ + ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + p_lda * 1 * 2+ 0);\ + ymm3 = _mm256_broadcast_ss(tptr + p_lda * 1 * 2 + 1);\ + \ + ymm10 = _mm256_fmadd_ps(ymm0, ymm2, ymm10);\ + ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);\ + \ + tptr += 2;\ + b10 += cs_b;\ + }\ + }\ + ymm4 = _mm256_permute_ps(ymm4, 0xb1);\ + ymm6 = _mm256_permute_ps(ymm6, 0xb1);\ + \ + ymm8 = _mm256_addsub_ps(ymm8, ymm4);\ + ymm10 = _mm256_addsub_ps(ymm10, ymm6);\ +} + +#define BLIS_CTRSM_SMALL_GEMM_2nx1m(a01,b10,cs_b,p_lda,k_iter){\ + float *tptr = (float *)a01;\ + if(conjtransa) {\ + ymm18 = _mm256_setr_ps(-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0);\ + for(k = 0; k< k_iter; k++) \ + { \ + ymm0 = _mm256_broadcast_ps(( __m128 const *)(b10));\ + ymm0 = _mm256_permute_ps(ymm0, 0x44);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + p_lda * 0 + 0);\ + ymm3 = _mm256_broadcast_ss(tptr + p_lda * 0 + 1);\ + ymm3 = _mm256_mul_ps(ymm3, ymm18);\ + \ + _mm_prefetch((char*)( b10 + 2*cs_b), _MM_HINT_T0); \ + ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + p_lda * 1 * 2+ 0);\ + ymm3 = _mm256_broadcast_ss(tptr + p_lda * 1 * 2 + 1);\ + \ + ymm3 = _mm256_mul_ps(ymm3, ymm18);\ + ymm10 = _mm256_fmadd_ps(ymm0, ymm2, ymm10);\ + ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);\ + \ + tptr += 2;\ + b10 += cs_b;\ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) \ + { \ + ymm0 = _mm256_broadcast_ps(( __m128 const *)(b10 + 2));\ + ymm0 = _mm256_permute_ps(ymm0, 0x44);\ + xmm5 = _mm_loadu_ps((float const *)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + p_lda * 0 + 0);\ + ymm3 = _mm256_broadcast_ss(tptr + p_lda * 0 + 1);\ + \ + _mm_prefetch((char*)( b10 + 2*cs_b), _MM_HINT_T0); \ + ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + p_lda * 1 * 2+ 0);\ + ymm3 = _mm256_broadcast_ss(tptr + p_lda * 1 * 2 + 1);\ + \ + ymm10 = _mm256_fmadd_ps(ymm0, ymm2, ymm10);\ + ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);\ + \ + tptr += 2;\ + b10 += cs_b;\ + }\ + }\ + ymm4 = _mm256_permute_ps(ymm4, 0xb1);\ + ymm6 = _mm256_permute_ps(ymm6, 0xb1);\ + \ + ymm8 = _mm256_addsub_ps(ymm8, ymm4);\ + ymm10 = _mm256_addsub_ps(ymm10, ymm6);\ +} + /** * Performs GEMM operation. * Four elements of column in ymm0 @@ -38282,6 +40311,139 @@ BLIS_INLINE void ctrsm_small_pack_diag_element ymm8 = _mm256_addsub_ps(ymm8, ymm4);\ } +#define BLIS_CTRSM_SMALL_GEMM_1nx3m(a01,b10,cs_b,p_lda,k_iter){\ + float *tptr = (float *)a01;\ + if(conjtransa) {\ + ymm18 = _mm256_setr_ps(-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0);\ + for(k = 0; k< k_iter; k++) \ + { \ + xmm5 = _mm_loadu_ps((float const *)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b10 + 2));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + p_lda * 0 + 0);\ + ymm3 = _mm256_broadcast_ss(tptr + p_lda * 0 + 1);\ + ymm3 = _mm256_mul_ps(ymm3, ymm18);\ + \ + _mm_prefetch((char*)( b10 + 4*cs_b), _MM_HINT_T0); \ + ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ + \ + tptr += 2;\ + b10 += cs_b;\ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) \ + { \ + xmm5 = _mm_loadu_ps((float const *)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b10 + 2));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + p_lda * 0 + 0);\ + ymm3 = _mm256_broadcast_ss(tptr + p_lda * 0 + 1);\ + \ + _mm_prefetch((char*)( b10 + 4*cs_b), _MM_HINT_T0); \ + ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ + \ + tptr += 2;\ + b10 += cs_b;\ + }\ + }\ + ymm4 = _mm256_permute_ps(ymm4, 0xb1);\ + \ + ymm8 = _mm256_addsub_ps(ymm8, ymm4);\ +} + +#define BLIS_CTRSM_SMALL_GEMM_1nx2m(a01,b10,cs_b,p_lda,k_iter){\ + float *tptr = (float *)a01;\ + if(conjtransa) {\ + ymm18 = _mm256_setr_ps(-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0);\ + for(k = 0; k< k_iter; k++) \ + { \ + xmm5 = _mm_loadu_ps((float const *)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + p_lda * 0 + 0);\ + ymm3 = _mm256_broadcast_ss(tptr + p_lda * 0 + 1);\ + ymm3 = _mm256_mul_ps(ymm3, ymm18);\ + \ + _mm_prefetch((char*)( b10 + 4*cs_b), _MM_HINT_T0); \ + ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ + \ + tptr += 2;\ + b10 += cs_b;\ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) \ + { \ + xmm5 = _mm_loadu_ps((float const *)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + p_lda * 0 + 0);\ + ymm3 = _mm256_broadcast_ss(tptr + p_lda * 0 + 1);\ + \ + _mm_prefetch((char*)( b10 + 4*cs_b), _MM_HINT_T0); \ + ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ + \ + tptr += 2;\ + b10 += cs_b;\ + }\ + }\ + ymm4 = _mm256_permute_ps(ymm4, 0xb1);\ + \ + ymm8 = _mm256_addsub_ps(ymm8, ymm4);\ +} + +#define BLIS_CTRSM_SMALL_GEMM_1nx1m(a01,b10,cs_b,p_lda,k_iter){\ + float *tptr = (float *)a01;\ + if(conjtransa) {\ + ymm18 = _mm256_setr_ps(-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0);\ + for(k = 0; k< k_iter; k++) \ + { \ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + p_lda * 0 + 0);\ + ymm3 = _mm256_broadcast_ss(tptr + p_lda * 0 + 1);\ + ymm3 = _mm256_mul_ps(ymm3, ymm18);\ + \ + _mm_prefetch((char*)( b10 + 4*cs_b), _MM_HINT_T0); \ + ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ + \ + tptr += 2;\ + b10 += cs_b;\ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) \ + { \ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + p_lda * 0 + 0);\ + ymm3 = _mm256_broadcast_ss(tptr + p_lda * 0 + 1);\ + \ + _mm_prefetch((char*)( b10 + 4*cs_b), _MM_HINT_T0); \ + ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ + \ + tptr += 2;\ + b10 += cs_b;\ + }\ + }\ + ymm4 = _mm256_permute_ps(ymm4, 0xb1);\ + \ + ymm8 = _mm256_addsub_ps(ymm8, ymm4);\ +} + /** * Performs GEMM operation. * Eight elements of column in ymm0, ymm1 @@ -38487,6 +40649,228 @@ BLIS_INLINE void ctrsm_small_pack_diag_element ymm12 = _mm256_addsub_ps(ymm12, ymm14);\ } +#define BLIS_CTRSM_SMALL_GEMM_3nx3m(a01,b10,cs_b,p_lda,k_iter) {\ + float *tptr = (float *)a01;\ + if(conjtransa) {\ + ymm18 = _mm256_setr_ps(-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0);\ + for(k = 0; k< k_iter; k++) \ + { \ + xmm5 = _mm_loadu_ps((float const *)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b10 + 2));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + p_lda * 0 + 0);\ + ymm3 = _mm256_broadcast_ss(tptr + p_lda * 0 + 1);\ + ymm3 = _mm256_mul_ps(ymm3, ymm18);\ + \ + _mm_prefetch((char*)( b10 + 4*cs_b), _MM_HINT_T0); \ + ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + p_lda * 1 * 2+ 0);\ + ymm3 = _mm256_broadcast_ss(tptr + p_lda * 1 * 2 + 1);\ + \ + ymm3 = _mm256_mul_ps(ymm3, ymm18);\ + ymm10 = _mm256_fmadd_ps(ymm0, ymm2, ymm10);\ + ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + p_lda * 2 * 2+ 0);\ + ymm3 = _mm256_broadcast_ss(tptr + p_lda * 2 * 2 + 1);\ + \ + ymm3 = _mm256_mul_ps(ymm3, ymm18);\ + ymm12 = _mm256_fmadd_ps(ymm0, ymm2, ymm12);\ + ymm14 = _mm256_fmadd_ps(ymm0, ymm3, ymm14);\ + \ + tptr += 2;\ + b10 += cs_b;\ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) \ + { \ + xmm5 = _mm_loadu_ps((float const *)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b10 + 2));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ + \ + _mm_prefetch((char*)( b10 + 4*cs_b), _MM_HINT_T0); \ + ymm2 = _mm256_broadcast_ss(tptr + p_lda * 0 + 0);\ + ymm3 = _mm256_broadcast_ss(tptr + p_lda * 0 + 1);\ + \ + ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + p_lda * 1 * 2+ 0);\ + ymm3 = _mm256_broadcast_ss(tptr + p_lda * 1 * 2 + 1);\ + \ + ymm10 = _mm256_fmadd_ps(ymm0, ymm2, ymm10);\ + ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + p_lda * 2 * 2+ 0);\ + ymm3 = _mm256_broadcast_ss(tptr + p_lda * 2 * 2 + 1);\ + \ + ymm12 = _mm256_fmadd_ps(ymm0, ymm2, ymm12);\ + ymm14 = _mm256_fmadd_ps(ymm0, ymm3, ymm14);\ + \ + tptr += 2;\ + b10 += cs_b;\ + }\ + }\ + ymm4 = _mm256_permute_ps(ymm4, 0xb1);\ + ymm6 = _mm256_permute_ps(ymm6, 0xb1);\ + ymm14 = _mm256_permute_ps(ymm14, 0xb1);\ + \ + ymm8 = _mm256_addsub_ps(ymm8, ymm4);\ + ymm10 = _mm256_addsub_ps(ymm10, ymm6);\ + ymm12 = _mm256_addsub_ps(ymm12, ymm14);\ +} + +#define BLIS_CTRSM_SMALL_GEMM_3nx2m(a01,b10,cs_b,p_lda,k_iter) {\ + float *tptr = (float *)a01;\ + if(conjtransa) {\ + ymm18 = _mm256_setr_ps(-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0);\ + for(k = 0; k< k_iter; k++) \ + { \ + xmm5 = _mm_loadu_ps((float const *)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + p_lda * 0 + 0);\ + ymm3 = _mm256_broadcast_ss(tptr + p_lda * 0 + 1);\ + ymm3 = _mm256_mul_ps(ymm3, ymm18);\ + \ + _mm_prefetch((char*)( b10 + 4*cs_b), _MM_HINT_T0); \ + ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + p_lda * 1 * 2+ 0);\ + ymm3 = _mm256_broadcast_ss(tptr + p_lda * 1 * 2 + 1);\ + \ + ymm3 = _mm256_mul_ps(ymm3, ymm18);\ + ymm10 = _mm256_fmadd_ps(ymm0, ymm2, ymm10);\ + ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + p_lda * 2 * 2+ 0);\ + ymm3 = _mm256_broadcast_ss(tptr + p_lda * 2 * 2 + 1);\ + \ + ymm3 = _mm256_mul_ps(ymm3, ymm18);\ + ymm12 = _mm256_fmadd_ps(ymm0, ymm2, ymm12);\ + ymm14 = _mm256_fmadd_ps(ymm0, ymm3, ymm14);\ + \ + tptr += 2;\ + b10 += cs_b;\ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) \ + { \ + xmm5 = _mm_loadu_ps((float const *)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + \ + _mm_prefetch((char*)( b10 + 4*cs_b), _MM_HINT_T0); \ + ymm2 = _mm256_broadcast_ss(tptr + p_lda * 0 + 0);\ + ymm3 = _mm256_broadcast_ss(tptr + p_lda * 0 + 1);\ + \ + ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + p_lda * 1 * 2+ 0);\ + ymm3 = _mm256_broadcast_ss(tptr + p_lda * 1 * 2 + 1);\ + \ + ymm10 = _mm256_fmadd_ps(ymm0, ymm2, ymm10);\ + ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + p_lda * 2 * 2+ 0);\ + ymm3 = _mm256_broadcast_ss(tptr + p_lda * 2 * 2 + 1);\ + \ + ymm12 = _mm256_fmadd_ps(ymm0, ymm2, ymm12);\ + ymm14 = _mm256_fmadd_ps(ymm0, ymm3, ymm14);\ + \ + tptr += 2;\ + b10 += cs_b;\ + }\ + }\ + ymm4 = _mm256_permute_ps(ymm4, 0xb1);\ + ymm6 = _mm256_permute_ps(ymm6, 0xb1);\ + ymm14 = _mm256_permute_ps(ymm14, 0xb1);\ + \ + ymm8 = _mm256_addsub_ps(ymm8, ymm4);\ + ymm10 = _mm256_addsub_ps(ymm10, ymm6);\ + ymm12 = _mm256_addsub_ps(ymm12, ymm14);\ +} + +#define BLIS_CTRSM_SMALL_GEMM_3nx1m(a01,b10,cs_b,p_lda,k_iter) {\ + float *tptr = (float *)a01;\ + if(conjtransa) {\ + ymm18 = _mm256_setr_ps(-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0);\ + for(k = 0; k< k_iter; k++) \ + { \ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + p_lda * 0 + 0);\ + ymm3 = _mm256_broadcast_ss(tptr + p_lda * 0 + 1);\ + ymm3 = _mm256_mul_ps(ymm3, ymm18);\ + \ + _mm_prefetch((char*)( b10 + 4*cs_b), _MM_HINT_T0); \ + ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + p_lda * 1 * 2+ 0);\ + ymm3 = _mm256_broadcast_ss(tptr + p_lda * 1 * 2 + 1);\ + \ + ymm3 = _mm256_mul_ps(ymm3, ymm18);\ + ymm10 = _mm256_fmadd_ps(ymm0, ymm2, ymm10);\ + ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + p_lda * 2 * 2+ 0);\ + ymm3 = _mm256_broadcast_ss(tptr + p_lda * 2 * 2 + 1);\ + \ + ymm3 = _mm256_mul_ps(ymm3, ymm18);\ + ymm12 = _mm256_fmadd_ps(ymm0, ymm2, ymm12);\ + ymm14 = _mm256_fmadd_ps(ymm0, ymm3, ymm14);\ + \ + tptr += 2;\ + b10 += cs_b;\ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) \ + { \ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + \ + _mm_prefetch((char*)( b10 + 4*cs_b), _MM_HINT_T0); \ + ymm2 = _mm256_broadcast_ss(tptr + p_lda * 0 + 0);\ + ymm3 = _mm256_broadcast_ss(tptr + p_lda * 0 + 1);\ + \ + ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + p_lda * 1 * 2+ 0);\ + ymm3 = _mm256_broadcast_ss(tptr + p_lda * 1 * 2 + 1);\ + \ + ymm10 = _mm256_fmadd_ps(ymm0, ymm2, ymm10);\ + ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + p_lda * 2 * 2+ 0);\ + ymm3 = _mm256_broadcast_ss(tptr + p_lda * 2 * 2 + 1);\ + \ + ymm12 = _mm256_fmadd_ps(ymm0, ymm2, ymm12);\ + ymm14 = _mm256_fmadd_ps(ymm0, ymm3, ymm14);\ + \ + tptr += 2;\ + b10 += cs_b;\ + }\ + }\ + ymm4 = _mm256_permute_ps(ymm4, 0xb1);\ + ymm6 = _mm256_permute_ps(ymm6, 0xb1);\ + ymm14 = _mm256_permute_ps(ymm14, 0xb1);\ + \ + ymm8 = _mm256_addsub_ps(ymm8, ymm4);\ + ymm10 = _mm256_addsub_ps(ymm10, ymm6);\ + ymm12 = _mm256_addsub_ps(ymm12, ymm14);\ +} /** * Performs GEMM operation. @@ -38878,6 +41262,513 @@ BLIS_INLINE void ctrsm_small_pack_diag_element ymm10 = _mm256_addsub_ps(ymm10, ymm11);\ } +#define BLIS_CTRSM_SMALL_GEMM_2mx3n(a10,b01,cs_b,p_lda,k_iter) {\ + float *tptr = (float *)b01;\ + if(conjtransa) {\ + ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0);\ + for(k = 0; k< k_iter; k++) \ + { \ + xmm5 = _mm_loadu_ps((float const *)(a10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + ymm0 = _mm256_mul_ps(ymm0, ymm18);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + cs_b * 0 + 0);\ + ymm3 = _mm256_broadcast_ss(tptr + cs_b * 0 + 1);\ + \ + ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + cs_b * 1 * 2+ 0);\ + ymm3 = _mm256_broadcast_ss(tptr + cs_b * 1 * 2 + 1);\ + \ + ymm9 = _mm256_fmadd_ps(ymm0, ymm2, ymm9);\ + ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + cs_b * 2 * 2+ 0);\ + ymm3 = _mm256_broadcast_ss(tptr + cs_b * 2 * 2 + 1);\ + \ + ymm10 = _mm256_fmadd_ps(ymm0, ymm2, ymm10);\ + ymm11 = _mm256_fmadd_ps(ymm0, ymm3, ymm11);\ + \ + tptr += 2;\ + a10 += p_lda;\ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) \ + { \ + xmm5 = _mm_loadu_ps((float const *)(a10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + cs_b * 0 + 0);\ + ymm3 = _mm256_broadcast_ss(tptr + cs_b * 0 + 1);\ + \ + ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + cs_b * 1 * 2+ 0);\ + ymm3 = _mm256_broadcast_ss(tptr + cs_b * 1 * 2 + 1);\ + \ + ymm9 = _mm256_fmadd_ps(ymm0, ymm2, ymm9);\ + ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + cs_b * 2 * 2+ 0);\ + ymm3 = _mm256_broadcast_ss(tptr + cs_b * 2 * 2 + 1);\ + \ + ymm10 = _mm256_fmadd_ps(ymm0, ymm2, ymm10);\ + ymm11 = _mm256_fmadd_ps(ymm0, ymm3, ymm11);\ + \ + tptr += 2;\ + a10 += p_lda;\ + }\ + }\ + ymm4 = _mm256_permute_ps(ymm4, 0xb1);\ + ymm6 = _mm256_permute_ps(ymm6, 0xb1);\ + ymm11 = _mm256_permute_ps(ymm11, 0xb1);\ + \ + ymm8 = _mm256_addsub_ps(ymm8, ymm4);\ + ymm9 = _mm256_addsub_ps(ymm9, ymm6);\ + ymm10 = _mm256_addsub_ps(ymm10, ymm11);\ +} + +#define BLIS_CTRSM_SMALL_GEMM_3mx3n(a10,b01,cs_b,p_lda,k_iter) {\ + float *tptr = (float *)b01;\ + if(conjtransa) {\ + ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0);\ + for(k = 0; k< k_iter; k++) \ + { \ + xmm5 = _mm_loadu_ps((float const *)(a10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(a10 + 2));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ + ymm0 = _mm256_mul_ps(ymm0, ymm18);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + cs_b * 0 + 0);\ + ymm3 = _mm256_broadcast_ss(tptr + cs_b * 0 + 1);\ + \ + ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + cs_b * 1 * 2+ 0);\ + ymm3 = _mm256_broadcast_ss(tptr + cs_b * 1 * 2 + 1);\ + \ + ymm9 = _mm256_fmadd_ps(ymm0, ymm2, ymm9);\ + ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + cs_b * 2 * 2+ 0);\ + ymm3 = _mm256_broadcast_ss(tptr + cs_b * 2 * 2 + 1);\ + \ + ymm10 = _mm256_fmadd_ps(ymm0, ymm2, ymm10);\ + ymm11 = _mm256_fmadd_ps(ymm0, ymm3, ymm11);\ + \ + tptr += 2;\ + a10 += p_lda;\ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) \ + { \ + xmm5 = _mm_loadu_ps((float const *)(a10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(a10 + 2));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + cs_b * 0 + 0);\ + ymm3 = _mm256_broadcast_ss(tptr + cs_b * 0 + 1);\ + \ + ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + cs_b * 1 * 2+ 0);\ + ymm3 = _mm256_broadcast_ss(tptr + cs_b * 1 * 2 + 1);\ + \ + ymm9 = _mm256_fmadd_ps(ymm0, ymm2, ymm9);\ + ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + cs_b * 2 * 2+ 0);\ + ymm3 = _mm256_broadcast_ss(tptr + cs_b * 2 * 2 + 1);\ + \ + ymm10 = _mm256_fmadd_ps(ymm0, ymm2, ymm10);\ + ymm11 = _mm256_fmadd_ps(ymm0, ymm3, ymm11);\ + \ + tptr += 2;\ + a10 += p_lda;\ + }\ + }\ + ymm4 = _mm256_permute_ps(ymm4, 0xb1);\ + ymm6 = _mm256_permute_ps(ymm6, 0xb1);\ + ymm11 = _mm256_permute_ps(ymm11, 0xb1);\ + \ + ymm8 = _mm256_addsub_ps(ymm8, ymm4);\ + ymm9 = _mm256_addsub_ps(ymm9, ymm6);\ + ymm10 = _mm256_addsub_ps(ymm10, ymm11);\ +} + +#define BLIS_CTRSM_SMALL_GEMM_3mx2n(a10,b01,cs_b,p_lda,k_iter) {\ + float *tptr = (float *)b01;\ + if(conjtransa) {\ + ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0);\ + for(k = 0; k< k_iter; k++) \ + { \ + xmm5 = _mm_loadu_ps((float const *)(a10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(a10 + 2));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ + ymm0 = _mm256_mul_ps(ymm0, ymm18);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + cs_b * 0 + 0);\ + ymm3 = _mm256_broadcast_ss(tptr + cs_b * 0 + 1);\ + \ + ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + cs_b * 1 * 2+ 0);\ + ymm3 = _mm256_broadcast_ss(tptr + cs_b * 1 * 2 + 1);\ + \ + ymm9 = _mm256_fmadd_ps(ymm0, ymm2, ymm9);\ + ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);\ + \ + tptr += 2;\ + a10 += p_lda;\ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) \ + { \ + xmm5 = _mm_loadu_ps((float const *)(a10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(a10 + 2));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + cs_b * 0 + 0);\ + ymm3 = _mm256_broadcast_ss(tptr + cs_b * 0 + 1);\ + \ + ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + cs_b * 1 * 2+ 0);\ + ymm3 = _mm256_broadcast_ss(tptr + cs_b * 1 * 2 + 1);\ + \ + ymm9 = _mm256_fmadd_ps(ymm0, ymm2, ymm9);\ + ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);\ + \ + tptr += 2;\ + a10 += p_lda;\ + }\ + }\ + ymm4 = _mm256_permute_ps(ymm4, 0xb1);\ + ymm6 = _mm256_permute_ps(ymm6, 0xb1);\ + \ + ymm8 = _mm256_addsub_ps(ymm8, ymm4);\ + ymm9 = _mm256_addsub_ps(ymm9, ymm6);\ +} + +#define BLIS_CTRSM_SMALL_GEMM_3mx1n(a10,b01,cs_b,p_lda,k_iter) {\ + float *tptr = (float *)b01;\ + if(conjtransa) {\ + ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0);\ + for(k = 0; k< k_iter; k++) \ + { \ + xmm5 = _mm_loadu_ps((float const *)(a10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(a10 + 2));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ + ymm0 = _mm256_mul_ps(ymm0, ymm18);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + cs_b * 0 + 0);\ + ymm3 = _mm256_broadcast_ss(tptr + cs_b * 0 + 1);\ + \ + ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ + \ + tptr += 2;\ + a10 += p_lda;\ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) \ + { \ + xmm5 = _mm_loadu_ps((float const *)(a10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(a10 + 2));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + cs_b * 0 + 0);\ + ymm3 = _mm256_broadcast_ss(tptr + cs_b * 0 + 1);\ + \ + ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ + \ + tptr += 2;\ + a10 += p_lda;\ + }\ + }\ + ymm4 = _mm256_permute_ps(ymm4, 0xb1);\ + \ + ymm8 = _mm256_addsub_ps(ymm8, ymm4);\ +} + +#define BLIS_CTRSM_SMALL_GEMM_2mx2n(a10,b01,cs_b,p_lda,k_iter) {\ + float *tptr = (float *)b01;\ + if(conjtransa) {\ + ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0);\ + for(k = 0; k< k_iter; k++) \ + { \ + xmm5 = _mm_loadu_ps((float const *)(a10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + ymm0 = _mm256_mul_ps(ymm0, ymm18);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + cs_b * 0 + 0);\ + ymm3 = _mm256_broadcast_ss(tptr + cs_b * 0 + 1);\ + \ + ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + cs_b * 1 * 2+ 0);\ + ymm3 = _mm256_broadcast_ss(tptr + cs_b * 1 * 2 + 1);\ + \ + ymm9 = _mm256_fmadd_ps(ymm0, ymm2, ymm9);\ + ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);\ + \ + tptr += 2;\ + a10 += p_lda;\ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) \ + { \ + xmm5 = _mm_loadu_ps((float const *)(a10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + cs_b * 0 + 0);\ + ymm3 = _mm256_broadcast_ss(tptr + cs_b * 0 + 1);\ + \ + ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + cs_b * 1 * 2+ 0);\ + ymm3 = _mm256_broadcast_ss(tptr + cs_b * 1 * 2 + 1);\ + \ + ymm9 = _mm256_fmadd_ps(ymm0, ymm2, ymm9);\ + ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);\ + \ + tptr += 2;\ + a10 += p_lda;\ + }\ + }\ + ymm4 = _mm256_permute_ps(ymm4, 0xb1);\ + ymm6 = _mm256_permute_ps(ymm6, 0xb1);\ + \ + ymm8 = _mm256_addsub_ps(ymm8, ymm4);\ + ymm9 = _mm256_addsub_ps(ymm9, ymm6);\ +} + +#define BLIS_CTRSM_SMALL_GEMM_2mx1n(a10,b01,cs_b,p_lda,k_iter) {\ + float *tptr = (float *)b01;\ + if(conjtransa) {\ + ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0);\ + for(k = 0; k< k_iter; k++) \ + { \ + xmm5 = _mm_loadu_ps((float const *)(a10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + ymm0 = _mm256_mul_ps(ymm0, ymm18);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + cs_b * 0 + 0);\ + ymm3 = _mm256_broadcast_ss(tptr + cs_b * 0 + 1);\ + \ + ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ + \ + tptr += 2;\ + a10 += p_lda;\ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) \ + { \ + xmm5 = _mm_loadu_ps((float const *)(a10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + cs_b * 0 + 0);\ + ymm3 = _mm256_broadcast_ss(tptr + cs_b * 0 + 1);\ + \ + ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ + \ + tptr += 2;\ + a10 += p_lda;\ + }\ + }\ + ymm4 = _mm256_permute_ps(ymm4, 0xb1);\ + \ + ymm8 = _mm256_addsub_ps(ymm8, ymm4);\ +} + +#define BLIS_CTRSM_SMALL_GEMM_1mx3n(a10,b01,cs_b,p_lda,k_iter) {\ + float *tptr = (float *)b01;\ + if(conjtransa) {\ + ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0);\ + for(k = 0; k< k_iter; k++) \ + { \ + xmm0 = _mm_loadl_pi(xmm0,(__m64 *)(a10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0);\ + ymm0 = _mm256_mul_ps(ymm0, ymm18);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + cs_b * 0 + 0);\ + ymm3 = _mm256_broadcast_ss(tptr + cs_b * 0 + 1);\ + \ + ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + cs_b * 1 * 2+ 0);\ + ymm3 = _mm256_broadcast_ss(tptr + cs_b * 1 * 2 + 1);\ + \ + ymm9 = _mm256_fmadd_ps(ymm0, ymm2, ymm9);\ + ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + cs_b * 2 * 2+ 0);\ + ymm3 = _mm256_broadcast_ss(tptr + cs_b * 2 * 2 + 1);\ + \ + ymm10 = _mm256_fmadd_ps(ymm0, ymm2, ymm10);\ + ymm11 = _mm256_fmadd_ps(ymm0, ymm3, ymm11);\ + \ + tptr += 2;\ + a10 += p_lda;\ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) \ + { \ + xmm0 = _mm_loadl_pi(xmm0,(__m64 *)(a10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + cs_b * 0 + 0);\ + ymm3 = _mm256_broadcast_ss(tptr + cs_b * 0 + 1);\ + \ + ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + cs_b * 1 * 2+ 0);\ + ymm3 = _mm256_broadcast_ss(tptr + cs_b * 1 * 2 + 1);\ + \ + ymm9 = _mm256_fmadd_ps(ymm0, ymm2, ymm9);\ + ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + cs_b * 2 * 2+ 0);\ + ymm3 = _mm256_broadcast_ss(tptr + cs_b * 2 * 2 + 1);\ + \ + ymm10 = _mm256_fmadd_ps(ymm0, ymm2, ymm10);\ + ymm11 = _mm256_fmadd_ps(ymm0, ymm3, ymm11);\ + \ + tptr += 2;\ + a10 += p_lda;\ + }\ + }\ + ymm4 = _mm256_permute_ps(ymm4, 0xb1);\ + ymm6 = _mm256_permute_ps(ymm6, 0xb1);\ + ymm11 = _mm256_permute_ps(ymm11, 0xb1);\ + \ + ymm8 = _mm256_addsub_ps(ymm8, ymm4);\ + ymm9 = _mm256_addsub_ps(ymm9, ymm6);\ + ymm10 = _mm256_addsub_ps(ymm10, ymm11);\ +} + +#define BLIS_CTRSM_SMALL_GEMM_1mx2n(a10,b01,cs_b,p_lda,k_iter) {\ + float *tptr = (float *)b01;\ + if(conjtransa) {\ + ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0);\ + for(k = 0; k< k_iter; k++) \ + { \ + xmm0 = _mm_loadl_pi(xmm0,(__m64 *)(a10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0);\ + ymm0 = _mm256_mul_ps(ymm0, ymm18);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + cs_b * 0 + 0);\ + ymm3 = _mm256_broadcast_ss(tptr + cs_b * 0 + 1);\ + \ + ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + cs_b * 1 * 2+ 0);\ + ymm3 = _mm256_broadcast_ss(tptr + cs_b * 1 * 2 + 1);\ + \ + ymm9 = _mm256_fmadd_ps(ymm0, ymm2, ymm9);\ + ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);\ + \ + tptr += 2;\ + a10 += p_lda;\ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) \ + { \ + xmm0 = _mm_loadl_pi(xmm0,(__m64 *)(a10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + cs_b * 0 + 0);\ + ymm3 = _mm256_broadcast_ss(tptr + cs_b * 0 + 1);\ + \ + ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + cs_b * 1 * 2+ 0);\ + ymm3 = _mm256_broadcast_ss(tptr + cs_b * 1 * 2 + 1);\ + \ + ymm9 = _mm256_fmadd_ps(ymm0, ymm2, ymm9);\ + ymm6 = _mm256_fmadd_ps(ymm0, ymm3, ymm6);\ + \ + tptr += 2;\ + a10 += p_lda;\ + }\ + }\ + ymm4 = _mm256_permute_ps(ymm4, 0xb1);\ + ymm6 = _mm256_permute_ps(ymm6, 0xb1);\ + \ + ymm8 = _mm256_addsub_ps(ymm8, ymm4);\ + ymm9 = _mm256_addsub_ps(ymm9, ymm6);\ +} + +#define BLIS_CTRSM_SMALL_GEMM_1mx1n(a10,b01,cs_b,p_lda,k_iter) {\ + float *tptr = (float *)b01;\ + if(conjtransa) {\ + ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0);\ + for(k = 0; k< k_iter; k++) \ + { \ + xmm0 = _mm_loadl_pi(xmm0,(__m64 *)(a10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0);\ + ymm0 = _mm256_mul_ps(ymm0, ymm18);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + cs_b * 0 + 0);\ + ymm3 = _mm256_broadcast_ss(tptr + cs_b * 0 + 1);\ + \ + ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ + \ + tptr += 2;\ + a10 += p_lda;\ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) \ + { \ + xmm0 = _mm_loadl_pi(xmm0,(__m64 *)(a10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0);\ + \ + ymm2 = _mm256_broadcast_ss(tptr + cs_b * 0 + 0);\ + ymm3 = _mm256_broadcast_ss(tptr + cs_b * 0 + 1);\ + \ + ymm8 = _mm256_fmadd_ps(ymm0, ymm2, ymm8);\ + ymm4 = _mm256_fmadd_ps(ymm0, ymm3, ymm4);\ + \ + tptr += 2;\ + a10 += p_lda;\ + }\ + }\ + ymm4 = _mm256_permute_ps(ymm4, 0xb1);\ + \ + ymm8 = _mm256_addsub_ps(ymm8, ymm4);\ +} + /** * Performs GEMM operation. * Eight elements of column in ymm0, ymm1 @@ -38974,7 +41865,7 @@ BLIS_INLINE void ctrsm_small_pack_diag_element #define BLIS_CTRSM_SMALL_NREG_TRANSPOSE_1x4(b11,cs_b,AlphaVal){\ ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal));\ ymm16 = _mm256_permute_ps(ymm16, 0x44);\ -\ + \ ymm0 = _mm256_loadu_ps((float const *)(b11));\ ymm3 = _mm256_broadcast_ps((__m128 const *)&ones);\ ymm3 = _mm256_permute_ps(ymm3, 0x44);\ @@ -39355,6 +42246,7 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB __m256 ymm16, ymm17, ymm18, ymm19; __m128 xmm0, xmm1, xmm2, xmm3, xmm4; + __m128 xmm5; gint_t required_packing_A = 1; mem_t local_mem_buf_A_s = {0}; @@ -41109,14 +44001,23 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB k_iter = i; BLIS_SET_S_YMM_REG_ZEROS - BLIS_CTRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) + BLIS_CTRSM_SMALL_GEMM_3mx3n(a10,b01,cs_b,p_lda,k_iter) ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); ymm16 = _mm256_permute_ps(ymm16, 0x44); - ymm0 = _mm256_loadu_ps((float const *)(b11)); - ymm1 = _mm256_loadu_ps((float const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_ps((float const *)(b11 + cs_b *2)); + xmm0 = _mm_loadu_ps((float const *)(b11)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + xmm0 = _mm_loadl_pi(xmm0,(__m64 *)(b11 + 2)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 1); + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b)); + ymm1 = _mm256_insertf128_ps(ymm1, xmm0, 0); + xmm0 = _mm_loadl_pi(xmm0,(__m64 *)(b11 + cs_b + 2)); + ymm1 = _mm256_insertf128_ps(ymm1, xmm0, 1); + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b*2)); + ymm2 = _mm256_insertf128_ps(ymm2, xmm0, 0); + xmm0 = _mm_loadl_pi(xmm0,(__m64 *)(b11 + cs_b*2 + 2)); + ymm2 = _mm256_insertf128_ps(ymm2, xmm0, 1); ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); @@ -41179,7 +44080,7 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB if(2 == n_rem) { ///GEMM code begins/// - BLIS_CTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b, + BLIS_CTRSM_SMALL_GEMM_3mx2n(a10,b01,cs_b, p_lda,k_iter) BLIS_PRE_CTRSM_SMALL_3M_2N(AlphaVal,b11,cs_b) @@ -41196,7 +44097,7 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB else if(1 == n_rem) { ///GEMM code begins/// - BLIS_CTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b, + BLIS_CTRSM_SMALL_GEMM_3mx1n(a10,b01,cs_b, p_lda,k_iter) BLIS_PRE_CTRSM_SMALL_3M_1N(AlphaVal,b11,cs_b) @@ -41270,7 +44171,7 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB k_iter = i; BLIS_SET_S_YMM_REG_ZEROS - BLIS_CTRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) + BLIS_CTRSM_SMALL_GEMM_2mx3n(a10,b01,cs_b,p_lda,k_iter) ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); ymm16 = _mm256_permute_ps(ymm16, 0x44); @@ -41335,7 +44236,7 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB if(2 == n_rem) { ///GEMM code begins/// - BLIS_CTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b, + BLIS_CTRSM_SMALL_GEMM_2mx2n(a10,b01,cs_b, p_lda,k_iter) BLIS_PRE_CTRSM_SMALL_2M_2N(AlphaVal,b11,cs_b) @@ -41352,7 +44253,7 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB else if(1 == n_rem) { ///GEMM code begins/// - BLIS_CTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b, + BLIS_CTRSM_SMALL_GEMM_2mx1n(a10,b01,cs_b, p_lda,k_iter) BLIS_PRE_CTRSM_SMALL_2M_1N(AlphaVal,b11,cs_b) @@ -41424,14 +44325,17 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB k_iter = i; BLIS_SET_S_YMM_REG_ZEROS - BLIS_CTRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) + BLIS_CTRSM_SMALL_GEMM_1mx3n(a10,b01,cs_b,p_lda,k_iter) ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); ymm16 = _mm256_permute_ps(ymm16, 0x44); - ymm0 = _mm256_loadu_ps((float const *)(b11)); - ymm1 = _mm256_loadu_ps((float const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_ps((float const *)(b11 + cs_b *2)); + xmm0 = _mm_loadl_pi(xmm0,(__m64 *)(b11)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + xmm0 = _mm_loadl_pi(xmm0,(__m64 *)(b11 + cs_b)); + ymm1 = _mm256_insertf128_ps(ymm1, xmm0, 0); + xmm0 = _mm_loadl_pi(xmm0,(__m64 *)(b11 + cs_b*2)); + ymm2 = _mm256_insertf128_ps(ymm2, xmm0, 0); ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); @@ -41487,7 +44391,7 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB if(2 == n_rem) { ///GEMM code begins/// - BLIS_CTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b, + BLIS_CTRSM_SMALL_GEMM_1mx2n(a10,b01,cs_b, p_lda,k_iter) BLIS_PRE_CTRSM_SMALL_1M_2N(AlphaVal,b11, @@ -41507,7 +44411,7 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB else if(1 == n_rem) { ///GEMM code begins/// - BLIS_CTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b, + BLIS_CTRSM_SMALL_GEMM_1mx1n(a10,b01,cs_b, p_lda,k_iter) BLIS_PRE_CTRSM_SMALL_1M_1N(AlphaVal,b11, @@ -41586,6 +44490,7 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB __m256 ymm16, ymm17, ymm18, ymm19; __m128 xmm0, xmm1, xmm2, xmm3, xmm4; + __m128 xmm5; gint_t required_packing_A = 1; mem_t local_mem_buf_A_s = {0}; @@ -43311,14 +46216,23 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB k_iter = (m - m_rem); BLIS_SET_S_YMM_REG_ZEROS - BLIS_CTRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) + BLIS_CTRSM_SMALL_GEMM_3mx3n(a10,b01,cs_b,p_lda,k_iter) ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); ymm16 = _mm256_permute_ps(ymm16, 0x44); - ymm0 = _mm256_loadu_ps((float const *)(b11)); - ymm1 = _mm256_loadu_ps((float const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_ps((float const *)(b11 + cs_b *2)); + xmm0 = _mm_loadu_ps((float const *)(b11)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + xmm0 = _mm_loadl_pi(xmm0,(__m64 *)(b11 + 2)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 1); + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b)); + ymm1 = _mm256_insertf128_ps(ymm1, xmm0, 0); + xmm0 = _mm_loadl_pi(xmm0,(__m64 *)(b11 + cs_b + 2)); + ymm1 = _mm256_insertf128_ps(ymm1, xmm0, 1); + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b*2)); + ymm2 = _mm256_insertf128_ps(ymm2, xmm0, 0); + xmm0 = _mm_loadl_pi(xmm0,(__m64 *)(b11 + cs_b*2 + 2)); + ymm2 = _mm256_insertf128_ps(ymm2, xmm0, 1); ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); @@ -43381,7 +46295,7 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB if(2 == n_rem) { ///GEMM code begins/// - BLIS_CTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b, + BLIS_CTRSM_SMALL_GEMM_3mx2n(a10,b01,cs_b, p_lda,k_iter) BLIS_PRE_CTRSM_SMALL_3M_2N(AlphaVal,b11,cs_b) @@ -43398,7 +46312,7 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB else if(1 == n_rem) { ///GEMM code begins/// - BLIS_CTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b, + BLIS_CTRSM_SMALL_GEMM_3mx1n(a10,b01,cs_b, p_lda,k_iter) BLIS_PRE_CTRSM_SMALL_3M_1N(AlphaVal,b11,cs_b) @@ -43473,15 +46387,17 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB k_iter = (m - m_rem); BLIS_SET_S_YMM_REG_ZEROS - BLIS_CTRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) + BLIS_CTRSM_SMALL_GEMM_2mx3n(a10,b01,cs_b,p_lda,k_iter) ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); ymm16 = _mm256_permute_ps(ymm16, 0x44); - ymm0 = _mm256_loadu_ps((float const *)(b11)); - ymm1 = _mm256_loadu_ps((float const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_ps((float const *)(b11 + cs_b *2)); - + xmm0 = _mm_loadu_ps((float const *)(b11)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b)); + ymm1 = _mm256_insertf128_ps(ymm1, xmm0, 0); + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b*2)); + ymm2 = _mm256_insertf128_ps(ymm2, xmm0, 0); ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); @@ -43537,7 +46453,7 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB if(2 == n_rem) { ///GEMM code begins/// - BLIS_CTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b, + BLIS_CTRSM_SMALL_GEMM_2mx2n(a10,b01,cs_b, p_lda,k_iter) BLIS_PRE_CTRSM_SMALL_2M_2N(AlphaVal,b11,cs_b) @@ -43554,7 +46470,7 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB else if(1 == n_rem) { ///GEMM code begins/// - BLIS_CTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b, + BLIS_CTRSM_SMALL_GEMM_2mx1n(a10,b01,cs_b, p_lda,k_iter) BLIS_PRE_CTRSM_SMALL_2M_1N(AlphaVal,b11,cs_b) @@ -43627,15 +46543,17 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB k_iter = (m - m_rem); BLIS_SET_S_YMM_REG_ZEROS - BLIS_CTRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) + BLIS_CTRSM_SMALL_GEMM_1mx3n(a10,b01,cs_b,p_lda,k_iter) ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); ymm16 = _mm256_permute_ps(ymm16, 0x44); - - ymm0 = _mm256_loadu_ps((float const *)(b11)); - ymm1 = _mm256_loadu_ps((float const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_ps((float const *)(b11 + cs_b *2)); - + + xmm0 = _mm_loadl_pi(xmm0,(__m64 *)(b11)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + xmm0 = _mm_loadl_pi(xmm0,(__m64 *)(b11 + cs_b)); + ymm1 = _mm256_insertf128_ps(ymm1, xmm0, 0); + xmm0 = _mm_loadl_pi(xmm0,(__m64 *)(b11 + cs_b*2)); + ymm2 = _mm256_insertf128_ps(ymm2, xmm0, 0); ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); @@ -43692,7 +46610,7 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB if(2 == n_rem) { ///GEMM code begins/// - BLIS_CTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b, + BLIS_CTRSM_SMALL_GEMM_1mx2n(a10,b01,cs_b, p_lda,k_iter) BLIS_PRE_CTRSM_SMALL_1M_2N(AlphaVal,b11, @@ -43712,7 +46630,7 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB else if(1 == n_rem) { ///GEMM code begins/// - BLIS_CTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b, + BLIS_CTRSM_SMALL_GEMM_1mx1n(a10,b01,cs_b, p_lda,k_iter) BLIS_PRE_CTRSM_SMALL_1M_1N(AlphaVal,b11, @@ -43794,6 +46712,7 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB __m256 ymm16, ymm17, ymm18, ymm19; __m128 xmm0, xmm1, xmm2; + __m128 xmm5; gint_t required_packing_A = 1; mem_t local_mem_buf_A_s = {0}; @@ -44121,7 +47040,7 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_CTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_CTRSM_SMALL_GEMM_3nx3m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha BLIS_PRE_CTRSM_SMALL_3x3(AlphaVal,b11,cs_b) @@ -44229,7 +47148,7 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_CTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_CTRSM_SMALL_GEMM_3nx2m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha BLIS_PRE_CTRSM_SMALL_3x2(AlphaVal,b11,cs_b) @@ -44328,7 +47247,7 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_CTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_CTRSM_SMALL_GEMM_3nx1m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha BLIS_PRE_CTRSM_SMALL_3x1(AlphaVal,b11,cs_b) @@ -44675,7 +47594,7 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_CTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_CTRSM_SMALL_GEMM_2nx3m(a01,b10,cs_b,p_lda,k_iter) ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); ymm16 = _mm256_permute_ps(ymm16, 0x44); @@ -44759,7 +47678,7 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_CTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_CTRSM_SMALL_GEMM_2nx2m(a01,b10,cs_b,p_lda,k_iter) ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); ymm16 = _mm256_permute_ps(ymm16, 0x44); @@ -44833,7 +47752,7 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_CTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_CTRSM_SMALL_GEMM_2nx1m(a01,b10,cs_b,p_lda,k_iter) ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); ymm16 = _mm256_permute_ps(ymm16, 0x44); @@ -45065,7 +47984,7 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_CTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_CTRSM_SMALL_GEMM_1nx3m(a01,b10,cs_b,p_lda,k_iter) ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); ymm16 = _mm256_permute_ps(ymm16, 0x44); @@ -45107,7 +48026,7 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_CTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_CTRSM_SMALL_GEMM_1nx2m(a01,b10,cs_b,p_lda,k_iter) ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); ymm16 = _mm256_permute_ps(ymm16, 0x44); @@ -45145,7 +48064,7 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_CTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_CTRSM_SMALL_GEMM_1nx1m(a01,b10,cs_b,p_lda,k_iter) ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); ymm16 = _mm256_permute_ps(ymm16, 0x44); @@ -45235,6 +48154,7 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB __m256 ymm16, ymm17, ymm18, ymm19; __m128 xmm0, xmm1, xmm2; + __m128 xmm5; gint_t required_packing_A = 1; mem_t local_mem_buf_A_s = {0}; @@ -45556,354 +48476,325 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB _mm256_storeu_ps((float *)b11, ymm8); _mm256_storeu_ps((float *)(b11 + cs_b), ymm10); _mm256_storeu_ps((float *)(b11 + cs_b * 2), ymm12); - - m_rem -= 4; - i += 4; - } - if(m_rem == 3) - { - - a01 = D_A_pack; - a11 = L + j*cs_a + j*rs_a; - b10 = B + i; - b11 = B + i + j*cs_b; - - k_iter = j; - - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_S_YMM_REG_ZEROS - - ///GEMM implementation starts/// - BLIS_CTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) - - // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_CTRSM_SMALL_3x3(AlphaVal,b11,cs_b) - - ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); - ymm1 = _mm256_permute_ps(ymm1, 0x44); -#ifndef BLIS_ENABLE_TRSM_PREINVERSION - BLIS_CTRSM_DIV(ymm8) -#else - BLIS_CTRSM_MUL(ymm8) -#endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*1) ); - ymm2 = _mm256_permute_ps(ymm2, 0x44); - if(conjtransa) - { - ymm2 = _mm256_mul_ps(ymm2, ymm18); - } - - //extract a11 - //(ROw1): FMA operations - //For ymm8 - ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); - ymm1 = _mm256_shuffle_ps(ymm8, ymm8, 0xA0); - ymm16 = _mm256_shuffle_ps(ymm8, ymm8,0xF5); - ymm16 = _mm256_mul_ps(ymm16, ymm3); - ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); - ymm10 = _mm256_sub_ps(ymm10,ymm16); - - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*2) ); - ymm2 = _mm256_permute_ps(ymm2, 0x44); - if(conjtransa) - { - ymm2 = _mm256_mul_ps(ymm2, ymm18); - } - - //For ymm8 - ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); - ymm1 = _mm256_shuffle_ps(ymm8, ymm8, 0xA0); - ymm16 = _mm256_shuffle_ps(ymm8, ymm8,0xF5); - ymm16 = _mm256_mul_ps(ymm16, ymm3); - ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); - ymm12 = _mm256_sub_ps(ymm12,ymm16); - - - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); - ymm1 = _mm256_permute_ps(ymm1, 0x44); -#ifndef BLIS_ENABLE_TRSM_PREINVERSION - BLIS_CTRSM_DIV(ymm10) -#else - BLIS_CTRSM_MUL(ymm10) -#endif - - - a11 += cs_a; - - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*2) ); - ymm2 = _mm256_permute_ps(ymm2, 0x44); - if(conjtransa) - { - ymm2 = _mm256_mul_ps(ymm2, ymm18); - } - - //For ymm9 - ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); - ymm1 = _mm256_shuffle_ps(ymm10, ymm10, 0xA0); - ymm16 = _mm256_shuffle_ps(ymm10, ymm10,0xF5); - ymm16 = _mm256_mul_ps(ymm16, ymm3); - ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); - ymm12 = _mm256_sub_ps(ymm12,ymm16); - - - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 2)); - ymm1 = _mm256_permute_ps(ymm1, 0x44); - -#ifndef BLIS_ENABLE_TRSM_PREINVERSION - BLIS_CTRSM_DIV(ymm12) -#else - BLIS_CTRSM_MUL(ymm12) -#endif - -/* ymm0 = _mm256_loadu_ps((float const *)b11); - ymm1 = _mm256_loadu_ps((float const *)(b11 + cs_b)); - ymm2 = _mm256_loadu_ps((float const *)(b11 + cs_b * 2)); - ymm8 = _mm256_blend_ps(ymm8, ymm0, 0xC0); - ymm10 = _mm256_blend_ps(ymm10, ymm1, 0xC0); - ymm12 = _mm256_blend_ps(ymm12, ymm2, 0xC0); - _mm256_storeu_ps((float *)b11, ymm8); - _mm256_storeu_ps((float *)(b11 + cs_b), ymm10); - _mm256_storeu_ps((float *)(b11 + cs_b * 2), ymm12);*/ - xmm0 = _mm256_extractf128_ps(ymm8, 0); - xmm1 = _mm256_extractf128_ps(ymm8, 1); - _mm_storeu_ps((float *)(b11), xmm0); - _mm_storel_pi((__m64 *)(b11 + 2), xmm1); - xmm0 = _mm256_extractf128_ps(ymm10, 0); - xmm1 = _mm256_extractf128_ps(ymm10, 1); - _mm_storeu_ps((float *)(b11 + cs_b), xmm0); - _mm_storel_pi((__m64 *)(b11 + cs_b + 2), xmm1); - xmm0 = _mm256_extractf128_ps(ymm12, 0); - xmm1 = _mm256_extractf128_ps(ymm12, 1); - _mm_storeu_ps((float *)(b11 + cs_b * 2), xmm0); - _mm_storel_pi((__m64 *)(b11 + cs_b * 2 + 2), xmm1); - - m_rem -= 3; - i += 3; - - } - if(m_rem == 2) - { - - a01 = D_A_pack; - a11 = L + j*cs_a + j*rs_a; - b10 = B + i; - b11 = B + i + j*cs_b; - - k_iter = j; - - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_S_YMM_REG_ZEROS - - ///GEMM implementation starts/// - BLIS_CTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) - - // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_CTRSM_SMALL_3x2(AlphaVal,b11,cs_b) - - ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); - ymm1 = _mm256_permute_ps(ymm1, 0x44); -#ifndef BLIS_ENABLE_TRSM_PREINVERSION - BLIS_CTRSM_DIV(ymm8) -#else - BLIS_CTRSM_MUL(ymm8) -#endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*1) ); - ymm2 = _mm256_permute_ps(ymm2, 0x44); - if(conjtransa) - { - ymm2 = _mm256_mul_ps(ymm2, ymm18); - } - - //extract a11 - //(ROw1): FMA operations - //For ymm8 - ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); - ymm1 = _mm256_shuffle_ps(ymm8, ymm8, 0xA0); - ymm16 = _mm256_shuffle_ps(ymm8, ymm8,0xF5); - ymm16 = _mm256_mul_ps(ymm16, ymm3); - ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); - ymm10 = _mm256_sub_ps(ymm10,ymm16); - - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*2) ); - ymm2 = _mm256_permute_ps(ymm2, 0x44); - if(conjtransa) - { - ymm2 = _mm256_mul_ps(ymm2, ymm18); - } - - //For ymm8 - ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); - ymm1 = _mm256_shuffle_ps(ymm8, ymm8, 0xA0); - ymm16 = _mm256_shuffle_ps(ymm8, ymm8,0xF5); - ymm16 = _mm256_mul_ps(ymm16, ymm3); - ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); - ymm12 = _mm256_sub_ps(ymm12,ymm16); - - - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); - ymm1 = _mm256_permute_ps(ymm1, 0x44); -#ifndef BLIS_ENABLE_TRSM_PREINVERSION - BLIS_CTRSM_DIV(ymm10) -#else - BLIS_CTRSM_MUL(ymm10) -#endif - - - a11 += cs_a; - - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*2) ); - ymm2 = _mm256_permute_ps(ymm2, 0x44); - if(conjtransa) - { - ymm2 = _mm256_mul_ps(ymm2, ymm18); - } - - //For ymm9 - ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); - ymm1 = _mm256_shuffle_ps(ymm10, ymm10, 0xA0); - ymm16 = _mm256_shuffle_ps(ymm10, ymm10,0xF5); - ymm16 = _mm256_mul_ps(ymm16, ymm3); - ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); - ymm12 = _mm256_sub_ps(ymm12,ymm16); - - - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 2)); - ymm1 = _mm256_permute_ps(ymm1, 0x44); - -#ifndef BLIS_ENABLE_TRSM_PREINVERSION - BLIS_CTRSM_DIV(ymm12) -#else - BLIS_CTRSM_MUL(ymm12) -#endif - -/* ymm0 = _mm256_loadu_ps((float const *)b11); - ymm1 = _mm256_loadu_ps((float const *)(b11 + cs_b)); - ymm2 = _mm256_loadu_ps((float const *)(b11 + cs_b * 2)); - ymm8 = _mm256_blend_ps(ymm8, ymm0, 0xF0); - ymm10 = _mm256_blend_ps(ymm10, ymm1, 0xF0); - ymm12 = _mm256_blend_ps(ymm12, ymm2, 0xF0); - _mm256_storeu_ps((float *)b11, ymm8); - _mm256_storeu_ps((float *)(b11 + cs_b), ymm10); - _mm256_storeu_ps((float *)(b11 + cs_b * 2), ymm12); -*/ - xmm0 = _mm256_extractf128_ps(ymm8, 0); - _mm_storeu_ps((float *)(b11), xmm0); - xmm0 = _mm256_extractf128_ps(ymm10, 0); - _mm_storeu_ps((float *)(b11 + cs_b), xmm0); - xmm0 = _mm256_extractf128_ps(ymm12, 0); - _mm_storeu_ps((float *)(b11 + cs_b * 2), xmm0); - - - m_rem -= 2; - i += 2; - } - if(m_rem == 1) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j*rs_a; - b10 = B + i; - b11 = B + i + j*cs_b; - - k_iter = j; - - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_S_YMM_REG_ZEROS - - ///GEMM implementation starts/// - BLIS_CTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) - - // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_CTRSM_SMALL_3x1(AlphaVal,b11,cs_b) - - ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); - ymm1 = _mm256_permute_ps(ymm1, 0x44); -#ifndef BLIS_ENABLE_TRSM_PREINVERSION - BLIS_CTRSM_DIV(ymm8) -#else - BLIS_CTRSM_MUL(ymm8) -#endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*1) ); - ymm2 = _mm256_permute_ps(ymm2, 0x44); - if(conjtransa) - { - ymm2 = _mm256_mul_ps(ymm2, ymm18); - } - - //extract a11 - //(ROw1): FMA operations - //For ymm8 - ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); - ymm1 = _mm256_shuffle_ps(ymm8, ymm8, 0xA0); - ymm16 = _mm256_shuffle_ps(ymm8, ymm8,0xF5); - ymm16 = _mm256_mul_ps(ymm16, ymm3); - ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); - ymm10 = _mm256_sub_ps(ymm10,ymm16); - - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*2) ); - ymm2 = _mm256_permute_ps(ymm2, 0x44); - if(conjtransa) - { - ymm2 = _mm256_mul_ps(ymm2, ymm18); - } - - //For ymm8 - ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); - ymm1 = _mm256_shuffle_ps(ymm8, ymm8, 0xA0); - ymm16 = _mm256_shuffle_ps(ymm8, ymm8,0xF5); - ymm16 = _mm256_mul_ps(ymm16, ymm3); - ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); - ymm12 = _mm256_sub_ps(ymm12,ymm16); - - - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); - ymm1 = _mm256_permute_ps(ymm1, 0x44); -#ifndef BLIS_ENABLE_TRSM_PREINVERSION - BLIS_CTRSM_DIV(ymm10) -#else - BLIS_CTRSM_MUL(ymm10) -#endif - - - a11 += cs_a; - - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*2) ); - ymm2 = _mm256_permute_ps(ymm2, 0x44); - if(conjtransa) - { - ymm2 = _mm256_mul_ps(ymm2, ymm18); - } - - //For ymm9 - ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); - ymm1 = _mm256_shuffle_ps(ymm10, ymm10, 0xA0); - ymm16 = _mm256_shuffle_ps(ymm10, ymm10,0xF5); - ymm16 = _mm256_mul_ps(ymm16, ymm3); - ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); - ymm12 = _mm256_sub_ps(ymm12,ymm16); - - - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 2)); - ymm1 = _mm256_permute_ps(ymm1, 0x44); - -#ifndef BLIS_ENABLE_TRSM_PREINVERSION - BLIS_CTRSM_DIV(ymm12) -#else - BLIS_CTRSM_MUL(ymm12) -#endif - -/* ymm0 = _mm256_loadu_ps((float const *)b11); - ymm1 = _mm256_loadu_ps((float const *)(b11 + cs_b)); - ymm2 = _mm256_loadu_ps((float const *)(b11 + cs_b * 2)); - ymm8 = _mm256_blend_ps(ymm8, ymm0, 0xFC); - ymm10 = _mm256_blend_ps(ymm10, ymm1, 0xFC); - ymm12 = _mm256_blend_ps(ymm12, ymm2, 0xFC); - _mm256_storeu_ps((float *)b11, ymm8); - _mm256_storeu_ps((float *)(b11 + cs_b), ymm10); - _mm256_storeu_ps((float *)(b11 + cs_b * 2), ymm12); -*/ + + m_rem -= 4; + i += 4; + } + if(m_rem == 3) + { + + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; + b10 = B + i; + b11 = B + i + j*cs_b; + + k_iter = j; + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_CTRSM_SMALL_GEMM_3nx3m(a01,b10,cs_b,p_lda,k_iter) + + // Load b11 of size 4x6 and multiply with alpha + BLIS_PRE_CTRSM_SMALL_3x3(AlphaVal,b11,cs_b) + + ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm8) +#else + BLIS_CTRSM_MUL(ymm8) +#endif + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*1) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //extract a11 + //(ROw1): FMA operations + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm8, ymm8, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm8, ymm8,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm10 = _mm256_sub_ps(ymm10,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*2) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm8, ymm8, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm8, ymm8,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm12 = _mm256_sub_ps(ymm12,ymm16); + + + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm10) +#else + BLIS_CTRSM_MUL(ymm10) +#endif + + + a11 += cs_a; + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*2) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm9 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm10, ymm10, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm10, ymm10,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm12 = _mm256_sub_ps(ymm12,ymm16); + + + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 2)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); + +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm12) +#else + BLIS_CTRSM_MUL(ymm12) +#endif + + xmm0 = _mm256_extractf128_ps(ymm8, 0); + xmm1 = _mm256_extractf128_ps(ymm8, 1); + _mm_storeu_ps((float *)(b11), xmm0); + _mm_storel_pi((__m64 *)(b11 + 2), xmm1); + xmm0 = _mm256_extractf128_ps(ymm10, 0); + xmm1 = _mm256_extractf128_ps(ymm10, 1); + _mm_storeu_ps((float *)(b11 + cs_b), xmm0); + _mm_storel_pi((__m64 *)(b11 + cs_b + 2), xmm1); + xmm0 = _mm256_extractf128_ps(ymm12, 0); + xmm1 = _mm256_extractf128_ps(ymm12, 1); + _mm_storeu_ps((float *)(b11 + cs_b * 2), xmm0); + _mm_storel_pi((__m64 *)(b11 + cs_b * 2 + 2), xmm1); + + m_rem -= 3; + i += 3; + + } + if(m_rem == 2) + { + + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; + b10 = B + i; + b11 = B + i + j*cs_b; + + k_iter = j; + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_CTRSM_SMALL_GEMM_3nx2m(a01,b10,cs_b,p_lda,k_iter) + + // Load b11 of size 4x6 and multiply with alpha + BLIS_PRE_CTRSM_SMALL_3x2(AlphaVal,b11,cs_b) + + ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm8) +#else + BLIS_CTRSM_MUL(ymm8) +#endif + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*1) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //extract a11 + //(ROw1): FMA operations + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm8, ymm8, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm8, ymm8,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm10 = _mm256_sub_ps(ymm10,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*2) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm8, ymm8, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm8, ymm8,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm12 = _mm256_sub_ps(ymm12,ymm16); + + + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm10) +#else + BLIS_CTRSM_MUL(ymm10) +#endif + + + a11 += cs_a; + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*2) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm9 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm10, ymm10, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm10, ymm10,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm12 = _mm256_sub_ps(ymm12,ymm16); + + + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 2)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); + +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm12) +#else + BLIS_CTRSM_MUL(ymm12) +#endif + + xmm0 = _mm256_extractf128_ps(ymm8, 0); + _mm_storeu_ps((float *)(b11), xmm0); + xmm0 = _mm256_extractf128_ps(ymm10, 0); + _mm_storeu_ps((float *)(b11 + cs_b), xmm0); + xmm0 = _mm256_extractf128_ps(ymm12, 0); + _mm_storeu_ps((float *)(b11 + cs_b * 2), xmm0); + + + m_rem -= 2; + i += 2; + } + if(m_rem == 1) + { + a01 = D_A_pack; + a11 = L + j*cs_a + j*rs_a; + b10 = B + i; + b11 = B + i + j*cs_b; + + k_iter = j; + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_S_YMM_REG_ZEROS + + ///GEMM implementation starts/// + BLIS_CTRSM_SMALL_GEMM_3nx1m(a01,b10,cs_b,p_lda,k_iter) + + // Load b11 of size 4x6 and multiply with alpha + BLIS_PRE_CTRSM_SMALL_3x1(AlphaVal,b11,cs_b) + + ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm8) +#else + BLIS_CTRSM_MUL(ymm8) +#endif + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*1) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //extract a11 + //(ROw1): FMA operations + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm8, ymm8, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm8, ymm8,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm10 = _mm256_sub_ps(ymm10,ymm16); + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*2) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm8 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm8, ymm8, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm8, ymm8,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm12 = _mm256_sub_ps(ymm12,ymm16); + + + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm10) +#else + BLIS_CTRSM_MUL(ymm10) +#endif + + + a11 += cs_a; + + ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*2) ); + ymm2 = _mm256_permute_ps(ymm2, 0x44); + if(conjtransa) + { + ymm2 = _mm256_mul_ps(ymm2, ymm18); + } + + //For ymm9 + ymm3 = _mm256_shuffle_ps(ymm2, ymm2, 0x11); + ymm1 = _mm256_shuffle_ps(ymm10, ymm10, 0xA0); + ymm16 = _mm256_shuffle_ps(ymm10, ymm10,0xF5); + ymm16 = _mm256_mul_ps(ymm16, ymm3); + ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); + ymm12 = _mm256_sub_ps(ymm12,ymm16); + + + ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 2)); + ymm1 = _mm256_permute_ps(ymm1, 0x44); + +#ifndef BLIS_ENABLE_TRSM_PREINVERSION + BLIS_CTRSM_DIV(ymm12) +#else + BLIS_CTRSM_MUL(ymm12) +#endif + xmm0 = _mm256_extractf128_ps(ymm8, 0); xmm1 = _mm256_extractf128_ps(ymm10, 0); xmm2 = _mm256_extractf128_ps(ymm12, 0); @@ -46177,7 +49068,7 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_CTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_CTRSM_SMALL_GEMM_2nx3m(a01,b10,cs_b,p_lda,k_iter) ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); ymm16 = _mm256_permute_ps(ymm16, 0x44); @@ -46263,7 +49154,7 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_CTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_CTRSM_SMALL_GEMM_2nx2m(a01,b10,cs_b,p_lda,k_iter) ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); ymm16 = _mm256_permute_ps(ymm16, 0x44); @@ -46340,7 +49231,7 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_CTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_CTRSM_SMALL_GEMM_2nx1m(a01,b10,cs_b,p_lda,k_iter) ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); ymm16 = _mm256_permute_ps(ymm16, 0x44); @@ -46573,7 +49464,7 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_CTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_CTRSM_SMALL_GEMM_1nx3m(a01,b10,cs_b,p_lda,k_iter) ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); ymm16 = _mm256_permute_ps(ymm16, 0x44); @@ -46618,7 +49509,7 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_CTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_CTRSM_SMALL_GEMM_1nx2m(a01,b10,cs_b,p_lda,k_iter) ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); ymm16 = _mm256_permute_ps(ymm16, 0x44); @@ -46659,7 +49550,7 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB BLIS_SET_S_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_CTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_CTRSM_SMALL_GEMM_1nx1m(a01,b10,cs_b,p_lda,k_iter) ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); ymm16 = _mm256_permute_ps(ymm16, 0x44); From d925ebeb06fefae9d77f1fbcd9e0029d601aaa29 Mon Sep 17 00:00:00 2001 From: Chandrashekara K R Date: Tue, 23 Aug 2022 14:58:26 +0530 Subject: [PATCH 194/243] CBLAS/BLAS interface decoupling for level 3 APIs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ->In BLIS the cblas interface is implemented as a wrapper around the blas interface. For example the CBLAS api ‘cblas_dgemm’ internally invokes BLAS API ‘dgemm_’. ->If the end user wants to use the different libraries for CBLAS and BLAS, current implantation of BLIS doesn’t allow it and may result in recursion ->This change separate the CBLAS and BLAS implantation by adding and additional level of abstraction. The implementation of the API is moved to the new function which is invoked directly from the CBLAS and BLAS wrappers. AMD-Internal: [SWLCSG-1477] Change-Id: I6218a3e81060fc8045f4de0ace87f708465dfae5 --- frame/compat/bla_gemm.c | 41 +++++- frame/compat/bla_gemm.h | 13 ++ frame/compat/bla_gemm3m.c | 40 +++++- frame/compat/bla_gemm3m.h | 15 +- frame/compat/bla_gemm_amd.c | 82 ++++++++++- frame/compat/bla_gemmt.c | 40 +++++- frame/compat/bla_gemmt.h | 15 +- frame/compat/bla_hemm.c | 38 ++++- frame/compat/bla_hemm.h | 13 ++ frame/compat/bla_her2k.c | 40 +++++- frame/compat/bla_her2k.h | 13 ++ frame/compat/bla_herk.c | 38 ++++- frame/compat/bla_herk.h | 12 ++ frame/compat/bla_symm.c | 38 ++++- frame/compat/bla_symm.h | 13 ++ frame/compat/bla_syr2k.c | 38 ++++- frame/compat/bla_syr2k.h | 13 ++ frame/compat/bla_syrk.c | 36 ++++- frame/compat/bla_syrk.h | 12 ++ frame/compat/bla_trmm.c | 38 ++++- frame/compat/bla_trmm.h | 13 ++ frame/compat/bla_trsm.c | 38 ++++- frame/compat/bla_trsm.h | 13 ++ frame/compat/bla_trsm_amd.c | 108 ++++++++++++-- frame/compat/cblas/src/cblas_f77.h | 74 +++++----- frame/include/bli_macro_defs.h | 4 +- frame/util/bli_util_api_wrap.c | 218 ++++++++++++++--------------- 27 files changed, 846 insertions(+), 210 deletions(-) diff --git a/frame/compat/bla_gemm.c b/frame/compat/bla_gemm.c index d3601952ba..ae14196e89 100644 --- a/frame/compat/bla_gemm.c +++ b/frame/compat/bla_gemm.c @@ -44,7 +44,7 @@ #undef GENTFUNC #define GENTFUNC( ftype, ch, blasname, blisname ) \ \ -void PASTEF77(ch,blasname) \ +void PASTEF77S(ch,blasname) \ ( \ const f77_char* transa, \ const f77_char* transb, \ @@ -136,14 +136,30 @@ void PASTEF77(ch,blasname) \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ -} +} \ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* transa, \ + const f77_char* transb, \ + const f77_int* m, \ + const f77_int* n, \ + const f77_int* k, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + const ftype* b, const f77_int* ldb, \ + const ftype* beta, \ + ftype* c, const f77_int* ldc \ + ) \ +{ \ + PASTEF77S(ch,blasname) ( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc ); \ +} \ #else #undef GENTFUNC #define GENTFUNC( ftype, ch, blasname, blisname ) \ \ -void PASTEF77(ch,blasname) \ +void PASTEF77S(ch,blasname) \ ( \ const f77_char* transa, \ const f77_char* transb, \ @@ -318,7 +334,24 @@ void PASTEF77(ch,blasname) \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ -} +} \ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* transa, \ + const f77_char* transb, \ + const f77_int* m, \ + const f77_int* n, \ + const f77_int* k, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + const ftype* b, const f77_int* ldb, \ + const ftype* beta, \ + ftype* c, const f77_int* ldc \ + ) \ +{ \ + PASTEF77S(ch,blasname) ( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc ); \ +} \ + #endif #ifdef BLIS_ENABLE_BLAS diff --git a/frame/compat/bla_gemm.h b/frame/compat/bla_gemm.h index c9ea83149a..d8fe6ddb94 100644 --- a/frame/compat/bla_gemm.h +++ b/frame/compat/bla_gemm.h @@ -41,6 +41,19 @@ #define GENTPROT( ftype, ch, blasname ) \ \ BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ + ( \ + const f77_char* transa, \ + const f77_char* transb, \ + const f77_int* m, \ + const f77_int* n, \ + const f77_int* k, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + const ftype* b, const f77_int* ldb, \ + const ftype* beta, \ + ftype* c, const f77_int* ldc \ + ); \ +BLIS_EXPORT_BLAS void PASTEF77S(ch,blasname) \ ( \ const f77_char* transa, \ const f77_char* transb, \ diff --git a/frame/compat/bla_gemm3m.c b/frame/compat/bla_gemm3m.c index 665c8643dd..4ecbba5551 100644 --- a/frame/compat/bla_gemm3m.c +++ b/frame/compat/bla_gemm3m.c @@ -44,7 +44,7 @@ #undef GENTFUNC #define GENTFUNC( ftype, ch, blasname, blisname ) \ \ -void PASTEF77(ch,blasname) \ +void PASTEF77S(ch,blasname) \ ( \ const f77_char* transa, \ const f77_char* transb, \ @@ -131,14 +131,30 @@ void PASTEF77(ch,blasname) \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ -} +} \ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* transa, \ + const f77_char* transb, \ + const f77_int* m, \ + const f77_int* n, \ + const f77_int* k, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + const ftype* b, const f77_int* ldb, \ + const ftype* beta, \ + ftype* c, const f77_int* ldc \ + ) \ +{ \ + PASTEF77S(ch,blasname) ( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc ); \ +} \ #else #undef GENTFUNC #define GENTFUNC( ftype, ch, blasname, blisname ) \ \ -void PASTEF77(ch,blasname) \ +void PASTEF77S(ch,blasname) \ ( \ const f77_char* transa, \ const f77_char* transb, \ @@ -240,7 +256,23 @@ void PASTEF77(ch,blasname) \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ -} +} \ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* transa, \ + const f77_char* transb, \ + const f77_int* m, \ + const f77_int* n, \ + const f77_int* k, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + const ftype* b, const f77_int* ldb, \ + const ftype* beta, \ + ftype* c, const f77_int* ldc \ + ) \ +{ \ + PASTEF77S(ch,blasname) ( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc ); \ +} \ #endif diff --git a/frame/compat/bla_gemm3m.h b/frame/compat/bla_gemm3m.h index 1063d85c03..d64e3f199f 100644 --- a/frame/compat/bla_gemm3m.h +++ b/frame/compat/bla_gemm3m.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -40,6 +40,19 @@ #define GENTPROT( ftype, ch, blasname ) \ \ BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ + ( \ + const f77_char* transa, \ + const f77_char* transb, \ + const f77_int* m, \ + const f77_int* n, \ + const f77_int* k, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + const ftype* b, const f77_int* ldb, \ + const ftype* beta, \ + ftype* c, const f77_int* ldc \ + ); \ +BLIS_EXPORT_BLAS void PASTEF77S(ch,blasname) \ ( \ const f77_char* transa, \ const f77_char* transb, \ diff --git a/frame/compat/bla_gemm_amd.c b/frame/compat/bla_gemm_amd.c index 2a9dcb99db..6bc0dcd557 100644 --- a/frame/compat/bla_gemm_amd.c +++ b/frame/compat/bla_gemm_amd.c @@ -44,7 +44,7 @@ #undef GENTFUNC #define GENTFUNC( ftype, ch, blasname, blisname ) \ \ -void PASTEF77(ch,blasname) \ +void PASTEF77S(ch,blasname) \ ( \ const f77_char* transa, \ const f77_char* transb, \ @@ -136,14 +136,32 @@ void PASTEF77(ch,blasname) \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ -} +} \ +\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* transa, \ + const f77_char* transb, \ + const f77_int* m, \ + const f77_int* n, \ + const f77_int* k, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + const ftype* b, const f77_int* ldb, \ + const ftype* beta, \ + ftype* c, const f77_int* ldc \ + ) \ +{ \ +\ + PASTEF77S(ch,blasname) ( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc ); \ +} \ #else #undef GENTFUNC #define GENTFUNC( ftype, ch, blasname, blisname ) \ \ -void PASTEF77(ch,blasname) \ +void PASTEF77S(ch,blasname) \ ( \ const f77_char* transa, \ const f77_char* transb, \ @@ -318,11 +336,30 @@ void PASTEF77(ch,blasname) \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ -} +} \ +\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* transa, \ + const f77_char* transb, \ + const f77_int* m, \ + const f77_int* n, \ + const f77_int* k, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + const ftype* b, const f77_int* ldb, \ + const ftype* beta, \ + ftype* c, const f77_int* ldc \ + ) \ +{ \ +\ + PASTEF77S(ch,blasname) ( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc ); \ +} \ + #endif #ifdef BLIS_ENABLE_BLAS -void dgemm_ +void dgemm_blis_impl ( const f77_char* transa, const f77_char* transb, @@ -658,7 +695,24 @@ void dgemm_ bli_finalize_auto(); } // end of dgemm_ -void zgemm_ +void dgemm_ +( + const f77_char* transa, + const f77_char* transb, + const f77_int* m, + const f77_int* n, + const f77_int* k, + const double* alpha, + const double* a, const f77_int* lda, + const double* b, const f77_int* ldb, + const double* beta, + double* c, const f77_int* ldc +) +{ + dgemm_blis_impl(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void zgemm_blis_impl ( const f77_char* transa, const f77_char* transb, @@ -915,6 +969,22 @@ void zgemm_ bli_finalize_auto(); }// end of zgemm_ +void zgemm_ + ( + const f77_char* transa, + const f77_char* transb, + const f77_int* m, + const f77_int* n, + const f77_int* k, + const dcomplex* alpha, + const dcomplex* a, const f77_int* lda, + const dcomplex* b, const f77_int* ldb, + const dcomplex* beta, + dcomplex* c, const f77_int* ldc + ) +{ + zgemm_blis_impl(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); +} INSERT_GENTFUNC_BLAS_SC( gemm, gemm ) diff --git a/frame/compat/bla_gemmt.c b/frame/compat/bla_gemmt.c index 7abad40acf..f8f6fa2de6 100644 --- a/frame/compat/bla_gemmt.c +++ b/frame/compat/bla_gemmt.c @@ -44,7 +44,7 @@ #undef GENTFUNC #define GENTFUNC( ftype, ch, blasname, blisname ) \ \ -void PASTEF77(ch,blasname) \ +void PASTEF77S(ch,blasname) \ ( \ const f77_char* uploc, \ const f77_char* transa, \ @@ -134,14 +134,30 @@ void PASTEF77(ch,blasname) \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ -} +} \ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* uploc, \ + const f77_char* transa, \ + const f77_char* transb, \ + const f77_int* n, \ + const f77_int* k, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + const ftype* b, const f77_int* ldb, \ + const ftype* beta, \ + ftype* c, const f77_int* ldc \ + ) \ +{ \ + PASTEF77S(ch,blasname) ( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc ); \ +} \ #else #undef GENTFUNC #define GENTFUNC( ftype, ch, blasname, blisname ) \ \ -void PASTEF77(ch,blasname) \ +void PASTEF77S(ch,blasname) \ ( \ const f77_char* uploc, \ const f77_char* transa, \ @@ -247,7 +263,23 @@ void PASTEF77(ch,blasname) \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ -} +} \ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* uploc, \ + const f77_char* transa, \ + const f77_char* transb, \ + const f77_int* n, \ + const f77_int* k, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + const ftype* b, const f77_int* ldb, \ + const ftype* beta, \ + ftype* c, const f77_int* ldc \ + ) \ +{ \ + PASTEF77S(ch,blasname) ( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc ); \ +} \ #endif diff --git a/frame/compat/bla_gemmt.h b/frame/compat/bla_gemmt.h index 8043d68291..d4efb995ce 100644 --- a/frame/compat/bla_gemmt.h +++ b/frame/compat/bla_gemmt.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2020, Advanced Micro Devices, Inc. + Copyright (C) 2020-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -40,6 +40,19 @@ #define GENTPROT( ftype, ch, blasname ) \ \ BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ + ( \ + const f77_char* uploc, \ + const f77_char* transa, \ + const f77_char* transb, \ + const f77_int* n, \ + const f77_int* k, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + const ftype* b, const f77_int* ldb, \ + const ftype* beta, \ + ftype* c, const f77_int* ldc \ + ); \ +BLIS_EXPORT_BLAS void PASTEF77S(ch,blasname) \ ( \ const f77_char* uploc, \ const f77_char* transa, \ diff --git a/frame/compat/bla_hemm.c b/frame/compat/bla_hemm.c index 0e003012d2..ed3cbb5178 100644 --- a/frame/compat/bla_hemm.c +++ b/frame/compat/bla_hemm.c @@ -45,7 +45,7 @@ #undef GENTFUNCCO #define GENTFUNCCO( ftype, ftype_r, ch, chr, blasname, blisname ) \ \ -void PASTEF77(ch,blasname) \ +void PASTEF77S(ch,blasname) \ ( \ const f77_char* side, \ const f77_char* uploa, \ @@ -132,14 +132,29 @@ void PASTEF77(ch,blasname) \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ -} +} \ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* side, \ + const f77_char* uploa, \ + const f77_int* m, \ + const f77_int* n, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + const ftype* b, const f77_int* ldb, \ + const ftype* beta, \ + ftype* c, const f77_int* ldc \ + ) \ +{ \ + PASTEF77S(ch,blasname) ( side, uploa, m, n, alpha, a, lda, b, ldb, beta, c, ldc ); \ + } \ #else #undef GENTFUNCCO #define GENTFUNCCO( ftype, ftype_r, ch, chr, blasname, blisname ) \ \ -void PASTEF77(ch,blasname) \ +void PASTEF77S(ch,blasname) \ ( \ const f77_char* side, \ const f77_char* uploa, \ @@ -248,7 +263,22 @@ void PASTEF77(ch,blasname) \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ -} +} \ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* side, \ + const f77_char* uploa, \ + const f77_int* m, \ + const f77_int* n, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + const ftype* b, const f77_int* ldb, \ + const ftype* beta, \ + ftype* c, const f77_int* ldc \ + ) \ +{ \ + PASTEF77S(ch,blasname) ( side, uploa, m, n, alpha, a, lda, b, ldb, beta, c, ldc ); \ + } \ #endif diff --git a/frame/compat/bla_hemm.h b/frame/compat/bla_hemm.h index 711877edee..7054be7c90 100644 --- a/frame/compat/bla_hemm.h +++ b/frame/compat/bla_hemm.h @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -40,6 +41,18 @@ #define GENTPROTCO( ftype, ftype_r, ch, chr, blasname ) \ \ BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ + ( \ + const f77_char* side, \ + const f77_char* uploa, \ + const f77_int* m, \ + const f77_int* n, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + const ftype* b, const f77_int* ldb, \ + const ftype* beta, \ + ftype* c, const f77_int* ldc \ + ); \ +BLIS_EXPORT_BLAS void PASTEF77S(ch,blasname) \ ( \ const f77_char* side, \ const f77_char* uploa, \ diff --git a/frame/compat/bla_her2k.c b/frame/compat/bla_her2k.c index e21a2cda41..cba6432eb3 100755 --- a/frame/compat/bla_her2k.c +++ b/frame/compat/bla_her2k.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019 - 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -45,7 +45,7 @@ #undef GENTFUNCCO #define GENTFUNCCO( ftype, ftype_r, ch, chr, blasname, blisname ) \ \ -void PASTEF77(ch,blasname) \ +void PASTEF77S(ch,blasname) \ ( \ const f77_char* uploc, \ const f77_char* transa, \ @@ -137,14 +137,29 @@ void PASTEF77(ch,blasname) \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ -} +} \ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* uploc, \ + const f77_char* transa, \ + const f77_int* m, \ + const f77_int* k, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + const ftype* b, const f77_int* ldb, \ + const ftype_r* beta, \ + ftype* c, const f77_int* ldc \ + ) \ +{ \ + PASTEF77S(ch,blasname) ( uploc, transa, m, k, alpha, a, lda, b, ldb, beta, c, ldc ); \ + } \ #else #undef GENTFUNCCO #define GENTFUNCCO( ftype, ftype_r, ch, chr, blasname, blisname ) \ \ -void PASTEF77(ch,blasname) \ +void PASTEF77S(ch,blasname) \ ( \ const f77_char* uploc, \ const f77_char* transa, \ @@ -258,7 +273,22 @@ void PASTEF77(ch,blasname) \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ -} +} \ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* uploc, \ + const f77_char* transa, \ + const f77_int* m, \ + const f77_int* k, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + const ftype* b, const f77_int* ldb, \ + const ftype_r* beta, \ + ftype* c, const f77_int* ldc \ + ) \ +{ \ + PASTEF77S(ch,blasname) ( uploc, transa, m, k, alpha, a, lda, b, ldb, beta, c, ldc ); \ + } \ #endif diff --git a/frame/compat/bla_her2k.h b/frame/compat/bla_her2k.h index c771f78d4c..a3fa413027 100644 --- a/frame/compat/bla_her2k.h +++ b/frame/compat/bla_her2k.h @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -40,6 +41,18 @@ #define GENTPROTCO( ftype, ftype_r, ch, chr, blasname ) \ \ BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ + ( \ + const f77_char* uploc, \ + const f77_char* transa, \ + const f77_int* m, \ + const f77_int* k, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + const ftype* b, const f77_int* ldb, \ + const ftype_r* beta, \ + ftype* c, const f77_int* ldc \ + ); \ +BLIS_EXPORT_BLAS void PASTEF77S(ch,blasname) \ ( \ const f77_char* uploc, \ const f77_char* transa, \ diff --git a/frame/compat/bla_herk.c b/frame/compat/bla_herk.c index 36188e6a66..b07ee180cc 100755 --- a/frame/compat/bla_herk.c +++ b/frame/compat/bla_herk.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019 - 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -45,7 +45,7 @@ #undef GENTFUNCCO #define GENTFUNCCO( ftype, ftype_r, ch, chr, blasname, blisname ) \ \ -void PASTEF77(ch,blasname) \ +void PASTEF77S(ch,blasname) \ ( \ const f77_char* uploc, \ const f77_char* transa, \ @@ -131,14 +131,28 @@ void PASTEF77(ch,blasname) \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ -} +} \ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* uploc, \ + const f77_char* transa, \ + const f77_int* m, \ + const f77_int* k, \ + const ftype_r* alpha, \ + const ftype* a, const f77_int* lda, \ + const ftype_r* beta, \ + ftype* c, const f77_int* ldc \ + ) \ +{ \ + PASTEF77S(ch,blasname) ( uploc, transa, m, k, alpha, a, lda, beta, c, ldc ); \ + } \ #else #undef GENTFUNCCO #define GENTFUNCCO( ftype, ftype_r, ch, chr, blasname, blisname ) \ \ -void PASTEF77(ch,blasname) \ +void PASTEF77S(ch,blasname) \ ( \ const f77_char* uploc, \ const f77_char* transa, \ @@ -242,7 +256,21 @@ void PASTEF77(ch,blasname) \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ -} +} \ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* uploc, \ + const f77_char* transa, \ + const f77_int* m, \ + const f77_int* k, \ + const ftype_r* alpha, \ + const ftype* a, const f77_int* lda, \ + const ftype_r* beta, \ + ftype* c, const f77_int* ldc \ + ) \ +{ \ + PASTEF77S(ch,blasname) ( uploc, transa, m, k, alpha, a, lda, beta, c, ldc ); \ +} \ #endif diff --git a/frame/compat/bla_herk.h b/frame/compat/bla_herk.h index e649a74abb..8ec9183e8f 100644 --- a/frame/compat/bla_herk.h +++ b/frame/compat/bla_herk.h @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -40,6 +41,17 @@ #define GENTPROTCO( ftype, ftype_r, ch, chr, blasname ) \ \ BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ + ( \ + const f77_char* uploc, \ + const f77_char* transa, \ + const f77_int* m, \ + const f77_int* k, \ + const ftype_r* alpha, \ + const ftype* a, const f77_int* lda, \ + const ftype_r* beta, \ + ftype* c, const f77_int* ldc \ + ); \ +BLIS_EXPORT_BLAS void PASTEF77S(ch,blasname) \ ( \ const f77_char* uploc, \ const f77_char* transa, \ diff --git a/frame/compat/bla_symm.c b/frame/compat/bla_symm.c index 85aebb435f..7b915a5edb 100755 --- a/frame/compat/bla_symm.c +++ b/frame/compat/bla_symm.c @@ -45,7 +45,7 @@ #undef GENTFUNC #define GENTFUNC( ftype, ch, blasname, blisname ) \ \ -void PASTEF77(ch,blasname) \ +void PASTEF77S(ch,blasname) \ ( \ const f77_char* side, \ const f77_char* uploa, \ @@ -131,14 +131,29 @@ void PASTEF77(ch,blasname) \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ -} +} \ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* side, \ + const f77_char* uploa, \ + const f77_int* m, \ + const f77_int* n, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + const ftype* b, const f77_int* ldb, \ + const ftype* beta, \ + ftype* c, const f77_int* ldc \ + ) \ +{ \ + PASTEF77S(ch,blasname) ( side, uploa, m, n, alpha, a, lda, b, ldb, beta, c, ldc ); \ + } \ #else #undef GENTFUNC #define GENTFUNC( ftype, ch, blasname, blisname ) \ \ -void PASTEF77(ch,blasname) \ +void PASTEF77S(ch,blasname) \ ( \ const f77_char* side, \ const f77_char* uploa, \ @@ -246,7 +261,22 @@ void PASTEF77(ch,blasname) \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ -} +} \ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* side, \ + const f77_char* uploa, \ + const f77_int* m, \ + const f77_int* n, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + const ftype* b, const f77_int* ldb, \ + const ftype* beta, \ + ftype* c, const f77_int* ldc \ + ) \ +{ \ + PASTEF77S(ch,blasname) ( side, uploa, m, n, alpha, a, lda, b, ldb, beta, c, ldc ); \ +} \ #endif diff --git a/frame/compat/bla_symm.h b/frame/compat/bla_symm.h index b186e4b436..f10e1cbb86 100644 --- a/frame/compat/bla_symm.h +++ b/frame/compat/bla_symm.h @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -40,6 +41,18 @@ #define GENTPROT( ftype, ch, blasname ) \ \ BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ + ( \ + const f77_char* side, \ + const f77_char* uploa, \ + const f77_int* m, \ + const f77_int* n, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + const ftype* b, const f77_int* ldb, \ + const ftype* beta, \ + ftype* c, const f77_int* ldc \ + ); \ +BLIS_EXPORT_BLAS void PASTEF77S(ch,blasname) \ ( \ const f77_char* side, \ const f77_char* uploa, \ diff --git a/frame/compat/bla_syr2k.c b/frame/compat/bla_syr2k.c index 6a4f31b969..751e008ae4 100644 --- a/frame/compat/bla_syr2k.c +++ b/frame/compat/bla_syr2k.c @@ -45,7 +45,7 @@ #undef GENTFUNC #define GENTFUNC( ftype, ch, blasname, blisname ) \ \ -void PASTEF77(ch,blasname) \ +void PASTEF77S(ch,blasname) \ ( \ const f77_char* uploc, \ const f77_char* transa, \ @@ -139,14 +139,29 @@ void PASTEF77(ch,blasname) \ /* Finalize BLIS. */ \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ bli_finalize_auto(); \ -} +} \ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* uploc, \ + const f77_char* transa, \ + const f77_int* m, \ + const f77_int* k, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + const ftype* b, const f77_int* ldb, \ + const ftype* beta, \ + ftype* c, const f77_int* ldc \ + ) \ +{ \ + PASTEF77S(ch,blasname) ( uploc, transa, m, k, alpha, a, lda, b, ldb, beta, c, ldc ); \ + } \ #else #undef GENTFUNC #define GENTFUNC( ftype, ch, blasname, blisname ) \ \ -void PASTEF77(ch,blasname) \ +void PASTEF77S(ch,blasname) \ ( \ const f77_char* uploc, \ const f77_char* transa, \ @@ -262,7 +277,22 @@ void PASTEF77(ch,blasname) \ /* Finalize BLIS. */ \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ bli_finalize_auto(); \ -} +} \ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* uploc, \ + const f77_char* transa, \ + const f77_int* m, \ + const f77_int* k, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + const ftype* b, const f77_int* ldb, \ + const ftype* beta, \ + ftype* c, const f77_int* ldc \ + ) \ +{ \ + PASTEF77S(ch,blasname) ( uploc, transa, m, k, alpha, a, lda, b, ldb, beta, c, ldc ); \ +} \ #endif diff --git a/frame/compat/bla_syr2k.h b/frame/compat/bla_syr2k.h index 91d9a3acf8..fc127d9ea5 100644 --- a/frame/compat/bla_syr2k.h +++ b/frame/compat/bla_syr2k.h @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -40,6 +41,18 @@ #define GENTPROT( ftype, ch, blasname ) \ \ BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ + ( \ + const f77_char* uploc, \ + const f77_char* transa, \ + const f77_int* m, \ + const f77_int* k, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + const ftype* b, const f77_int* ldb, \ + const ftype* beta, \ + ftype* c, const f77_int* ldc \ + ); \ +BLIS_EXPORT_BLAS void PASTEF77S(ch,blasname) \ ( \ const f77_char* uploc, \ const f77_char* transa, \ diff --git a/frame/compat/bla_syrk.c b/frame/compat/bla_syrk.c index 376b23aec9..b2ec611f58 100644 --- a/frame/compat/bla_syrk.c +++ b/frame/compat/bla_syrk.c @@ -45,7 +45,7 @@ #undef GENTFUNC #define GENTFUNC( ftype, ch, blasname, blisname ) \ \ -void PASTEF77(ch,blasname) \ +void PASTEF77S(ch,blasname) \ ( \ const f77_char* uploc, \ const f77_char* transa, \ @@ -133,14 +133,28 @@ void PASTEF77(ch,blasname) \ /* Finalize BLIS. */ \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ bli_finalize_auto(); \ -} +} \ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* uploc, \ + const f77_char* transa, \ + const f77_int* m, \ + const f77_int* k, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + const ftype* beta, \ + ftype* c, const f77_int* ldc \ + ) \ +{ \ + PASTEF77S(ch,blasname) ( uploc, transa, m, k, alpha, a, lda, beta, c, ldc ); \ +} \ #else #undef GENTFUNC #define GENTFUNC( ftype, ch, blasname, blisname ) \ \ -void PASTEF77(ch,blasname) \ +void PASTEF77S(ch,blasname) \ ( \ const f77_char* uploc, \ const f77_char* transa, \ @@ -245,7 +259,21 @@ void PASTEF77(ch,blasname) \ /* Finalize BLIS. */ \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ bli_finalize_auto(); \ -} +} \ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* uploc, \ + const f77_char* transa, \ + const f77_int* m, \ + const f77_int* k, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + const ftype* beta, \ + ftype* c, const f77_int* ldc \ + ) \ +{ \ + PASTEF77S(ch,blasname) ( uploc, transa, m, k, alpha, a, lda, beta, c, ldc ); \ +} \ #endif diff --git a/frame/compat/bla_syrk.h b/frame/compat/bla_syrk.h index b6ca938a6f..c87dc6694c 100644 --- a/frame/compat/bla_syrk.h +++ b/frame/compat/bla_syrk.h @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -40,6 +41,17 @@ #define GENTPROT( ftype, ch, blasname ) \ \ BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ + ( \ + const f77_char* uploc, \ + const f77_char* transa, \ + const f77_int* m, \ + const f77_int* k, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + const ftype* beta, \ + ftype* c, const f77_int* ldc \ + ); \ +BLIS_EXPORT_BLAS void PASTEF77S(ch,blasname) \ ( \ const f77_char* uploc, \ const f77_char* transa, \ diff --git a/frame/compat/bla_trmm.c b/frame/compat/bla_trmm.c index c319b3ab51..59c64b90e1 100644 --- a/frame/compat/bla_trmm.c +++ b/frame/compat/bla_trmm.c @@ -45,7 +45,7 @@ #undef GENTFUNC #define GENTFUNC( ftype, ch, blasname, blisname ) \ \ -void PASTEF77(ch,blasname) \ +void PASTEF77S(ch,blasname) \ ( \ const f77_char* side, \ const f77_char* uploa, \ @@ -131,14 +131,29 @@ void PASTEF77(ch,blasname) \ /* Finalize BLIS. */ \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ bli_finalize_auto(); \ -} +} \ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* side, \ + const f77_char* uploa, \ + const f77_char* transa, \ + const f77_char* diaga, \ + const f77_int* m, \ + const f77_int* n, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + ftype* b, const f77_int* ldb \ + ) \ +{ \ + PASTEF77S(ch,blasname) ( side, uploa, transa, diaga, m, n, alpha, a, lda, b, ldb ); \ +} \ #else #undef GENTFUNC #define GENTFUNC( ftype, ch, blasname, blisname ) \ \ -void PASTEF77(ch,blasname) \ +void PASTEF77S(ch,blasname) \ ( \ const f77_char* side, \ const f77_char* uploa, \ @@ -239,7 +254,22 @@ void PASTEF77(ch,blasname) \ /* Finalize BLIS. */ \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ bli_finalize_auto(); \ -} +} \ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* side, \ + const f77_char* uploa, \ + const f77_char* transa, \ + const f77_char* diaga, \ + const f77_int* m, \ + const f77_int* n, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + ftype* b, const f77_int* ldb \ + ) \ +{ \ + PASTEF77S(ch,blasname) ( side, uploa, transa, diaga, m, n, alpha, a, lda, b, ldb ); \ +} \ #endif diff --git a/frame/compat/bla_trmm.h b/frame/compat/bla_trmm.h index 4f0c20b1b2..10cbb6cbc2 100644 --- a/frame/compat/bla_trmm.h +++ b/frame/compat/bla_trmm.h @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -40,6 +41,18 @@ #define GENTPROT( ftype, ch, blasname ) \ \ BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ + ( \ + const f77_char* side, \ + const f77_char* uploa, \ + const f77_char* transa, \ + const f77_char* diaga, \ + const f77_int* m, \ + const f77_int* n, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + ftype* b, const f77_int* ldb \ + ); \ +BLIS_EXPORT_BLAS void PASTEF77S(ch,blasname) \ ( \ const f77_char* side, \ const f77_char* uploa, \ diff --git a/frame/compat/bla_trsm.c b/frame/compat/bla_trsm.c index e99805d8dd..f709a8cd0a 100644 --- a/frame/compat/bla_trsm.c +++ b/frame/compat/bla_trsm.c @@ -45,7 +45,7 @@ #undef GENTFUNC #define GENTFUNC( ftype, ch, blasname, blisname ) \ \ -void PASTEF77(ch,blasname) \ +void PASTEF77S(ch,blasname) \ ( \ const f77_char* side, \ const f77_char* uploa, \ @@ -130,14 +130,29 @@ void PASTEF77(ch,blasname) \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ -} +} \ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* side, \ + const f77_char* uploa, \ + const f77_char* transa, \ + const f77_char* diaga, \ + const f77_int* m, \ + const f77_int* n, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + ftype* b, const f77_int* ldb \ + ) \ +{ \ + PASTEF77S(ch,blasname) ( side, uploa, transa, diaga, m, n, alpha, a, lda, b, ldb ); \ +} \ #else #undef GENTFUNC #define GENTFUNC( ftype, ch, blasname, blisname ) \ \ -void PASTEF77(ch,blasname) \ +void PASTEF77S(ch,blasname) \ ( \ const f77_char* side, \ const f77_char* uploa, \ @@ -393,7 +408,22 @@ void PASTEF77(ch,blasname) \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ -} +} \ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* side, \ + const f77_char* uploa, \ + const f77_char* transa, \ + const f77_char* diaga, \ + const f77_int* m, \ + const f77_int* n, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + ftype* b, const f77_int* ldb \ + ) \ +{ \ + PASTEF77S(ch,blasname) ( side, uploa, transa, diaga, m, n, alpha, a, lda, b, ldb ); \ +} \ #endif diff --git a/frame/compat/bla_trsm.h b/frame/compat/bla_trsm.h index 5694db52a8..af1b626dff 100644 --- a/frame/compat/bla_trsm.h +++ b/frame/compat/bla_trsm.h @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -40,6 +41,18 @@ #define GENTPROT( ftype, ch, blasname ) \ \ BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ + ( \ + const f77_char* side, \ + const f77_char* uploa, \ + const f77_char* transa, \ + const f77_char* diaga, \ + const f77_int* m, \ + const f77_int* n, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + ftype* b, const f77_int* ldb \ + ); \ +BLIS_EXPORT_BLAS void PASTEF77S(ch,blasname) \ ( \ const f77_char* side, \ const f77_char* uploa, \ diff --git a/frame/compat/bla_trsm_amd.c b/frame/compat/bla_trsm_amd.c index 8ca7434bd8..4479725fb9 100644 --- a/frame/compat/bla_trsm_amd.c +++ b/frame/compat/bla_trsm_amd.c @@ -45,7 +45,7 @@ #undef GENTFUNC #define GENTFUNC( ftype, ch, blasname, blisname ) \ \ -void PASTEF77(ch,blasname) \ +void PASTEF77S(ch,blasname) \ ( \ const f77_char* side, \ const f77_char* uploa, \ @@ -130,14 +130,29 @@ void PASTEF77(ch,blasname) \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ -} +} \ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* side, \ + const f77_char* uploa, \ + const f77_char* transa, \ + const f77_char* diaga, \ + const f77_int* m, \ + const f77_int* n, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + ftype* b, const f77_int* ldb \ + ) \ +{ \ + PASTEF77S(ch,blasname) ( side, uploa, transa, diaga, m, n, alpha, a, lda, b, ldb ); \ + } \ #else #undef GENTFUNC #define GENTFUNC( ftype, ch, blasname, blisname ) \ \ -void PASTEF77(ch,blasname) \ +void PASTEF77S(ch,blasname) \ ( \ const f77_char* side, \ const f77_char* uploa, \ @@ -393,13 +408,28 @@ void PASTEF77(ch,blasname) \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ -} +} \ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* side, \ + const f77_char* uploa, \ + const f77_char* transa, \ + const f77_char* diaga, \ + const f77_int* m, \ + const f77_int* n, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + ftype* b, const f77_int* ldb \ + ) \ +{ \ + PASTEF77S(ch,blasname) ( side, uploa, transa, diaga, m, n, alpha, a, lda, b, ldb ); \ +} \ #endif #ifdef BLIS_ENABLE_BLAS -void strsm_ +void strsm_blis_impl ( const f77_char* side, const f77_char* uploa, @@ -669,8 +699,23 @@ void strsm_ /* Finalize BLIS. */ bli_finalize_auto(); } +void strsm_ +( + const f77_char* side, + const f77_char* uploa, + const f77_char* transa, + const f77_char* diaga, + const f77_int* m, + const f77_int* n, + const float* alpha, + const float* a, const f77_int* lda, + float* b, const f77_int* ldb +) +{ + strsm_blis_impl ( side, uploa, transa, diaga, m, n, alpha, a, lda, b, ldb ); +} -void dtrsm_ +void dtrsm_blis_impl ( const f77_char* side, const f77_char* uploa, @@ -892,7 +937,7 @@ void dtrsm_ bli_obj_set_conjtrans( blis_transa, &ao ); bli_obj_set_struc( struca, &ao ); - + #ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM // This function is invoked on all architectures including ‘generic’. // Non-AVX platforms will use the kernels derived from the context. @@ -973,9 +1018,24 @@ void dtrsm_ /* Finalize BLIS. */ bli_finalize_auto(); } +void dtrsm_ +( + const f77_char* side, + const f77_char* uploa, + const f77_char* transa, + const f77_char* diaga, + const f77_int* m, + const f77_int* n, + const double* alpha, + const double* a, const f77_int* lda, + double* b, const f77_int* ldb +) +{ + dtrsm_blis_impl ( side, uploa, transa, diaga, m, n, alpha, a, lda, b, ldb ); +} -void ztrsm_ +void ztrsm_blis_impl ( const f77_char* side, const f77_char* uploa, @@ -1331,9 +1391,24 @@ void ztrsm_ /* Finalize BLIS. */ bli_finalize_auto(); } +void ztrsm_ +( + const f77_char* side, + const f77_char* uploa, + const f77_char* transa, + const f77_char* diaga, + const f77_int* m, + const f77_int* n, + const dcomplex* alpha, + const dcomplex* a, const f77_int* lda, + dcomplex* b, const f77_int* ldb +) +{ + ztrsm_blis_impl ( side, uploa, transa, diaga, m, n, alpha, a, lda, b, ldb ); +} -void ctrsm_ +void ctrsm_blis_impl ( const f77_char* side, const f77_char* uploa, @@ -1664,5 +1739,20 @@ void ctrsm_ /* Finalize BLIS. */ bli_finalize_auto(); } +void ctrsm_ +( + const f77_char* side, + const f77_char* uploa, + const f77_char* transa, + const f77_char* diaga, + const f77_int* m, + const f77_int* n, + const scomplex* alpha, + const scomplex* a, const f77_int* lda, + scomplex* b, const f77_int* ldb +) +{ + ctrsm_blis_impl ( side, uploa, transa, diaga, m, n, alpha, a, lda, b, ldb ); +} #endif diff --git a/frame/compat/cblas/src/cblas_f77.h b/frame/compat/cblas/src/cblas_f77.h index fabf3efb1c..5ec518de9e 100644 --- a/frame/compat/cblas/src/cblas_f77.h +++ b/frame/compat/cblas/src/cblas_f77.h @@ -7,7 +7,7 @@ * * (Heavily hacked down from the original) * - * Copyright (C) 2020 - 2021, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. * */ @@ -326,40 +326,40 @@ /* * Level 3 BLAS */ -#define F77_chemm chemm_ -#define F77_cherk cherk_ -#define F77_cher2k cher2k_ -#define F77_zhemm zhemm_ -#define F77_zherk zherk_ -#define F77_zher2k zher2k_ -#define F77_sgemm sgemm_ -#define F77_ssymm ssymm_ -#define F77_ssyrk ssyrk_ -#define F77_ssyr2k ssyr2k_ -#define F77_strmm strmm_ -#define F77_strsm strsm_ -#define F77_dgemm dgemm_ -#define F77_dsymm dsymm_ -#define F77_dsyrk dsyrk_ -#define F77_dsyr2k dsyr2k_ -#define F77_dtrmm dtrmm_ -#define F77_dtrsm dtrsm_ -#define F77_cgemm cgemm_ -#define F77_csymm csymm_ -#define F77_csyrk csyrk_ -#define F77_csyr2k csyr2k_ -#define F77_ctrmm ctrmm_ -#define F77_ctrsm ctrsm_ -#define F77_zgemm zgemm_ -#define F77_zsymm zsymm_ -#define F77_zsyrk zsyrk_ -#define F77_zsyr2k zsyr2k_ -#define F77_ztrmm ztrmm_ -#define F77_ztrsm ztrsm_ -#define F77_dgemmt dgemmt_ -#define F77_sgemmt sgemmt_ -#define F77_cgemmt cgemmt_ -#define F77_zgemmt zgemmt_ +#define F77_chemm chemm_blis_impl +#define F77_cherk cherk_blis_impl +#define F77_cher2k cher2k_blis_impl +#define F77_zhemm zhemm_blis_impl +#define F77_zherk zherk_blis_impl +#define F77_zher2k zher2k_blis_impl +#define F77_sgemm sgemm_blis_impl +#define F77_ssymm ssymm_blis_impl +#define F77_ssyrk ssyrk_blis_impl +#define F77_ssyr2k ssyr2k_blis_impl +#define F77_strmm strmm_blis_impl +#define F77_strsm strsm_blis_impl +#define F77_dgemm dgemm_blis_impl +#define F77_dsymm dsymm_blis_impl +#define F77_dsyrk dsyrk_blis_impl +#define F77_dsyr2k dsyr2k_blis_impl +#define F77_dtrmm dtrmm_blis_impl +#define F77_dtrsm dtrsm_blis_impl +#define F77_cgemm cgemm_blis_impl +#define F77_csymm csymm_blis_impl +#define F77_csyrk csyrk_blis_impl +#define F77_csyr2k csyr2k_blis_impl +#define F77_ctrmm ctrmm_blis_impl +#define F77_ctrsm ctrsm_blis_impl +#define F77_zgemm zgemm_blis_impl +#define F77_zsymm zsymm_blis_impl +#define F77_zsyrk zsyrk_blis_impl +#define F77_zsyr2k zsyr2k_blis_impl +#define F77_ztrmm ztrmm_blis_impl +#define F77_ztrsm ztrsm_blis_impl +#define F77_dgemmt dgemmt_blis_impl +#define F77_sgemmt sgemmt_blis_impl +#define F77_cgemmt cgemmt_blis_impl +#define F77_zgemmt zgemmt_blis_impl /* * Aux Function @@ -375,8 +375,8 @@ #define F77_daxpby daxpby_ #define F77_caxpby caxpby_ #define F77_zaxpby zaxpby_ -#define F77_cgemm3m cgemm3m_ -#define F77_zgemm3m zgemm3m_ +#define F77_cgemm3m cgemm3m_blis_impl +#define F77_zgemm3m zgemm3m_blis_impl #define F77_isamin_sub isaminsub_ #define F77_idamin_sub idaminsub_ diff --git a/frame/include/bli_macro_defs.h b/frame/include/bli_macro_defs.h index f29fdc1fe4..75b9c9fdc4 100644 --- a/frame/include/bli_macro_defs.h +++ b/frame/include/bli_macro_defs.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018-2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -162,11 +162,13 @@ #define PASTEF77(ch1,name) ch1 ## name #define PASTEF772(ch1,ch2,name) ch1 ## ch2 ## name #define PASTEF773(ch1,ch2,ch3,name) ch1 ## ch2 ## ch3 ## name +#define PASTEF77S(ch1,name) ch1 ## name ## _blis_impl #else #define PASTEF770(name) name ## _ #define PASTEF77(ch1,name) ch1 ## name ## _ #define PASTEF772(ch1,ch2,name) ch1 ## ch2 ## name ## _ #define PASTEF773(ch1,ch2,ch3,name) ch1 ## ch2 ## ch3 ## name ## _ +#define PASTEF77S(ch1,name) ch1 ## name ## _blis_impl #endif // -- Include other groups of macros diff --git a/frame/util/bli_util_api_wrap.c b/frame/util/bli_util_api_wrap.c index 81300761fb..9e8d1ccc38 100644 --- a/frame/util/bli_util_api_wrap.c +++ b/frame/util/bli_util_api_wrap.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2021-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -210,17 +210,17 @@ void CGBMV_(const char *trans,const f77_int *m,const f77_int *n,const f77_int void CGEMM(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) { - cgemm_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + cgemm_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void cgemm(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) { - cgemm_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + cgemm_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void CGEMM_(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) { - cgemm_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + cgemm_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void CGEMV(const char *trans,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *x,const f77_int *incx,const scomplex *beta,scomplex *y,const f77_int *incy) @@ -285,17 +285,17 @@ void CHBMV_(const char *uplo,const f77_int *n,const f77_int *k,const scomplex void CHEMM(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) { - chemm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + chemm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); } void chemm(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) { - chemm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + chemm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); } void CHEMM_(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) { - chemm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + chemm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); } void CHEMV(const char *uplo,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *x,const f77_int *incx,const scomplex *beta,scomplex *y,const f77_int *incy) @@ -345,32 +345,32 @@ void CHER2_(const char *uplo,const f77_int *n,const scomplex *alpha,const sco void CHER2K(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const float *beta,scomplex *c,const f77_int *ldc) { - cher2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + cher2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void cher2k(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const float *beta,scomplex *c,const f77_int *ldc) { - cher2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + cher2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void CHER2K_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const float *beta,scomplex *c,const f77_int *ldc) { - cher2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + cher2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void CHERK(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const float *alpha,const scomplex *a,const f77_int *lda,const float *beta,scomplex *c,const f77_int *ldc) { - cherk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); + cherk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); } void cherk(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const float *alpha,const scomplex *a,const f77_int *lda,const float *beta,scomplex *c,const f77_int *ldc) { - cherk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); + cherk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); } void CHERK_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const float *alpha,const scomplex *a,const f77_int *lda,const float *beta,scomplex *c,const f77_int *ldc) { - cherk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); + cherk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); } void CHPMV(const char *uplo,const f77_int *n,const scomplex *alpha,const scomplex *ap,const scomplex *x,const f77_int *incx,const scomplex *beta,scomplex *y,const f77_int *incy) @@ -495,47 +495,47 @@ void CSWAP_(const f77_int *n,scomplex *cx,const f77_int *incx,scomplex *cy,con void CSYMM(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) { - csymm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + csymm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); } void csymm(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) { - csymm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + csymm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); } void CSYMM_(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) { - csymm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + csymm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); } void CSYR2K(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) { - csyr2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + csyr2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void csyr2k(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) { - csyr2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + csyr2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void CSYR2K_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) { - csyr2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + csyr2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void CSYRK(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *beta,scomplex *c,const f77_int *ldc) { - csyrk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); + csyrk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); } void csyrk(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *beta,scomplex *c,const f77_int *ldc) { - csyrk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); + csyrk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); } void CSYRK_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *beta,scomplex *c,const f77_int *ldc) { - csyrk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); + csyrk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); } void CTBMV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const scomplex *a,const f77_int *lda,scomplex *x,const f77_int *incx) @@ -600,17 +600,17 @@ void CTPSV_(const char *uplo,const char *trans,const char *diag,const f77_ void CTRMM(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,scomplex *b,const f77_int *ldb) { - ctrmm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + ctrmm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void ctrmm(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,scomplex *b,const f77_int *ldb) { - ctrmm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + ctrmm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void CTRMM_(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,scomplex *b,const f77_int *ldb) { - ctrmm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + ctrmm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void CTRMV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const scomplex *a,const f77_int *lda,scomplex *x,const f77_int *incx) @@ -630,17 +630,17 @@ void CTRMV_(const char *uplo,const char *trans,const char *diag,const f77_ void CTRSM(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,scomplex *b,const f77_int *ldb) { - ctrsm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + ctrsm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void ctrsm(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,scomplex *b,const f77_int *ldb) { - ctrsm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + ctrsm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void CTRSM_(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,scomplex *b,const f77_int *ldb) { - ctrsm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + ctrsm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void CTRSV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const scomplex *a,const f77_int *lda,scomplex *x,const f77_int *incx) @@ -750,17 +750,17 @@ void DGBMV_(const char *trans,const f77_int *m,const f77_int *n,const f77_int void DGEMM(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const double *alpha,const double *a,const f77_int *lda,const double *b,const f77_int *ldb,const double *beta,double *c,const f77_int *ldc) { - dgemm_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + dgemm_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void dgemm(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const double *alpha,const double *a,const f77_int *lda,const double *b,const f77_int *ldb,const double *beta,double *c,const f77_int *ldc) { - dgemm_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + dgemm_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void DGEMM_(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const double *alpha,const double *a,const f77_int *lda,const double *b,const f77_int *ldb,const double *beta,double *c,const f77_int *ldc) { - dgemm_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + dgemm_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void DGEMV(const char *trans,const f77_int *m,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,const double *x,const f77_int *incx,const double *beta,double *y,const f77_int *incy) @@ -975,17 +975,17 @@ void DSWAP_(const f77_int *n,double *dx,const f77_int *incx,double *dy,const f77 void DSYMM(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,const double *b,const f77_int *ldb,const double *beta,double *c,const f77_int *ldc) { - dsymm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + dsymm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); } void dsymm(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,const double *b,const f77_int *ldb,const double *beta,double *c,const f77_int *ldc) { - dsymm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + dsymm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); } void DSYMM_(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,const double *b,const f77_int *ldb,const double *beta,double *c,const f77_int *ldc) { - dsymm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + dsymm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); } void DSYMV(const char *uplo,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,const double *x,const f77_int *incx,const double *beta,double *y,const f77_int *incy) @@ -1035,32 +1035,32 @@ void DSYR2_(const char *uplo,const f77_int *n,const double *alpha,const double void DSYR2K(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const double *alpha,const double *a,const f77_int *lda,const double *b,const f77_int *ldb,const double *beta,double *c,const f77_int *ldc) { - dsyr2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + dsyr2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void dsyr2k(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const double *alpha,const double *a,const f77_int *lda,const double *b,const f77_int *ldb,const double *beta,double *c,const f77_int *ldc) { - dsyr2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + dsyr2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void DSYR2K_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const double *alpha,const double *a,const f77_int *lda,const double *b,const f77_int *ldb,const double *beta,double *c,const f77_int *ldc) { - dsyr2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + dsyr2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void DSYRK(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const double *alpha,const double *a,const f77_int *lda,const double *beta,double *c,const f77_int *ldc) { - dsyrk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); + dsyrk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); } void dsyrk(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const double *alpha,const double *a,const f77_int *lda,const double *beta,double *c,const f77_int *ldc) { - dsyrk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); + dsyrk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); } void DSYRK_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const double *alpha,const double *a,const f77_int *lda,const double *beta,double *c,const f77_int *ldc) { - dsyrk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); + dsyrk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); } void DTBMV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const double *a,const f77_int *lda,double *x,const f77_int *incx) @@ -1125,17 +1125,17 @@ void DTPSV_(const char *uplo,const char *trans,const char *diag,const f77_ void DTRMM(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,double *b,const f77_int *ldb) { - dtrmm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + dtrmm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void dtrmm(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,double *b,const f77_int *ldb) { - dtrmm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + dtrmm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void DTRMM_(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,double *b,const f77_int *ldb) { - dtrmm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + dtrmm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void DTRMV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const double *a,const f77_int *lda,double *x,const f77_int *incx) @@ -1155,17 +1155,17 @@ void DTRMV_(const char *uplo,const char *trans,const char *diag,const f77_ void DTRSM(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,double *b,const f77_int *ldb) { - dtrsm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + dtrsm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void dtrsm(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,double *b,const f77_int *ldb) { - dtrsm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + dtrsm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void DTRSM_(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,double *b,const f77_int *ldb) { - dtrsm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + dtrsm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void DTRSV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const double *a,const f77_int *lda,double *x,const f77_int *incx) @@ -1417,17 +1417,17 @@ void SGBMV_(const char *trans,const f77_int *m,const f77_int *n,const f77_int void SGEMM(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const float *alpha,const float *a,const f77_int *lda,const float *b,const f77_int *ldb,const float *beta,float *c,const f77_int *ldc) { - sgemm_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + sgemm_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void sgemm(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const float *alpha,const float *a,const f77_int *lda,const float *b,const f77_int *ldb,const float *beta,float *c,const f77_int *ldc) { - sgemm_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + sgemm_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void SGEMM_(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const float *alpha,const float *a,const f77_int *lda,const float *b,const f77_int *ldb,const float *beta,float *c,const f77_int *ldc) { - sgemm_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + sgemm_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void SGEMV(const char *trans,const f77_int *m,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,const float *x,const f77_int *incx,const float *beta,float *y,const f77_int *incy) @@ -1629,17 +1629,17 @@ void SSWAP_(const f77_int *n,float *sx,const f77_int *incx,float *sy,const f77 void SSYMM(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,const float *b,const f77_int *ldb,const float *beta,float *c,const f77_int *ldc) { - ssymm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + ssymm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); } void ssymm(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,const float *b,const f77_int *ldb,const float *beta,float *c,const f77_int *ldc) { - ssymm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + ssymm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); } void SSYMM_(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,const float *b,const f77_int *ldb,const float *beta,float *c,const f77_int *ldc) { - ssymm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + ssymm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); } void SSYMV(const char *uplo,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,const float *x,const f77_int *incx,const float *beta,float *y,const f77_int *incy) @@ -1689,32 +1689,32 @@ void SSYR2_(const char *uplo,const f77_int *n,const float *alpha,const float void SSYR2K(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const float *alpha,const float *a,const f77_int *lda,const float *b,const f77_int *ldb,const float *beta,float *c,const f77_int *ldc) { - ssyr2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + ssyr2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void ssyr2k(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const float *alpha,const float *a,const f77_int *lda,const float *b,const f77_int *ldb,const float *beta,float *c,const f77_int *ldc) { - ssyr2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + ssyr2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void SSYR2K_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const float *alpha,const float *a,const f77_int *lda,const float *b,const f77_int *ldb,const float *beta,float *c,const f77_int *ldc) { - ssyr2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + ssyr2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void SSYRK(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const float *alpha,const float *a,const f77_int *lda,const float *beta,float *c,const f77_int *ldc) { - ssyrk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); + ssyrk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); } void ssyrk(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const float *alpha,const float *a,const f77_int *lda,const float *beta,float *c,const f77_int *ldc) { - ssyrk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); + ssyrk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); } void SSYRK_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const float *alpha,const float *a,const f77_int *lda,const float *beta,float *c,const f77_int *ldc) { - ssyrk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); + ssyrk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); } void STBMV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const float *a,const f77_int *lda,float *x,const f77_int *incx) @@ -1779,17 +1779,17 @@ void STPSV_(const char *uplo,const char *trans,const char *diag,const f77_ void STRMM(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,float *b,const f77_int *ldb) { - strmm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + strmm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void strmm(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,float *b,const f77_int *ldb) { - strmm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + strmm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void STRMM_(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,float *b,const f77_int *ldb) { - strmm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + strmm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void STRMV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const float *a,const f77_int *lda,float *x,const f77_int *incx) @@ -1809,17 +1809,17 @@ void STRMV_(const char *uplo,const char *trans,const char *diag,const f77_ void STRSM(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,float *b,const f77_int *ldb) { - strsm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + strsm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void strsm(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,float *b,const f77_int *ldb) { - strsm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + strsm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void STRSM_(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,float *b,const f77_int *ldb) { - strsm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + strsm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void STRSV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const float *a,const f77_int *lda,float *x,const f77_int *incx) @@ -1929,17 +1929,17 @@ void ZGBMV_(const char *trans,const f77_int *m,const f77_int *n,const f77_int void ZGEMM(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) { - zgemm_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + zgemm_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void zgemm(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) { - zgemm_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + zgemm_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void ZGEMM_(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) { - zgemm_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + zgemm_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void ZGEMV(const char *trans,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *x,const f77_int *incx,const dcomplex *beta,dcomplex *y,const f77_int *incy) @@ -2004,17 +2004,17 @@ void ZHBMV_(const char *uplo,const f77_int *n,const f77_int *k,const dcomplex void ZHEMM(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) { - zhemm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + zhemm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); } void zhemm(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) { - zhemm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + zhemm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); } void ZHEMM_(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) { - zhemm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + zhemm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); } void ZHEMV(const char *uplo,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *x,const f77_int *incx,const dcomplex *beta,dcomplex *y,const f77_int *incy) @@ -2064,32 +2064,32 @@ void ZHER2_(const char *uplo,const f77_int *n,const dcomplex *alpha,const dcom void ZHER2K(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const double *beta,dcomplex *c,const f77_int *ldc) { - zher2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + zher2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void zher2k(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const double *beta,dcomplex *c,const f77_int *ldc) { - zher2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + zher2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void ZHER2K_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const double *beta,dcomplex *c,const f77_int *ldc) { - zher2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + zher2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void ZHERK(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const double *alpha,const dcomplex *a,const f77_int *lda,const double *beta,dcomplex *c,const f77_int *ldc) { - zherk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); + zherk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); } void zherk(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const double *alpha,const dcomplex *a,const f77_int *lda,const double *beta,dcomplex *c,const f77_int *ldc) { - zherk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); + zherk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); } void ZHERK_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const double *alpha,const dcomplex *a,const f77_int *lda,const double *beta,dcomplex *c,const f77_int *ldc) { - zherk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); + zherk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); } void ZHPMV(const char *uplo,const f77_int *n,const dcomplex *alpha,const dcomplex *ap,const dcomplex *x,const f77_int *incx,const dcomplex *beta,dcomplex *y,const f77_int *incy) @@ -2184,47 +2184,47 @@ void ZSWAP_(const f77_int *n,dcomplex *zx,const f77_int *incx,dcomplex *zy,const void ZSYMM(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) { - zsymm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + zsymm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); } void zsymm(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) { - zsymm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + zsymm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); } void ZSYMM_(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) { - zsymm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + zsymm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); } void ZSYR2K(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) { - zsyr2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + zsyr2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void zsyr2k(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) { - zsyr2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + zsyr2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void ZSYR2K_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) { - zsyr2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + zsyr2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void ZSYRK(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *beta,dcomplex *c,const f77_int *ldc) { - zsyrk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); + zsyrk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); } void zsyrk(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *beta,dcomplex *c,const f77_int *ldc) { - zsyrk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); + zsyrk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); } void ZSYRK_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *beta,dcomplex *c,const f77_int *ldc) { - zsyrk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); + zsyrk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); } void ZTBMV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const dcomplex *a,const f77_int *lda,dcomplex *x,const f77_int *incx) @@ -2289,17 +2289,17 @@ void ZTPSV_(const char *uplo,const char *trans,const char *diag,const f77_ void ZTRMM(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,dcomplex *b,const f77_int *ldb) { - ztrmm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + ztrmm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void ztrmm(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,dcomplex *b,const f77_int *ldb) { - ztrmm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + ztrmm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void ZTRMM_(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,dcomplex *b,const f77_int *ldb) { - ztrmm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + ztrmm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void ZTRMV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const dcomplex *a,const f77_int *lda,dcomplex *x,const f77_int *incx) @@ -2319,17 +2319,17 @@ void ZTRMV_(const char *uplo,const char *trans,const char *diag,const f77_ void ZTRSM(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,dcomplex *b,const f77_int *ldb) { - ztrsm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + ztrsm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void ztrsm(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,dcomplex *b,const f77_int *ldb) { - ztrsm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + ztrsm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void ZTRSM_(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,dcomplex *b,const f77_int *ldb) { - ztrsm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + ztrsm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void ZTRSV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const dcomplex *a,const f77_int *lda,dcomplex *x,const f77_int *incx) @@ -2380,17 +2380,17 @@ void CDOTUSUB_( const f77_int* n, const scomplex* x,const f77_int* incxy, const void CGEMM3M( const f77_char* transa, const f77_char* transb, const f77_int* m, const f77_int* n, const f77_int* k, const scomplex* alpha, const scomplex* a, const f77_int* lda, const scomplex* b, const f77_int* ldb, const scomplex* beta, scomplex* c, const f77_int* ldc) { - cgemm3m_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + cgemm3m_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void cgemm3m( const f77_char* transa, const f77_char* transb, const f77_int* m, const f77_int* n, const f77_int* k, const scomplex* alpha, const scomplex* a, const f77_int* lda, const scomplex* b, const f77_int* ldb, const scomplex* beta, scomplex* c, const f77_int* ldc) { - cgemm3m_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + cgemm3m_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void CGEMM3M_( const f77_char* transa, const f77_char* transb, const f77_int* m, const f77_int* n, const f77_int* k, const scomplex* alpha, const scomplex* a, const f77_int* lda, const scomplex* b, const f77_int* ldb, const scomplex* beta, scomplex* c, const f77_int* ldc) { - cgemm3m_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + cgemm3m_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void CGEMM_BATCH( const f77_char* transa_array, const f77_char* transb_array,const f77_int *m_array, const f77_int *n_array, const f77_int *k_array,const scomplex* alpha_array, const scomplex** a_array, const f77_int *lda_array, const scomplex** b_array, const f77_int *ldb_array, const scomplex* beta_array, scomplex** c_array, const f77_int *ldc_array, const f77_int* group_count, const f77_int *group_size) @@ -2410,17 +2410,17 @@ void CGEMM_BATCH_( const f77_char* transa_array, const f77_char* transb_array,co void CGEMMT( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const scomplex* alpha, const scomplex* a, const f77_int* lda, const scomplex* b, const f77_int* ldb, const scomplex* beta, scomplex* c, const f77_int* ldc) { - cgemmt_( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + cgemmt_blis_impl( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void cgemmt( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const scomplex* alpha, const scomplex* a, const f77_int* lda, const scomplex* b, const f77_int* ldb, const scomplex* beta, scomplex* c, const f77_int* ldc) { - cgemmt_( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + cgemmt_blis_impl( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void CGEMMT_( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const scomplex* alpha, const scomplex* a, const f77_int* lda, const scomplex* b, const f77_int* ldb, const scomplex* beta, scomplex* c, const f77_int* ldc) { - cgemmt_( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + cgemmt_blis_impl( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void CIMATCOPY(f77_char* trans, f77_int* rows, f77_int* cols, const scomplex* alpha,scomplex* aptr, f77_int* lda, f77_int* ldb) @@ -2545,17 +2545,17 @@ void DGEMM_BATCH_( const f77_char* transa_array, const f77_char* transb_array,co void DGEMMT( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const double* alpha, const double* a, const f77_int* lda, const double* b, const f77_int* ldb, const double* beta, double* c, const f77_int* ldc) { - dgemmt_( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + dgemmt_blis_impl( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void dgemmt( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const double* alpha, const double* a, const f77_int* lda, const double* b, const f77_int* ldb, const double* beta, double* c, const f77_int* ldc) { - dgemmt_( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + dgemmt_blis_impl( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void DGEMMT_( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const double* alpha, const double* a, const f77_int* lda, const double* b, const f77_int* ldb, const double* beta, double* c, const f77_int* ldc) { - dgemmt_( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + dgemmt_blis_impl( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void DNRM2SUB(const f77_int* n, const double* x, const f77_int* incx, double *rval) @@ -2920,17 +2920,17 @@ void SGEMM_BATCH_(const f77_char* transa_array, const f77_char* transb_array,con void SGEMMT( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const float* alpha, const float* a, const f77_int* lda, const float* b, const f77_int* ldb, const float* beta, float* c, const f77_int* ldc) { - sgemmt_( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + sgemmt_blis_impl( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void sgemmt( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const float* alpha, const float* a, const f77_int* lda, const float* b, const f77_int* ldb, const float* beta, float* c, const f77_int* ldc) { - sgemmt_( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + sgemmt_blis_impl( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void SGEMMT_( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const float* alpha, const float* a, const f77_int* lda, const float* b, const f77_int* ldb, const float* beta, float* c, const f77_int* ldc) { - sgemmt_( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + sgemmt_blis_impl( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void SIMATCOPY( f77_char* trans, f77_int* rows, f77_int* cols, const float* alpha,float* aptr, f77_int* lda, f77_int* ldb) @@ -3055,17 +3055,17 @@ void ZDOTUSUB_( const f77_int* n, const dcomplex* x, const f77_int* incx,const d void ZGEMM3M( const f77_char* transa, const f77_char* transb, const f77_int* m, const f77_int* n, const f77_int* k, const dcomplex* alpha, const dcomplex* a, const f77_int* lda, const dcomplex* b, const f77_int* ldb, const dcomplex* beta, dcomplex* c, const f77_int* ldc) { - zgemm3m_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + zgemm3m_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void zgemm3m( const f77_char* transa, const f77_char* transb, const f77_int* m, const f77_int* n, const f77_int* k, const dcomplex* alpha, const dcomplex* a, const f77_int* lda, const dcomplex* b, const f77_int* ldb, const dcomplex* beta, dcomplex* c, const f77_int* ldc) { - zgemm3m_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + zgemm3m_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void ZGEMM3M_( const f77_char* transa, const f77_char* transb, const f77_int* m, const f77_int* n, const f77_int* k, const dcomplex* alpha, const dcomplex* a, const f77_int* lda, const dcomplex* b, const f77_int* ldb, const dcomplex* beta, dcomplex* c, const f77_int* ldc) { - zgemm3m_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + zgemm3m_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void ZGEMM_BATCH( const f77_char* transa_array, const f77_char* transb_array,const f77_int *m_array, const f77_int *n_array, const f77_int *k_array,const dcomplex* alpha_array, const dcomplex** a_array, const f77_int *lda_array, const dcomplex** b_array, const f77_int *ldb_array, const dcomplex* beta_array, dcomplex** c_array, const f77_int *ldc_array, const f77_int* group_count, const f77_int *group_size) @@ -3085,17 +3085,17 @@ void ZGEMM_BATCH_( const f77_char* transa_array, const f77_char* transb_array,c void ZGEMMT( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const dcomplex* alpha, const dcomplex* a, const f77_int* lda, const dcomplex* b, const f77_int* ldb, const dcomplex* beta, dcomplex* c, const f77_int* ldc) { - zgemmt_( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + zgemmt_blis_impl( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void zgemmt( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const dcomplex* alpha, const dcomplex* a, const f77_int* lda, const dcomplex* b, const f77_int* ldb, const dcomplex* beta, dcomplex* c, const f77_int* ldc) { - zgemmt_( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + zgemmt_blis_impl( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void ZGEMMT_( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const dcomplex* alpha, const dcomplex* a, const f77_int* lda, const dcomplex* b, const f77_int* ldb, const dcomplex* beta, dcomplex* c, const f77_int* ldc) { - zgemmt_( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + zgemmt_blis_impl( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void ZIMATCOPY(f77_char* trans, f77_int* rows, f77_int* cols, const dcomplex* alpha,dcomplex* aptr, f77_int* lda, f77_int* ldb) From 192f5313a13cdc21f0e07fdc8c3536bbf11c1c0e Mon Sep 17 00:00:00 2001 From: jagar Date: Fri, 26 Aug 2022 20:46:31 +0530 Subject: [PATCH 195/243] CBLAS/BLAS interface decoupling for level 2 APIs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - In BLIS the cblas interface is implemented as a wrapper around the blas interface. For example the CBLAS api ‘cblas_dgemm’ internally invokes BLAS API ‘dgemm_’. - If the end user wants to use the different libraries for CBLAS and BLAS, current implantation of BLIS doesn’t allow it and may result in recursion - This change separates the CBLAS and BLAS implantation by adding an additional level of abstraction. The implementation of the API is moved to the new function which is invoked directly from the CBLAS and BLAS wrappers. AMD-Internal: [SWLCSG-1477] Change-Id: I8380b6468683028035f2aece48916939e0fede8a --- frame/compat/bla_gemv.c | 25 +- frame/compat/bla_gemv.h | 15 +- frame/compat/bla_gemv_amd.c | 90 ++++++- frame/compat/bla_ger.c | 18 +- frame/compat/bla_ger.h | 15 +- frame/compat/bla_hemv.c | 19 +- frame/compat/bla_hemv.h | 16 +- frame/compat/bla_her.c | 19 +- frame/compat/bla_her.h | 14 +- frame/compat/bla_her2.c | 18 +- frame/compat/bla_her2.h | 15 +- frame/compat/bla_symv.c | 19 +- frame/compat/bla_symv.h | 14 +- frame/compat/bla_syr.c | 17 +- frame/compat/bla_syr.h | 12 +- frame/compat/bla_syr2.c | 18 +- frame/compat/bla_syr2.h | 13 +- frame/compat/bla_trmv.c | 18 +- frame/compat/bla_trmv.h | 13 +- frame/compat/bla_trsv.c | 18 +- frame/compat/bla_trsv.h | 13 +- frame/compat/cblas/src/cblas_f77.h | 132 +++++----- frame/compat/f2c/bla_gbmv.c | 34 ++- frame/compat/f2c/bla_gbmv.h | 7 +- frame/compat/f2c/bla_hbmv.c | 19 +- frame/compat/f2c/bla_hbmv.h | 5 +- frame/compat/f2c/bla_hpmv.c | 19 +- frame/compat/f2c/bla_hpmv.h | 5 +- frame/compat/f2c/bla_hpr.c | 19 +- frame/compat/f2c/bla_hpr.h | 5 +- frame/compat/f2c/bla_hpr2.c | 19 +- frame/compat/f2c/bla_hpr2.h | 5 +- frame/compat/f2c/bla_sbmv.c | 19 +- frame/compat/f2c/bla_sbmv.h | 5 +- frame/compat/f2c/bla_spmv.c | 19 +- frame/compat/f2c/bla_spmv.h | 5 +- frame/compat/f2c/bla_spr.c | 19 +- frame/compat/f2c/bla_spr.h | 5 +- frame/compat/f2c/bla_spr2.c | 19 +- frame/compat/f2c/bla_spr2.h | 5 +- frame/compat/f2c/bla_tbmv.c | 51 +++- frame/compat/f2c/bla_tbmv.h | 7 +- frame/compat/f2c/bla_tbsv.c | 35 ++- frame/compat/f2c/bla_tbsv.h | 7 +- frame/compat/f2c/bla_tpmv.c | 36 ++- frame/compat/f2c/bla_tpmv.h | 7 +- frame/compat/f2c/bla_tpsv.c | 35 ++- frame/compat/f2c/bla_tpsv.h | 7 +- frame/include/bli_macro_defs.h | 26 +- frame/util/bli_util_api_wrap.c | 396 ++++++++++++++--------------- 50 files changed, 997 insertions(+), 394 deletions(-) diff --git a/frame/compat/bla_gemv.c b/frame/compat/bla_gemv.c index 9dba1b43c4..f5a314331a 100644 --- a/frame/compat/bla_gemv.c +++ b/frame/compat/bla_gemv.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 22, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -35,14 +35,10 @@ #include "blis.h" - -// -// Define BLAS-to-BLIS interfaces. -// #undef GENTFUNC #define GENTFUNC( ftype, ch, blasname, blisname ) \ \ -void PASTEF77(ch,blasname) \ +void PASTEF77S(ch,blasname) \ ( \ const f77_char* transa, \ const f77_int* m, \ @@ -143,9 +139,24 @@ void PASTEF77(ch,blasname) \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ +}\ +\ +void PASTEF77S(ch,blasname) \ + ( \ + const f77_char* transa, \ + const f77_int* m, \ + const f77_int* n, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + const ftype* x, const f77_int* incx, \ + const ftype* beta, \ + ftype* y, const f77_int* incy \ + ) \ +{ \ + PASTEF77(ch,blasname) \ + ( transa, m, n, alpha, a, lda, x, incx, beta, y, incy ); \ } - #ifdef BLIS_ENABLE_BLAS INSERT_GENTFUNC_BLAS( gemv, gemv ) #endif diff --git a/frame/compat/bla_gemv.h b/frame/compat/bla_gemv.h index 22c8bf1c07..54b5471bd4 100644 --- a/frame/compat/bla_gemv.h +++ b/frame/compat/bla_gemv.h @@ -5,7 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - + Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -40,6 +41,18 @@ #define GENTPROT( ftype, ch, blasname ) \ \ BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ + ( \ + const f77_char* transa, \ + const f77_int* m, \ + const f77_int* n, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + const ftype* x, const f77_int* incx, \ + const ftype* beta, \ + ftype* y, const f77_int* incy \ + );\ +\ +BLIS_EXPORT_BLAS void PASTEF77S(ch,blasname) \ ( \ const f77_char* transa, \ const f77_int* m, \ diff --git a/frame/compat/bla_gemv_amd.c b/frame/compat/bla_gemv_amd.c index 354f45fe1b..61834948fc 100644 --- a/frame/compat/bla_gemv_amd.c +++ b/frame/compat/bla_gemv_amd.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 22, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -42,7 +42,7 @@ #undef GENTFUNC #define GENTFUNC( ftype, ch, blasname, blisname ) \ \ -void PASTEF77(ch,blasname) \ +void PASTEF77S(ch,blasname) \ ( \ const f77_char* transa, \ const f77_int* m, \ @@ -143,11 +143,26 @@ void PASTEF77(ch,blasname) \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ +}\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* transa, \ + const f77_int* m, \ + const f77_int* n, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + const ftype* x, const f77_int* incx, \ + const ftype* beta, \ + ftype* y, const f77_int* incy \ + ) \ +{ \ + PASTEF77S(ch,blasname) \ + ( transa, m, n, alpha, a, lda, x, incx, beta, y, incy ); \ } #ifdef BLIS_ENABLE_BLAS -void dgemv_ +void dgemv_blis_impl ( const f77_char* transa, const f77_int* m, @@ -331,8 +346,23 @@ void dgemv_ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); } +void dgemv_ + ( + const f77_char* transa, + const f77_int* m, + const f77_int* n, + const double* alpha, + const double* a, const f77_int* lda, + const double* x, const f77_int* incx, + const double* beta, + double* y, const f77_int* incy + ) +{ + dgemv_blis_impl( transa, m, n, alpha, a, lda, + x, incx, beta, y, incy ); +} -void sgemv_ +void sgemv_blis_impl ( const f77_char* transa, const f77_int* m, @@ -510,9 +540,23 @@ void sgemv_ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); } +void sgemv_ + ( + const f77_char* transa, + const f77_int* m, + const f77_int* n, + const float* alpha, + const float* a, const f77_int* lda, + const float* x, const f77_int* incx, + const float* beta, + float* y, const f77_int* incy + ) +{ + sgemv_blis_impl( transa, m, n, alpha, a, lda, + x, incx, beta, y, incy ); +} - -void cgemv_ +void cgemv_blis_impl ( const f77_char* transa, const f77_int* m, @@ -733,9 +777,23 @@ void cgemv_ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); } +void cgemv_ + ( + const f77_char* transa, + const f77_int* m, + const f77_int* n, + const scomplex* alpha, + const scomplex* a, const f77_int* lda, + const scomplex* x, const f77_int* incx, + const scomplex* beta, + scomplex* y, const f77_int* incy + ) +{ + cgemv_blis_impl( transa, m, n, alpha, a, lda, + x, incx, beta, y, incy ); +} - -void zgemv_ +void zgemv_blis_impl ( const f77_char* transa, const f77_int* m, @@ -957,7 +1015,21 @@ void zgemv_ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); } - +void zgemv_ + ( + const f77_char* transa, + const f77_int* m, + const f77_int* n, + const dcomplex* alpha, + const dcomplex* a, const f77_int* lda, + const dcomplex* x, const f77_int* incx, + const dcomplex* beta, + dcomplex* y, const f77_int* incy + ) +{ + zgemv_blis_impl( transa, m, n, alpha, a, lda, + x, incx, beta, y, incy ); +} #endif diff --git a/frame/compat/bla_ger.c b/frame/compat/bla_ger.c index b7613842ae..f489bd356e 100644 --- a/frame/compat/bla_ger.c +++ b/frame/compat/bla_ger.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -42,7 +42,7 @@ #undef GENTFUNCDOT #define GENTFUNCDOT( ftype, ch, chc, blis_conjy, blasname, blisname ) \ \ -void PASTEF772(ch,blasname,chc) \ +void PASTEF772S(ch,blasname,chc) \ ( \ const f77_int* m, \ const f77_int* n, \ @@ -110,6 +110,20 @@ void PASTEF772(ch,blasname,chc) \ \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ +} \ +\ +void PASTEF772(ch,blasname,chc) \ + ( \ + const f77_int* m, \ + const f77_int* n, \ + const ftype* alpha, \ + const ftype* x, const f77_int* incx, \ + const ftype* y, const f77_int* incy, \ + ftype* a, const f77_int* lda \ + ) \ +{ \ + PASTEF772S(ch,blasname,chc) \ + ( m, n, alpha, x, incx, y, incy, a, lda ); \ } #ifdef BLIS_ENABLE_BLAS diff --git a/frame/compat/bla_ger.h b/frame/compat/bla_ger.h index a31548f610..769bcffe46 100644 --- a/frame/compat/bla_ger.h +++ b/frame/compat/bla_ger.h @@ -5,7 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - + Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -40,6 +41,16 @@ #define GENTPROTDOT( ftype, chxy, chc, blasname ) \ \ BLIS_EXPORT_BLAS void PASTEF772(chxy,blasname,chc) \ + ( \ + const f77_int* m, \ + const f77_int* n, \ + const ftype* alpha, \ + const ftype* x, const f77_int* incx, \ + const ftype* y, const f77_int* incy, \ + ftype* a, const f77_int* lda \ + );\ +\ +BLIS_EXPORT_BLAS void PASTEF772S(chxy,blasname,chc) \ ( \ const f77_int* m, \ const f77_int* n, \ @@ -48,7 +59,7 @@ BLIS_EXPORT_BLAS void PASTEF772(chxy,blasname,chc) \ const ftype* y, const f77_int* incy, \ ftype* a, const f77_int* lda \ ); - + #ifdef BLIS_ENABLE_BLAS INSERT_GENTPROTDOT_BLAS( ger ) #endif diff --git a/frame/compat/bla_hemv.c b/frame/compat/bla_hemv.c index a722f3095d..f9fff26ba7 100644 --- a/frame/compat/bla_hemv.c +++ b/frame/compat/bla_hemv.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -42,7 +42,7 @@ #undef GENTFUNCCO #define GENTFUNCCO( ftype, ftype_r, ch, chr, blasname, blisname ) \ \ -void PASTEF77(ch,blasname) \ +void PASTEF77S(ch,blasname) \ ( \ const f77_char* uploa, \ const f77_int* m, \ @@ -113,6 +113,21 @@ void PASTEF77(ch,blasname) \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ +} \ +\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* uploa, \ + const f77_int* m, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + const ftype* x, const f77_int* incx, \ + const ftype* beta, \ + ftype* y, const f77_int* incy \ + ) \ +{ \ + PASTEF77S(ch,blasname) \ + ( uploa, m, alpha, a, lda, x, incx, beta, y, incy ); \ } #ifdef BLIS_ENABLE_BLAS diff --git a/frame/compat/bla_hemv.h b/frame/compat/bla_hemv.h index 4e82301146..68887fb0ff 100644 --- a/frame/compat/bla_hemv.h +++ b/frame/compat/bla_hemv.h @@ -5,7 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - + Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -40,6 +41,17 @@ #define GENTPROTCO( ftype, ftype_r, ch, chr, blasname ) \ \ BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ + ( \ + const f77_char* uploa, \ + const f77_int* m, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + const ftype* x, const f77_int* incx, \ + const ftype* beta, \ + ftype* y, const f77_int* incy \ + );\ +\ +BLIS_EXPORT_BLAS void PASTEF77S(ch,blasname) \ ( \ const f77_char* uploa, \ const f77_int* m, \ @@ -49,7 +61,7 @@ BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ const ftype* beta, \ ftype* y, const f77_int* incy \ ); - + #ifdef BLIS_ENABLE_BLAS INSERT_GENTPROTCO_BLAS( hemv ) #endif diff --git a/frame/compat/bla_her.c b/frame/compat/bla_her.c index abe0f1e372..f759cbda15 100755 --- a/frame/compat/bla_her.c +++ b/frame/compat/bla_her.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -42,7 +42,7 @@ #undef GENTFUNCCO #define GENTFUNCCO( ftype, ftype_r, ch, chr, blasname, blisname ) \ \ -void PASTEF77(ch,blasname) \ +void PASTEF77S(ch,blasname) \ ( \ const f77_char* uploa, \ const f77_int* m, \ @@ -103,8 +103,21 @@ void PASTEF77(ch,blasname) \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ +}\ +\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* uploa, \ + const f77_int* m, \ + const ftype_r* alpha, \ + const ftype* x, const f77_int* incx, \ + ftype* a, const f77_int* lda \ + ) \ +{\ + PASTEF77S(ch,blasname) \ + ( uploa, m, alpha, x, incx, a, lda ); \ } - + #ifdef BLIS_ENABLE_BLAS INSERT_GENTFUNCCO_BLAS( her, her ) #endif diff --git a/frame/compat/bla_her.h b/frame/compat/bla_her.h index b9ae30d903..0708dafba7 100644 --- a/frame/compat/bla_her.h +++ b/frame/compat/bla_her.h @@ -5,7 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - + Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -40,6 +41,15 @@ #define GENTPROTCO( ftype, ftype_r, ch, chr, blasname ) \ \ BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ + ( \ + const f77_char* uploa, \ + const f77_int* m, \ + const ftype_r* alpha, \ + const ftype* x, const f77_int* incx, \ + ftype* a, const f77_int* lda \ + );\ +\ +BLIS_EXPORT_BLAS void PASTEF77S(ch,blasname) \ ( \ const f77_char* uploa, \ const f77_int* m, \ @@ -47,7 +57,7 @@ BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ const ftype* x, const f77_int* incx, \ ftype* a, const f77_int* lda \ ); - + #ifdef BLIS_ENABLE_BLAS INSERT_GENTPROTCO_BLAS( her ) #endif diff --git a/frame/compat/bla_her2.c b/frame/compat/bla_her2.c index ce65be0cb5..8671e9e174 100644 --- a/frame/compat/bla_her2.c +++ b/frame/compat/bla_her2.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -42,7 +42,7 @@ #undef GENTFUNCCO #define GENTFUNCCO( ftype, ftype_r, ch, chr, blasname, blisname ) \ \ -void PASTEF77(ch,blasname) \ +void PASTEF77S(ch,blasname) \ ( \ const f77_char* uploa, \ const f77_int* m, \ @@ -111,6 +111,20 @@ void PASTEF77(ch,blasname) \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ +}\ +\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* uploa, \ + const f77_int* m, \ + const ftype* alpha, \ + const ftype* x, const f77_int* incx, \ + const ftype* y, const f77_int* incy, \ + ftype* a, const f77_int* lda \ + ) \ +{ \ + PASTEF77S(ch,blasname) \ + ( uploa, m, alpha, x, incx, y, incy, a, lda ); \ } #ifdef BLIS_ENABLE_BLAS diff --git a/frame/compat/bla_her2.h b/frame/compat/bla_her2.h index 7cf0bb867c..2868f83a9c 100644 --- a/frame/compat/bla_her2.h +++ b/frame/compat/bla_her2.h @@ -5,7 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - + Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -40,6 +41,16 @@ #define GENTPROTCO( ftype, ftype_r, ch, chr, blasname ) \ \ BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ + ( \ + const f77_char* uploa, \ + const f77_int* m, \ + const ftype* alpha, \ + const ftype* x, const f77_int* incx, \ + const ftype* y, const f77_int* incy, \ + ftype* a, const f77_int* lda \ + );\ +\ +BLIS_EXPORT_BLAS void PASTEF77S(ch,blasname) \ ( \ const f77_char* uploa, \ const f77_int* m, \ @@ -48,7 +59,7 @@ BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ const ftype* y, const f77_int* incy, \ ftype* a, const f77_int* lda \ ); - + #ifdef BLIS_ENABLE_BLAS INSERT_GENTPROTCO_BLAS( her2 ) #endif diff --git a/frame/compat/bla_symv.c b/frame/compat/bla_symv.c index c105be329e..1c460ea186 100755 --- a/frame/compat/bla_symv.c +++ b/frame/compat/bla_symv.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -42,7 +42,7 @@ #undef GENTFUNCRO #define GENTFUNCRO( ftype, ch, blasname, blisname ) \ \ -void PASTEF77(ch,blasname) \ +void PASTEF77S(ch,blasname) \ ( \ const f77_char* uploa, \ const f77_int* m, \ @@ -112,6 +112,21 @@ void PASTEF77(ch,blasname) \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ +}\ +\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* uploa, \ + const f77_int* m, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + const ftype* x, const f77_int* incx, \ + const ftype* beta, \ + ftype* y, const f77_int* incy \ + ) \ +{ \ + PASTEF77S(ch,blasname) \ + ( uploa, m, alpha, a, lda, x, incx, beta, y, incy ); \ } #ifdef BLIS_ENABLE_BLAS diff --git a/frame/compat/bla_symv.h b/frame/compat/bla_symv.h index 9d1662fadf..efe434884a 100644 --- a/frame/compat/bla_symv.h +++ b/frame/compat/bla_symv.h @@ -5,7 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - + Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -40,6 +41,17 @@ #define GENTPROTRO( ftype, ch, blasname ) \ \ BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ + ( \ + const f77_char* uploa, \ + const f77_int* m, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + const ftype* x, const f77_int* incx, \ + const ftype* beta, \ + ftype* y, const f77_int* incy \ + );\ +\ +BLIS_EXPORT_BLAS void PASTEF77S(ch,blasname) \ ( \ const f77_char* uploa, \ const f77_int* m, \ diff --git a/frame/compat/bla_syr.c b/frame/compat/bla_syr.c index 55251ea254..97f9610294 100644 --- a/frame/compat/bla_syr.c +++ b/frame/compat/bla_syr.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin. - Copyright (C) 2020 - 2021, Advanced Micro Devices, Inc.All Rights Reserved. + Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc.All Rights Reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -42,7 +42,7 @@ #undef GENTFUNCRO #define GENTFUNCRO( ftype, ch, blasname, blisname ) \ \ -void PASTEF77(ch,blasname) \ +void PASTEF77S(ch,blasname) \ ( \ const f77_char* uploa, \ const f77_int* m, \ @@ -104,6 +104,19 @@ void PASTEF77(ch,blasname) \ /* Finalize BLIS. */ \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ bli_finalize_auto(); \ +}\ +\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* uploa, \ + const f77_int* m, \ + const ftype* alpha, \ + const ftype* x, const f77_int* incx, \ + ftype* a, const f77_int* lda \ + ) \ +{ \ + PASTEF77S(ch,blasname) \ + ( uploa, m, alpha, x, incx, a, lda ); \ } #ifdef BLIS_ENABLE_BLAS diff --git a/frame/compat/bla_syr.h b/frame/compat/bla_syr.h index 0d2a1e0314..21d4324171 100644 --- a/frame/compat/bla_syr.h +++ b/frame/compat/bla_syr.h @@ -5,7 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - + Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -40,6 +41,15 @@ #define GENTPROTRO( ftype, ch, blasname ) \ \ BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ + ( \ + const f77_char* uploa, \ + const f77_int* m, \ + const ftype* alpha, \ + const ftype* x, const f77_int* incx, \ + ftype* a, const f77_int* lda \ + );\ +\ +BLIS_EXPORT_BLAS void PASTEF77S(ch,blasname) \ ( \ const f77_char* uploa, \ const f77_int* m, \ diff --git a/frame/compat/bla_syr2.c b/frame/compat/bla_syr2.c index 047dc64f9c..9208c770b2 100644 --- a/frame/compat/bla_syr2.c +++ b/frame/compat/bla_syr2.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin. - Copyright (C) 2020 - 2021, Advanced Micro Devices, Inc.All Rights Reserved. + Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc.All Rights Reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -42,7 +42,7 @@ #undef GENTFUNCRO #define GENTFUNCRO( ftype, ch, blasname, blisname ) \ \ -void PASTEF77(ch,blasname) \ +void PASTEF77S(ch,blasname) \ ( \ const f77_char* uploa, \ const f77_int* m, \ @@ -112,6 +112,20 @@ void PASTEF77(ch,blasname) \ /* Finalize BLIS. */ \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ bli_finalize_auto(); \ +}\ +\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* uploa, \ + const f77_int* m, \ + const ftype* alpha, \ + const ftype* x, const f77_int* incx, \ + const ftype* y, const f77_int* incy, \ + ftype* a, const f77_int* lda \ + ) \ +{ \ + PASTEF77S(ch,blasname) \ + ( uploa, m, alpha, x, incx, y, incy, a, lda ); \ } #ifdef BLIS_ENABLE_BLAS diff --git a/frame/compat/bla_syr2.h b/frame/compat/bla_syr2.h index b458767941..00af4eefdc 100644 --- a/frame/compat/bla_syr2.h +++ b/frame/compat/bla_syr2.h @@ -5,7 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - + Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -40,6 +41,16 @@ #define GENTPROTRO( ftype, ch, blasname ) \ \ BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ + ( \ + const f77_char* uploa, \ + const f77_int* m, \ + const ftype* alpha, \ + const ftype* x, const f77_int* incx, \ + const ftype* y, const f77_int* incy, \ + ftype* a, const f77_int* lda \ + );\ +\ +BLIS_EXPORT_BLAS void PASTEF77S(ch,blasname) \ ( \ const f77_char* uploa, \ const f77_int* m, \ diff --git a/frame/compat/bla_trmv.c b/frame/compat/bla_trmv.c index 9c98ad787a..067a9d2fa3 100644 --- a/frame/compat/bla_trmv.c +++ b/frame/compat/bla_trmv.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin. - Copyright (C) 2020 - 2021, Advanced Micro Devices, Inc.All Rights Reserved. + Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc.All Rights Reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -42,7 +42,7 @@ #undef GENTFUNC #define GENTFUNC( ftype, ch, blasname, blisname ) \ \ -void PASTEF77(ch,blasname) \ +void PASTEF77S(ch,blasname) \ ( \ const f77_char* uploa, \ const f77_char* transa, \ @@ -116,6 +116,20 @@ void PASTEF77(ch,blasname) \ /* Finalize BLIS. */ \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ bli_finalize_auto(); \ +}\ +\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* uploa, \ + const f77_char* transa, \ + const f77_char* diaga, \ + const f77_int* m, \ + const ftype* a, const f77_int* lda, \ + ftype* x, const f77_int* incx \ + ) \ +{ \ + PASTEF77S(ch,blasname) \ + ( uploa, transa, diaga, m, a, lda, x, incx );\ } #ifdef BLIS_ENABLE_BLAS diff --git a/frame/compat/bla_trmv.h b/frame/compat/bla_trmv.h index 4096ffe793..06dda90bd1 100644 --- a/frame/compat/bla_trmv.h +++ b/frame/compat/bla_trmv.h @@ -5,7 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - + Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -40,6 +41,16 @@ #define GENTPROT( ftype, ch, blasname ) \ \ BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ + ( \ + const f77_char* uploa, \ + const f77_char* transa, \ + const f77_char* diaga, \ + const f77_int* m, \ + const ftype* a, const f77_int* lda, \ + ftype* x, const f77_int* incx \ + );\ +\ +BLIS_EXPORT_BLAS void PASTEF77S(ch,blasname) \ ( \ const f77_char* uploa, \ const f77_char* transa, \ diff --git a/frame/compat/bla_trsv.c b/frame/compat/bla_trsv.c index 8baac6a8ba..01f8c2d713 100644 --- a/frame/compat/bla_trsv.c +++ b/frame/compat/bla_trsv.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin. - Copyright (C) 2020 - 2021, Advanced Micro Devices, Inc.All Rights Reserved. + Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc.All Rights Reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -42,7 +42,7 @@ #undef GENTFUNC #define GENTFUNC( ftype, ch, blasname, blisname ) \ \ -void PASTEF77(ch,blasname) \ +void PASTEF77S(ch,blasname) \ ( \ const f77_char* uploa, \ const f77_char* transa, \ @@ -116,6 +116,20 @@ void PASTEF77(ch,blasname) \ /* Finalize BLIS. */ \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ bli_finalize_auto(); \ +}\ +\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* uploa, \ + const f77_char* transa, \ + const f77_char* diaga, \ + const f77_int* m, \ + const ftype* a, const f77_int* lda, \ + ftype* x, const f77_int* incx \ + ) \ +{ \ + PASTEF77S(ch,blasname) \ + ( uploa, transa, diaga, m, a, lda, x, incx );\ } #ifdef BLIS_ENABLE_BLAS diff --git a/frame/compat/bla_trsv.h b/frame/compat/bla_trsv.h index 6edb435f10..3d47272e2c 100644 --- a/frame/compat/bla_trsv.h +++ b/frame/compat/bla_trsv.h @@ -5,7 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - + Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -40,6 +41,16 @@ #define GENTPROT( ftype, ch, blasname ) \ \ BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ + ( \ + const f77_char* uploa, \ + const f77_char* transa, \ + const f77_char* diaga, \ + const f77_int* m, \ + const ftype* a, const f77_int* lda, \ + ftype* x, const f77_int* incx \ + );\ +\ +BLIS_EXPORT_BLAS void PASTEF77S(ch,blasname) \ ( \ const f77_char* uploa, \ const f77_char* transa, \ diff --git a/frame/compat/cblas/src/cblas_f77.h b/frame/compat/cblas/src/cblas_f77.h index 5ec518de9e..d13d833ab9 100644 --- a/frame/compat/cblas/src/cblas_f77.h +++ b/frame/compat/cblas/src/cblas_f77.h @@ -257,72 +257,72 @@ /* * Level 2 BLAS */ -#define F77_ssymv ssymv_ -#define F77_ssbmv ssbmv_ -#define F77_sspmv sspmv_ -#define F77_sger sger_ -#define F77_ssyr ssyr_ -#define F77_sspr sspr_ -#define F77_ssyr2 ssyr2_ -#define F77_sspr2 sspr2_ -#define F77_dsymv dsymv_ -#define F77_dsbmv dsbmv_ -#define F77_dspmv dspmv_ -#define F77_dger dger_ -#define F77_dsyr dsyr_ -#define F77_dspr dspr_ -#define F77_dsyr2 dsyr2_ -#define F77_dspr2 dspr2_ -#define F77_chemv chemv_ -#define F77_chbmv chbmv_ -#define F77_chpmv chpmv_ -#define F77_cgeru cgeru_ -#define F77_cgerc cgerc_ -#define F77_cher cher_ -#define F77_chpr chpr_ -#define F77_cher2 cher2_ -#define F77_chpr2 chpr2_ -#define F77_zhemv zhemv_ -#define F77_zhbmv zhbmv_ -#define F77_zhpmv zhpmv_ -#define F77_zgeru zgeru_ -#define F77_zgerc zgerc_ -#define F77_zher zher_ -#define F77_zhpr zhpr_ -#define F77_zher2 zher2_ -#define F77_zhpr2 zhpr2_ -#define F77_sgemv sgemv_ -#define F77_sgbmv sgbmv_ -#define F77_strmv strmv_ -#define F77_stbmv stbmv_ -#define F77_stpmv stpmv_ -#define F77_strsv strsv_ -#define F77_stbsv stbsv_ -#define F77_stpsv stpsv_ -#define F77_dgemv dgemv_ -#define F77_dgbmv dgbmv_ -#define F77_dtrmv dtrmv_ -#define F77_dtbmv dtbmv_ -#define F77_dtpmv dtpmv_ -#define F77_dtrsv dtrsv_ -#define F77_dtbsv dtbsv_ -#define F77_dtpsv dtpsv_ -#define F77_cgemv cgemv_ -#define F77_cgbmv cgbmv_ -#define F77_ctrmv ctrmv_ -#define F77_ctbmv ctbmv_ -#define F77_ctpmv ctpmv_ -#define F77_ctrsv ctrsv_ -#define F77_ctbsv ctbsv_ -#define F77_ctpsv ctpsv_ -#define F77_zgemv zgemv_ -#define F77_zgbmv zgbmv_ -#define F77_ztrmv ztrmv_ -#define F77_ztbmv ztbmv_ -#define F77_ztpmv ztpmv_ -#define F77_ztrsv ztrsv_ -#define F77_ztbsv ztbsv_ -#define F77_ztpsv ztpsv_ +#define F77_ssymv ssymv_blis_impl +#define F77_ssbmv ssbmv_blis_impl +#define F77_sspmv sspmv_blis_impl +#define F77_sger sger_blis_impl +#define F77_ssyr ssyr_blis_impl +#define F77_sspr sspr_blis_impl +#define F77_ssyr2 ssyr2_blis_impl +#define F77_sspr2 sspr2_blis_impl +#define F77_dsymv dsymv_blis_impl +#define F77_dsbmv dsbmv_blis_impl +#define F77_dspmv dspmv_blis_impl +#define F77_dger dger_blis_impl +#define F77_dsyr dsyr_blis_impl +#define F77_dspr dspr_blis_impl +#define F77_dsyr2 dsyr2_blis_impl +#define F77_dspr2 dspr2_blis_impl +#define F77_chemv chemv_blis_impl +#define F77_chbmv chbmv_blis_impl +#define F77_chpmv chpmv_blis_impl +#define F77_cgeru cgeru_blis_impl +#define F77_cgerc cgerc_blis_impl +#define F77_cher cher_blis_impl +#define F77_chpr chpr_blis_impl +#define F77_cher2 cher2_blis_impl +#define F77_chpr2 chpr2_blis_impl +#define F77_zhemv zhemv_blis_impl +#define F77_zhbmv zhbmv_blis_impl +#define F77_zhpmv zhpmv_blis_impl +#define F77_zgeru zgeru_blis_impl +#define F77_zgerc zgerc_blis_impl +#define F77_zher zher_blis_impl +#define F77_zhpr zhpr_blis_impl +#define F77_zher2 zher2_blis_impl +#define F77_zhpr2 zhpr2_blis_impl +#define F77_sgemv sgemv_blis_impl +#define F77_sgbmv sgbmv_blis_impl +#define F77_strmv strmv_blis_impl +#define F77_stbmv stbmv_blis_impl +#define F77_stpmv stpmv_blis_impl +#define F77_strsv strsv_blis_impl +#define F77_stbsv stbsv_blis_impl +#define F77_stpsv stpsv_blis_impl +#define F77_dgemv dgemv_blis_impl +#define F77_dgbmv dgbmv_blis_impl +#define F77_dtrmv dtrmv_blis_impl +#define F77_dtbmv dtbmv_blis_impl +#define F77_dtpmv dtpmv_blis_impl +#define F77_dtrsv dtrsv_blis_impl +#define F77_dtbsv dtbsv_blis_impl +#define F77_dtpsv dtpsv_blis_impl +#define F77_cgemv cgemv_blis_impl +#define F77_cgbmv cgbmv_blis_impl +#define F77_ctrmv ctrmv_blis_impl +#define F77_ctbmv ctbmv_blis_impl +#define F77_ctpmv ctpmv_blis_impl +#define F77_ctrsv ctrsv_blis_impl +#define F77_ctbsv ctbsv_blis_impl +#define F77_ctpsv ctpsv_blis_impl +#define F77_zgemv zgemv_blis_impl +#define F77_zgbmv zgbmv_blis_impl +#define F77_ztrmv ztrmv_blis_impl +#define F77_ztbmv ztbmv_blis_impl +#define F77_ztpmv ztpmv_blis_impl +#define F77_ztrsv ztrsv_blis_impl +#define F77_ztbsv ztbsv_blis_impl +#define F77_ztpsv ztpsv_blis_impl /* * Level 3 BLAS */ diff --git a/frame/compat/f2c/bla_gbmv.c b/frame/compat/f2c/bla_gbmv.c index d53dd322ad..06999fdd93 100644 --- a/frame/compat/f2c/bla_gbmv.c +++ b/frame/compat/f2c/bla_gbmv.c @@ -5,7 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - + Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -41,7 +42,8 @@ -lf2c -lm (in that order) */ -/* Subroutine */ int PASTEF77(c,gbmv)(const bla_character *trans, const bla_integer *m, const bla_integer *n, const bla_integer *kl, const bla_integer *ku, const bla_scomplex *alpha, const bla_scomplex *a, const bla_integer *lda, const bla_scomplex *x, const bla_integer *incx, const bla_scomplex *beta, bla_scomplex *y, const bla_integer *incy) +/* Subroutine */ +int PASTEF77S(c,gbmv)(const bla_character *trans, const bla_integer *m, const bla_integer *n, const bla_integer *kl, const bla_integer *ku, const bla_scomplex *alpha, const bla_scomplex *a, const bla_integer *lda, const bla_scomplex *x, const bla_integer *incx, const bla_scomplex *beta, bla_scomplex *y, const bla_integer *incy) { /* System generated locals */ bla_integer a_dim1, a_offset, i__1, i__2, i__3, i__4, i__5, i__6; @@ -482,7 +484,8 @@ -lf2c -lm (in that order) */ -/* Subroutine */ int PASTEF77(d,gbmv)(const bla_character *trans, const bla_integer *m, const bla_integer *n, const bla_integer *kl, const bla_integer *ku, const bla_double *alpha, const bla_double *a, const bla_integer *lda, const bla_double *x, const bla_integer *incx, const bla_double *beta, bla_double *y, const bla_integer *incy) +/* Subroutine */ +int PASTEF77S(d,gbmv)(const bla_character *trans, const bla_integer *m, const bla_integer *n, const bla_integer *kl, const bla_integer *ku, const bla_double *alpha, const bla_double *a, const bla_integer *lda, const bla_double *x, const bla_integer *incx, const bla_double *beta, bla_double *y, const bla_integer *incy) { /* System generated locals */ bla_integer a_dim1, a_offset, i__1, i__2, i__3, i__4, i__5, i__6; @@ -838,7 +841,8 @@ -lf2c -lm (in that order) */ -/* Subroutine */ int PASTEF77(s,gbmv)(const bla_character *trans, const bla_integer *m, const bla_integer *n, const bla_integer *kl, const bla_integer *ku, const bla_real *alpha, const bla_real *a, const bla_integer *lda, const bla_real *x, const bla_integer * incx, const bla_real *beta, bla_real *y, const bla_integer *incy) +/* Subroutine */ +int PASTEF77S(s,gbmv)(const bla_character *trans, const bla_integer *m, const bla_integer *n, const bla_integer *kl, const bla_integer *ku, const bla_real *alpha, const bla_real *a, const bla_integer *lda, const bla_real *x, const bla_integer * incx, const bla_real *beta, bla_real *y, const bla_integer *incy) { /* System generated locals */ bla_integer a_dim1, a_offset, i__1, i__2, i__3, i__4, i__5, i__6; @@ -1194,7 +1198,8 @@ -lf2c -lm (in that order) */ -/* Subroutine */ int PASTEF77(z,gbmv)(const bla_character *trans, const bla_integer *m, const bla_integer *n, const bla_integer *kl, const bla_integer *ku, const bla_dcomplex *alpha, const bla_dcomplex *a, const bla_integer *lda, const bla_dcomplex *x, const bla_integer *incx, const bla_dcomplex *beta, bla_dcomplex * y, const bla_integer *incy) +/* Subroutine */ +int PASTEF77S(z,gbmv)(const bla_character *trans, const bla_integer *m, const bla_integer *n, const bla_integer *kl, const bla_integer *ku, const bla_dcomplex *alpha, const bla_dcomplex *a, const bla_integer *lda, const bla_dcomplex *x, const bla_integer *incx, const bla_dcomplex *beta, bla_dcomplex * y, const bla_integer *incy) { /* System generated locals */ bla_integer a_dim1, a_offset, i__1, i__2, i__3, i__4, i__5, i__6; @@ -1630,5 +1635,24 @@ } /* zgbmv_ */ +int PASTEF77(s,gbmv)(const bla_character *trans, const bla_integer *m, const bla_integer *n, const bla_integer *kl, const bla_integer *ku, const bla_real *alpha, const bla_real *a, const bla_integer *lda, const bla_real *x, const bla_integer * incx, const bla_real *beta, bla_real *y, const bla_integer *incy) +{ + return PASTEF77S(s,gbmv)( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy ); +} + +int PASTEF77(d,gbmv)(const bla_character *trans, const bla_integer *m, const bla_integer *n, const bla_integer *kl, const bla_integer *ku, const bla_double *alpha, const bla_double *a, const bla_integer *lda, const bla_double *x, const bla_integer *incx, const bla_double *beta, bla_double *y, const bla_integer *incy) +{ + return PASTEF77S(d,gbmv)( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy ); +} +int PASTEF77(c,gbmv)(const bla_character *trans, const bla_integer *m, const bla_integer *n, const bla_integer *kl, const bla_integer *ku, const bla_scomplex *alpha, const bla_scomplex *a, const bla_integer *lda, const bla_scomplex *x, const bla_integer *incx, const bla_scomplex *beta, bla_scomplex *y, const bla_integer *incy) +{ + return PASTEF77S(c,gbmv)( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy ); +} + +int PASTEF77(z,gbmv)(const bla_character *trans, const bla_integer *m, const bla_integer *n, const bla_integer *kl, const bla_integer *ku, const bla_dcomplex *alpha, const bla_dcomplex *a, const bla_integer *lda, const bla_dcomplex *x, const bla_integer *incx, const bla_dcomplex *beta, bla_dcomplex * y, const bla_integer *incy) +{ + return PASTEF77S(z,gbmv)( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy ); +} + #endif diff --git a/frame/compat/f2c/bla_gbmv.h b/frame/compat/f2c/bla_gbmv.h index eb8ce25342..2ee9a638b3 100644 --- a/frame/compat/f2c/bla_gbmv.h +++ b/frame/compat/f2c/bla_gbmv.h @@ -5,7 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - + Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -38,5 +39,9 @@ BLIS_EXPORT_BLAS int PASTEF77(c,gbmv)(const bla_character *trans, const bla_inte BLIS_EXPORT_BLAS int PASTEF77(d,gbmv)(const bla_character *trans, const bla_integer *m, const bla_integer *n, const bla_integer *kl, const bla_integer *ku, const bla_double *alpha, const bla_double *a, const bla_integer *lda, const bla_double *x, const bla_integer *incx, const bla_double *beta, bla_double *y, const bla_integer *incy); BLIS_EXPORT_BLAS int PASTEF77(s,gbmv)(const bla_character *trans, const bla_integer *m, const bla_integer *n, const bla_integer *kl, const bla_integer *ku, const bla_real *alpha, const bla_real *a, const bla_integer *lda, const bla_real *x, const bla_integer * incx, const bla_real *beta, bla_real *y, const bla_integer *incy); BLIS_EXPORT_BLAS int PASTEF77(z,gbmv)(const bla_character *trans, const bla_integer *m, const bla_integer *n, const bla_integer *kl, const bla_integer *ku, const bla_dcomplex *alpha, const bla_dcomplex *a, const bla_integer *lda, const bla_dcomplex *x, const bla_integer *incx, const bla_dcomplex *beta, bla_dcomplex * y, const bla_integer *incy); +BLIS_EXPORT_BLAS int PASTEF77S(c,gbmv)(const bla_character *trans, const bla_integer *m, const bla_integer *n, const bla_integer *kl, const bla_integer *ku, const bla_scomplex *alpha, const bla_scomplex *a, const bla_integer *lda, const bla_scomplex *x, const bla_integer *incx, const bla_scomplex *beta, bla_scomplex *y, const bla_integer *incy); +BLIS_EXPORT_BLAS int PASTEF77S(d,gbmv)(const bla_character *trans, const bla_integer *m, const bla_integer *n, const bla_integer *kl, const bla_integer *ku, const bla_double *alpha, const bla_double *a, const bla_integer *lda, const bla_double *x, const bla_integer *incx, const bla_double *beta, bla_double *y, const bla_integer *incy); +BLIS_EXPORT_BLAS int PASTEF77S(s,gbmv)(const bla_character *trans, const bla_integer *m, const bla_integer *n, const bla_integer *kl, const bla_integer *ku, const bla_real *alpha, const bla_real *a, const bla_integer *lda, const bla_real *x, const bla_integer * incx, const bla_real *beta, bla_real *y, const bla_integer *incy); +BLIS_EXPORT_BLAS int PASTEF77S(z,gbmv)(const bla_character *trans, const bla_integer *m, const bla_integer *n, const bla_integer *kl, const bla_integer *ku, const bla_dcomplex *alpha, const bla_dcomplex *a, const bla_integer *lda, const bla_dcomplex *x, const bla_integer *incx, const bla_dcomplex *beta, bla_dcomplex * y, const bla_integer *incy); #endif diff --git a/frame/compat/f2c/bla_hbmv.c b/frame/compat/f2c/bla_hbmv.c index 198336d048..af02c3f0ca 100644 --- a/frame/compat/f2c/bla_hbmv.c +++ b/frame/compat/f2c/bla_hbmv.c @@ -5,7 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - + Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -41,7 +42,8 @@ -lf2c -lm (in that order) */ -/* Subroutine */ int PASTEF77(c,hbmv)(const bla_character *uplo, const bla_integer *n, const bla_integer *k, const bla_scomplex * alpha, const bla_scomplex *a, const bla_integer *lda, const bla_scomplex *x, const bla_integer *incx, const bla_scomplex *beta, bla_scomplex *y, const bla_integer *incy) +/* Subroutine */ +int PASTEF77S(c,hbmv)(const bla_character *uplo, const bla_integer *n, const bla_integer *k, const bla_scomplex * alpha, const bla_scomplex *a, const bla_integer *lda, const bla_scomplex *x, const bla_integer *incx, const bla_scomplex *beta, bla_scomplex *y, const bla_integer *incy) { /* System generated locals */ bla_integer a_dim1, a_offset, i__1, i__2, i__3, i__4, i__5; @@ -487,7 +489,8 @@ -lf2c -lm (in that order) */ -/* Subroutine */ int PASTEF77(z,hbmv)(const bla_character *uplo, const bla_integer *n, const bla_integer *k, const bla_dcomplex *alpha, const bla_dcomplex *a, const bla_integer *lda, const bla_dcomplex *x, const bla_integer * incx, const bla_dcomplex *beta, bla_dcomplex *y, const bla_integer *incy) +/* Subroutine */ +int PASTEF77S(z,hbmv)(const bla_character *uplo, const bla_integer *n, const bla_integer *k, const bla_dcomplex *alpha, const bla_dcomplex *a, const bla_integer *lda, const bla_dcomplex *x, const bla_integer * incx, const bla_dcomplex *beta, bla_dcomplex *y, const bla_integer *incy) { /* System generated locals */ bla_integer a_dim1, a_offset, i__1, i__2, i__3, i__4, i__5; @@ -928,5 +931,15 @@ } /* zhbmv_ */ +int PASTEF77(c,hbmv)(const bla_character *uplo, const bla_integer *n, const bla_integer *k, const bla_scomplex * alpha, const bla_scomplex *a, const bla_integer *lda, const bla_scomplex *x, const bla_integer *incx, const bla_scomplex *beta, bla_scomplex *y, const bla_integer *incy) +{ + return PASTEF77S(c,hbmv)( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy ); +} + +int PASTEF77(z,hbmv)(const bla_character *uplo, const bla_integer *n, const bla_integer *k, const bla_dcomplex *alpha, const bla_dcomplex *a, const bla_integer *lda, const bla_dcomplex *x, const bla_integer * incx, const bla_dcomplex *beta, bla_dcomplex *y, const bla_integer *incy) +{ + return PASTEF77S(z,hbmv)( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy ); +} + #endif diff --git a/frame/compat/f2c/bla_hbmv.h b/frame/compat/f2c/bla_hbmv.h index 1ddb838071..fa4dde14a1 100644 --- a/frame/compat/f2c/bla_hbmv.h +++ b/frame/compat/f2c/bla_hbmv.h @@ -5,7 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - + Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -36,5 +37,7 @@ BLIS_EXPORT_BLAS int PASTEF77(c,hbmv)(const bla_character *uplo, const bla_integer *n, const bla_integer *k, const bla_scomplex *alpha, const bla_scomplex *a, const bla_integer *lda, const bla_scomplex *x, const bla_integer *incx, const bla_scomplex *beta, bla_scomplex *y, const bla_integer *incy); BLIS_EXPORT_BLAS int PASTEF77(z,hbmv)(const bla_character *uplo, const bla_integer *n, const bla_integer *k, const bla_dcomplex *alpha, const bla_dcomplex *a, const bla_integer *lda, const bla_dcomplex *x, const bla_integer *incx, const bla_dcomplex *beta, bla_dcomplex *y, const bla_integer *incy); +BLIS_EXPORT_BLAS int PASTEF77S(c,hbmv)(const bla_character *uplo, const bla_integer *n, const bla_integer *k, const bla_scomplex *alpha, const bla_scomplex *a, const bla_integer *lda, const bla_scomplex *x, const bla_integer *incx, const bla_scomplex *beta, bla_scomplex *y, const bla_integer *incy); +BLIS_EXPORT_BLAS int PASTEF77S(z,hbmv)(const bla_character *uplo, const bla_integer *n, const bla_integer *k, const bla_dcomplex *alpha, const bla_dcomplex *a, const bla_integer *lda, const bla_dcomplex *x, const bla_integer *incx, const bla_dcomplex *beta, bla_dcomplex *y, const bla_integer *incy); #endif diff --git a/frame/compat/f2c/bla_hpmv.c b/frame/compat/f2c/bla_hpmv.c index 0d7ebce9d7..344c2868ae 100644 --- a/frame/compat/f2c/bla_hpmv.c +++ b/frame/compat/f2c/bla_hpmv.c @@ -5,7 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - + Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -41,7 +42,8 @@ -lf2c -lm (in that order) */ -/* Subroutine */ int PASTEF77(c,hpmv)(const bla_character *uplo, const bla_integer *n, const bla_scomplex *alpha, const bla_scomplex * ap, const bla_scomplex *x, const bla_integer *incx, const bla_scomplex *beta, bla_scomplex *y, const bla_integer *incy) +/* Subroutine */ +int PASTEF77S(c,hpmv)(const bla_character *uplo, const bla_integer *n, const bla_scomplex *alpha, const bla_scomplex * ap, const bla_scomplex *x, const bla_integer *incx, const bla_scomplex *beta, bla_scomplex *y, const bla_integer *incy) { /* System generated locals */ bla_integer i__1, i__2, i__3, i__4, i__5; @@ -439,7 +441,8 @@ -lf2c -lm (in that order) */ -/* Subroutine */ int PASTEF77(z,hpmv)(const bla_character *uplo, const bla_integer *n, const bla_dcomplex *alpha, const bla_dcomplex *ap, const bla_dcomplex *x, const bla_integer *incx, const bla_dcomplex *beta, bla_dcomplex *y, const bla_integer *incy) +/* Subroutine */ +int PASTEF77S(z,hpmv)(const bla_character *uplo, const bla_integer *n, const bla_dcomplex *alpha, const bla_dcomplex *ap, const bla_dcomplex *x, const bla_integer *incx, const bla_dcomplex *beta, bla_dcomplex *y, const bla_integer *incy) { /* System generated locals */ bla_integer i__1, i__2, i__3, i__4, i__5; @@ -832,5 +835,15 @@ } /* zhpmv_ */ +int PASTEF77(c,hpmv)(const bla_character *uplo, const bla_integer *n, const bla_scomplex *alpha, const bla_scomplex * ap, const bla_scomplex *x, const bla_integer *incx, const bla_scomplex *beta, bla_scomplex *y, const bla_integer *incy) +{ + return PASTEF77S(c,hpmv)( uplo, n, alpha, ap, x, incx, beta, y, incy ); +} + +int PASTEF77(z,hpmv)(const bla_character *uplo, const bla_integer *n, const bla_dcomplex *alpha, const bla_dcomplex *ap, const bla_dcomplex *x, const bla_integer *incx, const bla_dcomplex *beta, bla_dcomplex *y, const bla_integer *incy) +{ + return PASTEF77S(z,hpmv)( uplo, n, alpha, ap, x, incx, beta, y, incy ); +} + #endif diff --git a/frame/compat/f2c/bla_hpmv.h b/frame/compat/f2c/bla_hpmv.h index 26d055effd..61a8e9c1d8 100644 --- a/frame/compat/f2c/bla_hpmv.h +++ b/frame/compat/f2c/bla_hpmv.h @@ -5,7 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - + Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -36,5 +37,7 @@ BLIS_EXPORT_BLAS int PASTEF77(c,hpmv)(const bla_character *uplo, const bla_integer *n, const bla_scomplex *alpha, const bla_scomplex *ap, const bla_scomplex *x, const bla_integer *incx, const bla_scomplex *beta, bla_scomplex *y, const bla_integer *incy); BLIS_EXPORT_BLAS int PASTEF77(z,hpmv)(const bla_character *uplo, const bla_integer *n, const bla_dcomplex *alpha, const bla_dcomplex *ap, const bla_dcomplex *x, const bla_integer *incx, const bla_dcomplex *beta, bla_dcomplex *y, const bla_integer *incy); +BLIS_EXPORT_BLAS int PASTEF77S(c,hpmv)(const bla_character *uplo, const bla_integer *n, const bla_scomplex *alpha, const bla_scomplex *ap, const bla_scomplex *x, const bla_integer *incx, const bla_scomplex *beta, bla_scomplex *y, const bla_integer *incy); +BLIS_EXPORT_BLAS int PASTEF77S(z,hpmv)(const bla_character *uplo, const bla_integer *n, const bla_dcomplex *alpha, const bla_dcomplex *ap, const bla_dcomplex *x, const bla_integer *incx, const bla_dcomplex *beta, bla_dcomplex *y, const bla_integer *incy); #endif diff --git a/frame/compat/f2c/bla_hpr.c b/frame/compat/f2c/bla_hpr.c index da1f0a0f39..ae013f869d 100644 --- a/frame/compat/f2c/bla_hpr.c +++ b/frame/compat/f2c/bla_hpr.c @@ -5,7 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - + Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -41,7 +42,8 @@ -lf2c -lm (in that order) */ -/* Subroutine */ int PASTEF77(c,hpr)(const bla_character *uplo, const bla_integer *n, const bla_real *alpha, const bla_scomplex *x, const bla_integer *incx, bla_scomplex *ap) +/* Subroutine */ +int PASTEF77S(c,hpr)(const bla_character *uplo, const bla_integer *n, const bla_real *alpha, const bla_scomplex *x, const bla_integer *incx, bla_scomplex *ap) { /* System generated locals */ bla_integer i__1, i__2, i__3, i__4, i__5; @@ -353,7 +355,8 @@ -lf2c -lm (in that order) */ -/* Subroutine */ int PASTEF77(z,hpr)(const bla_character *uplo, const bla_integer *n, const bla_double *alpha, const bla_dcomplex *x, const bla_integer *incx, bla_dcomplex *ap) +/* Subroutine */ +int PASTEF77S(z,hpr)(const bla_character *uplo, const bla_integer *n, const bla_double *alpha, const bla_dcomplex *x, const bla_integer *incx, bla_dcomplex *ap) { /* System generated locals */ bla_integer i__1, i__2, i__3, i__4, i__5; @@ -660,5 +663,15 @@ } /* zhpr_ */ +int PASTEF77(c,hpr)(const bla_character *uplo, const bla_integer *n, const bla_real *alpha, const bla_scomplex *x, const bla_integer *incx, bla_scomplex *ap) +{ + return PASTEF77S(c,hpr)( uplo, n, alpha, x, incx, ap ); +} + +int PASTEF77(z,hpr)(const bla_character *uplo, const bla_integer *n, const bla_double *alpha, const bla_dcomplex *x, const bla_integer *incx, bla_dcomplex *ap) +{ + return PASTEF77S(z,hpr)( uplo, n, alpha, x, incx, ap ); +} + #endif diff --git a/frame/compat/f2c/bla_hpr.h b/frame/compat/f2c/bla_hpr.h index cfce9e1779..3f6ffa7064 100644 --- a/frame/compat/f2c/bla_hpr.h +++ b/frame/compat/f2c/bla_hpr.h @@ -5,7 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - + Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -36,5 +37,7 @@ BLIS_EXPORT_BLAS int PASTEF77(c,hpr)(const bla_character *uplo, const bla_integer *n, const bla_real *alpha, const bla_scomplex *x, const bla_integer *incx, bla_scomplex *ap); BLIS_EXPORT_BLAS int PASTEF77(z,hpr)(const bla_character *uplo, const bla_integer *n, const bla_double *alpha, const bla_dcomplex *x, const bla_integer *incx, bla_dcomplex *ap); +BLIS_EXPORT_BLAS int PASTEF77S(c,hpr)(const bla_character *uplo, const bla_integer *n, const bla_real *alpha, const bla_scomplex *x, const bla_integer *incx, bla_scomplex *ap); +BLIS_EXPORT_BLAS int PASTEF77S(z,hpr)(const bla_character *uplo, const bla_integer *n, const bla_double *alpha, const bla_dcomplex *x, const bla_integer *incx, bla_dcomplex *ap); #endif diff --git a/frame/compat/f2c/bla_hpr2.c b/frame/compat/f2c/bla_hpr2.c index c78c1eec04..f99e8181b6 100644 --- a/frame/compat/f2c/bla_hpr2.c +++ b/frame/compat/f2c/bla_hpr2.c @@ -5,7 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - + Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -41,7 +42,8 @@ -lf2c -lm (in that order) */ -/* Subroutine */ int PASTEF77(c,hpr2)(const bla_character *uplo, const bla_integer *n, const bla_scomplex *alpha, const bla_scomplex *x, const bla_integer *incx, const bla_scomplex *y, const bla_integer *incy, bla_scomplex *ap) +/* Subroutine */ +int PASTEF77S(c,hpr2)(const bla_character *uplo, const bla_integer *n, const bla_scomplex *alpha, const bla_scomplex *x, const bla_integer *incx, const bla_scomplex *y, const bla_integer *incy, bla_scomplex *ap) { /* System generated locals */ bla_integer i__1, i__2, i__3, i__4, i__5, i__6; @@ -429,7 +431,8 @@ -lf2c -lm (in that order) */ -/* Subroutine */ int PASTEF77(z,hpr2)(const bla_character *uplo, const bla_integer *n, const bla_dcomplex *alpha, const bla_dcomplex *x, const bla_integer *incx, const bla_dcomplex *y, const bla_integer *incy, bla_dcomplex *ap) +/* Subroutine */ +int PASTEF77S(z,hpr2)(const bla_character *uplo, const bla_integer *n, const bla_dcomplex *alpha, const bla_dcomplex *x, const bla_integer *incx, const bla_dcomplex *y, const bla_integer *incy, bla_dcomplex *ap) { /* System generated locals */ bla_integer i__1, i__2, i__3, i__4, i__5, i__6; @@ -812,5 +815,15 @@ } /* zhpr2_ */ +int PASTEF77(c,hpr2)(const bla_character *uplo, const bla_integer *n, const bla_scomplex *alpha, const bla_scomplex *x, const bla_integer *incx, const bla_scomplex *y, const bla_integer *incy, bla_scomplex *ap) +{ + return PASTEF77S(c,hpr2)( uplo, n, alpha, x, incx, y, incy, ap ); +} + +int PASTEF77(z,hpr2)(const bla_character *uplo, const bla_integer *n, const bla_dcomplex *alpha, const bla_dcomplex *x, const bla_integer *incx, const bla_dcomplex *y, const bla_integer *incy, bla_dcomplex *ap) +{ + return PASTEF77S(z,hpr2)( uplo, n, alpha, x, incx, y, incy, ap ); +} + #endif diff --git a/frame/compat/f2c/bla_hpr2.h b/frame/compat/f2c/bla_hpr2.h index 16f929d611..6e56b5e053 100644 --- a/frame/compat/f2c/bla_hpr2.h +++ b/frame/compat/f2c/bla_hpr2.h @@ -5,7 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - + Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -36,5 +37,7 @@ BLIS_EXPORT_BLAS int PASTEF77(c,hpr2)(const bla_character *uplo, const bla_integer *n, const bla_scomplex *alpha, const bla_scomplex *x, const bla_integer *incx, const bla_scomplex *y, const bla_integer *incy, bla_scomplex *ap); BLIS_EXPORT_BLAS int PASTEF77(z,hpr2)(const bla_character *uplo, const bla_integer *n, const bla_dcomplex *alpha, const bla_dcomplex *x, const bla_integer *incx, const bla_dcomplex *y, const bla_integer *incy, bla_dcomplex *ap); +BLIS_EXPORT_BLAS int PASTEF77S(c,hpr2)(const bla_character *uplo, const bla_integer *n, const bla_scomplex *alpha, const bla_scomplex *x, const bla_integer *incx, const bla_scomplex *y, const bla_integer *incy, bla_scomplex *ap); +BLIS_EXPORT_BLAS int PASTEF77S(z,hpr2)(const bla_character *uplo, const bla_integer *n, const bla_dcomplex *alpha, const bla_dcomplex *x, const bla_integer *incx, const bla_dcomplex *y, const bla_integer *incy, bla_dcomplex *ap); #endif diff --git a/frame/compat/f2c/bla_sbmv.c b/frame/compat/f2c/bla_sbmv.c index 566fabd81c..bfbbcf0091 100644 --- a/frame/compat/f2c/bla_sbmv.c +++ b/frame/compat/f2c/bla_sbmv.c @@ -5,7 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - + Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -41,7 +42,8 @@ -lf2c -lm (in that order) */ -/* Subroutine */ int PASTEF77(d,sbmv)(const bla_character *uplo, const bla_integer *n, const bla_integer *k, const bla_double *alpha, const bla_double *a, const bla_integer *lda, const bla_double *x, const bla_integer *incx, const bla_double *beta, bla_double *y, const bla_integer *incy) +/* Subroutine */ +int PASTEF77S(d,sbmv)(const bla_character *uplo, const bla_integer *n, const bla_integer *k, const bla_double *alpha, const bla_double *a, const bla_integer *lda, const bla_double *x, const bla_integer *incx, const bla_double *beta, bla_double *y, const bla_integer *incy) { /* System generated locals */ bla_integer a_dim1, a_offset, i__1, i__2, i__3, i__4; @@ -392,7 +394,8 @@ -lf2c -lm (in that order) */ -/* Subroutine */ int PASTEF77(s,sbmv)(const bla_character *uplo, const bla_integer *n, const bla_integer *k, const bla_real *alpha, const bla_real *a, const bla_integer *lda, const bla_real *x, const bla_integer *incx, const bla_real *beta, bla_real *y, const bla_integer *incy) +/* Subroutine */ +int PASTEF77S(s,sbmv)(const bla_character *uplo, const bla_integer *n, const bla_integer *k, const bla_real *alpha, const bla_real *a, const bla_integer *lda, const bla_real *x, const bla_integer *incx, const bla_real *beta, bla_real *y, const bla_integer *incy) { /* System generated locals */ bla_integer a_dim1, a_offset, i__1, i__2, i__3, i__4; @@ -738,5 +741,15 @@ } /* ssbmv_ */ +int PASTEF77(d,sbmv)(const bla_character *uplo, const bla_integer *n, const bla_integer *k, const bla_double *alpha, const bla_double *a, const bla_integer *lda, const bla_double *x, const bla_integer *incx, const bla_double *beta, bla_double *y, const bla_integer *incy) +{ + return PASTEF77S(d,sbmv)(uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); +} + +int PASTEF77(s,sbmv)(const bla_character *uplo, const bla_integer *n, const bla_integer *k, const bla_real *alpha, const bla_real *a, const bla_integer *lda, const bla_real *x, const bla_integer *incx, const bla_real *beta, bla_real *y, const bla_integer *incy) +{ + return PASTEF77S(s,sbmv)(uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); +} + #endif diff --git a/frame/compat/f2c/bla_sbmv.h b/frame/compat/f2c/bla_sbmv.h index c3f3fc24f8..84e86273a1 100644 --- a/frame/compat/f2c/bla_sbmv.h +++ b/frame/compat/f2c/bla_sbmv.h @@ -5,7 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - + Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -36,5 +37,7 @@ BLIS_EXPORT_BLAS int PASTEF77(d,sbmv)(const bla_character *uplo, const bla_integer *n, const bla_integer *k, const bla_double *alpha, const bla_double *a, const bla_integer *lda, const bla_double *x, const bla_integer *incx, const bla_double *beta, bla_double *y, const bla_integer *incy); BLIS_EXPORT_BLAS int PASTEF77(s,sbmv)(const bla_character *uplo, const bla_integer *n, const bla_integer *k, const bla_real *alpha, const bla_real *a, const bla_integer *lda, const bla_real *x, const bla_integer *incx, const bla_real *beta, bla_real *y, const bla_integer *incy); +BLIS_EXPORT_BLAS int PASTEF77S(d,sbmv)(const bla_character *uplo, const bla_integer *n, const bla_integer *k, const bla_double *alpha, const bla_double *a, const bla_integer *lda, const bla_double *x, const bla_integer *incx, const bla_double *beta, bla_double *y, const bla_integer *incy); +BLIS_EXPORT_BLAS int PASTEF77S(s,sbmv)(const bla_character *uplo, const bla_integer *n, const bla_integer *k, const bla_real *alpha, const bla_real *a, const bla_integer *lda, const bla_real *x, const bla_integer *incx, const bla_real *beta, bla_real *y, const bla_integer *incy); #endif diff --git a/frame/compat/f2c/bla_spmv.c b/frame/compat/f2c/bla_spmv.c index 0485e1dc3a..e3cdbd70c8 100644 --- a/frame/compat/f2c/bla_spmv.c +++ b/frame/compat/f2c/bla_spmv.c @@ -5,7 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - + Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -41,7 +42,8 @@ -lf2c -lm (in that order) */ -/* Subroutine */ int PASTEF77(d,spmv)(const bla_character *uplo, const bla_integer *n, const bla_double *alpha, const bla_double *ap, const bla_double *x, const bla_integer *incx, const bla_double *beta, bla_double *y, const bla_integer *incy) +/* Subroutine */ +int PASTEF77S(d,spmv)(const bla_character *uplo, const bla_integer *n, const bla_double *alpha, const bla_double *ap, const bla_double *x, const bla_integer *incx, const bla_double *beta, bla_double *y, const bla_integer *incy) { /* System generated locals */ bla_integer i__1, i__2; @@ -342,7 +344,8 @@ -lf2c -lm (in that order) */ -/* Subroutine */ int PASTEF77(s,spmv)(const bla_character *uplo, const bla_integer *n, const bla_real *alpha, const bla_real *ap, const bla_real *x, const bla_integer *incx, const bla_real *beta, bla_real *y, const bla_integer *incy) +/* Subroutine */ +int PASTEF77S(s,spmv)(const bla_character *uplo, const bla_integer *n, const bla_real *alpha, const bla_real *ap, const bla_real *x, const bla_integer *incx, const bla_real *beta, bla_real *y, const bla_integer *incy) { /* System generated locals */ bla_integer i__1, i__2; @@ -638,5 +641,15 @@ } /* sspmv_ */ +int PASTEF77(d,spmv)(const bla_character *uplo, const bla_integer *n, const bla_double *alpha, const bla_double *ap, const bla_double *x, const bla_integer *incx, const bla_double *beta, bla_double *y, const bla_integer *incy) +{ + return PASTEF77S(d,spmv)( uplo, n, alpha, ap, x, incx, beta, y, incy); +} + +int PASTEF77(s,spmv)(const bla_character *uplo, const bla_integer *n, const bla_real *alpha, const bla_real *ap, const bla_real *x, const bla_integer *incx, const bla_real *beta, bla_real *y, const bla_integer *incy) +{ + return PASTEF77S(s,spmv)( uplo, n, alpha, ap, x, incx, beta, y, incy); +} + #endif diff --git a/frame/compat/f2c/bla_spmv.h b/frame/compat/f2c/bla_spmv.h index 7db7d4a8b6..df85babf1e 100644 --- a/frame/compat/f2c/bla_spmv.h +++ b/frame/compat/f2c/bla_spmv.h @@ -5,7 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - + Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -36,5 +37,7 @@ BLIS_EXPORT_BLAS int PASTEF77(d,spmv)(const bla_character *uplo, const bla_integer *n, const bla_double *alpha, const bla_double *ap, const bla_double *x, const bla_integer *incx, const bla_double *beta, bla_double *y, const bla_integer *incy); BLIS_EXPORT_BLAS int PASTEF77(s,spmv)(const bla_character *uplo, const bla_integer *n, const bla_real *alpha, const bla_real *ap, const bla_real *x, const bla_integer *incx, const bla_real *beta, bla_real *y, const bla_integer *incy); +BLIS_EXPORT_BLAS int PASTEF77S(d,spmv)(const bla_character *uplo, const bla_integer *n, const bla_double *alpha, const bla_double *ap, const bla_double *x, const bla_integer *incx, const bla_double *beta, bla_double *y, const bla_integer *incy); +BLIS_EXPORT_BLAS int PASTEF77S(s,spmv)(const bla_character *uplo, const bla_integer *n, const bla_real *alpha, const bla_real *ap, const bla_real *x, const bla_integer *incx, const bla_real *beta, bla_real *y, const bla_integer *incy); #endif diff --git a/frame/compat/f2c/bla_spr.c b/frame/compat/f2c/bla_spr.c index d276458b49..9b3ee1d886 100644 --- a/frame/compat/f2c/bla_spr.c +++ b/frame/compat/f2c/bla_spr.c @@ -5,7 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - + Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -41,7 +42,8 @@ -lf2c -lm (in that order) */ -/* Subroutine */ int PASTEF77(d,spr)(const bla_character *uplo, const bla_integer *n, const bla_double *alpha, const bla_double *x, const bla_integer *incx, bla_double *ap) +/* Subroutine */ +int PASTEF77S(d,spr)(const bla_character *uplo, const bla_integer *n, const bla_double *alpha, const bla_double *x, const bla_integer *incx, bla_double *ap) { /* System generated locals */ bla_integer i__1, i__2; @@ -268,7 +270,8 @@ -lf2c -lm (in that order) */ -/* Subroutine */ int PASTEF77(s,spr)(const bla_character *uplo, const bla_integer *n, const bla_real *alpha, const bla_real *x, const bla_integer *incx, bla_real *ap) +/* Subroutine */ +int PASTEF77S(s,spr)(const bla_character *uplo, const bla_integer *n, const bla_real *alpha, const bla_real *x, const bla_integer *incx, bla_real *ap) { /* System generated locals */ bla_integer i__1, i__2; @@ -490,5 +493,15 @@ } /* sspr_ */ +int PASTEF77(d,spr)(const bla_character *uplo, const bla_integer *n, const bla_double *alpha, const bla_double *x, const bla_integer *incx, bla_double *ap) +{ + return PASTEF77S(d,spr)( uplo, n, alpha, x, incx, ap ); +} + +int PASTEF77(s,spr)(const bla_character *uplo, const bla_integer *n, const bla_real *alpha, const bla_real *x, const bla_integer *incx, bla_real *ap) +{ + return PASTEF77S(s,spr)( uplo, n, alpha, x, incx, ap ); +} + #endif diff --git a/frame/compat/f2c/bla_spr.h b/frame/compat/f2c/bla_spr.h index 6712d7c166..d7519ca049 100644 --- a/frame/compat/f2c/bla_spr.h +++ b/frame/compat/f2c/bla_spr.h @@ -5,7 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - + Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -36,5 +37,7 @@ BLIS_EXPORT_BLAS int PASTEF77(d,spr)(const bla_character *uplo, const bla_integer *n, const bla_double *alpha, const bla_double *x, const bla_integer *incx, bla_double *ap); BLIS_EXPORT_BLAS int PASTEF77(s,spr)(const bla_character *uplo, const bla_integer *n, const bla_real *alpha, const bla_real *x, const bla_integer *incx, bla_real *ap); +BLIS_EXPORT_BLAS int PASTEF77S(d,spr)(const bla_character *uplo, const bla_integer *n, const bla_double *alpha, const bla_double *x, const bla_integer *incx, bla_double *ap); +BLIS_EXPORT_BLAS int PASTEF77S(s,spr)(const bla_character *uplo, const bla_integer *n, const bla_real *alpha, const bla_real *x, const bla_integer *incx, bla_real *ap); #endif diff --git a/frame/compat/f2c/bla_spr2.c b/frame/compat/f2c/bla_spr2.c index 7c75382122..6955172512 100644 --- a/frame/compat/f2c/bla_spr2.c +++ b/frame/compat/f2c/bla_spr2.c @@ -5,7 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - + Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -41,7 +42,8 @@ -lf2c -lm (in that order) */ -/* Subroutine */ int PASTEF77(d,spr2)(const bla_character *uplo, const bla_integer *n, const bla_double *alpha, const bla_double *x, const bla_integer *incx, const bla_double *y, const bla_integer *incy, bla_double *ap) +/* Subroutine */ +int PASTEF77S(d,spr2)(const bla_character *uplo, const bla_integer *n, const bla_double *alpha, const bla_double *x, const bla_integer *incx, const bla_double *y, const bla_integer *incy, bla_double *ap) { /* System generated locals */ bla_integer i__1, i__2; @@ -300,7 +302,8 @@ -lf2c -lm (in that order) */ -/* Subroutine */ int PASTEF77(s,spr2)(const bla_character *uplo, const bla_integer *n, const bla_real *alpha, const bla_real *x, const bla_integer *incx, const bla_real *y, const bla_integer *incy, bla_real *ap) +/* Subroutine */ +int PASTEF77S(s,spr2)(const bla_character *uplo, const bla_integer *n, const bla_real *alpha, const bla_real *x, const bla_integer *incx, const bla_real *y, const bla_integer *incy, bla_real *ap) { /* System generated locals */ bla_integer i__1, i__2; @@ -554,5 +557,15 @@ } /* sspr2_ */ +int PASTEF77(d,spr2)(const bla_character *uplo, const bla_integer *n, const bla_double *alpha, const bla_double *x, const bla_integer *incx, const bla_double *y, const bla_integer *incy, bla_double *ap) +{ + return PASTEF77S(d,spr2)( uplo, n, alpha, x, incx, y, incy,ap ); +} + +int PASTEF77(s,spr2)(const bla_character *uplo, const bla_integer *n, const bla_real *alpha, const bla_real *x, const bla_integer *incx, const bla_real *y, const bla_integer *incy, bla_real *ap) +{ + return PASTEF77S(s,spr2)( uplo, n, alpha, x, incx, y, incy,ap ); +} + #endif diff --git a/frame/compat/f2c/bla_spr2.h b/frame/compat/f2c/bla_spr2.h index 5a1d607471..1f02990dac 100644 --- a/frame/compat/f2c/bla_spr2.h +++ b/frame/compat/f2c/bla_spr2.h @@ -5,7 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - + Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -36,5 +37,7 @@ BLIS_EXPORT_BLAS int PASTEF77(d,spr2)(const bla_character *uplo, const bla_integer *n, const bla_double *alpha, const bla_double *x, const bla_integer *incx, const bla_double *y, const bla_integer *incy, bla_double *ap); BLIS_EXPORT_BLAS int PASTEF77(s,spr2)(const bla_character *uplo, const bla_integer *n, const bla_real *alpha, const bla_real *x, const bla_integer *incx, const bla_real *y, const bla_integer *incy, bla_real *ap); +BLIS_EXPORT_BLAS int PASTEF77S(d,spr2)(const bla_character *uplo, const bla_integer *n, const bla_double *alpha, const bla_double *x, const bla_integer *incx, const bla_double *y, const bla_integer *incy, bla_double *ap); +BLIS_EXPORT_BLAS int PASTEF77S(s,spr2)(const bla_character *uplo, const bla_integer *n, const bla_real *alpha, const bla_real *x, const bla_integer *incx, const bla_real *y, const bla_integer *incy, bla_real *ap); #endif diff --git a/frame/compat/f2c/bla_tbmv.c b/frame/compat/f2c/bla_tbmv.c index 78feb70562..de0cfe92db 100644 --- a/frame/compat/f2c/bla_tbmv.c +++ b/frame/compat/f2c/bla_tbmv.c @@ -5,7 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - + Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -41,7 +42,8 @@ -lf2c -lm (in that order) */ -/* Subroutine */ int PASTEF77(c,tbmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_scomplex *a, const bla_integer *lda, bla_scomplex *x, const bla_integer *incx) +/* Subroutine */ +int PASTEF77S(c,tbmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_scomplex *a, const bla_integer *lda, bla_scomplex *x, const bla_integer *incx) { /* System generated locals */ bla_integer a_dim1, a_offset, i__1, i__2, i__3, i__4, i__5; @@ -212,11 +214,11 @@ if (! PASTEF770(lsame)(uplo, "U", (ftnlen)1, (ftnlen)1) && ! PASTEF770(lsame)(uplo, "L", ( ftnlen)1, (ftnlen)1)) { info = 1; - } else if (! PASTEF770(lsame)(trans, "N", (ftnlen)1, (ftnlen)1) && ! PASTEF770(lsame)(trans, + } else if (! PASTEF770(lsame)(trans, "N", (ftnlen)1, (ftnlen)1) && ! PASTEF770(lsame)(trans, "T", (ftnlen)1, (ftnlen)1) && ! PASTEF770(lsame)(trans, "C", (ftnlen)1, ( ftnlen)1)) { info = 2; - } else if (! PASTEF770(lsame)(diag, "U", (ftnlen)1, (ftnlen)1) && ! PASTEF770(lsame)(diag, + } else if (! PASTEF770(lsame)(diag, "U", (ftnlen)1, (ftnlen)1) && ! PASTEF770(lsame)(diag, "N", (ftnlen)1, (ftnlen)1)) { info = 3; } else if (*n < 0) { @@ -611,7 +613,8 @@ -lf2c -lm (in that order) */ -/* Subroutine */ int PASTEF77(d,tbmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_double *a, const bla_integer *lda, bla_double *x, const bla_integer *incx) +/* Subroutine */ +int PASTEF77S(d,tbmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_double *a, const bla_integer *lda, bla_double *x, const bla_integer *incx) { /* System generated locals */ bla_integer a_dim1, a_offset, i__1, i__2, i__3, i__4; @@ -778,11 +781,11 @@ if (! PASTEF770(lsame)(uplo, "U", (ftnlen)1, (ftnlen)1) && ! PASTEF770(lsame)(uplo, "L", ( ftnlen)1, (ftnlen)1)) { info = 1; - } else if (! PASTEF770(lsame)(trans, "N", (ftnlen)1, (ftnlen)1) && ! PASTEF770(lsame)(trans, + } else if (! PASTEF770(lsame)(trans, "N", (ftnlen)1, (ftnlen)1) && ! PASTEF770(lsame)(trans, "T", (ftnlen)1, (ftnlen)1) && ! PASTEF770(lsame)(trans, "C", (ftnlen)1, ( ftnlen)1)) { info = 2; - } else if (! PASTEF770(lsame)(diag, "U", (ftnlen)1, (ftnlen)1) && ! PASTEF770(lsame)(diag, + } else if (! PASTEF770(lsame)(diag, "U", (ftnlen)1, (ftnlen)1) && ! PASTEF770(lsame)(diag, "N", (ftnlen)1, (ftnlen)1)) { info = 3; } else if (*n < 0) { @@ -1022,7 +1025,8 @@ -lf2c -lm (in that order) */ -/* Subroutine */ int PASTEF77(s,tbmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_real *a, const bla_integer *lda, bla_real *x, const bla_integer *incx) +/* Subroutine */ +int PASTEF77S(s,tbmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_real *a, const bla_integer *lda, bla_real *x, const bla_integer *incx) { /* System generated locals */ bla_integer a_dim1, a_offset, i__1, i__2, i__3, i__4; @@ -1189,11 +1193,11 @@ if (! PASTEF770(lsame)(uplo, "U", (ftnlen)1, (ftnlen)1) && ! PASTEF770(lsame)(uplo, "L", ( ftnlen)1, (ftnlen)1)) { info = 1; - } else if (! PASTEF770(lsame)(trans, "N", (ftnlen)1, (ftnlen)1) && ! PASTEF770(lsame)(trans, + } else if (! PASTEF770(lsame)(trans, "N", (ftnlen)1, (ftnlen)1) && ! PASTEF770(lsame)(trans, "T", (ftnlen)1, (ftnlen)1) && ! PASTEF770(lsame)(trans, "C", (ftnlen)1, ( ftnlen)1)) { info = 2; - } else if (! PASTEF770(lsame)(diag, "U", (ftnlen)1, (ftnlen)1) && ! PASTEF770(lsame)(diag, + } else if (! PASTEF770(lsame)(diag, "U", (ftnlen)1, (ftnlen)1) && ! PASTEF770(lsame)(diag, "N", (ftnlen)1, (ftnlen)1)) { info = 3; } else if (*n < 0) { @@ -1433,7 +1437,8 @@ -lf2c -lm (in that order) */ -/* Subroutine */ int PASTEF77(z,tbmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_dcomplex *a, const bla_integer *lda, bla_dcomplex *x, const bla_integer *incx) +/* Subroutine */ +int PASTEF77S(z,tbmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_dcomplex *a, const bla_integer *lda, bla_dcomplex *x, const bla_integer *incx) { /* System generated locals */ bla_integer a_dim1, a_offset, i__1, i__2, i__3, i__4, i__5; @@ -1604,11 +1609,11 @@ if (! PASTEF770(lsame)(uplo, "U", (ftnlen)1, (ftnlen)1) && ! PASTEF770(lsame)(uplo, "L", ( ftnlen)1, (ftnlen)1)) { info = 1; - } else if (! PASTEF770(lsame)(trans, "N", (ftnlen)1, (ftnlen)1) && ! PASTEF770(lsame)(trans, + } else if (! PASTEF770(lsame)(trans, "N", (ftnlen)1, (ftnlen)1) && ! PASTEF770(lsame)(trans, "T", (ftnlen)1, (ftnlen)1) && ! PASTEF770(lsame)(trans, "C", (ftnlen)1, ( ftnlen)1)) { info = 2; - } else if (! PASTEF770(lsame)(diag, "U", (ftnlen)1, (ftnlen)1) && ! PASTEF770(lsame)(diag, + } else if (! PASTEF770(lsame)(diag, "U", (ftnlen)1, (ftnlen)1) && ! PASTEF770(lsame)(diag, "N", (ftnlen)1, (ftnlen)1)) { info = 3; } else if (*n < 0) { @@ -1998,5 +2003,25 @@ } /* ztbmv_ */ +int PASTEF77(s,tbmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_real *a, const bla_integer *lda, bla_real *x, const bla_integer *incx) +{ + return PASTEF77S(s,tbmv)( uplo, trans, diag, n, k, a, lda, x, incx ); +} + +int PASTEF77(d,tbmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_double *a, const bla_integer *lda, bla_double *x, const bla_integer *incx) +{ + return PASTEF77S(d,tbmv)( uplo, trans, diag, n, k, a, lda, x, incx ); +} + +int PASTEF77(c,tbmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_scomplex *a, const bla_integer *lda, bla_scomplex *x, const bla_integer *incx) +{ + return PASTEF77S(c,tbmv)( uplo, trans, diag, n, k, a, lda, x, incx ); +} + +int PASTEF77(z,tbmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_dcomplex *a, const bla_integer *lda, bla_dcomplex *x, const bla_integer *incx) +{ + return PASTEF77S(z,tbmv)( uplo, trans, diag, n, k, a, lda, x, incx ); +} + #endif diff --git a/frame/compat/f2c/bla_tbmv.h b/frame/compat/f2c/bla_tbmv.h index f34654762b..cce9f18c8f 100644 --- a/frame/compat/f2c/bla_tbmv.h +++ b/frame/compat/f2c/bla_tbmv.h @@ -5,7 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - + Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -38,5 +39,9 @@ BLIS_EXPORT_BLAS int PASTEF77(c,tbmv)(const bla_character *uplo, const bla_chara BLIS_EXPORT_BLAS int PASTEF77(d,tbmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_double *a, const bla_integer *lda, bla_double *x, const bla_integer *incx); BLIS_EXPORT_BLAS int PASTEF77(s,tbmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_real *a, const bla_integer *lda, bla_real *x, const bla_integer *incx); BLIS_EXPORT_BLAS int PASTEF77(z,tbmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_dcomplex *a, const bla_integer *lda, bla_dcomplex *x, const bla_integer *incx); +BLIS_EXPORT_BLAS int PASTEF77S(c,tbmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_scomplex *a, const bla_integer *lda, bla_scomplex *x, const bla_integer *incx); +BLIS_EXPORT_BLAS int PASTEF77S(d,tbmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_double *a, const bla_integer *lda, bla_double *x, const bla_integer *incx); +BLIS_EXPORT_BLAS int PASTEF77S(s,tbmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_real *a, const bla_integer *lda, bla_real *x, const bla_integer *incx); +BLIS_EXPORT_BLAS int PASTEF77S(z,tbmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_dcomplex *a, const bla_integer *lda, bla_dcomplex *x, const bla_integer *incx); #endif diff --git a/frame/compat/f2c/bla_tbsv.c b/frame/compat/f2c/bla_tbsv.c index 819456f029..25239c780d 100644 --- a/frame/compat/f2c/bla_tbsv.c +++ b/frame/compat/f2c/bla_tbsv.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2021-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -42,7 +42,8 @@ -lf2c -lm (in that order) */ -/* Subroutine */ int PASTEF77(c,tbsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_scomplex *a, const bla_integer *lda, bla_scomplex *x, const bla_integer *incx) +/* Subroutine */ +int PASTEF77S(c,tbsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_scomplex *a, const bla_integer *lda, bla_scomplex *x, const bla_integer *incx) { /* System generated locals */ bla_integer a_dim1, a_offset, i__1, i__2, i__3, i__4, i__5; @@ -622,7 +623,8 @@ -lf2c -lm (in that order) */ -/* Subroutine */ int PASTEF77(d,tbsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_double *a, const bla_integer *lda, bla_double *x, const bla_integer *incx) +/* Subroutine */ +int PASTEF77S(d,tbsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_double *a, const bla_integer *lda, bla_double *x, const bla_integer *incx) { /* System generated locals */ bla_integer a_dim1, a_offset, i__1, i__2, i__3, i__4; @@ -1053,7 +1055,8 @@ -lf2c -lm (in that order) */ -/* Subroutine */ int PASTEF77(s,tbsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_real *a, const bla_integer *lda, bla_real *x, const bla_integer *incx) +/* Subroutine */ +int PASTEF77S(s,tbsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_real *a, const bla_integer *lda, bla_real *x, const bla_integer *incx) { /* System generated locals */ bla_integer a_dim1, a_offset, i__1, i__2, i__3, i__4; @@ -1484,7 +1487,8 @@ -lf2c -lm (in that order) */ -/* Subroutine */ int PASTEF77(z,tbsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_dcomplex *a, const bla_integer *lda, bla_dcomplex *x, const bla_integer *incx) +/* Subroutine */ +int PASTEF77S(z,tbsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_dcomplex *a, const bla_integer *lda, bla_dcomplex *x, const bla_integer *incx) { /* System generated locals */ bla_integer a_dim1, a_offset, i__1, i__2, i__3, i__4, i__5; @@ -2058,5 +2062,26 @@ } /* ztbsv_ */ + +int PASTEF77(s,tbsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_real *a, const bla_integer *lda, bla_real *x, const bla_integer *incx) +{ + return PASTEF77S(s,tbsv)( uplo, trans, diag, n, k, a, lda, x, incx ); +} + +int PASTEF77(d,tbsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_double *a, const bla_integer *lda, bla_double *x, const bla_integer *incx) +{ + return PASTEF77S(d,tbsv)( uplo, trans, diag, n, k, a, lda, x, incx ); +} + +int PASTEF77(c,tbsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_scomplex *a, const bla_integer *lda, bla_scomplex *x, const bla_integer *incx) +{ + return PASTEF77S(c,tbsv)( uplo, trans, diag, n, k, a, lda, x, incx ); +} + +int PASTEF77(z,tbsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_dcomplex *a, const bla_integer *lda, bla_dcomplex *x, const bla_integer *incx) +{ + return PASTEF77S(z,tbsv)( uplo, trans, diag, n, k, a, lda, x, incx ); +} + #endif diff --git a/frame/compat/f2c/bla_tbsv.h b/frame/compat/f2c/bla_tbsv.h index 5e84f5c363..dd11e3f3f5 100644 --- a/frame/compat/f2c/bla_tbsv.h +++ b/frame/compat/f2c/bla_tbsv.h @@ -5,7 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - + Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -38,5 +39,9 @@ BLIS_EXPORT_BLAS int PASTEF77(c,tbsv)(const bla_character *uplo, const bla_chara BLIS_EXPORT_BLAS int PASTEF77(d,tbsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_double *a, const bla_integer *lda, bla_double *x, const bla_integer *incx); BLIS_EXPORT_BLAS int PASTEF77(s,tbsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_real *a, const bla_integer *lda, bla_real *x, const bla_integer *incx); BLIS_EXPORT_BLAS int PASTEF77(z,tbsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_dcomplex *a, const bla_integer *lda, bla_dcomplex *x, const bla_integer *incx); +BLIS_EXPORT_BLAS int PASTEF77S(c,tbsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_scomplex *a, const bla_integer *lda, bla_scomplex *x, const bla_integer *incx); +BLIS_EXPORT_BLAS int PASTEF77S(d,tbsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_double *a, const bla_integer *lda, bla_double *x, const bla_integer *incx); +BLIS_EXPORT_BLAS int PASTEF77S(s,tbsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_real *a, const bla_integer *lda, bla_real *x, const bla_integer *incx); +BLIS_EXPORT_BLAS int PASTEF77S(z,tbsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_dcomplex *a, const bla_integer *lda, bla_dcomplex *x, const bla_integer *incx); #endif diff --git a/frame/compat/f2c/bla_tpmv.c b/frame/compat/f2c/bla_tpmv.c index 8fa46f4c4f..f01e9f9b42 100644 --- a/frame/compat/f2c/bla_tpmv.c +++ b/frame/compat/f2c/bla_tpmv.c @@ -5,7 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - + Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -41,7 +42,8 @@ -lf2c -lm (in that order) */ -/* Subroutine */ int PASTEF77(c,tpmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_scomplex *ap, bla_scomplex *x, const bla_integer *incx) +/* Subroutine */ +int PASTEF77S(c,tpmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_scomplex *ap, bla_scomplex *x, const bla_integer *incx) { /* System generated locals */ bla_integer i__1, i__2, i__3, i__4, i__5; @@ -542,7 +544,8 @@ -lf2c -lm (in that order) */ -/* Subroutine */ int PASTEF77(d,tpmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_double *ap, bla_double *x, const bla_integer *incx) +/* Subroutine */ +int PASTEF77S(d,tpmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_double *ap, bla_double *x, const bla_integer *incx) { /* System generated locals */ bla_integer i__1, i__2; @@ -890,7 +893,8 @@ -lf2c -lm (in that order) */ -/* Subroutine */ int PASTEF77(s,tpmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_real *ap, bla_real *x, const bla_integer *incx) +/* Subroutine */ +int PASTEF77S(s,tpmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_real *ap, bla_real *x, const bla_integer *incx) { /* System generated locals */ bla_integer i__1, i__2; @@ -1238,7 +1242,8 @@ -lf2c -lm (in that order) */ -/* Subroutine */ int PASTEF77(z,tpmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_dcomplex *ap, bla_dcomplex *x, const bla_integer *incx) +/* Subroutine */ +int PASTEF77S(z,tpmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_dcomplex *ap, bla_dcomplex *x, const bla_integer *incx) { /* System generated locals */ bla_integer i__1, i__2, i__3, i__4, i__5; @@ -1734,5 +1739,24 @@ } /* ztpmv_ */ -#endif +int PASTEF77(s,tpmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_real *ap, bla_real *x, const bla_integer *incx) +{ + return PASTEF77S(s,tpmv)( uplo, trans, diag, n, ap, x, incx ); +} + +int PASTEF77(d,tpmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_double *ap, bla_double *x, const bla_integer *incx) +{ + return PASTEF77S(d,tpmv)( uplo, trans, diag, n, ap, x, incx ); +} + +int PASTEF77(c,tpmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_scomplex *ap, bla_scomplex *x, const bla_integer *incx) +{ + return PASTEF77S(c,tpmv)( uplo, trans, diag, n, ap, x, incx ); +} + +int PASTEF77(z,tpmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_dcomplex *ap, bla_dcomplex *x, const bla_integer *incx) +{ + return PASTEF77S(z,tpmv)( uplo, trans, diag, n, ap, x, incx ); +} +#endif \ No newline at end of file diff --git a/frame/compat/f2c/bla_tpmv.h b/frame/compat/f2c/bla_tpmv.h index 2376ecfe33..6af438f1da 100644 --- a/frame/compat/f2c/bla_tpmv.h +++ b/frame/compat/f2c/bla_tpmv.h @@ -5,7 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - + Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -38,5 +39,9 @@ BLIS_EXPORT_BLAS int PASTEF77(c,tpmv)(const bla_character *uplo, const bla_chara BLIS_EXPORT_BLAS int PASTEF77(d,tpmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_double *ap, bla_double *x, const bla_integer *incx); BLIS_EXPORT_BLAS int PASTEF77(s,tpmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_real *ap, bla_real *x, const bla_integer *incx); BLIS_EXPORT_BLAS int PASTEF77(z,tpmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_dcomplex *ap, bla_dcomplex *x, const bla_integer *incx); +BLIS_EXPORT_BLAS int PASTEF77S(c,tpmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_scomplex *ap, bla_scomplex *x, const bla_integer *incx); +BLIS_EXPORT_BLAS int PASTEF77S(d,tpmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_double *ap, bla_double *x, const bla_integer *incx); +BLIS_EXPORT_BLAS int PASTEF77S(s,tpmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_real *ap, bla_real *x, const bla_integer *incx); +BLIS_EXPORT_BLAS int PASTEF77S(z,tpmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_dcomplex *ap, bla_dcomplex *x, const bla_integer *incx); #endif diff --git a/frame/compat/f2c/bla_tpsv.c b/frame/compat/f2c/bla_tpsv.c index 0764940979..2619df9fea 100644 --- a/frame/compat/f2c/bla_tpsv.c +++ b/frame/compat/f2c/bla_tpsv.c @@ -5,7 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - + Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -41,7 +42,8 @@ -lf2c -lm (in that order) */ -/* Subroutine */ int PASTEF77(c,tpsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_scomplex *ap, bla_scomplex *x, const bla_integer *incx) +/* Subroutine */ +int PASTEF77S(c,tpsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_scomplex *ap, bla_scomplex *x, const bla_integer *incx) { /* System generated locals */ bla_integer i__1, i__2, i__3, i__4, i__5; @@ -534,7 +536,8 @@ -lf2c -lm (in that order) */ -/* Subroutine */ int PASTEF77(d,tpsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_double *ap, bla_double *x, const bla_integer *incx) +/* Subroutine */ +int PASTEF77S(d,tpsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_double *ap, bla_double *x, const bla_integer *incx) { /* System generated locals */ bla_integer i__1, i__2; @@ -885,7 +888,8 @@ -lf2c -lm (in that order) */ -/* Subroutine */ int PASTEF77(s,tpsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_real *ap, bla_real *x, const bla_integer *incx) +/* Subroutine */ +int PASTEF77S(s,tpsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_real *ap, bla_real *x, const bla_integer *incx) { /* System generated locals */ bla_integer i__1, i__2; @@ -1236,7 +1240,8 @@ -lf2c -lm (in that order) */ -/* Subroutine */ int PASTEF77(z,tpsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_dcomplex *ap, bla_dcomplex *x, const bla_integer *incx) +/* Subroutine */ +int PASTEF77S(z,tpsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_dcomplex *ap, bla_dcomplex *x, const bla_integer *incx) { /* System generated locals */ bla_integer i__1, i__2, i__3, i__4, i__5; @@ -1725,5 +1730,25 @@ } /* ztpsv_ */ +int PASTEF77(s,tpsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_real *ap, bla_real *x, const bla_integer *incx) +{ + return PASTEF77S(s,tpsv)( uplo, trans, diag, n, ap, x, incx ); +} + +int PASTEF77(d,tpsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_double *ap, bla_double *x, const bla_integer *incx) +{ + return PASTEF77S(d,tpsv)( uplo, trans, diag, n, ap, x, incx ); +} + +int PASTEF77(c,tpsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_scomplex *ap, bla_scomplex *x, const bla_integer *incx) +{ + return PASTEF77S(c,tpsv)( uplo, trans, diag, n, ap, x, incx ); +} + +int PASTEF77(z,tpsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_dcomplex *ap, bla_dcomplex *x, const bla_integer *incx) +{ + return PASTEF77S(z,tpsv)( uplo, trans, diag, n, ap, x, incx ); +} + #endif diff --git a/frame/compat/f2c/bla_tpsv.h b/frame/compat/f2c/bla_tpsv.h index 77bd55979a..7a3332424d 100644 --- a/frame/compat/f2c/bla_tpsv.h +++ b/frame/compat/f2c/bla_tpsv.h @@ -5,7 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - + Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -38,5 +39,9 @@ BLIS_EXPORT_BLAS int PASTEF77(c,tpsv)(const bla_character *uplo, const bla_chara BLIS_EXPORT_BLAS int PASTEF77(d,tpsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_double *ap, bla_double *x, const bla_integer *incx); BLIS_EXPORT_BLAS int PASTEF77(s,tpsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_real *ap, bla_real *x, const bla_integer *incx); BLIS_EXPORT_BLAS int PASTEF77(z,tpsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_dcomplex *ap, bla_dcomplex *x, const bla_integer *incx); +BLIS_EXPORT_BLAS int PASTEF77S(c,tpsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_scomplex *ap, bla_scomplex *x, const bla_integer *incx); +BLIS_EXPORT_BLAS int PASTEF77S(d,tpsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_double *ap, bla_double *x, const bla_integer *incx); +BLIS_EXPORT_BLAS int PASTEF77S(s,tpsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_real *ap, bla_real *x, const bla_integer *incx); +BLIS_EXPORT_BLAS int PASTEF77S(z,tpsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_dcomplex *ap, bla_dcomplex *x, const bla_integer *incx); #endif diff --git a/frame/include/bli_macro_defs.h b/frame/include/bli_macro_defs.h index 75b9c9fdc4..9ab7c00aa7 100644 --- a/frame/include/bli_macro_defs.h +++ b/frame/include/bli_macro_defs.h @@ -158,17 +158,23 @@ #define PASTEMACT(ch1, ch2, ch3, ch4) bli_ ## ch1 ## ch2 ## _ ## ch3 ## _ ## ch4 // name-mangling macros. #ifdef BLIS_ENABLE_NO_UNDERSCORE_API -#define PASTEF770(name) name -#define PASTEF77(ch1,name) ch1 ## name -#define PASTEF772(ch1,ch2,name) ch1 ## ch2 ## name -#define PASTEF773(ch1,ch2,ch3,name) ch1 ## ch2 ## ch3 ## name -#define PASTEF77S(ch1,name) ch1 ## name ## _blis_impl +#define PASTEF770(name) name +#define PASTEF77(ch1,name) ch1 ## name +#define PASTEF772(ch1,ch2,name) ch1 ## ch2 ## name +#define PASTEF773(ch1,ch2,ch3,name) ch1 ## ch2 ## ch3 ## name +#define PASTEF770S(name) name ## _blis_impl +#define PASTEF77S(ch1,name) ch1 ## name ## _blis_impl +#define PASTEF772S(ch1,ch2,name) ch1 ## ch2 ## name ## _blis_impl +#define PASTEF773S(ch1,ch2,ch3,name) ch1 ## ch2 ## ch3 ## name ## _blis_impl #else -#define PASTEF770(name) name ## _ -#define PASTEF77(ch1,name) ch1 ## name ## _ -#define PASTEF772(ch1,ch2,name) ch1 ## ch2 ## name ## _ -#define PASTEF773(ch1,ch2,ch3,name) ch1 ## ch2 ## ch3 ## name ## _ -#define PASTEF77S(ch1,name) ch1 ## name ## _blis_impl +#define PASTEF770(name) name ## _ +#define PASTEF77(ch1,name) ch1 ## name ## _ +#define PASTEF772(ch1,ch2,name) ch1 ## ch2 ## name ## _ +#define PASTEF773(ch1,ch2,ch3,name) ch1 ## ch2 ## ch3 ## name ## _ +#define PASTEF770S(name) name ## _blis_impl +#define PASTEF77S(ch1,name) ch1 ## name ## _blis_impl +#define PASTEF772S(ch1,ch2,name) ch1 ## ch2 ## name ## _blis_impl +#define PASTEF773S(ch1,ch2,ch3,name) ch1 ## ch2 ## ch3 ## name ## _blis_impl #endif // -- Include other groups of macros diff --git a/frame/util/bli_util_api_wrap.c b/frame/util/bli_util_api_wrap.c index 9e8d1ccc38..098a7c33e8 100644 --- a/frame/util/bli_util_api_wrap.c +++ b/frame/util/bli_util_api_wrap.c @@ -195,17 +195,17 @@ void ZDOTU_(dcomplex* retval,const f77_int *n, const dcomplex *zx, const f77_int void CGBMV(const char *trans,const f77_int *m,const f77_int *n,const f77_int *kl,const f77_int *ku,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *x,const f77_int *incx,const scomplex *beta,scomplex *y,const f77_int *incy) { - cgbmv_( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); + cgbmv_blis_impl( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); } void cgbmv(const char *trans,const f77_int *m,const f77_int *n,const f77_int *kl,const f77_int *ku,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *x,const f77_int *incx,const scomplex *beta,scomplex *y,const f77_int *incy) { - cgbmv_( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); + cgbmv_blis_impl( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); } void CGBMV_(const char *trans,const f77_int *m,const f77_int *n,const f77_int *kl,const f77_int *ku,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *x,const f77_int *incx,const scomplex *beta,scomplex *y,const f77_int *incy) { - cgbmv_( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); + cgbmv_blis_impl( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); } void CGEMM(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) @@ -225,62 +225,62 @@ void CGEMM_(const char *transa,const char *transb,const f77_int *m,const f77 void CGEMV(const char *trans,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *x,const f77_int *incx,const scomplex *beta,scomplex *y,const f77_int *incy) { - cgemv_( trans, m, n, alpha, a, lda, x, incx, beta, y, incy); + cgemv_blis_impl( trans, m, n, alpha, a, lda, x, incx, beta, y, incy); } void cgemv(const char *trans,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *x,const f77_int *incx,const scomplex *beta,scomplex *y,const f77_int *incy) { - cgemv_( trans, m, n, alpha, a, lda, x, incx, beta, y, incy); + cgemv_blis_impl( trans, m, n, alpha, a, lda, x, incx, beta, y, incy); } void CGEMV_(const char *trans,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *x,const f77_int *incx,const scomplex *beta,scomplex *y,const f77_int *incy) { - cgemv_( trans, m, n, alpha, a, lda, x, incx, beta, y, incy); + cgemv_blis_impl( trans, m, n, alpha, a, lda, x, incx, beta, y, incy); } void CGERC(const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *x,const f77_int *incx,const scomplex *y,const f77_int *incy,scomplex *a,const f77_int *lda) { - cgerc_( m, n, alpha, x, incx, y, incy, a, lda); + cgerc_blis_impl( m, n, alpha, x, incx, y, incy, a, lda); } void cgerc(const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *x,const f77_int *incx,const scomplex *y,const f77_int *incy,scomplex *a,const f77_int *lda) { - cgerc_( m, n, alpha, x, incx, y, incy, a, lda); + cgerc_blis_impl( m, n, alpha, x, incx, y, incy, a, lda); } void CGERC_(const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *x,const f77_int *incx,const scomplex *y,const f77_int *incy,scomplex *a,const f77_int *lda) { - cgerc_( m, n, alpha, x, incx, y, incy, a, lda); + cgerc_blis_impl( m, n, alpha, x, incx, y, incy, a, lda); } void CGERU(const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *x,const f77_int *incx,const scomplex *y,const f77_int *incy,scomplex *a,const f77_int *lda) { - cgeru_( m, n, alpha, x, incx, y, incy, a, lda); + cgeru_blis_impl( m, n, alpha, x, incx, y, incy, a, lda); } void cgeru(const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *x,const f77_int *incx,const scomplex *y,const f77_int *incy,scomplex *a,const f77_int *lda) { - cgeru_( m, n, alpha, x, incx, y, incy, a, lda); + cgeru_blis_impl( m, n, alpha, x, incx, y, incy, a, lda); } void CGERU_(const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *x,const f77_int *incx,const scomplex *y,const f77_int *incy,scomplex *a,const f77_int *lda) { - cgeru_( m, n, alpha, x, incx, y, incy, a, lda); + cgeru_blis_impl( m, n, alpha, x, incx, y, incy, a, lda); } void CHBMV(const char *uplo,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *x,const f77_int *incx,const scomplex *beta,scomplex *y,const f77_int *incy) { - chbmv_( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); + chbmv_blis_impl( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); } void chbmv(const char *uplo,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *x,const f77_int *incx,const scomplex *beta,scomplex *y,const f77_int *incy) { - chbmv_( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); + chbmv_blis_impl( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); } void CHBMV_(const char *uplo,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *x,const f77_int *incx,const scomplex *beta,scomplex *y,const f77_int *incy) { - chbmv_( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); + chbmv_blis_impl( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); } void CHEMM(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) @@ -300,47 +300,47 @@ void CHEMM_(const char *side,const char *uplo,const f77_int *m,const f77_int void CHEMV(const char *uplo,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *x,const f77_int *incx,const scomplex *beta,scomplex *y,const f77_int *incy) { - chemv_( uplo, n, alpha, a, lda, x, incx, beta, y, incy); + chemv_blis_impl( uplo, n, alpha, a, lda, x, incx, beta, y, incy); } void chemv(const char *uplo,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *x,const f77_int *incx,const scomplex *beta,scomplex *y,const f77_int *incy) { - chemv_( uplo, n, alpha, a, lda, x, incx, beta, y, incy); + chemv_blis_impl( uplo, n, alpha, a, lda, x, incx, beta, y, incy); } void CHEMV_(const char *uplo,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *x,const f77_int *incx,const scomplex *beta,scomplex *y,const f77_int *incy) { - chemv_( uplo, n, alpha, a, lda, x, incx, beta, y, incy); + chemv_blis_impl( uplo, n, alpha, a, lda, x, incx, beta, y, incy); } void CHER(const char *uplo,const f77_int *n,const float *alpha,const scomplex *x,const f77_int *incx,scomplex *a,const f77_int *lda) { - cher_( uplo, n, alpha, x, incx, a, lda); + cher_blis_impl( uplo, n, alpha, x, incx, a, lda); } void cher(const char *uplo,const f77_int *n,const float *alpha,const scomplex *x,const f77_int *incx,scomplex *a,const f77_int *lda) { - cher_( uplo, n, alpha, x, incx, a, lda); + cher_blis_impl( uplo, n, alpha, x, incx, a, lda); } void CHER_(const char *uplo,const f77_int *n,const float *alpha,const scomplex *x,const f77_int *incx,scomplex *a,const f77_int *lda) { - cher_( uplo, n, alpha, x, incx, a, lda); + cher_blis_impl( uplo, n, alpha, x, incx, a, lda); } void CHER2(const char *uplo,const f77_int *n,const scomplex *alpha,const scomplex *x,const f77_int *incx,const scomplex *y,const f77_int *incy,scomplex *a,const f77_int *lda) { - cher2_( uplo, n, alpha, x, incx, y, incy, a, lda); + cher2_blis_impl( uplo, n, alpha, x, incx, y, incy, a, lda); } void cher2(const char *uplo,const f77_int *n,const scomplex *alpha,const scomplex *x,const f77_int *incx,const scomplex *y,const f77_int *incy,scomplex *a,const f77_int *lda) { - cher2_( uplo, n, alpha, x, incx, y, incy, a, lda); + cher2_blis_impl( uplo, n, alpha, x, incx, y, incy, a, lda); } void CHER2_(const char *uplo,const f77_int *n,const scomplex *alpha,const scomplex *x,const f77_int *incx,const scomplex *y,const f77_int *incy,scomplex *a,const f77_int *lda) { - cher2_( uplo, n, alpha, x, incx, y, incy, a, lda); + cher2_blis_impl( uplo, n, alpha, x, incx, y, incy, a, lda); } void CHER2K(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const float *beta,scomplex *c,const f77_int *ldc) @@ -375,47 +375,47 @@ void CHERK_(const char *uplo,const char *trans,const f77_int *n,const f77_in void CHPMV(const char *uplo,const f77_int *n,const scomplex *alpha,const scomplex *ap,const scomplex *x,const f77_int *incx,const scomplex *beta,scomplex *y,const f77_int *incy) { - chpmv_( uplo, n, alpha, ap, x, incx, beta, y, incy); + chpmv_blis_impl( uplo, n, alpha, ap, x, incx, beta, y, incy); } void chpmv(const char *uplo,const f77_int *n,const scomplex *alpha,const scomplex *ap,const scomplex *x,const f77_int *incx,const scomplex *beta,scomplex *y,const f77_int *incy) { - chpmv_( uplo, n, alpha, ap, x, incx, beta, y, incy); + chpmv_blis_impl( uplo, n, alpha, ap, x, incx, beta, y, incy); } void CHPMV_(const char *uplo,const f77_int *n,const scomplex *alpha,const scomplex *ap,const scomplex *x,const f77_int *incx,const scomplex *beta,scomplex *y,const f77_int *incy) { - chpmv_( uplo, n, alpha, ap, x, incx, beta, y, incy); + chpmv_blis_impl( uplo, n, alpha, ap, x, incx, beta, y, incy); } void CHPR(const char *uplo,const f77_int *n,const float *alpha,const scomplex *x,const f77_int *incx,scomplex *ap) { - chpr_( uplo, n, alpha, x, incx, ap); + chpr_blis_impl( uplo, n, alpha, x, incx, ap); } void chpr(const char *uplo,const f77_int *n,const float *alpha,const scomplex *x,const f77_int *incx,scomplex *ap) { - chpr_( uplo, n, alpha, x, incx, ap); + chpr_blis_impl( uplo, n, alpha, x, incx, ap); } void CHPR_(const char *uplo,const f77_int *n,const float *alpha,const scomplex *x,const f77_int *incx,scomplex *ap) { - chpr_( uplo, n, alpha, x, incx, ap); + chpr_blis_impl( uplo, n, alpha, x, incx, ap); } void CHPR2(const char *uplo,const f77_int *n,const scomplex *alpha,const scomplex *x,const f77_int *incx,const scomplex *y,const f77_int *incy,scomplex *ap) { - chpr2_( uplo, n, alpha, x, incx, y, incy, ap); + chpr2_blis_impl( uplo, n, alpha, x, incx, y, incy, ap); } void chpr2(const char *uplo,const f77_int *n,const scomplex *alpha,const scomplex *x,const f77_int *incx,const scomplex *y,const f77_int *incy,scomplex *ap) { - chpr2_( uplo, n, alpha, x, incx, y, incy, ap); + chpr2_blis_impl( uplo, n, alpha, x, incx, y, incy, ap); } void CHPR2_(const char *uplo,const f77_int *n,const scomplex *alpha,const scomplex *x,const f77_int *incx,const scomplex *y,const f77_int *incy,scomplex *ap) { - chpr2_( uplo, n, alpha, x, incx, y, incy, ap); + chpr2_blis_impl( uplo, n, alpha, x, incx, y, incy, ap); } void CROTG(scomplex *ca, bla_scomplex *cb, bla_real *c,scomplex *s) @@ -540,62 +540,62 @@ void CSYRK_(const char *uplo,const char *trans,const f77_int *n,const f77_in void CTBMV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const scomplex *a,const f77_int *lda,scomplex *x,const f77_int *incx) { - ctbmv_( uplo, trans, diag, n, k, a, lda, x, incx); + ctbmv_blis_impl( uplo, trans, diag, n, k, a, lda, x, incx); } void ctbmv(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const scomplex *a,const f77_int *lda,scomplex *x,const f77_int *incx) { - ctbmv_( uplo, trans, diag, n, k, a, lda, x, incx); + ctbmv_blis_impl( uplo, trans, diag, n, k, a, lda, x, incx); } void CTBMV_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const scomplex *a,const f77_int *lda,scomplex *x,const f77_int *incx) { - ctbmv_( uplo, trans, diag, n, k, a, lda, x, incx); + ctbmv_blis_impl( uplo, trans, diag, n, k, a, lda, x, incx); } void CTBSV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const scomplex *a,const f77_int *lda,scomplex *x,const f77_int *incx) { - ctbsv_( uplo, trans, diag, n, k, a, lda, x, incx); + ctbsv_blis_impl( uplo, trans, diag, n, k, a, lda, x, incx); } void ctbsv(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const scomplex *a,const f77_int *lda,scomplex *x,const f77_int *incx) { - ctbsv_( uplo, trans, diag, n, k, a, lda, x, incx); + ctbsv_blis_impl( uplo, trans, diag, n, k, a, lda, x, incx); } void CTBSV_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const scomplex *a,const f77_int *lda,scomplex *x,const f77_int *incx) { - ctbsv_( uplo, trans, diag, n, k, a, lda, x, incx); + ctbsv_blis_impl( uplo, trans, diag, n, k, a, lda, x, incx); } void CTPMV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const scomplex *ap,scomplex *x,const f77_int *incx) { - ctpmv_( uplo, trans, diag, n, ap, x, incx); + ctpmv_blis_impl( uplo, trans, diag, n, ap, x, incx); } void ctpmv(const char *uplo,const char *trans,const char *diag,const f77_int *n,const scomplex *ap,scomplex *x,const f77_int *incx) { - ctpmv_( uplo, trans, diag, n, ap, x, incx); + ctpmv_blis_impl( uplo, trans, diag, n, ap, x, incx); } void CTPMV_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const scomplex *ap,scomplex *x,const f77_int *incx) { - ctpmv_( uplo, trans, diag, n, ap, x, incx); + ctpmv_blis_impl( uplo, trans, diag, n, ap, x, incx); } void CTPSV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const scomplex *ap,scomplex *x,const f77_int *incx) { - ctpsv_( uplo, trans, diag, n, ap, x, incx); + ctpsv_blis_impl( uplo, trans, diag, n, ap, x, incx); } void ctpsv(const char *uplo,const char *trans,const char *diag,const f77_int *n,const scomplex *ap,scomplex *x,const f77_int *incx) { - ctpsv_( uplo, trans, diag, n, ap, x, incx); + ctpsv_blis_impl( uplo, trans, diag, n, ap, x, incx); } void CTPSV_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const scomplex *ap,scomplex *x,const f77_int *incx) { - ctpsv_( uplo, trans, diag, n, ap, x, incx); + ctpsv_blis_impl( uplo, trans, diag, n, ap, x, incx); } void CTRMM(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,scomplex *b,const f77_int *ldb) @@ -615,17 +615,17 @@ void CTRMM_(const char *side,const char *uplo,const char *transa,const cha void CTRMV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const scomplex *a,const f77_int *lda,scomplex *x,const f77_int *incx) { - ctrmv_( uplo, trans, diag, n, a, lda, x, incx); + ctrmv_blis_impl( uplo, trans, diag, n, a, lda, x, incx); } void ctrmv(const char *uplo,const char *trans,const char *diag,const f77_int *n,const scomplex *a,const f77_int *lda,scomplex *x,const f77_int *incx) { - ctrmv_( uplo, trans, diag, n, a, lda, x, incx); + ctrmv_blis_impl( uplo, trans, diag, n, a, lda, x, incx); } void CTRMV_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const scomplex *a,const f77_int *lda,scomplex *x,const f77_int *incx) { - ctrmv_( uplo, trans, diag, n, a, lda, x, incx); + ctrmv_blis_impl( uplo, trans, diag, n, a, lda, x, incx); } void CTRSM(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,scomplex *b,const f77_int *ldb) @@ -645,17 +645,17 @@ void CTRSM_(const char *side,const char *uplo,const char *transa,const cha void CTRSV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const scomplex *a,const f77_int *lda,scomplex *x,const f77_int *incx) { - ctrsv_( uplo, trans, diag, n, a, lda, x, incx); + ctrsv_blis_impl( uplo, trans, diag, n, a, lda, x, incx); } void ctrsv(const char *uplo,const char *trans,const char *diag,const f77_int *n,const scomplex *a,const f77_int *lda,scomplex *x,const f77_int *incx) { - ctrsv_( uplo, trans, diag, n, a, lda, x, incx); + ctrsv_blis_impl( uplo, trans, diag, n, a, lda, x, incx); } void CTRSV_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const scomplex *a,const f77_int *lda,scomplex *x,const f77_int *incx) { - ctrsv_( uplo, trans, diag, n, a, lda, x, incx); + ctrsv_blis_impl( uplo, trans, diag, n, a, lda, x, incx); } double DASUM(const f77_int *n,const double *dx,const f77_int *incx) @@ -735,17 +735,17 @@ double DDOT_(const f77_int *n,const double *dx,const f77_int *incx,const double void DGBMV(const char *trans,const f77_int *m,const f77_int *n,const f77_int *kl,const f77_int *ku,const double *alpha,const double *a,const f77_int *lda,const double *x,const f77_int *incx,const double *beta,double *y,const f77_int *incy) { - dgbmv_( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); + dgbmv_blis_impl( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); } void dgbmv(const char *trans,const f77_int *m,const f77_int *n,const f77_int *kl,const f77_int *ku,const double *alpha,const double *a,const f77_int *lda,const double *x,const f77_int *incx,const double *beta,double *y,const f77_int *incy) { - dgbmv_( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); + dgbmv_blis_impl( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); } void DGBMV_(const char *trans,const f77_int *m,const f77_int *n,const f77_int *kl,const f77_int *ku,const double *alpha,const double *a,const f77_int *lda,const double *x,const f77_int *incx,const double *beta,double *y,const f77_int *incy) { - dgbmv_( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); + dgbmv_blis_impl( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); } void DGEMM(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const double *alpha,const double *a,const f77_int *lda,const double *b,const f77_int *ldb,const double *beta,double *c,const f77_int *ldc) @@ -765,32 +765,32 @@ void DGEMM_(const char *transa,const char *transb,const f77_int *m,const f77 void DGEMV(const char *trans,const f77_int *m,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,const double *x,const f77_int *incx,const double *beta,double *y,const f77_int *incy) { - dgemv_( trans, m, n, alpha, a, lda, x, incx, beta, y, incy); + dgemv_blis_impl( trans, m, n, alpha, a, lda, x, incx, beta, y, incy); } void dgemv(const char *trans,const f77_int *m,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,const double *x,const f77_int *incx,const double *beta,double *y,const f77_int *incy) { - dgemv_( trans, m, n, alpha, a, lda, x, incx, beta, y, incy); + dgemv_blis_impl( trans, m, n, alpha, a, lda, x, incx, beta, y, incy); } void DGEMV_(const char *trans,const f77_int *m,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,const double *x,const f77_int *incx,const double *beta,double *y,const f77_int *incy) { - dgemv_( trans, m, n, alpha, a, lda, x, incx, beta, y, incy); + dgemv_blis_impl( trans, m, n, alpha, a, lda, x, incx, beta, y, incy); } void DGER(const f77_int *m,const f77_int *n,const double *alpha,const double *x,const f77_int *incx,const double *y,const f77_int *incy,double *a,const f77_int *lda) { - dger_( m, n, alpha, x, incx, y, incy, a, lda); + dger_blis_impl( m, n, alpha, x, incx, y, incy, a, lda); } void dger(const f77_int *m,const f77_int *n,const double *alpha,const double *x,const f77_int *incx,const double *y,const f77_int *incy,double *a,const f77_int *lda) { - dger_( m, n, alpha, x, incx, y, incy, a, lda); + dger_blis_impl( m, n, alpha, x, incx, y, incy, a, lda); } void DGER_(const f77_int *m,const f77_int *n,const double *alpha,const double *x,const f77_int *incx,const double *y,const f77_int *incy,double *a,const f77_int *lda) { - dger_( m, n, alpha, x, incx, y, incy, a, lda); + dger_blis_impl( m, n, alpha, x, incx, y, incy, a, lda); } double DNRM2(const f77_int *n,const double *x,const f77_int *incx) @@ -870,17 +870,17 @@ void DROTMG_(double *dd1,double *dd2,double *dx1,const double *dy1,double *dpara void DSBMV(const char *uplo,const f77_int *n,const f77_int *k,const double *alpha,const double *a,const f77_int *lda,const double *x,const f77_int *incx,const double *beta,double *y,const f77_int *incy) { - dsbmv_( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); + dsbmv_blis_impl( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); } void dsbmv(const char *uplo,const f77_int *n,const f77_int *k,const double *alpha,const double *a,const f77_int *lda,const double *x,const f77_int *incx,const double *beta,double *y,const f77_int *incy) { - dsbmv_( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); + dsbmv_blis_impl( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); } void DSBMV_(const char *uplo,const f77_int *n,const f77_int *k,const double *alpha,const double *a,const f77_int *lda,const double *x,const f77_int *incx,const double *beta,double *y,const f77_int *incy) { - dsbmv_( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); + dsbmv_blis_impl( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); } void DSCAL(const f77_int *n,const double *da,double *dx,const f77_int *incx) @@ -915,47 +915,47 @@ double DSDOT_(const f77_int *n,const float *sx,const f77_int *incx,const float void DSPMV(const char *uplo,const f77_int *n,const double *alpha,const double *ap,const double *x,const f77_int *incx,const double *beta,double *y,const f77_int *incy) { - dspmv_( uplo, n, alpha, ap, x, incx, beta, y, incy); + dspmv_blis_impl( uplo, n, alpha, ap, x, incx, beta, y, incy); } void dspmv(const char *uplo,const f77_int *n,const double *alpha,const double *ap,const double *x,const f77_int *incx,const double *beta,double *y,const f77_int *incy) { - dspmv_( uplo, n, alpha, ap, x, incx, beta, y, incy); + dspmv_blis_impl( uplo, n, alpha, ap, x, incx, beta, y, incy); } void DSPMV_(const char *uplo,const f77_int *n,const double *alpha,const double *ap,const double *x,const f77_int *incx,const double *beta,double *y,const f77_int *incy) { - dspmv_( uplo, n, alpha, ap, x, incx, beta, y, incy); + dspmv_blis_impl( uplo, n, alpha, ap, x, incx, beta, y, incy); } void DSPR(const char *uplo,const f77_int *n,const double *alpha,const double *x,const f77_int *incx,double *ap) { - dspr_( uplo, n, alpha, x, incx, ap); + dspr_blis_impl( uplo, n, alpha, x, incx, ap); } void dspr(const char *uplo,const f77_int *n,const double *alpha,const double *x,const f77_int *incx,double *ap) { - dspr_( uplo, n, alpha, x, incx, ap); + dspr_blis_impl( uplo, n, alpha, x, incx, ap); } void DSPR_(const char *uplo,const f77_int *n,const double *alpha,const double *x,const f77_int *incx,double *ap) { - dspr_( uplo, n, alpha, x, incx, ap); + dspr_blis_impl( uplo, n, alpha, x, incx, ap); } void DSPR2(const char *uplo,const f77_int *n,const double *alpha,const double *x,const f77_int *incx,const double *y,const f77_int *incy,double *ap) { - dspr2_( uplo, n, alpha, x, incx, y, incy, ap); + dspr2_blis_impl( uplo, n, alpha, x, incx, y, incy, ap); } void dspr2(const char *uplo,const f77_int *n,const double *alpha,const double *x,const f77_int *incx,const double *y,const f77_int *incy,double *ap) { - dspr2_( uplo, n, alpha, x, incx, y, incy, ap); + dspr2_blis_impl( uplo, n, alpha, x, incx, y, incy, ap); } void DSPR2_(const char *uplo,const f77_int *n,const double *alpha,const double *x,const f77_int *incx,const double *y,const f77_int *incy,double *ap) { - dspr2_( uplo, n, alpha, x, incx, y, incy, ap); + dspr2_blis_impl( uplo, n, alpha, x, incx, y, incy, ap); } void DSWAP(const f77_int *n,double *dx,const f77_int *incx,double *dy,const f77_int *incy) @@ -990,47 +990,47 @@ void DSYMM_(const char *side,const char *uplo,const f77_int *m,const f77_int void DSYMV(const char *uplo,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,const double *x,const f77_int *incx,const double *beta,double *y,const f77_int *incy) { - dsymv_( uplo, n, alpha, a, lda, x, incx, beta, y, incy); + dsymv_blis_impl( uplo, n, alpha, a, lda, x, incx, beta, y, incy); } void dsymv(const char *uplo,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,const double *x,const f77_int *incx,const double *beta,double *y,const f77_int *incy) { - dsymv_( uplo, n, alpha, a, lda, x, incx, beta, y, incy); + dsymv_blis_impl( uplo, n, alpha, a, lda, x, incx, beta, y, incy); } void DSYMV_(const char *uplo,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,const double *x,const f77_int *incx,const double *beta,double *y,const f77_int *incy) { - dsymv_( uplo, n, alpha, a, lda, x, incx, beta, y, incy); + dsymv_blis_impl( uplo, n, alpha, a, lda, x, incx, beta, y, incy); } void DSYR(const char *uplo,const f77_int *n,const double *alpha,const double *x,const f77_int *incx,double *a,const f77_int *lda) { - dsyr_( uplo, n, alpha, x, incx, a, lda); + dsyr_blis_impl( uplo, n, alpha, x, incx, a, lda); } void dsyr(const char *uplo,const f77_int *n,const double *alpha,const double *x,const f77_int *incx,double *a,const f77_int *lda) { - dsyr_( uplo, n, alpha, x, incx, a, lda); + dsyr_blis_impl( uplo, n, alpha, x, incx, a, lda); } void DSYR_(const char *uplo,const f77_int *n,const double *alpha,const double *x,const f77_int *incx,double *a,const f77_int *lda) { - dsyr_( uplo, n, alpha, x, incx, a, lda); + dsyr_blis_impl( uplo, n, alpha, x, incx, a, lda); } void DSYR2(const char *uplo,const f77_int *n,const double *alpha,const double *x,const f77_int *incx,const double *y,const f77_int *incy,double *a,const f77_int *lda) { - dsyr2_( uplo, n, alpha, x, incx, y, incy, a, lda); + dsyr2_blis_impl( uplo, n, alpha, x, incx, y, incy, a, lda); } void dsyr2(const char *uplo,const f77_int *n,const double *alpha,const double *x,const f77_int *incx,const double *y,const f77_int *incy,double *a,const f77_int *lda) { - dsyr2_( uplo, n, alpha, x, incx, y, incy, a, lda); + dsyr2_blis_impl( uplo, n, alpha, x, incx, y, incy, a, lda); } void DSYR2_(const char *uplo,const f77_int *n,const double *alpha,const double *x,const f77_int *incx,const double *y,const f77_int *incy,double *a,const f77_int *lda) { - dsyr2_( uplo, n, alpha, x, incx, y, incy, a, lda); + dsyr2_blis_impl( uplo, n, alpha, x, incx, y, incy, a, lda); } void DSYR2K(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const double *alpha,const double *a,const f77_int *lda,const double *b,const f77_int *ldb,const double *beta,double *c,const f77_int *ldc) @@ -1065,62 +1065,62 @@ void DSYRK_(const char *uplo,const char *trans,const f77_int *n,const f77_in void DTBMV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const double *a,const f77_int *lda,double *x,const f77_int *incx) { - dtbmv_( uplo, trans, diag, n, k, a, lda, x, incx); + dtbmv_blis_impl( uplo, trans, diag, n, k, a, lda, x, incx); } void dtbmv(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const double *a,const f77_int *lda,double *x,const f77_int *incx) { - dtbmv_( uplo, trans, diag, n, k, a, lda, x, incx); + dtbmv_blis_impl( uplo, trans, diag, n, k, a, lda, x, incx); } void DTBMV_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const double *a,const f77_int *lda,double *x,const f77_int *incx) { - dtbmv_( uplo, trans, diag, n, k, a, lda, x, incx); + dtbmv_blis_impl( uplo, trans, diag, n, k, a, lda, x, incx); } void DTBSV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const double *a,const f77_int *lda,double *x,const f77_int *incx) { - dtbsv_( uplo, trans, diag, n, k, a, lda, x, incx); + dtbsv_blis_impl( uplo, trans, diag, n, k, a, lda, x, incx); } void dtbsv(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const double *a,const f77_int *lda,double *x,const f77_int *incx) { - dtbsv_( uplo, trans, diag, n, k, a, lda, x, incx); + dtbsv_blis_impl( uplo, trans, diag, n, k, a, lda, x, incx); } void DTBSV_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const double *a,const f77_int *lda,double *x,const f77_int *incx) { - dtbsv_( uplo, trans, diag, n, k, a, lda, x, incx); + dtbsv_blis_impl( uplo, trans, diag, n, k, a, lda, x, incx); } void DTPMV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const double *ap,double *x,const f77_int *incx) { - dtpmv_( uplo, trans, diag, n, ap, x, incx); + dtpmv_blis_impl( uplo, trans, diag, n, ap, x, incx); } void dtpmv(const char *uplo,const char *trans,const char *diag,const f77_int *n,const double *ap,double *x,const f77_int *incx) { - dtpmv_( uplo, trans, diag, n, ap, x, incx); + dtpmv_blis_impl( uplo, trans, diag, n, ap, x, incx); } void DTPMV_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const double *ap,double *x,const f77_int *incx) { - dtpmv_( uplo, trans, diag, n, ap, x, incx); + dtpmv_blis_impl( uplo, trans, diag, n, ap, x, incx); } void DTPSV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const double *ap,double *x,const f77_int *incx) { - dtpsv_( uplo, trans, diag, n, ap, x, incx); + dtpsv_blis_impl( uplo, trans, diag, n, ap, x, incx); } void dtpsv(const char *uplo,const char *trans,const char *diag,const f77_int *n,const double *ap,double *x,const f77_int *incx) { - dtpsv_( uplo, trans, diag, n, ap, x, incx); + dtpsv_blis_impl( uplo, trans, diag, n, ap, x, incx); } void DTPSV_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const double *ap,double *x,const f77_int *incx) { - dtpsv_( uplo, trans, diag, n, ap, x, incx); + dtpsv_blis_impl( uplo, trans, diag, n, ap, x, incx); } void DTRMM(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,double *b,const f77_int *ldb) @@ -1140,17 +1140,17 @@ void DTRMM_(const char *side,const char *uplo,const char *transa,const cha void DTRMV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const double *a,const f77_int *lda,double *x,const f77_int *incx) { - dtrmv_( uplo, trans, diag, n, a, lda, x, incx); + dtrmv_blis_impl( uplo, trans, diag, n, a, lda, x, incx); } void dtrmv(const char *uplo,const char *trans,const char *diag,const f77_int *n,const double *a,const f77_int *lda,double *x,const f77_int *incx) { - dtrmv_( uplo, trans, diag, n, a, lda, x, incx); + dtrmv_blis_impl( uplo, trans, diag, n, a, lda, x, incx); } void DTRMV_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const double *a,const f77_int *lda,double *x,const f77_int *incx) { - dtrmv_( uplo, trans, diag, n, a, lda, x, incx); + dtrmv_blis_impl( uplo, trans, diag, n, a, lda, x, incx); } void DTRSM(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,double *b,const f77_int *ldb) @@ -1170,17 +1170,17 @@ void DTRSM_(const char *side,const char *uplo,const char *transa,const cha void DTRSV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const double *a,const f77_int *lda,double *x,const f77_int *incx) { - dtrsv_( uplo, trans, diag, n, a, lda, x, incx); + dtrsv_blis_impl( uplo, trans, diag, n, a, lda, x, incx); } void dtrsv(const char *uplo,const char *trans,const char *diag,const f77_int *n,const double *a,const f77_int *lda,double *x,const f77_int *incx) { - dtrsv_( uplo, trans, diag, n, a, lda, x, incx); + dtrsv_blis_impl( uplo, trans, diag, n, a, lda, x, incx); } void DTRSV_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const double *a,const f77_int *lda,double *x,const f77_int *incx) { - dtrsv_( uplo, trans, diag, n, a, lda, x, incx); + dtrsv_blis_impl( uplo, trans, diag, n, a, lda, x, incx); } double DZASUM(const f77_int *n,const dcomplex *zx,const f77_int *incx) @@ -1402,17 +1402,17 @@ float SDSDOT_(const f77_int *n,const float *sb, const float *sx, const f77_int void SGBMV(const char *trans,const f77_int *m,const f77_int *n,const f77_int *kl,const f77_int *ku,const float *alpha,const float *a,const f77_int *lda,const float *x,const f77_int *incx,const float *beta,float *y,const f77_int *incy) { - sgbmv_( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); + sgbmv_blis_impl( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); } void sgbmv(const char *trans,const f77_int *m,const f77_int *n,const f77_int *kl,const f77_int *ku,const float *alpha,const float *a,const f77_int *lda,const float *x,const f77_int *incx,const float *beta,float *y,const f77_int *incy) { - sgbmv_( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); + sgbmv_blis_impl( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); } void SGBMV_(const char *trans,const f77_int *m,const f77_int *n,const f77_int *kl,const f77_int *ku,const float *alpha,const float *a,const f77_int *lda,const float *x,const f77_int *incx,const float *beta,float *y,const f77_int *incy) { - sgbmv_( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); + sgbmv_blis_impl( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); } void SGEMM(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const float *alpha,const float *a,const f77_int *lda,const float *b,const f77_int *ldb,const float *beta,float *c,const f77_int *ldc) @@ -1432,32 +1432,32 @@ void SGEMM_(const char *transa,const char *transb,const f77_int *m,const f77 void SGEMV(const char *trans,const f77_int *m,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,const float *x,const f77_int *incx,const float *beta,float *y,const f77_int *incy) { - sgemv_( trans, m, n, alpha, a, lda, x, incx, beta, y, incy); + sgemv_blis_impl( trans, m, n, alpha, a, lda, x, incx, beta, y, incy); } void sgemv(const char *trans,const f77_int *m,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,const float *x,const f77_int *incx,const float *beta,float *y,const f77_int *incy) { - sgemv_( trans, m, n, alpha, a, lda, x, incx, beta, y, incy); + sgemv_blis_impl( trans, m, n, alpha, a, lda, x, incx, beta, y, incy); } void SGEMV_(const char *trans,const f77_int *m,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,const float *x,const f77_int *incx,const float *beta,float *y,const f77_int *incy) { - sgemv_( trans, m, n, alpha, a, lda, x, incx, beta, y, incy); + sgemv_blis_impl( trans, m, n, alpha, a, lda, x, incx, beta, y, incy); } void SGER(const f77_int *m,const f77_int *n,const float *alpha,const float *x,const f77_int *incx,const float *y,const f77_int *incy,float *a,const f77_int *lda) { - sger_( m, n, alpha, x, incx, y, incy, a, lda); + sger_blis_impl( m, n, alpha, x, incx, y, incy, a, lda); } void sger(const f77_int *m,const f77_int *n,const float *alpha,const float *x,const f77_int *incx,const float *y,const f77_int *incy,float *a,const f77_int *lda) { - sger_( m, n, alpha, x, incx, y, incy, a, lda); + sger_blis_impl( m, n, alpha, x, incx, y, incy, a, lda); } void SGER_(const f77_int *m,const f77_int *n,const float *alpha,const float *x,const f77_int *incx,const float *y,const f77_int *incy,float *a,const f77_int *lda) { - sger_( m, n, alpha, x, incx, y, incy, a, lda); + sger_blis_impl( m, n, alpha, x, incx, y, incy, a, lda); } @@ -1539,17 +1539,17 @@ void SROTMG_(float *sd1,float *sd2,float *sx1,const float *sy1,float *spara void SSBMV(const char *uplo,const f77_int *n,const f77_int *k,const float *alpha,const float *a,const f77_int *lda,const float *x,const f77_int *incx,const float *beta,float *y,const f77_int *incy) { - ssbmv_( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); + ssbmv_blis_impl( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); } void ssbmv(const char *uplo,const f77_int *n,const f77_int *k,const float *alpha,const float *a,const f77_int *lda,const float *x,const f77_int *incx,const float *beta,float *y,const f77_int *incy) { - ssbmv_( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); + ssbmv_blis_impl( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); } void SSBMV_(const char *uplo,const f77_int *n,const f77_int *k,const float *alpha,const float *a,const f77_int *lda,const float *x,const f77_int *incx,const float *beta,float *y,const f77_int *incy) { - ssbmv_( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); + ssbmv_blis_impl( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); } void SSCAL(const f77_int *n,const float *sa,float *sx,const f77_int *incx) @@ -1569,47 +1569,47 @@ void SSCAL_(const f77_int *n,const float *sa,float *sx,const f77_int *incx) void SSPMV(const char *uplo,const f77_int *n,const float *alpha,const float *ap,const float *x,const f77_int *incx,const float *beta,float *y,const f77_int *incy) { - sspmv_( uplo, n, alpha, ap, x, incx, beta, y, incy); + sspmv_blis_impl( uplo, n, alpha, ap, x, incx, beta, y, incy); } void sspmv(const char *uplo,const f77_int *n,const float *alpha,const float *ap,const float *x,const f77_int *incx,const float *beta,float *y,const f77_int *incy) { - sspmv_( uplo, n, alpha, ap, x, incx, beta, y, incy); + sspmv_blis_impl( uplo, n, alpha, ap, x, incx, beta, y, incy); } void SSPMV_(const char *uplo,const f77_int *n,const float *alpha,const float *ap,const float *x,const f77_int *incx,const float *beta,float *y,const f77_int *incy) { - sspmv_( uplo, n, alpha, ap, x, incx, beta, y, incy); + sspmv_blis_impl( uplo, n, alpha, ap, x, incx, beta, y, incy); } void SSPR(const char *uplo,const f77_int *n,const float *alpha,const float *x,const f77_int *incx,float *ap) { - sspr_( uplo, n, alpha, x, incx, ap); + sspr_blis_impl( uplo, n, alpha, x, incx, ap); } void sspr(const char *uplo,const f77_int *n,const float *alpha,const float *x,const f77_int *incx,float *ap) { - sspr_( uplo, n, alpha, x, incx, ap); + sspr_blis_impl( uplo, n, alpha, x, incx, ap); } void SSPR_(const char *uplo,const f77_int *n,const float *alpha,const float *x,const f77_int *incx,float *ap) { - sspr_( uplo, n, alpha, x, incx, ap); + sspr_blis_impl( uplo, n, alpha, x, incx, ap); } void SSPR2(const char *uplo,const f77_int *n,const float *alpha,const float *x,const f77_int *incx,const float *y,const f77_int *incy,float *ap) { - sspr2_( uplo, n, alpha, x, incx, y, incy, ap); + sspr2_blis_impl( uplo, n, alpha, x, incx, y, incy, ap); } void sspr2(const char *uplo,const f77_int *n,const float *alpha,const float *x,const f77_int *incx,const float *y,const f77_int *incy,float *ap) { - sspr2_( uplo, n, alpha, x, incx, y, incy, ap); + sspr2_blis_impl( uplo, n, alpha, x, incx, y, incy, ap); } void SSPR2_(const char *uplo,const f77_int *n,const float *alpha,const float *x,const f77_int *incx,const float *y,const f77_int *incy,float *ap) { - sspr2_( uplo, n, alpha, x, incx, y, incy, ap); + sspr2_blis_impl( uplo, n, alpha, x, incx, y, incy, ap); } void SSWAP(const f77_int *n,float *sx,const f77_int *incx,float *sy,const f77_int *incy) @@ -1644,47 +1644,47 @@ void SSYMM_(const char *side,const char *uplo,const f77_int *m,const f77_int void SSYMV(const char *uplo,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,const float *x,const f77_int *incx,const float *beta,float *y,const f77_int *incy) { - ssymv_( uplo, n, alpha, a, lda, x, incx, beta, y, incy); + ssymv_blis_impl( uplo, n, alpha, a, lda, x, incx, beta, y, incy); } void ssymv(const char *uplo,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,const float *x,const f77_int *incx,const float *beta,float *y,const f77_int *incy) { - ssymv_( uplo, n, alpha, a, lda, x, incx, beta, y, incy); + ssymv_blis_impl( uplo, n, alpha, a, lda, x, incx, beta, y, incy); } void SSYMV_(const char *uplo,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,const float *x,const f77_int *incx,const float *beta,float *y,const f77_int *incy) { - ssymv_( uplo, n, alpha, a, lda, x, incx, beta, y, incy); + ssymv_blis_impl( uplo, n, alpha, a, lda, x, incx, beta, y, incy); } void SSYR(const char *uplo,const f77_int *n,const float *alpha,const float *x,const f77_int *incx,float *a,const f77_int *lda) { - ssyr_( uplo, n, alpha, x, incx, a, lda); + ssyr_blis_impl( uplo, n, alpha, x, incx, a, lda); } void ssyr(const char *uplo,const f77_int *n,const float *alpha,const float *x,const f77_int *incx,float *a,const f77_int *lda) { - ssyr_( uplo, n, alpha, x, incx, a, lda); + ssyr_blis_impl( uplo, n, alpha, x, incx, a, lda); } void SSYR_(const char *uplo,const f77_int *n,const float *alpha,const float *x,const f77_int *incx,float *a,const f77_int *lda) { - ssyr_( uplo, n, alpha, x, incx, a, lda); + ssyr_blis_impl( uplo, n, alpha, x, incx, a, lda); } void SSYR2(const char *uplo,const f77_int *n,const float *alpha,const float *x,const f77_int *incx,const float *y,const f77_int *incy,float *a,const f77_int *lda) { - ssyr2_( uplo, n, alpha, x, incx, y, incy, a, lda); + ssyr2_blis_impl( uplo, n, alpha, x, incx, y, incy, a, lda); } void ssyr2(const char *uplo,const f77_int *n,const float *alpha,const float *x,const f77_int *incx,const float *y,const f77_int *incy,float *a,const f77_int *lda) { - ssyr2_( uplo, n, alpha, x, incx, y, incy, a, lda); + ssyr2_blis_impl( uplo, n, alpha, x, incx, y, incy, a, lda); } void SSYR2_(const char *uplo,const f77_int *n,const float *alpha,const float *x,const f77_int *incx,const float *y,const f77_int *incy,float *a,const f77_int *lda) { - ssyr2_( uplo, n, alpha, x, incx, y, incy, a, lda); + ssyr2_blis_impl( uplo, n, alpha, x, incx, y, incy, a, lda); } void SSYR2K(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const float *alpha,const float *a,const f77_int *lda,const float *b,const f77_int *ldb,const float *beta,float *c,const f77_int *ldc) @@ -1719,62 +1719,62 @@ void SSYRK_(const char *uplo,const char *trans,const f77_int *n,const f77_in void STBMV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const float *a,const f77_int *lda,float *x,const f77_int *incx) { - stbmv_( uplo, trans, diag, n, k, a, lda, x, incx); + stbmv_blis_impl( uplo, trans, diag, n, k, a, lda, x, incx); } void stbmv(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const float *a,const f77_int *lda,float *x,const f77_int *incx) { - stbmv_( uplo, trans, diag, n, k, a, lda, x, incx); + stbmv_blis_impl( uplo, trans, diag, n, k, a, lda, x, incx); } void STBMV_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const float *a,const f77_int *lda,float *x,const f77_int *incx) { - stbmv_( uplo, trans, diag, n, k, a, lda, x, incx); + stbmv_blis_impl( uplo, trans, diag, n, k, a, lda, x, incx); } void STBSV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const float *a,const f77_int *lda,float *x,const f77_int *incx) { - stbsv_( uplo, trans, diag, n, k, a, lda, x, incx); + stbsv_blis_impl( uplo, trans, diag, n, k, a, lda, x, incx); } void stbsv(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const float *a,const f77_int *lda,float *x,const f77_int *incx) { - stbsv_( uplo, trans, diag, n, k, a, lda, x, incx); + stbsv_blis_impl( uplo, trans, diag, n, k, a, lda, x, incx); } void STBSV_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const float *a,const f77_int *lda,float *x,const f77_int *incx) { - stbsv_( uplo, trans, diag, n, k, a, lda, x, incx); + stbsv_blis_impl( uplo, trans, diag, n, k, a, lda, x, incx); } void STPMV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const float *ap,float *x,const f77_int *incx) { - stpmv_( uplo, trans, diag, n, ap, x, incx); + stpmv_blis_impl( uplo, trans, diag, n, ap, x, incx); } void stpmv(const char *uplo,const char *trans,const char *diag,const f77_int *n,const float *ap,float *x,const f77_int *incx) { - stpmv_( uplo, trans, diag, n, ap, x, incx); + stpmv_blis_impl( uplo, trans, diag, n, ap, x, incx); } void STPMV_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const float *ap,float *x,const f77_int *incx) { - stpmv_( uplo, trans, diag, n, ap, x, incx); + stpmv_blis_impl( uplo, trans, diag, n, ap, x, incx); } void STPSV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const float *ap,float *x,const f77_int *incx) { - stpsv_( uplo, trans, diag, n, ap, x, incx); + stpsv_blis_impl( uplo, trans, diag, n, ap, x, incx); } void stpsv(const char *uplo,const char *trans,const char *diag,const f77_int *n,const float *ap,float *x,const f77_int *incx) { - stpsv_( uplo, trans, diag, n, ap, x, incx); + stpsv_blis_impl( uplo, trans, diag, n, ap, x, incx); } void STPSV_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const float *ap,float *x,const f77_int *incx) { - stpsv_( uplo, trans, diag, n, ap, x, incx); + stpsv_blis_impl( uplo, trans, diag, n, ap, x, incx); } void STRMM(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,float *b,const f77_int *ldb) @@ -1794,17 +1794,17 @@ void STRMM_(const char *side,const char *uplo,const char *transa,const cha void STRMV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const float *a,const f77_int *lda,float *x,const f77_int *incx) { - strmv_( uplo, trans, diag, n, a, lda, x, incx); + strmv_blis_impl( uplo, trans, diag, n, a, lda, x, incx); } void strmv(const char *uplo,const char *trans,const char *diag,const f77_int *n,const float *a,const f77_int *lda,float *x,const f77_int *incx) { - strmv_( uplo, trans, diag, n, a, lda, x, incx); + strmv_blis_impl( uplo, trans, diag, n, a, lda, x, incx); } void STRMV_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const float *a,const f77_int *lda,float *x,const f77_int *incx) { - strmv_( uplo, trans, diag, n, a, lda, x, incx); + strmv_blis_impl( uplo, trans, diag, n, a, lda, x, incx); } void STRSM(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,float *b,const f77_int *ldb) @@ -1824,17 +1824,17 @@ void STRSM_(const char *side,const char *uplo,const char *transa,const cha void STRSV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const float *a,const f77_int *lda,float *x,const f77_int *incx) { - strsv_( uplo, trans, diag, n, a, lda, x, incx); + strsv_blis_impl( uplo, trans, diag, n, a, lda, x, incx); } void strsv(const char *uplo,const char *trans,const char *diag,const f77_int *n,const float *a,const f77_int *lda,float *x,const f77_int *incx) { - strsv_( uplo, trans, diag, n, a, lda, x, incx); + strsv_blis_impl( uplo, trans, diag, n, a, lda, x, incx); } void STRSV_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const float *a,const f77_int *lda,float *x,const f77_int *incx) { - strsv_( uplo, trans, diag, n, a, lda, x, incx); + strsv_blis_impl( uplo, trans, diag, n, a, lda, x, incx); } int XERBLA(const char *srname,const f77_int *info, ftnlen n) @@ -1914,17 +1914,17 @@ void ZDSCAL_(const f77_int *n,const double *da,dcomplex *zx,const f77_int *incx) void ZGBMV(const char *trans,const f77_int *m,const f77_int *n,const f77_int *kl,const f77_int *ku,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *x,const f77_int *incx,const dcomplex *beta,dcomplex *y,const f77_int *incy) { - zgbmv_( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); + zgbmv_blis_impl( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); } void zgbmv(const char *trans,const f77_int *m,const f77_int *n,const f77_int *kl,const f77_int *ku,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *x,const f77_int *incx,const dcomplex *beta,dcomplex *y,const f77_int *incy) { - zgbmv_( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); + zgbmv_blis_impl( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); } void ZGBMV_(const char *trans,const f77_int *m,const f77_int *n,const f77_int *kl,const f77_int *ku,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *x,const f77_int *incx,const dcomplex *beta,dcomplex *y,const f77_int *incy) { - zgbmv_( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); + zgbmv_blis_impl( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); } void ZGEMM(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) @@ -1944,62 +1944,62 @@ void ZGEMM_(const char *transa,const char *transb,const f77_int *m,const f77 void ZGEMV(const char *trans,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *x,const f77_int *incx,const dcomplex *beta,dcomplex *y,const f77_int *incy) { - zgemv_( trans, m, n, alpha, a, lda, x, incx, beta, y, incy); + zgemv_blis_impl( trans, m, n, alpha, a, lda, x, incx, beta, y, incy); } void zgemv(const char *trans,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *x,const f77_int *incx,const dcomplex *beta,dcomplex *y,const f77_int *incy) { - zgemv_( trans, m, n, alpha, a, lda, x, incx, beta, y, incy); + zgemv_blis_impl( trans, m, n, alpha, a, lda, x, incx, beta, y, incy); } void ZGEMV_(const char *trans,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *x,const f77_int *incx,const dcomplex *beta,dcomplex *y,const f77_int *incy) { - zgemv_( trans, m, n, alpha, a, lda, x, incx, beta, y, incy); + zgemv_blis_impl( trans, m, n, alpha, a, lda, x, incx, beta, y, incy); } void ZGERC(const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *x,const f77_int *incx,const dcomplex *y,const f77_int *incy,dcomplex *a,const f77_int *lda) { - zgerc_( m, n, alpha, x, incx, y, incy, a, lda); + zgerc_blis_impl( m, n, alpha, x, incx, y, incy, a, lda); } void zgerc(const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *x,const f77_int *incx,const dcomplex *y,const f77_int *incy,dcomplex *a,const f77_int *lda) { - zgerc_( m, n, alpha, x, incx, y, incy, a, lda); + zgerc_blis_impl( m, n, alpha, x, incx, y, incy, a, lda); } void ZGERC_(const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *x,const f77_int *incx,const dcomplex *y,const f77_int *incy,dcomplex *a,const f77_int *lda) { - zgerc_( m, n, alpha, x, incx, y, incy, a, lda); + zgerc_blis_impl( m, n, alpha, x, incx, y, incy, a, lda); } void ZGERU(const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *x,const f77_int *incx,const dcomplex *y,const f77_int *incy,dcomplex *a,const f77_int *lda) { - zgeru_( m, n, alpha, x, incx, y, incy, a, lda); + zgeru_blis_impl( m, n, alpha, x, incx, y, incy, a, lda); } void zgeru(const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *x,const f77_int *incx,const dcomplex *y,const f77_int *incy,dcomplex *a,const f77_int *lda) { - zgeru_( m, n, alpha, x, incx, y, incy, a, lda); + zgeru_blis_impl( m, n, alpha, x, incx, y, incy, a, lda); } void ZGERU_(const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *x,const f77_int *incx,const dcomplex *y,const f77_int *incy,dcomplex *a,const f77_int *lda) { - zgeru_( m, n, alpha, x, incx, y, incy, a, lda); + zgeru_blis_impl( m, n, alpha, x, incx, y, incy, a, lda); } void ZHBMV(const char *uplo,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *x,const f77_int *incx,const dcomplex *beta,dcomplex *y,const f77_int *incy) { - zhbmv_( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); + zhbmv_blis_impl( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); } void zhbmv(const char *uplo,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *x,const f77_int *incx,const dcomplex *beta,dcomplex *y,const f77_int *incy) { - zhbmv_( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); + zhbmv_blis_impl( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); } void ZHBMV_(const char *uplo,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *x,const f77_int *incx,const dcomplex *beta,dcomplex *y,const f77_int *incy) { - zhbmv_( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); + zhbmv_blis_impl( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); } void ZHEMM(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) @@ -2019,47 +2019,47 @@ void ZHEMM_(const char *side,const char *uplo,const f77_int *m,const f77_int void ZHEMV(const char *uplo,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *x,const f77_int *incx,const dcomplex *beta,dcomplex *y,const f77_int *incy) { - zhemv_( uplo, n, alpha, a, lda, x, incx, beta, y, incy); + zhemv_blis_impl( uplo, n, alpha, a, lda, x, incx, beta, y, incy); } void zhemv(const char *uplo,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *x,const f77_int *incx,const dcomplex *beta,dcomplex *y,const f77_int *incy) { - zhemv_( uplo, n, alpha, a, lda, x, incx, beta, y, incy); + zhemv_blis_impl( uplo, n, alpha, a, lda, x, incx, beta, y, incy); } void ZHEMV_(const char *uplo,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *x,const f77_int *incx,const dcomplex *beta,dcomplex *y,const f77_int *incy) { - zhemv_( uplo, n, alpha, a, lda, x, incx, beta, y, incy); + zhemv_blis_impl( uplo, n, alpha, a, lda, x, incx, beta, y, incy); } void ZHER(const char *uplo,const f77_int *n,const double *alpha,const dcomplex *x,const f77_int *incx,dcomplex *a,const f77_int *lda) { - zher_( uplo, n, alpha, x, incx, a, lda); + zher_blis_impl( uplo, n, alpha, x, incx, a, lda); } void zher(const char *uplo,const f77_int *n,const double *alpha,const dcomplex *x,const f77_int *incx,dcomplex *a,const f77_int *lda) { - zher_( uplo, n, alpha, x, incx, a, lda); + zher_blis_impl( uplo, n, alpha, x, incx, a, lda); } void ZHER_(const char *uplo,const f77_int *n,const double *alpha,const dcomplex *x,const f77_int *incx,dcomplex *a,const f77_int *lda) { - zher_( uplo, n, alpha, x, incx, a, lda); + zher_blis_impl( uplo, n, alpha, x, incx, a, lda); } void ZHER2(const char *uplo,const f77_int *n,const dcomplex *alpha,const dcomplex *x,const f77_int *incx,const dcomplex *y,const f77_int *incy,dcomplex *a,const f77_int *lda) { - zher2_( uplo, n, alpha, x, incx, y, incy, a, lda); + zher2_blis_impl( uplo, n, alpha, x, incx, y, incy, a, lda); } void zher2(const char *uplo,const f77_int *n,const dcomplex *alpha,const dcomplex *x,const f77_int *incx,const dcomplex *y,const f77_int *incy,dcomplex *a,const f77_int *lda) { - zher2_( uplo, n, alpha, x, incx, y, incy, a, lda); + zher2_blis_impl( uplo, n, alpha, x, incx, y, incy, a, lda); } void ZHER2_(const char *uplo,const f77_int *n,const dcomplex *alpha,const dcomplex *x,const f77_int *incx,const dcomplex *y,const f77_int *incy,dcomplex *a,const f77_int *lda) { - zher2_( uplo, n, alpha, x, incx, y, incy, a, lda); + zher2_blis_impl( uplo, n, alpha, x, incx, y, incy, a, lda); } void ZHER2K(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const double *beta,dcomplex *c,const f77_int *ldc) @@ -2094,47 +2094,47 @@ void ZHERK_(const char *uplo,const char *trans,const f77_int *n,const f77_in void ZHPMV(const char *uplo,const f77_int *n,const dcomplex *alpha,const dcomplex *ap,const dcomplex *x,const f77_int *incx,const dcomplex *beta,dcomplex *y,const f77_int *incy) { - zhpmv_( uplo, n, alpha, ap, x, incx, beta, y, incy); + zhpmv_blis_impl( uplo, n, alpha, ap, x, incx, beta, y, incy); } void zhpmv(const char *uplo,const f77_int *n,const dcomplex *alpha,const dcomplex *ap,const dcomplex *x,const f77_int *incx,const dcomplex *beta,dcomplex *y,const f77_int *incy) { - zhpmv_( uplo, n, alpha, ap, x, incx, beta, y, incy); + zhpmv_blis_impl( uplo, n, alpha, ap, x, incx, beta, y, incy); } void ZHPMV_(const char *uplo,const f77_int *n,const dcomplex *alpha,const dcomplex *ap,const dcomplex *x,const f77_int *incx,const dcomplex *beta,dcomplex *y,const f77_int *incy) { - zhpmv_( uplo, n, alpha, ap, x, incx, beta, y, incy); + zhpmv_blis_impl( uplo, n, alpha, ap, x, incx, beta, y, incy); } void ZHPR(const char *uplo,const f77_int *n,const bla_double *alpha,const dcomplex *x,const f77_int *incx,dcomplex *ap) { - zhpr_( uplo, n, alpha, x, incx, ap); + zhpr_blis_impl( uplo, n, alpha, x, incx, ap); } void zhpr(const char *uplo,const f77_int *n,const bla_double *alpha,const dcomplex *x,const f77_int *incx,dcomplex *ap) { - zhpr_( uplo, n, alpha, x, incx, ap); + zhpr_blis_impl( uplo, n, alpha, x, incx, ap); } void ZHPR_(const char *uplo,const f77_int *n,const bla_double *alpha,const dcomplex *x,const f77_int *incx,dcomplex *ap) { - zhpr_( uplo, n, alpha, x, incx, ap); + zhpr_blis_impl( uplo, n, alpha, x, incx, ap); } void ZHPR2(const char *uplo,const f77_int *n,const dcomplex *alpha,const dcomplex *x,const f77_int *incx,const dcomplex *y,const f77_int *incy,dcomplex *ap) { - zhpr2_( uplo, n, alpha, x, incx, y, incy, ap); + zhpr2_blis_impl( uplo, n, alpha, x, incx, y, incy, ap); } void zhpr2(const char *uplo,const f77_int *n,const dcomplex *alpha,const dcomplex *x,const f77_int *incx,const dcomplex *y,const f77_int *incy,dcomplex *ap) { - zhpr2_( uplo, n, alpha, x, incx, y, incy, ap); + zhpr2_blis_impl( uplo, n, alpha, x, incx, y, incy, ap); } void ZHPR2_(const char *uplo,const f77_int *n,const dcomplex *alpha,const dcomplex *x,const f77_int *incx,const dcomplex *y,const f77_int *incy,dcomplex *ap) { - zhpr2_( uplo, n, alpha, x, incx, y, incy, ap); + zhpr2_blis_impl( uplo, n, alpha, x, incx, y, incy, ap); } void ZROTG(dcomplex *ca,bla_dcomplex *cb,bla_double *c,dcomplex *s) @@ -2229,62 +2229,62 @@ void ZSYRK_(const char *uplo,const char *trans,const f77_int *n,const f77_in void ZTBMV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const dcomplex *a,const f77_int *lda,dcomplex *x,const f77_int *incx) { - ztbmv_( uplo, trans, diag, n, k, a, lda, x, incx); + ztbmv_blis_impl( uplo, trans, diag, n, k, a, lda, x, incx); } void ztbmv(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const dcomplex *a,const f77_int *lda,dcomplex *x,const f77_int *incx) { - ztbmv_( uplo, trans, diag, n, k, a, lda, x, incx); + ztbmv_blis_impl( uplo, trans, diag, n, k, a, lda, x, incx); } void ZTBMV_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const dcomplex *a,const f77_int *lda,dcomplex *x,const f77_int *incx) { - ztbmv_( uplo, trans, diag, n, k, a, lda, x, incx); + ztbmv_blis_impl( uplo, trans, diag, n, k, a, lda, x, incx); } void ZTBSV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const dcomplex *a,const f77_int *lda,dcomplex *x,const f77_int *incx) { - ztbsv_( uplo, trans, diag, n, k, a, lda, x, incx); + ztbsv_blis_impl( uplo, trans, diag, n, k, a, lda, x, incx); } void ztbsv(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const dcomplex *a,const f77_int *lda,dcomplex *x,const f77_int *incx) { - ztbsv_( uplo, trans, diag, n, k, a, lda, x, incx); + ztbsv_blis_impl( uplo, trans, diag, n, k, a, lda, x, incx); } void ZTBSV_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const dcomplex *a,const f77_int *lda,dcomplex *x,const f77_int *incx) { - ztbsv_( uplo, trans, diag, n, k, a, lda, x, incx); + ztbsv_blis_impl( uplo, trans, diag, n, k, a, lda, x, incx); } void ZTPMV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const dcomplex *ap,dcomplex *x,const f77_int *incx) { - ztpmv_( uplo, trans, diag, n, ap, x, incx); + ztpmv_blis_impl( uplo, trans, diag, n, ap, x, incx); } void ztpmv(const char *uplo,const char *trans,const char *diag,const f77_int *n,const dcomplex *ap,dcomplex *x,const f77_int *incx) { - ztpmv_( uplo, trans, diag, n, ap, x, incx); + ztpmv_blis_impl( uplo, trans, diag, n, ap, x, incx); } void ZTPMV_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const dcomplex *ap,dcomplex *x,const f77_int *incx) { - ztpmv_( uplo, trans, diag, n, ap, x, incx); + ztpmv_blis_impl( uplo, trans, diag, n, ap, x, incx); } void ZTPSV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const dcomplex *ap,dcomplex *x,const f77_int *incx) { - ztpsv_( uplo, trans, diag, n, ap, x, incx); + ztpsv_blis_impl( uplo, trans, diag, n, ap, x, incx); } void ztpsv(const char *uplo,const char *trans,const char *diag,const f77_int *n,const dcomplex *ap,dcomplex *x,const f77_int *incx) { - ztpsv_( uplo, trans, diag, n, ap, x, incx); + ztpsv_blis_impl( uplo, trans, diag, n, ap, x, incx); } void ZTPSV_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const dcomplex *ap,dcomplex *x,const f77_int *incx) { - ztpsv_( uplo, trans, diag, n, ap, x, incx); + ztpsv_blis_impl( uplo, trans, diag, n, ap, x, incx); } void ZTRMM(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,dcomplex *b,const f77_int *ldb) @@ -2304,17 +2304,17 @@ void ZTRMM_(const char *side,const char *uplo,const char *transa,const cha void ZTRMV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const dcomplex *a,const f77_int *lda,dcomplex *x,const f77_int *incx) { - ztrmv_( uplo, trans, diag, n, a, lda, x, incx); + ztrmv_blis_impl( uplo, trans, diag, n, a, lda, x, incx); } void ztrmv(const char *uplo,const char *trans,const char *diag,const f77_int *n,const dcomplex *a,const f77_int *lda,dcomplex *x,const f77_int *incx) { - ztrmv_( uplo, trans, diag, n, a, lda, x, incx); + ztrmv_blis_impl( uplo, trans, diag, n, a, lda, x, incx); } void ZTRMV_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const dcomplex *a,const f77_int *lda,dcomplex *x,const f77_int *incx) { - ztrmv_( uplo, trans, diag, n, a, lda, x, incx); + ztrmv_blis_impl( uplo, trans, diag, n, a, lda, x, incx); } void ZTRSM(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,dcomplex *b,const f77_int *ldb) @@ -2334,17 +2334,17 @@ void ZTRSM_(const char *side,const char *uplo,const char *transa,const cha void ZTRSV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const dcomplex *a,const f77_int *lda,dcomplex *x,const f77_int *incx) { - ztrsv_( uplo, trans, diag, n, a, lda, x, incx); + ztrsv_blis_impl( uplo, trans, diag, n, a, lda, x, incx); } void ztrsv(const char *uplo,const char *trans,const char *diag,const f77_int *n,const dcomplex *a,const f77_int *lda,dcomplex *x,const f77_int *incx) { - ztrsv_( uplo, trans, diag, n, a, lda, x, incx); + ztrsv_blis_impl( uplo, trans, diag, n, a, lda, x, incx); } void ZTRSV_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const dcomplex *a,const f77_int *lda,dcomplex *x,const f77_int *incx) { - ztrsv_( uplo, trans, diag, n, a, lda, x, incx); + ztrsv_blis_impl( uplo, trans, diag, n, a, lda, x, incx); } From 326d8a557f88c08725d35487460a0140039707a8 Mon Sep 17 00:00:00 2001 From: Harihara Sudhan S Date: Fri, 26 Aug 2022 11:35:35 +0530 Subject: [PATCH 196/243] Performance regression in u8s8s16os16 - Performance of u8s8s16os16 came down by 40% after the introduction of post-ops - Analysis revealed that the target compiler assumed false dependency and was generating sub-optimal code due to the post-ops structure - Inserted vzeroupper to hint the compiler that no ISA change will occur AMD-Internal: [CPUPL-2447] Change-Id: I0b383b9742ad237d0e053394602428872691ef0c --- .../u8s8s16/lpgemm_6x32rowmajor_amd256.c | 297 ++++++++++-------- 1 file changed, 171 insertions(+), 126 deletions(-) diff --git a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_6x32rowmajor_amd256.c b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_6x32rowmajor_amd256.c index d90f195550..ad0d6e2a66 100644 --- a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_6x32rowmajor_amd256.c +++ b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_6x32rowmajor_amd256.c @@ -104,17 +104,11 @@ LPGEMM_MAIN_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x32) return; } - // B matrix storage. - __m256i b0, b1; - - // A matrix storage. - __m256i a_int32_0, a_int32_1; - - // Intermediate vectors - __m256i inter_vec[4]; - for (dim_t ir = 0; ir < m_full_pieces_loop_limit; ir += MR) { + + _mm256_zeroupper(); + // Registers to use for accumulating C. __m256i c_int16_0p0 = _mm256_setzero_si256(); __m256i c_int16_0p1 = _mm256_setzero_si256(); @@ -139,246 +133,293 @@ LPGEMM_MAIN_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x32) dim_t offset = kr * 2; // Broadcast a[0,kr:kr+2]. - a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 0) + (cs_a * offset))); - - b0 = _mm256_loadu_si256((__m256i const *)(b + (64 * kr) + (NR * 0))); - b1 = _mm256_loadu_si256((__m256i const *)(b + (64 * kr) + (NR * 1))); + __m256i a_int32_0 = + _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 0) + + (cs_a * offset))); - // Broadcast a[1,kr:kr+2]. - a_int32_1 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 1) + (cs_a * offset))); + __m256i b0 = + _mm256_loadu_si256((__m256i const *)(b + (64 * kr) + (NR * 0))); + __m256i b1 = + _mm256_loadu_si256((__m256i const *)(b + (64 * kr) + (NR * 1))); // Seperate register for intermediate op - inter_vec[0] = _mm256_maddubs_epi16(a_int32_0, b0); - inter_vec[1] = _mm256_maddubs_epi16(a_int32_0, b1); + __m256i inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] - c_int16_0p0 = _mm256_add_epi16(inter_vec[0], c_int16_0p0); - c_int16_0p1 = _mm256_add_epi16(inter_vec[1], c_int16_0p1); + c_int16_0p0 = _mm256_add_epi16(inter_vec, c_int16_0p0); - // Broadcast a[2,kr:kr+2]. - a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 2) + (cs_a * offset))); + inter_vec = _mm256_maddubs_epi16(a_int32_0, b1); + c_int16_0p1 = _mm256_add_epi16(inter_vec, c_int16_0p1); + + // Broadcast a[1,kr:kr+2]. + a_int32_0 = + _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 1) + (cs_a * offset))); // Seperate register for intermediate op - inter_vec[2] = _mm256_maddubs_epi16(a_int32_1, b0); - inter_vec[3] = _mm256_maddubs_epi16(a_int32_1, b1); + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. // c[1,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] - c_int16_1p0 = _mm256_add_epi16(inter_vec[2], c_int16_1p0); - c_int16_1p1 = _mm256_add_epi16(inter_vec[3], c_int16_1p1); + c_int16_1p0 = _mm256_add_epi16(inter_vec, c_int16_1p0); - // Broadcast a[3,kr:kr+2]. - a_int32_1 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 3) + (cs_a * offset))); + inter_vec = _mm256_maddubs_epi16(a_int32_0, b1); + c_int16_1p1 = _mm256_add_epi16(inter_vec, c_int16_1p1); - // Seperate register for intermediate op - inter_vec[0] = _mm256_maddubs_epi16(a_int32_0, b0); - inter_vec[1] = _mm256_maddubs_epi16(a_int32_0, b1); + // Broadcast a[2,kr:kr+2]. + a_int32_0 = + _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 2) + (cs_a * offset))); + // Seperate register for intermediate op + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] - c_int16_2p0 = _mm256_add_epi16(inter_vec[0], c_int16_2p0); - c_int16_2p1 = _mm256_add_epi16(inter_vec[1], c_int16_2p1); + c_int16_2p0 = _mm256_add_epi16(inter_vec, c_int16_2p0); - // Broadcast a[4,kr:kr+2]. - a_int32_0 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 4) + (cs_a * offset))); + inter_vec = _mm256_maddubs_epi16(a_int32_0, b1); + c_int16_2p1 = _mm256_add_epi16(inter_vec, c_int16_2p1); + + // Broadcast a[3,kr:kr+2]. + a_int32_0 = + _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 3) + (cs_a * offset))); // Seperate register for intermediate op - inter_vec[2] = _mm256_maddubs_epi16(a_int32_1, b0); - inter_vec[3] = _mm256_maddubs_epi16(a_int32_1, b1); + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] - c_int16_3p0 = _mm256_add_epi16(inter_vec[2], c_int16_3p0); - c_int16_3p1 = _mm256_add_epi16(inter_vec[3], c_int16_3p1); + c_int16_3p0 = _mm256_add_epi16(inter_vec, c_int16_3p0); - // Broadcast a[5,kr:kr+2]. - a_int32_1 = _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 5) + (cs_a * offset))); + inter_vec = _mm256_maddubs_epi16(a_int32_0, b1); + c_int16_3p1 = _mm256_add_epi16(inter_vec, c_int16_3p1); + + // Broadcast a[4,kr:kr+2]. + a_int32_0 = + _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 4) + (cs_a * offset))); // Seperate register for intermediate op - inter_vec[0] = _mm256_maddubs_epi16(a_int32_0, b0); - inter_vec[1] = _mm256_maddubs_epi16(a_int32_0, b1); + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+4,0-31] - c_int16_4p0 = _mm256_add_epi16(inter_vec[0], c_int16_4p0); - c_int16_4p1 = _mm256_add_epi16(inter_vec[1], c_int16_4p1); + c_int16_4p0 = _mm256_add_epi16(inter_vec, c_int16_4p0); + + inter_vec = _mm256_maddubs_epi16(a_int32_0, b1); + + c_int16_4p1 = _mm256_add_epi16(inter_vec, c_int16_4p1); + + // Broadcast a[5,kr:kr+2]. + a_int32_0 = + _mm256_set1_epi16(*(uint16_t *)(a + (rs_a * 5) + (cs_a * offset))); // Seperate register for intermediate op - inter_vec[2] = _mm256_maddubs_epi16(a_int32_1, b0); - inter_vec[3] = _mm256_maddubs_epi16(a_int32_1, b1); + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+4,0-31] - c_int16_5p0 = _mm256_add_epi16(inter_vec[2], c_int16_5p0); - c_int16_5p1 = _mm256_add_epi16(inter_vec[3], c_int16_5p1); + c_int16_5p0 = _mm256_add_epi16(inter_vec, c_int16_5p0); + + inter_vec = _mm256_maddubs_epi16(a_int32_0, b1); + c_int16_5p1 = _mm256_add_epi16(inter_vec, c_int16_5p1); } // Handle k remainder. if (k_partial_pieces > 0) { - uint8_t a_kfringe; - b0 = _mm256_loadu_si256((__m256i const *)(b + (64 * k_full_pieces) + (NR * 0))); - b1 = _mm256_loadu_si256((__m256i const *)(b + (64 * k_full_pieces) + (NR * 1))); + __m256i b0 = _mm256_loadu_si256((__m256i const *) + (b + (64 * k_full_pieces) + (NR * 0))); + __m256i b1 = _mm256_loadu_si256((__m256i const *) + (b + (64 * k_full_pieces) + (NR * 1))); - a_kfringe = *(a + (rs_a * 0) + (cs_a * (k_full_pieces * 2))); - a_int32_0 = _mm256_set1_epi8(a_kfringe); + uint8_t a_kfringe = *(a + (rs_a * 0) + (cs_a * (k_full_pieces * 2))); + __m256i a_int32_0 = _mm256_set1_epi8(a_kfringe); // Seperate register for intermediate op - inter_vec[0] = _mm256_maddubs_epi16(a_int32_0, b0); - inter_vec[1] = _mm256_maddubs_epi16(a_int32_0, b1); + __m256i inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] - c_int16_0p0 = _mm256_add_epi16(inter_vec[0], c_int16_0p0); - c_int16_0p1 = _mm256_add_epi16(inter_vec[1], c_int16_0p1); + c_int16_0p0 = _mm256_add_epi16(inter_vec, c_int16_0p0); + + inter_vec = _mm256_maddubs_epi16(a_int32_0, b1); + c_int16_0p1 = _mm256_add_epi16(inter_vec, c_int16_0p1); a_kfringe = *(a + (rs_a * 1) + (cs_a * (k_full_pieces * 2))); - a_int32_1 = _mm256_set1_epi8(a_kfringe); + a_int32_0 = _mm256_set1_epi8(a_kfringe); // Seperate register for intermediate op - inter_vec[2] = _mm256_maddubs_epi16(a_int32_1, b0); - inter_vec[3] = _mm256_maddubs_epi16(a_int32_1, b1); + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+4,0-31] - c_int16_1p0 = _mm256_add_epi16(inter_vec[2], c_int16_1p0); - c_int16_1p1 = _mm256_add_epi16(inter_vec[3], c_int16_1p1); + c_int16_1p0 = _mm256_add_epi16(inter_vec, c_int16_1p0); + + inter_vec = _mm256_maddubs_epi16(a_int32_0, b1); + c_int16_1p1 = _mm256_add_epi16(inter_vec, c_int16_1p1); a_kfringe = *(a + (rs_a * 2) + (cs_a * (k_full_pieces * 2))); a_int32_0 = _mm256_set1_epi8(a_kfringe); // Seperate register for intermediate op - inter_vec[0] = _mm256_maddubs_epi16(a_int32_0, b0); - inter_vec[1] = _mm256_maddubs_epi16(a_int32_0, b1); + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] - c_int16_2p0 = _mm256_add_epi16(inter_vec[0], c_int16_2p0); - c_int16_2p1 = _mm256_add_epi16(inter_vec[1], c_int16_2p1); + c_int16_2p0 = _mm256_add_epi16(inter_vec, c_int16_2p0); + + inter_vec = _mm256_maddubs_epi16(a_int32_0, b1); + + c_int16_2p1 = _mm256_add_epi16(inter_vec, c_int16_2p1); a_kfringe = *(a + (rs_a * 3) + (cs_a * (k_full_pieces * 2))); - a_int32_1 = _mm256_set1_epi8(a_kfringe); + a_int32_0 = _mm256_set1_epi8(a_kfringe); // Seperate register for intermediate op - inter_vec[2] = _mm256_maddubs_epi16(a_int32_1, b0); - inter_vec[3] = _mm256_maddubs_epi16(a_int32_1, b1); + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] - c_int16_3p0 = _mm256_add_epi16(inter_vec[2], c_int16_3p0); - c_int16_3p1 = _mm256_add_epi16(inter_vec[3], c_int16_3p1); + c_int16_3p0 = _mm256_add_epi16(inter_vec, c_int16_3p0); + + inter_vec = _mm256_maddubs_epi16(a_int32_0, b1); + c_int16_3p1 = _mm256_add_epi16(inter_vec, c_int16_3p1); a_kfringe = *(a + (rs_a * 4) + (cs_a * (k_full_pieces * 2))); a_int32_0 = _mm256_set1_epi8(a_kfringe); // Seperate register for intermediate op - inter_vec[0] = _mm256_maddubs_epi16(a_int32_0, b0); - inter_vec[1] = _mm256_maddubs_epi16(a_int32_0, b1); + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] - c_int16_4p0 = _mm256_add_epi16(inter_vec[0], c_int16_4p0); - c_int16_4p1 = _mm256_add_epi16(inter_vec[1], c_int16_4p1); + c_int16_4p0 = _mm256_add_epi16(inter_vec, c_int16_4p0); + + inter_vec = _mm256_maddubs_epi16(a_int32_0, b1); + c_int16_4p1 = _mm256_add_epi16(inter_vec, c_int16_4p1); a_kfringe = *(a + (rs_a * 5) + (cs_a * (k_full_pieces * 2))); - a_int32_1 = _mm256_set1_epi8(a_kfringe); + a_int32_0 = _mm256_set1_epi8(a_kfringe); // Seperate register for intermediate op - inter_vec[2] = _mm256_maddubs_epi16(a_int32_1, b0); - inter_vec[3] = _mm256_maddubs_epi16(a_int32_1, b1); + inter_vec = _mm256_maddubs_epi16(a_int32_0, b0); // Perform column direction mat-mul with k = 2. // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] - c_int16_5p0 = _mm256_add_epi16(inter_vec[2], c_int16_5p0); - c_int16_5p1 = _mm256_add_epi16(inter_vec[3], c_int16_5p1); + c_int16_5p0 = _mm256_add_epi16(inter_vec, c_int16_5p0); + + inter_vec = _mm256_maddubs_epi16(a_int32_0, b1); + c_int16_5p1 = _mm256_add_epi16(inter_vec, c_int16_5p1); } // Load alpha and beta - __m256i selector1 = _mm256_set1_epi16(alpha); - __m256i selector2 = _mm256_set1_epi16(beta); + __m256i alphav = _mm256_set1_epi16(alpha); + __m256i betav = _mm256_set1_epi16(beta); // Scale by alpha - c_int16_0p0 = _mm256_mullo_epi16(selector1, c_int16_0p0); - c_int16_0p1 = _mm256_mullo_epi16(selector1, c_int16_0p1); + c_int16_0p0 = _mm256_mullo_epi16(alphav, c_int16_0p0); + c_int16_0p1 = _mm256_mullo_epi16(alphav, c_int16_0p1); - c_int16_1p0 = _mm256_mullo_epi16(selector1, c_int16_1p0); - c_int16_1p1 = _mm256_mullo_epi16(selector1, c_int16_1p1); + c_int16_1p0 = _mm256_mullo_epi16(alphav, c_int16_1p0); + c_int16_1p1 = _mm256_mullo_epi16(alphav, c_int16_1p1); - c_int16_2p0 = _mm256_mullo_epi16(selector1, c_int16_2p0); - c_int16_2p1 = _mm256_mullo_epi16(selector1, c_int16_2p1); + c_int16_2p0 = _mm256_mullo_epi16(alphav, c_int16_2p0); + c_int16_2p1 = _mm256_mullo_epi16(alphav, c_int16_2p1); - c_int16_3p0 = _mm256_mullo_epi16(selector1, c_int16_3p0); - c_int16_3p1 = _mm256_mullo_epi16(selector1, c_int16_3p1); + c_int16_3p0 = _mm256_mullo_epi16(alphav, c_int16_3p0); + c_int16_3p1 = _mm256_mullo_epi16(alphav, c_int16_3p1); - c_int16_4p0 = _mm256_mullo_epi16(selector1, c_int16_4p0); - c_int16_4p1 = _mm256_mullo_epi16(selector1, c_int16_4p1); + c_int16_4p0 = _mm256_mullo_epi16(alphav, c_int16_4p0); + c_int16_4p1 = _mm256_mullo_epi16(alphav, c_int16_4p1); - c_int16_5p0 = _mm256_mullo_epi16(selector1, c_int16_5p0); - c_int16_5p1 = _mm256_mullo_epi16(selector1, c_int16_5p1); + c_int16_5p0 = _mm256_mullo_epi16(alphav, c_int16_5p0); + c_int16_5p1 = _mm256_mullo_epi16(alphav, c_int16_5p1); // Scale C by beta. if (beta != 0) { // c[0,0-15] - selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * (ir + 0)) + (0 * 16))); - selector1 = _mm256_mullo_epi16(selector2, selector1); + __m256i selector1 = + _mm256_loadu_si256((__m256i const *) + (c + (rs_c * (ir + 0)) + (0 * 16))); + selector1 = _mm256_mullo_epi16(betav, selector1); c_int16_0p0 = _mm256_add_epi16(selector1, c_int16_0p0); // c[0, 16-31] - selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * (ir + 0)) + (1 * 16))); - selector1 = _mm256_mullo_epi16(selector2, selector1); + selector1 = + _mm256_loadu_si256((__m256i const *) + (c + (rs_c * (ir + 0)) + (1 * 16))); + selector1 = _mm256_mullo_epi16(betav, selector1); c_int16_0p1 = _mm256_add_epi16(selector1, c_int16_0p1); // c[1,0-15] - selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * (ir + 1)) + (0 * 16))); - selector1 = _mm256_mullo_epi16(selector2, selector1); + selector1 = + _mm256_loadu_si256((__m256i const *) + (c + (rs_c * (ir + 1)) + (0 * 16))); + selector1 = _mm256_mullo_epi16(betav, selector1); c_int16_1p0 = _mm256_add_epi16(selector1, c_int16_1p0); // c[1,16-31] - selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * (ir + 1)) + (1 * 16))); - selector1 = _mm256_mullo_epi16(selector2, selector1); + selector1 = + _mm256_loadu_si256((__m256i const *) + (c + (rs_c * (ir + 1)) + (1 * 16))); + selector1 = _mm256_mullo_epi16(betav, selector1); c_int16_1p1 = _mm256_add_epi16(selector1, c_int16_1p1); // c[2,0-15] - selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * (ir + 2)) + (0 * 16))); - selector1 = _mm256_mullo_epi16(selector2, selector1); + selector1 = + _mm256_loadu_si256((__m256i const *) + (c + (rs_c * (ir + 2)) + (0 * 16))); + selector1 = _mm256_mullo_epi16(betav, selector1); c_int16_2p0 = _mm256_add_epi16(selector1, c_int16_2p0); // c[2,16-31] - selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * (ir + 2)) + (1 * 16))); - selector1 = _mm256_mullo_epi16(selector2, selector1); + selector1 = + _mm256_loadu_si256((__m256i const *) + (c + (rs_c * (ir + 2)) + (1 * 16))); + selector1 = _mm256_mullo_epi16(betav, selector1); c_int16_2p1 = _mm256_add_epi16(selector1, c_int16_2p1); // c[3,0-15] - selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * (ir + 3)) + (0 * 16))); - selector1 = _mm256_mullo_epi16(selector2, selector1); + selector1 = + _mm256_loadu_si256((__m256i const *) + (c + (rs_c * (ir + 3)) + (0 * 16))); + selector1 = _mm256_mullo_epi16(betav, selector1); c_int16_3p0 = _mm256_add_epi16(selector1, c_int16_3p0); // c[3,16-31] - selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * (ir + 3)) + (1 * 16))); - selector1 = _mm256_mullo_epi16(selector2, selector1); + selector1 = + _mm256_loadu_si256((__m256i const *) + (c + (rs_c * (ir + 3)) + (1 * 16))); + selector1 = _mm256_mullo_epi16(betav, selector1); c_int16_3p1 = _mm256_add_epi16(selector1, c_int16_3p1); // c[4,0-15] - selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * (ir + 4)) + (0 * 16))); - selector1 = _mm256_mullo_epi16(selector2, selector1); + selector1 = + _mm256_loadu_si256((__m256i const *) + (c + (rs_c * (ir + 4)) + (0 * 16))); + selector1 = _mm256_mullo_epi16(betav, selector1); c_int16_4p0 = _mm256_add_epi16(selector1, c_int16_4p0); // c[4,16-31] - selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * (ir + 4)) + (1 * 16))); - selector1 = _mm256_mullo_epi16(selector2, selector1); + selector1 = + _mm256_loadu_si256((__m256i const *) + (c + (rs_c * (ir + 4)) + (1 * 16))); + selector1 = _mm256_mullo_epi16(betav, selector1); c_int16_4p1 = _mm256_add_epi16(selector1, c_int16_4p1); // c[5,0-15] - selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * (ir + 5)) + (0 * 16))); - selector1 = _mm256_mullo_epi16(selector2, selector1); + selector1 = + _mm256_loadu_si256((__m256i const *) + (c + (rs_c * (ir + 5)) + (0 * 16))); + selector1 = _mm256_mullo_epi16(betav, selector1); c_int16_5p0 = _mm256_add_epi16(selector1, c_int16_5p0); // c[5,16-31] - selector1 = _mm256_loadu_si256((__m256i const *)(c + (rs_c * (ir + 5)) + (1 * 16))); - selector1 = _mm256_mullo_epi16(selector2, selector1); + selector1 = + _mm256_loadu_si256((__m256i const *) + (c + (rs_c * (ir + 5)) + (1 * 16))); + selector1 = _mm256_mullo_epi16(betav, selector1); c_int16_5p1 = _mm256_add_epi16(selector1, c_int16_5p1); } @@ -387,15 +428,17 @@ LPGEMM_MAIN_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x32) POST_OP_LABEL_LASTK_SAFE_JUMP POST_OPS_BIAS_6x32: { - selector1 = - _mm256_loadu_si256( (__m256i const *)((int16_t *)post_ops_list_temp->op_args1 + + __m256i selector1 = + _mm256_loadu_si256( (__m256i const *)( + (int16_t *)post_ops_list_temp->op_args1 + post_op_c_j + ( 0 * 16 )) ); - selector2 = - _mm256_loadu_si256( (__m256i const *)((int16_t *)post_ops_list_temp->op_args1 + + __m256i selector2 = + _mm256_loadu_si256( (__m256i const *)( + (int16_t *)post_ops_list_temp->op_args1 + post_op_c_j + ( 1 * 16 )) ); - + // c[0,0-15] - c_int16_0p0 = _mm256_add_epi16( selector1, c_int16_0p0 ); + c_int16_0p0 = _mm256_add_epi16(selector1, c_int16_0p0); // c[0, 16-31] c_int16_0p1 = _mm256_add_epi16( selector2, c_int16_0p1 ); @@ -434,7 +477,7 @@ LPGEMM_MAIN_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x32) } POST_OPS_RELU_6x32: { - selector1 = _mm256_setzero_si256 (); + __m256i selector1 = _mm256_setzero_si256 (); // c[0,0-15] c_int16_0p0 = _mm256_max_epi16( selector1, c_int16_0p0 ); @@ -476,9 +519,11 @@ LPGEMM_MAIN_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x32) } POST_OPS_RELU_SCALE_6x32: { - selector2 = + __m256i selector2 = _mm256_set1_epi16( *( ( int16_t* )post_ops_list_temp->op_args2 ) ); + __m256i selector1, b0; + // c[0,0-15] RELU_SCALE_OP_S16_AVX2(c_int16_0p0) From 95169ca8066c991bc0a3166a5e7da9594c0c393c Mon Sep 17 00:00:00 2001 From: jagar Date: Fri, 26 Aug 2022 22:39:51 +0530 Subject: [PATCH 197/243] CBLAS/BLAS interface decoupling for level 1 APIs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - In BLIS the cblas interface is implemented as a wrapper around the blas interface. For example the CBLAS api ‘cblas_dgemm’ internally invokes BLAS API ‘dgemm_’. - If the end user wants to use the different libraries for CBLAS and BLAS, current implantation of BLIS doesn’t allow it and may result in recursion - This change separate the CBLAS and BLAS implantation by adding an additional level of abstraction. The implementation of the API is moved to the new function which is invoked directly from the CBLAS and BLAS wrappers. AMD-Internal: [SWLCSG-1477] Change-Id: I0f4521e70a02f6132bdadbd4c07715c9d52fe62a --- frame/compat/bla_amax.c | 11 +- frame/compat/bla_amax.h | 9 +- frame/compat/bla_amax_amd.c | 32 ++- frame/compat/bla_amin.c | 14 +- frame/compat/bla_amin.h | 8 +- frame/compat/bla_asum.c | 13 +- frame/compat/bla_asum.h | 9 +- frame/compat/bla_axpby.c | 17 +- frame/compat/bla_axpby.h | 14 +- frame/compat/bla_axpy.c | 15 +- frame/compat/bla_axpy.h | 12 +- frame/compat/bla_axpy_amd.c | 66 ++++- frame/compat/bla_copy.c | 12 +- frame/compat/bla_copy.h | 10 +- frame/compat/bla_copy_amd.c | 35 ++- frame/compat/bla_dot.c | 48 +++- frame/compat/bla_dot.h | 33 ++- frame/compat/bla_dot_amd.c | 119 ++++++++- frame/compat/bla_nrm2.c | 13 +- frame/compat/bla_nrm2.h | 9 +- frame/compat/bla_scal.c | 13 +- frame/compat/bla_scal.h | 10 +- frame/compat/bla_scal_amd.c | 35 ++- frame/compat/cblas/f77_sub/f77_amax_sub.c | 17 +- frame/compat/cblas/f77_sub/f77_amax_sub.h | 10 +- frame/compat/cblas/f77_sub/f77_amin_sub.c | 18 +- frame/compat/cblas/f77_sub/f77_amin_sub.h | 9 +- frame/compat/cblas/f77_sub/f77_asum_sub.c | 17 +- frame/compat/cblas/f77_sub/f77_asum_sub.h | 10 +- frame/compat/cblas/f77_sub/f77_dot_sub.c | 60 ++++- frame/compat/cblas/f77_sub/f77_dot_sub.h | 29 ++- frame/compat/cblas/f77_sub/f77_nrm2_sub.c | 17 +- frame/compat/cblas/f77_sub/f77_nrm2_sub.h | 10 +- frame/compat/cblas/src/cblas_f77.h | 103 ++++---- frame/compat/f2c/bla_rot.c | 19 +- frame/compat/f2c/bla_rot.h | 5 +- frame/compat/f2c/bla_rotg.c | 19 +- frame/compat/f2c/bla_rotg.h | 5 +- frame/compat/f2c/bla_rotm.c | 19 +- frame/compat/f2c/bla_rotm.h | 5 +- frame/compat/f2c/bla_rotmg.c | 19 +- frame/compat/f2c/bla_rotmg.h | 5 +- frame/util/bli_util_api_wrap.c | 302 +++++++++++----------- 43 files changed, 946 insertions(+), 309 deletions(-) diff --git a/frame/compat/bla_amax.c b/frame/compat/bla_amax.c index b1cf77e7b8..bb924d53ba 100644 --- a/frame/compat/bla_amax.c +++ b/frame/compat/bla_amax.c @@ -41,7 +41,7 @@ #undef GENTFUNC #define GENTFUNC( ftype_x, chx, blasname, blisname ) \ \ -f77_int PASTEF772(i,chx,blasname) \ +f77_int PASTEF772S(i,chx,blasname) \ ( \ const f77_int* n, \ const ftype_x* x, const f77_int* incx \ @@ -95,6 +95,15 @@ f77_int PASTEF772(i,chx,blasname) \ \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ return f77_index; \ +}\ +\ +f77_int PASTEF772(i,chx,blasname) \ + ( \ + const f77_int* n, \ + const ftype_x* x, const f77_int* incx \ + ) \ +{ \ + return PASTEF772S(i,chx,blasname)( n, x, incx );\ } #ifdef BLIS_ENABLE_BLAS diff --git a/frame/compat/bla_amax.h b/frame/compat/bla_amax.h index 1f13715dc4..093f1f45cf 100644 --- a/frame/compat/bla_amax.h +++ b/frame/compat/bla_amax.h @@ -5,7 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -40,6 +41,12 @@ #define GENTPROT( ftype_x, chx, blasname ) \ \ BLIS_EXPORT_BLAS f77_int PASTEF772(i,chx,blasname) \ + ( \ + const f77_int* n, \ + const ftype_x* x, const f77_int* incx \ + );\ +\ +BLIS_EXPORT_BLAS f77_int PASTEF772S(i,chx,blasname) \ ( \ const f77_int* n, \ const ftype_x* x, const f77_int* incx \ diff --git a/frame/compat/bla_amax_amd.c b/frame/compat/bla_amax_amd.c index 2f7c2d2491..8804045350 100644 --- a/frame/compat/bla_amax_amd.c +++ b/frame/compat/bla_amax_amd.c @@ -41,7 +41,7 @@ #undef GENTFUNC #define GENTFUNC( ftype_x, chx, blasname, blisname ) \ \ -f77_int PASTEF772(i,chx,blasname) \ +f77_int PASTEF772S(i,chx,blasname) \ ( \ const f77_int* n, \ const ftype_x* x, const f77_int* incx \ @@ -95,11 +95,20 @@ f77_int PASTEF772(i,chx,blasname) \ \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ return f77_index; \ +}\ +\ +f77_int PASTEF772(i,chx,blasname) \ + ( \ + const f77_int* n, \ + const ftype_x* x, const f77_int* incx \ + ) \ +{ \ + return PASTEF772S(i,chx,blasname)( n, x, incx );\ } #ifdef BLIS_ENABLE_BLAS -f77_int isamax_ +f77_int isamax_blis_impl ( const f77_int* n, const float* x, const f77_int* incx @@ -197,8 +206,16 @@ f77_int isamax_ return f77_index; } +f77_int isamax_ + ( + const f77_int* n, + const float* x, const f77_int* incx + ) +{ + return isamax_blis_impl( n, x, incx ); +} -f77_int idamax_ +f77_int idamax_blis_impl ( const f77_int* n, const double* x, const f77_int* incx @@ -293,7 +310,14 @@ f77_int idamax_ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); return f77_index; } - +f77_int idamax_ + ( + const f77_int* n, + const double* x, const f77_int* incx + ) +{ + return idamax_blis_impl( n, x, incx ); +} INSERT_GENTFUNC_BLAS_CZ( amax, amaxv ) #endif diff --git a/frame/compat/bla_amin.c b/frame/compat/bla_amin.c index 7930fc1854..7c8be7e51f 100644 --- a/frame/compat/bla_amin.c +++ b/frame/compat/bla_amin.c @@ -5,7 +5,8 @@ libraries. Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. - + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -41,7 +42,7 @@ #undef GENTFUNC #define GENTFUNC( ftype_x, chx, blasname, blisname ) \ \ -f77_int PASTEF772(i,chx,blasname) \ +f77_int PASTEF772S(i,chx,blasname) \ ( \ const f77_int* n, \ const ftype_x* x, const f77_int* incx \ @@ -88,6 +89,15 @@ f77_int PASTEF772(i,chx,blasname) \ bli_finalize_auto(); \ \ return f77_index; \ +}\ +\ +f77_int PASTEF772(i,chx,blasname) \ + ( \ + const f77_int* n, \ + const ftype_x* x, const f77_int* incx \ + ) \ +{ \ + return PASTEF772S(i,chx,blasname)( n, x, incx );\ } #ifdef BLIS_ENABLE_BLAS diff --git a/frame/compat/bla_amin.h b/frame/compat/bla_amin.h index ebbed8262b..9b3ff7524a 100644 --- a/frame/compat/bla_amin.h +++ b/frame/compat/bla_amin.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -40,6 +40,12 @@ #define GENTPROT( ftype_x, chx, blasname ) \ \ BLIS_EXPORT_BLAS f77_int PASTEF772(i,chx,blasname) \ + ( \ + const f77_int* n, \ + const ftype_x* x, const f77_int* incx \ + );\ +\ +BLIS_EXPORT_BLAS f77_int PASTEF772S(i,chx,blasname) \ ( \ const f77_int* n, \ const ftype_x* x, const f77_int* incx \ diff --git a/frame/compat/bla_asum.c b/frame/compat/bla_asum.c index c104be96bd..024d821efb 100644 --- a/frame/compat/bla_asum.c +++ b/frame/compat/bla_asum.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -41,7 +41,7 @@ #undef GENTFUNCR2 #define GENTFUNCR2( ftype_x, ftype_r, chx, chr, blasname, blisname ) \ \ -ftype_r PASTEF772(chr,chx,blasname) \ +ftype_r PASTEF772S(chr,chx,blasname) \ ( \ const f77_int* n, \ const ftype_x* x, const f77_int* incx \ @@ -79,6 +79,15 @@ ftype_r PASTEF772(chr,chx,blasname) \ bli_finalize_auto(); \ \ return asum; \ +}\ +\ +ftype_r PASTEF772(chr,chx,blasname) \ + ( \ + const f77_int* n, \ + const ftype_x* x, const f77_int* incx \ + ) \ +{ \ + return PASTEF772S(chr,chx,blasname)( n, x, incx );\ } #ifdef BLIS_ENABLE_BLAS diff --git a/frame/compat/bla_asum.h b/frame/compat/bla_asum.h index a9ef27a036..6460a11178 100644 --- a/frame/compat/bla_asum.h +++ b/frame/compat/bla_asum.h @@ -5,7 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -40,6 +41,12 @@ #define GENTPROTR2( ftype_x, ftype_r, chx, chr, blasname ) \ \ BLIS_EXPORT_BLAS ftype_r PASTEF772(chr,chx,blasname) \ + ( \ + const f77_int* n, \ + const ftype_x* x, const f77_int* incx \ + );\ +\ +BLIS_EXPORT_BLAS ftype_r PASTEF772S(chr,chx,blasname) \ ( \ const f77_int* n, \ const ftype_x* x, const f77_int* incx \ diff --git a/frame/compat/bla_axpby.c b/frame/compat/bla_axpby.c index be53ec480b..90d0563190 100644 --- a/frame/compat/bla_axpby.c +++ b/frame/compat/bla_axpby.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -41,7 +41,7 @@ #undef GENTFUNC #define GENTFUNC( ftype, ch, blasname, blisname ) \ \ -void PASTEF77(ch,blasname) \ +void PASTEF77S(ch,blasname) \ ( \ const f77_int* n, \ const ftype* alpha, \ @@ -85,6 +85,19 @@ void PASTEF77(ch,blasname) \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ +}\ +\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_int* n, \ + const ftype* alpha, \ + const ftype* x, const f77_int* incx, \ + const ftype* beta, \ + ftype* y, const f77_int* incy \ + ) \ +{ \ + PASTEF77S(ch,blasname) \ + ( n, alpha, x, incx, beta, y, incy ); \ } #ifdef BLIS_ENABLE_BLAS diff --git a/frame/compat/bla_axpby.h b/frame/compat/bla_axpby.h index ab2952be98..74ca8908a7 100644 --- a/frame/compat/bla_axpby.h +++ b/frame/compat/bla_axpby.h @@ -5,7 +5,8 @@ libraries. Copyright (C) 2020, Advanced Micro Devices, Inc. - + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -40,6 +41,15 @@ #define GENTPROT( ftype, ch, blasname ) \ \ BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ + ( \ + const f77_int* n, \ + const ftype* alpha, \ + const ftype* x, const f77_int* incx, \ + const ftype* beta, \ + ftype* y, const f77_int* incy \ + );\ +\ +BLIS_EXPORT_BLAS void PASTEF77S(ch,blasname) \ ( \ const f77_int* n, \ const ftype* alpha, \ @@ -47,7 +57,7 @@ BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ const ftype* beta, \ ftype* y, const f77_int* incy \ ); - + #ifdef BLIS_ENABLE_BLAS INSERT_GENTPROT_BLAS( axpby ) #endif diff --git a/frame/compat/bla_axpy.c b/frame/compat/bla_axpy.c index 1a30f417b3..e5ca995914 100644 --- a/frame/compat/bla_axpy.c +++ b/frame/compat/bla_axpy.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 22, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -42,7 +42,7 @@ #undef GENTFUNC #define GENTFUNC( ftype, ch, blasname, blisname ) \ \ -void PASTEF77(ch,blasname) \ +void PASTEF77S(ch,blasname) \ ( \ const f77_int* n, \ const ftype* alpha, \ @@ -83,6 +83,17 @@ void PASTEF77(ch,blasname) \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ +}\ +\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_int* n, \ + const ftype* alpha, \ + const ftype* x, const f77_int* incx, \ + ftype* y, const f77_int* incy \ + ) \ +{ \ + PASTEF77S(ch,blasname)( n, alpha, x, incx, y, incy ) ; \ } #ifdef BLIS_ENABLE_BLAS diff --git a/frame/compat/bla_axpy.h b/frame/compat/bla_axpy.h index 294a385c78..dcbc5df8c1 100644 --- a/frame/compat/bla_axpy.h +++ b/frame/compat/bla_axpy.h @@ -5,7 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -40,13 +41,20 @@ #define GENTPROT( ftype, ch, blasname ) \ \ BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ + ( \ + const f77_int* n, \ + const ftype* alpha, \ + const ftype* x, const f77_int* incx, \ + ftype* y, const f77_int* incy \ + );\ +\ +BLIS_EXPORT_BLAS void PASTEF77S(ch,blasname) \ ( \ const f77_int* n, \ const ftype* alpha, \ const ftype* x, const f77_int* incx, \ ftype* y, const f77_int* incy \ ); - #ifdef BLIS_ENABLE_BLAS INSERT_GENTPROT_BLAS( axpy ) #endif diff --git a/frame/compat/bla_axpy_amd.c b/frame/compat/bla_axpy_amd.c index 8a9f0280c6..62f0c4df3d 100644 --- a/frame/compat/bla_axpy_amd.c +++ b/frame/compat/bla_axpy_amd.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 22, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -42,7 +42,7 @@ #undef GENTFUNC #define GENTFUNC( ftype, ch, blasname, blisname ) \ \ -void PASTEF77(ch,blasname) \ +void PASTEF77S(ch,blasname) \ ( \ const f77_int* n, \ const ftype* alpha, \ @@ -83,11 +83,23 @@ void PASTEF77(ch,blasname) \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ +}\ +\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_int* n, \ + const ftype* alpha, \ + const ftype* x, const f77_int* incx, \ + ftype* y, const f77_int* incy \ + ) \ +{ \ + PASTEF77S(ch,blasname)( n, alpha, x, incx, y, incy ) ; \ } + #ifdef BLIS_ENABLE_BLAS -void saxpy_ +void saxpy_blis_impl ( const f77_int* n, const float* alpha, @@ -178,8 +190,18 @@ void saxpy_ // bli_finalize_auto(); AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); } +void saxpy_ +( + const f77_int* n, + const float* alpha, + const float* x, const f77_int* incx, + float* y, const f77_int* incy + ) +{ + saxpy_blis_impl( n, alpha, x, incx, y, incy ) ; +} -void daxpy_ +void daxpy_blis_impl ( const f77_int* n, const double* alpha, @@ -271,8 +293,18 @@ void daxpy_ /* Finalize BLIS. */ // bli_finalize_auto(); } +void daxpy_ +( + const f77_int* n, + const double* alpha, + const double* x, const f77_int* incx, + double* y, const f77_int* incy + ) +{ + daxpy_blis_impl( n, alpha, x, incx, y, incy ) ; +} -void caxpy_ +void caxpy_blis_impl ( const f77_int* n, const scomplex* alpha, @@ -363,8 +395,17 @@ void caxpy_ /* Finalize BLIS. */ // bli_finalize_auto(); } - -void zaxpy_ +void caxpy_ +( + const f77_int* n, + const scomplex* alpha, + const scomplex* x, const f77_int* incx, + scomplex* y, const f77_int* incy + ) +{ + caxpy_blis_impl( n, alpha, x, incx, y, incy ) ; +} +void zaxpy_blis_impl ( const f77_int* n, const dcomplex* alpha, @@ -456,7 +497,16 @@ void zaxpy_ /* Finalize BLIS. */ // bli_finalize_auto(); } - +void zaxpy_ +( + const f77_int* n, + const dcomplex* alpha, + const dcomplex* x, const f77_int* incx, + dcomplex* y, const f77_int* incy + ) +{ + zaxpy_blis_impl( n, alpha, x, incx, y, incy ) ; +} #endif diff --git a/frame/compat/bla_copy.c b/frame/compat/bla_copy.c index 74baba689c..4f4e1e874e 100644 --- a/frame/compat/bla_copy.c +++ b/frame/compat/bla_copy.c @@ -42,7 +42,7 @@ #undef GENTFUNC #define GENTFUNC( ftype, ch, blasname, blisname ) \ \ -void PASTEF77(ch,blasname) \ +void PASTEF77S(ch,blasname) \ ( \ const f77_int* n, \ const ftype* x, const f77_int* incx, \ @@ -85,6 +85,16 @@ void PASTEF77(ch,blasname) \ \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ +}\ +\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_int* n, \ + const ftype* x, const f77_int* incx, \ + ftype* y, const f77_int* incy \ + ) \ +{ \ + PASTEF77S(ch,blasname)( n, x, incx, y, incy ); \ } #ifdef BLIS_ENABLE_BLAS diff --git a/frame/compat/bla_copy.h b/frame/compat/bla_copy.h index 679017b19d..cfe67967c4 100644 --- a/frame/compat/bla_copy.h +++ b/frame/compat/bla_copy.h @@ -5,7 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -40,6 +41,13 @@ #define GENTPROT( ftype, ch, blasname ) \ \ BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ + ( \ + const f77_int* n, \ + const ftype* x, const f77_int* incx, \ + ftype* y, const f77_int* incy \ + );\ +\ +BLIS_EXPORT_BLAS void PASTEF77S(ch,blasname) \ ( \ const f77_int* n, \ const ftype* x, const f77_int* incx, \ diff --git a/frame/compat/bla_copy_amd.c b/frame/compat/bla_copy_amd.c index 8dc4d5287c..4ed0c7f548 100644 --- a/frame/compat/bla_copy_amd.c +++ b/frame/compat/bla_copy_amd.c @@ -42,7 +42,7 @@ #undef GENTFUNC #define GENTFUNC( ftype, ch, blasname, blisname ) \ \ -void PASTEF77(ch,blasname) \ +void PASTEF77S(ch,blasname) \ ( \ const f77_int* n, \ const ftype* x, const f77_int* incx, \ @@ -85,11 +85,21 @@ void PASTEF77(ch,blasname) \ \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ +}\ +\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_int* n, \ + const ftype* x, const f77_int* incx, \ + ftype* y, const f77_int* incy \ + ) \ +{ \ + PASTEF77S(ch,blasname)( n, x, incx, y, incy ); \ } #ifdef BLIS_ENABLE_BLAS -void scopy_ +void scopy_blis_impl ( const f77_int* n, const float* x, const f77_int* incx, @@ -183,8 +193,17 @@ void scopy_ /* Finalize BLIS. */ // bli_finalize_auto(); } +void scopy_ +( + const f77_int* n, + const float* x, const f77_int* incx, + float* y, const f77_int* incy +) +{ + scopy_blis_impl( n, x, incx, y, incy ); +} -void dcopy_ +void dcopy_blis_impl ( const f77_int* n, const double* x, const f77_int* incx, @@ -279,7 +298,15 @@ void dcopy_ /* Finalize BLIS. */ // bli_finalize_auto(); } - +void dcopy_ +( + const f77_int* n, + const double* x, const f77_int* incx, + double* y, const f77_int* incy +) +{ + dcopy_blis_impl( n, x, incx, y, incy ); +} INSERT_GENTFUNC_BLAS_CZ(copy, copyv) #endif diff --git a/frame/compat/bla_dot.c b/frame/compat/bla_dot.c index 3c4d8c538f..79c65c4d8d 100644 --- a/frame/compat/bla_dot.c +++ b/frame/compat/bla_dot.c @@ -42,7 +42,7 @@ #undef GENTFUNCDOT #define GENTFUNCDOT( ftype, ch, chc, blis_conjx, blasname, blisname ) \ \ -ftype PASTEF772(ch,blasname,chc) \ +ftype PASTEF772S(ch,blasname,chc) \ ( \ const f77_int* n, \ const ftype* x, const f77_int* incx, \ @@ -87,6 +87,16 @@ ftype PASTEF772(ch,blasname,chc) \ bli_finalize_auto(); \ \ return rho; \ +}\ +\ +ftype PASTEF772(ch,blasname,chc) \ + ( \ + const f77_int* n, \ + const ftype* x, const f77_int* incx, \ + const ftype* y, const f77_int* incy \ + ) \ +{ \ + return PASTEF772S(ch,blasname,chc)( n, x, incx, y, incy );\ } #ifdef BLIS_ENABLE_BLAS @@ -100,7 +110,7 @@ INSERT_GENTFUNCDOTC_BLAS( dot, dotv ) #undef GENTFUNCDOT #define GENTFUNCDOT( ftype, ch, chc, blis_conjx, blasname, blisname ) \ \ -void PASTEF772(ch,blasname,chc) \ +void PASTEF772S(ch,blasname,chc) \ ( \ ftype* rhop, \ const f77_int* n, \ @@ -146,6 +156,17 @@ void PASTEF772(ch,blasname,chc) \ bli_finalize_auto(); \ \ *rhop = rho; \ +}\ +\ +void PASTEF772(ch,blasname,chc) \ + ( \ + ftype* rhop, \ + const f77_int* n, \ + const ftype* x, const f77_int* incx, \ + const ftype* y, const f77_int* incy \ + ) \ +{ \ + PASTEF772S(ch,blasname,chc)( rhop, n, x, incx, y, incy );\ } INSERT_GENTFUNCDOTC_BLAS( dot, dotv ) @@ -157,7 +178,7 @@ INSERT_GENTFUNCDOTC_BLAS( dot, dotv ) // Input vectors stored in single precision, computed in double precision, // with result returned in single precision. -float PASTEF77(sd,sdot) +float PASTEF77S(sd,sdot) ( const f77_int* n, const float* sb, @@ -176,10 +197,20 @@ float PASTEF77(sd,sdot) ) ); } +float PASTEF77(sd,sdot) + ( + const f77_int* n, + const float* sb, + const float* x, const f77_int* incx, + const float* y, const f77_int* incy + ) +{ + return PASTEF77S(sd,sdot)( n, sb, x, incx, y, incy ); +} // Input vectors stored in single precision, computed in double precision, // with result returned in double precision. -double PASTEF77(d,sdot) +double PASTEF77S(d,sdot) ( const f77_int* n, const float* x, const f77_int* incx, @@ -223,5 +254,14 @@ double PASTEF77(d,sdot) return rho; } +double PASTEF77(d,sdot) + ( + const f77_int* n, + const float* x, const f77_int* incx, + const float* y, const f77_int* incy + ) +{ + return PASTEF77S(d,sdot)( n, x, incx, y, incy ); +} #endif // BLIS_ENABLE_BLAS diff --git a/frame/compat/bla_dot.h b/frame/compat/bla_dot.h index 16bc3f97cc..a582503753 100644 --- a/frame/compat/bla_dot.h +++ b/frame/compat/bla_dot.h @@ -5,7 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -40,6 +41,13 @@ #define GENTPROTDOT( ftype, ch, chc, blasname ) \ \ BLIS_EXPORT_BLAS ftype PASTEF772(ch,blasname,chc) \ + ( \ + const f77_int* n, \ + const ftype* x, const f77_int* incx, \ + const ftype* y, const f77_int* incy \ + );\ +\ +BLIS_EXPORT_BLAS ftype PASTEF772S(ch,blasname,chc) \ ( \ const f77_int* n, \ const ftype* x, const f77_int* incx, \ @@ -60,6 +68,14 @@ INSERT_GENTPROTDOTC_BLAS( dot ) #define GENTPROTDOT( ftype, ch, chc, blasname ) \ \ BLIS_EXPORT_BLAS void PASTEF772(ch,blasname,chc) \ + ( \ + ftype* rhop, \ + const f77_int* n, \ + const ftype* x, const f77_int* incx, \ + const ftype* y, const f77_int* incy \ + );\ +\ +BLIS_EXPORT_BLAS void PASTEF772S(ch,blasname,chc) \ ( \ ftype* rhop, \ const f77_int* n, \ @@ -80,12 +96,25 @@ BLIS_EXPORT_BLAS float PASTEF77(sd,sdot) const float* x, const f77_int* incx, const float* y, const f77_int* incy ); - +BLIS_EXPORT_BLAS float PASTEF77S(sd,sdot) + ( + const f77_int* n, + const float* sb, + const float* x, const f77_int* incx, + const float* y, const f77_int* incy + ); + BLIS_EXPORT_BLAS double PASTEF77(d,sdot) ( const f77_int* n, const float* x, const f77_int* incx, const float* y, const f77_int* incy ); +BLIS_EXPORT_BLAS double PASTEF77S(d,sdot) + ( + const f77_int* n, + const float* x, const f77_int* incx, + const float* y, const f77_int* incy + ); #endif diff --git a/frame/compat/bla_dot_amd.c b/frame/compat/bla_dot_amd.c index 0cdaa6535b..0e954f3317 100644 --- a/frame/compat/bla_dot_amd.c +++ b/frame/compat/bla_dot_amd.c @@ -42,7 +42,7 @@ #undef GENTFUNCDOT #define GENTFUNCDOT( ftype, ch, chc, blis_conjx, blasname, blisname ) \ \ -ftype PASTEF772(ch,blasname,chc) \ +ftype PASTEF772S(ch,blasname,chc) \ ( \ const f77_int* n, \ const ftype* x, const f77_int* incx, \ @@ -87,10 +87,20 @@ ftype PASTEF772(ch,blasname,chc) \ bli_finalize_auto(); \ \ return rho; \ +}\ +\ +ftype PASTEF772(ch,blasname,chc) \ + ( \ + const f77_int* n, \ + const ftype* x, const f77_int* incx, \ + const ftype* y, const f77_int* incy \ + ) \ +{ \ + return PASTEF772S(ch,blasname,chc)( n, x, incx, y, incy );\ } #ifdef BLIS_ENABLE_BLAS -float sdot_ +float sdot_blis_impl ( const f77_int* n, const float* x, const f77_int* incx, @@ -191,7 +201,17 @@ float sdot_ return rho; } -double ddot_ +float sdot_ + ( + const f77_int* n, + const float* x, const f77_int* incx, + const float* y, const f77_int* incy + ) +{ + return sdot_blis_impl( n, x, incx, y, incy ); +} + +double ddot_blis_impl ( const f77_int* n, const double* x, const f77_int* incx, @@ -291,9 +311,18 @@ double ddot_ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); return rho; } +double ddot_ + ( + const f77_int* n, + const double* x, const f77_int* incx, + const double* y, const f77_int* incy + ) +{ + return ddot_blis_impl( n, x, incx, y, incy ); +} #ifdef BLIS_DISABLE_COMPLEX_RETURN_INTEL -scomplex cdotu_ +scomplex cdotu_blis_impl ( const f77_int* n, const scomplex* x, const f77_int* incx, @@ -393,8 +422,17 @@ scomplex cdotu_ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); return rho; } +scomplex cdotu_ + ( + const f77_int* n, + const scomplex* x, const f77_int* incx, + const scomplex* y, const f77_int* incy + ) +{ + return cdotu_blis_impl( n, x, incx, y, incy ); +} -dcomplex zdotu_ +dcomplex zdotu_blis_impl ( const f77_int* n, const dcomplex* x, const f77_int* incx, @@ -497,9 +535,17 @@ dcomplex zdotu_ return rho; } +dcomplex zdotu_ + ( + const f77_int* n, + const dcomplex* x, const f77_int* incx, + const dcomplex* y, const f77_int* incy + ) +{ + return zdotu_blis_impl( n, x, incx, y, incy ); +} - -scomplex cdotc_ +scomplex cdotc_blis_impl ( const f77_int* n, const scomplex* x, const f77_int* incx, @@ -601,8 +647,17 @@ scomplex cdotc_ return rho; } +scomplex cdotc_ + ( + const f77_int* n, + const scomplex* x, const f77_int* incx, + const scomplex* y, const f77_int* incy + ) +{ + return cdotc_blis_impl( n, x, incx, y, incy ); +} -dcomplex zdotc_ +dcomplex zdotc_blis_impl ( const f77_int* n, const dcomplex* x, const f77_int* incx, @@ -708,13 +763,21 @@ dcomplex zdotc_ return rho; } - +dcomplex zdotc_ + ( + const f77_int* n, + const dcomplex* x, const f77_int* incx, + const dcomplex* y, const f77_int* incy + ) +{ + return zdotc_blis_impl( n, x, incx, y, incy ); +} #else // BLIS_DISABLE_COMPLEX_RETURN_INTEL // For the "intel" complex return type, use a hidden parameter to return the result #undef GENTFUNCDOT #define GENTFUNCDOT( ftype, ch, chc, blis_conjx, blasname, blisname ) \ \ -void PASTEF772(ch,blasname,chc) \ +void PASTEF772S(ch,blasname,chc) \ ( \ ftype* rhop, \ const f77_int* n, \ @@ -760,6 +823,17 @@ void PASTEF772(ch,blasname,chc) \ bli_finalize_auto(); \ \ *rhop = rho; \ +}\ +\ +void PASTEF772(ch,blasname,chc) \ + ( \ + ftype* rhop, \ + const f77_int* n, \ + const ftype* x, const f77_int* incx, \ + const ftype* y, const f77_int* incy \ + ) \ +{ \ + PASTEF772S(ch,blasname,chc)( rhop, n, x, incx, y, incy );\ } INSERT_GENTFUNCDOTC_BLAS( dot, dotv ) @@ -771,7 +845,7 @@ INSERT_GENTFUNCDOTC_BLAS( dot, dotv ) // Input vectors stored in single precision, computed in double precision, // with result returned in single precision. -float PASTEF77(sd,sdot) +float PASTEF77S(sd,sdot) ( const f77_int* n, const float* sb, @@ -782,7 +856,7 @@ float PASTEF77(sd,sdot) return ( float ) ( ( double )(*sb) + - PASTEF77(d,sdot) + PASTEF77S(d,sdot) ( n, x, incx, @@ -790,10 +864,20 @@ float PASTEF77(sd,sdot) ) ); } +float PASTEF77(sd,sdot) + ( + const f77_int* n, + const float* sb, + const float* x, const f77_int* incx, + const float* y, const f77_int* incy + ) +{ + return PASTEF77S(sd,sdot)( n,sb, x, incx, y, incy ); +} // Input vectors stored in single precision, computed in double precision, // with result returned in double precision. -double PASTEF77(d,sdot) +double PASTEF77S(d,sdot) ( const f77_int* n, const float* x, const f77_int* incx, @@ -837,5 +921,14 @@ double PASTEF77(d,sdot) return rho; } +double PASTEF77(d,sdot) + ( + const f77_int* n, + const float* x, const f77_int* incx, + const float* y, const f77_int* incy + ) +{ + return PASTEF77S(d,sdot)( n, x, incx, y, incy ); +} #endif diff --git a/frame/compat/bla_nrm2.c b/frame/compat/bla_nrm2.c index 576d9eda8c..a823747ec9 100755 --- a/frame/compat/bla_nrm2.c +++ b/frame/compat/bla_nrm2.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -42,7 +42,7 @@ #undef GENTFUNCR2 #define GENTFUNCR2( ftype_x, ftype_r, chx, chr, blasname, blisname ) \ \ -ftype_r PASTEF772(chr,chx,blasname) \ +ftype_r PASTEF772S(chr,chx,blasname) \ ( \ const f77_int* n, \ const ftype_x* x, const f77_int* incx \ @@ -80,6 +80,15 @@ ftype_r PASTEF772(chr,chx,blasname) \ bli_finalize_auto(); \ \ return norm; \ +}\ +\ +ftype_r PASTEF772(chr,chx,blasname) \ + ( \ + const f77_int* n, \ + const ftype_x* x, const f77_int* incx \ + ) \ +{ \ + return PASTEF772S(chr,chx,blasname)( n, x, incx );\ } #ifdef BLIS_ENABLE_BLAS diff --git a/frame/compat/bla_nrm2.h b/frame/compat/bla_nrm2.h index a8bc25ef48..f690b4e071 100644 --- a/frame/compat/bla_nrm2.h +++ b/frame/compat/bla_nrm2.h @@ -5,7 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -40,6 +41,12 @@ #define GENTPROTR2( ftype_x, ftype_r, chx, chr, blasname ) \ \ BLIS_EXPORT_BLAS ftype_r PASTEF772(chr,chx,blasname) \ + ( \ + const f77_int* n, \ + const ftype_x* x, const f77_int* incx \ + );\ +\ +BLIS_EXPORT_BLAS ftype_r PASTEF772S(chr,chx,blasname) \ ( \ const f77_int* n, \ const ftype_x* x, const f77_int* incx \ diff --git a/frame/compat/bla_scal.c b/frame/compat/bla_scal.c index b9651577eb..44a97144eb 100644 --- a/frame/compat/bla_scal.c +++ b/frame/compat/bla_scal.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020-22, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -42,7 +42,7 @@ #undef GENTFUNCSCAL #define GENTFUNCSCAL( ftype_x, ftype_a, chx, cha, blasname, blisname ) \ \ -void PASTEF772(chx,cha,blasname) \ +void PASTEF772S(chx,cha,blasname) \ ( \ const f77_int* n, \ const ftype_a* alpha, \ @@ -90,6 +90,15 @@ void PASTEF772(chx,cha,blasname) \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ +}\ +void PASTEF772(chx,cha,blasname) \ + ( \ + const f77_int* n, \ + const ftype_a* alpha, \ + ftype_x* x, const f77_int* incx \ + ) \ +{ \ + PASTEF772S(chx,cha,blasname)( n, alpha, x, incx ); \ } #ifdef BLIS_ENABLE_BLAS diff --git a/frame/compat/bla_scal.h b/frame/compat/bla_scal.h index c8e898b6ba..c3b4540bb0 100644 --- a/frame/compat/bla_scal.h +++ b/frame/compat/bla_scal.h @@ -5,7 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -40,6 +41,13 @@ #define GENTPROTSCAL( ftype_a, ftype_x, cha, chx, blasname ) \ \ BLIS_EXPORT_BLAS void PASTEF772(chx,cha,blasname) \ + ( \ + const f77_int* n, \ + const ftype_a* alpha, \ + ftype_x* x, const f77_int* incx \ + );\ +\ +BLIS_EXPORT_BLAS void PASTEF772S(chx,cha,blasname) \ ( \ const f77_int* n, \ const ftype_a* alpha, \ diff --git a/frame/compat/bla_scal_amd.c b/frame/compat/bla_scal_amd.c index 178776a149..518058b060 100644 --- a/frame/compat/bla_scal_amd.c +++ b/frame/compat/bla_scal_amd.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020-22, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -42,7 +42,7 @@ #undef GENTFUNCSCAL #define GENTFUNCSCAL( ftype_x, ftype_a, chx, cha, blasname, blisname ) \ \ -void PASTEF772(chx,cha,blasname) \ +void PASTEF772S(chx,cha,blasname) \ ( \ const f77_int* n, \ const ftype_a* alpha, \ @@ -90,11 +90,20 @@ void PASTEF772(chx,cha,blasname) \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ +}\ +void PASTEF772(chx,cha,blasname) \ + ( \ + const f77_int* n, \ + const ftype_a* alpha, \ + ftype_x* x, const f77_int* incx \ + ) \ +{ \ + PASTEF772S(chx,cha,blasname)( n, alpha, x, incx ); \ } #ifdef BLIS_ENABLE_BLAS -void sscal_ +void sscal_blis_impl ( const f77_int* n, const float* alpha, @@ -173,8 +182,17 @@ void sscal_ // bli_finalize_auto(); AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) } +void sscal_ + ( + const f77_int* n, + const float* alpha, + float* x, const f77_int* incx + ) +{ + sscal_blis_impl( n, alpha, x, incx ); +} -void dscal_ +void dscal_blis_impl ( const f77_int* n, const double* alpha, @@ -254,6 +272,15 @@ void dscal_ // bli_finalize_auto(); AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) } +void dscal_ + ( + const f77_int* n, + const double* alpha, + double* x, const f77_int* incx + ) +{ + dscal_blis_impl( n, alpha, x, incx ); +} INSERT_GENTFUNCSCAL_BLAS_CZ( scal, scalv ) diff --git a/frame/compat/cblas/f77_sub/f77_amax_sub.c b/frame/compat/cblas/f77_sub/f77_amax_sub.c index cc26196d79..c394ed4d40 100644 --- a/frame/compat/cblas/f77_sub/f77_amax_sub.c +++ b/frame/compat/cblas/f77_sub/f77_amax_sub.c @@ -5,7 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -42,18 +43,28 @@ #undef GENTFUNC #define GENTFUNC( ftype_x, chx, blasname, blisname ) \ \ -void PASTEF773(i,chx,blasname,sub) \ +void PASTEF773S(i,chx,blasname,sub) \ ( \ const f77_int* n, \ const ftype_x* x, const f77_int* incx, \ f77_int* rval \ ) \ { \ - *rval = PASTEF772(i,chx,blasname) \ + *rval = PASTEF772S(i,chx,blasname) \ ( \ n, \ x, incx \ ); \ +}\ +\ +void PASTEF773(i,chx,blasname,sub) \ + ( \ + const f77_int* n, \ + const ftype_x* x, const f77_int* incx, \ + f77_int* rval \ + ) \ +{ \ + PASTEF773S(i,chx,blasname,sub) ( n, x, incx, rval );\ } #ifdef BLIS_ENABLE_CBLAS diff --git a/frame/compat/cblas/f77_sub/f77_amax_sub.h b/frame/compat/cblas/f77_sub/f77_amax_sub.h index 9cd1202d26..35d501ba4a 100644 --- a/frame/compat/cblas/f77_sub/f77_amax_sub.h +++ b/frame/compat/cblas/f77_sub/f77_amax_sub.h @@ -5,7 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -40,6 +41,13 @@ #define GENTPROT( ftype_x, chx, blasname ) \ \ BLIS_EXPORT_BLAS void PASTEF773(i,chx,blasname,sub) \ + ( \ + const f77_int* n, \ + const ftype_x* x, const f77_int* incx, \ + f77_int* rval \ + );\ +\ +BLIS_EXPORT_BLAS void PASTEF773S(i,chx,blasname,sub) \ ( \ const f77_int* n, \ const ftype_x* x, const f77_int* incx, \ diff --git a/frame/compat/cblas/f77_sub/f77_amin_sub.c b/frame/compat/cblas/f77_sub/f77_amin_sub.c index 73e1951839..2eaa231061 100644 --- a/frame/compat/cblas/f77_sub/f77_amin_sub.c +++ b/frame/compat/cblas/f77_sub/f77_amin_sub.c @@ -4,8 +4,8 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. - + Copyright (C) 2020-2022, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -42,18 +42,28 @@ #undef GENTFUNC #define GENTFUNC( ftype_x, chx, blasname, blisname ) \ \ -void PASTEF773(i,chx,blasname,sub) \ +void PASTEF773S(i,chx,blasname,sub) \ ( \ const f77_int* n, \ const ftype_x* x, const f77_int* incx, \ f77_int* rval \ ) \ { \ - *rval = PASTEF772(i,chx,blasname) \ + *rval = PASTEF772S(i,chx,blasname) \ ( \ n, \ x, incx \ ); \ +}\ +\ +void PASTEF773(i,chx,blasname,sub) \ + ( \ + const f77_int* n, \ + const ftype_x* x, const f77_int* incx, \ + f77_int* rval \ + ) \ +{ \ + PASTEF773S(i,chx,blasname,sub) ( n, x, incx, rval );\ } #ifdef BLIS_ENABLE_CBLAS diff --git a/frame/compat/cblas/f77_sub/f77_amin_sub.h b/frame/compat/cblas/f77_sub/f77_amin_sub.h index 522dcc7938..90b4f25b5f 100644 --- a/frame/compat/cblas/f77_sub/f77_amin_sub.h +++ b/frame/compat/cblas/f77_sub/f77_amin_sub.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -39,6 +39,13 @@ #define GENTPROT( ftype_x, chx, blasname ) \ \ BLIS_EXPORT_BLAS void PASTEF773(i,chx,blasname,sub) \ + ( \ + const f77_int* n, \ + const ftype_x* x, const f77_int* incx, \ + f77_int* rval \ + );\ +\ +BLIS_EXPORT_BLAS void PASTEF773S(i,chx,blasname,sub) \ ( \ const f77_int* n, \ const ftype_x* x, const f77_int* incx, \ diff --git a/frame/compat/cblas/f77_sub/f77_asum_sub.c b/frame/compat/cblas/f77_sub/f77_asum_sub.c index f1cb35b0cc..befac150e0 100644 --- a/frame/compat/cblas/f77_sub/f77_asum_sub.c +++ b/frame/compat/cblas/f77_sub/f77_asum_sub.c @@ -5,7 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -42,18 +43,28 @@ #undef GENTFUNCR2 #define GENTFUNCR2( ftype_x, ftype_r, chx, chr, blasname, blisname ) \ \ -void PASTEF773(chr,chx,blasname,sub) \ +void PASTEF773S(chr,chx,blasname,sub) \ ( \ const f77_int* n, \ const ftype_x* x, const f77_int* incx, \ ftype_r* rval \ ) \ { \ - *rval = PASTEF772(chr,chx,blasname) \ + *rval = PASTEF772S(chr,chx,blasname) \ ( \ n, \ x, incx \ ); \ +}\ +\ +void PASTEF773(chr,chx,blasname,sub) \ + ( \ + const f77_int* n, \ + const ftype_x* x, const f77_int* incx, \ + ftype_r* rval \ + ) \ +{ \ + PASTEF773S(chr,chx,blasname,sub) ( n, x, incx, rval ); \ } #ifdef BLIS_ENABLE_CBLAS diff --git a/frame/compat/cblas/f77_sub/f77_asum_sub.h b/frame/compat/cblas/f77_sub/f77_asum_sub.h index 4b8634c166..de3d99bfc9 100644 --- a/frame/compat/cblas/f77_sub/f77_asum_sub.h +++ b/frame/compat/cblas/f77_sub/f77_asum_sub.h @@ -5,7 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -40,6 +41,13 @@ #define GENTPROTR2( ftype_x, ftype_r, chx, chr, blasname ) \ \ BLIS_EXPORT_BLAS void PASTEF773(chr,chx,blasname,sub) \ + ( \ + const f77_int* n, \ + const ftype_x* x, const f77_int* incx, \ + ftype_r* rval \ + );\ +\ +BLIS_EXPORT_BLAS void PASTEF773S(chr,chx,blasname,sub) \ ( \ const f77_int* n, \ const ftype_x* x, const f77_int* incx, \ diff --git a/frame/compat/cblas/f77_sub/f77_dot_sub.c b/frame/compat/cblas/f77_sub/f77_dot_sub.c index 0ca80464d3..f497ab97f0 100644 --- a/frame/compat/cblas/f77_sub/f77_dot_sub.c +++ b/frame/compat/cblas/f77_sub/f77_dot_sub.c @@ -5,7 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -43,7 +44,7 @@ #undef GENTFUNCDOT #define GENTFUNCDOT( ftype, ch, chc, blis_conjx, blasname, blisname ) \ \ -void PASTEF773(ch,blasname,chc,sub) \ +void PASTEF773S(ch,blasname,chc,sub) \ ( \ const f77_int* n, \ const ftype* x, const f77_int* incx, \ @@ -51,12 +52,23 @@ void PASTEF773(ch,blasname,chc,sub) \ ftype* rval \ ) \ { \ - *rval = PASTEF772(ch,blasname,chc) \ + *rval = PASTEF772S(ch,blasname,chc) \ ( \ n, \ x, incx, \ y, incy \ ); \ +}\ +\ +void PASTEF773(ch,blasname,chc,sub) \ + ( \ + const f77_int* n, \ + const ftype* x, const f77_int* incx, \ + const ftype* y, const f77_int* incy, \ + ftype* rval \ + ) \ +{ \ + PASTEF773S(ch,blasname,chc,sub)( n, x, incx, y, incy, rval); \ } INSERT_GENTFUNCDOTR_BLAS( dot, NULL ) @@ -75,7 +87,7 @@ INSERT_GENTFUNCDOTC_BLAS( dot, NULL ) #undef GENTFUNCDOT #define GENTFUNCDOT( ftype, ch, chc, blis_conjx, blasname, blisname ) \ \ -void PASTEF773(ch,blasname,chc,sub) \ +void PASTEF773S(ch,blasname,chc,sub) \ ( \ const f77_int* n, \ const ftype* x, const f77_int* incx, \ @@ -90,6 +102,17 @@ void PASTEF773(ch,blasname,chc,sub) \ x, incx, \ y, incy \ ); \ +}\ +\ +void PASTEF773(ch,blasname,chc,sub) \ + ( \ + const f77_int* n, \ + const ftype* x, const f77_int* incx, \ + const ftype* y, const f77_int* incy, \ + ftype* rval \ + ) \ +{ \ + PASTEF773S(ch,blasname,chc,sub)( n, x, incx, y, incy, rval); \ } INSERT_GENTFUNCDOTC_BLAS( dot, NULL ) @@ -100,7 +123,7 @@ INSERT_GENTFUNCDOTC_BLAS( dot, NULL ) // Input vectors stored in single precision, computed in double precision, // with result returned in single precision. -void PASTEF772(sds,dot,sub) +void PASTEF772S(sds,dot,sub) ( const f77_int* n, const float* sb, @@ -109,7 +132,7 @@ void PASTEF772(sds,dot,sub) float* rval ) { - *rval = PASTEF77(sds,dot) + *rval = PASTEF77S(sds,dot) ( n, sb, @@ -117,10 +140,21 @@ void PASTEF772(sds,dot,sub) y, incy ); } +void PASTEF772(sds,dot,sub) + ( + const f77_int* n, + const float* sb, + const float* x, const f77_int* incx, + const float* y, const f77_int* incy, + float* rval + ) +{ + PASTEF772S(sds,dot,sub)( n, sb, x, incx, y, incy, rval); +} // Input vectors stored in single precision, computed in double precision, // with result returned in double precision. -void PASTEF772(ds,dot,sub) +void PASTEF772S(ds,dot,sub) ( const f77_int* n, const float* x, const f77_int* incx, @@ -128,13 +162,23 @@ void PASTEF772(ds,dot,sub) double* rval ) { - *rval = PASTEF77(ds,dot) + *rval = PASTEF77S(ds,dot) ( n, x, incx, y, incy ); } +void PASTEF772(ds,dot,sub) + ( + const f77_int* n, + const float* x, const f77_int* incx, + const float* y, const f77_int* incy, + double* rval + ) +{ + PASTEF772S(ds,dot,sub)( n, x, incx, y, incy, rval); +} #endif diff --git a/frame/compat/cblas/f77_sub/f77_dot_sub.h b/frame/compat/cblas/f77_sub/f77_dot_sub.h index 8aab2728bf..54a40a9a02 100644 --- a/frame/compat/cblas/f77_sub/f77_dot_sub.h +++ b/frame/compat/cblas/f77_sub/f77_dot_sub.h @@ -5,7 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -40,6 +41,14 @@ #define GENTPROTDOT( ftype, ch, chc, blasname ) \ \ BLIS_EXPORT_BLAS void PASTEF773(ch,blasname,chc,sub) \ + ( \ + const f77_int* n, \ + const ftype* x, const f77_int* incx, \ + const ftype* y, const f77_int* incy, \ + ftype* rval \ + );\ +\ +BLIS_EXPORT_BLAS void PASTEF773S(ch,blasname,chc,sub) \ ( \ const f77_int* n, \ const ftype* x, const f77_int* incx, \ @@ -61,7 +70,15 @@ BLIS_EXPORT_BLAS void PASTEF772(sds,dot,sub) const float* y, const f77_int* incy, float* rval ); - +BLIS_EXPORT_BLAS void PASTEF772S(sds,dot,sub) + ( + const f77_int* n, + const float* sb, + const float* x, const f77_int* incx, + const float* y, const f77_int* incy, + float* rval + ); + BLIS_EXPORT_BLAS void PASTEF772(ds,dot,sub) ( const f77_int* n, @@ -69,4 +86,12 @@ BLIS_EXPORT_BLAS void PASTEF772(ds,dot,sub) const float* y, const f77_int* incy, double* rval ); +BLIS_EXPORT_BLAS void PASTEF772S(ds,dot,sub) + ( + const f77_int* n, + const float* x, const f77_int* incx, + const float* y, const f77_int* incy, + double* rval + ); + #endif diff --git a/frame/compat/cblas/f77_sub/f77_nrm2_sub.c b/frame/compat/cblas/f77_sub/f77_nrm2_sub.c index 54ce1a5b49..72fa07593a 100644 --- a/frame/compat/cblas/f77_sub/f77_nrm2_sub.c +++ b/frame/compat/cblas/f77_sub/f77_nrm2_sub.c @@ -5,7 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -42,18 +43,28 @@ #undef GENTFUNCR2 #define GENTFUNCR2( ftype_x, ftype_r, chx, chr, blasname, blisname ) \ \ -void PASTEF773(chr,chx,blasname,sub) \ +void PASTEF773S(chr,chx,blasname,sub) \ ( \ const f77_int* n, \ const ftype_x* x, const f77_int* incx, \ ftype_r* rval \ ) \ { \ - *rval = PASTEF772(chr,chx,blasname) \ + *rval = PASTEF772S(chr,chx,blasname) \ ( \ n, \ x, incx \ ); \ +}\ +\ +void PASTEF773(chr,chx,blasname,sub) \ + ( \ + const f77_int* n, \ + const ftype_x* x, const f77_int* incx, \ + ftype_r* rval \ + ) \ +{ \ + PASTEF773S(chr,chx,blasname,sub)( n, x, incx, rval );\ } #ifdef BLIS_ENABLE_CBLAS diff --git a/frame/compat/cblas/f77_sub/f77_nrm2_sub.h b/frame/compat/cblas/f77_sub/f77_nrm2_sub.h index c51a94292b..dbe2809741 100644 --- a/frame/compat/cblas/f77_sub/f77_nrm2_sub.h +++ b/frame/compat/cblas/f77_sub/f77_nrm2_sub.h @@ -5,7 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -40,6 +41,13 @@ #define GENTPROTR2( ftype_x, ftype_r, chx, chr, blasname ) \ \ BLIS_EXPORT_BLAS void PASTEF773(chr,chx,blasname,sub) \ + ( \ + const f77_int* n, \ + const ftype_x* x, const f77_int* incx, \ + ftype_r* rval \ + );\ +\ +BLIS_EXPORT_BLAS void PASTEF773S(chr,chx,blasname,sub) \ ( \ const f77_int* n, \ const ftype_x* x, const f77_int* incx, \ diff --git a/frame/compat/cblas/src/cblas_f77.h b/frame/compat/cblas/src/cblas_f77.h index d13d833ab9..864d78895e 100644 --- a/frame/compat/cblas/src/cblas_f77.h +++ b/frame/compat/cblas/src/cblas_f77.h @@ -207,53 +207,52 @@ * Level 1 BLAS */ #define F77_xerbla xerbla_ -#define F77_srotg srotg_ -#define F77_srotmg srotmg_ -#define F77_srot srot_ -#define F77_srotm srotm_ -#define F77_drotg drotg_ -#define F77_drotmg drotmg_ -#define F77_drot drot_ -#define F77_drotm drotm_ +#define F77_srotg srotg_blis_impl +#define F77_srotmg srotmg_blis_impl +#define F77_srot srot_blis_impl +#define F77_srotm srotm_blis_impl +#define F77_drotg drotg_blis_impl +#define F77_drotmg drotmg_blis_impl +#define F77_drot drot_blis_impl +#define F77_drotm drotm_blis_impl #define F77_sswap sswap_ -#define F77_scopy scopy_ -#define F77_saxpy saxpy_ -#define F77_isamax_sub isamaxsub_ +#define F77_scopy scopy_blis_impl +#define F77_saxpy saxpy_blis_impl +#define F77_isamax_sub isamaxsub_blis_impl #define F77_dswap dswap_ -#define F77_dcopy dcopy_ -#define F77_daxpy daxpy_ -#define F77_idamax_sub idamaxsub_ +#define F77_dcopy dcopy_blis_impl +#define F77_daxpy daxpy_blis_impl +#define F77_idamax_sub idamaxsub_blis_impl #define F77_cswap cswap_ -#define F77_ccopy ccopy_ -#define F77_caxpy caxpy_ -#define F77_icamax_sub icamaxsub_ +#define F77_ccopy ccopy_blis_impl +#define F77_caxpy caxpy_blis_impl +#define F77_icamax_sub icamaxsub_blis_impl #define F77_zswap zswap_ -#define F77_zcopy zcopy_ -#define F77_zaxpy zaxpy_ -#define F77_zaxpby zaxpby_ -#define F77_izamax_sub izamaxsub_ -#define F77_sdot_sub sdotsub_ -#define F77_ddot_sub ddotsub_ -#define F77_dsdot_sub dsdotsub_ -#define F77_sscal sscal_ -#define F77_dscal dscal_ -#define F77_cscal cscal_ -#define F77_zscal zscal_ -#define F77_csscal csscal_ -#define F77_zdscal zdscal_ -#define F77_cdotu_sub cdotusub_ -#define F77_cdotc_sub cdotcsub_ -#define F77_zdotu_sub zdotusub_ -#define F77_zdotc_sub zdotcsub_ -#define F77_snrm2_sub snrm2sub_ -#define F77_sasum_sub sasumsub_ -#define F77_dnrm2_sub dnrm2sub_ -#define F77_dasum_sub dasumsub_ -#define F77_scnrm2_sub scnrm2sub_ -#define F77_scasum_sub scasumsub_ -#define F77_dznrm2_sub dznrm2sub_ -#define F77_dzasum_sub dzasumsub_ -#define F77_sdsdot_sub sdsdotsub_ +#define F77_zcopy zcopy_blis_impl +#define F77_zaxpy zaxpy_blis_impl +#define F77_izamax_sub izamaxsub_blis_impl +#define F77_sdot_sub sdotsub_blis_impl +#define F77_ddot_sub ddotsub_blis_impl +#define F77_dsdot_sub dsdotsub_blis_impl +#define F77_sscal sscal_blis_impl +#define F77_dscal dscal_blis_impl +#define F77_cscal cscal_blis_impl +#define F77_zscal zscal_blis_impl +#define F77_csscal csscal_blis_impl +#define F77_zdscal zdscal_blis_impl +#define F77_cdotu_sub cdotusub_blis_impl +#define F77_cdotc_sub cdotcsub_blis_impl +#define F77_zdotu_sub zdotusub_blis_impl +#define F77_zdotc_sub zdotcsub_blis_impl +#define F77_snrm2_sub snrm2sub_blis_impl +#define F77_sasum_sub sasumsub_blis_impl +#define F77_dnrm2_sub dnrm2sub_blis_impl +#define F77_dasum_sub dasumsub_blis_impl +#define F77_scnrm2_sub scnrm2sub_blis_impl +#define F77_scasum_sub scasumsub_blis_impl +#define F77_dznrm2_sub dznrm2sub_blis_impl +#define F77_dzasum_sub dzasumsub_blis_impl +#define F77_sdsdot_sub sdsdotsub_blis_impl /* * Level 2 BLAS */ @@ -371,17 +370,17 @@ * -- BLAS Extension APIs -- */ -#define F77_saxpby saxpby_ -#define F77_daxpby daxpby_ -#define F77_caxpby caxpby_ -#define F77_zaxpby zaxpby_ +#define F77_saxpby saxpby_blis_impl +#define F77_daxpby daxpby_blis_impl +#define F77_caxpby caxpby_blis_impl +#define F77_zaxpby zaxpby_blis_impl #define F77_cgemm3m cgemm3m_blis_impl #define F77_zgemm3m zgemm3m_blis_impl -#define F77_isamin_sub isaminsub_ -#define F77_idamin_sub idaminsub_ -#define F77_icamin_sub icaminsub_ -#define F77_izamin_sub izaminsub_ +#define F77_isamin_sub isaminsub_blis_impl +#define F77_idamin_sub idaminsub_blis_impl +#define F77_icamin_sub icaminsub_blis_impl +#define F77_izamin_sub izaminsub_blis_impl // -- Batch APIs -- #define F77_sgemm_batch sgemm_batch_ @@ -390,4 +389,4 @@ #define F77_zgemm_batch zgemm_batch_ #endif -#endif /* CBLAS_F77_H */ \ No newline at end of file +#endif /* CBLAS_F77_H */ diff --git a/frame/compat/f2c/bla_rot.c b/frame/compat/f2c/bla_rot.c index c79769bc05..f66aad12c0 100644 --- a/frame/compat/f2c/bla_rot.c +++ b/frame/compat/f2c/bla_rot.c @@ -5,7 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -41,7 +42,8 @@ -lf2c -lm (in that order) */ -/* Subroutine */ int PASTEF77(s,rot)(const bla_integer *n, bla_real *sx, const bla_integer *incx, bla_real *sy, const bla_integer *incy, const bla_real *c__, const bla_real *s) +/* Subroutine */ +int PASTEF77S(s,rot)(const bla_integer *n, bla_real *sx, const bla_integer *incx, bla_real *sy, const bla_integer *incy, const bla_real *c__, const bla_real *s) { /* System generated locals */ bla_integer i__1; @@ -109,7 +111,8 @@ -lf2c -lm (in that order) */ -/* Subroutine */ int PASTEF77(d,rot)(const bla_integer *n, bla_double *dx, const bla_integer *incx, bla_double *dy, const bla_integer *incy, const bla_double *c__, const bla_double *s) +/* Subroutine */ +int PASTEF77S(d,rot)(const bla_integer *n, bla_double *dx, const bla_integer *incx, bla_double *dy, const bla_integer *incy, const bla_double *c__, const bla_double *s) { /* System generated locals */ bla_integer i__1; @@ -172,6 +175,16 @@ return 0; } /* drot_ */ +int PASTEF77(s,rot)(const bla_integer *n, bla_real *sx, const bla_integer *incx, bla_real *sy, const bla_integer *incy, const bla_real *c__, const bla_real *s) +{ + return PASTEF77S(s,rot)( n, sx, incx, sy, incy, c__, s ); +} + +int PASTEF77(d,rot)(const bla_integer *n, bla_double *dx, const bla_integer *incx, bla_double *dy, const bla_integer *incy, const bla_double *c__, const bla_double *s) +{ + return PASTEF77S(d,rot)( n, dx, incx, dy, incy, c__, s ); +} + /* csrot.f -- translated by f2c (version 19991025). You must link the resulting object file with the libraries: -lf2c -lm (in that order) diff --git a/frame/compat/f2c/bla_rot.h b/frame/compat/f2c/bla_rot.h index 6093555600..c8bd42c254 100644 --- a/frame/compat/f2c/bla_rot.h +++ b/frame/compat/f2c/bla_rot.h @@ -5,7 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -38,5 +39,7 @@ BLIS_EXPORT_BLAS int PASTEF77(s,rot)(const bla_integer *n, bla_real *sx, const b BLIS_EXPORT_BLAS int PASTEF77(d,rot)(const bla_integer *n, bla_double *dx, const bla_integer *incx, bla_double *dy, const bla_integer *incy, const bla_double *c__, const bla_double *s); BLIS_EXPORT_BLAS int PASTEF77(cs,rot)(const bla_integer *n, bla_scomplex *cx, const bla_integer *incx, bla_scomplex *cy, const bla_integer *incy, const bla_real *c__, const bla_real *s); BLIS_EXPORT_BLAS int PASTEF77(zd,rot)(const bla_integer *n, bla_dcomplex *zx, const bla_integer *incx, bla_dcomplex *zy, const bla_integer *incy, const bla_double *c__, const bla_double *s); +BLIS_EXPORT_BLAS int PASTEF77S(s,rot)(const bla_integer *n, bla_real *sx, const bla_integer *incx, bla_real *sy, const bla_integer *incy, const bla_real *c__, const bla_real *s); +BLIS_EXPORT_BLAS int PASTEF77S(d,rot)(const bla_integer *n, bla_double *dx, const bla_integer *incx, bla_double *dy, const bla_integer *incy, const bla_double *c__, const bla_double *s); #endif diff --git a/frame/compat/f2c/bla_rotg.c b/frame/compat/f2c/bla_rotg.c index 1572689f57..613574c955 100644 --- a/frame/compat/f2c/bla_rotg.c +++ b/frame/compat/f2c/bla_rotg.c @@ -5,7 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -45,7 +46,8 @@ static bla_real sc_b4 = 1.f; -/* Subroutine */ int PASTEF77(s,rotg)(bla_real *sa, bla_real *sb, bla_real *c__, bla_real *s) +/* Subroutine */ +int PASTEF77S(s,rotg)(bla_real *sa, bla_real *sb, bla_real *c__, bla_real *s) { /* System generated locals */ bla_real r__1, r__2; @@ -105,7 +107,8 @@ static bla_real sc_b4 = 1.f; static bla_double dc_b4 = 1.; -/* Subroutine */ int PASTEF77(d,rotg)(bla_double *da, bla_double *db, bla_double *c__, bla_double *s) +/* Subroutine */ +int PASTEF77S(d,rotg)(bla_double *da, bla_double *db, bla_double *c__, bla_double *s) { /* System generated locals */ bla_double d__1, d__2; @@ -156,6 +159,16 @@ static bla_double dc_b4 = 1.; return 0; } /* drotg_ */ +int PASTEF77(s,rotg)(bla_real *sa, bla_real *sb, bla_real *c__, bla_real *s) +{ + return PASTEF77S(s,rotg)( sa, sb, c__, s ); +} + +int PASTEF77(d,rotg)(bla_double *da, bla_double *db, bla_double *c__, bla_double *s) +{ + return PASTEF77S(d,rotg)( da, db, c__, s ); +} + /* crotg.f -- translated by f2c (version 19991025). You must link the resulting object file with the libraries: -lf2c -lm (in that order) diff --git a/frame/compat/f2c/bla_rotg.h b/frame/compat/f2c/bla_rotg.h index b968ebbea2..067f67c22a 100644 --- a/frame/compat/f2c/bla_rotg.h +++ b/frame/compat/f2c/bla_rotg.h @@ -5,7 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -38,5 +39,7 @@ BLIS_EXPORT_BLAS int PASTEF77(s,rotg)(bla_real *sa, bla_real *sb, bla_real *c__, BLIS_EXPORT_BLAS int PASTEF77(d,rotg)(bla_double *da, bla_double *db, bla_double *c__, bla_double *s); BLIS_EXPORT_BLAS int PASTEF77(c,rotg)(bla_scomplex *ca, bla_scomplex *cb, bla_real *c__, bla_scomplex *s); BLIS_EXPORT_BLAS int PASTEF77(z,rotg)(bla_dcomplex *ca, bla_dcomplex *cb, bla_double *c__, bla_dcomplex *s); +BLIS_EXPORT_BLAS int PASTEF77S(s,rotg)(bla_real *sa, bla_real *sb, bla_real *c__, bla_real *s); +BLIS_EXPORT_BLAS int PASTEF77S(d,rotg)(bla_double *da, bla_double *db, bla_double *c__, bla_double *s); #endif diff --git a/frame/compat/f2c/bla_rotm.c b/frame/compat/f2c/bla_rotm.c index 003dea7155..b60442e387 100644 --- a/frame/compat/f2c/bla_rotm.c +++ b/frame/compat/f2c/bla_rotm.c @@ -5,7 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -41,7 +42,8 @@ -lf2c -lm (in that order) */ -/* Subroutine */ int PASTEF77(s,rotm)(const bla_integer *n, bla_real *sx, const bla_integer *incx, bla_real *sy, const bla_integer *incy, const bla_real *sparam) +/* Subroutine */ +int PASTEF77S(s,rotm)(const bla_integer *n, bla_real *sx, const bla_integer *incx, bla_real *sy, const bla_integer *incy, const bla_real *sparam) { /* Initialized data */ @@ -207,7 +209,8 @@ -lf2c -lm (in that order) */ -/* Subroutine */ int PASTEF77(d,rotm)(const bla_integer *n, bla_double *dx, const bla_integer *incx, bla_double *dy, const bla_integer *incy, const bla_double *dparam) +/* Subroutine */ +int PASTEF77S(d,rotm)(const bla_integer *n, bla_double *dx, const bla_integer *incx, bla_double *dy, const bla_integer *incy, const bla_double *dparam) { /* Initialized data */ @@ -368,5 +371,15 @@ return 0; } /* drotm_ */ +int PASTEF77(s,rotm)(const bla_integer *n, bla_real *sx, const bla_integer *incx, bla_real *sy, const bla_integer *incy, const bla_real *sparam) +{ + return PASTEF77S(s,rotm)( n, sx, incx, sy, incy, sparam); +} + +int PASTEF77(d,rotm)(const bla_integer *n, bla_double *dx, const bla_integer *incx, bla_double *dy, const bla_integer *incy, const bla_double *dparam) +{ + return PASTEF77S(d,rotm)( n, dx, incx, dy, incy, dparam); +} + #endif diff --git a/frame/compat/f2c/bla_rotm.h b/frame/compat/f2c/bla_rotm.h index 21906358be..8f610ba5a9 100644 --- a/frame/compat/f2c/bla_rotm.h +++ b/frame/compat/f2c/bla_rotm.h @@ -5,7 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -36,5 +37,7 @@ BLIS_EXPORT_BLAS int PASTEF77(s,rotm)(const bla_integer *n, bla_real *sx, const bla_integer *incx, bla_real *sy, const bla_integer *incy, const bla_real *sparam); BLIS_EXPORT_BLAS int PASTEF77(d,rotm)(const bla_integer *n, bla_double *dx, const bla_integer *incx, bla_double *dy, const bla_integer *incy, const bla_double *dparam); +BLIS_EXPORT_BLAS int PASTEF77S(s,rotm)(const bla_integer *n, bla_real *sx, const bla_integer *incx, bla_real *sy, const bla_integer *incy, const bla_real *sparam); +BLIS_EXPORT_BLAS int PASTEF77S(d,rotm)(const bla_integer *n, bla_double *dx, const bla_integer *incx, bla_double *dy, const bla_integer *incy, const bla_double *dparam); #endif diff --git a/frame/compat/f2c/bla_rotmg.c b/frame/compat/f2c/bla_rotmg.c index 11ccc6f333..a285e69a39 100644 --- a/frame/compat/f2c/bla_rotmg.c +++ b/frame/compat/f2c/bla_rotmg.c @@ -5,7 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -41,7 +42,8 @@ -lf2c -lm (in that order) */ -/* Subroutine */ int PASTEF77(s,rotmg)(bla_real *sd1, bla_real *sd2, bla_real *sx1, const bla_real *sy1, bla_real *sparam) +/* Subroutine */ +int PASTEF77S(s,rotmg)(bla_real *sd1, bla_real *sd2, bla_real *sx1, const bla_real *sy1, bla_real *sparam) { /* Initialized data */ @@ -281,7 +283,8 @@ -lf2c -lm (in that order) */ -/* Subroutine */ int PASTEF77(d,rotmg)(bla_double *dd1, bla_double *dd2, bla_double *dx1, const bla_double *dy1, bla_double *dparam) +/* Subroutine */ +int PASTEF77S(d,rotmg)(bla_double *dd1, bla_double *dd2, bla_double *dx1, const bla_double *dy1, bla_double *dparam) { /* Initialized data */ @@ -516,5 +519,15 @@ return 0; } /* drotmg_ */ +int PASTEF77(s,rotmg)(bla_real *sd1, bla_real *sd2, bla_real *sx1, const bla_real *sy1, bla_real *sparam) +{ + return PASTEF77S(s,rotmg)( sd1, sd2, sx1, sy1, sparam ); +} + +int PASTEF77(d,rotmg)(bla_double *dd1, bla_double *dd2, bla_double *dx1, const bla_double *dy1, bla_double *dparam) +{ + return PASTEF77S(d,rotmg)( dd1, dd2, dx1, dy1, dparam ); +} + #endif diff --git a/frame/compat/f2c/bla_rotmg.h b/frame/compat/f2c/bla_rotmg.h index 63e9710da1..fe3a10beb4 100644 --- a/frame/compat/f2c/bla_rotmg.h +++ b/frame/compat/f2c/bla_rotmg.h @@ -5,7 +5,8 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -36,5 +37,7 @@ BLIS_EXPORT_BLAS int PASTEF77(s,rotmg)(bla_real *sd1, bla_real *sd2, bla_real *sx1, const bla_real *sy1, bla_real *sparam); BLIS_EXPORT_BLAS int PASTEF77(d,rotmg)(bla_double *dd1, bla_double *dd2, bla_double *dx1, const bla_double *dy1, bla_double *dparam); +BLIS_EXPORT_BLAS int PASTEF77S(s,rotmg)(bla_real *sd1, bla_real *sd2, bla_real *sx1, const bla_real *sy1, bla_real *sparam); +BLIS_EXPORT_BLAS int PASTEF77S(d,rotmg)(bla_double *dd1, bla_double *dd2, bla_double *dx1, const bla_double *dy1, bla_double *dparam); #endif diff --git a/frame/util/bli_util_api_wrap.c b/frame/util/bli_util_api_wrap.c index 098a7c33e8..26810531d3 100644 --- a/frame/util/bli_util_api_wrap.c +++ b/frame/util/bli_util_api_wrap.c @@ -43,32 +43,32 @@ #ifndef BLIS_ENABLE_UPPERCASE_API void CAXPY(const f77_int *n,const scomplex *ca,const scomplex *cx,const f77_int *incx,scomplex *cy,const f77_int *incy) { - caxpy_( n, ca, cx, incx, cy, incy); + caxpy_blis_impl( n, ca, cx, incx, cy, incy); } void caxpy(const f77_int *n,const scomplex *ca,const scomplex *cx,const f77_int *incx,scomplex *cy,const f77_int *incy) { - caxpy_( n, ca, cx, incx, cy, incy); + caxpy_blis_impl( n, ca, cx, incx, cy, incy); } void CAXPY_(const f77_int *n,const scomplex *ca,const scomplex *cx,const f77_int *incx,scomplex *cy,const f77_int *incy) { - caxpy_( n, ca, cx, incx, cy, incy); + caxpy_blis_impl( n, ca, cx, incx, cy, incy); } void CCOPY(const f77_int *n,const scomplex *cx,const f77_int *incx,scomplex *cy,const f77_int *incy) { - ccopy_( n, cx, incx, cy, incy); + ccopy_blis_impl( n, cx, incx, cy, incy); } void ccopy(const f77_int *n,const scomplex *cx,const f77_int *incx,scomplex *cy,const f77_int *incy) { - ccopy_( n, cx, incx, cy, incy); + ccopy_blis_impl( n, cx, incx, cy, incy); } void CCOPY_(const f77_int *n,const scomplex *cx,const f77_int *incx,scomplex *cy,const f77_int *incy) { - ccopy_( n, cx, incx, cy, incy); + ccopy_blis_impl( n, cx, incx, cy, incy); } #ifdef BLIS_DISABLE_COMPLEX_RETURN_INTEL @@ -435,17 +435,17 @@ void CROTG_(scomplex *ca, bla_scomplex *cb, bla_real *c,scomplex *s) void CSCAL(const f77_int *n,const scomplex *ca,scomplex *cx,const f77_int *incx) { - cscal_( n, ca, cx, incx); + cscal_blis_impl( n, ca, cx, incx); } void cscal(const f77_int *n,const scomplex *ca,scomplex *cx,const f77_int *incx) { - cscal_( n, ca, cx, incx); + cscal_blis_impl( n, ca, cx, incx); } void CSCAL_(const f77_int *n,const scomplex *ca,scomplex *cx,const f77_int *incx) { - cscal_( n, ca, cx, incx); + cscal_blis_impl( n, ca, cx, incx); } void CSROT(const f77_int *n,scomplex *cx,const f77_int *incx,scomplex *cy,const f77_int *incy,const float *c,const float *s) @@ -465,17 +465,17 @@ void CSROT_(const f77_int *n,scomplex *cx,const f77_int *incx,scomplex *cy,con void CSSCAL(const f77_int *n,const float *sa,scomplex *cx,const f77_int *incx) { - csscal_( n, sa, cx, incx); + csscal_blis_impl( n, sa, cx, incx); } void csscal(const f77_int *n,const float *sa,scomplex *cx,const f77_int *incx) { - csscal_( n, sa, cx, incx); + csscal_blis_impl( n, sa, cx, incx); } void CSSCAL_(const f77_int *n,const float *sa,scomplex *cx,const f77_int *incx) { - csscal_( n, sa, cx, incx); + csscal_blis_impl( n, sa, cx, incx); } void CSWAP(const f77_int *n,scomplex *cx,const f77_int *incx,scomplex *cy,const f77_int *incy) @@ -675,17 +675,17 @@ double DASUM_(const f77_int *n,const double *dx,const f77_int *incx) void DAXPY(const f77_int *n,const double *da,const double *dx,const f77_int *incx,double *dy,const f77_int *incy) { - daxpy_( n, da, dx, incx, dy, incy); + daxpy_blis_impl( n, da, dx, incx, dy, incy); } void daxpy(const f77_int *n,const double *da,const double *dx,const f77_int *incx,double *dy,const f77_int *incy) { - daxpy_( n, da, dx, incx, dy, incy); + daxpy_blis_impl( n, da, dx, incx, dy, incy); } void DAXPY_(const f77_int *n,const double *da,const double *dx,const f77_int *incx,double *dy,const f77_int *incy) { - daxpy_( n, da, dx, incx, dy, incy); + daxpy_blis_impl( n, da, dx, incx, dy, incy); } double DCABS1(bla_dcomplex *z) @@ -705,17 +705,17 @@ double DCABS1_(bla_dcomplex *z) void DCOPY(const f77_int *n,const double *dx,const f77_int *incx,double *dy,const f77_int *incy) { - dcopy_( n, dx, incx, dy, incy); + dcopy_blis_impl( n, dx, incx, dy, incy); } void dcopy(const f77_int *n,const double *dx,const f77_int *incx,double *dy,const f77_int *incy) { - dcopy_( n, dx, incx, dy, incy); + dcopy_blis_impl( n, dx, incx, dy, incy); } void DCOPY_(const f77_int *n,const double *dx,const f77_int *incx,double *dy,const f77_int *incy) { - dcopy_( n, dx, incx, dy, incy); + dcopy_blis_impl( n, dx, incx, dy, incy); } double DDOT(const f77_int *n,const double *dx,const f77_int *incx,const double *dy,const f77_int *incy) @@ -810,62 +810,62 @@ double DNRM2_(const f77_int *n,const double *x,const f77_int *incx) void DROT(const f77_int *n,double *dx,const f77_int *incx,double *dy,const f77_int *incy,const double *c,const double *s) { - drot_( n, dx, incx, dy, incy, c, s); + drot_blis_impl( n, dx, incx, dy, incy, c, s); } void drot(const f77_int *n,double *dx,const f77_int *incx,double *dy,const f77_int *incy,const double *c,const double *s) { - drot_( n, dx, incx, dy, incy, c, s); + drot_blis_impl( n, dx, incx, dy, incy, c, s); } void DROT_(const f77_int *n,double *dx,const f77_int *incx,double *dy,const f77_int *incy,const double *c,const double *s) { - drot_( n, dx, incx, dy, incy, c, s); + drot_blis_impl( n, dx, incx, dy, incy, c, s); } void DROTG(double *da,double *db,double *c,double *s) { - drotg_( da, db, c, s); + drotg_blis_impl( da, db, c, s); } void drotg(double *da,double *db,double *c,double *s) { - drotg_( da, db, c, s); + drotg_blis_impl( da, db, c, s); } void DROTG_(double *da,double *db,double *c,double *s) { - drotg_( da, db, c, s); + drotg_blis_impl( da, db, c, s); } void DROTM(const f77_int *n,double *dx,const f77_int *incx,double *dy,const f77_int *incy,const double *dparam) { - drotm_( n, dx, incx, dy, incy, dparam); + drotm_blis_impl( n, dx, incx, dy, incy, dparam); } void drotm(const f77_int *n,double *dx,const f77_int *incx,double *dy,const f77_int *incy,const double *dparam) { - drotm_( n, dx, incx, dy, incy, dparam); + drotm_blis_impl( n, dx, incx, dy, incy, dparam); } void DROTM_(const f77_int *n,double *dx,const f77_int *incx,double *dy,const f77_int *incy,const double *dparam) { - drotm_( n, dx, incx, dy, incy, dparam); + drotm_blis_impl( n, dx, incx, dy, incy, dparam); } void DROTMG(double *dd1,double *dd2,double *dx1,const double *dy1,double *dparam) { - drotmg_( dd1, dd2, dx1, dy1, dparam); + drotmg_blis_impl( dd1, dd2, dx1, dy1, dparam); } void drotmg(double *dd1,double *dd2,double *dx1,const double *dy1,double *dparam) { - drotmg_( dd1, dd2, dx1, dy1, dparam); + drotmg_blis_impl( dd1, dd2, dx1, dy1, dparam); } void DROTMG_(double *dd1,double *dd2,double *dx1,const double *dy1,double *dparam) { - drotmg_( dd1, dd2, dx1, dy1, dparam); + drotmg_blis_impl( dd1, dd2, dx1, dy1, dparam); } void DSBMV(const char *uplo,const f77_int *n,const f77_int *k,const double *alpha,const double *a,const f77_int *lda,const double *x,const f77_int *incx,const double *beta,double *y,const f77_int *incy) @@ -885,17 +885,17 @@ void DSBMV_(const char *uplo,const f77_int *n,const f77_int *k,const double *a void DSCAL(const f77_int *n,const double *da,double *dx,const f77_int *incx) { - dscal_( n, da, dx, incx); + dscal_blis_impl( n, da, dx, incx); } void dscal(const f77_int *n,const double *da,double *dx,const f77_int *incx) { - dscal_( n, da, dx, incx); + dscal_blis_impl( n, da, dx, incx); } void DSCAL_(const f77_int *n,const double *da,double *dx,const f77_int *incx) { - dscal_( n, da, dx, incx); + dscal_blis_impl( n, da, dx, incx); } double DSDOT(const f77_int *n,const float *sx,const f77_int *incx,const float *sy,const f77_int *incy) @@ -1305,17 +1305,17 @@ float SASUM_(const f77_int *n,const float *sx, const f77_int *incx) void SAXPY(const f77_int *n,const float *sa,const float *sx,const f77_int *incx,float *sy,const f77_int *incy) { - saxpy_( n, sa, sx, incx, sy, incy); + saxpy_blis_impl( n, sa, sx, incx, sy, incy); } void saxpy(const f77_int *n,const float *sa,const float *sx,const f77_int *incx,float *sy,const f77_int *incy) { - saxpy_( n, sa, sx, incx, sy, incy); + saxpy_blis_impl( n, sa, sx, incx, sy, incy); } void SAXPY_(const f77_int *n,const float *sa,const float *sx,const f77_int *incx,float *sy,const f77_int *incy) { - saxpy_( n, sa, sx, incx, sy, incy); + saxpy_blis_impl( n, sa, sx, incx, sy, incy); } @@ -1354,17 +1354,17 @@ float SCNRM2_(const f77_int *n,const scomplex *x, const f77_int *incx) void SCOPY(const f77_int *n,const float *sx,const f77_int *incx,float *sy,const f77_int *incy) { - scopy_( n, sx, incx, sy, incy); + scopy_blis_impl( n, sx, incx, sy, incy); } void scopy(const f77_int *n,const float *sx,const f77_int *incx,float *sy,const f77_int *incy) { - scopy_( n, sx, incx, sy, incy); + scopy_blis_impl( n, sx, incx, sy, incy); } void SCOPY_(const f77_int *n,const float *sx,const f77_int *incx,float *sy,const f77_int *incy) { - scopy_( n, sx, incx, sy, incy); + scopy_blis_impl( n, sx, incx, sy, incy); } @@ -1479,62 +1479,62 @@ float SNRM2_(const f77_int *n,const float *x, const f77_int *incx) void SROT(const f77_int *n,float *sx,const f77_int *incx,float *sy,const f77_int *incy,const float *c,const float *s) { - srot_( n, sx, incx, sy, incy, c, s); + srot_blis_impl( n, sx, incx, sy, incy, c, s); } void srot(const f77_int *n,float *sx,const f77_int *incx,float *sy,const f77_int *incy,const float *c,const float *s) { - srot_( n, sx, incx, sy, incy, c, s); + srot_blis_impl( n, sx, incx, sy, incy, c, s); } -void SROT_(const f77_int *n,float *sx,const f77_int *incx,float *sy,const f77_int *incy,const float *c,const float *s) +void SROT_blis_impl(const f77_int *n,float *sx,const f77_int *incx,float *sy,const f77_int *incy,const float *c,const float *s) { - srot_( n, sx, incx, sy, incy, c, s); + srot_blis_impl( n, sx, incx, sy, incy, c, s); } void SROTG(float *sa,float *sb,float *c,float *s) { - srotg_( sa, sb, c, s); + srotg_blis_impl( sa, sb, c, s); } void srotg(float *sa,float *sb,float *c,float *s) { - srotg_( sa, sb, c, s); + srotg_blis_impl( sa, sb, c, s); } void SROTG_(float *sa,float *sb,float *c,float *s) { - srotg_( sa, sb, c, s); + srotg_blis_impl( sa, sb, c, s); } void SROTM(const f77_int *n,float *sx,const f77_int *incx,float *sy,const f77_int *incy,const float *sparam) { - srotm_( n, sx, incx, sy, incy, sparam); + srotm_blis_impl( n, sx, incx, sy, incy, sparam); } void srotm(const f77_int *n,float *sx,const f77_int *incx,float *sy,const f77_int *incy,const float *sparam) { - srotm_( n, sx, incx, sy, incy, sparam); + srotm_blis_impl( n, sx, incx, sy, incy, sparam); } void SROTM_(const f77_int *n,float *sx,const f77_int *incx,float *sy,const f77_int *incy,const float *sparam) { - srotm_( n, sx, incx, sy, incy, sparam); + srotm_blis_impl( n, sx, incx, sy, incy, sparam); } void SROTMG(float *sd1,float *sd2,float *sx1,const float *sy1,float *sparam) { - srotmg_( sd1, sd2, sx1, sy1, sparam); + srotmg_blis_impl( sd1, sd2, sx1, sy1, sparam); } void srotmg(float *sd1,float *sd2,float *sx1,const float *sy1,float *sparam) { - srotmg_( sd1, sd2, sx1, sy1, sparam); + srotmg_blis_impl( sd1, sd2, sx1, sy1, sparam); } void SROTMG_(float *sd1,float *sd2,float *sx1,const float *sy1,float *sparam) { - srotmg_( sd1, sd2, sx1, sy1, sparam); + srotmg_blis_impl( sd1, sd2, sx1, sy1, sparam); } void SSBMV(const char *uplo,const f77_int *n,const f77_int *k,const float *alpha,const float *a,const f77_int *lda,const float *x,const f77_int *incx,const float *beta,float *y,const f77_int *incy) @@ -1554,17 +1554,17 @@ void SSBMV_(const char *uplo,const f77_int *n,const f77_int *k,const float *a void SSCAL(const f77_int *n,const float *sa,float *sx,const f77_int *incx) { - sscal_( n, sa, sx, incx); + sscal_blis_impl( n, sa, sx, incx); } void sscal(const f77_int *n,const float *sa,float *sx,const f77_int *incx) { - sscal_( n, sa, sx, incx); + sscal_blis_impl( n, sa, sx, incx); } void SSCAL_(const f77_int *n,const float *sa,float *sx,const f77_int *incx) { - sscal_( n, sa, sx, incx); + sscal_blis_impl( n, sa, sx, incx); } void SSPMV(const char *uplo,const f77_int *n,const float *alpha,const float *ap,const float *x,const f77_int *incx,const float *beta,float *y,const f77_int *incy) @@ -1854,32 +1854,32 @@ int xerbla(const char *srname,const f77_int *info, ftnlen n) void ZAXPY(const f77_int *n,const dcomplex *za,const dcomplex *zx,const f77_int *incx,dcomplex *zy,const f77_int *incy) { - zaxpy_( n, za, zx, incx, zy, incy); + zaxpy_blis_impl( n, za, zx, incx, zy, incy); } void zaxpy(const f77_int *n,const dcomplex *za,const dcomplex *zx,const f77_int *incx,dcomplex *zy,const f77_int *incy) { - zaxpy_( n, za, zx, incx, zy, incy); + zaxpy_blis_impl( n, za, zx, incx, zy, incy); } void ZAXPY_(const f77_int *n,const dcomplex *za,const dcomplex *zx,const f77_int *incx,dcomplex *zy,const f77_int *incy) { - zaxpy_( n, za, zx, incx, zy, incy); + zaxpy_blis_impl( n, za, zx, incx, zy, incy); } void ZCOPY(const f77_int *n,const dcomplex *zx,const f77_int *incx,dcomplex *zy,const f77_int *incy) { - zcopy_( n, zx, incx, zy, incy); + zcopy_blis_impl( n, zx, incx, zy, incy); } void zcopy(const f77_int *n,const dcomplex *zx,const f77_int *incx,dcomplex *zy,const f77_int *incy) { - zcopy_( n, zx, incx, zy, incy); + zcopy_blis_impl( n, zx, incx, zy, incy); } void ZCOPY_(const f77_int *n,const dcomplex *zx,const f77_int *incx,dcomplex *zy,const f77_int *incy) { - zcopy_( n, zx, incx, zy, incy); + zcopy_blis_impl( n, zx, incx, zy, incy); } void ZDROT(const f77_int *n,dcomplex *cx,const f77_int *incx,dcomplex *cy,const f77_int *incy,const double *c,const double *s) @@ -1899,17 +1899,17 @@ void ZDROT_(const f77_int *n,dcomplex *cx,const f77_int *incx,dcomplex *cy,const void ZDSCAL(const f77_int *n,const double *da,dcomplex *zx,const f77_int *incx) { - zdscal_( n, da, zx, incx); + zdscal_blis_impl( n, da, zx, incx); } void zdscal(const f77_int *n,const double *da,dcomplex *zx,const f77_int *incx) { - zdscal_( n, da, zx, incx); + zdscal_blis_impl( n, da, zx, incx); } void ZDSCAL_(const f77_int *n,const double *da,dcomplex *zx,const f77_int *incx) { - zdscal_( n, da, zx, incx); + zdscal_blis_impl( n, da, zx, incx); } void ZGBMV(const char *trans,const f77_int *m,const f77_int *n,const f77_int *kl,const f77_int *ku,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *x,const f77_int *incx,const dcomplex *beta,dcomplex *y,const f77_int *incy) @@ -2154,17 +2154,17 @@ void ZROTG_(dcomplex *ca,bla_dcomplex *cb,bla_double *c,dcomplex *s) void ZSCAL(const f77_int *n,const dcomplex *za,dcomplex *zx,const f77_int *incx) { - zscal_( n, za, zx, incx); + zscal_blis_impl( n, za, zx, incx); } void zscal(const f77_int *n,const dcomplex *za,dcomplex *zx,const f77_int *incx) { - zscal_( n, za, zx, incx); + zscal_blis_impl( n, za, zx, incx); } void ZSCAL_(const f77_int *n,const dcomplex *za,dcomplex *zx,const f77_int *incx) { - zscal_( n, za, zx, incx); + zscal_blis_impl( n, za, zx, incx); } void ZSWAP(const f77_int *n,dcomplex *zx,const f77_int *incx,dcomplex *zy,const f77_int *incy) @@ -2350,32 +2350,32 @@ void ZTRSV_(const char *uplo,const char *trans,const char *diag,const f77_ void CDOTCSUB( const f77_int* n, const scomplex* x,const f77_int* incx, const scomplex* y, const f77_int* incy, scomplex* rval) { - cdotcsub_( n, x, incx, y, incy, rval); + cdotcsub_blis_impl( n, x, incx, y, incy, rval); } void cdotcsub( const f77_int* n, const scomplex* x,const f77_int* incx, const scomplex* y, const f77_int* incy, scomplex* rval) { - cdotcsub_( n, x, incx, y, incy, rval); + cdotcsub_blis_impl( n, x, incx, y, incy, rval); } void CDOTCSUB_( const f77_int* n, const scomplex* x,const f77_int* incx, const scomplex* y, const f77_int* incy, scomplex* rval) { - cdotcsub_( n, x, incx, y, incy, rval); + cdotcsub_blis_impl( n, x, incx, y, incy, rval); } void CDOTUSUB( const f77_int* n, const scomplex* x,const f77_int* incxy, const scomplex* y, const f77_int* incy, scomplex* rval) { - cdotusub_( n, x, incxy, y, incy, rval); + cdotusub_blis_impl( n, x, incxy, y, incy, rval); } void cdotusub( const f77_int* n, const scomplex* x,const f77_int* incxy, const scomplex* y, const f77_int* incy, scomplex* rval) { - cdotusub_( n, x, incxy, y, incy, rval); + cdotusub_blis_impl( n, x, incxy, y, incy, rval); } void CDOTUSUB_( const f77_int* n, const scomplex* x,const f77_int* incxy, const scomplex* y, const f77_int* incy, scomplex* rval) { - cdotusub_( n, x, incxy, y, incy, rval); + cdotusub_blis_impl( n, x, incxy, y, incy, rval); } void CGEMM3M( const f77_char* transa, const f77_char* transb, const f77_int* m, const f77_int* n, const f77_int* k, const scomplex* alpha, const scomplex* a, const f77_int* lda, const scomplex* b, const f77_int* ldb, const scomplex* beta, scomplex* c, const f77_int* ldc) @@ -2485,47 +2485,47 @@ void COMATCOPY_(f77_char* trans, f77_int* rows, f77_int* cols, const scomplex* a void DASUMSUB(const f77_int* n, const double* x, const f77_int* incx, double* rval) { - dasumsub_( n, x, incx, rval); + dasumsub_blis_impl( n, x, incx, rval); } void dasumsub(const f77_int* n, const double* x, const f77_int* incx, double* rval) { - dasumsub_( n, x, incx, rval); + dasumsub_blis_impl( n, x, incx, rval); } void DASUMSUB_(const f77_int* n, const double* x, const f77_int* incx, double* rval) { - dasumsub_( n, x, incx, rval); + dasumsub_blis_impl( n, x, incx, rval); } void DAXPBY(const f77_int* n, const double* alpha, const double *x, const f77_int* incx, const double* beta, double *y, const f77_int* incy) { - daxpby_( n, alpha, x, incx, beta, y, incy); + daxpby_blis_impl( n, alpha, x, incx, beta, y, incy); } void daxpby(const f77_int* n, const double* alpha, const double *x, const f77_int* incx, const double* beta, double *y, const f77_int* incy) { - daxpby_( n, alpha, x, incx, beta, y, incy); + daxpby_blis_impl( n, alpha, x, incx, beta, y, incy); } void DAXPBY_(const f77_int* n, const double* alpha, const double *x, const f77_int* incx, const double* beta, double *y, const f77_int* incy) { - daxpby_( n, alpha, x, incx, beta, y, incy); + daxpby_blis_impl( n, alpha, x, incx, beta, y, incy); } void DDOTSUB(const f77_int* n, const double* x, const f77_int* incx, const double* y, const f77_int* incy, double* rval) { - ddotsub_( n, x, incx, y, incy, rval); + ddotsub_blis_impl( n, x, incx, y, incy, rval); } void ddotsub(const f77_int* n, const double* x, const f77_int* incx, const double* y, const f77_int* incy, double* rval) { - ddotsub_( n, x, incx, y, incy, rval); + ddotsub_blis_impl( n, x, incx, y, incy, rval); } void DDOTSUB_(const f77_int* n, const double* x, const f77_int* incx, const double* y, const f77_int* incy, double* rval) { - ddotsub_( n, x, incx, y, incy, rval); + ddotsub_blis_impl( n, x, incx, y, incy, rval); } void DGEMM_BATCH( const f77_char* transa_array, const f77_char* transb_array,const f77_int *m_array, const f77_int *n_array, const f77_int *k_array,const double* alpha_array, const double** a_array, const f77_int *lda_array, const double** b_array, const f77_int *ldb_array, const double* beta_array, double** c_array, const f77_int *ldc_array, const f77_int* group_count, const f77_int *group_size) @@ -2560,17 +2560,17 @@ void DGEMMT_( const f77_char* uploc, const f77_char* transa, const f77_char* tra void DNRM2SUB(const f77_int* n, const double* x, const f77_int* incx, double *rval) { - dnrm2sub_( n, x, incx, rval); + dnrm2sub_blis_impl( n, x, incx, rval); } void dnrm2sub(const f77_int* n, const double* x, const f77_int* incx, double *rval) { - dnrm2sub_( n, x, incx, rval); + dnrm2sub_blis_impl( n, x, incx, rval); } void DNRM2SUB_(const f77_int* n, const double* x, const f77_int* incx, double *rval) { - dnrm2sub_( n, x, incx, rval); + dnrm2sub_blis_impl( n, x, incx, rval); } void DOMATADD(f77_char* transa,f77_char* transb, f77_int* m, f77_int* n, const double* alpha, const double* A, f77_int* lda, const double* beta, const double* B, f77_int* ldb, double* C, f77_int* ldc) @@ -2620,47 +2620,47 @@ void DOMATCOPY_(f77_char* trans, f77_int* rows, f77_int* cols, const double* alp void DZASUMSUB(const f77_int* n, const dcomplex* x, const f77_int* incx, double* rval) { - dzasumsub_( n, x, incx, rval); + dzasumsub_blis_impl( n, x, incx, rval); } void dzasumsub(const f77_int* n, const dcomplex* x, const f77_int* incx, double* rval) { - dzasumsub_( n, x, incx, rval); + dzasumsub_blis_impl( n, x, incx, rval); } void DZASUMSUB_(const f77_int* n, const dcomplex* x, const f77_int* incx, double* rval) { - dzasumsub_( n, x, incx, rval); + dzasumsub_blis_impl( n, x, incx, rval); } void DZNRM2SUB(const f77_int* n, const dcomplex* x, const f77_int* incx, double* rval) { - dznrm2sub_( n, x, incx, rval); + dznrm2sub_blis_impl( n, x, incx, rval); } void dznrm2sub(const f77_int* n, const dcomplex* x, const f77_int* incx, double* rval) { - dznrm2sub_( n, x, incx, rval); + dznrm2sub_blis_impl( n, x, incx, rval); } void DZNRM2SUB_(const f77_int* n, const dcomplex* x, const f77_int* incx, double* rval) { - dznrm2sub_( n, x, incx, rval); + dznrm2sub_blis_impl( n, x, incx, rval); } void ICAMAXSUB(const f77_int* n, const scomplex* x, const f77_int* incx, f77_int* rval) { - icamaxsub_( n, x, incx, rval); + icamaxsub_blis_impl( n, x, incx, rval); } void icamaxsub(const f77_int* n, const scomplex* x, const f77_int* incx, f77_int* rval) { - icamaxsub_( n, x, incx, rval); + icamaxsub_blis_impl( n, x, incx, rval); } void ICAMAXSUB_(const f77_int* n, const scomplex* x, const f77_int* incx, f77_int* rval) { - icamaxsub_( n, x, incx, rval); + icamaxsub_blis_impl( n, x, incx, rval); } f77_int ICAMIN( const f77_int* n, const scomplex* x, const f77_int* incx) @@ -2680,32 +2680,32 @@ f77_int ICAMIN_( const f77_int* n, const scomplex* x, const f77_int* incx) void ICAMINSUB( const f77_int* n, const scomplex* x, const f77_int* incx, f77_int* rval) { - icaminsub_( n, x, incx, rval); + icaminsub_blis_impl( n, x, incx, rval); } void icaminsub( const f77_int* n, const scomplex* x, const f77_int* incx, f77_int* rval) { - icaminsub_( n, x, incx, rval); + icaminsub_blis_impl( n, x, incx, rval); } void ICAMINSUB_( const f77_int* n, const scomplex* x, const f77_int* incx, f77_int* rval) { - icaminsub_( n, x, incx, rval); + icaminsub_blis_impl( n, x, incx, rval); } void IDAMAXSUB( const f77_int* n, const double* x, const f77_int* incx, f77_int* rval) { - idamaxsub_( n, x, incx, rval); + idamaxsub_blis_impl( n, x, incx, rval); } void idamaxsub( const f77_int* n, const double* x, const f77_int* incx, f77_int* rval) { - idamaxsub_( n, x, incx, rval); + idamaxsub_blis_impl( n, x, incx, rval); } void IDAMAXSUB_( const f77_int* n, const double* x, const f77_int* incx, f77_int* rval) { - idamaxsub_( n, x, incx, rval); + idamaxsub_blis_impl( n, x, incx, rval); } f77_int IDAMIN( const f77_int* n, const double* x, const f77_int* incx) @@ -2725,32 +2725,32 @@ f77_int IDAMIN_( const f77_int* n, const double* x, const f77_int* incx) void IDAMINSUB(const f77_int* n, const double* x, const f77_int* incx, f77_int* rval) { - idaminsub_( n, x, incx, rval); + idaminsub_blis_impl( n, x, incx, rval); } void idaminsub(const f77_int* n, const double* x, const f77_int* incx, f77_int* rval) { - idaminsub_( n, x, incx, rval); + idaminsub_blis_impl( n, x, incx, rval); } void IDAMINSUB_(const f77_int* n, const double* x, const f77_int* incx, f77_int* rval) { - idaminsub_( n, x, incx, rval); + idaminsub_blis_impl( n, x, incx, rval); } void ISAMAXSUB( const f77_int* n, const float* x, const f77_int* incx, f77_int* rval) { - isamaxsub_( n, x, incx, rval); + isamaxsub_blis_impl( n, x, incx, rval); } void isamaxsub( const f77_int* n, const float* x, const f77_int* incx, f77_int* rval) { - isamaxsub_( n, x, incx, rval); + isamaxsub_blis_impl( n, x, incx, rval); } void ISAMAXSUB_( const f77_int* n, const float* x, const f77_int* incx, f77_int* rval) { - isamaxsub_( n, x, incx, rval); + isamaxsub_blis_impl( n, x, incx, rval); } f77_int ISAMIN( const f77_int* n, const float* x, const f77_int* incx) @@ -2770,32 +2770,32 @@ f77_int ISAMIN_( const f77_int* n, const float* x, const f77_int* incx) void ISAMINSUB( const f77_int* n, const float* x, const f77_int* incx, f77_int* rval) { - isaminsub_( n, x, incx, rval); + isaminsub_blis_impl( n, x, incx, rval); } void isaminsub( const f77_int* n, const float* x, const f77_int* incx, f77_int* rval) { - isaminsub_( n, x, incx, rval); + isaminsub_blis_impl( n, x, incx, rval); } void ISAMINSUB_( const f77_int* n, const float* x, const f77_int* incx, f77_int* rval) { - isaminsub_( n, x, incx, rval); + isaminsub_blis_impl( n, x, incx, rval); } void IZAMAXSUB( const f77_int* n, const dcomplex* x, const f77_int* incx, f77_int* rval) { - izamaxsub_( n, x, incx, rval); + izamaxsub_blis_impl( n, x, incx, rval); } void izamaxsub( const f77_int* n, const dcomplex* x, const f77_int* incx, f77_int* rval) { - izamaxsub_( n, x, incx, rval); + izamaxsub_blis_impl( n, x, incx, rval); } void IZAMAXSUB_( const f77_int* n, const dcomplex* x, const f77_int* incx, f77_int* rval) { - izamaxsub_( n, x, incx, rval); + izamaxsub_blis_impl( n, x, incx, rval); } f77_int IZAMIN( const f77_int* n, const dcomplex* x, const f77_int* incx) @@ -2815,92 +2815,92 @@ f77_int IZAMIN_( const f77_int* n, const dcomplex* x, const f77_int* incx) void IZAMINSUB( const f77_int* n, const dcomplex* x, const f77_int* incx, f77_int* rval) { - izaminsub_( n, x, incx, rval); + izaminsub_blis_impl( n, x, incx, rval); } void izaminsub( const f77_int* n, const dcomplex* x, const f77_int* incx, f77_int* rval) { - izaminsub_( n, x, incx, rval); + izaminsub_blis_impl( n, x, incx, rval); } void IZAMINSUB_( const f77_int* n, const dcomplex* x, const f77_int* incx, f77_int* rval) { - izaminsub_( n, x, incx, rval); + izaminsub_blis_impl( n, x, incx, rval); } void SASUMSUB( const f77_int* n, const float* x, const f77_int* incx, float* rval) { - sasumsub_( n, x, incx, rval); + sasumsub_blis_impl( n, x, incx, rval); } void sasumsub( const f77_int* n, const float* x, const f77_int* incx, float* rval) { - sasumsub_( n, x, incx, rval); + sasumsub_blis_impl( n, x, incx, rval); } void SASUMSUB_( const f77_int* n, const float* x, const f77_int* incx, float* rval) { - sasumsub_( n, x, incx, rval); + sasumsub_blis_impl( n, x, incx, rval); } void SAXPBY( const f77_int* n, const float* alpha, const float *x, const f77_int* incx, const float* beta, float *y, const f77_int* incy) { - saxpby_( n, alpha, x, incx, beta, y, incy); + saxpby_blis_impl( n, alpha, x, incx, beta, y, incy); } void saxpby( const f77_int* n, const float* alpha, const float *x, const f77_int* incx, const float* beta, float *y, const f77_int* incy) { - saxpby_( n, alpha, x, incx, beta, y, incy); + saxpby_blis_impl( n, alpha, x, incx, beta, y, incy); } void SAXPBY_( const f77_int* n, const float* alpha, const float *x, const f77_int* incx, const float* beta, float *y, const f77_int* incy) { - saxpby_( n, alpha, x, incx, beta, y, incy); + saxpby_blis_impl( n, alpha, x, incx, beta, y, incy); } void SCASUMSUB( const f77_int* n, const scomplex* x, const f77_int* incx, float* rval) { - scasumsub_( n, x, incx, rval); + scasumsub_blis_impl( n, x, incx, rval); } void scasumsub( const f77_int* n, const scomplex* x, const f77_int* incx, float* rval) { - scasumsub_( n, x, incx, rval); + scasumsub_blis_impl( n, x, incx, rval); } void SCASUMSUB_( const f77_int* n, const scomplex* x, const f77_int* incx, float* rval) { - scasumsub_( n, x, incx, rval); + scasumsub_blis_impl( n, x, incx, rval); } void SCNRM2SUB( const f77_int* n, const scomplex* x, const f77_int* incx, float* rval) { - scnrm2sub_( n, x, incx, rval); + scnrm2sub_blis_impl( n, x, incx, rval); } void scnrm2sub( const f77_int* n, const scomplex* x, const f77_int* incx, float* rval) { - scnrm2sub_( n, x, incx, rval); + scnrm2sub_blis_impl( n, x, incx, rval); } void SCNRM2SUB_( const f77_int* n, const scomplex* x, const f77_int* incx, float* rval) { - scnrm2sub_( n, x, incx, rval); + scnrm2sub_blis_impl( n, x, incx, rval); } void SDOTSUB( const f77_int* n, const float* x, const f77_int* incx, const float* y, const f77_int* incy, float* rval) { - sdotsub_( n, x, incx, y, incy, rval); + sdotsub_blis_impl( n, x, incx, y, incy, rval); } void sdotsub( const f77_int* n, const float* x, const f77_int* incx, const float* y, const f77_int* incy, float* rval) { - sdotsub_( n, x, incx, y, incy, rval); + sdotsub_blis_impl( n, x, incx, y, incy, rval); } void SDOTSUB_( const f77_int* n, const float* x, const f77_int* incx, const float* y, const f77_int* incy, float* rval) { - sdotsub_( n, x, incx, y, incy, rval); + sdotsub_blis_impl( n, x, incx, y, incy, rval); } void SGEMM_BATCH(const f77_char* transa_array, const f77_char* transb_array,const f77_int *m_array, const f77_int *n_array, const f77_int *k_array,const float* alpha_array, const float** a_array, const f77_int *lda_array, const float** b_array, const f77_int *ldb_array, const float* beta_array, float** c_array, const f77_int *ldc_array, const f77_int* group_count, const f77_int *group_size) @@ -2950,17 +2950,17 @@ void SIMATCOPY_( f77_char* trans, f77_int* rows, f77_int* cols, const float* alp void SNRM2SUB( const f77_int* n, const float* x, const f77_int* incx, float *rval) { - snrm2sub_( n, x, incx, rval); + snrm2sub_blis_impl( n, x, incx, rval); } void snrm2sub( const f77_int* n, const float* x, const f77_int* incx, float *rval) { - snrm2sub_( n, x, incx, rval); + snrm2sub_blis_impl( n, x, incx, rval); } void SNRM2SUB_( const f77_int* n, const float* x, const f77_int* incx, float *rval) { - snrm2sub_( n, x, incx, rval); + snrm2sub_blis_impl( n, x, incx, rval); } void SOMATADD( f77_char* transa,f77_char* transb, f77_int* m, f77_int* n, const float* alpha, const float* A, f77_int* lda, const float* beta, const float* B, f77_int* ldb, float* C, f77_int* ldc) @@ -3010,47 +3010,47 @@ void SOMATCOPY_( f77_char* trans, f77_int* rows, f77_int* cols, const float* alp void ZAXPBY( const f77_int* n, const dcomplex* alpha, const dcomplex *x, const f77_int* incx, const dcomplex* beta, dcomplex *y, const f77_int* incy) { - zaxpby_( n, alpha, x, incx, beta, y, incy); + zaxpby_blis_impl( n, alpha, x, incx, beta, y, incy); } void zaxpby( const f77_int* n, const dcomplex* alpha, const dcomplex *x, const f77_int* incx, const dcomplex* beta, dcomplex *y, const f77_int* incy) { - zaxpby_( n, alpha, x, incx, beta, y, incy); + zaxpby_blis_impl( n, alpha, x, incx, beta, y, incy); } void ZAXPBY_( const f77_int* n, const dcomplex* alpha, const dcomplex *x, const f77_int* incx, const dcomplex* beta, dcomplex *y, const f77_int* incy) { - zaxpby_( n, alpha, x, incx, beta, y, incy); + zaxpby_blis_impl( n, alpha, x, incx, beta, y, incy); } void ZDOTCSUB( const f77_int* n, const dcomplex* x, const f77_int* incx, const dcomplex* y, const f77_int* incy, dcomplex* rval) { - zdotcsub_( n, x, incx, y, incy, rval); + zdotcsub_blis_impl( n, x, incx, y, incy, rval); } void zdotcsub( const f77_int* n, const dcomplex* x, const f77_int* incx, const dcomplex* y, const f77_int* incy, dcomplex* rval) { - zdotcsub_( n, x, incx, y, incy, rval); + zdotcsub_blis_impl( n, x, incx, y, incy, rval); } void ZDOTCSUB_( const f77_int* n, const dcomplex* x, const f77_int* incx, const dcomplex* y, const f77_int* incy, dcomplex* rval) { - zdotcsub_( n, x, incx, y, incy, rval); + zdotcsub_blis_impl( n, x, incx, y, incy, rval); } void ZDOTUSUB( const f77_int* n, const dcomplex* x, const f77_int* incx,const dcomplex* y, const f77_int* incy, dcomplex* rval) { - zdotusub_( n, x, incx, y, incy, rval); + zdotusub_blis_impl( n, x, incx, y, incy, rval); } void zdotusub( const f77_int* n, const dcomplex* x, const f77_int* incx,const dcomplex* y, const f77_int* incy, dcomplex* rval) { - zdotusub_( n, x, incx, y, incy, rval); + zdotusub_blis_impl( n, x, incx, y, incy, rval); } void ZDOTUSUB_( const f77_int* n, const dcomplex* x, const f77_int* incx,const dcomplex* y, const f77_int* incy, dcomplex* rval) { - zdotusub_( n, x, incx, y, incy, rval); + zdotusub_blis_impl( n, x, incx, y, incy, rval); } void ZGEMM3M( const f77_char* transa, const f77_char* transb, const f77_int* m, const f77_int* n, const f77_int* k, const dcomplex* alpha, const dcomplex* a, const f77_int* lda, const dcomplex* b, const f77_int* ldb, const dcomplex* beta, dcomplex* c, const f77_int* ldc) @@ -3178,47 +3178,47 @@ float SCABS1_(bla_scomplex* z) void SDSDOTSUB( const f77_int* n, float* sb, const float* x, const f77_int* incx, const float* y, const f77_int* incy, float* dot) { - sdsdotsub_( n, sb, x, incx, y, incy, dot); + sdsdotsub_blis_impl( n, sb, x, incx, y, incy, dot); } void sdsdotsub( const f77_int* n, float* sb, const float* x, const f77_int* incx, const float* y, const f77_int* incy, float* dot) { - sdsdotsub_( n, sb, x, incx, y, incy, dot); + sdsdotsub_blis_impl( n, sb, x, incx, y, incy, dot); } void SDSDOTSUB_( const f77_int* n, float* sb, const float* x, const f77_int* incx, const float* y, const f77_int* incy, float* dot) { - sdsdotsub_( n, sb, x, incx, y, incy, dot); + sdsdotsub_blis_impl( n, sb, x, incx, y, incy, dot); } void DSDOTSUB( const f77_int* n, const float* x, const f77_int* incx, const float* y, const f77_int* incy, double* dot) { - dsdotsub_( n, x, incx, y, incy, dot); + dsdotsub_blis_impl( n, x, incx, y, incy, dot); } void dsdotsub( const f77_int* n, const float* x, const f77_int* incx, const float* y, const f77_int* incy, double* dot) { - dsdotsub_( n, x, incx, y, incy, dot); + dsdotsub_blis_impl( n, x, incx, y, incy, dot); } void DSDOTSUB_( const f77_int* n, const float* x, const f77_int* incx, const float* y, const f77_int* incy, double* dot) { - dsdotsub_( n, x, incx, y, incy, dot); + dsdotsub_blis_impl( n, x, incx, y, incy, dot); } void CAXPBY( const f77_int* n, const scomplex* alpha, const scomplex *x, const f77_int* incx, const scomplex* beta, scomplex *y, const f77_int* incy) { - caxpby_(n, alpha, x, incx, beta, y, incy); + caxpby_blis_impl(n, alpha, x, incx, beta, y, incy); } void caxpby( const f77_int* n, const scomplex* alpha, const scomplex *x, const f77_int* incx, const scomplex* beta, scomplex *y, const f77_int* incy) { - caxpby_(n, alpha, x, incx, beta, y, incy); + caxpby_blis_impl(n, alpha, x, incx, beta, y, incy); } void CAXPBY_( const f77_int* n, const scomplex* alpha, const scomplex *x, const f77_int* incx, const scomplex* beta, scomplex *y, const f77_int* incy) { - caxpby_(n, alpha, x, incx, beta, y, incy); + caxpby_blis_impl(n, alpha, x, incx, beta, y, incy); } #endif From d3b503bbf251d8d4402d64ed32480e4757ad4dfb Mon Sep 17 00:00:00 2001 From: Dipal M Zambare Date: Mon, 29 Aug 2022 12:59:37 +0530 Subject: [PATCH 198/243] Code cleanup and warnings fixes - Removed all compiler warnings as reported by GCC 11 and AOCC 3.2 - Removed unused files - Removed commented and disabled code (#if 0, #if 1) from some files AMD-Internal: [CPUPL-2460] Change-Id: Ifc976f6fe585b09e2e387b6793961ad6ef05bb4a --- config/zen/bli_cntx_init_zen.c | 19 +- config/zen2/bli_cntx_init_zen2.c | 4 +- config/zen3/bli_cntx_init_zen3.c | 4 +- config/zen4/bli_cntx_init_zen4.c | 26 - frame/3/gemm/bli_gemm_front.c | 84 -- frame/3/gemm/bli_gemm_front_amd.c | 86 -- frame/3/gemmt/bli_gemmt_front.c | 87 +- frame/compat/bla_gemm.c | 3 +- frame/compat/bla_gemm_amd.c | 43 +- kernels/zen/3/CMakeLists.txt | 2 - kernels/zen/3/bli_gemm_small.c | 12 +- kernels/zen/3/bli_gemm_sqp.c | 1203 ------------------ kernels/zen/3/bli_gemm_sqp_kernels.c | 1750 -------------------------- kernels/zen/3/bli_gemm_sqp_kernels.h | 65 - kernels/zen/bli_kernels_zen.h | 12 - testsuite/src/test_gemm.c | 46 +- 16 files changed, 22 insertions(+), 3424 deletions(-) delete mode 100644 kernels/zen/3/bli_gemm_sqp.c delete mode 100644 kernels/zen/3/bli_gemm_sqp_kernels.c delete mode 100644 kernels/zen/3/bli_gemm_sqp_kernels.h diff --git a/config/zen/bli_cntx_init_zen.c b/config/zen/bli_cntx_init_zen.c index f527fe58d1..9d4197712e 100644 --- a/config/zen/bli_cntx_init_zen.c +++ b/config/zen/bli_cntx_init_zen.c @@ -104,11 +104,11 @@ void bli_cntx_init_zen( cntx_t* cntx ) bli_cntx_set_l1v_kers ( 26, -#if 1 + // amaxv BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int, BLIS_AMAXV_KER, BLIS_DOUBLE, bli_damaxv_zen_int, -#endif + // axpbyv BLIS_AXPBYV_KER, BLIS_FLOAT, bli_saxpbyv_zen_int10, BLIS_AXPBYV_KER, BLIS_DOUBLE, bli_daxpbyv_zen_int10, @@ -116,16 +116,11 @@ void bli_cntx_init_zen( cntx_t* cntx ) BLIS_AXPBYV_KER, BLIS_DCOMPLEX, bli_zaxpbyv_zen_int, // axpyv -#if 0 - BLIS_AXPYV_KER, BLIS_FLOAT, bli_saxpyv_zen_int, - BLIS_AXPYV_KER, BLIS_DOUBLE, bli_daxpyv_zen_int, -#else BLIS_AXPYV_KER, BLIS_FLOAT, bli_saxpyv_zen_int10, BLIS_AXPYV_KER, BLIS_DOUBLE, bli_daxpyv_zen_int10, BLIS_AXPYV_KER, BLIS_SCOMPLEX, bli_caxpyv_zen_int5, BLIS_AXPYV_KER, BLIS_DCOMPLEX, bli_zaxpyv_zen_int5, -#endif // dotv BLIS_DOTV_KER, BLIS_FLOAT, bli_sdotv_zen_int, BLIS_DOTV_KER, BLIS_DOUBLE, bli_ddotv_zen_int, @@ -138,18 +133,18 @@ void bli_cntx_init_zen( cntx_t* cntx ) BLIS_DOTXV_KER, BLIS_DCOMPLEX, bli_zdotxv_zen_int, BLIS_DOTXV_KER, BLIS_SCOMPLEX, bli_cdotxv_zen_int, // scalv -#if 0 - BLIS_SCALV_KER, BLIS_FLOAT, bli_sscalv_zen_int, - BLIS_SCALV_KER, BLIS_DOUBLE, bli_dscalv_zen_int, -#else + BLIS_SCALV_KER, BLIS_FLOAT, bli_sscalv_zen_int10, BLIS_SCALV_KER, BLIS_DOUBLE, bli_dscalv_zen_int10, -#endif + + // swapv BLIS_SWAPV_KER, BLIS_FLOAT, bli_sswapv_zen_int8, BLIS_SWAPV_KER, BLIS_DOUBLE, bli_dswapv_zen_int8, + // copyv BLIS_COPYV_KER, BLIS_FLOAT, bli_scopyv_zen_int, BLIS_COPYV_KER, BLIS_DOUBLE, bli_dcopyv_zen_int, + //set BLIS_SETV_KER, BLIS_FLOAT, bli_ssetv_zen_int, BLIS_SETV_KER, BLIS_DOUBLE, bli_dsetv_zen_int, diff --git a/config/zen2/bli_cntx_init_zen2.c b/config/zen2/bli_cntx_init_zen2.c index 1ecb62ff52..3ce2fced92 100644 --- a/config/zen2/bli_cntx_init_zen2.c +++ b/config/zen2/bli_cntx_init_zen2.c @@ -116,11 +116,11 @@ void bli_cntx_init_zen2( cntx_t* cntx ) bli_cntx_set_l1v_kers ( 26, -#if 1 + // amaxv BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int, BLIS_AMAXV_KER, BLIS_DOUBLE, bli_damaxv_zen_int, -#endif + // axpbyv BLIS_AXPBYV_KER, BLIS_FLOAT, bli_saxpbyv_zen_int10, BLIS_AXPBYV_KER, BLIS_DOUBLE, bli_daxpbyv_zen_int10, diff --git a/config/zen3/bli_cntx_init_zen3.c b/config/zen3/bli_cntx_init_zen3.c index 02e264d277..779bb7277c 100644 --- a/config/zen3/bli_cntx_init_zen3.c +++ b/config/zen3/bli_cntx_init_zen3.c @@ -116,11 +116,11 @@ void bli_cntx_init_zen3( cntx_t* cntx ) bli_cntx_set_l1v_kers ( 26, -#if 1 + // amaxv BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int, BLIS_AMAXV_KER, BLIS_DOUBLE, bli_damaxv_zen_int, -#endif + // axpbyv BLIS_AXPBYV_KER, BLIS_FLOAT, bli_saxpbyv_zen_int10, BLIS_AXPBYV_KER, BLIS_DOUBLE, bli_daxpbyv_zen_int10, diff --git a/config/zen4/bli_cntx_init_zen4.c b/config/zen4/bli_cntx_init_zen4.c index 5f0728f18a..1de13061b2 100644 --- a/config/zen4/bli_cntx_init_zen4.c +++ b/config/zen4/bli_cntx_init_zen4.c @@ -204,32 +204,6 @@ void bli_cntx_init_zen4( cntx_t* cntx ) ); // ------------------------------------------------------------------------- -#if 0 // Replaced with runtime blocksize override - - //Initialize TRSM blocksize objects with architecture-specific values. - //Using different cache block sizes for TRSM instead of common level-3 block sizes. - //Tuning is done for double-precision only. - // s d c z - bli_blksz_init_easy( &blkszs[ BLIS_MR ], 6, 6, 3, 3 ); - bli_blksz_init_easy( &blkszs[ BLIS_NR ], 16, 8, 8, 4 ); - bli_blksz_init_easy( &blkszs[ BLIS_MC ], 144, 72, 144, 72 ); - bli_blksz_init_easy( &blkszs[ BLIS_KC ], 256, 492, 256, 256 ); - bli_blksz_init_easy( &blkszs[ BLIS_NC ], 4080, 1600, 4080, 4080 ); - - // Update the context with the current architecture's register and cache - // blocksizes for level-3 TRSM problems. - bli_cntx_set_trsm_blkszs - ( - 5, - BLIS_NC, &blkszs[ BLIS_NC ], - BLIS_KC, &blkszs[ BLIS_KC ], - BLIS_MC, &blkszs[ BLIS_MC ], - BLIS_NR, &blkszs[ BLIS_NR ], - BLIS_MR, &blkszs[ BLIS_MR ], - cntx - ); -#endif - // Initialize sup thresholds with architecture-appropriate values. s d c z bli_blksz_init_easy( &thresh[ BLIS_MT ], 512, 256, 380, 110 ); bli_blksz_init_easy( &thresh[ BLIS_NT ], 200, 256, 256, 128 ); diff --git a/frame/3/gemm/bli_gemm_front.c b/frame/3/gemm/bli_gemm_front.c index a9bada995d..063f40ff9c 100644 --- a/frame/3/gemm/bli_gemm_front.c +++ b/frame/3/gemm/bli_gemm_front.c @@ -294,88 +294,4 @@ void bli_gemm_front // ----------------------------------------------------------------------------- -#if 0 - if ( bli_obj_dt( a ) != bli_obj_dt( b ) || - bli_obj_dt( a ) != bli_obj_dt( c ) || - bli_obj_comp_prec( c ) != bli_obj_prec( c ) ) - { - const bool a_is_real = bli_obj_is_real( a ); - const bool a_is_comp = bli_obj_is_complex( a ); - const bool b_is_real = bli_obj_is_real( b ); - const bool b_is_comp = bli_obj_is_complex( b ); - const bool c_is_real = bli_obj_is_real( c ); - const bool c_is_comp = bli_obj_is_complex( c ); - - const bool a_is_single = bli_obj_is_single_prec( a ); - const bool a_is_double = bli_obj_is_double_prec( a ); - const bool b_is_single = bli_obj_is_single_prec( b ); - const bool b_is_double = bli_obj_is_double_prec( b ); - const bool c_is_single = bli_obj_is_single_prec( c ); - const bool c_is_double = bli_obj_is_double_prec( c ); - - const bool comp_single = bli_obj_comp_prec( c ) == BLIS_SINGLE_PREC; - const bool comp_double = bli_obj_comp_prec( c ) == BLIS_DOUBLE_PREC; - - const bool mixeddomain = bli_obj_domain( c ) != bli_obj_domain( a ) || - bli_obj_domain( c ) != bli_obj_domain( b ); - - ( void )a_is_real; ( void )a_is_comp; - ( void )b_is_real; ( void )b_is_comp; - ( void )c_is_real; ( void )c_is_comp; - ( void )a_is_single; ( void )a_is_double; - ( void )b_is_single; ( void )b_is_double; - ( void )c_is_single; ( void )c_is_double; - ( void )comp_single; ( void )comp_double; - - if ( - //( c_is_comp && a_is_comp && b_is_real ) || - //( c_is_comp && a_is_real && b_is_comp ) || - //( c_is_real && a_is_comp && b_is_comp ) || - //( c_is_comp && a_is_real && b_is_real ) || - //( c_is_real && a_is_comp && b_is_real ) || - //( c_is_real && a_is_real && b_is_comp ) || - //FALSE - TRUE - ) - { - if ( - ( c_is_single && a_is_single && b_is_single && mixeddomain ) || - ( c_is_single && a_is_single && b_is_single && comp_single ) || - ( c_is_single && a_is_single && b_is_single && comp_double ) || - ( c_is_single && a_is_single && b_is_double ) || - ( c_is_single && a_is_double && b_is_single ) || - ( c_is_double && a_is_single && b_is_single ) || - ( c_is_single && a_is_double && b_is_double ) || - ( c_is_double && a_is_single && b_is_double ) || - ( c_is_double && a_is_double && b_is_single ) || - ( c_is_double && a_is_double && b_is_double && comp_single ) || - ( c_is_double && a_is_double && b_is_double && comp_double ) || - ( c_is_double && a_is_double && b_is_double && mixeddomain ) || - FALSE - ) - bli_gemm_md_front( alpha, a, b, beta, c, cntx, cntl ); - else - bli_gemm_md_zgemm( alpha, a, b, beta, c, cntx, cntl ); - } - else - bli_gemm_md_zgemm( alpha, a, b, beta, c, cntx, cntl ); - return; - } -#else -#if 0 - // If any of the storage datatypes differ, or if the execution precision - // differs from the storage precision of C, utilize the mixed datatype - // code path. - // NOTE: We could check the exec dt against the storage dt of C, but for - // now we don't support the caller setting the execution domain - // explicitly. - if ( bli_obj_dt( a ) != bli_obj_dt( b ) || - bli_obj_dt( a ) != bli_obj_dt( c ) || - bli_obj_comp_prec( c ) != bli_obj_prec( c ) ) - { - bli_gemm_md_front( alpha, a, b, beta, c, cntx, cntl ); - return; - } -#endif -#endif diff --git a/frame/3/gemm/bli_gemm_front_amd.c b/frame/3/gemm/bli_gemm_front_amd.c index 34b41f0568..b15d906dd8 100644 --- a/frame/3/gemm/bli_gemm_front_amd.c +++ b/frame/3/gemm/bli_gemm_front_amd.c @@ -319,89 +319,3 @@ void bli_gemm_front } // ----------------------------------------------------------------------------- - -#if 0 - if ( bli_obj_dt( a ) != bli_obj_dt( b ) || - bli_obj_dt( a ) != bli_obj_dt( c ) || - bli_obj_comp_prec( c ) != bli_obj_prec( c ) ) - { - const bool a_is_real = bli_obj_is_real( a ); - const bool a_is_comp = bli_obj_is_complex( a ); - const bool b_is_real = bli_obj_is_real( b ); - const bool b_is_comp = bli_obj_is_complex( b ); - const bool c_is_real = bli_obj_is_real( c ); - const bool c_is_comp = bli_obj_is_complex( c ); - - const bool a_is_single = bli_obj_is_single_prec( a ); - const bool a_is_double = bli_obj_is_double_prec( a ); - const bool b_is_single = bli_obj_is_single_prec( b ); - const bool b_is_double = bli_obj_is_double_prec( b ); - const bool c_is_single = bli_obj_is_single_prec( c ); - const bool c_is_double = bli_obj_is_double_prec( c ); - - const bool comp_single = bli_obj_comp_prec( c ) == BLIS_SINGLE_PREC; - const bool comp_double = bli_obj_comp_prec( c ) == BLIS_DOUBLE_PREC; - - const bool mixeddomain = bli_obj_domain( c ) != bli_obj_domain( a ) || - bli_obj_domain( c ) != bli_obj_domain( b ); - - ( void )a_is_real; ( void )a_is_comp; - ( void )b_is_real; ( void )b_is_comp; - ( void )c_is_real; ( void )c_is_comp; - ( void )a_is_single; ( void )a_is_double; - ( void )b_is_single; ( void )b_is_double; - ( void )c_is_single; ( void )c_is_double; - ( void )comp_single; ( void )comp_double; - - if ( - //( c_is_comp && a_is_comp && b_is_real ) || - //( c_is_comp && a_is_real && b_is_comp ) || - //( c_is_real && a_is_comp && b_is_comp ) || - //( c_is_comp && a_is_real && b_is_real ) || - //( c_is_real && a_is_comp && b_is_real ) || - //( c_is_real && a_is_real && b_is_comp ) || - //FALSE - TRUE - ) - { - if ( - ( c_is_single && a_is_single && b_is_single && mixeddomain ) || - ( c_is_single && a_is_single && b_is_single && comp_single ) || - ( c_is_single && a_is_single && b_is_single && comp_double ) || - ( c_is_single && a_is_single && b_is_double ) || - ( c_is_single && a_is_double && b_is_single ) || - ( c_is_double && a_is_single && b_is_single ) || - ( c_is_single && a_is_double && b_is_double ) || - ( c_is_double && a_is_single && b_is_double ) || - ( c_is_double && a_is_double && b_is_single ) || - ( c_is_double && a_is_double && b_is_double && comp_single ) || - ( c_is_double && a_is_double && b_is_double && comp_double ) || - ( c_is_double && a_is_double && b_is_double && mixeddomain ) || - FALSE - ) - bli_gemm_md_front( alpha, a, b, beta, c, cntx, cntl ); - else - bli_gemm_md_zgemm( alpha, a, b, beta, c, cntx, cntl ); - } - else - bli_gemm_md_zgemm( alpha, a, b, beta, c, cntx, cntl ); - return; - } -#else -#if 0 - // If any of the storage datatypes differ, or if the execution precision - // differs from the storage precision of C, utilize the mixed datatype - // code path. - // NOTE: We could check the exec dt against the storage dt of C, but for - // now we don't support the caller setting the execution domain - // explicitly. - if ( bli_obj_dt( a ) != bli_obj_dt( b ) || - bli_obj_dt( a ) != bli_obj_dt( c ) || - bli_obj_comp_prec( c ) != bli_obj_prec( c ) ) - { - bli_gemm_md_front( alpha, a, b, beta, c, cntx, cntl ); - return; - } -#endif -#endif - diff --git a/frame/3/gemmt/bli_gemmt_front.c b/frame/3/gemmt/bli_gemmt_front.c index 86940c1bd2..b2155a0bcd 100644 --- a/frame/3/gemmt/bli_gemmt_front.c +++ b/frame/3/gemmt/bli_gemmt_front.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -278,88 +278,3 @@ void bli_gemmt_front } // ----------------------------------------------------------------------------- - -#if 0 - if ( bli_obj_dt( a ) != bli_obj_dt( b ) || - bli_obj_dt( a ) != bli_obj_dt( c ) || - bli_obj_comp_prec( c ) != bli_obj_prec( c ) ) - { - const bool a_is_real = bli_obj_is_real( a ); - const bool a_is_comp = bli_obj_is_complex( a ); - const bool b_is_real = bli_obj_is_real( b ); - const bool b_is_comp = bli_obj_is_complex( b ); - const bool c_is_real = bli_obj_is_real( c ); - const bool c_is_comp = bli_obj_is_complex( c ); - - const bool a_is_single = bli_obj_is_single_prec( a ); - const bool a_is_double = bli_obj_is_double_prec( a ); - const bool b_is_single = bli_obj_is_single_prec( b ); - const bool b_is_double = bli_obj_is_double_prec( b ); - const bool c_is_single = bli_obj_is_single_prec( c ); - const bool c_is_double = bli_obj_is_double_prec( c ); - - const bool comp_single = bli_obj_comp_prec( c ) == BLIS_SINGLE_PREC; - const bool comp_double = bli_obj_comp_prec( c ) == BLIS_DOUBLE_PREC; - - const bool mixeddomain = bli_obj_domain( c ) != bli_obj_domain( a ) || - bli_obj_domain( c ) != bli_obj_domain( b ); - - ( void )a_is_real; ( void )a_is_comp; - ( void )b_is_real; ( void )b_is_comp; - ( void )c_is_real; ( void )c_is_comp; - ( void )a_is_single; ( void )a_is_double; - ( void )b_is_single; ( void )b_is_double; - ( void )c_is_single; ( void )c_is_double; - ( void )comp_single; ( void )comp_double; - - if ( - //( c_is_comp && a_is_comp && b_is_real ) || - //( c_is_comp && a_is_real && b_is_comp ) || - //( c_is_real && a_is_comp && b_is_comp ) || - //( c_is_comp && a_is_real && b_is_real ) || - //( c_is_real && a_is_comp && b_is_real ) || - //( c_is_real && a_is_real && b_is_comp ) || - //FALSE - TRUE - ) - { - if ( - ( c_is_single && a_is_single && b_is_single && mixeddomain ) || - ( c_is_single && a_is_single && b_is_single && comp_single ) || - ( c_is_single && a_is_single && b_is_single && comp_double ) || - ( c_is_single && a_is_single && b_is_double ) || - ( c_is_single && a_is_double && b_is_single ) || - ( c_is_double && a_is_single && b_is_single ) || - ( c_is_single && a_is_double && b_is_double ) || - ( c_is_double && a_is_single && b_is_double ) || - ( c_is_double && a_is_double && b_is_single ) || - ( c_is_double && a_is_double && b_is_double && comp_single ) || - ( c_is_double && a_is_double && b_is_double && comp_double ) || - ( c_is_double && a_is_double && b_is_double && mixeddomain ) || - FALSE - ) - bli_gemm_md_front( alpha, a, b, beta, c, cntx, cntl ); - else - bli_gemm_md_zgemm( alpha, a, b, beta, c, cntx, cntl ); - } - else - bli_gemm_md_zgemm( alpha, a, b, beta, c, cntx, cntl ); - return; - } -#else -#if 0 - // If any of the storage datatypes differ, or if the execution precision - // differs from the storage precision of C, utilize the mixed datatype - // code path. - // NOTE: We could check the exec dt against the storage dt of C, but for - // now we don't support the caller setting the execution domain - // explicitly. - if ( bli_obj_dt( a ) != bli_obj_dt( b ) || - bli_obj_dt( a ) != bli_obj_dt( c ) || - bli_obj_comp_prec( c ) != bli_obj_prec( c ) ) - { - bli_gemm_md_front( alpha, a, b, beta, c, cntx, cntl ); - return; - } -#endif -#endif diff --git a/frame/compat/bla_gemm.c b/frame/compat/bla_gemm.c index ae14196e89..bf5cc502a7 100644 --- a/frame/compat/bla_gemm.c +++ b/frame/compat/bla_gemm.c @@ -357,7 +357,6 @@ void PASTEF77(ch,blasname) \ #ifdef BLIS_ENABLE_BLAS INSERT_GENTFUNC_BLAS( gemm,gemm ) -#if 1 void dzgemm_ ( const f77_char* transa, @@ -460,5 +459,5 @@ void dzgemm_ /* Finalize BLIS. */ bli_finalize_auto(); }// end of dzgemm_ -#endif + #endif diff --git a/frame/compat/bla_gemm_amd.c b/frame/compat/bla_gemm_amd.c index 6bc0dcd557..bf710bdae7 100644 --- a/frame/compat/bla_gemm_amd.c +++ b/frame/compat/bla_gemm_amd.c @@ -853,23 +853,6 @@ void zgemm_blis_impl return; } -#if 0 -/*** Code is disabled as bli_zgemv_unf_var1 not optimised *** - Calling below unoptimised variant causes regression ***/ - else - { - bli_zgemv_unf_var1( - blis_transa, - bli_extract_conj(blis_transb), - k0, m0, - (dcomplex *)alpha, - (dcomplex *)a, rs_a, cs_a, - (dcomplex *)b, bli_is_notrans(blis_transb) ? rs_b : cs_b, - (dcomplex *)beta, - c, rs_c, - ((void *)0)); - } -#endif } else if (m0 == 1) { @@ -888,30 +871,12 @@ void zgemm_blis_impl AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); return; } -#if 0 -/*** Code is disabled as bli_zgemv_unf_var1 not optimised *** - Calling below unoptimised variant causes regression ***/ - - else - { - bli_zgemv_unf_var1( - blis_transb, - bli_extract_conj(blis_transa), - n0, k0, - (dcomplex *)alpha, - (dcomplex *)b, cs_b, rs_b, - (dcomplex *)a, bli_is_notrans(blis_transa) ? cs_a : rs_a, - (dcomplex *)beta, - c, cs_c, - ((void *)0)); - } -#endif } #ifdef BLIS_ENABLE_SMALL_MATRIX if (((nt == 0) && (((m0 <= 40) && (n0 <= 40)) || - (m0 <= 128) && (n0 <= 128) && bli_is_notrans(blis_transb)) && (k0 <= 512)) || + ((m0 <= 128) && (n0 <= 128) && bli_is_notrans(blis_transb))) && (k0 <= 512)) || ((nt == 1) && (((m0 <= 32) || (n0 <= 32) || (k0 <= 32)) && ((m0 + n0 + k0) <= 100)))) { err_t status = BLIS_NOT_YET_IMPLEMENTED; @@ -988,10 +953,6 @@ void zgemm_ INSERT_GENTFUNC_BLAS_SC( gemm, gemm ) - -// Observed a regression in dgemm with this function addition. -// Disabling temporarily. -#if 1 void dzgemm_ ( const f77_char* transa, @@ -1094,5 +1055,5 @@ void dzgemm_ /* Finalize BLIS. */ bli_finalize_auto(); }// end of dzgemm_ -#endif + #endif diff --git a/kernels/zen/3/CMakeLists.txt b/kernels/zen/3/CMakeLists.txt index a52740ecaf..d90e4e3902 100644 --- a/kernels/zen/3/CMakeLists.txt +++ b/kernels/zen/3/CMakeLists.txt @@ -5,8 +5,6 @@ target_sources("${PROJECT_NAME}" ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemm_small.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_trsm_small.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_dgemm_ref_k1.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemm_sqp_kernels.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemm_sqp.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_zgemm_ref_k1.c ) diff --git a/kernels/zen/3/bli_gemm_small.c b/kernels/zen/3/bli_gemm_small.c index e232e28e51..22bb48f737 100644 --- a/kernels/zen/3/bli_gemm_small.c +++ b/kernels/zen/3/bli_gemm_small.c @@ -7274,7 +7274,7 @@ err_t bli_zgemm_small } m_remainder = M - row_idx; - if ((m_remainder == 3)) + if (m_remainder == 3) { m_remainder -= 3; __m128d xmm0; @@ -8213,7 +8213,7 @@ err_t bli_zgemm_small _mm_storeu_pd((double *)(tC + 2), xmm0); } } - if ((m_remainder == 2)) + if (m_remainder == 2) { m_remainder -= 2; @@ -8952,7 +8952,7 @@ err_t bli_zgemm_small _mm256_storeu_pd((double *)tC, ymm8); } } - if ((m_remainder == 1)) + if (m_remainder == 1) { m_remainder -= 1; __m128d xmm0; @@ -10842,7 +10842,7 @@ err_t bli_zgemm_small_At } m_remainder = M - row_idx; - if ((m_remainder == 3)) + if (m_remainder == 3) { m_remainder -= 3; __m128d xmm0; @@ -11832,7 +11832,7 @@ err_t bli_zgemm_small_At _mm_storeu_pd((double *)(tC + 2), xmm0); } } - if ((m_remainder == 2)) + if (m_remainder == 2) { m_remainder -= 2; @@ -12615,7 +12615,7 @@ err_t bli_zgemm_small_At _mm256_storeu_pd((double *)tC, ymm8); } } - if ((m_remainder == 1)) + if (m_remainder == 1) { m_remainder -= 1; __m128d xmm0; diff --git a/kernels/zen/3/bli_gemm_sqp.c b/kernels/zen/3/bli_gemm_sqp.c deleted file mode 100644 index ceab622bf3..0000000000 --- a/kernels/zen/3/bli_gemm_sqp.c +++ /dev/null @@ -1,1203 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2021, Advanced Micro Devices, Inc. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ -#include "blis.h" -#include "immintrin.h" -#include "bli_gemm_sqp_kernels.h" - -#define SQP_THREAD_ENABLE 0//currently disabled -#define BLI_SQP_MAX_THREADS 128 -#define BLIS_LOADFIRST 0 -#define MEM_ALLOC 1//malloc performs better than bli_malloc. - -#define SET_TRANS(X,Y)\ - Y = BLIS_NO_TRANSPOSE;\ - if(bli_obj_has_trans( a ))\ - {\ - Y = BLIS_TRANSPOSE;\ - if(bli_obj_has_conj(a))\ - {\ - Y = BLIS_CONJ_TRANSPOSE;\ - }\ - }\ - else if(bli_obj_has_conj(a))\ - {\ - Y = BLIS_CONJ_NO_TRANSPOSE;\ - } - -//Macro for 3m_sqp n loop -#define BLI_SQP_ZGEMM_N(MX)\ - int j=0;\ - for(; j<=(n-nx); j+= nx)\ - {\ - status = bli_sqp_zgemm_m8( m, nx, k, a, lda, b+(j*ldb), ldb, c+(j*ldc), ldc, alpha_real, beta_real, transa, MX, p_istart, kx, &mem_3m_sqp);\ - }\ - if(jreal; - double alpha_imag = alphap->imag; - double beta_real = betap->real; - double beta_imag = betap->imag; - if( (alpha_imag!=0)||(beta_imag!=0) ) - { - return BLIS_NOT_YET_IMPLEMENTED; - } - //printf("zsqp "); - return bli_sqp_zgemm( m, n, k, ap, lda, bp, ldb, cp, ldc, alpha_real, beta_real, transa, nt); - } - else if(dt == BLIS_DOUBLE) - { - double *alpha_cast, *beta_cast; - alpha_cast = bli_obj_buffer_for_1x1(BLIS_DOUBLE, alpha); - beta_cast = bli_obj_buffer_for_1x1(BLIS_DOUBLE, beta); - - if((*beta_cast)!=1.0) - { - return BLIS_NOT_YET_IMPLEMENTED; - } - if(((*alpha_cast)!=1.0)&&((*alpha_cast)!=-1.0)) - { - return BLIS_NOT_YET_IMPLEMENTED; - } - //printf("dsqp "); - // dgemm case only transpose or no-transpose is handled. - // conjugate_transpose and conjugate no transpose are not applicable. - return bli_sqp_dgemm( m, n, k, ap, lda, bp, ldb, cp, ldc, *alpha_cast, *beta_cast, isTransA, nt); - } - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); - return BLIS_NOT_YET_IMPLEMENTED; -}; - -//sqp_dgemm k partition -BLIS_INLINE void bli_sqp_dgemm_kx( gint_t m, - gint_t n, - gint_t kx, - gint_t p, - double* a, - guint_t lda, - double* b, - guint_t ldb, - double* c, - guint_t ldc, - bool isTransA, - double alpha, - gint_t mx, - gint_t i, - bool pack_on, - double *aligned) -{ - inc_t j = 0; - double* ci = c + i; - double* aPacked; - //packing - if(pack_on==true) - { - aPacked = aligned; - double *pa = a + i + (p*lda); - if(isTransA==true) - { - pa = a + (i*lda) + p; - } - bli_sqp_prepackA(pa, aPacked, kx, lda, isTransA, alpha, mx); - } - else - { - aPacked = a+i + (p*lda); - } - - //compute - if(mx==8) - { - //printf("\n mx8i:%3ld ", i); - if (j <= (n - 6)) - { - j = bli_sqp_dgemm_kernel_8mx6n(n, kx, j, aPacked, lda, b + p, ldb, ci, ldc); - } - if (j <= (n - 5)) - { - j = bli_sqp_dgemm_kernel_8mx5n(n, kx, j, aPacked, lda, b + (j * ldb) + p, ldb, ci + (j * ldc), ldc); - } - if (j <= (n - 4)) - { - j = bli_sqp_dgemm_kernel_8mx4n(n, kx, j, aPacked, lda, b + (j * ldb) + p, ldb, ci + (j * ldc), ldc); - } - if (j <= (n - 3)) - { - j = bli_sqp_dgemm_kernel_8mx3n(n, kx, j, aPacked, lda, b + (j * ldb) + p, ldb, ci + (j * ldc), ldc); - } - if (j <= (n - 2)) - { - j = bli_sqp_dgemm_kernel_8mx2n(n, kx, j, aPacked, lda, b + (j * ldb) + p, ldb, ci + (j * ldc), ldc); - } - if (j <= (n - 1)) - { - j = bli_sqp_dgemm_kernel_8mx1n(n, kx, j, aPacked, lda, b + (j * ldb) + p, ldb, ci + (j * ldc), ldc); - } - } - /* mx==4 to be implemented */ - else - { - // this residue kernel needs to be improved. - j = bli_sqp_dgemm_kernel_mxn(n, kx, j, aPacked, lda, b + p, ldb, ci, ldc, mx); - } -} - -//sqp dgemm m loop -void bli_sqp_dgemm_m( gint_t i_start, - gint_t i_end, - gint_t m, - gint_t n, - gint_t k, - gint_t kx, - double* a, - guint_t lda, - double* b, - guint_t ldb, - double* c, - guint_t ldc, - bool isTransA, - double alpha, - gint_t mx, - bool pack_on, - double *aligned ) -{ -#if SQP_THREAD_ENABLE - if(pack_on==true) - { - //NEEDED IN THREADING CASE: - aligned = (double*)bli_malloc_user(sizeof(double) * kx * mx); - if(aligned==NULL) - { - return BLIS_MALLOC_RETURNED_NULL;// return to be removed - } - } -#endif//SQP_THREAD_ENABLE - - for (gint_t i = i_start; i <= (i_end-mx); i += mx) //this loop can be threaded. no of workitems = m/8 - { - int p = 0; - for(; p <= (k-kx); p += kx) - { - bli_sqp_dgemm_kx(m, n, kx, p, a, lda, b, ldb, c, ldc, isTransA, alpha, mx, i, pack_on, aligned); - }// k loop end - - if(pi_start, - arg->i_end, - arg->m, - arg->n, - arg->k, - arg->kx, - arg->a, - arg->lda, - arg->b, - arg->ldb, - arg->c, - arg->ldc, - arg->isTransA, - arg->alpha, - arg->mx, - arg->pack_on, - arg->aligned); -} - -// sqp_dgemm m loop -BLIS_INLINE err_t bli_sqp_dgemm_m8( gint_t m, - gint_t n, - gint_t k, - double* a, - guint_t lda, - double* b, - guint_t ldb, - double* c, - guint_t ldc, - bool isTransA, - double alpha, - gint_t mx, - gint_t* p_istart, - gint_t kx, - double *aligned) -{ - gint_t i; - if(kx > k) - { - kx = k; - } - - bool pack_on = false; - if((m!=mx)||(m!=lda)||isTransA) - { - pack_on = true; - } - -#if 0//SQP_THREAD_ENABLE//ENABLE Threading - gint_t status = 0; - gint_t workitems = (m-(*p_istart))/mx; - gint_t inputThreadCount = bli_thread_get_num_threads(); - inputThreadCount = bli_min(inputThreadCount, BLI_SQP_MAX_THREADS); - inputThreadCount = bli_min(inputThreadCount,workitems);// limit input thread count when workitems are lesser. - inputThreadCount = bli_max(inputThreadCount,1); - gint_t num_threads; - num_threads = bli_max(inputThreadCount,1); - gint_t mx_per_thread = workitems/num_threads;//no of workitems per thread - //printf("\nistart %d workitems %d inputThreadCount %d num_threads %d mx_per_thread %d mx %d " , - *p_istart, workitems,inputThreadCount,num_threads,mx_per_thread, mx); - - pthread_t ptid[BLI_SQP_MAX_THREADS]; - bli_sqp_thread_info thread_info[BLI_SQP_MAX_THREADS]; - - //create threads - for (gint_t t = 0; t < num_threads; t++) - { - //ptid[t].tid = t; - gint_t i_end = ((mx_per_thread*(t+1))*mx)+(*p_istart); - if(i_end>m) - { - i_end = m; - } - - if(t==(num_threads-1)) - { - if((i_end+mx)==m) - { - i_end = m; - } - - if(mx==1) - { - i_end = m; - } - } - - thread_info[t].i_start = ((mx_per_thread*t)*mx)+(*p_istart); - thread_info[t].i_end = i_end; - //printf("\n threadid %d istart %d iend %d m %d mx %d", t, thread_info[t].i_start, i_end, m, mx); - thread_info[t].m = m; - thread_info[t].n = n; - thread_info[t].k = k; - thread_info[t].kx = kx; - thread_info[t].a = a; - thread_info[t].lda = lda; - thread_info[t].b = b; - thread_info[t].ldb = ldb; - thread_info[t].c = c; - thread_info[t].ldc = ldc; - thread_info[t].isTransA = isTransA; - thread_info[t].alpha = alpha; - thread_info[t].mx = mx; - thread_info[t].pack_on = pack_on; - thread_info[t].aligned = aligned; -#if 1 - if ((status = pthread_create(&ptid[t], NULL, bli_sqp_thread, (void*)&thread_info[t]))) - { - printf("error sqp pthread_create\n"); - return BLIS_FAILURE; - } -#else - //simulate thread for debugging.. - bli_sqp_thread((void*)&thread_info[t]); -#endif - } - - //wait for completion - for (gint_t t = 0; t < num_threads; t++) - { - pthread_join(ptid[t], NULL); - } - - if(num_threads>0) - { - *p_istart = thread_info[(num_threads-1)].i_end; - } -#else//SQP_THREAD_ENABLE - - if(pack_on==true) - { - //aligned = (double*)bli_malloc_user(sizeof(double) * kx * mx); // allocation moved to top. - if(aligned==NULL) - { - return BLIS_MALLOC_RETURNED_NULL; - } - } - - for (i = (*p_istart); i <= (m-mx); i += mx) //this loop can be threaded. no of workitems = m/8 - { - int p = 0; - for(; p <= (k-kx); p += kx) - { - bli_sqp_dgemm_kx(m, n, kx, p, a, lda, b, ldb, c, ldc, - isTransA, alpha, mx, i, pack_on, aligned); - }// k loop end - - if(pdata_size * mem_req->size; - if (memSize == 0) - { - return -1; - } - memSize += 128;// extra 128 bytes added for alignment. Could be minimized to 64. -#if MEM_ALLOC -#ifdef BLIS_ENABLE_MEM_TRACING - printf( "malloc(): size %ld\n",( long )memSize); - fflush( stdout ); -#endif - mem_req->unalignedBuf = (double*)malloc(memSize); - if (mem_req->unalignedBuf == NULL) - { - return -1; - } - - int64_t address = (int64_t)mem_req->unalignedBuf; - address += (-address) & 63; //64 bytes alignment done. - mem_req->alignedBuf = (double*)address; -#else - mem_req->alignedBuf = bli_malloc_user( memSize ); - if (mem_req->alignedBuf == NULL) - { - return -1; - } -#endif - return 0; -} - -gint_t bli_allocateWorkspace(gint_t n, gint_t k, mem_block *mxr, mem_block *mxi, mem_block *msx) -{ - //allocate workspace - mxr->data_size = mxi->data_size = msx->data_size = sizeof(double); - mxr->size = mxi->size = n * k; - msx->size = n * k; - mxr->alignedBuf = mxi->alignedBuf = msx->alignedBuf = NULL; - mxr->unalignedBuf = mxi->unalignedBuf = msx->unalignedBuf = NULL; - - if (!((bli_getaligned(mxr) == 0) && (bli_getaligned(mxi) == 0) && (bli_getaligned(msx) == 0))) - { -#if MEM_ALLOC - if(mxr->unalignedBuf) - { - free(mxr->unalignedBuf); - } - if(mxi->unalignedBuf) - { - free(mxi->unalignedBuf); - } - if(msx->unalignedBuf) - { - free(msx->unalignedBuf); - } -#else - bli_free_user(mxr->alignedBuf); - bli_free_user(mxi->alignedBuf); - bli_free_user(msx->alignedBuf); -#endif - return -1; - } - return 0; -} - -//3m_sqp k loop -BLIS_INLINE void bli_sqp_zgemm_kx( gint_t m, - gint_t n, - gint_t kx, - gint_t p, - double* a, - guint_t lda, - guint_t ldb, - double* c, - guint_t ldc, - trans_t transa, - double alpha, - double beta, - gint_t mx, - gint_t i, - double* ar, - double* ai, - double* as, - double* br, - double* bi, - double* bs, - double* cr, - double* ci, - double* w, - double *a_aligned) -{ - gint_t j; - - ////////////// operation 1 ///////////////// - /* Split a (ar, ai) and - compute as = ar + ai */ - double* par = ar; - double* pai = ai; - double* pas = as; - - /* a matrix real and imag packing and compute. */ - bli_3m_sqp_packA_real_imag_sum(a, i, kx+p, lda, par, pai, pas, transa, mx, p); - - double* pcr = cr; - double* pci = ci; - - //Split Cr and Ci and beta multiplication done. - double* pc = c + i; - if(p==0) - { - bli_3m_sqp_packC_real_imag(pc, n, mx, ldc, pcr, pci, beta, mx); - } - //Ci := rgemm( SA, SB, Ci ) - gint_t istart = 0; - gint_t* p_is = &istart; - *p_is = 0; - bli_sqp_dgemm_m8(mx, n, kx, as, mx, bs, ldb, ci, mx, false, 1.0, mx, p_is, kx, a_aligned); - - ////////////// operation 2 ///////////////// - //Wr: = dgemm_sqp(Ar, Br, 0) // Wr output 8xn - double* wr = w; - for (j = 0; j < n; j++) { - for (gint_t ii = 0; ii < mx; ii += 1) { - *wr = 0; - wr++; - } - } - wr = w; - - *p_is = 0; - bli_sqp_dgemm_m8(mx, n, kx, ar, mx, br, ldb, wr, mx, false, 1.0, mx, p_is, kx, a_aligned); - //Cr : = addm(Wr, Cr) - bli_add_m(mx, n, wr, cr); - //Ci : = subm(Wr, Ci) - bli_sub_m(mx, n, wr, ci); - - - ////////////// operation 3 ///////////////// - //Wi : = dgemm_sqp(Ai, Bi, 0) // Wi output 8xn - double* wi = w; - for (j = 0; j < n; j++) { - for (gint_t ii = 0; ii < mx; ii += 1) { - *wi = 0; - wi++; - } - } - wi = w; - - *p_is = 0; - bli_sqp_dgemm_m8(mx, n, kx, ai, mx, bi, ldb, wi, mx, false, 1.0, mx, p_is, kx, a_aligned); - //Cr : = subm(Wi, Cr) - bli_sub_m(mx, n, wi, cr); - //Ci : = subm(Wi, Ci) - bli_sub_m(mx, n, wi, ci); - - pcr = cr; - pci = ci; - - for (j = 0; j < n; j++) - { - for (gint_t ii = 0; ii < (mx*2); ii += 2) - { - c[(j * ldc) + i + ii] = *pcr; - c[(j * ldc) + i + ii + 1] = *pci; - pcr++; pci++; - } - } -} - -/**************************************************************/ -/* workspace memory allocation for 3m_sqp algorithm for zgemm */ -/**************************************************************/ -err_t allocate_3m_Sqp_workspace(workspace_3m_sqp *mem_3m_sqp, - gint_t mx, - gint_t nx, - gint_t k, - gint_t kx ) -{ - //3m_sqp workspace Memory allocation - /* B matrix */ - // B matrix packed with n x k size. without kx smaller sizes for now. - mem_block mbr, mbi, mbs; - if(bli_allocateWorkspace(nx, k, &mbr, &mbi, &mbs)!=0) - { - return BLIS_FAILURE; - } - mem_3m_sqp->br = (double*)mbr.alignedBuf; - mem_3m_sqp->bi = (double*)mbi.alignedBuf; - mem_3m_sqp->bs = (double*)mbs.alignedBuf; - mem_3m_sqp->br_unaligned = (double*)mbr.unalignedBuf; - mem_3m_sqp->bi_unaligned = (double*)mbi.unalignedBuf; - mem_3m_sqp->bs_unaligned = (double*)mbs.unalignedBuf; - - /* Workspace memory allocation currently done dynamically - This needs to be taken from already allocated memory pool in application for better performance */ - /* A matrix */ - mem_block mar, mai, mas; - if(bli_allocateWorkspace(mx, kx, &mar, &mai, &mas) !=0) - { - return BLIS_FAILURE; - } - mem_3m_sqp->ar = (double*)mar.alignedBuf; - mem_3m_sqp->ai = (double*)mai.alignedBuf; - mem_3m_sqp->as = (double*)mas.alignedBuf; - mem_3m_sqp->ar_unaligned = (double*)mar.unalignedBuf; - mem_3m_sqp->ai_unaligned = (double*)mai.unalignedBuf; - mem_3m_sqp->as_unaligned = (double*)mas.unalignedBuf; - - /* w matrix */ - mem_block mw; - mw.data_size = sizeof(double); - mw.size = mx * nx; - if (bli_getaligned(&mw) != 0) - { - return BLIS_FAILURE; - } - mem_3m_sqp->w = (double*)mw.alignedBuf; - mem_3m_sqp->w_unaligned = (double*)mw.unalignedBuf; - /* cr matrix */ - mem_block mcr; - mcr.data_size = sizeof(double); - mcr.size = mx * nx; - if (bli_getaligned(&mcr) != 0) - { - return BLIS_FAILURE; - } - mem_3m_sqp->cr = (double*)mcr.alignedBuf; - mem_3m_sqp->cr_unaligned = (double*)mcr.unalignedBuf; - - - /* ci matrix */ - mem_block mci; - mci.data_size = sizeof(double); - mci.size = mx * nx; - if (bli_getaligned(&mci) != 0) - { - return BLIS_FAILURE; - } - mem_3m_sqp->ci = (double*)mci.alignedBuf; - mem_3m_sqp->ci_unaligned = (double*)mci.unalignedBuf; - - // A packing buffer - mem_3m_sqp->aPacked = (double*)bli_malloc_user(sizeof(double) * kx * mx); - if (mem_3m_sqp->aPacked == NULL) - { - return BLIS_FAILURE; - } - - return BLIS_SUCCESS; -} - -void free_3m_Sqp_workspace(workspace_3m_sqp *mem_3m_sqp) -{ - // A packing buffer free - bli_free_user(mem_3m_sqp->aPacked); - -#if MEM_ALLOC - if(mem_3m_sqp->ar_unaligned) - { - free(mem_3m_sqp->ar_unaligned); - } - if(mem_3m_sqp->ai_unaligned) - { - free(mem_3m_sqp->ai_unaligned); - } - if(mem_3m_sqp->as_unaligned) - { - free(mem_3m_sqp->as_unaligned); - } - - if(mem_3m_sqp->br_unaligned) - { - free(mem_3m_sqp->br_unaligned); - } - if(mem_3m_sqp->bi_unaligned) - { - free(mem_3m_sqp->bi_unaligned); - } - if(mem_3m_sqp->bs_unaligned) - { - free(mem_3m_sqp->bs_unaligned); - } - - if(mem_3m_sqp->w_unaligned) - { - free(mem_3m_sqp->w_unaligned); - } - if(mem_3m_sqp->cr_unaligned) - { - free(mem_3m_sqp->cr_unaligned); - } - if(mem_3m_sqp->ci_unaligned) - { - free(mem_3m_sqp->ci_unaligned); - } - -#else//MEM_ALLOC - /* free workspace buffers */ - bli_free_user(mem_3m_sqp->br); - bli_free_user(mem_3m_sqp->bi); - bli_free_user(mem_3m_sqp->bs); - bli_free_user(mem_3m_sqp->ar); - bli_free_user(mem_3m_sqp->ai); - bli_free_user(mem_3m_sqp->as); - bli_free_user(mem_3m_sqp->w); - bli_free_user(mem_3m_sqp->cr); - bli_free_user(mem_3m_sqp->ci); -#endif//MEM_ALLOC -} - -//3m_sqp m loop -BLIS_INLINE err_t bli_sqp_zgemm_m8( gint_t m, - gint_t n, - gint_t k, - double* a, - guint_t lda, - double* b, - guint_t ldb, - double* c, - guint_t ldc, - double alpha, - double beta, - trans_t transa, - gint_t mx, - gint_t* p_istart, - gint_t kx, - workspace_3m_sqp *mem_3m_sqp) -{ - inc_t m2 = m<<1; - inc_t mxmul2 = mx<<1; - - if((*p_istart) > (m2-mxmul2)) - { - return BLIS_SUCCESS; - } - inc_t i; - gint_t max_m = (m2-mxmul2); - - //get workspace - double* ar, * ai, * as; - ar = mem_3m_sqp->ar; - ai = mem_3m_sqp->ai; - as = mem_3m_sqp->as; - - double* br, * bi, * bs; - br = mem_3m_sqp->br; - bi = mem_3m_sqp->bi; - bs = mem_3m_sqp->bs; - - double* cr, * ci; - cr = mem_3m_sqp->cr; - ci = mem_3m_sqp->ci; - - double *w; - w = mem_3m_sqp->w; - - double* a_aligned; - a_aligned = mem_3m_sqp->aPacked; - - /* Split b (br, bi) and - compute bs = br + bi */ - double* pbr = br; - double* pbi = bi; - double* pbs = bs; - /* b matrix real and imag packing and compute. */ - bli_3m_sqp_packB_real_imag_sum(b, n, k, ldb, pbr, pbi, pbs, alpha, mx); - - for (i = (*p_istart); i <= max_m; i += mxmul2) //this loop can be threaded. - { -#if KLP//kloop - int p = 0; - for(; p <= (k-kx); p += kx) - { - bli_sqp_zgemm_kx(m, n, kx, p, a, lda, k, c, ldc, - transa, alpha, beta, mx, i, ar, ai, as, - br + p, bi + p, bs + p, cr, ci, w, a_aligned); - }// k loop end - - if(p>3)<<3); - - workspace_3m_sqp mem_3m_sqp; - - /* multiply lda, ldb and ldc by 2 to account for - real & imaginary components per dcomplex. */ - lda = lda * 2; - ldb = ldb * 2; - ldc = ldc * 2; - - /* user can set BLIS_MULTI_INSTANCE macro for - better performance while runing multi-instance use-case. - */ - dim_t multi_instance = bli_env_get_var( "BLIS_MULTI_INSTANCE", -1 ); - gint_t nx = n; - if(multi_instance>0) - { - //limited nx size helps in reducing memory footprint in multi-instance case. - nx = 84; - // 84 is derived based on tuning results - } - - if(nx>n) - { - nx = n; - } - - gint_t kx = k;// kx is configurable at run-time. -#if KLP - if (kx > k) - { - kx = k; - } - // for tn case there is a bug in handling k parts. To be fixed. - if(transa!=BLIS_NO_TRANSPOSE) - { - kx = k; - } -#else - kx = k; -#endif - //3m_sqp workspace Memory allocation - if(allocate_3m_Sqp_workspace(&mem_3m_sqp, mx, nx, k, kx)!=BLIS_SUCCESS) - { - return BLIS_FAILURE; - } - - BLI_SQP_ZGEMM_N(mx) - *p_istart = (m-m8rem)*2; - - if(m8rem!=0) - { - //complete residue m blocks - BLI_SQP_ZGEMM_N(m8rem) - } - - free_3m_Sqp_workspace(&mem_3m_sqp); - return status; -} - -/****************************************************************************/ -/*********************** dgemm_sqp implementation****************************/ -/****************************************************************************/ -/* dgemm_sqp implementation packs A matrix based on lda and m size. - dgemm_sqp focuses mainly on square matrixes but also supports non-square matrix. - Current support is limiteed to m multiple of 8 and column storage. - C = AxB and C = AtxB is handled in the design. - AtxB case is done by transposing A matrix while packing A. - In majority of use-case, alpha are +/-1, so instead of explicitly multiplying - alpha its done during packing itself by changing sign. -*/ -BLIS_INLINE err_t bli_sqp_dgemm(gint_t m, - gint_t n, - gint_t k, - double* a, - guint_t lda, - double* b, - guint_t ldb, - double* c, - guint_t ldc, - double alpha, - double beta, - bool isTransA, - dim_t nt) -{ - gint_t istart = 0; - gint_t* p_istart = &istart; - *p_istart = 0; - err_t status = BLIS_SUCCESS; - dim_t m8rem = m - ((m>>3)<<3); - - /* dgemm implementation with 8mx5n major kernel and column preferred storage */ - gint_t mx = 8; - gint_t kx = k; - double* a_aligned = NULL; - - if(nt<=1)//single pack buffer allocated for single thread case - { - a_aligned = (double*)bli_malloc_user(sizeof(double) * kx * mx); - } - - gint_t nx = n;//MAX; - if(nx>n) - { - nx = n; - } - - //mx==8 case for dgemm. - BLI_SQP_DGEMM_N(mx) - *p_istart = (m-m8rem); - - if(nt>1) - { - //2nd level thread for mx=8 - gint_t rem_m = m - (*p_istart); - if((rem_m>=mx)&&(status==BLIS_SUCCESS)) - { - status = bli_sqp_dgemm_m8( m, n, k, a, lda, b, ldb, c, ldc, - isTransA, alpha, mx, p_istart, kx, a_aligned); - } - } - - if(status==BLIS_SUCCESS) - { - if(m8rem!=0) - { - //complete residue m blocks - BLI_SQP_DGEMM_N(m8rem) - } - } - - if(nt<=1)//single pack buffer allocated for single thread case - { - bli_free_user(a_aligned); - } - return status; -} \ No newline at end of file diff --git a/kernels/zen/3/bli_gemm_sqp_kernels.c b/kernels/zen/3/bli_gemm_sqp_kernels.c deleted file mode 100644 index 0f20c0a956..0000000000 --- a/kernels/zen/3/bli_gemm_sqp_kernels.c +++ /dev/null @@ -1,1750 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2021, Advanced Micro Devices, Inc. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ -#include "blis.h" -#include "immintrin.h" -#include "bli_gemm_sqp_kernels.h" - -#define BLIS_LOADFIRST 0 -#define BLIS_ENABLE_PREFETCH 1 - -#define BLIS_MX8 8 -#define BLIS_MX4 4 -#define BLIS_MX1 1 - -/****************************************************************************/ -/*************** dgemm kernels (8mxn) column preffered *********************/ -/****************************************************************************/ - -/* Main dgemm kernel 8mx6n with single load and store of C matrix block - alpha = +/-1 and beta = +/-1,0 handled while packing.*/ -inc_t bli_sqp_dgemm_kernel_8mx6n(gint_t n, - gint_t k, - gint_t j, - double* aPacked, - guint_t lda, - double* b, - guint_t ldb, - double* c, - guint_t ldc) -{ - gint_t p; - - __m256d av0, av1; - __m256d bv0, bv1; - __m256d cv0, cv1, cv2, cv3, cv4, cv5; - __m256d cx0, cx1, cx2, cx3, cx4, cx5; - double* pb, * pc; - - pb = b; - pc = c; - inc_t ldc6 = ldc * 6; inc_t ldb6 = ldb * 6; - - for (j = 0; j <= (n - 6); j += 6) { - double* pcldc = pc + ldc; - double* pcldc2 = pcldc + ldc; - double* pcldc3 = pcldc2 + ldc; - double* pcldc4 = pcldc3 + ldc; - double* pcldc5 = pcldc4 + ldc; - - double* pbldb = pb + ldb; - double* pbldb2 = pbldb + ldb; - double* pbldb3 = pbldb2 + ldb; - double* pbldb4 = pbldb3 + ldb; - double* pbldb5 = pbldb4 + ldb; - -#if BLIS_ENABLE_PREFETCH - _mm_prefetch((char*)(pc), _MM_HINT_T0); - _mm_prefetch((char*)(pcldc), _MM_HINT_T0); - _mm_prefetch((char*)(pcldc2), _MM_HINT_T0); - _mm_prefetch((char*)(pcldc3), _MM_HINT_T0); - _mm_prefetch((char*)(pcldc4), _MM_HINT_T0); - _mm_prefetch((char*)(pcldc5), _MM_HINT_T0); - - _mm_prefetch((char*)(aPacked), _MM_HINT_T0); - - _mm_prefetch((char*)(pb), _MM_HINT_T0); - _mm_prefetch((char*)(pbldb), _MM_HINT_T0); - _mm_prefetch((char*)(pbldb2), _MM_HINT_T0); - _mm_prefetch((char*)(pbldb3), _MM_HINT_T0); - _mm_prefetch((char*)(pbldb4), _MM_HINT_T0); - _mm_prefetch((char*)(pbldb5), _MM_HINT_T0); -#endif - /* C matrix column major load */ -#if BLIS_LOADFIRST - cv0 = _mm256_loadu_pd(pc); cx0 = _mm256_loadu_pd(pc + 4); - cv1 = _mm256_loadu_pd(pcldc); cx1 = _mm256_loadu_pd(pcldc + 4); - cv2 = _mm256_loadu_pd(pcldc2); cx2 = _mm256_loadu_pd(pcldc2 + 4); - cv3 = _mm256_loadu_pd(pcldc3); cx3 = _mm256_loadu_pd(pcldc3 + 4); - cv4 = _mm256_loadu_pd(pcldc4); cx4 = _mm256_loadu_pd(pcldc4 + 4); - cv5 = _mm256_loadu_pd(pcldc5); cx5 = _mm256_loadu_pd(pcldc5 + 4); -#else - cv0 = _mm256_setzero_pd(); cx0 = _mm256_setzero_pd(); - cv1 = _mm256_setzero_pd(); cx1 = _mm256_setzero_pd(); - cv2 = _mm256_setzero_pd(); cx2 = _mm256_setzero_pd(); - cv3 = _mm256_setzero_pd(); cx3 = _mm256_setzero_pd(); - cv4 = _mm256_setzero_pd(); cx4 = _mm256_setzero_pd(); - cv5 = _mm256_setzero_pd(); cx5 = _mm256_setzero_pd(); -#endif - double* x = aPacked; - double* pb0 = pb; - for (p = 0; p < k; p += 1) { - av0 = _mm256_loadu_pd(x); x += 4; av1 = _mm256_loadu_pd(x); x += 4; - bv0 = _mm256_broadcast_sd (pb0); pb0++; - bv1 = _mm256_broadcast_sd(pbldb); pbldb++; - cv0 = _mm256_fmadd_pd(av0, bv0, cv0); - cx0 = _mm256_fmadd_pd(av1, bv0, cx0); - cv1 = _mm256_fmadd_pd(av0, bv1, cv1); - cx1 = _mm256_fmadd_pd(av1, bv1, cx1); - - bv0 = _mm256_broadcast_sd(pbldb2);pbldb2++; - bv1 = _mm256_broadcast_sd(pbldb3);pbldb3++; - cv2 = _mm256_fmadd_pd(av0, bv0, cv2); - cx2 = _mm256_fmadd_pd(av1, bv0, cx2); - cv3 = _mm256_fmadd_pd(av0, bv1, cv3); - cx3 = _mm256_fmadd_pd(av1, bv1, cx3); - - bv0 = _mm256_broadcast_sd(pbldb4);pbldb4++; - bv1 = _mm256_broadcast_sd(pbldb5);pbldb5++; - cv4 = _mm256_fmadd_pd(av0, bv0, cv4); - cx4 = _mm256_fmadd_pd(av1, bv0, cx4); - cv5 = _mm256_fmadd_pd(av0, bv1, cv5); - cx5 = _mm256_fmadd_pd(av1, bv1, cx5); - } -#if BLIS_LOADFIRST -#else - bv0 = _mm256_loadu_pd(pc); bv1 = _mm256_loadu_pd(pc + 4); - cv0 = _mm256_add_pd(cv0, bv0); cx0 = _mm256_add_pd(cx0, bv1); - - av0 = _mm256_loadu_pd(pcldc); av1 = _mm256_loadu_pd(pcldc + 4); - cv1 = _mm256_add_pd(cv1, av0); cx1 = _mm256_add_pd(cx1, av1); - - bv0 = _mm256_loadu_pd(pcldc2); bv1 = _mm256_loadu_pd(pcldc2 + 4); - cv2 = _mm256_add_pd(cv2, bv0); cx2 = _mm256_add_pd(cx2, bv1); - - av0 = _mm256_loadu_pd(pcldc3); av1 = _mm256_loadu_pd(pcldc3 + 4); - cv3 = _mm256_add_pd(cv3, av0); cx3 = _mm256_add_pd(cx3, av1); - - bv0 = _mm256_loadu_pd(pcldc4); bv1 = _mm256_loadu_pd(pcldc4 + 4); - cv4 = _mm256_add_pd(cv4, bv0); cx4 = _mm256_add_pd(cx4, bv1); - - av0 = _mm256_loadu_pd(pcldc5); av1 = _mm256_loadu_pd(pcldc5 + 4); - cv5 = _mm256_add_pd(cv5, av0); cx5 = _mm256_add_pd(cx5, av1); -#endif - /* C matrix column major store */ - _mm256_storeu_pd(pc, cv0); - _mm256_storeu_pd(pc + 4, cx0); - - _mm256_storeu_pd(pcldc, cv1); - _mm256_storeu_pd(pcldc + 4, cx1); - - _mm256_storeu_pd(pcldc2, cv2); - _mm256_storeu_pd(pcldc2 + 4, cx2); - - _mm256_storeu_pd(pcldc3, cv3); - _mm256_storeu_pd(pcldc3 + 4, cx3); - - _mm256_storeu_pd(pcldc4, cv4); - _mm256_storeu_pd(pcldc4 + 4, cx4); - - _mm256_storeu_pd(pcldc5, cv5); - _mm256_storeu_pd(pcldc5 + 4, cx5); - - pc += ldc6;pb += ldb6; - } - //printf(" 8x6:j:%d ", j); - return j; -} - -/* alternative Main dgemm kernel 8mx5n with single load and store of C matrix block - alpha = +/-1 and beta = +/-1,0 handled while packing.*/ -inc_t bli_sqp_dgemm_kernel_8mx5n( gint_t n, - gint_t k, - gint_t j, - double* aPacked, - guint_t lda, - double* b, - guint_t ldb, - double* c, - guint_t ldc) -{ - gint_t p; - __m256d av0; - __m256d bv0, bv1, bv2, bv3; - __m256d cv0, cv1, cv2, cv3; - __m256d cx0, cx1, cx2, cx3; - __m256d bv4, cv4, cx4; - double* pb, * pc; - - pb = b; - pc = c; - inc_t ldc5 = ldc * 5; inc_t ldb5 = ldb * 5; - - for (; j <= (n - 5); j += 5) { - - double* pcldc = pc + ldc; - double* pcldc2 = pcldc + ldc; - double* pcldc3 = pcldc2 + ldc; - double* pcldc4 = pcldc3 + ldc; - - double* pbldb = pb + ldb; - double* pbldb2 = pbldb + ldb; - double* pbldb3 = pbldb2 + ldb; - double* pbldb4 = pbldb3 + ldb; - -#if BLIS_ENABLE_PREFETCH - _mm_prefetch((char*)(pc), _MM_HINT_T0); - _mm_prefetch((char*)(pcldc), _MM_HINT_T0); - _mm_prefetch((char*)(pcldc2), _MM_HINT_T0); - _mm_prefetch((char*)(pcldc3), _MM_HINT_T0); - _mm_prefetch((char*)(pcldc4), _MM_HINT_T0); - - _mm_prefetch((char*)(aPacked), _MM_HINT_T0); - - _mm_prefetch((char*)(pb), _MM_HINT_T0); - _mm_prefetch((char*)(pbldb), _MM_HINT_T0); - _mm_prefetch((char*)(pbldb2), _MM_HINT_T0); - _mm_prefetch((char*)(pbldb3), _MM_HINT_T0); - _mm_prefetch((char*)(pbldb4), _MM_HINT_T0); -#endif - /* C matrix column major load */ -#if BLIS_LOADFIRST - cv0 = _mm256_loadu_pd(pc); cx0 = _mm256_loadu_pd(pc + 4); - cv1 = _mm256_loadu_pd(pcldc); cx1 = _mm256_loadu_pd(pcldc + 4); - cv2 = _mm256_loadu_pd(pcldc2); cx2 = _mm256_loadu_pd(pcldc2 + 4); - cv3 = _mm256_loadu_pd(pcldc3); cx3 = _mm256_loadu_pd(pcldc3 + 4); - cv4 = _mm256_loadu_pd(pcldc4); cx4 = _mm256_loadu_pd(pcldc4 + 4); -#else - cv0 = _mm256_setzero_pd(); cx0 = _mm256_setzero_pd(); - cv1 = _mm256_setzero_pd(); cx1 = _mm256_setzero_pd(); - cv2 = _mm256_setzero_pd(); cx2 = _mm256_setzero_pd(); - cv3 = _mm256_setzero_pd(); cx3 = _mm256_setzero_pd(); - cv4 = _mm256_setzero_pd(); cx4 = _mm256_setzero_pd(); -#endif - double* x = aPacked; - double* pb0 = pb; - for (p = 0; p < k; p += 1) { - bv0 = _mm256_broadcast_sd(pb0); pb0++; - bv1 = _mm256_broadcast_sd(pbldb); pbldb++; - bv2 = _mm256_broadcast_sd(pbldb2); pbldb2++; - bv3 = _mm256_broadcast_sd(pbldb3);pbldb3++; - bv4 = _mm256_broadcast_sd(pbldb4);pbldb4++; - - av0 = _mm256_loadu_pd(x); x += 4; - cv0 = _mm256_fmadd_pd(av0, bv0, cv0); - cv1 = _mm256_fmadd_pd(av0, bv1, cv1); - cv2 = _mm256_fmadd_pd(av0, bv2, cv2); - cv3 = _mm256_fmadd_pd(av0, bv3, cv3); - cv4 = _mm256_fmadd_pd(av0, bv4, cv4); - - av0 = _mm256_loadu_pd(x); x += 4; - cx0 = _mm256_fmadd_pd(av0, bv0, cx0); - cx1 = _mm256_fmadd_pd(av0, bv1, cx1); - cx2 = _mm256_fmadd_pd(av0, bv2, cx2); - cx3 = _mm256_fmadd_pd(av0, bv3, cx3); - cx4 = _mm256_fmadd_pd(av0, bv4, cx4); - } -#if BLIS_LOADFIRST -#else - bv0 = _mm256_loadu_pd(pc); bv1 = _mm256_loadu_pd(pc + 4); - cv0 = _mm256_add_pd(cv0, bv0); cx0 = _mm256_add_pd(cx0, bv1); - - bv2 = _mm256_loadu_pd(pcldc); bv3 = _mm256_loadu_pd(pcldc + 4); - cv1 = _mm256_add_pd(cv1, bv2); cx1 = _mm256_add_pd(cx1, bv3); - - bv0 = _mm256_loadu_pd(pcldc2); bv1 = _mm256_loadu_pd(pcldc2 + 4); - cv2 = _mm256_add_pd(cv2, bv0); cx2 = _mm256_add_pd(cx2, bv1); - - bv2 = _mm256_loadu_pd(pcldc3); bv3 = _mm256_loadu_pd(pcldc3 + 4); - cv3 = _mm256_add_pd(cv3, bv2); cx3 = _mm256_add_pd(cx3, bv3); - - bv0 = _mm256_loadu_pd(pcldc4); bv1 = _mm256_loadu_pd(pcldc4 + 4); - cv4 = _mm256_add_pd(cv4, bv0); cx4 = _mm256_add_pd(cx4, bv1); -#endif - /* C matrix column major store */ - _mm256_storeu_pd(pc, cv0); - _mm256_storeu_pd(pc + 4, cx0); - - _mm256_storeu_pd(pcldc, cv1); - _mm256_storeu_pd(pcldc + 4, cx1); - - _mm256_storeu_pd(pcldc2, cv2); - _mm256_storeu_pd(pcldc2 + 4, cx2); - - _mm256_storeu_pd(pcldc3, cv3); - _mm256_storeu_pd(pcldc3 + 4, cx3); - - _mm256_storeu_pd(pcldc4, cv4); - _mm256_storeu_pd(pcldc4 + 4, cx4); - - pc += ldc5;pb += ldb5; - } - //printf(" 8x5:j:%d ", j); - return j; -} - -/* residue dgemm kernel 8mx4n with single load and store of C matrix block - Code could be optimized further, complete ymm register set is not used. - Being residue kernel, its of lesser priority. -*/ -inc_t bli_sqp_dgemm_kernel_8mx4n( gint_t n, - gint_t k, - gint_t j, - double* aPacked, - guint_t lda, - double* b, - guint_t ldb, - double* c, - guint_t ldc) -{ - gint_t p; - __m256d av0; - __m256d bv0, bv1, bv2, bv3; - __m256d cv0, cv1, cv2, cv3; - __m256d cx0, cx1, cx2, cx3; - double* pb, * pc; - - pb = b; - pc = c; - inc_t ldc4 = ldc * 4; inc_t ldb4 = ldb * 4; - - for (; j <= (n - 4); j += 4) { - - double* pcldc = pc + ldc; double* pcldc2 = pcldc + ldc; double* pcldc3 = pcldc2 + ldc; - double* pbldb = pb + ldb; double* pbldb2 = pbldb + ldb; double* pbldb3 = pbldb2 + ldb; - - cv0 = _mm256_loadu_pd(pc); cx0 = _mm256_loadu_pd(pc + 4); - cv1 = _mm256_loadu_pd(pcldc); cx1 = _mm256_loadu_pd(pcldc + 4); - cv2 = _mm256_loadu_pd(pcldc2); cx2 = _mm256_loadu_pd(pcldc2 + 4); - cv3 = _mm256_loadu_pd(pcldc3); cx3 = _mm256_loadu_pd(pcldc3 + 4); - { - double* x = aPacked; - double* pb0 = pb; - for (p = 0; p < k; p += 1) { - // better kernel to be written since more register are available. - bv0 = _mm256_broadcast_sd(pb0); pb0++; - bv1 = _mm256_broadcast_sd(pbldb); pbldb++; - bv2 = _mm256_broadcast_sd(pbldb2); pbldb2++; - bv3 = _mm256_broadcast_sd(pbldb3); pbldb3++; - - av0 = _mm256_loadu_pd(x); x += 4; - cv0 = _mm256_fmadd_pd(av0, bv0, cv0); - cv1 = _mm256_fmadd_pd(av0, bv1, cv1); - cv2 = _mm256_fmadd_pd(av0, bv2, cv2); - cv3 = _mm256_fmadd_pd(av0, bv3, cv3); - - av0 = _mm256_loadu_pd(x); x += 4; - cx0 = _mm256_fmadd_pd(av0, bv0, cx0); - cx1 = _mm256_fmadd_pd(av0, bv1, cx1); - cx2 = _mm256_fmadd_pd(av0, bv2, cx2); - cx3 = _mm256_fmadd_pd(av0, bv3, cx3); - } - } - _mm256_storeu_pd(pc, cv0); - _mm256_storeu_pd(pc + 4, cx0); - _mm256_storeu_pd(pcldc, cv1); - _mm256_storeu_pd(pcldc + 4, cx1); - _mm256_storeu_pd(pcldc2, cv2); - _mm256_storeu_pd(pcldc2 + 4, cx2); - _mm256_storeu_pd(pcldc3, cv3); - _mm256_storeu_pd(pcldc3 + 4, cx3); - - pc += ldc4;pb += ldb4; - }// j loop 4 multiple - //printf(" 8x4:j:%d ", j); - return j; -} - -/* residue dgemm kernel 8mx3n with single load and store of C matrix block - Code could be optimized further, complete ymm register set is not used. - Being residue kernel, its of lesser priority. -*/ -inc_t bli_sqp_dgemm_kernel_8mx3n( gint_t n, - gint_t k, - gint_t j, - double* aPacked, - guint_t lda, - double* b, - guint_t ldb, - double* c, - guint_t ldc) -{ - gint_t p; - __m256d av0; - __m256d bv0, bv1, bv2; - __m256d cv0, cv1, cv2; - __m256d cx0, cx1, cx2; - double* pb, * pc; - - pb = b; - pc = c; - - inc_t ldc3 = ldc * 3; inc_t ldb3 = ldb * 3; - - for (; j <= (n - 3); j += 3) { - - double* pcldc = pc + ldc; double* pcldc2 = pcldc + ldc; - double* pbldb = pb + ldb; double* pbldb2 = pbldb + ldb; - - cv0 = _mm256_loadu_pd(pc); cx0 = _mm256_loadu_pd(pc + 4); - cv1 = _mm256_loadu_pd(pcldc); cx1 = _mm256_loadu_pd(pcldc + 4); - cv2 = _mm256_loadu_pd(pcldc2); cx2 = _mm256_loadu_pd(pcldc2 + 4); - { - double* x = aPacked; - double* pb0 = pb; - for (p = 0; p < k; p += 1) { - bv0 = _mm256_broadcast_sd(pb0); pb0++; - bv1 = _mm256_broadcast_sd(pbldb); pbldb++; - bv2 = _mm256_broadcast_sd(pbldb2); pbldb2++; - - av0 = _mm256_loadu_pd(x); x += 4; - cv0 = _mm256_fmadd_pd(av0, bv0, cv0); - cv1 = _mm256_fmadd_pd(av0, bv1, cv1); - cv2 = _mm256_fmadd_pd(av0, bv2, cv2); - - av0 = _mm256_loadu_pd(x); x += 4; - cx0 = _mm256_fmadd_pd(av0, bv0, cx0); - cx1 = _mm256_fmadd_pd(av0, bv1, cx1); - cx2 = _mm256_fmadd_pd(av0, bv2, cx2); - } - } - - _mm256_storeu_pd(pc, cv0); - _mm256_storeu_pd(pc + 4, cx0); - _mm256_storeu_pd(pcldc, cv1); - _mm256_storeu_pd(pcldc + 4, cx1); - _mm256_storeu_pd(pcldc2, cv2); - _mm256_storeu_pd(pcldc2 + 4, cx2); - - pc += ldc3;pb += ldb3; - }// j loop 3 multiple - //printf(" 8x3:j:%d ", j); - return j; -} - -/* residue dgemm kernel 8mx2n with single load and store of C matrix block - Code could be optimized further, complete ymm register set is not used. - Being residue kernel, its of lesser priority. -*/ -inc_t bli_sqp_dgemm_kernel_8mx2n( gint_t n, - gint_t k, - gint_t j, - double* aPacked, - guint_t lda, - double* b, - guint_t ldb, - double* c, - guint_t ldc) -{ - gint_t p; - __m256d av0; - __m256d bv0, bv1; - __m256d cv0, cv1; - __m256d cx0, cx1; - double* pb, * pc; - - pb = b; - pc = c; - inc_t ldc2 = ldc * 2; inc_t ldb2 = ldb * 2; - - for (; j <= (n - 2); j += 2) { - double* pcldc = pc + ldc; - double* pbldb = pb + ldb; - - cv0 = _mm256_loadu_pd(pc); cx0 = _mm256_loadu_pd(pc + 4); - cv1 = _mm256_loadu_pd(pcldc); cx1 = _mm256_loadu_pd(pcldc + 4); - { - double* x = aPacked; - double* pb0 = pb; - for (p = 0; p < k; p += 1) { - bv0 = _mm256_broadcast_sd(pb0); pb0++; - bv1 = _mm256_broadcast_sd(pbldb); pbldb++; - - av0 = _mm256_loadu_pd(x); x += 4; - cv0 = _mm256_fmadd_pd(av0, bv0, cv0); - cv1 = _mm256_fmadd_pd(av0, bv1, cv1); - - av0 = _mm256_loadu_pd(x); x += 4; - cx0 = _mm256_fmadd_pd(av0, bv0, cx0); - cx1 = _mm256_fmadd_pd(av0, bv1, cx1); - } - } - _mm256_storeu_pd(pc, cv0); - _mm256_storeu_pd(pc + 4, cx0); - _mm256_storeu_pd(pcldc, cv1); - _mm256_storeu_pd(pcldc + 4, cx1); - - pc += ldc2;pb += ldb2; - }// j loop 2 multiple - //printf(" 8x2:j:%d ", j); - return j; -} - -/* residue dgemm kernel 8mx1n with single load and store of C matrix block - Code could be optimized further, complete ymm register set is not used. - Being residue kernel, its of lesser priority. -*/ -inc_t bli_sqp_dgemm_kernel_8mx1n( gint_t n, - gint_t k, - gint_t j, - double* aPacked, - guint_t lda, - double* b, - guint_t ldb, - double* c, - guint_t ldc) -{ - gint_t p; - __m256d av0; - __m256d bv0; - __m256d cv0; - __m256d cx0; - double* pb, * pc; - - pb = b; - pc = c; - - for (; j <= (n - 1); j += 1) { - cv0 = _mm256_loadu_pd(pc); cx0 = _mm256_loadu_pd(pc + 4); - double* x = aPacked; - double* pb0 = pb; - for (p = 0; p < k; p += 1) { - bv0 = _mm256_broadcast_sd(pb0); pb0++; - - av0 = _mm256_loadu_pd(x); x += 4; - cv0 = _mm256_fmadd_pd(av0, bv0, cv0); - - av0 = _mm256_loadu_pd(x); x += 4; - cx0 = _mm256_fmadd_pd(av0, bv0, cx0); - } - _mm256_storeu_pd(pc, cv0); - _mm256_storeu_pd(pc + 4, cx0); - pc += ldc;pb += ldb; - }// j loop 1 multiple - //printf(" 8x1:j:%d ", j); - return j; -} - -#if 0 -/************************************************************************************************************/ -/************************** dgemm kernels (4mxn) column preffered ******************************************/ -/************************************************************************************************************/ -/* Residue dgemm kernel 4mx10n with single load and store of C matrix block - alpha = +/-1 and beta = +/-1,0 handled while packing.*/ -inc_t bli_sqp_dgemm_kernel_4mx10n( gint_t n, - gint_t k, - gint_t j, - double* aPacked, - guint_t lda, - double* b, - guint_t ldb, - double* c, - guint_t ldc) -{ - gint_t p; - /* incomplete */ - __m256d av0; - __m256d bv0, bv1, bv2, bv3; - __m256d cv0, cv1, cv2, cv3; - __m256d cx0, cx1, cx2, cx3; - __m256d bv4, cv4, cx4; - double* pb, * pc; - - pb = b; - pc = c; - inc_t ldc10 = ldc * 10; inc_t ldb10 = ldb * 10; - - for (j = 0; j <= (n - 10); j += 10) { - - double* pcldc = pc + ldc; double* pcldc2 = pcldc + ldc; double* pcldc3 = pcldc2 + ldc; double* pcldc4 = pcldc3 + ldc; - double* pbldb = pb + ldb; double* pbldb2 = pbldb + ldb; double* pbldb3 = pbldb2 + ldb; double* pbldb4 = pbldb3 + ldb; - -#if BLIS_ENABLE_PREFETCH - _mm_prefetch((char*)(pc), _MM_HINT_T0); - _mm_prefetch((char*)(pcldc), _MM_HINT_T0); - _mm_prefetch((char*)(pcldc2), _MM_HINT_T0); - _mm_prefetch((char*)(pcldc3), _MM_HINT_T0); - _mm_prefetch((char*)(pcldc4), _MM_HINT_T0); - - _mm_prefetch((char*)(aPacked), _MM_HINT_T0); - - _mm_prefetch((char*)(pb), _MM_HINT_T0); - _mm_prefetch((char*)(pbldb), _MM_HINT_T0); - _mm_prefetch((char*)(pbldb2), _MM_HINT_T0); - _mm_prefetch((char*)(pbldb3), _MM_HINT_T0); - _mm_prefetch((char*)(pbldb4), _MM_HINT_T0); -#endif - /* C matrix column major load */ -#if BLIS_LOADFIRST - cv0 = _mm256_loadu_pd(pc); - cv1 = _mm256_loadu_pd(pcldc); - cv2 = _mm256_loadu_pd(pcldc2); - cv3 = _mm256_loadu_pd(pcldc3); - cv4 = _mm256_loadu_pd(pcldc4); -#else - cv0 = _mm256_setzero_pd(); - cv1 = _mm256_setzero_pd(); - cv2 = _mm256_setzero_pd(); - cv3 = _mm256_setzero_pd(); - cv4 = _mm256_setzero_pd(); -#endif - double* x = aPacked; - double* pb0 = pb; - for (p = 0; p < k; p += 1) { - bv0 = _mm256_broadcast_sd(pb0); pb0++; - bv1 = _mm256_broadcast_sd(pbldb); pbldb++; - bv2 = _mm256_broadcast_sd(pbldb2); pbldb2++; - bv3 = _mm256_broadcast_sd(pbldb3);pbldb3++; - bv4 = _mm256_broadcast_sd(pbldb4);pbldb4++; - - av0 = _mm256_loadu_pd(x); x += 4; - cv0 = _mm256_fmadd_pd(av0, bv0, cv0); - cv1 = _mm256_fmadd_pd(av0, bv1, cv1); - cv2 = _mm256_fmadd_pd(av0, bv2, cv2); - cv3 = _mm256_fmadd_pd(av0, bv3, cv3); - cv4 = _mm256_fmadd_pd(av0, bv4, cv4); - - } -#if BLIS_LOADFIRST -#else - bv0 = _mm256_loadu_pd(pc); - cv0 = _mm256_add_pd(cv0, bv0); - - bv2 = _mm256_loadu_pd(pcldc); - cv1 = _mm256_add_pd(cv1, bv2); - - bv0 = _mm256_loadu_pd(pcldc2); - cv2 = _mm256_add_pd(cv2, bv0); - - bv2 = _mm256_loadu_pd(pcldc3); - cv3 = _mm256_add_pd(cv3, bv2); - - bv0 = _mm256_loadu_pd(pcldc4); - cv4 = _mm256_add_pd(cv4, bv0); -#endif - /* C matrix column major store */ - _mm256_storeu_pd(pc, cv0); - _mm256_storeu_pd(pcldc, cv1); - _mm256_storeu_pd(pcldc2, cv2); - _mm256_storeu_pd(pcldc3, cv3); - _mm256_storeu_pd(pcldc4, cv4); - - - pc += ldc10;pb += ldb10; - } - - return j; -} - -/* residue dgemm kernel 4mx1n with single load and store of C matrix block - Code could be optimized further, complete ymm register set is not used. - Being residue kernel, its of lesser priority. -*/ -inc_t bli_sqp_dgemm_kernel_4mx1n( gint_t n, - gint_t k, - gint_t j, - double* aPacked, - guint_t lda, - double* b, - guint_t ldb, - double* c, - guint_t ldc) -{ - gint_t p; - __m256d av0; - __m256d bv0; - __m256d cv0; - double* pb, * pc; - - pb = b; - pc = c; - - for (; j <= (n - 1); j += 1) { - cv0 = _mm256_loadu_pd(pc); - double* x = aPacked; - double* pb0 = pb; - for (p = 0; p < k; p += 1) { - bv0 = _mm256_broadcast_sd(pb0); pb0++; - av0 = _mm256_loadu_pd(x); x += 4; - cv0 = _mm256_fmadd_pd(av0, bv0, cv0); - } - _mm256_storeu_pd(pc, cv0); - pc += ldc;pb += ldb; - }// j loop 1 multiple - return j; -} - -#endif -/************************************************************************************************************/ -/************************** dgemm kernels (1mxn) column preffered ******************************************/ -/************************************************************************************************************/ - -/* residue dgemm kernel 1mx1n with single load and store of C matrix block - Code could be optimized further, complete ymm register set is not used. - Being residue kernel, its of lesser priority. -*/ -inc_t bli_sqp_dgemm_kernel_1mx1n( gint_t n, - gint_t k, - gint_t j, - double* aPacked, - guint_t lda, - double* b, - guint_t ldb, - double* c, - guint_t ldc) -{ - gint_t p; - double a0; - double b0; - double c0; - double* pb, * pc; - - pb = b; - pc = c; - - for (; j <= (n - 1); j += 1) { - c0 = *pc; - double* x = aPacked; - double* pb0 = pb; - for (p = 0; p < k; p += 1) { - b0 = *pb0; pb0++; - a0 = *x; x++; - c0 += (a0 * b0); - } - *pc = c0; - pc += ldc;pb += ldb; - }// j loop 1 multiple - //printf(" 1x1:j:%d ", j); - return j; -} - -inc_t bli_sqp_dgemm_kernel_mxn( gint_t n, - gint_t k, - gint_t j, - double* aPacked, - guint_t lda, - double* b, - guint_t ldb, - double* c, - guint_t ldc, - gint_t mx) -{ - gint_t p; - double cx[7]; - - double* pb, * pc; - - pb = b; - pc = c; - - for (; j <= (n - 1); j += 1) { - //cv0 = _mm256_loadu_pd(pc); - for (int i = 0; i < mx; i++) - { - cx[i] = *(pc + i); - } - - double* x = aPacked; - double* pb0 = pb; - for (p = 0; p < k; p += 1) { - //bv0 = _mm256_broadcast_sd(pb0); - double b0 = *pb0; - pb0++; - for (int i = 0; i < mx; i++) - { - cx[i] += (*(x + i)) * b0;//cv0 = _mm256_fmadd_pd(av0, bv0, cv0); - } - //av0 = _mm256_loadu_pd(x); - x += mx; - } - //_mm256_storeu_pd(pc, cv0); - for (int i = 0; i < mx; i++) - { - *(pc + i) = cx[i]; - } - pc += ldc;pb += ldb; - }// j loop 1 multiple - //printf(" mx1:j:%d ", j); - return j; -} - -void bli_sqp_prepackA( double* pa, - double* aPacked, - gint_t k, - guint_t lda, - bool isTransA, - double alpha, - gint_t mx) -{ - //printf(" pmx:%d ",mx); - if(mx==8) - { - bli_prepackA_8(pa,aPacked,k, lda,isTransA, alpha); - } - else if(mx==4) - { - bli_prepackA_4(pa,aPacked,k, lda,isTransA, alpha); - } - else if(mx>4) - { - bli_prepackA_G4(pa,aPacked,k, lda,isTransA, alpha, mx); - } - else - { - bli_prepackA_L4(pa,aPacked,k, lda,isTransA, alpha, mx); - } -} - -/* Ax8 packing subroutine */ -void bli_prepackA_8(double* pa, - double* aPacked, - gint_t k, - guint_t lda, - bool isTransA, - double alpha) -{ - __m256d av0, av1, ymm0; - if(isTransA==false) - { - if(alpha==1.0) - { - for (gint_t p = 0; p < k; p += 1) { - av0 = _mm256_loadu_pd(pa); av1 = _mm256_loadu_pd(pa + 4); pa += lda; - _mm256_storeu_pd(aPacked, av0); _mm256_storeu_pd(aPacked + 4, av1); - aPacked += BLIS_MX8; - } - } - else if(alpha==-1.0) - { - ymm0 = _mm256_setzero_pd();//set zero - for (gint_t p = 0; p < k; p += 1) { - av0 = _mm256_loadu_pd(pa); av1 = _mm256_loadu_pd(pa + 4); pa += lda; - av0 = _mm256_sub_pd(ymm0,av0); av1 = _mm256_sub_pd(ymm0,av1); // a = 0 - a; - _mm256_storeu_pd(aPacked, av0); _mm256_storeu_pd(aPacked + 4, av1); - aPacked += BLIS_MX8; - } - } - } - else //subroutine below to be optimized - { - if(alpha==1.0) - { - //A Transpose case: - for (gint_t i = 0; i < BLIS_MX8 ; i++) - { - gint_t idx = i * lda; - for (gint_t p = 0; p < k; p ++) - { - double ar_ = *(pa+idx+p); - gint_t sidx = p * BLIS_MX8; - *(aPacked + sidx + i) = ar_; - } - } - } - else if(alpha==-1.0) - { - //A Transpose case: - for (gint_t i = 0; i < BLIS_MX8 ; i++) - { - gint_t idx = i * lda; - for (gint_t p = 0; p < k; p ++) - { - double ar_ = *(pa+idx+p); - gint_t sidx = p * BLIS_MX8; - *(aPacked + sidx + i) = -ar_; - } - } - } - } -} - -/* Ax4 packing subroutine */ -void bli_prepackA_4(double* pa, - double* aPacked, - gint_t k, - guint_t lda, - bool isTransA, - double alpha) -{ - __m256d av0, ymm0; - if(isTransA==false) - { - if(alpha==1.0) - { - for (gint_t p = 0; p < k; p += 1) { - av0 = _mm256_loadu_pd(pa); pa += lda; - _mm256_storeu_pd(aPacked, av0); - aPacked += BLIS_MX4; - } - } - else if(alpha==-1.0) - { - ymm0 = _mm256_setzero_pd();//set zero - for (gint_t p = 0; p < k; p += 1) { - av0 = _mm256_loadu_pd(pa); pa += lda; - av0 = _mm256_sub_pd(ymm0,av0); // a = 0 - a; - _mm256_storeu_pd(aPacked, av0); - aPacked += BLIS_MX4; - } - } - } - else //subroutine below to be optimized - { - if(alpha==1.0) - { - //A Transpose case: - for (gint_t i = 0; i < BLIS_MX4 ; i++) - { - gint_t idx = i * lda; - for (gint_t p = 0; p < k; p ++) - { - double ar_ = *(pa+idx+p); - gint_t sidx = p * BLIS_MX4; - *(aPacked + sidx + i) = ar_; - } - } - } - else if(alpha==-1.0) - { - //A Transpose case: - for (gint_t i = 0; i < BLIS_MX4 ; i++) - { - gint_t idx = i * lda; - for (gint_t p = 0; p < k; p ++) - { - double ar_ = *(pa+idx+p); - gint_t sidx = p * BLIS_MX4; - *(aPacked + sidx + i) = -ar_; - } - } - } - } - -} - -/* A packing m>4 subroutine */ -void bli_prepackA_G4( double* pa, - double* aPacked, - gint_t k, - guint_t lda, - bool isTransA, - double alpha, - gint_t mx) -{ - __m256d av0, ymm0; - gint_t mrem = mx - 4; - - if(isTransA==false) - { - if(alpha==1.0) - { - for (gint_t p = 0; p < k; p += 1) { - av0 = _mm256_loadu_pd(pa); - _mm256_storeu_pd(aPacked, av0); - for (gint_t i = 0; i < mrem; i += 1) { - *(aPacked+4+i) = *(pa+4+i); - } - aPacked += mx;pa += lda; - } - } - else if(alpha==-1.0) - { - ymm0 = _mm256_setzero_pd();//set zero - for (gint_t p = 0; p < k; p += 1) { - av0 = _mm256_loadu_pd(pa); - av0 = _mm256_sub_pd(ymm0,av0); // a = 0 - a; - _mm256_storeu_pd(aPacked, av0); - for (gint_t i = 0; i < mrem; i += 1) { - *(aPacked+4+i) = -*(pa+4+i); - } - aPacked += mx;pa += lda; - } - } - } - else //subroutine below to be optimized - { - if(alpha==1.0) - { - //A Transpose case: - for (gint_t i = 0; i < mx ; i++) - { - gint_t idx = i * lda; - for (gint_t p = 0; p < k; p ++) - { - double ar_ = *(pa+idx+p); - gint_t sidx = p * mx; - *(aPacked + sidx + i) = ar_; - } - } - } - else if(alpha==-1.0) - { - //A Transpose case: - for (gint_t i = 0; i < mx ; i++) - { - gint_t idx = i * lda; - for (gint_t p = 0; p < k; p ++) - { - double ar_ = *(pa+idx+p); - gint_t sidx = p * mx; - *(aPacked + sidx + i) = -ar_; - } - } - } - } - -} - -/* A packing m<4 subroutine */ -void bli_prepackA_L4( double* pa, - double* aPacked, - gint_t k, - guint_t lda, - bool isTransA, - double alpha, - gint_t mx) -{ - if(isTransA==false) - { - if(alpha==1.0) - { - for (gint_t p = 0; p < k; p += 1) - { - for (gint_t i = 0; i < mx; i += 1) - { - *(aPacked+i) = *(pa+i); - } - aPacked += mx;pa += lda; - } - } - else if(alpha==-1.0) - { - for (gint_t p = 0; p < k; p += 1) - { - for (gint_t i = 0; i < mx; i += 1) - { - *(aPacked+i) = -*(pa+i); - } - aPacked += mx;pa += lda; - } - } - } - else - { - if(alpha==1.0) - { - //A Transpose case: - for (gint_t i = 0; i < mx ; i++) - { - gint_t idx = i * lda; - for (gint_t p = 0; p < k; p ++) - { - double ar_ = *(pa+idx+p); - gint_t sidx = p * mx; - *(aPacked + sidx + i) = ar_; - } - } - } - else if(alpha==-1.0) - { - //A Transpose case: - for (gint_t i = 0; i < mx ; i++) - { - gint_t idx = i * lda; - for (gint_t p = 0; p < k; p ++) - { - double ar_ = *(pa+idx+p); - gint_t sidx = p * mx; - *(aPacked + sidx + i) = -ar_; - } - } - } - } - - -} - -/* Ax1 packing subroutine */ -void bli_prepackA_1(double* pa, - double* aPacked, - gint_t k, - guint_t lda, - bool isTransA, - double alpha) -{ - if(isTransA==false) - { - if(alpha==1.0) - { - for (gint_t p = 0; p < k; p += 1) { - *aPacked = *pa; - pa += lda; - aPacked++; - } - } - else if(alpha==-1.0) - { - for (gint_t p = 0; p < k; p += 1) { - *aPacked = -(*pa); - pa += lda; - aPacked++; - } - } - } - else - { - if(alpha==1.0) - { - //A Transpose case: - for (gint_t p = 0; p < k; p ++) - { - double ar_ = *(pa+p); - *(aPacked + p) = ar_; - } - } - else if(alpha==-1.0) - { - //A Transpose case: - for (gint_t p = 0; p < k; p ++) - { - double ar_ = *(pa+p); - *(aPacked + p) = -ar_; - } - } - } -} - - -void bli_add_m( gint_t m, - gint_t n, - double* w, - double* c) -{ - double* pc = c; - double* pw = w; - gint_t count = m*n; - gint_t i = 0; - __m256d cv0, wv0; - - for (; i <= (count-4); i+=4) - { - cv0 = _mm256_loadu_pd(pc); - wv0 = _mm256_loadu_pd(pw); pw += 4; - cv0 = _mm256_add_pd(cv0,wv0); - _mm256_storeu_pd(pc, cv0); pc += 4; - } - for (; i < count; i++) - { - *pc = *pc + *pw; - pc++; pw++; - } -} - -void bli_sub_m( gint_t m, - gint_t n, - double* w, - double* c) -{ - double* pc = c; - double* pw = w; - gint_t count = m*n; - gint_t i = 0; - __m256d cv0, wv0; - - for (; i <= (count-4); i+=4) - { - cv0 = _mm256_loadu_pd(pc); - wv0 = _mm256_loadu_pd(pw); pw += 4; - cv0 = _mm256_sub_pd(cv0,wv0); - _mm256_storeu_pd(pc, cv0); pc += 4; - } - for (; i < count; i++) - { - *pc = *pc - *pw; - pc++; pw++; - } -} - -/* Pack real and imaginary parts in separate buffers and also multipy with multiplication factor */ -void bli_3m_sqp_packC_real_imag(double* pc, - guint_t n, - guint_t m, - guint_t ldc, - double* pcr, - double* pci, - double mul, - gint_t mx) -{ - gint_t j, p; - __m256d av0, av1, zerov; - __m256d tv0, tv1; - gint_t max_m = (m*2)-8; - - if((mul ==1.0)||(mul==-1.0)) - { - if(mul ==1.0) /* handles alpha or beta = 1.0 */ - { - for (j = 0; j < n; j++) - { - for (p = 0; p <= max_m; p += 8) - { - double* pbp = pc + p; - av0 = _mm256_loadu_pd(pbp); //ai1, ar1, ai0, ar0 - av1 = _mm256_loadu_pd(pbp+4); //ai3, ar3, ai2, ar2 - - tv0 = _mm256_permute2f128_pd(av0, av1, 0x20);//ai2, ar2, ai0, ar0 - tv1 = _mm256_permute2f128_pd(av0, av1, 0x31);//ai3, ar3, ai1, ar1 - av0 = _mm256_unpacklo_pd(tv0, tv1);//ar3, ar2, ar1, ar0 - av1 = _mm256_unpackhi_pd(tv0, tv1);//ai3, ai2, ai1, ai0 - - _mm256_storeu_pd(pcr, av0); pcr += 4; - _mm256_storeu_pd(pci, av1); pci += 4; - } - - for (; p < (m*2); p += 2)// (real + imag)*m - { - double br = *(pc + p) ; - double bi = *(pc + p + 1); - *pcr = br; - *pci = bi; - pcr++; pci++; - } - pc = pc + ldc; - } - } - else /* handles alpha or beta = - 1.0 */ - { - zerov = _mm256_setzero_pd(); - for (j = 0; j < n; j++) - { - for (p = 0; p <= max_m; p += 8) - { - double* pbp = pc + p; - av0 = _mm256_loadu_pd(pbp); //ai1, ar1, ai0, ar0 - av1 = _mm256_loadu_pd(pbp+4);//ai3, ar3, ai2, ar2 - - tv0 = _mm256_permute2f128_pd(av0, av1, 0x20);//ai2, ar2, ai0, ar0 - tv1 = _mm256_permute2f128_pd(av0, av1, 0x31);//ai3, ar3, ai1, ar1 - av0 = _mm256_unpacklo_pd(tv0, tv1);//ar3, ar2, ar1, ar0 - av1 = _mm256_unpackhi_pd(tv0, tv1);//ai3, ai2, ai1, ai0 - - //negate - av0 = _mm256_sub_pd(zerov,av0); - av1 = _mm256_sub_pd(zerov,av1); - - _mm256_storeu_pd(pcr, av0); pcr += 4; - _mm256_storeu_pd(pci, av1); pci += 4; - } - - for (; p < (m*2); p += 2)// (real + imag)*m - { - double br = -*(pc + p) ; - double bi = -*(pc + p + 1); - *pcr = br; - *pci = bi; - pcr++; pci++; - } - pc = pc + ldc; - } - } - } - else if(mul==0) /* handles alpha or beta is equal to zero */ - { - double br_ = 0; - double bi_ = 0; - for (j = 0; j < n; j++) - { - for (p = 0; p < (m*2); p += 2)// (real + imag)*m - { - *pcr = br_; - *pci = bi_; - pcr++; pci++; - } - pc = pc + ldc; - } - } - else /* handles alpha or beta is not equal +/- 1.0 and zero */ - { - for (j = 0; j < n; j++) - { - for (p = 0; p < (m*2); p += 2)// (real + imag)*m - { - double br_ = mul * (*(pc + p)); - double bi_ = mul * (*(pc + p + 1)); - *pcr = br_; - *pci = bi_; - pcr++; pci++; - } - pc = pc + ldc; - } - } -} - -/* Pack real and imaginary parts in separate buffers and compute sum of real and imaginary part */ -void bli_3m_sqp_packB_real_imag_sum(double* pb, - guint_t n, - guint_t k, - guint_t ldb, - double* pbr, - double* pbi, - double* pbs, - double mul, - gint_t mx) -{ - gint_t j, p; - __m256d av0, av1, zerov; - __m256d tv0, tv1, sum; - gint_t max_k = (k*2) - 8; - if((mul ==1.0)||(mul==-1.0)) - { - if(mul ==1.0) - { - for (j = 0; j < n; j++) - { - for (p=0; p <= max_k; p += 8) - { - double* pbp = pb + p; - av0 = _mm256_loadu_pd(pbp);//ai1, ar1, ai0, ar0 - av1 = _mm256_loadu_pd(pbp+4);//ai3, ar3, ai2, ar2 - - tv0 = _mm256_permute2f128_pd(av0, av1, 0x20);//ai2, ar2, ai0, ar0 - tv1 = _mm256_permute2f128_pd(av0, av1, 0x31);//ai3, ar3, ai1, ar1 - av0 = _mm256_unpacklo_pd(tv0, tv1);//ar3, ar2, ar1, ar0 - av1 = _mm256_unpackhi_pd(tv0, tv1);//ai3, ai2, ai1, ai0 - sum = _mm256_add_pd(av0, av1); - _mm256_storeu_pd(pbr, av0); pbr += 4; - _mm256_storeu_pd(pbi, av1); pbi += 4; - _mm256_storeu_pd(pbs, sum); pbs += 4; - } - - for (; p < (k*2); p += 2)// (real + imag)*k - { - double br = *(pb + p) ; - double bi = *(pb + p + 1); - *pbr = br; - *pbi = bi; - *pbs = br + bi; - - pbr++; pbi++; pbs++; - } - pb = pb + ldb; - } - } - else - { - zerov = _mm256_setzero_pd(); - for (j = 0; j < n; j++) - { - for (p = 0; p <= max_k; p += 8) - { - double* pbp = pb + p; - av0 = _mm256_loadu_pd(pbp);//ai1, ar1, ai0, ar0 - av1 = _mm256_loadu_pd(pbp+4);//ai3, ar3, ai2, ar2 - - tv0 = _mm256_permute2f128_pd(av0, av1, 0x20);//ai2, ar2, ai0, ar0 - tv1 = _mm256_permute2f128_pd(av0, av1, 0x31);//ai3, ar3, ai1, ar1 - av0 = _mm256_unpacklo_pd(tv0, tv1);//ar3, ar2, ar1, ar0 - av1 = _mm256_unpackhi_pd(tv0, tv1);//ai3, ai2, ai1, ai0 - - //negate - av0 = _mm256_sub_pd(zerov,av0); - av1 = _mm256_sub_pd(zerov,av1); - - sum = _mm256_add_pd(av0, av1); - _mm256_storeu_pd(pbr, av0); pbr += 4; - _mm256_storeu_pd(pbi, av1); pbi += 4; - _mm256_storeu_pd(pbs, sum); pbs += 4; - } - - for (; p < (k*2); p += 2)// (real + imag)*k - { - double br = -*(pb + p) ; - double bi = -*(pb + p + 1); - *pbr = br; - *pbi = bi; - *pbs = br + bi; - - pbr++; pbi++; pbs++; - } - pb = pb + ldb; - } - } - } - else - { - for (j = 0; j < n; j++) - { - for (p = 0; p < (k*2); p += 2)// (real + imag)*k - { - double br_ = mul * (*(pb + p)); - double bi_ = mul * (*(pb + p + 1)); - *pbr = br_; - *pbi = bi_; - *pbs = br_ + bi_; - - pbr++; pbi++; pbs++; - } - pb = pb + ldb; - } - } -} - -/* Pack real and imaginary parts of A matrix in separate buffers and compute sum of real and imaginary part */ -void bli_3m_sqp_packA_real_imag_sum(double *pa, - gint_t i, - guint_t k, - guint_t lda, - double *par, - double *pai, - double *pas, - trans_t transa, - gint_t mx, - gint_t p) -{ - __m256d av0, av1, av2, av3; - __m256d tv0, tv1, sum, zerov; - gint_t poffset = p; -#if KLP -#endif - if(mx==8) - { - if(transa == BLIS_NO_TRANSPOSE) - { - pa = pa +i; -#if KLP - pa = pa + (p*lda); -#else - p = 0; -#endif - for (; p < k; p += 1) - { - //for (int ii = 0; ii < MX8 * 2; ii += 2) //real + imag : Rkernel needs 8 elements each. - av0 = _mm256_loadu_pd(pa); - av1 = _mm256_loadu_pd(pa+4); - av2 = _mm256_loadu_pd(pa+8); - av3 = _mm256_loadu_pd(pa+12); - - tv0 = _mm256_permute2f128_pd(av0, av1, 0x20); - tv1 = _mm256_permute2f128_pd(av0, av1, 0x31); - av0 = _mm256_unpacklo_pd(tv0, tv1); - av1 = _mm256_unpackhi_pd(tv0, tv1); - sum = _mm256_add_pd(av0, av1); - _mm256_storeu_pd(par, av0); par += 4; - _mm256_storeu_pd(pai, av1); pai += 4; - _mm256_storeu_pd(pas, sum); pas += 4; - - tv0 = _mm256_permute2f128_pd(av2, av3, 0x20); - tv1 = _mm256_permute2f128_pd(av2, av3, 0x31); - av2 = _mm256_unpacklo_pd(tv0, tv1); - av3 = _mm256_unpackhi_pd(tv0, tv1); - sum = _mm256_add_pd(av2, av3); - _mm256_storeu_pd(par, av2); par += 4; - _mm256_storeu_pd(pai, av3); pai += 4; - _mm256_storeu_pd(pas, sum); pas += 4; - - pa = pa + lda; - } - } - else if(transa == BLIS_CONJ_NO_TRANSPOSE) - { - zerov = _mm256_setzero_pd(); - pa = pa +i; -#if KLP - pa = pa + (p*lda); -#else - p = 0; -#endif - for (; p < k; p += 1) - { - //for (int ii = 0; ii < MX8 * 2; ii += 2) //real + imag : Rkernel needs 8 elements each. - av0 = _mm256_loadu_pd(pa); - av1 = _mm256_loadu_pd(pa+4); - av2 = _mm256_loadu_pd(pa+8); - av3 = _mm256_loadu_pd(pa+12); - - tv0 = _mm256_permute2f128_pd(av0, av1, 0x20); - tv1 = _mm256_permute2f128_pd(av0, av1, 0x31); - av0 = _mm256_unpacklo_pd(tv0, tv1); - av1 = _mm256_unpackhi_pd(tv0, tv1); - av1 = _mm256_sub_pd(zerov,av1);//negate imaginary component - sum = _mm256_add_pd(av0, av1); - _mm256_storeu_pd(par, av0); par += 4; - _mm256_storeu_pd(pai, av1); pai += 4; - _mm256_storeu_pd(pas, sum); pas += 4; - - tv0 = _mm256_permute2f128_pd(av2, av3, 0x20); - tv1 = _mm256_permute2f128_pd(av2, av3, 0x31); - av2 = _mm256_unpacklo_pd(tv0, tv1); - av3 = _mm256_unpackhi_pd(tv0, tv1); - av3 = _mm256_sub_pd(zerov,av3);//negate imaginary component - sum = _mm256_add_pd(av2, av3); - _mm256_storeu_pd(par, av2); par += 4; - _mm256_storeu_pd(pai, av3); pai += 4; - _mm256_storeu_pd(pas, sum); pas += 4; - - pa = pa + lda; - } - } - else if(transa == BLIS_TRANSPOSE) - { - gint_t idx = (i/2) * lda; - pa = pa + idx; -#if KLP -#else - p = 0; -#endif - //A Transpose case: - for (gint_t ii = 0; ii < BLIS_MX8 ; ii++) - { - gint_t idx = ii * lda; - gint_t sidx; - gint_t pidx = 0; - gint_t max_k = (k*2) - 8; - for (p = poffset; p <= max_k; p += 8) - { - double ar0_ = *(pa + idx + p); - double ai0_ = *(pa + idx + p + 1); - - double ar1_ = *(pa + idx + p + 2); - double ai1_ = *(pa + idx + p + 3); - - double ar2_ = *(pa + idx + p + 4); - double ai2_ = *(pa + idx + p + 5); - - double ar3_ = *(pa + idx + p + 6); - double ai3_ = *(pa + idx + p + 7); - - sidx = (pidx/2) * BLIS_MX8; - *(par + sidx + ii) = ar0_; - *(pai + sidx + ii) = ai0_; - *(pas + sidx + ii) = ar0_ + ai0_; - - sidx = ((pidx+2)/2) * BLIS_MX8; - *(par + sidx + ii) = ar1_; - *(pai + sidx + ii) = ai1_; - *(pas + sidx + ii) = ar1_ + ai1_; - - sidx = ((pidx+4)/2) * BLIS_MX8; - *(par + sidx + ii) = ar2_; - *(pai + sidx + ii) = ai2_; - *(pas + sidx + ii) = ar2_ + ai2_; - - sidx = ((pidx+6)/2) * BLIS_MX8; - *(par + sidx + ii) = ar3_; - *(pai + sidx + ii) = ai3_; - *(pas + sidx + ii) = ar3_ + ai3_; - pidx += 8; - - } - - for (; p < (k*2); p += 2) - { - double ar_ = *(pa + idx + p); - double ai_ = *(pa + idx + p + 1); - gint_t sidx = (pidx/2) * BLIS_MX8; - *(par + sidx + ii) = ar_; - *(pai + sidx + ii) = ai_; - *(pas + sidx + ii) = ar_ + ai_; - pidx += 2; - } - } - } - else if(transa == BLIS_CONJ_TRANSPOSE) - { - gint_t idx = (i/2) * lda; - pa = pa + idx; -#if KLP -#else - p = 0; -#endif - //A conjugate Transpose case: - for (gint_t ii = 0; ii < BLIS_MX8 ; ii++) - { - gint_t idx = ii * lda; - gint_t sidx; - gint_t pidx = 0; - gint_t max_k = (k*2) - 8; - for (p = poffset; p <= max_k; p += 8) - { - double ar0_ = *(pa + idx + p); - double ai0_ = -(*(pa + idx + p + 1)); - - double ar1_ = *(pa + idx + p + 2); - double ai1_ = -(*(pa + idx + p + 3)); - - double ar2_ = *(pa + idx + p + 4); - double ai2_ = -(*(pa + idx + p + 5)); - - double ar3_ = *(pa + idx + p + 6); - double ai3_ = -(*(pa + idx + p + 7)); - - sidx = (pidx/2) * BLIS_MX8; - *(par + sidx + ii) = ar0_; - *(pai + sidx + ii) = ai0_; - *(pas + sidx + ii) = ar0_ + ai0_; - - sidx = ((pidx+2)/2) * BLIS_MX8; - *(par + sidx + ii) = ar1_; - *(pai + sidx + ii) = ai1_; - *(pas + sidx + ii) = ar1_ + ai1_; - - sidx = ((pidx+4)/2) * BLIS_MX8; - *(par + sidx + ii) = ar2_; - *(pai + sidx + ii) = ai2_; - *(pas + sidx + ii) = ar2_ + ai2_; - - sidx = ((pidx+6)/2) * BLIS_MX8; - *(par + sidx + ii) = ar3_; - *(pai + sidx + ii) = ai3_; - *(pas + sidx + ii) = ar3_ + ai3_; - pidx += 8; - } - - for (; p < (k*2); p += 2) - { - double ar_ = *(pa + idx + p); - double ai_ = -(*(pa + idx + p + 1)); - gint_t sidx = (pidx/2) * BLIS_MX8; - *(par + sidx + ii) = ar_; - *(pai + sidx + ii) = ai_; - *(pas + sidx + ii) = ar_ + ai_; - pidx += 2; - } - } - } - } //mx==8 - else//mx==1 - { - if(transa == BLIS_NO_TRANSPOSE) - { - pa = pa + i; -#if KLP -#else - p = 0; -#endif - //A No transpose case: - for (; p < k; p += 1) - { - gint_t idx = p * lda; - for (gint_t ii = 0; ii < (mx*2) ; ii += 2) - { //real + imag : Rkernel needs 8 elements each. - double ar_ = *(pa + idx + ii); - double ai_ = *(pa + idx + ii + 1); - *par = ar_; - *pai = ai_; - *pas = ar_ + ai_; - par++; pai++; pas++; - } - } - } - else if(transa == BLIS_CONJ_NO_TRANSPOSE) - { - pa = pa + i; -#if KLP -#else - p = 0; -#endif - //A conjuate No transpose case: - for (; p < k; p += 1) - { - gint_t idx = p * lda; - for (gint_t ii = 0; ii < (mx*2) ; ii += 2) - { //real + imag : Rkernel needs 8 elements each. - double ar_ = *(pa + idx + ii); - double ai_ = -(*(pa + idx + ii + 1));// conjugate: negate imaginary component - *par = ar_; - *pai = ai_; - *pas = ar_ + ai_; - par++; pai++; pas++; - } - } - } - else if(transa == BLIS_TRANSPOSE) - { - gint_t idx = (i/2) * lda; - pa = pa + idx; -#if KLP -#else - p = 0; -#endif - //A Transpose case: - for (gint_t ii = 0; ii < mx ; ii++) - { - gint_t idx = ii * lda; - gint_t sidx; - gint_t pidx = 0; - for (p = poffset;p < (k*2); p += 2) - { - double ar0_ = *(pa + idx + p); - double ai0_ = *(pa + idx + p + 1); - - sidx = (pidx/2) * mx; - *(par + sidx + ii) = ar0_; - *(pai + sidx + ii) = ai0_; - *(pas + sidx + ii) = ar0_ + ai0_; - pidx += 2; - - } - } - } - else if(transa == BLIS_CONJ_TRANSPOSE) - { - gint_t idx = (i/2) * lda; - pa = pa + idx; -#if KLP -#else - p = 0; -#endif - //A Transpose case: - for (gint_t ii = 0; ii < mx ; ii++) - { - gint_t idx = ii * lda; - gint_t sidx; - gint_t pidx = 0; - for (p = poffset;p < (k*2); p += 2) - { - double ar0_ = *(pa + idx + p); - double ai0_ = -(*(pa + idx + p + 1)); - - sidx = (pidx/2) * mx; - *(par + sidx + ii) = ar0_; - *(pai + sidx + ii) = ai0_; - *(pas + sidx + ii) = ar0_ + ai0_; - pidx += 2; - - } - } - } - }//mx==1 -} - diff --git a/kernels/zen/3/bli_gemm_sqp_kernels.h b/kernels/zen/3/bli_gemm_sqp_kernels.h deleted file mode 100644 index 588981fad0..0000000000 --- a/kernels/zen/3/bli_gemm_sqp_kernels.h +++ /dev/null @@ -1,65 +0,0 @@ -/* - - BLIS - An object-based framework for developing high-performance BLAS-like - libraries. - - Copyright (C) 2021, Advanced Micro Devices, Inc. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ -/* square packed (sqp) kernels */ -#define KLP 1// k loop partition. - -/* sqp dgemm core kernels, targetted mainly for square sizes by default. - sqp framework allows tunning for other shapes.*/ -inc_t bli_sqp_dgemm_kernel_8mx6n(gint_t n, gint_t k, gint_t j, double* aPacked, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc); -inc_t bli_sqp_dgemm_kernel_8mx5n(gint_t n, gint_t k, gint_t j, double* aPacked, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc); -inc_t bli_sqp_dgemm_kernel_8mx4n(gint_t n, gint_t k, gint_t j, double* aPacked, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc); -inc_t bli_sqp_dgemm_kernel_8mx3n(gint_t n, gint_t k, gint_t j, double* aPacked, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc); -inc_t bli_sqp_dgemm_kernel_8mx2n(gint_t n, gint_t k, gint_t j, double* aPacked, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc); -inc_t bli_sqp_dgemm_kernel_8mx1n(gint_t n, gint_t k, gint_t j, double* aPacked, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc); -inc_t bli_sqp_dgemm_kernel_1mx1n(gint_t n, gint_t k, gint_t j, double* aPacked, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc); -inc_t bli_sqp_dgemm_kernel_mxn(gint_t n, gint_t k, gint_t j, double* aPacked, guint_t lda, double* b, guint_t ldb, double* c, guint_t ldc, gint_t mx); - -//add and sub kernels -void bli_add_m(gint_t m,gint_t n,double* w,double* c); -void bli_sub_m(gint_t m, gint_t n, double* w, double* c); - -//packing kernels -//Pack A with alpha multiplication -void bli_sqp_prepackA(double* pa, double* aPacked, gint_t k, guint_t lda, bool isTransA, double alpha, gint_t mx); - -void bli_prepackA_8(double* pa, double* aPacked, gint_t k, guint_t lda, bool isTransA, double alpha); -void bli_prepackA_4(double* pa, double* aPacked, gint_t k, guint_t lda, bool isTransA, double alpha); -void bli_prepackA_G4(double* pa, double* aPacked, gint_t k, guint_t lda, bool isTransA, double alpha, gint_t mx); -void bli_prepackA_L4(double* pa, double* aPacked, gint_t k, guint_t lda, bool isTransA, double alpha, gint_t mx); -void bli_prepackA_1(double* pa, double* aPacked, gint_t k, guint_t lda, bool isTransA, double alpha); - -/* Pack real and imaginary parts in separate buffers and also multipy with multiplication factor */ -void bli_3m_sqp_packC_real_imag(double* pb, guint_t n, guint_t k, guint_t ldb, double* pbr, double* pbi, double mul, gint_t mx); -void bli_3m_sqp_packB_real_imag_sum(double* pb, guint_t n, guint_t k, guint_t ldb, double* pbr, double* pbi, double* pbs, double mul, gint_t mx); -void bli_3m_sqp_packA_real_imag_sum(double *pa, gint_t i, guint_t k, guint_t lda, double *par, double *pai, double *pas, trans_t transa, gint_t mx, gint_t p); \ No newline at end of file diff --git a/kernels/zen/bli_kernels_zen.h b/kernels/zen/bli_kernels_zen.h index 4cec80773f..d600c4ac45 100644 --- a/kernels/zen/bli_kernels_zen.h +++ b/kernels/zen/bli_kernels_zen.h @@ -289,18 +289,6 @@ err_t bli_zgemm_small_At cntl_t* cntl ); -// gemm square matrix size friendly implementation -err_t bli_gemm_sqp - ( - obj_t* alpha, - obj_t* a, - obj_t* b, - obj_t* beta, - obj_t* c, - cntx_t* cntx, - cntl_t* cntl - ); - void bli_dgemm_ref_k1_nn ( dim_t m, diff --git a/testsuite/src/test_gemm.c b/testsuite/src/test_gemm.c index 1182f07e27..c49f0a3287 100644 --- a/testsuite/src/test_gemm.c +++ b/testsuite/src/test_gemm.c @@ -35,7 +35,6 @@ #include "blis.h" #include "test_libblis.h" -#define TEST_SQP 0// ENABLE to test sqp path. // Static variables. static char* op_str = "gemm"; @@ -243,18 +242,6 @@ void libblis_test_gemm_experiment sc_str[0], m, n, &c_save ); // Set alpha and beta. -#if TEST_SQP - if ( bli_obj_is_real( &c ) ) - { - bli_setsc( 1.0, 0.0, &alpha ); - bli_setsc( 1.0, 0.0, &beta ); - } - else - { - bli_setsc( 1.0, 0.0, &alpha ); - bli_setsc( 1.0, 0.0, &beta ); - } -#else if ( bli_obj_is_real( &c ) ) { bli_setsc( 1.2, 0.0, &alpha ); @@ -265,13 +252,6 @@ void libblis_test_gemm_experiment bli_setsc( 1.2, 0.8, &alpha ); bli_setsc( 0.9, 1.0, &beta ); } -#endif - - #if 0 - //bli_setm( &BLIS_ONE, &a ); - bli_setsc( 1.0, 0.0, &alpha ); - bli_setsc( 1.0, 0.0, &beta ); - #endif // Randomize A, B, and C, and save C. libblis_test_mobj_randomize( params, TRUE, &a ); @@ -457,31 +437,7 @@ void libblis_test_gemm_impl switch ( iface ) { case BLIS_TEST_SEQ_FRONT_END: -#if 0 -//bli_printm( "alpha", alpha, "%5.2f", "" ); -//bli_printm( "beta", beta, "%5.2f", "" ); -bli_printm( "a", a, "%5.2f", "" ); -bli_printm( "b", b, "%5.2f", "" ); -bli_printm( "c", c, "%5.2f", "" ); -#endif -//if ( bli_obj_length( b ) == 16 && -// bli_obj_stor3_from_strides( c, a, b ) == BLIS_CRR ) -//bli_printm( "c before", c, "%6.3f", "" ); - -#if TEST_SQP - if(bli_gemm_sqp(alpha,a,b,beta,c,NULL,NULL)!=BLIS_SUCCESS) - { - bli_gemm( alpha, a, b, beta, c ); - } -#else//TEST_SQP - bli_gemm( alpha, a, b, beta, c ); -#endif//TEST_SQP -#if 0 -if ( bli_obj_length( c ) == 12 && - bli_obj_stor3_from_strides( c, a, b ) == BLIS_RRR ) -bli_printm( "c after", c, "%6.3f", "" ); -#endif -//bli_printm( "c after", c, "%5.2f", "" ); + bli_gemm( alpha, a, b, beta, c ); break; default: From 2beaa6a0e64feaa8352871642a88c5c74e3f399d Mon Sep 17 00:00:00 2001 From: jagar Date: Mon, 29 Aug 2022 12:04:53 +0530 Subject: [PATCH 199/243] CBLAS/BLAS interface decoupling for swap api MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - In BLIS the cblas interface is implemented as a wrapper around the blas interface. For example the CBLAS api ‘cblas_dgemm’ internally invokes BLAS API ‘dgemm_’. - If the end user wants to use the different libraries for CBLAS and BLAS, current implantation of BLIS doesn’t allow it. - This change separates the CBLAS and BLAS implantation by adding an additional level of abstraction. The implementation of the API is moved to the new function which is invoked directly from the CBLAS and BLAS wrappers. AMD-Internal: [SWLCSG-1477] Change-Id: I8d81072aaca739f175318b82f6510d386103c24b --- frame/compat/bla_gemv.c | 4 +- frame/compat/bla_swap.c | 12 +- frame/compat/bla_swap.h | 7 + frame/compat/bla_swap_amd.c | 36 ++++- frame/compat/cblas/src/cblas_f77.h | 8 +- frame/include/bli_macro_defs.h | 2 +- frame/util/bli_util_api_wrap.c | 242 ++++++++++++++--------------- 7 files changed, 178 insertions(+), 133 deletions(-) diff --git a/frame/compat/bla_gemv.c b/frame/compat/bla_gemv.c index f5a314331a..12fda66605 100644 --- a/frame/compat/bla_gemv.c +++ b/frame/compat/bla_gemv.c @@ -141,7 +141,7 @@ void PASTEF77S(ch,blasname) \ bli_finalize_auto(); \ }\ \ -void PASTEF77S(ch,blasname) \ +void PASTEF77(ch,blasname) \ ( \ const f77_char* transa, \ const f77_int* m, \ @@ -153,7 +153,7 @@ void PASTEF77S(ch,blasname) \ ftype* y, const f77_int* incy \ ) \ { \ - PASTEF77(ch,blasname) \ + PASTEF77S(ch,blasname) \ ( transa, m, n, alpha, a, lda, x, incx, beta, y, incy ); \ } diff --git a/frame/compat/bla_swap.c b/frame/compat/bla_swap.c index d653426478..67d58c8cbd 100644 --- a/frame/compat/bla_swap.c +++ b/frame/compat/bla_swap.c @@ -42,7 +42,7 @@ #undef GENTFUNC #define GENTFUNC( ftype, ch, blasname, blisname ) \ \ -void PASTEF77(ch,blasname) \ +void PASTEF77S(ch,blasname) \ ( \ const f77_int* n, \ ftype* x, const f77_int* incx, \ @@ -80,6 +80,16 @@ void PASTEF77(ch,blasname) \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ +}\ +\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_int* n, \ + ftype* x, const f77_int* incx, \ + ftype* y, const f77_int* incy \ + ) \ +{ \ + PASTEF77S(ch,blasname)( n, x, incx, y, incy ); \ } #ifdef BLIS_ENABLE_BLAS diff --git a/frame/compat/bla_swap.h b/frame/compat/bla_swap.h index 54c0613a92..ccb31688c7 100644 --- a/frame/compat/bla_swap.h +++ b/frame/compat/bla_swap.h @@ -40,6 +40,13 @@ #define GENTPROT( ftype, ch, blasname ) \ \ BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ + ( \ + const f77_int* n, \ + ftype* x, const f77_int* incx, \ + ftype* y, const f77_int* incy \ + ); \ +\ +BLIS_EXPORT_BLAS void PASTEF77S(ch,blasname) \ ( \ const f77_int* n, \ ftype* x, const f77_int* incx, \ diff --git a/frame/compat/bla_swap_amd.c b/frame/compat/bla_swap_amd.c index 617c78a4aa..77e8afcca0 100644 --- a/frame/compat/bla_swap_amd.c +++ b/frame/compat/bla_swap_amd.c @@ -42,7 +42,7 @@ #undef GENTFUNC #define GENTFUNC( ftype, ch, blasname, blisname ) \ \ -void PASTEF77(ch,blasname) \ +void PASTEF77S(ch,blasname) \ ( \ const f77_int* n, \ ftype* x, const f77_int* incx, \ @@ -80,11 +80,21 @@ void PASTEF77(ch,blasname) \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ +}\ +\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_int* n, \ + ftype* x, const f77_int* incx, \ + ftype* y, const f77_int* incy \ + ) \ +{ \ + PASTEF77S(ch,blasname)( n, x, incx, y, incy ); \ } #ifdef BLIS_ENABLE_BLAS -void sswap_ +void sswap_blis_impl ( const f77_int* n, float* x, const f77_int* incx, @@ -172,8 +182,17 @@ void sswap_ // bli_finalize_auto(); AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) } - -void dswap_ +void sswap_ + ( + const f77_int* n, + float* x, const f77_int* incx, + float* y, const f77_int* incy + ) +{ + sswap_blis_impl( n, x, incx, y, incy ); +} + +void dswap_blis_impl ( const f77_int* n, double* x, const f77_int* incx, @@ -261,6 +280,15 @@ void dswap_ // bli_finalize_auto(); AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) } +void dswap_ + ( + const f77_int* n, + double* x, const f77_int* incx, + double* y, const f77_int* incy + ) +{ + dswap_blis_impl( n, x, incx, y, incy ); +} INSERT_GENTFUNC_BLAS_CZ( swap, swapv ) diff --git a/frame/compat/cblas/src/cblas_f77.h b/frame/compat/cblas/src/cblas_f77.h index 864d78895e..8e43c05d17 100644 --- a/frame/compat/cblas/src/cblas_f77.h +++ b/frame/compat/cblas/src/cblas_f77.h @@ -215,19 +215,19 @@ #define F77_drotmg drotmg_blis_impl #define F77_drot drot_blis_impl #define F77_drotm drotm_blis_impl -#define F77_sswap sswap_ +#define F77_sswap sswap_blis_impl #define F77_scopy scopy_blis_impl #define F77_saxpy saxpy_blis_impl #define F77_isamax_sub isamaxsub_blis_impl -#define F77_dswap dswap_ +#define F77_dswap dswap_blis_impl #define F77_dcopy dcopy_blis_impl #define F77_daxpy daxpy_blis_impl #define F77_idamax_sub idamaxsub_blis_impl -#define F77_cswap cswap_ +#define F77_cswap cswap_blis_impl #define F77_ccopy ccopy_blis_impl #define F77_caxpy caxpy_blis_impl #define F77_icamax_sub icamaxsub_blis_impl -#define F77_zswap zswap_ +#define F77_zswap zswap_blis_impl #define F77_zcopy zcopy_blis_impl #define F77_zaxpy zaxpy_blis_impl #define F77_izamax_sub izamaxsub_blis_impl diff --git a/frame/include/bli_macro_defs.h b/frame/include/bli_macro_defs.h index 9ab7c00aa7..99d1de6180 100644 --- a/frame/include/bli_macro_defs.h +++ b/frame/include/bli_macro_defs.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018-2021, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/frame/util/bli_util_api_wrap.c b/frame/util/bli_util_api_wrap.c index 26810531d3..6e500b0b88 100644 --- a/frame/util/bli_util_api_wrap.c +++ b/frame/util/bli_util_api_wrap.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2021-2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -210,17 +210,17 @@ void CGBMV_(const char *trans,const f77_int *m,const f77_int *n,const f77_int void CGEMM(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) { - cgemm_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + cgemm_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void cgemm(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) { - cgemm_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + cgemm_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void CGEMM_(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) { - cgemm_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + cgemm_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void CGEMV(const char *trans,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *x,const f77_int *incx,const scomplex *beta,scomplex *y,const f77_int *incy) @@ -285,17 +285,17 @@ void CHBMV_(const char *uplo,const f77_int *n,const f77_int *k,const scomplex void CHEMM(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) { - chemm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + chemm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); } void chemm(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) { - chemm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + chemm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); } void CHEMM_(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) { - chemm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + chemm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); } void CHEMV(const char *uplo,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *x,const f77_int *incx,const scomplex *beta,scomplex *y,const f77_int *incy) @@ -345,32 +345,32 @@ void CHER2_(const char *uplo,const f77_int *n,const scomplex *alpha,const sco void CHER2K(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const float *beta,scomplex *c,const f77_int *ldc) { - cher2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + cher2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void cher2k(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const float *beta,scomplex *c,const f77_int *ldc) { - cher2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + cher2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void CHER2K_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const float *beta,scomplex *c,const f77_int *ldc) { - cher2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + cher2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void CHERK(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const float *alpha,const scomplex *a,const f77_int *lda,const float *beta,scomplex *c,const f77_int *ldc) { - cherk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); + cherk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); } void cherk(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const float *alpha,const scomplex *a,const f77_int *lda,const float *beta,scomplex *c,const f77_int *ldc) { - cherk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); + cherk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); } void CHERK_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const float *alpha,const scomplex *a,const f77_int *lda,const float *beta,scomplex *c,const f77_int *ldc) { - cherk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); + cherk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); } void CHPMV(const char *uplo,const f77_int *n,const scomplex *alpha,const scomplex *ap,const scomplex *x,const f77_int *incx,const scomplex *beta,scomplex *y,const f77_int *incy) @@ -480,62 +480,62 @@ void CSSCAL_(const f77_int *n,const float *sa,scomplex *cx,const f77_int *incx void CSWAP(const f77_int *n,scomplex *cx,const f77_int *incx,scomplex *cy,const f77_int *incy) { - cswap_( n, cx, incx, cy, incy); + cswap_blis_impl( n, cx, incx, cy, incy); } void cswap(const f77_int *n,scomplex *cx,const f77_int *incx,scomplex *cy,const f77_int *incy) { - cswap_( n, cx, incx, cy, incy); + cswap_blis_impl( n, cx, incx, cy, incy); } void CSWAP_(const f77_int *n,scomplex *cx,const f77_int *incx,scomplex *cy,const f77_int *incy) { - cswap_( n, cx, incx, cy, incy); + cswap_blis_impl( n, cx, incx, cy, incy); } void CSYMM(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) { - csymm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + csymm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); } void csymm(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) { - csymm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + csymm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); } void CSYMM_(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) { - csymm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + csymm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); } void CSYR2K(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) { - csyr2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + csyr2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void csyr2k(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) { - csyr2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + csyr2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void CSYR2K_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) { - csyr2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + csyr2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void CSYRK(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *beta,scomplex *c,const f77_int *ldc) { - csyrk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); + csyrk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); } void csyrk(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *beta,scomplex *c,const f77_int *ldc) { - csyrk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); + csyrk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); } void CSYRK_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *beta,scomplex *c,const f77_int *ldc) { - csyrk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); + csyrk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); } void CTBMV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const scomplex *a,const f77_int *lda,scomplex *x,const f77_int *incx) @@ -600,17 +600,17 @@ void CTPSV_(const char *uplo,const char *trans,const char *diag,const f77_ void CTRMM(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,scomplex *b,const f77_int *ldb) { - ctrmm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + ctrmm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void ctrmm(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,scomplex *b,const f77_int *ldb) { - ctrmm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + ctrmm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void CTRMM_(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,scomplex *b,const f77_int *ldb) { - ctrmm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + ctrmm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void CTRMV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const scomplex *a,const f77_int *lda,scomplex *x,const f77_int *incx) @@ -630,17 +630,17 @@ void CTRMV_(const char *uplo,const char *trans,const char *diag,const f77_ void CTRSM(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,scomplex *b,const f77_int *ldb) { - ctrsm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + ctrsm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void ctrsm(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,scomplex *b,const f77_int *ldb) { - ctrsm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + ctrsm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void CTRSM_(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,scomplex *b,const f77_int *ldb) { - ctrsm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + ctrsm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void CTRSV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const scomplex *a,const f77_int *lda,scomplex *x,const f77_int *incx) @@ -750,17 +750,17 @@ void DGBMV_(const char *trans,const f77_int *m,const f77_int *n,const f77_int void DGEMM(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const double *alpha,const double *a,const f77_int *lda,const double *b,const f77_int *ldb,const double *beta,double *c,const f77_int *ldc) { - dgemm_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + dgemm_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void dgemm(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const double *alpha,const double *a,const f77_int *lda,const double *b,const f77_int *ldb,const double *beta,double *c,const f77_int *ldc) { - dgemm_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + dgemm_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void DGEMM_(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const double *alpha,const double *a,const f77_int *lda,const double *b,const f77_int *ldb,const double *beta,double *c,const f77_int *ldc) { - dgemm_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + dgemm_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void DGEMV(const char *trans,const f77_int *m,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,const double *x,const f77_int *incx,const double *beta,double *y,const f77_int *incy) @@ -960,32 +960,32 @@ void DSPR2_(const char *uplo,const f77_int *n,const double *alpha,const double void DSWAP(const f77_int *n,double *dx,const f77_int *incx,double *dy,const f77_int *incy) { - dswap_( n, dx, incx, dy, incy); + dswap_blis_impl( n, dx, incx, dy, incy); } void dswap(const f77_int *n,double *dx,const f77_int *incx,double *dy,const f77_int *incy) { - dswap_( n, dx, incx, dy, incy); + dswap_blis_impl( n, dx, incx, dy, incy); } void DSWAP_(const f77_int *n,double *dx,const f77_int *incx,double *dy,const f77_int *incy) { - dswap_( n, dx, incx, dy, incy); + dswap_blis_impl( n, dx, incx, dy, incy); } void DSYMM(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,const double *b,const f77_int *ldb,const double *beta,double *c,const f77_int *ldc) { - dsymm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + dsymm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); } void dsymm(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,const double *b,const f77_int *ldb,const double *beta,double *c,const f77_int *ldc) { - dsymm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + dsymm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); } void DSYMM_(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,const double *b,const f77_int *ldb,const double *beta,double *c,const f77_int *ldc) { - dsymm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + dsymm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); } void DSYMV(const char *uplo,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,const double *x,const f77_int *incx,const double *beta,double *y,const f77_int *incy) @@ -1035,32 +1035,32 @@ void DSYR2_(const char *uplo,const f77_int *n,const double *alpha,const double void DSYR2K(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const double *alpha,const double *a,const f77_int *lda,const double *b,const f77_int *ldb,const double *beta,double *c,const f77_int *ldc) { - dsyr2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + dsyr2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void dsyr2k(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const double *alpha,const double *a,const f77_int *lda,const double *b,const f77_int *ldb,const double *beta,double *c,const f77_int *ldc) { - dsyr2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + dsyr2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void DSYR2K_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const double *alpha,const double *a,const f77_int *lda,const double *b,const f77_int *ldb,const double *beta,double *c,const f77_int *ldc) { - dsyr2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + dsyr2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void DSYRK(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const double *alpha,const double *a,const f77_int *lda,const double *beta,double *c,const f77_int *ldc) { - dsyrk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); + dsyrk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); } void dsyrk(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const double *alpha,const double *a,const f77_int *lda,const double *beta,double *c,const f77_int *ldc) { - dsyrk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); + dsyrk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); } void DSYRK_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const double *alpha,const double *a,const f77_int *lda,const double *beta,double *c,const f77_int *ldc) { - dsyrk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); + dsyrk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); } void DTBMV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const double *a,const f77_int *lda,double *x,const f77_int *incx) @@ -1125,17 +1125,17 @@ void DTPSV_(const char *uplo,const char *trans,const char *diag,const f77_ void DTRMM(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,double *b,const f77_int *ldb) { - dtrmm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + dtrmm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void dtrmm(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,double *b,const f77_int *ldb) { - dtrmm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + dtrmm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void DTRMM_(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,double *b,const f77_int *ldb) { - dtrmm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + dtrmm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void DTRMV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const double *a,const f77_int *lda,double *x,const f77_int *incx) @@ -1155,17 +1155,17 @@ void DTRMV_(const char *uplo,const char *trans,const char *diag,const f77_ void DTRSM(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,double *b,const f77_int *ldb) { - dtrsm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + dtrsm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void dtrsm(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,double *b,const f77_int *ldb) { - dtrsm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + dtrsm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void DTRSM_(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,double *b,const f77_int *ldb) { - dtrsm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + dtrsm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void DTRSV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const double *a,const f77_int *lda,double *x,const f77_int *incx) @@ -1417,17 +1417,17 @@ void SGBMV_(const char *trans,const f77_int *m,const f77_int *n,const f77_int void SGEMM(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const float *alpha,const float *a,const f77_int *lda,const float *b,const f77_int *ldb,const float *beta,float *c,const f77_int *ldc) { - sgemm_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + sgemm_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void sgemm(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const float *alpha,const float *a,const f77_int *lda,const float *b,const f77_int *ldb,const float *beta,float *c,const f77_int *ldc) { - sgemm_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + sgemm_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void SGEMM_(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const float *alpha,const float *a,const f77_int *lda,const float *b,const f77_int *ldb,const float *beta,float *c,const f77_int *ldc) { - sgemm_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + sgemm_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void SGEMV(const char *trans,const f77_int *m,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,const float *x,const f77_int *incx,const float *beta,float *y,const f77_int *incy) @@ -1614,32 +1614,32 @@ void SSPR2_(const char *uplo,const f77_int *n,const float *alpha,const float void SSWAP(const f77_int *n,float *sx,const f77_int *incx,float *sy,const f77_int *incy) { - sswap_( n, sx, incx, sy, incy); + sswap_blis_impl( n, sx, incx, sy, incy); } void sswap(const f77_int *n,float *sx,const f77_int *incx,float *sy,const f77_int *incy) { - sswap_( n, sx, incx, sy, incy); + sswap_blis_impl( n, sx, incx, sy, incy); } void SSWAP_(const f77_int *n,float *sx,const f77_int *incx,float *sy,const f77_int *incy) { - sswap_( n, sx, incx, sy, incy); + sswap_blis_impl( n, sx, incx, sy, incy); } void SSYMM(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,const float *b,const f77_int *ldb,const float *beta,float *c,const f77_int *ldc) { - ssymm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + ssymm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); } void ssymm(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,const float *b,const f77_int *ldb,const float *beta,float *c,const f77_int *ldc) { - ssymm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + ssymm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); } void SSYMM_(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,const float *b,const f77_int *ldb,const float *beta,float *c,const f77_int *ldc) { - ssymm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + ssymm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); } void SSYMV(const char *uplo,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,const float *x,const f77_int *incx,const float *beta,float *y,const f77_int *incy) @@ -1689,32 +1689,32 @@ void SSYR2_(const char *uplo,const f77_int *n,const float *alpha,const float void SSYR2K(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const float *alpha,const float *a,const f77_int *lda,const float *b,const f77_int *ldb,const float *beta,float *c,const f77_int *ldc) { - ssyr2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + ssyr2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void ssyr2k(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const float *alpha,const float *a,const f77_int *lda,const float *b,const f77_int *ldb,const float *beta,float *c,const f77_int *ldc) { - ssyr2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + ssyr2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void SSYR2K_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const float *alpha,const float *a,const f77_int *lda,const float *b,const f77_int *ldb,const float *beta,float *c,const f77_int *ldc) { - ssyr2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + ssyr2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void SSYRK(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const float *alpha,const float *a,const f77_int *lda,const float *beta,float *c,const f77_int *ldc) { - ssyrk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); + ssyrk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); } void ssyrk(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const float *alpha,const float *a,const f77_int *lda,const float *beta,float *c,const f77_int *ldc) { - ssyrk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); + ssyrk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); } void SSYRK_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const float *alpha,const float *a,const f77_int *lda,const float *beta,float *c,const f77_int *ldc) { - ssyrk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); + ssyrk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); } void STBMV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const float *a,const f77_int *lda,float *x,const f77_int *incx) @@ -1779,17 +1779,17 @@ void STPSV_(const char *uplo,const char *trans,const char *diag,const f77_ void STRMM(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,float *b,const f77_int *ldb) { - strmm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + strmm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void strmm(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,float *b,const f77_int *ldb) { - strmm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + strmm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void STRMM_(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,float *b,const f77_int *ldb) { - strmm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + strmm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void STRMV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const float *a,const f77_int *lda,float *x,const f77_int *incx) @@ -1809,17 +1809,17 @@ void STRMV_(const char *uplo,const char *trans,const char *diag,const f77_ void STRSM(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,float *b,const f77_int *ldb) { - strsm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + strsm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void strsm(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,float *b,const f77_int *ldb) { - strsm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + strsm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void STRSM_(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,float *b,const f77_int *ldb) { - strsm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + strsm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void STRSV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const float *a,const f77_int *lda,float *x,const f77_int *incx) @@ -1929,17 +1929,17 @@ void ZGBMV_(const char *trans,const f77_int *m,const f77_int *n,const f77_int void ZGEMM(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) { - zgemm_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + zgemm_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void zgemm(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) { - zgemm_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + zgemm_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void ZGEMM_(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) { - zgemm_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + zgemm_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void ZGEMV(const char *trans,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *x,const f77_int *incx,const dcomplex *beta,dcomplex *y,const f77_int *incy) @@ -2004,17 +2004,17 @@ void ZHBMV_(const char *uplo,const f77_int *n,const f77_int *k,const dcomplex void ZHEMM(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) { - zhemm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + zhemm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); } void zhemm(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) { - zhemm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + zhemm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); } void ZHEMM_(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) { - zhemm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + zhemm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); } void ZHEMV(const char *uplo,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *x,const f77_int *incx,const dcomplex *beta,dcomplex *y,const f77_int *incy) @@ -2064,32 +2064,32 @@ void ZHER2_(const char *uplo,const f77_int *n,const dcomplex *alpha,const dcom void ZHER2K(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const double *beta,dcomplex *c,const f77_int *ldc) { - zher2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + zher2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void zher2k(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const double *beta,dcomplex *c,const f77_int *ldc) { - zher2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + zher2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void ZHER2K_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const double *beta,dcomplex *c,const f77_int *ldc) { - zher2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + zher2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void ZHERK(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const double *alpha,const dcomplex *a,const f77_int *lda,const double *beta,dcomplex *c,const f77_int *ldc) { - zherk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); + zherk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); } void zherk(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const double *alpha,const dcomplex *a,const f77_int *lda,const double *beta,dcomplex *c,const f77_int *ldc) { - zherk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); + zherk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); } void ZHERK_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const double *alpha,const dcomplex *a,const f77_int *lda,const double *beta,dcomplex *c,const f77_int *ldc) { - zherk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); + zherk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); } void ZHPMV(const char *uplo,const f77_int *n,const dcomplex *alpha,const dcomplex *ap,const dcomplex *x,const f77_int *incx,const dcomplex *beta,dcomplex *y,const f77_int *incy) @@ -2169,62 +2169,62 @@ void ZSCAL_(const f77_int *n,const dcomplex *za,dcomplex *zx,const f77_int *incx void ZSWAP(const f77_int *n,dcomplex *zx,const f77_int *incx,dcomplex *zy,const f77_int *incy) { - zswap_( n, zx, incx, zy, incy); + zswap_blis_impl( n, zx, incx, zy, incy); } void zswap(const f77_int *n,dcomplex *zx,const f77_int *incx,dcomplex *zy,const f77_int *incy) { - zswap_( n, zx, incx, zy, incy); + zswap_blis_impl( n, zx, incx, zy, incy); } void ZSWAP_(const f77_int *n,dcomplex *zx,const f77_int *incx,dcomplex *zy,const f77_int *incy) { - zswap_( n, zx, incx, zy, incy); + zswap_blis_impl( n, zx, incx, zy, incy); } void ZSYMM(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) { - zsymm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + zsymm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); } void zsymm(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) { - zsymm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + zsymm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); } void ZSYMM_(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) { - zsymm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + zsymm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); } void ZSYR2K(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) { - zsyr2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + zsyr2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void zsyr2k(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) { - zsyr2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + zsyr2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void ZSYR2K_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) { - zsyr2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + zsyr2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void ZSYRK(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *beta,dcomplex *c,const f77_int *ldc) { - zsyrk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); + zsyrk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); } void zsyrk(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *beta,dcomplex *c,const f77_int *ldc) { - zsyrk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); + zsyrk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); } void ZSYRK_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *beta,dcomplex *c,const f77_int *ldc) { - zsyrk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); + zsyrk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); } void ZTBMV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const dcomplex *a,const f77_int *lda,dcomplex *x,const f77_int *incx) @@ -2289,17 +2289,17 @@ void ZTPSV_(const char *uplo,const char *trans,const char *diag,const f77_ void ZTRMM(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,dcomplex *b,const f77_int *ldb) { - ztrmm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + ztrmm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void ztrmm(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,dcomplex *b,const f77_int *ldb) { - ztrmm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + ztrmm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void ZTRMM_(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,dcomplex *b,const f77_int *ldb) { - ztrmm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + ztrmm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void ZTRMV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const dcomplex *a,const f77_int *lda,dcomplex *x,const f77_int *incx) @@ -2319,17 +2319,17 @@ void ZTRMV_(const char *uplo,const char *trans,const char *diag,const f77_ void ZTRSM(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,dcomplex *b,const f77_int *ldb) { - ztrsm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + ztrsm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void ztrsm(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,dcomplex *b,const f77_int *ldb) { - ztrsm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + ztrsm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void ZTRSM_(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,dcomplex *b,const f77_int *ldb) { - ztrsm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + ztrsm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void ZTRSV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const dcomplex *a,const f77_int *lda,dcomplex *x,const f77_int *incx) @@ -2380,17 +2380,17 @@ void CDOTUSUB_( const f77_int* n, const scomplex* x,const f77_int* incxy, const void CGEMM3M( const f77_char* transa, const f77_char* transb, const f77_int* m, const f77_int* n, const f77_int* k, const scomplex* alpha, const scomplex* a, const f77_int* lda, const scomplex* b, const f77_int* ldb, const scomplex* beta, scomplex* c, const f77_int* ldc) { - cgemm3m_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + cgemm3m_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void cgemm3m( const f77_char* transa, const f77_char* transb, const f77_int* m, const f77_int* n, const f77_int* k, const scomplex* alpha, const scomplex* a, const f77_int* lda, const scomplex* b, const f77_int* ldb, const scomplex* beta, scomplex* c, const f77_int* ldc) { - cgemm3m_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + cgemm3m_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void CGEMM3M_( const f77_char* transa, const f77_char* transb, const f77_int* m, const f77_int* n, const f77_int* k, const scomplex* alpha, const scomplex* a, const f77_int* lda, const scomplex* b, const f77_int* ldb, const scomplex* beta, scomplex* c, const f77_int* ldc) { - cgemm3m_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + cgemm3m_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void CGEMM_BATCH( const f77_char* transa_array, const f77_char* transb_array,const f77_int *m_array, const f77_int *n_array, const f77_int *k_array,const scomplex* alpha_array, const scomplex** a_array, const f77_int *lda_array, const scomplex** b_array, const f77_int *ldb_array, const scomplex* beta_array, scomplex** c_array, const f77_int *ldc_array, const f77_int* group_count, const f77_int *group_size) @@ -2410,17 +2410,17 @@ void CGEMM_BATCH_( const f77_char* transa_array, const f77_char* transb_array,co void CGEMMT( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const scomplex* alpha, const scomplex* a, const f77_int* lda, const scomplex* b, const f77_int* ldb, const scomplex* beta, scomplex* c, const f77_int* ldc) { - cgemmt_blis_impl( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + cgemmt_( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void cgemmt( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const scomplex* alpha, const scomplex* a, const f77_int* lda, const scomplex* b, const f77_int* ldb, const scomplex* beta, scomplex* c, const f77_int* ldc) { - cgemmt_blis_impl( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + cgemmt_( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void CGEMMT_( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const scomplex* alpha, const scomplex* a, const f77_int* lda, const scomplex* b, const f77_int* ldb, const scomplex* beta, scomplex* c, const f77_int* ldc) { - cgemmt_blis_impl( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + cgemmt_( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void CIMATCOPY(f77_char* trans, f77_int* rows, f77_int* cols, const scomplex* alpha,scomplex* aptr, f77_int* lda, f77_int* ldb) @@ -2545,17 +2545,17 @@ void DGEMM_BATCH_( const f77_char* transa_array, const f77_char* transb_array,co void DGEMMT( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const double* alpha, const double* a, const f77_int* lda, const double* b, const f77_int* ldb, const double* beta, double* c, const f77_int* ldc) { - dgemmt_blis_impl( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + dgemmt_( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void dgemmt( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const double* alpha, const double* a, const f77_int* lda, const double* b, const f77_int* ldb, const double* beta, double* c, const f77_int* ldc) { - dgemmt_blis_impl( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + dgemmt_( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void DGEMMT_( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const double* alpha, const double* a, const f77_int* lda, const double* b, const f77_int* ldb, const double* beta, double* c, const f77_int* ldc) { - dgemmt_blis_impl( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + dgemmt_( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void DNRM2SUB(const f77_int* n, const double* x, const f77_int* incx, double *rval) @@ -2920,17 +2920,17 @@ void SGEMM_BATCH_(const f77_char* transa_array, const f77_char* transb_array,con void SGEMMT( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const float* alpha, const float* a, const f77_int* lda, const float* b, const f77_int* ldb, const float* beta, float* c, const f77_int* ldc) { - sgemmt_blis_impl( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + sgemmt_( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void sgemmt( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const float* alpha, const float* a, const f77_int* lda, const float* b, const f77_int* ldb, const float* beta, float* c, const f77_int* ldc) { - sgemmt_blis_impl( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + sgemmt_( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void SGEMMT_( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const float* alpha, const float* a, const f77_int* lda, const float* b, const f77_int* ldb, const float* beta, float* c, const f77_int* ldc) { - sgemmt_blis_impl( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + sgemmt_( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void SIMATCOPY( f77_char* trans, f77_int* rows, f77_int* cols, const float* alpha,float* aptr, f77_int* lda, f77_int* ldb) @@ -3055,17 +3055,17 @@ void ZDOTUSUB_( const f77_int* n, const dcomplex* x, const f77_int* incx,const d void ZGEMM3M( const f77_char* transa, const f77_char* transb, const f77_int* m, const f77_int* n, const f77_int* k, const dcomplex* alpha, const dcomplex* a, const f77_int* lda, const dcomplex* b, const f77_int* ldb, const dcomplex* beta, dcomplex* c, const f77_int* ldc) { - zgemm3m_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + zgemm3m_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void zgemm3m( const f77_char* transa, const f77_char* transb, const f77_int* m, const f77_int* n, const f77_int* k, const dcomplex* alpha, const dcomplex* a, const f77_int* lda, const dcomplex* b, const f77_int* ldb, const dcomplex* beta, dcomplex* c, const f77_int* ldc) { - zgemm3m_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + zgemm3m_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void ZGEMM3M_( const f77_char* transa, const f77_char* transb, const f77_int* m, const f77_int* n, const f77_int* k, const dcomplex* alpha, const dcomplex* a, const f77_int* lda, const dcomplex* b, const f77_int* ldb, const dcomplex* beta, dcomplex* c, const f77_int* ldc) { - zgemm3m_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + zgemm3m_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void ZGEMM_BATCH( const f77_char* transa_array, const f77_char* transb_array,const f77_int *m_array, const f77_int *n_array, const f77_int *k_array,const dcomplex* alpha_array, const dcomplex** a_array, const f77_int *lda_array, const dcomplex** b_array, const f77_int *ldb_array, const dcomplex* beta_array, dcomplex** c_array, const f77_int *ldc_array, const f77_int* group_count, const f77_int *group_size) @@ -3085,17 +3085,17 @@ void ZGEMM_BATCH_( const f77_char* transa_array, const f77_char* transb_array,c void ZGEMMT( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const dcomplex* alpha, const dcomplex* a, const f77_int* lda, const dcomplex* b, const f77_int* ldb, const dcomplex* beta, dcomplex* c, const f77_int* ldc) { - zgemmt_blis_impl( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + zgemmt_( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void zgemmt( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const dcomplex* alpha, const dcomplex* a, const f77_int* lda, const dcomplex* b, const f77_int* ldb, const dcomplex* beta, dcomplex* c, const f77_int* ldc) { - zgemmt_blis_impl( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + zgemmt_( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void ZGEMMT_( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const dcomplex* alpha, const dcomplex* a, const f77_int* lda, const dcomplex* b, const f77_int* ldb, const dcomplex* beta, dcomplex* c, const f77_int* ldc) { - zgemmt_blis_impl( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + zgemmt_( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void ZIMATCOPY(f77_char* trans, f77_int* rows, f77_int* cols, const dcomplex* alpha,dcomplex* aptr, f77_int* lda, f77_int* ldb) From eb83a0fe9d1060918791b339301d0cbd845b1199 Mon Sep 17 00:00:00 2001 From: Arnav Sharma Date: Mon, 29 Aug 2022 15:20:07 +0530 Subject: [PATCH 200/243] Enabled ZHER Optimized Path - While calculating the diagonal and corner elements, the combined operation of calculating the product of x and x hermitian and simultaneously scaling it with alpha and adding the result to the matrix was the cause of increased underflow and overflow errors in netlib tests. - So the above calculation is now being done in three steps: scaling x vector with alpha, then calculating its product with x hermitian and later adding the final result to the matrix. AMD-Internal: [CPUPL-2213] Change-Id: I32df572b013bc3189340662dbf17eddcaec9f0f8 --- frame/2/her/bli_her_unb_var1_amd.c | 10 +- frame/2/her/bli_her_unb_var2_amd.c | 10 +- kernels/zen/2/bli_her_zen_int_amd.c | 136 ++++++++++++++++++++++------ 3 files changed, 119 insertions(+), 37 deletions(-) diff --git a/frame/2/her/bli_her_unb_var1_amd.c b/frame/2/her/bli_her_unb_var1_amd.c index 297c9200ee..1dcb6d0eeb 100644 --- a/frame/2/her/bli_her_unb_var1_amd.c +++ b/frame/2/her/bli_her_unb_var1_amd.c @@ -163,11 +163,11 @@ void PASTEMAC(ch,varname) \ ) \ { \ const num_t dt = PASTEMAC(ch,type); \ - /* ToDo: - Enable intrinsic implementation after verifying - with netlib BLAS tests. */ \ /* Redirect to intrinsic implementation of HER for dcomplex */ \ - /* if ( bli_cpuid_is_avx_supported() == TRUE && bli_is_conj(conjh) && incx == 1 ) \ + if ( bli_cpuid_is_avx_supported() == TRUE && \ + ( rs_c == 1 || cs_c == 1 ) && \ + ( bli_is_upper( uplo ) || bli_is_lower( uplo ) ) && \ + bli_is_conj(conjh) && incx == 1 ) \ { \ bli_zher_zen_int_var1 \ ( \ @@ -184,7 +184,7 @@ void PASTEMAC(ch,varname) \ cntx \ ); \ } \ - else \ */ \ + else \ { \ ctype* x0; \ ctype* chi1; \ diff --git a/frame/2/her/bli_her_unb_var2_amd.c b/frame/2/her/bli_her_unb_var2_amd.c index c101200d2d..f16ef42a76 100644 --- a/frame/2/her/bli_her_unb_var2_amd.c +++ b/frame/2/her/bli_her_unb_var2_amd.c @@ -163,11 +163,11 @@ void PASTEMAC(ch,varname) \ ) \ { \ const num_t dt = PASTEMAC(ch,type); \ - /* ToDo: - Enable intrinsic implementation after verifying - with netlib BLAS tests. */ \ /* Redirect to intrinsic implementation of HER for unit increment */ \ - /* if ( bli_cpuid_is_avx_supported() == TRUE && bli_is_conj(conjh) && incx == 1 ) \ + if ( bli_cpuid_is_avx_supported() == TRUE && \ + ( rs_c == 1 || cs_c == 1 ) && \ + ( bli_is_upper( uplo ) || bli_is_lower( uplo ) ) && \ + bli_is_conj(conjh) && incx == 1 ) \ { \ bli_zher_zen_int_var2 \ ( \ @@ -184,7 +184,7 @@ void PASTEMAC(ch,varname) \ cntx \ ); \ } \ - else \ */ \ + else \ { \ ctype* chi1; \ ctype* x2; \ diff --git a/kernels/zen/2/bli_her_zen_int_amd.c b/kernels/zen/2/bli_her_zen_int_amd.c index 393797f8cb..ee259b7e3e 100644 --- a/kernels/zen/2/bli_her_zen_int_amd.c +++ b/kernels/zen/2/bli_her_zen_int_amd.c @@ -60,6 +60,7 @@ void bli_zher_zen_int_var1 double xcR, xcI; double xhermcR, xhermcI; double alphaR; + double interR, interI; dcomplex* xc; dcomplex* xhermc; @@ -112,8 +113,15 @@ void bli_zher_zen_int_var1 xc = x + i*incx; xcR = xc->real; xcI = xc->imag; + xhermcR = xc->real; + xhermcI = xc->imag; + + xcR = alphaR * xcR; + xcI = alphaR * xcI; + interR = xcR * xhermcR + xcI * xhermcI; + cc = c + (i)*rs_ct + (i)*cs_ct; - cc->real += alphaR * ((xcR * xcR) + (xcI * xcI)); + cc->real += interR; cc->imag = 0; } @@ -152,9 +160,14 @@ void bli_zher_zen_int_var1 xhermcR = xhermc->real; xhermcI = -1 * conj_multiplier * xhermc->imag; + xcR = alphaR * xcR; + xcI = alphaR * xcI; + interR = xcR * xhermcR - xcI * xhermcI; + interI = xcR * xhermcI + xcI * xhermcR; + cc = c + (i + 1)*rs_ct + (i + 0)*cs_ct; - cc->real += alphaR * ( (xcR * xhermcR) - (xcI * xhermcI) ); - cc->imag += alphaR * ( (xcR * xhermcI) + (xcI * xhermcR) ); + cc->real += interR; + cc->imag += interI; xc = x + (i + 3)*incx; xcR = xc->real; @@ -164,9 +177,14 @@ void bli_zher_zen_int_var1 xhermcR = xhermc->real; xhermcI = -1 * conj_multiplier * xhermc->imag; + xcR = alphaR * xcR; + xcI = alphaR * xcI; + interR = xcR * xhermcR - xcI * xhermcI; + interI = xcR * xhermcI + xcI * xhermcR; + cc = c + (i + 3)*rs_ct + (i + 2)*cs_ct; - cc->real += alphaR * ( (xcR * xhermcR) - (xcI * xhermcI) ); - cc->imag += alphaR * ( (xcR * xhermcI) + (xcI * xhermcR) ); + cc->real += interR; + cc->imag += interI; // Solving the 2x2 square tile inside the triangular block // using intrinsics @@ -502,9 +520,14 @@ void bli_zher_zen_int_var1 xhermcR = xhermc->real; xhermcI = -1 * conj_multiplier * xhermc->imag; + xcR = alphaR * xcR; + xcI = alphaR * xcI; + interR = xcR * xhermcR - xcI * xhermcI; + interI = xcR * xhermcI + xcI * xhermcR; + cc = c + (i + 1)*rs_ct + i*cs_ct; - cc->real += alphaR * ( (xcR * xhermcR) - (xcI * xhermcI) ); - cc->imag += alphaR * ( (xcR * xhermcI) + (xcI * xhermcR) ); + cc->real += interR; + cc->imag += interI; // Loading elements of x to ymm0 for computing xherm vector ymm0 = _mm256_loadu_pd( (double*)( x + i*incx ) ); @@ -571,6 +594,7 @@ void bli_zher_zen_int_var2 double xcR, xcI; double xhermcR, xhermcI; double alphaR; + double interR, interI; dcomplex* xc; dcomplex* xhermc; @@ -613,7 +637,13 @@ void bli_zher_zen_int_var2 if ( bli_is_conj( conjx ) ) alphaRv = _mm256_broadcast_sd( &alphaR ); else alphaRv = _mm256_set_pd( -alphaR, alphaR, -alphaR, alphaR ); - __m256d conj_mulv = _mm256_set_pd( conj_multiplier, -1 * conj_multiplier, conj_multiplier, -1 * conj_multiplier ); + __m256d conj_mulv = _mm256_set_pd + ( + conj_multiplier, + -1 * conj_multiplier, + conj_multiplier, + -1 * conj_multiplier + ); /********* DIAGONAL ELEMENTS *********/ // Solving for the diagonal elements using a scalar loop @@ -622,8 +652,15 @@ void bli_zher_zen_int_var2 xc = x + i*incx; xcR = xc->real; xcI = xc->imag; + xhermcR = xc->real; + xhermcI = xc->imag; + + xcR = alphaR * xcR; + xcI = alphaR * xcI; + interR = xcR * xhermcR + xcI * xhermcI; + cc = c + (i)*rs_ct + (i)*cs_ct; - cc->real += alphaR * ((xcR * xcR) + (xcI * xcI)); + cc->real += interR; cc->imag = 0; } @@ -664,9 +701,14 @@ void bli_zher_zen_int_var2 xhermcR = xhermc->real; xhermcI = -1 * conj_multiplier * xhermc->imag; + xcR = alphaR * xcR; + xcI = alphaR * xcI; + interR = xcR * xhermcR - xcI * xhermcI; + interI = xcR * xhermcI + xcI * xhermcR; + cc = c + (i + 1)*rs_ct + (i + 0)*cs_ct; - cc->real += alphaR * ( (xcR * xhermcR) - (xcI * xhermcI) ); - cc->imag += alphaR * ( (xcR * xhermcI) + (xcI * xhermcR) ); + cc->real += interR; + cc->imag += interI; xc = x + (i + 3)*incx; xcR = xc->real; @@ -676,9 +718,14 @@ void bli_zher_zen_int_var2 xhermcR = xhermc->real; xhermcI = -1 * conj_multiplier * xhermc->imag; + xcR = alphaR * xcR; + xcI = alphaR * xcI; + interR = xcR * xhermcR - xcI * xhermcI; + interI = xcR * xhermcI + xcI * xhermcR; + cc = c + (i + 3)*rs_ct + (i + 2)*cs_ct; - cc->real += alphaR * ( (xcR * xhermcR) - (xcI * xhermcI) ); - cc->imag += alphaR * ( (xcR * xhermcI) + (xcI * xhermcR) ); + cc->real += interR; + cc->imag += interI; // Solving the 2x2 square tile inside the triangular block // using intrinsics @@ -949,10 +996,15 @@ void bli_zher_zen_int_var2 xhermcR = xhermc->real; xhermcI = -1 * conj_multiplier * xhermc->imag; + xcR = alphaR * xcR; + xcI = alphaR * xcI; + interR = xcR * xhermcR - xcI * xhermcI; + interI = xcR * xhermcI + xcI * xhermcR; + // c + ((alpha * x) * xherm) cc = c + (j)*rs_ct + (i)*cs_ct; - cc->real += (alphaR * ((xcR * xhermcR) - (xcI * xhermcI))); - cc->imag += (alphaR * ((xcR * xhermcI) + (xcI * xhermcR))); + cc->real += interR; + cc->imag += interI; xc = x + j*incx; xcR = xc->real; @@ -962,10 +1014,15 @@ void bli_zher_zen_int_var2 xhermcR = xhermc->real; xhermcI = -1 * conj_multiplier * xhermc->imag; + xcR = alphaR * xcR; + xcI = alphaR * xcI; + interR = xcR * xhermcR - xcI * xhermcI; + interI = xcR * xhermcI + xcI * xhermcR; + // c + ((alpha * x) * xherm) cc = c + (j)*rs_ct + (i + 1)*cs_ct; - cc->real += (alphaR * ((xcR * xhermcR) - (xcI * xhermcI))); - cc->imag += (alphaR * ((xcR * xhermcI) + (xcI * xhermcR))); + cc->real += interR; + cc->imag += interI; xc = x + j*incx; xcR = xc->real; @@ -975,10 +1032,15 @@ void bli_zher_zen_int_var2 xhermcR = xhermc->real; xhermcI = -1 * conj_multiplier * xhermc->imag; + xcR = alphaR * xcR; + xcI = alphaR * xcI; + interR = xcR * xhermcR - xcI * xhermcI; + interI = xcR * xhermcI + xcI * xhermcR; + // c + ((alpha * x) * xherm) cc = c + (j)*rs_ct + (i + 2)*cs_ct; - cc->real += (alphaR * ((xcR * xhermcR) - (xcI * xhermcI))); - cc->imag += (alphaR * ((xcR * xhermcI) + (xcI * xhermcR))); + cc->real += interR; + cc->imag += interI; xc = x + j*incx; xcR = xc->real; @@ -988,10 +1050,15 @@ void bli_zher_zen_int_var2 xhermcR = xhermc->real; xhermcI = -1 * conj_multiplier * xhermc->imag; + xcR = alphaR * xcR; + xcI = alphaR * xcI; + interR = xcR * xhermcR - xcI * xhermcI; + interI = xcR * xhermcI + xcI * xhermcR; + // c + ((alpha * x) * xherm) cc = c + (j)*rs_ct + (i + 3)*cs_ct; - cc->real += (alphaR * ((xcR * xhermcR) - (xcI * xhermcI))); - cc->imag += (alphaR * ((xcR * xhermcI) + (xcI * xhermcR))); + cc->real += interR; + cc->imag += interI; } } @@ -1008,9 +1075,14 @@ void bli_zher_zen_int_var2 xhermcR = xhermc->real; xhermcI = -1 * conj_multiplier * xhermc->imag; + xcR = alphaR * xcR; + xcI = alphaR * xcI; + interR = xcR * xhermcR - xcI * xhermcI; + interI = xcR * xhermcI + xcI * xhermcR; + cc = c + (i + 1)*rs_ct + i*cs_ct; - cc->real += alphaR * ( (xcR * xhermcR) - (xcI * xhermcI) ); - cc->imag += alphaR * ( (xcR * xhermcI) + (xcI * xhermcR) ); + cc->real += interR; + cc->imag += interI; // Solving the remaining elements in square block // using scalar code @@ -1024,10 +1096,15 @@ void bli_zher_zen_int_var2 xhermcR = xhermc->real; xhermcI = -1 * conj_multiplier * xhermc->imag; + xcR = alphaR * xcR; + xcI = alphaR * xcI; + interR = xcR * xhermcR - xcI * xhermcI; + interI = xcR * xhermcI + xcI * xhermcR; + // c + ((alpha * x) * xherm) cc = c + (j)*rs_ct + (i)*cs_ct; - cc->real += (alphaR * ((xcR * xhermcR) - (xcI * xhermcI))); - cc->imag += (alphaR * ((xcR * xhermcI) + (xcI * xhermcR))); + cc->real += interR; + cc->imag += interI; xc = x + j*incx; xcR = xc->real; @@ -1037,10 +1114,15 @@ void bli_zher_zen_int_var2 xhermcR = xhermc->real; xhermcI = -1 * conj_multiplier * xhermc->imag; + xcR = alphaR * xcR; + xcI = alphaR * xcI; + interR = xcR * xhermcR - xcI * xhermcI; + interI = xcR * xhermcI + xcI * xhermcR; + // c + ((alpha * x) * xherm) cc = c + (j)*rs_ct + (i + 1)*cs_ct; - cc->real += (alphaR * ((xcR * xhermcR) - (xcI * xhermcI))); - cc->imag += (alphaR * ((xcR * xhermcI) + (xcI * xhermcR))); + cc->real += interR; + cc->imag += interI; } } } \ No newline at end of file From 9b28371f45d1724e87a8a04fe3ceb1684eeba95d Mon Sep 17 00:00:00 2001 From: Edward Smyth Date: Fri, 26 Aug 2022 07:04:38 -0400 Subject: [PATCH 201/243] BLIS: BLAS3 quick return functionality Bugfix for http://gerrit-git.amd.com/q/I0ebe9d76465b0e48b2ff5c2f1cc2a75763fe187c changes in frame/compat/bla_gemm.c AMD-Internal: [CPUPL-2373] Change-Id: Ia58b76f060eda204bba68be6730105f4d8db7537 --- frame/compat/bla_gemm.c | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/frame/compat/bla_gemm.c b/frame/compat/bla_gemm.c index bf5cc502a7..78c7f5f4b6 100644 --- a/frame/compat/bla_gemm.c +++ b/frame/compat/bla_gemm.c @@ -399,14 +399,14 @@ void dzgemm_ ); /* Quick return if possible. */ - if ( *m == 0 || *n == 0 || (( PASTEMAC(ch,eq0)( *alpha ) || *k == 0) - && PASTEMAC(ch,eq1)( *beta ) )) + if ( *m == 0 || *n == 0 || (( PASTEMAC(z,eq0)( *alpha ) || *k == 0) + && PASTEMAC(z,eq1)( *beta ) )) { AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); /* Finalize BLIS. */ bli_finalize_auto(); - return; \ + return; } /* Map BLAS chars to their corresponding BLIS enumerated type value. */ From abf848ad1209bf525388c7868ad29784d6f9f51b Mon Sep 17 00:00:00 2001 From: Edward Smyth Date: Mon, 29 Aug 2022 07:21:11 -0400 Subject: [PATCH 202/243] Code cleanup and warnings fixes - Removed some additional compiler warnings reported by GCC 12.1 - Fixed a couple of typos in comments - frame/3/bli_l3_sup.c: routines were returning before final call to AOCL_DTL_TRACE_EXIT - frame/2/gemv/bli_gemv_unf_var1_amd.c: bli_multi_sgemv_4x2 is only defined in header file if BLIS_ENABLE_OPENMP is defined AMD-Internal: [CPUPL-2460] Change-Id: I2eacd5687f2548d8f40c24bd1b930859eefbbcde --- blastest/f2c/rdfmt.c | 12 ++++++++---- frame/2/gemv/bli_gemv_unf_var1_amd.c | 2 ++ frame/3/bli_l3_sup.c | 12 ++++++------ frame/base/bli_init.c | 4 ++-- frame/util/bli_util_update.c | 14 +++++++------- 5 files changed, 25 insertions(+), 19 deletions(-) diff --git a/blastest/f2c/rdfmt.c b/blastest/f2c/rdfmt.c index 6349e3f3fd..0d8a0bf12e 100644 --- a/blastest/f2c/rdfmt.c +++ b/blastest/f2c/rdfmt.c @@ -249,9 +249,13 @@ static int rd_F(ufloat *p, int w, int d, ftnlen len) } while(ch == ' ') { blankdrop: - if (!w--) goto zero; GET(ch); } - while(ch == '0') - { if (!w--) goto zero; GET(ch); } + if (!w--) goto zero; + GET(ch); + } + while(ch == '0') { + if (!w--) goto zero; + GET(ch); + } if (ch == ' ' && f__cblank) goto blankdrop; scale1 = f__scale; @@ -262,7 +266,7 @@ static int rd_F(ufloat *p, int w, int d, ftnlen len) digloop1e: if (!w--) goto done; GET(ch); - } + } if (ch == ' ') { if (f__cblank) { ch = '0'; goto digloop1; } diff --git a/frame/2/gemv/bli_gemv_unf_var1_amd.c b/frame/2/gemv/bli_gemv_unf_var1_amd.c index 3347a133aa..a9534bd9a0 100644 --- a/frame/2/gemv/bli_gemv_unf_var1_amd.c +++ b/frame/2/gemv/bli_gemv_unf_var1_amd.c @@ -519,6 +519,7 @@ void bli_sgemv_unf_var1 if ( ( nt_max > 1 ) & ( is_omp_mt_enabled == TRUE ) ) { +#ifdef BLIS_ENABLE_OPENMP b_fuse = 4; //Setting the thread count to the maximum number of threads provided @@ -544,6 +545,7 @@ void bli_sgemv_unf_var1 cntx, nt ); +#endif// BLIS_ENABLE_OPENMP } else { diff --git a/frame/3/bli_l3_sup.c b/frame/3/bli_l3_sup.c index d23df8c1e5..867ccd200c 100644 --- a/frame/3/bli_l3_sup.c +++ b/frame/3/bli_l3_sup.c @@ -152,8 +152,7 @@ err_t bli_gemmsup // Query the small/unpacked handler from the context and invoke it. gemmsup_oft gemmsup_fp = bli_cntx_get_l3_sup_handler( BLIS_GEMM, cntx ); - return - gemmsup_fp + err_t ret_gemmsup_fp = gemmsup_fp ( alpha, a, @@ -165,6 +164,7 @@ err_t bli_gemmsup ); AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_2); + return ret_gemmsup_fp; } err_t bli_gemmtsup @@ -285,8 +285,7 @@ printf( "dims: %d %d %d (threshs: %d %d %d)\n", // Query the small/unpacked handler from the context and invoke it. gemmtsup_oft gemmtsup_fp = bli_cntx_get_l3_sup_handler( BLIS_GEMMT, cntx ); - return - gemmtsup_fp + err_t ret_gemmtsup_fp = gemmtsup_fp ( alpha, a, @@ -298,6 +297,7 @@ printf( "dims: %d %d %d (threshs: %d %d %d)\n", ); AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_2); + return ret_gemmtsup_fp; } err_t bli_syrksup @@ -414,8 +414,7 @@ printf( "dims: %d %d %d (threshs: %d %d %d)\n", // Query the small/unpacked handler from the context and invoke it. gemmtsup_oft gemmtsup_fp = bli_cntx_get_l3_sup_handler( BLIS_GEMMT, cntx ); - return - gemmtsup_fp + err_t ret_gemmtsup_fp = gemmtsup_fp ( alpha, a, @@ -427,4 +426,5 @@ printf( "dims: %d %d %d (threshs: %d %d %d)\n", ); AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_2); + return ret_gemmtsup_fp; } diff --git a/frame/base/bli_init.c b/frame/base/bli_init.c index 1207058f12..b037fbd217 100644 --- a/frame/base/bli_init.c +++ b/frame/base/bli_init.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -74,7 +74,7 @@ void bli_finalize_auto( void ) void bli_init_apis( void ) { - /* Initialzie DTL Libary with trace level set by the user */ + /* Initialize DTL Library with trace level set by the user */ AOCL_DTL_INITIALIZE(AOCL_DTL_TRACE_LEVEL); // Initialize various sub-APIs. bli_gks_init(); diff --git a/frame/util/bli_util_update.c b/frame/util/bli_util_update.c index b57c065721..6bcd31dff2 100644 --- a/frame/util/bli_util_update.c +++ b/frame/util/bli_util_update.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2020 - 21, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -212,12 +212,12 @@ void PASTEMAC(ch, varname) \ c[m*rs_c + n].imag = ct[m*rs_ct + n].imag; \ } \ \ - for(; m < m_cur; m++) \ - for(n = 0; n < n_cur; n++) \ - { \ - c[m*rs_c + n].real = ct[m*rs_ct + n].real; \ - c[m*rs_c + n].imag = ct[m*rs_ct + n].imag; \ - } \ + for(; m < m_cur; m++) \ + for(n = 0; n < n_cur; n++) \ + { \ + c[m*rs_c + n].real = ct[m*rs_ct + n].real; \ + c[m*rs_c + n].imag = ct[m*rs_ct + n].imag; \ + } \ } \ \ return; \ From 40c71dd2e1584a0e4a13ff613f7862b409e4f8c2 Mon Sep 17 00:00:00 2001 From: Dipal M Zambare Date: Tue, 30 Aug 2022 10:08:37 +0530 Subject: [PATCH 203/243] Revert "CBLAS/BLAS interface decoupling for swap api" This reverts commit 2beaa6a0e64feaa8352871642a88c5c74e3f399d. Reverting it as it is planned for the next release. Change-Id: Ib9271acd0b5b4cfd10c8f8b7bbb6ef93a3d594ea --- frame/compat/bla_gemv.c | 4 +- frame/compat/bla_swap.c | 12 +- frame/compat/bla_swap.h | 7 - frame/compat/bla_swap_amd.c | 36 +---- frame/compat/cblas/src/cblas_f77.h | 8 +- frame/include/bli_macro_defs.h | 2 +- frame/util/bli_util_api_wrap.c | 242 ++++++++++++++--------------- 7 files changed, 133 insertions(+), 178 deletions(-) diff --git a/frame/compat/bla_gemv.c b/frame/compat/bla_gemv.c index 12fda66605..f5a314331a 100644 --- a/frame/compat/bla_gemv.c +++ b/frame/compat/bla_gemv.c @@ -141,7 +141,7 @@ void PASTEF77S(ch,blasname) \ bli_finalize_auto(); \ }\ \ -void PASTEF77(ch,blasname) \ +void PASTEF77S(ch,blasname) \ ( \ const f77_char* transa, \ const f77_int* m, \ @@ -153,7 +153,7 @@ void PASTEF77(ch,blasname) \ ftype* y, const f77_int* incy \ ) \ { \ - PASTEF77S(ch,blasname) \ + PASTEF77(ch,blasname) \ ( transa, m, n, alpha, a, lda, x, incx, beta, y, incy ); \ } diff --git a/frame/compat/bla_swap.c b/frame/compat/bla_swap.c index 67d58c8cbd..d653426478 100644 --- a/frame/compat/bla_swap.c +++ b/frame/compat/bla_swap.c @@ -42,7 +42,7 @@ #undef GENTFUNC #define GENTFUNC( ftype, ch, blasname, blisname ) \ \ -void PASTEF77S(ch,blasname) \ +void PASTEF77(ch,blasname) \ ( \ const f77_int* n, \ ftype* x, const f77_int* incx, \ @@ -80,16 +80,6 @@ void PASTEF77S(ch,blasname) \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ -}\ -\ -void PASTEF77(ch,blasname) \ - ( \ - const f77_int* n, \ - ftype* x, const f77_int* incx, \ - ftype* y, const f77_int* incy \ - ) \ -{ \ - PASTEF77S(ch,blasname)( n, x, incx, y, incy ); \ } #ifdef BLIS_ENABLE_BLAS diff --git a/frame/compat/bla_swap.h b/frame/compat/bla_swap.h index ccb31688c7..54c0613a92 100644 --- a/frame/compat/bla_swap.h +++ b/frame/compat/bla_swap.h @@ -40,13 +40,6 @@ #define GENTPROT( ftype, ch, blasname ) \ \ BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ - ( \ - const f77_int* n, \ - ftype* x, const f77_int* incx, \ - ftype* y, const f77_int* incy \ - ); \ -\ -BLIS_EXPORT_BLAS void PASTEF77S(ch,blasname) \ ( \ const f77_int* n, \ ftype* x, const f77_int* incx, \ diff --git a/frame/compat/bla_swap_amd.c b/frame/compat/bla_swap_amd.c index 77e8afcca0..617c78a4aa 100644 --- a/frame/compat/bla_swap_amd.c +++ b/frame/compat/bla_swap_amd.c @@ -42,7 +42,7 @@ #undef GENTFUNC #define GENTFUNC( ftype, ch, blasname, blisname ) \ \ -void PASTEF77S(ch,blasname) \ +void PASTEF77(ch,blasname) \ ( \ const f77_int* n, \ ftype* x, const f77_int* incx, \ @@ -80,21 +80,11 @@ void PASTEF77S(ch,blasname) \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ -}\ -\ -void PASTEF77(ch,blasname) \ - ( \ - const f77_int* n, \ - ftype* x, const f77_int* incx, \ - ftype* y, const f77_int* incy \ - ) \ -{ \ - PASTEF77S(ch,blasname)( n, x, incx, y, incy ); \ } #ifdef BLIS_ENABLE_BLAS -void sswap_blis_impl +void sswap_ ( const f77_int* n, float* x, const f77_int* incx, @@ -182,17 +172,8 @@ void sswap_blis_impl // bli_finalize_auto(); AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) } -void sswap_ - ( - const f77_int* n, - float* x, const f77_int* incx, - float* y, const f77_int* incy - ) -{ - sswap_blis_impl( n, x, incx, y, incy ); -} - -void dswap_blis_impl + +void dswap_ ( const f77_int* n, double* x, const f77_int* incx, @@ -280,15 +261,6 @@ void dswap_blis_impl // bli_finalize_auto(); AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) } -void dswap_ - ( - const f77_int* n, - double* x, const f77_int* incx, - double* y, const f77_int* incy - ) -{ - dswap_blis_impl( n, x, incx, y, incy ); -} INSERT_GENTFUNC_BLAS_CZ( swap, swapv ) diff --git a/frame/compat/cblas/src/cblas_f77.h b/frame/compat/cblas/src/cblas_f77.h index 8e43c05d17..864d78895e 100644 --- a/frame/compat/cblas/src/cblas_f77.h +++ b/frame/compat/cblas/src/cblas_f77.h @@ -215,19 +215,19 @@ #define F77_drotmg drotmg_blis_impl #define F77_drot drot_blis_impl #define F77_drotm drotm_blis_impl -#define F77_sswap sswap_blis_impl +#define F77_sswap sswap_ #define F77_scopy scopy_blis_impl #define F77_saxpy saxpy_blis_impl #define F77_isamax_sub isamaxsub_blis_impl -#define F77_dswap dswap_blis_impl +#define F77_dswap dswap_ #define F77_dcopy dcopy_blis_impl #define F77_daxpy daxpy_blis_impl #define F77_idamax_sub idamaxsub_blis_impl -#define F77_cswap cswap_blis_impl +#define F77_cswap cswap_ #define F77_ccopy ccopy_blis_impl #define F77_caxpy caxpy_blis_impl #define F77_icamax_sub icamaxsub_blis_impl -#define F77_zswap zswap_blis_impl +#define F77_zswap zswap_ #define F77_zcopy zcopy_blis_impl #define F77_zaxpy zaxpy_blis_impl #define F77_izamax_sub izamaxsub_blis_impl diff --git a/frame/include/bli_macro_defs.h b/frame/include/bli_macro_defs.h index 99d1de6180..9ab7c00aa7 100644 --- a/frame/include/bli_macro_defs.h +++ b/frame/include/bli_macro_defs.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018-2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/frame/util/bli_util_api_wrap.c b/frame/util/bli_util_api_wrap.c index 6e500b0b88..26810531d3 100644 --- a/frame/util/bli_util_api_wrap.c +++ b/frame/util/bli_util_api_wrap.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2021-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -210,17 +210,17 @@ void CGBMV_(const char *trans,const f77_int *m,const f77_int *n,const f77_int void CGEMM(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) { - cgemm_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + cgemm_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void cgemm(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) { - cgemm_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + cgemm_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void CGEMM_(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) { - cgemm_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + cgemm_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void CGEMV(const char *trans,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *x,const f77_int *incx,const scomplex *beta,scomplex *y,const f77_int *incy) @@ -285,17 +285,17 @@ void CHBMV_(const char *uplo,const f77_int *n,const f77_int *k,const scomplex void CHEMM(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) { - chemm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + chemm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); } void chemm(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) { - chemm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + chemm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); } void CHEMM_(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) { - chemm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + chemm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); } void CHEMV(const char *uplo,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *x,const f77_int *incx,const scomplex *beta,scomplex *y,const f77_int *incy) @@ -345,32 +345,32 @@ void CHER2_(const char *uplo,const f77_int *n,const scomplex *alpha,const sco void CHER2K(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const float *beta,scomplex *c,const f77_int *ldc) { - cher2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + cher2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void cher2k(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const float *beta,scomplex *c,const f77_int *ldc) { - cher2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + cher2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void CHER2K_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const float *beta,scomplex *c,const f77_int *ldc) { - cher2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + cher2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void CHERK(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const float *alpha,const scomplex *a,const f77_int *lda,const float *beta,scomplex *c,const f77_int *ldc) { - cherk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); + cherk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); } void cherk(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const float *alpha,const scomplex *a,const f77_int *lda,const float *beta,scomplex *c,const f77_int *ldc) { - cherk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); + cherk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); } void CHERK_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const float *alpha,const scomplex *a,const f77_int *lda,const float *beta,scomplex *c,const f77_int *ldc) { - cherk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); + cherk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); } void CHPMV(const char *uplo,const f77_int *n,const scomplex *alpha,const scomplex *ap,const scomplex *x,const f77_int *incx,const scomplex *beta,scomplex *y,const f77_int *incy) @@ -480,62 +480,62 @@ void CSSCAL_(const f77_int *n,const float *sa,scomplex *cx,const f77_int *incx void CSWAP(const f77_int *n,scomplex *cx,const f77_int *incx,scomplex *cy,const f77_int *incy) { - cswap_blis_impl( n, cx, incx, cy, incy); + cswap_( n, cx, incx, cy, incy); } void cswap(const f77_int *n,scomplex *cx,const f77_int *incx,scomplex *cy,const f77_int *incy) { - cswap_blis_impl( n, cx, incx, cy, incy); + cswap_( n, cx, incx, cy, incy); } void CSWAP_(const f77_int *n,scomplex *cx,const f77_int *incx,scomplex *cy,const f77_int *incy) { - cswap_blis_impl( n, cx, incx, cy, incy); + cswap_( n, cx, incx, cy, incy); } void CSYMM(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) { - csymm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + csymm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); } void csymm(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) { - csymm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + csymm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); } void CSYMM_(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) { - csymm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + csymm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); } void CSYR2K(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) { - csyr2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + csyr2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void csyr2k(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) { - csyr2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + csyr2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void CSYR2K_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) { - csyr2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + csyr2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void CSYRK(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *beta,scomplex *c,const f77_int *ldc) { - csyrk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); + csyrk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); } void csyrk(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *beta,scomplex *c,const f77_int *ldc) { - csyrk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); + csyrk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); } void CSYRK_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *beta,scomplex *c,const f77_int *ldc) { - csyrk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); + csyrk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); } void CTBMV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const scomplex *a,const f77_int *lda,scomplex *x,const f77_int *incx) @@ -600,17 +600,17 @@ void CTPSV_(const char *uplo,const char *trans,const char *diag,const f77_ void CTRMM(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,scomplex *b,const f77_int *ldb) { - ctrmm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + ctrmm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void ctrmm(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,scomplex *b,const f77_int *ldb) { - ctrmm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + ctrmm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void CTRMM_(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,scomplex *b,const f77_int *ldb) { - ctrmm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + ctrmm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void CTRMV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const scomplex *a,const f77_int *lda,scomplex *x,const f77_int *incx) @@ -630,17 +630,17 @@ void CTRMV_(const char *uplo,const char *trans,const char *diag,const f77_ void CTRSM(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,scomplex *b,const f77_int *ldb) { - ctrsm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + ctrsm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void ctrsm(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,scomplex *b,const f77_int *ldb) { - ctrsm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + ctrsm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void CTRSM_(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,scomplex *b,const f77_int *ldb) { - ctrsm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + ctrsm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void CTRSV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const scomplex *a,const f77_int *lda,scomplex *x,const f77_int *incx) @@ -750,17 +750,17 @@ void DGBMV_(const char *trans,const f77_int *m,const f77_int *n,const f77_int void DGEMM(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const double *alpha,const double *a,const f77_int *lda,const double *b,const f77_int *ldb,const double *beta,double *c,const f77_int *ldc) { - dgemm_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + dgemm_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void dgemm(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const double *alpha,const double *a,const f77_int *lda,const double *b,const f77_int *ldb,const double *beta,double *c,const f77_int *ldc) { - dgemm_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + dgemm_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void DGEMM_(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const double *alpha,const double *a,const f77_int *lda,const double *b,const f77_int *ldb,const double *beta,double *c,const f77_int *ldc) { - dgemm_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + dgemm_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void DGEMV(const char *trans,const f77_int *m,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,const double *x,const f77_int *incx,const double *beta,double *y,const f77_int *incy) @@ -960,32 +960,32 @@ void DSPR2_(const char *uplo,const f77_int *n,const double *alpha,const double void DSWAP(const f77_int *n,double *dx,const f77_int *incx,double *dy,const f77_int *incy) { - dswap_blis_impl( n, dx, incx, dy, incy); + dswap_( n, dx, incx, dy, incy); } void dswap(const f77_int *n,double *dx,const f77_int *incx,double *dy,const f77_int *incy) { - dswap_blis_impl( n, dx, incx, dy, incy); + dswap_( n, dx, incx, dy, incy); } void DSWAP_(const f77_int *n,double *dx,const f77_int *incx,double *dy,const f77_int *incy) { - dswap_blis_impl( n, dx, incx, dy, incy); + dswap_( n, dx, incx, dy, incy); } void DSYMM(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,const double *b,const f77_int *ldb,const double *beta,double *c,const f77_int *ldc) { - dsymm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + dsymm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); } void dsymm(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,const double *b,const f77_int *ldb,const double *beta,double *c,const f77_int *ldc) { - dsymm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + dsymm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); } void DSYMM_(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,const double *b,const f77_int *ldb,const double *beta,double *c,const f77_int *ldc) { - dsymm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + dsymm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); } void DSYMV(const char *uplo,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,const double *x,const f77_int *incx,const double *beta,double *y,const f77_int *incy) @@ -1035,32 +1035,32 @@ void DSYR2_(const char *uplo,const f77_int *n,const double *alpha,const double void DSYR2K(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const double *alpha,const double *a,const f77_int *lda,const double *b,const f77_int *ldb,const double *beta,double *c,const f77_int *ldc) { - dsyr2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + dsyr2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void dsyr2k(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const double *alpha,const double *a,const f77_int *lda,const double *b,const f77_int *ldb,const double *beta,double *c,const f77_int *ldc) { - dsyr2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + dsyr2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void DSYR2K_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const double *alpha,const double *a,const f77_int *lda,const double *b,const f77_int *ldb,const double *beta,double *c,const f77_int *ldc) { - dsyr2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + dsyr2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void DSYRK(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const double *alpha,const double *a,const f77_int *lda,const double *beta,double *c,const f77_int *ldc) { - dsyrk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); + dsyrk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); } void dsyrk(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const double *alpha,const double *a,const f77_int *lda,const double *beta,double *c,const f77_int *ldc) { - dsyrk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); + dsyrk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); } void DSYRK_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const double *alpha,const double *a,const f77_int *lda,const double *beta,double *c,const f77_int *ldc) { - dsyrk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); + dsyrk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); } void DTBMV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const double *a,const f77_int *lda,double *x,const f77_int *incx) @@ -1125,17 +1125,17 @@ void DTPSV_(const char *uplo,const char *trans,const char *diag,const f77_ void DTRMM(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,double *b,const f77_int *ldb) { - dtrmm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + dtrmm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void dtrmm(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,double *b,const f77_int *ldb) { - dtrmm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + dtrmm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void DTRMM_(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,double *b,const f77_int *ldb) { - dtrmm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + dtrmm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void DTRMV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const double *a,const f77_int *lda,double *x,const f77_int *incx) @@ -1155,17 +1155,17 @@ void DTRMV_(const char *uplo,const char *trans,const char *diag,const f77_ void DTRSM(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,double *b,const f77_int *ldb) { - dtrsm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + dtrsm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void dtrsm(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,double *b,const f77_int *ldb) { - dtrsm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + dtrsm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void DTRSM_(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,double *b,const f77_int *ldb) { - dtrsm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + dtrsm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void DTRSV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const double *a,const f77_int *lda,double *x,const f77_int *incx) @@ -1417,17 +1417,17 @@ void SGBMV_(const char *trans,const f77_int *m,const f77_int *n,const f77_int void SGEMM(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const float *alpha,const float *a,const f77_int *lda,const float *b,const f77_int *ldb,const float *beta,float *c,const f77_int *ldc) { - sgemm_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + sgemm_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void sgemm(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const float *alpha,const float *a,const f77_int *lda,const float *b,const f77_int *ldb,const float *beta,float *c,const f77_int *ldc) { - sgemm_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + sgemm_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void SGEMM_(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const float *alpha,const float *a,const f77_int *lda,const float *b,const f77_int *ldb,const float *beta,float *c,const f77_int *ldc) { - sgemm_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + sgemm_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void SGEMV(const char *trans,const f77_int *m,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,const float *x,const f77_int *incx,const float *beta,float *y,const f77_int *incy) @@ -1614,32 +1614,32 @@ void SSPR2_(const char *uplo,const f77_int *n,const float *alpha,const float void SSWAP(const f77_int *n,float *sx,const f77_int *incx,float *sy,const f77_int *incy) { - sswap_blis_impl( n, sx, incx, sy, incy); + sswap_( n, sx, incx, sy, incy); } void sswap(const f77_int *n,float *sx,const f77_int *incx,float *sy,const f77_int *incy) { - sswap_blis_impl( n, sx, incx, sy, incy); + sswap_( n, sx, incx, sy, incy); } void SSWAP_(const f77_int *n,float *sx,const f77_int *incx,float *sy,const f77_int *incy) { - sswap_blis_impl( n, sx, incx, sy, incy); + sswap_( n, sx, incx, sy, incy); } void SSYMM(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,const float *b,const f77_int *ldb,const float *beta,float *c,const f77_int *ldc) { - ssymm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + ssymm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); } void ssymm(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,const float *b,const f77_int *ldb,const float *beta,float *c,const f77_int *ldc) { - ssymm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + ssymm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); } void SSYMM_(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,const float *b,const f77_int *ldb,const float *beta,float *c,const f77_int *ldc) { - ssymm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + ssymm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); } void SSYMV(const char *uplo,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,const float *x,const f77_int *incx,const float *beta,float *y,const f77_int *incy) @@ -1689,32 +1689,32 @@ void SSYR2_(const char *uplo,const f77_int *n,const float *alpha,const float void SSYR2K(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const float *alpha,const float *a,const f77_int *lda,const float *b,const f77_int *ldb,const float *beta,float *c,const f77_int *ldc) { - ssyr2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + ssyr2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void ssyr2k(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const float *alpha,const float *a,const f77_int *lda,const float *b,const f77_int *ldb,const float *beta,float *c,const f77_int *ldc) { - ssyr2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + ssyr2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void SSYR2K_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const float *alpha,const float *a,const f77_int *lda,const float *b,const f77_int *ldb,const float *beta,float *c,const f77_int *ldc) { - ssyr2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + ssyr2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void SSYRK(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const float *alpha,const float *a,const f77_int *lda,const float *beta,float *c,const f77_int *ldc) { - ssyrk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); + ssyrk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); } void ssyrk(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const float *alpha,const float *a,const f77_int *lda,const float *beta,float *c,const f77_int *ldc) { - ssyrk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); + ssyrk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); } void SSYRK_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const float *alpha,const float *a,const f77_int *lda,const float *beta,float *c,const f77_int *ldc) { - ssyrk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); + ssyrk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); } void STBMV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const float *a,const f77_int *lda,float *x,const f77_int *incx) @@ -1779,17 +1779,17 @@ void STPSV_(const char *uplo,const char *trans,const char *diag,const f77_ void STRMM(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,float *b,const f77_int *ldb) { - strmm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + strmm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void strmm(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,float *b,const f77_int *ldb) { - strmm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + strmm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void STRMM_(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,float *b,const f77_int *ldb) { - strmm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + strmm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void STRMV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const float *a,const f77_int *lda,float *x,const f77_int *incx) @@ -1809,17 +1809,17 @@ void STRMV_(const char *uplo,const char *trans,const char *diag,const f77_ void STRSM(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,float *b,const f77_int *ldb) { - strsm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + strsm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void strsm(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,float *b,const f77_int *ldb) { - strsm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + strsm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void STRSM_(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,float *b,const f77_int *ldb) { - strsm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + strsm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void STRSV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const float *a,const f77_int *lda,float *x,const f77_int *incx) @@ -1929,17 +1929,17 @@ void ZGBMV_(const char *trans,const f77_int *m,const f77_int *n,const f77_int void ZGEMM(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) { - zgemm_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + zgemm_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void zgemm(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) { - zgemm_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + zgemm_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void ZGEMM_(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) { - zgemm_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + zgemm_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void ZGEMV(const char *trans,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *x,const f77_int *incx,const dcomplex *beta,dcomplex *y,const f77_int *incy) @@ -2004,17 +2004,17 @@ void ZHBMV_(const char *uplo,const f77_int *n,const f77_int *k,const dcomplex void ZHEMM(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) { - zhemm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + zhemm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); } void zhemm(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) { - zhemm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + zhemm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); } void ZHEMM_(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) { - zhemm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + zhemm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); } void ZHEMV(const char *uplo,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *x,const f77_int *incx,const dcomplex *beta,dcomplex *y,const f77_int *incy) @@ -2064,32 +2064,32 @@ void ZHER2_(const char *uplo,const f77_int *n,const dcomplex *alpha,const dcom void ZHER2K(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const double *beta,dcomplex *c,const f77_int *ldc) { - zher2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + zher2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void zher2k(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const double *beta,dcomplex *c,const f77_int *ldc) { - zher2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + zher2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void ZHER2K_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const double *beta,dcomplex *c,const f77_int *ldc) { - zher2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + zher2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void ZHERK(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const double *alpha,const dcomplex *a,const f77_int *lda,const double *beta,dcomplex *c,const f77_int *ldc) { - zherk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); + zherk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); } void zherk(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const double *alpha,const dcomplex *a,const f77_int *lda,const double *beta,dcomplex *c,const f77_int *ldc) { - zherk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); + zherk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); } void ZHERK_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const double *alpha,const dcomplex *a,const f77_int *lda,const double *beta,dcomplex *c,const f77_int *ldc) { - zherk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); + zherk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); } void ZHPMV(const char *uplo,const f77_int *n,const dcomplex *alpha,const dcomplex *ap,const dcomplex *x,const f77_int *incx,const dcomplex *beta,dcomplex *y,const f77_int *incy) @@ -2169,62 +2169,62 @@ void ZSCAL_(const f77_int *n,const dcomplex *za,dcomplex *zx,const f77_int *incx void ZSWAP(const f77_int *n,dcomplex *zx,const f77_int *incx,dcomplex *zy,const f77_int *incy) { - zswap_blis_impl( n, zx, incx, zy, incy); + zswap_( n, zx, incx, zy, incy); } void zswap(const f77_int *n,dcomplex *zx,const f77_int *incx,dcomplex *zy,const f77_int *incy) { - zswap_blis_impl( n, zx, incx, zy, incy); + zswap_( n, zx, incx, zy, incy); } void ZSWAP_(const f77_int *n,dcomplex *zx,const f77_int *incx,dcomplex *zy,const f77_int *incy) { - zswap_blis_impl( n, zx, incx, zy, incy); + zswap_( n, zx, incx, zy, incy); } void ZSYMM(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) { - zsymm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + zsymm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); } void zsymm(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) { - zsymm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + zsymm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); } void ZSYMM_(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) { - zsymm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + zsymm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); } void ZSYR2K(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) { - zsyr2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + zsyr2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void zsyr2k(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) { - zsyr2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + zsyr2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void ZSYR2K_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) { - zsyr2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + zsyr2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void ZSYRK(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *beta,dcomplex *c,const f77_int *ldc) { - zsyrk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); + zsyrk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); } void zsyrk(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *beta,dcomplex *c,const f77_int *ldc) { - zsyrk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); + zsyrk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); } void ZSYRK_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *beta,dcomplex *c,const f77_int *ldc) { - zsyrk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); + zsyrk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); } void ZTBMV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const dcomplex *a,const f77_int *lda,dcomplex *x,const f77_int *incx) @@ -2289,17 +2289,17 @@ void ZTPSV_(const char *uplo,const char *trans,const char *diag,const f77_ void ZTRMM(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,dcomplex *b,const f77_int *ldb) { - ztrmm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + ztrmm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void ztrmm(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,dcomplex *b,const f77_int *ldb) { - ztrmm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + ztrmm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void ZTRMM_(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,dcomplex *b,const f77_int *ldb) { - ztrmm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + ztrmm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void ZTRMV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const dcomplex *a,const f77_int *lda,dcomplex *x,const f77_int *incx) @@ -2319,17 +2319,17 @@ void ZTRMV_(const char *uplo,const char *trans,const char *diag,const f77_ void ZTRSM(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,dcomplex *b,const f77_int *ldb) { - ztrsm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + ztrsm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void ztrsm(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,dcomplex *b,const f77_int *ldb) { - ztrsm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + ztrsm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void ZTRSM_(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,dcomplex *b,const f77_int *ldb) { - ztrsm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + ztrsm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void ZTRSV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const dcomplex *a,const f77_int *lda,dcomplex *x,const f77_int *incx) @@ -2380,17 +2380,17 @@ void CDOTUSUB_( const f77_int* n, const scomplex* x,const f77_int* incxy, const void CGEMM3M( const f77_char* transa, const f77_char* transb, const f77_int* m, const f77_int* n, const f77_int* k, const scomplex* alpha, const scomplex* a, const f77_int* lda, const scomplex* b, const f77_int* ldb, const scomplex* beta, scomplex* c, const f77_int* ldc) { - cgemm3m_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + cgemm3m_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void cgemm3m( const f77_char* transa, const f77_char* transb, const f77_int* m, const f77_int* n, const f77_int* k, const scomplex* alpha, const scomplex* a, const f77_int* lda, const scomplex* b, const f77_int* ldb, const scomplex* beta, scomplex* c, const f77_int* ldc) { - cgemm3m_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + cgemm3m_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void CGEMM3M_( const f77_char* transa, const f77_char* transb, const f77_int* m, const f77_int* n, const f77_int* k, const scomplex* alpha, const scomplex* a, const f77_int* lda, const scomplex* b, const f77_int* ldb, const scomplex* beta, scomplex* c, const f77_int* ldc) { - cgemm3m_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + cgemm3m_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void CGEMM_BATCH( const f77_char* transa_array, const f77_char* transb_array,const f77_int *m_array, const f77_int *n_array, const f77_int *k_array,const scomplex* alpha_array, const scomplex** a_array, const f77_int *lda_array, const scomplex** b_array, const f77_int *ldb_array, const scomplex* beta_array, scomplex** c_array, const f77_int *ldc_array, const f77_int* group_count, const f77_int *group_size) @@ -2410,17 +2410,17 @@ void CGEMM_BATCH_( const f77_char* transa_array, const f77_char* transb_array,co void CGEMMT( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const scomplex* alpha, const scomplex* a, const f77_int* lda, const scomplex* b, const f77_int* ldb, const scomplex* beta, scomplex* c, const f77_int* ldc) { - cgemmt_( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + cgemmt_blis_impl( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void cgemmt( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const scomplex* alpha, const scomplex* a, const f77_int* lda, const scomplex* b, const f77_int* ldb, const scomplex* beta, scomplex* c, const f77_int* ldc) { - cgemmt_( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + cgemmt_blis_impl( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void CGEMMT_( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const scomplex* alpha, const scomplex* a, const f77_int* lda, const scomplex* b, const f77_int* ldb, const scomplex* beta, scomplex* c, const f77_int* ldc) { - cgemmt_( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + cgemmt_blis_impl( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void CIMATCOPY(f77_char* trans, f77_int* rows, f77_int* cols, const scomplex* alpha,scomplex* aptr, f77_int* lda, f77_int* ldb) @@ -2545,17 +2545,17 @@ void DGEMM_BATCH_( const f77_char* transa_array, const f77_char* transb_array,co void DGEMMT( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const double* alpha, const double* a, const f77_int* lda, const double* b, const f77_int* ldb, const double* beta, double* c, const f77_int* ldc) { - dgemmt_( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + dgemmt_blis_impl( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void dgemmt( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const double* alpha, const double* a, const f77_int* lda, const double* b, const f77_int* ldb, const double* beta, double* c, const f77_int* ldc) { - dgemmt_( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + dgemmt_blis_impl( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void DGEMMT_( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const double* alpha, const double* a, const f77_int* lda, const double* b, const f77_int* ldb, const double* beta, double* c, const f77_int* ldc) { - dgemmt_( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + dgemmt_blis_impl( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void DNRM2SUB(const f77_int* n, const double* x, const f77_int* incx, double *rval) @@ -2920,17 +2920,17 @@ void SGEMM_BATCH_(const f77_char* transa_array, const f77_char* transb_array,con void SGEMMT( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const float* alpha, const float* a, const f77_int* lda, const float* b, const f77_int* ldb, const float* beta, float* c, const f77_int* ldc) { - sgemmt_( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + sgemmt_blis_impl( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void sgemmt( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const float* alpha, const float* a, const f77_int* lda, const float* b, const f77_int* ldb, const float* beta, float* c, const f77_int* ldc) { - sgemmt_( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + sgemmt_blis_impl( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void SGEMMT_( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const float* alpha, const float* a, const f77_int* lda, const float* b, const f77_int* ldb, const float* beta, float* c, const f77_int* ldc) { - sgemmt_( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + sgemmt_blis_impl( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void SIMATCOPY( f77_char* trans, f77_int* rows, f77_int* cols, const float* alpha,float* aptr, f77_int* lda, f77_int* ldb) @@ -3055,17 +3055,17 @@ void ZDOTUSUB_( const f77_int* n, const dcomplex* x, const f77_int* incx,const d void ZGEMM3M( const f77_char* transa, const f77_char* transb, const f77_int* m, const f77_int* n, const f77_int* k, const dcomplex* alpha, const dcomplex* a, const f77_int* lda, const dcomplex* b, const f77_int* ldb, const dcomplex* beta, dcomplex* c, const f77_int* ldc) { - zgemm3m_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + zgemm3m_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void zgemm3m( const f77_char* transa, const f77_char* transb, const f77_int* m, const f77_int* n, const f77_int* k, const dcomplex* alpha, const dcomplex* a, const f77_int* lda, const dcomplex* b, const f77_int* ldb, const dcomplex* beta, dcomplex* c, const f77_int* ldc) { - zgemm3m_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + zgemm3m_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void ZGEMM3M_( const f77_char* transa, const f77_char* transb, const f77_int* m, const f77_int* n, const f77_int* k, const dcomplex* alpha, const dcomplex* a, const f77_int* lda, const dcomplex* b, const f77_int* ldb, const dcomplex* beta, dcomplex* c, const f77_int* ldc) { - zgemm3m_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + zgemm3m_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void ZGEMM_BATCH( const f77_char* transa_array, const f77_char* transb_array,const f77_int *m_array, const f77_int *n_array, const f77_int *k_array,const dcomplex* alpha_array, const dcomplex** a_array, const f77_int *lda_array, const dcomplex** b_array, const f77_int *ldb_array, const dcomplex* beta_array, dcomplex** c_array, const f77_int *ldc_array, const f77_int* group_count, const f77_int *group_size) @@ -3085,17 +3085,17 @@ void ZGEMM_BATCH_( const f77_char* transa_array, const f77_char* transb_array,c void ZGEMMT( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const dcomplex* alpha, const dcomplex* a, const f77_int* lda, const dcomplex* b, const f77_int* ldb, const dcomplex* beta, dcomplex* c, const f77_int* ldc) { - zgemmt_( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + zgemmt_blis_impl( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void zgemmt( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const dcomplex* alpha, const dcomplex* a, const f77_int* lda, const dcomplex* b, const f77_int* ldb, const dcomplex* beta, dcomplex* c, const f77_int* ldc) { - zgemmt_( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + zgemmt_blis_impl( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void ZGEMMT_( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const dcomplex* alpha, const dcomplex* a, const f77_int* lda, const dcomplex* b, const f77_int* ldb, const dcomplex* beta, dcomplex* c, const f77_int* ldc) { - zgemmt_( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + zgemmt_blis_impl( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void ZIMATCOPY(f77_char* trans, f77_int* rows, f77_int* cols, const dcomplex* alpha,dcomplex* aptr, f77_int* lda, f77_int* ldb) From ea79efa915af7c2ac1c6c2f3b4a86ec83b6a446d Mon Sep 17 00:00:00 2001 From: Nallani Bhaskar Date: Mon, 29 Aug 2022 23:06:28 +0530 Subject: [PATCH 204/243] Fixed out of bound memory access in sgemmsup zen rv kernels Details: 1. In sgemmsup_zen_rv_?x2 kernels "vmovps" instruction is used to load B matrix in k loop and k last loop, which is loading 128 bit into xmm than 64 bit as expected. 2. Changed vmovps instruction to vmovsd instrucntions which load only 64 bit in xmm register 3. Avoided C memory access by vfma instruction when multiplying with non-beta at corner cases with required access to 128 bit which leads to out of bound. Replaced with vmovq first to get 64 bit data then peformed vfma on xmm register in rv_6x8m and rv_6x4m AMD-Internal: [CPUPL-2472] Change-Id: Iad397f8f5b5cc607b4278b603b1e0ea3f6b082f2 --- .../zen/3/sup/bli_gemmsup_rv_zen_asm_s6x16.c | 60 +++++++++---------- .../zen/3/sup/bli_gemmsup_rv_zen_asm_s6x16m.c | 30 ++++++---- 2 files changed, 50 insertions(+), 40 deletions(-) diff --git a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_s6x16.c b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_s6x16.c index 507ff5a717..7befbb69bb 100644 --- a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_s6x16.c +++ b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_s6x16.c @@ -6853,7 +6853,7 @@ void bli_sgemmsup_rv_zen_asm_6x2 // ---------------------------------- iteration 0 - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), xmm2) @@ -6874,7 +6874,7 @@ void bli_sgemmsup_rv_zen_asm_6x2 // ---------------------------------- iteration 1 - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), xmm2) @@ -6895,7 +6895,7 @@ void bli_sgemmsup_rv_zen_asm_6x2 // ---------------------------------- iteration 2 - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), xmm2) @@ -6916,7 +6916,7 @@ void bli_sgemmsup_rv_zen_asm_6x2 // ---------------------------------- iteration 3 - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), xmm2) @@ -6949,7 +6949,7 @@ void bli_sgemmsup_rv_zen_asm_6x2 label(.SLOOPKLEFT) // EDGE LOOP - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), xmm2) @@ -7237,7 +7237,7 @@ void bli_sgemmsup_rv_zen_asm_5x2 label(.SLOOPKITER) // MAIN LOOP // ---------------------------------- iteration 0 - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), xmm2) @@ -7256,7 +7256,7 @@ void bli_sgemmsup_rv_zen_asm_5x2 // ---------------------------------- iteration 1 - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), xmm2) @@ -7275,7 +7275,7 @@ void bli_sgemmsup_rv_zen_asm_5x2 // ---------------------------------- iteration 2 - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), xmm2) @@ -7294,7 +7294,7 @@ void bli_sgemmsup_rv_zen_asm_5x2 // ---------------------------------- iteration 3 - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), xmm2) @@ -7325,7 +7325,7 @@ void bli_sgemmsup_rv_zen_asm_5x2 label(.SLOOPKLEFT) // EDGE LOOP - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), xmm2) @@ -7612,7 +7612,7 @@ void bli_sgemmsup_rv_zen_asm_4x2 label(.SLOOPKITER) // MAIN LOOP // ---------------------------------- iteration 0 - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), ymm2) @@ -7627,7 +7627,7 @@ void bli_sgemmsup_rv_zen_asm_4x2 vfmadd231ps(xmm0, xmm3, xmm10) // ---------------------------------- iteration 1 - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), ymm2) @@ -7642,7 +7642,7 @@ void bli_sgemmsup_rv_zen_asm_4x2 vfmadd231ps(xmm0, xmm3, xmm10) // ---------------------------------- iteration 2 - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), ymm2) @@ -7657,7 +7657,7 @@ void bli_sgemmsup_rv_zen_asm_4x2 vfmadd231ps(xmm0, xmm3, xmm10) // ---------------------------------- iteration 3 - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), ymm2) @@ -7685,7 +7685,7 @@ void bli_sgemmsup_rv_zen_asm_4x2 label(.SLOOPKLEFT) // EDGE LOOP - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), ymm2) @@ -7940,7 +7940,7 @@ void bli_sgemmsup_rv_zen_asm_3x2 label(.SLOOPKITER) // MAIN LOOP // ---------------------------------- iteration 0 - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), ymm2) @@ -7953,7 +7953,7 @@ void bli_sgemmsup_rv_zen_asm_3x2 vfmadd231ps(xmm0, xmm2, xmm8) // ---------------------------------- iteration 1 - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), ymm2) @@ -7966,7 +7966,7 @@ void bli_sgemmsup_rv_zen_asm_3x2 vfmadd231ps(xmm0, xmm2, xmm8) // ---------------------------------- iteration 2 - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), ymm2) @@ -7979,7 +7979,7 @@ void bli_sgemmsup_rv_zen_asm_3x2 vfmadd231ps(xmm0, xmm2, xmm8) // ---------------------------------- iteration 3 - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), ymm2) @@ -8005,7 +8005,7 @@ void bli_sgemmsup_rv_zen_asm_3x2 label(.SLOOPKLEFT) // EDGE LOOP - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), ymm2) @@ -8245,7 +8245,7 @@ void bli_sgemmsup_rv_zen_asm_2x2 // ---------------------------------- iteration 0 - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), ymm2) vbroadcastss(mem(rax, r8, 1), ymm3) @@ -8254,7 +8254,7 @@ void bli_sgemmsup_rv_zen_asm_2x2 vfmadd231ps(xmm0, xmm3, xmm6) // ---------------------------------- iteration 1 - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), ymm2) vbroadcastss(mem(rax, r8, 1), ymm3) @@ -8263,7 +8263,7 @@ void bli_sgemmsup_rv_zen_asm_2x2 vfmadd231ps(xmm0, xmm3, xmm6) // ---------------------------------- iteration 2 - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), ymm2) @@ -8273,7 +8273,7 @@ void bli_sgemmsup_rv_zen_asm_2x2 vfmadd231ps(xmm0, xmm3, xmm6) // ---------------------------------- iteration 3 - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), ymm2) @@ -8294,7 +8294,7 @@ void bli_sgemmsup_rv_zen_asm_2x2 label(.SLOOPKLEFT) // EDGE LOOP - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), ymm2) @@ -8503,7 +8503,7 @@ void bli_sgemmsup_rv_zen_asm_1x2 label(.SLOOPKITER) // MAIN LOOP // ---------------------------------- iteration 0 - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), ymm2) @@ -8511,7 +8511,7 @@ void bli_sgemmsup_rv_zen_asm_1x2 vfmadd231ps(xmm0, xmm2, xmm4) // ---------------------------------- iteration 1 - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), ymm2) @@ -8520,7 +8520,7 @@ void bli_sgemmsup_rv_zen_asm_1x2 // ---------------------------------- iteration 2 - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), ymm2) @@ -8528,7 +8528,7 @@ void bli_sgemmsup_rv_zen_asm_1x2 vfmadd231ps(xmm0, xmm2, xmm4) // ---------------------------------- iteration 3 - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), ymm2) @@ -8548,7 +8548,7 @@ void bli_sgemmsup_rv_zen_asm_1x2 label(.SLOOPKLEFT) // EDGE LOOP - vmovups(mem(rbx, 0*32), xmm0) + vmovsd(mem(rbx, 0*32), xmm0) add(r10, rbx) // b += rs_b; vbroadcastss(mem(rax ), xmm2) diff --git a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_s6x16m.c b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_s6x16m.c index e6ecd47f47..d5e2135a66 100644 --- a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_s6x16m.c +++ b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_s6x16m.c @@ -1291,13 +1291,17 @@ void bli_sgemmsup_rv_zen_asm_6x8m vextractf128(imm(0x1), ymm0, xmm2) vpermilps(imm(0xe),xmm0,xmm5) vpermilps(imm(0xe),xmm2,xmm6) - vfmadd231ps(mem(rdx), xmm3, xmm0) - vfmadd231ps(mem(rdx, rsi, 4), xmm3, xmm2) + vmovq(mem(rdx),xmm4) + vmovq(mem(rdx, rsi, 4),xmm1) + vfmadd231ps(xmm4, xmm3, xmm0) + vfmadd231ps(xmm1, xmm3, xmm2) vmovlpd(xmm0, mem(rdx)) // store ( gamma40..gamma50 ) vmovlpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma44..gamma54 ) lea(mem(rdx, rsi, 1), rdx) - vfmadd231ps(mem(rdx), xmm3, xmm5) - vfmadd231ps(mem(rdx, rsi, 4), xmm3, xmm6) + vmovq(mem(rdx),xmm4) + vmovq(mem(rdx, rsi, 4),xmm1) + vfmadd231ps(xmm4, xmm3, xmm5) + vfmadd231ps(xmm1, xmm3, xmm6) vmovlpd(xmm5, mem(rdx)) // store ( gamma41..gamma51 ) vmovlpd(xmm6, mem(rdx, rsi, 4)) // store ( gamma45..gamma55 ) lea(mem(rdx, rsi, 1), rdx) @@ -1306,13 +1310,17 @@ void bli_sgemmsup_rv_zen_asm_6x8m vextractf128(imm(0x1), ymm0, xmm2) vpermilps(imm(0xe),xmm0,xmm5) vpermilps(imm(0xe),xmm2,xmm6) - vfmadd231ps(mem(rdx), xmm3, xmm0) - vfmadd231ps(mem(rdx, rsi, 4), xmm3, xmm2) + vmovq(mem(rdx),xmm4) + vmovq(mem(rdx, rsi, 4),xmm1) + vfmadd231ps(xmm4, xmm3, xmm0) + vfmadd231ps(xmm1, xmm3, xmm2) vmovlpd(xmm0, mem(rdx)) // store ( gamma42..gamma52 ) vmovlpd(xmm2, mem(rdx, rsi, 4)) // store ( gamma46..gamma56 ) lea(mem(rdx, rsi, 1), rdx) - vfmadd231ps(mem(rdx), xmm3, xmm5) - vfmadd231ps(mem(rdx, rsi, 4), xmm3, xmm6) + vmovq(mem(rdx),xmm4) + vmovq(mem(rdx, rsi, 4),xmm1) + vfmadd231ps(xmm4, xmm3, xmm5) + vfmadd231ps(xmm1, xmm3, xmm6) vmovlpd(xmm5, mem(rdx)) // store ( gamma43..gamma53 ) vmovlpd(xmm6, mem(rdx, rsi, 4)) // store ( gamma47..gamma57 ) @@ -1810,11 +1818,13 @@ void bli_sgemmsup_rv_zen_asm_6x4m lea(mem(rdx, rsi, 1), rdx) vunpckhps(xmm14, xmm12, xmm0) vpermilps(imm(0x4e), xmm0, xmm5) - vfmadd231ps(mem(rdx), xmm3, xmm0) + vmovq(mem(rdx),xmm4) + vfmadd231ps(xmm4, xmm3, xmm0) vmovlpd(xmm0, mem(rdx)) // store ( gamma42..gamma52 ) lea(mem(rdx, rsi, 1), rdx) - vfmadd231ps(mem(rdx), xmm3, xmm5) + vmovq(mem(rdx),xmm4) + vfmadd231ps(xmm4, xmm3, xmm5) vmovlpd(xmm5, mem(rdx)) // store ( gamma43..gamma53 ) jmp(.SDONE) // jump to end. From 6cff8b030e0021b23bdb2bbd5d05a442edefc105 Mon Sep 17 00:00:00 2001 From: Dipal M Zambare Date: Tue, 30 Aug 2022 10:43:01 +0530 Subject: [PATCH 205/243] Revert "CBLAS/BLAS interface decoupling for level 1 APIs" This reverts commit 95169ca8066c991bc0a3166a5e7da9594c0c393c. Change-Id: Ic441aca616be6f27c7f1ba64e4480edcc6b17632 --- frame/compat/bla_amax.c | 11 +- frame/compat/bla_amax.h | 9 +- frame/compat/bla_amax_amd.c | 32 +-- frame/compat/bla_amin.c | 14 +- frame/compat/bla_amin.h | 8 +- frame/compat/bla_asum.c | 13 +- frame/compat/bla_asum.h | 9 +- frame/compat/bla_axpby.c | 17 +- frame/compat/bla_axpby.h | 14 +- frame/compat/bla_axpy.c | 15 +- frame/compat/bla_axpy.h | 12 +- frame/compat/bla_axpy_amd.c | 66 +---- frame/compat/bla_copy.c | 12 +- frame/compat/bla_copy.h | 10 +- frame/compat/bla_copy_amd.c | 35 +-- frame/compat/bla_dot.c | 48 +--- frame/compat/bla_dot.h | 33 +-- frame/compat/bla_dot_amd.c | 119 +-------- frame/compat/bla_nrm2.c | 13 +- frame/compat/bla_nrm2.h | 9 +- frame/compat/bla_scal.c | 13 +- frame/compat/bla_scal.h | 10 +- frame/compat/bla_scal_amd.c | 35 +-- frame/compat/cblas/f77_sub/f77_amax_sub.c | 17 +- frame/compat/cblas/f77_sub/f77_amax_sub.h | 10 +- frame/compat/cblas/f77_sub/f77_amin_sub.c | 18 +- frame/compat/cblas/f77_sub/f77_amin_sub.h | 9 +- frame/compat/cblas/f77_sub/f77_asum_sub.c | 17 +- frame/compat/cblas/f77_sub/f77_asum_sub.h | 10 +- frame/compat/cblas/f77_sub/f77_dot_sub.c | 60 +---- frame/compat/cblas/f77_sub/f77_dot_sub.h | 29 +-- frame/compat/cblas/f77_sub/f77_nrm2_sub.c | 17 +- frame/compat/cblas/f77_sub/f77_nrm2_sub.h | 10 +- frame/compat/cblas/src/cblas_f77.h | 103 ++++---- frame/compat/f2c/bla_rot.c | 19 +- frame/compat/f2c/bla_rot.h | 5 +- frame/compat/f2c/bla_rotg.c | 19 +- frame/compat/f2c/bla_rotg.h | 5 +- frame/compat/f2c/bla_rotm.c | 19 +- frame/compat/f2c/bla_rotm.h | 5 +- frame/compat/f2c/bla_rotmg.c | 19 +- frame/compat/f2c/bla_rotmg.h | 5 +- frame/util/bli_util_api_wrap.c | 302 +++++++++++----------- 43 files changed, 309 insertions(+), 946 deletions(-) diff --git a/frame/compat/bla_amax.c b/frame/compat/bla_amax.c index bb924d53ba..b1cf77e7b8 100644 --- a/frame/compat/bla_amax.c +++ b/frame/compat/bla_amax.c @@ -41,7 +41,7 @@ #undef GENTFUNC #define GENTFUNC( ftype_x, chx, blasname, blisname ) \ \ -f77_int PASTEF772S(i,chx,blasname) \ +f77_int PASTEF772(i,chx,blasname) \ ( \ const f77_int* n, \ const ftype_x* x, const f77_int* incx \ @@ -95,15 +95,6 @@ f77_int PASTEF772S(i,chx,blasname) \ \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ return f77_index; \ -}\ -\ -f77_int PASTEF772(i,chx,blasname) \ - ( \ - const f77_int* n, \ - const ftype_x* x, const f77_int* incx \ - ) \ -{ \ - return PASTEF772S(i,chx,blasname)( n, x, incx );\ } #ifdef BLIS_ENABLE_BLAS diff --git a/frame/compat/bla_amax.h b/frame/compat/bla_amax.h index 093f1f45cf..1f13715dc4 100644 --- a/frame/compat/bla_amax.h +++ b/frame/compat/bla_amax.h @@ -5,8 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. - + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -41,12 +40,6 @@ #define GENTPROT( ftype_x, chx, blasname ) \ \ BLIS_EXPORT_BLAS f77_int PASTEF772(i,chx,blasname) \ - ( \ - const f77_int* n, \ - const ftype_x* x, const f77_int* incx \ - );\ -\ -BLIS_EXPORT_BLAS f77_int PASTEF772S(i,chx,blasname) \ ( \ const f77_int* n, \ const ftype_x* x, const f77_int* incx \ diff --git a/frame/compat/bla_amax_amd.c b/frame/compat/bla_amax_amd.c index 8804045350..2f7c2d2491 100644 --- a/frame/compat/bla_amax_amd.c +++ b/frame/compat/bla_amax_amd.c @@ -41,7 +41,7 @@ #undef GENTFUNC #define GENTFUNC( ftype_x, chx, blasname, blisname ) \ \ -f77_int PASTEF772S(i,chx,blasname) \ +f77_int PASTEF772(i,chx,blasname) \ ( \ const f77_int* n, \ const ftype_x* x, const f77_int* incx \ @@ -95,20 +95,11 @@ f77_int PASTEF772S(i,chx,blasname) \ \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ return f77_index; \ -}\ -\ -f77_int PASTEF772(i,chx,blasname) \ - ( \ - const f77_int* n, \ - const ftype_x* x, const f77_int* incx \ - ) \ -{ \ - return PASTEF772S(i,chx,blasname)( n, x, incx );\ } #ifdef BLIS_ENABLE_BLAS -f77_int isamax_blis_impl +f77_int isamax_ ( const f77_int* n, const float* x, const f77_int* incx @@ -206,16 +197,8 @@ f77_int isamax_blis_impl return f77_index; } -f77_int isamax_ - ( - const f77_int* n, - const float* x, const f77_int* incx - ) -{ - return isamax_blis_impl( n, x, incx ); -} -f77_int idamax_blis_impl +f77_int idamax_ ( const f77_int* n, const double* x, const f77_int* incx @@ -310,14 +293,7 @@ f77_int idamax_blis_impl AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); return f77_index; } -f77_int idamax_ - ( - const f77_int* n, - const double* x, const f77_int* incx - ) -{ - return idamax_blis_impl( n, x, incx ); -} + INSERT_GENTFUNC_BLAS_CZ( amax, amaxv ) #endif diff --git a/frame/compat/bla_amin.c b/frame/compat/bla_amin.c index 7c8be7e51f..7930fc1854 100644 --- a/frame/compat/bla_amin.c +++ b/frame/compat/bla_amin.c @@ -5,8 +5,7 @@ libraries. Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. - + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -42,7 +41,7 @@ #undef GENTFUNC #define GENTFUNC( ftype_x, chx, blasname, blisname ) \ \ -f77_int PASTEF772S(i,chx,blasname) \ +f77_int PASTEF772(i,chx,blasname) \ ( \ const f77_int* n, \ const ftype_x* x, const f77_int* incx \ @@ -89,15 +88,6 @@ f77_int PASTEF772S(i,chx,blasname) \ bli_finalize_auto(); \ \ return f77_index; \ -}\ -\ -f77_int PASTEF772(i,chx,blasname) \ - ( \ - const f77_int* n, \ - const ftype_x* x, const f77_int* incx \ - ) \ -{ \ - return PASTEF772S(i,chx,blasname)( n, x, incx );\ } #ifdef BLIS_ENABLE_BLAS diff --git a/frame/compat/bla_amin.h b/frame/compat/bla_amin.h index 9b3ff7524a..ebbed8262b 100644 --- a/frame/compat/bla_amin.h +++ b/frame/compat/bla_amin.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2020-2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -40,12 +40,6 @@ #define GENTPROT( ftype_x, chx, blasname ) \ \ BLIS_EXPORT_BLAS f77_int PASTEF772(i,chx,blasname) \ - ( \ - const f77_int* n, \ - const ftype_x* x, const f77_int* incx \ - );\ -\ -BLIS_EXPORT_BLAS f77_int PASTEF772S(i,chx,blasname) \ ( \ const f77_int* n, \ const ftype_x* x, const f77_int* incx \ diff --git a/frame/compat/bla_asum.c b/frame/compat/bla_asum.c index 024d821efb..c104be96bd 100644 --- a/frame/compat/bla_asum.c +++ b/frame/compat/bla_asum.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2020-2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -41,7 +41,7 @@ #undef GENTFUNCR2 #define GENTFUNCR2( ftype_x, ftype_r, chx, chr, blasname, blisname ) \ \ -ftype_r PASTEF772S(chr,chx,blasname) \ +ftype_r PASTEF772(chr,chx,blasname) \ ( \ const f77_int* n, \ const ftype_x* x, const f77_int* incx \ @@ -79,15 +79,6 @@ ftype_r PASTEF772S(chr,chx,blasname) \ bli_finalize_auto(); \ \ return asum; \ -}\ -\ -ftype_r PASTEF772(chr,chx,blasname) \ - ( \ - const f77_int* n, \ - const ftype_x* x, const f77_int* incx \ - ) \ -{ \ - return PASTEF772S(chr,chx,blasname)( n, x, incx );\ } #ifdef BLIS_ENABLE_BLAS diff --git a/frame/compat/bla_asum.h b/frame/compat/bla_asum.h index 6460a11178..a9ef27a036 100644 --- a/frame/compat/bla_asum.h +++ b/frame/compat/bla_asum.h @@ -5,8 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. - + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -41,12 +40,6 @@ #define GENTPROTR2( ftype_x, ftype_r, chx, chr, blasname ) \ \ BLIS_EXPORT_BLAS ftype_r PASTEF772(chr,chx,blasname) \ - ( \ - const f77_int* n, \ - const ftype_x* x, const f77_int* incx \ - );\ -\ -BLIS_EXPORT_BLAS ftype_r PASTEF772S(chr,chx,blasname) \ ( \ const f77_int* n, \ const ftype_x* x, const f77_int* incx \ diff --git a/frame/compat/bla_axpby.c b/frame/compat/bla_axpby.c index 90d0563190..be53ec480b 100644 --- a/frame/compat/bla_axpby.c +++ b/frame/compat/bla_axpby.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2020-2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -41,7 +41,7 @@ #undef GENTFUNC #define GENTFUNC( ftype, ch, blasname, blisname ) \ \ -void PASTEF77S(ch,blasname) \ +void PASTEF77(ch,blasname) \ ( \ const f77_int* n, \ const ftype* alpha, \ @@ -85,19 +85,6 @@ void PASTEF77S(ch,blasname) \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ -}\ -\ -void PASTEF77(ch,blasname) \ - ( \ - const f77_int* n, \ - const ftype* alpha, \ - const ftype* x, const f77_int* incx, \ - const ftype* beta, \ - ftype* y, const f77_int* incy \ - ) \ -{ \ - PASTEF77S(ch,blasname) \ - ( n, alpha, x, incx, beta, y, incy ); \ } #ifdef BLIS_ENABLE_BLAS diff --git a/frame/compat/bla_axpby.h b/frame/compat/bla_axpby.h index 74ca8908a7..ab2952be98 100644 --- a/frame/compat/bla_axpby.h +++ b/frame/compat/bla_axpby.h @@ -5,8 +5,7 @@ libraries. Copyright (C) 2020, Advanced Micro Devices, Inc. - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. - + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -41,15 +40,6 @@ #define GENTPROT( ftype, ch, blasname ) \ \ BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ - ( \ - const f77_int* n, \ - const ftype* alpha, \ - const ftype* x, const f77_int* incx, \ - const ftype* beta, \ - ftype* y, const f77_int* incy \ - );\ -\ -BLIS_EXPORT_BLAS void PASTEF77S(ch,blasname) \ ( \ const f77_int* n, \ const ftype* alpha, \ @@ -57,7 +47,7 @@ BLIS_EXPORT_BLAS void PASTEF77S(ch,blasname) \ const ftype* beta, \ ftype* y, const f77_int* incy \ ); - + #ifdef BLIS_ENABLE_BLAS INSERT_GENTPROT_BLAS( axpby ) #endif diff --git a/frame/compat/bla_axpy.c b/frame/compat/bla_axpy.c index e5ca995914..1a30f417b3 100644 --- a/frame/compat/bla_axpy.c +++ b/frame/compat/bla_axpy.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 22, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -42,7 +42,7 @@ #undef GENTFUNC #define GENTFUNC( ftype, ch, blasname, blisname ) \ \ -void PASTEF77S(ch,blasname) \ +void PASTEF77(ch,blasname) \ ( \ const f77_int* n, \ const ftype* alpha, \ @@ -83,17 +83,6 @@ void PASTEF77S(ch,blasname) \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ -}\ -\ -void PASTEF77(ch,blasname) \ - ( \ - const f77_int* n, \ - const ftype* alpha, \ - const ftype* x, const f77_int* incx, \ - ftype* y, const f77_int* incy \ - ) \ -{ \ - PASTEF77S(ch,blasname)( n, alpha, x, incx, y, incy ) ; \ } #ifdef BLIS_ENABLE_BLAS diff --git a/frame/compat/bla_axpy.h b/frame/compat/bla_axpy.h index dcbc5df8c1..294a385c78 100644 --- a/frame/compat/bla_axpy.h +++ b/frame/compat/bla_axpy.h @@ -5,8 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. - + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -41,20 +40,13 @@ #define GENTPROT( ftype, ch, blasname ) \ \ BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ - ( \ - const f77_int* n, \ - const ftype* alpha, \ - const ftype* x, const f77_int* incx, \ - ftype* y, const f77_int* incy \ - );\ -\ -BLIS_EXPORT_BLAS void PASTEF77S(ch,blasname) \ ( \ const f77_int* n, \ const ftype* alpha, \ const ftype* x, const f77_int* incx, \ ftype* y, const f77_int* incy \ ); + #ifdef BLIS_ENABLE_BLAS INSERT_GENTPROT_BLAS( axpy ) #endif diff --git a/frame/compat/bla_axpy_amd.c b/frame/compat/bla_axpy_amd.c index 62f0c4df3d..8a9f0280c6 100644 --- a/frame/compat/bla_axpy_amd.c +++ b/frame/compat/bla_axpy_amd.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 22, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -42,7 +42,7 @@ #undef GENTFUNC #define GENTFUNC( ftype, ch, blasname, blisname ) \ \ -void PASTEF77S(ch,blasname) \ +void PASTEF77(ch,blasname) \ ( \ const f77_int* n, \ const ftype* alpha, \ @@ -83,23 +83,11 @@ void PASTEF77S(ch,blasname) \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ -}\ -\ -void PASTEF77(ch,blasname) \ - ( \ - const f77_int* n, \ - const ftype* alpha, \ - const ftype* x, const f77_int* incx, \ - ftype* y, const f77_int* incy \ - ) \ -{ \ - PASTEF77S(ch,blasname)( n, alpha, x, incx, y, incy ) ; \ } - #ifdef BLIS_ENABLE_BLAS -void saxpy_blis_impl +void saxpy_ ( const f77_int* n, const float* alpha, @@ -190,18 +178,8 @@ void saxpy_blis_impl // bli_finalize_auto(); AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); } -void saxpy_ -( - const f77_int* n, - const float* alpha, - const float* x, const f77_int* incx, - float* y, const f77_int* incy - ) -{ - saxpy_blis_impl( n, alpha, x, incx, y, incy ) ; -} -void daxpy_blis_impl +void daxpy_ ( const f77_int* n, const double* alpha, @@ -293,18 +271,8 @@ void daxpy_blis_impl /* Finalize BLIS. */ // bli_finalize_auto(); } -void daxpy_ -( - const f77_int* n, - const double* alpha, - const double* x, const f77_int* incx, - double* y, const f77_int* incy - ) -{ - daxpy_blis_impl( n, alpha, x, incx, y, incy ) ; -} -void caxpy_blis_impl +void caxpy_ ( const f77_int* n, const scomplex* alpha, @@ -395,17 +363,8 @@ void caxpy_blis_impl /* Finalize BLIS. */ // bli_finalize_auto(); } -void caxpy_ -( - const f77_int* n, - const scomplex* alpha, - const scomplex* x, const f77_int* incx, - scomplex* y, const f77_int* incy - ) -{ - caxpy_blis_impl( n, alpha, x, incx, y, incy ) ; -} -void zaxpy_blis_impl + +void zaxpy_ ( const f77_int* n, const dcomplex* alpha, @@ -497,16 +456,7 @@ void zaxpy_blis_impl /* Finalize BLIS. */ // bli_finalize_auto(); } -void zaxpy_ -( - const f77_int* n, - const dcomplex* alpha, - const dcomplex* x, const f77_int* incx, - dcomplex* y, const f77_int* incy - ) -{ - zaxpy_blis_impl( n, alpha, x, incx, y, incy ) ; -} + #endif diff --git a/frame/compat/bla_copy.c b/frame/compat/bla_copy.c index 4f4e1e874e..74baba689c 100644 --- a/frame/compat/bla_copy.c +++ b/frame/compat/bla_copy.c @@ -42,7 +42,7 @@ #undef GENTFUNC #define GENTFUNC( ftype, ch, blasname, blisname ) \ \ -void PASTEF77S(ch,blasname) \ +void PASTEF77(ch,blasname) \ ( \ const f77_int* n, \ const ftype* x, const f77_int* incx, \ @@ -85,16 +85,6 @@ void PASTEF77S(ch,blasname) \ \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ -}\ -\ -void PASTEF77(ch,blasname) \ - ( \ - const f77_int* n, \ - const ftype* x, const f77_int* incx, \ - ftype* y, const f77_int* incy \ - ) \ -{ \ - PASTEF77S(ch,blasname)( n, x, incx, y, incy ); \ } #ifdef BLIS_ENABLE_BLAS diff --git a/frame/compat/bla_copy.h b/frame/compat/bla_copy.h index cfe67967c4..679017b19d 100644 --- a/frame/compat/bla_copy.h +++ b/frame/compat/bla_copy.h @@ -5,8 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. - + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -41,13 +40,6 @@ #define GENTPROT( ftype, ch, blasname ) \ \ BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ - ( \ - const f77_int* n, \ - const ftype* x, const f77_int* incx, \ - ftype* y, const f77_int* incy \ - );\ -\ -BLIS_EXPORT_BLAS void PASTEF77S(ch,blasname) \ ( \ const f77_int* n, \ const ftype* x, const f77_int* incx, \ diff --git a/frame/compat/bla_copy_amd.c b/frame/compat/bla_copy_amd.c index 4ed0c7f548..8dc4d5287c 100644 --- a/frame/compat/bla_copy_amd.c +++ b/frame/compat/bla_copy_amd.c @@ -42,7 +42,7 @@ #undef GENTFUNC #define GENTFUNC( ftype, ch, blasname, blisname ) \ \ -void PASTEF77S(ch,blasname) \ +void PASTEF77(ch,blasname) \ ( \ const f77_int* n, \ const ftype* x, const f77_int* incx, \ @@ -85,21 +85,11 @@ void PASTEF77S(ch,blasname) \ \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ -}\ -\ -void PASTEF77(ch,blasname) \ - ( \ - const f77_int* n, \ - const ftype* x, const f77_int* incx, \ - ftype* y, const f77_int* incy \ - ) \ -{ \ - PASTEF77S(ch,blasname)( n, x, incx, y, incy ); \ } #ifdef BLIS_ENABLE_BLAS -void scopy_blis_impl +void scopy_ ( const f77_int* n, const float* x, const f77_int* incx, @@ -193,17 +183,8 @@ void scopy_blis_impl /* Finalize BLIS. */ // bli_finalize_auto(); } -void scopy_ -( - const f77_int* n, - const float* x, const f77_int* incx, - float* y, const f77_int* incy -) -{ - scopy_blis_impl( n, x, incx, y, incy ); -} -void dcopy_blis_impl +void dcopy_ ( const f77_int* n, const double* x, const f77_int* incx, @@ -298,15 +279,7 @@ void dcopy_blis_impl /* Finalize BLIS. */ // bli_finalize_auto(); } -void dcopy_ -( - const f77_int* n, - const double* x, const f77_int* incx, - double* y, const f77_int* incy -) -{ - dcopy_blis_impl( n, x, incx, y, incy ); -} + INSERT_GENTFUNC_BLAS_CZ(copy, copyv) #endif diff --git a/frame/compat/bla_dot.c b/frame/compat/bla_dot.c index 79c65c4d8d..3c4d8c538f 100644 --- a/frame/compat/bla_dot.c +++ b/frame/compat/bla_dot.c @@ -42,7 +42,7 @@ #undef GENTFUNCDOT #define GENTFUNCDOT( ftype, ch, chc, blis_conjx, blasname, blisname ) \ \ -ftype PASTEF772S(ch,blasname,chc) \ +ftype PASTEF772(ch,blasname,chc) \ ( \ const f77_int* n, \ const ftype* x, const f77_int* incx, \ @@ -87,16 +87,6 @@ ftype PASTEF772S(ch,blasname,chc) \ bli_finalize_auto(); \ \ return rho; \ -}\ -\ -ftype PASTEF772(ch,blasname,chc) \ - ( \ - const f77_int* n, \ - const ftype* x, const f77_int* incx, \ - const ftype* y, const f77_int* incy \ - ) \ -{ \ - return PASTEF772S(ch,blasname,chc)( n, x, incx, y, incy );\ } #ifdef BLIS_ENABLE_BLAS @@ -110,7 +100,7 @@ INSERT_GENTFUNCDOTC_BLAS( dot, dotv ) #undef GENTFUNCDOT #define GENTFUNCDOT( ftype, ch, chc, blis_conjx, blasname, blisname ) \ \ -void PASTEF772S(ch,blasname,chc) \ +void PASTEF772(ch,blasname,chc) \ ( \ ftype* rhop, \ const f77_int* n, \ @@ -156,17 +146,6 @@ void PASTEF772S(ch,blasname,chc) \ bli_finalize_auto(); \ \ *rhop = rho; \ -}\ -\ -void PASTEF772(ch,blasname,chc) \ - ( \ - ftype* rhop, \ - const f77_int* n, \ - const ftype* x, const f77_int* incx, \ - const ftype* y, const f77_int* incy \ - ) \ -{ \ - PASTEF772S(ch,blasname,chc)( rhop, n, x, incx, y, incy );\ } INSERT_GENTFUNCDOTC_BLAS( dot, dotv ) @@ -178,7 +157,7 @@ INSERT_GENTFUNCDOTC_BLAS( dot, dotv ) // Input vectors stored in single precision, computed in double precision, // with result returned in single precision. -float PASTEF77S(sd,sdot) +float PASTEF77(sd,sdot) ( const f77_int* n, const float* sb, @@ -197,20 +176,10 @@ float PASTEF77S(sd,sdot) ) ); } -float PASTEF77(sd,sdot) - ( - const f77_int* n, - const float* sb, - const float* x, const f77_int* incx, - const float* y, const f77_int* incy - ) -{ - return PASTEF77S(sd,sdot)( n, sb, x, incx, y, incy ); -} // Input vectors stored in single precision, computed in double precision, // with result returned in double precision. -double PASTEF77S(d,sdot) +double PASTEF77(d,sdot) ( const f77_int* n, const float* x, const f77_int* incx, @@ -254,14 +223,5 @@ double PASTEF77S(d,sdot) return rho; } -double PASTEF77(d,sdot) - ( - const f77_int* n, - const float* x, const f77_int* incx, - const float* y, const f77_int* incy - ) -{ - return PASTEF77S(d,sdot)( n, x, incx, y, incy ); -} #endif // BLIS_ENABLE_BLAS diff --git a/frame/compat/bla_dot.h b/frame/compat/bla_dot.h index a582503753..16bc3f97cc 100644 --- a/frame/compat/bla_dot.h +++ b/frame/compat/bla_dot.h @@ -5,8 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. - + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -41,13 +40,6 @@ #define GENTPROTDOT( ftype, ch, chc, blasname ) \ \ BLIS_EXPORT_BLAS ftype PASTEF772(ch,blasname,chc) \ - ( \ - const f77_int* n, \ - const ftype* x, const f77_int* incx, \ - const ftype* y, const f77_int* incy \ - );\ -\ -BLIS_EXPORT_BLAS ftype PASTEF772S(ch,blasname,chc) \ ( \ const f77_int* n, \ const ftype* x, const f77_int* incx, \ @@ -68,14 +60,6 @@ INSERT_GENTPROTDOTC_BLAS( dot ) #define GENTPROTDOT( ftype, ch, chc, blasname ) \ \ BLIS_EXPORT_BLAS void PASTEF772(ch,blasname,chc) \ - ( \ - ftype* rhop, \ - const f77_int* n, \ - const ftype* x, const f77_int* incx, \ - const ftype* y, const f77_int* incy \ - );\ -\ -BLIS_EXPORT_BLAS void PASTEF772S(ch,blasname,chc) \ ( \ ftype* rhop, \ const f77_int* n, \ @@ -96,25 +80,12 @@ BLIS_EXPORT_BLAS float PASTEF77(sd,sdot) const float* x, const f77_int* incx, const float* y, const f77_int* incy ); -BLIS_EXPORT_BLAS float PASTEF77S(sd,sdot) - ( - const f77_int* n, - const float* sb, - const float* x, const f77_int* incx, - const float* y, const f77_int* incy - ); - + BLIS_EXPORT_BLAS double PASTEF77(d,sdot) ( const f77_int* n, const float* x, const f77_int* incx, const float* y, const f77_int* incy ); -BLIS_EXPORT_BLAS double PASTEF77S(d,sdot) - ( - const f77_int* n, - const float* x, const f77_int* incx, - const float* y, const f77_int* incy - ); #endif diff --git a/frame/compat/bla_dot_amd.c b/frame/compat/bla_dot_amd.c index 0e954f3317..0cdaa6535b 100644 --- a/frame/compat/bla_dot_amd.c +++ b/frame/compat/bla_dot_amd.c @@ -42,7 +42,7 @@ #undef GENTFUNCDOT #define GENTFUNCDOT( ftype, ch, chc, blis_conjx, blasname, blisname ) \ \ -ftype PASTEF772S(ch,blasname,chc) \ +ftype PASTEF772(ch,blasname,chc) \ ( \ const f77_int* n, \ const ftype* x, const f77_int* incx, \ @@ -87,20 +87,10 @@ ftype PASTEF772S(ch,blasname,chc) \ bli_finalize_auto(); \ \ return rho; \ -}\ -\ -ftype PASTEF772(ch,blasname,chc) \ - ( \ - const f77_int* n, \ - const ftype* x, const f77_int* incx, \ - const ftype* y, const f77_int* incy \ - ) \ -{ \ - return PASTEF772S(ch,blasname,chc)( n, x, incx, y, incy );\ } #ifdef BLIS_ENABLE_BLAS -float sdot_blis_impl +float sdot_ ( const f77_int* n, const float* x, const f77_int* incx, @@ -201,17 +191,7 @@ float sdot_blis_impl return rho; } -float sdot_ - ( - const f77_int* n, - const float* x, const f77_int* incx, - const float* y, const f77_int* incy - ) -{ - return sdot_blis_impl( n, x, incx, y, incy ); -} - -double ddot_blis_impl +double ddot_ ( const f77_int* n, const double* x, const f77_int* incx, @@ -311,18 +291,9 @@ double ddot_blis_impl AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); return rho; } -double ddot_ - ( - const f77_int* n, - const double* x, const f77_int* incx, - const double* y, const f77_int* incy - ) -{ - return ddot_blis_impl( n, x, incx, y, incy ); -} #ifdef BLIS_DISABLE_COMPLEX_RETURN_INTEL -scomplex cdotu_blis_impl +scomplex cdotu_ ( const f77_int* n, const scomplex* x, const f77_int* incx, @@ -422,17 +393,8 @@ scomplex cdotu_blis_impl AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); return rho; } -scomplex cdotu_ - ( - const f77_int* n, - const scomplex* x, const f77_int* incx, - const scomplex* y, const f77_int* incy - ) -{ - return cdotu_blis_impl( n, x, incx, y, incy ); -} -dcomplex zdotu_blis_impl +dcomplex zdotu_ ( const f77_int* n, const dcomplex* x, const f77_int* incx, @@ -535,17 +497,9 @@ dcomplex zdotu_blis_impl return rho; } -dcomplex zdotu_ - ( - const f77_int* n, - const dcomplex* x, const f77_int* incx, - const dcomplex* y, const f77_int* incy - ) -{ - return zdotu_blis_impl( n, x, incx, y, incy ); -} -scomplex cdotc_blis_impl + +scomplex cdotc_ ( const f77_int* n, const scomplex* x, const f77_int* incx, @@ -647,17 +601,8 @@ scomplex cdotc_blis_impl return rho; } -scomplex cdotc_ - ( - const f77_int* n, - const scomplex* x, const f77_int* incx, - const scomplex* y, const f77_int* incy - ) -{ - return cdotc_blis_impl( n, x, incx, y, incy ); -} -dcomplex zdotc_blis_impl +dcomplex zdotc_ ( const f77_int* n, const dcomplex* x, const f77_int* incx, @@ -763,21 +708,13 @@ dcomplex zdotc_blis_impl return rho; } -dcomplex zdotc_ - ( - const f77_int* n, - const dcomplex* x, const f77_int* incx, - const dcomplex* y, const f77_int* incy - ) -{ - return zdotc_blis_impl( n, x, incx, y, incy ); -} + #else // BLIS_DISABLE_COMPLEX_RETURN_INTEL // For the "intel" complex return type, use a hidden parameter to return the result #undef GENTFUNCDOT #define GENTFUNCDOT( ftype, ch, chc, blis_conjx, blasname, blisname ) \ \ -void PASTEF772S(ch,blasname,chc) \ +void PASTEF772(ch,blasname,chc) \ ( \ ftype* rhop, \ const f77_int* n, \ @@ -823,17 +760,6 @@ void PASTEF772S(ch,blasname,chc) \ bli_finalize_auto(); \ \ *rhop = rho; \ -}\ -\ -void PASTEF772(ch,blasname,chc) \ - ( \ - ftype* rhop, \ - const f77_int* n, \ - const ftype* x, const f77_int* incx, \ - const ftype* y, const f77_int* incy \ - ) \ -{ \ - PASTEF772S(ch,blasname,chc)( rhop, n, x, incx, y, incy );\ } INSERT_GENTFUNCDOTC_BLAS( dot, dotv ) @@ -845,7 +771,7 @@ INSERT_GENTFUNCDOTC_BLAS( dot, dotv ) // Input vectors stored in single precision, computed in double precision, // with result returned in single precision. -float PASTEF77S(sd,sdot) +float PASTEF77(sd,sdot) ( const f77_int* n, const float* sb, @@ -856,7 +782,7 @@ float PASTEF77S(sd,sdot) return ( float ) ( ( double )(*sb) + - PASTEF77S(d,sdot) + PASTEF77(d,sdot) ( n, x, incx, @@ -864,20 +790,10 @@ float PASTEF77S(sd,sdot) ) ); } -float PASTEF77(sd,sdot) - ( - const f77_int* n, - const float* sb, - const float* x, const f77_int* incx, - const float* y, const f77_int* incy - ) -{ - return PASTEF77S(sd,sdot)( n,sb, x, incx, y, incy ); -} // Input vectors stored in single precision, computed in double precision, // with result returned in double precision. -double PASTEF77S(d,sdot) +double PASTEF77(d,sdot) ( const f77_int* n, const float* x, const f77_int* incx, @@ -921,14 +837,5 @@ double PASTEF77S(d,sdot) return rho; } -double PASTEF77(d,sdot) - ( - const f77_int* n, - const float* x, const f77_int* incx, - const float* y, const f77_int* incy - ) -{ - return PASTEF77S(d,sdot)( n, x, incx, y, incy ); -} #endif diff --git a/frame/compat/bla_nrm2.c b/frame/compat/bla_nrm2.c index a823747ec9..576d9eda8c 100755 --- a/frame/compat/bla_nrm2.c +++ b/frame/compat/bla_nrm2.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020-2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -42,7 +42,7 @@ #undef GENTFUNCR2 #define GENTFUNCR2( ftype_x, ftype_r, chx, chr, blasname, blisname ) \ \ -ftype_r PASTEF772S(chr,chx,blasname) \ +ftype_r PASTEF772(chr,chx,blasname) \ ( \ const f77_int* n, \ const ftype_x* x, const f77_int* incx \ @@ -80,15 +80,6 @@ ftype_r PASTEF772S(chr,chx,blasname) \ bli_finalize_auto(); \ \ return norm; \ -}\ -\ -ftype_r PASTEF772(chr,chx,blasname) \ - ( \ - const f77_int* n, \ - const ftype_x* x, const f77_int* incx \ - ) \ -{ \ - return PASTEF772S(chr,chx,blasname)( n, x, incx );\ } #ifdef BLIS_ENABLE_BLAS diff --git a/frame/compat/bla_nrm2.h b/frame/compat/bla_nrm2.h index f690b4e071..a8bc25ef48 100644 --- a/frame/compat/bla_nrm2.h +++ b/frame/compat/bla_nrm2.h @@ -5,8 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. - + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -41,12 +40,6 @@ #define GENTPROTR2( ftype_x, ftype_r, chx, chr, blasname ) \ \ BLIS_EXPORT_BLAS ftype_r PASTEF772(chr,chx,blasname) \ - ( \ - const f77_int* n, \ - const ftype_x* x, const f77_int* incx \ - );\ -\ -BLIS_EXPORT_BLAS ftype_r PASTEF772S(chr,chx,blasname) \ ( \ const f77_int* n, \ const ftype_x* x, const f77_int* incx \ diff --git a/frame/compat/bla_scal.c b/frame/compat/bla_scal.c index 44a97144eb..b9651577eb 100644 --- a/frame/compat/bla_scal.c +++ b/frame/compat/bla_scal.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020-2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020-22, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -42,7 +42,7 @@ #undef GENTFUNCSCAL #define GENTFUNCSCAL( ftype_x, ftype_a, chx, cha, blasname, blisname ) \ \ -void PASTEF772S(chx,cha,blasname) \ +void PASTEF772(chx,cha,blasname) \ ( \ const f77_int* n, \ const ftype_a* alpha, \ @@ -90,15 +90,6 @@ void PASTEF772S(chx,cha,blasname) \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ -}\ -void PASTEF772(chx,cha,blasname) \ - ( \ - const f77_int* n, \ - const ftype_a* alpha, \ - ftype_x* x, const f77_int* incx \ - ) \ -{ \ - PASTEF772S(chx,cha,blasname)( n, alpha, x, incx ); \ } #ifdef BLIS_ENABLE_BLAS diff --git a/frame/compat/bla_scal.h b/frame/compat/bla_scal.h index c3b4540bb0..c8e898b6ba 100644 --- a/frame/compat/bla_scal.h +++ b/frame/compat/bla_scal.h @@ -5,8 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. - + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -41,13 +40,6 @@ #define GENTPROTSCAL( ftype_a, ftype_x, cha, chx, blasname ) \ \ BLIS_EXPORT_BLAS void PASTEF772(chx,cha,blasname) \ - ( \ - const f77_int* n, \ - const ftype_a* alpha, \ - ftype_x* x, const f77_int* incx \ - );\ -\ -BLIS_EXPORT_BLAS void PASTEF772S(chx,cha,blasname) \ ( \ const f77_int* n, \ const ftype_a* alpha, \ diff --git a/frame/compat/bla_scal_amd.c b/frame/compat/bla_scal_amd.c index 518058b060..178776a149 100644 --- a/frame/compat/bla_scal_amd.c +++ b/frame/compat/bla_scal_amd.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020-2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020-22, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -42,7 +42,7 @@ #undef GENTFUNCSCAL #define GENTFUNCSCAL( ftype_x, ftype_a, chx, cha, blasname, blisname ) \ \ -void PASTEF772S(chx,cha,blasname) \ +void PASTEF772(chx,cha,blasname) \ ( \ const f77_int* n, \ const ftype_a* alpha, \ @@ -90,20 +90,11 @@ void PASTEF772S(chx,cha,blasname) \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ -}\ -void PASTEF772(chx,cha,blasname) \ - ( \ - const f77_int* n, \ - const ftype_a* alpha, \ - ftype_x* x, const f77_int* incx \ - ) \ -{ \ - PASTEF772S(chx,cha,blasname)( n, alpha, x, incx ); \ } #ifdef BLIS_ENABLE_BLAS -void sscal_blis_impl +void sscal_ ( const f77_int* n, const float* alpha, @@ -182,17 +173,8 @@ void sscal_blis_impl // bli_finalize_auto(); AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) } -void sscal_ - ( - const f77_int* n, - const float* alpha, - float* x, const f77_int* incx - ) -{ - sscal_blis_impl( n, alpha, x, incx ); -} -void dscal_blis_impl +void dscal_ ( const f77_int* n, const double* alpha, @@ -272,15 +254,6 @@ void dscal_blis_impl // bli_finalize_auto(); AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) } -void dscal_ - ( - const f77_int* n, - const double* alpha, - double* x, const f77_int* incx - ) -{ - dscal_blis_impl( n, alpha, x, incx ); -} INSERT_GENTFUNCSCAL_BLAS_CZ( scal, scalv ) diff --git a/frame/compat/cblas/f77_sub/f77_amax_sub.c b/frame/compat/cblas/f77_sub/f77_amax_sub.c index c394ed4d40..cc26196d79 100644 --- a/frame/compat/cblas/f77_sub/f77_amax_sub.c +++ b/frame/compat/cblas/f77_sub/f77_amax_sub.c @@ -5,8 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. - + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -43,28 +42,18 @@ #undef GENTFUNC #define GENTFUNC( ftype_x, chx, blasname, blisname ) \ \ -void PASTEF773S(i,chx,blasname,sub) \ +void PASTEF773(i,chx,blasname,sub) \ ( \ const f77_int* n, \ const ftype_x* x, const f77_int* incx, \ f77_int* rval \ ) \ { \ - *rval = PASTEF772S(i,chx,blasname) \ + *rval = PASTEF772(i,chx,blasname) \ ( \ n, \ x, incx \ ); \ -}\ -\ -void PASTEF773(i,chx,blasname,sub) \ - ( \ - const f77_int* n, \ - const ftype_x* x, const f77_int* incx, \ - f77_int* rval \ - ) \ -{ \ - PASTEF773S(i,chx,blasname,sub) ( n, x, incx, rval );\ } #ifdef BLIS_ENABLE_CBLAS diff --git a/frame/compat/cblas/f77_sub/f77_amax_sub.h b/frame/compat/cblas/f77_sub/f77_amax_sub.h index 35d501ba4a..9cd1202d26 100644 --- a/frame/compat/cblas/f77_sub/f77_amax_sub.h +++ b/frame/compat/cblas/f77_sub/f77_amax_sub.h @@ -5,8 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. - + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -41,13 +40,6 @@ #define GENTPROT( ftype_x, chx, blasname ) \ \ BLIS_EXPORT_BLAS void PASTEF773(i,chx,blasname,sub) \ - ( \ - const f77_int* n, \ - const ftype_x* x, const f77_int* incx, \ - f77_int* rval \ - );\ -\ -BLIS_EXPORT_BLAS void PASTEF773S(i,chx,blasname,sub) \ ( \ const f77_int* n, \ const ftype_x* x, const f77_int* incx, \ diff --git a/frame/compat/cblas/f77_sub/f77_amin_sub.c b/frame/compat/cblas/f77_sub/f77_amin_sub.c index 2eaa231061..73e1951839 100644 --- a/frame/compat/cblas/f77_sub/f77_amin_sub.c +++ b/frame/compat/cblas/f77_sub/f77_amin_sub.c @@ -4,8 +4,8 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2020-2022, Advanced Micro Devices, Inc. All rights reserved. - + Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -42,28 +42,18 @@ #undef GENTFUNC #define GENTFUNC( ftype_x, chx, blasname, blisname ) \ \ -void PASTEF773S(i,chx,blasname,sub) \ +void PASTEF773(i,chx,blasname,sub) \ ( \ const f77_int* n, \ const ftype_x* x, const f77_int* incx, \ f77_int* rval \ ) \ { \ - *rval = PASTEF772S(i,chx,blasname) \ + *rval = PASTEF772(i,chx,blasname) \ ( \ n, \ x, incx \ ); \ -}\ -\ -void PASTEF773(i,chx,blasname,sub) \ - ( \ - const f77_int* n, \ - const ftype_x* x, const f77_int* incx, \ - f77_int* rval \ - ) \ -{ \ - PASTEF773S(i,chx,blasname,sub) ( n, x, incx, rval );\ } #ifdef BLIS_ENABLE_CBLAS diff --git a/frame/compat/cblas/f77_sub/f77_amin_sub.h b/frame/compat/cblas/f77_sub/f77_amin_sub.h index 90b4f25b5f..522dcc7938 100644 --- a/frame/compat/cblas/f77_sub/f77_amin_sub.h +++ b/frame/compat/cblas/f77_sub/f77_amin_sub.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2020-2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -39,13 +39,6 @@ #define GENTPROT( ftype_x, chx, blasname ) \ \ BLIS_EXPORT_BLAS void PASTEF773(i,chx,blasname,sub) \ - ( \ - const f77_int* n, \ - const ftype_x* x, const f77_int* incx, \ - f77_int* rval \ - );\ -\ -BLIS_EXPORT_BLAS void PASTEF773S(i,chx,blasname,sub) \ ( \ const f77_int* n, \ const ftype_x* x, const f77_int* incx, \ diff --git a/frame/compat/cblas/f77_sub/f77_asum_sub.c b/frame/compat/cblas/f77_sub/f77_asum_sub.c index befac150e0..f1cb35b0cc 100644 --- a/frame/compat/cblas/f77_sub/f77_asum_sub.c +++ b/frame/compat/cblas/f77_sub/f77_asum_sub.c @@ -5,8 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. - + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -43,28 +42,18 @@ #undef GENTFUNCR2 #define GENTFUNCR2( ftype_x, ftype_r, chx, chr, blasname, blisname ) \ \ -void PASTEF773S(chr,chx,blasname,sub) \ +void PASTEF773(chr,chx,blasname,sub) \ ( \ const f77_int* n, \ const ftype_x* x, const f77_int* incx, \ ftype_r* rval \ ) \ { \ - *rval = PASTEF772S(chr,chx,blasname) \ + *rval = PASTEF772(chr,chx,blasname) \ ( \ n, \ x, incx \ ); \ -}\ -\ -void PASTEF773(chr,chx,blasname,sub) \ - ( \ - const f77_int* n, \ - const ftype_x* x, const f77_int* incx, \ - ftype_r* rval \ - ) \ -{ \ - PASTEF773S(chr,chx,blasname,sub) ( n, x, incx, rval ); \ } #ifdef BLIS_ENABLE_CBLAS diff --git a/frame/compat/cblas/f77_sub/f77_asum_sub.h b/frame/compat/cblas/f77_sub/f77_asum_sub.h index de3d99bfc9..4b8634c166 100644 --- a/frame/compat/cblas/f77_sub/f77_asum_sub.h +++ b/frame/compat/cblas/f77_sub/f77_asum_sub.h @@ -5,8 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. - + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -41,13 +40,6 @@ #define GENTPROTR2( ftype_x, ftype_r, chx, chr, blasname ) \ \ BLIS_EXPORT_BLAS void PASTEF773(chr,chx,blasname,sub) \ - ( \ - const f77_int* n, \ - const ftype_x* x, const f77_int* incx, \ - ftype_r* rval \ - );\ -\ -BLIS_EXPORT_BLAS void PASTEF773S(chr,chx,blasname,sub) \ ( \ const f77_int* n, \ const ftype_x* x, const f77_int* incx, \ diff --git a/frame/compat/cblas/f77_sub/f77_dot_sub.c b/frame/compat/cblas/f77_sub/f77_dot_sub.c index f497ab97f0..0ca80464d3 100644 --- a/frame/compat/cblas/f77_sub/f77_dot_sub.c +++ b/frame/compat/cblas/f77_sub/f77_dot_sub.c @@ -5,8 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. - + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -44,7 +43,7 @@ #undef GENTFUNCDOT #define GENTFUNCDOT( ftype, ch, chc, blis_conjx, blasname, blisname ) \ \ -void PASTEF773S(ch,blasname,chc,sub) \ +void PASTEF773(ch,blasname,chc,sub) \ ( \ const f77_int* n, \ const ftype* x, const f77_int* incx, \ @@ -52,23 +51,12 @@ void PASTEF773S(ch,blasname,chc,sub) \ ftype* rval \ ) \ { \ - *rval = PASTEF772S(ch,blasname,chc) \ + *rval = PASTEF772(ch,blasname,chc) \ ( \ n, \ x, incx, \ y, incy \ ); \ -}\ -\ -void PASTEF773(ch,blasname,chc,sub) \ - ( \ - const f77_int* n, \ - const ftype* x, const f77_int* incx, \ - const ftype* y, const f77_int* incy, \ - ftype* rval \ - ) \ -{ \ - PASTEF773S(ch,blasname,chc,sub)( n, x, incx, y, incy, rval); \ } INSERT_GENTFUNCDOTR_BLAS( dot, NULL ) @@ -87,7 +75,7 @@ INSERT_GENTFUNCDOTC_BLAS( dot, NULL ) #undef GENTFUNCDOT #define GENTFUNCDOT( ftype, ch, chc, blis_conjx, blasname, blisname ) \ \ -void PASTEF773S(ch,blasname,chc,sub) \ +void PASTEF773(ch,blasname,chc,sub) \ ( \ const f77_int* n, \ const ftype* x, const f77_int* incx, \ @@ -102,17 +90,6 @@ void PASTEF773S(ch,blasname,chc,sub) \ x, incx, \ y, incy \ ); \ -}\ -\ -void PASTEF773(ch,blasname,chc,sub) \ - ( \ - const f77_int* n, \ - const ftype* x, const f77_int* incx, \ - const ftype* y, const f77_int* incy, \ - ftype* rval \ - ) \ -{ \ - PASTEF773S(ch,blasname,chc,sub)( n, x, incx, y, incy, rval); \ } INSERT_GENTFUNCDOTC_BLAS( dot, NULL ) @@ -123,7 +100,7 @@ INSERT_GENTFUNCDOTC_BLAS( dot, NULL ) // Input vectors stored in single precision, computed in double precision, // with result returned in single precision. -void PASTEF772S(sds,dot,sub) +void PASTEF772(sds,dot,sub) ( const f77_int* n, const float* sb, @@ -132,7 +109,7 @@ void PASTEF772S(sds,dot,sub) float* rval ) { - *rval = PASTEF77S(sds,dot) + *rval = PASTEF77(sds,dot) ( n, sb, @@ -140,21 +117,10 @@ void PASTEF772S(sds,dot,sub) y, incy ); } -void PASTEF772(sds,dot,sub) - ( - const f77_int* n, - const float* sb, - const float* x, const f77_int* incx, - const float* y, const f77_int* incy, - float* rval - ) -{ - PASTEF772S(sds,dot,sub)( n, sb, x, incx, y, incy, rval); -} // Input vectors stored in single precision, computed in double precision, // with result returned in double precision. -void PASTEF772S(ds,dot,sub) +void PASTEF772(ds,dot,sub) ( const f77_int* n, const float* x, const f77_int* incx, @@ -162,23 +128,13 @@ void PASTEF772S(ds,dot,sub) double* rval ) { - *rval = PASTEF77S(ds,dot) + *rval = PASTEF77(ds,dot) ( n, x, incx, y, incy ); } -void PASTEF772(ds,dot,sub) - ( - const f77_int* n, - const float* x, const f77_int* incx, - const float* y, const f77_int* incy, - double* rval - ) -{ - PASTEF772S(ds,dot,sub)( n, x, incx, y, incy, rval); -} #endif diff --git a/frame/compat/cblas/f77_sub/f77_dot_sub.h b/frame/compat/cblas/f77_sub/f77_dot_sub.h index 54a40a9a02..8aab2728bf 100644 --- a/frame/compat/cblas/f77_sub/f77_dot_sub.h +++ b/frame/compat/cblas/f77_sub/f77_dot_sub.h @@ -5,8 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. - + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -41,14 +40,6 @@ #define GENTPROTDOT( ftype, ch, chc, blasname ) \ \ BLIS_EXPORT_BLAS void PASTEF773(ch,blasname,chc,sub) \ - ( \ - const f77_int* n, \ - const ftype* x, const f77_int* incx, \ - const ftype* y, const f77_int* incy, \ - ftype* rval \ - );\ -\ -BLIS_EXPORT_BLAS void PASTEF773S(ch,blasname,chc,sub) \ ( \ const f77_int* n, \ const ftype* x, const f77_int* incx, \ @@ -70,15 +61,7 @@ BLIS_EXPORT_BLAS void PASTEF772(sds,dot,sub) const float* y, const f77_int* incy, float* rval ); -BLIS_EXPORT_BLAS void PASTEF772S(sds,dot,sub) - ( - const f77_int* n, - const float* sb, - const float* x, const f77_int* incx, - const float* y, const f77_int* incy, - float* rval - ); - + BLIS_EXPORT_BLAS void PASTEF772(ds,dot,sub) ( const f77_int* n, @@ -86,12 +69,4 @@ BLIS_EXPORT_BLAS void PASTEF772(ds,dot,sub) const float* y, const f77_int* incy, double* rval ); -BLIS_EXPORT_BLAS void PASTEF772S(ds,dot,sub) - ( - const f77_int* n, - const float* x, const f77_int* incx, - const float* y, const f77_int* incy, - double* rval - ); - #endif diff --git a/frame/compat/cblas/f77_sub/f77_nrm2_sub.c b/frame/compat/cblas/f77_sub/f77_nrm2_sub.c index 72fa07593a..54ce1a5b49 100644 --- a/frame/compat/cblas/f77_sub/f77_nrm2_sub.c +++ b/frame/compat/cblas/f77_sub/f77_nrm2_sub.c @@ -5,8 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. - + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -43,28 +42,18 @@ #undef GENTFUNCR2 #define GENTFUNCR2( ftype_x, ftype_r, chx, chr, blasname, blisname ) \ \ -void PASTEF773S(chr,chx,blasname,sub) \ +void PASTEF773(chr,chx,blasname,sub) \ ( \ const f77_int* n, \ const ftype_x* x, const f77_int* incx, \ ftype_r* rval \ ) \ { \ - *rval = PASTEF772S(chr,chx,blasname) \ + *rval = PASTEF772(chr,chx,blasname) \ ( \ n, \ x, incx \ ); \ -}\ -\ -void PASTEF773(chr,chx,blasname,sub) \ - ( \ - const f77_int* n, \ - const ftype_x* x, const f77_int* incx, \ - ftype_r* rval \ - ) \ -{ \ - PASTEF773S(chr,chx,blasname,sub)( n, x, incx, rval );\ } #ifdef BLIS_ENABLE_CBLAS diff --git a/frame/compat/cblas/f77_sub/f77_nrm2_sub.h b/frame/compat/cblas/f77_sub/f77_nrm2_sub.h index dbe2809741..c51a94292b 100644 --- a/frame/compat/cblas/f77_sub/f77_nrm2_sub.h +++ b/frame/compat/cblas/f77_sub/f77_nrm2_sub.h @@ -5,8 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. - + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -41,13 +40,6 @@ #define GENTPROTR2( ftype_x, ftype_r, chx, chr, blasname ) \ \ BLIS_EXPORT_BLAS void PASTEF773(chr,chx,blasname,sub) \ - ( \ - const f77_int* n, \ - const ftype_x* x, const f77_int* incx, \ - ftype_r* rval \ - );\ -\ -BLIS_EXPORT_BLAS void PASTEF773S(chr,chx,blasname,sub) \ ( \ const f77_int* n, \ const ftype_x* x, const f77_int* incx, \ diff --git a/frame/compat/cblas/src/cblas_f77.h b/frame/compat/cblas/src/cblas_f77.h index 864d78895e..d13d833ab9 100644 --- a/frame/compat/cblas/src/cblas_f77.h +++ b/frame/compat/cblas/src/cblas_f77.h @@ -207,52 +207,53 @@ * Level 1 BLAS */ #define F77_xerbla xerbla_ -#define F77_srotg srotg_blis_impl -#define F77_srotmg srotmg_blis_impl -#define F77_srot srot_blis_impl -#define F77_srotm srotm_blis_impl -#define F77_drotg drotg_blis_impl -#define F77_drotmg drotmg_blis_impl -#define F77_drot drot_blis_impl -#define F77_drotm drotm_blis_impl +#define F77_srotg srotg_ +#define F77_srotmg srotmg_ +#define F77_srot srot_ +#define F77_srotm srotm_ +#define F77_drotg drotg_ +#define F77_drotmg drotmg_ +#define F77_drot drot_ +#define F77_drotm drotm_ #define F77_sswap sswap_ -#define F77_scopy scopy_blis_impl -#define F77_saxpy saxpy_blis_impl -#define F77_isamax_sub isamaxsub_blis_impl +#define F77_scopy scopy_ +#define F77_saxpy saxpy_ +#define F77_isamax_sub isamaxsub_ #define F77_dswap dswap_ -#define F77_dcopy dcopy_blis_impl -#define F77_daxpy daxpy_blis_impl -#define F77_idamax_sub idamaxsub_blis_impl +#define F77_dcopy dcopy_ +#define F77_daxpy daxpy_ +#define F77_idamax_sub idamaxsub_ #define F77_cswap cswap_ -#define F77_ccopy ccopy_blis_impl -#define F77_caxpy caxpy_blis_impl -#define F77_icamax_sub icamaxsub_blis_impl +#define F77_ccopy ccopy_ +#define F77_caxpy caxpy_ +#define F77_icamax_sub icamaxsub_ #define F77_zswap zswap_ -#define F77_zcopy zcopy_blis_impl -#define F77_zaxpy zaxpy_blis_impl -#define F77_izamax_sub izamaxsub_blis_impl -#define F77_sdot_sub sdotsub_blis_impl -#define F77_ddot_sub ddotsub_blis_impl -#define F77_dsdot_sub dsdotsub_blis_impl -#define F77_sscal sscal_blis_impl -#define F77_dscal dscal_blis_impl -#define F77_cscal cscal_blis_impl -#define F77_zscal zscal_blis_impl -#define F77_csscal csscal_blis_impl -#define F77_zdscal zdscal_blis_impl -#define F77_cdotu_sub cdotusub_blis_impl -#define F77_cdotc_sub cdotcsub_blis_impl -#define F77_zdotu_sub zdotusub_blis_impl -#define F77_zdotc_sub zdotcsub_blis_impl -#define F77_snrm2_sub snrm2sub_blis_impl -#define F77_sasum_sub sasumsub_blis_impl -#define F77_dnrm2_sub dnrm2sub_blis_impl -#define F77_dasum_sub dasumsub_blis_impl -#define F77_scnrm2_sub scnrm2sub_blis_impl -#define F77_scasum_sub scasumsub_blis_impl -#define F77_dznrm2_sub dznrm2sub_blis_impl -#define F77_dzasum_sub dzasumsub_blis_impl -#define F77_sdsdot_sub sdsdotsub_blis_impl +#define F77_zcopy zcopy_ +#define F77_zaxpy zaxpy_ +#define F77_zaxpby zaxpby_ +#define F77_izamax_sub izamaxsub_ +#define F77_sdot_sub sdotsub_ +#define F77_ddot_sub ddotsub_ +#define F77_dsdot_sub dsdotsub_ +#define F77_sscal sscal_ +#define F77_dscal dscal_ +#define F77_cscal cscal_ +#define F77_zscal zscal_ +#define F77_csscal csscal_ +#define F77_zdscal zdscal_ +#define F77_cdotu_sub cdotusub_ +#define F77_cdotc_sub cdotcsub_ +#define F77_zdotu_sub zdotusub_ +#define F77_zdotc_sub zdotcsub_ +#define F77_snrm2_sub snrm2sub_ +#define F77_sasum_sub sasumsub_ +#define F77_dnrm2_sub dnrm2sub_ +#define F77_dasum_sub dasumsub_ +#define F77_scnrm2_sub scnrm2sub_ +#define F77_scasum_sub scasumsub_ +#define F77_dznrm2_sub dznrm2sub_ +#define F77_dzasum_sub dzasumsub_ +#define F77_sdsdot_sub sdsdotsub_ /* * Level 2 BLAS */ @@ -370,17 +371,17 @@ * -- BLAS Extension APIs -- */ -#define F77_saxpby saxpby_blis_impl -#define F77_daxpby daxpby_blis_impl -#define F77_caxpby caxpby_blis_impl -#define F77_zaxpby zaxpby_blis_impl +#define F77_saxpby saxpby_ +#define F77_daxpby daxpby_ +#define F77_caxpby caxpby_ +#define F77_zaxpby zaxpby_ #define F77_cgemm3m cgemm3m_blis_impl #define F77_zgemm3m zgemm3m_blis_impl -#define F77_isamin_sub isaminsub_blis_impl -#define F77_idamin_sub idaminsub_blis_impl -#define F77_icamin_sub icaminsub_blis_impl -#define F77_izamin_sub izaminsub_blis_impl +#define F77_isamin_sub isaminsub_ +#define F77_idamin_sub idaminsub_ +#define F77_icamin_sub icaminsub_ +#define F77_izamin_sub izaminsub_ // -- Batch APIs -- #define F77_sgemm_batch sgemm_batch_ @@ -389,4 +390,4 @@ #define F77_zgemm_batch zgemm_batch_ #endif -#endif /* CBLAS_F77_H */ +#endif /* CBLAS_F77_H */ \ No newline at end of file diff --git a/frame/compat/f2c/bla_rot.c b/frame/compat/f2c/bla_rot.c index f66aad12c0..c79769bc05 100644 --- a/frame/compat/f2c/bla_rot.c +++ b/frame/compat/f2c/bla_rot.c @@ -5,8 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. - + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -42,8 +41,7 @@ -lf2c -lm (in that order) */ -/* Subroutine */ -int PASTEF77S(s,rot)(const bla_integer *n, bla_real *sx, const bla_integer *incx, bla_real *sy, const bla_integer *incy, const bla_real *c__, const bla_real *s) +/* Subroutine */ int PASTEF77(s,rot)(const bla_integer *n, bla_real *sx, const bla_integer *incx, bla_real *sy, const bla_integer *incy, const bla_real *c__, const bla_real *s) { /* System generated locals */ bla_integer i__1; @@ -111,8 +109,7 @@ int PASTEF77S(s,rot)(const bla_integer *n, bla_real *sx, const bla_integer *incx -lf2c -lm (in that order) */ -/* Subroutine */ -int PASTEF77S(d,rot)(const bla_integer *n, bla_double *dx, const bla_integer *incx, bla_double *dy, const bla_integer *incy, const bla_double *c__, const bla_double *s) +/* Subroutine */ int PASTEF77(d,rot)(const bla_integer *n, bla_double *dx, const bla_integer *incx, bla_double *dy, const bla_integer *incy, const bla_double *c__, const bla_double *s) { /* System generated locals */ bla_integer i__1; @@ -175,16 +172,6 @@ int PASTEF77S(d,rot)(const bla_integer *n, bla_double *dx, const bla_integer *in return 0; } /* drot_ */ -int PASTEF77(s,rot)(const bla_integer *n, bla_real *sx, const bla_integer *incx, bla_real *sy, const bla_integer *incy, const bla_real *c__, const bla_real *s) -{ - return PASTEF77S(s,rot)( n, sx, incx, sy, incy, c__, s ); -} - -int PASTEF77(d,rot)(const bla_integer *n, bla_double *dx, const bla_integer *incx, bla_double *dy, const bla_integer *incy, const bla_double *c__, const bla_double *s) -{ - return PASTEF77S(d,rot)( n, dx, incx, dy, incy, c__, s ); -} - /* csrot.f -- translated by f2c (version 19991025). You must link the resulting object file with the libraries: -lf2c -lm (in that order) diff --git a/frame/compat/f2c/bla_rot.h b/frame/compat/f2c/bla_rot.h index c8bd42c254..6093555600 100644 --- a/frame/compat/f2c/bla_rot.h +++ b/frame/compat/f2c/bla_rot.h @@ -5,8 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. - + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -39,7 +38,5 @@ BLIS_EXPORT_BLAS int PASTEF77(s,rot)(const bla_integer *n, bla_real *sx, const b BLIS_EXPORT_BLAS int PASTEF77(d,rot)(const bla_integer *n, bla_double *dx, const bla_integer *incx, bla_double *dy, const bla_integer *incy, const bla_double *c__, const bla_double *s); BLIS_EXPORT_BLAS int PASTEF77(cs,rot)(const bla_integer *n, bla_scomplex *cx, const bla_integer *incx, bla_scomplex *cy, const bla_integer *incy, const bla_real *c__, const bla_real *s); BLIS_EXPORT_BLAS int PASTEF77(zd,rot)(const bla_integer *n, bla_dcomplex *zx, const bla_integer *incx, bla_dcomplex *zy, const bla_integer *incy, const bla_double *c__, const bla_double *s); -BLIS_EXPORT_BLAS int PASTEF77S(s,rot)(const bla_integer *n, bla_real *sx, const bla_integer *incx, bla_real *sy, const bla_integer *incy, const bla_real *c__, const bla_real *s); -BLIS_EXPORT_BLAS int PASTEF77S(d,rot)(const bla_integer *n, bla_double *dx, const bla_integer *incx, bla_double *dy, const bla_integer *incy, const bla_double *c__, const bla_double *s); #endif diff --git a/frame/compat/f2c/bla_rotg.c b/frame/compat/f2c/bla_rotg.c index 613574c955..1572689f57 100644 --- a/frame/compat/f2c/bla_rotg.c +++ b/frame/compat/f2c/bla_rotg.c @@ -5,8 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. - + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -46,8 +45,7 @@ static bla_real sc_b4 = 1.f; -/* Subroutine */ -int PASTEF77S(s,rotg)(bla_real *sa, bla_real *sb, bla_real *c__, bla_real *s) +/* Subroutine */ int PASTEF77(s,rotg)(bla_real *sa, bla_real *sb, bla_real *c__, bla_real *s) { /* System generated locals */ bla_real r__1, r__2; @@ -107,8 +105,7 @@ int PASTEF77S(s,rotg)(bla_real *sa, bla_real *sb, bla_real *c__, bla_real *s) static bla_double dc_b4 = 1.; -/* Subroutine */ -int PASTEF77S(d,rotg)(bla_double *da, bla_double *db, bla_double *c__, bla_double *s) +/* Subroutine */ int PASTEF77(d,rotg)(bla_double *da, bla_double *db, bla_double *c__, bla_double *s) { /* System generated locals */ bla_double d__1, d__2; @@ -159,16 +156,6 @@ int PASTEF77S(d,rotg)(bla_double *da, bla_double *db, bla_double *c__, bla_doubl return 0; } /* drotg_ */ -int PASTEF77(s,rotg)(bla_real *sa, bla_real *sb, bla_real *c__, bla_real *s) -{ - return PASTEF77S(s,rotg)( sa, sb, c__, s ); -} - -int PASTEF77(d,rotg)(bla_double *da, bla_double *db, bla_double *c__, bla_double *s) -{ - return PASTEF77S(d,rotg)( da, db, c__, s ); -} - /* crotg.f -- translated by f2c (version 19991025). You must link the resulting object file with the libraries: -lf2c -lm (in that order) diff --git a/frame/compat/f2c/bla_rotg.h b/frame/compat/f2c/bla_rotg.h index 067f67c22a..b968ebbea2 100644 --- a/frame/compat/f2c/bla_rotg.h +++ b/frame/compat/f2c/bla_rotg.h @@ -5,8 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. - + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -39,7 +38,5 @@ BLIS_EXPORT_BLAS int PASTEF77(s,rotg)(bla_real *sa, bla_real *sb, bla_real *c__, BLIS_EXPORT_BLAS int PASTEF77(d,rotg)(bla_double *da, bla_double *db, bla_double *c__, bla_double *s); BLIS_EXPORT_BLAS int PASTEF77(c,rotg)(bla_scomplex *ca, bla_scomplex *cb, bla_real *c__, bla_scomplex *s); BLIS_EXPORT_BLAS int PASTEF77(z,rotg)(bla_dcomplex *ca, bla_dcomplex *cb, bla_double *c__, bla_dcomplex *s); -BLIS_EXPORT_BLAS int PASTEF77S(s,rotg)(bla_real *sa, bla_real *sb, bla_real *c__, bla_real *s); -BLIS_EXPORT_BLAS int PASTEF77S(d,rotg)(bla_double *da, bla_double *db, bla_double *c__, bla_double *s); #endif diff --git a/frame/compat/f2c/bla_rotm.c b/frame/compat/f2c/bla_rotm.c index b60442e387..003dea7155 100644 --- a/frame/compat/f2c/bla_rotm.c +++ b/frame/compat/f2c/bla_rotm.c @@ -5,8 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. - + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -42,8 +41,7 @@ -lf2c -lm (in that order) */ -/* Subroutine */ -int PASTEF77S(s,rotm)(const bla_integer *n, bla_real *sx, const bla_integer *incx, bla_real *sy, const bla_integer *incy, const bla_real *sparam) +/* Subroutine */ int PASTEF77(s,rotm)(const bla_integer *n, bla_real *sx, const bla_integer *incx, bla_real *sy, const bla_integer *incy, const bla_real *sparam) { /* Initialized data */ @@ -209,8 +207,7 @@ int PASTEF77S(s,rotm)(const bla_integer *n, bla_real *sx, const bla_integer *inc -lf2c -lm (in that order) */ -/* Subroutine */ -int PASTEF77S(d,rotm)(const bla_integer *n, bla_double *dx, const bla_integer *incx, bla_double *dy, const bla_integer *incy, const bla_double *dparam) +/* Subroutine */ int PASTEF77(d,rotm)(const bla_integer *n, bla_double *dx, const bla_integer *incx, bla_double *dy, const bla_integer *incy, const bla_double *dparam) { /* Initialized data */ @@ -371,15 +368,5 @@ int PASTEF77S(d,rotm)(const bla_integer *n, bla_double *dx, const bla_integer *i return 0; } /* drotm_ */ -int PASTEF77(s,rotm)(const bla_integer *n, bla_real *sx, const bla_integer *incx, bla_real *sy, const bla_integer *incy, const bla_real *sparam) -{ - return PASTEF77S(s,rotm)( n, sx, incx, sy, incy, sparam); -} - -int PASTEF77(d,rotm)(const bla_integer *n, bla_double *dx, const bla_integer *incx, bla_double *dy, const bla_integer *incy, const bla_double *dparam) -{ - return PASTEF77S(d,rotm)( n, dx, incx, dy, incy, dparam); -} - #endif diff --git a/frame/compat/f2c/bla_rotm.h b/frame/compat/f2c/bla_rotm.h index 8f610ba5a9..21906358be 100644 --- a/frame/compat/f2c/bla_rotm.h +++ b/frame/compat/f2c/bla_rotm.h @@ -5,8 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. - + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -37,7 +36,5 @@ BLIS_EXPORT_BLAS int PASTEF77(s,rotm)(const bla_integer *n, bla_real *sx, const bla_integer *incx, bla_real *sy, const bla_integer *incy, const bla_real *sparam); BLIS_EXPORT_BLAS int PASTEF77(d,rotm)(const bla_integer *n, bla_double *dx, const bla_integer *incx, bla_double *dy, const bla_integer *incy, const bla_double *dparam); -BLIS_EXPORT_BLAS int PASTEF77S(s,rotm)(const bla_integer *n, bla_real *sx, const bla_integer *incx, bla_real *sy, const bla_integer *incy, const bla_real *sparam); -BLIS_EXPORT_BLAS int PASTEF77S(d,rotm)(const bla_integer *n, bla_double *dx, const bla_integer *incx, bla_double *dy, const bla_integer *incy, const bla_double *dparam); #endif diff --git a/frame/compat/f2c/bla_rotmg.c b/frame/compat/f2c/bla_rotmg.c index a285e69a39..11ccc6f333 100644 --- a/frame/compat/f2c/bla_rotmg.c +++ b/frame/compat/f2c/bla_rotmg.c @@ -5,8 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. - + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -42,8 +41,7 @@ -lf2c -lm (in that order) */ -/* Subroutine */ -int PASTEF77S(s,rotmg)(bla_real *sd1, bla_real *sd2, bla_real *sx1, const bla_real *sy1, bla_real *sparam) +/* Subroutine */ int PASTEF77(s,rotmg)(bla_real *sd1, bla_real *sd2, bla_real *sx1, const bla_real *sy1, bla_real *sparam) { /* Initialized data */ @@ -283,8 +281,7 @@ int PASTEF77S(s,rotmg)(bla_real *sd1, bla_real *sd2, bla_real *sx1, const bla_re -lf2c -lm (in that order) */ -/* Subroutine */ -int PASTEF77S(d,rotmg)(bla_double *dd1, bla_double *dd2, bla_double *dx1, const bla_double *dy1, bla_double *dparam) +/* Subroutine */ int PASTEF77(d,rotmg)(bla_double *dd1, bla_double *dd2, bla_double *dx1, const bla_double *dy1, bla_double *dparam) { /* Initialized data */ @@ -519,15 +516,5 @@ int PASTEF77S(d,rotmg)(bla_double *dd1, bla_double *dd2, bla_double *dx1, const return 0; } /* drotmg_ */ -int PASTEF77(s,rotmg)(bla_real *sd1, bla_real *sd2, bla_real *sx1, const bla_real *sy1, bla_real *sparam) -{ - return PASTEF77S(s,rotmg)( sd1, sd2, sx1, sy1, sparam ); -} - -int PASTEF77(d,rotmg)(bla_double *dd1, bla_double *dd2, bla_double *dx1, const bla_double *dy1, bla_double *dparam) -{ - return PASTEF77S(d,rotmg)( dd1, dd2, dx1, dy1, dparam ); -} - #endif diff --git a/frame/compat/f2c/bla_rotmg.h b/frame/compat/f2c/bla_rotmg.h index fe3a10beb4..63e9710da1 100644 --- a/frame/compat/f2c/bla_rotmg.h +++ b/frame/compat/f2c/bla_rotmg.h @@ -5,8 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. - + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -37,7 +36,5 @@ BLIS_EXPORT_BLAS int PASTEF77(s,rotmg)(bla_real *sd1, bla_real *sd2, bla_real *sx1, const bla_real *sy1, bla_real *sparam); BLIS_EXPORT_BLAS int PASTEF77(d,rotmg)(bla_double *dd1, bla_double *dd2, bla_double *dx1, const bla_double *dy1, bla_double *dparam); -BLIS_EXPORT_BLAS int PASTEF77S(s,rotmg)(bla_real *sd1, bla_real *sd2, bla_real *sx1, const bla_real *sy1, bla_real *sparam); -BLIS_EXPORT_BLAS int PASTEF77S(d,rotmg)(bla_double *dd1, bla_double *dd2, bla_double *dx1, const bla_double *dy1, bla_double *dparam); #endif diff --git a/frame/util/bli_util_api_wrap.c b/frame/util/bli_util_api_wrap.c index 26810531d3..098a7c33e8 100644 --- a/frame/util/bli_util_api_wrap.c +++ b/frame/util/bli_util_api_wrap.c @@ -43,32 +43,32 @@ #ifndef BLIS_ENABLE_UPPERCASE_API void CAXPY(const f77_int *n,const scomplex *ca,const scomplex *cx,const f77_int *incx,scomplex *cy,const f77_int *incy) { - caxpy_blis_impl( n, ca, cx, incx, cy, incy); + caxpy_( n, ca, cx, incx, cy, incy); } void caxpy(const f77_int *n,const scomplex *ca,const scomplex *cx,const f77_int *incx,scomplex *cy,const f77_int *incy) { - caxpy_blis_impl( n, ca, cx, incx, cy, incy); + caxpy_( n, ca, cx, incx, cy, incy); } void CAXPY_(const f77_int *n,const scomplex *ca,const scomplex *cx,const f77_int *incx,scomplex *cy,const f77_int *incy) { - caxpy_blis_impl( n, ca, cx, incx, cy, incy); + caxpy_( n, ca, cx, incx, cy, incy); } void CCOPY(const f77_int *n,const scomplex *cx,const f77_int *incx,scomplex *cy,const f77_int *incy) { - ccopy_blis_impl( n, cx, incx, cy, incy); + ccopy_( n, cx, incx, cy, incy); } void ccopy(const f77_int *n,const scomplex *cx,const f77_int *incx,scomplex *cy,const f77_int *incy) { - ccopy_blis_impl( n, cx, incx, cy, incy); + ccopy_( n, cx, incx, cy, incy); } void CCOPY_(const f77_int *n,const scomplex *cx,const f77_int *incx,scomplex *cy,const f77_int *incy) { - ccopy_blis_impl( n, cx, incx, cy, incy); + ccopy_( n, cx, incx, cy, incy); } #ifdef BLIS_DISABLE_COMPLEX_RETURN_INTEL @@ -435,17 +435,17 @@ void CROTG_(scomplex *ca, bla_scomplex *cb, bla_real *c,scomplex *s) void CSCAL(const f77_int *n,const scomplex *ca,scomplex *cx,const f77_int *incx) { - cscal_blis_impl( n, ca, cx, incx); + cscal_( n, ca, cx, incx); } void cscal(const f77_int *n,const scomplex *ca,scomplex *cx,const f77_int *incx) { - cscal_blis_impl( n, ca, cx, incx); + cscal_( n, ca, cx, incx); } void CSCAL_(const f77_int *n,const scomplex *ca,scomplex *cx,const f77_int *incx) { - cscal_blis_impl( n, ca, cx, incx); + cscal_( n, ca, cx, incx); } void CSROT(const f77_int *n,scomplex *cx,const f77_int *incx,scomplex *cy,const f77_int *incy,const float *c,const float *s) @@ -465,17 +465,17 @@ void CSROT_(const f77_int *n,scomplex *cx,const f77_int *incx,scomplex *cy,con void CSSCAL(const f77_int *n,const float *sa,scomplex *cx,const f77_int *incx) { - csscal_blis_impl( n, sa, cx, incx); + csscal_( n, sa, cx, incx); } void csscal(const f77_int *n,const float *sa,scomplex *cx,const f77_int *incx) { - csscal_blis_impl( n, sa, cx, incx); + csscal_( n, sa, cx, incx); } void CSSCAL_(const f77_int *n,const float *sa,scomplex *cx,const f77_int *incx) { - csscal_blis_impl( n, sa, cx, incx); + csscal_( n, sa, cx, incx); } void CSWAP(const f77_int *n,scomplex *cx,const f77_int *incx,scomplex *cy,const f77_int *incy) @@ -675,17 +675,17 @@ double DASUM_(const f77_int *n,const double *dx,const f77_int *incx) void DAXPY(const f77_int *n,const double *da,const double *dx,const f77_int *incx,double *dy,const f77_int *incy) { - daxpy_blis_impl( n, da, dx, incx, dy, incy); + daxpy_( n, da, dx, incx, dy, incy); } void daxpy(const f77_int *n,const double *da,const double *dx,const f77_int *incx,double *dy,const f77_int *incy) { - daxpy_blis_impl( n, da, dx, incx, dy, incy); + daxpy_( n, da, dx, incx, dy, incy); } void DAXPY_(const f77_int *n,const double *da,const double *dx,const f77_int *incx,double *dy,const f77_int *incy) { - daxpy_blis_impl( n, da, dx, incx, dy, incy); + daxpy_( n, da, dx, incx, dy, incy); } double DCABS1(bla_dcomplex *z) @@ -705,17 +705,17 @@ double DCABS1_(bla_dcomplex *z) void DCOPY(const f77_int *n,const double *dx,const f77_int *incx,double *dy,const f77_int *incy) { - dcopy_blis_impl( n, dx, incx, dy, incy); + dcopy_( n, dx, incx, dy, incy); } void dcopy(const f77_int *n,const double *dx,const f77_int *incx,double *dy,const f77_int *incy) { - dcopy_blis_impl( n, dx, incx, dy, incy); + dcopy_( n, dx, incx, dy, incy); } void DCOPY_(const f77_int *n,const double *dx,const f77_int *incx,double *dy,const f77_int *incy) { - dcopy_blis_impl( n, dx, incx, dy, incy); + dcopy_( n, dx, incx, dy, incy); } double DDOT(const f77_int *n,const double *dx,const f77_int *incx,const double *dy,const f77_int *incy) @@ -810,62 +810,62 @@ double DNRM2_(const f77_int *n,const double *x,const f77_int *incx) void DROT(const f77_int *n,double *dx,const f77_int *incx,double *dy,const f77_int *incy,const double *c,const double *s) { - drot_blis_impl( n, dx, incx, dy, incy, c, s); + drot_( n, dx, incx, dy, incy, c, s); } void drot(const f77_int *n,double *dx,const f77_int *incx,double *dy,const f77_int *incy,const double *c,const double *s) { - drot_blis_impl( n, dx, incx, dy, incy, c, s); + drot_( n, dx, incx, dy, incy, c, s); } void DROT_(const f77_int *n,double *dx,const f77_int *incx,double *dy,const f77_int *incy,const double *c,const double *s) { - drot_blis_impl( n, dx, incx, dy, incy, c, s); + drot_( n, dx, incx, dy, incy, c, s); } void DROTG(double *da,double *db,double *c,double *s) { - drotg_blis_impl( da, db, c, s); + drotg_( da, db, c, s); } void drotg(double *da,double *db,double *c,double *s) { - drotg_blis_impl( da, db, c, s); + drotg_( da, db, c, s); } void DROTG_(double *da,double *db,double *c,double *s) { - drotg_blis_impl( da, db, c, s); + drotg_( da, db, c, s); } void DROTM(const f77_int *n,double *dx,const f77_int *incx,double *dy,const f77_int *incy,const double *dparam) { - drotm_blis_impl( n, dx, incx, dy, incy, dparam); + drotm_( n, dx, incx, dy, incy, dparam); } void drotm(const f77_int *n,double *dx,const f77_int *incx,double *dy,const f77_int *incy,const double *dparam) { - drotm_blis_impl( n, dx, incx, dy, incy, dparam); + drotm_( n, dx, incx, dy, incy, dparam); } void DROTM_(const f77_int *n,double *dx,const f77_int *incx,double *dy,const f77_int *incy,const double *dparam) { - drotm_blis_impl( n, dx, incx, dy, incy, dparam); + drotm_( n, dx, incx, dy, incy, dparam); } void DROTMG(double *dd1,double *dd2,double *dx1,const double *dy1,double *dparam) { - drotmg_blis_impl( dd1, dd2, dx1, dy1, dparam); + drotmg_( dd1, dd2, dx1, dy1, dparam); } void drotmg(double *dd1,double *dd2,double *dx1,const double *dy1,double *dparam) { - drotmg_blis_impl( dd1, dd2, dx1, dy1, dparam); + drotmg_( dd1, dd2, dx1, dy1, dparam); } void DROTMG_(double *dd1,double *dd2,double *dx1,const double *dy1,double *dparam) { - drotmg_blis_impl( dd1, dd2, dx1, dy1, dparam); + drotmg_( dd1, dd2, dx1, dy1, dparam); } void DSBMV(const char *uplo,const f77_int *n,const f77_int *k,const double *alpha,const double *a,const f77_int *lda,const double *x,const f77_int *incx,const double *beta,double *y,const f77_int *incy) @@ -885,17 +885,17 @@ void DSBMV_(const char *uplo,const f77_int *n,const f77_int *k,const double *a void DSCAL(const f77_int *n,const double *da,double *dx,const f77_int *incx) { - dscal_blis_impl( n, da, dx, incx); + dscal_( n, da, dx, incx); } void dscal(const f77_int *n,const double *da,double *dx,const f77_int *incx) { - dscal_blis_impl( n, da, dx, incx); + dscal_( n, da, dx, incx); } void DSCAL_(const f77_int *n,const double *da,double *dx,const f77_int *incx) { - dscal_blis_impl( n, da, dx, incx); + dscal_( n, da, dx, incx); } double DSDOT(const f77_int *n,const float *sx,const f77_int *incx,const float *sy,const f77_int *incy) @@ -1305,17 +1305,17 @@ float SASUM_(const f77_int *n,const float *sx, const f77_int *incx) void SAXPY(const f77_int *n,const float *sa,const float *sx,const f77_int *incx,float *sy,const f77_int *incy) { - saxpy_blis_impl( n, sa, sx, incx, sy, incy); + saxpy_( n, sa, sx, incx, sy, incy); } void saxpy(const f77_int *n,const float *sa,const float *sx,const f77_int *incx,float *sy,const f77_int *incy) { - saxpy_blis_impl( n, sa, sx, incx, sy, incy); + saxpy_( n, sa, sx, incx, sy, incy); } void SAXPY_(const f77_int *n,const float *sa,const float *sx,const f77_int *incx,float *sy,const f77_int *incy) { - saxpy_blis_impl( n, sa, sx, incx, sy, incy); + saxpy_( n, sa, sx, incx, sy, incy); } @@ -1354,17 +1354,17 @@ float SCNRM2_(const f77_int *n,const scomplex *x, const f77_int *incx) void SCOPY(const f77_int *n,const float *sx,const f77_int *incx,float *sy,const f77_int *incy) { - scopy_blis_impl( n, sx, incx, sy, incy); + scopy_( n, sx, incx, sy, incy); } void scopy(const f77_int *n,const float *sx,const f77_int *incx,float *sy,const f77_int *incy) { - scopy_blis_impl( n, sx, incx, sy, incy); + scopy_( n, sx, incx, sy, incy); } void SCOPY_(const f77_int *n,const float *sx,const f77_int *incx,float *sy,const f77_int *incy) { - scopy_blis_impl( n, sx, incx, sy, incy); + scopy_( n, sx, incx, sy, incy); } @@ -1479,62 +1479,62 @@ float SNRM2_(const f77_int *n,const float *x, const f77_int *incx) void SROT(const f77_int *n,float *sx,const f77_int *incx,float *sy,const f77_int *incy,const float *c,const float *s) { - srot_blis_impl( n, sx, incx, sy, incy, c, s); + srot_( n, sx, incx, sy, incy, c, s); } void srot(const f77_int *n,float *sx,const f77_int *incx,float *sy,const f77_int *incy,const float *c,const float *s) { - srot_blis_impl( n, sx, incx, sy, incy, c, s); + srot_( n, sx, incx, sy, incy, c, s); } -void SROT_blis_impl(const f77_int *n,float *sx,const f77_int *incx,float *sy,const f77_int *incy,const float *c,const float *s) +void SROT_(const f77_int *n,float *sx,const f77_int *incx,float *sy,const f77_int *incy,const float *c,const float *s) { - srot_blis_impl( n, sx, incx, sy, incy, c, s); + srot_( n, sx, incx, sy, incy, c, s); } void SROTG(float *sa,float *sb,float *c,float *s) { - srotg_blis_impl( sa, sb, c, s); + srotg_( sa, sb, c, s); } void srotg(float *sa,float *sb,float *c,float *s) { - srotg_blis_impl( sa, sb, c, s); + srotg_( sa, sb, c, s); } void SROTG_(float *sa,float *sb,float *c,float *s) { - srotg_blis_impl( sa, sb, c, s); + srotg_( sa, sb, c, s); } void SROTM(const f77_int *n,float *sx,const f77_int *incx,float *sy,const f77_int *incy,const float *sparam) { - srotm_blis_impl( n, sx, incx, sy, incy, sparam); + srotm_( n, sx, incx, sy, incy, sparam); } void srotm(const f77_int *n,float *sx,const f77_int *incx,float *sy,const f77_int *incy,const float *sparam) { - srotm_blis_impl( n, sx, incx, sy, incy, sparam); + srotm_( n, sx, incx, sy, incy, sparam); } void SROTM_(const f77_int *n,float *sx,const f77_int *incx,float *sy,const f77_int *incy,const float *sparam) { - srotm_blis_impl( n, sx, incx, sy, incy, sparam); + srotm_( n, sx, incx, sy, incy, sparam); } void SROTMG(float *sd1,float *sd2,float *sx1,const float *sy1,float *sparam) { - srotmg_blis_impl( sd1, sd2, sx1, sy1, sparam); + srotmg_( sd1, sd2, sx1, sy1, sparam); } void srotmg(float *sd1,float *sd2,float *sx1,const float *sy1,float *sparam) { - srotmg_blis_impl( sd1, sd2, sx1, sy1, sparam); + srotmg_( sd1, sd2, sx1, sy1, sparam); } void SROTMG_(float *sd1,float *sd2,float *sx1,const float *sy1,float *sparam) { - srotmg_blis_impl( sd1, sd2, sx1, sy1, sparam); + srotmg_( sd1, sd2, sx1, sy1, sparam); } void SSBMV(const char *uplo,const f77_int *n,const f77_int *k,const float *alpha,const float *a,const f77_int *lda,const float *x,const f77_int *incx,const float *beta,float *y,const f77_int *incy) @@ -1554,17 +1554,17 @@ void SSBMV_(const char *uplo,const f77_int *n,const f77_int *k,const float *a void SSCAL(const f77_int *n,const float *sa,float *sx,const f77_int *incx) { - sscal_blis_impl( n, sa, sx, incx); + sscal_( n, sa, sx, incx); } void sscal(const f77_int *n,const float *sa,float *sx,const f77_int *incx) { - sscal_blis_impl( n, sa, sx, incx); + sscal_( n, sa, sx, incx); } void SSCAL_(const f77_int *n,const float *sa,float *sx,const f77_int *incx) { - sscal_blis_impl( n, sa, sx, incx); + sscal_( n, sa, sx, incx); } void SSPMV(const char *uplo,const f77_int *n,const float *alpha,const float *ap,const float *x,const f77_int *incx,const float *beta,float *y,const f77_int *incy) @@ -1854,32 +1854,32 @@ int xerbla(const char *srname,const f77_int *info, ftnlen n) void ZAXPY(const f77_int *n,const dcomplex *za,const dcomplex *zx,const f77_int *incx,dcomplex *zy,const f77_int *incy) { - zaxpy_blis_impl( n, za, zx, incx, zy, incy); + zaxpy_( n, za, zx, incx, zy, incy); } void zaxpy(const f77_int *n,const dcomplex *za,const dcomplex *zx,const f77_int *incx,dcomplex *zy,const f77_int *incy) { - zaxpy_blis_impl( n, za, zx, incx, zy, incy); + zaxpy_( n, za, zx, incx, zy, incy); } void ZAXPY_(const f77_int *n,const dcomplex *za,const dcomplex *zx,const f77_int *incx,dcomplex *zy,const f77_int *incy) { - zaxpy_blis_impl( n, za, zx, incx, zy, incy); + zaxpy_( n, za, zx, incx, zy, incy); } void ZCOPY(const f77_int *n,const dcomplex *zx,const f77_int *incx,dcomplex *zy,const f77_int *incy) { - zcopy_blis_impl( n, zx, incx, zy, incy); + zcopy_( n, zx, incx, zy, incy); } void zcopy(const f77_int *n,const dcomplex *zx,const f77_int *incx,dcomplex *zy,const f77_int *incy) { - zcopy_blis_impl( n, zx, incx, zy, incy); + zcopy_( n, zx, incx, zy, incy); } void ZCOPY_(const f77_int *n,const dcomplex *zx,const f77_int *incx,dcomplex *zy,const f77_int *incy) { - zcopy_blis_impl( n, zx, incx, zy, incy); + zcopy_( n, zx, incx, zy, incy); } void ZDROT(const f77_int *n,dcomplex *cx,const f77_int *incx,dcomplex *cy,const f77_int *incy,const double *c,const double *s) @@ -1899,17 +1899,17 @@ void ZDROT_(const f77_int *n,dcomplex *cx,const f77_int *incx,dcomplex *cy,const void ZDSCAL(const f77_int *n,const double *da,dcomplex *zx,const f77_int *incx) { - zdscal_blis_impl( n, da, zx, incx); + zdscal_( n, da, zx, incx); } void zdscal(const f77_int *n,const double *da,dcomplex *zx,const f77_int *incx) { - zdscal_blis_impl( n, da, zx, incx); + zdscal_( n, da, zx, incx); } void ZDSCAL_(const f77_int *n,const double *da,dcomplex *zx,const f77_int *incx) { - zdscal_blis_impl( n, da, zx, incx); + zdscal_( n, da, zx, incx); } void ZGBMV(const char *trans,const f77_int *m,const f77_int *n,const f77_int *kl,const f77_int *ku,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *x,const f77_int *incx,const dcomplex *beta,dcomplex *y,const f77_int *incy) @@ -2154,17 +2154,17 @@ void ZROTG_(dcomplex *ca,bla_dcomplex *cb,bla_double *c,dcomplex *s) void ZSCAL(const f77_int *n,const dcomplex *za,dcomplex *zx,const f77_int *incx) { - zscal_blis_impl( n, za, zx, incx); + zscal_( n, za, zx, incx); } void zscal(const f77_int *n,const dcomplex *za,dcomplex *zx,const f77_int *incx) { - zscal_blis_impl( n, za, zx, incx); + zscal_( n, za, zx, incx); } void ZSCAL_(const f77_int *n,const dcomplex *za,dcomplex *zx,const f77_int *incx) { - zscal_blis_impl( n, za, zx, incx); + zscal_( n, za, zx, incx); } void ZSWAP(const f77_int *n,dcomplex *zx,const f77_int *incx,dcomplex *zy,const f77_int *incy) @@ -2350,32 +2350,32 @@ void ZTRSV_(const char *uplo,const char *trans,const char *diag,const f77_ void CDOTCSUB( const f77_int* n, const scomplex* x,const f77_int* incx, const scomplex* y, const f77_int* incy, scomplex* rval) { - cdotcsub_blis_impl( n, x, incx, y, incy, rval); + cdotcsub_( n, x, incx, y, incy, rval); } void cdotcsub( const f77_int* n, const scomplex* x,const f77_int* incx, const scomplex* y, const f77_int* incy, scomplex* rval) { - cdotcsub_blis_impl( n, x, incx, y, incy, rval); + cdotcsub_( n, x, incx, y, incy, rval); } void CDOTCSUB_( const f77_int* n, const scomplex* x,const f77_int* incx, const scomplex* y, const f77_int* incy, scomplex* rval) { - cdotcsub_blis_impl( n, x, incx, y, incy, rval); + cdotcsub_( n, x, incx, y, incy, rval); } void CDOTUSUB( const f77_int* n, const scomplex* x,const f77_int* incxy, const scomplex* y, const f77_int* incy, scomplex* rval) { - cdotusub_blis_impl( n, x, incxy, y, incy, rval); + cdotusub_( n, x, incxy, y, incy, rval); } void cdotusub( const f77_int* n, const scomplex* x,const f77_int* incxy, const scomplex* y, const f77_int* incy, scomplex* rval) { - cdotusub_blis_impl( n, x, incxy, y, incy, rval); + cdotusub_( n, x, incxy, y, incy, rval); } void CDOTUSUB_( const f77_int* n, const scomplex* x,const f77_int* incxy, const scomplex* y, const f77_int* incy, scomplex* rval) { - cdotusub_blis_impl( n, x, incxy, y, incy, rval); + cdotusub_( n, x, incxy, y, incy, rval); } void CGEMM3M( const f77_char* transa, const f77_char* transb, const f77_int* m, const f77_int* n, const f77_int* k, const scomplex* alpha, const scomplex* a, const f77_int* lda, const scomplex* b, const f77_int* ldb, const scomplex* beta, scomplex* c, const f77_int* ldc) @@ -2485,47 +2485,47 @@ void COMATCOPY_(f77_char* trans, f77_int* rows, f77_int* cols, const scomplex* a void DASUMSUB(const f77_int* n, const double* x, const f77_int* incx, double* rval) { - dasumsub_blis_impl( n, x, incx, rval); + dasumsub_( n, x, incx, rval); } void dasumsub(const f77_int* n, const double* x, const f77_int* incx, double* rval) { - dasumsub_blis_impl( n, x, incx, rval); + dasumsub_( n, x, incx, rval); } void DASUMSUB_(const f77_int* n, const double* x, const f77_int* incx, double* rval) { - dasumsub_blis_impl( n, x, incx, rval); + dasumsub_( n, x, incx, rval); } void DAXPBY(const f77_int* n, const double* alpha, const double *x, const f77_int* incx, const double* beta, double *y, const f77_int* incy) { - daxpby_blis_impl( n, alpha, x, incx, beta, y, incy); + daxpby_( n, alpha, x, incx, beta, y, incy); } void daxpby(const f77_int* n, const double* alpha, const double *x, const f77_int* incx, const double* beta, double *y, const f77_int* incy) { - daxpby_blis_impl( n, alpha, x, incx, beta, y, incy); + daxpby_( n, alpha, x, incx, beta, y, incy); } void DAXPBY_(const f77_int* n, const double* alpha, const double *x, const f77_int* incx, const double* beta, double *y, const f77_int* incy) { - daxpby_blis_impl( n, alpha, x, incx, beta, y, incy); + daxpby_( n, alpha, x, incx, beta, y, incy); } void DDOTSUB(const f77_int* n, const double* x, const f77_int* incx, const double* y, const f77_int* incy, double* rval) { - ddotsub_blis_impl( n, x, incx, y, incy, rval); + ddotsub_( n, x, incx, y, incy, rval); } void ddotsub(const f77_int* n, const double* x, const f77_int* incx, const double* y, const f77_int* incy, double* rval) { - ddotsub_blis_impl( n, x, incx, y, incy, rval); + ddotsub_( n, x, incx, y, incy, rval); } void DDOTSUB_(const f77_int* n, const double* x, const f77_int* incx, const double* y, const f77_int* incy, double* rval) { - ddotsub_blis_impl( n, x, incx, y, incy, rval); + ddotsub_( n, x, incx, y, incy, rval); } void DGEMM_BATCH( const f77_char* transa_array, const f77_char* transb_array,const f77_int *m_array, const f77_int *n_array, const f77_int *k_array,const double* alpha_array, const double** a_array, const f77_int *lda_array, const double** b_array, const f77_int *ldb_array, const double* beta_array, double** c_array, const f77_int *ldc_array, const f77_int* group_count, const f77_int *group_size) @@ -2560,17 +2560,17 @@ void DGEMMT_( const f77_char* uploc, const f77_char* transa, const f77_char* tra void DNRM2SUB(const f77_int* n, const double* x, const f77_int* incx, double *rval) { - dnrm2sub_blis_impl( n, x, incx, rval); + dnrm2sub_( n, x, incx, rval); } void dnrm2sub(const f77_int* n, const double* x, const f77_int* incx, double *rval) { - dnrm2sub_blis_impl( n, x, incx, rval); + dnrm2sub_( n, x, incx, rval); } void DNRM2SUB_(const f77_int* n, const double* x, const f77_int* incx, double *rval) { - dnrm2sub_blis_impl( n, x, incx, rval); + dnrm2sub_( n, x, incx, rval); } void DOMATADD(f77_char* transa,f77_char* transb, f77_int* m, f77_int* n, const double* alpha, const double* A, f77_int* lda, const double* beta, const double* B, f77_int* ldb, double* C, f77_int* ldc) @@ -2620,47 +2620,47 @@ void DOMATCOPY_(f77_char* trans, f77_int* rows, f77_int* cols, const double* alp void DZASUMSUB(const f77_int* n, const dcomplex* x, const f77_int* incx, double* rval) { - dzasumsub_blis_impl( n, x, incx, rval); + dzasumsub_( n, x, incx, rval); } void dzasumsub(const f77_int* n, const dcomplex* x, const f77_int* incx, double* rval) { - dzasumsub_blis_impl( n, x, incx, rval); + dzasumsub_( n, x, incx, rval); } void DZASUMSUB_(const f77_int* n, const dcomplex* x, const f77_int* incx, double* rval) { - dzasumsub_blis_impl( n, x, incx, rval); + dzasumsub_( n, x, incx, rval); } void DZNRM2SUB(const f77_int* n, const dcomplex* x, const f77_int* incx, double* rval) { - dznrm2sub_blis_impl( n, x, incx, rval); + dznrm2sub_( n, x, incx, rval); } void dznrm2sub(const f77_int* n, const dcomplex* x, const f77_int* incx, double* rval) { - dznrm2sub_blis_impl( n, x, incx, rval); + dznrm2sub_( n, x, incx, rval); } void DZNRM2SUB_(const f77_int* n, const dcomplex* x, const f77_int* incx, double* rval) { - dznrm2sub_blis_impl( n, x, incx, rval); + dznrm2sub_( n, x, incx, rval); } void ICAMAXSUB(const f77_int* n, const scomplex* x, const f77_int* incx, f77_int* rval) { - icamaxsub_blis_impl( n, x, incx, rval); + icamaxsub_( n, x, incx, rval); } void icamaxsub(const f77_int* n, const scomplex* x, const f77_int* incx, f77_int* rval) { - icamaxsub_blis_impl( n, x, incx, rval); + icamaxsub_( n, x, incx, rval); } void ICAMAXSUB_(const f77_int* n, const scomplex* x, const f77_int* incx, f77_int* rval) { - icamaxsub_blis_impl( n, x, incx, rval); + icamaxsub_( n, x, incx, rval); } f77_int ICAMIN( const f77_int* n, const scomplex* x, const f77_int* incx) @@ -2680,32 +2680,32 @@ f77_int ICAMIN_( const f77_int* n, const scomplex* x, const f77_int* incx) void ICAMINSUB( const f77_int* n, const scomplex* x, const f77_int* incx, f77_int* rval) { - icaminsub_blis_impl( n, x, incx, rval); + icaminsub_( n, x, incx, rval); } void icaminsub( const f77_int* n, const scomplex* x, const f77_int* incx, f77_int* rval) { - icaminsub_blis_impl( n, x, incx, rval); + icaminsub_( n, x, incx, rval); } void ICAMINSUB_( const f77_int* n, const scomplex* x, const f77_int* incx, f77_int* rval) { - icaminsub_blis_impl( n, x, incx, rval); + icaminsub_( n, x, incx, rval); } void IDAMAXSUB( const f77_int* n, const double* x, const f77_int* incx, f77_int* rval) { - idamaxsub_blis_impl( n, x, incx, rval); + idamaxsub_( n, x, incx, rval); } void idamaxsub( const f77_int* n, const double* x, const f77_int* incx, f77_int* rval) { - idamaxsub_blis_impl( n, x, incx, rval); + idamaxsub_( n, x, incx, rval); } void IDAMAXSUB_( const f77_int* n, const double* x, const f77_int* incx, f77_int* rval) { - idamaxsub_blis_impl( n, x, incx, rval); + idamaxsub_( n, x, incx, rval); } f77_int IDAMIN( const f77_int* n, const double* x, const f77_int* incx) @@ -2725,32 +2725,32 @@ f77_int IDAMIN_( const f77_int* n, const double* x, const f77_int* incx) void IDAMINSUB(const f77_int* n, const double* x, const f77_int* incx, f77_int* rval) { - idaminsub_blis_impl( n, x, incx, rval); + idaminsub_( n, x, incx, rval); } void idaminsub(const f77_int* n, const double* x, const f77_int* incx, f77_int* rval) { - idaminsub_blis_impl( n, x, incx, rval); + idaminsub_( n, x, incx, rval); } void IDAMINSUB_(const f77_int* n, const double* x, const f77_int* incx, f77_int* rval) { - idaminsub_blis_impl( n, x, incx, rval); + idaminsub_( n, x, incx, rval); } void ISAMAXSUB( const f77_int* n, const float* x, const f77_int* incx, f77_int* rval) { - isamaxsub_blis_impl( n, x, incx, rval); + isamaxsub_( n, x, incx, rval); } void isamaxsub( const f77_int* n, const float* x, const f77_int* incx, f77_int* rval) { - isamaxsub_blis_impl( n, x, incx, rval); + isamaxsub_( n, x, incx, rval); } void ISAMAXSUB_( const f77_int* n, const float* x, const f77_int* incx, f77_int* rval) { - isamaxsub_blis_impl( n, x, incx, rval); + isamaxsub_( n, x, incx, rval); } f77_int ISAMIN( const f77_int* n, const float* x, const f77_int* incx) @@ -2770,32 +2770,32 @@ f77_int ISAMIN_( const f77_int* n, const float* x, const f77_int* incx) void ISAMINSUB( const f77_int* n, const float* x, const f77_int* incx, f77_int* rval) { - isaminsub_blis_impl( n, x, incx, rval); + isaminsub_( n, x, incx, rval); } void isaminsub( const f77_int* n, const float* x, const f77_int* incx, f77_int* rval) { - isaminsub_blis_impl( n, x, incx, rval); + isaminsub_( n, x, incx, rval); } void ISAMINSUB_( const f77_int* n, const float* x, const f77_int* incx, f77_int* rval) { - isaminsub_blis_impl( n, x, incx, rval); + isaminsub_( n, x, incx, rval); } void IZAMAXSUB( const f77_int* n, const dcomplex* x, const f77_int* incx, f77_int* rval) { - izamaxsub_blis_impl( n, x, incx, rval); + izamaxsub_( n, x, incx, rval); } void izamaxsub( const f77_int* n, const dcomplex* x, const f77_int* incx, f77_int* rval) { - izamaxsub_blis_impl( n, x, incx, rval); + izamaxsub_( n, x, incx, rval); } void IZAMAXSUB_( const f77_int* n, const dcomplex* x, const f77_int* incx, f77_int* rval) { - izamaxsub_blis_impl( n, x, incx, rval); + izamaxsub_( n, x, incx, rval); } f77_int IZAMIN( const f77_int* n, const dcomplex* x, const f77_int* incx) @@ -2815,92 +2815,92 @@ f77_int IZAMIN_( const f77_int* n, const dcomplex* x, const f77_int* incx) void IZAMINSUB( const f77_int* n, const dcomplex* x, const f77_int* incx, f77_int* rval) { - izaminsub_blis_impl( n, x, incx, rval); + izaminsub_( n, x, incx, rval); } void izaminsub( const f77_int* n, const dcomplex* x, const f77_int* incx, f77_int* rval) { - izaminsub_blis_impl( n, x, incx, rval); + izaminsub_( n, x, incx, rval); } void IZAMINSUB_( const f77_int* n, const dcomplex* x, const f77_int* incx, f77_int* rval) { - izaminsub_blis_impl( n, x, incx, rval); + izaminsub_( n, x, incx, rval); } void SASUMSUB( const f77_int* n, const float* x, const f77_int* incx, float* rval) { - sasumsub_blis_impl( n, x, incx, rval); + sasumsub_( n, x, incx, rval); } void sasumsub( const f77_int* n, const float* x, const f77_int* incx, float* rval) { - sasumsub_blis_impl( n, x, incx, rval); + sasumsub_( n, x, incx, rval); } void SASUMSUB_( const f77_int* n, const float* x, const f77_int* incx, float* rval) { - sasumsub_blis_impl( n, x, incx, rval); + sasumsub_( n, x, incx, rval); } void SAXPBY( const f77_int* n, const float* alpha, const float *x, const f77_int* incx, const float* beta, float *y, const f77_int* incy) { - saxpby_blis_impl( n, alpha, x, incx, beta, y, incy); + saxpby_( n, alpha, x, incx, beta, y, incy); } void saxpby( const f77_int* n, const float* alpha, const float *x, const f77_int* incx, const float* beta, float *y, const f77_int* incy) { - saxpby_blis_impl( n, alpha, x, incx, beta, y, incy); + saxpby_( n, alpha, x, incx, beta, y, incy); } void SAXPBY_( const f77_int* n, const float* alpha, const float *x, const f77_int* incx, const float* beta, float *y, const f77_int* incy) { - saxpby_blis_impl( n, alpha, x, incx, beta, y, incy); + saxpby_( n, alpha, x, incx, beta, y, incy); } void SCASUMSUB( const f77_int* n, const scomplex* x, const f77_int* incx, float* rval) { - scasumsub_blis_impl( n, x, incx, rval); + scasumsub_( n, x, incx, rval); } void scasumsub( const f77_int* n, const scomplex* x, const f77_int* incx, float* rval) { - scasumsub_blis_impl( n, x, incx, rval); + scasumsub_( n, x, incx, rval); } void SCASUMSUB_( const f77_int* n, const scomplex* x, const f77_int* incx, float* rval) { - scasumsub_blis_impl( n, x, incx, rval); + scasumsub_( n, x, incx, rval); } void SCNRM2SUB( const f77_int* n, const scomplex* x, const f77_int* incx, float* rval) { - scnrm2sub_blis_impl( n, x, incx, rval); + scnrm2sub_( n, x, incx, rval); } void scnrm2sub( const f77_int* n, const scomplex* x, const f77_int* incx, float* rval) { - scnrm2sub_blis_impl( n, x, incx, rval); + scnrm2sub_( n, x, incx, rval); } void SCNRM2SUB_( const f77_int* n, const scomplex* x, const f77_int* incx, float* rval) { - scnrm2sub_blis_impl( n, x, incx, rval); + scnrm2sub_( n, x, incx, rval); } void SDOTSUB( const f77_int* n, const float* x, const f77_int* incx, const float* y, const f77_int* incy, float* rval) { - sdotsub_blis_impl( n, x, incx, y, incy, rval); + sdotsub_( n, x, incx, y, incy, rval); } void sdotsub( const f77_int* n, const float* x, const f77_int* incx, const float* y, const f77_int* incy, float* rval) { - sdotsub_blis_impl( n, x, incx, y, incy, rval); + sdotsub_( n, x, incx, y, incy, rval); } void SDOTSUB_( const f77_int* n, const float* x, const f77_int* incx, const float* y, const f77_int* incy, float* rval) { - sdotsub_blis_impl( n, x, incx, y, incy, rval); + sdotsub_( n, x, incx, y, incy, rval); } void SGEMM_BATCH(const f77_char* transa_array, const f77_char* transb_array,const f77_int *m_array, const f77_int *n_array, const f77_int *k_array,const float* alpha_array, const float** a_array, const f77_int *lda_array, const float** b_array, const f77_int *ldb_array, const float* beta_array, float** c_array, const f77_int *ldc_array, const f77_int* group_count, const f77_int *group_size) @@ -2950,17 +2950,17 @@ void SIMATCOPY_( f77_char* trans, f77_int* rows, f77_int* cols, const float* alp void SNRM2SUB( const f77_int* n, const float* x, const f77_int* incx, float *rval) { - snrm2sub_blis_impl( n, x, incx, rval); + snrm2sub_( n, x, incx, rval); } void snrm2sub( const f77_int* n, const float* x, const f77_int* incx, float *rval) { - snrm2sub_blis_impl( n, x, incx, rval); + snrm2sub_( n, x, incx, rval); } void SNRM2SUB_( const f77_int* n, const float* x, const f77_int* incx, float *rval) { - snrm2sub_blis_impl( n, x, incx, rval); + snrm2sub_( n, x, incx, rval); } void SOMATADD( f77_char* transa,f77_char* transb, f77_int* m, f77_int* n, const float* alpha, const float* A, f77_int* lda, const float* beta, const float* B, f77_int* ldb, float* C, f77_int* ldc) @@ -3010,47 +3010,47 @@ void SOMATCOPY_( f77_char* trans, f77_int* rows, f77_int* cols, const float* alp void ZAXPBY( const f77_int* n, const dcomplex* alpha, const dcomplex *x, const f77_int* incx, const dcomplex* beta, dcomplex *y, const f77_int* incy) { - zaxpby_blis_impl( n, alpha, x, incx, beta, y, incy); + zaxpby_( n, alpha, x, incx, beta, y, incy); } void zaxpby( const f77_int* n, const dcomplex* alpha, const dcomplex *x, const f77_int* incx, const dcomplex* beta, dcomplex *y, const f77_int* incy) { - zaxpby_blis_impl( n, alpha, x, incx, beta, y, incy); + zaxpby_( n, alpha, x, incx, beta, y, incy); } void ZAXPBY_( const f77_int* n, const dcomplex* alpha, const dcomplex *x, const f77_int* incx, const dcomplex* beta, dcomplex *y, const f77_int* incy) { - zaxpby_blis_impl( n, alpha, x, incx, beta, y, incy); + zaxpby_( n, alpha, x, incx, beta, y, incy); } void ZDOTCSUB( const f77_int* n, const dcomplex* x, const f77_int* incx, const dcomplex* y, const f77_int* incy, dcomplex* rval) { - zdotcsub_blis_impl( n, x, incx, y, incy, rval); + zdotcsub_( n, x, incx, y, incy, rval); } void zdotcsub( const f77_int* n, const dcomplex* x, const f77_int* incx, const dcomplex* y, const f77_int* incy, dcomplex* rval) { - zdotcsub_blis_impl( n, x, incx, y, incy, rval); + zdotcsub_( n, x, incx, y, incy, rval); } void ZDOTCSUB_( const f77_int* n, const dcomplex* x, const f77_int* incx, const dcomplex* y, const f77_int* incy, dcomplex* rval) { - zdotcsub_blis_impl( n, x, incx, y, incy, rval); + zdotcsub_( n, x, incx, y, incy, rval); } void ZDOTUSUB( const f77_int* n, const dcomplex* x, const f77_int* incx,const dcomplex* y, const f77_int* incy, dcomplex* rval) { - zdotusub_blis_impl( n, x, incx, y, incy, rval); + zdotusub_( n, x, incx, y, incy, rval); } void zdotusub( const f77_int* n, const dcomplex* x, const f77_int* incx,const dcomplex* y, const f77_int* incy, dcomplex* rval) { - zdotusub_blis_impl( n, x, incx, y, incy, rval); + zdotusub_( n, x, incx, y, incy, rval); } void ZDOTUSUB_( const f77_int* n, const dcomplex* x, const f77_int* incx,const dcomplex* y, const f77_int* incy, dcomplex* rval) { - zdotusub_blis_impl( n, x, incx, y, incy, rval); + zdotusub_( n, x, incx, y, incy, rval); } void ZGEMM3M( const f77_char* transa, const f77_char* transb, const f77_int* m, const f77_int* n, const f77_int* k, const dcomplex* alpha, const dcomplex* a, const f77_int* lda, const dcomplex* b, const f77_int* ldb, const dcomplex* beta, dcomplex* c, const f77_int* ldc) @@ -3178,47 +3178,47 @@ float SCABS1_(bla_scomplex* z) void SDSDOTSUB( const f77_int* n, float* sb, const float* x, const f77_int* incx, const float* y, const f77_int* incy, float* dot) { - sdsdotsub_blis_impl( n, sb, x, incx, y, incy, dot); + sdsdotsub_( n, sb, x, incx, y, incy, dot); } void sdsdotsub( const f77_int* n, float* sb, const float* x, const f77_int* incx, const float* y, const f77_int* incy, float* dot) { - sdsdotsub_blis_impl( n, sb, x, incx, y, incy, dot); + sdsdotsub_( n, sb, x, incx, y, incy, dot); } void SDSDOTSUB_( const f77_int* n, float* sb, const float* x, const f77_int* incx, const float* y, const f77_int* incy, float* dot) { - sdsdotsub_blis_impl( n, sb, x, incx, y, incy, dot); + sdsdotsub_( n, sb, x, incx, y, incy, dot); } void DSDOTSUB( const f77_int* n, const float* x, const f77_int* incx, const float* y, const f77_int* incy, double* dot) { - dsdotsub_blis_impl( n, x, incx, y, incy, dot); + dsdotsub_( n, x, incx, y, incy, dot); } void dsdotsub( const f77_int* n, const float* x, const f77_int* incx, const float* y, const f77_int* incy, double* dot) { - dsdotsub_blis_impl( n, x, incx, y, incy, dot); + dsdotsub_( n, x, incx, y, incy, dot); } void DSDOTSUB_( const f77_int* n, const float* x, const f77_int* incx, const float* y, const f77_int* incy, double* dot) { - dsdotsub_blis_impl( n, x, incx, y, incy, dot); + dsdotsub_( n, x, incx, y, incy, dot); } void CAXPBY( const f77_int* n, const scomplex* alpha, const scomplex *x, const f77_int* incx, const scomplex* beta, scomplex *y, const f77_int* incy) { - caxpby_blis_impl(n, alpha, x, incx, beta, y, incy); + caxpby_(n, alpha, x, incx, beta, y, incy); } void caxpby( const f77_int* n, const scomplex* alpha, const scomplex *x, const f77_int* incx, const scomplex* beta, scomplex *y, const f77_int* incy) { - caxpby_blis_impl(n, alpha, x, incx, beta, y, incy); + caxpby_(n, alpha, x, incx, beta, y, incy); } void CAXPBY_( const f77_int* n, const scomplex* alpha, const scomplex *x, const f77_int* incx, const scomplex* beta, scomplex *y, const f77_int* incy) { - caxpby_blis_impl(n, alpha, x, incx, beta, y, incy); + caxpby_(n, alpha, x, incx, beta, y, incy); } #endif From 7e42b3d2e0f4eead4f988b1d40930150f77f3e17 Mon Sep 17 00:00:00 2001 From: Dipal M Zambare Date: Tue, 30 Aug 2022 11:53:35 +0530 Subject: [PATCH 206/243] Revert "CBLAS/BLAS interface decoupling for level 2 APIs" This reverts commit 192f5313a13cdc21f0e07fdc8c3536bbf11c1c0e. Change-Id: I876cad90902970ebc61550f109eb0ce32539ea1c --- frame/compat/bla_gemv.c | 25 +- frame/compat/bla_gemv.h | 15 +- frame/compat/bla_gemv_amd.c | 90 +------ frame/compat/bla_ger.c | 18 +- frame/compat/bla_ger.h | 15 +- frame/compat/bla_hemv.c | 19 +- frame/compat/bla_hemv.h | 16 +- frame/compat/bla_her.c | 19 +- frame/compat/bla_her.h | 14 +- frame/compat/bla_her2.c | 18 +- frame/compat/bla_her2.h | 15 +- frame/compat/bla_symv.c | 19 +- frame/compat/bla_symv.h | 14 +- frame/compat/bla_syr.c | 17 +- frame/compat/bla_syr.h | 12 +- frame/compat/bla_syr2.c | 18 +- frame/compat/bla_syr2.h | 13 +- frame/compat/bla_trmv.c | 18 +- frame/compat/bla_trmv.h | 13 +- frame/compat/bla_trsv.c | 18 +- frame/compat/bla_trsv.h | 13 +- frame/compat/cblas/src/cblas_f77.h | 132 +++++----- frame/compat/f2c/bla_gbmv.c | 34 +-- frame/compat/f2c/bla_gbmv.h | 7 +- frame/compat/f2c/bla_hbmv.c | 19 +- frame/compat/f2c/bla_hbmv.h | 5 +- frame/compat/f2c/bla_hpmv.c | 19 +- frame/compat/f2c/bla_hpmv.h | 5 +- frame/compat/f2c/bla_hpr.c | 19 +- frame/compat/f2c/bla_hpr.h | 5 +- frame/compat/f2c/bla_hpr2.c | 19 +- frame/compat/f2c/bla_hpr2.h | 5 +- frame/compat/f2c/bla_sbmv.c | 19 +- frame/compat/f2c/bla_sbmv.h | 5 +- frame/compat/f2c/bla_spmv.c | 19 +- frame/compat/f2c/bla_spmv.h | 5 +- frame/compat/f2c/bla_spr.c | 19 +- frame/compat/f2c/bla_spr.h | 5 +- frame/compat/f2c/bla_spr2.c | 19 +- frame/compat/f2c/bla_spr2.h | 5 +- frame/compat/f2c/bla_tbmv.c | 51 +--- frame/compat/f2c/bla_tbmv.h | 7 +- frame/compat/f2c/bla_tbsv.c | 35 +-- frame/compat/f2c/bla_tbsv.h | 7 +- frame/compat/f2c/bla_tpmv.c | 36 +-- frame/compat/f2c/bla_tpmv.h | 7 +- frame/compat/f2c/bla_tpsv.c | 35 +-- frame/compat/f2c/bla_tpsv.h | 7 +- frame/include/bli_macro_defs.h | 26 +- frame/util/bli_util_api_wrap.c | 396 ++++++++++++++--------------- 50 files changed, 394 insertions(+), 997 deletions(-) diff --git a/frame/compat/bla_gemv.c b/frame/compat/bla_gemv.c index f5a314331a..9dba1b43c4 100644 --- a/frame/compat/bla_gemv.c +++ b/frame/compat/bla_gemv.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 22, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -35,10 +35,14 @@ #include "blis.h" + +// +// Define BLAS-to-BLIS interfaces. +// #undef GENTFUNC #define GENTFUNC( ftype, ch, blasname, blisname ) \ \ -void PASTEF77S(ch,blasname) \ +void PASTEF77(ch,blasname) \ ( \ const f77_char* transa, \ const f77_int* m, \ @@ -139,24 +143,9 @@ void PASTEF77S(ch,blasname) \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ -}\ -\ -void PASTEF77S(ch,blasname) \ - ( \ - const f77_char* transa, \ - const f77_int* m, \ - const f77_int* n, \ - const ftype* alpha, \ - const ftype* a, const f77_int* lda, \ - const ftype* x, const f77_int* incx, \ - const ftype* beta, \ - ftype* y, const f77_int* incy \ - ) \ -{ \ - PASTEF77(ch,blasname) \ - ( transa, m, n, alpha, a, lda, x, incx, beta, y, incy ); \ } + #ifdef BLIS_ENABLE_BLAS INSERT_GENTFUNC_BLAS( gemv, gemv ) #endif diff --git a/frame/compat/bla_gemv.h b/frame/compat/bla_gemv.h index 54b5471bd4..22c8bf1c07 100644 --- a/frame/compat/bla_gemv.h +++ b/frame/compat/bla_gemv.h @@ -5,8 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. - + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -41,18 +40,6 @@ #define GENTPROT( ftype, ch, blasname ) \ \ BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ - ( \ - const f77_char* transa, \ - const f77_int* m, \ - const f77_int* n, \ - const ftype* alpha, \ - const ftype* a, const f77_int* lda, \ - const ftype* x, const f77_int* incx, \ - const ftype* beta, \ - ftype* y, const f77_int* incy \ - );\ -\ -BLIS_EXPORT_BLAS void PASTEF77S(ch,blasname) \ ( \ const f77_char* transa, \ const f77_int* m, \ diff --git a/frame/compat/bla_gemv_amd.c b/frame/compat/bla_gemv_amd.c index 61834948fc..354f45fe1b 100644 --- a/frame/compat/bla_gemv_amd.c +++ b/frame/compat/bla_gemv_amd.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 22, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -42,7 +42,7 @@ #undef GENTFUNC #define GENTFUNC( ftype, ch, blasname, blisname ) \ \ -void PASTEF77S(ch,blasname) \ +void PASTEF77(ch,blasname) \ ( \ const f77_char* transa, \ const f77_int* m, \ @@ -143,26 +143,11 @@ void PASTEF77S(ch,blasname) \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ -}\ -void PASTEF77(ch,blasname) \ - ( \ - const f77_char* transa, \ - const f77_int* m, \ - const f77_int* n, \ - const ftype* alpha, \ - const ftype* a, const f77_int* lda, \ - const ftype* x, const f77_int* incx, \ - const ftype* beta, \ - ftype* y, const f77_int* incy \ - ) \ -{ \ - PASTEF77S(ch,blasname) \ - ( transa, m, n, alpha, a, lda, x, incx, beta, y, incy ); \ } #ifdef BLIS_ENABLE_BLAS -void dgemv_blis_impl +void dgemv_ ( const f77_char* transa, const f77_int* m, @@ -346,23 +331,8 @@ void dgemv_blis_impl AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); } -void dgemv_ - ( - const f77_char* transa, - const f77_int* m, - const f77_int* n, - const double* alpha, - const double* a, const f77_int* lda, - const double* x, const f77_int* incx, - const double* beta, - double* y, const f77_int* incy - ) -{ - dgemv_blis_impl( transa, m, n, alpha, a, lda, - x, incx, beta, y, incy ); -} -void sgemv_blis_impl +void sgemv_ ( const f77_char* transa, const f77_int* m, @@ -540,23 +510,9 @@ void sgemv_blis_impl AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); } -void sgemv_ - ( - const f77_char* transa, - const f77_int* m, - const f77_int* n, - const float* alpha, - const float* a, const f77_int* lda, - const float* x, const f77_int* incx, - const float* beta, - float* y, const f77_int* incy - ) -{ - sgemv_blis_impl( transa, m, n, alpha, a, lda, - x, incx, beta, y, incy ); -} -void cgemv_blis_impl + +void cgemv_ ( const f77_char* transa, const f77_int* m, @@ -777,23 +733,9 @@ void cgemv_blis_impl AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); } -void cgemv_ - ( - const f77_char* transa, - const f77_int* m, - const f77_int* n, - const scomplex* alpha, - const scomplex* a, const f77_int* lda, - const scomplex* x, const f77_int* incx, - const scomplex* beta, - scomplex* y, const f77_int* incy - ) -{ - cgemv_blis_impl( transa, m, n, alpha, a, lda, - x, incx, beta, y, incy ); -} -void zgemv_blis_impl + +void zgemv_ ( const f77_char* transa, const f77_int* m, @@ -1015,21 +957,7 @@ void zgemv_blis_impl AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); } -void zgemv_ - ( - const f77_char* transa, - const f77_int* m, - const f77_int* n, - const dcomplex* alpha, - const dcomplex* a, const f77_int* lda, - const dcomplex* x, const f77_int* incx, - const dcomplex* beta, - dcomplex* y, const f77_int* incy - ) -{ - zgemv_blis_impl( transa, m, n, alpha, a, lda, - x, incx, beta, y, incy ); -} + #endif diff --git a/frame/compat/bla_ger.c b/frame/compat/bla_ger.c index f489bd356e..b7613842ae 100644 --- a/frame/compat/bla_ger.c +++ b/frame/compat/bla_ger.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -42,7 +42,7 @@ #undef GENTFUNCDOT #define GENTFUNCDOT( ftype, ch, chc, blis_conjy, blasname, blisname ) \ \ -void PASTEF772S(ch,blasname,chc) \ +void PASTEF772(ch,blasname,chc) \ ( \ const f77_int* m, \ const f77_int* n, \ @@ -110,20 +110,6 @@ void PASTEF772S(ch,blasname,chc) \ \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ -} \ -\ -void PASTEF772(ch,blasname,chc) \ - ( \ - const f77_int* m, \ - const f77_int* n, \ - const ftype* alpha, \ - const ftype* x, const f77_int* incx, \ - const ftype* y, const f77_int* incy, \ - ftype* a, const f77_int* lda \ - ) \ -{ \ - PASTEF772S(ch,blasname,chc) \ - ( m, n, alpha, x, incx, y, incy, a, lda ); \ } #ifdef BLIS_ENABLE_BLAS diff --git a/frame/compat/bla_ger.h b/frame/compat/bla_ger.h index 769bcffe46..a31548f610 100644 --- a/frame/compat/bla_ger.h +++ b/frame/compat/bla_ger.h @@ -5,8 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. - + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -41,16 +40,6 @@ #define GENTPROTDOT( ftype, chxy, chc, blasname ) \ \ BLIS_EXPORT_BLAS void PASTEF772(chxy,blasname,chc) \ - ( \ - const f77_int* m, \ - const f77_int* n, \ - const ftype* alpha, \ - const ftype* x, const f77_int* incx, \ - const ftype* y, const f77_int* incy, \ - ftype* a, const f77_int* lda \ - );\ -\ -BLIS_EXPORT_BLAS void PASTEF772S(chxy,blasname,chc) \ ( \ const f77_int* m, \ const f77_int* n, \ @@ -59,7 +48,7 @@ BLIS_EXPORT_BLAS void PASTEF772S(chxy,blasname,chc) \ const ftype* y, const f77_int* incy, \ ftype* a, const f77_int* lda \ ); - + #ifdef BLIS_ENABLE_BLAS INSERT_GENTPROTDOT_BLAS( ger ) #endif diff --git a/frame/compat/bla_hemv.c b/frame/compat/bla_hemv.c index f9fff26ba7..a722f3095d 100644 --- a/frame/compat/bla_hemv.c +++ b/frame/compat/bla_hemv.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 2021, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -42,7 +42,7 @@ #undef GENTFUNCCO #define GENTFUNCCO( ftype, ftype_r, ch, chr, blasname, blisname ) \ \ -void PASTEF77S(ch,blasname) \ +void PASTEF77(ch,blasname) \ ( \ const f77_char* uploa, \ const f77_int* m, \ @@ -113,21 +113,6 @@ void PASTEF77S(ch,blasname) \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ -} \ -\ -void PASTEF77(ch,blasname) \ - ( \ - const f77_char* uploa, \ - const f77_int* m, \ - const ftype* alpha, \ - const ftype* a, const f77_int* lda, \ - const ftype* x, const f77_int* incx, \ - const ftype* beta, \ - ftype* y, const f77_int* incy \ - ) \ -{ \ - PASTEF77S(ch,blasname) \ - ( uploa, m, alpha, a, lda, x, incx, beta, y, incy ); \ } #ifdef BLIS_ENABLE_BLAS diff --git a/frame/compat/bla_hemv.h b/frame/compat/bla_hemv.h index 68887fb0ff..4e82301146 100644 --- a/frame/compat/bla_hemv.h +++ b/frame/compat/bla_hemv.h @@ -5,8 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. - + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -41,17 +40,6 @@ #define GENTPROTCO( ftype, ftype_r, ch, chr, blasname ) \ \ BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ - ( \ - const f77_char* uploa, \ - const f77_int* m, \ - const ftype* alpha, \ - const ftype* a, const f77_int* lda, \ - const ftype* x, const f77_int* incx, \ - const ftype* beta, \ - ftype* y, const f77_int* incy \ - );\ -\ -BLIS_EXPORT_BLAS void PASTEF77S(ch,blasname) \ ( \ const f77_char* uploa, \ const f77_int* m, \ @@ -61,7 +49,7 @@ BLIS_EXPORT_BLAS void PASTEF77S(ch,blasname) \ const ftype* beta, \ ftype* y, const f77_int* incy \ ); - + #ifdef BLIS_ENABLE_BLAS INSERT_GENTPROTCO_BLAS( hemv ) #endif diff --git a/frame/compat/bla_her.c b/frame/compat/bla_her.c index f759cbda15..abe0f1e372 100755 --- a/frame/compat/bla_her.c +++ b/frame/compat/bla_her.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 2021, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -42,7 +42,7 @@ #undef GENTFUNCCO #define GENTFUNCCO( ftype, ftype_r, ch, chr, blasname, blisname ) \ \ -void PASTEF77S(ch,blasname) \ +void PASTEF77(ch,blasname) \ ( \ const f77_char* uploa, \ const f77_int* m, \ @@ -103,21 +103,8 @@ void PASTEF77S(ch,blasname) \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ -}\ -\ -void PASTEF77(ch,blasname) \ - ( \ - const f77_char* uploa, \ - const f77_int* m, \ - const ftype_r* alpha, \ - const ftype* x, const f77_int* incx, \ - ftype* a, const f77_int* lda \ - ) \ -{\ - PASTEF77S(ch,blasname) \ - ( uploa, m, alpha, x, incx, a, lda ); \ } - + #ifdef BLIS_ENABLE_BLAS INSERT_GENTFUNCCO_BLAS( her, her ) #endif diff --git a/frame/compat/bla_her.h b/frame/compat/bla_her.h index 0708dafba7..b9ae30d903 100644 --- a/frame/compat/bla_her.h +++ b/frame/compat/bla_her.h @@ -5,8 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. - + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -41,15 +40,6 @@ #define GENTPROTCO( ftype, ftype_r, ch, chr, blasname ) \ \ BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ - ( \ - const f77_char* uploa, \ - const f77_int* m, \ - const ftype_r* alpha, \ - const ftype* x, const f77_int* incx, \ - ftype* a, const f77_int* lda \ - );\ -\ -BLIS_EXPORT_BLAS void PASTEF77S(ch,blasname) \ ( \ const f77_char* uploa, \ const f77_int* m, \ @@ -57,7 +47,7 @@ BLIS_EXPORT_BLAS void PASTEF77S(ch,blasname) \ const ftype* x, const f77_int* incx, \ ftype* a, const f77_int* lda \ ); - + #ifdef BLIS_ENABLE_BLAS INSERT_GENTPROTCO_BLAS( her ) #endif diff --git a/frame/compat/bla_her2.c b/frame/compat/bla_her2.c index 8671e9e174..ce65be0cb5 100644 --- a/frame/compat/bla_her2.c +++ b/frame/compat/bla_her2.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 2021, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -42,7 +42,7 @@ #undef GENTFUNCCO #define GENTFUNCCO( ftype, ftype_r, ch, chr, blasname, blisname ) \ \ -void PASTEF77S(ch,blasname) \ +void PASTEF77(ch,blasname) \ ( \ const f77_char* uploa, \ const f77_int* m, \ @@ -111,20 +111,6 @@ void PASTEF77S(ch,blasname) \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ -}\ -\ -void PASTEF77(ch,blasname) \ - ( \ - const f77_char* uploa, \ - const f77_int* m, \ - const ftype* alpha, \ - const ftype* x, const f77_int* incx, \ - const ftype* y, const f77_int* incy, \ - ftype* a, const f77_int* lda \ - ) \ -{ \ - PASTEF77S(ch,blasname) \ - ( uploa, m, alpha, x, incx, y, incy, a, lda ); \ } #ifdef BLIS_ENABLE_BLAS diff --git a/frame/compat/bla_her2.h b/frame/compat/bla_her2.h index 2868f83a9c..7cf0bb867c 100644 --- a/frame/compat/bla_her2.h +++ b/frame/compat/bla_her2.h @@ -5,8 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. - + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -41,16 +40,6 @@ #define GENTPROTCO( ftype, ftype_r, ch, chr, blasname ) \ \ BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ - ( \ - const f77_char* uploa, \ - const f77_int* m, \ - const ftype* alpha, \ - const ftype* x, const f77_int* incx, \ - const ftype* y, const f77_int* incy, \ - ftype* a, const f77_int* lda \ - );\ -\ -BLIS_EXPORT_BLAS void PASTEF77S(ch,blasname) \ ( \ const f77_char* uploa, \ const f77_int* m, \ @@ -59,7 +48,7 @@ BLIS_EXPORT_BLAS void PASTEF77S(ch,blasname) \ const ftype* y, const f77_int* incy, \ ftype* a, const f77_int* lda \ ); - + #ifdef BLIS_ENABLE_BLAS INSERT_GENTPROTCO_BLAS( her2 ) #endif diff --git a/frame/compat/bla_symv.c b/frame/compat/bla_symv.c index 1c460ea186..c105be329e 100755 --- a/frame/compat/bla_symv.c +++ b/frame/compat/bla_symv.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 2021, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -42,7 +42,7 @@ #undef GENTFUNCRO #define GENTFUNCRO( ftype, ch, blasname, blisname ) \ \ -void PASTEF77S(ch,blasname) \ +void PASTEF77(ch,blasname) \ ( \ const f77_char* uploa, \ const f77_int* m, \ @@ -112,21 +112,6 @@ void PASTEF77S(ch,blasname) \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ -}\ -\ -void PASTEF77(ch,blasname) \ - ( \ - const f77_char* uploa, \ - const f77_int* m, \ - const ftype* alpha, \ - const ftype* a, const f77_int* lda, \ - const ftype* x, const f77_int* incx, \ - const ftype* beta, \ - ftype* y, const f77_int* incy \ - ) \ -{ \ - PASTEF77S(ch,blasname) \ - ( uploa, m, alpha, a, lda, x, incx, beta, y, incy ); \ } #ifdef BLIS_ENABLE_BLAS diff --git a/frame/compat/bla_symv.h b/frame/compat/bla_symv.h index efe434884a..9d1662fadf 100644 --- a/frame/compat/bla_symv.h +++ b/frame/compat/bla_symv.h @@ -5,8 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. - + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -41,17 +40,6 @@ #define GENTPROTRO( ftype, ch, blasname ) \ \ BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ - ( \ - const f77_char* uploa, \ - const f77_int* m, \ - const ftype* alpha, \ - const ftype* a, const f77_int* lda, \ - const ftype* x, const f77_int* incx, \ - const ftype* beta, \ - ftype* y, const f77_int* incy \ - );\ -\ -BLIS_EXPORT_BLAS void PASTEF77S(ch,blasname) \ ( \ const f77_char* uploa, \ const f77_int* m, \ diff --git a/frame/compat/bla_syr.c b/frame/compat/bla_syr.c index 97f9610294..55251ea254 100644 --- a/frame/compat/bla_syr.c +++ b/frame/compat/bla_syr.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin. - Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc.All Rights Reserved. + Copyright (C) 2020 - 2021, Advanced Micro Devices, Inc.All Rights Reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -42,7 +42,7 @@ #undef GENTFUNCRO #define GENTFUNCRO( ftype, ch, blasname, blisname ) \ \ -void PASTEF77S(ch,blasname) \ +void PASTEF77(ch,blasname) \ ( \ const f77_char* uploa, \ const f77_int* m, \ @@ -104,19 +104,6 @@ void PASTEF77S(ch,blasname) \ /* Finalize BLIS. */ \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ bli_finalize_auto(); \ -}\ -\ -void PASTEF77(ch,blasname) \ - ( \ - const f77_char* uploa, \ - const f77_int* m, \ - const ftype* alpha, \ - const ftype* x, const f77_int* incx, \ - ftype* a, const f77_int* lda \ - ) \ -{ \ - PASTEF77S(ch,blasname) \ - ( uploa, m, alpha, x, incx, a, lda ); \ } #ifdef BLIS_ENABLE_BLAS diff --git a/frame/compat/bla_syr.h b/frame/compat/bla_syr.h index 21d4324171..0d2a1e0314 100644 --- a/frame/compat/bla_syr.h +++ b/frame/compat/bla_syr.h @@ -5,8 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. - + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -41,15 +40,6 @@ #define GENTPROTRO( ftype, ch, blasname ) \ \ BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ - ( \ - const f77_char* uploa, \ - const f77_int* m, \ - const ftype* alpha, \ - const ftype* x, const f77_int* incx, \ - ftype* a, const f77_int* lda \ - );\ -\ -BLIS_EXPORT_BLAS void PASTEF77S(ch,blasname) \ ( \ const f77_char* uploa, \ const f77_int* m, \ diff --git a/frame/compat/bla_syr2.c b/frame/compat/bla_syr2.c index 9208c770b2..047dc64f9c 100644 --- a/frame/compat/bla_syr2.c +++ b/frame/compat/bla_syr2.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin. - Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc.All Rights Reserved. + Copyright (C) 2020 - 2021, Advanced Micro Devices, Inc.All Rights Reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -42,7 +42,7 @@ #undef GENTFUNCRO #define GENTFUNCRO( ftype, ch, blasname, blisname ) \ \ -void PASTEF77S(ch,blasname) \ +void PASTEF77(ch,blasname) \ ( \ const f77_char* uploa, \ const f77_int* m, \ @@ -112,20 +112,6 @@ void PASTEF77S(ch,blasname) \ /* Finalize BLIS. */ \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ bli_finalize_auto(); \ -}\ -\ -void PASTEF77(ch,blasname) \ - ( \ - const f77_char* uploa, \ - const f77_int* m, \ - const ftype* alpha, \ - const ftype* x, const f77_int* incx, \ - const ftype* y, const f77_int* incy, \ - ftype* a, const f77_int* lda \ - ) \ -{ \ - PASTEF77S(ch,blasname) \ - ( uploa, m, alpha, x, incx, y, incy, a, lda ); \ } #ifdef BLIS_ENABLE_BLAS diff --git a/frame/compat/bla_syr2.h b/frame/compat/bla_syr2.h index 00af4eefdc..b458767941 100644 --- a/frame/compat/bla_syr2.h +++ b/frame/compat/bla_syr2.h @@ -5,8 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. - + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -41,16 +40,6 @@ #define GENTPROTRO( ftype, ch, blasname ) \ \ BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ - ( \ - const f77_char* uploa, \ - const f77_int* m, \ - const ftype* alpha, \ - const ftype* x, const f77_int* incx, \ - const ftype* y, const f77_int* incy, \ - ftype* a, const f77_int* lda \ - );\ -\ -BLIS_EXPORT_BLAS void PASTEF77S(ch,blasname) \ ( \ const f77_char* uploa, \ const f77_int* m, \ diff --git a/frame/compat/bla_trmv.c b/frame/compat/bla_trmv.c index 067a9d2fa3..9c98ad787a 100644 --- a/frame/compat/bla_trmv.c +++ b/frame/compat/bla_trmv.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin. - Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc.All Rights Reserved. + Copyright (C) 2020 - 2021, Advanced Micro Devices, Inc.All Rights Reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -42,7 +42,7 @@ #undef GENTFUNC #define GENTFUNC( ftype, ch, blasname, blisname ) \ \ -void PASTEF77S(ch,blasname) \ +void PASTEF77(ch,blasname) \ ( \ const f77_char* uploa, \ const f77_char* transa, \ @@ -116,20 +116,6 @@ void PASTEF77S(ch,blasname) \ /* Finalize BLIS. */ \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ bli_finalize_auto(); \ -}\ -\ -void PASTEF77(ch,blasname) \ - ( \ - const f77_char* uploa, \ - const f77_char* transa, \ - const f77_char* diaga, \ - const f77_int* m, \ - const ftype* a, const f77_int* lda, \ - ftype* x, const f77_int* incx \ - ) \ -{ \ - PASTEF77S(ch,blasname) \ - ( uploa, transa, diaga, m, a, lda, x, incx );\ } #ifdef BLIS_ENABLE_BLAS diff --git a/frame/compat/bla_trmv.h b/frame/compat/bla_trmv.h index 06dda90bd1..4096ffe793 100644 --- a/frame/compat/bla_trmv.h +++ b/frame/compat/bla_trmv.h @@ -5,8 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. - + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -41,16 +40,6 @@ #define GENTPROT( ftype, ch, blasname ) \ \ BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ - ( \ - const f77_char* uploa, \ - const f77_char* transa, \ - const f77_char* diaga, \ - const f77_int* m, \ - const ftype* a, const f77_int* lda, \ - ftype* x, const f77_int* incx \ - );\ -\ -BLIS_EXPORT_BLAS void PASTEF77S(ch,blasname) \ ( \ const f77_char* uploa, \ const f77_char* transa, \ diff --git a/frame/compat/bla_trsv.c b/frame/compat/bla_trsv.c index 01f8c2d713..8baac6a8ba 100644 --- a/frame/compat/bla_trsv.c +++ b/frame/compat/bla_trsv.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin. - Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc.All Rights Reserved. + Copyright (C) 2020 - 2021, Advanced Micro Devices, Inc.All Rights Reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -42,7 +42,7 @@ #undef GENTFUNC #define GENTFUNC( ftype, ch, blasname, blisname ) \ \ -void PASTEF77S(ch,blasname) \ +void PASTEF77(ch,blasname) \ ( \ const f77_char* uploa, \ const f77_char* transa, \ @@ -116,20 +116,6 @@ void PASTEF77S(ch,blasname) \ /* Finalize BLIS. */ \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ bli_finalize_auto(); \ -}\ -\ -void PASTEF77(ch,blasname) \ - ( \ - const f77_char* uploa, \ - const f77_char* transa, \ - const f77_char* diaga, \ - const f77_int* m, \ - const ftype* a, const f77_int* lda, \ - ftype* x, const f77_int* incx \ - ) \ -{ \ - PASTEF77S(ch,blasname) \ - ( uploa, transa, diaga, m, a, lda, x, incx );\ } #ifdef BLIS_ENABLE_BLAS diff --git a/frame/compat/bla_trsv.h b/frame/compat/bla_trsv.h index 3d47272e2c..6edb435f10 100644 --- a/frame/compat/bla_trsv.h +++ b/frame/compat/bla_trsv.h @@ -5,8 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. - + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -41,16 +40,6 @@ #define GENTPROT( ftype, ch, blasname ) \ \ BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ - ( \ - const f77_char* uploa, \ - const f77_char* transa, \ - const f77_char* diaga, \ - const f77_int* m, \ - const ftype* a, const f77_int* lda, \ - ftype* x, const f77_int* incx \ - );\ -\ -BLIS_EXPORT_BLAS void PASTEF77S(ch,blasname) \ ( \ const f77_char* uploa, \ const f77_char* transa, \ diff --git a/frame/compat/cblas/src/cblas_f77.h b/frame/compat/cblas/src/cblas_f77.h index d13d833ab9..5ec518de9e 100644 --- a/frame/compat/cblas/src/cblas_f77.h +++ b/frame/compat/cblas/src/cblas_f77.h @@ -257,72 +257,72 @@ /* * Level 2 BLAS */ -#define F77_ssymv ssymv_blis_impl -#define F77_ssbmv ssbmv_blis_impl -#define F77_sspmv sspmv_blis_impl -#define F77_sger sger_blis_impl -#define F77_ssyr ssyr_blis_impl -#define F77_sspr sspr_blis_impl -#define F77_ssyr2 ssyr2_blis_impl -#define F77_sspr2 sspr2_blis_impl -#define F77_dsymv dsymv_blis_impl -#define F77_dsbmv dsbmv_blis_impl -#define F77_dspmv dspmv_blis_impl -#define F77_dger dger_blis_impl -#define F77_dsyr dsyr_blis_impl -#define F77_dspr dspr_blis_impl -#define F77_dsyr2 dsyr2_blis_impl -#define F77_dspr2 dspr2_blis_impl -#define F77_chemv chemv_blis_impl -#define F77_chbmv chbmv_blis_impl -#define F77_chpmv chpmv_blis_impl -#define F77_cgeru cgeru_blis_impl -#define F77_cgerc cgerc_blis_impl -#define F77_cher cher_blis_impl -#define F77_chpr chpr_blis_impl -#define F77_cher2 cher2_blis_impl -#define F77_chpr2 chpr2_blis_impl -#define F77_zhemv zhemv_blis_impl -#define F77_zhbmv zhbmv_blis_impl -#define F77_zhpmv zhpmv_blis_impl -#define F77_zgeru zgeru_blis_impl -#define F77_zgerc zgerc_blis_impl -#define F77_zher zher_blis_impl -#define F77_zhpr zhpr_blis_impl -#define F77_zher2 zher2_blis_impl -#define F77_zhpr2 zhpr2_blis_impl -#define F77_sgemv sgemv_blis_impl -#define F77_sgbmv sgbmv_blis_impl -#define F77_strmv strmv_blis_impl -#define F77_stbmv stbmv_blis_impl -#define F77_stpmv stpmv_blis_impl -#define F77_strsv strsv_blis_impl -#define F77_stbsv stbsv_blis_impl -#define F77_stpsv stpsv_blis_impl -#define F77_dgemv dgemv_blis_impl -#define F77_dgbmv dgbmv_blis_impl -#define F77_dtrmv dtrmv_blis_impl -#define F77_dtbmv dtbmv_blis_impl -#define F77_dtpmv dtpmv_blis_impl -#define F77_dtrsv dtrsv_blis_impl -#define F77_dtbsv dtbsv_blis_impl -#define F77_dtpsv dtpsv_blis_impl -#define F77_cgemv cgemv_blis_impl -#define F77_cgbmv cgbmv_blis_impl -#define F77_ctrmv ctrmv_blis_impl -#define F77_ctbmv ctbmv_blis_impl -#define F77_ctpmv ctpmv_blis_impl -#define F77_ctrsv ctrsv_blis_impl -#define F77_ctbsv ctbsv_blis_impl -#define F77_ctpsv ctpsv_blis_impl -#define F77_zgemv zgemv_blis_impl -#define F77_zgbmv zgbmv_blis_impl -#define F77_ztrmv ztrmv_blis_impl -#define F77_ztbmv ztbmv_blis_impl -#define F77_ztpmv ztpmv_blis_impl -#define F77_ztrsv ztrsv_blis_impl -#define F77_ztbsv ztbsv_blis_impl -#define F77_ztpsv ztpsv_blis_impl +#define F77_ssymv ssymv_ +#define F77_ssbmv ssbmv_ +#define F77_sspmv sspmv_ +#define F77_sger sger_ +#define F77_ssyr ssyr_ +#define F77_sspr sspr_ +#define F77_ssyr2 ssyr2_ +#define F77_sspr2 sspr2_ +#define F77_dsymv dsymv_ +#define F77_dsbmv dsbmv_ +#define F77_dspmv dspmv_ +#define F77_dger dger_ +#define F77_dsyr dsyr_ +#define F77_dspr dspr_ +#define F77_dsyr2 dsyr2_ +#define F77_dspr2 dspr2_ +#define F77_chemv chemv_ +#define F77_chbmv chbmv_ +#define F77_chpmv chpmv_ +#define F77_cgeru cgeru_ +#define F77_cgerc cgerc_ +#define F77_cher cher_ +#define F77_chpr chpr_ +#define F77_cher2 cher2_ +#define F77_chpr2 chpr2_ +#define F77_zhemv zhemv_ +#define F77_zhbmv zhbmv_ +#define F77_zhpmv zhpmv_ +#define F77_zgeru zgeru_ +#define F77_zgerc zgerc_ +#define F77_zher zher_ +#define F77_zhpr zhpr_ +#define F77_zher2 zher2_ +#define F77_zhpr2 zhpr2_ +#define F77_sgemv sgemv_ +#define F77_sgbmv sgbmv_ +#define F77_strmv strmv_ +#define F77_stbmv stbmv_ +#define F77_stpmv stpmv_ +#define F77_strsv strsv_ +#define F77_stbsv stbsv_ +#define F77_stpsv stpsv_ +#define F77_dgemv dgemv_ +#define F77_dgbmv dgbmv_ +#define F77_dtrmv dtrmv_ +#define F77_dtbmv dtbmv_ +#define F77_dtpmv dtpmv_ +#define F77_dtrsv dtrsv_ +#define F77_dtbsv dtbsv_ +#define F77_dtpsv dtpsv_ +#define F77_cgemv cgemv_ +#define F77_cgbmv cgbmv_ +#define F77_ctrmv ctrmv_ +#define F77_ctbmv ctbmv_ +#define F77_ctpmv ctpmv_ +#define F77_ctrsv ctrsv_ +#define F77_ctbsv ctbsv_ +#define F77_ctpsv ctpsv_ +#define F77_zgemv zgemv_ +#define F77_zgbmv zgbmv_ +#define F77_ztrmv ztrmv_ +#define F77_ztbmv ztbmv_ +#define F77_ztpmv ztpmv_ +#define F77_ztrsv ztrsv_ +#define F77_ztbsv ztbsv_ +#define F77_ztpsv ztpsv_ /* * Level 3 BLAS */ diff --git a/frame/compat/f2c/bla_gbmv.c b/frame/compat/f2c/bla_gbmv.c index 06999fdd93..d53dd322ad 100644 --- a/frame/compat/f2c/bla_gbmv.c +++ b/frame/compat/f2c/bla_gbmv.c @@ -5,8 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. - + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -42,8 +41,7 @@ -lf2c -lm (in that order) */ -/* Subroutine */ -int PASTEF77S(c,gbmv)(const bla_character *trans, const bla_integer *m, const bla_integer *n, const bla_integer *kl, const bla_integer *ku, const bla_scomplex *alpha, const bla_scomplex *a, const bla_integer *lda, const bla_scomplex *x, const bla_integer *incx, const bla_scomplex *beta, bla_scomplex *y, const bla_integer *incy) +/* Subroutine */ int PASTEF77(c,gbmv)(const bla_character *trans, const bla_integer *m, const bla_integer *n, const bla_integer *kl, const bla_integer *ku, const bla_scomplex *alpha, const bla_scomplex *a, const bla_integer *lda, const bla_scomplex *x, const bla_integer *incx, const bla_scomplex *beta, bla_scomplex *y, const bla_integer *incy) { /* System generated locals */ bla_integer a_dim1, a_offset, i__1, i__2, i__3, i__4, i__5, i__6; @@ -484,8 +482,7 @@ int PASTEF77S(c,gbmv)(const bla_character *trans, const bla_integer *m, const bl -lf2c -lm (in that order) */ -/* Subroutine */ -int PASTEF77S(d,gbmv)(const bla_character *trans, const bla_integer *m, const bla_integer *n, const bla_integer *kl, const bla_integer *ku, const bla_double *alpha, const bla_double *a, const bla_integer *lda, const bla_double *x, const bla_integer *incx, const bla_double *beta, bla_double *y, const bla_integer *incy) +/* Subroutine */ int PASTEF77(d,gbmv)(const bla_character *trans, const bla_integer *m, const bla_integer *n, const bla_integer *kl, const bla_integer *ku, const bla_double *alpha, const bla_double *a, const bla_integer *lda, const bla_double *x, const bla_integer *incx, const bla_double *beta, bla_double *y, const bla_integer *incy) { /* System generated locals */ bla_integer a_dim1, a_offset, i__1, i__2, i__3, i__4, i__5, i__6; @@ -841,8 +838,7 @@ int PASTEF77S(d,gbmv)(const bla_character *trans, const bla_integer *m, const bl -lf2c -lm (in that order) */ -/* Subroutine */ -int PASTEF77S(s,gbmv)(const bla_character *trans, const bla_integer *m, const bla_integer *n, const bla_integer *kl, const bla_integer *ku, const bla_real *alpha, const bla_real *a, const bla_integer *lda, const bla_real *x, const bla_integer * incx, const bla_real *beta, bla_real *y, const bla_integer *incy) +/* Subroutine */ int PASTEF77(s,gbmv)(const bla_character *trans, const bla_integer *m, const bla_integer *n, const bla_integer *kl, const bla_integer *ku, const bla_real *alpha, const bla_real *a, const bla_integer *lda, const bla_real *x, const bla_integer * incx, const bla_real *beta, bla_real *y, const bla_integer *incy) { /* System generated locals */ bla_integer a_dim1, a_offset, i__1, i__2, i__3, i__4, i__5, i__6; @@ -1198,8 +1194,7 @@ int PASTEF77S(s,gbmv)(const bla_character *trans, const bla_integer *m, const bl -lf2c -lm (in that order) */ -/* Subroutine */ -int PASTEF77S(z,gbmv)(const bla_character *trans, const bla_integer *m, const bla_integer *n, const bla_integer *kl, const bla_integer *ku, const bla_dcomplex *alpha, const bla_dcomplex *a, const bla_integer *lda, const bla_dcomplex *x, const bla_integer *incx, const bla_dcomplex *beta, bla_dcomplex * y, const bla_integer *incy) +/* Subroutine */ int PASTEF77(z,gbmv)(const bla_character *trans, const bla_integer *m, const bla_integer *n, const bla_integer *kl, const bla_integer *ku, const bla_dcomplex *alpha, const bla_dcomplex *a, const bla_integer *lda, const bla_dcomplex *x, const bla_integer *incx, const bla_dcomplex *beta, bla_dcomplex * y, const bla_integer *incy) { /* System generated locals */ bla_integer a_dim1, a_offset, i__1, i__2, i__3, i__4, i__5, i__6; @@ -1635,24 +1630,5 @@ int PASTEF77S(z,gbmv)(const bla_character *trans, const bla_integer *m, const bl } /* zgbmv_ */ -int PASTEF77(s,gbmv)(const bla_character *trans, const bla_integer *m, const bla_integer *n, const bla_integer *kl, const bla_integer *ku, const bla_real *alpha, const bla_real *a, const bla_integer *lda, const bla_real *x, const bla_integer * incx, const bla_real *beta, bla_real *y, const bla_integer *incy) -{ - return PASTEF77S(s,gbmv)( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy ); -} - -int PASTEF77(d,gbmv)(const bla_character *trans, const bla_integer *m, const bla_integer *n, const bla_integer *kl, const bla_integer *ku, const bla_double *alpha, const bla_double *a, const bla_integer *lda, const bla_double *x, const bla_integer *incx, const bla_double *beta, bla_double *y, const bla_integer *incy) -{ - return PASTEF77S(d,gbmv)( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy ); -} -int PASTEF77(c,gbmv)(const bla_character *trans, const bla_integer *m, const bla_integer *n, const bla_integer *kl, const bla_integer *ku, const bla_scomplex *alpha, const bla_scomplex *a, const bla_integer *lda, const bla_scomplex *x, const bla_integer *incx, const bla_scomplex *beta, bla_scomplex *y, const bla_integer *incy) -{ - return PASTEF77S(c,gbmv)( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy ); -} - -int PASTEF77(z,gbmv)(const bla_character *trans, const bla_integer *m, const bla_integer *n, const bla_integer *kl, const bla_integer *ku, const bla_dcomplex *alpha, const bla_dcomplex *a, const bla_integer *lda, const bla_dcomplex *x, const bla_integer *incx, const bla_dcomplex *beta, bla_dcomplex * y, const bla_integer *incy) -{ - return PASTEF77S(z,gbmv)( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy ); -} - #endif diff --git a/frame/compat/f2c/bla_gbmv.h b/frame/compat/f2c/bla_gbmv.h index 2ee9a638b3..eb8ce25342 100644 --- a/frame/compat/f2c/bla_gbmv.h +++ b/frame/compat/f2c/bla_gbmv.h @@ -5,8 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. - + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -39,9 +38,5 @@ BLIS_EXPORT_BLAS int PASTEF77(c,gbmv)(const bla_character *trans, const bla_inte BLIS_EXPORT_BLAS int PASTEF77(d,gbmv)(const bla_character *trans, const bla_integer *m, const bla_integer *n, const bla_integer *kl, const bla_integer *ku, const bla_double *alpha, const bla_double *a, const bla_integer *lda, const bla_double *x, const bla_integer *incx, const bla_double *beta, bla_double *y, const bla_integer *incy); BLIS_EXPORT_BLAS int PASTEF77(s,gbmv)(const bla_character *trans, const bla_integer *m, const bla_integer *n, const bla_integer *kl, const bla_integer *ku, const bla_real *alpha, const bla_real *a, const bla_integer *lda, const bla_real *x, const bla_integer * incx, const bla_real *beta, bla_real *y, const bla_integer *incy); BLIS_EXPORT_BLAS int PASTEF77(z,gbmv)(const bla_character *trans, const bla_integer *m, const bla_integer *n, const bla_integer *kl, const bla_integer *ku, const bla_dcomplex *alpha, const bla_dcomplex *a, const bla_integer *lda, const bla_dcomplex *x, const bla_integer *incx, const bla_dcomplex *beta, bla_dcomplex * y, const bla_integer *incy); -BLIS_EXPORT_BLAS int PASTEF77S(c,gbmv)(const bla_character *trans, const bla_integer *m, const bla_integer *n, const bla_integer *kl, const bla_integer *ku, const bla_scomplex *alpha, const bla_scomplex *a, const bla_integer *lda, const bla_scomplex *x, const bla_integer *incx, const bla_scomplex *beta, bla_scomplex *y, const bla_integer *incy); -BLIS_EXPORT_BLAS int PASTEF77S(d,gbmv)(const bla_character *trans, const bla_integer *m, const bla_integer *n, const bla_integer *kl, const bla_integer *ku, const bla_double *alpha, const bla_double *a, const bla_integer *lda, const bla_double *x, const bla_integer *incx, const bla_double *beta, bla_double *y, const bla_integer *incy); -BLIS_EXPORT_BLAS int PASTEF77S(s,gbmv)(const bla_character *trans, const bla_integer *m, const bla_integer *n, const bla_integer *kl, const bla_integer *ku, const bla_real *alpha, const bla_real *a, const bla_integer *lda, const bla_real *x, const bla_integer * incx, const bla_real *beta, bla_real *y, const bla_integer *incy); -BLIS_EXPORT_BLAS int PASTEF77S(z,gbmv)(const bla_character *trans, const bla_integer *m, const bla_integer *n, const bla_integer *kl, const bla_integer *ku, const bla_dcomplex *alpha, const bla_dcomplex *a, const bla_integer *lda, const bla_dcomplex *x, const bla_integer *incx, const bla_dcomplex *beta, bla_dcomplex * y, const bla_integer *incy); #endif diff --git a/frame/compat/f2c/bla_hbmv.c b/frame/compat/f2c/bla_hbmv.c index af02c3f0ca..198336d048 100644 --- a/frame/compat/f2c/bla_hbmv.c +++ b/frame/compat/f2c/bla_hbmv.c @@ -5,8 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. - + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -42,8 +41,7 @@ -lf2c -lm (in that order) */ -/* Subroutine */ -int PASTEF77S(c,hbmv)(const bla_character *uplo, const bla_integer *n, const bla_integer *k, const bla_scomplex * alpha, const bla_scomplex *a, const bla_integer *lda, const bla_scomplex *x, const bla_integer *incx, const bla_scomplex *beta, bla_scomplex *y, const bla_integer *incy) +/* Subroutine */ int PASTEF77(c,hbmv)(const bla_character *uplo, const bla_integer *n, const bla_integer *k, const bla_scomplex * alpha, const bla_scomplex *a, const bla_integer *lda, const bla_scomplex *x, const bla_integer *incx, const bla_scomplex *beta, bla_scomplex *y, const bla_integer *incy) { /* System generated locals */ bla_integer a_dim1, a_offset, i__1, i__2, i__3, i__4, i__5; @@ -489,8 +487,7 @@ int PASTEF77S(c,hbmv)(const bla_character *uplo, const bla_integer *n, const bla -lf2c -lm (in that order) */ -/* Subroutine */ -int PASTEF77S(z,hbmv)(const bla_character *uplo, const bla_integer *n, const bla_integer *k, const bla_dcomplex *alpha, const bla_dcomplex *a, const bla_integer *lda, const bla_dcomplex *x, const bla_integer * incx, const bla_dcomplex *beta, bla_dcomplex *y, const bla_integer *incy) +/* Subroutine */ int PASTEF77(z,hbmv)(const bla_character *uplo, const bla_integer *n, const bla_integer *k, const bla_dcomplex *alpha, const bla_dcomplex *a, const bla_integer *lda, const bla_dcomplex *x, const bla_integer * incx, const bla_dcomplex *beta, bla_dcomplex *y, const bla_integer *incy) { /* System generated locals */ bla_integer a_dim1, a_offset, i__1, i__2, i__3, i__4, i__5; @@ -931,15 +928,5 @@ int PASTEF77S(z,hbmv)(const bla_character *uplo, const bla_integer *n, const bla } /* zhbmv_ */ -int PASTEF77(c,hbmv)(const bla_character *uplo, const bla_integer *n, const bla_integer *k, const bla_scomplex * alpha, const bla_scomplex *a, const bla_integer *lda, const bla_scomplex *x, const bla_integer *incx, const bla_scomplex *beta, bla_scomplex *y, const bla_integer *incy) -{ - return PASTEF77S(c,hbmv)( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy ); -} - -int PASTEF77(z,hbmv)(const bla_character *uplo, const bla_integer *n, const bla_integer *k, const bla_dcomplex *alpha, const bla_dcomplex *a, const bla_integer *lda, const bla_dcomplex *x, const bla_integer * incx, const bla_dcomplex *beta, bla_dcomplex *y, const bla_integer *incy) -{ - return PASTEF77S(z,hbmv)( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy ); -} - #endif diff --git a/frame/compat/f2c/bla_hbmv.h b/frame/compat/f2c/bla_hbmv.h index fa4dde14a1..1ddb838071 100644 --- a/frame/compat/f2c/bla_hbmv.h +++ b/frame/compat/f2c/bla_hbmv.h @@ -5,8 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. - + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -37,7 +36,5 @@ BLIS_EXPORT_BLAS int PASTEF77(c,hbmv)(const bla_character *uplo, const bla_integer *n, const bla_integer *k, const bla_scomplex *alpha, const bla_scomplex *a, const bla_integer *lda, const bla_scomplex *x, const bla_integer *incx, const bla_scomplex *beta, bla_scomplex *y, const bla_integer *incy); BLIS_EXPORT_BLAS int PASTEF77(z,hbmv)(const bla_character *uplo, const bla_integer *n, const bla_integer *k, const bla_dcomplex *alpha, const bla_dcomplex *a, const bla_integer *lda, const bla_dcomplex *x, const bla_integer *incx, const bla_dcomplex *beta, bla_dcomplex *y, const bla_integer *incy); -BLIS_EXPORT_BLAS int PASTEF77S(c,hbmv)(const bla_character *uplo, const bla_integer *n, const bla_integer *k, const bla_scomplex *alpha, const bla_scomplex *a, const bla_integer *lda, const bla_scomplex *x, const bla_integer *incx, const bla_scomplex *beta, bla_scomplex *y, const bla_integer *incy); -BLIS_EXPORT_BLAS int PASTEF77S(z,hbmv)(const bla_character *uplo, const bla_integer *n, const bla_integer *k, const bla_dcomplex *alpha, const bla_dcomplex *a, const bla_integer *lda, const bla_dcomplex *x, const bla_integer *incx, const bla_dcomplex *beta, bla_dcomplex *y, const bla_integer *incy); #endif diff --git a/frame/compat/f2c/bla_hpmv.c b/frame/compat/f2c/bla_hpmv.c index 344c2868ae..0d7ebce9d7 100644 --- a/frame/compat/f2c/bla_hpmv.c +++ b/frame/compat/f2c/bla_hpmv.c @@ -5,8 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. - + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -42,8 +41,7 @@ -lf2c -lm (in that order) */ -/* Subroutine */ -int PASTEF77S(c,hpmv)(const bla_character *uplo, const bla_integer *n, const bla_scomplex *alpha, const bla_scomplex * ap, const bla_scomplex *x, const bla_integer *incx, const bla_scomplex *beta, bla_scomplex *y, const bla_integer *incy) +/* Subroutine */ int PASTEF77(c,hpmv)(const bla_character *uplo, const bla_integer *n, const bla_scomplex *alpha, const bla_scomplex * ap, const bla_scomplex *x, const bla_integer *incx, const bla_scomplex *beta, bla_scomplex *y, const bla_integer *incy) { /* System generated locals */ bla_integer i__1, i__2, i__3, i__4, i__5; @@ -441,8 +439,7 @@ int PASTEF77S(c,hpmv)(const bla_character *uplo, const bla_integer *n, const bla -lf2c -lm (in that order) */ -/* Subroutine */ -int PASTEF77S(z,hpmv)(const bla_character *uplo, const bla_integer *n, const bla_dcomplex *alpha, const bla_dcomplex *ap, const bla_dcomplex *x, const bla_integer *incx, const bla_dcomplex *beta, bla_dcomplex *y, const bla_integer *incy) +/* Subroutine */ int PASTEF77(z,hpmv)(const bla_character *uplo, const bla_integer *n, const bla_dcomplex *alpha, const bla_dcomplex *ap, const bla_dcomplex *x, const bla_integer *incx, const bla_dcomplex *beta, bla_dcomplex *y, const bla_integer *incy) { /* System generated locals */ bla_integer i__1, i__2, i__3, i__4, i__5; @@ -835,15 +832,5 @@ int PASTEF77S(z,hpmv)(const bla_character *uplo, const bla_integer *n, const bla } /* zhpmv_ */ -int PASTEF77(c,hpmv)(const bla_character *uplo, const bla_integer *n, const bla_scomplex *alpha, const bla_scomplex * ap, const bla_scomplex *x, const bla_integer *incx, const bla_scomplex *beta, bla_scomplex *y, const bla_integer *incy) -{ - return PASTEF77S(c,hpmv)( uplo, n, alpha, ap, x, incx, beta, y, incy ); -} - -int PASTEF77(z,hpmv)(const bla_character *uplo, const bla_integer *n, const bla_dcomplex *alpha, const bla_dcomplex *ap, const bla_dcomplex *x, const bla_integer *incx, const bla_dcomplex *beta, bla_dcomplex *y, const bla_integer *incy) -{ - return PASTEF77S(z,hpmv)( uplo, n, alpha, ap, x, incx, beta, y, incy ); -} - #endif diff --git a/frame/compat/f2c/bla_hpmv.h b/frame/compat/f2c/bla_hpmv.h index 61a8e9c1d8..26d055effd 100644 --- a/frame/compat/f2c/bla_hpmv.h +++ b/frame/compat/f2c/bla_hpmv.h @@ -5,8 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. - + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -37,7 +36,5 @@ BLIS_EXPORT_BLAS int PASTEF77(c,hpmv)(const bla_character *uplo, const bla_integer *n, const bla_scomplex *alpha, const bla_scomplex *ap, const bla_scomplex *x, const bla_integer *incx, const bla_scomplex *beta, bla_scomplex *y, const bla_integer *incy); BLIS_EXPORT_BLAS int PASTEF77(z,hpmv)(const bla_character *uplo, const bla_integer *n, const bla_dcomplex *alpha, const bla_dcomplex *ap, const bla_dcomplex *x, const bla_integer *incx, const bla_dcomplex *beta, bla_dcomplex *y, const bla_integer *incy); -BLIS_EXPORT_BLAS int PASTEF77S(c,hpmv)(const bla_character *uplo, const bla_integer *n, const bla_scomplex *alpha, const bla_scomplex *ap, const bla_scomplex *x, const bla_integer *incx, const bla_scomplex *beta, bla_scomplex *y, const bla_integer *incy); -BLIS_EXPORT_BLAS int PASTEF77S(z,hpmv)(const bla_character *uplo, const bla_integer *n, const bla_dcomplex *alpha, const bla_dcomplex *ap, const bla_dcomplex *x, const bla_integer *incx, const bla_dcomplex *beta, bla_dcomplex *y, const bla_integer *incy); #endif diff --git a/frame/compat/f2c/bla_hpr.c b/frame/compat/f2c/bla_hpr.c index ae013f869d..da1f0a0f39 100644 --- a/frame/compat/f2c/bla_hpr.c +++ b/frame/compat/f2c/bla_hpr.c @@ -5,8 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. - + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -42,8 +41,7 @@ -lf2c -lm (in that order) */ -/* Subroutine */ -int PASTEF77S(c,hpr)(const bla_character *uplo, const bla_integer *n, const bla_real *alpha, const bla_scomplex *x, const bla_integer *incx, bla_scomplex *ap) +/* Subroutine */ int PASTEF77(c,hpr)(const bla_character *uplo, const bla_integer *n, const bla_real *alpha, const bla_scomplex *x, const bla_integer *incx, bla_scomplex *ap) { /* System generated locals */ bla_integer i__1, i__2, i__3, i__4, i__5; @@ -355,8 +353,7 @@ int PASTEF77S(c,hpr)(const bla_character *uplo, const bla_integer *n, const bla_ -lf2c -lm (in that order) */ -/* Subroutine */ -int PASTEF77S(z,hpr)(const bla_character *uplo, const bla_integer *n, const bla_double *alpha, const bla_dcomplex *x, const bla_integer *incx, bla_dcomplex *ap) +/* Subroutine */ int PASTEF77(z,hpr)(const bla_character *uplo, const bla_integer *n, const bla_double *alpha, const bla_dcomplex *x, const bla_integer *incx, bla_dcomplex *ap) { /* System generated locals */ bla_integer i__1, i__2, i__3, i__4, i__5; @@ -663,15 +660,5 @@ int PASTEF77S(z,hpr)(const bla_character *uplo, const bla_integer *n, const bla_ } /* zhpr_ */ -int PASTEF77(c,hpr)(const bla_character *uplo, const bla_integer *n, const bla_real *alpha, const bla_scomplex *x, const bla_integer *incx, bla_scomplex *ap) -{ - return PASTEF77S(c,hpr)( uplo, n, alpha, x, incx, ap ); -} - -int PASTEF77(z,hpr)(const bla_character *uplo, const bla_integer *n, const bla_double *alpha, const bla_dcomplex *x, const bla_integer *incx, bla_dcomplex *ap) -{ - return PASTEF77S(z,hpr)( uplo, n, alpha, x, incx, ap ); -} - #endif diff --git a/frame/compat/f2c/bla_hpr.h b/frame/compat/f2c/bla_hpr.h index 3f6ffa7064..cfce9e1779 100644 --- a/frame/compat/f2c/bla_hpr.h +++ b/frame/compat/f2c/bla_hpr.h @@ -5,8 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. - + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -37,7 +36,5 @@ BLIS_EXPORT_BLAS int PASTEF77(c,hpr)(const bla_character *uplo, const bla_integer *n, const bla_real *alpha, const bla_scomplex *x, const bla_integer *incx, bla_scomplex *ap); BLIS_EXPORT_BLAS int PASTEF77(z,hpr)(const bla_character *uplo, const bla_integer *n, const bla_double *alpha, const bla_dcomplex *x, const bla_integer *incx, bla_dcomplex *ap); -BLIS_EXPORT_BLAS int PASTEF77S(c,hpr)(const bla_character *uplo, const bla_integer *n, const bla_real *alpha, const bla_scomplex *x, const bla_integer *incx, bla_scomplex *ap); -BLIS_EXPORT_BLAS int PASTEF77S(z,hpr)(const bla_character *uplo, const bla_integer *n, const bla_double *alpha, const bla_dcomplex *x, const bla_integer *incx, bla_dcomplex *ap); #endif diff --git a/frame/compat/f2c/bla_hpr2.c b/frame/compat/f2c/bla_hpr2.c index f99e8181b6..c78c1eec04 100644 --- a/frame/compat/f2c/bla_hpr2.c +++ b/frame/compat/f2c/bla_hpr2.c @@ -5,8 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. - + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -42,8 +41,7 @@ -lf2c -lm (in that order) */ -/* Subroutine */ -int PASTEF77S(c,hpr2)(const bla_character *uplo, const bla_integer *n, const bla_scomplex *alpha, const bla_scomplex *x, const bla_integer *incx, const bla_scomplex *y, const bla_integer *incy, bla_scomplex *ap) +/* Subroutine */ int PASTEF77(c,hpr2)(const bla_character *uplo, const bla_integer *n, const bla_scomplex *alpha, const bla_scomplex *x, const bla_integer *incx, const bla_scomplex *y, const bla_integer *incy, bla_scomplex *ap) { /* System generated locals */ bla_integer i__1, i__2, i__3, i__4, i__5, i__6; @@ -431,8 +429,7 @@ int PASTEF77S(c,hpr2)(const bla_character *uplo, const bla_integer *n, const bla -lf2c -lm (in that order) */ -/* Subroutine */ -int PASTEF77S(z,hpr2)(const bla_character *uplo, const bla_integer *n, const bla_dcomplex *alpha, const bla_dcomplex *x, const bla_integer *incx, const bla_dcomplex *y, const bla_integer *incy, bla_dcomplex *ap) +/* Subroutine */ int PASTEF77(z,hpr2)(const bla_character *uplo, const bla_integer *n, const bla_dcomplex *alpha, const bla_dcomplex *x, const bla_integer *incx, const bla_dcomplex *y, const bla_integer *incy, bla_dcomplex *ap) { /* System generated locals */ bla_integer i__1, i__2, i__3, i__4, i__5, i__6; @@ -815,15 +812,5 @@ int PASTEF77S(z,hpr2)(const bla_character *uplo, const bla_integer *n, const bla } /* zhpr2_ */ -int PASTEF77(c,hpr2)(const bla_character *uplo, const bla_integer *n, const bla_scomplex *alpha, const bla_scomplex *x, const bla_integer *incx, const bla_scomplex *y, const bla_integer *incy, bla_scomplex *ap) -{ - return PASTEF77S(c,hpr2)( uplo, n, alpha, x, incx, y, incy, ap ); -} - -int PASTEF77(z,hpr2)(const bla_character *uplo, const bla_integer *n, const bla_dcomplex *alpha, const bla_dcomplex *x, const bla_integer *incx, const bla_dcomplex *y, const bla_integer *incy, bla_dcomplex *ap) -{ - return PASTEF77S(z,hpr2)( uplo, n, alpha, x, incx, y, incy, ap ); -} - #endif diff --git a/frame/compat/f2c/bla_hpr2.h b/frame/compat/f2c/bla_hpr2.h index 6e56b5e053..16f929d611 100644 --- a/frame/compat/f2c/bla_hpr2.h +++ b/frame/compat/f2c/bla_hpr2.h @@ -5,8 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. - + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -37,7 +36,5 @@ BLIS_EXPORT_BLAS int PASTEF77(c,hpr2)(const bla_character *uplo, const bla_integer *n, const bla_scomplex *alpha, const bla_scomplex *x, const bla_integer *incx, const bla_scomplex *y, const bla_integer *incy, bla_scomplex *ap); BLIS_EXPORT_BLAS int PASTEF77(z,hpr2)(const bla_character *uplo, const bla_integer *n, const bla_dcomplex *alpha, const bla_dcomplex *x, const bla_integer *incx, const bla_dcomplex *y, const bla_integer *incy, bla_dcomplex *ap); -BLIS_EXPORT_BLAS int PASTEF77S(c,hpr2)(const bla_character *uplo, const bla_integer *n, const bla_scomplex *alpha, const bla_scomplex *x, const bla_integer *incx, const bla_scomplex *y, const bla_integer *incy, bla_scomplex *ap); -BLIS_EXPORT_BLAS int PASTEF77S(z,hpr2)(const bla_character *uplo, const bla_integer *n, const bla_dcomplex *alpha, const bla_dcomplex *x, const bla_integer *incx, const bla_dcomplex *y, const bla_integer *incy, bla_dcomplex *ap); #endif diff --git a/frame/compat/f2c/bla_sbmv.c b/frame/compat/f2c/bla_sbmv.c index bfbbcf0091..566fabd81c 100644 --- a/frame/compat/f2c/bla_sbmv.c +++ b/frame/compat/f2c/bla_sbmv.c @@ -5,8 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. - + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -42,8 +41,7 @@ -lf2c -lm (in that order) */ -/* Subroutine */ -int PASTEF77S(d,sbmv)(const bla_character *uplo, const bla_integer *n, const bla_integer *k, const bla_double *alpha, const bla_double *a, const bla_integer *lda, const bla_double *x, const bla_integer *incx, const bla_double *beta, bla_double *y, const bla_integer *incy) +/* Subroutine */ int PASTEF77(d,sbmv)(const bla_character *uplo, const bla_integer *n, const bla_integer *k, const bla_double *alpha, const bla_double *a, const bla_integer *lda, const bla_double *x, const bla_integer *incx, const bla_double *beta, bla_double *y, const bla_integer *incy) { /* System generated locals */ bla_integer a_dim1, a_offset, i__1, i__2, i__3, i__4; @@ -394,8 +392,7 @@ int PASTEF77S(d,sbmv)(const bla_character *uplo, const bla_integer *n, const bla -lf2c -lm (in that order) */ -/* Subroutine */ -int PASTEF77S(s,sbmv)(const bla_character *uplo, const bla_integer *n, const bla_integer *k, const bla_real *alpha, const bla_real *a, const bla_integer *lda, const bla_real *x, const bla_integer *incx, const bla_real *beta, bla_real *y, const bla_integer *incy) +/* Subroutine */ int PASTEF77(s,sbmv)(const bla_character *uplo, const bla_integer *n, const bla_integer *k, const bla_real *alpha, const bla_real *a, const bla_integer *lda, const bla_real *x, const bla_integer *incx, const bla_real *beta, bla_real *y, const bla_integer *incy) { /* System generated locals */ bla_integer a_dim1, a_offset, i__1, i__2, i__3, i__4; @@ -741,15 +738,5 @@ int PASTEF77S(s,sbmv)(const bla_character *uplo, const bla_integer *n, const bla } /* ssbmv_ */ -int PASTEF77(d,sbmv)(const bla_character *uplo, const bla_integer *n, const bla_integer *k, const bla_double *alpha, const bla_double *a, const bla_integer *lda, const bla_double *x, const bla_integer *incx, const bla_double *beta, bla_double *y, const bla_integer *incy) -{ - return PASTEF77S(d,sbmv)(uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); -} - -int PASTEF77(s,sbmv)(const bla_character *uplo, const bla_integer *n, const bla_integer *k, const bla_real *alpha, const bla_real *a, const bla_integer *lda, const bla_real *x, const bla_integer *incx, const bla_real *beta, bla_real *y, const bla_integer *incy) -{ - return PASTEF77S(s,sbmv)(uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); -} - #endif diff --git a/frame/compat/f2c/bla_sbmv.h b/frame/compat/f2c/bla_sbmv.h index 84e86273a1..c3f3fc24f8 100644 --- a/frame/compat/f2c/bla_sbmv.h +++ b/frame/compat/f2c/bla_sbmv.h @@ -5,8 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. - + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -37,7 +36,5 @@ BLIS_EXPORT_BLAS int PASTEF77(d,sbmv)(const bla_character *uplo, const bla_integer *n, const bla_integer *k, const bla_double *alpha, const bla_double *a, const bla_integer *lda, const bla_double *x, const bla_integer *incx, const bla_double *beta, bla_double *y, const bla_integer *incy); BLIS_EXPORT_BLAS int PASTEF77(s,sbmv)(const bla_character *uplo, const bla_integer *n, const bla_integer *k, const bla_real *alpha, const bla_real *a, const bla_integer *lda, const bla_real *x, const bla_integer *incx, const bla_real *beta, bla_real *y, const bla_integer *incy); -BLIS_EXPORT_BLAS int PASTEF77S(d,sbmv)(const bla_character *uplo, const bla_integer *n, const bla_integer *k, const bla_double *alpha, const bla_double *a, const bla_integer *lda, const bla_double *x, const bla_integer *incx, const bla_double *beta, bla_double *y, const bla_integer *incy); -BLIS_EXPORT_BLAS int PASTEF77S(s,sbmv)(const bla_character *uplo, const bla_integer *n, const bla_integer *k, const bla_real *alpha, const bla_real *a, const bla_integer *lda, const bla_real *x, const bla_integer *incx, const bla_real *beta, bla_real *y, const bla_integer *incy); #endif diff --git a/frame/compat/f2c/bla_spmv.c b/frame/compat/f2c/bla_spmv.c index e3cdbd70c8..0485e1dc3a 100644 --- a/frame/compat/f2c/bla_spmv.c +++ b/frame/compat/f2c/bla_spmv.c @@ -5,8 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. - + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -42,8 +41,7 @@ -lf2c -lm (in that order) */ -/* Subroutine */ -int PASTEF77S(d,spmv)(const bla_character *uplo, const bla_integer *n, const bla_double *alpha, const bla_double *ap, const bla_double *x, const bla_integer *incx, const bla_double *beta, bla_double *y, const bla_integer *incy) +/* Subroutine */ int PASTEF77(d,spmv)(const bla_character *uplo, const bla_integer *n, const bla_double *alpha, const bla_double *ap, const bla_double *x, const bla_integer *incx, const bla_double *beta, bla_double *y, const bla_integer *incy) { /* System generated locals */ bla_integer i__1, i__2; @@ -344,8 +342,7 @@ int PASTEF77S(d,spmv)(const bla_character *uplo, const bla_integer *n, const bla -lf2c -lm (in that order) */ -/* Subroutine */ -int PASTEF77S(s,spmv)(const bla_character *uplo, const bla_integer *n, const bla_real *alpha, const bla_real *ap, const bla_real *x, const bla_integer *incx, const bla_real *beta, bla_real *y, const bla_integer *incy) +/* Subroutine */ int PASTEF77(s,spmv)(const bla_character *uplo, const bla_integer *n, const bla_real *alpha, const bla_real *ap, const bla_real *x, const bla_integer *incx, const bla_real *beta, bla_real *y, const bla_integer *incy) { /* System generated locals */ bla_integer i__1, i__2; @@ -641,15 +638,5 @@ int PASTEF77S(s,spmv)(const bla_character *uplo, const bla_integer *n, const bla } /* sspmv_ */ -int PASTEF77(d,spmv)(const bla_character *uplo, const bla_integer *n, const bla_double *alpha, const bla_double *ap, const bla_double *x, const bla_integer *incx, const bla_double *beta, bla_double *y, const bla_integer *incy) -{ - return PASTEF77S(d,spmv)( uplo, n, alpha, ap, x, incx, beta, y, incy); -} - -int PASTEF77(s,spmv)(const bla_character *uplo, const bla_integer *n, const bla_real *alpha, const bla_real *ap, const bla_real *x, const bla_integer *incx, const bla_real *beta, bla_real *y, const bla_integer *incy) -{ - return PASTEF77S(s,spmv)( uplo, n, alpha, ap, x, incx, beta, y, incy); -} - #endif diff --git a/frame/compat/f2c/bla_spmv.h b/frame/compat/f2c/bla_spmv.h index df85babf1e..7db7d4a8b6 100644 --- a/frame/compat/f2c/bla_spmv.h +++ b/frame/compat/f2c/bla_spmv.h @@ -5,8 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. - + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -37,7 +36,5 @@ BLIS_EXPORT_BLAS int PASTEF77(d,spmv)(const bla_character *uplo, const bla_integer *n, const bla_double *alpha, const bla_double *ap, const bla_double *x, const bla_integer *incx, const bla_double *beta, bla_double *y, const bla_integer *incy); BLIS_EXPORT_BLAS int PASTEF77(s,spmv)(const bla_character *uplo, const bla_integer *n, const bla_real *alpha, const bla_real *ap, const bla_real *x, const bla_integer *incx, const bla_real *beta, bla_real *y, const bla_integer *incy); -BLIS_EXPORT_BLAS int PASTEF77S(d,spmv)(const bla_character *uplo, const bla_integer *n, const bla_double *alpha, const bla_double *ap, const bla_double *x, const bla_integer *incx, const bla_double *beta, bla_double *y, const bla_integer *incy); -BLIS_EXPORT_BLAS int PASTEF77S(s,spmv)(const bla_character *uplo, const bla_integer *n, const bla_real *alpha, const bla_real *ap, const bla_real *x, const bla_integer *incx, const bla_real *beta, bla_real *y, const bla_integer *incy); #endif diff --git a/frame/compat/f2c/bla_spr.c b/frame/compat/f2c/bla_spr.c index 9b3ee1d886..d276458b49 100644 --- a/frame/compat/f2c/bla_spr.c +++ b/frame/compat/f2c/bla_spr.c @@ -5,8 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. - + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -42,8 +41,7 @@ -lf2c -lm (in that order) */ -/* Subroutine */ -int PASTEF77S(d,spr)(const bla_character *uplo, const bla_integer *n, const bla_double *alpha, const bla_double *x, const bla_integer *incx, bla_double *ap) +/* Subroutine */ int PASTEF77(d,spr)(const bla_character *uplo, const bla_integer *n, const bla_double *alpha, const bla_double *x, const bla_integer *incx, bla_double *ap) { /* System generated locals */ bla_integer i__1, i__2; @@ -270,8 +268,7 @@ int PASTEF77S(d,spr)(const bla_character *uplo, const bla_integer *n, const bla_ -lf2c -lm (in that order) */ -/* Subroutine */ -int PASTEF77S(s,spr)(const bla_character *uplo, const bla_integer *n, const bla_real *alpha, const bla_real *x, const bla_integer *incx, bla_real *ap) +/* Subroutine */ int PASTEF77(s,spr)(const bla_character *uplo, const bla_integer *n, const bla_real *alpha, const bla_real *x, const bla_integer *incx, bla_real *ap) { /* System generated locals */ bla_integer i__1, i__2; @@ -493,15 +490,5 @@ int PASTEF77S(s,spr)(const bla_character *uplo, const bla_integer *n, const bla_ } /* sspr_ */ -int PASTEF77(d,spr)(const bla_character *uplo, const bla_integer *n, const bla_double *alpha, const bla_double *x, const bla_integer *incx, bla_double *ap) -{ - return PASTEF77S(d,spr)( uplo, n, alpha, x, incx, ap ); -} - -int PASTEF77(s,spr)(const bla_character *uplo, const bla_integer *n, const bla_real *alpha, const bla_real *x, const bla_integer *incx, bla_real *ap) -{ - return PASTEF77S(s,spr)( uplo, n, alpha, x, incx, ap ); -} - #endif diff --git a/frame/compat/f2c/bla_spr.h b/frame/compat/f2c/bla_spr.h index d7519ca049..6712d7c166 100644 --- a/frame/compat/f2c/bla_spr.h +++ b/frame/compat/f2c/bla_spr.h @@ -5,8 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. - + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -37,7 +36,5 @@ BLIS_EXPORT_BLAS int PASTEF77(d,spr)(const bla_character *uplo, const bla_integer *n, const bla_double *alpha, const bla_double *x, const bla_integer *incx, bla_double *ap); BLIS_EXPORT_BLAS int PASTEF77(s,spr)(const bla_character *uplo, const bla_integer *n, const bla_real *alpha, const bla_real *x, const bla_integer *incx, bla_real *ap); -BLIS_EXPORT_BLAS int PASTEF77S(d,spr)(const bla_character *uplo, const bla_integer *n, const bla_double *alpha, const bla_double *x, const bla_integer *incx, bla_double *ap); -BLIS_EXPORT_BLAS int PASTEF77S(s,spr)(const bla_character *uplo, const bla_integer *n, const bla_real *alpha, const bla_real *x, const bla_integer *incx, bla_real *ap); #endif diff --git a/frame/compat/f2c/bla_spr2.c b/frame/compat/f2c/bla_spr2.c index 6955172512..7c75382122 100644 --- a/frame/compat/f2c/bla_spr2.c +++ b/frame/compat/f2c/bla_spr2.c @@ -5,8 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. - + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -42,8 +41,7 @@ -lf2c -lm (in that order) */ -/* Subroutine */ -int PASTEF77S(d,spr2)(const bla_character *uplo, const bla_integer *n, const bla_double *alpha, const bla_double *x, const bla_integer *incx, const bla_double *y, const bla_integer *incy, bla_double *ap) +/* Subroutine */ int PASTEF77(d,spr2)(const bla_character *uplo, const bla_integer *n, const bla_double *alpha, const bla_double *x, const bla_integer *incx, const bla_double *y, const bla_integer *incy, bla_double *ap) { /* System generated locals */ bla_integer i__1, i__2; @@ -302,8 +300,7 @@ int PASTEF77S(d,spr2)(const bla_character *uplo, const bla_integer *n, const bla -lf2c -lm (in that order) */ -/* Subroutine */ -int PASTEF77S(s,spr2)(const bla_character *uplo, const bla_integer *n, const bla_real *alpha, const bla_real *x, const bla_integer *incx, const bla_real *y, const bla_integer *incy, bla_real *ap) +/* Subroutine */ int PASTEF77(s,spr2)(const bla_character *uplo, const bla_integer *n, const bla_real *alpha, const bla_real *x, const bla_integer *incx, const bla_real *y, const bla_integer *incy, bla_real *ap) { /* System generated locals */ bla_integer i__1, i__2; @@ -557,15 +554,5 @@ int PASTEF77S(s,spr2)(const bla_character *uplo, const bla_integer *n, const bla } /* sspr2_ */ -int PASTEF77(d,spr2)(const bla_character *uplo, const bla_integer *n, const bla_double *alpha, const bla_double *x, const bla_integer *incx, const bla_double *y, const bla_integer *incy, bla_double *ap) -{ - return PASTEF77S(d,spr2)( uplo, n, alpha, x, incx, y, incy,ap ); -} - -int PASTEF77(s,spr2)(const bla_character *uplo, const bla_integer *n, const bla_real *alpha, const bla_real *x, const bla_integer *incx, const bla_real *y, const bla_integer *incy, bla_real *ap) -{ - return PASTEF77S(s,spr2)( uplo, n, alpha, x, incx, y, incy,ap ); -} - #endif diff --git a/frame/compat/f2c/bla_spr2.h b/frame/compat/f2c/bla_spr2.h index 1f02990dac..5a1d607471 100644 --- a/frame/compat/f2c/bla_spr2.h +++ b/frame/compat/f2c/bla_spr2.h @@ -5,8 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. - + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -37,7 +36,5 @@ BLIS_EXPORT_BLAS int PASTEF77(d,spr2)(const bla_character *uplo, const bla_integer *n, const bla_double *alpha, const bla_double *x, const bla_integer *incx, const bla_double *y, const bla_integer *incy, bla_double *ap); BLIS_EXPORT_BLAS int PASTEF77(s,spr2)(const bla_character *uplo, const bla_integer *n, const bla_real *alpha, const bla_real *x, const bla_integer *incx, const bla_real *y, const bla_integer *incy, bla_real *ap); -BLIS_EXPORT_BLAS int PASTEF77S(d,spr2)(const bla_character *uplo, const bla_integer *n, const bla_double *alpha, const bla_double *x, const bla_integer *incx, const bla_double *y, const bla_integer *incy, bla_double *ap); -BLIS_EXPORT_BLAS int PASTEF77S(s,spr2)(const bla_character *uplo, const bla_integer *n, const bla_real *alpha, const bla_real *x, const bla_integer *incx, const bla_real *y, const bla_integer *incy, bla_real *ap); #endif diff --git a/frame/compat/f2c/bla_tbmv.c b/frame/compat/f2c/bla_tbmv.c index de0cfe92db..78feb70562 100644 --- a/frame/compat/f2c/bla_tbmv.c +++ b/frame/compat/f2c/bla_tbmv.c @@ -5,8 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. - + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -42,8 +41,7 @@ -lf2c -lm (in that order) */ -/* Subroutine */ -int PASTEF77S(c,tbmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_scomplex *a, const bla_integer *lda, bla_scomplex *x, const bla_integer *incx) +/* Subroutine */ int PASTEF77(c,tbmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_scomplex *a, const bla_integer *lda, bla_scomplex *x, const bla_integer *incx) { /* System generated locals */ bla_integer a_dim1, a_offset, i__1, i__2, i__3, i__4, i__5; @@ -214,11 +212,11 @@ int PASTEF77S(c,tbmv)(const bla_character *uplo, const bla_character *trans, con if (! PASTEF770(lsame)(uplo, "U", (ftnlen)1, (ftnlen)1) && ! PASTEF770(lsame)(uplo, "L", ( ftnlen)1, (ftnlen)1)) { info = 1; - } else if (! PASTEF770(lsame)(trans, "N", (ftnlen)1, (ftnlen)1) && ! PASTEF770(lsame)(trans, + } else if (! PASTEF770(lsame)(trans, "N", (ftnlen)1, (ftnlen)1) && ! PASTEF770(lsame)(trans, "T", (ftnlen)1, (ftnlen)1) && ! PASTEF770(lsame)(trans, "C", (ftnlen)1, ( ftnlen)1)) { info = 2; - } else if (! PASTEF770(lsame)(diag, "U", (ftnlen)1, (ftnlen)1) && ! PASTEF770(lsame)(diag, + } else if (! PASTEF770(lsame)(diag, "U", (ftnlen)1, (ftnlen)1) && ! PASTEF770(lsame)(diag, "N", (ftnlen)1, (ftnlen)1)) { info = 3; } else if (*n < 0) { @@ -613,8 +611,7 @@ int PASTEF77S(c,tbmv)(const bla_character *uplo, const bla_character *trans, con -lf2c -lm (in that order) */ -/* Subroutine */ -int PASTEF77S(d,tbmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_double *a, const bla_integer *lda, bla_double *x, const bla_integer *incx) +/* Subroutine */ int PASTEF77(d,tbmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_double *a, const bla_integer *lda, bla_double *x, const bla_integer *incx) { /* System generated locals */ bla_integer a_dim1, a_offset, i__1, i__2, i__3, i__4; @@ -781,11 +778,11 @@ int PASTEF77S(d,tbmv)(const bla_character *uplo, const bla_character *trans, con if (! PASTEF770(lsame)(uplo, "U", (ftnlen)1, (ftnlen)1) && ! PASTEF770(lsame)(uplo, "L", ( ftnlen)1, (ftnlen)1)) { info = 1; - } else if (! PASTEF770(lsame)(trans, "N", (ftnlen)1, (ftnlen)1) && ! PASTEF770(lsame)(trans, + } else if (! PASTEF770(lsame)(trans, "N", (ftnlen)1, (ftnlen)1) && ! PASTEF770(lsame)(trans, "T", (ftnlen)1, (ftnlen)1) && ! PASTEF770(lsame)(trans, "C", (ftnlen)1, ( ftnlen)1)) { info = 2; - } else if (! PASTEF770(lsame)(diag, "U", (ftnlen)1, (ftnlen)1) && ! PASTEF770(lsame)(diag, + } else if (! PASTEF770(lsame)(diag, "U", (ftnlen)1, (ftnlen)1) && ! PASTEF770(lsame)(diag, "N", (ftnlen)1, (ftnlen)1)) { info = 3; } else if (*n < 0) { @@ -1025,8 +1022,7 @@ int PASTEF77S(d,tbmv)(const bla_character *uplo, const bla_character *trans, con -lf2c -lm (in that order) */ -/* Subroutine */ -int PASTEF77S(s,tbmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_real *a, const bla_integer *lda, bla_real *x, const bla_integer *incx) +/* Subroutine */ int PASTEF77(s,tbmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_real *a, const bla_integer *lda, bla_real *x, const bla_integer *incx) { /* System generated locals */ bla_integer a_dim1, a_offset, i__1, i__2, i__3, i__4; @@ -1193,11 +1189,11 @@ int PASTEF77S(s,tbmv)(const bla_character *uplo, const bla_character *trans, con if (! PASTEF770(lsame)(uplo, "U", (ftnlen)1, (ftnlen)1) && ! PASTEF770(lsame)(uplo, "L", ( ftnlen)1, (ftnlen)1)) { info = 1; - } else if (! PASTEF770(lsame)(trans, "N", (ftnlen)1, (ftnlen)1) && ! PASTEF770(lsame)(trans, + } else if (! PASTEF770(lsame)(trans, "N", (ftnlen)1, (ftnlen)1) && ! PASTEF770(lsame)(trans, "T", (ftnlen)1, (ftnlen)1) && ! PASTEF770(lsame)(trans, "C", (ftnlen)1, ( ftnlen)1)) { info = 2; - } else if (! PASTEF770(lsame)(diag, "U", (ftnlen)1, (ftnlen)1) && ! PASTEF770(lsame)(diag, + } else if (! PASTEF770(lsame)(diag, "U", (ftnlen)1, (ftnlen)1) && ! PASTEF770(lsame)(diag, "N", (ftnlen)1, (ftnlen)1)) { info = 3; } else if (*n < 0) { @@ -1437,8 +1433,7 @@ int PASTEF77S(s,tbmv)(const bla_character *uplo, const bla_character *trans, con -lf2c -lm (in that order) */ -/* Subroutine */ -int PASTEF77S(z,tbmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_dcomplex *a, const bla_integer *lda, bla_dcomplex *x, const bla_integer *incx) +/* Subroutine */ int PASTEF77(z,tbmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_dcomplex *a, const bla_integer *lda, bla_dcomplex *x, const bla_integer *incx) { /* System generated locals */ bla_integer a_dim1, a_offset, i__1, i__2, i__3, i__4, i__5; @@ -1609,11 +1604,11 @@ int PASTEF77S(z,tbmv)(const bla_character *uplo, const bla_character *trans, con if (! PASTEF770(lsame)(uplo, "U", (ftnlen)1, (ftnlen)1) && ! PASTEF770(lsame)(uplo, "L", ( ftnlen)1, (ftnlen)1)) { info = 1; - } else if (! PASTEF770(lsame)(trans, "N", (ftnlen)1, (ftnlen)1) && ! PASTEF770(lsame)(trans, + } else if (! PASTEF770(lsame)(trans, "N", (ftnlen)1, (ftnlen)1) && ! PASTEF770(lsame)(trans, "T", (ftnlen)1, (ftnlen)1) && ! PASTEF770(lsame)(trans, "C", (ftnlen)1, ( ftnlen)1)) { info = 2; - } else if (! PASTEF770(lsame)(diag, "U", (ftnlen)1, (ftnlen)1) && ! PASTEF770(lsame)(diag, + } else if (! PASTEF770(lsame)(diag, "U", (ftnlen)1, (ftnlen)1) && ! PASTEF770(lsame)(diag, "N", (ftnlen)1, (ftnlen)1)) { info = 3; } else if (*n < 0) { @@ -2003,25 +1998,5 @@ int PASTEF77S(z,tbmv)(const bla_character *uplo, const bla_character *trans, con } /* ztbmv_ */ -int PASTEF77(s,tbmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_real *a, const bla_integer *lda, bla_real *x, const bla_integer *incx) -{ - return PASTEF77S(s,tbmv)( uplo, trans, diag, n, k, a, lda, x, incx ); -} - -int PASTEF77(d,tbmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_double *a, const bla_integer *lda, bla_double *x, const bla_integer *incx) -{ - return PASTEF77S(d,tbmv)( uplo, trans, diag, n, k, a, lda, x, incx ); -} - -int PASTEF77(c,tbmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_scomplex *a, const bla_integer *lda, bla_scomplex *x, const bla_integer *incx) -{ - return PASTEF77S(c,tbmv)( uplo, trans, diag, n, k, a, lda, x, incx ); -} - -int PASTEF77(z,tbmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_dcomplex *a, const bla_integer *lda, bla_dcomplex *x, const bla_integer *incx) -{ - return PASTEF77S(z,tbmv)( uplo, trans, diag, n, k, a, lda, x, incx ); -} - #endif diff --git a/frame/compat/f2c/bla_tbmv.h b/frame/compat/f2c/bla_tbmv.h index cce9f18c8f..f34654762b 100644 --- a/frame/compat/f2c/bla_tbmv.h +++ b/frame/compat/f2c/bla_tbmv.h @@ -5,8 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. - + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -39,9 +38,5 @@ BLIS_EXPORT_BLAS int PASTEF77(c,tbmv)(const bla_character *uplo, const bla_chara BLIS_EXPORT_BLAS int PASTEF77(d,tbmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_double *a, const bla_integer *lda, bla_double *x, const bla_integer *incx); BLIS_EXPORT_BLAS int PASTEF77(s,tbmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_real *a, const bla_integer *lda, bla_real *x, const bla_integer *incx); BLIS_EXPORT_BLAS int PASTEF77(z,tbmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_dcomplex *a, const bla_integer *lda, bla_dcomplex *x, const bla_integer *incx); -BLIS_EXPORT_BLAS int PASTEF77S(c,tbmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_scomplex *a, const bla_integer *lda, bla_scomplex *x, const bla_integer *incx); -BLIS_EXPORT_BLAS int PASTEF77S(d,tbmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_double *a, const bla_integer *lda, bla_double *x, const bla_integer *incx); -BLIS_EXPORT_BLAS int PASTEF77S(s,tbmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_real *a, const bla_integer *lda, bla_real *x, const bla_integer *incx); -BLIS_EXPORT_BLAS int PASTEF77S(z,tbmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_dcomplex *a, const bla_integer *lda, bla_dcomplex *x, const bla_integer *incx); #endif diff --git a/frame/compat/f2c/bla_tbsv.c b/frame/compat/f2c/bla_tbsv.c index 25239c780d..819456f029 100644 --- a/frame/compat/f2c/bla_tbsv.c +++ b/frame/compat/f2c/bla_tbsv.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2021-2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -42,8 +42,7 @@ -lf2c -lm (in that order) */ -/* Subroutine */ -int PASTEF77S(c,tbsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_scomplex *a, const bla_integer *lda, bla_scomplex *x, const bla_integer *incx) +/* Subroutine */ int PASTEF77(c,tbsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_scomplex *a, const bla_integer *lda, bla_scomplex *x, const bla_integer *incx) { /* System generated locals */ bla_integer a_dim1, a_offset, i__1, i__2, i__3, i__4, i__5; @@ -623,8 +622,7 @@ int PASTEF77S(c,tbsv)(const bla_character *uplo, const bla_character *trans, con -lf2c -lm (in that order) */ -/* Subroutine */ -int PASTEF77S(d,tbsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_double *a, const bla_integer *lda, bla_double *x, const bla_integer *incx) +/* Subroutine */ int PASTEF77(d,tbsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_double *a, const bla_integer *lda, bla_double *x, const bla_integer *incx) { /* System generated locals */ bla_integer a_dim1, a_offset, i__1, i__2, i__3, i__4; @@ -1055,8 +1053,7 @@ int PASTEF77S(d,tbsv)(const bla_character *uplo, const bla_character *trans, con -lf2c -lm (in that order) */ -/* Subroutine */ -int PASTEF77S(s,tbsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_real *a, const bla_integer *lda, bla_real *x, const bla_integer *incx) +/* Subroutine */ int PASTEF77(s,tbsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_real *a, const bla_integer *lda, bla_real *x, const bla_integer *incx) { /* System generated locals */ bla_integer a_dim1, a_offset, i__1, i__2, i__3, i__4; @@ -1487,8 +1484,7 @@ int PASTEF77S(s,tbsv)(const bla_character *uplo, const bla_character *trans, con -lf2c -lm (in that order) */ -/* Subroutine */ -int PASTEF77S(z,tbsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_dcomplex *a, const bla_integer *lda, bla_dcomplex *x, const bla_integer *incx) +/* Subroutine */ int PASTEF77(z,tbsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_dcomplex *a, const bla_integer *lda, bla_dcomplex *x, const bla_integer *incx) { /* System generated locals */ bla_integer a_dim1, a_offset, i__1, i__2, i__3, i__4, i__5; @@ -2062,26 +2058,5 @@ int PASTEF77S(z,tbsv)(const bla_character *uplo, const bla_character *trans, con } /* ztbsv_ */ - -int PASTEF77(s,tbsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_real *a, const bla_integer *lda, bla_real *x, const bla_integer *incx) -{ - return PASTEF77S(s,tbsv)( uplo, trans, diag, n, k, a, lda, x, incx ); -} - -int PASTEF77(d,tbsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_double *a, const bla_integer *lda, bla_double *x, const bla_integer *incx) -{ - return PASTEF77S(d,tbsv)( uplo, trans, diag, n, k, a, lda, x, incx ); -} - -int PASTEF77(c,tbsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_scomplex *a, const bla_integer *lda, bla_scomplex *x, const bla_integer *incx) -{ - return PASTEF77S(c,tbsv)( uplo, trans, diag, n, k, a, lda, x, incx ); -} - -int PASTEF77(z,tbsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_dcomplex *a, const bla_integer *lda, bla_dcomplex *x, const bla_integer *incx) -{ - return PASTEF77S(z,tbsv)( uplo, trans, diag, n, k, a, lda, x, incx ); -} - #endif diff --git a/frame/compat/f2c/bla_tbsv.h b/frame/compat/f2c/bla_tbsv.h index dd11e3f3f5..5e84f5c363 100644 --- a/frame/compat/f2c/bla_tbsv.h +++ b/frame/compat/f2c/bla_tbsv.h @@ -5,8 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. - + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -39,9 +38,5 @@ BLIS_EXPORT_BLAS int PASTEF77(c,tbsv)(const bla_character *uplo, const bla_chara BLIS_EXPORT_BLAS int PASTEF77(d,tbsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_double *a, const bla_integer *lda, bla_double *x, const bla_integer *incx); BLIS_EXPORT_BLAS int PASTEF77(s,tbsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_real *a, const bla_integer *lda, bla_real *x, const bla_integer *incx); BLIS_EXPORT_BLAS int PASTEF77(z,tbsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_dcomplex *a, const bla_integer *lda, bla_dcomplex *x, const bla_integer *incx); -BLIS_EXPORT_BLAS int PASTEF77S(c,tbsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_scomplex *a, const bla_integer *lda, bla_scomplex *x, const bla_integer *incx); -BLIS_EXPORT_BLAS int PASTEF77S(d,tbsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_double *a, const bla_integer *lda, bla_double *x, const bla_integer *incx); -BLIS_EXPORT_BLAS int PASTEF77S(s,tbsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_real *a, const bla_integer *lda, bla_real *x, const bla_integer *incx); -BLIS_EXPORT_BLAS int PASTEF77S(z,tbsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_integer *k, const bla_dcomplex *a, const bla_integer *lda, bla_dcomplex *x, const bla_integer *incx); #endif diff --git a/frame/compat/f2c/bla_tpmv.c b/frame/compat/f2c/bla_tpmv.c index f01e9f9b42..8fa46f4c4f 100644 --- a/frame/compat/f2c/bla_tpmv.c +++ b/frame/compat/f2c/bla_tpmv.c @@ -5,8 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. - + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -42,8 +41,7 @@ -lf2c -lm (in that order) */ -/* Subroutine */ -int PASTEF77S(c,tpmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_scomplex *ap, bla_scomplex *x, const bla_integer *incx) +/* Subroutine */ int PASTEF77(c,tpmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_scomplex *ap, bla_scomplex *x, const bla_integer *incx) { /* System generated locals */ bla_integer i__1, i__2, i__3, i__4, i__5; @@ -544,8 +542,7 @@ int PASTEF77S(c,tpmv)(const bla_character *uplo, const bla_character *trans, con -lf2c -lm (in that order) */ -/* Subroutine */ -int PASTEF77S(d,tpmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_double *ap, bla_double *x, const bla_integer *incx) +/* Subroutine */ int PASTEF77(d,tpmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_double *ap, bla_double *x, const bla_integer *incx) { /* System generated locals */ bla_integer i__1, i__2; @@ -893,8 +890,7 @@ int PASTEF77S(d,tpmv)(const bla_character *uplo, const bla_character *trans, con -lf2c -lm (in that order) */ -/* Subroutine */ -int PASTEF77S(s,tpmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_real *ap, bla_real *x, const bla_integer *incx) +/* Subroutine */ int PASTEF77(s,tpmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_real *ap, bla_real *x, const bla_integer *incx) { /* System generated locals */ bla_integer i__1, i__2; @@ -1242,8 +1238,7 @@ int PASTEF77S(s,tpmv)(const bla_character *uplo, const bla_character *trans, con -lf2c -lm (in that order) */ -/* Subroutine */ -int PASTEF77S(z,tpmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_dcomplex *ap, bla_dcomplex *x, const bla_integer *incx) +/* Subroutine */ int PASTEF77(z,tpmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_dcomplex *ap, bla_dcomplex *x, const bla_integer *incx) { /* System generated locals */ bla_integer i__1, i__2, i__3, i__4, i__5; @@ -1739,24 +1734,5 @@ int PASTEF77S(z,tpmv)(const bla_character *uplo, const bla_character *trans, con } /* ztpmv_ */ -int PASTEF77(s,tpmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_real *ap, bla_real *x, const bla_integer *incx) -{ - return PASTEF77S(s,tpmv)( uplo, trans, diag, n, ap, x, incx ); -} - -int PASTEF77(d,tpmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_double *ap, bla_double *x, const bla_integer *incx) -{ - return PASTEF77S(d,tpmv)( uplo, trans, diag, n, ap, x, incx ); -} - -int PASTEF77(c,tpmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_scomplex *ap, bla_scomplex *x, const bla_integer *incx) -{ - return PASTEF77S(c,tpmv)( uplo, trans, diag, n, ap, x, incx ); -} - -int PASTEF77(z,tpmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_dcomplex *ap, bla_dcomplex *x, const bla_integer *incx) -{ - return PASTEF77S(z,tpmv)( uplo, trans, diag, n, ap, x, incx ); -} +#endif -#endif \ No newline at end of file diff --git a/frame/compat/f2c/bla_tpmv.h b/frame/compat/f2c/bla_tpmv.h index 6af438f1da..2376ecfe33 100644 --- a/frame/compat/f2c/bla_tpmv.h +++ b/frame/compat/f2c/bla_tpmv.h @@ -5,8 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. - + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -39,9 +38,5 @@ BLIS_EXPORT_BLAS int PASTEF77(c,tpmv)(const bla_character *uplo, const bla_chara BLIS_EXPORT_BLAS int PASTEF77(d,tpmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_double *ap, bla_double *x, const bla_integer *incx); BLIS_EXPORT_BLAS int PASTEF77(s,tpmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_real *ap, bla_real *x, const bla_integer *incx); BLIS_EXPORT_BLAS int PASTEF77(z,tpmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_dcomplex *ap, bla_dcomplex *x, const bla_integer *incx); -BLIS_EXPORT_BLAS int PASTEF77S(c,tpmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_scomplex *ap, bla_scomplex *x, const bla_integer *incx); -BLIS_EXPORT_BLAS int PASTEF77S(d,tpmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_double *ap, bla_double *x, const bla_integer *incx); -BLIS_EXPORT_BLAS int PASTEF77S(s,tpmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_real *ap, bla_real *x, const bla_integer *incx); -BLIS_EXPORT_BLAS int PASTEF77S(z,tpmv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_dcomplex *ap, bla_dcomplex *x, const bla_integer *incx); #endif diff --git a/frame/compat/f2c/bla_tpsv.c b/frame/compat/f2c/bla_tpsv.c index 2619df9fea..0764940979 100644 --- a/frame/compat/f2c/bla_tpsv.c +++ b/frame/compat/f2c/bla_tpsv.c @@ -5,8 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. - + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -42,8 +41,7 @@ -lf2c -lm (in that order) */ -/* Subroutine */ -int PASTEF77S(c,tpsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_scomplex *ap, bla_scomplex *x, const bla_integer *incx) +/* Subroutine */ int PASTEF77(c,tpsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_scomplex *ap, bla_scomplex *x, const bla_integer *incx) { /* System generated locals */ bla_integer i__1, i__2, i__3, i__4, i__5; @@ -536,8 +534,7 @@ int PASTEF77S(c,tpsv)(const bla_character *uplo, const bla_character *trans, con -lf2c -lm (in that order) */ -/* Subroutine */ -int PASTEF77S(d,tpsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_double *ap, bla_double *x, const bla_integer *incx) +/* Subroutine */ int PASTEF77(d,tpsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_double *ap, bla_double *x, const bla_integer *incx) { /* System generated locals */ bla_integer i__1, i__2; @@ -888,8 +885,7 @@ int PASTEF77S(d,tpsv)(const bla_character *uplo, const bla_character *trans, con -lf2c -lm (in that order) */ -/* Subroutine */ -int PASTEF77S(s,tpsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_real *ap, bla_real *x, const bla_integer *incx) +/* Subroutine */ int PASTEF77(s,tpsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_real *ap, bla_real *x, const bla_integer *incx) { /* System generated locals */ bla_integer i__1, i__2; @@ -1240,8 +1236,7 @@ int PASTEF77S(s,tpsv)(const bla_character *uplo, const bla_character *trans, con -lf2c -lm (in that order) */ -/* Subroutine */ -int PASTEF77S(z,tpsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_dcomplex *ap, bla_dcomplex *x, const bla_integer *incx) +/* Subroutine */ int PASTEF77(z,tpsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_dcomplex *ap, bla_dcomplex *x, const bla_integer *incx) { /* System generated locals */ bla_integer i__1, i__2, i__3, i__4, i__5; @@ -1730,25 +1725,5 @@ int PASTEF77S(z,tpsv)(const bla_character *uplo, const bla_character *trans, con } /* ztpsv_ */ -int PASTEF77(s,tpsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_real *ap, bla_real *x, const bla_integer *incx) -{ - return PASTEF77S(s,tpsv)( uplo, trans, diag, n, ap, x, incx ); -} - -int PASTEF77(d,tpsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_double *ap, bla_double *x, const bla_integer *incx) -{ - return PASTEF77S(d,tpsv)( uplo, trans, diag, n, ap, x, incx ); -} - -int PASTEF77(c,tpsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_scomplex *ap, bla_scomplex *x, const bla_integer *incx) -{ - return PASTEF77S(c,tpsv)( uplo, trans, diag, n, ap, x, incx ); -} - -int PASTEF77(z,tpsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_dcomplex *ap, bla_dcomplex *x, const bla_integer *incx) -{ - return PASTEF77S(z,tpsv)( uplo, trans, diag, n, ap, x, incx ); -} - #endif diff --git a/frame/compat/f2c/bla_tpsv.h b/frame/compat/f2c/bla_tpsv.h index 7a3332424d..77bd55979a 100644 --- a/frame/compat/f2c/bla_tpsv.h +++ b/frame/compat/f2c/bla_tpsv.h @@ -5,8 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. - + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: @@ -39,9 +38,5 @@ BLIS_EXPORT_BLAS int PASTEF77(c,tpsv)(const bla_character *uplo, const bla_chara BLIS_EXPORT_BLAS int PASTEF77(d,tpsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_double *ap, bla_double *x, const bla_integer *incx); BLIS_EXPORT_BLAS int PASTEF77(s,tpsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_real *ap, bla_real *x, const bla_integer *incx); BLIS_EXPORT_BLAS int PASTEF77(z,tpsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_dcomplex *ap, bla_dcomplex *x, const bla_integer *incx); -BLIS_EXPORT_BLAS int PASTEF77S(c,tpsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_scomplex *ap, bla_scomplex *x, const bla_integer *incx); -BLIS_EXPORT_BLAS int PASTEF77S(d,tpsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_double *ap, bla_double *x, const bla_integer *incx); -BLIS_EXPORT_BLAS int PASTEF77S(s,tpsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_real *ap, bla_real *x, const bla_integer *incx); -BLIS_EXPORT_BLAS int PASTEF77S(z,tpsv)(const bla_character *uplo, const bla_character *trans, const bla_character *diag, const bla_integer *n, const bla_dcomplex *ap, bla_dcomplex *x, const bla_integer *incx); #endif diff --git a/frame/include/bli_macro_defs.h b/frame/include/bli_macro_defs.h index 9ab7c00aa7..75b9c9fdc4 100644 --- a/frame/include/bli_macro_defs.h +++ b/frame/include/bli_macro_defs.h @@ -158,23 +158,17 @@ #define PASTEMACT(ch1, ch2, ch3, ch4) bli_ ## ch1 ## ch2 ## _ ## ch3 ## _ ## ch4 // name-mangling macros. #ifdef BLIS_ENABLE_NO_UNDERSCORE_API -#define PASTEF770(name) name -#define PASTEF77(ch1,name) ch1 ## name -#define PASTEF772(ch1,ch2,name) ch1 ## ch2 ## name -#define PASTEF773(ch1,ch2,ch3,name) ch1 ## ch2 ## ch3 ## name -#define PASTEF770S(name) name ## _blis_impl -#define PASTEF77S(ch1,name) ch1 ## name ## _blis_impl -#define PASTEF772S(ch1,ch2,name) ch1 ## ch2 ## name ## _blis_impl -#define PASTEF773S(ch1,ch2,ch3,name) ch1 ## ch2 ## ch3 ## name ## _blis_impl +#define PASTEF770(name) name +#define PASTEF77(ch1,name) ch1 ## name +#define PASTEF772(ch1,ch2,name) ch1 ## ch2 ## name +#define PASTEF773(ch1,ch2,ch3,name) ch1 ## ch2 ## ch3 ## name +#define PASTEF77S(ch1,name) ch1 ## name ## _blis_impl #else -#define PASTEF770(name) name ## _ -#define PASTEF77(ch1,name) ch1 ## name ## _ -#define PASTEF772(ch1,ch2,name) ch1 ## ch2 ## name ## _ -#define PASTEF773(ch1,ch2,ch3,name) ch1 ## ch2 ## ch3 ## name ## _ -#define PASTEF770S(name) name ## _blis_impl -#define PASTEF77S(ch1,name) ch1 ## name ## _blis_impl -#define PASTEF772S(ch1,ch2,name) ch1 ## ch2 ## name ## _blis_impl -#define PASTEF773S(ch1,ch2,ch3,name) ch1 ## ch2 ## ch3 ## name ## _blis_impl +#define PASTEF770(name) name ## _ +#define PASTEF77(ch1,name) ch1 ## name ## _ +#define PASTEF772(ch1,ch2,name) ch1 ## ch2 ## name ## _ +#define PASTEF773(ch1,ch2,ch3,name) ch1 ## ch2 ## ch3 ## name ## _ +#define PASTEF77S(ch1,name) ch1 ## name ## _blis_impl #endif // -- Include other groups of macros diff --git a/frame/util/bli_util_api_wrap.c b/frame/util/bli_util_api_wrap.c index 098a7c33e8..9e8d1ccc38 100644 --- a/frame/util/bli_util_api_wrap.c +++ b/frame/util/bli_util_api_wrap.c @@ -195,17 +195,17 @@ void ZDOTU_(dcomplex* retval,const f77_int *n, const dcomplex *zx, const f77_int void CGBMV(const char *trans,const f77_int *m,const f77_int *n,const f77_int *kl,const f77_int *ku,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *x,const f77_int *incx,const scomplex *beta,scomplex *y,const f77_int *incy) { - cgbmv_blis_impl( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); + cgbmv_( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); } void cgbmv(const char *trans,const f77_int *m,const f77_int *n,const f77_int *kl,const f77_int *ku,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *x,const f77_int *incx,const scomplex *beta,scomplex *y,const f77_int *incy) { - cgbmv_blis_impl( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); + cgbmv_( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); } void CGBMV_(const char *trans,const f77_int *m,const f77_int *n,const f77_int *kl,const f77_int *ku,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *x,const f77_int *incx,const scomplex *beta,scomplex *y,const f77_int *incy) { - cgbmv_blis_impl( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); + cgbmv_( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); } void CGEMM(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) @@ -225,62 +225,62 @@ void CGEMM_(const char *transa,const char *transb,const f77_int *m,const f77 void CGEMV(const char *trans,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *x,const f77_int *incx,const scomplex *beta,scomplex *y,const f77_int *incy) { - cgemv_blis_impl( trans, m, n, alpha, a, lda, x, incx, beta, y, incy); + cgemv_( trans, m, n, alpha, a, lda, x, incx, beta, y, incy); } void cgemv(const char *trans,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *x,const f77_int *incx,const scomplex *beta,scomplex *y,const f77_int *incy) { - cgemv_blis_impl( trans, m, n, alpha, a, lda, x, incx, beta, y, incy); + cgemv_( trans, m, n, alpha, a, lda, x, incx, beta, y, incy); } void CGEMV_(const char *trans,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *x,const f77_int *incx,const scomplex *beta,scomplex *y,const f77_int *incy) { - cgemv_blis_impl( trans, m, n, alpha, a, lda, x, incx, beta, y, incy); + cgemv_( trans, m, n, alpha, a, lda, x, incx, beta, y, incy); } void CGERC(const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *x,const f77_int *incx,const scomplex *y,const f77_int *incy,scomplex *a,const f77_int *lda) { - cgerc_blis_impl( m, n, alpha, x, incx, y, incy, a, lda); + cgerc_( m, n, alpha, x, incx, y, incy, a, lda); } void cgerc(const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *x,const f77_int *incx,const scomplex *y,const f77_int *incy,scomplex *a,const f77_int *lda) { - cgerc_blis_impl( m, n, alpha, x, incx, y, incy, a, lda); + cgerc_( m, n, alpha, x, incx, y, incy, a, lda); } void CGERC_(const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *x,const f77_int *incx,const scomplex *y,const f77_int *incy,scomplex *a,const f77_int *lda) { - cgerc_blis_impl( m, n, alpha, x, incx, y, incy, a, lda); + cgerc_( m, n, alpha, x, incx, y, incy, a, lda); } void CGERU(const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *x,const f77_int *incx,const scomplex *y,const f77_int *incy,scomplex *a,const f77_int *lda) { - cgeru_blis_impl( m, n, alpha, x, incx, y, incy, a, lda); + cgeru_( m, n, alpha, x, incx, y, incy, a, lda); } void cgeru(const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *x,const f77_int *incx,const scomplex *y,const f77_int *incy,scomplex *a,const f77_int *lda) { - cgeru_blis_impl( m, n, alpha, x, incx, y, incy, a, lda); + cgeru_( m, n, alpha, x, incx, y, incy, a, lda); } void CGERU_(const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *x,const f77_int *incx,const scomplex *y,const f77_int *incy,scomplex *a,const f77_int *lda) { - cgeru_blis_impl( m, n, alpha, x, incx, y, incy, a, lda); + cgeru_( m, n, alpha, x, incx, y, incy, a, lda); } void CHBMV(const char *uplo,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *x,const f77_int *incx,const scomplex *beta,scomplex *y,const f77_int *incy) { - chbmv_blis_impl( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); + chbmv_( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); } void chbmv(const char *uplo,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *x,const f77_int *incx,const scomplex *beta,scomplex *y,const f77_int *incy) { - chbmv_blis_impl( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); + chbmv_( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); } void CHBMV_(const char *uplo,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *x,const f77_int *incx,const scomplex *beta,scomplex *y,const f77_int *incy) { - chbmv_blis_impl( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); + chbmv_( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); } void CHEMM(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) @@ -300,47 +300,47 @@ void CHEMM_(const char *side,const char *uplo,const f77_int *m,const f77_int void CHEMV(const char *uplo,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *x,const f77_int *incx,const scomplex *beta,scomplex *y,const f77_int *incy) { - chemv_blis_impl( uplo, n, alpha, a, lda, x, incx, beta, y, incy); + chemv_( uplo, n, alpha, a, lda, x, incx, beta, y, incy); } void chemv(const char *uplo,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *x,const f77_int *incx,const scomplex *beta,scomplex *y,const f77_int *incy) { - chemv_blis_impl( uplo, n, alpha, a, lda, x, incx, beta, y, incy); + chemv_( uplo, n, alpha, a, lda, x, incx, beta, y, incy); } void CHEMV_(const char *uplo,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *x,const f77_int *incx,const scomplex *beta,scomplex *y,const f77_int *incy) { - chemv_blis_impl( uplo, n, alpha, a, lda, x, incx, beta, y, incy); + chemv_( uplo, n, alpha, a, lda, x, incx, beta, y, incy); } void CHER(const char *uplo,const f77_int *n,const float *alpha,const scomplex *x,const f77_int *incx,scomplex *a,const f77_int *lda) { - cher_blis_impl( uplo, n, alpha, x, incx, a, lda); + cher_( uplo, n, alpha, x, incx, a, lda); } void cher(const char *uplo,const f77_int *n,const float *alpha,const scomplex *x,const f77_int *incx,scomplex *a,const f77_int *lda) { - cher_blis_impl( uplo, n, alpha, x, incx, a, lda); + cher_( uplo, n, alpha, x, incx, a, lda); } void CHER_(const char *uplo,const f77_int *n,const float *alpha,const scomplex *x,const f77_int *incx,scomplex *a,const f77_int *lda) { - cher_blis_impl( uplo, n, alpha, x, incx, a, lda); + cher_( uplo, n, alpha, x, incx, a, lda); } void CHER2(const char *uplo,const f77_int *n,const scomplex *alpha,const scomplex *x,const f77_int *incx,const scomplex *y,const f77_int *incy,scomplex *a,const f77_int *lda) { - cher2_blis_impl( uplo, n, alpha, x, incx, y, incy, a, lda); + cher2_( uplo, n, alpha, x, incx, y, incy, a, lda); } void cher2(const char *uplo,const f77_int *n,const scomplex *alpha,const scomplex *x,const f77_int *incx,const scomplex *y,const f77_int *incy,scomplex *a,const f77_int *lda) { - cher2_blis_impl( uplo, n, alpha, x, incx, y, incy, a, lda); + cher2_( uplo, n, alpha, x, incx, y, incy, a, lda); } void CHER2_(const char *uplo,const f77_int *n,const scomplex *alpha,const scomplex *x,const f77_int *incx,const scomplex *y,const f77_int *incy,scomplex *a,const f77_int *lda) { - cher2_blis_impl( uplo, n, alpha, x, incx, y, incy, a, lda); + cher2_( uplo, n, alpha, x, incx, y, incy, a, lda); } void CHER2K(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const float *beta,scomplex *c,const f77_int *ldc) @@ -375,47 +375,47 @@ void CHERK_(const char *uplo,const char *trans,const f77_int *n,const f77_in void CHPMV(const char *uplo,const f77_int *n,const scomplex *alpha,const scomplex *ap,const scomplex *x,const f77_int *incx,const scomplex *beta,scomplex *y,const f77_int *incy) { - chpmv_blis_impl( uplo, n, alpha, ap, x, incx, beta, y, incy); + chpmv_( uplo, n, alpha, ap, x, incx, beta, y, incy); } void chpmv(const char *uplo,const f77_int *n,const scomplex *alpha,const scomplex *ap,const scomplex *x,const f77_int *incx,const scomplex *beta,scomplex *y,const f77_int *incy) { - chpmv_blis_impl( uplo, n, alpha, ap, x, incx, beta, y, incy); + chpmv_( uplo, n, alpha, ap, x, incx, beta, y, incy); } void CHPMV_(const char *uplo,const f77_int *n,const scomplex *alpha,const scomplex *ap,const scomplex *x,const f77_int *incx,const scomplex *beta,scomplex *y,const f77_int *incy) { - chpmv_blis_impl( uplo, n, alpha, ap, x, incx, beta, y, incy); + chpmv_( uplo, n, alpha, ap, x, incx, beta, y, incy); } void CHPR(const char *uplo,const f77_int *n,const float *alpha,const scomplex *x,const f77_int *incx,scomplex *ap) { - chpr_blis_impl( uplo, n, alpha, x, incx, ap); + chpr_( uplo, n, alpha, x, incx, ap); } void chpr(const char *uplo,const f77_int *n,const float *alpha,const scomplex *x,const f77_int *incx,scomplex *ap) { - chpr_blis_impl( uplo, n, alpha, x, incx, ap); + chpr_( uplo, n, alpha, x, incx, ap); } void CHPR_(const char *uplo,const f77_int *n,const float *alpha,const scomplex *x,const f77_int *incx,scomplex *ap) { - chpr_blis_impl( uplo, n, alpha, x, incx, ap); + chpr_( uplo, n, alpha, x, incx, ap); } void CHPR2(const char *uplo,const f77_int *n,const scomplex *alpha,const scomplex *x,const f77_int *incx,const scomplex *y,const f77_int *incy,scomplex *ap) { - chpr2_blis_impl( uplo, n, alpha, x, incx, y, incy, ap); + chpr2_( uplo, n, alpha, x, incx, y, incy, ap); } void chpr2(const char *uplo,const f77_int *n,const scomplex *alpha,const scomplex *x,const f77_int *incx,const scomplex *y,const f77_int *incy,scomplex *ap) { - chpr2_blis_impl( uplo, n, alpha, x, incx, y, incy, ap); + chpr2_( uplo, n, alpha, x, incx, y, incy, ap); } void CHPR2_(const char *uplo,const f77_int *n,const scomplex *alpha,const scomplex *x,const f77_int *incx,const scomplex *y,const f77_int *incy,scomplex *ap) { - chpr2_blis_impl( uplo, n, alpha, x, incx, y, incy, ap); + chpr2_( uplo, n, alpha, x, incx, y, incy, ap); } void CROTG(scomplex *ca, bla_scomplex *cb, bla_real *c,scomplex *s) @@ -540,62 +540,62 @@ void CSYRK_(const char *uplo,const char *trans,const f77_int *n,const f77_in void CTBMV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const scomplex *a,const f77_int *lda,scomplex *x,const f77_int *incx) { - ctbmv_blis_impl( uplo, trans, diag, n, k, a, lda, x, incx); + ctbmv_( uplo, trans, diag, n, k, a, lda, x, incx); } void ctbmv(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const scomplex *a,const f77_int *lda,scomplex *x,const f77_int *incx) { - ctbmv_blis_impl( uplo, trans, diag, n, k, a, lda, x, incx); + ctbmv_( uplo, trans, diag, n, k, a, lda, x, incx); } void CTBMV_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const scomplex *a,const f77_int *lda,scomplex *x,const f77_int *incx) { - ctbmv_blis_impl( uplo, trans, diag, n, k, a, lda, x, incx); + ctbmv_( uplo, trans, diag, n, k, a, lda, x, incx); } void CTBSV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const scomplex *a,const f77_int *lda,scomplex *x,const f77_int *incx) { - ctbsv_blis_impl( uplo, trans, diag, n, k, a, lda, x, incx); + ctbsv_( uplo, trans, diag, n, k, a, lda, x, incx); } void ctbsv(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const scomplex *a,const f77_int *lda,scomplex *x,const f77_int *incx) { - ctbsv_blis_impl( uplo, trans, diag, n, k, a, lda, x, incx); + ctbsv_( uplo, trans, diag, n, k, a, lda, x, incx); } void CTBSV_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const scomplex *a,const f77_int *lda,scomplex *x,const f77_int *incx) { - ctbsv_blis_impl( uplo, trans, diag, n, k, a, lda, x, incx); + ctbsv_( uplo, trans, diag, n, k, a, lda, x, incx); } void CTPMV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const scomplex *ap,scomplex *x,const f77_int *incx) { - ctpmv_blis_impl( uplo, trans, diag, n, ap, x, incx); + ctpmv_( uplo, trans, diag, n, ap, x, incx); } void ctpmv(const char *uplo,const char *trans,const char *diag,const f77_int *n,const scomplex *ap,scomplex *x,const f77_int *incx) { - ctpmv_blis_impl( uplo, trans, diag, n, ap, x, incx); + ctpmv_( uplo, trans, diag, n, ap, x, incx); } void CTPMV_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const scomplex *ap,scomplex *x,const f77_int *incx) { - ctpmv_blis_impl( uplo, trans, diag, n, ap, x, incx); + ctpmv_( uplo, trans, diag, n, ap, x, incx); } void CTPSV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const scomplex *ap,scomplex *x,const f77_int *incx) { - ctpsv_blis_impl( uplo, trans, diag, n, ap, x, incx); + ctpsv_( uplo, trans, diag, n, ap, x, incx); } void ctpsv(const char *uplo,const char *trans,const char *diag,const f77_int *n,const scomplex *ap,scomplex *x,const f77_int *incx) { - ctpsv_blis_impl( uplo, trans, diag, n, ap, x, incx); + ctpsv_( uplo, trans, diag, n, ap, x, incx); } void CTPSV_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const scomplex *ap,scomplex *x,const f77_int *incx) { - ctpsv_blis_impl( uplo, trans, diag, n, ap, x, incx); + ctpsv_( uplo, trans, diag, n, ap, x, incx); } void CTRMM(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,scomplex *b,const f77_int *ldb) @@ -615,17 +615,17 @@ void CTRMM_(const char *side,const char *uplo,const char *transa,const cha void CTRMV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const scomplex *a,const f77_int *lda,scomplex *x,const f77_int *incx) { - ctrmv_blis_impl( uplo, trans, diag, n, a, lda, x, incx); + ctrmv_( uplo, trans, diag, n, a, lda, x, incx); } void ctrmv(const char *uplo,const char *trans,const char *diag,const f77_int *n,const scomplex *a,const f77_int *lda,scomplex *x,const f77_int *incx) { - ctrmv_blis_impl( uplo, trans, diag, n, a, lda, x, incx); + ctrmv_( uplo, trans, diag, n, a, lda, x, incx); } void CTRMV_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const scomplex *a,const f77_int *lda,scomplex *x,const f77_int *incx) { - ctrmv_blis_impl( uplo, trans, diag, n, a, lda, x, incx); + ctrmv_( uplo, trans, diag, n, a, lda, x, incx); } void CTRSM(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,scomplex *b,const f77_int *ldb) @@ -645,17 +645,17 @@ void CTRSM_(const char *side,const char *uplo,const char *transa,const cha void CTRSV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const scomplex *a,const f77_int *lda,scomplex *x,const f77_int *incx) { - ctrsv_blis_impl( uplo, trans, diag, n, a, lda, x, incx); + ctrsv_( uplo, trans, diag, n, a, lda, x, incx); } void ctrsv(const char *uplo,const char *trans,const char *diag,const f77_int *n,const scomplex *a,const f77_int *lda,scomplex *x,const f77_int *incx) { - ctrsv_blis_impl( uplo, trans, diag, n, a, lda, x, incx); + ctrsv_( uplo, trans, diag, n, a, lda, x, incx); } void CTRSV_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const scomplex *a,const f77_int *lda,scomplex *x,const f77_int *incx) { - ctrsv_blis_impl( uplo, trans, diag, n, a, lda, x, incx); + ctrsv_( uplo, trans, diag, n, a, lda, x, incx); } double DASUM(const f77_int *n,const double *dx,const f77_int *incx) @@ -735,17 +735,17 @@ double DDOT_(const f77_int *n,const double *dx,const f77_int *incx,const double void DGBMV(const char *trans,const f77_int *m,const f77_int *n,const f77_int *kl,const f77_int *ku,const double *alpha,const double *a,const f77_int *lda,const double *x,const f77_int *incx,const double *beta,double *y,const f77_int *incy) { - dgbmv_blis_impl( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); + dgbmv_( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); } void dgbmv(const char *trans,const f77_int *m,const f77_int *n,const f77_int *kl,const f77_int *ku,const double *alpha,const double *a,const f77_int *lda,const double *x,const f77_int *incx,const double *beta,double *y,const f77_int *incy) { - dgbmv_blis_impl( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); + dgbmv_( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); } void DGBMV_(const char *trans,const f77_int *m,const f77_int *n,const f77_int *kl,const f77_int *ku,const double *alpha,const double *a,const f77_int *lda,const double *x,const f77_int *incx,const double *beta,double *y,const f77_int *incy) { - dgbmv_blis_impl( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); + dgbmv_( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); } void DGEMM(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const double *alpha,const double *a,const f77_int *lda,const double *b,const f77_int *ldb,const double *beta,double *c,const f77_int *ldc) @@ -765,32 +765,32 @@ void DGEMM_(const char *transa,const char *transb,const f77_int *m,const f77 void DGEMV(const char *trans,const f77_int *m,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,const double *x,const f77_int *incx,const double *beta,double *y,const f77_int *incy) { - dgemv_blis_impl( trans, m, n, alpha, a, lda, x, incx, beta, y, incy); + dgemv_( trans, m, n, alpha, a, lda, x, incx, beta, y, incy); } void dgemv(const char *trans,const f77_int *m,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,const double *x,const f77_int *incx,const double *beta,double *y,const f77_int *incy) { - dgemv_blis_impl( trans, m, n, alpha, a, lda, x, incx, beta, y, incy); + dgemv_( trans, m, n, alpha, a, lda, x, incx, beta, y, incy); } void DGEMV_(const char *trans,const f77_int *m,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,const double *x,const f77_int *incx,const double *beta,double *y,const f77_int *incy) { - dgemv_blis_impl( trans, m, n, alpha, a, lda, x, incx, beta, y, incy); + dgemv_( trans, m, n, alpha, a, lda, x, incx, beta, y, incy); } void DGER(const f77_int *m,const f77_int *n,const double *alpha,const double *x,const f77_int *incx,const double *y,const f77_int *incy,double *a,const f77_int *lda) { - dger_blis_impl( m, n, alpha, x, incx, y, incy, a, lda); + dger_( m, n, alpha, x, incx, y, incy, a, lda); } void dger(const f77_int *m,const f77_int *n,const double *alpha,const double *x,const f77_int *incx,const double *y,const f77_int *incy,double *a,const f77_int *lda) { - dger_blis_impl( m, n, alpha, x, incx, y, incy, a, lda); + dger_( m, n, alpha, x, incx, y, incy, a, lda); } void DGER_(const f77_int *m,const f77_int *n,const double *alpha,const double *x,const f77_int *incx,const double *y,const f77_int *incy,double *a,const f77_int *lda) { - dger_blis_impl( m, n, alpha, x, incx, y, incy, a, lda); + dger_( m, n, alpha, x, incx, y, incy, a, lda); } double DNRM2(const f77_int *n,const double *x,const f77_int *incx) @@ -870,17 +870,17 @@ void DROTMG_(double *dd1,double *dd2,double *dx1,const double *dy1,double *dpara void DSBMV(const char *uplo,const f77_int *n,const f77_int *k,const double *alpha,const double *a,const f77_int *lda,const double *x,const f77_int *incx,const double *beta,double *y,const f77_int *incy) { - dsbmv_blis_impl( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); + dsbmv_( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); } void dsbmv(const char *uplo,const f77_int *n,const f77_int *k,const double *alpha,const double *a,const f77_int *lda,const double *x,const f77_int *incx,const double *beta,double *y,const f77_int *incy) { - dsbmv_blis_impl( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); + dsbmv_( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); } void DSBMV_(const char *uplo,const f77_int *n,const f77_int *k,const double *alpha,const double *a,const f77_int *lda,const double *x,const f77_int *incx,const double *beta,double *y,const f77_int *incy) { - dsbmv_blis_impl( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); + dsbmv_( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); } void DSCAL(const f77_int *n,const double *da,double *dx,const f77_int *incx) @@ -915,47 +915,47 @@ double DSDOT_(const f77_int *n,const float *sx,const f77_int *incx,const float void DSPMV(const char *uplo,const f77_int *n,const double *alpha,const double *ap,const double *x,const f77_int *incx,const double *beta,double *y,const f77_int *incy) { - dspmv_blis_impl( uplo, n, alpha, ap, x, incx, beta, y, incy); + dspmv_( uplo, n, alpha, ap, x, incx, beta, y, incy); } void dspmv(const char *uplo,const f77_int *n,const double *alpha,const double *ap,const double *x,const f77_int *incx,const double *beta,double *y,const f77_int *incy) { - dspmv_blis_impl( uplo, n, alpha, ap, x, incx, beta, y, incy); + dspmv_( uplo, n, alpha, ap, x, incx, beta, y, incy); } void DSPMV_(const char *uplo,const f77_int *n,const double *alpha,const double *ap,const double *x,const f77_int *incx,const double *beta,double *y,const f77_int *incy) { - dspmv_blis_impl( uplo, n, alpha, ap, x, incx, beta, y, incy); + dspmv_( uplo, n, alpha, ap, x, incx, beta, y, incy); } void DSPR(const char *uplo,const f77_int *n,const double *alpha,const double *x,const f77_int *incx,double *ap) { - dspr_blis_impl( uplo, n, alpha, x, incx, ap); + dspr_( uplo, n, alpha, x, incx, ap); } void dspr(const char *uplo,const f77_int *n,const double *alpha,const double *x,const f77_int *incx,double *ap) { - dspr_blis_impl( uplo, n, alpha, x, incx, ap); + dspr_( uplo, n, alpha, x, incx, ap); } void DSPR_(const char *uplo,const f77_int *n,const double *alpha,const double *x,const f77_int *incx,double *ap) { - dspr_blis_impl( uplo, n, alpha, x, incx, ap); + dspr_( uplo, n, alpha, x, incx, ap); } void DSPR2(const char *uplo,const f77_int *n,const double *alpha,const double *x,const f77_int *incx,const double *y,const f77_int *incy,double *ap) { - dspr2_blis_impl( uplo, n, alpha, x, incx, y, incy, ap); + dspr2_( uplo, n, alpha, x, incx, y, incy, ap); } void dspr2(const char *uplo,const f77_int *n,const double *alpha,const double *x,const f77_int *incx,const double *y,const f77_int *incy,double *ap) { - dspr2_blis_impl( uplo, n, alpha, x, incx, y, incy, ap); + dspr2_( uplo, n, alpha, x, incx, y, incy, ap); } void DSPR2_(const char *uplo,const f77_int *n,const double *alpha,const double *x,const f77_int *incx,const double *y,const f77_int *incy,double *ap) { - dspr2_blis_impl( uplo, n, alpha, x, incx, y, incy, ap); + dspr2_( uplo, n, alpha, x, incx, y, incy, ap); } void DSWAP(const f77_int *n,double *dx,const f77_int *incx,double *dy,const f77_int *incy) @@ -990,47 +990,47 @@ void DSYMM_(const char *side,const char *uplo,const f77_int *m,const f77_int void DSYMV(const char *uplo,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,const double *x,const f77_int *incx,const double *beta,double *y,const f77_int *incy) { - dsymv_blis_impl( uplo, n, alpha, a, lda, x, incx, beta, y, incy); + dsymv_( uplo, n, alpha, a, lda, x, incx, beta, y, incy); } void dsymv(const char *uplo,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,const double *x,const f77_int *incx,const double *beta,double *y,const f77_int *incy) { - dsymv_blis_impl( uplo, n, alpha, a, lda, x, incx, beta, y, incy); + dsymv_( uplo, n, alpha, a, lda, x, incx, beta, y, incy); } void DSYMV_(const char *uplo,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,const double *x,const f77_int *incx,const double *beta,double *y,const f77_int *incy) { - dsymv_blis_impl( uplo, n, alpha, a, lda, x, incx, beta, y, incy); + dsymv_( uplo, n, alpha, a, lda, x, incx, beta, y, incy); } void DSYR(const char *uplo,const f77_int *n,const double *alpha,const double *x,const f77_int *incx,double *a,const f77_int *lda) { - dsyr_blis_impl( uplo, n, alpha, x, incx, a, lda); + dsyr_( uplo, n, alpha, x, incx, a, lda); } void dsyr(const char *uplo,const f77_int *n,const double *alpha,const double *x,const f77_int *incx,double *a,const f77_int *lda) { - dsyr_blis_impl( uplo, n, alpha, x, incx, a, lda); + dsyr_( uplo, n, alpha, x, incx, a, lda); } void DSYR_(const char *uplo,const f77_int *n,const double *alpha,const double *x,const f77_int *incx,double *a,const f77_int *lda) { - dsyr_blis_impl( uplo, n, alpha, x, incx, a, lda); + dsyr_( uplo, n, alpha, x, incx, a, lda); } void DSYR2(const char *uplo,const f77_int *n,const double *alpha,const double *x,const f77_int *incx,const double *y,const f77_int *incy,double *a,const f77_int *lda) { - dsyr2_blis_impl( uplo, n, alpha, x, incx, y, incy, a, lda); + dsyr2_( uplo, n, alpha, x, incx, y, incy, a, lda); } void dsyr2(const char *uplo,const f77_int *n,const double *alpha,const double *x,const f77_int *incx,const double *y,const f77_int *incy,double *a,const f77_int *lda) { - dsyr2_blis_impl( uplo, n, alpha, x, incx, y, incy, a, lda); + dsyr2_( uplo, n, alpha, x, incx, y, incy, a, lda); } void DSYR2_(const char *uplo,const f77_int *n,const double *alpha,const double *x,const f77_int *incx,const double *y,const f77_int *incy,double *a,const f77_int *lda) { - dsyr2_blis_impl( uplo, n, alpha, x, incx, y, incy, a, lda); + dsyr2_( uplo, n, alpha, x, incx, y, incy, a, lda); } void DSYR2K(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const double *alpha,const double *a,const f77_int *lda,const double *b,const f77_int *ldb,const double *beta,double *c,const f77_int *ldc) @@ -1065,62 +1065,62 @@ void DSYRK_(const char *uplo,const char *trans,const f77_int *n,const f77_in void DTBMV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const double *a,const f77_int *lda,double *x,const f77_int *incx) { - dtbmv_blis_impl( uplo, trans, diag, n, k, a, lda, x, incx); + dtbmv_( uplo, trans, diag, n, k, a, lda, x, incx); } void dtbmv(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const double *a,const f77_int *lda,double *x,const f77_int *incx) { - dtbmv_blis_impl( uplo, trans, diag, n, k, a, lda, x, incx); + dtbmv_( uplo, trans, diag, n, k, a, lda, x, incx); } void DTBMV_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const double *a,const f77_int *lda,double *x,const f77_int *incx) { - dtbmv_blis_impl( uplo, trans, diag, n, k, a, lda, x, incx); + dtbmv_( uplo, trans, diag, n, k, a, lda, x, incx); } void DTBSV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const double *a,const f77_int *lda,double *x,const f77_int *incx) { - dtbsv_blis_impl( uplo, trans, diag, n, k, a, lda, x, incx); + dtbsv_( uplo, trans, diag, n, k, a, lda, x, incx); } void dtbsv(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const double *a,const f77_int *lda,double *x,const f77_int *incx) { - dtbsv_blis_impl( uplo, trans, diag, n, k, a, lda, x, incx); + dtbsv_( uplo, trans, diag, n, k, a, lda, x, incx); } void DTBSV_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const double *a,const f77_int *lda,double *x,const f77_int *incx) { - dtbsv_blis_impl( uplo, trans, diag, n, k, a, lda, x, incx); + dtbsv_( uplo, trans, diag, n, k, a, lda, x, incx); } void DTPMV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const double *ap,double *x,const f77_int *incx) { - dtpmv_blis_impl( uplo, trans, diag, n, ap, x, incx); + dtpmv_( uplo, trans, diag, n, ap, x, incx); } void dtpmv(const char *uplo,const char *trans,const char *diag,const f77_int *n,const double *ap,double *x,const f77_int *incx) { - dtpmv_blis_impl( uplo, trans, diag, n, ap, x, incx); + dtpmv_( uplo, trans, diag, n, ap, x, incx); } void DTPMV_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const double *ap,double *x,const f77_int *incx) { - dtpmv_blis_impl( uplo, trans, diag, n, ap, x, incx); + dtpmv_( uplo, trans, diag, n, ap, x, incx); } void DTPSV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const double *ap,double *x,const f77_int *incx) { - dtpsv_blis_impl( uplo, trans, diag, n, ap, x, incx); + dtpsv_( uplo, trans, diag, n, ap, x, incx); } void dtpsv(const char *uplo,const char *trans,const char *diag,const f77_int *n,const double *ap,double *x,const f77_int *incx) { - dtpsv_blis_impl( uplo, trans, diag, n, ap, x, incx); + dtpsv_( uplo, trans, diag, n, ap, x, incx); } void DTPSV_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const double *ap,double *x,const f77_int *incx) { - dtpsv_blis_impl( uplo, trans, diag, n, ap, x, incx); + dtpsv_( uplo, trans, diag, n, ap, x, incx); } void DTRMM(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,double *b,const f77_int *ldb) @@ -1140,17 +1140,17 @@ void DTRMM_(const char *side,const char *uplo,const char *transa,const cha void DTRMV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const double *a,const f77_int *lda,double *x,const f77_int *incx) { - dtrmv_blis_impl( uplo, trans, diag, n, a, lda, x, incx); + dtrmv_( uplo, trans, diag, n, a, lda, x, incx); } void dtrmv(const char *uplo,const char *trans,const char *diag,const f77_int *n,const double *a,const f77_int *lda,double *x,const f77_int *incx) { - dtrmv_blis_impl( uplo, trans, diag, n, a, lda, x, incx); + dtrmv_( uplo, trans, diag, n, a, lda, x, incx); } void DTRMV_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const double *a,const f77_int *lda,double *x,const f77_int *incx) { - dtrmv_blis_impl( uplo, trans, diag, n, a, lda, x, incx); + dtrmv_( uplo, trans, diag, n, a, lda, x, incx); } void DTRSM(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,double *b,const f77_int *ldb) @@ -1170,17 +1170,17 @@ void DTRSM_(const char *side,const char *uplo,const char *transa,const cha void DTRSV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const double *a,const f77_int *lda,double *x,const f77_int *incx) { - dtrsv_blis_impl( uplo, trans, diag, n, a, lda, x, incx); + dtrsv_( uplo, trans, diag, n, a, lda, x, incx); } void dtrsv(const char *uplo,const char *trans,const char *diag,const f77_int *n,const double *a,const f77_int *lda,double *x,const f77_int *incx) { - dtrsv_blis_impl( uplo, trans, diag, n, a, lda, x, incx); + dtrsv_( uplo, trans, diag, n, a, lda, x, incx); } void DTRSV_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const double *a,const f77_int *lda,double *x,const f77_int *incx) { - dtrsv_blis_impl( uplo, trans, diag, n, a, lda, x, incx); + dtrsv_( uplo, trans, diag, n, a, lda, x, incx); } double DZASUM(const f77_int *n,const dcomplex *zx,const f77_int *incx) @@ -1402,17 +1402,17 @@ float SDSDOT_(const f77_int *n,const float *sb, const float *sx, const f77_int void SGBMV(const char *trans,const f77_int *m,const f77_int *n,const f77_int *kl,const f77_int *ku,const float *alpha,const float *a,const f77_int *lda,const float *x,const f77_int *incx,const float *beta,float *y,const f77_int *incy) { - sgbmv_blis_impl( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); + sgbmv_( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); } void sgbmv(const char *trans,const f77_int *m,const f77_int *n,const f77_int *kl,const f77_int *ku,const float *alpha,const float *a,const f77_int *lda,const float *x,const f77_int *incx,const float *beta,float *y,const f77_int *incy) { - sgbmv_blis_impl( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); + sgbmv_( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); } void SGBMV_(const char *trans,const f77_int *m,const f77_int *n,const f77_int *kl,const f77_int *ku,const float *alpha,const float *a,const f77_int *lda,const float *x,const f77_int *incx,const float *beta,float *y,const f77_int *incy) { - sgbmv_blis_impl( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); + sgbmv_( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); } void SGEMM(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const float *alpha,const float *a,const f77_int *lda,const float *b,const f77_int *ldb,const float *beta,float *c,const f77_int *ldc) @@ -1432,32 +1432,32 @@ void SGEMM_(const char *transa,const char *transb,const f77_int *m,const f77 void SGEMV(const char *trans,const f77_int *m,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,const float *x,const f77_int *incx,const float *beta,float *y,const f77_int *incy) { - sgemv_blis_impl( trans, m, n, alpha, a, lda, x, incx, beta, y, incy); + sgemv_( trans, m, n, alpha, a, lda, x, incx, beta, y, incy); } void sgemv(const char *trans,const f77_int *m,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,const float *x,const f77_int *incx,const float *beta,float *y,const f77_int *incy) { - sgemv_blis_impl( trans, m, n, alpha, a, lda, x, incx, beta, y, incy); + sgemv_( trans, m, n, alpha, a, lda, x, incx, beta, y, incy); } void SGEMV_(const char *trans,const f77_int *m,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,const float *x,const f77_int *incx,const float *beta,float *y,const f77_int *incy) { - sgemv_blis_impl( trans, m, n, alpha, a, lda, x, incx, beta, y, incy); + sgemv_( trans, m, n, alpha, a, lda, x, incx, beta, y, incy); } void SGER(const f77_int *m,const f77_int *n,const float *alpha,const float *x,const f77_int *incx,const float *y,const f77_int *incy,float *a,const f77_int *lda) { - sger_blis_impl( m, n, alpha, x, incx, y, incy, a, lda); + sger_( m, n, alpha, x, incx, y, incy, a, lda); } void sger(const f77_int *m,const f77_int *n,const float *alpha,const float *x,const f77_int *incx,const float *y,const f77_int *incy,float *a,const f77_int *lda) { - sger_blis_impl( m, n, alpha, x, incx, y, incy, a, lda); + sger_( m, n, alpha, x, incx, y, incy, a, lda); } void SGER_(const f77_int *m,const f77_int *n,const float *alpha,const float *x,const f77_int *incx,const float *y,const f77_int *incy,float *a,const f77_int *lda) { - sger_blis_impl( m, n, alpha, x, incx, y, incy, a, lda); + sger_( m, n, alpha, x, incx, y, incy, a, lda); } @@ -1539,17 +1539,17 @@ void SROTMG_(float *sd1,float *sd2,float *sx1,const float *sy1,float *spara void SSBMV(const char *uplo,const f77_int *n,const f77_int *k,const float *alpha,const float *a,const f77_int *lda,const float *x,const f77_int *incx,const float *beta,float *y,const f77_int *incy) { - ssbmv_blis_impl( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); + ssbmv_( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); } void ssbmv(const char *uplo,const f77_int *n,const f77_int *k,const float *alpha,const float *a,const f77_int *lda,const float *x,const f77_int *incx,const float *beta,float *y,const f77_int *incy) { - ssbmv_blis_impl( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); + ssbmv_( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); } void SSBMV_(const char *uplo,const f77_int *n,const f77_int *k,const float *alpha,const float *a,const f77_int *lda,const float *x,const f77_int *incx,const float *beta,float *y,const f77_int *incy) { - ssbmv_blis_impl( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); + ssbmv_( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); } void SSCAL(const f77_int *n,const float *sa,float *sx,const f77_int *incx) @@ -1569,47 +1569,47 @@ void SSCAL_(const f77_int *n,const float *sa,float *sx,const f77_int *incx) void SSPMV(const char *uplo,const f77_int *n,const float *alpha,const float *ap,const float *x,const f77_int *incx,const float *beta,float *y,const f77_int *incy) { - sspmv_blis_impl( uplo, n, alpha, ap, x, incx, beta, y, incy); + sspmv_( uplo, n, alpha, ap, x, incx, beta, y, incy); } void sspmv(const char *uplo,const f77_int *n,const float *alpha,const float *ap,const float *x,const f77_int *incx,const float *beta,float *y,const f77_int *incy) { - sspmv_blis_impl( uplo, n, alpha, ap, x, incx, beta, y, incy); + sspmv_( uplo, n, alpha, ap, x, incx, beta, y, incy); } void SSPMV_(const char *uplo,const f77_int *n,const float *alpha,const float *ap,const float *x,const f77_int *incx,const float *beta,float *y,const f77_int *incy) { - sspmv_blis_impl( uplo, n, alpha, ap, x, incx, beta, y, incy); + sspmv_( uplo, n, alpha, ap, x, incx, beta, y, incy); } void SSPR(const char *uplo,const f77_int *n,const float *alpha,const float *x,const f77_int *incx,float *ap) { - sspr_blis_impl( uplo, n, alpha, x, incx, ap); + sspr_( uplo, n, alpha, x, incx, ap); } void sspr(const char *uplo,const f77_int *n,const float *alpha,const float *x,const f77_int *incx,float *ap) { - sspr_blis_impl( uplo, n, alpha, x, incx, ap); + sspr_( uplo, n, alpha, x, incx, ap); } void SSPR_(const char *uplo,const f77_int *n,const float *alpha,const float *x,const f77_int *incx,float *ap) { - sspr_blis_impl( uplo, n, alpha, x, incx, ap); + sspr_( uplo, n, alpha, x, incx, ap); } void SSPR2(const char *uplo,const f77_int *n,const float *alpha,const float *x,const f77_int *incx,const float *y,const f77_int *incy,float *ap) { - sspr2_blis_impl( uplo, n, alpha, x, incx, y, incy, ap); + sspr2_( uplo, n, alpha, x, incx, y, incy, ap); } void sspr2(const char *uplo,const f77_int *n,const float *alpha,const float *x,const f77_int *incx,const float *y,const f77_int *incy,float *ap) { - sspr2_blis_impl( uplo, n, alpha, x, incx, y, incy, ap); + sspr2_( uplo, n, alpha, x, incx, y, incy, ap); } void SSPR2_(const char *uplo,const f77_int *n,const float *alpha,const float *x,const f77_int *incx,const float *y,const f77_int *incy,float *ap) { - sspr2_blis_impl( uplo, n, alpha, x, incx, y, incy, ap); + sspr2_( uplo, n, alpha, x, incx, y, incy, ap); } void SSWAP(const f77_int *n,float *sx,const f77_int *incx,float *sy,const f77_int *incy) @@ -1644,47 +1644,47 @@ void SSYMM_(const char *side,const char *uplo,const f77_int *m,const f77_int void SSYMV(const char *uplo,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,const float *x,const f77_int *incx,const float *beta,float *y,const f77_int *incy) { - ssymv_blis_impl( uplo, n, alpha, a, lda, x, incx, beta, y, incy); + ssymv_( uplo, n, alpha, a, lda, x, incx, beta, y, incy); } void ssymv(const char *uplo,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,const float *x,const f77_int *incx,const float *beta,float *y,const f77_int *incy) { - ssymv_blis_impl( uplo, n, alpha, a, lda, x, incx, beta, y, incy); + ssymv_( uplo, n, alpha, a, lda, x, incx, beta, y, incy); } void SSYMV_(const char *uplo,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,const float *x,const f77_int *incx,const float *beta,float *y,const f77_int *incy) { - ssymv_blis_impl( uplo, n, alpha, a, lda, x, incx, beta, y, incy); + ssymv_( uplo, n, alpha, a, lda, x, incx, beta, y, incy); } void SSYR(const char *uplo,const f77_int *n,const float *alpha,const float *x,const f77_int *incx,float *a,const f77_int *lda) { - ssyr_blis_impl( uplo, n, alpha, x, incx, a, lda); + ssyr_( uplo, n, alpha, x, incx, a, lda); } void ssyr(const char *uplo,const f77_int *n,const float *alpha,const float *x,const f77_int *incx,float *a,const f77_int *lda) { - ssyr_blis_impl( uplo, n, alpha, x, incx, a, lda); + ssyr_( uplo, n, alpha, x, incx, a, lda); } void SSYR_(const char *uplo,const f77_int *n,const float *alpha,const float *x,const f77_int *incx,float *a,const f77_int *lda) { - ssyr_blis_impl( uplo, n, alpha, x, incx, a, lda); + ssyr_( uplo, n, alpha, x, incx, a, lda); } void SSYR2(const char *uplo,const f77_int *n,const float *alpha,const float *x,const f77_int *incx,const float *y,const f77_int *incy,float *a,const f77_int *lda) { - ssyr2_blis_impl( uplo, n, alpha, x, incx, y, incy, a, lda); + ssyr2_( uplo, n, alpha, x, incx, y, incy, a, lda); } void ssyr2(const char *uplo,const f77_int *n,const float *alpha,const float *x,const f77_int *incx,const float *y,const f77_int *incy,float *a,const f77_int *lda) { - ssyr2_blis_impl( uplo, n, alpha, x, incx, y, incy, a, lda); + ssyr2_( uplo, n, alpha, x, incx, y, incy, a, lda); } void SSYR2_(const char *uplo,const f77_int *n,const float *alpha,const float *x,const f77_int *incx,const float *y,const f77_int *incy,float *a,const f77_int *lda) { - ssyr2_blis_impl( uplo, n, alpha, x, incx, y, incy, a, lda); + ssyr2_( uplo, n, alpha, x, incx, y, incy, a, lda); } void SSYR2K(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const float *alpha,const float *a,const f77_int *lda,const float *b,const f77_int *ldb,const float *beta,float *c,const f77_int *ldc) @@ -1719,62 +1719,62 @@ void SSYRK_(const char *uplo,const char *trans,const f77_int *n,const f77_in void STBMV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const float *a,const f77_int *lda,float *x,const f77_int *incx) { - stbmv_blis_impl( uplo, trans, diag, n, k, a, lda, x, incx); + stbmv_( uplo, trans, diag, n, k, a, lda, x, incx); } void stbmv(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const float *a,const f77_int *lda,float *x,const f77_int *incx) { - stbmv_blis_impl( uplo, trans, diag, n, k, a, lda, x, incx); + stbmv_( uplo, trans, diag, n, k, a, lda, x, incx); } void STBMV_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const float *a,const f77_int *lda,float *x,const f77_int *incx) { - stbmv_blis_impl( uplo, trans, diag, n, k, a, lda, x, incx); + stbmv_( uplo, trans, diag, n, k, a, lda, x, incx); } void STBSV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const float *a,const f77_int *lda,float *x,const f77_int *incx) { - stbsv_blis_impl( uplo, trans, diag, n, k, a, lda, x, incx); + stbsv_( uplo, trans, diag, n, k, a, lda, x, incx); } void stbsv(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const float *a,const f77_int *lda,float *x,const f77_int *incx) { - stbsv_blis_impl( uplo, trans, diag, n, k, a, lda, x, incx); + stbsv_( uplo, trans, diag, n, k, a, lda, x, incx); } void STBSV_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const float *a,const f77_int *lda,float *x,const f77_int *incx) { - stbsv_blis_impl( uplo, trans, diag, n, k, a, lda, x, incx); + stbsv_( uplo, trans, diag, n, k, a, lda, x, incx); } void STPMV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const float *ap,float *x,const f77_int *incx) { - stpmv_blis_impl( uplo, trans, diag, n, ap, x, incx); + stpmv_( uplo, trans, diag, n, ap, x, incx); } void stpmv(const char *uplo,const char *trans,const char *diag,const f77_int *n,const float *ap,float *x,const f77_int *incx) { - stpmv_blis_impl( uplo, trans, diag, n, ap, x, incx); + stpmv_( uplo, trans, diag, n, ap, x, incx); } void STPMV_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const float *ap,float *x,const f77_int *incx) { - stpmv_blis_impl( uplo, trans, diag, n, ap, x, incx); + stpmv_( uplo, trans, diag, n, ap, x, incx); } void STPSV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const float *ap,float *x,const f77_int *incx) { - stpsv_blis_impl( uplo, trans, diag, n, ap, x, incx); + stpsv_( uplo, trans, diag, n, ap, x, incx); } void stpsv(const char *uplo,const char *trans,const char *diag,const f77_int *n,const float *ap,float *x,const f77_int *incx) { - stpsv_blis_impl( uplo, trans, diag, n, ap, x, incx); + stpsv_( uplo, trans, diag, n, ap, x, incx); } void STPSV_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const float *ap,float *x,const f77_int *incx) { - stpsv_blis_impl( uplo, trans, diag, n, ap, x, incx); + stpsv_( uplo, trans, diag, n, ap, x, incx); } void STRMM(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,float *b,const f77_int *ldb) @@ -1794,17 +1794,17 @@ void STRMM_(const char *side,const char *uplo,const char *transa,const cha void STRMV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const float *a,const f77_int *lda,float *x,const f77_int *incx) { - strmv_blis_impl( uplo, trans, diag, n, a, lda, x, incx); + strmv_( uplo, trans, diag, n, a, lda, x, incx); } void strmv(const char *uplo,const char *trans,const char *diag,const f77_int *n,const float *a,const f77_int *lda,float *x,const f77_int *incx) { - strmv_blis_impl( uplo, trans, diag, n, a, lda, x, incx); + strmv_( uplo, trans, diag, n, a, lda, x, incx); } void STRMV_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const float *a,const f77_int *lda,float *x,const f77_int *incx) { - strmv_blis_impl( uplo, trans, diag, n, a, lda, x, incx); + strmv_( uplo, trans, diag, n, a, lda, x, incx); } void STRSM(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,float *b,const f77_int *ldb) @@ -1824,17 +1824,17 @@ void STRSM_(const char *side,const char *uplo,const char *transa,const cha void STRSV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const float *a,const f77_int *lda,float *x,const f77_int *incx) { - strsv_blis_impl( uplo, trans, diag, n, a, lda, x, incx); + strsv_( uplo, trans, diag, n, a, lda, x, incx); } void strsv(const char *uplo,const char *trans,const char *diag,const f77_int *n,const float *a,const f77_int *lda,float *x,const f77_int *incx) { - strsv_blis_impl( uplo, trans, diag, n, a, lda, x, incx); + strsv_( uplo, trans, diag, n, a, lda, x, incx); } void STRSV_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const float *a,const f77_int *lda,float *x,const f77_int *incx) { - strsv_blis_impl( uplo, trans, diag, n, a, lda, x, incx); + strsv_( uplo, trans, diag, n, a, lda, x, incx); } int XERBLA(const char *srname,const f77_int *info, ftnlen n) @@ -1914,17 +1914,17 @@ void ZDSCAL_(const f77_int *n,const double *da,dcomplex *zx,const f77_int *incx) void ZGBMV(const char *trans,const f77_int *m,const f77_int *n,const f77_int *kl,const f77_int *ku,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *x,const f77_int *incx,const dcomplex *beta,dcomplex *y,const f77_int *incy) { - zgbmv_blis_impl( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); + zgbmv_( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); } void zgbmv(const char *trans,const f77_int *m,const f77_int *n,const f77_int *kl,const f77_int *ku,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *x,const f77_int *incx,const dcomplex *beta,dcomplex *y,const f77_int *incy) { - zgbmv_blis_impl( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); + zgbmv_( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); } void ZGBMV_(const char *trans,const f77_int *m,const f77_int *n,const f77_int *kl,const f77_int *ku,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *x,const f77_int *incx,const dcomplex *beta,dcomplex *y,const f77_int *incy) { - zgbmv_blis_impl( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); + zgbmv_( trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, incy); } void ZGEMM(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) @@ -1944,62 +1944,62 @@ void ZGEMM_(const char *transa,const char *transb,const f77_int *m,const f77 void ZGEMV(const char *trans,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *x,const f77_int *incx,const dcomplex *beta,dcomplex *y,const f77_int *incy) { - zgemv_blis_impl( trans, m, n, alpha, a, lda, x, incx, beta, y, incy); + zgemv_( trans, m, n, alpha, a, lda, x, incx, beta, y, incy); } void zgemv(const char *trans,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *x,const f77_int *incx,const dcomplex *beta,dcomplex *y,const f77_int *incy) { - zgemv_blis_impl( trans, m, n, alpha, a, lda, x, incx, beta, y, incy); + zgemv_( trans, m, n, alpha, a, lda, x, incx, beta, y, incy); } void ZGEMV_(const char *trans,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *x,const f77_int *incx,const dcomplex *beta,dcomplex *y,const f77_int *incy) { - zgemv_blis_impl( trans, m, n, alpha, a, lda, x, incx, beta, y, incy); + zgemv_( trans, m, n, alpha, a, lda, x, incx, beta, y, incy); } void ZGERC(const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *x,const f77_int *incx,const dcomplex *y,const f77_int *incy,dcomplex *a,const f77_int *lda) { - zgerc_blis_impl( m, n, alpha, x, incx, y, incy, a, lda); + zgerc_( m, n, alpha, x, incx, y, incy, a, lda); } void zgerc(const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *x,const f77_int *incx,const dcomplex *y,const f77_int *incy,dcomplex *a,const f77_int *lda) { - zgerc_blis_impl( m, n, alpha, x, incx, y, incy, a, lda); + zgerc_( m, n, alpha, x, incx, y, incy, a, lda); } void ZGERC_(const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *x,const f77_int *incx,const dcomplex *y,const f77_int *incy,dcomplex *a,const f77_int *lda) { - zgerc_blis_impl( m, n, alpha, x, incx, y, incy, a, lda); + zgerc_( m, n, alpha, x, incx, y, incy, a, lda); } void ZGERU(const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *x,const f77_int *incx,const dcomplex *y,const f77_int *incy,dcomplex *a,const f77_int *lda) { - zgeru_blis_impl( m, n, alpha, x, incx, y, incy, a, lda); + zgeru_( m, n, alpha, x, incx, y, incy, a, lda); } void zgeru(const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *x,const f77_int *incx,const dcomplex *y,const f77_int *incy,dcomplex *a,const f77_int *lda) { - zgeru_blis_impl( m, n, alpha, x, incx, y, incy, a, lda); + zgeru_( m, n, alpha, x, incx, y, incy, a, lda); } void ZGERU_(const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *x,const f77_int *incx,const dcomplex *y,const f77_int *incy,dcomplex *a,const f77_int *lda) { - zgeru_blis_impl( m, n, alpha, x, incx, y, incy, a, lda); + zgeru_( m, n, alpha, x, incx, y, incy, a, lda); } void ZHBMV(const char *uplo,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *x,const f77_int *incx,const dcomplex *beta,dcomplex *y,const f77_int *incy) { - zhbmv_blis_impl( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); + zhbmv_( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); } void zhbmv(const char *uplo,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *x,const f77_int *incx,const dcomplex *beta,dcomplex *y,const f77_int *incy) { - zhbmv_blis_impl( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); + zhbmv_( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); } void ZHBMV_(const char *uplo,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *x,const f77_int *incx,const dcomplex *beta,dcomplex *y,const f77_int *incy) { - zhbmv_blis_impl( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); + zhbmv_( uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); } void ZHEMM(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) @@ -2019,47 +2019,47 @@ void ZHEMM_(const char *side,const char *uplo,const f77_int *m,const f77_int void ZHEMV(const char *uplo,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *x,const f77_int *incx,const dcomplex *beta,dcomplex *y,const f77_int *incy) { - zhemv_blis_impl( uplo, n, alpha, a, lda, x, incx, beta, y, incy); + zhemv_( uplo, n, alpha, a, lda, x, incx, beta, y, incy); } void zhemv(const char *uplo,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *x,const f77_int *incx,const dcomplex *beta,dcomplex *y,const f77_int *incy) { - zhemv_blis_impl( uplo, n, alpha, a, lda, x, incx, beta, y, incy); + zhemv_( uplo, n, alpha, a, lda, x, incx, beta, y, incy); } void ZHEMV_(const char *uplo,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *x,const f77_int *incx,const dcomplex *beta,dcomplex *y,const f77_int *incy) { - zhemv_blis_impl( uplo, n, alpha, a, lda, x, incx, beta, y, incy); + zhemv_( uplo, n, alpha, a, lda, x, incx, beta, y, incy); } void ZHER(const char *uplo,const f77_int *n,const double *alpha,const dcomplex *x,const f77_int *incx,dcomplex *a,const f77_int *lda) { - zher_blis_impl( uplo, n, alpha, x, incx, a, lda); + zher_( uplo, n, alpha, x, incx, a, lda); } void zher(const char *uplo,const f77_int *n,const double *alpha,const dcomplex *x,const f77_int *incx,dcomplex *a,const f77_int *lda) { - zher_blis_impl( uplo, n, alpha, x, incx, a, lda); + zher_( uplo, n, alpha, x, incx, a, lda); } void ZHER_(const char *uplo,const f77_int *n,const double *alpha,const dcomplex *x,const f77_int *incx,dcomplex *a,const f77_int *lda) { - zher_blis_impl( uplo, n, alpha, x, incx, a, lda); + zher_( uplo, n, alpha, x, incx, a, lda); } void ZHER2(const char *uplo,const f77_int *n,const dcomplex *alpha,const dcomplex *x,const f77_int *incx,const dcomplex *y,const f77_int *incy,dcomplex *a,const f77_int *lda) { - zher2_blis_impl( uplo, n, alpha, x, incx, y, incy, a, lda); + zher2_( uplo, n, alpha, x, incx, y, incy, a, lda); } void zher2(const char *uplo,const f77_int *n,const dcomplex *alpha,const dcomplex *x,const f77_int *incx,const dcomplex *y,const f77_int *incy,dcomplex *a,const f77_int *lda) { - zher2_blis_impl( uplo, n, alpha, x, incx, y, incy, a, lda); + zher2_( uplo, n, alpha, x, incx, y, incy, a, lda); } void ZHER2_(const char *uplo,const f77_int *n,const dcomplex *alpha,const dcomplex *x,const f77_int *incx,const dcomplex *y,const f77_int *incy,dcomplex *a,const f77_int *lda) { - zher2_blis_impl( uplo, n, alpha, x, incx, y, incy, a, lda); + zher2_( uplo, n, alpha, x, incx, y, incy, a, lda); } void ZHER2K(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const double *beta,dcomplex *c,const f77_int *ldc) @@ -2094,47 +2094,47 @@ void ZHERK_(const char *uplo,const char *trans,const f77_int *n,const f77_in void ZHPMV(const char *uplo,const f77_int *n,const dcomplex *alpha,const dcomplex *ap,const dcomplex *x,const f77_int *incx,const dcomplex *beta,dcomplex *y,const f77_int *incy) { - zhpmv_blis_impl( uplo, n, alpha, ap, x, incx, beta, y, incy); + zhpmv_( uplo, n, alpha, ap, x, incx, beta, y, incy); } void zhpmv(const char *uplo,const f77_int *n,const dcomplex *alpha,const dcomplex *ap,const dcomplex *x,const f77_int *incx,const dcomplex *beta,dcomplex *y,const f77_int *incy) { - zhpmv_blis_impl( uplo, n, alpha, ap, x, incx, beta, y, incy); + zhpmv_( uplo, n, alpha, ap, x, incx, beta, y, incy); } void ZHPMV_(const char *uplo,const f77_int *n,const dcomplex *alpha,const dcomplex *ap,const dcomplex *x,const f77_int *incx,const dcomplex *beta,dcomplex *y,const f77_int *incy) { - zhpmv_blis_impl( uplo, n, alpha, ap, x, incx, beta, y, incy); + zhpmv_( uplo, n, alpha, ap, x, incx, beta, y, incy); } void ZHPR(const char *uplo,const f77_int *n,const bla_double *alpha,const dcomplex *x,const f77_int *incx,dcomplex *ap) { - zhpr_blis_impl( uplo, n, alpha, x, incx, ap); + zhpr_( uplo, n, alpha, x, incx, ap); } void zhpr(const char *uplo,const f77_int *n,const bla_double *alpha,const dcomplex *x,const f77_int *incx,dcomplex *ap) { - zhpr_blis_impl( uplo, n, alpha, x, incx, ap); + zhpr_( uplo, n, alpha, x, incx, ap); } void ZHPR_(const char *uplo,const f77_int *n,const bla_double *alpha,const dcomplex *x,const f77_int *incx,dcomplex *ap) { - zhpr_blis_impl( uplo, n, alpha, x, incx, ap); + zhpr_( uplo, n, alpha, x, incx, ap); } void ZHPR2(const char *uplo,const f77_int *n,const dcomplex *alpha,const dcomplex *x,const f77_int *incx,const dcomplex *y,const f77_int *incy,dcomplex *ap) { - zhpr2_blis_impl( uplo, n, alpha, x, incx, y, incy, ap); + zhpr2_( uplo, n, alpha, x, incx, y, incy, ap); } void zhpr2(const char *uplo,const f77_int *n,const dcomplex *alpha,const dcomplex *x,const f77_int *incx,const dcomplex *y,const f77_int *incy,dcomplex *ap) { - zhpr2_blis_impl( uplo, n, alpha, x, incx, y, incy, ap); + zhpr2_( uplo, n, alpha, x, incx, y, incy, ap); } void ZHPR2_(const char *uplo,const f77_int *n,const dcomplex *alpha,const dcomplex *x,const f77_int *incx,const dcomplex *y,const f77_int *incy,dcomplex *ap) { - zhpr2_blis_impl( uplo, n, alpha, x, incx, y, incy, ap); + zhpr2_( uplo, n, alpha, x, incx, y, incy, ap); } void ZROTG(dcomplex *ca,bla_dcomplex *cb,bla_double *c,dcomplex *s) @@ -2229,62 +2229,62 @@ void ZSYRK_(const char *uplo,const char *trans,const f77_int *n,const f77_in void ZTBMV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const dcomplex *a,const f77_int *lda,dcomplex *x,const f77_int *incx) { - ztbmv_blis_impl( uplo, trans, diag, n, k, a, lda, x, incx); + ztbmv_( uplo, trans, diag, n, k, a, lda, x, incx); } void ztbmv(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const dcomplex *a,const f77_int *lda,dcomplex *x,const f77_int *incx) { - ztbmv_blis_impl( uplo, trans, diag, n, k, a, lda, x, incx); + ztbmv_( uplo, trans, diag, n, k, a, lda, x, incx); } void ZTBMV_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const dcomplex *a,const f77_int *lda,dcomplex *x,const f77_int *incx) { - ztbmv_blis_impl( uplo, trans, diag, n, k, a, lda, x, incx); + ztbmv_( uplo, trans, diag, n, k, a, lda, x, incx); } void ZTBSV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const dcomplex *a,const f77_int *lda,dcomplex *x,const f77_int *incx) { - ztbsv_blis_impl( uplo, trans, diag, n, k, a, lda, x, incx); + ztbsv_( uplo, trans, diag, n, k, a, lda, x, incx); } void ztbsv(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const dcomplex *a,const f77_int *lda,dcomplex *x,const f77_int *incx) { - ztbsv_blis_impl( uplo, trans, diag, n, k, a, lda, x, incx); + ztbsv_( uplo, trans, diag, n, k, a, lda, x, incx); } void ZTBSV_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const dcomplex *a,const f77_int *lda,dcomplex *x,const f77_int *incx) { - ztbsv_blis_impl( uplo, trans, diag, n, k, a, lda, x, incx); + ztbsv_( uplo, trans, diag, n, k, a, lda, x, incx); } void ZTPMV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const dcomplex *ap,dcomplex *x,const f77_int *incx) { - ztpmv_blis_impl( uplo, trans, diag, n, ap, x, incx); + ztpmv_( uplo, trans, diag, n, ap, x, incx); } void ztpmv(const char *uplo,const char *trans,const char *diag,const f77_int *n,const dcomplex *ap,dcomplex *x,const f77_int *incx) { - ztpmv_blis_impl( uplo, trans, diag, n, ap, x, incx); + ztpmv_( uplo, trans, diag, n, ap, x, incx); } void ZTPMV_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const dcomplex *ap,dcomplex *x,const f77_int *incx) { - ztpmv_blis_impl( uplo, trans, diag, n, ap, x, incx); + ztpmv_( uplo, trans, diag, n, ap, x, incx); } void ZTPSV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const dcomplex *ap,dcomplex *x,const f77_int *incx) { - ztpsv_blis_impl( uplo, trans, diag, n, ap, x, incx); + ztpsv_( uplo, trans, diag, n, ap, x, incx); } void ztpsv(const char *uplo,const char *trans,const char *diag,const f77_int *n,const dcomplex *ap,dcomplex *x,const f77_int *incx) { - ztpsv_blis_impl( uplo, trans, diag, n, ap, x, incx); + ztpsv_( uplo, trans, diag, n, ap, x, incx); } void ZTPSV_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const dcomplex *ap,dcomplex *x,const f77_int *incx) { - ztpsv_blis_impl( uplo, trans, diag, n, ap, x, incx); + ztpsv_( uplo, trans, diag, n, ap, x, incx); } void ZTRMM(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,dcomplex *b,const f77_int *ldb) @@ -2304,17 +2304,17 @@ void ZTRMM_(const char *side,const char *uplo,const char *transa,const cha void ZTRMV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const dcomplex *a,const f77_int *lda,dcomplex *x,const f77_int *incx) { - ztrmv_blis_impl( uplo, trans, diag, n, a, lda, x, incx); + ztrmv_( uplo, trans, diag, n, a, lda, x, incx); } void ztrmv(const char *uplo,const char *trans,const char *diag,const f77_int *n,const dcomplex *a,const f77_int *lda,dcomplex *x,const f77_int *incx) { - ztrmv_blis_impl( uplo, trans, diag, n, a, lda, x, incx); + ztrmv_( uplo, trans, diag, n, a, lda, x, incx); } void ZTRMV_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const dcomplex *a,const f77_int *lda,dcomplex *x,const f77_int *incx) { - ztrmv_blis_impl( uplo, trans, diag, n, a, lda, x, incx); + ztrmv_( uplo, trans, diag, n, a, lda, x, incx); } void ZTRSM(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,dcomplex *b,const f77_int *ldb) @@ -2334,17 +2334,17 @@ void ZTRSM_(const char *side,const char *uplo,const char *transa,const cha void ZTRSV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const dcomplex *a,const f77_int *lda,dcomplex *x,const f77_int *incx) { - ztrsv_blis_impl( uplo, trans, diag, n, a, lda, x, incx); + ztrsv_( uplo, trans, diag, n, a, lda, x, incx); } void ztrsv(const char *uplo,const char *trans,const char *diag,const f77_int *n,const dcomplex *a,const f77_int *lda,dcomplex *x,const f77_int *incx) { - ztrsv_blis_impl( uplo, trans, diag, n, a, lda, x, incx); + ztrsv_( uplo, trans, diag, n, a, lda, x, incx); } void ZTRSV_(const char *uplo,const char *trans,const char *diag,const f77_int *n,const dcomplex *a,const f77_int *lda,dcomplex *x,const f77_int *incx) { - ztrsv_blis_impl( uplo, trans, diag, n, a, lda, x, incx); + ztrsv_( uplo, trans, diag, n, a, lda, x, incx); } From a7d1cc736953b3992d20640e5cddb1ad70109058 Mon Sep 17 00:00:00 2001 From: mkadavil Date: Mon, 29 Aug 2022 11:54:38 +0530 Subject: [PATCH 207/243] Multi-Threading support for BFloat16 gemm. -OpenMP based multi-threading support added for BFloat16 gemm. Both gemm and reorder api's are parallelized. -Multi-threading support for u8s8s16 reorder api. -Typecast issues fixed for bfloat16 gemm kernels. AMD-Internal: [CPUPL-2459] Change-Id: I6502d71ab32aa73bb159245976ea3d3a8e0ed109 --- .../aocl_gemm/frame/bf16bf16f32/lpgemm_bf16.c | 251 +++++++++--- .../frame/bf16bf16f32/lpgemm_reorder_bf16.c | 145 +++++-- .../threading/lpgemm_thread_decor_openmp.h | 2 +- .../frame/u8s8s16/lpgemm_reorder_s16.c | 177 +++++--- .../frame/u8s8s16/lpgemm_reorder_s16.h | 10 +- addon/aocl_gemm/frame/u8s8s32/lpgemm_utils.h | 2 +- .../lpgemm_6x64rowmajor_bf16_amd512vnni.c | 40 +- .../lpgemm_m_fringe_bf16_amd512vnni.c | 140 +++---- .../lpgemm_mn_fringe_bf16_amd512vnni.c | 382 +++++++++--------- .../lpgemm_n_fringe_bf16_amd512vnni.c | 124 +++--- 10 files changed, 785 insertions(+), 488 deletions(-) diff --git a/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_bf16.c b/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_bf16.c index 3a1f473d58..e03ec87421 100644 --- a/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_bf16.c +++ b/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_bf16.c @@ -43,47 +43,72 @@ // B should always be packed. LPGEMM_5LOOP(bfloat16,bfloat16,float,bf16bf16f32of32) { - dim_t NC = lpgemm_get_block_size_NC_global_cntx( BF16BF16F32OF32 ); - dim_t KC = lpgemm_get_block_size_KC_global_cntx( BF16BF16F32OF32 ); - dim_t MC = lpgemm_get_block_size_MC_global_cntx( BF16BF16F32OF32 ); - dim_t NR = lpgemm_get_block_size_NR_global_cntx( BF16BF16F32OF32 ); - - const int16_t* a_use = NULL; - dim_t cs_a_use = cs_a; - dim_t a_block_stride = 0; - - const int16_t* b_use = NULL; - dim_t rs_b_use = rs_b; - dim_t cs_b_use = cs_b; - - float* c_use_jc = NULL; - float* c_use_ic = NULL; - - // kc needs to be a multiple of 2 so that it can be used with dpbf16_ps - // instruction. Padding is added in cases this condition is not - // satisfied, and therefore the k offset used for packed/reordered - // buffer needs to be updated. - dim_t k_updated = k; - k_updated += (k_updated & 0x1); - - // Is required to decide whether to apply post ops or not. - bool is_last_k = FALSE; - - for ( dim_t jc = 0; jc < n; jc += NC ) - { - dim_t nc0 = ( ( jc + NC ) <= n ) ? NC : ( n % NC ); - - dim_t nc0_mod16 = nc0 % 16; - dim_t nc0_updated = nc0; - if ( nc0_mod16 != 0 ) + dim_t NC = lpgemm_get_block_size_NC_global_cntx( BF16BF16F32OF32 ); + dim_t KC = lpgemm_get_block_size_KC_global_cntx( BF16BF16F32OF32 ); + dim_t MC = lpgemm_get_block_size_MC_global_cntx( BF16BF16F32OF32 ); + dim_t NR = lpgemm_get_block_size_NR_global_cntx( BF16BF16F32OF32 ); + dim_t MR = lpgemm_get_block_size_MR_global_cntx( BF16BF16F32OF32 ); + + const int16_t* a_use = NULL; + dim_t cs_a_use = cs_a; + dim_t a_block_stride = 0; + + const int16_t* b_use = NULL; + dim_t rs_b_use = rs_b; + dim_t cs_b_use = cs_b; + + float* c_use_jc = NULL; + float* c_use_ic = NULL; + + // Pack buffer for B. + bfloat16* pack_b_buffer_bf16; + mem_t mem_b = BLIS_MEM_INITIALIZER; + siz_t mem_b_size_req = 0; + dim_t packb_min_NR = 16; + + // kc needs to be a multiple of 2 so that it can be used with dpbf16_ps + // instruction. Padding is added in cases this condition is not + // satisfied, and therefore the k offset used for packed/reordered + // buffer needs to be updated. + dim_t k_updated = k; + k_updated += (k_updated & 0x1); + + // Is required to decide whether to apply post ops or not. + bool is_last_k = FALSE; + + // Generate thrinfo objects for jc and ic loops from lpgemm_thrinfo_t. + thrinfo_t thread_jc; + thrinfo_t thread_ic; + + lpgemm_gen_thrinfo( thread, &thread_jc, &thread_ic ); + + // Compute the JC loop thread range for the current thread. + dim_t jc_start, jc_end; + bli_thread_range_sub( &thread_jc, n, NR, FALSE, &jc_start, &jc_end ); + + for ( dim_t jc = jc_start; jc < jc_end; jc += NC ) + { + dim_t nc0 = bli_min( ( jc_end - jc ), NC ); + c_use_jc = c + jc; + + dim_t jc_cur_loop = jc; + dim_t jc_cur_loop_rem = 0; + dim_t n_sub_updated; + + if ( mtag_b == REORDERED ) { - nc0_updated += ( 16 - nc0_mod16 ); + get_B_panel_reordered_start_offset_width + ( + jc, n, NC, packb_min_NR, + &jc_cur_loop, &jc_cur_loop_rem, + &nc0, &n_sub_updated + ); } for ( dim_t pc = 0; pc < k; pc += KC ) { float beta0 = ( pc == 0 ) ? beta : 1; - dim_t kc0 = ( ( pc + KC ) <= k ) ? KC : ( k % KC ); + dim_t kc0 = bli_min( ( k - pc ), KC ); // kc0 needs to be a multiple of 2 so that it can be // used with dpbf16_ps instruction. Padding is added in @@ -93,43 +118,159 @@ LPGEMM_5LOOP(bfloat16,bfloat16,float,bf16bf16f32of32) dim_t kc0_updated = kc0; kc0_updated += (kc0_updated & 0x1); - is_last_k = ( ( pc + KC ) >= k ) ? ( TRUE ) : ( FALSE ); + is_last_k = ( ( pc + KC ) >= k ) ? ( TRUE ) : ( FALSE ); + + if ( mtag_b == PACK ) + { + // Pack B chunks are based on jc work id. + dim_t jc_work_id = bli_thread_work_id( &thread_jc ); + + // Using child thrinfo (thread_ic) tid to decide chief thread + // per B matrix chunk (jc work id group) + if ( bli_thread_am_ochief( &thread_ic ) ) + { + // nc0 needs to be a multiple of 16 since this gives maximum + // vectorization. Packing B always results in buffers with width + // which is a multiple of 16. Subsequently the nc0 offsets used + // for packed/reordered buffers needs to be updated. + dim_t nc0_updated = make_multiple_of_n( nc0, packb_min_NR ); + mem_b_size_req = sizeof( bfloat16 ) * nc0_updated * kc0_updated; + + lpgemm_alloc_mem_panel + ( + mem_b_size_req, BLIS_BUFFER_FOR_B_PANEL, + &mem_b, rntm + ); + + thread->comm[jc_work_id].sent_object = + bli_mem_buffer( &mem_b ); + } + + // All threads in work group should wait till chief thread has + // finished allocating the packing buffers. + bli_thrcomm_barrier + ( + bli_thread_ocomm_id( &thread_ic ), + &thread->comm[jc_work_id] + ); + + pack_b_buffer_bf16 = + ( bfloat16* ) thread->comm[jc_work_id].sent_object; + + // Compute the B panel per thread loop range for parallel + // packing using ic_ways number of threads. Since atmost only + // ic_ways threads can be used, the thread_ic attributes are + // used to split the loop range. + dim_t jc_packb_start, jc_packb_end; + bli_thread_range_sub + ( + &thread_ic, nc0, NR, FALSE, + &jc_packb_start, &jc_packb_end + ); + + // Ensure thread ranges are valid, especially cases where no: + // of threads available for parallelization are greater than + // no: of B panel NR chunks. + if ( ( jc_packb_end > jc_packb_start ) && + ( jc_packb_start < ( jc + nc0 ) ) ) + { + packb_nr64_bf16bf16f32of32 + ( + pack_b_buffer_bf16 + ( jc_packb_start * kc0_updated ), + ( b + ( rs_b * pc ) + ( cs_b * jc ) + + ( cs_b * jc_packb_start ) ), rs_b, + ( jc_packb_end - jc_packb_start ), kc0, + &rs_b_use, &cs_b_use + ); + } + else + { + get_packb_nr64_bf16bf16f32of32_strides( &rs_b_use, &cs_b_use ); + } + + // All threads in work group should wait till B matrix packing + // is completed by the participating threads. + bli_thrcomm_barrier + ( + bli_thread_ocomm_id( &thread_ic ), + &thread->comm[jc_work_id] + ); + b_use = pack_b_buffer_bf16; + } // B part getting processed - if ( mtag_b == REORDERED ) + if ( mtag_b == REORDERED ) { - b_use = b + ( jc * k_updated ) + ( pc * nc0_updated ); + // In multi-threaded scenarios, an extra offset into a given + // packed B panel is required, since the jc loop split can + // result in per thread start offset inside the panel, instead + // of panel boundaries. + b_use = b + ( jc_cur_loop * k_updated ) + + ( n_sub_updated * pc ) + + ( jc_cur_loop_rem * kc0_updated ); + get_packb_nr64_bf16bf16f32of32_strides( &rs_b_use, &cs_b_use ); } - for ( dim_t ic = 0; ic < m; ic += MC ) + dim_t ic_start, ic_end; + bli_thread_range_sub( &thread_ic, m, MR, FALSE, &ic_start, &ic_end ); + + for ( dim_t ic = ic_start; ic < ic_end; ic += MC ) { - dim_t mc0 = ( ( ic + MC ) <= m ) ? MC : ( m % MC ); + dim_t mc0 = bli_min( ( ic_end - ic ), MC ); + c_use_ic = c_use_jc + ( rs_c * ic ); - a_use = a + ( rs_a * ic ) + ( cs_a * pc ); + if ( mtag_a == UNPACKED ) + { + a_use = a + ( rs_a * ic ) + ( cs_a * pc ); - // bf16 kernel reads 2 elements, totalling 4 bytes in a - // single broadcast for use in bf16 instruction. - // Non bf16 based kernel requires update to this code. - cs_a_use = 2; - a_block_stride = rs_a; + // bf16 kernel reads 2 elements, totalling 4 bytes in a + // single broadcast for use in bf16 instruction. + // Non bf16 based kernel requires update to this code. + cs_a_use = 2; + a_block_stride = rs_a; + } for ( dim_t jr = 0; jr < nc0; jr += NR ) { - dim_t nr0 = ( ( jr + NR ) <= nc0 ) ? NR : ( nc0 % NR ); + dim_t nr0 = bli_min( ( nc0 - jr ), NR ); // Reorder/Packed B, Reorder/Packed/Unpacked A call. lpgemm_rowvar_bf16bf16f32of32_6x64 - ( - mc0, nr0, kc0, - a_use, rs_a, cs_a_use, a_block_stride, - ( b_use + ( jr * kc0_updated ) ), rs_b_use, cs_b_use, - ( c + ( rs_c * ic ) + jc + jr ), rs_c, 1, - alpha, beta0, - is_last_k, ic, ( jc + jr ), post_op_list - ); + ( + mc0, nr0, kc0, + a_use, rs_a, cs_a_use, a_block_stride, + ( b_use + ( jr * kc0_updated ) ), rs_b_use, cs_b_use, + ( c_use_ic + jr ), rs_c, 1, + alpha, beta0, + is_last_k, ic, ( jc + jr ), post_op_list + ); } } } + if ( mtag_b == REORDERED ) + { + adjust_B_panel_reordered_jc( &jc, jc_cur_loop ); + } + } + + // Release pack buffers. + if ( mtag_b == PACK ) + { + // All threads in work group should wait till B matrix usage is + // completed by the participating threads. + bli_thrcomm_barrier + ( + bli_thread_ocomm_id( &thread_jc ), + &thread->comm[bli_thread_work_id( &thread_jc)] + ); + + if ( bli_thread_am_ochief( &thread_ic ) ) + { + if ( bli_mem_is_alloc( &mem_b ) ) + { + bli_membrk_release( rntm, &mem_b ); + } + } } } diff --git a/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_reorder_bf16.c b/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_reorder_bf16.c index 5b3461d73d..07b087b790 100644 --- a/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_reorder_bf16.c +++ b/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_reorder_bf16.c @@ -45,49 +45,130 @@ void reorderb_nr64_bf16bf16f32of32 lpgemm_obj_t *b_reorder ) { - dim_t NC = lpgemm_get_block_size_NC_global_cntx( BF16BF16F32OF32 ); - dim_t KC = lpgemm_get_block_size_KC_global_cntx( BF16BF16F32OF32 ); + dim_t NC = lpgemm_get_block_size_NC_global_cntx( BF16BF16F32OF32 ); + dim_t NR = lpgemm_get_block_size_NR_global_cntx( BF16BF16F32OF32 ); + dim_t KC = lpgemm_get_block_size_KC_global_cntx( BF16BF16F32OF32 ); - // Extracting the matrix properties from the lpgemm object - dim_t rs_b = b->rs; - dim_t n = b->width; - dim_t k = b->length; + // Extracting the matrix properties from the lpgemm object + dim_t rs_b = b->rs; + dim_t n = b->width; + dim_t k = b->length; - dim_t rs_b_reorder; - dim_t cs_b_reorder; + dim_t rs_b_reorder; + dim_t cs_b_reorder; - // k needs to be a multiple of 2 so that it can be used with vpdpbusd - // instruction. Padding is added in cases this condition is not - // satisfied, and therefore the k offset used for packed/reordered - // buffer needs to be updated. - dim_t k_updated = k; - k_updated += (k_updated & 0x1); + // k needs to be a multiple of 2 so that it can be used with dpbf + // instruction. Padding is added in cases this condition is not + // satisfied, and therefore the k offset used for packed/reordered + // buffer needs to be updated. + dim_t k_updated = k; + k_updated += (k_updated & 0x1); - for ( dim_t jc = 0; jc < n; jc += NC ) + // Initialize a local runtime with global settings if necessary. Note + // that in the case that a runtime is passed in, we make a local copy. + rntm_t rntm_g; + bli_rntm_init_from_global( &rntm_g ); + + dim_t n_threads = bli_rntm_num_threads( &rntm_g ); + n_threads = ( n_threads > 0 ) ? n_threads : 1; + +#ifdef BLIS_ENABLE_OPENMP + _Pragma( "omp parallel num_threads(n_threads)" ) + { + // Initialise a local thrinfo obj for work split across threads. + thrinfo_t thread_jc; + bli_thrinfo_set_n_way( n_threads, &thread_jc ); + bli_thrinfo_set_work_id( omp_get_thread_num(), &thread_jc ); +#else { - dim_t nc0 = ( ( jc + NC ) <= n ) ? NC : ( n % NC ); + // Initialise a local thrinfo obj for work split across threads. + thrinfo_t thread_jc; + bli_thrinfo_set_n_way( 1, &thread_jc ); + bli_thrinfo_set_work_id( 0, &thread_jc ); +#endif + // Compute the JC loop thread range for the current thread. + dim_t jc_start, jc_end; + bli_thread_range_sub( &thread_jc, n, NR, FALSE, &jc_start, &jc_end ); - dim_t nc0_mod16 = nc0 % 16; - dim_t nc0_updated = nc0; - if ( nc0_mod16 != 0 ) + for ( dim_t jc = jc_start; jc < jc_end; jc += NC ) { - nc0_updated += ( 16 - nc0_mod16 ); - } - for ( dim_t pc = 0; pc < k; pc += KC ) - { - dim_t kc0 = ( ( pc + KC ) <= k ) ? KC : ( k % KC ); - // B should always be packed. - packb_nr64_bf16bf16f32of32 + dim_t nc0 = bli_min( ( jc_end - jc ), NC ); + + dim_t jc_cur_loop = jc; + dim_t jc_cur_loop_rem = 0; + dim_t n_sub_updated; + + get_B_panel_reordered_start_offset_width ( - ( ( ( bfloat16* )b_reorder->storage.aligned_buffer ) + ( jc * k_updated ) + - ( nc0_updated * pc ) ), - ( ( ( bfloat16* )b->storage.aligned_buffer ) + ( rs_b * pc ) + jc ), - rs_b, nc0, kc0, &rs_b_reorder, &cs_b_reorder - ); + jc, n, NC, 16, + &jc_cur_loop, &jc_cur_loop_rem, + &nc0, &n_sub_updated + ); + + for ( dim_t pc = 0; pc < k; pc += KC ) + { + dim_t kc0 = bli_min( ( k - pc ), KC ); + + // k needs to be a multiple of 2 so that it can be used with dpbf + // instruction. Padding is added in cases this condition is not + // satisfied, and therefore the k offset used for packed/reordered + // buffer needs to be updated. + dim_t kc0_updated = kc0; + kc0_updated += (kc0_updated & 0x1); + + // The offsets are calculated in such a way that it resembles + // the reorder buffer traversal in single threaded reordering. + // The panel boundaries (KCxNC) remain as it is accessed in + // single thread, and as a consequence a thread with jc_start + // inside the panel cannot consider NC range for reorder. It + // has to work with NC' < NC, and the offset is calulated using + // prev NC panels spanning k dim + cur NC panel spaning pc loop + // cur iteration + (NC - NC') spanning current kc0 (<= KC). + // + //Eg: Consider the following reordered buffer diagram: + // t1 t2 + // | | + // | |..NC..| + // | | | + // |.NC. |.NC. |NC'|NC" + // pc=0-+-----+-----+---+--+ + // KC| | | | | + // | 1 | 3 | 5 | + // pc=KC-+-----+-----+---st-+ + // KC| | | | | + // | 2 | 4 | 6 | 7| + // pc=k=2KC-+-----+-----+---+--+ + // |jc=0 |jc=NC|jc=2NC| + // + // The numbers 1,2..6,7 denotes the order in which reordered + // KCxNC blocks are stored in memory, ie: block 1 followed by 2 + // followed by 3, etc. Given two threads t1 and t2, and t2 needs + // to acces point st in the reorder buffer to write the data: + // The offset calulation logic will be: + // jc_cur_loop = 2NC, jc_cur_loop_rem = NC', pc = KC, + // n_sub_updated = NC, k = 2KC, kc0_updated = KC + // + // st = ( jc_cur_loop * k ) + // + ( n_sub_updated * pc ) + // + ( NC' * kc0_updated) + + // B should always be packed. + packb_nr64_bf16bf16f32of32 + ( + ( ( ( bfloat16* )b_reorder->storage.aligned_buffer ) + + ( jc_cur_loop * k_updated ) + ( n_sub_updated * pc ) + + ( jc_cur_loop_rem * kc0_updated ) ), + ( ( ( bfloat16* )b->storage.aligned_buffer ) + + ( rs_b * pc ) + jc ), + rs_b, nc0, kc0, &rs_b_reorder, &cs_b_reorder + ); + } + + adjust_B_panel_reordered_jc( &jc, jc_cur_loop ); } } b_reorder->rs = rs_b_reorder; b_reorder->cs = cs_b_reorder; b_reorder->mtag = REORDERED; -} +} diff --git a/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.h b/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.h index 82702a4cf6..a32a3b580a 100644 --- a/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.h +++ b/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.h @@ -92,7 +92,7 @@ void lpgemm_ ## LPGEMM_SFX ## _thread_decorator \ lpgemm_post_op* post_op_list \ ); \ -GEN_LPGEMM_DECORATOR_FN(uint8_t,int8_t,int32_t,u8s8s16o16) +GEN_LPGEMM_DECORATOR_FN(uint8_t,int8_t,int16_t,u8s8s16o16) GEN_LPGEMM_DECORATOR_FN(uint8_t,int8_t,int32_t,u8s8s32o32) GEN_LPGEMM_DECORATOR_FN(bfloat16,bfloat16,float,bf16bf16f32of32) GEN_LPGEMM_DECORATOR_FN(float,float,float,f32f32f32of32) diff --git a/addon/aocl_gemm/frame/u8s8s16/lpgemm_reorder_s16.c b/addon/aocl_gemm/frame/u8s8s16/lpgemm_reorder_s16.c index f471244423..0b55f31215 100644 --- a/addon/aocl_gemm/frame/u8s8s16/lpgemm_reorder_s16.c +++ b/addon/aocl_gemm/frame/u8s8s16/lpgemm_reorder_s16.c @@ -38,56 +38,131 @@ #include "lpgemm_config.h" void aocl_reorderb_nr32_u8s8s16o16 - ( - lpgemm_obj_t *b, - lpgemm_obj_t *b_reorder - ) + ( + lpgemm_obj_t *b, + lpgemm_obj_t *b_reorder + ) { - const dim_t NC = lpgemm_get_block_size_NC_global_cntx(U8S8S16OS16); - const dim_t KC = lpgemm_get_block_size_KC_global_cntx(U8S8S16OS16); - - // Extracting the matrix properties from the lpgemm object - dim_t rs_b = b->rs; - dim_t n = b->width; - dim_t k = b->length; - - dim_t rs_b_reorder; - dim_t cs_b_reorder; - - dim_t k_updated = k; - - // Making multiple of 2 to suit k in vpmaddubsw - k_updated += (k_updated & 0x1); - - for (dim_t jc = 0; jc < n; jc += NC) - { - dim_t nc0 = ((jc + NC) <= n) ? NC : (n % NC); - - // nc0 needs to be a multiple of 16 since this gives maximum - // vectorization. Packing B always results in buffers with width - // which is a multiple of 16. Subsequently the nc0 offsets used - // for packed/reordered buffers needs to be updated. - dim_t nc0_mod16 = nc0 % 16; - dim_t nc0_updated = nc0; - if (nc0_mod16 != 0) - { - nc0_updated += (16 - nc0_mod16); - } - - for (dim_t pc = 0; pc < k; pc += KC) - { - dim_t kc0 = ((pc + KC) <= k) ? KC : (k % KC); - - // B should always be packed. - packb_nr32_u8s8s16o16( - (((int8_t *)b_reorder->storage.aligned_buffer) + (jc * k_updated) + (nc0_updated * pc)), - (((int8_t *)b->storage.aligned_buffer) + (rs_b * pc) + jc), - rs_b, nc0, kc0, &rs_b_reorder, &cs_b_reorder); - } - } - - // Changing the packed matrix properties in the packed matrix object - b_reorder->rs = rs_b_reorder; - b_reorder->cs = cs_b_reorder; - b_reorder->mtag = REORDERED; + const dim_t NC = lpgemm_get_block_size_NC_global_cntx(U8S8S16OS16); + const dim_t KC = lpgemm_get_block_size_KC_global_cntx(U8S8S16OS16); + const dim_t NR = lpgemm_get_block_size_NR_global_cntx(U8S8S16OS16); + + // Extracting the matrix properties from the lpgemm object + dim_t rs_b = b->rs; + dim_t n = b->width; + dim_t k = b->length; + + dim_t rs_b_reorder; + dim_t cs_b_reorder; + + dim_t k_updated = k; + + // Making multiple of 2 to suit k in vpmaddubsw + k_updated += (k_updated & 0x1); + + // Initialize a local runtime with global settings if necessary. Note + // that in the case that a runtime is passed in, we make a local copy. + rntm_t rntm_g; + bli_rntm_init_from_global( &rntm_g ); + + dim_t n_threads = bli_rntm_num_threads( &rntm_g ); + n_threads = ( n_threads > 0 ) ? n_threads : 1; + +#ifdef BLIS_ENABLE_OPENMP + _Pragma( "omp parallel num_threads(n_threads)" ) + { + // Initialise a local thrinfo obj for work split across threads. + thrinfo_t thread_jc; + bli_thrinfo_set_n_way( n_threads, &thread_jc ); + bli_thrinfo_set_work_id( omp_get_thread_num(), &thread_jc ); +#else + { + // Initialise a local thrinfo obj for work split across threads. + thrinfo_t thread_jc; + bli_thrinfo_set_n_way( 1, &thread_jc ); + bli_thrinfo_set_work_id( 0, &thread_jc ); +#endif + // Compute the JC loop thread range for the current thread. + dim_t jc_start, jc_end; + bli_thread_range_sub( &thread_jc, n, NR, FALSE, &jc_start, &jc_end ); + + for ( dim_t jc = jc_start; jc < jc_end; jc += NC ) + { + dim_t nc0 = bli_min( ( jc_end - jc ), NC ); + + dim_t jc_cur_loop = jc; + dim_t jc_cur_loop_rem = 0; + dim_t n_sub_updated; + + get_B_panel_reordered_start_offset_width + ( + jc, n, NC, 16, + &jc_cur_loop, &jc_cur_loop_rem, + &nc0, &n_sub_updated + ); + + for ( dim_t pc = 0; pc < k; pc += KC ) + { + dim_t kc0 = bli_min( ( k - pc ), KC ); + + // kc0 needs to be a multiple of 2 so that it can be used with + // vmaddubsw instruction. Padding is added in cases this + // condition is not satisfied, and therefore the kc0 offsets + // used for packed/reordered buffers needs to be updated. + dim_t kc0_updated = make_multiple_of_n( kc0, 2 ); + + // The offsets are calculated in such a way that it resembles + // the reorder buffer traversal in single threaded reordering. + // The panel boundaries (KCxNC) remain as it is accessed in + // single thread, and as a consequence a thread with jc_start + // inside the panel cannot consider NC range for reorder. It + // has to work with NC' < NC, and the offset is calulated using + // prev NC panels spanning k dim + cur NC panel spaning pc loop + // cur iteration + (NC - NC') spanning current kc0 (<= KC). + // + //Eg: Consider the following reordered buffer diagram: + // t1 t2 + // | | + // | |..NC..| + // | | | + // |.NC. |.NC. |NC'|NC" + // pc=0-+-----+-----+---+--+ + // KC| | | | | + // | 1 | 3 | 5 | + // pc=KC-+-----+-----+---st-+ + // KC| | | | | + // | 2 | 4 | 6 | 7| + // pc=k=2KC-+-----+-----+---+--+ + // |jc=0 |jc=NC|jc=2NC| + // + // The numbers 1,2..6,7 denotes the order in which reordered + // KCxNC blocks are stored in memory, ie: block 1 followed by 2 + // followed by 3, etc. Given two threads t1 and t2, and t2 needs + // to acces point st in the reorder buffer to write the data: + // The offset calulation logic will be: + // jc_cur_loop = 2NC, jc_cur_loop_rem = NC', pc = KC, + // n_sub_updated = NC, k = 2KC, kc0_updated = KC + // + // st = ( jc_cur_loop * k ) + // + ( n_sub_updated * pc ) + // + ( NC' * kc0_updated) + packb_nr32_u8s8s16o16 + ( + ( ( ( int8_t* )b_reorder->storage.aligned_buffer ) + + ( jc_cur_loop * k_updated ) + ( n_sub_updated * pc ) + + ( jc_cur_loop_rem * kc0_updated ) ), + ( ( ( int8_t* )b->storage.aligned_buffer ) + + ( rs_b * pc ) + jc ), + rs_b, nc0, kc0, &rs_b_reorder, &cs_b_reorder + ); + } + + adjust_B_panel_reordered_jc( &jc, jc_cur_loop ); + } + } + + // Changing the packed matrix properties in the packed matrix object + b_reorder->rs = rs_b_reorder; + b_reorder->cs = cs_b_reorder; + b_reorder->mtag = REORDERED; } diff --git a/addon/aocl_gemm/frame/u8s8s16/lpgemm_reorder_s16.h b/addon/aocl_gemm/frame/u8s8s16/lpgemm_reorder_s16.h index 1b107d634b..6018978bc7 100644 --- a/addon/aocl_gemm/frame/u8s8s16/lpgemm_reorder_s16.h +++ b/addon/aocl_gemm/frame/u8s8s16/lpgemm_reorder_s16.h @@ -37,9 +37,9 @@ #include "lpgemm_types.h" void aocl_reorderb_nr32_u8s8s16o16 - ( - lpgemm_obj_t *b, - lpgemm_obj_t *b_reorder - ); + ( + lpgemm_obj_t *b, + lpgemm_obj_t *b_reorder + ); -#endif // LPGEMM_REORDER_H \ No newline at end of file +#endif // LPGEMM_REORDER_H diff --git a/addon/aocl_gemm/frame/u8s8s32/lpgemm_utils.h b/addon/aocl_gemm/frame/u8s8s32/lpgemm_utils.h index 743af4f3ec..93acad6ac9 100644 --- a/addon/aocl_gemm/frame/u8s8s32/lpgemm_utils.h +++ b/addon/aocl_gemm/frame/u8s8s32/lpgemm_utils.h @@ -187,7 +187,7 @@ BLIS_INLINE void get_B_panel_reordered_start_offset_width dim_t* panel_start, dim_t* panel_offset, dim_t* panel_width, - dim_t* panel_width_kdim_trav + dim_t* panel_width_kdim_trav ) { // Since n dimension is split across threads in units of NR blocks, diff --git a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_6x64rowmajor_bf16_amd512vnni.c b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_6x64rowmajor_bf16_amd512vnni.c index e98c6e9872..b3db64e0b0 100644 --- a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_6x64rowmajor_bf16_amd512vnni.c +++ b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_6x64rowmajor_bf16_amd512vnni.c @@ -190,21 +190,21 @@ LPGEMM_MAIN_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x64) // The instructions are arranged in a mixed way to reduce data // chain dependencies. - b0 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); // Broadcast a[0,kr:kr+2] - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )(a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )(a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); - b1 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 1 ) ); - b2 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 2 ) ); - b3 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 3 ) ); + b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + b2 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 2 ) ); + b3 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 3 ) ); // Perform column direction mat-mul with k = 2. // c[0,0-63] = a[0,kr:kr+2]*b[kr:kr+2,0-63] c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); // Broadcast a[1,kr:kr+2]. - a_bf16_1 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + a_bf16_1 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); @@ -215,7 +215,7 @@ LPGEMM_MAIN_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x64) c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_1, b0 ); // Broadcast a[2,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_1, b1 ); c_float_1p2 = _mm512_dpbf16_ps( c_float_1p2, a_bf16_1, b2 ); @@ -226,7 +226,7 @@ LPGEMM_MAIN_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x64) c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); // Broadcast a[3,kr:kr+2]. - a_bf16_1 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + a_bf16_1 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); c_float_2p2 = _mm512_dpbf16_ps( c_float_2p2, a_bf16_0, b2 ); @@ -237,7 +237,7 @@ LPGEMM_MAIN_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x64) c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_1, b0 ); // Broadcast a[4,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); c_float_3p1 = _mm512_dpbf16_ps( c_float_3p1, a_bf16_1, b1 ); c_float_3p2 = _mm512_dpbf16_ps( c_float_3p2, a_bf16_1, b2 ); @@ -248,7 +248,7 @@ LPGEMM_MAIN_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x64) c_float_4p0 = _mm512_dpbf16_ps( c_float_4p0, a_bf16_0, b0 ); // Broadcast a[5,kr:kr+2]. - a_bf16_1 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 5 ) + ( cs_a * kr ) ) ); + a_bf16_1 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 5 ) + ( cs_a * kr ) ) ); c_float_4p1 = _mm512_dpbf16_ps( c_float_4p1, a_bf16_0, b1 ); c_float_4p2 = _mm512_dpbf16_ps( c_float_4p2, a_bf16_0, b2 ); @@ -264,7 +264,7 @@ LPGEMM_MAIN_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x64) // Handle k remainder. if ( k_partial_pieces > 0 ) { - b0 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); // Broadcast a[0,kr:kr+2]. memcpy @@ -273,11 +273,11 @@ LPGEMM_MAIN_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x64) ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); - b1 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); - b2 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); - b3 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 3 ) ); + b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + b2 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); + b3 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 3 ) ); // Perform column direction mat-mul with k = 2. // c[0,0-63] = a[0,kr:kr+2]*b[kr:kr+2,0-63] @@ -290,7 +290,7 @@ LPGEMM_MAIN_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x64) ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_1 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_1 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); @@ -307,7 +307,7 @@ LPGEMM_MAIN_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x64) ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_1, b1 ); c_float_1p2 = _mm512_dpbf16_ps( c_float_1p2, a_bf16_1, b2 ); @@ -324,7 +324,7 @@ LPGEMM_MAIN_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x64) ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_1 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_1 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); c_float_2p2 = _mm512_dpbf16_ps( c_float_2p2, a_bf16_0, b2 ); @@ -341,7 +341,7 @@ LPGEMM_MAIN_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x64) ( a + ( rs_a * 4 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); c_float_3p1 = _mm512_dpbf16_ps( c_float_3p1, a_bf16_1, b1 ); c_float_3p2 = _mm512_dpbf16_ps( c_float_3p2, a_bf16_1, b2 ); @@ -358,7 +358,7 @@ LPGEMM_MAIN_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x64) ( a + ( rs_a * 5 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_1 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_1 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); c_float_4p1 = _mm512_dpbf16_ps( c_float_4p1, a_bf16_0, b1 ); c_float_4p2 = _mm512_dpbf16_ps( c_float_4p2, a_bf16_0, b2 ); diff --git a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_m_fringe_bf16_amd512vnni.c b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_m_fringe_bf16_amd512vnni.c index 9d1e742bea..b4b2d2e27a 100644 --- a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_m_fringe_bf16_amd512vnni.c +++ b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_m_fringe_bf16_amd512vnni.c @@ -84,21 +84,21 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x64) for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) { - b0 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); // Broadcast a[0,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); - b1 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 1 ) ); - b2 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 2 ) ); - b3 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 3 ) ); + b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + b2 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 2 ) ); + b3 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 3 ) ); // Perform column direction mat-mul with k = 2. // c[0,0-63] = a[0,kr:kr+2]*b[kr:kr+2,0-63] c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); // Broadcast a[1,kr:kr+2]. - a_bf16_1 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + a_bf16_1 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); @@ -109,7 +109,7 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x64) c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_1, b0 ); // Broadcast a[2,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_1, b1 ); c_float_1p2 = _mm512_dpbf16_ps( c_float_1p2, a_bf16_1, b2 ); @@ -120,7 +120,7 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x64) c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); // Broadcast a[3,kr:kr+2]. - a_bf16_1 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + a_bf16_1 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); c_float_2p2 = _mm512_dpbf16_ps( c_float_2p2, a_bf16_0, b2 ); @@ -131,7 +131,7 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x64) c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_1, b0 ); // Broadcast a[4,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); c_float_3p1 = _mm512_dpbf16_ps( c_float_3p1, a_bf16_1, b1 ); c_float_3p2 = _mm512_dpbf16_ps( c_float_3p2, a_bf16_1, b2 ); @@ -147,7 +147,7 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x64) // Handle k remainder. if ( k_partial_pieces > 0 ) { - b0 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); // Broadcast a[0,kr:kr+4]. memcpy @@ -156,11 +156,11 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x64) ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); - b1 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); - b2 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); - b3 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 3 ) ); + b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + b2 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); + b3 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 3 ) ); // Perform column direction mat-mul with k = 2. // c[0,0-63] = a[0,kr:kr+2]*b[kr:kr+2,0-63] @@ -173,7 +173,7 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x64) ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_1 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_1 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); @@ -190,7 +190,7 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x64) ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_1, b1 ); c_float_1p2 = _mm512_dpbf16_ps( c_float_1p2, a_bf16_1, b2 ); @@ -207,7 +207,7 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x64) ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_1 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_1 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); c_float_2p2 = _mm512_dpbf16_ps( c_float_2p2, a_bf16_0, b2 ); @@ -224,7 +224,7 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x64) ( a + ( rs_a * 4 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); c_float_3p1 = _mm512_dpbf16_ps( c_float_3p1, a_bf16_1, b1 ); c_float_3p2 = _mm512_dpbf16_ps( c_float_3p2, a_bf16_1, b2 ); @@ -475,21 +475,21 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x64) for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) { - b0 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); // Broadcast a[0,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); - b1 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 1 ) ); - b2 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 2 ) ); - b3 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 3 ) ); + b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + b2 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 2 ) ); + b3 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 3 ) ); // Perform column direction mat-mul with k = 2. // c[0,0-63] = a[0,kr:kr+4]*b[kr:kr+4,0-63] c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); // Broadcast a[1,kr:kr+2]. - a_bf16_1 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + a_bf16_1 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); @@ -500,7 +500,7 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x64) c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_1, b0 ); // Broadcast a[2,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_1, b1 ); c_float_1p2 = _mm512_dpbf16_ps( c_float_1p2, a_bf16_1, b2 ); @@ -511,7 +511,7 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x64) c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); // Broadcast a[3,kr:kr+2]. - a_bf16_1 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + a_bf16_1 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); c_float_2p2 = _mm512_dpbf16_ps( c_float_2p2, a_bf16_0, b2 ); @@ -528,7 +528,7 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x64) // Handle k remainder. if ( k_partial_pieces > 0 ) { - b0 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); // Broadcast a[0,kr:kr+2]. memcpy @@ -537,11 +537,11 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x64) ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); - b1 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); - b2 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); - b3 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 3 ) ); + b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + b2 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); + b3 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 3 ) ); // Perform column direction mat-mul with k = 2. // c[0,0-63] = a[0,kr:kr+2]*b[kr:kr+2,0-63] @@ -554,7 +554,7 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x64) ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_1 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_1 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); @@ -571,7 +571,7 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x64) ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_1, b1 ); c_float_1p2 = _mm512_dpbf16_ps( c_float_1p2, a_bf16_1, b2 ); @@ -588,7 +588,7 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x64) ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_1 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_1 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); c_float_2p2 = _mm512_dpbf16_ps( c_float_2p2, a_bf16_0, b2 ); @@ -797,21 +797,21 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x64) for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) { - b0 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); // Broadcast a[0,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); - b1 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 1 ) ); - b2 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 2 ) ); - b3 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 3 ) ); + b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + b2 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 2 ) ); + b3 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 3 ) ); // Perform column direction mat-mul with k = 2. // c[0,0-63] = a[0,kr:kr+2]*b[kr:kr+2,0-63] c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); // Broadcast a[1,kr:kr+2]. - a_bf16_1 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + a_bf16_1 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); @@ -822,7 +822,7 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x64) c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_1, b0 ); // Broadcast a[2,kr:kr+4]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_1, b1 ); c_float_1p2 = _mm512_dpbf16_ps( c_float_1p2, a_bf16_1, b2 ); @@ -839,7 +839,7 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x64) // Handle k remainder. if ( k_partial_pieces > 0 ) { - b0 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); // Broadcast a[0,kr:kr+2]. memcpy @@ -848,11 +848,11 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x64) ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); - b1 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); - b2 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); - b3 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 3 ) ); + b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + b2 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); + b3 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 3 ) ); // Perform column direction mat-mul with k = 2. // c[0,0-63] = a[0,kr:kr+2]*b[kr:kr+2,0-63] @@ -865,7 +865,7 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x64) ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - __m512i a_bf16_1 = _mm512_set1_epi32( a_kfringe_buf ); + __m512i a_bf16_1 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); @@ -882,7 +882,7 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x64) ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_1, b1 ); c_float_1p2 = _mm512_dpbf16_ps( c_float_1p2, a_bf16_1, b2 ); @@ -1048,21 +1048,21 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2x64) for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) { - b0 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); // Broadcast a[0,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); - b1 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 1 ) ); - b2 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 2 ) ); - b3 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 3 ) ); + b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + b2 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 2 ) ); + b3 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 3 ) ); // Perform column direction mat-mul with k = 2. // c[0,0-63] = a[0,kr:kr+2]*b[kr:kr+2,0-63] c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); // Broadcast a[1,kr:kr+2]. - a_bf16_1 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + a_bf16_1 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); @@ -1079,7 +1079,7 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2x64) // Handle k remainder. if ( k_partial_pieces > 0 ) { - __m512i b0 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + __m512i b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); // Broadcast a[0,kr:kr+2]. memcpy @@ -1088,11 +1088,11 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2x64) ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - __m512i a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + __m512i a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); - __m512i b1 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); - __m512i b2 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); - __m512i b3 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 3 ) ); + __m512i b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + __m512i b2 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); + __m512i b3 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 3 ) ); // Perform column direction mat-mul with k = 2. // c[0,0-63] = a[0,kr:kr+2]*b[kr:kr+2,0-63] @@ -1105,7 +1105,7 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2x64) ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - __m512i a_bf16_1 = _mm512_set1_epi32( a_kfringe_buf ); + __m512i a_bf16_1 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); @@ -1221,14 +1221,14 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1x64) for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) { - __m512bh b0 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + __m512bh b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); // Broadcast a[0,kr] - __m512bh a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + __m512bh a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); - __m512bh b1 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 1 ) ); - __m512bh b2 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 2 ) ); - __m512bh b3 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 3 ) ); + __m512bh b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + __m512bh b2 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 2 ) ); + __m512bh b3 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 3 ) ); // Perform column direction mat-mul with k = 2. // c[0,0-63] = a[0,kr:kr+2]*b[kr:kr+2,0-63] @@ -1241,7 +1241,7 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1x64) // Handle k remainder. if ( k_partial_pieces > 0 ) { - __m512i b0 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + __m512i b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); // Broadcast a[0,kr:kr+2]. memcpy @@ -1250,11 +1250,11 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1x64) ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - __m512i a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + __m512i a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); - __m512i b1 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); - __m512i b2 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); - __m512i b3 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 3 ) ); + __m512i b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + __m512i b2 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); + __m512i b3 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 3 ) ); // Perform column direction mat-mul with k = 2. // c[0,0-63] = a[0,kr:kr+2]*b[kr:kr+2,0-63] diff --git a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_mn_fringe_bf16_amd512vnni.c b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_mn_fringe_bf16_amd512vnni.c index 978b7944c1..24c4cd0e85 100644 --- a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_mn_fringe_bf16_amd512vnni.c +++ b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_mn_fringe_bf16_amd512vnni.c @@ -72,38 +72,38 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5xlt16) for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) { - b0 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); // Broadcast a[0,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); // Broadcast a[1,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); // Broadcast a[2,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-15] c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); // Broadcast a[3,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[3,0-15] = a[3,kr:kr+2]*b[kr:kr+2,0-15] c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); // Broadcast a[4,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[4,0-15] = a[4,kr:kr+2]*b[kr:kr+2,0-15] @@ -112,11 +112,11 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5xlt16) // Handle k remainder. if ( k_partial_pieces > 0 ) { - b0 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); // Broadcast a[0,kr:kr+2]. memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] @@ -124,7 +124,7 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5xlt16) // Broadcast a[1,kr:kr+2]. memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] @@ -132,7 +132,7 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5xlt16) // Broadcast a[2,kr:kr+2]. memcpy( &a_kfringe_buf, ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-15] @@ -140,7 +140,7 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5xlt16) // Broadcast a[3,kr:kr+2]. memcpy( &a_kfringe_buf, ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[3,0-15] = a[3,kr:kr+2]*b[kr:kr+2,0-15] @@ -148,7 +148,7 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5xlt16) // Broadcast a[4,kr:kr+2]. memcpy( &a_kfringe_buf, ( a + ( rs_a * 4 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[4,0-15] = a[4,kr:kr+2]*b[kr:kr+2,0-15] @@ -271,31 +271,31 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4xlt16) for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) { - b0 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); // Broadcast a[0,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); // Broadcast a[1,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); // Broadcast a[2,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-15] c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); // Broadcast a[3,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[3,0-15] = a[3,kr:kr+2]*b[kr:kr+2,0-15] @@ -305,11 +305,11 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4xlt16) if ( k_partial_pieces > 0 ) { - b0 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); // Broadcast a[0,kr:kr+2]. memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] @@ -317,7 +317,7 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4xlt16) // Broadcast a[1,kr:kr+2]. memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] @@ -325,7 +325,7 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4xlt16) // Broadcast a[2,kr:kr+2]. memcpy( &a_kfringe_buf, ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-15] @@ -333,7 +333,7 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4xlt16) // Broadcast a[3,kr:kr+2]. memcpy( &a_kfringe_buf, ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[3,0-15] = a[3,kr:kr+2]*b[kr:kr+2,0-15] @@ -439,24 +439,24 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3xlt16) for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) { - b0 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); // Broadcast a[0,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); // Broadcast a[1,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); // Broadcast a[2,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-15] @@ -465,11 +465,11 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3xlt16) // Handle k remainder. if ( k_partial_pieces > 0 ) { - b0 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); // Broadcast a[0,kr:kr+2]. memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] @@ -477,7 +477,7 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3xlt16) // Broadcast a[1,kr:kr+2]. memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] @@ -485,7 +485,7 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3xlt16) // Broadcast a[2,kr:kr+2]. memcpy( &a_kfringe_buf, ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-15] @@ -574,17 +574,17 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2xlt16) for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) { - b0 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); // Broadcast a[0,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); // Broadcast a[1,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] @@ -593,11 +593,11 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2xlt16) // Handle k remainder. if ( k_partial_pieces > 0 ) { - b0 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); // Broadcast a[0,kr:kr+2]. memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] @@ -605,7 +605,7 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2xlt16) // Broadcast a[1,kr:kr+2]. memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] @@ -677,10 +677,10 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1xlt16) for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) { - b0 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); // Broadcast a[0,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] @@ -689,11 +689,11 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1xlt16) // Handle k remainder. if ( k_partial_pieces > 0 ) { - b0 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); // Broadcast a[0,kr:kr+2]. memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] @@ -755,38 +755,38 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x16) for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) { - b0 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); // Broadcast a[0,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); // Broadcast a[1,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); // Broadcast a[2,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-15] c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); // Broadcast a[3,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[3,0-15] = a[3,kr:kr+2]*b[kr:kr+2,0-15] c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); // Broadcast a[4,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[4,0-15] = a[4,kr:kr+2]*b[kr:kr+2,0-15] @@ -795,11 +795,11 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x16) // Handle k remainder. if ( k_partial_pieces > 0 ) { - b0 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); // Broadcast a[0,kr:kr+2]. memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] @@ -807,7 +807,7 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x16) // Broadcast a[1,kr:kr+2]. memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] @@ -815,7 +815,7 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x16) // Broadcast a[2,kr:kr+2]. memcpy( &a_kfringe_buf, ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-15] @@ -823,7 +823,7 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x16) // Broadcast a[3,kr:kr+2]. memcpy( &a_kfringe_buf, ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[3,0-15] = a[3,kr:kr+2]*b[kr:kr+2,0-15] @@ -831,7 +831,7 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x16) // Broadcast a[4,kr:kr+2]. memcpy( &a_kfringe_buf, ( a + ( rs_a * 4 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[4,0-15] = a[4,kr:kr+2]*b[kr:kr+2,0-15] @@ -924,31 +924,31 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x16) for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) { - b0 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); // Broadcast a[0,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); // Broadcast a[1,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); // Broadcast a[2,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-15] c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); // Broadcast a[3,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[3,0-15] = a[3,kr:kr+2]*b[kr:kr+2,0-15] @@ -957,11 +957,11 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x16) // Handle k remainder. if ( k_partial_pieces > 0 ) { - b0 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); // Broadcast a[0,kr:kr+2]. memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] @@ -969,7 +969,7 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x16) // Broadcast a[1,kr:kr+2]. memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] @@ -977,7 +977,7 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x16) // Broadcast a[2,kr:kr+2]. memcpy( &a_kfringe_buf, ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-15] @@ -985,7 +985,7 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x16) // Broadcast a[3,kr:kr+2]. memcpy( &a_kfringe_buf, ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[3,0-15] = a[3,kr:kr+2]*b[kr:kr+2,0-15] @@ -1066,24 +1066,24 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x16) for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) { - b0 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); // Broadcast a[0,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); // Broadcast a[1,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); // Broadcast a[2,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-15] @@ -1092,11 +1092,11 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x16) // Handle k remainder. if ( k_partial_pieces > 0 ) { - b0 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); // Broadcast a[0,kr:kr+2]. memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] @@ -1104,7 +1104,7 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x16) // Broadcast a[1,kr:kr+2]. memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] @@ -1112,7 +1112,7 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x16) // Broadcast a[2,kr:kr+2]. memcpy( &a_kfringe_buf, ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-15] @@ -1181,17 +1181,17 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2x16) for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) { - b0 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); // Broadcast a[0,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); // Broadcast a[1,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] @@ -1200,11 +1200,11 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2x16) // Handle k remainder. if ( k_partial_pieces > 0 ) { - b0 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); // Broadcast a[0,kr:kr+2]. memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] @@ -1212,7 +1212,7 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2x16) // Broadcast a[1,kr:kr+2]. memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] @@ -1269,10 +1269,10 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1x16) for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) { - b0 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); // Broadcast a[0,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] @@ -1281,11 +1281,11 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1x16) // Handle k remainder. if ( k_partial_pieces > 0 ) { - b0 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); // Broadcast a[0,kr:kr+2]. memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] @@ -1303,7 +1303,7 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1x16) if ( beta != 0 ) { // c[0,0-15] - selector1 = _mm512_loadu_epi32( c + ( rs_c * 0 ) + ( 0*16 ) ); + selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 0*16 ) ); selector1 = _mm512_mul_ps( selector2, selector1 ); c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); } @@ -1346,11 +1346,11 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x32) for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) { - b0 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); - b1 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 1 ) ); // Broadcast a[0,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] @@ -1358,7 +1358,7 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x32) c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); // Broadcast a[1,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[1,0-31] = a[1,kr:kr+2]*b[kr:kr+2,0-31] @@ -1366,7 +1366,7 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x32) c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_0, b1 ); // Broadcast a[2,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[2,0-31] = a[2,kr:kr+2]*b[kr:kr+2,0-31] @@ -1374,7 +1374,7 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x32) c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); // Broadcast a[3,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[3,0-31] = a[3,kr:kr+2]*b[kr:kr+2,0-31] @@ -1382,7 +1382,7 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x32) c_float_3p1 = _mm512_dpbf16_ps( c_float_3p1, a_bf16_0, b1 ); // Broadcast a[4,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[4,0-31] = a[4,kr:kr+2]*b[kr:kr+2,0-31] @@ -1392,12 +1392,12 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x32) // Handle k remainder. if ( k_partial_pieces > 0 ) { - b0 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); - b1 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); // Broadcast a[0,kr:kr+2]. memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] @@ -1406,7 +1406,7 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x32) // Broadcast a[1,kr:kr+2]. memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[1,0-31] = a[1,kr:kr+2]*b[kr:kr+2,0-31] @@ -1415,7 +1415,7 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x32) // Broadcast a[2,kr:kr+2]. memcpy( &a_kfringe_buf, ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[2,0-31] = a[2,kr:kr+2]*b[kr:kr+2,0-31] @@ -1424,7 +1424,7 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x32) // Broadcast a[3,kr:kr+2]. memcpy( &a_kfringe_buf, ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[3,0-31] = a[3,kr:kr+2]*b[kr:kr+2,0-31] @@ -1433,7 +1433,7 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x32) // Broadcast a[4,kr:kr+2]. memcpy( &a_kfringe_buf, ( a + ( rs_a * 4 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[4,0-31] = a[4,kr:kr+2]*b[kr:kr+2,0-31] @@ -1577,11 +1577,11 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x32) for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) { - b0 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); - b1 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 1 ) ); // Broadcast a[0,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] @@ -1589,7 +1589,7 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x32) c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); // Broadcast a[1,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[1,0-31] = a[1,kr:kr+2]*b[kr:kr+2,0-31] @@ -1597,7 +1597,7 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x32) c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_0, b1 ); // Broadcast a[2,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[2,0-31] = a[2,kr:kr+2]*b[kr:kr+2,0-31] @@ -1605,7 +1605,7 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x32) c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); // Broadcast a[3,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[3,0-31] = a[3,kr:kr+2]*b[kr:kr+2,0-31] @@ -1615,12 +1615,12 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x32) // Handle k remainder. if ( k_partial_pieces > 0 ) { - b0 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); - b1 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); // Broadcast a[0,kr:kr+2]. memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] @@ -1629,7 +1629,7 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x32) // Broadcast a[1,kr:kr+2]. memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[1,0-31] = a[1,kr:kr+2]*b[kr:kr+2,0-31] @@ -1638,7 +1638,7 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x32) // Broadcast a[2,kr:kr+2]. memcpy( &a_kfringe_buf, ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[2,0-31] = a[2,kr:kr+2]*b[kr:kr+2,0-31] @@ -1647,7 +1647,7 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x32) // Broadcast a[3,kr:kr+2]. memcpy( &a_kfringe_buf, ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[3,0-31] = a[3,kr:kr+2]*b[kr:kr+2,0-31] @@ -1769,11 +1769,11 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x32) for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) { - b0 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); - b1 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 1 ) ); // Broadcast a[0,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] @@ -1781,7 +1781,7 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x32) c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); // Broadcast a[1,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[1,0-31] = a[1,kr:kr+2]*b[kr:kr+2,0-31] @@ -1789,7 +1789,7 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x32) c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_0, b1 ); // Broadcast a[2,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[2,0-31] = a[2,kr:kr+2]*b[kr:kr+2,0-31] @@ -1799,12 +1799,12 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x32) // Handle k remainder. if ( k_partial_pieces > 0 ) { - b0 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); - b1 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); // Broadcast a[0,kr:kr+2]. memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] @@ -1813,7 +1813,7 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x32) // Broadcast a[1,kr:kr+2]. memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[1,0-31] = a[1,kr:kr+2]*b[kr:kr+2,0-31] @@ -1822,7 +1822,7 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x32) // Broadcast a[2,kr:kr+2]. memcpy( &a_kfringe_buf, ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[2,0-31] = a[2,kr:kr+2]*b[kr:kr+2,0-31] @@ -1922,11 +1922,11 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2x32) for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) { - b0 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); - b1 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 1 ) ); // Broadcast a[0,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] @@ -1934,7 +1934,7 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2x32) c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); // Broadcast a[1,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[1,0-31] = a[1,kr:kr+2]*b[kr:kr+2,0-31] @@ -1944,12 +1944,12 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2x32) // Handle k remainder. if ( k_partial_pieces > 0 ) { - b0 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); - b1 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); // Broadcast a[0,kr:kr+2]. memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] @@ -1958,7 +1958,7 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2x32) // Broadcast a[1,kr:kr+2]. memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[1,0-31] = a[1,kr:kr+2]*b[kr:kr+2,0-31] @@ -2036,11 +2036,11 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1x32) for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) { - b0 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); - b1 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 1 ) ); // Broadcast a[0,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] @@ -2050,12 +2050,12 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1x32) // Handle k remainder. if ( k_partial_pieces > 0 ) { - b0 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); - b1 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); // Broadcast a[0,kr:kr+2]. memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] @@ -2132,12 +2132,12 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x48) for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) { - b0 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); - b1 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 1 ) ); - b2 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 2 ) ); + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + b2 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 2 ) ); // Broadcast a[0,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[0,0-47] = a[0,kr:kr+2]*b[kr:kr+2,0-47] @@ -2146,7 +2146,7 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x48) c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); // Broadcast a[1,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[1,0-47] = a[1,kr:kr+2]*b[kr:kr+2,0-47] @@ -2155,7 +2155,7 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x48) c_float_1p2 = _mm512_dpbf16_ps( c_float_1p2, a_bf16_0, b2 ); // Broadcast a[2,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[2,0-47] = a[2,kr:kr+2]*b[kr:kr+2,0-47] @@ -2164,7 +2164,7 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x48) c_float_2p2 = _mm512_dpbf16_ps( c_float_2p2, a_bf16_0, b2 ); // Broadcast a[3,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[3,0-47] = a[3,kr:kr+2]*b[kr:kr+2,0-47] @@ -2173,7 +2173,7 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x48) c_float_3p2 = _mm512_dpbf16_ps( c_float_3p2, a_bf16_0, b2 ); // Broadcast a[4,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[4,0-47] = a[4,kr:kr+2]*b[kr:kr+2,0-47] @@ -2184,13 +2184,13 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x48) // Handle k remainder. if ( k_partial_pieces > 0 ) { - b0 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); - b1 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); - b2 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + b2 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); // Broadcast a[0,kr:kr+2]. memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[0,0-47] = a[0,kr:kr+2]*b[kr:kr+2,0-47] @@ -2200,7 +2200,7 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x48) // Broadcast a[1,kr:kr+2]. memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[1,0-47] = a[1,kr:kr+2]*b[kr:kr+2,0-47] @@ -2210,7 +2210,7 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x48) // Broadcast a[2,kr:kr+2]. memcpy( &a_kfringe_buf, ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[2,0-47] = a[2,kr:kr+2]*b[kr:kr+2,0-47] @@ -2220,7 +2220,7 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x48) // Broadcast a[3,kr:kr+2]. memcpy( &a_kfringe_buf, ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[3,0-47] = a[3,kr:kr+2]*b[kr:kr+2,0-47] @@ -2230,7 +2230,7 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x48) // Broadcast a[4,kr:kr+2]. memcpy( &a_kfringe_buf, ( a + ( rs_a * 4 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[4,0-47] = a[4,kr:kr+2]*b[kr:kr+2,0-47] @@ -2425,12 +2425,12 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x48) for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) { - b0 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); - b1 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 1 ) ); - b2 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 2 ) ); + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + b2 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 2 ) ); // Broadcast a[0,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[0,0-47] = a[0,kr:kr+2]*b[kr:kr+2,0-47] @@ -2439,7 +2439,7 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x48) c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); // Broadcast a[1,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[1,0-47] = a[1,kr:kr+2]*b[kr:kr+2,0-47] @@ -2448,7 +2448,7 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x48) c_float_1p2 = _mm512_dpbf16_ps( c_float_1p2, a_bf16_0, b2 ); // Broadcast a[2,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[2,0-47] = a[2,kr:kr+2]*b[kr:kr+2,0-47] @@ -2457,7 +2457,7 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x48) c_float_2p2 = _mm512_dpbf16_ps( c_float_2p2, a_bf16_0, b2 ); // Broadcast a[3,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[3,0-47] = a[3,kr:kr+2]*b[kr:kr+2,0-47] @@ -2468,13 +2468,13 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x48) // Handle k remainder. if ( k_partial_pieces > 0 ) { - b0 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); - b1 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); - b2 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + b2 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); // Broadcast a[0,kr:kr+2]. memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[0,0-47] = a[0,kr:kr+2]*b[kr:kr+2,0-47] @@ -2484,7 +2484,7 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x48) // Broadcast a[1,kr:kr+2]. memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[1,0-47] = a[1,kr:kr+2]*b[kr:kr+2,0-47] @@ -2494,7 +2494,7 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x48) // Broadcast a[2,kr:kr+2]. memcpy( &a_kfringe_buf, ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[2,0-47] = a[2,kr:kr+2]*b[kr:kr+2,0-47] @@ -2504,7 +2504,7 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x48) // Broadcast a[3,kr:kr+2]. memcpy( &a_kfringe_buf, ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[3,0-47] = a[3,kr:kr+2]*b[kr:kr+2,0-47] @@ -2667,12 +2667,12 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x48) for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) { - b0 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); - b1 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 1 ) ); - b2 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 2 ) ); + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + b2 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 2 ) ); // Broadcast a[0,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[0,0-47] = a[0,kr:kr+2]*b[kr:kr+2,0-47] @@ -2681,7 +2681,7 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x48) float_0p2 = _mm512_dpbf16_ps( float_0p2, a_bf16_0, b2 ); // Broadcast a[1,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[1,0-47] = a[1,kr:kr+2]*b[kr:kr+2,0-47] @@ -2690,7 +2690,7 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x48) float_1p2 = _mm512_dpbf16_ps( float_1p2, a_bf16_0, b2 ); // Broadcast a[2,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[2,0-47] = a[2,kr:kr+2]*b[kr:kr+2,0-47] @@ -2701,13 +2701,13 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x48) // Handle k remainder. if ( k_partial_pieces > 0 ) { - b0 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); - b1 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); - b2 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + b2 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); // Broadcast a[0,kr:kr+2]. memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[0,0-47] = a[0,kr:kr+2]*b[kr:kr+2,0-47] @@ -2717,7 +2717,7 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x48) // Broadcast a[1,kr:kr+2]. memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[1,0-47] = a[1,kr:kr+2]*b[kr:kr+2,0-47] @@ -2727,7 +2727,7 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x48) // Broadcast a[2,kr:kr+2]. memcpy( &a_kfringe_buf, ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[2,0-47] = a[2,kr:kr+2]*b[kr:kr+2,0-47] @@ -2858,12 +2858,12 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2x48) for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) { - b0 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); - b1 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 1 ) ); - b2 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 2 ) ); + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + b2 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 2 ) ); // Broadcast a[0,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[0,0-47] = a[0,kr:kr+2]*b[kr:kr+2,0-47] @@ -2872,7 +2872,7 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2x48) c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); // Broadcast a[1,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[1,0-47] = a[1,kr:kr+2]*b[kr:kr+2,0-47] @@ -2883,13 +2883,13 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2x48) // Handle k remainder. if ( k_partial_pieces > 0 ) { - b0 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); - b1 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); - b2 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + b2 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); // Broadcast a[0,kr:kr+2]. memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[0,0-47] = a[0,kr:kr+2]*b[kr:kr+2,0-47] @@ -2899,7 +2899,7 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2x48) // Broadcast a[1,kr:kr+2]. memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[1,0-47] = a[1,kr:kr+2]*b[kr:kr+2,0-47] @@ -2998,12 +2998,12 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1x48) for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) { - b0 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); - b1 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 1 ) ); - b2 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 2 ) ); + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + b2 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 2 ) ); // Broadcast a[0,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[0,0-47] = a[0,kr:kr+2]*b[kr:kr+2,0-47] @@ -3014,13 +3014,13 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1x48) // Handle k remainder. if ( k_partial_pieces > 0 ) { - b0 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); - b1 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); - b2 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + b2 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); // Broadcast a[0,kr:kr+2]. memcpy( &a_kfringe_buf, ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[0,0-47] = a[0,kr:kr+2]*b[kr:kr+2,0-47] diff --git a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_n_fringe_bf16_amd512vnni.c b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_n_fringe_bf16_amd512vnni.c index bdfb8de8e4..1908aa0c3f 100644 --- a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_n_fringe_bf16_amd512vnni.c +++ b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_n_fringe_bf16_amd512vnni.c @@ -86,45 +86,45 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6xlt16) // in bf16 instructions and each load to ZMM register will have 2 // elements along k direction and 16 elements across n directions, // so 2x16 elements to a ZMM register. - b0 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); // Broadcast a[0,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); // Broadcast a[1,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); // Broadcast a[2,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-15] c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); // Broadcast a[3,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[3,0-15] = a[3,kr:kr+2]*b[kr:kr+2,0-15] c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); // Broadcast a[4,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[4,0-15] = a[4,kr:kr+2]*b[kr:kr+2,0-15] c_float_4p0 = _mm512_dpbf16_ps( c_float_4p0, a_bf16_0, b0 ); // Broadcast a[5,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 5 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 5 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[5,0-15] = a[5,kr:kr+2]*b[kr:kr+2,0-15] @@ -134,7 +134,7 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6xlt16) // Handle k remainder. if ( k_partial_pieces > 0 ) { - b0 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); // Broadcast a[0,kr:kr+2]. memcpy @@ -143,7 +143,7 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6xlt16) ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] @@ -156,7 +156,7 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6xlt16) ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] @@ -169,7 +169,7 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6xlt16) ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-15] @@ -182,7 +182,7 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6xlt16) ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[3,0-15] = a[3,kr:kr+2]*b[kr:kr+2,0-15] @@ -195,7 +195,7 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6xlt16) ( a + ( rs_a * 4 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[4,0-15] = a[4,kr:kr+2]*b[kr:kr+2,0-15] @@ -208,7 +208,7 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6xlt16) ( a + ( rs_a * 5 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[5,0-15] = a[5,kr:kr+2]*b[kr:kr+2,0-15] @@ -435,45 +435,45 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x16) // instructions and each load to ZMM register will have 2 elements // along k direction and 16 elements across n directions, so 2x16 // elements to a ZMM register. - b0 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); // Broadcast a[0,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); // Broadcast a[1,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); // Broadcast a[2,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-15] c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); // Broadcast a[3,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[3,0-15] = a[3,kr:kr+2]*b[kr:kr+2,0-15] c_float_3p0 = _mm512_dpbf16_ps( c_float_3p0, a_bf16_0, b0 ); // Broadcast a[4,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[4,0-15] = a[4,kr:kr+2]*b[kr:kr+2,0-15] c_float_4p0 = _mm512_dpbf16_ps( c_float_4p0, a_bf16_0, b0 ); // Broadcast a[5,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 5 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 5 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[5,0-15] = a[5,kr:kr+2]*b[kr:kr+2,0-15] @@ -483,7 +483,7 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x16) if ( k_partial_pieces > 0 ) { - b0 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); // Broadcast a[0,kr:kr+2]. memcpy @@ -492,7 +492,7 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x16) ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[0,0-15] = a[0,kr:kr+2]*b[kr:kr+2,0-15] @@ -505,7 +505,7 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x16) ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[1,0-15] = a[1,kr:kr+2]*b[kr:kr+2,0-15] @@ -518,7 +518,7 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x16) ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[2,0-15] = a[2,kr:kr+2]*b[kr:kr+2,0-15] @@ -531,7 +531,7 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x16) ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[3,0-15] = a[3,kr:kr+2]*b[kr:kr+2,0-15] @@ -544,7 +544,7 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x16) ( a + ( rs_a * 4 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[4,0-15] = a[4,kr:kr+2]*b[kr:kr+2,0-15] @@ -557,7 +557,7 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x16) ( a + ( rs_a * 5 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[5,0-15] = a[5,kr:kr+2]*b[kr:kr+2,0-15] @@ -765,11 +765,11 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x32) // instructions and each load to ZMM register will have 2 elements // along k direction and 32 elements across n directions, so 2x16 // elements to a ZMM register. - b0 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); - b1 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 1 ) ); // Broadcast a[0,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] @@ -777,7 +777,7 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x32) c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); // Broadcast a[1,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[1,0-31] = a[1,kr:kr+2]*b[kr:kr+2,0-31] @@ -785,7 +785,7 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x32) c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_0, b1 ); // Broadcast a[2,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[2,0-31] = a[2,kr:kr+2]*b[kr:kr+2,0-31] @@ -793,7 +793,7 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x32) c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); // Broadcast a[3,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[3,0-31] = a[3,kr:kr+2]*b[kr:kr+2,0-31] @@ -801,7 +801,7 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x32) c_float_3p1 = _mm512_dpbf16_ps( c_float_3p1, a_bf16_0, b1 ); // Broadcast a[4,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[4,0-31] = a[4,kr:kr+2]*b[kr:kr+2,0-31] @@ -809,7 +809,7 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x32) c_float_4p1 = _mm512_dpbf16_ps( c_float_4p1, a_bf16_0, b1 ); // Broadcast a[5,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 5 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 5 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[5,0-31] = a[5,kr:kr+2]*b[kr:kr+2,0-31] @@ -819,8 +819,8 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x32) // Handle k remainder. if ( k_partial_pieces > 0 ) { - b0 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); - b1 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); // Broadcast a[0,kr:kr+2]. memcpy @@ -829,7 +829,7 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x32) ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[0,0-31] = a[0,kr:kr+2]*b[kr:kr+2,0-31] @@ -843,7 +843,7 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x32) ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[1,0-31] = a[1,kr:kr+2]*b[kr:kr+2,0-31] @@ -857,7 +857,7 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x32) ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[2,0-31] = a[2,kr:kr+2]*b[kr:kr+2,0-31] @@ -871,7 +871,7 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x32) ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[3,0-31] = a[3,kr:kr+2]*b[kr:kr+2,0-31] @@ -885,7 +885,7 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x32) ( a + ( rs_a * 4 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[4,0-31] = a[4,kr:kr+2]*b[kr:kr+2,0-31] @@ -899,7 +899,7 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x32) ( a + ( rs_a * 5 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[5,0-31] = a[5,kr:kr+2]*b[kr:kr+2,0-31] @@ -1169,12 +1169,12 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x48) // instructions and each load to ZMM register will have 2 elements // along k direction and 16 elements across n directions, so 2x16 // elements to a ZMM register. - b0 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); - b1 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 1 ) ); - b2 = _mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 2 ) ); + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 0 ) ); + b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 1 ) ); + b2 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * kr ) + ( cs_b * 2 ) ); // Broadcast a[0,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 0 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[0,0-47] = a[0,kr:kr+2]*b[kr:kr+2,0-47] @@ -1183,7 +1183,7 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x48) c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); // Broadcast a[1,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[1,0-47] = a[1,kr:kr+2]*b[kr:kr+2,0-47] @@ -1192,7 +1192,7 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x48) c_float_1p2 = _mm512_dpbf16_ps( c_float_1p2, a_bf16_0, b2 ); // Broadcast a[2,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[2,0-47] = a[2,kr:kr+2]*b[kr:kr+2,0-47] @@ -1201,7 +1201,7 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x48) c_float_2p2 = _mm512_dpbf16_ps( c_float_2p2, a_bf16_0, b2 ); // Broadcast a[3,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 3 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[3,0-47] = a[3,kr:kr+2]*b[kr:kr+2,0-47] @@ -1210,7 +1210,7 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x48) c_float_3p2 = _mm512_dpbf16_ps( c_float_3p2, a_bf16_0, b2 ); // Broadcast a[4,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 4 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[4,0-47] = a[4,kr:kr+2]*b[kr:kr+2,0-47] @@ -1219,7 +1219,7 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x48) c_float_4p2 = _mm512_dpbf16_ps( c_float_4p2, a_bf16_0, b2 ); // Broadcast a[5,kr:kr+2]. - a_bf16_0 = _mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 5 ) + ( cs_a * kr ) ) ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 5 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[5,0-47] = a[5,kr:kr+2]*b[kr:kr+2,0-47] @@ -1231,9 +1231,9 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x48) // Handle k remainder. if ( k_partial_pieces > 0 ) { - b0 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); - b1 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); - b2 = _mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + b2 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); // Broadcast a[0,kr:kr+4]. memcpy @@ -1242,7 +1242,7 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x48) ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[0,0-47] = a[0,kr:kr+2]*b[kr:kr+2,0-47] @@ -1257,7 +1257,7 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x48) ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[1,0-47] = a[1,kr:kr+2]*b[kr:kr+2,0-47] @@ -1272,7 +1272,7 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x48) ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[2,0-47] = a[2,kr:kr+2]*b[kr:kr+2,0-47] @@ -1287,7 +1287,7 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x48) ( a + ( rs_a * 3 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[3,0-47] = a[3,kr:kr+2]*b[kr:kr+2,0-47] @@ -1302,7 +1302,7 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x48) ( a + ( rs_a * 4 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[4,0-47] = a[4,kr:kr+2]*b[kr:kr+2,0-47] @@ -1317,7 +1317,7 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x48) ( a + ( rs_a * 5 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - a_bf16_0 = _mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); // Perform column direction mat-mul with k = 2. // c[5,0-47] = a[5,kr:kr+2]*b[kr:kr+2,0-47] From e674fae75883529e695e0b11ce904cdc6b0644af Mon Sep 17 00:00:00 2001 From: eashdash Date: Fri, 26 Aug 2022 09:11:53 +0000 Subject: [PATCH 208/243] Post-Ops for bf16bf16f32 Functionality - Post-ops is a set of operations performed elemnent wise on the output matrix post GEMM operation. The support for the same is added by fusing post-ops with GEMM operations. - Post-ops Bias, Relu and Parametric Relu are added to all the compute kernels of bf16bf16f32of32 - Modified bf16 interface files to add check for bf16 ISA support Change-Id: I2f7069a405037a59ea188a41bd8d10c4aae72fb3 --- addon/aocl_gemm/aocl_gemm_bf16.c | 4 +- addon/aocl_gemm/aocl_gemm_bf16_utils.c | 4 +- .../lpgemm_6x64rowmajor_bf16_amd512vnni.c | 266 +++ .../bf16bf16f32/lpgemm_f32_kern_macros.h | 45 + .../lpgemm_m_fringe_bf16_amd512vnni.c | 838 +++++++- .../lpgemm_mn_fringe_bf16_amd512vnni.c | 1842 ++++++++++++++++- .../lpgemm_n_fringe_bf16_amd512vnni.c | 538 ++++- 7 files changed, 3431 insertions(+), 106 deletions(-) create mode 100644 addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_f32_kern_macros.h diff --git a/addon/aocl_gemm/aocl_gemm_bf16.c b/addon/aocl_gemm/aocl_gemm_bf16.c index b61a2efb43..6138b4e259 100644 --- a/addon/aocl_gemm/aocl_gemm_bf16.c +++ b/addon/aocl_gemm/aocl_gemm_bf16.c @@ -46,8 +46,8 @@ AOCL_GEMM_MATMUL(bfloat16,bfloat16,float,bf16bf16f32of32) trans_t blis_transa; trans_t blis_transb; - // Check if avx512_vnni ISA is supported, lpgemm matmul only works with it. - if ( bli_cpuid_is_avx512vnni_supported() == FALSE ) + // Check if avx512_bf16 ISA is supported, lpgemm matmul only works with it. + if ( bli_cpuid_is_avx512_bf16_supported() == FALSE ) { printf(" AVX512_BF16 ISA not supported by processor, cannot perform lpgemm.\n"); return; // Error. diff --git a/addon/aocl_gemm/aocl_gemm_bf16_utils.c b/addon/aocl_gemm/aocl_gemm_bf16_utils.c index b65d04bd43..7af08b751b 100644 --- a/addon/aocl_gemm/aocl_gemm_bf16_utils.c +++ b/addon/aocl_gemm/aocl_gemm_bf16_utils.c @@ -47,7 +47,7 @@ AOCL_GEMM_GET_REORDER_BUF_SIZE(bf16bf16f32of32) } // Check if avx512_bf16 ISA is supported, lpgemm matmul only works with it. - if ( bli_cpuid_is_avx512vnni_supported() == FALSE ) + if ( bli_cpuid_is_avx512_bf16_supported() == FALSE ) { printf(" AVX512_BF16 ISA not supported by processor, cannot perform lpgemm.\n"); return 0; // Error. @@ -91,7 +91,7 @@ AOCL_GEMM_REORDER(bfloat16, bf16bf16f32of32) } // Check if avx512_bf16 ISA is supported, lpgemm matmul only works with it. - if ( bli_cpuid_is_avx512vnni_supported() == FALSE ) + if ( bli_cpuid_is_avx512_bf16_supported() == FALSE ) { printf(" AVX512_BF16 ISA not supported by processor, cannot perform lpgemm.\n"); return; // Error. diff --git a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_6x64rowmajor_bf16_amd512vnni.c b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_6x64rowmajor_bf16_amd512vnni.c index b3db64e0b0..872aa9a747 100644 --- a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_6x64rowmajor_bf16_amd512vnni.c +++ b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_6x64rowmajor_bf16_amd512vnni.c @@ -36,10 +36,18 @@ #include "blis.h" #include "lpgemm_kernels.h" +#include "lpgemm_f32_kern_macros.h" // 6x64 bf16 kernel LPGEMM_MAIN_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x64) { + static void* post_ops_labels[] = + { + &&POST_OPS_6x64_DISABLE, + &&POST_OPS_BIAS_6x64, + &&POST_OPS_RELU_6x64, + &&POST_OPS_RELU_SCALE_6x64 + }; dim_t MR = 6; dim_t NR = 64; @@ -85,6 +93,7 @@ LPGEMM_MAIN_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x64) b = b + ( 48 * k0_updated ); // k0x48 packed contiguosly. c = c + 48; + post_op_c_j += 48; } else if ( n0_32 == 1 ) @@ -103,6 +112,7 @@ LPGEMM_MAIN_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x64) b = b + ( 32 * k0_updated ); // k0x32 packed contiguosly. c = c + 32; + post_op_c_j += 32; } else if ( n0_16 == 1 ) @@ -121,6 +131,7 @@ LPGEMM_MAIN_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x64) b = b + ( 16 * k0_updated ); // k0x16 packed contiguosly. c = c + 16; + post_op_c_j += 16; } if ( n0_rem > 0 ) @@ -530,6 +541,260 @@ LPGEMM_MAIN_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x64) selector1 = _mm512_mul_ps( selector2, selector1 ); c_float_5p3 = _mm512_add_ps( selector1, c_float_5p3 ); } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_6x64: + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 ) ); + __m512 selector3 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 2 * 16 ) ); + __m512 selector4 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 3 * 16 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector3, c_float_0p2 ); + + // c[0,48-63] + c_float_0p3 = _mm512_add_ps( selector4, c_float_0p3 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector3, c_float_1p2 ); + + // c[1,48-63] + c_float_1p3 = _mm512_add_ps( selector4, c_float_1p3 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector2, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); + + // c[2,48-63] + c_float_2p3 = _mm512_add_ps( selector4, c_float_2p3 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + + // c[3, 16-31] + c_float_3p1 = _mm512_add_ps( selector2, c_float_3p1 ); + + // c[3,32-47] + c_float_3p2 = _mm512_add_ps( selector3, c_float_3p2 ); + + // c[3,48-63] + c_float_3p3 = _mm512_add_ps( selector4, c_float_3p3 ); + + // c[4,0-15] + c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); + + // c[4, 16-31] + c_float_4p1 = _mm512_add_ps( selector2, c_float_4p1 ); + + // c[4,32-47] + c_float_4p2 = _mm512_add_ps( selector3, c_float_4p2 ); + + // c[4,48-63] + c_float_4p3 = _mm512_add_ps( selector4, c_float_4p3 ); + + // c[5,0-15] + c_float_5p0 = _mm512_add_ps( selector1, c_float_5p0 ); + + // c[5, 16-31] + c_float_5p1 = _mm512_add_ps( selector2, c_float_5p1 ); + + // c[5,32-47] + c_float_5p2 = _mm512_add_ps( selector3, c_float_5p2 ); + + // c[5,48-63] + c_float_5p3 = _mm512_add_ps( selector4, c_float_5p3 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_6x64: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_max_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_max_ps( selector1, c_float_0p2 ); + + // c[0,48-63] + c_float_0p3 = _mm512_max_ps( selector1, c_float_0p3 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + c_float_1p1 = _mm512_max_ps( selector1, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_max_ps( selector1, c_float_1p2 ); + + // c[1,48-63] + c_float_1p3 = _mm512_max_ps( selector1, c_float_1p3 ); + + // c[2,0-15] + c_float_2p0 = _mm512_max_ps( selector1, c_float_2p0 ); + + // c[2,16-31] + c_float_2p1 = _mm512_max_ps( selector1, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_max_ps( selector1, c_float_2p2 ); + + // c[2,48-63] + c_float_2p3 = _mm512_max_ps( selector1, c_float_2p3 ); + + // c[3,0-15] + c_float_3p0 = _mm512_max_ps( selector1, c_float_3p0 ); + + // c[3,16-31] + c_float_3p1 = _mm512_max_ps( selector1, c_float_3p1 ); + + // c[3,32-47] + c_float_3p2 = _mm512_max_ps( selector1, c_float_3p2 ); + + // c[3,48-63] + c_float_3p3 = _mm512_max_ps( selector1, c_float_3p3 ); + + // c[4,0-15] + c_float_4p0 = _mm512_max_ps( selector1, c_float_4p0 ); + + // c[4,16-31] + c_float_4p1 = _mm512_max_ps( selector1, c_float_4p1 ); + + // c[4,32-47] + c_float_4p2 = _mm512_max_ps( selector1, c_float_4p2 ); + + // c[4,48-63] + c_float_4p3 = _mm512_max_ps( selector1, c_float_4p3 ); + + // c[5,0-15] + c_float_5p0 = _mm512_max_ps( selector1, c_float_5p0 ); + + // c[5,16-31] + c_float_5p1 = _mm512_max_ps( selector1, c_float_5p1 ); + + // c[5,32-47] + c_float_5p2 = _mm512_max_ps( selector1, c_float_5p2 ); + + // c[5,48-63] + c_float_5p3 = _mm512_max_ps( selector1, c_float_5p3 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_6x64: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_0p1) + + // c[0, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_0p2) + + // c[0, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_0p3) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_1p1) + + // c[1, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_1p2) + + // c[1, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_1p3) + + // c[2, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_2p0) + + // c[2, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_2p1) + + // c[2, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_2p2) + + // c[2, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_2p3) + + // c[3, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_3p0) + + // c[3, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_3p1) + + // c[3, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_3p2) + + // c[3, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_3p3) + + // c[4, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_4p0) + + // c[4, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_4p1) + + // c[4, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_4p2) + + // c[4, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_4p3) + + // c[5, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_5p0) + + // c[5, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_5p1) + + // c[5, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_5p2) + + // c[5, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_5p3) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_6x64_DISABLE: + ; // Store the results. // c[0,0-15] @@ -605,6 +870,7 @@ LPGEMM_MAIN_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x64) _mm512_storeu_ps( c + ( rs_c * ( ir + 5 ) ) + ( 3*16 ), c_float_5p3 ); a = a + ( MR * ps_a ); + post_op_c_i += MR; } if ( m_partial_pieces > 0 ) diff --git a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_f32_kern_macros.h b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_f32_kern_macros.h new file mode 100644 index 0000000000..5ed1c6350f --- /dev/null +++ b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_f32_kern_macros.h @@ -0,0 +1,45 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef LPGEMM_F32_KERN_MACROS_H +#define LPGEMM_F32_KERN_MACROS_H + +#define RELU_SCALE_OP_F32_AVX512(reg) \ + /* Generate indenx of elements <= 0.*/ \ + relu_cmp_mask = _mm512_cmple_ps_mask( reg, selector1 ); \ + \ + /* Apply scaling on for <= 0 elements.*/ \ + reg = _mm512_mask_mul_ps( reg, relu_cmp_mask, reg, selector2 ); \ + +#endif // LPGEMM_F32_KERN_MACROS_H \ No newline at end of file diff --git a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_m_fringe_bf16_amd512vnni.c b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_m_fringe_bf16_amd512vnni.c index b4b2d2e27a..f88aaa9106 100644 --- a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_m_fringe_bf16_amd512vnni.c +++ b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_m_fringe_bf16_amd512vnni.c @@ -37,10 +37,18 @@ #include "blis.h" #include "lpgemm_kernels.h" +#include "lpgemm_f32_kern_macros.h" // 5x64 bf16 kernel LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x64) { + static void* post_ops_labels[] = + { + &&POST_OPS_5x64_DISABLE, + &&POST_OPS_BIAS_5x64, + &&POST_OPS_RELU_5x64, + &&POST_OPS_RELU_SCALE_5x64 + }; dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -371,6 +379,224 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x64) selector1 = _mm512_mul_ps( selector2, selector1 ); c_float_4p3 = _mm512_add_ps( selector1, c_float_4p3 ); } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_5x64: + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 ) ); + __m512 selector3 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 2 * 16 ) ); + __m512 selector4 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 3 * 16 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector3, c_float_0p2 ); + + // c[0,48-63] + c_float_0p3 = _mm512_add_ps( selector4, c_float_0p3 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector3, c_float_1p2 ); + + // c[1,48-63] + c_float_1p3 = _mm512_add_ps( selector4, c_float_1p3 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector2, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); + + // c[2,48-63] + c_float_2p3 = _mm512_add_ps( selector4, c_float_2p3 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + + // c[3, 16-31] + c_float_3p1 = _mm512_add_ps( selector2, c_float_3p1 ); + + // c[3,32-47] + c_float_3p2 = _mm512_add_ps( selector3, c_float_3p2 ); + + // c[3,48-63] + c_float_3p3 = _mm512_add_ps( selector4, c_float_3p3 ); + + // c[4,0-15] + c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); + + // c[4, 16-31] + c_float_4p1 = _mm512_add_ps( selector2, c_float_4p1 ); + + // c[4,32-47] + c_float_4p2 = _mm512_add_ps( selector3, c_float_4p2 ); + + // c[4,48-63] + c_float_4p3 = _mm512_add_ps( selector4, c_float_4p3 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_5x64: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_max_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_max_ps( selector1, c_float_0p2 ); + + // c[0,48-63] + c_float_0p3 = _mm512_max_ps( selector1, c_float_0p3 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + c_float_1p1 = _mm512_max_ps( selector1, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_max_ps( selector1, c_float_1p2 ); + + // c[1,48-63] + c_float_1p3 = _mm512_max_ps( selector1, c_float_1p3 ); + + // c[2,0-15] + c_float_2p0 = _mm512_max_ps( selector1, c_float_2p0 ); + + // c[2,16-31] + c_float_2p1 = _mm512_max_ps( selector1, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_max_ps( selector1, c_float_2p2 ); + + // c[2,48-63] + c_float_2p3 = _mm512_max_ps( selector1, c_float_2p3 ); + + // c[3,0-15] + c_float_3p0 = _mm512_max_ps( selector1, c_float_3p0 ); + + // c[3,16-31] + c_float_3p1 = _mm512_max_ps( selector1, c_float_3p1 ); + + // c[3,32-47] + c_float_3p2 = _mm512_max_ps( selector1, c_float_3p2 ); + + // c[3,48-63] + c_float_3p3 = _mm512_max_ps( selector1, c_float_3p3 ); + + // c[4,0-15] + c_float_4p0 = _mm512_max_ps( selector1, c_float_4p0 ); + + // c[4,16-31] + c_float_4p1 = _mm512_max_ps( selector1, c_float_4p1 ); + + // c[4,32-47] + c_float_4p2 = _mm512_max_ps( selector1, c_float_4p2 ); + + // c[4,48-63] + c_float_4p3 = _mm512_max_ps( selector1, c_float_4p3 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_5x64: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_0p1) + + // c[0, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_0p2) + + // c[0, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_0p3) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_1p1) + + // c[1, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_1p2) + + // c[1, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_1p3) + + // c[2, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_2p0) + + // c[2, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_2p1) + + // c[2, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_2p2) + + // c[2, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_2p3) + + // c[3, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_3p0) + + // c[3, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_3p1) + + // c[3, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_3p2) + + // c[3, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_3p3) + + // c[4, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_4p0) + + // c[4, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_4p1) + + // c[4, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_4p2) + + // c[4, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_4p3) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_5x64_DISABLE: + ; // Store the results. // c[0,0-15] @@ -437,6 +663,13 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x64) // 4x64 bf16 kernel LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x64) { + static void* post_ops_labels[] = + { + &&POST_OPS_4x64_DISABLE, + &&POST_OPS_BIAS_4x64, + &&POST_OPS_RELU_4x64, + &&POST_OPS_RELU_SCALE_4x64 + }; dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -710,6 +943,188 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x64) selector1 = _mm512_mul_ps( selector2, selector1 ); c_float_3p3 = _mm512_add_ps( selector1, c_float_3p3 ); } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_4x64: + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 ) ); + __m512 selector3 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 2 * 16 ) ); + __m512 selector4 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 3 * 16 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector3, c_float_0p2 ); + + // c[0,48-63] + c_float_0p3 = _mm512_add_ps( selector4, c_float_0p3 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector3, c_float_1p2 ); + + // c[1,48-63] + c_float_1p3 = _mm512_add_ps( selector4, c_float_1p3 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector2, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); + + // c[2,48-63] + c_float_2p3 = _mm512_add_ps( selector4, c_float_2p3 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + + // c[3, 16-31] + c_float_3p1 = _mm512_add_ps( selector2, c_float_3p1 ); + + // c[3,32-47] + c_float_3p2 = _mm512_add_ps( selector3, c_float_3p2 ); + + // c[3,48-63] + c_float_3p3 = _mm512_add_ps( selector4, c_float_3p3 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_4x64: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_max_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_max_ps( selector1, c_float_0p2 ); + + // c[0,48-63] + c_float_0p3 = _mm512_max_ps( selector1, c_float_0p3 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + c_float_1p1 = _mm512_max_ps( selector1, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_max_ps( selector1, c_float_1p2 ); + + // c[1,48-63] + c_float_1p3 = _mm512_max_ps( selector1, c_float_1p3 ); + + // c[2,0-15] + c_float_2p0 = _mm512_max_ps( selector1, c_float_2p0 ); + + // c[2,16-31] + c_float_2p1 = _mm512_max_ps( selector1, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_max_ps( selector1, c_float_2p2 ); + + // c[2,48-63] + c_float_2p3 = _mm512_max_ps( selector1, c_float_2p3 ); + + // c[3,0-15] + c_float_3p0 = _mm512_max_ps( selector1, c_float_3p0 ); + + // c[3,16-31] + c_float_3p1 = _mm512_max_ps( selector1, c_float_3p1 ); + + // c[3,32-47] + c_float_3p2 = _mm512_max_ps( selector1, c_float_3p2 ); + + // c[3,48-63] + c_float_3p3 = _mm512_max_ps( selector1, c_float_3p3 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_4x64: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_0p1) + + // c[0, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_0p2) + + // c[0, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_0p3) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_1p1) + + // c[1, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_1p2) + + // c[1, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_1p3) + + // c[2, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_2p0) + + // c[2, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_2p1) + + // c[2, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_2p2) + + // c[2, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_2p3) + + // c[3, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_3p0) + + // c[3, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_3p1) + + // c[3, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_3p2) + + // c[3, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_3p3) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_4x64_DISABLE: + ; // Store the results. // c[0,0-15] @@ -764,6 +1179,13 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x64) // 3x64 bf16 kernel LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x64) { + static void* post_ops_labels[] = + { + &&POST_OPS_3x64_DISABLE, + &&POST_OPS_BIAS_3x64, + &&POST_OPS_RELU_3x64, + &&POST_OPS_RELU_SCALE_3x64 + }; dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -865,7 +1287,7 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x64) ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - __m512i a_bf16_1 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + a_bf16_1 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); @@ -979,6 +1401,152 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x64) selector1 = _mm512_mul_ps( selector2, selector1 ); c_float_2p3 = _mm512_add_ps( selector1, c_float_2p3 ); } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_3x64: + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 ) ); + __m512 selector3 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 2 * 16 ) ); + __m512 selector4 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 3 * 16 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector3, c_float_0p2 ); + + // c[0,48-63] + c_float_0p3 = _mm512_add_ps( selector4, c_float_0p3 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector3, c_float_1p2 ); + + // c[1,48-63] + c_float_1p3 = _mm512_add_ps( selector4, c_float_1p3 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector2, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); + + // c[2,48-63] + c_float_2p3 = _mm512_add_ps( selector4, c_float_2p3 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_3x64: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_max_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_max_ps( selector1, c_float_0p2 ); + + // c[0,48-63] + c_float_0p3 = _mm512_max_ps( selector1, c_float_0p3 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + c_float_1p1 = _mm512_max_ps( selector1, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_max_ps( selector1, c_float_1p2 ); + + // c[1,48-63] + c_float_1p3 = _mm512_max_ps( selector1, c_float_1p3 ); + + // c[2,0-15] + c_float_2p0 = _mm512_max_ps( selector1, c_float_2p0 ); + + // c[2,16-31] + c_float_2p1 = _mm512_max_ps( selector1, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_max_ps( selector1, c_float_2p2 ); + + // c[2,48-63] + c_float_2p3 = _mm512_max_ps( selector1, c_float_2p3 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_3x64: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_0p1) + + // c[0, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_0p2) + + // c[0, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_0p3) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_1p1) + + // c[1, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_1p2) + + // c[1, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_1p3) + + // c[2, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_2p0) + + // c[2, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_2p1) + + // c[2, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_2p2) + + // c[2, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_2p3) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_3x64_DISABLE: + ; // Store the results. // c[0,0-15] @@ -1021,6 +1589,13 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x64) // 2x64 bf16 kernel LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2x64) { + static void* post_ops_labels[] = + { + &&POST_OPS_2x64_DISABLE, + &&POST_OPS_BIAS_2x64, + &&POST_OPS_RELU_2x64, + &&POST_OPS_RELU_SCALE_2x64 + }; dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -1079,7 +1654,7 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2x64) // Handle k remainder. if ( k_partial_pieces > 0 ) { - __m512i b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); // Broadcast a[0,kr:kr+2]. memcpy @@ -1088,11 +1663,11 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2x64) ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - __m512i a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); - __m512i b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); - __m512i b2 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); - __m512i b3 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 3 ) ); + b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + b2 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); + b3 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 3 ) ); // Perform column direction mat-mul with k = 2. // c[0,0-63] = a[0,kr:kr+2]*b[kr:kr+2,0-63] @@ -1105,7 +1680,7 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2x64) ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - __m512i a_bf16_1 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + a_bf16_1 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); @@ -1178,6 +1753,116 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2x64) selector1 = _mm512_mul_ps( selector2, selector1 ); c_float_1p3 = _mm512_add_ps( selector1, c_float_1p3 ); } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_2x64: + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 ) ); + __m512 selector3 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 2 * 16 ) ); + __m512 selector4 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 3 * 16 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector3, c_float_0p2 ); + + // c[0,48-63] + c_float_0p3 = _mm512_add_ps( selector4, c_float_0p3 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector3, c_float_1p2 ); + + // c[1,48-63] + c_float_1p3 = _mm512_add_ps( selector4, c_float_1p3 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_2x64: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_max_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_max_ps( selector1, c_float_0p2 ); + + // c[0,48-63] + c_float_0p3 = _mm512_max_ps( selector1, c_float_0p3 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + c_float_1p1 = _mm512_max_ps( selector1, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_max_ps( selector1, c_float_1p2 ); + + // c[1,48-63] + c_float_1p3 = _mm512_max_ps( selector1, c_float_1p3 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_2x64: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_0p1) + + // c[0, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_0p2) + + // c[0, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_0p3) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_1p1) + + // c[1, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_1p2) + + // c[1, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_1p3) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_2x64_DISABLE: + ; // Store the results. // c[0,0-15] @@ -1208,16 +1893,23 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2x64) // 1x64 bf16 kernel LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1x64) { + static void* post_ops_labels[] = + { + &&POST_OPS_1x64_DISABLE, + &&POST_OPS_BIAS_1x64, + &&POST_OPS_RELU_1x64, + &&POST_OPS_RELU_SCALE_1x64 + }; dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; int32_t a_kfringe_buf = 0; // Registers to use for accumulating C. - __m512 c_float32_0p0 = _mm512_setzero_ps(); - __m512 c_float32_0p1 = _mm512_setzero_ps(); - __m512 c_float32_0p2 = _mm512_setzero_ps(); - __m512 c_float32_0p3 = _mm512_setzero_ps(); + __m512 c_float_0p0 = _mm512_setzero_ps(); + __m512 c_float_0p1 = _mm512_setzero_ps(); + __m512 c_float_0p2 = _mm512_setzero_ps(); + __m512 c_float_0p3 = _mm512_setzero_ps(); for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) { @@ -1232,16 +1924,16 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1x64) // Perform column direction mat-mul with k = 2. // c[0,0-63] = a[0,kr:kr+2]*b[kr:kr+2,0-63] - c_float32_0p0 = _mm512_dpbf16_ps( c_float32_0p0, a_bf16_0, b0 ); - c_float32_0p1 = _mm512_dpbf16_ps( c_float32_0p1, a_bf16_0, b1 ); - c_float32_0p2 = _mm512_dpbf16_ps( c_float32_0p2, a_bf16_0, b2 ); - c_float32_0p3 = _mm512_dpbf16_ps( c_float32_0p3, a_bf16_0, b3 ); + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); + c_float_0p3 = _mm512_dpbf16_ps( c_float_0p3, a_bf16_0, b3 ); } // Handle k remainder. if ( k_partial_pieces > 0 ) { - __m512i b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); + __m512bh b0 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 0 ) ); // Broadcast a[0,kr:kr+2]. memcpy @@ -1250,18 +1942,18 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1x64) ( a + ( rs_a * 0 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); - __m512i a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); + __m512bh a_bf16_0 = (__m512bh)_mm512_set1_epi32( a_kfringe_buf ); - __m512i b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); - __m512i b2 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); - __m512i b3 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 3 ) ); + __m512bh b1 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 1 ) ); + __m512bh b2 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 2 ) ); + __m512bh b3 = (__m512bh)_mm512_loadu_epi16( b + ( rs_b * k_full_pieces ) + ( cs_b * 3 ) ); // Perform column direction mat-mul with k = 2. // c[0,0-63] = a[0,kr:kr+2]*b[kr:kr+2,0-63] - c_float32_0p0 = _mm512_dpbf16_ps( c_float32_0p0, a_bf16_0, b0 ); - c_float32_0p1 = _mm512_dpbf16_ps( c_float32_0p1, a_bf16_0, b1 ); - c_float32_0p2 = _mm512_dpbf16_ps( c_float32_0p2, a_bf16_0, b2 ); - c_float32_0p3 = _mm512_dpbf16_ps( c_float32_0p3, a_bf16_0, b3 ); + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); + c_float_0p3 = _mm512_dpbf16_ps( c_float_0p3, a_bf16_0, b3 ); } // Load alpha and beta @@ -1269,10 +1961,10 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1x64) __m512 selector2 = _mm512_set1_ps( beta ); // Scale by alpha - c_float32_0p0 = _mm512_mul_ps( selector1, c_float32_0p0 ); - c_float32_0p1 = _mm512_mul_ps( selector1, c_float32_0p1 ); - c_float32_0p2 = _mm512_mul_ps( selector1, c_float32_0p2 ); - c_float32_0p3 = _mm512_mul_ps( selector1, c_float32_0p3 ); + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + c_float_0p1 = _mm512_mul_ps( selector1, c_float_0p1 ); + c_float_0p2 = _mm512_mul_ps( selector1, c_float_0p2 ); + c_float_0p3 = _mm512_mul_ps( selector1, c_float_0p3 ); // Scale C by beta. if ( beta != 0) @@ -1280,34 +1972,108 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1x64) // c[0,0-15] selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 0*16 ) ); selector1 = _mm512_mul_ps( selector2, selector1 ); - c_float32_0p0 = _mm512_add_ps( selector1, c_float32_0p0 ); + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); // c[0, 16-31] selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 1*16 ) ); selector1 = _mm512_mul_ps( selector2, selector1 ); - c_float32_0p1 = _mm512_add_ps( selector1, c_float32_0p1 ); + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); // c[0,32-47] selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 2*16 ) ); selector1 = _mm512_mul_ps( selector2, selector1 ); - c_float32_0p2 = _mm512_add_ps( selector1, c_float32_0p2 ); + c_float_0p2 = _mm512_add_ps( selector1, c_float_0p2 ); // c[0,48-63] selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 3*16 ) ); selector1 = _mm512_mul_ps( selector2, selector1 ); - c_float32_0p3 = _mm512_add_ps( selector1, c_float32_0p3 ); + c_float_0p3 = _mm512_add_ps( selector1, c_float_0p3 ); + } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_1x64: + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 ) ); + __m512 selector3 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 2 * 16 ) ); + __m512 selector4 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 3 * 16 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector3, c_float_0p2 ); + + // c[0,48-63] + c_float_0p3 = _mm512_add_ps( selector4, c_float_0p3 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_RELU_1x64: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_max_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_max_ps( selector1, c_float_0p2 ); + + // c[0,48-63] + c_float_0p3 = _mm512_max_ps( selector1, c_float_0p3 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_1x64: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_0p1) + + // c[0, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_0p2) + + // c[0, 48-63] + RELU_SCALE_OP_F32_AVX512(c_float_0p3) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_1x64_DISABLE: + ; // Store the accumulated results. // c[0,0-15] - _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 0*16 ), c_float32_0p0 ); + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 0*16 ), c_float_0p0 ); // c[0, 16-31] - _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 1*16 ), c_float32_0p1 ); + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 1*16 ), c_float_0p1 ); // c[0,32-47] - _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 2*16 ), c_float32_0p2 ); + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 2*16 ), c_float_0p2 ); // c[0,48-63] - _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 3*16 ), c_float32_0p3 ); + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 3*16 ), c_float_0p3 ); } diff --git a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_mn_fringe_bf16_amd512vnni.c b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_mn_fringe_bf16_amd512vnni.c index 24c4cd0e85..6e864b45c3 100644 --- a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_mn_fringe_bf16_amd512vnni.c +++ b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_mn_fringe_bf16_amd512vnni.c @@ -37,10 +37,18 @@ #include "blis.h" #include "lpgemm_kernels.h" +#include "lpgemm_f32_kern_macros.h" // 5xlt16 bf16 fringe kernel LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5xlt16) { + static void* post_ops_labels[] = + { + &&POST_OPS_5xLT16_DISABLE, + &&POST_OPS_BIAS_5xLT16, + &&POST_OPS_RELU_5xLT16, + &&POST_OPS_RELU_SCALE_5xLT16 + }; dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -204,6 +212,80 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5xlt16) selector1 = _mm512_mul_ps( selector2, selector1 ); c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_5xLT16: + { + memcpy( buf0, ( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j ), ( n0_rem * sizeof( float ) ) ); + selector1 = _mm512_loadu_ps( buf0 ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + + // c[4,0-15] + c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_5xLT16: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_max_ps( selector1, c_float_2p0 ); + + // c[3,0-15] + c_float_3p0 = _mm512_max_ps( selector1, c_float_3p0 ); + + // c[4,0-15] + c_float_4p0 = _mm512_max_ps( selector1, c_float_4p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_5xLT16: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + // c[2, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_2p0) + + // c[3, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_3p0) + + // c[4, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_4p0) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_5xLT16_DISABLE: + ; // Store the results. // c[0,0-15] @@ -242,6 +324,13 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5xlt16) // 4xlt16 bf16 fringe kernel LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4xlt16) { + static void* post_ops_labels[] = + { + &&POST_OPS_4xLT16_DISABLE, + &&POST_OPS_BIAS_4xLT16, + &&POST_OPS_RELU_4xLT16, + &&POST_OPS_RELU_SCALE_4xLT16 + }; dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -381,7 +470,72 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4xlt16) selector1 = _mm512_mul_ps( selector2, selector1 ); c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); } - + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_4xLT16: + { + memcpy( buf0, ( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j ), ( n0_rem * sizeof( float ) ) ); + selector1 = _mm512_loadu_ps( buf0 ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_4xLT16: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_max_ps( selector1, c_float_2p0 ); + + // c[3,0-15] + c_float_3p0 = _mm512_max_ps( selector1, c_float_3p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_4xLT16: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + // c[2, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_2p0) + + // c[3, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_3p0) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_4xLT16_DISABLE: + ; + // Store the results. // c[0,0-15] _mm512_storeu_ps( buf0, c_float_0p0 ); @@ -413,6 +567,13 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4xlt16) // 3xlt16 bf16 fringe kernel LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3xlt16) { + static void* post_ops_labels[] = + { + &&POST_OPS_3xLT16_DISABLE, + &&POST_OPS_BIAS_3xLT16, + &&POST_OPS_RELU_3xLT16, + &&POST_OPS_RELU_SCALE_3xLT16 + }; dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -525,6 +686,62 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3xlt16) selector1 = _mm512_mul_ps( selector2, selector1 ); c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_3xLT16: + { + memcpy( buf0, ( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j ), ( n0_rem * sizeof( float ) ) ); + selector1 = _mm512_loadu_ps( buf0 ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_3xLT16: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_max_ps( selector1, c_float_2p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_3xLT16: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + // c[2, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_2p0) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_3xLT16_DISABLE: + ; // Store the results. // c[0,0-15] @@ -551,6 +768,13 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3xlt16) // 2xlt16 bf16 fringe kernel LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2xlt16) { + static void* post_ops_labels[] = + { + &&POST_OPS_2xLT16_DISABLE, + &&POST_OPS_BIAS_2xLT16, + &&POST_OPS_RELU_2xLT16, + &&POST_OPS_RELU_SCALE_2xLT16 + }; dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -637,6 +861,53 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2xlt16) selector1 = _mm512_mul_ps( selector2, selector1 ); c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_2xLT16: + { + memcpy( buf0, ( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j ), ( n0_rem * sizeof( float ) ) ); + selector1 = _mm512_loadu_ps( buf0 ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_2xLT16: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_2xLT16: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_2xLT16_DISABLE: + ; // Store the results. // c[0,0-15] @@ -657,6 +928,13 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2xlt16) // 1xlt16 bf16 fringe kernel LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1xlt16) { + static void* post_ops_labels[] = + { + &&POST_OPS_1xLT16_DISABLE, + &&POST_OPS_BIAS_1xLT16, + &&POST_OPS_RELU_1xLT16, + &&POST_OPS_RELU_SCALE_1xLT16 + }; dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -717,6 +995,44 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1xlt16) selector1 = _mm512_mul_ps( selector2, selector1 ); c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_1xLT16: + { + memcpy( buf0, ( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j ), ( n0_rem * sizeof( float ) ) ); + selector1 = _mm512_loadu_ps( buf0 ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_1xLT16: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_1xLT16: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_1xLT16_DISABLE: + ; // Store the results. // c[0,0-15] @@ -731,6 +1047,13 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1xlt16) // 5x16 bf16 kernel LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x16) { + static void* post_ops_labels[] = + { + &&POST_OPS_5x16_DISABLE, + &&POST_OPS_BIAS_5x16, + &&POST_OPS_RELU_5x16, + &&POST_OPS_RELU_SCALE_5x16 + }; dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -881,6 +1204,80 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x16) selector1 = _mm512_mul_ps( selector2, selector1 ); c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_5x16: + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + + // c[4,0-15] + c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_5x16: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_max_ps( selector1, c_float_2p0 ); + + // c[3,0-15] + c_float_3p0 = _mm512_max_ps( selector1, c_float_3p0 ); + + // c[4,0-15] + c_float_4p0 = _mm512_max_ps( selector1, c_float_4p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_5x16: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + // c[2, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_2p0) + + // c[3, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_3p0) + + // c[4, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_4p0) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_5x16_DISABLE: + ; // Store the results. // c[0,0-15] @@ -902,6 +1299,13 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x16) // 4x16 bf16 kernel LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x16) { + static void* post_ops_labels[] = + { + &&POST_OPS_4x16_DISABLE, + &&POST_OPS_BIAS_4x16, + &&POST_OPS_RELU_4x16, + &&POST_OPS_RELU_SCALE_4x16 + }; dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -1028,6 +1432,71 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x16) selector1 = _mm512_mul_ps( selector2, selector1 ); c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_4x16: + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_4x16: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_max_ps( selector1, c_float_2p0 ); + + // c[3,0-15] + c_float_3p0 = _mm512_max_ps( selector1, c_float_3p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_4x16: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + // c[2, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_2p0) + + // c[3, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_3p0) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_4x16_DISABLE: + ; // Store the results. // c[0,0-15] @@ -1046,6 +1515,13 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x16) // 3x16 bf16 kernel LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x16) { + static void* post_ops_labels[] = + { + &&POST_OPS_3x16_DISABLE, + &&POST_OPS_BIAS_3x16, + &&POST_OPS_RELU_3x16, + &&POST_OPS_RELU_SCALE_3x16 + }; dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -1148,6 +1624,63 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x16) selector1 = _mm512_mul_ps( selector2, selector1 ); c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_3x16: + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_3x16: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_max_ps( selector1, c_float_2p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_3x16: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + // c[2, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_2p0) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_3x16_DISABLE: + ; + // Store the results. // c[0,0-15] @@ -1163,6 +1696,13 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x16) // 2x16 bf16 kernel LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2x16) { + static void* post_ops_labels[] = + { + &&POST_OPS_2x16_DISABLE, + &&POST_OPS_BIAS_2x16, + &&POST_OPS_RELU_2x16, + &&POST_OPS_RELU_SCALE_2x16 + }; dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -1241,6 +1781,53 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2x16) selector1 = _mm512_mul_ps( selector2, selector1 ); c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_2x16: + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_2x16: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_2x16: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_2x16_DISABLE: + ; // Store the results. // c[0,0-15] @@ -1253,6 +1840,13 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2x16) // 1x16 bf16 kernel LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1x16) { + static void* post_ops_labels[] = + { + &&POST_OPS_1x16_DISABLE, + &&POST_OPS_BIAS_1x16, + &&POST_OPS_RELU_1x16, + &&POST_OPS_RELU_SCALE_1x16 + }; dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -1307,6 +1901,44 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1x16) selector1 = _mm512_mul_ps( selector2, selector1 ); c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_1x16: + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_1x16: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_1x16: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_1x16_DISABLE: + ; // Store the results. // c[0,0-15] @@ -1316,6 +1948,13 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1x16) // 5x32 bf16 kernel LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x32) { + static void* post_ops_labels[] = + { + &&POST_OPS_5x32_DISABLE, + &&POST_OPS_BIAS_5x32, + &&POST_OPS_RELU_5x32, + &&POST_OPS_RELU_SCALE_5x32 + }; dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -1514,6 +2153,128 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x32) selector1 = _mm512_mul_ps( selector2, selector1 ); c_float_4p1 = _mm512_add_ps( selector1, c_float_4p1 ); } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_5x32: + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector2, c_float_2p1 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + + // c[3, 16-31] + c_float_3p1 = _mm512_add_ps( selector2, c_float_3p1 ); + + // c[4,0-15] + c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); + + // c[4, 16-31] + c_float_4p1 = _mm512_add_ps( selector2, c_float_4p1 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_5x32: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_max_ps( selector1, c_float_0p1 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + c_float_1p1 = _mm512_max_ps( selector1, c_float_1p1 ); + + // c[2,0-15] + c_float_2p0 = _mm512_max_ps( selector1, c_float_2p0 ); + + // c[2,16-31] + c_float_2p1 = _mm512_max_ps( selector1, c_float_2p1 ); + + // c[3,0-15] + c_float_3p0 = _mm512_max_ps( selector1, c_float_3p0 ); + + // c[3,16-31] + c_float_3p1 = _mm512_max_ps( selector1, c_float_3p1 ); + + // c[4,0-15] + c_float_4p0 = _mm512_max_ps( selector1, c_float_4p0 ); + + // c[4,16-31] + c_float_4p1 = _mm512_max_ps( selector1, c_float_4p1 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_5x32: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_0p1) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_1p1) + + // c[2, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_2p0) + + // c[2, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_2p1) + + // c[3, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_3p0) + + // c[3, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_3p1) + + // c[4, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_4p0) + + // c[4, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_4p1) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_5x32_DISABLE: + ; // Store the results. // c[0,0-15] @@ -1550,6 +2311,13 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x32) // 4x32 bf16 kernel LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x32) { + static void* post_ops_labels[] = + { + &&POST_OPS_4x32_DISABLE, + &&POST_OPS_BIAS_4x32, + &&POST_OPS_RELU_4x32, + &&POST_OPS_RELU_SCALE_4x32 + }; dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -1715,19 +2483,123 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x32) selector1 = _mm512_mul_ps( selector2, selector1 ); c_float_3p1 = _mm512_add_ps( selector1, c_float_3p1 ); } - - // Store the results. - // c[0,0-15] - _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 0*16 ), c_float_0p0 ); + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_4x32: + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 ) ); - // c[0, 16-31] - _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 1*16 ), c_float_0p1 ); + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); - // c[1,0-15] - _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 0*16 ), c_float_1p0 ); + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); - // c[1,16-31] - _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 1*16 ), c_float_1p1 ); + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector2, c_float_2p1 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + + // c[3, 16-31] + c_float_3p1 = _mm512_add_ps( selector2, c_float_3p1 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_4x32: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_max_ps( selector1, c_float_0p1 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + c_float_1p1 = _mm512_max_ps( selector1, c_float_1p1 ); + + // c[2,0-15] + c_float_2p0 = _mm512_max_ps( selector1, c_float_2p0 ); + + // c[2,16-31] + c_float_2p1 = _mm512_max_ps( selector1, c_float_2p1 ); + + // c[3,0-15] + c_float_3p0 = _mm512_max_ps( selector1, c_float_3p0 ); + + // c[3,16-31] + c_float_3p1 = _mm512_max_ps( selector1, c_float_3p1 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_4x32: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_0p1) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_1p1) + + // c[2, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_2p0) + + // c[2, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_2p1) + + // c[3, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_3p0) + + // c[3, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_3p1) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_4x32_DISABLE: + ; + + // Store the results. + // c[0,0-15] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 0*16 ), c_float_0p0 ); + + // c[0, 16-31] + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 1*16 ), c_float_0p1 ); + + // c[1,0-15] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 0*16 ), c_float_1p0 ); + + // c[1,16-31] + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 1*16 ), c_float_1p1 ); // c[2,0-15] _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 0*16 ), c_float_2p0 ); @@ -1745,6 +2617,13 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x32) // 3x32 bf16 kernel LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x32) { + static void* post_ops_labels[] = + { + &&POST_OPS_3x32_DISABLE, + &&POST_OPS_BIAS_3x32, + &&POST_OPS_RELU_3x32, + &&POST_OPS_RELU_SCALE_3x32 + }; dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -1877,6 +2756,92 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x32) selector1 = _mm512_mul_ps( selector2, selector1 ); c_float_2p1 = _mm512_add_ps( selector1, c_float_2p1 ); } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_3x32: + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector2, c_float_2p1 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_3x32: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_max_ps( selector1, c_float_0p1 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + c_float_1p1 = _mm512_max_ps( selector1, c_float_1p1 ); + + // c[2,0-15] + c_float_2p0 = _mm512_max_ps( selector1, c_float_2p0 ); + + // c[2,16-31] + c_float_2p1 = _mm512_max_ps( selector1, c_float_2p1 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_3x32: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_0p1) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_1p1) + + // c[2, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_2p0) + + // c[2, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_2p1) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_3x32_DISABLE: + ; // Store the results. // c[0,0-15] @@ -1901,6 +2866,13 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x32) // 2x32 bf16 kernel LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2x32) { + static void* post_ops_labels[] = + { + &&POST_OPS_2x32_DISABLE, + &&POST_OPS_BIAS_2x32, + &&POST_OPS_RELU_2x32, + &&POST_OPS_RELU_SCALE_2x32 + }; dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -2000,6 +2972,74 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2x32) selector1 = _mm512_mul_ps( selector2, selector1 ); c_float_1p1 = _mm512_add_ps( selector1, c_float_1p1 ); } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_2x32: + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_2x32: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_max_ps( selector1, c_float_0p1 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + c_float_1p1 = _mm512_max_ps( selector1, c_float_1p1 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_2x32: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_0p1) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_1p1) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_2x32_DISABLE: + ; // Store the results. // c[0,0-15] @@ -2018,6 +3058,13 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2x32) // 1x32 bf16 kernel LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1x32) { + static void* post_ops_labels[] = + { + &&POST_OPS_1x32_DISABLE, + &&POST_OPS_BIAS_1x32, + &&POST_OPS_RELU_1x32, + &&POST_OPS_RELU_SCALE_1x32 + }; dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -2084,6 +3131,56 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1x32) selector1 = _mm512_mul_ps( selector2, selector1 ); c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_1x32: + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_1x32: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_max_ps( selector1, c_float_0p1 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_1x32: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_0p1) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_1x32_DISABLE: + ; // Store the results. // c[0,0-15] @@ -2096,6 +3193,13 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1x32) // 5x48 bf16 kernel LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x48) { + static void* post_ops_labels[] = + { + &&POST_OPS_5x48_DISABLE, + &&POST_OPS_BIAS_5x48, + &&POST_OPS_RELU_5x48, + &&POST_OPS_RELU_SCALE_5x48 + }; dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -2342,6 +3446,176 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x48) selector1 = _mm512_mul_ps( selector2, selector1 ); c_float_4p2 = _mm512_add_ps( selector1, c_float_4p2 ); } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_5x48: + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 ) ); + __m512 selector3 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 2 * 16 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector3, c_float_0p2 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector3, c_float_1p2 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector2, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + + // c[3, 16-31] + c_float_3p1 = _mm512_add_ps( selector2, c_float_3p1 ); + + // c[3,32-47] + c_float_3p2 = _mm512_add_ps( selector3, c_float_3p2 ); + + // c[4,0-15] + c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); + + // c[4, 16-31] + c_float_4p1 = _mm512_add_ps( selector2, c_float_4p1 ); + + // c[4,32-47] + c_float_4p2 = _mm512_add_ps( selector3, c_float_4p2 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_5x48: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_max_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_max_ps( selector1, c_float_0p2 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + c_float_1p1 = _mm512_max_ps( selector1, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_max_ps( selector1, c_float_1p2 ); + + // c[2,0-15] + c_float_2p0 = _mm512_max_ps( selector1, c_float_2p0 ); + + // c[2,16-31] + c_float_2p1 = _mm512_max_ps( selector1, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_max_ps( selector1, c_float_2p2 ); + + // c[3,0-15] + c_float_3p0 = _mm512_max_ps( selector1, c_float_3p0 ); + + // c[3,16-31] + c_float_3p1 = _mm512_max_ps( selector1, c_float_3p1 ); + + // c[3,32-47] + c_float_3p2 = _mm512_max_ps( selector1, c_float_3p2 ); + + // c[4,0-15] + c_float_4p0 = _mm512_max_ps( selector1, c_float_4p0 ); + + // c[4,16-31] + c_float_4p1 = _mm512_max_ps( selector1, c_float_4p1 ); + + // c[4,32-47] + c_float_4p2 = _mm512_max_ps( selector1, c_float_4p2 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_5x48: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_0p1) + + // c[0, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_0p2) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_1p1) + + // c[1, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_1p2) + + // c[2, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_2p0) + + // c[2, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_2p1) + + // c[2, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_2p2) + + // c[3, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_3p0) + + // c[3, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_3p1) + + // c[3, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_3p2) + + // c[4, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_4p0) + + // c[4, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_4p1) + + // c[4, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_4p2) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_5x48_DISABLE: + ; // Store the results. // c[0,0-15] @@ -2393,6 +3667,13 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x48) // 4x48 bf16 kernel LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x48) { + static void* post_ops_labels[] = + { + &&POST_OPS_4x48_DISABLE, + &&POST_OPS_BIAS_4x48, + &&POST_OPS_RELU_4x48, + &&POST_OPS_RELU_SCALE_4x48 + }; dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -2597,6 +3878,149 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x48) selector1 = _mm512_mul_ps( selector2, selector1 ); c_float_3p2 = _mm512_add_ps( selector1, c_float_3p2 ); } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_4x48: + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 ) ); + __m512 selector3 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 2 * 16 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector3, c_float_0p2 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector3, c_float_1p2 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector2, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + + // c[3, 16-31] + c_float_3p1 = _mm512_add_ps( selector2, c_float_3p1 ); + + // c[3,32-47] + c_float_3p2 = _mm512_add_ps( selector3, c_float_3p2 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_4x48: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_max_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_max_ps( selector1, c_float_0p2 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + c_float_1p1 = _mm512_max_ps( selector1, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_max_ps( selector1, c_float_1p2 ); + + // c[2,0-15] + c_float_2p0 = _mm512_max_ps( selector1, c_float_2p0 ); + + // c[2,16-31] + c_float_2p1 = _mm512_max_ps( selector1, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_max_ps( selector1, c_float_2p2 ); + + // c[3,0-15] + c_float_3p0 = _mm512_max_ps( selector1, c_float_3p0 ); + + // c[3,16-31] + c_float_3p1 = _mm512_max_ps( selector1, c_float_3p1 ); + + // c[3,32-47] + c_float_3p2 = _mm512_max_ps( selector1, c_float_3p2 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_4x48: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_0p1) + + // c[0, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_0p2) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_1p1) + + // c[1, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_1p2) + + // c[2, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_2p0) + + // c[2, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_2p1) + + // c[2, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_2p2) + + // c[3, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_3p0) + + // c[3, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_3p1) + + // c[3, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_3p2) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_4x48_DISABLE: + ; // Store the results. // c[0,0-15] @@ -2639,6 +4063,13 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x48) // 3x48 bf16 kernel LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x48) { + static void* post_ops_labels[] = + { + &&POST_OPS_3x48_DISABLE, + &&POST_OPS_BIAS_3x48, + &&POST_OPS_RELU_3x48, + &&POST_OPS_RELU_SCALE_3x48 + }; dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -2653,17 +4084,17 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x48) __m512bh a_bf16_0; // Registers to use for accumulating C. - __m512 float_0p0 = _mm512_setzero_ps(); - __m512 float_0p1 = _mm512_setzero_ps(); - __m512 float_0p2 = _mm512_setzero_ps(); + __m512 c_float_0p0 = _mm512_setzero_ps(); + __m512 c_float_0p1 = _mm512_setzero_ps(); + __m512 c_float_0p2 = _mm512_setzero_ps(); - __m512 float_1p0 = _mm512_setzero_ps(); - __m512 float_1p1 = _mm512_setzero_ps(); - __m512 float_1p2 = _mm512_setzero_ps(); + __m512 c_float_1p0 = _mm512_setzero_ps(); + __m512 c_float_1p1 = _mm512_setzero_ps(); + __m512 c_float_1p2 = _mm512_setzero_ps(); - __m512 float_2p0 = _mm512_setzero_ps(); - __m512 float_2p1 = _mm512_setzero_ps(); - __m512 float_2p2 = _mm512_setzero_ps(); + __m512 c_float_2p0 = _mm512_setzero_ps(); + __m512 c_float_2p1 = _mm512_setzero_ps(); + __m512 c_float_2p2 = _mm512_setzero_ps(); for ( dim_t kr = 0; kr < k_full_pieces; kr += 1 ) { @@ -2676,27 +4107,27 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x48) // Perform column direction mat-mul with k = 2. // c[0,0-47] = a[0,kr:kr+2]*b[kr:kr+2,0-47] - float_0p0 = _mm512_dpbf16_ps( float_0p0, a_bf16_0, b0 ); - float_0p1 = _mm512_dpbf16_ps( float_0p1, a_bf16_0, b1 ); - float_0p2 = _mm512_dpbf16_ps( float_0p2, a_bf16_0, b2 ); + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); // Broadcast a[1,kr:kr+2]. a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 1 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[1,0-47] = a[1,kr:kr+2]*b[kr:kr+2,0-47] - float_1p0 = _mm512_dpbf16_ps( float_1p0, a_bf16_0, b0 ); - float_1p1 = _mm512_dpbf16_ps( float_1p1, a_bf16_0, b1 ); - float_1p2 = _mm512_dpbf16_ps( float_1p2, a_bf16_0, b2 ); + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_0, b1 ); + c_float_1p2 = _mm512_dpbf16_ps( c_float_1p2, a_bf16_0, b2 ); // Broadcast a[2,kr:kr+2]. a_bf16_0 = (__m512bh)_mm512_set1_epi32( *( int32_t* )( a + ( rs_a * 2 ) + ( cs_a * kr ) ) ); // Perform column direction mat-mul with k = 2. // c[2,0-47] = a[2,kr:kr+2]*b[kr:kr+2,0-47] - float_2p0 = _mm512_dpbf16_ps( float_2p0, a_bf16_0, b0 ); - float_2p1 = _mm512_dpbf16_ps( float_2p1, a_bf16_0, b1 ); - float_2p2 = _mm512_dpbf16_ps( float_2p2, a_bf16_0, b2 ); + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); + c_float_2p2 = _mm512_dpbf16_ps( c_float_2p2, a_bf16_0, b2 ); } // Handle k remainder. if ( k_partial_pieces > 0 ) @@ -2711,9 +4142,9 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x48) // Perform column direction mat-mul with k = 2. // c[0,0-47] = a[0,kr:kr+2]*b[kr:kr+2,0-47] - float_0p0 = _mm512_dpbf16_ps( float_0p0, a_bf16_0, b0 ); - float_0p1 = _mm512_dpbf16_ps( float_0p1, a_bf16_0, b1 ); - float_0p2 = _mm512_dpbf16_ps( float_0p2, a_bf16_0, b2 ); + c_float_0p0 = _mm512_dpbf16_ps( c_float_0p0, a_bf16_0, b0 ); + c_float_0p1 = _mm512_dpbf16_ps( c_float_0p1, a_bf16_0, b1 ); + c_float_0p2 = _mm512_dpbf16_ps( c_float_0p2, a_bf16_0, b2 ); // Broadcast a[1,kr:kr+2]. memcpy( &a_kfringe_buf, ( a + ( rs_a * 1 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); @@ -2721,9 +4152,9 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x48) // Perform column direction mat-mul with k = 2. // c[1,0-47] = a[1,kr:kr+2]*b[kr:kr+2,0-47] - float_1p0 = _mm512_dpbf16_ps( float_1p0, a_bf16_0, b0 ); - float_1p1 = _mm512_dpbf16_ps( float_1p1, a_bf16_0, b1 ); - float_1p2 = _mm512_dpbf16_ps( float_1p2, a_bf16_0, b2 ); + c_float_1p0 = _mm512_dpbf16_ps( c_float_1p0, a_bf16_0, b0 ); + c_float_1p1 = _mm512_dpbf16_ps( c_float_1p1, a_bf16_0, b1 ); + c_float_1p2 = _mm512_dpbf16_ps( c_float_1p2, a_bf16_0, b2 ); // Broadcast a[2,kr:kr+2]. memcpy( &a_kfringe_buf, ( a + ( rs_a * 2 ) + ( cs_a * k_full_pieces ) ), ( k_partial_pieces * sizeof( bfloat16 ) ) ); @@ -2731,9 +4162,9 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x48) // Perform column direction mat-mul with k = 2. // c[2,0-47] = a[2,kr:kr+2]*b[kr:kr+2,0-47] - float_2p0 = _mm512_dpbf16_ps( float_2p0, a_bf16_0, b0 ); - float_2p1 = _mm512_dpbf16_ps( float_2p1, a_bf16_0, b1 ); - float_2p2 = _mm512_dpbf16_ps( float_2p2, a_bf16_0, b2 ); + c_float_2p0 = _mm512_dpbf16_ps( c_float_2p0, a_bf16_0, b0 ); + c_float_2p1 = _mm512_dpbf16_ps( c_float_2p1, a_bf16_0, b1 ); + c_float_2p2 = _mm512_dpbf16_ps( c_float_2p2, a_bf16_0, b2 ); } // Load alpha and beta @@ -2741,17 +4172,17 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x48) __m512 selector2 = _mm512_set1_ps( beta ); // Scale by alpha - float_0p0 = _mm512_mul_ps( selector1, float_0p0 ); - float_0p1 = _mm512_mul_ps( selector1, float_0p1 ); - float_0p2 = _mm512_mul_ps( selector1, float_0p2 ); + c_float_0p0 = _mm512_mul_ps( selector1, c_float_0p0 ); + c_float_0p1 = _mm512_mul_ps( selector1, c_float_0p1 ); + c_float_0p2 = _mm512_mul_ps( selector1, c_float_0p2 ); - float_1p0 = _mm512_mul_ps( selector1, float_1p0 ); - float_1p1 = _mm512_mul_ps( selector1, float_1p1 ); - float_1p2 = _mm512_mul_ps( selector1, float_1p2 ); + c_float_1p0 = _mm512_mul_ps( selector1, c_float_1p0 ); + c_float_1p1 = _mm512_mul_ps( selector1, c_float_1p1 ); + c_float_1p2 = _mm512_mul_ps( selector1, c_float_1p2 ); - float_2p0 = _mm512_mul_ps( selector1, float_2p0 ); - float_2p1 = _mm512_mul_ps( selector1, float_2p1 ); - float_2p2 = _mm512_mul_ps( selector1, float_2p2 ); + c_float_2p0 = _mm512_mul_ps( selector1, c_float_2p0 ); + c_float_2p1 = _mm512_mul_ps( selector1, c_float_2p1 ); + c_float_2p2 = _mm512_mul_ps( selector1, c_float_2p2 ); // Scale C by beta. if ( beta != 0 ) @@ -2759,81 +4190,204 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x48) // c[0,0-15] selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 0*16 ) ); selector1 = _mm512_mul_ps( selector2, selector1 ); - float_0p0 = _mm512_add_ps( selector1, float_0p0 ); + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); // c[0, 16-31] selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 1*16 ) ); selector1 = _mm512_mul_ps( selector2, selector1 ); - float_0p1 = _mm512_add_ps( selector1, float_0p1 ); + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); // c[0,32-47] selector1 = _mm512_loadu_ps( c + ( rs_c * 0 ) + ( 2*16 ) ); selector1 = _mm512_mul_ps( selector2, selector1 ); - float_0p2 = _mm512_add_ps( selector1, float_0p2 ); + c_float_0p2 = _mm512_add_ps( selector1, c_float_0p2 ); // c[1,0-15] selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 0*16 ) ); selector1 = _mm512_mul_ps( selector2, selector1 ); - float_1p0 = _mm512_add_ps( selector1, float_1p0 ); + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); // c[1,16-31] selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 1*16 ) ); selector1 = _mm512_mul_ps( selector2, selector1 ); - float_1p1 = _mm512_add_ps( selector1, float_1p1 ); + c_float_1p1 = _mm512_add_ps( selector1, c_float_1p1 ); // c[1,32-47] selector1 = _mm512_loadu_ps( c + ( rs_c * 1 ) + ( 2*16 ) ); selector1 = _mm512_mul_ps( selector2, selector1 ); - float_1p2 = _mm512_add_ps( selector1, float_1p2 ); + c_float_1p2 = _mm512_add_ps( selector1, c_float_1p2 ); // c[2,0-15] selector1 = _mm512_loadu_ps( c + ( rs_c * 2 ) + ( 0*16 ) ); selector1 = _mm512_mul_ps( selector2, selector1 ); - float_2p0 = _mm512_add_ps( selector1, float_2p0 ); + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); // c[2,16-31] selector1 = _mm512_loadu_ps( c + ( rs_c * 2 ) + ( 1*16 ) ); selector1 = _mm512_mul_ps( selector2, selector1 ); - float_2p1 = _mm512_add_ps( selector1, float_2p1 ); + c_float_2p1 = _mm512_add_ps( selector1, c_float_2p1 ); // c[2,32-47] selector1 = _mm512_loadu_ps( c + ( rs_c * 2 ) + ( 2*16 ) ); selector1 = _mm512_mul_ps( selector2, selector1 ); - float_2p2 = _mm512_add_ps( selector1, float_2p2 ); + c_float_2p2 = _mm512_add_ps( selector1, c_float_2p2 ); } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_3x48: + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 ) ); + __m512 selector3 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 2 * 16 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector3, c_float_0p2 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector3, c_float_1p2 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector2, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_3x48: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_max_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_max_ps( selector1, c_float_0p2 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + c_float_1p1 = _mm512_max_ps( selector1, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_max_ps( selector1, c_float_1p2 ); + + // c[2,0-15] + c_float_2p0 = _mm512_max_ps( selector1, c_float_2p0 ); + + // c[2,16-31] + c_float_2p1 = _mm512_max_ps( selector1, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_max_ps( selector1, c_float_2p2 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_3x48: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_0p1) + + // c[0, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_0p2) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_1p1) + + // c[1, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_1p2) + + // c[2, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_2p0) + + // c[2, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_2p1) + + // c[2, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_2p2) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_3x48_DISABLE: + ; // Store the results. // c[0,0-15] - _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 0*16 ), float_0p0 ); + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 0*16 ), c_float_0p0 ); // c[0, 16-31] - _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 1*16 ), float_0p1 ); + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 1*16 ), c_float_0p1 ); // c[0,32-47] - _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 2*16 ), float_0p2 ); + _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 2*16 ), c_float_0p2 ); // c[1,0-15] - _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 0*16 ), float_1p0 ); + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 0*16 ), c_float_1p0 ); // c[1,16-31] - _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 1*16 ), float_1p1 ); + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 1*16 ), c_float_1p1 ); // c[1,32-47] - _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 2*16 ), float_1p2 ); + _mm512_storeu_ps( c + ( rs_c * 1 ) + ( 2*16 ), c_float_1p2 ); // c[2,0-15] - _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 0*16 ), float_2p0 ); + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 0*16 ), c_float_2p0 ); // c[2,16-31] - _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 1*16 ), float_2p1 ); + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 1*16 ), c_float_2p1 ); // c[2,32-47] - _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 2*16 ), float_2p2 ); + _mm512_storeu_ps( c + ( rs_c * 2 ) + ( 2*16 ), c_float_2p2 ); } // 2x48 bf16 kernel LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2x48) { + static void* post_ops_labels[] = + { + &&POST_OPS_2x48_DISABLE, + &&POST_OPS_BIAS_2x48, + &&POST_OPS_RELU_2x48, + &&POST_OPS_RELU_SCALE_2x48 + }; dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -2954,6 +4508,95 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2x48) selector1 = _mm512_mul_ps( selector2, selector1 ); c_float_1p2 = _mm512_add_ps( selector1, c_float_1p2 ); } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_2x48: + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 ) ); + __m512 selector3 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 2 * 16 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector3, c_float_0p2 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector3, c_float_1p2 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_2x48: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_max_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_max_ps( selector1, c_float_0p2 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + c_float_1p1 = _mm512_max_ps( selector1, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_max_ps( selector1, c_float_1p2 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_2x48: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_0p1) + + // c[0, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_0p2) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_1p1) + + // c[1, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_1p2) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_2x48_DISABLE: + ; // Store the results. // c[0,0-15] @@ -2978,6 +4621,13 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2x48) // 1x48 bf16 kernel LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1x48) { + static void* post_ops_labels[] = + { + &&POST_OPS_1x48_DISABLE, + &&POST_OPS_BIAS_1x48, + &&POST_OPS_RELU_1x48, + &&POST_OPS_RELU_SCALE_1x48 + }; dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -3056,6 +4706,68 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1x48) selector1 = _mm512_mul_ps( selector2, selector1 ); c_float_0p2 = _mm512_add_ps( selector1, c_float_0p2 ); } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_1x48: + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 ) ); + __m512 selector3 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 2 * 16 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector3, c_float_0p2 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_1x48: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_max_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_max_ps( selector1, c_float_0p2 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_1x48: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_0p1) + + // c[0, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_0p2) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_1x48_DISABLE: + ; // Store the results. // c[0,0-15] diff --git a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_n_fringe_bf16_amd512vnni.c b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_n_fringe_bf16_amd512vnni.c index 1908aa0c3f..f796a7d0ea 100644 --- a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_n_fringe_bf16_amd512vnni.c +++ b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_n_fringe_bf16_amd512vnni.c @@ -36,10 +36,18 @@ #include "blis.h" #include "lpgemm_kernels.h" +#include "lpgemm_f32_kern_macros.h" // 6xlt16 bf16 fringe kernel LPGEMM_N_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6xlt16) { + static void* post_ops_labels[] = + { + &&POST_OPS_6xLT16_DISABLE, + &&POST_OPS_BIAS_6xLT16, + &&POST_OPS_RELU_6xLT16, + &&POST_OPS_RELU_SCALE_6xLT16 + }; dim_t MR = 6; dim_t m_full_pieces = m0 / MR; dim_t m_full_pieces_loop_limit = m_full_pieces * MR; @@ -272,6 +280,89 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6xlt16) selector1 = _mm512_mul_ps( selector2, selector1 ); c_float_5p0 = _mm512_add_ps( selector1, c_float_5p0 ); } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_6xLT16: + { + memcpy( buf0, ( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j ), ( n0_rem * sizeof( float ) ) ); + selector1 = _mm512_loadu_ps( buf0 ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + + // c[4,0-15] + c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); + + // c[5,0-15] + c_float_5p0 = _mm512_add_ps( selector1, c_float_5p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_6xLT16: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_max_ps( selector1, c_float_2p0 ); + + // c[3,0-15] + c_float_3p0 = _mm512_max_ps( selector1, c_float_3p0 ); + + // c[4,0-15] + c_float_4p0 = _mm512_max_ps( selector1, c_float_4p0 ); + + // c[5,0-15] + c_float_5p0 = _mm512_max_ps( selector1, c_float_5p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_6xLT16: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + // c[2, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_2p0) + + // c[3, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_3p0) + + // c[4, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_4p0) + + // c[5, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_5p0) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_6xLT16_DISABLE: + ; // Store the results. // c[0,0-15] @@ -312,6 +403,7 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6xlt16) memcpy( c + ( rs_c * ( ir + 5 ) ) + ( 0*16 ), buf5, ( n0_rem * sizeof( float ) ) ); a = a + ( MR * ps_a ); + post_op_c_i += MR; } if ( m_partial_pieces > 0 ) @@ -397,6 +489,13 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6xlt16) // 6x16 bf16 fringe kernel LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x16) { + static void* post_ops_labels[] = + { + &&POST_OPS_6x16_DISABLE, + &&POST_OPS_BIAS_6x16, + &&POST_OPS_RELU_6x16, + &&POST_OPS_RELU_SCALE_6x16 + }; dim_t MR = 6; dim_t m_full_pieces = m0 / MR; dim_t m_full_pieces_loop_limit = m_full_pieces * MR; @@ -614,6 +713,89 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x16) selector1 = _mm512_mul_ps( selector2, selector1 ); c_float_5p0 = _mm512_add_ps( selector1, c_float_5p0 ); } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_6x16: + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + + // c[4,0-15] + c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); + + // c[5,0-15] + c_float_5p0 = _mm512_add_ps( selector1, c_float_5p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_6x16: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_max_ps( selector1, c_float_2p0 ); + + // c[3,0-15] + c_float_3p0 = _mm512_max_ps( selector1, c_float_3p0 ); + + // c[4,0-15] + c_float_4p0 = _mm512_max_ps( selector1, c_float_4p0 ); + + // c[5,0-15] + c_float_5p0 = _mm512_max_ps( selector1, c_float_5p0 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_6x16: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + // c[2, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_2p0) + + // c[3, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_3p0) + + // c[4, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_4p0) + + // c[5, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_5p0) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_6x16_DISABLE: + ; // Store the results. // c[0,0-15] @@ -635,6 +817,7 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x16) _mm512_storeu_ps( c + ( rs_c * ( ir + 5 ) ) + ( 0*16 ), c_float_5p0 ); a = a + ( MR * ps_a ); + post_op_c_i += MR; } if ( m_partial_pieces > 0 ) @@ -720,6 +903,13 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x16) // 6x32 bf16 fringe kernel LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x32) { + static void* post_ops_labels[] = + { + &&POST_OPS_6x32_DISABLE, + &&POST_OPS_BIAS_6x32, + &&POST_OPS_RELU_6x32, + &&POST_OPS_RELU_SCALE_6x32 + }; dim_t MR = 6; dim_t m_full_pieces = m0 / MR; dim_t m_full_pieces_loop_limit = m_full_pieces * MR; @@ -992,6 +1182,146 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x32) selector1 = _mm512_mul_ps( selector2, selector1 ); c_float_5p1 = _mm512_add_ps( selector1, c_float_5p1 ); } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_6x32: + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector2, c_float_2p1 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + + // c[3, 16-31] + c_float_3p1 = _mm512_add_ps( selector2, c_float_3p1 ); + + // c[4,0-15] + c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); + + // c[4, 16-31] + c_float_4p1 = _mm512_add_ps( selector2, c_float_4p1 ); + + // c[5,0-15] + c_float_5p0 = _mm512_add_ps( selector1, c_float_5p0 ); + + // c[5, 16-31] + c_float_5p1 = _mm512_add_ps( selector2, c_float_5p1 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_6x32: + { + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_max_ps( selector1, c_float_0p1 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + c_float_1p1 = _mm512_max_ps( selector1, c_float_1p1 ); + + // c[2,0-15] + c_float_2p0 = _mm512_max_ps( selector1, c_float_2p0 ); + + // c[2,16-31] + c_float_2p1 = _mm512_max_ps( selector1, c_float_2p1 ); + + // c[3,0-15] + c_float_3p0 = _mm512_max_ps( selector1, c_float_3p0 ); + + // c[3,16-31] + c_float_3p1 = _mm512_max_ps( selector1, c_float_3p1 ); + + // c[4,0-15] + c_float_4p0 = _mm512_max_ps( selector1, c_float_4p0 ); + + // c[4,16-31] + c_float_4p1 = _mm512_max_ps( selector1, c_float_4p1 ); + + // c[5,0-15] + c_float_5p0 = _mm512_max_ps( selector1, c_float_5p0 ); + + // c[5,16-31] + c_float_5p1 = _mm512_max_ps( selector1, c_float_5p1 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_6x32: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_0p1) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_1p1) + + // c[2, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_2p0) + + // c[2, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_2p1) + + // c[3, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_3p0) + + // c[3, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_3p1) + + // c[4, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_4p0) + + // c[4, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_4p1) + + // c[5, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_5p0) + + // c[5, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_5p1) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_6x32_DISABLE: + ; // Store the results. // c[0,0-15] @@ -1031,6 +1361,7 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x32) _mm512_storeu_ps( c + ( rs_c * ( ir + 5 ) ) + ( 1*16 ), c_float_5p1 ); a = a + ( MR * ps_a ); + post_op_c_i += MR; } if ( m_partial_pieces > 0 ) @@ -1116,6 +1447,13 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x32) // 6x48 bf16 fringe kernel LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x48) { + static void* post_ops_labels[] = + { + &&POST_OPS_6x48_DISABLE, + &&POST_OPS_BIAS_6x48, + &&POST_OPS_RELU_6x48, + &&POST_OPS_RELU_SCALE_6x48 + }; dim_t MR = 6; dim_t m_full_pieces = m0 / MR; dim_t m_full_pieces_loop_limit = m_full_pieces * MR; @@ -1447,8 +1785,205 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x48) selector1 = _mm512_loadu_ps( c + ( rs_c * ( ir + 5 ) ) + ( 2*16 ) ); selector1 = _mm512_mul_ps( selector2, selector1 ); c_float_5p2 = _mm512_add_ps( selector1, c_float_5p2 ); - } + // Post Ops + lpgemm_post_op* post_ops_list_temp = post_ops_list; + POST_OP_LABEL_LASTK_SAFE_JUMP +POST_OPS_BIAS_6x48: + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 ) ); + __m512 selector3 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 2 * 16 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector3, c_float_0p2 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector3, c_float_1p2 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector2, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + + // c[3, 16-31] + c_float_3p1 = _mm512_add_ps( selector2, c_float_3p1 ); + + // c[3,32-47] + c_float_3p2 = _mm512_add_ps( selector3, c_float_3p2 ); + + // c[4,0-15] + c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); + + // c[4, 16-31] + c_float_4p1 = _mm512_add_ps( selector2, c_float_4p1 ); + + // c[4,32-47] + c_float_4p2 = _mm512_add_ps( selector3, c_float_4p2 ); + + // c[5,0-15] + c_float_5p0 = _mm512_add_ps( selector1, c_float_5p0 ); + + // c[5, 16-31] + c_float_5p1 = _mm512_add_ps( selector2, c_float_5p1 ); + + // c[5,32-47] + c_float_5p2 = _mm512_add_ps( selector3, c_float_5p2 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_6x48: + { + //printf("relu\n"); + selector1 = _mm512_setzero_ps(); + + // c[0,0-15] + c_float_0p0 = _mm512_max_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_max_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_max_ps( selector1, c_float_0p2 ); + + // c[1,0-15] + c_float_1p0 = _mm512_max_ps( selector1, c_float_1p0 ); + + // c[1,16-31] + c_float_1p1 = _mm512_max_ps( selector1, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_max_ps( selector1, c_float_1p2 ); + + // c[2,0-15] + c_float_2p0 = _mm512_max_ps( selector1, c_float_2p0 ); + + // c[2,16-31] + c_float_2p1 = _mm512_max_ps( selector1, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_max_ps( selector1, c_float_2p2 ); + + // c[3,0-15] + c_float_3p0 = _mm512_max_ps( selector1, c_float_3p0 ); + + // c[3,16-31] + c_float_3p1 = _mm512_max_ps( selector1, c_float_3p1 ); + + // c[3,32-47] + c_float_3p2 = _mm512_max_ps( selector1, c_float_3p2 ); + + // c[4,0-15] + c_float_4p0 = _mm512_max_ps( selector1, c_float_4p0 ); + + // c[4,16-31] + c_float_4p1 = _mm512_max_ps( selector1, c_float_4p1 ); + + // c[4,32-47] + c_float_4p2 = _mm512_max_ps( selector1, c_float_4p2 ); + + // c[5,0-15] + c_float_5p0 = _mm512_max_ps( selector1, c_float_5p0 ); + + // c[5,16-31] + c_float_5p1 = _mm512_max_ps( selector1, c_float_5p1 ); + + // c[5,32-47] + c_float_5p2 = _mm512_max_ps( selector1, c_float_5p2 ); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_RELU_SCALE_6x48: + { + selector1 = _mm512_setzero_ps(); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args2 ) ); + + __mmask16 relu_cmp_mask; + + // c[0, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_0p0) + + // c[0, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_0p1) + + // c[0, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_0p2) + + // c[1, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_1p0) + + // c[1, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_1p1) + + // c[1, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_1p2) + + // c[2, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_2p0) + + // c[2, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_2p1) + + // c[2, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_2p2) + + // c[3, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_3p0) + + // c[3, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_3p1) + + // c[3, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_3p2) + + // c[4, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_4p0) + + // c[4, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_4p1) + + // c[4, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_4p2) + + // c[5, 0-15] + RELU_SCALE_OP_F32_AVX512(c_float_5p0) + + // c[5, 16-31] + RELU_SCALE_OP_F32_AVX512(c_float_5p1) + + // c[5, 32-47] + RELU_SCALE_OP_F32_AVX512(c_float_5p2) + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_6x48_DISABLE: + ; // Store the results. // c[0,0-15] @@ -1506,6 +2041,7 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x48) _mm512_storeu_ps( c + ( rs_c * ( ir + 5 ) ) + ( 2*16 ), c_float_5p2 ); a = a + ( MR * ps_a ); + post_op_c_i += MR; } From 5c42afada856dff98aed43284c210d800bf91a06 Mon Sep 17 00:00:00 2001 From: Dipal M Zambare Date: Tue, 30 Aug 2022 12:21:38 +0530 Subject: [PATCH 209/243] Revert "CBLAS/BLAS interface decoupling for level 3 APIs" This reverts commit d925ebeb06fefae9d77f1fbcd9e0029d601aaa29. Change-Id: I2e842b29c1fedbe14bf913949cf978f3e7515ff3 --- frame/compat/bla_gemm.c | 41 +----- frame/compat/bla_gemm.h | 13 -- frame/compat/bla_gemm3m.c | 40 +----- frame/compat/bla_gemm3m.h | 15 +- frame/compat/bla_gemm_amd.c | 82 +---------- frame/compat/bla_gemmt.c | 40 +----- frame/compat/bla_gemmt.h | 15 +- frame/compat/bla_hemm.c | 38 +---- frame/compat/bla_hemm.h | 13 -- frame/compat/bla_her2k.c | 40 +----- frame/compat/bla_her2k.h | 13 -- frame/compat/bla_herk.c | 38 +---- frame/compat/bla_herk.h | 12 -- frame/compat/bla_symm.c | 38 +---- frame/compat/bla_symm.h | 13 -- frame/compat/bla_syr2k.c | 38 +---- frame/compat/bla_syr2k.h | 13 -- frame/compat/bla_syrk.c | 36 +---- frame/compat/bla_syrk.h | 12 -- frame/compat/bla_trmm.c | 38 +---- frame/compat/bla_trmm.h | 13 -- frame/compat/bla_trsm.c | 38 +---- frame/compat/bla_trsm.h | 13 -- frame/compat/bla_trsm_amd.c | 108 ++------------ frame/compat/cblas/src/cblas_f77.h | 74 +++++----- frame/include/bli_macro_defs.h | 4 +- frame/util/bli_util_api_wrap.c | 218 ++++++++++++++--------------- 27 files changed, 210 insertions(+), 846 deletions(-) diff --git a/frame/compat/bla_gemm.c b/frame/compat/bla_gemm.c index 78c7f5f4b6..931c80243a 100644 --- a/frame/compat/bla_gemm.c +++ b/frame/compat/bla_gemm.c @@ -44,7 +44,7 @@ #undef GENTFUNC #define GENTFUNC( ftype, ch, blasname, blisname ) \ \ -void PASTEF77S(ch,blasname) \ +void PASTEF77(ch,blasname) \ ( \ const f77_char* transa, \ const f77_char* transb, \ @@ -136,30 +136,14 @@ void PASTEF77S(ch,blasname) \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ -} \ -void PASTEF77(ch,blasname) \ - ( \ - const f77_char* transa, \ - const f77_char* transb, \ - const f77_int* m, \ - const f77_int* n, \ - const f77_int* k, \ - const ftype* alpha, \ - const ftype* a, const f77_int* lda, \ - const ftype* b, const f77_int* ldb, \ - const ftype* beta, \ - ftype* c, const f77_int* ldc \ - ) \ -{ \ - PASTEF77S(ch,blasname) ( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc ); \ -} \ +} #else #undef GENTFUNC #define GENTFUNC( ftype, ch, blasname, blisname ) \ \ -void PASTEF77S(ch,blasname) \ +void PASTEF77(ch,blasname) \ ( \ const f77_char* transa, \ const f77_char* transb, \ @@ -334,24 +318,7 @@ void PASTEF77S(ch,blasname) \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ -} \ -void PASTEF77(ch,blasname) \ - ( \ - const f77_char* transa, \ - const f77_char* transb, \ - const f77_int* m, \ - const f77_int* n, \ - const f77_int* k, \ - const ftype* alpha, \ - const ftype* a, const f77_int* lda, \ - const ftype* b, const f77_int* ldb, \ - const ftype* beta, \ - ftype* c, const f77_int* ldc \ - ) \ -{ \ - PASTEF77S(ch,blasname) ( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc ); \ -} \ - +} #endif #ifdef BLIS_ENABLE_BLAS diff --git a/frame/compat/bla_gemm.h b/frame/compat/bla_gemm.h index d8fe6ddb94..c9ea83149a 100644 --- a/frame/compat/bla_gemm.h +++ b/frame/compat/bla_gemm.h @@ -41,19 +41,6 @@ #define GENTPROT( ftype, ch, blasname ) \ \ BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ - ( \ - const f77_char* transa, \ - const f77_char* transb, \ - const f77_int* m, \ - const f77_int* n, \ - const f77_int* k, \ - const ftype* alpha, \ - const ftype* a, const f77_int* lda, \ - const ftype* b, const f77_int* ldb, \ - const ftype* beta, \ - ftype* c, const f77_int* ldc \ - ); \ -BLIS_EXPORT_BLAS void PASTEF77S(ch,blasname) \ ( \ const f77_char* transa, \ const f77_char* transb, \ diff --git a/frame/compat/bla_gemm3m.c b/frame/compat/bla_gemm3m.c index 4ecbba5551..665c8643dd 100644 --- a/frame/compat/bla_gemm3m.c +++ b/frame/compat/bla_gemm3m.c @@ -44,7 +44,7 @@ #undef GENTFUNC #define GENTFUNC( ftype, ch, blasname, blisname ) \ \ -void PASTEF77S(ch,blasname) \ +void PASTEF77(ch,blasname) \ ( \ const f77_char* transa, \ const f77_char* transb, \ @@ -131,30 +131,14 @@ void PASTEF77S(ch,blasname) \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ -} \ -void PASTEF77(ch,blasname) \ - ( \ - const f77_char* transa, \ - const f77_char* transb, \ - const f77_int* m, \ - const f77_int* n, \ - const f77_int* k, \ - const ftype* alpha, \ - const ftype* a, const f77_int* lda, \ - const ftype* b, const f77_int* ldb, \ - const ftype* beta, \ - ftype* c, const f77_int* ldc \ - ) \ -{ \ - PASTEF77S(ch,blasname) ( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc ); \ -} \ +} #else #undef GENTFUNC #define GENTFUNC( ftype, ch, blasname, blisname ) \ \ -void PASTEF77S(ch,blasname) \ +void PASTEF77(ch,blasname) \ ( \ const f77_char* transa, \ const f77_char* transb, \ @@ -256,23 +240,7 @@ void PASTEF77S(ch,blasname) \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ -} \ -void PASTEF77(ch,blasname) \ - ( \ - const f77_char* transa, \ - const f77_char* transb, \ - const f77_int* m, \ - const f77_int* n, \ - const f77_int* k, \ - const ftype* alpha, \ - const ftype* a, const f77_int* lda, \ - const ftype* b, const f77_int* ldb, \ - const ftype* beta, \ - ftype* c, const f77_int* ldc \ - ) \ -{ \ - PASTEF77S(ch,blasname) ( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc ); \ -} \ +} #endif diff --git a/frame/compat/bla_gemm3m.h b/frame/compat/bla_gemm3m.h index d64e3f199f..1063d85c03 100644 --- a/frame/compat/bla_gemm3m.h +++ b/frame/compat/bla_gemm3m.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2020-2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -40,19 +40,6 @@ #define GENTPROT( ftype, ch, blasname ) \ \ BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ - ( \ - const f77_char* transa, \ - const f77_char* transb, \ - const f77_int* m, \ - const f77_int* n, \ - const f77_int* k, \ - const ftype* alpha, \ - const ftype* a, const f77_int* lda, \ - const ftype* b, const f77_int* ldb, \ - const ftype* beta, \ - ftype* c, const f77_int* ldc \ - ); \ -BLIS_EXPORT_BLAS void PASTEF77S(ch,blasname) \ ( \ const f77_char* transa, \ const f77_char* transb, \ diff --git a/frame/compat/bla_gemm_amd.c b/frame/compat/bla_gemm_amd.c index bf710bdae7..c9a5c96039 100644 --- a/frame/compat/bla_gemm_amd.c +++ b/frame/compat/bla_gemm_amd.c @@ -44,7 +44,7 @@ #undef GENTFUNC #define GENTFUNC( ftype, ch, blasname, blisname ) \ \ -void PASTEF77S(ch,blasname) \ +void PASTEF77(ch,blasname) \ ( \ const f77_char* transa, \ const f77_char* transb, \ @@ -136,32 +136,14 @@ void PASTEF77S(ch,blasname) \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ -} \ -\ -void PASTEF77(ch,blasname) \ - ( \ - const f77_char* transa, \ - const f77_char* transb, \ - const f77_int* m, \ - const f77_int* n, \ - const f77_int* k, \ - const ftype* alpha, \ - const ftype* a, const f77_int* lda, \ - const ftype* b, const f77_int* ldb, \ - const ftype* beta, \ - ftype* c, const f77_int* ldc \ - ) \ -{ \ -\ - PASTEF77S(ch,blasname) ( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc ); \ -} \ +} #else #undef GENTFUNC #define GENTFUNC( ftype, ch, blasname, blisname ) \ \ -void PASTEF77S(ch,blasname) \ +void PASTEF77(ch,blasname) \ ( \ const f77_char* transa, \ const f77_char* transb, \ @@ -336,30 +318,11 @@ void PASTEF77S(ch,blasname) \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ -} \ -\ -void PASTEF77(ch,blasname) \ - ( \ - const f77_char* transa, \ - const f77_char* transb, \ - const f77_int* m, \ - const f77_int* n, \ - const f77_int* k, \ - const ftype* alpha, \ - const ftype* a, const f77_int* lda, \ - const ftype* b, const f77_int* ldb, \ - const ftype* beta, \ - ftype* c, const f77_int* ldc \ - ) \ -{ \ -\ - PASTEF77S(ch,blasname) ( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc ); \ -} \ - +} #endif #ifdef BLIS_ENABLE_BLAS -void dgemm_blis_impl +void dgemm_ ( const f77_char* transa, const f77_char* transb, @@ -695,24 +658,7 @@ void dgemm_blis_impl bli_finalize_auto(); } // end of dgemm_ -void dgemm_ -( - const f77_char* transa, - const f77_char* transb, - const f77_int* m, - const f77_int* n, - const f77_int* k, - const double* alpha, - const double* a, const f77_int* lda, - const double* b, const f77_int* ldb, - const double* beta, - double* c, const f77_int* ldc -) -{ - dgemm_blis_impl(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); -} - -void zgemm_blis_impl +void zgemm_ ( const f77_char* transa, const f77_char* transb, @@ -934,22 +880,6 @@ void zgemm_blis_impl bli_finalize_auto(); }// end of zgemm_ -void zgemm_ - ( - const f77_char* transa, - const f77_char* transb, - const f77_int* m, - const f77_int* n, - const f77_int* k, - const dcomplex* alpha, - const dcomplex* a, const f77_int* lda, - const dcomplex* b, const f77_int* ldb, - const dcomplex* beta, - dcomplex* c, const f77_int* ldc - ) -{ - zgemm_blis_impl(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); -} INSERT_GENTFUNC_BLAS_SC( gemm, gemm ) diff --git a/frame/compat/bla_gemmt.c b/frame/compat/bla_gemmt.c index f8f6fa2de6..7abad40acf 100644 --- a/frame/compat/bla_gemmt.c +++ b/frame/compat/bla_gemmt.c @@ -44,7 +44,7 @@ #undef GENTFUNC #define GENTFUNC( ftype, ch, blasname, blisname ) \ \ -void PASTEF77S(ch,blasname) \ +void PASTEF77(ch,blasname) \ ( \ const f77_char* uploc, \ const f77_char* transa, \ @@ -134,30 +134,14 @@ void PASTEF77S(ch,blasname) \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ -} \ -void PASTEF77(ch,blasname) \ - ( \ - const f77_char* uploc, \ - const f77_char* transa, \ - const f77_char* transb, \ - const f77_int* n, \ - const f77_int* k, \ - const ftype* alpha, \ - const ftype* a, const f77_int* lda, \ - const ftype* b, const f77_int* ldb, \ - const ftype* beta, \ - ftype* c, const f77_int* ldc \ - ) \ -{ \ - PASTEF77S(ch,blasname) ( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc ); \ -} \ +} #else #undef GENTFUNC #define GENTFUNC( ftype, ch, blasname, blisname ) \ \ -void PASTEF77S(ch,blasname) \ +void PASTEF77(ch,blasname) \ ( \ const f77_char* uploc, \ const f77_char* transa, \ @@ -263,23 +247,7 @@ void PASTEF77S(ch,blasname) \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ -} \ -void PASTEF77(ch,blasname) \ - ( \ - const f77_char* uploc, \ - const f77_char* transa, \ - const f77_char* transb, \ - const f77_int* n, \ - const f77_int* k, \ - const ftype* alpha, \ - const ftype* a, const f77_int* lda, \ - const ftype* b, const f77_int* ldb, \ - const ftype* beta, \ - ftype* c, const f77_int* ldc \ - ) \ -{ \ - PASTEF77S(ch,blasname) ( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc ); \ -} \ +} #endif diff --git a/frame/compat/bla_gemmt.h b/frame/compat/bla_gemmt.h index d4efb995ce..8043d68291 100644 --- a/frame/compat/bla_gemmt.h +++ b/frame/compat/bla_gemmt.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2020-2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -40,19 +40,6 @@ #define GENTPROT( ftype, ch, blasname ) \ \ BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ - ( \ - const f77_char* uploc, \ - const f77_char* transa, \ - const f77_char* transb, \ - const f77_int* n, \ - const f77_int* k, \ - const ftype* alpha, \ - const ftype* a, const f77_int* lda, \ - const ftype* b, const f77_int* ldb, \ - const ftype* beta, \ - ftype* c, const f77_int* ldc \ - ); \ -BLIS_EXPORT_BLAS void PASTEF77S(ch,blasname) \ ( \ const f77_char* uploc, \ const f77_char* transa, \ diff --git a/frame/compat/bla_hemm.c b/frame/compat/bla_hemm.c index ed3cbb5178..0e003012d2 100644 --- a/frame/compat/bla_hemm.c +++ b/frame/compat/bla_hemm.c @@ -45,7 +45,7 @@ #undef GENTFUNCCO #define GENTFUNCCO( ftype, ftype_r, ch, chr, blasname, blisname ) \ \ -void PASTEF77S(ch,blasname) \ +void PASTEF77(ch,blasname) \ ( \ const f77_char* side, \ const f77_char* uploa, \ @@ -132,29 +132,14 @@ void PASTEF77S(ch,blasname) \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ -} \ -void PASTEF77(ch,blasname) \ - ( \ - const f77_char* side, \ - const f77_char* uploa, \ - const f77_int* m, \ - const f77_int* n, \ - const ftype* alpha, \ - const ftype* a, const f77_int* lda, \ - const ftype* b, const f77_int* ldb, \ - const ftype* beta, \ - ftype* c, const f77_int* ldc \ - ) \ -{ \ - PASTEF77S(ch,blasname) ( side, uploa, m, n, alpha, a, lda, b, ldb, beta, c, ldc ); \ - } \ +} #else #undef GENTFUNCCO #define GENTFUNCCO( ftype, ftype_r, ch, chr, blasname, blisname ) \ \ -void PASTEF77S(ch,blasname) \ +void PASTEF77(ch,blasname) \ ( \ const f77_char* side, \ const f77_char* uploa, \ @@ -263,22 +248,7 @@ void PASTEF77S(ch,blasname) \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ -} \ -void PASTEF77(ch,blasname) \ - ( \ - const f77_char* side, \ - const f77_char* uploa, \ - const f77_int* m, \ - const f77_int* n, \ - const ftype* alpha, \ - const ftype* a, const f77_int* lda, \ - const ftype* b, const f77_int* ldb, \ - const ftype* beta, \ - ftype* c, const f77_int* ldc \ - ) \ -{ \ - PASTEF77S(ch,blasname) ( side, uploa, m, n, alpha, a, lda, b, ldb, beta, c, ldc ); \ - } \ +} #endif diff --git a/frame/compat/bla_hemm.h b/frame/compat/bla_hemm.h index 7054be7c90..711877edee 100644 --- a/frame/compat/bla_hemm.h +++ b/frame/compat/bla_hemm.h @@ -5,7 +5,6 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -41,18 +40,6 @@ #define GENTPROTCO( ftype, ftype_r, ch, chr, blasname ) \ \ BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ - ( \ - const f77_char* side, \ - const f77_char* uploa, \ - const f77_int* m, \ - const f77_int* n, \ - const ftype* alpha, \ - const ftype* a, const f77_int* lda, \ - const ftype* b, const f77_int* ldb, \ - const ftype* beta, \ - ftype* c, const f77_int* ldc \ - ); \ -BLIS_EXPORT_BLAS void PASTEF77S(ch,blasname) \ ( \ const f77_char* side, \ const f77_char* uploa, \ diff --git a/frame/compat/bla_her2k.c b/frame/compat/bla_her2k.c index cba6432eb3..e21a2cda41 100755 --- a/frame/compat/bla_her2k.c +++ b/frame/compat/bla_her2k.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019 - 2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 2021, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -45,7 +45,7 @@ #undef GENTFUNCCO #define GENTFUNCCO( ftype, ftype_r, ch, chr, blasname, blisname ) \ \ -void PASTEF77S(ch,blasname) \ +void PASTEF77(ch,blasname) \ ( \ const f77_char* uploc, \ const f77_char* transa, \ @@ -137,29 +137,14 @@ void PASTEF77S(ch,blasname) \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ -} \ -void PASTEF77(ch,blasname) \ - ( \ - const f77_char* uploc, \ - const f77_char* transa, \ - const f77_int* m, \ - const f77_int* k, \ - const ftype* alpha, \ - const ftype* a, const f77_int* lda, \ - const ftype* b, const f77_int* ldb, \ - const ftype_r* beta, \ - ftype* c, const f77_int* ldc \ - ) \ -{ \ - PASTEF77S(ch,blasname) ( uploc, transa, m, k, alpha, a, lda, b, ldb, beta, c, ldc ); \ - } \ +} #else #undef GENTFUNCCO #define GENTFUNCCO( ftype, ftype_r, ch, chr, blasname, blisname ) \ \ -void PASTEF77S(ch,blasname) \ +void PASTEF77(ch,blasname) \ ( \ const f77_char* uploc, \ const f77_char* transa, \ @@ -273,22 +258,7 @@ void PASTEF77S(ch,blasname) \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ -} \ -void PASTEF77(ch,blasname) \ - ( \ - const f77_char* uploc, \ - const f77_char* transa, \ - const f77_int* m, \ - const f77_int* k, \ - const ftype* alpha, \ - const ftype* a, const f77_int* lda, \ - const ftype* b, const f77_int* ldb, \ - const ftype_r* beta, \ - ftype* c, const f77_int* ldc \ - ) \ -{ \ - PASTEF77S(ch,blasname) ( uploc, transa, m, k, alpha, a, lda, b, ldb, beta, c, ldc ); \ - } \ +} #endif diff --git a/frame/compat/bla_her2k.h b/frame/compat/bla_her2k.h index a3fa413027..c771f78d4c 100644 --- a/frame/compat/bla_her2k.h +++ b/frame/compat/bla_her2k.h @@ -5,7 +5,6 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -41,18 +40,6 @@ #define GENTPROTCO( ftype, ftype_r, ch, chr, blasname ) \ \ BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ - ( \ - const f77_char* uploc, \ - const f77_char* transa, \ - const f77_int* m, \ - const f77_int* k, \ - const ftype* alpha, \ - const ftype* a, const f77_int* lda, \ - const ftype* b, const f77_int* ldb, \ - const ftype_r* beta, \ - ftype* c, const f77_int* ldc \ - ); \ -BLIS_EXPORT_BLAS void PASTEF77S(ch,blasname) \ ( \ const f77_char* uploc, \ const f77_char* transa, \ diff --git a/frame/compat/bla_herk.c b/frame/compat/bla_herk.c index b07ee180cc..36188e6a66 100755 --- a/frame/compat/bla_herk.c +++ b/frame/compat/bla_herk.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019 - 2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 2021, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -45,7 +45,7 @@ #undef GENTFUNCCO #define GENTFUNCCO( ftype, ftype_r, ch, chr, blasname, blisname ) \ \ -void PASTEF77S(ch,blasname) \ +void PASTEF77(ch,blasname) \ ( \ const f77_char* uploc, \ const f77_char* transa, \ @@ -131,28 +131,14 @@ void PASTEF77S(ch,blasname) \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ -} \ -void PASTEF77(ch,blasname) \ - ( \ - const f77_char* uploc, \ - const f77_char* transa, \ - const f77_int* m, \ - const f77_int* k, \ - const ftype_r* alpha, \ - const ftype* a, const f77_int* lda, \ - const ftype_r* beta, \ - ftype* c, const f77_int* ldc \ - ) \ -{ \ - PASTEF77S(ch,blasname) ( uploc, transa, m, k, alpha, a, lda, beta, c, ldc ); \ - } \ +} #else #undef GENTFUNCCO #define GENTFUNCCO( ftype, ftype_r, ch, chr, blasname, blisname ) \ \ -void PASTEF77S(ch,blasname) \ +void PASTEF77(ch,blasname) \ ( \ const f77_char* uploc, \ const f77_char* transa, \ @@ -256,21 +242,7 @@ void PASTEF77S(ch,blasname) \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ -} \ -void PASTEF77(ch,blasname) \ - ( \ - const f77_char* uploc, \ - const f77_char* transa, \ - const f77_int* m, \ - const f77_int* k, \ - const ftype_r* alpha, \ - const ftype* a, const f77_int* lda, \ - const ftype_r* beta, \ - ftype* c, const f77_int* ldc \ - ) \ -{ \ - PASTEF77S(ch,blasname) ( uploc, transa, m, k, alpha, a, lda, beta, c, ldc ); \ -} \ +} #endif diff --git a/frame/compat/bla_herk.h b/frame/compat/bla_herk.h index 8ec9183e8f..e649a74abb 100644 --- a/frame/compat/bla_herk.h +++ b/frame/compat/bla_herk.h @@ -5,7 +5,6 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -41,17 +40,6 @@ #define GENTPROTCO( ftype, ftype_r, ch, chr, blasname ) \ \ BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ - ( \ - const f77_char* uploc, \ - const f77_char* transa, \ - const f77_int* m, \ - const f77_int* k, \ - const ftype_r* alpha, \ - const ftype* a, const f77_int* lda, \ - const ftype_r* beta, \ - ftype* c, const f77_int* ldc \ - ); \ -BLIS_EXPORT_BLAS void PASTEF77S(ch,blasname) \ ( \ const f77_char* uploc, \ const f77_char* transa, \ diff --git a/frame/compat/bla_symm.c b/frame/compat/bla_symm.c index 7b915a5edb..85aebb435f 100755 --- a/frame/compat/bla_symm.c +++ b/frame/compat/bla_symm.c @@ -45,7 +45,7 @@ #undef GENTFUNC #define GENTFUNC( ftype, ch, blasname, blisname ) \ \ -void PASTEF77S(ch,blasname) \ +void PASTEF77(ch,blasname) \ ( \ const f77_char* side, \ const f77_char* uploa, \ @@ -131,29 +131,14 @@ void PASTEF77S(ch,blasname) \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ -} \ -void PASTEF77(ch,blasname) \ - ( \ - const f77_char* side, \ - const f77_char* uploa, \ - const f77_int* m, \ - const f77_int* n, \ - const ftype* alpha, \ - const ftype* a, const f77_int* lda, \ - const ftype* b, const f77_int* ldb, \ - const ftype* beta, \ - ftype* c, const f77_int* ldc \ - ) \ -{ \ - PASTEF77S(ch,blasname) ( side, uploa, m, n, alpha, a, lda, b, ldb, beta, c, ldc ); \ - } \ +} #else #undef GENTFUNC #define GENTFUNC( ftype, ch, blasname, blisname ) \ \ -void PASTEF77S(ch,blasname) \ +void PASTEF77(ch,blasname) \ ( \ const f77_char* side, \ const f77_char* uploa, \ @@ -261,22 +246,7 @@ void PASTEF77S(ch,blasname) \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ -} \ -void PASTEF77(ch,blasname) \ - ( \ - const f77_char* side, \ - const f77_char* uploa, \ - const f77_int* m, \ - const f77_int* n, \ - const ftype* alpha, \ - const ftype* a, const f77_int* lda, \ - const ftype* b, const f77_int* ldb, \ - const ftype* beta, \ - ftype* c, const f77_int* ldc \ - ) \ -{ \ - PASTEF77S(ch,blasname) ( side, uploa, m, n, alpha, a, lda, b, ldb, beta, c, ldc ); \ -} \ +} #endif diff --git a/frame/compat/bla_symm.h b/frame/compat/bla_symm.h index f10e1cbb86..b186e4b436 100644 --- a/frame/compat/bla_symm.h +++ b/frame/compat/bla_symm.h @@ -5,7 +5,6 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -41,18 +40,6 @@ #define GENTPROT( ftype, ch, blasname ) \ \ BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ - ( \ - const f77_char* side, \ - const f77_char* uploa, \ - const f77_int* m, \ - const f77_int* n, \ - const ftype* alpha, \ - const ftype* a, const f77_int* lda, \ - const ftype* b, const f77_int* ldb, \ - const ftype* beta, \ - ftype* c, const f77_int* ldc \ - ); \ -BLIS_EXPORT_BLAS void PASTEF77S(ch,blasname) \ ( \ const f77_char* side, \ const f77_char* uploa, \ diff --git a/frame/compat/bla_syr2k.c b/frame/compat/bla_syr2k.c index 751e008ae4..6a4f31b969 100644 --- a/frame/compat/bla_syr2k.c +++ b/frame/compat/bla_syr2k.c @@ -45,7 +45,7 @@ #undef GENTFUNC #define GENTFUNC( ftype, ch, blasname, blisname ) \ \ -void PASTEF77S(ch,blasname) \ +void PASTEF77(ch,blasname) \ ( \ const f77_char* uploc, \ const f77_char* transa, \ @@ -139,29 +139,14 @@ void PASTEF77S(ch,blasname) \ /* Finalize BLIS. */ \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ bli_finalize_auto(); \ -} \ -void PASTEF77(ch,blasname) \ - ( \ - const f77_char* uploc, \ - const f77_char* transa, \ - const f77_int* m, \ - const f77_int* k, \ - const ftype* alpha, \ - const ftype* a, const f77_int* lda, \ - const ftype* b, const f77_int* ldb, \ - const ftype* beta, \ - ftype* c, const f77_int* ldc \ - ) \ -{ \ - PASTEF77S(ch,blasname) ( uploc, transa, m, k, alpha, a, lda, b, ldb, beta, c, ldc ); \ - } \ +} #else #undef GENTFUNC #define GENTFUNC( ftype, ch, blasname, blisname ) \ \ -void PASTEF77S(ch,blasname) \ +void PASTEF77(ch,blasname) \ ( \ const f77_char* uploc, \ const f77_char* transa, \ @@ -277,22 +262,7 @@ void PASTEF77S(ch,blasname) \ /* Finalize BLIS. */ \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ bli_finalize_auto(); \ -} \ -void PASTEF77(ch,blasname) \ - ( \ - const f77_char* uploc, \ - const f77_char* transa, \ - const f77_int* m, \ - const f77_int* k, \ - const ftype* alpha, \ - const ftype* a, const f77_int* lda, \ - const ftype* b, const f77_int* ldb, \ - const ftype* beta, \ - ftype* c, const f77_int* ldc \ - ) \ -{ \ - PASTEF77S(ch,blasname) ( uploc, transa, m, k, alpha, a, lda, b, ldb, beta, c, ldc ); \ -} \ +} #endif diff --git a/frame/compat/bla_syr2k.h b/frame/compat/bla_syr2k.h index fc127d9ea5..91d9a3acf8 100644 --- a/frame/compat/bla_syr2k.h +++ b/frame/compat/bla_syr2k.h @@ -5,7 +5,6 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -41,18 +40,6 @@ #define GENTPROT( ftype, ch, blasname ) \ \ BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ - ( \ - const f77_char* uploc, \ - const f77_char* transa, \ - const f77_int* m, \ - const f77_int* k, \ - const ftype* alpha, \ - const ftype* a, const f77_int* lda, \ - const ftype* b, const f77_int* ldb, \ - const ftype* beta, \ - ftype* c, const f77_int* ldc \ - ); \ -BLIS_EXPORT_BLAS void PASTEF77S(ch,blasname) \ ( \ const f77_char* uploc, \ const f77_char* transa, \ diff --git a/frame/compat/bla_syrk.c b/frame/compat/bla_syrk.c index b2ec611f58..376b23aec9 100644 --- a/frame/compat/bla_syrk.c +++ b/frame/compat/bla_syrk.c @@ -45,7 +45,7 @@ #undef GENTFUNC #define GENTFUNC( ftype, ch, blasname, blisname ) \ \ -void PASTEF77S(ch,blasname) \ +void PASTEF77(ch,blasname) \ ( \ const f77_char* uploc, \ const f77_char* transa, \ @@ -133,28 +133,14 @@ void PASTEF77S(ch,blasname) \ /* Finalize BLIS. */ \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ bli_finalize_auto(); \ -} \ -void PASTEF77(ch,blasname) \ - ( \ - const f77_char* uploc, \ - const f77_char* transa, \ - const f77_int* m, \ - const f77_int* k, \ - const ftype* alpha, \ - const ftype* a, const f77_int* lda, \ - const ftype* beta, \ - ftype* c, const f77_int* ldc \ - ) \ -{ \ - PASTEF77S(ch,blasname) ( uploc, transa, m, k, alpha, a, lda, beta, c, ldc ); \ -} \ +} #else #undef GENTFUNC #define GENTFUNC( ftype, ch, blasname, blisname ) \ \ -void PASTEF77S(ch,blasname) \ +void PASTEF77(ch,blasname) \ ( \ const f77_char* uploc, \ const f77_char* transa, \ @@ -259,21 +245,7 @@ void PASTEF77S(ch,blasname) \ /* Finalize BLIS. */ \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ bli_finalize_auto(); \ -} \ -void PASTEF77(ch,blasname) \ - ( \ - const f77_char* uploc, \ - const f77_char* transa, \ - const f77_int* m, \ - const f77_int* k, \ - const ftype* alpha, \ - const ftype* a, const f77_int* lda, \ - const ftype* beta, \ - ftype* c, const f77_int* ldc \ - ) \ -{ \ - PASTEF77S(ch,blasname) ( uploc, transa, m, k, alpha, a, lda, beta, c, ldc ); \ -} \ +} #endif diff --git a/frame/compat/bla_syrk.h b/frame/compat/bla_syrk.h index c87dc6694c..b6ca938a6f 100644 --- a/frame/compat/bla_syrk.h +++ b/frame/compat/bla_syrk.h @@ -5,7 +5,6 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -41,17 +40,6 @@ #define GENTPROT( ftype, ch, blasname ) \ \ BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ - ( \ - const f77_char* uploc, \ - const f77_char* transa, \ - const f77_int* m, \ - const f77_int* k, \ - const ftype* alpha, \ - const ftype* a, const f77_int* lda, \ - const ftype* beta, \ - ftype* c, const f77_int* ldc \ - ); \ -BLIS_EXPORT_BLAS void PASTEF77S(ch,blasname) \ ( \ const f77_char* uploc, \ const f77_char* transa, \ diff --git a/frame/compat/bla_trmm.c b/frame/compat/bla_trmm.c index 59c64b90e1..c319b3ab51 100644 --- a/frame/compat/bla_trmm.c +++ b/frame/compat/bla_trmm.c @@ -45,7 +45,7 @@ #undef GENTFUNC #define GENTFUNC( ftype, ch, blasname, blisname ) \ \ -void PASTEF77S(ch,blasname) \ +void PASTEF77(ch,blasname) \ ( \ const f77_char* side, \ const f77_char* uploa, \ @@ -131,29 +131,14 @@ void PASTEF77S(ch,blasname) \ /* Finalize BLIS. */ \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ bli_finalize_auto(); \ -} \ -void PASTEF77(ch,blasname) \ - ( \ - const f77_char* side, \ - const f77_char* uploa, \ - const f77_char* transa, \ - const f77_char* diaga, \ - const f77_int* m, \ - const f77_int* n, \ - const ftype* alpha, \ - const ftype* a, const f77_int* lda, \ - ftype* b, const f77_int* ldb \ - ) \ -{ \ - PASTEF77S(ch,blasname) ( side, uploa, transa, diaga, m, n, alpha, a, lda, b, ldb ); \ -} \ +} #else #undef GENTFUNC #define GENTFUNC( ftype, ch, blasname, blisname ) \ \ -void PASTEF77S(ch,blasname) \ +void PASTEF77(ch,blasname) \ ( \ const f77_char* side, \ const f77_char* uploa, \ @@ -254,22 +239,7 @@ void PASTEF77S(ch,blasname) \ /* Finalize BLIS. */ \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ bli_finalize_auto(); \ -} \ -void PASTEF77(ch,blasname) \ - ( \ - const f77_char* side, \ - const f77_char* uploa, \ - const f77_char* transa, \ - const f77_char* diaga, \ - const f77_int* m, \ - const f77_int* n, \ - const ftype* alpha, \ - const ftype* a, const f77_int* lda, \ - ftype* b, const f77_int* ldb \ - ) \ -{ \ - PASTEF77S(ch,blasname) ( side, uploa, transa, diaga, m, n, alpha, a, lda, b, ldb ); \ -} \ +} #endif diff --git a/frame/compat/bla_trmm.h b/frame/compat/bla_trmm.h index 10cbb6cbc2..4f0c20b1b2 100644 --- a/frame/compat/bla_trmm.h +++ b/frame/compat/bla_trmm.h @@ -5,7 +5,6 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -41,18 +40,6 @@ #define GENTPROT( ftype, ch, blasname ) \ \ BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ - ( \ - const f77_char* side, \ - const f77_char* uploa, \ - const f77_char* transa, \ - const f77_char* diaga, \ - const f77_int* m, \ - const f77_int* n, \ - const ftype* alpha, \ - const ftype* a, const f77_int* lda, \ - ftype* b, const f77_int* ldb \ - ); \ -BLIS_EXPORT_BLAS void PASTEF77S(ch,blasname) \ ( \ const f77_char* side, \ const f77_char* uploa, \ diff --git a/frame/compat/bla_trsm.c b/frame/compat/bla_trsm.c index f709a8cd0a..e99805d8dd 100644 --- a/frame/compat/bla_trsm.c +++ b/frame/compat/bla_trsm.c @@ -45,7 +45,7 @@ #undef GENTFUNC #define GENTFUNC( ftype, ch, blasname, blisname ) \ \ -void PASTEF77S(ch,blasname) \ +void PASTEF77(ch,blasname) \ ( \ const f77_char* side, \ const f77_char* uploa, \ @@ -130,29 +130,14 @@ void PASTEF77S(ch,blasname) \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ -} \ -void PASTEF77(ch,blasname) \ - ( \ - const f77_char* side, \ - const f77_char* uploa, \ - const f77_char* transa, \ - const f77_char* diaga, \ - const f77_int* m, \ - const f77_int* n, \ - const ftype* alpha, \ - const ftype* a, const f77_int* lda, \ - ftype* b, const f77_int* ldb \ - ) \ -{ \ - PASTEF77S(ch,blasname) ( side, uploa, transa, diaga, m, n, alpha, a, lda, b, ldb ); \ -} \ +} #else #undef GENTFUNC #define GENTFUNC( ftype, ch, blasname, blisname ) \ \ -void PASTEF77S(ch,blasname) \ +void PASTEF77(ch,blasname) \ ( \ const f77_char* side, \ const f77_char* uploa, \ @@ -408,22 +393,7 @@ void PASTEF77S(ch,blasname) \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ -} \ -void PASTEF77(ch,blasname) \ - ( \ - const f77_char* side, \ - const f77_char* uploa, \ - const f77_char* transa, \ - const f77_char* diaga, \ - const f77_int* m, \ - const f77_int* n, \ - const ftype* alpha, \ - const ftype* a, const f77_int* lda, \ - ftype* b, const f77_int* ldb \ - ) \ -{ \ - PASTEF77S(ch,blasname) ( side, uploa, transa, diaga, m, n, alpha, a, lda, b, ldb ); \ -} \ +} #endif diff --git a/frame/compat/bla_trsm.h b/frame/compat/bla_trsm.h index af1b626dff..5694db52a8 100644 --- a/frame/compat/bla_trsm.h +++ b/frame/compat/bla_trsm.h @@ -5,7 +5,6 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -41,18 +40,6 @@ #define GENTPROT( ftype, ch, blasname ) \ \ BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ - ( \ - const f77_char* side, \ - const f77_char* uploa, \ - const f77_char* transa, \ - const f77_char* diaga, \ - const f77_int* m, \ - const f77_int* n, \ - const ftype* alpha, \ - const ftype* a, const f77_int* lda, \ - ftype* b, const f77_int* ldb \ - ); \ -BLIS_EXPORT_BLAS void PASTEF77S(ch,blasname) \ ( \ const f77_char* side, \ const f77_char* uploa, \ diff --git a/frame/compat/bla_trsm_amd.c b/frame/compat/bla_trsm_amd.c index 4479725fb9..8ca7434bd8 100644 --- a/frame/compat/bla_trsm_amd.c +++ b/frame/compat/bla_trsm_amd.c @@ -45,7 +45,7 @@ #undef GENTFUNC #define GENTFUNC( ftype, ch, blasname, blisname ) \ \ -void PASTEF77S(ch,blasname) \ +void PASTEF77(ch,blasname) \ ( \ const f77_char* side, \ const f77_char* uploa, \ @@ -130,29 +130,14 @@ void PASTEF77S(ch,blasname) \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ -} \ -void PASTEF77(ch,blasname) \ - ( \ - const f77_char* side, \ - const f77_char* uploa, \ - const f77_char* transa, \ - const f77_char* diaga, \ - const f77_int* m, \ - const f77_int* n, \ - const ftype* alpha, \ - const ftype* a, const f77_int* lda, \ - ftype* b, const f77_int* ldb \ - ) \ -{ \ - PASTEF77S(ch,blasname) ( side, uploa, transa, diaga, m, n, alpha, a, lda, b, ldb ); \ - } \ +} #else #undef GENTFUNC #define GENTFUNC( ftype, ch, blasname, blisname ) \ \ -void PASTEF77S(ch,blasname) \ +void PASTEF77(ch,blasname) \ ( \ const f77_char* side, \ const f77_char* uploa, \ @@ -408,28 +393,13 @@ void PASTEF77S(ch,blasname) \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) \ /* Finalize BLIS. */ \ bli_finalize_auto(); \ -} \ -void PASTEF77(ch,blasname) \ - ( \ - const f77_char* side, \ - const f77_char* uploa, \ - const f77_char* transa, \ - const f77_char* diaga, \ - const f77_int* m, \ - const f77_int* n, \ - const ftype* alpha, \ - const ftype* a, const f77_int* lda, \ - ftype* b, const f77_int* ldb \ - ) \ -{ \ - PASTEF77S(ch,blasname) ( side, uploa, transa, diaga, m, n, alpha, a, lda, b, ldb ); \ -} \ +} #endif #ifdef BLIS_ENABLE_BLAS -void strsm_blis_impl +void strsm_ ( const f77_char* side, const f77_char* uploa, @@ -699,23 +669,8 @@ void strsm_blis_impl /* Finalize BLIS. */ bli_finalize_auto(); } -void strsm_ -( - const f77_char* side, - const f77_char* uploa, - const f77_char* transa, - const f77_char* diaga, - const f77_int* m, - const f77_int* n, - const float* alpha, - const float* a, const f77_int* lda, - float* b, const f77_int* ldb -) -{ - strsm_blis_impl ( side, uploa, transa, diaga, m, n, alpha, a, lda, b, ldb ); -} -void dtrsm_blis_impl +void dtrsm_ ( const f77_char* side, const f77_char* uploa, @@ -937,7 +892,7 @@ void dtrsm_blis_impl bli_obj_set_conjtrans( blis_transa, &ao ); bli_obj_set_struc( struca, &ao ); - + #ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM // This function is invoked on all architectures including ‘generic’. // Non-AVX platforms will use the kernels derived from the context. @@ -1018,24 +973,9 @@ void dtrsm_blis_impl /* Finalize BLIS. */ bli_finalize_auto(); } -void dtrsm_ -( - const f77_char* side, - const f77_char* uploa, - const f77_char* transa, - const f77_char* diaga, - const f77_int* m, - const f77_int* n, - const double* alpha, - const double* a, const f77_int* lda, - double* b, const f77_int* ldb -) -{ - dtrsm_blis_impl ( side, uploa, transa, diaga, m, n, alpha, a, lda, b, ldb ); -} -void ztrsm_blis_impl +void ztrsm_ ( const f77_char* side, const f77_char* uploa, @@ -1391,24 +1331,9 @@ void ztrsm_blis_impl /* Finalize BLIS. */ bli_finalize_auto(); } -void ztrsm_ -( - const f77_char* side, - const f77_char* uploa, - const f77_char* transa, - const f77_char* diaga, - const f77_int* m, - const f77_int* n, - const dcomplex* alpha, - const dcomplex* a, const f77_int* lda, - dcomplex* b, const f77_int* ldb -) -{ - ztrsm_blis_impl ( side, uploa, transa, diaga, m, n, alpha, a, lda, b, ldb ); -} -void ctrsm_blis_impl +void ctrsm_ ( const f77_char* side, const f77_char* uploa, @@ -1739,20 +1664,5 @@ void ctrsm_blis_impl /* Finalize BLIS. */ bli_finalize_auto(); } -void ctrsm_ -( - const f77_char* side, - const f77_char* uploa, - const f77_char* transa, - const f77_char* diaga, - const f77_int* m, - const f77_int* n, - const scomplex* alpha, - const scomplex* a, const f77_int* lda, - scomplex* b, const f77_int* ldb -) -{ - ctrsm_blis_impl ( side, uploa, transa, diaga, m, n, alpha, a, lda, b, ldb ); -} #endif diff --git a/frame/compat/cblas/src/cblas_f77.h b/frame/compat/cblas/src/cblas_f77.h index 5ec518de9e..fabf3efb1c 100644 --- a/frame/compat/cblas/src/cblas_f77.h +++ b/frame/compat/cblas/src/cblas_f77.h @@ -7,7 +7,7 @@ * * (Heavily hacked down from the original) * - * Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2020 - 2021, Advanced Micro Devices, Inc. All rights reserved. * */ @@ -326,40 +326,40 @@ /* * Level 3 BLAS */ -#define F77_chemm chemm_blis_impl -#define F77_cherk cherk_blis_impl -#define F77_cher2k cher2k_blis_impl -#define F77_zhemm zhemm_blis_impl -#define F77_zherk zherk_blis_impl -#define F77_zher2k zher2k_blis_impl -#define F77_sgemm sgemm_blis_impl -#define F77_ssymm ssymm_blis_impl -#define F77_ssyrk ssyrk_blis_impl -#define F77_ssyr2k ssyr2k_blis_impl -#define F77_strmm strmm_blis_impl -#define F77_strsm strsm_blis_impl -#define F77_dgemm dgemm_blis_impl -#define F77_dsymm dsymm_blis_impl -#define F77_dsyrk dsyrk_blis_impl -#define F77_dsyr2k dsyr2k_blis_impl -#define F77_dtrmm dtrmm_blis_impl -#define F77_dtrsm dtrsm_blis_impl -#define F77_cgemm cgemm_blis_impl -#define F77_csymm csymm_blis_impl -#define F77_csyrk csyrk_blis_impl -#define F77_csyr2k csyr2k_blis_impl -#define F77_ctrmm ctrmm_blis_impl -#define F77_ctrsm ctrsm_blis_impl -#define F77_zgemm zgemm_blis_impl -#define F77_zsymm zsymm_blis_impl -#define F77_zsyrk zsyrk_blis_impl -#define F77_zsyr2k zsyr2k_blis_impl -#define F77_ztrmm ztrmm_blis_impl -#define F77_ztrsm ztrsm_blis_impl -#define F77_dgemmt dgemmt_blis_impl -#define F77_sgemmt sgemmt_blis_impl -#define F77_cgemmt cgemmt_blis_impl -#define F77_zgemmt zgemmt_blis_impl +#define F77_chemm chemm_ +#define F77_cherk cherk_ +#define F77_cher2k cher2k_ +#define F77_zhemm zhemm_ +#define F77_zherk zherk_ +#define F77_zher2k zher2k_ +#define F77_sgemm sgemm_ +#define F77_ssymm ssymm_ +#define F77_ssyrk ssyrk_ +#define F77_ssyr2k ssyr2k_ +#define F77_strmm strmm_ +#define F77_strsm strsm_ +#define F77_dgemm dgemm_ +#define F77_dsymm dsymm_ +#define F77_dsyrk dsyrk_ +#define F77_dsyr2k dsyr2k_ +#define F77_dtrmm dtrmm_ +#define F77_dtrsm dtrsm_ +#define F77_cgemm cgemm_ +#define F77_csymm csymm_ +#define F77_csyrk csyrk_ +#define F77_csyr2k csyr2k_ +#define F77_ctrmm ctrmm_ +#define F77_ctrsm ctrsm_ +#define F77_zgemm zgemm_ +#define F77_zsymm zsymm_ +#define F77_zsyrk zsyrk_ +#define F77_zsyr2k zsyr2k_ +#define F77_ztrmm ztrmm_ +#define F77_ztrsm ztrsm_ +#define F77_dgemmt dgemmt_ +#define F77_sgemmt sgemmt_ +#define F77_cgemmt cgemmt_ +#define F77_zgemmt zgemmt_ /* * Aux Function @@ -375,8 +375,8 @@ #define F77_daxpby daxpby_ #define F77_caxpby caxpby_ #define F77_zaxpby zaxpby_ -#define F77_cgemm3m cgemm3m_blis_impl -#define F77_zgemm3m zgemm3m_blis_impl +#define F77_cgemm3m cgemm3m_ +#define F77_zgemm3m zgemm3m_ #define F77_isamin_sub isaminsub_ #define F77_idamin_sub idaminsub_ diff --git a/frame/include/bli_macro_defs.h b/frame/include/bli_macro_defs.h index 75b9c9fdc4..f29fdc1fe4 100644 --- a/frame/include/bli_macro_defs.h +++ b/frame/include/bli_macro_defs.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018-2021, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -162,13 +162,11 @@ #define PASTEF77(ch1,name) ch1 ## name #define PASTEF772(ch1,ch2,name) ch1 ## ch2 ## name #define PASTEF773(ch1,ch2,ch3,name) ch1 ## ch2 ## ch3 ## name -#define PASTEF77S(ch1,name) ch1 ## name ## _blis_impl #else #define PASTEF770(name) name ## _ #define PASTEF77(ch1,name) ch1 ## name ## _ #define PASTEF772(ch1,ch2,name) ch1 ## ch2 ## name ## _ #define PASTEF773(ch1,ch2,ch3,name) ch1 ## ch2 ## ch3 ## name ## _ -#define PASTEF77S(ch1,name) ch1 ## name ## _blis_impl #endif // -- Include other groups of macros diff --git a/frame/util/bli_util_api_wrap.c b/frame/util/bli_util_api_wrap.c index 9e8d1ccc38..81300761fb 100644 --- a/frame/util/bli_util_api_wrap.c +++ b/frame/util/bli_util_api_wrap.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2021-2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -210,17 +210,17 @@ void CGBMV_(const char *trans,const f77_int *m,const f77_int *n,const f77_int void CGEMM(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) { - cgemm_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + cgemm_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void cgemm(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) { - cgemm_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + cgemm_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void CGEMM_(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) { - cgemm_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + cgemm_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void CGEMV(const char *trans,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *x,const f77_int *incx,const scomplex *beta,scomplex *y,const f77_int *incy) @@ -285,17 +285,17 @@ void CHBMV_(const char *uplo,const f77_int *n,const f77_int *k,const scomplex void CHEMM(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) { - chemm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + chemm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); } void chemm(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) { - chemm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + chemm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); } void CHEMM_(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) { - chemm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + chemm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); } void CHEMV(const char *uplo,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *x,const f77_int *incx,const scomplex *beta,scomplex *y,const f77_int *incy) @@ -345,32 +345,32 @@ void CHER2_(const char *uplo,const f77_int *n,const scomplex *alpha,const sco void CHER2K(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const float *beta,scomplex *c,const f77_int *ldc) { - cher2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + cher2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void cher2k(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const float *beta,scomplex *c,const f77_int *ldc) { - cher2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + cher2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void CHER2K_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const float *beta,scomplex *c,const f77_int *ldc) { - cher2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + cher2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void CHERK(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const float *alpha,const scomplex *a,const f77_int *lda,const float *beta,scomplex *c,const f77_int *ldc) { - cherk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); + cherk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); } void cherk(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const float *alpha,const scomplex *a,const f77_int *lda,const float *beta,scomplex *c,const f77_int *ldc) { - cherk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); + cherk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); } void CHERK_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const float *alpha,const scomplex *a,const f77_int *lda,const float *beta,scomplex *c,const f77_int *ldc) { - cherk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); + cherk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); } void CHPMV(const char *uplo,const f77_int *n,const scomplex *alpha,const scomplex *ap,const scomplex *x,const f77_int *incx,const scomplex *beta,scomplex *y,const f77_int *incy) @@ -495,47 +495,47 @@ void CSWAP_(const f77_int *n,scomplex *cx,const f77_int *incx,scomplex *cy,con void CSYMM(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) { - csymm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + csymm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); } void csymm(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) { - csymm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + csymm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); } void CSYMM_(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) { - csymm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + csymm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); } void CSYR2K(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) { - csyr2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + csyr2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void csyr2k(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) { - csyr2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + csyr2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void CSYR2K_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *b,const f77_int *ldb,const scomplex *beta,scomplex *c,const f77_int *ldc) { - csyr2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + csyr2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void CSYRK(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *beta,scomplex *c,const f77_int *ldc) { - csyrk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); + csyrk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); } void csyrk(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *beta,scomplex *c,const f77_int *ldc) { - csyrk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); + csyrk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); } void CSYRK_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const scomplex *alpha,const scomplex *a,const f77_int *lda,const scomplex *beta,scomplex *c,const f77_int *ldc) { - csyrk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); + csyrk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); } void CTBMV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const scomplex *a,const f77_int *lda,scomplex *x,const f77_int *incx) @@ -600,17 +600,17 @@ void CTPSV_(const char *uplo,const char *trans,const char *diag,const f77_ void CTRMM(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,scomplex *b,const f77_int *ldb) { - ctrmm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + ctrmm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void ctrmm(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,scomplex *b,const f77_int *ldb) { - ctrmm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + ctrmm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void CTRMM_(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,scomplex *b,const f77_int *ldb) { - ctrmm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + ctrmm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void CTRMV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const scomplex *a,const f77_int *lda,scomplex *x,const f77_int *incx) @@ -630,17 +630,17 @@ void CTRMV_(const char *uplo,const char *trans,const char *diag,const f77_ void CTRSM(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,scomplex *b,const f77_int *ldb) { - ctrsm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + ctrsm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void ctrsm(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,scomplex *b,const f77_int *ldb) { - ctrsm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + ctrsm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void CTRSM_(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const scomplex *alpha,const scomplex *a,const f77_int *lda,scomplex *b,const f77_int *ldb) { - ctrsm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + ctrsm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void CTRSV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const scomplex *a,const f77_int *lda,scomplex *x,const f77_int *incx) @@ -750,17 +750,17 @@ void DGBMV_(const char *trans,const f77_int *m,const f77_int *n,const f77_int void DGEMM(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const double *alpha,const double *a,const f77_int *lda,const double *b,const f77_int *ldb,const double *beta,double *c,const f77_int *ldc) { - dgemm_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + dgemm_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void dgemm(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const double *alpha,const double *a,const f77_int *lda,const double *b,const f77_int *ldb,const double *beta,double *c,const f77_int *ldc) { - dgemm_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + dgemm_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void DGEMM_(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const double *alpha,const double *a,const f77_int *lda,const double *b,const f77_int *ldb,const double *beta,double *c,const f77_int *ldc) { - dgemm_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + dgemm_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void DGEMV(const char *trans,const f77_int *m,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,const double *x,const f77_int *incx,const double *beta,double *y,const f77_int *incy) @@ -975,17 +975,17 @@ void DSWAP_(const f77_int *n,double *dx,const f77_int *incx,double *dy,const f77 void DSYMM(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,const double *b,const f77_int *ldb,const double *beta,double *c,const f77_int *ldc) { - dsymm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + dsymm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); } void dsymm(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,const double *b,const f77_int *ldb,const double *beta,double *c,const f77_int *ldc) { - dsymm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + dsymm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); } void DSYMM_(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,const double *b,const f77_int *ldb,const double *beta,double *c,const f77_int *ldc) { - dsymm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + dsymm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); } void DSYMV(const char *uplo,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,const double *x,const f77_int *incx,const double *beta,double *y,const f77_int *incy) @@ -1035,32 +1035,32 @@ void DSYR2_(const char *uplo,const f77_int *n,const double *alpha,const double void DSYR2K(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const double *alpha,const double *a,const f77_int *lda,const double *b,const f77_int *ldb,const double *beta,double *c,const f77_int *ldc) { - dsyr2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + dsyr2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void dsyr2k(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const double *alpha,const double *a,const f77_int *lda,const double *b,const f77_int *ldb,const double *beta,double *c,const f77_int *ldc) { - dsyr2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + dsyr2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void DSYR2K_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const double *alpha,const double *a,const f77_int *lda,const double *b,const f77_int *ldb,const double *beta,double *c,const f77_int *ldc) { - dsyr2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + dsyr2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void DSYRK(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const double *alpha,const double *a,const f77_int *lda,const double *beta,double *c,const f77_int *ldc) { - dsyrk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); + dsyrk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); } void dsyrk(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const double *alpha,const double *a,const f77_int *lda,const double *beta,double *c,const f77_int *ldc) { - dsyrk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); + dsyrk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); } void DSYRK_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const double *alpha,const double *a,const f77_int *lda,const double *beta,double *c,const f77_int *ldc) { - dsyrk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); + dsyrk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); } void DTBMV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const double *a,const f77_int *lda,double *x,const f77_int *incx) @@ -1125,17 +1125,17 @@ void DTPSV_(const char *uplo,const char *trans,const char *diag,const f77_ void DTRMM(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,double *b,const f77_int *ldb) { - dtrmm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + dtrmm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void dtrmm(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,double *b,const f77_int *ldb) { - dtrmm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + dtrmm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void DTRMM_(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,double *b,const f77_int *ldb) { - dtrmm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + dtrmm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void DTRMV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const double *a,const f77_int *lda,double *x,const f77_int *incx) @@ -1155,17 +1155,17 @@ void DTRMV_(const char *uplo,const char *trans,const char *diag,const f77_ void DTRSM(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,double *b,const f77_int *ldb) { - dtrsm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + dtrsm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void dtrsm(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,double *b,const f77_int *ldb) { - dtrsm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + dtrsm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void DTRSM_(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const double *alpha,const double *a,const f77_int *lda,double *b,const f77_int *ldb) { - dtrsm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + dtrsm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void DTRSV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const double *a,const f77_int *lda,double *x,const f77_int *incx) @@ -1417,17 +1417,17 @@ void SGBMV_(const char *trans,const f77_int *m,const f77_int *n,const f77_int void SGEMM(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const float *alpha,const float *a,const f77_int *lda,const float *b,const f77_int *ldb,const float *beta,float *c,const f77_int *ldc) { - sgemm_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + sgemm_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void sgemm(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const float *alpha,const float *a,const f77_int *lda,const float *b,const f77_int *ldb,const float *beta,float *c,const f77_int *ldc) { - sgemm_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + sgemm_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void SGEMM_(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const float *alpha,const float *a,const f77_int *lda,const float *b,const f77_int *ldb,const float *beta,float *c,const f77_int *ldc) { - sgemm_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + sgemm_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void SGEMV(const char *trans,const f77_int *m,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,const float *x,const f77_int *incx,const float *beta,float *y,const f77_int *incy) @@ -1629,17 +1629,17 @@ void SSWAP_(const f77_int *n,float *sx,const f77_int *incx,float *sy,const f77 void SSYMM(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,const float *b,const f77_int *ldb,const float *beta,float *c,const f77_int *ldc) { - ssymm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + ssymm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); } void ssymm(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,const float *b,const f77_int *ldb,const float *beta,float *c,const f77_int *ldc) { - ssymm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + ssymm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); } void SSYMM_(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,const float *b,const f77_int *ldb,const float *beta,float *c,const f77_int *ldc) { - ssymm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + ssymm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); } void SSYMV(const char *uplo,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,const float *x,const f77_int *incx,const float *beta,float *y,const f77_int *incy) @@ -1689,32 +1689,32 @@ void SSYR2_(const char *uplo,const f77_int *n,const float *alpha,const float void SSYR2K(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const float *alpha,const float *a,const f77_int *lda,const float *b,const f77_int *ldb,const float *beta,float *c,const f77_int *ldc) { - ssyr2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + ssyr2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void ssyr2k(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const float *alpha,const float *a,const f77_int *lda,const float *b,const f77_int *ldb,const float *beta,float *c,const f77_int *ldc) { - ssyr2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + ssyr2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void SSYR2K_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const float *alpha,const float *a,const f77_int *lda,const float *b,const f77_int *ldb,const float *beta,float *c,const f77_int *ldc) { - ssyr2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + ssyr2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void SSYRK(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const float *alpha,const float *a,const f77_int *lda,const float *beta,float *c,const f77_int *ldc) { - ssyrk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); + ssyrk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); } void ssyrk(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const float *alpha,const float *a,const f77_int *lda,const float *beta,float *c,const f77_int *ldc) { - ssyrk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); + ssyrk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); } void SSYRK_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const float *alpha,const float *a,const f77_int *lda,const float *beta,float *c,const f77_int *ldc) { - ssyrk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); + ssyrk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); } void STBMV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const float *a,const f77_int *lda,float *x,const f77_int *incx) @@ -1779,17 +1779,17 @@ void STPSV_(const char *uplo,const char *trans,const char *diag,const f77_ void STRMM(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,float *b,const f77_int *ldb) { - strmm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + strmm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void strmm(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,float *b,const f77_int *ldb) { - strmm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + strmm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void STRMM_(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,float *b,const f77_int *ldb) { - strmm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + strmm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void STRMV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const float *a,const f77_int *lda,float *x,const f77_int *incx) @@ -1809,17 +1809,17 @@ void STRMV_(const char *uplo,const char *trans,const char *diag,const f77_ void STRSM(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,float *b,const f77_int *ldb) { - strsm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + strsm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void strsm(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,float *b,const f77_int *ldb) { - strsm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + strsm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void STRSM_(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const float *alpha,const float *a,const f77_int *lda,float *b,const f77_int *ldb) { - strsm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + strsm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void STRSV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const float *a,const f77_int *lda,float *x,const f77_int *incx) @@ -1929,17 +1929,17 @@ void ZGBMV_(const char *trans,const f77_int *m,const f77_int *n,const f77_int void ZGEMM(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) { - zgemm_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + zgemm_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void zgemm(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) { - zgemm_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + zgemm_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void ZGEMM_(const char *transa,const char *transb,const f77_int *m,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) { - zgemm_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + zgemm_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void ZGEMV(const char *trans,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *x,const f77_int *incx,const dcomplex *beta,dcomplex *y,const f77_int *incy) @@ -2004,17 +2004,17 @@ void ZHBMV_(const char *uplo,const f77_int *n,const f77_int *k,const dcomplex void ZHEMM(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) { - zhemm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + zhemm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); } void zhemm(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) { - zhemm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + zhemm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); } void ZHEMM_(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) { - zhemm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + zhemm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); } void ZHEMV(const char *uplo,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *x,const f77_int *incx,const dcomplex *beta,dcomplex *y,const f77_int *incy) @@ -2064,32 +2064,32 @@ void ZHER2_(const char *uplo,const f77_int *n,const dcomplex *alpha,const dcom void ZHER2K(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const double *beta,dcomplex *c,const f77_int *ldc) { - zher2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + zher2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void zher2k(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const double *beta,dcomplex *c,const f77_int *ldc) { - zher2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + zher2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void ZHER2K_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const double *beta,dcomplex *c,const f77_int *ldc) { - zher2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + zher2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void ZHERK(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const double *alpha,const dcomplex *a,const f77_int *lda,const double *beta,dcomplex *c,const f77_int *ldc) { - zherk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); + zherk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); } void zherk(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const double *alpha,const dcomplex *a,const f77_int *lda,const double *beta,dcomplex *c,const f77_int *ldc) { - zherk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); + zherk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); } void ZHERK_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const double *alpha,const dcomplex *a,const f77_int *lda,const double *beta,dcomplex *c,const f77_int *ldc) { - zherk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); + zherk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); } void ZHPMV(const char *uplo,const f77_int *n,const dcomplex *alpha,const dcomplex *ap,const dcomplex *x,const f77_int *incx,const dcomplex *beta,dcomplex *y,const f77_int *incy) @@ -2184,47 +2184,47 @@ void ZSWAP_(const f77_int *n,dcomplex *zx,const f77_int *incx,dcomplex *zy,const void ZSYMM(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) { - zsymm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + zsymm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); } void zsymm(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) { - zsymm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + zsymm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); } void ZSYMM_(const char *side,const char *uplo,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) { - zsymm_blis_impl( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); + zsymm_( side, uplo, m, n, alpha, a, lda, b, ldb, beta, c, ldc); } void ZSYR2K(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) { - zsyr2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + zsyr2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void zsyr2k(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) { - zsyr2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + zsyr2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void ZSYR2K_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *b,const f77_int *ldb,const dcomplex *beta,dcomplex *c,const f77_int *ldc) { - zsyr2k_blis_impl( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + zsyr2k_( uplo, trans, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void ZSYRK(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *beta,dcomplex *c,const f77_int *ldc) { - zsyrk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); + zsyrk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); } void zsyrk(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *beta,dcomplex *c,const f77_int *ldc) { - zsyrk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); + zsyrk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); } void ZSYRK_(const char *uplo,const char *trans,const f77_int *n,const f77_int *k,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,const dcomplex *beta,dcomplex *c,const f77_int *ldc) { - zsyrk_blis_impl( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); + zsyrk_( uplo, trans, n, k, alpha, a, lda, beta, c, ldc); } void ZTBMV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const f77_int *k,const dcomplex *a,const f77_int *lda,dcomplex *x,const f77_int *incx) @@ -2289,17 +2289,17 @@ void ZTPSV_(const char *uplo,const char *trans,const char *diag,const f77_ void ZTRMM(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,dcomplex *b,const f77_int *ldb) { - ztrmm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + ztrmm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void ztrmm(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,dcomplex *b,const f77_int *ldb) { - ztrmm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + ztrmm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void ZTRMM_(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,dcomplex *b,const f77_int *ldb) { - ztrmm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + ztrmm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void ZTRMV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const dcomplex *a,const f77_int *lda,dcomplex *x,const f77_int *incx) @@ -2319,17 +2319,17 @@ void ZTRMV_(const char *uplo,const char *trans,const char *diag,const f77_ void ZTRSM(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,dcomplex *b,const f77_int *ldb) { - ztrsm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + ztrsm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void ztrsm(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,dcomplex *b,const f77_int *ldb) { - ztrsm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + ztrsm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void ZTRSM_(const char *side,const char *uplo,const char *transa,const char *diag,const f77_int *m,const f77_int *n,const dcomplex *alpha,const dcomplex *a,const f77_int *lda,dcomplex *b,const f77_int *ldb) { - ztrsm_blis_impl( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); + ztrsm_( side, uplo, transa, diag, m, n, alpha, a, lda, b, ldb); } void ZTRSV(const char *uplo,const char *trans,const char *diag,const f77_int *n,const dcomplex *a,const f77_int *lda,dcomplex *x,const f77_int *incx) @@ -2380,17 +2380,17 @@ void CDOTUSUB_( const f77_int* n, const scomplex* x,const f77_int* incxy, const void CGEMM3M( const f77_char* transa, const f77_char* transb, const f77_int* m, const f77_int* n, const f77_int* k, const scomplex* alpha, const scomplex* a, const f77_int* lda, const scomplex* b, const f77_int* ldb, const scomplex* beta, scomplex* c, const f77_int* ldc) { - cgemm3m_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + cgemm3m_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void cgemm3m( const f77_char* transa, const f77_char* transb, const f77_int* m, const f77_int* n, const f77_int* k, const scomplex* alpha, const scomplex* a, const f77_int* lda, const scomplex* b, const f77_int* ldb, const scomplex* beta, scomplex* c, const f77_int* ldc) { - cgemm3m_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + cgemm3m_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void CGEMM3M_( const f77_char* transa, const f77_char* transb, const f77_int* m, const f77_int* n, const f77_int* k, const scomplex* alpha, const scomplex* a, const f77_int* lda, const scomplex* b, const f77_int* ldb, const scomplex* beta, scomplex* c, const f77_int* ldc) { - cgemm3m_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + cgemm3m_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void CGEMM_BATCH( const f77_char* transa_array, const f77_char* transb_array,const f77_int *m_array, const f77_int *n_array, const f77_int *k_array,const scomplex* alpha_array, const scomplex** a_array, const f77_int *lda_array, const scomplex** b_array, const f77_int *ldb_array, const scomplex* beta_array, scomplex** c_array, const f77_int *ldc_array, const f77_int* group_count, const f77_int *group_size) @@ -2410,17 +2410,17 @@ void CGEMM_BATCH_( const f77_char* transa_array, const f77_char* transb_array,co void CGEMMT( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const scomplex* alpha, const scomplex* a, const f77_int* lda, const scomplex* b, const f77_int* ldb, const scomplex* beta, scomplex* c, const f77_int* ldc) { - cgemmt_blis_impl( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + cgemmt_( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void cgemmt( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const scomplex* alpha, const scomplex* a, const f77_int* lda, const scomplex* b, const f77_int* ldb, const scomplex* beta, scomplex* c, const f77_int* ldc) { - cgemmt_blis_impl( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + cgemmt_( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void CGEMMT_( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const scomplex* alpha, const scomplex* a, const f77_int* lda, const scomplex* b, const f77_int* ldb, const scomplex* beta, scomplex* c, const f77_int* ldc) { - cgemmt_blis_impl( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + cgemmt_( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void CIMATCOPY(f77_char* trans, f77_int* rows, f77_int* cols, const scomplex* alpha,scomplex* aptr, f77_int* lda, f77_int* ldb) @@ -2545,17 +2545,17 @@ void DGEMM_BATCH_( const f77_char* transa_array, const f77_char* transb_array,co void DGEMMT( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const double* alpha, const double* a, const f77_int* lda, const double* b, const f77_int* ldb, const double* beta, double* c, const f77_int* ldc) { - dgemmt_blis_impl( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + dgemmt_( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void dgemmt( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const double* alpha, const double* a, const f77_int* lda, const double* b, const f77_int* ldb, const double* beta, double* c, const f77_int* ldc) { - dgemmt_blis_impl( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + dgemmt_( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void DGEMMT_( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const double* alpha, const double* a, const f77_int* lda, const double* b, const f77_int* ldb, const double* beta, double* c, const f77_int* ldc) { - dgemmt_blis_impl( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + dgemmt_( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void DNRM2SUB(const f77_int* n, const double* x, const f77_int* incx, double *rval) @@ -2920,17 +2920,17 @@ void SGEMM_BATCH_(const f77_char* transa_array, const f77_char* transb_array,con void SGEMMT( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const float* alpha, const float* a, const f77_int* lda, const float* b, const f77_int* ldb, const float* beta, float* c, const f77_int* ldc) { - sgemmt_blis_impl( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + sgemmt_( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void sgemmt( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const float* alpha, const float* a, const f77_int* lda, const float* b, const f77_int* ldb, const float* beta, float* c, const f77_int* ldc) { - sgemmt_blis_impl( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + sgemmt_( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void SGEMMT_( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const float* alpha, const float* a, const f77_int* lda, const float* b, const f77_int* ldb, const float* beta, float* c, const f77_int* ldc) { - sgemmt_blis_impl( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + sgemmt_( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void SIMATCOPY( f77_char* trans, f77_int* rows, f77_int* cols, const float* alpha,float* aptr, f77_int* lda, f77_int* ldb) @@ -3055,17 +3055,17 @@ void ZDOTUSUB_( const f77_int* n, const dcomplex* x, const f77_int* incx,const d void ZGEMM3M( const f77_char* transa, const f77_char* transb, const f77_int* m, const f77_int* n, const f77_int* k, const dcomplex* alpha, const dcomplex* a, const f77_int* lda, const dcomplex* b, const f77_int* ldb, const dcomplex* beta, dcomplex* c, const f77_int* ldc) { - zgemm3m_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + zgemm3m_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void zgemm3m( const f77_char* transa, const f77_char* transb, const f77_int* m, const f77_int* n, const f77_int* k, const dcomplex* alpha, const dcomplex* a, const f77_int* lda, const dcomplex* b, const f77_int* ldb, const dcomplex* beta, dcomplex* c, const f77_int* ldc) { - zgemm3m_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + zgemm3m_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void ZGEMM3M_( const f77_char* transa, const f77_char* transb, const f77_int* m, const f77_int* n, const f77_int* k, const dcomplex* alpha, const dcomplex* a, const f77_int* lda, const dcomplex* b, const f77_int* ldb, const dcomplex* beta, dcomplex* c, const f77_int* ldc) { - zgemm3m_blis_impl( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + zgemm3m_( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void ZGEMM_BATCH( const f77_char* transa_array, const f77_char* transb_array,const f77_int *m_array, const f77_int *n_array, const f77_int *k_array,const dcomplex* alpha_array, const dcomplex** a_array, const f77_int *lda_array, const dcomplex** b_array, const f77_int *ldb_array, const dcomplex* beta_array, dcomplex** c_array, const f77_int *ldc_array, const f77_int* group_count, const f77_int *group_size) @@ -3085,17 +3085,17 @@ void ZGEMM_BATCH_( const f77_char* transa_array, const f77_char* transb_array,c void ZGEMMT( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const dcomplex* alpha, const dcomplex* a, const f77_int* lda, const dcomplex* b, const f77_int* ldb, const dcomplex* beta, dcomplex* c, const f77_int* ldc) { - zgemmt_blis_impl( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + zgemmt_( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void zgemmt( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const dcomplex* alpha, const dcomplex* a, const f77_int* lda, const dcomplex* b, const f77_int* ldb, const dcomplex* beta, dcomplex* c, const f77_int* ldc) { - zgemmt_blis_impl( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + zgemmt_( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void ZGEMMT_( const f77_char* uploc, const f77_char* transa, const f77_char* transb, const f77_int* n, const f77_int* k, const dcomplex* alpha, const dcomplex* a, const f77_int* lda, const dcomplex* b, const f77_int* ldb, const dcomplex* beta, dcomplex* c, const f77_int* ldc) { - zgemmt_blis_impl( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); + zgemmt_( uploc, transa, transb, n, k, alpha, a, lda, b, ldb, beta, c, ldc); } void ZIMATCOPY(f77_char* trans, f77_int* rows, f77_int* cols, const dcomplex* alpha,dcomplex* aptr, f77_int* lda, f77_int* ldb) From 99dc9066fbbdc4b7444b382e4ae98b3b5fbebbf2 Mon Sep 17 00:00:00 2001 From: satish kumar nuggu Date: Tue, 30 Aug 2022 18:02:19 +0530 Subject: [PATCH 210/243] Fix in ZTRSM Small MT Details: 1. Changes are made in ztrsm small MT path,to avoid accuracy issues reported in libflame tests. AMD-Internal: [CPUPL-2476] Change-Id: Ic279106343fb1744e89ff4c920023adbe1d0158a --- frame/compat/bla_trsm_amd.c | 25 +------------------------ 1 file changed, 1 insertion(+), 24 deletions(-) diff --git a/frame/compat/bla_trsm_amd.c b/frame/compat/bla_trsm_amd.c index 8ca7434bd8..e143a88d37 100644 --- a/frame/compat/bla_trsm_amd.c +++ b/frame/compat/bla_trsm_amd.c @@ -1291,30 +1291,7 @@ void ztrsm_ return; } } -#ifdef BLIS_ENABLE_OPENMP - - // bli_trsm_small_mt supports till n_threads equal to 8 - if( bli_cntx_trsm_small_thresh_is_met_zen(&ao, m0, n0) == true ) - { - err_t status; - status = bli_trsm_small_mt( - blis_side, - &alphao, - &ao, - &bo, - NULL, - NULL); - - if ( status == BLIS_SUCCESS ) - { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - /* Finalize BLIS. */ - bli_finalize_auto(); - return; - } - } -#endif// BLIS_ENABLE_OPENMP - } // bli_cpuid_is_avx_supported} + } // bli_cpuid_is_avx_supported #endif bli_trsmnat From d62f12a18a9e8c56ac9a64253c266017e4502a8e Mon Sep 17 00:00:00 2001 From: satish kumar nuggu Date: Mon, 29 Aug 2022 22:56:19 +0530 Subject: [PATCH 211/243] Fixed bug in DZGEMM Details: 1. In zen4 dgemm and sgemm native kernels are column-prefer kernels, cgemm and zgemm kernels are row-prefer kernels. zen3 and older arch (uses row-prefer kernels for all datatypes). Induced-transpose carried out based on kernel preference check in both crc and ccr. we don't require kernel preference check to apply Induced-transpose. 2. In case of ccr, B matrix is real. We must Induce a transposition and perform C+=A*B (crc). where A (formerly B) is real. 3. In case of crc, A matrix is real. We don't require any Induced-transpose. AMD-Internal: [CPUPL-2440] [CPUPL-2449] Change-Id: I44c53a20c8def7ddbb84797ba20260acec2086a2 --- frame/3/gemm/bli_gemm_md.c | 105 +++---------------------------------- 1 file changed, 7 insertions(+), 98 deletions(-) diff --git a/frame/3/gemm/bli_gemm_md.c b/frame/3/gemm/bli_gemm_md.c index c9450a26c2..d7e402b7c7 100644 --- a/frame/3/gemm/bli_gemm_md.c +++ b/frame/3/gemm/bli_gemm_md.c @@ -156,90 +156,20 @@ mddm_t bli_gemm_md_ccr cntx_t** cntx ) { - mddm_t doms; - - // We assume that the requested computation domain is complex. - //dom_t dom_comp_in = bli_obj_comp_domain( c ); - //dom_t dom_comp_in = BLIS_COMPLEX; - - // For ccr, the computation (ukernel) will be real, but the execution - // will appear complex to other parts of the implementation. - doms.comp = BLIS_REAL; - doms.exec = BLIS_COMPLEX; - - // Here we construct the computation datatype, which for the ccr case - // is equal to the real projection of the execution datatype, and use - // that computation datatype to query the corresponding ukernel output - // preference. - const num_t dt = BLIS_REAL | bli_obj_comp_prec( c ); - - // We can only perform this case of mixed-domain gemm, C += A*B where - // B is real, if the microkernel prefers column output. If it prefers - // row output, we must induce a transposition and perform C += A*B - // where A (formerly B) is real. - if ( bli_cntx_l3_vir_ukr_dislikes_storage_of_md( c, dt, BLIS_GEMM_UKR, *cntx ) ) - { - bli_obj_swap( a, b ); - - bli_obj_induce_trans( a ); - bli_obj_induce_trans( b ); - bli_obj_induce_trans( c ); - - return bli_gemm_md_crc( a, b, beta, c, cntx_local, cntx ); - } - // Create a local copy of the context and then prepare to use this // context instead of the one passed in. *cntx_local = **cntx; *cntx = cntx_local; - // Copy the real domain blocksizes into the slots of their complex - // counterparts. - blksz_t* blksz_mr = bli_cntx_get_blksz( BLIS_MR, *cntx ); - blksz_t* blksz_nr = bli_cntx_get_blksz( BLIS_NR, *cntx ); - blksz_t* blksz_mc = bli_cntx_get_blksz( BLIS_MC, *cntx ); - blksz_t* blksz_nc = bli_cntx_get_blksz( BLIS_NC, *cntx ); - blksz_t* blksz_kc = bli_cntx_get_blksz( BLIS_KC, *cntx ); - - bli_blksz_copy_dt( BLIS_FLOAT, blksz_mr, BLIS_SCOMPLEX, blksz_mr ); - bli_blksz_copy_dt( BLIS_DOUBLE, blksz_mr, BLIS_DCOMPLEX, blksz_mr ); - - bli_blksz_copy_dt( BLIS_FLOAT, blksz_nr, BLIS_SCOMPLEX, blksz_nr ); - bli_blksz_copy_dt( BLIS_DOUBLE, blksz_nr, BLIS_DCOMPLEX, blksz_nr ); - - bli_blksz_copy_dt( BLIS_FLOAT, blksz_mc, BLIS_SCOMPLEX, blksz_mc ); - bli_blksz_copy_dt( BLIS_DOUBLE, blksz_mc, BLIS_DCOMPLEX, blksz_mc ); - - bli_blksz_copy_dt( BLIS_FLOAT, blksz_nc, BLIS_SCOMPLEX, blksz_nc ); - bli_blksz_copy_dt( BLIS_DOUBLE, blksz_nc, BLIS_DCOMPLEX, blksz_nc ); - - bli_blksz_copy_dt( BLIS_FLOAT, blksz_kc, BLIS_SCOMPLEX, blksz_kc ); - bli_blksz_copy_dt( BLIS_DOUBLE, blksz_kc, BLIS_DCOMPLEX, blksz_kc ); - - // Halve both the real and complex MR's (which are both real MR's). - bli_blksz_scale_def_max( 1, 2, BLIS_FLOAT, blksz_mr ); - bli_blksz_scale_def_max( 1, 2, BLIS_DOUBLE, blksz_mr ); - bli_blksz_scale_def_max( 1, 2, BLIS_SCOMPLEX, blksz_mr ); - bli_blksz_scale_def_max( 1, 2, BLIS_DCOMPLEX, blksz_mr ); - - // Halve both the real and complex MC's (which are both real MC's). - bli_blksz_scale_def_max( 1, 2, BLIS_FLOAT, blksz_mc ); - bli_blksz_scale_def_max( 1, 2, BLIS_DOUBLE, blksz_mc ); - bli_blksz_scale_def_max( 1, 2, BLIS_SCOMPLEX, blksz_mc ); - bli_blksz_scale_def_max( 1, 2, BLIS_DCOMPLEX, blksz_mc ); - - // Use the default pack schemas in the context. - - // static func_t* bli_cntx_get_l3_vir_ukrs( l3ukr_t ukr_id, cntx_t* cntx ) - func_t* l3_vir_ukrs = bli_cntx_get_l3_vir_ukrs( BLIS_GEMM_UKR, *cntx ); + //we must induce a transposition and perform C += A*B + // where A (formerly B) is real. + bli_obj_swap( a, b ); - // Rather than check which complex datatype dt_comp refers to, we set - // the mixed-domain virtual microkernel for both types. - bli_func_set_dt( bli_cgemm_md_c2r_ref, BLIS_SCOMPLEX, l3_vir_ukrs ); - bli_func_set_dt( bli_zgemm_md_c2r_ref, BLIS_DCOMPLEX, l3_vir_ukrs ); + bli_obj_induce_trans( a ); + bli_obj_induce_trans( b ); + bli_obj_induce_trans( c ); - // Return the computation and execution domains. - return doms; + return bli_gemm_md_crc( a, b, beta, c, cntx_local, cntx ); } // ----------------------------------------------------------------------------- @@ -266,27 +196,6 @@ mddm_t bli_gemm_md_crc doms.comp = BLIS_REAL; doms.exec = BLIS_COMPLEX; - // Here we construct the computation datatype, which for the crc case - // is equal to the real projection of the execution datatype, and use - // that computation datatype to query the corresponding ukernel output - // preference. - const num_t dt = BLIS_REAL | bli_obj_comp_prec( c ); - - // We can only perform this case of mixed-domain gemm, C += A*B where - // A is real, if the microkernel prefers row output. If it prefers - // column output, we must induce a transposition and perform C += A*B - // where B (formerly A) is real. - if ( bli_cntx_l3_vir_ukr_dislikes_storage_of_md( c, dt, BLIS_GEMM_UKR, *cntx ) ) - { - bli_obj_swap( a, b ); - - bli_obj_induce_trans( a ); - bli_obj_induce_trans( b ); - bli_obj_induce_trans( c ); - - return bli_gemm_md_ccr( a, b, beta, c, cntx_local, cntx ); - } - // Create a local copy of the context and then prepare to use this // context instead of the one passed in. *cntx_local = **cntx; From 958c9238acc92556bb302e2fc0c7d1307fc45425 Mon Sep 17 00:00:00 2001 From: mkadavil Date: Tue, 30 Aug 2022 15:30:34 +0530 Subject: [PATCH 212/243] Output downscaling support for low precision GEMM. - Downscaling is used when GEMM output is accumulated at a higher precision and needs to be converted to a lower precision afterwards. This is required in AI workloads where quantization/dequantization routines are used. - New GEMM APIs are introduced specifically to support this use case. Currently downscaling support is added for s32, s16 and bfloat16 GEMM. AMD-Internal: [CPUPL-2475] Change-Id: I81c3ee1ba5414f62427a7a0abb6ecef0c5ff71bf --- addon/aocl_gemm/aocl_gemm_bf16bf16f32obf16.c | 157 +++++ ...emm_bf16.c => aocl_gemm_bf16bf16f32of32.c} | 10 +- addon/aocl_gemm/aocl_gemm_f32f32f32of32.c | 6 +- addon/aocl_gemm/aocl_gemm_interface_apis.h | 3 + addon/aocl_gemm/aocl_gemm_post_ops.h | 3 +- addon/aocl_gemm/aocl_gemm_u8s8s16os16.c | 30 +- addon/aocl_gemm/aocl_gemm_u8s8s16os8.c | 156 +++++ addon/aocl_gemm/aocl_gemm_u8s8s32os32.c | 6 +- addon/aocl_gemm/aocl_gemm_u8s8s32os8.c | 157 +++++ .../aocl_gemm/frame/bf16bf16f32/lpgemm_bf16.c | 73 ++- .../frame/lpgemm_5loop_interface_apis.h | 3 +- addon/aocl_gemm/frame/lpgemm_post_ops.c | 12 +- addon/aocl_gemm/frame/lpgemm_post_ops.h | 8 +- .../threading/lpgemm_thread_decor_openmp.c | 10 +- .../threading/lpgemm_thread_decor_openmp.h | 6 +- .../aocl_gemm/frame/u8s8s16/lpgemm_u8s8s16.c | 70 ++- .../aocl_gemm/frame/u8s8s32/lpgemm_u8s8s32.c | 71 ++- .../lpgemm_6x64rowmajor_bf16_amd512vnni.c | 18 +- .../lpgemm_n_fringe_bf16_amd512vnni.c | 40 +- addon/aocl_gemm/kernels/lpgemm_kernels.h | 18 +- .../u8s8s16/lpgemm_6x32rowmajor_amd256.c | 10 +- .../kernels/u8s8s16/lpgemm_n_fringe_amd256.c | 12 +- .../u8s8s32/lpgemm_6x64rowmajor_amd512vnni.c | 110 +++- .../u8s8s32/lpgemm_m_fringe_amd512vnni.c | 280 ++++++++- .../u8s8s32/lpgemm_mn_fringe_amd512vnni.c | 580 +++++++++++++++++- .../u8s8s32/lpgemm_n_fringe_amd512vnni.c | 219 ++++++- .../kernels/u8s8s32/lpgemm_s32_kern_macros.h | 40 ++ 27 files changed, 1939 insertions(+), 169 deletions(-) create mode 100644 addon/aocl_gemm/aocl_gemm_bf16bf16f32obf16.c rename addon/aocl_gemm/{aocl_gemm_bf16.c => aocl_gemm_bf16bf16f32of32.c} (93%) create mode 100644 addon/aocl_gemm/aocl_gemm_u8s8s16os8.c create mode 100644 addon/aocl_gemm/aocl_gemm_u8s8s32os8.c diff --git a/addon/aocl_gemm/aocl_gemm_bf16bf16f32obf16.c b/addon/aocl_gemm/aocl_gemm_bf16bf16f32obf16.c new file mode 100644 index 0000000000..482e7c264e --- /dev/null +++ b/addon/aocl_gemm/aocl_gemm_bf16bf16f32obf16.c @@ -0,0 +1,157 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "aocl_gemm_interface_apis.h" +#include "lpgemm_types.h" +#include "lpgemm_post_ops.h" +#include "lpgemm_thread_decor_openmp.h" +#include "lpgemm_5loop_interface_apis.h" +#include "lpgemm_config.h" +#include "lpgemm_utils.h" + +AOCL_GEMM_MATMUL(bfloat16,bfloat16,bfloat16,bf16bf16f32obf16) +{ + trans_t blis_transa; + trans_t blis_transb; + + // Check if avx512_vnni ISA is supported, lpgemm matmul only works with it. + if ( bli_cpuid_is_avx512vnni_supported() == FALSE ) + { + printf(" AVX512_BF16 ISA not supported by processor, cannot perform lpgemm.\n"); + return; // Error. + } + + /* Initialize BLIS. */ + bli_init_auto(); + + // Set MC, NC, KC, NR, MR. + aocl_lpgemm_init_global_cntx(); + + // Null check for pointers. + if ( ( a == NULL ) || ( b == NULL ) || ( c == NULL ) ) + { + return; // Error. + } + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + bli_param_map_netlib_to_blis_trans( transa, &blis_transa ); + bli_param_map_netlib_to_blis_trans( transb, &blis_transb ); + + /* Perform BLAS parameter checking. */ + // Transpose not supported. + if ( ( blis_transa != BLIS_NO_TRANSPOSE ) || + ( blis_transb != BLIS_NO_TRANSPOSE ) ) + { + return; // Error. + } + if ( ( order != 'r' ) && ( order != 'R' ) ) + { + return; // Only row major supported. + } + + // Row major input expected with leading dimensions equal to row stride. + if ( ( lda != k ) || ( ldb != n ) || ( ldc != n ) ) + { + return; // Error. + } + + // Check if dimensions are valid. + if ( ( m <= 0) || ( n <= 0 ) || ( k <= 0 ) || + ( lda <= 0 ) || ( ldb <= 0 ) || ( ldc <= 0 ) ) + { + return; // Error. + } + + const inc_t rs_a = lda; + const inc_t cs_a = 1; + const inc_t rs_b = ldb; + const inc_t cs_b = 1; + const inc_t rs_c = ldc; + + AOCL_MEMORY_TAG mtag_a; + AOCL_MEMORY_TAG mtag_b; + + bli_param_map_char_to_lpmtag( mem_format_a, &mtag_a ); + bli_param_map_char_to_lpmtag( mem_format_b, &mtag_b ); + + // B matrix needs to be packed in a certain format in order to be loaded + // and used in bf16 instrution. As such the mtag_b always needs to be either + // packed or reordered. B matrix as it is (unpacked) cannot be used, and + // the mtag_b is set to packed to enable runtime packing. + if ( mtag_b == UNPACKED ) + { + mtag_b = PACK; + } + + // Only unpacked A supported now. + if ( mtag_a != UNPACKED ) + { + return; // Error. + } + + // Convert post op struct to post op linked list format. + lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS]; + lpgemm_translate_to_post_ops_list( post_op_unparsed, post_op_list, ( void* )c ); + + // Initialize a local runtime with global settings if necessary. Note + // that in the case that a runtime is passed in, we make a local copy. + rntm_t rntm_g; + bli_rntm_init_from_global( &rntm_g ); + bli_membrk_rntm_set_membrk( &rntm_g ); + +#ifdef BLIS_ENABLE_OPENMP + lpgemm_bf16bf16f32of32_openmp_thread_decorator + ( + m, n, k, + a, rs_a, cs_a, mtag_a, + b, rs_b, cs_b, mtag_b, + ( float* )c, rs_c, + alpha, beta, + &rntm_g, + post_op_list, TRUE + ); +#else + lpgemm_bf16bf16f32of32_thread_decorator + ( + m, n, k, + a, rs_a, cs_a, mtag_a, + b, rs_b, cs_b, mtag_b, + ( float* )c, rs_c, + alpha, beta, + &rntm_g, + post_op_list, TRUE + ); +#endif +} diff --git a/addon/aocl_gemm/aocl_gemm_bf16.c b/addon/aocl_gemm/aocl_gemm_bf16bf16f32of32.c similarity index 93% rename from addon/aocl_gemm/aocl_gemm_bf16.c rename to addon/aocl_gemm/aocl_gemm_bf16bf16f32of32.c index 6138b4e259..e36399ba55 100644 --- a/addon/aocl_gemm/aocl_gemm_bf16.c +++ b/addon/aocl_gemm/aocl_gemm_bf16bf16f32of32.c @@ -46,8 +46,8 @@ AOCL_GEMM_MATMUL(bfloat16,bfloat16,float,bf16bf16f32of32) trans_t blis_transa; trans_t blis_transb; - // Check if avx512_bf16 ISA is supported, lpgemm matmul only works with it. - if ( bli_cpuid_is_avx512_bf16_supported() == FALSE ) + // Check if avx512_vnni ISA is supported, lpgemm matmul only works with it. + if ( bli_cpuid_is_avx512vnni_supported() == FALSE ) { printf(" AVX512_BF16 ISA not supported by processor, cannot perform lpgemm.\n"); return; // Error. @@ -123,7 +123,7 @@ AOCL_GEMM_MATMUL(bfloat16,bfloat16,float,bf16bf16f32of32) // Convert post op struct to post op linked list format. lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS]; - lpgemm_translate_to_post_ops_list( post_op_unparsed, post_op_list ); + lpgemm_translate_to_post_ops_list( post_op_unparsed, post_op_list, ( void* )c ); // Initialize a local runtime with global settings if necessary. Note // that in the case that a runtime is passed in, we make a local copy. @@ -140,7 +140,7 @@ AOCL_GEMM_MATMUL(bfloat16,bfloat16,float,bf16bf16f32of32) c, rs_c, alpha, beta, &rntm_g, - post_op_list + post_op_list, FALSE ); #else lpgemm_bf16bf16f32of32_thread_decorator @@ -151,7 +151,7 @@ AOCL_GEMM_MATMUL(bfloat16,bfloat16,float,bf16bf16f32of32) c, rs_c, alpha, beta, &rntm_g, - post_op_list + post_op_list, FALSE ); #endif } diff --git a/addon/aocl_gemm/aocl_gemm_f32f32f32of32.c b/addon/aocl_gemm/aocl_gemm_f32f32f32of32.c index 179882d412..90c3ff41c2 100644 --- a/addon/aocl_gemm/aocl_gemm_f32f32f32of32.c +++ b/addon/aocl_gemm/aocl_gemm_f32f32f32of32.c @@ -124,7 +124,7 @@ AOCL_GEMM_MATMUL(float,float,float,f32f32f32of32) // Convert post op struct to post op linked list format. lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS]; - lpgemm_translate_to_post_ops_list( post_op_unparsed, post_op_list ); + lpgemm_translate_to_post_ops_list( post_op_unparsed, post_op_list, ( void* )c ); // Initialize a local runtime with global settings if necessary. Note // that in the case that a runtime is passed in, we make a local copy. @@ -141,7 +141,7 @@ AOCL_GEMM_MATMUL(float,float,float,f32f32f32of32) c, rs_c, alpha, beta, &rntm_g, - post_op_list + post_op_list, FALSE ); #else // Setting pack A by default for non open mp case. @@ -155,7 +155,7 @@ AOCL_GEMM_MATMUL(float,float,float,f32f32f32of32) c, rs_c, alpha, beta, &rntm_g, - post_op_list + post_op_list, FALSE ); #endif diff --git a/addon/aocl_gemm/aocl_gemm_interface_apis.h b/addon/aocl_gemm/aocl_gemm_interface_apis.h index b38c2c1599..69ee4ad8e1 100644 --- a/addon/aocl_gemm/aocl_gemm_interface_apis.h +++ b/addon/aocl_gemm/aocl_gemm_interface_apis.h @@ -100,5 +100,8 @@ AOCL_GEMM_MATMUL(float,float,float,f32f32f32of32); AOCL_GEMM_MATMUL(uint8_t,int8_t,int32_t,u8s8s32os32); AOCL_GEMM_MATMUL(uint8_t,int8_t,int16_t,u8s8s16os16); AOCL_GEMM_MATMUL(bfloat16,bfloat16,float,bf16bf16f32of32); +AOCL_GEMM_MATMUL(uint8_t,int8_t,int8_t,u8s8s32os8); +AOCL_GEMM_MATMUL(uint8_t,int8_t,int8_t,u8s8s16os8); +AOCL_GEMM_MATMUL(bfloat16,bfloat16,bfloat16,bf16bf16f32obf16); #endif // AOCL_GEMM_INTERFACE_H diff --git a/addon/aocl_gemm/aocl_gemm_post_ops.h b/addon/aocl_gemm/aocl_gemm_post_ops.h index 4a739892a4..86034598ac 100644 --- a/addon/aocl_gemm/aocl_gemm_post_ops.h +++ b/addon/aocl_gemm/aocl_gemm_post_ops.h @@ -48,6 +48,7 @@ typedef enum SUM = 1, ELTWISE = 2, BIAS = 3, + SCALE = 4, } AOCL_POST_OP_TYPE; typedef struct @@ -63,7 +64,7 @@ typedef struct void* scale_factor; void* buff; void* zero_point; -} aocl_post_op_sum; +} aocl_post_op_sum; // Also use for scale. typedef struct { diff --git a/addon/aocl_gemm/aocl_gemm_u8s8s16os16.c b/addon/aocl_gemm/aocl_gemm_u8s8s16os16.c index 4d0e8565a9..20b987d3b6 100644 --- a/addon/aocl_gemm/aocl_gemm_u8s8s16os16.c +++ b/addon/aocl_gemm/aocl_gemm_u8s8s16os16.c @@ -122,7 +122,7 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int16_t,u8s8s16os16) // Convert post op struct to post op linked list format. lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS]; - lpgemm_translate_to_post_ops_list( post_op_unparsed, post_op_list ); + lpgemm_translate_to_post_ops_list( post_op_unparsed, post_op_list, ( void* )c ); // Initialize a local runtime with global settings if necessary. Note // that in the case that a runtime is passed in, we make a local copy. @@ -133,24 +133,24 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int16_t,u8s8s16os16) #ifdef BLIS_ENABLE_OPENMP lpgemm_u8s8s16o16_openmp_thread_decorator ( - m, n, k, - a, rs_a, cs_a, mtag_a, - b, rs_b, cs_b, mtag_b, - c, rs_c, - alpha, beta, - &rntm_g, - post_op_list + m, n, k, + a, rs_a, cs_a, mtag_a, + b, rs_b, cs_b, mtag_b, + c, rs_c, + alpha, beta, + &rntm_g, + post_op_list, FALSE ); #else lpgemm_u8s8s16o16_thread_decorator ( - m, n, k, - a, rs_a, cs_a, mtag_a, - b, rs_b, cs_b, mtag_b, - c, rs_c, - alpha, beta, - &rntm_g, - post_op_list + m, n, k, + a, rs_a, cs_a, mtag_a, + b, rs_b, cs_b, mtag_b, + c, rs_c, + alpha, beta, + &rntm_g, + post_op_list, FALSE ); #endif } diff --git a/addon/aocl_gemm/aocl_gemm_u8s8s16os8.c b/addon/aocl_gemm/aocl_gemm_u8s8s16os8.c new file mode 100644 index 0000000000..90aeb4c7b8 --- /dev/null +++ b/addon/aocl_gemm/aocl_gemm_u8s8s16os8.c @@ -0,0 +1,156 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "aocl_gemm_interface_apis.h" +#include "lpgemm_types.h" +#include "lpgemm_5loop_interface_apis.h" +#include "lpgemm_config.h" +#include "lpgemm_utils.h" +#include "lpgemm_thread_decor_openmp.h" +#include "lpgemm_post_ops.h" + +AOCL_GEMM_MATMUL(uint8_t,int8_t,int8_t,u8s8s16os8) +{ + trans_t blis_transa; + trans_t blis_transb; + + // Check if avx ISA is supported, lpgemm u8s8s16os16 matmul only works with it. + if ( bli_cpuid_is_avx_supported() == FALSE ) + { + printf(" AVX2 ISA not supported by processor, cannot perform lpgemm.\n"); + return; // Error. + } + + /* Initialize BLIS. */ + bli_init_auto(); + + // Set MC, NC, KC, NR, MR. + aocl_lpgemm_init_global_cntx(); + + // Null check for pointers. + if ((a == NULL) || (b == NULL) || (c == NULL)) + { + return; // Error. + } + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + bli_param_map_netlib_to_blis_trans(transa, &blis_transa); + bli_param_map_netlib_to_blis_trans(transb, &blis_transb); + + /* Perform BLAS parameter checking. */ + // Transpose not supported. + if ( ( blis_transa != BLIS_NO_TRANSPOSE ) || + ( blis_transb != BLIS_NO_TRANSPOSE ) ) + { + return; // Error. + } + if ( ( order != 'r' ) && ( order != 'R' ) ) + { + return; // Only row major supported. + } + + // Row major input expected with leading dimensions equal to row stride. + if ((lda != k) || (ldb != n) || (ldc != n)) + { + return; // Error. + } + + // Check if dimensions are valid. + if ((m <= 0) || (n <= 0) || (k <= 0) || (lda <= 0) || (ldb <= 0) || (ldc <= 0)) + { + return; // Error. + } + + const inc_t rs_a = lda; + const inc_t cs_a = 1; + const inc_t rs_b = ldb; + const inc_t cs_b = 1; + const inc_t rs_c = ldc; + + AOCL_MEMORY_TAG mtag_a; + AOCL_MEMORY_TAG mtag_b; + + bli_param_map_char_to_lpmtag(mem_format_a, &mtag_a); + bli_param_map_char_to_lpmtag(mem_format_b, &mtag_b); + + // B matrix needs to be packed in a certain format in order to be loaded + // and used in VNNI instrution. As such the mtag_b always needs to be either + // packed or reordered. B matrix as it is (unpacked) cannot be used, and + // the mtag_b is set to packed to enable runtime packing. + if (mtag_b == UNPACKED) + { + mtag_b = PACK; + } + + // Only unpacked A supported now. + if (mtag_a != UNPACKED) + { + return; // Error. + } + + // Convert post op struct to post op linked list format. + lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS]; + lpgemm_translate_to_post_ops_list( post_op_unparsed, post_op_list, ( void* )c ); + + // Initialize a local runtime with global settings if necessary. Note + // that in the case that a runtime is passed in, we make a local copy. + rntm_t rntm_g; + bli_rntm_init_from_global(&rntm_g); + bli_membrk_rntm_set_membrk(&rntm_g); + +#ifdef BLIS_ENABLE_OPENMP + lpgemm_u8s8s16o16_openmp_thread_decorator + ( + m, n, k, + a, rs_a, cs_a, mtag_a, + b, rs_b, cs_b, mtag_b, + ( int16_t* )c, rs_c, + alpha, beta, + &rntm_g, + post_op_list, TRUE + ); +#else + lpgemm_u8s8s16o16_thread_decorator + ( + m, n, k, + a, rs_a, cs_a, mtag_a, + b, rs_b, cs_b, mtag_b, + ( int16_t* )c, rs_c, + alpha, beta, + &rntm_g, + post_op_list, TRUE + ); +#endif +} diff --git a/addon/aocl_gemm/aocl_gemm_u8s8s32os32.c b/addon/aocl_gemm/aocl_gemm_u8s8s32os32.c index 6263d2ce27..8f75db1274 100644 --- a/addon/aocl_gemm/aocl_gemm_u8s8s32os32.c +++ b/addon/aocl_gemm/aocl_gemm_u8s8s32os32.c @@ -123,7 +123,7 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int32_t,u8s8s32os32) // Convert post op struct to post op linked list format. lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS]; - lpgemm_translate_to_post_ops_list( post_op_unparsed, post_op_list ); + lpgemm_translate_to_post_ops_list( post_op_unparsed, post_op_list, ( void* )c ); // Initialize a local runtime with global settings if necessary. Note // that in the case that a runtime is passed in, we make a local copy. @@ -140,7 +140,7 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int32_t,u8s8s32os32) c, rs_c, alpha, beta, &rntm_g, - post_op_list + post_op_list, FALSE ); #else lpgemm_u8s8s32o32_thread_decorator @@ -151,7 +151,7 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int32_t,u8s8s32os32) c, rs_c, alpha, beta, &rntm_g, - post_op_list + post_op_list, FALSE ); #endif } diff --git a/addon/aocl_gemm/aocl_gemm_u8s8s32os8.c b/addon/aocl_gemm/aocl_gemm_u8s8s32os8.c new file mode 100644 index 0000000000..df90de4968 --- /dev/null +++ b/addon/aocl_gemm/aocl_gemm_u8s8s32os8.c @@ -0,0 +1,157 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "aocl_gemm_interface_apis.h" +#include "lpgemm_types.h" +#include "lpgemm_post_ops.h" +#include "lpgemm_thread_decor_openmp.h" +#include "lpgemm_5loop_interface_apis.h" +#include "lpgemm_config.h" +#include "lpgemm_utils.h" + +AOCL_GEMM_MATMUL(uint8_t,int8_t,int8_t,u8s8s32os8) +{ + trans_t blis_transa; + trans_t blis_transb; + + // Check if avx512_vnni ISA is supported, lpgemm matmul only works with it. + if ( bli_cpuid_is_avx512vnni_supported() == FALSE ) + { + printf(" AVX512_VNNI ISA not supported by processor, cannot perform lpgemm.\n"); + return; // Error. + } + + /* Initialize BLIS. */ + bli_init_auto(); + + // Set MC, NC, KC, NR, MR. + aocl_lpgemm_init_global_cntx(); + + // Null check for pointers. + if ( ( a == NULL ) || ( b == NULL ) || ( c == NULL ) ) + { + return; // Error. + } + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + bli_param_map_netlib_to_blis_trans( transa, &blis_transa ); + bli_param_map_netlib_to_blis_trans( transb, &blis_transb ); + + /* Perform BLAS parameter checking. */ + // Transpose not supported. + if ( ( blis_transa != BLIS_NO_TRANSPOSE ) || + ( blis_transb != BLIS_NO_TRANSPOSE ) ) + { + return; // Error. + } + if ( ( order != 'r' ) && ( order != 'R' ) ) + { + return; // Only row major supported. + } + + // Row major input expected with leading dimensions equal to row stride. + if ( ( lda != k ) || ( ldb != n ) || ( ldc != n ) ) + { + return; // Error. + } + + // Check if dimensions are valid. + if ( ( m <= 0) || ( n <= 0 ) || ( k <= 0 ) || + ( lda <= 0 ) || ( ldb <= 0 ) || ( ldc <= 0 ) ) + { + return; // Error. + } + + const inc_t rs_a = lda; + const inc_t cs_a = 1; + const inc_t rs_b = ldb; + const inc_t cs_b = 1; + const inc_t rs_c = ldc; + + AOCL_MEMORY_TAG mtag_a; + AOCL_MEMORY_TAG mtag_b; + + bli_param_map_char_to_lpmtag( mem_format_a, &mtag_a ); + bli_param_map_char_to_lpmtag( mem_format_b, &mtag_b ); + + // B matrix needs to be packed in a certain format in order to be loaded + // and used in VNNI instrution. As such the mtag_b always needs to be either + // packed or reordered. B matrix as it is (unpacked) cannot be used, and + // the mtag_b is set to packed to enable runtime packing. + if ( mtag_b == UNPACKED ) + { + mtag_b = PACK; + } + + // Only unpacked A supported now. + if ( mtag_a != UNPACKED ) + { + return; // Error. + } + + // Convert post op struct to post op linked list format. + lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS]; + lpgemm_translate_to_post_ops_list( post_op_unparsed, post_op_list, ( void* )c ); + + // Initialize a local runtime with global settings if necessary. Note + // that in the case that a runtime is passed in, we make a local copy. + rntm_t rntm_g; + bli_rntm_init_from_global( &rntm_g ); + bli_membrk_rntm_set_membrk( &rntm_g ); + +#ifdef BLIS_ENABLE_OPENMP + lpgemm_u8s8s32o32_openmp_thread_decorator + ( + m, n, k, + a, rs_a, cs_a, mtag_a, + b, rs_b, cs_b, mtag_b, + ( int32_t* )c, rs_c, + alpha, beta, + &rntm_g, + post_op_list, TRUE + ); +#else + lpgemm_u8s8s32o32_thread_decorator + ( + m, n, k, + a, rs_a, cs_a, mtag_a, + b, rs_b, cs_b, mtag_b, + ( int32_t* )c, rs_c, + alpha, beta, + &rntm_g, + post_op_list, TRUE + ); +#endif +} diff --git a/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_bf16.c b/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_bf16.c index e03ec87421..1a092377e7 100644 --- a/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_bf16.c +++ b/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_bf16.c @@ -59,6 +59,8 @@ LPGEMM_5LOOP(bfloat16,bfloat16,float,bf16bf16f32of32) float* c_use_jc = NULL; float* c_use_ic = NULL; + dim_t rs_c_use = rs_c; + dim_t rs_c_downscale = rs_c; // Pack buffer for B. bfloat16* pack_b_buffer_bf16; @@ -66,6 +68,11 @@ LPGEMM_5LOOP(bfloat16,bfloat16,float,bf16bf16f32of32) siz_t mem_b_size_req = 0; dim_t packb_min_NR = 16; + // Temporary buffer for C accumulation when downscaling is required. + float* temp_scal_c_buffer_bf16; + mem_t mem_scale_c = BLIS_MEM_INITIALIZER; + siz_t mem_scale_c_size_req = 0; + // kc needs to be a multiple of 2 so that it can be used with dpbf16_ps // instruction. Padding is added in cases this condition is not // satisfied, and therefore the k offset used for packed/reordered @@ -82,14 +89,16 @@ LPGEMM_5LOOP(bfloat16,bfloat16,float,bf16bf16f32of32) lpgemm_gen_thrinfo( thread, &thread_jc, &thread_ic ); - // Compute the JC loop thread range for the current thread. + // Compute the JC, IC loop thread range for the current thread. dim_t jc_start, jc_end; bli_thread_range_sub( &thread_jc, n, NR, FALSE, &jc_start, &jc_end ); + dim_t ic_start, ic_end; + bli_thread_range_sub( &thread_ic, m, MR, FALSE, &ic_start, &ic_end ); + for ( dim_t jc = jc_start; jc < jc_end; jc += NC ) { dim_t nc0 = bli_min( ( jc_end - jc ), NC ); - c_use_jc = c + jc; dim_t jc_cur_loop = jc; dim_t jc_cur_loop_rem = 0; @@ -105,6 +114,47 @@ LPGEMM_5LOOP(bfloat16,bfloat16,float,bf16bf16f32of32) ); } + if ( c_downscale == FALSE ) + { + c_use_jc = c + jc; + } + // Temp accumulaton buffer for C allocation. + else if ( c_downscale == TRUE ) + { + mem_scale_c_size_req = sizeof( float ) * nc0 * ( ic_end - ic_start ); + + lpgemm_alloc_mem_panel + ( + mem_scale_c_size_req, BLIS_BUFFER_FOR_C_PANEL, + &mem_scale_c, rntm + ); + + temp_scal_c_buffer_bf16 = bli_mem_buffer( &mem_scale_c ); + + c_use_jc = ( float* )temp_scal_c_buffer_bf16; + + if ( beta != 0 ) + { + dim_t i_temp = 0; + dim_t j_temp = 0; + // Upscale out C to temporary C matrix. + for ( dim_t i_dscale = ic_start; i_dscale < ic_end; ++i_dscale ) + { + j_temp = 0; + for ( dim_t j_dscale = jc; j_dscale < nc0; ++j_dscale ) + { + *( temp_scal_c_buffer_bf16 + ( nc0 * i_temp ) + j_temp ) = + ( float )( *( c + ( rs_c * i_dscale ) + j_dscale ) ); + j_temp++; + } + i_temp++; + } + } + + // The temp c buffer stride is modified as opposed to original C matrix. + rs_c_use = nc0; + } + for ( dim_t pc = 0; pc < k; pc += KC ) { float beta0 = ( pc == 0 ) ? beta : 1; @@ -197,9 +247,8 @@ LPGEMM_5LOOP(bfloat16,bfloat16,float,bf16bf16f32of32) ); b_use = pack_b_buffer_bf16; } - // B part getting processed - if ( mtag_b == REORDERED ) + else if ( mtag_b == REORDERED ) { // In multi-threaded scenarios, an extra offset into a given // packed B panel is required, since the jc loop split can @@ -212,13 +261,10 @@ LPGEMM_5LOOP(bfloat16,bfloat16,float,bf16bf16f32of32) get_packb_nr64_bf16bf16f32of32_strides( &rs_b_use, &cs_b_use ); } - dim_t ic_start, ic_end; - bli_thread_range_sub( &thread_ic, m, MR, FALSE, &ic_start, &ic_end ); - for ( dim_t ic = ic_start; ic < ic_end; ic += MC ) { dim_t mc0 = bli_min( ( ic_end - ic ), MC ); - c_use_ic = c_use_jc + ( rs_c * ic ); + c_use_ic = c_use_jc + ( rs_c_use * ic ); if ( mtag_a == UNPACKED ) { @@ -241,9 +287,9 @@ LPGEMM_5LOOP(bfloat16,bfloat16,float,bf16bf16f32of32) mc0, nr0, kc0, a_use, rs_a, cs_a_use, a_block_stride, ( b_use + ( jr * kc0_updated ) ), rs_b_use, cs_b_use, - ( c_use_ic + jr ), rs_c, 1, + ( c_use_ic + jr ), rs_c_use, 1, alpha, beta0, - is_last_k, ic, ( jc + jr ), post_op_list + is_last_k, ic, ( jc + jr ), post_op_list, rs_c_downscale ); } } @@ -273,4 +319,11 @@ LPGEMM_5LOOP(bfloat16,bfloat16,float,bf16bf16f32of32) } } } + if ( c_downscale == TRUE ) + { + if ( bli_mem_is_alloc( &mem_scale_c ) ) + { + bli_membrk_release( rntm, &mem_scale_c ); + } + } } diff --git a/addon/aocl_gemm/frame/lpgemm_5loop_interface_apis.h b/addon/aocl_gemm/frame/lpgemm_5loop_interface_apis.h index 5a5d2eaff8..ef63ef09eb 100644 --- a/addon/aocl_gemm/frame/lpgemm_5loop_interface_apis.h +++ b/addon/aocl_gemm/frame/lpgemm_5loop_interface_apis.h @@ -59,7 +59,8 @@ void lpgemm_rowvar_ ## LP_SFX \ C_type beta, \ rntm_t* rntm, \ lpgemm_thrinfo_t* thread, \ - lpgemm_post_op* post_op_list \ + lpgemm_post_op* post_op_list, \ + bool c_downscale \ ) \ LPGEMM_5LOOP(uint8_t,int8_t,int32_t,u8s8s32o32); diff --git a/addon/aocl_gemm/frame/lpgemm_post_ops.c b/addon/aocl_gemm/frame/lpgemm_post_ops.c index 45b479c19e..3b25527a20 100644 --- a/addon/aocl_gemm/frame/lpgemm_post_ops.c +++ b/addon/aocl_gemm/frame/lpgemm_post_ops.c @@ -58,7 +58,8 @@ BLIS_INLINE void lpgemm_set_node_params void lpgemm_translate_to_post_ops_list ( aocl_post_op* post_op_unparsed, - lpgemm_post_op* post_op_list + lpgemm_post_op* post_op_list, + void* scale_buffer ) { if ( post_op_unparsed == NULL ) @@ -131,6 +132,15 @@ void lpgemm_translate_to_post_ops_list NULL, NULL, NULL, FALSE ); break; + case SCALE: + lpgemm_set_node_params + ( + ( post_op_list + i ), POST_OPS_DOWNSCALE, + post_op_unparsed->sum.zero_point, + NULL, scale_buffer, + post_op_unparsed->sum.scale_factor, FALSE + ); + break; default: break; } diff --git a/addon/aocl_gemm/frame/lpgemm_post_ops.h b/addon/aocl_gemm/frame/lpgemm_post_ops.h index 7dea44b0c7..ce26fe1397 100644 --- a/addon/aocl_gemm/frame/lpgemm_post_ops.h +++ b/addon/aocl_gemm/frame/lpgemm_post_ops.h @@ -41,7 +41,8 @@ typedef enum POST_OPS_BIAS = 1, POST_OPS_RELU = 2, POST_OPS_RELU_SCALE = 3, - POST_OPS_SUM = 4, + POST_OPS_DOWNSCALE = 4, + POST_OPS_SUM = 5, } LPGEMM_POST_OP_CODE; // Used as an internal structure. @@ -50,7 +51,7 @@ typedef struct lpgemm_post_op_t LPGEMM_POST_OP_CODE op_code; void* op_args1; void* op_args2; // alpha, zero_point - void* op_args3; // beta + void* op_args3; // beta, downscale buffer/original C matrix void* scale_factor; bool is_power_of_2; struct lpgemm_post_op_t* next; @@ -59,7 +60,8 @@ typedef struct lpgemm_post_op_t void lpgemm_translate_to_post_ops_list ( aocl_post_op* post_op_unparsed, - lpgemm_post_op* post_op_list + lpgemm_post_op* post_op_list, + void* scale_buffer ); #define POST_OP_LABEL_LASTK_SAFE_JUMP \ diff --git a/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.c b/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.c index cf2c9231c3..e22b248157 100644 --- a/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.c +++ b/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.c @@ -477,7 +477,8 @@ void lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \ C_type alpha, \ C_type beta, \ rntm_t* rntm_g, \ - lpgemm_post_op* post_op_list \ + lpgemm_post_op* post_op_list, \ + bool c_downscale \ ) \ { \ dim_t n_threads; \ @@ -537,7 +538,7 @@ void lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \ beta, \ &rntm_l, \ &thread, \ - post_op_list \ + post_op_list, c_downscale \ ); \ } \ if ( jc_ways > BLIS_LPGEMM_NUM_STATIC_COMMS ) \ @@ -572,7 +573,8 @@ void lpgemm_ ## LPGEMM_SFX ## _thread_decorator \ C_type alpha, \ C_type beta, \ rntm_t* rntm_g, \ - lpgemm_post_op* post_op_list \ + lpgemm_post_op* post_op_list, \ + bool c_downscale \ ) \ { \ dim_t n_threads = 1; \ @@ -611,7 +613,7 @@ void lpgemm_ ## LPGEMM_SFX ## _thread_decorator \ beta, \ rntm_g, \ &thread, \ - post_op_list \ + post_op_list, c_downscale \ ); \ } \ diff --git a/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.h b/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.h index a32a3b580a..78e01291ac 100644 --- a/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.h +++ b/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.h @@ -60,7 +60,8 @@ void lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \ C_type alpha, \ C_type beta, \ rntm_t* rntm_g, \ - lpgemm_post_op* post_op_list \ + lpgemm_post_op* post_op_list, \ + bool c_downscale \ ); \ GEN_LPGEMM_OPENMP_DECORATOR_FN(uint8_t,int8_t,int16_t,u8s8s16o16) @@ -89,7 +90,8 @@ void lpgemm_ ## LPGEMM_SFX ## _thread_decorator \ C_type alpha, \ C_type beta, \ rntm_t* rntm_g, \ - lpgemm_post_op* post_op_list \ + lpgemm_post_op* post_op_list, \ + bool c_downscale \ ); \ GEN_LPGEMM_DECORATOR_FN(uint8_t,int8_t,int16_t,u8s8s16o16) diff --git a/addon/aocl_gemm/frame/u8s8s16/lpgemm_u8s8s16.c b/addon/aocl_gemm/frame/u8s8s16/lpgemm_u8s8s16.c index 76e5640f73..0465abce66 100644 --- a/addon/aocl_gemm/frame/u8s8s16/lpgemm_u8s8s16.c +++ b/addon/aocl_gemm/frame/u8s8s16/lpgemm_u8s8s16.c @@ -65,6 +65,8 @@ LPGEMM_5LOOP(uint8_t,int8_t,int16_t,u8s8s16o16) int16_t *c_use_jc = NULL; int16_t *c_use_ic = NULL; + dim_t rs_c_use = rs_c; + dim_t rs_c_downscale = rs_c; // Pack buffer for B. int8_t *pack_b_buffer_u8s8s16o16; @@ -72,6 +74,11 @@ LPGEMM_5LOOP(uint8_t,int8_t,int16_t,u8s8s16o16) dim_t packb_min_NR = 16; siz_t mem_b_size_req = 0; + // Temporary buffer for C accumulation when downscaling is required. + int16_t* temp_scal_c_buffer_u8s8s16o16; + mem_t mem_scale_c = BLIS_MEM_INITIALIZER; + siz_t mem_scale_c_size_req = 0; + // Making multiple of 2 to suit k in vpmaddubsw dim_t k_updated = make_multiple_of_n( k, 2 ); @@ -84,14 +91,16 @@ LPGEMM_5LOOP(uint8_t,int8_t,int16_t,u8s8s16o16) lpgemm_gen_thrinfo(thread, &thread_jc, &thread_ic); - // Compute the JC loop thread range for the current thread. + // Compute the JC, IC loop thread range for the current thread. dim_t jc_start, jc_end; bli_thread_range_sub(&thread_jc, n, NR, FALSE, &jc_start, &jc_end); + dim_t ic_start, ic_end; + bli_thread_range_sub(&thread_ic, m, MR, FALSE, &ic_start, &ic_end); + for (dim_t jc = jc_start; jc < jc_end; jc += NC) { dim_t nc0 = bli_min((jc_end - jc), NC); - c_use_jc = c + jc; dim_t jc_cur_loop = jc; dim_t jc_cur_loop_rem = 0; @@ -107,6 +116,47 @@ LPGEMM_5LOOP(uint8_t,int8_t,int16_t,u8s8s16o16) ); } + if ( c_downscale == FALSE ) + { + c_use_jc = c + jc; + } + // Temp accumulaton buffer for C allocation. + else if ( c_downscale == TRUE ) + { + mem_scale_c_size_req = sizeof( int16_t ) * nc0 * ( ic_end - ic_start ); + + lpgemm_alloc_mem_panel + ( + mem_scale_c_size_req, BLIS_BUFFER_FOR_C_PANEL, + &mem_scale_c, rntm + ); + + temp_scal_c_buffer_u8s8s16o16 = bli_mem_buffer( &mem_scale_c ); + + c_use_jc = ( int16_t* )temp_scal_c_buffer_u8s8s16o16; + + if ( beta != 0 ) + { + dim_t i_temp = 0; + dim_t j_temp = 0; + // Upscale out C to temporary C matrix. + for ( dim_t i_dscale = ic_start; i_dscale < ic_end; ++i_dscale ) + { + j_temp = 0; + for ( dim_t j_dscale = jc; j_dscale < nc0; ++j_dscale ) + { + *( temp_scal_c_buffer_u8s8s16o16 + ( nc0 * i_dscale ) + j_dscale ) = + ( int16_t )( *( c + ( rs_c * i_dscale ) + j_dscale ) ); + j_temp++; + } + i_temp++; + } + } + + // The temp c buffer stride is modified as opposed to original C matrix. + rs_c_use = nc0; + } + for (dim_t pc = 0; pc < k; pc += KC) { int16_t beta0 = (pc == 0) ? beta : 1; @@ -215,13 +265,10 @@ LPGEMM_5LOOP(uint8_t,int8_t,int16_t,u8s8s16o16) return; } - dim_t ic_start, ic_end; - bli_thread_range_sub(&thread_ic, m, MR, FALSE, &ic_start, &ic_end); - for (dim_t ic = ic_start; ic < ic_end; ic += MC) { dim_t mc0 = bli_min((ic_end - ic), MC); - c_use_ic = c_use_jc + (rs_c * ic); + c_use_ic = c_use_jc + (rs_c_use * ic); a_use = a + (rs_a * ic) + (cs_a * pc); cs_a_use = 1; @@ -238,9 +285,9 @@ LPGEMM_5LOOP(uint8_t,int8_t,int16_t,u8s8s16o16) mc0, nr0, kc0, a_use, rs_a_use, cs_a_use, a_block_stride, (b_use + (jr * kc0_updated)), rs_b_use, cs_b_use, - (c_use_ic + jr), rs_c, 1, + (c_use_ic + jr), rs_c_use, 1, alpha, beta0, - is_last_k, ic, ( jc + jr ), post_op_list + is_last_k, ic, ( jc + jr ), post_op_list, rs_c_downscale ); } } @@ -269,4 +316,11 @@ LPGEMM_5LOOP(uint8_t,int8_t,int16_t,u8s8s16o16) } } } + if ( c_downscale == TRUE ) + { + if ( bli_mem_is_alloc( &mem_scale_c ) ) + { + bli_membrk_release( rntm, &mem_scale_c ); + } + } } diff --git a/addon/aocl_gemm/frame/u8s8s32/lpgemm_u8s8s32.c b/addon/aocl_gemm/frame/u8s8s32/lpgemm_u8s8s32.c index 2c5ff4999c..019f081860 100644 --- a/addon/aocl_gemm/frame/u8s8s32/lpgemm_u8s8s32.c +++ b/addon/aocl_gemm/frame/u8s8s32/lpgemm_u8s8s32.c @@ -68,6 +68,8 @@ LPGEMM_5LOOP(uint8_t,int8_t,int32_t,u8s8s32o32) int32_t* c_use_jc = NULL; int32_t* c_use_ic = NULL; + dim_t rs_c_use = rs_c; + dim_t rs_c_downscale = rs_c; // Pack buffer for A. uint8_t* pack_a_buffer_u8s8s32o32; @@ -80,6 +82,11 @@ LPGEMM_5LOOP(uint8_t,int8_t,int32_t,u8s8s32o32) siz_t mem_b_size_req = 0; dim_t packb_min_NR = get_packb_u8s8s32o32_min_NR(); + // Temporary buffer for C accumulation when downscaling is required. + int32_t* temp_scal_c_buffer_u8s8s32o32; + mem_t mem_scale_c = BLIS_MEM_INITIALIZER; + siz_t mem_scale_c_size_req = 0; + // kc needs to be a multiple of 4 so that it can be used with vpdpbusd // instruction. Padding is added in cases this condition is not // satisfied, and therefore the k offset used for packed/reordered @@ -95,14 +102,16 @@ LPGEMM_5LOOP(uint8_t,int8_t,int32_t,u8s8s32o32) lpgemm_gen_thrinfo( thread, &thread_jc, &thread_ic ); - // Compute the JC loop thread range for the current thread. + // Compute the JC, IC loop thread range for the current thread. dim_t jc_start, jc_end; bli_thread_range_sub( &thread_jc, n, NR, FALSE, &jc_start, &jc_end ); + dim_t ic_start, ic_end; + bli_thread_range_sub( &thread_ic, m, MR, FALSE, &ic_start, &ic_end ); + for ( dim_t jc = jc_start; jc < jc_end; jc += NC ) { dim_t nc0 = bli_min( ( jc_end - jc ), NC ); - c_use_jc = c + jc; dim_t jc_cur_loop = jc; dim_t jc_cur_loop_rem = 0; @@ -118,6 +127,47 @@ LPGEMM_5LOOP(uint8_t,int8_t,int32_t,u8s8s32o32) ); } + if ( c_downscale == FALSE ) + { + c_use_jc = c + jc; + } + // Temp accumulaton buffer for C allocation. + else if ( c_downscale == TRUE ) + { + mem_scale_c_size_req = sizeof( int32_t ) * nc0 * ( ic_end - ic_start ); + + lpgemm_alloc_mem_panel + ( + mem_scale_c_size_req, BLIS_BUFFER_FOR_C_PANEL, + &mem_scale_c, rntm + ); + + temp_scal_c_buffer_u8s8s32o32 = bli_mem_buffer( &mem_scale_c ); + + c_use_jc = ( int32_t* )temp_scal_c_buffer_u8s8s32o32; + + if ( beta != 0 ) + { + dim_t i_temp = 0; + dim_t j_temp = 0; + // Upscale out C to temporary C matrix. + for ( dim_t i_dscale = ic_start; i_dscale < ic_end; ++i_dscale ) + { + j_temp = 0; + for ( dim_t j_dscale = jc; j_dscale < nc0; ++j_dscale ) + { + *( temp_scal_c_buffer_u8s8s32o32 + ( nc0 * i_temp ) + j_temp ) = + ( int32_t )( *( c + ( rs_c * i_dscale ) + j_dscale ) ); + j_temp++; + } + i_temp++; + } + } + + // The temp c buffer stride is modified as opposed to original C matrix. + rs_c_use = nc0; + } + for ( dim_t pc = 0; pc < k; pc += KC ) { int32_t beta0 = ( pc == 0 ) ? beta : 1; @@ -227,13 +277,11 @@ LPGEMM_5LOOP(uint8_t,int8_t,int32_t,u8s8s32o32) return; } - dim_t ic_start, ic_end; - bli_thread_range_sub( &thread_ic, m, MR, FALSE, &ic_start, &ic_end ); - for ( dim_t ic = ic_start; ic < ic_end; ic += MC ) { dim_t mc0 = bli_min( ( ic_end - ic ), MC ); - c_use_ic = c_use_jc + ( rs_c * ic ); + + c_use_ic = c_use_jc + ( rs_c_use * ic ); // Matrix A packed and reordered code path is not triggerred // currently since we do not support it yet. @@ -285,9 +333,9 @@ LPGEMM_5LOOP(uint8_t,int8_t,int32_t,u8s8s32o32) mc0, nr0, kc0, a_use, rs_a_use, cs_a_use, a_block_stride, ( b_use + ( jr * kc0_updated ) ), rs_b_use, cs_b_use, - ( c_use_ic + jr ), rs_c, 1, + ( c_use_ic + jr ), rs_c_use, 1, alpha, beta0, - is_last_k, ic, ( jc + jr ), post_op_list + is_last_k, ic, ( jc + jr ), post_op_list, rs_c_downscale ); } } @@ -324,4 +372,11 @@ LPGEMM_5LOOP(uint8_t,int8_t,int32_t,u8s8s32o32) bli_membrk_release( rntm, &mem_a ); } } + if ( c_downscale == TRUE ) + { + if ( bli_mem_is_alloc( &mem_scale_c ) ) + { + bli_membrk_release( rntm, &mem_scale_c ); + } + } } diff --git a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_6x64rowmajor_bf16_amd512vnni.c b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_6x64rowmajor_bf16_amd512vnni.c index 872aa9a747..81d15c4e04 100644 --- a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_6x64rowmajor_bf16_amd512vnni.c +++ b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_6x64rowmajor_bf16_amd512vnni.c @@ -88,7 +88,7 @@ LPGEMM_MAIN_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x64) alpha, beta, is_last_k, post_op_c_i, post_op_c_j, - post_ops_list + post_ops_list, rs_c_downscale ); b = b + ( 48 * k0_updated ); // k0x48 packed contiguosly. @@ -107,7 +107,7 @@ LPGEMM_MAIN_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x64) alpha, beta, is_last_k, post_op_c_i, post_op_c_j, - post_ops_list + post_ops_list, rs_c_downscale ); b = b + ( 32 * k0_updated ); // k0x32 packed contiguosly. @@ -126,7 +126,7 @@ LPGEMM_MAIN_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x64) alpha, beta, is_last_k, post_op_c_i, post_op_c_j, - post_ops_list + post_ops_list, rs_c_downscale ); b = b + ( 16 * k0_updated ); // k0x16 packed contiguosly. @@ -145,7 +145,7 @@ LPGEMM_MAIN_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x64) alpha, beta, n0_rem, is_last_k, post_op_c_i, post_op_c_j, - post_ops_list + post_ops_list, rs_c_downscale ); // No leftover fringe after this podint. @@ -893,7 +893,7 @@ LPGEMM_MAIN_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x64) alpha, beta, is_last_k, post_op_c_i, post_op_c_j, - post_ops_list + post_ops_list, rs_c_downscale ); } else if ( m_partial_pieces == 4 ) @@ -908,7 +908,7 @@ LPGEMM_MAIN_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x64) alpha, beta, is_last_k, post_op_c_i, post_op_c_j, - post_ops_list + post_ops_list, rs_c_downscale ); } else if ( m_partial_pieces == 3 ) @@ -923,7 +923,7 @@ LPGEMM_MAIN_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x64) alpha, beta, is_last_k, post_op_c_i, post_op_c_j, - post_ops_list + post_ops_list, rs_c_downscale ); } else if ( m_partial_pieces == 2 ) @@ -938,7 +938,7 @@ LPGEMM_MAIN_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x64) alpha, beta, is_last_k, post_op_c_i, post_op_c_j, - post_ops_list + post_ops_list, rs_c_downscale ); } else if ( m_partial_pieces == 1 ) @@ -953,7 +953,7 @@ LPGEMM_MAIN_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x64) alpha, beta, is_last_k, post_op_c_i, post_op_c_j, - post_ops_list + post_ops_list, rs_c_downscale ); } } diff --git a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_n_fringe_bf16_amd512vnni.c b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_n_fringe_bf16_amd512vnni.c index f796a7d0ea..1a6a238d9d 100644 --- a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_n_fringe_bf16_amd512vnni.c +++ b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_n_fringe_bf16_amd512vnni.c @@ -420,7 +420,7 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6xlt16) alpha, beta, n0_rem, is_last_k, post_op_c_i, post_op_c_j, - post_ops_list + post_ops_list, rs_c_downscale ); } else if ( m_partial_pieces == 4 ) @@ -435,7 +435,7 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6xlt16) alpha, beta, n0_rem, is_last_k, post_op_c_i, post_op_c_j, - post_ops_list + post_ops_list, rs_c_downscale ); } else if ( m_partial_pieces == 3 ) @@ -450,7 +450,7 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6xlt16) alpha, beta, n0_rem, is_last_k, post_op_c_i, post_op_c_j, - post_ops_list + post_ops_list, rs_c_downscale ); } else if ( m_partial_pieces == 2 ) @@ -465,7 +465,7 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6xlt16) alpha, beta, n0_rem, is_last_k, post_op_c_i, post_op_c_j, - post_ops_list + post_ops_list, rs_c_downscale ); } else if ( m_partial_pieces == 1 ) @@ -480,7 +480,7 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6xlt16) alpha, beta, n0_rem, is_last_k, post_op_c_i, post_op_c_j, - post_ops_list + post_ops_list, rs_c_downscale ); } } @@ -834,7 +834,7 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x16) alpha, beta, is_last_k, post_op_c_i, post_op_c_j, - post_ops_list + post_ops_list, rs_c_downscale ); } else if ( m_partial_pieces == 4 ) @@ -849,7 +849,7 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x16) alpha, beta, is_last_k, post_op_c_i, post_op_c_j, - post_ops_list + post_ops_list, rs_c_downscale ); } else if ( m_partial_pieces == 3 ) @@ -864,7 +864,7 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x16) alpha, beta, is_last_k, post_op_c_i, post_op_c_j, - post_ops_list + post_ops_list, rs_c_downscale ); } else if ( m_partial_pieces == 2 ) @@ -879,7 +879,7 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x16) alpha, beta, is_last_k, post_op_c_i, post_op_c_j, - post_ops_list + post_ops_list, rs_c_downscale ); } else if ( m_partial_pieces == 1 ) @@ -894,7 +894,7 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x16) alpha, beta, is_last_k, post_op_c_i, post_op_c_j, - post_ops_list + post_ops_list, rs_c_downscale ); } } @@ -1378,7 +1378,7 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x32) alpha, beta, is_last_k, post_op_c_i, post_op_c_j, - post_ops_list + post_ops_list, rs_c_downscale ); } else if ( m_partial_pieces == 4 ) @@ -1393,7 +1393,7 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x32) alpha, beta, is_last_k, post_op_c_i, post_op_c_j, - post_ops_list + post_ops_list, rs_c_downscale ); } else if ( m_partial_pieces == 3 ) @@ -1408,7 +1408,7 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x32) alpha, beta, is_last_k, post_op_c_i, post_op_c_j, - post_ops_list + post_ops_list, rs_c_downscale ); } else if ( m_partial_pieces == 2 ) @@ -1423,7 +1423,7 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x32) alpha, beta, is_last_k, post_op_c_i, post_op_c_j, - post_ops_list + post_ops_list, rs_c_downscale ); } else if ( m_partial_pieces == 1 ) @@ -1438,7 +1438,7 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x32) alpha, beta, is_last_k, post_op_c_i, post_op_c_j, - post_ops_list + post_ops_list, rs_c_downscale ); } } @@ -2059,7 +2059,7 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x48) alpha, beta, is_last_k, post_op_c_i, post_op_c_j, - post_ops_list + post_ops_list, rs_c_downscale ); } else if ( m_partial_pieces == 4 ) @@ -2074,7 +2074,7 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x48) alpha, beta, is_last_k, post_op_c_i, post_op_c_j, - post_ops_list + post_ops_list, rs_c_downscale ); } else if ( m_partial_pieces == 3 ) @@ -2089,7 +2089,7 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x48) alpha, beta, is_last_k, post_op_c_i, post_op_c_j, - post_ops_list + post_ops_list, rs_c_downscale ); } else if ( m_partial_pieces == 2 ) @@ -2104,7 +2104,7 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x48) alpha, beta, is_last_k, post_op_c_i, post_op_c_j, - post_ops_list + post_ops_list, rs_c_downscale ); } else if ( m_partial_pieces == 1 ) @@ -2119,7 +2119,7 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x48) alpha, beta, is_last_k, post_op_c_i, post_op_c_j, - post_ops_list + post_ops_list, rs_c_downscale ); } } diff --git a/addon/aocl_gemm/kernels/lpgemm_kernels.h b/addon/aocl_gemm/kernels/lpgemm_kernels.h index 3e79a0ce58..7b73ba27e9 100644 --- a/addon/aocl_gemm/kernels/lpgemm_kernels.h +++ b/addon/aocl_gemm/kernels/lpgemm_kernels.h @@ -59,7 +59,8 @@ void lpgemm_rowvar_ ## LP_SFX \ bool is_last_k, \ dim_t post_op_c_i, \ dim_t post_op_c_j, \ - lpgemm_post_op* post_ops_list \ + lpgemm_post_op* post_ops_list, \ + const dim_t rs_c_downscale \ ) \ LPGEMM_MAIN_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x64); @@ -83,7 +84,8 @@ void lpgemm_rowvar_ ## LP_SFX \ bool is_last_k, \ dim_t post_op_c_i, \ dim_t post_op_c_j, \ - lpgemm_post_op* post_ops_list \ + lpgemm_post_op* post_ops_list, \ + const dim_t rs_c_downscale \ ) \ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5x64); @@ -121,7 +123,8 @@ void lpgemm_rowvar_ ## LP_SFX \ bool is_last_k, \ dim_t post_op_c_i, \ dim_t post_op_c_j, \ - lpgemm_post_op* post_ops_list \ + lpgemm_post_op* post_ops_list, \ + const dim_t rs_c_downscale \ ) \ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x16); @@ -154,7 +157,8 @@ void lpgemm_rowvar_ ## LP_SFX \ bool is_last_k, \ dim_t post_op_c_i, \ dim_t post_op_c_j, \ - lpgemm_post_op* post_ops_list \ + lpgemm_post_op* post_ops_list, \ + const dim_t rs_c_downscale \ ) \ LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6xlt16); @@ -180,7 +184,8 @@ void lpgemm_rowvar_ ## LP_SFX \ bool is_last_k, \ dim_t post_op_c_i, \ dim_t post_op_c_j, \ - lpgemm_post_op* post_ops_list \ + lpgemm_post_op* post_ops_list, \ + const dim_t rs_c_downscale \ ) \ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5x16); @@ -237,7 +242,8 @@ void lpgemm_rowvar_ ## LP_SFX \ bool is_last_k, \ dim_t post_op_c_i, \ dim_t post_op_c_j, \ - lpgemm_post_op* post_ops_list \ + lpgemm_post_op* post_ops_list, \ + const dim_t rs_c_downscale \ ) \ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5xlt16); diff --git a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_6x32rowmajor_amd256.c b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_6x32rowmajor_amd256.c index ad0d6e2a66..553ce1a134 100644 --- a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_6x32rowmajor_amd256.c +++ b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_6x32rowmajor_amd256.c @@ -80,7 +80,7 @@ LPGEMM_MAIN_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x32) alpha, beta, is_last_k, post_op_c_i, post_op_c_j, - post_ops_list); + post_ops_list, rs_c_downscale); b = b + (16 * k0_updated); c = c + 16; @@ -97,7 +97,7 @@ LPGEMM_MAIN_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x32) alpha, beta, n0_rem, is_last_k, post_op_c_i, post_op_c_j, - post_ops_list); + post_ops_list, rs_c_downscale); } // If fringe cases are encountered, return early @@ -627,7 +627,7 @@ LPGEMM_MAIN_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x32) alpha, beta, is_last_k, post_op_c_i, post_op_c_j, - post_ops_list); + post_ops_list, rs_c_downscale); // a pointer increment a = a + (4 * ps_a); @@ -644,7 +644,7 @@ LPGEMM_MAIN_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x32) alpha, beta, is_last_k, post_op_c_i, post_op_c_j, - post_ops_list); + post_ops_list, rs_c_downscale); // a pointer increment a = a + (2 * ps_a); @@ -660,7 +660,7 @@ LPGEMM_MAIN_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x32) (c + (rs_c * m_full_pieces_loop_limit)), rs_c, alpha, beta,is_last_k, post_op_c_i, post_op_c_j, - post_ops_list); + post_ops_list, rs_c_downscale); } } } diff --git a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_n_fringe_amd256.c b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_n_fringe_amd256.c index 631e01f27f..171e730e0a 100644 --- a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_n_fringe_amd256.c +++ b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_n_fringe_amd256.c @@ -390,7 +390,7 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x16) alpha, beta, is_last_k, post_op_c_i, post_op_c_j, - post_ops_list); + post_ops_list, rs_c_downscale); // a pointer increment a = a + (4 * ps_a); @@ -407,7 +407,7 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x16) alpha, beta, is_last_k, post_op_c_i, post_op_c_j, - post_ops_list); + post_ops_list, rs_c_downscale); // a pointer increment a = a + (2 * ps_a); @@ -424,7 +424,7 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x16) alpha, beta, is_last_k, post_op_c_i, post_op_c_j, - post_ops_list); + post_ops_list, rs_c_downscale); } } } @@ -813,7 +813,7 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6xlt16) alpha, beta, n0_rem, is_last_k, post_op_c_i, post_op_c_j, - post_ops_list); + post_ops_list, rs_c_downscale); // a pointer increment a = a + (4 * ps_a); @@ -830,7 +830,7 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6xlt16) alpha, beta, n0_rem, is_last_k, post_op_c_i, post_op_c_j, - post_ops_list); + post_ops_list, rs_c_downscale); // a pointer increment a = a + (2 * ps_a); @@ -847,7 +847,7 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6xlt16) alpha, beta, n0_rem, is_last_k, post_op_c_i, post_op_c_j, - post_ops_list); + post_ops_list, rs_c_downscale); } } } diff --git a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_6x64rowmajor_amd512vnni.c b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_6x64rowmajor_amd512vnni.c index 64c61a9c57..ea5c8bf280 100644 --- a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_6x64rowmajor_amd512vnni.c +++ b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_6x64rowmajor_amd512vnni.c @@ -46,7 +46,8 @@ LPGEMM_MAIN_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x64) &&POST_OPS_6x64_DISABLE, &&POST_OPS_BIAS_6x64, &&POST_OPS_RELU_6x64, - &&POST_OPS_RELU_SCALE_6x64 + &&POST_OPS_RELU_SCALE_6x64, + &&POST_OPS_DOWNSCALE_6x64 }; dim_t MR = 6; @@ -92,7 +93,7 @@ LPGEMM_MAIN_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x64) alpha, beta, is_last_k, post_op_c_i, post_op_c_j, - post_ops_list + post_ops_list, rs_c_downscale ); b = b + ( 48 * k0_updated ); // k0x48 packed contiguosly. @@ -110,7 +111,7 @@ LPGEMM_MAIN_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x64) alpha, beta, is_last_k, post_op_c_i, post_op_c_j, - post_ops_list + post_ops_list, rs_c_downscale ); b = b + ( 32 * k0_updated ); // k0x32 packed contiguosly. @@ -128,7 +129,7 @@ LPGEMM_MAIN_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x64) alpha, beta, is_last_k, post_op_c_i, post_op_c_j, - post_ops_list + post_ops_list, rs_c_downscale ); b = b + ( 16 * k0_updated ); // k0x16 packed contiguosly. @@ -147,7 +148,7 @@ LPGEMM_MAIN_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x64) alpha, beta, n0_rem, is_last_k, post_op_c_i, post_op_c_j, - post_ops_list + post_ops_list, rs_c_downscale ); // No leftover fringe after this point. @@ -800,6 +801,95 @@ LPGEMM_MAIN_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x64) // c[5, 48-63] RELU_SCALE_OP_S32_AVX512(c_int32_5p3) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_6x64: + { + selector1 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 1 * 16 ) ); + a_int32_0 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 2 * 16 ) ); + a_int32_1 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 3 * 16 ) ); + + // c[0, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_0p0,selector1,0,0); + + // c[0, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_0p1,selector2,0,1); + + // c[0, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_0p2,a_int32_0,0,2); + + // c[0, 48-63] + CVT_MULRND_CVT32_CVT8(c_int32_0p3,a_int32_1,0,3); + + // c[1, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_1p0,selector1,1,0); + + // c[1, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_1p1,selector2,1,1); + + // c[1, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_1p2,a_int32_0,1,2); + + // c[1, 48-63] + CVT_MULRND_CVT32_CVT8(c_int32_1p3,a_int32_1,1,3); + + // c[2, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_2p0,selector1,2,0); + + // c[2, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_2p1,selector2,2,1); + + // c[2, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_2p2,a_int32_0,2,2); + + // c[2, 48-63] + CVT_MULRND_CVT32_CVT8(c_int32_2p3,a_int32_1,2,3); + + // c[3, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_3p0,selector1,3,0); + + // c[3, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_3p1,selector2,3,1); + + // c[3, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_3p2,a_int32_0,3,2); + + // c[3, 48-63] + CVT_MULRND_CVT32_CVT8(c_int32_3p3,a_int32_1,3,3); + + // c[4, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_4p0,selector1,4,0); + + // c[4, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_4p1,selector2,4,1); + + // c[4, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_4p2,a_int32_0,4,2); + + // c[4, 48-63] + CVT_MULRND_CVT32_CVT8(c_int32_4p3,a_int32_1,4,3); + + // c[5, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_5p0,selector1,5,0); + + // c[5, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_5p1,selector2,5,1); + + // c[5, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_5p2,a_int32_0,5,2); + + // c[5, 48-63] + CVT_MULRND_CVT32_CVT8(c_int32_5p3,a_int32_1,5,3); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_6x64_DISABLE: @@ -902,7 +992,7 @@ LPGEMM_MAIN_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x64) alpha, beta, is_last_k, post_op_c_i, post_op_c_j, - post_ops_list + post_ops_list, rs_c_downscale ); } else if ( m_partial_pieces == 4 ) @@ -917,7 +1007,7 @@ LPGEMM_MAIN_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x64) alpha, beta, is_last_k, post_op_c_i, post_op_c_j, - post_ops_list + post_ops_list, rs_c_downscale ); } else if ( m_partial_pieces == 3 ) @@ -932,7 +1022,7 @@ LPGEMM_MAIN_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x64) alpha, beta, is_last_k, post_op_c_i, post_op_c_j, - post_ops_list + post_ops_list, rs_c_downscale ); } else if ( m_partial_pieces == 2 ) @@ -947,7 +1037,7 @@ LPGEMM_MAIN_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x64) alpha, beta, is_last_k, post_op_c_i, post_op_c_j, - post_ops_list + post_ops_list, rs_c_downscale ); } else if ( m_partial_pieces == 1 ) @@ -962,7 +1052,7 @@ LPGEMM_MAIN_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x64) alpha, beta, is_last_k, post_op_c_i, post_op_c_j, - post_ops_list + post_ops_list, rs_c_downscale ); } } diff --git a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_m_fringe_amd512vnni.c b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_m_fringe_amd512vnni.c index 896aa23d60..e6079cce23 100644 --- a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_m_fringe_amd512vnni.c +++ b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_m_fringe_amd512vnni.c @@ -47,7 +47,8 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5x64) &&POST_OPS_5x64_DISABLE, &&POST_OPS_BIAS_5x64, &&POST_OPS_RELU_5x64, - &&POST_OPS_RELU_SCALE_5x64 + &&POST_OPS_RELU_SCALE_5x64, + &&POST_OPS_DOWNSCALE_5x64 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -594,6 +595,83 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5x64) // c[4, 48-63] RELU_SCALE_OP_S32_AVX512(c_int32_4p3) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_5x64: + { + selector1 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 1 * 16 ) ); + a_int32_0 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 2 * 16 ) ); + a_int32_1 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 3 * 16 ) ); + + // c[0, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_0p0,selector1,0,0); + + // c[0, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_0p1,selector2,0,1); + + // c[0, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_0p2,a_int32_0,0,2); + + // c[0, 48-63] + CVT_MULRND_CVT32_CVT8(c_int32_0p3,a_int32_1,0,3); + + // c[1, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_1p0,selector1,1,0); + + // c[1, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_1p1,selector2,1,1); + + // c[1, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_1p2,a_int32_0,1,2); + + // c[1, 48-63] + CVT_MULRND_CVT32_CVT8(c_int32_1p3,a_int32_1,1,3); + + // c[2, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_2p0,selector1,2,0); + + // c[2, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_2p1,selector2,2,1); + + // c[2, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_2p2,a_int32_0,2,2); + + // c[2, 48-63] + CVT_MULRND_CVT32_CVT8(c_int32_2p3,a_int32_1,2,3); + + // c[3, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_3p0,selector1,3,0); + + // c[3, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_3p1,selector2,3,1); + + // c[3, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_3p2,a_int32_0,3,2); + + // c[3, 48-63] + CVT_MULRND_CVT32_CVT8(c_int32_3p3,a_int32_1,3,3); + + // c[4, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_4p0,selector1,4,0); + + // c[4, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_4p1,selector2,4,1); + + // c[4, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_4p2,a_int32_0,4,2); + + // c[4, 48-63] + CVT_MULRND_CVT32_CVT8(c_int32_4p3,a_int32_1,4,3); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_5x64_DISABLE: @@ -669,7 +747,8 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_4x64) &&POST_OPS_4x64_DISABLE, &&POST_OPS_BIAS_4x64, &&POST_OPS_RELU_4x64, - &&POST_OPS_RELU_SCALE_4x64 + &&POST_OPS_RELU_SCALE_4x64, + &&POST_OPS_DOWNSCALE_4x64 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -1122,6 +1201,71 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_4x64) // c[3, 48-63] RELU_SCALE_OP_S32_AVX512(c_int32_3p3) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_4x64: + { + selector1 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 1 * 16 ) ); + a_int32_0 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 2 * 16 ) ); + a_int32_1 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 3 * 16 ) ); + + // c[0, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_0p0,selector1,0,0); + + // c[0, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_0p1,selector2,0,1); + + // c[0, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_0p2,a_int32_0,0,2); + + // c[0, 48-63] + CVT_MULRND_CVT32_CVT8(c_int32_0p3,a_int32_1,0,3); + + // c[1, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_1p0,selector1,1,0); + + // c[1, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_1p1,selector2,1,1); + + // c[1, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_1p2,a_int32_0,1,2); + + // c[1, 48-63] + CVT_MULRND_CVT32_CVT8(c_int32_1p3,a_int32_1,1,3); + + // c[2, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_2p0,selector1,2,0); + + // c[2, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_2p1,selector2,2,1); + + // c[2, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_2p2,a_int32_0,2,2); + + // c[2, 48-63] + CVT_MULRND_CVT32_CVT8(c_int32_2p3,a_int32_1,2,3); + + // c[3, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_3p0,selector1,3,0); + + // c[3, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_3p1,selector2,3,1); + + // c[3, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_3p2,a_int32_0,3,2); + + // c[3, 48-63] + CVT_MULRND_CVT32_CVT8(c_int32_3p3,a_int32_1,3,3); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_4x64_DISABLE: @@ -1185,7 +1329,8 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3x64) &&POST_OPS_3x64_DISABLE, &&POST_OPS_BIAS_3x64, &&POST_OPS_RELU_3x64, - &&POST_OPS_RELU_SCALE_3x64 + &&POST_OPS_RELU_SCALE_3x64, + &&POST_OPS_DOWNSCALE_3x64 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -1544,6 +1689,59 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3x64) // c[2, 48-63] RELU_SCALE_OP_S32_AVX512(c_int32_2p3) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_3x64: + { + selector1 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 1 * 16 ) ); + a_int32_0 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 2 * 16 ) ); + a_int32_1 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 3 * 16 ) ); + + // c[0, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_0p0,selector1,0,0); + + // c[0, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_0p1,selector2,0,1); + + // c[0, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_0p2,a_int32_0,0,2); + + // c[0, 48-63] + CVT_MULRND_CVT32_CVT8(c_int32_0p3,a_int32_1,0,3); + + // c[1, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_1p0,selector1,1,0); + + // c[1, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_1p1,selector2,1,1); + + // c[1, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_1p2,a_int32_0,1,2); + + // c[1, 48-63] + CVT_MULRND_CVT32_CVT8(c_int32_1p3,a_int32_1,1,3); + + // c[2, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_2p0,selector1,2,0); + + // c[2, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_2p1,selector2,2,1); + + // c[2, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_2p2,a_int32_0,2,2); + + // c[2, 48-63] + CVT_MULRND_CVT32_CVT8(c_int32_2p3,a_int32_1,2,3); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_3x64_DISABLE: @@ -1595,7 +1793,8 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2x64) &&POST_OPS_2x64_DISABLE, &&POST_OPS_BIAS_2x64, &&POST_OPS_RELU_2x64, - &&POST_OPS_RELU_SCALE_2x64 + &&POST_OPS_RELU_SCALE_2x64, + &&POST_OPS_DOWNSCALE_2x64 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -1860,6 +2059,47 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2x64) // c[1, 48-63] RELU_SCALE_OP_S32_AVX512(c_int32_1p3) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_2x64: + { + selector1 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 1 * 16 ) ); + a_int32_0 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 2 * 16 ) ); + a_int32_1 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 3 * 16 ) ); + + // c[0, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_0p0,selector1,0,0); + + // c[0, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_0p1,selector2,0,1); + + // c[0, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_0p2,a_int32_0,0,2); + + // c[0, 48-63] + CVT_MULRND_CVT32_CVT8(c_int32_0p3,a_int32_1,0,3); + + // c[1, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_1p0,selector1,1,0); + + // c[1, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_1p1,selector2,1,1); + + // c[1, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_1p2,a_int32_0,1,2); + + // c[1, 48-63] + CVT_MULRND_CVT32_CVT8(c_int32_1p3,a_int32_1,1,3); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_2x64_DISABLE: @@ -1899,7 +2139,8 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1x64) &&POST_OPS_1x64_DISABLE, &&POST_OPS_BIAS_1x64, &&POST_OPS_RELU_1x64, - &&POST_OPS_RELU_SCALE_1x64 + &&POST_OPS_RELU_SCALE_1x64, + &&POST_OPS_DOWNSCALE_1x64 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -2070,6 +2311,35 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1x64) // c[0, 48-63] RELU_SCALE_OP_S32_AVX512(c_int32_0p3) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_1x64: + { + selector1 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 1 * 16 ) ); + a_int32_0 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 2 * 16 ) ); + a_int32_1 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 3 * 16 ) ); + + // c[0, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_0p0,selector1,0,0); + + // c[0, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_0p1,selector2,0,1); + + // c[0, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_0p2,a_int32_0,0,2); + + // c[0, 48-63] + CVT_MULRND_CVT32_CVT8(c_int32_0p3,a_int32_1,0,3); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_1x64_DISABLE: diff --git a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_mn_fringe_amd512vnni.c b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_mn_fringe_amd512vnni.c index 07c2d8fef9..3900dcb82b 100644 --- a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_mn_fringe_amd512vnni.c +++ b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_mn_fringe_amd512vnni.c @@ -47,7 +47,8 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5xlt16) &&POST_OPS_5xLT16_DISABLE, &&POST_OPS_BIAS_5xLT16, &&POST_OPS_RELU_5xLT16, - &&POST_OPS_RELU_SCALE_5xLT16 + &&POST_OPS_RELU_SCALE_5xLT16, + &&POST_OPS_DOWNSCALE_5xLT16 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -280,6 +281,29 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5xlt16) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_DOWNSCALE_5xLT16: + { + memcpy( buf0, ( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j ), ( n0_rem * sizeof( float ) ) ); + selector1 = _mm512_loadu_epi32( buf0 ); + + // c[0, 0-15] + CVT_MULRND_CVT32_CVT8_LT16(c_int32_0p0,selector1,0,0); + + // c[1, 0-15] + CVT_MULRND_CVT32_CVT8_LT16(c_int32_1p0,selector1,1,0); + + // c[2, 0-15] + CVT_MULRND_CVT32_CVT8_LT16(c_int32_2p0,selector1,2,0); + + // c[3, 0-15] + CVT_MULRND_CVT32_CVT8_LT16(c_int32_3p0,selector1,3,0); + + // c[4, 0-15] + CVT_MULRND_CVT32_CVT8_LT16(c_int32_4p0,selector1,4,0); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_5xLT16_DISABLE: ; @@ -325,7 +349,8 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_4xlt16) &&POST_OPS_4xLT16_DISABLE, &&POST_OPS_BIAS_4xLT16, &&POST_OPS_RELU_4xLT16, - &&POST_OPS_RELU_SCALE_4xLT16 + &&POST_OPS_RELU_SCALE_4xLT16, + &&POST_OPS_DOWNSCALE_4xLT16 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -524,6 +549,26 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_4xlt16) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_DOWNSCALE_4xLT16: + { + memcpy( buf0, ( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j ), ( n0_rem * sizeof( float ) ) ); + selector1 = _mm512_loadu_epi32( buf0 ); + + // c[0, 0-15] + CVT_MULRND_CVT32_CVT8_LT16(c_int32_0p0,selector1,0,0); + + // c[1, 0-15] + CVT_MULRND_CVT32_CVT8_LT16(c_int32_1p0,selector1,1,0); + + // c[2, 0-15] + CVT_MULRND_CVT32_CVT8_LT16(c_int32_2p0,selector1,2,0); + + // c[3, 0-15] + CVT_MULRND_CVT32_CVT8_LT16(c_int32_3p0,selector1,3,0); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_4xLT16_DISABLE: ; @@ -563,7 +608,8 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3xlt16) &&POST_OPS_3xLT16_DISABLE, &&POST_OPS_BIAS_3xLT16, &&POST_OPS_RELU_3xLT16, - &&POST_OPS_RELU_SCALE_3xLT16 + &&POST_OPS_RELU_SCALE_3xLT16, + &&POST_OPS_DOWNSCALE_3xLT16 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -726,6 +772,23 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3xlt16) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_DOWNSCALE_3xLT16: + { + memcpy( buf0, ( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j ), ( n0_rem * sizeof( float ) ) ); + selector1 = _mm512_loadu_epi32( buf0 ); + + // c[0, 0-15] + CVT_MULRND_CVT32_CVT8_LT16(c_int32_0p0,selector1,0,0); + + // c[1, 0-15] + CVT_MULRND_CVT32_CVT8_LT16(c_int32_1p0,selector1,1,0); + + // c[2, 0-15] + CVT_MULRND_CVT32_CVT8_LT16(c_int32_2p0,selector1,2,0); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_3xLT16_DISABLE: ; @@ -759,7 +822,8 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2xlt16) &&POST_OPS_2xLT16_DISABLE, &&POST_OPS_BIAS_2xLT16, &&POST_OPS_RELU_2xLT16, - &&POST_OPS_RELU_SCALE_2xLT16 + &&POST_OPS_RELU_SCALE_2xLT16, + &&POST_OPS_DOWNSCALE_2xLT16 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -887,6 +951,20 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2xlt16) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_DOWNSCALE_2xLT16: + { + memcpy( buf0, ( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j ), ( n0_rem * sizeof( float ) ) ); + selector1 = _mm512_loadu_epi32( buf0 ); + + // c[0, 0-15] + CVT_MULRND_CVT32_CVT8_LT16(c_int32_0p0,selector1,0,0); + + // c[1, 0-15] + CVT_MULRND_CVT32_CVT8_LT16(c_int32_1p0,selector1,1,0); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_2xLT16_DISABLE: ; @@ -914,7 +992,8 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1xlt16) &&POST_OPS_1xLT16_DISABLE, &&POST_OPS_BIAS_1xLT16, &&POST_OPS_RELU_1xLT16, - &&POST_OPS_RELU_SCALE_1xLT16 + &&POST_OPS_RELU_SCALE_1xLT16, + &&POST_OPS_DOWNSCALE_1xLT16 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -1007,6 +1086,17 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1xlt16) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_DOWNSCALE_1xLT16: + { + memcpy( buf0, ( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j ), ( n0_rem * sizeof( float ) ) ); + selector1 = _mm512_loadu_epi32( buf0 ); + + // c[0, 0-15] + CVT_MULRND_CVT32_CVT8_LT16(c_int32_0p0,selector1,0,0); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_1xLT16_DISABLE: ; @@ -1028,7 +1118,8 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5x16) &&POST_OPS_5x16_DISABLE, &&POST_OPS_BIAS_5x16, &&POST_OPS_RELU_5x16, - &&POST_OPS_RELU_SCALE_5x16 + &&POST_OPS_RELU_SCALE_5x16, + &&POST_OPS_DOWNSCALE_5x16 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -1245,6 +1336,29 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5x16) // c[4, 0-15] RELU_SCALE_OP_S32_AVX512(c_int32_4p0) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_5x16: + { + selector1 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 0 * 16 ) ); + + // c[0, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_0p0,selector1,0,0); + + // c[1, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_1p0,selector1,1,0); + + // c[2, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_2p0,selector1,2,0); + + // c[3, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_3p0,selector1,3,0); + + // c[4, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_4p0,selector1,4,0); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_5x16_DISABLE: @@ -1275,7 +1389,8 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_4x16) &&POST_OPS_4x16_DISABLE, &&POST_OPS_BIAS_4x16, &&POST_OPS_RELU_4x16, - &&POST_OPS_RELU_SCALE_4x16 + &&POST_OPS_RELU_SCALE_4x16, + &&POST_OPS_DOWNSCALE_4x16 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -1459,6 +1574,26 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_4x16) // c[3, 0-15] RELU_SCALE_OP_S32_AVX512(c_int32_3p0) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_4x16: + { + selector1 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 0 * 16 ) ); + + // c[0, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_0p0,selector1,0,0); + + // c[1, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_1p0,selector1,1,0); + + // c[2, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_2p0,selector1,2,0); + + // c[3, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_3p0,selector1,3,0); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_4x16_DISABLE: @@ -1486,7 +1621,8 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3x16) &&POST_OPS_3x16_DISABLE, &&POST_OPS_BIAS_3x16, &&POST_OPS_RELU_3x16, - &&POST_OPS_RELU_SCALE_3x16 + &&POST_OPS_RELU_SCALE_3x16, + &&POST_OPS_DOWNSCALE_3x16 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -1637,6 +1773,23 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3x16) // c[2, 0-15] RELU_SCALE_OP_S32_AVX512(c_int32_2p0) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_3x16: + { + selector1 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 0 * 16 ) ); + + // c[0, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_0p0,selector1,0,0); + + // c[1, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_1p0,selector1,1,0); + + // c[2, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_2p0,selector1,2,0); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_3x16_DISABLE: @@ -1661,7 +1814,8 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2x16) &&POST_OPS_2x16_DISABLE, &&POST_OPS_BIAS_2x16, &&POST_OPS_RELU_2x16, - &&POST_OPS_RELU_SCALE_2x16 + &&POST_OPS_RELU_SCALE_2x16, + &&POST_OPS_DOWNSCALE_2x16 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -1779,6 +1933,20 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2x16) // c[1, 0-15] RELU_SCALE_OP_S32_AVX512(c_int32_1p0) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_2x16: + { + selector1 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 0 * 16 ) ); + + // c[0, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_0p0,selector1,0,0); + + // c[1, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_1p0,selector1,1,0); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_2x16_DISABLE: @@ -1800,7 +1968,8 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1x16) &&POST_OPS_1x16_DISABLE, &&POST_OPS_BIAS_1x16, &&POST_OPS_RELU_1x16, - &&POST_OPS_RELU_SCALE_1x16 + &&POST_OPS_RELU_SCALE_1x16, + &&POST_OPS_DOWNSCALE_1x16 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -1885,6 +2054,17 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1x16) // c[0, 0-15] RELU_SCALE_OP_S32_AVX512(c_int32_0p0) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_1x16: + { + selector1 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 0 * 16 ) ); + + // c[0, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_0p0,selector1,0,0); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_1x16_DISABLE: @@ -1903,7 +2083,8 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5x32) &&POST_OPS_5x32_DISABLE, &&POST_OPS_BIAS_5x32, &&POST_OPS_RELU_5x32, - &&POST_OPS_RELU_SCALE_5x32 + &&POST_OPS_RELU_SCALE_5x32, + &&POST_OPS_DOWNSCALE_5x32 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -2222,6 +2403,47 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5x32) // c[4, 16-31] RELU_SCALE_OP_S32_AVX512(c_int32_4p1) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_5x32: + { + selector1 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 1 * 16 ) ); + + // c[0, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_0p0,selector1,0,0); + + // c[0, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_0p1,selector2,0,1); + + // c[1, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_1p0,selector1,1,0); + + // c[1, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_1p1,selector2,1,1); + + // c[2, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_2p0,selector1,2,0); + + // c[2, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_2p1,selector2,2,1); + + // c[3, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_3p0,selector1,3,0); + + // c[3, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_3p1,selector2,3,1); + + // c[4, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_4p0,selector1,4,0); + + // c[4, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_4p1,selector2,4,1); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_5x32_DISABLE: @@ -2267,7 +2489,8 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_4x32) &&POST_OPS_4x32_DISABLE, &&POST_OPS_BIAS_4x32, &&POST_OPS_RELU_4x32, - &&POST_OPS_RELU_SCALE_4x32 + &&POST_OPS_RELU_SCALE_4x32, + &&POST_OPS_DOWNSCALE_4x32 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -2535,6 +2758,41 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_4x32) // c[3, 16-31] RELU_SCALE_OP_S32_AVX512(c_int32_3p1) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_4x32: + { + selector1 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 1 * 16 ) ); + + // c[0, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_0p0,selector1,0,0); + + // c[0, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_0p1,selector2,0,1); + + // c[1, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_1p0,selector1,1,0); + + // c[1, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_1p1,selector2,1,1); + + // c[2, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_2p0,selector1,2,0); + + // c[2, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_2p1,selector2,2,1); + + // c[3, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_3p0,selector1,3,0); + + // c[3, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_3p1,selector2,3,1); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_4x32_DISABLE: @@ -2574,7 +2832,8 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3x32) &&POST_OPS_3x32_DISABLE, &&POST_OPS_BIAS_3x32, &&POST_OPS_RELU_3x32, - &&POST_OPS_RELU_SCALE_3x32 + &&POST_OPS_RELU_SCALE_3x32, + &&POST_OPS_DOWNSCALE_3x32 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -2791,6 +3050,35 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3x32) // c[2, 16-31] RELU_SCALE_OP_S32_AVX512(c_int32_2p1) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_3x32: + { + selector1 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 1 * 16 ) ); + + // c[0, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_0p0,selector1,0,0); + + // c[0, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_0p1,selector2,0,1); + + // c[1, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_1p0,selector1,1,0); + + // c[1, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_1p1,selector2,1,1); + + // c[2, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_2p0,selector1,2,0); + + // c[2, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_2p1,selector2,2,1); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_3x32_DISABLE: @@ -2824,7 +3112,8 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2x32) &&POST_OPS_2x32_DISABLE, &&POST_OPS_BIAS_2x32, &&POST_OPS_RELU_2x32, - &&POST_OPS_RELU_SCALE_2x32 + &&POST_OPS_RELU_SCALE_2x32, + &&POST_OPS_DOWNSCALE_2x32 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -2990,6 +3279,29 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2x32) // c[1, 16-31] RELU_SCALE_OP_S32_AVX512(c_int32_1p1) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_2x32: + { + selector1 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 1 * 16 ) ); + + // c[0, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_0p0,selector1,0,0); + + // c[0, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_0p1,selector2,0,1); + + // c[1, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_1p0,selector1,1,0); + + // c[1, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_1p1,selector2,1,1); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_2x32_DISABLE: @@ -3017,7 +3329,8 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1x32) &&POST_OPS_1x32_DISABLE, &&POST_OPS_BIAS_1x32, &&POST_OPS_RELU_1x32, - &&POST_OPS_RELU_SCALE_1x32 + &&POST_OPS_RELU_SCALE_1x32, + &&POST_OPS_DOWNSCALE_1x32 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -3132,6 +3445,23 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1x32) // c[0, 16-31] RELU_SCALE_OP_S32_AVX512(c_int32_0p1) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_1x32: + { + selector1 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 1 * 16 ) ); + + // c[0, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_0p0,selector1,0,0); + + // c[0, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_0p1,selector2,0,1); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_1x32_DISABLE: @@ -3153,7 +3483,8 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5x48) &&POST_OPS_5x48_DISABLE, &&POST_OPS_BIAS_5x48, &&POST_OPS_RELU_5x48, - &&POST_OPS_RELU_SCALE_5x48 + &&POST_OPS_RELU_SCALE_5x48, + &&POST_OPS_DOWNSCALE_5x48 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -3568,6 +3899,65 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5x48) // c[4, 32-47] RELU_SCALE_OP_S32_AVX512(c_int32_4p2) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_5x48: + { + selector1 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 1 * 16 ) ); + a_int32_0 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 2 * 16 ) ); + + // c[0, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_0p0,selector1,0,0); + + // c[0, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_0p1,selector2,0,1); + + // c[0, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_0p2,a_int32_0,0,2); + + // c[1, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_1p0,selector1,1,0); + + // c[1, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_1p1,selector2,1,1); + + // c[1, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_1p2,a_int32_0,1,2); + + // c[2, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_2p0,selector1,2,0); + + // c[2, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_2p1,selector2,2,1); + + // c[2, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_2p2,a_int32_0,2,2); + + // c[3, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_3p0,selector1,3,0); + + // c[3, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_3p1,selector2,3,1); + + // c[3, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_3p2,a_int32_0,3,2); + + // c[4, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_4p0,selector1,4,0); + + // c[4, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_4p1,selector2,4,1); + + // c[4, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_4p2,a_int32_0,4,2); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_5x48_DISABLE: @@ -3628,7 +4018,8 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_4x48) &&POST_OPS_4x48_DISABLE, &&POST_OPS_BIAS_4x48, &&POST_OPS_RELU_4x48, - &&POST_OPS_RELU_SCALE_4x48 + &&POST_OPS_RELU_SCALE_4x48, + &&POST_OPS_DOWNSCALE_4x48 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -3974,6 +4365,56 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_4x48) // c[3, 32-47] RELU_SCALE_OP_S32_AVX512(c_int32_3p2) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_4x48: + { + selector1 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 1 * 16 ) ); + a_int32_0 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 2 * 16 ) ); + + // c[0, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_0p0,selector1,0,0); + + // c[0, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_0p1,selector2,0,1); + + // c[0, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_0p2,a_int32_0,0,2); + + // c[1, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_1p0,selector1,1,0); + + // c[1, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_1p1,selector2,1,1); + + // c[1, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_1p2,a_int32_0,1,2); + + // c[2, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_2p0,selector1,2,0); + + // c[2, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_2p1,selector2,2,1); + + // c[2, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_2p2,a_int32_0,2,2); + + // c[3, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_3p0,selector1,3,0); + + // c[3, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_3p1,selector2,3,1); + + // c[3, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_3p2,a_int32_0,3,2); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_4x48_DISABLE: @@ -4025,7 +4466,8 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3x48) &&POST_OPS_3x48_DISABLE, &&POST_OPS_BIAS_3x48, &&POST_OPS_RELU_3x48, - &&POST_OPS_RELU_SCALE_3x48 + &&POST_OPS_RELU_SCALE_3x48, + &&POST_OPS_DOWNSCALE_3x48 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -4302,6 +4744,47 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_3x48) // c[2, 32-47] RELU_SCALE_OP_S32_AVX512(c_int32_2p2) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_3x48: + { + selector1 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 1 * 16 ) ); + a_int32_0 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 2 * 16 ) ); + + // c[0, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_0p0,selector1,0,0); + + // c[0, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_0p1,selector2,0,1); + + // c[0, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_0p2,a_int32_0,0,2); + + // c[1, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_1p0,selector1,1,0); + + // c[1, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_1p1,selector2,1,1); + + // c[1, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_1p2,a_int32_0,1,2); + + // c[2, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_2p0,selector1,2,0); + + // c[2, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_2p1,selector2,2,1); + + // c[2, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_2p2,a_int32_0,2,2); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_3x48_DISABLE: @@ -4344,7 +4827,8 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2x48) &&POST_OPS_2x48_DISABLE, &&POST_OPS_BIAS_2x48, &&POST_OPS_RELU_2x48, - &&POST_OPS_RELU_SCALE_2x48 + &&POST_OPS_RELU_SCALE_2x48, + &&POST_OPS_DOWNSCALE_2x48 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -4552,6 +5036,38 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_2x48) // c[1, 32-47] RELU_SCALE_OP_S32_AVX512(c_int32_1p2) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_2x48: + { + selector1 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 1 * 16 ) ); + a_int32_0 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 2 * 16 ) ); + + // c[0, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_0p0,selector1,0,0); + + // c[0, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_0p1,selector2,0,1); + + // c[0, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_0p2,a_int32_0,0,2); + + // c[1, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_1p0,selector1,1,0); + + // c[1, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_1p1,selector2,1,1); + + // c[1, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_1p2,a_int32_0,1,2); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_2x48_DISABLE: @@ -4585,7 +5101,8 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1x48) &&POST_OPS_1x48_DISABLE, &&POST_OPS_BIAS_1x48, &&POST_OPS_RELU_1x48, - &&POST_OPS_RELU_SCALE_1x48 + &&POST_OPS_RELU_SCALE_1x48, + &&POST_OPS_DOWNSCALE_1x48 }; dim_t k_full_pieces = k0 / 4; dim_t k_partial_pieces = k0 % 4; @@ -4724,6 +5241,29 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1x48) // c[0, 32-47] RELU_SCALE_OP_S32_AVX512(c_int32_0p2) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_1x48: + { + selector1 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 1 * 16 ) ); + a_int32_0 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 2 * 16 ) ); + + // c[0, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_0p0,selector1,0,0); + + // c[0, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_0p1,selector2,0,1); + + // c[0, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_0p2,a_int32_0,0,2); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_1x48_DISABLE: diff --git a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_n_fringe_amd512vnni.c b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_n_fringe_amd512vnni.c index b710f2eb49..609a0641db 100644 --- a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_n_fringe_amd512vnni.c +++ b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_n_fringe_amd512vnni.c @@ -47,7 +47,8 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6xlt16) &&POST_OPS_6xLT16_DISABLE, &&POST_OPS_BIAS_6xLT16, &&POST_OPS_RELU_6xLT16, - &&POST_OPS_RELU_SCALE_6xLT16 + &&POST_OPS_RELU_SCALE_6xLT16, + &&POST_OPS_DOWNSCALE_6xLT16 }; dim_t MR = 6; dim_t m_full_pieces = m0 / MR; @@ -362,6 +363,32 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6xlt16) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_DOWNSCALE_6xLT16: + { + memcpy( buf0, ( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j ), ( n0_rem * sizeof( float ) ) ); + selector1 = _mm512_loadu_epi32( buf0 ); + + // c[0, 0-15] + CVT_MULRND_CVT32_CVT8_LT16(c_int32_0p0,selector1,0,0); + + // c[1, 0-15] + CVT_MULRND_CVT32_CVT8_LT16(c_int32_1p0,selector1,1,0); + + // c[2, 0-15] + CVT_MULRND_CVT32_CVT8_LT16(c_int32_2p0,selector1,2,0); + + // c[3, 0-15] + CVT_MULRND_CVT32_CVT8_LT16(c_int32_3p0,selector1,3,0); + + // c[4, 0-15] + CVT_MULRND_CVT32_CVT8_LT16(c_int32_4p0,selector1,4,0); + + // c[5, 0-15] + CVT_MULRND_CVT32_CVT8_LT16(c_int32_5p0,selector1,5,0); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_6xLT16_DISABLE: ; @@ -421,7 +448,7 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6xlt16) alpha, beta, n0_rem, is_last_k, post_op_c_i, post_op_c_j, - post_ops_list + post_ops_list, rs_c_downscale ); } else if ( m_partial_pieces == 4 ) @@ -436,7 +463,7 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6xlt16) alpha, beta, n0_rem, is_last_k, post_op_c_i, post_op_c_j, - post_ops_list + post_ops_list, rs_c_downscale ); } else if ( m_partial_pieces == 3 ) @@ -451,7 +478,7 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6xlt16) alpha, beta, n0_rem, is_last_k, post_op_c_i, post_op_c_j, - post_ops_list + post_ops_list, rs_c_downscale ); } else if ( m_partial_pieces == 2 ) @@ -466,7 +493,7 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6xlt16) alpha, beta, n0_rem, is_last_k, post_op_c_i, post_op_c_j, - post_ops_list + post_ops_list, rs_c_downscale ); } else if ( m_partial_pieces == 1 ) @@ -481,7 +508,7 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6xlt16) alpha, beta, n0_rem, is_last_k, post_op_c_i, post_op_c_j, - post_ops_list + post_ops_list, rs_c_downscale ); } } @@ -495,7 +522,8 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x16) &&POST_OPS_6x16_DISABLE, &&POST_OPS_BIAS_6x16, &&POST_OPS_RELU_6x16, - &&POST_OPS_RELU_SCALE_6x16 + &&POST_OPS_RELU_SCALE_6x16, + &&POST_OPS_DOWNSCALE_6x16 }; dim_t MR = 6; dim_t m_full_pieces = m0 / MR; @@ -795,6 +823,32 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x16) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_DOWNSCALE_6x16: + { + selector1 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 0 * 16 ) ); + + // c[0, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_0p0,selector1,0,0); + + // c[1, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_1p0,selector1,1,0); + + // c[2, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_2p0,selector1,2,0); + + // c[3, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_3p0,selector1,3,0); + + // c[4, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_4p0,selector1,4,0); + + // c[5, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_5p0,selector1,5,0); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_6x16_DISABLE: ; @@ -835,7 +889,7 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x16) alpha, beta, is_last_k, post_op_c_i, post_op_c_j, - post_ops_list + post_ops_list, rs_c_downscale ); } else if ( m_partial_pieces == 4 ) @@ -850,7 +904,7 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x16) alpha, beta, is_last_k, post_op_c_i, post_op_c_j, - post_ops_list + post_ops_list, rs_c_downscale ); } else if ( m_partial_pieces == 3 ) @@ -865,7 +919,7 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x16) alpha, beta, is_last_k, post_op_c_i, post_op_c_j, - post_ops_list + post_ops_list, rs_c_downscale ); } else if ( m_partial_pieces == 2 ) @@ -880,7 +934,7 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x16) alpha, beta, is_last_k, post_op_c_i, post_op_c_j, - post_ops_list + post_ops_list, rs_c_downscale ); } else if ( m_partial_pieces == 1 ) @@ -895,7 +949,7 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x16) alpha, beta, is_last_k, post_op_c_i, post_op_c_j, - post_ops_list + post_ops_list, rs_c_downscale ); } } @@ -909,7 +963,8 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x32) &&POST_OPS_6x32_DISABLE, &&POST_OPS_BIAS_6x32, &&POST_OPS_RELU_6x32, - &&POST_OPS_RELU_SCALE_6x32 + &&POST_OPS_RELU_SCALE_6x32, + &&POST_OPS_DOWNSCALE_6x32 }; dim_t MR = 6; dim_t m_full_pieces = m0 / MR; @@ -1323,6 +1378,53 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x32) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_DOWNSCALE_6x32: + { + selector1 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 1 * 16 ) ); + + // c[0, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_0p0,selector1,0,0); + + // c[0, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_0p1,selector2,0,1); + + // c[1, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_1p0,selector1,1,0); + + // c[1, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_1p1,selector2,1,1); + + // c[2, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_2p0,selector1,2,0); + + // c[2, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_2p1,selector2,2,1); + + // c[3, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_3p0,selector1,3,0); + + // c[3, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_3p1,selector2,3,1); + + // c[4, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_4p0,selector1,4,0); + + // c[4, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_4p1,selector2,4,1); + + // c[5, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_5p0,selector1,5,0); + + // c[5, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_5p1,selector2,5,1); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_6x32_DISABLE: ; @@ -1381,7 +1483,7 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x32) alpha, beta, is_last_k, post_op_c_i, post_op_c_j, - post_ops_list + post_ops_list, rs_c_downscale ); } else if ( m_partial_pieces == 4 ) @@ -1396,7 +1498,7 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x32) alpha, beta, is_last_k, post_op_c_i, post_op_c_j, - post_ops_list + post_ops_list, rs_c_downscale ); } else if ( m_partial_pieces == 3 ) @@ -1411,7 +1513,7 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x32) alpha, beta, is_last_k, post_op_c_i, post_op_c_j, - post_ops_list + post_ops_list, rs_c_downscale ); } else if ( m_partial_pieces == 2 ) @@ -1426,7 +1528,7 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x32) alpha, beta, is_last_k, post_op_c_i, post_op_c_j, - post_ops_list + post_ops_list, rs_c_downscale ); } else if ( m_partial_pieces == 1 ) @@ -1441,7 +1543,7 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x32) alpha, beta, is_last_k, post_op_c_i, post_op_c_j, - post_ops_list + post_ops_list, rs_c_downscale ); } } @@ -1455,7 +1557,8 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x48) &&POST_OPS_6x48_DISABLE, &&POST_OPS_BIAS_6x48, &&POST_OPS_RELU_6x48, - &&POST_OPS_RELU_SCALE_6x48 + &&POST_OPS_RELU_SCALE_6x48, + &&POST_OPS_DOWNSCALE_6x48 }; dim_t MR = 6; dim_t m_full_pieces = m0 / MR; @@ -1984,6 +2087,74 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x48) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_DOWNSCALE_6x48: + { + selector1 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 1 * 16 ) ); + a_int32_0 = + _mm512_loadu_epi32( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j + ( 2 * 16 ) ); + + // c[0, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_0p0,selector1,0,0); + + // c[0, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_0p1,selector2,0,1); + + // c[0, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_0p2,a_int32_0,0,2); + + // c[1, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_1p0,selector1,1,0); + + // c[1, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_1p1,selector2,1,1); + + // c[1, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_1p2,a_int32_0,1,2); + + // c[2, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_2p0,selector1,2,0); + + // c[2, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_2p1,selector2,2,1); + + // c[2, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_2p2,a_int32_0,2,2); + + // c[3, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_3p0,selector1,3,0); + + // c[3, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_3p1,selector2,3,1); + + // c[3, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_3p2,a_int32_0,3,2); + + // c[4, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_4p0,selector1,4,0); + + // c[4, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_4p1,selector2,4,1); + + // c[4, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_4p2,a_int32_0,4,2); + + // c[5, 0-15] + CVT_MULRND_CVT32_CVT8(c_int32_5p0,selector1,5,0); + + // c[5, 16-31] + CVT_MULRND_CVT32_CVT8(c_int32_5p1,selector2,5,1); + + // c[5, 32-47] + CVT_MULRND_CVT32_CVT8(c_int32_5p2,a_int32_0,5,2); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_6x48_DISABLE: ; @@ -2060,7 +2231,7 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x48) alpha, beta, is_last_k, post_op_c_i, post_op_c_j, - post_ops_list + post_ops_list, rs_c_downscale ); } else if ( m_partial_pieces == 4 ) @@ -2075,7 +2246,7 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x48) alpha, beta, is_last_k, post_op_c_i, post_op_c_j, - post_ops_list + post_ops_list, rs_c_downscale ); } else if ( m_partial_pieces == 3 ) @@ -2090,7 +2261,7 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x48) alpha, beta, is_last_k, post_op_c_i, post_op_c_j, - post_ops_list + post_ops_list, rs_c_downscale ); } else if ( m_partial_pieces == 2 ) @@ -2105,7 +2276,7 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x48) alpha, beta, is_last_k, post_op_c_i, post_op_c_j, - post_ops_list + post_ops_list, rs_c_downscale ); } else if ( m_partial_pieces == 1 ) @@ -2120,7 +2291,7 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x48) alpha, beta, is_last_k, post_op_c_i, post_op_c_j, - post_ops_list + post_ops_list, rs_c_downscale ); } } diff --git a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_s32_kern_macros.h b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_s32_kern_macros.h index bb82d04d34..d095ef683b 100644 --- a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_s32_kern_macros.h +++ b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_s32_kern_macros.h @@ -42,4 +42,44 @@ /* Apply scaling on for <= 0 elements.*/ \ reg = _mm512_mask_mullo_epi32( reg, relu_cmp_mask, reg, selector2 ); \ +#define CVT_MULRND_CVT32_CVT8(reg,selector,m_ind,n_ind) \ + _mm_storeu_epi8 \ + ( \ + ( int8_t* )post_ops_list_temp->op_args3 + \ + ( rs_c_downscale * ( post_op_c_i + m_ind ) ) + post_op_c_j + ( n_ind * 16 ), \ + _mm512_cvtepi32_epi8 \ + ( \ + _mm512_cvtps_epi32 \ + ( \ + _mm512_mul_round_ps \ + ( \ + _mm512_cvtepi32_ps( reg ), \ + ( __m512 )selector, \ + ( _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC ) \ + ) \ + ) \ + ) \ + ) \ + +#define CVT_MULRND_CVT32_CVT8_LT16(reg,selector,m_ind,n_ind) \ + _mm_storeu_epi8 \ + ( \ + buf0, \ + _mm512_cvtepi32_epi8 \ + ( \ + _mm512_cvtps_epi32 \ + ( \ + _mm512_mul_round_ps \ + ( \ + _mm512_cvtepi32_ps( reg ), \ + ( __m512 )selector, \ + ( _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC ) \ + ) \ + ) \ + ) \ + ); \ + memcpy( ( int8_t* )post_ops_list_temp->op_args3 + \ + ( rs_c_downscale * ( post_op_c_i + m_ind ) ) + post_op_c_j + \ + ( n_ind * 16 ) , buf0, ( n0_rem * sizeof( int8_t ) ) ); \ + #endif // LPGEMM_S32_KERN_MACROS_H From 5faab43e661964c52bd422dc9fffcd85620f08b7 Mon Sep 17 00:00:00 2001 From: Harihara Sudhan S Date: Tue, 30 Aug 2022 16:39:40 +0530 Subject: [PATCH 213/243] Downscaling as part of u8s8s16os16 - int16 c matrix intermediate values are converted to int32, then the int32 values are converted to fp32. On these fp32 values scaling is done - The resultant value is down scaled to int8 and stored in a separate buffer AMD-Internal: [2171] Change-Id: I76ff04098def04d55d1bd88ac8c8d3f267964cab --- .../u8s8s16/lpgemm_6x32rowmajor_amd256.c | 439 +++++++++++++- .../kernels/u8s8s16/lpgemm_m_fringe_amd256.c | 547 +++++++++++++++++- .../kernels/u8s8s16/lpgemm_mn_fringe_amd256.c | 506 +++++++++++++++- .../kernels/u8s8s16/lpgemm_n_fringe_amd256.c | 462 ++++++++++++++- 4 files changed, 1944 insertions(+), 10 deletions(-) diff --git a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_6x32rowmajor_amd256.c b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_6x32rowmajor_amd256.c index 553ce1a134..d5bd5ef93a 100644 --- a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_6x32rowmajor_amd256.c +++ b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_6x32rowmajor_amd256.c @@ -44,7 +44,8 @@ LPGEMM_MAIN_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x32) &&POST_OPS_6x32_DISABLE, &&POST_OPS_BIAS_6x32, &&POST_OPS_RELU_6x32, - &&POST_OPS_RELU_SCALE_6x32 + &&POST_OPS_RELU_SCALE_6x32, + &&POST_OPS_DOWNSCALE_6x32 }; dim_t MR = 6; @@ -560,6 +561,442 @@ LPGEMM_MAIN_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x32) // c[5,16-31] RELU_SCALE_OP_S16_AVX2(c_int16_5p1) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_6x32: + { + __m128i temp[2]; + __m256i temp_32[2]; + __m256 temp_float[2]; + + // Load the scale vector values into the register + __m256 scale_1 = + _mm256_loadu_ps( + (float *)post_ops_list_temp->scale_factor + + post_op_c_j + (0 * 8)); + __m256 scale_2 = + _mm256_loadu_ps( + (float *)post_ops_list_temp->scale_factor + + post_op_c_j + (1 * 8)); + + // Extract the first 128 bits of the register + temp[0] = _mm256_extractf128_si256(c_int16_0p0, 0); + + // Extract the second 128 bits of the register + temp[1] = _mm256_extractf128_si256(c_int16_0p0, 1); + + // Since s16 values cannot be converted to f32 directly, + // they are converted to s32, then to f32 and the scale is performed + temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); + temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); + + temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); + temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); + + // Multiply the C matrix by the scale value + __m256 res_1 = _mm256_mul_ps(temp_float[0], scale_1); + __m256 res_2 = _mm256_mul_ps(temp_float[0], scale_2); + + // Round the resultant value to the nearest integer + res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + + // Convert float32 scaled rounded value to int32 + temp_32[0] = _mm256_cvtps_epi32(res_1); + temp_32[1] = _mm256_cvtps_epi32(res_2); + + // Convert the s32 to s16 + c_int16_0p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]); + + // Permute to make sure the order is correct + c_int16_0p0 = _mm256_permute4x64_epi64(c_int16_0p0, 0XD8); + + // Extract the first 128 bits of the register + temp[0] = _mm256_extractf128_si256(c_int16_0p1, 0); + + // Extract the second 128 bits of the register + temp[1] = _mm256_extractf128_si256(c_int16_0p1, 1); + + // Since s16 values cannot be converted to f32 directly, + // they are converted to s32, then to f32 and the scale is performed + temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); + temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); + + temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); + temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); + + // Multiply the C matrix by the scale value + res_1 = _mm256_mul_ps(temp_float[0], scale_1); + res_2 = _mm256_mul_ps(temp_float[0], scale_2); + + // Round the resultant value to the nearest integer + res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + + // Convert float32 scaled rounded value to int32 + temp_32[0] = _mm256_cvtps_epi32(res_1); + temp_32[1] = _mm256_cvtps_epi32(res_2); + + // Convert the s32 to s16 + c_int16_0p1 = _mm256_packs_epi32(temp_32[0], temp_32[1]); + + // Permute to make sure the order is correct + c_int16_0p1 = _mm256_permute4x64_epi64(c_int16_0p1, 0XD8); + + __m256i store_reg = _mm256_packs_epi16(c_int16_0p0, c_int16_0p1); + + _mm256_storeu_si256((__m256i *)(c + (rs_c * (ir + 0)) + (0 * 16)), store_reg); + + //-------------------------------------------------------------------------- + + // Extract the first 128 bits of the register + temp[0] = _mm256_extractf128_si256(c_int16_1p0, 0); + + // Extract the second 128 bits of the register + temp[1] = _mm256_extractf128_si256(c_int16_1p0, 1); + + // Since s16 values cannot be converted to f32 directly, + // they are converted to s32, then to f32 and the scale is performed + temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); + temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); + + temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); + temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); + + // Multiply the C matrix by the scale value + res_1 = _mm256_mul_ps(temp_float[0], scale_1); + res_2 = _mm256_mul_ps(temp_float[0], scale_2); + + // Round the resultant value to the nearest integer + res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + + // Convert float32 scaled rounded value to int32 + temp_32[0] = _mm256_cvtps_epi32(res_1); + temp_32[1] = _mm256_cvtps_epi32(res_2); + + // Convert the s32 to s16 + c_int16_1p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]); + + // Permute to make sure the order is correct + c_int16_1p0 = _mm256_permute4x64_epi64(c_int16_1p0, 0XD8); + + // Extract the first 128 bits of the register + temp[0] = _mm256_extractf128_si256(c_int16_1p1, 0); + + // Extract the second 128 bits of the register + temp[1] = _mm256_extractf128_si256(c_int16_1p1, 1); + + // Since s16 values cannot be converted to f32 directly, + // they are converted to s32, then to f32 and the scale is performed + temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); + temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); + + temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); + temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); + + // Multiply the C matrix by the scale value + res_1 = _mm256_mul_ps(temp_float[0], scale_1); + res_2 = _mm256_mul_ps(temp_float[0], scale_2); + + // Round the resultant value to the nearest integer + res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + + // Convert float32 scaled rounded value to int32 + temp_32[0] = _mm256_cvtps_epi32(res_1); + temp_32[1] = _mm256_cvtps_epi32(res_2); + + // Convert the s32 to s16 + c_int16_1p1 = _mm256_packs_epi32(temp_32[0], temp_32[1]); + + // Permute to make sure the order is correct + c_int16_1p1 = _mm256_permute4x64_epi64(c_int16_1p1, 0XD8); + + store_reg = _mm256_packs_epi16(c_int16_1p0, c_int16_1p1); + + _mm256_storeu_si256((__m256i *)(c + (rs_c * (ir + 0)) + (0 * 16)), store_reg); + + //-------------------------------------------------------------------------- + + // Extract the first 128 bits of the register + temp[0] = _mm256_extractf128_si256(c_int16_2p0, 0); + + // Extract the second 128 bits of the register + temp[1] = _mm256_extractf128_si256(c_int16_2p0, 1); + + // Since s16 values cannot be converted to f32 directly, + // they are converted to s32, then to f32 and the scale is performed + temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); + temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); + + temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); + temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); + + // Multiply the C matrix by the scale value + res_1 = _mm256_mul_ps(temp_float[0], scale_1); + res_2 = _mm256_mul_ps(temp_float[0], scale_2); + + // Round the resultant value to the nearest integer + res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + + // Convert float32 scaled rounded value to int32 + temp_32[0] = _mm256_cvtps_epi32(res_1); + temp_32[1] = _mm256_cvtps_epi32(res_2); + + // Convert the s32 to s16 + c_int16_2p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]); + + // Permute to make sure the order is correct + c_int16_2p0 = _mm256_permute4x64_epi64(c_int16_2p0, 0XD8); + + // Extract the first 128 bits of the register + temp[0] = _mm256_extractf128_si256(c_int16_2p1, 0); + + // Extract the second 128 bits of the register + temp[1] = _mm256_extractf128_si256(c_int16_2p1, 1); + + // Since s16 values cannot be converted to f32 directly, + // they are converted to s32, then to f32 and the scale is performed + temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); + temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); + + temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); + temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); + + // Multiply the C matrix by the scale value + res_1 = _mm256_mul_ps(temp_float[0], scale_1); + res_2 = _mm256_mul_ps(temp_float[0], scale_2); + + // Round the resultant value to the nearest integer + res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + + // Convert float32 scaled rounded value to int32 + temp_32[0] = _mm256_cvtps_epi32(res_1); + temp_32[1] = _mm256_cvtps_epi32(res_2); + + // Convert the s32 to s16 + c_int16_2p1 = _mm256_packs_epi32(temp_32[0], temp_32[1]); + + // Permute to make sure the order is correct + c_int16_2p1 = _mm256_permute4x64_epi64(c_int16_2p1, 0XD8); + + store_reg = _mm256_packs_epi16(c_int16_2p0, c_int16_2p1); + + _mm256_storeu_si256((__m256i *)(c + (rs_c * (ir + 0)) + (0 * 16)), store_reg); + + //-------------------------------------------------------------------------- + + // Extract the first 128 bits of the register + temp[0] = _mm256_extractf128_si256(c_int16_3p0, 0); + + // Extract the second 128 bits of the register + temp[1] = _mm256_extractf128_si256(c_int16_3p0, 1); + + // Since s16 values cannot be converted to f32 directly, + // they are converted to s32, then to f32 and the scale is performed + temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); + temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); + + temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); + temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); + + // Multiply the C matrix by the scale value + res_1 = _mm256_mul_ps(temp_float[0], scale_1); + res_2 = _mm256_mul_ps(temp_float[0], scale_2); + + // Round the resultant value to the nearest integer + res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + + // Convert float32 scaled rounded value to int32 + temp_32[0] = _mm256_cvtps_epi32(res_1); + temp_32[1] = _mm256_cvtps_epi32(res_2); + + // Convert the s32 to s16 + c_int16_3p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]); + + // Permute to make sure the order is correct + c_int16_3p0 = _mm256_permute4x64_epi64(c_int16_3p0, 0XD8); + + // Extract the first 128 bits of the register + temp[0] = _mm256_extractf128_si256(c_int16_3p1, 0); + + // Extract the second 128 bits of the register + temp[1] = _mm256_extractf128_si256(c_int16_3p1, 1); + + // Since s16 values cannot be converted to f32 directly, + // they are converted to s32, then to f32 and the scale is performed + temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); + temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); + + temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); + temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); + + // Multiply the C matrix by the scale value + res_1 = _mm256_mul_ps(temp_float[0], scale_1); + res_2 = _mm256_mul_ps(temp_float[0], scale_2); + + // Round the resultant value to the nearest integer + res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + + // Convert float32 scaled rounded value to int32 + temp_32[0] = _mm256_cvtps_epi32(res_1); + temp_32[1] = _mm256_cvtps_epi32(res_2); + + // Convert the s32 to s16 + c_int16_3p1 = _mm256_packs_epi32(temp_32[0], temp_32[1]); + + // Permute to make sure the order is correct + c_int16_3p1 = _mm256_permute4x64_epi64(c_int16_3p1, 0XD8); + + store_reg = _mm256_packs_epi16(c_int16_3p0, c_int16_3p1); + + _mm256_storeu_si256((__m256i *)(c + (rs_c * (ir + 0)) + (0 * 16)), store_reg); + + //-------------------------------------------------------------------------- + + // Extract the first 128 bits of the register + temp[0] = _mm256_extractf128_si256(c_int16_4p0, 0); + + // Extract the second 128 bits of the register + temp[1] = _mm256_extractf128_si256(c_int16_4p0, 1); + + // Since s16 values cannot be converted to f32 directly, + // they are converted to s32, then to f32 and the scale is performed + temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); + temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); + + temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); + temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); + + // Multiply the C matrix by the scale value + res_1 = _mm256_mul_ps(temp_float[0], scale_1); + res_2 = _mm256_mul_ps(temp_float[0], scale_2); + + // Round the resultant value to the nearest integer + res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + + // Convert float32 scaled rounded value to int32 + temp_32[0] = _mm256_cvtps_epi32(res_1); + temp_32[1] = _mm256_cvtps_epi32(res_2); + + // Convert the s32 to s16 + c_int16_4p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]); + + // Permute to make sure the order is correct + c_int16_4p0 = _mm256_permute4x64_epi64(c_int16_4p0, 0XD8); + + // Extract the first 128 bits of the register + temp[0] = _mm256_extractf128_si256(c_int16_4p1, 0); + + // Extract the second 128 bits of the register + temp[1] = _mm256_extractf128_si256(c_int16_4p1, 1); + + // Since s16 values cannot be converted to f32 directly, + // they are converted to s32, then to f32 and the scale is performed + temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); + temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); + + temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); + temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); + + // Multiply the C matrix by the scale value + res_1 = _mm256_mul_ps(temp_float[0], scale_1); + res_2 = _mm256_mul_ps(temp_float[0], scale_2); + + // Round the resultant value to the nearest integer + res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + + // Convert float32 scaled rounded value to int32 + temp_32[0] = _mm256_cvtps_epi32(res_1); + temp_32[1] = _mm256_cvtps_epi32(res_2); + + // Convert the s32 to s16 + c_int16_4p1 = _mm256_packs_epi32(temp_32[0], temp_32[1]); + + // Permute to make sure the order is correct + c_int16_4p1 = _mm256_permute4x64_epi64(c_int16_4p1, 0XD8); + + store_reg = _mm256_packs_epi16(c_int16_4p0, c_int16_4p1); + + _mm256_storeu_si256((__m256i *)(c + (rs_c * (ir + 0)) + (0 * 16)), store_reg); + + //-------------------------------------------------------------------------- + + // Extract the first 128 bits of the register + temp[0] = _mm256_extractf128_si256(c_int16_5p0, 0); + + // Extract the second 128 bits of the register + temp[1] = _mm256_extractf128_si256(c_int16_5p0, 1); + + // Since s16 values cannot be converted to f32 directly, + // they are converted to s32, then to f32 and the scale is performed + temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); + temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); + + temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); + temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); + + // Multiply the C matrix by the scale value + res_1 = _mm256_mul_ps(temp_float[0], scale_1); + res_2 = _mm256_mul_ps(temp_float[0], scale_2); + + // Round the resultant value to the nearest integer + res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + + // Convert float32 scaled rounded value to int32 + temp_32[0] = _mm256_cvtps_epi32(res_1); + temp_32[1] = _mm256_cvtps_epi32(res_2); + + // Convert the s32 to s16 + c_int16_5p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]); + + // Permute to make sure the order is correct + c_int16_5p0 = _mm256_permute4x64_epi64(c_int16_5p0, 0XD8); + + // Extract the first 128 bits of the register + temp[0] = _mm256_extractf128_si256(c_int16_5p1, 0); + + // Extract the second 128 bits of the register + temp[1] = _mm256_extractf128_si256(c_int16_5p1, 1); + + // Since s16 values cannot be converted to f32 directly, + // they are converted to s32, then to f32 and the scale is performed + temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); + temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); + + temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); + temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); + + // Multiply the C matrix by the scale value + res_1 = _mm256_mul_ps(temp_float[0], scale_1); + res_2 = _mm256_mul_ps(temp_float[0], scale_2); + + // Round the resultant value to the nearest integer + res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + + // Convert float32 scaled rounded value to int32 + temp_32[0] = _mm256_cvtps_epi32(res_1); + temp_32[1] = _mm256_cvtps_epi32(res_2); + + // Convert the s32 to s16 + c_int16_5p1 = _mm256_packs_epi32(temp_32[0], temp_32[1]); + + // Permute to make sure the order is correct + c_int16_5p1 = _mm256_permute4x64_epi64(c_int16_4p1, 0XD8); + + store_reg = _mm256_packs_epi16(c_int16_5p0, c_int16_5p1); + + _mm256_storeu_si256((__m256i *)(c + (rs_c * (ir + 0)) + (0 * 16)), store_reg); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_6x32_DISABLE: diff --git a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_m_fringe_amd256.c b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_m_fringe_amd256.c index 753a083d93..12026ad0e0 100644 --- a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_m_fringe_amd256.c +++ b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_m_fringe_amd256.c @@ -48,7 +48,8 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4x32) &&POST_OPS_4x32_DISABLE, &&POST_OPS_BIAS_4x32, &&POST_OPS_RELU_4x32, - &&POST_OPS_RELU_SCALE_4x32 + &&POST_OPS_RELU_SCALE_4x32, + &&POST_OPS_DOWNSCALE_4x32 }; // The division is done by considering the vpmaddubsw instruction @@ -348,6 +349,302 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4x32) // c[3,16-31] RELU_SCALE_OP_S16_AVX2(c_int16_3p1) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_4x32: + { + __m128i temp[2]; + __m256i temp_32[2]; + __m256 temp_float[2]; + + // Load the scale vector values into the register + __m256 scale_1 = + _mm256_loadu_ps( + (float *)post_ops_list_temp->scale_factor + + post_op_c_j + (0 * 8)); + __m256 scale_2 = + _mm256_loadu_ps( + (float *)post_ops_list_temp->scale_factor + + post_op_c_j + (1 * 8)); + + // Extract the first 128 bits of the register + temp[0] = _mm256_extractf128_si256(c_int16_0p0, 0); + + // Extract the second 128 bits of the register + temp[1] = _mm256_extractf128_si256(c_int16_0p0, 1); + + // Since s16 values cannot be converted to f32 directly, + // they are converted to s32, then to f32 and the scale is performed + temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); + temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); + + temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); + temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); + + // Multiply the C matrix by the scale value + __m256 res_1 = _mm256_mul_ps(temp_float[0], scale_1); + __m256 res_2 = _mm256_mul_ps(temp_float[0], scale_2); + + // Round the resultant value to the nearest integer + res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + + // Convert float32 scaled rounded value to int32 + temp_32[0] = _mm256_cvtps_epi32(res_1); + temp_32[1] = _mm256_cvtps_epi32(res_2); + + // Convert the s32 to s16 + c_int16_0p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]); + + // Permute to make sure the order is correct + c_int16_0p0 = _mm256_permute4x64_epi64(c_int16_0p0, 0XD8); + + // Extract the first 128 bits of the register + temp[0] = _mm256_extractf128_si256(c_int16_0p1, 0); + + // Extract the second 128 bits of the register + temp[1] = _mm256_extractf128_si256(c_int16_0p1, 1); + + // Since s16 values cannot be converted to f32 directly, + // they are converted to s32, then to f32 and the scale is performed + temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); + temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); + + temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); + temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); + + // Multiply the C matrix by the scale value + res_1 = _mm256_mul_ps(temp_float[0], scale_1); + res_2 = _mm256_mul_ps(temp_float[0], scale_2); + + // Round the resultant value to the nearest integer + res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + + // Convert float32 scaled rounded value to int32 + temp_32[0] = _mm256_cvtps_epi32(res_1); + temp_32[1] = _mm256_cvtps_epi32(res_2); + + // Convert the s32 to s16 + c_int16_0p1 = _mm256_packs_epi32(temp_32[0], temp_32[1]); + + // Permute to make sure the order is correct + c_int16_0p1 = _mm256_permute4x64_epi64(c_int16_0p1, 0XD8); + + __m256i store_reg = _mm256_packs_epi16(c_int16_0p0, c_int16_0p1); + + _mm256_storeu_si256((__m256i *)(c + (rs_c * 0)) + (0 * 16), store_reg); + + //-------------------------------------------------------------------------- + + // Extract the first 128 bits of the register + temp[0] = _mm256_extractf128_si256(c_int16_1p0, 0); + + // Extract the second 128 bits of the register + temp[1] = _mm256_extractf128_si256(c_int16_1p0, 1); + + // Since s16 values cannot be converted to f32 directly, + // they are converted to s32, then to f32 and the scale is performed + temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); + temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); + + temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); + temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); + + // Multiply the C matrix by the scale value + res_1 = _mm256_mul_ps(temp_float[0], scale_1); + res_2 = _mm256_mul_ps(temp_float[0], scale_2); + + // Round the resultant value to the nearest integer + res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + + // Convert float32 scaled rounded value to int32 + temp_32[0] = _mm256_cvtps_epi32(res_1); + temp_32[1] = _mm256_cvtps_epi32(res_2); + + // Convert the s32 to s16 + c_int16_1p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]); + + // Permute to make sure the order is correct + c_int16_1p0 = _mm256_permute4x64_epi64(c_int16_1p0, 0XD8); + + // Extract the first 128 bits of the register + temp[0] = _mm256_extractf128_si256(c_int16_1p1, 0); + + // Extract the second 128 bits of the register + temp[1] = _mm256_extractf128_si256(c_int16_1p1, 1); + + // Since s16 values cannot be converted to f32 directly, + // they are converted to s32, then to f32 and the scale is performed + temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); + temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); + + temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); + temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); + + // Multiply the C matrix by the scale value + res_1 = _mm256_mul_ps(temp_float[0], scale_1); + res_2 = _mm256_mul_ps(temp_float[0], scale_2); + + // Round the resultant value to the nearest integer + res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + + // Convert float32 scaled rounded value to int32 + temp_32[0] = _mm256_cvtps_epi32(res_1); + temp_32[1] = _mm256_cvtps_epi32(res_2); + + // Convert the s32 to s16 + c_int16_1p1 = _mm256_packs_epi32(temp_32[0], temp_32[1]); + + // Permute to make sure the order is correct + c_int16_1p1 = _mm256_permute4x64_epi64(c_int16_1p1, 0XD8); + + store_reg = _mm256_packs_epi16(c_int16_1p0, c_int16_1p1); + + _mm256_storeu_si256((__m256i *)(c + (rs_c * 0) + (0 * 16)), store_reg); + + //-------------------------------------------------------------------------- + + // Extract the first 128 bits of the register + temp[0] = _mm256_extractf128_si256(c_int16_2p0, 0); + + // Extract the second 128 bits of the register + temp[1] = _mm256_extractf128_si256(c_int16_2p0, 1); + + // Since s16 values cannot be converted to f32 directly, + // they are converted to s32, then to f32 and the scale is performed + temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); + temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); + + temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); + temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); + + // Multiply the C matrix by the scale value + res_1 = _mm256_mul_ps(temp_float[0], scale_1); + res_2 = _mm256_mul_ps(temp_float[0], scale_2); + + // Round the resultant value to the nearest integer + res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + + // Convert float32 scaled rounded value to int32 + temp_32[0] = _mm256_cvtps_epi32(res_1); + temp_32[1] = _mm256_cvtps_epi32(res_2); + + // Convert the s32 to s16 + c_int16_2p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]); + + // Permute to make sure the order is correct + c_int16_2p0 = _mm256_permute4x64_epi64(c_int16_2p0, 0XD8); + + // Extract the first 128 bits of the register + temp[0] = _mm256_extractf128_si256(c_int16_2p1, 0); + + // Extract the second 128 bits of the register + temp[1] = _mm256_extractf128_si256(c_int16_2p1, 1); + + // Since s16 values cannot be converted to f32 directly, + // they are converted to s32, then to f32 and the scale is performed + temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); + temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); + + temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); + temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); + + // Multiply the C matrix by the scale value + res_1 = _mm256_mul_ps(temp_float[0], scale_1); + res_2 = _mm256_mul_ps(temp_float[0], scale_2); + + // Round the resultant value to the nearest integer + res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + + // Convert float32 scaled rounded value to int32 + temp_32[0] = _mm256_cvtps_epi32(res_1); + temp_32[1] = _mm256_cvtps_epi32(res_2); + + // Convert the s32 to s16 + c_int16_2p1 = _mm256_packs_epi32(temp_32[0], temp_32[1]); + + // Permute to make sure the order is correct + c_int16_2p1 = _mm256_permute4x64_epi64(c_int16_2p1, 0XD8); + + store_reg = _mm256_packs_epi16(c_int16_2p0, c_int16_2p1); + + _mm256_storeu_si256((__m256i *)(c + (rs_c * 0) + (0 * 16)), store_reg); + + //-------------------------------------------------------------------------- + + // Extract the first 128 bits of the register + temp[0] = _mm256_extractf128_si256(c_int16_3p0, 0); + + // Extract the second 128 bits of the register + temp[1] = _mm256_extractf128_si256(c_int16_3p0, 1); + + // Since s16 values cannot be converted to f32 directly, + // they are converted to s32, then to f32 and the scale is performed + temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); + temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); + + temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); + temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); + + // Multiply the C matrix by the scale value + res_1 = _mm256_mul_ps(temp_float[0], scale_1); + res_2 = _mm256_mul_ps(temp_float[0], scale_2); + + // Round the resultant value to the nearest integer + res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + + // Convert float32 scaled rounded value to int32 + temp_32[0] = _mm256_cvtps_epi32(res_1); + temp_32[1] = _mm256_cvtps_epi32(res_2); + + // Convert the s32 to s16 + c_int16_3p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]); + + // Permute to make sure the order is correct + c_int16_3p0 = _mm256_permute4x64_epi64(c_int16_3p0, 0XD8); + + // Extract the first 128 bits of the register + temp[0] = _mm256_extractf128_si256(c_int16_3p1, 0); + + // Extract the second 128 bits of the register + temp[1] = _mm256_extractf128_si256(c_int16_3p1, 1); + + // Since s16 values cannot be converted to f32 directly, + // they are converted to s32, then to f32 and the scale is performed + temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); + temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); + + temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); + temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); + + // Multiply the C matrix by the scale value + res_1 = _mm256_mul_ps(temp_float[0], scale_1); + res_2 = _mm256_mul_ps(temp_float[0], scale_2); + + // Round the resultant value to the nearest integer + res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + + // Convert float32 scaled rounded value to int32 + temp_32[0] = _mm256_cvtps_epi32(res_1); + temp_32[1] = _mm256_cvtps_epi32(res_2); + + // Convert the s32 to s16 + c_int16_3p1 = _mm256_packs_epi32(temp_32[0], temp_32[1]); + + // Permute to make sure the order is correct + c_int16_3p1 = _mm256_permute4x64_epi64(c_int16_3p1, 0XD8); + + store_reg = _mm256_packs_epi16(c_int16_3p0, c_int16_3p1); + + _mm256_storeu_si256((__m256i *)(c + (rs_c * 0) + (0 * 16)), store_reg); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_4x32_DISABLE: @@ -390,7 +687,8 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2x32) &&POST_OPS_2x32_DISABLE, &&POST_OPS_BIAS_2x32, &&POST_OPS_RELU_2x32, - &&POST_OPS_RELU_SCALE_2x32 + &&POST_OPS_RELU_SCALE_2x32, + &&POST_OPS_DOWNSCALE_2x32 }; // The division is done by considering the vpmaddubsw instruction @@ -573,6 +871,162 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2x32) // c[1,16-31] RELU_SCALE_OP_S16_AVX2(c_int16_1p1) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_2x32: + { + __m128i temp[2]; + __m256i temp_32[2]; + __m256 temp_float[2]; + + // Load the scale vector values into the register + __m256 scale_1 = + _mm256_loadu_ps( + (float *)post_ops_list_temp->scale_factor + + post_op_c_j + (0 * 8)); + __m256 scale_2 = + _mm256_loadu_ps( + (float *)post_ops_list_temp->scale_factor + + post_op_c_j + (1 * 8)); + + // Extract the first 128 bits of the register + temp[0] = _mm256_extractf128_si256(c_int16_0p0, 0); + + // Extract the second 128 bits of the register + temp[1] = _mm256_extractf128_si256(c_int16_0p0, 1); + + // Since s16 values cannot be converted to f32 directly, + // they are converted to s32, then to f32 and the scale is performed + temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); + temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); + + temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); + temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); + + // Multiply the C matrix by the scale value + __m256 res_1 = _mm256_mul_ps(temp_float[0], scale_1); + __m256 res_2 = _mm256_mul_ps(temp_float[0], scale_2); + + // Round the resultant value to the nearest integer + res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + + // Convert float32 scaled rounded value to int32 + temp_32[0] = _mm256_cvtps_epi32(res_1); + temp_32[1] = _mm256_cvtps_epi32(res_2); + + // Convert the s32 to s16 + c_int16_0p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]); + + // Permute to make sure the order is correct + c_int16_0p0 = _mm256_permute4x64_epi64(c_int16_0p0, 0XD8); + + // Extract the first 128 bits of the register + temp[0] = _mm256_extractf128_si256(c_int16_0p1, 0); + + // Extract the second 128 bits of the register + temp[1] = _mm256_extractf128_si256(c_int16_0p1, 1); + + // Since s16 values cannot be converted to f32 directly, + // they are converted to s32, then to f32 and the scale is performed + temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); + temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); + + temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); + temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); + + // Multiply the C matrix by the scale value + res_1 = _mm256_mul_ps(temp_float[0], scale_1); + res_2 = _mm256_mul_ps(temp_float[0], scale_2); + + // Round the resultant value to the nearest integer + res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + + // Convert float32 scaled rounded value to int32 + temp_32[0] = _mm256_cvtps_epi32(res_1); + temp_32[1] = _mm256_cvtps_epi32(res_2); + + // Convert the s32 to s16 + c_int16_0p1 = _mm256_packs_epi32(temp_32[0], temp_32[1]); + + // Permute to make sure the order is correct + c_int16_0p1 = _mm256_permute4x64_epi64(c_int16_0p1, 0XD8); + + __m256i store_reg = _mm256_packs_epi16(c_int16_0p0, c_int16_0p1); + + _mm256_storeu_si256((__m256i *)(c + (rs_c * 0)) + (0 * 16), store_reg); + + //-------------------------------------------------------------------------- + + // Extract the first 128 bits of the register + temp[0] = _mm256_extractf128_si256(c_int16_1p0, 0); + + // Extract the second 128 bits of the register + temp[1] = _mm256_extractf128_si256(c_int16_1p0, 1); + + // Since s16 values cannot be converted to f32 directly, + // they are converted to s32, then to f32 and the scale is performed + temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); + temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); + + temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); + temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); + + // Multiply the C matrix by the scale value + res_1 = _mm256_mul_ps(temp_float[0], scale_1); + res_2 = _mm256_mul_ps(temp_float[0], scale_2); + + // Round the resultant value to the nearest integer + res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + + // Convert float32 scaled rounded value to int32 + temp_32[0] = _mm256_cvtps_epi32(res_1); + temp_32[1] = _mm256_cvtps_epi32(res_2); + + // Convert the s32 to s16 + c_int16_1p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]); + + // Permute to make sure the order is correct + c_int16_1p0 = _mm256_permute4x64_epi64(c_int16_1p0, 0XD8); + + // Extract the first 128 bits of the register + temp[0] = _mm256_extractf128_si256(c_int16_1p1, 0); + + // Extract the second 128 bits of the register + temp[1] = _mm256_extractf128_si256(c_int16_1p1, 1); + + // Since s16 values cannot be converted to f32 directly, + // they are converted to s32, then to f32 and the scale is performed + temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); + temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); + + temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); + temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); + + // Multiply the C matrix by the scale value + res_1 = _mm256_mul_ps(temp_float[0], scale_1); + res_2 = _mm256_mul_ps(temp_float[0], scale_2); + + // Round the resultant value to the nearest integer + res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + + // Convert float32 scaled rounded value to int32 + temp_32[0] = _mm256_cvtps_epi32(res_1); + temp_32[1] = _mm256_cvtps_epi32(res_2); + + // Convert the s32 to s16 + c_int16_1p1 = _mm256_packs_epi32(temp_32[0], temp_32[1]); + + // Permute to make sure the order is correct + c_int16_1p1 = _mm256_permute4x64_epi64(c_int16_1p1, 0XD8); + + store_reg = _mm256_packs_epi16(c_int16_1p0, c_int16_1p1); + + _mm256_storeu_si256((__m256i *)(c + (rs_c * 0) + (0 * 16)), store_reg); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_2x32_DISABLE: @@ -602,7 +1056,8 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_1x32) &&POST_OPS_1x32_DISABLE, &&POST_OPS_BIAS_1x32, &&POST_OPS_RELU_1x32, - &&POST_OPS_RELU_SCALE_1x32 + &&POST_OPS_RELU_SCALE_1x32, + &&POST_OPS_DOWNSCALE_1x32 }; // The division is done by considering the vpmaddubsw instruction @@ -726,6 +1181,92 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_1x32) // c[0,16-31] RELU_SCALE_OP_S16_AVX2(c_int16_0p1) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_1x32: + { + __m128i temp[2]; + __m256i temp_32[2]; + __m256 temp_float[2]; + + // Load the scale vector values into the register + __m256 scale_1 = + _mm256_loadu_ps( + (float *)post_ops_list_temp->scale_factor + + post_op_c_j + (0 * 8)); + __m256 scale_2 = + _mm256_loadu_ps( + (float *)post_ops_list_temp->scale_factor + + post_op_c_j + (1 * 8)); + + // Extract the first 128 bits of the register + temp[0] = _mm256_extractf128_si256(c_int16_0p0, 0); + + // Extract the second 128 bits of the register + temp[1] = _mm256_extractf128_si256(c_int16_0p0, 1); + + // Since s16 values cannot be converted to f32 directly, + // they are converted to s32, then to f32 and the scale is performed + temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); + temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); + + temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); + temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); + + // Multiply the C matrix by the scale value + __m256 res_1 = _mm256_mul_ps(temp_float[0], scale_1); + __m256 res_2 = _mm256_mul_ps(temp_float[0], scale_2); + + // Round the resultant value to the nearest integer + res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + + // Convert float32 scaled rounded value to int32 + temp_32[0] = _mm256_cvtps_epi32(res_1); + temp_32[1] = _mm256_cvtps_epi32(res_2); + + // Convert the s32 to s16 + c_int16_0p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]); + + // Permute to make sure the order is correct + c_int16_0p0 = _mm256_permute4x64_epi64(c_int16_0p0, 0XD8); + + // Extract the first 128 bits of the register + temp[0] = _mm256_extractf128_si256(c_int16_0p1, 0); + + // Extract the second 128 bits of the register + temp[1] = _mm256_extractf128_si256(c_int16_0p1, 1); + + // Since s16 values cannot be converted to f32 directly, + // they are converted to s32, then to f32 and the scale is performed + temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); + temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); + + temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); + temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); + + // Multiply the C matrix by the scale value + res_1 = _mm256_mul_ps(temp_float[0], scale_1); + res_2 = _mm256_mul_ps(temp_float[0], scale_2); + + // Round the resultant value to the nearest integer + res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + + // Convert float32 scaled rounded value to int32 + temp_32[0] = _mm256_cvtps_epi32(res_1); + temp_32[1] = _mm256_cvtps_epi32(res_2); + + // Convert the s32 to s16 + c_int16_0p1 = _mm256_packs_epi32(temp_32[0], temp_32[1]); + + // Permute to make sure the order is correct + c_int16_0p1 = _mm256_permute4x64_epi64(c_int16_0p1, 0XD8); + + __m256i store_reg = _mm256_packs_epi16(c_int16_0p0, c_int16_0p1); + + _mm256_storeu_si256((__m256i *)(c + (rs_c * 0)) + (0 * 16), store_reg); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_1x32_DISABLE: diff --git a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_mn_fringe_amd256.c b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_mn_fringe_amd256.c index 77ccd2d27d..f50dd98826 100644 --- a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_mn_fringe_amd256.c +++ b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_mn_fringe_amd256.c @@ -48,7 +48,8 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4x16) &&POST_OPS_4x16_DISABLE, &&POST_OPS_BIAS_4x16, &&POST_OPS_RELU_4x16, - &&POST_OPS_RELU_SCALE_4x16 + &&POST_OPS_RELU_SCALE_4x16, + &&POST_OPS_DOWNSCALE_4x16 }; // The division is done by considering the vpmaddubsw instruction @@ -258,6 +259,164 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4x16) // c[3,0-15] RELU_SCALE_OP_S16_AVX2(c_int16_3p0) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_4x16: + { + __m128i temp[2]; + __m256i temp_32[2]; + __m256 temp_float[2]; + + // Load the scale vector values into the register + __m256 scale_1 = + _mm256_loadu_ps( + (float *)post_ops_list_temp->scale_factor + + post_op_c_j + (0 * 8)); + __m256 scale_2 = + _mm256_loadu_ps( + (float *)post_ops_list_temp->scale_factor + + post_op_c_j + (1 * 8)); + + // Extract the first 128 bits of the register + temp[0] = _mm256_extractf128_si256(c_int16_0p0, 0); + + // Extract the second 128 bits of the register + temp[1] = _mm256_extractf128_si256(c_int16_0p0, 1); + + // Since s16 values cannot be converted to f32 directly, + // they are converted to s32, then to f32 and the scale is performed + temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); + temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); + + temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); + temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); + + // Multiply the C matrix by the scale value + __m256 res_1 = _mm256_mul_ps(temp_float[0], scale_1); + __m256 res_2 = _mm256_mul_ps(temp_float[0], scale_2); + + // Round the resultant value to the nearest integer + res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + + // Convert float32 scaled rounded value to int32 + temp_32[0] = _mm256_cvtps_epi32(res_1); + temp_32[1] = _mm256_cvtps_epi32(res_2); + + // Convert the s32 to s16 + c_int16_0p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]); + + // Permute to make sure the order is correct + c_int16_0p0 = _mm256_permute4x64_epi64(c_int16_0p0, 0XD8); + //-------------------------------------------------------------------------- + + // Extract the first 128 bits of the register + temp[0] = _mm256_extractf128_si256(c_int16_1p0, 0); + + // Extract the second 128 bits of the register + temp[1] = _mm256_extractf128_si256(c_int16_1p0, 1); + + // Since s16 values cannot be converted to f32 directly, + // they are converted to s32, then to f32 and the scale is performed + temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); + temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); + + temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); + temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); + + // Multiply the C matrix by the scale value + res_1 = _mm256_mul_ps(temp_float[0], scale_1); + res_2 = _mm256_mul_ps(temp_float[0], scale_2); + + // Round the resultant value to the nearest integer + res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + + // Convert float32 scaled rounded value to int32 + temp_32[0] = _mm256_cvtps_epi32(res_1); + temp_32[1] = _mm256_cvtps_epi32(res_2); + + // Convert the s32 to s16 + c_int16_1p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]); + + // Permute to make sure the order is correct + c_int16_1p0 = _mm256_permute4x64_epi64(c_int16_1p0, 0XD8); + + __m256i store_reg = _mm256_packs_epi16(c_int16_0p0, c_int16_1p0); + + _mm256_storeu_si256((__m256i *)(c + (rs_c * 0) + (0 * 16)), store_reg); + + //-------------------------------------------------------------------------- + + // Extract the first 128 bits of the register + temp[0] = _mm256_extractf128_si256(c_int16_2p0, 0); + + // Extract the second 128 bits of the register + temp[1] = _mm256_extractf128_si256(c_int16_2p0, 1); + + // Since s16 values cannot be converted to f32 directly, + // they are converted to s32, then to f32 and the scale is performed + temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); + temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); + + temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); + temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); + + // Multiply the C matrix by the scale value + res_1 = _mm256_mul_ps(temp_float[0], scale_1); + res_2 = _mm256_mul_ps(temp_float[0], scale_2); + + // Round the resultant value to the nearest integer + res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + + // Convert float32 scaled rounded value to int32 + temp_32[0] = _mm256_cvtps_epi32(res_1); + temp_32[1] = _mm256_cvtps_epi32(res_2); + + // Convert the s32 to s16 + c_int16_2p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]); + + // Permute to make sure the order is correct + c_int16_2p0 = _mm256_permute4x64_epi64(c_int16_2p0, 0XD8); + //-------------------------------------------------------------------------- + + // Extract the first 128 bits of the register + temp[0] = _mm256_extractf128_si256(c_int16_3p0, 0); + + // Extract the second 128 bits of the register + temp[1] = _mm256_extractf128_si256(c_int16_3p0, 1); + + // Since s16 values cannot be converted to f32 directly, + // they are converted to s32, then to f32 and the scale is performed + temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); + temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); + + temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); + temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); + + // Multiply the C matrix by the scale value + res_1 = _mm256_mul_ps(temp_float[0], scale_1); + res_2 = _mm256_mul_ps(temp_float[0], scale_2); + + // Round the resultant value to the nearest integer + res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + + // Convert float32 scaled rounded value to int32 + temp_32[0] = _mm256_cvtps_epi32(res_1); + temp_32[1] = _mm256_cvtps_epi32(res_2); + + // Convert the s32 to s16 + c_int16_3p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]); + + // Permute to make sure the order is correct + c_int16_3p0 = _mm256_permute4x64_epi64(c_int16_3p0, 0XD8); + + store_reg = _mm256_packs_epi16(c_int16_2p0, c_int16_3p0); + + _mm256_storeu_si256((__m256i *)(c + (rs_c * 0) + (0 * 16)), store_reg); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_4x16_DISABLE: @@ -287,7 +446,8 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4xlt16) &&POST_OPS_4xlt16_DISABLE, &&POST_OPS_BIAS_4xlt16, &&POST_OPS_RELU_4xlt16, - &&POST_OPS_RELU_SCALE_4xlt16 + &&POST_OPS_RELU_SCALE_4xlt16, + &&POST_OPS_DOWNSCALE_4xlt16 }; // The division is done by considering the vpmaddubsw instruction @@ -512,6 +672,164 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4xlt16) // c[3,0-15] RELU_SCALE_OP_S16_AVX2(c_int16_3p0) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_4xlt16: + { + __m128i temp[2]; + __m256i temp_32[2]; + __m256 temp_float[2]; + + // Load the scale vector values into the register + __m256 scale_1 = + _mm256_loadu_ps( + (float *)post_ops_list_temp->scale_factor + + post_op_c_j + (0 * 8)); + __m256 scale_2 = + _mm256_loadu_ps( + (float *)post_ops_list_temp->scale_factor + + post_op_c_j + (1 * 8)); + + // Extract the first 128 bits of the register + temp[0] = _mm256_extractf128_si256(c_int16_0p0, 0); + + // Extract the second 128 bits of the register + temp[1] = _mm256_extractf128_si256(c_int16_0p0, 1); + + // Since s16 values cannot be converted to f32 directly, + // they are converted to s32, then to f32 and the scale is performed + temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); + temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); + + temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); + temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); + + // Multiply the C matrix by the scale value + __m256 res_1 = _mm256_mul_ps(temp_float[0], scale_1); + __m256 res_2 = _mm256_mul_ps(temp_float[0], scale_2); + + // Round the resultant value to the nearest integer + res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + + // Convert float32 scaled rounded value to int32 + temp_32[0] = _mm256_cvtps_epi32(res_1); + temp_32[1] = _mm256_cvtps_epi32(res_2); + + // Convert the s32 to s16 + c_int16_0p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]); + + // Permute to make sure the order is correct + c_int16_0p0 = _mm256_permute4x64_epi64(c_int16_0p0, 0XD8); + //-------------------------------------------------------------------------- + + // Extract the first 128 bits of the register + temp[0] = _mm256_extractf128_si256(c_int16_1p0, 0); + + // Extract the second 128 bits of the register + temp[1] = _mm256_extractf128_si256(c_int16_1p0, 1); + + // Since s16 values cannot be converted to f32 directly, + // they are converted to s32, then to f32 and the scale is performed + temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); + temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); + + temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); + temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); + + // Multiply the C matrix by the scale value + res_1 = _mm256_mul_ps(temp_float[0], scale_1); + res_2 = _mm256_mul_ps(temp_float[0], scale_2); + + // Round the resultant value to the nearest integer + res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + + // Convert float32 scaled rounded value to int32 + temp_32[0] = _mm256_cvtps_epi32(res_1); + temp_32[1] = _mm256_cvtps_epi32(res_2); + + // Convert the s32 to s16 + c_int16_1p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]); + + // Permute to make sure the order is correct + c_int16_1p0 = _mm256_permute4x64_epi64(c_int16_1p0, 0XD8); + + __m256i store_reg = _mm256_packs_epi16(c_int16_0p0, c_int16_1p0); + + _mm256_storeu_si256((__m256i *)(c + (rs_c * 0) + (0 * 16)), store_reg); + + //-------------------------------------------------------------------------- + + // Extract the first 128 bits of the register + temp[0] = _mm256_extractf128_si256(c_int16_2p0, 0); + + // Extract the second 128 bits of the register + temp[1] = _mm256_extractf128_si256(c_int16_2p0, 1); + + // Since s16 values cannot be converted to f32 directly, + // they are converted to s32, then to f32 and the scale is performed + temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); + temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); + + temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); + temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); + + // Multiply the C matrix by the scale value + res_1 = _mm256_mul_ps(temp_float[0], scale_1); + res_2 = _mm256_mul_ps(temp_float[0], scale_2); + + // Round the resultant value to the nearest integer + res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + + // Convert float32 scaled rounded value to int32 + temp_32[0] = _mm256_cvtps_epi32(res_1); + temp_32[1] = _mm256_cvtps_epi32(res_2); + + // Convert the s32 to s16 + c_int16_2p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]); + + // Permute to make sure the order is correct + c_int16_2p0 = _mm256_permute4x64_epi64(c_int16_2p0, 0XD8); + //-------------------------------------------------------------------------- + + // Extract the first 128 bits of the register + temp[0] = _mm256_extractf128_si256(c_int16_3p0, 0); + + // Extract the second 128 bits of the register + temp[1] = _mm256_extractf128_si256(c_int16_3p0, 1); + + // Since s16 values cannot be converted to f32 directly, + // they are converted to s32, then to f32 and the scale is performed + temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); + temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); + + temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); + temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); + + // Multiply the C matrix by the scale value + res_1 = _mm256_mul_ps(temp_float[0], scale_1); + res_2 = _mm256_mul_ps(temp_float[0], scale_2); + + // Round the resultant value to the nearest integer + res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + + // Convert float32 scaled rounded value to int32 + temp_32[0] = _mm256_cvtps_epi32(res_1); + temp_32[1] = _mm256_cvtps_epi32(res_2); + + // Convert the s32 to s16 + c_int16_3p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]); + + // Permute to make sure the order is correct + c_int16_3p0 = _mm256_permute4x64_epi64(c_int16_3p0, 0XD8); + + store_reg = _mm256_packs_epi16(c_int16_2p0, c_int16_3p0); + + _mm256_storeu_si256((__m256i *)(c + (rs_c * 0) + (0 * 16)), store_reg); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_4xlt16_DISABLE: @@ -551,7 +869,8 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2x16) &&POST_OPS_2x16_DISABLE, &&POST_OPS_BIAS_2x16, &&POST_OPS_RELU_2x16, - &&POST_OPS_RELU_SCALE_2x16 + &&POST_OPS_RELU_SCALE_2x16, + &&POST_OPS_DOWNSCALE_2x16 }; // The division is done by considering the vpmaddubsw instruction @@ -687,6 +1006,95 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2x16) // c[1,0-15] RELU_SCALE_OP_S16_AVX2(c_int16_1p0) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_2x16: + { + __m128i temp[2]; + __m256i temp_32[2]; + __m256 temp_float[2]; + + // Load the scale vector values into the register + __m256 scale_1 = + _mm256_loadu_ps( + (float *)post_ops_list_temp->scale_factor + + post_op_c_j + (0 * 8)); + __m256 scale_2 = + _mm256_loadu_ps( + (float *)post_ops_list_temp->scale_factor + + post_op_c_j + (1 * 8)); + + // Extract the first 128 bits of the register + temp[0] = _mm256_extractf128_si256(c_int16_0p0, 0); + + // Extract the second 128 bits of the register + temp[1] = _mm256_extractf128_si256(c_int16_0p0, 1); + + // Since s16 values cannot be converted to f32 directly, + // they are converted to s32, then to f32 and the scale is performed + temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); + temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); + + temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); + temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); + + // Multiply the C matrix by the scale value + __m256 res_1 = _mm256_mul_ps(temp_float[0], scale_1); + __m256 res_2 = _mm256_mul_ps(temp_float[0], scale_2); + + // Round the resultant value to the nearest integer + res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + + // Convert float32 scaled rounded value to int32 + temp_32[0] = _mm256_cvtps_epi32(res_1); + temp_32[1] = _mm256_cvtps_epi32(res_2); + + // Convert the s32 to s16 + c_int16_0p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]); + + // Permute to make sure the order is correct + c_int16_0p0 = _mm256_permute4x64_epi64(c_int16_0p0, 0XD8); + //-------------------------------------------------------------------------- + + // Extract the first 128 bits of the register + temp[0] = _mm256_extractf128_si256(c_int16_1p0, 0); + + // Extract the second 128 bits of the register + temp[1] = _mm256_extractf128_si256(c_int16_1p0, 1); + + // Since s16 values cannot be converted to f32 directly, + // they are converted to s32, then to f32 and the scale is performed + temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); + temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); + + temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); + temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); + + // Multiply the C matrix by the scale value + res_1 = _mm256_mul_ps(temp_float[0], scale_1); + res_2 = _mm256_mul_ps(temp_float[0], scale_2); + + // Round the resultant value to the nearest integer + res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + + // Convert float32 scaled rounded value to int32 + temp_32[0] = _mm256_cvtps_epi32(res_1); + temp_32[1] = _mm256_cvtps_epi32(res_2); + + // Convert the s32 to s16 + c_int16_1p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]); + + // Permute to make sure the order is correct + c_int16_1p0 = _mm256_permute4x64_epi64(c_int16_1p0, 0XD8); + + __m256i store_reg = _mm256_packs_epi16(c_int16_0p0, c_int16_1p0); + + _mm256_storeu_si256((__m256i *)(c + (rs_c * 0) + (0 * 16)), store_reg); + + //-------------------------------------------------------------------------- + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_2x16_DISABLE: @@ -710,7 +1118,8 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2xlt16) &&POST_OPS_2xlt16_DISABLE, &&POST_OPS_BIAS_2xlt16, &&POST_OPS_RELU_2xlt16, - &&POST_OPS_RELU_SCALE_2xlt16 + &&POST_OPS_RELU_SCALE_2xlt16, + &&POST_OPS_DOWNSCALE_2xlt16 }; // The division is done by considering the vpmaddubsw instruction @@ -854,6 +1263,95 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2xlt16) // c[1,0-15] RELU_SCALE_OP_S16_AVX2(c_int16_1p0) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_2xlt16: + { + __m128i temp[2]; + __m256i temp_32[2]; + __m256 temp_float[2]; + + // Load the scale vector values into the register + __m256 scale_1 = + _mm256_loadu_ps( + (float *)post_ops_list_temp->scale_factor + + post_op_c_j + (0 * 8)); + __m256 scale_2 = + _mm256_loadu_ps( + (float *)post_ops_list_temp->scale_factor + + post_op_c_j + (1 * 8)); + + // Extract the first 128 bits of the register + temp[0] = _mm256_extractf128_si256(c_int16_0p0, 0); + + // Extract the second 128 bits of the register + temp[1] = _mm256_extractf128_si256(c_int16_0p0, 1); + + // Since s16 values cannot be converted to f32 directly, + // they are converted to s32, then to f32 and the scale is performed + temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); + temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); + + temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); + temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); + + // Multiply the C matrix by the scale value + __m256 res_1 = _mm256_mul_ps(temp_float[0], scale_1); + __m256 res_2 = _mm256_mul_ps(temp_float[0], scale_2); + + // Round the resultant value to the nearest integer + res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + + // Convert float32 scaled rounded value to int32 + temp_32[0] = _mm256_cvtps_epi32(res_1); + temp_32[1] = _mm256_cvtps_epi32(res_2); + + // Convert the s32 to s16 + c_int16_0p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]); + + // Permute to make sure the order is correct + c_int16_0p0 = _mm256_permute4x64_epi64(c_int16_0p0, 0XD8); + //-------------------------------------------------------------------------- + + // Extract the first 128 bits of the register + temp[0] = _mm256_extractf128_si256(c_int16_1p0, 0); + + // Extract the second 128 bits of the register + temp[1] = _mm256_extractf128_si256(c_int16_1p0, 1); + + // Since s16 values cannot be converted to f32 directly, + // they are converted to s32, then to f32 and the scale is performed + temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); + temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); + + temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); + temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); + + // Multiply the C matrix by the scale value + res_1 = _mm256_mul_ps(temp_float[0], scale_1); + res_2 = _mm256_mul_ps(temp_float[0], scale_2); + + // Round the resultant value to the nearest integer + res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + + // Convert float32 scaled rounded value to int32 + temp_32[0] = _mm256_cvtps_epi32(res_1); + temp_32[1] = _mm256_cvtps_epi32(res_2); + + // Convert the s32 to s16 + c_int16_1p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]); + + // Permute to make sure the order is correct + c_int16_1p0 = _mm256_permute4x64_epi64(c_int16_1p0, 0XD8); + + __m256i store_reg = _mm256_packs_epi16(c_int16_0p0, c_int16_1p0); + + _mm256_storeu_si256((__m256i *)(c + (rs_c * 0) + (0 * 16)), store_reg); + + //-------------------------------------------------------------------------- + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_2xlt16_DISABLE: diff --git a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_n_fringe_amd256.c b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_n_fringe_amd256.c index 171e730e0a..f70d494c70 100644 --- a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_n_fringe_amd256.c +++ b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_n_fringe_amd256.c @@ -49,7 +49,8 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x16) &&POST_OPS_6x16_DISABLE, &&POST_OPS_BIAS_6x16, &&POST_OPS_RELU_6x16, - &&POST_OPS_RELU_SCALE_6x16 + &&POST_OPS_RELU_SCALE_6x16, + &&POST_OPS_DOWNSCALE_6x16 }; dim_t m_full_pieces = m0 / MR; @@ -344,6 +345,234 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x16) // c[5,0-15] RELU_SCALE_OP_S16_AVX2(c_int16_5p0) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_6x16: + { + __m128i temp[2]; + __m256i temp_32[2]; + __m256 temp_float[2]; + + // Load the scale vector values into the register + __m256 scale_1 = + _mm256_loadu_ps( + (float *)post_ops_list_temp->scale_factor + + post_op_c_j + (0 * 8)); + __m256 scale_2 = + _mm256_loadu_ps( + (float *)post_ops_list_temp->scale_factor + + post_op_c_j + (1 * 8)); + + // Extract the first 128 bits of the register + temp[0] = _mm256_extractf128_si256(c_int16_0p0, 0); + + // Extract the second 128 bits of the register + temp[1] = _mm256_extractf128_si256(c_int16_0p0, 1); + + // Since s16 values cannot be converted to f32 directly, + // they are converted to s32, then to f32 and the scale is performed + temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); + temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); + + temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); + temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); + + // Multiply the C matrix by the scale value + __m256 res_1 = _mm256_mul_ps(temp_float[0], scale_1); + __m256 res_2 = _mm256_mul_ps(temp_float[0], scale_2); + + // Round the resultant value to the nearest integer + res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + + // Convert float32 scaled rounded value to int32 + temp_32[0] = _mm256_cvtps_epi32(res_1); + temp_32[1] = _mm256_cvtps_epi32(res_2); + + // Convert the s32 to s16 + c_int16_0p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]); + + //-------------------------------------------------------------------------- + + // Extract the first 128 bits of the register + temp[0] = _mm256_extractf128_si256(c_int16_1p0, 0); + + // Extract the second 128 bits of the register + temp[1] = _mm256_extractf128_si256(c_int16_1p0, 1); + + // Since s16 values cannot be converted to f32 directly, + // they are converted to s32, then to f32 and the scale is performed + temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); + temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); + + temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); + temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); + + // Multiply the C matrix by the scale value + res_1 = _mm256_mul_ps(temp_float[0], scale_1); + res_2 = _mm256_mul_ps(temp_float[0], scale_2); + + // Round the resultant value to the nearest integer + res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + + // Convert float32 scaled rounded value to int32 + temp_32[0] = _mm256_cvtps_epi32(res_1); + temp_32[1] = _mm256_cvtps_epi32(res_2); + + // Convert the s32 to s16 + c_int16_1p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]); + + // Permute to make sure the order is correct + c_int16_1p0 = _mm256_permute4x64_epi64(c_int16_1p0, 0XD8); + + __m256i store_reg = _mm256_packs_epi16(c_int16_0p0, c_int16_1p0); + + _mm256_storeu_si256((__m256i *)(c + (rs_c * (ir + 0)) + (0 * 16)), store_reg); + + //-------------------------------------------------------------------------- + + // Extract the first 128 bits of the register + temp[0] = _mm256_extractf128_si256(c_int16_2p0, 0); + + // Extract the second 128 bits of the register + temp[1] = _mm256_extractf128_si256(c_int16_2p0, 1); + + // Since s16 values cannot be converted to f32 directly, + // they are converted to s32, then to f32 and the scale is performed + temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); + temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); + + temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); + temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); + + // Multiply the C matrix by the scale value + res_1 = _mm256_mul_ps(temp_float[0], scale_1); + res_2 = _mm256_mul_ps(temp_float[0], scale_2); + + // Round the resultant value to the nearest integer + res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + + // Convert float32 scaled rounded value to int32 + temp_32[0] = _mm256_cvtps_epi32(res_1); + temp_32[1] = _mm256_cvtps_epi32(res_2); + + // Convert the s32 to s16 + c_int16_2p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]); + + // Permute to make sure the order is correct + c_int16_2p0 = _mm256_permute4x64_epi64(c_int16_2p0, 0XD8); + //-------------------------------------------------------------------------- + + // Extract the first 128 bits of the register + temp[0] = _mm256_extractf128_si256(c_int16_3p0, 0); + + // Extract the second 128 bits of the register + temp[1] = _mm256_extractf128_si256(c_int16_3p0, 1); + + // Since s16 values cannot be converted to f32 directly, + // they are converted to s32, then to f32 and the scale is performed + temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); + temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); + + temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); + temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); + + // Multiply the C matrix by the scale value + res_1 = _mm256_mul_ps(temp_float[0], scale_1); + res_2 = _mm256_mul_ps(temp_float[0], scale_2); + + // Round the resultant value to the nearest integer + res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + + // Convert float32 scaled rounded value to int32 + temp_32[0] = _mm256_cvtps_epi32(res_1); + temp_32[1] = _mm256_cvtps_epi32(res_2); + + // Convert the s32 to s16 + c_int16_3p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]); + + // Permute to make sure the order is correct + c_int16_3p0 = _mm256_permute4x64_epi64(c_int16_3p0, 0XD8); + + store_reg = _mm256_packs_epi16(c_int16_2p0, c_int16_3p0); + + _mm256_storeu_si256((__m256i *)(c + (rs_c * (ir + 0)) + (0 * 16)), store_reg); + + //-------------------------------------------------------------------------- + + // Extract the first 128 bits of the register + temp[0] = _mm256_extractf128_si256(c_int16_4p0, 0); + + // Extract the second 128 bits of the register + temp[1] = _mm256_extractf128_si256(c_int16_4p0, 1); + + // Since s16 values cannot be converted to f32 directly, + // they are converted to s32, then to f32 and the scale is performed + temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); + temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); + + temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); + temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); + + // Multiply the C matrix by the scale value + res_1 = _mm256_mul_ps(temp_float[0], scale_1); + res_2 = _mm256_mul_ps(temp_float[0], scale_2); + + // Round the resultant value to the nearest integer + res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + + // Convert float32 scaled rounded value to int32 + temp_32[0] = _mm256_cvtps_epi32(res_1); + temp_32[1] = _mm256_cvtps_epi32(res_2); + + // Convert the s32 to s16 + c_int16_4p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]); + + // Permute to make sure the order is correct + c_int16_4p0 = _mm256_permute4x64_epi64(c_int16_4p0, 0XD8); + + //-------------------------------------------------------------------------- + + // Extract the first 128 bits of the register + temp[0] = _mm256_extractf128_si256(c_int16_5p0, 0); + + // Extract the second 128 bits of the register + temp[1] = _mm256_extractf128_si256(c_int16_5p0, 1); + + // Since s16 values cannot be converted to f32 directly, + // they are converted to s32, then to f32 and the scale is performed + temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); + temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); + + temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); + temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); + + // Multiply the C matrix by the scale value + res_1 = _mm256_mul_ps(temp_float[0], scale_1); + res_2 = _mm256_mul_ps(temp_float[0], scale_2); + + // Round the resultant value to the nearest integer + res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + + // Convert float32 scaled rounded value to int32 + temp_32[0] = _mm256_cvtps_epi32(res_1); + temp_32[1] = _mm256_cvtps_epi32(res_2); + + // Convert the s32 to s16 + c_int16_5p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]); + + // Permute to make sure the order is correct + c_int16_5p0 = _mm256_permute4x64_epi64(c_int16_5p0, 0XD8); + + store_reg = _mm256_packs_epi16(c_int16_4p0, c_int16_4p0); + + _mm256_storeu_si256((__m256i *)(c + (rs_c * (ir + 0)) + (0 * 16)), store_reg); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_6x16_DISABLE: @@ -439,7 +668,8 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6xlt16) &&POST_OPS_6xlt16_DISABLE, &&POST_OPS_BIAS_6xlt16, &&POST_OPS_RELU_6xlt16, - &&POST_OPS_RELU_SCALE_6xlt16 + &&POST_OPS_RELU_SCALE_6xlt16, + &&POST_OPS_DOWNSCALE_6xlt16 }; dim_t m_full_pieces = m0 / MR; @@ -750,6 +980,234 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6xlt16) // c[5,0-15] RELU_SCALE_OP_S16_AVX2(c_int16_5p0) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_DOWNSCALE_6xlt16: + { + __m128i temp[2]; + __m256i temp_32[2]; + __m256 temp_float[2]; + + // Load the scale vector values into the register + __m256 scale_1 = + _mm256_loadu_ps( + (float *)post_ops_list_temp->scale_factor + + post_op_c_j + (0 * 8)); + __m256 scale_2 = + _mm256_loadu_ps( + (float *)post_ops_list_temp->scale_factor + + post_op_c_j + (1 * 8)); + + // Extract the first 128 bits of the register + temp[0] = _mm256_extractf128_si256(c_int16_0p0, 0); + + // Extract the second 128 bits of the register + temp[1] = _mm256_extractf128_si256(c_int16_0p0, 1); + + // Since s16 values cannot be converted to f32 directly, + // they are converted to s32, then to f32 and the scale is performed + temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); + temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); + + temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); + temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); + + // Multiply the C matrix by the scale value + __m256 res_1 = _mm256_mul_ps(temp_float[0], scale_1); + __m256 res_2 = _mm256_mul_ps(temp_float[0], scale_2); + + // Round the resultant value to the nearest integer + res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + + // Convert float32 scaled rounded value to int32 + temp_32[0] = _mm256_cvtps_epi32(res_1); + temp_32[1] = _mm256_cvtps_epi32(res_2); + + // Convert the s32 to s16 + c_int16_0p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]); + + //-------------------------------------------------------------------------- + + // Extract the first 128 bits of the register + temp[0] = _mm256_extractf128_si256(c_int16_1p0, 0); + + // Extract the second 128 bits of the register + temp[1] = _mm256_extractf128_si256(c_int16_1p0, 1); + + // Since s16 values cannot be converted to f32 directly, + // they are converted to s32, then to f32 and the scale is performed + temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); + temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); + + temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); + temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); + + // Multiply the C matrix by the scale value + res_1 = _mm256_mul_ps(temp_float[0], scale_1); + res_2 = _mm256_mul_ps(temp_float[0], scale_2); + + // Round the resultant value to the nearest integer + res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + + // Convert float32 scaled rounded value to int32 + temp_32[0] = _mm256_cvtps_epi32(res_1); + temp_32[1] = _mm256_cvtps_epi32(res_2); + + // Convert the s32 to s16 + c_int16_1p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]); + + // Permute to make sure the order is correct + c_int16_1p0 = _mm256_permute4x64_epi64(c_int16_1p0, 0XD8); + + __m256i store_reg = _mm256_packs_epi16(c_int16_0p0, c_int16_1p0); + + _mm256_storeu_si256((__m256i *)(c + (rs_c * (ir + 0)) + (0 * 16)), store_reg); + + //-------------------------------------------------------------------------- + + // Extract the first 128 bits of the register + temp[0] = _mm256_extractf128_si256(c_int16_2p0, 0); + + // Extract the second 128 bits of the register + temp[1] = _mm256_extractf128_si256(c_int16_2p0, 1); + + // Since s16 values cannot be converted to f32 directly, + // they are converted to s32, then to f32 and the scale is performed + temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); + temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); + + temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); + temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); + + // Multiply the C matrix by the scale value + res_1 = _mm256_mul_ps(temp_float[0], scale_1); + res_2 = _mm256_mul_ps(temp_float[0], scale_2); + + // Round the resultant value to the nearest integer + res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + + // Convert float32 scaled rounded value to int32 + temp_32[0] = _mm256_cvtps_epi32(res_1); + temp_32[1] = _mm256_cvtps_epi32(res_2); + + // Convert the s32 to s16 + c_int16_2p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]); + + // Permute to make sure the order is correct + c_int16_2p0 = _mm256_permute4x64_epi64(c_int16_2p0, 0XD8); + //-------------------------------------------------------------------------- + + // Extract the first 128 bits of the register + temp[0] = _mm256_extractf128_si256(c_int16_3p0, 0); + + // Extract the second 128 bits of the register + temp[1] = _mm256_extractf128_si256(c_int16_3p0, 1); + + // Since s16 values cannot be converted to f32 directly, + // they are converted to s32, then to f32 and the scale is performed + temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); + temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); + + temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); + temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); + + // Multiply the C matrix by the scale value + res_1 = _mm256_mul_ps(temp_float[0], scale_1); + res_2 = _mm256_mul_ps(temp_float[0], scale_2); + + // Round the resultant value to the nearest integer + res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + + // Convert float32 scaled rounded value to int32 + temp_32[0] = _mm256_cvtps_epi32(res_1); + temp_32[1] = _mm256_cvtps_epi32(res_2); + + // Convert the s32 to s16 + c_int16_3p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]); + + // Permute to make sure the order is correct + c_int16_3p0 = _mm256_permute4x64_epi64(c_int16_3p0, 0XD8); + + store_reg = _mm256_packs_epi16(c_int16_2p0, c_int16_3p0); + + _mm256_storeu_si256((__m256i *)(c + (rs_c * (ir + 0)) + (0 * 16)), store_reg); + + //-------------------------------------------------------------------------- + + // Extract the first 128 bits of the register + temp[0] = _mm256_extractf128_si256(c_int16_4p0, 0); + + // Extract the second 128 bits of the register + temp[1] = _mm256_extractf128_si256(c_int16_4p0, 1); + + // Since s16 values cannot be converted to f32 directly, + // they are converted to s32, then to f32 and the scale is performed + temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); + temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); + + temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); + temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); + + // Multiply the C matrix by the scale value + res_1 = _mm256_mul_ps(temp_float[0], scale_1); + res_2 = _mm256_mul_ps(temp_float[0], scale_2); + + // Round the resultant value to the nearest integer + res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + + // Convert float32 scaled rounded value to int32 + temp_32[0] = _mm256_cvtps_epi32(res_1); + temp_32[1] = _mm256_cvtps_epi32(res_2); + + // Convert the s32 to s16 + c_int16_4p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]); + + // Permute to make sure the order is correct + c_int16_4p0 = _mm256_permute4x64_epi64(c_int16_4p0, 0XD8); + + //-------------------------------------------------------------------------- + + // Extract the first 128 bits of the register + temp[0] = _mm256_extractf128_si256(c_int16_5p0, 0); + + // Extract the second 128 bits of the register + temp[1] = _mm256_extractf128_si256(c_int16_5p0, 1); + + // Since s16 values cannot be converted to f32 directly, + // they are converted to s32, then to f32 and the scale is performed + temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); + temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); + + temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); + temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); + + // Multiply the C matrix by the scale value + res_1 = _mm256_mul_ps(temp_float[0], scale_1); + res_2 = _mm256_mul_ps(temp_float[0], scale_2); + + // Round the resultant value to the nearest integer + res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + + // Convert float32 scaled rounded value to int32 + temp_32[0] = _mm256_cvtps_epi32(res_1); + temp_32[1] = _mm256_cvtps_epi32(res_2); + + // Convert the s32 to s16 + c_int16_5p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]); + + // Permute to make sure the order is correct + c_int16_5p0 = _mm256_permute4x64_epi64(c_int16_5p0, 0XD8); + + store_reg = _mm256_packs_epi16(c_int16_4p0, c_int16_4p0); + + _mm256_storeu_si256((__m256i *)(c + (rs_c * (ir + 0)) + (0 * 16)), store_reg); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_6xlt16_DISABLE: From 32a9e735f1c9d3d4b391608462ea761d841c7a65 Mon Sep 17 00:00:00 2001 From: eashdash Date: Tue, 30 Aug 2022 17:27:16 +0000 Subject: [PATCH 214/243] BF16 Output downscaling functionality - BF16 instructions output is accumulated at a higher precision of FP32 which needs to be converted to a lower precison of bf16 post the GEMM operations. This is required in AI workloads where both input and output are in BF16 format. - BF16 downscaling is implemented as post-ops inside the GEMM microkernels. Change-Id: Id1606746e3db4f3ed88cba385a7709c8604002a8 --- .../lpgemm_6x64rowmajor_bf16_amd512vnni.c | 82 +++- .../bf16bf16f32/lpgemm_f32_kern_macros.h | 21 + .../lpgemm_m_fringe_bf16_amd512vnni.c | 216 ++++++++- .../lpgemm_mn_fringe_bf16_amd512vnni.c | 457 +++++++++++++++++- .../lpgemm_n_fringe_bf16_amd512vnni.c | 154 +++++- 5 files changed, 898 insertions(+), 32 deletions(-) diff --git a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_6x64rowmajor_bf16_amd512vnni.c b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_6x64rowmajor_bf16_amd512vnni.c index 81d15c4e04..87f713705e 100644 --- a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_6x64rowmajor_bf16_amd512vnni.c +++ b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_6x64rowmajor_bf16_amd512vnni.c @@ -46,7 +46,8 @@ LPGEMM_MAIN_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x64) &&POST_OPS_6x64_DISABLE, &&POST_OPS_BIAS_6x64, &&POST_OPS_RELU_6x64, - &&POST_OPS_RELU_SCALE_6x64 + &&POST_OPS_RELU_SCALE_6x64, + &&POST_OPS_DOWNSCALE_6x64 }; dim_t MR = 6; dim_t NR = 64; @@ -792,7 +793,84 @@ LPGEMM_MAIN_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x64) RELU_SCALE_OP_F32_AVX512(c_float_5p3) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR - } + } +POST_OPS_DOWNSCALE_6x64: +{ + // c[0, 0-15] + CVT_F32_BF16(c_float_0p0,0,0); + + // c[0, 16-31] + CVT_F32_BF16(c_float_0p1,0,1); + + // c[0, 32-47] + CVT_F32_BF16(c_float_0p2,0,2); + + // c[0, 48-63] + CVT_F32_BF16(c_float_0p3,0,3); + + // c[1, 0-15] + CVT_F32_BF16(c_float_1p0,1,0); + + // c[1, 16-31] + CVT_F32_BF16(c_float_1p1,1,1); + + // c[1, 32-47] + CVT_F32_BF16(c_float_1p2,1,2); + + // c[1, 48-63] + CVT_F32_BF16(c_float_1p3,1,3); + + // c[2, 0-15] + CVT_F32_BF16(c_float_2p0,2,0); + + // c[2, 16-31] + CVT_F32_BF16(c_float_2p1,2,1); + + // c[2, 32-47] + CVT_F32_BF16(c_float_2p2,2,2); + + // c[2, 48-63] + CVT_F32_BF16(c_float_2p3,2,3); + + // c[3, 0-15] + CVT_F32_BF16(c_float_3p0,3,0); + + // c[3, 16-31] + CVT_F32_BF16(c_float_3p1,3,1); + + // c[3, 32-47] + CVT_F32_BF16(c_float_3p2,3,2); + + // c[3, 48-63] + CVT_F32_BF16(c_float_3p3,3,3); + + // c[4, 0-15] + CVT_F32_BF16(c_float_4p0,4,0); + + // c[4, 16-31] + CVT_F32_BF16(c_float_4p1,4,1); + + // c[4, 32-47] + CVT_F32_BF16(c_float_4p2,4,2); + + // c[4, 48-63] + CVT_F32_BF16(c_float_4p3,4,3); + + // c[5, 0-15] + CVT_F32_BF16(c_float_5p0,5,0); + + // c[5, 16-31] + CVT_F32_BF16(c_float_5p1,5,1); + + // c[5, 32-47] + CVT_F32_BF16(c_float_5p2,5,2); + + // c[5, 48-63] + CVT_F32_BF16(c_float_5p3,5,3); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR +} + POST_OPS_6x64_DISABLE: ; diff --git a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_f32_kern_macros.h b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_f32_kern_macros.h index 5ed1c6350f..88b449cfd5 100644 --- a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_f32_kern_macros.h +++ b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_f32_kern_macros.h @@ -31,6 +31,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ +#include "aocl_bf16_type.h" #ifndef LPGEMM_F32_KERN_MACROS_H #define LPGEMM_F32_KERN_MACROS_H @@ -42,4 +43,24 @@ /* Apply scaling on for <= 0 elements.*/ \ reg = _mm512_mask_mul_ps( reg, relu_cmp_mask, reg, selector2 ); \ +#define CVT_F32_BF16(reg,m_ind,n_ind) \ + _mm256_storeu_epi16 \ + ( \ + ( bfloat16* )post_ops_list_temp->op_args3 + \ + ( rs_c_downscale * ( post_op_c_i + m_ind ) ) + post_op_c_j + ( n_ind * 16 ), \ + (__m256i) \ + _mm512_cvtneps_pbh( reg ) \ + ) \ + +#define CVT_F32_BF16_LT16(reg,m_ind,n_ind) \ + _mm256_storeu_epi16 \ + ( \ + buf0, \ + (__m256i) \ + _mm512_cvtneps_pbh( reg ) \ + ); \ + memcpy( ( bfloat16* )post_ops_list_temp->op_args3 + \ + ( rs_c_downscale * ( post_op_c_i + m_ind ) ) + post_op_c_j + \ + ( n_ind * 16 ) , buf0, ( n0_rem * sizeof( bfloat16 ) ) ); \ + #endif // LPGEMM_F32_KERN_MACROS_H \ No newline at end of file diff --git a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_m_fringe_bf16_amd512vnni.c b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_m_fringe_bf16_amd512vnni.c index f88aaa9106..073668ba1f 100644 --- a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_m_fringe_bf16_amd512vnni.c +++ b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_m_fringe_bf16_amd512vnni.c @@ -47,7 +47,8 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x64) &&POST_OPS_5x64_DISABLE, &&POST_OPS_BIAS_5x64, &&POST_OPS_RELU_5x64, - &&POST_OPS_RELU_SCALE_5x64 + &&POST_OPS_RELU_SCALE_5x64, + &&POST_OPS_DOWNSCALE_5x64 }; dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -595,6 +596,70 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x64) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_DOWNSCALE_5x64: + { + // c[0, 0-15] + CVT_F32_BF16(c_float_0p0,0,0); + + // c[0, 16-31] + CVT_F32_BF16(c_float_0p1,0,1); + + // c[0, 32-47] + CVT_F32_BF16(c_float_0p2,0,2); + + // c[0, 48-63] + CVT_F32_BF16(c_float_0p3,0,3); + + // c[1, 0-15] + CVT_F32_BF16(c_float_1p0,1,0); + + // c[1, 16-31] + CVT_F32_BF16(c_float_1p1,1,1); + + // c[1, 32-47] + CVT_F32_BF16(c_float_1p2,1,2); + + // c[1, 48-63] + CVT_F32_BF16(c_float_1p3,1,3); + + // c[2, 0-15] + CVT_F32_BF16(c_float_2p0,2,0); + + // c[2, 16-31] + CVT_F32_BF16(c_float_2p1,2,1); + + // c[2, 32-47] + CVT_F32_BF16(c_float_2p2,2,2); + + // c[2, 48-63] + CVT_F32_BF16(c_float_2p3,2,3); + + // c[3, 0-15] + CVT_F32_BF16(c_float_3p0,3,0); + + // c[3, 16-31] + CVT_F32_BF16(c_float_3p1,3,1); + + // c[3, 32-47] + CVT_F32_BF16(c_float_3p2,3,2); + + // c[3, 48-63] + CVT_F32_BF16(c_float_3p3,3,3); + + // c[4, 0-15] + CVT_F32_BF16(c_float_4p0,4,0); + + // c[4, 16-31] + CVT_F32_BF16(c_float_4p1,4,1); + + // c[4, 32-47] + CVT_F32_BF16(c_float_4p2,4,2); + + // c[4, 48-63] + CVT_F32_BF16(c_float_4p3,4,3); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_5x64_DISABLE: ; @@ -668,7 +733,8 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x64) &&POST_OPS_4x64_DISABLE, &&POST_OPS_BIAS_4x64, &&POST_OPS_RELU_4x64, - &&POST_OPS_RELU_SCALE_4x64 + &&POST_OPS_RELU_SCALE_4x64, + &&POST_OPS_DOWNSCALE_4x64 }; dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -1123,6 +1189,59 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x64) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_DOWNSCALE_4x64: + { + // c[0, 0-15] + CVT_F32_BF16(c_float_0p0,0,0); + + // c[0, 16-31] + CVT_F32_BF16(c_float_0p1,0,1); + + // c[0, 32-47] + CVT_F32_BF16(c_float_0p2,0,2); + + // c[0, 48-63] + CVT_F32_BF16(c_float_0p3,0,3); + + // c[1, 0-15] + CVT_F32_BF16(c_float_1p0,1,0); + + // c[1, 16-31] + CVT_F32_BF16(c_float_1p1,1,1); + + // c[1, 32-47] + CVT_F32_BF16(c_float_1p2,1,2); + + // c[1, 48-63] + CVT_F32_BF16(c_float_1p3,1,3); + + // c[2, 0-15] + CVT_F32_BF16(c_float_2p0,2,0); + + // c[2, 16-31] + CVT_F32_BF16(c_float_2p1,2,1); + + // c[2, 32-47] + CVT_F32_BF16(c_float_2p2,2,2); + + // c[2, 48-63] + CVT_F32_BF16(c_float_2p3,2,3); + + // c[3, 0-15] + CVT_F32_BF16(c_float_3p0,3,0); + + // c[3, 16-31] + CVT_F32_BF16(c_float_3p1,3,1); + + // c[3, 32-47] + CVT_F32_BF16(c_float_3p2,3,2); + + // c[3, 48-63] + CVT_F32_BF16(c_float_3p3,3,3); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_4x64_DISABLE: ; @@ -1184,7 +1303,8 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x64) &&POST_OPS_3x64_DISABLE, &&POST_OPS_BIAS_3x64, &&POST_OPS_RELU_3x64, - &&POST_OPS_RELU_SCALE_3x64 + &&POST_OPS_RELU_SCALE_3x64, + &&POST_OPS_DOWNSCALE_3x64 }; dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -1545,6 +1665,46 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x64) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_DOWNSCALE_3x64: + { + // c[0, 0-15] + CVT_F32_BF16(c_float_0p0,0,0); + + // c[0, 16-31] + CVT_F32_BF16(c_float_0p1,0,1); + + // c[0, 32-47] + CVT_F32_BF16(c_float_0p2,0,2); + + // c[0, 48-63] + CVT_F32_BF16(c_float_0p3,0,3); + + // c[1, 0-15] + CVT_F32_BF16(c_float_1p0,1,0); + + // c[1, 16-31] + CVT_F32_BF16(c_float_1p1,1,1); + + // c[1, 32-47] + CVT_F32_BF16(c_float_1p2,1,2); + + // c[1, 48-63] + CVT_F32_BF16(c_float_1p3,1,3); + + // c[2, 0-15] + CVT_F32_BF16(c_float_2p0,2,0); + + // c[2, 16-31] + CVT_F32_BF16(c_float_2p1,2,1); + + // c[2, 32-47] + CVT_F32_BF16(c_float_2p2,2,2); + + // c[2, 48-63] + CVT_F32_BF16(c_float_2p3,2,3); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_3x64_DISABLE: ; @@ -1594,7 +1754,8 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2x64) &&POST_OPS_2x64_DISABLE, &&POST_OPS_BIAS_2x64, &&POST_OPS_RELU_2x64, - &&POST_OPS_RELU_SCALE_2x64 + &&POST_OPS_RELU_SCALE_2x64, + &&POST_OPS_DOWNSCALE_2x64 }; dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -1859,6 +2020,34 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2x64) // c[1, 48-63] RELU_SCALE_OP_F32_AVX512(c_float_1p3) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } + POST_OPS_DOWNSCALE_2x64: + { + // c[0, 0-15] + CVT_F32_BF16(c_float_0p0,0,0); + + // c[0, 16-31] + CVT_F32_BF16(c_float_0p1,0,1); + + // c[0, 32-47] + CVT_F32_BF16(c_float_0p2,0,2); + + // c[0, 48-63] + CVT_F32_BF16(c_float_0p3,0,3); + + // c[1, 0-15] + CVT_F32_BF16(c_float_1p0,1,0); + + // c[1, 16-31] + CVT_F32_BF16(c_float_1p1,1,1); + + // c[1, 32-47] + CVT_F32_BF16(c_float_1p2,1,2); + + // c[1, 48-63] + CVT_F32_BF16(c_float_1p3,1,3); + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_2x64_DISABLE: @@ -1898,7 +2087,8 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1x64) &&POST_OPS_1x64_DISABLE, &&POST_OPS_BIAS_1x64, &&POST_OPS_RELU_1x64, - &&POST_OPS_RELU_SCALE_1x64 + &&POST_OPS_RELU_SCALE_1x64, + &&POST_OPS_DOWNSCALE_1x64 }; dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -2061,6 +2251,22 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1x64) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_DOWNSCALE_1x64: + { + // c[0, 0-15] + CVT_F32_BF16(c_float_0p0,0,0); + + // c[0, 16-31] + CVT_F32_BF16(c_float_0p1,0,1); + + // c[0, 32-47] + CVT_F32_BF16(c_float_0p2,0,2); + + // c[0, 48-63] + CVT_F32_BF16(c_float_0p3,0,3); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_1x64_DISABLE: ; diff --git a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_mn_fringe_bf16_amd512vnni.c b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_mn_fringe_bf16_amd512vnni.c index 6e864b45c3..77dd30acd2 100644 --- a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_mn_fringe_bf16_amd512vnni.c +++ b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_mn_fringe_bf16_amd512vnni.c @@ -47,7 +47,8 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5xlt16) &&POST_OPS_5xLT16_DISABLE, &&POST_OPS_BIAS_5xLT16, &&POST_OPS_RELU_5xLT16, - &&POST_OPS_RELU_SCALE_5xLT16 + &&POST_OPS_RELU_SCALE_5xLT16, + &&POST_OPS_DOWNSCALE_5xLT16 }; dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -284,6 +285,25 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5xlt16) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_DOWNSCALE_5xLT16: + { + // c[0, 0-15] + CVT_F32_BF16_LT16(c_float_0p0,0,0); + + // c[1, 0-15] + CVT_F32_BF16_LT16(c_float_1p0,1,0); + + // c[2, 0-15] + CVT_F32_BF16_LT16(c_float_2p0,2,0); + + // c[3, 0-15] + CVT_F32_BF16_LT16(c_float_3p0,3,0); + + // c[4, 0-15] + CVT_F32_BF16_LT16(c_float_4p0,4,0); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_5xLT16_DISABLE: ; @@ -329,7 +349,8 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4xlt16) &&POST_OPS_4xLT16_DISABLE, &&POST_OPS_BIAS_4xLT16, &&POST_OPS_RELU_4xLT16, - &&POST_OPS_RELU_SCALE_4xLT16 + &&POST_OPS_RELU_SCALE_4xLT16, + &&POST_OPS_DOWNSCALE_4xLT16 }; dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -533,6 +554,22 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4xlt16) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_DOWNSCALE_4xLT16: + { + // c[0, 0-15] + CVT_F32_BF16_LT16(c_float_0p0,0,0); + + // c[1, 0-15] + CVT_F32_BF16_LT16(c_float_1p0,1,0); + + // c[2, 0-15] + CVT_F32_BF16_LT16(c_float_2p0,2,0); + + // c[3, 0-15] + CVT_F32_BF16_LT16(c_float_3p0,3,0); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_4xLT16_DISABLE: ; @@ -572,7 +609,8 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3xlt16) &&POST_OPS_3xLT16_DISABLE, &&POST_OPS_BIAS_3xLT16, &&POST_OPS_RELU_3xLT16, - &&POST_OPS_RELU_SCALE_3xLT16 + &&POST_OPS_RELU_SCALE_3xLT16, + &&POST_OPS_DOWNSCALE_3xLT16 }; dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -740,6 +778,19 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3xlt16) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_DOWNSCALE_3xLT16: + { + // c[0, 0-15] + CVT_F32_BF16_LT16(c_float_0p0,0,0); + + // c[1, 0-15] + CVT_F32_BF16_LT16(c_float_1p0,1,0); + + // c[2, 0-15] + CVT_F32_BF16_LT16(c_float_2p0,2,0); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_3xLT16_DISABLE: ; @@ -773,7 +824,8 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2xlt16) &&POST_OPS_2xLT16_DISABLE, &&POST_OPS_BIAS_2xLT16, &&POST_OPS_RELU_2xLT16, - &&POST_OPS_RELU_SCALE_2xLT16 + &&POST_OPS_RELU_SCALE_2xLT16, + &&POST_OPS_DOWNSCALE_2xLT16 }; dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -906,6 +958,16 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2xlt16) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_DOWNSCALE_2xLT16: + { + // c[0, 0-15] + CVT_F32_BF16_LT16(c_float_0p0,0,0); + + // c[1, 0-15] + CVT_F32_BF16_LT16(c_float_1p0,1,0); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_2xLT16_DISABLE: ; @@ -933,7 +995,8 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1xlt16) &&POST_OPS_1xLT16_DISABLE, &&POST_OPS_BIAS_1xLT16, &&POST_OPS_RELU_1xLT16, - &&POST_OPS_RELU_SCALE_1xLT16 + &&POST_OPS_RELU_SCALE_1xLT16, + &&POST_OPS_DOWNSCALE_1xLT16 }; dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -1031,6 +1094,13 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1xlt16) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_DOWNSCALE_1xLT16: + { + // c[0, 0-15] + CVT_F32_BF16_LT16(c_float_0p0,0,0); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_1xLT16_DISABLE: ; @@ -1052,7 +1122,8 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x16) &&POST_OPS_5x16_DISABLE, &&POST_OPS_BIAS_5x16, &&POST_OPS_RELU_5x16, - &&POST_OPS_RELU_SCALE_5x16 + &&POST_OPS_RELU_SCALE_5x16, + &&POST_OPS_DOWNSCALE_5x16 }; dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -1276,6 +1347,25 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x16) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_DOWNSCALE_5x16: + { + // c[0, 0-15] + CVT_F32_BF16(c_float_0p0,0,0); + + // c[1, 0-15] + CVT_F32_BF16(c_float_1p0,1,0); + + // c[2, 0-15] + CVT_F32_BF16(c_float_2p0,2,0); + + // c[3, 0-15] + CVT_F32_BF16(c_float_3p0,3,0); + + // c[4, 0-15] + CVT_F32_BF16(c_float_4p0,4,0); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_5x16_DISABLE: ; @@ -1304,7 +1394,8 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x16) &&POST_OPS_4x16_DISABLE, &&POST_OPS_BIAS_4x16, &&POST_OPS_RELU_4x16, - &&POST_OPS_RELU_SCALE_4x16 + &&POST_OPS_RELU_SCALE_4x16, + &&POST_OPS_DOWNSCALE_4x16 }; dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -1495,6 +1586,22 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x16) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_DOWNSCALE_4x16: + { + // c[0, 0-15] + CVT_F32_BF16(c_float_0p0,0,0); + + // c[1, 0-15] + CVT_F32_BF16(c_float_1p0,1,0); + + // c[2, 0-15] + CVT_F32_BF16(c_float_2p0,2,0); + + // c[3, 0-15] + CVT_F32_BF16(c_float_3p0,3,0); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_4x16_DISABLE: ; @@ -1520,7 +1627,8 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x16) &&POST_OPS_3x16_DISABLE, &&POST_OPS_BIAS_3x16, &&POST_OPS_RELU_3x16, - &&POST_OPS_RELU_SCALE_3x16 + &&POST_OPS_RELU_SCALE_3x16, + &&POST_OPS_DOWNSCALE_3x16 }; dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -1678,10 +1786,23 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x16) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_DOWNSCALE_3x16: + { + // c[0, 0-15] + CVT_F32_BF16(c_float_0p0,0,0); + + // c[1, 0-15] + CVT_F32_BF16(c_float_1p0,1,0); + + // c[2, 0-15] + CVT_F32_BF16(c_float_2p0,2,0); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_3x16_DISABLE: ; - + // Store the results. // c[0,0-15] _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 0*16 ), c_float_0p0 ); @@ -1701,7 +1822,8 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2x16) &&POST_OPS_2x16_DISABLE, &&POST_OPS_BIAS_2x16, &&POST_OPS_RELU_2x16, - &&POST_OPS_RELU_SCALE_2x16 + &&POST_OPS_RELU_SCALE_2x16, + &&POST_OPS_DOWNSCALE_2x16 }; dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -1826,6 +1948,16 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2x16) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_DOWNSCALE_2x16: + { + // c[0, 0-15] + CVT_F32_BF16(c_float_0p0,0,0); + + // c[1, 0-15] + CVT_F32_BF16(c_float_1p0,1,0); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_2x16_DISABLE: ; @@ -1845,7 +1977,8 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1x16) &&POST_OPS_1x16_DISABLE, &&POST_OPS_BIAS_1x16, &&POST_OPS_RELU_1x16, - &&POST_OPS_RELU_SCALE_1x16 + &&POST_OPS_RELU_SCALE_1x16, + &&POST_OPS_DOWNSCALE_1x16 }; dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -1937,6 +2070,13 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1x16) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_DOWNSCALE_1x16: + { + // c[0, 0-15] + CVT_F32_BF16(c_float_0p0,0,0); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_1x16_DISABLE: ; @@ -1953,7 +2093,8 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x32) &&POST_OPS_5x32_DISABLE, &&POST_OPS_BIAS_5x32, &&POST_OPS_RELU_5x32, - &&POST_OPS_RELU_SCALE_5x32 + &&POST_OPS_RELU_SCALE_5x32, + &&POST_OPS_DOWNSCALE_5x32 }; dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -2273,6 +2414,40 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x32) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_DOWNSCALE_5x32: + { + // c[0, 0-15] + CVT_F32_BF16(c_float_0p0,0,0); + + // c[0, 16-31] + CVT_F32_BF16(c_float_0p1,0,1); + + // c[1, 0-15] + CVT_F32_BF16(c_float_1p0,1,0); + + // c[1, 16-31] + CVT_F32_BF16(c_float_1p1,1,1); + + // c[2, 0-15] + CVT_F32_BF16(c_float_2p0,2,0); + + // c[2, 16-31] + CVT_F32_BF16(c_float_2p1,2,1); + + // c[3, 0-15] + CVT_F32_BF16(c_float_3p0,3,0); + + // c[3, 16-31] + CVT_F32_BF16(c_float_3p1,3,1); + + // c[4, 0-15] + CVT_F32_BF16(c_float_4p0,4,0); + + // c[4, 16-31] + CVT_F32_BF16(c_float_4p1,4,1); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_5x32_DISABLE: ; @@ -2316,7 +2491,8 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x32) &&POST_OPS_4x32_DISABLE, &&POST_OPS_BIAS_4x32, &&POST_OPS_RELU_4x32, - &&POST_OPS_RELU_SCALE_4x32 + &&POST_OPS_RELU_SCALE_4x32, + &&POST_OPS_DOWNSCALE_4x32 }; dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -2585,6 +2761,34 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x32) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_DOWNSCALE_4x32: + { + // c[0, 0-15] + CVT_F32_BF16(c_float_0p0,0,0); + + // c[0, 16-31] + CVT_F32_BF16(c_float_0p1,0,1); + + // c[1, 0-15] + CVT_F32_BF16(c_float_1p0,1,0); + + // c[1, 16-31] + CVT_F32_BF16(c_float_1p1,1,1); + + // c[2, 0-15] + CVT_F32_BF16(c_float_2p0,2,0); + + // c[2, 16-31] + CVT_F32_BF16(c_float_2p1,2,1); + + // c[3, 0-15] + CVT_F32_BF16(c_float_3p0,3,0); + + // c[3, 16-31] + CVT_F32_BF16(c_float_3p1,3,1); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_4x32_DISABLE: ; @@ -2622,7 +2826,8 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x32) &&POST_OPS_3x32_DISABLE, &&POST_OPS_BIAS_3x32, &&POST_OPS_RELU_3x32, - &&POST_OPS_RELU_SCALE_3x32 + &&POST_OPS_RELU_SCALE_3x32, + &&POST_OPS_DOWNSCALE_3x32 }; dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -2840,6 +3045,28 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x32) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_DOWNSCALE_3x32: + { + // c[0, 0-15] + CVT_F32_BF16(c_float_0p0,0,0); + + // c[0, 16-31] + CVT_F32_BF16(c_float_0p1,0,1); + + // c[1, 0-15] + CVT_F32_BF16(c_float_1p0,1,0); + + // c[1, 16-31] + CVT_F32_BF16(c_float_1p1,1,1); + + // c[2, 0-15] + CVT_F32_BF16(c_float_2p0,2,0); + + // c[2, 16-31] + CVT_F32_BF16(c_float_2p1,2,1); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_3x32_DISABLE: ; @@ -2871,7 +3098,8 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2x32) &&POST_OPS_2x32_DISABLE, &&POST_OPS_BIAS_2x32, &&POST_OPS_RELU_2x32, - &&POST_OPS_RELU_SCALE_2x32 + &&POST_OPS_RELU_SCALE_2x32, + &&POST_OPS_DOWNSCALE_2x32 }; dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -3038,6 +3266,22 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2x32) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_DOWNSCALE_2x32: + { + // c[0, 0-15] + CVT_F32_BF16(c_float_0p0,0,0); + + // c[0, 16-31] + CVT_F32_BF16(c_float_0p1,0,1); + + // c[1, 0-15] + CVT_F32_BF16(c_float_1p0,1,0); + + // c[1, 16-31] + CVT_F32_BF16(c_float_1p1,1,1); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_2x32_DISABLE: ; @@ -3063,7 +3307,8 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1x32) &&POST_OPS_1x32_DISABLE, &&POST_OPS_BIAS_1x32, &&POST_OPS_RELU_1x32, - &&POST_OPS_RELU_SCALE_1x32 + &&POST_OPS_RELU_SCALE_1x32, + &&POST_OPS_DOWNSCALE_1x32 }; dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -3179,6 +3424,16 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1x32) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_DOWNSCALE_1x32: + { + // c[0, 0-15] + CVT_F32_BF16(c_float_0p0,0,0); + + // c[0, 16-31] + CVT_F32_BF16(c_float_0p1,0,1); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_1x32_DISABLE: ; @@ -3198,7 +3453,8 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x48) &&POST_OPS_5x48_DISABLE, &&POST_OPS_BIAS_5x48, &&POST_OPS_RELU_5x48, - &&POST_OPS_RELU_SCALE_5x48 + &&POST_OPS_RELU_SCALE_5x48, + &&POST_OPS_DOWNSCALE_5x48 }; dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -3614,6 +3870,55 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x48) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_DOWNSCALE_5x48: + { + // c[0, 0-15] + CVT_F32_BF16(c_float_0p0,0,0); + + // c[0, 16-31] + CVT_F32_BF16(c_float_0p1,0,1); + + // c[0, 32-47] + CVT_F32_BF16(c_float_0p2,0,2); + + // c[1, 0-15] + CVT_F32_BF16(c_float_1p0,1,0); + + // c[1, 16-31] + CVT_F32_BF16(c_float_1p1,1,1); + + // c[1, 32-47] + CVT_F32_BF16(c_float_1p2,1,2); + + // c[2, 0-15] + CVT_F32_BF16(c_float_2p0,2,0); + + // c[2, 16-31] + CVT_F32_BF16(c_float_2p1,2,1); + + // c[2, 32-47] + CVT_F32_BF16(c_float_2p2,2,2); + + // c[3, 0-15] + CVT_F32_BF16(c_float_3p0,3,0); + + // c[3, 16-31] + CVT_F32_BF16(c_float_3p1,3,1); + + // c[3, 32-47] + CVT_F32_BF16(c_float_3p2,3,2); + + // c[4, 0-15] + CVT_F32_BF16(c_float_4p0,4,0); + + // c[4, 16-31] + CVT_F32_BF16(c_float_4p1,4,1); + + // c[4, 32-47] + CVT_F32_BF16(c_float_4p2,4,2); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_5x48_DISABLE: ; @@ -3672,7 +3977,8 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x48) &&POST_OPS_4x48_DISABLE, &&POST_OPS_BIAS_4x48, &&POST_OPS_RELU_4x48, - &&POST_OPS_RELU_SCALE_4x48 + &&POST_OPS_RELU_SCALE_4x48, + &&POST_OPS_DOWNSCALE_4x48 }; dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -4019,6 +4325,46 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x48) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_DOWNSCALE_4x48: + { + // c[0, 0-15] + CVT_F32_BF16(c_float_0p0,0,0); + + // c[0, 16-31] + CVT_F32_BF16(c_float_0p1,0,1); + + // c[0, 32-47] + CVT_F32_BF16(c_float_0p2,0,2); + + // c[1, 0-15] + CVT_F32_BF16(c_float_1p0,1,0); + + // c[1, 16-31] + CVT_F32_BF16(c_float_1p1,1,1); + + // c[1, 32-47] + CVT_F32_BF16(c_float_1p2,1,2); + + // c[2, 0-15] + CVT_F32_BF16(c_float_2p0,2,0); + + // c[2, 16-31] + CVT_F32_BF16(c_float_2p1,2,1); + + // c[2, 32-47] + CVT_F32_BF16(c_float_2p2,2,2); + + // c[3, 0-15] + CVT_F32_BF16(c_float_3p0,3,0); + + // c[3, 16-31] + CVT_F32_BF16(c_float_3p1,3,1); + + // c[3, 32-47] + CVT_F32_BF16(c_float_3p2,3,2); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_4x48_DISABLE: ; @@ -4068,7 +4414,8 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x48) &&POST_OPS_3x48_DISABLE, &&POST_OPS_BIAS_3x48, &&POST_OPS_RELU_3x48, - &&POST_OPS_RELU_SCALE_3x48 + &&POST_OPS_RELU_SCALE_3x48, + &&POST_OPS_DOWNSCALE_3x48 }; dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -4346,6 +4693,37 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x48) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_DOWNSCALE_3x48: + { + // c[0, 0-15] + CVT_F32_BF16(c_float_0p0,0,0); + + // c[0, 16-31] + CVT_F32_BF16(c_float_0p1,0,1); + + // c[0, 32-47] + CVT_F32_BF16(c_float_0p2,0,2); + + // c[1, 0-15] + CVT_F32_BF16(c_float_1p0,1,0); + + // c[1, 16-31] + CVT_F32_BF16(c_float_1p1,1,1); + + // c[1, 32-47] + CVT_F32_BF16(c_float_1p2,1,2); + + // c[2, 0-15] + CVT_F32_BF16(c_float_2p0,2,0); + + // c[2, 16-31] + CVT_F32_BF16(c_float_2p1,2,1); + + // c[2, 32-47] + CVT_F32_BF16(c_float_2p2,2,2); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_3x48_DISABLE: ; @@ -4386,7 +4764,8 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2x48) &&POST_OPS_2x48_DISABLE, &&POST_OPS_BIAS_2x48, &&POST_OPS_RELU_2x48, - &&POST_OPS_RELU_SCALE_2x48 + &&POST_OPS_RELU_SCALE_2x48, + &&POST_OPS_DOWNSCALE_2x48 }; dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -4595,6 +4974,28 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2x48) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_DOWNSCALE_2x48: + { + // c[0, 0-15] + CVT_F32_BF16(c_float_0p0,0,0); + + // c[0, 16-31] + CVT_F32_BF16(c_float_0p1,0,1); + + // c[0, 32-47] + CVT_F32_BF16(c_float_0p2,0,2); + + // c[1, 0-15] + CVT_F32_BF16(c_float_1p0,1,0); + + // c[1, 16-31] + CVT_F32_BF16(c_float_1p1,1,1); + + // c[1, 32-47] + CVT_F32_BF16(c_float_1p2,1,2); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_2x48_DISABLE: ; @@ -4626,7 +5027,8 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1x48) &&POST_OPS_1x48_DISABLE, &&POST_OPS_BIAS_1x48, &&POST_OPS_RELU_1x48, - &&POST_OPS_RELU_SCALE_1x48 + &&POST_OPS_RELU_SCALE_1x48, + &&POST_OPS_DOWNSCALE_1x48 }; dim_t k_full_pieces = k0 / 2; dim_t k_partial_pieces = k0 % 2; @@ -4766,6 +5168,19 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1x48) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_DOWNSCALE_1x48: + { + // c[0, 0-15] + CVT_F32_BF16(c_float_0p0,0,0); + + // c[0, 16-31] + CVT_F32_BF16(c_float_0p1,0,1); + + // c[0, 32-47] + CVT_F32_BF16(c_float_0p2,0,2); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_1x48_DISABLE: ; diff --git a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_n_fringe_bf16_amd512vnni.c b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_n_fringe_bf16_amd512vnni.c index 1a6a238d9d..e3c3f1c524 100644 --- a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_n_fringe_bf16_amd512vnni.c +++ b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_n_fringe_bf16_amd512vnni.c @@ -46,7 +46,8 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6xlt16) &&POST_OPS_6xLT16_DISABLE, &&POST_OPS_BIAS_6xLT16, &&POST_OPS_RELU_6xLT16, - &&POST_OPS_RELU_SCALE_6xLT16 + &&POST_OPS_RELU_SCALE_6xLT16, + &&POST_OPS_DOWNSCALE_6xLT16 }; dim_t MR = 6; dim_t m_full_pieces = m0 / MR; @@ -361,6 +362,28 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6xlt16) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_DOWNSCALE_6xLT16: + { + // c[0, 0-15] + CVT_F32_BF16_LT16(c_float_0p0,0,0); + + // c[1, 0-15] + CVT_F32_BF16_LT16(c_float_1p0,1,0); + + // c[2, 0-15] + CVT_F32_BF16_LT16(c_float_2p0,2,0); + + // c[3, 0-15] + CVT_F32_BF16_LT16(c_float_3p0,3,0); + + // c[4, 0-15] + CVT_F32_BF16_LT16(c_float_4p0,4,0); + + // c[5, 0-15] + CVT_F32_BF16_LT16(c_float_5p0,5,0); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_6xLT16_DISABLE: ; @@ -494,7 +517,8 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x16) &&POST_OPS_6x16_DISABLE, &&POST_OPS_BIAS_6x16, &&POST_OPS_RELU_6x16, - &&POST_OPS_RELU_SCALE_6x16 + &&POST_OPS_RELU_SCALE_6x16, + &&POST_OPS_DOWNSCALE_6x16 }; dim_t MR = 6; dim_t m_full_pieces = m0 / MR; @@ -794,6 +818,28 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x16) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_DOWNSCALE_6x16: + { + // c[0, 0-15] + CVT_F32_BF16(c_float_0p0,0,0); + + // c[1, 0-15] + CVT_F32_BF16(c_float_1p0,1,0); + + // c[2, 0-15] + CVT_F32_BF16(c_float_2p0,2,0); + + // c[3, 0-15] + CVT_F32_BF16(c_float_3p0,3,0); + + // c[4, 0-15] + CVT_F32_BF16(c_float_4p0,4,0); + + // c[5, 0-15] + CVT_F32_BF16(c_float_5p0,5,0); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_6x16_DISABLE: ; @@ -908,7 +954,8 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x32) &&POST_OPS_6x32_DISABLE, &&POST_OPS_BIAS_6x32, &&POST_OPS_RELU_6x32, - &&POST_OPS_RELU_SCALE_6x32 + &&POST_OPS_RELU_SCALE_6x32, + &&POST_OPS_DOWNSCALE_6x32 }; dim_t MR = 6; dim_t m_full_pieces = m0 / MR; @@ -1320,6 +1367,46 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x32) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_DOWNSCALE_6x32: + { + // c[0, 0-15] + CVT_F32_BF16(c_float_0p0,0,0); + + // c[0, 16-31] + CVT_F32_BF16(c_float_0p1,0,1); + + // c[1, 0-15] + CVT_F32_BF16(c_float_1p0,1,0); + + // c[1, 16-31] + CVT_F32_BF16(c_float_1p1,1,1); + + // c[2, 0-15] + CVT_F32_BF16(c_float_2p0,2,0); + + // c[2, 16-31] + CVT_F32_BF16(c_float_2p1,2,1); + + // c[3, 0-15] + CVT_F32_BF16(c_float_3p0,3,0); + + // c[3, 16-31] + CVT_F32_BF16(c_float_3p1,3,1); + + // c[4, 0-15] + CVT_F32_BF16(c_float_4p0,4,0); + + // c[4, 16-31] + CVT_F32_BF16(c_float_4p1,4,1); + + // c[5, 0-15] + CVT_F32_BF16(c_float_5p0,5,0); + + // c[5, 16-31] + CVT_F32_BF16(c_float_5p1,5,1); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_6x32_DISABLE: ; @@ -1452,7 +1539,8 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x48) &&POST_OPS_6x48_DISABLE, &&POST_OPS_BIAS_6x48, &&POST_OPS_RELU_6x48, - &&POST_OPS_RELU_SCALE_6x48 + &&POST_OPS_RELU_SCALE_6x48, + &&POST_OPS_DOWNSCALE_6x48 }; dim_t MR = 6; dim_t m_full_pieces = m0 / MR; @@ -1982,6 +2070,64 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x48) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_DOWNSCALE_6x48: + { + // c[0, 0-15] + CVT_F32_BF16(c_float_0p0,0,0); + + // c[0, 16-31] + CVT_F32_BF16(c_float_0p1,0,1); + + // c[0, 32-47] + CVT_F32_BF16(c_float_0p2,0,2); + + // c[1, 0-15] + CVT_F32_BF16(c_float_1p0,1,0); + + // c[1, 16-31] + CVT_F32_BF16(c_float_1p1,1,1); + + // c[1, 32-47] + CVT_F32_BF16(c_float_1p2,1,2); + + // c[2, 0-15] + CVT_F32_BF16(c_float_2p0,2,0); + + // c[2, 16-31] + CVT_F32_BF16(c_float_2p1,2,1); + + // c[2, 32-47] + CVT_F32_BF16(c_float_2p2,2,2); + + // c[3, 0-15] + CVT_F32_BF16(c_float_3p0,3,0); + + // c[3, 16-31] + CVT_F32_BF16(c_float_3p1,3,1); + + // c[3, 32-47] + CVT_F32_BF16(c_float_3p2,3,2); + + // c[4, 0-15] + CVT_F32_BF16(c_float_4p0,4,0); + + // c[4, 16-31] + CVT_F32_BF16(c_float_4p1,4,1); + + // c[4, 32-47] + CVT_F32_BF16(c_float_4p2,4,2); + + // c[5, 0-15] + CVT_F32_BF16(c_float_5p0,5,0); + + // c[5, 16-31] + CVT_F32_BF16(c_float_5p1,5,1); + + // c[5, 32-47] + CVT_F32_BF16(c_float_5p2,5,2); + + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } POST_OPS_6x48_DISABLE: ; From 559ebe6ad9ffe08afef8c2f2563c689c593b3b02 Mon Sep 17 00:00:00 2001 From: eashdash Date: Tue, 6 Sep 2022 11:41:09 +0000 Subject: [PATCH 215/243] LPGEMM BF16 MT panel based balancing Introduced multi-thread panel based balancing for BF16 to improve the overall MT performance. AMD-Internal: [CPUPL-2502] Change-Id: Iddce9548fa96e5f57bd3d3eb3e8268855ca47f25 --- addon/aocl_gemm/aocl_gemm_bf16bf16f32obf16.c | 2 +- addon/aocl_gemm/aocl_gemm_bf16bf16f32of32.c | 2 +- .../aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.c | 7 +++++++ .../aocl_gemm/kernels/bf16bf16f32/lpgemm_f32_kern_macros.h | 6 +++--- 4 files changed, 12 insertions(+), 5 deletions(-) diff --git a/addon/aocl_gemm/aocl_gemm_bf16bf16f32obf16.c b/addon/aocl_gemm/aocl_gemm_bf16bf16f32obf16.c index 482e7c264e..60f66d9405 100644 --- a/addon/aocl_gemm/aocl_gemm_bf16bf16f32obf16.c +++ b/addon/aocl_gemm/aocl_gemm_bf16bf16f32obf16.c @@ -47,7 +47,7 @@ AOCL_GEMM_MATMUL(bfloat16,bfloat16,bfloat16,bf16bf16f32obf16) trans_t blis_transb; // Check if avx512_vnni ISA is supported, lpgemm matmul only works with it. - if ( bli_cpuid_is_avx512vnni_supported() == FALSE ) + if ( bli_cpuid_is_avx512_bf16_supported() == FALSE ) { printf(" AVX512_BF16 ISA not supported by processor, cannot perform lpgemm.\n"); return; // Error. diff --git a/addon/aocl_gemm/aocl_gemm_bf16bf16f32of32.c b/addon/aocl_gemm/aocl_gemm_bf16bf16f32of32.c index e36399ba55..faedc060dc 100644 --- a/addon/aocl_gemm/aocl_gemm_bf16bf16f32of32.c +++ b/addon/aocl_gemm/aocl_gemm_bf16bf16f32of32.c @@ -47,7 +47,7 @@ AOCL_GEMM_MATMUL(bfloat16,bfloat16,float,bf16bf16f32of32) trans_t blis_transb; // Check if avx512_vnni ISA is supported, lpgemm matmul only works with it. - if ( bli_cpuid_is_avx512vnni_supported() == FALSE ) + if ( bli_cpuid_is_avx512_bf16_supported() == FALSE ) { printf(" AVX512_BF16 ISA not supported by processor, cannot perform lpgemm.\n"); return; // Error. diff --git a/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.c b/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.c index e22b248157..bf4f622dea 100644 --- a/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.c +++ b/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.c @@ -362,6 +362,7 @@ BLIS_INLINE void lpgemm_bf16bf16f32of32_get_threading { dim_t NR = lpgemm_get_block_size_NR_global_cntx( BF16BF16F32OF32 ); + dim_t MR = lpgemm_get_block_size_MR_global_cntx( BF16BF16F32OF32 ); if ( n <= NR ) { @@ -374,6 +375,12 @@ BLIS_INLINE void lpgemm_bf16bf16f32of32_get_threading { // If BLIS_NUM_THREADS are set, generate jc,ic from the same. bli_thread_partition_2x2( ( *n_threads ), m, n, ic_ways, jc_ways ); + lpgemm_adjust_ic_jc_ways( m, n, n_threads, ic_ways, jc_ways ); + lpgemm_pnl_wrk_heur_adjust_ic_jc_ways + ( + MR, NR, m, n, + n_threads, ic_ways, jc_ways + ); } } else diff --git a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_f32_kern_macros.h b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_f32_kern_macros.h index 88b449cfd5..c8c2a04c91 100644 --- a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_f32_kern_macros.h +++ b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_f32_kern_macros.h @@ -49,8 +49,8 @@ ( bfloat16* )post_ops_list_temp->op_args3 + \ ( rs_c_downscale * ( post_op_c_i + m_ind ) ) + post_op_c_j + ( n_ind * 16 ), \ (__m256i) \ - _mm512_cvtneps_pbh( reg ) \ - ) \ + _mm512_cvtneps_pbh( reg ) \ + ) \ #define CVT_F32_BF16_LT16(reg,m_ind,n_ind) \ _mm256_storeu_epi16 \ @@ -61,6 +61,6 @@ ); \ memcpy( ( bfloat16* )post_ops_list_temp->op_args3 + \ ( rs_c_downscale * ( post_op_c_i + m_ind ) ) + post_op_c_j + \ - ( n_ind * 16 ) , buf0, ( n0_rem * sizeof( bfloat16 ) ) ); \ + ( n_ind * 16 ) , buf0, ( n0_rem * sizeof( bfloat16 ) ) ); \ #endif // LPGEMM_F32_KERN_MACROS_H \ No newline at end of file From 4402bebdcffafeba18645d643a1695ff3c0dc394 Mon Sep 17 00:00:00 2001 From: Dipal M Zambare Date: Mon, 5 Sep 2022 11:41:36 +0530 Subject: [PATCH 216/243] GCC 8 support for zen4 (and amdzen) configuration - Added check for GCC version 8, - Added AVX512 compiler flags needed for zen4 build using GCC 8. AMD-Internal: [CPUPL-2494] Change-Id: I7fd72e4b197fdd754633f674a8b87f01da8dd320 --- config/zen4/make_defs.mk | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/config/zen4/make_defs.mk b/config/zen4/make_defs.mk index d5d13167aa..501ace287c 100644 --- a/config/zen4/make_defs.mk +++ b/config/zen4/make_defs.mk @@ -86,10 +86,15 @@ ifeq ($(shell test $(GCC_VERSION) -ge 9; echo $$?),0) CKVECFLAGS += -march=znver2 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -mfpmath=sse CRVECFLAGS += -march=znver2 else -# If gcc is older than 9.1.0 but at least 6.1.0, then we can use -march=znver1 +ifeq ($(shell test $(GCC_VERSION) -ge 8; echo $$?),0) +CKVECFLAGS += -march=znver1 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -mfpmath=sse +CRVECFLAGS += -march=znver1 +else +# If gcc is older than 8.0.0 but at least 6.1.0, then we can use -march=znver1 # as the fallback option. CKVECFLAGS += -march=znver1 -mno-avx256-split-unaligned-store CRVECFLAGS += -march=znver1 -mno-avx256-split-unaligned-store +endif # GCC 8 endif # GCC 9 endif # GCC 11 endif # GCC 12 From c9e69c44885b3e082b916e4e3bca477cf4c41918 Mon Sep 17 00:00:00 2001 From: Harihara Sudhan S Date: Thu, 8 Sep 2022 17:34:58 +0530 Subject: [PATCH 217/243] Bug fix in s16 downscale operation - Store operations was done to c matrix and not to c buffer AMD-Internal:[CPUPL-2171] Change-Id: Ic0897a20850fdae96db52f0ccc6fa087c84239fa --- .../u8s8s16/lpgemm_6x32rowmajor_amd256.c | 430 +------------- .../kernels/u8s8s16/lpgemm_m_fringe_amd256.c | 536 ++--------------- .../kernels/u8s8s16/lpgemm_mn_fringe_amd256.c | 546 ++++-------------- .../kernels/u8s8s16/lpgemm_n_fringe_amd256.c | 446 +------------- .../kernels/u8s8s16/lpgemm_s16_kern_macros.h | 340 +++++++++++ 5 files changed, 530 insertions(+), 1768 deletions(-) diff --git a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_6x32rowmajor_amd256.c b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_6x32rowmajor_amd256.c index d5bd5ef93a..5c4e440407 100644 --- a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_6x32rowmajor_amd256.c +++ b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_6x32rowmajor_amd256.c @@ -568,434 +568,40 @@ LPGEMM_MAIN_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x32) __m128i temp[2]; __m256i temp_32[2]; __m256 temp_float[2]; + __m256 scale_1, scale_2; + __m256 res_1, res_2; + __m256i store_reg; - // Load the scale vector values into the register - __m256 scale_1 = + /* Load the scale vector values into the register*/ + scale_1 = _mm256_loadu_ps( - (float *)post_ops_list_temp->scale_factor + - post_op_c_j + (0 * 8)); - __m256 scale_2 = + (float *)post_ops_list_temp->scale_factor + + post_op_c_j + (0 * 8)); + scale_2 = _mm256_loadu_ps( - (float *)post_ops_list_temp->scale_factor + - post_op_c_j + (1 * 8)); - - // Extract the first 128 bits of the register - temp[0] = _mm256_extractf128_si256(c_int16_0p0, 0); - - // Extract the second 128 bits of the register - temp[1] = _mm256_extractf128_si256(c_int16_0p0, 1); - - // Since s16 values cannot be converted to f32 directly, - // they are converted to s32, then to f32 and the scale is performed - temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); - temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); - - temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); - temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); - - // Multiply the C matrix by the scale value - __m256 res_1 = _mm256_mul_ps(temp_float[0], scale_1); - __m256 res_2 = _mm256_mul_ps(temp_float[0], scale_2); - - // Round the resultant value to the nearest integer - res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - - // Convert float32 scaled rounded value to int32 - temp_32[0] = _mm256_cvtps_epi32(res_1); - temp_32[1] = _mm256_cvtps_epi32(res_2); - - // Convert the s32 to s16 - c_int16_0p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]); - - // Permute to make sure the order is correct - c_int16_0p0 = _mm256_permute4x64_epi64(c_int16_0p0, 0XD8); - - // Extract the first 128 bits of the register - temp[0] = _mm256_extractf128_si256(c_int16_0p1, 0); - - // Extract the second 128 bits of the register - temp[1] = _mm256_extractf128_si256(c_int16_0p1, 1); - - // Since s16 values cannot be converted to f32 directly, - // they are converted to s32, then to f32 and the scale is performed - temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); - temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); - - temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); - temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); - - // Multiply the C matrix by the scale value - res_1 = _mm256_mul_ps(temp_float[0], scale_1); - res_2 = _mm256_mul_ps(temp_float[0], scale_2); - - // Round the resultant value to the nearest integer - res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - - // Convert float32 scaled rounded value to int32 - temp_32[0] = _mm256_cvtps_epi32(res_1); - temp_32[1] = _mm256_cvtps_epi32(res_2); - - // Convert the s32 to s16 - c_int16_0p1 = _mm256_packs_epi32(temp_32[0], temp_32[1]); - - // Permute to make sure the order is correct - c_int16_0p1 = _mm256_permute4x64_epi64(c_int16_0p1, 0XD8); - - __m256i store_reg = _mm256_packs_epi16(c_int16_0p0, c_int16_0p1); - - _mm256_storeu_si256((__m256i *)(c + (rs_c * (ir + 0)) + (0 * 16)), store_reg); + (float *)post_ops_list_temp->scale_factor + + post_op_c_j + (1 * 8)); + bli_mm256_s16_downscale(c_int16_0p0, c_int16_0p1, 0); //-------------------------------------------------------------------------- - // Extract the first 128 bits of the register - temp[0] = _mm256_extractf128_si256(c_int16_1p0, 0); - - // Extract the second 128 bits of the register - temp[1] = _mm256_extractf128_si256(c_int16_1p0, 1); - - // Since s16 values cannot be converted to f32 directly, - // they are converted to s32, then to f32 and the scale is performed - temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); - temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); - - temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); - temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); - - // Multiply the C matrix by the scale value - res_1 = _mm256_mul_ps(temp_float[0], scale_1); - res_2 = _mm256_mul_ps(temp_float[0], scale_2); - - // Round the resultant value to the nearest integer - res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - - // Convert float32 scaled rounded value to int32 - temp_32[0] = _mm256_cvtps_epi32(res_1); - temp_32[1] = _mm256_cvtps_epi32(res_2); - - // Convert the s32 to s16 - c_int16_1p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]); - - // Permute to make sure the order is correct - c_int16_1p0 = _mm256_permute4x64_epi64(c_int16_1p0, 0XD8); - - // Extract the first 128 bits of the register - temp[0] = _mm256_extractf128_si256(c_int16_1p1, 0); - - // Extract the second 128 bits of the register - temp[1] = _mm256_extractf128_si256(c_int16_1p1, 1); - - // Since s16 values cannot be converted to f32 directly, - // they are converted to s32, then to f32 and the scale is performed - temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); - temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); - - temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); - temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); - - // Multiply the C matrix by the scale value - res_1 = _mm256_mul_ps(temp_float[0], scale_1); - res_2 = _mm256_mul_ps(temp_float[0], scale_2); - - // Round the resultant value to the nearest integer - res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - - // Convert float32 scaled rounded value to int32 - temp_32[0] = _mm256_cvtps_epi32(res_1); - temp_32[1] = _mm256_cvtps_epi32(res_2); - - // Convert the s32 to s16 - c_int16_1p1 = _mm256_packs_epi32(temp_32[0], temp_32[1]); - - // Permute to make sure the order is correct - c_int16_1p1 = _mm256_permute4x64_epi64(c_int16_1p1, 0XD8); - - store_reg = _mm256_packs_epi16(c_int16_1p0, c_int16_1p1); - - _mm256_storeu_si256((__m256i *)(c + (rs_c * (ir + 0)) + (0 * 16)), store_reg); - + bli_mm256_s16_downscale(c_int16_1p0, c_int16_1p1, 1); + //-------------------------------------------------------------------------- - // Extract the first 128 bits of the register - temp[0] = _mm256_extractf128_si256(c_int16_2p0, 0); - - // Extract the second 128 bits of the register - temp[1] = _mm256_extractf128_si256(c_int16_2p0, 1); - - // Since s16 values cannot be converted to f32 directly, - // they are converted to s32, then to f32 and the scale is performed - temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); - temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); - - temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); - temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); - - // Multiply the C matrix by the scale value - res_1 = _mm256_mul_ps(temp_float[0], scale_1); - res_2 = _mm256_mul_ps(temp_float[0], scale_2); - - // Round the resultant value to the nearest integer - res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - - // Convert float32 scaled rounded value to int32 - temp_32[0] = _mm256_cvtps_epi32(res_1); - temp_32[1] = _mm256_cvtps_epi32(res_2); - - // Convert the s32 to s16 - c_int16_2p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]); - - // Permute to make sure the order is correct - c_int16_2p0 = _mm256_permute4x64_epi64(c_int16_2p0, 0XD8); - - // Extract the first 128 bits of the register - temp[0] = _mm256_extractf128_si256(c_int16_2p1, 0); - - // Extract the second 128 bits of the register - temp[1] = _mm256_extractf128_si256(c_int16_2p1, 1); - - // Since s16 values cannot be converted to f32 directly, - // they are converted to s32, then to f32 and the scale is performed - temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); - temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); - - temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); - temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); - - // Multiply the C matrix by the scale value - res_1 = _mm256_mul_ps(temp_float[0], scale_1); - res_2 = _mm256_mul_ps(temp_float[0], scale_2); - - // Round the resultant value to the nearest integer - res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - - // Convert float32 scaled rounded value to int32 - temp_32[0] = _mm256_cvtps_epi32(res_1); - temp_32[1] = _mm256_cvtps_epi32(res_2); - - // Convert the s32 to s16 - c_int16_2p1 = _mm256_packs_epi32(temp_32[0], temp_32[1]); - - // Permute to make sure the order is correct - c_int16_2p1 = _mm256_permute4x64_epi64(c_int16_2p1, 0XD8); - - store_reg = _mm256_packs_epi16(c_int16_2p0, c_int16_2p1); - - _mm256_storeu_si256((__m256i *)(c + (rs_c * (ir + 0)) + (0 * 16)), store_reg); - + bli_mm256_s16_downscale(c_int16_2p0, c_int16_2p1, 2); + //-------------------------------------------------------------------------- - // Extract the first 128 bits of the register - temp[0] = _mm256_extractf128_si256(c_int16_3p0, 0); - - // Extract the second 128 bits of the register - temp[1] = _mm256_extractf128_si256(c_int16_3p0, 1); - - // Since s16 values cannot be converted to f32 directly, - // they are converted to s32, then to f32 and the scale is performed - temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); - temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); - - temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); - temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); - - // Multiply the C matrix by the scale value - res_1 = _mm256_mul_ps(temp_float[0], scale_1); - res_2 = _mm256_mul_ps(temp_float[0], scale_2); - - // Round the resultant value to the nearest integer - res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - - // Convert float32 scaled rounded value to int32 - temp_32[0] = _mm256_cvtps_epi32(res_1); - temp_32[1] = _mm256_cvtps_epi32(res_2); - - // Convert the s32 to s16 - c_int16_3p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]); - - // Permute to make sure the order is correct - c_int16_3p0 = _mm256_permute4x64_epi64(c_int16_3p0, 0XD8); - - // Extract the first 128 bits of the register - temp[0] = _mm256_extractf128_si256(c_int16_3p1, 0); - - // Extract the second 128 bits of the register - temp[1] = _mm256_extractf128_si256(c_int16_3p1, 1); - - // Since s16 values cannot be converted to f32 directly, - // they are converted to s32, then to f32 and the scale is performed - temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); - temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); - - temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); - temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); - - // Multiply the C matrix by the scale value - res_1 = _mm256_mul_ps(temp_float[0], scale_1); - res_2 = _mm256_mul_ps(temp_float[0], scale_2); - - // Round the resultant value to the nearest integer - res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - - // Convert float32 scaled rounded value to int32 - temp_32[0] = _mm256_cvtps_epi32(res_1); - temp_32[1] = _mm256_cvtps_epi32(res_2); - - // Convert the s32 to s16 - c_int16_3p1 = _mm256_packs_epi32(temp_32[0], temp_32[1]); - - // Permute to make sure the order is correct - c_int16_3p1 = _mm256_permute4x64_epi64(c_int16_3p1, 0XD8); - - store_reg = _mm256_packs_epi16(c_int16_3p0, c_int16_3p1); - - _mm256_storeu_si256((__m256i *)(c + (rs_c * (ir + 0)) + (0 * 16)), store_reg); + bli_mm256_s16_downscale(c_int16_3p0, c_int16_3p1, 3); //-------------------------------------------------------------------------- - // Extract the first 128 bits of the register - temp[0] = _mm256_extractf128_si256(c_int16_4p0, 0); - - // Extract the second 128 bits of the register - temp[1] = _mm256_extractf128_si256(c_int16_4p0, 1); - - // Since s16 values cannot be converted to f32 directly, - // they are converted to s32, then to f32 and the scale is performed - temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); - temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); - - temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); - temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); - - // Multiply the C matrix by the scale value - res_1 = _mm256_mul_ps(temp_float[0], scale_1); - res_2 = _mm256_mul_ps(temp_float[0], scale_2); - - // Round the resultant value to the nearest integer - res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - - // Convert float32 scaled rounded value to int32 - temp_32[0] = _mm256_cvtps_epi32(res_1); - temp_32[1] = _mm256_cvtps_epi32(res_2); - - // Convert the s32 to s16 - c_int16_4p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]); - - // Permute to make sure the order is correct - c_int16_4p0 = _mm256_permute4x64_epi64(c_int16_4p0, 0XD8); - - // Extract the first 128 bits of the register - temp[0] = _mm256_extractf128_si256(c_int16_4p1, 0); - - // Extract the second 128 bits of the register - temp[1] = _mm256_extractf128_si256(c_int16_4p1, 1); - - // Since s16 values cannot be converted to f32 directly, - // they are converted to s32, then to f32 and the scale is performed - temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); - temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); - - temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); - temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); - - // Multiply the C matrix by the scale value - res_1 = _mm256_mul_ps(temp_float[0], scale_1); - res_2 = _mm256_mul_ps(temp_float[0], scale_2); - - // Round the resultant value to the nearest integer - res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - - // Convert float32 scaled rounded value to int32 - temp_32[0] = _mm256_cvtps_epi32(res_1); - temp_32[1] = _mm256_cvtps_epi32(res_2); - - // Convert the s32 to s16 - c_int16_4p1 = _mm256_packs_epi32(temp_32[0], temp_32[1]); - - // Permute to make sure the order is correct - c_int16_4p1 = _mm256_permute4x64_epi64(c_int16_4p1, 0XD8); - - store_reg = _mm256_packs_epi16(c_int16_4p0, c_int16_4p1); - - _mm256_storeu_si256((__m256i *)(c + (rs_c * (ir + 0)) + (0 * 16)), store_reg); + bli_mm256_s16_downscale(c_int16_4p0, c_int16_4p1, 4); //-------------------------------------------------------------------------- - // Extract the first 128 bits of the register - temp[0] = _mm256_extractf128_si256(c_int16_5p0, 0); - - // Extract the second 128 bits of the register - temp[1] = _mm256_extractf128_si256(c_int16_5p0, 1); - - // Since s16 values cannot be converted to f32 directly, - // they are converted to s32, then to f32 and the scale is performed - temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); - temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); - - temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); - temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); - - // Multiply the C matrix by the scale value - res_1 = _mm256_mul_ps(temp_float[0], scale_1); - res_2 = _mm256_mul_ps(temp_float[0], scale_2); - - // Round the resultant value to the nearest integer - res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - - // Convert float32 scaled rounded value to int32 - temp_32[0] = _mm256_cvtps_epi32(res_1); - temp_32[1] = _mm256_cvtps_epi32(res_2); - - // Convert the s32 to s16 - c_int16_5p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]); - - // Permute to make sure the order is correct - c_int16_5p0 = _mm256_permute4x64_epi64(c_int16_5p0, 0XD8); - - // Extract the first 128 bits of the register - temp[0] = _mm256_extractf128_si256(c_int16_5p1, 0); - - // Extract the second 128 bits of the register - temp[1] = _mm256_extractf128_si256(c_int16_5p1, 1); - - // Since s16 values cannot be converted to f32 directly, - // they are converted to s32, then to f32 and the scale is performed - temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); - temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); - - temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); - temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); - - // Multiply the C matrix by the scale value - res_1 = _mm256_mul_ps(temp_float[0], scale_1); - res_2 = _mm256_mul_ps(temp_float[0], scale_2); - - // Round the resultant value to the nearest integer - res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - - // Convert float32 scaled rounded value to int32 - temp_32[0] = _mm256_cvtps_epi32(res_1); - temp_32[1] = _mm256_cvtps_epi32(res_2); - - // Convert the s32 to s16 - c_int16_5p1 = _mm256_packs_epi32(temp_32[0], temp_32[1]); - - // Permute to make sure the order is correct - c_int16_5p1 = _mm256_permute4x64_epi64(c_int16_4p1, 0XD8); - - store_reg = _mm256_packs_epi16(c_int16_5p0, c_int16_5p1); - - _mm256_storeu_si256((__m256i *)(c + (rs_c * (ir + 0)) + (0 * 16)), store_reg); + bli_mm256_s16_downscale(c_int16_5p0, c_int16_5p1, 5); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } diff --git a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_m_fringe_amd256.c b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_m_fringe_amd256.c index 12026ad0e0..7c8f3b57b0 100644 --- a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_m_fringe_amd256.c +++ b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_m_fringe_amd256.c @@ -356,294 +356,34 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4x32) __m128i temp[2]; __m256i temp_32[2]; __m256 temp_float[2]; + __m256 scale_1, scale_2; + __m256 res_1, res_2; + __m256i store_reg; - // Load the scale vector values into the register - __m256 scale_1 = + /* Load the scale vector values into the register*/ + scale_1 = _mm256_loadu_ps( - (float *)post_ops_list_temp->scale_factor + - post_op_c_j + (0 * 8)); - __m256 scale_2 = + (float *)post_ops_list_temp->scale_factor + + post_op_c_j + (0 * 8)); + scale_2 = _mm256_loadu_ps( - (float *)post_ops_list_temp->scale_factor + - post_op_c_j + (1 * 8)); - - // Extract the first 128 bits of the register - temp[0] = _mm256_extractf128_si256(c_int16_0p0, 0); - - // Extract the second 128 bits of the register - temp[1] = _mm256_extractf128_si256(c_int16_0p0, 1); - - // Since s16 values cannot be converted to f32 directly, - // they are converted to s32, then to f32 and the scale is performed - temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); - temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); - - temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); - temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); - - // Multiply the C matrix by the scale value - __m256 res_1 = _mm256_mul_ps(temp_float[0], scale_1); - __m256 res_2 = _mm256_mul_ps(temp_float[0], scale_2); - - // Round the resultant value to the nearest integer - res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - - // Convert float32 scaled rounded value to int32 - temp_32[0] = _mm256_cvtps_epi32(res_1); - temp_32[1] = _mm256_cvtps_epi32(res_2); - - // Convert the s32 to s16 - c_int16_0p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]); - - // Permute to make sure the order is correct - c_int16_0p0 = _mm256_permute4x64_epi64(c_int16_0p0, 0XD8); - - // Extract the first 128 bits of the register - temp[0] = _mm256_extractf128_si256(c_int16_0p1, 0); - - // Extract the second 128 bits of the register - temp[1] = _mm256_extractf128_si256(c_int16_0p1, 1); - - // Since s16 values cannot be converted to f32 directly, - // they are converted to s32, then to f32 and the scale is performed - temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); - temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); - - temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); - temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); - - // Multiply the C matrix by the scale value - res_1 = _mm256_mul_ps(temp_float[0], scale_1); - res_2 = _mm256_mul_ps(temp_float[0], scale_2); - - // Round the resultant value to the nearest integer - res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - - // Convert float32 scaled rounded value to int32 - temp_32[0] = _mm256_cvtps_epi32(res_1); - temp_32[1] = _mm256_cvtps_epi32(res_2); - - // Convert the s32 to s16 - c_int16_0p1 = _mm256_packs_epi32(temp_32[0], temp_32[1]); - - // Permute to make sure the order is correct - c_int16_0p1 = _mm256_permute4x64_epi64(c_int16_0p1, 0XD8); - - __m256i store_reg = _mm256_packs_epi16(c_int16_0p0, c_int16_0p1); - - _mm256_storeu_si256((__m256i *)(c + (rs_c * 0)) + (0 * 16), store_reg); + (float *)post_ops_list_temp->scale_factor + + post_op_c_j + (1 * 8)); + bli_mm256_s16_downscale(c_int16_0p0, c_int16_0p1, 0); //-------------------------------------------------------------------------- - // Extract the first 128 bits of the register - temp[0] = _mm256_extractf128_si256(c_int16_1p0, 0); - - // Extract the second 128 bits of the register - temp[1] = _mm256_extractf128_si256(c_int16_1p0, 1); - - // Since s16 values cannot be converted to f32 directly, - // they are converted to s32, then to f32 and the scale is performed - temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); - temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); - - temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); - temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); - - // Multiply the C matrix by the scale value - res_1 = _mm256_mul_ps(temp_float[0], scale_1); - res_2 = _mm256_mul_ps(temp_float[0], scale_2); - - // Round the resultant value to the nearest integer - res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - - // Convert float32 scaled rounded value to int32 - temp_32[0] = _mm256_cvtps_epi32(res_1); - temp_32[1] = _mm256_cvtps_epi32(res_2); - - // Convert the s32 to s16 - c_int16_1p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]); - - // Permute to make sure the order is correct - c_int16_1p0 = _mm256_permute4x64_epi64(c_int16_1p0, 0XD8); - - // Extract the first 128 bits of the register - temp[0] = _mm256_extractf128_si256(c_int16_1p1, 0); - - // Extract the second 128 bits of the register - temp[1] = _mm256_extractf128_si256(c_int16_1p1, 1); - - // Since s16 values cannot be converted to f32 directly, - // they are converted to s32, then to f32 and the scale is performed - temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); - temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); - - temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); - temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); - - // Multiply the C matrix by the scale value - res_1 = _mm256_mul_ps(temp_float[0], scale_1); - res_2 = _mm256_mul_ps(temp_float[0], scale_2); - - // Round the resultant value to the nearest integer - res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - - // Convert float32 scaled rounded value to int32 - temp_32[0] = _mm256_cvtps_epi32(res_1); - temp_32[1] = _mm256_cvtps_epi32(res_2); - - // Convert the s32 to s16 - c_int16_1p1 = _mm256_packs_epi32(temp_32[0], temp_32[1]); - - // Permute to make sure the order is correct - c_int16_1p1 = _mm256_permute4x64_epi64(c_int16_1p1, 0XD8); - - store_reg = _mm256_packs_epi16(c_int16_1p0, c_int16_1p1); - - _mm256_storeu_si256((__m256i *)(c + (rs_c * 0) + (0 * 16)), store_reg); - + bli_mm256_s16_downscale(c_int16_1p0, c_int16_1p1, 1); + //-------------------------------------------------------------------------- - // Extract the first 128 bits of the register - temp[0] = _mm256_extractf128_si256(c_int16_2p0, 0); - - // Extract the second 128 bits of the register - temp[1] = _mm256_extractf128_si256(c_int16_2p0, 1); - - // Since s16 values cannot be converted to f32 directly, - // they are converted to s32, then to f32 and the scale is performed - temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); - temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); - - temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); - temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); - - // Multiply the C matrix by the scale value - res_1 = _mm256_mul_ps(temp_float[0], scale_1); - res_2 = _mm256_mul_ps(temp_float[0], scale_2); - - // Round the resultant value to the nearest integer - res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - - // Convert float32 scaled rounded value to int32 - temp_32[0] = _mm256_cvtps_epi32(res_1); - temp_32[1] = _mm256_cvtps_epi32(res_2); - - // Convert the s32 to s16 - c_int16_2p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]); - - // Permute to make sure the order is correct - c_int16_2p0 = _mm256_permute4x64_epi64(c_int16_2p0, 0XD8); - - // Extract the first 128 bits of the register - temp[0] = _mm256_extractf128_si256(c_int16_2p1, 0); - - // Extract the second 128 bits of the register - temp[1] = _mm256_extractf128_si256(c_int16_2p1, 1); - - // Since s16 values cannot be converted to f32 directly, - // they are converted to s32, then to f32 and the scale is performed - temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); - temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); - - temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); - temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); - - // Multiply the C matrix by the scale value - res_1 = _mm256_mul_ps(temp_float[0], scale_1); - res_2 = _mm256_mul_ps(temp_float[0], scale_2); - - // Round the resultant value to the nearest integer - res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - - // Convert float32 scaled rounded value to int32 - temp_32[0] = _mm256_cvtps_epi32(res_1); - temp_32[1] = _mm256_cvtps_epi32(res_2); - - // Convert the s32 to s16 - c_int16_2p1 = _mm256_packs_epi32(temp_32[0], temp_32[1]); - - // Permute to make sure the order is correct - c_int16_2p1 = _mm256_permute4x64_epi64(c_int16_2p1, 0XD8); - - store_reg = _mm256_packs_epi16(c_int16_2p0, c_int16_2p1); - - _mm256_storeu_si256((__m256i *)(c + (rs_c * 0) + (0 * 16)), store_reg); - + bli_mm256_s16_downscale(c_int16_2p0, c_int16_2p1, 2); + //-------------------------------------------------------------------------- - // Extract the first 128 bits of the register - temp[0] = _mm256_extractf128_si256(c_int16_3p0, 0); - - // Extract the second 128 bits of the register - temp[1] = _mm256_extractf128_si256(c_int16_3p0, 1); - - // Since s16 values cannot be converted to f32 directly, - // they are converted to s32, then to f32 and the scale is performed - temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); - temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); - - temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); - temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); - - // Multiply the C matrix by the scale value - res_1 = _mm256_mul_ps(temp_float[0], scale_1); - res_2 = _mm256_mul_ps(temp_float[0], scale_2); - - // Round the resultant value to the nearest integer - res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - - // Convert float32 scaled rounded value to int32 - temp_32[0] = _mm256_cvtps_epi32(res_1); - temp_32[1] = _mm256_cvtps_epi32(res_2); - - // Convert the s32 to s16 - c_int16_3p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]); - - // Permute to make sure the order is correct - c_int16_3p0 = _mm256_permute4x64_epi64(c_int16_3p0, 0XD8); - - // Extract the first 128 bits of the register - temp[0] = _mm256_extractf128_si256(c_int16_3p1, 0); + bli_mm256_s16_downscale(c_int16_3p0, c_int16_3p1, 3); - // Extract the second 128 bits of the register - temp[1] = _mm256_extractf128_si256(c_int16_3p1, 1); - - // Since s16 values cannot be converted to f32 directly, - // they are converted to s32, then to f32 and the scale is performed - temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); - temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); - - temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); - temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); - - // Multiply the C matrix by the scale value - res_1 = _mm256_mul_ps(temp_float[0], scale_1); - res_2 = _mm256_mul_ps(temp_float[0], scale_2); - - // Round the resultant value to the nearest integer - res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - - // Convert float32 scaled rounded value to int32 - temp_32[0] = _mm256_cvtps_epi32(res_1); - temp_32[1] = _mm256_cvtps_epi32(res_2); - - // Convert the s32 to s16 - c_int16_3p1 = _mm256_packs_epi32(temp_32[0], temp_32[1]); - - // Permute to make sure the order is correct - c_int16_3p1 = _mm256_permute4x64_epi64(c_int16_3p1, 0XD8); - - store_reg = _mm256_packs_epi16(c_int16_3p0, c_int16_3p1); - - _mm256_storeu_si256((__m256i *)(c + (rs_c * 0) + (0 * 16)), store_reg); + //-------------------------------------------------------------------------- POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -878,154 +618,26 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2x32) __m128i temp[2]; __m256i temp_32[2]; __m256 temp_float[2]; + __m256 scale_1, scale_2; + __m256 res_1, res_2; + __m256i store_reg; - // Load the scale vector values into the register - __m256 scale_1 = + /* Load the scale vector values into the register*/ + scale_1 = _mm256_loadu_ps( - (float *)post_ops_list_temp->scale_factor + - post_op_c_j + (0 * 8)); - __m256 scale_2 = + (float *)post_ops_list_temp->scale_factor + + post_op_c_j + (0 * 8)); + scale_2 = _mm256_loadu_ps( - (float *)post_ops_list_temp->scale_factor + - post_op_c_j + (1 * 8)); - - // Extract the first 128 bits of the register - temp[0] = _mm256_extractf128_si256(c_int16_0p0, 0); - - // Extract the second 128 bits of the register - temp[1] = _mm256_extractf128_si256(c_int16_0p0, 1); - - // Since s16 values cannot be converted to f32 directly, - // they are converted to s32, then to f32 and the scale is performed - temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); - temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); - - temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); - temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); - - // Multiply the C matrix by the scale value - __m256 res_1 = _mm256_mul_ps(temp_float[0], scale_1); - __m256 res_2 = _mm256_mul_ps(temp_float[0], scale_2); - - // Round the resultant value to the nearest integer - res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - - // Convert float32 scaled rounded value to int32 - temp_32[0] = _mm256_cvtps_epi32(res_1); - temp_32[1] = _mm256_cvtps_epi32(res_2); - - // Convert the s32 to s16 - c_int16_0p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]); - - // Permute to make sure the order is correct - c_int16_0p0 = _mm256_permute4x64_epi64(c_int16_0p0, 0XD8); - - // Extract the first 128 bits of the register - temp[0] = _mm256_extractf128_si256(c_int16_0p1, 0); - - // Extract the second 128 bits of the register - temp[1] = _mm256_extractf128_si256(c_int16_0p1, 1); - - // Since s16 values cannot be converted to f32 directly, - // they are converted to s32, then to f32 and the scale is performed - temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); - temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); - - temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); - temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); - - // Multiply the C matrix by the scale value - res_1 = _mm256_mul_ps(temp_float[0], scale_1); - res_2 = _mm256_mul_ps(temp_float[0], scale_2); - - // Round the resultant value to the nearest integer - res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - - // Convert float32 scaled rounded value to int32 - temp_32[0] = _mm256_cvtps_epi32(res_1); - temp_32[1] = _mm256_cvtps_epi32(res_2); - - // Convert the s32 to s16 - c_int16_0p1 = _mm256_packs_epi32(temp_32[0], temp_32[1]); - - // Permute to make sure the order is correct - c_int16_0p1 = _mm256_permute4x64_epi64(c_int16_0p1, 0XD8); - - __m256i store_reg = _mm256_packs_epi16(c_int16_0p0, c_int16_0p1); - - _mm256_storeu_si256((__m256i *)(c + (rs_c * 0)) + (0 * 16), store_reg); + (float *)post_ops_list_temp->scale_factor + + post_op_c_j + (1 * 8)); + bli_mm256_s16_downscale(c_int16_0p0, c_int16_0p1, 0); //-------------------------------------------------------------------------- - // Extract the first 128 bits of the register - temp[0] = _mm256_extractf128_si256(c_int16_1p0, 0); - - // Extract the second 128 bits of the register - temp[1] = _mm256_extractf128_si256(c_int16_1p0, 1); - - // Since s16 values cannot be converted to f32 directly, - // they are converted to s32, then to f32 and the scale is performed - temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); - temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); - - temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); - temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); - - // Multiply the C matrix by the scale value - res_1 = _mm256_mul_ps(temp_float[0], scale_1); - res_2 = _mm256_mul_ps(temp_float[0], scale_2); - - // Round the resultant value to the nearest integer - res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - - // Convert float32 scaled rounded value to int32 - temp_32[0] = _mm256_cvtps_epi32(res_1); - temp_32[1] = _mm256_cvtps_epi32(res_2); - - // Convert the s32 to s16 - c_int16_1p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]); - - // Permute to make sure the order is correct - c_int16_1p0 = _mm256_permute4x64_epi64(c_int16_1p0, 0XD8); - - // Extract the first 128 bits of the register - temp[0] = _mm256_extractf128_si256(c_int16_1p1, 0); - - // Extract the second 128 bits of the register - temp[1] = _mm256_extractf128_si256(c_int16_1p1, 1); - - // Since s16 values cannot be converted to f32 directly, - // they are converted to s32, then to f32 and the scale is performed - temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); - temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); - - temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); - temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); - - // Multiply the C matrix by the scale value - res_1 = _mm256_mul_ps(temp_float[0], scale_1); - res_2 = _mm256_mul_ps(temp_float[0], scale_2); - - // Round the resultant value to the nearest integer - res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - - // Convert float32 scaled rounded value to int32 - temp_32[0] = _mm256_cvtps_epi32(res_1); - temp_32[1] = _mm256_cvtps_epi32(res_2); - - // Convert the s32 to s16 - c_int16_1p1 = _mm256_packs_epi32(temp_32[0], temp_32[1]); - - // Permute to make sure the order is correct - c_int16_1p1 = _mm256_permute4x64_epi64(c_int16_1p1, 0XD8); - - store_reg = _mm256_packs_epi16(c_int16_1p0, c_int16_1p1); - - _mm256_storeu_si256((__m256i *)(c + (rs_c * 0) + (0 * 16)), store_reg); + bli_mm256_s16_downscale(c_int16_1p0, c_int16_1p1, 1); + + //-------------------------------------------------------------------------- POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -1188,84 +800,22 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_1x32) __m128i temp[2]; __m256i temp_32[2]; __m256 temp_float[2]; + __m256 scale_1, scale_2; + __m256 res_1, res_2; + __m256i store_reg; - // Load the scale vector values into the register - __m256 scale_1 = + /* Load the scale vector values into the register*/ + scale_1 = _mm256_loadu_ps( - (float *)post_ops_list_temp->scale_factor + - post_op_c_j + (0 * 8)); - __m256 scale_2 = + (float *)post_ops_list_temp->scale_factor + + post_op_c_j + (0 * 8)); + scale_2 = _mm256_loadu_ps( - (float *)post_ops_list_temp->scale_factor + - post_op_c_j + (1 * 8)); - - // Extract the first 128 bits of the register - temp[0] = _mm256_extractf128_si256(c_int16_0p0, 0); - - // Extract the second 128 bits of the register - temp[1] = _mm256_extractf128_si256(c_int16_0p0, 1); - - // Since s16 values cannot be converted to f32 directly, - // they are converted to s32, then to f32 and the scale is performed - temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); - temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); + (float *)post_ops_list_temp->scale_factor + + post_op_c_j + (1 * 8)); - temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); - temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); - - // Multiply the C matrix by the scale value - __m256 res_1 = _mm256_mul_ps(temp_float[0], scale_1); - __m256 res_2 = _mm256_mul_ps(temp_float[0], scale_2); - - // Round the resultant value to the nearest integer - res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - - // Convert float32 scaled rounded value to int32 - temp_32[0] = _mm256_cvtps_epi32(res_1); - temp_32[1] = _mm256_cvtps_epi32(res_2); - - // Convert the s32 to s16 - c_int16_0p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]); - - // Permute to make sure the order is correct - c_int16_0p0 = _mm256_permute4x64_epi64(c_int16_0p0, 0XD8); - - // Extract the first 128 bits of the register - temp[0] = _mm256_extractf128_si256(c_int16_0p1, 0); - - // Extract the second 128 bits of the register - temp[1] = _mm256_extractf128_si256(c_int16_0p1, 1); - - // Since s16 values cannot be converted to f32 directly, - // they are converted to s32, then to f32 and the scale is performed - temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); - temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); - - temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); - temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); - - // Multiply the C matrix by the scale value - res_1 = _mm256_mul_ps(temp_float[0], scale_1); - res_2 = _mm256_mul_ps(temp_float[0], scale_2); - - // Round the resultant value to the nearest integer - res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - - // Convert float32 scaled rounded value to int32 - temp_32[0] = _mm256_cvtps_epi32(res_1); - temp_32[1] = _mm256_cvtps_epi32(res_2); - - // Convert the s32 to s16 - c_int16_0p1 = _mm256_packs_epi32(temp_32[0], temp_32[1]); - - // Permute to make sure the order is correct - c_int16_0p1 = _mm256_permute4x64_epi64(c_int16_0p1, 0XD8); - - __m256i store_reg = _mm256_packs_epi16(c_int16_0p0, c_int16_0p1); - - _mm256_storeu_si256((__m256i *)(c + (rs_c * 0)) + (0 * 16), store_reg); + bli_mm256_s16_downscale(c_int16_0p0, c_int16_0p1, 0); + //-------------------------------------------------------------------------- POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } diff --git a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_mn_fringe_amd256.c b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_mn_fringe_amd256.c index f50dd98826..968959ca7f 100644 --- a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_mn_fringe_amd256.c +++ b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_mn_fringe_amd256.c @@ -266,157 +266,28 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4x16) __m128i temp[2]; __m256i temp_32[2]; __m256 temp_float[2]; + __m256 scale_1, scale_2; + __m256 res_1, res_2; + __m256i store_reg; - // Load the scale vector values into the register - __m256 scale_1 = + /* Load the scale vector values into the register*/ + scale_1 = _mm256_loadu_ps( - (float *)post_ops_list_temp->scale_factor + - post_op_c_j + (0 * 8)); - __m256 scale_2 = + (float *)post_ops_list_temp->scale_factor + + post_op_c_j + (0 * 8)); + scale_2 = _mm256_loadu_ps( - (float *)post_ops_list_temp->scale_factor + - post_op_c_j + (1 * 8)); - - // Extract the first 128 bits of the register - temp[0] = _mm256_extractf128_si256(c_int16_0p0, 0); - - // Extract the second 128 bits of the register - temp[1] = _mm256_extractf128_si256(c_int16_0p0, 1); - - // Since s16 values cannot be converted to f32 directly, - // they are converted to s32, then to f32 and the scale is performed - temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); - temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); - - temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); - temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); - - // Multiply the C matrix by the scale value - __m256 res_1 = _mm256_mul_ps(temp_float[0], scale_1); - __m256 res_2 = _mm256_mul_ps(temp_float[0], scale_2); - - // Round the resultant value to the nearest integer - res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - - // Convert float32 scaled rounded value to int32 - temp_32[0] = _mm256_cvtps_epi32(res_1); - temp_32[1] = _mm256_cvtps_epi32(res_2); - - // Convert the s32 to s16 - c_int16_0p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]); - - // Permute to make sure the order is correct - c_int16_0p0 = _mm256_permute4x64_epi64(c_int16_0p0, 0XD8); - //-------------------------------------------------------------------------- - - // Extract the first 128 bits of the register - temp[0] = _mm256_extractf128_si256(c_int16_1p0, 0); - - // Extract the second 128 bits of the register - temp[1] = _mm256_extractf128_si256(c_int16_1p0, 1); - - // Since s16 values cannot be converted to f32 directly, - // they are converted to s32, then to f32 and the scale is performed - temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); - temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); - - temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); - temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); - - // Multiply the C matrix by the scale value - res_1 = _mm256_mul_ps(temp_float[0], scale_1); - res_2 = _mm256_mul_ps(temp_float[0], scale_2); - - // Round the resultant value to the nearest integer - res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - - // Convert float32 scaled rounded value to int32 - temp_32[0] = _mm256_cvtps_epi32(res_1); - temp_32[1] = _mm256_cvtps_epi32(res_2); - - // Convert the s32 to s16 - c_int16_1p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]); + (float *)post_ops_list_temp->scale_factor + + post_op_c_j + (1 * 8)); - // Permute to make sure the order is correct - c_int16_1p0 = _mm256_permute4x64_epi64(c_int16_1p0, 0XD8); - - __m256i store_reg = _mm256_packs_epi16(c_int16_0p0, c_int16_1p0); - - _mm256_storeu_si256((__m256i *)(c + (rs_c * 0) + (0 * 16)), store_reg); + bli_mm256_s16_downscale2(c_int16_0p0, c_int16_1p0, 0, 1); //-------------------------------------------------------------------------- - // Extract the first 128 bits of the register - temp[0] = _mm256_extractf128_si256(c_int16_2p0, 0); - - // Extract the second 128 bits of the register - temp[1] = _mm256_extractf128_si256(c_int16_2p0, 1); - - // Since s16 values cannot be converted to f32 directly, - // they are converted to s32, then to f32 and the scale is performed - temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); - temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); - - temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); - temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); + bli_mm256_s16_downscale2(c_int16_2p0, c_int16_3p0, 2, 3); - // Multiply the C matrix by the scale value - res_1 = _mm256_mul_ps(temp_float[0], scale_1); - res_2 = _mm256_mul_ps(temp_float[0], scale_2); - - // Round the resultant value to the nearest integer - res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - - // Convert float32 scaled rounded value to int32 - temp_32[0] = _mm256_cvtps_epi32(res_1); - temp_32[1] = _mm256_cvtps_epi32(res_2); - - // Convert the s32 to s16 - c_int16_2p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]); - - // Permute to make sure the order is correct - c_int16_2p0 = _mm256_permute4x64_epi64(c_int16_2p0, 0XD8); //-------------------------------------------------------------------------- - // Extract the first 128 bits of the register - temp[0] = _mm256_extractf128_si256(c_int16_3p0, 0); - - // Extract the second 128 bits of the register - temp[1] = _mm256_extractf128_si256(c_int16_3p0, 1); - - // Since s16 values cannot be converted to f32 directly, - // they are converted to s32, then to f32 and the scale is performed - temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); - temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); - - temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); - temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); - - // Multiply the C matrix by the scale value - res_1 = _mm256_mul_ps(temp_float[0], scale_1); - res_2 = _mm256_mul_ps(temp_float[0], scale_2); - - // Round the resultant value to the nearest integer - res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - - // Convert float32 scaled rounded value to int32 - temp_32[0] = _mm256_cvtps_epi32(res_1); - temp_32[1] = _mm256_cvtps_epi32(res_2); - - // Convert the s32 to s16 - c_int16_3p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]); - - // Permute to make sure the order is correct - c_int16_3p0 = _mm256_permute4x64_epi64(c_int16_3p0, 0XD8); - - store_reg = _mm256_packs_epi16(c_int16_2p0, c_int16_3p0); - - _mm256_storeu_si256((__m256i *)(c + (rs_c * 0) + (0 * 16)), store_reg); - POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_4x16_DISABLE: @@ -677,158 +548,27 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4xlt16) POST_OPS_DOWNSCALE_4xlt16: { __m128i temp[2]; - __m256i temp_32[2]; - __m256 temp_float[2]; - - // Load the scale vector values into the register - __m256 scale_1 = - _mm256_loadu_ps( - (float *)post_ops_list_temp->scale_factor + - post_op_c_j + (0 * 8)); - __m256 scale_2 = - _mm256_loadu_ps( - (float *)post_ops_list_temp->scale_factor + - post_op_c_j + (1 * 8)); - - // Extract the first 128 bits of the register - temp[0] = _mm256_extractf128_si256(c_int16_0p0, 0); - - // Extract the second 128 bits of the register - temp[1] = _mm256_extractf128_si256(c_int16_0p0, 1); - - // Since s16 values cannot be converted to f32 directly, - // they are converted to s32, then to f32 and the scale is performed - temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); - temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); - - temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); - temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); - - // Multiply the C matrix by the scale value - __m256 res_1 = _mm256_mul_ps(temp_float[0], scale_1); - __m256 res_2 = _mm256_mul_ps(temp_float[0], scale_2); - - // Round the resultant value to the nearest integer - res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - - // Convert float32 scaled rounded value to int32 - temp_32[0] = _mm256_cvtps_epi32(res_1); - temp_32[1] = _mm256_cvtps_epi32(res_2); - - // Convert the s32 to s16 - c_int16_0p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]); - - // Permute to make sure the order is correct - c_int16_0p0 = _mm256_permute4x64_epi64(c_int16_0p0, 0XD8); - //-------------------------------------------------------------------------- - - // Extract the first 128 bits of the register - temp[0] = _mm256_extractf128_si256(c_int16_1p0, 0); - - // Extract the second 128 bits of the register - temp[1] = _mm256_extractf128_si256(c_int16_1p0, 1); - - // Since s16 values cannot be converted to f32 directly, - // they are converted to s32, then to f32 and the scale is performed - temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); - temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); - - temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); - temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); - - // Multiply the C matrix by the scale value - res_1 = _mm256_mul_ps(temp_float[0], scale_1); - res_2 = _mm256_mul_ps(temp_float[0], scale_2); - - // Round the resultant value to the nearest integer - res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - - // Convert float32 scaled rounded value to int32 - temp_32[0] = _mm256_cvtps_epi32(res_1); - temp_32[1] = _mm256_cvtps_epi32(res_2); - - // Convert the s32 to s16 - c_int16_1p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]); - - // Permute to make sure the order is correct - c_int16_1p0 = _mm256_permute4x64_epi64(c_int16_1p0, 0XD8); - - __m256i store_reg = _mm256_packs_epi16(c_int16_0p0, c_int16_1p0); - - _mm256_storeu_si256((__m256i *)(c + (rs_c * 0) + (0 * 16)), store_reg); - - //-------------------------------------------------------------------------- - - // Extract the first 128 bits of the register - temp[0] = _mm256_extractf128_si256(c_int16_2p0, 0); - - // Extract the second 128 bits of the register - temp[1] = _mm256_extractf128_si256(c_int16_2p0, 1); - - // Since s16 values cannot be converted to f32 directly, - // they are converted to s32, then to f32 and the scale is performed - temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); - temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); - - temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); - temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); - - // Multiply the C matrix by the scale value - res_1 = _mm256_mul_ps(temp_float[0], scale_1); - res_2 = _mm256_mul_ps(temp_float[0], scale_2); - - // Round the resultant value to the nearest integer - res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - - // Convert float32 scaled rounded value to int32 - temp_32[0] = _mm256_cvtps_epi32(res_1); - temp_32[1] = _mm256_cvtps_epi32(res_2); - - // Convert the s32 to s16 - c_int16_2p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]); + __m256i temp_32[2]; + __m256 temp_float[2]; + __m256 scale_1, scale_2; + __m256 res_1, res_2; + __m256i store_reg; - // Permute to make sure the order is correct - c_int16_2p0 = _mm256_permute4x64_epi64(c_int16_2p0, 0XD8); - //-------------------------------------------------------------------------- - - // Extract the first 128 bits of the register - temp[0] = _mm256_extractf128_si256(c_int16_3p0, 0); - - // Extract the second 128 bits of the register - temp[1] = _mm256_extractf128_si256(c_int16_3p0, 1); - - // Since s16 values cannot be converted to f32 directly, - // they are converted to s32, then to f32 and the scale is performed - temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); - temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); - - temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); - temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); + float float_buf[16]; + int8 store_buf[16]; - // Multiply the C matrix by the scale value - res_1 = _mm256_mul_ps(temp_float[0], scale_1); - res_2 = _mm256_mul_ps(temp_float[0], scale_2); + memcpy( float_buf, ( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j ), ( n0_rem * sizeof( float ) ) ); - // Round the resultant value to the nearest integer - res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + // Load the scale vector values into the register + scale_1 = _mm256_loadu_ps(float_buf + (0 * 8)); + scale_2 = _mm256_loadu_ps(float_buf + (1 * 8)); - // Convert float32 scaled rounded value to int32 - temp_32[0] = _mm256_cvtps_epi32(res_1); - temp_32[1] = _mm256_cvtps_epi32(res_2); + bli_mm256_s16_downscale2_lt16(c_int16_0p0, c_int16_1p0, 0, 1) - // Convert the s32 to s16 - c_int16_3p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]); + //-------------------------------------------------------------------------- - // Permute to make sure the order is correct - c_int16_3p0 = _mm256_permute4x64_epi64(c_int16_3p0, 0XD8); - - store_reg = _mm256_packs_epi16(c_int16_2p0, c_int16_3p0); - - _mm256_storeu_si256((__m256i *)(c + (rs_c * 0) + (0 * 16)), store_reg); + bli_mm256_s16_downscale2_lt16(c_int16_2p0, c_int16_3p0, 2, 3) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -1013,85 +753,21 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2x16) __m128i temp[2]; __m256i temp_32[2]; __m256 temp_float[2]; + __m256 scale_1, scale_2; + __m256 res_1, res_2; + __m256i store_reg; - // Load the scale vector values into the register - __m256 scale_1 = + /* Load the scale vector values into the register*/ + scale_1 = _mm256_loadu_ps( - (float *)post_ops_list_temp->scale_factor + - post_op_c_j + (0 * 8)); - __m256 scale_2 = + (float *)post_ops_list_temp->scale_factor + + post_op_c_j + (0 * 8)); + scale_2 = _mm256_loadu_ps( - (float *)post_ops_list_temp->scale_factor + - post_op_c_j + (1 * 8)); - - // Extract the first 128 bits of the register - temp[0] = _mm256_extractf128_si256(c_int16_0p0, 0); - - // Extract the second 128 bits of the register - temp[1] = _mm256_extractf128_si256(c_int16_0p0, 1); - - // Since s16 values cannot be converted to f32 directly, - // they are converted to s32, then to f32 and the scale is performed - temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); - temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); - - temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); - temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); - - // Multiply the C matrix by the scale value - __m256 res_1 = _mm256_mul_ps(temp_float[0], scale_1); - __m256 res_2 = _mm256_mul_ps(temp_float[0], scale_2); - - // Round the resultant value to the nearest integer - res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - - // Convert float32 scaled rounded value to int32 - temp_32[0] = _mm256_cvtps_epi32(res_1); - temp_32[1] = _mm256_cvtps_epi32(res_2); - - // Convert the s32 to s16 - c_int16_0p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]); - - // Permute to make sure the order is correct - c_int16_0p0 = _mm256_permute4x64_epi64(c_int16_0p0, 0XD8); - //-------------------------------------------------------------------------- - - // Extract the first 128 bits of the register - temp[0] = _mm256_extractf128_si256(c_int16_1p0, 0); + (float *)post_ops_list_temp->scale_factor + + post_op_c_j + (1 * 8)); - // Extract the second 128 bits of the register - temp[1] = _mm256_extractf128_si256(c_int16_1p0, 1); - - // Since s16 values cannot be converted to f32 directly, - // they are converted to s32, then to f32 and the scale is performed - temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); - temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); - - temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); - temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); - - // Multiply the C matrix by the scale value - res_1 = _mm256_mul_ps(temp_float[0], scale_1); - res_2 = _mm256_mul_ps(temp_float[0], scale_2); - - // Round the resultant value to the nearest integer - res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - - // Convert float32 scaled rounded value to int32 - temp_32[0] = _mm256_cvtps_epi32(res_1); - temp_32[1] = _mm256_cvtps_epi32(res_2); - - // Convert the s32 to s16 - c_int16_1p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]); - - // Permute to make sure the order is correct - c_int16_1p0 = _mm256_permute4x64_epi64(c_int16_1p0, 0XD8); - - __m256i store_reg = _mm256_packs_epi16(c_int16_0p0, c_int16_1p0); - - _mm256_storeu_si256((__m256i *)(c + (rs_c * 0) + (0 * 16)), store_reg); + bli_mm256_s16_downscale2(c_int16_0p0, c_int16_1p0, 0, 1); //-------------------------------------------------------------------------- @@ -1268,89 +944,23 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2xlt16) POST_OPS_DOWNSCALE_2xlt16: { __m128i temp[2]; - __m256i temp_32[2]; - __m256 temp_float[2]; - - // Load the scale vector values into the register - __m256 scale_1 = - _mm256_loadu_ps( - (float *)post_ops_list_temp->scale_factor + - post_op_c_j + (0 * 8)); - __m256 scale_2 = - _mm256_loadu_ps( - (float *)post_ops_list_temp->scale_factor + - post_op_c_j + (1 * 8)); - - // Extract the first 128 bits of the register - temp[0] = _mm256_extractf128_si256(c_int16_0p0, 0); - - // Extract the second 128 bits of the register - temp[1] = _mm256_extractf128_si256(c_int16_0p0, 1); - - // Since s16 values cannot be converted to f32 directly, - // they are converted to s32, then to f32 and the scale is performed - temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); - temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); - - temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); - temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); + __m256i temp_32[2]; + __m256 temp_float[2]; + __m256 scale_1, scale_2; + __m256 res_1, res_2; + __m256i store_reg; - // Multiply the C matrix by the scale value - __m256 res_1 = _mm256_mul_ps(temp_float[0], scale_1); - __m256 res_2 = _mm256_mul_ps(temp_float[0], scale_2); + float float_buf[16]; + int8 store_buf[16]; - // Round the resultant value to the nearest integer - res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + memcpy( float_buf, ( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j ), ( n0_rem * sizeof( float ) ) ); - // Convert float32 scaled rounded value to int32 - temp_32[0] = _mm256_cvtps_epi32(res_1); - temp_32[1] = _mm256_cvtps_epi32(res_2); + // Load the scale vector values into the register + scale_1 = _mm256_loadu_ps(float_buf + (0 * 8)); + scale_2 = _mm256_loadu_ps(float_buf + (1 * 8)); - // Convert the s32 to s16 - c_int16_0p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]); - - // Permute to make sure the order is correct - c_int16_0p0 = _mm256_permute4x64_epi64(c_int16_0p0, 0XD8); - //-------------------------------------------------------------------------- - - // Extract the first 128 bits of the register - temp[0] = _mm256_extractf128_si256(c_int16_1p0, 0); - - // Extract the second 128 bits of the register - temp[1] = _mm256_extractf128_si256(c_int16_1p0, 1); - - // Since s16 values cannot be converted to f32 directly, - // they are converted to s32, then to f32 and the scale is performed - temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); - temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); - - temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); - temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); - - // Multiply the C matrix by the scale value - res_1 = _mm256_mul_ps(temp_float[0], scale_1); - res_2 = _mm256_mul_ps(temp_float[0], scale_2); - - // Round the resultant value to the nearest integer - res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - - // Convert float32 scaled rounded value to int32 - temp_32[0] = _mm256_cvtps_epi32(res_1); - temp_32[1] = _mm256_cvtps_epi32(res_2); - - // Convert the s32 to s16 - c_int16_1p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]); - - // Permute to make sure the order is correct - c_int16_1p0 = _mm256_permute4x64_epi64(c_int16_1p0, 0XD8); - - __m256i store_reg = _mm256_packs_epi16(c_int16_0p0, c_int16_1p0); - - _mm256_storeu_si256((__m256i *)(c + (rs_c * 0) + (0 * 16)), store_reg); - - //-------------------------------------------------------------------------- + bli_mm256_s16_downscale2_lt16(c_int16_0p0, c_int16_1p0, 0, 1) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -1380,7 +990,8 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_1x16) &&POST_OPS_1x16_DISABLE, &&POST_OPS_BIAS_1x16, &&POST_OPS_RELU_1x16, - &&POST_OPS_RELU_SCALE_1x16 + &&POST_OPS_RELU_SCALE_1x16, + &&POST_OPS_DOWNSCALE_1x16 }; // The division is done by considering the vpmaddubsw instruction @@ -1478,6 +1089,31 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_1x16) // c[0,0-15] RELU_SCALE_OP_S16_AVX2(c_int16_0p0) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR + } +POST_OPS_DOWNSCALE_1x16: + { + __m128i temp[2]; + __m256i temp_32[2]; + __m256 temp_float[2]; + __m256 scale_1, scale_2; + __m256 res_1, res_2; + __m256i store_reg; + + /* Load the scale vector values into the register*/ + scale_1 = + _mm256_loadu_ps( + (float *)post_ops_list_temp->scale_factor + + post_op_c_j + (0 * 8)); + scale_2 = + _mm256_loadu_ps( + (float *)post_ops_list_temp->scale_factor + + post_op_c_j + (1 * 8)); + + temp_32[1] = _mm256_setzero_si256(); + + bli_mm256_s16_downscale2_edge(c_int16_0p0, temp_32[1]) + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_1x16_DISABLE: @@ -1498,7 +1134,8 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_1xlt16) &&POST_OPS_1xlt16_DISABLE, &&POST_OPS_BIAS_1xlt16, &&POST_OPS_RELU_1xlt16, - &&POST_OPS_RELU_SCALE_1xlt16 + &&POST_OPS_RELU_SCALE_1xlt16, + &&POST_OPS_DOWNSCALE_1xlt16 }; // The division is done by considering the vpmaddubsw instruction @@ -1604,6 +1241,29 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_1xlt16) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } +POST_OPS_DOWNSCALE_1xlt16: + { + __m128i temp[2]; + __m256i temp_32[2]; + __m256 temp_float[2]; + __m256 scale_1, scale_2; + __m256 res_1, res_2; + __m256i store_reg; + + float float_buf[16]; + int8 store_buf[16]; + + memcpy( float_buf, ( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j ), ( n0_rem * sizeof( float ) ) ); + + // Load the scale vector values into the register + scale_1 = _mm256_loadu_ps(float_buf + (0 * 8)); + scale_2 = _mm256_loadu_ps(float_buf + (1 * 8)); + + temp_32[1] = _mm256_setzero_si256(); + + bli_mm256_s16_downscale2_edge_lt16(c_int16_0p0, temp_32[1]) + } POST_OPS_1xlt16_DISABLE: ; diff --git a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_n_fringe_amd256.c b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_n_fringe_amd256.c index f70d494c70..cd5a4113e5 100644 --- a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_n_fringe_amd256.c +++ b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_n_fringe_amd256.c @@ -352,226 +352,29 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x16) __m128i temp[2]; __m256i temp_32[2]; __m256 temp_float[2]; + __m256 scale_1, scale_2; + __m256 res_1, res_2; + __m256i store_reg; - // Load the scale vector values into the register - __m256 scale_1 = + /* Load the scale vector values into the register*/ + scale_1 = _mm256_loadu_ps( - (float *)post_ops_list_temp->scale_factor + - post_op_c_j + (0 * 8)); - __m256 scale_2 = + (float *)post_ops_list_temp->scale_factor + + post_op_c_j + (0 * 8)); + scale_2 = _mm256_loadu_ps( - (float *)post_ops_list_temp->scale_factor + - post_op_c_j + (1 * 8)); - - // Extract the first 128 bits of the register - temp[0] = _mm256_extractf128_si256(c_int16_0p0, 0); - - // Extract the second 128 bits of the register - temp[1] = _mm256_extractf128_si256(c_int16_0p0, 1); - - // Since s16 values cannot be converted to f32 directly, - // they are converted to s32, then to f32 and the scale is performed - temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); - temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); - - temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); - temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); - - // Multiply the C matrix by the scale value - __m256 res_1 = _mm256_mul_ps(temp_float[0], scale_1); - __m256 res_2 = _mm256_mul_ps(temp_float[0], scale_2); - - // Round the resultant value to the nearest integer - res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - - // Convert float32 scaled rounded value to int32 - temp_32[0] = _mm256_cvtps_epi32(res_1); - temp_32[1] = _mm256_cvtps_epi32(res_2); - - // Convert the s32 to s16 - c_int16_0p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]); - - //-------------------------------------------------------------------------- - - // Extract the first 128 bits of the register - temp[0] = _mm256_extractf128_si256(c_int16_1p0, 0); - - // Extract the second 128 bits of the register - temp[1] = _mm256_extractf128_si256(c_int16_1p0, 1); - - // Since s16 values cannot be converted to f32 directly, - // they are converted to s32, then to f32 and the scale is performed - temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); - temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); - - temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); - temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); - - // Multiply the C matrix by the scale value - res_1 = _mm256_mul_ps(temp_float[0], scale_1); - res_2 = _mm256_mul_ps(temp_float[0], scale_2); + (float *)post_ops_list_temp->scale_factor + + post_op_c_j + (1 * 8)); - // Round the resultant value to the nearest integer - res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - - // Convert float32 scaled rounded value to int32 - temp_32[0] = _mm256_cvtps_epi32(res_1); - temp_32[1] = _mm256_cvtps_epi32(res_2); - - // Convert the s32 to s16 - c_int16_1p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]); - - // Permute to make sure the order is correct - c_int16_1p0 = _mm256_permute4x64_epi64(c_int16_1p0, 0XD8); - - __m256i store_reg = _mm256_packs_epi16(c_int16_0p0, c_int16_1p0); - - _mm256_storeu_si256((__m256i *)(c + (rs_c * (ir + 0)) + (0 * 16)), store_reg); + bli_mm256_s16_downscale2(c_int16_0p0, c_int16_1p0, 0, 1); //-------------------------------------------------------------------------- - // Extract the first 128 bits of the register - temp[0] = _mm256_extractf128_si256(c_int16_2p0, 0); - - // Extract the second 128 bits of the register - temp[1] = _mm256_extractf128_si256(c_int16_2p0, 1); - - // Since s16 values cannot be converted to f32 directly, - // they are converted to s32, then to f32 and the scale is performed - temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); - temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); - - temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); - temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); + bli_mm256_s16_downscale2(c_int16_2p0, c_int16_3p0, 2, 3); - // Multiply the C matrix by the scale value - res_1 = _mm256_mul_ps(temp_float[0], scale_1); - res_2 = _mm256_mul_ps(temp_float[0], scale_2); - - // Round the resultant value to the nearest integer - res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - - // Convert float32 scaled rounded value to int32 - temp_32[0] = _mm256_cvtps_epi32(res_1); - temp_32[1] = _mm256_cvtps_epi32(res_2); - - // Convert the s32 to s16 - c_int16_2p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]); - - // Permute to make sure the order is correct - c_int16_2p0 = _mm256_permute4x64_epi64(c_int16_2p0, 0XD8); //-------------------------------------------------------------------------- - // Extract the first 128 bits of the register - temp[0] = _mm256_extractf128_si256(c_int16_3p0, 0); - - // Extract the second 128 bits of the register - temp[1] = _mm256_extractf128_si256(c_int16_3p0, 1); - - // Since s16 values cannot be converted to f32 directly, - // they are converted to s32, then to f32 and the scale is performed - temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); - temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); - - temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); - temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); - - // Multiply the C matrix by the scale value - res_1 = _mm256_mul_ps(temp_float[0], scale_1); - res_2 = _mm256_mul_ps(temp_float[0], scale_2); - - // Round the resultant value to the nearest integer - res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - - // Convert float32 scaled rounded value to int32 - temp_32[0] = _mm256_cvtps_epi32(res_1); - temp_32[1] = _mm256_cvtps_epi32(res_2); - - // Convert the s32 to s16 - c_int16_3p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]); - - // Permute to make sure the order is correct - c_int16_3p0 = _mm256_permute4x64_epi64(c_int16_3p0, 0XD8); - - store_reg = _mm256_packs_epi16(c_int16_2p0, c_int16_3p0); - - _mm256_storeu_si256((__m256i *)(c + (rs_c * (ir + 0)) + (0 * 16)), store_reg); - - //-------------------------------------------------------------------------- - - // Extract the first 128 bits of the register - temp[0] = _mm256_extractf128_si256(c_int16_4p0, 0); - - // Extract the second 128 bits of the register - temp[1] = _mm256_extractf128_si256(c_int16_4p0, 1); - - // Since s16 values cannot be converted to f32 directly, - // they are converted to s32, then to f32 and the scale is performed - temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); - temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); - - temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); - temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); - - // Multiply the C matrix by the scale value - res_1 = _mm256_mul_ps(temp_float[0], scale_1); - res_2 = _mm256_mul_ps(temp_float[0], scale_2); - - // Round the resultant value to the nearest integer - res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - - // Convert float32 scaled rounded value to int32 - temp_32[0] = _mm256_cvtps_epi32(res_1); - temp_32[1] = _mm256_cvtps_epi32(res_2); - - // Convert the s32 to s16 - c_int16_4p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]); - - // Permute to make sure the order is correct - c_int16_4p0 = _mm256_permute4x64_epi64(c_int16_4p0, 0XD8); - - //-------------------------------------------------------------------------- - - // Extract the first 128 bits of the register - temp[0] = _mm256_extractf128_si256(c_int16_5p0, 0); - - // Extract the second 128 bits of the register - temp[1] = _mm256_extractf128_si256(c_int16_5p0, 1); - - // Since s16 values cannot be converted to f32 directly, - // they are converted to s32, then to f32 and the scale is performed - temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); - temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); - - temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); - temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); - - // Multiply the C matrix by the scale value - res_1 = _mm256_mul_ps(temp_float[0], scale_1); - res_2 = _mm256_mul_ps(temp_float[0], scale_2); - - // Round the resultant value to the nearest integer - res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - - // Convert float32 scaled rounded value to int32 - temp_32[0] = _mm256_cvtps_epi32(res_1); - temp_32[1] = _mm256_cvtps_epi32(res_2); - - // Convert the s32 to s16 - c_int16_5p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]); - - // Permute to make sure the order is correct - c_int16_5p0 = _mm256_permute4x64_epi64(c_int16_5p0, 0XD8); - - store_reg = _mm256_packs_epi16(c_int16_4p0, c_int16_4p0); - - _mm256_storeu_si256((__m256i *)(c + (rs_c * (ir + 0)) + (0 * 16)), store_reg); + bli_mm256_s16_downscale2(c_int16_4p0, c_int16_5p0, 4, 5); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -987,226 +790,29 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6xlt16) __m128i temp[2]; __m256i temp_32[2]; __m256 temp_float[2]; + __m256 scale_1, scale_2; + __m256 res_1, res_2; + __m256i store_reg; - // Load the scale vector values into the register - __m256 scale_1 = - _mm256_loadu_ps( - (float *)post_ops_list_temp->scale_factor + - post_op_c_j + (0 * 8)); - __m256 scale_2 = - _mm256_loadu_ps( - (float *)post_ops_list_temp->scale_factor + - post_op_c_j + (1 * 8)); - - // Extract the first 128 bits of the register - temp[0] = _mm256_extractf128_si256(c_int16_0p0, 0); - - // Extract the second 128 bits of the register - temp[1] = _mm256_extractf128_si256(c_int16_0p0, 1); - - // Since s16 values cannot be converted to f32 directly, - // they are converted to s32, then to f32 and the scale is performed - temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); - temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); - - temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); - temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); - - // Multiply the C matrix by the scale value - __m256 res_1 = _mm256_mul_ps(temp_float[0], scale_1); - __m256 res_2 = _mm256_mul_ps(temp_float[0], scale_2); - - // Round the resultant value to the nearest integer - res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - - // Convert float32 scaled rounded value to int32 - temp_32[0] = _mm256_cvtps_epi32(res_1); - temp_32[1] = _mm256_cvtps_epi32(res_2); - - // Convert the s32 to s16 - c_int16_0p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]); - - //-------------------------------------------------------------------------- - - // Extract the first 128 bits of the register - temp[0] = _mm256_extractf128_si256(c_int16_1p0, 0); - - // Extract the second 128 bits of the register - temp[1] = _mm256_extractf128_si256(c_int16_1p0, 1); - - // Since s16 values cannot be converted to f32 directly, - // they are converted to s32, then to f32 and the scale is performed - temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); - temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); - - temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); - temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); - - // Multiply the C matrix by the scale value - res_1 = _mm256_mul_ps(temp_float[0], scale_1); - res_2 = _mm256_mul_ps(temp_float[0], scale_2); - - // Round the resultant value to the nearest integer - res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - - // Convert float32 scaled rounded value to int32 - temp_32[0] = _mm256_cvtps_epi32(res_1); - temp_32[1] = _mm256_cvtps_epi32(res_2); - - // Convert the s32 to s16 - c_int16_1p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]); - - // Permute to make sure the order is correct - c_int16_1p0 = _mm256_permute4x64_epi64(c_int16_1p0, 0XD8); - - __m256i store_reg = _mm256_packs_epi16(c_int16_0p0, c_int16_1p0); - - _mm256_storeu_si256((__m256i *)(c + (rs_c * (ir + 0)) + (0 * 16)), store_reg); - - //-------------------------------------------------------------------------- - - // Extract the first 128 bits of the register - temp[0] = _mm256_extractf128_si256(c_int16_2p0, 0); - - // Extract the second 128 bits of the register - temp[1] = _mm256_extractf128_si256(c_int16_2p0, 1); - - // Since s16 values cannot be converted to f32 directly, - // they are converted to s32, then to f32 and the scale is performed - temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); - temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); - - temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); - temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); - - // Multiply the C matrix by the scale value - res_1 = _mm256_mul_ps(temp_float[0], scale_1); - res_2 = _mm256_mul_ps(temp_float[0], scale_2); - - // Round the resultant value to the nearest integer - res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + float float_buf[16]; + int8 store_buf[16]; - // Convert float32 scaled rounded value to int32 - temp_32[0] = _mm256_cvtps_epi32(res_1); - temp_32[1] = _mm256_cvtps_epi32(res_2); + memcpy( float_buf, ( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j ), ( n0_rem * sizeof( float ) ) ); - // Convert the s32 to s16 - c_int16_2p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]); - - // Permute to make sure the order is correct - c_int16_2p0 = _mm256_permute4x64_epi64(c_int16_2p0, 0XD8); - //-------------------------------------------------------------------------- - - // Extract the first 128 bits of the register - temp[0] = _mm256_extractf128_si256(c_int16_3p0, 0); - - // Extract the second 128 bits of the register - temp[1] = _mm256_extractf128_si256(c_int16_3p0, 1); - - // Since s16 values cannot be converted to f32 directly, - // they are converted to s32, then to f32 and the scale is performed - temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); - temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); - - temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); - temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); - - // Multiply the C matrix by the scale value - res_1 = _mm256_mul_ps(temp_float[0], scale_1); - res_2 = _mm256_mul_ps(temp_float[0], scale_2); - - // Round the resultant value to the nearest integer - res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - - // Convert float32 scaled rounded value to int32 - temp_32[0] = _mm256_cvtps_epi32(res_1); - temp_32[1] = _mm256_cvtps_epi32(res_2); - - // Convert the s32 to s16 - c_int16_3p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]); - - // Permute to make sure the order is correct - c_int16_3p0 = _mm256_permute4x64_epi64(c_int16_3p0, 0XD8); - - store_reg = _mm256_packs_epi16(c_int16_2p0, c_int16_3p0); + // Load the scale vector values into the register + scale_1 = _mm256_loadu_ps(float_buf + (0 * 8)); + scale_2 = _mm256_loadu_ps(float_buf + (1 * 8)); - _mm256_storeu_si256((__m256i *)(c + (rs_c * (ir + 0)) + (0 * 16)), store_reg); + bli_mm256_s16_downscale2_lt16(c_int16_0p0, c_int16_1p0, 0, 1) //-------------------------------------------------------------------------- - // Extract the first 128 bits of the register - temp[0] = _mm256_extractf128_si256(c_int16_4p0, 0); - - // Extract the second 128 bits of the register - temp[1] = _mm256_extractf128_si256(c_int16_4p0, 1); - - // Since s16 values cannot be converted to f32 directly, - // they are converted to s32, then to f32 and the scale is performed - temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); - temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); - - temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); - temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); - - // Multiply the C matrix by the scale value - res_1 = _mm256_mul_ps(temp_float[0], scale_1); - res_2 = _mm256_mul_ps(temp_float[0], scale_2); - - // Round the resultant value to the nearest integer - res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - - // Convert float32 scaled rounded value to int32 - temp_32[0] = _mm256_cvtps_epi32(res_1); - temp_32[1] = _mm256_cvtps_epi32(res_2); - - // Convert the s32 to s16 - c_int16_4p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]); - - // Permute to make sure the order is correct - c_int16_4p0 = _mm256_permute4x64_epi64(c_int16_4p0, 0XD8); + bli_mm256_s16_downscale2_lt16(c_int16_2p0, c_int16_3p0, 2, 3) //-------------------------------------------------------------------------- - // Extract the first 128 bits of the register - temp[0] = _mm256_extractf128_si256(c_int16_5p0, 0); - - // Extract the second 128 bits of the register - temp[1] = _mm256_extractf128_si256(c_int16_5p0, 1); - - // Since s16 values cannot be converted to f32 directly, - // they are converted to s32, then to f32 and the scale is performed - temp_32[0] = _mm256_cvtepi16_epi32(temp[0]); - temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]); - - temp_32[1] = _mm256_cvtepi16_epi32(temp[1]); - temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]); - - // Multiply the C matrix by the scale value - res_1 = _mm256_mul_ps(temp_float[0], scale_1); - res_2 = _mm256_mul_ps(temp_float[0], scale_2); - - // Round the resultant value to the nearest integer - res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - - // Convert float32 scaled rounded value to int32 - temp_32[0] = _mm256_cvtps_epi32(res_1); - temp_32[1] = _mm256_cvtps_epi32(res_2); - - // Convert the s32 to s16 - c_int16_5p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]); - - // Permute to make sure the order is correct - c_int16_5p0 = _mm256_permute4x64_epi64(c_int16_5p0, 0XD8); - - store_reg = _mm256_packs_epi16(c_int16_4p0, c_int16_4p0); - - _mm256_storeu_si256((__m256i *)(c + (rs_c * (ir + 0)) + (0 * 16)), store_reg); + bli_mm256_s16_downscale2_lt16(c_int16_4p0, c_int16_5p0, 4, 5) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } diff --git a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_s16_kern_macros.h b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_s16_kern_macros.h index 07131d1099..21876f555e 100644 --- a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_s16_kern_macros.h +++ b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_s16_kern_macros.h @@ -52,4 +52,344 @@ reg = _mm256_or_si256( b0, reg ); \ \ +//-------------------------------------------------------------------------- + +#define bli_mm256_s16_downscale(c_int16__p0, c_int16__p1, vec_loc)\ +\ + /* Extract the first 128 bits of the register*/\ + temp[0] = _mm256_extractf128_si256(c_int16__p0, 0);\ + /* Extract the second 128 bits of the register*/\ + temp[1] = _mm256_extractf128_si256(c_int16__p0, 1);\ +\ + temp_32[0] = _mm256_cvtepi16_epi32(temp[0]);\ + temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]);\ +\ + /* Since s16 values cannot be converted to f32 directly, + they are converted to s32, then to f32 and the scale is performed*/\ + temp_32[1] = _mm256_cvtepi16_epi32(temp[1]);\ + temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]);\ +\ + /* Multiply the C matrix by the scale value*/\ + res_1 = _mm256_mul_ps(temp_float[0], scale_1);\ + res_2 = _mm256_mul_ps(temp_float[0], scale_2);\ +\ + /* Round the resultant value to the nearest integer*/\ + res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));\ + res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));\ +\ + /* Convert float32 scaled rounded value to int32 */\ + temp_32[0] = _mm256_cvtps_epi32(res_1);\ + temp_32[1] = _mm256_cvtps_epi32(res_2);\ +\ + /* Convert the s32 to s16 */\ + c_int16__p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]);\ +\ + /*Permute to make sure the order is correct*/\ + c_int16__p0 = _mm256_permute4x64_epi64(c_int16__p0, 0XD8);\ +\ + /* Extract the first 128 bits of the register*/\ + temp[0] = _mm256_extractf128_si256(c_int16__p1, 0);\ +\ + /* Extract the second 128 bits of the register*/\ + temp[1] = _mm256_extractf128_si256(c_int16__p1, 1);\ +\ + temp_32[0] = _mm256_cvtepi16_epi32(temp[0]);\ + temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]);\ +\ + /* Since s16 values cannot be converted to f32 directly, + they are converted to s32, then to f32 and the scale is performed*/\ + temp_32[1] = _mm256_cvtepi16_epi32(temp[1]);\ + temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]);\ +\ + /* Multiply the C matrix by the scale value*/\ + res_1 = _mm256_mul_ps(temp_float[0], scale_1);\ + res_2 = _mm256_mul_ps(temp_float[0], scale_2);\ +\ + /* Round the resultant value to the nearest integer*/\ + res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));\ + res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));\ +\ + /* Convert float32 scaled rounded value to int32 */\ + temp_32[0] = _mm256_cvtps_epi32(res_1);\ + temp_32[1] = _mm256_cvtps_epi32(res_2);\ +\ + /* Convert the s32 to s16 */\ + c_int16__p1 = _mm256_packs_epi32(temp_32[0], temp_32[1]);\ +\ + /*Permute to make sure the order is correct*/\ + c_int16__p1 = _mm256_permute4x64_epi64(c_int16__p1, 0XD8);\ +\ + /* Convert the s16 to s8 */\ + store_reg = _mm256_packs_epi16(c_int16__p0, c_int16__p1);\ +\ + /* Store the result in s8 form */\ + _mm256_storeu_si256((__m256i *)(( int8_t* )post_ops_list_temp->op_args3 + \ + ( rs_c_downscale * ( post_op_c_i + vec_loc ) ) + post_op_c_j + ( 0 * 16 )), store_reg);\ +\ + +//-------------------------------------------------------------------------- + +#define bli_mm256_s16_downscale2(c_int16__p0, c_int16__p1, vec_loc1, vec_loc2)\ +\ + /* Extract the first 128 bits of the register*/\ + temp[0] = _mm256_extractf128_si256(c_int16__p0, 0);\ + /* Extract the second 128 bits of the register*/\ + temp[1] = _mm256_extractf128_si256(c_int16__p0, 1);\ +\ + temp_32[0] = _mm256_cvtepi16_epi32(temp[0]);\ + temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]);\ +\ + /* Since s16 values cannot be converted to f32 directly, + they are converted to s32, then to f32 and the scale is performed*/\ + temp_32[1] = _mm256_cvtepi16_epi32(temp[1]);\ + temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]);\ +\ + /* Multiply the C matrix by the scale value*/\ + res_1 = _mm256_mul_ps(temp_float[0], scale_1);\ + res_2 = _mm256_mul_ps(temp_float[0], scale_2);\ +\ + /* Round the resultant value to the nearest integer*/\ + res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));\ + res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));\ +\ + /* Convert float32 scaled rounded value to int32 */\ + temp_32[0] = _mm256_cvtps_epi32(res_1);\ + temp_32[1] = _mm256_cvtps_epi32(res_2);\ +\ + /* Convert the s32 to s16 */\ + c_int16__p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]);\ +\ + /*Permute to make sure the order is correct*/\ + c_int16__p0 = _mm256_permute4x64_epi64(c_int16__p0, 0XD8);\ +\ + /* Extract the first 128 bits of the register*/\ + temp[0] = _mm256_extractf128_si256(c_int16__p1, 0);\ +\ + /* Extract the second 128 bits of the register*/\ + temp[1] = _mm256_extractf128_si256(c_int16__p1, 1);\ +\ + temp_32[0] = _mm256_cvtepi16_epi32(temp[0]);\ + temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]);\ +\ + /* Since s16 values cannot be converted to f32 directly, + they are converted to s32, then to f32 and the scale is performed*/\ + temp_32[1] = _mm256_cvtepi16_epi32(temp[1]);\ + temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]);\ +\ + /* Multiply the C matrix by the scale value*/\ + res_1 = _mm256_mul_ps(temp_float[0], scale_1);\ + res_2 = _mm256_mul_ps(temp_float[0], scale_2);\ +\ + /* Round the resultant value to the nearest integer*/\ + res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));\ + res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));\ +\ + /* Convert float32 scaled rounded value to int32 */\ + temp_32[0] = _mm256_cvtps_epi32(res_1);\ + temp_32[1] = _mm256_cvtps_epi32(res_2);\ +\ + /* Convert the s32 to s16 */\ + c_int16__p1 = _mm256_packs_epi32(temp_32[0], temp_32[1]);\ +\ + /*Permute to make sure the order is correct*/\ + c_int16__p1 = _mm256_permute4x64_epi64(c_int16__p1, 0XD8);\ +\ + /* Convert the s16 to s8 */\ + store_reg = _mm256_packs_epi16(c_int16__p0, c_int16__p1);\ + /* Extract the first 128 bits of the register*/\ + temp[0] = _mm256_extractf128_si256(store_reg, 0);\ + /* Extract the second 128 bits of the register*/\ + temp[1] = _mm256_extractf128_si256(store_reg, 1);\ +\ + /* Store the result in s8 form */\ + _mm_storeu_si128((__m128i *)(( int8_t* )post_ops_list_temp->op_args3 + \ + ( rs_c_downscale * ( post_op_c_i + vec_loc1 ) ) + post_op_c_j + ( 0 * 16 )), temp[0]);\ + _mm_storeu_si128((__m128i *)(( int8_t* )post_ops_list_temp->op_args3 + \ + ( rs_c_downscale * ( post_op_c_i + vec_loc2 ) ) + post_op_c_j + ( 0 * 16 )), temp[1]);\ +\ + +//-------------------------------------------------------------------------- + +#define bli_mm256_s16_downscale2_lt16(c_int16__p0, c_int16__p1, vec_loc1, vec_loc2)\ +\ + /* Extract the first 128 bits of the register*/\ + temp[0] = _mm256_extractf128_si256(c_int16__p0, 0);\ + /* Extract the second 128 bits of the register*/\ + temp[1] = _mm256_extractf128_si256(c_int16__p0, 1);\ +\ + temp_32[0] = _mm256_cvtepi16_epi32(temp[0]);\ + temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]);\ +\ + /* Since s16 values cannot be converted to f32 directly, + they are converted to s32, then to f32 and the scale is performed*/\ + temp_32[1] = _mm256_cvtepi16_epi32(temp[1]);\ + temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]);\ +\ + /* Multiply the C matrix by the scale value*/\ + res_1 = _mm256_mul_ps(temp_float[0], scale_1);\ + res_2 = _mm256_mul_ps(temp_float[0], scale_2);\ +\ + /* Round the resultant value to the nearest integer*/\ + res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));\ + res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));\ +\ + /* Convert float32 scaled rounded value to int32 */\ + temp_32[0] = _mm256_cvtps_epi32(res_1);\ + temp_32[1] = _mm256_cvtps_epi32(res_2);\ +\ + /* Convert the s32 to s16 */\ + c_int16__p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]);\ +\ + /*Permute to make sure the order is correct*/\ + c_int16__p0 = _mm256_permute4x64_epi64(c_int16__p0, 0XD8);\ +\ + /* Extract the first 128 bits of the register*/\ + temp[0] = _mm256_extractf128_si256(c_int16__p1, 0);\ +\ + /* Extract the second 128 bits of the register*/\ + temp[1] = _mm256_extractf128_si256(c_int16__p1, 1);\ +\ + temp_32[0] = _mm256_cvtepi16_epi32(temp[0]);\ + temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]);\ +\ + /* Since s16 values cannot be converted to f32 directly, + they are converted to s32, then to f32 and the scale is performed*/\ + temp_32[1] = _mm256_cvtepi16_epi32(temp[1]);\ + temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]);\ +\ + /* Multiply the C matrix by the scale value*/\ + res_1 = _mm256_mul_ps(temp_float[0], scale_1);\ + res_2 = _mm256_mul_ps(temp_float[0], scale_2);\ +\ + /* Round the resultant value to the nearest integer*/\ + res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));\ + res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));\ +\ + /* Convert float32 scaled rounded value to int32 */\ + temp_32[0] = _mm256_cvtps_epi32(res_1);\ + temp_32[1] = _mm256_cvtps_epi32(res_2);\ +\ + /* Convert the s32 to s16 */\ + c_int16__p1 = _mm256_packs_epi32(temp_32[0], temp_32[1]);\ +\ + /*Permute to make sure the order is correct*/\ + c_int16__p1 = _mm256_permute4x64_epi64(c_int16__p1, 0XD8);\ +\ + /* Convert the s16 to s8 */\ + store_reg = _mm256_packs_epi16(c_int16__p0, c_int16__p1);\ + /* Extract the first 128 bits of the register*/\ + temp[0] = _mm256_extractf128_si256(store_reg, 0);\ + /* Extract the second 128 bits of the register*/\ + temp[1] = _mm256_extractf128_si256(store_reg, 1);\ +\ + /* Store the result in s8 form */\ + _mm_storeu_si128((__m128i *)store_buf, temp[0]);\ + memcpy( ( int8_t* )post_ops_list_temp->op_args3 + \ + ( rs_c_downscale * ( post_op_c_i + vec_loc1 ) ) + post_op_c_j + \ + ( 0 * 16 ) , store_buf, ( n0_rem * sizeof( int8_t ) ) ); \ +\ + _mm_storeu_si128((__m128i *)store_buf, temp[1]);\ + memcpy( ( int8_t* )post_ops_list_temp->op_args3 + \ + ( rs_c_downscale * ( post_op_c_i + vec_loc1 ) ) + post_op_c_j + \ + ( 0 * 16 ) , store_buf, ( n0_rem * sizeof( int8_t ) ) ); \ +\ + +//-------------------------------------------------------------------------- + +#define bli_mm256_s16_downscale2_edge(c_int16__p0, c_int16__p1)\ +\ + /* Extract the first 128 bits of the register*/\ + temp[0] = _mm256_extractf128_si256(c_int16__p0, 0);\ + /* Extract the second 128 bits of the register*/\ + temp[1] = _mm256_extractf128_si256(c_int16__p0, 1);\ +\ + temp_32[0] = _mm256_cvtepi16_epi32(temp[0]);\ + temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]);\ +\ + /* Since s16 values cannot be converted to f32 directly, + they are converted to s32, then to f32 and the scale is performed*/\ + temp_32[1] = _mm256_cvtepi16_epi32(temp[1]);\ + temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]);\ +\ + /* Multiply the C matrix by the scale value*/\ + res_1 = _mm256_mul_ps(temp_float[0], scale_1);\ + res_2 = _mm256_mul_ps(temp_float[0], scale_2);\ +\ + /* Round the resultant value to the nearest integer*/\ + res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));\ + res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));\ +\ + /* Convert float32 scaled rounded value to int32 */\ + temp_32[0] = _mm256_cvtps_epi32(res_1);\ + temp_32[1] = _mm256_cvtps_epi32(res_2);\ +\ + /* Convert the s32 to s16 */\ + c_int16__p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]);\ +\ + /*Permute to make sure the order is correct*/\ + c_int16__p0 = _mm256_permute4x64_epi64(c_int16__p0, 0XD8);\ +\ + /* Convert the s32 to s16 */\ + c_int16__p1 = _mm256_packs_epi32(temp_32[0], temp_32[1]);\ +\ + /*Permute to make sure the order is correct*/\ + c_int16__p1 = _mm256_permute4x64_epi64(c_int16__p1, 0XD8);\ +\ + /* Convert the s16 to s8 */\ + store_reg = _mm256_packs_epi16(c_int16__p0, c_int16__p1);\ + /* Extract the first 128 bits of the register*/\ + temp[0] = _mm256_extractf128_si256(store_reg, 0);\ +\ + /* Store the result in s8 form */\ + _mm_storeu_si128((__m128i *)(( int8_t* )post_ops_list_temp->op_args3 + \ + ( rs_c_downscale * ( post_op_c_i + 0 ) ) + post_op_c_j + ( 0 * 16 )), temp[0]);\ +\ + +//-------------------------------------------------------------------------- + +#define bli_mm256_s16_downscale2_edge_lt16(c_int16__p0, c_int16__p1)\ +\ + /* Extract the first 128 bits of the register*/\ + temp[0] = _mm256_extractf128_si256(c_int16__p0, 0);\ + /* Extract the second 128 bits of the register*/\ + temp[1] = _mm256_extractf128_si256(c_int16__p0, 1);\ +\ + temp_32[0] = _mm256_cvtepi16_epi32(temp[0]);\ + temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]);\ +\ + /* Since s16 values cannot be converted to f32 directly, + they are converted to s32, then to f32 and the scale is performed*/\ + temp_32[1] = _mm256_cvtepi16_epi32(temp[1]);\ + temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]);\ +\ + /* Multiply the C matrix by the scale value*/\ + res_1 = _mm256_mul_ps(temp_float[0], scale_1);\ + res_2 = _mm256_mul_ps(temp_float[0], scale_2);\ +\ + /* Round the resultant value to the nearest integer*/\ + res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));\ + res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));\ +\ + /* Convert float32 scaled rounded value to int32 */\ + temp_32[0] = _mm256_cvtps_epi32(res_1);\ + temp_32[1] = _mm256_cvtps_epi32(res_2);\ +\ + /* Convert the s32 to s16 */\ + c_int16__p0 = _mm256_packs_epi32(temp_32[0], temp_32[1]);\ +\ + /*Permute to make sure the order is correct*/\ + c_int16__p0 = _mm256_permute4x64_epi64(c_int16__p0, 0XD8);\ +\ + /* Convert the s16 to s8 */\ + store_reg = _mm256_packs_epi16(c_int16__p0, c_int16__p1);\ + /* Extract the first 128 bits of the register*/\ + temp[0] = _mm256_extractf128_si256(store_reg, 0);\ +\ + /* Store the result in s8 form */\ + _mm_storeu_si128((__m128i *)store_buf, temp[0]);\ + memcpy( ( int8_t* )post_ops_list_temp->op_args3 + \ + ( rs_c_downscale * ( post_op_c_i + 0 ) ) + post_op_c_j + \ + ( 0 * 16 ) , store_buf, ( n0_rem * sizeof( int8_t ) ) ); \ +\ + #endif //LPGEMM_S16_KERN_MACROS_H From 0dfcdac9cf5190facac5f54eedb408e6cf6e786c Mon Sep 17 00:00:00 2001 From: mkadavil Date: Wed, 7 Sep 2022 17:53:39 +0530 Subject: [PATCH 218/243] Low Precision GEMM framework fixes for downscaling. - The temporary buffer allocated for C matrix when downscaling is enabled is not filled properly. This results in wrong gemm accumulation when beta != 0, and thus wrong output after downscaling. The C panel iterators used for filling the temporary buffer are updated to fix it. - Low precision gemm bench updated for testing/benchmarking downscaling. AMD-Internal: [CPUPL-2514] Change-Id: Ib1ba25ba9df2d2997edaaf0763ff0113fb35d6eb --- .../aocl_gemm/frame/bf16bf16f32/lpgemm_bf16.c | 29 ++- .../aocl_gemm/frame/u8s8s16/lpgemm_u8s8s16.c | 21 +- .../aocl_gemm/frame/u8s8s32/lpgemm_u8s8s32.c | 20 +- bench/bench_aocl_gemm/bench_lpgemm.c | 246 ++++++++++++------ 4 files changed, 225 insertions(+), 91 deletions(-) diff --git a/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_bf16.c b/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_bf16.c index 1a092377e7..a92c70c8fc 100644 --- a/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_bf16.c +++ b/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_bf16.c @@ -137,14 +137,25 @@ LPGEMM_5LOOP(bfloat16,bfloat16,float,bf16bf16f32of32) { dim_t i_temp = 0; dim_t j_temp = 0; + int32_t temp_conv_buf = 0; // Upscale out C to temporary C matrix. for ( dim_t i_dscale = ic_start; i_dscale < ic_end; ++i_dscale ) { j_temp = 0; for ( dim_t j_dscale = jc; j_dscale < nc0; ++j_dscale ) { - *( temp_scal_c_buffer_bf16 + ( nc0 * i_temp ) + j_temp ) = - ( float )( *( c + ( rs_c * i_dscale ) + j_dscale ) ); + // Implemented with the idea sizeof(float)=4. + temp_conv_buf = 0; + temp_conv_buf = *( ( int16_t* )( ( bfloat16* )c + + ( rs_c * i_dscale ) + j_dscale ) ); + + // Add 16 bits in the fractional part. + temp_conv_buf = temp_conv_buf << 16; + + // Store the bytes in float format. + *( temp_scal_c_buffer_bf16 + ( nc0 * i_temp ) + j_temp ) + = *( ( float* )( &temp_conv_buf ) ); + j_temp++; } i_temp++; @@ -193,7 +204,7 @@ LPGEMM_5LOOP(bfloat16,bfloat16,float,bf16bf16f32of32) ); thread->comm[jc_work_id].sent_object = - bli_mem_buffer( &mem_b ); + bli_mem_buffer( &mem_b ); } // All threads in work group should wait till chief thread has @@ -264,7 +275,17 @@ LPGEMM_5LOOP(bfloat16,bfloat16,float,bf16bf16f32of32) for ( dim_t ic = ic_start; ic < ic_end; ic += MC ) { dim_t mc0 = bli_min( ( ic_end - ic ), MC ); - c_use_ic = c_use_jc + ( rs_c_use * ic ); + + // Only per thread C matrix is stored in temp buffer, so both + // per thread jc and ic start should be normalized to zero. + if ( c_downscale == TRUE ) + { + c_use_ic = c_use_jc + ( rs_c_use * ( ic - ic_start ) ); + } + else + { + c_use_ic = c_use_jc + ( rs_c_use * ic ); + } if ( mtag_a == UNPACKED ) { diff --git a/addon/aocl_gemm/frame/u8s8s16/lpgemm_u8s8s16.c b/addon/aocl_gemm/frame/u8s8s16/lpgemm_u8s8s16.c index 0465abce66..b8f5115429 100644 --- a/addon/aocl_gemm/frame/u8s8s16/lpgemm_u8s8s16.c +++ b/addon/aocl_gemm/frame/u8s8s16/lpgemm_u8s8s16.c @@ -143,10 +143,13 @@ LPGEMM_5LOOP(uint8_t,int8_t,int16_t,u8s8s16o16) for ( dim_t i_dscale = ic_start; i_dscale < ic_end; ++i_dscale ) { j_temp = 0; - for ( dim_t j_dscale = jc; j_dscale < nc0; ++j_dscale ) + for ( dim_t j_dscale = jc; j_dscale < ( jc + nc0 ); ++j_dscale ) { - *( temp_scal_c_buffer_u8s8s16o16 + ( nc0 * i_dscale ) + j_dscale ) = - ( int16_t )( *( c + ( rs_c * i_dscale ) + j_dscale ) ); + *( temp_scal_c_buffer_u8s8s16o16 + + ( nc0 * i_temp ) + j_temp ) = + ( int16_t )( *( ( ( int8_t* )c ) + + ( rs_c * i_dscale ) + j_dscale ) ); + j_temp++; } i_temp++; @@ -268,7 +271,17 @@ LPGEMM_5LOOP(uint8_t,int8_t,int16_t,u8s8s16o16) for (dim_t ic = ic_start; ic < ic_end; ic += MC) { dim_t mc0 = bli_min((ic_end - ic), MC); - c_use_ic = c_use_jc + (rs_c_use * ic); + + // Only per thread C matrix is stored in temp buffer, so both + // per thread jc and ic start should be normalized to zero. + if ( c_downscale == TRUE ) + { + c_use_ic = c_use_jc + ( rs_c_use * ( ic - ic_start ) ); + } + else + { + c_use_ic = c_use_jc + ( rs_c_use * ic ); + } a_use = a + (rs_a * ic) + (cs_a * pc); cs_a_use = 1; diff --git a/addon/aocl_gemm/frame/u8s8s32/lpgemm_u8s8s32.c b/addon/aocl_gemm/frame/u8s8s32/lpgemm_u8s8s32.c index 019f081860..a62d42a1f1 100644 --- a/addon/aocl_gemm/frame/u8s8s32/lpgemm_u8s8s32.c +++ b/addon/aocl_gemm/frame/u8s8s32/lpgemm_u8s8s32.c @@ -154,10 +154,13 @@ LPGEMM_5LOOP(uint8_t,int8_t,int32_t,u8s8s32o32) for ( dim_t i_dscale = ic_start; i_dscale < ic_end; ++i_dscale ) { j_temp = 0; - for ( dim_t j_dscale = jc; j_dscale < nc0; ++j_dscale ) + for ( dim_t j_dscale = jc; j_dscale < ( jc + nc0 ); ++j_dscale ) { - *( temp_scal_c_buffer_u8s8s32o32 + ( nc0 * i_temp ) + j_temp ) = - ( int32_t )( *( c + ( rs_c * i_dscale ) + j_dscale ) ); + *( temp_scal_c_buffer_u8s8s32o32 + + ( nc0 * i_temp ) + j_temp ) = + ( int32_t )( *( ( ( int8_t* )c ) + + ( rs_c * i_dscale ) + j_dscale ) ); + j_temp++; } i_temp++; @@ -281,7 +284,16 @@ LPGEMM_5LOOP(uint8_t,int8_t,int32_t,u8s8s32o32) { dim_t mc0 = bli_min( ( ic_end - ic ), MC ); - c_use_ic = c_use_jc + ( rs_c_use * ic ); + // Only per thread C matrix is stored in temp buffer, so both + // per thread jc and ic start should be normalized to zero. + if ( c_downscale == TRUE ) + { + c_use_ic = c_use_jc + ( rs_c_use * ( ic - ic_start ) ); + } + else + { + c_use_ic = c_use_jc + ( rs_c_use * ic ); + } // Matrix A packed and reordered code path is not triggerred // currently since we do not support it yet. diff --git a/bench/bench_aocl_gemm/bench_lpgemm.c b/bench/bench_aocl_gemm/bench_lpgemm.c index 67a9277a5a..3484dffb3d 100644 --- a/bench/bench_aocl_gemm/bench_lpgemm.c +++ b/bench/bench_aocl_gemm/bench_lpgemm.c @@ -16,6 +16,8 @@ char bench_mode = 'p'; int32_t global_n_repeat = 0; +char global_dscale_out = 'n'; + #define _XSTR(str) #str #define XSTR(str) _XSTR(str) @@ -27,7 +29,7 @@ void fill_array_ ## ctype ( void* arr, dim_t size ) \ ctype* temp_arr = ( ctype* ) arr; \ for ( dim_t i = 0; i < size; ++i ) \ { \ - temp_arr[i] = ( ctype )( i % 100 ); \ + temp_arr[i] = ( ctype )( i % 10 ); \ } \ } \ @@ -150,8 +152,10 @@ void mat_mul_ ## BLAS_SFX \ } */\ } \ -GEN_BLIS_MAT_MUL_FUNC(uint8_t, int8_t, int16_t, u8s8s16os16) +GEN_BLIS_MAT_MUL_FUNC(uint8_t,int8_t,int16_t,u8s8s16os16) +GEN_BLIS_MAT_MUL_FUNC(uint8_t,int8_t,int8_t,u8s8s16os8) GEN_BLIS_MAT_MUL_FUNC(uint8_t,int8_t,int32_t,u8s8s32os32) +GEN_BLIS_MAT_MUL_FUNC(uint8_t,int8_t,int8_t,u8s8s32os8) GEN_BLIS_MAT_MUL_FUNC(float,float,float,f32f32f32of32) double get_gflops @@ -236,11 +240,13 @@ void mat_mul_bench_driver_ ## BLAS_SFX \ print_result( XSTR(BLAS_SFX), n_repeats, m, n, k, lda, ldb, ldc, min_time_diff); \ } \ -GEN_MAT_MUL_BENCH_DRV_FUNC(uint8_t, int8_t, int16_t, u8s8s16os16) +GEN_MAT_MUL_BENCH_DRV_FUNC(uint8_t,int8_t,int16_t,u8s8s16os16) +GEN_MAT_MUL_BENCH_DRV_FUNC(uint8_t,int8_t,int8_t,u8s8s16os8) GEN_MAT_MUL_BENCH_DRV_FUNC(uint8_t,int8_t,int32_t,u8s8s32os32) +GEN_MAT_MUL_BENCH_DRV_FUNC(uint8_t,int8_t,int8_t,u8s8s32os8) GEN_MAT_MUL_BENCH_DRV_FUNC(float,float,float,f32f32f32of32) -#define GEN_MAT_MUL_ACC_CHK_DRV_FUNC(A_type,B_type,C_type,BLAS_SFX) \ +#define GEN_MAT_MUL_ACC_CHK_DRV_FUNC(A_type,B_type,C_type,DSCALE_type,SCALE_type,BLAS_SFX) \ void mat_mul_accuracy_check_driver_ ## BLAS_SFX \ ( \ FILE* fout, \ @@ -264,7 +270,8 @@ void mat_mul_accuracy_check_driver_ ## BLAS_SFX \ { \ for ( dim_t j = 0; j < n; ++j ) \ { \ - C_type temp_accum = 0; \ + DSCALE_type temp_accum = 0; \ + C_type out_temp_accum = 0; \ \ for ( dim_t p = 0; p < k; ++p) \ { \ @@ -281,15 +288,17 @@ void mat_mul_accuracy_check_driver_ ## BLAS_SFX \ { \ if ( post_op->seq_length >= 1 ) \ { \ - temp_accum += ( *( ( C_type* )post_op->bias.bias + j ) ); \ + temp_accum += ( *( ( DSCALE_type* )post_op->bias.bias + j ) ); \ } \ - if ( post_op->seq_length > 1 ) \ + if ( ( post_op->seq_length > 1 ) && \ + ( post_op->seq_vector[1] == ELTWISE ) ) \ { \ if ( post_op->eltwise.algo.alpha != NULL ) /* PReLU*/ \ { \ temp_accum = ( temp_accum > 0 ) ? \ - temp_accum : \ - ( temp_accum * *( ( C_type* ) post_op->eltwise.algo.alpha ) ); \ + temp_accum : \ + ( temp_accum * \ + *( ( DSCALE_type* ) post_op->eltwise.algo.alpha ) ); \ } \ else \ { \ @@ -305,21 +314,30 @@ void mat_mul_accuracy_check_driver_ ## BLAS_SFX \ { \ temp_accum = ( temp_accum > 0 ) ? \ temp_accum : \ - ( temp_accum * *( ( C_type* ) post_op->eltwise.algo.alpha ) ); \ + ( temp_accum * *( ( DSCALE_type* ) post_op->eltwise.algo.alpha ) ); \ } \ else \ { \ temp_accum = ( temp_accum > 0 ) ? temp_accum : 0 ; \ } \ } \ - if ( post_op->seq_length > 1 ) \ + if ( ( post_op->seq_length > 1 ) && ( post_op->seq_vector[1] == BIAS ) ) \ { \ - temp_accum += ( *( ( C_type* )post_op->bias.bias + j ) ); \ + temp_accum += ( *( ( DSCALE_type* )post_op->bias.bias + j ) ); \ } \ } \ } \ + if ( global_dscale_out == 'y' ) \ + { \ + out_temp_accum = ( C_type )lroundf( ( SCALE_type )temp_accum * \ + ( *( ( SCALE_type* )post_op->sum.scale_factor + j ) ) ); \ + } \ + else \ + { \ + out_temp_accum = ( C_type )temp_accum; \ + } \ \ - if ( *( c + ( ldc * i ) + j ) != temp_accum ) \ + if ( *( c + ( ldc * i ) + j ) != out_temp_accum ) \ { \ if ( fout ) \ { \ @@ -328,7 +346,7 @@ void mat_mul_accuracy_check_driver_ ## BLAS_SFX \ XSTR(BLAS_SFX), m, n, k, lda, ldb, ldc ); \ fflush( fout ); \ } \ - printf("failure, m: %ld, n: %ld, k: %ld\n", i, j, k); \ + printf("failure, m: %ld, n: %ld, k: %ld\n", i, j, k ); \ goto cleanup_acc; \ } \ } \ @@ -337,12 +355,14 @@ cleanup_acc: \ return; \ } \ -GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t, int8_t, int16_t, u8s8s16os16) -GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t,int8_t,int32_t,u8s8s32os32) -GEN_MAT_MUL_ACC_CHK_DRV_FUNC(float,float,float,f32f32f32of32) +GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t,int8_t,int16_t,int16_t,float,u8s8s16os16) +GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t,int8_t,int8_t,int16_t,float,u8s8s16os8) +GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t,int8_t,int32_t,int32_t,float,u8s8s32os32) +GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t,int8_t,int8_t,int32_t,float,u8s8s32os8) +GEN_MAT_MUL_ACC_CHK_DRV_FUNC(float,float,float,float,float,f32f32f32of32) /* Only supports bias followed by RELU and vice versa for now.*/ \ -#define GEN_MAT_MUL_POST_OPS_CREATOR(C_type,BLAS_SFX) \ +#define GEN_MAT_MUL_POST_OPS_CREATOR(C_type,DSCALE_type,BLAS_SFX) \ aocl_post_op* lpgemm_create_post_ops_struct_ ## BLAS_SFX \ ( \ dim_t m, \ @@ -353,13 +373,13 @@ aocl_post_op* lpgemm_create_post_ops_struct_ ## BLAS_SFX \ aocl_post_op* post_ops = NULL; \ post_ops = ( aocl_post_op* ) malloc( sizeof( aocl_post_op ) ); \ \ - if ( post_ops == NULL ) \ + if ( ( post_ops == NULL ) && ( global_dscale_out == 'n' ) ) \ { \ return NULL; \ } \ \ - /* Only supporting 2 post ops at max for now.*/ \ - dim_t max_post_ops_seq_length = 2; \ + /* Only supporting 3 post ops at max for now.*/ \ + dim_t max_post_ops_seq_length = 3; \ post_ops->seq_vector = ( AOCL_POST_OP_TYPE* ) \ malloc \ ( \ @@ -374,58 +394,94 @@ aocl_post_op* lpgemm_create_post_ops_struct_ ## BLAS_SFX \ } \ \ /* Parse post ops list.*/ \ - char* ops_tok = strtok(post_ops_str, ", " ); \ - bool is_param_relu = FALSE; \ dim_t cur_op_index = 0; \ - while ( ops_tok ) \ + /* Ensure the buffers that use NULL check in deinit code is properly set to NULL.*/ \ + post_ops->eltwise.algo.alpha = NULL; \ + post_ops->bias.bias = NULL; \ + post_ops->sum.scale_factor = NULL; \ + if ( post_ops_str != NULL ) \ { \ - if ( strcmp( ops_tok, "bias") == 0 ) \ + char* ops_tok = strtok(post_ops_str, ", " ); \ + bool is_param_relu = FALSE; \ + while ( ops_tok ) \ { \ - post_ops->seq_vector[cur_op_index] = BIAS; \ + if ( strcmp( ops_tok, "bias") == 0 ) \ + { \ + post_ops->seq_vector[cur_op_index] = BIAS; \ + } \ + else if ( strcmp( ops_tok, "relu") == 0 ) \ + { \ + post_ops->seq_vector[cur_op_index] = ELTWISE; \ + } \ + else if ( strcmp( ops_tok, "prelu") == 0 ) \ + { \ + post_ops->seq_vector[cur_op_index] = ELTWISE; \ + is_param_relu = TRUE; \ + } \ + ops_tok = strtok( NULL, ", " ); \ + cur_op_index++; \ } \ - else if ( strcmp( ops_tok, "relu") == 0 ) \ + \ + /* Allocate bias buffer, return early if alloc fails.*/ \ + post_ops->bias.bias = malloc( n * sizeof( C_type ) ); \ + if ( post_ops->bias.bias == NULL ) \ { \ - post_ops->seq_vector[cur_op_index] = ELTWISE; \ + free( post_ops->seq_vector ); \ + free( post_ops ); \ + return NULL; \ } \ - else if ( strcmp( ops_tok, "prelu") == 0 ) \ + GEN_FUNC_NAME(fill_array_post_ops_,C_type)( post_ops->bias.bias, n ); \ + \ + post_ops->eltwise.is_power_of_2 = FALSE; \ + post_ops->eltwise.scale_factor = NULL; \ + post_ops->eltwise.algo.alpha = NULL; \ + post_ops->eltwise.algo.algo_type = RELU; \ + if ( is_param_relu == TRUE ) \ { \ - post_ops->seq_vector[cur_op_index] = ELTWISE; \ - is_param_relu = TRUE; \ + post_ops->eltwise.algo.alpha = malloc( sizeof( C_type ) ); \ + *( ( C_type* ) post_ops->eltwise.algo.alpha ) = ( C_type )6; \ + post_ops->eltwise.algo.algo_type = PRELU; \ } \ - ops_tok = strtok( NULL, ", " ); \ - cur_op_index++; \ + post_ops->eltwise.algo.beta = NULL; \ } \ - post_ops->seq_length = cur_op_index; \ \ - /* Allocate bias buffer, return early if alloc fails.*/ \ - post_ops->bias.bias = malloc( n * sizeof( C_type ) ); \ - if ( post_ops->bias.bias == NULL ) \ + if ( global_dscale_out == 'y' ) \ { \ - free( post_ops->seq_vector ); \ - free( post_ops ); \ - return NULL; \ - } \ - \ - GEN_FUNC_NAME(fill_array_post_ops_,C_type)( post_ops->bias.bias, n ); \ + post_ops->seq_vector[cur_op_index] = SCALE; \ + cur_op_index++; \ \ - post_ops->eltwise.is_power_of_2 = FALSE; \ - post_ops->eltwise.scale_factor = NULL; \ - post_ops->eltwise.algo.alpha = NULL; \ - post_ops->eltwise.algo.algo_type = RELU; \ - if ( is_param_relu == TRUE ) \ - { \ - post_ops->eltwise.algo.alpha = malloc( sizeof( C_type ) ); \ - *( ( C_type* ) post_ops->eltwise.algo.alpha ) = ( C_type )6; \ - post_ops->eltwise.algo.algo_type = PRELU; \ + post_ops->sum.is_power_of_2 = FALSE; \ + post_ops->sum.scale_factor = NULL; \ + post_ops->sum.buff = NULL; \ + post_ops->sum.zero_point = NULL; \ + if ( global_dscale_out == 'y' ) \ + { \ + /* Allocate scale buffer, return early if alloc fails.*/ \ + post_ops->sum.scale_factor = malloc( n * sizeof( DSCALE_type ) ); \ + if ( post_ops->sum.scale_factor == NULL ) \ + { \ + free ( post_ops->bias.bias ); \ + free( post_ops->seq_vector ); \ + free( post_ops ); \ + return NULL; \ + } \ + /* Fill scale factor.*/ \ + DSCALE_type* temp_dscale_ptr = ( DSCALE_type* )post_ops->sum.scale_factor; \ + for ( dim_t i = 0; i < n; ++i ) \ + { \ + temp_dscale_ptr[i] = ( ( DSCALE_type )1 )/ ( ( DSCALE_type )1000 ); \ + } \ + } \ } \ - post_ops->eltwise.algo.beta = NULL; \ + \ + post_ops->seq_length = cur_op_index; \ \ return post_ops; \ } \ -GEN_MAT_MUL_POST_OPS_CREATOR(int16_t,u8s8s16os16) -GEN_MAT_MUL_POST_OPS_CREATOR(int32_t,u8s8s32os32) -GEN_MAT_MUL_POST_OPS_CREATOR(float,f32f32f32of32) +GEN_MAT_MUL_POST_OPS_CREATOR(int16_t,float,u8s8s16os16) +GEN_MAT_MUL_POST_OPS_CREATOR(int32_t,float,u8s8s32os32) +GEN_MAT_MUL_POST_OPS_CREATOR(float,float,f32f32f32of32) void lpgemm_destroy_post_ops_struct( aocl_post_op* post_ops ) { @@ -438,11 +494,14 @@ void lpgemm_destroy_post_ops_struct( aocl_post_op* post_ops ) { free( post_ops->eltwise.algo.alpha ); } + if ( post_ops->sum.scale_factor != NULL ) + { + free( post_ops->sum.scale_factor ); + } if ( post_ops->bias.bias != NULL ) { free( post_ops->bias.bias ); } - if( post_ops->seq_vector != NULL ) { free( post_ops->seq_vector ); @@ -451,7 +510,7 @@ void lpgemm_destroy_post_ops_struct( aocl_post_op* post_ops ) free( post_ops ); } -#define GEN_MAT_MUL_BENCH_MAIN_FUNC(A_type,B_type,C_type,BLAS_SFX) \ +#define GEN_MAT_MUL_BENCH_MAIN_FUNC(A_type,B_type,C_type,BLAS_SFX,REORDER_SFX) \ void mat_mul_bench_main_ ## BLAS_SFX \ ( \ FILE* fin, \ @@ -506,9 +565,9 @@ void mat_mul_bench_main_ ## BLAS_SFX \ GEN_FUNC_NAME(fill_array_,B_type)( b, ( k * n ) ); \ \ aocl_post_op* post_op = NULL; \ - if ( post_ops_str != NULL ) \ + if ( ( post_ops_str != NULL ) || ( global_dscale_out == 'y' ) ) \ { \ - post_op = GEN_FUNC_NAME(lpgemm_create_post_ops_struct_,BLAS_SFX)( m, n, post_ops_str ); \ + post_op = GEN_FUNC_NAME(lpgemm_create_post_ops_struct_,REORDER_SFX)( m, n, post_ops_str ); \ if ( post_op == NULL ) \ { \ printf(" post op struct allocation failure, returning.\n"); \ @@ -534,10 +593,10 @@ void mat_mul_bench_main_ ## BLAS_SFX \ { \ /* Reorder B.*/ \ siz_t b_reorder_buf_siz_req = \ - GEN_FUNC_NAME(aocl_get_reorder_buf_size_,BLAS_SFX)( 'B', k, n ); \ + GEN_FUNC_NAME(aocl_get_reorder_buf_size_,REORDER_SFX)( 'B', k, n ); \ \ B_type* b_reorder = ( B_type* ) bli_malloc_user( b_reorder_buf_siz_req ); \ - GEN_FUNC_NAME(aocl_reorder_,BLAS_SFX)( 'B', b, b_reorder, k, n, stride_b ); \ + GEN_FUNC_NAME(aocl_reorder_,REORDER_SFX)( 'B', b, b_reorder, k, n, stride_b ); \ \ GEN_FUNC_NAME(mat_mul_bench_driver_,BLAS_SFX) \ ( \ @@ -555,7 +614,7 @@ void mat_mul_bench_main_ ## BLAS_SFX \ \ if ( bench_mode == 'a' ) \ { \ - printf(" Running accuracy check.\n"); \ + printf("Running accuracy check.\n"); \ GEN_FUNC_NAME(mat_mul_accuracy_check_driver_,BLAS_SFX) \ ( \ fout, m, n, k, \ @@ -589,9 +648,11 @@ void mat_mul_bench_main_ ## BLAS_SFX \ } \ } \ -GEN_MAT_MUL_BENCH_MAIN_FUNC(uint8_t, int8_t, int16_t, u8s8s16os16) -GEN_MAT_MUL_BENCH_MAIN_FUNC(uint8_t,int8_t,int32_t,u8s8s32os32) -GEN_MAT_MUL_BENCH_MAIN_FUNC(float,float,float,f32f32f32of32) +GEN_MAT_MUL_BENCH_MAIN_FUNC(uint8_t,int8_t,int16_t,u8s8s16os16,u8s8s16os16) +GEN_MAT_MUL_BENCH_MAIN_FUNC(uint8_t,int8_t,int8_t,u8s8s16os8,u8s8s16os16) +GEN_MAT_MUL_BENCH_MAIN_FUNC(uint8_t,int8_t,int32_t,u8s8s32os32,u8s8s32os32) +GEN_MAT_MUL_BENCH_MAIN_FUNC(uint8_t,int8_t,int8_t,u8s8s32os8,u8s8s32os32) +GEN_MAT_MUL_BENCH_MAIN_FUNC(float,float,float,f32f32f32of32,f32f32f32of32) int main( int argc, char** argv ) { @@ -616,7 +677,7 @@ int main( int argc, char** argv ) // Parse CLI arguments. opterr = 0; int opt_val; - while ( ( opt_val = getopt( argc, argv, "i:m:n:o:" ) ) != -1 ) + while ( ( opt_val = getopt( argc, argv, "i:m:n:o:d" ) ) != -1 ) { switch ( opt_val ) { @@ -632,6 +693,9 @@ int main( int argc, char** argv ) case 'o': post_ops_str = optarg; break; + case 'd': + global_dscale_out = 'y'; + break; default: break; } @@ -713,12 +777,24 @@ int main( int argc, char** argv ) { if ( ( op_type_char == 'i' ) || ( op_type_char == 'I' ) ) { - GEN_FUNC_NAME(mat_mul_bench_main_,u8s8s32os32) - ( - fin, fout, op_t, - m, n, k, stride_a, stride_b, stride_c, - post_ops_str_dest - ); + if ( global_dscale_out == 'n' ) + { + GEN_FUNC_NAME(mat_mul_bench_main_,u8s8s32os32) + ( + fin, fout, op_t, + m, n, k, stride_a, stride_b, stride_c, + post_ops_str_dest + ); + } + else + { + GEN_FUNC_NAME(mat_mul_bench_main_,u8s8s32os8) + ( + fin, fout, op_t, + m, n, k, stride_a, stride_b, stride_c, + post_ops_str_dest + ); + } } else if ( ( op_type_char == 'f' ) || ( op_type_char == 'F' ) ) { @@ -731,12 +807,24 @@ int main( int argc, char** argv ) } else if ((op_type_char == 's') || (op_type_char == 'S')) { - GEN_FUNC_NAME(mat_mul_bench_main_, u8s8s16os16) - ( - fin, fout, op_t, - m, n, k, stride_a, stride_b, stride_c, - post_ops_str_dest - ); + if ( global_dscale_out == 'n' ) + { + GEN_FUNC_NAME(mat_mul_bench_main_,u8s8s16os16) + ( + fin, fout, op_t, + m, n, k, stride_a, stride_b, stride_c, + post_ops_str_dest + ); + } + else + { + GEN_FUNC_NAME(mat_mul_bench_main_,u8s8s16os8) + ( + fin, fout, op_t, + m, n, k, stride_a, stride_b, stride_c, + post_ops_str_dest + ); + } } if ( post_ops_str != NULL ) { From 8f2efc4519fb70a820b0e50f3f8c3a7091a31e65 Mon Sep 17 00:00:00 2001 From: Edward Smyth Date: Tue, 30 Aug 2022 02:43:50 -0400 Subject: [PATCH 219/243] Update zen4 test in bli_cpuid.c Remove test on is_arch in bli_cpuid_is_zen4() so it will report true for all Zen4 models, based on family and AVX512 feature tests. AMD-Internal: [CPUPL-2474] Change-Id: I85d2e230b33391d5c9779df4585ae2a358788e72 (cherry picked from commit d69473c8f72735bf19fa41b43042ec6f941bcd5b) --- frame/base/bli_cpuid.c | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/frame/base/bli_cpuid.c b/frame/base/bli_cpuid.c index 4dba53080a..2796fadc05 100644 --- a/frame/base/bli_cpuid.c +++ b/frame/base/bli_cpuid.c @@ -295,14 +295,10 @@ bool bli_cpuid_is_zen4 // For zen4 the family id is 0x19 if ( family != 0x19 ) return FALSE; - // Finally, check for specific models: - // Zen 4 maps to couple of different model number ranges - // we check for all of them. - const bool is_arch - = - (0x10 <= model && model <= 0x1f ); - - if ( !is_arch ) return FALSE; + // All family 0x19 CPUs that support AVX512 instructions are zen4, + // thus no need to check model numbers here. Family 0x19 CPUs that + // don't support AVX512 are zen3. Their model ranges are tested in + // a separate function below. return TRUE; } From b4850513d6d7f5f91fb5db2091095cd9b475702b Mon Sep 17 00:00:00 2001 From: Dipal M Zambare Date: Wed, 7 Sep 2022 13:33:51 +0530 Subject: [PATCH 220/243] AOCL progress callback hardening MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - BLIS uses callback function to report the progress of the operation. The callback is implemented in the user application and is invoked by BLIS. - Updated callback function prototype to make all arguments const. This will ensure that any attempt to write using callback’s argument is prevented at the compile time itself. AMD-Internal: [CPUPL-2504] Change-Id: I8ceb671242365d2a9155b485301cd8c75043e667 --- frame/util/bli_util_progress.h | 10 +++++----- test/test_gemm.c | 12 ++++++------ test/test_trsm.c | 10 +++++----- 3 files changed, 16 insertions(+), 16 deletions(-) diff --git a/frame/util/bli_util_progress.h b/frame/util/bli_util_progress.h index 0e2a63eb1c..ed7a79cb66 100644 --- a/frame/util/bli_util_progress.h +++ b/frame/util/bli_util_progress.h @@ -37,11 +37,11 @@ // Public interface for the end user. -typedef dim_t (*AOCL_progress_callback)(char *api, - dim_t lapi, - dim_t progress, - dim_t current_thread, - dim_t total_threads); +typedef dim_t (*AOCL_progress_callback)(const char* const api, + const dim_t lapi, + const dim_t progress, + const dim_t current_thread, + const dim_t total_threads); BLIS_EXPORT_BLIS void AOCL_BLIS_set_progress(AOCL_progress_callback func); diff --git a/test/test_gemm.c b/test/test_gemm.c index 81b7e36616..cc50eb04ae 100644 --- a/test/test_gemm.c +++ b/test/test_gemm.c @@ -48,14 +48,14 @@ //#define CBLAS // Uncomment to enable progress printing. -//#define PROGRESS_ENABLED +#define PROGRESS_ENABLED #ifdef PROGRESS_ENABLED -dim_t AOCL_progress(char *api, - dim_t lapi, - dim_t progress, - dim_t current_thread, - dim_t total_threads) +dim_t AOCL_progress( const char* const api, + const dim_t lapi, + const dim_t progress, + const dim_t current_thread, + const dim_t total_threads ) { printf("\n%s, len = %ld, nt = %ld, tid = %ld, Processed %ld Elements", api, lapi, total_threads, current_thread, progress); diff --git a/test/test_trsm.c b/test/test_trsm.c index f6709f5d7f..3c48015967 100644 --- a/test/test_trsm.c +++ b/test/test_trsm.c @@ -54,11 +54,11 @@ //#define PROGRESS_ENABLED #ifdef PROGRESS_ENABLED -dim_t AOCL_progress(char *api, - dim_t lapi, - dim_t progress, - dim_t current_thread, - dim_t total_threads) +dim_t AOCL_progress( const char* const api, + const dim_t lapi, + const dim_t progress, + const dim_t current_thread, + const dim_t total_threads ) { printf("\n%s, len = %ld, nt = %ld, tid = %ld, Processed %ld Elements", api, lapi, total_threads, current_thread, progress); From 822bf0d1fe1b17565df891f51d4cba4c66bb4eb6 Mon Sep 17 00:00:00 2001 From: satish kumar nuggu Date: Fri, 9 Sep 2022 15:24:15 +0530 Subject: [PATCH 221/243] Fix in DTRSM Small MT Details: 1. Changes are made in dtrsm small MT path,to avoid accuracy issues. AMD-Internal: [SWLCSG-1470] Change-Id: I65237225892f97b7222fe71f66b02841b5956560 --- frame/compat/bla_trsm_amd.c | 34 ---------------------------------- 1 file changed, 34 deletions(-) diff --git a/frame/compat/bla_trsm_amd.c b/frame/compat/bla_trsm_amd.c index e143a88d37..13330a5d08 100644 --- a/frame/compat/bla_trsm_amd.c +++ b/frame/compat/bla_trsm_amd.c @@ -922,40 +922,6 @@ void dtrsm_ return; } } - - // bli_trsm_small_mt is performing better than native multithread - // for certain sizes of m & n. -#ifdef BLIS_ENABLE_OPENMP - rntm_t rntm; - bli_rntm_init_from_global( &rntm ); - - // Query the total number of threads from the rntm_t object. - dim_t n_threads = bli_rntm_num_threads( &rntm ); - if ( ( (n_threads > 1) && (m0 <= 1500) && (n0 <= 1500) ) || - ( (n_threads == 32) && (m0 <= 2300) && (n0 <= 2300) ) || - ( (n_threads == 16) && (m0 <= 3800) && (n0 <= 3800) ) || - ( (n_threads == 8) && (m0 <= 2800) && (n0 <= 2800) ) || - ( (n_threads == 4) && (m0 <= 2000) && (n0 <= 2000) ) || - ( (n_threads == 2) && (m0 <= 2000) && (n0 <= 2000) ) ) - { - err_t status; - status = bli_trsm_small_mt( - blis_side, - &alphao, - &ao, - &bo, - NULL, - NULL); - - if ( status == BLIS_SUCCESS ) - { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - /* Finalize BLIS. */ - bli_finalize_auto(); - return; - } - } -#endif// BLIS_ENABLE_OPENMP } // bli_cpuid_is_avx_supported #endif// END of BLIS_ENABLE_SMALL_MATRIX_TRSM From d4668c195e17d09cd3334c2181ce377241a35c80 Mon Sep 17 00:00:00 2001 From: Harihara Sudhan S Date: Wed, 14 Sep 2022 16:10:23 +0530 Subject: [PATCH 222/243] u8s8s16os16 bug fix for downscale operation - Removed some read code from the macros for downscale - Store permute correction - Simplified macros for edge cases and corrected intermediate operation AMD-Internal: [CPUPL-2171] Change-Id: I9cba497211cdc2c9267c62d3db549cce35e2d6eb --- .../u8s8s16/lpgemm_6x32rowmajor_amd256.c | 24 ++-- .../kernels/u8s8s16/lpgemm_m_fringe_amd256.c | 25 ++--- .../kernels/u8s8s16/lpgemm_mn_fringe_amd256.c | 82 +++++++------- .../kernels/u8s8s16/lpgemm_n_fringe_amd256.c | 28 +++-- .../kernels/u8s8s16/lpgemm_s16_kern_macros.h | 103 +++++++----------- 5 files changed, 105 insertions(+), 157 deletions(-) diff --git a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_6x32rowmajor_amd256.c b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_6x32rowmajor_amd256.c index 5c4e440407..f7ad5f2d23 100644 --- a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_6x32rowmajor_amd256.c +++ b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_6x32rowmajor_amd256.c @@ -582,26 +582,17 @@ LPGEMM_MAIN_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x32) (float *)post_ops_list_temp->scale_factor + post_op_c_j + (1 * 8)); - bli_mm256_s16_downscale(c_int16_0p0, c_int16_0p1, 0); - //-------------------------------------------------------------------------- + BLI_MM256_S16_DOWNSCALE(c_int16_0p0, c_int16_0p1, 0); - bli_mm256_s16_downscale(c_int16_1p0, c_int16_1p1, 1); - - //-------------------------------------------------------------------------- + BLI_MM256_S16_DOWNSCALE(c_int16_1p0, c_int16_1p1, 1); - bli_mm256_s16_downscale(c_int16_2p0, c_int16_2p1, 2); - - //-------------------------------------------------------------------------- + BLI_MM256_S16_DOWNSCALE(c_int16_2p0, c_int16_2p1, 2); - bli_mm256_s16_downscale(c_int16_3p0, c_int16_3p1, 3); + BLI_MM256_S16_DOWNSCALE(c_int16_3p0, c_int16_3p1, 3); - //-------------------------------------------------------------------------- + BLI_MM256_S16_DOWNSCALE(c_int16_4p0, c_int16_4p1, 4); - bli_mm256_s16_downscale(c_int16_4p0, c_int16_4p1, 4); - - //-------------------------------------------------------------------------- - - bli_mm256_s16_downscale(c_int16_5p0, c_int16_5p1, 5); + BLI_MM256_S16_DOWNSCALE(c_int16_5p0, c_int16_5p1, 5); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -675,6 +666,7 @@ LPGEMM_MAIN_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x32) // a pointer increment a = a + (4 * ps_a); m_full_pieces_loop_limit += 4; + post_op_c_i += 4; } if (m_partial2 == 1) @@ -692,6 +684,7 @@ LPGEMM_MAIN_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x32) // a pointer increment a = a + (2 * ps_a); m_full_pieces_loop_limit += 2; + post_op_c_i += 2; } if (m_partial == 1) @@ -704,6 +697,7 @@ LPGEMM_MAIN_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x32) alpha, beta,is_last_k, post_op_c_i, post_op_c_j, post_ops_list, rs_c_downscale); + post_op_c_i += 1; } } } diff --git a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_m_fringe_amd256.c b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_m_fringe_amd256.c index 7c8f3b57b0..4934b8b11c 100644 --- a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_m_fringe_amd256.c +++ b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_m_fringe_amd256.c @@ -370,20 +370,13 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4x32) (float *)post_ops_list_temp->scale_factor + post_op_c_j + (1 * 8)); - bli_mm256_s16_downscale(c_int16_0p0, c_int16_0p1, 0); - //-------------------------------------------------------------------------- + BLI_MM256_S16_DOWNSCALE(c_int16_0p0, c_int16_0p1, 0); - bli_mm256_s16_downscale(c_int16_1p0, c_int16_1p1, 1); - - //-------------------------------------------------------------------------- - - bli_mm256_s16_downscale(c_int16_2p0, c_int16_2p1, 2); - - //-------------------------------------------------------------------------- + BLI_MM256_S16_DOWNSCALE(c_int16_1p0, c_int16_1p1, 1); - bli_mm256_s16_downscale(c_int16_3p0, c_int16_3p1, 3); + BLI_MM256_S16_DOWNSCALE(c_int16_2p0, c_int16_2p1, 2); - //-------------------------------------------------------------------------- + BLI_MM256_S16_DOWNSCALE(c_int16_3p0, c_int16_3p1, 3); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -632,12 +625,9 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2x32) (float *)post_ops_list_temp->scale_factor + post_op_c_j + (1 * 8)); - bli_mm256_s16_downscale(c_int16_0p0, c_int16_0p1, 0); - //-------------------------------------------------------------------------- + BLI_MM256_S16_DOWNSCALE(c_int16_0p0, c_int16_0p1, 0); - bli_mm256_s16_downscale(c_int16_1p0, c_int16_1p1, 1); - - //-------------------------------------------------------------------------- + BLI_MM256_S16_DOWNSCALE(c_int16_1p0, c_int16_1p1, 1); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -814,8 +804,7 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_1x32) (float *)post_ops_list_temp->scale_factor + post_op_c_j + (1 * 8)); - bli_mm256_s16_downscale(c_int16_0p0, c_int16_0p1, 0); - //-------------------------------------------------------------------------- + BLI_MM256_S16_DOWNSCALE(c_int16_0p0, c_int16_0p1, 0); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } diff --git a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_mn_fringe_amd256.c b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_mn_fringe_amd256.c index 968959ca7f..f24455036e 100644 --- a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_mn_fringe_amd256.c +++ b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_mn_fringe_amd256.c @@ -280,13 +280,9 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4x16) (float *)post_ops_list_temp->scale_factor + post_op_c_j + (1 * 8)); - bli_mm256_s16_downscale2(c_int16_0p0, c_int16_1p0, 0, 1); + BLI_MM256_S16_DOWNSCALE2(c_int16_0p0, c_int16_1p0, 0, 1); - //-------------------------------------------------------------------------- - - bli_mm256_s16_downscale2(c_int16_2p0, c_int16_3p0, 2, 3); - - //-------------------------------------------------------------------------- + BLI_MM256_S16_DOWNSCALE2(c_int16_2p0, c_int16_3p0, 2, 3); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -548,27 +544,25 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_4xlt16) POST_OPS_DOWNSCALE_4xlt16: { __m128i temp[2]; - __m256i temp_32[2]; - __m256 temp_float[2]; - __m256 scale_1, scale_2; - __m256 res_1, res_2; - __m256i store_reg; - - float float_buf[16]; - int8 store_buf[16]; + __m256i temp_32[2]; + __m256 temp_float[2]; + __m256 scale_1, scale_2; + __m256 res_1, res_2; + __m256i store_reg; - memcpy( float_buf, ( ( float* )post_ops_list_temp->scale_factor + - post_op_c_j ), ( n0_rem * sizeof( float ) ) ); + float float_buf[16]; + int8_t store_buf[16]; - // Load the scale vector values into the register - scale_1 = _mm256_loadu_ps(float_buf + (0 * 8)); - scale_2 = _mm256_loadu_ps(float_buf + (1 * 8)); + memcpy( float_buf, ( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j ), ( n0_rem * sizeof( float ) ) ); - bli_mm256_s16_downscale2_lt16(c_int16_0p0, c_int16_1p0, 0, 1) + // Load the scale vector values into the register + scale_1 = _mm256_loadu_ps(float_buf + (0 * 8)); + scale_2 = _mm256_loadu_ps(float_buf + (1 * 8)); - //-------------------------------------------------------------------------- + BLI_MM256_S16_DOWNSCALE2_LT16(c_int16_0p0, c_int16_1p0, 0, 1) - bli_mm256_s16_downscale2_lt16(c_int16_2p0, c_int16_3p0, 2, 3) + BLI_MM256_S16_DOWNSCALE2_LT16(c_int16_2p0, c_int16_3p0, 2, 3) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -767,9 +761,7 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2x16) (float *)post_ops_list_temp->scale_factor + post_op_c_j + (1 * 8)); - bli_mm256_s16_downscale2(c_int16_0p0, c_int16_1p0, 0, 1); - - //-------------------------------------------------------------------------- + BLI_MM256_S16_DOWNSCALE2(c_int16_0p0, c_int16_1p0, 0, 1); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -944,23 +936,23 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_2xlt16) POST_OPS_DOWNSCALE_2xlt16: { __m128i temp[2]; - __m256i temp_32[2]; - __m256 temp_float[2]; - __m256 scale_1, scale_2; - __m256 res_1, res_2; - __m256i store_reg; + __m256i temp_32[2]; + __m256 temp_float[2]; + __m256 scale_1, scale_2; + __m256 res_1, res_2; + __m256i store_reg; - float float_buf[16]; - int8 store_buf[16]; + float float_buf[16]; + int8_t store_buf[16]; - memcpy( float_buf, ( ( float* )post_ops_list_temp->scale_factor + - post_op_c_j ), ( n0_rem * sizeof( float ) ) ); + memcpy( float_buf, ( ( float* )post_ops_list_temp->scale_factor + + post_op_c_j ), ( n0_rem * sizeof( float ) ) ); - // Load the scale vector values into the register - scale_1 = _mm256_loadu_ps(float_buf + (0 * 8)); - scale_2 = _mm256_loadu_ps(float_buf + (1 * 8)); + // Load the scale vector values into the register + scale_1 = _mm256_loadu_ps(float_buf + (0 * 8)); + scale_2 = _mm256_loadu_ps(float_buf + (1 * 8)); - bli_mm256_s16_downscale2_lt16(c_int16_0p0, c_int16_1p0, 0, 1) + BLI_MM256_S16_DOWNSCALE2_LT16(c_int16_0p0, c_int16_1p0, 0, 1) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -1094,7 +1086,7 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_1x16) POST_OPS_DOWNSCALE_1x16: { __m128i temp[2]; - __m256i temp_32[2]; + __m256i temp_32[2], zero_reg; __m256 temp_float[2]; __m256 scale_1, scale_2; __m256 res_1, res_2; @@ -1110,9 +1102,9 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_1x16) (float *)post_ops_list_temp->scale_factor + post_op_c_j + (1 * 8)); - temp_32[1] = _mm256_setzero_si256(); + zero_reg = _mm256_setzero_si256(); - bli_mm256_s16_downscale2_edge(c_int16_0p0, temp_32[1]) + BLI_MM256_S16_DOWNSCALE2_EDGE(c_int16_0p0, 0) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -1244,14 +1236,14 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_1xlt16) POST_OPS_DOWNSCALE_1xlt16: { __m128i temp[2]; - __m256i temp_32[2]; + __m256i temp_32[2], zero_reg; __m256 temp_float[2]; __m256 scale_1, scale_2; __m256 res_1, res_2; __m256i store_reg; float float_buf[16]; - int8 store_buf[16]; + int8_t store_buf[16]; memcpy( float_buf, ( ( float* )post_ops_list_temp->scale_factor + post_op_c_j ), ( n0_rem * sizeof( float ) ) ); @@ -1260,9 +1252,9 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_1xlt16) scale_1 = _mm256_loadu_ps(float_buf + (0 * 8)); scale_2 = _mm256_loadu_ps(float_buf + (1 * 8)); - temp_32[1] = _mm256_setzero_si256(); + zero_reg = _mm256_setzero_si256(); - bli_mm256_s16_downscale2_edge_lt16(c_int16_0p0, temp_32[1]) + BLI_MM256_S16_DOWNSCALE2_EDGE_LT16(c_int16_0p0, 0) } POST_OPS_1xlt16_DISABLE: ; diff --git a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_n_fringe_amd256.c b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_n_fringe_amd256.c index cd5a4113e5..b24d49dac7 100644 --- a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_n_fringe_amd256.c +++ b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_n_fringe_amd256.c @@ -366,15 +366,11 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x16) (float *)post_ops_list_temp->scale_factor + post_op_c_j + (1 * 8)); - bli_mm256_s16_downscale2(c_int16_0p0, c_int16_1p0, 0, 1); + BLI_MM256_S16_DOWNSCALE2(c_int16_0p0, c_int16_1p0, 0, 1); - //-------------------------------------------------------------------------- + BLI_MM256_S16_DOWNSCALE2(c_int16_2p0, c_int16_3p0, 2, 3); - bli_mm256_s16_downscale2(c_int16_2p0, c_int16_3p0, 2, 3); - - //-------------------------------------------------------------------------- - - bli_mm256_s16_downscale2(c_int16_4p0, c_int16_5p0, 4, 5); + BLI_MM256_S16_DOWNSCALE2(c_int16_4p0, c_int16_5p0, 4, 5); POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -427,6 +423,7 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x16) // a pointer increment a = a + (4 * ps_a); m_full_pieces_loop_limit += 4; + post_op_c_i += 4; } if (m_partial2 == 1) @@ -444,6 +441,7 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x16) // a pointer increment a = a + (2 * ps_a); m_full_pieces_loop_limit += 2; + post_op_c_i += 2; } if (m_partial == 1) @@ -457,6 +455,7 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6x16) is_last_k, post_op_c_i, post_op_c_j, post_ops_list, rs_c_downscale); + post_op_c_i += 1; } } } @@ -795,7 +794,7 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6xlt16) __m256i store_reg; float float_buf[16]; - int8 store_buf[16]; + int8_t store_buf[16]; memcpy( float_buf, ( ( float* )post_ops_list_temp->scale_factor + post_op_c_j ), ( n0_rem * sizeof( float ) ) ); @@ -804,15 +803,11 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6xlt16) scale_1 = _mm256_loadu_ps(float_buf + (0 * 8)); scale_2 = _mm256_loadu_ps(float_buf + (1 * 8)); - bli_mm256_s16_downscale2_lt16(c_int16_0p0, c_int16_1p0, 0, 1) - - //-------------------------------------------------------------------------- - - bli_mm256_s16_downscale2_lt16(c_int16_2p0, c_int16_3p0, 2, 3) + BLI_MM256_S16_DOWNSCALE2_LT16(c_int16_0p0, c_int16_1p0, 0, 1) - //-------------------------------------------------------------------------- + BLI_MM256_S16_DOWNSCALE2_LT16(c_int16_2p0, c_int16_3p0, 2, 3) - bli_mm256_s16_downscale2_lt16(c_int16_4p0, c_int16_5p0, 4, 5) + BLI_MM256_S16_DOWNSCALE2_LT16(c_int16_4p0, c_int16_5p0, 4, 5) POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -882,6 +877,7 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6xlt16) // a pointer increment a = a + (4 * ps_a); m_full_pieces_loop_limit += 4; + post_op_c_i += 4; } if (m_partial2 == 1) @@ -899,6 +895,7 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6xlt16) // a pointer increment a = a + (2 * ps_a); m_full_pieces_loop_limit += 2; + post_op_c_i += 2; } if (m_partial == 1) @@ -912,6 +909,7 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int16_t,u8s8s16o16_6xlt16) is_last_k, post_op_c_i, post_op_c_j, post_ops_list, rs_c_downscale); + post_op_c_i += 1; } } } diff --git a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_s16_kern_macros.h b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_s16_kern_macros.h index 21876f555e..55a763727d 100644 --- a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_s16_kern_macros.h +++ b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_s16_kern_macros.h @@ -54,7 +54,7 @@ //-------------------------------------------------------------------------- -#define bli_mm256_s16_downscale(c_int16__p0, c_int16__p1, vec_loc)\ +#define BLI_MM256_S16_DOWNSCALE(c_int16__p0, c_int16__p1, vec_loc)\ \ /* Extract the first 128 bits of the register*/\ temp[0] = _mm256_extractf128_si256(c_int16__p0, 0);\ @@ -62,16 +62,13 @@ temp[1] = _mm256_extractf128_si256(c_int16__p0, 1);\ \ temp_32[0] = _mm256_cvtepi16_epi32(temp[0]);\ - temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]);\ -\ - /* Since s16 values cannot be converted to f32 directly, - they are converted to s32, then to f32 and the scale is performed*/\ temp_32[1] = _mm256_cvtepi16_epi32(temp[1]);\ + temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]);\ temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]);\ \ /* Multiply the C matrix by the scale value*/\ res_1 = _mm256_mul_ps(temp_float[0], scale_1);\ - res_2 = _mm256_mul_ps(temp_float[0], scale_2);\ + res_2 = _mm256_mul_ps(temp_float[1], scale_2);\ \ /* Round the resultant value to the nearest integer*/\ res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));\ @@ -94,16 +91,13 @@ temp[1] = _mm256_extractf128_si256(c_int16__p1, 1);\ \ temp_32[0] = _mm256_cvtepi16_epi32(temp[0]);\ - temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]);\ -\ - /* Since s16 values cannot be converted to f32 directly, - they are converted to s32, then to f32 and the scale is performed*/\ temp_32[1] = _mm256_cvtepi16_epi32(temp[1]);\ + temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]);\ temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]);\ \ /* Multiply the C matrix by the scale value*/\ res_1 = _mm256_mul_ps(temp_float[0], scale_1);\ - res_2 = _mm256_mul_ps(temp_float[0], scale_2);\ + res_2 = _mm256_mul_ps(temp_float[1], scale_2);\ \ /* Round the resultant value to the nearest integer*/\ res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));\ @@ -121,15 +115,16 @@ \ /* Convert the s16 to s8 */\ store_reg = _mm256_packs_epi16(c_int16__p0, c_int16__p1);\ + store_reg = _mm256_permute4x64_epi64(store_reg, 0XD8);\ \ /* Store the result in s8 form */\ _mm256_storeu_si256((__m256i *)(( int8_t* )post_ops_list_temp->op_args3 + \ - ( rs_c_downscale * ( post_op_c_i + vec_loc ) ) + post_op_c_j + ( 0 * 16 )), store_reg);\ + ( rs_c_downscale * ( post_op_c_i + vec_loc ) ) + post_op_c_j), store_reg);\ \ //-------------------------------------------------------------------------- -#define bli_mm256_s16_downscale2(c_int16__p0, c_int16__p1, vec_loc1, vec_loc2)\ +#define BLI_MM256_S16_DOWNSCALE2(c_int16__p0, c_int16__p1, vec_loc1, vec_loc2)\ \ /* Extract the first 128 bits of the register*/\ temp[0] = _mm256_extractf128_si256(c_int16__p0, 0);\ @@ -137,16 +132,13 @@ temp[1] = _mm256_extractf128_si256(c_int16__p0, 1);\ \ temp_32[0] = _mm256_cvtepi16_epi32(temp[0]);\ - temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]);\ -\ - /* Since s16 values cannot be converted to f32 directly, - they are converted to s32, then to f32 and the scale is performed*/\ temp_32[1] = _mm256_cvtepi16_epi32(temp[1]);\ + temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]);\ temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]);\ \ /* Multiply the C matrix by the scale value*/\ res_1 = _mm256_mul_ps(temp_float[0], scale_1);\ - res_2 = _mm256_mul_ps(temp_float[0], scale_2);\ + res_2 = _mm256_mul_ps(temp_float[1], scale_2);\ \ /* Round the resultant value to the nearest integer*/\ res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));\ @@ -169,16 +161,13 @@ temp[1] = _mm256_extractf128_si256(c_int16__p1, 1);\ \ temp_32[0] = _mm256_cvtepi16_epi32(temp[0]);\ - temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]);\ -\ - /* Since s16 values cannot be converted to f32 directly, - they are converted to s32, then to f32 and the scale is performed*/\ temp_32[1] = _mm256_cvtepi16_epi32(temp[1]);\ + temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]);\ temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]);\ \ /* Multiply the C matrix by the scale value*/\ res_1 = _mm256_mul_ps(temp_float[0], scale_1);\ - res_2 = _mm256_mul_ps(temp_float[0], scale_2);\ + res_2 = _mm256_mul_ps(temp_float[1], scale_2);\ \ /* Round the resultant value to the nearest integer*/\ res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));\ @@ -196,6 +185,7 @@ \ /* Convert the s16 to s8 */\ store_reg = _mm256_packs_epi16(c_int16__p0, c_int16__p1);\ + store_reg = _mm256_permute4x64_epi64(store_reg, 0XD8);\ /* Extract the first 128 bits of the register*/\ temp[0] = _mm256_extractf128_si256(store_reg, 0);\ /* Extract the second 128 bits of the register*/\ @@ -203,14 +193,14 @@ \ /* Store the result in s8 form */\ _mm_storeu_si128((__m128i *)(( int8_t* )post_ops_list_temp->op_args3 + \ - ( rs_c_downscale * ( post_op_c_i + vec_loc1 ) ) + post_op_c_j + ( 0 * 16 )), temp[0]);\ + ( rs_c_downscale * ( post_op_c_i + vec_loc1 ) ) + post_op_c_j), temp[0]);\ _mm_storeu_si128((__m128i *)(( int8_t* )post_ops_list_temp->op_args3 + \ - ( rs_c_downscale * ( post_op_c_i + vec_loc2 ) ) + post_op_c_j + ( 0 * 16 )), temp[1]);\ + ( rs_c_downscale * ( post_op_c_i + vec_loc2 ) ) + post_op_c_j), temp[1]);\ \ //-------------------------------------------------------------------------- -#define bli_mm256_s16_downscale2_lt16(c_int16__p0, c_int16__p1, vec_loc1, vec_loc2)\ +#define BLI_MM256_S16_DOWNSCALE2_LT16(c_int16__p0, c_int16__p1, vec_loc1, vec_loc2)\ \ /* Extract the first 128 bits of the register*/\ temp[0] = _mm256_extractf128_si256(c_int16__p0, 0);\ @@ -218,16 +208,13 @@ temp[1] = _mm256_extractf128_si256(c_int16__p0, 1);\ \ temp_32[0] = _mm256_cvtepi16_epi32(temp[0]);\ - temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]);\ -\ - /* Since s16 values cannot be converted to f32 directly, - they are converted to s32, then to f32 and the scale is performed*/\ temp_32[1] = _mm256_cvtepi16_epi32(temp[1]);\ + temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]);\ temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]);\ \ /* Multiply the C matrix by the scale value*/\ res_1 = _mm256_mul_ps(temp_float[0], scale_1);\ - res_2 = _mm256_mul_ps(temp_float[0], scale_2);\ + res_2 = _mm256_mul_ps(temp_float[1], scale_2);\ \ /* Round the resultant value to the nearest integer*/\ res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));\ @@ -250,16 +237,13 @@ temp[1] = _mm256_extractf128_si256(c_int16__p1, 1);\ \ temp_32[0] = _mm256_cvtepi16_epi32(temp[0]);\ - temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]);\ -\ - /* Since s16 values cannot be converted to f32 directly, - they are converted to s32, then to f32 and the scale is performed*/\ temp_32[1] = _mm256_cvtepi16_epi32(temp[1]);\ + temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]);\ temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]);\ \ /* Multiply the C matrix by the scale value*/\ res_1 = _mm256_mul_ps(temp_float[0], scale_1);\ - res_2 = _mm256_mul_ps(temp_float[0], scale_2);\ + res_2 = _mm256_mul_ps(temp_float[1], scale_2);\ \ /* Round the resultant value to the nearest integer*/\ res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));\ @@ -277,6 +261,7 @@ \ /* Convert the s16 to s8 */\ store_reg = _mm256_packs_epi16(c_int16__p0, c_int16__p1);\ + store_reg = _mm256_permute4x64_epi64(store_reg, 0XD8);\ /* Extract the first 128 bits of the register*/\ temp[0] = _mm256_extractf128_si256(store_reg, 0);\ /* Extract the second 128 bits of the register*/\ @@ -285,18 +270,18 @@ /* Store the result in s8 form */\ _mm_storeu_si128((__m128i *)store_buf, temp[0]);\ memcpy( ( int8_t* )post_ops_list_temp->op_args3 + \ - ( rs_c_downscale * ( post_op_c_i + vec_loc1 ) ) + post_op_c_j + \ - ( 0 * 16 ) , store_buf, ( n0_rem * sizeof( int8_t ) ) ); \ + ( rs_c_downscale * ( post_op_c_i + vec_loc1 ) ) + post_op_c_j \ + , store_buf, ( n0_rem * sizeof( int8_t ) ) ); \ \ _mm_storeu_si128((__m128i *)store_buf, temp[1]);\ memcpy( ( int8_t* )post_ops_list_temp->op_args3 + \ - ( rs_c_downscale * ( post_op_c_i + vec_loc1 ) ) + post_op_c_j + \ - ( 0 * 16 ) , store_buf, ( n0_rem * sizeof( int8_t ) ) ); \ + ( rs_c_downscale * ( post_op_c_i + vec_loc2 ) ) + post_op_c_j \ + , store_buf, ( n0_rem * sizeof( int8_t ) ) ); \ \ //-------------------------------------------------------------------------- -#define bli_mm256_s16_downscale2_edge(c_int16__p0, c_int16__p1)\ +#define BLI_MM256_S16_DOWNSCALE2_EDGE(c_int16__p0, vec_ind)\ \ /* Extract the first 128 bits of the register*/\ temp[0] = _mm256_extractf128_si256(c_int16__p0, 0);\ @@ -304,16 +289,13 @@ temp[1] = _mm256_extractf128_si256(c_int16__p0, 1);\ \ temp_32[0] = _mm256_cvtepi16_epi32(temp[0]);\ - temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]);\ -\ - /* Since s16 values cannot be converted to f32 directly, - they are converted to s32, then to f32 and the scale is performed*/\ temp_32[1] = _mm256_cvtepi16_epi32(temp[1]);\ + temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]);\ temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]);\ \ /* Multiply the C matrix by the scale value*/\ res_1 = _mm256_mul_ps(temp_float[0], scale_1);\ - res_2 = _mm256_mul_ps(temp_float[0], scale_2);\ + res_2 = _mm256_mul_ps(temp_float[1], scale_2);\ \ /* Round the resultant value to the nearest integer*/\ res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));\ @@ -328,26 +310,21 @@ \ /*Permute to make sure the order is correct*/\ c_int16__p0 = _mm256_permute4x64_epi64(c_int16__p0, 0XD8);\ -\ - /* Convert the s32 to s16 */\ - c_int16__p1 = _mm256_packs_epi32(temp_32[0], temp_32[1]);\ -\ - /*Permute to make sure the order is correct*/\ - c_int16__p1 = _mm256_permute4x64_epi64(c_int16__p1, 0XD8);\ \ /* Convert the s16 to s8 */\ - store_reg = _mm256_packs_epi16(c_int16__p0, c_int16__p1);\ + store_reg = _mm256_packs_epi16(c_int16__p0, zero_reg);\ + store_reg = _mm256_permute4x64_epi64(store_reg, 0XD8);\ /* Extract the first 128 bits of the register*/\ temp[0] = _mm256_extractf128_si256(store_reg, 0);\ \ /* Store the result in s8 form */\ _mm_storeu_si128((__m128i *)(( int8_t* )post_ops_list_temp->op_args3 + \ - ( rs_c_downscale * ( post_op_c_i + 0 ) ) + post_op_c_j + ( 0 * 16 )), temp[0]);\ + ( rs_c_downscale * ( post_op_c_i + vec_ind ) ) + post_op_c_j), temp[0]);\ \ //-------------------------------------------------------------------------- -#define bli_mm256_s16_downscale2_edge_lt16(c_int16__p0, c_int16__p1)\ +#define BLI_MM256_S16_DOWNSCALE2_EDGE_LT16(c_int16__p0, vec_ind)\ \ /* Extract the first 128 bits of the register*/\ temp[0] = _mm256_extractf128_si256(c_int16__p0, 0);\ @@ -355,16 +332,13 @@ temp[1] = _mm256_extractf128_si256(c_int16__p0, 1);\ \ temp_32[0] = _mm256_cvtepi16_epi32(temp[0]);\ - temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]);\ -\ - /* Since s16 values cannot be converted to f32 directly, - they are converted to s32, then to f32 and the scale is performed*/\ temp_32[1] = _mm256_cvtepi16_epi32(temp[1]);\ + temp_float[0] = _mm256_cvtepi32_ps(temp_32[0]);\ temp_float[1] = _mm256_cvtepi32_ps(temp_32[1]);\ \ /* Multiply the C matrix by the scale value*/\ res_1 = _mm256_mul_ps(temp_float[0], scale_1);\ - res_2 = _mm256_mul_ps(temp_float[0], scale_2);\ + res_2 = _mm256_mul_ps(temp_float[1], scale_2);\ \ /* Round the resultant value to the nearest integer*/\ res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));\ @@ -381,15 +355,16 @@ c_int16__p0 = _mm256_permute4x64_epi64(c_int16__p0, 0XD8);\ \ /* Convert the s16 to s8 */\ - store_reg = _mm256_packs_epi16(c_int16__p0, c_int16__p1);\ + store_reg = _mm256_packs_epi16(c_int16__p0, zero_reg);\ + store_reg = _mm256_permute4x64_epi64(store_reg, 0XD8);\ /* Extract the first 128 bits of the register*/\ temp[0] = _mm256_extractf128_si256(store_reg, 0);\ \ /* Store the result in s8 form */\ _mm_storeu_si128((__m128i *)store_buf, temp[0]);\ - memcpy( ( int8_t* )post_ops_list_temp->op_args3 + \ - ( rs_c_downscale * ( post_op_c_i + 0 ) ) + post_op_c_j + \ - ( 0 * 16 ) , store_buf, ( n0_rem * sizeof( int8_t ) ) ); \ + memcpy( (( int8_t* )post_ops_list_temp->op_args3 + \ + ( rs_c_downscale * ( post_op_c_i + vec_ind ) ) + post_op_c_j) \ + ,store_buf, ( n0_rem * sizeof( int8_t ) ) ); \ \ #endif //LPGEMM_S16_KERN_MACROS_H From ebacc3798f1795a14a8e73feba3b409cb94a6609 Mon Sep 17 00:00:00 2001 From: mkadavil Date: Tue, 20 Sep 2022 14:49:04 +0530 Subject: [PATCH 223/243] Column major input support for BFloat16 gemm. -The bf16 gemm framework is modified to swap input column major matrices and compute gemm for the transposed matrices (now row major) using the existing row-major kernels. The output is written to C matrix assuming it is transposed. -Framework changes to support leading dimensions that are greater than matrix widths. -Bench changes to test low precision gemm for column major inputs. AMD-Internal: [CPUPL-2570] Change-Id: I22c76f52619fd76d0c0e41531828b437a1935495 --- addon/aocl_gemm/aocl_gemm_bf16bf16f32obf16.c | 117 +- addon/aocl_gemm/aocl_gemm_bf16bf16f32of32.c | 117 +- addon/aocl_gemm/aocl_gemm_f32f32f32of32.c | 19 +- addon/aocl_gemm/aocl_gemm_interface_apis.h | 18 +- addon/aocl_gemm/aocl_gemm_u8s8s16os16.c | 19 +- addon/aocl_gemm/aocl_gemm_u8s8s16os8.c | 19 +- addon/aocl_gemm/aocl_gemm_u8s8s32os32.c | 19 +- addon/aocl_gemm/aocl_gemm_u8s8s32os8.c | 19 +- .../aocl_gemm/frame/bf16bf16f32/lpgemm_bf16.c | 2 +- .../frame/f32f32f32/lpgemm_f32f32f32.c | 4 +- .../frame/lpgemm_5loop_interface_apis.h | 1 + addon/aocl_gemm/frame/lpgemm_post_ops.c | 7 +- addon/aocl_gemm/frame/lpgemm_post_ops.h | 5 +- .../threading/lpgemm_thread_decor_openmp.c | 6 +- .../threading/lpgemm_thread_decor_openmp.h | 2 + .../lpgemm_6x64rowmajor_bf16_amd512vnni.c | 226 ++- .../lpgemm_m_fringe_bf16_amd512vnni.c | 633 ++++++--- .../lpgemm_mn_fringe_bf16_amd512vnni.c | 1233 +++++++++++++---- .../lpgemm_n_fringe_bf16_amd512vnni.c | 448 ++++-- bench/bench_aocl_gemm/bench_input.txt | 1054 +++++++------- bench/bench_aocl_gemm/bench_lpgemm.c | 66 +- 21 files changed, 2767 insertions(+), 1267 deletions(-) diff --git a/addon/aocl_gemm/aocl_gemm_bf16bf16f32obf16.c b/addon/aocl_gemm/aocl_gemm_bf16bf16f32obf16.c index 60f66d9405..92171c79c6 100644 --- a/addon/aocl_gemm/aocl_gemm_bf16bf16f32obf16.c +++ b/addon/aocl_gemm/aocl_gemm_bf16bf16f32obf16.c @@ -76,13 +76,25 @@ AOCL_GEMM_MATMUL(bfloat16,bfloat16,bfloat16,bf16bf16f32obf16) { return; // Error. } - if ( ( order != 'r' ) && ( order != 'R' ) ) + + // Sanitize order input. + char order_use = + ( ( order == 'r' ) || ( order == 'R' ) || + ( order == 'c' ) || ( order == 'C' ) ) ? + order : 'r'; + + bool is_row_major = ( ( order_use == 'r' ) || ( order_use == 'R' ) ); + bool is_column_major = ( ( order_use == 'c' ) || ( order_use == 'C' ) ); + + // Row major input expected with leading dimensions >= row stride. + if ( ( is_row_major == TRUE ) && + ( ( lda < k ) || ( ldb < n ) || ( ldc < n ) ) ) { - return; // Only row major supported. + return; // Error. } - - // Row major input expected with leading dimensions equal to row stride. - if ( ( lda != k ) || ( ldb != n ) || ( ldc != n ) ) + // Column major input expected with leading dimensions >= column stride. + else if ( ( is_column_major == TRUE ) && + ( ( lda < m ) || ( ldb < k ) || ( ldc < m ) ) ) { return; // Error. } @@ -99,6 +111,7 @@ AOCL_GEMM_MATMUL(bfloat16,bfloat16,bfloat16,bf16bf16f32obf16) const inc_t rs_b = ldb; const inc_t cs_b = 1; const inc_t rs_c = ldc; + const inc_t cs_c = 1; AOCL_MEMORY_TAG mtag_a; AOCL_MEMORY_TAG mtag_b; @@ -110,20 +123,34 @@ AOCL_GEMM_MATMUL(bfloat16,bfloat16,bfloat16,bf16bf16f32obf16) // and used in bf16 instrution. As such the mtag_b always needs to be either // packed or reordered. B matrix as it is (unpacked) cannot be used, and // the mtag_b is set to packed to enable runtime packing. - if ( mtag_b == UNPACKED ) + if ( ( is_row_major == TRUE ) && ( mtag_b == UNPACKED ) ) { mtag_b = PACK; } + // Inputs swapped in column major, A becomes B from kernel point of view. + else if ( ( is_column_major == TRUE ) && ( mtag_a == UNPACKED ) ) + { + mtag_a = PACK; + } // Only unpacked A supported now. - if ( mtag_a != UNPACKED ) + if ( ( is_row_major == TRUE ) && ( mtag_a != UNPACKED ) ) + { + return; // Error. + } + // Inputs swapped in column major, B becomes A from kernel point of view. + else if ( ( is_column_major == TRUE ) && ( mtag_b != UNPACKED ) ) { return; // Error. } // Convert post op struct to post op linked list format. lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS]; - lpgemm_translate_to_post_ops_list( post_op_unparsed, post_op_list, ( void* )c ); + lpgemm_translate_to_post_ops_list + ( + post_op_unparsed, post_op_list, + ( void* )c, ( void* )( &order_use ) + ); // Initialize a local runtime with global settings if necessary. Note // that in the case that a runtime is passed in, we make a local copy. @@ -132,26 +159,60 @@ AOCL_GEMM_MATMUL(bfloat16,bfloat16,bfloat16,bf16bf16f32obf16) bli_membrk_rntm_set_membrk( &rntm_g ); #ifdef BLIS_ENABLE_OPENMP - lpgemm_bf16bf16f32of32_openmp_thread_decorator - ( - m, n, k, - a, rs_a, cs_a, mtag_a, - b, rs_b, cs_b, mtag_b, - ( float* )c, rs_c, - alpha, beta, - &rntm_g, - post_op_list, TRUE - ); + // Swapping inputs to induce row major computation for column major inputs. + if ( is_column_major == TRUE ) + { + lpgemm_bf16bf16f32of32_openmp_thread_decorator + ( + n, m, k, + b, rs_b, cs_b, mtag_b, + a, rs_a, cs_a, mtag_a, + ( float* )c, rs_c, cs_c, + alpha, beta, + &rntm_g, + post_op_list, TRUE + ); + } + else + { + lpgemm_bf16bf16f32of32_openmp_thread_decorator + ( + m, n, k, + a, rs_a, cs_a, mtag_a, + b, rs_b, cs_b, mtag_b, + ( float* )c, rs_c, cs_c, + alpha, beta, + &rntm_g, + post_op_list, TRUE + ); + } #else - lpgemm_bf16bf16f32of32_thread_decorator - ( - m, n, k, - a, rs_a, cs_a, mtag_a, - b, rs_b, cs_b, mtag_b, - ( float* )c, rs_c, - alpha, beta, - &rntm_g, - post_op_list, TRUE - ); + // Swapping inputs to induce row major computation for column major inputs. + if ( is_column_major == TRUE ) + { + lpgemm_bf16bf16f32of32_thread_decorator + ( + n, m, k, + b, rs_b, cs_b, mtag_b, + a, rs_a, cs_a, mtag_a, + ( float* )c, rs_c, cs_c, + alpha, beta, + &rntm_g, + post_op_list, TRUE + ); + } + else + { + lpgemm_bf16bf16f32of32_thread_decorator + ( + m, n, k, + a, rs_a, cs_a, mtag_a, + b, rs_b, cs_b, mtag_b, + ( float* )c, rs_c, cs_c, + alpha, beta, + &rntm_g, + post_op_list, TRUE + ); + } #endif } diff --git a/addon/aocl_gemm/aocl_gemm_bf16bf16f32of32.c b/addon/aocl_gemm/aocl_gemm_bf16bf16f32of32.c index faedc060dc..731afb129a 100644 --- a/addon/aocl_gemm/aocl_gemm_bf16bf16f32of32.c +++ b/addon/aocl_gemm/aocl_gemm_bf16bf16f32of32.c @@ -76,13 +76,25 @@ AOCL_GEMM_MATMUL(bfloat16,bfloat16,float,bf16bf16f32of32) { return; // Error. } - if ( ( order != 'r' ) && ( order != 'R' ) ) + + // Sanitize order input. + char order_use = + ( ( order == 'r' ) || ( order == 'R' ) || + ( order == 'c' ) || ( order == 'C' ) ) ? + order : 'r'; + + bool is_row_major = ( ( order_use == 'r' ) || ( order_use == 'R' ) ); + bool is_column_major = ( ( order_use == 'c' ) || ( order_use == 'C' ) ); + + // Row major input expected with leading dimensions >= row stride. + if ( ( is_row_major == TRUE ) && + ( ( lda < k ) || ( ldb < n ) || ( ldc < n ) ) ) { - return; // Only row major supported. + return; // Error. } - - // Row major input expected with leading dimensions equal to row stride. - if ( ( lda != k ) || ( ldb != n ) || ( ldc != n ) ) + // Column major input expected with leading dimensions >= column stride. + else if ( ( is_column_major == TRUE ) && + ( ( lda < m ) || ( ldb < k ) || ( ldc < m ) ) ) { return; // Error. } @@ -99,6 +111,7 @@ AOCL_GEMM_MATMUL(bfloat16,bfloat16,float,bf16bf16f32of32) const inc_t rs_b = ldb; const inc_t cs_b = 1; const inc_t rs_c = ldc; + const inc_t cs_c = 1; AOCL_MEMORY_TAG mtag_a; AOCL_MEMORY_TAG mtag_b; @@ -110,20 +123,34 @@ AOCL_GEMM_MATMUL(bfloat16,bfloat16,float,bf16bf16f32of32) // and used in bf16 instrution. As such the mtag_b always needs to be either // packed or reordered. B matrix as it is (unpacked) cannot be used, and // the mtag_b is set to packed to enable runtime packing. - if ( mtag_b == UNPACKED ) + if ( ( is_row_major == TRUE ) && ( mtag_b == UNPACKED ) ) { mtag_b = PACK; } + // Inputs swapped in column major, A becomes B from kernel point of view. + else if ( ( is_column_major == TRUE ) && ( mtag_a == UNPACKED ) ) + { + mtag_a = PACK; + } // Only unpacked A supported now. - if ( mtag_a != UNPACKED ) + if ( ( is_row_major == TRUE ) && ( mtag_a != UNPACKED ) ) + { + return; // Error. + } + // Inputs swapped in column major, B becomes A from kernel point of view. + else if ( ( is_column_major == TRUE ) && ( mtag_b != UNPACKED ) ) { return; // Error. } // Convert post op struct to post op linked list format. lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS]; - lpgemm_translate_to_post_ops_list( post_op_unparsed, post_op_list, ( void* )c ); + lpgemm_translate_to_post_ops_list + ( + post_op_unparsed, post_op_list, + ( void* )c, ( void* )( &order_use ) + ); // Initialize a local runtime with global settings if necessary. Note // that in the case that a runtime is passed in, we make a local copy. @@ -132,26 +159,60 @@ AOCL_GEMM_MATMUL(bfloat16,bfloat16,float,bf16bf16f32of32) bli_membrk_rntm_set_membrk( &rntm_g ); #ifdef BLIS_ENABLE_OPENMP - lpgemm_bf16bf16f32of32_openmp_thread_decorator - ( - m, n, k, - a, rs_a, cs_a, mtag_a, - b, rs_b, cs_b, mtag_b, - c, rs_c, - alpha, beta, - &rntm_g, - post_op_list, FALSE - ); + // Swapping inputs to induce row major computation for column major inputs. + if ( is_column_major == TRUE ) + { + lpgemm_bf16bf16f32of32_openmp_thread_decorator + ( + n, m, k, + b, rs_b, cs_b, mtag_b, + a, rs_a, cs_a, mtag_a, + c, rs_c, cs_c, + alpha, beta, + &rntm_g, + post_op_list, FALSE + ); + } + else + { + lpgemm_bf16bf16f32of32_openmp_thread_decorator + ( + m, n, k, + a, rs_a, cs_a, mtag_a, + b, rs_b, cs_b, mtag_b, + c, rs_c, cs_c, + alpha, beta, + &rntm_g, + post_op_list, FALSE + ); + } #else - lpgemm_bf16bf16f32of32_thread_decorator - ( - m, n, k, - a, rs_a, cs_a, mtag_a, - b, rs_b, cs_b, mtag_b, - c, rs_c, - alpha, beta, - &rntm_g, - post_op_list, FALSE - ); + // Swapping inputs to induce row major computation for column major inputs. + if ( is_column_major == TRUE ) + { + lpgemm_bf16bf16f32of32_thread_decorator + ( + n, m, k, + b, rs_b, cs_b, mtag_b, + a, rs_a, cs_a, mtag_a, + c, rs_c, cs_c, + alpha, beta, + &rntm_g, + post_op_list, FALSE + ); + } + else + { + lpgemm_bf16bf16f32of32_thread_decorator + ( + m, n, k, + a, rs_a, cs_a, mtag_a, + b, rs_b, cs_b, mtag_b, + c, rs_c, cs_c, + alpha, beta, + &rntm_g, + post_op_list, FALSE + ); + } #endif } diff --git a/addon/aocl_gemm/aocl_gemm_f32f32f32of32.c b/addon/aocl_gemm/aocl_gemm_f32f32f32of32.c index 90c3ff41c2..58f2675df3 100644 --- a/addon/aocl_gemm/aocl_gemm_f32f32f32of32.c +++ b/addon/aocl_gemm/aocl_gemm_f32f32f32of32.c @@ -80,7 +80,13 @@ AOCL_GEMM_MATMUL(float,float,float,f32f32f32of32) "Input matrix transpose not supported."); return; // Error. } - if ( ( order != 'r' ) && ( order != 'R' ) ) + + // Sanitize order input. + char order_use = + ( ( order == 'r' ) || ( order == 'R' ) || + ( order == 'c' ) || ( order == 'C' ) ) ? + order : 'r'; + if ( ( order_use != 'r' ) && ( order_use != 'R' ) ) { return; // Only row major supported. } @@ -107,6 +113,7 @@ AOCL_GEMM_MATMUL(float,float,float,f32f32f32of32) const inc_t rs_b = ldb; const inc_t cs_b = 1; const inc_t rs_c = ldc; + const inc_t cs_c = 1; AOCL_MEMORY_TAG mtag_a; AOCL_MEMORY_TAG mtag_b; @@ -124,7 +131,11 @@ AOCL_GEMM_MATMUL(float,float,float,f32f32f32of32) // Convert post op struct to post op linked list format. lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS]; - lpgemm_translate_to_post_ops_list( post_op_unparsed, post_op_list, ( void* )c ); + lpgemm_translate_to_post_ops_list + ( + post_op_unparsed, post_op_list, + ( void* )c, ( void* )( &order_use ) + ); // Initialize a local runtime with global settings if necessary. Note // that in the case that a runtime is passed in, we make a local copy. @@ -138,7 +149,7 @@ AOCL_GEMM_MATMUL(float,float,float,f32f32f32of32) m, n, k, a, rs_a, cs_a, mtag_a, b, rs_b, cs_b, mtag_b, - c, rs_c, + c, rs_c, cs_c, alpha, beta, &rntm_g, post_op_list, FALSE @@ -152,7 +163,7 @@ AOCL_GEMM_MATMUL(float,float,float,f32f32f32of32) m, n, k, a, rs_a, cs_a, mtag_a, b, rs_b, cs_b, mtag_b, - c, rs_c, + c, rs_c, cs_c, alpha, beta, &rntm_g, post_op_list, FALSE diff --git a/addon/aocl_gemm/aocl_gemm_interface_apis.h b/addon/aocl_gemm/aocl_gemm_interface_apis.h index 69ee4ad8e1..c2840c3468 100644 --- a/addon/aocl_gemm/aocl_gemm_interface_apis.h +++ b/addon/aocl_gemm/aocl_gemm_interface_apis.h @@ -58,12 +58,12 @@ AOCL_GEMM_GET_REORDER_BUF_SIZE(bf16bf16f32of32); #define AOCL_GEMM_REORDER(B_type,LP_SFX) \ BLIS_EXPORT_ADDON void aocl_reorder_ ## LP_SFX \ ( \ - const char mat_type, \ + const char mat_type, \ const B_type* input_buf_addr, \ B_type* reorder_buf_addr, \ - const dim_t k, \ - const dim_t n, \ - const dim_t ldb \ + const dim_t k, \ + const dim_t n, \ + const dim_t ldb \ ) \ AOCL_GEMM_REORDER(float,f32f32f32of32); @@ -83,15 +83,15 @@ BLIS_EXPORT_ADDON void aocl_gemm_ ## LP_SFX \ const dim_t m, \ const dim_t n, \ const dim_t k, \ - const C_type alpha, \ - const A_type* a, \ + const C_type alpha, \ + const A_type* a, \ const dim_t lda, \ const char mem_format_a, \ - const B_type* b, \ + const B_type* b, \ const dim_t ldb, \ const char mem_format_b, \ - const C_type beta, \ - C_type* c, \ + const C_type beta, \ + C_type* c, \ const dim_t ldc, \ aocl_post_op* post_op_unparsed \ ) \ diff --git a/addon/aocl_gemm/aocl_gemm_u8s8s16os16.c b/addon/aocl_gemm/aocl_gemm_u8s8s16os16.c index 20b987d3b6..08084fbd7a 100644 --- a/addon/aocl_gemm/aocl_gemm_u8s8s16os16.c +++ b/addon/aocl_gemm/aocl_gemm_u8s8s16os16.c @@ -76,7 +76,13 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int16_t,u8s8s16os16) { return; // Error. } - if ( ( order != 'r' ) && ( order != 'R' ) ) + + // Sanitize order input. + char order_use = + ( ( order == 'r' ) || ( order == 'R' ) || + ( order == 'c' ) || ( order == 'C' ) ) ? + order : 'r'; + if ( ( order_use != 'r' ) && ( order_use != 'R' ) ) { return; // Only row major supported. } @@ -98,6 +104,7 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int16_t,u8s8s16os16) const inc_t rs_b = ldb; const inc_t cs_b = 1; const inc_t rs_c = ldc; + const inc_t cs_c = 1; AOCL_MEMORY_TAG mtag_a; AOCL_MEMORY_TAG mtag_b; @@ -122,7 +129,11 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int16_t,u8s8s16os16) // Convert post op struct to post op linked list format. lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS]; - lpgemm_translate_to_post_ops_list( post_op_unparsed, post_op_list, ( void* )c ); + lpgemm_translate_to_post_ops_list + ( + post_op_unparsed, post_op_list, + ( void* )c, ( void* )( &order_use ) + ); // Initialize a local runtime with global settings if necessary. Note // that in the case that a runtime is passed in, we make a local copy. @@ -136,7 +147,7 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int16_t,u8s8s16os16) m, n, k, a, rs_a, cs_a, mtag_a, b, rs_b, cs_b, mtag_b, - c, rs_c, + c, rs_c, cs_c, alpha, beta, &rntm_g, post_op_list, FALSE @@ -147,7 +158,7 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int16_t,u8s8s16os16) m, n, k, a, rs_a, cs_a, mtag_a, b, rs_b, cs_b, mtag_b, - c, rs_c, + c, rs_c, cs_c, alpha, beta, &rntm_g, post_op_list, FALSE diff --git a/addon/aocl_gemm/aocl_gemm_u8s8s16os8.c b/addon/aocl_gemm/aocl_gemm_u8s8s16os8.c index 90aeb4c7b8..6df22b14ce 100644 --- a/addon/aocl_gemm/aocl_gemm_u8s8s16os8.c +++ b/addon/aocl_gemm/aocl_gemm_u8s8s16os8.c @@ -76,7 +76,13 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int8_t,u8s8s16os8) { return; // Error. } - if ( ( order != 'r' ) && ( order != 'R' ) ) + + // Sanitize order input. + char order_use = + ( ( order == 'r' ) || ( order == 'R' ) || + ( order == 'c' ) || ( order == 'C' ) ) ? + order : 'r'; + if ( ( order_use != 'r' ) && ( order_use != 'R' ) ) { return; // Only row major supported. } @@ -98,6 +104,7 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int8_t,u8s8s16os8) const inc_t rs_b = ldb; const inc_t cs_b = 1; const inc_t rs_c = ldc; + const inc_t cs_c = 1; AOCL_MEMORY_TAG mtag_a; AOCL_MEMORY_TAG mtag_b; @@ -122,7 +129,11 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int8_t,u8s8s16os8) // Convert post op struct to post op linked list format. lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS]; - lpgemm_translate_to_post_ops_list( post_op_unparsed, post_op_list, ( void* )c ); + lpgemm_translate_to_post_ops_list + ( + post_op_unparsed, post_op_list, + ( void* )c, ( void* )( &order_use ) + ); // Initialize a local runtime with global settings if necessary. Note // that in the case that a runtime is passed in, we make a local copy. @@ -136,7 +147,7 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int8_t,u8s8s16os8) m, n, k, a, rs_a, cs_a, mtag_a, b, rs_b, cs_b, mtag_b, - ( int16_t* )c, rs_c, + ( int16_t* )c, rs_c, cs_c, alpha, beta, &rntm_g, post_op_list, TRUE @@ -147,7 +158,7 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int8_t,u8s8s16os8) m, n, k, a, rs_a, cs_a, mtag_a, b, rs_b, cs_b, mtag_b, - ( int16_t* )c, rs_c, + ( int16_t* )c, rs_c, cs_c, alpha, beta, &rntm_g, post_op_list, TRUE diff --git a/addon/aocl_gemm/aocl_gemm_u8s8s32os32.c b/addon/aocl_gemm/aocl_gemm_u8s8s32os32.c index 8f75db1274..29a0c74936 100644 --- a/addon/aocl_gemm/aocl_gemm_u8s8s32os32.c +++ b/addon/aocl_gemm/aocl_gemm_u8s8s32os32.c @@ -76,7 +76,13 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int32_t,u8s8s32os32) { return; // Error. } - if ( ( order != 'r' ) && ( order != 'R' ) ) + + // Sanitize order input. + char order_use = + ( ( order == 'r' ) || ( order == 'R' ) || + ( order == 'c' ) || ( order == 'C' ) ) ? + order : 'r'; + if ( ( order_use != 'r' ) && ( order_use != 'R' ) ) { return; // Only row major supported. } @@ -99,6 +105,7 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int32_t,u8s8s32os32) const inc_t rs_b = ldb; const inc_t cs_b = 1; const inc_t rs_c = ldc; + const inc_t cs_c = 1; AOCL_MEMORY_TAG mtag_a; AOCL_MEMORY_TAG mtag_b; @@ -123,7 +130,11 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int32_t,u8s8s32os32) // Convert post op struct to post op linked list format. lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS]; - lpgemm_translate_to_post_ops_list( post_op_unparsed, post_op_list, ( void* )c ); + lpgemm_translate_to_post_ops_list + ( + post_op_unparsed, post_op_list, + ( void* )c, ( void* )( &order_use ) + ); // Initialize a local runtime with global settings if necessary. Note // that in the case that a runtime is passed in, we make a local copy. @@ -137,7 +148,7 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int32_t,u8s8s32os32) m, n, k, a, rs_a, cs_a, mtag_a, b, rs_b, cs_b, mtag_b, - c, rs_c, + c, rs_c, cs_c, alpha, beta, &rntm_g, post_op_list, FALSE @@ -148,7 +159,7 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int32_t,u8s8s32os32) m, n, k, a, rs_a, cs_a, mtag_a, b, rs_b, cs_b, mtag_b, - c, rs_c, + c, rs_c, cs_c, alpha, beta, &rntm_g, post_op_list, FALSE diff --git a/addon/aocl_gemm/aocl_gemm_u8s8s32os8.c b/addon/aocl_gemm/aocl_gemm_u8s8s32os8.c index df90de4968..da2cb1a587 100644 --- a/addon/aocl_gemm/aocl_gemm_u8s8s32os8.c +++ b/addon/aocl_gemm/aocl_gemm_u8s8s32os8.c @@ -76,7 +76,13 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int8_t,u8s8s32os8) { return; // Error. } - if ( ( order != 'r' ) && ( order != 'R' ) ) + + // Sanitize order input. + char order_use = + ( ( order == 'r' ) || ( order == 'R' ) || + ( order == 'c' ) || ( order == 'C' ) ) ? + order : 'r'; + if ( ( order_use != 'r' ) && ( order_use != 'R' ) ) { return; // Only row major supported. } @@ -99,6 +105,7 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int8_t,u8s8s32os8) const inc_t rs_b = ldb; const inc_t cs_b = 1; const inc_t rs_c = ldc; + const inc_t cs_c = 1; AOCL_MEMORY_TAG mtag_a; AOCL_MEMORY_TAG mtag_b; @@ -123,7 +130,11 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int8_t,u8s8s32os8) // Convert post op struct to post op linked list format. lpgemm_post_op post_op_list[AOCL_MAX_POST_OPS]; - lpgemm_translate_to_post_ops_list( post_op_unparsed, post_op_list, ( void* )c ); + lpgemm_translate_to_post_ops_list + ( + post_op_unparsed, post_op_list, + ( void* )c, ( void* )( &order_use ) + ); // Initialize a local runtime with global settings if necessary. Note // that in the case that a runtime is passed in, we make a local copy. @@ -137,7 +148,7 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int8_t,u8s8s32os8) m, n, k, a, rs_a, cs_a, mtag_a, b, rs_b, cs_b, mtag_b, - ( int32_t* )c, rs_c, + ( int32_t* )c, rs_c, cs_c, alpha, beta, &rntm_g, post_op_list, TRUE @@ -148,7 +159,7 @@ AOCL_GEMM_MATMUL(uint8_t,int8_t,int8_t,u8s8s32os8) m, n, k, a, rs_a, cs_a, mtag_a, b, rs_b, cs_b, mtag_b, - ( int32_t* )c, rs_c, + ( int32_t* )c, rs_c, cs_c, alpha, beta, &rntm_g, post_op_list, TRUE diff --git a/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_bf16.c b/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_bf16.c index a92c70c8fc..738840f78e 100644 --- a/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_bf16.c +++ b/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_bf16.c @@ -142,7 +142,7 @@ LPGEMM_5LOOP(bfloat16,bfloat16,float,bf16bf16f32of32) for ( dim_t i_dscale = ic_start; i_dscale < ic_end; ++i_dscale ) { j_temp = 0; - for ( dim_t j_dscale = jc; j_dscale < nc0; ++j_dscale ) + for ( dim_t j_dscale = jc; j_dscale < ( jc + nc0 ); ++j_dscale ) { // Implemented with the idea sizeof(float)=4. temp_conv_buf = 0; diff --git a/addon/aocl_gemm/frame/f32f32f32/lpgemm_f32f32f32.c b/addon/aocl_gemm/frame/f32f32f32/lpgemm_f32f32f32.c index 41a5151259..6242ceebe8 100644 --- a/addon/aocl_gemm/frame/f32f32f32/lpgemm_f32f32f32.c +++ b/addon/aocl_gemm/frame/f32f32f32/lpgemm_f32f32f32.c @@ -78,7 +78,7 @@ LPGEMM_5LOOP(float,float,float,f32f32f32of32) float* c_use_ic = NULL; // Only supporting row major with unit column strided C for now. - const dim_t cs_c = 1; + const dim_t cs_c_use = 1; /* Compute partitioning step values for each matrix of each loop. */ inc_t ps_a_use; @@ -222,7 +222,7 @@ LPGEMM_5LOOP(float,float,float,f32f32f32of32) ( float* )a_use, rs_a_use, cs_a_use, ( float* )( b_use + ( jr * ps_b_use ) ), rs_b_use, cs_b_use, &beta0, - ( c_use_ic + jr ), rs_c, cs_c, + ( c_use_ic + jr ), rs_c, cs_c_use, &aux, cntx ); } diff --git a/addon/aocl_gemm/frame/lpgemm_5loop_interface_apis.h b/addon/aocl_gemm/frame/lpgemm_5loop_interface_apis.h index ef63ef09eb..45328669de 100644 --- a/addon/aocl_gemm/frame/lpgemm_5loop_interface_apis.h +++ b/addon/aocl_gemm/frame/lpgemm_5loop_interface_apis.h @@ -55,6 +55,7 @@ void lpgemm_rowvar_ ## LP_SFX \ const AOCL_MEMORY_TAG mtag_b, \ C_type* c, \ const dim_t rs_c, \ + const dim_t cs_c, \ C_type alpha, \ C_type beta, \ rntm_t* rntm, \ diff --git a/addon/aocl_gemm/frame/lpgemm_post_ops.c b/addon/aocl_gemm/frame/lpgemm_post_ops.c index 3b25527a20..63fb25765f 100644 --- a/addon/aocl_gemm/frame/lpgemm_post_ops.c +++ b/addon/aocl_gemm/frame/lpgemm_post_ops.c @@ -59,7 +59,8 @@ void lpgemm_translate_to_post_ops_list ( aocl_post_op* post_op_unparsed, lpgemm_post_op* post_op_list, - void* scale_buffer + void* scale_buffer, + void* meta_arg ) { if ( post_op_unparsed == NULL ) @@ -129,7 +130,7 @@ void lpgemm_translate_to_post_ops_list ( ( post_op_list + i ), POST_OPS_BIAS, post_op_unparsed->bias.bias, - NULL, NULL, NULL, FALSE + meta_arg, NULL, NULL, FALSE ); break; case SCALE: @@ -137,7 +138,7 @@ void lpgemm_translate_to_post_ops_list ( ( post_op_list + i ), POST_OPS_DOWNSCALE, post_op_unparsed->sum.zero_point, - NULL, scale_buffer, + meta_arg, scale_buffer, post_op_unparsed->sum.scale_factor, FALSE ); break; diff --git a/addon/aocl_gemm/frame/lpgemm_post_ops.h b/addon/aocl_gemm/frame/lpgemm_post_ops.h index ce26fe1397..3932daf602 100644 --- a/addon/aocl_gemm/frame/lpgemm_post_ops.h +++ b/addon/aocl_gemm/frame/lpgemm_post_ops.h @@ -50,7 +50,7 @@ typedef struct lpgemm_post_op_t { LPGEMM_POST_OP_CODE op_code; void* op_args1; - void* op_args2; // alpha, zero_point + void* op_args2; // alpha, zero_point, storage order void* op_args3; // beta, downscale buffer/original C matrix void* scale_factor; bool is_power_of_2; @@ -61,7 +61,8 @@ void lpgemm_translate_to_post_ops_list ( aocl_post_op* post_op_unparsed, lpgemm_post_op* post_op_list, - void* scale_buffer + void* scale_buffer, + void* meta_arg ); #define POST_OP_LABEL_LASTK_SAFE_JUMP \ diff --git a/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.c b/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.c index bf4f622dea..0c1df5e7c3 100644 --- a/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.c +++ b/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.c @@ -481,6 +481,7 @@ void lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \ const AOCL_MEMORY_TAG mtag_b, \ C_type* c, \ const dim_t rs_c, \ + const dim_t cs_c, \ C_type alpha, \ C_type beta, \ rntm_t* rntm_g, \ @@ -540,7 +541,7 @@ void lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \ m, n, k, \ a, rs_a, cs_a, mtag_a, \ b, rs_b, cs_b, mtag_b, \ - c, rs_c, \ + c, rs_c, cs_c,\ alpha, \ beta, \ &rntm_l, \ @@ -577,6 +578,7 @@ void lpgemm_ ## LPGEMM_SFX ## _thread_decorator \ const AOCL_MEMORY_TAG mtag_b, \ C_type* c, \ const dim_t rs_c, \ + const dim_t cs_c, \ C_type alpha, \ C_type beta, \ rntm_t* rntm_g, \ @@ -615,7 +617,7 @@ void lpgemm_ ## LPGEMM_SFX ## _thread_decorator \ m, n, k, \ a, rs_a, cs_a, mtag_a, \ b, rs_b, cs_b, mtag_b, \ - c, rs_c, \ + c, rs_c, cs_c, \ alpha, \ beta, \ rntm_g, \ diff --git a/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.h b/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.h index 78e01291ac..8055d623e6 100644 --- a/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.h +++ b/addon/aocl_gemm/frame/threading/lpgemm_thread_decor_openmp.h @@ -57,6 +57,7 @@ void lpgemm_ ## LPGEMM_SFX ## _openmp_thread_decorator \ const AOCL_MEMORY_TAG mtag_b, \ C_type* c, \ const dim_t rs_c, \ + const dim_t cs_c, \ C_type alpha, \ C_type beta, \ rntm_t* rntm_g, \ @@ -87,6 +88,7 @@ void lpgemm_ ## LPGEMM_SFX ## _thread_decorator \ const AOCL_MEMORY_TAG mtag_b, \ C_type* c, \ const dim_t rs_c, \ + const dim_t cs_c, \ C_type alpha, \ C_type beta, \ rntm_t* rntm_g, \ diff --git a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_6x64rowmajor_bf16_amd512vnni.c b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_6x64rowmajor_bf16_amd512vnni.c index 87f713705e..14dc7e8e57 100644 --- a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_6x64rowmajor_bf16_amd512vnni.c +++ b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_6x64rowmajor_bf16_amd512vnni.c @@ -547,90 +547,196 @@ LPGEMM_MAIN_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x64) POST_OP_LABEL_LASTK_SAFE_JUMP POST_OPS_BIAS_6x64: { - selector1 = - _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 0 * 16 ) ); - selector2 = - _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 1 * 16 ) ); - __m512 selector3 = - _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 2 * 16 ) ); - __m512 selector4 = - _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 3 * 16 ) ); + __m512 selector3; + __m512 selector4; - // c[0,0-15] - c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + selector1 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 0 * 16 ) ); + selector2 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 1 * 16 ) ); + selector3 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 2 * 16 ) ); + selector4 = + _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + + post_op_c_j + ( 3 * 16 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector3, c_float_0p2 ); + + // c[0,48-63] + c_float_0p3 = _mm512_add_ps( selector4, c_float_0p3 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector3, c_float_1p2 ); + + // c[1,48-63] + c_float_1p3 = _mm512_add_ps( selector4, c_float_1p3 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector2, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); + + // c[2,48-63] + c_float_2p3 = _mm512_add_ps( selector4, c_float_2p3 ); - // c[0, 16-31] - c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); - // c[0,32-47] - c_float_0p2 = _mm512_add_ps( selector3, c_float_0p2 ); + // c[3, 16-31] + c_float_3p1 = _mm512_add_ps( selector2, c_float_3p1 ); - // c[0,48-63] - c_float_0p3 = _mm512_add_ps( selector4, c_float_0p3 ); + // c[3,32-47] + c_float_3p2 = _mm512_add_ps( selector3, c_float_3p2 ); - // c[1,0-15] - c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + // c[3,48-63] + c_float_3p3 = _mm512_add_ps( selector4, c_float_3p3 ); - // c[1, 16-31] - c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + // c[4,0-15] + c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); - // c[1,32-47] - c_float_1p2 = _mm512_add_ps( selector3, c_float_1p2 ); + // c[4, 16-31] + c_float_4p1 = _mm512_add_ps( selector2, c_float_4p1 ); - // c[1,48-63] - c_float_1p3 = _mm512_add_ps( selector4, c_float_1p3 ); + // c[4,32-47] + c_float_4p2 = _mm512_add_ps( selector3, c_float_4p2 ); - // c[2,0-15] - c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + // c[4,48-63] + c_float_4p3 = _mm512_add_ps( selector4, c_float_4p3 ); - // c[2, 16-31] - c_float_2p1 = _mm512_add_ps( selector2, c_float_2p1 ); + // c[5,0-15] + c_float_5p0 = _mm512_add_ps( selector1, c_float_5p0 ); - // c[2,32-47] - c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); + // c[5, 16-31] + c_float_5p1 = _mm512_add_ps( selector2, c_float_5p1 ); - // c[2,48-63] - c_float_2p3 = _mm512_add_ps( selector4, c_float_2p3 ); + // c[5,32-47] + c_float_5p2 = _mm512_add_ps( selector3, c_float_5p2 ); - // c[3,0-15] - c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + // c[5,48-63] + c_float_5p3 = _mm512_add_ps( selector4, c_float_5p3 ); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the bias array will be accessed by + // the ic index, and each bias element corresponds to an + // entire row of the transposed output array, instead of an + // entire column. + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 2 ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 3 ) ); + a_bf16_0 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 4 ) ); + a_bf16_1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 5 ) ); - // c[3, 16-31] - c_float_3p1 = _mm512_add_ps( selector2, c_float_3p1 ); + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); - // c[3,32-47] - c_float_3p2 = _mm512_add_ps( selector3, c_float_3p2 ); + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); - // c[3,48-63] - c_float_3p3 = _mm512_add_ps( selector4, c_float_3p3 ); + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector1, c_float_0p2 ); - // c[4,0-15] - c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); + // c[0,48-63] + c_float_0p3 = _mm512_add_ps( selector1, c_float_0p3 ); - // c[4, 16-31] - c_float_4p1 = _mm512_add_ps( selector2, c_float_4p1 ); + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); - // c[4,32-47] - c_float_4p2 = _mm512_add_ps( selector3, c_float_4p2 ); + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); - // c[4,48-63] - c_float_4p3 = _mm512_add_ps( selector4, c_float_4p3 ); + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector2, c_float_1p2 ); - // c[5,0-15] - c_float_5p0 = _mm512_add_ps( selector1, c_float_5p0 ); + // c[1,48-63] + c_float_1p3 = _mm512_add_ps( selector2, c_float_1p3 ); - // c[5, 16-31] - c_float_5p1 = _mm512_add_ps( selector2, c_float_5p1 ); + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector3, c_float_2p0 ); - // c[5,32-47] - c_float_5p2 = _mm512_add_ps( selector3, c_float_5p2 ); + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector3, c_float_2p1 ); - // c[5,48-63] - c_float_5p3 = _mm512_add_ps( selector4, c_float_5p3 ); + // c[2,32-47] + c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); + + // c[2,48-63] + c_float_2p3 = _mm512_add_ps( selector3, c_float_2p3 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector4, c_float_3p0 ); + + // c[3, 16-31] + c_float_3p1 = _mm512_add_ps( selector4, c_float_3p1 ); + + // c[3,32-47] + c_float_3p2 = _mm512_add_ps( selector4, c_float_3p2 ); + + // c[3,48-63] + c_float_3p3 = _mm512_add_ps( selector4, c_float_3p3 ); + + // c[4,0-15] + c_float_4p0 = _mm512_add_ps( a_bf16_0, c_float_4p0 ); + + // c[4, 16-31] + c_float_4p1 = _mm512_add_ps( a_bf16_0, c_float_4p1 ); + + // c[4,32-47] + c_float_4p2 = _mm512_add_ps( a_bf16_0, c_float_4p2 ); + + // c[4,48-63] + c_float_4p3 = _mm512_add_ps( a_bf16_0, c_float_4p3 ); + + // c[5,0-15] + c_float_5p0 = _mm512_add_ps( a_bf16_1, c_float_5p0 ); + + // c[5, 16-31] + c_float_5p1 = _mm512_add_ps( a_bf16_1, c_float_5p1 ); + + // c[5,32-47] + c_float_5p2 = _mm512_add_ps( a_bf16_1, c_float_5p2 ); + + // c[5,48-63] + c_float_5p3 = _mm512_add_ps( a_bf16_1, c_float_5p3 ); + } POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } diff --git a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_m_fringe_bf16_amd512vnni.c b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_m_fringe_bf16_amd512vnni.c index 073668ba1f..cfd8f0a0dc 100644 --- a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_m_fringe_bf16_amd512vnni.c +++ b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_m_fringe_bf16_amd512vnni.c @@ -385,78 +385,169 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x64) POST_OP_LABEL_LASTK_SAFE_JUMP POST_OPS_BIAS_5x64: { - selector1 = + __m512 selector3; + __m512 selector4; + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + selector1 = _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_op_c_j ); - selector2 = + post_op_c_j ); + selector2 = _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 1 * 16 ) ); - __m512 selector3 = + post_op_c_j + ( 1 * 16 ) ); + selector3 = _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 2 * 16 ) ); - __m512 selector4 = + post_op_c_j + ( 2 * 16 ) ); + selector4 = _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 3 * 16 ) ); + post_op_c_j + ( 3 * 16 ) ); - // c[0,0-15] - c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); - // c[0, 16-31] - c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); - // c[0,32-47] - c_float_0p2 = _mm512_add_ps( selector3, c_float_0p2 ); + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector3, c_float_0p2 ); - // c[0,48-63] - c_float_0p3 = _mm512_add_ps( selector4, c_float_0p3 ); + // c[0,48-63] + c_float_0p3 = _mm512_add_ps( selector4, c_float_0p3 ); - // c[1,0-15] - c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); - // c[1, 16-31] - c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); - // c[1,32-47] - c_float_1p2 = _mm512_add_ps( selector3, c_float_1p2 ); + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector3, c_float_1p2 ); - // c[1,48-63] - c_float_1p3 = _mm512_add_ps( selector4, c_float_1p3 ); + // c[1,48-63] + c_float_1p3 = _mm512_add_ps( selector4, c_float_1p3 ); - // c[2,0-15] - c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); - // c[2, 16-31] - c_float_2p1 = _mm512_add_ps( selector2, c_float_2p1 ); + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector2, c_float_2p1 ); - // c[2,32-47] - c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); + // c[2,32-47] + c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); - // c[2,48-63] - c_float_2p3 = _mm512_add_ps( selector4, c_float_2p3 ); + // c[2,48-63] + c_float_2p3 = _mm512_add_ps( selector4, c_float_2p3 ); - // c[3,0-15] - c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); - // c[3, 16-31] - c_float_3p1 = _mm512_add_ps( selector2, c_float_3p1 ); + // c[3, 16-31] + c_float_3p1 = _mm512_add_ps( selector2, c_float_3p1 ); - // c[3,32-47] - c_float_3p2 = _mm512_add_ps( selector3, c_float_3p2 ); + // c[3,32-47] + c_float_3p2 = _mm512_add_ps( selector3, c_float_3p2 ); - // c[3,48-63] - c_float_3p3 = _mm512_add_ps( selector4, c_float_3p3 ); + // c[3,48-63] + c_float_3p3 = _mm512_add_ps( selector4, c_float_3p3 ); - // c[4,0-15] - c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); + // c[4,0-15] + c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); - // c[4, 16-31] - c_float_4p1 = _mm512_add_ps( selector2, c_float_4p1 ); + // c[4, 16-31] + c_float_4p1 = _mm512_add_ps( selector2, c_float_4p1 ); - // c[4,32-47] - c_float_4p2 = _mm512_add_ps( selector3, c_float_4p2 ); + // c[4,32-47] + c_float_4p2 = _mm512_add_ps( selector3, c_float_4p2 ); - // c[4,48-63] - c_float_4p3 = _mm512_add_ps( selector4, c_float_4p3 ); + // c[4,48-63] + c_float_4p3 = _mm512_add_ps( selector4, c_float_4p3 ); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the bias array will be accessed by + // the ic index, and each bias element corresponds to an + // entire row of the transposed output array, instead of an + // entire column. + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 2 ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 3 ) ); + a_bf16_0 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 4 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector1, c_float_0p2 ); + + // c[0,48-63] + c_float_0p3 = _mm512_add_ps( selector1, c_float_0p3 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector2, c_float_1p2 ); + + // c[1,48-63] + c_float_1p3 = _mm512_add_ps( selector2, c_float_1p3 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector3, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector3, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); + + // c[2,48-63] + c_float_2p3 = _mm512_add_ps( selector3, c_float_2p3 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector4, c_float_3p0 ); + + // c[3, 16-31] + c_float_3p1 = _mm512_add_ps( selector4, c_float_3p1 ); + + // c[3,32-47] + c_float_3p2 = _mm512_add_ps( selector4, c_float_3p2 ); + + // c[3,48-63] + c_float_3p3 = _mm512_add_ps( selector4, c_float_3p3 ); + + // c[4,0-15] + c_float_4p0 = _mm512_add_ps( a_bf16_0, c_float_4p0 ); + + // c[4, 16-31] + c_float_4p1 = _mm512_add_ps( a_bf16_0, c_float_4p1 ); + + // c[4,32-47] + c_float_4p2 = _mm512_add_ps( a_bf16_0, c_float_4p2 ); + + // c[4,48-63] + c_float_4p3 = _mm512_add_ps( a_bf16_0, c_float_4p3 ); + } POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -1014,66 +1105,142 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x64) POST_OP_LABEL_LASTK_SAFE_JUMP POST_OPS_BIAS_4x64: { - selector1 = + __m512 selector3; + __m512 selector4; + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + selector1 = _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_op_c_j ); - selector2 = + post_op_c_j ); + selector2 = _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 1 * 16 ) ); - __m512 selector3 = + post_op_c_j + ( 1 * 16 ) ); + selector3 = _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 2 * 16 ) ); - __m512 selector4 = + post_op_c_j + ( 2 * 16 ) ); + selector4 = _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 3 * 16 ) ); + post_op_c_j + ( 3 * 16 ) ); - // c[0,0-15] - c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); - // c[0, 16-31] - c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); - // c[0,32-47] - c_float_0p2 = _mm512_add_ps( selector3, c_float_0p2 ); + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector3, c_float_0p2 ); - // c[0,48-63] - c_float_0p3 = _mm512_add_ps( selector4, c_float_0p3 ); + // c[0,48-63] + c_float_0p3 = _mm512_add_ps( selector4, c_float_0p3 ); - // c[1,0-15] - c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); - // c[1, 16-31] - c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); - // c[1,32-47] - c_float_1p2 = _mm512_add_ps( selector3, c_float_1p2 ); + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector3, c_float_1p2 ); - // c[1,48-63] - c_float_1p3 = _mm512_add_ps( selector4, c_float_1p3 ); + // c[1,48-63] + c_float_1p3 = _mm512_add_ps( selector4, c_float_1p3 ); - // c[2,0-15] - c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); - // c[2, 16-31] - c_float_2p1 = _mm512_add_ps( selector2, c_float_2p1 ); + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector2, c_float_2p1 ); - // c[2,32-47] - c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); + // c[2,32-47] + c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); - // c[2,48-63] - c_float_2p3 = _mm512_add_ps( selector4, c_float_2p3 ); + // c[2,48-63] + c_float_2p3 = _mm512_add_ps( selector4, c_float_2p3 ); - // c[3,0-15] - c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); - // c[3, 16-31] - c_float_3p1 = _mm512_add_ps( selector2, c_float_3p1 ); + // c[3, 16-31] + c_float_3p1 = _mm512_add_ps( selector2, c_float_3p1 ); - // c[3,32-47] - c_float_3p2 = _mm512_add_ps( selector3, c_float_3p2 ); + // c[3,32-47] + c_float_3p2 = _mm512_add_ps( selector3, c_float_3p2 ); - // c[3,48-63] - c_float_3p3 = _mm512_add_ps( selector4, c_float_3p3 ); + // c[3,48-63] + c_float_3p3 = _mm512_add_ps( selector4, c_float_3p3 ); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the bias array will be accessed by + // the ic index, and each bias element corresponds to an + // entire row of the transposed output array, instead of an + // entire column. + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 2 ) ); + selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 3 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector1, c_float_0p2 ); + + // c[0,48-63] + c_float_0p3 = _mm512_add_ps( selector1, c_float_0p3 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector2, c_float_1p2 ); + + // c[1,48-63] + c_float_1p3 = _mm512_add_ps( selector2, c_float_1p3 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector3, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector3, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); + + // c[2,48-63] + c_float_2p3 = _mm512_add_ps( selector3, c_float_2p3 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector4, c_float_3p0 ); + + // c[3, 16-31] + c_float_3p1 = _mm512_add_ps( selector4, c_float_3p1 ); + + // c[3,32-47] + c_float_3p2 = _mm512_add_ps( selector4, c_float_3p2 ); + + // c[3,48-63] + c_float_3p3 = _mm512_add_ps( selector4, c_float_3p3 ); + } POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -1526,55 +1693,116 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x64) POST_OP_LABEL_LASTK_SAFE_JUMP POST_OPS_BIAS_3x64: { - selector1 = + __m512 selector3; + __m512 selector4; + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + selector1 = _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_op_c_j ); - selector2 = + post_op_c_j ); + selector2 = _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 1 * 16 ) ); - __m512 selector3 = + post_op_c_j + ( 1 * 16 ) ); + selector3 = _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 2 * 16 ) ); - __m512 selector4 = + post_op_c_j + ( 2 * 16 ) ); + selector4 = _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 3 * 16 ) ); + post_op_c_j + ( 3 * 16 ) ); - // c[0,0-15] - c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); - // c[0, 16-31] - c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); - // c[0,32-47] - c_float_0p2 = _mm512_add_ps( selector3, c_float_0p2 ); + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector3, c_float_0p2 ); - // c[0,48-63] - c_float_0p3 = _mm512_add_ps( selector4, c_float_0p3 ); + // c[0,48-63] + c_float_0p3 = _mm512_add_ps( selector4, c_float_0p3 ); - // c[1,0-15] - c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); - // c[1, 16-31] - c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); - // c[1,32-47] - c_float_1p2 = _mm512_add_ps( selector3, c_float_1p2 ); + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector3, c_float_1p2 ); - // c[1,48-63] - c_float_1p3 = _mm512_add_ps( selector4, c_float_1p3 ); + // c[1,48-63] + c_float_1p3 = _mm512_add_ps( selector4, c_float_1p3 ); - // c[2,0-15] - c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); - // c[2, 16-31] - c_float_2p1 = _mm512_add_ps( selector2, c_float_2p1 ); + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector2, c_float_2p1 ); - // c[2,32-47] - c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); + // c[2,32-47] + c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); - // c[2,48-63] - c_float_2p3 = _mm512_add_ps( selector4, c_float_2p3 ); + // c[2,48-63] + c_float_2p3 = _mm512_add_ps( selector4, c_float_2p3 ); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the bias array will be accessed by + // the ic index, and each bias element corresponds to an + // entire row of the transposed output array, instead of an + // entire column. + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 2 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector1, c_float_0p2 ); + + // c[0,48-63] + c_float_0p3 = _mm512_add_ps( selector1, c_float_0p3 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector2, c_float_1p2 ); + // c[1,48-63] + c_float_1p3 = _mm512_add_ps( selector2, c_float_1p3 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector3, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector3, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); + + // c[2,48-63] + c_float_2p3 = _mm512_add_ps( selector3, c_float_2p3 ); + } + POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } POST_OPS_RELU_3x64: @@ -1919,42 +2147,88 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2x64) POST_OP_LABEL_LASTK_SAFE_JUMP POST_OPS_BIAS_2x64: { - selector1 = + __m512 selector3; + __m512 selector4; + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + selector1 = _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_op_c_j ); - selector2 = + post_op_c_j ); + selector2 = _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 1 * 16 ) ); - __m512 selector3 = + post_op_c_j + ( 1 * 16 ) ); + selector3 = _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 2 * 16 ) ); - __m512 selector4 = + post_op_c_j + ( 2 * 16 ) ); + selector4 = _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 3 * 16 ) ); + post_op_c_j + ( 3 * 16 ) ); - // c[0,0-15] - c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); - // c[0, 16-31] - c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); - // c[0,32-47] - c_float_0p2 = _mm512_add_ps( selector3, c_float_0p2 ); + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector3, c_float_0p2 ); - // c[0,48-63] - c_float_0p3 = _mm512_add_ps( selector4, c_float_0p3 ); + // c[0,48-63] + c_float_0p3 = _mm512_add_ps( selector4, c_float_0p3 ); - // c[1,0-15] - c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); - // c[1, 16-31] - c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); - // c[1,32-47] - c_float_1p2 = _mm512_add_ps( selector3, c_float_1p2 ); + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector3, c_float_1p2 ); - // c[1,48-63] - c_float_1p3 = _mm512_add_ps( selector4, c_float_1p3 ); + // c[1,48-63] + c_float_1p3 = _mm512_add_ps( selector4, c_float_1p3 ); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the bias array will be accessed by + // the ic index, and each bias element corresponds to an + // entire row of the transposed output array, instead of an + // entire column. + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 1 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector1, c_float_0p2 ); + + // c[0,48-63] + c_float_0p3 = _mm512_add_ps( selector1, c_float_0p3 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector2, c_float_1p2 ); + + // c[1,48-63] + c_float_1p3 = _mm512_add_ps( selector2, c_float_1p3 ); + } POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -2184,30 +2458,61 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1x64) POST_OP_LABEL_LASTK_SAFE_JUMP POST_OPS_BIAS_1x64: { - selector1 = + __m512 selector3; + __m512 selector4; + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + selector1 = _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_op_c_j ); - selector2 = + post_op_c_j ); + selector2 = _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 1 * 16 ) ); - __m512 selector3 = + post_op_c_j + ( 1 * 16 ) ); + selector3 = _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 2 * 16 ) ); - __m512 selector4 = + post_op_c_j + ( 2 * 16 ) ); + selector4 = _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 3 * 16 ) ); - - // c[0,0-15] - c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); - - // c[0, 16-31] - c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); - - // c[0,32-47] - c_float_0p2 = _mm512_add_ps( selector3, c_float_0p2 ); - - // c[0,48-63] - c_float_0p3 = _mm512_add_ps( selector4, c_float_0p3 ); + post_op_c_j + ( 3 * 16 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector3, c_float_0p2 ); + + // c[0,48-63] + c_float_0p3 = _mm512_add_ps( selector4, c_float_0p3 ); + } + else + { + // If original output was columns major, then by the time + // kernel sees it, the matrix would be accessed as if it were + // transposed. Due to this the bias array will be accessed by + // the ic index, and each bias element corresponds to an + // entire row of the transposed output array, instead of an + // entire column. + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 0 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector1, c_float_0p2 ); + + // c[0,48-63] + c_float_0p3 = _mm512_add_ps( selector1, c_float_0p3 ); + } POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } diff --git a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_mn_fringe_bf16_amd512vnni.c b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_mn_fringe_bf16_amd512vnni.c index 77dd30acd2..65bce97799 100644 --- a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_mn_fringe_bf16_amd512vnni.c +++ b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_mn_fringe_bf16_amd512vnni.c @@ -218,24 +218,61 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5xlt16) POST_OP_LABEL_LASTK_SAFE_JUMP POST_OPS_BIAS_5xLT16: { - memcpy( buf0, ( ( float* )post_ops_list_temp->op_args1 + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + memcpy( buf0, ( ( float* )post_ops_list_temp->op_args1 + post_op_c_j ), ( n0_rem * sizeof( float ) ) ); - selector1 = _mm512_loadu_ps( buf0 ); - - // c[0,0-15] - c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); - - // c[1,0-15] - c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); - - // c[2,0-15] - c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); - - // c[3,0-15] - c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); - - // c[4,0-15] - c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); + selector1 = _mm512_loadu_ps( buf0 ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + + // c[4,0-15] + c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 1 ) ); + __m512 selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 2 ) ); + __m512 selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 3 ) ); + a_bf16_0 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 4 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector3, c_float_2p0 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector4, c_float_3p0 ); + + // c[4,0-15] + c_float_4p0 = _mm512_add_ps( a_bf16_0, c_float_4p0 ); + } POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -496,21 +533,52 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4xlt16) POST_OP_LABEL_LASTK_SAFE_JUMP POST_OPS_BIAS_4xLT16: { - memcpy( buf0, ( ( float* )post_ops_list_temp->op_args1 + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + memcpy( buf0, ( ( float* )post_ops_list_temp->op_args1 + post_op_c_j ), ( n0_rem * sizeof( float ) ) ); - selector1 = _mm512_loadu_ps( buf0 ); - - // c[0,0-15] - c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); - - // c[1,0-15] - c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); - - // c[2,0-15] - c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); - - // c[3,0-15] - c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + selector1 = _mm512_loadu_ps( buf0 ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 1 ) ); + __m512 selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 2 ) ); + a_bf16_0 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 3 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector3, c_float_2p0 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( a_bf16_0, c_float_3p0 ); + } POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -729,18 +797,43 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3xlt16) POST_OP_LABEL_LASTK_SAFE_JUMP POST_OPS_BIAS_3xLT16: { - memcpy( buf0, ( ( float* )post_ops_list_temp->op_args1 + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + memcpy( buf0, ( ( float* )post_ops_list_temp->op_args1 + post_op_c_j ), ( n0_rem * sizeof( float ) ) ); - selector1 = _mm512_loadu_ps( buf0 ); - - // c[0,0-15] - c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); - - // c[1,0-15] - c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); - - // c[2,0-15] - c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + selector1 = _mm512_loadu_ps( buf0 ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 1 ) ); + a_bf16_0 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 2 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( a_bf16_0, c_float_2p0 ); + } POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -918,15 +1011,34 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2xlt16) POST_OP_LABEL_LASTK_SAFE_JUMP POST_OPS_BIAS_2xLT16: { - memcpy( buf0, ( ( float* )post_ops_list_temp->op_args1 + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + memcpy( buf0, ( ( float* )post_ops_list_temp->op_args1 + post_op_c_j ), ( n0_rem * sizeof( float ) ) ); - selector1 = _mm512_loadu_ps( buf0 ); - - // c[0,0-15] - c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); - - // c[1,0-15] - c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + selector1 = _mm512_loadu_ps( buf0 ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 1 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + } POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -1063,12 +1175,25 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1xlt16) POST_OP_LABEL_LASTK_SAFE_JUMP POST_OPS_BIAS_1xLT16: { - memcpy( buf0, ( ( float* )post_ops_list_temp->op_args1 + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + memcpy( buf0, ( ( float* )post_ops_list_temp->op_args1 + post_op_c_j ), ( n0_rem * sizeof( float ) ) ); - selector1 = _mm512_loadu_ps( buf0 ); + selector1 = _mm512_loadu_ps( buf0 ); - // c[0,0-15] - c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 0 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + } POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -1280,24 +1405,61 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x16) POST_OP_LABEL_LASTK_SAFE_JUMP POST_OPS_BIAS_5x16: { - selector1 = + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + selector1 = _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_op_c_j ); + post_op_c_j ); - // c[0,0-15] - c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); - // c[1,0-15] - c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); - // c[2,0-15] - c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); - // c[3,0-15] - c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); - // c[4,0-15] - c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); + // c[4,0-15] + c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 1 ) ); + __m512 selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 2 ) ); + __m512 selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 3 ) ); + __m512 selector5 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 4 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector3, c_float_2p0 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector4, c_float_3p0 ); + + // c[4,0-15] + c_float_4p0 = _mm512_add_ps( selector5, c_float_4p0 ); + } POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -1528,21 +1690,52 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x16) POST_OP_LABEL_LASTK_SAFE_JUMP POST_OPS_BIAS_4x16: { - selector1 = + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + selector1 = _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_op_c_j ); + post_op_c_j ); - // c[0,0-15] - c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); - // c[1,0-15] - c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); - // c[2,0-15] - c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); - // c[3,0-15] - c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 1 ) ); + __m512 selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 2 ) ); + __m512 selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 3 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector3, c_float_2p0 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector4, c_float_3p0 ); + } POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -1737,18 +1930,43 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x16) POST_OP_LABEL_LASTK_SAFE_JUMP POST_OPS_BIAS_3x16: { - selector1 = + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + selector1 = _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_op_c_j ); + post_op_c_j ); - // c[0,0-15] - c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); - // c[1,0-15] - c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); - // c[2,0-15] - c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 1 ) ); + __m512 selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 2 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector3, c_float_2p0 ); + } POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -1908,15 +2126,34 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2x16) POST_OP_LABEL_LASTK_SAFE_JUMP POST_OPS_BIAS_2x16: { - selector1 = + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + selector1 = _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_op_c_j ); + post_op_c_j ); - // c[0,0-15] - c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); - // c[1,0-15] - c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 1 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + } POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -2039,12 +2276,25 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1x16) POST_OP_LABEL_LASTK_SAFE_JUMP POST_OPS_BIAS_1x16: { - selector1 = + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + selector1 = _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_op_c_j ); + post_op_c_j ); - // c[0,0-15] - c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 0 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + } POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -2299,42 +2549,94 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x32) POST_OP_LABEL_LASTK_SAFE_JUMP POST_OPS_BIAS_5x32: { - selector1 = + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + selector1 = _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 0 * 16 ) ); - selector2 = + post_op_c_j + ( 0 * 16 ) ); + selector2 = _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 1 * 16 ) ); + post_op_c_j + ( 1 * 16 ) ); - // c[0,0-15] - c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); - // c[0, 16-31] - c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); - // c[1,0-15] - c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); - // c[1, 16-31] - c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); - // c[2,0-15] - c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); - // c[2, 16-31] - c_float_2p1 = _mm512_add_ps( selector2, c_float_2p1 ); + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector2, c_float_2p1 ); - // c[3,0-15] - c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); - // c[3, 16-31] - c_float_3p1 = _mm512_add_ps( selector2, c_float_3p1 ); + // c[3, 16-31] + c_float_3p1 = _mm512_add_ps( selector2, c_float_3p1 ); - // c[4,0-15] - c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); + // c[4,0-15] + c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); - // c[4, 16-31] - c_float_4p1 = _mm512_add_ps( selector2, c_float_4p1 ); + // c[4, 16-31] + c_float_4p1 = _mm512_add_ps( selector2, c_float_4p1 ); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 1 ) ); + __m512 selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 2 ) ); + __m512 selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 3 ) ); + __m512 selector5 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 4 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector3, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector3, c_float_2p1 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector4, c_float_3p0 ); + + // c[3, 16-31] + c_float_3p1 = _mm512_add_ps( selector4, c_float_3p1 ); + + // c[4,0-15] + c_float_4p0 = _mm512_add_ps( selector5, c_float_4p0 ); + + // c[4, 16-31] + c_float_4p1 = _mm512_add_ps( selector5, c_float_4p1 ); + } POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -2664,36 +2966,79 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x32) POST_OP_LABEL_LASTK_SAFE_JUMP POST_OPS_BIAS_4x32: { - selector1 = + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + selector1 = _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 0 * 16 ) ); - selector2 = + post_op_c_j + ( 0 * 16 ) ); + selector2 = _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 1 * 16 ) ); + post_op_c_j + ( 1 * 16 ) ); - // c[0,0-15] - c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); - // c[0, 16-31] - c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); - // c[1,0-15] - c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); - // c[1, 16-31] - c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); - // c[2,0-15] - c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); - // c[2, 16-31] - c_float_2p1 = _mm512_add_ps( selector2, c_float_2p1 ); + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector2, c_float_2p1 ); - // c[3,0-15] - c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); - // c[3, 16-31] - c_float_3p1 = _mm512_add_ps( selector2, c_float_3p1 ); + // c[3, 16-31] + c_float_3p1 = _mm512_add_ps( selector2, c_float_3p1 ); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 1 ) ); + __m512 selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 2 ) ); + __m512 selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 3 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector3, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector3, c_float_2p1 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector4, c_float_3p0 ); + + // c[3, 16-31] + c_float_3p1 = _mm512_add_ps( selector4, c_float_3p1 ); + } POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -2966,30 +3311,64 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x32) POST_OP_LABEL_LASTK_SAFE_JUMP POST_OPS_BIAS_3x32: { - selector1 = + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + selector1 = _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 0 * 16 ) ); - selector2 = + post_op_c_j + ( 0 * 16 ) ); + selector2 = _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 1 * 16 ) ); + post_op_c_j + ( 1 * 16 ) ); - // c[0,0-15] - c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); - // c[0, 16-31] - c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); - // c[1,0-15] - c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); - // c[1, 16-31] - c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); - // c[2,0-15] - c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); - // c[2, 16-31] - c_float_2p1 = _mm512_add_ps( selector2, c_float_2p1 ); + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector2, c_float_2p1 ); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 1 ) ); + __m512 selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 2 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector3, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector3, c_float_2p1 ); + } POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -3205,24 +3584,49 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2x32) POST_OP_LABEL_LASTK_SAFE_JUMP POST_OPS_BIAS_2x32: { - selector1 = + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + selector1 = _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 0 * 16 ) ); - selector2 = + post_op_c_j + ( 0 * 16 ) ); + selector2 = _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 1 * 16 ) ); + post_op_c_j + ( 1 * 16 ) ); - // c[0,0-15] - c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); - // c[0, 16-31] - c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); - // c[1,0-15] - c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); - // c[1, 16-31] - c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 1 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + } POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -3381,18 +3785,34 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1x32) POST_OP_LABEL_LASTK_SAFE_JUMP POST_OPS_BIAS_1x32: { - selector1 = + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + selector1 = _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 0 * 16 ) ); - selector2 = + post_op_c_j + ( 0 * 16 ) ); + selector2 = _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 1 * 16 ) ); + post_op_c_j + ( 1 * 16 ) ); - // c[0,0-15] - c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); - // c[0, 16-31] - c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 0 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + } POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -3707,60 +4127,129 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x48) POST_OP_LABEL_LASTK_SAFE_JUMP POST_OPS_BIAS_5x48: { - selector1 = + __m512 selector3; + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + selector1 = _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 0 * 16 ) ); - selector2 = + post_op_c_j + ( 0 * 16 ) ); + selector2 = _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 1 * 16 ) ); - __m512 selector3 = + post_op_c_j + ( 1 * 16 ) ); + selector3 = _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 2 * 16 ) ); + post_op_c_j + ( 2 * 16 ) ); - // c[0,0-15] - c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); - // c[0, 16-31] - c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); - // c[0,32-47] - c_float_0p2 = _mm512_add_ps( selector3, c_float_0p2 ); + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector3, c_float_0p2 ); - // c[1,0-15] - c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); - // c[1, 16-31] - c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); - // c[1,32-47] - c_float_1p2 = _mm512_add_ps( selector3, c_float_1p2 ); + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector3, c_float_1p2 ); - // c[2,0-15] - c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); - // c[2, 16-31] - c_float_2p1 = _mm512_add_ps( selector2, c_float_2p1 ); + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector2, c_float_2p1 ); - // c[2,32-47] - c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); + // c[2,32-47] + c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); - // c[3,0-15] - c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); - // c[3, 16-31] - c_float_3p1 = _mm512_add_ps( selector2, c_float_3p1 ); + // c[3, 16-31] + c_float_3p1 = _mm512_add_ps( selector2, c_float_3p1 ); - // c[3,32-47] - c_float_3p2 = _mm512_add_ps( selector3, c_float_3p2 ); + // c[3,32-47] + c_float_3p2 = _mm512_add_ps( selector3, c_float_3p2 ); - // c[4,0-15] - c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); + // c[4,0-15] + c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); - // c[4, 16-31] - c_float_4p1 = _mm512_add_ps( selector2, c_float_4p1 ); + // c[4, 16-31] + c_float_4p1 = _mm512_add_ps( selector2, c_float_4p1 ); - // c[4,32-47] - c_float_4p2 = _mm512_add_ps( selector3, c_float_4p2 ); + // c[4,32-47] + c_float_4p2 = _mm512_add_ps( selector3, c_float_4p2 ); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 2 ) ); + __m512 selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 3 ) ); + __m512 selector5 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 4 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector1, c_float_0p2 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector2, c_float_1p2 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector3, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector3, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector4, c_float_3p0 ); + + // c[3, 16-31] + c_float_3p1 = _mm512_add_ps( selector4, c_float_3p1 ); + + // c[3,32-47] + c_float_3p2 = _mm512_add_ps( selector4, c_float_3p2 ); + + // c[4,0-15] + c_float_4p0 = _mm512_add_ps( selector5, c_float_4p0 ); + + // c[4, 16-31] + c_float_4p1 = _mm512_add_ps( selector5, c_float_4p1 ); + + // c[4,32-47] + c_float_4p2 = _mm512_add_ps( selector5, c_float_4p2 ); + } POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -4189,51 +4678,108 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4x48) POST_OP_LABEL_LASTK_SAFE_JUMP POST_OPS_BIAS_4x48: { - selector1 = + __m512 selector3; + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + selector1 = _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 0 * 16 ) ); - selector2 = + post_op_c_j + ( 0 * 16 ) ); + selector2 = _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 1 * 16 ) ); - __m512 selector3 = + post_op_c_j + ( 1 * 16 ) ); + selector3 = _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 2 * 16 ) ); + post_op_c_j + ( 2 * 16 ) ); - // c[0,0-15] - c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); - // c[0, 16-31] - c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); - // c[0,32-47] - c_float_0p2 = _mm512_add_ps( selector3, c_float_0p2 ); + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector3, c_float_0p2 ); - // c[1,0-15] - c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); - // c[1, 16-31] - c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); - // c[1,32-47] - c_float_1p2 = _mm512_add_ps( selector3, c_float_1p2 ); + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector3, c_float_1p2 ); - // c[2,0-15] - c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); - // c[2, 16-31] - c_float_2p1 = _mm512_add_ps( selector2, c_float_2p1 ); + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector2, c_float_2p1 ); - // c[2,32-47] - c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); + // c[2,32-47] + c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); - // c[3,0-15] - c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); - // c[3, 16-31] - c_float_3p1 = _mm512_add_ps( selector2, c_float_3p1 ); + // c[3, 16-31] + c_float_3p1 = _mm512_add_ps( selector2, c_float_3p1 ); - // c[3,32-47] - c_float_3p2 = _mm512_add_ps( selector3, c_float_3p2 ); + // c[3,32-47] + c_float_3p2 = _mm512_add_ps( selector3, c_float_3p2 ); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 2 ) ); + __m512 selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 3 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector1, c_float_0p2 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector2, c_float_1p2 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector3, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector3, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector4, c_float_3p0 ); + + // c[3, 16-31] + c_float_3p1 = _mm512_add_ps( selector4, c_float_3p1 ); + + // c[3,32-47] + c_float_3p2 = _mm512_add_ps( selector4, c_float_3p2 ); + } POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -4584,42 +5130,87 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3x48) POST_OP_LABEL_LASTK_SAFE_JUMP POST_OPS_BIAS_3x48: { - selector1 = + __m512 selector3; + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + selector1 = _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 0 * 16 ) ); - selector2 = + post_op_c_j + ( 0 * 16 ) ); + selector2 = _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 1 * 16 ) ); - __m512 selector3 = + post_op_c_j + ( 1 * 16 ) ); + selector3 = _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 2 * 16 ) ); + post_op_c_j + ( 2 * 16 ) ); - // c[0,0-15] - c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); - // c[0, 16-31] - c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); - // c[0,32-47] - c_float_0p2 = _mm512_add_ps( selector3, c_float_0p2 ); + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector3, c_float_0p2 ); - // c[1,0-15] - c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); - // c[1, 16-31] - c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); - // c[1,32-47] - c_float_1p2 = _mm512_add_ps( selector3, c_float_1p2 ); + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector3, c_float_1p2 ); - // c[2,0-15] - c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); - // c[2, 16-31] - c_float_2p1 = _mm512_add_ps( selector2, c_float_2p1 ); + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector2, c_float_2p1 ); - // c[2,32-47] - c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); + // c[2,32-47] + c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 2 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector1, c_float_0p2 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector2, c_float_1p2 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector3, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector3, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); + } POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -4892,33 +5483,66 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_2x48) POST_OP_LABEL_LASTK_SAFE_JUMP POST_OPS_BIAS_2x48: { - selector1 = + __m512 selector3; + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + selector1 = _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 0 * 16 ) ); - selector2 = + post_op_c_j + ( 0 * 16 ) ); + selector2 = _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 1 * 16 ) ); - __m512 selector3 = + post_op_c_j + ( 1 * 16 ) ); + selector3 = _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 2 * 16 ) ); + post_op_c_j + ( 2 * 16 ) ); - // c[0,0-15] - c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); - // c[0, 16-31] - c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); - // c[0,32-47] - c_float_0p2 = _mm512_add_ps( selector3, c_float_0p2 ); + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector3, c_float_0p2 ); - // c[1,0-15] - c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); - // c[1, 16-31] - c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); - // c[1,32-47] - c_float_1p2 = _mm512_add_ps( selector3, c_float_1p2 ); + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector3, c_float_1p2 ); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 1 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector1, c_float_0p2 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector2, c_float_1p2 ); + } POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -5113,24 +5737,45 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1x48) POST_OP_LABEL_LASTK_SAFE_JUMP POST_OPS_BIAS_1x48: { - selector1 = + __m512 selector3; + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + selector1 = _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 0 * 16 ) ); - selector2 = + post_op_c_j + ( 0 * 16 ) ); + selector2 = _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 1 * 16 ) ); - __m512 selector3 = + post_op_c_j + ( 1 * 16 ) ); + selector3 = _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 2 * 16 ) ); + post_op_c_j + ( 2 * 16 ) ); - // c[0,0-15] - c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); - // c[0, 16-31] - c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); - // c[0,32-47] - c_float_0p2 = _mm512_add_ps( selector3, c_float_0p2 ); + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector3, c_float_0p2 ); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 0 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector1, c_float_0p2 ); + } POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } diff --git a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_n_fringe_bf16_amd512vnni.c b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_n_fringe_bf16_amd512vnni.c index e3c3f1c524..7207cf41dc 100644 --- a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_n_fringe_bf16_amd512vnni.c +++ b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_n_fringe_bf16_amd512vnni.c @@ -286,27 +286,70 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6xlt16) POST_OP_LABEL_LASTK_SAFE_JUMP POST_OPS_BIAS_6xLT16: { - memcpy( buf0, ( ( float* )post_ops_list_temp->op_args1 + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + memcpy( buf0, ( ( float* )post_ops_list_temp->op_args1 + post_op_c_j ), ( n0_rem * sizeof( float ) ) ); - selector1 = _mm512_loadu_ps( buf0 ); - - // c[0,0-15] - c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); - - // c[1,0-15] - c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); - - // c[2,0-15] - c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); - - // c[3,0-15] - c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); - - // c[4,0-15] - c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); - - // c[5,0-15] - c_float_5p0 = _mm512_add_ps( selector1, c_float_5p0 ); + selector1 = _mm512_loadu_ps( buf0 ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + + // c[4,0-15] + c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); + + // c[5,0-15] + c_float_5p0 = _mm512_add_ps( selector1, c_float_5p0 ); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 1 ) ); + __m512 selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 2 ) ); + __m512 selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 3 ) ); + __m512 selector5 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 4 ) ); + a_bf16_0 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 5 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector3, c_float_2p0 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector4, c_float_3p0 ); + + // c[4,0-15] + c_float_4p0 = _mm512_add_ps( selector5, c_float_4p0 ); + + // c[5,0-15] + c_float_5p0 = _mm512_add_ps( a_bf16_0, c_float_5p0 ); + } POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -742,27 +785,70 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x16) POST_OP_LABEL_LASTK_SAFE_JUMP POST_OPS_BIAS_6x16: { - selector1 = + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + selector1 = _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_op_c_j ); - - // c[0,0-15] - c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); - - // c[1,0-15] - c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); - - // c[2,0-15] - c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); - - // c[3,0-15] - c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); - - // c[4,0-15] - c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); - - // c[5,0-15] - c_float_5p0 = _mm512_add_ps( selector1, c_float_5p0 ); + post_op_c_j ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + + // c[4,0-15] + c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); + + // c[5,0-15] + c_float_5p0 = _mm512_add_ps( selector1, c_float_5p0 ); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 1 ) ); + __m512 selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 2 ) ); + __m512 selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 3 ) ); + __m512 selector5 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 4 ) ); + a_bf16_0 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 5 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector3, c_float_2p0 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector4, c_float_3p0 ); + + // c[4,0-15] + c_float_4p0 = _mm512_add_ps( selector5, c_float_4p0 ); + + // c[5,0-15] + c_float_5p0 = _mm512_add_ps( a_bf16_0, c_float_5p0 ); + } POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -1234,48 +1320,109 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x32) POST_OP_LABEL_LASTK_SAFE_JUMP POST_OPS_BIAS_6x32: { - selector1 = + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + selector1 = _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 0 * 16 ) ); - selector2 = + post_op_c_j + ( 0 * 16 ) ); + selector2 = _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 1 * 16 ) ); + post_op_c_j + ( 1 * 16 ) ); - // c[0,0-15] - c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); - // c[0, 16-31] - c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); - // c[1,0-15] - c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); - // c[1, 16-31] - c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); - // c[2,0-15] - c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); - // c[2, 16-31] - c_float_2p1 = _mm512_add_ps( selector2, c_float_2p1 ); + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector2, c_float_2p1 ); - // c[3,0-15] - c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); - // c[3, 16-31] - c_float_3p1 = _mm512_add_ps( selector2, c_float_3p1 ); + // c[3, 16-31] + c_float_3p1 = _mm512_add_ps( selector2, c_float_3p1 ); - // c[4,0-15] - c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); + // c[4,0-15] + c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); - // c[4, 16-31] - c_float_4p1 = _mm512_add_ps( selector2, c_float_4p1 ); + // c[4, 16-31] + c_float_4p1 = _mm512_add_ps( selector2, c_float_4p1 ); - // c[5,0-15] - c_float_5p0 = _mm512_add_ps( selector1, c_float_5p0 ); + // c[5,0-15] + c_float_5p0 = _mm512_add_ps( selector1, c_float_5p0 ); - // c[5, 16-31] - c_float_5p1 = _mm512_add_ps( selector2, c_float_5p1 ); + // c[5, 16-31] + c_float_5p1 = _mm512_add_ps( selector2, c_float_5p1 ); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 1 ) ); + __m512 selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 2 ) ); + __m512 selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 3 ) ); + __m512 selector5 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 4 ) ); + a_bf16_0 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 5 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector3, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector3, c_float_2p1 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector4, c_float_3p0 ); + + // c[3, 16-31] + c_float_3p1 = _mm512_add_ps( selector4, c_float_3p1 ); + + // c[4,0-15] + c_float_4p0 = _mm512_add_ps( selector5, c_float_4p0 ); + + // c[4, 16-31] + c_float_4p1 = _mm512_add_ps( selector5, c_float_4p1 ); + + // c[5,0-15] + c_float_5p0 = _mm512_add_ps( a_bf16_0, c_float_5p0 ); + + // c[5, 16-31] + c_float_5p1 = _mm512_add_ps( a_bf16_0, c_float_5p1 ); + } POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } @@ -1879,69 +2026,150 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x48) POST_OP_LABEL_LASTK_SAFE_JUMP POST_OPS_BIAS_6x48: { - selector1 = + __m512 selector3; + + if ( ( *( char* )post_ops_list_temp->op_args2 == 'r' ) || + ( *( char* )post_ops_list_temp->op_args2 == 'R' ) ) + { + selector1 = _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 0 * 16 ) ); - selector2 = + post_op_c_j + ( 0 * 16 ) ); + selector2 = _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 1 * 16 ) ); - __m512 selector3 = + post_op_c_j + ( 1 * 16 ) ); + selector3 = _mm512_loadu_ps( ( float* )post_ops_list_temp->op_args1 + - post_op_c_j + ( 2 * 16 ) ); + post_op_c_j + ( 2 * 16 ) ); - // c[0,0-15] - c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); - // c[0, 16-31] - c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector2, c_float_0p1 ); - // c[0,32-47] - c_float_0p2 = _mm512_add_ps( selector3, c_float_0p2 ); + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector3, c_float_0p2 ); - // c[1,0-15] - c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector1, c_float_1p0 ); - // c[1, 16-31] - c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); - // c[1,32-47] - c_float_1p2 = _mm512_add_ps( selector3, c_float_1p2 ); + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector3, c_float_1p2 ); - // c[2,0-15] - c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector1, c_float_2p0 ); - // c[2, 16-31] - c_float_2p1 = _mm512_add_ps( selector2, c_float_2p1 ); + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector2, c_float_2p1 ); - // c[2,32-47] - c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); + // c[2,32-47] + c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); - // c[3,0-15] - c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector1, c_float_3p0 ); - // c[3, 16-31] - c_float_3p1 = _mm512_add_ps( selector2, c_float_3p1 ); + // c[3, 16-31] + c_float_3p1 = _mm512_add_ps( selector2, c_float_3p1 ); - // c[3,32-47] - c_float_3p2 = _mm512_add_ps( selector3, c_float_3p2 ); + // c[3,32-47] + c_float_3p2 = _mm512_add_ps( selector3, c_float_3p2 ); - // c[4,0-15] - c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); + // c[4,0-15] + c_float_4p0 = _mm512_add_ps( selector1, c_float_4p0 ); - // c[4, 16-31] - c_float_4p1 = _mm512_add_ps( selector2, c_float_4p1 ); + // c[4, 16-31] + c_float_4p1 = _mm512_add_ps( selector2, c_float_4p1 ); - // c[4,32-47] - c_float_4p2 = _mm512_add_ps( selector3, c_float_4p2 ); + // c[4,32-47] + c_float_4p2 = _mm512_add_ps( selector3, c_float_4p2 ); - // c[5,0-15] - c_float_5p0 = _mm512_add_ps( selector1, c_float_5p0 ); + // c[5,0-15] + c_float_5p0 = _mm512_add_ps( selector1, c_float_5p0 ); - // c[5, 16-31] - c_float_5p1 = _mm512_add_ps( selector2, c_float_5p1 ); + // c[5, 16-31] + c_float_5p1 = _mm512_add_ps( selector2, c_float_5p1 ); - // c[5,32-47] - c_float_5p2 = _mm512_add_ps( selector3, c_float_5p2 ); + // c[5,32-47] + c_float_5p2 = _mm512_add_ps( selector3, c_float_5p2 ); + } + else + { + selector1 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 0 ) ); + selector2 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 1 ) ); + selector3 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 2 ) ); + __m512 selector4 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 3 ) ); + __m512 selector5 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 4 ) ); + a_bf16_0 = + _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + + post_op_c_i + 5 ) ); + + // c[0,0-15] + c_float_0p0 = _mm512_add_ps( selector1, c_float_0p0 ); + + // c[0, 16-31] + c_float_0p1 = _mm512_add_ps( selector1, c_float_0p1 ); + + // c[0,32-47] + c_float_0p2 = _mm512_add_ps( selector1, c_float_0p2 ); + + // c[1,0-15] + c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); + + // c[1, 16-31] + c_float_1p1 = _mm512_add_ps( selector2, c_float_1p1 ); + + // c[1,32-47] + c_float_1p2 = _mm512_add_ps( selector2, c_float_1p2 ); + + // c[2,0-15] + c_float_2p0 = _mm512_add_ps( selector3, c_float_2p0 ); + + // c[2, 16-31] + c_float_2p1 = _mm512_add_ps( selector3, c_float_2p1 ); + + // c[2,32-47] + c_float_2p2 = _mm512_add_ps( selector3, c_float_2p2 ); + + // c[3,0-15] + c_float_3p0 = _mm512_add_ps( selector4, c_float_3p0 ); + + // c[3, 16-31] + c_float_3p1 = _mm512_add_ps( selector4, c_float_3p1 ); + + // c[3,32-47] + c_float_3p2 = _mm512_add_ps( selector4, c_float_3p2 ); + + // c[4,0-15] + c_float_4p0 = _mm512_add_ps( selector5, c_float_4p0 ); + + // c[4, 16-31] + c_float_4p1 = _mm512_add_ps( selector5, c_float_4p1 ); + + // c[4,32-47] + c_float_4p2 = _mm512_add_ps( selector5, c_float_4p2 ); + + // c[5,0-15] + c_float_5p0 = _mm512_add_ps( a_bf16_0, c_float_5p0 ); + + // c[5, 16-31] + c_float_5p1 = _mm512_add_ps( a_bf16_0, c_float_5p1 ); + + // c[5,32-47] + c_float_5p2 = _mm512_add_ps( a_bf16_0, c_float_5p2 ); + } POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR } diff --git a/bench/bench_aocl_gemm/bench_input.txt b/bench/bench_aocl_gemm/bench_input.txt index 617dced6df..faec5d0b4f 100644 --- a/bench/bench_aocl_gemm/bench_input.txt +++ b/bench/bench_aocl_gemm/bench_input.txt @@ -1,527 +1,527 @@ -s r 480 20 2050 2050 20 20 -s r 481 20 2050 2050 20 20 -s r 482 20 2050 2050 20 20 -s p 483 20 2050 2050 20 20 -s R 484 20 2050 2050 20 20 -s R 485 20 2050 2050 20 20 -s R 480 39 2050 2050 39 39 -s R 481 39 2050 2050 39 39 -s R 482 39 2050 2050 39 39 -s R 483 39 2050 2050 39 39 -s R 484 39 2050 2050 39 39 -s p 485 39 2050 2050 39 39 -s p 480 50 2050 2050 50 50 -s p 481 50 2050 2050 50 50 -s p 482 50 2050 2050 50 50 -s p 483 50 2050 2050 50 50 -s p 484 50 2050 2050 50 50 -s p 485 50 2050 2050 50 50 -s R 480 1108 2050 2050 1108 1108 -s R 481 1108 2050 2050 1108 1108 -s R 482 1108 2050 2050 1108 1108 -s R 483 1108 2050 2050 1108 1108 -s R 484 1108 2050 2050 1108 1108 -s R 485 1108 2050 2050 1108 1108 -s R 480 1127 2050 2050 1127 1127 -s R 481 1127 2050 2050 1127 1127 -s R 482 1127 2050 2050 1127 1127 -s R 483 1127 2050 2050 1127 1127 -s p 484 1127 2050 2050 1127 1127 -s p 485 1127 2050 2050 1127 1127 -s p 480 1138 2050 2050 1138 1138 -s p 481 1138 2050 2050 1138 1138 -s p 482 1138 2050 2050 1138 1138 -s p 483 1138 2050 2050 1138 1138 -s p 484 1138 2050 2050 1138 1138 -s p 485 1138 2050 2050 1138 1138 -s p 1 1 3 3 1 1 -s p 1 9 3 3 9 9 -s p 1 2048 3 3 2048 2048 -s p 1 2048 5192 5192 2048 2048 -s p 9 1 3 3 1 1 -s p 576 1 3500 3500 1 1 -s p 1 1 1 1 1 1 -s p 102 1088 1024 1024 1088 1088 -s p 102 2048 1024 1024 2048 2048 -s p 485 656 1024 1024 656 656 -s p 483 656 1024 1024 656 656 -s p 81 128 3 3 128 128 -s p 1022 512 515 515 512 512 -s p 74 512 515 515 512 512 -s p 253 2048 515 515 2048 2048 -s p 8192 1040 515 515 1040 1040 -s p 10 1029 515 515 1029 1029 -s p 24 1040 2050 2050 1040 1040 -s p 1024 1029 2050 2050 1029 1029 -s p 480 660 2050 2050 660 660 -s p 481 660 2050 2050 660 660 -s p 482 660 2050 2050 660 660 -s p 483 660 2050 2050 660 660 -s p 484 660 2050 2050 660 660 -s p 485 660 2050 2050 660 660 -s p 480 679 2050 2050 679 679 -s p 481 679 2050 2050 679 679 -s p 482 679 2050 2050 679 679 -s p 483 679 2050 2050 679 679 -s p 484 679 2050 2050 679 679 -s p 485 679 2050 2050 679 679 -s p 480 690 2050 2050 690 690 -s p 481 690 2050 2050 690 690 -s p 482 690 2050 2050 690 690 -s p 483 690 2050 2050 690 690 -s p 484 690 2050 2050 690 690 -s p 485 690 2050 2050 690 690 -s p 480 660 2048 2048 660 660 -s p 481 660 2048 2048 660 660 -s p 482 660 2048 2048 660 660 -s p 483 660 2048 2048 660 660 -s p 484 660 2048 2048 660 660 -s p 485 660 2048 2048 660 660 -s p 480 679 2048 2048 679 679 -s p 481 679 2048 2048 679 679 -s p 482 679 2048 2048 679 679 -s p 483 679 2048 2048 679 679 -s p 484 679 2048 2048 679 679 -s p 485 679 2048 2048 679 679 -s p 480 690 2048 2048 690 690 -s p 481 690 2048 2048 690 690 -s p 482 690 2048 2048 690 690 -s p 483 690 2048 2048 690 690 -s p 484 690 2048 2048 690 690 -s p 485 690 2048 2048 690 690 -s p 480 656 1024 1024 656 656 -s p 480 128 3 3 128 128 -s p 1024 512 515 515 512 512 -s p 1024 2048 1024 1024 2048 2048 -s p 1024 2048 515 515 2048 2048 -s p 1024 1040 515 515 1040 1040 -s p 5 1029 515 515 1029 1029 -s p 1024 1029 515 515 1029 1029 -s p 1024 1040 2050 2050 1040 1040 -s p 1029 1029 2050 2050 1029 1029 -s R 480 646 2050 2050 646 646 -s R 481 646 2050 2050 646 646 -s R 482 646 2050 2050 646 646 -s R 483 646 2050 2050 646 646 -s R 484 646 2050 2050 646 646 -s R 485 646 2050 2050 646 646 -s R 481 656 2050 2050 656 656 -s R 482 656 2050 2050 656 656 -s R 483 656 2050 2050 656 656 -s R 484 656 2050 2050 656 656 -s p 485 656 2050 2050 656 656 -s p 480 672 2050 2050 672 672 -s p 481 672 2050 2050 672 672 -s p 482 672 2050 2050 672 672 -s p 483 672 2050 2050 672 672 -s p 484 672 2050 2050 672 672 -s p 485 672 2050 2050 672 672 -s p 480 688 2050 2050 688 688 -s p 481 688 2050 2050 688 688 -s r 482 688 2050 2050 688 688 -s r 483 688 2050 2050 688 688 -s r 484 688 2050 2050 688 688 -s r 485 688 2050 2050 688 688 -s r 1024 512 64 64 512 512 -s r 16 256 512 512 256 256 -s r 480 640 512 512 640 640 -s r 64 768 512 512 768 768 -s r 128 128 128 128 128 128 -s r 1024 64 512 512 64 64 -s r 1024 256 32 32 256 256 -s r 1024 512 64 64 512 512 -s r 480 640 512 512 640 640 -s p 1024 32 256 256 32 32 -s P 1024 64 512 512 64 64 -s P 64 800 320 320 800 800 -s P 64 768 512 512 768 768 -s P 16 256 512 512 256 256 -s P 128 128 128 128 128 128 -s P 256 512 256 256 512 512 -s P 1024 1024 1024 1024 1024 1024 -s P 480 640 1024 1024 640 640 -s P 480 640 256 256 640 640 -s P 8 64 32 32 64 64 -s P 9 64 32 32 64 64 -s P 10 128 64 64 128 128 -s P 8 8 8 8 8 8 -s P 12 12 12 12 12 12 -s P 25 25 25 25 25 25 -s P 25 25 20 20 25 25 -i p 480 20 2050 2050 20 20 -i p 481 20 2050 2050 20 20 -i p 482 20 2050 2050 20 20 -i p 483 20 2050 2050 20 20 -i R 484 20 2050 2050 20 20 -i R 485 20 2050 2050 20 20 -i R 480 39 2050 2050 39 39 -i R 481 39 2050 2050 39 39 -i R 482 39 2050 2050 39 39 -i R 483 39 2050 2050 39 39 -i R 484 39 2050 2050 39 39 -i p 485 39 2050 2050 39 39 -i p 480 50 2050 2050 50 50 -i p 481 50 2050 2050 50 50 -i p 482 50 2050 2050 50 50 -i p 483 50 2050 2050 50 50 -i p 484 50 2050 2050 50 50 -i p 485 50 2050 2050 50 50 -i R 480 1108 2050 2050 1108 1108 -i R 481 1108 2050 2050 1108 1108 -i R 482 1108 2050 2050 1108 1108 -i R 483 1108 2050 2050 1108 1108 -i R 484 1108 2050 2050 1108 1108 -i R 485 1108 2050 2050 1108 1108 -i R 480 1127 2050 2050 1127 1127 -i R 481 1127 2050 2050 1127 1127 -i R 482 1127 2050 2050 1127 1127 -i R 483 1127 2050 2050 1127 1127 -i p 484 1127 2050 2050 1127 1127 -i p 485 1127 2050 2050 1127 1127 -i p 480 1138 2050 2050 1138 1138 -i p 481 1138 2050 2050 1138 1138 -i p 482 1138 2050 2050 1138 1138 -i p 483 1138 2050 2050 1138 1138 -i p 484 1138 2050 2050 1138 1138 -i p 485 1138 2050 2050 1138 1138 -i p 1 1 3 3 1 1 -i p 1 9 3 3 9 9 -i p 1 2048 3 3 2048 2048 -i p 1 2048 5192 5192 2048 2048 -i p 9 1 3 3 1 1 -i p 576 1 3500 3500 1 1 -i p 1 1 1 1 1 1 -i p 102 1088 1024 1024 1088 1088 -i p 102 2048 1024 1024 2048 2048 -i p 485 656 1024 1024 656 656 -i p 483 656 1024 1024 656 656 -i p 81 128 3 3 128 128 -i p 1022 512 515 515 512 512 -i p 74 512 515 515 512 512 -i p 253 2048 515 515 2048 2048 -i p 8192 1040 515 515 1040 1040 -i p 10 1029 515 515 1029 1029 -i p 24 1040 2050 2050 1040 1040 -i p 1024 1029 2050 2050 1029 1029 -i p 480 660 2050 2050 660 660 -i p 481 660 2050 2050 660 660 -i p 482 660 2050 2050 660 660 -i p 483 660 2050 2050 660 660 -i p 484 660 2050 2050 660 660 -i p 485 660 2050 2050 660 660 -i p 480 679 2050 2050 679 679 -i p 481 679 2050 2050 679 679 -i p 482 679 2050 2050 679 679 -i p 483 679 2050 2050 679 679 -i p 484 679 2050 2050 679 679 -i p 485 679 2050 2050 679 679 -i p 480 690 2050 2050 690 690 -i p 481 690 2050 2050 690 690 -i p 482 690 2050 2050 690 690 -i p 483 690 2050 2050 690 690 -i p 484 690 2050 2050 690 690 -i p 485 690 2050 2050 690 690 -i p 480 660 2048 2048 660 660 -i p 481 660 2048 2048 660 660 -i p 482 660 2048 2048 660 660 -i p 483 660 2048 2048 660 660 -i p 484 660 2048 2048 660 660 -i p 485 660 2048 2048 660 660 -i p 480 679 2048 2048 679 679 -i p 481 679 2048 2048 679 679 -i p 482 679 2048 2048 679 679 -i p 483 679 2048 2048 679 679 -i p 484 679 2048 2048 679 679 -i p 485 679 2048 2048 679 679 -i p 480 690 2048 2048 690 690 -i p 481 690 2048 2048 690 690 -i p 482 690 2048 2048 690 690 -i p 483 690 2048 2048 690 690 -i p 484 690 2048 2048 690 690 -i p 485 690 2048 2048 690 690 -i p 480 656 1024 1024 656 656 -i p 480 128 3 3 128 128 -i p 1024 512 515 515 512 512 -i p 1024 2048 1024 1024 2048 2048 -i p 1024 2048 515 515 2048 2048 -i p 1024 1040 515 515 1040 1040 -i p 5 1029 515 515 1029 1029 -i p 1024 1029 515 515 1029 1029 -i p 1024 1040 2050 2050 1040 1040 -i p 1029 1029 2050 2050 1029 1029 -i R 480 646 2050 2050 646 646 -i R 481 646 2050 2050 646 646 -i R 482 646 2050 2050 646 646 -i R 483 646 2050 2050 646 646 -i R 484 646 2050 2050 646 646 -i R 485 646 2050 2050 646 646 -i R 481 656 2050 2050 656 656 -i R 482 656 2050 2050 656 656 -i R 483 656 2050 2050 656 656 -i R 484 656 2050 2050 656 656 -i p 485 656 2050 2050 656 656 -i p 480 672 2050 2050 672 672 -i p 481 672 2050 2050 672 672 -i p 482 672 2050 2050 672 672 -i p 483 672 2050 2050 672 672 -i p 484 672 2050 2050 672 672 -i p 485 672 2050 2050 672 672 -i p 480 688 2050 2050 688 688 -i p 481 688 2050 2050 688 688 -i r 482 688 2050 2050 688 688 -i r 483 688 2050 2050 688 688 -i r 484 688 2050 2050 688 688 -i r 485 688 2050 2050 688 688 -i r 1024 512 64 64 512 512 -i r 16 256 512 512 256 256 -i r 480 640 512 512 640 640 -i r 64 768 512 512 768 768 -i r 128 128 128 128 128 128 -i r 1024 64 512 512 64 64 -i r 1024 256 32 32 256 256 -i r 1024 512 64 64 512 512 -i r 480 640 512 512 640 640 -i p 1024 32 256 256 32 32 -i P 1024 64 512 512 64 64 -i P 64 800 320 320 800 800 -i P 64 768 512 512 768 768 -i P 16 256 512 512 256 256 -i P 128 128 128 128 128 128 -i P 256 512 256 256 512 512 -i P 1024 1024 1024 1024 1024 1024 -i P 480 640 1024 1024 640 640 -i P 480 640 256 256 640 640 -i P 8 64 32 32 64 64 -i P 9 64 32 32 64 64 -i P 10 128 64 64 128 128 -i P 8 8 8 8 8 8 -i P 12 12 12 12 12 12 -i P 25 25 25 25 25 25 -i P 25 25 20 20 25 25 -f p 480 20 2050 2050 20 20 -f p 481 20 2050 2050 20 20 -f p 482 20 2050 2050 20 20 -f p 483 20 2050 2050 20 20 -f R 484 20 2050 2050 20 20 -f R 485 20 2050 2050 20 20 -f R 480 39 2050 2050 39 39 -f R 481 39 2050 2050 39 39 -f R 482 39 2050 2050 39 39 -f R 483 39 2050 2050 39 39 -f R 484 39 2050 2050 39 39 -f p 485 39 2050 2050 39 39 -f p 480 50 2050 2050 50 50 -f p 481 50 2050 2050 50 50 -f p 482 50 2050 2050 50 50 -f p 483 50 2050 2050 50 50 -f p 484 50 2050 2050 50 50 -f p 485 50 2050 2050 50 50 -f R 480 1108 2050 2050 1108 1108 -f R 481 1108 2050 2050 1108 1108 -f R 482 1108 2050 2050 1108 1108 -f R 483 1108 2050 2050 1108 1108 -f R 484 1108 2050 2050 1108 1108 -f R 485 1108 2050 2050 1108 1108 -f R 480 1127 2050 2050 1127 1127 -f R 481 1127 2050 2050 1127 1127 -f R 482 1127 2050 2050 1127 1127 -f R 483 1127 2050 2050 1127 1127 -f p 484 1127 2050 2050 1127 1127 -f p 485 1127 2050 2050 1127 1127 -f p 480 1138 2050 2050 1138 1138 -f p 481 1138 2050 2050 1138 1138 -f p 482 1138 2050 2050 1138 1138 -f p 483 1138 2050 2050 1138 1138 -f p 484 1138 2050 2050 1138 1138 -f p 485 1138 2050 2050 1138 1138 -f p 1 1 3 3 1 1 -f p 1 9 3 3 9 9 -f p 1 2048 3 3 2048 2048 -f p 1 2048 5192 5192 2048 2048 -f p 9 1 3 3 1 1 -f p 576 1 3500 3500 1 1 -f p 1 1 1 1 1 1 -f p 102 1088 1024 1024 1088 1088 -f p 102 2048 1024 1024 2048 2048 -f p 485 656 1024 1024 656 656 -f p 483 656 1024 1024 656 656 -f p 81 128 3 3 128 128 -f p 1022 512 515 515 512 512 -f p 74 512 515 515 512 512 -f p 253 2048 515 515 2048 2048 -f p 8192 1040 515 515 1040 1040 -f p 10 1029 515 515 1029 1029 -f p 24 1040 2050 2050 1040 1040 -f p 1024 1029 2050 2050 1029 1029 -f p 480 660 2050 2050 660 660 -f p 481 660 2050 2050 660 660 -f p 482 660 2050 2050 660 660 -f p 483 660 2050 2050 660 660 -f p 484 660 2050 2050 660 660 -f p 485 660 2050 2050 660 660 -f p 480 679 2050 2050 679 679 -f p 481 679 2050 2050 679 679 -f p 482 679 2050 2050 679 679 -f p 483 679 2050 2050 679 679 -f p 484 679 2050 2050 679 679 -f p 485 679 2050 2050 679 679 -f p 480 690 2050 2050 690 690 -f p 481 690 2050 2050 690 690 -f p 482 690 2050 2050 690 690 -f p 483 690 2050 2050 690 690 -f p 484 690 2050 2050 690 690 -f p 485 690 2050 2050 690 690 -f p 480 660 2048 2048 660 660 -f p 481 660 2048 2048 660 660 -f p 482 660 2048 2048 660 660 -f p 483 660 2048 2048 660 660 -f p 484 660 2048 2048 660 660 -f p 485 660 2048 2048 660 660 -f p 480 679 2048 2048 679 679 -f p 481 679 2048 2048 679 679 -f p 482 679 2048 2048 679 679 -f p 483 679 2048 2048 679 679 -f p 484 679 2048 2048 679 679 -f p 485 679 2048 2048 679 679 -f p 480 690 2048 2048 690 690 -f p 481 690 2048 2048 690 690 -f p 482 690 2048 2048 690 690 -f p 483 690 2048 2048 690 690 -f p 484 690 2048 2048 690 690 -f p 485 690 2048 2048 690 690 -f p 480 656 1024 1024 656 656 -f p 480 128 3 3 128 128 -f p 1024 512 515 515 512 512 -f p 1024 2048 1024 1024 2048 2048 -f p 1024 2048 515 515 2048 2048 -f p 1024 1040 515 515 1040 1040 -f p 5 1029 515 515 1029 1029 -f p 1024 1029 515 515 1029 1029 -f p 1024 1040 2050 2050 1040 1040 -f p 1029 1029 2050 2050 1029 1029 -f R 480 646 2050 2050 646 646 -f R 481 646 2050 2050 646 646 -f R 482 646 2050 2050 646 646 -f R 483 646 2050 2050 646 646 -f R 484 646 2050 2050 646 646 -f R 485 646 2050 2050 646 646 -f R 481 656 2050 2050 656 656 -f R 482 656 2050 2050 656 656 -f R 483 656 2050 2050 656 656 -f R 484 656 2050 2050 656 656 -f p 485 656 2050 2050 656 656 -f p 480 672 2050 2050 672 672 -f p 481 672 2050 2050 672 672 -f p 482 672 2050 2050 672 672 -f p 483 672 2050 2050 672 672 -f p 484 672 2050 2050 672 672 -f p 485 672 2050 2050 672 672 -f p 480 688 2050 2050 688 688 -f p 481 688 2050 2050 688 688 -f r 482 688 2050 2050 688 688 -f r 483 688 2050 2050 688 688 -f r 484 688 2050 2050 688 688 -f r 485 688 2050 2050 688 688 -f r 1024 512 64 64 512 512 -f r 16 256 512 512 256 256 -f r 480 640 512 512 640 640 -f r 64 768 512 512 768 768 -f r 128 128 128 128 128 128 -f r 1024 64 512 512 64 64 -f r 1024 256 32 32 256 256 -f r 1024 512 64 64 512 512 -f r 480 640 512 512 640 640 -f p 1024 32 256 256 32 32 -f P 1024 64 512 512 64 64 -f P 64 800 320 320 800 800 -f P 64 768 512 512 768 768 -f P 16 256 512 512 256 256 -f P 128 128 128 128 128 128 -f P 256 512 256 256 512 512 -f P 1024 1024 1024 1024 1024 1024 -f P 480 640 1024 1024 640 640 -f P 480 640 256 256 640 640 -f P 8 64 32 32 64 64 -f P 9 64 32 32 64 64 -f P 10 128 64 64 128 128 -f P 8 8 8 8 8 8 -f P 12 12 12 12 12 12 -f P 25 25 25 25 25 25 -f P 25 25 20 20 25 25 -i r 4096 256 5 5 256 256 -i r 3000 256 128 128 256 256 -i r 4096 1024 512 512 1024 1024 -i r 144 256 5 5 256 256 -i r 144 256 128 128 256 256 -i r 144 1024 512 512 1024 1024 -i r 480 688 256 256 688 688 -i r 480 640 512 512 640 640 -i r 480 640 1024 1024 640 640 -i r 64 800 320 320 800 800 -i r 64 768 512 512 768 768 -i r 16 256 512 512 256 256 -i r 128 128 128 128 128 128 -i r 256 512 256 256 512 512 -i r 1024 1024 1024 1024 1024 1024 -i r 1024 32 256 256 32 32 -i r 1024 64 512 512 64 64 -i r 1024 256 32 32 256 256 -i r 1024 512 64 64 512 512 -i r 512 32 256 256 32 32 -i r 512 768 512 512 768 768 -i r 512 256 32 32 256 256 -i r 512 512 64 64 512 512 -i r 512 256 768 768 256 256 -i r 768 768 1024 1024 768 768 -i r 768 768 768 768 768 768 -i r 2048 2048 2048 2048 2048 2048 -i r 4096 4096 4096 4096 4096 4096 -f r 4096 256 5 5 256 256 -f r 3000 256 128 128 256 256 -f r 4096 1024 512 512 1024 1024 -f r 144 256 5 5 256 256 -f r 144 256 128 128 256 256 -f r 144 1024 512 512 1024 1024 -f r 480 688 256 256 688 688 -f r 480 640 512 512 640 640 -f r 480 640 1024 1024 640 640 -f r 64 800 320 320 800 800 -f r 64 768 512 512 768 768 -f r 16 256 512 512 256 256 -f r 128 128 128 128 128 128 -f r 256 512 256 256 512 512 -f r 1024 1024 1024 1024 1024 1024 -f r 1024 32 256 256 32 32 -f r 1024 64 512 512 64 64 -f r 1024 256 32 32 256 256 -f r 1024 512 64 64 512 512 -f r 512 32 256 256 32 32 -f r 512 768 512 512 768 768 -f r 512 256 32 32 256 256 -f r 512 512 64 64 512 512 -f r 512 256 768 768 256 256 -f r 768 768 1024 1024 768 768 -f r 768 768 768 768 768 768 -f r 2048 2048 2048 2048 2048 2048 -f r 4096 4096 4096 4096 4096 4096 -f r 2048 1024 1024 1024 1024 1024 -f r 2048 4096 1024 1024 4096 4096 -f r 2048 1024 4096 4096 1024 1024 -f r 2048 1024 2 2 1024 1024 -f r 128 1024 1024 1024 1024 1024 -f r 1536 768 768 768 768 768 -f r 1536 3072 768 768 3072 3072 -f r 1536 768 3072 3072 768 768 -f r 1536 768 2 2 768 768 -f r 128 768 768 768 768 768 -f r 1024 8 13 13 8 8 -f r 1024 4 8 8 4 4 -f r 1024 128 355 355 128 128 -f r 1024 64 128 128 64 64 -f r 1024 1 64 64 1 1 -f r 480 1 256 256 1 1 -f r 480 256 512 512 256 256 -f r 480 1024 845 845 1024 1024 -f r 480 512 1024 1024 512 512 -f r 10 17191 128 128 17191 17191 -f r 10 512 256 256 512 512 +s r r 480 20 2050 2050 20 20 +s r r 481 20 2050 2050 20 20 +s r r 482 20 2050 2050 20 20 +s r p 483 20 2050 2050 20 20 +s r R 484 20 2050 2050 20 20 +s r R 485 20 2050 2050 20 20 +s r R 480 39 2050 2050 39 39 +s r R 481 39 2050 2050 39 39 +s r R 482 39 2050 2050 39 39 +s r R 483 39 2050 2050 39 39 +s r R 484 39 2050 2050 39 39 +s r p 485 39 2050 2050 39 39 +s r p 480 50 2050 2050 50 50 +s r p 481 50 2050 2050 50 50 +s r p 482 50 2050 2050 50 50 +s r p 483 50 2050 2050 50 50 +s r p 484 50 2050 2050 50 50 +s r p 485 50 2050 2050 50 50 +s r R 480 1108 2050 2050 1108 1108 +s r R 481 1108 2050 2050 1108 1108 +s r R 482 1108 2050 2050 1108 1108 +s r R 483 1108 2050 2050 1108 1108 +s r R 484 1108 2050 2050 1108 1108 +s r R 485 1108 2050 2050 1108 1108 +s r R 480 1127 2050 2050 1127 1127 +s r R 481 1127 2050 2050 1127 1127 +s r R 482 1127 2050 2050 1127 1127 +s r R 483 1127 2050 2050 1127 1127 +s r p 484 1127 2050 2050 1127 1127 +s r p 485 1127 2050 2050 1127 1127 +s r p 480 1138 2050 2050 1138 1138 +s r p 481 1138 2050 2050 1138 1138 +s r p 482 1138 2050 2050 1138 1138 +s r p 483 1138 2050 2050 1138 1138 +s r p 484 1138 2050 2050 1138 1138 +s r p 485 1138 2050 2050 1138 1138 +s r p 1 1 3 3 1 1 +s r p 1 9 3 3 9 9 +s r p 1 2048 3 3 2048 2048 +s r p 1 2048 5192 5192 2048 2048 +s r p 9 1 3 3 1 1 +s r p 576 1 3500 3500 1 1 +s r p 1 1 1 1 1 1 +s r p 102 1088 1024 1024 1088 1088 +s r p 102 2048 1024 1024 2048 2048 +s r p 485 656 1024 1024 656 656 +s r p 483 656 1024 1024 656 656 +s r p 81 128 3 3 128 128 +s r p 1022 512 515 515 512 512 +s r p 74 512 515 515 512 512 +s r p 253 2048 515 515 2048 2048 +s r p 8192 1040 515 515 1040 1040 +s r p 10 1029 515 515 1029 1029 +s r p 24 1040 2050 2050 1040 1040 +s r p 1024 1029 2050 2050 1029 1029 +s r p 480 660 2050 2050 660 660 +s r p 481 660 2050 2050 660 660 +s r p 482 660 2050 2050 660 660 +s r p 483 660 2050 2050 660 660 +s r p 484 660 2050 2050 660 660 +s r p 485 660 2050 2050 660 660 +s r p 480 679 2050 2050 679 679 +s r p 481 679 2050 2050 679 679 +s r p 482 679 2050 2050 679 679 +s r p 483 679 2050 2050 679 679 +s r p 484 679 2050 2050 679 679 +s r p 485 679 2050 2050 679 679 +s r p 480 690 2050 2050 690 690 +s r p 481 690 2050 2050 690 690 +s r p 482 690 2050 2050 690 690 +s r p 483 690 2050 2050 690 690 +s r p 484 690 2050 2050 690 690 +s r p 485 690 2050 2050 690 690 +s r p 480 660 2048 2048 660 660 +s r p 481 660 2048 2048 660 660 +s r p 482 660 2048 2048 660 660 +s r p 483 660 2048 2048 660 660 +s r p 484 660 2048 2048 660 660 +s r p 485 660 2048 2048 660 660 +s r p 480 679 2048 2048 679 679 +s r p 481 679 2048 2048 679 679 +s r p 482 679 2048 2048 679 679 +s r p 483 679 2048 2048 679 679 +s r p 484 679 2048 2048 679 679 +s r p 485 679 2048 2048 679 679 +s r p 480 690 2048 2048 690 690 +s r p 481 690 2048 2048 690 690 +s r p 482 690 2048 2048 690 690 +s r p 483 690 2048 2048 690 690 +s r p 484 690 2048 2048 690 690 +s r p 485 690 2048 2048 690 690 +s r p 480 656 1024 1024 656 656 +s r p 480 128 3 3 128 128 +s r p 1024 512 515 515 512 512 +s r p 1024 2048 1024 1024 2048 2048 +s r p 1024 2048 515 515 2048 2048 +s r p 1024 1040 515 515 1040 1040 +s r p 5 1029 515 515 1029 1029 +s r p 1024 1029 515 515 1029 1029 +s r p 1024 1040 2050 2050 1040 1040 +s r p 1029 1029 2050 2050 1029 1029 +s r R 480 646 2050 2050 646 646 +s r R 481 646 2050 2050 646 646 +s r R 482 646 2050 2050 646 646 +s r R 483 646 2050 2050 646 646 +s r R 484 646 2050 2050 646 646 +s r R 485 646 2050 2050 646 646 +s r R 481 656 2050 2050 656 656 +s r R 482 656 2050 2050 656 656 +s r R 483 656 2050 2050 656 656 +s r R 484 656 2050 2050 656 656 +s r p 485 656 2050 2050 656 656 +s r p 480 672 2050 2050 672 672 +s r p 481 672 2050 2050 672 672 +s r p 482 672 2050 2050 672 672 +s r p 483 672 2050 2050 672 672 +s r p 484 672 2050 2050 672 672 +s r p 485 672 2050 2050 672 672 +s r p 480 688 2050 2050 688 688 +s r p 481 688 2050 2050 688 688 +s r r 482 688 2050 2050 688 688 +s r r 483 688 2050 2050 688 688 +s r r 484 688 2050 2050 688 688 +s r r 485 688 2050 2050 688 688 +s r r 1024 512 64 64 512 512 +s r r 16 256 512 512 256 256 +s r r 480 640 512 512 640 640 +s r r 64 768 512 512 768 768 +s r r 128 128 128 128 128 128 +s r r 1024 64 512 512 64 64 +s r r 1024 256 32 32 256 256 +s r r 1024 512 64 64 512 512 +s r r 480 640 512 512 640 640 +s r p 1024 32 256 256 32 32 +s r P 1024 64 512 512 64 64 +s r P 64 800 320 320 800 800 +s r P 64 768 512 512 768 768 +s r P 16 256 512 512 256 256 +s r P 128 128 128 128 128 128 +s r P 256 512 256 256 512 512 +s r P 1024 1024 1024 1024 1024 1024 +s r P 480 640 1024 1024 640 640 +s r P 480 640 256 256 640 640 +s r P 8 64 32 32 64 64 +s r P 9 64 32 32 64 64 +s r P 10 128 64 64 128 128 +s r P 8 8 8 8 8 8 +s r P 12 12 12 12 12 12 +s r P 25 25 25 25 25 25 +s r P 25 25 20 20 25 25 +i r p 480 20 2050 2050 20 20 +i r p 481 20 2050 2050 20 20 +i r p 482 20 2050 2050 20 20 +i r p 483 20 2050 2050 20 20 +i r R 484 20 2050 2050 20 20 +i r R 485 20 2050 2050 20 20 +i r R 480 39 2050 2050 39 39 +i r R 481 39 2050 2050 39 39 +i r R 482 39 2050 2050 39 39 +i r R 483 39 2050 2050 39 39 +i r R 484 39 2050 2050 39 39 +i r p 485 39 2050 2050 39 39 +i r p 480 50 2050 2050 50 50 +i r p 481 50 2050 2050 50 50 +i r p 482 50 2050 2050 50 50 +i r p 483 50 2050 2050 50 50 +i r p 484 50 2050 2050 50 50 +i r p 485 50 2050 2050 50 50 +i r R 480 1108 2050 2050 1108 1108 +i r R 481 1108 2050 2050 1108 1108 +i r R 482 1108 2050 2050 1108 1108 +i r R 483 1108 2050 2050 1108 1108 +i r R 484 1108 2050 2050 1108 1108 +i r R 485 1108 2050 2050 1108 1108 +i r R 480 1127 2050 2050 1127 1127 +i r R 481 1127 2050 2050 1127 1127 +i r R 482 1127 2050 2050 1127 1127 +i r R 483 1127 2050 2050 1127 1127 +i r p 484 1127 2050 2050 1127 1127 +i r p 485 1127 2050 2050 1127 1127 +i r p 480 1138 2050 2050 1138 1138 +i r p 481 1138 2050 2050 1138 1138 +i r p 482 1138 2050 2050 1138 1138 +i r p 483 1138 2050 2050 1138 1138 +i r p 484 1138 2050 2050 1138 1138 +i r p 485 1138 2050 2050 1138 1138 +i r p 1 1 3 3 1 1 +i r p 1 9 3 3 9 9 +i r p 1 2048 3 3 2048 2048 +i r p 1 2048 5192 5192 2048 2048 +i r p 9 1 3 3 1 1 +i r p 576 1 3500 3500 1 1 +i r p 1 1 1 1 1 1 +i r p 102 1088 1024 1024 1088 1088 +i r p 102 2048 1024 1024 2048 2048 +i r p 485 656 1024 1024 656 656 +i r p 483 656 1024 1024 656 656 +i r p 81 128 3 3 128 128 +i r p 1022 512 515 515 512 512 +i r p 74 512 515 515 512 512 +i r p 253 2048 515 515 2048 2048 +i r p 8192 1040 515 515 1040 1040 +i r p 10 1029 515 515 1029 1029 +i r p 24 1040 2050 2050 1040 1040 +i r p 1024 1029 2050 2050 1029 1029 +i r p 480 660 2050 2050 660 660 +i r p 481 660 2050 2050 660 660 +i r p 482 660 2050 2050 660 660 +i r p 483 660 2050 2050 660 660 +i r p 484 660 2050 2050 660 660 +i r p 485 660 2050 2050 660 660 +i r p 480 679 2050 2050 679 679 +i r p 481 679 2050 2050 679 679 +i r p 482 679 2050 2050 679 679 +i r p 483 679 2050 2050 679 679 +i r p 484 679 2050 2050 679 679 +i r p 485 679 2050 2050 679 679 +i r p 480 690 2050 2050 690 690 +i r p 481 690 2050 2050 690 690 +i r p 482 690 2050 2050 690 690 +i r p 483 690 2050 2050 690 690 +i r p 484 690 2050 2050 690 690 +i r p 485 690 2050 2050 690 690 +i r p 480 660 2048 2048 660 660 +i r p 481 660 2048 2048 660 660 +i r p 482 660 2048 2048 660 660 +i r p 483 660 2048 2048 660 660 +i r p 484 660 2048 2048 660 660 +i r p 485 660 2048 2048 660 660 +i r p 480 679 2048 2048 679 679 +i r p 481 679 2048 2048 679 679 +i r p 482 679 2048 2048 679 679 +i r p 483 679 2048 2048 679 679 +i r p 484 679 2048 2048 679 679 +i r p 485 679 2048 2048 679 679 +i r p 480 690 2048 2048 690 690 +i r p 481 690 2048 2048 690 690 +i r p 482 690 2048 2048 690 690 +i r p 483 690 2048 2048 690 690 +i r p 484 690 2048 2048 690 690 +i r p 485 690 2048 2048 690 690 +i r p 480 656 1024 1024 656 656 +i r p 480 128 3 3 128 128 +i r p 1024 512 515 515 512 512 +i r p 1024 2048 1024 1024 2048 2048 +i r p 1024 2048 515 515 2048 2048 +i r p 1024 1040 515 515 1040 1040 +i r p 5 1029 515 515 1029 1029 +i r p 1024 1029 515 515 1029 1029 +i r p 1024 1040 2050 2050 1040 1040 +i r p 1029 1029 2050 2050 1029 1029 +i r R 480 646 2050 2050 646 646 +i r R 481 646 2050 2050 646 646 +i r R 482 646 2050 2050 646 646 +i r R 483 646 2050 2050 646 646 +i r R 484 646 2050 2050 646 646 +i r R 485 646 2050 2050 646 646 +i r R 481 656 2050 2050 656 656 +i r R 482 656 2050 2050 656 656 +i r R 483 656 2050 2050 656 656 +i r R 484 656 2050 2050 656 656 +i r p 485 656 2050 2050 656 656 +i r p 480 672 2050 2050 672 672 +i r p 481 672 2050 2050 672 672 +i r p 482 672 2050 2050 672 672 +i r p 483 672 2050 2050 672 672 +i r p 484 672 2050 2050 672 672 +i r p 485 672 2050 2050 672 672 +i r p 480 688 2050 2050 688 688 +i r p 481 688 2050 2050 688 688 +i r r 482 688 2050 2050 688 688 +i r r 483 688 2050 2050 688 688 +i r r 484 688 2050 2050 688 688 +i r r 485 688 2050 2050 688 688 +i r r 1024 512 64 64 512 512 +i r r 16 256 512 512 256 256 +i r r 480 640 512 512 640 640 +i r r 64 768 512 512 768 768 +i r r 128 128 128 128 128 128 +i r r 1024 64 512 512 64 64 +i r r 1024 256 32 32 256 256 +i r r 1024 512 64 64 512 512 +i r r 480 640 512 512 640 640 +i r p 1024 32 256 256 32 32 +i r P 1024 64 512 512 64 64 +i r P 64 800 320 320 800 800 +i r P 64 768 512 512 768 768 +i r P 16 256 512 512 256 256 +i r P 128 128 128 128 128 128 +i r P 256 512 256 256 512 512 +i r P 1024 1024 1024 1024 1024 1024 +i r P 480 640 1024 1024 640 640 +i r P 480 640 256 256 640 640 +i r P 8 64 32 32 64 64 +i r P 9 64 32 32 64 64 +i r P 10 128 64 64 128 128 +i r P 8 8 8 8 8 8 +i r P 12 12 12 12 12 12 +i r P 25 25 25 25 25 25 +i r P 25 25 20 20 25 25 +f r p 480 20 2050 2050 20 20 +f r p 481 20 2050 2050 20 20 +f r p 482 20 2050 2050 20 20 +f r p 483 20 2050 2050 20 20 +f r R 484 20 2050 2050 20 20 +f r R 485 20 2050 2050 20 20 +f r R 480 39 2050 2050 39 39 +f r R 481 39 2050 2050 39 39 +f r R 482 39 2050 2050 39 39 +f r R 483 39 2050 2050 39 39 +f r R 484 39 2050 2050 39 39 +f r p 485 39 2050 2050 39 39 +f r p 480 50 2050 2050 50 50 +f r p 481 50 2050 2050 50 50 +f r p 482 50 2050 2050 50 50 +f r p 483 50 2050 2050 50 50 +f r p 484 50 2050 2050 50 50 +f r p 485 50 2050 2050 50 50 +f r R 480 1108 2050 2050 1108 1108 +f r R 481 1108 2050 2050 1108 1108 +f r R 482 1108 2050 2050 1108 1108 +f r R 483 1108 2050 2050 1108 1108 +f r R 484 1108 2050 2050 1108 1108 +f r R 485 1108 2050 2050 1108 1108 +f r R 480 1127 2050 2050 1127 1127 +f r R 481 1127 2050 2050 1127 1127 +f r R 482 1127 2050 2050 1127 1127 +f r R 483 1127 2050 2050 1127 1127 +f r p 484 1127 2050 2050 1127 1127 +f r p 485 1127 2050 2050 1127 1127 +f r p 480 1138 2050 2050 1138 1138 +f r p 481 1138 2050 2050 1138 1138 +f r p 482 1138 2050 2050 1138 1138 +f r p 483 1138 2050 2050 1138 1138 +f r p 484 1138 2050 2050 1138 1138 +f r p 485 1138 2050 2050 1138 1138 +f r p 1 1 3 3 1 1 +f r p 1 9 3 3 9 9 +f r p 1 2048 3 3 2048 2048 +f r p 1 2048 5192 5192 2048 2048 +f r p 9 1 3 3 1 1 +f r p 576 1 3500 3500 1 1 +f r p 1 1 1 1 1 1 +f r p 102 1088 1024 1024 1088 1088 +f r p 102 2048 1024 1024 2048 2048 +f r p 485 656 1024 1024 656 656 +f r p 483 656 1024 1024 656 656 +f r p 81 128 3 3 128 128 +f r p 1022 512 515 515 512 512 +f r p 74 512 515 515 512 512 +f r p 253 2048 515 515 2048 2048 +f r p 8192 1040 515 515 1040 1040 +f r p 10 1029 515 515 1029 1029 +f r p 24 1040 2050 2050 1040 1040 +f r p 1024 1029 2050 2050 1029 1029 +f r p 480 660 2050 2050 660 660 +f r p 481 660 2050 2050 660 660 +f r p 482 660 2050 2050 660 660 +f r p 483 660 2050 2050 660 660 +f r p 484 660 2050 2050 660 660 +f r p 485 660 2050 2050 660 660 +f r p 480 679 2050 2050 679 679 +f r p 481 679 2050 2050 679 679 +f r p 482 679 2050 2050 679 679 +f r p 483 679 2050 2050 679 679 +f r p 484 679 2050 2050 679 679 +f r p 485 679 2050 2050 679 679 +f r p 480 690 2050 2050 690 690 +f r p 481 690 2050 2050 690 690 +f r p 482 690 2050 2050 690 690 +f r p 483 690 2050 2050 690 690 +f r p 484 690 2050 2050 690 690 +f r p 485 690 2050 2050 690 690 +f r p 480 660 2048 2048 660 660 +f r p 481 660 2048 2048 660 660 +f r p 482 660 2048 2048 660 660 +f r p 483 660 2048 2048 660 660 +f r p 484 660 2048 2048 660 660 +f r p 485 660 2048 2048 660 660 +f r p 480 679 2048 2048 679 679 +f r p 481 679 2048 2048 679 679 +f r p 482 679 2048 2048 679 679 +f r p 483 679 2048 2048 679 679 +f r p 484 679 2048 2048 679 679 +f r p 485 679 2048 2048 679 679 +f r p 480 690 2048 2048 690 690 +f r p 481 690 2048 2048 690 690 +f r p 482 690 2048 2048 690 690 +f r p 483 690 2048 2048 690 690 +f r p 484 690 2048 2048 690 690 +f r p 485 690 2048 2048 690 690 +f r p 480 656 1024 1024 656 656 +f r p 480 128 3 3 128 128 +f r p 1024 512 515 515 512 512 +f r p 1024 2048 1024 1024 2048 2048 +f r p 1024 2048 515 515 2048 2048 +f r p 1024 1040 515 515 1040 1040 +f r p 5 1029 515 515 1029 1029 +f r p 1024 1029 515 515 1029 1029 +f r p 1024 1040 2050 2050 1040 1040 +f r p 1029 1029 2050 2050 1029 1029 +f r R 480 646 2050 2050 646 646 +f r R 481 646 2050 2050 646 646 +f r R 482 646 2050 2050 646 646 +f r R 483 646 2050 2050 646 646 +f r R 484 646 2050 2050 646 646 +f r R 485 646 2050 2050 646 646 +f r R 481 656 2050 2050 656 656 +f r R 482 656 2050 2050 656 656 +f r R 483 656 2050 2050 656 656 +f r R 484 656 2050 2050 656 656 +f r p 485 656 2050 2050 656 656 +f r p 480 672 2050 2050 672 672 +f r p 481 672 2050 2050 672 672 +f r p 482 672 2050 2050 672 672 +f r p 483 672 2050 2050 672 672 +f r p 484 672 2050 2050 672 672 +f r p 485 672 2050 2050 672 672 +f r p 480 688 2050 2050 688 688 +f r p 481 688 2050 2050 688 688 +f r r 482 688 2050 2050 688 688 +f r r 483 688 2050 2050 688 688 +f r r 484 688 2050 2050 688 688 +f r r 485 688 2050 2050 688 688 +f r r 1024 512 64 64 512 512 +f r r 16 256 512 512 256 256 +f r r 480 640 512 512 640 640 +f r r 64 768 512 512 768 768 +f r r 128 128 128 128 128 128 +f r r 1024 64 512 512 64 64 +f r r 1024 256 32 32 256 256 +f r r 1024 512 64 64 512 512 +f r r 480 640 512 512 640 640 +f r p 1024 32 256 256 32 32 +f r P 1024 64 512 512 64 64 +f r P 64 800 320 320 800 800 +f r P 64 768 512 512 768 768 +f r P 16 256 512 512 256 256 +f r P 128 128 128 128 128 128 +f r P 256 512 256 256 512 512 +f r P 1024 1024 1024 1024 1024 1024 +f r P 480 640 1024 1024 640 640 +f r P 480 640 256 256 640 640 +f r P 8 64 32 32 64 64 +f r P 9 64 32 32 64 64 +f r P 10 128 64 64 128 128 +f r P 8 8 8 8 8 8 +f r P 12 12 12 12 12 12 +f r P 25 25 25 25 25 25 +f r P 25 25 20 20 25 25 +i r r 4096 256 5 5 256 256 +i r r 3000 256 128 128 256 256 +i r r 4096 1024 512 512 1024 1024 +i r r 144 256 5 5 256 256 +i r r 144 256 128 128 256 256 +i r r 144 1024 512 512 1024 1024 +i r r 480 688 256 256 688 688 +i r r 480 640 512 512 640 640 +i r r 480 640 1024 1024 640 640 +i r r 64 800 320 320 800 800 +i r r 64 768 512 512 768 768 +i r r 16 256 512 512 256 256 +i r r 128 128 128 128 128 128 +i r r 256 512 256 256 512 512 +i r r 1024 1024 1024 1024 1024 1024 +i r r 1024 32 256 256 32 32 +i r r 1024 64 512 512 64 64 +i r r 1024 256 32 32 256 256 +i r r 1024 512 64 64 512 512 +i r r 512 32 256 256 32 32 +i r r 512 768 512 512 768 768 +i r r 512 256 32 32 256 256 +i r r 512 512 64 64 512 512 +i r r 512 256 768 768 256 256 +i r r 768 768 1024 1024 768 768 +i r r 768 768 768 768 768 768 +i r r 2048 2048 2048 2048 2048 2048 +i r r 4096 4096 4096 4096 4096 4096 +f r r 4096 256 5 5 256 256 +f r r 3000 256 128 128 256 256 +f r r 4096 1024 512 512 1024 1024 +f r r 144 256 5 5 256 256 +f r r 144 256 128 128 256 256 +f r r 144 1024 512 512 1024 1024 +f r r 480 688 256 256 688 688 +f r r 480 640 512 512 640 640 +f r r 480 640 1024 1024 640 640 +f r r 64 800 320 320 800 800 +f r r 64 768 512 512 768 768 +f r r 16 256 512 512 256 256 +f r r 128 128 128 128 128 128 +f r r 256 512 256 256 512 512 +f r r 1024 1024 1024 1024 1024 1024 +f r r 1024 32 256 256 32 32 +f r r 1024 64 512 512 64 64 +f r r 1024 256 32 32 256 256 +f r r 1024 512 64 64 512 512 +f r r 512 32 256 256 32 32 +f r r 512 768 512 512 768 768 +f r r 512 256 32 32 256 256 +f r r 512 512 64 64 512 512 +f r r 512 256 768 768 256 256 +f r r 768 768 1024 1024 768 768 +f r r 768 768 768 768 768 768 +f r r 2048 2048 2048 2048 2048 2048 +f r r 4096 4096 4096 4096 4096 4096 +f r r 2048 1024 1024 1024 1024 1024 +f r r 2048 4096 1024 1024 4096 4096 +f r r 2048 1024 4096 4096 1024 1024 +f r r 2048 1024 2 2 1024 1024 +f r r 128 1024 1024 1024 1024 1024 +f r r 1536 768 768 768 768 768 +f r r 1536 3072 768 768 3072 3072 +f r r 1536 768 3072 3072 768 768 +f r r 1536 768 2 2 768 768 +f r r 128 768 768 768 768 768 +f r r 1024 8 13 13 8 8 +f r r 1024 4 8 8 4 4 +f r r 1024 128 355 355 128 128 +f r r 1024 64 128 128 64 64 +f r r 1024 1 64 64 1 1 +f r r 480 1 256 256 1 1 +f r r 480 256 512 512 256 256 +f r r 480 1024 845 845 1024 1024 +f r r 480 512 1024 1024 512 512 +f r r 10 17191 128 128 17191 17191 +f r r 10 512 256 256 512 512 diff --git a/bench/bench_aocl_gemm/bench_lpgemm.c b/bench/bench_aocl_gemm/bench_lpgemm.c index 3484dffb3d..0262a25f36 100644 --- a/bench/bench_aocl_gemm/bench_lpgemm.c +++ b/bench/bench_aocl_gemm/bench_lpgemm.c @@ -55,6 +55,7 @@ GEN_FILL_ARRAY_POST_OPS_FUNC(float) #define GEN_BLIS_MAT_MUL_FUNC(A_type,B_type,C_type,BLAS_SFX) \ void mat_mul_ ## BLAS_SFX \ ( \ + char stor_order, \ char op_t, \ dim_t m, \ dim_t n, \ @@ -70,7 +71,7 @@ void mat_mul_ ## BLAS_SFX \ aocl_post_op* post_op\ ) \ { \ - char storage = 'r'; \ + char storage = stor_order; \ char transa = 'n'; \ char transb = 'n'; \ char reordera = 'n'; \ @@ -191,6 +192,7 @@ void print_result #define GEN_MAT_MUL_BENCH_DRV_FUNC(A_type,B_type,C_type,BLAS_SFX) \ void mat_mul_bench_driver_ ## BLAS_SFX \ ( \ + char stor_order, \ char op_t, \ int32_t n_repeats, \ dim_t m, \ @@ -220,7 +222,7 @@ void mat_mul_bench_driver_ ## BLAS_SFX \ \ GEN_FUNC_NAME(mat_mul_,BLAS_SFX) \ ( \ - op_t, m, n, k, \ + stor_order, op_t, m, n, k, \ alpha, \ a, lda, \ b, ldb, \ @@ -250,6 +252,7 @@ GEN_MAT_MUL_BENCH_DRV_FUNC(float,float,float,f32f32f32of32) void mat_mul_accuracy_check_driver_ ## BLAS_SFX \ ( \ FILE* fout, \ + const char stor_order, \ dim_t m, \ dim_t n, \ dim_t k, \ @@ -266,6 +269,27 @@ void mat_mul_accuracy_check_driver_ ## BLAS_SFX \ aocl_post_op* post_op\ ) \ { \ + dim_t rs_a = lda; \ + dim_t cs_a = 1; \ + dim_t rs_b = ldb; \ + dim_t cs_b = 1; \ + dim_t rs_c = ldc; \ + dim_t cs_c = 1; \ + dim_t rs_c_ref = ldc_ref; \ + dim_t cs_c_ref = 1; \ + \ + if ( ( stor_order == 'C' ) || ( stor_order == 'c' ) ) \ + { \ + rs_a = 1; \ + cs_a = lda; \ + rs_b = 1; \ + cs_b = ldb; \ + rs_c = 1; \ + cs_c = ldc; \ + rs_c_ref = 1; \ + cs_c_ref = ldc_ref; \ + } \ + \ for ( dim_t i = 0; i < m; ++i ) \ { \ for ( dim_t j = 0; j < n; ++j ) \ @@ -275,11 +299,12 @@ void mat_mul_accuracy_check_driver_ ## BLAS_SFX \ \ for ( dim_t p = 0; p < k; ++p) \ { \ - temp_accum += ( *( a + ( i * lda ) + p ) * *( b + ( p * ldb ) + j ) ); \ + temp_accum += ( *( a + ( i * rs_a ) + ( cs_a * p ) ) * \ + *( b + ( rs_b * p ) + ( cs_b * j ) ) ); \ } \ \ - temp_accum = ( beta * ( * (c_ref + ( ldc_ref * i ) + j ) ) ) \ - + ( alpha * temp_accum ); \ + temp_accum = ( beta * ( * (c_ref + ( rs_c_ref * i ) + \ + ( cs_c_ref * j ) ) ) ) + ( alpha * temp_accum ); \ \ if ( post_op != NULL ) \ { \ @@ -337,7 +362,7 @@ void mat_mul_accuracy_check_driver_ ## BLAS_SFX \ out_temp_accum = ( C_type )temp_accum; \ } \ \ - if ( *( c + ( ldc * i ) + j ) != out_temp_accum ) \ + if ( *( c + ( rs_c * i ) + ( cs_c * j ) ) != out_temp_accum ) \ { \ if ( fout ) \ { \ @@ -515,6 +540,7 @@ void mat_mul_bench_main_ ## BLAS_SFX \ ( \ FILE* fin, \ FILE* fout, \ + char stor_order, \ char op_t, \ int32_t m, \ int32_t n, \ @@ -580,7 +606,7 @@ void mat_mul_bench_main_ ## BLAS_SFX \ /* No reordering of B.*/ \ GEN_FUNC_NAME(mat_mul_bench_driver_,BLAS_SFX) \ ( \ - op_t, n_repeats, m, n, k, \ + stor_order, op_t, n_repeats, m, n, k, \ alpha, \ a, stride_a, \ b, stride_b, \ @@ -600,7 +626,7 @@ void mat_mul_bench_main_ ## BLAS_SFX \ \ GEN_FUNC_NAME(mat_mul_bench_driver_,BLAS_SFX) \ ( \ - op_t, n_repeats, m, n, k, \ + stor_order, op_t, n_repeats, m, n, k, \ alpha, \ a, stride_a, \ b_reorder, stride_b, \ @@ -617,7 +643,7 @@ void mat_mul_bench_main_ ## BLAS_SFX \ printf("Running accuracy check.\n"); \ GEN_FUNC_NAME(mat_mul_accuracy_check_driver_,BLAS_SFX) \ ( \ - fout, m, n, k, \ + fout, stor_order, m, n, k, \ alpha, \ a, stride_a, \ b, stride_b, \ @@ -734,6 +760,7 @@ int main( int argc, char** argv ) char op_type_char; char op_t; + char stor_order; int32_t m, n, k; int32_t stride_a, stride_b, stride_c; @@ -771,17 +798,22 @@ int main( int argc, char** argv ) } } - while ( fscanf( fin, "%c %c %d %d %d %d %d %d\n", - &op_type_char, &op_t, &m, &n, &k, - &stride_a, &stride_b, &stride_c ) == 8 ) + // Input format: data_type stor_type pack/reorder m n k lda ldb ldc + while ( fscanf( fin, "%c %c %c %d %d %d %d %d %d\n", + &op_type_char, &stor_order, &op_t, &m, &n, &k, + &stride_a, &stride_b, &stride_c ) == 9 ) { + stor_order = ( ( stor_order == 'r' ) || ( stor_order == 'R' ) || + ( stor_order == 'c' ) || ( stor_order == 'C' ) ) ? + stor_order : 'r'; + if ( ( op_type_char == 'i' ) || ( op_type_char == 'I' ) ) { if ( global_dscale_out == 'n' ) { GEN_FUNC_NAME(mat_mul_bench_main_,u8s8s32os32) ( - fin, fout, op_t, + fin, fout, stor_order, op_t, m, n, k, stride_a, stride_b, stride_c, post_ops_str_dest ); @@ -790,7 +822,7 @@ int main( int argc, char** argv ) { GEN_FUNC_NAME(mat_mul_bench_main_,u8s8s32os8) ( - fin, fout, op_t, + fin, fout, stor_order, op_t, m, n, k, stride_a, stride_b, stride_c, post_ops_str_dest ); @@ -800,7 +832,7 @@ int main( int argc, char** argv ) { GEN_FUNC_NAME(mat_mul_bench_main_,f32f32f32of32) ( - fin, fout, op_t, + fin, fout, stor_order, op_t, m, n, k, stride_a, stride_b, stride_c, NULL ); @@ -811,7 +843,7 @@ int main( int argc, char** argv ) { GEN_FUNC_NAME(mat_mul_bench_main_,u8s8s16os16) ( - fin, fout, op_t, + fin, fout, stor_order, op_t, m, n, k, stride_a, stride_b, stride_c, post_ops_str_dest ); @@ -820,7 +852,7 @@ int main( int argc, char** argv ) { GEN_FUNC_NAME(mat_mul_bench_main_,u8s8s16os8) ( - fin, fout, op_t, + fin, fout, stor_order, op_t, m, n, k, stride_a, stride_b, stride_c, post_ops_str_dest ); From 0445ee3261d448b5d7e826a6d4b973e4042dca27 Mon Sep 17 00:00:00 2001 From: eashdash Date: Thu, 22 Sep 2022 08:54:59 +0000 Subject: [PATCH 224/243] Accumulation type for alpha, beta values and BF16 bench integration 1. Correcting the type of alpha, and beta values from C_type (output type) to accumulation type. For the downscaled LPGEMM APIs, C_type will be the downscaled type but the required type for alpha and beta values should be the accumulation type. 2. BF16 bench integration with the LPGEMM bench for both the BF16 (bf16bf16f32of32 and bf16bf16f32obf16) APIs AMD-Internal: [CPUPL-2561] Change-Id: I3a99336c743f3880be1b96605ceeeae7c3bd4797 --- addon/aocl_gemm/aocl_gemm_bf16bf16f32obf16.c | 2 +- addon/aocl_gemm/aocl_gemm_bf16bf16f32of32.c | 2 +- addon/aocl_gemm/aocl_gemm_f32f32f32of32.c | 2 +- addon/aocl_gemm/aocl_gemm_interface_apis.h | 50 +-- addon/aocl_gemm/aocl_gemm_u8s8s16os16.c | 2 +- addon/aocl_gemm/aocl_gemm_u8s8s16os8.c | 2 +- addon/aocl_gemm/aocl_gemm_u8s8s32os32.c | 2 +- addon/aocl_gemm/aocl_gemm_u8s8s32os8.c | 2 +- bench/bench_aocl_gemm/bench_input.txt | 256 +++++++++++ bench/bench_aocl_gemm/bench_lpgemm.c | 421 +++++++++++++++++-- 10 files changed, 666 insertions(+), 75 deletions(-) diff --git a/addon/aocl_gemm/aocl_gemm_bf16bf16f32obf16.c b/addon/aocl_gemm/aocl_gemm_bf16bf16f32obf16.c index 92171c79c6..fedf3a43c5 100644 --- a/addon/aocl_gemm/aocl_gemm_bf16bf16f32obf16.c +++ b/addon/aocl_gemm/aocl_gemm_bf16bf16f32obf16.c @@ -41,7 +41,7 @@ #include "lpgemm_config.h" #include "lpgemm_utils.h" -AOCL_GEMM_MATMUL(bfloat16,bfloat16,bfloat16,bf16bf16f32obf16) +AOCL_GEMM_MATMUL(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16) { trans_t blis_transa; trans_t blis_transb; diff --git a/addon/aocl_gemm/aocl_gemm_bf16bf16f32of32.c b/addon/aocl_gemm/aocl_gemm_bf16bf16f32of32.c index 731afb129a..8f87f4dff3 100644 --- a/addon/aocl_gemm/aocl_gemm_bf16bf16f32of32.c +++ b/addon/aocl_gemm/aocl_gemm_bf16bf16f32of32.c @@ -41,7 +41,7 @@ #include "lpgemm_config.h" #include "lpgemm_utils.h" -AOCL_GEMM_MATMUL(bfloat16,bfloat16,float,bf16bf16f32of32) +AOCL_GEMM_MATMUL(bfloat16,bfloat16,float,float,bf16bf16f32of32) { trans_t blis_transa; trans_t blis_transb; diff --git a/addon/aocl_gemm/aocl_gemm_f32f32f32of32.c b/addon/aocl_gemm/aocl_gemm_f32f32f32of32.c index 58f2675df3..8366f746cb 100644 --- a/addon/aocl_gemm/aocl_gemm_f32f32f32of32.c +++ b/addon/aocl_gemm/aocl_gemm_f32f32f32of32.c @@ -40,7 +40,7 @@ #include "lpgemm_utils.h" #include "lpgemm_5loop_interface_apis.h" -AOCL_GEMM_MATMUL(float,float,float,f32f32f32of32) +AOCL_GEMM_MATMUL(float,float,float,float,f32f32f32of32) { trans_t blis_transa; trans_t blis_transb; diff --git a/addon/aocl_gemm/aocl_gemm_interface_apis.h b/addon/aocl_gemm/aocl_gemm_interface_apis.h index c2840c3468..40101cbe6a 100644 --- a/addon/aocl_gemm/aocl_gemm_interface_apis.h +++ b/addon/aocl_gemm/aocl_gemm_interface_apis.h @@ -74,34 +74,34 @@ AOCL_GEMM_REORDER(bfloat16,bf16bf16f32of32); // Only supports matrices in row major format. This api can perform gemm with // both normal as well as reordered B matrix as opposesd to sgemm (only // supports former). This api can be considered analogous to packed sgemm api. -#define AOCL_GEMM_MATMUL(A_type,B_type,C_type,LP_SFX) \ +#define AOCL_GEMM_MATMUL(A_type,B_type,C_type,Sum_type,LP_SFX) \ BLIS_EXPORT_ADDON void aocl_gemm_ ## LP_SFX \ ( \ - const char order, \ - const char transa, \ - const char transb, \ - const dim_t m, \ - const dim_t n, \ - const dim_t k, \ - const C_type alpha, \ - const A_type* a, \ - const dim_t lda, \ - const char mem_format_a, \ - const B_type* b, \ - const dim_t ldb, \ - const char mem_format_b, \ - const C_type beta, \ - C_type* c, \ - const dim_t ldc, \ - aocl_post_op* post_op_unparsed \ + const char order, \ + const char transa, \ + const char transb, \ + const dim_t m, \ + const dim_t n, \ + const dim_t k, \ + const Sum_type alpha, \ + const A_type* a, \ + const dim_t lda, \ + const char mem_format_a, \ + const B_type* b, \ + const dim_t ldb, \ + const char mem_format_b, \ + const Sum_type beta, \ + C_type* c, \ + const dim_t ldc, \ + aocl_post_op* post_op_unparsed \ ) \ -AOCL_GEMM_MATMUL(float,float,float,f32f32f32of32); -AOCL_GEMM_MATMUL(uint8_t,int8_t,int32_t,u8s8s32os32); -AOCL_GEMM_MATMUL(uint8_t,int8_t,int16_t,u8s8s16os16); -AOCL_GEMM_MATMUL(bfloat16,bfloat16,float,bf16bf16f32of32); -AOCL_GEMM_MATMUL(uint8_t,int8_t,int8_t,u8s8s32os8); -AOCL_GEMM_MATMUL(uint8_t,int8_t,int8_t,u8s8s16os8); -AOCL_GEMM_MATMUL(bfloat16,bfloat16,bfloat16,bf16bf16f32obf16); +AOCL_GEMM_MATMUL(float,float,float,float,f32f32f32of32); +AOCL_GEMM_MATMUL(uint8_t,int8_t,int32_t,int32_t,u8s8s32os32); +AOCL_GEMM_MATMUL(uint8_t,int8_t,int16_t,int16_t,u8s8s16os16); +AOCL_GEMM_MATMUL(bfloat16,bfloat16,float,float,bf16bf16f32of32); +AOCL_GEMM_MATMUL(uint8_t,int8_t,int8_t,int32_t,u8s8s32os8); +AOCL_GEMM_MATMUL(uint8_t,int8_t,int8_t,int16_t,u8s8s16os8); +AOCL_GEMM_MATMUL(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16); #endif // AOCL_GEMM_INTERFACE_H diff --git a/addon/aocl_gemm/aocl_gemm_u8s8s16os16.c b/addon/aocl_gemm/aocl_gemm_u8s8s16os16.c index 08084fbd7a..1c6b0899ad 100644 --- a/addon/aocl_gemm/aocl_gemm_u8s8s16os16.c +++ b/addon/aocl_gemm/aocl_gemm_u8s8s16os16.c @@ -41,7 +41,7 @@ #include "lpgemm_thread_decor_openmp.h" #include "lpgemm_post_ops.h" -AOCL_GEMM_MATMUL(uint8_t,int8_t,int16_t,u8s8s16os16) +AOCL_GEMM_MATMUL(uint8_t,int8_t,int16_t,int16_t,u8s8s16os16) { trans_t blis_transa; trans_t blis_transb; diff --git a/addon/aocl_gemm/aocl_gemm_u8s8s16os8.c b/addon/aocl_gemm/aocl_gemm_u8s8s16os8.c index 6df22b14ce..fed10c1e01 100644 --- a/addon/aocl_gemm/aocl_gemm_u8s8s16os8.c +++ b/addon/aocl_gemm/aocl_gemm_u8s8s16os8.c @@ -41,7 +41,7 @@ #include "lpgemm_thread_decor_openmp.h" #include "lpgemm_post_ops.h" -AOCL_GEMM_MATMUL(uint8_t,int8_t,int8_t,u8s8s16os8) +AOCL_GEMM_MATMUL(uint8_t,int8_t,int8_t,int16_t,u8s8s16os8) { trans_t blis_transa; trans_t blis_transb; diff --git a/addon/aocl_gemm/aocl_gemm_u8s8s32os32.c b/addon/aocl_gemm/aocl_gemm_u8s8s32os32.c index 29a0c74936..39fd49bca4 100644 --- a/addon/aocl_gemm/aocl_gemm_u8s8s32os32.c +++ b/addon/aocl_gemm/aocl_gemm_u8s8s32os32.c @@ -41,7 +41,7 @@ #include "lpgemm_config.h" #include "lpgemm_utils.h" -AOCL_GEMM_MATMUL(uint8_t,int8_t,int32_t,u8s8s32os32) +AOCL_GEMM_MATMUL(uint8_t,int8_t,int32_t,int32_t,u8s8s32os32) { trans_t blis_transa; trans_t blis_transb; diff --git a/addon/aocl_gemm/aocl_gemm_u8s8s32os8.c b/addon/aocl_gemm/aocl_gemm_u8s8s32os8.c index da2cb1a587..e4a4ce3f2d 100644 --- a/addon/aocl_gemm/aocl_gemm_u8s8s32os8.c +++ b/addon/aocl_gemm/aocl_gemm_u8s8s32os8.c @@ -41,7 +41,7 @@ #include "lpgemm_config.h" #include "lpgemm_utils.h" -AOCL_GEMM_MATMUL(uint8_t,int8_t,int8_t,u8s8s32os8) +AOCL_GEMM_MATMUL(uint8_t,int8_t,int8_t,int32_t,u8s8s32os8) { trans_t blis_transa; trans_t blis_transb; diff --git a/bench/bench_aocl_gemm/bench_input.txt b/bench/bench_aocl_gemm/bench_input.txt index faec5d0b4f..d8b8226a13 100644 --- a/bench/bench_aocl_gemm/bench_input.txt +++ b/bench/bench_aocl_gemm/bench_input.txt @@ -1,3 +1,259 @@ +b r r 480 20 2050 2050 20 20 +b r r 481 20 2050 2050 20 20 +b r r 482 20 2050 2050 20 20 +b r p 483 20 2050 2050 20 20 +b r R 484 20 2050 2050 20 20 +b r R 485 20 2050 2050 20 20 +b r R 480 39 2050 2050 39 39 +b r R 481 39 2050 2050 39 39 +b r R 482 39 2050 2050 39 39 +b r R 483 39 2050 2050 39 39 +b r R 484 39 2050 2050 39 39 +b r p 485 39 2050 2050 39 39 +b r p 480 50 2050 2050 50 50 +b r p 481 50 2050 2050 50 50 +b r p 482 50 2050 2050 50 50 +b r p 483 50 2050 2050 50 50 +b r p 484 50 2050 2050 50 50 +b r p 485 50 2050 2050 50 50 +b r R 480 1108 2050 2050 1108 1108 +b r R 481 1108 2050 2050 1108 1108 +b r R 482 1108 2050 2050 1108 1108 +b r R 483 1108 2050 2050 1108 1108 +b r R 484 1108 2050 2050 1108 1108 +b r R 485 1108 2050 2050 1108 1108 +b r R 480 1127 2050 2050 1127 1127 +b r R 481 1127 2050 2050 1127 1127 +b r R 482 1127 2050 2050 1127 1127 +b r R 483 1127 2050 2050 1127 1127 +b r p 484 1127 2050 2050 1127 1127 +b r p 485 1127 2050 2050 1127 1127 +b r p 480 1138 2050 2050 1138 1138 +b r p 481 1138 2050 2050 1138 1138 +b r p 482 1138 2050 2050 1138 1138 +b r p 483 1138 2050 2050 1138 1138 +b r p 484 1138 2050 2050 1138 1138 +b r p 485 1138 2050 2050 1138 1138 +b r p 1 1 3 3 1 1 +b r p 1 9 3 3 9 9 +b r p 1 2048 3 3 2048 2048 +b r p 1 2048 5192 5192 2048 2048 +b r p 9 1 3 3 1 1 +b r p 576 1 3500 3500 1 1 +b r p 1 1 1 1 1 1 +b r p 102 1088 1024 1024 1088 1088 +b r p 102 2048 1024 1024 2048 2048 +b r p 485 656 1024 1024 656 656 +b r p 483 656 1024 1024 656 656 +b r p 81 128 3 3 128 128 +b r p 1022 512 515 515 512 512 +b r p 74 512 515 515 512 512 +b r p 253 2048 515 515 2048 2048 +b r p 8192 1040 515 515 1040 1040 +b r p 10 1029 515 515 1029 1029 +b r p 24 1040 2050 2050 1040 1040 +b r p 1024 1029 2050 2050 1029 1029 +b r p 480 660 2050 2050 660 660 +b r p 481 660 2050 2050 660 660 +b r p 482 660 2050 2050 660 660 +b r p 483 660 2050 2050 660 660 +b r p 484 660 2050 2050 660 660 +b r p 485 660 2050 2050 660 660 +b r p 480 679 2050 2050 679 679 +b r p 481 679 2050 2050 679 679 +b r p 482 679 2050 2050 679 679 +b r p 483 679 2050 2050 679 679 +b r p 484 679 2050 2050 679 679 +b r p 485 679 2050 2050 679 679 +b r p 480 690 2050 2050 690 690 +b r p 481 690 2050 2050 690 690 +b r p 482 690 2050 2050 690 690 +b r p 483 690 2050 2050 690 690 +b r p 484 690 2050 2050 690 690 +b r p 485 690 2050 2050 690 690 +b r p 480 660 2048 2048 660 660 +b r p 481 660 2048 2048 660 660 +b r p 482 660 2048 2048 660 660 +b r p 483 660 2048 2048 660 660 +b r p 484 660 2048 2048 660 660 +b r p 485 660 2048 2048 660 660 +b r p 480 679 2048 2048 679 679 +b r p 481 679 2048 2048 679 679 +b r p 482 679 2048 2048 679 679 +b r p 483 679 2048 2048 679 679 +b r p 484 679 2048 2048 679 679 +b r p 485 679 2048 2048 679 679 +b r p 480 690 2048 2048 690 690 +b r p 481 690 2048 2048 690 690 +b r p 482 690 2048 2048 690 690 +b r p 483 690 2048 2048 690 690 +b r p 484 690 2048 2048 690 690 +b r p 485 690 2048 2048 690 690 +b r p 480 656 1024 1024 656 656 +b r p 480 128 3 3 128 128 +b r p 1024 512 515 515 512 512 +b r p 1024 2048 1024 1024 2048 2048 +b r p 1024 2048 515 515 2048 2048 +b r p 1024 1040 515 515 1040 1040 +b r p 5 1029 515 515 1029 1029 +b r p 1024 1029 515 515 1029 1029 +b r p 1024 1040 2050 2050 1040 1040 +b r p 1029 1029 2050 2050 1029 1029 +b r R 480 646 2050 2050 646 646 +b r R 481 646 2050 2050 646 646 +b r R 482 646 2050 2050 646 646 +b r R 483 646 2050 2050 646 646 +b r R 484 646 2050 2050 646 646 +b r R 485 646 2050 2050 646 646 +b r R 481 656 2050 2050 656 656 +b r R 482 656 2050 2050 656 656 +b r R 483 656 2050 2050 656 656 +b r R 484 656 2050 2050 656 656 +b r p 485 656 2050 2050 656 656 +b r p 480 672 2050 2050 672 672 +b r p 481 672 2050 2050 672 672 +b r p 482 672 2050 2050 672 672 +b r p 483 672 2050 2050 672 672 +b r p 484 672 2050 2050 672 672 +b r p 485 672 2050 2050 672 672 +b r p 480 688 2050 2050 688 688 +b r p 481 688 2050 2050 688 688 +b r r 482 688 2050 2050 688 688 +b r r 483 688 2050 2050 688 688 +b r r 484 688 2050 2050 688 688 +b r r 485 688 2050 2050 688 688 +b r r 1024 512 64 64 512 512 +b r r 16 256 512 512 256 256 +b r r 480 640 512 512 640 640 +b r r 64 768 512 512 768 768 +b r r 128 128 128 128 128 128 +b r r 1024 64 512 512 64 64 +b r r 1024 256 32 32 256 256 +b r r 1024 512 64 64 512 512 +b r r 480 640 512 512 640 640 +b r p 1024 32 256 256 32 32 +b r P 1024 64 512 512 64 64 +b r P 64 800 320 320 800 800 +b r P 64 768 512 512 768 768 +b r P 16 256 512 512 256 256 +b r P 128 128 128 128 128 128 +b r P 256 512 256 256 512 512 +b r P 1024 1024 1024 1024 1024 1024 +b r P 480 640 1024 1024 640 640 +b r P 480 640 256 256 640 640 +b r P 8 64 32 32 64 64 +b r P 9 64 32 32 64 64 +b r P 10 128 64 64 128 128 +b r P 8 8 8 8 8 8 +b r P 12 12 12 12 12 12 +b r P 25 25 25 25 25 25 +b r P 25 25 20 20 25 25 +b c p 485 39 2050 485 2050 485 +b c p 480 50 2050 480 2050 480 +b c p 481 50 2050 481 2050 481 +b c p 482 50 2050 482 2050 482 +b c p 483 50 2050 483 2050 483 +b c p 484 50 2050 484 2050 484 +b c p 485 50 2050 485 2050 485 +b c p 484 1127 2050 484 2050 484 +b c p 485 1127 2050 485 2050 485 +b c p 480 1138 2050 480 2050 480 +b c p 481 1138 2050 481 2050 481 +b c p 482 1138 2050 482 2050 482 +b c p 483 1138 2050 483 2050 483 +b c p 484 1138 2050 484 2050 484 +b c p 485 1138 2050 485 2050 485 +b c p 1 1 3 1 3 1 +b c p 1 9 3 1 3 1 +b c p 1 2048 3 1 3 1 +b c p 1 2048 5192 1 5192 1 +b c p 9 1 3 9 3 9 +b c p 576 1 3500 576 3500 576 +b c p 1 1 1 1 1 1 +b c p 102 1088 1024 102 1024 102 +b c p 102 2048 1024 102 1024 102 +b c p 485 656 1024 485 1024 485 +b c p 483 656 1024 483 1024 483 +b c p 81 128 3 81 3 81 +b c p 1022 512 515 1022 515 1022 +b c p 74 512 515 74 515 74 +b c p 253 2048 515 253 515 253 +b c p 8192 1040 515 8192 515 8192 +b c p 10 1029 515 10 515 10 +b c p 24 1040 2050 24 2050 24 +b c p 1024 1029 2050 1024 2050 1024 +b c p 480 660 2050 480 2050 480 +b c p 481 660 2050 481 2050 481 +b c p 482 660 2050 482 2050 482 +b c p 483 660 2050 483 2050 483 +b c p 484 660 2050 484 2050 484 +b c p 485 660 2050 485 2050 485 +b c p 480 679 2050 480 2050 480 +b c p 481 679 2050 481 2050 481 +b c p 482 679 2050 482 2050 482 +b c p 483 679 2050 483 2050 483 +b c p 484 679 2050 484 2050 484 +b c p 485 679 2050 485 2050 485 +b c p 480 690 2050 480 2050 480 +b c p 481 690 2050 481 2050 481 +b c p 482 690 2050 482 2050 482 +b c p 483 690 2050 483 2050 483 +b c p 484 690 2050 484 2050 484 +b c p 485 690 2050 485 2050 485 +b c p 480 660 2048 480 2048 480 +b c p 481 660 2048 481 2048 481 +b c p 482 660 2048 482 2048 482 +b c p 483 660 2048 483 2048 483 +b c p 484 660 2048 484 2048 484 +b c p 485 660 2048 485 2048 485 +b c p 480 679 2048 480 2048 480 +b c p 481 679 2048 481 2048 481 +b c p 482 679 2048 482 2048 482 +b c p 483 679 2048 483 2048 483 +b c p 484 679 2048 484 2048 484 +b c p 485 679 2048 485 2048 485 +b c p 480 690 2048 480 2048 480 +b c p 481 690 2048 481 2048 481 +b c p 482 690 2048 482 2048 482 +b c p 483 690 2048 483 2048 483 +b c p 484 690 2048 484 2048 484 +b c p 485 690 2048 485 2048 485 +b c p 480 656 1024 480 1024 480 +b c p 480 128 3 480 3 480 +b c p 1024 512 515 1024 515 1024 +b c p 1024 2048 1024 1024 1024 1024 +b c p 1024 2048 515 1024 515 1024 +b c p 1024 1040 515 1024 515 1024 +b c p 5 1029 515 5 515 5 +b c p 1024 1029 515 1024 515 1024 +b c p 1024 1040 2050 1024 2050 1024 +b c p 1029 1029 2050 1029 2050 1029 +b c p 485 656 2050 485 2050 485 +b c p 480 672 2050 480 2050 480 +b c p 481 672 2050 481 2050 481 +b c p 482 672 2050 482 2050 482 +b c p 483 672 2050 483 2050 483 +b c p 484 672 2050 484 2050 484 +b c p 485 672 2050 485 2050 485 +b c p 480 688 2050 480 2050 480 +b c p 481 688 2050 481 2050 481 +b c p 1024 32 256 1024 256 1024 +b c P 1024 64 512 1024 512 1024 +b c P 64 800 320 64 320 64 +b c P 64 768 512 64 512 64 +b c P 16 256 512 16 512 16 +b c P 128 128 128 128 128 128 +b c P 256 512 256 256 256 256 +b c P 1024 1024 1024 1024 1024 1024 +b c P 480 640 1024 480 1024 480 +b c P 480 640 256 480 256 480 +b c P 8 64 32 8 32 8 +b c P 9 64 32 9 32 9 +b c P 10 128 64 10 64 10 +b c P 8 8 8 8 8 8 +b c P 12 12 12 12 12 12 +b c P 25 25 25 25 25 25 +b c P 25 25 20 25 20 25 s r r 480 20 2050 2050 20 20 s r r 481 20 2050 2050 20 20 s r r 482 20 2050 2050 20 20 diff --git a/bench/bench_aocl_gemm/bench_lpgemm.c b/bench/bench_aocl_gemm/bench_lpgemm.c index 0262a25f36..df239fccb2 100644 --- a/bench/bench_aocl_gemm/bench_lpgemm.c +++ b/bench/bench_aocl_gemm/bench_lpgemm.c @@ -38,6 +38,21 @@ GEN_FILL_ARRAY_FUNC(int8_t) GEN_FILL_ARRAY_FUNC(float) GEN_FILL_ARRAY_FUNC(int32_t) +inline void float_to_bf16( float* float_value, bfloat16* bf16_val ) +{ + /*Set offset 2 to copy most significant 2 bytes of float + to convert float values to bf16 values*/ + memcpy( ( bf16_val ), (char *)( float_value ) + 2, sizeof ( bfloat16 ) ); +} + +inline void convert_float_arr_to_bf16( float* array, bfloat16* array_bf16, int size ) +{ + for (int i=0; i< size; i++) + { + float_to_bf16( ( array + i ), ( array_bf16 + i ) ); + } +} + #define GEN_FILL_ARRAY_POST_OPS_FUNC(ctype) \ void fill_array_post_ops_ ## ctype ( void* arr, dim_t size ) \ { \ @@ -52,7 +67,7 @@ GEN_FILL_ARRAY_POST_OPS_FUNC(int16_t) GEN_FILL_ARRAY_POST_OPS_FUNC(int32_t) GEN_FILL_ARRAY_POST_OPS_FUNC(float) -#define GEN_BLIS_MAT_MUL_FUNC(A_type,B_type,C_type,BLAS_SFX) \ +#define GEN_BLIS_MAT_MUL_FUNC(A_type,B_type,C_type,ACCUM_type,BLAS_SFX) \ void mat_mul_ ## BLAS_SFX \ ( \ char stor_order, \ @@ -60,12 +75,12 @@ void mat_mul_ ## BLAS_SFX \ dim_t m, \ dim_t n, \ dim_t k, \ - C_type alpha, \ + ACCUM_type alpha, \ A_type* a, \ dim_t lda, \ B_type* b, \ dim_t ldb, \ - C_type beta, \ + ACCUM_type beta, \ C_type* c, \ dim_t ldc, \ aocl_post_op* post_op\ @@ -153,11 +168,13 @@ void mat_mul_ ## BLAS_SFX \ } */\ } \ -GEN_BLIS_MAT_MUL_FUNC(uint8_t,int8_t,int16_t,u8s8s16os16) -GEN_BLIS_MAT_MUL_FUNC(uint8_t,int8_t,int8_t,u8s8s16os8) -GEN_BLIS_MAT_MUL_FUNC(uint8_t,int8_t,int32_t,u8s8s32os32) -GEN_BLIS_MAT_MUL_FUNC(uint8_t,int8_t,int8_t,u8s8s32os8) -GEN_BLIS_MAT_MUL_FUNC(float,float,float,f32f32f32of32) +GEN_BLIS_MAT_MUL_FUNC(uint8_t,int8_t,int16_t,int16_t,u8s8s16os16) +GEN_BLIS_MAT_MUL_FUNC(uint8_t,int8_t,int8_t,int16_t,u8s8s16os8) +GEN_BLIS_MAT_MUL_FUNC(uint8_t,int8_t,int32_t,int32_t,u8s8s32os32) +GEN_BLIS_MAT_MUL_FUNC(uint8_t,int8_t,int8_t,int32_t,u8s8s32os8) +GEN_BLIS_MAT_MUL_FUNC(bfloat16,bfloat16,float,float,bf16bf16f32of32) +GEN_BLIS_MAT_MUL_FUNC(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16) +GEN_BLIS_MAT_MUL_FUNC(float,float,float,float,f32f32f32of32) double get_gflops ( @@ -189,7 +206,7 @@ void print_result msg, m, n, k, lda, ldb, ldc, gflops, n_repeats); } -#define GEN_MAT_MUL_BENCH_DRV_FUNC(A_type,B_type,C_type,BLAS_SFX) \ +#define GEN_MAT_MUL_BENCH_DRV_FUNC(A_type,B_type,C_type,ACCUM_type,BLAS_SFX) \ void mat_mul_bench_driver_ ## BLAS_SFX \ ( \ char stor_order, \ @@ -198,12 +215,12 @@ void mat_mul_bench_driver_ ## BLAS_SFX \ dim_t m, \ dim_t n, \ dim_t k, \ - C_type alpha, \ + ACCUM_type alpha, \ A_type* a, \ dim_t lda, \ B_type* b, \ dim_t ldb, \ - C_type beta, \ + ACCUM_type beta, \ C_type* c, \ dim_t ldc, \ aocl_post_op* post_op\ @@ -242,26 +259,167 @@ void mat_mul_bench_driver_ ## BLAS_SFX \ print_result( XSTR(BLAS_SFX), n_repeats, m, n, k, lda, ldb, ldc, min_time_diff); \ } \ -GEN_MAT_MUL_BENCH_DRV_FUNC(uint8_t,int8_t,int16_t,u8s8s16os16) -GEN_MAT_MUL_BENCH_DRV_FUNC(uint8_t,int8_t,int8_t,u8s8s16os8) -GEN_MAT_MUL_BENCH_DRV_FUNC(uint8_t,int8_t,int32_t,u8s8s32os32) -GEN_MAT_MUL_BENCH_DRV_FUNC(uint8_t,int8_t,int8_t,u8s8s32os8) -GEN_MAT_MUL_BENCH_DRV_FUNC(float,float,float,f32f32f32of32) +GEN_MAT_MUL_BENCH_DRV_FUNC(uint8_t,int8_t,int16_t,int16_t,u8s8s16os16) +GEN_MAT_MUL_BENCH_DRV_FUNC(uint8_t,int8_t,int8_t,int16_t,u8s8s16os8) +GEN_MAT_MUL_BENCH_DRV_FUNC(uint8_t,int8_t,int32_t,int32_t,u8s8s32os32) +GEN_MAT_MUL_BENCH_DRV_FUNC(uint8_t,int8_t,int8_t,int32_t,u8s8s32os8) +GEN_MAT_MUL_BENCH_DRV_FUNC(bfloat16,bfloat16,float,float,bf16bf16f32of32) +GEN_MAT_MUL_BENCH_DRV_FUNC(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16) +GEN_MAT_MUL_BENCH_DRV_FUNC(float,float,float,float,f32f32f32of32) + + +#define GEN_MAT_MUL_ACC_CHK_DOWNSCALE(C_type,ACCUM_type,SCALE_type,BLAS_DOWNSCALE_SFX) \ +inline C_type mat_mul_accuracy_check_downscale_ ## BLAS_DOWNSCALE_SFX \ + (\ + ACCUM_type temp_accum,\ + C_type out_temp_accum, \ + aocl_post_op* post_op, \ + dim_t j \ + )\ +{\ + out_temp_accum = ( C_type )lroundf( ( SCALE_type )temp_accum * \ + ( *( ( SCALE_type* )post_op->sum.scale_factor + j ) ) ); \ + return out_temp_accum; \ +}\ + +GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int8_t,int16_t,float,u8s8s16os8) +GEN_MAT_MUL_ACC_CHK_DOWNSCALE(int8_t,int32_t,float,u8s8s32os8) + +inline bfloat16 mat_mul_accuracy_check_downscale_bf16bf16f32obf16 + ( + float temp_accum, + bfloat16 out_temp_accum, + aocl_post_op* post_op, + dim_t j + ) +{ + float_to_bf16( ( &temp_accum ), ( &out_temp_accum ) ); + return out_temp_accum; +} + +#define GEN_MAT_MUL_ACC_CHK_ACCUM(A_type, B_type, C_type,ACCUM_type,BLAS_SFX) \ +inline ACCUM_type mat_mul_accuracy_check_accum_ ## BLAS_SFX \ + (\ + A_type* a, \ + B_type* b, \ + C_type* c_ref, \ + ACCUM_type temp_accum,\ + ACCUM_type alpha, \ + ACCUM_type beta, \ + dim_t rs_a, \ + dim_t rs_b, \ + dim_t cs_a, \ + dim_t cs_b, \ + dim_t rs_c_ref, \ + dim_t cs_c_ref, \ + dim_t i, \ + dim_t j, \ + dim_t k \ + )\ +{\ + for ( dim_t p = 0; p < k; ++p) \ + { \ + temp_accum += ( *( a + ( i * rs_a ) + ( cs_a * p ) ) * \ + *( b + ( rs_b * p ) + ( cs_b * j ) ) ); \ + } \ +\ + temp_accum = ( beta * ( * (c_ref + ( rs_c_ref * i ) + ( cs_c_ref * j ) ) ) ) \ + + ( alpha * temp_accum ); \ + return temp_accum; \ +}\ + +GEN_MAT_MUL_ACC_CHK_ACCUM(uint8_t,int8_t,int8_t,int16_t,u8s8s16os8) +GEN_MAT_MUL_ACC_CHK_ACCUM(uint8_t,int8_t,int16_t,int16_t,u8s8s16os16) +GEN_MAT_MUL_ACC_CHK_ACCUM(uint8_t,int8_t,int8_t,int32_t,u8s8s32os8) +GEN_MAT_MUL_ACC_CHK_ACCUM(uint8_t,int8_t,int32_t,int32_t,u8s8s32os32) +GEN_MAT_MUL_ACC_CHK_ACCUM(float,float,float,float,f32f32f32of32) -#define GEN_MAT_MUL_ACC_CHK_DRV_FUNC(A_type,B_type,C_type,DSCALE_type,SCALE_type,BLAS_SFX) \ +inline float bf16_to_float + ( + bfloat16 bf16_val + ) +{ + int32_t inter_temp = *( ( int16_t* ) &bf16_val ); + inter_temp = inter_temp << 16; + float float_value = *( float* ) ( &inter_temp ); + return float_value; +} + +inline float mat_mul_accuracy_check_accum_bf16bf16f32of32 + ( + bfloat16* a, + bfloat16* b, + float* c_ref, + float temp_accum, + float alpha, + float beta, + dim_t rs_a, + dim_t rs_b, + dim_t cs_a, + dim_t cs_b, + dim_t rs_c_ref, + dim_t cs_c_ref, + dim_t i, + dim_t j, + dim_t k + ) +{ + for ( dim_t p = 0; p < k; ++p) + { + float a_float = bf16_to_float( *( a + i * rs_a + p * cs_a ) ); + float b_float = bf16_to_float( *( b + p * rs_b + j * cs_b ) ); + temp_accum += ( ( a_float ) * ( b_float ) ); + } + temp_accum = ( beta * ( * (c_ref + ( rs_c_ref * i ) + ( cs_c_ref * j ) ) ) ) + + ( alpha * temp_accum ); + return temp_accum; +} + +inline float mat_mul_accuracy_check_accum_bf16bf16f32obf16 + ( + bfloat16* a, + bfloat16* b, + bfloat16* c_ref, + float temp_accum, + float alpha, + float beta, + dim_t rs_a, + dim_t rs_b, + dim_t cs_a, + dim_t cs_b, + dim_t rs_c_ref, + dim_t cs_c_ref, + dim_t i, + dim_t j, + dim_t k + ) +{ + for ( dim_t p = 0; p < k; ++p) + { + float a_float = bf16_to_float( *( a + i*rs_a + p*cs_a ) ); + float b_float = bf16_to_float( *( b + p*rs_b + j*cs_b ) ); + temp_accum += ( ( a_float ) * ( b_float ) ); + } + float c_ref_float = bf16_to_float( *( c_ref + i*rs_c_ref + j*cs_c_ref ) ); + temp_accum = ( beta * ( c_ref_float ) ) + ( alpha * temp_accum ); + + return temp_accum; +} + +#define GEN_MAT_MUL_ACC_CHK_DRV_FUNC(A_type,B_type,C_type,ACCUM_type,SCALE_type,BLAS_SFX,BLAS_DOWNSCALE_SFX) \ void mat_mul_accuracy_check_driver_ ## BLAS_SFX \ ( \ FILE* fout, \ - const char stor_order, \ + const char stor_order, \ dim_t m, \ dim_t n, \ dim_t k, \ - C_type alpha, \ + ACCUM_type alpha, \ A_type* a, \ dim_t lda, \ B_type* b, \ dim_t ldb, \ - C_type beta, \ + ACCUM_type beta, \ C_type* c, \ dim_t ldc, \ C_type* c_ref, \ @@ -294,18 +452,12 @@ void mat_mul_accuracy_check_driver_ ## BLAS_SFX \ { \ for ( dim_t j = 0; j < n; ++j ) \ { \ - DSCALE_type temp_accum = 0; \ + ACCUM_type temp_accum = 0; \ C_type out_temp_accum = 0; \ \ - for ( dim_t p = 0; p < k; ++p) \ - { \ - temp_accum += ( *( a + ( i * rs_a ) + ( cs_a * p ) ) * \ - *( b + ( rs_b * p ) + ( cs_b * j ) ) ); \ - } \ - \ - temp_accum = ( beta * ( * (c_ref + ( rs_c_ref * i ) + \ - ( cs_c_ref * j ) ) ) ) + ( alpha * temp_accum ); \ - \ + temp_accum = GEN_FUNC_NAME(mat_mul_accuracy_check_accum_,BLAS_SFX) \ + (a,b,c_ref,temp_accum,alpha,beta,rs_a,rs_b,cs_a,cs_b,rs_c_ref,cs_c_ref,i,j,k); \ +\ if ( post_op != NULL ) \ { \ /* Apply bias followed by relu. */ \ @@ -313,7 +465,7 @@ void mat_mul_accuracy_check_driver_ ## BLAS_SFX \ { \ if ( post_op->seq_length >= 1 ) \ { \ - temp_accum += ( *( ( DSCALE_type* )post_op->bias.bias + j ) ); \ + temp_accum += ( *( ( ACCUM_type* )post_op->bias.bias + j ) ); \ } \ if ( ( post_op->seq_length > 1 ) && \ ( post_op->seq_vector[1] == ELTWISE ) ) \ @@ -323,7 +475,7 @@ void mat_mul_accuracy_check_driver_ ## BLAS_SFX \ temp_accum = ( temp_accum > 0 ) ? \ temp_accum : \ ( temp_accum * \ - *( ( DSCALE_type* ) post_op->eltwise.algo.alpha ) ); \ + *( ( ACCUM_type* ) post_op->eltwise.algo.alpha ) ); \ } \ else \ { \ @@ -339,7 +491,7 @@ void mat_mul_accuracy_check_driver_ ## BLAS_SFX \ { \ temp_accum = ( temp_accum > 0 ) ? \ temp_accum : \ - ( temp_accum * *( ( DSCALE_type* ) post_op->eltwise.algo.alpha ) ); \ + ( temp_accum * *( ( ACCUM_type* ) post_op->eltwise.algo.alpha ) ); \ } \ else \ { \ @@ -348,14 +500,14 @@ void mat_mul_accuracy_check_driver_ ## BLAS_SFX \ } \ if ( ( post_op->seq_length > 1 ) && ( post_op->seq_vector[1] == BIAS ) ) \ { \ - temp_accum += ( *( ( DSCALE_type* )post_op->bias.bias + j ) ); \ + temp_accum += ( *( ( ACCUM_type* )post_op->bias.bias + j ) ); \ } \ } \ } \ if ( global_dscale_out == 'y' ) \ { \ - out_temp_accum = ( C_type )lroundf( ( SCALE_type )temp_accum * \ - ( *( ( SCALE_type* )post_op->sum.scale_factor + j ) ) ); \ + out_temp_accum = GEN_FUNC_NAME(mat_mul_accuracy_check_downscale_,BLAS_DOWNSCALE_SFX) \ + (temp_accum, out_temp_accum, post_op, j); \ } \ else \ { \ @@ -380,11 +532,13 @@ cleanup_acc: \ return; \ } \ -GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t,int8_t,int16_t,int16_t,float,u8s8s16os16) -GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t,int8_t,int8_t,int16_t,float,u8s8s16os8) -GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t,int8_t,int32_t,int32_t,float,u8s8s32os32) -GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t,int8_t,int8_t,int32_t,float,u8s8s32os8) -GEN_MAT_MUL_ACC_CHK_DRV_FUNC(float,float,float,float,float,f32f32f32of32) +GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t,int8_t,int16_t,int16_t,float,u8s8s16os16,u8s8s16os8) +GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t,int8_t,int8_t,int16_t,float,u8s8s16os8,u8s8s16os8) +GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t,int8_t,int32_t,int32_t,float,u8s8s32os32,u8s8s32os8) +GEN_MAT_MUL_ACC_CHK_DRV_FUNC(uint8_t,int8_t,int8_t,int32_t,float,u8s8s32os8,u8s8s32os8) +GEN_MAT_MUL_ACC_CHK_DRV_FUNC(bfloat16,bfloat16,float,float,float,bf16bf16f32of32,bf16bf16f32obf16) +GEN_MAT_MUL_ACC_CHK_DRV_FUNC(bfloat16,bfloat16,bfloat16,float,float,bf16bf16f32obf16,bf16bf16f32obf16) +GEN_MAT_MUL_ACC_CHK_DRV_FUNC(float,float,float,float,float,f32f32f32of32,bf16bf16f32obf16) /* Only supports bias followed by RELU and vice versa for now.*/ \ #define GEN_MAT_MUL_POST_OPS_CREATOR(C_type,DSCALE_type,BLAS_SFX) \ @@ -506,6 +660,7 @@ aocl_post_op* lpgemm_create_post_ops_struct_ ## BLAS_SFX \ GEN_MAT_MUL_POST_OPS_CREATOR(int16_t,float,u8s8s16os16) GEN_MAT_MUL_POST_OPS_CREATOR(int32_t,float,u8s8s32os32) +GEN_MAT_MUL_POST_OPS_CREATOR(float,float,bf16bf16f32of32) GEN_MAT_MUL_POST_OPS_CREATOR(float,float,f32f32f32of32) void lpgemm_destroy_post_ops_struct( aocl_post_op* post_ops ) @@ -548,7 +703,7 @@ void mat_mul_bench_main_ ## BLAS_SFX \ int32_t stride_a, \ int32_t stride_b, \ int32_t stride_c, \ - char* post_ops_str \ + char* post_ops_str \ ) \ { \ if ( ( op_t != 'p' ) && ( op_t != 'P' ) && ( op_t != 'r' ) && ( op_t != 'R' ) ) \ @@ -680,6 +835,163 @@ GEN_MAT_MUL_BENCH_MAIN_FUNC(uint8_t,int8_t,int32_t,u8s8s32os32,u8s8s32os32) GEN_MAT_MUL_BENCH_MAIN_FUNC(uint8_t,int8_t,int8_t,u8s8s32os8,u8s8s32os32) GEN_MAT_MUL_BENCH_MAIN_FUNC(float,float,float,f32f32f32of32,f32f32f32of32) +#define GEN_MAT_MUL_BENCH_MAIN_FUNC_BF16(C_type, BLAS_SFX) \ +void mat_mul_bench_main_ ## BLAS_SFX \ + ( \ + FILE* fin, \ + FILE* fout, \ + char stor_order, \ + char op_t, \ + int32_t m, \ + int32_t n, \ + int32_t k, \ + int32_t stride_a, \ + int32_t stride_b, \ + int32_t stride_c, \ + char* post_ops_str \ + ) \ +{ \ + if ( ( op_t != 'p' ) && ( op_t != 'P' ) && ( op_t != 'r' ) && ( op_t != 'R' ) ) \ + { \ + printf("The op_t ( 2nd arg in input.txt) is not valid\n");\ + return; \ + } \ + \ + int32_t n_repeats = bli_max( 30, bli_min(( 3e10 / ( ( int64_t )m * n * k )), 1000 )); \ + if ( global_n_repeat > 0 ) \ + { \ + n_repeats = global_n_repeat; \ + } \ + \ + /* Get 64 byte aligned memory.*/ \ + bfloat16* a = ( bfloat16* ) bli_malloc_user( sizeof( bfloat16 ) * m * k ); \ + float *a_float = bli_malloc_user( m * k * sizeof( float )); \ + for ( int32_t i = 0; i < m*k; ++i ) \ + { \ + a_float[i] = ( float ) ( i % 5 ); \ + } \ + convert_float_arr_to_bf16( a_float, a, m * k ); \ + \ + bfloat16* b = ( bfloat16* ) bli_malloc_user( sizeof( bfloat16 ) * n * k ); \ + float *b_float = bli_malloc_user( k * n * sizeof( float )); \ + for ( int32_t i = 0; i < k*n; ++i ) \ + { \ + b_float[i] = ( float ) ( i % 5 );\ + } \ + convert_float_arr_to_bf16( b_float, b, k * n ); \ + \ + C_type* c = ( C_type* ) bli_malloc_user( sizeof( C_type ) * m * n ); \ + memset( ( void* ) c, 0, sizeof( C_type ) * m * n ); \ + \ + C_type* c_ref = ( C_type* ) bli_malloc_user( sizeof( C_type ) * m * n ); \ + memset( ( void* ) c_ref, 0, sizeof( C_type ) * m * n ); \ + \ + float alpha; \ + float beta; \ + if ( bench_mode == 'p' ) \ + { \ + alpha = 1; \ + beta = 0; \ + } \ + else if ( bench_mode == 'a' ) \ + { \ + alpha = 2; \ + beta = 9; \ + } \ + \ + aocl_post_op* post_op = NULL; \ + if ( ( post_ops_str != NULL ) || ( global_dscale_out == 'y' ) ) \ + { \ + post_op = lpgemm_create_post_ops_struct_bf16bf16f32of32( m, n, post_ops_str ); \ + if ( post_op == NULL ) \ + { \ + printf(" post op struct allocation failure, returning.\n"); \ + return; \ + } \ + } \ + \ + if ( ( op_t == 'p' ) || ( op_t == 'P' ) ) \ + { \ + /* No reordering of B.*/ \ + GEN_FUNC_NAME(mat_mul_bench_driver_,BLAS_SFX) \ + ( \ + stor_order, op_t, n_repeats, m, n, k, \ + alpha, \ + a, stride_a, \ + b, stride_b, \ + beta, \ + c, stride_c, \ + post_op \ + ); \ + } \ + else if ( ( op_t == 'r' ) || ( op_t == 'R' ) ) \ + { \ + /* Reorder B.*/ \ + siz_t b_reorder_buf_siz_req = \ + aocl_get_reorder_buf_size_bf16bf16f32of32( 'B', k, n ); \ + \ + bfloat16* b_reorder = ( bfloat16* ) bli_malloc_user( b_reorder_buf_siz_req ); \ + aocl_reorder_bf16bf16f32of32( 'B', b, b_reorder, k, n, stride_b ); \ + \ + GEN_FUNC_NAME(mat_mul_bench_driver_,BLAS_SFX) \ + ( \ + stor_order, op_t, n_repeats, m, n, k, \ + alpha, \ + a, stride_a, \ + b_reorder, stride_b, \ + beta, \ + c, stride_c, \ + post_op \ + ); \ + } \ + \ +if ( bench_mode == 'a' ) \ + { \ + printf(" Running accuracy check.\n"); \ + GEN_FUNC_NAME(mat_mul_accuracy_check_driver_,BLAS_SFX) \ + ( \ + fout, stor_order, m, n, k, \ + alpha, \ + a, stride_a, \ + b, stride_b, \ + beta, \ + c, stride_c, \ + c_ref, stride_c, \ + post_op \ + ); \ + } \ + \ + lpgemm_destroy_post_ops_struct( post_op ); \ + \ + if ( a != NULL ) \ + { \ + bli_free_user( a ); \ + } \ + if ( b != NULL ) \ + { \ + bli_free_user( b ); \ + } \ + if ( a_float != NULL ) \ + { \ + bli_free_user( a_float ); \ + } \ + if ( b_float != NULL ) \ + { \ + bli_free_user( b_float ); \ + } \ + if ( c != NULL ) \ + { \ + bli_free_user( c ); \ + } \ + if ( c_ref != NULL ) \ + { \ + bli_free_user( c_ref ); \ + } \ +} \ + +GEN_MAT_MUL_BENCH_MAIN_FUNC_BF16(float,bf16bf16f32of32) +GEN_MAT_MUL_BENCH_MAIN_FUNC_BF16(bfloat16,bf16bf16f32obf16) + int main( int argc, char** argv ) { FILE* fin = NULL; @@ -692,7 +1004,9 @@ int main( int argc, char** argv ) "\nPost ops can be executed optionaly by providing a " \ "coma separated list of ops after -o arg.\nCurrently " \ "bias and relu/prelu is supported and can be specified " \ - "as a single post op or combination of the same. eg: -o bias,relu ; -o prelu.\n" ); + "as a single post op or combination of the same. eg: -o bias,relu ; -o prelu." \ + "\nDownscaled version of an API can be enabled by using -d arg. " \ + "downscale is used to enable- u8s8s32os8, u8s8s16os8 or bf16bf16f32obf16 \n" ); exit( 1 ); } @@ -858,6 +1172,27 @@ int main( int argc, char** argv ) ); } } + if ((op_type_char == 'b') || (op_type_char == 'B')) + { + if ( global_dscale_out == 'n' ) + { + GEN_FUNC_NAME(mat_mul_bench_main_, bf16bf16f32of32) + ( + fin, fout, stor_order, op_t, + m, n, k, stride_a, stride_b, stride_c, + post_ops_str_dest + ); + } + else + { + GEN_FUNC_NAME(mat_mul_bench_main_, bf16bf16f32obf16) + ( + fin, fout, stor_order, op_t, + m, n, k, stride_a, stride_b, stride_c, + post_ops_str_dest + ); + } + } if ( post_ops_str != NULL ) { strcpy( post_ops_str_dest, post_ops_str ); From 8c5b40f597e630416db8360e4e4282e9fb941b67 Mon Sep 17 00:00:00 2001 From: "Field G. Van Zee" Date: Tue, 13 Sep 2022 11:50:23 -0500 Subject: [PATCH 225/243] Use kernel CFLAGS for 'kernels' subdirs in addons. (#658) Details: - Updated Makefile and common.mk so that the targeted configuration's kernel CFLAGS are applied to source files that are found in a 'kernels' subdirectory within an enabled addon. For now, this behavior only applies when the 'kernels' directory is at the top level of the addon directory structure. For example, if there is an addon named 'foobar', the source code must be located in addon/foobar/kernels/ in order for it to be compiled with the target configurations's kernel CFLAGS. Any other source code within addon/foobar/ will be compiled with general-purpose CFLAGS (the same ones that were used on all addon code prior to this commit). Thanks to AMD (esp. Mithun Mohan) for suggesting this change and catching an intermediate bug in the PR. - Comment/whitespace updates. (cherry picked from commit fd885cf98f4fe1d3bc46468e567776c37c670fcc) Change-Id: I9a678f78bde90b23a6293ce90377004876f51067 --- Makefile | 53 ++++++++++++++++++++++++++++++++++++++++++++--------- common.mk | 33 +++++++++++++++++++++------------ 2 files changed, 65 insertions(+), 21 deletions(-) diff --git a/Makefile b/Makefile index e066ea36aa..1f86acc7e5 100644 --- a/Makefile +++ b/Makefile @@ -213,6 +213,20 @@ MK_REFKERN_OBJS := $(foreach arch, $(CONFIG_LIST), \ # Generate object file paths for all of the portable framework source code. MK_FRAME_OBJS := $(call gen-obj-paths-from-src,$(FRAME_SRC_SUFS),$(MK_FRAME_SRC),$(FRAME_PATH),$(BASE_OBJ_FRAME_PATH)) +# Generate object file paths for the addon source code. If one or more addons +# were not enabled a configure-time, these variable will we empty. +# NOTE: We separate the source and objects into kernel and non-kernel lists. +MK_ADDON_KERS_SRC := $(foreach addon, $(ADDON_LIST), \ + $(filter $(ADDON_PATH)/$(addon)/$(KERNELS_DIR)/%, \ + $(MK_ADDON_SRC)) \ + ) +MK_ADDON_OTHER_SRC := $(foreach addon, $(ADDON_LIST), \ + $(filter-out $(ADDON_PATH)/$(addon)/$(KERNELS_DIR)/%, \ + $(MK_ADDON_SRC)) \ + ) +MK_ADDON_KERS_OBJS := $(call gen-obj-paths-from-src,$(ADDON_SRC_SUFS),$(MK_ADDON_KERS_SRC),$(ADDON_PATH),$(BASE_OBJ_ADDON_PATH)) +MK_ADDON_OTHER_OBJS := $(call gen-obj-paths-from-src,$(ADDON_SRC_SUFS),$(MK_ADDON_OTHER_SRC),$(ADDON_PATH),$(BASE_OBJ_ADDON_PATH)) +MK_ADDON_OBJS := $(MK_ADDON_KERS_OBJS) $(MK_ADDON_OTHER_OBJS) # AMD has optimized some of the framework files, these optimizations # may not be compatible with other platforms. # @@ -237,11 +251,6 @@ endif # Generate object file paths for all of the debgu and trace logger. MK_AOCLDTL_OBJS := $(call gen-obj-paths-from-src,$(AOCLDTL_SRC_SUFS),$(MK_AOCLDTL_SRC),$(AOCLDTL_PATH),$(BASE_OBJ_AOCLDTL_PATH)) - -# Generate object file paths for the addon source code. If one or more addons -# were not enabled a configure-time, this variable will we empty. -MK_ADDON_OBJS := $(call gen-obj-paths-from-src,$(ADDON_SRC_SUFS),$(MK_ADDON_SRC),$(ADDON_PATH),$(BASE_OBJ_ADDON_PATH)) - # Generate object file paths for the sandbox source code. If a sandbox was not # enabled a configure-time, this variable will we empty. MK_SANDBOX_OBJS := $(call gen-obj-paths-from-src,$(SANDBOX_SRC_SUFS),$(MK_SANDBOX_SRC),$(SANDBOX_PATH),$(BASE_OBJ_SANDBOX_PATH)) @@ -595,18 +604,34 @@ endef # first argument: a configuration name from the union of config_list and # config_name, used to look up the CFLAGS to use during compilation. +# second argument: the C99 addon file suffix being considered. define make-c99-addon-rule $(BASE_OBJ_ADDON_PATH)/%.o: $(ADDON_PATH)/%.$(2) $(BLIS_H_FLAT) $(ADDON_H99_FILES) $(MAKE_DEFS_MK_PATHS) ifeq ($(ENABLE_VERBOSE),yes) - $$(if $$(findstring _amd512vnni,$$<),$$(eval LPGEMM_MARCH_VAR=icelake-server -mavx512bf16),$$(eval LPGEMM_MARCH_VAR=znver3)) - $(CC) -march=$$(LPGEMM_MARCH_VAR) $(call get-addon-c99flags-for,$(1)) -c $$< -o $$@ + $(CC) $(call get-addon-c99flags-for,$(1)) -c $$< -o $$@ else @echo "Compiling $$@" $(call get-addon-c99text-for,$(1)) - $$(if $$(findstring _amd512vnni,$$<),$$(eval LPGEMM_MARCH_VAR=icelake-server -mavx512bf16),$$(eval LPGEMM_MARCH_VAR=znver3)) - @$(CC) -march=$$(LPGEMM_MARCH_VAR) $(call get-addon-c99flags-for,$(1)) -c $$< -o $$@ + @$(CC) $(call get-addon-c99flags-for,$(1)) -c $$< -o $$@ +endif +endef + +# first argument: a configuration name from the union of config_list and +# config_name, used to look up the CFLAGS to use during compilation. +# second argument: the C99 addon file suffix being considered. +# third argument: the name of the addon being considered. +define make-c99-addon-kers-rule +$(BASE_OBJ_ADDON_PATH)/$(3)/$(KERNELS_DIR)/%.o: $(ADDON_PATH)/$(3)/$(KERNELS_DIR)/%.$(2) $(BLIS_H_FLAT) $(ADDON_H99_FILES) $(MAKE_DEFS_MK_PATHS) +ifeq ($(ENABLE_VERBOSE),yes) + $(CC) $(call get-addon-kernel-c99flags-for,$(1)) -c $$< -o $$@ +else + @echo "Compiling $$@" $(call get-addon-kernel-text-for,$(1)) + @$(CC) $(call get-addon-kernel-c99flags-for,$(1)) -c $$< -o $$@ endif endef +# first argument: a configuration name from the union of config_list and +# config_name, used to look up the CFLAGS to use during compilation. +# second argument: the C++ addon file suffix being considered. define make-cxx-addon-rule $(BASE_OBJ_ADDON_PATH)/%.o: $(ADDON_PATH)/%.$(2) $(BLIS_H_FLAT) $(ADDON_HXX_FILES) $(MAKE_DEFS_MK_PATHS) ifeq ($(ENABLE_VERBOSE),yes) @@ -619,6 +644,7 @@ endef # first argument: a configuration name from the union of config_list and # config_name, used to look up the CFLAGS to use during compilation. +# second argument: the C99 sandbox file suffix being considered. define make-c99-sandbox-rule $(BASE_OBJ_SANDBOX_PATH)/%.o: $(SANDBOX_PATH)/%.$(2) $(BLIS_H_FLAT) $(SANDBOX_H99_FILES) $(MAKE_DEFS_MK_PATHS) ifeq ($(ENABLE_VERBOSE),yes) @@ -629,6 +655,9 @@ else endif endef +# first argument: a configuration name from the union of config_list and +# config_name, used to look up the CFLAGS to use during compilation. +# second argument: the C++ sandbox file suffix being considered. define make-cxx-sandbox-rule $(BASE_OBJ_SANDBOX_PATH)/%.o: $(SANDBOX_PATH)/%.$(2) $(BLIS_H_FLAT) $(SANDBOX_HXX_FILES) $(MAKE_DEFS_MK_PATHS) ifeq ($(ENABLE_VERBOSE),yes) @@ -682,6 +711,12 @@ $(foreach kset, $(KERNEL_LIST), $(eval $(call make-kernels-rule,$(kset),$(call g $(foreach suf, $(ADDON_C99_SUFS), \ $(foreach conf, $(CONFIG_NAME), $(eval $(call make-c99-addon-rule,$(conf),$(suf))))) +# Instantiate the build rule for C addon/kernels files. Use the CFLAGS for the +# configuration family. +$(foreach addon, $(ADDON_LIST), \ +$(foreach suf, $(ADDON_C99_SUFS), \ +$(foreach conf, $(CONFIG_NAME), $(eval $(call make-c99-addon-kers-rule,$(conf),$(suf),$(addon)))))) + # Instantiate the build rule for C++ addon files. Use the CFLAGS for the # configuration family. $(foreach suf, $(ADDON_CXX_SUFS), \ diff --git a/common.mk b/common.mk index 0fdf659a9f..5912f53b49 100644 --- a/common.mk +++ b/common.mk @@ -159,7 +159,7 @@ get-kernel-cflags-for = $(strip $(call load-var-for,CKOPTFLAGS,$(1)) \ $(BUILD_SYMFLAGS) \ ) -# When compiling sandboxes, we use flags similar to those of general framework +# When compiling addons, we use flags similar to those of general framework # source. This ensures that the same code can be linked and run across various # sub-configurations. get-addon-c99flags-for = $(strip $(call load-var-for,COPTFLAGS,$(1)) \ @@ -174,6 +174,15 @@ get-addon-cxxflags-for = $(strip $(call load-var-for,COPTFLAGS,$(1)) \ $(BUILD_CPPFLAGS) \ $(BUILD_SYMFLAGS) \ ) +# When compiling addon kernels, we use flags similar to those of kernels +# flags, except we also include the addon header paths. +get-addon-kernel-c99flags-for = $(strip $(call load-var-for,CKOPTFLAGS,$(1)) \ + $(call load-var-for,CKVECFLAGS,$(1)) \ + $(call get-noopt-cflags-for,$(1)) \ + $(CADDONINCFLAGS) \ + $(BUILD_CPPFLAGS) \ + $(BUILD_SYMFLAGS) \ + ) # When compiling sandboxes, we use flags similar to those of general framework # source. This ensures that the same code can be linked and run across various @@ -208,17 +217,17 @@ get-user-cflags-for = $(strip $(call load-var-for,COPTFLAGS,$(1)) \ # Define functions that return messages appropriate for each non-verbose line # of compilation output. -get-noopt-text = "(CFLAGS for no optimization)" -get-refinit-text-for = "('$(1)' CFLAGS for ref. kernel init)" -get-refkern-text-for = "('$(1)' CFLAGS for ref. kernels)" -get-config-text-for = "('$(1)' CFLAGS for config code)" -get-frame-text-for = "('$(1)' CFLAGS for framework code)" -get-aocldtl-text-for = "('$(1)' CFLAGS for AOCL debug and trace code)" -get-kernel-text-for = "('$(1)' CFLAGS for kernels)" -get-addon-c99text-for = "('$(1)' CFLAGS for addons)" -get-addon-cxxtext-for = "('$(1)' CXXFLAGS for addons)" -get-sandbox-c99text-for = "('$(1)' CFLAGS for sandboxes)" -get-sandbox-cxxtext-for = "('$(1)' CXXFLAGS for sandboxes)" +get-noopt-text = "(CFLAGS for no optimization)" +get-refinit-text-for = "('$(1)' CFLAGS for ref. kernel init)" +get-refkern-text-for = "('$(1)' CFLAGS for ref. kernels)" +get-config-text-for = "('$(1)' CFLAGS for config code)" +get-frame-text-for = "('$(1)' CFLAGS for framework code)" +get-kernel-text-for = "('$(1)' CFLAGS for kernels)" +get-addon-c99text-for = "('$(1)' CFLAGS for addons)" +get-addon-cxxtext-for = "('$(1)' CXXFLAGS for addons)" +get-addon-kernel-text-for = "('$(1)' CFLAGS for addon kernels)" +get-sandbox-c99text-for = "('$(1)' CFLAGS for sandboxes)" +get-sandbox-cxxtext-for = "('$(1)' CXXFLAGS for sandboxes)" From 5cd358b76e09ba4672fbf3f5d13532da496cd649 Mon Sep 17 00:00:00 2001 From: mkadavil Date: Tue, 27 Sep 2022 16:59:39 +0530 Subject: [PATCH 226/243] Zen4 compilation flag updates to support low precision gemm. - BFloat16 flags added to zen4 make_defs in order to enable compilation of low precision gemm by using zen4 config. - Avoid -ftree-partial-pre optimization flag with gcc due to non optimal code generation for intrinsics based kernels in low precision gemm. - Enable only Zen3 specific low precision gemm kernels (s16) compilation when aocl_gemm addon is compiled on Zen3 machines. AMD-Internal: [CPUPL-1545] Change-Id: Id3be3410bfbf141bb6fc4b4e3391115a4e0bb79f --- .../aocl_gemm/frame/bf16bf16f32/lpgemm_bf16.c | 16 ++++++++++ .../frame/bf16bf16f32/lpgemm_reorder_bf16.c | 8 ++++- .../aocl_gemm/frame/u8s8s32/lpgemm_reorder.c | 17 ++++++++++ .../aocl_gemm/frame/u8s8s32/lpgemm_u8s8s32.c | 16 ++++++++++ .../lpgemm_6x64rowmajor_bf16_amd512vnni.c | 22 +++++++------ .../lpgemm_m_fringe_bf16_amd512vnni.c | 12 ++++--- .../lpgemm_mn_fringe_bf16_amd512vnni.c | 14 ++++---- .../lpgemm_n_fringe_bf16_amd512vnni.c | 24 +++++++------- .../lpgemm_packb_bf16_amd512vnni.c | 32 ++++++++++--------- .../u8s8s32/lpgemm_6x64rowmajor_amd512vnni.c | 2 ++ .../u8s8s32/lpgemm_m_fringe_amd512vnni.c | 2 ++ .../u8s8s32/lpgemm_mn_fringe_amd512vnni.c | 2 ++ .../u8s8s32/lpgemm_n_fringe_amd512vnni.c | 2 ++ .../kernels/u8s8s32/lpgemm_packa_amd512vnni.c | 22 +++++++------ .../kernels/u8s8s32/lpgemm_packb_amd512vnni.c | 22 +++++++------ common.mk | 1 + config/zen4/make_defs.mk | 22 +++++++------ 17 files changed, 158 insertions(+), 78 deletions(-) diff --git a/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_bf16.c b/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_bf16.c index 738840f78e..5db523f987 100644 --- a/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_bf16.c +++ b/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_bf16.c @@ -235,6 +235,7 @@ LPGEMM_5LOOP(bfloat16,bfloat16,float,bf16bf16f32of32) if ( ( jc_packb_end > jc_packb_start ) && ( jc_packb_start < ( jc + nc0 ) ) ) { +#ifdef BLIS_KERNELS_ZEN4 packb_nr64_bf16bf16f32of32 ( pack_b_buffer_bf16 + ( jc_packb_start * kc0_updated ), @@ -243,6 +244,7 @@ LPGEMM_5LOOP(bfloat16,bfloat16,float,bf16bf16f32of32) ( jc_packb_end - jc_packb_start ), kc0, &rs_b_use, &cs_b_use ); +#endif } else { @@ -302,6 +304,7 @@ LPGEMM_5LOOP(bfloat16,bfloat16,float,bf16bf16f32of32) { dim_t nr0 = bli_min( ( nc0 - jr ), NR ); +#ifdef BLIS_KERNELS_ZEN4 // Reorder/Packed B, Reorder/Packed/Unpacked A call. lpgemm_rowvar_bf16bf16f32of32_6x64 ( @@ -312,6 +315,19 @@ LPGEMM_5LOOP(bfloat16,bfloat16,float,bf16bf16f32of32) alpha, beta0, is_last_k, ic, ( jc + jr ), post_op_list, rs_c_downscale ); +#else + // Silence compiler warnings. + ( void )b_use; + ( void )a_block_stride; + ( void )rs_c_downscale; + ( void )is_last_k; + ( void )c_use_ic; + ( void )a_use; + ( void )beta0; + ( void )nr0; + ( void )mc0; + ( void )cs_a_use; +#endif } } } diff --git a/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_reorder_bf16.c b/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_reorder_bf16.c index 07b087b790..5bb217facd 100644 --- a/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_reorder_bf16.c +++ b/addon/aocl_gemm/frame/bf16bf16f32/lpgemm_reorder_bf16.c @@ -151,7 +151,7 @@ void reorderb_nr64_bf16bf16f32of32 // st = ( jc_cur_loop * k ) // + ( n_sub_updated * pc ) // + ( NC' * kc0_updated) - +#ifdef BLIS_KERNELS_ZEN4 // B should always be packed. packb_nr64_bf16bf16f32of32 ( @@ -162,6 +162,12 @@ void reorderb_nr64_bf16bf16f32of32 ( rs_b * pc ) + jc ), rs_b, nc0, kc0, &rs_b_reorder, &cs_b_reorder ); +#else + // Silence compiler warnings. + rs_b_reorder = 0; + cs_b_reorder = 0; + ( void )rs_b; +#endif } adjust_B_panel_reordered_jc( &jc, jc_cur_loop ); diff --git a/addon/aocl_gemm/frame/u8s8s32/lpgemm_reorder.c b/addon/aocl_gemm/frame/u8s8s32/lpgemm_reorder.c index 8d61c757e8..746a134100 100644 --- a/addon/aocl_gemm/frame/u8s8s32/lpgemm_reorder.c +++ b/addon/aocl_gemm/frame/u8s8s32/lpgemm_reorder.c @@ -148,6 +148,7 @@ void reorderb_nr64_u8s8s32o32 // st = ( jc_cur_loop * k ) // + ( n_sub_updated * pc ) // + ( NC' * kc0_updated) +#ifdef BLIS_KERNELS_ZEN4 packb_nr64_u8s8s32o32 ( ( ( ( int8_t* )b_reorder->storage.aligned_buffer ) + @@ -157,6 +158,14 @@ void reorderb_nr64_u8s8s32o32 ( rs_b * pc ) + jc ), rs_b, nc0, kc0, &rs_b_reorder, &cs_b_reorder ); +#else + // Silence compiler warnings. + rs_b_reorder = 0; + cs_b_reorder = 0; + ( void )kc0_updated; + ( void )k_updated; + ( void )rs_b; +#endif } adjust_B_panel_reordered_jc( &jc, jc_cur_loop ); @@ -198,6 +207,7 @@ void reordera_mr6_u8s8s32o32 { dim_t mc0 = bli_min( ( m - ic ), MC ); +#ifdef BLIS_KERNELS_ZEN4 packa_k64_u8s8s32o32 ( ( ( ( uint8_t* )a_reorder->storage.aligned_buffer ) + ( pc * m ) + @@ -205,6 +215,13 @@ void reordera_mr6_u8s8s32o32 ( ( ( uint8_t* )a->storage.aligned_buffer ) + ( rs_a * ic ) + pc ), rs_a, mc0, kc0, &rs_a_reorder, &cs_a_reorder ); +#else + rs_a_reorder = 0; + cs_a_reorder = 0; + ( void )kc0_updated; + ( void )rs_a; + ( void )mc0; +#endif } } diff --git a/addon/aocl_gemm/frame/u8s8s32/lpgemm_u8s8s32.c b/addon/aocl_gemm/frame/u8s8s32/lpgemm_u8s8s32.c index a62d42a1f1..82a745fcf5 100644 --- a/addon/aocl_gemm/frame/u8s8s32/lpgemm_u8s8s32.c +++ b/addon/aocl_gemm/frame/u8s8s32/lpgemm_u8s8s32.c @@ -239,6 +239,7 @@ LPGEMM_5LOOP(uint8_t,int8_t,int32_t,u8s8s32o32) if ( ( jc_packb_end > jc_packb_start ) && ( jc_packb_start < ( jc + nc0 ) ) ) { +#ifdef BLIS_KERNELS_ZEN4 packb_nr64_u8s8s32o32 ( pack_b_buffer_u8s8s32o32 + ( jc_packb_start * kc0_updated ), @@ -247,6 +248,7 @@ LPGEMM_5LOOP(uint8_t,int8_t,int32_t,u8s8s32o32) ( jc_packb_end - jc_packb_start ), kc0, &rs_b_use, &cs_b_use ); +#endif } else { @@ -308,6 +310,7 @@ LPGEMM_5LOOP(uint8_t,int8_t,int32_t,u8s8s32o32) ); pack_a_buffer_u8s8s32o32 = ( uint8_t* )bli_mem_buffer( &mem_a ); +#ifdef BLIS_KERNELS_ZEN4 packa_k64_u8s8s32o32 ( pack_a_buffer_u8s8s32o32, @@ -315,6 +318,7 @@ LPGEMM_5LOOP(uint8_t,int8_t,int32_t,u8s8s32o32) mc0, kc0, &rs_a_use, &cs_a_use ); +#endif a_use = pack_a_buffer_u8s8s32o32; a_block_stride = kc0_updated; } @@ -339,6 +343,7 @@ LPGEMM_5LOOP(uint8_t,int8_t,int32_t,u8s8s32o32) { dim_t nr0 = bli_min( ( nc0 - jr ), NR ); +#ifdef BLIS_KERNELS_ZEN4 // Reorder/Packed B, Reorder/Packed/Unpacked A call. lpgemm_rowvar_u8s8s32o32_6x64 ( @@ -349,6 +354,17 @@ LPGEMM_5LOOP(uint8_t,int8_t,int32_t,u8s8s32o32) alpha, beta0, is_last_k, ic, ( jc + jr ), post_op_list, rs_c_downscale ); +#else + // Silence compiler warnings. + ( void )b_use; + ( void )a_block_stride; + ( void )rs_c_downscale; + ( void )is_last_k; + ( void )c_use_ic; + ( void )a_use; + ( void )beta0; + ( void )nr0; +#endif } } } diff --git a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_6x64rowmajor_bf16_amd512vnni.c b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_6x64rowmajor_bf16_amd512vnni.c index 14dc7e8e57..65a4963dcb 100644 --- a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_6x64rowmajor_bf16_amd512vnni.c +++ b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_6x64rowmajor_bf16_amd512vnni.c @@ -38,6 +38,7 @@ #include "lpgemm_kernels.h" #include "lpgemm_f32_kern_macros.h" +#ifdef BLIS_KERNELS_ZEN4 // 6x64 bf16 kernel LPGEMM_MAIN_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x64) { @@ -658,10 +659,10 @@ LPGEMM_MAIN_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x64) selector4 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + post_op_c_i + 3 ) ); - a_bf16_0 = + __m512 selector5 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + post_op_c_i + 4 ) ); - a_bf16_1 = + __m512 selector6 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + post_op_c_i + 5 ) ); @@ -714,28 +715,28 @@ LPGEMM_MAIN_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x64) c_float_3p3 = _mm512_add_ps( selector4, c_float_3p3 ); // c[4,0-15] - c_float_4p0 = _mm512_add_ps( a_bf16_0, c_float_4p0 ); + c_float_4p0 = _mm512_add_ps( selector5, c_float_4p0 ); // c[4, 16-31] - c_float_4p1 = _mm512_add_ps( a_bf16_0, c_float_4p1 ); + c_float_4p1 = _mm512_add_ps( selector5, c_float_4p1 ); // c[4,32-47] - c_float_4p2 = _mm512_add_ps( a_bf16_0, c_float_4p2 ); + c_float_4p2 = _mm512_add_ps( selector5, c_float_4p2 ); // c[4,48-63] - c_float_4p3 = _mm512_add_ps( a_bf16_0, c_float_4p3 ); + c_float_4p3 = _mm512_add_ps( selector5, c_float_4p3 ); // c[5,0-15] - c_float_5p0 = _mm512_add_ps( a_bf16_1, c_float_5p0 ); + c_float_5p0 = _mm512_add_ps( selector6, c_float_5p0 ); // c[5, 16-31] - c_float_5p1 = _mm512_add_ps( a_bf16_1, c_float_5p1 ); + c_float_5p1 = _mm512_add_ps( selector6, c_float_5p1 ); // c[5,32-47] - c_float_5p2 = _mm512_add_ps( a_bf16_1, c_float_5p2 ); + c_float_5p2 = _mm512_add_ps( selector6, c_float_5p2 ); // c[5,48-63] - c_float_5p3 = _mm512_add_ps( a_bf16_1, c_float_5p3 ); + c_float_5p3 = _mm512_add_ps( selector6, c_float_5p3 ); } POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR @@ -1142,3 +1143,4 @@ LPGEMM_MAIN_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x64) } } } +#endif diff --git a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_m_fringe_bf16_amd512vnni.c b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_m_fringe_bf16_amd512vnni.c index cfd8f0a0dc..e4418b2a0e 100644 --- a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_m_fringe_bf16_amd512vnni.c +++ b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_m_fringe_bf16_amd512vnni.c @@ -39,6 +39,7 @@ #include "lpgemm_kernels.h" #include "lpgemm_f32_kern_macros.h" +#ifdef BLIS_KERNELS_ZEN4 // 5x64 bf16 kernel LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x64) { @@ -484,7 +485,7 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x64) selector4 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + post_op_c_i + 3 ) ); - a_bf16_0 = + __m512 selector5 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + post_op_c_i + 4 ) ); @@ -537,16 +538,16 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5x64) c_float_3p3 = _mm512_add_ps( selector4, c_float_3p3 ); // c[4,0-15] - c_float_4p0 = _mm512_add_ps( a_bf16_0, c_float_4p0 ); + c_float_4p0 = _mm512_add_ps( selector5, c_float_4p0 ); // c[4, 16-31] - c_float_4p1 = _mm512_add_ps( a_bf16_0, c_float_4p1 ); + c_float_4p1 = _mm512_add_ps( selector5, c_float_4p1 ); // c[4,32-47] - c_float_4p2 = _mm512_add_ps( a_bf16_0, c_float_4p2 ); + c_float_4p2 = _mm512_add_ps( selector5, c_float_4p2 ); // c[4,48-63] - c_float_4p3 = _mm512_add_ps( a_bf16_0, c_float_4p3 ); + c_float_4p3 = _mm512_add_ps( selector5, c_float_4p3 ); } POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR @@ -2588,3 +2589,4 @@ LPGEMM_M_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1x64) // c[0,48-63] _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 3*16 ), c_float_0p3 ); } +#endif diff --git a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_mn_fringe_bf16_amd512vnni.c b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_mn_fringe_bf16_amd512vnni.c index 65bce97799..6e985f154f 100644 --- a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_mn_fringe_bf16_amd512vnni.c +++ b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_mn_fringe_bf16_amd512vnni.c @@ -39,6 +39,7 @@ #include "lpgemm_kernels.h" #include "lpgemm_f32_kern_macros.h" +#ifdef BLIS_KERNELS_ZEN4 // 5xlt16 bf16 fringe kernel LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5xlt16) { @@ -254,7 +255,7 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5xlt16) __m512 selector4 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + post_op_c_i + 3 ) ); - a_bf16_0 = + __m512 selector5 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + post_op_c_i + 4 ) ); @@ -271,7 +272,7 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_5xlt16) c_float_3p0 = _mm512_add_ps( selector4, c_float_3p0 ); // c[4,0-15] - c_float_4p0 = _mm512_add_ps( a_bf16_0, c_float_4p0 ); + c_float_4p0 = _mm512_add_ps( selector5, c_float_4p0 ); } POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR @@ -563,7 +564,7 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4xlt16) __m512 selector3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + post_op_c_i + 2 ) ); - a_bf16_0 = + __m512 selector4 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + post_op_c_i + 3 ) ); @@ -577,7 +578,7 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_4xlt16) c_float_2p0 = _mm512_add_ps( selector3, c_float_2p0 ); // c[3,0-15] - c_float_3p0 = _mm512_add_ps( a_bf16_0, c_float_3p0 ); + c_float_3p0 = _mm512_add_ps( selector4, c_float_3p0 ); } POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR @@ -821,7 +822,7 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3xlt16) selector2 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + post_op_c_i + 1 ) ); - a_bf16_0 = + __m512 selector3 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + post_op_c_i + 2 ) ); @@ -832,7 +833,7 @@ LPGEMM_MN_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_3xlt16) c_float_1p0 = _mm512_add_ps( selector2, c_float_1p0 ); // c[2,0-15] - c_float_2p0 = _mm512_add_ps( a_bf16_0, c_float_2p0 ); + c_float_2p0 = _mm512_add_ps( selector3, c_float_2p0 ); } POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR @@ -5839,3 +5840,4 @@ LPGEMM_MN_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_1x48) // c[0,32-47] _mm512_storeu_ps( c + ( rs_c * 0 ) + ( 2*16 ), c_float_0p2 ); } +#endif diff --git a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_n_fringe_bf16_amd512vnni.c b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_n_fringe_bf16_amd512vnni.c index 7207cf41dc..1a37ab071a 100644 --- a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_n_fringe_bf16_amd512vnni.c +++ b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_n_fringe_bf16_amd512vnni.c @@ -38,6 +38,7 @@ #include "lpgemm_kernels.h" #include "lpgemm_f32_kern_macros.h" +#ifdef BLIS_KERNELS_ZEN4 // 6xlt16 bf16 fringe kernel LPGEMM_N_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6xlt16) { @@ -328,7 +329,7 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6xlt16) __m512 selector5 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + post_op_c_i + 4 ) ); - a_bf16_0 = + __m512 selector6 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + post_op_c_i + 5 ) ); @@ -348,7 +349,7 @@ LPGEMM_N_LT_NR0_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6xlt16) c_float_4p0 = _mm512_add_ps( selector5, c_float_4p0 ); // c[5,0-15] - c_float_5p0 = _mm512_add_ps( a_bf16_0, c_float_5p0 ); + c_float_5p0 = _mm512_add_ps( selector6, c_float_5p0 ); } POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR @@ -827,7 +828,7 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x16) __m512 selector5 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + post_op_c_i + 4 ) ); - a_bf16_0 = + __m512 selector6 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + post_op_c_i + 5 ) ); @@ -847,7 +848,7 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x16) c_float_4p0 = _mm512_add_ps( selector5, c_float_4p0 ); // c[5,0-15] - c_float_5p0 = _mm512_add_ps( a_bf16_0, c_float_5p0 ); + c_float_5p0 = _mm512_add_ps( selector6, c_float_5p0 ); } POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR @@ -1383,7 +1384,7 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x32) __m512 selector5 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + post_op_c_i + 4 ) ); - a_bf16_0 = + __m512 selector6 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + post_op_c_i + 5 ) ); @@ -1418,10 +1419,10 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x32) c_float_4p1 = _mm512_add_ps( selector5, c_float_4p1 ); // c[5,0-15] - c_float_5p0 = _mm512_add_ps( a_bf16_0, c_float_5p0 ); + c_float_5p0 = _mm512_add_ps( selector6, c_float_5p0 ); // c[5, 16-31] - c_float_5p1 = _mm512_add_ps( a_bf16_0, c_float_5p1 ); + c_float_5p1 = _mm512_add_ps( selector6, c_float_5p1 ); } POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR @@ -2112,7 +2113,7 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x48) __m512 selector5 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + post_op_c_i + 4 ) ); - a_bf16_0 = + __m512 selector6 = _mm512_set1_ps( *( ( float* )post_ops_list_temp->op_args1 + post_op_c_i + 5 ) ); @@ -2162,13 +2163,13 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x48) c_float_4p2 = _mm512_add_ps( selector5, c_float_4p2 ); // c[5,0-15] - c_float_5p0 = _mm512_add_ps( a_bf16_0, c_float_5p0 ); + c_float_5p0 = _mm512_add_ps( selector6, c_float_5p0 ); // c[5, 16-31] - c_float_5p1 = _mm512_add_ps( a_bf16_0, c_float_5p1 ); + c_float_5p1 = _mm512_add_ps( selector6, c_float_5p1 ); // c[5,32-47] - c_float_5p2 = _mm512_add_ps( a_bf16_0, c_float_5p2 ); + c_float_5p2 = _mm512_add_ps( selector6, c_float_5p2 ); } POST_OP_LABEL_LASTK_SAFE_JUMP_WITH_NEXT_PTR @@ -2498,3 +2499,4 @@ LPGEMM_N_FRINGE_KERN(bfloat16, bfloat16, float, bf16bf16f32of32_6x48) } } } +#endif diff --git a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_packb_bf16_amd512vnni.c b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_packb_bf16_amd512vnni.c index 19725b2768..374ac3280e 100644 --- a/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_packb_bf16_amd512vnni.c +++ b/addon/aocl_gemm/kernels/bf16bf16f32/lpgemm_packb_bf16_amd512vnni.c @@ -39,6 +39,17 @@ #include "lpgemm_config.h" #include "aocl_bf16_type.h" +void get_packb_nr64_bf16bf16f32of32_strides + ( + dim_t* rs_b, + dim_t* cs_b + ) +{ + *rs_b = lpgemm_get_block_size_NR_global_cntx( BF16BF16F32OF32 ) * 2; + *cs_b = lpgemm_get_block_size_NR_global_cntx( BF16BF16F32OF32 ) / 2; +} + +#ifdef BLIS_KERNELS_ZEN4 void packb_nrlt16_bf16bf16f32of32 ( bfloat16* pack_b_buffer_bf16bf16f32of32, @@ -54,7 +65,7 @@ void packb_nr16_bf16bf16f32of32 const bfloat16* b, const dim_t ldb, const dim_t KC - ); + ); void packb_nr32_bf16bf16f32of32 ( @@ -62,7 +73,7 @@ void packb_nr32_bf16bf16f32of32 const bfloat16* b, const dim_t ldb, const dim_t KC - ); + ); void packb_nr48_bf16bf16f32of32 ( @@ -70,17 +81,7 @@ void packb_nr48_bf16bf16f32of32 const bfloat16* b, const dim_t ldb, const dim_t KC - ); - -void get_packb_nr64_bf16bf16f32of32_strides - ( - dim_t* rs_b, - dim_t* cs_b - ) -{ - *rs_b = lpgemm_get_block_size_NR_global_cntx( BF16BF16F32OF32 ) * 2; - *cs_b = lpgemm_get_block_size_NR_global_cntx( BF16BF16F32OF32 ) / 2; -} + ); void packb_nr64_bf16bf16f32of32 ( @@ -91,7 +92,7 @@ void packb_nr64_bf16bf16f32of32 const dim_t KC, dim_t* rs_b, dim_t* cs_b - ) + ) { dim_t NR = 64; @@ -501,4 +502,5 @@ void packb_nrlt16_bf16bf16f32of32 _mm256_storeu_epi64( pack_b_buffer_bf16bf16f32of32 + ( ( kr_new + 0 ) * NR ), b0 ); _mm256_storeu_epi64( pack_b_buffer_bf16bf16f32of32 + ( ( kr_new + 1 ) * NR ), a0 ); } -} +} +#endif diff --git a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_6x64rowmajor_amd512vnni.c b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_6x64rowmajor_amd512vnni.c index ea5c8bf280..f249106a3c 100644 --- a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_6x64rowmajor_amd512vnni.c +++ b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_6x64rowmajor_amd512vnni.c @@ -38,6 +38,7 @@ #include "lpgemm_kernels.h" #include "lpgemm_s32_kern_macros.h" +#ifdef BLIS_KERNELS_ZEN4 // 6x64 int8o32 kernel LPGEMM_MAIN_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x64) { @@ -1057,3 +1058,4 @@ LPGEMM_MAIN_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x64) } } } +#endif diff --git a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_m_fringe_amd512vnni.c b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_m_fringe_amd512vnni.c index e6079cce23..1674a22bd0 100644 --- a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_m_fringe_amd512vnni.c +++ b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_m_fringe_amd512vnni.c @@ -39,6 +39,7 @@ #include "lpgemm_kernels.h" #include "lpgemm_s32_kern_macros.h" +#ifdef BLIS_KERNELS_ZEN4 // 5x64 int8o32 kernel LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5x64) { @@ -2358,3 +2359,4 @@ LPGEMM_M_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1x64) // c[0,48-63] _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 3*16 ), c_int32_0p3 ); } +#endif diff --git a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_mn_fringe_amd512vnni.c b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_mn_fringe_amd512vnni.c index 3900dcb82b..b202061e6a 100644 --- a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_mn_fringe_amd512vnni.c +++ b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_mn_fringe_amd512vnni.c @@ -39,6 +39,7 @@ #include "lpgemm_kernels.h" #include "lpgemm_s32_kern_macros.h" +#ifdef BLIS_KERNELS_ZEN4 // 5xlt16 int8o32 fringe kernel LPGEMM_MN_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_5xlt16) { @@ -5279,3 +5280,4 @@ LPGEMM_MN_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_1x48) // c[0,32-47] _mm512_storeu_epi32( c + ( rs_c * 0 ) + ( 2*16 ), c_int32_0p2 ); } +#endif diff --git a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_n_fringe_amd512vnni.c b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_n_fringe_amd512vnni.c index 609a0641db..856dc1355e 100644 --- a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_n_fringe_amd512vnni.c +++ b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_n_fringe_amd512vnni.c @@ -39,6 +39,7 @@ #include "lpgemm_kernels.h" #include "lpgemm_s32_kern_macros.h" +#ifdef BLIS_KERNELS_ZEN4 // 6xlt16 int8o32 fringe kernel LPGEMM_N_LT_NR0_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6xlt16) { @@ -2296,3 +2297,4 @@ LPGEMM_N_FRINGE_KERN(uint8_t,int8_t,int32_t,u8s8s32o32_6x48) } } } +#endif diff --git a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_packa_amd512vnni.c b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_packa_amd512vnni.c index 03bc7db03f..601b8a3eff 100644 --- a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_packa_amd512vnni.c +++ b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_packa_amd512vnni.c @@ -40,6 +40,17 @@ #define MR 6 #define NR 64 +void get_packa_k64_u8s8s32o32_strides + ( + dim_t* rs_a, + dim_t* cs_a + ) +{ + *rs_a = 4; + *cs_a = 24; +} + +#ifdef BLIS_KERNELS_ZEN4 void packa_m5_k64_u8s8s32o32 ( uint8_t* pack_a_buffer_u8s8s32o32, @@ -80,16 +91,6 @@ void packa_m1_k64_u8s8s32o32 const dim_t KC ); -void get_packa_k64_u8s8s32o32_strides - ( - dim_t* rs_a, - dim_t* cs_a - ) -{ - *rs_a = 4; - *cs_a = 24; -} - // TODO: k fringe till k=4, k%4=0 and padding to make k'%4 = 0 if k%4 != 0 originally. void packa_k64_u8s8s32o32 ( @@ -516,3 +517,4 @@ void packa_m1_k64_u8s8s32o32 _mm512_storeu_epi64( pack_a_buffer_u8s8s32o32 + ( ( kr * 1 ) + ( 0 ) ), a0 ); } } +#endif diff --git a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_packb_amd512vnni.c b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_packb_amd512vnni.c index 06a46afb44..d388c476e9 100644 --- a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_packb_amd512vnni.c +++ b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_packb_amd512vnni.c @@ -40,6 +40,17 @@ #define NR 64 +void get_packb_nr64_u8s8s32o32_strides + ( + dim_t* rs_b, + dim_t* cs_b + ) +{ + *rs_b = NR * 4; + *cs_b = NR; +} + +#ifdef BLIS_KERNELS_ZEN4 void packb_nrlt16_u8s8s32o32 ( int8_t* pack_b_buffer_u8s8s32o32, @@ -73,16 +84,6 @@ void packb_nr48_u8s8s32o32 const dim_t KC ); -void get_packb_nr64_u8s8s32o32_strides - ( - dim_t* rs_b, - dim_t* cs_b - ) -{ - *rs_b = NR * 4; - *cs_b = NR; -} - void packb_nr64_u8s8s32o32 ( int8_t* pack_b_buffer_u8s8s32o32, @@ -790,3 +791,4 @@ void packb_nrlt16_u8s8s32o32 _mm512_storeu_epi64( pack_b_buffer_u8s8s32o32 + ( ( kr_new + 0 ) * NR ), a0_zmm ); } } +#endif diff --git a/common.mk b/common.mk index 5912f53b49..220e8ccaa8 100644 --- a/common.mk +++ b/common.mk @@ -222,6 +222,7 @@ get-refinit-text-for = "('$(1)' CFLAGS for ref. kernel init)" get-refkern-text-for = "('$(1)' CFLAGS for ref. kernels)" get-config-text-for = "('$(1)' CFLAGS for config code)" get-frame-text-for = "('$(1)' CFLAGS for framework code)" +get-aocldtl-text-for = "('$(1)' CFLAGS for AOCL debug and trace code)" get-kernel-text-for = "('$(1)' CFLAGS for kernels)" get-addon-c99text-for = "('$(1)' CFLAGS for addons)" get-addon-cxxtext-for = "('$(1)' CXXFLAGS for addons)" diff --git a/config/zen4/make_defs.mk b/config/zen4/make_defs.mk index 501ace287c..062e680910 100644 --- a/config/zen4/make_defs.mk +++ b/config/zen4/make_defs.mk @@ -70,15 +70,18 @@ endif CKOPTFLAGS := $(COPTFLAGS) -fomit-frame-pointer ifeq ($(CC_VENDOR),gcc) GCC_VERSION := $(strip $(shell $(CC) -dumpversion | cut -d. -f1)) -# gcc or clang version must be atleast 4.0 -# gcc 12.0 or later: -ifeq ($(shell test $(GCC_VERSION) -ge 12; echo $$?),0) -CKVECFLAGS += -march=znver4 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -mfpmath=sse -CRVECFLAGS += -march=znver4 -else + + # gcc 11.0 or later: ifeq ($(shell test $(GCC_VERSION) -ge 11; echo $$?),0) -CKVECFLAGS += -march=znver3 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -mfpmath=sse +# Update CKOPTFLAGS for gcc 11+ to use O3 optimization without +# -ftree-partial-pre flag. This flag results in suboptimal code +# generation for instrinsics based kernels. +ifneq ($(DEBUG_TYPE),noopt) +CKOPTFLAGS := -O2 -fgcse-after-reload -fipa-cp-clone -floop-interchange -floop-unroll-and-jam -fpeel-loops -fpredictive-commoning -fsplit-loops -fsplit-paths -ftree-loop-distribution -funswitch-loops -fvect-cost-model=dynamic -fversion-loops-for-strides -fomit-frame-pointer +endif + +CKVECFLAGS += -march=znver3 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -mavx512bf16 -mfpmath=sse CRVECFLAGS += -march=znver3 else # gcc 9.0 or later: @@ -97,7 +100,6 @@ CRVECFLAGS += -march=znver1 -mno-avx256-split-unaligned-store endif # GCC 8 endif # GCC 9 endif # GCC 11 -endif # GCC 12 else ifeq ($(CC_VENDOR),clang) @@ -114,12 +116,12 @@ ifeq ($(CC_VENDOR),clang) # for version 4x we will enable znver4 ifeq ($(strip $(shell $(CC) -v |&head -1 |grep -c 'AOCC_4')),1) -CKVECFLAGS += -march=znver4 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mfpmath=sse +CKVECFLAGS += -march=znver4 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512bf16 -mfpmath=sse CRVECFLAGS += -march=znver4 else # for version 3x we will enable znver3 ifeq ($(strip $(shell $(CC) -v |&head -1 |grep -c 'AOCC_3')),1) -CKVECFLAGS += -march=znver3 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -mfpmath=sse +CKVECFLAGS += -march=znver3 -mavx512f -mavx512dq -mavx512bw -mavx512vl -mavx512vnni -mavx512bf16 -mfpmath=sse CRVECFLAGS += -march=znver3 else # for version 2x we will enable znver2 From ef62d2e0755b229473a54e60beb322e508d6c797 Mon Sep 17 00:00:00 2001 From: satish kumar nuggu Date: Wed, 14 Sep 2022 13:15:20 +0530 Subject: [PATCH 227/243] Addressed uninitialized variables in trsm small algo 1. Addressed uninitialized variables reported in coverity for all datatypes of trsm small algo. AMD-Internal: [CPUPL-2542] Change-Id: Ifae57ef6435493942732526720e6a9d6bec70e71 --- kernels/zen/3/bli_trsm_small.c | 46 ++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/kernels/zen/3/bli_trsm_small.c b/kernels/zen/3/bli_trsm_small.c index ffb00b83c1..8b6bfd9f67 100644 --- a/kernels/zen/3/bli_trsm_small.c +++ b/kernels/zen/3/bli_trsm_small.c @@ -8445,6 +8445,8 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB __m128d xmm5; + xmm5 = _mm_setzero_pd(); + /* Performs solving TRSM for 6 rows at a time from 0 to n/6 in steps of d_nr a. Load and pack A (a01 block), the size of packing 6x6 to 6x (n-6) @@ -10846,6 +10848,8 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB __m128d xmm5; + xmm5 = _mm_setzero_pd(); + /* Performs solving TRSM for 6 rows at a time from 0 to n/6 in steps of d_nr a. Load and pack A (a01 block), the size of packing 6x6 to 6x (n-6) @@ -13136,6 +13140,8 @@ BLIS_INLINE err_t bli_dtrsm_small_AltXB_AuXB __m128d xmm5; + xmm5 = _mm_setzero_pd(); + gint_t required_packing_A = 1; mem_t local_mem_buf_A_s = {0}; double *D_A_pack = NULL; @@ -15138,6 +15144,8 @@ BLIS_INLINE err_t bli_dtrsm_small_AutXB_AlXB __m128d xmm5; + xmm5 = _mm_setzero_pd(); + gint_t required_packing_A = 1; mem_t local_mem_buf_A_s = {0}; double *D_A_pack = NULL; @@ -17764,6 +17772,8 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB __m128 xmm5; + xmm5 = _mm_setzero_ps(); + /* Performs solving TRSM for 6 rows at a time from 0 to n/6 in steps of d_nr a. Load and pack A (a01 block), the size of packing 6x6 to 6x (n-6) @@ -21429,6 +21439,8 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB __m128 xmm5; + xmm5 = _mm_setzero_ps(); + /* Performs solving TRSM for 6 rows at a time from 0 to n/6 in steps of d_nr a. Load and pack A (a01 block), the size of packing 6x6 to 6x (n-6) @@ -33749,6 +33761,9 @@ BLIS_INLINE err_t bli_ztrsm_small_AutXB_AlXB __m128d xmm5, xmm4; + xmm4 = _mm_setzero_pd(); + xmm5 = _mm_setzero_pd(); + gint_t required_packing_A = 1; mem_t local_mem_buf_A_s = {0}; dcomplex *D_A_pack = NULL; @@ -34980,6 +34995,9 @@ BLIS_INLINE err_t bli_ztrsm_small_AltXB_AuXB __m128d xmm5, xmm4; + xmm4 = _mm_setzero_pd(); + xmm5 = _mm_setzero_pd(); + gint_t required_packing_A = 1; mem_t local_mem_buf_A_s = {0}; dcomplex *D_A_pack = NULL; @@ -36231,6 +36249,8 @@ BLIS_INLINE err_t bli_ztrsm_small_XAutB_XAlB __m128d xmm5; + xmm5 = _mm_setzero_pd(); + for(j = (n-d_nr); (j+1) > 0; j -= d_nr) //loop along 'N' direction { a01 = L + (j*rs_a) + (j+d_nr)*cs_a; @@ -37692,6 +37712,8 @@ BLIS_INLINE err_t bli_ztrsm_small_XAltB_XAuB __m128d xmm5; + xmm5 = _mm_setzero_pd(); + for(j = 0; (j+d_nr-1) < n; j += d_nr) //loop along 'N' direction { a01 = L + j*rs_a;//pointer to block of A to be used in GEMM @@ -42248,6 +42270,13 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB __m128 xmm0, xmm1, xmm2, xmm3, xmm4; __m128 xmm5; + xmm0 = _mm_setzero_ps(); + xmm1 = _mm_setzero_ps(); + xmm2 = _mm_setzero_ps(); + xmm3 = _mm_setzero_ps(); + xmm4 = _mm_setzero_ps(); + xmm5 = _mm_setzero_ps(); + gint_t required_packing_A = 1; mem_t local_mem_buf_A_s = {0}; scomplex *D_A_pack = NULL; @@ -44492,6 +44521,13 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB __m128 xmm0, xmm1, xmm2, xmm3, xmm4; __m128 xmm5; + xmm0 = _mm_setzero_ps(); + xmm1 = _mm_setzero_ps(); + xmm2 = _mm_setzero_ps(); + xmm3 = _mm_setzero_ps(); + xmm4 = _mm_setzero_ps(); + xmm5 = _mm_setzero_ps(); + gint_t required_packing_A = 1; mem_t local_mem_buf_A_s = {0}; scomplex *D_A_pack = NULL; @@ -46714,6 +46750,11 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB __m128 xmm0, xmm1, xmm2; __m128 xmm5; + xmm0 = _mm_setzero_ps(); + xmm1 = _mm_setzero_ps(); + xmm2 = _mm_setzero_ps(); + xmm5 = _mm_setzero_ps(); + gint_t required_packing_A = 1; mem_t local_mem_buf_A_s = {0}; scomplex *D_A_pack = NULL; @@ -48156,6 +48197,11 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB __m128 xmm0, xmm1, xmm2; __m128 xmm5; + xmm0 = _mm_setzero_ps(); + xmm1 = _mm_setzero_ps(); + xmm2 = _mm_setzero_ps(); + xmm5 = _mm_setzero_ps(); + gint_t required_packing_A = 1; mem_t local_mem_buf_A_s = {0}; scomplex *D_A_pack = NULL; From 6c5b15c38798c7c11fddaa11ab02c9ad1306ddc6 Mon Sep 17 00:00:00 2001 From: satish kumar nuggu Date: Thu, 29 Sep 2022 16:15:46 +0530 Subject: [PATCH 228/243] Fixed ASAN reported issues in [s/c]trsm small kernels Details: 1. Fixed the memory access paritial overflows for the variables AlphaVal,ones reported by ASAN. 2. Using 128 bit packed broadcast with the 64 bit data types after type casting would cause the garbage data to be filled in the destination register. 3. Fixed this issue by using set_ps instruction instead of broadcast. 4. In cases of n remainder being 1, extra elements were accessed that could cause out of memory access. Removed the extra element access. AMD-Internal: [CPUPL-2578][CPUPL-2587] Change-Id: Iaa918060c66287f2f46bcb9f69e9323f6707cf75 --- kernels/zen/3/bli_trsm_small.c | 1807 ++++++++++++++++++++++++++------ 1 file changed, 1481 insertions(+), 326 deletions(-) diff --git a/kernels/zen/3/bli_trsm_small.c b/kernels/zen/3/bli_trsm_small.c index 8b6bfd9f67..f5f7f37c6f 100644 --- a/kernels/zen/3/bli_trsm_small.c +++ b/kernels/zen/3/bli_trsm_small.c @@ -21044,8 +21044,6 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB xmm0 = _mm_loadu_ps((float *)(a01 + rs_a * 0 + loop_count*6)); _mm_storeu_ps((float *)(ptr_a10_dup + p_lda * 0 + loop_count*6), xmm0); - xmm0 = _mm_loadl_pi(xmm1,(__m64 *)(a01 + rs_a * 0 + 4 + loop_count*6)); - _mm_storel_pi((__m64 *)(ptr_a10_dup + p_lda * 0 + 4 + loop_count*6),xmm0); } } @@ -24872,8 +24870,6 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB xmm0 = _mm_loadu_ps((float *)(a01 + rs_a * 0 + loop_count*6)); _mm_storeu_ps((float *)(ptr_a10_dup + p_lda * 0 + loop_count*6), xmm0); - xmm0 = _mm_loadl_pi(xmm1,(__m64 *)(a01 + rs_a * 0 + 4 + loop_count*6)); - _mm_storel_pi((__m64 *)(ptr_a10_dup + p_lda * 0 + 4 + loop_count*6),xmm0); } } @@ -41889,7 +41885,7 @@ BLIS_INLINE void ctrsm_small_pack_diag_element ymm16 = _mm256_permute_ps(ymm16, 0x44);\ \ ymm0 = _mm256_loadu_ps((float const *)(b11));\ - ymm3 = _mm256_broadcast_ps((__m128 const *)&ones);\ + ymm3 = _mm256_broadcast_ps((__m128 const *)&ones_a);\ ymm3 = _mm256_permute_ps(ymm3, 0x44);\ /*in register transpose * ymm0,ymm1,ymm2 holds @@ -41939,7 +41935,7 @@ BLIS_INLINE void ctrsm_small_pack_diag_element \ ymm0 = _mm256_loadu_ps((float const *)(b11));\ ymm1 = _mm256_loadu_ps((float const *)(b11 + cs_b *1));\ - ymm3 = _mm256_broadcast_ps((__m128 const *)&ones);\ + ymm3 = _mm256_broadcast_ps((__m128 const *)&ones_a);\ ymm3 = _mm256_permute_ps(ymm3, 0x44);\ /*in register transpose * ymm0,ymm1,ymm2 holds @@ -41996,7 +41992,7 @@ BLIS_INLINE void ctrsm_small_pack_diag_element ymm0 = _mm256_loadu_ps((float const *)(b11));\ ymm1 = _mm256_loadu_ps((float const *)(b11 + cs_b *1));\ ymm2 = _mm256_loadu_ps((float const *)(b11 + cs_b *2));\ - ymm3 = _mm256_broadcast_ps((__m128 const *)&ones);\ + ymm3 = _mm256_broadcast_ps((__m128 const *)&ones_a);\ ymm3 = _mm256_permute_ps(ymm3, 0x44);\ /*in register transpose * ymm0,ymm1,ymm2 holds @@ -42059,7 +42055,7 @@ BLIS_INLINE void ctrsm_small_pack_diag_element ymm0 = _mm256_loadu_ps((float const *)(b11));\ ymm1 = _mm256_loadu_ps((float const *)(b11 + cs_b *1));\ ymm2 = _mm256_loadu_ps((float const *)(b11 + cs_b *2));\ - ymm3 = _mm256_broadcast_ps((__m128 const *)&ones);\ + ymm3 = _mm256_broadcast_ps((__m128 const *)&ones_a);\ ymm3 = _mm256_permute_ps(ymm3, 0x44);\ /*in register transpose * ymm0,ymm1,ymm2 holds @@ -42216,7 +42212,6 @@ BLIS_INLINE void ctrsm_small_pack_diag_element _mm256_storeu_ps((float *)(b11 + cs_b * 2 + 4), ymm2);\ } - BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ( obj_t* AlphaObj, @@ -42251,13 +42246,17 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB dim_t i, j, k; //loop variables dim_t k_iter; //number of times GEMM to be performed - scomplex AlphaVal = *(scomplex *)AlphaObj->buffer; //value of alpha - scomplex *L = a->buffer; //pointer to matrix A - scomplex *B = b->buffer; //pointer to matrix B + scomplex AlphaVal[2]; + AlphaVal[0] = *(scomplex *)AlphaObj->buffer; //value of alpha + AlphaVal[1] = *(scomplex *)AlphaObj->buffer; //value of alpha + + scomplex *L = bli_obj_buffer_at_off(a); //pointer to matrix A + scomplex *B = bli_obj_buffer_at_off(b); //pointer to matrix B scomplex *a10, *a11, *b01, *b11; //pointers that point to blocks for GEMM and TRSM - scomplex ones = {1.0, 1.0}; + float ones = 1.0; + float ones_a[4] = {1.0, 1.0,1.0,1.0}; bool is_unitdiag = bli_obj_has_unit_diag(a); //scratch registers @@ -42270,14 +42269,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB __m128 xmm0, xmm1, xmm2, xmm3, xmm4; __m128 xmm5; - xmm0 = _mm_setzero_ps(); + xmm0 = _mm_setzero_ps(); xmm1 = _mm_setzero_ps(); xmm2 = _mm_setzero_ps(); xmm3 = _mm_setzero_ps(); xmm4 = _mm_setzero_ps(); xmm5 = _mm_setzero_ps(); - - gint_t required_packing_A = 1; + + gint_t required_packing_A = 1; mem_t local_mem_buf_A_s = {0}; scomplex *D_A_pack = NULL; scomplex d11_pack[d_mr] __attribute__((aligned(64))); @@ -42306,8 +42305,8 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); if(NULL==D_A_pack) return BLIS_NULL_POINTER; } - - /* + + /* Performs solving TRSM for 4 colmns at a time from 0 to m/4 in steps of d_mr a. Load, transpose, Pack A (a10 block), the size of packing 4x3 to 4x (m-4) First there will be no GEMM and no packing of a10 because it is only TRSM @@ -42392,14 +42391,20 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB */ ////extract a00 ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm8) #else BLIS_CTRSM_MUL(ymm8) #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*1) ); + ymm2 = _mm256_set_ps((a11 + cs_a*1)->imag,(a11 + cs_a*1)->real, + (a11 + cs_a*1)->imag,(a11 + cs_a*1)->real, + (a11 + cs_a*1)->imag,(a11 + cs_a*1)->real, + (a11 + cs_a*1)->imag,(a11 + cs_a*1)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -42416,7 +42421,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm9 = _mm256_sub_ps(ymm9,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*2) ); + ymm2 = _mm256_set_ps((a11 + cs_a*2)->imag,(a11 + cs_a*2)->real, + (a11 + cs_a*2)->imag,(a11 + cs_a*2)->real, + (a11 + cs_a*2)->imag,(a11 + cs_a*2)->real, + (a11 + cs_a*2)->imag,(a11 + cs_a*2)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -42431,7 +42439,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm10 = _mm256_sub_ps(ymm10,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*3) ); + ymm2 = _mm256_set_ps((a11 + cs_a*3)->imag,(a11 + cs_a*3)->real, + (a11 + cs_a*3)->imag,(a11 + cs_a*3)->real, + (a11 + cs_a*3)->imag,(a11 + cs_a*3)->real, + (a11 + cs_a*3)->imag,(a11 + cs_a*3)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -42446,7 +42457,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm11 = _mm256_sub_ps(ymm11,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*4) ); + ymm2 = _mm256_set_ps((a11 + cs_a*4)->imag,(a11 + cs_a*4)->real, + (a11 + cs_a*4)->imag,(a11 + cs_a*4)->real, + (a11 + cs_a*4)->imag,(a11 + cs_a*4)->real, + (a11 + cs_a*4)->imag,(a11 + cs_a*4)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -42461,7 +42475,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm12 = _mm256_sub_ps(ymm12,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*5) ); + ymm2 = _mm256_set_ps((a11 + cs_a*5)->imag,(a11 + cs_a*5)->real, + (a11 + cs_a*5)->imag,(a11 + cs_a*5)->real, + (a11 + cs_a*5)->imag,(a11 + cs_a*5)->real, + (a11 + cs_a*5)->imag,(a11 + cs_a*5)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -42476,7 +42493,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm13 = _mm256_sub_ps(ymm13,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*6) ); + ymm2 = _mm256_set_ps((a11 + cs_a*6)->imag,(a11 + cs_a*6)->real, + (a11 + cs_a*6)->imag,(a11 + cs_a*6)->real, + (a11 + cs_a*6)->imag,(a11 + cs_a*6)->real, + (a11 + cs_a*6)->imag,(a11 + cs_a*6)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -42491,7 +42511,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm14 = _mm256_sub_ps(ymm14,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*7) ); + ymm2 = _mm256_set_ps((a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -42506,7 +42529,11 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm15 = _mm256_sub_ps(ymm15,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_set_ps((d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real); + ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm9) @@ -42517,7 +42544,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB a11 += rs_a; - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*2) ); + ymm2 = _mm256_set_ps((a11 + cs_a*2)->imag,(a11 + cs_a*2)->real, + (a11 + cs_a*2)->imag,(a11 + cs_a*2)->real, + (a11 + cs_a*2)->imag,(a11 + cs_a*2)->real, + (a11 + cs_a*2)->imag,(a11 + cs_a*2)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -42532,7 +42562,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm10 = _mm256_sub_ps(ymm10,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*3) ); + ymm2 = _mm256_set_ps((a11 + cs_a*3)->imag,(a11 + cs_a*3)->real, + (a11 + cs_a*3)->imag,(a11 + cs_a*3)->real, + (a11 + cs_a*3)->imag,(a11 + cs_a*3)->real, + (a11 + cs_a*3)->imag,(a11 + cs_a*3)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -42547,7 +42580,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm11 = _mm256_sub_ps(ymm11,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*4) ); + ymm2 = _mm256_set_ps((a11 + cs_a*4)->imag,(a11 + cs_a*4)->real, + (a11 + cs_a*4)->imag,(a11 + cs_a*4)->real, + (a11 + cs_a*4)->imag,(a11 + cs_a*4)->real, + (a11 + cs_a*4)->imag,(a11 + cs_a*4)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -42562,7 +42598,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm12 = _mm256_sub_ps(ymm12,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*5) ); + ymm2 = _mm256_set_ps((a11 + cs_a*5)->imag,(a11 + cs_a*5)->real, + (a11 + cs_a*5)->imag,(a11 + cs_a*5)->real, + (a11 + cs_a*5)->imag,(a11 + cs_a*5)->real, + (a11 + cs_a*5)->imag,(a11 + cs_a*5)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -42577,7 +42616,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm13 = _mm256_sub_ps(ymm13,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*6) ); + ymm2 = _mm256_set_ps((a11 + cs_a*6)->imag,(a11 + cs_a*6)->real, + (a11 + cs_a*6)->imag,(a11 + cs_a*6)->real, + (a11 + cs_a*6)->imag,(a11 + cs_a*6)->real, + (a11 + cs_a*6)->imag,(a11 + cs_a*6)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -42592,7 +42634,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm14 = _mm256_sub_ps(ymm14,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*7) ); + ymm2 = _mm256_set_ps((a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -42608,7 +42653,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm15 = _mm256_sub_ps(ymm15,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 2)); + ymm1 = _mm256_set_ps((d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION @@ -42620,7 +42668,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB a11 += rs_a; - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*3) ); + ymm2 = _mm256_set_ps((a11 + cs_a*3)->imag,(a11 + cs_a*3)->real, + (a11 + cs_a*3)->imag,(a11 + cs_a*3)->real, + (a11 + cs_a*3)->imag,(a11 + cs_a*3)->real, + (a11 + cs_a*3)->imag,(a11 + cs_a*3)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -42635,7 +42686,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm11 = _mm256_sub_ps(ymm11,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*4) ); + ymm2 = _mm256_set_ps((a11 + cs_a*4)->imag,(a11 + cs_a*4)->real, + (a11 + cs_a*4)->imag,(a11 + cs_a*4)->real, + (a11 + cs_a*4)->imag,(a11 + cs_a*4)->real, + (a11 + cs_a*4)->imag,(a11 + cs_a*4)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -42650,7 +42704,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm12 = _mm256_sub_ps(ymm12,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*5) ); + ymm2 = _mm256_set_ps((a11 + cs_a*5)->imag,(a11 + cs_a*5)->real, + (a11 + cs_a*5)->imag,(a11 + cs_a*5)->real, + (a11 + cs_a*5)->imag,(a11 + cs_a*5)->real, + (a11 + cs_a*5)->imag,(a11 + cs_a*5)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -42665,7 +42722,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm13 = _mm256_sub_ps(ymm13,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*6) ); + ymm2 = _mm256_set_ps((a11 + cs_a*6)->imag,(a11 + cs_a*6)->real, + (a11 + cs_a*6)->imag,(a11 + cs_a*6)->real, + (a11 + cs_a*6)->imag,(a11 + cs_a*6)->real, + (a11 + cs_a*6)->imag,(a11 + cs_a*6)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -42680,7 +42740,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm14 = _mm256_sub_ps(ymm14,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*7) ); + ymm2 = _mm256_set_ps((a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -42696,7 +42759,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm15 = _mm256_sub_ps(ymm15,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 3)); + ymm1 = _mm256_set_ps((d11_pack + 3)->imag,(d11_pack + 3)->real, + (d11_pack + 3)->imag,(d11_pack + 3)->real, + (d11_pack + 3)->imag,(d11_pack + 3)->real, + (d11_pack + 3)->imag,(d11_pack + 3)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm11) @@ -42707,7 +42773,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB a11 += rs_a; - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*4) ); + ymm2 = _mm256_set_ps((a11 + cs_a*4)->imag,(a11 + cs_a*4)->real, + (a11 + cs_a*4)->imag,(a11 + cs_a*4)->real, + (a11 + cs_a*4)->imag,(a11 + cs_a*4)->real, + (a11 + cs_a*4)->imag,(a11 + cs_a*4)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -42722,7 +42791,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm12 = _mm256_sub_ps(ymm12,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*5) ); + ymm2 = _mm256_set_ps((a11 + cs_a*5)->imag,(a11 + cs_a*5)->real, + (a11 + cs_a*5)->imag,(a11 + cs_a*5)->real, + (a11 + cs_a*5)->imag,(a11 + cs_a*5)->real, + (a11 + cs_a*5)->imag,(a11 + cs_a*5)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -42737,7 +42809,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm13 = _mm256_sub_ps(ymm13,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*6) ); + ymm2 = _mm256_set_ps((a11 + cs_a*6)->imag,(a11 + cs_a*6)->real, + (a11 + cs_a*6)->imag,(a11 + cs_a*6)->real, + (a11 + cs_a*6)->imag,(a11 + cs_a*6)->real, + (a11 + cs_a*6)->imag,(a11 + cs_a*6)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -42752,7 +42827,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm14 = _mm256_sub_ps(ymm14,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*7) ); + ymm2 = _mm256_set_ps((a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -42768,7 +42846,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm15 = _mm256_sub_ps(ymm15,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 4)); + ymm1 = _mm256_set_ps((d11_pack + 4)->imag,(d11_pack + 4)->real, + (d11_pack + 4)->imag,(d11_pack + 4)->real, + (d11_pack + 4)->imag,(d11_pack + 4)->real, + (d11_pack + 4)->imag,(d11_pack + 4)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm12) @@ -42777,7 +42858,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB #endif a11 += rs_a; - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*5) ); + ymm2 = _mm256_set_ps((a11 + cs_a*5)->imag,(a11 + cs_a*5)->real, + (a11 + cs_a*5)->imag,(a11 + cs_a*5)->real, + (a11 + cs_a*5)->imag,(a11 + cs_a*5)->real, + (a11 + cs_a*5)->imag,(a11 + cs_a*5)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -42792,7 +42876,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm13 = _mm256_sub_ps(ymm13,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*6) ); + ymm2 = _mm256_set_ps((a11 + cs_a*6)->imag,(a11 + cs_a*6)->real, + (a11 + cs_a*6)->imag,(a11 + cs_a*6)->real, + (a11 + cs_a*6)->imag,(a11 + cs_a*6)->real, + (a11 + cs_a*6)->imag,(a11 + cs_a*6)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -42807,7 +42894,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm14 = _mm256_sub_ps(ymm14,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*7) ); + ymm2 = _mm256_set_ps((a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -42823,7 +42913,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm15 = _mm256_sub_ps(ymm15,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 5)); + ymm1 = _mm256_set_ps((d11_pack + 5)->imag,(d11_pack + 5)->real, + (d11_pack + 5)->imag,(d11_pack + 5)->real, + (d11_pack + 5)->imag,(d11_pack + 5)->real, + (d11_pack + 5)->imag,(d11_pack + 5)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm13) @@ -42833,7 +42926,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB a11 += rs_a; - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*6) ); + ymm2 = _mm256_set_ps((a11 + cs_a*6)->imag,(a11 + cs_a*6)->real, + (a11 + cs_a*6)->imag,(a11 + cs_a*6)->real, + (a11 + cs_a*6)->imag,(a11 + cs_a*6)->real, + (a11 + cs_a*6)->imag,(a11 + cs_a*6)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -42848,7 +42944,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm14 = _mm256_sub_ps(ymm14,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*7) ); + ymm2 = _mm256_set_ps((a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -42864,7 +42963,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm15 = _mm256_sub_ps(ymm15,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 6)); + ymm1 = _mm256_set_ps((d11_pack + 6)->imag,(d11_pack + 6)->real, + (d11_pack + 6)->imag,(d11_pack + 6)->real, + (d11_pack + 6)->imag,(d11_pack + 6)->real, + (d11_pack + 6)->imag,(d11_pack + 6)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm14) @@ -42874,7 +42976,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB a11 += rs_a; - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*7) ); + ymm2 = _mm256_set_ps((a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -42890,7 +42995,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm15 = _mm256_sub_ps(ymm15,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 7)); + ymm1 = _mm256_set_ps((d11_pack + 7)->imag,(d11_pack + 7)->real, + (d11_pack + 7)->imag,(d11_pack + 7)->real, + (d11_pack + 7)->imag,(d11_pack + 7)->real, + (d11_pack + 7)->imag,(d11_pack + 7)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm15) @@ -42919,8 +43027,8 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB BLIS_CTRSM_SMALL_GEMM_8mx2n(a10,b01,cs_b,p_lda,k_iter) float zero = 0.0; - ymm16 = _mm256_broadcast_ss(&AlphaVal.real); - ymm17 = _mm256_broadcast_ss(&AlphaVal.imag); + ymm16 = _mm256_broadcast_ss(&AlphaVal[0].real); + ymm17 = _mm256_broadcast_ss(&AlphaVal[0].imag); ymm2 = _mm256_broadcast_ss(&zero); ymm3 = _mm256_broadcast_ss(&zero); ymm6 = _mm256_broadcast_ss(&zero); @@ -42971,8 +43079,8 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB BLIS_CTRSM_SMALL_GEMM_8mx1n(a10,b01,cs_b,p_lda,k_iter) float zero = 0.0; - ymm16 = _mm256_broadcast_ss(&AlphaVal.real); - ymm17 = _mm256_broadcast_ss(&AlphaVal.imag); + ymm16 = _mm256_broadcast_ss(&AlphaVal[0].real); + ymm17 = _mm256_broadcast_ss(&AlphaVal[0].imag); ymm2 = _mm256_broadcast_ss(&zero); ymm3 = _mm256_broadcast_ss(&zero); ymm6 = _mm256_broadcast_ss(&zero); @@ -43040,14 +43148,20 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm15 = _mm256_permute2f128_ps(ymm18,ymm19,0x31); ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm8) #else BLIS_CTRSM_MUL(ymm8) #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*1) ); + ymm2 = _mm256_set_ps((a11 + cs_a*1)->imag,(a11 + cs_a*1)->real, + (a11 + cs_a*1)->imag,(a11 + cs_a*1)->real, + (a11 + cs_a*1)->imag,(a11 + cs_a*1)->real, + (a11 + cs_a*1)->imag,(a11 + cs_a*1)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -43064,7 +43178,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm9 = _mm256_sub_ps(ymm9,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*2) ); + ymm2 = _mm256_set_ps((a11 + cs_a*2)->imag,(a11 + cs_a*2)->real, + (a11 + cs_a*2)->imag,(a11 + cs_a*2)->real, + (a11 + cs_a*2)->imag,(a11 + cs_a*2)->real, + (a11 + cs_a*2)->imag,(a11 + cs_a*2)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -43079,7 +43196,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm10 = _mm256_sub_ps(ymm10,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*3) ); + ymm2 = _mm256_set_ps((a11 + cs_a*3)->imag,(a11 + cs_a*3)->real, + (a11 + cs_a*3)->imag,(a11 + cs_a*3)->real, + (a11 + cs_a*3)->imag,(a11 + cs_a*3)->real, + (a11 + cs_a*3)->imag,(a11 + cs_a*3)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -43094,7 +43214,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm11 = _mm256_sub_ps(ymm11,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*4) ); + ymm2 = _mm256_set_ps((a11 + cs_a*4)->imag,(a11 + cs_a*4)->real, + (a11 + cs_a*4)->imag,(a11 + cs_a*4)->real, + (a11 + cs_a*4)->imag,(a11 + cs_a*4)->real, + (a11 + cs_a*4)->imag,(a11 + cs_a*4)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -43109,7 +43232,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm12 = _mm256_sub_ps(ymm12,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*5) ); + ymm2 = _mm256_set_ps((a11 + cs_a*5)->imag,(a11 + cs_a*5)->real, + (a11 + cs_a*5)->imag,(a11 + cs_a*5)->real, + (a11 + cs_a*5)->imag,(a11 + cs_a*5)->real, + (a11 + cs_a*5)->imag,(a11 + cs_a*5)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -43124,7 +43250,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm13 = _mm256_sub_ps(ymm13,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*6) ); + ymm2 = _mm256_set_ps((a11 + cs_a*6)->imag,(a11 + cs_a*6)->real, + (a11 + cs_a*6)->imag,(a11 + cs_a*6)->real, + (a11 + cs_a*6)->imag,(a11 + cs_a*6)->real, + (a11 + cs_a*6)->imag,(a11 + cs_a*6)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -43139,7 +43268,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm14 = _mm256_sub_ps(ymm14,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*7) ); + ymm2 = _mm256_set_ps((a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -43154,7 +43286,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm15 = _mm256_sub_ps(ymm15,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_set_ps((d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm9) @@ -43165,7 +43300,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB a11 += rs_a; - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*2) ); + ymm2 = _mm256_set_ps((a11 + cs_a*2)->imag,(a11 + cs_a*2)->real, + (a11 + cs_a*2)->imag,(a11 + cs_a*2)->real, + (a11 + cs_a*2)->imag,(a11 + cs_a*2)->real, + (a11 + cs_a*2)->imag,(a11 + cs_a*2)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -43180,7 +43318,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm10 = _mm256_sub_ps(ymm10,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*3) ); + ymm2 = _mm256_set_ps((a11 + cs_a*3)->imag,(a11 + cs_a*3)->real, + (a11 + cs_a*3)->imag,(a11 + cs_a*3)->real, + (a11 + cs_a*3)->imag,(a11 + cs_a*3)->real, + (a11 + cs_a*3)->imag,(a11 + cs_a*3)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -43195,7 +43336,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm11 = _mm256_sub_ps(ymm11,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*4) ); + ymm2 = _mm256_set_ps((a11 + cs_a*4)->imag,(a11 + cs_a*4)->real, + (a11 + cs_a*4)->imag,(a11 + cs_a*4)->real, + (a11 + cs_a*4)->imag,(a11 + cs_a*4)->real, + (a11 + cs_a*4)->imag,(a11 + cs_a*4)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -43210,7 +43354,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm12 = _mm256_sub_ps(ymm12,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*5) ); + ymm2 = _mm256_set_ps((a11 + cs_a*5)->imag,(a11 + cs_a*5)->real, + (a11 + cs_a*5)->imag,(a11 + cs_a*5)->real, + (a11 + cs_a*5)->imag,(a11 + cs_a*5)->real, + (a11 + cs_a*5)->imag,(a11 + cs_a*5)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -43225,7 +43372,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm13 = _mm256_sub_ps(ymm13,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*6) ); + ymm2 = _mm256_set_ps((a11 + cs_a*6)->imag,(a11 + cs_a*6)->real, + (a11 + cs_a*6)->imag,(a11 + cs_a*6)->real, + (a11 + cs_a*6)->imag,(a11 + cs_a*6)->real, + (a11 + cs_a*6)->imag,(a11 + cs_a*6)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -43240,7 +43390,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm14 = _mm256_sub_ps(ymm14,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*7) ); + ymm2 = _mm256_set_ps((a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -43256,7 +43409,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm15 = _mm256_sub_ps(ymm15,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 2)); + ymm1 = _mm256_set_ps((d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION @@ -43268,7 +43424,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB a11 += rs_a; - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*3) ); + ymm2 = _mm256_set_ps((a11 + cs_a*3)->imag,(a11 + cs_a*3)->real, + (a11 + cs_a*3)->imag,(a11 + cs_a*3)->real, + (a11 + cs_a*3)->imag,(a11 + cs_a*3)->real, + (a11 + cs_a*3)->imag,(a11 + cs_a*3)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -43283,7 +43442,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm11 = _mm256_sub_ps(ymm11,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*4) ); + ymm2 = _mm256_set_ps((a11 + cs_a*4)->imag,(a11 + cs_a*4)->real, + (a11 + cs_a*4)->imag,(a11 + cs_a*4)->real, + (a11 + cs_a*4)->imag,(a11 + cs_a*4)->real, + (a11 + cs_a*4)->imag,(a11 + cs_a*4)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -43298,7 +43460,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm12 = _mm256_sub_ps(ymm12,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*5) ); + ymm2 = _mm256_set_ps((a11 + cs_a*5)->imag,(a11 + cs_a*5)->real, + (a11 + cs_a*5)->imag,(a11 + cs_a*5)->real, + (a11 + cs_a*5)->imag,(a11 + cs_a*5)->real, + (a11 + cs_a*5)->imag,(a11 + cs_a*5)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -43313,7 +43478,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm13 = _mm256_sub_ps(ymm13,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*6) ); + ymm2 = _mm256_set_ps((a11 + cs_a*6)->imag,(a11 + cs_a*6)->real, + (a11 + cs_a*6)->imag,(a11 + cs_a*6)->real, + (a11 + cs_a*6)->imag,(a11 + cs_a*6)->real, + (a11 + cs_a*6)->imag,(a11 + cs_a*6)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -43328,7 +43496,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm14 = _mm256_sub_ps(ymm14,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*7) ); + ymm2 = _mm256_set_ps((a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -43344,7 +43515,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm15 = _mm256_sub_ps(ymm15,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 3)); + ymm1 = _mm256_set_ps((d11_pack + 3)->imag,(d11_pack + 3)->real, + (d11_pack + 3)->imag,(d11_pack + 3)->real, + (d11_pack + 3)->imag,(d11_pack + 3)->real, + (d11_pack + 3)->imag,(d11_pack + 3)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm11) @@ -43355,7 +43529,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB a11 += rs_a; - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*4) ); + ymm2 = _mm256_set_ps((a11 + cs_a*4)->imag,(a11 + cs_a*4)->real, + (a11 + cs_a*4)->imag,(a11 + cs_a*4)->real, + (a11 + cs_a*4)->imag,(a11 + cs_a*4)->real, + (a11 + cs_a*4)->imag,(a11 + cs_a*4)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -43370,7 +43547,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm12 = _mm256_sub_ps(ymm12,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*5) ); + ymm2 = _mm256_set_ps((a11 + cs_a*5)->imag,(a11 + cs_a*5)->real, + (a11 + cs_a*5)->imag,(a11 + cs_a*5)->real, + (a11 + cs_a*5)->imag,(a11 + cs_a*5)->real, + (a11 + cs_a*5)->imag,(a11 + cs_a*5)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -43385,7 +43565,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm13 = _mm256_sub_ps(ymm13,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*6) ); + ymm2 = _mm256_set_ps((a11 + cs_a*6)->imag,(a11 + cs_a*6)->real, + (a11 + cs_a*6)->imag,(a11 + cs_a*6)->real, + (a11 + cs_a*6)->imag,(a11 + cs_a*6)->real, + (a11 + cs_a*6)->imag,(a11 + cs_a*6)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -43400,7 +43583,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm14 = _mm256_sub_ps(ymm14,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*7) ); + ymm2 = _mm256_set_ps((a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -43416,7 +43602,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm15 = _mm256_sub_ps(ymm15,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 4)); + ymm1 = _mm256_set_ps((d11_pack + 4)->imag,(d11_pack + 4)->real, + (d11_pack + 4)->imag,(d11_pack + 4)->real, + (d11_pack + 4)->imag,(d11_pack + 4)->real, + (d11_pack + 4)->imag,(d11_pack + 4)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm12) @@ -43425,7 +43614,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB #endif a11 += rs_a; - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*5) ); + ymm2 = _mm256_set_ps((a11 + cs_a*5)->imag,(a11 + cs_a*5)->real, + (a11 + cs_a*5)->imag,(a11 + cs_a*5)->real, + (a11 + cs_a*5)->imag,(a11 + cs_a*5)->real, + (a11 + cs_a*5)->imag,(a11 + cs_a*5)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -43440,7 +43632,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm13 = _mm256_sub_ps(ymm13,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*6) ); + ymm2 = _mm256_set_ps((a11 + cs_a*6)->imag,(a11 + cs_a*6)->real, + (a11 + cs_a*6)->imag,(a11 + cs_a*6)->real, + (a11 + cs_a*6)->imag,(a11 + cs_a*6)->real, + (a11 + cs_a*6)->imag,(a11 + cs_a*6)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -43455,7 +43650,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm14 = _mm256_sub_ps(ymm14,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*7) ); + ymm2 = _mm256_set_ps((a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -43471,7 +43669,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm15 = _mm256_sub_ps(ymm15,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 5)); + ymm1 = _mm256_set_ps((d11_pack + 5)->imag,(d11_pack + 5)->real, + (d11_pack + 5)->imag,(d11_pack + 5)->real, + (d11_pack + 5)->imag,(d11_pack + 5)->real, + (d11_pack + 5)->imag,(d11_pack + 5)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm13) @@ -43481,7 +43682,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB a11 += rs_a; - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*6) ); + ymm2 = _mm256_set_ps((a11 + cs_a*6)->imag,(a11 + cs_a*6)->real, + (a11 + cs_a*6)->imag,(a11 + cs_a*6)->real, + (a11 + cs_a*6)->imag,(a11 + cs_a*6)->real, + (a11 + cs_a*6)->imag,(a11 + cs_a*6)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -43496,7 +43700,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm14 = _mm256_sub_ps(ymm14,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*7) ); + ymm2 = _mm256_set_ps((a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -43512,7 +43719,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm15 = _mm256_sub_ps(ymm15,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 6)); + ymm1 = _mm256_set_ps((d11_pack + 6)->imag,(d11_pack + 6)->real, + (d11_pack + 6)->imag,(d11_pack + 6)->real, + (d11_pack + 6)->imag,(d11_pack + 6)->real, + (d11_pack + 6)->imag,(d11_pack + 6)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm14) @@ -43522,7 +43732,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB a11 += rs_a; - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*7) ); + ymm2 = _mm256_set_ps((a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real, + (a11 + cs_a*7)->imag,(a11 + cs_a*7)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -43538,7 +43751,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm15 = _mm256_sub_ps(ymm15,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 7)); + ymm1 = _mm256_set_ps((d11_pack + 7)->imag,(d11_pack + 7)->real, + (d11_pack + 7)->imag,(d11_pack + 7)->real, + (d11_pack + 7)->imag,(d11_pack + 7)->real, + (d11_pack + 7)->imag,(d11_pack + 7)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm15) @@ -43654,14 +43870,20 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB BLIS_CTRSM_SMALL_NREG_TRANSPOSE_3x4(b11,cs_b,AlphaVal) ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm8) #else BLIS_CTRSM_MUL(ymm8) #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*1) ); + ymm2 = _mm256_set_ps((a11 + cs_a*1)->imag,(a11 + cs_a*1)->real, + (a11 + cs_a*1)->imag,(a11 + cs_a*1)->real, + (a11 + cs_a*1)->imag,(a11 + cs_a*1)->real, + (a11 + cs_a*1)->imag,(a11 + cs_a*1)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -43677,7 +43899,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm9 = _mm256_sub_ps(ymm9,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*2) ); + ymm2 = _mm256_set_ps((a11 + cs_a*2)->imag,(a11 + cs_a*2)->real, + (a11 + cs_a*2)->imag,(a11 + cs_a*2)->real, + (a11 + cs_a*2)->imag,(a11 + cs_a*2)->real, + (a11 + cs_a*2)->imag,(a11 + cs_a*2)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -43691,7 +43916,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm10 = _mm256_sub_ps(ymm10,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*3) ); + ymm2 = _mm256_set_ps((a11 + cs_a*3)->imag,(a11 + cs_a*3)->real, + (a11 + cs_a*3)->imag,(a11 + cs_a*3)->real, + (a11 + cs_a*3)->imag,(a11 + cs_a*3)->real, + (a11 + cs_a*3)->imag,(a11 + cs_a*3)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -43706,7 +43934,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm11 = _mm256_sub_ps(ymm11,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_set_ps((d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm9) @@ -43717,7 +43948,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB a11 += rs_a; - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*2) ); + ymm2 = _mm256_set_ps((a11 + cs_a*2)->imag,(a11 + cs_a*2)->real, + (a11 + cs_a*2)->imag,(a11 + cs_a*2)->real, + (a11 + cs_a*2)->imag,(a11 + cs_a*2)->real, + (a11 + cs_a*2)->imag,(a11 + cs_a*2)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -43731,7 +43965,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm10 = _mm256_sub_ps(ymm10,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*3) ); + ymm2 = _mm256_set_ps((a11 + cs_a*3)->imag,(a11 + cs_a*3)->real, + (a11 + cs_a*3)->imag,(a11 + cs_a*3)->real, + (a11 + cs_a*3)->imag,(a11 + cs_a*3)->real, + (a11 + cs_a*3)->imag,(a11 + cs_a*3)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -43747,7 +43984,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 2)); + ymm1 = _mm256_set_ps((d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION @@ -43759,7 +43999,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB a11 += rs_a; - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*3) ); + ymm2 = _mm256_set_ps((a11 + cs_a*3)->imag,(a11 + cs_a*3)->real, + (a11 + cs_a*3)->imag,(a11 + cs_a*3)->real, + (a11 + cs_a*3)->imag,(a11 + cs_a*3)->real, + (a11 + cs_a*3)->imag,(a11 + cs_a*3)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -43775,7 +44018,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 3)); + ymm1 = _mm256_set_ps((d11_pack + 3)->imag,(d11_pack + 3)->real, + (d11_pack + 3)->imag,(d11_pack + 3)->real, + (d11_pack + 3)->imag,(d11_pack + 3)->real, + (d11_pack + 3)->imag,(d11_pack + 3)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm11) @@ -43803,7 +44049,7 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ///GEMM code begins/// BLIS_CTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b,p_lda,k_iter) BLIS_CTRSM_SMALL_NREG_TRANSPOSE_2x4(b11,cs_b,AlphaVal) - } + } else if(1 == n_rem) { ///GEMM code begins/// @@ -43813,14 +44059,20 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm8) #else BLIS_CTRSM_MUL(ymm8) #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*1) ); + ymm2 = _mm256_set_ps((a11 + cs_a*1)->imag,(a11 + cs_a*1)->real, + (a11 + cs_a*1)->imag,(a11 + cs_a*1)->real, + (a11 + cs_a*1)->imag,(a11 + cs_a*1)->real, + (a11 + cs_a*1)->imag,(a11 + cs_a*1)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -43836,7 +44088,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm9 = _mm256_sub_ps(ymm9,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*2) ); + ymm2 = _mm256_set_ps((a11 + cs_a*2)->imag,(a11 + cs_a*2)->real, + (a11 + cs_a*2)->imag,(a11 + cs_a*2)->real, + (a11 + cs_a*2)->imag,(a11 + cs_a*2)->real, + (a11 + cs_a*2)->imag,(a11 + cs_a*2)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -43850,7 +44105,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm10 = _mm256_sub_ps(ymm10,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*3) ); + ymm2 = _mm256_set_ps((a11 + cs_a*3)->imag,(a11 + cs_a*3)->real, + (a11 + cs_a*3)->imag,(a11 + cs_a*3)->real, + (a11 + cs_a*3)->imag,(a11 + cs_a*3)->real, + (a11 + cs_a*3)->imag,(a11 + cs_a*3)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -43865,7 +44123,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm11 = _mm256_sub_ps(ymm11,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_set_ps((d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm9) @@ -43876,7 +44137,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB a11 += rs_a; - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*2) ); + ymm2 = _mm256_set_ps((a11 + cs_a*2)->imag,(a11 + cs_a*2)->real, + (a11 + cs_a*2)->imag,(a11 + cs_a*2)->real, + (a11 + cs_a*2)->imag,(a11 + cs_a*2)->real, + (a11 + cs_a*2)->imag,(a11 + cs_a*2)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -43890,7 +44154,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm10 = _mm256_sub_ps(ymm10,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*3) ); + ymm2 = _mm256_set_ps((a11 + cs_a*3)->imag,(a11 + cs_a*3)->real, + (a11 + cs_a*3)->imag,(a11 + cs_a*3)->real, + (a11 + cs_a*3)->imag,(a11 + cs_a*3)->real, + (a11 + cs_a*3)->imag,(a11 + cs_a*3)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -43906,7 +44173,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 2)); + ymm1 = _mm256_set_ps((d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION @@ -43918,7 +44188,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB a11 += rs_a; - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*3) ); + ymm2 = _mm256_set_ps((a11 + cs_a*3)->imag,(a11 + cs_a*3)->real, + (a11 + cs_a*3)->imag,(a11 + cs_a*3)->real, + (a11 + cs_a*3)->imag,(a11 + cs_a*3)->real, + (a11 + cs_a*3)->imag,(a11 + cs_a*3)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -43934,7 +44207,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 3)); + ymm1 = _mm256_set_ps((d11_pack + 3)->imag,(d11_pack + 3)->real, + (d11_pack + 3)->imag,(d11_pack + 3)->real, + (d11_pack + 3)->imag,(d11_pack + 3)->real, + (d11_pack + 3)->imag,(d11_pack + 3)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm11) @@ -44502,13 +44778,17 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB dim_t i, j, k; //loop variables dim_t k_iter; //number of times GEMM to be performed - scomplex AlphaVal = *(scomplex *)AlphaObj->buffer; //value of alpha - scomplex *L = a->buffer; //pointer to matrix A - scomplex *B = b->buffer; //pointer to matrix B + scomplex AlphaVal[2]; + AlphaVal[0] = *(scomplex *)AlphaObj->buffer; //value of alpha + AlphaVal[1] = *(scomplex *)AlphaObj->buffer; //value of alpha + + scomplex *L = bli_obj_buffer_at_off(a); //pointer to matrix A + scomplex *B = bli_obj_buffer_at_off(b); //pointer to matrix B scomplex *a10, *a11, *b01, *b11; //pointers that point to blocks for GEMM and TRSM - scomplex ones = {1.0, 1.0}; + float ones = 1.0; + float ones_a[4] = {1.0, 1.0,1.0,1.0}; bool is_unitdiag = bli_obj_has_unit_diag(a); //scratch registers @@ -44521,7 +44801,7 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB __m128 xmm0, xmm1, xmm2, xmm3, xmm4; __m128 xmm5; - xmm0 = _mm_setzero_ps(); + xmm0 = _mm_setzero_ps(); xmm1 = _mm_setzero_ps(); xmm2 = _mm_setzero_ps(); xmm3 = _mm_setzero_ps(); @@ -44645,14 +44925,24 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB */ ////extract a00 ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 7)); + ymm1 = _mm256_set_ps((d11_pack + 7)->imag,(d11_pack + 7)->real, + (d11_pack + 7)->imag,(d11_pack + 7)->real, + (d11_pack + 7)->imag,(d11_pack + 7)->real, + (d11_pack + 7)->imag,(d11_pack + 7)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm15) #else BLIS_CTRSM_MUL(ymm15) #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*6 + 7*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*6 + 7*rs_a)->imag, + (a11 + cs_a*6 + 7*rs_a)->real, + (a11 + cs_a*6 + 7*rs_a)->imag, + (a11 + cs_a*6 + 7*rs_a)->real, + (a11 + cs_a*6 + 7*rs_a)->imag, + (a11 + cs_a*6 + 7*rs_a)->real, + (a11 + cs_a*6 + 7*rs_a)->imag, + (a11 + cs_a*6 + 7*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -44668,7 +44958,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm14 = _mm256_sub_ps(ymm14,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*5 + 7*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*5 + 7*rs_a)->imag, + (a11 + cs_a*5 + 7*rs_a)->real, + (a11 + cs_a*5 + 7*rs_a)->imag, + (a11 + cs_a*5 + 7*rs_a)->real, + (a11 + cs_a*5 + 7*rs_a)->imag, + (a11 + cs_a*5 + 7*rs_a)->real, + (a11 + cs_a*5 + 7*rs_a)->imag, + (a11 + cs_a*5 + 7*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -44683,7 +44980,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm13 = _mm256_sub_ps(ymm13,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*4 + 7*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*4 + 7*rs_a)->imag, + (a11 + cs_a*4 + 7*rs_a)->real, + (a11 + cs_a*4 + 7*rs_a)->imag, + (a11 + cs_a*4 + 7*rs_a)->real, + (a11 + cs_a*4 + 7*rs_a)->imag, + (a11 + cs_a*4 + 7*rs_a)->real, + (a11 + cs_a*4 + 7*rs_a)->imag, + (a11 + cs_a*4 + 7*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -44698,7 +45002,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm12 = _mm256_sub_ps(ymm12,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*3 + 7*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*3 + 7*rs_a)->imag, + (a11 + cs_a*3 + 7*rs_a)->real, + (a11 + cs_a*3 + 7*rs_a)->imag, + (a11 + cs_a*3 + 7*rs_a)->real, + (a11 + cs_a*3 + 7*rs_a)->imag, + (a11 + cs_a*3 + 7*rs_a)->real, + (a11 + cs_a*3 + 7*rs_a)->imag, + (a11 + cs_a*3 + 7*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -44713,7 +45024,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm11 = _mm256_sub_ps(ymm11,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*2 + 7*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*2 + 7*rs_a)->imag, + (a11 + cs_a*2 + 7*rs_a)->real, + (a11 + cs_a*2 + 7*rs_a)->imag, + (a11 + cs_a*2 + 7*rs_a)->real, + (a11 + cs_a*2 + 7*rs_a)->imag, + (a11 + cs_a*2 + 7*rs_a)->real, + (a11 + cs_a*2 + 7*rs_a)->imag, + (a11 + cs_a*2 + 7*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -44728,7 +45046,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm10 = _mm256_sub_ps(ymm10,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*1 + 7*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*1 + 7*rs_a)->imag, + (a11 + cs_a*1 + 7*rs_a)->real, + (a11 + cs_a*1 + 7*rs_a)->imag, + (a11 + cs_a*1 + 7*rs_a)->real, + (a11 + cs_a*1 + 7*rs_a)->imag, + (a11 + cs_a*1 + 7*rs_a)->real, + (a11 + cs_a*1 + 7*rs_a)->imag, + (a11 + cs_a*1 + 7*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -44743,7 +45068,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm9 = _mm256_sub_ps(ymm9,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*0 + 7*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*0 + 7*rs_a)->imag, + (a11 + cs_a*0 + 7*rs_a)->real, + (a11 + cs_a*0 + 7*rs_a)->imag, + (a11 + cs_a*0 + 7*rs_a)->real, + (a11 + cs_a*0 + 7*rs_a)->imag, + (a11 + cs_a*0 + 7*rs_a)->real, + (a11 + cs_a*0 + 7*rs_a)->imag, + (a11 + cs_a*0 + 7*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -44758,7 +45090,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm8 = _mm256_sub_ps(ymm8,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 6)); + ymm1 = _mm256_set_ps((d11_pack + 6)->imag,(d11_pack + 6)->real, + (d11_pack + 6)->imag,(d11_pack + 6)->real, + (d11_pack + 6)->imag,(d11_pack + 6)->real, + (d11_pack + 6)->imag,(d11_pack + 6)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm14) @@ -44766,7 +45101,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB BLIS_CTRSM_MUL(ymm14) #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*5 + 6*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*5 + 6*rs_a)->imag, + (a11 + cs_a*5 + 6*rs_a)->real, + (a11 + cs_a*5 + 6*rs_a)->imag, + (a11 + cs_a*5 + 6*rs_a)->real, + (a11 + cs_a*5 + 6*rs_a)->imag, + (a11 + cs_a*5 + 6*rs_a)->real, + (a11 + cs_a*5 + 6*rs_a)->imag, + (a11 + cs_a*5 + 6*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -44781,7 +45123,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm13 = _mm256_sub_ps(ymm13,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*4 + 6*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*4 + 6*rs_a)->imag, + (a11 + cs_a*4 + 6*rs_a)->real, + (a11 + cs_a*4 + 6*rs_a)->imag, + (a11 + cs_a*4 + 6*rs_a)->real, + (a11 + cs_a*4 + 6*rs_a)->imag, + (a11 + cs_a*4 + 6*rs_a)->real, + (a11 + cs_a*4 + 6*rs_a)->imag, + (a11 + cs_a*4 + 6*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -44796,7 +45145,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm12 = _mm256_sub_ps(ymm12,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*3 + 6*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*3 + 6*rs_a)->imag, + (a11 + cs_a*3 + 6*rs_a)->real, + (a11 + cs_a*3 + 6*rs_a)->imag, + (a11 + cs_a*3 + 6*rs_a)->real, + (a11 + cs_a*3 + 6*rs_a)->imag, + (a11 + cs_a*3 + 6*rs_a)->real, + (a11 + cs_a*3 + 6*rs_a)->imag, + (a11 + cs_a*3 + 6*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -44811,7 +45167,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm11 = _mm256_sub_ps(ymm11,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*2 + 6*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*2 + 6*rs_a)->imag, + (a11 + cs_a*2 + 6*rs_a)->real, + (a11 + cs_a*2 + 6*rs_a)->imag, + (a11 + cs_a*2 + 6*rs_a)->real, + (a11 + cs_a*2 + 6*rs_a)->imag, + (a11 + cs_a*2 + 6*rs_a)->real, + (a11 + cs_a*2 + 6*rs_a)->imag, + (a11 + cs_a*2 + 6*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -44826,7 +45189,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm10 = _mm256_sub_ps(ymm10,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*1 + 6*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*1 + 6*rs_a)->imag, + (a11 + cs_a*1 + 6*rs_a)->real, + (a11 + cs_a*1 + 6*rs_a)->imag, + (a11 + cs_a*1 + 6*rs_a)->real, + (a11 + cs_a*1 + 6*rs_a)->imag, + (a11 + cs_a*1 + 6*rs_a)->real, + (a11 + cs_a*1 + 6*rs_a)->imag, + (a11 + cs_a*1 + 6*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -44841,7 +45211,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm9 = _mm256_sub_ps(ymm9,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*0 + 6*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*0 + 6*rs_a)->imag, + (a11 + cs_a*0 + 6*rs_a)->real, + (a11 + cs_a*0 + 6*rs_a)->imag, + (a11 + cs_a*0 + 6*rs_a)->real, + (a11 + cs_a*0 + 6*rs_a)->imag, + (a11 + cs_a*0 + 6*rs_a)->real, + (a11 + cs_a*0 + 6*rs_a)->imag, + (a11 + cs_a*0 + 6*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -44857,7 +45234,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm8 = _mm256_sub_ps(ymm8,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 5)); + ymm1 = _mm256_set_ps((d11_pack + 5)->imag,(d11_pack + 5)->real, + (d11_pack + 5)->imag,(d11_pack + 5)->real, + (d11_pack + 5)->imag,(d11_pack + 5)->real, + (d11_pack + 5)->imag,(d11_pack + 5)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION @@ -44867,7 +45247,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*4 + 5*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*4 + 5*rs_a)->imag, + (a11 + cs_a*4 + 5*rs_a)->real, + (a11 + cs_a*4 + 5*rs_a)->imag, + (a11 + cs_a*4 + 5*rs_a)->real, + (a11 + cs_a*4 + 5*rs_a)->imag, + (a11 + cs_a*4 + 5*rs_a)->real, + (a11 + cs_a*4 + 5*rs_a)->imag, + (a11 + cs_a*4 + 5*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -44882,7 +45269,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm12 = _mm256_sub_ps(ymm12,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*3 + 5*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*3 + 5*rs_a)->imag, + (a11 + cs_a*3 + 5*rs_a)->real, + (a11 + cs_a*3 + 5*rs_a)->imag, + (a11 + cs_a*3 + 5*rs_a)->real, + (a11 + cs_a*3 + 5*rs_a)->imag, + (a11 + cs_a*3 + 5*rs_a)->real, + (a11 + cs_a*3 + 5*rs_a)->imag, + (a11 + cs_a*3 + 5*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -44897,7 +45291,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm11 = _mm256_sub_ps(ymm11,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*2 + 5*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*2 + 5*rs_a)->imag, + (a11 + cs_a*2 + 5*rs_a)->real, + (a11 + cs_a*2 + 5*rs_a)->imag, + (a11 + cs_a*2 + 5*rs_a)->real, + (a11 + cs_a*2 + 5*rs_a)->imag, + (a11 + cs_a*2 + 5*rs_a)->real, + (a11 + cs_a*2 + 5*rs_a)->imag, + (a11 + cs_a*2 + 5*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -44912,7 +45313,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm10 = _mm256_sub_ps(ymm10,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*1 + 5*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*1 + 5*rs_a)->imag, + (a11 + cs_a*1 + 5*rs_a)->real, + (a11 + cs_a*1 + 5*rs_a)->imag, + (a11 + cs_a*1 + 5*rs_a)->real, + (a11 + cs_a*1 + 5*rs_a)->imag, + (a11 + cs_a*1 + 5*rs_a)->real, + (a11 + cs_a*1 + 5*rs_a)->imag, + (a11 + cs_a*1 + 5*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -44927,7 +45335,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm9 = _mm256_sub_ps(ymm9,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*0 + 5*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*0 + 5*rs_a)->imag, + (a11 + cs_a*0 + 5*rs_a)->real, + (a11 + cs_a*0 + 5*rs_a)->imag, + (a11 + cs_a*0 + 5*rs_a)->real, + (a11 + cs_a*0 + 5*rs_a)->imag, + (a11 + cs_a*0 + 5*rs_a)->real, + (a11 + cs_a*0 + 5*rs_a)->imag, + (a11 + cs_a*0 + 5*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -44943,7 +45358,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm8 = _mm256_sub_ps(ymm8,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 4)); + ymm1 = _mm256_set_ps((d11_pack + 4)->imag,(d11_pack + 4)->real, + (d11_pack + 4)->imag,(d11_pack + 4)->real, + (d11_pack + 4)->imag,(d11_pack + 4)->real, + (d11_pack + 4)->imag,(d11_pack + 4)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm12) @@ -44952,7 +45370,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*3 + 4*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*3 + 4*rs_a)->imag, + (a11 + cs_a*3 + 4*rs_a)->real, + (a11 + cs_a*3 + 4*rs_a)->imag, + (a11 + cs_a*3 + 4*rs_a)->real, + (a11 + cs_a*3 + 4*rs_a)->imag, + (a11 + cs_a*3 + 4*rs_a)->real, + (a11 + cs_a*3 + 4*rs_a)->imag, + (a11 + cs_a*3 + 4*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -44967,7 +45392,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm11 = _mm256_sub_ps(ymm11,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*2 + 4*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*2 + 4*rs_a)->imag, + (a11 + cs_a*2 + 4*rs_a)->real, + (a11 + cs_a*2 + 4*rs_a)->imag, + (a11 + cs_a*2 + 4*rs_a)->real, + (a11 + cs_a*2 + 4*rs_a)->imag, + (a11 + cs_a*2 + 4*rs_a)->real, + (a11 + cs_a*2 + 4*rs_a)->imag, + (a11 + cs_a*2 + 4*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -44982,7 +45414,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm10 = _mm256_sub_ps(ymm10,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*1 + 4*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*1 + 4*rs_a)->imag, + (a11 + cs_a*1 + 4*rs_a)->real, + (a11 + cs_a*1 + 4*rs_a)->imag, + (a11 + cs_a*1 + 4*rs_a)->real, + (a11 + cs_a*1 + 4*rs_a)->imag, + (a11 + cs_a*1 + 4*rs_a)->real, + (a11 + cs_a*1 + 4*rs_a)->imag, + (a11 + cs_a*1 + 4*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -44997,7 +45436,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm9 = _mm256_sub_ps(ymm9,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*0 + 4*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*0 + 4*rs_a)->imag, + (a11 + cs_a*0 + 4*rs_a)->real, + (a11 + cs_a*0 + 4*rs_a)->imag, + (a11 + cs_a*0 + 4*rs_a)->real, + (a11 + cs_a*0 + 4*rs_a)->imag, + (a11 + cs_a*0 + 4*rs_a)->real, + (a11 + cs_a*0 + 4*rs_a)->imag, + (a11 + cs_a*0 + 4*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -45013,7 +45459,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm8 = _mm256_sub_ps(ymm8,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 3)); + ymm1 = _mm256_set_ps((d11_pack + 3)->imag,(d11_pack + 3)->real, + (d11_pack + 3)->imag,(d11_pack + 3)->real, + (d11_pack + 3)->imag,(d11_pack + 3)->real, + (d11_pack + 3)->imag,(d11_pack + 3)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm11) @@ -45021,7 +45470,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB BLIS_CTRSM_MUL(ymm11) #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*2 + 3*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*2 + 3*rs_a)->imag, + (a11 + cs_a*2 + 3*rs_a)->real, + (a11 + cs_a*2 + 3*rs_a)->imag, + (a11 + cs_a*2 + 3*rs_a)->real, + (a11 + cs_a*2 + 3*rs_a)->imag, + (a11 + cs_a*2 + 3*rs_a)->real, + (a11 + cs_a*2 + 3*rs_a)->imag, + (a11 + cs_a*2 + 3*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -45036,7 +45492,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm10 = _mm256_sub_ps(ymm10,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*1 + 3*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*1 + 3*rs_a)->imag, + (a11 + cs_a*1 + 3*rs_a)->real, + (a11 + cs_a*1 + 3*rs_a)->imag, + (a11 + cs_a*1 + 3*rs_a)->real, + (a11 + cs_a*1 + 3*rs_a)->imag, + (a11 + cs_a*1 + 3*rs_a)->real, + (a11 + cs_a*1 + 3*rs_a)->imag, + (a11 + cs_a*1 + 3*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -45051,7 +45514,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm9 = _mm256_sub_ps(ymm9,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*0 + 3*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*0 + 3*rs_a)->imag, + (a11 + cs_a*0 + 3*rs_a)->real, + (a11 + cs_a*0 + 3*rs_a)->imag, + (a11 + cs_a*0 + 3*rs_a)->real, + (a11 + cs_a*0 + 3*rs_a)->imag, + (a11 + cs_a*0 + 3*rs_a)->real, + (a11 + cs_a*0 + 3*rs_a)->imag, + (a11 + cs_a*0 + 3*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -45067,7 +45537,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm8 = _mm256_sub_ps(ymm8,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 2)); + ymm1 = _mm256_set_ps((d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm10) @@ -45076,7 +45549,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a + 2*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*1 + 2*rs_a)->imag, + (a11 + cs_a*1 + 2*rs_a)->real, + (a11 + cs_a*1 + 2*rs_a)->imag, + (a11 + cs_a*1 + 2*rs_a)->real, + (a11 + cs_a*1 + 2*rs_a)->imag, + (a11 + cs_a*1 + 2*rs_a)->real, + (a11 + cs_a*1 + 2*rs_a)->imag, + (a11 + cs_a*1 + 2*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -45091,7 +45571,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm9 = _mm256_sub_ps(ymm9,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*0 + 2*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*0 + 2*rs_a)->imag, + (a11 + cs_a*0 + 2*rs_a)->real, + (a11 + cs_a*0 + 2*rs_a)->imag, + (a11 + cs_a*0 + 2*rs_a)->real, + (a11 + cs_a*0 + 2*rs_a)->imag, + (a11 + cs_a*0 + 2*rs_a)->real, + (a11 + cs_a*0 + 2*rs_a)->imag, + (a11 + cs_a*0 + 2*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -45107,7 +45594,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm8 = _mm256_sub_ps(ymm8,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_set_ps((d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm9) @@ -45116,7 +45606,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a) ); + ymm2 = _mm256_set_ps((a11 + cs_a*0 + 1*rs_a)->imag, + (a11 + cs_a*0 + 1*rs_a)->real, + (a11 + cs_a*0 + 1*rs_a)->imag, + (a11 + cs_a*0 + 1*rs_a)->real, + (a11 + cs_a*0 + 1*rs_a)->imag, + (a11 + cs_a*0 + 1*rs_a)->real, + (a11 + cs_a*0 + 1*rs_a)->imag, + (a11 + cs_a*0 + 1*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -45132,7 +45629,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm8 = _mm256_sub_ps(ymm8,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm8) @@ -45163,8 +45663,8 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB BLIS_CTRSM_SMALL_GEMM_8mx2n(a10,b01,cs_b,p_lda,k_iter) float zero = 0.0; - ymm16 = _mm256_broadcast_ss(&AlphaVal.real); - ymm17 = _mm256_broadcast_ss(&AlphaVal.imag); + ymm16 = _mm256_broadcast_ss(&AlphaVal[0].real); + ymm17 = _mm256_broadcast_ss(&AlphaVal[0].imag); ymm2 = _mm256_broadcast_ss(&zero); ymm3 = _mm256_broadcast_ss(&zero); ymm6 = _mm256_broadcast_ss(&zero); @@ -45215,8 +45715,8 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB BLIS_CTRSM_SMALL_GEMM_8mx1n(a10,b01,cs_b,p_lda,k_iter) float zero = 0.0; - ymm16 = _mm256_broadcast_ss(&AlphaVal.real); - ymm17 = _mm256_broadcast_ss(&AlphaVal.imag); + ymm16 = _mm256_broadcast_ss(&AlphaVal[0].real); + ymm17 = _mm256_broadcast_ss(&AlphaVal[0].imag); ymm2 = _mm256_broadcast_ss(&zero); ymm3 = _mm256_broadcast_ss(&zero); ymm6 = _mm256_broadcast_ss(&zero); @@ -45285,14 +45785,24 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 7)); + ymm1 = _mm256_set_ps((d11_pack + 7)->imag,(d11_pack + 7)->real, + (d11_pack + 7)->imag,(d11_pack + 7)->real, + (d11_pack + 7)->imag,(d11_pack + 7)->real, + (d11_pack + 7)->imag,(d11_pack + 7)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm15) #else BLIS_CTRSM_MUL(ymm15) #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*6 + 7*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*6 + 7*rs_a)->imag, + (a11 + cs_a*6 + 7*rs_a)->real, + (a11 + cs_a*6 + 7*rs_a)->imag, + (a11 + cs_a*6 + 7*rs_a)->real, + (a11 + cs_a*6 + 7*rs_a)->imag, + (a11 + cs_a*6 + 7*rs_a)->real, + (a11 + cs_a*6 + 7*rs_a)->imag, + (a11 + cs_a*6 + 7*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -45309,7 +45819,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm14 = _mm256_sub_ps(ymm14,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*5 + 7*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*5 + 7*rs_a)->imag, + (a11 + cs_a*5 + 7*rs_a)->real, + (a11 + cs_a*5 + 7*rs_a)->imag, + (a11 + cs_a*5 + 7*rs_a)->real, + (a11 + cs_a*5 + 7*rs_a)->imag, + (a11 + cs_a*5 + 7*rs_a)->real, + (a11 + cs_a*5 + 7*rs_a)->imag, + (a11 + cs_a*5 + 7*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -45324,7 +45841,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm13 = _mm256_sub_ps(ymm13,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*4 + 7*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*4 + 7*rs_a)->imag, + (a11 + cs_a*4 + 7*rs_a)->real, + (a11 + cs_a*4 + 7*rs_a)->imag, + (a11 + cs_a*4 + 7*rs_a)->real, + (a11 + cs_a*4 + 7*rs_a)->imag, + (a11 + cs_a*4 + 7*rs_a)->real, + (a11 + cs_a*4 + 7*rs_a)->imag, + (a11 + cs_a*4 + 7*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -45339,7 +45863,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm12 = _mm256_sub_ps(ymm12,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*3 + 7*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*3 + 7*rs_a)->imag, + (a11 + cs_a*3 + 7*rs_a)->real, + (a11 + cs_a*3 + 7*rs_a)->imag, + (a11 + cs_a*3 + 7*rs_a)->real, + (a11 + cs_a*3 + 7*rs_a)->imag, + (a11 + cs_a*3 + 7*rs_a)->real, + (a11 + cs_a*3 + 7*rs_a)->imag, + (a11 + cs_a*3 + 7*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -45354,7 +45885,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm11 = _mm256_sub_ps(ymm11,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*2 + 7*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*2 + 7*rs_a)->imag, + (a11 + cs_a*2 + 7*rs_a)->real, + (a11 + cs_a*2 + 7*rs_a)->imag, + (a11 + cs_a*2 + 7*rs_a)->real, + (a11 + cs_a*2 + 7*rs_a)->imag, + (a11 + cs_a*2 + 7*rs_a)->real, + (a11 + cs_a*2 + 7*rs_a)->imag, + (a11 + cs_a*2 + 7*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -45369,7 +45907,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm10 = _mm256_sub_ps(ymm10,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*1 + 7*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*1 + 7*rs_a)->imag, + (a11 + cs_a*1 + 7*rs_a)->real, + (a11 + cs_a*1 + 7*rs_a)->imag, + (a11 + cs_a*1 + 7*rs_a)->real, + (a11 + cs_a*1 + 7*rs_a)->imag, + (a11 + cs_a*1 + 7*rs_a)->real, + (a11 + cs_a*1 + 7*rs_a)->imag, + (a11 + cs_a*1 + 7*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -45384,7 +45929,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm9 = _mm256_sub_ps(ymm9,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*0 + 7*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*0 + 7*rs_a)->imag, + (a11 + cs_a*0 + 7*rs_a)->real, + (a11 + cs_a*0 + 7*rs_a)->imag, + (a11 + cs_a*0 + 7*rs_a)->real, + (a11 + cs_a*0 + 7*rs_a)->imag, + (a11 + cs_a*0 + 7*rs_a)->real, + (a11 + cs_a*0 + 7*rs_a)->imag, + (a11 + cs_a*0 + 7*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -45399,7 +45951,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm8 = _mm256_sub_ps(ymm8,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 6)); + ymm1 = _mm256_set_ps((d11_pack + 6)->imag,(d11_pack + 6)->real, + (d11_pack + 6)->imag,(d11_pack + 6)->real, + (d11_pack + 6)->imag,(d11_pack + 6)->real, + (d11_pack + 6)->imag,(d11_pack + 6)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm14) @@ -45407,7 +45962,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB BLIS_CTRSM_MUL(ymm14) #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*5 + 6*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*5 + 6*rs_a)->imag, + (a11 + cs_a*5 + 6*rs_a)->real, + (a11 + cs_a*5 + 6*rs_a)->imag, + (a11 + cs_a*5 + 6*rs_a)->real, + (a11 + cs_a*5 + 6*rs_a)->imag, + (a11 + cs_a*5 + 6*rs_a)->real, + (a11 + cs_a*5 + 6*rs_a)->imag, + (a11 + cs_a*5 + 6*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -45422,7 +45984,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm13 = _mm256_sub_ps(ymm13,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*4 + 6*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*4 + 6*rs_a)->imag, + (a11 + cs_a*4 + 6*rs_a)->real, + (a11 + cs_a*4 + 6*rs_a)->imag, + (a11 + cs_a*4 + 6*rs_a)->real, + (a11 + cs_a*4 + 6*rs_a)->imag, + (a11 + cs_a*4 + 6*rs_a)->real, + (a11 + cs_a*4 + 6*rs_a)->imag, + (a11 + cs_a*4 + 6*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -45437,7 +46006,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm12 = _mm256_sub_ps(ymm12,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*3 + 6*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*3 + 6*rs_a)->imag, + (a11 + cs_a*3 + 6*rs_a)->real, + (a11 + cs_a*3 + 6*rs_a)->imag, + (a11 + cs_a*3 + 6*rs_a)->real, + (a11 + cs_a*3 + 6*rs_a)->imag, + (a11 + cs_a*3 + 6*rs_a)->real, + (a11 + cs_a*3 + 6*rs_a)->imag, + (a11 + cs_a*3 + 6*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -45452,7 +46028,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm11 = _mm256_sub_ps(ymm11,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*2 + 6*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*2 + 6*rs_a)->imag, + (a11 + cs_a*2 + 6*rs_a)->real, + (a11 + cs_a*2 + 6*rs_a)->imag, + (a11 + cs_a*2 + 6*rs_a)->real, + (a11 + cs_a*2 + 6*rs_a)->imag, + (a11 + cs_a*2 + 6*rs_a)->real, + (a11 + cs_a*2 + 6*rs_a)->imag, + (a11 + cs_a*2 + 6*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -45467,7 +46050,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm10 = _mm256_sub_ps(ymm10,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*1 + 6*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*1 + 6*rs_a)->imag, + (a11 + cs_a*1 + 6*rs_a)->real, + (a11 + cs_a*1 + 6*rs_a)->imag, + (a11 + cs_a*1 + 6*rs_a)->real, + (a11 + cs_a*1 + 6*rs_a)->imag, + (a11 + cs_a*1 + 6*rs_a)->real, + (a11 + cs_a*1 + 6*rs_a)->imag, + (a11 + cs_a*1 + 6*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -45482,7 +46072,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm9 = _mm256_sub_ps(ymm9,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*0 + 6*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*0 + 6*rs_a)->imag, + (a11 + cs_a*0 + 6*rs_a)->real, + (a11 + cs_a*0 + 6*rs_a)->imag, + (a11 + cs_a*0 + 6*rs_a)->real, + (a11 + cs_a*0 + 6*rs_a)->imag, + (a11 + cs_a*0 + 6*rs_a)->real, + (a11 + cs_a*0 + 6*rs_a)->imag, + (a11 + cs_a*0 + 6*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -45498,7 +46095,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm8 = _mm256_sub_ps(ymm8,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 5)); + ymm1 = _mm256_set_ps((d11_pack + 5)->imag,(d11_pack + 5)->real, + (d11_pack + 5)->imag,(d11_pack + 5)->real, + (d11_pack + 5)->imag,(d11_pack + 5)->real, + (d11_pack + 5)->imag,(d11_pack + 5)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION @@ -45508,7 +46108,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*4 + 5*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*4 + 5*rs_a)->imag, + (a11 + cs_a*4 + 5*rs_a)->real, + (a11 + cs_a*4 + 5*rs_a)->imag, + (a11 + cs_a*4 + 5*rs_a)->real, + (a11 + cs_a*4 + 5*rs_a)->imag, + (a11 + cs_a*4 + 5*rs_a)->real, + (a11 + cs_a*4 + 5*rs_a)->imag, + (a11 + cs_a*4 + 5*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -45523,7 +46130,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm12 = _mm256_sub_ps(ymm12,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*3 + 5*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*3 + 5*rs_a)->imag, + (a11 + cs_a*3 + 5*rs_a)->real, + (a11 + cs_a*3 + 5*rs_a)->imag, + (a11 + cs_a*3 + 5*rs_a)->real, + (a11 + cs_a*3 + 5*rs_a)->imag, + (a11 + cs_a*3 + 5*rs_a)->real, + (a11 + cs_a*3 + 5*rs_a)->imag, + (a11 + cs_a*3 + 5*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -45538,7 +46152,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm11 = _mm256_sub_ps(ymm11,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*2 + 5*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*2 + 5*rs_a)->imag, + (a11 + cs_a*2 + 5*rs_a)->real, + (a11 + cs_a*2 + 5*rs_a)->imag, + (a11 + cs_a*2 + 5*rs_a)->real, + (a11 + cs_a*2 + 5*rs_a)->imag, + (a11 + cs_a*2 + 5*rs_a)->real, + (a11 + cs_a*2 + 5*rs_a)->imag, + (a11 + cs_a*2 + 5*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -45553,7 +46174,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm10 = _mm256_sub_ps(ymm10,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*1 + 5*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*1 + 5*rs_a)->imag, + (a11 + cs_a*1 + 5*rs_a)->real, + (a11 + cs_a*1 + 5*rs_a)->imag, + (a11 + cs_a*1 + 5*rs_a)->real, + (a11 + cs_a*1 + 5*rs_a)->imag, + (a11 + cs_a*1 + 5*rs_a)->real, + (a11 + cs_a*1 + 5*rs_a)->imag, + (a11 + cs_a*1 + 5*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -45568,7 +46196,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm9 = _mm256_sub_ps(ymm9,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*0 + 5*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*0 + 5*rs_a)->imag, + (a11 + cs_a*0 + 5*rs_a)->real, + (a11 + cs_a*0 + 5*rs_a)->imag, + (a11 + cs_a*0 + 5*rs_a)->real, + (a11 + cs_a*0 + 5*rs_a)->imag, + (a11 + cs_a*0 + 5*rs_a)->real, + (a11 + cs_a*0 + 5*rs_a)->imag, + (a11 + cs_a*0 + 5*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -45584,7 +46219,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm8 = _mm256_sub_ps(ymm8,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 4)); + ymm1 = _mm256_set_ps((d11_pack + 4)->imag,(d11_pack + 4)->real, + (d11_pack + 4)->imag,(d11_pack + 4)->real, + (d11_pack + 4)->imag,(d11_pack + 4)->real, + (d11_pack + 4)->imag,(d11_pack + 4)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm12) @@ -45593,7 +46231,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*3 + 4*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*3 + 4*rs_a)->imag, + (a11 + cs_a*3 + 4*rs_a)->real, + (a11 + cs_a*3 + 4*rs_a)->imag, + (a11 + cs_a*3 + 4*rs_a)->real, + (a11 + cs_a*3 + 4*rs_a)->imag, + (a11 + cs_a*3 + 4*rs_a)->real, + (a11 + cs_a*3 + 4*rs_a)->imag, + (a11 + cs_a*3 + 4*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -45608,7 +46253,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm11 = _mm256_sub_ps(ymm11,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*2 + 4*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*2 + 4*rs_a)->imag, + (a11 + cs_a*2 + 4*rs_a)->real, + (a11 + cs_a*2 + 4*rs_a)->imag, + (a11 + cs_a*2 + 4*rs_a)->real, + (a11 + cs_a*2 + 4*rs_a)->imag, + (a11 + cs_a*2 + 4*rs_a)->real, + (a11 + cs_a*2 + 4*rs_a)->imag, + (a11 + cs_a*2 + 4*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -45623,7 +46275,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm10 = _mm256_sub_ps(ymm10,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*1 + 4*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*1 + 4*rs_a)->imag, + (a11 + cs_a*1 + 4*rs_a)->real, + (a11 + cs_a*1 + 4*rs_a)->imag, + (a11 + cs_a*1 + 4*rs_a)->real, + (a11 + cs_a*1 + 4*rs_a)->imag, + (a11 + cs_a*1 + 4*rs_a)->real, + (a11 + cs_a*1 + 4*rs_a)->imag, + (a11 + cs_a*1 + 4*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -45638,7 +46297,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm9 = _mm256_sub_ps(ymm9,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*0 + 4*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*0 + 4*rs_a)->imag, + (a11 + cs_a*0 + 4*rs_a)->real, + (a11 + cs_a*0 + 4*rs_a)->imag, + (a11 + cs_a*0 + 4*rs_a)->real, + (a11 + cs_a*0 + 4*rs_a)->imag, + (a11 + cs_a*0 + 4*rs_a)->real, + (a11 + cs_a*0 + 4*rs_a)->imag, + (a11 + cs_a*0 + 4*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -45654,7 +46320,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm8 = _mm256_sub_ps(ymm8,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 3)); + ymm1 = _mm256_set_ps((d11_pack + 3)->imag,(d11_pack + 3)->real, + (d11_pack + 3)->imag,(d11_pack + 3)->real, + (d11_pack + 3)->imag,(d11_pack + 3)->real, + (d11_pack + 3)->imag,(d11_pack + 3)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm11) @@ -45662,7 +46331,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB BLIS_CTRSM_MUL(ymm11) #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*2 + 3*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*2 + 3*rs_a)->imag, + (a11 + cs_a*2 + 3*rs_a)->real, + (a11 + cs_a*2 + 3*rs_a)->imag, + (a11 + cs_a*2 + 3*rs_a)->real, + (a11 + cs_a*2 + 3*rs_a)->imag, + (a11 + cs_a*2 + 3*rs_a)->real, + (a11 + cs_a*2 + 3*rs_a)->imag, + (a11 + cs_a*2 + 3*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -45677,7 +46353,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm10 = _mm256_sub_ps(ymm10,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*1 + 3*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*1 + 3*rs_a)->imag, + (a11 + cs_a*1 + 3*rs_a)->real, + (a11 + cs_a*1 + 3*rs_a)->imag, + (a11 + cs_a*1 + 3*rs_a)->real, + (a11 + cs_a*1 + 3*rs_a)->imag, + (a11 + cs_a*1 + 3*rs_a)->real, + (a11 + cs_a*1 + 3*rs_a)->imag, + (a11 + cs_a*1 + 3*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -45692,7 +46375,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm9 = _mm256_sub_ps(ymm9,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*0 + 3*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*0 + 3*rs_a)->imag, + (a11 + cs_a*0 + 3*rs_a)->real, + (a11 + cs_a*0 + 3*rs_a)->imag, + (a11 + cs_a*0 + 3*rs_a)->real, + (a11 + cs_a*0 + 3*rs_a)->imag, + (a11 + cs_a*0 + 3*rs_a)->real, + (a11 + cs_a*0 + 3*rs_a)->imag, + (a11 + cs_a*0 + 3*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -45708,7 +46398,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm8 = _mm256_sub_ps(ymm8,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 2)); + ymm1 = _mm256_set_ps((d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm10) @@ -45717,7 +46410,15 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a + 2*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*1 + 2*rs_a)->imag, + (a11 + cs_a*1 + 2*rs_a)->real, + (a11 + cs_a*1 + 2*rs_a)->imag, + (a11 + cs_a*1 + 2*rs_a)->real, + (a11 + cs_a*1 + 2*rs_a)->imag, + (a11 + cs_a*1 + 2*rs_a)->real, + (a11 + cs_a*1 + 2*rs_a)->imag, + (a11 + cs_a*1 + 2*rs_a)->real); + ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -45732,7 +46433,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm9 = _mm256_sub_ps(ymm9,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*0 + 2*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*0 + 2*rs_a)->imag, + (a11 + cs_a*0 + 2*rs_a)->real, + (a11 + cs_a*0 + 2*rs_a)->imag, + (a11 + cs_a*0 + 2*rs_a)->real, + (a11 + cs_a*0 + 2*rs_a)->imag, + (a11 + cs_a*0 + 2*rs_a)->real, + (a11 + cs_a*0 + 2*rs_a)->imag, + (a11 + cs_a*0 + 2*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -45748,7 +46456,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm8 = _mm256_sub_ps(ymm8,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_set_ps((d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm9) @@ -45757,7 +46468,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a) ); + ymm2 = _mm256_set_ps((a11 + cs_a*0 + 1*rs_a)->imag, + (a11 + cs_a*0 + 1*rs_a)->real, + (a11 + cs_a*0 + 1*rs_a)->imag, + (a11 + cs_a*0 + 1*rs_a)->real, + (a11 + cs_a*0 + 1*rs_a)->imag, + (a11 + cs_a*0 + 1*rs_a)->real, + (a11 + cs_a*0 + 1*rs_a)->imag, + (a11 + cs_a*0 + 1*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -45773,7 +46491,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm8 = _mm256_sub_ps(ymm8,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm8) @@ -45891,7 +46612,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB BLIS_CTRSM_SMALL_NREG_TRANSPOSE_3x4(b11,cs_b,AlphaVal) ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 3)); + ymm1 = _mm256_set_ps((d11_pack + 3)->imag,(d11_pack + 3)->real, + (d11_pack + 3)->imag,(d11_pack + 3)->real, + (d11_pack + 3)->imag,(d11_pack + 3)->real, + (d11_pack + 3)->imag,(d11_pack + 3)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm11) @@ -45899,7 +46623,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB BLIS_CTRSM_MUL(ymm11) #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*2 + 3*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*2 + 3*rs_a)->imag, + (a11 + cs_a*2 + 3*rs_a)->real, + (a11 + cs_a*2 + 3*rs_a)->imag, + (a11 + cs_a*2 + 3*rs_a)->real, + (a11 + cs_a*2 + 3*rs_a)->imag, + (a11 + cs_a*2 + 3*rs_a)->real, + (a11 + cs_a*2 + 3*rs_a)->imag, + (a11 + cs_a*2 + 3*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -45913,7 +46644,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm10 = _mm256_sub_ps(ymm10,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*1 + 3*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*1 + 3*rs_a)->imag, + (a11 + cs_a*1 + 3*rs_a)->real, + (a11 + cs_a*1 + 3*rs_a)->imag, + (a11 + cs_a*1 + 3*rs_a)->real, + (a11 + cs_a*1 + 3*rs_a)->imag, + (a11 + cs_a*1 + 3*rs_a)->real, + (a11 + cs_a*1 + 3*rs_a)->imag, + (a11 + cs_a*1 + 3*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -45927,7 +46665,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm9 = _mm256_sub_ps(ymm9,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*0 + 3*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*0 + 3*rs_a)->imag, + (a11 + cs_a*0 + 3*rs_a)->real, + (a11 + cs_a*0 + 3*rs_a)->imag, + (a11 + cs_a*0 + 3*rs_a)->real, + (a11 + cs_a*0 + 3*rs_a)->imag, + (a11 + cs_a*0 + 3*rs_a)->real, + (a11 + cs_a*0 + 3*rs_a)->imag, + (a11 + cs_a*0 + 3*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -45942,7 +46687,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm8 = _mm256_sub_ps(ymm8,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 2)); + ymm1 = _mm256_set_ps((d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm10) @@ -45951,7 +46699,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a + 2*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*1 + 2*rs_a)->imag, + (a11 + cs_a*1 + 2*rs_a)->real, + (a11 + cs_a*1 + 2*rs_a)->imag, + (a11 + cs_a*1 + 2*rs_a)->real, + (a11 + cs_a*1 + 2*rs_a)->imag, + (a11 + cs_a*1 + 2*rs_a)->real, + (a11 + cs_a*1 + 2*rs_a)->imag, + (a11 + cs_a*1 + 2*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -45965,7 +46720,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm9 = _mm256_sub_ps(ymm9,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*0 + 2*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*0 + 2*rs_a)->imag, + (a11 + cs_a*0 + 2*rs_a)->real, + (a11 + cs_a*0 + 2*rs_a)->imag, + (a11 + cs_a*0 + 2*rs_a)->real, + (a11 + cs_a*0 + 2*rs_a)->imag, + (a11 + cs_a*0 + 2*rs_a)->real, + (a11 + cs_a*0 + 2*rs_a)->imag, + (a11 + cs_a*0 + 2*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -45980,7 +46742,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm8 = _mm256_sub_ps(ymm8,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_set_ps((d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm9) @@ -45989,7 +46754,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a) ); + ymm2 = _mm256_set_ps((a11 + cs_a*0 + 1*rs_a)->imag, + (a11 + cs_a*0 + 1*rs_a)->real, + (a11 + cs_a*0 + 1*rs_a)->imag, + (a11 + cs_a*0 + 1*rs_a)->real, + (a11 + cs_a*0 + 1*rs_a)->imag, + (a11 + cs_a*0 + 1*rs_a)->real, + (a11 + cs_a*0 + 1*rs_a)->imag, + (a11 + cs_a*0 + 1*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -46004,7 +46776,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm8 = _mm256_sub_ps(ymm8,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm8) @@ -46040,7 +46815,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB } ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 3)); + ymm1 = _mm256_set_ps((d11_pack + 3)->imag,(d11_pack + 3)->real, + (d11_pack + 3)->imag,(d11_pack + 3)->real, + (d11_pack + 3)->imag,(d11_pack + 3)->real, + (d11_pack + 3)->imag,(d11_pack + 3)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm11) @@ -46048,7 +46826,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB BLIS_CTRSM_MUL(ymm11) #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*2 + 3*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*2 + 3*rs_a)->imag, + (a11 + cs_a*2 + 3*rs_a)->real, + (a11 + cs_a*2 + 3*rs_a)->imag, + (a11 + cs_a*2 + 3*rs_a)->real, + (a11 + cs_a*2 + 3*rs_a)->imag, + (a11 + cs_a*2 + 3*rs_a)->real, + (a11 + cs_a*2 + 3*rs_a)->imag, + (a11 + cs_a*2 + 3*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -46062,7 +46847,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm10 = _mm256_sub_ps(ymm10,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*1 + 3*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*1 + 3*rs_a)->imag, + (a11 + cs_a*1 + 3*rs_a)->real, + (a11 + cs_a*1 + 3*rs_a)->imag, + (a11 + cs_a*1 + 3*rs_a)->real, + (a11 + cs_a*1 + 3*rs_a)->imag, + (a11 + cs_a*1 + 3*rs_a)->real, + (a11 + cs_a*1 + 3*rs_a)->imag, + (a11 + cs_a*1 + 3*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -46076,7 +46868,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm9 = _mm256_sub_ps(ymm9,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*0 + 3*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*0 + 3*rs_a)->imag, + (a11 + cs_a*0 + 3*rs_a)->real, + (a11 + cs_a*0 + 3*rs_a)->imag, + (a11 + cs_a*0 + 3*rs_a)->real, + (a11 + cs_a*0 + 3*rs_a)->imag, + (a11 + cs_a*0 + 3*rs_a)->real, + (a11 + cs_a*0 + 3*rs_a)->imag, + (a11 + cs_a*0 + 3*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -46091,7 +46890,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm8 = _mm256_sub_ps(ymm8,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 2)); + ymm1 = _mm256_set_ps((d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm10) @@ -46100,7 +46902,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a + 2*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*1 + 2*rs_a)->imag, + (a11 + cs_a*1 + 2*rs_a)->real, + (a11 + cs_a*1 + 2*rs_a)->imag, + (a11 + cs_a*1 + 2*rs_a)->real, + (a11 + cs_a*1 + 2*rs_a)->imag, + (a11 + cs_a*1 + 2*rs_a)->real, + (a11 + cs_a*1 + 2*rs_a)->imag, + (a11 + cs_a*1 + 2*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -46114,7 +46923,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm9 = _mm256_sub_ps(ymm9,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*0 + 2*rs_a)); + ymm2 = _mm256_set_ps((a11 + cs_a*0 + 2*rs_a)->imag, + (a11 + cs_a*0 + 2*rs_a)->real, + (a11 + cs_a*0 + 2*rs_a)->imag, + (a11 + cs_a*0 + 2*rs_a)->real, + (a11 + cs_a*0 + 2*rs_a)->imag, + (a11 + cs_a*0 + 2*rs_a)->real, + (a11 + cs_a*0 + 2*rs_a)->imag, + (a11 + cs_a*0 + 2*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -46129,7 +46945,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm8 = _mm256_sub_ps(ymm8,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_set_ps((d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm9) @@ -46138,7 +46957,14 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a) ); + ymm2 = _mm256_set_ps((a11 + cs_a*0 + 1*rs_a)->imag, + (a11 + cs_a*0 + 1*rs_a)->real, + (a11 + cs_a*0 + 1*rs_a)->imag, + (a11 + cs_a*0 + 1*rs_a)->real, + (a11 + cs_a*0 + 1*rs_a)->imag, + (a11 + cs_a*0 + 1*rs_a)->real, + (a11 + cs_a*0 + 1*rs_a)->imag, + (a11 + cs_a*0 + 1*rs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -46153,7 +46979,10 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB ymm8 = _mm256_sub_ps(ymm8,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm8) @@ -46696,7 +47525,6 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB return BLIS_SUCCESS; } - BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ( obj_t* AlphaObj, @@ -46731,9 +47559,12 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB dim_t i, j, k; //loop variables dim_t k_iter; //number of times GEMM to be performed - scomplex AlphaVal = *(scomplex *)AlphaObj->buffer; //value of alpha - scomplex *L = a->buffer; //pointer to matrix A - scomplex *B = b->buffer; //pointer to matrix B + scomplex AlphaVal[2]; + AlphaVal[0] = *(scomplex *)AlphaObj->buffer; //value of alpha + AlphaVal[1] = *(scomplex *)AlphaObj->buffer; //value of alpha + + scomplex *L = bli_obj_buffer_at_off(a); //pointer to matrix A + scomplex *B = bli_obj_buffer_at_off(b); //pointer to matrix B scomplex *a01, *a11, *b10, *b11; //pointers that point to blocks for GEMM and TRSM @@ -46750,7 +47581,7 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB __m128 xmm0, xmm1, xmm2; __m128 xmm5; - xmm0 = _mm_setzero_ps(); + xmm0 = _mm_setzero_ps(); xmm1 = _mm_setzero_ps(); xmm2 = _mm_setzero_ps(); xmm5 = _mm_setzero_ps(); @@ -46870,7 +47701,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB */ ////extract a00 ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 2)); + ymm1 = _mm256_set_ps((d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_TWO_DIV(ymm12, ymm13) @@ -46878,7 +47712,14 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB BLIS_CTRSM_MUL(ymm12) BLIS_CTRSM_MUL(ymm13) #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a *2 + rs_a*1) ); + ymm2 = _mm256_set_ps((a11 + cs_a *2 + rs_a*1)->imag, + (a11 + cs_a *2 + rs_a*1)->real, + (a11 + cs_a *2 + rs_a*1)->imag, + (a11 + cs_a *2 + rs_a*1)->real, + (a11 + cs_a *2 + rs_a*1)->imag, + (a11 + cs_a *2 + rs_a*1)->real, + (a11 + cs_a *2 + rs_a*1)->imag, + (a11 + cs_a *2 + rs_a*1)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -46902,7 +47743,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm11 = _mm256_sub_ps(ymm11,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*2) ); + ymm2 = _mm256_set_ps((a11 + cs_a *2)->imag,(a11 + cs_a *2)->real, + (a11 + cs_a *2)->imag,(a11 + cs_a *2)->real, + (a11 + cs_a *2)->imag,(a11 + cs_a *2)->real, + (a11 + cs_a *2)->imag,(a11 + cs_a *2)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -46923,7 +47767,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm9 = _mm256_sub_ps(ymm9,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_set_ps((d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_TWO_DIV(ymm10, ymm11) @@ -46932,7 +47779,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB BLIS_CTRSM_MUL(ymm11) #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a) ); + ymm2 = _mm256_set_ps((a11 + cs_a)->imag,(a11 + cs_a)->real, + (a11 + cs_a)->imag,(a11 + cs_a)->real, + (a11 + cs_a)->imag,(a11 + cs_a)->real, + (a11 + cs_a)->imag,(a11 + cs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -46952,7 +47802,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm9 = _mm256_sub_ps(ymm9,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION @@ -46989,14 +47842,24 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB // Load b11 of size 4x6 and multiply with alpha BLIS_PRE_CTRSM_SMALL_3x4(AlphaVal,b11,cs_b) ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 2)); + ymm1 = _mm256_set_ps((d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm12) #else BLIS_CTRSM_MUL(ymm12) #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a *2 + rs_a*1) ); + ymm2 = _mm256_set_ps((a11 + cs_a *2 + rs_a*1)->imag, + (a11 + cs_a *2 + rs_a*1)->real, + (a11 + cs_a *2 + rs_a*1)->imag, + (a11 + cs_a *2 + rs_a*1)->real, + (a11 + cs_a *2 + rs_a*1)->imag, + (a11 + cs_a *2 + rs_a*1)->real, + (a11 + cs_a *2 + rs_a*1)->imag, + (a11 + cs_a *2 + rs_a*1)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -47014,7 +47877,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm10 = _mm256_sub_ps(ymm10,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*2) ); + ymm2 = _mm256_set_ps((a11 + cs_a *2)->imag,(a11 + cs_a *2)->real, + (a11 + cs_a *2)->imag,(a11 + cs_a *2)->real, + (a11 + cs_a *2)->imag,(a11 + cs_a *2)->real, + (a11 + cs_a *2)->imag,(a11 + cs_a *2)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -47028,7 +47894,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm8 = _mm256_sub_ps(ymm8,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_set_ps((d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm10) @@ -47036,7 +47905,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB BLIS_CTRSM_MUL(ymm10) #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a) ); + ymm2 = _mm256_set_ps((a11 + cs_a)->imag,(a11 + cs_a)->real, + (a11 + cs_a)->imag,(a11 + cs_a)->real, + (a11 + cs_a)->imag,(a11 + cs_a)->real, + (a11 + cs_a)->imag,(a11 + cs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -47050,7 +47922,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm8 = _mm256_sub_ps(ymm8,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION @@ -47087,14 +47962,24 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB BLIS_PRE_CTRSM_SMALL_3x3(AlphaVal,b11,cs_b) ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 2)); + ymm1 = _mm256_set_ps((d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm12) #else BLIS_CTRSM_MUL(ymm12) #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a *2 + rs_a*1) ); + ymm2 = _mm256_set_ps((a11 + cs_a *2 + rs_a*1)->imag, + (a11 + cs_a *2 + rs_a*1)->real, + (a11 + cs_a *2 + rs_a*1)->imag, + (a11 + cs_a *2 + rs_a*1)->real, + (a11 + cs_a *2 + rs_a*1)->imag, + (a11 + cs_a *2 + rs_a*1)->real, + (a11 + cs_a *2 + rs_a*1)->imag, + (a11 + cs_a *2 + rs_a*1)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -47111,7 +47996,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm10 = _mm256_sub_ps(ymm10,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*2) ); + ymm2 = _mm256_set_ps((a11 + cs_a *2)->imag,(a11 + cs_a *2)->real, + (a11 + cs_a *2)->imag,(a11 + cs_a *2)->real, + (a11 + cs_a *2)->imag,(a11 + cs_a *2)->real, + (a11 + cs_a *2)->imag,(a11 + cs_a *2)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -47125,7 +48013,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm8 = _mm256_sub_ps(ymm8,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_set_ps((d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm10) @@ -47133,7 +48024,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB BLIS_CTRSM_MUL(ymm10) #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a) ); + ymm2 = _mm256_set_ps((a11 + cs_a)->imag,(a11 + cs_a)->real, + (a11 + cs_a)->imag,(a11 + cs_a)->real, + (a11 + cs_a)->imag,(a11 + cs_a)->real, + (a11 + cs_a)->imag,(a11 + cs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -47147,7 +48041,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm8 = _mm256_sub_ps(ymm8,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION @@ -47195,14 +48092,24 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB BLIS_PRE_CTRSM_SMALL_3x2(AlphaVal,b11,cs_b) ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 2)); + ymm1 = _mm256_set_ps((d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm12) #else BLIS_CTRSM_MUL(ymm12) #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a *2 + rs_a*1) ); + ymm2 = _mm256_set_ps((a11 + cs_a *2 + rs_a*1)->imag, + (a11 + cs_a *2 + rs_a*1)->real, + (a11 + cs_a *2 + rs_a*1)->imag, + (a11 + cs_a *2 + rs_a*1)->real, + (a11 + cs_a *2 + rs_a*1)->imag, + (a11 + cs_a *2 + rs_a*1)->real, + (a11 + cs_a *2 + rs_a*1)->imag, + (a11 + cs_a *2 + rs_a*1)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -47220,7 +48127,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm10 = _mm256_sub_ps(ymm10,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*2) ); + ymm2 = _mm256_set_ps((a11 + cs_a *2)->imag,(a11 + cs_a *2)->real, + (a11 + cs_a *2)->imag,(a11 + cs_a *2)->real, + (a11 + cs_a *2)->imag,(a11 + cs_a *2)->real, + (a11 + cs_a *2)->imag,(a11 + cs_a *2)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -47234,7 +48144,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm8 = _mm256_sub_ps(ymm8,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_set_ps((d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm10) @@ -47242,7 +48155,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB BLIS_CTRSM_MUL(ymm10) #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a) ); + ymm2 = _mm256_set_ps((a11 + cs_a)->imag,(a11 + cs_a)->real, + (a11 + cs_a)->imag,(a11 + cs_a)->real, + (a11 + cs_a)->imag,(a11 + cs_a)->real, + (a11 + cs_a)->imag,(a11 + cs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -47256,7 +48172,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm8 = _mm256_sub_ps(ymm8,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION @@ -47294,14 +48213,24 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB BLIS_PRE_CTRSM_SMALL_3x1(AlphaVal,b11,cs_b) ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 2)); + ymm1 = _mm256_set_ps((d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm12) #else BLIS_CTRSM_MUL(ymm12) #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a *2 + rs_a*1) ); + ymm2 = _mm256_set_ps((a11 + cs_a *2 + rs_a*1)->imag, + (a11 + cs_a *2 + rs_a*1)->real, + (a11 + cs_a *2 + rs_a*1)->imag, + (a11 + cs_a *2 + rs_a*1)->real, + (a11 + cs_a *2 + rs_a*1)->imag, + (a11 + cs_a *2 + rs_a*1)->real, + (a11 + cs_a *2 + rs_a*1)->imag, + (a11 + cs_a *2 + rs_a*1)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -47318,7 +48247,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm10 = _mm256_sub_ps(ymm10,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*2) ); + ymm2 = _mm256_set_ps((a11 + cs_a *2)->imag,(a11 + cs_a *2)->real, + (a11 + cs_a *2)->imag,(a11 + cs_a *2)->real, + (a11 + cs_a *2)->imag,(a11 + cs_a *2)->real, + (a11 + cs_a *2)->imag,(a11 + cs_a *2)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -47332,7 +48264,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm8 = _mm256_sub_ps(ymm8,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_set_ps((d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm10) @@ -47340,7 +48275,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB BLIS_CTRSM_MUL(ymm10) #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a) ); + ymm2 = _mm256_set_ps((a11 + cs_a)->imag,(a11 + cs_a)->real, + (a11 + cs_a)->imag,(a11 + cs_a)->real, + (a11 + cs_a)->imag,(a11 + cs_a)->real, + (a11 + cs_a)->imag,(a11 + cs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -47354,7 +48292,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm8 = _mm256_sub_ps(ymm8,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION @@ -47508,7 +48449,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm11 = _mm256_sub_ps(ymm19, ymm11); ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_set_ps((d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_TWO_DIV(ymm10, ymm11) @@ -47517,7 +48461,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB BLIS_CTRSM_MUL(ymm11) #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a) ); + ymm2 = _mm256_set_ps((a11 + cs_a)->imag,(a11 + cs_a)->real, + (a11 + cs_a)->imag,(a11 + cs_a)->real, + (a11 + cs_a)->imag,(a11 + cs_a)->real, + (a11 + cs_a)->imag,(a11 + cs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -47537,7 +48484,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm9 = _mm256_sub_ps(ymm9,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION @@ -47587,14 +48537,20 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm10 = _mm256_sub_ps(ymm19, ymm10); ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_set_ps((d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm10) #else BLIS_CTRSM_MUL(ymm10) #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*1) ); + ymm2 = _mm256_set_ps((a11 + cs_a)->imag,(a11 + cs_a)->real, + (a11 + cs_a)->imag,(a11 + cs_a)->real, + (a11 + cs_a)->imag,(a11 + cs_a)->real, + (a11 + cs_a)->imag,(a11 + cs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -47610,7 +48566,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm8 = _mm256_sub_ps(ymm8,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION @@ -47663,14 +48622,20 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm10 = _mm256_sub_ps(ymm19, ymm10); ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_set_ps((d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm10) #else BLIS_CTRSM_MUL(ymm10) #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*1) ); + ymm2 = _mm256_set_ps((a11 + cs_a)->imag,(a11 + cs_a)->real, + (a11 + cs_a)->imag,(a11 + cs_a)->real, + (a11 + cs_a)->imag,(a11 + cs_a)->real, + (a11 + cs_a)->imag,(a11 + cs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -47686,7 +48651,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm8 = _mm256_sub_ps(ymm8,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION @@ -47743,14 +48711,20 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm10 = _mm256_sub_ps(ymm19, ymm10); ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_set_ps((d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm10) #else BLIS_CTRSM_MUL(ymm10) #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*1) ); + ymm2 = _mm256_set_ps((a11 + cs_a)->imag,(a11 + cs_a)->real, + (a11 + cs_a)->imag,(a11 + cs_a)->real, + (a11 + cs_a)->imag,(a11 + cs_a)->real, + (a11 + cs_a)->imag,(a11 + cs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -47766,7 +48740,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm8 = _mm256_sub_ps(ymm8,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION @@ -47817,14 +48794,20 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm10 = _mm256_sub_ps(ymm19, ymm10); ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_set_ps((d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm10) #else BLIS_CTRSM_MUL(ymm10) #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + cs_a*1) ); + ymm2 = _mm256_set_ps((a11 + cs_a)->imag,(a11 + cs_a)->real, + (a11 + cs_a)->imag,(a11 + cs_a)->real, + (a11 + cs_a)->imag,(a11 + cs_a)->real, + (a11 + cs_a)->imag,(a11 + cs_a)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -47840,7 +48823,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm8 = _mm256_sub_ps(ymm8,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION @@ -47917,7 +48903,6 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB _mm_storeu_ps((float *)(ptr_a10_dup + p_lda * 0 + x*3), xmm0); xmm0 = _mm_loadl_pi(xmm1,(__m64 *)(a01 + rs_a * 0 + 2 + x*3)); _mm_storel_pi((__m64 *)(ptr_a10_dup + p_lda * 0 + 2 + x*3),xmm0); - } } @@ -47963,7 +48948,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_TWO_DIV(ymm8, ymm9) @@ -48003,7 +48991,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm8) @@ -48043,7 +49034,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm8) @@ -48083,7 +49077,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm8) @@ -48120,7 +49117,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm8 = _mm256_sub_ps(ymm19, ymm8); ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm8) @@ -48179,9 +49179,12 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB dim_t i, j, k; //loop variables dim_t k_iter; //number of times GEMM to be performed - scomplex AlphaVal = *(scomplex *)AlphaObj->buffer; //value of alpha - scomplex *L = a->buffer; //pointer to matrix A - scomplex *B = b->buffer; //pointer to matrix B + scomplex AlphaVal[2]; + AlphaVal[0] = *(scomplex *)AlphaObj->buffer; //value of alpha + AlphaVal[1] = *(scomplex *)AlphaObj->buffer; //value of alpha + + scomplex *L = bli_obj_buffer_at_off(a); //pointer to matrix A + scomplex *B = bli_obj_buffer_at_off(b); //pointer to matrix B scomplex *a01, *a11, *b10, *b11; //pointers that point to blocks for GEMM and TRSM @@ -48197,7 +49200,7 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB __m128 xmm0, xmm1, xmm2; __m128 xmm5; - xmm0 = _mm_setzero_ps(); + xmm0 = _mm_setzero_ps(); xmm1 = _mm_setzero_ps(); xmm2 = _mm_setzero_ps(); xmm5 = _mm_setzero_ps(); @@ -48318,7 +49321,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB */ ////extract a00 ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_TWO_DIV(ymm8, ymm9) @@ -48326,7 +49332,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB BLIS_CTRSM_MUL(ymm8) BLIS_CTRSM_MUL(ymm9) #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*1) ); + ymm2 = _mm256_set_ps((a11 + rs_a*1)->imag,(a11 + rs_a*1)->real, + (a11 + rs_a*1)->imag,(a11 + rs_a*1)->real, + (a11 + rs_a*1)->imag,(a11 + rs_a*1)->real, + (a11 + rs_a*1)->imag,(a11 + rs_a*1)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -48350,7 +49359,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm11 = _mm256_sub_ps(ymm11,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*2) ); + ymm2 = _mm256_set_ps((a11 + rs_a*2)->imag,(a11 + rs_a*2)->real, + (a11 + rs_a*2)->imag,(a11 + rs_a*2)->real, + (a11 + rs_a*2)->imag,(a11 + rs_a*2)->real, + (a11 + rs_a*2)->imag,(a11 + rs_a*2)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -48372,7 +49384,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm13 = _mm256_sub_ps(ymm13,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_set_ps((d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_TWO_DIV(ymm10, ymm11) @@ -48384,7 +49399,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB a11 += cs_a; - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*2) ); + ymm2 = _mm256_set_ps((a11 + rs_a*2)->imag,(a11 + rs_a*2)->real, + (a11 + rs_a*2)->imag,(a11 + rs_a*2)->real, + (a11 + rs_a*2)->imag,(a11 + rs_a*2)->real, + (a11 + rs_a*2)->imag,(a11 + rs_a*2)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -48405,7 +49423,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm13 = _mm256_sub_ps(ymm13,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 2)); + ymm1 = _mm256_set_ps((d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION @@ -48444,14 +49465,20 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB BLIS_PRE_CTRSM_SMALL_3x4(AlphaVal,b11,cs_b) ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm8) #else BLIS_CTRSM_MUL(ymm8) #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*1) ); + ymm2 = _mm256_set_ps((a11 + rs_a*1)->imag,(a11 + rs_a*1)->real, + (a11 + rs_a*1)->imag,(a11 + rs_a*1)->real, + (a11 + rs_a*1)->imag,(a11 + rs_a*1)->real, + (a11 + rs_a*1)->imag,(a11 + rs_a*1)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -48468,7 +49495,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm10 = _mm256_sub_ps(ymm10,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*2) ); + ymm2 = _mm256_set_ps((a11 + rs_a*2)->imag,(a11 + rs_a*2)->real, + (a11 + rs_a*2)->imag,(a11 + rs_a*2)->real, + (a11 + rs_a*2)->imag,(a11 + rs_a*2)->real, + (a11 + rs_a*2)->imag,(a11 + rs_a*2)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -48484,7 +49514,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm12 = _mm256_sub_ps(ymm12,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_set_ps((d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm10) @@ -48495,7 +49528,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB a11 += cs_a; - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*2) ); + ymm2 = _mm256_set_ps((a11 + rs_a*2)->imag,(a11 + rs_a*2)->real, + (a11 + rs_a*2)->imag,(a11 + rs_a*2)->real, + (a11 + rs_a*2)->imag,(a11 + rs_a*2)->real, + (a11 + rs_a*2)->imag,(a11 + rs_a*2)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -48511,7 +49547,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm12 = _mm256_sub_ps(ymm12,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 2)); + ymm1 = _mm256_set_ps((d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION @@ -48546,14 +49585,20 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB BLIS_PRE_CTRSM_SMALL_3x3(AlphaVal,b11,cs_b) ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm8) #else BLIS_CTRSM_MUL(ymm8) #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*1) ); + ymm2 = _mm256_set_ps((a11 + rs_a*1)->imag,(a11 + rs_a*1)->real, + (a11 + rs_a*1)->imag,(a11 + rs_a*1)->real, + (a11 + rs_a*1)->imag,(a11 + rs_a*1)->real, + (a11 + rs_a*1)->imag,(a11 + rs_a*1)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -48570,7 +49615,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm10 = _mm256_sub_ps(ymm10,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*2) ); + ymm2 = _mm256_set_ps((a11 + rs_a*2)->imag,(a11 + rs_a*2)->real, + (a11 + rs_a*2)->imag,(a11 + rs_a*2)->real, + (a11 + rs_a*2)->imag,(a11 + rs_a*2)->real, + (a11 + rs_a*2)->imag,(a11 + rs_a*2)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -48586,7 +49634,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm12 = _mm256_sub_ps(ymm12,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_set_ps((d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm10) @@ -48597,7 +49648,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB a11 += cs_a; - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*2) ); + ymm2 = _mm256_set_ps((a11 + rs_a*2)->imag,(a11 + rs_a*2)->real, + (a11 + rs_a*2)->imag,(a11 + rs_a*2)->real, + (a11 + rs_a*2)->imag,(a11 + rs_a*2)->real, + (a11 + rs_a*2)->imag,(a11 + rs_a*2)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -48613,7 +49667,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm12 = _mm256_sub_ps(ymm12,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 2)); + ymm1 = _mm256_set_ps((d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION @@ -48659,14 +49716,20 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB BLIS_PRE_CTRSM_SMALL_3x2(AlphaVal,b11,cs_b) ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm8) #else BLIS_CTRSM_MUL(ymm8) #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*1) ); + ymm2 = _mm256_set_ps((a11 + rs_a*1)->imag,(a11 + rs_a*1)->real, + (a11 + rs_a*1)->imag,(a11 + rs_a*1)->real, + (a11 + rs_a*1)->imag,(a11 + rs_a*1)->real, + (a11 + rs_a*1)->imag,(a11 + rs_a*1)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -48683,7 +49746,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm10 = _mm256_sub_ps(ymm10,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*2) ); + ymm2 = _mm256_set_ps((a11 + rs_a*2)->imag,(a11 + rs_a*2)->real, + (a11 + rs_a*2)->imag,(a11 + rs_a*2)->real, + (a11 + rs_a*2)->imag,(a11 + rs_a*2)->real, + (a11 + rs_a*2)->imag,(a11 + rs_a*2)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -48699,7 +49765,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm12 = _mm256_sub_ps(ymm12,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_set_ps((d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm10) @@ -48710,7 +49779,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB a11 += cs_a; - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*2) ); + ymm2 = _mm256_set_ps((a11 + rs_a*2)->imag,(a11 + rs_a*2)->real, + (a11 + rs_a*2)->imag,(a11 + rs_a*2)->real, + (a11 + rs_a*2)->imag,(a11 + rs_a*2)->real, + (a11 + rs_a*2)->imag,(a11 + rs_a*2)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -48726,7 +49798,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm12 = _mm256_sub_ps(ymm12,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 2)); + ymm1 = _mm256_set_ps((d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION @@ -48765,14 +49840,20 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB BLIS_PRE_CTRSM_SMALL_3x1(AlphaVal,b11,cs_b) ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm8) #else BLIS_CTRSM_MUL(ymm8) #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*1) ); + ymm2 = _mm256_set_ps((a11 + rs_a*1)->imag,(a11 + rs_a*1)->real, + (a11 + rs_a*1)->imag,(a11 + rs_a*1)->real, + (a11 + rs_a*1)->imag,(a11 + rs_a*1)->real, + (a11 + rs_a*1)->imag,(a11 + rs_a*1)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -48789,7 +49870,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm10 = _mm256_sub_ps(ymm10,ymm16); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*2) ); + ymm2 = _mm256_set_ps((a11 + rs_a*2)->imag,(a11 + rs_a*2)->real, + (a11 + rs_a*2)->imag,(a11 + rs_a*2)->real, + (a11 + rs_a*2)->imag,(a11 + rs_a*2)->real, + (a11 + rs_a*2)->imag,(a11 + rs_a*2)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -48805,7 +49889,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm12 = _mm256_sub_ps(ymm12,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_set_ps((d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm10) @@ -48816,7 +49903,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB a11 += cs_a; - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*2) ); + ymm2 = _mm256_set_ps((a11 + rs_a*2)->imag,(a11 + rs_a*2)->real, + (a11 + rs_a*2)->imag,(a11 + rs_a*2)->real, + (a11 + rs_a*2)->imag,(a11 + rs_a*2)->real, + (a11 + rs_a*2)->imag,(a11 + rs_a*2)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -48832,7 +49922,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm12 = _mm256_sub_ps(ymm12,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 2)); + ymm1 = _mm256_set_ps((d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real, + (d11_pack + 2)->imag,(d11_pack + 2)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION @@ -48983,7 +50076,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_TWO_DIV(ymm8, ymm9) @@ -48991,7 +50087,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB BLIS_CTRSM_MUL(ymm8) BLIS_CTRSM_MUL(ymm9) #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*1) ); + ymm2 = _mm256_set_ps((a11 + rs_a*1)->imag,(a11 + rs_a*1)->real, + (a11 + rs_a*1)->imag,(a11 + rs_a*1)->real, + (a11 + rs_a*1)->imag,(a11 + rs_a*1)->real, + (a11 + rs_a*1)->imag,(a11 + rs_a*1)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -49014,7 +50113,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm11 = _mm256_sub_ps(ymm11,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_set_ps((d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION @@ -49064,14 +50166,20 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm10 = _mm256_sub_ps(ymm19, ymm10); ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm8) #else BLIS_CTRSM_MUL(ymm8) #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*1) ); + ymm2 = _mm256_set_ps((a11 + rs_a*1)->imag,(a11 + rs_a*1)->real, + (a11 + rs_a*1)->imag,(a11 + rs_a*1)->real, + (a11 + rs_a*1)->imag,(a11 + rs_a*1)->real, + (a11 + rs_a*1)->imag,(a11 + rs_a*1)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -49088,7 +50196,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm10 = _mm256_sub_ps(ymm10,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_set_ps((d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION @@ -49142,14 +50253,20 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm10 = _mm256_sub_ps(ymm19, ymm10); ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm8) #else BLIS_CTRSM_MUL(ymm8) #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*1) ); + ymm2 = _mm256_set_ps((a11 + rs_a*1)->imag,(a11 + rs_a*1)->real, + (a11 + rs_a*1)->imag,(a11 + rs_a*1)->real, + (a11 + rs_a*1)->imag,(a11 + rs_a*1)->real, + (a11 + rs_a*1)->imag,(a11 + rs_a*1)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -49166,7 +50283,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm10 = _mm256_sub_ps(ymm10,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_set_ps((d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION @@ -49202,6 +50322,7 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ///GEMM implementation starts/// BLIS_CTRSM_SMALL_GEMM_2nx2m(a01,b10,cs_b,p_lda,k_iter) ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); + ymm16 = _mm256_permute_ps(ymm16, 0x44); xmm0 = _mm_loadu_ps((float const *)(b11)); @@ -49224,14 +50345,20 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm10 = _mm256_sub_ps(ymm19, ymm10); ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm8) #else BLIS_CTRSM_MUL(ymm8) #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*1) ); + ymm2 = _mm256_set_ps((a11 + rs_a*1)->imag,(a11 + rs_a*1)->real, + (a11 + rs_a*1)->imag,(a11 + rs_a*1)->real, + (a11 + rs_a*1)->imag,(a11 + rs_a*1)->real, + (a11 + rs_a*1)->imag,(a11 + rs_a*1)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -49248,7 +50375,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm10 = _mm256_sub_ps(ymm10,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_set_ps((d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION @@ -49301,14 +50431,20 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm10 = _mm256_sub_ps(ymm19, ymm10); ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm8) #else BLIS_CTRSM_MUL(ymm8) #endif - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11 + rs_a*1) ); + ymm2 = _mm256_set_ps((a11 + rs_a*1)->imag,(a11 + rs_a*1)->real, + (a11 + rs_a*1)->imag,(a11 + rs_a*1)->real, + (a11 + rs_a*1)->imag,(a11 + rs_a*1)->real, + (a11 + rs_a*1)->imag,(a11 + rs_a*1)->real); ymm2 = _mm256_permute_ps(ymm2, 0x44); if(conjtransa) { @@ -49325,7 +50461,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm16 = _mm256_fmaddsub_ps(ymm1, ymm2, ymm16); ymm10 = _mm256_sub_ps(ymm10,ymm16); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 1)); + ymm1 = _mm256_set_ps((d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real, + (d11_pack + 1)->imag,(d11_pack + 1)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION @@ -49445,7 +50584,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_TWO_DIV(ymm8, ymm9) @@ -49485,7 +50627,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm8) @@ -49528,7 +50673,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm8) @@ -49571,7 +50719,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm8) @@ -49611,7 +50762,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm8 = _mm256_sub_ps(ymm19, ymm8); ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); + ymm1 = _mm256_set_ps((d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real, + (d11_pack)->imag,(d11_pack)->real); ymm1 = _mm256_permute_ps(ymm1, 0x44); #ifndef BLIS_ENABLE_TRSM_PREINVERSION BLIS_CTRSM_DIV(ymm8) @@ -49638,6 +50792,7 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB return BLIS_SUCCESS; } + /* * Check if the TRSM small path should be taken for this * input and threads combination From 0696c5d58d3f4f6e4b4f203e23e6babe30f37ce9 Mon Sep 17 00:00:00 2001 From: Dipal M Zambare Date: Fri, 23 Sep 2022 14:31:45 +0530 Subject: [PATCH 229/243] Fixed ASAN reported issues in bli_l3_packm.c MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The local_mem_s is allocated on stack but in the inner scop of the “if” block, however, it can be accessed through cntl_mem_p outside the if block. This error is flagged by address sanitizer. Fixed this issue by moving the variables declaration at the functions scope. This fix address the issue reported for follwoing libflame APIs geqp3, geqrf, gerq2, gerqf, gesvd, ggev, ggevx, potrf, potrs, stedc, steqr, syevd AMD-Internal: [CPUPL-2587] Change-Id: I63749c7d406c7339d2b45b0488108ccd3f90a248 --- frame/3/bli_l3_packm.c | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/frame/3/bli_l3_packm.c b/frame/3/bli_l3_packm.c index d6efb593cc..1492313309 100644 --- a/frame/3/bli_l3_packm.c +++ b/frame/3/bli_l3_packm.c @@ -47,6 +47,8 @@ void bli_l3_packm { packbuf_t pack_buf_type; mem_t* cntl_mem_p; + mem_t* local_mem_p; + mem_t local_mem_s; siz_t size_needed; // FGVZ: Not sure why we need this barrier, but we do. @@ -80,9 +82,6 @@ void bli_l3_packm // all threads in the chief's thread group. if ( bli_mem_is_unalloc( cntl_mem_p ) ) { - mem_t* local_mem_p; - mem_t local_mem_s; - if ( bli_thread_am_ochief( thread ) ) { #ifdef BLIS_ENABLE_MEM_TRACING @@ -110,9 +109,6 @@ void bli_l3_packm } else // ( bli_mem_is_alloc( cntl_mem_p ) ) { - mem_t* local_mem_p; - mem_t local_mem_s; - // If the mem_t entry in the control tree does NOT contain a NULL // buffer, then a block has already been acquired from the memory // broker and cached in the control tree. From 67712db5dc8e0a62a9a3466c52a011c9de6912b1 Mon Sep 17 00:00:00 2001 From: Mangala V Date: Fri, 30 Sep 2022 08:50:34 -0400 Subject: [PATCH 230/243] Fixed ASAN reported issues in bli_dgemmsup_rd_haswell_asm_6x8m Address sanitizer reports error when rbp regitser is modified. Register rbp was stored with rs_a which was used during prefetch of Matrix A. Usage of rbp is avoided by using rcx register as a temporary storage register. Hence rcx is updated with Matrix C address before storing the computed data. This fix address the issue reported by GEQP3 API of libflame AMD-Internal: [CPUPL-2587] Change-Id: Ica790259010d8e71528c3d0ab1cd49069c56fc1d --- .../3/sup/bli_gemmsup_rd_haswell_asm_d6x8m.c | 22 ++++++++++--------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/kernels/haswell/3/sup/bli_gemmsup_rd_haswell_asm_d6x8m.c b/kernels/haswell/3/sup/bli_gemmsup_rd_haswell_asm_d6x8m.c index 2f25755ef4..990358db8b 100644 --- a/kernels/haswell/3/sup/bli_gemmsup_rd_haswell_asm_d6x8m.c +++ b/kernels/haswell/3/sup/bli_gemmsup_rd_haswell_asm_d6x8m.c @@ -258,7 +258,7 @@ void bli_dgemmsup_rd_haswell_asm_6x8m prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c #endif - lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + lea(mem(r8, r8, 4), rcx) // rcx = 5*rs_a @@ -277,7 +277,7 @@ void bli_dgemmsup_rd_haswell_asm_6x8m #if 1 prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a - prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + prefetch(0, mem(rax, rcx, 1, 0*8)) // prefetch rax + 5*rs_a #endif vmovupd(mem(rax ), ymm0) @@ -341,7 +341,7 @@ void bli_dgemmsup_rd_haswell_asm_6x8m #if 1 prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a - prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + prefetch(0, mem(rax, rcx, 1, 0*8)) // prefetch rax + 5*rs_a #endif vmovupd(mem(rax ), ymm0) @@ -423,7 +423,7 @@ void bli_dgemmsup_rd_haswell_asm_6x8m #if 1 prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a - prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + prefetch(0, mem(rax, rcx, 1, 0*8)) // prefetch rax + 5*rs_a #endif vmovupd(mem(rax ), ymm0) @@ -560,6 +560,7 @@ void bli_dgemmsup_rd_haswell_asm_6x8m //mov(var(rs_c), rdi) // load rs_c //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + lea(mem(r12), rcx) // rcx = c_iijj; mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate @@ -677,7 +678,7 @@ void bli_dgemmsup_rd_haswell_asm_6x8m [a_next] "m" (a_next), [b_next] "m" (b_next)*/ : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", @@ -12950,7 +12951,7 @@ void bli_dgemmsup_rd_haswell_asm_6x4m prefetch(0, mem(rcx, rdi, 1, 3*8)) // prefetch c + 1*rs_c prefetch(0, mem(rcx, rdi, 2, 3*8)) // prefetch c + 2*rs_c #endif - lea(mem(r8, r8, 4), rbp) // rbp = 5*rs_a + lea(mem(r8, r8, 4), rcx) // rcx = 5*rs_a @@ -12969,7 +12970,7 @@ void bli_dgemmsup_rd_haswell_asm_6x4m #if 1 prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a - prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + prefetch(0, mem(rax, rcx, 1, 0*8)) // prefetch rax + 5*rs_a #endif vmovupd(mem(rax ), ymm0) @@ -13033,7 +13034,7 @@ void bli_dgemmsup_rd_haswell_asm_6x4m #if 1 prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a - prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + prefetch(0, mem(rax, rcx, 1, 0*8)) // prefetch rax + 5*rs_a #endif vmovupd(mem(rax ), ymm0) @@ -13115,7 +13116,7 @@ void bli_dgemmsup_rd_haswell_asm_6x4m #if 1 prefetch(0, mem(rax, r10, 1, 0*8)) // prefetch rax + 3*rs_a prefetch(0, mem(rax, r8, 4, 0*8)) // prefetch rax + 4*rs_a - prefetch(0, mem(rax, rbp, 1, 0*8)) // prefetch rax + 5*rs_a + prefetch(0, mem(rax, rcx, 1, 0*8)) // prefetch rax + 5*rs_a #endif vmovupd(mem(rax ), ymm0) @@ -13251,6 +13252,7 @@ void bli_dgemmsup_rd_haswell_asm_6x4m //mov(var(rs_c), rdi) // load rs_c //lea(mem(, rdi, 8), rdi) // rs_c *= sizeof(double) + lea(mem(r12), rcx) // rcx = c + 3*ii*rs_c; mov(var(alpha), rax) // load address of alpha mov(var(beta), rbx) // load address of beta vbroadcastsd(mem(rax), ymm0) // load alpha and duplicate @@ -13361,7 +13363,7 @@ void bli_dgemmsup_rd_haswell_asm_6x4m [a_next] "m" (a_next), [b_next] "m" (b_next)*/ : // register clobber list - "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "rbp", + "rax", "rbx", "rcx", "rdx", "rsi", "rdi", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "r15", "xmm0", "xmm1", "xmm2", "xmm3", "xmm4", "xmm5", "xmm6", "xmm7", From 1e66d8e2d5d47e416af321efe1a03fa11a4b93cf Mon Sep 17 00:00:00 2001 From: Shubham Sharma Date: Tue, 20 Sep 2022 12:06:43 +0000 Subject: [PATCH 231/243] DTRSM Native path AVX 512 kernels are developed - Added DTRSM AVX512 kernels for lower and upper variants in the native path. - Changes in framework are made to accommodate these kernels. AMD-Internal: [CPUPL-2588] Change-Id: I1f74273ef2389018343c0645870290373ce25efe --- CMakeLists.txt | 2 + config/zen4/bli_cntx_init_zen4.c | 14 +- frame/3/trsm/bli_trsm_front.c | 4 +- frame/3/trsm/bli_trsm_ll_ker_var2.c | 2 +- frame/3/trsm/bli_trsm_lu_ker_var2.c | 2 +- frame/3/trsm/bli_trsm_rl_ker_var2.c | 2 +- frame/3/trsm/bli_trsm_ru_ker_var2.c | 2 +- frame/include/bli_x86_asm_macros.h | 4 +- kernels/zen4/3/CMakeLists.txt | 7 + kernels/zen4/3/bli_gemmtrsm_l_zen_16x14.c | 1669 ++++++++++++++++++++ kernels/zen4/3/bli_gemmtrsm_u_zen_16x14.c | 1706 +++++++++++++++++++++ kernels/zen4/CMakeLists.txt | 1 + kernels/zen4/bli_kernels_zen4.h | 3 + 13 files changed, 3404 insertions(+), 14 deletions(-) create mode 100644 kernels/zen4/3/CMakeLists.txt create mode 100644 kernels/zen4/3/bli_gemmtrsm_l_zen_16x14.c create mode 100644 kernels/zen4/3/bli_gemmtrsm_u_zen_16x14.c diff --git a/CMakeLists.txt b/CMakeLists.txt index 5b01a8f4e4..2752df7e68 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -337,6 +337,8 @@ endif() if(${TARGET_ARCH} STREQUAL zen4 OR ${TARGET_ARCH} STREQUAL amdzen) set_source_files_properties(${CMAKE_CURRENT_SOURCE_DIR}/kernels/zen4/1/bli_amaxv_zen_int_avx512.c PROPERTIES COMPILE_FLAGS /arch:AVX512) + set_source_files_properties(${CMAKE_CURRENT_SOURCE_DIR}/kernels/zen4/3/bli_gemmtrsm_l_zen_16x14.c PROPERTIES COMPILE_FLAGS /arch:AVX512) + set_source_files_properties(${CMAKE_CURRENT_SOURCE_DIR}/kernels/zen4/3/bli_gemmtrsm_u_zen_16x14.c PROPERTIES COMPILE_FLAGS /arch:AVX512) set_source_files_properties(${CMAKE_CURRENT_SOURCE_DIR}/kernels/skx/3/bli_dgemm_skx_asm_16x14.c PROPERTIES COMPILE_FLAGS /arch:AVX512) set_source_files_properties(${CMAKE_CURRENT_SOURCE_DIR}/kernels/skx/3/bli_sgemm_skx_asm_32x12_l2.c PROPERTIES COMPILE_FLAGS /arch:AVX512) endif() diff --git a/config/zen4/bli_cntx_init_zen4.c b/config/zen4/bli_cntx_init_zen4.c index 1de13061b2..1d9bf3cc05 100644 --- a/config/zen4/bli_cntx_init_zen4.c +++ b/config/zen4/bli_cntx_init_zen4.c @@ -77,10 +77,10 @@ void bli_cntx_init_zen4( cntx_t* cntx ) // gemmtrsm_l BLIS_GEMMTRSM_L_UKR, BLIS_FLOAT, bli_sgemmtrsm_l_haswell_asm_6x16, TRUE, - BLIS_GEMMTRSM_L_UKR, BLIS_DOUBLE, bli_dgemmtrsm_l_haswell_asm_6x8, TRUE, + BLIS_GEMMTRSM_L_UKR, BLIS_DOUBLE, bli_dgemmtrsm_l_zen_asm_16x14, TRUE, // gemmtrsm_u BLIS_GEMMTRSM_U_UKR, BLIS_FLOAT, bli_sgemmtrsm_u_haswell_asm_6x16, TRUE, - BLIS_GEMMTRSM_U_UKR, BLIS_DOUBLE, bli_dgemmtrsm_u_haswell_asm_6x8, TRUE, + BLIS_GEMMTRSM_U_UKR, BLIS_DOUBLE, bli_dgemmtrsm_u_zen_asm_16x14, TRUE, cntx ); @@ -302,11 +302,11 @@ void bli_cntx_init_zen4( cntx_t* cntx ) void bli_zen4_override_trsm_blkszs (cntx_t* cntx) { blksz_t blkszs[ BLIS_NUM_BLKSZS ]; - bli_blksz_init_easy( &blkszs[ BLIS_MR ], 6, 6, 3, 3 ); - bli_blksz_init_easy( &blkszs[ BLIS_NR ], 16, 8, 8, 4 ); - bli_blksz_init_easy( &blkszs[ BLIS_MC ], 144, 72, 144, 72 ); - bli_blksz_init_easy( &blkszs[ BLIS_KC ], 256, 492, 256, 256 ); - bli_blksz_init_easy( &blkszs[ BLIS_NC ], 4080, 1600, 4080, 4080 ); + bli_blksz_init_easy( &blkszs[ BLIS_MR ], 6, 16, 3, 3 ); + bli_blksz_init_easy( &blkszs[ BLIS_NR ], 16, 14, 8, 4 ); + bli_blksz_init_easy( &blkszs[ BLIS_MC ], 144, 240, 144, 72 ); + bli_blksz_init_easy( &blkszs[ BLIS_KC ], 256, 512, 256, 256 ); + bli_blksz_init_easy( &blkszs[ BLIS_NC ], 4080, 4004, 4080, 4080 ); // Update the context with the current architecture's register and cache diff --git a/frame/3/trsm/bli_trsm_front.c b/frame/3/trsm/bli_trsm_front.c index 9eddd5c42a..35cd2d4b85 100644 --- a/frame/3/trsm/bli_trsm_front.c +++ b/frame/3/trsm/bli_trsm_front.c @@ -165,7 +165,7 @@ void bli_trsm_front * We need to revisit this when TRSM AVX-512 kernels are implemented. */ if ( (bli_arch_query_id() == BLIS_ARCH_ZEN4) && - (bli_obj_dt(a) == BLIS_FLOAT || bli_obj_dt(a) == BLIS_DOUBLE) ) + (bli_obj_dt(a) == BLIS_FLOAT) ) { bli_zen4_override_trsm_blkszs(cntx); } @@ -205,7 +205,7 @@ void bli_trsm_front * default block sizes are restored for the subsequent operations. */ if ( (bli_arch_query_id() == BLIS_ARCH_ZEN4) && - (bli_obj_dt(a) == BLIS_FLOAT || bli_obj_dt(a) == BLIS_DOUBLE) ) + (bli_obj_dt(a) == BLIS_FLOAT) ) { bli_zen4_restore_default_blkszs(cntx); } diff --git a/frame/3/trsm/bli_trsm_ll_ker_var2.c b/frame/3/trsm/bli_trsm_ll_ker_var2.c index fe39e6f478..a15f39fc3c 100644 --- a/frame/3/trsm/bli_trsm_ll_ker_var2.c +++ b/frame/3/trsm/bli_trsm_ll_ker_var2.c @@ -189,7 +189,7 @@ void PASTEMAC(ch,varname) \ * We need to revisit this when TRSM AVX-512 kernels are implemented. */ \ if ((bli_arch_query_id() == BLIS_ARCH_ZEN4) && \ - (dt == BLIS_FLOAT || dt == BLIS_DOUBLE)) \ + (dt == BLIS_FLOAT)) \ { \ gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_AVX2_UKR, cntx ); \ } \ diff --git a/frame/3/trsm/bli_trsm_lu_ker_var2.c b/frame/3/trsm/bli_trsm_lu_ker_var2.c index e55b75dff4..48e4588f52 100644 --- a/frame/3/trsm/bli_trsm_lu_ker_var2.c +++ b/frame/3/trsm/bli_trsm_lu_ker_var2.c @@ -189,7 +189,7 @@ void PASTEMAC(ch,varname) \ * We need to revisit this when TRSM AVX-512 kernels are implemented. */ \ if ((bli_arch_query_id() == BLIS_ARCH_ZEN4) && \ - (dt == BLIS_FLOAT || dt == BLIS_DOUBLE)) \ + (dt == BLIS_FLOAT)) \ { \ gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_AVX2_UKR, cntx ); \ } \ diff --git a/frame/3/trsm/bli_trsm_rl_ker_var2.c b/frame/3/trsm/bli_trsm_rl_ker_var2.c index 5e070a760b..2705a747ac 100644 --- a/frame/3/trsm/bli_trsm_rl_ker_var2.c +++ b/frame/3/trsm/bli_trsm_rl_ker_var2.c @@ -195,7 +195,7 @@ void PASTEMAC(ch,varname) \ * We need to revisit this when TRSM AVX-512 kernels are implemented. */ \ if ((bli_arch_query_id() == BLIS_ARCH_ZEN4) && \ - (dt == BLIS_FLOAT || dt == BLIS_DOUBLE)) \ + (dt == BLIS_FLOAT)) \ { \ gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_AVX2_UKR, cntx ); \ } \ diff --git a/frame/3/trsm/bli_trsm_ru_ker_var2.c b/frame/3/trsm/bli_trsm_ru_ker_var2.c index b592c24276..dc37614eb6 100644 --- a/frame/3/trsm/bli_trsm_ru_ker_var2.c +++ b/frame/3/trsm/bli_trsm_ru_ker_var2.c @@ -194,7 +194,7 @@ void PASTEMAC(ch,varname) \ * We need to revisit this when TRSM AVX-512 kernels are implemented. */ \ if ((bli_arch_query_id() == BLIS_ARCH_ZEN4) && \ - (dt == BLIS_FLOAT || dt == BLIS_DOUBLE)) \ + (dt == BLIS_FLOAT)) \ { \ gemm_ukr = bli_cntx_get_l3_vir_ukr_dt( dt, BLIS_GEMM_AVX2_UKR, cntx ); \ } \ diff --git a/frame/include/bli_x86_asm_macros.h b/frame/include/bli_x86_asm_macros.h index a4987b4c5f..ffb2771758 100644 --- a/frame/include/bli_x86_asm_macros.h +++ b/frame/include/bli_x86_asm_macros.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2018, The University of Texas at Austin - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2019-22, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -1271,6 +1271,7 @@ #else #define KMOVW(_0, _1) INSTR_(kmovw, _0, _1) +#define KMOVQ(_0, _1) INSTR_(kmovq, _0, _1) #define JKNZD(_0, _1) INSTR_(kortestw, _0, _0) INSTR_(jnz, _1) #endif @@ -1279,6 +1280,7 @@ #define KSHIFTRW(_0, _1, _2) INSTR_(kshiftrw, _0, _1, _2) #define kmovw(_0, _1) KMOVW(_0, _1) +#define kmovq(_0, _1) KMOVQ(_0, _1) #define jknzd(_0, _1) JKNZD(_0, _1) #define kxnorw(_0, _1, _2) KXNORW(_0, _1, _2) #define kshiftrw(_0, _1, _2) KSHIFTRW(_0, _1, _2) diff --git a/kernels/zen4/3/CMakeLists.txt b/kernels/zen4/3/CMakeLists.txt new file mode 100644 index 0000000000..381204ae68 --- /dev/null +++ b/kernels/zen4/3/CMakeLists.txt @@ -0,0 +1,7 @@ +##Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.## + +target_sources("${PROJECT_NAME}" + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmtrsm_l_zen_16x14.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmtrsm_u_zen_16x14.c + ) diff --git a/kernels/zen4/3/bli_gemmtrsm_l_zen_16x14.c b/kernels/zen4/3/bli_gemmtrsm_l_zen_16x14.c new file mode 100644 index 0000000000..3680c44b05 --- /dev/null +++ b/kernels/zen4/3/bli_gemmtrsm_l_zen_16x14.c @@ -0,0 +1,1669 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + + +#include "blis.h" +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +#define A_L1_PREFETCH_DIST 12 // in units of k iterations +#define B_L1_PREFETCH_DIST 12 // e.g. 4 k iterations ~= 56 cycles +#define TAIL_NITER 5 // in units of 4x unrolled k iterations + // e.g. 5 -> 4*5 k iterations ~= 280 cycles + +#define PREFETCH_A_L1(n, k) \ + PREFETCH(0, MEM(RAX, A_L1_PREFETCH_DIST*16*8 + (2*n+k)*64)) +#define PREFETCH_B_L1(n, k) \ + PREFETCH(0, MEM(RBX, B_L1_PREFETCH_DIST*14*8 + (2*n+k)*56)) + +#define LOOP_ALIGN ALIGN32 + +#define SUBITER(n) \ +\ + PREFETCH_A_L1(n, 0) \ + \ + VBROADCASTSD(MEM(RBX, (14*n + 0)*8), ZMM(2)) \ + VBROADCASTSD(MEM(RBX, (14*n + 1)*8), ZMM(3)) \ + VFMADD231PD(ZMM(0), ZMM(2), ZMM(4)) \ + VFMADD231PD(ZMM(1), ZMM(2), ZMM(5)) \ + VFMADD231PD(ZMM(0), ZMM(3), ZMM(6)) \ + VFMADD231PD(ZMM(1), ZMM(3), ZMM(7)) \ + \ + VBROADCASTSD(MEM(RBX, (14*n + 2)*8), ZMM(2)) \ + VBROADCASTSD(MEM(RBX, (14*n + 3)*8), ZMM(3)) \ + VFMADD231PD(ZMM(0), ZMM(2), ZMM(8) ) \ + VFMADD231PD(ZMM(1), ZMM(2), ZMM(9) ) \ + VFMADD231PD(ZMM(0), ZMM(3), ZMM(10)) \ + VFMADD231PD(ZMM(1), ZMM(3), ZMM(11)) \ + \ + PREFETCH_B_L1(n, 0) \ + \ + VBROADCASTSD(MEM(RBX, (14*n + 4)*8), ZMM(2)) \ + VBROADCASTSD(MEM(RBX, (14*n + 5)*8), ZMM(3)) \ + VFMADD231PD(ZMM(0), ZMM(2), ZMM(12)) \ + VFMADD231PD(ZMM(1), ZMM(2), ZMM(13)) \ + VFMADD231PD(ZMM(0), ZMM(3), ZMM(14)) \ + VFMADD231PD(ZMM(1), ZMM(3), ZMM(15)) \ + \ + VBROADCASTSD(MEM(RBX, (14*n + 6)*8), ZMM(2)) \ + VBROADCASTSD(MEM(RBX, (14*n + 7)*8), ZMM(3)) \ + VFMADD231PD(ZMM(0), ZMM(2), ZMM(16)) \ + VFMADD231PD(ZMM(1), ZMM(2), ZMM(17)) \ + VFMADD231PD(ZMM(0), ZMM(3), ZMM(18)) \ + VFMADD231PD(ZMM(1), ZMM(3), ZMM(19)) \ + \ + PREFETCH_A_L1(n, 1) \ + \ + VBROADCASTSD(MEM(RBX, (14*n + 8)*8), ZMM(2)) \ + VBROADCASTSD(MEM(RBX, (14*n + 9)*8), ZMM(3)) \ + VFMADD231PD(ZMM(0), ZMM(2), ZMM(20)) \ + VFMADD231PD(ZMM(1), ZMM(2), ZMM(21)) \ + VFMADD231PD(ZMM(0), ZMM(3), ZMM(22)) \ + VFMADD231PD(ZMM(1), ZMM(3), ZMM(23)) \ + \ + VBROADCASTSD(MEM(RBX, (14*n + 10)*8), ZMM(2)) \ + VBROADCASTSD(MEM(RBX, (14*n + 11)*8), ZMM(3)) \ + VFMADD231PD(ZMM(0), ZMM(2), ZMM(24)) \ + VFMADD231PD(ZMM(1), ZMM(2), ZMM(25)) \ + VFMADD231PD(ZMM(0), ZMM(3), ZMM(26)) \ + VFMADD231PD(ZMM(1), ZMM(3), ZMM(27)) \ + \ + PREFETCH_B_L1(n, 1) \ + \ + VBROADCASTSD(MEM(RBX, (14*n + 12)*8), ZMM(2)) \ + VBROADCASTSD(MEM(RBX, (14*n + 13)*8), ZMM(3)) \ + VFMADD231PD(ZMM(0), ZMM(2), ZMM(28)) \ + VFMADD231PD(ZMM(1), ZMM(2), ZMM(29)) \ + VFMADD231PD(ZMM(0), ZMM(3), ZMM(30)) \ + VFMADD231PD(ZMM(1), ZMM(3), ZMM(31)) \ + \ + VMOVAPD(MEM(RAX,((n*2)+2)*8*8), ZMM(0)) \ + VMOVAPD(MEM(RAX,((n*2)+3)*8*8), ZMM(1)) + +#define UPDATE_C_COL_SCATTERED(R1,R2) \ +\ + KXNORW(K(0), K(0), K(1)) \ + KXNORW(K(0), K(0), K(2)) \ + VSCATTERQPD(ZMM(R1), MEM(RCX,ZMM(2),1) MASK_K(1)) \ + VSCATTERQPD(ZMM(R2), MEM(R14,ZMM(2),1) MASK_K(2)) \ + ADD(RDI, RCX) \ + ADD(RDI, R14) \ + +/* scatter only first 6 elements of r1 and r2 */ +#define UPDATE_C_COL_SCATTERED_2x6(R1,R2) \ +\ + KXNORW(K(0), K(0), K(1)) \ + KXNORW(K(0), K(0), K(2)) \ + MOVQ(IMM(0b00111111), RAX) \ + KMOVQ(RAX, K(2)) \ + KMOVQ(RAX, K(1)) \ + VSCATTERQPD(ZMM(R1), MEM(RCX,ZMM(2),1) MASK_K(1)) \ + VSCATTERQPD(ZMM(R2), MEM(R14,ZMM(2),1) MASK_K(2)) \ + ADD(RDI, RCX) \ + ADD(RDI, R14) \ + +/* +Transpose 8 zmm registers and store the output in the given 8 registers + Note: Requires offsetPointer for scatter instruction + and 512 bytes of free memory (rcx) for transpose. + Input : + R1 = [ 0, 1, 2, 3, 4, 5, 6, 7] + R2 = [ 8, 9, 10, 11, 12, 13, 14, 15] + R3 = [16, 17, 18, 19, 20, 21, 22, 23] + R4 = [24, 25, 26, 27, 28, 29, 30, 31] + R5 = [32, 33, 34, 35, 36, 37, 38, 39] + R6 = [40, 41, 42, 43, 44, 45, 46, 47] + R7 = [48, 49, 50, 51, 52, 53, 54, 55] + R18= [56, 57, 58, 59, 60, 61, 62, 63] + Output : + R1 = [0, 8, 16, 24, 32, 40, 48, 56] + R2 = [1, 9, 17, 25, 33, 41, 49, 57] + R3 = [2, 10, 18, 26, 34, 42, 50, 58] + R4 = [3, 11, 19, 27, 35, 43, 51, 59] + R5 = [4, 12, 20, 28, 36, 44, 52, 60] + R6 = [5, 13, 21, 29, 37, 45, 53, 61] + R7 = [6, 14, 22, 30, 38, 46, 54, 62] + R18= [7, 15, 23, 31, 39, 47, 55, 63] +*/ +#define TRANSPOSE_REGISTERS_8x8(R1, R2, R3, R4, R5, R6, R7, R18) \ +\ + MOV(R8, RCX) \ + MOV(VAR(cs_c), RSI) \ + MOV(R9, RDI) \ + LEA(MEM(RCX, RSI, 8), RDX) \ + MOV(VAR(offsetPtr), R13) \ + MOV(RDI, R12) \ + CMP(RSI, R12) \ + CMOVL(RSI, R12) \ + VPBROADCASTQ(R12, ZMM(0)) \ + VPMULLQ(MEM(R13 ), ZMM(0), ZMM(2)) \ + \ + KXNORW(K(0), K(0), K(1)) \ + KXNORW(K(0), K(0), K(2)) \ + KXNORW(K(0), K(0), K(3)) \ + KXNORW(K(0), K(0), K(4)) \ + VSCATTERQPD(ZMM(R1), MEM(RCX,ZMM(2),1) MASK_K(1)) \ + ADD(IMM(1*8), RCX) \ + KXNORW(K(0), K(0), K(1)) \ + VSCATTERQPD(ZMM(R2), MEM(RCX,ZMM(2),1) MASK_K(2)) \ + ADD(IMM(1*8), RCX) \ + KXNORW(K(0), K(0), K(2)) \ + VSCATTERQPD(ZMM(R3), MEM(RCX,ZMM(2),1) MASK_K(3)) \ + ADD(IMM(1*8), RCX) \ + KXNORW(K(0), K(0), K(3)) \ + VSCATTERQPD(ZMM(R4), MEM(RCX,ZMM(2),1) MASK_K(4)) \ + ADD(IMM(1*8), RCX) \ + KXNORW(K(0), K(0), K(4)) \ + VSCATTERQPD(ZMM(R5), MEM(RCX,ZMM(2),1) MASK_K(1)) \ + ADD(IMM(1*8), RCX) \ + KXNORW(K(0), K(0), K(1)) \ + VSCATTERQPD(ZMM(R6), MEM(RCX,ZMM(2),1) MASK_K(2)) \ + ADD(IMM(1*8), RCX) \ + KXNORW(K(0), K(0), K(2)) \ + VSCATTERQPD(ZMM(R7), MEM(RCX,ZMM(2),1) MASK_K(3)) \ + ADD(IMM(1*8), RCX) \ + KXNORW(K(0), K(0), K(3)) \ + VSCATTERQPD(ZMM(R18), MEM(RCX,ZMM(2),1) MASK_K(4)) \ + \ + MOV(R8, RCX) \ + \ + VMOVUPD(MEM(RCX), ZMM(R1))\ + ADD(R12, RCX) \ + VMOVUPD(MEM(RCX), ZMM(R2)) \ + ADD(R12, RCX) \ + VMOVUPD(MEM(RCX), ZMM(R3)) \ + ADD(R12, RCX) \ + VMOVUPD(MEM(RCX), ZMM(R4)) \ + ADD(R12, RCX) \ + VMOVUPD(MEM(RCX), ZMM(R5)) \ + ADD(R12, RCX) \ + VMOVUPD(MEM(RCX), ZMM(R6)) \ + ADD(R12, RCX) \ + VMOVUPD(MEM(RCX), ZMM(R7)) \ + ADD(R12, RCX) \ + VMOVUPD(MEM(RCX), ZMM(R18)) \ + +/* +Transpose six zmm registers and store the output in the given 8 registers + Note: Require offsetPointer for scatter instruction + and 512 bytes of free memory (rcx) for transpose. + Input : + R1 = [ 0, 1, 2, 3, 4, 5, 6, 7] + R2 = [ 8, 9, 10, 11, 12, 13, 14, 15] + R3 = [16, 17, 18, 19, 20, 21, 22, 23] + R4 = [24, 25, 26, 27, 28, 29, 30, 31] + R5 = [32, 33, 34, 35, 36, 37, 38, 39] + R6 = [40, 41, 42, 43, 44, 45, 46, 47] + Output : + R1 = [0, 8, 16, 24, 32, 40, -, -] + R2 = [1, 9, 17, 25, 33, 41, -, -] + R3 = [2, 10, 18, 26, 34, 42, -, -] + R4 = [3, 11, 19, 27, 35, 43, -, -] + R5 = [4, 12, 20, 28, 36, 44, -, -] + R6 = [5, 13, 21, 29, 37, 45, -, -] + R7 = [6, 14, 22, 30, 38, 46, -, -] + R18 = [7, 15, 23, 31, 39, 47, -, -] +*/ +#define TRANSPOSE_REGISTERS_6x8(R1, R2, R3, R4, R5, R6, R7, R18) \ +\ + MOV(R8, RCX) \ + MOV(VAR(cs_c), RSI) \ + MOV(R9, RDI) \ + LEA(MEM(RCX, RSI, 8), RDX) \ + MOV(VAR(offsetPtr), R13) \ + MOV(RDI, R12) \ + CMP(RSI, R12) \ + CMOVL(RSI, R12) \ + VPBROADCASTQ(R12, ZMM(0)) \ + VPMULLQ(MEM(R13 ), ZMM(0), ZMM(2)) \ + LEA(MEM(RCX, R12, 4), RCX) \ + LEA(MEM(RCX, R12, 1), RCX) \ + \ + KXNORW(K(0), K(0), K(1)) \ + KXNORW(K(0), K(0), K(2)) \ + KXNORW(K(0), K(0), K(3)) \ + KXNORW(K(0), K(0), K(4)) \ + VSCATTERQPD(ZMM(R1), MEM(RCX,ZMM(2),1) MASK_K(1)) \ + ADD(IMM(1*8), RCX) \ + KXNORW(K(0), K(0), K(1)) \ + VSCATTERQPD(ZMM(R2), MEM(RCX,ZMM(2),1) MASK_K(2)) \ + ADD(IMM(1*8), RCX) \ + KXNORW(K(0), K(0), K(2)) \ + VSCATTERQPD(ZMM(R3), MEM(RCX,ZMM(2),1) MASK_K(3)) \ + ADD(IMM(1*8), RCX) \ + KXNORW(K(0), K(0), K(3)) \ + VSCATTERQPD(ZMM(R4), MEM(RCX,ZMM(2),1) MASK_K(4)) \ + ADD(IMM(1*8), RCX) \ + KXNORW(K(0), K(0), K(4)) \ + VSCATTERQPD(ZMM(R5), MEM(RCX,ZMM(2),1) MASK_K(1)) \ + ADD(IMM(1*8), RCX) \ + KXNORW(K(0), K(0), K(1)) \ + VSCATTERQPD(ZMM(R6), MEM(RCX,ZMM(2),1) MASK_K(2)) \ + \ + MOV(R8, RCX) \ + LEA(MEM(RCX, R12, 4), RCX) \ + LEA(MEM(RCX, R12, 1), RCX) \ + \ + VMOVUPD(MEM(RCX), ZMM(R1))\ + ADD(R12, RCX) \ + VMOVUPD(MEM(RCX), ZMM(R2)) \ + ADD(R12, RCX) \ + VMOVUPD(MEM(RCX), ZMM(R3)) \ + ADD(R12, RCX) \ + VMOVUPD(MEM(RCX), ZMM(R4)) \ + ADD(R12, RCX) \ + VMOVUPD(MEM(RCX), ZMM(R5)) \ + ADD(R12, RCX) \ + VMOVUPD(MEM(RCX), ZMM(R6)) \ + ADD(R12, RCX) \ + VMOVUPD(MEM(RCX), ZMM(R7)) \ + ADD(R12, RCX) \ + VMOVUPD(MEM(RCX), ZMM(R18)) \ + +// Offsets for scatter/gather instructions +static int64_t offsets[16] __attribute__((aligned(64))) = + { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12,13,14,15}; + + +void bli_dgemmtrsm_l_zen_asm_16x14 +( + dim_t k_, + double* restrict alpha, + double* restrict a10, + double* restrict a11, + double* restrict b01, + double* restrict b11, + double* restrict c11, inc_t rs_c_, inc_t cs_c_, + auxinfo_t* restrict data, + cntx_t* restrict cntx +) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_9); + const int64_t k = k_; + uint64_t rs_c = rs_c_ * 8; + const int64_t* offsetPtr = &offsets[0]; + uint64_t cs_c = cs_c_ * 8; + + BEGIN_ASM() + + //clear out registers + VXORPD(ZMM(4), ZMM(4), ZMM(4)) + VMOVAPD(ZMM(4), ZMM(5) ) + VMOVAPD(ZMM(4), ZMM(6) ) + VMOVAPD(ZMM(4), ZMM(7) ) + VMOVAPD(ZMM(4), ZMM(8) ) + VMOVAPD(ZMM(4), ZMM(9) ) + VMOVAPD(ZMM(4), ZMM(10)) + VMOVAPD(ZMM(4), ZMM(11)) + VMOVAPD(ZMM(4), ZMM(12)) + VMOVAPD(ZMM(4), ZMM(13)) + VMOVAPD(ZMM(4), ZMM(14)) + VMOVAPD(ZMM(4), ZMM(15)) + VMOVAPD(ZMM(4), ZMM(16)) + VMOVAPD(ZMM(4), ZMM(17)) + VMOVAPD(ZMM(4), ZMM(18)) + VMOVAPD(ZMM(4), ZMM(19)) + VMOVAPD(ZMM(4), ZMM(20)) + VMOVAPD(ZMM(4), ZMM(21)) + VMOVAPD(ZMM(4), ZMM(22)) + VMOVAPD(ZMM(4), ZMM(23)) + VMOVAPD(ZMM(4), ZMM(24)) + VMOVAPD(ZMM(4), ZMM(25)) + VMOVAPD(ZMM(4), ZMM(26)) + VMOVAPD(ZMM(4), ZMM(27)) + VMOVAPD(ZMM(4), ZMM(28)) + VMOVAPD(ZMM(4), ZMM(29)) + VMOVAPD(ZMM(4), ZMM(30)) + VMOVAPD(ZMM(4), ZMM(31)) + + MOV(VAR(k), RSI) + + MOV(VAR(a10), RAX) // load address of a + MOV(VAR(b01), RBX) // load address of b + MOV(VAR(c11), R8) // load address of c + + LEA(MEM(RSI,RSI,2), RDX) + LEA(MEM(,RDX,4), RDX) + LEA(MEM(RDX,RSI,4), RDX) // 16 * K + LEA(MEM(RAX,RDX,8,-128), RDX) // a_next + LEA(MEM(R8,63), R12) // c for prefetching + + MOV(IMM(14), RDI) + LEA(MEM(, RDI, 8), RDI) + + MOV(VAR(rs_c), R9) + MOV(VAR(cs_c), R13) + + MOV(IMM(0), R11) + MOV(R13, R15) + + CMP(IMM(8), R13) + JNE(.DBEFORELOOP) + MOV(IMM(2), R11) + MOV(R9, R15) + + LABEL(.DBEFORELOOP) + + VMOVAPD(MEM(RAX, 0*8*8), ZMM(0)) + VMOVAPD(MEM(RAX, 1*8*8), ZMM(1)) // preload a + + MOV(RSI, R10) + AND(IMM(3), R10) // R10 = K % 4 + SAR(IMM(2), RSI) // RSI = K / 4 + + /* + MAIN LOOP + Note: This loop runs (K/4 - 14 - TAIL_NITER) times + */ + SUB(R11, RSI) + SUB(IMM(14+TAIL_NITER), RSI) + JLE(K_LE_80) + + LOOP_ALIGN + LABEL(LOOP1) + SUBITER(0) + PREFETCH(1, MEM(RDX)) + SUBITER(1) + SUB(IMM(1), RSI) + SUBITER(2) + PREFETCH(1, MEM(RDX,64)) + SUBITER(3) + + LEA(MEM(RAX,4*16*8), RAX) + LEA(MEM(RBX,4*14*8), RBX) + LEA(MEM(RDX,16*8), RDX) + + JNZ(LOOP1) + + LABEL(K_LE_80) + + /* + C prefetch Loop + Note: This loop runs 14 times, + These 14 iterations are done seperately so that c11 can be prefetched here. + */ + ADD(R11, RSI) + ADD(IMM(14), RSI) + JLE(K_LE_24) + + LOOP_ALIGN + LABEL(LOOP2) + PREFETCH(0, MEM(R12)) + SUBITER(0) + PREFETCH(1, MEM(RDX)) + SUBITER(1) + PREFETCH(0, MEM(R12,64)) + SUB(IMM(1), RSI) + SUBITER(2) + PREFETCH(1, MEM(RDX,64)) + SUBITER(3) + + LEA(MEM(RAX,4*16*8), RAX) + LEA(MEM(RBX,4*14*8), RBX) + LEA(MEM(RDX,16*8), RDX) + LEA(MEM(R12,R15,1), R12) + + JNZ(LOOP2) + + LABEL(K_LE_24) + + /* + TAIL_NITER Loop + Note: This loop runs TAIL_NITER times, + This loop is used to provide some distance between c11 prefetch and usage of c11. + */ + ADD(IMM(0+TAIL_NITER), RSI) + JLE(TAIL) + + LOOP_ALIGN + LABEL(LOOP3) + + SUBITER(0) + PREFETCH(1, MEM(RDX)) + SUBITER(1) + SUB(IMM(1), RSI) + SUBITER(2) + PREFETCH(1, MEM(RDX,64)) + SUBITER(3) + + LEA(MEM(RAX,4*16*8), RAX) + LEA(MEM(RBX,4*14*8), RBX) + LEA(MEM(RDX,16*8), RDX) + + JNZ(LOOP3) + + /* + K Left Loop + This loop runs K % 4 times. + */ + LABEL(TAIL) + MOV(R10, RSI) + TEST(RSI, RSI) + JE(.DPOSTACCUM) + LOOP_ALIGN + LABEL(TAIL_LOOP) + + SUB(IMM(1), RSI) + SUBITER(0) + + LEA(MEM(RAX,16*8), RAX) + LEA(MEM(RBX,14*8), RBX) + + JNZ(TAIL_LOOP) + + LABEL(.DPOSTACCUM) + /* GEMM output before transpose GEMM output after transpose + __________________________________ + ___________________________ |______zmm4______|______zmm20___x x| + | | | | | | | | | | | | | | | |______zmm6______|______zmm22___x x| + |z|z|z|z|z|z|z|z|z|z|z|z|z|z| |______zmm8______|______zmm24___x x| + |m|m|m|m|m|m|m|m|m|m|m|m|m|m| |______zmm10_____|______zmm26___x x| + |m|m|m|m|m|m|m|m|m|m|m|m|m|m| |______zmm12_____|______zmm28___x x| + |4|6|8|1|1|1|1|1|2|2|2|2|2|3| |______zmm14_____|______zmm30___x x| + | | | |0|2|4|6|8|0|2|4|6|8|0| |______zmm16_____|_____c11______x x| + | | | | | | | | | | | | | | | |______zmm18_____|_____c11+cs___x x| + ____________________________ |______zmm5______|______zmm21___x x| + | | | | | | | | | | | | | | | |______zmm7______|______zmm23___x x| + |z|z|z|z|z|z|z|z|z|z|z|z|z|z| |______zmm9______|______zmm25___x x| + |m|m|m|m|m|m|m|m|m|m|m|m|m|m| |______zmm11_____|______zmm27___x x| + |m|m|m|m|m|m|m|m|m|m|m|m|m|m| |______zmm13_____|______zmm29___x x| + |5|7|9|1|1|1|1|1|2|2|2|2|2|3| |______zmm15_____|______zmm31___x x| + | | | |1|3|5|7|9|1|3|5|7|9|1| |______zmm17_____|____c11+cs*2__x x| + | | | | | | | | | | | | | | | |______zmm19_____|____c11+cs*4__x x| + _____________________________ + */ + TRANSPOSE_REGISTERS_8x8(4, 6, 8, 10, 12, 14, 16, 18) // transpose the output of GEMM + TRANSPOSE_REGISTERS_8x8(5, 7, 9, 11, 13, 15, 17, 19) + TRANSPOSE_REGISTERS_6x8(20, 22, 24, 26, 28, 30, 0, 1) + VMOVUPD(ZMM(0), MEM(R8 )) + VMOVUPD(ZMM(1), MEM(R8, R12, 1)) // zmm0 and zmm1 are needed for other computations, + // therefore store zmm0, zmm1 's data in rcx + TRANSPOSE_REGISTERS_6x8(21, 23, 25, 27, 29, 31, 0, 1) + VMOVUPD(ZMM(0), MEM(R8, R12, 2)) + VMOVUPD(ZMM(1), MEM(R8, R12, 4)) // zmm0 and zmm1 are needed for other computations, + // therefore store zmm0, zmm1 's data in rcx + MOV(IMM(14), RDI) + LEA(MEM(, RDI, 8), RDI) + + MOV(VAR(alpha), RBX) + VBROADCASTSD(MEM(RBX), ZMM(3)) + + MOV(IMM(1), RSI) + LEA(MEM(, RSI, 8), RSI) + + MOV(VAR(b11), RCX) + LEA(MEM(RCX, RSI, 8), RDX) + + MOV(RCX, R11) + MOV(RDX, R14) + + // Scale by Alpha + VFMSUB231PD(MEM(RCX), ZMM(3), ZMM(4)) + ADD(RDI, RCX) + VFMSUB231PD(MEM(RCX), ZMM(3), ZMM(6)) + ADD(RDI, RCX) + VFMSUB231PD(MEM(RCX), ZMM(3), ZMM(8)) + ADD(RDI, RCX) + VFMSUB231PD(MEM(RCX), ZMM(3), ZMM(10)) + ADD(RDI, RCX) + VFMSUB231PD(MEM(RCX), ZMM(3), ZMM(12)) + ADD(RDI, RCX) + VFMSUB231PD(MEM(RCX), ZMM(3), ZMM(14)) + ADD(RDI, RCX) + VFMSUB231PD(MEM(RCX), ZMM(3), ZMM(16)) + ADD(RDI, RCX) + VFMSUB231PD(MEM(RCX), ZMM(3), ZMM(18)) + ADD(RDI, RCX) + + VFMSUB231PD(MEM(RCX), ZMM(3), ZMM(5)) + ADD(RDI, RCX) + VFMSUB231PD(MEM(RCX), ZMM(3), ZMM(7)) + ADD(RDI, RCX) + VFMSUB231PD(MEM(RCX), ZMM(3), ZMM(9)) + ADD(RDI, RCX) + VFMSUB231PD(MEM(RCX), ZMM(3), ZMM(11)) + ADD(RDI, RCX) + VFMSUB231PD(MEM(RCX), ZMM(3), ZMM(13)) + ADD(RDI, RCX) + VFMSUB231PD(MEM(RCX), ZMM(3), ZMM(15)) + ADD(RDI, RCX) + VFMSUB231PD(MEM(RCX), ZMM(3), ZMM(17)) + ADD(RDI, RCX) + VFMSUB231PD(MEM(RCX), ZMM(3), ZMM(19)) + + + + VFMSUB231PD(MEM(RDX), ZMM(3), ZMM(20)) + ADD(RDI, RDX) + VFMSUB231PD(MEM(RDX), ZMM(3), ZMM(22)) + ADD(RDI, RDX) + VFMSUB231PD(MEM(RDX), ZMM(3), ZMM(24)) + ADD(RDI, RDX) + VFMSUB231PD(MEM(RDX), ZMM(3), ZMM(26)) + ADD(RDI, RDX) + VFMSUB231PD(MEM(RDX), ZMM(3), ZMM(28)) + ADD(RDI, RDX) + VFMSUB231PD(MEM(RDX), ZMM(3), ZMM(30)) + ADD(RDI, RDX) + VMOVUPD(MEM(R8 ), ZMM(0)) + VMOVUPD(MEM(R8, R12, 1), ZMM(1)) + VFMSUB231PD(MEM(RDX), ZMM(3), ZMM(0)) + VMOVUPD(ZMM(0), MEM(R8 )) + ADD(RDI, RDX) + VFMSUB231PD(MEM(RDX), ZMM(3), ZMM(1)) + VMOVUPD(ZMM(1), MEM(R8, R12, 1)) + ADD(RDI, RDX) + + VFMSUB231PD(MEM(RDX), ZMM(3), ZMM(21)) + ADD(RDI, RDX) + VFMSUB231PD(MEM(RDX), ZMM(3), ZMM(23)) + ADD(RDI, RDX) + VFMSUB231PD(MEM(RDX), ZMM(3), ZMM(25)) + ADD(RDI, RDX) + VFMSUB231PD(MEM(RDX), ZMM(3), ZMM(27)) + ADD(RDI, RDX) + VFMSUB231PD(MEM(RDX), ZMM(3), ZMM(29)) + ADD(RDI, RDX) + VFMSUB231PD(MEM(RDX), ZMM(3), ZMM(31)) + ADD(RDI, RDX) + VMOVUPD(MEM(R8, R12, 2), ZMM(0)) + VMOVUPD(MEM(R8, R12, 4), ZMM(1)) + VFMSUB231PD(MEM(RDX), ZMM(3), ZMM(0)) + VMOVUPD(ZMM(0), MEM(R8, R12, 2)) + ADD(RDI, RDX) + VFMSUB231PD(MEM(RDX), ZMM(3), ZMM(1)) + VMOVUPD(ZMM(1), MEM(R8, R12, 4)) + + /* + TRSM region + Each row requires 1 iteration, therefore 16 iterations are present + */ + MOV(VAR(a11), RAX) + MOV(R11, RCX) + MOV(R14, RDX) + + + //iteration 0 -------------------------------------------- + VBROADCASTSD(MEM(RAX, (0+0*16)*8), ZMM(0)) + + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + VMULPD(ZMM(0), ZMM(4) , ZMM(4) ) + VMULPD(ZMM(0), ZMM(20), ZMM(20)) + #else + VDIVPD(ZMM(0), ZMM(4) , ZMM(4) ) + VDIVPD(ZMM(0), ZMM(20), ZMM(20)) + #endif + + VMOVUPD(ZMM(4), MEM(RCX)) + VEXTRACTF64X4(IMM(1), ZMM(20), YMM(0)) + VMOVUPD(YMM(20), MEM(RDX )) + VMOVUPD(XMM(0) , MEM(RDX,4*8)) // move only first six values to rcx + ADD(RDI, RCX) + ADD(RDI, RDX) + + //iteration 1 -------------------------------------------- + VBROADCASTSD(MEM(RAX, (1+0*16)*8), ZMM(0)) + VBROADCASTSD(MEM(RAX, (1+1*16)*8), ZMM(1)) + + VMULPD(ZMM(0), ZMM(4) , ZMM(2)) + VMULPD(ZMM(0), ZMM(20), ZMM(3)) + + VSUBPD(ZMM(2), ZMM(6) , ZMM(6) ) + VSUBPD(ZMM(3), ZMM(22), ZMM(22)) + + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + VMULPD(ZMM(1), ZMM(6) , ZMM(6) ) + VMULPD(ZMM(1), ZMM(22), ZMM(22)) + #else + VDIVPD(ZMM(1), ZMM(6) , ZMM(6)) + VDIVPD(ZMM(1), ZMM(22), ZMM(22)) + #endif + + VMOVUPD(ZMM(6), MEM(RCX)) + VEXTRACTF64X4(IMM(1), ZMM(22), YMM(0)) + VMOVUPD(YMM(22), MEM(RDX )) + VMOVUPD(XMM(0) , MEM(RDX,4*8)) + ADD(RDI, RCX) + ADD(RDI, RDX) + + //iteration 2 -------------------------------------------- + VBROADCASTSD(MEM(RAX, (2+0*16)*8), ZMM(0)) + VBROADCASTSD(MEM(RAX, (2+1*16)*8), ZMM(1)) + + VMULPD(ZMM(0), ZMM(4) , ZMM(2)) + VMULPD(ZMM(0), ZMM(20), ZMM(3)) + + + VBROADCASTSD(MEM(RAX, (2+2*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(6) , ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(22), ZMM(3)) + + VSUBPD(ZMM(2), ZMM(8) , ZMM(8) ) + VSUBPD(ZMM(3), ZMM(24), ZMM(24)) + + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + VMULPD(ZMM(0), ZMM(8) , ZMM(8) ) + VMULPD(ZMM(0), ZMM(24), ZMM(24)) + #else + VDIVPD(ZMM(0), ZMM(8) , ZMM(8) ) + VDIVPD(ZMM(0), ZMM(24), ZMM(24)) + #endif + + VMOVUPD(ZMM(8), MEM(RCX)) + VEXTRACTF64X4(IMM(1), ZMM(24), YMM(0)) + VMOVUPD(YMM(24), MEM(RDX )) + VMOVUPD(XMM(0) , MEM(RDX,4*8)) + ADD(RDI, RCX) + ADD(RDI, RDX) + + //iteration 3 -------------------------------------------- + VBROADCASTSD(MEM(RAX, (3+0*16)*8), ZMM(0)) + VBROADCASTSD(MEM(RAX, (3+1*16)*8), ZMM(1)) + + VMULPD(ZMM(0), ZMM(4) , ZMM(2)) + VMULPD(ZMM(0), ZMM(20), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (3+2*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(6) , ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(22), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (3+3*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(8) , ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(24), ZMM(3)) + + VSUBPD(ZMM(2), ZMM(10), ZMM(10)) + VSUBPD(ZMM(3), ZMM(26), ZMM(26)) + + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + VMULPD(ZMM(1), ZMM(10), ZMM(10)) + VMULPD(ZMM(1), ZMM(26), ZMM(26)) + #else + VDIVPD(ZMM(1), ZMM(10), ZMM(10)) + VDIVPD(ZMM(1), ZMM(26), ZMM(26)) + #endif + + VMOVUPD(ZMM(10), MEM(RCX)) + VEXTRACTF64X4(IMM(1), ZMM(26), YMM(0)) + VMOVUPD(YMM(26), MEM(RDX )) + VMOVUPD(XMM(0) , MEM(RDX,4*8)) + ADD(RDI, RCX) + ADD(RDI, RDX) + + //iteration 4 -------------------------------------------- + VBROADCASTSD(MEM(RAX, (4+0*16)*8), ZMM(0)) + VBROADCASTSD(MEM(RAX, (4+1*16)*8), ZMM(1)) + + VMULPD(ZMM(0), ZMM(4) , ZMM(2)) + VMULPD(ZMM(0), ZMM(20), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (4+2*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(6) , ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(22), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (4+3*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(8) , ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(24), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (4+4*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(10), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(26), ZMM(3)) + + VSUBPD(ZMM(2), ZMM(12), ZMM(12)) + VSUBPD(ZMM(3), ZMM(28), ZMM(28)) + + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + VMULPD(ZMM(0), ZMM(12), ZMM(12)) + VMULPD(ZMM(0), ZMM(28), ZMM(28)) + #else + VDIVPD(ZMM(0), ZMM(12), ZMM(12)) + VDIVPD(ZMM(0), ZMM(28), ZMM(28)) + #endif + + VMOVUPD(ZMM(12), MEM(RCX)) + VEXTRACTF64X4(IMM(1), ZMM(28), YMM(0)) + VMOVUPD(YMM(28), MEM(RDX )) + VMOVUPD(XMM(0) , MEM(RDX,4*8)) + ADD(RDI, RCX) + ADD(RDI, RDX) + + //iteration 5 -------------------------------------------- + VBROADCASTSD(MEM(RAX, (5+0*16)*8), ZMM(0)) + VBROADCASTSD(MEM(RAX, (5+1*16)*8), ZMM(1)) + + VMULPD(ZMM(0), ZMM(4) , ZMM(2)) + VMULPD(ZMM(0), ZMM(20), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (5+2*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(6) , ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(22), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (5+3*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(8) , ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(24), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (5+4*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(10), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(26), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (5+5*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(12), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(28), ZMM(3)) + + VSUBPD(ZMM(2), ZMM(14), ZMM(14)) + VSUBPD(ZMM(3), ZMM(30), ZMM(30)) + + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + VMULPD(ZMM(1), ZMM(14), ZMM(14)) + VMULPD(ZMM(1), ZMM(30), ZMM(30)) + #else + VDIVPD(ZMM(1), ZMM(14), ZMM(14)) + VDIVPD(ZMM(1), ZMM(30), ZMM(30)) + #endif + + VMOVUPD(ZMM(14), MEM(RCX)) + VEXTRACTF64X4(IMM(1), ZMM(30), YMM(0)) + VMOVUPD(YMM(30), MEM(RDX )) + VMOVUPD(XMM(0) , MEM(RDX,4*8)) + ADD(RDI, RCX) + ADD(RDI, RDX) + + //iteration 6 -------------------------------------------- + VBROADCASTSD(MEM(RAX, (6+0*16)*8), ZMM(0)) + VBROADCASTSD(MEM(RAX, (6+1*16)*8), ZMM(1)) + + VMULPD(ZMM(0), ZMM(4) , ZMM(2)) + VMULPD(ZMM(0), ZMM(20), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (6+2*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(6) , ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(22), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (6+3*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(8) , ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(24), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (6+4*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(10), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(26), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (6+5*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(12), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(28), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (6+6*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(14), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(30), ZMM(3)) + + VMOVUPD(MEM(R8), ZMM(1)) + VSUBPD(ZMM(2), ZMM(16), ZMM(16)) + VSUBPD(ZMM(3), ZMM(1) , ZMM(1) ) + + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + VMULPD(ZMM(0), ZMM(16), ZMM(16)) + VMULPD(ZMM(0), ZMM(1) , ZMM(1) ) + #else + VDIVPD(ZMM(0), ZMM(16), ZMM(16)) + VDIVPD(ZMM(0), ZMM(1) , ZMM(1) ) + #endif + + VMOVUPD(ZMM(1), MEM(R8 )) + + VMOVUPD(ZMM(16), MEM(RCX)) + VEXTRACTF64X4(IMM(1), ZMM(1), YMM(0)) + VMOVUPD(YMM(1), MEM(RDX )) + VMOVUPD(XMM(0), MEM(RDX,4*8)) + ADD(RDI, RCX) + ADD(RDI, RDX) + + //iteration 7 -------------------------------------------- + VBROADCASTSD(MEM(RAX, (7+0*16)*8), ZMM(0)) + VBROADCASTSD(MEM(RAX, (7+1*16)*8), ZMM(1)) + + VMULPD(ZMM(0), ZMM(4) , ZMM(2)) + VMULPD(ZMM(0), ZMM(20), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (7+2*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(6) , ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(22), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (7+3*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(8) , ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(24), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (7+4*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(10), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(26), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (7+5*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(12), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(28), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (7+6*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(14), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(30), ZMM(3)) + + VMOVUPD(MEM(R8 ), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(16), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(1) , ZMM(3)) + VBROADCASTSD(MEM(RAX, (7+7*16)*8), ZMM(1)) + + VMOVUPD(MEM(R8, R12, 1), ZMM(0)) + VSUBPD(ZMM(2), ZMM(18), ZMM(18)) + VSUBPD(ZMM(3), ZMM(0) , ZMM(0) ) + + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + VMULPD(ZMM(1), ZMM(18), ZMM(18)) + VMULPD(ZMM(1), ZMM(0) , ZMM(0) ) + #else + VDIVPD(ZMM(1), ZMM(18), ZMM(18)) + VDIVPD(ZMM(1), ZMM(0) , ZMM(0) ) + #endif + + VMOVUPD(ZMM(0), MEM(R8, R12, 1)) + + VMOVUPD(ZMM(18), MEM(RCX)) + VEXTRACTF64X4(IMM(1), ZMM(0), YMM(1)) + VMOVUPD(YMM(0), MEM(RDX )) + VMOVUPD(XMM(1), MEM(RDX,4*8)) + ADD(RDI, RCX) + ADD(RDI, RDX) + + //iteration 8 -------------------------------------------- + VBROADCASTSD(MEM(RAX, (8+0*16)*8), ZMM(0)) + VBROADCASTSD(MEM(RAX, (8+1*16)*8), ZMM(1)) + + VMULPD(ZMM(0), ZMM(4) , ZMM(2)) + VMULPD(ZMM(0), ZMM(20), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (8+2*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(6) , ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(22), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (8+3*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(8) , ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(24), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (8+4*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(10), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(26), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (8+5*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(12), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(28), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (8+6*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(14), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(30), ZMM(3)) + + VMOVUPD(MEM(R8 ), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(16), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(1) , ZMM(3)) + VBROADCASTSD(MEM(RAX, (8+7*16)*8), ZMM(1)) + + VMOVUPD(MEM(R8, R12, 1), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(18), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(0) , ZMM(3)) + VBROADCASTSD(MEM(RAX, (8+8*16)*8), ZMM(0)) + + VSUBPD(ZMM(2), ZMM(5) , ZMM(5) ) + VSUBPD(ZMM(3), ZMM(21), ZMM(21)) + + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + VMULPD(ZMM(0), ZMM(5), ZMM(5)) + VMULPD(ZMM(0), ZMM(21), ZMM(21)) + #else + VDIVPD(ZMM(0), ZMM(5), ZMM(5)) + VDIVPD(ZMM(0), ZMM(21), ZMM(21)) + #endif + + VMOVUPD(ZMM(5), MEM(RCX)) + VEXTRACTF64X4(IMM(1), ZMM(21), YMM(1)) + VMOVUPD(YMM(21), MEM(RDX )) + VMOVUPD(XMM(1), MEM(RDX,4*8)) + ADD(RDI, RCX) + ADD(RDI, RDX) + + //iteration 9 -------------------------------------------- + VBROADCASTSD(MEM(RAX, (9+0*16)*8), ZMM(0)) + VBROADCASTSD(MEM(RAX, (9+1*16)*8), ZMM(1)) + + VMULPD(ZMM(0), ZMM(4) , ZMM(2)) + VMULPD(ZMM(0), ZMM(20), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (9+2*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(6) , ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(22), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (9+3*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(8) , ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(24), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (9+4*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(10), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(26), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (9+5*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(12), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(28), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (9+6*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(14), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(30), ZMM(3)) + + VMOVUPD(MEM(R8 ), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(16), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(1) , ZMM(3)) + VBROADCASTSD(MEM(RAX, (9+7*16)*8), ZMM(1)) + + VMOVUPD(MEM(R8, R12, 1), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(18), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(0) , ZMM(3)) + VBROADCASTSD(MEM(RAX, (9+8*16)*8), ZMM(0)) + + VBROADCASTSD(MEM(RAX, (9+9*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(5), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(21), ZMM(3)) + + VSUBPD(ZMM(2), ZMM(7), ZMM(7)) + VSUBPD(ZMM(3), ZMM(23), ZMM(23)) + + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + VMULPD(ZMM(1), ZMM(7), ZMM(7)) + VMULPD(ZMM(1), ZMM(23), ZMM(23)) + #else + VDIVPD(ZMM(1), ZMM(7), ZMM(7)) + VDIVPD(ZMM(1), ZMM(23), ZMM(23)) + #endif + + VMOVUPD(ZMM(7), MEM(RCX)) + VEXTRACTF64X4(IMM(1), ZMM(23), YMM(1)) + VMOVUPD(YMM(23), MEM(RDX )) + VMOVUPD(XMM(1), MEM(RDX,4*8)) + ADD(RDI, RCX) + ADD(RDI, RDX) + + //iteration 10 -------------------------------------------- + VBROADCASTSD(MEM(RAX, (10+0*16)*8), ZMM(0)) + VBROADCASTSD(MEM(RAX, (10+1*16)*8), ZMM(1)) + + VMULPD(ZMM(0), ZMM(4) , ZMM(2)) + VMULPD(ZMM(0), ZMM(20), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (10+2*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(6) , ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(22), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (10+3*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(8) , ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(24), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (10+4*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(10), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(26), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (10+5*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(12), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(28), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (10+6*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(14), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(30), ZMM(3)) + + VMOVUPD(MEM(R8 ), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(16), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(1) , ZMM(3)) + VBROADCASTSD(MEM(RAX, (10+7*16)*8), ZMM(1)) + + VMOVUPD(MEM(R8, R12, 1), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(18), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(0) , ZMM(3)) + VBROADCASTSD(MEM(RAX, (10+8*16)*8), ZMM(0)) + + VBROADCASTSD(MEM(RAX, (10+9*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(5), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(21), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (10+10*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(7), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(23), ZMM(3)) + + VSUBPD(ZMM(2), ZMM(9), ZMM(9)) + VSUBPD(ZMM(3), ZMM(25), ZMM(25)) + + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + VMULPD(ZMM(0), ZMM(9), ZMM(9)) + VMULPD(ZMM(0), ZMM(25), ZMM(25)) + #else + VDIVPD(ZMM(0), ZMM(9), ZMM(9)) + VDIVPD(ZMM(0), ZMM(25), ZMM(25)) + #endif + + VMOVUPD(ZMM(9), MEM(RCX)) + VEXTRACTF64X4(IMM(1), ZMM(25), YMM(1)) + VMOVUPD(YMM(25), MEM(RDX )) + VMOVUPD(XMM(1), MEM(RDX,4*8)) + ADD(RDI, RCX) + ADD(RDI, RDX) + + //iteration 11 -------------------------------------------- + VBROADCASTSD(MEM(RAX, (11+0*16)*8), ZMM(0)) + VBROADCASTSD(MEM(RAX, (11+1*16)*8), ZMM(1)) + + VMULPD(ZMM(0), ZMM(4) , ZMM(2)) + VMULPD(ZMM(0), ZMM(20), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (11+2*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(6) , ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(22), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (11+3*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(8) , ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(24), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (11+4*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(10), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(26), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (11+5*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(12), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(28), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (11+6*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(14), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(30), ZMM(3)) + + VMOVUPD(MEM(R8 ), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(16), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(1) , ZMM(3)) + VBROADCASTSD(MEM(RAX, (11+7*16)*8), ZMM(1)) + + VMOVUPD(MEM(R8, R12, 1), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(18), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(0) , ZMM(3)) + VBROADCASTSD(MEM(RAX, (11+8*16)*8), ZMM(0)) + + VBROADCASTSD(MEM(RAX, (11+9*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(5), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(21), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (11+10*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(7), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(23), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (11+11*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(9), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(25), ZMM(3)) + + VSUBPD(ZMM(2), ZMM(11), ZMM(11)) + VSUBPD(ZMM(3), ZMM(27), ZMM(27)) + + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + VMULPD(ZMM(1), ZMM(11), ZMM(11)) + VMULPD(ZMM(1), ZMM(27), ZMM(27)) + #else + VDIVPD(ZMM(1), ZMM(11), ZMM(11)) + VDIVPD(ZMM(1), ZMM(27), ZMM(27)) + #endif + + VMOVUPD(ZMM(11), MEM(RCX)) + VEXTRACTF64X4(IMM(1), ZMM(27), YMM(1)) + VMOVUPD(YMM(27), MEM(RDX )) + VMOVUPD(XMM(1), MEM(RDX,4*8)) + ADD(RDI, RCX) + ADD(RDI, RDX) + + //iteration 12 -------------------------------------------- + VBROADCASTSD(MEM(RAX, (12+0*16)*8), ZMM(0)) + VBROADCASTSD(MEM(RAX, (12+1*16)*8), ZMM(1)) + + VMULPD(ZMM(0), ZMM(4) , ZMM(2)) + VMULPD(ZMM(0), ZMM(20), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (12+2*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(6) , ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(22), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (12+3*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(8) , ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(24), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (12+4*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(10), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(26), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (12+5*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(12), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(28), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (12+6*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(14), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(30), ZMM(3)) + + VMOVUPD(MEM(R8 ), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(16), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(1) , ZMM(3)) + VBROADCASTSD(MEM(RAX, (12+7*16)*8), ZMM(1)) + + VMOVUPD(MEM(R8, R12, 1), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(18), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(0) , ZMM(3)) + VBROADCASTSD(MEM(RAX, (12+8*16)*8), ZMM(0)) + + VBROADCASTSD(MEM(RAX, (12+9*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(5), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(21), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (12+10*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(7), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(23), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (12+11*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(9), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(25), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (12+12*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(11), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(27), ZMM(3)) + + VSUBPD(ZMM(2), ZMM(13), ZMM(13)) + VSUBPD(ZMM(3), ZMM(29), ZMM(29)) + + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + VMULPD(ZMM(0), ZMM(13), ZMM(13)) + VMULPD(ZMM(0), ZMM(29), ZMM(29)) + #else + VDIVPD(ZMM(0), ZMM(13), ZMM(13)) + VDIVPD(ZMM(0), ZMM(29), ZMM(29)) + #endif + + VMOVUPD(ZMM(13), MEM(RCX)) + VEXTRACTF64X4(IMM(1), ZMM(29), YMM(1)) + VMOVUPD(YMM(29), MEM(RDX )) + VMOVUPD(XMM(1), MEM(RDX,4*8)) + ADD(RDI, RCX) + ADD(RDI, RDX) + + //iteration 13 -------------------------------------------- + VBROADCASTSD(MEM(RAX, (13+0*16)*8), ZMM(0)) + VBROADCASTSD(MEM(RAX, (13+1*16)*8), ZMM(1)) + + VMULPD(ZMM(0), ZMM(4) , ZMM(2)) + VMULPD(ZMM(0), ZMM(20), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (13+2*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(6) , ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(22), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (13+3*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(8) , ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(24), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (13+4*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(10), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(26), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (13+5*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(12), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(28), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (13+6*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(14), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(30), ZMM(3)) + + VMOVUPD(MEM(R8 ), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(16), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(1) , ZMM(3)) + VBROADCASTSD(MEM(RAX, (13+7*16)*8), ZMM(1)) + + VMOVUPD(MEM(R8, R12, 1), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(18), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(0) , ZMM(3)) + VBROADCASTSD(MEM(RAX, (13+8*16)*8), ZMM(0)) + + VBROADCASTSD(MEM(RAX, (13+9*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(5), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(21), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (13+10*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(7), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(23), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (13+11*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(9), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(25), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (13+12*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(11), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(27), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (13+13*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(13), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(29), ZMM(3)) + + VSUBPD(ZMM(2), ZMM(15), ZMM(15)) + VSUBPD(ZMM(3), ZMM(31), ZMM(31)) + + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + VMULPD(ZMM(1), ZMM(15), ZMM(15)) + VMULPD(ZMM(1), ZMM(31), ZMM(31)) + #else + VDIVPD(ZMM(1), ZMM(15), ZMM(15)) + VDIVPD(ZMM(1), ZMM(31), ZMM(31)) + #endif + + VMOVUPD(ZMM(15), MEM(RCX)) + VEXTRACTF64X4(IMM(1), ZMM(31), YMM(1)) + VMOVUPD(YMM(31), MEM(RDX )) + VMOVUPD(XMM(1), MEM(RDX,4*8)) + ADD(RDI, RCX) + ADD(RDI, RDX) + + //iteration 14 -------------------------------------------- + VBROADCASTSD(MEM(RAX, (14+0*16)*8), ZMM(0)) + VBROADCASTSD(MEM(RAX, (14+1*16)*8), ZMM(1)) + + VMULPD(ZMM(0), ZMM(4) , ZMM(2)) + VMULPD(ZMM(0), ZMM(20), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (14+2*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(6) , ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(22), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (14+3*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(8) , ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(24), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (14+4*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(10), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(26), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (14+5*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(12), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(28), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (14+6*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(14), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(30), ZMM(3)) + + VMOVUPD(MEM(R8 ), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(16), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(1) , ZMM(3)) + VBROADCASTSD(MEM(RAX, (14+7*16)*8), ZMM(1)) + + VMOVUPD(MEM(R8, R12, 1), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(18), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(0) , ZMM(3)) + VBROADCASTSD(MEM(RAX, (14+8*16)*8), ZMM(0)) + + VBROADCASTSD(MEM(RAX, (14+9*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(5), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(21), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (14+10*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(7), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(23), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (14+11*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(9), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(25), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (14+12*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(11), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(27), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (14+13*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(13), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(29), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (14+14*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(15), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(31), ZMM(3)) + + VMOVUPD(MEM(R8, R12, 2), ZMM(1)) + VSUBPD(ZMM(2), ZMM(17), ZMM(17)) + VSUBPD(ZMM(3), ZMM(1) , ZMM(1) ) + + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + VMULPD(ZMM(0), ZMM(17), ZMM(17)) + VMULPD(ZMM(0), ZMM(1) , ZMM(1) ) + #else + VDIVPD(ZMM(0), ZMM(17), ZMM(17)) + VDIVPD(ZMM(0), ZMM(1) , ZMM(1) ) + #endif + + VMOVUPD(ZMM(1), MEM(R8, R12, 2)) + + VMOVUPD(ZMM(17), MEM(RCX)) + VEXTRACTF64X4(IMM(1), ZMM(1), YMM(0)) + VMOVUPD(YMM(1), MEM(RDX )) + VMOVUPD(XMM(0), MEM(RDX,4*8)) + ADD(RDI, RCX) + ADD(RDI, RDX) + + //iteration 15 -------------------------------------------- + VBROADCASTSD(MEM(RAX, (15+0*16)*8), ZMM(0)) + VBROADCASTSD(MEM(RAX, (15+1*16)*8), ZMM(1)) + + VMULPD(ZMM(0), ZMM(4) , ZMM(2)) + VMULPD(ZMM(0), ZMM(20), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (15+2*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(6) , ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(22), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (15+3*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(8) , ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(24), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (15+4*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(10), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(26), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (15+5*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(12), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(28), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (15+6*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(14), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(30), ZMM(3)) + + VMOVUPD(MEM(R8 ), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(16), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(1) , ZMM(3)) + VBROADCASTSD(MEM(RAX, (15+7*16)*8), ZMM(1)) + + VMOVUPD(MEM(R8, R12, 1), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(18), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(0) , ZMM(3)) + VBROADCASTSD(MEM(RAX, (15+8*16)*8), ZMM(0)) + + VBROADCASTSD(MEM(RAX, (15+9*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(5), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(21), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (15+10*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(7), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(23), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (15+11*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(9), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(25), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (15+12*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(11), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(27), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (15+13*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(13), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(29), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (15+14*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(15), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(31), ZMM(3)) + + VMOVUPD(MEM(R8, R12, 2), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(17), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(1 ), ZMM(3)) + VBROADCASTSD(MEM(RAX, (15+15*16)*8), ZMM(1)) + + VMOVUPD(MEM(R8, R12, 4), ZMM(0)) + VSUBPD(ZMM(2), ZMM(19), ZMM(19)) + VSUBPD(ZMM(3), ZMM(0) , ZMM(0) ) + + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + VMULPD(ZMM(1), ZMM(19), ZMM(19)) + VMULPD(ZMM(1), ZMM(0) , ZMM(0) ) + #else + VDIVPD(ZMM(1), ZMM(19), ZMM(19)) + VDIVPD(ZMM(1), ZMM(0) , ZMM(0) ) + #endif + + VMOVUPD(ZMM(0), MEM(R8, R12, 4)) + + VMOVUPD(ZMM(19), MEM(RCX)) + VEXTRACTF64X4(IMM(1), ZMM(0), YMM(1)) + VMOVUPD(YMM(0), MEM(RDX )) + VMOVUPD(XMM(1), MEM(RDX,4*8)) + ADD(RDI, RCX) + ADD(RDI, RDX) + + + /* + Storage Region (Post TRSM) + */ + MOV(R8, RCX) + MOV(R9, RDI) + MOV(VAR(cs_c), RSI) + + LEA(MEM(RCX, RSI, 8), RDX) + LEA(MEM(RCX, RDI, 8), R14) + + LEA(MEM(RSI, RSI, 2), R12) + LEA(MEM(RSI, RSI, 4), R13) + LEA(MEM(R13, RSI, 2), R15) + + CMP(IMM(8), RSI) + JZ(.DROWSTORED) + + CMP(IMM(8), RDI) + JZ(.DCOLSTORED) + + LABEL(.DROWSTORED) + + VMOVUPD(MEM(R8 ), ZMM(1)) + VMOVUPD(ZMM(4), MEM(RCX)) + VMOVUPD(MEM(R8, RDI, 1), ZMM(4)) + + ADD(RDI, RCX) + VEXTRACTF64X4(IMM(1), ZMM(20), YMM(0)) + VMOVUPD(YMM(20), MEM(RDX )) + + VMOVUPD(MEM(R8, RDI, 2), ZMM(20)) + + VMOVUPD(XMM(0) , MEM(RDX,4*8)) + ADD(RDI, RDX) + + VMOVUPD(ZMM(6), MEM(RCX)) + + VMOVUPD(MEM(R8, RDI, 4), ZMM(6)) + + ADD(RDI, RCX) + VEXTRACTF64X4(IMM(1), ZMM(22), YMM(0)) + VMOVUPD(YMM(22), MEM(RDX )) + VMOVUPD(XMM(0) , MEM(RDX,4*8)) + ADD(RDI, RDX) + + VMOVUPD(ZMM(8), MEM(RCX)) + ADD(RDI, RCX) + VEXTRACTF64X4(IMM(1), ZMM(24), YMM(0)) + VMOVUPD(YMM(24), MEM(RDX )) + VMOVUPD(XMM(0) , MEM(RDX,4*8)) + ADD(RDI, RDX) + + VMOVUPD(ZMM(10), MEM(RCX)) + ADD(RDI, RCX) + VEXTRACTF64X4(IMM(1), ZMM(26), YMM(0)) + VMOVUPD(YMM(26), MEM(RDX )) + VMOVUPD(XMM(0) , MEM(RDX,4*8)) + ADD(RDI, RDX) + + VMOVUPD(ZMM(12), MEM(RCX)) + ADD(RDI, RCX) + VEXTRACTF64X4(IMM(1), ZMM(28), YMM(0)) + VMOVUPD(YMM(28), MEM(RDX )) + VMOVUPD(XMM(0) , MEM(RDX,4*8)) + ADD(RDI, RDX) + + VMOVUPD(ZMM(14), MEM(RCX)) + ADD(RDI, RCX) + VEXTRACTF64X4(IMM(1), ZMM(30), YMM(0)) + VMOVUPD(YMM(30), MEM(RDX )) + VMOVUPD(XMM(0) , MEM(RDX,4*8)) + ADD(RDI, RDX) + + VMOVUPD(ZMM(16), MEM(RCX)) + ADD(RDI, RCX) + VEXTRACTF64X4(IMM(1), ZMM(1), YMM(0)) + VMOVUPD(YMM(1), MEM(RDX )) + VMOVUPD(XMM(0), MEM(RDX,4*8)) + ADD(RDI, RDX) + + VMOVUPD(ZMM(18), MEM(RCX)) + ADD(RDI, RCX) + VEXTRACTF64X4(IMM(1), ZMM(4), YMM(0)) + VMOVUPD(YMM(4), MEM(RDX )) + VMOVUPD(XMM(0), MEM(RDX,4*8)) + ADD(RDI, RDX) + + VMOVUPD(ZMM(5), MEM(RCX)) + ADD(RDI, RCX) + VEXTRACTF64X4(IMM(1), ZMM(21), YMM(0)) + VMOVUPD(YMM(21), MEM(RDX )) + VMOVUPD(XMM(0) , MEM(RDX,4*8)) + ADD(RDI, RDX) + + VMOVUPD(ZMM(7), MEM(RCX)) + ADD(RDI, RCX) + VEXTRACTF64X4(IMM(1), ZMM(23), YMM(0)) + VMOVUPD(YMM(23), MEM(RDX )) + VMOVUPD(XMM(0) , MEM(RDX,4*8)) + ADD(RDI, RDX) + + VMOVUPD(ZMM(9), MEM(RCX)) + ADD(RDI, RCX) + VEXTRACTF64X4(IMM(1), ZMM(25), YMM(0)) + VMOVUPD(YMM(25), MEM(RDX )) + VMOVUPD(XMM(0) , MEM(RDX,4*8)) + ADD(RDI, RDX) + + VMOVUPD(ZMM(11), MEM(RCX)) + ADD(RDI, RCX) + VEXTRACTF64X4(IMM(1), ZMM(27), YMM(0)) + VMOVUPD(YMM(27), MEM(RDX )) + VMOVUPD(XMM(0) , MEM(RDX,4*8)) + ADD(RDI, RDX) + + VMOVUPD(ZMM(13), MEM(RCX)) + ADD(RDI, RCX) + VEXTRACTF64X4(IMM(1), ZMM(29), YMM(0)) + VMOVUPD(YMM(29), MEM(RDX )) + VMOVUPD(XMM(0) , MEM(RDX,4*8)) + ADD(RDI, RDX) + + VMOVUPD(ZMM(15), MEM(RCX)) + ADD(RDI, RCX) + VEXTRACTF64X4(IMM(1), ZMM(31), YMM(0)) + VMOVUPD(YMM(31), MEM(RDX )) + VMOVUPD(XMM(0) , MEM(RDX,4*8)) + ADD(RDI, RDX) + + VMOVUPD(ZMM(17), MEM(RCX)) + ADD(RDI, RCX) + VEXTRACTF64X4(IMM(1), ZMM(20), YMM(0)) + VMOVUPD(YMM(20), MEM(RDX )) + VMOVUPD(XMM(0), MEM(RDX,4*8)) + ADD(RDI, RDX) + + VMOVUPD(ZMM(19), MEM(RCX)) + VEXTRACTF64X4(IMM(1), ZMM(6), YMM(0)) + VMOVUPD(YMM(6), MEM(RDX )) + VMOVUPD(XMM(0), MEM(RDX,4*8)) + + + JMP(.DDONE) + LABEL(.DCOLSTORED) + + + MOV(VAR(offsetPtr), R12) + LEA(MEM(RCX, RSI, 8), RDX) + VPBROADCASTQ(RSI, ZMM(0)) + VPMULLQ(MEM(R12), ZMM(0), ZMM(2)) + VPMULLQ(MEM(R12,64), ZMM(0), ZMM(3)) + + VMOVUPD(MEM(RCX ), ZMM(0)) + VMOVUPD(MEM(RCX, RSI, 1), ZMM(1)) + + MOV(RDX, RCX) + LEA(MEM(RCX, RDI, 8), R14) + UPDATE_C_COL_SCATTERED_2x6(20,21) + VMOVUPD(MEM(R8, RSI, 2), ZMM(20)) + VMOVUPD(MEM(R8, RSI, 4), ZMM(21)) + UPDATE_C_COL_SCATTERED_2x6(22,23) + UPDATE_C_COL_SCATTERED_2x6(24,25) + UPDATE_C_COL_SCATTERED_2x6(26,27) + UPDATE_C_COL_SCATTERED_2x6(28,29) + UPDATE_C_COL_SCATTERED_2x6(30,31) + UPDATE_C_COL_SCATTERED_2x6(0 ,20 ) + UPDATE_C_COL_SCATTERED_2x6(1 ,21 ) + + MOV(R8, RCX) + LEA(MEM(RCX, RDI, 8), R14) + UPDATE_C_COL_SCATTERED( 4, 5) + UPDATE_C_COL_SCATTERED( 6, 7) + UPDATE_C_COL_SCATTERED( 8, 9) + UPDATE_C_COL_SCATTERED(10,11) + UPDATE_C_COL_SCATTERED(12,13) + UPDATE_C_COL_SCATTERED(14,15) + UPDATE_C_COL_SCATTERED(16,17) + UPDATE_C_COL_SCATTERED(18,19) + + + LABEL(.DDONE) + + VZEROUPPER() + + end_asm( + : // output operands (none) + : // input operands + [a10] "m" (a10), // 1 + [k] "m" (k), // 2 + [b01] "m" (b01), // 3 + [a11] "m" (a11), // 6 + [b11] "m" (b11), // 7 + [c11] "m" (c11), // 8 + [rs_c] "m" (rs_c), // 9 + [cs_c] "m" (cs_c), // 10, + [alpha] "m" (alpha), + [offsetPtr] "m" (offsetPtr) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rdi", "rsi", "r8", "r9", "r10", "r11", "r12", + "r13", "r14", "r15", "zmm0", "zmm1", "zmm2", "zmm3", "zmm4", "zmm5", + "zmm6", "zmm7", "zmm8", "zmm9", "zmm10", "zmm11", "zmm12", "zmm13", + "zmm14", "zmm15", "zmm16", "zmm17", "zmm18", "zmm19", "zmm20", "zmm21", + "zmm22", "zmm23", "zmm24", "zmm25", "zmm26", "zmm27", "zmm28", "zmm29", + "zmm30", "zmm31", "memory" + ) + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_9); +} \ No newline at end of file diff --git a/kernels/zen4/3/bli_gemmtrsm_u_zen_16x14.c b/kernels/zen4/3/bli_gemmtrsm_u_zen_16x14.c new file mode 100644 index 0000000000..787f85155c --- /dev/null +++ b/kernels/zen4/3/bli_gemmtrsm_u_zen_16x14.c @@ -0,0 +1,1706 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#define BLIS_ASM_SYNTAX_ATT +#include "bli_x86_asm_macros.h" + +#define A_L1_PREFETCH_DIST 12 // in units of k iterations +#define B_L1_PREFETCH_DIST 12 // e.g. 4 k iterations ~= 56 cycles +#define TAIL_NITER 5 // in units of 4x unrolled k iterations + // e.g. 5 -> 4*5 k iterations ~= 280 cycles + +#define PREFETCH_A_L1(n, k) \ + PREFETCH(0, MEM(RAX, A_L1_PREFETCH_DIST*16*8 + (2*n+k)*64)) +#define PREFETCH_B_L1(n, k) \ + PREFETCH(0, MEM(RBX, B_L1_PREFETCH_DIST*14*8 + (2*n+k)*56)) + +#define LOOP_ALIGN ALIGN32 + +#define SUBITER(n) \ +\ + PREFETCH_A_L1(n, 0) \ + \ + VBROADCASTSD(MEM(RBX, (14*n + 0)*8), ZMM(2)) \ + VBROADCASTSD(MEM(RBX, (14*n + 1)*8), ZMM(3)) \ + VFMADD231PD(ZMM(0), ZMM(2), ZMM(4)) \ + VFMADD231PD(ZMM(1), ZMM(2), ZMM(5)) \ + VFMADD231PD(ZMM(0), ZMM(3), ZMM(6)) \ + VFMADD231PD(ZMM(1), ZMM(3), ZMM(7)) \ + \ + VBROADCASTSD(MEM(RBX, (14*n + 2)*8), ZMM(2)) \ + VBROADCASTSD(MEM(RBX, (14*n + 3)*8), ZMM(3)) \ + VFMADD231PD(ZMM(0), ZMM(2), ZMM(8) ) \ + VFMADD231PD(ZMM(1), ZMM(2), ZMM(9) ) \ + VFMADD231PD(ZMM(0), ZMM(3), ZMM(10)) \ + VFMADD231PD(ZMM(1), ZMM(3), ZMM(11)) \ + \ + PREFETCH_B_L1(n, 0) \ + \ + VBROADCASTSD(MEM(RBX, (14*n + 4)*8), ZMM(2)) \ + VBROADCASTSD(MEM(RBX, (14*n + 5)*8), ZMM(3)) \ + VFMADD231PD(ZMM(0), ZMM(2), ZMM(12)) \ + VFMADD231PD(ZMM(1), ZMM(2), ZMM(13)) \ + VFMADD231PD(ZMM(0), ZMM(3), ZMM(14)) \ + VFMADD231PD(ZMM(1), ZMM(3), ZMM(15)) \ + \ + VBROADCASTSD(MEM(RBX, (14*n + 6)*8), ZMM(2)) \ + VBROADCASTSD(MEM(RBX, (14*n + 7)*8), ZMM(3)) \ + VFMADD231PD(ZMM(0), ZMM(2), ZMM(16)) \ + VFMADD231PD(ZMM(1), ZMM(2), ZMM(17)) \ + VFMADD231PD(ZMM(0), ZMM(3), ZMM(18)) \ + VFMADD231PD(ZMM(1), ZMM(3), ZMM(19)) \ + \ + PREFETCH_A_L1(n, 1) \ + \ + VBROADCASTSD(MEM(RBX, (14*n + 8)*8), ZMM(2)) \ + VBROADCASTSD(MEM(RBX, (14*n + 9)*8), ZMM(3)) \ + VFMADD231PD(ZMM(0), ZMM(2), ZMM(20)) \ + VFMADD231PD(ZMM(1), ZMM(2), ZMM(21)) \ + VFMADD231PD(ZMM(0), ZMM(3), ZMM(22)) \ + VFMADD231PD(ZMM(1), ZMM(3), ZMM(23)) \ + \ + VBROADCASTSD(MEM(RBX, (14*n + 10)*8), ZMM(2)) \ + VBROADCASTSD(MEM(RBX, (14*n + 11)*8), ZMM(3)) \ + VFMADD231PD(ZMM(0), ZMM(2), ZMM(24)) \ + VFMADD231PD(ZMM(1), ZMM(2), ZMM(25)) \ + VFMADD231PD(ZMM(0), ZMM(3), ZMM(26)) \ + VFMADD231PD(ZMM(1), ZMM(3), ZMM(27)) \ + \ + PREFETCH_B_L1(n, 1) \ + \ + VBROADCASTSD(MEM(RBX, (14*n + 12)*8), ZMM(2)) \ + VBROADCASTSD(MEM(RBX, (14*n + 13)*8), ZMM(3)) \ + VFMADD231PD(ZMM(0), ZMM(2), ZMM(28)) \ + VFMADD231PD(ZMM(1), ZMM(2), ZMM(29)) \ + VFMADD231PD(ZMM(0), ZMM(3), ZMM(30)) \ + VFMADD231PD(ZMM(1), ZMM(3), ZMM(31)) \ + \ + VMOVAPD(MEM(RAX,((n*2)+2)*8*8), ZMM(0)) \ + VMOVAPD(MEM(RAX,((n*2)+3)*8*8), ZMM(1)) + +#define UPDATE_C_COL_SCATTERED(R1,R2) \ +\ + KXNORW(K(0), K(0), K(1)) \ + KXNORW(K(0), K(0), K(2)) \ + VSCATTERQPD(ZMM(R1), MEM(RCX,ZMM(2),1) MASK_K(1)) \ + VSCATTERQPD(ZMM(R2), MEM(R14,ZMM(2),1) MASK_K(2)) \ + ADD(RDI, RCX) \ + ADD(RDI, R14) \ + +/* scatter only first 6 elements of r1 and r2 */ +#define UPDATE_C_COL_SCATTERED_2x6(R1,R2) \ +\ + KXNORW(K(0), K(0), K(1)) \ + KXNORW(K(0), K(0), K(2)) \ + MOVQ(IMM(0b00111111), RAX) \ + KMOVQ(RAX, K(2)) \ + KMOVQ(RAX, K(1)) \ + VSCATTERQPD(ZMM(R1), MEM(RCX,ZMM(2),1) MASK_K(1)) \ + VSCATTERQPD(ZMM(R2), MEM(R14,ZMM(2),1) MASK_K(2)) \ + ADD(RDI, RCX) \ + ADD(RDI, R14) \ + +/* +Transpose 8 zmm registers and store the output in the given 8 registers + Note: Requires offsetPointer for scatter instruction + and 512 bytes of free memory (rcx) for transpose. + Input : + R1 = [ 0, 1, 2, 3, 4, 5, 6, 7] + R2 = [ 8, 9, 10, 11, 12, 13, 14, 15] + R3 = [16, 17, 18, 19, 20, 21, 22, 23] + R4 = [24, 25, 26, 27, 28, 29, 30, 31] + R5 = [32, 33, 34, 35, 36, 37, 38, 39] + R6 = [40, 41, 42, 43, 44, 45, 46, 47] + R7 = [48, 49, 50, 51, 52, 53, 54, 55] + R18= [56, 57, 58, 59, 60, 61, 62, 63] + Output : + R1 = [0, 8, 16, 24, 32, 40, 48, 56] + R2 = [1, 9, 17, 25, 33, 41, 49, 57] + R3 = [2, 10, 18, 26, 34, 42, 50, 58] + R4 = [3, 11, 19, 27, 35, 43, 51, 59] + R5 = [4, 12, 20, 28, 36, 44, 52, 60] + R6 = [5, 13, 21, 29, 37, 45, 53, 61] + R7 = [6, 14, 22, 30, 38, 46, 54, 62] + R18= [7, 15, 23, 31, 39, 47, 55, 63] +*/ +#define TRANSPOSE_REGISTERS_8x8(R1, R2, R3, R4, R5, R6, R7, R18) \ +\ + MOV(R8, RCX) \ + MOV(VAR(cs_c), RSI) \ + MOV(R9, RDI) \ + LEA(MEM(RCX, RSI, 8), RDX) \ + MOV(VAR(offsetPtr), R13) \ + MOV(RDI, R12) \ + CMP(RSI, R12) \ + CMOVL(RSI, R12) \ + VPBROADCASTQ(R12, ZMM(0)) \ + VPMULLQ(MEM(R13 ), ZMM(0), ZMM(2)) \ + \ + KXNORW(K(0), K(0), K(1)) \ + KXNORW(K(0), K(0), K(2)) \ + KXNORW(K(0), K(0), K(3)) \ + KXNORW(K(0), K(0), K(4)) \ + VSCATTERQPD(ZMM(R1), MEM(RCX,ZMM(2),1) MASK_K(1)) \ + ADD(IMM(1*8), RCX) \ + KXNORW(K(0), K(0), K(1)) \ + VSCATTERQPD(ZMM(R2), MEM(RCX,ZMM(2),1) MASK_K(2)) \ + ADD(IMM(1*8), RCX) \ + KXNORW(K(0), K(0), K(2)) \ + VSCATTERQPD(ZMM(R3), MEM(RCX,ZMM(2),1) MASK_K(3)) \ + ADD(IMM(1*8), RCX) \ + KXNORW(K(0), K(0), K(3)) \ + VSCATTERQPD(ZMM(R4), MEM(RCX,ZMM(2),1) MASK_K(4)) \ + ADD(IMM(1*8), RCX) \ + KXNORW(K(0), K(0), K(4)) \ + VSCATTERQPD(ZMM(R5), MEM(RCX,ZMM(2),1) MASK_K(1)) \ + ADD(IMM(1*8), RCX) \ + KXNORW(K(0), K(0), K(1)) \ + VSCATTERQPD(ZMM(R6), MEM(RCX,ZMM(2),1) MASK_K(2)) \ + ADD(IMM(1*8), RCX) \ + KXNORW(K(0), K(0), K(2)) \ + VSCATTERQPD(ZMM(R7), MEM(RCX,ZMM(2),1) MASK_K(3)) \ + ADD(IMM(1*8), RCX) \ + KXNORW(K(0), K(0), K(3)) \ + VSCATTERQPD(ZMM(R18), MEM(RCX,ZMM(2),1) MASK_K(4)) \ + \ + MOV(R8, RCX) \ + \ + VMOVUPD(MEM(RCX), ZMM(R1))\ + ADD(R12, RCX) \ + VMOVUPD(MEM(RCX), ZMM(R2)) \ + ADD(R12, RCX) \ + VMOVUPD(MEM(RCX), ZMM(R3)) \ + ADD(R12, RCX) \ + VMOVUPD(MEM(RCX), ZMM(R4)) \ + ADD(R12, RCX) \ + VMOVUPD(MEM(RCX), ZMM(R5)) \ + ADD(R12, RCX) \ + VMOVUPD(MEM(RCX), ZMM(R6)) \ + ADD(R12, RCX) \ + VMOVUPD(MEM(RCX), ZMM(R7)) \ + ADD(R12, RCX) \ + VMOVUPD(MEM(RCX), ZMM(R18)) \ + +/* +Transpose six zmm registers and store the output in the given 8 registers + Note: Require offsetPointer for scatter instruction + and 512 bytes of free memory (rcx) for transpose. + Input : + R1 = [ 0, 1, 2, 3, 4, 5, 6, 7] + R2 = [ 8, 9, 10, 11, 12, 13, 14, 15] + R3 = [16, 17, 18, 19, 20, 21, 22, 23] + R4 = [24, 25, 26, 27, 28, 29, 30, 31] + R5 = [32, 33, 34, 35, 36, 37, 38, 39] + R6 = [40, 41, 42, 43, 44, 45, 46, 47] + Output : + R1 = [0, 8, 16, 24, 32, 40, -, -] + R2 = [1, 9, 17, 25, 33, 41, -, -] + R3 = [2, 10, 18, 26, 34, 42, -, -] + R4 = [3, 11, 19, 27, 35, 43, -, -] + R5 = [4, 12, 20, 28, 36, 44, -, -] + R6 = [5, 13, 21, 29, 37, 45, -, -] + R7 = [6, 14, 22, 30, 38, 46, -, -] + R18 = [7, 15, 23, 31, 39, 47, -, -] +*/ +#define TRANSPOSE_REGISTERS_6x8(R1, R2, R3, R4, R5, R6, R7, R18) \ +\ + MOV(R8, RCX) \ + MOV(VAR(cs_c), RSI) \ + MOV(R9, RDI) \ + LEA(MEM(RCX, RSI, 8), RDX) \ + MOV(VAR(offsetPtr), R13) \ + MOV(RDI, R12) \ + CMP(RSI, R12) \ + CMOVL(RSI, R12) \ + VPBROADCASTQ(R12, ZMM(0)) \ + VPMULLQ(MEM(R13 ), ZMM(0), ZMM(2)) \ + LEA(MEM(RCX, R12, 4), RCX) \ + LEA(MEM(RCX, R12, 1), RCX) \ + \ + KXNORW(K(0), K(0), K(1)) \ + KXNORW(K(0), K(0), K(2)) \ + KXNORW(K(0), K(0), K(3)) \ + KXNORW(K(0), K(0), K(4)) \ + VSCATTERQPD(ZMM(R1), MEM(RCX,ZMM(2),1) MASK_K(1)) \ + ADD(IMM(1*8), RCX) \ + KXNORW(K(0), K(0), K(1)) \ + VSCATTERQPD(ZMM(R2), MEM(RCX,ZMM(2),1) MASK_K(2)) \ + ADD(IMM(1*8), RCX) \ + KXNORW(K(0), K(0), K(2)) \ + VSCATTERQPD(ZMM(R3), MEM(RCX,ZMM(2),1) MASK_K(3)) \ + ADD(IMM(1*8), RCX) \ + KXNORW(K(0), K(0), K(3)) \ + VSCATTERQPD(ZMM(R4), MEM(RCX,ZMM(2),1) MASK_K(4)) \ + ADD(IMM(1*8), RCX) \ + KXNORW(K(0), K(0), K(4)) \ + VSCATTERQPD(ZMM(R5), MEM(RCX,ZMM(2),1) MASK_K(1)) \ + ADD(IMM(1*8), RCX) \ + KXNORW(K(0), K(0), K(1)) \ + VSCATTERQPD(ZMM(R6), MEM(RCX,ZMM(2),1) MASK_K(2)) \ + \ + MOV(R8, RCX) \ + LEA(MEM(RCX, R12, 4), RCX) \ + LEA(MEM(RCX, R12, 1), RCX) \ + \ + VMOVUPD(MEM(RCX), ZMM(R1))\ + ADD(R12, RCX) \ + VMOVUPD(MEM(RCX), ZMM(R2)) \ + ADD(R12, RCX) \ + VMOVUPD(MEM(RCX), ZMM(R3)) \ + ADD(R12, RCX) \ + VMOVUPD(MEM(RCX), ZMM(R4)) \ + ADD(R12, RCX) \ + VMOVUPD(MEM(RCX), ZMM(R5)) \ + ADD(R12, RCX) \ + VMOVUPD(MEM(RCX), ZMM(R6)) \ + ADD(R12, RCX) \ + VMOVUPD(MEM(RCX), ZMM(R7)) \ + ADD(R12, RCX) \ + VMOVUPD(MEM(RCX), ZMM(R18)) \ + +// Offsets for scatter/gather instructions +static int64_t offsets[16] __attribute__((aligned(64))) = + { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12,13,14,15}; + + +void bli_dgemmtrsm_u_zen_asm_16x14 +( + dim_t k_, + double* restrict alpha, + double* restrict a10, + double* restrict a11, + double* restrict b01, + double* restrict b11, + double* restrict c11, inc_t rs_c_, inc_t cs_c_, + auxinfo_t* restrict data, + cntx_t* restrict cntx +) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_9); + const int64_t k = k_; + uint64_t rs_c = rs_c_ * 8; + const int64_t* offsetPtr = &offsets[0]; + uint64_t cs_c = cs_c_ * 8; + + BEGIN_ASM() + + //clear out registers + VXORPD(ZMM(4), ZMM(4), ZMM(4)) + VMOVAPD(ZMM(4), ZMM(5) ) + VMOVAPD(ZMM(4), ZMM(6) ) + VMOVAPD(ZMM(4), ZMM(7) ) + VMOVAPD(ZMM(4), ZMM(8) ) + VMOVAPD(ZMM(4), ZMM(9) ) + VMOVAPD(ZMM(4), ZMM(10)) + VMOVAPD(ZMM(4), ZMM(11)) + VMOVAPD(ZMM(4), ZMM(12)) + VMOVAPD(ZMM(4), ZMM(13)) + VMOVAPD(ZMM(4), ZMM(14)) + VMOVAPD(ZMM(4), ZMM(15)) + VMOVAPD(ZMM(4), ZMM(16)) + VMOVAPD(ZMM(4), ZMM(17)) + VMOVAPD(ZMM(4), ZMM(18)) + VMOVAPD(ZMM(4), ZMM(19)) + VMOVAPD(ZMM(4), ZMM(20)) + VMOVAPD(ZMM(4), ZMM(21)) + VMOVAPD(ZMM(4), ZMM(22)) + VMOVAPD(ZMM(4), ZMM(23)) + VMOVAPD(ZMM(4), ZMM(24)) + VMOVAPD(ZMM(4), ZMM(25)) + VMOVAPD(ZMM(4), ZMM(26)) + VMOVAPD(ZMM(4), ZMM(27)) + VMOVAPD(ZMM(4), ZMM(28)) + VMOVAPD(ZMM(4), ZMM(29)) + VMOVAPD(ZMM(4), ZMM(30)) + VMOVAPD(ZMM(4), ZMM(31)) + + MOV(VAR(k), RSI) + + MOV(VAR(a10), RAX) // load address of a + MOV(VAR(b01), RBX) // load address of b + MOV(VAR(c11), R8) // load address of c + + LEA(MEM(RSI,RSI,2), RDX) + LEA(MEM(,RDX,4), RDX) + LEA(MEM(RDX,RSI,4), RDX) // RDX = 16 * K + LEA(MEM(RAX,RDX,8,-128), RDX) // RDX = a_next for prefetching + LEA(MEM(R8,63), R12) // c for prefetching + + MOV(VAR(rs_c), R9) + MOV(VAR(cs_c), R13) + + MOV(IMM(0), R11) + MOV(R13, R15) + + CMP(IMM(8), R13) + JNE(.DBEFORELOOP) + MOV(IMM(2), R11) + MOV(R9, R15) + + LABEL(.DBEFORELOOP) + + VMOVAPD(MEM(RAX, 0*8*8), ZMM(0)) + VMOVAPD(MEM(RAX, 1*8*8), ZMM(1)) // preload a + + MOV(RSI, R10) + AND(IMM(3), R10) // R10 = K % 4 + SAR(IMM(2), RSI) // RSI = K / 4 + + /* + MAIN LOOP + Note: This loop runs (K/4 - 14 - TAIL_NITER) times + */ + SUB(R11, RSI) + SUB(IMM(14+TAIL_NITER), RSI) + JLE(K_LE_80) + + LOOP_ALIGN + LABEL(LOOP1) + SUBITER(0) + PREFETCH(1, MEM(RDX)) + SUBITER(1) + SUB(IMM(1), RSI) + SUBITER(2) + PREFETCH(1, MEM(RDX,64)) + SUBITER(3) + + LEA(MEM(RAX,4*16*8), RAX) + LEA(MEM(RBX,4*14*8), RBX) + LEA(MEM(RDX,16*8), RDX) + + JNZ(LOOP1) + + LABEL(K_LE_80) + + /* + C prefetch Loop + Note: This loop runs 14 times, + These 14 iterations are done seperately so that c11 can be prefetched here. + */ + ADD(R11, RSI) + ADD(IMM(14), RSI) + JLE(K_LE_24) + + LOOP_ALIGN + LABEL(LOOP2) + PREFETCH(0, MEM(R12)) + SUBITER(0) + PREFETCH(1, MEM(RDX)) + SUBITER(1) + PREFETCH(0, MEM(R12,64)) + SUB(IMM(1), RSI) + SUBITER(2) + PREFETCH(1, MEM(RDX,64)) + SUBITER(3) + + LEA(MEM(RAX,4*16*8), RAX) + LEA(MEM(RBX,4*14*8), RBX) + LEA(MEM(RDX,16*8), RDX) + LEA(MEM(R12,R15,1), R12) + + JNZ(LOOP2) + + LABEL(K_LE_24) + + /* + TAIL_NITER Loop + Note: This loop runs TAIL_NITER times, + This loop is used to provide some distance between c11 prefetch and usage of c11. + */ + ADD(IMM(0+TAIL_NITER), RSI) + JLE(TAIL) + + LOOP_ALIGN + LABEL(LOOP3) + + SUBITER(0) + PREFETCH(1, MEM(RDX)) + SUBITER(1) + SUB(IMM(1), RSI) + SUBITER(2) + PREFETCH(1, MEM(RDX,64)) + SUBITER(3) + + LEA(MEM(RAX,4*16*8), RAX) + LEA(MEM(RBX,4*14*8), RBX) + LEA(MEM(RDX,16*8), RDX) + + JNZ(LOOP3) + + /* + K Left Loop + This loop runs K % 4 times. + */ + LABEL(TAIL) + MOV(R10, RSI) + TEST(RSI, RSI) + JE(.DPOSTACCUM) + LOOP_ALIGN + LABEL(TAIL_LOOP) + + SUB(IMM(1), RSI) + SUBITER(0) + + LEA(MEM(RAX,16*8), RAX) + LEA(MEM(RBX,14*8), RBX) + + JNZ(TAIL_LOOP) + + LABEL(.DPOSTACCUM) + + /* GEMM output before transpose GEMM output after transpose + __________________________________ + ___________________________ |______zmm4______|______zmm20___x x| + | | | | | | | | | | | | | | | |______zmm6______|______zmm22___x x| + |z|z|z|z|z|z|z|z|z|z|z|z|z|z| |______zmm8______|______zmm24___x x| + |m|m|m|m|m|m|m|m|m|m|m|m|m|m| |______zmm10_____|______zmm26___x x| + |m|m|m|m|m|m|m|m|m|m|m|m|m|m| |______zmm12_____|______zmm28___x x| + |4|6|8|1|1|1|1|1|2|2|2|2|2|3| |______zmm14_____|______zmm30___x x| + | | | |0|2|4|6|8|0|2|4|6|8|0| |______zmm16_____|_____c11______x x| + | | | | | | | | | | | | | | | |______zmm18_____|_____c11+cs___x x| + ____________________________ |______zmm5______|______zmm21___x x| + | | | | | | | | | | | | | | | |______zmm7______|______zmm23___x x| + |z|z|z|z|z|z|z|z|z|z|z|z|z|z| |______zmm9______|______zmm25___x x| + |m|m|m|m|m|m|m|m|m|m|m|m|m|m| |______zmm11_____|______zmm27___x x| + |m|m|m|m|m|m|m|m|m|m|m|m|m|m| |______zmm13_____|______zmm29___x x| + |5|7|9|1|1|1|1|1|2|2|2|2|2|3| |______zmm15_____|______zmm31___x x| + | | | |1|3|5|7|9|1|3|5|7|9|1| |______zmm17_____|____c11+cs*2__x x| + | | | | | | | | | | | | | | | |______zmm19_____|____c11+cs*4__x x| + _____________________________ + */ + TRANSPOSE_REGISTERS_8x8(4, 6, 8, 10, 12, 14, 16, 18) // transpose the output of GEMM + TRANSPOSE_REGISTERS_8x8(5, 7, 9, 11, 13, 15, 17, 19) + TRANSPOSE_REGISTERS_6x8(20, 22, 24, 26, 28, 30, 0, 1) + VMOVUPD(ZMM(0), MEM(R8 )) + VMOVUPD(ZMM(1), MEM(R8, R12, 1)) // zmm0 and zmm1 are needed for other computations, + // therefore store zmm0, zmm1 's data in rcx + TRANSPOSE_REGISTERS_6x8(21, 23, 25, 27, 29, 31, 0, 1) + VMOVUPD(ZMM(0), MEM(R8, R12, 2)) + VMOVUPD(ZMM(1), MEM(R8, R12, 4)) // zmm0 and zmm1 are needed for other computations, + // therefore store zmm0, zmm1 's data in rcx + MOV(IMM(14), RDI) + LEA(MEM(, RDI, 8), RDI) + + MOV(VAR(alpha), RBX) + VBROADCASTSD(MEM(RBX), ZMM(3)) + + MOV(IMM(1), RSI) + LEA(MEM(, RSI, 8), RSI) + + MOV(VAR(b11), RCX) + LEA(MEM(RCX, RSI, 8), RDX) + + MOV(RCX, R11) + MOV(RDX, R14) + + // Subtract b11 from GEMM output + VFMSUB231PD(MEM(RCX), ZMM(3), ZMM(4)) + ADD(RDI, RCX) + VFMSUB231PD(MEM(RCX), ZMM(3), ZMM(6)) + ADD(RDI, RCX) + VFMSUB231PD(MEM(RCX), ZMM(3), ZMM(8)) + ADD(RDI, RCX) + VFMSUB231PD(MEM(RCX), ZMM(3), ZMM(10)) + ADD(RDI, RCX) + VFMSUB231PD(MEM(RCX), ZMM(3), ZMM(12)) + ADD(RDI, RCX) + VFMSUB231PD(MEM(RCX), ZMM(3), ZMM(14)) + ADD(RDI, RCX) + VFMSUB231PD(MEM(RCX), ZMM(3), ZMM(16)) + ADD(RDI, RCX) + VFMSUB231PD(MEM(RCX), ZMM(3), ZMM(18)) + ADD(RDI, RCX) + + VFMSUB231PD(MEM(RCX), ZMM(3), ZMM(5)) + ADD(RDI, RCX) + VFMSUB231PD(MEM(RCX), ZMM(3), ZMM(7)) + ADD(RDI, RCX) + VFMSUB231PD(MEM(RCX), ZMM(3), ZMM(9)) + ADD(RDI, RCX) + VFMSUB231PD(MEM(RCX), ZMM(3), ZMM(11)) + ADD(RDI, RCX) + VFMSUB231PD(MEM(RCX), ZMM(3), ZMM(13)) + ADD(RDI, RCX) + VFMSUB231PD(MEM(RCX), ZMM(3), ZMM(15)) + ADD(RDI, RCX) + VFMSUB231PD(MEM(RCX), ZMM(3), ZMM(17)) + ADD(RDI, RCX) + VFMSUB231PD(MEM(RCX), ZMM(3), ZMM(19)) + + + + VFMSUB231PD(MEM(RDX), ZMM(3), ZMM(20)) + ADD(RDI, RDX) + VFMSUB231PD(MEM(RDX), ZMM(3), ZMM(22)) + ADD(RDI, RDX) + VFMSUB231PD(MEM(RDX), ZMM(3), ZMM(24)) + ADD(RDI, RDX) + VFMSUB231PD(MEM(RDX), ZMM(3), ZMM(26)) + ADD(RDI, RDX) + VFMSUB231PD(MEM(RDX), ZMM(3), ZMM(28)) + ADD(RDI, RDX) + VFMSUB231PD(MEM(RDX), ZMM(3), ZMM(30)) + ADD(RDI, RDX) + VMOVUPD(MEM(R8 ), ZMM(0)) + VMOVUPD(MEM(R8, R12, 1), ZMM(1)) + VFMSUB231PD(MEM(RDX), ZMM(3), ZMM(0)) + VMOVUPD(ZMM(0), MEM(R8 )) + ADD(RDI, RDX) + VFMSUB231PD(MEM(RDX), ZMM(3), ZMM(1)) + VMOVUPD(ZMM(1), MEM(R8, R12, 1)) + ADD(RDI, RDX) + + VFMSUB231PD(MEM(RDX), ZMM(3), ZMM(21)) + ADD(RDI, RDX) + VFMSUB231PD(MEM(RDX), ZMM(3), ZMM(23)) + ADD(RDI, RDX) + VFMSUB231PD(MEM(RDX), ZMM(3), ZMM(25)) + ADD(RDI, RDX) + VFMSUB231PD(MEM(RDX), ZMM(3), ZMM(27)) + ADD(RDI, RDX) + VFMSUB231PD(MEM(RDX), ZMM(3), ZMM(29)) + ADD(RDI, RDX) + VFMSUB231PD(MEM(RDX), ZMM(3), ZMM(31)) + ADD(RDI, RDX) + VMOVUPD(MEM(R8, R12, 2), ZMM(0)) + VMOVUPD(MEM(R8, R12, 4), ZMM(1)) + VFMSUB231PD(MEM(RDX), ZMM(3), ZMM(0)) + VMOVUPD(ZMM(0), MEM(R8, R12, 2)) + ADD(RDI, RDX) + VFMSUB231PD(MEM(RDX), ZMM(3), ZMM(1)) + VMOVUPD(ZMM(1), MEM(R8, R12, 4)) + + /* + TRSM region + Each row requires 1 iteration, therefore 16 iterations are present + */ + MOV(VAR(a11), RAX) + MOV(R11, RCX) + MOV(R14, RDX) + + LEA(MEM(RDI, RDI, 4), R14) + LEA(MEM(R14, R14, 2), R14) // R14 = RDI * 15 + LEA(MEM(RCX, R14, 1), RCX) // rcx = b11 + (16-1)*rs_b + LEA(MEM(RDX, R14, 1), RDX) // rdx = b11 + (16-1)*rs_b + 8*cs_b + + //iteration 0 -------------------------------------------- + VBROADCASTSD(MEM(RAX, (15+15*16)*8), ZMM(0)) + + VMOVUPD(MEM(R8, R12, 4), ZMM(1)) + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + VMULPD(ZMM(0), ZMM(19), ZMM(19)) + VMULPD(ZMM(0), ZMM(1) , ZMM(1) ) + #else + VDIVPD(ZMM(0), ZMM(19),ZMM(19) ) + VDIVPD(ZMM(0), ZMM(1) , ZMM(1) ) + #endif + VMOVUPD(ZMM(1), MEM(R8, R12, 4)) + + VMOVUPD(ZMM(19), MEM(RCX)) + VEXTRACTF64X4(IMM(1), ZMM(1), YMM(0)) + VMOVUPD(YMM(1), MEM(RDX )) + VMOVUPD(XMM(0), MEM(RDX,4*8)) // move only first six values to rcx + SUB(RDI, RCX) + SUB(RDI, RDX) + + //iteration 1 -------------------------------------------- + VBROADCASTSD(MEM(RAX, (14+15*16)*8), ZMM(0)) + VMOVUPD(MEM(R8, R12, 4), ZMM(1)) + + VMULPD(ZMM(0), ZMM(19), ZMM(2)) + VMULPD(ZMM(0), ZMM(1) , ZMM(3)) + + VBROADCASTSD(MEM(RAX, (14+14*16)*8), ZMM(1)) + VMOVUPD(MEM(R8, R12, 2), ZMM(0)) + + VSUBPD(ZMM(2), ZMM(17), ZMM(17)) + VSUBPD(ZMM(3), ZMM(0) , ZMM(0) ) + + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + VMULPD(ZMM(1), ZMM(17), ZMM(17)) + VMULPD(ZMM(1), ZMM(0) , ZMM(0) ) + #else + VDIVPD(ZMM(1), ZMM(17), ZMM(17)) + VDIVPD(ZMM(1), ZMM(0) , ZMM(0) ) + #endif + + VMOVUPD(ZMM(0), MEM(R8, R12, 2)) + + VMOVUPD(ZMM(17), MEM(RCX)) + + VEXTRACTF64X4(IMM(1), ZMM(0), YMM(1)) + VMOVUPD(YMM(0), MEM(RDX )) + VMOVUPD(XMM(1), MEM(RDX,4*8)) + SUB(RDI, RCX) + SUB(RDI, RDX) + + //iteration 2 -------------------------------------------- + VBROADCASTSD(MEM(RAX, (13+15*16)*8), ZMM(0)) + VMOVUPD(MEM(R8, R12, 4), ZMM(1)) + + VMULPD(ZMM(0), ZMM(19), ZMM(2)) + VMULPD(ZMM(0), ZMM(1) , ZMM(3)) + + + VBROADCASTSD(MEM(RAX, (13+14*16)*8), ZMM(1)) + VMOVUPD(MEM(R8, R12, 2), ZMM(0)) + + VFMADD231PD(ZMM(1), ZMM(17), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(0) , ZMM(3)) + + VBROADCASTSD(MEM(RAX, (13+13*16)*8), ZMM(0)) + VSUBPD(ZMM(2), ZMM(15), ZMM(15)) + VSUBPD(ZMM(3), ZMM(31), ZMM(31)) + + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + VMULPD(ZMM(0), ZMM(15), ZMM(15)) + VMULPD(ZMM(0), ZMM(31), ZMM(31)) + #else + VDIVPD(ZMM(0), ZMM(15), ZMM(15)) + VDIVPD(ZMM(0), ZMM(31), ZMM(31)) + #endif + + VMOVUPD(ZMM(15), MEM(RCX)) + VEXTRACTF64X4(IMM(1), ZMM(31), YMM(0)) + VMOVUPD(YMM(31), MEM(RDX )) + VMOVUPD(XMM(0) , MEM(RDX,4*8)) + SUB(RDI, RCX) + SUB(RDI, RDX) + + //iteration 3 -------------------------------------------- + VBROADCASTSD(MEM(RAX, (12+15*16)*8), ZMM(0)) + VMOVUPD(MEM(R8, R12, 4), ZMM(1)) + + VMULPD(ZMM(0), ZMM(19), ZMM(2)) + VMULPD(ZMM(0), ZMM(1) , ZMM(3)) + + VBROADCASTSD(MEM(RAX, (12+14*16)*8), ZMM(1)) + VMOVUPD(MEM(R8, R12, 2), ZMM(0)) + + VFMADD231PD(ZMM(1), ZMM(17), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(0) , ZMM(3)) + + VBROADCASTSD(MEM(RAX, (12+13*16)*8), ZMM(0)) + VBROADCASTSD(MEM(RAX, (12+12*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(15), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(31), ZMM(3)) + + VSUBPD(ZMM(2), ZMM(13), ZMM(13)) + VSUBPD(ZMM(3), ZMM(29), ZMM(29)) + + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + VMULPD(ZMM(1), ZMM(13), ZMM(13)) + VMULPD(ZMM(1), ZMM(29), ZMM(29)) + #else + VDIVPD(ZMM(1), ZMM(13), ZMM(13)) + VDIVPD(ZMM(1), ZMM(29), ZMM(29)) + #endif + + VMOVUPD(ZMM(13), MEM(RCX)) + VEXTRACTF64X4(IMM(1), ZMM(29), YMM(0)) + VMOVUPD(YMM(29), MEM(RDX )) + VMOVUPD(XMM(0) , MEM(RDX,4*8)) + SUB(RDI, RCX) + SUB(RDI, RDX) + + //iteration 4 -------------------------------------------- + VBROADCASTSD(MEM(RAX, (11+15*16)*8), ZMM(0)) + VMOVUPD(MEM(R8, R12, 4), ZMM(1)) + + VMULPD(ZMM(0), ZMM(19), ZMM(2)) + VMULPD(ZMM(0), ZMM(1) , ZMM(3)) + + VBROADCASTSD(MEM(RAX, (11+14*16)*8), ZMM(1)) + VMOVUPD(MEM(R8, R12, 2), ZMM(0)) + + VFMADD231PD(ZMM(1), ZMM(17), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(0) , ZMM(3)) + + VBROADCASTSD(MEM(RAX, (11+13*16)*8), ZMM(0)) + VBROADCASTSD(MEM(RAX, (11+12*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(15), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(31), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (11+11*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(13), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(29), ZMM(3)) + + VSUBPD(ZMM(2), ZMM(11), ZMM(11)) + VSUBPD(ZMM(3), ZMM(27), ZMM(27)) + + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + VMULPD(ZMM(0), ZMM(11), ZMM(11)) + VMULPD(ZMM(0), ZMM(27), ZMM(27)) + #else + VDIVPD(ZMM(0), ZMM(11), ZMM(11)) + VDIVPD(ZMM(0), ZMM(27), ZMM(27)) + #endif + + VMOVUPD(ZMM(11), MEM(RCX)) + VEXTRACTF64X4(IMM(1), ZMM(27), YMM(0)) + VMOVUPD(YMM(27), MEM(RDX )) + VMOVUPD(XMM(0) , MEM(RDX,4*8)) + SUB(RDI, RCX) + SUB(RDI, RDX) + + //iteration 5 -------------------------------------------- + VBROADCASTSD(MEM(RAX, (10+15*16)*8), ZMM(0)) + VMOVUPD(MEM(R8, R12, 4), ZMM(1)) + + VMULPD(ZMM(0), ZMM(19), ZMM(2)) + VMULPD(ZMM(0), ZMM(1) , ZMM(3)) + + VBROADCASTSD(MEM(RAX, (10+14*16)*8), ZMM(1)) + VMOVUPD(MEM(R8, R12, 2), ZMM(0)) + + VFMADD231PD(ZMM(1), ZMM(17), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(0) , ZMM(3)) + + VBROADCASTSD(MEM(RAX, (10+13*16)*8), ZMM(0)) + VBROADCASTSD(MEM(RAX, (10+12*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(15), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(31), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (10+11*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(13), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(29), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (10+10*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(11), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(27), ZMM(3)) + + VSUBPD(ZMM(2), ZMM(9) , ZMM(9) ) + VSUBPD(ZMM(3), ZMM(25), ZMM(25)) + + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + VMULPD(ZMM(1), ZMM(9) , ZMM(9) ) + VMULPD(ZMM(1), ZMM(25), ZMM(25)) + #else + VDIVPD(ZMM(1), ZMM(9) , ZMM(9) ) + VDIVPD(ZMM(1), ZMM(25), ZMM(25)) + #endif + + VMOVUPD(ZMM(9) , MEM(RCX)) + VEXTRACTF64X4(IMM(1), ZMM(25), YMM(0)) + VMOVUPD(YMM(25), MEM(RDX )) + VMOVUPD(XMM(0) , MEM(RDX,4*8)) + SUB(RDI, RCX) + SUB(RDI, RDX) + + //iteration 6 -------------------------------------------- + VBROADCASTSD(MEM(RAX, (9+15*16)*8), ZMM(0)) + VMOVUPD(MEM(R8, R12, 4), ZMM(1)) + + VMULPD(ZMM(0), ZMM(19), ZMM(2)) + VMULPD(ZMM(0), ZMM(1) , ZMM(3)) + + VBROADCASTSD(MEM(RAX, (9+14*16)*8), ZMM(1)) + VMOVUPD(MEM(R8, R12, 2), ZMM(0)) + + VFMADD231PD(ZMM(1), ZMM(17), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(0) , ZMM(3)) + + VBROADCASTSD(MEM(RAX, (9+13*16)*8), ZMM(0)) + VBROADCASTSD(MEM(RAX, (9+12*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(15), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(31), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (9+11*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(13), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(29), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (9+10*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(11), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(27), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (9+9*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(9) , ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(25), ZMM(3)) + + VSUBPD(ZMM(2), ZMM(7) , ZMM(7) ) + VSUBPD(ZMM(3), ZMM(23), ZMM(23)) + + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + VMULPD(ZMM(0), ZMM(7) , ZMM(7) ) + VMULPD(ZMM(0), ZMM(23), ZMM(23)) + #else + VDIVPD(ZMM(0), ZMM(7) , ZMM(7) ) + VDIVPD(ZMM(0), ZMM(23), ZMM(23)) + #endif + + VMOVUPD(ZMM(7), MEM(RCX)) + VEXTRACTF64X4(IMM(1), ZMM(23), YMM(0)) + VMOVUPD(YMM(23), MEM(RDX )) + VMOVUPD(XMM(0) , MEM(RDX,4*8)) + SUB(RDI, RCX) + SUB(RDI, RDX) + + //iteration 7 -------------------------------------------- + VBROADCASTSD(MEM(RAX, (8+15*16)*8), ZMM(0)) + VMOVUPD(MEM(R8, R12, 4), ZMM(1)) + + VMULPD(ZMM(0), ZMM(19), ZMM(2)) + VMULPD(ZMM(0), ZMM(1) , ZMM(3)) + + VBROADCASTSD(MEM(RAX, (8+14*16)*8), ZMM(1)) + VMOVUPD(MEM(R8, R12, 2), ZMM(0)) + + VFMADD231PD(ZMM(1), ZMM(17), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(0) , ZMM(3)) + + VBROADCASTSD(MEM(RAX, (8+13*16)*8), ZMM(0)) + VBROADCASTSD(MEM(RAX, (8+12*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(15), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(31), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (8+11*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(13), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(29), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (8+10*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(11), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(27), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (8+9*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(9) , ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(25), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (8+8*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(7) , ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(23), ZMM(3)) + + VSUBPD(ZMM(2), ZMM(5) , ZMM(5) ) + VSUBPD(ZMM(3), ZMM(21), ZMM(21)) + + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + VMULPD(ZMM(1), ZMM(5) , ZMM(5) ) + VMULPD(ZMM(1), ZMM(21), ZMM(21)) + #else + VDIVPD(ZMM(1), ZMM(5) , ZMM(5) ) + VDIVPD(ZMM(1), ZMM(21), ZMM(21)) + #endif + + VMOVUPD(ZMM(5) , MEM(RCX)) + VEXTRACTF64X4(IMM(1), ZMM(21), YMM(1)) + VMOVUPD(YMM(21), MEM(RDX )) + VMOVUPD(XMM(1), MEM(RDX,4*8)) + SUB(RDI, RCX) + SUB(RDI, RDX) + + //iteration 8 -------------------------------------------- + VBROADCASTSD(MEM(RAX, (7+15*16)*8), ZMM(0)) + VMOVUPD(MEM(R8, R12, 4), ZMM(1)) + + VMULPD(ZMM(0), ZMM(19), ZMM(2)) + VMULPD(ZMM(0), ZMM(1) , ZMM(3)) + + VBROADCASTSD(MEM(RAX, (7+14*16)*8), ZMM(1)) + VMOVUPD(MEM(R8, R12, 2), ZMM(0)) + + VFMADD231PD(ZMM(1), ZMM(17), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(0) , ZMM(3)) + + VBROADCASTSD(MEM(RAX, (7+13*16)*8), ZMM(0)) + VBROADCASTSD(MEM(RAX, (7+12*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(15), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(31), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (7+11*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(13), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(29), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (7+10*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(11), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(27), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (7+9*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(9) , ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(25), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (7+8*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(7) , ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(23), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (7+7*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(5), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(21) , ZMM(3)) + + VMOVUPD(MEM(R8, R12, 1), ZMM(1)) + VSUBPD(ZMM(2), ZMM(18), ZMM(18)) + VSUBPD(ZMM(3), ZMM(1) , ZMM(1) ) + + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + VMULPD(ZMM(0), ZMM(18), ZMM(18)) + VMULPD(ZMM(0), ZMM(1) , ZMM(1) ) + #else + VDIVPD(ZMM(0), ZMM(18), ZMM(18)) + VDIVPD(ZMM(0), ZMM(1) , ZMM(1) ) + #endif + VMOVUPD(ZMM(1), MEM(R8, R12, 1)) + + VMOVUPD(ZMM(18), MEM(RCX)) + VEXTRACTF64X4(IMM(1), ZMM(1), YMM(0)) + VMOVUPD(YMM(1), MEM(RDX )) + VMOVUPD(XMM(0), MEM(RDX,4*8)) + SUB(RDI, RCX) + SUB(RDI, RDX) + + //iteration 9 -------------------------------------------- + VBROADCASTSD(MEM(RAX, (6+15*16)*8), ZMM(0)) + VMOVUPD(MEM(R8, R12, 4), ZMM(1)) + + VMULPD(ZMM(0), ZMM(19), ZMM(2)) + VMULPD(ZMM(0), ZMM(1) , ZMM(3)) + + VBROADCASTSD(MEM(RAX, (6+14*16)*8), ZMM(1)) + VMOVUPD(MEM(R8, R12, 2), ZMM(0)) + + VFMADD231PD(ZMM(1), ZMM(17), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(0) , ZMM(3)) + + VBROADCASTSD(MEM(RAX, (6+13*16)*8), ZMM(0)) + VBROADCASTSD(MEM(RAX, (6+12*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(15), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(31), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (6+11*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(13), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(29), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (6+10*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(11), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(27), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (6+9*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(9) , ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(25), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (6+8*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(7) , ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(23), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (6+7*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(5) , ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(21), ZMM(3)) + + VMOVUPD(MEM(R8, R12, 1), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(18), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(1) , ZMM(3)) + VBROADCASTSD(MEM(RAX, (6+6*16)*8), ZMM(1)) + + VMOVUPD(MEM(R8 ), ZMM(0)) + VSUBPD(ZMM(2), ZMM(16), ZMM(16)) + VSUBPD(ZMM(3), ZMM(0) , ZMM(0) ) + + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + VMULPD(ZMM(1), ZMM(16), ZMM(16)) + VMULPD(ZMM(1), ZMM(0) , ZMM(0) ) + #else + VDIVPD(ZMM(1), ZMM(16), ZMM(16)) + VDIVPD(ZMM(1), ZMM(0) , ZMM(0) ) + #endif + + VMOVUPD(ZMM(0), MEM(R8 )) + VMOVUPD(ZMM(16), MEM(RCX)) + VEXTRACTF64X4(IMM(1), ZMM(0), YMM(1)) + VMOVUPD(YMM(0), MEM(RDX )) + VMOVUPD(XMM(1), MEM(RDX,4*8)) + SUB(RDI, RCX) + SUB(RDI, RDX) + + //iteration 10 -------------------------------------------- + VBROADCASTSD(MEM(RAX, (5+15*16)*8), ZMM(0)) + VMOVUPD(MEM(R8, R12, 4), ZMM(1)) + + VMULPD(ZMM(0), ZMM(19), ZMM(2)) + VMULPD(ZMM(0), ZMM(1) , ZMM(3)) + + VBROADCASTSD(MEM(RAX, (5+14*16)*8), ZMM(1)) + VMOVUPD(MEM(R8, R12, 2), ZMM(0)) + + VFMADD231PD(ZMM(1), ZMM(17), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(0) , ZMM(3)) + + VBROADCASTSD(MEM(RAX, (5+13*16)*8), ZMM(0)) + VBROADCASTSD(MEM(RAX, (5+12*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(15), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(31), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (5+11*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(13), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(29), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (5+10*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(11), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(27), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (5+9*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(9) , ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(25), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (5+8*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(7) , ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(23), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (5+7*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(5) , ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(21), ZMM(3)) + + VMOVUPD(MEM(R8, R12, 1), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(18), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(1) , ZMM(3)) + VBROADCASTSD(MEM(RAX, (5+6*16)*8), ZMM(1)) + + VMOVUPD(MEM(R8 ), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(16), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(0) , ZMM(3)) + VBROADCASTSD(MEM(RAX, (5+5*16)*8), ZMM(0)) + + VSUBPD(ZMM(2), ZMM(14), ZMM(14)) + VSUBPD(ZMM(3), ZMM(30), ZMM(30)) + + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + VMULPD(ZMM(0), ZMM(14), ZMM(14)) + VMULPD(ZMM(0), ZMM(30), ZMM(30)) + #else + VDIVPD(ZMM(0), ZMM(14), ZMM(14)) + VDIVPD(ZMM(0), ZMM(30), ZMM(30)) + #endif + + VMOVUPD(ZMM(14), MEM(RCX)) + VEXTRACTF64X4(IMM(1), ZMM(30), YMM(1)) + VMOVUPD(YMM(30), MEM(RDX )) + VMOVUPD(XMM(1), MEM(RDX,4*8)) + SUB(RDI, RCX) + SUB(RDI, RDX) + + //iteration 11 -------------------------------------------- + VBROADCASTSD(MEM(RAX, (4+15*16)*8), ZMM(0)) + VMOVUPD(MEM(R8, R12, 4), ZMM(1)) + + VMULPD(ZMM(0), ZMM(19), ZMM(2)) + VMULPD(ZMM(0), ZMM(1) , ZMM(3)) + + VBROADCASTSD(MEM(RAX, (4+14*16)*8), ZMM(1)) + VMOVUPD(MEM(R8, R12, 2), ZMM(0)) + + VFMADD231PD(ZMM(1), ZMM(17), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(0) , ZMM(3)) + + VBROADCASTSD(MEM(RAX, (4+13*16)*8), ZMM(0)) + VBROADCASTSD(MEM(RAX, (4+12*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(15), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(31), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (4+11*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(13), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(29), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (4+10*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(11), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(27), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (4+9*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(9) , ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(25), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (4+8*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(7) , ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(23), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (4+7*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(5) , ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(21), ZMM(3)) + + VMOVUPD(MEM(R8, R12, 1), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(18), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(1) , ZMM(3)) + VBROADCASTSD(MEM(RAX, (4+6*16)*8), ZMM(1)) + + VMOVUPD(MEM(R8 ), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(16), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(0) , ZMM(3)) + VBROADCASTSD(MEM(RAX, (4+5*16)*8), ZMM(0)) + + VBROADCASTSD(MEM(RAX, (4+4*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(14), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(30), ZMM(3)) + + VSUBPD(ZMM(2), ZMM(12), ZMM(12)) + VSUBPD(ZMM(3), ZMM(28), ZMM(28)) + + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + VMULPD(ZMM(1), ZMM(12), ZMM(12)) + VMULPD(ZMM(1), ZMM(28), ZMM(28)) + #else + VDIVPD(ZMM(1), ZMM(12), ZMM(12)) + VDIVPD(ZMM(1), ZMM(28), ZMM(28)) + #endif + + VMOVUPD(ZMM(12), MEM(RCX)) + VEXTRACTF64X4(IMM(1), ZMM(28), YMM(1)) + VMOVUPD(YMM(28), MEM(RDX )) + VMOVUPD(XMM(1), MEM(RDX,4*8)) + SUB(RDI, RCX) + SUB(RDI, RDX) + + //iteration 12 -------------------------------------------- + VBROADCASTSD(MEM(RAX, (3+15*16)*8), ZMM(0)) + VMOVUPD(MEM(R8, R12, 4), ZMM(1)) + + VMULPD(ZMM(0), ZMM(19), ZMM(2)) + VMULPD(ZMM(0), ZMM(1) , ZMM(3)) + + VBROADCASTSD(MEM(RAX, (3+14*16)*8), ZMM(1)) + VMOVUPD(MEM(R8, R12, 2), ZMM(0)) + + VFMADD231PD(ZMM(1), ZMM(17), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(0) , ZMM(3)) + + VBROADCASTSD(MEM(RAX, (3+13*16)*8), ZMM(0)) + VBROADCASTSD(MEM(RAX, (3+12*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(15), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(31), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (3+11*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(13), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(29), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (3+10*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(11), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(27), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (3+9*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(9) , ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(25), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (3+8*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(7) , ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(23), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (3+7*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(5) , ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(21), ZMM(3)) + + VMOVUPD(MEM(R8, R12, 1), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(18), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(1) , ZMM(3)) + VBROADCASTSD(MEM(RAX, (3+6*16)*8), ZMM(1)) + + VMOVUPD(MEM(R8 ), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(16), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(0) , ZMM(3)) + VBROADCASTSD(MEM(RAX, (3+5*16)*8), ZMM(0)) + + VBROADCASTSD(MEM(RAX, (3+4*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(14), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(30), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (3+3*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(12), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(28), ZMM(3)) + + VSUBPD(ZMM(2), ZMM(10), ZMM(10)) + VSUBPD(ZMM(3), ZMM(26), ZMM(26)) + + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + VMULPD(ZMM(0), ZMM(10), ZMM(10)) + VMULPD(ZMM(0), ZMM(26), ZMM(26)) + #else + VDIVPD(ZMM(0), ZMM(10), ZMM(10)) + VDIVPD(ZMM(0), ZMM(26), ZMM(26)) + #endif + + VMOVUPD(ZMM(10), MEM(RCX)) + VEXTRACTF64X4(IMM(1), ZMM(26), YMM(1)) + VMOVUPD(YMM(26), MEM(RDX )) + VMOVUPD(XMM(1), MEM(RDX,4*8)) + SUB(RDI, RCX) + SUB(RDI, RDX) + + //iteration 13 -------------------------------------------- + VBROADCASTSD(MEM(RAX, (2+15*16)*8), ZMM(0)) + VMOVUPD(MEM(R8, R12, 4), ZMM(1)) + + VMULPD(ZMM(0), ZMM(19), ZMM(2)) + VMULPD(ZMM(0), ZMM(1) , ZMM(3)) + + VBROADCASTSD(MEM(RAX, (2+14*16)*8), ZMM(1)) + VMOVUPD(MEM(R8, R12, 2), ZMM(0)) + + VFMADD231PD(ZMM(1), ZMM(17), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(0) , ZMM(3)) + + VBROADCASTSD(MEM(RAX, (2+13*16)*8), ZMM(0)) + VBROADCASTSD(MEM(RAX, (2+12*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(15), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(31), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (2+11*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(13), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(29), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (2+10*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(11), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(27), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (2+9*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(9) , ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(25), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (2+8*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(7) , ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(23), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (2+7*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(5) , ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(21), ZMM(3)) + + VMOVUPD(MEM(R8, R12, 1), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(18), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(1) , ZMM(3)) + VBROADCASTSD(MEM(RAX, (2+6*16)*8), ZMM(1)) + + VMOVUPD(MEM(R8 ), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(16), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(0) , ZMM(3)) + VBROADCASTSD(MEM(RAX, (2+5*16)*8), ZMM(0)) + + VBROADCASTSD(MEM(RAX, (2+4*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(14), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(30), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (2+3*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(12), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(28), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (2+2*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(10), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(26), ZMM(3)) + + VSUBPD(ZMM(2), ZMM(8), ZMM(8)) + VSUBPD(ZMM(3), ZMM(24), ZMM(24)) + + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + VMULPD(ZMM(1), ZMM(8), ZMM(8)) + VMULPD(ZMM(1), ZMM(24), ZMM(24)) + #else + VDIVPD(ZMM(1), ZMM(8), ZMM(8)) + VDIVPD(ZMM(1), ZMM(24), ZMM(24)) + #endif + + VMOVUPD(ZMM(8), MEM(RCX)) + VEXTRACTF64X4(IMM(1), ZMM(24), YMM(1)) + VMOVUPD(YMM(24), MEM(RDX )) + VMOVUPD(XMM(1), MEM(RDX,4*8)) + SUB(RDI, RCX) + SUB(RDI, RDX) + + //iteration 14 -------------------------------------------- + VBROADCASTSD(MEM(RAX, (1+15*16)*8), ZMM(0)) + VMOVUPD(MEM(R8, R12, 4), ZMM(1)) + + VMULPD(ZMM(0), ZMM(19), ZMM(2)) + VMULPD(ZMM(0), ZMM(1) , ZMM(3)) + + VBROADCASTSD(MEM(RAX, (1+14*16)*8), ZMM(1)) + VMOVUPD(MEM(R8, R12, 2), ZMM(0)) + + VFMADD231PD(ZMM(1), ZMM(17), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(0) , ZMM(3)) + + VBROADCASTSD(MEM(RAX, (1+13*16)*8), ZMM(0)) + VBROADCASTSD(MEM(RAX, (1+12*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(15), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(31), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (1+11*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(13), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(29), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (1+10*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(11), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(27), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (1+9*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(9) , ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(25), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (1+8*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(7) , ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(23), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (1+7*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(5) , ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(21), ZMM(3)) + + VMOVUPD(MEM(R8, R12, 1), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(18), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(1) , ZMM(3)) + VBROADCASTSD(MEM(RAX, (1+6*16)*8), ZMM(1)) + + VMOVUPD(MEM(R8 ), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(16), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(0) , ZMM(3)) + VBROADCASTSD(MEM(RAX, (1+5*16)*8), ZMM(0)) + + VBROADCASTSD(MEM(RAX, (1+4*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(14), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(30), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (1+3*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(12), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(28), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (1+2*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(10), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(26), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (1+1*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(8) , ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(24), ZMM(3)) + + VSUBPD(ZMM(2), ZMM(6) , ZMM(6) ) + VSUBPD(ZMM(3), ZMM(22), ZMM(22)) + + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + VMULPD(ZMM(0), ZMM(6) , ZMM(6) ) + VMULPD(ZMM(0), ZMM(22), ZMM(22)) + #else + VDIVPD(ZMM(0), ZMM(6) , ZMM(6) ) + VDIVPD(ZMM(0), ZMM(22), ZMM(22)) + #endif + + VMOVUPD(ZMM(6), MEM(RCX)) + VEXTRACTF64X4(IMM(1), ZMM(22), YMM(0)) + VMOVUPD(YMM(22), MEM(RDX )) + VMOVUPD(XMM(0), MEM(RDX,4*8)) + SUB(RDI, RCX) + SUB(RDI, RDX) + + //iteration 15 -------------------------------------------- + VBROADCASTSD(MEM(RAX, (0+15*16)*8), ZMM(0)) + VMOVUPD(MEM(R8, R12, 4), ZMM(1)) + + VMULPD(ZMM(0), ZMM(19), ZMM(2)) + VMULPD(ZMM(0), ZMM(1) , ZMM(3)) + + VBROADCASTSD(MEM(RAX, (0+14*16)*8), ZMM(1)) + VMOVUPD(MEM(R8, R12, 2), ZMM(0)) + + VFMADD231PD(ZMM(1), ZMM(17), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(0) , ZMM(3)) + + VBROADCASTSD(MEM(RAX, (0+13*16)*8), ZMM(0)) + VBROADCASTSD(MEM(RAX, (0+12*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(15), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(31), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (0+11*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(13), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(29), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (0+10*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(11), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(27), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (0+9*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(9) , ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(25), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (0+8*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(7) , ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(23), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (0+7*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(5) , ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(21), ZMM(3)) + + VMOVUPD(MEM(R8, R12, 1), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(18), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(1) , ZMM(3)) + VBROADCASTSD(MEM(RAX, (0+6*16)*8), ZMM(1)) + + VMOVUPD(MEM(R8 ), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(16), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(0) , ZMM(3)) + VBROADCASTSD(MEM(RAX, (0+5*16)*8), ZMM(0)) + + VBROADCASTSD(MEM(RAX, (0+4*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(14), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(30), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (0+3*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(12), ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(28), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (0+2*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(10), ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(26), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (0+1*16)*8), ZMM(0)) + VFMADD231PD(ZMM(1), ZMM(8) , ZMM(2)) + VFMADD231PD(ZMM(1), ZMM(24), ZMM(3)) + + VBROADCASTSD(MEM(RAX, (0+0*16)*8), ZMM(1)) + VFMADD231PD(ZMM(0), ZMM(6) , ZMM(2)) + VFMADD231PD(ZMM(0), ZMM(22), ZMM(3)) + + VSUBPD(ZMM(2), ZMM(4) , ZMM(4) ) + VSUBPD(ZMM(3), ZMM(20), ZMM(20)) + + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + VMULPD(ZMM(1), ZMM(4) , ZMM(4) ) + VMULPD(ZMM(1), ZMM(20), ZMM(20)) + #else + VDIVPD(ZMM(1), ZMM(4) , ZMM(4) ) + VDIVPD(ZMM(1), ZMM(20), ZMM(20)) + #endif + + VMOVUPD(ZMM(4), MEM(RCX)) + VEXTRACTF64X4(IMM(1), ZMM(20), YMM(1)) + VMOVUPD(YMM(20), MEM(RDX )) + VMOVUPD(XMM(1) , MEM(RDX,4*8)) + SUB(RDI, RCX) + SUB(RDI, RDX) + + /* + Storage Region (Post TRSM) + */ + MOV(R8, RCX) + MOV(R9, RDI) + MOV(VAR(cs_c), RSI) + + LEA(MEM(RCX, RSI, 8), RDX) // rdx = rcx + cs_c * 8 + LEA(MEM(RCX, RDI, 8), R14) // r14 = rcx + rs_c * 8 + + LEA(MEM(RSI, RSI, 2), R12) // cs_c * 3 + LEA(MEM(RSI, RSI, 4), R13) // cs_c * 5 + LEA(MEM(R13, RSI, 2), R15) // cs_c * 7 + + CMP(IMM(8), RSI) + JZ(.DROWSTORED) + + CMP(IMM(8), RDI) + JZ(.DCOLSTORED) + + LABEL(.DROWSTORED) + + VMOVUPD(MEM(R8 ), ZMM(1)) + VMOVUPD(ZMM(4), MEM(RCX)) + VMOVUPD(MEM(R8, RDI, 1), ZMM(4)) + + ADD(RDI, RCX) + VEXTRACTF64X4(IMM(1), ZMM(20), YMM(0)) + VMOVUPD(YMM(20), MEM(RDX )) + + VMOVUPD(MEM(R8, RDI, 2), ZMM(20)) + + VMOVUPD(XMM(0) , MEM(RDX,4*8)) + ADD(RDI, RDX) + + VMOVUPD(ZMM(6), MEM(RCX)) + + VMOVUPD(MEM(R8, RDI, 4), ZMM(6)) + + ADD(RDI, RCX) + VEXTRACTF64X4(IMM(1), ZMM(22), YMM(0)) + VMOVUPD(YMM(22), MEM(RDX )) + VMOVUPD(XMM(0) , MEM(RDX,4*8)) + ADD(RDI, RDX) + + VMOVUPD(ZMM(8), MEM(RCX)) + ADD(RDI, RCX) + VEXTRACTF64X4(IMM(1), ZMM(24), YMM(0)) + VMOVUPD(YMM(24), MEM(RDX )) + VMOVUPD(XMM(0) , MEM(RDX,4*8)) + ADD(RDI, RDX) + + VMOVUPD(ZMM(10), MEM(RCX)) + ADD(RDI, RCX) + VEXTRACTF64X4(IMM(1), ZMM(26), YMM(0)) + VMOVUPD(YMM(26), MEM(RDX )) + VMOVUPD(XMM(0) , MEM(RDX,4*8)) + ADD(RDI, RDX) + + VMOVUPD(ZMM(12), MEM(RCX)) + ADD(RDI, RCX) + VEXTRACTF64X4(IMM(1), ZMM(28), YMM(0)) + VMOVUPD(YMM(28), MEM(RDX )) + VMOVUPD(XMM(0) , MEM(RDX,4*8)) + ADD(RDI, RDX) + + VMOVUPD(ZMM(14), MEM(RCX)) + ADD(RDI, RCX) + VEXTRACTF64X4(IMM(1), ZMM(30), YMM(0)) + VMOVUPD(YMM(30), MEM(RDX )) + VMOVUPD(XMM(0) , MEM(RDX,4*8)) + ADD(RDI, RDX) + + VMOVUPD(ZMM(16), MEM(RCX)) + ADD(RDI, RCX) + VEXTRACTF64X4(IMM(1), ZMM(1), YMM(0)) + VMOVUPD(YMM(1), MEM(RDX )) + VMOVUPD(XMM(0), MEM(RDX,4*8)) + ADD(RDI, RDX) + + VMOVUPD(ZMM(18), MEM(RCX)) + ADD(RDI, RCX) + VEXTRACTF64X4(IMM(1), ZMM(4), YMM(0)) + VMOVUPD(YMM(4), MEM(RDX )) + VMOVUPD(XMM(0), MEM(RDX,4*8)) + ADD(RDI, RDX) + + VMOVUPD(ZMM(5), MEM(RCX)) + ADD(RDI, RCX) + VEXTRACTF64X4(IMM(1), ZMM(21), YMM(0)) + VMOVUPD(YMM(21), MEM(RDX )) + VMOVUPD(XMM(0) , MEM(RDX,4*8)) + ADD(RDI, RDX) + + VMOVUPD(ZMM(7), MEM(RCX)) + ADD(RDI, RCX) + VEXTRACTF64X4(IMM(1), ZMM(23), YMM(0)) + VMOVUPD(YMM(23), MEM(RDX )) + VMOVUPD(XMM(0) , MEM(RDX,4*8)) + ADD(RDI, RDX) + + VMOVUPD(ZMM(9), MEM(RCX)) + ADD(RDI, RCX) + VEXTRACTF64X4(IMM(1), ZMM(25), YMM(0)) + VMOVUPD(YMM(25), MEM(RDX )) + VMOVUPD(XMM(0) , MEM(RDX,4*8)) + ADD(RDI, RDX) + + VMOVUPD(ZMM(11), MEM(RCX)) + ADD(RDI, RCX) + VEXTRACTF64X4(IMM(1), ZMM(27), YMM(0)) + VMOVUPD(YMM(27), MEM(RDX )) + VMOVUPD(XMM(0) , MEM(RDX,4*8)) + ADD(RDI, RDX) + + VMOVUPD(ZMM(13), MEM(RCX)) + ADD(RDI, RCX) + VEXTRACTF64X4(IMM(1), ZMM(29), YMM(0)) + VMOVUPD(YMM(29), MEM(RDX )) + VMOVUPD(XMM(0) , MEM(RDX,4*8)) + ADD(RDI, RDX) + + VMOVUPD(ZMM(15), MEM(RCX)) + ADD(RDI, RCX) + VEXTRACTF64X4(IMM(1), ZMM(31), YMM(0)) + VMOVUPD(YMM(31), MEM(RDX )) + VMOVUPD(XMM(0) , MEM(RDX,4*8)) + ADD(RDI, RDX) + + VMOVUPD(ZMM(17), MEM(RCX)) + ADD(RDI, RCX) + VEXTRACTF64X4(IMM(1), ZMM(20), YMM(0)) + VMOVUPD(YMM(20), MEM(RDX )) + VMOVUPD(XMM(0), MEM(RDX,4*8)) + ADD(RDI, RDX) + + VMOVUPD(ZMM(19), MEM(RCX)) + VEXTRACTF64X4(IMM(1), ZMM(6), YMM(0)) + VMOVUPD(YMM(6), MEM(RDX )) + VMOVUPD(XMM(0), MEM(RDX,4*8)) + + + JMP(.DDONE) + LABEL(.DCOLSTORED) + + + MOV(VAR(offsetPtr), R12) + LEA(MEM(RCX, RSI, 8), RDX) // rdx = rcx + cs_c * 8 + VPBROADCASTQ(RSI, ZMM(0)) + VPMULLQ(MEM(R12 ), ZMM(0), ZMM(2)) + VPMULLQ(MEM(R12,64), ZMM(0), ZMM(3)) // load offsets in zmm2, zmm3 + + VMOVUPD(MEM(RCX ), ZMM(0)) + VMOVUPD(MEM(RCX, RSI, 1), ZMM(1)) + + MOV(RDX, RCX) + LEA(MEM(RCX, RDI, 8), R14) + UPDATE_C_COL_SCATTERED_2x6(20,21) + VMOVUPD(MEM(R8, RSI, 2), ZMM(20)) + VMOVUPD(MEM(R8, RSI, 4), ZMM(21)) + UPDATE_C_COL_SCATTERED_2x6(22,23) + UPDATE_C_COL_SCATTERED_2x6(24,25) + UPDATE_C_COL_SCATTERED_2x6(26,27) + UPDATE_C_COL_SCATTERED_2x6(28,29) + UPDATE_C_COL_SCATTERED_2x6(30,31) + UPDATE_C_COL_SCATTERED_2x6(0 ,20 ) + UPDATE_C_COL_SCATTERED_2x6(1 ,21 ) + + MOV(R8, RCX) + LEA(MEM(RCX, RDI, 8), R14) + UPDATE_C_COL_SCATTERED( 4, 5) + UPDATE_C_COL_SCATTERED( 6, 7) + UPDATE_C_COL_SCATTERED( 8, 9) + UPDATE_C_COL_SCATTERED(10,11) + UPDATE_C_COL_SCATTERED(12,13) + UPDATE_C_COL_SCATTERED(14,15) + UPDATE_C_COL_SCATTERED(16,17) + UPDATE_C_COL_SCATTERED(18,19) + + + LABEL(.DDONE) + + VZEROUPPER() + + end_asm( + : // output operands (none) + : // input operands + [a10] "m" (a10), // 1 + [k] "m" (k), // 2 + [b01] "m" (b01), // 3 + [a11] "m" (a11), // 6 + [b11] "m" (b11), // 7 + [c11] "m" (c11), // 8 + [rs_c] "m" (rs_c), // 9 + [cs_c] "m" (cs_c), // 10, + [alpha] "m" (alpha), + [offsetPtr] "m" (offsetPtr) + : // register clobber list + "rax", "rbx", "rcx", "rdx", "rdi", "rsi", "r8", "r9", "r10", "r11", "r12", + "r13", "r14", "r15", "zmm0", "zmm1", "zmm2", "zmm3", "zmm4", "zmm5", + "zmm6", "zmm7", "zmm8", "zmm9", "zmm10", "zmm11", "zmm12", "zmm13", + "zmm14", "zmm15", "zmm16", "zmm17", "zmm18", "zmm19", "zmm20", "zmm21", + "zmm22", "zmm23", "zmm24", "zmm25", "zmm26", "zmm27", "zmm28", "zmm29", + "zmm30", "zmm31", "memory" + ) + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_9); +} \ No newline at end of file diff --git a/kernels/zen4/CMakeLists.txt b/kernels/zen4/CMakeLists.txt index 827d91bbe2..c22c5ba143 100644 --- a/kernels/zen4/CMakeLists.txt +++ b/kernels/zen4/CMakeLists.txt @@ -1,5 +1,6 @@ ##Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.## add_subdirectory(1) +add_subdirectory(3) diff --git a/kernels/zen4/bli_kernels_zen4.h b/kernels/zen4/bli_kernels_zen4.h index 476eeaaeed..e518a86047 100644 --- a/kernels/zen4/bli_kernels_zen4.h +++ b/kernels/zen4/bli_kernels_zen4.h @@ -37,3 +37,6 @@ // amaxv (intrinsics) AMAXV_KER_PROT( float, s, amaxv_zen_int_avx512 ) AMAXV_KER_PROT( double, d, amaxv_zen_int_avx512 ) + +GEMMTRSM_UKR_PROT( double, d, gemmtrsm_l_zen_asm_16x14) +GEMMTRSM_UKR_PROT( double, d, gemmtrsm_u_zen_asm_16x14) \ No newline at end of file From b79dc709485605a1b7f9f02c45d78f4d15ab6cb4 Mon Sep 17 00:00:00 2001 From: Eleni Vlachopoulou Date: Wed, 14 Sep 2022 16:06:08 +0530 Subject: [PATCH 232/243] Adding AVX2 support for DNRM2 - For the cases where AVX2 is available, an optimized function is called, based on Blue's algorithm. The fallback method based on sumsqv is used otherwise. - Scaling is used to avoid overflow and underflow. - Works correctly for negative increments. AMD-Internal: [CPUPL-2551] Change-Id: I5d8976b29b5af463a8981061b2be907ea647123c --- bench/Makefile | 3 + bench/bench_nrm2.c | 241 ++++++++++++++++ bench/inputnrm2.txt | 42 +++ config/amdzen/bli_family_amdzen.h | 2 - config/zen/bli_family_zen.h | 4 +- config/zen2/bli_family_zen2.h | 3 +- config/zen3/bli_family_zen3.h | 4 +- config/zen4/bli_family_zen4.h | 2 - frame/util/bli_util_unb_var1.c | 75 +++-- kernels/zen/1/bli_norm2_zen_int.c | 465 +++++++++++++++++++++--------- kernels/zen/bli_kernels_zen.h | 7 +- 11 files changed, 664 insertions(+), 184 deletions(-) create mode 100644 bench/bench_nrm2.c create mode 100644 bench/inputnrm2.txt diff --git a/bench/Makefile b/bench/Makefile index 93cca3298a..0203d5a5b0 100755 --- a/bench/Makefile +++ b/bench/Makefile @@ -186,6 +186,7 @@ blis: \ bench_gemv_blis.x \ bench_syrk_blis.x \ bench_ger_blis.x \ + bench_nrm2_blis.x \ bench_scalv_blis.x \ bench_dotv_blis.x \ bench_trsv_blis.x \ @@ -201,6 +202,7 @@ openblas: \ bench_gemv_openblas.x \ bench_syrk_openblas.x \ bench_ger_openblas.x \ + bench_nrm2_openblas.x \ bench_scalv_openblas.x \ bench_dotv_openblas.x \ bench_trsv_openblas.x \ @@ -231,6 +233,7 @@ mkl: \ bench_gemv_mkl.x \ bench_syrk_mkl.x \ bench_ger_mkl.x \ + bench_nrm2_mkl.x \ bench_scalv_mkl.x \ bench_dotv_mkl.x \ bench_trsv_mkl.x \ diff --git a/bench/bench_nrm2.c b/bench/bench_nrm2.c new file mode 100644 index 0000000000..ae79eb3307 --- /dev/null +++ b/bench/bench_nrm2.c @@ -0,0 +1,241 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name of The University of Texas nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifdef WIN32 +#include +#else +#include +#endif +#include "blis.h" + + +// Benchmark application to process aocl logs generated by BLIS library. +#ifndef DT +#define DT BLIS_DOUBLE +#endif + + +#define AOCL_MATRIX_INITIALISATION + +//#define BLIS_ENABLE_CBLAS + +/* For BLIS since logs are collected at BLAS interfaces + * we disable cblas interfaces for this benchmark application + */ + +/* #ifdef BLIS_ENABLE_CBLAS */ +/* #define CBLAS */ +/* #endif */ + +int main( int argc, char** argv ) +{ + obj_t x; + obj_t normf; + dim_t p_inc = 0; // to keep track of number of inputs + num_t dt; + char dt_ch; + int r, n_repeats; + + double dtime; + double dtime_save; + double gflops; + + FILE* fin = NULL; + FILE* fout = NULL; + + n_repeats = N_REPEAT; // This macro will get from Makefile. + + dt = DT; + + if ( argc < 3 ) + { + printf("Usage: ./test_nrm2_XX.x input.csv output.csv [number_repeats]\n"); + exit(1); + } + fin = fopen( argv[1], "r" ); + if ( argc == 4 ) + { + n_repeats = atoi(argv[3]); + } + if ( fin == NULL ) + { + printf("Error opening the file %s\n", argv[1]); + exit(1); + } + fout = fopen(argv[2], "w"); + if (fout == NULL) + { + printf("Error opening output file %s\n", argv[2]); + exit(1); + } + + fprintf(fout, "Dt\t n\t incx\t gflops\n"); + dim_t n; + inc_t incx; + char tmp[256]; // to store function name, line no present in logs. + + + // {S,D,C,Z} {n incx} + while (fscanf(fin, "%s %c" INT_FS INT_FS "\n", + tmp, &dt_ch, &n, &incx) == 4) + { + +#ifdef PRINT + fprintf (stdout, "Input = %s %c %ld %ld\n", + tmp, dt_ch, n, incx); +#endif + + if (dt_ch == 'D' || dt_ch == 'd') dt = BLIS_DOUBLE; + else if (dt_ch == 'Z' || dt_ch == 'z') dt = BLIS_DCOMPLEX; + else if (dt_ch == 'S' || dt_ch == 's') dt = BLIS_FLOAT; + else if (dt_ch == 'C' || dt_ch == 'c') dt = BLIS_SCOMPLEX; + else + { + printf("Invalid data type %c\n", dt_ch); + continue; + } + + // Create objects with required sizes and strides. + + // The ?nrm2 routines compute the Euclidean norm of a vector X + // norm = ||X|| + // defined as the square root of the sum of squares of the vector elements + // where: + // X is an n-element vector. + + bli_obj_create( dt, n, 1, incx, 1, &x ); + bli_obj_create_1x1( dt, &normf ); +#ifdef AOCL_MATRIX_INITIALISATION + bli_randv( &x ); +#endif + dtime_save = DBL_MAX; + + for ( r = 0; r < n_repeats; ++r ) + { + +#ifdef PRINT + bli_printm( "x", &x, "%4.1f", "" ); +#endif + dtime = bli_clock(); + +#ifdef BLIS + bli_normfv(&x, &normf); +#else // BLIS Interface + + // Set data type independent inputs for BLAS and + // CBLAS API's + + f77_int nn = bli_obj_length( &x ); + f77_int blas_incx = bli_obj_vector_inc( &x ); + + if ( bli_is_float( dt ) ){ + float* xp = bli_obj_buffer( &x ); + float* normfp = bli_obj_buffer( &normf ); +#ifdef CBLAS + *normfp = cblas_snrm2( nn, xp, blas_incx ); +#else // cblas snrm2 + *normfp = snrm2_( &nn, xp, &blas_incx); +#endif // cblas snrm2 + } + else if ( bli_is_double( dt ) ) + { + + double* xp = bli_obj_buffer( &x ); + double* normfp = bli_obj_buffer( &normf ); + +#ifdef CBLAS + *normfp = cblas_dnrm2( nn, xp, blas_incx ); + +#else // cblas dnrm2 + *normfp = dnrm2_( &nn, xp, &blas_incx); +#endif // cblas dnrm2 + } + else if ( bli_is_scomplex( dt ) ) + { + scomplex* xp = bli_obj_buffer( &x ); + float* normfp = bli_obj_buffer( &normf ); + +#ifdef CBLAS + *normfp = cblas_scnrm2( nn, xp, blas_incx ); +#else // cblas cnrm2 + *normfp = scnrm2_( &nn, xp, &blas_incx); +#endif // cblas cnrm2 + } + else if ( bli_is_dcomplex( dt ) ) + { + dcomplex* xp = bli_obj_buffer( &x ); + double* normfp = bli_obj_buffer( &normf ); +#ifdef CBLAS + *normfp = cblas_dznrm2( nn, xp, blas_incx ); +#else // cblas znrm2 + *normfp = dznrm2_( &nn, xp, &blas_incx); +#endif // cblas znrm2 + } + +#endif // BLIS Interface + +#ifdef PRINT + bli_printm( "x after", &x "%4.1f", "" ); + exit(1); +#endif + + dtime_save = bli_clock_min_diff( dtime_save, dtime ); + } + + gflops = (2*n) / ( dtime_save * 1.0e9 ); + + if ( bli_is_complex( dt ) ) gflops *= 2.0; + + printf( "data_nrm2_%s", BLAS ); + + p_inc++; + printf("( %2lu, 1:4 ) = [ %4lu %7.2f ];\n", + (unsigned long)(p_inc), + (unsigned long)n, + gflops); + + fprintf (fout, "%c %ld %ld %6.3f\n", + dt_ch, n, incx, gflops); + + fflush(fout); + + bli_obj_free( &x ); + } + + //bli_finalize(); + fclose(fin); + fclose(fout); + + return 0; +} diff --git a/bench/inputnrm2.txt b/bench/inputnrm2.txt new file mode 100644 index 0000000000..567d6e4691 --- /dev/null +++ b/bench/inputnrm2.txt @@ -0,0 +1,42 @@ +dnrm2:171: D 2 1 +dnrm2:171: D 4 1 +dnrm2:171: D 8 1 +dnrm2:171: D 42 1 +dnrm2:171: D 64 1 +dnrm2:171: D 87 1 +dnrm2:171: D 100 1 +dnrm2:171: D 128 1 +dnrm2:171: D 189 1 +dnrm2:171: D 208 1 +dnrm2:171: D 256 1 +dnrm2:171: D 313 1 +dnrm2:171: D 512 1 +dnrm2:171: D 718 1 +dnrm2:171: D 932 1 +dnrm2:171: D 1024 1 +dnrm2:171: D 1895 1 +dnrm2:171: D 2048 1 +dnrm2:171: D 3275 1 +dnrm2:171: D 4096 1 +dnrm2:171: D 6749 1 +dnrm2:171: D 8192 1 +dnrm2:171: D 10001 1 +dnrm2:171: D 16384 1 +dnrm2:171: D 20976 1 +dnrm2:171: D 32768 1 +dnrm2:171: D 56841 1 +dnrm2:171: D 65536 1 +dnrm2:171: D 8 3 +dnrm2:171: D 64 7 +dnrm2:171: D 87 12 +dnrm2:171: D 189 9 +dnrm2:171: D 313 3 +dnrm2:171: D 718 5 +dnrm2:171: D 1024 2 +dnrm2:171: D 3275 4 +dnrm2:171: D 4096 7 +dnrm2:171: D 8192 5 +dnrm2:171: D 16384 11 +dnrm2:171: D 20976 3 +dnrm2:171: D 56841 19 +dnrm2:171: D 65536 6 \ No newline at end of file diff --git a/config/amdzen/bli_family_amdzen.h b/config/amdzen/bli_family_amdzen.h index 1a8c1234a8..0cf46d5a4e 100644 --- a/config/amdzen/bli_family_amdzen.h +++ b/config/amdzen/bli_family_amdzen.h @@ -59,8 +59,6 @@ // BLIS), defining this macro as 1 yields better performance. #define AOCL_BLIS_MULTIINSTANCE 0 -//#define BLIS_ENABLE_FAST_MATH - /* * Override the block sizes in the context to the block sizes used * by AVX2 GEMM+TRSM kernels, this is needed in Zen4 context as default diff --git a/config/zen/bli_family_zen.h b/config/zen/bli_family_zen.h index 23d3d608c7..8b31c32ca0 100644 --- a/config/zen/bli_family_zen.h +++ b/config/zen/bli_family_zen.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -53,6 +53,4 @@ #define BLIS_SMALL_MATRIX_A_THRES_M_SYRK 96 #define BLIS_SMALL_MATRIX_A_THRES_N_SYRK 128 -//#define BLIS_ENABLE_FAST_MATH - #endif diff --git a/config/zen2/bli_family_zen2.h b/config/zen2/bli_family_zen2.h index dbae9752cc..16fe50609e 100644 --- a/config/zen2/bli_family_zen2.h +++ b/config/zen2/bli_family_zen2.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019 - 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -56,6 +56,5 @@ // When running HPL with pure MPI without DGEMM threading (Single-threaded // BLIS), defining this macro as 1 yields better performance. #define AOCL_BLIS_MULTIINSTANCE 0 -//#define BLIS_ENABLE_FAST_MATH #endif diff --git a/config/zen3/bli_family_zen3.h b/config/zen3/bli_family_zen3.h index 69def1422d..ce84104c52 100644 --- a/config/zen3/bli_family_zen3.h +++ b/config/zen3/bli_family_zen3.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020-2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -55,6 +55,4 @@ #define BLIS_SMALL_MATRIX_A_THRES_M_SYRK 96 #define BLIS_SMALL_MATRIX_A_THRES_N_SYRK 128 -//#define BLIS_ENABLE_FAST_MATH - #endif diff --git a/config/zen4/bli_family_zen4.h b/config/zen4/bli_family_zen4.h index fad5f16986..b21d1582f7 100644 --- a/config/zen4/bli_family_zen4.h +++ b/config/zen4/bli_family_zen4.h @@ -53,8 +53,6 @@ #define BLIS_SMALL_MATRIX_A_THRES_M_SYRK 96 #define BLIS_SMALL_MATRIX_A_THRES_N_SYRK 128 -//#define BLIS_ENABLE_FAST_MATH - // -- SIMD config -------------------------------------------------------- #define BLIS_SIMD_ALIGN_SIZE 64 diff --git a/frame/util/bli_util_unb_var1.c b/frame/util/bli_util_unb_var1.c index a2166b7b1f..2d43bb74c3 100644 --- a/frame/util/bli_util_unb_var1.c +++ b/frame/util/bli_util_unb_var1.c @@ -440,34 +440,55 @@ void PASTEMAC(ch,varname) \ } #endif GENTFUNCR( float, float, s, s, normfv_unb_var1, sumsqv_unb_var1 ) -/*call sumsqv_unb_var1 if FAST_MATH is not defined else call dot-norm method*/\ -#ifndef BLIS_ENABLE_FAST_MATH -GENTFUNCR( double, double, d, d, normfv_unb_var1, sumsqv_unb_var1 ) -#else -#undef GENTFUNCR -#define GENTFUNCR( ctype, ctype_r, ch, chr, varname, kername ) \ -\ -void PASTEMAC(ch,varname) \ - ( \ - dim_t n, \ - ctype* x, inc_t incx, \ - ctype_r* norm, \ - cntx_t* cntx, \ - rntm_t* rntm \ - ) \ -{ \ -\ - /* Compute the sum of the squares of the vector. */ \ - PASTEMAC(ch,kername) \ - ( \ - n, \ - x, incx, \ - norm, \ - cntx \ - ); \ + +void bli_dnormfv_unb_var1 + ( + dim_t n, + double* x, + inc_t incx, + double* norm, + cntx_t* cntx, + rntm_t* rntm + ) +{ + + if( bli_cpuid_is_avx_supported() == TRUE ) + { + bli_dnorm2fv_unb_var1_avx( n, x, incx, norm, cntx ); + } + else + { + double* zero = bli_d0; + double* one = bli_d1; + double scale; + double sumsq; + double sqrt_sumsq; + + // Initialize scale and sumsq to begin the summation. + bli_ddcopys( *zero, scale ); + bli_ddcopys( *one, sumsq ); + + // Compute the sum of the squares of the vector. + + bli_dsumsqv_unb_var1 + ( + n, + x, + incx, + &scale, + &sumsq, + cntx, + rntm + ); + + // Compute: norm = scale * sqrt( sumsq ) + bli_dsqrt2s( sumsq, sqrt_sumsq ); + bli_dscals( scale, sqrt_sumsq ); + + // Store the final value to the output variable. + bli_dcopys( sqrt_sumsq, *norm ); + } } -GENTFUNCR( double, double, d, d, normfv_unb_var1, norm2fv_unb_var1 ) -#endif #undef GENTFUNCR #define GENTFUNCR( ctype, ctype_r, ch, chr, varname ) \ diff --git a/kernels/zen/1/bli_norm2_zen_int.c b/kernels/zen/1/bli_norm2_zen_int.c index 0a0f92e36c..a1941667bc 100644 --- a/kernels/zen/1/bli_norm2_zen_int.c +++ b/kernels/zen/1/bli_norm2_zen_int.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2021 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -34,203 +34,388 @@ #include "immintrin.h" #include "blis.h" -#ifdef BLIS_ENABLE_FAST_MATH -/* Union data structure to access AVX registers - One 256-bit AVX register holds 8 SP elements. */ +// Union data structure to access AVX registers +// One 256-bit AVX register holds 8 SP elements. typedef union { __m256 v; - float f[8] __attribute__((aligned(64))); + float f[8] __attribute__( ( aligned( 64 ) ) ); } v8sf_t; -/* Union data structure to access AVX registers -* One 256-bit AVX register holds 4 DP elements. */ +// Union data structure to access AVX registers +// One 256-bit AVX register holds 4 DP elements. typedef union { __m256d v; - double d[4] __attribute__((aligned(64))); + double d[4] __attribute__( ( aligned( 64 ) ) ); } v4df_t; -// ----------------------------------------------------------------------------- +// Return a mask which indicates either: +// v <= t or v >= T +#define CMP256( v, t, T ) \ + _mm256_or_pd( _mm256_cmp_pd( v, t, _CMP_LE_OS ), _mm256_cmp_pd( v, T, _CMP_GE_OS ) ); -void bli_dnorm2fv_unb_var1 - ( +// Returns true if any of the values in the mask vector is true, +// and false, otherwise. +static inline bool bli_horizontal_or( __m256d a ) { return ! _mm256_testz_pd( a, a ); } + +// Optimized function that computes the Frobenius norm using AVX intrinsics. +void bli_dnorm2fv_unb_var1_avx + ( dim_t n, double* x, inc_t incx, double* norm, cntx_t* cntx - ) + ) { + AOCL_DTL_TRACE_ENTRY( AOCL_DTL_LEVEL_TRACE_3 ); + double sumsq = 0; - double rem_sumsq = 0; /*sum of squares accumulated for n_remainder<8 cases.*/ + dim_t i = 0; dim_t n_remainder = 0; - dim_t i; - /*memory pool declarations for packing vector X. - Initialize mem pool buffer to NULL and size to 0 - "buf" and "size" fields are assigned once memory - is allocated from the pool in bli_membrk_acquire_m(). - This will ensure bli_mem_is_alloc() will be passed on - an allocated memory if created or a NULL .*/ - mem_t mem_bufX = {0}; - rntm_t rntm; double *x_buf = x; - - /*early return if n<=0 or incx =0 */ - if((n <= 0) || (incx == 0)) - return; - - /*packing for non-unit strided Vector X*/ - if(incx != 1) + + // Early return if n<=0 or incx=0 + if ( ( n <= 0) || ( incx == 0 ) ) { - /* In order to get the buffer from pool via rntm access to memory broker - is needed.Following are initializations for rntm */ + return; + } + + // Memory pool declarations for packing vector X. + // Initialize mem pool buffer to NULL and size to 0. + // "buf" and "size" fields are assigned once memory + // is allocated from the pool in bli_membrk_acquire_m(). + // This will ensure bli_mem_is_alloc() will be passed on + // an allocated memory if created or a NULL. + mem_t mem_bufX = {0}; + rntm_t rntm; + // Packing for non-unit strided vector x. + if ( incx != 1 ) + { + // In order to get the buffer from pool via rntm access to memory broker + //is needed. Following are initializations for rntm. bli_rntm_init_from_global( &rntm ); bli_rntm_set_num_threads_only( 1, &rntm ); bli_membrk_rntm_set_membrk( &rntm ); - //calculate the size required for "n" double elements in vector X. - size_t buffer_size = n * sizeof(double); + // Calculate the size required for "n" double elements in vector x. + size_t buffer_size = n * sizeof( double ); #ifdef BLIS_ENABLE_MEM_TRACING printf( "bli_dnorm2fv_unb_var1(): get mem pool block\n" ); - #endif + #endif + + // Acquire a Buffer(n*size(double)) from the memory broker + // and save the associated mem_t entry to mem_bufX. + bli_membrk_acquire_m + ( + &rntm, + buffer_size, + BLIS_BUFFER_FOR_B_PANEL, + &mem_bufX + ); + + // Continue packing X if buffer memory is allocated. + if ( ( bli_mem_is_alloc( &mem_bufX ) ) ) + { + x_buf = bli_mem_buffer( &mem_bufX ); + // Pack vector x with non-unit stride to a temp buffer x_buf with unit stride. + for ( dim_t x_index = 0; x_index < n; x_index++ ) + { + if (incx > 0) + { + *( x_buf + x_index ) = *( x + ( x_index * incx ) ); + } + else + { + *( x_buf + x_index ) = *( x + ( - ( n - x_index - 1 ) * incx ) ); + } + } + } + } - /*acquire a Buffer(n*size(double)) from the memory broker - and save the associated mem_t entry to mem_bufX.*/ - bli_membrk_acquire_m(&rntm, - buffer_size, - BLIS_BUFFER_FOR_B_PANEL, - &mem_bufX); + double *xt = x_buf; + + // Compute the sum of squares on 3 accumulators to avoid overflow + // and underflow, depending on the vector element value. + // Accumulator for small values; using scaling to avoid underflow. + double sum_sml = 0; + // Accumulator for medium values; no scaling required. + double sum_med = 0; + // Accumulator for big values; using scaling to avoid overflow. + double sum_big = 0; + + // Constants chosen to minimize roundoff, according to Blue's algorithm. + const double thres_sml = pow( ( double )FLT_RADIX, ceil( ( DBL_MIN_EXP - 1 ) * 0.5 ) ); + const double thres_big = pow( ( double )FLT_RADIX, floor( ( DBL_MAX_EXP - 52) * 0.5 ) ); + const double scale_sml = pow( ( double )FLT_RADIX, - floor( ( DBL_MIN_EXP - 53 ) * 0.5 ) ); + const double scale_big = pow( ( double )FLT_RADIX, - ceil( ( DBL_MAX_EXP - 52 ) * 0.5 ) ); + + double scale; + double abs_chi; + bool isbig = false; - /*Continue packing X if buffer memory is allocated*/ - if ((bli_mem_is_alloc( &mem_bufX ))) + if ( n > 4 ) + { + // Constants used for comparisons. + v4df_t temp1, thres_sml_vec, thres_big_vec, zerov, ymm0, ymm1; + temp1.v = _mm256_set1_pd( -0.0 ); + thres_sml_vec.v = _mm256_set1_pd( thres_sml ); + thres_big_vec.v = _mm256_set1_pd( thres_big ); + v4df_t x0v, x1v, mask_vec0, mask_vec1; + zerov.v = _mm256_setzero_pd(); + + // Partial sums used for scaling. + v4df_t sum_med_vec0, sum_big_vec0, sum_sml_vec0, sum_med_vec1, sum_big_vec1, sum_sml_vec1; + sum_med_vec0.v = _mm256_setzero_pd(); + sum_big_vec0.v = _mm256_setzero_pd(); + sum_sml_vec0.v = _mm256_setzero_pd(); + sum_med_vec1.v = _mm256_setzero_pd(); + sum_big_vec1.v = _mm256_setzero_pd(); + sum_sml_vec1.v = _mm256_setzero_pd(); + + + for (; ( i + 8 ) <= n; i = i + 8) { - x_buf = bli_mem_buffer(&mem_bufX); + x0v.v = _mm256_loadu_pd( xt ); + x1v.v = _mm256_loadu_pd( xt + 4 ); - /*pack X vector with non-unit stride to a temp buffer x_buf with unit stride*/ - for(dim_t x_index = 0 ; x_index < n ; x_index++) + // Getting the abs of the vector elements. + x0v.v = _mm256_andnot_pd( temp1.v, x0v.v ); + x1v.v = _mm256_andnot_pd( temp1.v, x1v.v ); + + // Mask vectors which indicate whether + // xi<=thres_sml or xi>=thres_big. + mask_vec0.v = CMP256( x0v.v, thres_sml_vec.v, thres_big_vec.v ); + mask_vec1.v = CMP256( x1v.v, thres_sml_vec.v, thres_big_vec.v ); + + + if ( !bli_horizontal_or( mask_vec0.v ) ) + { + // Scaling is not necessary; only medium values. + sum_med_vec0.v = _mm256_fmadd_pd( x0v.v, x0v.v, sum_med_vec0.v ); + } + else { - *(x_buf + x_index) = *(x + (x_index * incx)) ; + // Mask vector which indicates whether xi < thres_big. + mask_vec0.v = _mm256_cmp_pd( thres_big_vec.v, x0v.v, _CMP_GT_OQ ); + + if ( bli_horizontal_or( mask_vec0.v ) ) + { + isbig = true; + + // Fill sum_med vector without scaling. + ymm0.v = _mm256_blendv_pd( zerov.v, x0v.v, mask_vec0.v ); + sum_med_vec0.v = _mm256_fmadd_pd( ymm0.v, ymm0.v, sum_med_vec0.v ); + + // Fill sum_big vector using scaling. + ymm0.v = _mm256_blendv_pd( _mm256_set1_pd( scale_big ), zerov.v, mask_vec0.v ); + ymm0.v = _mm256_mul_pd( x0v.v, ymm0.v ); + sum_big_vec0.v = _mm256_fmadd_pd( ymm0.v, ymm0.v, sum_big_vec0.v ); + } + else if (!isbig) + { + // Mask vector which indicates whether xi > thres_small. + mask_vec0.v = _mm256_cmp_pd( thres_sml_vec.v, x0v.v, _CMP_LT_OQ ); + + // Fill sum_med vector without scaling. + ymm0.v = _mm256_blendv_pd( zerov.v, x0v.v, mask_vec0.v ); + sum_med_vec0.v = _mm256_fmadd_pd( ymm0.v, ymm0.v, sum_med_vec0.v ); + + // Fill sum_sml vector using scaling. + ymm0.v = _mm256_blendv_pd( _mm256_set1_pd( scale_sml ), zerov.v, mask_vec0.v ); + ymm0.v = _mm256_mul_pd( x0v.v, ymm0.v ); + sum_sml_vec0.v = _mm256_fmadd_pd( ymm0.v, ymm0.v, sum_sml_vec0.v ); + } + } + + if ( !bli_horizontal_or( mask_vec1.v ) ) + { + // Scaling is not necessary; only medium values. + sum_med_vec1.v = _mm256_fmadd_pd( x1v.v, x1v.v, sum_med_vec1.v ); + } + else + { + // Mask vector which indicated whether xi < thres_big. + mask_vec1.v = _mm256_cmp_pd( thres_big_vec.v, x1v.v, _CMP_GT_OQ ); + if ( bli_horizontal_or( mask_vec1.v ) ) + { + isbig = true; + + // Fill sum_med vector without scaling. + ymm1.v = _mm256_blendv_pd( zerov.v, x1v.v, mask_vec1.v ); + sum_med_vec1.v = _mm256_fmadd_pd( ymm1.v, ymm1.v, sum_med_vec1.v ); + + // Fill sum_big vector using scaling. + ymm1.v = _mm256_blendv_pd( _mm256_set1_pd( scale_big ), zerov.v, mask_vec1.v ); + ymm1.v = _mm256_mul_pd( x1v.v, ymm1.v ); + sum_big_vec1.v = _mm256_fmadd_pd( ymm1.v, ymm1.v, sum_big_vec1.v ); + } + else if (!isbig) + { + // Mask vector which indicated whether xi > thres_small. + mask_vec1.v = _mm256_cmp_pd( thres_sml_vec.v, x1v.v, _CMP_LT_OQ ); + ymm1.v = _mm256_blendv_pd( zerov.v, x1v.v, mask_vec1.v ); + sum_med_vec1.v = _mm256_fmadd_pd( ymm1.v, ymm1.v, sum_med_vec1.v ); + + // Fill sum_sml vector using scaling. + ymm1.v = _mm256_blendv_pd( _mm256_set1_pd( scale_sml ), zerov.v, mask_vec1.v ); + ymm1.v = _mm256_mul_pd( x1v.v, ymm1.v ); + sum_sml_vec1.v = _mm256_fmadd_pd( ymm1.v, ymm1.v, sum_sml_vec1.v ); + } } + xt += 8; } - } - v4df_t x0v, x1v, x2v, x3v, x4v, x5v, x6v, x7v; - /* Initialize rho vector accumulators to zero.*/ - v4df_t rho0v; rho0v.v = _mm256_setzero_pd(); - v4df_t rho1v; rho1v.v = _mm256_setzero_pd(); - v4df_t rho2v; rho2v.v = _mm256_setzero_pd(); - v4df_t rho3v; rho3v.v = _mm256_setzero_pd(); - v4df_t rho4v; rho4v.v = _mm256_setzero_pd(); - v4df_t rho5v; rho5v.v = _mm256_setzero_pd(); - v4df_t rho6v; rho6v.v = _mm256_setzero_pd(); - v4df_t rho7v; rho7v.v = _mm256_setzero_pd(); + for ( ; ( i + 4 ) <= n; i = i + 4 ) + { + x0v.v = _mm256_loadu_pd( xt ); - double *x0 = x_buf; + // Getting the abs of the vector elements. + x0v.v = _mm256_andnot_pd( temp1.v, x0v.v ); - for(i = 0 ; i+31 < n ; i = i + 32) - { + // Mask vector which indicates whether + // xi<=thres_sml or xi>=thres_big. + mask_vec0.v = CMP256( x0v.v, thres_sml_vec.v, thres_big_vec.v ); - x0v.v = _mm256_loadu_pd( x0 ); - x1v.v = _mm256_loadu_pd( x0 + 4 ); - x2v.v = _mm256_loadu_pd( x0 + 8 ); - x3v.v = _mm256_loadu_pd( x0 + 12 ); - x4v.v = _mm256_loadu_pd( x0 + 16 ); - x5v.v = _mm256_loadu_pd( x0 + 20 ); - x6v.v = _mm256_loadu_pd( x0 + 24 ); - x7v.v = _mm256_loadu_pd( x0 + 28 ); - - rho0v.v = _mm256_fmadd_pd(x0v.v, x0v.v, rho0v.v); - rho1v.v = _mm256_fmadd_pd(x1v.v, x1v.v, rho1v.v); - rho2v.v = _mm256_fmadd_pd(x2v.v, x2v.v, rho2v.v); - rho3v.v = _mm256_fmadd_pd(x3v.v, x3v.v, rho3v.v); - rho4v.v = _mm256_fmadd_pd(x4v.v, x4v.v, rho4v.v); - rho5v.v = _mm256_fmadd_pd(x5v.v, x5v.v, rho5v.v); - rho6v.v = _mm256_fmadd_pd(x6v.v, x6v.v, rho6v.v); - rho7v.v = _mm256_fmadd_pd(x7v.v, x7v.v, rho7v.v); - - x0 += 32; + if ( !bli_horizontal_or( mask_vec0.v ) ) + { + // Scaling is not necessary; only medium values. + sum_med_vec0.v = _mm256_fmadd_pd(x0v.v, x0v.v, sum_med_vec0.v); + } + else + { + // Mask vector which indicate whether xi < thres_big. + mask_vec0.v = _mm256_cmp_pd( thres_big_vec.v, x0v.v, _CMP_GT_OQ ); + + if ( bli_horizontal_or( mask_vec0.v ) ) + { + isbig = true; + + // Fill sum_med vector without scaling. + ymm0.v = _mm256_blendv_pd( zerov.v, x0v.v, mask_vec0.v ); + sum_med_vec0.v = _mm256_fmadd_pd(ymm0.v, ymm0.v, sum_med_vec0.v); + + // Fill sum_big vector using scaling. + ymm0.v = _mm256_blendv_pd( _mm256_set1_pd(scale_big), zerov.v, mask_vec0.v ); + ymm0.v = _mm256_mul_pd(x0v.v, ymm0.v); + sum_big_vec0.v = _mm256_fmadd_pd(ymm0.v, ymm0.v, sum_big_vec0.v); + } + else if (!isbig) + { + // Mask vector which indicates whether xi > thres_small. + mask_vec0.v = _mm256_cmp_pd( thres_sml_vec.v, x0v.v, _CMP_LT_OQ ); + ymm0.v = _mm256_blendv_pd( zerov.v, x0v.v, mask_vec0.v ); + sum_med_vec0.v = _mm256_fmadd_pd(ymm0.v, ymm0.v, sum_med_vec0.v); + + // Fill sum_sml vector using scaling. + ymm0.v = _mm256_blendv_pd( _mm256_set1_pd(scale_sml), zerov.v, mask_vec0.v ); + ymm0.v = _mm256_mul_pd(x0v.v, ymm0.v); + sum_sml_vec0.v = _mm256_fmadd_pd(ymm0.v, ymm0.v, sum_sml_vec0.v); + } + } + xt += 4; + } + + sum_sml_vec0.v = _mm256_add_pd( sum_sml_vec0.v, sum_sml_vec1.v ); + sum_med_vec0.v = _mm256_add_pd( sum_med_vec0.v, sum_med_vec1.v ); + sum_big_vec0.v = _mm256_add_pd( sum_big_vec0.v, sum_big_vec1.v ); + + sum_sml += sum_sml_vec0.v[0] + sum_sml_vec0.v[1] + + sum_sml_vec0.v[2] + sum_sml_vec0.v[3]; + sum_med += sum_med_vec0.v[0] + sum_med_vec0.v[1] + + sum_med_vec0.v[2] + sum_med_vec0.v[3]; + sum_big += sum_big_vec0.v[0] + sum_big_vec0.v[1] + + sum_big_vec0.v[2] + sum_big_vec0.v[3]; } n_remainder = n - i; - if(n_remainder) + if ( ( n_remainder > 0 ) ) { - if(n_remainder >= 16) + // Put first the most likely to happen to avoid evaluations on if statements. + for (i = 0; i < n_remainder; i++) { - x0v.v = _mm256_loadu_pd( x0 ); - x1v.v = _mm256_loadu_pd( x0 + 4 ); - x2v.v = _mm256_loadu_pd( x0 + 8 ); - x3v.v = _mm256_loadu_pd( x0 + 12 ); - - rho0v.v = _mm256_fmadd_pd(x0v.v, x0v.v, rho0v.v); - rho1v.v = _mm256_fmadd_pd(x1v.v, x1v.v, rho1v.v); - rho2v.v = _mm256_fmadd_pd(x2v.v, x2v.v, rho2v.v); - rho3v.v = _mm256_fmadd_pd(x3v.v, x3v.v, rho3v.v); - - x0 += 16; - n_remainder -= 16; + abs_chi = bli_fabs( *xt ); + // Most likely case: medium values, not over/under-flow. + if ( ( abs_chi <= thres_big ) && ( abs_chi >= thres_sml ) ) + { + sum_med += abs_chi * abs_chi; + } + // Case where there could be an overflow. Scaling is required. + else if ( abs_chi > thres_big ) + { + sum_big += ( abs_chi * scale_big ) * ( abs_chi * scale_big ); + isbig = true; + } + // Case where there could be an underflow. Scaling is required. + else if ( ( ( !isbig ) && abs_chi < thres_sml ) ) + { + sum_sml += ( abs_chi * scale_sml ) * ( abs_chi * scale_sml ); + } + xt++; } - if(n_remainder >= 8) - { - x0v.v = _mm256_loadu_pd( x0 ); - x1v.v = _mm256_loadu_pd( x0 + 4 ); - - rho0v.v = _mm256_fmadd_pd(x0v.v, x0v.v, rho0v.v); - rho1v.v = _mm256_fmadd_pd(x1v.v, x1v.v, rho1v.v); + } - x0 += 8; - n_remainder -= 8; - } - if(n_remainder >= 4) + // Combine accumulators. + if ( isbig ) + { + // Combine sum_big and sum_med if sum_med > 0. + if ( sum_med > 0.0 ) { - x0v.v = _mm256_loadu_pd( x0 ); - - rho0v.v = _mm256_fmadd_pd(x0v.v, x0v.v, rho0v.v); - - x0 += 4; - n_remainder -= 4; + sum_big += ( sum_med * scale_big ) * scale_big; } - if(n_remainder) + scale = 1.0 / scale_big; + sumsq = sum_big; + } + + else if ( sum_sml > 0.0 ) + { + // Combine sum_med and sum_sml if sum_sml>0. + if ( sum_med > 0.0 ) { - for(i=0; i< n_remainder ;i++) + sum_med = sqrt( sum_med ); + sum_sml = sqrt( sum_sml ) / scale_sml; + double ymin, ymax; + if ( sum_sml > sum_med ) + { + ymin = sum_med; + ymax = sum_sml; + } + else { - double x_temp = *x0; - rem_sumsq += x_temp * x_temp ; - x0 += 1; + ymin = sum_sml; + ymax = sum_med; } + scale = 1.0; + sumsq = ymax * ymax * ( 1.0 + ( ymin / ymax ) * ( ymin / ymax ) ); + } + else + { + scale = 1.0 / scale_sml; + sumsq = sum_sml; } } + else + { + // If all values are mid-range: + scale = 1.0; + sumsq = sum_med; + } - /*add all the dot product of x*x into one vector .*/ - rho0v.v = _mm256_add_pd ( rho0v.v, rho1v.v ); - rho1v.v = _mm256_add_pd ( rho2v.v, rho3v.v ); - rho2v.v = _mm256_add_pd ( rho4v.v, rho5v.v ); - rho3v.v = _mm256_add_pd ( rho6v.v, rho7v.v ); - - rho4v.v = _mm256_add_pd ( rho0v.v, rho1v.v ); - rho5v.v = _mm256_add_pd ( rho2v.v, rho3v.v ); - - rho6v.v = _mm256_add_pd ( rho4v.v, rho5v.v ); - - rho7v.v = _mm256_hadd_pd( rho6v.v, rho6v.v ); - - /*rem_sumsq will have sum of squares of n_remainder < 4 cases . - Accumulate all the sum of squares to sumsq*/ - sumsq = rem_sumsq + rho7v.d[0] + rho7v.d[2]; - - PASTEMAC(d,sqrt2s)( sumsq, *norm ); + *norm = scale * sqrt( sumsq ); - if ((incx != 1) && bli_mem_is_alloc( &mem_bufX )) + if ( ( incx != 1 ) && bli_mem_is_alloc( &mem_bufX ) ) { #ifdef BLIS_ENABLE_MEM_TRACING printf( "bli_dnorm2fv_unb_var1(): releasing mem pool block\n" ); #endif - /* Return the buffer to pool*/ - bli_membrk_release(&rntm , &mem_bufX); + // Return the buffer to pool. + bli_membrk_release( &rntm , &mem_bufX ); } - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); - return ; -} -#endif + + AOCL_DTL_TRACE_EXIT( AOCL_DTL_LEVEL_TRACE_3 ); + + return; +} \ No newline at end of file diff --git a/kernels/zen/bli_kernels_zen.h b/kernels/zen/bli_kernels_zen.h index d600c4ac45..210052d069 100644 --- a/kernels/zen/bli_kernels_zen.h +++ b/kernels/zen/bli_kernels_zen.h @@ -379,13 +379,10 @@ bool bli_cntx_trsm_small_thresh_is_met_zen dim_t n ); -#ifdef BLIS_ENABLE_FAST_MATH -void bli_dnorm2fv_unb_var1 +void bli_dnorm2fv_unb_var1_avx ( dim_t n, double* x, inc_t incx, double* norm, cntx_t* cntx - ); -#endif - + ); \ No newline at end of file From 85120634f771fb4481aa3fce65cf232573d958d3 Mon Sep 17 00:00:00 2001 From: Eleni Vlachopoulou Date: Tue, 27 Sep 2022 00:50:09 +0530 Subject: [PATCH 233/243] Bugfix in DNRM2 AVX path Description: Enabled DNRM2 AVX path and fixed bug that caused numerical accuracy errors. AMD-Internal: [CPUPL-2576] Change-Id: Ic9fda9d9668bdfe233621f79db6acce518b4d10e --- kernels/zen/1/bli_norm2_zen_int.c | 188 ++++++++++++++++++++---------- 1 file changed, 128 insertions(+), 60 deletions(-) diff --git a/kernels/zen/1/bli_norm2_zen_int.c b/kernels/zen/1/bli_norm2_zen_int.c index a1941667bc..7ca57206e7 100644 --- a/kernels/zen/1/bli_norm2_zen_int.c +++ b/kernels/zen/1/bli_norm2_zen_int.c @@ -159,8 +159,8 @@ void bli_dnorm2fv_unb_var1_avx if ( n > 4 ) { // Constants used for comparisons. - v4df_t temp1, thres_sml_vec, thres_big_vec, zerov, ymm0, ymm1; - temp1.v = _mm256_set1_pd( -0.0 ); + v4df_t temp, thres_sml_vec, thres_big_vec, zerov, ymm0, ymm1; + temp.v = _mm256_set1_pd( -0.0 ); thres_sml_vec.v = _mm256_set1_pd( thres_sml ); thres_big_vec.v = _mm256_set1_pd( thres_big ); v4df_t x0v, x1v, mask_vec0, mask_vec1; @@ -182,8 +182,22 @@ void bli_dnorm2fv_unb_var1_avx x1v.v = _mm256_loadu_pd( xt + 4 ); // Getting the abs of the vector elements. - x0v.v = _mm256_andnot_pd( temp1.v, x0v.v ); - x1v.v = _mm256_andnot_pd( temp1.v, x1v.v ); + x0v.v = _mm256_andnot_pd( temp.v, x0v.v ); + x1v.v = _mm256_andnot_pd( temp.v, x1v.v ); + + // Check if any of the values is a NaN and if so, return. + mask_vec0.v = _mm256_cmp_pd(x0v.v, x0v.v, _CMP_UNORD_Q); + mask_vec1.v = _mm256_cmp_pd(x1v.v, x1v.v, _CMP_UNORD_Q); + if ( bli_horizontal_or( mask_vec0.v ) ) + { + *norm = NAN; + return; + } + if ( bli_horizontal_or( mask_vec1.v ) ) + { + *norm = NAN; + return; + } // Mask vectors which indicate whether // xi<=thres_sml or xi>=thres_big. @@ -198,35 +212,42 @@ void bli_dnorm2fv_unb_var1_avx } else { - // Mask vector which indicates whether xi < thres_big. - mask_vec0.v = _mm256_cmp_pd( thres_big_vec.v, x0v.v, _CMP_GT_OQ ); + // Mask vector which indicate whether xi > thres_big. + mask_vec0.v = _mm256_cmp_pd( x0v.v, thres_big_vec.v, _CMP_GT_OQ ); if ( bli_horizontal_or( mask_vec0.v ) ) { isbig = true; // Fill sum_med vector without scaling. - ymm0.v = _mm256_blendv_pd( zerov.v, x0v.v, mask_vec0.v ); + ymm0.v = _mm256_blendv_pd( x0v.v, zerov.v, mask_vec0.v ); sum_med_vec0.v = _mm256_fmadd_pd( ymm0.v, ymm0.v, sum_med_vec0.v ); - // Fill sum_big vector using scaling. - ymm0.v = _mm256_blendv_pd( _mm256_set1_pd( scale_big ), zerov.v, mask_vec0.v ); + // Fill sum_big vector using scaling. + temp.v = _mm256_set1_pd( scale_big ); + ymm0.v = _mm256_blendv_pd( zerov.v, temp.v, mask_vec0.v ); ymm0.v = _mm256_mul_pd( x0v.v, ymm0.v ); - sum_big_vec0.v = _mm256_fmadd_pd( ymm0.v, ymm0.v, sum_big_vec0.v ); + sum_big_vec0.v = _mm256_fmadd_pd( ymm0.v, ymm0.v, sum_big_vec0.v ); + temp.v = _mm256_set1_pd( -0.0 ); } - else if (!isbig) + else { // Mask vector which indicates whether xi > thres_small. - mask_vec0.v = _mm256_cmp_pd( thres_sml_vec.v, x0v.v, _CMP_LT_OQ ); - + mask_vec0.v = _mm256_cmp_pd( x0v.v, thres_sml_vec.v, _CMP_LT_OQ ); // Fill sum_med vector without scaling. - ymm0.v = _mm256_blendv_pd( zerov.v, x0v.v, mask_vec0.v ); + ymm0.v = _mm256_blendv_pd( x0v.v, zerov.v, mask_vec0.v ); sum_med_vec0.v = _mm256_fmadd_pd( ymm0.v, ymm0.v, sum_med_vec0.v ); - - // Fill sum_sml vector using scaling. - ymm0.v = _mm256_blendv_pd( _mm256_set1_pd( scale_sml ), zerov.v, mask_vec0.v ); - ymm0.v = _mm256_mul_pd( x0v.v, ymm0.v ); - sum_sml_vec0.v = _mm256_fmadd_pd( ymm0.v, ymm0.v, sum_sml_vec0.v ); + + // Accumulate small values only if there have not been any big values so far. + if ( !isbig ) + { + // Fill sum_sml vector using scaling. + temp.v = _mm256_set1_pd( scale_sml ); + ymm0.v = _mm256_blendv_pd( zerov.v, temp.v, mask_vec0.v ); + ymm0.v = _mm256_mul_pd( x0v.v, ymm0.v ); + sum_sml_vec0.v = _mm256_fmadd_pd( ymm0.v, ymm0.v, sum_sml_vec0.v ); + temp.v = _mm256_set1_pd( -0.0 ); + } } } @@ -237,34 +258,45 @@ void bli_dnorm2fv_unb_var1_avx } else { - // Mask vector which indicated whether xi < thres_big. - mask_vec1.v = _mm256_cmp_pd( thres_big_vec.v, x1v.v, _CMP_GT_OQ ); + // Mask vector which indicate whether xi > thres_big. + mask_vec1.v = _mm256_cmp_pd( x1v.v, thres_big_vec.v, _CMP_GT_OQ ); + if ( bli_horizontal_or( mask_vec1.v ) ) { isbig = true; // Fill sum_med vector without scaling. - ymm1.v = _mm256_blendv_pd( zerov.v, x1v.v, mask_vec1.v ); + ymm1.v = _mm256_blendv_pd( x1v.v, zerov.v, mask_vec1.v ); sum_med_vec1.v = _mm256_fmadd_pd( ymm1.v, ymm1.v, sum_med_vec1.v ); - - // Fill sum_big vector using scaling. - ymm1.v = _mm256_blendv_pd( _mm256_set1_pd( scale_big ), zerov.v, mask_vec1.v ); + + // Fill sum_big vector using scaling. + temp.v = _mm256_set1_pd( scale_big ); + ymm1.v = _mm256_blendv_pd( zerov.v, temp.v, mask_vec1.v ); ymm1.v = _mm256_mul_pd( x1v.v, ymm1.v ); - sum_big_vec1.v = _mm256_fmadd_pd( ymm1.v, ymm1.v, sum_big_vec1.v ); + sum_big_vec1.v = _mm256_fmadd_pd( ymm1.v, ymm1.v, sum_big_vec1.v ); + temp.v = _mm256_set1_pd( -0.0 ); } - else if (!isbig) + else { - // Mask vector which indicated whether xi > thres_small. - mask_vec1.v = _mm256_cmp_pd( thres_sml_vec.v, x1v.v, _CMP_LT_OQ ); - ymm1.v = _mm256_blendv_pd( zerov.v, x1v.v, mask_vec1.v ); + // Mask vector which indicates whether xi > thres_small. + mask_vec1.v = _mm256_cmp_pd( x1v.v, thres_sml_vec.v, _CMP_LT_OQ ); + // Fill sum_med vector without scaling. + ymm1.v = _mm256_blendv_pd( x1v.v, zerov.v, mask_vec1.v ); sum_med_vec1.v = _mm256_fmadd_pd( ymm1.v, ymm1.v, sum_med_vec1.v ); - - // Fill sum_sml vector using scaling. - ymm1.v = _mm256_blendv_pd( _mm256_set1_pd( scale_sml ), zerov.v, mask_vec1.v ); - ymm1.v = _mm256_mul_pd( x1v.v, ymm1.v ); - sum_sml_vec1.v = _mm256_fmadd_pd( ymm1.v, ymm1.v, sum_sml_vec1.v ); + + // Accumulate small values only if there have not been any big values so far. + if ( !isbig ) + { + // Fill sum_sml vector using scaling. + temp.v = _mm256_set1_pd( scale_sml ); + ymm1.v = _mm256_blendv_pd( zerov.v, temp.v, mask_vec1.v ); + ymm1.v = _mm256_mul_pd( x1v.v, ymm1.v ); + sum_sml_vec1.v = _mm256_fmadd_pd( ymm1.v, ymm1.v, sum_sml_vec1.v ); + temp.v = _mm256_set1_pd( -0.0 ); + } } - } + } + xt += 8; } @@ -273,46 +305,63 @@ void bli_dnorm2fv_unb_var1_avx x0v.v = _mm256_loadu_pd( xt ); // Getting the abs of the vector elements. - x0v.v = _mm256_andnot_pd( temp1.v, x0v.v ); + x0v.v = _mm256_andnot_pd( temp.v, x0v.v ); + + // Check if any of the values is a NaN and if so, return. + mask_vec0.v = _mm256_cmp_pd(x0v.v, x0v.v, _CMP_UNORD_Q); + if ( bli_horizontal_or( mask_vec0.v ) ) + { + *norm = NAN; + return; + } - // Mask vector which indicates whether + // Mask vectors which indicate whether // xi<=thres_sml or xi>=thres_big. mask_vec0.v = CMP256( x0v.v, thres_sml_vec.v, thres_big_vec.v ); if ( !bli_horizontal_or( mask_vec0.v ) ) { // Scaling is not necessary; only medium values. - sum_med_vec0.v = _mm256_fmadd_pd(x0v.v, x0v.v, sum_med_vec0.v); - } + sum_med_vec0.v = _mm256_fmadd_pd( x0v.v, x0v.v, sum_med_vec0.v ); + } else { - // Mask vector which indicate whether xi < thres_big. - mask_vec0.v = _mm256_cmp_pd( thres_big_vec.v, x0v.v, _CMP_GT_OQ ); - + // Mask vector which indicate whether xi > thres_big. + mask_vec0.v = _mm256_cmp_pd( x0v.v, thres_big_vec.v, _CMP_GT_OQ ); + if ( bli_horizontal_or( mask_vec0.v ) ) { isbig = true; // Fill sum_med vector without scaling. - ymm0.v = _mm256_blendv_pd( zerov.v, x0v.v, mask_vec0.v ); - sum_med_vec0.v = _mm256_fmadd_pd(ymm0.v, ymm0.v, sum_med_vec0.v); - - // Fill sum_big vector using scaling. - ymm0.v = _mm256_blendv_pd( _mm256_set1_pd(scale_big), zerov.v, mask_vec0.v ); - ymm0.v = _mm256_mul_pd(x0v.v, ymm0.v); - sum_big_vec0.v = _mm256_fmadd_pd(ymm0.v, ymm0.v, sum_big_vec0.v); + ymm0.v = _mm256_blendv_pd( x0v.v, zerov.v, mask_vec0.v ); + sum_med_vec0.v = _mm256_fmadd_pd( ymm0.v, ymm0.v, sum_med_vec0.v ); + + // Fill sum_big vector using scaling. + temp.v = _mm256_set1_pd( scale_big ); + ymm0.v = _mm256_blendv_pd( zerov.v, temp.v, mask_vec0.v ); + ymm0.v = _mm256_mul_pd( x0v.v, ymm0.v ); + sum_big_vec0.v = _mm256_fmadd_pd( ymm0.v, ymm0.v, sum_big_vec0.v ); + temp.v = _mm256_set1_pd( -0.0 ); } - else if (!isbig) + else { // Mask vector which indicates whether xi > thres_small. - mask_vec0.v = _mm256_cmp_pd( thres_sml_vec.v, x0v.v, _CMP_LT_OQ ); - ymm0.v = _mm256_blendv_pd( zerov.v, x0v.v, mask_vec0.v ); - sum_med_vec0.v = _mm256_fmadd_pd(ymm0.v, ymm0.v, sum_med_vec0.v); - - // Fill sum_sml vector using scaling. - ymm0.v = _mm256_blendv_pd( _mm256_set1_pd(scale_sml), zerov.v, mask_vec0.v ); - ymm0.v = _mm256_mul_pd(x0v.v, ymm0.v); - sum_sml_vec0.v = _mm256_fmadd_pd(ymm0.v, ymm0.v, sum_sml_vec0.v); + mask_vec0.v = _mm256_cmp_pd( x0v.v, thres_sml_vec.v, _CMP_LT_OQ ); + // Fill sum_med vector without scaling. + ymm0.v = _mm256_blendv_pd( x0v.v, zerov.v, mask_vec0.v ); + sum_med_vec0.v = _mm256_fmadd_pd( ymm0.v, ymm0.v, sum_med_vec0.v ); + + // Accumulate small values only if there have not been any big values so far. + if ( !isbig ) + { + // Fill sum_sml vector using scaling. + temp.v = _mm256_set1_pd( scale_sml ); + ymm0.v = _mm256_blendv_pd( zerov.v, temp.v, mask_vec0.v ); + ymm0.v = _mm256_mul_pd( x0v.v, ymm0.v ); + sum_sml_vec0.v = _mm256_fmadd_pd( ymm0.v, ymm0.v, sum_sml_vec0.v ); + temp.v = _mm256_set1_pd( -0.0 ); + } } } xt += 4; @@ -331,13 +380,29 @@ void bli_dnorm2fv_unb_var1_avx } n_remainder = n - i; - + bool hasInf = false; if ( ( n_remainder > 0 ) ) { // Put first the most likely to happen to avoid evaluations on if statements. for (i = 0; i < n_remainder; i++) { abs_chi = bli_fabs( *xt ); + // If any of the elements is NaN, then return NaN as a result. + if ( bli_isnan( abs_chi ) ) + { + *norm = abs_chi; + return; + } + // Else, if any of the elements is an Inf, then return +Inf as a result. + if ( bli_isinf( abs_chi ) ) + { + *norm = abs_chi; + // Instead of returning immediately, use this flag + // to denote that there is an Inf element in the vector. + // That is used to avoid cases where there is a NaN which comes + // after an Inf. + hasInf = true; + } // Most likely case: medium values, not over/under-flow. if ( ( abs_chi <= thres_big ) && ( abs_chi >= thres_sml ) ) { @@ -358,6 +423,9 @@ void bli_dnorm2fv_unb_var1_avx } } + // Early return if there is an Inf. + if (hasInf) return; + // Combine accumulators. if ( isbig ) { From 92fd790545e0aa128e444e47e4dbdd5045ef65f9 Mon Sep 17 00:00:00 2001 From: Eleni Vlachopoulou Date: Thu, 29 Sep 2022 22:02:27 +0530 Subject: [PATCH 234/243] Adding AVX2 support for DZNRM2 - For the cases where AVX2 is available, an optimized function is called, based on Blue's algorithm. The fallback method based on sumsqv is used otherwise. - Scaling is used to avoid overflow and underflow. - Works correctly for negative increments. - Cleaned up some white space in the AVX2 implementation for DNRM2. AMD-Internal: [CPUPL-2551] Change-Id: I0875234ea735540307168fe7efc3f10fe6c40ffc --- frame/util/bli_util_unb_var1.c | 52 ++- kernels/zen/1/bli_norm2_zen_int.c | 526 ++++++++++++++++++++++++++++-- kernels/zen/bli_kernels_zen.h | 18 +- 3 files changed, 563 insertions(+), 33 deletions(-) diff --git a/frame/util/bli_util_unb_var1.c b/frame/util/bli_util_unb_var1.c index 2d43bb74c3..48abe57acc 100644 --- a/frame/util/bli_util_unb_var1.c +++ b/frame/util/bli_util_unb_var1.c @@ -308,7 +308,55 @@ void PASTEMAC(ch,varname) \ //INSERT_GENTFUNCR_BASIC( normfv_unb_var1, sumsqv_unb_var1 ) GENTFUNCR( scomplex, float, c, s, normfv_unb_var1, sumsqv_unb_var1 ) -GENTFUNCR( dcomplex, double, z, d, normfv_unb_var1, sumsqv_unb_var1 ) + +void bli_znormfv_unb_var1 + ( + dim_t n, + dcomplex* x, + inc_t incx, + double* norm, + cntx_t* cntx, + rntm_t* rntm + ) +{ + + if ( bli_cpuid_is_avx_supported() == TRUE ) + { + bli_dznorm2fv_unb_var1_avx2( n, x, incx, norm, cntx ); + } + else + { + double* zero = bli_d0; + double* one = bli_d1; + double scale; + double sumsq; + double sqrt_sumsq; + + // Initialize scale and sumsq to begin the summation. + bli_dcopys( *zero, scale ); + bli_dcopys( *one, sumsq ); + + // Compute the sum of the squares of the vector. + + bli_zsumsqv_unb_var1 + ( + n, + x, + incx, + &scale, + &sumsq, + cntx, + rntm + ); + + // Compute: norm = scale * sqrt( sumsq ) + bli_dsqrt2s( sumsq, sqrt_sumsq ); + bli_dscals( scale, sqrt_sumsq ); + + // Store the final value to the output variable. + bli_dcopys( sqrt_sumsq, *norm ); + } +} #undef GENTFUNCR // We've disabled the dotv-based implementation because that method of @@ -454,7 +502,7 @@ void bli_dnormfv_unb_var1 if( bli_cpuid_is_avx_supported() == TRUE ) { - bli_dnorm2fv_unb_var1_avx( n, x, incx, norm, cntx ); + bli_dnorm2fv_unb_var1_avx2( n, x, incx, norm, cntx ); } else { diff --git a/kernels/zen/1/bli_norm2_zen_int.c b/kernels/zen/1/bli_norm2_zen_int.c index 7ca57206e7..278f752b63 100644 --- a/kernels/zen/1/bli_norm2_zen_int.c +++ b/kernels/zen/1/bli_norm2_zen_int.c @@ -59,8 +59,8 @@ typedef union // and false, otherwise. static inline bool bli_horizontal_or( __m256d a ) { return ! _mm256_testz_pd( a, a ); } -// Optimized function that computes the Frobenius norm using AVX intrinsics. -void bli_dnorm2fv_unb_var1_avx +// Optimized function that computes the Frobenius norm using AVX2 intrinsics. +void bli_dnorm2fv_unb_var1_avx2 ( dim_t n, double* x, inc_t incx, @@ -75,7 +75,7 @@ void bli_dnorm2fv_unb_var1_avx dim_t n_remainder = 0; double *x_buf = x; - // Early return if n<=0 or incx=0 + // Early return if n<=0 or incx=0 if ( ( n <= 0) || ( incx == 0 ) ) { return; @@ -104,7 +104,7 @@ void bli_dnorm2fv_unb_var1_avx #ifdef BLIS_ENABLE_MEM_TRACING printf( "bli_dnorm2fv_unb_var1(): get mem pool block\n" ); - #endif + #endif // Acquire a Buffer(n*size(double)) from the memory broker // and save the associated mem_t entry to mem_bufX. @@ -123,11 +123,11 @@ void bli_dnorm2fv_unb_var1_avx // Pack vector x with non-unit stride to a temp buffer x_buf with unit stride. for ( dim_t x_index = 0; x_index < n; x_index++ ) { - if (incx > 0) + if ( incx > 0 ) { *( x_buf + x_index ) = *( x + ( x_index * incx ) ); - } - else + } + else { *( x_buf + x_index ) = *( x + ( - ( n - x_index - 1 ) * incx ) ); } @@ -151,7 +151,7 @@ void bli_dnorm2fv_unb_var1_avx const double thres_big = pow( ( double )FLT_RADIX, floor( ( DBL_MAX_EXP - 52) * 0.5 ) ); const double scale_sml = pow( ( double )FLT_RADIX, - floor( ( DBL_MIN_EXP - 53 ) * 0.5 ) ); const double scale_big = pow( ( double )FLT_RADIX, - ceil( ( DBL_MAX_EXP - 52 ) * 0.5 ) ); - + double scale; double abs_chi; bool isbig = false; @@ -163,7 +163,7 @@ void bli_dnorm2fv_unb_var1_avx temp.v = _mm256_set1_pd( -0.0 ); thres_sml_vec.v = _mm256_set1_pd( thres_sml ); thres_big_vec.v = _mm256_set1_pd( thres_big ); - v4df_t x0v, x1v, mask_vec0, mask_vec1; + v4df_t x0v, x1v, mask_vec0, mask_vec1; zerov.v = _mm256_setzero_pd(); // Partial sums used for scaling. @@ -175,7 +175,6 @@ void bli_dnorm2fv_unb_var1_avx sum_big_vec1.v = _mm256_setzero_pd(); sum_sml_vec1.v = _mm256_setzero_pd(); - for (; ( i + 8 ) <= n; i = i + 8) { x0v.v = _mm256_loadu_pd( xt ); @@ -199,22 +198,21 @@ void bli_dnorm2fv_unb_var1_avx return; } - // Mask vectors which indicate whether + // Mask vectors which indicate whether // xi<=thres_sml or xi>=thres_big. mask_vec0.v = CMP256( x0v.v, thres_sml_vec.v, thres_big_vec.v ); mask_vec1.v = CMP256( x1v.v, thres_sml_vec.v, thres_big_vec.v ); - if ( !bli_horizontal_or( mask_vec0.v ) ) { // Scaling is not necessary; only medium values. sum_med_vec0.v = _mm256_fmadd_pd( x0v.v, x0v.v, sum_med_vec0.v ); - } + } else { // Mask vector which indicate whether xi > thres_big. mask_vec0.v = _mm256_cmp_pd( x0v.v, thres_big_vec.v, _CMP_GT_OQ ); - + if ( bli_horizontal_or( mask_vec0.v ) ) { isbig = true; @@ -222,8 +220,8 @@ void bli_dnorm2fv_unb_var1_avx // Fill sum_med vector without scaling. ymm0.v = _mm256_blendv_pd( x0v.v, zerov.v, mask_vec0.v ); sum_med_vec0.v = _mm256_fmadd_pd( ymm0.v, ymm0.v, sum_med_vec0.v ); - - // Fill sum_big vector using scaling. + + // Fill sum_big vector using scaling. temp.v = _mm256_set1_pd( scale_big ); ymm0.v = _mm256_blendv_pd( zerov.v, temp.v, mask_vec0.v ); ymm0.v = _mm256_mul_pd( x0v.v, ymm0.v ); @@ -249,18 +247,18 @@ void bli_dnorm2fv_unb_var1_avx temp.v = _mm256_set1_pd( -0.0 ); } } - } + } if ( !bli_horizontal_or( mask_vec1.v ) ) { // Scaling is not necessary; only medium values. sum_med_vec1.v = _mm256_fmadd_pd( x1v.v, x1v.v, sum_med_vec1.v ); - } + } else { // Mask vector which indicate whether xi > thres_big. mask_vec1.v = _mm256_cmp_pd( x1v.v, thres_big_vec.v, _CMP_GT_OQ ); - + if ( bli_horizontal_or( mask_vec1.v ) ) { isbig = true; @@ -268,8 +266,8 @@ void bli_dnorm2fv_unb_var1_avx // Fill sum_med vector without scaling. ymm1.v = _mm256_blendv_pd( x1v.v, zerov.v, mask_vec1.v ); sum_med_vec1.v = _mm256_fmadd_pd( ymm1.v, ymm1.v, sum_med_vec1.v ); - - // Fill sum_big vector using scaling. + + // Fill sum_big vector using scaling. temp.v = _mm256_set1_pd( scale_big ); ymm1.v = _mm256_blendv_pd( zerov.v, temp.v, mask_vec1.v ); ymm1.v = _mm256_mul_pd( x1v.v, ymm1.v ); @@ -295,7 +293,7 @@ void bli_dnorm2fv_unb_var1_avx temp.v = _mm256_set1_pd( -0.0 ); } } - } + } xt += 8; } @@ -315,7 +313,7 @@ void bli_dnorm2fv_unb_var1_avx return; } - // Mask vectors which indicate whether + // Mask vectors which indicate whether // xi<=thres_sml or xi>=thres_big. mask_vec0.v = CMP256( x0v.v, thres_sml_vec.v, thres_big_vec.v ); @@ -323,12 +321,12 @@ void bli_dnorm2fv_unb_var1_avx { // Scaling is not necessary; only medium values. sum_med_vec0.v = _mm256_fmadd_pd( x0v.v, x0v.v, sum_med_vec0.v ); - } + } else { // Mask vector which indicate whether xi > thres_big. mask_vec0.v = _mm256_cmp_pd( x0v.v, thres_big_vec.v, _CMP_GT_OQ ); - + if ( bli_horizontal_or( mask_vec0.v ) ) { isbig = true; @@ -336,10 +334,10 @@ void bli_dnorm2fv_unb_var1_avx // Fill sum_med vector without scaling. ymm0.v = _mm256_blendv_pd( x0v.v, zerov.v, mask_vec0.v ); sum_med_vec0.v = _mm256_fmadd_pd( ymm0.v, ymm0.v, sum_med_vec0.v ); - - // Fill sum_big vector using scaling. + + // Fill sum_big vector using scaling. temp.v = _mm256_set1_pd( scale_big ); - ymm0.v = _mm256_blendv_pd( zerov.v, temp.v, mask_vec0.v ); + ymm0.v = _mm256_blendv_pd( zerov.v, temp.v, mask_vec0.v ); ymm0.v = _mm256_mul_pd( x0v.v, ymm0.v ); sum_big_vec0.v = _mm256_fmadd_pd( ymm0.v, ymm0.v, sum_big_vec0.v ); temp.v = _mm256_set1_pd( -0.0 ); @@ -415,7 +413,7 @@ void bli_dnorm2fv_unb_var1_avx isbig = true; } // Case where there could be an underflow. Scaling is required. - else if ( ( ( !isbig ) && abs_chi < thres_sml ) ) + else if ( ( !isbig ) && ( abs_chi < thres_sml ) ) { sum_sml += ( abs_chi * scale_sml ) * ( abs_chi * scale_sml ); } @@ -424,7 +422,7 @@ void bli_dnorm2fv_unb_var1_avx } // Early return if there is an Inf. - if (hasInf) return; + if ( hasInf ) return; // Combine accumulators. if ( isbig ) @@ -485,5 +483,473 @@ void bli_dnorm2fv_unb_var1_avx AOCL_DTL_TRACE_EXIT( AOCL_DTL_LEVEL_TRACE_3 ); + return; +} + +// Optimized function that computes the Frobenius norm using AVX2 intrinsics. +void bli_dznorm2fv_unb_var1_avx2 + ( + dim_t n, + dcomplex* x, inc_t incx, + double* norm, + cntx_t* cntx + ) +{ + AOCL_DTL_TRACE_ENTRY( AOCL_DTL_LEVEL_TRACE_3 ); + + double sumsq = 0; + dim_t i = 0; + dim_t n_remainder = 0; + dcomplex *x_buf = x; + + // Early return if n<=0 or incx=0 + if ( ( n <= 0) || ( incx == 0 ) ) + { + return; + } + + // Memory pool declarations for packing vector X. + // Initialize mem pool buffer to NULL and size to 0. + // "buf" and "size" fields are assigned once memory + // is allocated from the pool in bli_membrk_acquire_m(). + // This will ensure bli_mem_is_alloc() will be passed on + // an allocated memory if created or a NULL. + mem_t mem_bufX = {0}; + rntm_t rntm; + + // Packing for non-unit strided vector x. + if ( incx != 1 ) + { + // In order to get the buffer from pool via rntm access to memory broker + //is needed. Following are initializations for rntm. + bli_rntm_init_from_global( &rntm ); + bli_rntm_set_num_threads_only( 1, &rntm ); + bli_membrk_rntm_set_membrk( &rntm ); + + // Calculate the size required for "n" dcomplex elements in vector x. + size_t buffer_size = n * sizeof( dcomplex ); + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_dznorm2fv_unb_var1(): get mem pool block\n" ); + #endif + + // Acquire a Buffer(n*size(dcomplex)) from the memory broker + // and save the associated mem_t entry to mem_bufX. + bli_membrk_acquire_m + ( + &rntm, + buffer_size, + BLIS_BUFFER_FOR_B_PANEL, + &mem_bufX + ); + + // Continue packing X if buffer memory is allocated. + if ( ( bli_mem_is_alloc( &mem_bufX ) ) ) + { + x_buf = bli_mem_buffer( &mem_bufX ); + // Pack vector x with non-unit stride to a temp buffer x_buf with unit stride. + for ( dim_t x_index = 0; x_index < n; x_index++ ) + { + if ( incx > 0 ) + { + *( x_buf + x_index ) = *( x + ( x_index * incx ) ); + } + else + { + *( x_buf + x_index ) = *( x + ( - ( n - x_index - 1 ) * incx ) ); + } + } + } + } + + dcomplex *xt = x_buf; + + // Compute the sum of squares on 3 accumulators to avoid overflow + // and underflow, depending on the vector element value. + // Accumulator for small values; using scaling to avoid underflow. + double sum_sml = 0; + // Accumulator for medium values; no scaling required. + double sum_med = 0; + // Accumulator for big values; using scaling to avoid overflow. + double sum_big = 0; + + // Constants chosen to minimize roundoff, according to Blue's algorithm. + const double thres_sml = pow( ( double )FLT_RADIX, ceil( ( DBL_MIN_EXP - 1 ) * 0.5 ) ); + const double thres_big = pow( ( double )FLT_RADIX, floor( ( DBL_MAX_EXP - 52) * 0.5 ) ); + const double scale_sml = pow( ( double )FLT_RADIX, - floor( ( DBL_MIN_EXP - 53 ) * 0.5 ) ); + const double scale_big = pow( ( double )FLT_RADIX, - ceil( ( DBL_MAX_EXP - 52 ) * 0.5 ) ); + + double scale; + double abs_chi; + bool isbig = false; + + if ( n > 2 ) + { + // Constants used for comparisons. + v4df_t temp, thres_sml_vec, thres_big_vec, zerov, ymm0, ymm1; + temp.v = _mm256_set1_pd( -0.0 ); + thres_sml_vec.v = _mm256_set1_pd( thres_sml ); + thres_big_vec.v = _mm256_set1_pd( thres_big ); + v4df_t x0v, x1v, mask_vec0, mask_vec1; + zerov.v = _mm256_setzero_pd(); + + // Partial sums used for scaling. + v4df_t sum_med_vec0, sum_big_vec0, sum_sml_vec0, sum_med_vec1, sum_big_vec1, sum_sml_vec1; + sum_med_vec0.v = _mm256_setzero_pd(); + sum_big_vec0.v = _mm256_setzero_pd(); + sum_sml_vec0.v = _mm256_setzero_pd(); + sum_med_vec1.v = _mm256_setzero_pd(); + sum_big_vec1.v = _mm256_setzero_pd(); + sum_sml_vec1.v = _mm256_setzero_pd(); + + for (; ( i + 4 ) <= n; i = i + 4) + { + x0v.v = _mm256_loadu_pd( xt ); + x1v.v = _mm256_loadu_pd( xt + 2 ); + + // Getting the abs of the vector elements. + x0v.v = _mm256_andnot_pd( temp.v, x0v.v ); + x1v.v = _mm256_andnot_pd( temp.v, x1v.v ); + + // Check if any of the values is a NaN and if so, return. + mask_vec0.v = _mm256_cmp_pd(x0v.v, x0v.v, _CMP_UNORD_Q); + mask_vec1.v = _mm256_cmp_pd(x1v.v, x1v.v, _CMP_UNORD_Q); + if ( bli_horizontal_or( mask_vec0.v ) ) + { + *norm = NAN; + return; + } + if ( bli_horizontal_or( mask_vec1.v ) ) + { + *norm = NAN; + return; + } + + // Mask vectors which indicate whether + // xi<=thres_sml or xi>=thres_big. + mask_vec0.v = CMP256( x0v.v, thres_sml_vec.v, thres_big_vec.v ); + mask_vec1.v = CMP256( x1v.v, thres_sml_vec.v, thres_big_vec.v ); + + if ( !bli_horizontal_or( mask_vec0.v ) ) + { + // Scaling is not necessary; only medium values. + sum_med_vec0.v = _mm256_fmadd_pd( x0v.v, x0v.v, sum_med_vec0.v ); + } + else + { + // Mask vector which indicate whether xi > thres_big. + mask_vec0.v = _mm256_cmp_pd( x0v.v, thres_big_vec.v, _CMP_GT_OQ ); + + if ( bli_horizontal_or( mask_vec0.v ) ) + { + isbig = true; + + // Fill sum_med vector without scaling. + ymm0.v = _mm256_blendv_pd( x0v.v, zerov.v, mask_vec0.v ); + sum_med_vec0.v = _mm256_fmadd_pd( ymm0.v, ymm0.v, sum_med_vec0.v ); + + // Fill sum_big vector using scaling. + temp.v = _mm256_set1_pd( scale_big ); + ymm0.v = _mm256_blendv_pd( zerov.v, temp.v, mask_vec0.v ); + ymm0.v = _mm256_mul_pd( x0v.v, ymm0.v ); + sum_big_vec0.v = _mm256_fmadd_pd( ymm0.v, ymm0.v, sum_big_vec0.v ); + temp.v = _mm256_set1_pd( -0.0 ); + } + else + { + // Mask vector which indicates whether xi > thres_small. + mask_vec0.v = _mm256_cmp_pd( x0v.v, thres_sml_vec.v, _CMP_LT_OQ ); + // Fill sum_med vector without scaling. + ymm0.v = _mm256_blendv_pd( x0v.v, zerov.v, mask_vec0.v ); + sum_med_vec0.v = _mm256_fmadd_pd( ymm0.v, ymm0.v, sum_med_vec0.v ); + + // Accumulate small values only if there have not been any big values so far. + if ( !isbig ) + { + // Fill sum_sml vector using scaling. + temp.v = _mm256_set1_pd( scale_sml ); + ymm0.v = _mm256_blendv_pd( zerov.v, temp.v, mask_vec0.v ); + ymm0.v = _mm256_mul_pd( x0v.v, ymm0.v ); + sum_sml_vec0.v = _mm256_fmadd_pd( ymm0.v, ymm0.v, sum_sml_vec0.v ); + temp.v = _mm256_set1_pd( -0.0 ); + } + } + } + + if ( !bli_horizontal_or( mask_vec1.v ) ) + { + // Scaling is not necessary; only medium values. + sum_med_vec1.v = _mm256_fmadd_pd( x1v.v, x1v.v, sum_med_vec1.v ); + } + else + { + // Mask vector which indicate whether xi > thres_big. + mask_vec1.v = _mm256_cmp_pd( x1v.v, thres_big_vec.v, _CMP_GT_OQ ); + + if ( bli_horizontal_or( mask_vec1.v ) ) + { + isbig = true; + + // Fill sum_med vector without scaling. + ymm1.v = _mm256_blendv_pd( x1v.v, zerov.v, mask_vec1.v ); + sum_med_vec1.v = _mm256_fmadd_pd( ymm1.v, ymm1.v, sum_med_vec1.v ); + + // Fill sum_big vector using scaling. + temp.v = _mm256_set1_pd( scale_big ); + ymm1.v = _mm256_blendv_pd( zerov.v, temp.v, mask_vec1.v ); + ymm1.v = _mm256_mul_pd( x1v.v, ymm1.v ); + sum_big_vec1.v = _mm256_fmadd_pd( ymm1.v, ymm1.v, sum_big_vec1.v ); + temp.v = _mm256_set1_pd( -0.0 ); + } + else + { + // Mask vector which indicates whether xi > thres_small. + mask_vec1.v = _mm256_cmp_pd( x1v.v, thres_sml_vec.v, _CMP_LT_OQ ); + // Fill sum_med vector without scaling. + ymm1.v = _mm256_blendv_pd( x1v.v, zerov.v, mask_vec1.v ); + sum_med_vec1.v = _mm256_fmadd_pd( ymm1.v, ymm1.v, sum_med_vec1.v ); + + // Accumulate small values only if there have not been any big values so far. + if ( !isbig ) + { + // Fill sum_sml vector using scaling. + temp.v = _mm256_set1_pd( scale_sml ); + ymm1.v = _mm256_blendv_pd( zerov.v, temp.v, mask_vec1.v ); + ymm1.v = _mm256_mul_pd( x1v.v, ymm1.v ); + sum_sml_vec1.v = _mm256_fmadd_pd( ymm1.v, ymm1.v, sum_sml_vec1.v ); + temp.v = _mm256_set1_pd( -0.0 ); + } + } + } + + xt += 4; + } + + for ( ; ( i + 2 ) <= n; i = i + 2 ) + { + x0v.v = _mm256_loadu_pd( xt ); + + // Getting the abs of the vector elements. + x0v.v = _mm256_andnot_pd( temp.v, x0v.v ); + + // Check if any of the values is a NaN and if so, return. + mask_vec0.v = _mm256_cmp_pd(x0v.v, x0v.v, _CMP_UNORD_Q); + if ( bli_horizontal_or( mask_vec0.v ) ) + { + *norm = NAN; + return; + } + + // Mask vectors which indicate whether + // xi<=thres_sml or xi>=thres_big. + mask_vec0.v = CMP256( x0v.v, thres_sml_vec.v, thres_big_vec.v ); + + if ( !bli_horizontal_or( mask_vec0.v ) ) + { + // Scaling is not necessary; only medium values. + sum_med_vec0.v = _mm256_fmadd_pd( x0v.v, x0v.v, sum_med_vec0.v ); + } + else + { + // Mask vector which indicate whether xi > thres_big. + mask_vec0.v = _mm256_cmp_pd( x0v.v, thres_big_vec.v, _CMP_GT_OQ ); + + if ( bli_horizontal_or( mask_vec0.v ) ) + { + isbig = true; + + // Fill sum_med vector without scaling. + ymm0.v = _mm256_blendv_pd( x0v.v, zerov.v, mask_vec0.v ); + sum_med_vec0.v = _mm256_fmadd_pd( ymm0.v, ymm0.v, sum_med_vec0.v ); + + // Fill sum_big vector using scaling. + temp.v = _mm256_set1_pd( scale_big ); + ymm0.v = _mm256_blendv_pd( zerov.v, temp.v, mask_vec0.v ); + ymm0.v = _mm256_mul_pd( x0v.v, ymm0.v ); + sum_big_vec0.v = _mm256_fmadd_pd( ymm0.v, ymm0.v, sum_big_vec0.v ); + temp.v = _mm256_set1_pd( -0.0 ); + } + else + { + // Mask vector which indicates whether xi > thres_small. + mask_vec0.v = _mm256_cmp_pd( x0v.v, thres_sml_vec.v, _CMP_LT_OQ ); + // Fill sum_med vector without scaling. + ymm0.v = _mm256_blendv_pd( x0v.v, zerov.v, mask_vec0.v ); + sum_med_vec0.v = _mm256_fmadd_pd( ymm0.v, ymm0.v, sum_med_vec0.v ); + + // Accumulate small values only if there have not been any big values so far. + if ( !isbig ) + { + // Fill sum_sml vector using scaling. + temp.v = _mm256_set1_pd( scale_sml ); + ymm0.v = _mm256_blendv_pd( zerov.v, temp.v, mask_vec0.v ); + ymm0.v = _mm256_mul_pd( x0v.v, ymm0.v ); + sum_sml_vec0.v = _mm256_fmadd_pd( ymm0.v, ymm0.v, sum_sml_vec0.v ); + temp.v = _mm256_set1_pd( -0.0 ); + } + } + } + xt += 2; + } + + sum_sml_vec0.v = _mm256_add_pd( sum_sml_vec0.v, sum_sml_vec1.v ); + sum_med_vec0.v = _mm256_add_pd( sum_med_vec0.v, sum_med_vec1.v ); + sum_big_vec0.v = _mm256_add_pd( sum_big_vec0.v, sum_big_vec1.v ); + + sum_sml += sum_sml_vec0.v[0] + sum_sml_vec0.v[1] + + sum_sml_vec0.v[2] + sum_sml_vec0.v[3]; + sum_med += sum_med_vec0.v[0] + sum_med_vec0.v[1] + + sum_med_vec0.v[2] + sum_med_vec0.v[3]; + sum_big += sum_big_vec0.v[0] + sum_big_vec0.v[1] + + sum_big_vec0.v[2] + sum_big_vec0.v[3]; + } + + n_remainder = n - i; + bool hasInf = false; + if ( ( n_remainder > 0 ) ) + { + // Put first the most likely to happen to avoid evaluations on if statements. + for (i = 0; i < n_remainder; i++) + { + // Get real and imaginary component of the vector element. + double chi_r, chi_i; + bli_zdgets(*xt, chi_r, chi_i); + + // Start with accumulating the real component of the vector element. + abs_chi = bli_fabs( chi_r ); + // If any of the elements is NaN, then return NaN as a result. + if ( bli_isnan( abs_chi ) ) + { + *norm = abs_chi; + return; + } + // Else, if any of the elements is an Inf, then return +Inf as a result. + if ( bli_isinf( abs_chi ) ) + { + *norm = abs_chi; + // Instead of returning immediately, use this flag + // to denote that there is an Inf element in the vector. + // That is used to avoid cases where there is a NaN which comes + // after an Inf. + hasInf = true; + } + // Most likely case: medium values, not over/under-flow. + if ( ( abs_chi <= thres_big ) && ( abs_chi >= thres_sml ) ) + { + sum_med += abs_chi * abs_chi; + } + // Case where there could be an overflow. Scaling is required. + else if ( abs_chi > thres_big ) + { + sum_big += ( abs_chi * scale_big ) * ( abs_chi * scale_big ); + isbig = true; + } + // Case where there could be an underflow. Scaling is required. + else if ( ( !isbig ) && ( abs_chi < thres_sml ) ) + { + sum_sml += ( abs_chi * scale_sml ) * ( abs_chi * scale_sml ); + } + + // Accumulate the imaginary component of the vector element. + abs_chi = bli_fabs( chi_i ); + // If any of the elements is NaN, then return NaN as a result. + if ( bli_isnan( abs_chi ) ) + { + *norm = abs_chi; + return; + } + // Else, if any of the elements is an Inf, then return +Inf as a result. + if ( bli_isinf( abs_chi ) ) + { + *norm = abs_chi; + // Instead of returning immediately, use this flag + // to denote that there is an Inf element in the vector. + // That is used to avoid cases where there is a NaN which comes + // after an Inf. + hasInf = true; + } + // Most likely case: medium values, not over/under-flow. + if ( ( abs_chi <= thres_big ) && ( abs_chi >= thres_sml ) ) + { + sum_med += abs_chi * abs_chi; + } + // Case where there could be an overflow. Scaling is required. + else if ( abs_chi > thres_big ) + { + sum_big += ( abs_chi * scale_big ) * ( abs_chi * scale_big ); + isbig = true; + } + // Case where there could be an underflow. Scaling is required. + else if ( ( !isbig ) && ( abs_chi < thres_sml ) ) + { + sum_sml += ( abs_chi * scale_sml ) * ( abs_chi * scale_sml ); + } + + xt++; + } + } + + // Early return if there is an Inf. + if ( hasInf ) return; + + // Combine accumulators. + if ( isbig ) + { + // Combine sum_big and sum_med if sum_med > 0. + if ( sum_med > 0.0 ) + { + sum_big += ( sum_med * scale_big ) * scale_big; + } + scale = 1.0 / scale_big; + sumsq = sum_big; + } + + else if ( sum_sml > 0.0 ) + { + // Combine sum_med and sum_sml if sum_sml>0. + if ( sum_med > 0.0 ) + { + sum_med = sqrt( sum_med ); + sum_sml = sqrt( sum_sml ) / scale_sml; + double ymin, ymax; + if ( sum_sml > sum_med ) + { + ymin = sum_med; + ymax = sum_sml; + } + else + { + ymin = sum_sml; + ymax = sum_med; + } + scale = 1.0; + sumsq = ymax * ymax * ( 1.0 + ( ymin / ymax ) * ( ymin / ymax ) ); + } + else + { + scale = 1.0 / scale_sml; + sumsq = sum_sml; + } + } + else + { + // If all values are mid-range: + scale = 1.0; + sumsq = sum_med; + } + + *norm = scale * sqrt( sumsq ); + + if ( ( incx != 1 ) && bli_mem_is_alloc( &mem_bufX ) ) + { + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_dznorm2fv_unb_var1(): releasing mem pool block\n" ); + #endif + // Return the buffer to pool. + bli_membrk_release( &rntm , &mem_bufX ); + } + + AOCL_DTL_TRACE_EXIT( AOCL_DTL_LEVEL_TRACE_3 ); + return; } \ No newline at end of file diff --git a/kernels/zen/bli_kernels_zen.h b/kernels/zen/bli_kernels_zen.h index 210052d069..f16aa5cc98 100644 --- a/kernels/zen/bli_kernels_zen.h +++ b/kernels/zen/bli_kernels_zen.h @@ -379,10 +379,26 @@ bool bli_cntx_trsm_small_thresh_is_met_zen dim_t n ); -void bli_dnorm2fv_unb_var1_avx +void bli_dnorm2fv_unb_var1_avx2 ( dim_t n, double* x, inc_t incx, double* norm, cntx_t* cntx + ); + +void bli_dznorm2fv_unb_var1_avx2 + ( + dim_t n, + dcomplex* x, inc_t incx, + double* norm, + cntx_t* cntx + ); +void bli_zdscalv_zen_int10 + ( + conj_t conjalpha, + dim_t n, + double* restrict alpha, + dcomplex* restrict x, inc_t incx, + cntx_t* restrict cntx ); \ No newline at end of file From 0400fc067d19f8956ca14222668aa3eb2f982428 Mon Sep 17 00:00:00 2001 From: Edward Smyth Date: Mon, 3 Oct 2022 10:55:49 +0000 Subject: [PATCH 235/243] BLIS: Errors in Netlib LAPACK Tests Zen4 kernel bli_damaxv_zen_int_avx512 is causing incorrect results in the netlib LAPACK tests, specifically in: ./xlintstd < ../dtest.in > dtest.out in the TESTING/LIN directory. Given time constraints, i.e. the need to finalize code for AOCL 4.0 release, disable calls to AVX512 kernel (i.e. always use the AVX2 kernel) for now, and aim to correct bli_damaxv_zen_int_avx512 for AOCL 4.1. AMD-Internal: [CPUPL-2590] Change-Id: I2603dd97c3931acb9730563e8126b109ec2b2572 (cherry picked from commit 991f2ec0e8a2668cfe80feb33ec2bc5189c973b4) --- config/zen4/bli_cntx_init_zen4.c | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/config/zen4/bli_cntx_init_zen4.c b/config/zen4/bli_cntx_init_zen4.c index 1d9bf3cc05..ac9875abf6 100644 --- a/config/zen4/bli_cntx_init_zen4.c +++ b/config/zen4/bli_cntx_init_zen4.c @@ -137,7 +137,7 @@ void bli_cntx_init_zen4( cntx_t* cntx ) // amaxv BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int_avx512, - BLIS_AMAXV_KER, BLIS_DOUBLE, bli_damaxv_zen_int_avx512, + BLIS_AMAXV_KER, BLIS_DOUBLE, bli_damaxv_zen_int, // axpbyv BLIS_AXPBYV_KER, BLIS_FLOAT, bli_saxpbyv_zen_int10, @@ -355,4 +355,4 @@ void bli_zen4_restore_default_blkszs (cntx_t* cntx) BLIS_DF, &blkszs[ BLIS_DF ], BLIS_DF, cntx ); -} \ No newline at end of file +} From a865f33f0e5c2e3f06489be15d4a39f000eeeac7 Mon Sep 17 00:00:00 2001 From: Vignesh Balasubramanian Date: Thu, 6 Oct 2022 17:33:29 +0530 Subject: [PATCH 236/243] Disabling ZGEMM SUP for single thread -Fine tuned the thresholds for bli_zgemm_small() path for single thread scenarios. -Disabled the sup path for zgemm for single thread cases temporarily. It will be enabled after retuning the performance and updating the thresholds between other paths. -Currently bli_zgemm_small() and native paths handle zgemm inputs for single thread. AMD-Internal: [CPUPL-2495] Change-Id: If068becc7b9093c6d333003b8b72818acd88cab8 --- frame/compat/bla_gemm_amd.c | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/frame/compat/bla_gemm_amd.c b/frame/compat/bla_gemm_amd.c index c9a5c96039..a9478581ef 100644 --- a/frame/compat/bla_gemm_amd.c +++ b/frame/compat/bla_gemm_amd.c @@ -821,8 +821,7 @@ void zgemm_ #ifdef BLIS_ENABLE_SMALL_MATRIX - if (((nt == 0) && (((m0 <= 40) && (n0 <= 40)) || - ((m0 <= 128) && (n0 <= 128) && bli_is_notrans(blis_transb))) && (k0 <= 512)) || + if (((nt == 0) && ((m0 <= 512) && (n0 <= 512) && (k0 <= 512))) || ((nt == 1) && (((m0 <= 32) || (n0 <= 32) || (k0 <= 32)) && ((m0 + n0 + k0) <= 100)))) { err_t status = BLIS_NOT_YET_IMPLEMENTED; @@ -860,12 +859,16 @@ void zgemm_ } #endif - err_t status = bli_gemmsup(&alphao, &ao, &bo, &betao, &co, NULL, NULL); - if (status == BLIS_SUCCESS) + // disabling sup path for single thread in zgemm until further tuning. + if (nt == 1) { - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) - return; + err_t status = bli_gemmsup(&alphao, &ao, &bo, &betao, &co, NULL, NULL); + if (status == BLIS_SUCCESS) + { + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) + return; + } } // fall back on native path when zgemm is not handled in sup path. From be4699bb436926ad88caeab8a3802effc3c69da9 Mon Sep 17 00:00:00 2001 From: Harihara Sudhan S Date: Fri, 7 Oct 2022 07:04:44 +0000 Subject: [PATCH 237/243] Fixed bench accuracy issue in LPGEMM Description: - When the value of the result in s8 for u8s8s32 and u8s8s16 are close to 0. Values are getting ceiled to 1. - Used nearbyintf to round the downscaled values in bench reference. This fixed the result mismatch issue between the vectorized kernel implementation and reference implementation in bench accuracy test. AMD-Internal: [CPUPL-2617] Change-Id: I7d58ec0c902a7ad540f5fd55794769ec946afbc5 --- bench/bench_aocl_gemm/bench_lpgemm.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bench/bench_aocl_gemm/bench_lpgemm.c b/bench/bench_aocl_gemm/bench_lpgemm.c index df239fccb2..ff82ef9fc2 100644 --- a/bench/bench_aocl_gemm/bench_lpgemm.c +++ b/bench/bench_aocl_gemm/bench_lpgemm.c @@ -277,7 +277,7 @@ inline C_type mat_mul_accuracy_check_downscale_ ## BLAS_DOWNSCALE_SFX \ dim_t j \ )\ {\ - out_temp_accum = ( C_type )lroundf( ( SCALE_type )temp_accum * \ + out_temp_accum = ( C_type )nearbyintf( ( SCALE_type )temp_accum * \ ( *( ( SCALE_type* )post_op->sum.scale_factor + j ) ) ); \ return out_temp_accum; \ }\ From c08222cb57a5a6112f74f7cfcf1884459f71c168 Mon Sep 17 00:00:00 2001 From: Mangala V Date: Fri, 7 Oct 2022 07:29:09 -0400 Subject: [PATCH 238/243] Enabled MT path in DTRSM Small 1. Fixed accuracy issues in dtrsm small multi-thread. 2. Fixed out of bound memory accesses in the patch "Fixes to avoid Out of Bound Memory Access in TRSM small algorithm" 3. Re-enabled DTRSM small MT with above fixes. AMD-Internal: [CPUPL-2567] Change-Id: Ibf2949b25fde4007a92cc635526bed0e0d897800 --- frame/compat/bla_trsm_amd.c | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/frame/compat/bla_trsm_amd.c b/frame/compat/bla_trsm_amd.c index 13330a5d08..f23599258b 100644 --- a/frame/compat/bla_trsm_amd.c +++ b/frame/compat/bla_trsm_amd.c @@ -922,9 +922,41 @@ void dtrsm_ return; } } + // bli_trsm_small_mt is performing better than native multithread + // for certain sizes of m & n. +#ifdef BLIS_ENABLE_OPENMP + rntm_t rntm; + bli_rntm_init_from_global( &rntm ); + // Query the total number of threads from the rntm_t object. + dim_t n_threads = bli_rntm_num_threads( &rntm ); + if ( ( (n_threads > 1) && (m0 <= 1500) && (n0 <= 1500) ) || + ( (n_threads == 32) && (m0 <= 2300) && (n0 <= 2300) ) || + ( (n_threads == 16) && (m0 <= 3800) && (n0 <= 3800) ) || + ( (n_threads == 8) && (m0 <= 2800) && (n0 <= 2800) ) || + ( (n_threads == 4) && (m0 <= 2000) && (n0 <= 2000) ) || + ( (n_threads == 2) && (m0 <= 2000) && (n0 <= 2000) ) ) + { + err_t status; + status = bli_trsm_small_mt( + blis_side, + &alphao, + &ao, + &bo, + NULL, + NULL); + if ( status == BLIS_SUCCESS ) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + /* Finalize BLIS. */ + bli_finalize_auto(); + return; + } + } +#endif// BLIS_ENABLE_OPENMP } // bli_cpuid_is_avx_supported #endif// END of BLIS_ENABLE_SMALL_MATRIX_TRSM + bli_trsmnat ( blis_side, From ce89138cd950db38e005b321b68106007a8a71ee Mon Sep 17 00:00:00 2001 From: Nallani Bhaskar Date: Tue, 11 Oct 2022 13:20:50 +0530 Subject: [PATCH 239/243] Fixed compilation warnings Description: - Fixed few compilation warnings in dnrm2 avx2 implementation. Warnings occurred when dcomplex* type is used instead of double* while loading data using intrinsic. - Type casted the address to double* to remove the warning. Change-Id: Ic0fd08635f9368c0075e0ca0605d446e44e5d55d --- kernels/zen/1/bli_norm2_zen_int.c | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/kernels/zen/1/bli_norm2_zen_int.c b/kernels/zen/1/bli_norm2_zen_int.c index 278f752b63..1971b79433 100644 --- a/kernels/zen/1/bli_norm2_zen_int.c +++ b/kernels/zen/1/bli_norm2_zen_int.c @@ -604,8 +604,8 @@ void bli_dznorm2fv_unb_var1_avx2 for (; ( i + 4 ) <= n; i = i + 4) { - x0v.v = _mm256_loadu_pd( xt ); - x1v.v = _mm256_loadu_pd( xt + 2 ); + x0v.v = _mm256_loadu_pd( (double*) xt ); + x1v.v = _mm256_loadu_pd( (double*) (xt + 2) ); // Getting the abs of the vector elements. x0v.v = _mm256_andnot_pd( temp.v, x0v.v ); @@ -727,7 +727,7 @@ void bli_dznorm2fv_unb_var1_avx2 for ( ; ( i + 2 ) <= n; i = i + 2 ) { - x0v.v = _mm256_loadu_pd( xt ); + x0v.v = _mm256_loadu_pd( (double*) xt ); // Getting the abs of the vector elements. x0v.v = _mm256_andnot_pd( temp.v, x0v.v ); @@ -952,4 +952,4 @@ void bli_dznorm2fv_unb_var1_avx2 AOCL_DTL_TRACE_EXIT( AOCL_DTL_LEVEL_TRACE_3 ); return; -} \ No newline at end of file +} From 6dfc1cc3af648b396b740a1933cd52da98644e79 Mon Sep 17 00:00:00 2001 From: eashdash Date: Tue, 11 Oct 2022 12:05:32 +0000 Subject: [PATCH 240/243] Added clipping while downscaling for u8s8s32os8 and u8s8s16os8. Clipping is done during the downscaling of the accumulated result from s32 to s8 for u8s8s32os8 and from s16 to s8 for u8s8s16os8, to saturate the final output values between [-128,127] AMD-Internal: [LWPZENDNN-493] Change-Id: Ibc0ff2fd2a1662bd8a4835a828c5f45f41f7b5ce --- .../kernels/u8s8s16/lpgemm_s16_kern_macros.h | 100 ++++++++++++------ .../kernels/u8s8s32/lpgemm_s32_kern_macros.h | 34 ++++-- bench/bench_aocl_gemm/bench_lpgemm.c | 20 +++- 3 files changed, 109 insertions(+), 45 deletions(-) diff --git a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_s16_kern_macros.h b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_s16_kern_macros.h index 55a763727d..00583977f3 100644 --- a/addon/aocl_gemm/kernels/u8s8s16/lpgemm_s16_kern_macros.h +++ b/addon/aocl_gemm/kernels/u8s8s16/lpgemm_s16_kern_macros.h @@ -34,6 +34,8 @@ #ifndef LPGEMM_S16_KERN_MACROS_H #define LPGEMM_S16_KERN_MACROS_H +#define S8_MIN (-128) +#define S8_MAX (+127) #define RELU_SCALE_OP_S16_AVX2(reg) \ selector1 = _mm256_setzero_si256();\ @@ -70,11 +72,15 @@ res_1 = _mm256_mul_ps(temp_float[0], scale_1);\ res_2 = _mm256_mul_ps(temp_float[1], scale_2);\ \ - /* Round the resultant value to the nearest integer*/\ - res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));\ - res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));\ + /* Round the resultant value to the nearest float value and clip the values between [-128, 127] */\ + res_1 = _mm256_min_ps(_mm256_max_ps \ + (_mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)), \ + _mm256_set1_ps(( float )S8_MIN)), _mm256_set1_ps(( float )S8_MAX));\ + res_2 = _mm256_min_ps(_mm256_max_ps \ + (_mm256_round_ps (res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)), \ + _mm256_set1_ps(( float )S8_MIN)), _mm256_set1_ps(( float )S8_MAX));\ \ - /* Convert float32 scaled rounded value to int32 */\ + /* Convert the clipped float32 scaled rounded value to int32 */\ temp_32[0] = _mm256_cvtps_epi32(res_1);\ temp_32[1] = _mm256_cvtps_epi32(res_2);\ \ @@ -99,11 +105,15 @@ res_1 = _mm256_mul_ps(temp_float[0], scale_1);\ res_2 = _mm256_mul_ps(temp_float[1], scale_2);\ \ - /* Round the resultant value to the nearest integer*/\ - res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));\ - res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));\ + /* Round the resultant value to the nearest float value and clip the values between [-128, 127] */\ + res_1 = _mm256_min_ps(_mm256_max_ps \ + (_mm256_round_ps (res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)), \ + _mm256_set1_ps(( float )S8_MIN)), _mm256_set1_ps(( float )S8_MAX));\ + res_2 = _mm256_min_ps(_mm256_max_ps \ + (_mm256_round_ps (res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)), \ + _mm256_set1_ps(( float )S8_MIN)), _mm256_set1_ps(( float )S8_MAX));\ \ - /* Convert float32 scaled rounded value to int32 */\ + /* Convert the clipped float32 scaled rounded value to int32 */\ temp_32[0] = _mm256_cvtps_epi32(res_1);\ temp_32[1] = _mm256_cvtps_epi32(res_2);\ \ @@ -140,11 +150,15 @@ res_1 = _mm256_mul_ps(temp_float[0], scale_1);\ res_2 = _mm256_mul_ps(temp_float[1], scale_2);\ \ - /* Round the resultant value to the nearest integer*/\ - res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));\ - res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));\ + /* Round the resultant value to the nearest float value and clip the values between [-128, 127] */\ + res_1 = _mm256_min_ps(_mm256_max_ps \ + (_mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)), \ + _mm256_set1_ps(( float )S8_MIN)), _mm256_set1_ps(( float )S8_MAX));\ + res_2 = _mm256_min_ps(_mm256_max_ps \ + (_mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)), \ + _mm256_set1_ps(( float )S8_MIN)), _mm256_set1_ps(( float )S8_MAX));\ \ - /* Convert float32 scaled rounded value to int32 */\ + /* Convert the clipped float32 scaled rounded value to int32 */\ temp_32[0] = _mm256_cvtps_epi32(res_1);\ temp_32[1] = _mm256_cvtps_epi32(res_2);\ \ @@ -169,11 +183,15 @@ res_1 = _mm256_mul_ps(temp_float[0], scale_1);\ res_2 = _mm256_mul_ps(temp_float[1], scale_2);\ \ - /* Round the resultant value to the nearest integer*/\ - res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));\ - res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));\ + /* Round the resultant value to the nearest float value and clip the values between [-128, 127] */\ + res_1 = _mm256_min_ps(_mm256_max_ps \ + (_mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)), \ + _mm256_set1_ps(( float )S8_MIN)), _mm256_set1_ps(( float )S8_MAX));\ + res_2 = _mm256_min_ps(_mm256_max_ps \ + (_mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)), \ + _mm256_set1_ps(( float )S8_MIN)), _mm256_set1_ps(( float )S8_MAX));\ \ - /* Convert float32 scaled rounded value to int32 */\ + /* Convert the clipped float32 scaled rounded value to int32 */\ temp_32[0] = _mm256_cvtps_epi32(res_1);\ temp_32[1] = _mm256_cvtps_epi32(res_2);\ \ @@ -216,11 +234,15 @@ res_1 = _mm256_mul_ps(temp_float[0], scale_1);\ res_2 = _mm256_mul_ps(temp_float[1], scale_2);\ \ - /* Round the resultant value to the nearest integer*/\ - res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));\ - res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));\ + /* Round the resultant value to the nearest float value and clip the values between [-128, 127] */\ + res_1 = _mm256_min_ps(_mm256_max_ps \ + (_mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)), \ + _mm256_set1_ps (( float )S8_MIN)), _mm256_set1_ps (( float )S8_MAX));\ + res_2 = _mm256_min_ps(_mm256_max_ps \ + (_mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)), \ + _mm256_set1_ps (( float )S8_MIN)), _mm256_set1_ps (( float )S8_MAX));\ \ - /* Convert float32 scaled rounded value to int32 */\ + /* Convert the clipped float32 scaled rounded value to int32 */\ temp_32[0] = _mm256_cvtps_epi32(res_1);\ temp_32[1] = _mm256_cvtps_epi32(res_2);\ \ @@ -245,11 +267,15 @@ res_1 = _mm256_mul_ps(temp_float[0], scale_1);\ res_2 = _mm256_mul_ps(temp_float[1], scale_2);\ \ - /* Round the resultant value to the nearest integer*/\ - res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));\ - res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));\ + /* Round the resultant value to the nearest float value and clip the values between [-128, 127] */\ + res_1 = _mm256_min_ps(_mm256_max_ps \ + (_mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)), \ + _mm256_set1_ps (( float )S8_MIN)), _mm256_set1_ps (( float )S8_MAX));\ + res_2 = _mm256_min_ps(_mm256_max_ps \ + (_mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)), \ + _mm256_set1_ps (( float )S8_MIN)), _mm256_set1_ps (( float )S8_MAX));\ \ - /* Convert float32 scaled rounded value to int32 */\ + /* Convert the clipped float32 scaled rounded value to int32 */\ temp_32[0] = _mm256_cvtps_epi32(res_1);\ temp_32[1] = _mm256_cvtps_epi32(res_2);\ \ @@ -297,11 +323,15 @@ res_1 = _mm256_mul_ps(temp_float[0], scale_1);\ res_2 = _mm256_mul_ps(temp_float[1], scale_2);\ \ - /* Round the resultant value to the nearest integer*/\ - res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));\ - res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));\ + /* Round the resultant value to the nearest float value and clip the values between [-128, 127] */\ + res_1 = _mm256_min_ps(_mm256_max_ps \ + (_mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)), \ + _mm256_set1_ps (( float )S8_MIN)), _mm256_set1_ps (( float )S8_MAX));\ + res_2 = _mm256_min_ps(_mm256_max_ps \ + (_mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)), \ + _mm256_set1_ps (( float )S8_MIN)), _mm256_set1_ps (( float )S8_MAX));\ \ - /* Convert float32 scaled rounded value to int32 */\ + /* Convert the clipped float32 scaled rounded value to int32 */\ temp_32[0] = _mm256_cvtps_epi32(res_1);\ temp_32[1] = _mm256_cvtps_epi32(res_2);\ \ @@ -340,11 +370,15 @@ res_1 = _mm256_mul_ps(temp_float[0], scale_1);\ res_2 = _mm256_mul_ps(temp_float[1], scale_2);\ \ - /* Round the resultant value to the nearest integer*/\ - res_1 = _mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));\ - res_2 = _mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));\ + /* Round the resultant value to the nearest float value and clip the values between [-128, 127] */\ + res_1 = _mm256_min_ps(_mm256_max_ps \ + (_mm256_round_ps(res_1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)), \ + _mm256_set1_ps (( float )S8_MIN)), _mm256_set1_ps (( float )S8_MAX));\ + res_2 = _mm256_min_ps(_mm256_max_ps \ + (_mm256_round_ps(res_2, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)), \ + _mm256_set1_ps (( float )S8_MIN)), _mm256_set1_ps (( float )S8_MAX));\ \ - /* Convert float32 scaled rounded value to int32 */\ + /* Convert the clipped float32 scaled rounded value to int32 */\ temp_32[0] = _mm256_cvtps_epi32(res_1);\ temp_32[1] = _mm256_cvtps_epi32(res_2);\ \ @@ -367,4 +401,4 @@ ,store_buf, ( n0_rem * sizeof( int8_t ) ) ); \ \ -#endif //LPGEMM_S16_KERN_MACROS_H +#endif //LPGEMM_S16_KERN_MACROS_H \ No newline at end of file diff --git a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_s32_kern_macros.h b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_s32_kern_macros.h index d095ef683b..bc3546736c 100644 --- a/addon/aocl_gemm/kernels/u8s8s32/lpgemm_s32_kern_macros.h +++ b/addon/aocl_gemm/kernels/u8s8s32/lpgemm_s32_kern_macros.h @@ -34,6 +34,8 @@ #ifndef LPGEMM_S32_KERN_MACROS_H #define LPGEMM_S32_KERN_MACROS_H +#define S8_MIN (-128) +#define S8_MAX (+127) #define RELU_SCALE_OP_S32_AVX512(reg) \ /* Generate indenx of elements <= 0.*/ \ @@ -51,11 +53,19 @@ ( \ _mm512_cvtps_epi32 \ ( \ - _mm512_mul_round_ps \ + _mm512_min_ps \ ( \ - _mm512_cvtepi32_ps( reg ), \ - ( __m512 )selector, \ - ( _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC ) \ + _mm512_max_ps \ + ( \ + _mm512_mul_round_ps \ + ( \ + _mm512_cvtepi32_ps( reg ), \ + ( __m512 )selector, \ + ( _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC ) \ + ) \ + , _mm512_set1_ps (( float )S8_MIN) \ + ) \ + , _mm512_set1_ps (( float )S8_MAX) \ ) \ ) \ ) \ @@ -69,11 +79,19 @@ ( \ _mm512_cvtps_epi32 \ ( \ - _mm512_mul_round_ps \ + _mm512_min_ps \ ( \ - _mm512_cvtepi32_ps( reg ), \ - ( __m512 )selector, \ - ( _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC ) \ + _mm512_max_ps \ + ( \ + _mm512_mul_round_ps \ + ( \ + _mm512_cvtepi32_ps( reg ), \ + ( __m512 )selector, \ + ( _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC ) \ + ) \ + , _mm512_set1_ps (( float )S8_MIN) \ + ) \ + , _mm512_set1_ps (( float )S8_MAX) \ ) \ ) \ ) \ diff --git a/bench/bench_aocl_gemm/bench_lpgemm.c b/bench/bench_aocl_gemm/bench_lpgemm.c index ff82ef9fc2..92b7a7a1a6 100644 --- a/bench/bench_aocl_gemm/bench_lpgemm.c +++ b/bench/bench_aocl_gemm/bench_lpgemm.c @@ -8,6 +8,9 @@ #include "blis.h" +#define S8_MIN (-128) +#define S8_MAX (+127) + // Mode can be one of the follwoing: // 1. p - performance, used for benchmarks. // 2. a - accuracy, used to test accuracy/correctness. @@ -267,6 +270,15 @@ GEN_MAT_MUL_BENCH_DRV_FUNC(bfloat16,bfloat16,float,float,bf16bf16f32of32) GEN_MAT_MUL_BENCH_DRV_FUNC(bfloat16,bfloat16,bfloat16,float,bf16bf16f32obf16) GEN_MAT_MUL_BENCH_DRV_FUNC(float,float,float,float,f32f32f32of32) +int max (int a, int b) +{ + return ( a > b ? a : b ); +} + +int min (int a, int b) +{ + return ( a < b ? a : b ); +} #define GEN_MAT_MUL_ACC_CHK_DOWNSCALE(C_type,ACCUM_type,SCALE_type,BLAS_DOWNSCALE_SFX) \ inline C_type mat_mul_accuracy_check_downscale_ ## BLAS_DOWNSCALE_SFX \ @@ -277,8 +289,8 @@ inline C_type mat_mul_accuracy_check_downscale_ ## BLAS_DOWNSCALE_SFX \ dim_t j \ )\ {\ - out_temp_accum = ( C_type )nearbyintf( ( SCALE_type )temp_accum * \ - ( *( ( SCALE_type* )post_op->sum.scale_factor + j ) ) ); \ + out_temp_accum = ( C_type ) min ( max ( nearbyintf( ( SCALE_type )temp_accum * \ + ( *( ( SCALE_type* )post_op->sum.scale_factor + j ) ) ), S8_MIN ), S8_MAX ) ; \ return out_temp_accum; \ }\ @@ -1118,8 +1130,8 @@ int main( int argc, char** argv ) &stride_a, &stride_b, &stride_c ) == 9 ) { stor_order = ( ( stor_order == 'r' ) || ( stor_order == 'R' ) || - ( stor_order == 'c' ) || ( stor_order == 'C' ) ) ? - stor_order : 'r'; + ( stor_order == 'c' ) || ( stor_order == 'C' ) ) ? + stor_order : 'r'; if ( ( op_type_char == 'i' ) || ( op_type_char == 'I' ) ) { From c3b08b6ae16e5d3521279e651bd3ccd3ddf18f85 Mon Sep 17 00:00:00 2001 From: Chandrashekara K R Date: Thu, 13 Oct 2022 12:22:16 +0530 Subject: [PATCH 241/243] Updated the version string from 3.2.1 to 4.0.0 AMD-Internal: [CPUPL-2322] Change-Id: I74380b6b5ba4297cf1a846de0da4cfd700c2732f --- so_version | 4 ++-- version | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/so_version b/so_version index 63b08bc9c2..436b8f7fa7 100644 --- a/so_version +++ b/so_version @@ -1,2 +1,2 @@ -3 -2.1 +4 +0.0 diff --git a/version b/version index e4604e3afd..fcdb2e109f 100644 --- a/version +++ b/version @@ -1 +1 @@ -3.2.1 +4.0.0 From 62a23deb42f02d483dd0db2b34e23b29b17a8984 Mon Sep 17 00:00:00 2001 From: Harihara Sudhan S Date: Thu, 13 Oct 2022 06:35:47 +0000 Subject: [PATCH 242/243] Copyright modification - Added copyright information to modified/newly created files missing them Change-Id: I1b951512b4ae60bc2120f70f1512790e958c2428 --- bench/bench_aocl_gemm/data_gen_lpgemm.py | 33 +++++++++++++++++++ build/bli_win_config.h.in | 2 +- config/CMakeLists.txt | 2 +- config/zen4/CMakeLists.txt | 2 +- frame/2/her/CMakeLists.txt | 2 +- frame/3/bli_l3_packm.c | 2 +- frame/3/bli_l3_packm.h | 2 +- frame/3/bli_l3_sup_int.c | 2 +- frame/3/bli_l3_sup_packm_a.c | 2 +- frame/3/bli_l3_sup_packm_b.c | 2 +- frame/3/gemm/bli_gemm_md.c | 2 +- frame/base/bli_cntx.h | 2 +- frame/base/bli_pool.c | 2 +- frame/include/bli_genarray_macro_defs.h | 2 +- frame/include/bli_obj_macro_defs.h | 2 +- frame/include/blis.h | 2 +- frame/util/bli_util_api_wrap.h | 2 +- frame/util/bli_util_unb_var1.c | 2 +- .../haswell/1m/bli_packm_haswell_asm_c3xk.c | 2 +- .../haswell/1m/bli_packm_haswell_asm_c8xk.c | 2 +- .../haswell/1m/bli_packm_haswell_asm_d8xk.c | 2 +- kernels/haswell/3/bli_gemm_haswell_asm_d6x8.c | 2 +- .../3/sup/bli_gemmsup_rv_haswell_asm_s6x16m.c | 2 +- .../d6x8/bli_gemmsup_rd_haswell_asm_dMx1.c | 2 +- .../d6x8/bli_gemmsup_rd_haswell_asm_dMx4.c | 2 +- .../s6x16/bli_gemmsup_rv_haswell_asm_sMx2.c | 2 +- kernels/zen/1/CMakeLists.txt | 2 +- kernels/zen/2/bli_gemv_zen_int_4.c | 2 +- kernels/zen/util/bli_thresh_funcs_zen.c | 2 +- test/CMakeLists.txt | 2 +- testsuite/CMakeLists.txt | 2 +- testsuite/src/test_gemm.c | 2 +- testsuite/src/test_libblis.c | 2 +- 33 files changed, 65 insertions(+), 32 deletions(-) diff --git a/bench/bench_aocl_gemm/data_gen_lpgemm.py b/bench/bench_aocl_gemm/data_gen_lpgemm.py index f9c09689f9..3bc3a24421 100644 --- a/bench/bench_aocl_gemm/data_gen_lpgemm.py +++ b/bench/bench_aocl_gemm/data_gen_lpgemm.py @@ -1,3 +1,36 @@ +# +# +# BLIS +# An object-based framework for developing high-performance BLAS-like +# libraries. +# +# Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# - Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# - Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# - Neither the name(s) of the copyright holder(s) nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# + # Initializing global mnk_array.This array will be used to store all mnk values mnk_array = [] diff --git a/build/bli_win_config.h.in b/build/bli_win_config.h.in index 3903763873..24e1fc3d59 100644 --- a/build/bli_win_config.h.in +++ b/build/bli_win_config.h.in @@ -1,5 +1,5 @@ /* - * Copyright (C) 2020-2021, Advanced Micro Devices, Inc. All Rights Reserved + * Copyright (C) 2020-2022, Advanced Micro Devices, Inc. All Rights Reserved */ #ifndef BLIS_CONFIG_H diff --git a/config/CMakeLists.txt b/config/CMakeLists.txt index dd8305c371..7429ff42ee 100644 --- a/config/CMakeLists.txt +++ b/config/CMakeLists.txt @@ -1,4 +1,4 @@ -##Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. ## +##Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. ## if(${TARGET_ARCH} STREQUAL zen4) message("The configuration is : ${TARGET_ARCH}") diff --git a/config/zen4/CMakeLists.txt b/config/zen4/CMakeLists.txt index 1c18083516..ea166b00c7 100644 --- a/config/zen4/CMakeLists.txt +++ b/config/zen4/CMakeLists.txt @@ -1,4 +1,4 @@ -##Copyright (C) 2021, Advanced Micro Devices, Inc ## +##Copyright (C) 2022, Advanced Micro Devices, Inc ## target_sources("${PROJECT_NAME}" PRIVATE diff --git a/frame/2/her/CMakeLists.txt b/frame/2/her/CMakeLists.txt index 0bcab8ff61..b97ee3874b 100644 --- a/frame/2/her/CMakeLists.txt +++ b/frame/2/her/CMakeLists.txt @@ -1,4 +1,4 @@ -##Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.## +##Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.## target_sources("${PROJECT_NAME}" PRIVATE diff --git a/frame/3/bli_l3_packm.c b/frame/3/bli_l3_packm.c index 1492313309..1134bdc1fd 100644 --- a/frame/3/bli_l3_packm.c +++ b/frame/3/bli_l3_packm.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/frame/3/bli_l3_packm.h b/frame/3/bli_l3_packm.h index 696dabf593..50a6e2f9c1 100644 --- a/frame/3/bli_l3_packm.h +++ b/frame/3/bli_l3_packm.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/frame/3/bli_l3_sup_int.c b/frame/3/bli_l3_sup_int.c index 909f480599..e215a7d825 100644 --- a/frame/3/bli_l3_sup_int.c +++ b/frame/3/bli_l3_sup_int.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2019-21, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019-22, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/frame/3/bli_l3_sup_packm_a.c b/frame/3/bli_l3_sup_packm_a.c index 4e5f0e4444..196dfae0b5 100644 --- a/frame/3/bli_l3_sup_packm_a.c +++ b/frame/3/bli_l3_sup_packm_a.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2022, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/frame/3/bli_l3_sup_packm_b.c b/frame/3/bli_l3_sup_packm_b.c index 7d7c8815ab..b733122825 100644 --- a/frame/3/bli_l3_sup_packm_b.c +++ b/frame/3/bli_l3_sup_packm_b.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2022, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/frame/3/gemm/bli_gemm_md.c b/frame/3/gemm/bli_gemm_md.c index d7e402b7c7..68298c71ca 100644 --- a/frame/3/gemm/bli_gemm_md.c +++ b/frame/3/gemm/bli_gemm_md.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2017 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2017 - 2022, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/frame/base/bli_cntx.h b/frame/base/bli_cntx.h index fcef4738fd..8dab2a5a19 100644 --- a/frame/base/bli_cntx.h +++ b/frame/base/bli_cntx.h @@ -6,7 +6,7 @@ Copyright (C) 2014, The University of Texas at Austin Copyright (C) 2016, Hewlett Packard Enterprise Development LP - Copyright (C) 2020 - 21, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 22, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/frame/base/bli_pool.c b/frame/base/bli_pool.c index 106fbfa10b..8e11b451f6 100644 --- a/frame/base/bli_pool.c +++ b/frame/base/bli_pool.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/frame/include/bli_genarray_macro_defs.h b/frame/include/bli_genarray_macro_defs.h index a63af52a76..c6ba2d7fbb 100644 --- a/frame/include/bli_genarray_macro_defs.h +++ b/frame/include/bli_genarray_macro_defs.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020, Advanced Micro Devices, Inc. + Copyright (C) 2022, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/frame/include/bli_obj_macro_defs.h b/frame/include/bli_obj_macro_defs.h index 855384425e..1499dc4182 100644 --- a/frame/include/bli_obj_macro_defs.h +++ b/frame/include/bli_obj_macro_defs.h @@ -6,7 +6,7 @@ Copyright (C) 2014, The University of Texas at Austin Copyright (C) 2016, Hewlett Packard Enterprise Development LP - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2022, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/frame/include/blis.h b/frame/include/blis.h index b0b0ba5476..251503de27 100644 --- a/frame/include/blis.h +++ b/frame/include/blis.h @@ -6,7 +6,7 @@ Copyright (C) 2014, The University of Texas at Austin Copyright (C) 2016, Hewlett Packard Enterprise Development LP - Copyright (C) 2018 - 2020, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/frame/util/bli_util_api_wrap.h b/frame/util/bli_util_api_wrap.h index 78f088e28e..86471c76f6 100644 --- a/frame/util/bli_util_api_wrap.h +++ b/frame/util/bli_util_api_wrap.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/frame/util/bli_util_unb_var1.c b/frame/util/bli_util_unb_var1.c index 48abe57acc..78c4c9198d 100644 --- a/frame/util/bli_util_unb_var1.c +++ b/frame/util/bli_util_unb_var1.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/haswell/1m/bli_packm_haswell_asm_c3xk.c b/kernels/haswell/1m/bli_packm_haswell_asm_c3xk.c index 6b7321859f..ab42e06aa9 100644 --- a/kernels/haswell/1m/bli_packm_haswell_asm_c3xk.c +++ b/kernels/haswell/1m/bli_packm_haswell_asm_c3xk.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019 - 2021, Advanced Micro Devices, Inc.All rights reserved. + Copyright (C) 2019 - 2022, Advanced Micro Devices, Inc.All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/haswell/1m/bli_packm_haswell_asm_c8xk.c b/kernels/haswell/1m/bli_packm_haswell_asm_c8xk.c index 6184447d58..a101e66d18 100644 --- a/kernels/haswell/1m/bli_packm_haswell_asm_c8xk.c +++ b/kernels/haswell/1m/bli_packm_haswell_asm_c8xk.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019 - 2021, Advanced Micro Devices, Inc. + Copyright (C) 2019 - 2022, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/haswell/1m/bli_packm_haswell_asm_d8xk.c b/kernels/haswell/1m/bli_packm_haswell_asm_d8xk.c index 3b03d38fb7..0cfa2e8d68 100644 --- a/kernels/haswell/1m/bli_packm_haswell_asm_d8xk.c +++ b/kernels/haswell/1m/bli_packm_haswell_asm_d8xk.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019 - 2020, Advanced Micro Devices, Inc. + Copyright (C) 2019 - 2022, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/haswell/3/bli_gemm_haswell_asm_d6x8.c b/kernels/haswell/3/bli_gemm_haswell_asm_d6x8.c index 5187d0bcb0..79625519c5 100644 --- a/kernels/haswell/3/bli_gemm_haswell_asm_d6x8.c +++ b/kernels/haswell/3/bli_gemm_haswell_asm_d6x8.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2021, Advanced Micro Devices, Inc.All rights reserved. + Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc.All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_s6x16m.c b/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_s6x16m.c index 877e636b80..c299047ff9 100644 --- a/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_s6x16m.c +++ b/kernels/haswell/3/sup/bli_gemmsup_rv_haswell_asm_s6x16m.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2022, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx1.c b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx1.c index 457ef9f22d..8d3900f2e8 100644 --- a/kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx1.c +++ b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx1.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2022, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx4.c b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx4.c index 41f73bc9ec..f19b703b41 100644 --- a/kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx4.c +++ b/kernels/haswell/3/sup/d6x8/bli_gemmsup_rd_haswell_asm_dMx4.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2022, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/haswell/3/sup/s6x16/bli_gemmsup_rv_haswell_asm_sMx2.c b/kernels/haswell/3/sup/s6x16/bli_gemmsup_rv_haswell_asm_sMx2.c index 3cbb69a50f..3d90e6e4f3 100644 --- a/kernels/haswell/3/sup/s6x16/bli_gemmsup_rv_haswell_asm_sMx2.c +++ b/kernels/haswell/3/sup/s6x16/bli_gemmsup_rv_haswell_asm_sMx2.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2022, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/zen/1/CMakeLists.txt b/kernels/zen/1/CMakeLists.txt index 434be490d5..dbdd1533e2 100644 --- a/kernels/zen/1/CMakeLists.txt +++ b/kernels/zen/1/CMakeLists.txt @@ -1,4 +1,4 @@ -##Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.## +##Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.## target_sources("${PROJECT_NAME}" PRIVATE diff --git a/kernels/zen/2/bli_gemv_zen_int_4.c b/kernels/zen/2/bli_gemv_zen_int_4.c index 74904605ee..a4bdfb4499 100644 --- a/kernels/zen/2/bli_gemv_zen_int_4.c +++ b/kernels/zen/2/bli_gemv_zen_int_4.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/kernels/zen/util/bli_thresh_funcs_zen.c b/kernels/zen/util/bli_thresh_funcs_zen.c index 2786f00e43..82e4936fe1 100644 --- a/kernels/zen/util/bli_thresh_funcs_zen.c +++ b/kernels/zen/util/bli_thresh_funcs_zen.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index fe8f7bac98..77b746ba94 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -1,4 +1,4 @@ -##Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved.## +##Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.## add_definitions(-DBLAS="AOCL") diff --git a/testsuite/CMakeLists.txt b/testsuite/CMakeLists.txt index 613f9e3861..b997b8a8d9 100644 --- a/testsuite/CMakeLists.txt +++ b/testsuite/CMakeLists.txt @@ -1,4 +1,4 @@ -##Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved.## +##Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.## include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src) diff --git a/testsuite/src/test_gemm.c b/testsuite/src/test_gemm.c index c49f0a3287..0fbf54df36 100644 --- a/testsuite/src/test_gemm.c +++ b/testsuite/src/test_gemm.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/testsuite/src/test_libblis.c b/testsuite/src/test_libblis.c index 925d2ece1b..b904094267 100644 --- a/testsuite/src/test_libblis.c +++ b/testsuite/src/test_libblis.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2020, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are From ea4acd2661a6be11737e06a71aa4ed4b04f15899 Mon Sep 17 00:00:00 2001 From: Sireesha Sanga Date: Thu, 20 Oct 2022 23:02:41 -0400 Subject: [PATCH 243/243] Revert "Enabled MT path in DTRSM Small" This reverts commit c08222cb57a5a6112f74f7cfcf1884459f71c168. Reason for revert: Functionality Issues Change-Id: I26573cb47e31deda21e0d51851075a5cb50985ca --- frame/compat/bla_trsm_amd.c | 32 -------------------------------- 1 file changed, 32 deletions(-) diff --git a/frame/compat/bla_trsm_amd.c b/frame/compat/bla_trsm_amd.c index f23599258b..13330a5d08 100644 --- a/frame/compat/bla_trsm_amd.c +++ b/frame/compat/bla_trsm_amd.c @@ -922,41 +922,9 @@ void dtrsm_ return; } } - // bli_trsm_small_mt is performing better than native multithread - // for certain sizes of m & n. -#ifdef BLIS_ENABLE_OPENMP - rntm_t rntm; - bli_rntm_init_from_global( &rntm ); - // Query the total number of threads from the rntm_t object. - dim_t n_threads = bli_rntm_num_threads( &rntm ); - if ( ( (n_threads > 1) && (m0 <= 1500) && (n0 <= 1500) ) || - ( (n_threads == 32) && (m0 <= 2300) && (n0 <= 2300) ) || - ( (n_threads == 16) && (m0 <= 3800) && (n0 <= 3800) ) || - ( (n_threads == 8) && (m0 <= 2800) && (n0 <= 2800) ) || - ( (n_threads == 4) && (m0 <= 2000) && (n0 <= 2000) ) || - ( (n_threads == 2) && (m0 <= 2000) && (n0 <= 2000) ) ) - { - err_t status; - status = bli_trsm_small_mt( - blis_side, - &alphao, - &ao, - &bo, - NULL, - NULL); - if ( status == BLIS_SUCCESS ) - { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - /* Finalize BLIS. */ - bli_finalize_auto(); - return; - } - } -#endif// BLIS_ENABLE_OPENMP } // bli_cpuid_is_avx_supported #endif// END of BLIS_ENABLE_SMALL_MATRIX_TRSM - bli_trsmnat ( blis_side,